From b18d422a414d27fe1079cecde62dbb5f923c8258 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Wed, 18 May 2022 23:48:49 +0100 Subject: [PATCH 1/3] added output specific NN & necessary changes to AutoGuide --- .../AutoAmortisedNormalMessenger.py | 111 ++++--- cell2location/models/_cell2location_module.py | 9 +- cell2location/nn/OutputSpecificNN.py | 309 ++++++++++++++++++ 3 files changed, 383 insertions(+), 46 deletions(-) create mode 100755 cell2location/nn/OutputSpecificNN.py diff --git a/cell2location/distributions/AutoAmortisedNormalMessenger.py b/cell2location/distributions/AutoAmortisedNormalMessenger.py index e9a0f2da..08c2846c 100755 --- a/cell2location/distributions/AutoAmortisedNormalMessenger.py +++ b/cell2location/distributions/AutoAmortisedNormalMessenger.py @@ -261,13 +261,14 @@ def encode(self, name: str, prior: Distribution): to_pyro_module_(one_encoder) deep_setattr(self, "one_encoder", one_encoder) else: + encoder_kwargs = deepcopy(self.encoder_kwargs) + if "n_out" not in encoder_kwargs.keys(): + encoder_kwargs["n_out"] = self.n_hidden["single"] # create encoder instance from encoder class deep_setattr( self, "one_encoder", - self.encoder_class(n_in=self.single_n_in, n_out=self.n_hidden["single"], **self.encoder_kwargs).to( - prior.mean.device - ), + self.encoder_class(n_in=self.single_n_in, **encoder_kwargs).to(prior.mean.device), ) if "multiple" in self.encoder_mode: # determine the number of hidden layers @@ -277,7 +278,14 @@ def encode(self, name: str, prior: Distribution): n_hidden = self.n_hidden["multiple"] multi_encoder_kwargs = deepcopy(self.multi_encoder_kwargs) multi_encoder_kwargs["n_hidden"] = n_hidden - + if "n_out" not in multi_encoder_kwargs.keys(): + multi_encoder_kwargs["n_out"] = n_hidden + elif isinstance(self.multi_encoder_kwargs["n_out"], str): + multi_encoder_kwargs["n_out"] = self.amortised_plate_sites["sites"][name] + if (self._hierarchical_sites is None) or (name in self._hierarchical_sites): + multi_encoder_kwargs["n_out_extra"] = 3 + else: + multi_encoder_kwargs["n_out_extra"] = 2 # create multiple encoders if self.encoder_instance is not None: # copy instances @@ -294,9 +302,7 @@ def encode(self, name: str, prior: Distribution): deep_setattr( self, "multiple_encoders." + name, - self.encoder_class(n_in=self.multiple_n_in, n_out=n_hidden, **multi_encoder_kwargs).to( - prior.mean.device - ), + self.encoder_class(n_in=self.multiple_n_in, **multi_encoder_kwargs).to(prior.mean.device), ) return self.encode(name, prior) @@ -308,18 +314,32 @@ def _get_params(self, name: str, prior: Distribution): args, kwargs = self.args_kwargs # stored as a tuple of (tuple, dict) hidden = self.encode(name, prior) try: - linear_loc = deep_getattr(self.hidden2locs, name) bias_loc = deep_getattr(self.bias4locs, name) - loc = hidden @ linear_loc + bias_loc - linear_scale = deep_getattr(self.hidden2scales, name) + if isinstance(self.multi_encoder_kwargs["n_out"], str): + loc = hidden[:, :, 0] + else: + linear_loc = deep_getattr(self.hidden2locs, name) + loc = hidden @ linear_loc + loc = loc + bias_loc + bias_scale = deep_getattr(self.bias4scales, name) - scale = self.softplus((hidden @ linear_scale) + bias_scale - self._init_scale_unconstrained) + if isinstance(self.multi_encoder_kwargs["n_out"], str): + scale = hidden[:, :, 1] + else: + linear_scale = deep_getattr(self.hidden2scales, name) + scale = hidden @ linear_scale + scale = self.softplus(scale + bias_scale - self._init_scale_unconstrained) + if (self._hierarchical_sites is None) or (name in self._hierarchical_sites): if self.weight_type == "element-wise": # weight is element-wise - linear_weight = deep_getattr(self.hidden2weights, name) bias_weight = deep_getattr(self.bias4weights, name) - weight = self.softplus((hidden @ linear_weight) + bias_weight - self._init_weight_unconstrained) + if isinstance(self.multi_encoder_kwargs["n_out"], str): + weight = hidden[:, :, 2] + else: + linear_weight = deep_getattr(self.hidden2weights, name) + weight = hidden @ linear_weight + weight = self.softplus(weight + bias_weight - self._init_weight_unconstrained) if self.weight_type == "scalar": # weight is a single value parameter weight = deep_getattr(self.weights, name) @@ -330,47 +350,52 @@ def _get_params(self, name: str, prior: Distribution): pass # Initialize. - with torch.no_grad(): - init_scale = torch.full((), self._init_scale) - self._init_scale_unconstrained = self.softplus.inv(init_scale) - init_weight = torch.full((), self._init_weight) - self._init_weight_unconstrained = self.softplus.inv(init_weight) - - # determine the number of hidden layers - if "multiple" in self.encoder_mode: - if name in self.n_hidden.keys(): - n_hidden = self.n_hidden[name] - else: - n_hidden = self.n_hidden["multiple"] - elif "single" in self.encoder_mode: - n_hidden = self.n_hidden["single"] - # determine parameter dimensions - param_dim = (n_hidden, self.amortised_plate_sites["sites"][name]) - bias_dim = (1, self.amortised_plate_sites["sites"][name]) - # generate initial value for linear parameters - init_param = torch.normal( - torch.full(size=param_dim, fill_value=0.0, device=prior.mean.device), - torch.full( - size=param_dim, fill_value=(1 * self.init_param_scale) / np.sqrt(n_hidden), device=prior.mean.device - ), - ) - deep_setattr(self, "hidden2locs." + name, PyroParam(init_param.clone().detach().requires_grad_(True))) - deep_setattr(self, "hidden2scales." + name, PyroParam(init_param.clone().detach().requires_grad_(True))) + bias_dim = (1, self.amortised_plate_sites["sites"][name]) deep_setattr( self, "bias4locs." + name, PyroParam(torch.full(size=bias_dim, fill_value=0.0, device=prior.mean.device)) ) deep_setattr( self, "bias4scales." + name, PyroParam(torch.full(size=bias_dim, fill_value=0.0, device=prior.mean.device)) ) + with torch.no_grad(): + init_scale = torch.full((), self._init_scale) + self._init_scale_unconstrained = self.softplus.inv(init_scale) + init_weight = torch.full((), self._init_weight) + self._init_weight_unconstrained = self.softplus.inv(init_weight) + + if not isinstance(self.multi_encoder_kwargs["n_out"], str): + with torch.no_grad(): + # determine the number of hidden layers + if "multiple" in self.encoder_mode: + if name in self.n_hidden.keys(): + n_hidden = self.n_hidden[name] + else: + n_hidden = self.n_hidden["multiple"] + elif "single" in self.encoder_mode: + n_hidden = self.n_hidden["single"] + # determine parameter dimensions + param_dim = (n_hidden, self.amortised_plate_sites["sites"][name]) + # generate initial value for linear parameters + init_param = torch.normal( + torch.full(size=param_dim, fill_value=0.0, device=prior.mean.device), + torch.full( + size=param_dim, + fill_value=(1 * self.init_param_scale) / np.sqrt(n_hidden), + device=prior.mean.device, + ), + ) + deep_setattr(self, "hidden2locs." + name, PyroParam(init_param.clone().detach().requires_grad_(True))) + deep_setattr(self, "hidden2scales." + name, PyroParam(init_param.clone().detach().requires_grad_(True))) if (self._hierarchical_sites is None) or (name in self._hierarchical_sites): if self.weight_type == "scalar": # weight is a single value parameter deep_setattr(self, "weights." + name, PyroParam(init_weight, constraint=constraints.positive)) if self.weight_type == "element-wise": - # weight is element-wise - deep_setattr( - self, "hidden2weights." + name, PyroParam(init_param.clone().detach().requires_grad_(True)) - ) + if not isinstance(self.multi_encoder_kwargs["n_out"], str): + # weight is element-wise + deep_setattr( + self, "hidden2weights." + name, PyroParam(init_param.clone().detach().requires_grad_(True)) + ) deep_setattr( self, "bias4weights." + name, diff --git a/cell2location/models/_cell2location_module.py b/cell2location/models/_cell2location_module.py index 5c22da79..760741ec 100755 --- a/cell2location/models/_cell2location_module.py +++ b/cell2location/models/_cell2location_module.py @@ -222,14 +222,17 @@ def list_obs_plate_vars(self): return { "name": "obs_plate", - "input": [0, 2], # expression data + (optional) batch index + "input": [ + 0, + # 2, + ], # expression data + (optional) batch index "input_transform": [ torch.log1p, - lambda x: x, + # lambda x: x, ], # how to transform input data before passing to NN "input_normalisation": [ False, - False, + # False, ], # whether to normalise input data before passing to NN "sites": { "n_s_cells_per_location": 1, diff --git a/cell2location/nn/OutputSpecificNN.py b/cell2location/nn/OutputSpecificNN.py new file mode 100755 index 00000000..3079875c --- /dev/null +++ b/cell2location/nn/OutputSpecificNN.py @@ -0,0 +1,309 @@ +from typing import Iterable + +import numpy as np +import pyro +import pyro.distributions as dist +import torch +from pyro.infer.autoguide.utils import deep_getattr, deep_setattr +from pyro.nn import PyroModule, PyroParam +from torch import nn as nn +from torch.distributions import constraints + + +class OutputSpecificNN(PyroModule): + """ + Model which defines small, output dimension-specific NNs. Inspired by DCDI and scvi-tools FCLayers. + + Parameters + ---------- + n_in + The dimensionality of the input + n_out + The dimensionality of the output + n_cat_list + A list containing, for each category of interest, + the number of categories. Each category will be + included using a one-hot encoding. + n_layers + The number of fully-connected hidden layers + n_hidden + The number of nodes per hidden layer + dropout_rate + Dropout rate to apply to each of the hidden layers + use_batch_norm + Whether to have `BatchNorm` layers or not + use_layer_norm + Whether to have `LayerNorm` layers or not + use_activation + Whether to have layer activation or not + use_activation + Whether to have layer activation at last layer or not + bias + Whether to learn bias in linear layers or not + inject_covariates + Whether to inject covariates in each layer, or just the first (default). + activation_fn + Which activation function to use + """ + + def __init__( + self, + n_in: int, + n_out: int, + name: str = "", + n_out_extra: int = 1, + n_cat_list: Iterable[int] = None, + n_layers: int = 2, + n_hidden: int = 4, + dropout_rate: float = 0.1, + bayesian: bool = False, + use_non_negative_weights: bool = True, + use_layer_norm: bool = True, + use_last_layer_norm: bool = False, + use_activation: bool = True, + use_last_activation: bool = False, + use_global_weights: bool = False, + bias: bool = True, + last_bias: bool = True, + inject_covariates: bool = True, + activation_fn: nn.Module = nn.ELU, + weights_prior={"shape": 0.1, "scale": 1.0}, + bias_prior={"mean": -10.0, "sigma": 3.0}, + ): + super().__init__() + + self.name = name + self.n_in = n_in + self.n_out = n_out + self.n_out_extra = n_out_extra + self.n_cat_list = n_cat_list + self.n_layers = n_layers + self.n_hidden = n_hidden + self.dropout_rate = dropout_rate + self.bayesian = bayesian + self.use_non_negative_weights = use_non_negative_weights + self.use_layer_norm = use_layer_norm + self.use_last_layer_norm = use_last_layer_norm + self.use_activation = use_activation + self.use_last_activation = use_last_activation + self.use_global_weights = use_global_weights + self.bias = bias + self.last_bias = last_bias + self.inject_covariates = inject_covariates + self.activation_fn = activation_fn + self.weights_prior = weights_prior + self.bias_prior = bias_prior + + self.weights = PyroModule() + + self.register_buffer("ones", torch.ones((1, 1))) + self.register_buffer("zeros", torch.zeros((1, 1))) + self.register_buffer("weights_prior_shape", torch.tensor(self.weights_prior["shape"])) + self.register_buffer("weights_prior_scale", torch.tensor(self.weights_prior["scale"])) + self.register_buffer("bias_mean_prior", torch.tensor(self.bias_prior["mean"])) + self.register_buffer("bias_sigma_prior", torch.tensor(self.bias_prior["sigma"])) + + def forward( + self, + x: torch.Tensor, + *in_out_effect, + ): + + for layer in range(self.n_layers + 1): + if self.use_global_weights and (layer < self.n_layers): + n_out = 1 + else: + n_out = self.n_out + if layer == 0: + n_in = self.n_in + else: + n_in = self.n_hidden + if layer == self.n_layers: + n_hidden = self.n_out_extra + else: + n_hidden = self.n_hidden + # optionally apply dropout ========== + if self.dropout_rate > 0: + if getattr(self.weights, f"{self.name}_layer_{layer}_dropout", None) is None: + deep_setattr( + self.weights, + f"{self.name}_layer_{layer}_dropout", + nn.Dropout(p=self.dropout_rate), + ) + dropout = deep_getattr(self.weights, f"{self.name}_layer_{layer}_dropout") + x = dropout(x) + + # generate parameters ========== + if self.bayesian: + # generate bayesian variables + if self.use_non_negative_weights: + # define Gamma distributed weights and Normal bias + # positive effect of input on output + weights = pyro.sample( + f"{self.name}_weights_layer_{layer}", + # for every TF increase alpha + # (more TFs have less sparse distribution) + dist.Gamma( + self.weights_prior_shape.to(x.device), + self.ones.to(x.device), + ) + .expand([n_hidden, n_out, n_in]) + .to_event(3), + ) + else: + # define laplace prior or horseshoe prior TODO horseshoe + weights = pyro.sample( + f"{self.name}_weights_layer_{layer}", + # for every TF increase alpha + # (more TFs have less sparse distribution) + dist.Laplace( + self.zeros.to(x.device), + self.weights_prior_scale.to(x.device), + ) + .expand([n_hidden, n_out, n_in]) + .to_event(3), + ) + + # bias allows requiring signal from more than one input + bias = pyro.sample( + f"{self.name}_bias_layer_{layer}", + dist.Normal( + self.bias_mean_prior.to(x.device), + self.ones.to(x.device) * self.bias_sigma_prior.to(x.device), + ) + .expand([1, n_out, n_hidden]) + .to_event(3), + ) + else: + if getattr(self.weights, f"{self.name}_layer_{layer}_weights", None) is None: + if self.use_non_negative_weights: + # initialise weights + init_param = torch.normal( + torch.full( + size=(n_hidden, n_out, n_in), + fill_value=0.0, + device=x.device, + ), + torch.full( + size=(n_hidden, n_out, n_in), + fill_value=1 / np.sqrt(n_hidden + n_out), + device=x.device, + ), + ).abs() + deep_setattr( + self.weights, + f"{self.name}_layer_{layer}_weights", + PyroParam( + init_param.clone().detach().requires_grad_(True), + constraint=constraints.positive, + ), + ) + init_param = torch.normal( + torch.full( + size=(1, n_out, n_hidden), + fill_value=0.0, + device=x.device, + ), + torch.full( + size=(1, n_out, n_hidden), + fill_value=1 / np.sqrt(n_hidden + n_out), + device=x.device, + ), + ) + deep_setattr( + self.weights, + f"{self.name}_layer_{layer}_bias", + PyroParam( + init_param.clone().detach().requires_grad_(True), + ), + ) + else: + # initialise weights + init_param = torch.normal( + torch.full( + size=(n_hidden, n_out, n_in), + fill_value=0.0, + device=x.device, + ), + torch.full( + size=(n_hidden, n_out, n_in), + fill_value=1 / np.sqrt(n_hidden + n_out), + device=x.device, + ), + ) + deep_setattr( + self.weights, + f"{self.name}_layer_{layer}_weights", + PyroParam(init_param.clone().detach().requires_grad_(True)), + ) + + init_param = torch.normal( + torch.full( + size=(1, n_out, n_hidden), + fill_value=0.0, + device=x.device, + ), + torch.full( + size=(1, n_out, n_hidden), + fill_value=1 / np.sqrt(n_hidden + n_out), + device=x.device, + ), + ) + deep_setattr( + self.weights, + f"{self.name}_layer_{layer}_bias", + PyroParam(init_param.clone().detach().requires_grad_(True)), + ) + + # extract weights + weights = deep_getattr(self.weights, f"{self.name}_layer_{layer}_weights") + bias = deep_getattr(self.weights, f"{self.name}_layer_{layer}_bias") + if self.use_global_weights: + weights = weights.expand([n_hidden, self.n_out, n_in]) + bias = bias.expand([1, self.n_out, n_hidden]) + + # compute layer weighted sum using einsum ========== + if (len(in_out_effect) == 1) and (layer == 0): + # first layer, apply in_out_effect (fg) + if len(in_out_effect[0].shape) == 2: + in_out_effect = in_out_effect[0].unsqueeze(0) + x = torch.einsum("tfg,lfg,cg->cft", weights, in_out_effect, x) + elif len(in_out_effect[0].shape) == 3: + x = torch.einsum("tfg,cfg,cg->cft", weights, in_out_effect[0], x) + elif (len(in_out_effect) == 0) and (layer == 0): + # first layer, without in_out_effect (fg) + x = torch.einsum("tfg,cg->cft", weights, x) + else: + # second layer or more + x = torch.einsum("qft,cft->cfq", weights, x) + if ((layer < self.n_layers) and self.bias) or ((layer == self.n_layers) and self.last_bias and self.bias): + x = x + bias + + # optionally apply layernorm ========== + if ((layer < self.n_layers) and self.use_layer_norm) or ( + (layer == self.n_layers) and self.use_last_layer_norm and self.use_layer_norm + ): + if getattr(self.weights, f"{self.name}_layer_{layer}_layer_norm", None) is None: + deep_setattr( + self.weights, + f"{self.name}_layer_{layer}_layer_norm", + nn.LayerNorm((self.n_out, n_hidden), elementwise_affine=False), + ) + layer_norm = deep_getattr(self.weights, f"{self.name}_layer_{layer}_layer_norm") + x = layer_norm(x) + + # optionally apply activation ========== + if ((layer < self.n_layers) and self.use_activation) or ( + (layer == self.n_layers) and self.use_last_activation and self.use_activation + ): + if getattr(self.weights, f"{self.name}_layer_{layer}_activation_fn", None) is None: + deep_setattr( + self.weights, + f"{self.name}_layer_{layer}_activation_fn", + self.activation_fn(), + ) + activation_fn = deep_getattr(self.weights, f"{self.name}_layer_{layer}_activation_fn") + x = activation_fn(x) + if self.n_out_extra == 1: + x = x.squeeze(-1) + return x From 31e65b5a095eef12eaaf51a5d1581835d2a83b28 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Wed, 18 May 2022 23:59:17 +0100 Subject: [PATCH 2/3] bux fixes --- .../distributions/AutoAmortisedNormalMessenger.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cell2location/distributions/AutoAmortisedNormalMessenger.py b/cell2location/distributions/AutoAmortisedNormalMessenger.py index 08c2846c..580ca39e 100755 --- a/cell2location/distributions/AutoAmortisedNormalMessenger.py +++ b/cell2location/distributions/AutoAmortisedNormalMessenger.py @@ -315,7 +315,7 @@ def _get_params(self, name: str, prior: Distribution): hidden = self.encode(name, prior) try: bias_loc = deep_getattr(self.bias4locs, name) - if isinstance(self.multi_encoder_kwargs["n_out"], str): + if ("n_out" in self.multi_encoder_kwargs) and isinstance(self.multi_encoder_kwargs["n_out"], str): loc = hidden[:, :, 0] else: linear_loc = deep_getattr(self.hidden2locs, name) @@ -323,7 +323,7 @@ def _get_params(self, name: str, prior: Distribution): loc = loc + bias_loc bias_scale = deep_getattr(self.bias4scales, name) - if isinstance(self.multi_encoder_kwargs["n_out"], str): + if ("n_out" in self.multi_encoder_kwargs) and isinstance(self.multi_encoder_kwargs["n_out"], str): scale = hidden[:, :, 1] else: linear_scale = deep_getattr(self.hidden2scales, name) @@ -334,7 +334,7 @@ def _get_params(self, name: str, prior: Distribution): if self.weight_type == "element-wise": # weight is element-wise bias_weight = deep_getattr(self.bias4weights, name) - if isinstance(self.multi_encoder_kwargs["n_out"], str): + if ("n_out" in self.multi_encoder_kwargs) and isinstance(self.multi_encoder_kwargs["n_out"], str): weight = hidden[:, :, 2] else: linear_weight = deep_getattr(self.hidden2weights, name) @@ -363,7 +363,7 @@ def _get_params(self, name: str, prior: Distribution): init_weight = torch.full((), self._init_weight) self._init_weight_unconstrained = self.softplus.inv(init_weight) - if not isinstance(self.multi_encoder_kwargs["n_out"], str): + if not (("n_out" in self.multi_encoder_kwargs) and isinstance(self.multi_encoder_kwargs["n_out"], str)): with torch.no_grad(): # determine the number of hidden layers if "multiple" in self.encoder_mode: @@ -391,7 +391,7 @@ def _get_params(self, name: str, prior: Distribution): # weight is a single value parameter deep_setattr(self, "weights." + name, PyroParam(init_weight, constraint=constraints.positive)) if self.weight_type == "element-wise": - if not isinstance(self.multi_encoder_kwargs["n_out"], str): + if not (("n_out" in self.multi_encoder_kwargs) and isinstance(self.multi_encoder_kwargs["n_out"], str)): # weight is element-wise deep_setattr( self, "hidden2weights." + name, PyroParam(init_param.clone().detach().requires_grad_(True)) From 2248ae2c7ca7bd63642a40d81b0c6b01707feea0 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Wed, 6 Jul 2022 00:11:23 +0100 Subject: [PATCH 3/3] changed defaults --- cell2location/nn/OutputSpecificNN.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cell2location/nn/OutputSpecificNN.py b/cell2location/nn/OutputSpecificNN.py index 3079875c..71695d54 100755 --- a/cell2location/nn/OutputSpecificNN.py +++ b/cell2location/nn/OutputSpecificNN.py @@ -54,17 +54,17 @@ def __init__( n_out_extra: int = 1, n_cat_list: Iterable[int] = None, n_layers: int = 2, - n_hidden: int = 4, + n_hidden: int = 16, dropout_rate: float = 0.1, bayesian: bool = False, - use_non_negative_weights: bool = True, + use_non_negative_weights: bool = False, use_layer_norm: bool = True, use_last_layer_norm: bool = False, use_activation: bool = True, use_last_activation: bool = False, use_global_weights: bool = False, bias: bool = True, - last_bias: bool = True, + last_bias: bool = False, inject_covariates: bool = True, activation_fn: nn.Module = nn.ELU, weights_prior={"shape": 0.1, "scale": 1.0},