From 3c318039055575f0e4814bfdcd0260214303eaad Mon Sep 17 00:00:00 2001 From: Dylan Date: Tue, 8 Sep 2020 20:58:56 +0200 Subject: [PATCH 01/44] adds weight normalizer and atleast_kd --- tests/test_data_processing.py | 18 +++++++++++++ utils/data_processing.py | 48 +++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/tests/test_data_processing.py b/tests/test_data_processing.py index a615b4d0..cec59546 100644 --- a/tests/test_data_processing.py +++ b/tests/test_data_processing.py @@ -237,3 +237,21 @@ def test_label_conversion(self): np.testing.assert_equal(func_dense, dense_labels) func_one_hot = dp.dense_to_one_hot(torch.tensor(dense_labels), num_classes).numpy() np.testing.assert_equal(func_one_hot, one_hot_labels) + + def test_atleastkd(self): + x = np.random.standard_normal([2, 3, 4]) + ks = [0, 3, 8, 10] + for k in ks: + new_x = dp.atleast_kd(torch.tensor(x), k).numpy() + test_nd = np.maximum(k, x.ndim) + np.testing.assert_equal(new_x.ndim, test_nd) + + def test_l2_weight_norm(self): + w_fc = np.random.standard_normal([24, 38]) + w_conv = np.random.standard_normal([38, 24, 8, 8]) + for w in [w_fc, w_conv, 0*w_fc, 0*w_conv]: + w_norm = dp.get_weights_l2_norm(torch.tensor(w), eps=1e-12).numpy() + normed_w = dp.l2_normalize_weights(torch.tensor(w), eps=1e-12).numpy() + normed_w_norm = dp.get_weights_l2_norm(torch.tensor(normed_w), eps=1e-12).numpy() + np.testing.assert_allclose(normed_w_norm, 1.0, rtol=1e-10) + np.testing.assert_allclose(w / w_norm, normed_w, rtol=1e-10) diff --git a/utils/data_processing.py b/utils/data_processing.py index fff2fddc..85dc23c9 100644 --- a/utils/data_processing.py +++ b/utils/data_processing.py @@ -1,5 +1,6 @@ import numpy as np import torch +import warn def reshape_data(data, flatten=None, out_shape=None): @@ -212,3 +213,50 @@ def dense_to_one_hot(labels_dense, num_classes): labels_one_hot = torch.zeros((num_labels, num_classes)) labels_one_hot.view(-1)[index_offset + labels_dense.view(-1)] = 1 return labels_one_hot + +def atleast_kd(x, k): + """ + return x reshaped to append singleton dimensions such that x.ndim is at least k + Inputs: + x [Tensor or numpy ndarray] + k [int] minimum number of dimensions + Outputs: + x [same as input x] reshaped input to have at least k dimensions + """ + shape = x.shape + (1,) * (k - x.ndim) + return x.reshape(shape) + +def get_weights_l2_norm(w, eps=1e-12): + """ + get l2 norm of weight matrix + Inputs: + w [Tensor] assumed to have shape [inC, outC] or [outC, inC, kernH, kernW] + norm is calculated over vectorized version of inC in the first case or inC*kernH*kernW in the second + eps [float] minimum value to prevent division by zero + Outputs: + norm [Tensor] norm of each of the outC weight vectors + """ + if w.ndim == 2: # fully-connected, [inputs, outputs] + norms = torch.norm(w, dim=0, keepdim=True) + elif w.ndim == 4: # convolutional, [out_channels, in_channels, kernel_height, kernel_width] + norms = torch.norm(w.flatten(start_dim=1), dim=-1, keepdim=True) + else: + assert False, (f'input w must have ndim = 2 or 4, not {w.ndim}') + if(torch.max(norms) <= eps): #TODO: Warnings + print(f'Warning: input gradient is less than or equal to {eps}') + norms = torch.max(norms, eps*torch.ones_like(norms)) # prevent div by 0 # TODO: Change to torch.maximum when it is stable + norms = atleast_kd(norms, w.ndim) + return norms + +def l2_normalize_weights(w, eps=1e-12): + """ + l2 normalize weight matrix + Inputs: + w [Tensor] assumed to have shape [inC, outC] or [outC, inC, kernH, kernW] + norm is calculated over vectorized version of inC in the first case or inC*kernH*kernW in the second + eps [float] minimum value to prevent division by zero + Outputs: + w [Tensor] same type and shape as input w, but with unitary l2 norm when computed over all input dimensions + """ + norms = get_weights_l2_norm(w, eps) + return w / norms From 15bf4d6660399a774fc56f2b0d01ed5b39e8ec03 Mon Sep 17 00:00:00 2001 From: Dylan Date: Tue, 8 Sep 2020 20:59:46 +0200 Subject: [PATCH 02/44] updates weight init to use new normalizer util --- modules/lca_module.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/modules/lca_module.py b/modules/lca_module.py index 9444a5d6..bd397c7c 100644 --- a/modules/lca_module.py +++ b/modules/lca_module.py @@ -3,16 +3,15 @@ import torch.nn.functional as F from DeepSparseCoding.modules.activations import lca_threshold +import DeepSparseCoding.utils.data_processing as dp class LcaModule(nn.Module): def setup_module(self, params): self.params = params - self.w = nn.Parameter( - F.normalize( - torch.randn(self.params.num_pixels, self.params.num_latent), - p=2, dim=0), - requires_grad=True) + w_init = torch.randn([self.params.num_pixels, self.params.num_latent]) + w_init_normed = dp.l2_normalize_weights(w_init, eps=self.params.eps) + self.w = nn.Parameter(w_init_normed, requires_grad=True) def preprocess_data(self, input_tensor): input_tensor = input_tensor.view(-1, self.params.num_pixels) @@ -44,7 +43,6 @@ def infer_coefficients(self, input_tensor): u_list = [torch.zeros([input_tensor.shape[0], self.params.num_latent], device=self.params.device)] a_list = [self.threshold_units(u_list[0])] - # TODO: look into redoing this with a register_buffer that gets updated? look up simple RNN code... for step in range(self.params.num_steps-1): u = self.step_inference(u_list[step], a_list[step], lca_b, lca_g, step)[0] u_list.append(u) From 793abf5b92a37a76ecfbed1add21bf1ca5a96d22 Mon Sep 17 00:00:00 2001 From: Dylan Date: Tue, 8 Sep 2020 21:01:40 +0200 Subject: [PATCH 03/44] changes weight norm to use new util; removes deprecated argument to scheduler.step --- utils/run_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/utils/run_utils.py b/utils/run_utils.py index 085fb6ca..aa472db2 100644 --- a/utils/run_utils.py +++ b/utils/run_utils.py @@ -1,5 +1,7 @@ import torch +import DeepSparseCoding.utils.data_processing as dp + def train_single_model(model, loss): model.optimizer.zero_grad() # clear gradietns of all optimized variables @@ -7,7 +9,7 @@ def train_single_model(model, loss): model.optimizer.step() if(hasattr(model.params, 'renormalize_weights') and model.params.renormalize_weights): with torch.no_grad(): # tell autograd to not record this operation - model.w.div_(torch.norm(model.w, dim=0, keepdim=True)) + model.w.div_(dp.get_weights_l2_norm(model.w)) def train_epoch(epoch, model, loader): @@ -36,9 +38,9 @@ def train_epoch(epoch, model, loader): input_data=inputs[0], input_labels=target, batch_step=batch_step) if(model.params.model_type.lower() == 'ensemble'): for submodule in model: - submodule.scheduler.step(epoch) + submodule.scheduler.step() else: - model.scheduler.step(epoch) + model.scheduler.step() def test_single_model(model, data, target, epoch): From 8e64b9fb6704ca79041d4b333edae35a35460b63 Mon Sep 17 00:00:00 2001 From: Dylan Date: Tue, 8 Sep 2020 21:02:04 +0200 Subject: [PATCH 04/44] removes unnecessary imports --- params/lca_mnist_params.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/params/lca_mnist_params.py b/params/lca_mnist_params.py index dd9cdf34..ff95de40 100644 --- a/params/lca_mnist_params.py +++ b/params/lca_mnist_params.py @@ -1,9 +1,5 @@ -import os import types -import numpy as np -import torch - from DeepSparseCoding.params.base_params import BaseParams From 3ae4997c9e84943753c9b9e365305c21d86cdd04 Mon Sep 17 00:00:00 2001 From: Dylan Date: Tue, 8 Sep 2020 21:03:29 +0200 Subject: [PATCH 05/44] adds convolutional lca; training gets NaN after 14 epochs --- models/conv_lca_model.py | 47 +++++++++++++++++++ modules/conv_lca_module.py | 83 +++++++++++++++++++++++++++++++++ params/conv_lca_mnist_params.py | 44 +++++++++++++++++ params/test_params.py | 12 +++++ tests/test_models.py | 2 +- utils/loaders.py | 6 +++ 6 files changed, 193 insertions(+), 1 deletion(-) create mode 100644 models/conv_lca_model.py create mode 100644 modules/conv_lca_module.py create mode 100644 params/conv_lca_mnist_params.py diff --git a/models/conv_lca_model.py b/models/conv_lca_model.py new file mode 100644 index 00000000..f9d96742 --- /dev/null +++ b/models/conv_lca_model.py @@ -0,0 +1,47 @@ +import numpy as np +import torch + +from DeepSparseCoding.models.base_model import BaseModel +from DeepSparseCoding.modules.conv_lca_module import ConvLcaModule +import DeepSparseCoding.modules.losses as losses + + +class ConvLcaModel(BaseModel, ConvLcaModule): + def setup(self, params, logger=None): + super(ConvLcaModel, self).setup(params, logger) + self.setup_module(params) + self.setup_optimizer() + + def get_total_loss(self, input_tuple): + input_tensor, input_labels = input_tuple + latents = self.get_encodings(input_tensor) + recon = self.get_recon_from_latents(latents) + recon_loss = losses.half_squared_l2(input_tensor, recon) + sparse_loss = self.params.sparse_mult * losses.l1_norm(latents) + total_loss = recon_loss + sparse_loss + return total_loss + + def generate_update_dict(self, input_data, input_labels=None, batch_step=0, update_dict=None): + if update_dict is None: + update_dict = super(ConvLcaModel, self).generate_update_dict(input_data, input_labels, batch_step) + epoch = batch_step / self.params.batches_per_epoch + stat_dict = { + 'epoch':int(epoch), + 'batch_step':batch_step, + 'train_progress':np.round(batch_step/self.params.num_batches, 3), + 'weight_lr':self.scheduler.get_lr()[0]} + latents = self.get_encodings(input_data) + recon = self.get_recon_from_latents(latents) + recon_loss = losses.half_squared_l2(input_data, recon).item() + sparse_loss = self.params.sparse_mult * losses.l1_norm(latents).item() + stat_dict['loss_recon'] = recon_loss + stat_dict['loss_sparse'] = sparse_loss + stat_dict['loss_total'] = recon_loss + sparse_loss + stat_dict['input_max_mean_min'] = [ + input_data.max().item(), input_data.mean().item(), input_data.min().item()] + stat_dict['recon_max_mean_min'] = [ + recon.max().item(), recon.mean().item(), recon.min().item()] + latent_nnz = torch.sum(latents != 0).item() # TODO: github issue 23907 requests torch.count_nonzero + stat_dict['latents_fraction_active'] = latent_nnz / latents.numel() + update_dict.update(stat_dict) + return update_dict diff --git a/modules/conv_lca_module.py b/modules/conv_lca_module.py new file mode 100644 index 00000000..5ab5297a --- /dev/null +++ b/modules/conv_lca_module.py @@ -0,0 +1,83 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from DeepSparseCoding.modules.lca_module import LcaModule +import DeepSparseCoding.utils.data_processing as dp + + +class ConvLcaModule(LcaModule): + """ + Parameters + ----------------------------- + data_shape [list of int] by default it is set to [h, w, c], however pytorch conv wants [c, h, w] so it is permuted in this module + Assumes h = w (i.e. square inputs) + in_channels [int] - Number of channels in the input image + Automatically set to params.num_pixels + out_channels [int] - Number of channels produced by the convolution + Automatically set to params.num_latent + kernel_size [int] - Edge size of the square convolving kernel + stride [int] - Vertical and horizontal stride of the convolution. + padding [int] - Zero-padding added to both sides of the input. + """ + def setup_module(self, params): + self.params = params + self.params.data_shape = [self.params.data_shape[2], self.params.data_shape[0], self.params.data_shape[1]] + self.input_shape = [self.params.batch_size] + self.params.data_shape + self.w_shape = [ + self.params.out_channels, + self.params.in_channels, + self.params.kernel_size, + self.params.kernel_size + ] + dilation = 1 + conv_hout = int(1 + (self.input_shape[2] + 2 * self.params.padding - dilation * (self.params.kernel_size - 1) - 1) / self.params.stride) + conv_wout = conv_hout # Assumes square images + self.output_shape = [self.params.batch_size, self.params.out_channels, conv_hout, conv_wout] + w_init = torch.randn(self.w_shape) + w_init_normed = dp.l2_normalize_weights(w_init, eps=self.params.eps) + self.w = nn.Parameter(w_init_normed, requires_grad=True) + + def preprocess_data(self, input_tensor): + return input_tensor.permute(0, 3, 1, 2) + + def get_recon_from_latents(self, a_in): + recon = F.conv_transpose2d( + input=a_in, + weight=self.w, + bias=None, + stride=self.params.stride, + padding=self.params.padding + ) + return recon + + def step_inference(self, input_tensor, u_in, a_in, step): + recon = self.get_recon_from_latents(a_in) + recon_error = input_tensor - recon + error_injection = F.conv2d( + input=recon_error, + weight=self.w, + bias=None, + stride=self.params.stride, + padding=self.params.padding + ) + du = error_injection + a_in - u_in + u_out = u_in + self.params.step_size * du + return u_out + + def infer_coefficients(self, input_tensor): + u_list = [torch.zeros(self.output_shape, device=self.params.device)] + a_list = [self.threshold_units(u_list[0])] + for step in range(self.params.num_steps-1): + u = self.step_inference(input_tensor, u_list[step], a_list[step], step) + u_list.append(u) + a_list.append(self.threshold_units(u)) + return (u_list, a_list) + + def get_encodings(self, input_tensor): + u_list, a_list = self.infer_coefficients(input_tensor) + return a_list[-1] + + def forward(self, input_tensor): + latents = self.get_encodings(input_tensor) + return latents diff --git a/params/conv_lca_mnist_params.py b/params/conv_lca_mnist_params.py new file mode 100644 index 00000000..d1e35e50 --- /dev/null +++ b/params/conv_lca_mnist_params.py @@ -0,0 +1,44 @@ +import types + +from DeepSparseCoding.params.base_params import BaseParams + + +class params(BaseParams): + def set_params(self): + super(params, self).set_params() + self.model_type = 'conv_lca' + self.model_name = 'conv_lca_mnist' + self.version = '0' + self.dataset = 'mnist' + self.standardize_data = False + self.rescale_data_to_one = True + self.num_pixels = 784 + self.batch_size = 100 + self.num_epochs = 100 + self.weight_decay = 0. + self.weight_lr = 0.1 + self.train_logs_per_epoch = 6 + self.optimizer = types.SimpleNamespace() + self.optimizer.name = 'sgd' + self.optimizer.lr_annealing_milestone_frac = [0.7] # fraction of num_epochs + self.optimizer.lr_decay_rate = 0.5 + self.renormalize_weights = True + self.dt = 0.001 + self.tau = 0.03 + self.num_steps = 75 + self.rectify_a = True + self.thresh_type = 'soft' + self.sparse_mult = 0.25 + self.kernel_size = 8 + self.stride = 2 + self.padding = 0 + self.num_latent = 128 + self.compute_helper_params() + + def compute_helper_params(self): + super(params, self).compute_helper_params() + self.optimizer.milestones = [frac * self.num_epochs + for frac in self.optimizer.lr_annealing_milestone_frac] + self.step_size = self.dt / self.tau + self.out_channels = self.num_latent + self.in_channels = 1 diff --git a/params/test_params.py b/params/test_params.py index 01a687b5..48c7f260 100644 --- a/params/test_params.py +++ b/params/test_params.py @@ -70,6 +70,18 @@ def set_params(self): for frac in self.optimizer.lr_annealing_milestone_frac] self.step_size = self.dt / self.tau +class conv_lca_params(lca_params): + def set_params(self): + super(conv_lca_params, self).set_params() + self.kernel_size = 8 + self.stride = 2 + self.padding = 0 + self.optimizer.milestones = [frac * self.num_epochs + for frac in self.optimizer.lr_annealing_milestone_frac] + self.step_size = self.dt / self.tau + self.out_channels = self.num_latent + self.in_channels = 1 + class mlp_params(BaseParams): def set_params(self): diff --git a/tests/test_models.py b/tests/test_models.py index a5e16ccd..1ff329da 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -20,7 +20,7 @@ def setUp(self): def test_model_loading(self): for model_type in self.model_list: - model_type = ''.join(model_type.split('_')[:-1]) # remove '_model' at the end + model_type = '_'.join(model_type.split('_')[:-1]) # remove '_model' at the end model = loaders.load_model(model_type) params = loaders.load_params(self.test_params_file, key=model_type+'_params') train_loader, val_loader, test_loader, data_params = datasets.load_dataset(params) diff --git a/utils/loaders.py b/utils/loaders.py index 9b73a8fb..e79553ad 100644 --- a/utils/loaders.py +++ b/utils/loaders.py @@ -42,6 +42,9 @@ def load_model_class(model_type): elif(model_type.lower() == 'lca'): py_module_name = 'LcaModel' file_name = os.path.join(*[dsc_dir, 'models', 'lca_model.py']) + elif(model_type.lower() == 'conv_lca'): + py_module_name = 'ConvLcaModel' + file_name = os.path.join(*[dsc_dir, 'models', 'conv_lca_model.py']) elif(model_type.lower() == 'ensemble'): py_module_name = 'EnsembleModel' file_name = os.path.join(*[dsc_dir, 'models', 'ensemble_model.py']) @@ -66,6 +69,9 @@ def load_module(module_type): elif(module_type.lower() == 'lca'): py_module_name = 'LcaModule' file_name = os.path.join(*[dsc_dir, 'modules', 'lca_module.py']) + elif(module_type.lower() == 'conv_lca'): + py_module_name = 'ConvLcaModule' + file_name = os.path.join(*[dsc_dir, 'modules', 'conv_lca_module.py']) elif(module_type.lower() == 'ensemble'): py_module_name = 'EnsembleModule' file_name = os.path.join(*[dsc_dir, 'modules', 'ensemble_module.py']) From 26bddf325a8972b2a60d9756def419b95b37f5bf Mon Sep 17 00:00:00 2001 From: Dylan Date: Thu, 10 Sep 2020 15:42:55 +0200 Subject: [PATCH 06/44] adds fastMNIST dataset; bugfixes; import fixes * adds fastMNIST dataset and parameters * fixes conv lca params to match between pytorch & tf1x - no longer getting NAN * fixes relative imports for utils/loaders.py - need to propagate to other files * fixes minor bugs in notebooks and run_utils * improves some error messaging --- modules/conv_lca_module.py | 2 + notebooks/monitor_training.ipynb | 11 ++-- notebooks/visualize_model_weights.ipynb | 33 +++++++----- params/base_params.py | 2 + params/conv_lca_mnist_params.py | 13 ++--- params/lca_mlp_mnist_params.py | 1 + params/lca_mnist_params.py | 1 + params/mlp_mnist_params.py | 1 + tf1x/params/lca_conv_params.py | 11 ++-- tf1x/utils/data_processing.py | 2 +- utils/data_processing.py | 1 - utils/dataset_utils.py | 67 ++++++++++++++++++++----- utils/loaders.py | 3 +- utils/run_utils.py | 4 +- 14 files changed, 108 insertions(+), 44 deletions(-) diff --git a/modules/conv_lca_module.py b/modules/conv_lca_module.py index 5ab5297a..84adba0b 100644 --- a/modules/conv_lca_module.py +++ b/modules/conv_lca_module.py @@ -24,6 +24,8 @@ def setup_module(self, params): self.params = params self.params.data_shape = [self.params.data_shape[2], self.params.data_shape[0], self.params.data_shape[1]] self.input_shape = [self.params.batch_size] + self.params.data_shape + assert (self.input_shape[-1] % self.params.stride == 0), ( + f'Stride = {self.params.stride} must divide evenly into input edge size = {self.input_shape[-1]}') self.w_shape = [ self.params.out_channels, self.params.in_channels, diff --git a/notebooks/monitor_training.ipynb b/notebooks/monitor_training.ipynb index 2c3ecad2..3796865e 100644 --- a/notebooks/monitor_training.ipynb +++ b/notebooks/monitor_training.ipynb @@ -23,7 +23,8 @@ "outputs": [], "source": [ "workspace_dir = os.path.expanduser('~')+'/Work/'\n", - "log_file = workspace_dir+'/Torch_projects/lca_768_mlp_mnist/logfiles/lca_768_mlp_mnist_v0.log'\n", + "model_name = 'conv_lca_mnist'\n", + "log_file = workspace_dir+'/Torch_projects/{}/logfiles/{}_v0.log'.format(model_name, model_name)\n", "logger = Logger(log_file, overwrite=False)\n", "\n", "log_text = logger.load_file()\n", @@ -41,9 +42,9 @@ "outputs": [], "source": [ "x_key = 'epoch'\n", - "y_keys = ['lca_loss_recon', 'lca_loss_sparse', 'lca_loss_total', 'mlp_loss', 'mlp_train_accuracy']\n", - "y_labels = ['Recon loss', 'Sparse loss', 'Total LCA loss', 'Total MLP loss', 'MLP train accuracy']\n", - "stats_fig = pf.plot_stats(model_stats, x_key=x_key, y_keys=y_keys, y_labels=y_labels, start_index=0)" + "#y_keys = ['lca_loss_recon', 'lca_loss_sparse', 'lca_loss_total', 'mlp_loss', 'mlp_train_accuracy']\n", + "#y_labels = ['Recon loss', 'Sparse loss', 'Total LCA loss', 'Total MLP loss', 'MLP train accuracy']\n", + "stats_fig = pf.plot_stats(model_stats, x_key=x_key)#, y_keys=y_keys, y_labels=y_labels, start_index=0)" ] }, { @@ -70,7 +71,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.7" + "version": "3.6.8" }, "varInspector": { "cols": { diff --git a/notebooks/visualize_model_weights.ipynb b/notebooks/visualize_model_weights.ipynb index 84678ccb..a6bb3460 100644 --- a/notebooks/visualize_model_weights.ipynb +++ b/notebooks/visualize_model_weights.ipynb @@ -29,16 +29,12 @@ "outputs": [], "source": [ "workspace_dir = os.path.expanduser(\"~\")+\"/Work/\"\n", - "model_name = 'lca_dsprites'\n", - "num_epochs = 100\n", - "sparse_mult = 0.05\n", - "model_name += '_{}_{}'.format(sparse_mult, num_epochs)\n", + "model_name = 'conv_lca_mnist'\n", + "#num_epochs = 200\n", + "#sparse_mult = 0.25\n", + "#model_name += '_{}_{}'.format(sparse_mult, num_epochs)\n", "log_file = workspace_dir+'/Torch_projects/{}/logfiles/{}_v0.log'.format(model_name, model_name)\n", "logger = Logger(log_file, overwrite=False)\n", - "\n", - "target_index = 1\n", - "\n", - "logger = Logger(log_files[target_index], overwrite=False)\n", "log_text = logger.load_file()\n", "params = logger.read_params(log_text)[-1]" ] @@ -94,7 +90,7 @@ " paired_pics = [paired_pics[i, :, :, :] for i in range(paired_pics.shape[0])]\n", " print(np.array(paired_pics).shape)\n", " visualize_util.grid_save_images(paired_pics, os.path.join('', \"reconstructions.jpg\"))\n", - "else:\n", + "elif model.params.model_type.lower() in ['mlp', 'ensemble']:\n", " test_results = run_utils.test_epoch(0, model, test_loader, log_to_file=False)\n", " print(test_results)" ] @@ -111,8 +107,14 @@ "else:\n", " weights = list(model.parameters())[0].data.cpu().numpy()\n", "\n", - "num_neurons, num_pixels = weights.shape\n", - "weights = np.reshape(weights, [num_neurons, int(np.sqrt(num_pixels)), int(np.sqrt(num_pixels))])" + "if weights.ndim == 4:\n", + " num_neurons, num_channels, num_h, num_w = weights.shape\n", + " num_pixels = num_channels * num_h * num_w\n", + "elif weights.ndim == 2:\n", + " num_neurons, num_pixels = weights.shape\n", + " weights = np.reshape(weights, [num_neurons, 1, int(np.sqrt(num_pixels)), int(np.sqrt(num_pixels))])\n", + "else:\n", + " assert False, (f'weights.ndim == {weights.ndim} must be 2 or 4')\n" ] }, { @@ -141,7 +143,11 @@ "def pad_matrix_to_image(matrix, pad_size=0, pad_value=0, normalize=False):\n", " if normalize:\n", " matrix = normalize_data_with_max(matrix)[0]\n", - " num_weights, img_h, img_w = matrix.shape\n", + " num_weights, img_c, img_h, img_w = matrix.shape\n", + " if img_c == 1:\n", + " matrix = matrix.squeeze()\n", + " else:\n", + " assert False, (f'Multiple color channels are not currently supported') # TODO\n", " num_extra_images = int(np.ceil(np.sqrt(num_weights))**2 - num_weights)\n", " if num_extra_images > 0:\n", " matrix = np.concatenate(\n", @@ -181,7 +187,8 @@ "\n", "tfpf.plot_image(pad_matrix_to_image(weights), vmin=None, vmax=None, title=\"\", save_filename=model.params.disp_dir+\"/weights_plot_image.png\")\n", "tfpf.plot_weights(weights, save_filename=model.params.disp_dir+\"/weights_plot_weights.png\")\n", - "tfpf.plot_data_tiled(weights[..., None], save_filename=model.params.disp_dir+\"/weights_plot_data_tiled.png\")" + "tfpf.plot_data_tiled(np.transpose(weights, (0, 2, 3, 1)),\n", + " save_filename=model.params.disp_dir+\"/weights_plot_data_tiled.png\")" ] }, { diff --git a/params/base_params.py b/params/base_params.py index c82fe486..142bdb5a 100644 --- a/params/base_params.py +++ b/params/base_params.py @@ -13,6 +13,8 @@ class BaseParams(object): device [str] which device to run on dtype [torch dtype] dtype for network variables eps [float] small value to avoid division by zero + fast_mnist [bool] if True, use the fastMNIST dataset, + which loads faster but does not allow for torchvision transforms like flip and rotate lib_root_dir [str] system location of this library directory log_to_file [bool] if set, log to file, else log to stderr model_name [str] name for model (can be anything) diff --git a/params/conv_lca_mnist_params.py b/params/conv_lca_mnist_params.py index d1e35e50..a129258b 100644 --- a/params/conv_lca_mnist_params.py +++ b/params/conv_lca_mnist_params.py @@ -10,18 +10,19 @@ def set_params(self): self.model_name = 'conv_lca_mnist' self.version = '0' self.dataset = 'mnist' + self.fast_mnist = True self.standardize_data = False self.rescale_data_to_one = True self.num_pixels = 784 - self.batch_size = 100 - self.num_epochs = 100 - self.weight_decay = 0. - self.weight_lr = 0.1 + self.batch_size = 50 + self.num_epochs = 500 + self.weight_decay = 0.0 + self.weight_lr = 0.001 self.train_logs_per_epoch = 6 self.optimizer = types.SimpleNamespace() self.optimizer.name = 'sgd' - self.optimizer.lr_annealing_milestone_frac = [0.7] # fraction of num_epochs - self.optimizer.lr_decay_rate = 0.5 + self.optimizer.lr_annealing_milestone_frac = [0.8] # fraction of num_epochs + self.optimizer.lr_decay_rate = 0.8 self.renormalize_weights = True self.dt = 0.001 self.tau = 0.03 diff --git a/params/lca_mlp_mnist_params.py b/params/lca_mlp_mnist_params.py index 212de71f..5cfb6b11 100644 --- a/params/lca_mlp_mnist_params.py +++ b/params/lca_mlp_mnist_params.py @@ -14,6 +14,7 @@ def __init__(self): self.model_name = 'lca_768_mlp_mnist' self.version = '0' self.dataset = 'mnist' + self.fast_mnist = True self.standardize_data = False self.num_pixels = 28*28*1 self.batch_size = 100 diff --git a/params/lca_mnist_params.py b/params/lca_mnist_params.py index ff95de40..881c75ff 100644 --- a/params/lca_mnist_params.py +++ b/params/lca_mnist_params.py @@ -10,6 +10,7 @@ def set_params(self): self.model_name = 'lca_768_mnist' self.version = '0' self.dataset = 'mnist' + self.fast_mnist = True self.standardize_data = False self.num_pixels = 784 self.batch_size = 100 diff --git a/params/mlp_mnist_params.py b/params/mlp_mnist_params.py index 901ab60e..808078fd 100644 --- a/params/mlp_mnist_params.py +++ b/params/mlp_mnist_params.py @@ -14,6 +14,7 @@ def set_params(self): self.model_name = 'mlp_768_mnist' self.version = '0' self.dataset = 'mnist' + self.fast_mnist = True self.standardize_data = False self.rescale_data_to_one = False self.num_pixels = 28*28*1 diff --git a/tf1x/params/lca_conv_params.py b/tf1x/params/lca_conv_params.py index cedd0d96..55f492d2 100644 --- a/tf1x/params/lca_conv_params.py +++ b/tf1x/params/lca_conv_params.py @@ -57,19 +57,24 @@ def set_data_params(self, data_type): self.data_type = data_type if data_type.lower() == "mnist": self.model_name += "_mnist" + self.log_int = 200 self.rescale_data = True self.center_data = False self.whiten_data = False self.lpf_data = False # only for ZCA + self.num_val = 0 + self.batch_size = 50 self.lpf_cutoff = 0.7 - self.num_neurons = 768 + self.num_neurons = 128 + self.num_steps = 75 self.stride_y = 2 self.stride_x = 2 self.patch_size_y = 8 # weight receptive field self.patch_size_x = 8 for schedule_idx in range(len(self.schedule)): - self.schedule[schedule_idx]["sparse_mult"] = 0.21 - self.schedule[schedule_idx]["weight_lr"] = [0.1] + self.schedule[schedule_idx]["num_batches"] = int(6e5) + self.schedule[schedule_idx]["sparse_mult"] = 0.25 + self.schedule[schedule_idx]["weight_lr"] = [0.001] elif data_type.lower() == "vanhateren": self.model_name += "_vh" diff --git a/tf1x/utils/data_processing.py b/tf1x/utils/data_processing.py index 1fd62606..6118a0bf 100644 --- a/tf1x/utils/data_processing.py +++ b/tf1x/utils/data_processing.py @@ -88,7 +88,7 @@ def reshape_data(data, flatten=None, out_shape=None): if flatten == True: data = np.reshape(data, (num_examples, num_rows*num_cols*num_channels)) else: - assert False, ("Data must have 1, 2, 3, or 4 dimensions.") + assert False, (f'Data must have 1, 2, 3, or 4 dimensions, not {orig_ndim}') else: num_examples = None; num_rows=None; num_cols=None; num_channels=None data = np.reshape(data, out_shape) diff --git a/utils/data_processing.py b/utils/data_processing.py index 85dc23c9..a0697454 100644 --- a/utils/data_processing.py +++ b/utils/data_processing.py @@ -1,6 +1,5 @@ import numpy as np import torch -import warn def reshape_data(data, flatten=None, out_shape=None): diff --git a/utils/dataset_utils.py b/utils/dataset_utils.py index cd46541b..6ffb8aae 100644 --- a/utils/dataset_utils.py +++ b/utils/dataset_utils.py @@ -3,7 +3,8 @@ import numpy as np import torch -from torchvision import datasets, transforms +from torchvision import transforms +from torchvision.datasets import MNIST ROOT_DIR = os.path.dirname(os.getcwd()) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) @@ -12,6 +13,30 @@ import DeepSparseCoding.datasets.synthetic as synthetic +class FastMNIST(MNIST): + """ + The torchvision MNIST dataset has additional overhead that slows it down. + This loads the entire dataset onto the specified device at init, resulting in a considerable speedup + """ + def __init__(self, *args, **kwargs): + device = kwargs.pop('device', 'cpu') + super().__init__(*args, **kwargs) + # Scale data to [0,1] + self.data = self.data.unsqueeze(-1).float().div(255) + # Put both data and targets on GPU in advance + self.data, self.targets = self.data.to(device), self.targets.to(device) + + def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + tuple: (image, target) where target is index of the target class. + """ + img, target = self.data[index], self.targets[index] + return img, target + + class CustomTensorDataset(torch.utils.data.Dataset): def __init__(self, data_tensor): self.data_tensor = data_tensor @@ -36,17 +61,33 @@ def load_dataset(params): if params.rescale_data_to_one: preprocessing_pipeline.append( transforms.Lambda(lambda x: dp.rescale_data_to_one(x, eps=params.eps, samplewise=True)[0])) - train_loader = torch.utils.data.DataLoader( - datasets.MNIST(root=params.data_dir, train=True, download=True, - transform=transforms.Compose(preprocessing_pipeline)), - batch_size=params.batch_size, shuffle=params.shuffle_data, - num_workers=0, pin_memory=False) - val_loader = None - test_loader = torch.utils.data.DataLoader( - datasets.MNIST(root=params.data_dir, train=False, download=True, - transform=transforms.Compose(preprocessing_pipeline)), - batch_size=params.batch_size, shuffle=params.shuffle_data, - num_workers=0, pin_memory=False) + kwargs = { + 'root':params.data_dir, + 'download':True, + 'transform':transforms.Compose(preprocessing_pipeline) + } + if hasattr(params, 'fast_mnist') and params.fast_mnist: + kwargs['device'] = params.device + kwargs['train'] = True + train_loader = torch.utils.data.DataLoader( + FastMNIST(**kwargs), batch_size=params.batch_size, + shuffle=params.shuffle_data, num_workers=0, pin_memory=False) + kwargs['train'] = False + val_loader = None + test_loader = torch.utils.data.DataLoader( + FastMNIST(**kwargs), batch_size=params.batch_size, + shuffle=params.shuffle_data, num_workers=0, pin_memory=False) + else: + kwargs['train'] = True + train_loader = torch.utils.data.DataLoader( + MNIST(**kwargs), batch_size=params.batch_size, + shuffle=params.shuffle_data, num_workers=0, pin_memory=True) + kwargs['train'] = False + val_loader = None + test_loader = torch.utils.data.DataLoader( + MNIST(**kwargs), batch_size=params.batch_size, + shuffle=params.shuffle_data, num_workers=0, pin_memory=True) + elif(params.dataset.lower() == 'dsprites'): root = os.path.join(*[params.data_dir]) dsprites_file = os.path.join(*[root, 'dsprites/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz']) @@ -67,6 +108,7 @@ def load_dataset(params): pin_memory=False) val_loader = None test_loader = None + elif(params.dataset.lower() == 'synthetic'): preprocessing_pipeline = [transforms.ToTensor(), transforms.Lambda(lambda x: x.permute(1, 2, 0)) # channels last @@ -80,6 +122,7 @@ def load_dataset(params): val_loader = None test_loader = None new_params["num_pixels"] = params.data_edge_size**2 + else: assert False, (f'Supported datasets are ["mnist", "dsprites", "synthetic"], not {dataset_name}') new_params = {} diff --git a/utils/loaders.py b/utils/loaders.py index e79553ad..21b23af5 100644 --- a/utils/loaders.py +++ b/utils/loaders.py @@ -1,7 +1,8 @@ import os import sys +from os.path import dirname as up -ROOT_DIR = os.path.dirname(os.getcwd()) +ROOT_DIR = up(up(up(os.path.realpath(__file__)))) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import DeepSparseCoding.utils.file_utils as file_utils diff --git a/utils/run_utils.py b/utils/run_utils.py index aa472db2..958f1a8f 100644 --- a/utils/run_utils.py +++ b/utils/run_utils.py @@ -33,7 +33,7 @@ def train_epoch(epoch, model, loader): train_single_model(model, loss) if model.params.train_logs_per_epoch is not None: if(batch_idx % int(num_batches/model.params.train_logs_per_epoch) == 0.): - batch_step = epoch * model.params.batches_per_epoch + batch_idx + batch_step = int((epoch - 1) * model.params.batches_per_epoch + batch_idx) model.print_update( input_data=inputs[0], input_labels=target, batch_step=batch_step) if(model.params.model_type.lower() == 'ensemble'): @@ -46,7 +46,7 @@ def train_epoch(epoch, model, loader): def test_single_model(model, data, target, epoch): output = model(data) #test_loss = torch.nn.functional.nll_loss(output, target, reduction='sum').item() - test_loss = torch.nn.CorssEntropyLoss()(output, target) + test_loss = torch.nn.CrossEntropyLoss()(output, target) pred = output.max(1, keepdim=True)[1] correct = pred.eq(target.view_as(pred)).sum().item() return (test_loss, correct) From 8889ebe0b81e61da75ea96f0dea3f4344ec0ba6a Mon Sep 17 00:00:00 2001 From: Dylan Date: Wed, 30 Sep 2020 13:14:59 +0200 Subject: [PATCH 07/44] adds proper error message and IPython embed on js load fail --- tf1x/utils/logger.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tf1x/utils/logger.py b/tf1x/utils/logger.py index c527e488..edfb404e 100644 --- a/tf1x/utils/logger.py +++ b/tf1x/utils/logger.py @@ -90,7 +90,13 @@ def read_js(self, tokens, text): assert len(tokens) == 2, ("Input variable tokens must be a list of length 2") matches = re.findall(re.escape(tokens[0])+"([\s\S]*?)"+re.escape(tokens[1]), text) if len(matches) > 1: - js_matches = [js.loads(match) for match in matches] + js_matches = [] + for match_idx, match in enumerate(matches): + try: + js_matches.append(js.loads(match)) + except: + print(f'ERROR: JSON load failed on match index {match_idx}') + import IPython; IPython.embed(); raise SystemExit else: js_matches = [js.loads(matches[0])] return js_matches From bb17d46d18595910a49270b4ab497cac5a59d00d Mon Sep 17 00:00:00 2001 From: Dylan Date: Wed, 30 Sep 2020 15:10:05 +0200 Subject: [PATCH 08/44] updates plots for final JOV submission --- tf1x/utils/jov_funcs.py | 402 ++++++++++++-------- tf1x/vis/JOV_Euler_Attacks.ipynb | 83 ++++- tf1x/vis/JOV_figs.ipynb | 617 ++++++++++++++++--------------- 3 files changed, 658 insertions(+), 444 deletions(-) diff --git a/tf1x/utils/jov_funcs.py b/tf1x/utils/jov_funcs.py index be46539e..5dd8f78e 100644 --- a/tf1x/utils/jov_funcs.py +++ b/tf1x/utils/jov_funcs.py @@ -119,10 +119,6 @@ def plot_iso_contour_set(analysis_dict, curvatures, num_levels, num_x, num_y, sh width=arrow_width, head_width=arrow_head_width, head_length=arrow_head_length, fc='k', ec='k', linestyle='-', linewidth=arrow_linewidth) tenth_range_shift = ((max(analysis_dict['x_range']) - min(analysis_dict['x_range']))/10) # For shifting labels - #text_handle = curve_axes[-1].text( - # target_vector_x+(tenth_range_shift*phi_k_text_x_offset), - # target_vector_y+(tenth_range_shift*phi_k_text_y_offset), - # r'$\Phi_{k}$', horizontalalignment='center', verticalalignment='center') # plot comparison neuron arrow & label proj_comparison = analysis_dict['contour_dataset']['proj_comparison_vect'][neuron_index][orth_index] comparison_vector_x = proj_comparison[0].item() @@ -130,10 +126,6 @@ def plot_iso_contour_set(analysis_dict, curvatures, num_levels, num_x, num_y, sh curve_axes[-1].arrow(0, 0, comparison_vector_x, comparison_vector_y, width=arrow_width, head_width=arrow_head_width, head_length=arrow_head_length, fc='k', ec='k', linestyle='-', linewidth=arrow_linewidth) - #text_handle = curve_axes[-1].text( - # comparison_vector_x+(tenth_range_shift*phi_j_text_x_offset), - # comparison_vector_y+(tenth_range_shift*phi_j_text_y_offset), - # r'$\Phi_{j}$', horizontalalignment='center', verticalalignment='center') # Plot orthogonal vector Nu proj_orth = analysis_dict['contour_dataset']['proj_orth_vect'][neuron_index][orth_index] orth_vector_x = proj_orth[0].item() @@ -141,10 +133,6 @@ def plot_iso_contour_set(analysis_dict, curvatures, num_levels, num_x, num_y, sh curve_axes[-1].arrow(0, 0, orth_vector_x, orth_vector_y, width=arrow_width, head_width=arrow_head_width, head_length=arrow_head_length, fc='k', ec='k', linestyle='-', linewidth=arrow_linewidth) - #text_handle = curve_axes[-1].text( - # orth_vector_x+(tenth_range_shift*nu_text_x_offset), - # orth_vector_y+(tenth_range_shift*nu_text_y_offset), - # r'$\nu$', horizontalalignment='center', verticalalignment='center') # Plot axes curve_axes[-1].set_aspect('equal') curve_axes[-1].plot(analysis_dict['x_range'], [0,0], color='k', linewidth=arrow_linewidth/2) @@ -153,14 +141,11 @@ def plot_iso_contour_set(analysis_dict, curvatures, num_levels, num_x, num_y, sh k_idx = analysis_dict["target_neuron_ids"][neuron_index] curv_val = curvatures[y_id, x_id] curve_axes[-1].set_title( - #f'k={k_idx}; j={j_idx}\nC={curv_val:.3f}', f'C={curv_val:.3f}', fontsize=rcParams['axes.titlesize']/2, pad=2, horizontalalignment='center' ) - #curve_axes[-1].text(x=-0.08, y=1.75, s=f'C={curvatures[y_id, x_id]:.3f}', - # horizontalalignment='right', verticalalignment='center', fontsize=6) if y_id==0: curve_axes[-1].set_ylabel(str(neuron_index), visible=True) # Add colorbar @@ -180,6 +165,123 @@ def plot_iso_contour_set(analysis_dict, curvatures, num_levels, num_x, num_y, sh return fig, contour_handles +def plot_curvature_histograms(activity, contour_pts, contour_angle, view_elevation, contour_text_loc, hist_list, + label_list, color_list, mesh_color, bin_centers, title, xlabel, curve_lims, + scatter, log=True, text_width=200, width_ratio=1.0, dpi=100): + gs0_wspace = 0.5 + hspace_hist = 0.7 + wspace_hist = 0.08 + iso_response_line_thickness = 2 + respone_attenuation_line_thickness = 2 + num_y_plots = 2 if log else 1 + num_x_plots = 1 + fig = plt.figure(figsize=set_size(text_width, width_ratio, [num_y_plots, num_x_plots]), dpi=dpi) + gs_base = gridspec.GridSpec(num_y_plots, num_x_plots, wspace=gs0_wspace) + if log: + curve_ax = fig.add_subplot(gs_base[0], projection='3d') + curve_ax.minorticks_off() + x_mesh, y_mesh = np.meshgrid(*contour_pts) + curve_ax.set_zlim(0, 1) + curve_ax.set_xlim3d(5, 200) + curve_ax.grid(False) + curve_ax.set_xticklabels([]) + curve_ax.set_yticklabels([]) + curve_ax.set_zticklabels([]) + curve_ax.zaxis.set_rotate_label(False) + if scatter: + curve_ax.scatter(x_mesh, y_mesh, activity, color=mesh_color, s=0.01) + else: + curve_ax.plot_wireframe(x_mesh, y_mesh, activity, rcount=100, ccount=100, color=mesh_color, zorder=1, + linestyles='dotted', linewidths=0.3, alpha=1.0) + # Plane vector visualizations + v = Arrow3D([-200/3., -200/3.], [200/2., 200/2.+200/16.], + [0, 0.0], mutation_scale=10, + lw=0.5, arrowstyle='-|>', color='red', linestyle='dashed') + curve_ax.add_artist(v) + curve_ax.text(-300/3., 280/3.0, 0.0, r'$\nu$', color='red') + phi_k = Arrow3D([-200/3., 0.], [200/2., 200/2.], + [0, 0.0], mutation_scale=10, + lw=1, arrowstyle='-|>', color='red', linestyle = 'dashed') + curve_ax.add_artist(phi_k) + curve_ax.text(-175/3., 250/3.0, 0.0, r'${\phi}_{k}$', color='red') + # Iso-response curve + loc0, loc1, loc2 = contour_text_loc[0] + curve_ax.text(loc0, loc1, loc2, 'Iso-\nresponse', color='black', weight='bold', zorder=10) + lines = np.array([0.2, 0.203, 0.197]) - 0.1 + for i in lines: + curve_ax.contour3D(x_mesh, y_mesh, activity, [i], colors='black', linewidths=2, zorder=2) + # Response attenuation curve + loc0, loc1, loc2 = contour_text_loc[1] + curve_ax.text(loc0, loc1, loc2, 'Response\nAttenuation', color='black', weight='bold', zorder=10) + att_line_offset = 165 + x, y = contour_pts + curve_ax.plot(np.zeros_like(x)+att_line_offset, y, activity[:, att_line_offset], + color='black', lw=2, zorder=2) + # Activity label + #loc0, loc1, loc2 = contour_text_loc[2] + #curve_ax.text(loc0, loc1, loc2, 'Activity', color='black', weight='bold', zorder=10, zdir='z') + # Additional settings + curve_ax.view_init(view_elevation, contour_angle) + scaling = np.array([getattr(curve_ax, 'get_{}lim'.format(dim))() for dim in 'xyz']) + curve_ax.auto_scale_xyz(*[[np.min(scaling), np.max(scaling)]]*3) # square aspect + curve_ax._axis3don = False + gs_base_idx = 1 if log else 0 + # Histogram plots + num_hist_y_plots = 2 + num_hist_x_plots = 2 + gs_hist = gridspec.GridSpecFromSubplotSpec(num_hist_y_plots, num_hist_x_plots, gs_base[gs_base_idx], + hspace=hspace_hist, wspace=wspace_hist) + orig_ax = fig.add_subplot(gs_hist[0,0]) + axes = [] + for sub_plt_y in range(0, num_hist_y_plots): + axes.append([]) + for sub_plt_x in range(0, num_hist_x_plots): + if (sub_plt_x, sub_plt_y) == (0,0): + axes[sub_plt_y].append(orig_ax) + else: + axes[sub_plt_y].append(fig.add_subplot(gs_hist[sub_plt_y, sub_plt_x], sharey=orig_ax)) + all_x_lists = zip(hist_list, label_list, color_list, bin_centers, title) + for axis_x, (curvature_hist, sub_label, sub_color, sub_bins, sub_title) in enumerate(all_x_lists): + sub_bins = np.squeeze(sub_bins) + all_y_lists = zip(curvature_hist, sub_label, sub_color, xlabel) + for axis_y, (dataset_hist, axis_labels, axis_colors, sub_xlabel) in enumerate(all_y_lists): + axes[axis_y][axis_x].spines['top'].set_visible(False) + axes[axis_y][axis_x].spines['right'].set_visible(False) + axes[axis_y][axis_x].set_xticks(sub_bins, minor=True) + axes[axis_y][axis_x].set_xticks(sub_bins[::int(len(sub_bins)/4)], minor=False) + axes[axis_y][axis_x].xaxis.set_major_formatter(plticker.FormatStrFormatter('%0.3f')) + for neuron_hist, label, color in zip(dataset_hist, axis_labels, axis_colors): + neuron_hist = np.squeeze(neuron_hist) + if log: + axes[axis_y][axis_x].semilogy(sub_bins, neuron_hist, color=color, linestyle='-', + drawstyle='steps-mid', label=label) + axes[axis_y][axis_x].yaxis.set_major_formatter(matplotlib.ticker.LogFormatterSciNotation()) + else: + axes[axis_y][axis_x].plot(sub_bins, neuron_hist, color=color, linestyle='-', drawstyle='steps-mid', label=label) + axes[axis_y][axis_x].axvline(0.0, color='black', linestyle='dashed', linewidth=1) + if axis_y == 0: + axes[axis_y][axis_x].set_title(sub_title) + axes[axis_y][axis_x].set_xlabel(sub_xlabel) + if axis_x == 0: + if log: + axes[axis_y][axis_x].set_ylabel('Relative\nLog Frequency') + else: + axes[axis_y][axis_x].set_ylabel('Relative\nFrequency') + ax_handles, ax_labels = axes[axis_y][axis_x].get_legend_handles_labels() + legend = axes[axis_y][axis_x].legend(handles=ax_handles, labels=ax_labels, loc='upper right', + ncol=3, borderaxespad=0., borderpad=0., handlelength=0., columnspacing=-0.5, + labelspacing=0., bbox_to_anchor=(0.95, 0.95)) + legend.get_frame().set_linewidth(0.0) + for text, color in zip(legend.get_texts(), axis_colors): + text.set_color(color) + for item in legend.legendHandles: + item.set_visible(False) + if axis_x == 1: + axes[axis_y][axis_x].tick_params(axis='y', labelleft=False) + plt.show() + return fig + + def plot_group_iso_contours(analyzer_list, neuron_indices, orth_indices, num_levels, x_range, y_range, show_contours=True, curvature=None, text_width=200, width_fraction=1.0, dpi=100): arrow_width = 0.0 @@ -421,140 +523,142 @@ def compute_curvature_hists(analyzer_list, num_bins): rand_hist, _ = np.histogram(flat_rand_curvatures, attn_bins, density=False) analyzer.attn_rand_hist = rand_hist / len(flat_rand_curvatures) -def plot_curvature_histograms(activity, contour_pts, contour_angle, contour_text_loc, hist_list, label_list, - color_list, mesh_color, bin_centers, title, xlabel, curve_lims, log=True, - text_width=200, width_ratio=1.0, dpi=100): - gs0_wspace = 0.5 - hspace_hist = 0.7 - wspace_hist = 0.08 - view_elevation = 30 - iso_response_line_thickness = 2 - respone_attenuation_line_thickness = 2 - num_y_plots = 2 - num_x_plots = 1 - fig = plt.figure(figsize=set_size(text_width, width_ratio, [num_y_plots, num_x_plots]), dpi=dpi) - gs_base = gridspec.GridSpec(num_y_plots, num_x_plots, wspace=gs0_wspace) - curve_ax = fig.add_subplot(gs_base[0], projection='3d') - x_mesh, y_mesh = np.meshgrid(*contour_pts) - curve_ax.set_zlim(0, 1) - curve_ax.set_xlim3d(5, 200) - curve_ax.grid(b=False, zorder=0) - x_ticks = curve_ax.get_xticks().tolist() - x_ticks = np.round(np.linspace(curve_lims['x'][0], curve_lims['x'][1], - len(x_ticks)), 1).astype(str) - a_x = [' ']*len(x_ticks) - a_x[1] = x_ticks[1] - a_x[-1] = x_ticks[-1] - curve_ax.set_xticklabels(a_x) - y_ticks = curve_ax.get_yticks().tolist() - y_ticks = np.round(np.linspace(curve_lims['y'][0], curve_lims['y'][1], - len(y_ticks)), 1).astype(str) - a_y = [' ']*len(y_ticks) - a_y[1] = y_ticks[1] - a_y[-1] = y_ticks[-1] - curve_ax.set_yticklabels(a_y) - curve_ax.set_zticklabels([]) - curve_ax.zaxis.set_rotate_label(False) - curve_ax.set_zlabel('Normalized\nActivity', rotation=95, labelpad=-12., position=(-10., 0.)) - #curve_ax.scatter(x_mesh, y_mesh, activity, color=mesh_color, s=0.01) - curve_ax.plot_wireframe(x_mesh, y_mesh, activity, rcount=100, ccount=100, color=mesh_color, zorder=1, - linestyles='dotted', linewidths=0.5, alpha=1.0) - # Plane vector visualizations - v = Arrow3D([-200/3., -200/3.], [200/2., 200/2.+200/16.], - [0, 0.0], mutation_scale=10, - lw=1, arrowstyle='-|>', color='red', linestyle='dashed') - curve_ax.add_artist(v) - curve_ax.text(-300/3., 280/3.0, 0.0, r'$\nu$', color='red') - phi_k = Arrow3D([-200/3., 0.], [200/2., 200/2.], - [0, 0.0], mutation_scale=10, - lw=1, arrowstyle='-|>', color='red', linestyle = 'dashed') - curve_ax.add_artist(phi_k) - curve_ax.text(-175/3., 250/3.0, 0.0, r'${\phi}_{k}$', color='red') - # Iso-response curve - loc0, loc1, loc2 = contour_text_loc[0] - curve_ax.text(loc0, loc1, loc2, 'Iso-\nresponse', color='black', weight='bold', zorder=10) - iso_line_offset = 165 - x, y = contour_pts - curve_ax.plot(np.zeros_like(x)+iso_line_offset, y, activity[:, iso_line_offset], - color='black', lw=2, zorder=2) - # Response attenuation curve - loc0, loc1, loc2 = contour_text_loc[1] - curve_ax.text(loc0, loc1, loc2, 'Response\nAttenuation', color='black', weight='bold', zorder=10) - lines = np.array([0.2, 0.203, 0.197]) - 0.1 - for i in lines: - curve_ax.contour3D(x_mesh, y_mesh, activity, [i], colors='black', linewidths=2, zorder=2) - # Additional settings - curve_ax.view_init(view_elevation, contour_angle) - scaling = np.array([getattr(curve_ax, 'get_{}lim'.format(dim))() for dim in 'xyz']) - curve_ax.auto_scale_xyz(*[[np.min(scaling), np.max(scaling)]]*3) # make sure it has a square aspect - num_hist_y_plots = 2 - num_hist_x_plots = 2 - gs_hist = gridspec.GridSpecFromSubplotSpec(num_hist_y_plots, num_hist_x_plots, gs_base[1], - hspace=hspace_hist, wspace=wspace_hist) - orig_ax = fig.add_subplot(gs_hist[0,0]) - axes = [] - for sub_plt_y in range(0, num_hist_y_plots): - axes.append([]) - for sub_plt_x in range(0, num_hist_x_plots): - if (sub_plt_x, sub_plt_y) == (0,0): - axes[sub_plt_y].append(orig_ax) - else: - axes[sub_plt_y].append(fig.add_subplot(gs_hist[sub_plt_y, sub_plt_x], sharey=orig_ax)) - #[curvature type] [iso/att] - #[dataset type] [comp/rand] - #[target neuron id] - all_x_lists = zip(hist_list, label_list, color_list, bin_centers, title) - for axis_x, (curvature_hist, sub_label, sub_color, sub_bins, sub_title) in enumerate(all_x_lists): - sub_bins = np.squeeze(sub_bins) - #max_hist_val = 0.001 - #min_hist_val = 100 - all_y_lists = zip(curvature_hist, sub_label, sub_color, xlabel) - for axis_y, (dataset_hist, axis_labels, axis_colors, sub_xlabel) in enumerate(all_y_lists): - axes[axis_y][axis_x].spines['top'].set_visible(False) - axes[axis_y][axis_x].spines['right'].set_visible(False) - axes[axis_y][axis_x].set_xticks(sub_bins, minor=True) - axes[axis_y][axis_x].set_xticks(sub_bins[::int(len(sub_bins)/4)], minor=False) - axes[axis_y][axis_x].xaxis.set_major_formatter(plticker.FormatStrFormatter('%0.3f')) - for neuron_hist, label, color in zip(dataset_hist, axis_labels, axis_colors): - neuron_hist = np.squeeze(neuron_hist) - plot_hist = [] - - if log: - axes[axis_y][axis_x].semilogy(sub_bins, neuron_hist, color=color, linestyle='-', drawstyle='steps-mid', label=label) - #axes[axis_y][axis_x].set_yscale('log') - axes[axis_y][axis_x].yaxis.set_major_formatter( - ticker.FuncFormatter( - lambda y,pos: ('{{:.{:1d}f}}'.format(int(np.maximum(-np.log10(y),0)))).format(y) - ) - ) - else: - axes[axis_y][axis_x].plot(sub_bins, neuron_hist, color=color, linestyle='-', drawstyle='steps-mid', label=label) - #if np.max(hist) > max_hist_val: - # max_hist_val = np.max(hist) - #if np.min(hist) < min_hist_val: - # min_hist_val = np.min(hist) - axes[axis_y][axis_x].axvline(0.0, color='black', linestyle='dashed', linewidth=1) - if axis_y == 0: - axes[axis_y][axis_x].set_title(sub_title) - axes[axis_y][axis_x].set_xlabel(sub_xlabel) - if axis_x == 0: - if log: - axes[axis_y][axis_x].set_ylabel('Relative\nLog Frequency') - else: - axes[axis_y][axis_x].set_ylabel('Relative\nFrequency') - ax_handles, ax_labels = axes[axis_y][axis_x].get_legend_handles_labels() - legend = axes[axis_y][axis_x].legend(handles=ax_handles, labels=ax_labels, loc='upper right', - ncol=3, borderaxespad=0., borderpad=0., handlelength=0., columnspacing=-0.5, - labelspacing=0., bbox_to_anchor=(0.95, 0.95)) - legend.get_frame().set_linewidth(0.0) - for text, color in zip(legend.get_texts(), axis_colors): - text.set_color(color) - for item in legend.legendHandles: - item.set_visible(False) - if axis_x == 1: - axes[axis_y][axis_x].tick_params(axis='y', labelleft=False) - plt.show() - return fig + +#def plot_curvature_histograms(activity, contour_pts, contour_angle, contour_text_loc, hist_list, label_list, +# color_list, mesh_color, bin_centers, title, xlabel, curve_lims, log=True, +# text_width=200, width_ratio=1.0, dpi=100): +# gs0_wspace = 0.5 +# hspace_hist = 0.7 +# wspace_hist = 0.08 +# view_elevation = 30 +# iso_response_line_thickness = 2 +# respone_attenuation_line_thickness = 2 +# num_y_plots = 2 +# num_x_plots = 1 +# fig = plt.figure(figsize=set_size(text_width, width_ratio, [num_y_plots, num_x_plots]), dpi=dpi) +# gs_base = gridspec.GridSpec(num_y_plots, num_x_plots, wspace=gs0_wspace) +# curve_ax = fig.add_subplot(gs_base[0], projection='3d') +# x_mesh, y_mesh = np.meshgrid(*contour_pts) +# curve_ax.set_zlim(0, 1) +# curve_ax.set_xlim3d(5, 200) +# curve_ax.grid(b=False, zorder=0) +# x_ticks = curve_ax.get_xticks().tolist() +# x_ticks = np.round(np.linspace(curve_lims['x'][0], curve_lims['x'][1], +# len(x_ticks)), 1).astype(str) +# a_x = [' ']*len(x_ticks) +# a_x[1] = x_ticks[1] +# a_x[-1] = x_ticks[-1] +# curve_ax.set_xticklabels(a_x) +# y_ticks = curve_ax.get_yticks().tolist() +# y_ticks = np.round(np.linspace(curve_lims['y'][0], curve_lims['y'][1], +# len(y_ticks)), 1).astype(str) +# a_y = [' ']*len(y_ticks) +# a_y[1] = y_ticks[1] +# a_y[-1] = y_ticks[-1] +# curve_ax.set_yticklabels(a_y) +# curve_ax.set_zticklabels([]) +# curve_ax.zaxis.set_rotate_label(False) +# curve_ax.set_zlabel('Normalized\nActivity', rotation=95, labelpad=-12., position=(-10., 0.)) +# #curve_ax.scatter(x_mesh, y_mesh, activity, color=mesh_color, s=0.01) +# curve_ax.plot_wireframe(x_mesh, y_mesh, activity, rcount=100, ccount=100, color=mesh_color, zorder=1, +# linestyles='dotted', linewidths=0.5, alpha=1.0) +# # Plane vector visualizations +# v = Arrow3D([-200/3., -200/3.], [200/2., 200/2.+200/16.], +# [0, 0.0], mutation_scale=10, +# lw=1, arrowstyle='-|>', color='red', linestyle='dashed') +# curve_ax.add_artist(v) +# curve_ax.text(-300/3., 280/3.0, 0.0, r'$\nu$', color='red') +# phi_k = Arrow3D([-200/3., 0.], [200/2., 200/2.], +# [0, 0.0], mutation_scale=10, +# lw=1, arrowstyle='-|>', color='red', linestyle = 'dashed') +# curve_ax.add_artist(phi_k) +# curve_ax.text(-175/3., 250/3.0, 0.0, r'${\phi}_{k}$', color='red') +# # Iso-response curve +# loc0, loc1, loc2 = contour_text_loc[0] +# curve_ax.text(loc0, loc1, loc2, 'Iso-\nresponse', color='black', weight='bold', zorder=10) +# iso_line_offset = 165 +# x, y = contour_pts +# curve_ax.plot(np.zeros_like(x)+iso_line_offset, y, activity[:, iso_line_offset], +# color='black', lw=2, zorder=2) +# # Response attenuation curve +# loc0, loc1, loc2 = contour_text_loc[1] +# curve_ax.text(loc0, loc1, loc2, 'Response\nAttenuation', color='black', weight='bold', zorder=10) +# lines = np.array([0.2, 0.203, 0.197]) - 0.1 +# for i in lines: +# curve_ax.contour3D(x_mesh, y_mesh, activity, [i], colors='black', linewidths=2, zorder=2) +# # Additional settings +# curve_ax.view_init(view_elevation, contour_angle) +# scaling = np.array([getattr(curve_ax, 'get_{}lim'.format(dim))() for dim in 'xyz']) +# curve_ax.auto_scale_xyz(*[[np.min(scaling), np.max(scaling)]]*3) # make sure it has a square aspect +# num_hist_y_plots = 2 +# num_hist_x_plots = 2 +# gs_hist = gridspec.GridSpecFromSubplotSpec(num_hist_y_plots, num_hist_x_plots, gs_base[1], +# hspace=hspace_hist, wspace=wspace_hist) +# orig_ax = fig.add_subplot(gs_hist[0,0]) +# axes = [] +# for sub_plt_y in range(0, num_hist_y_plots): +# axes.append([]) +# for sub_plt_x in range(0, num_hist_x_plots): +# if (sub_plt_x, sub_plt_y) == (0,0): +# axes[sub_plt_y].append(orig_ax) +# else: +# axes[sub_plt_y].append(fig.add_subplot(gs_hist[sub_plt_y, sub_plt_x], sharey=orig_ax)) +# #[curvature type] [iso/att] +# #[dataset type] [comp/rand] +# #[target neuron id] +# all_x_lists = zip(hist_list, label_list, color_list, bin_centers, title) +# for axis_x, (curvature_hist, sub_label, sub_color, sub_bins, sub_title) in enumerate(all_x_lists): +# sub_bins = np.squeeze(sub_bins) +# #max_hist_val = 0.001 +# #min_hist_val = 100 +# all_y_lists = zip(curvature_hist, sub_label, sub_color, xlabel) +# for axis_y, (dataset_hist, axis_labels, axis_colors, sub_xlabel) in enumerate(all_y_lists): +# axes[axis_y][axis_x].spines['top'].set_visible(False) +# axes[axis_y][axis_x].spines['right'].set_visible(False) +# axes[axis_y][axis_x].set_xticks(sub_bins, minor=True) +# axes[axis_y][axis_x].set_xticks(sub_bins[::int(len(sub_bins)/4)], minor=False) +# axes[axis_y][axis_x].xaxis.set_major_formatter(plticker.FormatStrFormatter('%0.3f')) +# for neuron_hist, label, color in zip(dataset_hist, axis_labels, axis_colors): +# neuron_hist = np.squeeze(neuron_hist) +# plot_hist = [] +# +# if log: +# axes[axis_y][axis_x].semilogy(sub_bins, neuron_hist, color=color, linestyle='-', drawstyle='steps-mid', label=label) +# #axes[axis_y][axis_x].set_yscale('log') +# axes[axis_y][axis_x].yaxis.set_major_formatter( +# ticker.FuncFormatter( +# lambda y,pos: ('{{:.{:1d}f}}'.format(int(np.maximum(-np.log10(y),0)))).format(y) +# ) +# ) +# else: +# axes[axis_y][axis_x].plot(sub_bins, neuron_hist, color=color, linestyle='-', drawstyle='steps-mid', label=label) +# #if np.max(hist) > max_hist_val: +# # max_hist_val = np.max(hist) +# #if np.min(hist) < min_hist_val: +# # min_hist_val = np.min(hist) +# axes[axis_y][axis_x].axvline(0.0, color='black', linestyle='dashed', linewidth=1) +# if axis_y == 0: +# axes[axis_y][axis_x].set_title(sub_title) +# axes[axis_y][axis_x].set_xlabel(sub_xlabel) +# if axis_x == 0: +# if log: +# axes[axis_y][axis_x].set_ylabel('Relative\nLog Frequency') +# else: +# axes[axis_y][axis_x].set_ylabel('Relative\nFrequency') +# ax_handles, ax_labels = axes[axis_y][axis_x].get_legend_handles_labels() +# legend = axes[axis_y][axis_x].legend(handles=ax_handles, labels=ax_labels, loc='upper right', +# ncol=3, borderaxespad=0., borderpad=0., handlelength=0., columnspacing=-0.5, +# labelspacing=0., bbox_to_anchor=(0.95, 0.95)) +# legend.get_frame().set_linewidth(0.0) +# for text, color in zip(legend.get_texts(), axis_colors): +# text.set_color(color) +# for item in legend.legendHandles: +# item.set_visible(False) +# if axis_x == 1: +# axes[axis_y][axis_x].tick_params(axis='y', labelleft=False) +# plt.show() +# return fig + def plot_contrast_orientation_tuning(bf_indices, contrasts, orientations, activations, figsize=(32,32)): ''' diff --git a/tf1x/vis/JOV_Euler_Attacks.ipynb b/tf1x/vis/JOV_Euler_Attacks.ipynb index efef9419..60a7d42c 100644 --- a/tf1x/vis/JOV_Euler_Attacks.ipynb +++ b/tf1x/vis/JOV_Euler_Attacks.ipynb @@ -32,11 +32,13 @@ "outputs": [], "source": [ "import autograd.numpy as np\n", + "import matplotlib as mpl\n", "import matplotlib.pyplot as plt\n", "\n", "root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.getcwd())))\n", "if root_path not in sys.path: sys.path.append(root_path)\n", "\n", + "import DeepSparseCoding.tf1x.utils.jov_funcs as jov\n", "import DeepSparseCoding.tf1x.analysis.analysis_picker as ap\n", "from DeepSparseCoding.tf1x.data.dataset import Dataset\n", "import DeepSparseCoding.tf1x.utils.data_processing as dp" @@ -65,7 +67,8 @@ " import schematic_utils\n", "except ImportError:\n", " import sys\n", - " sys.path.append(\"../schematic_figure/\")\n", + " usr = os.path.expanduser('~')\n", + " sys.path.append(usr+'/Work/DeepSparseCoding/tf1x/')\n", " import schematic_utils" ] }, @@ -75,9 +78,30 @@ "metadata": {}, "outputs": [], "source": [ - "figsize = (16, 16)\n", + "text_width = 540.60236 #pt 416.83269 #pt = 14.65cm\n", + "text_width_cm = 18.9973 # 14.705\n", + "fontsize = 10\n", + "dpi = 300\n", + "file_extensions = ['.pdf']#, '.eps', '.png']\n", + "#figsize = (16, 16)\n", + "num_y_plots = 3\n", + "num_x_plots = 1\n", + "width_ratio = 1.0\n", + "figsize = jov.set_size(text_width, width_ratio, [num_y_plots, num_x_plots])\n", "fontsize = 20\n", - "dpi = 200" + "font_settings = {\n", + " \"text.usetex\": True,\n", + " \"font.family\": 'serif',\n", + " \"font.serif\": 'Computer Modern Roman',\n", + " \"axes.labelsize\": fontsize,\n", + " \"axes.titlesize\": fontsize,\n", + " \"figure.titlesize\": fontsize+2,\n", + " \"font.size\": fontsize,\n", + " \"legend.fontsize\": fontsize,\n", + " \"xtick.labelsize\": fontsize-2,\n", + " \"ytick.labelsize\": fontsize-2,\n", + "}\n", + "mpl.rcParams.update(font_settings)" ] }, { @@ -364,7 +388,8 @@ "metadata": {}, "outputs": [], "source": [ - "f = plt.figure(figsize=(2*figsize[0],figsize[1]), dpi=dpi)\n", + "figsize = (2*16,16)#(figsize[0], figsize[1])\n", + "f = plt.figure(figsize=figsize, dpi=dpi)\n", "fig_shape = (1, 4)\n", "\n", "mlp_ax = plt.subplot2grid(fig_shape, loc=(0, 0), colspan=1, fig=f)\n", @@ -404,11 +429,57 @@ "metadata": {}, "outputs": [], "source": [ - "for ext in [\".png\", \".eps\"]:\n", - " save_name = (analyzer.analysis_out_dir+\"/vis/contours_and_gradients_schematic\"\n", + "for ext in file_extensions:#[\".png\", \".eps\"]:\n", + " save_name = (analyzer.analysis_out_dir+\"/vis/contours_and_gradients_schematic_new\"\n", " +\"_\"+analyzer.analysis_params.save_info+ext)\n", " f.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.01, dpi=dpi)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig_size_inches = f.get_size_inches()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig_size_inches" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "f.set_size_inches(fig_size_inches/2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for ext in file_extensions:#[\".png\", \".eps\"]:\n", + " save_name = (analyzer.analysis_out_dir+\"/vis/contours_and_gradients_schematic_small\"\n", + " +\"_\"+analyzer.analysis_params.save_info+ext)\n", + " f.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.01, dpi=dpi)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/tf1x/vis/JOV_figs.ipynb b/tf1x/vis/JOV_figs.ipynb index 7c92c4ca..10c8bb02 100644 --- a/tf1x/vis/JOV_figs.ipynb +++ b/tf1x/vis/JOV_figs.ipynb @@ -85,10 +85,16 @@ "metadata": {}, "outputs": [], "source": [ - "text_width = 416.83269 #pt = 14.65cm\n", - "text_width_cm = 14.705\n", - "fontsize = 12\n", - "dpi = 1200" + "\"\"\"\n", + "textwidth in pt: 540.60236pt\n", + "textwidth in cm: 18.9973cm\n", + "textwidth in in: 7.48178in\n", + "\"\"\"\n", + "text_width = 540.60236 #pt 416.83269 #pt = 14.65cm\n", + "text_width_cm = 18.9973 # 14.705\n", + "fontsize = 10\n", + "dpi = 300\n", + "file_extensions = ['.pdf']#, '.eps', '.png']" ] }, { @@ -99,10 +105,11 @@ "source": [ "font_settings = {\n", " \"text.usetex\": True,\n", - " \"font.family\": \"serif\",\n", + " \"font.family\": 'serif',\n", + " \"font.serif\": 'Computer Modern Roman',\n", " \"axes.labelsize\": fontsize,\n", " \"axes.titlesize\": fontsize,\n", - " \"figure.titlesize\": fontsize,\n", + " \"figure.titlesize\": fontsize+2,\n", " \"font.size\": fontsize,\n", " \"legend.fontsize\": fontsize,\n", " \"xtick.labelsize\": fontsize-2,\n", @@ -305,18 +312,8 @@ "for analyzer, save_name in zip(analyzer_list, save_names):\n", " analyzer.iso_params = np.load(analyzer.analysis_out_dir+'savefiles/iso_params_'+save_name\n", " +analyzer.analysis_params.save_info+'.npz', allow_pickle=True)['data'].item()\n", - " #min_angle = analyzer.iso_params['min_angle']\n", - " #batch_size = analyzer.iso_params['batch_size']\n", - " #vh_image_scale = analyzer.iso_params['vh_image_scale']\n", - " #comparison_method = analyzer.iso_params['comparison_method']\n", - " #num_neurons = analyzer.iso_params['num_neurons']\n", - " #analyzer.num_comparison_vectors = analyzer.iso_params['num_comparisons']\n", " x_range = analyzer.iso_params['x_range']\n", " y_range = analyzer.iso_params['y_range']\n", - " #num_images = analyzer.iso_params['num_images']\n", - " #params_list = analyzer.iso_params['params_list']\n", - " #iso_save_name = analyzer.iso_params['iso_save_name']\n", - " #target_neuron_ids = analyzer.iso_params['target_neuron_ids']\n", "\n", " iso_vectors = np.load(analyzer.analysis_out_dir+'savefiles/iso_vectors_'+save_name\n", " +analyzer.analysis_params.save_info+'.npz', allow_pickle=True)['data'].item()\n", @@ -383,7 +380,7 @@ " num_levels, x_range, y_range, show_contours, curvature, text_width, width_fraction, dpi)\n", "\n", "for analyzer, neuron_index, orth_index, save_suffix in zip(analyzer_list, neuron_indices, orth_indices, save_names):\n", - " for ext in [\".eps\"]:#[\".png\", \".eps\"]:\n", + " for ext in file_extensions:\n", " neuron_str = str(analyzer.target_neuron_ids[neuron_index])\n", " orth_str = str(analyzer.comparison_neuron_ids[neuron_index][orth_index])\n", " save_name = analyzer.analysis_out_dir+\"/vis/iso_contour_comparison_\"\n", @@ -432,11 +429,11 @@ " num_y,\n", " show_contours,\n", " text_width,\n", - " width_fraction,\n", + " 1.00,\n", " dpi\n", ")\n", "\n", - "for ext in [\".eps\"]:#[\".png\", \".eps\"]:\n", + "for ext in file_extensions:\n", " save_name = analyzer.analysis_out_dir+\"/vis/scaled_iso_contours_set_\"\n", " if not show_contours:\n", " save_name += \"continuous_\"\n", @@ -550,154 +547,6 @@ "full_xlabel = [\"Curvature (Comparison)\", \"Curvature (Random)\"]" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def plot_curvature_histograms(activity, contour_pts, contour_angle, view_elevation, contour_text_loc, hist_list,\n", - " label_list, color_list, mesh_color, bin_centers, title, xlabel, curve_lims,\n", - " scatter, log=True, text_width=200, width_ratio=1.0, dpi=100):\n", - " gs0_wspace = 0.5\n", - " hspace_hist = 0.7\n", - " wspace_hist = 0.08\n", - " iso_response_line_thickness = 2\n", - " respone_attenuation_line_thickness = 2\n", - " num_y_plots = 2\n", - " num_x_plots = 1\n", - " fig = plt.figure(figsize=nc.set_size(text_width, width_ratio, [num_y_plots, num_x_plots]), dpi=dpi)\n", - " gs_base = gridspec.GridSpec(num_y_plots, num_x_plots, wspace=gs0_wspace)\n", - " \n", - " curve_ax = fig.add_subplot(gs_base[0], projection='3d')\n", - " curve_ax.minorticks_off()\n", - " x_mesh, y_mesh = np.meshgrid(*contour_pts)\n", - " curve_ax.set_zlim(0, 1)\n", - " curve_ax.set_xlim3d(5, 200)\n", - " curve_ax.grid(False)\n", - " #x_ticks = curve_ax.get_xticks().tolist()\n", - " #x_ticks = np.round(np.linspace(curve_lims['x'][0], curve_lims['x'][1],\n", - " # len(x_ticks)), 1).astype(str)\n", - " #a_x = [' ']*len(x_ticks)\n", - " #a_x[1] = x_ticks[1]\n", - " #a_x[-1] = x_ticks[-1]\n", - " curve_ax.set_xticklabels([])#a_x)\n", - " #y_ticks = curve_ax.get_yticks().tolist()\n", - " #y_ticks = np.round(np.linspace(curve_lims['y'][0], curve_lims['y'][1],\n", - " # len(y_ticks)), 1).astype(str)\n", - " #a_y = [' ']*len(y_ticks)\n", - " #a_y[1] = y_ticks[1]\n", - " #a_y[-1] = y_ticks[-1]\n", - " curve_ax.set_yticklabels([])#a_y)\n", - " curve_ax.set_zticklabels([])\n", - " curve_ax.zaxis.set_rotate_label(False)\n", - " #curve_ax.set_zlabel('Activity', rotation=95, labelpad=-15., position=(-10., 0.))\n", - " if scatter:\n", - " curve_ax.scatter(x_mesh, y_mesh, activity, color=mesh_color, s=0.01)\n", - " else:\n", - " curve_ax.plot_wireframe(x_mesh, y_mesh, activity, rcount=100, ccount=100, color=mesh_color, zorder=1,\n", - " linestyles='dotted', linewidths=0.3, alpha=1.0)\n", - " \n", - " # Plane vector visualizations\n", - " v = nc.Arrow3D([-200/3., -200/3.], [200/2., 200/2.+200/16.], \n", - " [0, 0.0], mutation_scale=10, \n", - " lw=0.5, arrowstyle='-|>', color='red', linestyle='dashed')\n", - " curve_ax.add_artist(v)\n", - " curve_ax.text(-300/3., 280/3.0, 0.0, r'$\\nu$', color='red')\n", - " phi_k = nc.Arrow3D([-200/3., 0.], [200/2., 200/2.], \n", - " [0, 0.0], mutation_scale=10, \n", - " lw=1, arrowstyle='-|>', color='red', linestyle = 'dashed')\n", - " curve_ax.add_artist(phi_k)\n", - " curve_ax.text(-175/3., 250/3.0, 0.0, r'${\\phi}_{k}$', color='red')\n", - " \n", - " # Iso-response curve\n", - " loc0, loc1, loc2 = contour_text_loc[0]\n", - " curve_ax.text(loc0, loc1, loc2, 'Iso-\\nresponse', color='black', weight='bold', zorder=10)\n", - " lines = np.array([0.2, 0.203, 0.197]) - 0.1\n", - " for i in lines:\n", - " curve_ax.contour3D(x_mesh, y_mesh, activity, [i], colors='black', linewidths=2, zorder=2)\n", - " \n", - " # Response attenuation curve\n", - " loc0, loc1, loc2 = contour_text_loc[1]\n", - " curve_ax.text(loc0, loc1, loc2, 'Response\\nAttenuation', color='black', weight='bold', zorder=10)\n", - " att_line_offset = 165\n", - " x, y = contour_pts\n", - " curve_ax.plot(np.zeros_like(x)+att_line_offset, y, activity[:, att_line_offset],\n", - " color='black', lw=2, zorder=2)\n", - " \n", - " # Activity label\n", - " #loc0, loc1, loc2 = contour_text_loc[2]\n", - " #curve_ax.text(loc0, loc1, loc2, 'Activity', color='black', weight='bold', zorder=10, zdir='z')\n", - " \n", - " # Additional settings\n", - " curve_ax.view_init(view_elevation, contour_angle)\n", - " scaling = np.array([getattr(curve_ax, 'get_{}lim'.format(dim))() for dim in 'xyz'])\n", - " curve_ax.auto_scale_xyz(*[[np.min(scaling), np.max(scaling)]]*3) # square aspect\n", - " #curve_ax.get_xaxis().set_visible(False)\n", - " #curve_ax.get_yaxis().set_visible(False)\n", - " curve_ax._axis3don = False\n", - " #y_aspect = 2\n", - " #scale_x = [np.min(scaling), np.max(scaling)]\n", - " #scale_y = [y_aspect*np.min(scaling), y_aspect*np.max(scaling)]\n", - " #scale_z = [np.min(scaling), np.max(scaling)]\n", - " #curve_ax.auto_scale_xyz(scale_x, scale_y, scale_z)\n", - " \n", - " # Histogram plots\n", - " num_hist_y_plots = 2\n", - " num_hist_x_plots = 2\n", - " gs_hist = gridspec.GridSpecFromSubplotSpec(num_hist_y_plots, num_hist_x_plots, gs_base[1],\n", - " hspace=hspace_hist, wspace=wspace_hist)\n", - " orig_ax = fig.add_subplot(gs_hist[0,0])\n", - " axes = []\n", - " for sub_plt_y in range(0, num_hist_y_plots):\n", - " axes.append([])\n", - " for sub_plt_x in range(0, num_hist_x_plots):\n", - " if (sub_plt_x, sub_plt_y) == (0,0):\n", - " axes[sub_plt_y].append(orig_ax)\n", - " else:\n", - " axes[sub_plt_y].append(fig.add_subplot(gs_hist[sub_plt_y, sub_plt_x], sharey=orig_ax))\n", - " all_x_lists = zip(hist_list, label_list, color_list, bin_centers, title)\n", - " for axis_x, (curvature_hist, sub_label, sub_color, sub_bins, sub_title) in enumerate(all_x_lists):\n", - " sub_bins = np.squeeze(sub_bins)\n", - " all_y_lists = zip(curvature_hist, sub_label, sub_color, xlabel)\n", - " for axis_y, (dataset_hist, axis_labels, axis_colors, sub_xlabel) in enumerate(all_y_lists):\n", - " axes[axis_y][axis_x].spines['top'].set_visible(False)\n", - " axes[axis_y][axis_x].spines['right'].set_visible(False)\n", - " axes[axis_y][axis_x].set_xticks(sub_bins, minor=True)\n", - " axes[axis_y][axis_x].set_xticks(sub_bins[::int(len(sub_bins)/4)], minor=False)\n", - " axes[axis_y][axis_x].xaxis.set_major_formatter(plticker.FormatStrFormatter('%0.3f'))\n", - " for neuron_hist, label, color in zip(dataset_hist, axis_labels, axis_colors):\n", - " neuron_hist = np.squeeze(neuron_hist)\n", - " if log:\n", - " axes[axis_y][axis_x].semilogy(sub_bins, neuron_hist, color=color, linestyle='-',\n", - " drawstyle='steps-mid', label=label)\n", - " axes[axis_y][axis_x].yaxis.set_major_formatter(matplotlib.ticker.LogFormatterSciNotation())\n", - " else:\n", - " axes[axis_y][axis_x].plot(sub_bins, neuron_hist, color=color, linestyle='-', drawstyle='steps-mid', label=label)\n", - " axes[axis_y][axis_x].axvline(0.0, color='black', linestyle='dashed', linewidth=1)\n", - " if axis_y == 0:\n", - " axes[axis_y][axis_x].set_title(sub_title)\n", - " axes[axis_y][axis_x].set_xlabel(sub_xlabel)\n", - " if axis_x == 0:\n", - " if log:\n", - " axes[axis_y][axis_x].set_ylabel('Relative\\nLog Frequency')\n", - " else:\n", - " axes[axis_y][axis_x].set_ylabel('Relative\\nFrequency')\n", - " ax_handles, ax_labels = axes[axis_y][axis_x].get_legend_handles_labels()\n", - " legend = axes[axis_y][axis_x].legend(handles=ax_handles, labels=ax_labels, loc='upper right',\n", - " ncol=3, borderaxespad=0., borderpad=0., handlelength=0., columnspacing=-0.5,\n", - " labelspacing=0., bbox_to_anchor=(0.95, 0.95))\n", - " legend.get_frame().set_linewidth(0.0)\n", - " for text, color in zip(legend.get_texts(), axis_colors):\n", - " text.set_color(color)\n", - " for item in legend.legendHandles:\n", - " item.set_visible(False)\n", - " if axis_x == 1:\n", - " axes[axis_y][axis_x].tick_params(axis='y', labelleft=False)\n", - " plt.show()\n", - " return fig" - ] - }, { "cell_type": "code", "execution_count": null, @@ -718,12 +567,12 @@ "activity_loc = [-27, 150, 1.5]\n", "contour_text_loc = [iso_resp_loc, resp_att_loc, activity_loc]\n", "\n", - "curvature_log_fig = plot_curvature_histograms(contour_activity, contour_pts, contour_angle, view_elevation, \n", + "curvature_log_fig = nc.plot_curvature_histograms(contour_activity, contour_pts, contour_angle, view_elevation, \n", " contour_text_loc, full_hist_list, full_label_list, full_color_list, mesh_color, full_bin_centers,\n", - " full_title, full_xlabel, curve_lims, scatter, log=True, text_width=text_width, width_ratio=1.0, dpi=dpi)\n", + " full_title, full_xlabel, curve_lims, scatter, log=True, text_width=text_width, width_ratio=0.75, dpi=dpi)\n", "\n", "for analyzer in analyzer_list:\n", - " for ext in [\".pdf\"]:\n", + " for ext in file_extensions:\n", " save_name = (analyzer.analysis_out_dir+\"/vis/\"+iso_save_name+\"curvatures_and_histograms_logy\"\n", " +\"_\"+analyzer.analysis_params.save_info+ext)\n", " curvature_log_fig.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.05, dpi=dpi)" @@ -737,12 +586,12 @@ }, "outputs": [], "source": [ - "curvature_lin_fig = plot_curvature_histograms(contour_activity, contour_pts, contour_angle, view_elevation,\n", + "curvature_lin_fig = nc.plot_curvature_histograms(contour_activity, contour_pts, contour_angle, view_elevation,\n", " contour_text_loc, full_hist_list, full_label_list, full_color_list, mesh_color, full_bin_centers,\n", - " full_title, full_xlabel, curve_lims, scatter, log=False, text_width=text_width, width_ratio=1.0, dpi=dpi)\n", + " full_title, full_xlabel, curve_lims, scatter, log=False, text_width=text_width, width_ratio=0.75, dpi=dpi)\n", "\n", "for analyzer in analyzer_list:\n", - " for ext in [\".eps\"]:\n", + " for ext in file_extensions:\n", " save_name = (analyzer.analysis_out_dir+\"/vis/\"+iso_save_name+\"curvatures_and_histograms_liny\"\n", " +\"_\"+analyzer.analysis_params.save_info+ext)\n", " curvature_lin_fig.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.05, dpi=dpi)" @@ -820,10 +669,10 @@ "density = False\n", "\n", "circ_var_fig = nc.plot_circ_variance_histogram(analyzer_list, circ_var_list, color_list, label_list, num_bins,\n", - " density, width_ratios, height_ratios, text_width=text_width, width_ratio=1.0, dpi=dpi)\n", + " density, width_ratios, height_ratios, text_width=text_width, width_ratio=0.75, dpi=dpi)\n", "\n", "for analyzer in analyzer_list:\n", - " for ext in [\".png\", \".eps\"]:\n", + " for ext in file_extensions:\n", " save_name = (analyzer.analysis_out_dir+\"/vis/circular_variance_combo\"\n", " +\"_\"+analyzer.analysis_params.save_info+ext)\n", " circ_var_fig.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.01, dpi=dpi)" @@ -838,7 +687,7 @@ "spatial_frequencies = np.stack([np.array(analyzer.bf_stats[\"spatial_frequencies\"]) for analyzer in analyzer_list], axis=0)\n", "circular_variances = np.stack([variance for variance in circ_var_list], axis=0)\n", "\n", - "cv_vs_sf_fig = plt.figure(figsize=nc.set_size(text_width), dpi=dpi)\n", + "cv_vs_sf_fig = plt.figure(figsize=nc.set_size(text_width, fraction=0.75), dpi=dpi)\n", "ax = cv_vs_sf_fig.add_subplot()\n", "for analyzer_idx in range(len(analyzer_list)):\n", " ax.scatter(spatial_frequencies[analyzer_idx, :], circular_variances[analyzer_idx, :],\n", @@ -858,7 +707,7 @@ "plt.show()\n", "\n", "for analyzer in analyzer_list:\n", - " for ext in [\".eps\"]:\n", + " for ext in file_extensions:\n", " save_name = (analyzer.analysis_out_dir+\"/vis/spatial_freq_vs_circular_variance\"\n", " +\"_\"+analyzer.analysis_params.save_info+ext)\n", " cv_vs_sf_fig.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.01, dpi=dpi)" @@ -871,7 +720,7 @@ "outputs": [], "source": [ "params_list = [lca_512_vh_params(), lca_1024_vh_params(), lca_2560_vh_params()]\n", - "display_names = [\"512 Neurons\", \"768 Neurons\", \"1024 Neurons\"]#, \"2560 Neurons\"]\n", + "display_names = [\"512 Neurons\", \"1024 Neurons\", \"2560 Neurons\"]\n", "for params, display_name in zip(params_list, display_names):\n", " params.display_name = display_name\n", " params.model_dir = (os.path.expanduser(\"~\")+\"/Work/Projects/\"+params.model_name)\n", @@ -925,10 +774,10 @@ "density = True\n", "\n", "oc_vs_cv_fig = nc.plot_circ_variance_histogram(analyzer_list, circ_var_list, color_list, label_list, num_bins,\n", - " density, width_ratios, height_ratios, text_width=text_width, width_ratio=1.0, dpi=dpi)\n", + " density, width_ratios, height_ratios, text_width=text_width, width_ratio=0.75, dpi=dpi)\n", "\n", "for analyzer in analyzer_list:\n", - " for ext in [\".png\", \".eps\"]:\n", + " for ext in file_extensions:\n", " save_name = (analyzer.analysis_out_dir+\"/vis/overcompleteness_vs_circular_variance\"\n", " +\"_\"+analyzer.analysis_params.save_info+ext)\n", " oc_vs_cv_fig.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.05, dpi=dpi)" @@ -944,7 +793,9 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "scrolled": false + }, "outputs": [], "source": [ "params_list = [lca_512_vh_params(), lca_768_vh_params(), lca_2560_vh_params()]\n", @@ -970,6 +821,156 @@ " analyzer_list.append(analyzer)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def closest_val_in_array(num, arr):\n", + " curr = arr[0]\n", + " for val in arr:\n", + " if abs(num - val) < abs(num - curr):\n", + " curr = val\n", + " curr_idx = np.argwhere(np.array(arr) == curr).item()\n", + " return arr[curr_idx]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "code_folding": [], + "scrolled": false + }, + "outputs": [], + "source": [ + "num_interesting_vals = [\n", + " np.array([analyzer.nat_selectivity['num_interesting_img_nl'],\n", + " analyzer.nat_selectivity['num_interesting_img_l']])\n", + " for analyzer in analyzer_list]\n", + "\n", + "num_interesting_medians = np.stack(\n", + " [np.array([np.median(np.array(analyzer.nat_selectivity['num_interesting_img_nl'])),\n", + " np.median(np.array(analyzer.nat_selectivity['num_interesting_img_l']))])\n", + " for analyzer in analyzer_list], axis=0)\n", + "\n", + "num_interesting_means = np.stack(\n", + " [np.array([analyzer.nat_selectivity['num_interesting_img_nl_mean'],\n", + " analyzer.nat_selectivity['num_interesting_img_l_mean']])\n", + " for analyzer in analyzer_list], axis=0)\n", + "\n", + "num_interesting_stds = np.stack(\n", + " [np.array([analyzer.nat_selectivity['num_interesting_img_nl_std'],\n", + " analyzer.nat_selectivity['num_interesting_img_l_std']])\n", + " for analyzer in analyzer_list], axis=0)\n", + "\n", + "array = [\n", + " [1, 2, 3],\n", + " [4, 5, 6],\n", + "]\n", + "\n", + "scale = 1\n", + "rc_kwargs = {\n", + " 'fontsize':scale*matplotlib.rcParams['font.size'],\n", + " 'fontfamily':scale*matplotlib.rcParams['font.family'],\n", + " 'legend.fontsize': scale*matplotlib.rcParams['font.size'],\n", + " 'text.labelsize': scale*matplotlib.rcParams['font.size']\n", + "}\n", + "figsize = nc.set_size(text_width, fraction=1.00)\n", + "with plot.rc.context(**rc_kwargs):\n", + " interesting_imgs_fig, axs = plot.subplots(array, sharey=False, sharex=False, aspect=3.0, figsize=figsize)\n", + " for ovc_idx, overcompleteness in enumerate(num_interesting_vals):\n", + " ax = axs[ovc_idx]\n", + " df = pd.DataFrame(\n", + " overcompleteness.T,\n", + " columns=pd.Index(['Sparse Coding', 'Linear'])#, name='xlabel')\n", + " )\n", + " box_parts = ax.boxplot(\n", + " df,\n", + " notch=True,\n", + " fill=False,\n", + " whis=(5, 95),\n", + " marker='*',\n", + " markersize=1.0,\n", + " lw=1.2\n", + " )\n", + " colors = ['md_red', 'md_green']\n", + " for pc_idx, box in enumerate(box_parts['boxes']):\n", + " box.set_color(color_vals[colors[pc_idx]])\n", + " ax.format(\n", + " ylocator=50,\n", + " ylim=[0, np.max([np.max(val) for val in num_interesting_vals])],\n", + " title=analyzer_list[ovc_idx].nat_selectivity['oc_label'],\n", + " ylabel='Average number of\\nintersting images',\n", + " xtickminor=False,\n", + " xgrid=False\n", + " )\n", + "\n", + " for idx, analyzer in enumerate(analyzer_list):\n", + " ax = axs[idx+3]\n", + " angle_min = 0.0\n", + " angle_max = 90.0\n", + " nbins=20\n", + " bins = np.linspace(angle_min, angle_max, nbins)\n", + " lin_data = [mean for mean in analyzer.nat_selectivity['lin_means'] if mean>0]\n", + " non_lin_data = [mean for mean in analyzer.nat_selectivity['lca_means'] if mean>0]\n", + " hist_list = []\n", + " color_list = [color_vals['md_green'], color_vals['md_red']]\n", + " label_list = ['Linear Autoencoder', 'Sparse Coding']\n", + " handles = []\n", + " hist_max_list = []\n", + " for angles, label, color in zip([lin_data, non_lin_data], label_list, color_list):\n", + " # density means the y vals are probability density function at the bin, normalized such that the integral over the range is 1.\n", + " hist, bin_edges = np.histogram(np.array(angles).flatten(), bins, density=False)\n", + " hist_max_list.append(hist.max())\n", + " hist_list.append(hist)\n", + " bin_left, bin_right = bin_edges[:-1], bin_edges[1:]\n", + " bin_centers = bin_left + (bin_right - bin_left)/2\n", + " handles.append(ax.plot(bin_centers, hist, linestyle='-', drawstyle='steps-mid', color=color, label=label))\n", + " oc = analyzer.nat_selectivity['oc_label']\n", + " ax.spines['top'].set_visible(False)\n", + " ax.spines['right'].set_visible(False)\n", + " ax.set_xticks(bin_left, minor=True)\n", + " ax.set_xticks(bin_left[::2], minor=False)\n", + " ax.xaxis.set_major_formatter(plticker.FormatStrFormatter('%0.0f'))\n", + " ax.set_xticks([angle_min, angle_max//2, angle_max])\n", + " mid_val = max(hist_max_list)//2\n", + " max_val = int(max(hist_max_list))\n", + " #interval_list = list(range(0, mid_val+51, 50))\n", + " #new_mid = closest_val_in_array(mid_val, interval_list)\n", + " interval_list = list(range(0, max_val+51, 50))\n", + " new_max = closest_val_in_array(max_val, interval_list)\n", + " new_mid = new_max//2\n", + " ax.set_ylim([0, new_max+0.1*new_max])\n", + " ax.set_yticks([0, new_mid, new_max])\n", + " #axs[-1].legend(handles, ncol=1, frameon=False, loc='ur', bbox_to_anchor=[1, 1.02])\n", + " hist_ax_idx = 3\n", + " axs[hist_ax_idx].format(ylabel='Total number of\\ninteresting images')\n", + " axs[hist_ax_idx:].format(\n", + " suptitle='Sparse Coding Increases Neuron Selectivity for Natural Signals',\n", + " xlabel='Mean image-to-weight angle',\n", + " xlim=[0, 90],\n", + " ygrid=False\n", + " )\n", + "plot.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "for analyzer in analyzer_list:\n", + " for ext in file_extensions:\n", + " save_name = (analyzer.analysis_out_dir+'/vis/natural_img_selectivity_box_'\n", + " +analyzer.analysis_params.save_info+ext)\n", + " interesting_imgs_fig.savefig(save_name, transparent=False, pad_inches=0.005, dpi=dpi)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -984,6 +985,16 @@ "color_list = [color_vals['md_green'], color_vals['md_red']]\n", "label_list = ['Linear Autoencoder', 'Sparse Coding']\n", "\n", + "num_interesting_vals = [\n", + " np.array([analyzer.nat_selectivity['num_interesting_img_nl'],\n", + " analyzer.nat_selectivity['num_interesting_img_l']])\n", + " for analyzer in analyzer_list]\n", + "\n", + "num_interesting_medians = np.stack(\n", + " [np.array([analyzer.nat_selectivity['num_interesting_img_nl_mean'],\n", + " analyzer.nat_selectivity['num_interesting_img_l_mean']])\n", + " for analyzer in analyzer_list], axis=0)\n", + "\n", "num_interesting_means = np.stack(\n", " [np.array([analyzer.nat_selectivity['num_interesting_img_nl_mean'],\n", " analyzer.nat_selectivity['num_interesting_img_l_mean']])\n", @@ -1010,10 +1021,11 @@ " 'fontsize':scale*matplotlib.rcParams['font.size'],\n", " 'fontfamily':scale*matplotlib.rcParams['font.family'],\n", " 'legend.fontsize': scale*matplotlib.rcParams['font.size'],\n", - " 'text.labelsize': scale*matplotlib.rcParams['font.size']-2\n", + " 'text.labelsize': scale*matplotlib.rcParams['font.size']\n", "}\n", + "figsize = nc.set_size(text_width, fraction=1.00)\n", "with plot.rc.context(**rc_kwargs):\n", - " interesting_imgs_fig, axs = plot.subplots(array, sharey=False, aspect=3.0, width=0.4*text_width_cm)\n", + " interesting_imgs_fig, axs = plot.subplots(array, sharey=False, aspect=3.0, figsize=figsize)#, width=0.4*text_width_cm)\n", " ax = axs[0]\n", " obj = ax.bar(\n", " df,\n", @@ -1033,7 +1045,7 @@ " xlocator=1,\n", " xminorlocator=0.5,\n", " ytickminor=False,\n", - " ylim=[0, np.max(num_interesting_means)+np.max(num_interesting_stds)],\n", + " #ylim=[0, np.max(num_interesting_means)+np.max(num_interesting_stds)],\n", " #suptitle='Average number of intersting images'\n", " ylabel='Average number of\\nintersting images',\n", " xgrid=False\n", @@ -1062,7 +1074,8 @@ " ax.set_xticks(bin_left[::2], minor=False)\n", " ax.xaxis.set_major_formatter(plticker.FormatStrFormatter('%0.0f'))\n", " ax.set_xticks([angle_min, angle_max//2, angle_max])\n", - " ax.set_ylim([0.0, max(hist_max_list)+0.1*max(hist_max_list)])\n", + " ax.set_ylim([0, max(hist_max_list)+0.1*max(hist_max_list)])\n", + " ax.set_yticks([0, max(hist_max_list)//2, int(max(hist_max_list))])\n", " ax.format(title=f'{oc}\\n')#, ygrid=False)\n", " #ax.grid(b=False, which='both', axis='both')\n", " axs[1].format(ylabel='Total number of\\ninteresting images')\n", @@ -1084,10 +1097,10 @@ "outputs": [], "source": [ "for analyzer in analyzer_list:\n", - " for ext in ['.png', '.pdf', '.eps']:\n", - " save_name = (analyzer.analysis_out_dir+'/vis/natural_img_selectivity_'\n", + " for ext in file_extensions:\n", + " save_name = (analyzer.analysis_out_dir+'/vis/natural_img_selectivity_bar_'\n", " +analyzer.analysis_params.save_info+ext)\n", - " interesting_imgs_fig.savefig(save_name, transparent=False, pad_inches=0.005, dpi=interesting_imgs_fig.dpi)" + " interesting_imgs_fig.savefig(save_name, transparent=False, pad_inches=0.005, dpi=dpi)" ] }, { @@ -1480,7 +1493,7 @@ " COLORS, inner_group_names, outer_group_names, titles, text_width, width_ratio=1.0, dpi=dpi)\n", "\n", "#for analyzer in analyzer_list:\n", - "for ext in [\".png\", \".eps\"]:\n", + "for ext in file_extensions:\n", " save_name = (output_dir+'/adv_mse_comparison_boxplots'+ext)\n", " adv_fig.savefig(save_name, transparent=False, bbox_inches='tight', pad_inches=0.05, dpi=dpi)" ] @@ -1785,7 +1798,7 @@ "outputs": [], "source": [ "#for analyzer in analyzer_list:\n", - "for ext in [\".png\", \".eps\"]:\n", + "for ext in file_extensions:\n", " save_name = (output_dir+'/adv_mse_comparison_example_images'+ext)\n", " adv_img_fig.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.05, dpi=dpi)" ] @@ -1839,9 +1852,8 @@ "\n", "conf_fig = plot_average_conf_step(files, names)\n", "#for analyzer in analyzer_list:\n", - "for ext in [\".png\", \".eps\"]:\n", - " save_name = (output_dir+'adv_mse_comparison_example_images'\n", - " +\"_\"+analyzer.analysis_params.save_info+ext)\n", + "for ext in file_extensions:\n", + " save_name = (output_dir+'adv_mse_comparison_example_images'+ext)\n", " conf_fig.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.01, dpi=dpi)" ] }, @@ -2082,66 +2094,75 @@ "metadata": {}, "outputs": [], "source": [ - "def plot_average_mse_step(analysis_files, colors, title, model_names, bar_width, hatches, figsize, dpi):\n", + "def plot_average_mse_step(analysis_files, recons, confs, colors, title, model_names, bar_width, hatches, figsize, dpi):\n", " fig = plt.figure(figsize=figsize, dpi=dpi)\n", - " gs0 = gridspec.GridSpec(1, 2, wspace=0.2, width_ratios = [2, 1])#, hspace=0.3)\n", - " left_gs = gridspec.GridSpecFromSubplotSpec(1, 2, gs0[0], wspace=1.3)\n", - " right_gs = gridspec.GridSpecFromSubplotSpec(1, 1, gs0[1], wspace=0.9)\n", - " group_data = []\n", - " group_means = []\n", - " handles = []\n", + " num_conditions = len(analysis_files)\n", + " gs_top = gridspec.GridSpec(num_conditions, num_conditions)\n", " axes = []\n", - " for x_ax_idx, key in enumerate(['input_adv_mses', 'adversarial_outputs']):\n", - " axes.append(fig.add_subplot(left_gs[x_ax_idx]))\n", - " for file_idx, (file, name) in enumerate(zip(analysis_files, model_names)):\n", - " analysis = np.load(file, allow_pickle=True)[\"data\"].item()\n", - " adv_conf = 100*np.max(np.squeeze(analysis['adversarial_outputs']), axis=-1)\n", - " if x_ax_idx == 0:\n", - " axes[-1].set_ylabel('Adversarial\\nConfidence')\n", - " axes[-1].axhline(90.0, color='black', linestyle='dashed', linewidth=1) \n", - " axes[-1].set_ylim([0, 100.1])\n", - " mean_vals = np.mean(adv_conf, axis=-1)[1:]\n", - " std_vals = np.std(adv_conf, axis=-1)[1:]\n", - " else:\n", - " adv_mse = np.squeeze(analysis['input_adv_mses'])\n", - " axes[-1].set_ylabel('Adversarial Mean\\nSquared Distance')\n", - " axes[-1].yaxis.set_major_formatter(ticker.FormatStrFormatter('%.3f'))\n", - " thresh_indices = np.argwhere(np.mean(adv_conf, axis=-1)>90)\n", - " first_adv_cross = np.min(thresh_indices[thresh_indices>2]) # first couple are original label\n", - " axes[-1].axvline(first_adv_cross, color=colors[file_idx][0], linestyle='dashed', linewidth=1)\n", - " mean_vals = np.mean(adv_mse, axis=-1)[1:]\n", - " std_vals = np.std(adv_mse, axis=-1)[1:]\n", - " group_data.append(adv_mse[first_adv_cross, :])\n", - " group_means.append(mean_vals[first_adv_cross])\n", - " max_val = 0.020#np.max(mean_vals)+std_vals[np.argmax(mean_vals)]\n", - " axes[-1].set_ylim([0, max_val])\n", - " axes[-1].plot(range(len(mean_vals)), mean_vals, label=name,\n", - " lw=2, color=colors[file_idx][0], zorder=1)\n", - " axes[-1].fill_between(range(len(mean_vals)), mean_vals + std_vals , mean_vals - std_vals,\n", - " edgecolor=colors[file_idx][1], alpha=1.0, zorder=0, facecolor=\"none\", hatch=hatches[file_idx],\n", - " rasterized=False)\n", - " axes[-1].set_xlabel('Attack Step')\n", + " for condition, (condition_analysis_files, recon, conf) in enumerate(zip(analysis_files, recons, confs)):\n", + " #gs0 = gridspec.GridSpec(1, 2, wspace=0.2, width_ratios = [2, 1])#, hspace=0.3)\n", + " gs0 = gridspec.GridSpecFromSubplotSpec(1, 2, gs_top[condition, :],\n", + " wspace=0.2, width_ratios = [2, 1])#, hspace=0.3)\n", + " left_gs = gridspec.GridSpecFromSubplotSpec(1, 2, gs0[0], wspace=1.3)\n", + " right_gs = gridspec.GridSpecFromSubplotSpec(1, 1, gs0[1], wspace=0.9)\n", + " group_data = []\n", + " group_means = []\n", + " handles = []\n", + " for x_ax_idx, key in enumerate(['input_adv_mses', 'adversarial_outputs']):\n", + " axes.append(fig.add_subplot(left_gs[x_ax_idx]))\n", + " for file_idx, (file, name) in enumerate(zip(condition_analysis_files, model_names)):\n", + " analysis = np.load(file, allow_pickle=True)[\"data\"].item()\n", + " adv_conf = 100*np.max(np.squeeze(analysis['adversarial_outputs']), axis=-1)\n", + " if x_ax_idx == 0:\n", + " if condition == 0:\n", + " axes[-1].set_ylabel('Adversarial\\nConfidence')\n", + " axes[-1].axhline(90.0, color='black', linestyle='dashed', linewidth=1) \n", + " axes[-1].set_ylim([0, 100.1])\n", + " mean_vals = np.mean(adv_conf, axis=-1)[1:]\n", + " std_vals = np.std(adv_conf, axis=-1)[1:]\n", + " else:\n", + " if condition == 0:\n", + " axes[-1].set_ylabel('Adversarial Mean\\nSquared Distance')\n", + " adv_mse = np.squeeze(analysis['input_adv_mses'])\n", + " axes[-1].yaxis.set_major_formatter(ticker.FormatStrFormatter('%.3f'))\n", + " thresh_indices = np.argwhere(np.mean(adv_conf, axis=-1)>90)\n", + " first_adv_cross = np.min(thresh_indices[thresh_indices>2]) # first couple are original label\n", + " axes[-1].axvline(first_adv_cross, color=colors[file_idx][0], linestyle='dashed', linewidth=1)\n", + " mean_vals = np.mean(adv_mse, axis=-1)[1:]\n", + " std_vals = np.std(adv_mse, axis=-1)[1:]\n", + " group_data.append(adv_mse[first_adv_cross, :])\n", + " group_means.append(mean_vals[first_adv_cross])\n", + " max_val = 0.03#np.max(mean_vals)+std_vals[np.argmax(mean_vals)]\n", + " axes[-1].set_ylim([0, max_val])\n", + " axes[-1].plot(range(len(mean_vals)), mean_vals, label=name,\n", + " lw=2, color=colors[file_idx][0], zorder=1)\n", + " axes[-1].fill_between(range(len(mean_vals)), mean_vals + std_vals , mean_vals - std_vals,\n", + " edgecolor=colors[file_idx][1], alpha=1.0, zorder=0, facecolor=\"none\",\n", + " hatch=hatches[file_idx], rasterized=False)\n", + " if condition == num_conditions-1:\n", + " axes[-1].set_xlabel('Attack Step')\n", + " axes[-1].grid(False)\n", + " axes.append(fig.add_subplot(right_gs[0]))\n", + " x_pos = np.arange(2) + 2 * bar_width\n", + " linewidth = 1\n", + " medianprops = dict(linestyle='--', linewidth=linewidth, color='k')\n", + " meanprops = dict(linestyle='-', linewidth=linewidth, color='k')\n", + " float_colors = [[52/255, 152/255, 219/255], [231/255, 76/255, 60/255]] # blue, red\n", + " axes[-1].set_title(f'c={recon}, '+r'$\\kappa$'+f'={conf}')\n", + " for data, means, pos, color, name in zip(group_data, group_means, x_pos, float_colors, model_names):\n", + " boxprops = dict(linestyle='-', linewidth=linewidth, color=color)\n", + " whiskerprops = boxprops\n", + " capprops = boxprops\n", + " handles.append(axes[-1].boxplot(data, sym='', positions=[pos],\n", + " whis=(5, 95), widths=bar_width, meanline=True, showmeans=True, boxprops=boxprops,\n", + " whiskerprops=whiskerprops, capprops=capprops, medianprops=medianprops,\n", + " meanprops=meanprops\n", + " ))\n", + " axes[-1].set_ylim([0, max_val])\n", + " axes[-1].set_yticklabels('')\n", + " axes[-1].get_xaxis().set_ticks([])\n", " axes[-1].grid(False)\n", - " axes.append(fig.add_subplot(right_gs[0]))\n", - " x_pos = np.arange(2) + 2 * bar_width\n", - " linewidth = 1\n", - " medianprops = dict(linestyle='--', linewidth=linewidth, color='k')\n", - " meanprops = dict(linestyle='-', linewidth=linewidth, color='k')\n", - " colors = [[52/255, 152/255, 219/255], [231/255, 76/255, 60/255]] # blue, red\n", - " for data, means, pos, color, name in zip(group_data, group_means, x_pos, colors, model_names):\n", - " boxprops = dict(linestyle='-', linewidth=linewidth, color=color)\n", - " whiskerprops = boxprops\n", - " capprops = boxprops\n", - " handles.append(axes[-1].boxplot(data, sym='', positions=[pos],\n", - " whis=(5, 95), widths=bar_width, meanline=True, showmeans=True, boxprops=boxprops,\n", - " whiskerprops=whiskerprops, capprops=capprops, medianprops=medianprops,\n", - " meanprops=meanprops\n", - " ))\n", - " axes[-1].set_ylim([0, max_val])\n", - " axes[-1].set_yticklabels('')\n", - " axes[-1].get_xaxis().set_ticks([])\n", - " axes[-1].grid(False)\n", - " axes[-1].text(pos, 0.002, name, horizontalalignment='center', verticalalignment='center')\n", + " axes[-1].text(pos, 0.0025, name, horizontalalignment='center', verticalalignment='center')\n", " fig.subplots_adjust(top=0.8)\n", " fig.suptitle(title, y=0.98)\n", " return fig, axes" @@ -2156,35 +2177,44 @@ "outputs": [], "source": [ "colors = [[color_vals['md_blue'], color_vals['lt_blue']], [color_vals['md_red'], color_vals['lt_red']]]\n", - "model_names = ['w/o\\nLCA', 'w/\\nLCA']\n", + "model_names = ['w/o LCA', 'w/ LCA']\n", "hatches = ['///', '\\\\\\\\\\\\']\n", "\n", - "k_file_path = analysis_dir+'savefiles/class_adversary_analysis_test_temp_kurakin_targeted.npz'\n", - "k_img_path = analysis_dir+'savefiles/class_adversary_images_analysis_test_temp_kurakin_targeted.npz'\n", - "k_mlp_files = [projects_dir + model_name + k_file_path for model_name in [mnist_mlp_768_2layer]]\n", - "k_lca_files = [projects_dir + model_name + k_file_path for model_name in [mnist_lca_768_2layer]]\n", - "k_files = k_mlp_files + k_lca_files\n", - "\n", - "run_number = '5'\n", - "c_file_path = (analysis_dir+'savefiles/class_adversary_analysis_test_temp'\n", - " +str(run_number)+'_carlini_targeted.npz')\n", - "c_img_path = (analysis_dir+'savefiles/class_adversary_images_analysis_test_temp'\n", - " +str(run_number)+'_carlini_targeted.npz')\n", - "c_mlp_files = [projects_dir + model_name + c_file_path for model_name in [mnist_mlp_768_2layer]]\n", - "c_lca_files = [projects_dir + model_name + c_file_path for model_name in [mnist_lca_768_2layer]]\n", - "c_files = c_mlp_files + c_lca_files\n", - "\n", - "figsize = nc.set_size(text_width, fraction=1.0, subplot=[2, 3])\n", "#carlini_title = 'Networks with an LCA layer require larger\\nperturbations for equal confidence with the Carlini attack'\n", "#carlini_title = 'Networks with an LCA layer are more robust than without'\n", - "carlini_title = ''#Networks with an LCA layer are more robust than without'\n", - "fig, ax = plot_average_mse_step(c_files, colors, carlini_title,\n", + "carlini_title = ''\n", + "\n", + "all_recons = []\n", + "all_confs = []\n", + "all_files = []\n", + "for recon in ['0.5', '1.0']:\n", + " for conf in ['0.0', '10.0']:\n", + " if conf == '10.0':\n", + " extra_str = '_'\n", + " temp = '1.00'\n", + " else:\n", + " extra_str = ''\n", + " temp = '1.0'\n", + " c_file_path = (f'{analysis_dir}savefiles/class_adversary_analysis_test'+\n", + " f'{extra_str}temp{temp}_conf{conf}_recon{recon}_carlini_targeted.npz')\n", + " c_mlp_files = [projects_dir + model_name + c_file_path for model_name in [mnist_mlp_768_2layer]]\n", + " temp = '0.65'\n", + " c_file_path = (f'{analysis_dir}savefiles/class_adversary_analysis_test'+\n", + " f'{extra_str}temp{temp}_conf{conf}_recon{recon}_carlini_targeted.npz')\n", + " c_lca_files = [projects_dir + model_name + c_file_path for model_name in [mnist_lca_768_2layer]]\n", + " c_files = c_mlp_files + c_lca_files\n", + " all_recons.append(recon)\n", + " all_confs.append(conf)\n", + " all_files.append(c_files)\n", + "\n", + "figsize = nc.set_size(text_width, fraction=1.0, subplot=[2*2, 3])\n", + "fig, ax = plot_average_mse_step(all_files, all_recons, all_confs, colors, carlini_title,\n", " model_names, bar_width, hatches, figsize, dpi)\n", "\n", - "out_list = [projects_dir + model_name + '/analysis/0.0/vis/kurakin_carlini_mse_vs_iteration_temp' + str(run_number)\n", + "out_list = [projects_dir + model_name + '/analysis/0.0/vis/carlini_mse_vs_iteration_k0.0-10.0_conditions'\n", " for model_name in [mnist_lca_768_2layer, mnist_mlp_768_2layer]]\n", "for out_name in out_list:\n", - " for ext in [\".png\", \".eps\"]:\n", + " for ext in file_extensions:\n", " save_name = out_name+ext\n", " fig.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.01, dpi=dpi)" ] @@ -2334,7 +2364,7 @@ "\n", "out_list += [path + lista + \"/analysis/0.0/vis/lista_adv_transferability\"]\n", "for out_name in out_list:\n", - " for ext in [\".png\", \".eps\"]:\n", + " for ext in file_extensions:\n", " save_name = out_name+ext\n", " fig.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.01, dpi=dpi)" ] @@ -2367,7 +2397,7 @@ " image_idx = category_idx + start_idx\n", " orig_ax = ax_orig_list[category_idx]\n", " if category_idx == 0:\n", - " orig_ax.set_title(\"Unperturbed\", y=orig_y_adj, fontsize=fontsize)\n", + " orig_ax.set_title(\"Unperturbed\", y=orig_y_adj)#, fontsize=fontsize)\n", " orig_img = crop(np.squeeze(mlp_grp[0][0][image_idx, ...]), crop_ammount)\n", " orig_im_handle = show_image_with_label(orig_ax, orig_img, orig_labels[0][0][image_idx], cmap=cmap)\n", " for model_idx, gs_sub in enumerate([gs_sub0_list[category_idx], gs_sub1_list[category_idx]]):\n", @@ -2388,27 +2418,27 @@ " vmax = 1.0\n", " if j == 0: # top left image\n", " if model_idx == 0:\n", - " current_ax.set_ylabel(r\"$s^{*}_{T}$\", fontsize=fontsize)\n", + " current_ax.set_ylabel(r\"$s^{*}_{T}$\")#, fontsize=fontsize)\n", " current_target_label = target_labels[j][model_idx][image_idx]\n", " if category_idx == 0: # top category only\n", " x_loc = group_name_loc[0]\n", " y_loc = group_name_loc[1]\n", - " text_handle = current_ax.text(x_loc, y_loc, group_names[j+model_idx], fontsize=fontsize,\n", + " text_handle = current_ax.text(x_loc, y_loc, group_names[j+model_idx],#, fontsize=fontsize,\n", " horizontalalignment='left', verticalalignment='bottom')\n", " else: # i == 1\n", " vmin = np.round(diff_vmin, 2)\n", " vmax = np.round(diff_vmax, 2)\n", " if j == 0 and model_idx == 0:\n", - " current_ax.set_ylabel(r\"$s-s^{*}_{T}$\", fontsize=fontsize)\n", + " current_ax.set_ylabel(r\"$s-s^{*}_{T}$\")#, fontsize=fontsize)\n", " if j == 0 and category_idx == num_categories-1: # bottom left\n", - " current_ax.set_xlabel(\"w/o\\nLCA\", fontsize=fontsize)\n", + " current_ax.set_xlabel(\"w/o\\nLCA\")#, fontsize=fontsize)\n", " elif j == 1 and category_idx == num_categories-1: # bottom right\n", - " current_ax.set_xlabel(\"w/\\nLCA\", fontsize=fontsize)\n", + " current_ax.set_xlabel(\"w/\\nLCA\")#, fontsize=fontsize)\n", " im_handle = show_image_with_label(current_ax, current_image, current_target_label, vmin=vmin, vmax=vmax, cmap=cmap)\n", " if j == 1:\n", - " pf.add_colorbar_to_im(im_handle, aspect=10, ax=current_ax, ticks=[vmin, vmax], labelsize=fontsize/2)\n", + " pf.add_colorbar_to_im(im_handle, aspect=10, ax=current_ax, ticks=[vmin, vmax])#, labelsize=fontsize/2)\n", "\n", - "def plot_adv_images_with_figsize(image_groups, labels, mnist_start_idx, cifar_start_idx, figsize, fontsize, dpi=100):\n", + "def plot_adv_images_with_figsize(image_groups, labels, mnist_start_idx, cifar_start_idx, figsize, dpi=100):\n", " mnist_mlp_grp, mnist_lca_grp = image_groups[0]\n", " cifar_mlp_grp, cifar_lca_grp = image_groups[1]\n", " mnist_orig_labels, mnist_target_labels, mnist_img_labels = labels[0]\n", @@ -2420,22 +2450,23 @@ " sub_wspace = 0.2\n", " orig_y_adj = 1.10\n", " img_label_loc = [-8.0, -8.0] # [x, y]\n", - " fig2 = plt.figure(figsize=[figsize[0]/2, figsize[1]], dpi=dpi)\n", - " gs0 = plt.GridSpec(2, 1, figure=fig2, hspace=0.3)\n", + " fig = plt.figure(figsize=[figsize[0]/2, figsize[1]], dpi=dpi)\n", + " gs0 = plt.GridSpec(2, 1, figure=fig, hspace=0.3)\n", " \n", " num_categories=3\n", " \n", " gs_mnist = gridspec.GridSpecFromSubplotSpec(num_categories, 6, gs0[0], hspace=hspace, wspace=wspace)\n", - " make_grid_subplots_with_fontsize(fig2, gs_mnist, mnist_mlp_grp, mnist_lca_grp, mnist_orig_labels,\n", + " make_grid_subplots_with_fontsize(fig, gs_mnist, mnist_mlp_grp, mnist_lca_grp, mnist_orig_labels,\n", " mnist_target_labels, mnist_img_labels, img_label_loc, orig_y_adj, mnist_start_idx, num_categories,\n", - " hspace=sub_hspace, wspace=sub_wspace, cmap=\"Greys\", fontsize=fontsize)\n", + " hspace=sub_hspace, wspace=sub_wspace, cmap=\"Greys\")\n", " \n", " gs_cifar = gridspec.GridSpecFromSubplotSpec(num_categories, 6, gs0[1], hspace=hspace, wspace=wspace)\n", - " make_grid_subplots_with_fontsize(fig2, gs_cifar, cifar_mlp_grp, cifar_lca_grp, cifar_orig_labels,\n", + " make_grid_subplots_with_fontsize(fig, gs_cifar, cifar_mlp_grp, cifar_lca_grp, cifar_orig_labels,\n", " cifar_target_labels, cifar_img_labels, img_label_loc, orig_y_adj, 0, num_categories,\n", - " hspace=sub_hspace, wspace=sub_wspace, cmap=\"Greys_r\", fontsize=fontsize)\n", + " hspace=sub_hspace, wspace=sub_wspace, cmap=\"Greys_r\")\n", " \n", - " plt.show()" + " plt.show()\n", + " return fig " ] }, { @@ -2446,9 +2477,17 @@ }, "outputs": [], "source": [ - "full_adv_img_fig = plot_adv_images_with_figsize(image_groups, labels, mnist_start_idx=44, cifar_start_idx=0,\n", - " figsize=(16, 16), fontsize=20, dpi=dpi)" + "figsize = nc.set_size(text_width, fraction=1.0, subplot=[16, 16])\n", + "full_adv_img_fig = plot_adv_images_with_figsize(image_groups, label_groups, mnist_start_idx=44, cifar_start_idx=0,\n", + " figsize=figsize, dpi=dpi)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { From 352eeaacd42a0c04b8444111edc4647f908faaa8 Mon Sep 17 00:00:00 2001 From: Dylan Date: Fri, 15 Jan 2021 12:32:03 +0100 Subject: [PATCH 09/44] adds scaffolding for SMT model --- models/smt_model.py | 65 +++++++++++++++++++++++++++++++++++++++++++ modules/losses.py | 8 ++++++ modules/mlp_module.py | 19 +++++++++++-- 3 files changed, 89 insertions(+), 3 deletions(-) create mode 100644 models/smt_model.py diff --git a/models/smt_model.py b/models/smt_model.py new file mode 100644 index 00000000..2f027ffd --- /dev/null +++ b/models/smt_model.py @@ -0,0 +1,65 @@ +import torch + +import DeepSparseCoding.utils.loaders as loaders +from DeepSparseCoding.models.base_model import BaseModel +from DeepSparseCoding.modules.ensemble_module import EnsembleModule + + +class SmtModel(BaseModel, EnsembleModule): + def setup(self, params, logger=None): + """ + Setup required model components + """ + super(SmtModel, self).setup(params, logger) + self.setup_module(params) + self.setup_optimizer() + + def setup_module(self, params): + for subparams in params.ensemble_params: + subparams.epoch_size = params.epoch_size + subparams.batches_per_epoch = params.batches_per_epoch + subparams.num_batches = params.num_batches + #subparams.num_val_images = params.num_val_images + #subparams.num_test_images = params.num_test_images + subparams.data_shape = params.data_shape + super(SmtModel, self).setup_ensemble_module(params) + self.submodel_classes = [] + for submodel_params in self.params.ensemble_params: + self.submodel_classes.append(loaders.load_model_class(submodel_params.model_type)) + + def setup_optimizer(self): + for module in self: + module.optimizer = self.get_optimizer( + optimizer_params=module.params, + trainable_variables=module.parameters()) + module.scheduler = torch.optim.lr_scheduler.MultiStepLR( + module.optimizer, + milestones=module.params.optimizer.milestones, + gamma=module.params.optimizer.lr_decay_rate) + + def preprocess_data(self, data): + """ + We assume that only the first submodel will be preprocessing the input data + """ + submodule = self.__getitem__(0) + return self.submodel_classes[0].preprocess_data(submodule, data) + + def get_total_loss(self, input_tuple, ensemble_index): + submodule = self.__getitem__(ensemble_index) + submodel_class = self.submodel_classes[ensemble_index] + return submodel_class.get_total_loss(submodule, input_tuple) + + def generate_update_dict(self, input_data, input_labels=None, batch_step=0): + update_dict = super(SmtModel, self).generate_update_dict(input_data, + input_labels, batch_step) + x = input_data.clone() # TODO: Do I need to clone it? If not then don't. + for ensemble_index, submodel_class in enumerate(self.submodel_classes): + submodule = self.__getitem__(ensemble_index) + submodel_update_dict = submodel_class.generate_update_dict(submodule, x, + input_labels, batch_step, update_dict=dict()) + for key, value in submodel_update_dict.items(): + if key not in ['epoch', 'batch_step']: + key = submodule.params.model_type+'_'+key + update_dict[key] = value + x = submodule.get_encodings(x) + return update_dict diff --git a/modules/losses.py b/modules/losses.py index d9ea61b7..e73e9cea 100644 --- a/modules/losses.py +++ b/modules/losses.py @@ -3,6 +3,14 @@ import DeepSparseCoding.utils.data_processing as dp +#def l2_flatness(z1, z2, z3, w): +# """ +# Minimized when a straight line can be drawn through [z1, z2, z3]. +# Extended from equations 8 and 12 in +# Chen, Paiton, Olshausen (2018) - The Sparse Manifold Transform +# """ +# z_mat = + def half_squared_l2(x1, x2): """ Computes the standard reconstruction loss. It will average over batch dimensions. diff --git a/modules/mlp_module.py b/modules/mlp_module.py index 4877d8eb..c99967b5 100644 --- a/modules/mlp_module.py +++ b/modules/mlp_module.py @@ -17,11 +17,24 @@ def setup_module(self, params): in_features = self.params.layer_channels[layer_index], out_features = self.params.layer_channels[layer_index+1], bias = True) - self.register_parameter('fc'+str(layer_index)+'_w', layer.weight) - self.register_parameter('fc'+str(layer_index)+'_b', layer.bias) - self.layers.append(layer) + elif layer_type == 'conv': + w_shape = [ + self.params.out_channels[layer_index], + self.params.in_channels[layer_index], + self.params.kernel_size[layer_index], + self.params.kernel_size[layer_index]] + layer = nn.Conv2d( + in_channels = self.params.in_channels[layer_index], + out_channels = self.params.out_channels[layer_index], + kernel_size = w_shape, + stride = self.parmas.stride[layer_index], + padding = self.params.padding[layer_index], + bias=True) else: assert False, ('layer_type parameter must be "fc", not %g'%(layer_type)) + self.register_parameter(layer_type+str(layer_index)+'_w', layer.weight) + self.register_parameter(layer_type+str(layer_index)+'_b', layer.bias) + self.layers.append(layer) self.dropout.append(nn.Dropout(p=self.params.dropout_rate[layer_index])) def preprocess_data(self, input_tensor): From 5d56472eb720e23ffd958e8cf4f44ce944ebabf5 Mon Sep 17 00:00:00 2001 From: Dylan Date: Fri, 15 Jan 2021 12:33:14 +0100 Subject: [PATCH 10/44] final updates for JOV paper figures --- tf1x/analysis/iso_response_analysis.py | 19 ++++++++------ tf1x/utils/data_processing.py | 6 ++--- tf1x/vis/JOV_figs.ipynb | 36 +++++++++++++++++++++----- 3 files changed, 44 insertions(+), 17 deletions(-) diff --git a/tf1x/analysis/iso_response_analysis.py b/tf1x/analysis/iso_response_analysis.py index 754dd044..21cbd701 100644 --- a/tf1x/analysis/iso_response_analysis.py +++ b/tf1x/analysis/iso_response_analysis.py @@ -219,29 +219,32 @@ def __init__(self): cont_analysis['min_angle'] = 15 cont_analysis['batch_size'] = 100 cont_analysis['vh_image_scale'] = 31.773287 # Mean of the l2 norm of the training set - cont_analysis['comparison_method'] = 'closest' # rand or closest - + cont_analysis['comparison_method'] = 'rand' # rand or closest + cont_analysis['measure_upper_right'] = False + cont_analysis['bounds'] = ((-1, 1), (-1, 1)) + cont_analysis['target_act'] = 0.5 cont_analysis['num_neurons'] = 100 # How many neurons to plot cont_analysis['num_comparisons'] = 300 # How many planes to construct (None is all of them) cont_analysis['x_range'] = [-2.0, 2.0] cont_analysis['y_range'] = [-2.0, 2.0] cont_analysis['num_images'] = int(30**2) - cont_analysis['params_list'] = [lca_512_vh_params()] + cont_analysis['params_list'] = [lca_1024_vh_params(), lca_2560_vh_params()] #cont_analysis['params_list'] = [lca_768_vh_params()] #cont_analysis['params_list'] = [lca_1024_vh_params()] #cont_analysis['params_list'] = [lca_2560_vh_params()] #cont_analysis['iso_save_name'] = "iso_curvature_xrange1.3_yrange-2.2_" #cont_analysis['iso_save_name'] = "iso_curvature_ryan_" - cont_analysis['iso_save_name'] = "rescaled_closecomp_" + cont_analysis['iso_save_name'] = "newfits_rescaled_randomcomp_" #cont_analysis['iso_save_name'] = '' - np.savez(save_root+'iso_params_'+cont_analysis['iso_save_name']+params.save_info+".npz", - data=cont_analysis) analyzer_list = [load_analyzer(params) for params in cont_analysis['params_list']] for analyzer, params in zip(analyzer_list, cont_analysis['params_list']): + save_root=analyzer.analysis_out_dir+'savefiles/' + np.savez(save_root+'iso_params_'+cont_analysis['iso_save_name']+params.save_info+".npz", + data=cont_analysis) print(analyzer.analysis_params.display_name) print("Computing the iso-response vectors...") cont_analysis['target_neuron_ids'] = iso_data.get_rand_target_neuron_ids( @@ -297,7 +300,6 @@ def __init__(self): datapoints, get_dsc_activations_cell, activation_function_kwargs) - save_root=analyzer.analysis_out_dir+'savefiles/' if use_rand_orth_vects: np.savez(save_root+'iso_rand_activations_'+cont_analysis['iso_save_name']+params.save_info+'.npz', data=activations) @@ -310,10 +312,11 @@ def __init__(self): data=contour_dataset) cont_analysis['comparison_neuron_ids'] = analyzer.comparison_neuron_ids cont_analysis['contour_dataset'] = contour_dataset + cont_analysis['activations'] = activations curvatures, fits = hist_funcs.iso_response_curvature_poly_fits( cont_analysis['activations'], target_act=cont_analysis['target_act'], - measure_upper_right=False + bounds=cont_analysis['bounds'] ) cont_analysis['curvatures'] = np.stack(np.stack(curvatures, axis=0), axis=0) np.savez(save_root+'group_iso_vectors_'+cont_analysis['iso_save_name']+params.save_info+'.npz', diff --git a/tf1x/utils/data_processing.py b/tf1x/utils/data_processing.py index 6118a0bf..999a8dcf 100644 --- a/tf1x/utils/data_processing.py +++ b/tf1x/utils/data_processing.py @@ -284,7 +284,7 @@ def generate_grating(patch_edge_size, location, diameter, orientation, frequency """ vals = np.linspace(-np.pi, np.pi, patch_edge_size) X, Y = np.meshgrid(vals, vals) - Xr = np.cos(orientation)*X + -np.sin(orientation)*Y # countercloclwise + Xr = np.cos(orientation)*X + -np.sin(orientation)*Y # counterclockwise Yr = np.sin(orientation)*X + np.cos(orientation)*Y stim = contrast*np.sin(Yr*frequency+phase) if diameter > 0: # Generate mask @@ -958,13 +958,13 @@ def pca_reduction(data, num_pcs=-1): data_mean = data.mean(axis=(1))[:,None] data -= data_mean Cov = np.cov(data.T) # Covariace matrix - U, S, V = np.linalg.svd(Cov) # SVD decomposition + U, S, VT = np.linalg.svd(Cov) # SVD decomposition diagS = np.diag(S) if num_pcs <= 0: n = num_rows else: n = num_pcs - data_reduc = np.dot(data, np.dot(np.dot(U[:, :n], diagS[:n, :n]), V[:n, :])) + data_reduc = np.dot(data, np.dot(np.dot(U[:, :n], diagS[:n, :n]), VT[:n, :])) return data_reduc def compute_power_spectrum(data): diff --git a/tf1x/vis/JOV_figs.ipynb b/tf1x/vis/JOV_figs.ipynb index 10c8bb02..76111b37 100644 --- a/tf1x/vis/JOV_figs.ipynb +++ b/tf1x/vis/JOV_figs.ipynb @@ -337,6 +337,28 @@ " analyzer = add_analyzer_keys(analyzer)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "target_act = 0.5 # target activity spot between min & max value of normalized activity (btwn 0 and 1)\n", + "lca_activations = analyzer_list[-1].comp_activations\n", + "curvatures, fits, contours = ha.iso_response_curvature_poly_fits(\n", + " lca_activations,\n", + " target_act=target_act\n", + ")\n", + "max_comp_indices = []\n", + "max_vals = []\n", + "for target_neuron_id in range(len(curvatures)):\n", + " max_idx = np.argmax(curvatures[target_neuron_id])\n", + " max_comp_indices.append(max_idx)\n", + " max_vals.append(curvatures[target_neuron_id][max_idx])\n", + "max_target_id = np.argmax(max_vals)\n", + "max_comparison_id = max_comp_indices[max_target_id]" + ] + }, { "cell_type": "code", "execution_count": null, @@ -360,22 +382,24 @@ "min_comparison_id = target_min_idx[min_target_id]\n", "\n", "# 8(.039), 17(.028) 23(.037), 25(.036), 41(.033), 48(.035), 49(0.039)\n", - "neuron_indices = [0, 0, 0, min_target_id]\n", - "orth_indices = [0, 0, 0, min_comparison_id]\n", - "target_act = 0.5 # target activity spot between min & max value of normalized activity (btwn 0 and 1)\n", + "neuron_indices = [0, 0, 0, max_target_id]#min_target_id]\n", + "orth_indices = [0, 0, 0, max_comparison_id]#min_comparison_id]\n", "num_plots_y = 2\n", "num_plots_x = 2\n", "width_fraction = 1.0\n", "show_contours = True\n", "\n", "lca_activations = analyzer_list[-1].comp_activations[neuron_indices[-1], orth_indices[-1], ...][None, None, ...]\n", - "curvatures, fits = ha.iso_response_curvature_poly_fits(\n", + "curvatures, fits, contours = ha.iso_response_curvature_poly_fits(\n", " lca_activations,\n", - " target_act=target_act,\n", - " measure_upper_right=False\n", + " target_act=target_act\n", ")\n", "curvature = [None, None, None, curvatures[0][0]]\n", "\n", + "#for analyzer in analyzer_list:\n", + "# analyzer.comp_activations = analyzer.comp_activations - analyzer.comp_activations.min()\n", + "# analyzer.comp_activations = analyzer.comp_activations / analyzer.comp_activations.max()\n", + "\n", "contour_fig, contour_handles = nc.plot_group_iso_contours(analyzer_list, neuron_indices, orth_indices,\n", " num_levels, x_range, y_range, show_contours, curvature, text_width, width_fraction, dpi)\n", "\n", From 5ce4c9baa6cfeeb0f0c90a4fea43876868c2418f Mon Sep 17 00:00:00 2001 From: Dylan Date: Thu, 28 Jan 2021 16:12:56 +0100 Subject: [PATCH 11/44] updates relative imports removes incomplete aversarial_analysis script all pytorch tests pass --- adversarial_analysis.py | 121 -------------------- datasets/synthetic.py | 7 +- params/test_params.py | 7 +- tests/test_data_processing.py | 9 +- tests/test_datasets.py | 7 +- tests/test_foolbox.py | 9 +- tests/test_models.py | 7 +- tests/test_param_loading.py | 3 +- tf1x/analyze_model.py | 7 +- tf1x/tests/analysis/atas_test.py | 3 +- tf1x/tests/data/data_selector_test.py | 3 +- tf1x/tests/models/build_test.py | 3 +- tf1x/tests/models/comb_test.py | 3 +- tf1x/tests/models/run_test.py | 3 +- tf1x/tests/utils/checkpoint_test.py | 3 +- tf1x/tests/utils/contrast_normalize_test.py | 3 +- tf1x/tests/utils/patches_test.py | 3 +- tf1x/tests/utils/reshape_data_test.py | 3 +- tf1x/tests/utils/standardize_data_test.py | 3 +- tf1x/train_model.py | 7 +- tf1x/vis/tsne_analysis.py | 7 +- tf1x/vis/vis_class_adversarial.py | 8 +- tf1x/vis/vis_conv_lca.py | 8 +- tf1x/vis/vis_corrupt.py | 7 +- tf1x/vis/vis_recon_adversarial.py | 7 +- train_model.py | 3 +- utils/dataset_utils.py | 7 +- 27 files changed, 83 insertions(+), 178 deletions(-) delete mode 100644 adversarial_analysis.py diff --git a/adversarial_analysis.py b/adversarial_analysis.py deleted file mode 100644 index 716d6d5f..00000000 --- a/adversarial_analysis.py +++ /dev/null @@ -1,121 +0,0 @@ -import os -import sys - -ROOT_DIR = os.path.dirname(os.getcwd()) -if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) - -import numpy as np -import proplot as plot -import torch - -from DeepSparseCoding.utils.file_utils import Logger -import DeepSparseCoding.utils.loaders as loaders -import DeepSparseCoding.utils.run_utils as run_utils -import DeepSparseCoding.utils.dataset_utils as dataset_utils -import DeepSparseCoding.utils.run_utils as ru -import DeepSparseCoding.utils.plot_functions as pf - -import eagerpy as ep -from foolbox import PyTorchModel, accuracy, samples -import foolbox.attacks as fa - - -log_files = [ - os.path.join(*[ROOT_DIR, 'Torch_projects', 'mlp_768_mnist', 'logfiles', 'mlp_768_mnist_v0.log']), - os.path.join(*[ROOT_DIR, 'Torch_projects', 'lca_768_mlp_mnist', 'logfiles', 'lca_768_mlp_mnist_v0.log']) - ] - -cp_latest_filenames = [ - os.path.join(*[ROOT_DIR,'Torch_projects', 'mlp_768_mnist', 'checkpoints', 'mlp_768_mnist_latest_checkpoint_v0.pt']), - os.path.join(*[ROOT_DIR, 'Torch_projects', 'lca_768_mlp_mnist', 'checkpoints', 'lca_768_mlp_mnist_latest_checkpoint_v0.pt']) - ] - -attack_params = { - 'linfPGD': { - 'abs_stepsize':0.01, - 'steps':5000 - } -} - -attacks = [ - #fa.FGSM(), - fa.LinfPGD(**attack_params['linfPGD']), - #fa.LinfBasicIterativeAttack(), - #fa.LinfAdditiveUniformNoiseAttack(), - #fa.LinfDeepFoolAttack(), -] - -epsilons = [ # allowed perturbation size - 0.0, - 0.05, - 0.1, - 0.15, - 0.2, - 0.25, - 0.3, - 0.35, - #0.4, - 0.5, - #0.8, - 1.0 -] - -num_models = len(log_files) -for model_index in range(num_models): - logger = Logger(log_files[model_index], overwrite=False) - log_text = logger.load_file() - params = logger.read_params(log_text)[-1] - params.cp_latest_filename = cp_latest_filenames[model_index] - train_loader, val_loader, test_loader, data_params = dataset_utils.load_dataset(params) - for key, value in data_params.items(): - setattr(params, key, value) - model = loaders.load_model(params.model_type) - model.setup(params, logger) - model.params.analysis_out_dir = os.path.join( - *[model.params.model_out_dir, 'analysis', model.params.version]) - model.params.analysis_save_dir = os.path.join(model.params.analysis_out_dir, 'savefiles') - if not os.path.exists(model.params.analysis_save_dir): - os.makedirs(model.params.analysis_save_dir) - model.to(params.device) - model.load_checkpoint() - fmodel = PyTorchModel(model.eval(), bounds=(0, 1)) - print('\n', '~' * 79) - num_batches = len(test_loader.dataset) // model.params.batch_size - attack_success = np.zeros( - (len(attacks), len(epsilons), num_batches, model.params.batch_size), dtype=np.bool) - for batch_index, (data, target) in enumerate(test_loader): - data = model.preprocess_data(data.to(model.params.device)) - target = target.to(model.params.device) - images, labels = ep.astensors(*(data, target)) - del data; del target - print(f'Model type: {model.params.model_type} [{model_index+1} out of {len(log_files)}]') - print(f'Batch {batch_index+1} out of {num_batches}') - print(f'accuracy {accuracy(fmodel, images, labels)}') - for attack_index, attack in enumerate(attacks): - advs, inputs, success = attack(fmodel, images, labels, epsilons=epsilons) - assert success.shape == (len(epsilons), len(images)) - success_ = success.numpy() - assert success_.dtype == np.bool - attack_success[attack_index, :, batch_index, :] = success_ - print('\n', attack) - print(' ', 1.0 - success_.mean(axis=-1).round(2)) - np.savez('tmp_perturbations.npz', data=advs[0].numpy()) - np.savez('tmp_images.npz', data=images.numpy()) - np.savez('tmp_inputs.npz', data=inputs[0].numpy()) - import IPython; IPython.embed(); raise SystemExit - robust_accuracy = 1.0 - attack_success[:, :, batch_index, :].max(axis=0).mean(axis=-1) - print('\n', '-' * 79, '\n') - print('worst case (best attack per-sample)') - print(' ', robust_accuracy.round(2)) - print('-' * 79) - attack_success = attack_success.reshape( - (len(attacks), len(epsilons), num_batches*model.params.batch_size)) - attack_types = [str(type(attack)).split('.')[-1][:-2] for attack in attacks] - output_filename = os.path.join(model.params.analysis_save_dir, - f'linf_adversarial_analysis.npz') - out_dict = { - 'adversarial_analysis':attack_success, - 'attack_types':attack_types, - 'epsilons':epsilons, - 'attack_params':attack_params} - np.savez(output_filename, data=out_dict) diff --git a/datasets/synthetic.py b/datasets/synthetic.py index db82e635..3af48bc6 100644 --- a/datasets/synthetic.py +++ b/datasets/synthetic.py @@ -1,5 +1,9 @@ import os import sys +from os.path import dirname as up + +ROOT_DIR = up(up(up(os.path.realpath(__file__)))) +if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np from scipy.stats import norm @@ -7,9 +11,6 @@ import torch import torchvision -ROOT_DIR = os.path.dirname(os.getcwd()) -if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) - import DeepSparseCoding.utils.data_processing as dp class SyntheticImages(torchvision.datasets.vision.VisionDataset): diff --git a/params/test_params.py b/params/test_params.py index 48c7f260..a9ea2da2 100644 --- a/params/test_params.py +++ b/params/test_params.py @@ -1,13 +1,14 @@ import os import sys import types +from os.path import dirname as up + +ROOT_DIR = up(up(up(os.path.realpath(__file__)))) +if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np import torch -ROOT_DIR = os.path.dirname(os.getcwd()) -if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) - from DeepSparseCoding.params.base_params import BaseParams from DeepSparseCoding.params.lca_mnist_params import params as LcaParams from DeepSparseCoding.params.mlp_mnist_params import params as MlpParams diff --git a/tests/test_data_processing.py b/tests/test_data_processing.py index cec59546..004ee80e 100644 --- a/tests/test_data_processing.py +++ b/tests/test_data_processing.py @@ -1,15 +1,14 @@ import os import sys import unittest +from os.path import dirname as up +ROOT_DIR = up(up(up(os.path.realpath(__file__)))) +if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import torch import numpy as np - -ROOT_DIR = os.path.dirname(os.getcwd()) -if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) - import DeepSparseCoding.utils.data_processing as dp @@ -249,7 +248,7 @@ def test_atleastkd(self): def test_l2_weight_norm(self): w_fc = np.random.standard_normal([24, 38]) w_conv = np.random.standard_normal([38, 24, 8, 8]) - for w in [w_fc, w_conv, 0*w_fc, 0*w_conv]: + for w in [w_fc, w_conv]: w_norm = dp.get_weights_l2_norm(torch.tensor(w), eps=1e-12).numpy() normed_w = dp.l2_normalize_weights(torch.tensor(w), eps=1e-12).numpy() normed_w_norm = dp.get_weights_l2_norm(torch.tensor(normed_w), eps=1e-12).numpy() diff --git a/tests/test_datasets.py b/tests/test_datasets.py index d1c0cde4..85463134 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -2,13 +2,14 @@ import sys import unittest import types +from os.path import dirname as up + +ROOT_DIR = up(up(up(os.path.realpath(__file__)))) +if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np from torchvision import datasets -ROOT_DIR = os.path.dirname(os.getcwd()) -if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) - import DeepSparseCoding.utils.dataset_utils as dataset_utils class TestDatasets(unittest.TestCase): diff --git a/tests/test_foolbox.py b/tests/test_foolbox.py index 02dc807d..b7f9cc53 100644 --- a/tests/test_foolbox.py +++ b/tests/test_foolbox.py @@ -1,15 +1,16 @@ import os import sys import unittest +from os.path import dirname as up + +ROOT_DIR = up(up(up(os.path.realpath(__file__)))) +if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) #import numpy as np import eagerpy as ep from foolbox import PyTorchModel, accuracy, samples import foolbox.attacks as fa -ROOT_DIR = os.path.dirname(os.getcwd()) -if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) - import DeepSparseCoding.utils.loaders as loaders #import DeepSparseCoding.utils.dataset_utils as datasets #import DeepSparseCoding.utils.run_utils as run_utils @@ -42,4 +43,4 @@ # fmodel = PyTorchModel(model.eval(), bounds=(0, 1)) # model_output = fmodel.forward() # adv_model_outputs, adv_images, success = attack(fmodel, train_data_batch, train_target_batch, epsilons=epsilons) -# \ No newline at end of file +# diff --git a/tests/test_models.py b/tests/test_models.py index 1ff329da..426a903c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,12 +1,13 @@ import os import sys import unittest +from os.path import dirname as up -import numpy as np - -ROOT_DIR = os.path.dirname(os.getcwd()) +ROOT_DIR = up(up(up(os.path.realpath(__file__)))) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) +import numpy as np + import DeepSparseCoding.utils.loaders as loaders import DeepSparseCoding.utils.dataset_utils as datasets import DeepSparseCoding.utils.run_utils as run_utils diff --git a/tests/test_param_loading.py b/tests/test_param_loading.py index 721b6a59..26377ad2 100644 --- a/tests/test_param_loading.py +++ b/tests/test_param_loading.py @@ -1,7 +1,8 @@ import os import sys +from os.path import dirname as up -ROOT_DIR = os.path.dirname(os.getcwd()) +ROOT_DIR = up(up(up(os.path.realpath(__file__)))) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import DeepSparseCoding.utils.loaders as loaders diff --git a/tf1x/analyze_model.py b/tf1x/analyze_model.py index c8db1e28..a30975aa 100644 --- a/tf1x/analyze_model.py +++ b/tf1x/analyze_model.py @@ -1,13 +1,14 @@ import os import sys import argparse +from os.path import dirname as up + +ROOT_DIR = up(up(up(os.path.realpath(__file__)))) +if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np import tensorflow as tf -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) -if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) - from DeepSparseCoding.tf1x.utils.logger import Logger import DeepSparseCoding.tf1x.utils.data_processing as dp import DeepSparseCoding.tf1x.data.data_selector as ds diff --git a/tf1x/tests/analysis/atas_test.py b/tf1x/tests/analysis/atas_test.py index af071a2b..061a12aa 100644 --- a/tf1x/tests/analysis/atas_test.py +++ b/tf1x/tests/analysis/atas_test.py @@ -1,7 +1,8 @@ import os import sys +from os.path import dirname as up -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) +ROOT_DIR = up(up(up(up(up(os.path.realpath(__file__)))))) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np diff --git a/tf1x/tests/data/data_selector_test.py b/tf1x/tests/data/data_selector_test.py index 43864e26..76ec089b 100644 --- a/tf1x/tests/data/data_selector_test.py +++ b/tf1x/tests/data/data_selector_test.py @@ -1,7 +1,8 @@ import os import sys +from os.path import dirname as up -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) +ROOT_DIR = up(up(up(up(up(os.path.realpath(__file__)))))) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np diff --git a/tf1x/tests/models/build_test.py b/tf1x/tests/models/build_test.py index c2acca84..70ffb516 100644 --- a/tf1x/tests/models/build_test.py +++ b/tf1x/tests/models/build_test.py @@ -1,8 +1,9 @@ import copy import os import sys +from os.path import dirname as up -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) +ROOT_DIR = up(up(up(up(up(os.path.realpath(__file__)))))) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np diff --git a/tf1x/tests/models/comb_test.py b/tf1x/tests/models/comb_test.py index f54396f4..6d10e833 100644 --- a/tf1x/tests/models/comb_test.py +++ b/tf1x/tests/models/comb_test.py @@ -1,8 +1,9 @@ import copy import os import sys +from os.path import dirname as up -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) +ROOT_DIR = up(up(up(up(up(os.path.realpath(__file__)))))) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np diff --git a/tf1x/tests/models/run_test.py b/tf1x/tests/models/run_test.py index 20f379be..7342d63a 100644 --- a/tf1x/tests/models/run_test.py +++ b/tf1x/tests/models/run_test.py @@ -1,8 +1,9 @@ import copy import os import sys +from os.path import dirname as up -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) +ROOT_DIR = up(up(up(up(up(os.path.realpath(__file__)))))) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np diff --git a/tf1x/tests/utils/checkpoint_test.py b/tf1x/tests/utils/checkpoint_test.py index 4833d59e..bd7c5617 100644 --- a/tf1x/tests/utils/checkpoint_test.py +++ b/tf1x/tests/utils/checkpoint_test.py @@ -1,7 +1,8 @@ import os import sys +from os.path import dirname as up -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) +ROOT_DIR = up(up(up(up(up(os.path.realpath(__file__)))))) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np diff --git a/tf1x/tests/utils/contrast_normalize_test.py b/tf1x/tests/utils/contrast_normalize_test.py index 3ec003fe..32712405 100644 --- a/tf1x/tests/utils/contrast_normalize_test.py +++ b/tf1x/tests/utils/contrast_normalize_test.py @@ -1,7 +1,8 @@ import os import sys +from os.path import dirname as up -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) +ROOT_DIR = up(up(up(up(up(os.path.realpath(__file__)))))) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np diff --git a/tf1x/tests/utils/patches_test.py b/tf1x/tests/utils/patches_test.py index 4639a162..ad0e7c78 100644 --- a/tf1x/tests/utils/patches_test.py +++ b/tf1x/tests/utils/patches_test.py @@ -1,7 +1,8 @@ import os import sys +from os.path import dirname as up -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) +ROOT_DIR = up(up(up(up(up(os.path.realpath(__file__)))))) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np diff --git a/tf1x/tests/utils/reshape_data_test.py b/tf1x/tests/utils/reshape_data_test.py index ff6922ad..18b76eed 100644 --- a/tf1x/tests/utils/reshape_data_test.py +++ b/tf1x/tests/utils/reshape_data_test.py @@ -1,7 +1,8 @@ import os import sys +from os.path import dirname as up -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) +ROOT_DIR = up(up(up(up(up(os.path.realpath(__file__)))))) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np diff --git a/tf1x/tests/utils/standardize_data_test.py b/tf1x/tests/utils/standardize_data_test.py index 524d3f18..e918bd1a 100644 --- a/tf1x/tests/utils/standardize_data_test.py +++ b/tf1x/tests/utils/standardize_data_test.py @@ -1,7 +1,8 @@ import os import sys +from os.path import dirname as up -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) +ROOT_DIR = up(up(up(up(up(os.path.realpath(__file__)))))) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np diff --git a/tf1x/train_model.py b/tf1x/train_model.py index a11ff070..bc2cd08c 100644 --- a/tf1x/train_model.py +++ b/tf1x/train_model.py @@ -2,15 +2,16 @@ import sys import time as ti import argparse +from os.path import dirname as up + +ROOT_DIR = up(up(up(os.path.realpath(__file__)))) +if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import matplotlib matplotlib.use("Agg") import numpy as np import tensorflow as tf -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) -if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) - import DeepSparseCoding.tf1x.params.param_picker as pp import DeepSparseCoding.tf1x.models.model_picker as mp import DeepSparseCoding.tf1x.data.data_selector as ds diff --git a/tf1x/vis/tsne_analysis.py b/tf1x/vis/tsne_analysis.py index f0f4430e..b60d7220 100644 --- a/tf1x/vis/tsne_analysis.py +++ b/tf1x/vis/tsne_analysis.py @@ -1,5 +1,9 @@ import os import sys +from os.path import dirname as up + +ROOT_DIR = up(up(up(up(os.path.realpath(__file__))))) +if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np import pickle @@ -7,9 +11,6 @@ from tensorflow.contrib.tensorboard.plugins import projector from scipy.misc import imsave -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) -if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) - import DeepSparseCoding.tf1x.data.data_selector as ds import DeepSparseCoding.tf1x.analysis.analysis_picker as ap import DeepSparseCoding.tf1x.utils.data_processing as dp diff --git a/tf1x/vis/vis_class_adversarial.py b/tf1x/vis/vis_class_adversarial.py index 7bbd72a7..ac41ce23 100644 --- a/tf1x/vis/vis_class_adversarial.py +++ b/tf1x/vis/vis_class_adversarial.py @@ -2,6 +2,11 @@ matplotlib.use('Agg') import os import sys +from os.path import dirname as up + +ROOT_DIR = up(up(up(up(os.path.realpath(__file__))))) +if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) + import numpy as np import matplotlib @@ -13,9 +18,6 @@ import pandas as pd import pdb -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) -if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) - import DeepSparseCoding.tf1x.data.data_selector as ds import DeepSparseCoding.tf1x.utils.data_processing as dp import DeepSparseCoding.tf1x.utils.plot_functions as pf diff --git a/tf1x/vis/vis_conv_lca.py b/tf1x/vis/vis_conv_lca.py index 567dea57..08fbda10 100644 --- a/tf1x/vis/vis_conv_lca.py +++ b/tf1x/vis/vis_conv_lca.py @@ -1,6 +1,11 @@ # In[1]: import os +from os.path import dirname as up + +ROOT_DIR = up(up(up(up(os.path.realpath(__file__))))) +if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) + import numpy as np import matplotlib matplotlib.use('Agg') @@ -11,9 +16,6 @@ import tensorflow as tf import pdb -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) -if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) - import DeepSparseCoding.tf1x.data.data_selector as ds import DeepSparseCoding.tf1x.utils.data_processing as dp import DeepSparseCoding.tf1x.utils.plot_functions as pf diff --git a/tf1x/vis/vis_corrupt.py b/tf1x/vis/vis_corrupt.py index 8b6b3df4..dee47ad2 100644 --- a/tf1x/vis/vis_corrupt.py +++ b/tf1x/vis/vis_corrupt.py @@ -2,6 +2,10 @@ matplotlib.use('Agg') import os import sys +from os.path import dirname as up + +ROOT_DIR = up(up(up(up(os.path.realpath(__file__))))) +if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np import matplotlib @@ -13,9 +17,6 @@ import pandas as pd import pickle -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) -if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) - import DeepSparseCoding.tf1x.data.data_selector as ds import DeepSparseCoding.tf1x.utils.data_processing as dp import DeepSparseCoding.tf1x.utils.plot_functions as pf diff --git a/tf1x/vis/vis_recon_adversarial.py b/tf1x/vis/vis_recon_adversarial.py index ab364e0e..ae926b89 100644 --- a/tf1x/vis/vis_recon_adversarial.py +++ b/tf1x/vis/vis_recon_adversarial.py @@ -3,6 +3,10 @@ import os import sys import pdb +from os.path import dirname as up + +ROOT_DIR = up(up(up(up(os.path.realpath(__file__))))) +if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np import matplotlib @@ -11,9 +15,6 @@ import matplotlib.gridspec as gridspec from skimage.measure import compare_psnr -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) -if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) - from DeepSparseCoding.tf1x.data.dataset import Dataset import DeepSparseCoding.tf1x.data.data_selector as ds import DeepSparseCoding.tf1x.utils.data_processing as dp diff --git a/train_model.py b/train_model.py index db2b2778..ef0193c3 100644 --- a/train_model.py +++ b/train_model.py @@ -2,8 +2,9 @@ import sys import argparse import time as ti +from os.path import dirname as up -ROOT_DIR = os.path.dirname(os.getcwd()) +ROOT_DIR = up(up(os.path.realpath(__file__))) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import DeepSparseCoding.utils.loaders as loaders diff --git a/utils/dataset_utils.py b/utils/dataset_utils.py index 6ffb8aae..0ca0f3cd 100644 --- a/utils/dataset_utils.py +++ b/utils/dataset_utils.py @@ -1,14 +1,15 @@ import os import sys +from os.path import dirname as up + +ROOT_DIR = up(up(up(os.path.realpath(__file__)))) +if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np import torch from torchvision import transforms from torchvision.datasets import MNIST -ROOT_DIR = os.path.dirname(os.getcwd()) -if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) - import DeepSparseCoding.utils.data_processing as dp import DeepSparseCoding.datasets.synthetic as synthetic From e1b32b7cbfe1996aa507d4297beb7c45d5052eaf Mon Sep 17 00:00:00 2001 From: Dylan Date: Thu, 4 Feb 2021 12:57:30 +0100 Subject: [PATCH 12/44] new util for extracting image patches --- tests/test_data_processing.py | 19 ++++ train_model.py | 1 + utils/data_processing.py | 206 +++++++++++++++++++++++++++++++--- 3 files changed, 210 insertions(+), 16 deletions(-) diff --git a/tests/test_data_processing.py b/tests/test_data_processing.py index 004ee80e..7fa116f3 100644 --- a/tests/test_data_processing.py +++ b/tests/test_data_processing.py @@ -254,3 +254,22 @@ def test_l2_weight_norm(self): normed_w_norm = dp.get_weights_l2_norm(torch.tensor(normed_w), eps=1e-12).numpy() np.testing.assert_allclose(normed_w_norm, 1.0, rtol=1e-10) np.testing.assert_allclose(w / w_norm, normed_w, rtol=1e-10) + + def test_patches(self): + err = 1e-6 + rand_mean = 0; rand_var = 1 + num_im = 10; im_edge = 512; im_chan = 1; patch_edge = 16 + num_patches = np.int(num_im * (im_edge / patch_edge)**2) + rand_seed = 1234 + rand_state = np.random.RandomState(rand_seed) + data = np.stack([rand_state.normal(rand_mean, rand_var, size=[im_edge, im_edge, im_chan]) + for _ in range(num_im)]) + data_shape = list(data.shape) + patch_shape = [patch_edge, patch_edge, im_chan] + datapoint = torch.tensor(data[0, ...]) + datapoint_patches = dp.single_image_to_patches(datapoint, patch_shape) + datapoint_recon = dp.patches_to_single_image(datapoint_patches, data_shape[1:]) + np.testing.assert_allclose(datapoint.numpy(), datapoint_recon.numpy(), rtol=err) + patches = dp.images_to_patches(torch.tensor(data), patch_shape) + data_recon = dp.patches_to_images(patches, data_shape[1:]) + np.testing.assert_allclose(data, data_recon.numpy(), rtol=err) diff --git a/train_model.py b/train_model.py index ef0193c3..d4e5d727 100644 --- a/train_model.py +++ b/train_model.py @@ -40,6 +40,7 @@ model.log_info(f'Completed epoch {epoch}/{model.params.num_epochs}') print(f'Completed epoch {epoch}/{model.params.num_epochs}') +# Final outputs t1 = ti.time() tot_time=float(t1-t0) tot_images = model.params.num_epochs*len(train_loader.dataset) diff --git a/utils/data_processing.py b/utils/data_processing.py index a0697454..aa936e52 100644 --- a/utils/data_processing.py +++ b/utils/data_processing.py @@ -4,8 +4,9 @@ def reshape_data(data, flatten=None, out_shape=None): """ - Helper function to reshape input data for processing and return data shape - Inputs: + Reshape input data for processing and return data shape + + Keyword arguments: data: [tensor] data of shape: n is num_examples, i is num_rows, j is num_cols, k is num_channels, l is num_examples = i*j*k if out_shape is not specified, it is assumed that i == j @@ -22,6 +23,7 @@ def reshape_data(data, flatten=None, out_shape=None): If data is flat and flatten==True, or !flat and flatten==False, then None condition will apply out_shape: [list or tuple] containing the desired output shape This will overwrite flatten, and return the input reshaped according to out_shape + Outputs: tuple containing: data: [tensor] data with new shape @@ -98,7 +100,8 @@ def reshape_data(data, flatten=None, out_shape=None): def check_all_same_shape(tensor_list): """ Verify that all tensors in the tensor list have the same shape - Args: + + Keyword arguments: tensor_list: list of tensors to be checked Returns: raises error if the tensors are not the same shape @@ -114,7 +117,8 @@ def check_all_same_shape(tensor_list): def flatten_feature_map(feature_map): """ Flatten input tensor from [batch, y, x, f] to [batch, y*x*f] - Args: + + Keyword arguments: feature_map: tensor with shape [batch, y, x, f] Returns: reshaped_map: tensor with shape [batch, y*x*f] @@ -134,12 +138,15 @@ def flatten_feature_map(feature_map): def standardize(data, eps=None, samplewise=True): """ Standardize each image data to have zero mean and unit standard-deviation (z-score) - Uses population standard deviation data.sum() / N, where N = data.shape[0]. - Inputs: + + This function uses population standard deviation data.sum() / N, where N = data.shape[0]. + + Keyword arguments: data: [tensor] unnormalized data eps: [float] if the std(data) is less than eps, then divide by eps instead of std(data) samplewise: [bool] if True, standardize each sample individually; akin to contrast-normalization if False, compute mean and std over entire batch + Outputs: data: [tensor] normalized data """ @@ -164,10 +171,12 @@ def standardize(data, eps=None, samplewise=True): def rescale_data_to_one(data, eps=None, samplewise=True): """ Rescale input data to be between 0 and 1 - Inputs: + + Keyword arguments: data: [tensor] unnormalized data eps: [float] if the std(data) is less than eps, then divide by eps instead of std(data) samplewise: [bool] if True, compute it per-sample, otherwise normalize entire batch + Outputs: data: [tensor] centered data of shape (n, i, j, k) or (n, l) """ @@ -186,11 +195,14 @@ def rescale_data_to_one(data, eps=None, samplewise=True): data = (data - data_min) / data_range return data, data_min, data_max + def one_hot_to_dense(one_hot_labels): """ - converts a matrix of one-hot labels to a list of dense labels - Inputs: + Convert a matrix of one-hot labels to a list of dense labels + + Keyword arguments: one_hot_labels: one-hot torch tensor of shape [num_labels, num_classes] + Outputs: dense_labels: 1D torch tensor array of labels The integer value indicates the class and 0 is assumed to be a class. @@ -202,10 +214,17 @@ def one_hot_to_dense(one_hot_labels): dense_labels[label_id] = torch.nonzero(one_hot_labels[label_id, :] == 1) return dense_labels + def dense_to_one_hot(labels_dense, num_classes): """ - converts a (np.ndarray) vector of dense labels to a (np.ndarray) matrix of one-hot labels - e.g. [0, 1, 1, 3] -> [00, 01, 01, 11] + Converts a (np.ndarray) vector of dense labels to a (np.ndarray) matrix of one-hot labels. E.g. [0, 1, 1, 3] -> [00, 01, 01, 11] + + Keyword arguments: + labels_dense: dense torch tensor of shape [num_classes], where each entry is an integer indicating the class label + num-classes: The total number of classes in the dataset + + Outputs: + one_hot_labels: one-hot torch tensor of shape [num_labels, num_classes] """ num_labels = labels_dense.shape[0] index_offset = torch.arange(end=num_labels, dtype=torch.int32) * num_classes @@ -213,25 +232,31 @@ def dense_to_one_hot(labels_dense, num_classes): labels_one_hot.view(-1)[index_offset + labels_dense.view(-1)] = 1 return labels_one_hot + def atleast_kd(x, k): """ - return x reshaped to append singleton dimensions such that x.ndim is at least k - Inputs: + Return x reshaped to append singleton dimensions such that x.ndim is at least k + + Keyword arguments: x [Tensor or numpy ndarray] k [int] minimum number of dimensions + Outputs: x [same as input x] reshaped input to have at least k dimensions """ shape = x.shape + (1,) * (k - x.ndim) return x.reshape(shape) + def get_weights_l2_norm(w, eps=1e-12): """ - get l2 norm of weight matrix - Inputs: + Return l2 norm of weight matrix + + Keyword arguments: w [Tensor] assumed to have shape [inC, outC] or [outC, inC, kernH, kernW] norm is calculated over vectorized version of inC in the first case or inC*kernH*kernW in the second eps [float] minimum value to prevent division by zero + Outputs: norm [Tensor] norm of each of the outC weight vectors """ @@ -247,15 +272,164 @@ def get_weights_l2_norm(w, eps=1e-12): norms = atleast_kd(norms, w.ndim) return norms + def l2_normalize_weights(w, eps=1e-12): """ l2 normalize weight matrix - Inputs: + + Keyword arguments: w [Tensor] assumed to have shape [inC, outC] or [outC, inC, kernH, kernW] norm is calculated over vectorized version of inC in the first case or inC*kernH*kernW in the second eps [float] minimum value to prevent division by zero + Outputs: w [Tensor] same type and shape as input w, but with unitary l2 norm when computed over all input dimensions """ norms = get_weights_l2_norm(w, eps) return w / norms + + +def single_image_to_patches(image, patch_shape): + """ + Extract patches from a single image + + Keyword arguments: + image [torch tensor] of shape [im_height, im_width, im_chan] + patch_shape [tuple or list] containing the output shape + [patch_height, patch_width, patch_chan] + patch_chan must be the same as im_chan + + It is recommended, though not required, that the patch height and width divide evenly into + the image height and width, respectively. + + Outputs: + patches [torch tensor] of patches of shape [num_patches]+list(patch_shape) + """ + try: + im_height, im_width, im_chan = image.shape + patch_height, patch_width, patch_chan = patch_shape + except Exception as e: + raise ValueError( + f'This function requires that: ' + +f'1) The input variable "image" must have shape [im_height, im_width, im_chan], and is {image.shape}' + +f'and 2) the input variable "patch_shape" must have shape [patch_height, patch_width, patch_chan], and is {patch_shape}.' + ) from e + num_row_patches = np.floor(im_height / patch_height) + num_col_patches = np.floor(im_width / patch_width) + num_patches = int(num_row_patches * num_col_patches) + patches = torch.zeros((num_patches, patch_height, patch_width, patch_chan)) + row_id = 0 + col_id = 0 + for patch_idx in range(num_patches): + row_end = row_id + patch_height + col_end = col_id + patch_width + try: + patches[patch_idx, ...] = image[row_id:row_end, col_id:col_end, :] + except Exception as e: + raise ValueError('This function requires that im_chan equal patch_chan.') from e + row_id += patch_height + if row_id >= im_height: + row_id = 0 + col_id += patch_width + if col_id >= im_width: + col_id = 0 + return patches + + +def patches_to_single_image(patches, image_shape): + """ + Convert patches input into a single ouput + + Keyword arguments: + patches [torch tensor] of shape [num_patches, patch_height, patch_width, patch_chan] + image_shape [list or tuple] of length 2 containing the image shape [im_height, im_width, im_chan] + + im_chan is assumed to equal patch_chan + + Outputs: + image [torch tensor] of shape [im_height, im_width, im_chan] + """ + try: + num_patches, patch_height, patch_width, patch_chan = patches.shape + im_height, im_width, im_chan = image_shape + except Exception as e: + raise ValueError( + f'This funciton requires that input patches has shape' + f' [num_patches, patch_height, patch_width, patch_chan] and is {patches.shape}' + f' and input image_shape is a list or tuple of integers of length 3 containing [im_height, im_width, im_chan] and is {image_shape}' + ) from e + image = torch.zeros((im_height, im_width, im_chan)) + row_id = 0 + col_id = 0 + for patch_idx in range(num_patches): + row_end = row_id + patch_height + col_end = col_id + patch_width + image[row_id:row_end, col_id:col_end, :] = patches[patch_idx, ...] + row_id += patch_height + if row_id >= im_height: + row_id = 0 + col_id += patch_width + if col_id >= im_width: + col_id = 0 + return image + +def images_to_patches(images, patch_shape): + """ + Extract evenly distributed non-overlapping patches from an image dataset + + Keyword arguments: + images [torch tensor] of shape [num_images, im_height, im_width, im_chan] or [im_height, im_width, im_chan] for a single image + patch_shape [tuple or list] containing the output shape + [patch_height, patch_width, patch_chan] + patch_chan must be the same as im_chan + + It is recommended, though not required, that the patch height and width divide evenly into the image height and width, respectively. + + Outputs: + patches [np.ndarray] of patches of shape [num_patches]+list(patch_shape) + """ + if images.ndim == 3: # single image + return single_image_to_patches(images, patch_shape) + num_im, im_height, im_width, im_chan = images.shape + patch_height, patch_width, patch_chan = patch_shape + num_row_patches = np.floor(im_height / patch_height) + num_col_patches = np.floor(im_width / patch_width) + num_patches_per_im = int(num_row_patches * num_col_patches) + tot_num_patches = int(num_patches_per_im * num_im) + patches = torch.zeros([tot_num_patches, ]+list(patch_shape)) + patch_id = 0 + for im_id in range(num_im): + image = images[im_id, ...] + image_patches = single_image_to_patches(image, patch_shape) + patch_end = patch_id + num_patches_per_im + patches[patch_id:patch_end, ...] = image_patches + patch_id += num_patches_per_im + return patches + +def patches_to_images(patches, image_shape): + """ + Recombine patches tensor into a dataset of images + + Keyword arguments: + patches [torch tensor] holding square patch data of shape [num_patches, patch_height, patch_width, patch_chan] + image_shape [list or tuple] containing the image dataset shape [im_height, im_width, im_chan] + + It is assumed that im_chan equals patch_chan + + Outputs: + images [torch tensor] holding the recombined image dataset + """ + tot_num_patches, patch_height, patch_width, patch_chan = patches.shape + im_height, im_width, im_chan = image_shape + num_row_patches = np.floor(im_height / patch_height) + num_col_patches = np.floor(im_width / patch_width) + num_patches_per_im = int(num_row_patches * num_col_patches) + num_im = tot_num_patches // num_patches_per_im + images = torch.zeros([num_im]+image_shape) + patch_id = 0 + for im_id in range(num_im): + patch_end = patch_id + num_patches_per_im + patch_batch = patches[patch_id:patch_end, ...] + images[im_id, ...] = patches_to_single_image(patch_batch, image_shape) + patch_id += num_patches_per_im + return images From 15128d9a9ce064205f800eeafe759ce4d6486ea6 Mon Sep 17 00:00:00 2001 From: Dylan Date: Thu, 4 Feb 2021 12:59:52 +0100 Subject: [PATCH 13/44] adds util for converting a batch of images into a single tiled image --- utils/plot_functions.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/utils/plot_functions.py b/utils/plot_functions.py index 150ab1eb..8374db1b 100644 --- a/utils/plot_functions.py +++ b/utils/plot_functions.py @@ -94,3 +94,29 @@ def plot_stats(data, x_key, x_label=None, y_keys=None, y_labels=None, start_inde return None plot.show() return fig + +def pad_images(images, pad_values=1): + """ + Convert an array of images into a single tiled image with padded border + + Keyword arguments: + images: [np.ndarray] of shape [num_samples, im_height, im_width, im_chan] + pad_values: [int] specifying what value will be used for padding + + Outputs: + padded_images: [np.ndarray] padded version of input + """ + n = int(np.ceil(np.sqrt(images.shape[0]))) + padding = (((0, n ** 2 - images.shape[0]), + (1, 1), (1, 1)) # add some space between filters + + ((0, 0),) * (images.ndim - 3)) # don't pad last dimension (if there is one) + padded_images = np.pad(images, padding, mode="constant", + constant_values=pad_values) + # tile the filters into an image + padded_images = padded_images.reshape(( + (n, n) + padded_images.shape[1:])).transpose(( + (0, 2, 1, 3) + tuple(range(4, padded_images.ndim + 1)))) + padded_images = padded_images.reshape((n * padded_images.shape[1], + n * padded_images.shape[3]) + padded_images.shape[4:]) + return padded_images + From db57cd1f4bca92c6921626f077cd5a0b6738ccfa Mon Sep 17 00:00:00 2001 From: Dylan Date: Mon, 8 Feb 2021 15:08:03 +0100 Subject: [PATCH 14/44] renames variables for clarity; docs additions; adds num_validation --- params/base_params.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/params/base_params.py b/params/base_params.py index 142bdb5a..743278a3 100644 --- a/params/base_params.py +++ b/params/base_params.py @@ -9,6 +9,7 @@ class BaseParams(object): """ all models batch_size [int] number of images in a training batch + center_dataset [bool] if True, subtract the mean dataset image from all datapoints data_dir [str] location of dataset folders device [str] which device to run on dtype [torch dtype] dtype for network variables @@ -36,6 +37,7 @@ class BaseParams(object): standardize_data [bool] if set, z-score data to have mean=0 and standard deviation=1 using numpy operators train_logs_per_epoch [int or None] how often to send updates to the logfile workspace_dir [str] system directory that is the parent to the primary repository directory + num_validation [int] number of images to reserve for the validation set (only works with some datasets) mlp activation_functions [list of str] strings correspond to activation functions for layers. @@ -54,7 +56,7 @@ class BaseParams(object): num_steps [int] number of lca inference steps to take rectify_a [bool] if set, rectify the layer 1 neuron activity sparse_mult [float] multiplyer placed in front of the sparsity loss term - tau [float] LCA time constant + tau [float] LCA time constant; larger values result in smaller step sizes (i.e. slower convergence) lca update rule (step_size) is multiplied by dt/tau thresh_type [str] specifying LCA threshold function; can be "hard" or "soft" @@ -68,8 +70,9 @@ def __init__(self): self.compute_helper_params() def set_params(self): - self.standardize_data = False + self.center_dataset = False self.rescale_data_to_one = False + self.standardize_data = False self.model_type = None self.log_to_file = True self.train_logs_per_epoch = None @@ -77,6 +80,7 @@ def set_params(self): self.shuffle_data = True self.eps = 1e-12 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.num_validation = 0 self.rand_seed = 123456789 self.rand_state = np.random.RandomState(self.rand_seed) self.workspace_dir = os.path.join(os.path.expanduser('~'), 'Work') From b193860af26233fc11987ab23a06cc12877473e2 Mon Sep 17 00:00:00 2001 From: Dylan Date: Mon, 8 Feb 2021 15:10:28 +0100 Subject: [PATCH 15/44] adds normalization options and cifar10 dataset --- utils/data_processing.py | 73 +++++++++++++++++++++++++++++++++++++--- utils/dataset_utils.py | 52 +++++++++++++++++++++++++--- 2 files changed, 115 insertions(+), 10 deletions(-) diff --git a/utils/data_processing.py b/utils/data_processing.py index aa936e52..8c828b57 100644 --- a/utils/data_processing.py +++ b/utils/data_processing.py @@ -135,7 +135,69 @@ def flatten_feature_map(feature_map): return resh_map -def standardize(data, eps=None, samplewise=True): +def get_std_from_dataloader(loader, dataset_mean): + """ + TODO: Calculate the standard deviation from all entries in a pytorch data loader + + Keyword arguments: + loader: [pytorch DataLoader] containing the full dataset. + This function assumes there is always a target label, i.e. loader.next() returns (data, target) + dataset_mean: [torch tensor] of the same shape as a single dataset sample + + Outputs: + dataset_std: [torch tensor] of the same shape as a single dataset sample + """ + #dataset_std = torch.zeros(next(iter(loader)).shape[1:]) + #for data, target in loader: + # std_sum_squares += (data - dataset_mean)**2 + #dataset_std = torch.sqrt(std_sum_squares / len(loader.dataset)) + #return dataset_std + raise NotImplementedError + + +def get_mean_from_dataloader(loader): + """ + Calculate the mean datapoint from all entries in a pytorch data loader + + Keyword arguments: + loader: [pytorch DataLoader] containing the full dataset. + This function assumes there is always a target label, i.e. loader.next() returns (data, target) + + Outputs: + dataset_mean: [torch tensor] of the same shape as a single dataset sample + """ + dataset_mean = torch.zeros(next(iter(loader))[0].shape[1:]) # don't include batch dimension + num_batches = 0 + for data, target in loader: + dataset_mean += data.mean(axis=0, keepdims=False) + num_batches += 1 + return dataset_mean / num_batches + + +def center(data, samplewise=False, batch_size=100): + """ + Center image dataset to have zero mean + + Keyword arguments: + data: [tensor] unnormalized data + samplewise: [bool] if True, center each sample individually; if False, compute mean over entire batch + + Outputs: + data: [tensor] centered data + """ + data, orig_shape = reshape_data(data, flatten=True)[:2] # Adds channel dimension if it's missing + if(samplewise): # center each input sample individually + data_axis = tuple(range(data.ndim)[1:]) + data_mean = torch.mean(data, dim=data_axis, keepdim=True) + else: # center the entire population + data_mean = torch.mean(data, dim=0) + data = data - data_mean + if(data.shape != orig_shape): + data = reshape_data(data, out_shape=orig_shape)[0] + return data, data_mean + + +def standardize(data, eps=None, samplewise=False, batch_size=100): """ Standardize each image data to have zero mean and unit standard-deviation (z-score) @@ -144,6 +206,7 @@ def standardize(data, eps=None, samplewise=True): Keyword arguments: data: [tensor] unnormalized data eps: [float] if the std(data) is less than eps, then divide by eps instead of std(data) + defaults to 1/sqrt(data_dim) where data_dim is the total size of a data vector samplewise: [bool] if True, standardize each sample individually; akin to contrast-normalization if False, compute mean and std over entire batch @@ -154,12 +217,12 @@ def standardize(data, eps=None, samplewise=True): eps = 1.0 / np.sqrt(data[0,...].numel()) data, orig_shape = reshape_data(data, flatten=True)[:2] # Adds channel dimension if it's missing num_examples = data.shape[0] - if(samplewise): # standardize the entire population - data_axis = tuple(range(data.ndim)[1:]) # standardize each example individually + if(samplewise): # standardize each input sample individually + data_axis = tuple(range(data.ndim)[1:]) data_mean = torch.mean(data, dim=data_axis, keepdim=True) data_true_std = torch.std(data, unbiased=False, dim=data_axis, keepdim=True) - else: # standardize each input sample individually - data_mean = torch.mean(data) + else: # standardize the entire population + data_mean = torch.mean(data, dim=0) data_true_std = torch.std(data, unbiased=False) data_std = torch.where(data_true_std >= eps, data_true_std, eps*torch.ones_like(data_true_std)) data = (data - data_mean) / data_std diff --git a/utils/dataset_utils.py b/utils/dataset_utils.py index 0ca0f3cd..fd05bfc9 100644 --- a/utils/dataset_utils.py +++ b/utils/dataset_utils.py @@ -8,13 +8,13 @@ import numpy as np import torch from torchvision import transforms -from torchvision.datasets import MNIST +import torchvision.datasets import DeepSparseCoding.utils.data_processing as dp import DeepSparseCoding.datasets.synthetic as synthetic -class FastMNIST(MNIST): +class FastMNIST(torchvision.datasets.MNIST): """ The torchvision MNIST dataset has additional overhead that slows it down. This loads the entire dataset onto the specified device at init, resulting in a considerable speedup @@ -64,7 +64,7 @@ def load_dataset(params): transforms.Lambda(lambda x: dp.rescale_data_to_one(x, eps=params.eps, samplewise=True)[0])) kwargs = { 'root':params.data_dir, - 'download':True, + 'download':False, 'transform':transforms.Compose(preprocessing_pipeline) } if hasattr(params, 'fast_mnist') and params.fast_mnist: @@ -81,14 +81,56 @@ def load_dataset(params): else: kwargs['train'] = True train_loader = torch.utils.data.DataLoader( - MNIST(**kwargs), batch_size=params.batch_size, + torchvision.datasets.MNIST(**kwargs), batch_size=params.batch_size, shuffle=params.shuffle_data, num_workers=0, pin_memory=True) kwargs['train'] = False val_loader = None test_loader = torch.utils.data.DataLoader( - MNIST(**kwargs), batch_size=params.batch_size, + torchvision.datasets.MNIST(**kwargs), batch_size=params.batch_size, shuffle=params.shuffle_data, num_workers=0, pin_memory=True) + elif(params.dataset.lower() == 'cifar10'): + preprocessing_pipeline = [ + transforms.ToTensor(), + transforms.Lambda(lambda x: x.permute(1, 2, 0)), # channels last + ] + kwargs = { + 'root': os.path.join(*[params.data_dir,'cifar10']), + 'download': False, + 'train': True, + 'transform': transforms.Compose(preprocessing_pipeline) + } + if params.center_dataset: + dataset = torchvision.datasets.CIFAR10(**kwargs) + data_loader = torch.utils.data.DataLoader(dataset, batch_size=params.batch_size, + shuffle=False, num_workers=0, pin_memory=True) + dataset_mean_image = dp.get_mean_from_dataloader(data_loader) + preprocessing_pipeline.append( + transforms.Lambda(lambda x: x - dataset_mean_image)) + if params.standardize_data: + preprocessing_pipeline.append( + transforms.Lambda( + lambda x: dp.standardize(x, eps=params.eps, samplewise=True, batch_size=params.batch_size)[0] + ) + ) + if params.rescale_data_to_one: + preprocessing_pipeline.append( + transforms.Lambda(lambda x: dp.rescale_data_to_one(x, eps=params.eps, samplewise=True)[0])) + kwargs['transform'] = transforms.Compose(preprocessing_pipeline) + kwargs['train'] = True + dataset = torchvision.datasets.CIFAR10(**kwargs) + kwargs['train'] = False + testset = torchvision.datasets.CIFAR10(**kwargs) + num_train = len(dataset) - params.num_validation + trainset, valset = torch.utils.data.random_split(dataset, + [num_train, params.num_validation]) + train_loader = torch.utils.data.DataLoader(trainset, batch_size=params.batch_size, + shuffle=params.shuffle_data, num_workers=0, pin_memory=True) + val_loader = torch.utils.data.DataLoader(valset, batch_size=params.batch_size, + shuffle=False, num_workers=0, pin_memory=True) + test_loader = torch.utils.data.DataLoader(testset, batch_size=params.batch_size, + shuffle=False, num_workers=0, pin_memory=True) + elif(params.dataset.lower() == 'dsprites'): root = os.path.join(*[params.data_dir]) dsprites_file = os.path.join(*[root, 'dsprites/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz']) From b45f3f72815ad8544b30502ddd3a1abdb7e3f48f Mon Sep 17 00:00:00 2001 From: Dylan Date: Mon, 8 Feb 2021 15:10:47 +0100 Subject: [PATCH 16/44] working cifar10 conv lca params --- params/conv_lca_cifar10_params.py | 46 +++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 params/conv_lca_cifar10_params.py diff --git a/params/conv_lca_cifar10_params.py b/params/conv_lca_cifar10_params.py new file mode 100644 index 00000000..e04b8e41 --- /dev/null +++ b/params/conv_lca_cifar10_params.py @@ -0,0 +1,46 @@ +import types + +from DeepSparseCoding.params.base_params import BaseParams + + +class params(BaseParams): + def set_params(self): + super(params, self).set_params() + self.model_type = 'conv_lca' + self.model_name = 'conv_lca_cifar10' + self.version = '0' + self.dataset = 'cifar10' + self.num_validation = 10000 + self.standardize_data = True + self.rescale_data_to_one = False + self.center_dataset = False + self.batch_size = 25 + self.num_epochs = 250 + self.weight_decay = 0.0 + self.weight_lr = 0.001 + self.train_logs_per_epoch = 6 + self.optimizer = types.SimpleNamespace() + self.optimizer.name = 'sgd' + self.optimizer.lr_annealing_milestone_frac = [0.8] # fraction of num_epochs + self.optimizer.lr_decay_rate = 0.8 + self.renormalize_weights = True + self.dt = 0.001 + self.tau = 0.2 + self.num_steps = 75 + self.rectify_a = True + self.thresh_type = 'hard' + self.sparse_mult = 0.30 + self.kernel_size = 8 + self.stride = 2 + self.padding = 0 + self.num_latent = 512 + self.compute_helper_params() + + def compute_helper_params(self): + super(params, self).compute_helper_params() + self.optimizer.milestones = [frac * self.num_epochs + for frac in self.optimizer.lr_annealing_milestone_frac] + self.step_size = self.dt / self.tau + self.out_channels = self.num_latent + self.num_pixels = 3072 + self.in_channels = 3 From cdcecb11fb4feaf5dcad334a654d16625ef8e820 Mon Sep 17 00:00:00 2001 From: Dylan Date: Wed, 10 Feb 2021 11:55:01 +0100 Subject: [PATCH 17/44] more epochs for training, still not totally converged --- params/conv_lca_cifar10_params.py | 2 +- params/conv_lca_mnist_params.py | 45 ------------------------------- 2 files changed, 1 insertion(+), 46 deletions(-) delete mode 100644 params/conv_lca_mnist_params.py diff --git a/params/conv_lca_cifar10_params.py b/params/conv_lca_cifar10_params.py index e04b8e41..63e24d47 100644 --- a/params/conv_lca_cifar10_params.py +++ b/params/conv_lca_cifar10_params.py @@ -15,7 +15,7 @@ def set_params(self): self.rescale_data_to_one = False self.center_dataset = False self.batch_size = 25 - self.num_epochs = 250 + self.num_epochs = 500 self.weight_decay = 0.0 self.weight_lr = 0.001 self.train_logs_per_epoch = 6 diff --git a/params/conv_lca_mnist_params.py b/params/conv_lca_mnist_params.py deleted file mode 100644 index a129258b..00000000 --- a/params/conv_lca_mnist_params.py +++ /dev/null @@ -1,45 +0,0 @@ -import types - -from DeepSparseCoding.params.base_params import BaseParams - - -class params(BaseParams): - def set_params(self): - super(params, self).set_params() - self.model_type = 'conv_lca' - self.model_name = 'conv_lca_mnist' - self.version = '0' - self.dataset = 'mnist' - self.fast_mnist = True - self.standardize_data = False - self.rescale_data_to_one = True - self.num_pixels = 784 - self.batch_size = 50 - self.num_epochs = 500 - self.weight_decay = 0.0 - self.weight_lr = 0.001 - self.train_logs_per_epoch = 6 - self.optimizer = types.SimpleNamespace() - self.optimizer.name = 'sgd' - self.optimizer.lr_annealing_milestone_frac = [0.8] # fraction of num_epochs - self.optimizer.lr_decay_rate = 0.8 - self.renormalize_weights = True - self.dt = 0.001 - self.tau = 0.03 - self.num_steps = 75 - self.rectify_a = True - self.thresh_type = 'soft' - self.sparse_mult = 0.25 - self.kernel_size = 8 - self.stride = 2 - self.padding = 0 - self.num_latent = 128 - self.compute_helper_params() - - def compute_helper_params(self): - super(params, self).compute_helper_params() - self.optimizer.milestones = [frac * self.num_epochs - for frac in self.optimizer.lr_annealing_milestone_frac] - self.step_size = self.dt / self.tau - self.out_channels = self.num_latent - self.in_channels = 1 From 0203c3aa4fc4cfda32d1aace7434a50dfc739639 Mon Sep 17 00:00:00 2001 From: Dylan Date: Wed, 10 Feb 2021 11:55:29 +0100 Subject: [PATCH 18/44] no need to specify conv in the params filenames --- params/{conv_lca_cifar10_params.py => lca_cifar10_params.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename params/{conv_lca_cifar10_params.py => lca_cifar10_params.py} (100%) diff --git a/params/conv_lca_cifar10_params.py b/params/lca_cifar10_params.py similarity index 100% rename from params/conv_lca_cifar10_params.py rename to params/lca_cifar10_params.py From dcdc5a5ccaaca6d365ba5cb530ba65742b85c7a9 Mon Sep 17 00:00:00 2001 From: Dylan Date: Wed, 10 Feb 2021 11:56:59 +0100 Subject: [PATCH 19/44] combines conv and fc lca params --- params/lca_mnist_params.py | 43 ++++++++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/params/lca_mnist_params.py b/params/lca_mnist_params.py index 881c75ff..9d575e4b 100644 --- a/params/lca_mnist_params.py +++ b/params/lca_mnist_params.py @@ -3,33 +3,50 @@ from DeepSparseCoding.params.base_params import BaseParams +CONV = False + + class params(BaseParams): def set_params(self): super(params, self).set_params() - self.model_type = 'lca' - self.model_name = 'lca_768_mnist' self.version = '0' self.dataset = 'mnist' self.fast_mnist = True self.standardize_data = False self.num_pixels = 784 - self.batch_size = 100 - self.num_epochs = 1000 - self.weight_decay = 0. - self.weight_lr = 0.1 - self.train_logs_per_epoch = 6 - self.optimizer = types.SimpleNamespace() - self.optimizer.name = 'sgd' - self.optimizer.lr_annealing_milestone_frac = [0.7] # fraction of num_epochs - self.optimizer.lr_decay_rate = 0.5 - self.renormalize_weights = True self.dt = 0.001 self.tau = 0.03 self.num_steps = 75 self.rectify_a = True self.thresh_type = 'soft' self.sparse_mult = 0.25 - self.num_latent = 768#self.num_pixels*4 + self.renormalize_weights = True + self.optimizer = types.SimpleNamespace() + self.optimizer.name = 'sgd' + self.num_epochs = 1000 + self.weight_decay = 0.0 + self.train_logs_per_epoch = 6 + if CONV: + self.model_type = 'conv_lca' + self.model_name = 'conv_lca_mnist' + self.rescale_data_to_one = True + self.batch_size = 50 + self.weight_lr = 0.001 + self.optimizer.lr_annealing_milestone_frac = [0.8] # fraction of num_epochs + self.optimizer.lr_decay_rate = 0.8 + self.kernel_size = 8 + self.stride = 2 + self.padding = 0 + self.num_latent = 128 + else: + self.model_type = 'lca' + self.model_name = 'lca_768_mnist' + self.rescale_data_to_one = False + self.batch_size = 100 + self.weight_lr = 0.1 + self.optimizer.lr_annealing_milestone_frac = [0.7] # fraction of num_epochs + self.optimizer.lr_decay_rate = 0.5 + self.num_latent = 768 #self.num_pixels * 4 self.compute_helper_params() def compute_helper_params(self): From 2db240eba45b42a7cc677e87e93738975118ce2c Mon Sep 17 00:00:00 2001 From: Dylan Date: Thu, 11 Feb 2021 09:28:31 +0000 Subject: [PATCH 20/44] adds conv mlp; updates checkpointing; new params adds convolutional MLP model with max pooling reorders all expected datashapes to have channels first moves typical log outputs to base model adds optimizer to checkpoint writing adds ability to load checkpoints from a log file adds ability to boot from checkpoint at the start of training (untested) adds ability to ignore gradients or include them when training ensemble models (untested) datasets now include a 'num_pixels' parameter in their output --- models/base_model.py | 59 +++++++++++++--- models/conv_lca_model.py | 11 ++- models/ensemble_model.py | 12 +++- models/lca_model.py | 12 ++-- models/mlp_model.py | 12 ++-- modules/mlp_module.py | 80 ++++++++++++++++++++-- params/base_params.py | 11 +++ params/lca_mlp_cifar10_params.py | 72 ++++++++++++++++++++ params/lca_mlp_mnist_params.py | 7 +- params/mlp_cifar10_params.py | 43 ++++++++++++ params/mlp_mnist_params.py | 3 +- params/test_params.py | 2 + tests/test_data_processing.py | 32 ++++----- tests/test_datasets.py | 7 +- tests/test_models.py | 19 +++++- train_model.py | 2 +- utils/data_processing.py | 112 +++++++++++++++---------------- utils/dataset_utils.py | 11 ++- utils/loaders.py | 1 + utils/run_utils.py | 16 +++-- 20 files changed, 400 insertions(+), 124 deletions(-) create mode 100644 params/lca_mlp_cifar10_params.py create mode 100644 params/mlp_cifar10_params.py diff --git a/models/base_model.py b/models/base_model.py index c1e383fa..4bef1ca0 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -1,9 +1,11 @@ import os +import pprint import numpy as np import torch from DeepSparseCoding.utils.file_utils import Logger +import DeepSparseCoding.utils.loaders as loaders class BaseModel(object): @@ -96,20 +98,61 @@ def log_info(self, string): """Log input string""" self.logger.log_info(string) - def write_checkpoint(self): - """Write checkpoints""" - torch.save(self.state_dict(), self.params.cp_latest_filename) + def get_train_stats(self, batch_step=None): + """ + Get default statistics about current training run + + Keyword arguments: + batch_step: [int] current batch iteration. The default assumes that training has finished. + """ + if batch_step is None: + batch_step = self.params.num_batches + epoch = batch_step / self.params.batches_per_epoch + stat_dict = { + 'epoch':int(epoch), + 'batch_step':batch_step, + 'train_progress':np.round(batch_step/self.params.num_batches, 3), + } + return stat_dict + + def write_checkpoint(self, batch_step=None): + """ + Write checkpoints + + Keyword arguments: + batch_step: [int] current batch iteration. The default assumes that training has finished. + """ + output_dict = { + 'model_state_dict': self.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + } + training_stats = self.get_train_stats(batch_step) + output_dict.update(training_stats) + torch.save(output_dict, self.params.cp_latest_filename) self.log_info('Full model saved in file %s'%self.params.cp_latest_filename) - def load_checkpoint(self, cp_file=None): + def get_checkpoint_from_log(self, logfile): + model_params = loaders.load_params(logfile) + checkpoint = torch.load(model_params.cp_latest_filename) + return checkpoint + + def load_checkpoint(self, cp_file=None, load_optimizer=False): """ Load checkpoint - Inputs: - model_dir: String specifying the path to the checkpoint + Keyword arguments: + model_dir: [str] specifying the path to the checkpoint """ if cp_file is None: cp_file = self.params.cp_latest_filename - return self.load_state_dict(torch.load(cp_file)) + checkpoint = torch.load(cp_file) + if load_optimizer: + self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + self.load_state_dict(checkpoint['model_state_dict']) + _ = checkpoint.pop('optimizer_state_dict', None) + _ = checkpoint.pop('model_state_dict', None) + training_status = pprint.pformat(checkpoint, compact=True)#, sort_dicts=True #TODO: Python 3.8 adds the sort_dicts parameter + out_str = f'Loaded checkpoint from {cp_file} with the following stats:\n{training_status}' + return out_str def get_optimizer(self, optimizer_params, trainable_variables): optimizer_name = optimizer_params.optimizer.name @@ -157,7 +200,7 @@ def generate_update_dict(self, input_data, input_labels=None, batch_step=0, upda Generates a dictionary to be logged in the print_update function """ if update_dict is None: - update_dict = dict() + update_dict = self.get_train_stats(batch_step) for param_name, param_var in self.named_parameters(): grad = param_var.grad update_dict[param_name+'_grad_max_mean_min'] = [ diff --git a/models/conv_lca_model.py b/models/conv_lca_model.py index f9d96742..36f54086 100644 --- a/models/conv_lca_model.py +++ b/models/conv_lca_model.py @@ -11,6 +11,10 @@ def setup(self, params, logger=None): super(ConvLcaModel, self).setup(params, logger) self.setup_module(params) self.setup_optimizer() + if params.checkpoint_boot_log != '': + checkpoint = self.get_checkpoint_from_log(params.checkpoint_boot_log) + self.module.load_state_dict(checkpoint['model_state_dict']) + self.module.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) def get_total_loss(self, input_tuple): input_tensor, input_labels = input_tuple @@ -24,16 +28,11 @@ def get_total_loss(self, input_tuple): def generate_update_dict(self, input_data, input_labels=None, batch_step=0, update_dict=None): if update_dict is None: update_dict = super(ConvLcaModel, self).generate_update_dict(input_data, input_labels, batch_step) - epoch = batch_step / self.params.batches_per_epoch - stat_dict = { - 'epoch':int(epoch), - 'batch_step':batch_step, - 'train_progress':np.round(batch_step/self.params.num_batches, 3), - 'weight_lr':self.scheduler.get_lr()[0]} latents = self.get_encodings(input_data) recon = self.get_recon_from_latents(latents) recon_loss = losses.half_squared_l2(input_data, recon).item() sparse_loss = self.params.sparse_mult * losses.l1_norm(latents).item() + stat_dict['weight_lr'] = self.scheduler.get_lr()[0] stat_dict['loss_recon'] = recon_loss stat_dict['loss_sparse'] = sparse_loss stat_dict['loss_total'] = recon_loss + sparse_loss diff --git a/models/ensemble_model.py b/models/ensemble_model.py index 60bbb343..9e296f60 100644 --- a/models/ensemble_model.py +++ b/models/ensemble_model.py @@ -24,14 +24,22 @@ def setup_module(self, params): subparams.data_shape = params.data_shape super(EnsembleModel, self).setup_ensemble_module(params) self.submodel_classes = [] - for submodel_params in self.params.ensemble_params: - self.submodel_classes.append(loaders.load_model_class(submodel_params.model_type)) + for ensemble_index, submodel_params in enumerate(self.params.ensemble_params): + submodule_class = loaders.load_model_class(submodel_params.model_type) + self.submodel_classes.append(submodule_class) + if submodel_params.checkpoint_boot_log != '': + checkpoint = self.get_checkpoint_from_log(submodule_params.checkpoint_boot_log) + submodule = self.__getitem__(ensemble_index) + submodule.load_state_dict(checkpoint['model_state_dict']) def setup_optimizer(self): for module in self: module.optimizer = self.get_optimizer( optimizer_params=module.params, trainable_variables=module.parameters()) + if module.params.checkpoint_boot_log != '': + checkpoint = self.get_checkpoint_from_log(module.params.checkpoint_boot_log) + module.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) module.scheduler = torch.optim.lr_scheduler.MultiStepLR( module.optimizer, milestones=module.params.optimizer.milestones, diff --git a/models/lca_model.py b/models/lca_model.py index ec13c014..7d1a66e2 100644 --- a/models/lca_model.py +++ b/models/lca_model.py @@ -11,6 +11,10 @@ def setup(self, params, logger=None): super(LcaModel, self).setup(params, logger) self.setup_module(params) self.setup_optimizer() + if params.checkpoint_boot_log != '': + checkpoint = self.get_checkpoint_from_log(params.checkpoint_boot_log) + self.module.load_state_dict(checkpoint['model_state_dict']) + self.module.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) def get_total_loss(self, input_tuple): input_tensor, input_labels = input_tuple @@ -24,16 +28,12 @@ def get_total_loss(self, input_tuple): def generate_update_dict(self, input_data, input_labels=None, batch_step=0, update_dict=None): if update_dict is None: update_dict = super(LcaModel, self).generate_update_dict(input_data, input_labels, batch_step) - epoch = batch_step / self.params.batches_per_epoch - stat_dict = { - 'epoch':int(epoch), - 'batch_step':batch_step, - 'train_progress':np.round(batch_step/self.params.num_batches, 3), - 'weight_lr':self.scheduler.get_lr()[0]} + stat_dict = dict() latents = self.get_encodings(input_data) recon = self.get_recon_from_latents(latents) recon_loss = losses.half_squared_l2(input_data, recon).item() sparse_loss = self.params.sparse_mult * losses.l1_norm(latents).item() + stat_dict['weight_lr'] = self.scheduler.get_lr()[0] stat_dict['loss_recon'] = recon_loss stat_dict['loss_sparse'] = sparse_loss stat_dict['loss_total'] = recon_loss + sparse_loss diff --git a/models/mlp_model.py b/models/mlp_model.py index bf755d12..dc6b97f7 100644 --- a/models/mlp_model.py +++ b/models/mlp_model.py @@ -10,6 +10,10 @@ def setup(self, params, logger=None): super(MlpModel, self).setup(params, logger) self.setup_module(params) self.setup_optimizer() + if params.checkpoint_boot_log != '': + checkpoint = self.get_checkpoint_from_log(params.checkpoint_boot_log) + self.module.load_state_dict(checkpoint['model_state_dict']) + self.module.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) def get_total_loss(self, input_tuple): input_tensor, input_label = input_tuple @@ -20,16 +24,12 @@ def get_total_loss(self, input_tuple): def generate_update_dict(self, input_data, input_labels=None, batch_step=0, update_dict=None): if update_dict is None: update_dict = super(MlpModel, self).generate_update_dict(input_data, input_labels, batch_step) - epoch = batch_step / self.params.batches_per_epoch - stat_dict = { - 'epoch':int(epoch), - 'batch_step':batch_step, - 'train_progress':np.round(batch_step/self.params.num_batches, 3)} + stat_dict = dict() pred = self.forward(input_data) - #total_loss = F.nll_loss(pred, input_labels) total_loss = self.loss_fn(pred, input_labels) pred = pred.max(1, keepdim=True)[1] correct = pred.eq(input_labels.view_as(pred)).sum().item() + stat_dict['weight_lr'] = self.scheduler.get_lr()[0] stat_dict['loss'] = total_loss.item() stat_dict['train_accuracy'] = 100. * correct / self.params.batch_size update_dict.update(stat_dict) diff --git a/modules/mlp_module.py b/modules/mlp_module.py index 4877d8eb..1847b483 100644 --- a/modules/mlp_module.py +++ b/modules/mlp_module.py @@ -1,36 +1,102 @@ +import numpy as np import torch.nn as nn import torch.nn.functional as F from DeepSparseCoding.modules.activations import activation_picker +import DeepSparseCoding.utils.data_processing as dp class MlpModule(nn.Module): def setup_module(self, params): + def compute_conv_output_shape(in_length, kernel_size, stride, padding=0, dilation=1): + out_shape = ((in_length + 2 * padding - dilation * (kernel_size - 1) - 1) / stride) + 1 + return np.floor(out_shape).astype(np.int) self.params = params self.act_funcs = [activation_picker(act_func_str) for act_func_str in self.params.activation_functions] + self.layer_output_shapes = [self.params.data_shape] # [channels, height, width] self.layers = [] + self.pooling = [] self.dropout = [] for layer_index, layer_type in enumerate(self.params.layer_types): if layer_type == 'fc': + if(layer_index > 0 and self.params.layer_types[layer_index-1] == 'conv'): + in_features = np.prod(self.layer_output_shapes[-1]).astype(np.int) + else: + in_features = self.params.layer_channels[layer_index] layer = nn.Linear( - in_features = self.params.layer_channels[layer_index], - out_features = self.params.layer_channels[layer_index+1], - bias = True) + in_features=in_features, + out_features=self.params.layer_channels[layer_index + 1], + bias=True) self.register_parameter('fc'+str(layer_index)+'_w', layer.weight) self.register_parameter('fc'+str(layer_index)+'_b', layer.bias) self.layers.append(layer) + self.layer_output_shapes.append(self.params.layer_channels[layer_index + 1]) + elif layer_type == 'conv': + layer = nn.Conv2d( + in_channels=self.params.layer_channels[layer_index], + out_channels=self.params.layer_channels[layer_index + 1], + kernel_size=self.params.kernel_sizes[layer_index], + stride=self.params.strides[layer_index], + padding=0, + dilation=1, + bias=True) + self.register_parameter('conv'+str(layer_index)+'_w', layer.weight) + self.register_parameter('conv'+str(layer_index)+'_b', layer.bias) + self.layers.append(layer) + output_channels = self.params.layer_channels[layer_index + 1] + output_height = compute_conv_output_shape( + self.layer_output_shapes[-1][1], + self.params.kernel_sizes[layer_index], + self.params.strides[layer_index], + padding=0, + dilation=1) + output_width = compute_conv_output_shape( + self.layer_output_shapes[-1][2], + self.params.kernel_sizes[layer_index], + self.params.strides[layer_index], + padding=0, + dilation=1) + self.layer_output_shapes.append([output_channels, output_height, output_width]) + else: + assert False, ('layer_type parameter must be "fc" or "conv", not %g'%(layer_type)) + if(self.params.max_pool[layer_index] and layer_type == 'conv'): + self.pooling.append(nn.MaxPool2d( + kernel_size=self.params.pool_ksizes[layer_index], + stride=self.params.pool_strides[layer_index], + padding=0, + dilation=1)) + output_channels = self.params.layer_channels[layer_index + 1] + output_height = compute_conv_output_shape( + self.layer_output_shapes[-1][1], + self.params.pool_ksizes[layer_index], + self.params.pool_strides[layer_index], + padding=0, + dilation=1) + output_width = compute_conv_output_shape( + self.layer_output_shapes[-1][2], + self.params.pool_ksizes[layer_index], + self.params.pool_strides[layer_index], + padding=0, + dilation=1) + self.layer_output_shapes.append([output_channels, output_height, output_width]) else: - assert False, ('layer_type parameter must be "fc", not %g'%(layer_type)) + self.pooling.append(nn.Identity()) # do nothing self.dropout.append(nn.Dropout(p=self.params.dropout_rate[layer_index])) def preprocess_data(self, input_tensor): - input_tensor = input_tensor.view(-1, self.params.layer_channels[0]) + if self.params.layer_types[0] == 'fc': + input_tensor = input_tensor.view(self.params.batch_size, -1) # flatten input return input_tensor def forward(self, x): - for dropout, act_func, layer in zip(self.dropout, self.act_funcs, self.layers): - x = dropout(act_func(layer(x))) + layer_zip = zip(self.dropout, self.pooling, self.act_funcs, self.layers) + for layer_index, (dropout, pooling, act_func, layer) in enumerate(layer_zip): + prev_layer = self.params.layer_types[layer_index - 1] + current_layer = self.params.layer_types[layer_index] + if(layer_index > 0 and current_layer == 'fc' and prev_layer == 'conv'): + x = dp.flatten_feature_map(x) + x = dropout(pooling(act_func(layer(x)))) return x def get_encodings(self, input_tensor): diff --git a/params/base_params.py b/params/base_params.py index 743278a3..056dd3c9 100644 --- a/params/base_params.py +++ b/params/base_params.py @@ -10,6 +10,8 @@ class BaseParams(object): all models batch_size [int] number of images in a training batch center_dataset [bool] if True, subtract the mean dataset image from all datapoints + checkpoint_boot_log [str] path to a training log file for booting from checkpoint + if set, all specified model params must mach those in the log file #TODO: meaningful errors if not data_dir [str] location of dataset folders device [str] which device to run on dtype [torch dtype] dtype for network variables @@ -39,6 +41,9 @@ class BaseParams(object): workspace_dir [str] system directory that is the parent to the primary repository directory num_validation [int] number of images to reserve for the validation set (only works with some datasets) + ensemble + allow_parent_grads [bool] if True, allow loss gradients to propagate through all members of the ensemble + mlp activation_functions [list of str] strings correspond to activation functions for layers. len must equal the len of layer_types @@ -48,6 +53,11 @@ class BaseParams(object): layer_types [list of str] weight connectivity type, either "conv" or "fc" len must be equal to the len of layer_channels - 1 layer_channels [list of int] number of outputs per layer, including the input layer + kernel_sizes [list of ints] number of pixels on the edge of a square kernel, only used if layer_types is "conv" + strides [list of ints] number of pixels for the convolutional stride, assumes equal horizontal and vertical strides and is only used if layer_types is "conv" + max_pool [list of bools] if True, the network includes a max pooling op after the conv/fc op and before the dropout op + pool_ksizes [list of ints] number of pixels on the edge of a square max pooling kernel + pool_strides [list of ints] number of pixels in pooling stride, assumes equal and horizontal strides lca dt [float] discrete global time constant for neuron dynamics @@ -71,6 +81,7 @@ def __init__(self): def set_params(self): self.center_dataset = False + self.checkpoint_boot_log = '' self.rescale_data_to_one = False self.standardize_data = False self.model_type = None diff --git a/params/lca_mlp_cifar10_params.py b/params/lca_mlp_cifar10_params.py new file mode 100644 index 00000000..66ed1a32 --- /dev/null +++ b/params/lca_mlp_cifar10_params.py @@ -0,0 +1,72 @@ +import os +import types +import numpy as np +import torch + +from DeepSparseCoding.params.base_params import BaseParams +from DeepSparseCoding.params.lca_mnist_params import params as LcaParams +from DeepSparseCoding.params.mlp_mnist_params import params as MlpParams + + +class shared_params(object): + def __init__(self): + self.model_type = 'ensemble' + self.model_name = 'lca_mlp_cifar10' + self.version = '0' + self.dataset = 'cifar10' + self.standardize_data = True + self.batch_size = 25 + self.num_epochs = 500 + self.train_logs_per_epoch = 4 + self.allow_parent_grads = False + + +class lca_params(LcaParams): + def set_params(self): + super(lca_params, self).set_params() + for key, value in shared_params().__dict__.items(): + setattr(self, key, value) + self.model_type = 'lca' + self.weight_decay = 0.0 + self.weight_lr = 0.001 + self.optimizer = types.SimpleNamespace() + self.optimizer.name = 'sgd' + self.optimizer.lr_annealing_milestone_frac = [0.8] # fraction of num_epochs + self.optimizer.lr_decay_rate = 0.8 + self.renormalize_weights = True + self.dt = 0.001 + self.tau = 0.2 + self.num_steps = 75 + self.rectify_a = True + self.thresh_type = 'hard' + self.sparse_mult = 0.30 + self.num_latent = 512 + self.checkpoint_boot_log = '' + self.compute_helper_params() + + +class mlp_params(MlpParams): + def set_params(self): + super(mlp_params, self).set_params() + for key, value in shared_params().__dict__.items(): + setattr(self, key, value) + self.model_type = 'mlp' + self.weight_lr = 2e-3 + self.weight_decay = 1e-6 + self.layer_types = ['fc'] + self.layer_channels = [512, 10] + self.activation_functions = ['identity'] + self.dropout_rate = [0.0] # probability of value being set to zero + self.optimizer = types.SimpleNamespace() + self.optimizer.name = 'adam' + self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs + self.optimizer.lr_decay_rate = 0.8 + self.compute_helper_params() + + +class params(BaseParams): + def set_params(self): + super(params, self).set_params() + self.ensemble_params = [lca_params(), mlp_params()] + for key, value in shared_params().__dict__.items(): + setattr(self, key, value) diff --git a/params/lca_mlp_mnist_params.py b/params/lca_mlp_mnist_params.py index 5cfb6b11..537e96ba 100644 --- a/params/lca_mlp_mnist_params.py +++ b/params/lca_mlp_mnist_params.py @@ -20,6 +20,7 @@ def __init__(self): self.batch_size = 100 self.num_epochs = 1200 self.train_logs_per_epoch = 4 + self.allow_parent_grads = False class lca_params(LcaParams): @@ -41,8 +42,8 @@ def set_params(self): self.rectify_a = True self.thresh_type = 'soft' self.sparse_mult = 0.25 - self.num_latent = 768#self.num_pixels*4 - #self.allow_parent_grads = False # TODO: enable this param + self.num_latent = 768 + self.checkpoint_boot_log = '' self.compute_helper_params() @@ -55,7 +56,7 @@ def set_params(self): self.weight_lr = 1e-4 self.weight_decay = 0.0 self.layer_types = ['fc'] - self.layer_channels = [768, 10]#[self.num_pixels*4, 10] + self.layer_channels = [768, 10] self.activation_functions = ['identity'] self.dropout_rate = [0.0] # probability of value being set to zero self.optimizer = types.SimpleNamespace() diff --git a/params/mlp_cifar10_params.py b/params/mlp_cifar10_params.py new file mode 100644 index 00000000..24aa8f29 --- /dev/null +++ b/params/mlp_cifar10_params.py @@ -0,0 +1,43 @@ +import os +import types + +import numpy as np +import torch + +from DeepSparseCoding.params.base_params import BaseParams + + +class params(BaseParams): + def set_params(self): + super(params, self).set_params() + self.model_type = 'mlp' + self.model_name = 'mlp_cifar10' + self.version = '0' + self.dataset = 'cifar10' + self.standardize_data = True + self.rescale_data_to_one = False + self.center_data = False + self.num_validation = 1000 + self.batch_size = 50 + self.num_epochs = 500 + self.weight_decay = 3e-6 + self.weight_lr = 2e-3 + self.layer_types = ['conv', 'fc'] + self.layer_channels = [3, 512, 10] + self.kernel_sizes = [8, None] + self.strides = [2, None] + self.activation_functions = ['lrelu', 'identity'] + self.dropout_rate = [0.5, 0.0] # probability of value being set to zero + self.max_pool = [True, False] + self.pool_ksizes = [5, None] + self.pool_strides = [4, None] + self.train_logs_per_epoch = 4 + self.optimizer = types.SimpleNamespace() + self.optimizer.name = 'adam' + self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs + self.optimizer.lr_decay_rate = 0.1 + + def compute_helper_params(self): + super(params, self).compute_helper_params() + self.optimizer.milestones = [frac * self.num_epochs + for frac in self.optimizer.lr_annealing_milestone_frac] diff --git a/params/mlp_mnist_params.py b/params/mlp_mnist_params.py index 808078fd..dcc7e195 100644 --- a/params/mlp_mnist_params.py +++ b/params/mlp_mnist_params.py @@ -17,15 +17,16 @@ def set_params(self): self.fast_mnist = True self.standardize_data = False self.rescale_data_to_one = False - self.num_pixels = 28*28*1 self.batch_size = 50 self.num_epochs = 300 self.weight_lr = 5e-4 self.weight_decay = 2e-6 self.layer_types = ['fc', 'fc'] + self.num_pixels = 28*28*1 self.layer_channels = [self.num_pixels, 768, 10] self.activation_functions = ['lrelu', 'identity'] self.dropout_rate = [0.5, 0.0] # probability of value being set to zero + self.max_pool = [False, False] self.train_logs_per_epoch = 4 self.optimizer = types.SimpleNamespace() self.optimizer.name = 'adam' diff --git a/params/test_params.py b/params/test_params.py index a9ea2da2..be0af4e9 100644 --- a/params/test_params.py +++ b/params/test_params.py @@ -36,6 +36,7 @@ def __init__(self): self.num_test_images = 0 self.standardize_data = False self.rescale_data_to_one = False + self.allow_parent_grads = False self.num_epochs = 3 self.train_logs_per_epoch = 1 @@ -96,6 +97,7 @@ def set_params(self): self.layer_channels = [128, 10] self.activation_functions = ['identity'] self.dropout_rate = [0.0] # probability of value being set to zero + self.max_pool = [False] self.optimizer = types.SimpleNamespace() self.optimizer.name = 'adam' self.optimizer.lr_annealing_milestone_frac = [0.8] # fraction of num_epochs diff --git a/tests/test_data_processing.py b/tests/test_data_processing.py index 7fa116f3..96da0b9e 100644 --- a/tests/test_data_processing.py +++ b/tests/test_data_processing.py @@ -20,9 +20,9 @@ def test_reshape_data(self): function call: reshape_data(data, flatten=None, out_shape=None): 24 possible conditions: data: [np.ndarray] data of shape: - n is num_examples, i is num_rows, j is num_cols, k is num_channels, l is num_examples = i*j*k - (l) - single data point of shape l, assumes 1 color channel - (n, l) - n data points, each of shape l (flattened) + n is num_examples, i is num_channels, j is num_rows, k is num_cols + (i) - single data point of shape i + (n, i) - n data points, each of shape i (flattened) (i, j, k) - single datapoint of of shape (i, j, k) (n, i, j, k) - n data points, each of shape (i,j,k) flatten: True, False, None @@ -43,8 +43,8 @@ def test_reshape_data(self): input_array_list = [ np.zeros((num_elements)), # assumed num_examples == 1 np.zeros((num_examples, num_elements)), - np.zeros((num_rows, num_cols, num_channels)), # assumed num_examples == 1 - np.zeros((num_examples, num_rows, num_cols, num_channels))] + np.zeros((num_channels, num_rows, num_cols)), # assumed num_examples == 1 + np.zeros((num_examples, num_channels, num_rows, num_cols))] for input_array in input_array_list: input_shape = input_array.shape input_ndim = input_array.ndim @@ -53,7 +53,7 @@ def test_reshape_data(self): out_shape_list = [ None, (num_elements,), - (num_rows, num_cols, num_channels)] + (num_channels, num_rows, num_cols)] if(num_channels == 1): out_shape_list.append((num_rows, num_cols)) else: @@ -61,7 +61,7 @@ def test_reshape_data(self): out_shape_list = [ None, (num_examples, num_elements), - (num_examples, num_rows, num_cols, num_channels)] + (num_examples, num_channels, num_rows, num_cols)] if(num_channels == 1): out_shape_list.append((num_examples, num_rows, num_cols)) for out_shape in out_shape_list: @@ -82,7 +82,7 @@ def test_reshape_data(self): reshaped_array = reshape_outputs[0].numpy() err_msg += f'\nreshaped_array.shape={reshaped_array.shape}' self.assertEqual(reshape_outputs[1], input_shape, err_msg) # orig_shape - (resh_num_examples, resh_num_rows, resh_num_cols, resh_num_channels) = reshape_outputs[2:] + (resh_num_examples, resh_num_channels, resh_num_rows, resh_num_cols) = reshape_outputs[2:] err_msg += (f'\nfunction_shape_outputs={reshape_outputs[2:]}') if(out_shape is None): if(flatten is None): @@ -104,26 +104,26 @@ def test_reshape_data(self): expected_out_shape, err_msg) self.assertEqual( - resh_num_rows*resh_num_cols*resh_num_channels, + resh_num_channels * resh_num_rows * resh_num_cols, expected_out_shape[1], err_msg) elif(flatten == False): - expected_out_shape = (num_examples, num_rows, num_cols, num_channels) + expected_out_shape = (num_examples, num_channels, num_rows, num_cols) err_msg += f'\nexpected_out_shape={expected_out_shape}' self.assertEqual( reshaped_array.shape, expected_out_shape, err_msg) self.assertEqual( - resh_num_rows, + resh_num_channels, expected_out_shape[1], err_msg) self.assertEqual( - resh_num_cols, + resh_num_rows, expected_out_shape[2], err_msg) self.assertEqual( - resh_num_channels, + resh_num_cols, expected_out_shape[3], err_msg) else: @@ -149,7 +149,7 @@ def test_standardize(self): unflat_shape = [8, 4, 4, 3] flat_shape = [8, 4*4*3] shape_options = [unflat_shape, flat_shape] - eps_options = [1e-6, None] + eps_options = [1e-8, None] samplewise_options = [True, False] for shape in shape_options: for eps_val in eps_options: @@ -262,10 +262,10 @@ def test_patches(self): num_patches = np.int(num_im * (im_edge / patch_edge)**2) rand_seed = 1234 rand_state = np.random.RandomState(rand_seed) - data = np.stack([rand_state.normal(rand_mean, rand_var, size=[im_edge, im_edge, im_chan]) + data = np.stack([rand_state.normal(rand_mean, rand_var, size=[im_chan, im_edge, im_edge]) for _ in range(num_im)]) data_shape = list(data.shape) - patch_shape = [patch_edge, patch_edge, im_chan] + patch_shape = [im_chan, patch_edge, patch_edge] datapoint = torch.tensor(data[0, ...]) datapoint_patches = dp.single_image_to_patches(datapoint, patch_shape) datapoint_recon = dp.patches_to_single_image(datapoint_patches, data_shape[1:]) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 85463134..dc2f1fc1 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -66,5 +66,10 @@ def test_synthetic(self): setattr(params, key, value) assert len(train_loader.dataset) == epoch_size for batch_idx, (data, target) in enumerate(train_loader): - assert data.numpy().shape == (params.batch_size, params.data_edge_size, params.data_edge_size, 1) + expected_size = ( + params.batch_size, + 1, + params.data_edge_size, + params.data_edge_size) + assert data.numpy().shape == expected_size assert batch_idx + 1 == epoch_size // params.batch_size diff --git a/tests/test_models.py b/tests/test_models.py index 426a903c..25aa3842 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -19,6 +19,7 @@ def setUp(self): self.model_list = loaders.get_model_list(self.dsc_dir) self.test_params_file = os.path.join(*[self.dsc_dir, 'params', 'test_params.py']) + ### TODO - add ability to test multiple options (e.g. 'conv' and 'fc') from test params def test_model_loading(self): for model_type in self.model_list: model_type = '_'.join(model_type.split('_')[:-1]) # remove '_model' at the end @@ -29,12 +30,28 @@ def test_model_loading(self): setattr(params, key, value) model.setup(params) - ### TODO - more basic test to compute gradients per model### + + ### TODO - more basic test to compute gradients per model #def test_gradients(self): # for model_type in self.model_list: # model_type = ''.join(model_type.split('_')[:-1]) # remove '_model' at the end # model = loaders.load_model(model_type) + ### TODO - test for gradient blocking + #def test_get_module_encodings(self): + # """ + # Test for gradient blocking in the get_module_encodings function + + # construct test model1 & model2 + # construct test ensemble model = model1 -> model2 + # get encoding & grads for allow_grads={True, False} + # False: compare grads for model1 alone vs model1 in ensemble + # True: ensure that grad is different from model1 alone + # * Should also manually compute grads to compare? + # """ + # # test should utilize run_utils.get_module_encodings() + + def test_lca_ensemble_gradients(self): params = {} models = {} diff --git a/train_model.py b/train_model.py index d4e5d727..e8aed99d 100644 --- a/train_model.py +++ b/train_model.py @@ -35,7 +35,7 @@ # Train model for epoch in range(1, model.params.num_epochs+1): run_utils.train_epoch(epoch, model, train_loader) - if(model.params.model_type.lower() in ['mlp', 'ensemble']): + if(model.params.model_type.lower() in ['mlp', 'ensemble']): # TODO: use to validation set here; test at the end of training run_utils.test_epoch(epoch, model, test_loader) model.log_info(f'Completed epoch {epoch}/{model.params.num_epochs}') print(f'Completed epoch {epoch}/{model.params.num_epochs}') diff --git a/utils/data_processing.py b/utils/data_processing.py index 8c828b57..0679928a 100644 --- a/utils/data_processing.py +++ b/utils/data_processing.py @@ -8,18 +8,18 @@ def reshape_data(data, flatten=None, out_shape=None): Keyword arguments: data: [tensor] data of shape: - n is num_examples, i is num_rows, j is num_cols, k is num_channels, l is num_examples = i*j*k + n is num_examples, i is num_channels (a.k.a. vector length for fattened inputs), j is num_rows, k is num_cols if out_shape is not specified, it is assumed that i == j - (l) - single data point of shape l, assumes 1 color channel - (n, l) - n data points, each of shape l (flattened) - (i, j, k) - single datapoint of of shape (i,j, k) - (n, i, j, k) - n data points, each of shape (i,j,k) + (i) - single data point of shape i (flattened; assumed j = k = 1) + (n, i) - n data points, each of shape i (flattened; assumed j = k = 1) + (i, j, k) - single datapoint of of shape (i, j, k) + (n, i, j, k) - n data points, each of shape (i, j, k) flatten: [bool or None] specify the shape of the output If out_shape is not None, this arg has no effect If None, do not reshape data, but add num_examples dimension if necessary - If True, return ravelled data of shape (num_examples, num_elements) - If False, return unravelled data of shape (num_examples, sqrt(l), sqrt(l), 1) - where l is the number of elements (dimensionality) of the datapoints + If True, return ravelled data of shape (num_examples, num_elements), where num_elements=i*j*k + If False, return unravelled data of shape (num_examples, 1, sqrt(i), sqrt(i)) + where i is the number of elements (size) of the data points and 1 is the assumed number of channels If data is flat and flatten==True, or !flat and flatten==False, then None condition will apply out_shape: [list or tuple] containing the desired output shape This will overwrite flatten, and return the input reshaped according to out_shape @@ -27,13 +27,13 @@ def reshape_data(data, flatten=None, out_shape=None): Outputs: tuple containing: data: [tensor] data with new shape - (num_examples, num_rows, num_cols, num_channels) if flatten==False - (num_examples, num_elements) if flatten==True + (num_examples, num_channels, num_rows, num_cols) if flatten==False + (num_examples, num_elements) if flatten==True, where num_elements = num_channels*num_rows*num_cols orig_shape: [tuple of int32] original shape of the input data num_examples: [int32] number of data examples or None if out_shape is specified + num_channels: [int32] number of data channels or None if out_shape is specified num_rows: [int32] number of data rows or None if out_shape is specified num_cols: [int32] number of data cols or None if out_shape is specified - num_channels: [int32] number of data channels or None if out_shape is specified """ orig_shape = data.shape orig_ndim = data.ndim @@ -50,7 +50,7 @@ def reshape_data(data, flatten=None, out_shape=None): elif flatten == True: num_rows = num_elements num_cols = 1 - data = torch.reshape(data, (num_examples, num_rows*num_cols*num_channels)) + data = torch.reshape(data, (num_examples, num_channels * num_rows * num_cols)) else: # flatten == False sqrt_num_elements = np.sqrt(num_elements) assert np.floor(sqrt_num_elements) == np.ceil(sqrt_num_elements), ( @@ -59,7 +59,7 @@ def reshape_data(data, flatten=None, out_shape=None): +' and data_shape='+str(orig_shape)) num_rows = int(sqrt_num_elements) num_cols = num_rows - data = torch.reshape(data, (num_examples, num_rows, num_cols, num_channels)) + data = torch.reshape(data, (num_examples, num_channels, num_rows, num_cols)) elif orig_ndim == 2: # already flattened (num_examples, num_elements) = data.shape if flatten is None or flatten == True: # don't reshape data @@ -73,28 +73,28 @@ def reshape_data(data, flatten=None, out_shape=None): num_rows = int(sqrt_num_elements) num_cols = num_rows num_channels = 1 - data = torch.reshape(data, (num_examples, num_rows, num_cols, num_channels)) + data = torch.reshape(data, (num_examples, num_channels, num_rows, num_cols)) else: assert False, ('flatten argument must be True, False, or None') elif orig_ndim == 3: # single data point num_examples = 1 - num_rows, num_cols, num_channels = data.shape + num_channels, num_rows, num_cols = data.shape if flatten == True: - data = torch.reshape(data, (num_examples, num_rows * num_cols * num_channels)) + data = torch.reshape(data, (num_examples, num_channels * num_rows * num_cols)) elif flatten is None or flatten == False: # already not flat - data = data[None, ...] + data = data[None, ...] # add singleton num_examples dimension else: assert False, ('flatten argument must be True, False, or None') - elif orig_ndim == 4: # not flat - num_examples, num_rows, num_cols, num_channels = data.shape + elif orig_ndim == 4: # multiple data points, not flat + num_examples, num_channels, num_rows, num_cols = data.shape if flatten == True: - data = torch.reshape(data, (num_examples, num_rows*num_cols*num_channels)) + data = torch.reshape(data, (num_examples, num_channels * num_rows * num_cols)) else: assert False, ('Data must have 1, 2, 3, or 4 dimensions.') else: num_examples = None; num_rows=None; num_cols=None; num_channels=None data = torch.reshape(data, out_shape) - return (data, orig_shape, num_examples, num_rows, num_cols, num_channels) + return (data, orig_shape, num_examples, num_channels, num_rows, num_cols) def check_all_same_shape(tensor_list): @@ -116,17 +116,17 @@ def check_all_same_shape(tensor_list): def flatten_feature_map(feature_map): """ - Flatten input tensor from [batch, y, x, f] to [batch, y*x*f] + Flatten input tensor from [batch, c, y, x] to [batch, c * y * x] Keyword arguments: - feature_map: tensor with shape [batch, y, x, f] + feature_map: tensor with shape [batch, c, y, x] Returns: - reshaped_map: tensor with shape [batch, y*x*f] + reshaped_map: tensor with shape [batch, c * y * x] """ map_shape = feature_map.shape if(len(map_shape) == 4): - (batch, y, x, f) = map_shape - prev_input_features = int(y * x * f) + (batch, c, y, x) = map_shape + prev_input_features = int(c * y * x) resh_map = torch.reshape(feature_map, [-1, prev_input_features]) elif(len(map_shape) == 2): resh_map = feature_map @@ -185,7 +185,7 @@ def center(data, samplewise=False, batch_size=100): Outputs: data: [tensor] centered data """ - data, orig_shape = reshape_data(data, flatten=True)[:2] # Adds channel dimension if it's missing + data, orig_shape = reshape_data(data, flatten=True)[:2] if(samplewise): # center each input sample individually data_axis = tuple(range(data.ndim)[1:]) data_mean = torch.mean(data, dim=data_axis, keepdim=True) @@ -214,16 +214,16 @@ def standardize(data, eps=None, samplewise=False, batch_size=100): data: [tensor] normalized data """ if(eps is None): - eps = 1.0 / np.sqrt(data[0,...].numel()) - data, orig_shape = reshape_data(data, flatten=True)[:2] # Adds channel dimension if it's missing + eps = 1.0 / data[0,...].numel() + data, orig_shape = reshape_data(data, flatten=True)[:2] num_examples = data.shape[0] if(samplewise): # standardize each input sample individually data_axis = tuple(range(data.ndim)[1:]) data_mean = torch.mean(data, dim=data_axis, keepdim=True) data_true_std = torch.std(data, unbiased=False, dim=data_axis, keepdim=True) else: # standardize the entire population - data_mean = torch.mean(data, dim=0) - data_true_std = torch.std(data, unbiased=False) + data_mean = torch.mean(data, dim=0, keepdim=True) + data_true_std = torch.std(data, dim=0, unbiased=False, keepdim=True) data_std = torch.where(data_true_std >= eps, data_true_std, eps*torch.ones_like(data_true_std)) data = (data - data_mean) / data_std if(data.shape != orig_shape): @@ -357,9 +357,9 @@ def single_image_to_patches(image, patch_shape): Extract patches from a single image Keyword arguments: - image [torch tensor] of shape [im_height, im_width, im_chan] + image [torch tensor] of shape [im_chan, im_height, im_width] patch_shape [tuple or list] containing the output shape - [patch_height, patch_width, patch_chan] + [patch_chan, patch_height, patch_width] patch_chan must be the same as im_chan It is recommended, though not required, that the patch height and width divide evenly into @@ -369,25 +369,25 @@ def single_image_to_patches(image, patch_shape): patches [torch tensor] of patches of shape [num_patches]+list(patch_shape) """ try: - im_height, im_width, im_chan = image.shape - patch_height, patch_width, patch_chan = patch_shape + im_chan, im_height, im_width = image.shape + patch_chan, patch_height, patch_width = patch_shape except Exception as e: raise ValueError( f'This function requires that: ' - +f'1) The input variable "image" must have shape [im_height, im_width, im_chan], and is {image.shape}' - +f'and 2) the input variable "patch_shape" must have shape [patch_height, patch_width, patch_chan], and is {patch_shape}.' + +f'1) The input variable "image" must have shape [im_chan, im_height, im_width], and is {image.shape}' + +f'and 2) the input variable "patch_shape" must have shape [patch_chan, patch_height, patch_width], and is {patch_shape}.' ) from e num_row_patches = np.floor(im_height / patch_height) num_col_patches = np.floor(im_width / patch_width) num_patches = int(num_row_patches * num_col_patches) - patches = torch.zeros((num_patches, patch_height, patch_width, patch_chan)) + patches = torch.zeros((num_patches, patch_chan, patch_height, patch_width)) row_id = 0 col_id = 0 for patch_idx in range(num_patches): row_end = row_id + patch_height col_end = col_id + patch_width try: - patches[patch_idx, ...] = image[row_id:row_end, col_id:col_end, :] + patches[patch_idx, ...] = image[:, row_id:row_end, col_id:col_end] except Exception as e: raise ValueError('This function requires that im_chan equal patch_chan.') from e row_id += patch_height @@ -404,30 +404,30 @@ def patches_to_single_image(patches, image_shape): Convert patches input into a single ouput Keyword arguments: - patches [torch tensor] of shape [num_patches, patch_height, patch_width, patch_chan] - image_shape [list or tuple] of length 2 containing the image shape [im_height, im_width, im_chan] + patches [torch tensor] of shape [num_patches, patch_chan, patch_height, patch_width] + image_shape [list or tuple] of length 2 containing the image shape [im_chan, im_height, im_width] im_chan is assumed to equal patch_chan Outputs: - image [torch tensor] of shape [im_height, im_width, im_chan] + image [torch tensor] of shape [im_chan, im_height, im_width] """ try: - num_patches, patch_height, patch_width, patch_chan = patches.shape - im_height, im_width, im_chan = image_shape + num_patches, patch_chan, patch_height, patch_width = patches.shape + im_chan, im_height, im_width = image_shape except Exception as e: raise ValueError( f'This funciton requires that input patches has shape' - f' [num_patches, patch_height, patch_width, patch_chan] and is {patches.shape}' - f' and input image_shape is a list or tuple of integers of length 3 containing [im_height, im_width, im_chan] and is {image_shape}' + f' [num_patches, patch_chan, patch_height, patch_width] and is {patches.shape}' + f' and input image_shape is a list or tuple of integers of length 3 containing [im_chan, im_height, im_width] and is {image_shape}' ) from e - image = torch.zeros((im_height, im_width, im_chan)) + image = torch.zeros((im_chan, im_height, im_width)) row_id = 0 col_id = 0 for patch_idx in range(num_patches): row_end = row_id + patch_height col_end = col_id + patch_width - image[row_id:row_end, col_id:col_end, :] = patches[patch_idx, ...] + image[:, row_id:row_end, col_id:col_end] = patches[patch_idx, ...] row_id += patch_height if row_id >= im_height: row_id = 0 @@ -441,9 +441,9 @@ def images_to_patches(images, patch_shape): Extract evenly distributed non-overlapping patches from an image dataset Keyword arguments: - images [torch tensor] of shape [num_images, im_height, im_width, im_chan] or [im_height, im_width, im_chan] for a single image + images [torch tensor] of shape [num_images, im_chan, im_height, im_width] or [im_chan, im_height, im_width] for a single image patch_shape [tuple or list] containing the output shape - [patch_height, patch_width, patch_chan] + [patch_chan, patch_height, patch_width] patch_chan must be the same as im_chan It is recommended, though not required, that the patch height and width divide evenly into the image height and width, respectively. @@ -453,8 +453,8 @@ def images_to_patches(images, patch_shape): """ if images.ndim == 3: # single image return single_image_to_patches(images, patch_shape) - num_im, im_height, im_width, im_chan = images.shape - patch_height, patch_width, patch_chan = patch_shape + num_im, im_chan, im_height, im_width = images.shape + patch_chan, patch_height, patch_width = patch_shape num_row_patches = np.floor(im_height / patch_height) num_col_patches = np.floor(im_width / patch_width) num_patches_per_im = int(num_row_patches * num_col_patches) @@ -474,16 +474,16 @@ def patches_to_images(patches, image_shape): Recombine patches tensor into a dataset of images Keyword arguments: - patches [torch tensor] holding square patch data of shape [num_patches, patch_height, patch_width, patch_chan] - image_shape [list or tuple] containing the image dataset shape [im_height, im_width, im_chan] + patches [torch tensor] holding square patch data of shape [num_patches, patch_chan, patch_height, patch_width] + image_shape [list or tuple] containing the image dataset shape [im_chan, im_height, im_width] It is assumed that im_chan equals patch_chan Outputs: images [torch tensor] holding the recombined image dataset """ - tot_num_patches, patch_height, patch_width, patch_chan = patches.shape - im_height, im_width, im_chan = image_shape + tot_num_patches, patch_chan, patch_height, patch_width = patches.shape + im_chan, im_height, im_width = image_shape num_row_patches = np.floor(im_height / patch_height) num_col_patches = np.floor(im_width / patch_width) num_patches_per_im = int(num_row_patches * num_col_patches) diff --git a/utils/dataset_utils.py b/utils/dataset_utils.py index fd05bfc9..64571152 100644 --- a/utils/dataset_utils.py +++ b/utils/dataset_utils.py @@ -54,8 +54,7 @@ def load_dataset(params): if(params.dataset.lower() == 'mnist'): preprocessing_pipeline = [ transforms.ToTensor(), - transforms.Lambda(lambda x: x.permute(1, 2, 0)) # channels last - ] + ] if params.standardize_data: preprocessing_pipeline.append( transforms.Lambda(lambda x: dp.standardize(x, eps=params.eps)[0])) @@ -92,7 +91,6 @@ def load_dataset(params): elif(params.dataset.lower() == 'cifar10'): preprocessing_pipeline = [ transforms.ToTensor(), - transforms.Lambda(lambda x: x.permute(1, 2, 0)), # channels last ] kwargs = { 'root': os.path.join(*[params.data_dir,'cifar10']), @@ -153,9 +151,9 @@ def load_dataset(params): test_loader = None elif(params.dataset.lower() == 'synthetic'): - preprocessing_pipeline = [transforms.ToTensor(), - transforms.Lambda(lambda x: x.permute(1, 2, 0)) # channels last - ] + preprocessing_pipeline = [ + transforms.ToTensor(), + ] train_loader = torch.utils.data.DataLoader( synthetic.SyntheticImages(params.epoch_size, params.data_edge_size, params.dist_type, params.rand_state, params.num_classes, @@ -181,4 +179,5 @@ def load_dataset(params): else: new_params['num_test_images'] = len(test_loader.dataset) new_params['data_shape'] = list(next(iter(train_loader))[0].shape)[1:] + new_params['num_pixels'] = np.prod(new_params['data_shape']) return (train_loader, val_loader, test_loader, new_params) diff --git a/utils/loaders.py b/utils/loaders.py index 21b23af5..8e71480e 100644 --- a/utils/loaders.py +++ b/utils/loaders.py @@ -7,6 +7,7 @@ import DeepSparseCoding.utils.file_utils as file_utils + def get_dir_list(target_dir, target_string): dir_list = [filename.split('.')[0] for filename in os.listdir(target_dir) diff --git a/utils/run_utils.py b/utils/run_utils.py index 958f1a8f..10817f90 100644 --- a/utils/run_utils.py +++ b/utils/run_utils.py @@ -3,6 +3,13 @@ import DeepSparseCoding.utils.data_processing as dp +def get_module_encodings(module, data, allow_grads=False): + if allow_grads: + return module.get_encodings(data) + else: + return module.get_encodings(data).detach() + + def train_single_model(model, loss): model.optimizer.zero_grad() # clear gradietns of all optimized variables loss.backward() # backward pass @@ -20,13 +27,14 @@ def train_epoch(epoch, model, loader): for batch_idx, (data, target) in enumerate(loader): data, target = data.to(model.params.device), target.to(model.params.device) inputs = [] - if(model.params.model_type.lower() == 'ensemble'): # TODO: Move this to train_model + if(model.params.model_type.lower() == 'ensemble'): inputs.append(model[0].preprocess_data(data)) # First model preprocesses the input for submodule_idx, submodule in enumerate(model): loss = model.get_total_loss((inputs[-1], target), submodule_idx) train_single_model(submodule, loss) - # TODO: include optional parameter to allow gradients to propagate through the entire ensemble. - inputs.append(submodule.get_encodings(inputs[-1]).detach()) # must detach to prevent gradient leaking + encodings = get_module_encodings(submodule, inputs[-1], + model.params.allow_parent_grads) + inputs.append(encodings) else: inputs.append(model.preprocess_data(data)) loss = model.get_total_loss((inputs[-1], target)) @@ -78,7 +86,7 @@ def test_epoch(epoch, model, loader, log_to_file=True): test_accuracy = 100. * correct / len(loader.dataset) stat_dict = { 'test_epoch':epoch, - 'test_loss':test_loss, + 'test_loss':test_loss.item(), 'test_correct':correct, 'test_total':len(loader.dataset), 'test_accuracy':test_accuracy} From 75a6a5ec02e72a5ae188f706adbcb9fcbd53d0ea Mon Sep 17 00:00:00 2001 From: Dylan Date: Thu, 11 Feb 2021 12:47:49 +0000 Subject: [PATCH 21/44] integrates conv_lca model & module into lca adds utility to laod parameters from a log file adds preprocessing capability to FastMNIST minor typo fixes --- models/base_model.py | 2 +- models/conv_lca_model.py | 46 --------------- models/ensemble_model.py | 8 +-- modules/conv_lca_module.py | 85 --------------------------- modules/lca_module.py | 110 +++++++++++++++++++++++++++++------ modules/mlp_module.py | 4 +- params/lca_cifar10_params.py | 3 +- params/lca_mnist_params.py | 6 +- params/test_params.py | 25 ++++---- tests/test_foolbox.py | 2 +- tests/test_models.py | 6 +- tests/test_param_loading.py | 2 +- train_model.py | 2 +- utils/data_processing.py | 3 +- utils/dataset_utils.py | 15 ++++- utils/loaders.py | 9 ++- utils/run_utils.py | 6 ++ 17 files changed, 152 insertions(+), 182 deletions(-) delete mode 100644 models/conv_lca_model.py delete mode 100644 modules/conv_lca_module.py diff --git a/models/base_model.py b/models/base_model.py index 4bef1ca0..dd0bc8d6 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -132,7 +132,7 @@ def write_checkpoint(self, batch_step=None): self.log_info('Full model saved in file %s'%self.params.cp_latest_filename) def get_checkpoint_from_log(self, logfile): - model_params = loaders.load_params(logfile) + model_params = loaders.load_params_from_log(logfile) checkpoint = torch.load(model_params.cp_latest_filename) return checkpoint diff --git a/models/conv_lca_model.py b/models/conv_lca_model.py deleted file mode 100644 index 36f54086..00000000 --- a/models/conv_lca_model.py +++ /dev/null @@ -1,46 +0,0 @@ -import numpy as np -import torch - -from DeepSparseCoding.models.base_model import BaseModel -from DeepSparseCoding.modules.conv_lca_module import ConvLcaModule -import DeepSparseCoding.modules.losses as losses - - -class ConvLcaModel(BaseModel, ConvLcaModule): - def setup(self, params, logger=None): - super(ConvLcaModel, self).setup(params, logger) - self.setup_module(params) - self.setup_optimizer() - if params.checkpoint_boot_log != '': - checkpoint = self.get_checkpoint_from_log(params.checkpoint_boot_log) - self.module.load_state_dict(checkpoint['model_state_dict']) - self.module.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) - - def get_total_loss(self, input_tuple): - input_tensor, input_labels = input_tuple - latents = self.get_encodings(input_tensor) - recon = self.get_recon_from_latents(latents) - recon_loss = losses.half_squared_l2(input_tensor, recon) - sparse_loss = self.params.sparse_mult * losses.l1_norm(latents) - total_loss = recon_loss + sparse_loss - return total_loss - - def generate_update_dict(self, input_data, input_labels=None, batch_step=0, update_dict=None): - if update_dict is None: - update_dict = super(ConvLcaModel, self).generate_update_dict(input_data, input_labels, batch_step) - latents = self.get_encodings(input_data) - recon = self.get_recon_from_latents(latents) - recon_loss = losses.half_squared_l2(input_data, recon).item() - sparse_loss = self.params.sparse_mult * losses.l1_norm(latents).item() - stat_dict['weight_lr'] = self.scheduler.get_lr()[0] - stat_dict['loss_recon'] = recon_loss - stat_dict['loss_sparse'] = sparse_loss - stat_dict['loss_total'] = recon_loss + sparse_loss - stat_dict['input_max_mean_min'] = [ - input_data.max().item(), input_data.mean().item(), input_data.min().item()] - stat_dict['recon_max_mean_min'] = [ - recon.max().item(), recon.mean().item(), recon.min().item()] - latent_nnz = torch.sum(latents != 0).item() # TODO: github issue 23907 requests torch.count_nonzero - stat_dict['latents_fraction_active'] = latent_nnz / latents.numel() - update_dict.update(stat_dict) - return update_dict diff --git a/models/ensemble_model.py b/models/ensemble_model.py index 9e296f60..196532fe 100644 --- a/models/ensemble_model.py +++ b/models/ensemble_model.py @@ -24,11 +24,11 @@ def setup_module(self, params): subparams.data_shape = params.data_shape super(EnsembleModel, self).setup_ensemble_module(params) self.submodel_classes = [] - for ensemble_index, submodel_params in enumerate(self.params.ensemble_params): - submodule_class = loaders.load_model_class(submodel_params.model_type) + for ensemble_index, subparams in enumerate(self.params.ensemble_params): + submodule_class = loaders.load_model_class(subparams.model_type) self.submodel_classes.append(submodule_class) - if submodel_params.checkpoint_boot_log != '': - checkpoint = self.get_checkpoint_from_log(submodule_params.checkpoint_boot_log) + if subparams.checkpoint_boot_log != '': + checkpoint = self.get_checkpoint_from_log(subparams.checkpoint_boot_log) submodule = self.__getitem__(ensemble_index) submodule.load_state_dict(checkpoint['model_state_dict']) diff --git a/modules/conv_lca_module.py b/modules/conv_lca_module.py deleted file mode 100644 index 84adba0b..00000000 --- a/modules/conv_lca_module.py +++ /dev/null @@ -1,85 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -from DeepSparseCoding.modules.lca_module import LcaModule -import DeepSparseCoding.utils.data_processing as dp - - -class ConvLcaModule(LcaModule): - """ - Parameters - ----------------------------- - data_shape [list of int] by default it is set to [h, w, c], however pytorch conv wants [c, h, w] so it is permuted in this module - Assumes h = w (i.e. square inputs) - in_channels [int] - Number of channels in the input image - Automatically set to params.num_pixels - out_channels [int] - Number of channels produced by the convolution - Automatically set to params.num_latent - kernel_size [int] - Edge size of the square convolving kernel - stride [int] - Vertical and horizontal stride of the convolution. - padding [int] - Zero-padding added to both sides of the input. - """ - def setup_module(self, params): - self.params = params - self.params.data_shape = [self.params.data_shape[2], self.params.data_shape[0], self.params.data_shape[1]] - self.input_shape = [self.params.batch_size] + self.params.data_shape - assert (self.input_shape[-1] % self.params.stride == 0), ( - f'Stride = {self.params.stride} must divide evenly into input edge size = {self.input_shape[-1]}') - self.w_shape = [ - self.params.out_channels, - self.params.in_channels, - self.params.kernel_size, - self.params.kernel_size - ] - dilation = 1 - conv_hout = int(1 + (self.input_shape[2] + 2 * self.params.padding - dilation * (self.params.kernel_size - 1) - 1) / self.params.stride) - conv_wout = conv_hout # Assumes square images - self.output_shape = [self.params.batch_size, self.params.out_channels, conv_hout, conv_wout] - w_init = torch.randn(self.w_shape) - w_init_normed = dp.l2_normalize_weights(w_init, eps=self.params.eps) - self.w = nn.Parameter(w_init_normed, requires_grad=True) - - def preprocess_data(self, input_tensor): - return input_tensor.permute(0, 3, 1, 2) - - def get_recon_from_latents(self, a_in): - recon = F.conv_transpose2d( - input=a_in, - weight=self.w, - bias=None, - stride=self.params.stride, - padding=self.params.padding - ) - return recon - - def step_inference(self, input_tensor, u_in, a_in, step): - recon = self.get_recon_from_latents(a_in) - recon_error = input_tensor - recon - error_injection = F.conv2d( - input=recon_error, - weight=self.w, - bias=None, - stride=self.params.stride, - padding=self.params.padding - ) - du = error_injection + a_in - u_in - u_out = u_in + self.params.step_size * du - return u_out - - def infer_coefficients(self, input_tensor): - u_list = [torch.zeros(self.output_shape, device=self.params.device)] - a_list = [self.threshold_units(u_list[0])] - for step in range(self.params.num_steps-1): - u = self.step_inference(input_tensor, u_list[step], a_list[step], step) - u_list.append(u) - a_list.append(self.threshold_units(u)) - return (u_list, a_list) - - def get_encodings(self, input_tensor): - u_list, a_list = self.infer_coefficients(input_tensor) - return a_list[-1] - - def forward(self, input_tensor): - latents = self.get_encodings(input_tensor) - return latents diff --git a/modules/lca_module.py b/modules/lca_module.py index bd397c7c..aee8d2be 100644 --- a/modules/lca_module.py +++ b/modules/lca_module.py @@ -3,54 +3,128 @@ import torch.nn.functional as F from DeepSparseCoding.modules.activations import lca_threshold +from DeepSparseCoding.utils.run_utils import compute_conv_output_shape import DeepSparseCoding.utils.data_processing as dp class LcaModule(nn.Module): + """ + Keyword arguments: + params: [dict] with keys: + data_shape [list of int] of shape [elements, channels, height, width]; Assumes h = w (i.e. square inputs) + The remaining keys are only used layer_type is "conv": + kernel_size: [int] edge size of the square convolving kernel + stride: [int] vertical and horizontal stride of the convolution + padding: [int] zero-padding added to both sides of the input + """ def setup_module(self, params): self.params = params - w_init = torch.randn([self.params.num_pixels, self.params.num_latent]) + if self.params.layer_type == 'fc': + self.layer_output_shapes = [[self.params.num_latent]] + self.w_shape = [self.params.num_pixels, self.params.num_latent] + else: + self.layer_output_shapes = [self.params.data_shape] # [channels, height, width] + assert (self.params.data_shape[-1] % self.params.stride == 0), ( + f'Stride = {self.params.stride} must divide evenly into input edge size = {self.params.data_shape[-1]}') + self.w_shape = [ + self.params.num_latent, + self.params.data_shape[0], # channels = 1 + self.params.kernel_size, + self.params.kernel_size + ] + output_height = compute_conv_output_shape( + self.layer_output_shapes[-1][1], + self.params.kernel_size, + self.params.stride, + self.params.padding, + dilation=1) + output_width = compute_conv_output_shape( + self.layer_output_shapes[-1][2], + self.params.kernel_size, + self.params.stride, + self.params.padding, + dilation=1) + self.layer_output_shapes.append([self.params.num_latent, output_height, output_width]) + w_init = torch.randn(self.w_shape) w_init_normed = dp.l2_normalize_weights(w_init, eps=self.params.eps) self.w = nn.Parameter(w_init_normed, requires_grad=True) def preprocess_data(self, input_tensor): - input_tensor = input_tensor.view(-1, self.params.num_pixels) + if self.params.layer_type == 'fc': + input_tensor = input_tensor.view(self.params.batch_size, -1) return input_tensor - def compute_excitatory_current(self, input_tensor): - return torch.matmul(input_tensor, self.w) + def compute_excitatory_current(self, input_tensor, a_in): + if self.params.layer_type == 'fc': + excitatory_current = torch.matmul(input_tensor, self.w) + else: + recon = self.get_recon_from_latents(a_in) + recon_error = input_tensor - recon + error_injection = F.conv2d( + input=recon_error, + weight=self.w, + bias=None, + stride=self.params.stride, + padding=self.params.padding + ) + excitatory_current = error_injection + a_in + return excitatory_current def compute_inhibitory_connectivity(self): - lca_g = torch.matmul(torch.transpose(self.w, dim0=0, dim1=1), - self.w) - torch.eye(self.params.num_latent, - requires_grad=True, device=self.params.device) - return lca_g + if self.params.layer_type == 'fc': + inhibitory_connectivity = torch.matmul(torch.transpose(self.w, dim0=0, dim1=1), + self.w) - torch.eye(self.params.num_latent, + requires_grad=True, device=self.params.device) + else: + inhibitory_connectivity = 0 # TODO: return Grammian along channel dim for a single kernel location + return inhibitory_connectivity def threshold_units(self, u_in): a_out = lca_threshold(u_in, self.params.thresh_type, self.params.rectify_a, self.params.sparse_mult) return a_out - def step_inference(self, u_in, a_in, b, g, step): - lca_explain_away = torch.matmul(a_in, g) - du = b - lca_explain_away - u_in + def step_inference(self, u_in, a_in, excitatory_current, inhibitory_connectivity, step): + if self.params.layer_type == 'fc': + lca_explain_away = torch.matmul(a_in, inhibitory_connectivity) + else: + lca_explain_away = 0 # already computed in excitatory_current + du = excitatory_current - lca_explain_away - u_in u_out = u_in + self.params.step_size * du return u_out, lca_explain_away def infer_coefficients(self, input_tensor): - lca_b = self.compute_excitatory_current(input_tensor) - lca_g = self.compute_inhibitory_connectivity() - u_list = [torch.zeros([input_tensor.shape[0], self.params.num_latent], - device=self.params.device)] + output_shape = [input_tensor.shape[0]] + self.layer_output_shapes[-1] + u_list = [torch.zeros(output_shape, device=self.params.device)] a_list = [self.threshold_units(u_list[0])] + excitatory_current = self.compute_excitatory_current(input_tensor, a_list[-1]) + inhibitory_connectivity = self.compute_inhibitory_connectivity() for step in range(self.params.num_steps-1): - u = self.step_inference(u_list[step], a_list[step], lca_b, lca_g, step)[0] + u = self.step_inference( + u_list[step], + a_list[step], + excitatory_current, + inhibitory_connectivity, + step + )[0] u_list.append(u) a_list.append(self.threshold_units(u)) + if self.params.layer_type == 'conv': + excitatory_current = self.compute_excitatory_current(input_tensor, a_list[-1]) return (u_list, a_list) - def get_recon_from_latents(self, latents): - return torch.matmul(latents, torch.transpose(self.w, dim0=0, dim1=1)) + def get_recon_from_latents(self, a_in): + if self.params.layer_type == 'fc': + recon = torch.matmul(a_in, torch.transpose(self.w, dim0=0, dim1=1)) + else: + recon = F.conv_transpose2d( + input=a_in, + weight=self.w, + bias=None, + stride=self.params.stride, + padding=self.params.padding + ) + return recon def get_encodings(self, input_tensor): u_list, a_list = self.infer_coefficients(input_tensor) diff --git a/modules/mlp_module.py b/modules/mlp_module.py index 1847b483..3ad59b77 100644 --- a/modules/mlp_module.py +++ b/modules/mlp_module.py @@ -3,14 +3,12 @@ import torch.nn.functional as F from DeepSparseCoding.modules.activations import activation_picker +from DeepSparseCoding.utils.run_utils import compute_conv_output_shape import DeepSparseCoding.utils.data_processing as dp class MlpModule(nn.Module): def setup_module(self, params): - def compute_conv_output_shape(in_length, kernel_size, stride, padding=0, dilation=1): - out_shape = ((in_length + 2 * padding - dilation * (kernel_size - 1) - 1) / stride) + 1 - return np.floor(out_shape).astype(np.int) self.params = params self.act_funcs = [activation_picker(act_func_str) for act_func_str in self.params.activation_functions] diff --git a/params/lca_cifar10_params.py b/params/lca_cifar10_params.py index 63e24d47..f25edce2 100644 --- a/params/lca_cifar10_params.py +++ b/params/lca_cifar10_params.py @@ -6,10 +6,11 @@ class params(BaseParams): def set_params(self): super(params, self).set_params() - self.model_type = 'conv_lca' + self.model_type = 'lca' self.model_name = 'conv_lca_cifar10' self.version = '0' self.dataset = 'cifar10' + self.layer_type = 'conv' self.num_validation = 10000 self.standardize_data = True self.rescale_data_to_one = False diff --git a/params/lca_mnist_params.py b/params/lca_mnist_params.py index 9d575e4b..9c46cc1b 100644 --- a/params/lca_mnist_params.py +++ b/params/lca_mnist_params.py @@ -3,12 +3,13 @@ from DeepSparseCoding.params.base_params import BaseParams -CONV = False +CONV = True class params(BaseParams): def set_params(self): super(params, self).set_params() + self.model_type = 'lca' self.version = '0' self.dataset = 'mnist' self.fast_mnist = True @@ -27,7 +28,7 @@ def set_params(self): self.weight_decay = 0.0 self.train_logs_per_epoch = 6 if CONV: - self.model_type = 'conv_lca' + self.layer_type = 'conv' self.model_name = 'conv_lca_mnist' self.rescale_data_to_one = True self.batch_size = 50 @@ -39,6 +40,7 @@ def set_params(self): self.padding = 0 self.num_latent = 128 else: + self.layer_type = 'fc' self.model_type = 'lca' self.model_name = 'lca_768_mnist' self.rescale_data_to_one = False diff --git a/params/test_params.py b/params/test_params.py index be0af4e9..529c1522 100644 --- a/params/test_params.py +++ b/params/test_params.py @@ -56,6 +56,7 @@ def set_params(self): self.model_type = 'lca' self.weight_decay = 0.0 self.weight_lr = 0.1 + self.layer_type = 'fc' self.optimizer = types.SimpleNamespace() self.optimizer.name = 'sgd' self.optimizer.lr_annealing_milestone_frac = [0.7] # fraction of num_epochs @@ -72,17 +73,19 @@ def set_params(self): for frac in self.optimizer.lr_annealing_milestone_frac] self.step_size = self.dt / self.tau -class conv_lca_params(lca_params): - def set_params(self): - super(conv_lca_params, self).set_params() - self.kernel_size = 8 - self.stride = 2 - self.padding = 0 - self.optimizer.milestones = [frac * self.num_epochs - for frac in self.optimizer.lr_annealing_milestone_frac] - self.step_size = self.dt / self.tau - self.out_channels = self.num_latent - self.in_channels = 1 +# TODO: Add ability to test multiple param values +#class conv_lca_params(lca_params): +# def set_params(self): +# super(conv_lca_params, self).set_params() +# self.layer_type = 'conv' +# self.kernel_size = 8 +# self.stride = 2 +# self.padding = 0 +# self.optimizer.milestones = [frac * self.num_epochs +# for frac in self.optimizer.lr_annealing_milestone_frac] +# self.step_size = self.dt / self.tau +# self.out_channels = self.num_latent +# self.in_channels = 1 class mlp_params(BaseParams): diff --git a/tests/test_foolbox.py b/tests/test_foolbox.py index b7f9cc53..3faf224d 100644 --- a/tests/test_foolbox.py +++ b/tests/test_foolbox.py @@ -29,7 +29,7 @@ # 'steps':3}} # max perturbation it can reach is 0.5 # attack = fa.LinfPGD(**attack_params['linfPGD']) # epsilons = [0.3] # allowed perturbation size -# params['ensemble'] = loaders.load_params(self.test_params_file, key='ensemble_params') +# params['ensemble'] = loaders.load_params_file(self.test_params_file, key='ensemble_params') # params['ensemble'].train_logs_per_epoch = None # params['ensemble'].shuffle_data = False # train_loader, val_loader, test_loader, data_params = datasets.load_dataset(params['ensemble']) diff --git a/tests/test_models.py b/tests/test_models.py index 25aa3842..67b80732 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -24,7 +24,7 @@ def test_model_loading(self): for model_type in self.model_list: model_type = '_'.join(model_type.split('_')[:-1]) # remove '_model' at the end model = loaders.load_model(model_type) - params = loaders.load_params(self.test_params_file, key=model_type+'_params') + params = loaders.load_params_file(self.test_params_file, key=model_type+'_params') train_loader, val_loader, test_loader, data_params = datasets.load_dataset(params) for key, value in data_params.items(): setattr(params, key, value) @@ -55,7 +55,7 @@ def test_model_loading(self): def test_lca_ensemble_gradients(self): params = {} models = {} - params['lca'] = loaders.load_params(self.test_params_file, key='lca_params') + params['lca'] = loaders.load_params_file(self.test_params_file, key='lca_params') params['lca'].train_logs_per_epoch = None params['lca'].shuffle_data = False train_loader, val_loader, test_loader, data_params = datasets.load_dataset(params['lca']) @@ -64,7 +64,7 @@ def test_lca_ensemble_gradients(self): models['lca'] = loaders.load_model(params['lca'].model_type) models['lca'].setup(params['lca']) models['lca'].to(params['lca'].device) - params['ensemble'] = loaders.load_params(self.test_params_file, key='ensemble_params') + params['ensemble'] = loaders.load_params_file(self.test_params_file, key='ensemble_params') for key, value in data_params.items(): setattr(params['ensemble'], key, value) err_msg = f'\ndata_shape={params["ensemble"].data_shape}' diff --git a/tests/test_param_loading.py b/tests/test_param_loading.py index 26377ad2..7d488a73 100644 --- a/tests/test_param_loading.py +++ b/tests/test_param_loading.py @@ -13,4 +13,4 @@ def test_param_loading(): for params_name in params_list: if 'test_' not in params_name: params_file = os.path.join(*[dsc_dir, 'params', params_name+'.py']) - params = loaders.load_params(params_file, key='params') + params = loaders.load_params_file(params_file, key='params') diff --git a/train_model.py b/train_model.py index e8aed99d..b2c8f2f4 100644 --- a/train_model.py +++ b/train_model.py @@ -20,7 +20,7 @@ t0 = ti.time() # Load params -params = loaders.load_params(param_file) +params = loaders.load_params_file(param_file) # Load data train_loader, val_loader, test_loader, data_stats = dataset_utils.load_dataset(params) diff --git a/utils/data_processing.py b/utils/data_processing.py index 0679928a..ef2d7437 100644 --- a/utils/data_processing.py +++ b/utils/data_processing.py @@ -110,8 +110,7 @@ def check_all_same_shape(tensor_list): for index, tensor in enumerate(tensor_list): if tensor.shape != first_shape: raise ValueError( - 'Tensor entry %g in input list has shape %g, but should have shape %g'%( - index, tensor.shape, first_shape)) + f'Tensor entry {index} in input list has shape {tensor.shape}, but should have shape {first_shape}') def flatten_feature_map(feature_map): diff --git a/utils/dataset_utils.py b/utils/dataset_utils.py index 64571152..ef6b38a0 100644 --- a/utils/dataset_utils.py +++ b/utils/dataset_utils.py @@ -5,6 +5,7 @@ ROOT_DIR = up(up(up(os.path.realpath(__file__)))) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) +from PIL import Image import numpy as np import torch from torchvision import transforms @@ -24,6 +25,16 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Scale data to [0,1] self.data = self.data.unsqueeze(-1).float().div(255) + self.data = self.data.permute(0, 3, 1, 2) # channels first + if self.transform is not None: + # doing this so that it is consistent with all other datasets + # to return a PIL Image + for data_idx in range(self.data.shape[0]): + self.data[data_idx, ...] = self.transform( + Image.fromarray( + self.data[data_idx, ...].numpy().squeeze(), mode='L'))[None, ...] + if self.target_transform is not None: + self.targets = [self.target_transform(int(target)) for target in self.targets] # Put both data and targets on GPU in advance self.data, self.targets = self.data.to(device), self.targets.to(device) @@ -66,7 +77,7 @@ def load_dataset(params): 'download':False, 'transform':transforms.Compose(preprocessing_pipeline) } - if hasattr(params, 'fast_mnist') and params.fast_mnist: + if(hasattr(params, 'fast_mnist') and params.fast_mnist): kwargs['device'] = params.device kwargs['train'] = True train_loader = torch.utils.data.DataLoader( @@ -93,7 +104,7 @@ def load_dataset(params): transforms.ToTensor(), ] kwargs = { - 'root': os.path.join(*[params.data_dir,'cifar10']), + 'root': os.path.join(*[params.data_dir, 'cifar10']), 'download': False, 'train': True, 'transform': transforms.Compose(preprocessing_pipeline) diff --git a/utils/loaders.py b/utils/loaders.py index 8e71480e..03063f64 100644 --- a/utils/loaders.py +++ b/utils/loaders.py @@ -86,7 +86,14 @@ def load_module(module_type): return py_module_class() -def load_params(file_name, key='params'): +def load_params_from_log(log_file): + logger = file_utils.Logger(log_file, overwrite=False) + log_text = logger.load_file() + params = logger.read_params(log_text)[-1] + return params + + +def load_params_file(file_name, key='params'): params_module = file_utils.python_module_from_file(key, file_name) params = getattr(params_module, key)() return params diff --git a/utils/run_utils.py b/utils/run_utils.py index 10817f90..72d58a98 100644 --- a/utils/run_utils.py +++ b/utils/run_utils.py @@ -1,8 +1,14 @@ +import numpy as np import torch import DeepSparseCoding.utils.data_processing as dp +def compute_conv_output_shape(in_length, kernel_size, stride, padding=0, dilation=1): + out_shape = ((in_length + 2 * padding - dilation * (kernel_size - 1) - 1) / stride) + 1 + return np.floor(out_shape).astype(np.int) + + def get_module_encodings(module, data, allow_grads=False): if allow_grads: return module.get_encodings(data) From 17d9295ae0d4bbb1f7f395b26d099aabd36d2243 Mon Sep 17 00:00:00 2001 From: Dylan Date: Fri, 12 Feb 2021 10:00:56 +0100 Subject: [PATCH 22/44] simplifies conv -> fc logic --- modules/ensemble_module.py | 3 +++ modules/lca_module.py | 16 ++++++++-------- modules/mlp_module.py | 15 +++++++-------- params/lca_cifar10_params.py | 10 +++++----- params/lca_mlp_cifar10_params.py | 32 ++++++++++++++++++++++++++------ params/lca_mnist_params.py | 4 ++-- utils/data_processing.py | 2 +- 7 files changed, 52 insertions(+), 30 deletions(-) diff --git a/modules/ensemble_module.py b/modules/ensemble_module.py index 9e7b7932..ee8df6db 100644 --- a/modules/ensemble_module.py +++ b/modules/ensemble_module.py @@ -1,6 +1,7 @@ import torch.nn as nn import DeepSparseCoding.utils.loaders as loaders +from DeepSparseCoding.utils.data_processing import flatten_feature_map class EnsembleModule(nn.Sequential): @@ -17,6 +18,8 @@ def setup_ensemble_module(self, params): def forward(self, x): self.layer_list = [x] for module in self: + if module.params.layer_types[0] == 'fc': + self.layer_list[-1] = flatten_feature_map(self.layer_list[-1]) self.layer_list.append(module.get_encodings(self.layer_list[-1])) # latent encodings return self.layer_list[-1] diff --git a/modules/lca_module.py b/modules/lca_module.py index aee8d2be..058e6277 100644 --- a/modules/lca_module.py +++ b/modules/lca_module.py @@ -12,14 +12,14 @@ class LcaModule(nn.Module): Keyword arguments: params: [dict] with keys: data_shape [list of int] of shape [elements, channels, height, width]; Assumes h = w (i.e. square inputs) - The remaining keys are only used layer_type is "conv": + The remaining keys are only used layer_types[0] is "conv": kernel_size: [int] edge size of the square convolving kernel stride: [int] vertical and horizontal stride of the convolution padding: [int] zero-padding added to both sides of the input """ def setup_module(self, params): self.params = params - if self.params.layer_type == 'fc': + if self.params.layer_types[0] == 'fc': self.layer_output_shapes = [[self.params.num_latent]] self.w_shape = [self.params.num_pixels, self.params.num_latent] else: @@ -50,12 +50,12 @@ def setup_module(self, params): self.w = nn.Parameter(w_init_normed, requires_grad=True) def preprocess_data(self, input_tensor): - if self.params.layer_type == 'fc': + if self.params.layer_types[0] == 'fc': input_tensor = input_tensor.view(self.params.batch_size, -1) return input_tensor def compute_excitatory_current(self, input_tensor, a_in): - if self.params.layer_type == 'fc': + if self.params.layer_types[0] == 'fc': excitatory_current = torch.matmul(input_tensor, self.w) else: recon = self.get_recon_from_latents(a_in) @@ -71,7 +71,7 @@ def compute_excitatory_current(self, input_tensor, a_in): return excitatory_current def compute_inhibitory_connectivity(self): - if self.params.layer_type == 'fc': + if self.params.layer_types[0] == 'fc': inhibitory_connectivity = torch.matmul(torch.transpose(self.w, dim0=0, dim1=1), self.w) - torch.eye(self.params.num_latent, requires_grad=True, device=self.params.device) @@ -85,7 +85,7 @@ def threshold_units(self, u_in): return a_out def step_inference(self, u_in, a_in, excitatory_current, inhibitory_connectivity, step): - if self.params.layer_type == 'fc': + if self.params.layer_types[0] == 'fc': lca_explain_away = torch.matmul(a_in, inhibitory_connectivity) else: lca_explain_away = 0 # already computed in excitatory_current @@ -109,12 +109,12 @@ def infer_coefficients(self, input_tensor): )[0] u_list.append(u) a_list.append(self.threshold_units(u)) - if self.params.layer_type == 'conv': + if self.params.layer_types[0] == 'conv': excitatory_current = self.compute_excitatory_current(input_tensor, a_list[-1]) return (u_list, a_list) def get_recon_from_latents(self, a_in): - if self.params.layer_type == 'fc': + if self.params.layer_types[0] == 'fc': recon = torch.matmul(a_in, torch.transpose(self.w, dim0=0, dim1=1)) else: recon = F.conv_transpose2d( diff --git a/modules/mlp_module.py b/modules/mlp_module.py index 3ad59b77..727334d2 100644 --- a/modules/mlp_module.py +++ b/modules/mlp_module.py @@ -4,7 +4,7 @@ from DeepSparseCoding.modules.activations import activation_picker from DeepSparseCoding.utils.run_utils import compute_conv_output_shape -import DeepSparseCoding.utils.data_processing as dp +from DeepSparseCoding.utils.data_processing import flatten_feature_map class MlpModule(nn.Module): @@ -84,16 +84,15 @@ def setup_module(self, params): def preprocess_data(self, input_tensor): if self.params.layer_types[0] == 'fc': - input_tensor = input_tensor.view(self.params.batch_size, -1) # flatten input + input_tensor = flatten_feature_map(input_tensor) return input_tensor def forward(self, x): - layer_zip = zip(self.dropout, self.pooling, self.act_funcs, self.layers) - for layer_index, (dropout, pooling, act_func, layer) in enumerate(layer_zip): - prev_layer = self.params.layer_types[layer_index - 1] - current_layer = self.params.layer_types[layer_index] - if(layer_index > 0 and current_layer == 'fc' and prev_layer == 'conv'): - x = dp.flatten_feature_map(x) + layer_zip = zip(self.dropout, self.pooling, self.act_funcs, + self.layers, self.params.layer_types) + for dropout, pooling, act_func, layer, layer_type in layer_zip: + if layer_type == 'fc': + x = flatten_feature_map(x) x = dropout(pooling(act_func(layer(x)))) return x diff --git a/params/lca_cifar10_params.py b/params/lca_cifar10_params.py index f25edce2..545f34bf 100644 --- a/params/lca_cifar10_params.py +++ b/params/lca_cifar10_params.py @@ -10,21 +10,23 @@ def set_params(self): self.model_name = 'conv_lca_cifar10' self.version = '0' self.dataset = 'cifar10' - self.layer_type = 'conv' + self.layer_types = ['conv'] self.num_validation = 10000 self.standardize_data = True self.rescale_data_to_one = False self.center_dataset = False self.batch_size = 25 self.num_epochs = 500 + self.train_logs_per_epoch = 6 + self.renormalize_weights = True + self.stride = 2 + self.padding = 0 self.weight_decay = 0.0 self.weight_lr = 0.001 - self.train_logs_per_epoch = 6 self.optimizer = types.SimpleNamespace() self.optimizer.name = 'sgd' self.optimizer.lr_annealing_milestone_frac = [0.8] # fraction of num_epochs self.optimizer.lr_decay_rate = 0.8 - self.renormalize_weights = True self.dt = 0.001 self.tau = 0.2 self.num_steps = 75 @@ -32,8 +34,6 @@ def set_params(self): self.thresh_type = 'hard' self.sparse_mult = 0.30 self.kernel_size = 8 - self.stride = 2 - self.padding = 0 self.num_latent = 512 self.compute_helper_params() diff --git a/params/lca_mlp_cifar10_params.py b/params/lca_mlp_cifar10_params.py index 66ed1a32..9e539360 100644 --- a/params/lca_mlp_cifar10_params.py +++ b/params/lca_mlp_cifar10_params.py @@ -6,19 +6,20 @@ from DeepSparseCoding.params.base_params import BaseParams from DeepSparseCoding.params.lca_mnist_params import params as LcaParams from DeepSparseCoding.params.mlp_mnist_params import params as MlpParams +from DeepSparseCoding.utils.run_utils import compute_conv_output_shape class shared_params(object): def __init__(self): self.model_type = 'ensemble' self.model_name = 'lca_mlp_cifar10' - self.version = '0' + self.version = '1' self.dataset = 'cifar10' self.standardize_data = True self.batch_size = 25 self.num_epochs = 500 self.train_logs_per_epoch = 4 - self.allow_parent_grads = False + self.allow_parent_grads = True class lca_params(LcaParams): @@ -27,13 +28,16 @@ def set_params(self): for key, value in shared_params().__dict__.items(): setattr(self, key, value) self.model_type = 'lca' + self.layer_types = ['conv'] self.weight_decay = 0.0 self.weight_lr = 0.001 + self.renormalize_weights = True + self.stride = 2 + self.padding = 0 self.optimizer = types.SimpleNamespace() self.optimizer.name = 'sgd' self.optimizer.lr_annealing_milestone_frac = [0.8] # fraction of num_epochs self.optimizer.lr_decay_rate = 0.8 - self.renormalize_weights = True self.dt = 0.001 self.tau = 0.2 self.num_steps = 75 @@ -41,7 +45,7 @@ def set_params(self): self.thresh_type = 'hard' self.sparse_mult = 0.30 self.num_latent = 512 - self.checkpoint_boot_log = '' + self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/conv_lca_cifar10/logfiles/conv_lca_cifar10_v1.log' self.compute_helper_params() @@ -54,7 +58,7 @@ def set_params(self): self.weight_lr = 2e-3 self.weight_decay = 1e-6 self.layer_types = ['fc'] - self.layer_channels = [512, 10] + self.layer_channels = [None, 10] self.activation_functions = ['identity'] self.dropout_rate = [0.0] # probability of value being set to zero self.optimizer = types.SimpleNamespace() @@ -67,6 +71,22 @@ def set_params(self): class params(BaseParams): def set_params(self): super(params, self).set_params() - self.ensemble_params = [lca_params(), mlp_params()] + lca_params_inst = lca_params() + mlp_params_inst = mlp_params() + lca_output_height = compute_conv_output_shape( + 32, # TODO: infer this? currently hardcoded CIFAR10 size + lca_params_inst.kernel_size, + lca_params_inst.stride, + lca_params_inst.padding, + dilation=1) + lca_output_width = compute_conv_output_shape( + 32, + lca_params_inst.kernel_size, + lca_params_inst.stride, + lca_params_inst.padding, + dilation=1) + lca_output_shape = [lca_params_inst.num_latent, lca_output_height, lca_output_width] + mlp_params_inst.layer_channels[0] = np.prod(lca_output_shape) + self.ensemble_params = [lca_params_inst, mlp_params_inst] for key, value in shared_params().__dict__.items(): setattr(self, key, value) diff --git a/params/lca_mnist_params.py b/params/lca_mnist_params.py index 9c46cc1b..e3a7ebc2 100644 --- a/params/lca_mnist_params.py +++ b/params/lca_mnist_params.py @@ -28,7 +28,7 @@ def set_params(self): self.weight_decay = 0.0 self.train_logs_per_epoch = 6 if CONV: - self.layer_type = 'conv' + self.layer_types = ['conv'] self.model_name = 'conv_lca_mnist' self.rescale_data_to_one = True self.batch_size = 50 @@ -40,7 +40,7 @@ def set_params(self): self.padding = 0 self.num_latent = 128 else: - self.layer_type = 'fc' + self.layer_types = ['fc'] self.model_type = 'lca' self.model_name = 'lca_768_mnist' self.rescale_data_to_one = False diff --git a/utils/data_processing.py b/utils/data_processing.py index ef2d7437..0afdcc8f 100644 --- a/utils/data_processing.py +++ b/utils/data_processing.py @@ -126,7 +126,7 @@ def flatten_feature_map(feature_map): if(len(map_shape) == 4): (batch, c, y, x) = map_shape prev_input_features = int(c * y * x) - resh_map = torch.reshape(feature_map, [-1, prev_input_features]) + resh_map = torch.reshape(feature_map, [batch, prev_input_features]) elif(len(map_shape) == 2): resh_map = feature_map else: From aef8b9f06c7b2b14ef4735d2db22290ca1943356 Mon Sep 17 00:00:00 2001 From: Dylan Date: Mon, 22 Feb 2021 08:36:52 +0000 Subject: [PATCH 23/44] bugfix in optimizer checkpoint loading for ensemble models --- models/base_model.py | 8 +++++++- models/ensemble_model.py | 10 ++++++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/models/base_model.py b/models/base_model.py index dd0bc8d6..36d89827 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -124,8 +124,14 @@ def write_checkpoint(self, batch_step=None): """ output_dict = { 'model_state_dict': self.state_dict(), - 'optimizer_state_dict': self.optimizer.state_dict(), } + if(self.params.model_type.lower() == 'ensemble'): + for module in self: + module_state_dict_name = module.params.submodule_name+'_optimizer_state_dict' + output_dict[module_state_dict_name] = module.optimizer.state_dict(), + else: + module_state_dict_name = 'optimizer_state_dict' + output_dict[module_state_dict_name] = self.optimizer.state_dict(), training_stats = self.get_train_stats(batch_step) output_dict.update(training_stats) torch.save(output_dict, self.params.cp_latest_filename) diff --git a/models/ensemble_model.py b/models/ensemble_model.py index 196532fe..19997d5e 100644 --- a/models/ensemble_model.py +++ b/models/ensemble_model.py @@ -15,7 +15,9 @@ def setup(self, params, logger=None): self.setup_optimizer() def setup_module(self, params): - for subparams in params.ensemble_params: + for sub_index, subparams in enumerate(params.ensemble_params): + submodule_name = subparams.model_type + f'_{sub_index:02}' + subparams.submodule_name = submodule_name subparams.epoch_size = params.epoch_size subparams.batches_per_epoch = params.batches_per_epoch subparams.num_batches = params.num_batches @@ -39,7 +41,11 @@ def setup_optimizer(self): trainable_variables=module.parameters()) if module.params.checkpoint_boot_log != '': checkpoint = self.get_checkpoint_from_log(module.params.checkpoint_boot_log) - module.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + module_state_dict_name = module.params.submodule_name+'_optimizer_state_dict' + if module_state_dict_name in checkpoint.keys(): # It was already in an ensemble + module.optimizer.load_state_dict(checkpoint[module_state_dict_name]) + else: # it was trained on its own + module.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) module.scheduler = torch.optim.lr_scheduler.MultiStepLR( module.optimizer, milestones=module.params.optimizer.milestones, From a1515e01e7e25d829518e9c5283b404fbd177cb5 Mon Sep 17 00:00:00 2001 From: Dylan Date: Mon, 22 Feb 2021 08:38:05 +0000 Subject: [PATCH 24/44] adds manifold pooling layer --- models/pooling_model.py | 32 +++++++++++ modules/losses.py | 31 ++++++++--- modules/pooling_module.py | 77 ++++++++++++++++++++++++++ params/lca_pool_cifar10_params.py | 89 +++++++++++++++++++++++++++++++ utils/data_processing.py | 25 +++++++++ utils/loaders.py | 18 ++++--- 6 files changed, 258 insertions(+), 14 deletions(-) create mode 100644 models/pooling_model.py create mode 100644 modules/pooling_module.py create mode 100644 params/lca_pool_cifar10_params.py diff --git a/models/pooling_model.py b/models/pooling_model.py new file mode 100644 index 00000000..7b1bc149 --- /dev/null +++ b/models/pooling_model.py @@ -0,0 +1,32 @@ +import torch + +import DeepSparseCoding.modules.losses as losses + +from DeepSparseCoding.models.base_model import BaseModel +from DeepSparseCoding.modules.pooling_module import PoolingModule + +class PoolingModel(BaseModel, PoolingModule): + def setup(self, params, logger=None): + self.setup_module(params) + self.setup_optimizer() + if params.checkpoint_boot_log != '': + checkpoint = self.get_checkpoint_from_log(params.checkpoint_boot_log) + self.module.load_state_dict(checkpoint['model_state_dict']) + self.module.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + + def get_total_loss(self, input_tuple): + input_tensor, input_label = input_tuple + rep = self.forward(input_tensor) + self.loss_fn = losses.trace_covariance + return self.loss_fn(rep) + + def generate_update_dict(self, input_data, input_labels=None, batch_step=0, update_dict=None): + if update_dict is None: + update_dict = super(PoolinModel, self).generate_update_dict(input_data, input_labels, batch_step) + stat_dict = dict() + rep = self.forward(input_data) + total_loss = self.loss_fn(rep) + stat_dict['weight_lr'] = self.scheduler.get_lr()[0] + stat_dict['loss'] = total_loss.item() + update_dict.update(stat_dict) + return update_dict diff --git a/modules/losses.py b/modules/losses.py index d9ea61b7..62f71007 100644 --- a/modules/losses.py +++ b/modules/losses.py @@ -6,10 +6,10 @@ def half_squared_l2(x1, x2): """ Computes the standard reconstruction loss. It will average over batch dimensions. - Args: + Keyword arguments: x1: Tensor with original input image x2: Tensor with reconstructed image for comparison - Returns: + Outputs: recon_loss: Tensor representing the squared l2 distance between the inputs, averaged over batch """ dp.check_all_same_shape([x1, x2]) @@ -22,9 +22,9 @@ def half_squared_l2(x1, x2): def half_weight_norm_squared(weight_list): """ Computes a loss that encourages each weight in the list of weights to have unit l2 norm. - Args: + Keyword arguments: weight_list: List of torch variables - Returns: + Outputs: w_norm_loss: 0.5 * sum of (1 - l2_norm(w))^2 for each w in weight_list """ w_norm_list = [] @@ -39,9 +39,9 @@ def half_weight_norm_squared(weight_list): def weight_decay(weight_list): """ Computes typical weight decay loss - Args: + Keyword arguments: weight_list: List of torch variables - Returns: + Outputs: decay_loss: 0.5 * sum of w^2 for each w in weight_list """ decay_loss = 0.5 * torch.sum([torch.sum(torch.pow(w, 2.)) for w in weight_list]) @@ -52,11 +52,26 @@ def l1_norm(latents): """ Computes the L1 norm of for a batch of input vector This is the sparsity loss for a Laplacian prior - Args: + Keyword arguments: latents: torch tensor of any shape, but where first index is always batch - Returns: + Outputs: sparse_loss: sum of abs of latents, averaged over the batch """ reduc_dim = list(range(1, len(latents.shape))) sparse_loss = torch.mean(torch.sum(torch.abs(latents), dim=reduc_dim, keepdim=False)) return sparse_loss + + +def trace_covariance(latents): + """ + Loss is the trace of the covariance matrix of the latents + Keyword arguments: + latents: torch tensor of shape [num_batch, num_latents] or [num_batch, num_channels, latents_h, latents_w] + Outputs: + """ + corvariance = dp.covariance(latents) # [num_channels, num_channels] + if latenst.ndim == 4: + num_batch, num_channels, latents_h, latents_w = latents.shape + covariance = covariance / (latents_h * latents_w - 1.0) + trace = torch.trace(covariance) + return -1 * trace diff --git a/modules/pooling_module.py b/modules/pooling_module.py new file mode 100644 index 00000000..6b66525f --- /dev/null +++ b/modules/pooling_module.py @@ -0,0 +1,77 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import DeepSparseCoding.utils.data_processing as dp + + +class PoolingModule(nn.Module): + def setup_module(self, params): + self.params = params + if self.params.layer_type == 'fc': + layer = nn.Linear( + in_features=self.params.layer_channels[0], + out_features=self.params.layer_channels[1], + bias=False) + self.w = layer.weight + self.register_parameter('fc_pool_'+self.params.pool_name+'_w', layer.weight) + + elif self.params.layer_type == 'conv': + layer = nn.Conv2d( + in_channels=self.params.layer_channels[0], + out_channels=self.params.layer_channels[1], + kernel_size=self.params.pool_ksize, + stride=self.params.pool_stride, + padding=0, + dilation=1, + bias=False) + self.w = layer.weight + self.register_parameter('conv_pool_'+self.params.pool_name+'_w', layer.weight) + + elif self.params.layer_type == 'orth_conv': + """ + Based on Orthogonal Convolutional Neural Networks + https://arxiv.org/abs/1911.12207 + https://github.com/samaonline/Orthogonal-Convolutional-Neural-Networks + """ + self.w_shape = [ + self.params.layer_channels[1], + self.params.layer_channels[0], + self.params.pool_ksize, + self.params.pool_ksize # assumed square kernel + ] + w_init = torch.randn(self.w_shape) + w_init_normed = dp.l2_normalize_weights(w_init, eps=self.params.eps) + self.w = nn.Parameter(w_init_normed, requires_grad=True) + kernel_width = self.params.pool_ksize + in_channels = self.params.layer_channels[0] + new_stride = self.params.pool_stride * (kernel_width-1) + kernel_width + identity = torch.eye( + n=int(new_stride * new_stride * in_channels), + requires_grad=True, + device=self.params.device) + identity = identity.reshape( + (new_stride * new_stride * in_channels, in_channels, new_stride, new_stride)) + conv_out = F.conv2d( + identity, + self.w, + stride=self.params.pool_stride, + padding=0, + dilation=1) + conv_out = conv_out.reshape((new_stride * new_stride * in_channels, -1)) + Vmat = conv_out[np.floor(new_stride**2 / 2).astype(int)::new_stride**2, :] + dbt_mask = torch.zeors(in_channels, in_channels * new_stride**2) + for i in range(in_channels): + dbt_mask[i, np.floor(new_stride**2 / 2).astype(int) + new_stride**2 * i] = 1 + layer = torch.norm(torch.dot(Vmat, conv_out.transpose()) - dbt_mask, p='fro') + + else: + assert False, ('layer_type parameter must be "fc", "conv", or "orth_conv", not %g'%(layer_type)) + + def forward(self, x): + if self.params.layer_type == 'fc': + x = dp.flatten_feature_map(x) + return layer(x) + + def get_encodings(self, input_tensor): + return self.forward(input_tensor) diff --git a/params/lca_pool_cifar10_params.py b/params/lca_pool_cifar10_params.py new file mode 100644 index 00000000..871311c5 --- /dev/null +++ b/params/lca_pool_cifar10_params.py @@ -0,0 +1,89 @@ +import os +import types +import numpy as np +import torch + +from DeepSparseCoding.params.base_params import BaseParams +from DeepSparseCoding.params.lca_cifar10_params import params as LcaParams +from DeepSparseCoding.utils.run_utils import compute_conv_output_shape + + +class shared_params(object): + def __init__(self): + self.model_type = 'ensemble' + self.model_name = 'lca_pool_cifar10' + self.version = '0' + self.dataset = 'cifar10' + self.standardize_data = True + self.batch_size = 25 + self.num_epochs = 5 + self.train_logs_per_epoch = 4 + self.allow_parent_grads = False + + +class lca_params(LcaParams): + def set_params(self): + super(lca_params, self).set_params() + for key, value in shared_params().__dict__.items(): setattr(self, key, value) + self.model_type = 'lca' + self.layer_types = ['conv'] + self.weight_decay = 0.0 + self.weight_lr = 0.001 + self.renormalize_weights = True + self.stride = 2 + self.padding = 0 + self.optimizer = types.SimpleNamespace() + self.optimizer.name = 'sgd' + self.optimizer.lr_annealing_milestone_frac = [0.8] # fraction of num_epochs + self.optimizer.lr_decay_rate = 0.8 + self.dt = 0.001 + self.tau = 0.2 + self.num_steps = 75 + self.rectify_a = True + self.thresh_type = 'hard' + self.sparse_mult = 0.30 + self.num_latent = 512 + self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/conv_lca_cifar10/logfiles/conv_lca_cifar10_v1.log' + self.compute_helper_params() + + +class pool_params(BaseParams): + def set_params(self): + super(pool_params, self).set_params() + for key, value in shared_params().__dict__.items(): + setattr(self, key, value) + self.model_type = 'pool' + self.weight_lr = 1e-3 + self.layer_type = 'orth_conv' + self.layer_channels = [None, 10] + self.pool_ksize = 4 + self.pool_stride = 2 + self.optimizer = types.SimpleNamespace() + self.optimizer.name = 'sgd' + self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs + self.optimizer.lr_decay_rate = 0.8 + self.compute_helper_params() + + +class params(BaseParams): + def set_params(self): + super(params, self).set_params() + lca_params_inst = lca_params() + pool_params_inst = pool_params() + lca_output_height = compute_conv_output_shape( + 32, # TODO: infer this? currently hardcoded CIFAR10 size + lca_params_inst.kernel_size, + lca_params_inst.stride, + lca_params_inst.padding, + dilation=1) + lca_output_width = compute_conv_output_shape( + 32, + lca_params_inst.kernel_size, + lca_params_inst.stride, + lca_params_inst.padding, + dilation=1) + lca_output_shape = [lca_params_inst.num_latent, lca_output_height, lca_output_width] + pool_params_inst.layer_channels[0] = np.prod(lca_output_shape) + self.ensemble_params = [lca_params_inst, pool_params_inst] + for key, value in shared_params().__dict__.items(): + setattr(self, key, value) diff --git a/utils/data_processing.py b/utils/data_processing.py index 0afdcc8f..0c2fc175 100644 --- a/utils/data_processing.py +++ b/utils/data_processing.py @@ -119,6 +119,7 @@ def flatten_feature_map(feature_map): Keyword arguments: feature_map: tensor with shape [batch, c, y, x] + Returns: reshaped_map: tensor with shape [batch, c * y * x] """ @@ -495,3 +496,27 @@ def patches_to_images(patches, image_shape): images[im_id, ...] = patches_to_single_image(patch_batch, image_shape) patch_id += num_patches_per_im return images + + +def covariance(tensor): + """ + Returns the covariance matrix of the input tensor + + Keyword arguments: + tensor [torch tensor] of shape [num_batch, num_channels] or [num_batch, num_channels, elements_h, elements_w] + if tensor.ndim is 2 then the covariance is computed for each element over the batch dimension + if tensor.ndim is 4 then the covariance is computed over spatial dimensions for each channel and each batch instance + + Outputs: + covariance matrix [torch tensor] either of shape [num_channels, num_channels] + """ + if tensor.ndim == 2: # [num_batch, num_channels] + centered_tensor = tensor - tensor.mean(dim=0, keep_dims=True) # subtract mean vector + corvariance = torch.dot(centered_tensor.T, centered_tensor) # sum over batch + elif tensor.ndim == 4: # [num_batch, num_channels, elements_h, elements_w] + num_batch, num_channels, elements_h, elements_w = centered_tensor.shape + flat_map = centered_tensor.view(num_batch, num_channels, elements_h * elements_w) + cent_flat_map = flat_map - flat_map.mean(dim=2, keep_dims=True) # subtract mean vector + covariance = torch.bmm(cent_flat_map, torch.transpose(cent_flat_map, 1, 2)) # sum over space + covariance = covariance.mean(dim=0, keepdims=False) # avg cov over batch + return covariance diff --git a/utils/loaders.py b/utils/loaders.py index 03063f64..192664f0 100644 --- a/utils/loaders.py +++ b/utils/loaders.py @@ -44,9 +44,12 @@ def load_model_class(model_type): elif(model_type.lower() == 'lca'): py_module_name = 'LcaModel' file_name = os.path.join(*[dsc_dir, 'models', 'lca_model.py']) - elif(model_type.lower() == 'conv_lca'): - py_module_name = 'ConvLcaModel' - file_name = os.path.join(*[dsc_dir, 'models', 'conv_lca_model.py']) + #elif(model_type.lower() == 'conv_lca'): + # py_module_name = 'ConvLcaModel' + # file_name = os.path.join(*[dsc_dir, 'models', 'conv_lca_model.py']) + elif(model_type.lower() == 'pool'): + py_module_name = 'PoolingModel' + file_name = os.path.join(*[dsc_dir, 'models', 'pooling_model.py']) elif(model_type.lower() == 'ensemble'): py_module_name = 'EnsembleModel' file_name = os.path.join(*[dsc_dir, 'models', 'ensemble_model.py']) @@ -71,9 +74,12 @@ def load_module(module_type): elif(module_type.lower() == 'lca'): py_module_name = 'LcaModule' file_name = os.path.join(*[dsc_dir, 'modules', 'lca_module.py']) - elif(module_type.lower() == 'conv_lca'): - py_module_name = 'ConvLcaModule' - file_name = os.path.join(*[dsc_dir, 'modules', 'conv_lca_module.py']) + #elif(module_type.lower() == 'conv_lca'): + # py_module_name = 'ConvLcaModule' + # file_name = os.path.join(*[dsc_dir, 'modules', 'conv_lca_module.py']) + elif(module_type.lower() == 'pool'): + py_module_name = 'PoolingModule' + file_name = os.path.join(*[dsc_dir, 'modules', 'pooling_module.py']) elif(module_type.lower() == 'ensemble'): py_module_name = 'EnsembleModule' file_name = os.path.join(*[dsc_dir, 'modules', 'ensemble_module.py']) From 079543f00a998a3fe633ed55c8f842efe21298e2 Mon Sep 17 00:00:00 2001 From: Dylan Date: Mon, 22 Feb 2021 16:41:48 +0100 Subject: [PATCH 25/44] pooling weight orthogonalization is now specified as a loss; bugfixes --- models/pooling_model.py | 12 +++++-- modules/losses.py | 46 ++++++++++++++++++++++++-- modules/pooling_module.py | 55 ++++++------------------------- params/lca_pool_cifar10_params.py | 39 +++++++++++++--------- train_model.py | 5 +-- utils/data_processing.py | 18 +++++----- utils/run_utils.py | 4 +-- 7 files changed, 99 insertions(+), 80 deletions(-) diff --git a/models/pooling_model.py b/models/pooling_model.py index 7b1bc149..9fdd7458 100644 --- a/models/pooling_model.py +++ b/models/pooling_model.py @@ -15,10 +15,16 @@ def setup(self, params, logger=None): self.module.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) def get_total_loss(self, input_tuple): + def loss_fn(model_output): + output_loss = losses.trace_covariance(model_output) + w_stride = self.params.pool_stride + w_padding = 0 + weight_loss = losses.weight_orthogonality(self.w, stride=w_stride, padding=w_padding) + return output_loss + weight_loss input_tensor, input_label = input_tuple - rep = self.forward(input_tensor) - self.loss_fn = losses.trace_covariance - return self.loss_fn(rep) + layer_output = self.forward(input_tensor) + self.loss_fn = loss_fn + return self.loss_fn(layer_output) def generate_update_dict(self, input_data, input_labels=None, batch_step=0, update_dict=None): if update_dict is None: diff --git a/modules/losses.py b/modules/losses.py index 62f71007..d065b374 100644 --- a/modules/losses.py +++ b/modules/losses.py @@ -1,3 +1,4 @@ +import numpy as np import torch import DeepSparseCoding.utils.data_processing as dp @@ -64,14 +65,53 @@ def l1_norm(latents): def trace_covariance(latents): """ - Loss is the trace of the covariance matrix of the latents + Returns loss that is the trace of the covariance matrix of the latents + Keyword arguments: latents: torch tensor of shape [num_batch, num_latents] or [num_batch, num_channels, latents_h, latents_w] Outputs: + loss """ - corvariance = dp.covariance(latents) # [num_channels, num_channels] - if latenst.ndim == 4: + covariance = dp.covariance(latents) # [num_channels, num_channels] + if latents.ndim == 4: num_batch, num_channels, latents_h, latents_w = latents.shape covariance = covariance / (latents_h * latents_w - 1.0) trace = torch.trace(covariance) return -1 * trace + + +def weight_orthogonality(weights, stride=1, padding=0): + """ + Returns l2 loss that is minimized when the weights are orthogonal + + Keyword arguments: + weights [torch tensor] layer weights, either fully connected or 2d convolutional + stride [int] layer stride for convolutional layers + padding [int] layer padding for convolutional layers + + Outputs: + loss + + Note: + Convolutional orthogonalization loss is based on + Orthogonal Convolutional Neural Networks + https://arxiv.org/abs/1911.12207 + https://github.com/samaonline/Orthogonal-Convolutional-Neural-Networks + """ + w_shape = weights.shape + if weights.ndim == 2: # fully-connected, [inputs, outputs] + loss = torch.norm(torch.matmul(weights.transpose(), weights) - torch.eye(w_shape[1])) + elif weights.ndim == 4: # convolutional, [output_channels, input_channels, height, width] + out_channels, in_channels, in_height, in_width = w_shape + output = torch.conv2d(weights, weights, stride=stride, padding=padding) + out_height = output.shape[-2] + out_width = output.shape[-1] + target = torch.zeros((out_channels, out_channels, out_height, out_width), + device=weights.device) + center_h = int(np.floor(out_height / 2)) + center_w = int(np.floor(out_width / 2)) + target[:, :, center_h, center_w] = torch.eye(out_channels, device=weights.device) + loss = torch.norm(output - target, p='fro') + else: + assert False, (f'weights ndim must be 2 or 4, not {weights.ndim}') + return loss diff --git a/modules/pooling_module.py b/modules/pooling_module.py index 6b66525f..59265662 100644 --- a/modules/pooling_module.py +++ b/modules/pooling_module.py @@ -7,17 +7,18 @@ class PoolingModule(nn.Module): def setup_module(self, params): + params.weight_decay = 0 # used by base model; pooling layer never has weight decay self.params = params if self.params.layer_type == 'fc': - layer = nn.Linear( + self.layer = nn.Linear( in_features=self.params.layer_channels[0], out_features=self.params.layer_channels[1], bias=False) - self.w = layer.weight - self.register_parameter('fc_pool_'+self.params.pool_name+'_w', layer.weight) + self.w = self.layer.weight + self.register_parameter('fc_pool_'+self.params.layer_name+'_w', self.layer.weight) elif self.params.layer_type == 'conv': - layer = nn.Conv2d( + self.layer = nn.Conv2d( in_channels=self.params.layer_channels[0], out_channels=self.params.layer_channels[1], kernel_size=self.params.pool_ksize, @@ -25,53 +26,17 @@ def setup_module(self, params): padding=0, dilation=1, bias=False) - self.w = layer.weight - self.register_parameter('conv_pool_'+self.params.pool_name+'_w', layer.weight) - - elif self.params.layer_type == 'orth_conv': - """ - Based on Orthogonal Convolutional Neural Networks - https://arxiv.org/abs/1911.12207 - https://github.com/samaonline/Orthogonal-Convolutional-Neural-Networks - """ - self.w_shape = [ - self.params.layer_channels[1], - self.params.layer_channels[0], - self.params.pool_ksize, - self.params.pool_ksize # assumed square kernel - ] - w_init = torch.randn(self.w_shape) - w_init_normed = dp.l2_normalize_weights(w_init, eps=self.params.eps) - self.w = nn.Parameter(w_init_normed, requires_grad=True) - kernel_width = self.params.pool_ksize - in_channels = self.params.layer_channels[0] - new_stride = self.params.pool_stride * (kernel_width-1) + kernel_width - identity = torch.eye( - n=int(new_stride * new_stride * in_channels), - requires_grad=True, - device=self.params.device) - identity = identity.reshape( - (new_stride * new_stride * in_channels, in_channels, new_stride, new_stride)) - conv_out = F.conv2d( - identity, - self.w, - stride=self.params.pool_stride, - padding=0, - dilation=1) - conv_out = conv_out.reshape((new_stride * new_stride * in_channels, -1)) - Vmat = conv_out[np.floor(new_stride**2 / 2).astype(int)::new_stride**2, :] - dbt_mask = torch.zeors(in_channels, in_channels * new_stride**2) - for i in range(in_channels): - dbt_mask[i, np.floor(new_stride**2 / 2).astype(int) + new_stride**2 * i] = 1 - layer = torch.norm(torch.dot(Vmat, conv_out.transpose()) - dbt_mask, p='fro') + nn.init.orthogonal_(self.layer.weight) # initialize to orthogonal matrix + self.w = self.layer.weight + self.register_parameter('conv_pool_'+self.params.layer_name+'_w', self.layer.weight) else: - assert False, ('layer_type parameter must be "fc", "conv", or "orth_conv", not %g'%(layer_type)) + assert False, ('layer_type parameter must be "fc", "conv", not %g'%(layer_type)) def forward(self, x): if self.params.layer_type == 'fc': x = dp.flatten_feature_map(x) - return layer(x) + return self.layer(x) def get_encodings(self, input_tensor): return self.forward(input_tensor) diff --git a/params/lca_pool_cifar10_params.py b/params/lca_pool_cifar10_params.py index 871311c5..c1d1073b 100644 --- a/params/lca_pool_cifar10_params.py +++ b/params/lca_pool_cifar10_params.py @@ -53,9 +53,10 @@ def set_params(self): for key, value in shared_params().__dict__.items(): setattr(self, key, value) self.model_type = 'pool' + self.layer_name = 'pool_1' self.weight_lr = 1e-3 - self.layer_type = 'orth_conv' - self.layer_channels = [None, 10] + self.layer_type = 'conv' + self.layer_channels = [512, 10] self.pool_ksize = 4 self.pool_stride = 2 self.optimizer = types.SimpleNamespace() @@ -64,26 +65,32 @@ def set_params(self): self.optimizer.lr_decay_rate = 0.8 self.compute_helper_params() + def compute_helper_params(self): + super(pool_params, self).compute_helper_params() + self.optimizer.milestones = [frac * self.num_epochs + for frac in self.optimizer.lr_annealing_milestone_frac] + class params(BaseParams): def set_params(self): super(params, self).set_params() lca_params_inst = lca_params() pool_params_inst = pool_params() - lca_output_height = compute_conv_output_shape( - 32, # TODO: infer this? currently hardcoded CIFAR10 size - lca_params_inst.kernel_size, - lca_params_inst.stride, - lca_params_inst.padding, - dilation=1) - lca_output_width = compute_conv_output_shape( - 32, - lca_params_inst.kernel_size, - lca_params_inst.stride, - lca_params_inst.padding, - dilation=1) - lca_output_shape = [lca_params_inst.num_latent, lca_output_height, lca_output_width] - pool_params_inst.layer_channels[0] = np.prod(lca_output_shape) + if(pool_params_inst.layer_type == 'fc' and lca_params_inst.layer_type == 'conv'): + lca_output_height = compute_conv_output_shape( + 32, # TODO: infer this? currently hardcoded CIFAR10 size + lca_params_inst.kernel_size, + lca_params_inst.stride, + lca_params_inst.padding, + dilation=1) + lca_output_width = compute_conv_output_shape( + 32, + lca_params_inst.kernel_size, + lca_params_inst.stride, + lca_params_inst.padding, + dilation=1) + lca_output_shape = [lca_params_inst.num_latent, lca_output_height, lca_output_width] + pool_params_inst.layer_channels[0] = np.prod(lca_output_shape) self.ensemble_params = [lca_params_inst, pool_params_inst] for key, value in shared_params().__dict__.items(): setattr(self, key, value) diff --git a/train_model.py b/train_model.py index b2c8f2f4..40b0414d 100644 --- a/train_model.py +++ b/train_model.py @@ -35,8 +35,9 @@ # Train model for epoch in range(1, model.params.num_epochs+1): run_utils.train_epoch(epoch, model, train_loader) - if(model.params.model_type.lower() in ['mlp', 'ensemble']): # TODO: use to validation set here; test at the end of training - run_utils.test_epoch(epoch, model, test_loader) + # TODO: Ensemble models might not actually have a classification objective / need validation + #if(model.params.model_type.lower() in ['mlp', 'ensemble']): # TODO: use to validation set here; test at the end of training + # run_utils.test_epoch(epoch, model, test_loader) model.log_info(f'Completed epoch {epoch}/{model.params.num_epochs}') print(f'Completed epoch {epoch}/{model.params.num_epochs}') diff --git a/utils/data_processing.py b/utils/data_processing.py index 0c2fc175..684864bc 100644 --- a/utils/data_processing.py +++ b/utils/data_processing.py @@ -169,7 +169,7 @@ def get_mean_from_dataloader(loader): dataset_mean = torch.zeros(next(iter(loader))[0].shape[1:]) # don't include batch dimension num_batches = 0 for data, target in loader: - dataset_mean += data.mean(axis=0, keepdims=False) + dataset_mean += data.mean(dim=0, keepdim=False) num_batches += 1 return dataset_mean / num_batches @@ -247,9 +247,9 @@ def rescale_data_to_one(data, eps=None, samplewise=True): eps = 1.0 / np.sqrt(data[0,...].numel()) if(samplewise): data_min = torch.min(data.view(-1, np.prod(data.shape[1:])), - axis=1, keepdims=False)[0].view(-1, *[1]*(data.ndim-1)) + axis=1, keepdim=False)[0].view(-1, *[1]*(data.ndim-1)) data_max = torch.max(data.view(-1, np.prod(data.shape[1:])), - axis=1, keepdims=False)[0].view(-1, *[1]*(data.ndim-1)) + axis=1, keepdim=False)[0].view(-1, *[1]*(data.ndim-1)) else: data_min = torch.min(data) data_max = torch.max(data) @@ -508,15 +508,15 @@ def covariance(tensor): if tensor.ndim is 4 then the covariance is computed over spatial dimensions for each channel and each batch instance Outputs: - covariance matrix [torch tensor] either of shape [num_channels, num_channels] + covariance matrix [torch tensor] of shape [num_channels, num_channels] """ if tensor.ndim == 2: # [num_batch, num_channels] - centered_tensor = tensor - tensor.mean(dim=0, keep_dims=True) # subtract mean vector + centered_tensor = tensor - tensor.mean(dim=0, keepdim=True) # subtract mean vector corvariance = torch.dot(centered_tensor.T, centered_tensor) # sum over batch elif tensor.ndim == 4: # [num_batch, num_channels, elements_h, elements_w] - num_batch, num_channels, elements_h, elements_w = centered_tensor.shape - flat_map = centered_tensor.view(num_batch, num_channels, elements_h * elements_w) - cent_flat_map = flat_map - flat_map.mean(dim=2, keep_dims=True) # subtract mean vector + num_batch, num_channels, elements_h, elements_w = tensor.shape + flat_map = tensor.view(num_batch, num_channels, elements_h * elements_w) + cent_flat_map = flat_map - flat_map.mean(dim=2, keepdim=True) # subtract mean vector covariance = torch.bmm(cent_flat_map, torch.transpose(cent_flat_map, 1, 2)) # sum over space - covariance = covariance.mean(dim=0, keepdims=False) # avg cov over batch + covariance = covariance.mean(dim=0, keepdim=False) # avg cov over batch return covariance diff --git a/utils/run_utils.py b/utils/run_utils.py index 72d58a98..b68e62f9 100644 --- a/utils/run_utils.py +++ b/utils/run_utils.py @@ -4,8 +4,8 @@ import DeepSparseCoding.utils.data_processing as dp -def compute_conv_output_shape(in_length, kernel_size, stride, padding=0, dilation=1): - out_shape = ((in_length + 2 * padding - dilation * (kernel_size - 1) - 1) / stride) + 1 +def compute_conv_output_shape(in_length, kernel_length, stride, padding=0, dilation=1): + out_shape = ((in_length + 2 * padding - dilation * (kernel_length - 1) - 1) / stride) + 1 return np.floor(out_shape).astype(np.int) From 49259f57c546f73131e1c1ee098f4dfcd81ae456 Mon Sep 17 00:00:00 2001 From: Dylan Date: Tue, 23 Feb 2021 09:20:51 +0100 Subject: [PATCH 26/44] not using smt_model file currently --- models/smt_model.py | 65 --------------------------------------------- 1 file changed, 65 deletions(-) delete mode 100644 models/smt_model.py diff --git a/models/smt_model.py b/models/smt_model.py deleted file mode 100644 index 2f027ffd..00000000 --- a/models/smt_model.py +++ /dev/null @@ -1,65 +0,0 @@ -import torch - -import DeepSparseCoding.utils.loaders as loaders -from DeepSparseCoding.models.base_model import BaseModel -from DeepSparseCoding.modules.ensemble_module import EnsembleModule - - -class SmtModel(BaseModel, EnsembleModule): - def setup(self, params, logger=None): - """ - Setup required model components - """ - super(SmtModel, self).setup(params, logger) - self.setup_module(params) - self.setup_optimizer() - - def setup_module(self, params): - for subparams in params.ensemble_params: - subparams.epoch_size = params.epoch_size - subparams.batches_per_epoch = params.batches_per_epoch - subparams.num_batches = params.num_batches - #subparams.num_val_images = params.num_val_images - #subparams.num_test_images = params.num_test_images - subparams.data_shape = params.data_shape - super(SmtModel, self).setup_ensemble_module(params) - self.submodel_classes = [] - for submodel_params in self.params.ensemble_params: - self.submodel_classes.append(loaders.load_model_class(submodel_params.model_type)) - - def setup_optimizer(self): - for module in self: - module.optimizer = self.get_optimizer( - optimizer_params=module.params, - trainable_variables=module.parameters()) - module.scheduler = torch.optim.lr_scheduler.MultiStepLR( - module.optimizer, - milestones=module.params.optimizer.milestones, - gamma=module.params.optimizer.lr_decay_rate) - - def preprocess_data(self, data): - """ - We assume that only the first submodel will be preprocessing the input data - """ - submodule = self.__getitem__(0) - return self.submodel_classes[0].preprocess_data(submodule, data) - - def get_total_loss(self, input_tuple, ensemble_index): - submodule = self.__getitem__(ensemble_index) - submodel_class = self.submodel_classes[ensemble_index] - return submodel_class.get_total_loss(submodule, input_tuple) - - def generate_update_dict(self, input_data, input_labels=None, batch_step=0): - update_dict = super(SmtModel, self).generate_update_dict(input_data, - input_labels, batch_step) - x = input_data.clone() # TODO: Do I need to clone it? If not then don't. - for ensemble_index, submodel_class in enumerate(self.submodel_classes): - submodule = self.__getitem__(ensemble_index) - submodel_update_dict = submodel_class.generate_update_dict(submodule, x, - input_labels, batch_step, update_dict=dict()) - for key, value in submodel_update_dict.items(): - if key not in ['epoch', 'batch_step']: - key = submodule.params.model_type+'_'+key - update_dict[key] = value - x = submodule.get_encodings(x) - return update_dict From 088d3f74736d9a2b35ee71a84c619317b1afa2f8 Mon Sep 17 00:00:00 2001 From: Dylan Date: Fri, 26 Feb 2021 17:38:35 +0000 Subject: [PATCH 27/44] adds new logging features; bugfixes; pooling tests logging now includes computer environment and model architecture details lca now outputs fraction nonzero for channel location as well as convolutional spatial map fixes a checkpoint loading bug for ensemble models fixes a bug with ensemble modules that caused submodules of the same type to clobber removes flatten_feature_map function from utils.data_processing in favor of one-line option reorganizes mlp and lca forward function calls for cleaner integration into ensembles renames lca num_latent params to layer_channels to match mlp specification adds pooling params to the test suite changed some logging apis to be more intuitive / general --- models/base_model.py | 67 +++++++++---- models/ensemble_model.py | 52 ++++++++-- models/lca_model.py | 17 +++- models/pooling_model.py | 5 +- modules/ensemble_module.py | 13 +-- modules/lca_module.py | 34 +++---- modules/losses.py | 36 +++---- modules/mlp_module.py | 31 ++++-- modules/pooling_module.py | 21 ++-- params/lca_cifar10_params.py | 14 +-- params/lca_dsprites_params.py | 3 +- params/lca_mlp_cifar10_params.py | 18 ++-- params/lca_mlp_mnist_params.py | 4 +- params/lca_mnist_params.py | 4 +- params/lca_pool_cifar10_params.py | 40 ++++---- params/lca_pool_lca_cifar10_params.py | 133 ++++++++++++++++++++++++++ params/test_params.py | 28 +++++- tests/test_data_processing.py | 10 -- tests/test_models.py | 14 +-- train_model.py | 7 +- utils/data_processing.py | 28 +----- utils/file_utils.py | 122 ++++++++++++++++++++--- utils/loaders.py | 14 +-- utils/run_utils.py | 5 +- 24 files changed, 521 insertions(+), 199 deletions(-) create mode 100644 params/lca_pool_lca_cifar10_params.py diff --git a/models/base_model.py b/models/base_model.py index 36d89827..2295333e 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -1,10 +1,13 @@ import os +import subprocess import pprint import numpy as np import torch +from DeepSparseCoding.utils.file_utils import summary_string from DeepSparseCoding.utils.file_utils import Logger +from DeepSparseCoding.utils.run_utils import compute_conv_output_shape import DeepSparseCoding.utils.loaders as loaders @@ -12,7 +15,6 @@ class BaseModel(object): def setup(self, params, logger=None): """ Setup required model components - #TODO: log system info, including git commit hash """ self.load_params(params) self.check_params() @@ -22,6 +24,7 @@ def setup(self, params, logger=None): self.log_params() else: self.logger = logger + self.logger.log_info(self.get_env_details()) def load_params(self, params): """ @@ -94,10 +97,6 @@ def log_params(self, params=None): dump_obj = self.params.__dict__ self.logger.log_params(dump_obj) - def log_info(self, string): - """Log input string""" - self.logger.log_info(string) - def get_train_stats(self, batch_step=None): """ Get default statistics about current training run @@ -115,6 +114,43 @@ def get_train_stats(self, batch_step=None): } return stat_dict + def get_env_details(self): + env = {} + for k in ['SYSTEMROOT', 'PATH']: + v = os.environ.get(k) + if v is not None: + env[k] = v + commit_cmd = ['git', 'rev-parse', 'HEAD'] + commit = subprocess.Popen(commit_cmd, stdout=subprocess.PIPE, env=env).communicate()[0] + commit = commit.strip().decode('ascii') + branch_cmd = ['git', 'rev-parse', '--abbrev-ref', 'HEAD'] + branch = subprocess.Popen(branch_cmd, stdout=subprocess.PIPE, env=env).communicate()[0] + branch = branch.strip().decode('ascii') + system_details = os.uname() + out_dict = { + 'current_branch':branch, + 'current_commit_hash':commit, + 'sysname':system_details.sysname, + 'release':system_details.release, + 'machine':system_details.machine + } + if torch.cuda.is_available(): + out_dict['gpu_device'] = torch.cuda.get_device_name(0) + return out_dict + + def log_architecture_details(self): + """ + Log model architecture with computed output sizes and number of parameters for each layer + """ + architecture_string = '\n'+summary_string( + self, + input_size=tuple(self.params.data_shape), + batch_size=self.params.batch_size, + device=self.params.device, + dtype=torch.FloatTensor + )[0] + self.logger.log_string(architecture_string) + def write_checkpoint(self, batch_step=None): """ Write checkpoints @@ -122,20 +158,20 @@ def write_checkpoint(self, batch_step=None): Keyword arguments: batch_step: [int] current batch iteration. The default assumes that training has finished. """ - output_dict = { - 'model_state_dict': self.state_dict(), - } + output_dict = {} if(self.params.model_type.lower() == 'ensemble'): for module in self: - module_state_dict_name = module.params.submodule_name+'_optimizer_state_dict' - output_dict[module_state_dict_name] = module.optimizer.state_dict(), + module_name = module.params.submodule_name + output_dict[module_name+'_module_state_dict'] = module.state_dict() + output_dict[module_name+'_optimizer_state_dict'] = module.optimizer.state_dict() else: + output_dict['model_state_dict'] = self.state_dict() module_state_dict_name = 'optimizer_state_dict' output_dict[module_state_dict_name] = self.optimizer.state_dict(), training_stats = self.get_train_stats(batch_step) output_dict.update(training_stats) torch.save(output_dict, self.params.cp_latest_filename) - self.log_info('Full model saved in file %s'%self.params.cp_latest_filename) + self.logger.log_string('Full model saved in file %s'%self.params.cp_latest_filename) def get_checkpoint_from_log(self, logfile): model_params = loaders.load_params_from_log(logfile) @@ -192,14 +228,11 @@ def print_update(self, input_data, input_labels=None, batch_step=0): input_data: data object containing the current image batch input_labels: data object containing the current label batch batch_step: current batch number within the schedule - NOTE: For the analysis code to parse update statistics, the self.js_dumpstring() call - must receive a dict object. Additionally, the self.js_dumpstring() output must be - logged with tags. - For example: logging.info(''+self.js_dumpstring(output_dictionary)+'') + NOTE: For the analysis code to parse update statistics, + the logger.log_stats() function must be used """ update_dict = self.generate_update_dict(input_data, input_labels, batch_step) - js_str = self.js_dumpstring(update_dict) - self.log_info(''+js_str+'') + self.logger.log_stats(update_dict) def generate_update_dict(self, input_data, input_labels=None, batch_step=0, update_dict=None): """ diff --git a/models/ensemble_model.py b/models/ensemble_model.py index 19997d5e..51a6dc9b 100644 --- a/models/ensemble_model.py +++ b/models/ensemble_model.py @@ -1,3 +1,5 @@ +import pprint + import torch import DeepSparseCoding.utils.loaders as loaders @@ -15,15 +17,19 @@ def setup(self, params, logger=None): self.setup_optimizer() def setup_module(self, params): + layer_names = [] # TODO: Make this submodule_name=model_type+layer_name is unique, not layer_name is unique for sub_index, subparams in enumerate(params.ensemble_params): - submodule_name = subparams.model_type + f'_{sub_index:02}' - subparams.submodule_name = submodule_name + layer_names.append(subparams.layer_name) + assert len(set(layer_names)) == len(layer_names), ( + 'The "layer_name" parameter must be unique for each module in the ensemble.') + subparams.submodule_name = subparams.model_type + '_' + subparams.layer_name subparams.epoch_size = params.epoch_size subparams.batches_per_epoch = params.batches_per_epoch subparams.num_batches = params.num_batches #subparams.num_val_images = params.num_val_images #subparams.num_test_images = params.num_test_images - subparams.data_shape = params.data_shape + if not hasattr(subparams, 'data_shape'): # TODO: This is a workaround for a dependency on data_shape in lca module + subparams.data_shape = params.data_shape super(EnsembleModel, self).setup_ensemble_module(params) self.submodel_classes = [] for ensemble_index, subparams in enumerate(self.params.ensemble_params): @@ -32,7 +38,11 @@ def setup_module(self, params): if subparams.checkpoint_boot_log != '': checkpoint = self.get_checkpoint_from_log(subparams.checkpoint_boot_log) submodule = self.__getitem__(ensemble_index) - submodule.load_state_dict(checkpoint['model_state_dict']) + module_state_dict_name = subparams.submodule_name+'_module_state_dict' + if module_state_dict_name in checkpoint.keys(): # It was already in an ensemble + submodule.load_state_dict(checkpoint[module_state_dict_name]) + else: # it was trained on its own + submodule.load_state_dict(checkpoint['model_state_dict']) def setup_optimizer(self): for module in self: @@ -45,12 +55,42 @@ def setup_optimizer(self): if module_state_dict_name in checkpoint.keys(): # It was already in an ensemble module.optimizer.load_state_dict(checkpoint[module_state_dict_name]) else: # it was trained on its own - module.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + module.optimizer.load_state_dict(checkpoint['optimizer_state_dict'][0]) #TODO: For some reason this is a tuple of size 1 containing the dictionary. It should just be the dictionary module.scheduler = torch.optim.lr_scheduler.MultiStepLR( module.optimizer, milestones=module.params.optimizer.milestones, gamma=module.params.optimizer.lr_decay_rate) + def load_checkpoint(self, cp_file=None, load_optimizer=False): + """ + Load checkpoint + Keyword arguments: + model_dir: [str] specifying the path to the checkpoint + """ + if cp_file is None: + cp_file = self.params.cp_latest_filename + checkpoint = torch.load(cp_file) + for module in self: + module_state_dict_name = module.params.submodule_name+'_module_state_dict' + if module_state_dict_name in checkpoint.keys(): # It was already in an ensemble + module.load_state_dict(checkpoint[module_state_dict_name]) + _ = checkpoint.pop(module_state_dict_name, None) + else: # it was trained on its own + module.load_state_dict(checkpoint['model_state_dict']) + _ = checkpoint.pop('optimizer_state_dict', None) + if load_optimizer: + module_state_dict_name = module.params.submodule_name+'_optimizer_state_dict' + if module_state_dict_name in checkpoint.keys(): # It was already in an ensemble + module.optimizer.load_state_dict(checkpoint[module_state_dict_name]) + _ = checkpoint.pop(module_state_dict_name, None) + else: + module.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + _ = checkpoint.pop('optimizer_state_dict', None) + _ = checkpoint.pop('model_state_dict', None) + training_status = pprint.pformat(checkpoint, compact=True)#, sort_dicts=True #TODO: Python 3.8 adds the sort_dicts parameter + out_str = f'Loaded checkpoint from {cp_file} with the following stats:\n{training_status}' + return out_str + def preprocess_data(self, data): """ We assume that only the first submodel will be preprocessing the input data @@ -73,7 +113,7 @@ def generate_update_dict(self, input_data, input_labels=None, batch_step=0): input_labels, batch_step, update_dict=dict()) for key, value in submodel_update_dict.items(): if key not in ['epoch', 'batch_step']: - key = submodule.params.model_type+'_'+key + key = submodule.params.submodule_name + '_' + key update_dict[key] = value x = submodule.get_encodings(x) return update_dict diff --git a/models/lca_model.py b/models/lca_model.py index 7d1a66e2..38444b6a 100644 --- a/models/lca_model.py +++ b/models/lca_model.py @@ -41,7 +41,20 @@ def generate_update_dict(self, input_data, input_labels=None, batch_step=0, upda input_data.max().item(), input_data.mean().item(), input_data.min().item()] stat_dict['recon_max_mean_min'] = [ recon.max().item(), recon.mean().item(), recon.min().item()] - latent_nnz = torch.sum(latents != 0).item() # TODO: github issue 23907 requests torch.count_nonzero - stat_dict['latents_fraction_active'] = latent_nnz / latents.numel() + def count_nonzero(array, dim): + # TODO: github issue 23907 requests torch.count_nonzero, integrated in torch 1.7 + return torch.sum(array !=0, dim=dim, dtype=torch.float) + latent_dims = tuple([i for i in range(len(latents.shape))]) + latent_nnz = count_nonzero(latents, dim=latent_dims).item() + stat_dict['fraction_active_all_latents'] = latent_nnz / latents.numel() + if self.params.layer_types[0] == 'conv': + latent_map_dims = latent_dims[2:] + latent_map_size = np.prod(list(latents.shape[2:])) + latent_channel_nnz = count_nonzero(latents, dim=latent_map_dims)/latent_map_size + latent_channel_mean_nnz = torch.mean(latent_channel_nnz).item() + stat_dict['fraction_active_latents_per_channel'] = latent_channel_mean_nnz + num_channels = latents.shape[1] + latent_patch_mean_nnz = torch.mean(count_nonzero(latents, dim=1)/num_channels).item() + stat_dict['fraction_active_latents_per_patch'] = latent_patch_mean_nnz update_dict.update(stat_dict) return update_dict diff --git a/models/pooling_model.py b/models/pooling_model.py index 9fdd7458..5ed4c9fb 100644 --- a/models/pooling_model.py +++ b/models/pooling_model.py @@ -6,6 +6,9 @@ from DeepSparseCoding.modules.pooling_module import PoolingModule class PoolingModel(BaseModel, PoolingModule): + """ + TODO: rename pool_ksize and pool_stride to just kernel_size and stride + """ def setup(self, params, logger=None): self.setup_module(params) self.setup_optimizer() @@ -19,7 +22,7 @@ def loss_fn(model_output): output_loss = losses.trace_covariance(model_output) w_stride = self.params.pool_stride w_padding = 0 - weight_loss = losses.weight_orthogonality(self.w, stride=w_stride, padding=w_padding) + weight_loss = losses.weight_orthogonality(self.weight, stride=w_stride, padding=w_padding) return output_loss + weight_loss input_tensor, input_label = input_tuple layer_output = self.forward(input_tensor) diff --git a/modules/ensemble_module.py b/modules/ensemble_module.py index ee8df6db..27e60f77 100644 --- a/modules/ensemble_module.py +++ b/modules/ensemble_module.py @@ -1,27 +1,22 @@ import torch.nn as nn import DeepSparseCoding.utils.loaders as loaders -from DeepSparseCoding.utils.data_processing import flatten_feature_map class EnsembleModule(nn.Sequential): - def __init__(self): # do not do Sequential's init - super(nn.Sequential, self).__init__() - def setup_ensemble_module(self, params): self.params = params for subparams in params.ensemble_params: submodule = loaders.load_module(subparams.model_type) submodule.setup_module(subparams) - self.add_module(subparams.model_type, submodule) + self.add_module(subparams.layer_name, submodule) def forward(self, x): - self.layer_list = [x] for module in self: if module.params.layer_types[0] == 'fc': - self.layer_list[-1] = flatten_feature_map(self.layer_list[-1]) - self.layer_list.append(module.get_encodings(self.layer_list[-1])) # latent encodings - return self.layer_list[-1] + x = x.view(x.size(0), -1) #flat + x = module(x) + return x def get_encodings(self, x): return self.forward(x) diff --git a/modules/lca_module.py b/modules/lca_module.py index 058e6277..b8195626 100644 --- a/modules/lca_module.py +++ b/modules/lca_module.py @@ -16,18 +16,22 @@ class LcaModule(nn.Module): kernel_size: [int] edge size of the square convolving kernel stride: [int] vertical and horizontal stride of the convolution padding: [int] zero-padding added to both sides of the input + TODO: Inference process should be streamlined by defining only a single step and iterating it in forward() as is done here: + https://pytorch.org/tutorials/beginner/former_torchies/nnft_tutorial.html + + TODO: Remove dependency on data_shape to make more intuitive in a hierarchy. i.e. use layer_channels as is done in the mlp """ def setup_module(self, params): self.params = params if self.params.layer_types[0] == 'fc': - self.layer_output_shapes = [[self.params.num_latent]] - self.w_shape = [self.params.num_pixels, self.params.num_latent] + self.layer_output_shapes = [[self.params.layer_channels]] + self.w_shape = [self.params.num_pixels, self.params.layer_channels] else: self.layer_output_shapes = [self.params.data_shape] # [channels, height, width] assert (self.params.data_shape[-1] % self.params.stride == 0), ( f'Stride = {self.params.stride} must divide evenly into input edge size = {self.params.data_shape[-1]}') self.w_shape = [ - self.params.num_latent, + self.params.layer_channels, self.params.data_shape[0], # channels = 1 self.params.kernel_size, self.params.kernel_size @@ -44,10 +48,10 @@ def setup_module(self, params): self.params.stride, self.params.padding, dilation=1) - self.layer_output_shapes.append([self.params.num_latent, output_height, output_width]) + self.layer_output_shapes.append([self.params.layer_channels, output_height, output_width]) w_init = torch.randn(self.w_shape) w_init_normed = dp.l2_normalize_weights(w_init, eps=self.params.eps) - self.w = nn.Parameter(w_init_normed, requires_grad=True) + self.weight = nn.Parameter(w_init_normed, requires_grad=True) def preprocess_data(self, input_tensor): if self.params.layer_types[0] == 'fc': @@ -56,13 +60,13 @@ def preprocess_data(self, input_tensor): def compute_excitatory_current(self, input_tensor, a_in): if self.params.layer_types[0] == 'fc': - excitatory_current = torch.matmul(input_tensor, self.w) + excitatory_current = torch.matmul(input_tensor, self.weight) else: recon = self.get_recon_from_latents(a_in) recon_error = input_tensor - recon error_injection = F.conv2d( input=recon_error, - weight=self.w, + weight=self.weight, bias=None, stride=self.params.stride, padding=self.params.padding @@ -72,8 +76,8 @@ def compute_excitatory_current(self, input_tensor, a_in): def compute_inhibitory_connectivity(self): if self.params.layer_types[0] == 'fc': - inhibitory_connectivity = torch.matmul(torch.transpose(self.w, dim0=0, dim1=1), - self.w) - torch.eye(self.params.num_latent, + inhibitory_connectivity = torch.matmul(torch.transpose(self.weight, dim0=0, dim1=1), + self.weight) - torch.eye(self.params.layer_channels, requires_grad=True, device=self.params.device) else: inhibitory_connectivity = 0 # TODO: return Grammian along channel dim for a single kernel location @@ -115,22 +119,20 @@ def infer_coefficients(self, input_tensor): def get_recon_from_latents(self, a_in): if self.params.layer_types[0] == 'fc': - recon = torch.matmul(a_in, torch.transpose(self.w, dim0=0, dim1=1)) + recon = torch.matmul(a_in, torch.transpose(self.weight, dim0=0, dim1=1)) else: recon = F.conv_transpose2d( input=a_in, - weight=self.w, + weight=self.weight, bias=None, stride=self.params.stride, padding=self.params.padding ) return recon - def get_encodings(self, input_tensor): + def forward(self, input_tensor): u_list, a_list = self.infer_coefficients(input_tensor) return a_list[-1] - def forward(self, input_tensor): - latents = self.get_encodings(input_tensor) - reconstruction = self.get_recon_from_latents(latents) - return reconstruction + def get_encodings(self, input_tensor): + return self.forward(input_tensor) diff --git a/modules/losses.py b/modules/losses.py index 4d723bd2..1ecdd4b6 100644 --- a/modules/losses.py +++ b/modules/losses.py @@ -4,7 +4,7 @@ import DeepSparseCoding.utils.data_processing as dp -#def l2_flatness(z1, z2, z3, w): +#def l2_flatness(z1, z2, z3, weight): # """ # Minimized when a straight line can be drawn through [z1, z2, z3]. # Extended from equations 8 and 12 in @@ -34,12 +34,12 @@ def half_weight_norm_squared(weight_list): Keyword arguments: weight_list: List of torch variables Outputs: - w_norm_loss: 0.5 * sum of (1 - l2_norm(w))^2 for each w in weight_list + w_norm_loss: 0.5 * sum of (1 - l2_norm(weight))^2 for each weight in weight_list """ w_norm_list = [] - for w in weight_list: - reduc_dim = list(range(1, len(w.shape))) - w_norm = torch.sum(torch.pow(1 - torch.sqrt(torch.sum(tf.pow(w, 2.), axis=reduc_dim)), 2.)) + for weight in weight_list: + reduc_dim = list(range(1, len(weight.shape))) + w_norm = torch.sum(torch.pow(1 - torch.sqrt(torch.sum(tf.pow(weight, 2.), axis=reduc_dim)), 2.)) w_norm_list.append(w_norm) norm_loss = 0.5 * torch.sum(w_norm_list) return norm_loss @@ -51,9 +51,9 @@ def weight_decay(weight_list): Keyword arguments: weight_list: List of torch variables Outputs: - decay_loss: 0.5 * sum of w^2 for each w in weight_list + decay_loss: 0.5 * sum of weight^2 for each weight in weight_list """ - decay_loss = 0.5 * torch.sum([torch.sum(torch.pow(w, 2.)) for w in weight_list]) + decay_loss = 0.5 * torch.sum([torch.sum(torch.pow(weight, 2.)) for weight in weight_list]) return decay_loss @@ -88,12 +88,12 @@ def trace_covariance(latents): return -1 * trace -def weight_orthogonality(weights, stride=1, padding=0): +def weight_orthogonality(weight, stride=1, padding=0): """ - Returns l2 loss that is minimized when the weights are orthogonal + Returns l2 loss that is minimized when the weight are orthogonal Keyword arguments: - weights [torch tensor] layer weights, either fully connected or 2d convolutional + weight [torch tensor] layer weight, either fully connected or 2d convolutional stride [int] layer stride for convolutional layers padding [int] layer padding for convolutional layers @@ -106,20 +106,20 @@ def weight_orthogonality(weights, stride=1, padding=0): https://arxiv.org/abs/1911.12207 https://github.com/samaonline/Orthogonal-Convolutional-Neural-Networks """ - w_shape = weights.shape - if weights.ndim == 2: # fully-connected, [inputs, outputs] - loss = torch.norm(torch.matmul(weights.transpose(), weights) - torch.eye(w_shape[1])) - elif weights.ndim == 4: # convolutional, [output_channels, input_channels, height, width] + w_shape = weight.shape + if weight.ndim == 2: # fully-connected, [inputs, outputs] + loss = torch.norm(torch.matmul(weight.transpose(), weight) - torch.eye(w_shape[1])) + elif weight.ndim == 4: # convolutional, [output_channels, input_channels, height, width] out_channels, in_channels, in_height, in_width = w_shape - output = torch.conv2d(weights, weights, stride=stride, padding=padding) + output = torch.conv2d(weight, weight, stride=stride, padding=padding) out_height = output.shape[-2] out_width = output.shape[-1] target = torch.zeros((out_channels, out_channels, out_height, out_width), - device=weights.device) + device=weight.device) center_h = int(np.floor(out_height / 2)) center_w = int(np.floor(out_width / 2)) - target[:, :, center_h, center_w] = torch.eye(out_channels, device=weights.device) + target[:, :, center_h, center_w] = torch.eye(out_channels, device=weight.device) loss = torch.norm(output - target, p='fro') else: - assert False, (f'weights ndim must be 2 or 4, not {weights.ndim}') + assert False, (f'weight ndim must be 2 or 4, not {weight.ndim}') return loss diff --git a/modules/mlp_module.py b/modules/mlp_module.py index 727334d2..cebda2c3 100644 --- a/modules/mlp_module.py +++ b/modules/mlp_module.py @@ -1,10 +1,11 @@ +from collections import OrderedDict + import numpy as np import torch.nn as nn import torch.nn.functional as F from DeepSparseCoding.modules.activations import activation_picker from DeepSparseCoding.utils.run_utils import compute_conv_output_shape -from DeepSparseCoding.utils.data_processing import flatten_feature_map class MlpModule(nn.Module): @@ -81,19 +82,33 @@ def setup_module(self, params): else: self.pooling.append(nn.Identity()) # do nothing self.dropout.append(nn.Dropout(p=self.params.dropout_rate[layer_index])) + conv_module_dict = OrderedDict() + fc_module_dict = OrderedDict() + layer_zip = zip(self.params.layer_types, self.layers, self.act_funcs, self.pooling, + self.dropout) + for layer_idx, full_layer in enumerate(layer_zip): + for component_idx, layer_component in enumerate(full_layer[1:]): + component_id = f'{layer_idx:02}-{component_idx:02}' + if full_layer[0] == 'fc': + fc_module_dict[full_layer[0] + component_id] = layer_component + else: + conv_module_dict[full_layer[0] + component_id] = layer_component + self.conv_sequential = lambda x: x # identity by default + self.fc_sequential = lambda x: x # identity by default + if len(conv_module_dict) > 0: + self.conv_sequential = nn.Sequential(conv_module_dict) + if len(fc_module_dict) > 0: + self.fc_sequential = nn.Sequential(fc_module_dict) def preprocess_data(self, input_tensor): if self.params.layer_types[0] == 'fc': - input_tensor = flatten_feature_map(input_tensor) + input_tensor = input_tensor.view(input_tensor.size(0), -1) #flat return input_tensor def forward(self, x): - layer_zip = zip(self.dropout, self.pooling, self.act_funcs, - self.layers, self.params.layer_types) - for dropout, pooling, act_func, layer, layer_type in layer_zip: - if layer_type == 'fc': - x = flatten_feature_map(x) - x = dropout(pooling(act_func(layer(x)))) + x = self.conv_sequential(x) + x = x.view(x.size(0), -1) #flat + x = self.fc_sequential(x) return x def get_encodings(self, input_tensor): diff --git a/modules/pooling_module.py b/modules/pooling_module.py index 59265662..8b181cc2 100644 --- a/modules/pooling_module.py +++ b/modules/pooling_module.py @@ -1,23 +1,20 @@ import torch import torch.nn as nn -import torch.nn.functional as F - -import DeepSparseCoding.utils.data_processing as dp class PoolingModule(nn.Module): def setup_module(self, params): params.weight_decay = 0 # used by base model; pooling layer never has weight decay self.params = params - if self.params.layer_type == 'fc': + if self.params.layer_types[0] == 'fc': self.layer = nn.Linear( in_features=self.params.layer_channels[0], out_features=self.params.layer_channels[1], bias=False) - self.w = self.layer.weight - self.register_parameter('fc_pool_'+self.params.layer_name+'_w', self.layer.weight) + self.weight = self.layer.weight + #self.register_parameter('fc_pool_'+self.params.layer_name+'_w', self.layer.weight) - elif self.params.layer_type == 'conv': + elif self.params.layer_types[0] == 'conv': self.layer = nn.Conv2d( in_channels=self.params.layer_channels[0], out_channels=self.params.layer_channels[1], @@ -27,15 +24,15 @@ def setup_module(self, params): dilation=1, bias=False) nn.init.orthogonal_(self.layer.weight) # initialize to orthogonal matrix - self.w = self.layer.weight - self.register_parameter('conv_pool_'+self.params.layer_name+'_w', self.layer.weight) + self.weight = self.layer.weight + #self.register_parameter('conv_pool_'+self.params.layer_name+'_w', self.layer.weight) else: - assert False, ('layer_type parameter must be "fc", "conv", not %g'%(layer_type)) + assert False, ('layer_types[0] parameter must be "fc", "conv", not %g'%(layer_types[0])) def forward(self, x): - if self.params.layer_type == 'fc': - x = dp.flatten_feature_map(x) + if self.params.layer_types[0] == 'fc': + x = x.view(x.shape[0], -1) # flat return self.layer(x) def get_encodings(self, input_tensor): diff --git a/params/lca_cifar10_params.py b/params/lca_cifar10_params.py index 545f34bf..791a9d13 100644 --- a/params/lca_cifar10_params.py +++ b/params/lca_cifar10_params.py @@ -7,7 +7,7 @@ class params(BaseParams): def set_params(self): super(params, self).set_params() self.model_type = 'lca' - self.model_name = 'conv_lca_cifar10' + self.model_name = 'lca_cifar10' self.version = '0' self.dataset = 'cifar10' self.layer_types = ['conv'] @@ -19,6 +19,8 @@ def set_params(self): self.num_epochs = 500 self.train_logs_per_epoch = 6 self.renormalize_weights = True + self.layer_channels = 128 + self.kernel_size = 8 self.stride = 2 self.padding = 0 self.weight_decay = 0.0 @@ -28,13 +30,11 @@ def set_params(self): self.optimizer.lr_annealing_milestone_frac = [0.8] # fraction of num_epochs self.optimizer.lr_decay_rate = 0.8 self.dt = 0.001 - self.tau = 0.2 - self.num_steps = 75 + self.tau = 0.1#0.2 + self.num_steps = 37#75 self.rectify_a = True self.thresh_type = 'hard' - self.sparse_mult = 0.30 - self.kernel_size = 8 - self.num_latent = 512 + self.sparse_mult = 0.35#0.30 self.compute_helper_params() def compute_helper_params(self): @@ -42,6 +42,6 @@ def compute_helper_params(self): self.optimizer.milestones = [frac * self.num_epochs for frac in self.optimizer.lr_annealing_milestone_frac] self.step_size = self.dt / self.tau - self.out_channels = self.num_latent + self.out_channels = self.layer_channels self.num_pixels = 3072 self.in_channels = 3 diff --git a/params/lca_dsprites_params.py b/params/lca_dsprites_params.py index d498b076..a932ec63 100644 --- a/params/lca_dsprites_params.py +++ b/params/lca_dsprites_params.py @@ -12,6 +12,7 @@ def set_params(self): super(params, self).set_params() self.model_type = 'lca' self.model_name = 'lca_dsprites' + self.layer_types = ['fc'] self.version = '0' self.dataset = 'dsprites' self.standardize_data = False @@ -26,13 +27,13 @@ def set_params(self): self.optimizer.lr_annealing_milestone_frac = [0.7] # fraction of num_epochs self.optimizer.lr_decay_rate = 0.5 self.renormalize_weights = True + self.layer_channels = int(self.num_pixels*1.5) self.dt = 0.001 self.tau = 0.03 self.num_steps = 75 self.rectify_a = False self.thresh_type = 'soft' self.sparse_mult = 0.25 - self.num_latent = int(self.num_pixels*1.5) self.compute_helper_params() def compute_helper_params(self): diff --git a/params/lca_mlp_cifar10_params.py b/params/lca_mlp_cifar10_params.py index 9e539360..e90c5402 100644 --- a/params/lca_mlp_cifar10_params.py +++ b/params/lca_mlp_cifar10_params.py @@ -13,11 +13,11 @@ class shared_params(object): def __init__(self): self.model_type = 'ensemble' self.model_name = 'lca_mlp_cifar10' - self.version = '1' + self.version = '0' self.dataset = 'cifar10' self.standardize_data = True self.batch_size = 25 - self.num_epochs = 500 + self.num_epochs = 10#00 self.train_logs_per_epoch = 4 self.allow_parent_grads = True @@ -25,13 +25,15 @@ def __init__(self): class lca_params(LcaParams): def set_params(self): super(lca_params, self).set_params() - for key, value in shared_params().__dict__.items(): - setattr(self, key, value) + for key, value in shared_params().__dict__.items(): setattr(self, key, value) self.model_type = 'lca' + self.layer_name = 'lca' self.layer_types = ['conv'] self.weight_decay = 0.0 self.weight_lr = 0.001 self.renormalize_weights = True + self.layer_channels = 512 + self.kernel_size = 8 self.stride = 2 self.padding = 0 self.optimizer = types.SimpleNamespace() @@ -44,8 +46,7 @@ def set_params(self): self.rectify_a = True self.thresh_type = 'hard' self.sparse_mult = 0.30 - self.num_latent = 512 - self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/conv_lca_cifar10/logfiles/conv_lca_cifar10_v1.log' + self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/conv_lca_cifar10/logfiles/lca_cifar10_v0.log' self.compute_helper_params() @@ -55,6 +56,7 @@ def set_params(self): for key, value in shared_params().__dict__.items(): setattr(self, key, value) self.model_type = 'mlp' + self.layer_name = 'classifier' self.weight_lr = 2e-3 self.weight_decay = 1e-6 self.layer_types = ['fc'] @@ -62,7 +64,7 @@ def set_params(self): self.activation_functions = ['identity'] self.dropout_rate = [0.0] # probability of value being set to zero self.optimizer = types.SimpleNamespace() - self.optimizer.name = 'adam' + self.optimizer.name = 'sgd' self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs self.optimizer.lr_decay_rate = 0.8 self.compute_helper_params() @@ -85,7 +87,7 @@ def set_params(self): lca_params_inst.stride, lca_params_inst.padding, dilation=1) - lca_output_shape = [lca_params_inst.num_latent, lca_output_height, lca_output_width] + lca_output_shape = [lca_params_inst.layer_channels, lca_output_height, lca_output_width] mlp_params_inst.layer_channels[0] = np.prod(lca_output_shape) self.ensemble_params = [lca_params_inst, mlp_params_inst] for key, value in shared_params().__dict__.items(): diff --git a/params/lca_mlp_mnist_params.py b/params/lca_mlp_mnist_params.py index 537e96ba..8c8c6d4d 100644 --- a/params/lca_mlp_mnist_params.py +++ b/params/lca_mlp_mnist_params.py @@ -29,6 +29,7 @@ def set_params(self): for key, value in shared_params().__dict__.items(): setattr(self, key, value) self.model_type = 'lca' + self.layer_name = 'lca' self.weight_decay = 0.0 self.weight_lr = 0.1 self.optimizer = types.SimpleNamespace() @@ -36,13 +37,13 @@ def set_params(self): self.optimizer.lr_annealing_milestone_frac = [0.7] # fraction of num_epochs self.optimizer.lr_decay_rate = 0.5 self.renormalize_weights = True + self.layer_channels = 768 self.dt = 0.001 self.tau = 0.03 self.num_steps = 75 self.rectify_a = True self.thresh_type = 'soft' self.sparse_mult = 0.25 - self.num_latent = 768 self.checkpoint_boot_log = '' self.compute_helper_params() @@ -53,6 +54,7 @@ def set_params(self): for key, value in shared_params().__dict__.items(): setattr(self, key, value) self.model_type = 'mlp' + self.layer_name = 'classifier' self.weight_lr = 1e-4 self.weight_decay = 0.0 self.layer_types = ['fc'] diff --git a/params/lca_mnist_params.py b/params/lca_mnist_params.py index e3a7ebc2..eb248c9d 100644 --- a/params/lca_mnist_params.py +++ b/params/lca_mnist_params.py @@ -35,10 +35,10 @@ def set_params(self): self.weight_lr = 0.001 self.optimizer.lr_annealing_milestone_frac = [0.8] # fraction of num_epochs self.optimizer.lr_decay_rate = 0.8 + self.layer_channels = 128 self.kernel_size = 8 self.stride = 2 self.padding = 0 - self.num_latent = 128 else: self.layer_types = ['fc'] self.model_type = 'lca' @@ -48,7 +48,7 @@ def set_params(self): self.weight_lr = 0.1 self.optimizer.lr_annealing_milestone_frac = [0.7] # fraction of num_epochs self.optimizer.lr_decay_rate = 0.5 - self.num_latent = 768 #self.num_pixels * 4 + self.layer_channels = 768 #self.num_pixels * 4 self.compute_helper_params() def compute_helper_params(self): diff --git a/params/lca_pool_cifar10_params.py b/params/lca_pool_cifar10_params.py index c1d1073b..e252448c 100644 --- a/params/lca_pool_cifar10_params.py +++ b/params/lca_pool_cifar10_params.py @@ -16,7 +16,7 @@ def __init__(self): self.dataset = 'cifar10' self.standardize_data = True self.batch_size = 25 - self.num_epochs = 5 + self.num_epochs = 10 self.train_logs_per_epoch = 4 self.allow_parent_grads = False @@ -26,10 +26,13 @@ def set_params(self): super(lca_params, self).set_params() for key, value in shared_params().__dict__.items(): setattr(self, key, value) self.model_type = 'lca' + self.layer_name = 'lca_1' self.layer_types = ['conv'] self.weight_decay = 0.0 self.weight_lr = 0.001 self.renormalize_weights = True + self.layer_channels = 128 + self.kernel_size = 8 self.stride = 2 self.padding = 0 self.optimizer = types.SimpleNamespace() @@ -37,28 +40,27 @@ def set_params(self): self.optimizer.lr_annealing_milestone_frac = [0.8] # fraction of num_epochs self.optimizer.lr_decay_rate = 0.8 self.dt = 0.001 - self.tau = 0.2 - self.num_steps = 75 + self.tau = 0.1#0.2 + self.num_steps = 37#75 self.rectify_a = True self.thresh_type = 'hard' - self.sparse_mult = 0.30 - self.num_latent = 512 - self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/conv_lca_cifar10/logfiles/conv_lca_cifar10_v1.log' + self.sparse_mult = 0.35#0.30 + self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/lca_cifar10/logfiles/lca_cifar10_v0.log' self.compute_helper_params() -class pool_params(BaseParams): +class pooling_params(BaseParams): def set_params(self): - super(pool_params, self).set_params() + super(pooling_params, self).set_params() for key, value in shared_params().__dict__.items(): setattr(self, key, value) - self.model_type = 'pool' + self.model_type = 'pooling' self.layer_name = 'pool_1' self.weight_lr = 1e-3 - self.layer_type = 'conv' - self.layer_channels = [512, 10] - self.pool_ksize = 4 - self.pool_stride = 2 + self.layer_types = ['conv'] + self.layer_channels = [128, 32] + self.pool_ksize = 2 + self.pool_stride = 2 # non-overlapping self.optimizer = types.SimpleNamespace() self.optimizer.name = 'sgd' self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs @@ -66,7 +68,7 @@ def set_params(self): self.compute_helper_params() def compute_helper_params(self): - super(pool_params, self).compute_helper_params() + super(pooling_params, self).compute_helper_params() self.optimizer.milestones = [frac * self.num_epochs for frac in self.optimizer.lr_annealing_milestone_frac] @@ -75,10 +77,10 @@ class params(BaseParams): def set_params(self): super(params, self).set_params() lca_params_inst = lca_params() - pool_params_inst = pool_params() - if(pool_params_inst.layer_type == 'fc' and lca_params_inst.layer_type == 'conv'): + pooling_params_inst = pooling_params() + if(pooling_params_inst.layer_types[0] == 'fc' and lca_params_inst.layer_types[0] == 'conv'): lca_output_height = compute_conv_output_shape( - 32, # TODO: infer this? currently hardcoded CIFAR10 size + 32, lca_params_inst.kernel_size, lca_params_inst.stride, lca_params_inst.padding, @@ -90,7 +92,7 @@ def set_params(self): lca_params_inst.padding, dilation=1) lca_output_shape = [lca_params_inst.num_latent, lca_output_height, lca_output_width] - pool_params_inst.layer_channels[0] = np.prod(lca_output_shape) - self.ensemble_params = [lca_params_inst, pool_params_inst] + pooling_params_inst.layer_channels[0] = np.prod(lca_output_shape) + self.ensemble_params = [lca_params_inst, pooling_params_inst] for key, value in shared_params().__dict__.items(): setattr(self, key, value) diff --git a/params/lca_pool_lca_cifar10_params.py b/params/lca_pool_lca_cifar10_params.py new file mode 100644 index 00000000..56f4d0e2 --- /dev/null +++ b/params/lca_pool_lca_cifar10_params.py @@ -0,0 +1,133 @@ +import os +import types + +import numpy as np +import torch + +from DeepSparseCoding.params.base_params import BaseParams +from DeepSparseCoding.params.lca_cifar10_params import params as LcaParams +from DeepSparseCoding.utils.run_utils import compute_conv_output_shape + + +class shared_params(object): + def __init__(self): + self.model_type = 'ensemble' + self.model_name = 'lca_pool_lca_cifar10' + self.version = '0' + self.dataset = 'cifar10' + self.standardize_data = True + self.batch_size = 25 + self.num_epochs = 250 + self.train_logs_per_epoch = 4 + self.allow_parent_grads = False + + +class lca_1_params(LcaParams): + def set_params(self): + super(lca_1_params, self).set_params() + for key, value in shared_params().__dict__.items(): setattr(self, key, value) + self.model_type = 'lca' + self.layer_name = 'lca_1' + self.layer_types = ['conv'] + self.weight_decay = 0.0 + self.weight_lr = 0.001 + self.renormalize_weights = True + self.layer_channels = 128 + self.kernel_size = 8 + self.stride = 2 + self.padding = 0 + self.optimizer = types.SimpleNamespace() + self.optimizer.name = 'sgd' + self.optimizer.lr_annealing_milestone_frac = [0.8] # fraction of num_epochs + self.optimizer.lr_decay_rate = 0.8 + self.dt = 0.001 + self.tau = 0.1#0.2 + self.num_steps = 37#75 + self.rectify_a = True + self.thresh_type = 'hard' + self.sparse_mult = 0.35#0.30 + self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/lca_cifar10/logfiles/lca_cifar10_v0.log' + self.compute_helper_params() + + +class pooling_params(BaseParams): + def set_params(self): + super(pooling_params, self).set_params() + for key, value in shared_params().__dict__.items(): + setattr(self, key, value) + self.model_type = 'pooling' + self.layer_name = 'pool_1' + self.weight_lr = 1e-3 + self.layer_types = ['conv'] + self.layer_channels = [128, 32] + self.pool_ksize = 2 + self.pool_stride = 2 # non-overlapping + self.optimizer = types.SimpleNamespace() + self.optimizer.name = 'sgd' + self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs + self.optimizer.lr_decay_rate = 0.8 + self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/lca_pool_cifar10/logfiles/lca_pool_cifar10_v0.log' + self.compute_helper_params() + + def compute_helper_params(self): + super(pooling_params, self).compute_helper_params() + self.optimizer.milestones = [frac * self.num_epochs + for frac in self.optimizer.lr_annealing_milestone_frac] + + +class lca_2_params(LcaParams): + def set_params(self): + super(lca_2_params, self).set_params() + for key, value in shared_params().__dict__.items(): setattr(self, key, value) + for key, value in lca_1_params().__dict__.items(): setattr(self, key, value) + self.layer_name = 'lca_2' + self.layer_channels = 256 + self.kernel_size = 6 + self.stride = 1 + self.padding = 0 + self.sparse_mult = 0.10 + self.checkpoint_boot_log = '' + self.compute_helper_params() + + +class params(BaseParams): + def set_params(self): + super(params, self).set_params() + lca_1_params_inst = lca_1_params() + pooling_params_inst = pooling_params() + lca_2_params_inst = lca_2_params() + lca_1_output_height = compute_conv_output_shape( + 32, + lca_1_params_inst.kernel_size, + lca_1_params_inst.stride, + lca_1_params_inst.padding, + dilation=1) + lca_1_output_width = compute_conv_output_shape( + 32, + lca_1_params_inst.kernel_size, + lca_1_params_inst.stride, + lca_1_params_inst.padding, + dilation=1) + pooling_output_height = compute_conv_output_shape( + lca_1_output_height, + pooling_params_inst.pool_ksize, + pooling_params_inst.pool_stride, + padding=0, + dilation=1) + pooling_output_width = compute_conv_output_shape( + lca_1_output_width, + pooling_params_inst.pool_ksize, + pooling_params_inst.pool_stride, + padding=0, + dilation=1) + lca_2_params_inst.data_shape = [ + int(pooling_params_inst.layer_channels[-1]), + int(pooling_output_height), + int(pooling_output_width)] + self.ensemble_params = [ + lca_1_params_inst, + pooling_params_inst, + lca_2_params_inst + ] + for key, value in shared_params().__dict__.items(): + setattr(self, key, value) diff --git a/params/test_params.py b/params/test_params.py index 529c1522..9e78c701 100644 --- a/params/test_params.py +++ b/params/test_params.py @@ -56,7 +56,7 @@ def set_params(self): self.model_type = 'lca' self.weight_decay = 0.0 self.weight_lr = 0.1 - self.layer_type = 'fc' + self.layer_types = ['fc'] self.optimizer = types.SimpleNamespace() self.optimizer.name = 'sgd' self.optimizer.lr_annealing_milestone_frac = [0.7] # fraction of num_epochs @@ -68,7 +68,7 @@ def set_params(self): self.rectify_a = True self.thresh_type = 'soft' self.sparse_mult = 0.25 - self.num_latent = 128 + self.layer_channels = 128 self.optimizer.milestones = [frac * self.num_epochs for frac in self.optimizer.lr_annealing_milestone_frac] self.step_size = self.dt / self.tau @@ -77,17 +77,37 @@ def set_params(self): #class conv_lca_params(lca_params): # def set_params(self): # super(conv_lca_params, self).set_params() -# self.layer_type = 'conv' +# self.layer_types = ['conv'] # self.kernel_size = 8 # self.stride = 2 # self.padding = 0 # self.optimizer.milestones = [frac * self.num_epochs # for frac in self.optimizer.lr_annealing_milestone_frac] # self.step_size = self.dt / self.tau -# self.out_channels = self.num_latent +# self.out_channels = self.layer_channels # self.in_channels = 1 +class pooling_params(BaseParams): + def set_params(self): + super(pooling_params, self).set_params() + for key, value in shared_params().__dict__.items(): + setattr(self, key, value) + self.model_type = 'pooling' + self.layer_name = 'test_pool_1' + self.weight_lr = 1e-3 + self.layer_types = ['conv'] + self.layer_channels = [128, 32] + self.pool_ksize = 2 + self.pool_stride = 2 # non-overlapping + self.optimizer = types.SimpleNamespace() + self.optimizer.name = 'sgd' + self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs + self.optimizer.lr_decay_rate = 0.8 + self.optimizer.milestones = [frac * self.num_epochs + for frac in self.optimizer.lr_annealing_milestone_frac] + + class mlp_params(BaseParams): def set_params(self): super(mlp_params, self).set_params() diff --git a/tests/test_data_processing.py b/tests/test_data_processing.py index 96da0b9e..7f6d43bf 100644 --- a/tests/test_data_processing.py +++ b/tests/test_data_processing.py @@ -134,16 +134,6 @@ def test_reshape_data(self): self.assertEqual(reshaped_array.shape, expected_out_shape, err_msg) self.assertEqual(resh_num_examples, None, err_msg) - - def test_flatten_feature_map(self): - unflat_shape = [8, 4, 4, 3] - flat_shape = [8, 4*4*3] - shapes = [unflat_shape, flat_shape] - for shape in shapes: - test_map = torch.zeros(shape) - flat_map = dp.flatten_feature_map(test_map).numpy() - self.assertEqual(list(flat_map.shape), flat_shape) - def test_standardize(self): num_tolerance_decimals = 5 unflat_shape = [8, 4, 4, 3] diff --git a/tests/test_models.py b/tests/test_models.py index 67b80732..080cdd46 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -75,7 +75,7 @@ def test_lca_ensemble_gradients(self): models['ensemble'].setup(params['ensemble']) models['ensemble'].to(params['ensemble'].device) ensemble_state_dict = models['ensemble'].state_dict() - ensemble_state_dict['lca.w'] = models['lca'].w.clone() + ensemble_state_dict['lca.weight'] = models['lca'].weight.clone() models['ensemble'].load_state_dict(ensemble_state_dict) data, target = next(iter(train_loader)) train_data_batch = models['lca'].preprocess_data(data.to(params['lca'].device)) @@ -93,18 +93,18 @@ def test_lca_ensemble_gradients(self): ensemble_losses[0].backward() ensemble_losses[1].backward() lca_loss_val = lca_loss.cpu().detach().numpy() - lca_w_grad = models['lca'].w.grad.cpu().numpy() + lca_w_grad = models['lca'].weight.grad.cpu().numpy() ensemble_loss_val = ensemble_losses[0].cpu().detach().numpy() - ensemble_w_grad = models['ensemble'][0].w.grad.cpu().numpy() + ensemble_w_grad = models['ensemble'][0].weight.grad.cpu().numpy() assert lca_loss_val == ensemble_loss_val, (err_msg+'\n' +'Losses should be equal, but are lca={lca_loss_val} and ensemble={ensemble_loss_val}') assert np.all(lca_w_grad == ensemble_w_grad), (err_msg+'\nGrads should be equal, but are not.') - lca_pre_train_w = models['lca'].w.cpu().detach().numpy().copy() - ensemble_pre_train_w = models['ensemble'][0].w.cpu().detach().numpy().copy() + lca_pre_train_w = models['lca'].weight.cpu().detach().numpy().copy() + ensemble_pre_train_w = models['ensemble'][0].weight.cpu().detach().numpy().copy() run_utils.train_epoch(1, models['lca'], train_loader) run_utils.train_epoch(1, models['ensemble'], train_loader) - lca_w = models['lca'].w.cpu().detach().numpy().copy() - ensemble_w = models['ensemble'][0].w.cpu().detach().numpy().copy() + lca_w = models['lca'].weight.cpu().detach().numpy().copy() + ensemble_w = models['ensemble'][0].weight.cpu().detach().numpy().copy() assert np.all(lca_pre_train_w == ensemble_pre_train_w), (err_msg+'\n' +"lca & ensemble weights are not equal before one epoch of training") assert not np.all(lca_pre_train_w == lca_w), (err_msg+'\n' diff --git a/train_model.py b/train_model.py index 40b0414d..6a3bc5d6 100644 --- a/train_model.py +++ b/train_model.py @@ -7,6 +7,8 @@ ROOT_DIR = up(up(os.path.realpath(__file__))) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) +import torch + import DeepSparseCoding.utils.loaders as loaders import DeepSparseCoding.utils.run_utils as run_utils import DeepSparseCoding.utils.dataset_utils as dataset_utils @@ -31,6 +33,7 @@ model = loaders.load_model(params.model_type) model.setup(params) model.to(params.device) +model.log_architecture_details() # Train model for epoch in range(1, model.params.num_epochs+1): @@ -38,7 +41,7 @@ # TODO: Ensemble models might not actually have a classification objective / need validation #if(model.params.model_type.lower() in ['mlp', 'ensemble']): # TODO: use to validation set here; test at the end of training # run_utils.test_epoch(epoch, model, test_loader) - model.log_info(f'Completed epoch {epoch}/{model.params.num_epochs}') + model.logger.log_string(f'Completed epoch {epoch}/{model.params.num_epochs}') print(f'Completed epoch {epoch}/{model.params.num_epochs}') # Final outputs @@ -46,7 +49,7 @@ tot_time=float(t1-t0) tot_images = model.params.num_epochs*len(train_loader.dataset) out_str = f'Training on {tot_images} images is complete. Total time was {tot_time} seconds.\n' -model.log_info(out_str) +model.logger.log_string(out_str) print('Training Complete\n') model.write_checkpoint() diff --git a/utils/data_processing.py b/utils/data_processing.py index 684864bc..933a6305 100644 --- a/utils/data_processing.py +++ b/utils/data_processing.py @@ -113,28 +113,6 @@ def check_all_same_shape(tensor_list): f'Tensor entry {index} in input list has shape {tensor.shape}, but should have shape {first_shape}') -def flatten_feature_map(feature_map): - """ - Flatten input tensor from [batch, c, y, x] to [batch, c * y * x] - - Keyword arguments: - feature_map: tensor with shape [batch, c, y, x] - - Returns: - reshaped_map: tensor with shape [batch, c * y * x] - """ - map_shape = feature_map.shape - if(len(map_shape) == 4): - (batch, c, y, x) = map_shape - prev_input_features = int(c * y * x) - resh_map = torch.reshape(feature_map, [batch, prev_input_features]) - elif(len(map_shape) == 2): - resh_map = feature_map - else: - raise ValueError('Input feature_map has incorrect ndims') - return resh_map - - def get_std_from_dataloader(loader, dataset_mean): """ TODO: Calculate the standard deviation from all entries in a pytorch data loader @@ -177,7 +155,7 @@ def get_mean_from_dataloader(loader): def center(data, samplewise=False, batch_size=100): """ Center image dataset to have zero mean - + Keyword arguments: data: [tensor] unnormalized data samplewise: [bool] if True, center each sample individually; if False, compute mean over entire batch @@ -200,7 +178,7 @@ def center(data, samplewise=False, batch_size=100): def standardize(data, eps=None, samplewise=False, batch_size=100): """ Standardize each image data to have zero mean and unit standard-deviation (z-score) - + This function uses population standard deviation data.sum() / N, where N = data.shape[0]. Keyword arguments: @@ -329,7 +307,7 @@ def get_weights_l2_norm(w, eps=1e-12): norms = torch.norm(w.flatten(start_dim=1), dim=-1, keepdim=True) else: assert False, (f'input w must have ndim = 2 or 4, not {w.ndim}') - if(torch.max(norms) <= eps): #TODO: Warnings + if(torch.max(norms) <= eps): #TODO: raise proper warnings print(f'Warning: input gradient is less than or equal to {eps}') norms = torch.max(norms, eps*torch.ones_like(norms)) # prevent div by 0 # TODO: Change to torch.maximum when it is stable norms = atleast_kd(norms, w.ndim) diff --git a/utils/file_utils.py b/utils/file_utils.py index 6dd4b7f8..4a587aba 100644 --- a/utils/file_utils.py +++ b/utils/file_utils.py @@ -3,6 +3,7 @@ import types import os from copy import deepcopy +from collections import OrderedDict import importlib import numpy as np @@ -27,6 +28,16 @@ def js_dumpstring(self, obj): """Dump json string with special CustomEncoder""" return js.dumps(obj, sort_keys=True, indent=2, cls=CustomEncoder) + def log_string(self, string): + """Log input string""" + now = time.localtime(time.time()) + time_str = time.strftime('%m/%d/%y %H:%M:%S', now) + out_str = '\n' + time_str + ' -- ' + str(string) + if(self.log_to_file): + self.file_obj.write(out_str) + else: + print(out_str) + def log_trainable_variables(self, name_list): """ Use logging to write names of trainable variables in model @@ -34,7 +45,7 @@ def log_trainable_variables(self, name_list): name_list: list containing variable names """ js_str = self.js_dumpstring(name_list) - self.log_info(''+js_str+'') + self.log_string(''+js_str+'') def log_params(self, params): """ @@ -54,17 +65,17 @@ def log_params(self, params): if('rand_state' in out_params.keys()): del out_params['rand_state'] js_str = self.js_dumpstring(out_params) - self.log_info(''+js_str+'') + self.log_string(''+js_str+'') - def log_info(self, string): - """Log input string""" - now = time.localtime(time.time()) - time_str = time.strftime('%m/%d/%y %H:%M:%S', now) - out_str = '\n' + time_str + ' -- ' + str(string) - if(self.log_to_file): - self.file_obj.write(out_str) - else: - print(out_str) + def log_stats(self, stat_dict): + """Log dictionary of training / testing statistics""" + js_str = self.js_dumpstring(stat_dict) + self.log_string(''+js_str+'') + + def log_info(self, info_dict): + """Log input dictionary in tags""" + js_str = self.js_dumpstring(info_dict) + self.log_string(''+js_str+'') def load_file(self, filename=None): """ @@ -199,3 +210,92 @@ def python_module_from_file(py_module_name, file_name): py_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(py_module) return py_module + +def summary_string(model, input_size, batch_size=2, device=torch.device('cuda:0'), dtype=torch.FloatTensor): + """ + Returns a string that summarizees the model architecture, including the number of parameters + and layer output sizes + + Code is modified from: + https://github.com/sksq96/pytorch-summary + + Keyword arguments: + model [torch module, module subclass, or EnsembleModel] model to summarize + input_size [tuple or list of tuples] must not include the batch dimension; if it is a list + of tuples then the architecture will be computed for each option + batch_size [positive int] how many images to feed into the model. + The default of 2 will ensure that batch norm works. + devie [torch.device] which device to run the test on + dtype [torch.dtype] for the artificially generated inputs + """ + def register_hook(module): + def hook(module, input, output): + class_name = str(module.__class__).split('.')[-1].split("'")[0] + module_idx = len(summary) + m_key = '%s-%i' % (class_name, module_idx + 1) + summary[m_key] = OrderedDict() + summary[m_key]['input_shape'] = list(input[0].size()) + summary[m_key]['input_shape'][0] = batch_size + if isinstance(output, (list, tuple)): + summary[m_key]['output_shape'] = [ + [-1] + list(o.size())[1:] for o in output + ] + else: + summary[m_key]['output_shape'] = list(output.size()) + summary[m_key]['output_shape'][0] = batch_size + params = 0 + if hasattr(module, 'weight') and hasattr(module.weight, 'size'): + params += torch.prod(torch.LongTensor(list(module.weight.size()))) + summary[m_key]['trainable'] = module.weight.requires_grad + if hasattr(module, 'bias') and hasattr(module.bias, 'size'): + params += torch.prod(torch.LongTensor(list(module.bias.size()))) + summary[m_key]['nb_params'] = params + summary[m_key]['gpu_mem'] = round(torch.cuda.memory_allocated(0)/1024**3, 1) + if len(list(module.children())) == 0: # only apply hooks at child modules to avoid applying them twice + hooks.append(module.register_forward_hook(hook)) + x = torch.rand(batch_size, *input_size).type(dtype).to(device=device) + summary = OrderedDict() # used within hook function to store properties + hooks = [] # used within hook function to store resgistered hooks + model.apply(register_hook) # recursively apply register_hook function to model and all children + model(x) # make a forward pass + for h in hooks: + h.remove() # remove the hooks so they are not used at run time + summary_str = '----------------------------------------------------------------\n' + line_new = '{:>20} {:>25} {:>15}'.format('Layer (type)', 'Output Shape', 'Param #') + summary_str += line_new + '\n' + summary_str += '================================================================\n' + total_params = 0 + total_output = 0 + trainable_params = 0 + for layer in summary: + line_new = '{:>20} {:>25} {:>15}'.format( + layer, + str(summary[layer]['output_shape']), + '{0:,}'.format(summary[layer]['nb_params']), + ) # input_shape, output_shape, trainable, nb_params + total_params += summary[layer]['nb_params'] + total_output += np.prod(summary[layer]['output_shape']) + if 'trainable' in summary[layer]: + if summary[layer]['trainable'] == True: + trainable_params += summary[layer]['nb_params'] + summary_str += line_new + '\n' + # assume 4 bytes/number (float on cuda). + total_input_size = abs(np.prod(input_size)) * batch_size * 4. / (1024 ** 2.) + total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients + total_params_size = abs(total_params * 4. / (1024 ** 2.)) + total_size = total_params_size + total_output_size + total_input_size + summary_str += '================================================================\n' + summary_str += f'Total params: {total_params}\n' + summary_str += f'Trainable params: {trainable_params}\n' + param_diff = total_params - trainable_params + summary_str += f'Non-trainable params: {param_diff}\n' + summary_str += '----------------------------------------------------------------\n' + summary_str += f'Input size (MB): {total_input_size:0.2f}\n' + summary_str += f'Forward/backward pass size (MB): {total_output_size:0.2f}\n' + summary_str += f'Params size (MB): {total_params_size:0.2f}\n' + summary_str += f'Estimated total size (MB): {total_size:0.2f}\n' + ## TODO: Update pytorch for this to work + #device_memory = torch.cuda.memory_summary(device, abbreviated=True) + #summary_str += f'Device memory allocated with batch of inputs (GB): {device_memory}\n' + summary_str += '----------------------------------------------------------------\n' + return summary_str, (total_params, trainable_params) diff --git a/utils/loaders.py b/utils/loaders.py index 192664f0..8b395e2a 100644 --- a/utils/loaders.py +++ b/utils/loaders.py @@ -44,10 +44,7 @@ def load_model_class(model_type): elif(model_type.lower() == 'lca'): py_module_name = 'LcaModel' file_name = os.path.join(*[dsc_dir, 'models', 'lca_model.py']) - #elif(model_type.lower() == 'conv_lca'): - # py_module_name = 'ConvLcaModel' - # file_name = os.path.join(*[dsc_dir, 'models', 'conv_lca_model.py']) - elif(model_type.lower() == 'pool'): + elif(model_type.lower() == 'pooling'): py_module_name = 'PoolingModel' file_name = os.path.join(*[dsc_dir, 'models', 'pooling_model.py']) elif(model_type.lower() == 'ensemble'): @@ -56,7 +53,7 @@ def load_model_class(model_type): else: accepted_names = [''.join(name.split('_')[:-1]) for name in get_module_list(dsc_dir)] assert False, ( - 'Acceptible model_types are %s, not %s'%(','.join(accepted_names), model_type)) + 'Acceptible model_types are %s, not %s'%('; '.join(accepted_names), model_type)) py_module = file_utils.python_module_from_file(py_module_name, file_name) py_module_class = getattr(py_module, py_module_name) return py_module_class @@ -74,10 +71,7 @@ def load_module(module_type): elif(module_type.lower() == 'lca'): py_module_name = 'LcaModule' file_name = os.path.join(*[dsc_dir, 'modules', 'lca_module.py']) - #elif(module_type.lower() == 'conv_lca'): - # py_module_name = 'ConvLcaModule' - # file_name = os.path.join(*[dsc_dir, 'modules', 'conv_lca_module.py']) - elif(module_type.lower() == 'pool'): + elif(module_type.lower() == 'pooling'): py_module_name = 'PoolingModule' file_name = os.path.join(*[dsc_dir, 'modules', 'pooling_module.py']) elif(module_type.lower() == 'ensemble'): @@ -86,7 +80,7 @@ def load_module(module_type): else: accepted_names = [''.join(name.split('_')[:-1]) for name in get_module_list(dsc_dir)] assert False, ( - 'Acceptible model_types are %s, not %s'%(','.join(accepted_names), module_type)) + 'Acceptible model_types are %s, not %s'%('; '.join(accepted_names), module_type)) py_module = file_utils.python_module_from_file(py_module_name, file_name) py_module_class = getattr(py_module, py_module_name) return py_module_class() diff --git a/utils/run_utils.py b/utils/run_utils.py index b68e62f9..a9d83756 100644 --- a/utils/run_utils.py +++ b/utils/run_utils.py @@ -22,7 +22,7 @@ def train_single_model(model, loss): model.optimizer.step() if(hasattr(model.params, 'renormalize_weights') and model.params.renormalize_weights): with torch.no_grad(): # tell autograd to not record this operation - model.w.div_(dp.get_weights_l2_norm(model.w)) + model.weight.div_(dp.get_weights_l2_norm(model.weight)) def train_epoch(epoch, model, loader): @@ -97,8 +97,7 @@ def test_epoch(epoch, model, loader, log_to_file=True): 'test_total':len(loader.dataset), 'test_accuracy':test_accuracy} if log_to_file: - js_str = model.js_dumpstring(stat_dict) - model.log_info(''+js_str+'') + model.logger.log_stats(stat_dict) else: return stat_dict From 4bc2db8f32c8877a56ab120c68481b2d06d0380e Mon Sep 17 00:00:00 2001 From: Dylan Date: Wed, 3 Mar 2021 15:33:57 +0100 Subject: [PATCH 28/44] adds tag to architecture logging for easy retrieval --- models/base_model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/models/base_model.py b/models/base_model.py index 2295333e..07cb2e32 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -22,9 +22,9 @@ def setup(self, params, logger=None): if logger is None: self.init_logging() self.log_params() + self.logger.log_info(self.get_env_details()) else: self.logger = logger - self.logger.log_info(self.get_env_details()) def load_params(self, params): """ @@ -142,13 +142,14 @@ def log_architecture_details(self): """ Log model architecture with computed output sizes and number of parameters for each layer """ - architecture_string = '\n'+summary_string( + architecture_string = '\n'+summary_string( self, input_size=tuple(self.params.data_shape), batch_size=self.params.batch_size, device=self.params.device, dtype=torch.FloatTensor )[0] + architecture_string += '\n' self.logger.log_string(architecture_string) def write_checkpoint(self, batch_step=None): From da4bf52a307021b0865614c147fa82dfa578afd9 Mon Sep 17 00:00:00 2001 From: Dylan Date: Wed, 3 Mar 2021 15:35:14 +0100 Subject: [PATCH 29/44] minor linting change --- models/pooling_model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/models/pooling_model.py b/models/pooling_model.py index 5ed4c9fb..093d0b1f 100644 --- a/models/pooling_model.py +++ b/models/pooling_model.py @@ -21,8 +21,7 @@ def get_total_loss(self, input_tuple): def loss_fn(model_output): output_loss = losses.trace_covariance(model_output) w_stride = self.params.pool_stride - w_padding = 0 - weight_loss = losses.weight_orthogonality(self.weight, stride=w_stride, padding=w_padding) + weight_loss = losses.weight_orthogonality(self.weight, stride=w_stride, padding=0) return output_loss + weight_loss input_tensor, input_label = input_tuple layer_output = self.forward(input_tensor) From 9080984c6b7ed3ecd24f223511c40c7fdc7f9848 Mon Sep 17 00:00:00 2001 From: Dylan Date: Wed, 3 Mar 2021 15:38:14 +0100 Subject: [PATCH 30/44] updated trace loss to have minimum of 0 --- modules/losses.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/losses.py b/modules/losses.py index 1ecdd4b6..4176a10e 100644 --- a/modules/losses.py +++ b/modules/losses.py @@ -85,7 +85,8 @@ def trace_covariance(latents): num_batch, num_channels, latents_h, latents_w = latents.shape covariance = covariance / (latents_h * latents_w - 1.0) trace = torch.trace(covariance) - return -1 * trace + target = torch.trace(torch.eye(covariance.size(0), device=trace.device)) # should = trace.size[0] + return torch.norm(trace - target, p='fro') def weight_orthogonality(weight, stride=1, padding=0): @@ -108,7 +109,7 @@ def weight_orthogonality(weight, stride=1, padding=0): """ w_shape = weight.shape if weight.ndim == 2: # fully-connected, [inputs, outputs] - loss = torch.norm(torch.matmul(weight.transpose(), weight) - torch.eye(w_shape[1])) + loss = torch.norm(torch.mm(weight.T, weight) - torch.eye(w_shape[1], device=weight.device)) elif weight.ndim == 4: # convolutional, [output_channels, input_channels, height, width] out_channels, in_channels, in_height, in_width = w_shape output = torch.conv2d(weight, weight, stride=stride, padding=padding) From 9a8b0f2e6e72efca4f6063693347cf5bcb3ef91c Mon Sep 17 00:00:00 2001 From: Dylan Date: Wed, 3 Mar 2021 15:42:13 +0100 Subject: [PATCH 31/44] full smt params; minor param bugfixes --- params/lca_mlp_cifar10_params.py | 3 +- params/lca_pool_cifar10_params.py | 2 +- params/lca_pool_lca_cifar10_params.py | 15 +- params/lca_pool_lca_pool_cifar10_params.py | 170 +++++++++++++++ .../lca_pool_lca_pool_mlp_cifar10_params.py | 193 ++++++++++++++++++ 5 files changed, 372 insertions(+), 11 deletions(-) create mode 100644 params/lca_pool_lca_pool_cifar10_params.py create mode 100644 params/lca_pool_lca_pool_mlp_cifar10_params.py diff --git a/params/lca_mlp_cifar10_params.py b/params/lca_mlp_cifar10_params.py index e90c5402..979696ca 100644 --- a/params/lca_mlp_cifar10_params.py +++ b/params/lca_mlp_cifar10_params.py @@ -53,8 +53,7 @@ def set_params(self): class mlp_params(MlpParams): def set_params(self): super(mlp_params, self).set_params() - for key, value in shared_params().__dict__.items(): - setattr(self, key, value) + for key, value in shared_params().__dict__.items(): setattr(self, key, value) self.model_type = 'mlp' self.layer_name = 'classifier' self.weight_lr = 2e-3 diff --git a/params/lca_pool_cifar10_params.py b/params/lca_pool_cifar10_params.py index e252448c..8a05fefb 100644 --- a/params/lca_pool_cifar10_params.py +++ b/params/lca_pool_cifar10_params.py @@ -91,7 +91,7 @@ def set_params(self): lca_params_inst.stride, lca_params_inst.padding, dilation=1) - lca_output_shape = [lca_params_inst.num_latent, lca_output_height, lca_output_width] + lca_output_shape = [lca_params_inst.layer_channels, lca_output_height, lca_output_width] pooling_params_inst.layer_channels[0] = np.prod(lca_output_shape) self.ensemble_params = [lca_params_inst, pooling_params_inst] for key, value in shared_params().__dict__.items(): diff --git a/params/lca_pool_lca_cifar10_params.py b/params/lca_pool_lca_cifar10_params.py index 56f4d0e2..ceb1037c 100644 --- a/params/lca_pool_lca_cifar10_params.py +++ b/params/lca_pool_lca_cifar10_params.py @@ -46,15 +46,14 @@ def set_params(self): self.rectify_a = True self.thresh_type = 'hard' self.sparse_mult = 0.35#0.30 - self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/lca_cifar10/logfiles/lca_cifar10_v0.log' + self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/lca_pool_cifar10/logfiles/lca_pool_cifar10_v0.log' self.compute_helper_params() -class pooling_params(BaseParams): +class pooling_1_params(BaseParams): def set_params(self): - super(pooling_params, self).set_params() - for key, value in shared_params().__dict__.items(): - setattr(self, key, value) + super(pooling_1_params, self).set_params() + for key, value in shared_params().__dict__.items(): setattr(self, key, value) self.model_type = 'pooling' self.layer_name = 'pool_1' self.weight_lr = 1e-3 @@ -70,7 +69,7 @@ def set_params(self): self.compute_helper_params() def compute_helper_params(self): - super(pooling_params, self).compute_helper_params() + super(pooling_1_params, self).compute_helper_params() self.optimizer.milestones = [frac * self.num_epochs for frac in self.optimizer.lr_annealing_milestone_frac] @@ -85,7 +84,7 @@ def set_params(self): self.kernel_size = 6 self.stride = 1 self.padding = 0 - self.sparse_mult = 0.10 + self.sparse_mult = 0.15 self.checkpoint_boot_log = '' self.compute_helper_params() @@ -94,7 +93,7 @@ class params(BaseParams): def set_params(self): super(params, self).set_params() lca_1_params_inst = lca_1_params() - pooling_params_inst = pooling_params() + pooling_params_inst = pooling_1_params() lca_2_params_inst = lca_2_params() lca_1_output_height = compute_conv_output_shape( 32, diff --git a/params/lca_pool_lca_pool_cifar10_params.py b/params/lca_pool_lca_pool_cifar10_params.py new file mode 100644 index 00000000..01053ca8 --- /dev/null +++ b/params/lca_pool_lca_pool_cifar10_params.py @@ -0,0 +1,170 @@ +import os +import types + +import numpy as np +import torch + +from DeepSparseCoding.params.base_params import BaseParams +from DeepSparseCoding.params.lca_cifar10_params import params as LcaParams +from DeepSparseCoding.utils.run_utils import compute_conv_output_shape + + +class shared_params(object): + def __init__(self): + self.model_type = 'ensemble' + self.model_name = 'lca_pool_lca_pool_cifar10' + self.version = '0' + self.dataset = 'cifar10' + self.standardize_data = True + self.batch_size = 25 + self.num_epochs = 150 + self.train_logs_per_epoch = 4 + self.allow_parent_grads = False + + +class lca_1_params(LcaParams): + def set_params(self): + super(lca_1_params, self).set_params() + for key, value in shared_params().__dict__.items(): setattr(self, key, value) + self.model_type = 'lca' + self.layer_name = 'lca_1' + self.layer_types = ['conv'] + self.weight_decay = 0.0 + self.weight_lr = 0.001 + self.renormalize_weights = True + self.layer_channels = 128 + self.kernel_size = 8 + self.stride = 2 + self.padding = 0 + self.optimizer = types.SimpleNamespace() + self.optimizer.name = 'sgd' + self.optimizer.lr_annealing_milestone_frac = [0.8] # fraction of num_epochs + self.optimizer.lr_decay_rate = 0.8 + self.dt = 0.001 + self.tau = 0.1#0.2 + self.num_steps = 37#75 + self.rectify_a = True + self.thresh_type = 'hard' + self.sparse_mult = 0.35#0.30 + self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/lca_pool_lca_cifar10/logfiles/lca_pool_lca_cifar10_v0.log' + self.compute_helper_params() + + +class pooling_1_params(BaseParams): + def set_params(self): + super(pooling_1_params, self).set_params() + for key, value in shared_params().__dict__.items(): + setattr(self, key, value) + self.model_type = 'pooling' + self.layer_name = 'pool_1' + self.weight_lr = 1e-3 + self.layer_types = ['conv'] + self.layer_channels = [128, 32] + self.pool_ksize = 2 + self.pool_stride = 2 # non-overlapping + self.optimizer = types.SimpleNamespace() + self.optimizer.name = 'sgd' + self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs + self.optimizer.lr_decay_rate = 0.8 + self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/lca_pool_lca_cifar10/logfiles/lca_pool_lca_cifar10_v0.log' + self.compute_helper_params() + + def compute_helper_params(self): + super(pooling_1_params, self).compute_helper_params() + self.optimizer.milestones = [frac * self.num_epochs + for frac in self.optimizer.lr_annealing_milestone_frac] + + +class lca_2_params(LcaParams): + def set_params(self): + super(lca_2_params, self).set_params() + for key, value in shared_params().__dict__.items(): setattr(self, key, value) + for key, value in lca_1_params().__dict__.items(): setattr(self, key, value) + self.layer_name = 'lca_2' + self.layer_channels = 256 + self.kernel_size = 6 + self.stride = 1 + self.padding = 0 + self.sparse_mult = 0.15 + self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/lca_pool_lca_cifar10/logfiles/lca_pool_lca_cifar10_v0.log' + self.compute_helper_params() + +class pooling_2_params(BaseParams): + def set_params(self): + super(pooling_2_params, self).set_params() + for key, value in shared_params().__dict__.items(): setattr(self, key, value) + for key, value in pooling_1_params().__dict__.items(): setattr(self, key, value) + self.layer_name = 'pool_2' + self.weight_lr = 1e-3 + self.layer_types = ['fc'] + self.layer_channels = [None, 64] + self.optimizer = types.SimpleNamespace() + self.optimizer.name = 'sgd' + self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs + self.optimizer.lr_decay_rate = 0.8 + self.checkpoint_boot_log = '' + self.compute_helper_params() + + def compute_helper_params(self): + super(pooling_2_params, self).compute_helper_params() + self.optimizer.milestones = [frac * self.num_epochs + for frac in self.optimizer.lr_annealing_milestone_frac] + + +class params(BaseParams): + def set_params(self): + super(params, self).set_params() + lca_1_params_inst = lca_1_params() + pooling_1_params_inst = pooling_1_params() + lca_2_params_inst = lca_2_params() + pooling_2_params_inst = pooling_2_params() + lca_1_output_height = compute_conv_output_shape( + 32, + lca_1_params_inst.kernel_size, + lca_1_params_inst.stride, + lca_1_params_inst.padding, + dilation=1) + lca_1_output_width = compute_conv_output_shape( + 32, + lca_1_params_inst.kernel_size, + lca_1_params_inst.stride, + lca_1_params_inst.padding, + dilation=1) + pooling_1_output_height = compute_conv_output_shape( + lca_1_output_height, + pooling_1_params_inst.pool_ksize, + pooling_1_params_inst.pool_stride, + padding=0, + dilation=1) + pooling_1_output_width = compute_conv_output_shape( + lca_1_output_width, + pooling_1_params_inst.pool_ksize, + pooling_1_params_inst.pool_stride, + padding=0, + dilation=1) + lca_2_params_inst.data_shape = [ + int(pooling_1_params_inst.layer_channels[-1]), + int(pooling_1_output_height), + int(pooling_1_output_width)] + lca_2_output_height = compute_conv_output_shape( + pooling_1_output_height, + lca_2_params_inst.kernel_size, + lca_2_params_inst.stride, + lca_2_params_inst.padding, + dilation=1) + lca_2_output_width = compute_conv_output_shape( + pooling_1_output_width, + lca_2_params_inst.kernel_size, + lca_2_params_inst.stride, + lca_2_params_inst.padding, + dilation=1) + lca_2_flat_dim = lca_2_params_inst.layer_channels*lca_2_output_height*lca_2_output_width + pooling_2_params_inst.layer_channels[0] = lca_2_flat_dim + self.ensemble_params = [ + lca_1_params_inst, + pooling_1_params_inst, + lca_2_params_inst, + pooling_2_params_inst + ] + for key, value in shared_params().__dict__.items(): + setattr(self, key, value) diff --git a/params/lca_pool_lca_pool_mlp_cifar10_params.py b/params/lca_pool_lca_pool_mlp_cifar10_params.py new file mode 100644 index 00000000..f4b322df --- /dev/null +++ b/params/lca_pool_lca_pool_mlp_cifar10_params.py @@ -0,0 +1,193 @@ +import os +import types + +import numpy as np +import torch + +from DeepSparseCoding.params.base_params import BaseParams +from DeepSparseCoding.params.lca_cifar10_params import params as LcaParams +from DeepSparseCoding.params.mlp_mnist_params import params as MlpParams +from DeepSparseCoding.utils.run_utils import compute_conv_output_shape + + +class shared_params(object): + def __init__(self): + self.model_type = 'ensemble' + self.model_name = 'lca_pool_lca_pool_mlp_cifar10' + self.version = '0' + self.dataset = 'cifar10' + self.standardize_data = True + self.batch_size = 25 + self.num_epochs = 150 + self.train_logs_per_epoch = 4 + self.allow_parent_grads = False + + +class lca_1_params(LcaParams): + def set_params(self): + super(lca_1_params, self).set_params() + for key, value in shared_params().__dict__.items(): setattr(self, key, value) + self.model_type = 'lca' + self.layer_name = 'lca_1' + self.layer_types = ['conv'] + self.weight_decay = 0.0 + self.weight_lr = 0#1e-3 + self.renormalize_weights = True + self.layer_channels = 128 + self.kernel_size = 8 + self.stride = 2 + self.padding = 0 + self.optimizer = types.SimpleNamespace() + self.optimizer.name = 'sgd' + self.optimizer.lr_annealing_milestone_frac = [0.8] # fraction of num_epochs + self.optimizer.lr_decay_rate = 0.8 + self.dt = 0.001 + self.tau = 0.2 + self.num_steps = 75 + self.rectify_a = True + self.thresh_type = 'hard' + self.sparse_mult = 0.35#0.30 + self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/lca_pool_lca_pool_cifar10/logfiles/lca_pool_lca_pool_cifar10_v0.log' + self.compute_helper_params() + + +class pooling_1_params(BaseParams): + def set_params(self): + super(pooling_1_params, self).set_params() + for key, value in shared_params().__dict__.items(): + setattr(self, key, value) + self.model_type = 'pooling' + self.layer_name = 'pool_1' + self.weight_lr = 0#1e-3 + self.layer_types = ['conv'] + self.layer_channels = [128, 32] + self.pool_ksize = 2 + self.pool_stride = 2 # non-overlapping + self.optimizer = types.SimpleNamespace() + self.optimizer.name = 'sgd' + self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs + self.optimizer.lr_decay_rate = 0.8 + self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/lca_pool_lca_pool_cifar10/logfiles/lca_pool_lca_pool_cifar10_v0.log' + self.compute_helper_params() + + def compute_helper_params(self): + super(pooling_1_params, self).compute_helper_params() + self.optimizer.milestones = [frac * self.num_epochs + for frac in self.optimizer.lr_annealing_milestone_frac] + + +class lca_2_params(LcaParams): + def set_params(self): + super(lca_2_params, self).set_params() + for key, value in shared_params().__dict__.items(): setattr(self, key, value) + for key, value in lca_1_params().__dict__.items(): setattr(self, key, value) + self.layer_name = 'lca_2' + self.weight_lr = 0#1e-3 + self.layer_channels = 256 + self.kernel_size = 6 + self.stride = 1 + self.padding = 0 + self.sparse_mult = 0.20 + self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/lca_pool_lca_pool_cifar10/logfiles/lca_pool_lca_pool_cifar10_v0.log' + self.compute_helper_params() + +class pooling_2_params(BaseParams): + def set_params(self): + super(pooling_2_params, self).set_params() + for key, value in shared_params().__dict__.items(): setattr(self, key, value) + for key, value in pooling_1_params().__dict__.items(): setattr(self, key, value) + self.layer_name = 'pool_2' + self.weight_lr = 0#1e-3 + self.layer_types = ['fc'] + self.layer_channels = [None, 64] + self.optimizer = types.SimpleNamespace() + self.optimizer.name = 'sgd' + self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs + self.optimizer.lr_decay_rate = 0.8 + self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/lca_pool_lca_pool_cifar10/logfiles/lca_pool_lca_pool_cifar10_v0.log' + self.compute_helper_params() + + def compute_helper_params(self): + super(pooling_2_params, self).compute_helper_params() + self.optimizer.milestones = [frac * self.num_epochs + for frac in self.optimizer.lr_annealing_milestone_frac] + + +class mlp_params(MlpParams): + def set_params(self): + super(mlp_params, self).set_params() + for key, value in shared_params().__dict__.items(): setattr(self, key, value) + self.model_type = 'mlp' + self.layer_name = 'classifier' + self.weight_lr = 2e-3 + self.weight_decay = 1e-6 + self.layer_types = ['fc'] + self.layer_channels = [64, 10] + self.activation_functions = ['identity'] + self.dropout_rate = [0.0] # probability of value being set to zero + self.optimizer = types.SimpleNamespace() + self.optimizer.name = 'sgd' + self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs + self.optimizer.lr_decay_rate = 0.8 + self.compute_helper_params() + + +class params(BaseParams): + def set_params(self): + super(params, self).set_params() + lca_1_params_inst = lca_1_params() + pooling_1_params_inst = pooling_1_params() + lca_2_params_inst = lca_2_params() + pooling_2_params_inst = pooling_2_params() + mlp_params_inst = mlp_params() + lca_1_output_height = compute_conv_output_shape( + 32, + lca_1_params_inst.kernel_size, + lca_1_params_inst.stride, + lca_1_params_inst.padding, + dilation=1) + lca_1_output_width = compute_conv_output_shape( + 32, + lca_1_params_inst.kernel_size, + lca_1_params_inst.stride, + lca_1_params_inst.padding, + dilation=1) + pooling_1_output_height = compute_conv_output_shape( + lca_1_output_height, + pooling_1_params_inst.pool_ksize, + pooling_1_params_inst.pool_stride, + padding=0, + dilation=1) + pooling_1_output_width = compute_conv_output_shape( + lca_1_output_width, + pooling_1_params_inst.pool_ksize, + pooling_1_params_inst.pool_stride, + padding=0, + dilation=1) + lca_2_params_inst.data_shape = [ + int(pooling_1_params_inst.layer_channels[-1]), + int(pooling_1_output_height), + int(pooling_1_output_width)] + lca_2_output_height = compute_conv_output_shape( + pooling_1_output_height, + lca_2_params_inst.kernel_size, + lca_2_params_inst.stride, + lca_2_params_inst.padding, + dilation=1) + lca_2_output_width = compute_conv_output_shape( + pooling_1_output_width, + lca_2_params_inst.kernel_size, + lca_2_params_inst.stride, + lca_2_params_inst.padding, + dilation=1) + lca_2_flat_dim = lca_2_params_inst.layer_channels*lca_2_output_height*lca_2_output_width + pooling_2_params_inst.layer_channels[0] = lca_2_flat_dim + self.ensemble_params = [ + lca_1_params_inst, + pooling_1_params_inst, + lca_2_params_inst, + pooling_2_params_inst, + mlp_params_inst + ] + for key, value in shared_params().__dict__.items(): + setattr(self, key, value) From d4c7fb1b37e193e80474109fc4b17ecd4d030bb7 Mon Sep 17 00:00:00 2001 From: Dylan Date: Wed, 3 Mar 2021 15:44:09 +0100 Subject: [PATCH 32/44] updates standardization preprocessing standardization can now use the dataset mean & std dataset_utils now outputs original dataset mean & std --- train_model.py | 2 +- utils/data_processing.py | 51 ++++++++++++++++++++++++++++++++++------ utils/dataset_utils.py | 39 +++++++++++++++++++----------- 3 files changed, 70 insertions(+), 22 deletions(-) diff --git a/train_model.py b/train_model.py index 6a3bc5d6..b84c7c66 100644 --- a/train_model.py +++ b/train_model.py @@ -25,7 +25,7 @@ params = loaders.load_params_file(param_file) # Load data -train_loader, val_loader, test_loader, data_stats = dataset_utils.load_dataset(params) +train_loader, val_loader, test_loader, data_stats = dataset_utils.load_dataset(params)[:4] for key, value in data_stats.items(): setattr(params, key, value) diff --git a/utils/data_processing.py b/utils/data_processing.py index 933a6305..933ee369 100644 --- a/utils/data_processing.py +++ b/utils/data_processing.py @@ -133,6 +133,27 @@ def get_std_from_dataloader(loader, dataset_mean): raise NotImplementedError +def get_std_from_dataloader(loader, dataset_mean=None): + """ + Calculate the standard deviation from the mean for all entries in a pytorch data loader + + Keyword arguments: + loader: [pytorch DataLoader] containing the full dataset. + This function assumes there is always a target label, i.e. loader.next() returns (data, target) + + Outputs: + dataset_std: [torch tensor] of the same shape as a single dataset sample + """ + if dataset_mean is None: + dataset_mean = get_mean_from_dataloader(loader) + dataset_std = torch.zeros(next(iter(loader))[0].shape[1:]) # don't include batch dimension + num_batches = 0 + for data, target in loader: + dataset_std += torch.std(data - dataset_mean, dim=0, keepdim=False) + num_batches += 1 + return dataset_std / num_batches + + def get_mean_from_dataloader(loader): """ Calculate the mean datapoint from all entries in a pytorch data loader @@ -175,7 +196,7 @@ def center(data, samplewise=False, batch_size=100): return data, data_mean -def standardize(data, eps=None, samplewise=False, batch_size=100): +def standardize(data, eps=None, samplewise=False, batch_size=100, sample_mean=None, sample_std=None): """ Standardize each image data to have zero mean and unit standard-deviation (z-score) @@ -187,6 +208,10 @@ def standardize(data, eps=None, samplewise=False, batch_size=100): defaults to 1/sqrt(data_dim) where data_dim is the total size of a data vector samplewise: [bool] if True, standardize each sample individually; akin to contrast-normalization if False, compute mean and std over entire batch + sample_mean: [tensor] to be used as the dataset mean instead of calculating it, + it should be the same shape as a single data element + sample_std: [tensor] to be used as the dataset mean instead of calculating it, + it should be the same shape as a single data element Outputs: data: [tensor] normalized data @@ -196,12 +221,23 @@ def standardize(data, eps=None, samplewise=False, batch_size=100): data, orig_shape = reshape_data(data, flatten=True)[:2] num_examples = data.shape[0] if(samplewise): # standardize each input sample individually - data_axis = tuple(range(data.ndim)[1:]) - data_mean = torch.mean(data, dim=data_axis, keepdim=True) - data_true_std = torch.std(data, unbiased=False, dim=data_axis, keepdim=True) + if sample_mean is None: + data_mean = torch.mean(data, dim=1, keepdim=True) # [num_examples, 1] + else: + data_mean = sample_mean.mean().repeat(num_examples, 1) + if sample_std is None: + data_true_std = torch.std(data - data_mean, unbiased=False, dim=1, keepdim=True) + else: + data_true_std = sample_std.mean().repeat(num_examples, 1) else: # standardize the entire population - data_mean = torch.mean(data, dim=0, keepdim=True) - data_true_std = torch.std(data, dim=0, unbiased=False, keepdim=True) + if sample_mean is None: + data_mean = torch.mean(data, dim=0, keepdim=True) # [1, sample_dim] + else: + data_mean = sample_mean.view(1, -1) + if sample_std is None: + data_true_std = torch.std(data - data_mean, dim=0, unbiased=False, keepdim=True) + else: + data_true_std = sample_std.view(1, -1) data_std = torch.where(data_true_std >= eps, data_true_std, eps*torch.ones_like(data_true_std)) data = (data - data_mean) / data_std if(data.shape != orig_shape): @@ -490,7 +526,8 @@ def covariance(tensor): """ if tensor.ndim == 2: # [num_batch, num_channels] centered_tensor = tensor - tensor.mean(dim=0, keepdim=True) # subtract mean vector - corvariance = torch.dot(centered_tensor.T, centered_tensor) # sum over batch + covariance = torch.mm(centered_tensor.T, centered_tensor) # sum over batch + covariance = covariance / (centered_tensor.shape[0]-1) elif tensor.ndim == 4: # [num_batch, num_channels, elements_h, elements_w] num_batch, num_channels, elements_h, elements_w = tensor.shape flat_map = tensor.view(num_batch, num_channels, elements_h * elements_w) diff --git a/utils/dataset_utils.py b/utils/dataset_utils.py index ef6b38a0..8a6a6f7e 100644 --- a/utils/dataset_utils.py +++ b/utils/dataset_utils.py @@ -61,7 +61,8 @@ def __len__(self): def load_dataset(params): - new_params = {} + data_stats = {} # dataset statistics + extra_outputs = {} # depending on parameters may include dataset_mean, dataset_std, etc if(params.dataset.lower() == 'mnist'): preprocessing_pipeline = [ transforms.ToTensor(), @@ -116,10 +117,23 @@ def load_dataset(params): dataset_mean_image = dp.get_mean_from_dataloader(data_loader) preprocessing_pipeline.append( transforms.Lambda(lambda x: x - dataset_mean_image)) + extra_outputs['dataset_mean_image'] = dataset_mean_image if params.standardize_data: + dataset = torchvision.datasets.CIFAR10(**kwargs) + data_loader = torch.utils.data.DataLoader(dataset, batch_size=params.batch_size, + shuffle=False, num_workers=0, pin_memory=True) + dataset_mean_image = dp.get_mean_from_dataloader(data_loader) + extra_outputs['dataset_mean_image'] = dataset_mean_image + dataset_std_image = dp.get_std_from_dataloader(data_loader, dataset_mean_image) + extra_outputs['dataset_std_image'] = dataset_std_image preprocessing_pipeline.append( transforms.Lambda( - lambda x: dp.standardize(x, eps=params.eps, samplewise=True, batch_size=params.batch_size)[0] + lambda x: dp.standardize(x, + eps=params.eps, + samplewise=False, + batch_size=params.batch_size, + sample_mean=dataset_mean_image, + sample_std=dataset_std_image)[0] ) ) if params.rescale_data_to_one: @@ -131,8 +145,7 @@ def load_dataset(params): kwargs['train'] = False testset = torchvision.datasets.CIFAR10(**kwargs) num_train = len(dataset) - params.num_validation - trainset, valset = torch.utils.data.random_split(dataset, - [num_train, params.num_validation]) + trainset, valset = torch.utils.data.random_split(dataset, [num_train, params.num_validation]) train_loader = torch.utils.data.DataLoader(trainset, batch_size=params.batch_size, shuffle=params.shuffle_data, num_workers=0, pin_memory=True) val_loader = torch.utils.data.DataLoader(valset, batch_size=params.batch_size, @@ -173,22 +186,20 @@ def load_dataset(params): num_workers=0, pin_memory=False) val_loader = None test_loader = None - new_params["num_pixels"] = params.data_edge_size**2 else: assert False, (f'Supported datasets are ["mnist", "dsprites", "synthetic"], not {dataset_name}') - new_params = {} - new_params['epoch_size'] = len(train_loader.dataset) + data_stats['epoch_size'] = len(train_loader.dataset) if(not hasattr(params, 'num_val_images')): if val_loader is None: - new_params['num_val_images'] = 0 + data_stats['num_val_images'] = 0 else: - new_params['num_val_images'] = len(val_loader.dataset) + data_stats['num_val_images'] = len(val_loader.dataset) if(not hasattr(params, 'num_test_images')): if test_loader is None: - new_params['num_test_images'] = 0 + data_stats['num_test_images'] = 0 else: - new_params['num_test_images'] = len(test_loader.dataset) - new_params['data_shape'] = list(next(iter(train_loader))[0].shape)[1:] - new_params['num_pixels'] = np.prod(new_params['data_shape']) - return (train_loader, val_loader, test_loader, new_params) + data_stats['num_test_images'] = len(test_loader.dataset) + data_stats['data_shape'] = list(next(iter(train_loader))[0].shape)[1:] + data_stats['num_pixels'] = np.prod(data_stats['data_shape']) + return (train_loader, val_loader, test_loader, data_stats, extra_outputs) From ed00ea124ad34f6e45155295944c5b684c692719 Mon Sep 17 00:00:00 2001 From: Dylan Date: Wed, 3 Mar 2021 15:46:26 +0100 Subject: [PATCH 33/44] notebook for visualizing SMT outptus & weights --- notebooks/visualize_pooling_weights.ipynb | 1297 +++++++++++++++++++++ 1 file changed, 1297 insertions(+) create mode 100644 notebooks/visualize_pooling_weights.ipynb diff --git a/notebooks/visualize_pooling_weights.ipynb b/notebooks/visualize_pooling_weights.ipynb new file mode 100644 index 00000000..f459cf22 --- /dev/null +++ b/notebooks/visualize_pooling_weights.ipynb @@ -0,0 +1,1297 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "\n", + "ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd()))\n", + "if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)\n", + "\n", + "import scipy\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "import proplot as plot\n", + "import matplotlib\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.gridspec as gridspec\n", + "from matplotlib.colors import LinearSegmentedColormap\n", + "\n", + "from DeepSparseCoding.utils.file_utils import Logger\n", + "import DeepSparseCoding.utils.run_utils as run_utils\n", + "import DeepSparseCoding.utils.dataset_utils as dataset_utils\n", + "import DeepSparseCoding.utils.loaders as loaders\n", + "import DeepSparseCoding.utils.plot_functions as pf\n", + "import DeepSparseCoding.utils.data_processing as dp\n", + "from DeepSparseCoding.params.lca_cifar10_params import params as LcaParams" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "workspace_dir = '/mnt/qb/bethge/dpaiton/'\n", + "model_name = 'lca_pool_lca_pool_cifar10'\n", + "model_version = '0'\n", + "log_file = workspace_dir + os.path.join(*['Projects', model_name, 'logfiles', f'{model_name}_v{model_version}.log'])\n", + "logger = Logger(log_file, overwrite=False)\n", + "log_text = logger.load_file()\n", + "params = logger.read_params(log_text)[-1]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "model_stats = logger.read_stats(log_text)\n", + "x_key = \"epoch\"\n", + "y_keys = [key for key in list(model_stats.keys()) if 'test_' not in key]\n", + "stats_fig = pf.plot_stats(model_stats, x_key, y_keys=y_keys)\n", + "\n", + "if 'test_epoch' in list(model_stats.keys()):\n", + " x_key = \"test_epoch\"\n", + " y_keys = [key for key in list(model_stats.keys()) if 'test_' in key]\n", + " test_stats_fig = pf.plot_stats(model_stats, x_key, y_keys=y_keys)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = loaders.load_model(params.model_type)\n", + "model.setup(params, logger)\n", + "model.to(params.device)\n", + "model_state_str = model.load_checkpoint()\n", + "print(model_state_str)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_loader, val_loader, test_loader, data_stats, data_mean_std = dataset_utils.load_dataset(params)\n", + "train_mean_image = data_mean_std['dataset_mean_image'].to(model.params.device)\n", + "train_std_image = data_mean_std['dataset_std_image'].to(model.params.device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lca_weights = model.lca_1.weight.detach().cpu().numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "def normalize_data_with_max(data):\n", + " \"\"\"\n", + " Normalize data by dividing by abs(max(data))\n", + " If abs(max(data)) is zero, then output is zero\n", + " Inputs:\n", + " data: [np.ndarray] data to be normalized\n", + " Outputs:\n", + " norm_data: [np.ndarray] normalized data\n", + " data_max: [float] max that was divided out\n", + " \"\"\"\n", + " data_max = np.max(np.abs(data), axis=(1,2), keepdims=True)\n", + " norm_data = np.divide(data, data_max, out=np.zeros_like(data), where=data_max!=0)\n", + " return norm_data, data_max\n", + "\n", + "def pad_matrix_to_image(matrix, pad_size=0, pad_value=0, normalize=False):\n", + " if normalize:\n", + " #matrix = normalize_data_with_max(matrix)[0]\n", + " matrix = dp.rescale_data_to_one(torch.from_numpy(matrix), eps=1e-10, samplewise=True)[0].numpy()\n", + " num_weights, img_c, img_h, img_w = matrix.shape\n", + " #if img_c == 1:\n", + " # matrix = matrix.squeeze()\n", + " #else:\n", + " # # TODO: separate channels, pad each individual one, then recombine.\n", + " # assert False, (f'Multiple color channels are not currently supported') \n", + " num_extra_images = int(np.ceil(np.sqrt(num_weights))**2 - num_weights)\n", + " matrices = []\n", + " for channel_idx in range(img_c):\n", + " channel_matrix = matrix[:, channel_idx, ...].copy()\n", + " if num_extra_images > 0:\n", + " channel_matrix = np.concatenate(\n", + " [channel_matrix, np.zeros((num_extra_images, img_h, img_w))], axis=0)\n", + " channel_matrix = np.pad(channel_matrix,\n", + " pad_width=((0,0), (num_pad_pix, num_pad_pix), (num_pad_pix, num_pad_pix)),\n", + " mode='constant', constant_values=pad_value)\n", + " padded_img_h, padded_img_w = channel_matrix.shape[1:]\n", + " num_edge_tiles = int(np.sqrt(channel_matrix.shape[0]))\n", + " tiles = channel_matrix.reshape(num_edge_tiles, num_edge_tiles, padded_img_h, padded_img_w)\n", + " tiles = tiles.swapaxes(1, 2)\n", + " matrices.append(tiles.reshape(num_edge_tiles * padded_img_h, num_edge_tiles * padded_img_w))\n", + " padded_matrix = np.stack(matrices, axis=0) # channel dim first\n", + " return padded_matrix\n", + " \n", + "def plot_matrix(matrix, title='', cmap=None):\n", + " fig, ax = plot.subplots(figsize=(10,10))\n", + " ax = pf.clear_axis(ax)\n", + " ax.imshow(matrix, cmap=cmap)#, vmin=0.0, vmax=1.0)#, cmap='greys_r')\n", + " ax.format(title=title)\n", + " plot.show()\n", + " return fig\n", + "\n", + "pad_value = 0.5\n", + "num_pad_pix = 1\n", + "padded_matrix = pad_matrix_to_image(lca_weights, num_pad_pix, pad_value, normalize=True)\n", + "fig = plot_matrix(np.transpose(padded_matrix, axes=[1, 2, 0]), title=f'{model.params.model_name} weights')\n", + "fig.savefig(\n", + " f'{model.params.disp_dir}/weights_plot_matrix.png',\n", + " transparent=False,\n", + " bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def rgb_to_gray(rgb):\n", + " num, chan, height, width = rgb.shape\n", + " gray = np.zeros((num, 1, height, width))\n", + " for neuron_idx in range(num):\n", + " gray[neuron_idx, ...] = 0.2125 * rgb[neuron_idx, 0, ...]\n", + " gray[neuron_idx, ...] += 0.7154 * rgb[neuron_idx, 1, ...]\n", + " gray[neuron_idx, ...] += 0.0721 * rgb[neuron_idx, 2, ...]\n", + " return gray\n", + "\n", + "gray_lca_weights = rgb_to_gray(lca_weights)\n", + "pad_value = 0.5\n", + "num_pad_pix = 1\n", + "padded_matrix = pad_matrix_to_image(gray_lca_weights, num_pad_pix, pad_value, normalize=True)\n", + "fig = plot_matrix(np.squeeze(np.transpose(padded_matrix, axes=[1, 2, 0])), title=f'{model.params.model_name} weights', cmap='grays_r')\n", + "fig.savefig(\n", + " f'{model.params.disp_dir}/weights_grayscale_plot_matrix.png',\n", + " transparent=False,\n", + " bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def generate_gaussian(shape, mean, cov):\n", + " \"\"\"\n", + " Generate a Gaussian PDF from given mean & cov\n", + " Inputs:\n", + " shape: [tuple] specifying (num_rows, num_cols)\n", + " mean: [np.ndarray] of shape (2,) specifying the 2-D Gaussian center\n", + " cov: [np.ndarray] of shape (2,2) specifying the 2-D Gaussian covariance matrix\n", + " Outputs:\n", + " tuple containing (Gaussian PDF, grid_points used to generate PDF)\n", + " grid_points are specified as a tuple of (y,x) points\n", + " \"\"\"\n", + " (y_size, x_size) = shape\n", + " y = np.linspace(0, y_size, np.int32(np.floor(y_size)))\n", + " x = np.linspace(0, x_size, np.int32(np.floor(x_size)))\n", + " y, x = np.meshgrid(y, x)\n", + " pos = np.empty(x.shape + (2,)) #x.shape == y.shape\n", + " pos[:, :, 0] = y; pos[:, :, 1] = x\n", + " gauss = scipy.stats.multivariate_normal(mean, cov)\n", + " return (gauss.pdf(pos), (y,x))\n", + "\n", + "\n", + "def gaussian_fit(pyx):\n", + " \"\"\"\n", + " Compute the expected mean & covariance matrix for a 2-D gaussian fit of input distribution\n", + " Inputs:\n", + " pyx: [np.ndarray] of shape [num_rows, num_cols] that indicates the probability function to fit\n", + " Outputs:\n", + " mean: [np.ndarray] of shape (2,) specifying the 2-D Gaussian center\n", + " cov: [np.ndarray] of shape (2,2) specifying the 2-D Gaussian covariance matrix\n", + " \"\"\"\n", + " assert pyx.ndim == 2, (\n", + " \"Input must have 2 dimensions specifying [num_rows, num_cols]\")\n", + " mean = np.zeros((1,2), dtype=np.float32) # [mu_y, mu_x]\n", + " for idx in np.ndindex(pyx.shape): # [y, x] ticks columns (x) first, then rows (y)\n", + " mean += np.asarray([pyx[idx]*idx[0], pyx[idx]*idx[1]])[None,:]\n", + " cov = np.zeros((2,2), dtype=np.float32)\n", + " for idx in np.ndindex(pyx.shape): # ticks columns first, then rows\n", + " cov += np.dot((idx-mean).T, (idx-mean))*pyx[idx] # typically an outer-product\n", + " return (np.squeeze(mean), cov)\n", + "\n", + "\n", + "def get_gauss_fit(prob_map, num_attempts=1, perc_mean=0.33):\n", + " \"\"\"\n", + " Returns a gaussian fit for a given probability map\n", + " Fitting is done via robust regression, where a fit is\n", + " continuously refined by deleting outliers num_attempts times\n", + " Inputs:\n", + " prob_map: 2-D probability map to be fit\n", + " num_attempts: Number of times to fit & remove outliers\n", + " perc_mean: All probability values below perc_mean*mean(gauss_fit) will be\n", + " considered outliers for repeated attempts\n", + " Outputs:\n", + " gauss_fit: [np.ndarray] specifying the 2-D Gaussian PDF\n", + " grid: [tuple] containing (y,x) points with which the Gaussian PDF can be plotted\n", + " gauss_mean: [np.ndarray] of shape (2,) specifying the 2-D Gaussian center\n", + " gauss_cov: [np.ndarray] of shape (2,2) specifying the 2-D Gaussian covariance matrix\n", + " \"\"\"\n", + " assert prob_map.ndim==2, (\n", + " \"get_gauss_fit: Input prob_map must have 2 dimension specifying [num_rows, num_cols\")\n", + " if num_attempts < 1:\n", + " num_attempts = 1\n", + " orig_prob_map = prob_map.copy()\n", + " gauss_success = False\n", + " while not gauss_success:\n", + " prob_map = orig_prob_map.copy()\n", + " try:\n", + " for i in range(num_attempts):\n", + " map_min = np.min(prob_map)\n", + " prob_map -= map_min\n", + " map_sum = np.sum(prob_map)\n", + " if map_sum != 1.0:\n", + " prob_map /= map_sum\n", + " gauss_mean, gauss_cov = gaussian_fit(prob_map)\n", + " gauss_fit, grid = generate_gaussian(prob_map.shape, gauss_mean, gauss_cov)\n", + " gauss_fit = (gauss_fit * map_sum) + map_min\n", + " if i < num_attempts-1:\n", + " gauss_mask = gauss_fit.copy().T\n", + " mask_slice = np.where(gauss_mask0)] = 1\n", + " prob_map *= gauss_mask\n", + " gauss_success = True\n", + " except np.linalg.LinAlgError: # Usually means cov matrix is singular\n", + " print(\"get_gauss_fit: Failed to fit Gaussian at attempt \",i,\", trying again.\"+\n", + " \"\\n To avoid this try decreasing perc_mean.\")\n", + " num_attempts = i-1\n", + " if num_attempts <= 0:\n", + " assert False, (\"get_gauss_fit: np.linalg.LinAlgError - Unable to fit gaussian.\")\n", + " return (gauss_fit, grid, gauss_mean, gauss_cov)\n", + "\n", + "\n", + "def hilbert_amplitude(weights, padding=None):\n", + " \"\"\"\n", + " Compute Hilbert amplitude envelope of weight matrix\n", + " Inputs:\n", + " weights: [np.ndarray] of shape [num_inputs, num_outputs]\n", + " num_inputs must have an even square root\n", + " padding: [int] specifying how much 0-padding to use for FFT\n", + " default is the closest power of 2 of sqrt(num_inputs)\n", + " Outputs:\n", + " env: [np.ndarray] of shape [num_outputs, num_inputs]\n", + " Hilbert envelope\n", + " bff_filt: [np.ndarray] of shape [num_outputs, padded_num_inputs]\n", + " Filtered Fourier transform of basis function\n", + " hil_filt: [np.ndarray] of shape [num_outputs, sqrt(num_inputs), sqrt(num_inputs)]\n", + " Hilbert filter to be applied in Fourier space\n", + " bffs: [np.ndarray] of shape [num_outputs, padded_num_inputs, padded_num_inputs]\n", + " Fourier transform of input weights\n", + " \"\"\"\n", + " cart2pol = lambda x,y: (np.arctan2(y,x), np.hypot(x, y))\n", + " num_inputs, num_outputs = weights.shape\n", + " assert np.sqrt(num_inputs) == np.floor(np.sqrt(num_inputs)), (\n", + " \"weights.shape[0] must have an even square root.\")\n", + " patch_edge_size = int(np.sqrt(num_inputs))\n", + " if padding is None or padding <= patch_edge_size:\n", + " # Amount of zero padding for fft2 (closest power of 2)\n", + " N = np.int(2**(np.ceil(np.log2(patch_edge_size))))\n", + " else:\n", + " N = np.int(padding)\n", + " # Analytic signal envelope for weights\n", + " # (Hilbet transform of each basis function)\n", + " env = np.zeros((num_outputs, num_inputs), dtype=complex)\n", + " # Fourier transform of weights\n", + " bffs = np.zeros((num_outputs, N, N), dtype=complex)\n", + " # Filtered Fourier transform of weights\n", + " bff_filt = np.zeros((num_outputs, N**2), dtype=complex)\n", + " # Hilbert filters\n", + " hil_filt = np.zeros((num_outputs, N, N))\n", + " # Grid for creating filter\n", + " f = (2/N) * np.pi * np.arange(-N/2.0, N/2.0)\n", + " (fx, fy) = np.meshgrid(f, f)\n", + " (theta, r) = cart2pol(fx, fy)\n", + " for neuron_idx in range(num_outputs):\n", + " # Grab single basis function, reshape to a square image\n", + " bf = weights[:, neuron_idx].reshape(patch_edge_size, patch_edge_size)\n", + " # Convert basis function into DC-centered Fourier domain\n", + " bff = np.fft.fftshift(np.fft.fft2(bf-np.mean(bf), [N, N]))\n", + " bffs[neuron_idx, ...] = bff\n", + " # Find indices of the peak amplitude\n", + " max_ys = np.abs(bff).argmax(axis=0) # Returns row index for each col\n", + " max_x = np.argmax(np.abs(bff).max(axis=0))\n", + " # Convert peak amplitude location into angle in freq domain\n", + " fx_ang = f[max_x]\n", + " fy_ang = f[max_ys[max_x]]\n", + " theta_max = np.arctan2(fy_ang, fx_ang)\n", + " # Define the half-plane with respect to the maximum\n", + " ang_diff = np.abs(theta-theta_max)\n", + " idx = (ang_diff>np.pi).nonzero()\n", + " ang_diff[idx] = 2.0 * np.pi - ang_diff[idx]\n", + " hil_filt[neuron_idx, ...] = (ang_diff < np.pi/2.0).astype(int)\n", + " # Create analytic signal from the inverse FT of the half-plane filtered bf\n", + " abf = np.fft.ifft2(np.fft.fftshift(hil_filt[neuron_idx, ...]*bff))\n", + " env[neuron_idx, ...] = abf[0:patch_edge_size, 0:patch_edge_size].reshape(num_inputs)\n", + " bff_filt[neuron_idx, ...] = (hil_filt[neuron_idx, ...]*bff).reshape(N**2)\n", + " return (env, bff_filt, hil_filt, bffs)\n", + "\n", + "\n", + "def get_dictionary_stats(weights, padding=None, num_gauss_fits=20, gauss_thresh=0.2):\n", + " \"\"\"\n", + " Compute summary statistics on dictionary elements using Hilbert amplitude envelope\n", + " Inputs:\n", + " weights: [np.ndarray] of shape [num_inputs, num_outputs]\n", + " padding: [int] total image size to pad out to in the FFT computation\n", + " num_gauss_fits: [int] total number of attempts to make when fitting the BFs\n", + " gauss_thresh: All probability values below gauss_thresh*mean(gauss_fit) will be\n", + " considered outliers for repeated fits\n", + " Outputs:\n", + " The function output is a dictionary containing the keys for each type of analysis\n", + " Each key dereferences a list of len num_outputs (i.e. one entry for each weight vector)\n", + " The keys and their list entries are as follows:\n", + " basis_functions: [np.ndarray] of shape [patch_edge_size, patch_edge_size]\n", + " envelopes: [np.ndarray] of shape [N, N], where N is the amount of padding\n", + " for the hilbert_amplitude function\n", + " envelope_centers: [tuples of ints] indicating the (y, x) position of the\n", + " center of the Hilbert envelope\n", + " gauss_fits: [list of np.ndarrays] containing (gaussian_fit, grid) where gaussian_fit\n", + " is returned from get_gauss_fit and specifies the 2D Gaussian PDF fit to the Hilbert\n", + " envelope and grid is a tuple containing (y,x) points with which the Gaussian PDF\n", + " can be plotted\n", + " gauss_centers: [list of ints] containing the (y,x) position of the center of\n", + " the Gaussian fit\n", + " gauss_orientations: [list of np.ndarrays] containing the (eigenvalues, eigenvectors) of\n", + " the covariance matrix for the Gaussian fit of the Hilbert amplitude envelope. They are\n", + " both sorted according to the highest to lowest Eigenvalue.\n", + " fourier_centers: [list of ints] containing the (y,x) position of the center (max) of\n", + " the Fourier amplitude map\n", + " num_inputs: [int] dim[0] of input weights\n", + " num_outputs: [int] dim[1] of input weights\n", + " patch_edge_size: [int] int(floor(sqrt(num_inputs)))\n", + " areas: [list of floats] area of enclosed ellipse\n", + " spatial_frequncies: [list of floats] dominant spatial frequency for basis function\n", + " \"\"\"\n", + " envelope, bff_filt, hil_filter, bffs = hilbert_amplitude(weights, padding)\n", + " num_inputs, num_outputs = weights.shape\n", + " patch_edge_size = np.int(np.floor(np.sqrt(num_inputs)))\n", + " basis_funcs = [None]*num_outputs\n", + " envelopes = [None]*num_outputs\n", + " gauss_fits = [None]*num_outputs\n", + " gauss_centers = [None]*num_outputs\n", + " diameters = [None]*num_outputs\n", + " gauss_orientations = [None]*num_outputs\n", + " envelope_centers = [None]*num_outputs\n", + " fourier_centers = [None]*num_outputs\n", + " ellipse_orientations = [None]*num_outputs\n", + " fourier_maps = [None]*num_outputs\n", + " spatial_frequencies = [None]*num_outputs\n", + " areas = [None]*num_outputs\n", + " phases = [None]*num_outputs\n", + " for bf_idx in range(num_outputs):\n", + " # Reformatted individual basis function\n", + " basis_funcs[bf_idx] = weights.T[bf_idx,...].reshape((patch_edge_size, patch_edge_size))\n", + " # Reformatted individual envelope filter\n", + " envelopes[bf_idx] = np.abs(envelope[bf_idx,...]).reshape((patch_edge_size, patch_edge_size))\n", + " # Basis function center\n", + " max_ys = envelopes[bf_idx].argmax(axis=0) # Returns row index for each col\n", + " max_x = np.argmax(envelopes[bf_idx].max(axis=0))\n", + " y_cen = max_ys[max_x]\n", + " x_cen = max_x\n", + " envelope_centers[bf_idx] = (y_cen, x_cen)\n", + " # Gaussian fit to Hilbet amplitude envelope\n", + " gauss_fit, grid, gauss_mean, gauss_cov = get_gauss_fit(envelopes[bf_idx],\n", + " num_gauss_fits, gauss_thresh)\n", + " gauss_fits[bf_idx] = (gauss_fit, grid)\n", + " gauss_centers[bf_idx] = gauss_mean\n", + " evals, evecs = np.linalg.eigh(gauss_cov)\n", + " sort_indices = np.argsort(evals)[::-1]\n", + " gauss_orientations[bf_idx] = (evals[sort_indices], evecs[:,sort_indices])\n", + " width, height = evals[sort_indices] # Width & height are relative to orientation\n", + " diameters[bf_idx] = np.sqrt(width**2+height**2)\n", + " # Fourier function center, spatial frequency, orientation\n", + " fourier_map = np.sqrt(np.real(bffs[bf_idx, ...])**2+np.imag(bffs[bf_idx, ...])**2)\n", + " fourier_maps[bf_idx] = fourier_map\n", + " N = fourier_map.shape[0]\n", + " center_freq = int(np.floor(N/2))\n", + " fourier_map[center_freq, center_freq] = 0 # remove DC component\n", + " max_fys = fourier_map.argmax(axis=0)\n", + " max_fx = np.argmax(fourier_map.max(axis=0))\n", + " fy_cen = (max_fys[max_fx] - (N/2)) * (patch_edge_size/N)\n", + " fx_cen = (max_fx - (N/2)) * (patch_edge_size/N)\n", + " fourier_centers[bf_idx] = [fy_cen, fx_cen]\n", + " # NOTE: we flip fourier_centers because fx_cen is the peak of the x frequency,\n", + " # which would be a y coordinate\n", + " ellipse_orientations[bf_idx] = np.arctan2(*fourier_centers[bf_idx][::-1])\n", + " spatial_frequencies[bf_idx] = np.sqrt(fy_cen**2 + fx_cen**2)\n", + " areas[bf_idx] = np.pi * np.prod(evals)\n", + " phases[bf_idx] = np.angle(bffs[bf_idx])[y_cen, x_cen]\n", + " output = {\"basis_functions\":basis_funcs, \"envelopes\":envelopes, \"gauss_fits\":gauss_fits,\n", + " \"gauss_centers\":gauss_centers, \"gauss_orientations\":gauss_orientations, \"areas\":areas,\n", + " \"fourier_centers\":fourier_centers, \"fourier_maps\":fourier_maps, \"num_inputs\":num_inputs,\n", + " \"spatial_frequencies\":spatial_frequencies, \"envelope_centers\":envelope_centers,\n", + " \"num_outputs\":num_outputs, \"patch_edge_size\":patch_edge_size, \"phases\":phases,\n", + " \"ellipse_orientations\":ellipse_orientations, \"diameters\":diameters}\n", + " return output" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "bf_stats = get_dictionary_stats(\n", + " gray_lca_weights.reshape(gray_lca_weights.shape[0], -1).T,\n", + " padding=32,\n", + " num_gauss_fits=20,\n", + " gauss_thresh=0.2)\n", + "\n", + "np.savez(\n", + " model.params.save_dir+'bf_summary_stats.npz',\n", + " data={'bf_stats':bf_stats})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def clear_axis(ax, spines=\"none\"):\n", + " for ax_loc in [\"top\", \"bottom\", \"left\", \"right\"]:\n", + " ax.spines[ax_loc].set_color(spines)\n", + " ax.set_yticklabels([])\n", + " ax.set_xticklabels([])\n", + " ax.get_xaxis().set_visible(False)\n", + " ax.get_yaxis().set_visible(False)\n", + " ax.tick_params(axis=\"both\", bottom=False, top=False, left=False, right=False)\n", + " return ax\n", + "\n", + "def plot_ellipse(axis, center, shape, angle, color_val=\"auto\", alpha=1.0, lines=False,\n", + " fill_ellipse=False):\n", + " \"\"\"\n", + " Add an ellipse to given axis\n", + " Inputs:\n", + " axis [matplotlib.axes._subplots.AxesSubplot] axis on which ellipse should be drawn\n", + " center [tuple or list] specifying [y, x] center coordinates\n", + " shape [tuple or list] specifying [width, height] shape of ellipse\n", + " angle [float] specifying angle of ellipse\n", + " color_val [matplotlib color spec] specifying the color of the edge & face of the ellipse\n", + " alpha [float] specifying the transparency of the ellipse\n", + " lines [bool] if true, output will be a line, where the secondary axis of the ellipse\n", + " is collapsed\n", + " fill_ellipse [bool] if true and lines is false then a filled ellipse will be plotted\n", + " Outputs:\n", + " ellipse [matplotlib.patches.ellipse] ellipse object\n", + " \"\"\"\n", + " if fill_ellipse:\n", + " face_color_val = \"none\" if color_val==\"auto\" else color_val\n", + " else:\n", + " face_color_val = \"none\"\n", + " y_cen, x_cen = center\n", + " width, height = shape\n", + " if lines:\n", + " min_length = 0.1\n", + " if width < height:\n", + " width = min_length\n", + " elif width > height:\n", + " height = min_length\n", + " ellipse = matplotlib.patches.Ellipse(xy=[x_cen, y_cen], width=width,\n", + " height=height, angle=angle, edgecolor=color_val, facecolor=face_color_val,\n", + " alpha=alpha, fill=True)\n", + " axis.add_artist(ellipse)\n", + " ellipse.set_clip_box(axis.bbox)\n", + " return ellipse\n", + "\n", + "def plot_ellipse_summaries(bf_stats, num_bf=-1, lines=False, rand_bf=False):\n", + " \"\"\"\n", + " Plot basis functions with summary ellipses drawn over them\n", + " Inputs:\n", + " bf_stats [dict] output of dp.get_dictionary_stats()\n", + " num_bf [int] number of basis functions to plot (<=0 is all; >total is all)\n", + " lines [bool] If true, will plot lines instead of ellipses\n", + " rand_bf [bool] If true, will choose a random set of basis functions\n", + " \"\"\"\n", + " tot_num_bf = len(bf_stats[\"basis_functions\"])\n", + " if num_bf <= 0 or num_bf > tot_num_bf:\n", + " num_bf = tot_num_bf\n", + " SFs = np.asarray([np.sqrt(fcent[0]**2 + fcent[1]**2)\n", + " for fcent in bf_stats[\"fourier_centers\"]], dtype=np.float32)\n", + " sf_sort_indices = np.argsort(SFs)\n", + " if rand_bf:\n", + " bf_range = np.random.choice([i for i in range(tot_num_bf)], num_bf, replace=False)\n", + " num_plots_y = int(np.ceil(np.sqrt(num_bf)))\n", + " num_plots_x = int(np.ceil(np.sqrt(num_bf)))\n", + " gs = gridspec.GridSpec(num_plots_y, num_plots_x)\n", + " fig = plt.figure(figsize=(17,17))\n", + " filter_idx = 0\n", + " for plot_id in np.ndindex((num_plots_y, num_plots_x)):\n", + " ax = clear_axis(fig.add_subplot(gs[plot_id]))\n", + " if filter_idx < tot_num_bf and filter_idx < num_bf:\n", + " if rand_bf:\n", + " bf_idx = bf_range[filter_idx]\n", + " else:\n", + " bf_idx = filter_idx\n", + " bf = bf_stats[\"basis_functions\"][bf_idx]\n", + " ax.imshow(bf, interpolation=\"Nearest\", cmap=\"grays_r\")\n", + " ax.set_title(str(bf_idx), fontsize=\"8\")\n", + " center = bf_stats[\"gauss_centers\"][bf_idx]\n", + " evals, evecs = bf_stats[\"gauss_orientations\"][bf_idx]\n", + " orientations = bf_stats[\"fourier_centers\"][bf_idx]\n", + " angle = np.rad2deg(np.pi/2 + np.arctan2(*orientations))\n", + " alpha = 1.0\n", + " ellipse = plot_ellipse(ax, center, evals, angle, color_val=\"b\", alpha=alpha, lines=lines)\n", + " filter_idx += 1\n", + " ax.set_aspect(\"equal\")\n", + " plt.show()\n", + " return fig" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig = plot_ellipse_summaries(bf_stats, lines=False)\n", + "\n", + "fig.savefig(\n", + " f'{model.params.disp_dir}/basis_function_fits.png',\n", + " transparent=False,\n", + " bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def bgr_colormap():\n", + " \"\"\"\n", + " In cdict, the first column is interpolated between 0.0 & 1.0 - this indicates the value to be plotted\n", + " the second column specifies how interpolation should be done from below\n", + " the third column specifies how interpolation should be done from above\n", + " if the second column does not equal the third, then there will be a break in the colors\n", + " \"\"\"\n", + " darkness = 0.85 #0 is black, 1 is white\n", + " cdict = {\n", + " 'red': ((0.0, 0.0, 0.0),\n", + " (0.5, darkness, darkness),\n", + " (1.0, 1.0, 1.0)),\n", + " 'green': ((0.0, 0.0, 0.0),\n", + " (0.5, darkness, darkness),\n", + " (1.0, 0.0, 0.0)),\n", + " 'blue': ((0.0, 1.0, 1.0),\n", + " (0.5, darkness, darkness),\n", + " (1.0, 0.0, 0.0))\n", + " }\n", + " return LinearSegmentedColormap(\"bgr\", cdict)\n", + "\n", + "def plot_pooling_centers(bf_stats, pooling_filters, num_pooling_filters, num_connected_weights,\n", + " spot_size=10, figsize=None):\n", + " \"\"\"\n", + " Plot 2nd layer (fully-connected) weights in terms of spatial/frequency centers of\n", + " 1st layer weights\n", + " Inputs:\n", + " bf_stats [dict] Output of dp.get_dictionary_stats() which was run on the 1st layer weights\n", + " pooling_filters [np.ndarray] 2nd layer weights\n", + " should be shape [num_1st_layer_neurons, num_2nd_layer_neurons]\n", + " num_pooling_filters [int] How many 2nd layer neurons to plot\n", + " figsize [tuple] Containing the (width, height) of the figure, in inches\n", + " spot_size [int] How big to make the points\n", + " \"\"\"\n", + " num_filters_y = int(np.ceil(np.sqrt(num_pooling_filters)))\n", + " num_filters_x = int(np.ceil(np.sqrt(num_pooling_filters)))\n", + " tot_pooling_filters = pooling_filters.shape[1]\n", + " #filter_indices = np.random.choice(tot_pooling_filters, num_pooling_filters, replace=False)\n", + " filter_indices = np.arange(tot_pooling_filters, dtype=np.int32)\n", + " cmap = plt.get_cmap(bgr_colormap())# Could also use \"nipy_spectral\", \"coolwarm\", \"bwr\"\n", + " cNorm = matplotlib.colors.SymLogNorm(linthresh=0.03, linscale=0.01, vmin=-1.0, vmax=1.0)\n", + " scalarMap = matplotlib.cm.ScalarMappable(norm=cNorm, cmap=cmap)\n", + " x_p_cent = [x for (y,x) in bf_stats[\"gauss_centers\"]]# Get raw points\n", + " y_p_cent = [y for (y,x) in bf_stats[\"gauss_centers\"]]\n", + " x_f_cent = [x for (y,x) in bf_stats[\"fourier_centers\"]]\n", + " y_f_cent = [y for (y,x) in bf_stats[\"fourier_centers\"]]\n", + " max_sf = np.max(np.abs(x_f_cent+y_f_cent))\n", + " pair_w_gap = 0.01\n", + " group_w_gap = 0.03\n", + " h_gap = 0.03\n", + " plt_w = (num_filters_x/num_pooling_filters)\n", + " plt_h = plt_w\n", + " if figsize is None:\n", + " fig = plt.figure()\n", + " figsize = (fig.get_figwidth(), fig.get_figheight())\n", + " else:\n", + " fig = plt.figure(figsize=figsize) #figsize is (w,h)\n", + " axes = []\n", + " filter_id = 0\n", + " for plot_id in np.ndindex((num_filters_y, num_filters_x)):\n", + " if all(pid == 0 for pid in plot_id):\n", + " axes.append(clear_axis(fig.add_axes([0, plt_h+h_gap, 2*plt_w, plt_h])))\n", + " scalarMap._A = []\n", + " cbar = fig.colorbar(scalarMap, ax=axes[-1], ticks=[-1, 0, 1], aspect=10, location=\"bottom\")\n", + " cbar.ax.set_xticklabels([\"-1\", \"0\", \"1\"])\n", + " cbar.ax.xaxis.set_ticks_position('top')\n", + " cbar.ax.xaxis.set_label_position('top')\n", + " for label in cbar.ax.xaxis.get_ticklabels():\n", + " label.set_weight(\"bold\")\n", + " label.set_fontsize(10+figsize[0])\n", + " if (filter_id < num_pooling_filters):\n", + " example_filter = pooling_filters[:, filter_indices[filter_id]]\n", + " top_indices = np.argsort(np.abs(example_filter))[::-1] #descending\n", + " selected_indices = top_indices[:num_connected_weights][::-1] #select top, plot weakest first\n", + " filter_norm = np.max(np.abs(example_filter))\n", + " connection_colors = [scalarMap.to_rgba(example_filter[bf_idx]/filter_norm)\n", + " for bf_idx in range(bf_stats[\"num_outputs\"])]\n", + " if num_connected_weights < top_indices.size:\n", + " black_indices = top_indices[num_connected_weights:][::-1]\n", + " xp = [x_p_cent[i] for i in black_indices]+[x_p_cent[i] for i in selected_indices]\n", + " yp = [y_p_cent[i] for i in black_indices]+[y_p_cent[i] for i in selected_indices]\n", + " xf = [x_f_cent[i] for i in black_indices]+[x_f_cent[i] for i in selected_indices]\n", + " yf = [y_f_cent[i] for i in black_indices]+[y_f_cent[i] for i in selected_indices]\n", + " c = [(0.1,0.1,0.1,1.0) for i in black_indices]+[connection_colors[i] for i in selected_indices]\n", + " else:\n", + " xp = [x_p_cent[i] for i in selected_indices]\n", + " yp = [y_p_cent[i] for i in selected_indices]\n", + " xf = [x_f_cent[i] for i in selected_indices]\n", + " yf = [y_f_cent[i] for i in selected_indices]\n", + " c = [connection_colors[i] for i in selected_indices]\n", + " (y_id, x_id) = plot_id\n", + " if x_id == 0:\n", + " ax_l = 0\n", + " ax_b = - y_id * (plt_h+h_gap)\n", + " else:\n", + " bbox = axes[-1].get_position().get_points()[0]#bbox is [[x0,y0],[x1,y1]]\n", + " prev_l = bbox[0]\n", + " prev_b = bbox[1]\n", + " ax_l = prev_l + plt_w + group_w_gap\n", + " ax_b = prev_b\n", + " ax_w = plt_w\n", + " ax_h = plt_h\n", + " axes.append(clear_axis(fig.add_axes([ax_l, ax_b, ax_w, ax_h])))\n", + " axes[-1].invert_yaxis()\n", + " axes[-1].scatter(xp, yp, c=c, s=spot_size, alpha=0.8)\n", + " axes[-1].set_xlim(0, bf_stats[\"patch_edge_size\"]-1)\n", + " axes[-1].set_ylim(bf_stats[\"patch_edge_size\"]-1, 0)\n", + " axes[-1].set_aspect(\"equal\")\n", + " axes[-1].set_facecolor(\"w\")\n", + " axes.append(clear_axis(fig.add_axes([ax_l+ax_w+pair_w_gap, ax_b, ax_w, ax_h])))\n", + " axes[-1].scatter(xf, yf, c=c, s=spot_size, alpha=0.8)\n", + " axes[-1].set_xlim([-max_sf, max_sf])\n", + " axes[-1].set_ylim([-max_sf, max_sf])\n", + " axes[-1].xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True))\n", + " axes[-1].yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True))\n", + " axes[-1].set_aspect(\"equal\")\n", + " axes[-1].set_facecolor(\"w\")\n", + " filter_id += 1\n", + " plt.show()\n", + " return fig" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "kernel_pos = 0\n", + "pool_weights = model.pool_1.layer.weight.detach().cpu().numpy()\n", + "outputs, inputs, kernel_h, kernel_w = pool_weights.shape\n", + "\n", + "fig = plot_pooling_centers(\n", + " bf_stats,\n", + " pool_weights[:, :, kernel_pos, kernel_pos].T,\n", + " num_pooling_filters=outputs,\n", + " num_connected_weights=inputs,\n", + " spot_size=3,\n", + " figsize=(5, 5))\n", + "\n", + "fig.savefig(\n", + " f'{model.params.disp_dir}/pooling_spots.png',\n", + " transparent=False,\n", + " bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def plot_pooling_summaries(bf_stats, pooling_filters, num_pooling_filters,\n", + " num_connected_weights, lines=False, figsize=None):\n", + " \"\"\"\n", + " Plot 2nd layer (fully-connected) weights in terms of connection strengths to 1st layer weights\n", + " Inputs:\n", + " bf_stats [dict] output of dp.get_dictionary_stats() which was run on the 1st layer weights\n", + " pooling_filters [np.ndarray] 2nd layer weights\n", + " should be shape [num_1st_layer_neurons, num_2nd_layer_neurons]\n", + " num_pooling_filters [int] How many 2nd layer neurons to plot\n", + " num_connected_weights [int] How many 1st layer weight summaries to include\n", + " for a given 2nd layer neuron\n", + " lines [bool] if True, 1st layer weight summaries will appear as lines instead of ellipses\n", + " \"\"\"\n", + " num_inputs = bf_stats[\"num_inputs\"]\n", + " num_outputs = bf_stats[\"num_outputs\"]\n", + " tot_pooling_filters = pooling_filters.shape[1]\n", + " patch_edge_size = np.int32(np.sqrt(num_inputs))\n", + " filter_idx_list = np.arange(num_pooling_filters, dtype=np.int32)\n", + " assert num_pooling_filters <= num_outputs, (\n", + " \"num_pooling_filters must be less than or equal to bf_stats['num_outputs']\")\n", + " cmap = bgr_colormap()#plt.get_cmap('bwr')\n", + " cNorm = matplotlib.colors.SymLogNorm(linthresh=0.03, linscale=0.01, vmin=-1.0, vmax=1.0)\n", + " scalarMap = matplotlib.cm.ScalarMappable(norm=cNorm, cmap=cmap)\n", + " num_plots_y = np.int32(np.ceil(np.sqrt(num_pooling_filters)))\n", + " num_plots_x = np.int32(np.ceil(np.sqrt(num_pooling_filters)))+1 # +cbar col\n", + " gs_widths = [1 for _ in range(num_plots_x-1)]+[0.3]\n", + " gs = gridspec.GridSpec(num_plots_y, num_plots_x, width_ratios=gs_widths)\n", + " if figsize is None:\n", + " fig = plt.figure()\n", + " figsize = (fig.get_figwidth(), fig.get_figheight())\n", + " else:\n", + " fig = plt.figure(figsize=figsize)\n", + " filter_total = 0\n", + " for plot_id in np.ndindex((num_plots_y, num_plots_x-1)):\n", + " (y_id, x_id) = plot_id\n", + " ax = fig.add_subplot(gs[plot_id])\n", + " if (filter_total < num_pooling_filters and x_id != num_plots_x-1):\n", + " ax = clear_axis(ax, spines=\"k\")\n", + " filter_idx = filter_idx_list[filter_total]\n", + " example_filter = pooling_filters[:, filter_idx]\n", + " top_indices = np.argsort(np.abs(example_filter))[::-1] #descending\n", + " filter_norm = np.max(np.abs(example_filter))\n", + " SFs = np.asarray([np.sqrt(fcent[0]**2 + fcent[1]**2)\n", + " for fcent in bf_stats[\"fourier_centers\"]], dtype=np.float32)\n", + " # Plot weakest of the top connected filters first because of occlusion\n", + " for bf_idx in top_indices[:num_connected_weights][::-1]:\n", + " connection_strength = example_filter[bf_idx]/filter_norm\n", + " color_val = scalarMap.to_rgba(connection_strength)\n", + " center = bf_stats[\"gauss_centers\"][bf_idx]\n", + " evals, evecs = bf_stats[\"gauss_orientations\"][bf_idx]\n", + " orientations = bf_stats[\"fourier_centers\"][bf_idx]\n", + " angle = np.rad2deg(np.pi/2 + np.arctan2(*orientations))\n", + " alpha = 0.5#todo:spatial_freq for filled ellipses?\n", + " ellipse = plot_ellipse(ax, center, evals, angle, color_val, alpha=alpha, lines=lines)\n", + " ax.set_xlim(0, patch_edge_size-1)\n", + " ax.set_ylim(patch_edge_size-1, 0)\n", + " filter_total += 1\n", + " else:\n", + " ax = clear_axis(ax, spines=\"none\")\n", + " ax.set_aspect(\"equal\")\n", + " scalarMap._A = []\n", + " ax = clear_axis(fig.add_subplot(gs[0, -1]))\n", + " cbar = fig.colorbar(scalarMap, ax=ax, ticks=[-1, 0, 1])\n", + " cbar.ax.set_yticklabels([\"-1\", \"0\", \"1\"])\n", + " for label in cbar.ax.yaxis.get_ticklabels():\n", + " label.set_weight(\"bold\")\n", + " label.set_fontsize(14)\n", + " plt.show()\n", + " return fig" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig = plot_pooling_summaries(\n", + " bf_stats,\n", + " pool_weights[:, :, kernel_pos, kernel_pos].T,\n", + " num_pooling_filters=outputs,\n", + " num_connected_weights=40,\n", + " lines=True,\n", + " figsize=(18,18))\n", + "\n", + "fig.savefig(\n", + " f'{model.params.disp_dir}/pooling_lines.png',\n", + " transparent=False,\n", + " bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "P = pool_weights[:, :, kernel_pos, kernel_pos] # [inputs, outputs]\n", + "p_norm = np.linalg.norm(P, ord=2, axis=0)\n", + "affinity = np.dot(P.T, P) # cosyne similarity of neurons in embedded space\n", + "for i in range(affinity.shape[0]):\n", + " for j in range(affinity.shape[1]):\n", + " affinity[i, j] = affinity[i, j] / (p_norm[i] * p_norm[j])\n", + "affinity = affinity.T # [inputs, inputs]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig = plot_pooling_centers(\n", + " bf_stats,\n", + " affinity,\n", + " num_pooling_filters=outputs,\n", + " num_connected_weights=128, \n", + " spot_size=30,\n", + " figsize=(5, 5))\n", + "\n", + "fig.savefig(\n", + " f'{model.params.disp_dir}/affinity_spots.png',\n", + " transparent=False,\n", + " bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig = plot_pooling_summaries(\n", + " bf_stats,\n", + " affinity,\n", + " num_pooling_filters=outputs,\n", + " num_connected_weights=20,\n", + " lines=True,\n", + " figsize=(10, 10))\n", + "\n", + "fig.savefig(\n", + " f'{model.params.disp_dir}/affinity_lines.png',\n", + " transparent=False,\n", + " bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "example_batch = next(iter(train_loader))[0].to(model.params.device)\n", + "example_batch = model[0].preprocess_data(example_batch)\n", + "example_batch *= train_std_image\n", + "example_batch += train_mean_image\n", + "batch_min = example_batch.min().item()\n", + "batch_max = example_batch.max().item()\n", + "\n", + "example_image = example_batch[0, ...]\n", + "print(\n", + " f'min = {example_image.min().item()}'+\n", + " f'\\nmean = {example_image.mean().item()}'+\n", + " f'\\nmax = {example_image.max().item()}'+\n", + " f'\\nstd = {example_image.std().item()}')\n", + "\n", + "plot_example_image = ((example_image * train_std_image) + train_mean_image).cpu().numpy().transpose(1,2,0)\n", + "fig, ax = plot.subplots(nrows=1, ncols=1)\n", + "ax = pf.clear_axis(ax)\n", + "ax.imshow(plot_example_image, vmin=0, vmax=1)\n", + "plot.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "beta_2 = model(example_image[None,...])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class lca_2_recon_params(LcaParams):\n", + " def set_params(self):\n", + " super(lca_2_recon_params, self).set_params()\n", + " self.model_type = 'lca'\n", + " self.model_name = 'lca_2_recon'\n", + " self.version = '0'\n", + " self.layer_types = ['fc']\n", + " self.standardize_data = False\n", + " self.rescale_data_to_one = False\n", + " self.center_dataset = False\n", + " self.batch_size = 1\n", + " self.dt = 0.001\n", + " self.tau = 0.2\n", + " self.num_steps = 75\n", + " self.rectify_a = True\n", + " self.thresh_type = 'hard'\n", + " self.compute_helper_params()\n", + " \n", + "params = lca_2_recon_params()\n", + "params.set_params()\n", + "params.layer_channels = model.lca_2.params.layer_channels\n", + "params.sparse_mult = model.lca_2.params.sparse_mult\n", + "params.data_shape = list(beta_2.shape)\n", + "params.epoch_size = 1\n", + "params.num_pixels = np.prod(params.data_shape)\n", + "\n", + "lca_2_recon_model = loaders.load_model(params.model_type)\n", + "lca_2_recon_model.setup(params)\n", + "lca_2_recon_model.to(params.device)\n", + "lca_2_recon_model.eval()\n", + "with torch.no_grad():\n", + " lca_2_recon_model.weight = nn.Parameter(model.pool_2.weight)\n", + "alpha_2_hat = lca_2_recon_model(beta_2)[:, :, None, None]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with torch.no_grad():\n", + " beta_1_hat = F.conv_transpose2d(\n", + " input=alpha_2_hat,\n", + " weight=model.lca_2.weight,\n", + " bias=None,\n", + " stride=model.lca_2.params.stride,\n", + " padding=model.lca_2.params.padding)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from DeepSparseCoding.modules.lca_module import LcaModule\n", + "from DeepSparseCoding.models.base_model import BaseModel\n", + "from DeepSparseCoding.utils.run_utils import compute_deconv_output_shape\n", + "import DeepSparseCoding.modules.losses as losses\n", + "\n", + "class TransposedLcaModule(LcaModule):\n", + " def setup_module(self, params):\n", + " super(TransposedLcaModule, self).setup_module(params)\n", + " if self.params.layer_types[0] == 'conv':\n", + " self.layer_output_shapes = [self.params.data_shape] # [channels, height, width]\n", + " assert (self.params.data_shape[-1] % self.params.stride == 0), (\n", + " f'Stride = {self.params.stride} must divide evenly into input edge size = {self.params.data_shape[-1]}')\n", + " self.w_shape = [\n", + " self.params.layer_channels,\n", + " self.params.data_shape[0], # channels = 1\n", + " self.params.kernel_size,\n", + " self.params.kernel_size\n", + " ]\n", + " output_height = compute_deconv_output_shape(\n", + " self.layer_output_shapes[-1][1],\n", + " self.params.kernel_size,\n", + " self.params.stride,\n", + " self.params.padding,\n", + " output_padding=self.params.output_padding,\n", + " dilation=1)\n", + " output_width = compute_deconv_output_shape(\n", + " self.layer_output_shapes[-1][2],\n", + " self.params.kernel_size,\n", + " self.params.stride,\n", + " self.params.padding,\n", + " output_padding=self.params.output_padding,\n", + " dilation=1)\n", + " self.layer_output_shapes.append([self.params.layer_channels, output_height, output_width])\n", + " w_init = torch.randn(self.w_shape)\n", + " w_init_normed = dp.l2_normalize_weights(w_init, eps=self.params.eps)\n", + " self.weight = nn.Parameter(w_init_normed, requires_grad=True)\n", + "\n", + " def compute_excitatory_current(self, input_tensor, a_in, weight=None):\n", + " if weight is None:\n", + " weight = self.weight\n", + " if self.params.layer_types[0] == 'fc':\n", + " excitatory_current = torch.matmul(input_tensor, weight)\n", + " else:\n", + " recon = self.get_recon_from_latents(a_in, weight)\n", + " recon_error = input_tensor - recon\n", + " error_injection = F.conv_transpose2d(\n", + " input=recon_error,\n", + " weight=weight,\n", + " bias=None,\n", + " stride=self.params.stride,\n", + " padding=self.params.padding,\n", + " output_padding=self.params.output_padding,\n", + " dilation=1\n", + " )\n", + " excitatory_current = error_injection + a_in\n", + " return excitatory_current\n", + "\n", + " def get_recon_from_latents(self, a_in, weight=None):\n", + " if weight is None:\n", + " weight = self.weight\n", + " if self.params.layer_types[0] == 'fc':\n", + " recon = torch.matmul(a_in, torch.transpose(weight, dim0=0, dim1=1))\n", + " else:\n", + " recon = F.conv2d(\n", + " input=a_in,\n", + " weight=weight,\n", + " bias=None,\n", + " stride=self.params.stride,\n", + " padding=self.params.padding,\n", + " dilation=1\n", + " )\n", + " return recon\n", + "\n", + "class TransposedLcaModel(BaseModel, TransposedLcaModule):\n", + " def setup(self, params, logger=None):\n", + " super(TransposedLcaModel, self).setup(params, logger)\n", + " self.setup_module(params)\n", + " self.setup_optimizer()\n", + " if params.checkpoint_boot_log != '':\n", + " checkpoint = self.get_checkpoint_from_log(params.checkpoint_boot_log)\n", + " self.module.load_state_dict(checkpoint['model_state_dict'])\n", + " self.module.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", + "\n", + " def get_total_loss(self, input_tuple):\n", + " input_tensor, input_labels = input_tuple\n", + " latents = self.get_encodings(input_tensor)\n", + " recon = self.get_recon_from_latents(latents)\n", + " recon_loss = losses.half_squared_l2(input_tensor, recon)\n", + " sparse_loss = self.params.sparse_mult * losses.l1_norm(latents)\n", + " total_loss = recon_loss + sparse_loss\n", + " return total_loss\n", + "\n", + " def generate_update_dict(self, input_data, input_labels=None, batch_step=0, update_dict=None):\n", + " if update_dict is None:\n", + " update_dict = super(TransposedLcaModel, self).generate_update_dict(input_data, input_labels, batch_step)\n", + " stat_dict = dict()\n", + " latents = self.get_encodings(input_data)\n", + " recon = self.get_recon_from_latents(latents)\n", + " recon_loss = losses.half_squared_l2(input_data, recon).item()\n", + " sparse_loss = self.params.sparse_mult * losses.l1_norm(latents).item()\n", + " stat_dict['weight_lr'] = self.scheduler.get_lr()[0]\n", + " stat_dict['loss_recon'] = recon_loss\n", + " stat_dict['loss_sparse'] = sparse_loss\n", + " stat_dict['loss_total'] = recon_loss + sparse_loss\n", + " stat_dict['input_max_mean_min'] = [\n", + " input_data.max().item(), input_data.mean().item(), input_data.min().item()]\n", + " stat_dict['recon_max_mean_min'] = [\n", + " recon.max().item(), recon.mean().item(), recon.min().item()]\n", + " def count_nonzero(array, dim):\n", + " # TODO: github issue 23907 requests torch.count_nonzero, integrated in torch 1.7\n", + " return torch.sum(array !=0, dim=dim, dtype=torch.float)\n", + " latent_dims = tuple([i for i in range(len(latents.shape))])\n", + " latent_nnz = count_nonzero(latents, dim=latent_dims).item()\n", + " stat_dict['fraction_active_all_latents'] = latent_nnz / latents.numel()\n", + " if self.params.layer_types[0] == 'conv':\n", + " latent_map_dims = latent_dims[2:]\n", + " latent_map_size = np.prod(list(latents.shape[2:]))\n", + " latent_channel_nnz = count_nonzero(latents, dim=latent_map_dims)/latent_map_size\n", + " latent_channel_mean_nnz = torch.mean(latent_channel_nnz).item()\n", + " stat_dict['fraction_active_latents_per_channel'] = latent_channel_mean_nnz\n", + " num_channels = latents.shape[1]\n", + " latent_patch_mean_nnz = torch.mean(count_nonzero(latents, dim=1)/num_channels).item()\n", + " stat_dict['fraction_active_latents_per_patch'] = latent_patch_mean_nnz\n", + " update_dict.update(stat_dict)\n", + " return update_dict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class lca_1_recon_params(LcaParams):\n", + " def set_params(self):\n", + " super(lca_1_recon_params, self).set_params()\n", + " self.model_type = 'lca'\n", + " self.model_name = 'lca_1_recon'\n", + " self.version = '0'\n", + " self.layer_types = ['conv']\n", + " self.standardize_data = False\n", + " self.rescale_data_to_one = False\n", + " self.center_dataset = False\n", + " self.batch_size = 1\n", + " self.dt = 0.001\n", + " self.tau = 0.2\n", + " self.num_steps = 75\n", + " self.rectify_a = True\n", + " self.thresh_type = 'hard'\n", + " self.compute_helper_params()\n", + " \n", + "params = lca_1_recon_params()\n", + "params.set_params()\n", + "params.layer_channels = model.pool_1.params.layer_channels[0]\n", + "params.kernel_size = model.pool_1.params.pool_ksize\n", + "params.stride = model.pool_1.params.pool_stride\n", + "params.padding = 0\n", + "params.sparse_mult = 0.01#model.lca_1.params.sparse_mult\n", + "params.data_shape = list(beta_1_hat.shape[1:])\n", + "params.epoch_size = 1\n", + "params.output_padding = 1\n", + "params.num_pixels = np.prod(params.data_shape)\n", + "\n", + "lca_1_recon_model = TransposedLcaModel()\n", + "lca_1_recon_model.setup(params)\n", + "lca_1_recon_model.to(params.device)\n", + "lca_1_recon_model.eval()\n", + "with torch.no_grad():\n", + " lca_1_recon_model.weight = nn.Parameter(model.pool_1.weight)\n", + "alpha_1_hat = lca_1_recon_model(beta_1_hat)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with torch.no_grad():\n", + " recon = F.conv_transpose2d(\n", + " input=alpha_1_hat,\n", + " weight=model.lca_1.weight,\n", + " bias=None,\n", + " stride=model.lca_1.params.stride,\n", + " padding=model.lca_1.params.padding)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "alpha_2_nnz = torch.sum(alpha_2_hat !=0,\n", + " dim=tuple([i for i in range(len(alpha_2_hat.shape))]),\n", + " dtype=torch.float)/alpha_2_hat.numel()\n", + "alpha_1_nnz = torch.sum(alpha_1_hat !=0,\n", + " dim=tuple([i for i in range(len(alpha_1_hat.shape))]),\n", + " dtype=torch.float)/alpha_1_hat.numel()\n", + "print(\n", + " f'beta2 shape = {beta_2.shape}' + \n", + " f'\\nalpha2^ nnz = {alpha_2_nnz}'+\n", + " f'\\nalpha2^ shape = {alpha_2_hat.shape}'+\n", + " f'\\nbeta1^ shape = {beta_1_hat.shape}'\n", + " f'\\nalpha1^ nnz = {alpha_1_nnz}'+\n", + " f'\\nalpha1^ shape = {alpha_1_hat.shape}'+\n", + " f'\\nimage^ shape = {recon.shape}'\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\n", + " f'recon min = {recon.min().item()}'+\n", + " f'\\nrecon mean = {recon.mean().item()}'+\n", + " f'\\nrecon max = {recon.max().item()}'+\n", + " f'\\nrecon std = {recon.std().item()}')\n", + "\n", + "plot_recon = ((recon.squeeze() * train_std_image) + train_mean_image).cpu().numpy().transpose(1,2,0)\n", + "fig, ax = plot.subplots(nrows=1, ncols=1)\n", + "ax = pf.clear_axis(ax)\n", + "ax.imshow(plot_recon, vmin=0, vmax=1)\n", + "plot.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_recon = recon.squeeze().cpu().numpy().transpose(1,2,0)\n", + "plot_recon = (plot_recon - plot_recon.min()) / (plot_recon.max() - plot_recon.min())\n", + "fig, ax = plot.subplots(nrows=1, ncols=1)\n", + "ax = pf.clear_axis(ax)\n", + "ax.imshow(plot_recon)\n", + "plot.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.7" + }, + "varInspector": { + "cols": { + "lenName": 16, + "lenType": 16, + "lenVar": 40 + }, + "kernels_config": { + "python": { + "delete_cmd_postfix": "", + "delete_cmd_prefix": "del ", + "library": "var_list.py", + "varRefreshCmd": "print(var_dic_list())" + }, + "r": { + "delete_cmd_postfix": ") ", + "delete_cmd_prefix": "rm(", + "library": "var_list.r", + "varRefreshCmd": "cat(var_dic_list()) " + } + }, + "types_to_exclude": [ + "module", + "function", + "builtin_function_or_method", + "instance", + "_Feature" + ], + "window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 160ace4ea393fa9a313cdbfee2239add3ddb8e45 Mon Sep 17 00:00:00 2001 From: Dylan Date: Thu, 4 Mar 2021 11:47:24 +0000 Subject: [PATCH 34/44] updates to lca module for clarity & consistency removes unnecessary thresholding operators from modules/activations cleans up inhibitory connectivity function & adds convolutional version switched to torch.mm instead of torch.matmul when matrices are 2D layer_channels behaves like mlp, and now must include input channels minor comment addition/removal in pooling_module returned cifar preprocessing to be samplewise standardization fixed bug in tests with new dataset outputs updates ensemble lca test with comments & fixes ensemble state dict loading bug --- modules/activations.py | 21 ++++------ modules/lca_module.py | 36 ++++++++-------- modules/pooling_module.py | 4 +- params/lca_cifar10_params.py | 6 +-- params/lca_dsprites_params.py | 2 +- params/lca_mlp_cifar10_params.py | 4 +- params/lca_mlp_mnist_params.py | 2 +- params/lca_mnist_params.py | 4 +- params/lca_pool_cifar10_params.py | 4 +- params/lca_pool_lca_cifar10_params.py | 4 +- params/lca_pool_lca_pool_cifar10_params.py | 6 +-- .../lca_pool_lca_pool_mlp_cifar10_params.py | 6 +-- params/test_params.py | 23 +++++------ tests/test_data_processing.py | 2 +- tests/test_datasets.py | 4 +- tests/test_models.py | 41 ++++++++++++------- train_model.py | 3 +- utils/data_processing.py | 6 +-- utils/dataset_utils.py | 22 +++++----- 19 files changed, 101 insertions(+), 99 deletions(-) diff --git a/modules/activations.py b/modules/activations.py index 98446536..56a502a2 100644 --- a/modules/activations.py +++ b/modules/activations.py @@ -1,17 +1,5 @@ import torch import torch.nn as nn -import torch.nn.functional as F - -def activation_picker(activation_function): - if activation_function == 'identity': - return lambda x: x - if activation_function == 'relu': - return F.relu - if activation_function == 'lrelu' or activation_function == 'leaky_relu': - return F.leaky_relu - if activation_function == 'lca_threshold': - return lca_threshold - assert False, (f'Activation function {activation_function} is not supported.') def lca_threshold(u_in, thresh_type, rectify, sparse_threshold): u_zeros = torch.zeros_like(u_in) @@ -40,3 +28,12 @@ def lca_threshold(u_in, thresh_type, rectify, sparse_threshold): else: assert False, (f'Parameter thresh_type must be "soft" or "hard", not {thresh_type}') return a_out + +def activation_picker(activation_function): + if activation_function == 'identity': + return nn.Identity() + if activation_function == 'relu': + return nn.ReLU() + if activation_function == 'lrelu' or activation_function == 'leaky_relu': + return nn.LeakyReLU() + assert False, (f'Activation function {activation_function} is not supported.') diff --git a/modules/lca_module.py b/modules/lca_module.py index b8195626..b8446aae 100644 --- a/modules/lca_module.py +++ b/modules/lca_module.py @@ -16,39 +16,34 @@ class LcaModule(nn.Module): kernel_size: [int] edge size of the square convolving kernel stride: [int] vertical and horizontal stride of the convolution padding: [int] zero-padding added to both sides of the input - TODO: Inference process should be streamlined by defining only a single step and iterating it in forward() as is done here: - https://pytorch.org/tutorials/beginner/former_torchies/nnft_tutorial.html - - TODO: Remove dependency on data_shape to make more intuitive in a hierarchy. i.e. use layer_channels as is done in the mlp """ def setup_module(self, params): self.params = params if self.params.layer_types[0] == 'fc': - self.layer_output_shapes = [[self.params.layer_channels]] - self.w_shape = [self.params.num_pixels, self.params.layer_channels] + self.w_shape = self.params.layer_channels[::-1] #[outputs, inputs] + self.layer_output_shape = [self.params.layer_channels[-1]] else: - self.layer_output_shapes = [self.params.data_shape] # [channels, height, width] assert (self.params.data_shape[-1] % self.params.stride == 0), ( f'Stride = {self.params.stride} must divide evenly into input edge size = {self.params.data_shape[-1]}') self.w_shape = [ - self.params.layer_channels, - self.params.data_shape[0], # channels = 1 + self.params.layer_channels[1], + self.params.layer_channels[0], self.params.kernel_size, self.params.kernel_size ] output_height = compute_conv_output_shape( - self.layer_output_shapes[-1][1], + self.params.data_shape[1], self.params.kernel_size, self.params.stride, self.params.padding, dilation=1) output_width = compute_conv_output_shape( - self.layer_output_shapes[-1][2], + self.params.data_shape[2], self.params.kernel_size, self.params.stride, self.params.padding, dilation=1) - self.layer_output_shapes.append([self.params.layer_channels, output_height, output_width]) + self.layer_output_shape = [self.params.layer_channels[1], output_height, output_width] w_init = torch.randn(self.w_shape) w_init_normed = dp.l2_normalize_weights(w_init, eps=self.params.eps) self.weight = nn.Parameter(w_init_normed, requires_grad=True) @@ -60,7 +55,7 @@ def preprocess_data(self, input_tensor): def compute_excitatory_current(self, input_tensor, a_in): if self.params.layer_types[0] == 'fc': - excitatory_current = torch.matmul(input_tensor, self.weight) + excitatory_current = torch.mm(input_tensor, self.weight.T) else: recon = self.get_recon_from_latents(a_in) recon_error = input_tensor - recon @@ -75,12 +70,13 @@ def compute_excitatory_current(self, input_tensor, a_in): return excitatory_current def compute_inhibitory_connectivity(self): + identity = torch.eye(self.params.layer_channels[1], + requires_grad=True, device=self.params.device) if self.params.layer_types[0] == 'fc': - inhibitory_connectivity = torch.matmul(torch.transpose(self.weight, dim0=0, dim1=1), - self.weight) - torch.eye(self.params.layer_channels, - requires_grad=True, device=self.params.device) + inhibitory_connectivity = torch.mm(self.weight, self.weight.T) - identity else: - inhibitory_connectivity = 0 # TODO: return Grammian along channel dim for a single kernel location + conv_kernels = self.weight.view(1, -1) + inhibitory_connectivity = torch.mm(conv_kernels, conv_kernels.T) - identity return inhibitory_connectivity def threshold_units(self, u_in): @@ -90,7 +86,7 @@ def threshold_units(self, u_in): def step_inference(self, u_in, a_in, excitatory_current, inhibitory_connectivity, step): if self.params.layer_types[0] == 'fc': - lca_explain_away = torch.matmul(a_in, inhibitory_connectivity) + lca_explain_away = torch.mm(a_in, inhibitory_connectivity) else: lca_explain_away = 0 # already computed in excitatory_current du = excitatory_current - lca_explain_away - u_in @@ -98,7 +94,7 @@ def step_inference(self, u_in, a_in, excitatory_current, inhibitory_connectivity return u_out, lca_explain_away def infer_coefficients(self, input_tensor): - output_shape = [input_tensor.shape[0]] + self.layer_output_shapes[-1] + output_shape = [input_tensor.shape[0]] + self.layer_output_shape u_list = [torch.zeros(output_shape, device=self.params.device)] a_list = [self.threshold_units(u_list[0])] excitatory_current = self.compute_excitatory_current(input_tensor, a_list[-1]) @@ -119,7 +115,7 @@ def infer_coefficients(self, input_tensor): def get_recon_from_latents(self, a_in): if self.params.layer_types[0] == 'fc': - recon = torch.matmul(a_in, torch.transpose(self.weight, dim0=0, dim1=1)) + recon = torch.mm(a_in, self.weight) else: recon = F.conv_transpose2d( input=a_in, diff --git a/modules/pooling_module.py b/modules/pooling_module.py index 8b181cc2..80f2bca2 100644 --- a/modules/pooling_module.py +++ b/modules/pooling_module.py @@ -11,8 +11,7 @@ def setup_module(self, params): in_features=self.params.layer_channels[0], out_features=self.params.layer_channels[1], bias=False) - self.weight = self.layer.weight - #self.register_parameter('fc_pool_'+self.params.layer_name+'_w', self.layer.weight) + self.weight = self.layer.weight # [outputs, inputs] elif self.params.layer_types[0] == 'conv': self.layer = nn.Conv2d( @@ -25,7 +24,6 @@ def setup_module(self, params): bias=False) nn.init.orthogonal_(self.layer.weight) # initialize to orthogonal matrix self.weight = self.layer.weight - #self.register_parameter('conv_pool_'+self.params.layer_name+'_w', self.layer.weight) else: assert False, ('layer_types[0] parameter must be "fc", "conv", not %g'%(layer_types[0])) diff --git a/params/lca_cifar10_params.py b/params/lca_cifar10_params.py index 791a9d13..40390315 100644 --- a/params/lca_cifar10_params.py +++ b/params/lca_cifar10_params.py @@ -19,7 +19,7 @@ def set_params(self): self.num_epochs = 500 self.train_logs_per_epoch = 6 self.renormalize_weights = True - self.layer_channels = 128 + self.layer_channels = [3, 128] self.kernel_size = 8 self.stride = 2 self.padding = 0 @@ -42,6 +42,6 @@ def compute_helper_params(self): self.optimizer.milestones = [frac * self.num_epochs for frac in self.optimizer.lr_annealing_milestone_frac] self.step_size = self.dt / self.tau - self.out_channels = self.layer_channels self.num_pixels = 3072 - self.in_channels = 3 + self.in_channels = self.layer_channels[0] + self.out_channels = self.layer_channels[1] diff --git a/params/lca_dsprites_params.py b/params/lca_dsprites_params.py index a932ec63..f0090e40 100644 --- a/params/lca_dsprites_params.py +++ b/params/lca_dsprites_params.py @@ -27,7 +27,7 @@ def set_params(self): self.optimizer.lr_annealing_milestone_frac = [0.7] # fraction of num_epochs self.optimizer.lr_decay_rate = 0.5 self.renormalize_weights = True - self.layer_channels = int(self.num_pixels*1.5) + self.layer_channels = [1, int(self.num_pixels*1.5)] self.dt = 0.001 self.tau = 0.03 self.num_steps = 75 diff --git a/params/lca_mlp_cifar10_params.py b/params/lca_mlp_cifar10_params.py index 979696ca..475fdee2 100644 --- a/params/lca_mlp_cifar10_params.py +++ b/params/lca_mlp_cifar10_params.py @@ -32,7 +32,7 @@ def set_params(self): self.weight_decay = 0.0 self.weight_lr = 0.001 self.renormalize_weights = True - self.layer_channels = 512 + self.layer_channels = [3, 512] self.kernel_size = 8 self.stride = 2 self.padding = 0 @@ -86,7 +86,7 @@ def set_params(self): lca_params_inst.stride, lca_params_inst.padding, dilation=1) - lca_output_shape = [lca_params_inst.layer_channels, lca_output_height, lca_output_width] + lca_output_shape = [lca_params_inst.layer_channels[1], lca_output_height, lca_output_width] mlp_params_inst.layer_channels[0] = np.prod(lca_output_shape) self.ensemble_params = [lca_params_inst, mlp_params_inst] for key, value in shared_params().__dict__.items(): diff --git a/params/lca_mlp_mnist_params.py b/params/lca_mlp_mnist_params.py index 8c8c6d4d..6f230aac 100644 --- a/params/lca_mlp_mnist_params.py +++ b/params/lca_mlp_mnist_params.py @@ -37,7 +37,7 @@ def set_params(self): self.optimizer.lr_annealing_milestone_frac = [0.7] # fraction of num_epochs self.optimizer.lr_decay_rate = 0.5 self.renormalize_weights = True - self.layer_channels = 768 + self.layer_channels = [1, 768] self.dt = 0.001 self.tau = 0.03 self.num_steps = 75 diff --git a/params/lca_mnist_params.py b/params/lca_mnist_params.py index eb248c9d..6274ef54 100644 --- a/params/lca_mnist_params.py +++ b/params/lca_mnist_params.py @@ -35,7 +35,7 @@ def set_params(self): self.weight_lr = 0.001 self.optimizer.lr_annealing_milestone_frac = [0.8] # fraction of num_epochs self.optimizer.lr_decay_rate = 0.8 - self.layer_channels = 128 + self.layer_channels = [1, 128] self.kernel_size = 8 self.stride = 2 self.padding = 0 @@ -48,7 +48,7 @@ def set_params(self): self.weight_lr = 0.1 self.optimizer.lr_annealing_milestone_frac = [0.7] # fraction of num_epochs self.optimizer.lr_decay_rate = 0.5 - self.layer_channels = 768 #self.num_pixels * 4 + self.layer_channels = [1, 768] #self.num_pixels * 4 self.compute_helper_params() def compute_helper_params(self): diff --git a/params/lca_pool_cifar10_params.py b/params/lca_pool_cifar10_params.py index 8a05fefb..740ff86d 100644 --- a/params/lca_pool_cifar10_params.py +++ b/params/lca_pool_cifar10_params.py @@ -31,7 +31,7 @@ def set_params(self): self.weight_decay = 0.0 self.weight_lr = 0.001 self.renormalize_weights = True - self.layer_channels = 128 + self.layer_channels = [3, 128] self.kernel_size = 8 self.stride = 2 self.padding = 0 @@ -91,7 +91,7 @@ def set_params(self): lca_params_inst.stride, lca_params_inst.padding, dilation=1) - lca_output_shape = [lca_params_inst.layer_channels, lca_output_height, lca_output_width] + lca_output_shape = [lca_params_inst.layer_channels[1], lca_output_height, lca_output_width] pooling_params_inst.layer_channels[0] = np.prod(lca_output_shape) self.ensemble_params = [lca_params_inst, pooling_params_inst] for key, value in shared_params().__dict__.items(): diff --git a/params/lca_pool_lca_cifar10_params.py b/params/lca_pool_lca_cifar10_params.py index ceb1037c..5d6acfe2 100644 --- a/params/lca_pool_lca_cifar10_params.py +++ b/params/lca_pool_lca_cifar10_params.py @@ -32,7 +32,7 @@ def set_params(self): self.weight_decay = 0.0 self.weight_lr = 0.001 self.renormalize_weights = True - self.layer_channels = 128 + self.layer_channels = [3, 128] self.kernel_size = 8 self.stride = 2 self.padding = 0 @@ -80,7 +80,7 @@ def set_params(self): for key, value in shared_params().__dict__.items(): setattr(self, key, value) for key, value in lca_1_params().__dict__.items(): setattr(self, key, value) self.layer_name = 'lca_2' - self.layer_channels = 256 + self.layer_channels = [32, 256] self.kernel_size = 6 self.stride = 1 self.padding = 0 diff --git a/params/lca_pool_lca_pool_cifar10_params.py b/params/lca_pool_lca_pool_cifar10_params.py index 01053ca8..2ccce21d 100644 --- a/params/lca_pool_lca_pool_cifar10_params.py +++ b/params/lca_pool_lca_pool_cifar10_params.py @@ -32,7 +32,7 @@ def set_params(self): self.weight_decay = 0.0 self.weight_lr = 0.001 self.renormalize_weights = True - self.layer_channels = 128 + self.layer_channels = [3, 128] self.kernel_size = 8 self.stride = 2 self.padding = 0 @@ -81,7 +81,7 @@ def set_params(self): for key, value in shared_params().__dict__.items(): setattr(self, key, value) for key, value in lca_1_params().__dict__.items(): setattr(self, key, value) self.layer_name = 'lca_2' - self.layer_channels = 256 + self.layer_channels = [32, 256] self.kernel_size = 6 self.stride = 1 self.padding = 0 @@ -158,7 +158,7 @@ def set_params(self): lca_2_params_inst.stride, lca_2_params_inst.padding, dilation=1) - lca_2_flat_dim = lca_2_params_inst.layer_channels*lca_2_output_height*lca_2_output_width + lca_2_flat_dim = lca_2_params_inst.layer_channels[1]*lca_2_output_height*lca_2_output_width pooling_2_params_inst.layer_channels[0] = lca_2_flat_dim self.ensemble_params = [ lca_1_params_inst, diff --git a/params/lca_pool_lca_pool_mlp_cifar10_params.py b/params/lca_pool_lca_pool_mlp_cifar10_params.py index f4b322df..45ae965a 100644 --- a/params/lca_pool_lca_pool_mlp_cifar10_params.py +++ b/params/lca_pool_lca_pool_mlp_cifar10_params.py @@ -33,7 +33,7 @@ def set_params(self): self.weight_decay = 0.0 self.weight_lr = 0#1e-3 self.renormalize_weights = True - self.layer_channels = 128 + self.layer_channels = [3, 128] self.kernel_size = 8 self.stride = 2 self.padding = 0 @@ -83,7 +83,7 @@ def set_params(self): for key, value in lca_1_params().__dict__.items(): setattr(self, key, value) self.layer_name = 'lca_2' self.weight_lr = 0#1e-3 - self.layer_channels = 256 + self.layer_channels = [32, 256] self.kernel_size = 6 self.stride = 1 self.padding = 0 @@ -180,7 +180,7 @@ def set_params(self): lca_2_params_inst.stride, lca_2_params_inst.padding, dilation=1) - lca_2_flat_dim = lca_2_params_inst.layer_channels*lca_2_output_height*lca_2_output_width + lca_2_flat_dim = lca_2_params_inst.layer_channels[1]*lca_2_output_height*lca_2_output_width pooling_2_params_inst.layer_channels[0] = lca_2_flat_dim self.ensemble_params = [ lca_1_params_inst, diff --git a/params/test_params.py b/params/test_params.py index 9e78c701..80a25b6b 100644 --- a/params/test_params.py +++ b/params/test_params.py @@ -44,15 +44,13 @@ def __init__(self): class base_params(BaseParams): def set_params(self): super(base_params, self).set_params() - for key, value in shared_params().__dict__.items(): - setattr(self, key, value) + for key, value in shared_params().__dict__.items(): setattr(self, key, value) class lca_params(BaseParams): def set_params(self): super(lca_params, self).set_params() - for key, value in shared_params().__dict__.items(): - setattr(self, key, value) + for key, value in shared_params().__dict__.items(): setattr(self, key, value) self.model_type = 'lca' self.weight_decay = 0.0 self.weight_lr = 0.1 @@ -68,7 +66,7 @@ def set_params(self): self.rectify_a = True self.thresh_type = 'soft' self.sparse_mult = 0.25 - self.layer_channels = 128 + self.layer_channels = [64, 128] self.optimizer.milestones = [frac * self.num_epochs for frac in self.optimizer.lr_annealing_milestone_frac] self.step_size = self.dt / self.tau @@ -91,8 +89,7 @@ def set_params(self): class pooling_params(BaseParams): def set_params(self): super(pooling_params, self).set_params() - for key, value in shared_params().__dict__.items(): - setattr(self, key, value) + for key, value in shared_params().__dict__.items(): setattr(self, key, value) self.model_type = 'pooling' self.layer_name = 'test_pool_1' self.weight_lr = 1e-3 @@ -111,8 +108,7 @@ def set_params(self): class mlp_params(BaseParams): def set_params(self): super(mlp_params, self).set_params() - for key, value in shared_params().__dict__.items(): - setattr(self, key, value) + for key, value in shared_params().__dict__.items(): setattr(self, key, value) self.model_type = 'mlp' self.weight_lr = 1e-4 self.weight_decay = 0.0 @@ -132,6 +128,9 @@ def set_params(self): class ensemble_params(BaseParams): def set_params(self): super(ensemble_params, self).set_params() - self.ensemble_params = [lca_params(), mlp_params()] - for key, value in shared_params().__dict__.items(): - setattr(self, key, value) + layer1_params = lca_params() + layer1_params.layer_name = 'layer1' + layer2_params = mlp_params() + layer2_params.layer_name = 'layer2' + self.ensemble_params = [layer1_params, layer2_params] + for key, value in shared_params().__dict__.items(): setattr(self, key, value) diff --git a/tests/test_data_processing.py b/tests/test_data_processing.py index 7f6d43bf..42092eca 100644 --- a/tests/test_data_processing.py +++ b/tests/test_data_processing.py @@ -236,7 +236,7 @@ def test_atleastkd(self): np.testing.assert_equal(new_x.ndim, test_nd) def test_l2_weight_norm(self): - w_fc = np.random.standard_normal([24, 38]) + w_fc = np.random.standard_normal([38, 24]) w_conv = np.random.standard_normal([38, 24, 8, 8]) for w in [w_fc, w_conv]: w_norm = dp.get_weights_l2_norm(torch.tensor(w), eps=1e-12).numpy() diff --git a/tests/test_datasets.py b/tests/test_datasets.py index dc2f1fc1..c2ac7ea2 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -34,7 +34,7 @@ def test_mnist(self): params.dataset = 'mnist' params.shuffle_data = True params.batch_size = 10000 - train_loader, val_loader, test_loader, data_params = dataset_utils.load_dataset(params) + train_loader, val_loader, test_loader, data_params = dataset_utils.load_dataset(params)[:4] for key, value in data_params.items(): setattr(params, key, value) assert len(train_loader.dataset) == params.epoch_size @@ -61,7 +61,7 @@ def test_synthetic(self): params.dist_type = dist_type params.num_classes = num_classes params.rand_state = rand_state - train_loader, val_loader, test_loader, data_params = dataset_utils.load_dataset(params) + train_loader, val_loader, test_loader, data_params = dataset_utils.load_dataset(params)[:4] for key, value in data_params.items(): setattr(params, key, value) assert len(train_loader.dataset) == epoch_size diff --git a/tests/test_models.py b/tests/test_models.py index 080cdd46..2bcc3877 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -19,13 +19,14 @@ def setUp(self): self.model_list = loaders.get_model_list(self.dsc_dir) self.test_params_file = os.path.join(*[self.dsc_dir, 'params', 'test_params.py']) + ### TODO - define endpoint function for checkpoint loading & test independently ### TODO - add ability to test multiple options (e.g. 'conv' and 'fc') from test params def test_model_loading(self): for model_type in self.model_list: model_type = '_'.join(model_type.split('_')[:-1]) # remove '_model' at the end model = loaders.load_model(model_type) params = loaders.load_params_file(self.test_params_file, key=model_type+'_params') - train_loader, val_loader, test_loader, data_params = datasets.load_dataset(params) + train_loader, val_loader, test_loader, data_params = datasets.load_dataset(params)[:4] for key, value in data_params.items(): setattr(params, key, value) model.setup(params) @@ -53,20 +54,19 @@ def test_model_loading(self): def test_lca_ensemble_gradients(self): + ## Load models params = {} models = {} params['lca'] = loaders.load_params_file(self.test_params_file, key='lca_params') params['lca'].train_logs_per_epoch = None params['lca'].shuffle_data = False - train_loader, val_loader, test_loader, data_params = datasets.load_dataset(params['lca']) - for key, value in data_params.items(): - setattr(params['lca'], key, value) + train_loader, val_loader, test_loader, data_params = datasets.load_dataset(params['lca'])[:4] + for key, value in data_params.items(): setattr(params['lca'], key, value) models['lca'] = loaders.load_model(params['lca'].model_type) models['lca'].setup(params['lca']) models['lca'].to(params['lca'].device) params['ensemble'] = loaders.load_params_file(self.test_params_file, key='ensemble_params') - for key, value in data_params.items(): - setattr(params['ensemble'], key, value) + for key, value in data_params.items(): setattr(params['ensemble'], key, value) err_msg = f'\ndata_shape={params["ensemble"].data_shape}' err_msg += f'\nnum_pixels={params["ensemble"].num_pixels}' err_msg += f'\nbatch_size={params["ensemble"].batch_size}' @@ -74,9 +74,11 @@ def test_lca_ensemble_gradients(self): models['ensemble'] = loaders.load_model(params['ensemble'].model_type) models['ensemble'].setup(params['ensemble']) models['ensemble'].to(params['ensemble'].device) + ## Overwrite weight initialization so that they have the same weights ensemble_state_dict = models['ensemble'].state_dict() - ensemble_state_dict['lca.weight'] = models['lca'].weight.clone() + ensemble_state_dict['layer1.weight'] = models['lca'].weight.clone() models['ensemble'].load_state_dict(ensemble_state_dict) + ## Load data data, target = next(iter(train_loader)) train_data_batch = models['lca'].preprocess_data(data.to(params['lca'].device)) train_target_batch = target.to(params['lca'].device) @@ -84,21 +86,31 @@ def test_lca_ensemble_gradients(self): for submodel in models['ensemble']: submodel.optimizer.zero_grad() inputs = [train_data_batch] # only the first model acts on input + ## Verify feedforward encodings + lca_encoding = models['lca'](inputs[0]).cpu().detach().numpy() + ensemble_encoding = models['ensemble'][0].get_encodings(inputs[0]).cpu().detach().numpy() + assert np.all(lca_encoding == ensemble_encoding), (err_msg+'\n' + +f'Forward encodings for lca and ensemble[0] should be equal, but are not') + ## Verify LCA loss + lca_loss = models['lca'].get_total_loss((inputs[0], train_target_batch)) + ensemble_losses = [models['ensemble'].get_total_loss((inputs[0], train_target_batch), 0)] + lca_loss_val = lca_loss.cpu().detach().numpy() + ensemble_loss_val = ensemble_losses[0].cpu().detach().numpy() + assert lca_loss_val == ensemble_loss_val, (err_msg+'\n' + +f'Losses should be equal, but are lca={lca_loss_val} and ensemble={ensemble_loss_val}') + ## Compute remaining ensemble outputs for submodel in models['ensemble']: inputs.append(submodel.get_encodings(inputs[-1]).detach()) - lca_loss = models['lca'].get_total_loss((train_data_batch, train_target_batch)) - ensemble_losses = [models['ensemble'].get_total_loss((inputs[0], train_target_batch), 0)] ensemble_losses.append(models['ensemble'].get_total_loss((inputs[1], train_target_batch), 1)) + ## Verify lca grad & ensemble grad are equal lca_loss.backward() ensemble_losses[0].backward() ensemble_losses[1].backward() - lca_loss_val = lca_loss.cpu().detach().numpy() lca_w_grad = models['lca'].weight.grad.cpu().numpy() - ensemble_loss_val = ensemble_losses[0].cpu().detach().numpy() ensemble_w_grad = models['ensemble'][0].weight.grad.cpu().numpy() - assert lca_loss_val == ensemble_loss_val, (err_msg+'\n' - +'Losses should be equal, but are lca={lca_loss_val} and ensemble={ensemble_loss_val}') - assert np.all(lca_w_grad == ensemble_w_grad), (err_msg+'\nGrads should be equal, but are not.') + assert np.all(lca_w_grad == ensemble_w_grad), (err_msg+'\n' + +f'Grads should be equal, but are not.') + ## Verify weight updates are equal lca_pre_train_w = models['lca'].weight.cpu().detach().numpy().copy() ensemble_pre_train_w = models['ensemble'][0].weight.cpu().detach().numpy().copy() run_utils.train_epoch(1, models['lca'], train_loader) @@ -113,3 +125,4 @@ def test_lca_ensemble_gradients(self): +"ensemble weights are not different from init after one epoch of training") assert np.all(lca_w == ensemble_w), (err_msg+'\n' +"lca & ensemble weights are not equal after one epoch of training") + diff --git a/train_model.py b/train_model.py index b84c7c66..1bad1364 100644 --- a/train_model.py +++ b/train_model.py @@ -26,8 +26,7 @@ # Load data train_loader, val_loader, test_loader, data_stats = dataset_utils.load_dataset(params)[:4] -for key, value in data_stats.items(): - setattr(params, key, value) +for key, value in data_stats.items(): setattr(params, key, value) # Load model model = loaders.load_model(params.model_type) diff --git a/utils/data_processing.py b/utils/data_processing.py index 933ee369..16482dfa 100644 --- a/utils/data_processing.py +++ b/utils/data_processing.py @@ -330,15 +330,15 @@ def get_weights_l2_norm(w, eps=1e-12): Return l2 norm of weight matrix Keyword arguments: - w [Tensor] assumed to have shape [inC, outC] or [outC, inC, kernH, kernW] + w [Tensor] assumed to have shape [outC, inC] or [outC, inC, kernH, kernW] norm is calculated over vectorized version of inC in the first case or inC*kernH*kernW in the second eps [float] minimum value to prevent division by zero Outputs: norm [Tensor] norm of each of the outC weight vectors """ - if w.ndim == 2: # fully-connected, [inputs, outputs] - norms = torch.norm(w, dim=0, keepdim=True) + if w.ndim == 2: # fully-connected, [outputs, inputs] + norms = torch.norm(w, dim=1, keepdim=True) elif w.ndim == 4: # convolutional, [out_channels, in_channels, kernel_height, kernel_width] norms = torch.norm(w.flatten(start_dim=1), dim=-1, keepdim=True) else: diff --git a/utils/dataset_utils.py b/utils/dataset_utils.py index 8a6a6f7e..8c17df11 100644 --- a/utils/dataset_utils.py +++ b/utils/dataset_utils.py @@ -119,21 +119,21 @@ def load_dataset(params): transforms.Lambda(lambda x: x - dataset_mean_image)) extra_outputs['dataset_mean_image'] = dataset_mean_image if params.standardize_data: - dataset = torchvision.datasets.CIFAR10(**kwargs) - data_loader = torch.utils.data.DataLoader(dataset, batch_size=params.batch_size, - shuffle=False, num_workers=0, pin_memory=True) - dataset_mean_image = dp.get_mean_from_dataloader(data_loader) - extra_outputs['dataset_mean_image'] = dataset_mean_image - dataset_std_image = dp.get_std_from_dataloader(data_loader, dataset_mean_image) - extra_outputs['dataset_std_image'] = dataset_std_image + #dataset = torchvision.datasets.CIFAR10(**kwargs) + #data_loader = torch.utils.data.DataLoader(dataset, batch_size=params.batch_size, + # shuffle=False, num_workers=0, pin_memory=True) + #dataset_mean_image = dp.get_mean_from_dataloader(data_loader) + #extra_outputs['dataset_mean_image'] = dataset_mean_image + #dataset_std_image = dp.get_std_from_dataloader(data_loader, dataset_mean_image) + #extra_outputs['dataset_std_image'] = dataset_std_image preprocessing_pipeline.append( transforms.Lambda( lambda x: dp.standardize(x, eps=params.eps, - samplewise=False, - batch_size=params.batch_size, - sample_mean=dataset_mean_image, - sample_std=dataset_std_image)[0] + samplewise=True,#False, + batch_size=params.batch_size)[0] + #sample_mean=dataset_mean_image, + #sample_std=dataset_std_image)[0] ) ) if params.rescale_data_to_one: From 4ee8bc2475b5b53ecea5c7404106726e80247373 Mon Sep 17 00:00:00 2001 From: Dylan Date: Wed, 10 Mar 2021 13:12:43 +0000 Subject: [PATCH 35/44] adds reconstructions from each layer --- notebooks/visualize_pooling_weights.ipynb | 349 ++++++++++++++++------ 1 file changed, 261 insertions(+), 88 deletions(-) diff --git a/notebooks/visualize_pooling_weights.ipynb b/notebooks/visualize_pooling_weights.ipynb index f459cf22..c79282e9 100644 --- a/notebooks/visualize_pooling_weights.ipynb +++ b/notebooks/visualize_pooling_weights.ipynb @@ -49,8 +49,8 @@ "outputs": [], "source": [ "workspace_dir = '/mnt/qb/bethge/dpaiton/'\n", - "model_name = 'lca_pool_lca_pool_cifar10'\n", - "model_version = '0'\n", + "model_name = 'smt_cifar10'\n", + "model_version = 'lplp'\n", "log_file = workspace_dir + os.path.join(*['Projects', model_name, 'logfiles', f'{model_name}_v{model_version}.log'])\n", "logger = Logger(log_file, overwrite=False)\n", "log_text = logger.load_file()\n", @@ -89,17 +89,6 @@ "print(model_state_str)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "train_loader, val_loader, test_loader, data_stats, data_mean_std = dataset_utils.load_dataset(params)\n", - "train_mean_image = data_mean_std['dataset_mean_image'].to(model.params.device)\n", - "train_std_image = data_mean_std['dataset_std_image'].to(model.params.device)" - ] - }, { "cell_type": "code", "execution_count": null, @@ -885,7 +874,7 @@ " bf_stats,\n", " affinity,\n", " num_pooling_filters=outputs,\n", - " num_connected_weights=20,\n", + " num_connected_weights=15,\n", " lines=True,\n", " figsize=(10, 10))\n", "\n", @@ -895,6 +884,18 @@ " bbox_inches='tight')" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "params.standardize_data = False\n", + "train_loader, val_loader, test_loader, data_stats, data_mean_std = dataset_utils.load_dataset(params)\n", + "train_mean_image = 0#data_mean_std['dataset_mean_image'].to(model.params.device)\n", + "train_std_image = 1#data_mean_std['dataset_std_image'].to(model.params.device)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -903,22 +904,34 @@ "source": [ "example_batch = next(iter(train_loader))[0].to(model.params.device)\n", "example_batch = model[0].preprocess_data(example_batch)\n", - "example_batch *= train_std_image\n", - "example_batch += train_mean_image\n", + "#example_batch *= train_std_image\n", + "#example_batch += train_mean_image\n", "batch_min = example_batch.min().item()\n", "batch_max = example_batch.max().item()\n", "\n", "example_image = example_batch[0, ...]\n", "print(\n", - " f'min = {example_image.min().item()}'+\n", - " f'\\nmean = {example_image.mean().item()}'+\n", - " f'\\nmax = {example_image.max().item()}'+\n", - " f'\\nstd = {example_image.std().item()}')\n", + " f'example image min = {example_image.min().item()}'+\n", + " f'\\nexample image mean = {example_image.mean().item()}'+\n", + " f'\\nexample image max = {example_image.max().item()}'+\n", + " f'\\nexample image std = {example_image.std().item()}')\n", + "preproc_image, example_image_mean, example_image_std = dp.standardize(example_image[None, ...], samplewise=True)\n", + "print(\n", + " f'preproc image min = {preproc_image.min().item()}'+\n", + " f'\\npreproc image mean = {preproc_image.mean().item()}'+\n", + " f'\\npreproc image max = {preproc_image.max().item()}'+\n", + " f'\\npreproc image std = {preproc_image.std().item()}')\n", "\n", "plot_example_image = ((example_image * train_std_image) + train_mean_image).cpu().numpy().transpose(1,2,0)\n", - "fig, ax = plot.subplots(nrows=1, ncols=1)\n", - "ax = pf.clear_axis(ax)\n", + "plot_preproc_image = ((preproc_image - preproc_image.min())/(preproc_image.max() - preproc_image.min()))[0,...].cpu().numpy().transpose(1,2,0)\n", + "\n", + "fig, axs = plot.subplots(nrows=1, ncols=2)\n", + "ax = pf.clear_axis(axs[0])\n", "ax.imshow(plot_example_image, vmin=0, vmax=1)\n", + "ax.format(title='Original')\n", + "ax = pf.clear_axis(axs[1])\n", + "ax.format(title='Preprocessed')\n", + "ax.imshow(plot_preproc_image, vmin=0, vmax=1)\n", "plot.show()" ] }, @@ -928,7 +941,10 @@ "metadata": {}, "outputs": [], "source": [ - "beta_2 = model(example_image[None,...])" + "alpha_1 = model.lca_1.get_encodings(preproc_image)\n", + "beta_1 = model.pool_1.get_encodings(alpha_1)\n", + "alpha_2 = model.lca_2.get_encodings(beta_1)\n", + "beta_2 = model(preproc_image)" ] }, { @@ -957,8 +973,8 @@ " \n", "params = lca_2_recon_params()\n", "params.set_params()\n", - "params.layer_channels = model.lca_2.params.layer_channels\n", - "params.sparse_mult = model.lca_2.params.sparse_mult\n", + "params.layer_channels = list(model.pool_2.weight.shape)\n", + "params.sparse_mult = 0.0#model.lca_2.params.sparse_mult\n", "params.data_shape = list(beta_2.shape)\n", "params.epoch_size = 1\n", "params.num_pixels = np.prod(params.data_shape)\n", @@ -968,8 +984,50 @@ "lca_2_recon_model.to(params.device)\n", "lca_2_recon_model.eval()\n", "with torch.no_grad():\n", - " lca_2_recon_model.weight = nn.Parameter(model.pool_2.weight)\n", - "alpha_2_hat = lca_2_recon_model(beta_2)[:, :, None, None]" + " lca_2_recon_model.weight = nn.Parameter(model.pool_2.weight.T)\n", + "a2h_b2 = lca_2_recon_model(beta_2)[:, :, None, None]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "alpha_2_bin = alpha_2.detach().cpu().numpy()>0\n", + "alpha_2_hat_bin = a2h_b2.detach().cpu().numpy()>0\n", + "\n", + "alpha_2_nnz = np.count_nonzero(alpha_2_bin)\n", + "alpha_2_hat_nnz = np.count_nonzero(alpha_2_hat_bin)\n", + "\n", + "hamming_dist = np.abs(np.sum(alpha_2_bin * alpha_2_hat_bin))\n", + "\n", + "num_alpha_2 = np.prod(list(a2h_b2.shape))\n", + "alpha_2_edge = int(np.sqrt(num_alpha_2))\n", + "\n", + "plot_alpha_2_hat = ((a2h_b2 - a2h_b2.min())/(a2h_b2.max() - a2h_b2.min())).reshape(alpha_2_edge, alpha_2_edge).detach().cpu().numpy()\n", + "\n", + "plot_alpha_2 = ((alpha_2 - alpha_2.min()) / (alpha_2.max() - alpha_2.min())).reshape(alpha_2_edge, alpha_2_edge).detach().cpu().numpy()\n", + "\n", + "alpha_diff = np.abs(plot_alpha_2 - plot_alpha_2_hat)\n", + "\n", + "fig, axs = plot.subplots(nrows=1, ncols=3)\n", + "\n", + "ax = pf.clear_axis(axs[0])\n", + "ax.imshow(alpha_diff, vmin=0, vmax=1)\n", + "ax.format(title='alpha 2 differences')\n", + "\n", + "ax = pf.clear_axis(axs[1])\n", + "ax.imshow(plot_alpha_2_hat, vmin=0, vmax=1)\n", + "ax.format(title='alpha 2 hat')\n", + "\n", + "ax = pf.clear_axis(axs[2])\n", + "m = ax.imshow(plot_alpha_2, vmin=0, vmax=1)\n", + "ax.format(title='alpha 2')\n", + "\n", + "ax.colorbar(m, ax=ax)\n", + "axs.format(suptitle=f'active index overlap = {hamming_dist}; alpha 2 nnz = {alpha_2_nnz}; alpha 2 hat nnz = {alpha_2_hat_nnz}')\n", + "plot.show()" ] }, { @@ -979,14 +1037,59 @@ "outputs": [], "source": [ "with torch.no_grad():\n", - " beta_1_hat = F.conv_transpose2d(\n", - " input=alpha_2_hat,\n", + " b1h_a2h_b2 = F.conv_transpose2d(\n", + " input=a2h_b2,\n", + " weight=model.lca_2.weight,\n", + " bias=None,\n", + " stride=model.lca_2.params.stride,\n", + " padding=model.lca_2.params.padding)\n", + " \n", + " b1h_a2 = F.conv_transpose2d(\n", + " input=alpha_2,\n", " weight=model.lca_2.weight,\n", " bias=None,\n", " stride=model.lca_2.params.stride,\n", " padding=model.lca_2.params.padding)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "squared_error = torch.pow(beta_1 - b1h_a2h_b2, 2.)\n", + "l2_dist = 0.5 * torch.mean(squared_error).detach().cpu().numpy()\n", + "\n", + "num_beta_1 = np.prod(list(b1h_a2h_b2.shape))\n", + "beta_1_edge = int(np.floor(np.sqrt(num_beta_1)))\n", + "beta_1_resh = int(beta_1_edge**2)\n", + "\n", + "plot_beta_1_hat = ((b1h_a2h_b2 - b1h_a2h_b2.min())/(b1h_a2h_b2.max() - b1h_a2h_b2.min())).view(-1)[:beta_1_resh].reshape(beta_1_edge, beta_1_edge).detach().cpu().numpy()\n", + "\n", + "plot_beta_1 = ((beta_1 - beta_1.min()) / (beta_1.max() - beta_1.min())).view(-1)[:beta_1_resh].reshape(beta_1_edge, beta_1_edge).detach().cpu().numpy()\n", + "\n", + "beta_diff = np.abs(plot_beta_1 - plot_beta_1_hat)\n", + "\n", + "fig, axs = plot.subplots(nrows=1, ncols=3)\n", + "\n", + "ax = pf.clear_axis(axs[0])\n", + "ax.imshow(beta_diff, vmin=0, vmax=1)\n", + "ax.format(title='beta 1 differences')\n", + "\n", + "ax = pf.clear_axis(axs[1])\n", + "ax.imshow(plot_beta_1_hat, vmin=0, vmax=1)\n", + "ax.format(title='beta 1 hat')\n", + "\n", + "ax = pf.clear_axis(axs[2])\n", + "m = ax.imshow(plot_beta_1, vmin=0, vmax=1)\n", + "ax.format(title='beta 1')\n", + "\n", + "ax.colorbar(m, ax=ax)\n", + "axs.format(suptitle=f'l2 distance = {l2_dist:0.5f}')\n", + "plot.show()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -1002,45 +1105,42 @@ " def setup_module(self, params):\n", " super(TransposedLcaModule, self).setup_module(params)\n", " if self.params.layer_types[0] == 'conv':\n", - " self.layer_output_shapes = [self.params.data_shape] # [channels, height, width]\n", " assert (self.params.data_shape[-1] % self.params.stride == 0), (\n", " f'Stride = {self.params.stride} must divide evenly into input edge size = {self.params.data_shape[-1]}')\n", " self.w_shape = [\n", - " self.params.layer_channels,\n", - " self.params.data_shape[0], # channels = 1\n", + " self.params.layer_channels[1],\n", + " self.params.layer_channels[0],\n", " self.params.kernel_size,\n", " self.params.kernel_size\n", " ]\n", " output_height = compute_deconv_output_shape(\n", - " self.layer_output_shapes[-1][1],\n", + " self.params.data_shape[1],\n", " self.params.kernel_size,\n", " self.params.stride,\n", " self.params.padding,\n", " output_padding=self.params.output_padding,\n", " dilation=1)\n", " output_width = compute_deconv_output_shape(\n", - " self.layer_output_shapes[-1][2],\n", + " self.params.data_shape[2],\n", " self.params.kernel_size,\n", " self.params.stride,\n", " self.params.padding,\n", " output_padding=self.params.output_padding,\n", " dilation=1)\n", - " self.layer_output_shapes.append([self.params.layer_channels, output_height, output_width])\n", + " self.layer_output_shape = [self.params.layer_channels[1], output_height, output_width]\n", " w_init = torch.randn(self.w_shape)\n", " w_init_normed = dp.l2_normalize_weights(w_init, eps=self.params.eps)\n", " self.weight = nn.Parameter(w_init_normed, requires_grad=True)\n", "\n", - " def compute_excitatory_current(self, input_tensor, a_in, weight=None):\n", - " if weight is None:\n", - " weight = self.weight\n", + " def compute_excitatory_current(self, input_tensor, a_in):\n", " if self.params.layer_types[0] == 'fc':\n", - " excitatory_current = torch.matmul(input_tensor, weight)\n", + " excitatory_current = torch.matmul(input_tensor, self.weight.T)\n", " else:\n", - " recon = self.get_recon_from_latents(a_in, weight)\n", + " recon = self.get_recon_from_latents(a_in)\n", " recon_error = input_tensor - recon\n", " error_injection = F.conv_transpose2d(\n", " input=recon_error,\n", - " weight=weight,\n", + " weight=self.weight,\n", " bias=None,\n", " stride=self.params.stride,\n", " padding=self.params.padding,\n", @@ -1050,15 +1150,13 @@ " excitatory_current = error_injection + a_in\n", " return excitatory_current\n", "\n", - " def get_recon_from_latents(self, a_in, weight=None):\n", - " if weight is None:\n", - " weight = self.weight\n", + " def get_recon_from_latents(self, a_in):\n", " if self.params.layer_types[0] == 'fc':\n", - " recon = torch.matmul(a_in, torch.transpose(weight, dim0=0, dim1=1))\n", + " recon = torch.matmul(a_in, self.weight)\n", " else:\n", " recon = F.conv2d(\n", " input=a_in,\n", - " weight=weight,\n", + " weight=self.weight,\n", " bias=None,\n", " stride=self.params.stride,\n", " padding=self.params.padding,\n", @@ -1146,12 +1244,12 @@ " \n", "params = lca_1_recon_params()\n", "params.set_params()\n", - "params.layer_channels = model.pool_1.params.layer_channels[0]\n", + "params.layer_channels = model.pool_1.params.layer_channels[::-1]\n", "params.kernel_size = model.pool_1.params.pool_ksize\n", "params.stride = model.pool_1.params.pool_stride\n", "params.padding = 0\n", - "params.sparse_mult = 0.01#model.lca_1.params.sparse_mult\n", - "params.data_shape = list(beta_1_hat.shape[1:])\n", + "params.sparse_mult = 0.00#model.lca_1.params.sparse_mult\n", + "params.data_shape = list(b1h_a2h_b2.shape[1:])\n", "params.epoch_size = 1\n", "params.output_padding = 1\n", "params.num_pixels = np.prod(params.data_shape)\n", @@ -1162,7 +1260,50 @@ "lca_1_recon_model.eval()\n", "with torch.no_grad():\n", " lca_1_recon_model.weight = nn.Parameter(model.pool_1.weight)\n", - "alpha_1_hat = lca_1_recon_model(beta_1_hat)" + "a1h_b1h_a2h_b2 = lca_1_recon_model(b1h_a2h_b2)\n", + "a1h_b1h_a2 = lca_1_recon_model(b1h_a2)\n", + "a1h_b1 = lca_1_recon_model(beta_1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "alpha_1_bin = alpha_1.detach().cpu().numpy()>0\n", + "alpha_1_hat_bin = a1h_b1h_a2h_b2.detach().cpu().numpy()>0\n", + "hamming_dist = np.abs(np.sum(alpha_1_bin * alpha_1_hat_bin))\n", + "alpha_1_nnz = np.count_nonzero(alpha_1_bin)\n", + "alpha_1_hat_nnz = np.count_nonzero(alpha_1_hat_bin)\n", + "\n", + "num_alpha_1 = np.prod(list(a1h_b1h_a2h_b2.shape))\n", + "alpha_1_edge = int(np.floor(np.sqrt(num_alpha_1)))\n", + "alpha_1_resh = int(alpha_1_edge**2)\n", + "\n", + "plot_alpha_1_hat = ((a1h_b1h_a2h_b2 - a1h_b1h_a2h_b2.min())/(a1h_b1h_a2h_b2.max() - a1h_b1h_a2h_b2.min())).view(-1)[:alpha_1_resh].reshape(alpha_1_edge, alpha_1_edge).detach().cpu().numpy()\n", + "\n", + "plot_alpha_1 = ((alpha_1 - alpha_1.min()) / (alpha_1.max() - alpha_1.min())).view(-1)[:alpha_1_resh].reshape(alpha_1_edge, alpha_1_edge).detach().cpu().numpy()\n", + "\n", + "alpha_diff = np.abs(plot_alpha_1 - plot_alpha_1_hat)\n", + "\n", + "fig, axs = plot.subplots(nrows=1, ncols=3)\n", + "\n", + "ax = pf.clear_axis(axs[0])\n", + "ax.imshow(alpha_diff, vmin=0, vmax=1)\n", + "ax.format(title='alpha 1 differences')\n", + "\n", + "ax = pf.clear_axis(axs[1])\n", + "ax.imshow(plot_alpha_1_hat, vmin=0, vmax=1)\n", + "ax.format(title='alpha 1 hat')\n", + "\n", + "ax = pf.clear_axis(axs[2])\n", + "m = ax.imshow(plot_alpha_1, vmin=0, vmax=1)\n", + "ax.format(title='alpha 1')\n", + "\n", + "ax.colorbar(m, ax=ax)\n", + "axs.format(suptitle=f'active index overlap = {hamming_dist}; alpha 1 nnz = {alpha_1_nnz}; alpha 1 hat nnz = {alpha_1_hat_nnz}')\n", + "plot.show()" ] }, { @@ -1172,8 +1313,29 @@ "outputs": [], "source": [ "with torch.no_grad():\n", - " recon = F.conv_transpose2d(\n", - " input=alpha_1_hat,\n", + " recon_from_alpha_1 = F.conv_transpose2d(\n", + " input=alpha_1,\n", + " weight=model.lca_1.weight,\n", + " bias=None,\n", + " stride=model.lca_1.params.stride,\n", + " padding=model.lca_1.params.padding)\n", + " \n", + " recon_from_beta_1 = F.conv_transpose2d(\n", + " input=a1h_b1,\n", + " weight=model.lca_1.weight,\n", + " bias=None,\n", + " stride=model.lca_1.params.stride,\n", + " padding=model.lca_1.params.padding)\n", + " \n", + " recon_from_alpha_2 = F.conv_transpose2d(\n", + " input=a1h_b1h_a2,\n", + " weight=model.lca_1.weight,\n", + " bias=None,\n", + " stride=model.lca_1.params.stride,\n", + " padding=model.lca_1.params.padding)\n", + " \n", + " recon_from_beta_2 = F.conv_transpose2d(\n", + " input=a1h_b1h_a2h_b2,\n", " weight=model.lca_1.weight,\n", " bias=None,\n", " stride=model.lca_1.params.stride,\n", @@ -1186,39 +1348,57 @@ "metadata": {}, "outputs": [], "source": [ - "alpha_2_nnz = torch.sum(alpha_2_hat !=0,\n", - " dim=tuple([i for i in range(len(alpha_2_hat.shape))]),\n", - " dtype=torch.float)/alpha_2_hat.numel()\n", - "alpha_1_nnz = torch.sum(alpha_1_hat !=0,\n", - " dim=tuple([i for i in range(len(alpha_1_hat.shape))]),\n", - " dtype=torch.float)/alpha_1_hat.numel()\n", + "def plot_func(x):\n", + " x *= example_image_std\n", + " x += example_image_mean\n", + " x = ((x - x.min()) / (x.max() - x.min())).squeeze().cpu().numpy().transpose(1,2,0)\n", + " return(x)\n", + "\n", + "alpha_2_nnz = torch.sum(a2h_b2 !=0,\n", + " dim=tuple([i for i in range(len(a2h_b2.shape))]),\n", + " dtype=torch.float)/a2h_b2.numel()\n", + "alpha_1_nnz = torch.sum(a1h_b1h_a2h_b2 !=0,\n", + " dim=tuple([i for i in range(len(a1h_b1h_a2h_b2.shape))]),\n", + " dtype=torch.float)/a1h_b1h_a2h_b2.numel()\n", "print(\n", " f'beta2 shape = {beta_2.shape}' + \n", " f'\\nalpha2^ nnz = {alpha_2_nnz}'+\n", - " f'\\nalpha2^ shape = {alpha_2_hat.shape}'+\n", - " f'\\nbeta1^ shape = {beta_1_hat.shape}'\n", + " f'\\nalpha2^ shape = {a2h_b2.shape}'+\n", + " f'\\nbeta1^ shape = {b1h_a2h_b2.shape}'\n", " f'\\nalpha1^ nnz = {alpha_1_nnz}'+\n", - " f'\\nalpha1^ shape = {alpha_1_hat.shape}'+\n", - " f'\\nimage^ shape = {recon.shape}'\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ + " f'\\nalpha1^ shape = {a1h_b1h_a2h_b2.shape}'+\n", + " f'\\nimage^ shape = {recon_from_beta_2.shape}'\n", + ")\n", "print(\n", - " f'recon min = {recon.min().item()}'+\n", - " f'\\nrecon mean = {recon.mean().item()}'+\n", - " f'\\nrecon max = {recon.max().item()}'+\n", - " f'\\nrecon std = {recon.std().item()}')\n", - "\n", - "plot_recon = ((recon.squeeze() * train_std_image) + train_mean_image).cpu().numpy().transpose(1,2,0)\n", - "fig, ax = plot.subplots(nrows=1, ncols=1)\n", - "ax = pf.clear_axis(ax)\n", - "ax.imshow(plot_recon, vmin=0, vmax=1)\n", + " f'recon min = {recon_from_beta_2.min().item()}'+\n", + " f'\\nrecon mean = {recon_from_beta_2.mean().item()}'+\n", + " f'\\nrecon max = {recon_from_beta_2.max().item()}'+\n", + " f'\\nrecon std = {recon_from_beta_2.std().item()}'\n", + ")\n", + "\n", + "fig, axs = plot.subplots(nrows=2, ncols=3)\n", + "\n", + "ax = pf.clear_axis(axs[0,0])\n", + "ax.imshow(plot_preproc_image, vmin=0, vmax=1)\n", + "ax.format(title='original')\n", + "\n", + "ax = pf.clear_axis(axs[0,1])\n", + "ax.imshow(plot_func(recon_from_alpha_1), vmin=0, vmax=1)\n", + "ax.format(title='recon from alpha 1')\n", + "\n", + "ax = pf.clear_axis(axs[0,2])\n", + "ax.imshow(plot_func(recon_from_beta_1), vmin=0, vmax=1)\n", + "ax.format(title='recon from beta 1')\n", + "\n", + "ax = pf.clear_axis(axs[1,0])\n", + "ax.imshow(plot_func(recon_from_alpha_2), vmin=0, vmax=1)\n", + "ax.format(title='recon from alpha 2')\n", + "\n", + "ax = pf.clear_axis(axs[1,1])\n", + "ax.imshow(plot_func(recon_from_beta_2), vmin=0, vmax=1)\n", + "ax.format(title='recon from beta 2')\n", + "\n", + "ax = pf.clear_axis(axs[1,2])\n", "plot.show()" ] }, @@ -1227,14 +1407,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "plot_recon = recon.squeeze().cpu().numpy().transpose(1,2,0)\n", - "plot_recon = (plot_recon - plot_recon.min()) / (plot_recon.max() - plot_recon.min())\n", - "fig, ax = plot.subplots(nrows=1, ncols=1)\n", - "ax = pf.clear_axis(ax)\n", - "ax.imshow(plot_recon)\n", - "plot.show()" - ] + "source": [] }, { "cell_type": "code", From 893f7ab130ec2e46648f66b73c9fb32958e6d5cd Mon Sep 17 00:00:00 2001 From: Dylan Date: Wed, 10 Mar 2021 13:13:05 +0000 Subject: [PATCH 36/44] more appropriate name --- notebooks/{visualize_pooling_weights.ipynb => smt_analysis.ipynb} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename notebooks/{visualize_pooling_weights.ipynb => smt_analysis.ipynb} (100%) diff --git a/notebooks/visualize_pooling_weights.ipynb b/notebooks/smt_analysis.ipynb similarity index 100% rename from notebooks/visualize_pooling_weights.ipynb rename to notebooks/smt_analysis.ipynb From b9d0368ad884e2864caa4a9845b4301842ff642f Mon Sep 17 00:00:00 2001 From: Dylan Date: Wed, 10 Mar 2021 13:15:28 +0000 Subject: [PATCH 37/44] temp fix for learning_rate checkpoint bug --- models/ensemble_model.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/models/ensemble_model.py b/models/ensemble_model.py index 51a6dc9b..d1b334fd 100644 --- a/models/ensemble_model.py +++ b/models/ensemble_model.py @@ -42,7 +42,14 @@ def setup_module(self, params): if module_state_dict_name in checkpoint.keys(): # It was already in an ensemble submodule.load_state_dict(checkpoint[module_state_dict_name]) else: # it was trained on its own - submodule.load_state_dict(checkpoint['model_state_dict']) + if 'model_state_dict' in checkpoint.keys(): + submodule.load_state_dict(checkpoint['model_state_dict']) + else: + assert False, ( + f'subparams {subparams} has checkpoint_boot_log set to ' + +f'{subparams.checkpoint_boot_log}, but that log does not have the ' + +f'appropriate key. The key "{module_state_dict_name}" must be in ' + +f'checkpoint.keys() = {checkpoint.keys}') def setup_optimizer(self): for module in self: @@ -56,6 +63,11 @@ def setup_optimizer(self): module.optimizer.load_state_dict(checkpoint[module_state_dict_name]) else: # it was trained on its own module.optimizer.load_state_dict(checkpoint['optimizer_state_dict'][0]) #TODO: For some reason this is a tuple of size 1 containing the dictionary. It should just be the dictionary + for group in module.optimizer.param_groups: # overwrite learning rates + group['lr'] = module.params.weight_lr + group['initial_lr'] = module.params.weight_lr + ## TODO: load scheduler state dict with checkpoint, set last_epoch correctly + ## https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html module.scheduler = torch.optim.lr_scheduler.MultiStepLR( module.optimizer, milestones=module.params.optimizer.milestones, @@ -86,6 +98,10 @@ def load_checkpoint(self, cp_file=None, load_optimizer=False): else: module.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) _ = checkpoint.pop('optimizer_state_dict', None) + for group in module.optimizer.param_groups: # overwrite learning rates + group['lr'] = module.params.weight_lr + group['initial_lr'] = module.params.weight_lr + ## TODO: Load scheduler state dict as well _ = checkpoint.pop('model_state_dict', None) training_status = pprint.pformat(checkpoint, compact=True)#, sort_dicts=True #TODO: Python 3.8 adds the sort_dicts parameter out_str = f'Loaded checkpoint from {cp_file} with the following stats:\n{training_status}' From a4092109a0934ade4a58ff171fbe7e233db9b7a4 Mon Sep 17 00:00:00 2001 From: Dylan Date: Wed, 10 Mar 2021 13:16:17 +0000 Subject: [PATCH 38/44] minor commenting & new train progress outputs --- models/base_model.py | 5 +++-- models/mlp_model.py | 2 +- models/pooling_model.py | 6 ++++++ 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/models/base_model.py b/models/base_model.py index 07cb2e32..904cb7ea 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -169,6 +169,7 @@ def write_checkpoint(self, batch_step=None): output_dict['model_state_dict'] = self.state_dict() module_state_dict_name = 'optimizer_state_dict' output_dict[module_state_dict_name] = self.optimizer.state_dict(), + ## TODO: Save scheduler state dict as well training_stats = self.get_train_stats(batch_step) output_dict.update(training_stats) torch.save(output_dict, self.params.cp_latest_filename) @@ -215,8 +216,8 @@ def get_optimizer(self, optimizer_params, trainable_variables): def setup_optimizer(self): self.optimizer = self.get_optimizer( - optimizer_params=self.params, - trainable_variables=self.parameters()) + optimizer_params=self.params, + trainable_variables=self.parameters()) self.scheduler = torch.optim.lr_scheduler.MultiStepLR( self.optimizer, milestones=self.params.optimizer.milestones, diff --git a/models/mlp_model.py b/models/mlp_model.py index dc6b97f7..ed4ce8d3 100644 --- a/models/mlp_model.py +++ b/models/mlp_model.py @@ -29,7 +29,7 @@ def generate_update_dict(self, input_data, input_labels=None, batch_step=0, upda total_loss = self.loss_fn(pred, input_labels) pred = pred.max(1, keepdim=True)[1] correct = pred.eq(input_labels.view_as(pred)).sum().item() - stat_dict['weight_lr'] = self.scheduler.get_lr()[0] + stat_dict['weight_lr'] = self.scheduler.get_lr()[0] # one LR for all parameters stat_dict['loss'] = total_loss.item() stat_dict['train_accuracy'] = 100. * correct / self.params.batch_size update_dict.update(stat_dict) diff --git a/models/pooling_model.py b/models/pooling_model.py index 093d0b1f..55763552 100644 --- a/models/pooling_model.py +++ b/models/pooling_model.py @@ -33,6 +33,12 @@ def generate_update_dict(self, input_data, input_labels=None, batch_step=0, upda update_dict = super(PoolinModel, self).generate_update_dict(input_data, input_labels, batch_step) stat_dict = dict() rep = self.forward(input_data) + def count_nonzero(array, dim): + # TODO: github issue 23907 requests torch.count_nonzero, integrated in torch 1.7 + return torch.sum(array !=0, dim=dim, dtype=torch.float) + rep_dims = tuple([i for i in range(len(rep.shape))]) + rep_nnz = count_nonzero(rep, dim=rep_dims).item() + stat_dict['fraction_active_all_latents'] = rep_nnz / rep.numel() total_loss = self.loss_fn(rep) stat_dict['weight_lr'] = self.scheduler.get_lr()[0] stat_dict['loss'] = total_loss.item() From e4564a63825739ba0597901f7e6b95fa9a3a2a99 Mon Sep 17 00:00:00 2001 From: Dylan Date: Wed, 10 Mar 2021 13:16:53 +0000 Subject: [PATCH 39/44] adds function to read architecture information --- utils/file_utils.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/utils/file_utils.py b/utils/file_utils.py index 4a587aba..2ca41dc9 100644 --- a/utils/file_utils.py +++ b/utils/file_utils.py @@ -56,7 +56,7 @@ def log_params(self, params): out_params = deepcopy(params) if('ensemble_params' in out_params.keys()): for sub_idx, sub_params in enumerate(out_params['ensemble_params']): - sub_params.set_params() + #sub_params.set_params() for key, value in sub_params.__dict__.items(): if(key != 'rand_state'): new_dict_key = f'{sub_idx}_{key}' @@ -179,6 +179,18 @@ def read_stats(self, text): stats[key] = [js_match[key]] return stats + def read_architecture(self, text): + """ + Generate dictionary of lists that contain stats from log text + Outpus: + stats: [dict] containing run statistics + Inputs: + text: [str] containing text to parse, can be obtained by calling load_file() + """ + tokens = ['', ''] + js_match = self.read_js(tokens, text) + return js_match + def __del__(self): if(self.log_to_file and hasattr(self, 'file_obj')): self.file_obj.close() From ceb357705f92838da3e1c5d81b158a7c9fac59c2 Mon Sep 17 00:00:00 2001 From: Dylan Date: Wed, 10 Mar 2021 13:17:12 +0000 Subject: [PATCH 40/44] adds funciton to compute deconvolutional output shape --- utils/run_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/utils/run_utils.py b/utils/run_utils.py index a9d83756..c204ae66 100644 --- a/utils/run_utils.py +++ b/utils/run_utils.py @@ -9,6 +9,11 @@ def compute_conv_output_shape(in_length, kernel_length, stride, padding=0, dilat return np.floor(out_shape).astype(np.int) +def compute_deconv_output_shape(in_length, kernel_length, stride, padding=0, output_padding=0, dilation=1): + out_shape = (in_length - 1) * stride - 2 * padding + dilation * (kernel_length - 1) + output_padding + 1 + return np.floor(out_shape).astype(np.int) + + def get_module_encodings(module, data, allow_grads=False): if allow_grads: return module.get_encodings(data) From 13423693e74f7d6bcf726b8244f6174580195235 Mon Sep 17 00:00:00 2001 From: Dylan Date: Wed, 10 Mar 2021 13:18:16 +0000 Subject: [PATCH 41/44] integrated hierarchical params into one file --- params/lca_pool_cifar10_params.py | 98 ---------- params/lca_pool_lca_cifar10_params.py | 132 -------------- params/lca_pool_lca_pool_cifar10_params.py | 170 ------------------ ...ifar10_params.py => smt_cifar10_params.py} | 133 ++++++++++---- 4 files changed, 100 insertions(+), 433 deletions(-) delete mode 100644 params/lca_pool_cifar10_params.py delete mode 100644 params/lca_pool_lca_cifar10_params.py delete mode 100644 params/lca_pool_lca_pool_cifar10_params.py rename params/{lca_pool_lca_pool_mlp_cifar10_params.py => smt_cifar10_params.py} (60%) diff --git a/params/lca_pool_cifar10_params.py b/params/lca_pool_cifar10_params.py deleted file mode 100644 index 740ff86d..00000000 --- a/params/lca_pool_cifar10_params.py +++ /dev/null @@ -1,98 +0,0 @@ -import os -import types -import numpy as np -import torch - -from DeepSparseCoding.params.base_params import BaseParams -from DeepSparseCoding.params.lca_cifar10_params import params as LcaParams -from DeepSparseCoding.utils.run_utils import compute_conv_output_shape - - -class shared_params(object): - def __init__(self): - self.model_type = 'ensemble' - self.model_name = 'lca_pool_cifar10' - self.version = '0' - self.dataset = 'cifar10' - self.standardize_data = True - self.batch_size = 25 - self.num_epochs = 10 - self.train_logs_per_epoch = 4 - self.allow_parent_grads = False - - -class lca_params(LcaParams): - def set_params(self): - super(lca_params, self).set_params() - for key, value in shared_params().__dict__.items(): setattr(self, key, value) - self.model_type = 'lca' - self.layer_name = 'lca_1' - self.layer_types = ['conv'] - self.weight_decay = 0.0 - self.weight_lr = 0.001 - self.renormalize_weights = True - self.layer_channels = [3, 128] - self.kernel_size = 8 - self.stride = 2 - self.padding = 0 - self.optimizer = types.SimpleNamespace() - self.optimizer.name = 'sgd' - self.optimizer.lr_annealing_milestone_frac = [0.8] # fraction of num_epochs - self.optimizer.lr_decay_rate = 0.8 - self.dt = 0.001 - self.tau = 0.1#0.2 - self.num_steps = 37#75 - self.rectify_a = True - self.thresh_type = 'hard' - self.sparse_mult = 0.35#0.30 - self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/lca_cifar10/logfiles/lca_cifar10_v0.log' - self.compute_helper_params() - - -class pooling_params(BaseParams): - def set_params(self): - super(pooling_params, self).set_params() - for key, value in shared_params().__dict__.items(): - setattr(self, key, value) - self.model_type = 'pooling' - self.layer_name = 'pool_1' - self.weight_lr = 1e-3 - self.layer_types = ['conv'] - self.layer_channels = [128, 32] - self.pool_ksize = 2 - self.pool_stride = 2 # non-overlapping - self.optimizer = types.SimpleNamespace() - self.optimizer.name = 'sgd' - self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs - self.optimizer.lr_decay_rate = 0.8 - self.compute_helper_params() - - def compute_helper_params(self): - super(pooling_params, self).compute_helper_params() - self.optimizer.milestones = [frac * self.num_epochs - for frac in self.optimizer.lr_annealing_milestone_frac] - - -class params(BaseParams): - def set_params(self): - super(params, self).set_params() - lca_params_inst = lca_params() - pooling_params_inst = pooling_params() - if(pooling_params_inst.layer_types[0] == 'fc' and lca_params_inst.layer_types[0] == 'conv'): - lca_output_height = compute_conv_output_shape( - 32, - lca_params_inst.kernel_size, - lca_params_inst.stride, - lca_params_inst.padding, - dilation=1) - lca_output_width = compute_conv_output_shape( - 32, - lca_params_inst.kernel_size, - lca_params_inst.stride, - lca_params_inst.padding, - dilation=1) - lca_output_shape = [lca_params_inst.layer_channels[1], lca_output_height, lca_output_width] - pooling_params_inst.layer_channels[0] = np.prod(lca_output_shape) - self.ensemble_params = [lca_params_inst, pooling_params_inst] - for key, value in shared_params().__dict__.items(): - setattr(self, key, value) diff --git a/params/lca_pool_lca_cifar10_params.py b/params/lca_pool_lca_cifar10_params.py deleted file mode 100644 index 5d6acfe2..00000000 --- a/params/lca_pool_lca_cifar10_params.py +++ /dev/null @@ -1,132 +0,0 @@ -import os -import types - -import numpy as np -import torch - -from DeepSparseCoding.params.base_params import BaseParams -from DeepSparseCoding.params.lca_cifar10_params import params as LcaParams -from DeepSparseCoding.utils.run_utils import compute_conv_output_shape - - -class shared_params(object): - def __init__(self): - self.model_type = 'ensemble' - self.model_name = 'lca_pool_lca_cifar10' - self.version = '0' - self.dataset = 'cifar10' - self.standardize_data = True - self.batch_size = 25 - self.num_epochs = 250 - self.train_logs_per_epoch = 4 - self.allow_parent_grads = False - - -class lca_1_params(LcaParams): - def set_params(self): - super(lca_1_params, self).set_params() - for key, value in shared_params().__dict__.items(): setattr(self, key, value) - self.model_type = 'lca' - self.layer_name = 'lca_1' - self.layer_types = ['conv'] - self.weight_decay = 0.0 - self.weight_lr = 0.001 - self.renormalize_weights = True - self.layer_channels = [3, 128] - self.kernel_size = 8 - self.stride = 2 - self.padding = 0 - self.optimizer = types.SimpleNamespace() - self.optimizer.name = 'sgd' - self.optimizer.lr_annealing_milestone_frac = [0.8] # fraction of num_epochs - self.optimizer.lr_decay_rate = 0.8 - self.dt = 0.001 - self.tau = 0.1#0.2 - self.num_steps = 37#75 - self.rectify_a = True - self.thresh_type = 'hard' - self.sparse_mult = 0.35#0.30 - self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/lca_pool_cifar10/logfiles/lca_pool_cifar10_v0.log' - self.compute_helper_params() - - -class pooling_1_params(BaseParams): - def set_params(self): - super(pooling_1_params, self).set_params() - for key, value in shared_params().__dict__.items(): setattr(self, key, value) - self.model_type = 'pooling' - self.layer_name = 'pool_1' - self.weight_lr = 1e-3 - self.layer_types = ['conv'] - self.layer_channels = [128, 32] - self.pool_ksize = 2 - self.pool_stride = 2 # non-overlapping - self.optimizer = types.SimpleNamespace() - self.optimizer.name = 'sgd' - self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs - self.optimizer.lr_decay_rate = 0.8 - self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/lca_pool_cifar10/logfiles/lca_pool_cifar10_v0.log' - self.compute_helper_params() - - def compute_helper_params(self): - super(pooling_1_params, self).compute_helper_params() - self.optimizer.milestones = [frac * self.num_epochs - for frac in self.optimizer.lr_annealing_milestone_frac] - - -class lca_2_params(LcaParams): - def set_params(self): - super(lca_2_params, self).set_params() - for key, value in shared_params().__dict__.items(): setattr(self, key, value) - for key, value in lca_1_params().__dict__.items(): setattr(self, key, value) - self.layer_name = 'lca_2' - self.layer_channels = [32, 256] - self.kernel_size = 6 - self.stride = 1 - self.padding = 0 - self.sparse_mult = 0.15 - self.checkpoint_boot_log = '' - self.compute_helper_params() - - -class params(BaseParams): - def set_params(self): - super(params, self).set_params() - lca_1_params_inst = lca_1_params() - pooling_params_inst = pooling_1_params() - lca_2_params_inst = lca_2_params() - lca_1_output_height = compute_conv_output_shape( - 32, - lca_1_params_inst.kernel_size, - lca_1_params_inst.stride, - lca_1_params_inst.padding, - dilation=1) - lca_1_output_width = compute_conv_output_shape( - 32, - lca_1_params_inst.kernel_size, - lca_1_params_inst.stride, - lca_1_params_inst.padding, - dilation=1) - pooling_output_height = compute_conv_output_shape( - lca_1_output_height, - pooling_params_inst.pool_ksize, - pooling_params_inst.pool_stride, - padding=0, - dilation=1) - pooling_output_width = compute_conv_output_shape( - lca_1_output_width, - pooling_params_inst.pool_ksize, - pooling_params_inst.pool_stride, - padding=0, - dilation=1) - lca_2_params_inst.data_shape = [ - int(pooling_params_inst.layer_channels[-1]), - int(pooling_output_height), - int(pooling_output_width)] - self.ensemble_params = [ - lca_1_params_inst, - pooling_params_inst, - lca_2_params_inst - ] - for key, value in shared_params().__dict__.items(): - setattr(self, key, value) diff --git a/params/lca_pool_lca_pool_cifar10_params.py b/params/lca_pool_lca_pool_cifar10_params.py deleted file mode 100644 index 2ccce21d..00000000 --- a/params/lca_pool_lca_pool_cifar10_params.py +++ /dev/null @@ -1,170 +0,0 @@ -import os -import types - -import numpy as np -import torch - -from DeepSparseCoding.params.base_params import BaseParams -from DeepSparseCoding.params.lca_cifar10_params import params as LcaParams -from DeepSparseCoding.utils.run_utils import compute_conv_output_shape - - -class shared_params(object): - def __init__(self): - self.model_type = 'ensemble' - self.model_name = 'lca_pool_lca_pool_cifar10' - self.version = '0' - self.dataset = 'cifar10' - self.standardize_data = True - self.batch_size = 25 - self.num_epochs = 150 - self.train_logs_per_epoch = 4 - self.allow_parent_grads = False - - -class lca_1_params(LcaParams): - def set_params(self): - super(lca_1_params, self).set_params() - for key, value in shared_params().__dict__.items(): setattr(self, key, value) - self.model_type = 'lca' - self.layer_name = 'lca_1' - self.layer_types = ['conv'] - self.weight_decay = 0.0 - self.weight_lr = 0.001 - self.renormalize_weights = True - self.layer_channels = [3, 128] - self.kernel_size = 8 - self.stride = 2 - self.padding = 0 - self.optimizer = types.SimpleNamespace() - self.optimizer.name = 'sgd' - self.optimizer.lr_annealing_milestone_frac = [0.8] # fraction of num_epochs - self.optimizer.lr_decay_rate = 0.8 - self.dt = 0.001 - self.tau = 0.1#0.2 - self.num_steps = 37#75 - self.rectify_a = True - self.thresh_type = 'hard' - self.sparse_mult = 0.35#0.30 - self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/lca_pool_lca_cifar10/logfiles/lca_pool_lca_cifar10_v0.log' - self.compute_helper_params() - - -class pooling_1_params(BaseParams): - def set_params(self): - super(pooling_1_params, self).set_params() - for key, value in shared_params().__dict__.items(): - setattr(self, key, value) - self.model_type = 'pooling' - self.layer_name = 'pool_1' - self.weight_lr = 1e-3 - self.layer_types = ['conv'] - self.layer_channels = [128, 32] - self.pool_ksize = 2 - self.pool_stride = 2 # non-overlapping - self.optimizer = types.SimpleNamespace() - self.optimizer.name = 'sgd' - self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs - self.optimizer.lr_decay_rate = 0.8 - self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/lca_pool_lca_cifar10/logfiles/lca_pool_lca_cifar10_v0.log' - self.compute_helper_params() - - def compute_helper_params(self): - super(pooling_1_params, self).compute_helper_params() - self.optimizer.milestones = [frac * self.num_epochs - for frac in self.optimizer.lr_annealing_milestone_frac] - - -class lca_2_params(LcaParams): - def set_params(self): - super(lca_2_params, self).set_params() - for key, value in shared_params().__dict__.items(): setattr(self, key, value) - for key, value in lca_1_params().__dict__.items(): setattr(self, key, value) - self.layer_name = 'lca_2' - self.layer_channels = [32, 256] - self.kernel_size = 6 - self.stride = 1 - self.padding = 0 - self.sparse_mult = 0.15 - self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/lca_pool_lca_cifar10/logfiles/lca_pool_lca_cifar10_v0.log' - self.compute_helper_params() - -class pooling_2_params(BaseParams): - def set_params(self): - super(pooling_2_params, self).set_params() - for key, value in shared_params().__dict__.items(): setattr(self, key, value) - for key, value in pooling_1_params().__dict__.items(): setattr(self, key, value) - self.layer_name = 'pool_2' - self.weight_lr = 1e-3 - self.layer_types = ['fc'] - self.layer_channels = [None, 64] - self.optimizer = types.SimpleNamespace() - self.optimizer.name = 'sgd' - self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs - self.optimizer.lr_decay_rate = 0.8 - self.checkpoint_boot_log = '' - self.compute_helper_params() - - def compute_helper_params(self): - super(pooling_2_params, self).compute_helper_params() - self.optimizer.milestones = [frac * self.num_epochs - for frac in self.optimizer.lr_annealing_milestone_frac] - - -class params(BaseParams): - def set_params(self): - super(params, self).set_params() - lca_1_params_inst = lca_1_params() - pooling_1_params_inst = pooling_1_params() - lca_2_params_inst = lca_2_params() - pooling_2_params_inst = pooling_2_params() - lca_1_output_height = compute_conv_output_shape( - 32, - lca_1_params_inst.kernel_size, - lca_1_params_inst.stride, - lca_1_params_inst.padding, - dilation=1) - lca_1_output_width = compute_conv_output_shape( - 32, - lca_1_params_inst.kernel_size, - lca_1_params_inst.stride, - lca_1_params_inst.padding, - dilation=1) - pooling_1_output_height = compute_conv_output_shape( - lca_1_output_height, - pooling_1_params_inst.pool_ksize, - pooling_1_params_inst.pool_stride, - padding=0, - dilation=1) - pooling_1_output_width = compute_conv_output_shape( - lca_1_output_width, - pooling_1_params_inst.pool_ksize, - pooling_1_params_inst.pool_stride, - padding=0, - dilation=1) - lca_2_params_inst.data_shape = [ - int(pooling_1_params_inst.layer_channels[-1]), - int(pooling_1_output_height), - int(pooling_1_output_width)] - lca_2_output_height = compute_conv_output_shape( - pooling_1_output_height, - lca_2_params_inst.kernel_size, - lca_2_params_inst.stride, - lca_2_params_inst.padding, - dilation=1) - lca_2_output_width = compute_conv_output_shape( - pooling_1_output_width, - lca_2_params_inst.kernel_size, - lca_2_params_inst.stride, - lca_2_params_inst.padding, - dilation=1) - lca_2_flat_dim = lca_2_params_inst.layer_channels[1]*lca_2_output_height*lca_2_output_width - pooling_2_params_inst.layer_channels[0] = lca_2_flat_dim - self.ensemble_params = [ - lca_1_params_inst, - pooling_1_params_inst, - lca_2_params_inst, - pooling_2_params_inst - ] - for key, value in shared_params().__dict__.items(): - setattr(self, key, value) diff --git a/params/lca_pool_lca_pool_mlp_cifar10_params.py b/params/smt_cifar10_params.py similarity index 60% rename from params/lca_pool_lca_pool_mlp_cifar10_params.py rename to params/smt_cifar10_params.py index 45ae965a..59d77030 100644 --- a/params/lca_pool_lca_pool_mlp_cifar10_params.py +++ b/params/smt_cifar10_params.py @@ -13,12 +13,15 @@ class shared_params(object): def __init__(self): self.model_type = 'ensemble' - self.model_name = 'lca_pool_lca_pool_mlp_cifar10' - self.version = '0' + self.model_name = 'test_smt_cifar10' + #self.version = 'lplpm' + self.version = '2lp' self.dataset = 'cifar10' self.standardize_data = True - self.batch_size = 25 - self.num_epochs = 150 + self.rescale_data_to_one = False + self.center_dataset = False + self.batch_size = 30 + self.num_epochs = 200 self.train_logs_per_epoch = 4 self.allow_parent_grads = False @@ -31,23 +34,31 @@ def set_params(self): self.layer_name = 'lca_1' self.layer_types = ['conv'] self.weight_decay = 0.0 - self.weight_lr = 0#1e-3 + #self.weight_lr = 1e-3 + self.weight_lr = 0.0 # For next layer training self.renormalize_weights = True - self.layer_channels = [3, 128] + #self.layer_channels = [3, 128] + self.layer_channels = [3, 256] self.kernel_size = 8 - self.stride = 2 + #self.stride = 2 + self.stride = 1 self.padding = 0 self.optimizer = types.SimpleNamespace() self.optimizer.name = 'sgd' - self.optimizer.lr_annealing_milestone_frac = [0.8] # fraction of num_epochs + self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs self.optimizer.lr_decay_rate = 0.8 self.dt = 0.001 - self.tau = 0.2 + #self.tau = 0.2#0.10 + self.tau = 0.25 + #self.num_steps = 75#37 self.num_steps = 75 self.rectify_a = True self.thresh_type = 'hard' - self.sparse_mult = 0.35#0.30 - self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/lca_pool_lca_pool_cifar10/logfiles/lca_pool_lca_pool_cifar10_v0.log' + #self.sparse_mult = 0.35 + self.sparse_mult = 0.28 + #self.checkpoint_boot_log = '' + #self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/smt_cifar10/logfiles/smt_cifar10_vlplp.log' + self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/smt_cifar10/logfiles/smt_cifar10_v2l.log' self.compute_helper_params() @@ -58,16 +69,22 @@ def set_params(self): setattr(self, key, value) self.model_type = 'pooling' self.layer_name = 'pool_1' - self.weight_lr = 0#1e-3 self.layer_types = ['conv'] - self.layer_channels = [128, 32] - self.pool_ksize = 2 - self.pool_stride = 2 # non-overlapping + self.weight_lr = 1e-3 + #self.weight_lr = 0.0 # For next layer training + #self.layer_channels = [128, 32] + self.layer_channels = [256, 32] + #self.pool_ksize = 2 + self.pool_ksize = 4 + self.pool_stride = 2 + self.renormalize_weights = True self.optimizer = types.SimpleNamespace() self.optimizer.name = 'sgd' self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs self.optimizer.lr_decay_rate = 0.8 - self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/lca_pool_lca_pool_cifar10/logfiles/lca_pool_lca_pool_cifar10_v0.log' + self.checkpoint_boot_log = '' + #self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/smt_cifar10/logfiles/smt_cifar10_vlplp.log' + #self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/smt_cifar10/logfiles/smt_cifar10_v2lp.log' self.compute_helper_params() def compute_helper_params(self): @@ -82,13 +99,18 @@ def set_params(self): for key, value in shared_params().__dict__.items(): setattr(self, key, value) for key, value in lca_1_params().__dict__.items(): setattr(self, key, value) self.layer_name = 'lca_2' - self.weight_lr = 0#1e-3 - self.layer_channels = [32, 256] - self.kernel_size = 6 + self.weight_lr = 1e-3 + #self.weight_lr = 0.0 # For next layer training + #self.layer_channels = [32, 256] + self.layer_channels = [32, 512] + #self.kernel_size = 6 + self.kernel_size = 8 self.stride = 1 self.padding = 0 - self.sparse_mult = 0.20 - self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/lca_pool_lca_pool_cifar10/logfiles/lca_pool_lca_pool_cifar10_v0.log' + self.sparse_mult = 0.15 + self.tau = 0.20 + self.checkpoint_boot_log = '' + #self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/smt_cifar10/logfiles/smt_cifar10_vlplp.log' self.compute_helper_params() class pooling_2_params(BaseParams): @@ -97,14 +119,20 @@ def set_params(self): for key, value in shared_params().__dict__.items(): setattr(self, key, value) for key, value in pooling_1_params().__dict__.items(): setattr(self, key, value) self.layer_name = 'pool_2' - self.weight_lr = 0#1e-3 - self.layer_types = ['fc'] - self.layer_channels = [None, 64] + self.weight_lr = 1e-3 + #self.weight_lr = 0.0 # For next layer training + #self.layer_types = ['fc'] + self.layer_types = ['conv'] + #self.layer_channels = [None, 64] + self.layer_channels = [512, 150] + self.pool_ksize = 4 + self.pool_stride = 1 self.optimizer = types.SimpleNamespace() self.optimizer.name = 'sgd' self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs self.optimizer.lr_decay_rate = 0.8 - self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/lca_pool_lca_pool_cifar10/logfiles/lca_pool_lca_pool_cifar10_v0.log' + self.checkpoint_boot_log = '' + #self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/smt_cifar10/logfiles/smt_cifar10_vlplp.log' self.compute_helper_params() def compute_helper_params(self): @@ -119,10 +147,11 @@ def set_params(self): for key, value in shared_params().__dict__.items(): setattr(self, key, value) self.model_type = 'mlp' self.layer_name = 'classifier' - self.weight_lr = 2e-3 + self.weight_lr = 1e-2 self.weight_decay = 1e-6 self.layer_types = ['fc'] - self.layer_channels = [64, 10] + #self.layer_channels = [64, 10] + self.layer_channels = [150, 10] self.activation_functions = ['identity'] self.dropout_rate = [0.0] # probability of value being set to zero self.optimizer = types.SimpleNamespace() @@ -140,18 +169,24 @@ def set_params(self): lca_2_params_inst = lca_2_params() pooling_2_params_inst = pooling_2_params() mlp_params_inst = mlp_params() + data_shape = [3, 32, 32] lca_1_output_height = compute_conv_output_shape( - 32, + data_shape[1], lca_1_params_inst.kernel_size, lca_1_params_inst.stride, lca_1_params_inst.padding, dilation=1) lca_1_output_width = compute_conv_output_shape( - 32, + data_shape[2], lca_1_params_inst.kernel_size, lca_1_params_inst.stride, lca_1_params_inst.padding, dilation=1) + lca_1_shape = [ + lca_1_params_inst.layer_channels[-1], + lca_1_output_height, + lca_1_output_width + ] pooling_1_output_height = compute_conv_output_shape( lca_1_output_height, pooling_1_params_inst.pool_ksize, @@ -164,6 +199,11 @@ def set_params(self): pooling_1_params_inst.pool_stride, padding=0, dilation=1) + pooling_1_shape = [ + pooling_1_params_inst.layer_channels[-1], + pooling_1_output_height, + pooling_1_output_width + ] lca_2_params_inst.data_shape = [ int(pooling_1_params_inst.layer_channels[-1]), int(pooling_1_output_height), @@ -180,14 +220,41 @@ def set_params(self): lca_2_params_inst.stride, lca_2_params_inst.padding, dilation=1) - lca_2_flat_dim = lca_2_params_inst.layer_channels[1]*lca_2_output_height*lca_2_output_width + lca_2_shape = [ + lca_2_params_inst.layer_channels[-1], + lca_2_output_height, + lca_2_output_width + ] + lca_2_flat_dim = int(np.prod(lca_2_shape)) pooling_2_params_inst.layer_channels[0] = lca_2_flat_dim + pooling_2_output_height = compute_conv_output_shape( + lca_2_output_height, + pooling_2_params_inst.pool_ksize, + pooling_2_params_inst.pool_stride, + padding=0, + dilation=1) + pooling_2_output_width = compute_conv_output_shape( + lca_2_output_width, + pooling_2_params_inst.pool_ksize, + pooling_2_params_inst.pool_stride, + padding=0, + dilation=1) + pooling_2_shape = [ + pooling_2_params_inst.layer_channels[-1], + pooling_2_output_height, + pooling_2_output_width + ] + l1_overcompleteness = np.prod(lca_1_shape) / np.prod(data_shape) + p1_overcompleteness = np.prod(pooling_1_shape) / np.prod(lca_1_shape) + l2_overcompleteness = np.prod(lca_2_shape) / np.prod(pooling_1_shape) + p2_overcompleteness = np.prod(pooling_2_shape) / np.prod(lca_2_shape) + import IPython; IPython.embed(); raise SystemExit self.ensemble_params = [ lca_1_params_inst, pooling_1_params_inst, - lca_2_params_inst, - pooling_2_params_inst, - mlp_params_inst + #lca_2_params_inst, + #pooling_2_params_inst, + #mlp_params_inst ] for key, value in shared_params().__dict__.items(): setattr(self, key, value) From ff9e4537fb0929a433c43bb71e047e8289c928ea Mon Sep 17 00:00:00 2001 From: Dylan Date: Thu, 2 Dec 2021 17:36:53 -0700 Subject: [PATCH 42/44] updates so that tests pass with latest pytorch --- models/lca_model.py | 2 +- models/mlp_model.py | 2 +- models/pooling_model.py | 2 +- params/smt_cifar10_params.py | 1 - tests/test_data_processing.py | 2 +- tests/test_foolbox.py | 8 ++++---- tests/test_models.py | 1 - utils/run_utils.py | 4 ++-- 8 files changed, 10 insertions(+), 12 deletions(-) diff --git a/models/lca_model.py b/models/lca_model.py index 38444b6a..ad78e3a8 100644 --- a/models/lca_model.py +++ b/models/lca_model.py @@ -33,7 +33,7 @@ def generate_update_dict(self, input_data, input_labels=None, batch_step=0, upda recon = self.get_recon_from_latents(latents) recon_loss = losses.half_squared_l2(input_data, recon).item() sparse_loss = self.params.sparse_mult * losses.l1_norm(latents).item() - stat_dict['weight_lr'] = self.scheduler.get_lr()[0] + stat_dict['weight_lr'] = self.scheduler.get_last_lr()[0] stat_dict['loss_recon'] = recon_loss stat_dict['loss_sparse'] = sparse_loss stat_dict['loss_total'] = recon_loss + sparse_loss diff --git a/models/mlp_model.py b/models/mlp_model.py index ed4ce8d3..28d0f261 100644 --- a/models/mlp_model.py +++ b/models/mlp_model.py @@ -29,7 +29,7 @@ def generate_update_dict(self, input_data, input_labels=None, batch_step=0, upda total_loss = self.loss_fn(pred, input_labels) pred = pred.max(1, keepdim=True)[1] correct = pred.eq(input_labels.view_as(pred)).sum().item() - stat_dict['weight_lr'] = self.scheduler.get_lr()[0] # one LR for all parameters + stat_dict['weight_lr'] = self.scheduler.get_last_lr()[0] # one LR for all parameters stat_dict['loss'] = total_loss.item() stat_dict['train_accuracy'] = 100. * correct / self.params.batch_size update_dict.update(stat_dict) diff --git a/models/pooling_model.py b/models/pooling_model.py index 55763552..3f5caf1a 100644 --- a/models/pooling_model.py +++ b/models/pooling_model.py @@ -40,7 +40,7 @@ def count_nonzero(array, dim): rep_nnz = count_nonzero(rep, dim=rep_dims).item() stat_dict['fraction_active_all_latents'] = rep_nnz / rep.numel() total_loss = self.loss_fn(rep) - stat_dict['weight_lr'] = self.scheduler.get_lr()[0] + stat_dict['weight_lr'] = self.scheduler.get_last_lr()[0] stat_dict['loss'] = total_loss.item() update_dict.update(stat_dict) return update_dict diff --git a/params/smt_cifar10_params.py b/params/smt_cifar10_params.py index 59d77030..940d4fcf 100644 --- a/params/smt_cifar10_params.py +++ b/params/smt_cifar10_params.py @@ -248,7 +248,6 @@ def set_params(self): p1_overcompleteness = np.prod(pooling_1_shape) / np.prod(lca_1_shape) l2_overcompleteness = np.prod(lca_2_shape) / np.prod(pooling_1_shape) p2_overcompleteness = np.prod(pooling_2_shape) / np.prod(lca_2_shape) - import IPython; IPython.embed(); raise SystemExit self.ensemble_params = [ lca_1_params_inst, pooling_1_params_inst, diff --git a/tests/test_data_processing.py b/tests/test_data_processing.py index 42092eca..7a3f8af6 100644 --- a/tests/test_data_processing.py +++ b/tests/test_data_processing.py @@ -249,7 +249,7 @@ def test_patches(self): err = 1e-6 rand_mean = 0; rand_var = 1 num_im = 10; im_edge = 512; im_chan = 1; patch_edge = 16 - num_patches = np.int(num_im * (im_edge / patch_edge)**2) + num_patches = int(num_im * (im_edge / patch_edge)**2) rand_seed = 1234 rand_state = np.random.RandomState(rand_seed) data = np.stack([rand_state.normal(rand_mean, rand_var, size=[im_chan, im_edge, im_edge]) diff --git a/tests/test_foolbox.py b/tests/test_foolbox.py index 3faf224d..24ca9774 100644 --- a/tests/test_foolbox.py +++ b/tests/test_foolbox.py @@ -7,11 +7,11 @@ if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) #import numpy as np -import eagerpy as ep -from foolbox import PyTorchModel, accuracy, samples -import foolbox.attacks as fa +#import eagerpy as ep +#from foolbox import PyTorchModel, accuracy, samples +#import foolbox.attacks as fa -import DeepSparseCoding.utils.loaders as loaders +#import DeepSparseCoding.utils.loaders as loaders #import DeepSparseCoding.utils.dataset_utils as datasets #import DeepSparseCoding.utils.run_utils as run_utils diff --git a/tests/test_models.py b/tests/test_models.py index 2bcc3877..7fdd172a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -125,4 +125,3 @@ def test_lca_ensemble_gradients(self): +"ensemble weights are not different from init after one epoch of training") assert np.all(lca_w == ensemble_w), (err_msg+'\n' +"lca & ensemble weights are not equal after one epoch of training") - diff --git a/utils/run_utils.py b/utils/run_utils.py index c204ae66..20a1f7ea 100644 --- a/utils/run_utils.py +++ b/utils/run_utils.py @@ -6,12 +6,12 @@ def compute_conv_output_shape(in_length, kernel_length, stride, padding=0, dilation=1): out_shape = ((in_length + 2 * padding - dilation * (kernel_length - 1) - 1) / stride) + 1 - return np.floor(out_shape).astype(np.int) + return np.floor(out_shape).astype(int) def compute_deconv_output_shape(in_length, kernel_length, stride, padding=0, output_padding=0, dilation=1): out_shape = (in_length - 1) * stride - 2 * padding + dilation * (kernel_length - 1) + output_padding + 1 - return np.floor(out_shape).astype(np.int) + return np.floor(out_shape).astype(int) def get_module_encodings(module, data, allow_grads=False): From 4362b04668242c48482ce95c0f79dac3e86cd3c0 Mon Sep 17 00:00:00 2001 From: Dylan Date: Thu, 2 Dec 2021 17:43:34 -0700 Subject: [PATCH 43/44] moved tf requirements out of main list --- tf1x/additional_requirements.txt | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 tf1x/additional_requirements.txt diff --git a/tf1x/additional_requirements.txt b/tf1x/additional_requirements.txt new file mode 100644 index 00000000..8e5be756 --- /dev/null +++ b/tf1x/additional_requirements.txt @@ -0,0 +1,5 @@ +tensorflow-gpu==1.15.2 +tensorflow-estimator==1.15.1 +tensorboard==1.15 +tensorflow-probability==0.8.0 +tensorflow-compression \ No newline at end of file From a2dd266b2e672158e4e19875aa6e609104a0bfea Mon Sep 17 00:00:00 2001 From: Dylan Date: Thu, 2 Dec 2021 17:44:03 -0700 Subject: [PATCH 44/44] moved tf requirements out of main list --- requirements.txt | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/requirements.txt b/requirements.txt index f7e76024..02e1795b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,9 +11,4 @@ Pillow>=5.3.0 scikit-image>=0.14.1 scikit-learn>=0.20.0 scipy>=1.1.0 -seaborn>=0.9.0 -tensorflow-gpu==1.15.2 -tensorflow-estimator==1.15.1 -tensorboard==1.15 -tensorflow-probability==0.8.0 -tensorflow-compression +seaborn>=0.9.0 \ No newline at end of file