Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM pytorch/pytorch:latest
FROM python:3.10-slim

WORKDIR /project

Expand All @@ -10,6 +10,6 @@ ENV PYTHONPATH="/mlflow/projects/code/:$PYTHONPATH"

COPY . .

# install requirements
# install requirements with compatible versions
RUN python -m pip install --upgrade pip && \
python -m pip install -r requirements.txt
6 changes: 3 additions & 3 deletions config/local_config.cfg
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
[server]
MLFLOW_S3_ENDPOINT_URL = http://0.0.0.0:8002
MLFLOW_TRACKING_URI = http://0.0.0.0:85
MLFLOW_S3_ENDPOINT_URL = http://localhost:8002
MLFLOW_TRACKING_URI = http://localhost:85
ARTIFACT_PATH = s3://mlflow

[xnat]
USER = admin
PASSWORD = admin
PROJECT = hipposeg
VERIFY = false
SERVER = http://localhost
SERVER = http://localhost/

[project]
NAME = hipposeg
Expand Down
139 changes: 110 additions & 29 deletions project/DataModule.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import List, Optional

import pytorch_lightning
from mlops.data.tools.tools import xnat_build_dataset
from mlops.data.transforms.LoadImageXNATd import LoadImageXNATd
from project.utils.tools import xnat_build_dataset
from project.transforms.LoadImageXNATd import LoadImageXNATd
from monai.data import CacheDataset, pad_list_data_collate
from monai.transforms import (

Check warning on line 7 in project/DataModule.py

View workflow job for this annotation

GitHub Actions / build-and-test

F401 'monai.transforms.Orientationd' imported but unused

Check warning on line 7 in project/DataModule.py

View workflow job for this annotation

GitHub Actions / build-and-test

F401 'monai.transforms.CropForegroundd' imported but unused
EnsureChannelFirstd,
Compose,
LoadImage,
Expand All @@ -20,8 +20,16 @@

class DataModule(pytorch_lightning.LightningDataModule):

def __init__(self, data_dir: str = './', xnat_configuration: dict = None, batch_size: int = 1, num_workers: int = 4,
test_fraction: float = 0.1, train_val_ratio: float = 0.2, test_batch: int = -1):
def __init__(
self,
data_dir: str = "./",
xnat_configuration: dict = None,
batch_size: int = 1,
num_workers: int = 4,
test_fraction: float = 0.1,
train_val_ratio: float = 0.2,
test_batch: int = -1,
):

super().__init__()
self.data_dir = data_dir
Expand All @@ -38,18 +46,24 @@
:param stage:
:return:
"""
# list of tuples defining action functions and their data keys
actions = [(self.fetch_image, 'image'),
(self.fetch_label, 'label')]
actions = [(self.fetch_image, "image"), (self.fetch_label, "label")]

self.xnat_data_list = xnat_build_dataset(self.xnat_configuration, actions=actions, test_batch=self.test_batch)
self.xnat_data_list = xnat_build_dataset(
self.xnat_configuration, actions=actions, test_batch=self.test_batch
)

self.train_samples, self.valid_samples = random_split(self.xnat_data_list, [1-self.train_val_ratio, self.train_val_ratio])
self.train_samples, self.valid_samples = random_split(
self.xnat_data_list, [1 - self.train_val_ratio, self.train_val_ratio]
)

self.train_transforms = Compose(
[
LoadImageXNATd(keys=['data'], xnat_configuration=self.xnat_configuration,
image_loader=LoadImage(image_only=True), expected_filetype_ext='.nii.gz'),
LoadImageXNATd(
keys=["data"],
xnat_configuration=self.xnat_configuration,
image_loader=LoadImage(image_only=True),
expected_filetype_ext=".nii.gz",
),
EnsureChannelFirstd(keys=["image", "label"]),
Spacingd(
keys=["image", "label"],
Expand All @@ -62,8 +76,12 @@

self.val_transforms = Compose(
[
LoadImageXNATd(keys=['data'], xnat_configuration=self.xnat_configuration,
image_loader=LoadImage(image_only=True), expected_filetype_ext='.nii.gz'),
LoadImageXNATd(
keys=["data"],
xnat_configuration=self.xnat_configuration,
image_loader=LoadImage(image_only=True),
expected_filetype_ext=".nii.gz",
),
EnsureChannelFirstd(keys=["image", "label"]),
Spacingd(
keys=["image", "label"],
Expand All @@ -74,8 +92,12 @@
]
)

self.train_dataset = CacheDataset(data=self.train_samples, transform=self.train_transforms)
self.val_dataset = CacheDataset(data=self.valid_samples, transform=self.val_transforms)
self.train_dataset = CacheDataset(
data=self.train_samples, transform=self.train_transforms
)
self.val_dataset = CacheDataset(
data=self.valid_samples, transform=self.val_transforms
)

def prepare_data(self, *args, **kwargs):
pass
Expand All @@ -85,45 +107,104 @@
Define train dataloader
:return:
"""
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True,
num_workers=self.num_workers, collate_fn=pad_list_data_collate,
pin_memory=is_available())
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
collate_fn=pad_list_data_collate,
pin_memory=is_available(),
)

def val_dataloader(self):
"""
Define validation dataloader
:return:
"""
return DataLoader(self.val_dataset, batch_size=1, num_workers=self.num_workers, collate_fn=pad_list_data_collate,
pin_memory=is_available())

return DataLoader(
self.val_dataset,
batch_size=1,
num_workers=self.num_workers,
collate_fn=pad_list_data_collate,
pin_memory=is_available(),
)

