From 69e380fd841eea7e5af07bdf5e9a600b408c7c27 Mon Sep 17 00:00:00 2001 From: AgatheZ Date: Sun, 22 Jun 2025 23:42:31 +0200 Subject: [PATCH 1/2] update pipeline --- Dockerfile | 4 +- config/local_config.cfg | 6 +- project/DataModule.py | 139 ++++++++++--- project/Network.py | 50 +++-- project/transforms/LoadImageXNATd.py | 160 +++++++++++++++ project/{util => utils}/__init__.py | 0 project/utils/tools.py | 284 +++++++++++++++++++++++++++ project/{util => utils}/visualise.py | 31 ++- scripts/train.py | 53 ++--- 9 files changed, 649 insertions(+), 78 deletions(-) create mode 100644 project/transforms/LoadImageXNATd.py rename project/{util => utils}/__init__.py (100%) create mode 100644 project/utils/tools.py rename project/{util => utils}/visualise.py (62%) diff --git a/Dockerfile b/Dockerfile index e923c5a..18ca07f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM pytorch/pytorch:latest +FROM python:3.10-slim WORKDIR /project @@ -10,6 +10,6 @@ ENV PYTHONPATH="/mlflow/projects/code/:$PYTHONPATH" COPY . . -# install requirements +# install requirements with compatible versions RUN python -m pip install --upgrade pip && \ python -m pip install -r requirements.txt \ No newline at end of file diff --git a/config/local_config.cfg b/config/local_config.cfg index a308e3c..3ebe29a 100644 --- a/config/local_config.cfg +++ b/config/local_config.cfg @@ -1,6 +1,6 @@ [server] -MLFLOW_S3_ENDPOINT_URL = http://0.0.0.0:8002 -MLFLOW_TRACKING_URI = http://0.0.0.0:85 +MLFLOW_S3_ENDPOINT_URL = http://localhost:8002 +MLFLOW_TRACKING_URI = http://localhost:85 ARTIFACT_PATH = s3://mlflow [xnat] @@ -8,7 +8,7 @@ USER = admin PASSWORD = admin PROJECT = hipposeg VERIFY = false -SERVER = http://localhost +SERVER = http://localhost/ [project] NAME = hipposeg diff --git a/project/DataModule.py b/project/DataModule.py index 7d0cd15..4083df9 100644 --- a/project/DataModule.py +++ b/project/DataModule.py @@ -1,8 +1,8 @@ from typing import List, Optional import pytorch_lightning -from mlops.data.tools.tools import xnat_build_dataset -from mlops.data.transforms.LoadImageXNATd import LoadImageXNATd +from project.utils.tools import xnat_build_dataset +from project.transforms.LoadImageXNATd import LoadImageXNATd from monai.data import CacheDataset, pad_list_data_collate from monai.transforms import ( EnsureChannelFirstd, @@ -20,8 +20,16 @@ class DataModule(pytorch_lightning.LightningDataModule): - def __init__(self, data_dir: str = './', xnat_configuration: dict = None, batch_size: int = 1, num_workers: int = 4, - test_fraction: float = 0.1, train_val_ratio: float = 0.2, test_batch: int = -1): + def __init__( + self, + data_dir: str = "./", + xnat_configuration: dict = None, + batch_size: int = 1, + num_workers: int = 4, + test_fraction: float = 0.1, + train_val_ratio: float = 0.2, + test_batch: int = -1, + ): super().__init__() self.data_dir = data_dir @@ -38,18 +46,24 @@ def setup(self, stage: Optional[str] = None): :param stage: :return: """ - # list of tuples defining action functions and their data keys - actions = [(self.fetch_image, 'image'), - (self.fetch_label, 'label')] + actions = [(self.fetch_image, "image"), (self.fetch_label, "label")] - self.xnat_data_list = xnat_build_dataset(self.xnat_configuration, actions=actions, test_batch=self.test_batch) + self.xnat_data_list = xnat_build_dataset( + self.xnat_configuration, actions=actions, test_batch=self.test_batch + ) - self.train_samples, self.valid_samples = random_split(self.xnat_data_list, [1-self.train_val_ratio, self.train_val_ratio]) + self.train_samples, self.valid_samples = random_split( + self.xnat_data_list, [1 - self.train_val_ratio, self.train_val_ratio] + ) self.train_transforms = Compose( [ - LoadImageXNATd(keys=['data'], xnat_configuration=self.xnat_configuration, - image_loader=LoadImage(image_only=True), expected_filetype_ext='.nii.gz'), + LoadImageXNATd( + keys=["data"], + xnat_configuration=self.xnat_configuration, + image_loader=LoadImage(image_only=True), + expected_filetype_ext=".nii.gz", + ), EnsureChannelFirstd(keys=["image", "label"]), Spacingd( keys=["image", "label"], @@ -62,8 +76,12 @@ def setup(self, stage: Optional[str] = None): self.val_transforms = Compose( [ - LoadImageXNATd(keys=['data'], xnat_configuration=self.xnat_configuration, - image_loader=LoadImage(image_only=True), expected_filetype_ext='.nii.gz'), + LoadImageXNATd( + keys=["data"], + xnat_configuration=self.xnat_configuration, + image_loader=LoadImage(image_only=True), + expected_filetype_ext=".nii.gz", + ), EnsureChannelFirstd(keys=["image", "label"]), Spacingd( keys=["image", "label"], @@ -74,8 +92,12 @@ def setup(self, stage: Optional[str] = None): ] ) - self.train_dataset = CacheDataset(data=self.train_samples, transform=self.train_transforms) - self.val_dataset = CacheDataset(data=self.valid_samples, transform=self.val_transforms) + self.train_dataset = CacheDataset( + data=self.train_samples, transform=self.train_transforms + ) + self.val_dataset = CacheDataset( + data=self.valid_samples, transform=self.val_transforms + ) def prepare_data(self, *args, **kwargs): pass @@ -85,18 +107,27 @@ def train_dataloader(self): Define train dataloader :return: """ - return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, - num_workers=self.num_workers, collate_fn=pad_list_data_collate, - pin_memory=is_available()) + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + collate_fn=pad_list_data_collate, + pin_memory=is_available(), + ) def val_dataloader(self): """ Define validation dataloader :return: """ - return DataLoader(self.val_dataset, batch_size=1, num_workers=self.num_workers, collate_fn=pad_list_data_collate, - pin_memory=is_available()) - + return DataLoader( + self.val_dataset, + batch_size=1, + num_workers=self.num_workers, + collate_fn=pad_list_data_collate, + pin_memory=is_available(), + ) @staticmethod def fetch_image(subject_data: SubjectData = None) -> List[ImageScanData]: @@ -105,10 +136,35 @@ def fetch_image(subject_data: SubjectData = None) -> List[ImageScanData]: along with the 'key' that it will be used to access it. """ output = [] - for exp in subject_data.experiments: - for scan in subject_data.experiments[exp].scans: - if 'image' in subject_data.experiments[exp].scans[scan].id.lower(): - output.append(subject_data.experiments[exp].scans[scan]) + + if hasattr(subject_data.experiments, "values"): + experiments = subject_data.experiments.values() + else: + experiments = [ + subject_data.experiments[exp_id] + for exp_id in subject_data.experiments.keys() + ] + + for experiment in experiments: + try: + if hasattr(experiment.scans, "values"): + scans = experiment.scans.values() + else: + scans = [ + experiment.scans[scan_id] for scan_id in experiment.scans.keys() + ] + + for scan_obj in scans: + try: + scan_name = scan_obj.id.lower() + if "image" in scan_name: + output.append(scan_obj) + except Exception: + continue + + except Exception: + continue + if len(output) > 1: raise TypeError return output @@ -120,10 +176,35 @@ def fetch_label(subject_data: SubjectData = None) -> List[ImageScanData]: along with the 'key' that it will be used to access it. """ output = [] - for exp in subject_data.experiments: - for scan in subject_data.experiments[exp].scans: - if 'label' in subject_data.experiments[exp].scans[scan].id.lower(): - output.append(subject_data.experiments[exp].scans[scan]) + + if hasattr(subject_data.experiments, "values"): + experiments = subject_data.experiments.values() + else: + experiments = [ + subject_data.experiments[exp_id] + for exp_id in subject_data.experiments.keys() + ] + + for experiment in experiments: + try: + if hasattr(experiment.scans, "values"): + scans = experiment.scans.values() + else: + scans = [ + experiment.scans[scan_id] for scan_id in experiment.scans.keys() + ] + + for scan_obj in scans: + try: + scan_name = scan_obj.id.lower() + if "label" in scan_name: + output.append(scan_obj) + except Exception: + continue + + except Exception: + continue + if len(output) > 1: raise TypeError return output diff --git a/project/Network.py b/project/Network.py index 3ff11a3..964147e 100644 --- a/project/Network.py +++ b/project/Network.py @@ -33,9 +33,18 @@ def __init__(self, **kwargs): norm=Norm.BATCH, ) self.loss_function = DiceLoss(to_onehot_y=True, softmax=True) - self.post_pred = Compose([EnsureType("tensor", device=torch.device("cpu")), AsDiscrete(argmax=True, to_onehot=3)]) - self.post_label = Compose([EnsureType("tensor", device=torch.device("cpu")), AsDiscrete(to_onehot=3)]) - self.dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False) + self.post_pred = Compose( + [ + EnsureType("tensor", device=torch.device("cpu")), + AsDiscrete(argmax=True, to_onehot=3), + ] + ) + self.post_label = Compose( + [EnsureType("tensor", device=torch.device("cpu")), AsDiscrete(to_onehot=3)] + ) + self.dice_metric = DiceMetric( + include_background=False, reduction="mean", get_not_nans=False + ) self.best_val_dice = 0 self.best_val_epoch = 0 @@ -52,7 +61,14 @@ def training_step(self, batch, batch_idx): images, labels = batch["image"], batch["label"] output = self.forward(images) loss = self.loss_function(output, labels) - self.log('train_loss', loss.item(), on_step=True, on_epoch=True, logger=True, sync_dist=True) + self.log( + "train_loss", + loss.item(), + on_step=True, + on_epoch=True, + logger=True, + sync_dist=True, + ) return {"loss": loss} def validation_step(self, batch, batch_idx): @@ -60,23 +76,35 @@ def validation_step(self, batch, batch_idx): roi_size = (-1, -1, -1) sw_batch_size = 1 outputs = sliding_window_inference( - images, roi_size, sw_batch_size, self.forward) + images, roi_size, sw_batch_size, self.forward + ) loss = self.loss_function(outputs, labels) outputs = [self.post_pred(i) for i in decollate_batch(outputs)] labels = [self.post_label(i) for i in decollate_batch(labels)] self.dice_metric(y_pred=outputs, y=labels) - return {"val_loss": loss, "val_number": len(outputs)} + output_dict = {"val_loss": loss, "val_number": len(outputs)} + + if not hasattr(self, "validation_step_outputs"): + self.validation_step_outputs = [] + self.validation_step_outputs.append(output_dict) - def validation_epoch_end(self, outputs): + return output_dict + + def on_validation_epoch_end(self): val_loss, num_items = 0, 0 + outputs = getattr(self, "validation_step_outputs", []) for output in outputs: val_loss += output["val_loss"].sum().item() num_items += output["val_number"] mean_val_dice = self.dice_metric.aggregate().item() self.dice_metric.reset() mean_val_loss = torch.tensor(val_loss / num_items) - self.log_dict({ - "mean_val_dice": mean_val_dice, - "mean_val_loss": mean_val_loss, - }) + self.log_dict( + { + "mean_val_dice": mean_val_dice, + "mean_val_loss": mean_val_loss, + } + ) + # Clear the stored outputs + self.validation_step_outputs = [] return diff --git a/project/transforms/LoadImageXNATd.py b/project/transforms/LoadImageXNATd.py new file mode 100644 index 0000000..8b03963 --- /dev/null +++ b/project/transforms/LoadImageXNATd.py @@ -0,0 +1,160 @@ +""" +MONAI MapTransform for importing image data from XNAT +""" + +import glob +import logging +import os +import tempfile +import time + +import xnat +from monai.config import KeysCollection +from monai.transforms import MapTransform, LoadImage +from monai.transforms import Transform + +logger = logging.getLogger(__name__) + + +class LoadImageXNATd(MapTransform): + """ + MapTransform for importing image data from XNAT + """ + + def __init__( + self, + keys: KeysCollection, + xnat_configuration: dict = None, + image_loader: Transform = LoadImage(), + validate_data: bool = False, + expected_filetype_ext: str = ".dcm", + image_series_option: str = "error", + verbose=False, + ): + super().__init__(keys) + self.image_loader = image_loader + self.xnat_configuration = xnat_configuration + self.expected_filetype = expected_filetype_ext + self.validate_data = validate_data + self.image_series_option = image_series_option + self.verbose = verbose + + def __call__(self, data): + """ + Checks for the requested keys in the input data dictionary. + + If specified key is found then will loop over actions in action list and apply each to the value of the + requested key, if no action is triggered then raise a warning actions arefunctions that return any of projects, + subjects, experiments, scans, or resources XNAT object along with a key to be used in the data dictionary + + Each action function should locate a single image object in XNAT. This image object is then downloaded to a + temporary directory and loaded into memory as the value defined by key set by the actions' data_label. + + If validate_data is true then NO data will be downloaded. In this case the transform will loop over the actions + but will instead return a true/false value for each data sample. This can be used to remove samples where the + data is not present in XNAT. + + :param data: dictionary of data + :return: + + """ + + d = dict(data) + + for key in self.keys: + + if key in data: + + adjusted_xnat_config = self.xnat_configuration.copy() + if "localhost" in adjusted_xnat_config.get("server", ""): + adjusted_xnat_config["server"] = adjusted_xnat_config[ + "server" + ].replace("localhost", "host.docker.internal") + + with xnat.connect( + server=adjusted_xnat_config["server"], + user=adjusted_xnat_config["user"], + password=adjusted_xnat_config["password"], + verify=adjusted_xnat_config["verify"], + loglevel="ERROR", + ) as session: + + "Check data list has no duplicate keys" + if len(set([x["data_label"] for x in d[key]])) != len( + [x["data_label"] for x in d[key]] + ): + logger.warning("Multiple images with identical labels found") + raise + + "Download image from XNAT" + for item in d[key]: + data_label = item["data_label"] + if item["data_type"] == "value": + d[data_label] = item["action_data"] + continue + + attempts = 0 + while attempts < 3: + try: + with tempfile.TemporaryDirectory() as tmpdirname: + session_obj = session.create_object( + item["action_data"] + ) + session_obj.download_dir( + tmpdirname, verbose=self.verbose + ) + + images_path = glob.glob( + os.path.join( + tmpdirname, "**/*" + self.expected_filetype + ), + recursive=True, + ) + + # image loader needs full path to load single images + # logger.info(f"Downloading images: {images_path}") + if not images_path: + raise Exception + + if len(images_path) == 1: + image = self.image_loader(images_path) + + # image loader needs directory path to load 3D images + else: + "find unique directories in list of image paths" + image_dirs = list( + set( + os.path.dirname(image_path) + for image_path in images_path + ) + ) + if len(image_dirs) > 1: + if self.image_series_option == "error": + raise ValueError( + f"More than one image series found in {images_path}" + ) + elif ( + self.image_series_option == "keepfirst" + ): + image = self.image_loader(image_dirs[0]) + elif self.image_series_option == "keeplast": + image = self.image_loader( + image_dirs[-1] + ) + else: + raise ValueError( + f"More than one image series found in {images_path} and no option specified" + ) + break + + except Exception as e: + attempts += 1 + time.sleep(1.0) + if attempts == 3: + raise Exception( + f"Image loader failed on {item} after 3 retries due to {e}" + ) + + d[data_label] = image + + return d diff --git a/project/util/__init__.py b/project/utils/__init__.py similarity index 100% rename from project/util/__init__.py rename to project/utils/__init__.py diff --git a/project/utils/tools.py b/project/utils/tools.py new file mode 100644 index 0000000..f2fa0a6 --- /dev/null +++ b/project/utils/tools.py @@ -0,0 +1,284 @@ +import asyncio +import logging +import os +from concurrent.futures import ThreadPoolExecutor +from itertools import chain + +import mlflow +import pandas as pd +import tqdm +import xnat + +logger = logging.getLogger(__name__) + + +class DataBuilderXNAT: + + def __init__( + self, + xnat_configuration: dict, + actions: list = None, + flatten_output=True, + test_batch: int = -1, + num_workers: int = 1, + validate_data=True, + ): + self.xnat_configuration = xnat_configuration + self.actions = actions + self.flatten_output = flatten_output + self.test_batch = test_batch + self.missing_data_log = [] + self.num_workers = num_workers + self.validate_data = validate_data + + self.dataset = [] + + def fetch_data(self): + asyncio.run(self.start_async_process()) + + async def start_async_process(self): + with ThreadPoolExecutor(max_workers=self.num_workers) as executor: + adjusted_xnat_config = self.xnat_configuration.copy() + if "localhost" in adjusted_xnat_config.get("server", ""): + adjusted_xnat_config["server"] = adjusted_xnat_config["server"].replace( + "localhost", "host.docker.internal" + ) + + with xnat.connect( + server=adjusted_xnat_config["server"], + user=adjusted_xnat_config["user"], + password=adjusted_xnat_config["password"], + verify=adjusted_xnat_config["verify"], + loglevel="ERROR", + ) as session: + + logger.info( + f"Collecting XNAT project: {adjusted_xnat_config['project']}" + ) + project = session.projects[adjusted_xnat_config["project"]] + + if 0 < self.test_batch < len(project.subjects): + from random import sample + + project_subjects = sample(project.subjects[:], self.test_batch) + else: + project_subjects = project.subjects[:] + + loop = asyncio.get_event_loop() + tasks = [ + loop.run_in_executor( + executor, self.process_subject, *(project, subject_i) + ) + for subject_i in project_subjects + ] + + responses = [ + await f + for f in tqdm.tqdm(asyncio.as_completed(tasks), total=len(tasks)) + ] + + for response in await asyncio.gather(*tasks): + pass + + if self.validate_data: + self.dataset = [ + item + for item in self.dataset + if len(item["data"]) >= len(self.actions) + ] + + def process_subject(self, project, subject_i): + subject = project.subjects.data[subject_i.id] + data_sample = { + "subject_uri": project.subjects[subject.id].uri, + "subject_id": project.subjects[subject.id].label, + "data": [], + } + + if self.actions: + try: + data = [] + for action, data_label in self.actions: + action_data = [] + + xnat_obj = action(project.subjects[subject.id]) + + if ( + xnat_obj is None + or type(xnat_obj) == list + and len(xnat_obj) == 0 + ): + self.missing_data_log.append( + { + "subject_id": subject_i.id, + "action_data": subject_i.label, + "failed_action": action, + } + ) + logger.debug( + f"No data found for {subject_i}: action {action} removing sample" + ) + raise Exception + + elif type(xnat_obj) == list: + for obj in xnat_obj: + if obj.cache_id[0] == "ResourceCatalog": + action_data.append( + { + "source_action": action.__name__, + "action_data": obj.uri, + "data_type": "xnat_uri", + "data_label": data_label, + "resource_files": list(obj.files), + } + ) # retrieve filenames if xnat_obj contains Resource + else: + action_data.append( + { + "source_action": action.__name__, + "action_data": obj.uri, + "data_type": "xnat_uri", + "data_label": data_label, + } + ) + + else: + action_data.append( + { + "source_action": action.__name__, + "action_data": xnat_obj, + "data_type": "value", + "data_label": data_label, + } + ) + + data.append(action_data) + + except Exception as e: + logger.debug( + f"No data found for {subject_i}; removing sample. Exception {e}" + ) + pass + + if self.flatten_output: + data = list(chain(*data)) + + data_sample["data"] = data + + self.dataset.append((data_sample)) + + +def xnat_build_dataset( + xnat_configuration: dict, + actions: list = None, + flatten_output=True, + test_batch: int = -1, +): + """ + ***NOTE THE ASYNC VERSION ABOVE WILL BE MUCH FASTER FOR LARGE DATASETS*** + + Builds a dictionary that describes the XNAT project dataset using XNAT data hierarchy: project/subject/experiment/scan + + structured output returned flattened dataset. Can set structured_output to True to perform custom flattening. + + """ + + adjusted_xnat_config = xnat_configuration.copy() + if "localhost" in adjusted_xnat_config.get("server", ""): + adjusted_xnat_config["server"] = adjusted_xnat_config["server"].replace( + "localhost", "host.docker.internal" + ) + + with xnat.connect( + server=adjusted_xnat_config["server"], + user=adjusted_xnat_config["user"], + password=adjusted_xnat_config["password"], + verify=adjusted_xnat_config["verify"], + loglevel="ERROR", + ) as session: + + logger.info(f"Collecting XNAT project: {adjusted_xnat_config['project']}") + project = session.projects[adjusted_xnat_config["project"]] + + dataset = [] + + if 0 < test_batch < len(project.subjects): + from random import sample + + project_subjects = sample(project.subjects[:], test_batch) + else: + project_subjects = project.subjects[:] + + missing_data_log = [] + for subject_i in project_subjects: + subject = project.subjects.data[subject_i.id] + data_sample = { + "subject_uri": project.subjects[subject.id].uri, + "subject_id": project.subjects[subject.id].label, + "data": [], + } + + if actions: + try: + data = [] + for action, data_label in actions: + action_data = [] + + xnat_obj = action(project.subjects[subject.id]) + + if type(xnat_obj) == list: + if len(xnat_obj) == 0: + missing_data_log.append( + { + "subject_id": subject_i.id, + "action_data": subject_i.label, + "failed_action": action, + } + ) + logger.debug( + f"No data found for {subject_i}: action {action} removing sample" + ) + raise Exception + + for obj in xnat_obj: + action_data.append( + { + "source_action": action.__name__, + "action_data": obj.uri, + "data_type": "xnat_uri", + "data_label": data_label, + } + ) + + elif type(xnat_obj) == str: + action_data.append( + { + "source_action": action.__name__, + "action_data": xnat_obj, + "data_type": "value", + "data_label": data_label, + } + ) + + data.append(action_data) + + except Exception as e: + logger.debug( + f"No data found for {subject_i}; removing sample. Exception {e}" + ) + continue + + if flatten_output: + data = list(chain(*data)) + + data_sample["data"] = data + + dataset.append((data_sample)) + + if missing_data_log: + df = pd.DataFrame(missing_data_log) + df.to_csv("missing_data_log.csv") + mlflow.log_artifact("missing_data_log.csv") + mlflow.log_param("N_failed_xnat_samples", len(df)) + + return dataset diff --git a/project/util/visualise.py b/project/utils/visualise.py similarity index 62% rename from project/util/visualise.py rename to project/utils/visualise.py index 4f2dabe..40931d5 100644 --- a/project/util/visualise.py +++ b/project/utils/visualise.py @@ -3,25 +3,33 @@ import itertools import random import mlflow -from mlops.utils.logger import logger +import logging import numpy as np from monai.inferers import sliding_window_inference +logger = logging.getLogger(__name__) + def plot_inference_test(net, dm, n_samples_plot=4): net.eval() - device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu:0") + device = ( + torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu:0") + ) net.to(device) with torch.no_grad(): dl = dm.val_dataloader() n_samples_total = len(dl) - sample_idx = random.sample(range(n_samples_total), min(n_samples_total, n_samples_plot)) + sample_idx = random.sample( + range(n_samples_total), min(n_samples_total, n_samples_plot) + ) if n_samples_total == 0: - logger.warning(f'Unable to create preview figure: Dataloader {dl} is empty') + logger.warning(f"Unable to create preview figure: Dataloader {dl} is empty") return - fig, axs = plt.subplots(len(sample_idx), 4, figsize=(int(5*len(sample_idx)), 20), dpi=80) + fig, axs = plt.subplots( + len(sample_idx), 4, figsize=(int(5 * len(sample_idx)), 20), dpi=80 + ) for i, n in enumerate(sample_idx): k = int(np.floor(n / dl.batch_size)) @@ -29,11 +37,15 @@ def plot_inference_test(net, dm, n_samples_plot=4): roi_size = (-1, -1, -1) sw_batch_size = 1 - val_outputs = sliding_window_inference(val_data["image"].to(device), roi_size, sw_batch_size, net) + val_outputs = sliding_window_inference( + val_data["image"].to(device), roi_size, sw_batch_size, net + ) test_slice = random.randint(0, val_data["image"].shape[4]) input_image = val_data["image"].cpu().numpy()[0, 0, :, :, test_slice] label_image = val_data["label"].cpu().numpy()[0, 0, :, :, test_slice] - pred_image = np.array(torch.argmax(val_outputs, dim=1).detach().cpu()[0, :, :, test_slice]) + pred_image = np.array( + torch.argmax(val_outputs, dim=1).detach().cpu()[0, :, :, test_slice] + ) axs[i, 0].set_title(f"image {k} slice{test_slice}") axs[i, 0].imshow(input_image, cmap="gray") @@ -49,6 +61,5 @@ def plot_inference_test(net, dm, n_samples_plot=4): plt.tight_layout() plt.show() - plt.savefig('example_inference.png') - mlflow.log_artifact('example_inference.png') - + plt.savefig("example_inference.png") + mlflow.log_artifact("example_inference.png") diff --git a/scripts/train.py b/scripts/train.py index 4aae59e..3ee24dc 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -8,7 +8,7 @@ from project.DataModule import DataModule from project.Network import Network -from project.util.visualise import plot_inference_test +from project.utils.visualise import plot_inference_test def train(config): @@ -19,49 +19,56 @@ def train(config): n_gpu = 4 log_torchscript = True - xnat_configuration = {'server': config['xnat']['SERVER'], - 'user': config['xnat']['USER'], - 'password': config['xnat']['PASSWORD'], - 'project': config['xnat']['PROJECT'], - 'verify': config.getboolean('xnat', 'VERIFY')} + xnat_configuration = { + "server": config["xnat"]["SERVER"], + "user": config["xnat"]["USER"], + "password": config["xnat"]["PASSWORD"], + "project": config["xnat"]["PROJECT"], + "verify": config.getboolean("xnat", "VERIFY"), + } - print('Creating Network and DataModule') - dm = DataModule(xnat_configuration=xnat_configuration, batch_size=batch_size, test_batch=test_batch) + print("Creating Network and DataModule") + dm = DataModule( + xnat_configuration=xnat_configuration, + batch_size=batch_size, + test_batch=test_batch, + ) net = Network(dropout=0.2) - print('Starting logged run') + print("Starting logged run") mlflow.pytorch.autolog(log_models=False) - with mlflow.start_run(run_name='training') as run: + with mlflow.start_run(run_name="training") as run: mlf_logger = MLFlowLogger( - experiment_name=mlflow.get_experiment(mlflow.active_run().info.experiment_id).name, + experiment_name=mlflow.get_experiment( + mlflow.active_run().info.experiment_id + ).name, tracking_uri=mlflow.get_tracking_uri(), run_id=mlflow.active_run().info.run_id, ) - trainer = pl.Trainer(logger=mlf_logger, - auto_select_gpus=True, - precision=16 if cuda_available() else 32, - accelerator='gpu' if cuda_available() else None, - devices=n_gpu if cuda_available() else None, - max_epochs=max_epochs, - log_every_n_steps=1, - strategy="ddp" if cuda_available() else None, - ) + trainer = pl.Trainer( + logger=mlf_logger, + precision=32, + accelerator="cpu", + max_epochs=max_epochs, + log_every_n_steps=1, + strategy="auto", + ) trainer.fit(net, dm) plot_inference_test(net, dm) if log_torchscript: - scripted_model = net.to_torchscript(file_path='model.ts') + scripted_model = net.to_torchscript(file_path="model.ts") mlflow.pytorch.log_model(scripted_model, "model") -if __name__ == '__main__': +if __name__ == "__main__": if len(sys.argv) > 0: config_path = sys.argv[1] else: - config_path = '../config/config.cfg' + config_path = "../config/config.cfg" config = configparser.ConfigParser() config.read(config_path) From cf7460878391fc2af4b55722ade54f9921ba06ca Mon Sep 17 00:00:00 2001 From: AgatheZ Date: Sun, 22 Jun 2025 23:42:44 +0200 Subject: [PATCH 2/2] update requs --- requirements.txt | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/requirements.txt b/requirements.txt index b3b0364..0545d34 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,16 +1,16 @@ -pytorch_lightning -mlflow==2.0.1 +pytorch-lightning==2.5.1 +mlflow==2.10.0 protobuf~=3.19.0 # temporary solution to avoid breaking changes made to 3rd party streamlit (necessary for mlflow) https://discuss.streamlit.io/t/typeerror-descriptors-cannot-not-be-created-directly/25639 -csc-mlops -numpy -torchvision +csc-mlops==0.9.23 +numpy==1.26.4 +torch==2.1.2 +torchvision==0.16.2 monai itk tqdm pandas matplotlib xnat -boto3 docker pytest fsspec