Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 68 additions & 43 deletions cell2location/distributions/AutoAmortisedNormalMessenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,13 +261,14 @@ def encode(self, name: str, prior: Distribution):
to_pyro_module_(one_encoder)
deep_setattr(self, "one_encoder", one_encoder)
else:
encoder_kwargs = deepcopy(self.encoder_kwargs)
if "n_out" not in encoder_kwargs.keys():
encoder_kwargs["n_out"] = self.n_hidden["single"]
# create encoder instance from encoder class
deep_setattr(
self,
"one_encoder",
self.encoder_class(n_in=self.single_n_in, n_out=self.n_hidden["single"], **self.encoder_kwargs).to(
prior.mean.device
),
self.encoder_class(n_in=self.single_n_in, **encoder_kwargs).to(prior.mean.device),
)
if "multiple" in self.encoder_mode:
# determine the number of hidden layers
Expand All @@ -277,7 +278,14 @@ def encode(self, name: str, prior: Distribution):
n_hidden = self.n_hidden["multiple"]
multi_encoder_kwargs = deepcopy(self.multi_encoder_kwargs)
multi_encoder_kwargs["n_hidden"] = n_hidden

if "n_out" not in multi_encoder_kwargs.keys():
multi_encoder_kwargs["n_out"] = n_hidden
elif isinstance(self.multi_encoder_kwargs["n_out"], str):
multi_encoder_kwargs["n_out"] = self.amortised_plate_sites["sites"][name]
if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
multi_encoder_kwargs["n_out_extra"] = 3
else:
multi_encoder_kwargs["n_out_extra"] = 2
# create multiple encoders
if self.encoder_instance is not None:
# copy instances
Expand All @@ -294,9 +302,7 @@ def encode(self, name: str, prior: Distribution):
deep_setattr(
self,
"multiple_encoders." + name,
self.encoder_class(n_in=self.multiple_n_in, n_out=n_hidden, **multi_encoder_kwargs).to(
prior.mean.device
),
self.encoder_class(n_in=self.multiple_n_in, **multi_encoder_kwargs).to(prior.mean.device),
)
return self.encode(name, prior)

Expand All @@ -308,18 +314,32 @@ def _get_params(self, name: str, prior: Distribution):
args, kwargs = self.args_kwargs # stored as a tuple of (tuple, dict)
hidden = self.encode(name, prior)
try:
linear_loc = deep_getattr(self.hidden2locs, name)
bias_loc = deep_getattr(self.bias4locs, name)
loc = hidden @ linear_loc + bias_loc
linear_scale = deep_getattr(self.hidden2scales, name)
if ("n_out" in self.multi_encoder_kwargs) and isinstance(self.multi_encoder_kwargs["n_out"], str):
loc = hidden[:, :, 0]
else:
linear_loc = deep_getattr(self.hidden2locs, name)
loc = hidden @ linear_loc
loc = loc + bias_loc

bias_scale = deep_getattr(self.bias4scales, name)
scale = self.softplus((hidden @ linear_scale) + bias_scale - self._init_scale_unconstrained)
if ("n_out" in self.multi_encoder_kwargs) and isinstance(self.multi_encoder_kwargs["n_out"], str):
scale = hidden[:, :, 1]
else:
linear_scale = deep_getattr(self.hidden2scales, name)
scale = hidden @ linear_scale
scale = self.softplus(scale + bias_scale - self._init_scale_unconstrained)

if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
if self.weight_type == "element-wise":
# weight is element-wise
linear_weight = deep_getattr(self.hidden2weights, name)
bias_weight = deep_getattr(self.bias4weights, name)
weight = self.softplus((hidden @ linear_weight) + bias_weight - self._init_weight_unconstrained)
if ("n_out" in self.multi_encoder_kwargs) and isinstance(self.multi_encoder_kwargs["n_out"], str):
weight = hidden[:, :, 2]
else:
linear_weight = deep_getattr(self.hidden2weights, name)
weight = hidden @ linear_weight
weight = self.softplus(weight + bias_weight - self._init_weight_unconstrained)
if self.weight_type == "scalar":
# weight is a single value parameter
weight = deep_getattr(self.weights, name)
Expand All @@ -330,47 +350,52 @@ def _get_params(self, name: str, prior: Distribution):
pass