@staticmethod
def fetch_image(subject_data: SubjectData = None) -> List[ImageScanData]:

Check warning on line 133 in project/DataModule.py

View workflow job for this annotation

GitHub Actions / build-and-test

C901 'DataModule.fetch_image' is too complex (11)
"""
Function that identifies and returns the required xnat ImageData object from a xnat SubjectData object
along with the 'key' that it will be used to access it.
"""
output = []
for exp in subject_data.experiments:
for scan in subject_data.experiments[exp].scans:
if 'image' in subject_data.experiments[exp].scans[scan].id.lower():
output.append(subject_data.experiments[exp].scans[scan])

if hasattr(subject_data.experiments, "values"):
experiments = subject_data.experiments.values()
else:
experiments = [
subject_data.experiments[exp_id]
for exp_id in subject_data.experiments.keys()
]

for experiment in experiments:
try:
if hasattr(experiment.scans, "values"):
scans = experiment.scans.values()
else:
scans = [
experiment.scans[scan_id] for scan_id in experiment.scans.keys()
]

for scan_obj in scans:
try:
scan_name = scan_obj.id.lower()
if "image" in scan_name:
output.append(scan_obj)
except Exception:
continue

except Exception:
continue

if len(output) > 1:
raise TypeError
return output

@staticmethod
def fetch_label(subject_data: SubjectData = None) -> List[ImageScanData]:

Check warning on line 173 in project/DataModule.py

View workflow job for this annotation

GitHub Actions / build-and-test

C901 'DataModule.fetch_label' is too complex (11)
"""
Function that identifies and returns the required xnat ImageData object from a xnat SubjectData object
along with the 'key' that it will be used to access it.
"""
output = []
for exp in subject_data.experiments:
for scan in subject_data.experiments[exp].scans:
if 'label' in subject_data.experiments[exp].scans[scan].id.lower():
output.append(subject_data.experiments[exp].scans[scan])

if hasattr(subject_data.experiments, "values"):
experiments = subject_data.experiments.values()
else:
experiments = [
subject_data.experiments[exp_id]
for exp_id in subject_data.experiments.keys()
]

for experiment in experiments:
try:
if hasattr(experiment.scans, "values"):
scans = experiment.scans.values()
else:
scans = [
experiment.scans[scan_id] for scan_id in experiment.scans.keys()
]

for scan_obj in scans:
try:
scan_name = scan_obj.id.lower()
if "label" in scan_name:
output.append(scan_obj)
except Exception:
continue

except Exception:
continue

if len(output) > 1:
raise TypeError
return output
50 changes: 39 additions & 11 deletions project/Network.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,18 @@ def __init__(self, **kwargs):
norm=Norm.BATCH,
)
self.loss_function = DiceLoss(to_onehot_y=True, softmax=True)
self.post_pred = Compose([EnsureType("tensor", device=torch.device("cpu")), AsDiscrete(argmax=True, to_onehot=3)])
self.post_label = Compose([EnsureType("tensor", device=torch.device("cpu")), AsDiscrete(to_onehot=3)])
self.dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
self.post_pred = Compose(
[
EnsureType("tensor", device=torch.device("cpu")),
AsDiscrete(argmax=True, to_onehot=3),
]
)
self.post_label = Compose(
[EnsureType("tensor", device=torch.device("cpu")), AsDiscrete(to_onehot=3)]
)
self.dice_metric = DiceMetric(
include_background=False, reduction="mean", get_not_nans=False
)
self.best_val_dice = 0
self.best_val_epoch = 0

Expand All @@ -52,31 +61,50 @@ def training_step(self, batch, batch_idx):
images, labels = batch["image"], batch["label"]
output = self.forward(images)
loss = self.loss_function(output, labels)
self.log('train_loss', loss.item(), on_step=True, on_epoch=True, logger=True, sync_dist=True)
self.log(
"train_loss",
loss.item(),
on_step=True,
on_epoch=True,
logger=True,
sync_dist=True,
)
return {"loss": loss}

def validation_step(self, batch, batch_idx):
images, labels = batch["image"], batch["label"]
roi_size = (-1, -1, -1)
sw_batch_size = 1
outputs = sliding_window_inference(
images, roi_size, sw_batch_size, self.forward)
images, roi_size, sw_batch_size, self.forward
)
loss = self.loss_function(outputs, labels)
outputs = [self.post_pred(i) for i in decollate_batch(outputs)]
labels = [self.post_label(i) for i in decollate_batch(labels)]
self.dice_metric(y_pred=outputs, y=labels)
return {"val_loss": loss, "val_number": len(outputs)}
output_dict = {"val_loss": loss, "val_number": len(outputs)}

if not hasattr(self, "validation_step_outputs"):
self.validation_step_outputs = []
self.validation_step_outputs.append(output_dict)

def validation_epoch_end(self, outputs):
return output_dict

def on_validation_epoch_end(self):
val_loss, num_items = 0, 0
outputs = getattr(self, "validation_step_outputs", [])
for output in outputs:
val_loss += output["val_loss"].sum().item()
num_items += output["val_number"]
mean_val_dice = self.dice_metric.aggregate().item()
self.dice_metric.reset()
mean_val_loss = torch.tensor(val_loss / num_items)
self.log_dict({
"mean_val_dice": mean_val_dice,
"mean_val_loss": mean_val_loss,
})
self.log_dict(
{
"mean_val_dice": mean_val_dice,
"mean_val_loss": mean_val_loss,
}
)
# Clear the stored outputs
self.validation_step_outputs = []
return
Loading
Loading