From 31ec7bab25c00d1007393211908d55d7d6e719b2 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Wed, 23 Mar 2022 23:25:00 +0000 Subject: [PATCH 1/3] Posterior quantile detects minibatch plate vars --- cell2location/models/base/_pyro_mixin.py | 96 +++++++++++++----------- 1 file changed, 53 insertions(+), 43 deletions(-) diff --git a/cell2location/models/base/_pyro_mixin.py b/cell2location/models/base/_pyro_mixin.py index 45e194e2..56797e84 100755 --- a/cell2location/models/base/_pyro_mixin.py +++ b/cell2location/models/base/_pyro_mixin.py @@ -167,7 +167,9 @@ def optim_param(module_name, param_name): return optim_param @torch.no_grad() - def _posterior_quantile_amortised(self, q: float = 0.5, batch_size: int = 2048, use_gpu: bool = None): + def _posterior_quantile_minibatch( + self, q: float = 0.5, batch_size: int = 2048, use_gpu: bool = None, use_median: bool = False + ): """ Compute median of the posterior distribution of each parameter, separating local (minibatch) variable and global variables, which is necessary when performing amortised inference. @@ -183,10 +185,12 @@ def _posterior_quantile_amortised(self, q: float = 0.5, batch_size: int = 2048, number of observations per batch use_gpu Bool, use gpu? + use_median + Bool, when q=0.5 use median rather than quantile method of the guide Returns ------- - dictionary {variable_name: posterior median} + dictionary {variable_name: posterior quantile} """ @@ -206,35 +210,27 @@ def _posterior_quantile_amortised(self, q: float = 0.5, batch_size: int = 2048, self.to_device(device) if i == 0: - - means = self.module.guide.quantiles([q], *args, **kwargs) - means = { - k: means[k].cpu().numpy() - for k in means.keys() - if k in self.module.model.list_obs_plate_vars()["sites"] - } - + # find plate sites + obs_plate_sites = self._get_obs_plate_sites(args, kwargs, return_observed=True) + if len(obs_plate_sites) == 0: + # if no local variables - don't sample + break # find plate dimension - trace = poutine.trace(self.module.model).get_trace(*args, **kwargs) - # print(trace.nodes[self.module.model.list_obs_plate_vars()['name']]) - obs_plate = { - name: site["cond_indep_stack"][0].dim - for name, site in trace.nodes.items() - if site["type"] == "sample" - if any(f.name == self.module.model.list_obs_plate_vars()["name"] for f in site["cond_indep_stack"]) - } + obs_plate_dim = list(obs_plate_sites.values())[0] + if use_median and q == 0.5: + means = self.module.guide.median(*args, **kwargs) + else: + means = self.module.guide.quantiles([q], *args, **kwargs) + means = {k: means[k].cpu().numpy() for k in means.keys() if k in obs_plate_sites} else: + if use_median and q == 0.5: + means_ = self.module.guide.median(*args, **kwargs) + else: + means_ = self.module.guide.quantiles([q], *args, **kwargs) - means_ = self.module.guide.quantiles([q], *args, **kwargs) - means_ = { - k: means_[k].cpu().numpy() - for k in means_.keys() - if k in list(self.module.model.list_obs_plate_vars()["sites"].keys()) - } - means = { - k: np.concatenate([means[k], means_[k]], axis=list(obs_plate.values())[0]) for k in means.keys() - } + means_ = {k: means_[k].cpu().numpy() for k in means_.keys() if k in obs_plate_sites} + means = {k: np.concatenate([means[k], means_[k]], axis=obs_plate_dim) for k in means.keys()} i += 1 # sample global parameters @@ -244,12 +240,11 @@ def _posterior_quantile_amortised(self, q: float = 0.5, batch_size: int = 2048, kwargs = {k: v.to(device) for k, v in kwargs.items()} self.to_device(device) - global_means = self.module.guide.quantiles([q], *args, **kwargs) - global_means = { - k: global_means[k].cpu().numpy() - for k in global_means.keys() - if k not in list(self.module.model.list_obs_plate_vars()["sites"].keys()) - } + if use_median and q == 0.5: + global_means = self.module.guide.median(*args, **kwargs) + else: + global_means = self.module.guide.quantiles([q], *args, **kwargs) + global_means = {k: global_means[k].cpu().numpy() for k in global_means.keys() if k not in obs_plate_sites} for k in global_means.keys(): means[k] = global_means[k] @@ -259,26 +254,31 @@ def _posterior_quantile_amortised(self, q: float = 0.5, batch_size: int = 2048, return means @torch.no_grad() - def _posterior_quantile(self, q: float = 0.5, batch_size: int = 2048, use_gpu: bool = None): + def _posterior_quantile( + self, q: float = 0.5, batch_size: int = None, use_gpu: bool = None, use_median: bool = False + ): """ Compute median of the posterior distribution of each parameter pyro models trained without amortised inference. Parameters ---------- q - quantile to compute + Quantile to compute use_gpu Bool, use gpu? + use_median + Bool, when q=0.5 use median rather than quantile method of the guide Returns ------- - dictionary {variable_name: posterior median} + dictionary {variable_name: posterior quantile} """ self.module.eval() gpus, device = parse_use_gpu_arg(use_gpu) - + if batch_size is None: + batch_size = self.adata.n_obs train_dl = AnnDataLoader(self.adata, shuffle=False, batch_size=batch_size) # sample global parameters tensor_dict = next(iter(train_dl)) @@ -287,30 +287,40 @@ def _posterior_quantile(self, q: float = 0.5, batch_size: int = 2048, use_gpu: b kwargs = {k: v.to(device) for k, v in kwargs.items()} self.to_device(device) - means = self.module.guide.quantiles([q], *args, **kwargs) + if use_median and q == 0.5: + means = self.module.guide.median(*args, **kwargs) + else: + means = self.module.guide.quantiles([q], *args, **kwargs) means = {k: means[k].cpu().detach().numpy() for k in means.keys()} return means - def posterior_quantile(self, q: float = 0.5, batch_size: int = 2048, use_gpu: bool = None): + def posterior_quantile( + self, q: float = 0.5, batch_size: int = 2048, use_gpu: bool = None, use_median: bool = False + ): """ Compute median of the posterior distribution of each parameter. Parameters ---------- q - quantile to compute + Quantile to compute use_gpu + Bool, use gpu? + use_median + Bool, when q=0.5 use median rather than quantile method of the guide Returns ------- """ - if self.module.is_amortised: - return self._posterior_quantile_amortised(q=q, batch_size=batch_size, use_gpu=use_gpu) + if batch_size is not None: + return self._posterior_quantile_minibatch( + q=q, batch_size=batch_size, use_gpu=use_gpu, use_median=use_median + ) else: - return self._posterior_quantile(q=q, batch_size=batch_size, use_gpu=use_gpu) + return self._posterior_quantile(q=q, batch_size=batch_size, use_gpu=use_gpu, use_median=use_median) class PltExportMixin: From 66f1bbf2cfd618af6ecd3cd61300425e0691a292 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Wed, 23 Mar 2022 23:34:08 +0000 Subject: [PATCH 2/3] Update _pyro_mixin.py --- cell2location/models/base/_pyro_mixin.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cell2location/models/base/_pyro_mixin.py b/cell2location/models/base/_pyro_mixin.py index 56797e84..00d3e817 100755 --- a/cell2location/models/base/_pyro_mixin.py +++ b/cell2location/models/base/_pyro_mixin.py @@ -7,7 +7,6 @@ import pandas as pd import pyro import torch -from pyro import poutine from pyro.infer.autoguide import AutoNormal, init_to_mean from scipy.sparse import issparse from scvi import _CONSTANTS From b9238d57376ad36d1c6fa9118969e22b3b1c7a8e Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Wed, 23 Mar 2022 23:43:49 +0000 Subject: [PATCH 3/3] added test --- tests/test_cell2location.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_cell2location.py b/tests/test_cell2location.py index 02e4f587..213ce829 100644 --- a/tests/test_cell2location.py +++ b/tests/test_cell2location.py @@ -63,6 +63,8 @@ def test_cell2location(): dataset = st_model.export_posterior(dataset, sample_kwargs={"num_samples": 10, "batch_size": 50}) # test computing any quantile of the posterior distribution st_model.posterior_quantile(q=0.5) + quant = st_model.posterior_quantile(q=0.5, batch_size=50, use_median=True) + assert quant['w_sf'].shape == dataset.n_obs # test computing expected expression per cell type st_model.module.model.compute_expected_per_cell_type(st_model.samples["post_sample_q05"], st_model.adata) ### test amortised inference with default cell2location model ###