diff --git a/.gitignore b/.gitignore index cdffe91..584d441 100644 --- a/.gitignore +++ b/.gitignore @@ -24,7 +24,6 @@ __pycache__/ # Distribution / packaging .Python -env/ build/ develop-eggs/ dist/ @@ -37,6 +36,7 @@ parts/ sdist/ var/ wheels/ +share/python-wheels/ *.egg-info/ .installed.cfg *.egg @@ -55,13 +55,17 @@ pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ +.nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml -*,cover +*.cover +*.py,cover .hypothesis/ +.pytest_cache/ +cover/ # Translations *.mo @@ -70,6 +74,8 @@ coverage.xml # Django stuff: *.log local_settings.py +db.sqlite3 +db.sqlite3-journal # Flask stuff: instance/ @@ -82,32 +88,101 @@ instance/ docs/_build/ # PyBuilder +.pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints +# IPython +profile_default/ +ipython_config.py + # pyenv .python-version -# celery beat schedule file +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff celerybeat-schedule +celerybeat.pid -# dotenv -.env +# SageMath parsed files +*.sage.py -# virtualenv -.venv/ +# Environments +.env +.venv +env/ venv/ ENV/ +env.bak/ +venv.bak/ # Spyder project settings .spyderproject +.spyproject # Rope project settings .ropeproject -# vscode -.vscode/ -.vs/ -© 2019 GitHub, Inc. +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc diff --git a/deeplc/deeplc.py b/deeplc/deeplc.py index ab49568..5fbe9c4 100644 --- a/deeplc/deeplc.py +++ b/deeplc/deeplc.py @@ -12,20 +12,21 @@ "Robbin Bouwmeester", "Ralf Gabriels", "Arthur Declercq", + "Alireza Nameni" "Lennart Martens", "Sven Degroeve", ] - # Default models, will be used if no other is specified. If no best model is # selected during calibration, the first model in the list will be used. import os + deeplc_dir = os.path.dirname(os.path.realpath(__file__)) DEFAULT_MODELS = [ - "mods/full_hc_PXD005573_pub_1fd8363d9af9dcad3be7553c39396960.keras", - "mods/full_hc_PXD005573_pub_8c22d89667368f2f02ad996469ba157e.keras", - "mods/full_hc_PXD005573_pub_cb975cfdd4105f97efa0b3afffe075cc.keras", + "mods/full_hc_PXD005573_pub_1fd8363d9af9dcad3be7553c39396960.pt", + "mods/full_hc_PXD005573_pub_8c22d89667368f2f02ad996469ba157e.pt", + "mods/full_hc_PXD005573_pub_cb975cfdd4105f97efa0b3afffe075cc.pt", ] DEFAULT_MODELS = [os.path.join(deeplc_dir, dm) for dm in DEFAULT_MODELS] @@ -43,12 +44,10 @@ from configparser import ConfigParser from itertools import chain from tempfile import TemporaryDirectory - from sklearn.preprocessing import SplineTransformer from sklearn.linear_model import LinearRegression from sklearn.pipeline import make_pipeline - # If CLI/GUI/frozen: disable Tensorflow info and warnings before importing IS_CLI_GUI = os.path.basename(sys.argv[0]) in ["deeplc", "deeplc-gui"] IS_FROZEN = getattr(sys, "frozen", False) @@ -59,25 +58,15 @@ warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) -# Supress warnings (or at least try...) -logging.getLogger("tensorflow").setLevel(logging.ERROR) -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - +import torch +from torch.utils.data import Dataset, DataLoader import numpy as np import pandas as pd -import tensorflow as tf -from deeplcretrainer import deeplcretrainer from psm_utils.io import read_file from psm_utils.io.peptide_record import peprec_to_proforma from psm_utils.psm import PSM from psm_utils.psm_list import PSMList -try: - from tensorflow.keras.models import load_model -except: - from tensorflow.python.keras.models import load_model -from tensorflow.python.eager import context - from deeplc._exceptions import CalibrationError from deeplc.trainl3 import train_en @@ -121,6 +110,175 @@ def reset_keras(): # Set to force CPU calculations os.environ["CUDA_VISIBLE_DEVICES"] = "-1" +class DeepLCDataset(Dataset): + """ + Custom Dataset class for DeepLC used for loading features from peptide sequences. + + Parameters + ---------- + X : ndarray + Feature matrix for input data. + X_sum : ndarray + Feature matrix for sum of input data. + X_global : ndarray + Feature matrix for global input data. + X_hc : ndarray + Feature matrix for high-order context features. + target : ndarray, optional + The target retention times. Default is None. + """ + + def __init__(self, X, X_sum, X_global, X_hc, target=None): + self.X = torch.from_numpy(X).float() + self.X_sum = torch.from_numpy(X_sum).float() + self.X_global = torch.from_numpy(X_global).float() + self.X_hc = torch.from_numpy(X_hc).float() + + if target is not None: + self.target = torch.from_numpy(target).float() # Add target values if provided + else: + self.target = None # If no target is provided, set it to None + + def __len__(self): + return self.X.shape[0] + + def __getitem__(self, idx): + if self.target is not None: + # Return both features and target during training + return ( + self.X[idx], + self.X_sum[idx], + self.X_global[idx], + self.X_hc[idx], + self.target[idx] + ) + else: + # Return only features during prediction + return ( + self.X[idx], + self.X_sum[idx], + self.X_global[idx], + self.X_hc[idx] + ) + + +class DeepLCFineTuner: + """ + Class for fine-tuning a DeepLC model. + + Parameters + ---------- + model : torch.nn.Module + The model to fine-tune. + train_data : torch.utils.data.Dataset + Dataset containing the training data. + device : str, optional, default='cpu' + The device on which to run the model ('cpu' or 'cuda'). + learning_rate : float, optional, default=0.001 + The learning rate for the optimizer. + epochs : int, optional, default=10 + Number of training epochs. + batch_size : int, optional, default=256 + Batch size for training. + validation_data : torch.utils.data.Dataset or None, optional + If provided, used directly for validation. Otherwise, a fraction of + `train_data` will be held out. + validation_split : float, optional, default=0.1 + Fraction of `train_data` to reserve for validation when + `validation_data` is None. + patience : int, optional, default=5 + Number of epochs with no improvement on validation loss before stopping. + """ + + def __init__(self, model, train_data, device="cpu", learning_rate=0.001, epochs=10, batch_size=256, validation_data=None, + validation_split=0.1, patience=5): + self.model = model.to(device) + self.train_data = train_data + self.device = device + self.learning_rate = learning_rate + self.epochs = epochs + self.batch_size = batch_size + self.validation_data = validation_data + self.validation_split = validation_split + self.patience = patience + + def _freeze_layers(self, unfreeze_keywords="33_1"): + """ + Freezes all layers except those that contain the unfreeze_keyword + in their name. + """ + + for name, param in self.model.named_parameters(): + + param.requires_grad = (unfreeze_keywords in name) + + + def prepare_data(self, data, shuffle=True): + return DataLoader(data, batch_size=self.batch_size, shuffle=shuffle) + + def fine_tune(self): + logger.debug("Starting fine-tuning...") + if self.validation_data is None: + # Split the training data into training and validation sets + val_size = int(len(self.train_data) * self.validation_split) + train_size = len(self.train_data) - val_size + train_dataset, val_dataset = torch.utils.data.random_split( + self.train_data, [train_size, val_size] + ) + else: + train_dataset = self.train_data + val_dataset = self.validation_data + train_loader = self.prepare_data(train_dataset) + val_loader = self.prepare_data(val_dataset, shuffle=False) + + optimizer = torch.optim.Adam( + filter(lambda p: p.requires_grad, self.model.parameters()), + lr=self.learning_rate + ) + loss_fn = torch.nn.L1Loss() + best_model_wts = copy.deepcopy(self.model.state_dict()) + best_val_loss = float("inf") + epochs_no_improve = 0 + + for epoch in range(self.epochs): + running_loss = 0.0 + self.model.train() + for batch in train_loader: + batch_X, batch_X_sum, batch_X_global, batch_X_hc, target = batch + + target = target.view(-1, 1) + + optimizer.zero_grad() + outputs = self.model(batch_X, batch_X_sum, batch_X_global, batch_X_hc) + loss = loss_fn(outputs, target) + loss.backward() + optimizer.step() + running_loss += loss.item() + avg_loss = running_loss / len(train_loader) + + self.model.eval() + val_loss = 0.0 + with torch.no_grad(): + for batch in val_loader: + batch_X, batch_X_sum, batch_X_global, batch_X_hc, target = batch + target = target.view(-1, 1) + outputs = self.model(batch_X, batch_X_sum, batch_X_global, batch_X_hc) + val_loss += loss_fn(outputs, target).item() + avg_val_loss = val_loss / len(val_loader) + + logger.debug(f"Epoch {epoch + 1}/{self.epochs}, Loss: {avg_loss:.4f}, Validation Loss: {avg_val_loss:.4f}") + if avg_val_loss < best_val_loss: + best_val_loss = avg_val_loss + best_model_wts = copy.deepcopy(self.model.state_dict()) + epochs_no_improve = 0 + else: + epochs_no_improve += 1 + if epochs_no_improve >= self.patience: + logger.debug(f"Early stopping triggered {epoch + 1}") + break + self.model.load_state_dict(best_model_wts) + return self.model + class DeepLC: """ @@ -149,18 +307,33 @@ class DeepLC: path to configuration file f_extractor : object :: deeplc.FeatExtractor, optional deeplc.FeatExtractor object to use - cnn_model : bool, default=True + cnn_model : bool, default=True, optional use CNN model or not batch_num : int, default=250000 prediction batch size (in peptides); lower to decrease memory footprint - write_library : bool, default=False + write_library : bool, default=False, optional append new predictions to library for faster future results; requires `use_library` option use_library : str, optional library file with previous predictions for faster results to read from, or to write to - reload_library : bool, default=False + reload_library : bool, default=False, optional reload prediction library + pygam_calibration : bool, default=True, optional + use pygam for calibration + deepcallc_mod : bool, default=False, optional + use DeepCallC model + deeplc_retrain : bool, default=False, optional + retrain DeepLC model on the provided data + predict_ccs : bool, default=False, optional + predict CCS values + n_epochs : int, default=20, optional + number of epochs for model fine-tuning (if applicable) + single_model_mode : bool, default=True, optional + use single model mode (if applicable) + """ + + """ Methods ------- @@ -175,29 +348,30 @@ class DeepLC: # TODO have a CCS flag here def __init__( - self, - main_path=os.path.dirname(os.path.realpath(__file__)), - path_model=None, - verbose=True, - bin_dist=2, - dict_cal_divider=50, - split_cal=50, - n_jobs=None, - config_file=None, - f_extractor=None, - cnn_model=True, - batch_num=250000, - batch_num_tf=1024, - write_library=False, - use_library=None, - reload_library=False, - pygam_calibration=True, - deepcallc_mod=False, - deeplc_retrain=False, - predict_ccs=False, - n_epochs=20, - single_model_mode=True, + self, + main_path=os.path.dirname(os.path.realpath(__file__)), + path_model=None, + verbose=True, + bin_dist=2, + dict_cal_divider=50, + split_cal=50, + n_jobs=None, + config_file=None, + f_extractor=None, + cnn_model=True, + batch_num=250000, + batch_num_tl=128, + write_library=False, + use_library=None, + reload_library=False, + pygam_calibration=True, + deepcallc_mod=False, + deeplc_retrain=False, + predict_ccs=False, + n_epochs=100, + single_model_mode=True, ): + # if a config file is defined overwrite standard parameters if config_file: cparser = ConfigParser() @@ -214,15 +388,15 @@ def __init__( self.calibrate_max = 0 self.n_epochs = n_epochs self.cnn_model = cnn_model - + self.model_cache = {} self.batch_num = batch_num - self.batch_num_tf = batch_num_tf + self.batch_num_tf = batch_num_tl self.dict_cal_divider = dict_cal_divider self.split_cal = split_cal self.n_jobs = n_jobs if self.n_jobs == None: - max_threads = multiprocessing.cpu_count() + max_threads = 1 self.n_jobs = max_threads self.use_library = use_library @@ -230,15 +404,8 @@ def __init__( self.reload_library = reload_library - try: - tf.config.threading.set_intra_op_parallelism_threads(n_jobs) - except RuntimeError: - logger.warning( - "DeepLC tried to set intra op threads, but was unable to do so." - ) - - if "NUMEXPR_MAX_THREADS" not in os.environ: - os.environ["NUMEXPR_MAX_THREADS"] = str(n_jobs) + # if "NUMEXPR_MAX_THREADS" not in os.environ: + # os.environ["NUMEXPR_MAX_THREADS"] = str(n_jobs) if path_model: self.model = path_model @@ -481,7 +648,57 @@ def do_f_extraction_psm_list_parallel(self, psm_list): return all_feats + def _prepare_feature_matrices(self, psm_list): + """ + Extract features in parallel and assemble the four input matrices. + + Parameters + ---------- + psm_list : list of PSM + List of peptide‐spectrum matches for which to extract features. + + Returns + ------- + X : ndarray, shape (n_peptides, n_features) + X_sum : ndarray, shape (n_peptides, n_sum_features) + X_global : ndarray, shape (n_peptides, n_global_features * 2) + X_hc : ndarray, shape (n_peptides, n_hc_features) + """ + feats = self.do_f_extraction_psm_list_parallel(psm_list) + X = np.stack(list(feats["matrix"].values())) + X_sum = np.stack(list(feats["matrix_sum"].values())) + X_global = np.concatenate( + ( + np.stack(list(feats["matrix_all"].values())), + np.stack(list(feats["pos_matrix"].values())), + ), + axis=1, + ) + X_hc = np.stack(list(feats["matrix_hc"].values())) + return X, X_sum, X_global, X_hc + + + def calibration_core(self, uncal_preds, cal_dict, cal_min, cal_max): + """ + Perform calibration on uncalibrated predictions. + + Parameters + ---------- + uncal_preds : list or ndarray + The uncalibrated predicted retention times. + cal_dict : dict + Dictionary containing calibration parameters for different retention times. + cal_min : float + The minimum value for the calibration range. + cal_max : float + The maximum value for the calibration range. + + Returns + ------- + np.array + The calibrated retention time predictions. + """ cal_preds = [] if len(uncal_preds) == 0: return np.array(cal_preds) @@ -503,10 +720,10 @@ def calibration_core(self, uncal_preds, cal_dict, cal_min, cal_max): # Replace predictions outside the range with the linear model predictions cal_preds[~within_range & (uncal_preds.ravel() < cal_min)] = y_pred_left[ ~within_range & (uncal_preds.ravel() < cal_min) - ] + ] cal_preds[~within_range & (uncal_preds.ravel() > cal_max)] = y_pred_right[ ~within_range & (uncal_preds.ravel() > cal_max) - ] + ] else: for uncal_pred in uncal_preds: try: @@ -527,6 +744,23 @@ def calibration_core(self, uncal_preds, cal_dict, cal_min, cal_max): return np.array(cal_preds) def make_preds_core_library(self, psm_list=[], calibrate=True, mod_name=None): + """ + Make predictions for sequences using a pre-computed library. + + Parameters + ---------- + psm_list : list of PSM, optional + A list of PSM objects for which predictions are to be made. + calibrate : bool, optional, default=True + Whether to calibrate the predictions or not. + mod_name : str, optional + The model name to use for prediction. + + Returns + ------- + np.array + The predicted retention times for the peptides. + """ ret_preds = [] for psm in psm_list: ret_preds.append(LIBRARY[psm.peptidoform.proforma + "|" + mod_name]) @@ -594,31 +828,25 @@ def make_preds_core( if len(X) == 0 and len(psm_list) > 0: if self.verbose: logger.debug("Extracting features for the CNN model ...") - X = self.do_f_extraction_psm_list_parallel(psm_list) - - X_sum = np.stack(list(X["matrix_sum"].values())) - X_global = np.concatenate( - ( - np.stack(list(X["matrix_all"].values())), - np.stack(list(X["pos_matrix"].values())), - ), - axis=1, - ) - X_hc = np.stack(list(X["matrix_hc"].values())) - X = np.stack(list(X["matrix"].values())) + X, X_sum, X_global, X_hc = self._prepare_feature_matrices(psm_list) elif len(X) == 0 and len(psm_list) == 0: return [] + dataset = DeepLCDataset(X, X_sum, X_global, X_hc) + loader = DataLoader(dataset, batch_size=self.batch_num_tf, shuffle=False) + ret_preds = [] - mod = load_model(mod_name) + mod = torch.load(mod_name, weights_only=False, map_location=torch.device("cpu")) + mod.eval() try: - X - ret_preds = mod.predict( - [X, X_sum, X_global, X_hc], - batch_size=self.batch_num_tf, - verbose=int(self.verbose), - ).flatten() + with torch.no_grad(): + for batch in loader: + batch_X, batch_X_sum, batch_X_global, batch_X_hc = batch + batch_preds = mod(batch_X, batch_X_sum, batch_X_global, batch_X_hc) + ret_preds.append(batch_preds.detach().cpu().numpy()) + + ret_preds = np.concatenate(ret_preds, axis=0) except UnboundLocalError: logger.debug("X is empty, skipping...") ret_preds = [] @@ -717,17 +945,7 @@ def make_preds( if self.verbose: logger.debug("Extracting features for the CNN model ...") - X = self.do_f_extraction_psm_list_parallel(psm_list_t) - X_sum = np.stack(list(X["matrix_sum"].values())) - X_global = np.concatenate( - ( - np.stack(list(X["matrix_all"].values())), - np.stack(list(X["pos_matrix"].values())), - ), - axis=1, - ) - X_hc = np.stack(list(X["matrix_hc"].values())) - X = np.stack(list(X["matrix"].values())) + X, X_sum, X_global, X_hc = self._prepare_feature_matrices(psm_list_t) else: return [] @@ -1163,36 +1381,36 @@ def calibrate_preds( temp_pred = [] if self.deeplc_retrain: - # The following code is not required in most cases, but here it is used to clear variables that might cause problems - _ = tf.Variable([1]) + logger.debug("Preparing for model fine-tuning...") - context._context = None - context._create_context() + X, X_sum, X_global, X_hc = self._prepare_feature_matrices(psm_list) + dataset = DeepLCDataset(X, X_sum, X_global, X_hc, np.array(measured_tr)) - tf.config.threading.set_inter_op_parallelism_threads(1) + base_model_path = self.model[0] if isinstance(self.model, list) else self.model + base_model = torch.load(base_model_path, weights_only=False, map_location=torch.device("cpu")) + base_model.eval() - if len(location_retraining_models) > 0: - t_dir_models = TemporaryDirectory().name - os.mkdir(t_dir_models) + fine_tuner = DeepLCFineTuner( + model=base_model, + train_data=dataset, + batch_size=self.batch_num_tf, + epochs=self.n_epochs, + ) + # fine_tuner._freeze_layers() + fine_tuned_model = fine_tuner.fine_tune() + + if not location_retraining_models: + temp_dir_obj = TemporaryDirectory() + t_dir_models = temp_dir_obj.name + self._temp_dir_obj = temp_dir_obj else: t_dir_models = location_retraining_models - try: - os.mkdir(t_dir_models) - except: - pass - - # Here we will apply transfer learning we specify previously trained models in the 'mods_transfer_learning' - models = deeplcretrainer.retrain( - {"deeplc_transferlearn": psm_list}, - outpath=t_dir_models, - mods_transfer_learning=self.model, - freeze_layers=True, - n_epochs=self.n_epochs, - freeze_after_concat=1, - verbose=self.verbose, - ) + os.makedirs(t_dir_models, exist_ok=True) - self.model = models + # Define path to save fine-tuned model + fine_tuned_model_path = os.path.join(t_dir_models, "fine_tuned_model.pth") + torch.save(fine_tuned_model, fine_tuned_model_path) + self.model = [fine_tuned_model_path] if isinstance(sample_for_calibration_curve, int): psm_list = random.sample(list(psm_list), sample_for_calibration_curve) diff --git a/deeplc/mods/full_hc_PXD005573_pub_1fd8363d9af9dcad3be7553c39396960.keras b/deeplc/mods/full_hc_PXD005573_pub_1fd8363d9af9dcad3be7553c39396960.keras deleted file mode 100644 index 1a1c6af..0000000 Binary files a/deeplc/mods/full_hc_PXD005573_pub_1fd8363d9af9dcad3be7553c39396960.keras and /dev/null differ diff --git a/deeplc/mods/full_hc_PXD005573_pub_1fd8363d9af9dcad3be7553c39396960.pt b/deeplc/mods/full_hc_PXD005573_pub_1fd8363d9af9dcad3be7553c39396960.pt new file mode 100644 index 0000000..82399a4 Binary files /dev/null and b/deeplc/mods/full_hc_PXD005573_pub_1fd8363d9af9dcad3be7553c39396960.pt differ diff --git a/deeplc/mods/full_hc_PXD005573_pub_8c22d89667368f2f02ad996469ba157e.keras b/deeplc/mods/full_hc_PXD005573_pub_8c22d89667368f2f02ad996469ba157e.keras deleted file mode 100644 index 7016886..0000000 Binary files a/deeplc/mods/full_hc_PXD005573_pub_8c22d89667368f2f02ad996469ba157e.keras and /dev/null differ diff --git a/deeplc/mods/full_hc_PXD005573_pub_8c22d89667368f2f02ad996469ba157e.pt b/deeplc/mods/full_hc_PXD005573_pub_8c22d89667368f2f02ad996469ba157e.pt new file mode 100644 index 0000000..2d67b07 Binary files /dev/null and b/deeplc/mods/full_hc_PXD005573_pub_8c22d89667368f2f02ad996469ba157e.pt differ diff --git a/deeplc/mods/full_hc_PXD005573_pub_cb975cfdd4105f97efa0b3afffe075cc.keras b/deeplc/mods/full_hc_PXD005573_pub_cb975cfdd4105f97efa0b3afffe075cc.keras deleted file mode 100644 index e2d23ba..0000000 Binary files a/deeplc/mods/full_hc_PXD005573_pub_cb975cfdd4105f97efa0b3afffe075cc.keras and /dev/null differ diff --git a/deeplc/mods/full_hc_PXD005573_pub_cb975cfdd4105f97efa0b3afffe075cc.pt b/deeplc/mods/full_hc_PXD005573_pub_cb975cfdd4105f97efa0b3afffe075cc.pt new file mode 100644 index 0000000..292a21a Binary files /dev/null and b/deeplc/mods/full_hc_PXD005573_pub_cb975cfdd4105f97efa0b3afffe075cc.pt differ diff --git a/deeplc/onnx_model.py b/deeplc/onnx_model.py new file mode 100644 index 0000000..bef6e9b --- /dev/null +++ b/deeplc/onnx_model.py @@ -0,0 +1,31 @@ + +import os +from tensorflow.keras.models import load_model +import tensorflow as tf +import tf2onnx +from onnx2torch import convert +import torch + +deeplc_dir = os.path.dirname(os.path.realpath(__file__)) +DEFAULT_MODELS = [ + "mods/full_hc_PXD005573_pub_1fd8363d9af9dcad3be7553c39396960.keras", + "mods/full_hc_PXD005573_pub_8c22d89667368f2f02ad996469ba157e.keras", + "mods/full_hc_PXD005573_pub_cb975cfdd4105f97efa0b3afffe075cc.keras", +] +DEFAULT_MODELS = [os.path.join(deeplc_dir, dm) for dm in DEFAULT_MODELS] +def _convert_to_onnx(): + for model_path in DEFAULT_MODELS: + if os.path.exists(model_path): + mod = load_model(model_path) + spec = [ + tf.TensorSpec([None, 60, 6], tf.float32, name="input_1"), + tf.TensorSpec([None, 30, 6], tf.float32, name="input_2"), + tf.TensorSpec([None, 55], tf.float32, name="input_3"), + tf.TensorSpec([None, 60, 20], tf.float32, name="input_4"), + ] + onnx_model, _ = tf2onnx.convert.from_keras(mod, input_signature=spec, opset=13) + torch_model = convert(onnx_model) + torch.save(torch_model, model_path.replace(".keras", ".pt")) + +if __name__ == "__main__": + _convert_to_onnx() \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 31e8fe3..58237ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ authors = [ { name = "Niels Hulstaert" }, { name = "Arthur Declercq" }, { name = "Ralf Gabriels" }, + { name = "Alireza Nameni" }, { name = "Lennart Martens" }, { name = "Sven Degroeve" }, ] @@ -31,11 +32,11 @@ keywords = [ ] dependencies = [ - "tensorflow>=2.15.0,<3", + "torch>=2.6.0,<3", + "onnx2torch>=0.1.5", "numpy>=1.17", "pandas>=0.25", "scikit-learn>=1.2.0", - "deeplcretrainer>=1,<2", "psm_utils>=0.2.3" ] diff --git a/test.png b/test.png deleted file mode 100644 index 4e3138e..0000000 Binary files a/test.png and /dev/null differ