Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 4 additions & 11 deletions .github/workflows/typing-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ on:
- "tests/ignite/**"
- "requirements-dev.txt"
- "pyproject.toml"
- "mypy.ini"
- "pyrefly.toml"
- ".github/workflows/typing-checks.yml"
pull_request:
paths:
Expand All @@ -20,7 +20,7 @@ on:
- "tests/ignite/**"
- "requirements-dev.txt"
- "pyproject.toml"
- "mypy.ini"
- "pyrefly.toml"
- ".github/workflows/typing-checks.yml"
workflow_dispatch:

Expand All @@ -30,7 +30,7 @@ concurrency:
cancel-in-progress: true

jobs:
mypy:
pyrefly:
runs-on: ubuntu-latest
strategy:
matrix:
Expand Down Expand Up @@ -60,14 +60,7 @@ jobs:
run: uv pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu

- name: Install dependencies
run: |
uv pip install -r requirements-dev.txt
uv pip install .
uv pip install mypy
uv pip install pyrefly

- name: Run MyPy type checking
run: mypy
run: uv pip install -r requirements-dev.txt

- name: Run Pyrefly type checking
run: pyrefly check
10 changes: 5 additions & 5 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ into the following categories:
- [Formatting the code with a pre-commit hook](#formatting-the-code-with-a-pre-commit-hook)
- [Run tests](#run-tests)
- [Run distributed tests only on CPU](#run-distributed-tests-only-on-cpu)
- [Run Mypy checks](#run-mypy-checks)
- [Run Pyrefly checks](#run-pyrefly-checks)
- [Send a PR](#send-a-pr)
- [Sync up with the upstream](#sync-up-with-the-upstream)
- [Writing documentation](#writing-documentation)
Expand Down Expand Up @@ -214,15 +214,15 @@ export WORLD_SIZE=2
CUDA_VISIBLE_DEVICES="" pytest --dist=each --tx $WORLD_SIZE*popen//python=python tests/ -m distributed -vvv
```

#### Run Mypy checks:
#### Run Pyrefly checks:

To run mypy to check the optional static type:
To run pyrefly to check the optional static type:

```bash
mypy
pyrefly check
```

To change any config for specif folder, please see the file mypy.ini
To change any config for specific folder, please see the file pyrefly.toml

#### Send a PR

Expand Down
1 change: 0 additions & 1 deletion ignite/contrib/engines/tbptt.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,5 @@ def _update(engine: Engine, batch: Sequence[torch.Tensor]) -> float:
return sum(loss_list) / len(loss_list)

engine = Engine(_update)
# pyrefly: ignore [bad-argument-type]
engine.register_events(*Tbptt_Events)
return engine
10 changes: 6 additions & 4 deletions ignite/distributed/comp_models/horovod.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import warnings
from typing import Any, Callable, cast, List, Mapping, Optional, Tuple
from typing import Any, Callable, cast, List, Mapping, Optional, Tuple, TYPE_CHECKING

import torch

Expand All @@ -20,6 +20,11 @@
except ImportError:
has_hvd_support = False

if TYPE_CHECKING:
# Tell the type checker that hvd imports are always defined.
import horovod.torch as hvd
from horovod import run as hvd_mp_spawn


if has_hvd_support:
HOROVOD = "horovod"
Expand Down Expand Up @@ -171,11 +176,8 @@ def _setup_group(self, group: Any) -> hvd.ProcessSet:
return group

_reduce_op_map = {
# pyrefly: ignore [unbound-name]
"SUM": hvd.mpi_ops.Sum,
# pyrefly: ignore [unbound-name]
"AVERAGE": hvd.mpi_ops.Average,
# pyrefly: ignore [unbound-name]
"ADASUM": hvd.mpi_ops.Adasum,
}

Expand Down
12 changes: 4 additions & 8 deletions ignite/distributed/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,19 +322,15 @@ def __enter__(self) -> "Parallel":
idist.initialize(self.backend, init_method=self.init_method)

# The logger can be setup from now since idist.initialize() has been called (if needed)
self._logger = setup_logger(__name__ + "." + self.__class__.__name__) # type: ignore[assignment]
self._logger = setup_logger(__name__ + "." + self.__class__.__name__)

if self.backend is not None:
if self._spawn_params is None:
self._logger.info( # type: ignore[attr-defined]
f"Initialized processing group with backend: '{self.backend}'"
)
self._logger.info(f"Initialized processing group with backend: '{self.backend}'")
else:
self._logger.info( # type: ignore[attr-defined]
f"Initialized distributed launcher with backend: '{self.backend}'"
)
self._logger.info(f"Initialized distributed launcher with backend: '{self.backend}'")
msg = "\n\t".join([f"{k}: {v}" for k, v in self._spawn_params.items() if v is not None])
self._logger.info(f"- Parameters to spawn processes: \n\t{msg}") # type: ignore[attr-defined]
self._logger.info(f"- Parameters to spawn processes: \n\t{msg}")

return self

Expand Down
23 changes: 9 additions & 14 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,11 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]):
self.should_interrupt = False
self.state = State()
self._state_dict_user_keys: List[str] = []
self._allowed_events: List[EventEnum] = []
self._allowed_events: List[Union[str, EventEnum]] = []

self._dataloader_iter: Optional[Iterator[Any]] = None
self._init_iter: Optional[int] = None

# pyrefly: ignore [bad-argument-type]
self.register_events(*Events)

if self._process_function is None:
Expand All @@ -164,9 +163,7 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]):
# generator provided by self._internal_run_as_gen
self._internal_run_generator: Optional[Generator[Any, None, State]] = None

def register_events(
self, *event_names: Union[List[str], List[EventEnum]], event_to_attr: Optional[dict] = None
) -> None:
def register_events(self, *event_names: Union[str, EventEnum], event_to_attr: Optional[dict] = None) -> None:
"""Add events that can be fired.

Registering an event will let the user trigger these events at any point.
Expand Down Expand Up @@ -451,7 +448,7 @@ def _fire_event(self, event_name: Any, *event_args: Any, **event_kwargs: Any) ->
first, others = ((resolved_engine,), args[1:])
else:
# metrics do not provide engine when registered
first, others = (tuple(), args) # type: ignore[assignment]
first, others = (tuple(), args)

func(*first, *(event_args + others), **kwargs)

Expand Down Expand Up @@ -990,9 +987,9 @@ def _internal_run(self) -> State:
def _internal_run_as_gen(self) -> Generator[Any, None, State]:
self.should_terminate = self.should_terminate_single_epoch = self.should_interrupt = False
self._init_timers(self.state)
start_time = time.time()
try:
try:
start_time = time.time()
self._fire_event(Events.STARTED)
yield from self._maybe_terminate_or_interrupt()

Expand All @@ -1011,7 +1008,7 @@ def _internal_run_as_gen(self) -> Generator[Any, None, State]:
# time is available for handlers but must be updated after fire
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken

if self.should_terminate_single_epoch != "skip_epoch_completed": # type: ignore[comparison-overlap]
if self.should_terminate_single_epoch != "skip_epoch_completed":
handlers_start_time = time.time()
self._fire_event(Events.EPOCH_COMPLETED)
epoch_time_taken += time.time() - handlers_start_time
Expand Down Expand Up @@ -1039,13 +1036,12 @@ def _internal_run_as_gen(self) -> Generator[Any, None, State]:
"https://github.com/pytorch/ignite/issues/new/choose"
)

# pyrefly: ignore [unbound-name]
time_taken = time.time() - start_time
# time is available for handlers but must be updated after fire
self.state.times[Events.COMPLETED.name] = time_taken

# do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
if self.should_terminate != "skip_completed": # type: ignore[comparison-overlap]
if self.should_terminate != "skip_completed":
handlers_start_time = time.time()
self._fire_event(Events.COMPLETED)
time_taken += time.time() - handlers_start_time
Expand Down Expand Up @@ -1191,9 +1187,9 @@ def _internal_run_legacy(self) -> State:
# internal_run without generator for BC
self.should_terminate = self.should_terminate_single_epoch = self.should_interrupt = False
self._init_timers(self.state)
start_time = time.time()
try:
try:
start_time = time.time()
self._fire_event(Events.STARTED)
self._maybe_terminate_legacy()

Expand All @@ -1212,7 +1208,7 @@ def _internal_run_legacy(self) -> State:
# time is available for handlers but must be updated after fire
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken

if self.should_terminate_single_epoch != "skip_epoch_completed": # type: ignore[comparison-overlap]
if self.should_terminate_single_epoch != "skip_epoch_completed":
handlers_start_time = time.time()
self._fire_event(Events.EPOCH_COMPLETED)
epoch_time_taken += time.time() - handlers_start_time
Expand Down Expand Up @@ -1240,13 +1236,12 @@ def _internal_run_legacy(self) -> State:
"https://github.com/pytorch/ignite/issues/new/choose"
)

# pyrefly: ignore [unbound-name]
time_taken = time.time() - start_time
# time is available for handlers but must be updated after fire
self.state.times[Events.COMPLETED.name] = time_taken

# do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
if self.should_terminate != "skip_completed": # type: ignore[comparison-overlap]
if self.should_terminate != "skip_completed":
handlers_start_time = time.time()
self._fire_event(Events.COMPLETED)
time_taken += time.time() - handlers_start_time
Expand Down
7 changes: 2 additions & 5 deletions ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,15 +315,15 @@ def on_checkpoint_saved(engine):
"""

SAVED_CHECKPOINT = CheckpointEvents.SAVED_CHECKPOINT
Item = NamedTuple("Item", [("priority", int), ("filename", str)])
Item = NamedTuple("Item", [("priority", Union[int, float]), ("filename", str)])
_state_dict_all_req_keys = ("_saved",)

def __init__(
self,
to_save: Mapping,
save_handler: Union[str, Path, Callable, BaseSaveHandler],
filename_prefix: str = "",
score_function: Optional[Callable] = None,
score_function: Optional[Callable[[Engine], Union[int, float]]] = None,
score_name: Optional[str] = None,
n_saved: Union[int, None] = 1,
global_step_transform: Optional[Callable] = None,
Expand Down Expand Up @@ -440,7 +440,6 @@ def _compare_fn(self, new: Union[int, float]) -> bool:

def __call__(self, engine: Engine) -> None:
if not engine.has_registered_events(CheckpointEvents.SAVED_CHECKPOINT):
# pyrefly: ignore [bad-argument-type]
engine.register_events(*CheckpointEvents)
global_step = None
if self.global_step_transform is not None:
Expand All @@ -455,7 +454,6 @@ def __call__(self, engine: Engine) -> None:
global_step = engine.state.get_event_attrib_value(Events.ITERATION_COMPLETED)
priority = global_step

# pyrefly: ignore [bad-argument-type]
if self._check_lt_n_saved() or self._compare_fn(priority):
priority_str = f"{priority}" if isinstance(priority, numbers.Integral) else f"{priority:.4f}"

Expand Down Expand Up @@ -497,7 +495,6 @@ def __call__(self, engine: Engine) -> None:
if isinstance(self.save_handler, BaseSaveHandler):
self.save_handler.remove(item.filename)

# pyrefly: ignore [bad-argument-type]
self._saved.append(Checkpoint.Item(priority, filename))
self._saved.sort(key=lambda it: it[0])

Expand Down
4 changes: 2 additions & 2 deletions ignite/handlers/clearml_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ def _setup_check_clearml(self, logger: ClearMLLogger, output_uri: str) -> None:
except ImportError:
try:
# Backwards-compatibility for legacy Trains SDK
from trains import Task # type: ignore[no-redef]
from trains import Task
except ImportError:
raise ModuleNotFoundError(
"This contrib module requires clearml to be installed. "
Expand Down Expand Up @@ -937,7 +937,7 @@ def __call__(self, checkpoint: Mapping, filename: str, metadata: Optional[Mappin
except ImportError:
try:
# Backwards-compatibility for legacy Trains SDK
from trains.binding.frameworks import WeightsFileHandler # type: ignore[no-redef]
from trains.binding.frameworks import WeightsFileHandler
except ImportError:
raise ModuleNotFoundError(
"This contrib module requires clearml to be installed. "
Expand Down
22 changes: 10 additions & 12 deletions ignite/handlers/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,18 @@ def _run(
self._best_loss = None
self._diverge_flag = False

assert trainer.state.epoch_length is not None
assert trainer.state.max_epochs is not None

# attach LRScheduler to trainer.
if num_iter is None:
# pyrefly: ignore [unsupported-operation]
num_iter = trainer.state.epoch_length * trainer.state.max_epochs
else:
max_iter = trainer.state.epoch_length * trainer.state.max_epochs # type: ignore[operator]
max_iter = trainer.state.epoch_length * trainer.state.max_epochs
if max_iter < num_iter:
max_iter = num_iter
trainer.state.max_iters = num_iter
trainer.state.max_epochs = ceil(num_iter / trainer.state.epoch_length) # type: ignore[operator]
trainer.state.max_epochs = ceil(num_iter / trainer.state.epoch_length)

if not trainer.has_event_handler(self._reached_num_iterations):
trainer.add_event_handler(Events.ITERATION_COMPLETED, self._reached_num_iterations, num_iter)
Expand Down Expand Up @@ -179,18 +181,14 @@ def _log_lr_and_loss(self, trainer: Engine, output_transform: Callable, smooth_f
loss = idist.all_reduce(loss)
lr = self._lr_schedule.get_param()
self._history["lr"].append(lr)
if trainer.state.iteration == 1:
self._best_loss = loss # type: ignore[assignment]
else:
if smooth_f > 0:
loss = smooth_f * loss + (1 - smooth_f) * self._history["loss"][-1]
# pyrefly: ignore [unsupported-operation]
if loss < self._best_loss:
self._best_loss = loss
if trainer.state.iteration != 1 and smooth_f > 0:
loss = smooth_f * loss + (1 - smooth_f) * self._history["loss"][-1]
if self._best_loss is None or loss < self._best_loss:
self._best_loss = loss
self._history["loss"].append(loss)

# Check if the loss has diverged; if it has, stop the trainer
if self._history["loss"][-1] > diverge_th * self._best_loss: # type: ignore[operator]
if self._history["loss"][-1] > diverge_th * self._best_loss:
self._diverge_flag = True
self.logger.info("Stopping early, the loss has diverged")
trainer.terminate()
Expand Down
Loading
Loading