From fb325a61736c5bfb0971ceaa5af85c8052d44b83 Mon Sep 17 00:00:00 2001 From: lala8 Date: Wed, 20 Nov 2024 06:37:57 +0000 Subject: [PATCH] added robustness analysis --- src/polygraph/models.py | 82 +++++++++++++++++++++++++++++++++++---- src/polygraph/sequence.py | 34 +++++++++++----- 2 files changed, 99 insertions(+), 17 deletions(-) diff --git a/src/polygraph/models.py b/src/polygraph/models.py index f8f9da3..09610e7 100644 --- a/src/polygraph/models.py +++ b/src/polygraph/models.py @@ -279,24 +279,92 @@ def ism_score(model, seqs, batch_size, device="cpu", task=None): assert check_equal_lens(seqs) + # Predictions on original sequences + preds = predict(seqs=seqs, model=model, batch_size=batch_size, device=device) + assert preds.ndim < 3 + + # Select relevant task/cell type, or average predictions + if task is None: + if preds.ndim == 2: + preds = preds.mean(1, keepdims=True) + else: + preds = preds[:, [task]] + # Mutate sequences ism = ISM(seqs) # N x L x 4 # Make predictions on mutated sequences - preds = predict(seqs=ism, model=model, batch_size=batch_size, device=device) - assert preds.ndim < 3 + ism_preds = predict(seqs=ism, model=model, batch_size=batch_size, device=device) # Select relevant task/cell type, or average predictions if task is None: - if preds.ndim == 2: - preds = preds.mean(1) + if ism_preds.ndim == 2: + ism_preds = ism_preds.mean(1) else: - preds = preds[:, task] + ism_preds = ism_preds[:, task] # Reshape predictions : N, L, 4 - preds = preds.reshape(len(seqs), len(ism) // (len(seqs) * 4), 4) + ism_preds = ism_preds.reshape(len(seqs), len(ism) // (len(seqs) * 4), 4) # Compute base-level importance score - preds = np.log2(preds / preds.mean(-1, keepdims=True)) + preds = np.log2(ism_preds / preds) preds = np.abs(preds).max(-1) return preds + + +def robustness(model, seqs, batch_size, device="cpu", task=None, aggfunc="mean"): + """ + Get robustness scores for given sequence(s) using ISM + + Args: + seqs (list, pd.DataFrame): List of sequences or dataframe + containing sequences in the column "Sequence". + model (nn.Sequential): trained model + batch_size (int): Batch size for inference + device (str, int): ID of GPU to perform inference. + aggfunc (str): Either 'mean' or 'max'. Determines how to aggregate the + effect of all possible single-base mutations. + + Returns: + (pd.DataFrame): DataFrame of shape (n_seqs x n_outputs) + """ + from polygraph.sequence import ISM + from polygraph.utils import check_equal_lens + + assert check_equal_lens(seqs) + + # Predictions on original sequences + preds = predict(seqs=seqs, model=model, batch_size=batch_size, device=device) + assert preds.ndim < 3 + + # Select relevant task/cell type, or average predictions + if task is None: + if preds.ndim == 2: + preds = preds.mean(1, keepdims=True) + else: + preds = preds[:, [task]] + + # Mutate sequences + ism = ISM(seqs, drop_ref=True) # N x L x 3 + + # Make predictions on mutated sequences + ism_preds = predict(seqs=ism, model=model, batch_size=batch_size, device=device) + + # Select relevant task/cell type, or average predictions + if task is None: + if ism_preds.ndim == 2: + ism_preds = ism_preds.mean(1) + else: + ism_preds = ism_preds[:, task] + + # Reshape predictions : N, Lx3 + ism_preds = ism_preds.reshape(len(seqs), len(ism) // len(seqs)) + + # Compare mutated sequences to originals + deltas = np.abs((ism_preds / preds) - 1) + + # Aggregate over all possible mutations + if aggfunc == "mean": + return np.mean(deltas, 1) + elif aggfunc == "max": + return np.max(deltas, 1) diff --git a/src/polygraph/sequence.py b/src/polygraph/sequence.py index 3579909..c2366dd 100644 --- a/src/polygraph/sequence.py +++ b/src/polygraph/sequence.py @@ -261,13 +261,14 @@ def fastsk(seqs, k=5, m=2): return np.array(kernel.get_train_kernel()) -def ISM(seqs): +def ISM(seqs, drop_ref=False): """ Perform in-silico mutagenesis on given DNA sequence(s) Args: seqs (str, list, pd.DataFrame): A DNA sequence, list of sequences or dataframe containing sequences in the column "Sequence". + drop_ref (bool): If True, do not return the original sequence. Returns: (list): A list of all possible single-base mutated sequences @@ -275,22 +276,35 @@ def ISM(seqs): """ # ISM for a single sequence if isinstance(seqs, str): - return list( - np.concatenate( - [ - [seqs[:pos] + base + seqs[pos + 1 :] for base in STANDARD_BASES] - for pos in range(len(seqs)) - ] + if drop_ref: + return list( + np.concatenate( + [ + [ + seqs[:pos] + base + seqs[pos + 1 :] + for base in [x for x in STANDARD_BASES if x != b] + ] + for pos, b in enumerate(seqs) + ] + ) + ) + else: + return list( + np.concatenate( + [ + [seqs[:pos] + base + seqs[pos + 1 :] for base in STANDARD_BASES] + for pos in range(len(seqs)) + ] + ) ) - ) # Multiple sequences elif isinstance(seqs, list): - return list(np.concatenate([ISM(seq) for seq in seqs])) + return list(np.concatenate([ISM(seq, drop_ref=drop_ref) for seq in seqs])) # For a dataframe, copy the index elif isinstance(seqs, pd.DataFrame): - return ISM(seqs.Sequence.tolist()) + return ISM(seqs.Sequence.tolist(), drop_ref=drop_ref) else: raise TypeError("seqs must be a string, list or dataframe.")