Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ install_requires =
scikit_posthocs
scipy
plotnine
hotelling
# fastsk
upsetplot


[options.packages.find]
Expand Down
61 changes: 56 additions & 5 deletions src/polygraph/embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pandas as pd
import scanpy as sc
from hotelling.stats import hotelling_t2
from scipy.stats import fisher_exact
from sklearn.metrics import pairwise_distances
from sklearn.neighbors import NearestNeighbors
Expand Down Expand Up @@ -89,7 +90,7 @@ def differential_analysis(ad, reference_group, group_col="Group"):
return ad


def groupwise_1nn(ad, reference_group, group_col="Group", use_pca=False):
def reference_1nn(ad, reference_group, group_col="Group", use_pca=False):
"""
For each sequence, find its nearest neighbor among its own group or
the reference group based on the sequence embeddings.
Expand Down Expand Up @@ -175,7 +176,7 @@ def groupwise_1nn(ad, reference_group, group_col="Group", use_pca=False):
return ad


def joint_1nn(ad, reference_group, group_col="Group", use_pca=False):
def all_1nn(ad, reference_group, group_col="Group", use_pca=False):
"""
Find the group ID of each sequence's 1-nearest neighbor statistics based on the
sequence embeddings. Compare all groups to all other groups.
Expand Down Expand Up @@ -234,7 +235,7 @@ def joint_1nn(ad, reference_group, group_col="Group", use_pca=False):
return ad


def within_group_knn_dist(ad, n_neighbors=10, group_col="Group", use_pca=False):
def group_diversity(ad, n_neighbors=10, group_col="Group", use_pca=False):
"""
Calculates the mean distance of each sequence to its k nearest neighbors in the
same group, in the embedding space. Metric of diversity
Expand Down Expand Up @@ -366,6 +367,53 @@ def dist_to_reference(ad, reference_group, group_col="Group", use_pca=False):
return ad


def distribution_shift(ad, reference_group, group_col="Group", use_pca=False):
"""
Compare the distribution of sequences in each group to the distribution
of reference sequences, in the embedding space. Performs Hotelling's T2
test to compare multivariate distributions.

Args:
ad (anndata.AnnData): Anndata object containing sequence embeddings
of shape (n_seqs x n_vars)
reference_group (str): ID of group to use as reference
group_col (str): Name of column in .obs containing group ID
use_pca (bool): Whether to use PCA distances

Returns:
ad (anndata.AnnData): Modified anndata object containing distance to
reference in .uns['distribution_shift'].
"""
rows = []

# Get reference sequences
in_ref = ad.obs[group_col] == reference_group
if use_pca:
ref_X = ad.obsm["X_pca"][in_ref, :]
else:
ref_X = ad.X[in_ref, :]

# List groups
groups = ad.obs[group_col].unique()

for group in groups:
# Get group sequences
in_group = ad.obs[group_col] == group
if use_pca:
group_X = ad.obsm["X_pca"][in_group, :]
else:
group_X = ad.X[in_group, :]

# Perform Hotelling's T2 test to compare to the reference
rows.append([group] + list(hotelling_t2(group_X, ref_X)[:-1]))

# Format dataframe
res = pd.DataFrame(rows, columns=[group_col, "t2_stat", "fval", "pval"])
res["padj"] = fdrcorrection(res.pval)[1]
ad.uns["dist_shift_test"] = res.set_index(group_col)
return ad


def embedding_analysis(
matrix,
seqs,
Expand Down Expand Up @@ -425,14 +473,17 @@ def embedding_analysis(
ad = differential_analysis(ad, reference_group, group_col)

print("1-NN statistics")
ad = groupwise_1nn(ad, reference_group, group_col, use_pca=use_pca)
ad = reference_1nn(ad, reference_group, group_col, use_pca=use_pca)

print("Within-group KNN diversity")
ad = within_group_knn_dist(ad, n_neighbors, group_col, use_pca=use_pca)
ad = group_diversity(ad, n_neighbors, group_col, use_pca=use_pca)

print("Euclidean distance to nearest reference")
ad = dist_to_reference(ad, reference_group, group_col, use_pca=use_pca)

print("Distribution shift")
ad = distribution_shift(ad, reference_group, group_col, use_pca=use_pca)

print("Train groupwise classifiers")
ad = groupwise_svm(
ad,
Expand Down
6,622 changes: 5,491 additions & 1,131 deletions tutorials/1_yeast_tutorial.ipynb

Large diffs are not rendered by default.

Loading