diff --git a/denoise/__init__.py b/denoise/__init__.py new file mode 100644 index 0000000..1b2b1c5 --- /dev/null +++ b/denoise/__init__.py @@ -0,0 +1 @@ +from .models import UNet, UNetModel \ No newline at end of file diff --git a/denoise/models/__init__.py b/denoise/models/__init__.py new file mode 100644 index 0000000..e041c9b --- /dev/null +++ b/denoise/models/__init__.py @@ -0,0 +1,15 @@ +from .torchscript import TorchScriptModel +from .unet import UNet, UNetModel + +def get_model(model_name): + if model_name == 'unet': + return UNetModel + elif model_name == 'torchscript': + return TorchScriptModel + else: + raise NotImplementedError + + +# TODO: Add a model that first does colour based classical denoising +# and then does neural denoising (now only intensity based, not colour based) + # could even be applied in greyscale \ No newline at end of file diff --git a/denoise/models/torchscript.py b/denoise/models/torchscript.py new file mode 100644 index 0000000..7659e6d --- /dev/null +++ b/denoise/models/torchscript.py @@ -0,0 +1,15 @@ +import torch + +class TorchScriptModel(torch.nn.Module): + def __init__(self, model_path): + super().__init__() + self.net = torch.jit.load(model_path) + + @classmethod + def load(cls, model_path): + model = cls(model_path) + model.eval() + return model + + def forward(self, x): + return self.net(x) \ No newline at end of file diff --git a/models/unet.py b/denoise/models/unet.py similarity index 90% rename from models/unet.py rename to denoise/models/unet.py index 25abce4..5856dc3 100644 --- a/models/unet.py +++ b/denoise/models/unet.py @@ -107,21 +107,35 @@ def _block(in_channels, features, name): class UNetModel(pl.LightningModule): - def __init__(self, model): + def __init__(self, checkpoint=None, eval_mode=False, **load_kwargs): + # checkpoint: path to checkpoint file (lightning) super().__init__() - self.model = model + self.net = UNet(out_channels=3) self.loss_fn = nn.MSELoss() + + if eval_mode: + self.net.eval() + + if checkpoint: + self.load_from_checkpoint(checkpoint, **load_kwargs) + @classmethod + def load(cls, checkpoint, **load_kwargs): + net = UNet(out_channels=3) + model = cls.load_from_checkpoint(checkpoint, net=net, **load_kwargs) + + return model + def training_step(self, batch, batch_idx): x, y = batch - y_hat = self.model(x) + y_hat = self.net(x) loss = self.loss_fn(y_hat, y) self.log("train_loss", loss) return loss def validation_step(self, batch, batch_idx): x, y = batch - y_hat = self.model(x) + y_hat = self.net(x) loss = self.loss_fn(y_hat, y) self.log("val_loss", loss) diff --git a/denoise/noise/__init__.py b/denoise/noise/__init__.py new file mode 100644 index 0000000..5b67476 --- /dev/null +++ b/denoise/noise/__init__.py @@ -0,0 +1 @@ +from .noise import * \ No newline at end of file diff --git a/models/noise.py b/denoise/noise/noise.py similarity index 100% rename from models/noise.py rename to denoise/noise/noise.py diff --git a/scripts/__init__.py b/scripts/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/scripts/infer.py b/scripts/infer.py index 1e1d3e3..aaa38b1 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -1,7 +1,8 @@ import torch -from models.unet import UNet, UNetModel +from denoise.models.unet import UNet, UNetModel +from denoise.models import get_model import torchvision as tv -from models.noise import ShotNoise +from denoise.noise import ShotNoise, GaussianNoise from PIL import Image from kornia.contrib import extract_tensor_patches, combine_tensor_patches import os @@ -34,7 +35,7 @@ def infer_model(img, model, crop_size): # NOTE: Only compatible with a single image for now results = [] for patch in patches: # use batches? - result = model.model(patch.unsqueeze(0)) + result = model.net(patch.unsqueeze(0)) results.append( result ) @@ -50,8 +51,8 @@ def infer(img): unet = UNet(out_channels=3) model = UNetModel.load_from_checkpoint( - 'colab_lightning_logs/version_0/checkpoints/epoch=115-step=1856.ckpt', - model=unet, + 'colab_lightning_logs/version_0/checkpoints/netKEY-epoch=115-step=1856.ckpt', + net=unet, map_location=torch.device('cpu') ) @@ -70,7 +71,7 @@ def infer(img): results = [] for patch in patches: # use batches? results.append( - model.model(patch.unsqueeze(0)) + model.net(patch.unsqueeze(0)) # TODO: Disable batchnorm (needs 4D input) ) results = torch.stack(results, dim=1) @@ -83,45 +84,45 @@ def main(): crop_size = (256, 256) output_dir = 'infer-results' os.makedirs(output_dir, exist_ok=True) - unet = UNet(out_channels=3) - model = UNetModel.load_from_checkpoint( - 'colab_lightning_logs/version_0/checkpoints/epoch=115-step=1856.ckpt', - model=unet, - map_location=torch.device('cpu') - ) - - model.eval() - - test_image_path = 'data/div2k/subset/val/0804.png' - - tensor_transform = tv.transforms.ToTensor() - noise_transform = ShotNoise(sensitivity=-4.8, sensitivity_sigma=1.2) - - img = Image.open(test_image_path) - image = tensor_transform(img) - if apply_noise: - image = noise_transform(image) - tv.utils.save_image(image, os.path.join(output_dir, 'noisy.png')) - - image = image.unsqueeze(0) - patches = extract_tensor_patches(image, crop_size, stride=256, allow_auto_padding=True) - - original_size = img.size[::-1] # this is with padding - patches = patches.squeeze(0) - # NOTE: Only compatible with a single image for now - results = [] - for patch in patches: # use batches? - results.append( - model.model(patch.unsqueeze(0)) - # TODO: Disable batchnorm (needs 4D input) - ) - results = torch.stack(results, dim=1) - print(f"Results shape: {results.shape}") - result = combine_tensor_patches(results, original_size=original_size, window_size=crop_size, stride=crop_size[0], allow_auto_unpadding=True) - - print(result.shape) - - tv.utils.save_image(result, os.path.join(output_dir, 'result.png')) + model_type = 'torchscript' + model_path = 'pretrained-models/NAFNet/NAFNet.pt' + # model_load_args = {'map_location': torch.device('cpu')} # TODO: Add short yaml files to get these per-model load settings + + arch = get_model(model_type) + model = arch.load(model_path=model_path) + # model = arch.load(checkpoint='colab_lightning_logs/version_0/checkpoints/netKEY-epoch=115-step=1856.ckpt', eval_mode=True, **model_load_args) + with torch.no_grad(): + test_image_path = 'data/div2k/subset/val/0805.png' + + tensor_transform = tv.transforms.ToTensor() + # noise_transform = ShotNoise(sensitivity=-4.8, sensitivity_sigma=1.2) + noise_transform = GaussianNoise(rgb_variance=(0.005,0.005,0.005)) + + img = Image.open(test_image_path) + image = tensor_transform(img) + if apply_noise: + image = noise_transform(image) + tv.utils.save_image(image, os.path.join(output_dir, 'noisy.png')) + + image = image.unsqueeze(0) + patches = extract_tensor_patches(image, crop_size, stride=256, allow_auto_padding=True) + + original_size = img.size[::-1] # this is with padding + patches = patches.squeeze(0) + # NOTE: Only compatible with a single image for now + results = [] + for patch in patches: # use batches? + results.append( + model.net(patch.unsqueeze(0)) + # TODO: Disable batchnorm (needs 4D input) + ) + results = torch.stack(results, dim=1) + print(f"Results shape: {results.shape}") + result = combine_tensor_patches(results, original_size=original_size, window_size=crop_size, stride=crop_size[0], allow_auto_unpadding=True) + + print(result.shape) + + tv.utils.save_image(result, os.path.join(output_dir, 'result.png')) if __name__ == "__main__": diff --git a/scripts/train.py b/scripts/train.py index 3941fb4..44a0f04 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -56,7 +56,7 @@ def main(): logger = pl.loggers.TensorBoardLogger("lightning_logs") unet = UNet(out_channels=3) - model = UNetModel(model=unet) + model = UNetModel(net=unet) input_transforms = tv.transforms.Compose( [