diff --git a/cell2location/models/_cell2location_module.py b/cell2location/models/_cell2location_module.py index 5c22da79..859ec465 100755 --- a/cell2location/models/_cell2location_module.py +++ b/cell2location/models/_cell2location_module.py @@ -98,6 +98,7 @@ def __init__( init_vals: Optional[dict] = None, init_alpha=20.0, dropout_p=0.0, + location_sampling: bool = False, ): super().__init__() @@ -122,6 +123,8 @@ def __init__( if self.dropout_p is not None: self.dropout = torch.nn.Dropout(p=self.dropout_p) + self.location_sampling = location_sampling + if (init_vals is not None) & (type(init_vals) is dict): self.np_init_vals = init_vals for k in init_vals.keys(): @@ -192,6 +195,9 @@ def __init__( self.register_buffer("n_groups_tensor", torch.tensor(self.n_groups)) self.register_buffer("ones", torch.ones((1, 1))) + self.register_buffer("zeros", torch.zeros((1, 1))) + self.register_buffer("five", torch.tensor(5.0)) + self.register_buffer("ten", torch.tensor(10.0)) self.register_buffer("ones_1_n_groups", torch.ones((1, self.n_groups))) self.register_buffer("ones_n_batch_1", torch.ones((self.n_batch, 1))) self.register_buffer("eps", torch.tensor(1e-8)) @@ -460,12 +466,47 @@ def forward(self, x_data, idx, batch_index): # =====================Expected expression ======================= # if not self.training_wo_observed: # expected expression - mu = ((w_sf @ self.cell_state) * m_g + (obs2sample @ s_g_gene_add)) * detection_y_s - alpha = obs2sample @ (self.ones / alpha_g_inverse.pow(2)) - # convert mean and overdispersion to total count and logits - # total_count, logits = _convert_mean_disp_to_counts_logits( - # mu, alpha, eps=self.eps - # ) + if not self.location_sampling: + mu = ((w_sf @ self.cell_state) * m_g + (obs2sample @ s_g_gene_add)) * detection_y_s + alpha = obs2sample @ (self.ones / alpha_g_inverse.pow(2)) + else: + # sampling location-specific cell state signatures ==== + cell_state_sigma = pyro.sample( + "cell_state_sigma", + dist.Exponential(self.ten * self.ten).expand([1, 1, 1]).to_event(2), + ) # (1, 1).squeeze() + cell_state = torch.exp( + torch.log(self.cell_state.unsqueeze(-3)) + + dist.Normal(self.zeros, self.ones).sample([w_sf.shape[0], self.n_factors, self.n_vars]).squeeze() + * cell_state_sigma + ) + biol_mu = torch.einsum("sf,sfg->sg", w_sf, cell_state) + + # sampling location-specific background counts ==== + s_g_gene_add_sigma = pyro.sample( + "s_g_gene_add_sigma", + dist.Exponential(self.ten * self.ten).expand([self.n_batch, 1]).to_event(2), + ) # (1, 1) + s_g_gene_add_sigma = (obs2sample @ s_g_gene_add_sigma).unsqueeze(-1) + s_g_gene_add = torch.exp( + torch.log(s_g_gene_add.unsqueeze(-3)) + + dist.Normal(self.zeros, self.ones).sample([w_sf.shape[0], self.n_batch, self.n_vars]).squeeze() + * s_g_gene_add_sigma + ) + background_mu = torch.einsum("se,seg->sg", obs2sample, s_g_gene_add) + + # sampling location-specific technology sensitivity ==== + if False: + m_g_alpha = pyro.sample( + "m_g_alpha", + dist.Exponential(self.ten).expand([self.n_batch, 1]).to_event(2), + ) # (1, 1) + m_g_alpha = obs2sample @ (self.ones / m_g_alpha.pow(2)) + m_g = dist.Gamma(m_g_alpha, m_g_alpha / m_g).sample([1]).squeeze(-3) + + # compute expected value ==== + mu = (biol_mu * m_g + background_mu) * detection_y_s + alpha = obs2sample @ (self.ones / alpha_g_inverse.pow(2)) # =====================DATA likelihood ======================= # # Likelihood (sampling distribution) of data_target & add overdispersion via NegativeBinomial @@ -481,8 +522,12 @@ def forward(self, x_data, idx, batch_index): # =====================Compute mRNA count from each factor in locations ======================= # with obs_plate: - mRNA = w_sf * (self.cell_state * m_g).sum(-1) - pyro.deterministic("u_sf_mRNA_factors", mRNA) + if not self.location_sampling: + mRNA = w_sf * (self.cell_state * m_g).sum(-1) + pyro.deterministic("u_sf_mRNA_factors", mRNA) + else: + mRNA = torch.einsum("sf,sfg,sg->sf", w_sf, cell_state, m_g) + pyro.deterministic("u_sf_mRNA_factors", mRNA) def compute_expected(self, samples, adata_manager, ind_x=None): r"""Compute expected expression of each gene in each location. Useful for evaluating how well