Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions denoise/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .models import UNet, UNetModel
15 changes: 15 additions & 0 deletions denoise/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from .torchscript import TorchScriptModel
from .unet import UNet, UNetModel

def get_model(model_name):
if model_name == 'unet':
return UNetModel
elif model_name == 'torchscript':
return TorchScriptModel
else:
raise NotImplementedError


# TODO: Add a model that first does colour based classical denoising
# and then does neural denoising (now only intensity based, not colour based)
# could even be applied in greyscale
15 changes: 15 additions & 0 deletions denoise/models/torchscript.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch

class TorchScriptModel(torch.nn.Module):
def __init__(self, model_path):
super().__init__()
self.net = torch.jit.load(model_path)

@classmethod
def load(cls, model_path):
model = cls(model_path)
model.eval()
return model

def forward(self, x):
return self.net(x)
22 changes: 18 additions & 4 deletions models/unet.py → denoise/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,21 +107,35 @@ def _block(in_channels, features, name):


class UNetModel(pl.LightningModule):
def __init__(self, model):
def __init__(self, checkpoint=None, eval_mode=False, **load_kwargs):
# checkpoint: path to checkpoint file (lightning)
super().__init__()
self.model = model
self.net = UNet(out_channels=3)
self.loss_fn = nn.MSELoss()

if eval_mode:
self.net.eval()

if checkpoint:
self.load_from_checkpoint(checkpoint, **load_kwargs)

@classmethod
def load(cls, checkpoint, **load_kwargs):
net = UNet(out_channels=3)
model = cls.load_from_checkpoint(checkpoint, net=net, **load_kwargs)

return model

def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
y_hat = self.net(x)
loss = self.loss_fn(y_hat, y)
self.log("train_loss", loss)
return loss

def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
y_hat = self.net(x)
loss = self.loss_fn(y_hat, y)
self.log("val_loss", loss)

Expand Down
1 change: 1 addition & 0 deletions denoise/noise/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .noise import *
File renamed without changes.
Empty file removed scripts/__init__.py
Empty file.
91 changes: 46 additions & 45 deletions scripts/infer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
from models.unet import UNet, UNetModel
from denoise.models.unet import UNet, UNetModel
from denoise.models import get_model
import torchvision as tv
from models.noise import ShotNoise
from denoise.noise import ShotNoise, GaussianNoise
from PIL import Image
from kornia.contrib import extract_tensor_patches, combine_tensor_patches
import os
Expand Down Expand Up @@ -34,7 +35,7 @@ def infer_model(img, model, crop_size):
# NOTE: Only compatible with a single image for now
results = []
for patch in patches: # use batches?
result = model.model(patch.unsqueeze(0))
result = model.net(patch.unsqueeze(0))
results.append(
result
)
Expand All @@ -50,8 +51,8 @@ def infer(img):

unet = UNet(out_channels=3)
model = UNetModel.load_from_checkpoint(
'colab_lightning_logs/version_0/checkpoints/epoch=115-step=1856.ckpt',
model=unet,
'colab_lightning_logs/version_0/checkpoints/netKEY-epoch=115-step=1856.ckpt',
net=unet,
map_location=torch.device('cpu')
)

Expand All @@ -70,7 +71,7 @@ def infer(img):
results = []
for patch in patches: # use batches?
results.append(
model.model(patch.unsqueeze(0))
model.net(patch.unsqueeze(0))
# TODO: Disable batchnorm (needs 4D input)
)
results = torch.stack(results, dim=1)
Expand All @@ -83,45 +84,45 @@ def main():
crop_size = (256, 256)
output_dir = 'infer-results'
os.makedirs(output_dir, exist_ok=True)
unet = UNet(out_channels=3)
model = UNetModel.load_from_checkpoint(
'colab_lightning_logs/version_0/checkpoints/epoch=115-step=1856.ckpt',
model=unet,
map_location=torch.device('cpu')
)

model.eval()

test_image_path = 'data/div2k/subset/val/0804.png'

tensor_transform = tv.transforms.ToTensor()
noise_transform = ShotNoise(sensitivity=-4.8, sensitivity_sigma=1.2)

img = Image.open(test_image_path)
image = tensor_transform(img)
if apply_noise:
image = noise_transform(image)
tv.utils.save_image(image, os.path.join(output_dir, 'noisy.png'))

image = image.unsqueeze(0)
patches = extract_tensor_patches(image, crop_size, stride=256, allow_auto_padding=True)

original_size = img.size[::-1] # this is with padding
patches = patches.squeeze(0)
# NOTE: Only compatible with a single image for now
results = []
for patch in patches: # use batches?
results.append(
model.model(patch.unsqueeze(0))
# TODO: Disable batchnorm (needs 4D input)
)
results = torch.stack(results, dim=1)
print(f"Results shape: {results.shape}")
result = combine_tensor_patches(results, original_size=original_size, window_size=crop_size, stride=crop_size[0], allow_auto_unpadding=True)

print(result.shape)

tv.utils.save_image(result, os.path.join(output_dir, 'result.png'))
model_type = 'torchscript'
model_path = 'pretrained-models/NAFNet/NAFNet.pt'
# model_load_args = {'map_location': torch.device('cpu')} # TODO: Add short yaml files to get these per-model load settings

arch = get_model(model_type)
model = arch.load(model_path=model_path)
# model = arch.load(checkpoint='colab_lightning_logs/version_0/checkpoints/netKEY-epoch=115-step=1856.ckpt', eval_mode=True, **model_load_args)
with torch.no_grad():
test_image_path = 'data/div2k/subset/val/0805.png'

tensor_transform = tv.transforms.ToTensor()
# noise_transform = ShotNoise(sensitivity=-4.8, sensitivity_sigma=1.2)
noise_transform = GaussianNoise(rgb_variance=(0.005,0.005,0.005))

img = Image.open(test_image_path)
image = tensor_transform(img)
if apply_noise:
image = noise_transform(image)
tv.utils.save_image(image, os.path.join(output_dir, 'noisy.png'))
image = image.unsqueeze(0)
patches = extract_tensor_patches(image, crop_size, stride=256, allow_auto_padding=True)

original_size = img.size[::-1] # this is with padding
patches = patches.squeeze(0)
# NOTE: Only compatible with a single image for now
results = []
for patch in patches: # use batches?
results.append(
model.net(patch.unsqueeze(0))
# TODO: Disable batchnorm (needs 4D input)
)
results = torch.stack(results, dim=1)
print(f"Results shape: {results.shape}")
result = combine_tensor_patches(results, original_size=original_size, window_size=crop_size, stride=crop_size[0], allow_auto_unpadding=True)

print(result.shape)

tv.utils.save_image(result, os.path.join(output_dir, 'result.png'))


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def main():
logger = pl.loggers.TensorBoardLogger("lightning_logs")

unet = UNet(out_channels=3)
model = UNetModel(model=unet)
model = UNetModel(net=unet)

input_transforms = tv.transforms.Compose(
[
Expand Down