Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 75 additions & 7 deletions src/polygraph/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,24 +279,92 @@ def ism_score(model, seqs, batch_size, device="cpu", task=None):

assert check_equal_lens(seqs)

# Predictions on original sequences
preds = predict(seqs=seqs, model=model, batch_size=batch_size, device=device)
assert preds.ndim < 3

# Select relevant task/cell type, or average predictions
if task is None:
if preds.ndim == 2:
preds = preds.mean(1, keepdims=True)
else:
preds = preds[:, [task]]

# Mutate sequences
ism = ISM(seqs) # N x L x 4

# Make predictions on mutated sequences
preds = predict(seqs=ism, model=model, batch_size=batch_size, device=device)
assert preds.ndim < 3
ism_preds = predict(seqs=ism, model=model, batch_size=batch_size, device=device)

# Select relevant task/cell type, or average predictions
if task is None:
if preds.ndim == 2:
preds = preds.mean(1)
if ism_preds.ndim == 2:
ism_preds = ism_preds.mean(1)
else:
preds = preds[:, task]
ism_preds = ism_preds[:, task]

# Reshape predictions : N, L, 4
preds = preds.reshape(len(seqs), len(ism) // (len(seqs) * 4), 4)
ism_preds = ism_preds.reshape(len(seqs), len(ism) // (len(seqs) * 4), 4)

# Compute base-level importance score
preds = np.log2(preds / preds.mean(-1, keepdims=True))
preds = np.log2(ism_preds / preds)
preds = np.abs(preds).max(-1)
return preds


def robustness(model, seqs, batch_size, device="cpu", task=None, aggfunc="mean"):
"""
Get robustness scores for given sequence(s) using ISM

Args:
seqs (list, pd.DataFrame): List of sequences or dataframe
containing sequences in the column "Sequence".
model (nn.Sequential): trained model
batch_size (int): Batch size for inference
device (str, int): ID of GPU to perform inference.
aggfunc (str): Either 'mean' or 'max'. Determines how to aggregate the
effect of all possible single-base mutations.

Returns:
(pd.DataFrame): DataFrame of shape (n_seqs x n_outputs)
"""
from polygraph.sequence import ISM
from polygraph.utils import check_equal_lens

assert check_equal_lens(seqs)

# Predictions on original sequences
preds = predict(seqs=seqs, model=model, batch_size=batch_size, device=device)
assert preds.ndim < 3

# Select relevant task/cell type, or average predictions
if task is None:
if preds.ndim == 2:
preds = preds.mean(1, keepdims=True)
else:
preds = preds[:, [task]]

# Mutate sequences
ism = ISM(seqs, drop_ref=True) # N x L x 3

# Make predictions on mutated sequences
ism_preds = predict(seqs=ism, model=model, batch_size=batch_size, device=device)

# Select relevant task/cell type, or average predictions
if task is None:
if ism_preds.ndim == 2:
ism_preds = ism_preds.mean(1)
else:
ism_preds = ism_preds[:, task]

# Reshape predictions : N, Lx3
ism_preds = ism_preds.reshape(len(seqs), len(ism) // len(seqs))

# Compare mutated sequences to originals
deltas = np.abs((ism_preds / preds) - 1)

# Aggregate over all possible mutations
if aggfunc == "mean":
return np.mean(deltas, 1)
elif aggfunc == "max":
return np.max(deltas, 1)
34 changes: 24 additions & 10 deletions src/polygraph/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,36 +261,50 @@ def fastsk(seqs, k=5, m=2):
return np.array(kernel.get_train_kernel())


def ISM(seqs):
def ISM(seqs, drop_ref=False):
"""
Perform in-silico mutagenesis on given DNA sequence(s)

Args:
seqs (str, list, pd.DataFrame): A DNA sequence, list of sequences
or dataframe containing sequences in the column "Sequence".
drop_ref (bool): If True, do not return the original sequence.

Returns:
(list): A list of all possible single-base mutated sequences
derived from the original sequences.
"""
# ISM for a single sequence
if isinstance(seqs, str):
return list(
np.concatenate(
[
[seqs[:pos] + base + seqs[pos + 1 :] for base in STANDARD_BASES]
for pos in range(len(seqs))
]
if drop_ref:
return list(
np.concatenate(
[
[
seqs[:pos] + base + seqs[pos + 1 :]
for base in [x for x in STANDARD_BASES if x != b]
]
for pos, b in enumerate(seqs)
]
)
)
else:
return list(
np.concatenate(
[
[seqs[:pos] + base + seqs[pos + 1 :] for base in STANDARD_BASES]
for pos in range(len(seqs))
]
)
)
)

# Multiple sequences
elif isinstance(seqs, list):
return list(np.concatenate([ISM(seq) for seq in seqs]))
return list(np.concatenate([ISM(seq, drop_ref=drop_ref) for seq in seqs]))

# For a dataframe, copy the index
elif isinstance(seqs, pd.DataFrame):
return ISM(seqs.Sequence.tolist())
return ISM(seqs.Sequence.tolist(), drop_ref=drop_ref)

else:
raise TypeError("seqs must be a string, list or dataframe.")
Loading