# Initialize.
with torch.no_grad():
init_scale = torch.full((), self._init_scale)
self._init_scale_unconstrained = self.softplus.inv(init_scale)
init_weight = torch.full((), self._init_weight)
self._init_weight_unconstrained = self.softplus.inv(init_weight)

# determine the number of hidden layers
if "multiple" in self.encoder_mode:
if name in self.n_hidden.keys():
n_hidden = self.n_hidden[name]
else:
n_hidden = self.n_hidden["multiple"]
elif "single" in self.encoder_mode:
n_hidden = self.n_hidden["single"]
# determine parameter dimensions
param_dim = (n_hidden, self.amortised_plate_sites["sites"][name])
bias_dim = (1, self.amortised_plate_sites["sites"][name])
# generate initial value for linear parameters
init_param = torch.normal(
torch.full(size=param_dim, fill_value=0.0, device=prior.mean.device),
torch.full(
size=param_dim, fill_value=(1 * self.init_param_scale) / np.sqrt(n_hidden), device=prior.mean.device
),
)
deep_setattr(self, "hidden2locs." + name, PyroParam(init_param.clone().detach().requires_grad_(True)))
deep_setattr(self, "hidden2scales." + name, PyroParam(init_param.clone().detach().requires_grad_(True)))
bias_dim = (1, self.amortised_plate_sites["sites"][name])
deep_setattr(
self, "bias4locs." + name, PyroParam(torch.full(size=bias_dim, fill_value=0.0, device=prior.mean.device))
)
deep_setattr(
self, "bias4scales." + name, PyroParam(torch.full(size=bias_dim, fill_value=0.0, device=prior.mean.device))
)
with torch.no_grad():
init_scale = torch.full((), self._init_scale)
self._init_scale_unconstrained = self.softplus.inv(init_scale)
init_weight = torch.full((), self._init_weight)
self._init_weight_unconstrained = self.softplus.inv(init_weight)

if not (("n_out" in self.multi_encoder_kwargs) and isinstance(self.multi_encoder_kwargs["n_out"], str)):
with torch.no_grad():
# determine the number of hidden layers
if "multiple" in self.encoder_mode:
if name in self.n_hidden.keys():
n_hidden = self.n_hidden[name]
else:
n_hidden = self.n_hidden["multiple"]
elif "single" in self.encoder_mode:
n_hidden = self.n_hidden["single"]
# determine parameter dimensions
param_dim = (n_hidden, self.amortised_plate_sites["sites"][name])
# generate initial value for linear parameters
init_param = torch.normal(
torch.full(size=param_dim, fill_value=0.0, device=prior.mean.device),
torch.full(
size=param_dim,
fill_value=(1 * self.init_param_scale) / np.sqrt(n_hidden),
device=prior.mean.device,
),
)
deep_setattr(self, "hidden2locs." + name, PyroParam(init_param.clone().detach().requires_grad_(True)))
deep_setattr(self, "hidden2scales." + name, PyroParam(init_param.clone().detach().requires_grad_(True)))
if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
if self.weight_type == "scalar":
# weight is a single value parameter
deep_setattr(self, "weights." + name, PyroParam(init_weight, constraint=constraints.positive))
if self.weight_type == "element-wise":
# weight is element-wise
deep_setattr(
self, "hidden2weights." + name, PyroParam(init_param.clone().detach().requires_grad_(True))
)
if not (("n_out" in self.multi_encoder_kwargs) and isinstance(self.multi_encoder_kwargs["n_out"], str)):
# weight is element-wise
deep_setattr(
self, "hidden2weights." + name, PyroParam(init_param.clone().detach().requires_grad_(True))
)
deep_setattr(
self,
"bias4weights." + name,
Expand Down
9 changes: 6 additions & 3 deletions cell2location/models/_cell2location_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,14 +222,17 @@ def list_obs_plate_vars(self):

return {
"name": "obs_plate",
"input": [0, 2], # expression data + (optional) batch index
"input": [
0,
# 2,
], # expression data + (optional) batch index
"input_transform": [
torch.log1p,
lambda x: x,
# lambda x: x,
], # how to transform input data before passing to NN
"input_normalisation": [
False,
False,
# False,
], # whether to normalise input data before passing to NN
"sites": {
"n_s_cells_per_location": 1,
Expand Down
Loading