From cd1da1135d79b9fa4e327dbf24991e0697b9a101 Mon Sep 17 00:00:00 2001 From: Nan Date: Mon, 7 Feb 2022 18:07:21 +0100 Subject: [PATCH 01/10] plot attn mats --- .vscode/launch.json | 141 +++++++++++++++++++++----------------------- RIM.py | 41 ++++++++++--- argument_parser.py | 3 + data/MovingMNIST.py | 13 +++- networks.py | 4 +- train_mmnist.py | 28 ++++++++- utils/visualize.py | 38 ++++++++++++ 7 files changed, 179 insertions(+), 89 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 24dcc93..aa2e0c9 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -1,5 +1,67 @@ { "configurations": [ + { + "name": "minimal: only for debugging code", + "type": "python", + "request": "launch", + "program": "./train_mmnist.py", + "console": "integratedTerminal", + "args": [ + // "--mini", + "--train_dataset", + "balls3curtain64.h5", + "--transfer_dataset", + "balls3curtain64.h5", + "--hidden_size", + "5", + "--should_save_csv", + "False", + "--lr", + "0.0007", + "--num_units", + "2", + "--k", + "1", + "--num_input_heads", + "1", + "--version", + "0", + "--rnn_cell", + "GRU", + "--input_key_size", + "3", + "--input_value_size", + "4", + "--input_query_size", + "3", //actually not used. altijd == input_key_size + "--input_dropout", + "0.1", + "--comm_key_size", + "3", + "--comm_value_size", + "4", //actually not used. altijd == hidden_size because output a hidden_state + "--comm_query_size", + "3", + "--num_comm_heads", + "2", + "--comm_dropout", + "0.1", + "--experiment_name", + "MMnist", + "--batch_size", + "16", + "--epochs", + "10", + "--should_resume", + "False", + "--batch_frequency_to_log_heatmaps", + "1", + "--model_persist_frequency", + "1", + "--log_intm_frequency", + "1" + ] + }, { "name": "train bball complete (cuda)", "type": "python", @@ -44,7 +106,7 @@ "--comm_dropout", "0.1", "--experiment_name", - "Curtain", + "MMnist", "--batch_size", "64", "--epochs", @@ -54,6 +116,8 @@ "--batch_frequency_to_log_heatmaps", "10", "--model_persist_frequency", + "1", + "--log_intm_frequency", "1" ] }, @@ -104,7 +168,7 @@ "--comm_dropout", "0.1", "--experiment_name", - "Curtain", + "MMnist", "--batch_size", "16", "--epochs", @@ -114,78 +178,9 @@ "--batch_frequency_to_log_heatmaps", "1", "--model_persist_frequency", - "1" - ] - }, - { - "name": "main: RIM (cuda)", - "type": "python", - "request": "launch", - "program": "${file}", - "console": "integratedTerminal", - "args": [ - "--cuda", - "True", - "--epochs", - "6", - "--batch_size", - "64", - "--input_size", "1", - "--hidden_size", - "600", - "--size", - "14", - "--loadsaved", - "0", - "--model", - "RIM" - ] - }, - { - "name": "main: LSTM", - "type": "python", - "request": "launch", - "program": "${file}", - "console": "integratedTerminal", - "args": [ - "--cuda", - "False", - "--epochs", - "1", - "--batch_size", - "64", - "--input_size", - "1", - "--size", - "14", - "--loadsaved", - "0", - "--model", - "LSTM" - ] - }, - { - "name": "main: RIM", - "type": "python", - "request": "launch", - "program": "${file}", - "console": "integratedTerminal", - "args": [ - "--cuda", - "False", - "--epochs", - "7", - "--batch_size", - "64", - "--input_size", - "14", - "--size", - "14", - "--loadsaved", - "0", - "--model", - "RIM" + "--log_intm_frequency", + "1" ] } ] diff --git a/RIM.py b/RIM.py index 8727cc4..f696f91 100644 --- a/RIM.py +++ b/RIM.py @@ -222,7 +222,17 @@ def input_attention_mask(self, x, h): attention_probs = self.input_dropout(nn.Softmax(dim = -1)(attention_scores)) inputs = torch.matmul(attention_probs, value_layer) * mask_.unsqueeze(2) - return inputs, mask_ + """ + parameters: + key_mat: input_size x key_size -> input 2 key + value_mat: input_size x val_size -> input 2 value + query_mat: num_units x hidden_size x key_size -> hidden 2 query + """ + key_mat = get_mat(self.key).transpose(0,1).detach() # TODO detach? + value_mat = get_mat(self.value).transpose(0,1).detach() + query_mat = self.query.w.detach() + + return inputs, mask_, key_mat, value_mat, query_mat def communication_attention(self, h, mask): """ @@ -263,8 +273,18 @@ def communication_attention(self, h, mask): context_layer = context_layer.view(*new_context_layer_shape) context_layer = self.comm_attention_output(context_layer) context_layer = context_layer + h - - return context_layer + + """ + parameters: + key_mat: num_units x hidden_size x (key_size*num_heads) -> hidden 2 key + value_mat: num_units x hidden_size x (val_size*num_heads==hidden_size*num_heads) -> hidden 2 (value==new hidden_state) + query_mat: num_units x hidden_size x (key_size*num_heads) -> hidden 2 query + """ + key_mat = self.key_.w.detach() + value_mat = self.value_.w.detach() + query_mat = self.query_.w.detach() + + return context_layer, key_mat, value_mat, query_mat def nan_hook(self, out): nan_mask = torch.isnan(out) @@ -290,7 +310,7 @@ def forward(self, x, hs, cs = None): x = torch.cat((x.unsqueeze(1), null_input), dim = 1) # Compute input attention - inputs, mask = self.input_attention_mask(x, hs) + inputs, mask, *inp_ctx = self.input_attention_mask(x, hs) h_old = hs * 1.0 if cs is not None: c_old = cs * 1.0 @@ -309,15 +329,15 @@ def forward(self, x, hs, cs = None): h_new = blocked_grad.apply(hs, mask) # Compute communication attention - h_new = self.communication_attention(h_new, mask.squeeze(2)) + h_new, *comm_ctx = self.communication_attention(h_new, mask.squeeze(2)) self.nan_hook(h_new) hs = mask * h_new + (1 - mask) * h_old if cs is not None: cs = mask * cs + (1 - mask) * c_old - return hs, cs, None + return hs, cs, None, inp_ctx, comm_ctx self.nan_hook(hs) - return hs, None, None + return hs, None, None, inp_ctx, comm_ctx class RIM(nn.Module): @@ -625,4 +645,9 @@ def __init__(self, scale_factor, mode): def forward(self, x): x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=False) - return x \ No newline at end of file + return x + +def get_mat(linear_layer): + '''get linear transformation matrix in a linear layer''' + p_list = [p for p in linear_layer.parameters()] + return p_list[0] diff --git a/argument_parser.py b/argument_parser.py index f71315a..5050c85 100644 --- a/argument_parser.py +++ b/argument_parser.py @@ -71,6 +71,9 @@ def argument_parser(): help='Number of training epochs after which model is to ' 'be persisted. -1 means that the model is not' 'persisted') + parser.add_argument('--log_intm_frequency', type=int, default=10, + metavar="Frequency at which we log the intermediate variables of the model", + help='Just type in a positive integer') parser.add_argument('--batch_frequency_to_log_heatmaps', type=int, default=-1, metavar='Frequency at which the heatmaps are persisted', help='Number of training batches after which we will persit the ' diff --git a/data/MovingMNIST.py b/data/MovingMNIST.py index b9edc08..9f197e3 100644 --- a/data/MovingMNIST.py +++ b/data/MovingMNIST.py @@ -35,11 +35,12 @@ class MovingMNIST(data.Dataset): training_file = 'moving_mnist_train.pt' test_file = 'moving_mnist_test.pt' - def __init__(self, root, train=True, split=1000, transform=None, download=False): + def __init__(self, root, train=True, split=1000, transform=None, download=False, mini=False): self.root = os.path.expanduser(root) self.transform = transform self.split = split self.train = train # training set or test set + self.mini = mini if download: self.download() @@ -85,9 +86,15 @@ def _transform_time(data): def __len__(self): if self.train: - return len(self.train_data) + if not self.mini: + return len(self.train_data) + else: + return len(self.train_data) // 10 else: - return len(self.test_data) + if not self.mini: + return len(self.test_data) + else: + return len(self.test_data) // 10 def _check_exists(self): return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \ diff --git a/networks.py b/networks.py index 214a447..849ffbd 100644 --- a/networks.py +++ b/networks.py @@ -248,14 +248,14 @@ def forward(self, x, h_prev): encoded_input = self.Encoder(x) # encoded_input = clamp(encoded_input) self.nan_hook(encoded_input) - h_new, foo, bar = self.rim_model(encoded_input, h_prev) + h_new, foo, bar, inp_ctx, comm_ctx = self.rim_model(encoded_input, h_prev) # h_new = clamp(h_new) self.nan_hook(h_new) dec_out_ = self.Decoder(h_new.view(h_new.shape[0],-1)) # dec_out_ = clamp(dec_out_) self.nan_hook(dec_out_) - return dec_out_, h_new + return dec_out_, h_new, inp_ctx, comm_ctx def init_hidden(self, batch_size): # assert False, "don't call this" diff --git a/train_mmnist.py b/train_mmnist.py index 91d7d12..3c4a97f 100644 --- a/train_mmnist.py +++ b/train_mmnist.py @@ -10,8 +10,10 @@ from logbook.logbook import LogBook from utils.util import set_seed, make_dir from utils.visualize import VectorLog +from utils.visualize import HeatmapLog from data.MovingMNIST import MovingMNIST from box import Box +from tqdm import tqdm import os from os import listdir @@ -36,11 +38,17 @@ def get_grad_norm(model): def train(model, train_loader, optimizer, epoch, logbook, train_batch_idx, args): grad_norm_log = VectorLog(args.folder_log, "grad_norm") + inp_key_log = HeatmapLog(args.folder_log, "inp key mat") + inp_value_log = HeatmapLog(args.folder_log, "inp value mat") + inp_query_log = HeatmapLog(args.folder_log, "inp query mat") + comm_key_log = HeatmapLog(args.folder_log, "comm key mat") + comm_value_log = HeatmapLog(args.folder_log, "comm value mat") + comm_query_log = HeatmapLog(args.folder_log, "comm query mat") model.train() epoch_loss = torch.tensor(0.).to(args.device) - for batch_idx, data in enumerate(train_loader): + for batch_idx, data in enumerate(tqdm(train_loader)): hidden = model.init_hidden(data.shape[0]).to(args.device) start_time = time() @@ -52,7 +60,7 @@ def train(model, train_loader, optimizer, epoch, logbook, train_batch_idx, args) # with autograd.detect_anomaly(): if True: for frame in range(data.shape[1]-1): - output, hidden = model(data[:, frame, :, :, :], hidden) + output, hidden, inp_ctx, comm_ctx = model(data[:, frame, :, :, :], hidden) nan_hook(output) nan_hook(hidden) @@ -76,6 +84,20 @@ def train(model, train_loader, optimizer, epoch, logbook, train_batch_idx, args) } logbook.write_metric_logs(metrics=metrics) + # if True: + if args.log_intm_frequency > 0 and epoch % args.log_intm_frequency == 0: + """log intermediate variables here""" + pass + # TODO plot inp_ctx + inp_key_log.plot(inp_ctx[0], epoch) # happens in the first batch + inp_value_log.plot(inp_ctx[1], epoch) + inp_query_log.plot(inp_ctx[2], epoch) + pass + # TODO plot comm_ctx + comm_key_log.plot(comm_ctx[0], epoch) + comm_value_log.plot(comm_ctx[1], epoch) + comm_query_log.plot(comm_ctx[2], epoch) + epoch_loss += loss.detach() epoch_loss = epoch_loss / (batch_idx+1) @@ -100,7 +122,7 @@ def main(): model, optimizer, start_epoch, train_batch_idx = setup_model(args=args, logbook=logbook) - train_set = MovingMNIST(root='./data', train=True, download=True) + train_set = MovingMNIST(root='./data', train=True, download=True, mini=True) test_set = MovingMNIST(root='./data', train=False, download=True) train_loader = torch.utils.data.DataLoader( diff --git a/utils/visualize.py b/utils/visualize.py index 8c48230..0a00b5e 100644 --- a/utils/visualize.py +++ b/utils/visualize.py @@ -1,5 +1,6 @@ import matplotlib.pyplot as plt import torch +from .util import make_dir def plot_frames(batch_of_pred, batch_of_target, start_frame, end_frame, batch_idx): ''' @@ -27,6 +28,43 @@ def plot_curve(loss): axs.plot(loss) plt.savefig(f"loss_curve.png",dpi=120) +def plot_mat(mat, mat_name, epoch): + if mat.dim() == 3: + mat_list = [mat[idx_unit,:,:].squeeze().cpu() for idx_unit in range(mat.shape[0])] + else: + mat_list = [mat.cpu()] + fig, axs = plt.subplots(1, len(mat_list), figsize=(2*len(mat_list), 2)) + for idx_mat, mat in enumerate(mat_list): + if len(mat_list) == 1: + axs.imshow(mat, cmap='hot', interpolation='nearest') + else: + axs[idx_mat].imshow(mat, cmap='hot', interpolation='nearest') + fig.suptitle(mat_name.replace("_", " ")+ f' in epoch [{epoch}]') + plt.savefig(mat_name + f'_epoch_{epoch}.png', dpi=120) + plt.close() + +class HeatmapLog: + def __init__(self, folder_log, mat_name): + make_dir(f"{folder_log}/intermediate_vars/"+mat_name) + self.save_folder = f"{folder_log}/intermediate_vars/"+mat_name + self.mat_name = mat_name + + def plot(self, mat, epoch): + if mat.dim() == 3: + mat_list = [mat[idx_unit,:,:].squeeze().cpu() for idx_unit in range(mat.shape[0])] + else: + mat_list = [mat.cpu()] + fig, axs = plt.subplots(1, len(mat_list), figsize=(2*len(mat_list), 2)) + for idx_mat, mat in enumerate(mat_list): + if len(mat_list) == 1: + axs.imshow(mat, cmap='hot', interpolation='nearest') + else: + axs[idx_mat].imshow(mat, cmap='hot', interpolation='nearest') + fig.suptitle(self.mat_name.replace("_", " ")+ f' in epoch [{epoch}]') + plt.savefig(self.save_folder + '/' + self.mat_name.replace(' ','_') + f'_epoch_{epoch}.png', dpi=120) + plt.close() + + class VectorLog: def __init__(self, save_path, var_name): self.save_path = save_path From 98bbdf1d4eb402c889d4a13b0c19e3243d92beae Mon Sep 17 00:00:00 2001 From: Nan Date: Mon, 7 Feb 2022 18:15:08 +0100 Subject: [PATCH 02/10] renaming --- .vscode/launch.json | 4 ++-- train_mmnist.py | 6 +++--- utils/visualize.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index aa2e0c9..f900c63 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -118,7 +118,7 @@ "--model_persist_frequency", "1", "--log_intm_frequency", - "1" + "10" ] }, { @@ -180,7 +180,7 @@ "--model_persist_frequency", "1", "--log_intm_frequency", - "1" + "10" ] } ] diff --git a/train_mmnist.py b/train_mmnist.py index 3c4a97f..db651f3 100644 --- a/train_mmnist.py +++ b/train_mmnist.py @@ -9,7 +9,7 @@ from argument_parser import argument_parser from logbook.logbook import LogBook from utils.util import set_seed, make_dir -from utils.visualize import VectorLog +from utils.visualize import ScalarLog from utils.visualize import HeatmapLog from data.MovingMNIST import MovingMNIST from box import Box @@ -37,7 +37,7 @@ def get_grad_norm(model): return total_norm def train(model, train_loader, optimizer, epoch, logbook, train_batch_idx, args): - grad_norm_log = VectorLog(args.folder_log, "grad_norm") + grad_norm_log = ScalarLog(args.folder_log, "grad_norm") inp_key_log = HeatmapLog(args.folder_log, "inp key mat") inp_value_log = HeatmapLog(args.folder_log, "inp value mat") inp_query_log = HeatmapLog(args.folder_log, "inp query mat") @@ -136,7 +136,7 @@ def main(): shuffle=False ) transfer_loader = test_loader - epoch_loss_log = VectorLog(args.folder_log, "epoch_loss") + epoch_loss_log = ScalarLog(args.folder_log, "epoch_loss") for epoch in range(start_epoch, args.epochs+1): train_batch_idx, epoch_loss = train( model = model, diff --git a/utils/visualize.py b/utils/visualize.py index 0a00b5e..0b47945 100644 --- a/utils/visualize.py +++ b/utils/visualize.py @@ -65,7 +65,7 @@ def plot(self, mat, epoch): plt.close() -class VectorLog: +class ScalarLog: def __init__(self, save_path, var_name): self.save_path = save_path self.var_name = var_name+".pt" From 10ba81f8e539eef967896b3ed058065cb84d60e1 Mon Sep 17 00:00:00 2001 From: Nan Date: Wed, 9 Feb 2022 15:10:01 +0100 Subject: [PATCH 03/10] finished log and save and plot --- .vscode/launch.json | 11 +++ RIM.py | 4 +- data/MovingMNIST.py | 4 +- .../log.txt | 37 --------- .../model/args | Bin 1583 -> 0 bytes networks.py | 2 +- train_mmnist.py | 67 +++++++++------ utils/visualize.py | 71 +++++++++++++--- visualize_logtensor.py | 78 ++++++++++++++++++ 9 files changed, 195 insertions(+), 79 deletions(-) delete mode 100644 logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/log.txt delete mode 100644 logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args create mode 100644 visualize_logtensor.py diff --git a/.vscode/launch.json b/.vscode/launch.json index f900c63..9a47491 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -1,5 +1,16 @@ { "configurations": [ + { + "name": "visualize logged tensor", + "type": "python", + "request": "launch", + "program": "./visualize_logtensor.py", + "console": "integratedTerminal", + "args": [ + "--folder_log", + "logs/SchemaBlocks_5_2_MMnist_0.0007_num_inp_heads_1_ver_0/intermediate_vars" + ] + }, { "name": "minimal: only for debugging code", "type": "python", diff --git a/RIM.py b/RIM.py index f696f91..aa5d533 100644 --- a/RIM.py +++ b/RIM.py @@ -232,7 +232,7 @@ def input_attention_mask(self, x, h): value_mat = get_mat(self.value).transpose(0,1).detach() query_mat = self.query.w.detach() - return inputs, mask_, key_mat, value_mat, query_mat + return inputs, mask_, key_mat, value_mat, query_mat, attention_scores def communication_attention(self, h, mask): """ @@ -284,7 +284,7 @@ def communication_attention(self, h, mask): value_mat = self.value_.w.detach() query_mat = self.query_.w.detach() - return context_layer, key_mat, value_mat, query_mat + return context_layer, key_mat, value_mat, query_mat, attention_scores def nan_hook(self, out): nan_mask = torch.isnan(out) diff --git a/data/MovingMNIST.py b/data/MovingMNIST.py index 9f197e3..8fdb7d3 100644 --- a/data/MovingMNIST.py +++ b/data/MovingMNIST.py @@ -89,12 +89,12 @@ def __len__(self): if not self.mini: return len(self.train_data) else: - return len(self.train_data) // 10 + return min(len(self.train_data) // 50 , 100) else: if not self.mini: return len(self.test_data) else: - return len(self.test_data) // 10 + return min(len(self.test_data) // 50 , 100) def _check_exists(self): return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \ diff --git a/logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/log.txt b/logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/log.txt deleted file mode 100644 index f045fcc..0000000 --- a/logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/log.txt +++ /dev/null @@ -1,37 +0,0 @@ -{"batch_size": 16, "epochs": 2, "num_input_heads": 1, "sequence_length": 51, "lr": 0.0007, "input_dropout": 0.1, "comm_dropout": 0.1, "kl_coeff": 0.0, "num_units": 3, "num_encoders": 1, "k": 2, "memorytopk": 4, "hidden_size": 10, "n_templates": 0, "share_inp": false, "share_comm": false, "do_rel": false, "memory_slots": 4, "memory_mlp": 4, "attention_out": 340, "id": "SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "rnn_cell": "GRU", "model_persist_frequency": 1, "batch_frequency_to_log_heatmaps": 1, "path_to_load_model": "", "components_to_load": "", "train_dataset": "balls3curtain64.h5", "test_dataset": null, "transfer_dataset": "balls3curtain64.h5", "should_save_csv": false, "should_resume": false, "experiment_name": "Curtain", "version": 0, "do_comm": true, "input_key_size": 10, "input_value_size": 10, "input_query_size": 10, "comm_key_size": 5, "comm_value_size": 5, "comm_query_size": 5, "num_comm_heads": 2, "frame_frequency_to_log_heatmaps": 5, "folder_log": "./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "type": "config", "timestamp": "06:29PM CET Feb 03, 2022"} -{"messgae": "Saving args to ./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args", "experiment_id": 1, "type": "print", "timestamp": "06:29PM CET Feb 03, 2022"} -{"batch_size": 16, "epochs": 2, "num_input_heads": 1, "sequence_length": 51, "lr": 0.0007, "input_dropout": 0.1, "comm_dropout": 0.1, "kl_coeff": 0.0, "num_units": 3, "num_encoders": 1, "k": 2, "memorytopk": 4, "hidden_size": 10, "n_templates": 0, "share_inp": false, "share_comm": false, "do_rel": false, "memory_slots": 4, "memory_mlp": 4, "attention_out": 340, "id": "SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "rnn_cell": "GRU", "model_persist_frequency": 1, "batch_frequency_to_log_heatmaps": 1, "path_to_load_model": "", "components_to_load": "", "train_dataset": "balls3curtain64.h5", "test_dataset": null, "transfer_dataset": "balls3curtain64.h5", "should_save_csv": false, "should_resume": false, "experiment_name": "Curtain", "version": 0, "do_comm": true, "input_key_size": 10, "input_value_size": 10, "input_query_size": 10, "comm_key_size": 5, "comm_value_size": 5, "comm_query_size": 5, "num_comm_heads": 2, "frame_frequency_to_log_heatmaps": 5, "folder_log": "./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "type": "config", "timestamp": "06:29PM CET Feb 03, 2022"} -{"messgae": "Saving args to ./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args", "experiment_id": 1, "type": "print", "timestamp": "06:29PM CET Feb 03, 2022"} -{"batch_size": 16, "epochs": 2, "num_input_heads": 1, "sequence_length": 51, "lr": 0.0007, "input_dropout": 0.1, "comm_dropout": 0.1, "kl_coeff": 0.0, "num_units": 3, "num_encoders": 1, "k": 2, "memorytopk": 4, "hidden_size": 10, "n_templates": 0, "share_inp": false, "share_comm": false, "do_rel": false, "memory_slots": 4, "memory_mlp": 4, "attention_out": 340, "id": "SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "rnn_cell": "GRU", "model_persist_frequency": 1, "batch_frequency_to_log_heatmaps": 1, "path_to_load_model": "", "components_to_load": "", "train_dataset": "balls3curtain64.h5", "test_dataset": null, "transfer_dataset": "balls3curtain64.h5", "should_save_csv": false, "should_resume": false, "experiment_name": "Curtain", "version": 0, "do_comm": true, "input_key_size": 10, "input_value_size": 10, "input_query_size": 10, "comm_key_size": 5, "comm_value_size": 5, "comm_query_size": 5, "num_comm_heads": 2, "frame_frequency_to_log_heatmaps": 5, "folder_log": "./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "type": "config", "timestamp": "06:30PM CET Feb 03, 2022"} -{"messgae": "Saving args to ./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args", "experiment_id": 1, "type": "print", "timestamp": "06:30PM CET Feb 03, 2022"} -{"batch_size": 16, "epochs": 2, "num_input_heads": 1, "sequence_length": 51, "lr": 0.0007, "input_dropout": 0.1, "comm_dropout": 0.1, "kl_coeff": 0.0, "num_units": 3, "num_encoders": 1, "k": 2, "memorytopk": 4, "hidden_size": 10, "n_templates": 0, "share_inp": false, "share_comm": false, "do_rel": false, "memory_slots": 4, "memory_mlp": 4, "attention_out": 340, "id": "SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "rnn_cell": "GRU", "model_persist_frequency": 1, "batch_frequency_to_log_heatmaps": 1, "path_to_load_model": "", "components_to_load": "", "train_dataset": "balls3curtain64.h5", "test_dataset": null, "transfer_dataset": "balls3curtain64.h5", "should_save_csv": false, "should_resume": false, "experiment_name": "Curtain", "version": 0, "do_comm": true, "input_key_size": 10, "input_value_size": 10, "input_query_size": 10, "comm_key_size": 5, "comm_value_size": 5, "comm_query_size": 5, "num_comm_heads": 2, "frame_frequency_to_log_heatmaps": 5, "folder_log": "./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "type": "config", "timestamp": "06:50PM CET Feb 03, 2022"} -{"messgae": "Saving args to ./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args", "experiment_id": 1, "type": "print", "timestamp": "06:50PM CET Feb 03, 2022"} -{"batch_size": 16, "epochs": 2, "num_input_heads": 1, "sequence_length": 51, "lr": 0.0007, "input_dropout": 0.1, "comm_dropout": 0.1, "kl_coeff": 0.0, "num_units": 3, "num_encoders": 1, "k": 2, "memorytopk": 4, "hidden_size": 10, "n_templates": 0, "share_inp": false, "share_comm": false, "do_rel": false, "memory_slots": 4, "memory_mlp": 4, "attention_out": 340, "id": "SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "rnn_cell": "GRU", "model_persist_frequency": 1, "batch_frequency_to_log_heatmaps": 1, "path_to_load_model": "", "components_to_load": "", "train_dataset": "balls3curtain64.h5", "test_dataset": null, "transfer_dataset": "balls3curtain64.h5", "should_save_csv": false, "should_resume": false, "experiment_name": "Curtain", "version": 0, "do_comm": true, "input_key_size": 10, "input_value_size": 10, "input_query_size": 10, "comm_key_size": 5, "comm_value_size": 5, "comm_query_size": 5, "num_comm_heads": 2, "frame_frequency_to_log_heatmaps": 5, "folder_log": "./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "type": "config", "timestamp": "06:52PM CET Feb 03, 2022"} -{"messgae": "Saving args to ./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args", "experiment_id": 1, "type": "print", "timestamp": "06:52PM CET Feb 03, 2022"} -{"batch_size": 16, "epochs": 2, "num_input_heads": 1, "sequence_length": 51, "lr": 0.0007, "input_dropout": 0.1, "comm_dropout": 0.1, "kl_coeff": 0.0, "num_units": 3, "num_encoders": 1, "k": 2, "memorytopk": 4, "hidden_size": 10, "n_templates": 0, "share_inp": false, "share_comm": false, "do_rel": false, "memory_slots": 4, "memory_mlp": 4, "attention_out": 340, "id": "SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "rnn_cell": "GRU", "model_persist_frequency": 1, "batch_frequency_to_log_heatmaps": 1, "path_to_load_model": "", "components_to_load": "", "train_dataset": "balls3curtain64.h5", "test_dataset": null, "transfer_dataset": "balls3curtain64.h5", "should_save_csv": false, "should_resume": false, "experiment_name": "Curtain", "version": 0, "do_comm": true, "input_key_size": 10, "input_value_size": 10, "input_query_size": 10, "comm_key_size": 5, "comm_value_size": 5, "comm_query_size": 5, "num_comm_heads": 2, "frame_frequency_to_log_heatmaps": 5, "folder_log": "./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "type": "config", "timestamp": "06:54PM CET Feb 03, 2022"} -{"messgae": "Saving args to ./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args", "experiment_id": 1, "type": "print", "timestamp": "06:54PM CET Feb 03, 2022"} -{"batch_size": 16, "epochs": 2, "num_input_heads": 1, "sequence_length": 51, "lr": 0.0007, "input_dropout": 0.1, "comm_dropout": 0.1, "kl_coeff": 0.0, "num_units": 3, "num_encoders": 1, "k": 2, "memorytopk": 4, "hidden_size": 10, "n_templates": 0, "share_inp": false, "share_comm": false, "do_rel": false, "memory_slots": 4, "memory_mlp": 4, "attention_out": 340, "id": "SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "rnn_cell": "GRU", "model_persist_frequency": 1, "batch_frequency_to_log_heatmaps": 1, "path_to_load_model": "", "components_to_load": "", "train_dataset": "balls3curtain64.h5", "test_dataset": null, "transfer_dataset": "balls3curtain64.h5", "should_save_csv": false, "should_resume": false, "experiment_name": "Curtain", "version": 0, "do_comm": true, "input_key_size": 10, "input_value_size": 10, "input_query_size": 10, "comm_key_size": 5, "comm_value_size": 5, "comm_query_size": 5, "num_comm_heads": 2, "frame_frequency_to_log_heatmaps": 5, "folder_log": "./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "type": "config", "timestamp": "06:55PM CET Feb 03, 2022"} -{"messgae": "Saving args to ./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args", "experiment_id": 1, "type": "print", "timestamp": "06:55PM CET Feb 03, 2022"} -{"batch_size": 16, "epochs": 2, "num_input_heads": 1, "sequence_length": 51, "lr": 0.0007, "input_dropout": 0.1, "comm_dropout": 0.1, "kl_coeff": 0.0, "num_units": 3, "num_encoders": 1, "k": 2, "memorytopk": 4, "hidden_size": 10, "n_templates": 0, "share_inp": false, "share_comm": false, "do_rel": false, "memory_slots": 4, "memory_mlp": 4, "attention_out": 340, "id": "SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "rnn_cell": "GRU", "model_persist_frequency": 1, "batch_frequency_to_log_heatmaps": 1, "path_to_load_model": "", "components_to_load": "", "train_dataset": "balls3curtain64.h5", "test_dataset": null, "transfer_dataset": "balls3curtain64.h5", "should_save_csv": false, "should_resume": false, "experiment_name": "Curtain", "version": 0, "do_comm": true, "input_key_size": 10, "input_value_size": 10, "input_query_size": 10, "comm_key_size": 5, "comm_value_size": 5, "comm_query_size": 5, "num_comm_heads": 2, "frame_frequency_to_log_heatmaps": 5, "folder_log": "./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "type": "config", "timestamp": "06:56PM CET Feb 03, 2022"} -{"messgae": "Saving args to ./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args", "experiment_id": 1, "type": "print", "timestamp": "06:56PM CET Feb 03, 2022"} -{"batch_size": 16, "epochs": 2, "num_input_heads": 1, "sequence_length": 51, "lr": 0.0007, "input_dropout": 0.1, "comm_dropout": 0.1, "kl_coeff": 0.0, "num_units": 3, "num_encoders": 1, "k": 2, "memorytopk": 4, "hidden_size": 10, "n_templates": 0, "share_inp": false, "share_comm": false, "do_rel": false, "memory_slots": 4, "memory_mlp": 4, "attention_out": 340, "id": "SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "rnn_cell": "GRU", "model_persist_frequency": 1, "batch_frequency_to_log_heatmaps": 1, "path_to_load_model": "", "components_to_load": "", "train_dataset": "balls3curtain64.h5", "test_dataset": null, "transfer_dataset": "balls3curtain64.h5", "should_save_csv": false, "should_resume": false, "experiment_name": "Curtain", "version": 0, "do_comm": true, "input_key_size": 10, "input_value_size": 10, "input_query_size": 10, "comm_key_size": 5, "comm_value_size": 5, "comm_query_size": 5, "num_comm_heads": 2, "frame_frequency_to_log_heatmaps": 5, "folder_log": "./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "type": "config", "timestamp": "06:59PM CET Feb 03, 2022"} -{"messgae": "Saving args to ./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args", "experiment_id": 1, "type": "print", "timestamp": "06:59PM CET Feb 03, 2022"} -{"batch_size": 16, "epochs": 2, "num_input_heads": 1, "sequence_length": 51, "lr": 0.0007, "input_dropout": 0.1, "comm_dropout": 0.1, "kl_coeff": 0.0, "num_units": 3, "num_encoders": 1, "k": 2, "memorytopk": 4, "hidden_size": 10, "n_templates": 0, "share_inp": false, "share_comm": false, "do_rel": false, "memory_slots": 4, "memory_mlp": 4, "attention_out": 340, "id": "SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "rnn_cell": "GRU", "model_persist_frequency": 1, "batch_frequency_to_log_heatmaps": 1, "path_to_load_model": "", "components_to_load": "", "train_dataset": "balls3curtain64.h5", "test_dataset": null, "transfer_dataset": "balls3curtain64.h5", "should_save_csv": false, "should_resume": false, "experiment_name": "Curtain", "version": 0, "do_comm": true, "input_key_size": 10, "input_value_size": 10, "input_query_size": 10, "comm_key_size": 5, "comm_value_size": 5, "comm_query_size": 5, "num_comm_heads": 2, "frame_frequency_to_log_heatmaps": 5, "folder_log": "./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "type": "config", "timestamp": "07:00PM CET Feb 03, 2022"} -{"messgae": "Saving args to ./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args", "experiment_id": 1, "type": "print", "timestamp": "07:00PM CET Feb 03, 2022"} -{"batch_size": 16, "epochs": 2, "num_input_heads": 1, "sequence_length": 51, "lr": 0.0007, "input_dropout": 0.1, "comm_dropout": 0.1, "kl_coeff": 0.0, "num_units": 3, "num_encoders": 1, "k": 2, "memorytopk": 4, "hidden_size": 10, "n_templates": 0, "share_inp": false, "share_comm": false, "do_rel": false, "memory_slots": 4, "memory_mlp": 4, "attention_out": 340, "id": "SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "rnn_cell": "GRU", "model_persist_frequency": 1, "batch_frequency_to_log_heatmaps": 1, "path_to_load_model": "", "components_to_load": "", "train_dataset": "balls3curtain64.h5", "test_dataset": null, "transfer_dataset": "balls3curtain64.h5", "should_save_csv": false, "should_resume": false, "experiment_name": "Curtain", "version": 0, "do_comm": true, "input_key_size": 10, "input_value_size": 10, "input_query_size": 10, "comm_key_size": 5, "comm_value_size": 5, "comm_query_size": 5, "num_comm_heads": 2, "frame_frequency_to_log_heatmaps": 5, "folder_log": "./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "type": "config", "timestamp": "07:02PM CET Feb 03, 2022"} -{"messgae": "Saving args to ./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args", "experiment_id": 1, "type": "print", "timestamp": "07:02PM CET Feb 03, 2022"} -{"batch_size": 16, "epochs": 2, "num_input_heads": 1, "sequence_length": 51, "lr": 0.0007, "input_dropout": 0.1, "comm_dropout": 0.1, "kl_coeff": 0.0, "num_units": 3, "num_encoders": 1, "k": 2, "memorytopk": 4, "hidden_size": 10, "n_templates": 0, "share_inp": false, "share_comm": false, "do_rel": false, "memory_slots": 4, "memory_mlp": 4, "attention_out": 340, "id": "SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "rnn_cell": "GRU", "model_persist_frequency": 1, "batch_frequency_to_log_heatmaps": 1, "path_to_load_model": "", "components_to_load": "", "train_dataset": "balls3curtain64.h5", "test_dataset": null, "transfer_dataset": "balls3curtain64.h5", "should_save_csv": false, "should_resume": false, "experiment_name": "Curtain", "version": 0, "do_comm": true, "input_key_size": 10, "input_value_size": 10, "input_query_size": 10, "comm_key_size": 5, "comm_value_size": 5, "comm_query_size": 5, "num_comm_heads": 2, "frame_frequency_to_log_heatmaps": 5, "folder_log": "./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "type": "config", "timestamp": "07:02PM CET Feb 03, 2022"} -{"messgae": "Saving args to ./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args", "experiment_id": 1, "type": "print", "timestamp": "07:02PM CET Feb 03, 2022"} -{"batch_size": 16, "epochs": 2, "num_input_heads": 1, "sequence_length": 51, "lr": 0.0007, "input_dropout": 0.1, "comm_dropout": 0.1, "kl_coeff": 0.0, "num_units": 3, "num_encoders": 1, "k": 2, "memorytopk": 4, "hidden_size": 10, "n_templates": 0, "share_inp": false, "share_comm": false, "do_rel": false, "memory_slots": 4, "memory_mlp": 4, "attention_out": 340, "id": "SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "rnn_cell": "GRU", "model_persist_frequency": 1, "batch_frequency_to_log_heatmaps": 1, "path_to_load_model": "", "components_to_load": "", "train_dataset": "balls3curtain64.h5", "test_dataset": null, "transfer_dataset": "balls3curtain64.h5", "should_save_csv": false, "should_resume": false, "experiment_name": "Curtain", "version": 0, "do_comm": true, "input_key_size": 10, "input_value_size": 10, "input_query_size": 10, "comm_key_size": 5, "comm_value_size": 5, "comm_query_size": 5, "num_comm_heads": 2, "frame_frequency_to_log_heatmaps": 5, "folder_log": "./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "type": "config", "timestamp": "07:03PM CET Feb 03, 2022"} -{"messgae": "Saving args to ./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args", "experiment_id": 1, "type": "print", "timestamp": "07:03PM CET Feb 03, 2022"} -{"loss": 30.825395584106445, "mode": "train", "batch_idx": 1, "epoch": 1, "time_taken": 3.2734878063201904, "experiment_id": 1, "type": "metric", "timestamp": "07:03PM CET Feb 03, 2022"} -{"loss": -143.4939422607422, "mode": "train", "batch_idx": 2, "epoch": 1, "time_taken": 3.170254945755005, "experiment_id": 1, "type": "metric", "timestamp": "07:03PM CET Feb 03, 2022"} -{"loss": -273.0503234863281, "mode": "train", "batch_idx": 3, "epoch": 1, "time_taken": 3.102560043334961, "experiment_id": 1, "type": "metric", "timestamp": "07:03PM CET Feb 03, 2022"} -{"loss": -412.27764892578125, "mode": "train", "batch_idx": 4, "epoch": 1, "time_taken": 3.0873279571533203, "experiment_id": 1, "type": "metric", "timestamp": "07:03PM CET Feb 03, 2022"} -{"loss": -540.3322143554688, "mode": "train", "batch_idx": 5, "epoch": 1, "time_taken": 3.120579242706299, "experiment_id": 1, "type": "metric", "timestamp": "07:03PM CET Feb 03, 2022"} -{"loss": -716.9647827148438, "mode": "train", "batch_idx": 6, "epoch": 1, "time_taken": 3.1009089946746826, "experiment_id": 1, "type": "metric", "timestamp": "07:03PM CET Feb 03, 2022"} -{"loss": -709.658203125, "mode": "train", "batch_idx": 7, "epoch": 1, "time_taken": 3.1195900440216064, "experiment_id": 1, "type": "metric", "timestamp": "07:03PM CET Feb 03, 2022"} -{"loss": -729.649169921875, "mode": "train", "batch_idx": 8, "epoch": 1, "time_taken": 3.0995841026306152, "experiment_id": 1, "type": "metric", "timestamp": "07:03PM CET Feb 03, 2022"} -{"loss": -716.698974609375, "mode": "train", "batch_idx": 9, "epoch": 1, "time_taken": 3.098029851913452, "experiment_id": 1, "type": "metric", "timestamp": "07:03PM CET Feb 03, 2022"} -{"loss": -951.6339111328125, "mode": "train", "batch_idx": 10, "epoch": 1, "time_taken": 3.1082980632781982, "experiment_id": 1, "type": "metric", "timestamp": "07:03PM CET Feb 03, 2022"} -{"loss": -917.8902587890625, "mode": "train", "batch_idx": 11, "epoch": 1, "time_taken": 3.1051442623138428, "experiment_id": 1, "type": "metric", "timestamp": "07:03PM CET Feb 03, 2022"} diff --git a/logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args b/logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args deleted file mode 100644 index 8fb4449ec582ffbc509fc0371f337e0d1e261669..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1583 zcmbVM%We}f6rCh(`T*Ke9_3lyP@qjy0%e7eiqtMfps0Xkv9ZjI(;3Axo;&sg3aZ2g zi7fa4SnwUd53pjv2k;Mk0@rqurbXRgX*9m)z8}6#zcS)^uXcqi#_LEc?Ud#CQx=uL1_od3uF?>mrMr{Rq-1R6@P|Q5#pH)A_XIUnM@;@ z772~htk5*#EL1S+mn^P06g&+$6+GS25mfyKDWkl|VQk~!`k(5@uit-=BZxTdFqfGu zG>mUN{PxM5&*U&N1DPbW&zsognN-_?29md0Fu8%0J!ArZCZr;bwSq%_#ab~41|vhB zD>#f1mQbt>M=%GIWx^AgZ)=&gVagvNhY*V5FyyJreA=J1SV}cdGQl)gFynjHmx@@< zO#$F&*V=M`2^(g+wqT*8ITvuOYeR9qs1g$6JC5FV1E7h>-~>XZDoksh>R6`KRORI3 zXC+&{ID|Q}iu|WR#1pnAWYAW0rA{04QITsFr?g(H*X#FbuQ9GIbcJ?!PU~=rn9Aj8 zN&_wgoF)|{H#VQc88U~r1WP1nhGmMCrmfsHaU0HJ%`PC{&BFkvT2djmj1Qf#Ou;#n zn?<>d>BuozNUc#gPZ0ND-tw_NnM^T3)t7<`(Y#Gu=Ozg=tBPy5XvJSLAygx939jC) zMfc#6ttPfY4`eStAv5R#4Wz1t+4gN*ImpIUH>r^jRlW(v5qYZ@b~Tjl*T4 zc5sJ@f;-0I9tJ+9yr;2*3^O72X$WG_g%$Qgm_k36INti~r za!hY`0RCz89aJ}y+mZqwl^FtPkOv6VmhrBZ|8qR=7OK|t$Hw1Zef9Js?j!Rx4_t23 z-s)hv^{T$Uc4yyi+q38nAR3d#>8~$a&&`T5U(+v{-(aT>2V~l{_w3)W`?ZHHo29vX zeq(}t?OqB^zgk`FNz7N42C_CDr27B2J)J$ diff --git a/networks.py b/networks.py index 849ffbd..f51b545 100644 --- a/networks.py +++ b/networks.py @@ -255,7 +255,7 @@ def forward(self, x, h_prev): # dec_out_ = clamp(dec_out_) self.nan_hook(dec_out_) - return dec_out_, h_new, inp_ctx, comm_ctx + return dec_out_, h_new, inp_ctx, comm_ctx, encoded_input def init_hidden(self, batch_size): # assert False, "don't call this" diff --git a/train_mmnist.py b/train_mmnist.py index db651f3..7827429 100644 --- a/train_mmnist.py +++ b/train_mmnist.py @@ -9,7 +9,7 @@ from argument_parser import argument_parser from logbook.logbook import LogBook from utils.util import set_seed, make_dir -from utils.visualize import ScalarLog +from utils.visualize import ScalarLog, VectorLog from utils.visualize import HeatmapLog from data.MovingMNIST import MovingMNIST from box import Box @@ -37,18 +37,27 @@ def get_grad_norm(model): return total_norm def train(model, train_loader, optimizer, epoch, logbook, train_batch_idx, args): - grad_norm_log = ScalarLog(args.folder_log, "grad_norm") - inp_key_log = HeatmapLog(args.folder_log, "inp key mat") - inp_value_log = HeatmapLog(args.folder_log, "inp value mat") - inp_query_log = HeatmapLog(args.folder_log, "inp query mat") - comm_key_log = HeatmapLog(args.folder_log, "comm key mat") - comm_value_log = HeatmapLog(args.folder_log, "comm value mat") - comm_query_log = HeatmapLog(args.folder_log, "comm query mat") + intm_log_folder = args.folder_log + '/intermediate_vars' + inp_key_log = HeatmapLog(intm_log_folder, "inp key mat") + inp_value_log = HeatmapLog(intm_log_folder, "inp value mat") + inp_query_log = HeatmapLog(intm_log_folder, "inp query mat") + comm_key_log = HeatmapLog(intm_log_folder, "comm key mat") + comm_value_log = HeatmapLog(intm_log_folder, "comm value mat") + comm_query_log = HeatmapLog(intm_log_folder, "comm query mat") + + grad_norm_log = ScalarLog(intm_log_folder, "grad_norm") + encoded_log = VectorLog(intm_log_folder, "encoded", epoch=epoch) # TODO put them in a folder + attn_score_log = VectorLog(intm_log_folder, "attn_score", epoch=epoch) + hidden_log = VectorLog(intm_log_folder, "hidden_state", epoch=epoch) model.train() epoch_loss = torch.tensor(0.).to(args.device) for batch_idx, data in enumerate(tqdm(train_loader)): + attn_score_log.reset() + encoded_log.reset() + hidden_log.reset() + hidden = model.init_hidden(data.shape[0]).to(args.device) start_time = time() @@ -60,8 +69,15 @@ def train(model, train_loader, optimizer, epoch, logbook, train_batch_idx, args) # with autograd.detect_anomaly(): if True: for frame in range(data.shape[1]-1): - output, hidden, inp_ctx, comm_ctx = model(data[:, frame, :, :, :], hidden) - + output, hidden, inp_ctx, comm_ctx, encoded_input = model(data[:, frame, :, :, :], hidden) + # ----- logging ----- + encoded_log.append(encoded_input[-1]) # only take the last sample in a batch + attn_score_cat = torch.cat( + (inp_ctx[2][-1].flatten(),comm_ctx[2][-1].flatten()) + ) + attn_score_log.append(attn_score_cat) + hidden_log.append(hidden[-1].flatten()) + # ----- ------- ----- nan_hook(output) nan_hook(hidden) target = data[:, frame+1, :, :, :] @@ -84,19 +100,22 @@ def train(model, train_loader, optimizer, epoch, logbook, train_batch_idx, args) } logbook.write_metric_logs(metrics=metrics) - # if True: - if args.log_intm_frequency > 0 and epoch % args.log_intm_frequency == 0: - """log intermediate variables here""" - pass - # TODO plot inp_ctx - inp_key_log.plot(inp_ctx[0], epoch) # happens in the first batch - inp_value_log.plot(inp_ctx[1], epoch) - inp_query_log.plot(inp_ctx[2], epoch) - pass - # TODO plot comm_ctx - comm_key_log.plot(comm_ctx[0], epoch) - comm_value_log.plot(comm_ctx[1], epoch) - comm_query_log.plot(comm_ctx[2], epoch) + if args.log_intm_frequency > 0 and epoch % args.log_intm_frequency == 0: + """log intermediate variables here""" + pass + # TODO plot inp_ctx + inp_key_log.plot(inp_ctx[0], epoch) # happens in the first batch + inp_value_log.plot(inp_ctx[1], epoch) + inp_query_log.plot(inp_ctx[2], epoch) + pass + # TODO plot comm_ctx + comm_key_log.plot(comm_ctx[0], epoch) + comm_value_log.plot(comm_ctx[1], epoch) + comm_query_log.plot(comm_ctx[2], epoch) + # TODO SAVE logged vectors + encoded_log.save() + attn_score_log.save() + hidden_log.save() epoch_loss += loss.detach() @@ -136,7 +155,7 @@ def main(): shuffle=False ) transfer_loader = test_loader - epoch_loss_log = ScalarLog(args.folder_log, "epoch_loss") + epoch_loss_log = ScalarLog(args.folder_log+'/intermediate_vars', "epoch_loss") for epoch in range(start_epoch, args.epochs+1): train_batch_idx, epoch_loss = train( model = model, diff --git a/utils/visualize.py b/utils/visualize.py index 0b47945..64893c9 100644 --- a/utils/visualize.py +++ b/utils/visualize.py @@ -21,14 +21,19 @@ def plot_frames(batch_of_pred, batch_of_target, start_frame, end_frame, batch_id axs[1, frame-start_frame].imshow(pred[frame,:,:], cmap="Greys") axs[1, frame-start_frame].axis('off') plt.savefig(f'frames_in_batch_{batch_idx}.png', dpi=120) + plt.close() -def plot_curve(loss): - loss = loss.detach().to(torch.device('cpu')).squeeze() +def plot_curve(vector, save_path, filename): + vector = vector.detach().to(torch.device('cpu')).squeeze() fig, axs = plt.subplots(1,1) - axs.plot(loss) - plt.savefig(f"loss_curve.png",dpi=120) + axs.plot(vector) + plt.savefig(save_path +'/'+ filename, dpi=120) + plt.close() def plot_mat(mat, mat_name, epoch): + ''' + deprecated, use HeamapLog instead + ''' if mat.dim() == 3: mat_list = [mat[idx_unit,:,:].squeeze().cpu() for idx_unit in range(mat.shape[0])] else: @@ -45,30 +50,40 @@ def plot_mat(mat, mat_name, epoch): class HeatmapLog: def __init__(self, folder_log, mat_name): - make_dir(f"{folder_log}/intermediate_vars/"+mat_name) - self.save_folder = f"{folder_log}/intermediate_vars/"+mat_name + mat_name = mat_name.replace(' ','_') + make_dir(f"{folder_log}/"+mat_name) + self.save_folder = f"{folder_log}/"+mat_name self.mat_name = mat_name - def plot(self, mat, epoch): + def plot(self, mat, epoch=0): if mat.dim() == 3: mat_list = [mat[idx_unit,:,:].squeeze().cpu() for idx_unit in range(mat.shape[0])] else: mat_list = [mat.cpu()] - fig, axs = plt.subplots(1, len(mat_list), figsize=(2*len(mat_list), 2)) + w_h_ratio = (mat_list[0].shape[1]/mat_list[0].shape[0]) + if w_h_ratio >= 1: + fig, axs = plt.subplots(1, len(mat_list), figsize=(2*len(mat_list)*w_h_ratio, 2)) + else: + fig, axs = plt.subplots(1, len(mat_list), figsize=(2*len(mat_list), 2/(w_h_ratio))) for idx_mat, mat in enumerate(mat_list): if len(mat_list) == 1: axs.imshow(mat, cmap='hot', interpolation='nearest') else: axs[idx_mat].imshow(mat, cmap='hot', interpolation='nearest') fig.suptitle(self.mat_name.replace("_", " ")+ f' in epoch [{epoch}]') - plt.savefig(self.save_folder + '/' + self.mat_name.replace(' ','_') + f'_epoch_{epoch}.png', dpi=120) + plt.savefig(self.save_folder + '/' + self.mat_name + f'_epoch_{epoch}.png', dpi=120) plt.close() class ScalarLog: - def __init__(self, save_path, var_name): - self.save_path = save_path - self.var_name = var_name+".pt" + def __init__(self, folder_log, var_name, epoch=None): + self.save_folder = f"{folder_log}/"+var_name + self.var_name = var_name + make_dir(self.save_folder) + self.var = [] + self.epoch = epoch + + def reset(self): self.var = [] def append(self, value): @@ -76,7 +91,37 @@ def append(self, value): def save(self): var_tensor = torch.tensor(self.var) - torch.save(var_tensor, self.save_path+"/"+self.var_name) + if self.epoch is None: + torch.save(var_tensor, self.save_folder +'/'+ self.var_name + '.pt') + else: + torch.save(var_tensor, self.save_folder +'/'+ self.var_name + f'_epoch_{epoch}.pt') + +class VectorLog: + def __init__(self, folder_log, var_name, epoch=None): + ''' + log a vector in a list. vector is supposed to be 1-dim + ''' + self.save_folder = f"{folder_log}/"+var_name + self.var_name = var_name + make_dir(self.save_folder) + self.var_stack = None + self.epoch = epoch + + def reset(self): + self.var_stack = None + + def append(self, vector): + if self.var_stack is None: + self.var_stack = vector.detach().unsqueeze(0) + else: + self.var_stack = torch.cat((self.var_stack, vector.detach().unsqueeze(0)), 0) + + def save(self): + if self.epoch is None: + torch.save(self.var_stack, self.save_folder +'/'+ self.var_name + '.pt') + else: + torch.save(self.var_stack, self.save_folder +'/'+ self.var_name + f'_epoch_{self.epoch}.pt') + def main(): data = torch.rand((64,51,1,64,64)) diff --git a/visualize_logtensor.py b/visualize_logtensor.py new file mode 100644 index 0000000..041d532 --- /dev/null +++ b/visualize_logtensor.py @@ -0,0 +1,78 @@ +''' +this script is intended to plot logged tensors. +entries in the logged tensor should be like: + 1-dim: tensor(idx) == scalar e.g. loss(scalar) vs. epoch(idx) + 2-dim: tensor(idx) == vector e.g. attn_scores(vector) vs. frame +''' +import torch +import matplotlib.pyplot as plt +from utils.visualize import plot_curve, plot_mat, HeatmapLog +from utils.util import make_dir +import argparse + +def arg_parser(): + parser = argparse.ArgumentParser(description='Visualize Logged Tensor') + parser.add_argument('--folder_log', type=str) + args = parser.parse_args() + + return args + +class TensorVisualizer: + def __init__(self, folder_log, tensor_name): + self.folder_log = folder_log + self.tensor_name = tensor_name.replace(' ','_') + self.save_folder = folder_log+'/'+self.tensor_name + self.filename = folder_log+'/'+self.tensor_name+'/'+self.tensor_name + self.tensor = None + + def log_tensor(self, epoch=None): + if epoch is not None: + self.tensor = torch.load(self.filename+f'_epoch_{epoch}.pt') + else: + self.tensor = torch.load(self.filename+f'.pt') + return 0 + + def plot_logged_tensor(self, epoch=None): # NOTE epoch for curve ????/ + save_path = self.save_folder + make_dir(save_path) + if epoch is not None: + figname = self.tensor_name+f'_epoch_{epoch}.png' + else: + figname = self.tensor_name+f'.png' + if self.tensor.dim() == 1: + plot_curve(self.tensor, save_path, figname) + elif self.tensor.dim() == 2: + # TODO plot a matrix as a Heatmap or? + mat_log = HeatmapLog(save_path, 'figures') + mat_log.plot(self.tensor, epoch) + else: + raise ValueError('tensor.dim should either be 1 or 2!') + return 0 + + def __call__(self, epoch=None): + self.log_tensor(epoch) + self.plot_logged_tensor(epoch) + return 0 + + +def main(): + pass + # TODO parse a folder_log + args = arg_parser() + # TODO load the tensor + loss_plot = TensorVisualizer(args.folder_log, "epoch_loss") + gradnorm_plot = TensorVisualizer(args.folder_log, 'grad_norm') + encoded_plot = TensorVisualizer(args.folder_log, "encoded") + attn_plot = TensorVisualizer(args.folder_log, "attn_score") + # testmat_plot = TensorVisualizer(args.folder_log, "test_mat") + # TODO plot all-epoch tensors + loss_plot() + gradnorm_plot() + + # TODO plot per-epoch tensors + for epoch_idx in range(1,11): + encoded_plot(epoch_idx) + attn_plot(epoch_idx) + +if __name__ == "__main__": + main() \ No newline at end of file From 8d30090456b8cd522ea259ce0ad5bc0afea6a96c Mon Sep 17 00:00:00 2001 From: Nan Date: Wed, 9 Feb 2022 15:20:15 +0100 Subject: [PATCH 04/10] renaming --- .vscode/launch.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 9a47491..77f7306 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -74,7 +74,7 @@ ] }, { - "name": "train bball complete (cuda)", + "name": "train mmnist complete (cuda)", "type": "python", "request": "launch", "program": "${file}", From c9c04540c9ee6f297761bc7fedca9ae804932a1b Mon Sep 17 00:00:00 2001 From: Nan Date: Wed, 9 Feb 2022 17:12:56 +0100 Subject: [PATCH 05/10] better figure --- .gitignore | 3 ++ .vscode/launch.json | 16 +++---- RIM.py | 10 +++++ experiment_mmnist.sh | 4 +- train_mmnist.py | 2 +- utils/visualize.py | 12 +++-- visualize_mmnist.py | 2 +- visualize_result.py | 101 ------------------------------------------- 8 files changed, 33 insertions(+), 117 deletions(-) delete mode 100644 visualize_result.py diff --git a/.gitignore b/.gitignore index b6e4761..36358ae 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,6 @@ dmypy.json # Pyre type checker .pyre/ + +logs/ +data/ \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json index 77f7306..0780b02 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -8,7 +8,7 @@ "console": "integratedTerminal", "args": [ "--folder_log", - "logs/SchemaBlocks_5_2_MMnist_0.0007_num_inp_heads_1_ver_0/intermediate_vars" + "logs/SchemaBlocks_5_5_MMnist_0.0007_num_inp_heads_1_ver_0/intermediate_vars" ] }, { @@ -30,7 +30,7 @@ "--lr", "0.0007", "--num_units", - "2", + "5", "--k", "1", "--num_input_heads", @@ -101,7 +101,7 @@ "--input_key_size", "64", "--input_value_size", - "2040", + "400", "--input_query_size", "64", "--input_dropout", @@ -117,17 +117,17 @@ "--comm_dropout", "0.1", "--experiment_name", - "MMnist", + "MMNIST_complete", "--batch_size", "64", "--epochs", - "50", + "200", "--should_resume", "False", "--batch_frequency_to_log_heatmaps", "10", "--model_persist_frequency", - "1", + "5", "--log_intm_frequency", "10" ] @@ -179,11 +179,11 @@ "--comm_dropout", "0.1", "--experiment_name", - "MMnist", + "MMNIST_mini", "--batch_size", "16", "--epochs", - "10", + "100", "--should_resume", "True", "--batch_frequency_to_log_heatmaps", diff --git a/RIM.py b/RIM.py index aa5d533..c4a4728 100644 --- a/RIM.py +++ b/RIM.py @@ -211,6 +211,12 @@ def input_attention_mask(self, x, h): attention_scores = torch.mean(attention_scores, dim = 1) mask_ = torch.zeros(x.size(0), self.num_units).to(self.device) + ''' + attention_scores: (batch_size, num_heads, num_units, 2) + --> mean in dim-1 + --> (batch_size, num_units, 2) + ''' + not_null_scores = attention_scores[:,:, 0] topk1 = torch.topk(not_null_scores,self.k, dim = 1) row_index = np.arange(x.size(0)) @@ -259,6 +265,10 @@ def communication_attention(self, h, mask): attention_scores = attention_scores / math.sqrt(self.comm_key_size) self.inf_hook(attention_scores) attention_probs = nn.Softmax(dim=-1)(attention_scores) + + """ + attention_scores: (batch_size, num_heads, num_units, num_units) + """ mask = [mask for _ in range(attention_probs.size(1))] mask = torch.stack(mask, dim = 1) diff --git a/experiment_mmnist.sh b/experiment_mmnist.sh index 20711bd..01e2954 100644 --- a/experiment_mmnist.sh +++ b/experiment_mmnist.sh @@ -12,7 +12,7 @@ num_input_heads=1 version=0 rnn_cell="GRU" input_key_size=64 -input_value_size=2040 +input_value_size=400 input_query_size=64 input_dropout=0.1 comm_key_size=32 @@ -20,7 +20,7 @@ comm_value_size=32 comm_query_size=32 num_comm_heads=4 comm_dropout=0.1 -experiment_name="Curtain" +experiment_name="MMNIST_complete" batch_size=64 epochs=100 should_resume="False" diff --git a/train_mmnist.py b/train_mmnist.py index 7827429..afbc823 100644 --- a/train_mmnist.py +++ b/train_mmnist.py @@ -73,7 +73,7 @@ def train(model, train_loader, optimizer, epoch, logbook, train_batch_idx, args) # ----- logging ----- encoded_log.append(encoded_input[-1]) # only take the last sample in a batch attn_score_cat = torch.cat( - (inp_ctx[2][-1].flatten(),comm_ctx[2][-1].flatten()) + (inp_ctx[3][-1].flatten(),comm_ctx[3][-1].flatten()) ) attn_score_log.append(attn_score_cat) hidden_log.append(hidden[-1].flatten()) diff --git a/utils/visualize.py b/utils/visualize.py index 64893c9..e3853f2 100644 --- a/utils/visualize.py +++ b/utils/visualize.py @@ -62,14 +62,18 @@ def plot(self, mat, epoch=0): mat_list = [mat.cpu()] w_h_ratio = (mat_list[0].shape[1]/mat_list[0].shape[0]) if w_h_ratio >= 1: - fig, axs = plt.subplots(1, len(mat_list), figsize=(2*len(mat_list)*w_h_ratio, 2)) + fig, axs = plt.subplots(1, len(mat_list), figsize=(0.5+2*len(mat_list)*w_h_ratio, 2)) + cbar_ax = fig.add_axes([0.90, 0.15, 0.03, 0.7]) # left bottom width height else: - fig, axs = plt.subplots(1, len(mat_list), figsize=(2*len(mat_list), 2/(w_h_ratio))) + fig, axs = plt.subplots(1, len(mat_list), figsize=(2*len(mat_list), 0.5+2/(w_h_ratio))) + cbar_ax = fig.add_axes([0.10, 0.8, 0.7, 0.05]) # left bottom width height for idx_mat, mat in enumerate(mat_list): if len(mat_list) == 1: - axs.imshow(mat, cmap='hot', interpolation='nearest') + im=axs.imshow(mat, cmap='Greys', interpolation='nearest') else: - axs[idx_mat].imshow(mat, cmap='hot', interpolation='nearest') + im=axs[idx_mat].imshow(mat, cmap='Greys', interpolation='nearest') + + fig.colorbar(im, cax=cbar_ax) fig.suptitle(self.mat_name.replace("_", " ")+ f' in epoch [{epoch}]') plt.savefig(self.save_folder + '/' + self.mat_name + f'_epoch_{epoch}.png', dpi=120) plt.close() diff --git a/visualize_mmnist.py b/visualize_mmnist.py index 9c7054b..a341549 100644 --- a/visualize_mmnist.py +++ b/visualize_mmnist.py @@ -51,7 +51,7 @@ def test(model, test_loader, args): prediction = torch.zeros_like(data) for frame in range(data.shape[1]-1): - output, hidden = model(data[:, frame, :, :, :], hidden) + output, hidden, *_ = model(data[:, frame, :, :, :], hidden) nan_hook(output) nan_hook(hidden) diff --git a/visualize_result.py b/visualize_result.py deleted file mode 100644 index 85028b0..0000000 --- a/visualize_result.py +++ /dev/null @@ -1,101 +0,0 @@ -"""Main entry point of the code""" -from __future__ import print_function - -import os -from time import time - -import matplotlib.pyplot as plt -import numpy as np -import torch - -from networks import BallModel -# from model_components import GruState -from argument_parser import argument_parser -from dataset import get_dataloaders -from logbook.logbook import LogBook -from utils.util import set_seed, make_dir -from utils.visualize import plot_frames, plot_curve -from box import Box - -import os -from os import listdir -from os.path import isfile, join - -set_seed(0) - -loss_fn = torch.nn.BCELoss() - -def repackage_hidden(ten_): - """Wraps hidden states in new Tensors, to detach them from their history.""" - if isinstance(ten_, torch.Tensor): - return ten_.detach() - else: - return tuple(repackage_hidden(v) for v in ten_) - - -def main(): - """Function to run the experiment""" - args = argument_parser() - args.id = f"SchemaBlocks_{args.hidden_size}_{args.num_units}"+\ - f"_{args.experiment_name}_{args.lr}_num_inp_heads_{args.num_input_heads}"+\ - f"_ver_{args.version}" - # name="SchemaBlocks_"$dim1"_"$block1"_"$topk1"_"$something"_"$lr"_inp_heads_"$inp_heads"_templates_"$templates"_enc_"$encoder"_ver_"$version"_com_"$comm"_Sharing" - print(args) - logbook = LogBook(config=args) - - if not args.should_resume: - # New Experiment - make_dir(f"{args.folder_log}/model") - logbook.write_message_logs(message=f"Saving args to {args.folder_log}/model/args") - torch.save({"args": vars(args)}, f"{args.folder_log}/model/args") - - use_cuda = torch.cuda.is_available() - args.device = torch.device("cuda" if use_cuda else "cpu") - # args.device = torch.device("cpu") - model = setup_model(args=args, logbook=logbook) - - optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) - - args.directory = './data' # dataset directory - # args.directory = 'D:\Projecten\Recurrent-Independent-Mechanisms\data' # dataset directory, windows os - train_loader, test_loader, transfer_loader = get_dataloaders(args) - - train_batch_idx = 0 - - start_epoch = 1 - for epoch in range(start_epoch,start_epoch+1): - model.eval() - with torch.no_grad(): - data = next(iter(train_loader)) - hidden = model.init_hidden(data.shape[0]).to(args.device) # NOTE initialize per epoch or per batch [??] - data = data.to(args.device) - hidden = hidden.detach() - pred = torch.zeros_like(data) - for frame in range(49): - output, hidden = model(data[:, frame, :, :, :], hidden) - pred[:,frame+1,:,:,:] = output - - pred = pred[:,1:,:,:,:] - plot_frames(pred, data, 10, 20, 6) - -def setup_model(args, logbook): - """Method to setup the model""" - - model = BallModel(args) - if args.should_resume: - # Find the last checkpointed model and resume from that - model_dir = f"{args.folder_log}/model" - latest_model_idx = max( - [int(model_idx) for model_idx in listdir(model_dir) - if model_idx != "args"] - ) - args.path_to_load_model = f"{model_dir}/{latest_model_idx}" - args.checkpoint = {"epoch": latest_model_idx} - else: - assert False, 'set args.should_resume true!' - - return model - - -if __name__ == '__main__': - main() From b283f54a4b67929056e4bd754fb1e5e6aba04926 Mon Sep 17 00:00:00 2001 From: Nan Date: Wed, 9 Feb 2022 23:13:37 +0100 Subject: [PATCH 06/10] small fixes --- .vscode/launch.json | 10 +++++----- train_mmnist.py | 2 +- visualize_logtensor.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 0780b02..4deac72 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -8,7 +8,7 @@ "console": "integratedTerminal", "args": [ "--folder_log", - "logs/SchemaBlocks_5_5_MMnist_0.0007_num_inp_heads_1_ver_0/intermediate_vars" + "logs/SchemaBlocks_10_3_MMNIST_mini_0.0007_num_inp_heads_1_ver_0/intermediate_vars" ] }, { @@ -163,7 +163,7 @@ "--input_key_size", "10", "--input_value_size", - "10", + "40", "--input_query_size", "10", "--input_dropout", @@ -181,15 +181,15 @@ "--experiment_name", "MMNIST_mini", "--batch_size", - "16", + "64", "--epochs", "100", "--should_resume", - "True", + "False", "--batch_frequency_to_log_heatmaps", "1", "--model_persist_frequency", - "1", + "5", "--log_intm_frequency", "10" ] diff --git a/train_mmnist.py b/train_mmnist.py index afbc823..96c7cf1 100644 --- a/train_mmnist.py +++ b/train_mmnist.py @@ -141,7 +141,7 @@ def main(): model, optimizer, start_epoch, train_batch_idx = setup_model(args=args, logbook=logbook) - train_set = MovingMNIST(root='./data', train=True, download=True, mini=True) + train_set = MovingMNIST(root='./data', train=True, download=True, mini=False) test_set = MovingMNIST(root='./data', train=False, download=True) train_loader = torch.utils.data.DataLoader( diff --git a/visualize_logtensor.py b/visualize_logtensor.py index 041d532..f60f110 100644 --- a/visualize_logtensor.py +++ b/visualize_logtensor.py @@ -70,7 +70,7 @@ def main(): gradnorm_plot() # TODO plot per-epoch tensors - for epoch_idx in range(1,11): + for epoch_idx in range(10,110,10): encoded_plot(epoch_idx) attn_plot(epoch_idx) From c81e2885ca13720f8c556bf9b70e6d3dfee75f2b Mon Sep 17 00:00:00 2001 From: Nan Date: Wed, 9 Feb 2022 23:37:16 +0100 Subject: [PATCH 07/10] epoch_loss_log fixed & cmap --- train_mmnist.py | 13 ++++++++----- utils/visualize.py | 6 +++--- visualize_logtensor.py | 4 +++- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/train_mmnist.py b/train_mmnist.py index 96c7cf1..884e3ba 100644 --- a/train_mmnist.py +++ b/train_mmnist.py @@ -45,7 +45,7 @@ def train(model, train_loader, optimizer, epoch, logbook, train_batch_idx, args) comm_value_log = HeatmapLog(intm_log_folder, "comm value mat") comm_query_log = HeatmapLog(intm_log_folder, "comm query mat") - grad_norm_log = ScalarLog(intm_log_folder, "grad_norm") + grad_norm_log = ScalarLog(intm_log_folder, "grad_norm", epoch=epoch) encoded_log = VectorLog(intm_log_folder, "encoded", epoch=epoch) # TODO put them in a folder attn_score_log = VectorLog(intm_log_folder, "attn_score", epoch=epoch) hidden_log = VectorLog(intm_log_folder, "hidden_state", epoch=epoch) @@ -139,9 +139,9 @@ def main(): cudable = torch.cuda.is_available() args.device = torch.device("cuda" if cudable else "cpu") - model, optimizer, start_epoch, train_batch_idx = setup_model(args=args, logbook=logbook) + model, optimizer, start_epoch, train_batch_idx, epoch_loss_log = setup_model(args=args, logbook=logbook) - train_set = MovingMNIST(root='./data', train=True, download=True, mini=False) + train_set = MovingMNIST(root='./data', train=True, download=True, mini=True) test_set = MovingMNIST(root='./data', train=False, download=True) train_loader = torch.utils.data.DataLoader( @@ -155,7 +155,7 @@ def main(): shuffle=False ) transfer_loader = test_loader - epoch_loss_log = ScalarLog(args.folder_log+'/intermediate_vars', "epoch_loss") + # epoch_loss_log = ScalarLog(args.folder_log+'/intermediate_vars', "epoch_loss") for epoch in range(start_epoch, args.epochs+1): train_batch_idx, epoch_loss = train( model = model, @@ -178,6 +178,7 @@ def main(): 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': epoch_loss, + 'epoch_loss_log': epoch_loss_log }, f"{args.folder_log}/checkpoints/{epoch}") def setup_model(args, logbook): @@ -185,6 +186,7 @@ def setup_model(args, logbook): optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) start_epoch = 1 train_batch_idx = 0 + epoch_loss_log = ScalarLog(args.folder_log+'/intermediate_vars', "epoch_loss") if args.should_resume: # Find the last checkpointed model and resume from that model_dir = f"{args.folder_log}/checkpoints" @@ -201,10 +203,11 @@ def setup_model(args, logbook): optimizer.load_state_dict(checkpoint['optimizer_state_dict']) start_epoch = checkpoint['epoch'] + 1 loss = checkpoint['epoch'] + epoch_loss_log = checkpoint['epoch_loss_log'] logbook.write_message_logs(message=f"Resuming experiment id: {args.id}, from epoch: {start_epoch}") - return model, optimizer, start_epoch, train_batch_idx + return model, optimizer, start_epoch, train_batch_idx, epoch_loss_log if __name__ == '__main__': main() diff --git a/utils/visualize.py b/utils/visualize.py index e3853f2..55147fd 100644 --- a/utils/visualize.py +++ b/utils/visualize.py @@ -62,10 +62,10 @@ def plot(self, mat, epoch=0): mat_list = [mat.cpu()] w_h_ratio = (mat_list[0].shape[1]/mat_list[0].shape[0]) if w_h_ratio >= 1: - fig, axs = plt.subplots(1, len(mat_list), figsize=(0.5+2*len(mat_list)*w_h_ratio, 2)) + fig, axs = plt.subplots(1, len(mat_list), figsize=(0.75+2*len(mat_list)*w_h_ratio, 2)) cbar_ax = fig.add_axes([0.90, 0.15, 0.03, 0.7]) # left bottom width height else: - fig, axs = plt.subplots(1, len(mat_list), figsize=(2*len(mat_list), 0.5+2/(w_h_ratio))) + fig, axs = plt.subplots(1, len(mat_list), figsize=(2*len(mat_list), 0.75+2/(w_h_ratio))) cbar_ax = fig.add_axes([0.10, 0.8, 0.7, 0.05]) # left bottom width height for idx_mat, mat in enumerate(mat_list): if len(mat_list) == 1: @@ -98,7 +98,7 @@ def save(self): if self.epoch is None: torch.save(var_tensor, self.save_folder +'/'+ self.var_name + '.pt') else: - torch.save(var_tensor, self.save_folder +'/'+ self.var_name + f'_epoch_{epoch}.pt') + torch.save(var_tensor, self.save_folder +'/'+ self.var_name + f'_epoch_{self.epoch}.pt') class VectorLog: def __init__(self, folder_log, var_name, epoch=None): diff --git a/visualize_logtensor.py b/visualize_logtensor.py index f60f110..31014c8 100644 --- a/visualize_logtensor.py +++ b/visualize_logtensor.py @@ -64,15 +64,17 @@ def main(): gradnorm_plot = TensorVisualizer(args.folder_log, 'grad_norm') encoded_plot = TensorVisualizer(args.folder_log, "encoded") attn_plot = TensorVisualizer(args.folder_log, "attn_score") + hidden_plot = TensorVisualizer(args.folder_log, "hidden_state") # testmat_plot = TensorVisualizer(args.folder_log, "test_mat") # TODO plot all-epoch tensors loss_plot() - gradnorm_plot() # TODO plot per-epoch tensors for epoch_idx in range(10,110,10): + gradnorm_plot(epoch_idx) encoded_plot(epoch_idx) attn_plot(epoch_idx) + hidden_plot(epoch_idx) if __name__ == "__main__": main() \ No newline at end of file From 03e8367d52db89c817de279b2abe8262651e4adc Mon Sep 17 00:00:00 2001 From: Nan Date: Wed, 9 Feb 2022 23:56:29 +0100 Subject: [PATCH 08/10] bash fix --- experiment_mmnist.sh | 7 ++++--- train_mmnist.py | 2 +- utils/visualize.py | 3 ++- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/experiment_mmnist.sh b/experiment_mmnist.sh index 01e2954..8761358 100644 --- a/experiment_mmnist.sh +++ b/experiment_mmnist.sh @@ -22,9 +22,10 @@ num_comm_heads=4 comm_dropout=0.1 experiment_name="MMNIST_complete" batch_size=64 -epochs=100 +epochs=200 should_resume="False" batch_frequency_to_log_heatmaps=10 -model_persist_frequency=1 +model_persist_frequency=5 +log_intm_frequency=10 -nohup python3 train_mmnist.py --train_dataset $train_dataset --hidden_size $hidden_size --should_save_csv $should_save_csv --lr $lr --num_units $num_units --k $k --num_input_heads $num_input_heads --version $version --rnn_cell $rnn_cell --input_key_size $input_key_size --input_value_size $input_value_size --input_query_size $input_query_size --input_dropout $input_dropout --comm_key_size $comm_key_size --comm_value_size $comm_value_size --comm_query_size $comm_query_size --num_comm_heads $num_comm_heads --comm_dropout $comm_dropout --experiment_name $experiment_name --batch_size $batch_size --epochs $epochs --should_resume $should_resume --batch_frequency_to_log_heatmaps $batch_frequency_to_log_heatmaps --model_persist_frequency $model_persist_frequency +nohup python3 train_mmnist.py --train_dataset $train_dataset --hidden_size $hidden_size --should_save_csv $should_save_csv --lr $lr --num_units $num_units --k $k --num_input_heads $num_input_heads --version $version --rnn_cell $rnn_cell --input_key_size $input_key_size --input_value_size $input_value_size --input_query_size $input_query_size --input_dropout $input_dropout --comm_key_size $comm_key_size --comm_value_size $comm_value_size --comm_query_size $comm_query_size --num_comm_heads $num_comm_heads --comm_dropout $comm_dropout --experiment_name $experiment_name --batch_size $batch_size --epochs $epochs --should_resume $should_resume --batch_frequency_to_log_heatmaps $batch_frequency_to_log_heatmaps --model_persist_frequency $model_persist_frequency --log_intm_frequency $log_intm_frequency diff --git a/train_mmnist.py b/train_mmnist.py index 884e3ba..569533d 100644 --- a/train_mmnist.py +++ b/train_mmnist.py @@ -141,7 +141,7 @@ def main(): model, optimizer, start_epoch, train_batch_idx, epoch_loss_log = setup_model(args=args, logbook=logbook) - train_set = MovingMNIST(root='./data', train=True, download=True, mini=True) + train_set = MovingMNIST(root='./data', train=True, download=True, mini=False) test_set = MovingMNIST(root='./data', train=False, download=True) train_loader = torch.utils.data.DataLoader( diff --git a/utils/visualize.py b/utils/visualize.py index 55147fd..1b33bb4 100644 --- a/utils/visualize.py +++ b/utils/visualize.py @@ -66,7 +66,8 @@ def plot(self, mat, epoch=0): cbar_ax = fig.add_axes([0.90, 0.15, 0.03, 0.7]) # left bottom width height else: fig, axs = plt.subplots(1, len(mat_list), figsize=(2*len(mat_list), 0.75+2/(w_h_ratio))) - cbar_ax = fig.add_axes([0.10, 0.8, 0.7, 0.05]) # left bottom width height + cbar_ax = fig.add_axes([0.10, 0.85, 0.8, 0.03]) # left bottom width height + char_ax.orientation = 'horizontal' for idx_mat, mat in enumerate(mat_list): if len(mat_list) == 1: im=axs.imshow(mat, cmap='Greys', interpolation='nearest') From 5de47dce3df59e5bc8fa9681b3d5b881894b7008 Mon Sep 17 00:00:00 2001 From: Nan Date: Thu, 10 Feb 2022 00:03:45 +0100 Subject: [PATCH 09/10] typo fix --- utils/visualize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/visualize.py b/utils/visualize.py index 1b33bb4..88d2b35 100644 --- a/utils/visualize.py +++ b/utils/visualize.py @@ -67,7 +67,7 @@ def plot(self, mat, epoch=0): else: fig, axs = plt.subplots(1, len(mat_list), figsize=(2*len(mat_list), 0.75+2/(w_h_ratio))) cbar_ax = fig.add_axes([0.10, 0.85, 0.8, 0.03]) # left bottom width height - char_ax.orientation = 'horizontal' + cbar_ax.orientation = 'horizontal' for idx_mat, mat in enumerate(mat_list): if len(mat_list) == 1: im=axs.imshow(mat, cmap='Greys', interpolation='nearest') From 00261d98288e2a0b008fd3d79aa58c816e80da0a Mon Sep 17 00:00:00 2001 From: Nan Date: Thu, 10 Feb 2022 16:47:15 +0100 Subject: [PATCH 10/10] epoch_loss_log fix --- train_mmnist.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/train_mmnist.py b/train_mmnist.py index 569533d..18d27b6 100644 --- a/train_mmnist.py +++ b/train_mmnist.py @@ -99,6 +99,7 @@ def train(model, train_loader, optimizer, epoch, logbook, train_batch_idx, args) "time_taken": time() - start_time, } logbook.write_metric_logs(metrics=metrics) + epoch_loss += loss.detach() if args.log_intm_frequency > 0 and epoch % args.log_intm_frequency == 0: """log intermediate variables here""" @@ -116,8 +117,6 @@ def train(model, train_loader, optimizer, epoch, logbook, train_batch_idx, args) encoded_log.save() attn_score_log.save() hidden_log.save() - - epoch_loss += loss.detach() epoch_loss = epoch_loss / (batch_idx+1) return train_batch_idx, epoch_loss