diff --git a/.github/workflows/typing-checks.yml b/.github/workflows/typing-checks.yml index 18dc8f8011b9..26410d0501ba 100644 --- a/.github/workflows/typing-checks.yml +++ b/.github/workflows/typing-checks.yml @@ -11,7 +11,7 @@ on: - "tests/ignite/**" - "requirements-dev.txt" - "pyproject.toml" - - "mypy.ini" + - "pyrefly.toml" - ".github/workflows/typing-checks.yml" pull_request: paths: @@ -20,7 +20,7 @@ on: - "tests/ignite/**" - "requirements-dev.txt" - "pyproject.toml" - - "mypy.ini" + - "pyrefly.toml" - ".github/workflows/typing-checks.yml" workflow_dispatch: @@ -30,7 +30,7 @@ concurrency: cancel-in-progress: true jobs: - mypy: + pyrefly: runs-on: ubuntu-latest strategy: matrix: @@ -60,14 +60,7 @@ jobs: run: uv pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu - name: Install dependencies - run: | - uv pip install -r requirements-dev.txt - uv pip install . - uv pip install mypy - uv pip install pyrefly - - - name: Run MyPy type checking - run: mypy + run: uv pip install -r requirements-dev.txt - name: Run Pyrefly type checking run: pyrefly check diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c8f58e707de6..b2068bc4990f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -40,7 +40,7 @@ into the following categories: - [Formatting the code with a pre-commit hook](#formatting-the-code-with-a-pre-commit-hook) - [Run tests](#run-tests) - [Run distributed tests only on CPU](#run-distributed-tests-only-on-cpu) - - [Run Mypy checks](#run-mypy-checks) + - [Run Pyrefly checks](#run-pyrefly-checks) - [Send a PR](#send-a-pr) - [Sync up with the upstream](#sync-up-with-the-upstream) - [Writing documentation](#writing-documentation) @@ -214,15 +214,15 @@ export WORLD_SIZE=2 CUDA_VISIBLE_DEVICES="" pytest --dist=each --tx $WORLD_SIZE*popen//python=python tests/ -m distributed -vvv ``` -#### Run Mypy checks: +#### Run Pyrefly checks: -To run mypy to check the optional static type: +To run pyrefly to check the optional static type: ```bash -mypy +pyrefly check ``` -To change any config for specif folder, please see the file mypy.ini +To change any config for specific folder, please see the file pyrefly.toml #### Send a PR diff --git a/ignite/contrib/engines/tbptt.py b/ignite/contrib/engines/tbptt.py index d4b7fbf7f867..6860f587df7b 100644 --- a/ignite/contrib/engines/tbptt.py +++ b/ignite/contrib/engines/tbptt.py @@ -117,6 +117,5 @@ def _update(engine: Engine, batch: Sequence[torch.Tensor]) -> float: return sum(loss_list) / len(loss_list) engine = Engine(_update) - # pyrefly: ignore [bad-argument-type] engine.register_events(*Tbptt_Events) return engine diff --git a/ignite/distributed/comp_models/horovod.py b/ignite/distributed/comp_models/horovod.py index 844196ba6693..cf9a76ebb92b 100644 --- a/ignite/distributed/comp_models/horovod.py +++ b/ignite/distributed/comp_models/horovod.py @@ -1,6 +1,6 @@ import os import warnings -from typing import Any, Callable, cast, List, Mapping, Optional, Tuple +from typing import Any, Callable, cast, List, Mapping, Optional, Tuple, TYPE_CHECKING import torch @@ -20,6 +20,11 @@ except ImportError: has_hvd_support = False + if TYPE_CHECKING: + # Tell the type checker that hvd imports are always defined. + import horovod.torch as hvd + from horovod import run as hvd_mp_spawn + if has_hvd_support: HOROVOD = "horovod" @@ -171,11 +176,8 @@ def _setup_group(self, group: Any) -> hvd.ProcessSet: return group _reduce_op_map = { - # pyrefly: ignore [unbound-name] "SUM": hvd.mpi_ops.Sum, - # pyrefly: ignore [unbound-name] "AVERAGE": hvd.mpi_ops.Average, - # pyrefly: ignore [unbound-name] "ADASUM": hvd.mpi_ops.Adasum, } diff --git a/ignite/distributed/launcher.py b/ignite/distributed/launcher.py index 4c641a73623b..324533c618c5 100644 --- a/ignite/distributed/launcher.py +++ b/ignite/distributed/launcher.py @@ -322,19 +322,15 @@ def __enter__(self) -> "Parallel": idist.initialize(self.backend, init_method=self.init_method) # The logger can be setup from now since idist.initialize() has been called (if needed) - self._logger = setup_logger(__name__ + "." + self.__class__.__name__) # type: ignore[assignment] + self._logger = setup_logger(__name__ + "." + self.__class__.__name__) if self.backend is not None: if self._spawn_params is None: - self._logger.info( # type: ignore[attr-defined] - f"Initialized processing group with backend: '{self.backend}'" - ) + self._logger.info(f"Initialized processing group with backend: '{self.backend}'") else: - self._logger.info( # type: ignore[attr-defined] - f"Initialized distributed launcher with backend: '{self.backend}'" - ) + self._logger.info(f"Initialized distributed launcher with backend: '{self.backend}'") msg = "\n\t".join([f"{k}: {v}" for k, v in self._spawn_params.items() if v is not None]) - self._logger.info(f"- Parameters to spawn processes: \n\t{msg}") # type: ignore[attr-defined] + self._logger.info(f"- Parameters to spawn processes: \n\t{msg}") return self diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index 7f0fd9824904..052a8bbe6b1b 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -148,12 +148,11 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]): self.should_interrupt = False self.state = State() self._state_dict_user_keys: List[str] = [] - self._allowed_events: List[EventEnum] = [] + self._allowed_events: List[Union[str, EventEnum]] = [] self._dataloader_iter: Optional[Iterator[Any]] = None self._init_iter: Optional[int] = None - # pyrefly: ignore [bad-argument-type] self.register_events(*Events) if self._process_function is None: @@ -164,9 +163,7 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]): # generator provided by self._internal_run_as_gen self._internal_run_generator: Optional[Generator[Any, None, State]] = None - def register_events( - self, *event_names: Union[List[str], List[EventEnum]], event_to_attr: Optional[dict] = None - ) -> None: + def register_events(self, *event_names: Union[str, EventEnum], event_to_attr: Optional[dict] = None) -> None: """Add events that can be fired. Registering an event will let the user trigger these events at any point. @@ -451,7 +448,7 @@ def _fire_event(self, event_name: Any, *event_args: Any, **event_kwargs: Any) -> first, others = ((resolved_engine,), args[1:]) else: # metrics do not provide engine when registered - first, others = (tuple(), args) # type: ignore[assignment] + first, others = (tuple(), args) func(*first, *(event_args + others), **kwargs) @@ -990,9 +987,9 @@ def _internal_run(self) -> State: def _internal_run_as_gen(self) -> Generator[Any, None, State]: self.should_terminate = self.should_terminate_single_epoch = self.should_interrupt = False self._init_timers(self.state) + start_time = time.time() try: try: - start_time = time.time() self._fire_event(Events.STARTED) yield from self._maybe_terminate_or_interrupt() @@ -1011,7 +1008,7 @@ def _internal_run_as_gen(self) -> Generator[Any, None, State]: # time is available for handlers but must be updated after fire self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken - if self.should_terminate_single_epoch != "skip_epoch_completed": # type: ignore[comparison-overlap] + if self.should_terminate_single_epoch != "skip_epoch_completed": handlers_start_time = time.time() self._fire_event(Events.EPOCH_COMPLETED) epoch_time_taken += time.time() - handlers_start_time @@ -1039,13 +1036,12 @@ def _internal_run_as_gen(self) -> Generator[Any, None, State]: "https://github.com/pytorch/ignite/issues/new/choose" ) - # pyrefly: ignore [unbound-name] time_taken = time.time() - start_time # time is available for handlers but must be updated after fire self.state.times[Events.COMPLETED.name] = time_taken # do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True` - if self.should_terminate != "skip_completed": # type: ignore[comparison-overlap] + if self.should_terminate != "skip_completed": handlers_start_time = time.time() self._fire_event(Events.COMPLETED) time_taken += time.time() - handlers_start_time @@ -1191,9 +1187,9 @@ def _internal_run_legacy(self) -> State: # internal_run without generator for BC self.should_terminate = self.should_terminate_single_epoch = self.should_interrupt = False self._init_timers(self.state) + start_time = time.time() try: try: - start_time = time.time() self._fire_event(Events.STARTED) self._maybe_terminate_legacy() @@ -1212,7 +1208,7 @@ def _internal_run_legacy(self) -> State: # time is available for handlers but must be updated after fire self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken - if self.should_terminate_single_epoch != "skip_epoch_completed": # type: ignore[comparison-overlap] + if self.should_terminate_single_epoch != "skip_epoch_completed": handlers_start_time = time.time() self._fire_event(Events.EPOCH_COMPLETED) epoch_time_taken += time.time() - handlers_start_time @@ -1240,13 +1236,12 @@ def _internal_run_legacy(self) -> State: "https://github.com/pytorch/ignite/issues/new/choose" ) - # pyrefly: ignore [unbound-name] time_taken = time.time() - start_time # time is available for handlers but must be updated after fire self.state.times[Events.COMPLETED.name] = time_taken # do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True` - if self.should_terminate != "skip_completed": # type: ignore[comparison-overlap] + if self.should_terminate != "skip_completed": handlers_start_time = time.time() self._fire_event(Events.COMPLETED) time_taken += time.time() - handlers_start_time diff --git a/ignite/handlers/checkpoint.py b/ignite/handlers/checkpoint.py index bfd658fba454..c133d4ba2c13 100644 --- a/ignite/handlers/checkpoint.py +++ b/ignite/handlers/checkpoint.py @@ -315,7 +315,7 @@ def on_checkpoint_saved(engine): """ SAVED_CHECKPOINT = CheckpointEvents.SAVED_CHECKPOINT - Item = NamedTuple("Item", [("priority", int), ("filename", str)]) + Item = NamedTuple("Item", [("priority", Union[int, float]), ("filename", str)]) _state_dict_all_req_keys = ("_saved",) def __init__( @@ -323,7 +323,7 @@ def __init__( to_save: Mapping, save_handler: Union[str, Path, Callable, BaseSaveHandler], filename_prefix: str = "", - score_function: Optional[Callable] = None, + score_function: Optional[Callable[[Engine], Union[int, float]]] = None, score_name: Optional[str] = None, n_saved: Union[int, None] = 1, global_step_transform: Optional[Callable] = None, @@ -440,7 +440,6 @@ def _compare_fn(self, new: Union[int, float]) -> bool: def __call__(self, engine: Engine) -> None: if not engine.has_registered_events(CheckpointEvents.SAVED_CHECKPOINT): - # pyrefly: ignore [bad-argument-type] engine.register_events(*CheckpointEvents) global_step = None if self.global_step_transform is not None: @@ -455,7 +454,6 @@ def __call__(self, engine: Engine) -> None: global_step = engine.state.get_event_attrib_value(Events.ITERATION_COMPLETED) priority = global_step - # pyrefly: ignore [bad-argument-type] if self._check_lt_n_saved() or self._compare_fn(priority): priority_str = f"{priority}" if isinstance(priority, numbers.Integral) else f"{priority:.4f}" @@ -497,7 +495,6 @@ def __call__(self, engine: Engine) -> None: if isinstance(self.save_handler, BaseSaveHandler): self.save_handler.remove(item.filename) - # pyrefly: ignore [bad-argument-type] self._saved.append(Checkpoint.Item(priority, filename)) self._saved.sort(key=lambda it: it[0]) diff --git a/ignite/handlers/clearml_logger.py b/ignite/handlers/clearml_logger.py index 00af4531aa10..ff394afe1713 100644 --- a/ignite/handlers/clearml_logger.py +++ b/ignite/handlers/clearml_logger.py @@ -862,7 +862,7 @@ def _setup_check_clearml(self, logger: ClearMLLogger, output_uri: str) -> None: except ImportError: try: # Backwards-compatibility for legacy Trains SDK - from trains import Task # type: ignore[no-redef] + from trains import Task except ImportError: raise ModuleNotFoundError( "This contrib module requires clearml to be installed. " @@ -937,7 +937,7 @@ def __call__(self, checkpoint: Mapping, filename: str, metadata: Optional[Mappin except ImportError: try: # Backwards-compatibility for legacy Trains SDK - from trains.binding.frameworks import WeightsFileHandler # type: ignore[no-redef] + from trains.binding.frameworks import WeightsFileHandler except ImportError: raise ModuleNotFoundError( "This contrib module requires clearml to be installed. " diff --git a/ignite/handlers/lr_finder.py b/ignite/handlers/lr_finder.py index 5d9ba9cba358..10b3dd3dc67d 100644 --- a/ignite/handlers/lr_finder.py +++ b/ignite/handlers/lr_finder.py @@ -98,16 +98,18 @@ def _run( self._best_loss = None self._diverge_flag = False + assert trainer.state.epoch_length is not None + assert trainer.state.max_epochs is not None + # attach LRScheduler to trainer. if num_iter is None: - # pyrefly: ignore [unsupported-operation] num_iter = trainer.state.epoch_length * trainer.state.max_epochs else: - max_iter = trainer.state.epoch_length * trainer.state.max_epochs # type: ignore[operator] + max_iter = trainer.state.epoch_length * trainer.state.max_epochs if max_iter < num_iter: max_iter = num_iter trainer.state.max_iters = num_iter - trainer.state.max_epochs = ceil(num_iter / trainer.state.epoch_length) # type: ignore[operator] + trainer.state.max_epochs = ceil(num_iter / trainer.state.epoch_length) if not trainer.has_event_handler(self._reached_num_iterations): trainer.add_event_handler(Events.ITERATION_COMPLETED, self._reached_num_iterations, num_iter) @@ -179,18 +181,14 @@ def _log_lr_and_loss(self, trainer: Engine, output_transform: Callable, smooth_f loss = idist.all_reduce(loss) lr = self._lr_schedule.get_param() self._history["lr"].append(lr) - if trainer.state.iteration == 1: - self._best_loss = loss # type: ignore[assignment] - else: - if smooth_f > 0: - loss = smooth_f * loss + (1 - smooth_f) * self._history["loss"][-1] - # pyrefly: ignore [unsupported-operation] - if loss < self._best_loss: - self._best_loss = loss + if trainer.state.iteration != 1 and smooth_f > 0: + loss = smooth_f * loss + (1 - smooth_f) * self._history["loss"][-1] + if self._best_loss is None or loss < self._best_loss: + self._best_loss = loss self._history["loss"].append(loss) # Check if the loss has diverged; if it has, stop the trainer - if self._history["loss"][-1] > diverge_th * self._best_loss: # type: ignore[operator] + if self._history["loss"][-1] > diverge_th * self._best_loss: self._diverge_flag = True self.logger.info("Stopping early, the loss has diverged") trainer.terminate() diff --git a/ignite/handlers/param_scheduler.py b/ignite/handlers/param_scheduler.py index fccbc1c673ff..43bee495a9e4 100644 --- a/ignite/handlers/param_scheduler.py +++ b/ignite/handlers/param_scheduler.py @@ -1122,13 +1122,14 @@ def print_lr(): f"but given {type(lr_scheduler)}" ) - if not isinstance(warmup_duration, numbers.Integral): + if not isinstance(warmup_duration, int): raise TypeError(f"Argument warmup_duration should be integer, but given {warmup_duration}") if not (warmup_duration > 1): raise ValueError(f"Argument warmup_duration should be at least 2 events, but given {warmup_duration}") warmup_schedulers: List[ParamScheduler] = [] + milestones_values: List[Tuple[int, float]] = [] for param_group_index, param_group in enumerate(lr_scheduler.optimizer.param_groups): if warmup_end_value is None: @@ -1154,7 +1155,6 @@ def print_lr(): init_lr = lr_scheduler.get_param() if init_lr == param_group_warmup_end_value: if warmup_duration > 2: - # pyrefly: ignore [unsupported-operation] d = (param_group_warmup_end_value - warmup_start_value) / (warmup_duration - 1) milestones_values[-1] = (warmup_duration - 2, param_group_warmup_end_value - d) else: @@ -1164,7 +1164,6 @@ def print_lr(): PiecewiseLinear( lr_scheduler.optimizer, param_name="lr", - # pyrefly: ignore [bad-argument-type] milestones_values=milestones_values, param_group_index=param_group_index, save_history=save_history, @@ -1177,7 +1176,6 @@ def print_lr(): warmup_scheduler, lr_scheduler, ] - # pyrefly: ignore [unbound-name, unsupported-operation] durations = [milestones_values[-1][0] + 1] # pyrefly: ignore [bad-argument-type] combined_scheduler = ConcatScheduler(schedulers, durations=durations, save_history=save_history) @@ -1655,13 +1653,13 @@ def __init__( self.trainer = trainer self.optimizer = optimizer + min_lr: Union[float, List[float]] if "min_lr" in scheduler_kwargs and param_group_index is not None: min_lr = scheduler_kwargs["min_lr"] if not isinstance(min_lr, float): raise TypeError(f"When param_group_index is given, min_lr should be a float, but given {type(min_lr)}") _min_lr = min_lr min_lr = [0] * len(optimizer.param_groups) - # pyrefly: ignore [unsupported-operation] min_lr[param_group_index] = _min_lr else: min_lr = 0 @@ -1676,11 +1674,11 @@ def __init__( _scheduler_kwargs["verbose"] = False self.scheduler = ReduceLROnPlateau(optimizer, **_scheduler_kwargs) - self.scheduler._reduce_lr = self._reduce_lr # type: ignore[method-assign] + self.scheduler._reduce_lr = self._reduce_lr self._state_attrs += ["metric_name", "scheduler"] - def __call__(self, engine: Engine, name: Optional[str] = None) -> None: # type: ignore[override] + def __call__(self, engine: Engine, name: Optional[str] = None) -> None: if not hasattr(engine.state, "metrics") or self.metric_name not in engine.state.metrics: raise ValueError( "Argument engine should have in its 'state', attribute 'metrics' " diff --git a/ignite/handlers/state_param_scheduler.py b/ignite/handlers/state_param_scheduler.py index 3509086502ca..5cca2365ae30 100644 --- a/ignite/handlers/state_param_scheduler.py +++ b/ignite/handlers/state_param_scheduler.py @@ -1,7 +1,7 @@ import numbers import warnings from bisect import bisect_right -from typing import Any, List, Sequence, Tuple, Union +from typing import Any, Callable, List, Sequence, Tuple, Union from ignite.engine import CallableEventWithFilter, Engine, Events, EventsList from ignite.handlers.param_scheduler import BaseParamScheduler @@ -183,7 +183,13 @@ def print_param(): """ - def __init__(self, lambda_obj: Any, param_name: str, save_history: bool = False, create_new: bool = False): + def __init__( + self, + lambda_obj: Callable[[int], Union[List[float], float]], + param_name: str, + save_history: bool = False, + create_new: bool = False, + ): super(LambdaStateScheduler, self).__init__(param_name, save_history, create_new) if not callable(lambda_obj): @@ -193,7 +199,6 @@ def __init__(self, lambda_obj: Any, param_name: str, save_history: bool = False, self._state_attrs += ["lambda_obj"] def get_param(self) -> Union[List[float], float]: - # pyrefly: ignore [bad-return] return self.lambda_obj(self.event_index) diff --git a/ignite/handlers/time_profilers.py b/ignite/handlers/time_profilers.py index 7561de25b2aa..3aa85e115ec5 100644 --- a/ignite/handlers/time_profilers.py +++ b/ignite/handlers/time_profilers.py @@ -500,14 +500,14 @@ def __init__(self) -> None: self.dataflow_times: List[float] = [] self.processing_times: List[float] = [] - self.event_handlers_times: Dict[EventEnum, Dict[str, List[float]]] = {} + self.event_handlers_times: Dict[Union[str, EventEnum], Dict[str, List[float]]] = {} @staticmethod def _get_callable_name(handler: Callable) -> str: # get name of the callable handler return getattr(handler, "__qualname__", handler.__class__.__name__) - def _create_wrapped_handler(self, handler: Callable, event: EventEnum) -> Callable: + def _create_wrapped_handler(self, handler: Callable, event: Union[str, EventEnum]) -> Callable: @functools.wraps(handler) def _timeit_handler(*args: Any, **kwargs: Any) -> None: self._event_handlers_timer.reset() @@ -532,7 +532,7 @@ def _timeit_dataflow(self) -> None: t = self._dataflow_timer.value() self.dataflow_times.append(t) - def _reset(self, event_handlers_names: Mapping[EventEnum, List[str]]) -> None: + def _reset(self, event_handlers_names: Mapping[Union[str, EventEnum], List[str]]) -> None: # reset the variables used for profiling self.dataflow_times = [] self.processing_times = [] diff --git a/ignite/handlers/tqdm_logger.py b/ignite/handlers/tqdm_logger.py index 19bfe12c2e48..342c866985f7 100644 --- a/ignite/handlers/tqdm_logger.py +++ b/ignite/handlers/tqdm_logger.py @@ -223,7 +223,7 @@ def attach( # type: ignore[override] super(ProgressBar, self).attach(engine, log_handler, event_name) engine.add_event_handler(closing_event_name, self._close) - def attach_opt_params_handler( # type: ignore[empty-body] + def attach_opt_params_handler( self, engine: Engine, event_name: Union[str, Events], diff --git a/ignite/handlers/visdom_logger.py b/ignite/handlers/visdom_logger.py index 11053e10c5e7..9a10300c15a2 100644 --- a/ignite/handlers/visdom_logger.py +++ b/ignite/handlers/visdom_logger.py @@ -1,7 +1,7 @@ """Visdom logger and its helper handlers.""" import os -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union, TYPE_CHECKING import torch import torch.nn as nn @@ -165,7 +165,7 @@ def __init__( "pip install git+https://github.com/fossasia/visdom.git" ) - if num_workers > 0: + if num_workers > 0 or TYPE_CHECKING: # If visdom is installed, one of its dependencies `tornado` # requires also `futures` to be installed. # Let's check anyway if we can import it. @@ -199,7 +199,6 @@ def __init__( self.executor: Union[_DummyExecutor, "ThreadPoolExecutor"] = _DummyExecutor() if num_workers > 0: - # pyrefly: ignore [unbound-name] self.executor = ThreadPoolExecutor(max_workers=num_workers) def _save(self) -> None: diff --git a/ignite/metrics/accuracy.py b/ignite/metrics/accuracy.py index 47d72869319f..a4ea1fba364c 100644 --- a/ignite/metrics/accuracy.py +++ b/ignite/metrics/accuracy.py @@ -254,10 +254,10 @@ def update(self, output: Sequence[torch.Tensor]) -> None: y_pred = torch.transpose(y_pred, 1, last_dim - 1).reshape(-1, num_classes) y = torch.transpose(y, 1, last_dim - 1).reshape(-1, num_classes) correct = torch.all(y == y_pred.type_as(y), dim=-1) + else: + raise ValueError(f"Unexpected type: {self._type}") - # pyrefly: ignore [unbound-name] self._num_correct += torch.sum(correct).to(self._device) - # pyrefly: ignore [unbound-name] self._num_examples += correct.shape[0] @sync_all_reduce("_num_examples", "_num_correct") diff --git a/ignite/metrics/nlp/rouge.py b/ignite/metrics/nlp/rouge.py index 39a89438e920..47c0ac7a9e32 100644 --- a/ignite/metrics/nlp/rouge.py +++ b/ignite/metrics/nlp/rouge.py @@ -1,6 +1,5 @@ from abc import ABCMeta, abstractmethod -from collections import namedtuple -from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, List, Mapping, NamedTuple, Optional, Sequence, Tuple, Union import torch @@ -13,24 +12,25 @@ __all__ = ["Rouge", "RougeN", "RougeL"] -# pyrefly: ignore [invalid-inheritance] -class Score(namedtuple("Score", ["match", "candidate", "reference"])): +class Score(NamedTuple): r""" Computes precision and recall for given matches, candidate and reference lengths. """ + match: int + candidate: int + reference: int + def precision(self) -> float: """ Calculates precision. """ - # pyrefly: ignore [missing-attribute] return self.match / self.candidate if self.candidate > 0 else 0 def recall(self) -> float: """ Calculates recall. """ - # pyrefly: ignore [missing-attribute] return self.match / self.reference if self.reference > 0 else 0 diff --git a/mypy.ini b/mypy.ini deleted file mode 100644 index 489b3a3fd28c..000000000000 --- a/mypy.ini +++ /dev/null @@ -1,79 +0,0 @@ -[mypy] -files = ignite -pretty = True -show_error_codes = True - -check_untyped_defs = True -; a lot of work needed to fix issues -disallow_any_generics = False -disallow_incomplete_defs = True -disallow_subclassing_any = True -; due to missing types in pytorch set to False -disallow_untyped_calls = False -disallow_untyped_decorators = True -disallow_untyped_defs = True -no_implicit_optional = True -; would need a more precise import of pytorch classes and methods, which is not possible, therefore set to False -no_implicit_reexport = False -strict_equality = True -warn_redundant_casts = True -; due to missing types in multiple libs set to False -warn_return_any = False -; results in too many false positives, therefore set to False -warn_unreachable = False -warn_unused_configs = True -warn_unused_ignores = True - -[mypy-apex.*] -ignore_missing_imports = True - -[mypy-clearml.*] -ignore_missing_imports = True - -[mypy-horovod.*] -ignore_missing_imports = True - -[mypy-matplotlib.*] -ignore_missing_imports = True - -[mypy-mlflow.*] -ignore_missing_imports = True - -[mypy-neptune.*] -ignore_missing_imports = True - -[mypy-numpy.*] -ignore_missing_imports = True - -[mypy-pandas.*] -ignore_missing_imports = True - -[mypy-sklearn.*] -ignore_missing_imports = True - -[mypy-polyaxon.*] -ignore_missing_imports = True - -[mypy-polyaxon_client.*] -ignore_missing_imports = True - -[mypy-pynvml.*] -ignore_missing_imports = True - -[mypy-tensorboardX.*] -ignore_missing_imports = True - -[mypy-torch_xla.*] -ignore_missing_imports = True - -[mypy-trains.*] -ignore_missing_imports = True - -[mypy-tqdm.*] -ignore_missing_imports = True - -[mypy-scipy.*] -ignore_missing_imports = True - -[mypy-torchvision.*] -ignore_missing_imports = True diff --git a/requirements-dev.txt b/requirements-dev.txt index 8fe4bcbb02b7..ae66183dae01 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,7 +1,7 @@ # Tests dill filelock -mypy +pyrefly numpy pre-commit pytest diff --git a/tests/ignite/handlers/test_ema_handler.py b/tests/ignite/handlers/test_ema_handler.py index 2210de43a5a7..58608451cbf5 100644 --- a/tests/ignite/handlers/test_ema_handler.py +++ b/tests/ignite/handlers/test_ema_handler.py @@ -97,7 +97,7 @@ def check_ema_momentum(engine: Engine, momentum_warmup, final_momentum, warmup_i def test_ema_invalid_model(): with pytest.raises(ValueError, match="model should be an instance of nn.Module or its subclasses"): model = "Invalid Model" - EMAHandler(model) # type: ignore + EMAHandler(model) @pytest.mark.distributed