From 78131453810fe7fe5e9591080ea7bcff1b25868c Mon Sep 17 00:00:00 2001 From: Daniel Nobbe Date: Tue, 23 Jan 2024 17:34:24 +0100 Subject: [PATCH 1/8] add torchscript model --- denoise/models/torchscript.py | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 denoise/models/torchscript.py diff --git a/denoise/models/torchscript.py b/denoise/models/torchscript.py new file mode 100644 index 0000000..9b631a3 --- /dev/null +++ b/denoise/models/torchscript.py @@ -0,0 +1,9 @@ +import torch + +class TorchScriptModel(torch.nn.Module): + def __init__(self, model_path): + super().__init__() + self.net = torch.jit.load(model_path) + + def forward(self, x): + return self.net(x) \ No newline at end of file From fb55ff95d7462fe0cad428ea20670e11276c3d56 Mon Sep 17 00:00:00 2001 From: Daniel Nobbe Date: Tue, 23 Jan 2024 17:34:40 +0100 Subject: [PATCH 2/8] refactor folder structure --- denoise/models/__init__.py | 10 ++++++++++ {models => denoise/models}/unet.py | 0 {models => denoise/noise}/noise.py | 0 3 files changed, 10 insertions(+) create mode 100644 denoise/models/__init__.py rename {models => denoise/models}/unet.py (100%) rename {models => denoise/noise}/noise.py (100%) diff --git a/denoise/models/__init__.py b/denoise/models/__init__.py new file mode 100644 index 0000000..b80e1b8 --- /dev/null +++ b/denoise/models/__init__.py @@ -0,0 +1,10 @@ +from torchscript import TorchScriptModel +from unet import UNet, UNetModel + +def get_model(model_name): + if model_name == 'unet': + return UNet + elif model_name == 'torchscript': + return TorchScriptModel + else: + raise NotImplementedError \ No newline at end of file diff --git a/models/unet.py b/denoise/models/unet.py similarity index 100% rename from models/unet.py rename to denoise/models/unet.py diff --git a/models/noise.py b/denoise/noise/noise.py similarity index 100% rename from models/noise.py rename to denoise/noise/noise.py From f0861b138f9f93a388754d4bdf0b001f93827837 Mon Sep 17 00:00:00 2001 From: Daniel Nobbe Date: Tue, 23 Jan 2024 17:42:17 +0100 Subject: [PATCH 3/8] use net inside unet module --- denoise/models/unet.py | 8 ++++---- scripts/infer.py | 8 ++++---- scripts/train.py | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/denoise/models/unet.py b/denoise/models/unet.py index 25abce4..901f72d 100644 --- a/denoise/models/unet.py +++ b/denoise/models/unet.py @@ -107,21 +107,21 @@ def _block(in_channels, features, name): class UNetModel(pl.LightningModule): - def __init__(self, model): + def __init__(self, net): super().__init__() - self.model = model + self.net = net self.loss_fn = nn.MSELoss() def training_step(self, batch, batch_idx): x, y = batch - y_hat = self.model(x) + y_hat = self.net(x) loss = self.loss_fn(y_hat, y) self.log("train_loss", loss) return loss def validation_step(self, batch, batch_idx): x, y = batch - y_hat = self.model(x) + y_hat = self.net(x) loss = self.loss_fn(y_hat, y) self.log("val_loss", loss) diff --git a/scripts/infer.py b/scripts/infer.py index 1e1d3e3..d72ff74 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -51,7 +51,7 @@ def infer(img): unet = UNet(out_channels=3) model = UNetModel.load_from_checkpoint( 'colab_lightning_logs/version_0/checkpoints/epoch=115-step=1856.ckpt', - model=unet, + net=unet, map_location=torch.device('cpu') ) @@ -70,7 +70,7 @@ def infer(img): results = [] for patch in patches: # use batches? results.append( - model.model(patch.unsqueeze(0)) + model.net(patch.unsqueeze(0)) # TODO: Disable batchnorm (needs 4D input) ) results = torch.stack(results, dim=1) @@ -86,7 +86,7 @@ def main(): unet = UNet(out_channels=3) model = UNetModel.load_from_checkpoint( 'colab_lightning_logs/version_0/checkpoints/epoch=115-step=1856.ckpt', - model=unet, + net=unet, map_location=torch.device('cpu') ) @@ -112,7 +112,7 @@ def main(): results = [] for patch in patches: # use batches? results.append( - model.model(patch.unsqueeze(0)) + model.net(patch.unsqueeze(0)) # TODO: Disable batchnorm (needs 4D input) ) results = torch.stack(results, dim=1) diff --git a/scripts/train.py b/scripts/train.py index 3941fb4..44a0f04 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -56,7 +56,7 @@ def main(): logger = pl.loggers.TensorBoardLogger("lightning_logs") unet = UNet(out_channels=3) - model = UNetModel(model=unet) + model = UNetModel(net=unet) input_transforms = tv.transforms.Compose( [ From 390c7b7680c000f629e4724639c44a58b7eca0d9 Mon Sep 17 00:00:00 2001 From: Daniel Nobbe Date: Tue, 23 Jan 2024 17:43:33 +0100 Subject: [PATCH 4/8] fix wording in infer_model --- scripts/infer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/scripts/infer.py b/scripts/infer.py index d72ff74..2e995cd 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -34,7 +34,7 @@ def infer_model(img, model, crop_size): # NOTE: Only compatible with a single image for now results = [] for patch in patches: # use batches? - result = model.model(patch.unsqueeze(0)) + result = model.net(patch.unsqueeze(0)) results.append( result ) @@ -83,15 +83,17 @@ def main(): crop_size = (256, 256) output_dir = 'infer-results' os.makedirs(output_dir, exist_ok=True) + + unet = UNet(out_channels=3) model = UNetModel.load_from_checkpoint( 'colab_lightning_logs/version_0/checkpoints/epoch=115-step=1856.ckpt', net=unet, map_location=torch.device('cpu') ) - model.eval() + test_image_path = 'data/div2k/subset/val/0804.png' tensor_transform = tv.transforms.ToTensor() From 72336251948f6ebf5375cf9261c9a25914afb3fc Mon Sep 17 00:00:00 2001 From: Daniel Nobbe Date: Tue, 23 Jan 2024 19:07:31 +0100 Subject: [PATCH 5/8] refactor folder structure --- denoise/__init__.py | 1 + denoise/models/__init__.py | 4 ++-- denoise/models/unet.py | 9 ++++++++- denoise/noise/__init__.py | 1 + scripts/__init__.py | 0 scripts/infer.py | 10 +++++----- 6 files changed, 17 insertions(+), 8 deletions(-) create mode 100644 denoise/__init__.py create mode 100644 denoise/noise/__init__.py delete mode 100644 scripts/__init__.py diff --git a/denoise/__init__.py b/denoise/__init__.py new file mode 100644 index 0000000..1b2b1c5 --- /dev/null +++ b/denoise/__init__.py @@ -0,0 +1 @@ +from .models import UNet, UNetModel \ No newline at end of file diff --git a/denoise/models/__init__.py b/denoise/models/__init__.py index b80e1b8..2448602 100644 --- a/denoise/models/__init__.py +++ b/denoise/models/__init__.py @@ -1,5 +1,5 @@ -from torchscript import TorchScriptModel -from unet import UNet, UNetModel +from .torchscript import TorchScriptModel +from .unet import UNet, UNetModel def get_model(model_name): if model_name == 'unet': diff --git a/denoise/models/unet.py b/denoise/models/unet.py index 901f72d..abc03f4 100644 --- a/denoise/models/unet.py +++ b/denoise/models/unet.py @@ -107,10 +107,17 @@ def _block(in_channels, features, name): class UNetModel(pl.LightningModule): - def __init__(self, net): + def __init__(self, net, load_from_checkpoint='', eval_mode=False): super().__init__() self.net = net self.loss_fn = nn.MSELoss() + + if eval_mode: + self.net.eval() + + # if load_from_checkpoint: + + def training_step(self, batch, batch_idx): x, y = batch diff --git a/denoise/noise/__init__.py b/denoise/noise/__init__.py new file mode 100644 index 0000000..5b67476 --- /dev/null +++ b/denoise/noise/__init__.py @@ -0,0 +1 @@ +from .noise import * \ No newline at end of file diff --git a/scripts/__init__.py b/scripts/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/scripts/infer.py b/scripts/infer.py index 2e995cd..74a136a 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -1,7 +1,7 @@ import torch -from models.unet import UNet, UNetModel +from denoise.models.unet import UNet, UNetModel import torchvision as tv -from models.noise import ShotNoise +from denoise.noise import ShotNoise from PIL import Image from kornia.contrib import extract_tensor_patches, combine_tensor_patches import os @@ -50,7 +50,7 @@ def infer(img): unet = UNet(out_channels=3) model = UNetModel.load_from_checkpoint( - 'colab_lightning_logs/version_0/checkpoints/epoch=115-step=1856.ckpt', + 'colab_lightning_logs/version_0/checkpoints/netKEY-epoch=115-step=1856.ckpt', net=unet, map_location=torch.device('cpu') ) @@ -83,11 +83,11 @@ def main(): crop_size = (256, 256) output_dir = 'infer-results' os.makedirs(output_dir, exist_ok=True) - + model_type = 'unet' unet = UNet(out_channels=3) model = UNetModel.load_from_checkpoint( - 'colab_lightning_logs/version_0/checkpoints/epoch=115-step=1856.ckpt', + 'colab_lightning_logs/version_0/checkpoints/netKEY-epoch=115-step=1856.ckpt', net=unet, map_location=torch.device('cpu') ) From 3917fe099013ee77a775710baacbcb8c67390bb3 Mon Sep 17 00:00:00 2001 From: Daniel Nobbe Date: Tue, 23 Jan 2024 20:26:14 +0100 Subject: [PATCH 6/8] initialise unetmodel directly with unet net --- scripts/infer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/scripts/infer.py b/scripts/infer.py index 74a136a..2d81488 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -84,11 +84,10 @@ def main(): output_dir = 'infer-results' os.makedirs(output_dir, exist_ok=True) model_type = 'unet' + # model_load_args = {'map_location': torch.device('cpu')} - unet = UNet(out_channels=3) model = UNetModel.load_from_checkpoint( 'colab_lightning_logs/version_0/checkpoints/netKEY-epoch=115-step=1856.ckpt', - net=unet, map_location=torch.device('cpu') ) model.eval() From 9f3977d5625a4b55b328a57e3c448cc8c0ae4eeb Mon Sep 17 00:00:00 2001 From: Daniel Nobbe Date: Tue, 23 Jan 2024 20:45:20 +0100 Subject: [PATCH 7/8] select model by string --- denoise/models/__init__.py | 2 +- denoise/models/unet.py | 17 ++++++++++++----- scripts/infer.py | 10 ++++------ 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/denoise/models/__init__.py b/denoise/models/__init__.py index 2448602..d959db2 100644 --- a/denoise/models/__init__.py +++ b/denoise/models/__init__.py @@ -3,7 +3,7 @@ def get_model(model_name): if model_name == 'unet': - return UNet + return UNetModel elif model_name == 'torchscript': return TorchScriptModel else: diff --git a/denoise/models/unet.py b/denoise/models/unet.py index abc03f4..5856dc3 100644 --- a/denoise/models/unet.py +++ b/denoise/models/unet.py @@ -107,18 +107,25 @@ def _block(in_channels, features, name): class UNetModel(pl.LightningModule): - def __init__(self, net, load_from_checkpoint='', eval_mode=False): + def __init__(self, checkpoint=None, eval_mode=False, **load_kwargs): + # checkpoint: path to checkpoint file (lightning) super().__init__() - self.net = net + self.net = UNet(out_channels=3) self.loss_fn = nn.MSELoss() if eval_mode: self.net.eval() - # if load_from_checkpoint: - - + if checkpoint: + self.load_from_checkpoint(checkpoint, **load_kwargs) + @classmethod + def load(cls, checkpoint, **load_kwargs): + net = UNet(out_channels=3) + model = cls.load_from_checkpoint(checkpoint, net=net, **load_kwargs) + + return model + def training_step(self, batch, batch_idx): x, y = batch y_hat = self.net(x) diff --git a/scripts/infer.py b/scripts/infer.py index 2d81488..a779e1c 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -1,5 +1,6 @@ import torch from denoise.models.unet import UNet, UNetModel +from denoise.models import get_model import torchvision as tv from denoise.noise import ShotNoise from PIL import Image @@ -84,14 +85,11 @@ def main(): output_dir = 'infer-results' os.makedirs(output_dir, exist_ok=True) model_type = 'unet' - # model_load_args = {'map_location': torch.device('cpu')} + model_load_args = {'map_location': torch.device('cpu')} - model = UNetModel.load_from_checkpoint( - 'colab_lightning_logs/version_0/checkpoints/netKEY-epoch=115-step=1856.ckpt', - map_location=torch.device('cpu') - ) - model.eval() + arch = get_model(model_type) + model = arch.load(checkpoint='colab_lightning_logs/version_0/checkpoints/netKEY-epoch=115-step=1856.ckpt', eval_mode=True, **model_load_args) test_image_path = 'data/div2k/subset/val/0804.png' From ba115ed19fff1ee425dc9e8b2e879ab53b4fa3f3 Mon Sep 17 00:00:00 2001 From: Daniel Nobbe Date: Tue, 23 Jan 2024 21:20:22 +0100 Subject: [PATCH 8/8] load torchscript model --- denoise/models/__init__.py | 7 +++- denoise/models/torchscript.py | 6 +++ scripts/infer.py | 74 ++++++++++++++++++----------------- 3 files changed, 50 insertions(+), 37 deletions(-) diff --git a/denoise/models/__init__.py b/denoise/models/__init__.py index d959db2..e041c9b 100644 --- a/denoise/models/__init__.py +++ b/denoise/models/__init__.py @@ -7,4 +7,9 @@ def get_model(model_name): elif model_name == 'torchscript': return TorchScriptModel else: - raise NotImplementedError \ No newline at end of file + raise NotImplementedError + + +# TODO: Add a model that first does colour based classical denoising +# and then does neural denoising (now only intensity based, not colour based) + # could even be applied in greyscale \ No newline at end of file diff --git a/denoise/models/torchscript.py b/denoise/models/torchscript.py index 9b631a3..7659e6d 100644 --- a/denoise/models/torchscript.py +++ b/denoise/models/torchscript.py @@ -5,5 +5,11 @@ def __init__(self, model_path): super().__init__() self.net = torch.jit.load(model_path) + @classmethod + def load(cls, model_path): + model = cls(model_path) + model.eval() + return model + def forward(self, x): return self.net(x) \ No newline at end of file diff --git a/scripts/infer.py b/scripts/infer.py index a779e1c..aaa38b1 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -2,7 +2,7 @@ from denoise.models.unet import UNet, UNetModel from denoise.models import get_model import torchvision as tv -from denoise.noise import ShotNoise +from denoise.noise import ShotNoise, GaussianNoise from PIL import Image from kornia.contrib import extract_tensor_patches, combine_tensor_patches import os @@ -84,43 +84,45 @@ def main(): crop_size = (256, 256) output_dir = 'infer-results' os.makedirs(output_dir, exist_ok=True) - model_type = 'unet' - model_load_args = {'map_location': torch.device('cpu')} + model_type = 'torchscript' + model_path = 'pretrained-models/NAFNet/NAFNet.pt' + # model_load_args = {'map_location': torch.device('cpu')} # TODO: Add short yaml files to get these per-model load settings arch = get_model(model_type) - - model = arch.load(checkpoint='colab_lightning_logs/version_0/checkpoints/netKEY-epoch=115-step=1856.ckpt', eval_mode=True, **model_load_args) - - test_image_path = 'data/div2k/subset/val/0804.png' - - tensor_transform = tv.transforms.ToTensor() - noise_transform = ShotNoise(sensitivity=-4.8, sensitivity_sigma=1.2) - - img = Image.open(test_image_path) - image = tensor_transform(img) - if apply_noise: - image = noise_transform(image) - tv.utils.save_image(image, os.path.join(output_dir, 'noisy.png')) - - image = image.unsqueeze(0) - patches = extract_tensor_patches(image, crop_size, stride=256, allow_auto_padding=True) - - original_size = img.size[::-1] # this is with padding - patches = patches.squeeze(0) - # NOTE: Only compatible with a single image for now - results = [] - for patch in patches: # use batches? - results.append( - model.net(patch.unsqueeze(0)) - # TODO: Disable batchnorm (needs 4D input) - ) - results = torch.stack(results, dim=1) - print(f"Results shape: {results.shape}") - result = combine_tensor_patches(results, original_size=original_size, window_size=crop_size, stride=crop_size[0], allow_auto_unpadding=True) - - print(result.shape) - - tv.utils.save_image(result, os.path.join(output_dir, 'result.png')) + model = arch.load(model_path=model_path) + # model = arch.load(checkpoint='colab_lightning_logs/version_0/checkpoints/netKEY-epoch=115-step=1856.ckpt', eval_mode=True, **model_load_args) + with torch.no_grad(): + test_image_path = 'data/div2k/subset/val/0805.png' + + tensor_transform = tv.transforms.ToTensor() + # noise_transform = ShotNoise(sensitivity=-4.8, sensitivity_sigma=1.2) + noise_transform = GaussianNoise(rgb_variance=(0.005,0.005,0.005)) + + img = Image.open(test_image_path) + image = tensor_transform(img) + if apply_noise: + image = noise_transform(image) + tv.utils.save_image(image, os.path.join(output_dir, 'noisy.png')) + + image = image.unsqueeze(0) + patches = extract_tensor_patches(image, crop_size, stride=256, allow_auto_padding=True) + + original_size = img.size[::-1] # this is with padding + patches = patches.squeeze(0) + # NOTE: Only compatible with a single image for now + results = [] + for patch in patches: # use batches? + results.append( + model.net(patch.unsqueeze(0)) + # TODO: Disable batchnorm (needs 4D input) + ) + results = torch.stack(results, dim=1) + print(f"Results shape: {results.shape}") + result = combine_tensor_patches(results, original_size=original_size, window_size=crop_size, stride=crop_size[0], allow_auto_unpadding=True) + + print(result.shape) + + tv.utils.save_image(result, os.path.join(output_dir, 'result.png')) if __name__ == "__main__":