diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..06e8c78 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +# git-archive: ignore .gitignore, .gitattributes, .github/ etc. +.git* export-ignore diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml new file mode 100644 index 0000000..dc6bdda --- /dev/null +++ b/.github/workflows/CI.yml @@ -0,0 +1,53 @@ +name: CI + +on: + push: + branches: [ "master" ] + pull_request: + branches: [ "master" ] + +jobs: + test_2004: + name: ${{ matrix.os }} - Python ${{ matrix.python-version }} + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-20.04, windows-latest] + python-version: ["3.6", "3.7"] + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + env: + PIP_TRUSTED_HOST: "pypi.python.org pypi.org files.pythonhosted.org" + - name: Install reloading + run: python -m pip install . + - name: Test with unittest + run: python -m unittest + test_2204: + name: ${{ matrix.os }} - Python ${{ matrix.python-version }} + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-22.04, windows-latest] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install reloading + run: python -m pip install ".[development]" + - name: Lint with flake8 + run: flake8 + - name: Lint with ruff + run: ruff check . + - name: Type check with pyright + run: pyright + - name: Type check with mypy + run: mypy . + - name: Test with unittest + run: python -m unittest diff --git a/MANIFEST b/MANIFEST deleted file mode 100644 index d4a6512..0000000 --- a/MANIFEST +++ /dev/null @@ -1,6 +0,0 @@ -# file GENERATED by distutils, do NOT edit -setup.cfg -setup.py -reloading/__init__.py -reloading/reloading.py -reloading/test_reloading.py diff --git a/Makefile b/Makefile deleted file mode 100644 index e3a8b06..0000000 --- a/Makefile +++ /dev/null @@ -1,12 +0,0 @@ -PWD := $(shell pwd) - -.PHONY: test -test: test3.6 test3.7 test3.8 - -test3.6: - docker run -w /app -v $(PWD):/app python:3.6.10-alpine3.11 python -m unittest -test3.7: - docker run -w /app -v $(PWD):/app python:3.7.7-alpine3.11 python -m unittest -test3.8: - docker run -w /app -v $(PWD):/app python:3.8.3-alpine3.11 python -m unittest - diff --git a/README.md b/README.md index 623757d..76cc0a9 100644 --- a/README.md +++ b/README.md @@ -1,156 +1,161 @@ -# reloading -[![pypi badge](https://img.shields.io/pypi/v/reloading?color=%230c0)](https://pypi.org/project/reloading/) +# Reloading +[![CI](https://github.com/nneskildsf/reloading/actions/workflows/CI.yml/badge.svg)](https://github.com/nneskildsf/reloading/actions/workflows/CI.yml) +[![Checked with mypy](https://www.mypy-lang.org/static/mypy_badge.svg)](https://mypy-lang.org/) +[![Linting: Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/charliermarsh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) -A Python utility to reload a loop body from source on each iteration without -losing state +A Python utility to reload a function or loop body from source on each iteration without losing state. -Useful for editing source code during training of deep learning models. This lets -you e.g. add logging, print statistics or save the model without restarting the -training and, therefore, without losing the training progress. +## Installing Reloading and Supported Versions +This fork of reloading is *not* available on PyPi. Install it from Github: +```console +$ pip install https://github.com/nneskildsf/reloading/archive/refs/heads/master.zip +``` -![Demo](https://github.com/julvo/reloading/blob/master/examples/demo/demo.gif) +This fork of reloading supports Python 3.6+. -## Install -``` -pip install reloading -``` +## Supported Features +- Reload functions, `for` loops and `while` loops +- Works in Jupyter Notebook +- `break` and `continue` in loops +- Multiple reloading functions and loops in one file +- Reloaded functions preserve their original call signature +- Only reload source code when changed for faster performance +- Comprehensive exceptions and logging +- Exports locals of reloaded loops to parent locals (Python 3.13 and newer) ## Usage -To reload the body of a `for` loop from source before each iteration, simply -wrap the iterator with `reloading`, e.g. +### For Loop +To reload the body of a `for` loop from source before each iteration, wrap the iterator with `reloading`: ```python from reloading import reloading for i in reloading(range(10)): - # this code will be reloaded before each iteration + # This code will be reloaded before each iteration print(i) +``` + +### While Loop +To reload the body and condition of a `while` loop from source before each iteration, wrap the condition with `reloading`: +```python +from reloading import reloading +i = 0 +while reloading(i<10): + # This code and the condition (i<10) will be reloaded before each iteration + print(i) + i += 1 ``` +### Function To reload a function from source before each execution, decorate the function -definition with `@reloading`, e.g. +definition with `@reloading`: ```python from reloading import reloading @reloading -def some_function(): - # this code will be reloaded before each invocation +def function(): + # This code will be reloaded before each function call pass ``` -## Additional Options - -Pass the keyword argument `every` to reload only on every n-th invocation or iteration. E.g. +It is also possible to mark a function for reload after defining it: ```python -for i in reloading(range(1000), every=10): - # this code will only be reloaded before every 10th iteration - # this can help to speed-up tight loops - pass +from reloading import reloading -@reloading(every=10) -def some_function(): - # this code with only be reloaded before every 10th invocation +def function(): + # This function will be reloaded before each function call pass -``` -Pass `forever=True` instead of an iterable to create an endless reloading loop, e.g. -```python -for i in reloading(forever=True): - # this code will loop forever and reload from source before each iteration - pass +function = reloading(function) ``` -## Examples - -Here are the short snippets of how to use reloading with your favourite library. -For complete examples, check out the [examples folder](https://github.com/julvo/reloading/blob/master/examples). +## Additional Options -### PyTorch +### Interactive Exception Handling +Exceptions are handled interactively by default to avoid losing state. +When an exception occurs you will be notified and have the opportunity +to rectify the issue and continue. However, if reloading is +used in a setting where exceptions are better handled in the application +using reloading then it can be disabled by setting `interactive_exception` +to `False`. Example: ```python -for epoch in reloading(range(NB_EPOCHS)): - # the code inside this outer loop will be reloaded before each epoch - - for images, targets in dataloader: - optimiser.zero_grad() - predictions = model(images) - loss = F.cross_entropy(predictions, targets) - loss.backward() - optimiser.step() -``` -[Here](https://github.com/julvo/reloading/blob/master/examples/pytorch/train.py) is a full PyTorch example. +from reloading import reloading -### fastai -```python -@reloading -def update_learner(learner): - # this function will be reloaded from source before each epoch so that you - # can make changes to the learner while the training is running +@reloading(interactive_exception=False) +def reloading_function(): pass -class LearnerUpdater(LearnerCallback): - def on_epoch_begin(self, **kwargs): - update_learner(self.learn) +for i in reloading(range(10), interactive_exception=False): + pass -path = untar_data(URLs.MNIST_SAMPLE) -data = ImageDataBunch.from_folder(path) -learn = cnn_learner(data, models.resnet18, metrics=accuracy, - callback_fns=[LearnerUpdater]) -learn.fit(10) +j = 0 +while reloading(j<10, interactive_exception=False): + j += 1 ``` -[Here](https://github.com/julvo/reloading/blob/master/examples/fastai/train.py) is a full fastai example. -### Keras +### Iterate Forever in For Loop +To iterate forever in a `for` loop you can omit the argument: ```python -@reloading -def update_model(model): - # this function will be reloaded from source before each epoch so that you - # can make changes to the model while the training is running using - # K.set_value() - pass +from reloading import reloading -class ModelUpdater(Callback): - def on_epoch_begin(self, epoch, logs=None): - update_model(self.model) +for _ in reloading(): + # This code will loop forever and reload from source before each iteration + pass +``` -model = Sequential() -model.add(Dense(64, activation='relu', input_dim=20)) -model.add(Dense(10, activation='softmax')) +### Code Changes Logged +On Python 3.9 and newer, a diff is logged when the source code is updated. +Consider the following code as an example. +```python +from reloading import reloading +from time import sleep +import logging -sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True) -model.compile(loss='categorical_crossentropy', - optimizer=sgd, - metrics=['accuracy']) +log = logging.getLogger("reloading") +log.setLevel(logging.DEBUG) -model.fit(x_train, y_train, - epochs=200, - batch_size=128, - callbacks=[ModelUpdater()]) +for i in reloading(range(100)): + print(i) + sleep(1.0) +``` +After some time the code is edited. `i = 2*i` is added before `print(i)`, +resulting in the following log output: +```console +INFO:reloading:For loop at line 10 of file "../example.py" has been reloaded. +DEBUG:reloading:Code changes: ++i = i * 2 + print(i) + sleep(1.0) ``` -[Here](https://github.com/julvo/reloading/blob/master/examples/keras/train.py) is a full Keras example. -### TensorFlow +## Known Issus + +On Python version [less than 3.13](https://docs.python.org/3/reference/datamodel.html#frame.f_locals) it is not possible to properly export the local variables from a loop to parent locals. The following example demonstrates this: ```python -for epoch in reloading(range(NB_EPOCHS)): - # the code inside this outer loop will be reloaded from source - # before each epoch so that you can change it during training - - train_loss.reset_states() - train_accuracy.reset_states() - test_loss.reset_states() - test_accuracy.reset_states() - - for images, labels in tqdm(train_ds): - train_step(images, labels) - - for test_images, test_labels in tqdm(test_ds): - test_step(test_images, test_labels) -``` -[Here](https://github.com/julvo/reloading/blob/master/examples/tensorflow/train.py) is a full TensorFlow example. +from reloading import reloading -## Testing +def function(): + i = 0 + while reloading(i < 10): + i += 1 + print(i) -Make sure you have `python` and `python3` available in your path, then run: +function() # Prints 0. Not 10 as expected. Fixed in Python 3.13. +``` +A warning is emitted when the issue arises: +```console +WARNING:reloading:Variable(s) "i" in reloaded loop were not exported to the scope which called the reloaded loop at line... ``` -$ python3 reloading/test_reloading.py + +## Lint, Type Check and Testing + +Run: +```console +$ pip install -e ".[development]" +$ ruff check . +$ flake8 +$ pyright +$ mypy . +$ python -m unittest ``` diff --git a/examples/demo/demo.gif b/examples/demo/demo.gif deleted file mode 100644 index 803e46f..0000000 Binary files a/examples/demo/demo.gif and /dev/null differ diff --git a/examples/demo/demo.py b/examples/demo/demo.py deleted file mode 100644 index 57c881b..0000000 --- a/examples/demo/demo.py +++ /dev/null @@ -1,14 +0,0 @@ -import time -import sys -sys.path.insert(0, '../..') -from reloading import reloading - -epochs = 10000 -loss = 100 -model = { 'weights': [0.2, 0.1, 0.4, 0.8, 0.1] } - -for i in reloading(range(epochs)): - time.sleep(2) - loss /= 2 - - print('Epoch:', i, 'Loss:', loss) diff --git a/examples/fastai/requirements.txt b/examples/fastai/requirements.txt deleted file mode 100644 index 8a8b543..0000000 --- a/examples/fastai/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -fastai \ No newline at end of file diff --git a/examples/fastai/train.py b/examples/fastai/train.py deleted file mode 100644 index badab9e..0000000 --- a/examples/fastai/train.py +++ /dev/null @@ -1,42 +0,0 @@ -import sys -sys.path.insert(0, '../..') -from reloading import reloading - -from fastai.basic_train import LearnerCallback -from fastai.vision import (URLs, untar_data, ImageDataBunch, - cnn_learner, models, accuracy) - - -@reloading -def set_learning_rate(learner): - # Change the learning rate below during the training - learner.opt.opt.lr = 1e-3 - print('Set LR to', learner.opt.opt.lr) - -class LearningRateSetter(LearnerCallback): - def on_epoch_begin(self, **kwargs): - set_learning_rate(self.learn) - - -@reloading -def print_model_statistics(model): - # Uncomment the following lines after during the training - # to start printing statistics - # - # print('{: <28} {: <7} {: <7}'.format('NAME', ' MEAN', ' STDDEV')) - # for name, param in model.named_parameters(): - # mean = param.mean().item() - # std = param.std().item() - # print('{: <28} {: 6.4f} {: 6.4f}'.format(name, mean, std)) - pass - -class ModelStatsPrinter(LearnerCallback): - def on_epoch_begin(self, **kwargs): - print_model_statistics(self.learn.model) - - -path = untar_data(URLs.MNIST_SAMPLE) -data = ImageDataBunch.from_folder(path) -learn = cnn_learner(data, models.resnet18, metrics=accuracy, - callback_fns=[ModelStatsPrinter, LearningRateSetter]) -learn.fit(10) diff --git a/examples/keras/requirements.txt b/examples/keras/requirements.txt deleted file mode 100644 index 068d5df..0000000 --- a/examples/keras/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -keras -numpy \ No newline at end of file diff --git a/examples/keras/train.py b/examples/keras/train.py deleted file mode 100644 index 4d71bd1..0000000 --- a/examples/keras/train.py +++ /dev/null @@ -1,45 +0,0 @@ -# Example taken from https://keras.io/getting-started/sequential-model-guide/#examples -import sys -sys.path.insert(0, '../..') -from reloading import reloading - -import keras -from keras import backend as K -from keras.models import Sequential -from keras.layers import Dense, Activation -from keras.optimizers import SGD -from keras.callbacks import Callback - - -@reloading -def set_learning_rate(model): - # Change the below value during training and see how it updates - K.set_value(model.optimizer.lr, 1e-3) - print('Set LR to', K.get_value(model.optimizer.lr)) - -class LearningRateSetter(Callback): - def on_epoch_begin(self, epoch, logs=None): - set_learning_rate(self.model) - - -# Generate dummy data -import numpy as np -x_train = np.random.random((10000, 20)) -y_train = keras.utils.to_categorical(np.random.randint(10, size=(10000, 1)), num_classes=10) -x_test = np.random.random((1000, 20)) -y_test = keras.utils.to_categorical(np.random.randint(10, size=(1000, 1)), num_classes=10) - -model = Sequential() -model.add(Dense(64, activation='relu', input_dim=20)) -model.add(Dense(10, activation='softmax')) - -sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True) -model.compile(loss='categorical_crossentropy', - optimizer=sgd, - metrics=['accuracy']) - -model.fit(x_train, y_train, - epochs=200, - batch_size=128, - callbacks=[LearningRateSetter()]) -score = model.evaluate(x_test, y_test, batch_size=128) \ No newline at end of file diff --git a/examples/pytorch/requirements.txt b/examples/pytorch/requirements.txt deleted file mode 100644 index be3d297..0000000 --- a/examples/pytorch/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -torch -torchvision -tqdm diff --git a/examples/pytorch/train.py b/examples/pytorch/train.py deleted file mode 100644 index b3cd304..0000000 --- a/examples/pytorch/train.py +++ /dev/null @@ -1,44 +0,0 @@ -import sys -sys.path.insert(0, '../..') -from reloading import reloading - -from torch import nn -from torch.optim import Adam -import torch.nn.functional as F -from torchvision.models import resnet18 -from torchvision.datasets import FashionMNIST -from torchvision.transforms import ToTensor -from torch.utils.data import DataLoader -from tqdm import tqdm - - -dataset = FashionMNIST('.', download=True, transform=ToTensor()) -dataloader = DataLoader(dataset, batch_size=8) - -model = resnet18(pretrained=True) -model.fc = nn.Linear(model.fc.in_features, 10) - -optimiser = Adam(model.parameters()) - -for epoch in reloading(range(1000)): - # Try to change the code inside this loop during the training and see how the - # changes are applied without restarting the training - - model.train() - losses = [] - - for images, targets in tqdm(dataloader): - losses.append(1) - - optimiser.zero_grad() - predictions = model(images.expand(8, 3, 28, 28)) - loss = F.cross_entropy(predictions, targets) - loss.backward() - optimiser.step() - losses.append(loss.item()) - - # Here would be your validation code - - print(f'Epoch {epoch} - Loss {sum(losses) / len(losses)}') - - diff --git a/examples/tensorflow/requirements.txt b/examples/tensorflow/requirements.txt deleted file mode 100644 index af0795f..0000000 --- a/examples/tensorflow/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -tensorflow -tqdm \ No newline at end of file diff --git a/examples/tensorflow/train.py b/examples/tensorflow/train.py deleted file mode 100644 index 5fb32e8..0000000 --- a/examples/tensorflow/train.py +++ /dev/null @@ -1,95 +0,0 @@ -# Example from https://www.tensorflow.org/tutorials/quickstart/advanced - -from __future__ import absolute_import, division, print_function, unicode_literals - -import sys -sys.path.insert(0, '../..') -from reloading import reloading - -import tensorflow as tf -from tensorflow.keras.layers import Dense, Flatten, Conv2D -from tensorflow.keras import Model - -from tqdm import tqdm - -(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() -x_train, x_test = x_train / 255.0, x_test / 255.0 - -# Add a channels dimension -x_train = x_train[..., tf.newaxis] -x_test = x_test[..., tf.newaxis] - -train_ds = tf.data.Dataset.from_tensor_slices( - (x_train, y_train)).shuffle(10000).batch(32) - -test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32) - -class MyModel(Model): - def __init__(self): - super(MyModel, self).__init__() - self.conv1 = Conv2D(32, 3, activation='relu') - self.flatten = Flatten() - self.d1 = Dense(128, activation='relu') - self.d2 = Dense(10) - - def call(self, x): - x = self.conv1(x) - x = self.flatten(x) - x = self.d1(x) - return self.d2(x) - -model = MyModel() - -loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) - -optimizer = tf.keras.optimizers.Adam() - -train_loss = tf.keras.metrics.Mean(name='train_loss') -train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy') - -test_loss = tf.keras.metrics.Mean(name='test_loss') -test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy') - -@tf.function -def train_step(images, labels): - with tf.GradientTape() as tape: - predictions = model(images, training=True) - loss = loss_object(labels, predictions) - gradients = tape.gradient(loss, model.trainable_variables) - optimizer.apply_gradients(zip(gradients, model.trainable_variables)) - - train_loss(loss) - train_accuracy(labels, predictions) - -@tf.function -def test_step(images, labels): - predictions = model(images, training=False) - t_loss = loss_object(labels, predictions) - - test_loss(t_loss) - test_accuracy(labels, predictions) - -EPOCHS = 5 - -for epoch in reloading(range(EPOCHS)): - # Try to change the source code inside this loop during the training to - # see how the changes are applied without restarting the training. - # You can use it e.g. to inspect the model or changing the learning rate. - - train_loss.reset_states() - train_accuracy.reset_states() - test_loss.reset_states() - test_accuracy.reset_states() - - for images, labels in tqdm(train_ds): - train_step(images, labels) - - for test_images, test_labels in tqdm(test_ds): - test_step(test_images, test_labels) - - template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}' - print(template.format(epoch+1, - train_loss.result(), - train_accuracy.result()*100, - test_loss.result(), - test_accuracy.result()*100)) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..8cc7f45 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,50 @@ +[build-system] +requires = ["flit_core >=3.2,<4"] +build-backend = "flit_core.buildapi" + +[project] +name = "reloading" +dependencies = [] +requires-python = ">=3.6" +authors = [{name = "Julian Vossen", email = "pypi@julianvossen.de"}] +maintainers = [{name = "Eskild Schroll-Fleischer", email = "eyfl@novonordisk.com"}] +version="1.2.0" +license = {file = "LICENSE.txt"} +readme = "README.md" +description = "Reloads source code of a running program without losing state." +keywords = ["reload", "reloading", "refresh", "loop", "decorator"] +classifiers=[ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Topic :: Utilities", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", +] + +[project.urls] +Repository = "https://github.com/eskildsf/reloading" +DOWNLOAD = "https://github.com/nneskildsf/reloading/archive/refs/heads/master.zip" + +[project.optional-dependencies] +development = [ + "nbformat", + "flake8", + "pyright", + "mypy", + "ruff", + "build", +] + +[tool.pyright] +include = ["reloading"] +exclude = [ + "**/__pycache__", +] diff --git a/reloading/__init__.py b/reloading/__init__.py index dffa171..d899ad5 100644 --- a/reloading/__init__.py +++ b/reloading/__init__.py @@ -1 +1 @@ -from .reloading import reloading +from .reloading import reloading as reloading # noqa diff --git a/reloading/reloading.py b/reloading/reloading.py index 59dbd1c..16e9a16 100644 --- a/reloading/reloading.py +++ b/reloading/reloading.py @@ -1,66 +1,186 @@ +import ast +import difflib +import functools import inspect +import itertools +import logging +import os import sys -import ast import traceback import types -from itertools import chain -from functools import partial, update_wrapper +from typing import (Any, + Callable, + Dict, + Iterable, + List, + Optional, + overload, + Union) + +log = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +def get_diff_text(ast_before: ast.Module, ast_after: ast.Module): + """ + Calculate difference between two versions of reloaded code. + """ + # Unparse was introduced in Python 3.9. + if sys.version_info.major >= 3 and sys.version_info.minor >= 9: + # mypy complains that unparse is not available in some + # versions of Python. + unparse = getattr(ast, "unparse") + code_before = unparse(ast_before) + code_after = unparse(ast_after) + diff = difflib.unified_diff(code_before.splitlines(), + code_after.splitlines(), + lineterm="") + # Omit first three lines because they contain superfluous information. + changes = list(diff)[3:] + if len(changes): + return "\n".join(["Code changes:"]+changes) + else: + return "No code changes." + else: + return "Cannot compute code changes. Requires Python > 3.9." + + +class ReloadingException(Exception): + pass + + +@overload +def reloading(fn_or_seq_or_bool: Iterable) -> Iterable: ... + + +@overload +def reloading(fn_or_seq_or_bool: Iterable, + interactive_exception: bool) -> Iterable: ... + + +@overload +def reloading(fn_or_seq_or_bool: bool) -> Iterable: ... + + +@overload +def reloading(fn_or_seq_or_bool: bool, + interactive_exception: bool) -> Iterable: ... + + +@overload +def reloading(fn_or_seq_or_bool: Callable) -> Callable: ... -# have to make our own partial in case someone wants to use reloading as a iterator without any arguments -# they would get a partial back because a call without a iterator argument is assumed to be a decorator. -# getting a "TypeError: 'functools.partial' object is not iterable" -# which is not really descriptive. -# hence we overwrite the iter to make sure that the error makes sense. -class no_iter_partial(partial): - def __iter__(self): - raise TypeError("Nothing to iterate over. Please pass an iterable to reloading.") +@overload +def reloading(fn_or_seq_or_bool: Callable, + interactive_exception: bool) -> Callable: ... -def reloading(fn_or_seq=None, every=1, forever=None): - """Wraps a loop iterator or decorates a function to reload the source code +@overload +def reloading(fn_or_seq_or_bool: None) -> Iterable: ... + + +@overload +def reloading(*, interactive_exception: bool) -> Callable: ... + + +@overload +def reloading() -> Callable: ... + + +def reloading(fn_or_seq_or_bool: Optional[ + Union[Iterable, + Callable, + bool]] = None, + interactive_exception: bool = True) -> Union[Iterable, + Callable]: + """ + Wraps a loop iterator or decorates a function to reload the source code before every loop iteration or function invocation. When wrapped around the outermost iterator in a `for` loop, e.g. `for i in reloading(range(10))`, causes the loop body to reload from source before every iteration while keeping the state. + When wrapped around the condition of a `while` loop, e.g. + `while reloading(i<10)`, causes the loop body and condition to reload from + source before every iteration while keeping the state. When used as a function decorator, the decorated function is reloaded from source before each execution. - Pass the integer keyword argument `every` to reload the source code - only every n-th iteration/invocation. - Args: - fn_or_seq (function | iterable): A function or loop iterator which should - be reloaded from source before each invocation or iteration, - respectively - every (int, Optional): After how many iterations/invocations to reload - forever (bool, Optional): Pass `forever=true` instead of an iterator to - create an endless loop - + fn_or_seq_or_bool: + A function, iterator or condition which should be reloaded from + source before each invocation or iteration, respectively. + interactive_exception: + Exceptions raised from reloading code are caught and you can fix + the code without losing state. """ - if fn_or_seq: - if isinstance(fn_or_seq, types.FunctionType): - return _reloading_function(fn_or_seq, every=every) - return _reloading_loop(fn_or_seq, every=every) - if forever: - return _reloading_loop(iter(int, 1), every=every) - - # return this function with the keyword arguments partialed in, - # so that the return value can be used as a decorator - decorator = update_wrapper(no_iter_partial(reloading, every=every), reloading) - return decorator - -def unique_name(used): - # get the longest element of the used names and append a "0" - return max(used, key=len) + "0" + def wrap(x): + if callable(x): + return _reloading_function(x, interactive_exception) + else: + raise TypeError(f'reloading expected function, got' + f', "{type(fn_or_seq_or_bool)}"') + + if fn_or_seq_or_bool is not None: + if callable(fn_or_seq_or_bool): + return _reloading_function(fn_or_seq_or_bool, + interactive_exception) + elif (isinstance(fn_or_seq_or_bool, Iterable) or + isinstance(fn_or_seq_or_bool, bool)): + return _reloading_loop(fn_or_seq_or_bool, + interactive_exception) + else: + raise TypeError( + f'reloading expected function. iterable or bool' + f', got "{type(fn_or_seq_or_bool)}"' + ) + else: + # If reloading was called as a decorator with an argument, + # then we expect fn_or_seq_or_bool to be None, which is OK. + # However, if reloading was not called as a decorator and it + # did not get an argument then we assume that the user desired + # infinite iteration for a loop. + # Source: https://stackoverflow.com/questions/52191968/ + current_frame = inspect.currentframe() + assert isinstance(current_frame, types.FrameType) + assert isinstance(current_frame.f_back, types.FrameType) + frame = inspect.getframeinfo(current_frame.f_back, context=1) + assert frame.code_context is not None + # Remove whitespace due to indentation before .startswith + if frame.code_context[0].strip().startswith("@"): + return wrap + else: + return _reloading_loop(itertools.count(), interactive_exception) + + +def unique_name(seq: itertools.chain) -> str: + """ + Function to generate string which is unique + relative to the supplied sequence. + """ + return max(seq, key=len) + "0" -def format_itervars(ast_node): - """Formats an `ast_node` of loop iteration variables as string, e.g. 'a, b'""" +def format_iteration_variables(ast_node: Union[ast.Name, + ast.Tuple, + ast.List, + None]) -> str: + """ + Formats an `ast_node` of loop iteration variables as string. + """ + # ast.Name corresponds to cases where the iterator returns + # a single element. + # Example: + # for i in range(10): + # pass + # ast.Tuple/ast.List corresponds to multiple elements: + # for i, j in zip(range(10), range(10)): + # pass + if ast_node is None: + return "" - # handle the case that there only is a single loop var if isinstance(ast_node, ast.Name): return ast_node.id @@ -69,195 +189,638 @@ def format_itervars(ast_node): if isinstance(child, ast.Name): names.append(child.id) elif isinstance(child, ast.Tuple) or isinstance(child, ast.List): - # if its another tuple, like "a, (b, c)", recurse - names.append("({})".format(format_itervars(child))) + # Recursion to handle additional tuples such as "a, (b, c)" + names.append("("+format_iteration_variables(child)+")") return ", ".join(names) -def load_file(path): +def load_file(filename: str) -> str: + """ + Read contents of file containing reloading code. + Handle case of file appearing empty on read. + """ src = "" - # while loop here since while saving, the file may sometimes be empty. - while (src == ""): - with open(path, "r") as f: - src = f.read() - return src + "\n" - - -def parse_file_until_successful(path): - source = load_file(path) + while True: + with open(filename, "r") as f: + if filename.endswith(".ipynb"): + import nbformat + # Read Jupyter Notebook v. 4 + notebook = nbformat.read(f, 4) + # Create list of all code blocks + blocks = [cell.source for cell in notebook.cells + if cell["cell_type"] == "code"] + # Join all blocks (a block is a multiline string of code) + jupyter_code = "\n".join(blocks) + # Jupyter has a magic meaning of !. Lines which start + # with "!" are not Python code. Comment them out. + lines = [line.replace("!", "# !", 1) if line.startswith("!") + else line for line in jupyter_code.split("\n")] + src = "\n".join(lines) + else: + src = f.read() + if len(src): + return src + "\n" + + +def parse_file_until_successful(filename: str, + interactive_exception: bool) -> ast.Module: + """ + Parse source code of file containing reloading code. + File may appear incomplete as as it is read so retry until successful. + """ + source = load_file(filename) while True: try: tree = ast.parse(source) return tree except SyntaxError: - handle_exception(path) - source = load_file(path) - - -def isolate_loop_body_and_get_itervars(tree, lineno, loop_id): - """Modifies tree inplace as unclear how to create ast.Module. - Returns itervars""" - candidate_nodes = [] - for node in ast.walk(tree): - if ( - isinstance(node, ast.For) - and isinstance(node.iter, ast.Call) - and node.iter.func.id == "reloading" - and ( - (loop_id is not None and loop_id == get_loop_id(node)) - or getattr(node, "lineno", None) == lineno - ) - ): - candidate_nodes.append(node) - - if len(candidate_nodes) > 1: - raise LookupError( - "The reloading loop is ambigious. Use `reloading` only once per line and make sure that the code in that line is unique within the source file." - ) - - if len(candidate_nodes) < 1: - raise LookupError( - "Could not locate reloading loop. Please make sure the code in the line that uses `reloading` doesn't change between reloads." - ) - - loop_node = candidate_nodes[0] - tree.body = loop_node.body - return loop_node.target, get_loop_id(loop_node) - - -def get_loop_id(ast_node): - """Generates a unique identifier for an `ast_node` of type ast.For to find the loop in the changed source file - """ - return ast.dump(ast_node.target) + "__" + ast.dump(ast_node.iter) + handle_exception(filename, interactive_exception) + source = load_file(filename) -def get_loop_code(loop_frame_info, loop_id): - fpath = loop_frame_info[1] - while True: - tree = parse_file_until_successful(fpath) - try: - itervars, found_loop_id = isolate_loop_body_and_get_itervars(tree, lineno=loop_frame_info[2], loop_id=loop_id) - return compile(tree, filename="", mode="exec"), format_itervars(itervars), found_loop_id - except LookupError: - handle_exception(fpath) +break_ast = ast.parse('raise Exception("break")').body +continue_ast = ast.parse('raise Exception("continue")').body -def handle_exception(fpath): - exc = traceback.format_exc() - exc = exc.replace('File ""', 'File "{}"'.format(fpath)) - sys.stderr.write(exc + "\n") - print("Edit {} and press return to continue".format(fpath)) - sys.stdin.readline() +class ReplaceBreakContinueWithExceptions(ast.NodeTransformer): + def visit_Break(self, node): + return break_ast + def visit_Continue(self, node): + return continue_ast -def _reloading_loop(seq, every=1): - loop_frame_info = inspect.stack()[2] - fpath = loop_frame_info[1] - caller_globals = loop_frame_info[0].f_globals - caller_locals = loop_frame_info[0].f_locals +class WhileLoop: + """ + Object to hold ast and test-function for a reloading while loop. + """ + def __init__(self, + ast_module: ast.Module, + test: ast.Call, + filename: str, + node_id: str): + self.ast: ast.Module = ast_module + self.test: ast.Call = test + self.id: str = node_id + # Replace "break" and "continue" with custom exceptions. + # Otherwise SyntaxError is raised because these instructions + # are called outside a loop. + ReplaceBreakContinueWithExceptions().visit(ast_module) + ast.fix_missing_locations(ast_module) + self.compiled_body = compile(ast_module, + filename=filename, + mode="exec") + # If no argument was supplied, then loop forever + ast_condition = ast.Expression(body=ast.Constant(True)) + if len(test.args) > 0: + # Create expression to evaluate condition + # reloading only takes one argument, so we can + # pick the first element of the args. + ast_condition = ast.Expression(body=test.args[0]) + ast.fix_missing_locations(ast_condition) + self.condition = compile(ast_condition, filename="", mode="eval") + + +class ForLoop: + """ + Object to hold ast and iteration variables for a reloading for loop. + """ + def __init__(self, + ast_module: ast.Module, + iteration_variables: Union[ast.Name, + ast.Tuple, + ast.List], + filename: str, + node_id: str): + self.ast: ast.Module = ast_module + # Replace "break" and "continue" with custom exceptions. + # Otherwise SyntaxError is raised because these instructions + # are called outside a loop. + ReplaceBreakContinueWithExceptions().visit(ast_module) + ast.fix_missing_locations(ast_module) + self.compiled_body = compile(ast_module, + filename=filename, + mode="exec") + self.iteration_variables: Union[ast.Name, + ast.Tuple, + ast.List] = iteration_variables + self.iteration_variables_str: str = format_iteration_variables( + iteration_variables) + self.id: str = node_id + + +def get_loop_object(loop_frame_info: inspect.FrameInfo, + reloaded_file_ast: ast.Module, + filename: str, + loop_id: Union[None, str]) -> Union[WhileLoop, ForLoop]: + """ + Traverse AST for the entire reloaded file in a search for the + loop which is reloaded. + """ + Parentage().visit(reloaded_file_ast) + candidates: List[Union[ast.For, ast.While]] = [] + for node in ast.walk(reloaded_file_ast): + if isinstance(node, ast.For) and isinstance(node.iter, ast.Call): + # Handle "for reloading(iter)" as well as + # "for reloading.reloading(iter)". + if (getattr(node.iter.func, "id", "") == "reloading" or + getattr(node.iter.func, "attr", "") == "reloading") and ( + (loop_id is not None and loop_id == get_loop_id(node)) + or getattr(node, "lineno", None) == loop_frame_info.lineno): + candidates.append(node) + if isinstance(node, ast.While) and isinstance(node.test, ast.Call): + if (getattr(node.test.func, "id", "") == "reloading" or + getattr(node.test.func, "attr", "") == "reloading") and ( + (loop_id is not None and loop_id == get_loop_id(node)) + or getattr(node, "lineno", None) == loop_frame_info.lineno): + candidates.append(node) + + if len(candidates) == 0 and loop_id is None: + raise ReloadingException("Reloading used outside a loop.") + elif len(candidates) == 0 and loop_id: + raise ReloadingException( + f'Unable to reload loop initially defined at line ' + f'{loop_frame_info.lineno} ' + f'in file "{filename}". ' + 'The loop might have been removed.' + ) - # create a unique name in the caller namespace that we can safely write - # the values of the iteration variables into - unique = unique_name(chain(caller_locals.keys(), caller_globals.keys())) - loop_id = None + # Select the candidate node which is closest to function_frame_info + def sorting_function(candidate): + return abs(candidate.lineno - loop_frame_info.lineno) + candidate = min(candidates, key=sorting_function) + loop_node_ast = ast.Module(candidate.body, type_ignores=[]) + if isinstance(candidate, ast.For): + assert isinstance(candidate.target, (ast.Name, ast.Tuple, ast.List)) + return ForLoop(loop_node_ast, + candidate.target, + filename, + get_loop_id(candidate)) + elif isinstance(candidate, ast.While): + assert isinstance(candidate.test, ast.Call) + return WhileLoop(loop_node_ast, + candidate.test, + filename, + get_loop_id(candidate)) + raise ReloadingException("No loop node found.") + + +def get_loop_id(ast_node: Union[ast.For, ast.While]) -> str: + """ + Generates a unique identifier for an `ast_node`. + Used to identify the loop in the changed source file. + """ + if isinstance(ast_node, ast.For): + return "_".join([get_node_id(ast_node), + ast.dump(ast_node.target), + ast.dump(ast_node.iter)]) + elif isinstance(ast_node, ast.While): + return "_".join([get_node_id(ast_node), + ast.dump(ast_node.test)]) + + +def get_loop_code(loop_frame_info: inspect.FrameInfo, + loop_id: Union[None, str], + filename: str, + interactive_exception: bool) -> Union[WhileLoop, ForLoop]: + while True: + reloaded_file_ast: ast.Module = parse_file_until_successful( + filename, + interactive_exception + ) + try: + return get_loop_object( + loop_frame_info, reloaded_file_ast, filename, loop_id=loop_id + ) + except (LookupError, ReloadingException): + handle_exception(filename, interactive_exception=False) - for i, itervar_values in enumerate(seq): - if i % every == 0: - compiled_body, itervars, loop_id = get_loop_code(loop_frame_info, loop_id=loop_id) - caller_locals[unique] = itervar_values - exec(itervars + " = " + unique, caller_globals, caller_locals) +def handle_exception(filename: str, interactive_exception): + """ + Output helpful error message to user regarding exception in reloaded code. + """ + # Report traceback stack starting with the reloaded file. + # This avoids listing stack frames from this library. + # Source: https://stackoverflow.com/questions/45771299 + frame_summaries = traceback.extract_tb(sys.exc_info()[2]) + count = len(frame_summaries) + # find the first occurrence of the module file name + for frame_summary in frame_summaries: + if frame_summary.filename == filename: + break + count -= 1 + exception_text = traceback.format_exc(limit=-count) + # Even though "filename" is passed to calls to compile, + # we still have to replace the filename when an error + # occours during compiling. + exception_text = exception_text.replace('File ""', + f'File "{filename}"') + exception_text = exception_text.replace('File ""', + f'File "{filename}"') + if sys.stdin.isatty() and interactive_exception: + log.error("Exception occourred. Press to continue " + "or to exit:\n" + exception_text) try: - # run main loop body - exec(compiled_body, caller_globals, caller_locals) - except Exception: - handle_exception(fpath) - + sys.stdin.readline() + except KeyboardInterrupt: + log.info("\nExiting...") + sys.exit(1) + else: + raise + + +def execute_for_loop(seq: Iterable, + loop_frame_info: inspect.FrameInfo, + filename: str, + interactive_exception: bool): + caller_globals: Dict[str, Any] = loop_frame_info.frame.f_globals + caller_locals: Dict[str, Any] = loop_frame_info.frame.f_locals + + # Initialize variables + file_stat: int = 0 + vacant_variable_name: str = "" + assign_compiled = compile("", filename="", mode="exec") + for_loop = get_loop_code(loop_frame_info, + None, + filename, + interactive_exception) + + for i, iteration_variable_values in enumerate(seq): + # Reload code if possibly modified + file_stat_: int = os.stat(filename).st_mtime_ns + if file_stat != file_stat_: + if i > 0: + log.info(f'For loop at line {loop_frame_info.lineno} of file ' + f'"{filename}" has been reloaded.') + ast_before = for_loop.ast + for_loop = get_loop_code(loop_frame_info, + for_loop.id, + filename, + interactive_exception) + assert isinstance(for_loop, ForLoop) + ast_after = for_loop.ast + log.debug(get_diff_text(ast_before, ast_after)) + file_stat = file_stat_ + # Make up a name for a variable which is not already present in + # the global or local namespace. + vacant_variable_name = unique_name( + itertools.chain(caller_locals.keys(), caller_globals.keys()) + ) + # Reassign variable values from vacant variable in local scope + assign = ast.Module([ + ast.Assign(targets=[for_loop.iteration_variables], + value=ast.Name(vacant_variable_name, ast.Load()))], + type_ignores=[]) + ast.fix_missing_locations(assign) + assign_compiled = compile(assign, filename='', mode='exec') + # Store iteration variable values in vacant variable in local scope + caller_locals[vacant_variable_name] = iteration_variable_values + exec(assign_compiled, caller_globals, caller_locals) + # Clean up namespace + del caller_locals[vacant_variable_name] + try: + exec(for_loop.compiled_body, caller_globals, caller_locals) + except Exception as exception: + # A "break" inside the loop body will cause a SyntaxError + # because the code is executed outside the scope of a loop. + # We catch the exception and break *this* loop. + if exception.args == ("break",): + break + if exception.args == ("continue",): + continue + else: + handle_exception(filename, interactive_exception) + + +def execute_while_loop(loop_frame_info: inspect.FrameInfo, + filename: str, + interactive_exception: bool): + caller_globals: Dict[str, Any] = loop_frame_info.frame.f_globals + caller_locals: Dict[str, Any] = loop_frame_info.frame.f_locals + + file_stat: int = os.stat(filename).st_mtime_ns + while_loop = get_loop_code(loop_frame_info, + None, + filename, + interactive_exception) + + def condition(while_loop): + return eval(while_loop.condition, caller_globals, caller_locals) + + i = 0 + while condition(while_loop): + i += 1 + # Reload code if possibly modified + file_stat_: int = os.stat(filename).st_mtime_ns + if file_stat != file_stat_: + log.info(f'While loop at line {loop_frame_info.lineno} of file ' + f'"{filename}" has been reloaded.') + ast_before = while_loop.ast + while_loop = get_loop_code(loop_frame_info, + while_loop.id, + filename, + interactive_exception) + ast_after = while_loop.ast + log.debug(get_diff_text(ast_before, ast_after)) + file_stat = file_stat_ + try: + exec(while_loop.compiled_body, caller_globals, caller_locals) + except Exception as exception: + # A "break" inside the loop body will cause a SyntaxError + # because the code is executed outside the scope of a loop. + # We catch the exception and break *this* loop. + if exception.args == ("break",): + break + if exception.args == ("continue",): + continue + else: + handle_exception(filename, interactive_exception) + + +def _reloading_loop(seq: Union[Iterable, bool], + interactive_exception) -> Iterable: + stack: List[inspect.FrameInfo] = inspect.stack() + # The first element on the stack is the caller of inspect.stack() + # i.e. _reloading_loop + assert stack[0].function == "_reloading_loop" + # The second element is the caller of the first, i.e. reloading + assert stack[1].function == "reloading" + # The third element is the loop which called reloading. + loop_frame_info: inspect.FrameInfo = stack[2] + filename: str = loop_frame_info.filename + # If we are running in Jupyter Notebook then the filename + # of the current notebook is stored in the __session__ variable. + if ".ipynb" in loop_frame_info.frame.f_globals.get("__session__", ""): + filename = str(loop_frame_info.frame.f_globals.get("__session__")) + # Replace filename with Jupyter Notebook file name + # if we are in a Jupyter Notebook session. + loop_object = get_loop_code(loop_frame_info, + None, + filename, + interactive_exception) + + if isinstance(loop_object, ForLoop): + assert isinstance(seq, Iterable) + execute_for_loop(seq, + loop_frame_info, + filename, + interactive_exception) + elif isinstance(loop_object, WhileLoop): + execute_while_loop(loop_frame_info, + filename, + interactive_exception) + + # If there is a third element, then it is the scope which called the loop. + # If this is the main scope, then all is good. Howver, if we are in a + # scope within the main scope then it is only possible to modify + # variables in this scope since Python 3.13. + if (len(stack) > 3 and + sys.version_info.major >= 3 and + sys.version_info.minor >= 13): + # Copy locals from loop to caller of loop. + # This ensures that the following results in '9': + # for i in reloading(range(10)): + # pass + # print(i) + loop_caller_frame: inspect.FrameInfo = stack[3] + loop_caller_frame.frame.f_locals.update(loop_frame_info.frame.f_locals) + elif (len(stack) > 3 and + loop_frame_info.frame.f_locals.get("__name__", "") != "__main__"): + variables = ", ".join( + f'"{k}"' for k in loop_frame_info.frame.f_locals.keys()) + log.warning(f"Variable(s) {variables} in reloaded loop were not " + "exported to the scope which called the reloaded loop " + f'initially defined at line {loop_frame_info.lineno} in ' + f'file "{filename}".') return [] -def get_decorator_name_or_none(dec_node): - if hasattr(dec_node, "id"): - return dec_node.id - elif hasattr(dec_node.func, "id"): - return dec_node.func.id - elif hasattr(dec_node.func.value, "id"): - return dec_node.func.value.id +def get_decorator_name_or_none(decorator_node): + if hasattr(decorator_node, "id"): + return decorator_node.id + elif hasattr(decorator_node, "attr"): + return decorator_node.attr + elif hasattr(decorator_node.func, "id"): + return decorator_node.func.id + elif hasattr(decorator_node.func.value, "id"): + return decorator_node.func.value.id else: return None -def strip_reloading_decorator(func): - """Remove the 'reloading' decorator and all decorators before it""" - decorator_names = [get_decorator_name(dec) for dec in func.decorator_list] - reloading_idx = decorator_names.index("reloading") - func.decorator_list = func.decorator_list[reloading_idx + 1:] - - -def isolate_function_def(funcname, tree): - """Strip everything but the function definition from the ast in-place. - Also strips the reloading decorator from the function definition""" - for node in ast.walk(tree): - if ( - isinstance(node, ast.FunctionDef) - and node.name == funcname - and "reloading" in [ - get_decorator_name_or_none(dec) - for dec in node.decorator_list - ] - ): - strip_reloading_decorator(node) - tree.body = [ node ] - return True - return False - - -def get_function_def_code(fpath, fn): - tree = parse_file_until_successful(fpath) - found = isolate_function_def(fn.__name__, tree) - if not found: - return None - compiled = compile(tree, filename="", mode="exec") - return compiled - - -def get_reloaded_function(caller_globals, caller_locals, fpath, fn): - code = get_function_def_code(fpath, fn) - if code is None: - return None - # need to copy locals, otherwise the exec will overwrite the decorated with the undecorated new version - # this became a need after removing the reloading decorator from the newly defined version - caller_locals_copy = caller_locals.copy() - exec(code, caller_globals, caller_locals_copy) - func = caller_locals_copy[fn.__name__] - return func - - -def _reloading_function(fn, every=1): - stack = inspect.stack() - frame, fpath = stack[2][:2] - caller_locals = frame.f_locals - caller_globals = frame.f_globals - - # crutch to use dict as python2 doesn't support nonlocal - state = { - "func": None, - "reloads": 0, - } - +def strip_reloading_decorator(function_with_decorator: ast.FunctionDef): + """ + Remove the 'reloading' decorator and all decorators before it. + """ + # Create shorthand for readability + fwod = function_with_decorator + # Find decorator names + decorator_names = [get_decorator_name_or_none(decorator) + for decorator + in fwod.decorator_list] + if "reloading" in decorator_names: + # Find index of "reloading" decorator + reloading_index = decorator_names.index("reloading") + fwod.decorator_list = fwod.decorator_list[reloading_index + 1:] + function_without_decorator = fwod + return function_without_decorator + + +# Source: https://stackoverflow.com/questions/34570992/ +class Parentage(ast.NodeTransformer): + # current parent (module) + parent = None + + def visit(self, node: Any): + # set parent attribute for this node + node.parent = self.parent + # This node becomes the new parent + self.parent = node + # Do any work required by super class + node = super().visit(node) + # If we have a valid node (ie. node not being removed) + if isinstance(node, ast.AST): + # update the parent, since this may have been transformed + # to a different node by super + self.parent = getattr(node, "parent") + return node + + +def get_node_id(node) -> str: + path = "" + while node.parent: + path = node.__class__.__name__+("." if path else "")+path + node = node.parent + return path + + +class Function: + def __init__(self, + function_name: str, + function_frame_info: inspect.FrameInfo, + ast_module: ast.Module, + filename: str, + node_id: str): + self.ast = ast_module + self.id = node_id + self.name = function_name + caller_locals = function_frame_info.frame.f_locals + caller_globals = function_frame_info.frame.f_globals + # Copy locals to avoid exec overwriting the decorated function with + # the new undecorated function. + caller_locals_copy = caller_locals.copy() + caller_globals_copy = caller_globals.copy() + # Variables that are local to the calling scope + # are global to the function. + caller_globals_copy.update(caller_locals_copy) + compiled_body = compile(ast_module, filename=filename, mode="exec") + exec(compiled_body, caller_globals_copy, caller_locals_copy) + self.function = caller_locals_copy[function_name] + + +def get_function_object(function_frame_info: inspect.FrameInfo, + function: Callable, + reloaded_file_ast: ast.Module, + filename: str, + function_id: Union[None, str] = None) -> Function: + """ + Traverse AST of the entire reloaded file in a search for the + function (minus the reloading decorator) which is reloaded. + """ + qualname = function.__qualname__ + function_name = qualname.split(".")[-1] + + candidate = None + Parentage().visit(reloaded_file_ast) + relevant_nodes = [] + for node in ast.walk(reloaded_file_ast): + if (isinstance(node, ast.FunctionDef) and + node.name == function.__name__): + relevant_nodes.append(node) + if function_id is None: + # If we don't have an ID, then it is because this is the + # first time we get the function object. In this case, we + # can assume that the function object and the AST are in sync. + # That is, if the line numbers match then it's all good. + if node.lineno == function.__code__.co_firstlineno: + candidate = node + break + # Okay, so the line numbers don't match exactly. This could be + # because of decorators. Check if the function is decorated + # for reloading and that the line numbers are plausible. + node_l = node.lineno + function_l = function.__code__.co_firstlineno + if all(["reloading" in [get_decorator_name_or_none(decorator) + for decorator in node.decorator_list], + node_l > function_l, + node_l - function_l <= len(node.decorator_list)]): + candidate = node + break + # If the node IDs match then its a sure thing. + if get_node_id(node) == function_id: + candidate = node + break + if candidate is None and len(relevant_nodes) == 1: + candidate = relevant_nodes[0] + elif candidate is None and len(relevant_nodes) > 1: + raise ReloadingException( + f'File "{filename}" contains ' + f'{len(relevant_nodes)} definitions of function ' + f'"{function_name}" and it is not possible ' + f'to determine which to reload.') + # Select the candidate node which is closest to function_frame_info + if not candidate: + raise ReloadingException( + f'Unable to reload function "{function_name}" ' + f'in file "{filename}". ' + 'The function might have been renamed or the ' + 'decorator might have been removed.' + ) + function_id = get_node_id(candidate) + function_node = strip_reloading_decorator(candidate) + function_node_ast = ast.Module([function_node], type_ignores=[]) + return Function(function.__name__, + function_frame_info, + function_node_ast, + filename, + function_id) + + +def get_reloaded_function(function_frame_info: inspect.FrameInfo, + function: Callable, + filename: str, + function_id: Union[None, str], + interactive_exception: bool) -> Function: + reloaded_file_ast: ast.Module = parse_file_until_successful( + filename, + interactive_exception + ) + return get_function_object(function_frame_info, + function, + reloaded_file_ast, + filename, + function_id) + + +def _reloading_function(function: Callable, + interactive_exception: bool) -> Callable: + stack: List[inspect.FrameInfo] = inspect.stack() + # The first element on the stack is the caller of inspect.stack() + # That is, this very function. + assert stack[0].function == "_reloading_function" + index = 2 + if stack[1].function == "reloading": + pass + elif stack[1].function == "wrap" and stack[2].function == "reloading": + index = 3 + # The third/fourth element or later is the function which called reloading. + # Assume it's the third. + function_frame_info: inspect.FrameInfo = stack[index] + # Look to see if theres a better frame in the stack. + for frame_info in stack[index:]: + names_global = set(frame_info.frame.f_globals.keys()) + names_local = set(frame_info.frame.f_locals.keys()) + variables = names_local | names_global + if all([function.__name__ in variables, + frame_info.filename == function.__code__.co_filename]): + function_frame_info = frame_info + filename: str = function.__code__.co_filename + # If we are running in Jupyter Notebook then the filename + # of the current notebook is stored in the __session__ variable. + if ".ipynb" in function_frame_info.frame.f_globals.get("__session__", ""): + filename = str(function_frame_info.frame.f_globals.get("__session__")) + + file_stat: int = os.stat(filename).st_mtime_ns + function_object = get_reloaded_function(function_frame_info, + function, + filename, + None, + interactive_exception) + + @functools.wraps(function) def wrapped(*args, **kwargs): - if state["reloads"] % every == 0: - state["func"] = get_reloaded_function(caller_globals, caller_locals, fpath, fn) or state["func"] - state["reloads"] += 1 + nonlocal file_stat, function_object + # Reload code if possibly modified + file_stat_: int = os.stat(filename).st_mtime_ns + if file_stat != file_stat_: + log.info(f'Function "{function.__name__}" initially defined at ' + f'line {function_frame_info.lineno} ' + f'of file "{filename}" has been reloaded.') + ast_before = function_object.ast + function_object = get_reloaded_function(function_frame_info, + function, + filename, + function_object.id, + interactive_exception) + ast_after = function_object.ast + log.debug(get_diff_text(ast_before, ast_after)) + file_stat = file_stat_ while True: try: - result = state["func"](*args, **kwargs) - return result + return function_object.function(*args, **kwargs) except Exception: - handle_exception(fpath) - state["func"] = get_reloaded_function(caller_globals, caller_locals, fpath, fn) or state["func"] + handle_exception(filename, interactive_exception) - caller_locals[fn.__name__] = wrapped return wrapped diff --git a/reloading/test_reloading.py b/reloading/test_reloading.py index e541114..162cddc 100644 --- a/reloading/test_reloading.py +++ b/reloading/test_reloading.py @@ -1,22 +1,508 @@ import unittest +import sys import os -import subprocess as sp +import subprocess import time +import importlib from reloading import reloading SRC_FILE_NAME = "temporary_testing_file.py" -TEST_CHANGING_SOURCE_LOOP_CONTENT = """ + +def run_and_update_source(init_src, updated_src=None, update_after=0.5): + """Runs init_src in a subprocess and updates source to updated_src after + update_after seconds. Returns the standard output of the subprocess and + whether the subprocess produced an uncaught exception. + """ + with open(SRC_FILE_NAME, "w", encoding="utf-8") as f: + f.write(init_src) + + cmd = ["python", SRC_FILE_NAME] + with subprocess.Popen(cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) as proc: + if updated_src is not None: + time.sleep(update_after) + with open(SRC_FILE_NAME, "w", encoding="utf-8") as f: + f.write(updated_src) + f.flush() + + try: + stdout_bytes, _ = proc.communicate(timeout=2) + stdout = stdout_bytes.decode("utf-8") + has_error = False + except Exception: + stdout = "" + has_error = True + proc.terminate() + + if os.path.isfile(SRC_FILE_NAME): + os.remove(SRC_FILE_NAME) + + return stdout, has_error + + +class TestReloadingForLoopWithoutChanges(unittest.TestCase): + def test_no_argument(self): + i = 0 + for _ in reloading(): # type: ignore + i += 1 + if i > 10: + break + + if sys.version_info.major >= 3 and sys.version_info.minor >= 13: + self.assertEqual(i, 11) + + def test_range_pass(self): + for _ in reloading(range(10)): + pass + + def test_range_pass_not_interactive(self): + for _ in reloading(range(10), interactive_exception=False): + pass + + def test_module_import(self): + import reloading + for _ in reloading.reloading(range(10)): + pass + + def test_range_body(self): + i = 0 + for _ in reloading(range(10)): + i += 1 + + if sys.version_info.major >= 3 and sys.version_info.minor >= 13: + self.assertEqual(i, 10) + + def test_complex_iteration_variables(self): + i = 0 + j = 0 + for j, (a, b) in reloading(enumerate(zip(range(10), range(10)))): + i += 1 + + if sys.version_info.major >= 3 and sys.version_info.minor >= 13: + self.assertEqual(i, 10) + self.assertEqual(j, 9) + + def test_empty_iterator(self): + i = 0 + for _ in reloading([]): + i += 1 + + if sys.version_info.major >= 3 and sys.version_info.minor >= 13: + self.assertEqual(i, 0) + + def test_use_outside_loop(self): + with self.assertRaises(Exception): + reloading(range(10)) + + def test_continue(self): + i = 0 + j = 0 + for i in range(10): + i += 1 + if i > 5: + continue + j = i + + self.assertEqual(i, 10) + self.assertEqual(j, 5) + + def test_iterator_in_library(self): + with open("temporary_library.py", "w") as f: + f.write(""" +def iterator(): + return range(10) +""") + import temporary_library # type: ignore + importlib.reload(temporary_library) + + i = 0 + for i in reloading(temporary_library.iterator()): + pass + + if sys.version_info.major >= 3 and sys.version_info.minor >= 13: + self.assertEqual(i, 9) + if os.path.isfile("temporary_library.py"): + os.remove("temporary_library.py") + + +class TestReloadingWhileLoopWithoutChanges(unittest.TestCase): + def test_no_argument(self): + i = 0 + while reloading(): + i += 1 + if i == 10: + break + + if sys.version_info.major >= 3 and sys.version_info.minor >= 13: + self.assertEqual(i, 10) + + def test_not_interactive(self): + i = 0 + while reloading(i < 10, interactive_exception=False): + i += 1 + + def test_false(self): + i = 0 + while reloading(False): + i += 1 + + if sys.version_info.major >= 3 and sys.version_info.minor >= 13: + self.assertEqual(i, 0) + + def test_true_break(self): + i = 0 + while reloading(True): + i += 1 + if i > 9: + break + + if sys.version_info.major >= 3 and sys.version_info.minor >= 13: + self.assertEqual(i, 10) + + def test_condition_changes(self): + i = 0 + + def condition(): # type: ignore + return True + + while reloading(condition()): + i += 1 + if i > 9: + def condition(): + return False + + if sys.version_info.major >= 3 and sys.version_info.minor >= 13: + self.assertEqual(i, 10) + + def test_condition(self): + i = 0 + while reloading(i < 10): + i += 1 + + if sys.version_info.major >= 3 and sys.version_info.minor >= 13: + self.assertEqual(i, 10) + + def test_use_outside_loop(self): + with self.assertRaises(Exception): + reloading(True) + + def test_continue(self): + i = 0 + j = 0 + while reloading(i < 10): + i += 1 + if i > 5: + continue + j = i + + if sys.version_info.major >= 3 and sys.version_info.minor >= 13: + self.assertEqual(i, 10) + self.assertEqual(j, 5) + + def test_module_import(self): + import reloading + i = 0 + while reloading.reloading(i < 10): + i += 1 + + def test_condition_in_library(self): + with open("temporary_library.py", "w") as f: + f.write(""" +def condition(x): + return x < 10 +""") + import temporary_library # type: ignore + importlib.reload(temporary_library) + + i = 0 + while reloading(temporary_library.condition(i)): + i += 1 + + if sys.version_info.major >= 3 and sys.version_info.minor >= 13: + self.assertEqual(i, 10) + if os.path.isfile("temporary_library.py"): + os.remove("temporary_library.py") + + +class TestReloadingFunctionWithoutChanges(unittest.TestCase): + def test_empty_function_definition(self): + @reloading + def function(): + pass + + def test_empty_function_wrapped(self): + def function(): + pass + + function = reloading(function) + + def test_empty_function_wrapped_not_interactive(self): + def function(): + pass + + function = reloading(function, interactive_exception=False) + + def test_empty_function_run(self): + @reloading + def function(): + pass + + function() + + def test_module_import(self): + import reloading + + @reloading.reloading + def function(): + pass + + function() + + def test_function_return_value(self): + @reloading + def function(): + return "1" + + self.assertEqual(function(), "1") + + def test_function_return_value_wrapped(self): + def function(): + return "2" + + function = reloading(function) + self.assertEqual(function(), "2") + + def test_nested_function(self): + def outer(): + @reloading + def inner(): + return "3" + return inner() + + self.assertEqual(outer(), "3") + + def test_function_signature_is_preserved(self): + @reloading + def some_func(a, b, c): + return "4" + + import inspect + self.assertEqual(str(inspect.signature(some_func)), "(a, b, c)") + self.assertEqual(some_func(1, 2, 3), "4") + + def test_decorated_function(self): + def decorator(f): + def wrap(): + return "<"+str(f())+">" + return wrap + + @decorator + @reloading + def f(): + return 6 + + self.assertEqual(f(), "<6>") + + def test_function_in_library(self): + with open("temporary_library.py", "w") as f: + f.write(""" +def function_not_marked(x): + return x +""") + import temporary_library # type: ignore + importlib.reload(temporary_library) + self.assertEqual(temporary_library.function_not_marked("7"), "7") + g = reloading(temporary_library.function_not_marked) + self.assertEqual(g("8"), "8") + if os.path.isfile("temporary_library.py"): + os.remove("temporary_library.py") + + def test_reloading_function_in_library(self): + with open("temporary_library.py", "w") as f: + f.write(""" +from reloading import reloading + +@reloading +def function_marked(x): + return x +""") + import temporary_library # type: ignore + importlib.reload(temporary_library) + self.assertEqual(temporary_library.function_marked("9"), "9") + if os.path.isfile("temporary_library.py"): + os.remove("temporary_library.py") + + def test_deep_call_stack_make_locals_globals(self): + a = 2 + + def f(x): + b = 2 + return x*a*b + + def g(): + return reloading(f) + + function = g() + result = function("9") + + self.assertEqual(result, "9999") + + def test_deep_call_stack_prioritise_locals(self): + a = 2 + # Now flake8 does not complain about unused a + print(a) + + def f(x): + a = 3 + b = 2 + return x*a*b + + def g(): + return reloading(f) + + function = g() + result = function("9") + + self.assertEqual(result, "999999") + + +class TestReloadingFunctionDecoratorArgumentsWithoutChanges(unittest.TestCase): + def test_empty_function_definition(self): + @reloading(interactive_exception=False) + def function1(): + pass + + @reloading() + def function2(): + pass + + def test_empty_function_run(self): + @reloading(interactive_exception=False) + def function(): + pass + + function() + + def test_module_import(self): + import reloading + + @reloading.reloading(interactive_exception=False) + def function(): + pass + + function() + + def test_function_return_value(self): + @reloading(interactive_exception=False) + def function(): + return "1" + + self.assertEqual(function(), "1") + + def test_nested_function(self): + def outer(): + @reloading(interactive_exception=False) + def inner(): + return "3" + return inner() + + self.assertEqual(outer(), "3") + + def test_function_signature_is_preserved(self): + @reloading(interactive_exception=False) + def some_func(a, b, c): + return "4" + + import inspect + self.assertEqual(str(inspect.signature(some_func)), "(a, b, c)") + self.assertEqual(some_func(1, 2, 3), "4") + + def test_decorated_function(self): + def decorator(f): + def wrap(): + return "<"+str(f())+">" + return wrap + + @decorator + @reloading(interactive_exception=False) + def f(): + return 6 + + self.assertEqual(f(), "<6>") + + def test_reloading_function_in_library(self): + with open("temporary_library.py", "w") as f: + f.write(""" +from reloading import reloading + +@reloading(interactive_exception=False) +def function_marked(x): + return x +""") + import temporary_library # type: ignore + importlib.reload(temporary_library) + self.assertEqual(temporary_library.function_marked("9"), "9") + if os.path.isfile("temporary_library.py"): + os.remove("temporary_library.py") + + def test_deep_call_stack_make_locals_globals(self): + a = 2 + + @reloading(interactive_exception=False) + def f(x): + b = 2 + return x*a*b + + def g(): + return f + + function = g() + result = function("9") + + self.assertEqual(result, "9999") + + def test_deep_call_stack_prioritise_locals(self): + a = 2 + # Now flake8 does not complain about unused a + print(a) + + @reloading(interactive_exception=False) + def f(x): + a = 3 + b = 2 + return x*a*b + + def g(): + return f + + function = g() + result = function("9") + + self.assertEqual(result, "999999") + + +class TestReloadingForLoopWithChanges(unittest.TestCase): + def test_changing_source_loop(self): + code = """ from reloading import reloading from time import sleep -for epoch in reloading(range(10)): +for epoch1, epoch2 in reloading(zip(range(10), range(1,11))): sleep(0.2) print('INITIAL_FILE_CONTENTS') """ + stdout, _ = run_and_update_source( + init_src=code, + updated_src=code.replace("INITIAL", "CHANGED"), + ) + self.assertIn("INITIAL_FILE_CONTENTS", stdout) + self.assertIn("CHANGED_FILE_CONTENTS", stdout) -TEST_CHANGING_LINE_NUMBER_OF_LOOP = """ + def test_changing_line_number_of_loop(self): + code = """ from reloading import reloading from time import sleep @@ -26,30 +512,35 @@ sleep(0.2) print('INITIAL_FILE_CONTENTS') """ + stdout, _ = run_and_update_source( + init_src=code, + updated_src=code.replace("pass", "pass\npass\n"). + replace("INITIAL", "CHANGED"), + ) + self.assertIn("INITIAL_FILE_CONTENTS", stdout) + self.assertIn("CHANGED_FILE_CONTENTS", stdout) -TEST_CHANGING_SOURCE_FN_CONTENT = """ + def test_keep_local_variables(self): + code = """ from reloading import reloading from time import sleep -@reloading -def reload_this_fn(): - print('INITIAL_FILE_CONTENTS') - +text = "DON'T CHANGE ME" for epoch in reloading(range(10)): sleep(0.2) - reload_this_fn() -""" - -TEST_KEEP_LOCAL_VARIABLES_CONTENT = """ -from reloading import reloading -from time import sleep - -fpath = "DON'T CHANGE ME" -for epoch in reloading(range(1)): - assert fpath == "DON'T CHANGE ME" + print('INITIAL_FILE_CONTENTS') + assert text == "DON'T CHANGE ME" """ + stdout, has_error = run_and_update_source( + init_src=code, + updated_src=code.replace("INITIAL", "CHANGED") + ) + self.assertFalse(has_error) + self.assertIn("INITIAL_FILE_CONTENTS", stdout) + self.assertIn("CHANGED_FILE_CONTENTS", stdout) -TEST_PERSIST_AFTER_LOOP = """ + def test_persist_after_loop(self): + code = """ from reloading import reloading from time import sleep @@ -59,8 +550,11 @@ def reload_this_fn(): assert state == 'CHANGED' """ + _, has_error = run_and_update_source(init_src=code) + self.assertFalse(has_error) -TEST_COMMENT_AFTER_LOOP_CONTENT = """ + def test_comment_after_loop(self): + code = """ from reloading import reloading from time import sleep @@ -70,8 +564,16 @@ def reload_this_fn(): # a comment here should not cause an error """ + stdout, _ = run_and_update_source( + init_src=code, + updated_src=code.replace("INITIAL", "CHANGED"), + ) -TEST_FORMAT_STR_IN_LOOP_CONTENT = """ + self.assertIn("INITIAL_FILE_CONTENTS", stdout) + self.assertIn("CHANGED_FILE_CONTENTS", stdout) + + def test_format_str_in_loop(self): + code = """ from reloading import reloading from time import sleep @@ -80,8 +582,142 @@ def reload_this_fn(): file_contents = 'FILE_CONTENTS' print(f'INITIAL_{file_contents}') """ + stdout, _ = run_and_update_source( + init_src=code, + updated_src=code.replace("INITIAL", "CHANGED").rstrip("\n"), + ) + + self.assertIn("INITIAL_FILE_CONTENTS", stdout) + self.assertIn("CHANGED_FILE_CONTENTS", stdout) + + def test_unicode(self): + code = """ +from reloading import reloading +from time import sleep + +for epoch, ep in reloading(zip(range(10), range(1,11))): + sleep(0.2) + print('INITIAL_FILE_CONTENTS'+'😊') +""" + stdout, _ = run_and_update_source( + init_src=code, + updated_src=code.replace("INITIAL", "CHANGED"), + ) + + self.assertIn("INITIAL_FILE_CONTENTS", stdout) + self.assertIn("CHANGED_FILE_CONTENTS", stdout) + + def test_nested_loop(self): + code = """ +from reloading import reloading +from time import sleep + +for i in range(1): + static = 'A' + for j in reloading(range(10, 20)): + dynamic = 'B' + print(static+dynamic) + sleep(0.2) +""" + stdout, _ = run_and_update_source( + init_src=code, + updated_src=code.replace("dynamic = 'B'", "dynamic = 'C'"). + replace("static = 'A'", "static = 'D'"), + ) + + self.assertIn("AB", stdout) + self.assertIn("AC", stdout) + + def test_function_in_loop(self): + code = """ +from reloading import reloading +from time import sleep + +def f(): + return 'f' + +for i in reloading(range(10)): + print(f()+'g') + sleep(0.2) +""" + stdout, _ = run_and_update_source( + init_src=code, + updated_src=code.replace("'f'", "'F'"). + replace("'g'", "'G'"), + ) + + self.assertIn("fg", stdout) + self.assertIn("fG", stdout) + self.assertNotIn("FG", stdout) + + +class TestReloadingWhileLoopWithChanges(unittest.TestCase): + def test_changing_source_loop(self): + code = """ +from reloading import reloading +from time import sleep + +i = 0 +while reloading(i < 100): + sleep(0.2) + print(i) + i += 1 +""" + stdout, _ = run_and_update_source( + init_src=code, + updated_src=code.replace("i < 100", "i < 10"), + ) + max_i = stdout.strip().split("\n")[-1] + self.assertEqual(max_i, "9") + + def test_function_in_loop(self): + code = """ +from reloading import reloading +from time import sleep + +def f(): + return 'f' + +i = 0 +while reloading(i<10): + print(f()+'g') + sleep(0.2) + i += 1 +""" + stdout, _ = run_and_update_source( + init_src=code, + updated_src=code.replace("'f'", "'F'"). + replace("'g'", "'G'"), + ) + + self.assertIn("fg", stdout) + self.assertIn("fG", stdout) + self.assertNotIn("FG", stdout) + + +class TestReloadingFunctionsWithChanges(unittest.TestCase): + def test_changing_source_function(self): + code = """ +from reloading import reloading +from time import sleep + +@reloading +def reload_this_fn(): + print('INITIAL_FILE_CONTENTS') + +for epoch in reloading(range(10)): + sleep(0.2) + reload_this_fn() +""" + stdout, _ = run_and_update_source( + init_src=code, + updated_src=code.replace("INITIAL", "CHANGED"), + ) + self.assertIn("INITIAL_FILE_CONTENTS", stdout) + self.assertIn("CHANGED_FILE_CONTENTS", stdout) -TEST_FUNCTION_AFTER = """ + def test_reloading_function(self): + code = """ from reloading import reloading from time import sleep @@ -93,124 +729,160 @@ def some_func(a, b): for _ in range(10): some_func(2,1) """ + stdout, _ = run_and_update_source( + init_src=code, + updated_src=code.replace("a+b", "a-b"), + ) + self.assertIn("3", stdout) + self.assertIn("1", stdout) -def run_and_update_source(init_src, updated_src=None, update_after=0.5, bin="python3"): - """Runs init_src in a subprocess and updates source to updated_src after - update_after seconds. Returns the standard output of the subprocess and - whether the subprocess produced an uncaught exception. - """ - with open(SRC_FILE_NAME, "w") as f: - f.write(init_src) + def test_nested_function(self): + code = """ +from reloading import reloading +from time import sleep - cmd = [bin, SRC_FILE_NAME] - with sp.Popen(cmd, stdout=sp.PIPE, stderr=sp.PIPE) as proc: - if updated_src is not None: - time.sleep(update_after) - with open(SRC_FILE_NAME, "w") as f: - f.write(updated_src) +def outer(): + static = 'A' + @reloading + def inner(x): + dynamic = 'B' + return x + dynamic + return inner(static) - try: - stdout, _ = proc.communicate(timeout=2) - stdout = stdout.decode("utf-8") - has_error = False - except: - stdout = "" - has_error = True - proc.terminate() +for i in range(10): + print(outer()) + sleep(0.2) +""" + stdout, _ = run_and_update_source( + init_src=code, + updated_src=code.replace("dynamic = 'B'", "dynamic = 'C'"). + replace("static = 'D'", "static = 'D'"), + ) - if os.path.isfile(SRC_FILE_NAME): - os.remove(SRC_FILE_NAME) + self.assertIn("AB", stdout) + self.assertIn("AC", stdout) - return stdout, has_error + def test_multiple_function(self): + code = """ +from reloading import reloading +from time import sleep +@reloading +def f(): + return 'f' -class TestReloading(unittest.TestCase): - def test_simple_looping(self): - iters = 0 - for _ in reloading(range(10)): - iters += 1 +@reloading +def g(): + return 'g' - def test_changing_source_loop(self): - for bin in ["python", "python3"]: - stdout, _ = run_and_update_source( - init_src=TEST_CHANGING_SOURCE_LOOP_CONTENT, - updated_src=TEST_CHANGING_SOURCE_LOOP_CONTENT.replace("INITIAL", "CHANGED").rstrip("\n"), - bin=bin, - ) +for i in range(10): + print(f()+g()) + sleep(0.2) +""" + stdout, _ = run_and_update_source( + init_src=code, + updated_src=code.replace("'f'", "'F'"). + replace("'g'", "'G'"), + ) - self.assertTrue("INITIAL_FILE_CONTENTS" in stdout and "CHANGED_FILE_CONTENTS" in stdout) + self.assertIn("fg", stdout) + self.assertIn("FG", stdout) - def test_changing_line_number_of_loop(self): - for bin in ["python", "python3"]: - stdout, _ = run_and_update_source( - init_src=TEST_CHANGING_LINE_NUMBER_OF_LOOP, - updated_src=( - TEST_CHANGING_LINE_NUMBER_OF_LOOP - .replace("pass", "pass\npass\n") - .replace("INITIAL", "CHANGED") - .rstrip("\n") - ), - bin=bin, - ) - - self.assertTrue("INITIAL_FILE_CONTENTS" in stdout and "CHANGED_FILE_CONTENTS" in stdout) + def test_multiple_functions_not_decorated(self): + code = """ +from reloading import reloading +from time import sleep - def test_comment_after_loop(self): - for bin in ["python", "python3"]: - stdout, _ = run_and_update_source( - init_src=TEST_COMMENT_AFTER_LOOP_CONTENT, - updated_src=TEST_COMMENT_AFTER_LOOP_CONTENT.replace("INITIAL", "CHANGED").rstrip("\n"), - bin=bin, - ) +def f(): + return 'f' - self.assertTrue("INITIAL_FILE_CONTENTS" in stdout and "CHANGED_FILE_CONTENTS" in stdout) +def g(): + return 'g' - def test_format_str_in_loop(self): +f = reloading(f) +g = reloading(g) + +for i in range(10): + print(f()+g()) + sleep(0.2) +""" stdout, _ = run_and_update_source( - init_src=TEST_FORMAT_STR_IN_LOOP_CONTENT, - updated_src=TEST_FORMAT_STR_IN_LOOP_CONTENT.replace("INITIAL", "CHANGED").rstrip("\n"), - bin="python3", + init_src=code, + updated_src=code.replace("'f'", "'F'"). + replace("'g'", "'G'"), ) - self.assertTrue("INITIAL_FILE_CONTENTS" in stdout and "CHANGED_FILE_CONTENTS" in stdout) + self.assertIn("fg", stdout) + self.assertIn("FG", stdout) - def test_keep_local_variables(self): - for bin in ["python", "python3"]: - _, has_error = run_and_update_source(init_src=TEST_KEEP_LOCAL_VARIABLES_CONTENT, bin=bin) - self.assertFalse(has_error) + def test_class_decorates_methods(self): + code = """ +from reloading import reloading +from time import sleep - def test_persist_after_loop(self): - for bin in ["python", "python3"]: - _, has_error = run_and_update_source(init_src=TEST_PERSIST_AFTER_LOOP, bin=bin) - self.assertFalse(has_error) +def get_subclass_methods(cls): + methods = set(dir(cls(_get_subclass_methods=True))) + unique_methods = methods.difference( + *(dir(base()) for base in cls.__bases__) + ) + return list(unique_methods) + +class ClassWhichMarksSubclassMethodsForReload: + def __init__(self, *args, **kwargs): + if (self.__class__.__name__ != super().__thisclass__.__name__ + and not '_get_subclass_methods' in kwargs): + methods_of_subclass = get_subclass_methods(self.__class__) + for method in methods_of_subclass: + setattr(self.__class__, method, + reloading(getattr(self.__class__, method))) + def f(self): + return 'f' + +class Subclass(ClassWhichMarksSubclassMethodsForReload): + def g(self): + return 'g' + +obj = Subclass() + +for i in range(10): + print(obj.f()+obj.g()) + sleep(0.2) +""" + stdout, _ = run_and_update_source( + init_src=code, + updated_src=code.replace("'f'", "'F'"). + replace("'g'", "'G'"), + ) - def test_simple_function(self): - @reloading - def some_func(): - return "result" + self.assertIn("fg", stdout) + self.assertIn("fG", stdout) + self.assertNotIn("FG", stdout) - self.assertTrue(some_func() == "result") + def test_function_while_loop(self): + code = """ +from reloading import reloading +from time import sleep - def test_reloading_function(self): - for bin in ["python", "python3"]: - stdout, _ = run_and_update_source( - init_src=TEST_FUNCTION_AFTER, - updated_src=TEST_FUNCTION_AFTER.replace("a+b", "a-b"), - bin=bin, - ) - self.assertTrue("3" in stdout and "1" in stdout) +@reloading +def f(): + return 'f' - def test_changing_source_function(self): - for bin in ["python", "python3"]: - stdout, _ = run_and_update_source( - init_src=TEST_CHANGING_SOURCE_FN_CONTENT, - updated_src=TEST_CHANGING_SOURCE_FN_CONTENT.replace("INITIAL", "CHANGED").rstrip("\n"), - bin=bin, - ) +i = 0 +while reloading(i<10): + print(f()+'g') + sleep(0.2) + i += 1 +""" + stdout, _ = run_and_update_source( + init_src=code, + updated_src=code.replace("'f'", "'F'"). + replace("'g'", "'G'"), + ) - self.assertTrue("INITIAL_FILE_CONTENTS" in stdout and "CHANGED_FILE_CONTENTS" in stdout) + self.assertIn("fg", stdout) + self.assertIn("FG", stdout) if __name__ == "__main__": - unittest.main() + unittest.main(verbosity=2) diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index b88034e..0000000 --- a/setup.cfg +++ /dev/null @@ -1,2 +0,0 @@ -[metadata] -description-file = README.md diff --git a/setup.py b/setup.py deleted file mode 100644 index 74f5255..0000000 --- a/setup.py +++ /dev/null @@ -1,35 +0,0 @@ -from setuptools import setup -from os import path - -this_directory = path.abspath(path.dirname(__file__)) -with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f: - long_description = f.read() - -setup( - name='reloading', - packages=['reloading'], - version='1.1.2', - license='MIT', - description='Reloads source code of a running program without losing state', - long_description=long_description, - long_description_content_type='text/markdown', - author='Julian Vossen', - author_email='pypi@julianvossen.de', - url='https://github.com/julvo/reloading', - download_url='https://github.com/julvo/reloading/archive/v1.1.2.tar.gz', - keywords=['reload', 'reloading', 'refresh', 'loop', 'decorator'], - install_requires=[], - classifiers=[ - 'Development Status :: 3 - Alpha', - 'Intended Audience :: Developers', - 'Topic :: Utilities', - 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - ], -)