diff --git a/setup.cfg b/setup.cfg index a009334..dca5be2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -64,6 +64,7 @@ install_requires = hotelling # fastsk upsetplot + geosketch [options.packages.find] diff --git a/src/polygraph/embedding.py b/src/polygraph/embedding.py index f34b136..99b779e 100644 --- a/src/polygraph/embedding.py +++ b/src/polygraph/embedding.py @@ -414,6 +414,44 @@ def distribution_shift(ad, reference_group, group_col="Group", use_pca=False): return ad +def geometric_sketch(ad, N, groups=None, group_col="Group", use_pca=True): + """ + Applies geometric sketching (Hie, Brian et al. Cell Systems, Volume 8, + Issue 6, 483 - 493.e7) to sample a subset of sequences that represent + the diversity in the specified groups. + + Args: + ad (anndata.AnnData): Anndata object containing sequence embeddings + of shape (n_seqs x n_vars) + N (int): Number of sequences to sample from each group + groups (list): Names of groups from which to sample sequences. If None, + all groups are used. + group_col (str): Name of column in .obs containing group ID + use_pca (bool): Whether to use PCA distances + + Returns: + ad (anndata.AnnData): Modified anndata object containing selections in + ad.obs['selected']. + """ + from geosketch import gs + + ad.obs["selected"] = False + groups = groups or ad.obs[group_col].unique() + + for group in groups: + in_group = ad.obs[group_col] == group + group_idx = ad.obs_names[in_group].tolist() + if use_pca: + group_X = ad.obsm["X_pca"][in_group, :] + else: + group_X = ad.X[in_group, :] + + sketch_index = gs(group_X, N=N, replace=False) + ad.obs.loc[[group_idx[x] for x in sketch_index], "selected"] = True + + return ad + + def embedding_analysis( matrix, seqs,