diff --git a/.gitignore b/.gitignore index b6e4761..36358ae 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,6 @@ dmypy.json # Pyre type checker .pyre/ + +logs/ +data/ \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json index 24dcc93..4deac72 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -1,7 +1,80 @@ { "configurations": [ { - "name": "train bball complete (cuda)", + "name": "visualize logged tensor", + "type": "python", + "request": "launch", + "program": "./visualize_logtensor.py", + "console": "integratedTerminal", + "args": [ + "--folder_log", + "logs/SchemaBlocks_10_3_MMNIST_mini_0.0007_num_inp_heads_1_ver_0/intermediate_vars" + ] + }, + { + "name": "minimal: only for debugging code", + "type": "python", + "request": "launch", + "program": "./train_mmnist.py", + "console": "integratedTerminal", + "args": [ + // "--mini", + "--train_dataset", + "balls3curtain64.h5", + "--transfer_dataset", + "balls3curtain64.h5", + "--hidden_size", + "5", + "--should_save_csv", + "False", + "--lr", + "0.0007", + "--num_units", + "5", + "--k", + "1", + "--num_input_heads", + "1", + "--version", + "0", + "--rnn_cell", + "GRU", + "--input_key_size", + "3", + "--input_value_size", + "4", + "--input_query_size", + "3", //actually not used. altijd == input_key_size + "--input_dropout", + "0.1", + "--comm_key_size", + "3", + "--comm_value_size", + "4", //actually not used. altijd == hidden_size because output a hidden_state + "--comm_query_size", + "3", + "--num_comm_heads", + "2", + "--comm_dropout", + "0.1", + "--experiment_name", + "MMnist", + "--batch_size", + "16", + "--epochs", + "10", + "--should_resume", + "False", + "--batch_frequency_to_log_heatmaps", + "1", + "--model_persist_frequency", + "1", + "--log_intm_frequency", + "1" + ] + }, + { + "name": "train mmnist complete (cuda)", "type": "python", "request": "launch", "program": "${file}", @@ -28,7 +101,7 @@ "--input_key_size", "64", "--input_value_size", - "2040", + "400", "--input_query_size", "64", "--input_dropout", @@ -44,17 +117,19 @@ "--comm_dropout", "0.1", "--experiment_name", - "Curtain", + "MMNIST_complete", "--batch_size", "64", "--epochs", - "50", + "200", "--should_resume", "False", "--batch_frequency_to_log_heatmaps", "10", "--model_persist_frequency", - "1" + "5", + "--log_intm_frequency", + "10" ] }, { @@ -88,7 +163,7 @@ "--input_key_size", "10", "--input_value_size", - "10", + "40", "--input_query_size", "10", "--input_dropout", @@ -104,88 +179,19 @@ "--comm_dropout", "0.1", "--experiment_name", - "Curtain", + "MMNIST_mini", "--batch_size", - "16", + "64", "--epochs", - "10", + "100", "--should_resume", - "True", + "False", "--batch_frequency_to_log_heatmaps", "1", "--model_persist_frequency", - "1" - ] - }, - { - "name": "main: RIM (cuda)", - "type": "python", - "request": "launch", - "program": "${file}", - "console": "integratedTerminal", - "args": [ - "--cuda", - "True", - "--epochs", - "6", - "--batch_size", - "64", - "--input_size", - "1", - "--hidden_size", - "600", - "--size", - "14", - "--loadsaved", - "0", - "--model", - "RIM" - ] - }, - { - "name": "main: LSTM", - "type": "python", - "request": "launch", - "program": "${file}", - "console": "integratedTerminal", - "args": [ - "--cuda", - "False", - "--epochs", - "1", - "--batch_size", - "64", - "--input_size", - "1", - "--size", - "14", - "--loadsaved", - "0", - "--model", - "LSTM" - ] - }, - { - "name": "main: RIM", - "type": "python", - "request": "launch", - "program": "${file}", - "console": "integratedTerminal", - "args": [ - "--cuda", - "False", - "--epochs", - "7", - "--batch_size", - "64", - "--input_size", - "14", - "--size", - "14", - "--loadsaved", - "0", - "--model", - "RIM" + "5", + "--log_intm_frequency", + "10" ] } ] diff --git a/RIM.py b/RIM.py index 8727cc4..c4a4728 100644 --- a/RIM.py +++ b/RIM.py @@ -211,6 +211,12 @@ def input_attention_mask(self, x, h): attention_scores = torch.mean(attention_scores, dim = 1) mask_ = torch.zeros(x.size(0), self.num_units).to(self.device) + ''' + attention_scores: (batch_size, num_heads, num_units, 2) + --> mean in dim-1 + --> (batch_size, num_units, 2) + ''' + not_null_scores = attention_scores[:,:, 0] topk1 = torch.topk(not_null_scores,self.k, dim = 1) row_index = np.arange(x.size(0)) @@ -222,7 +228,17 @@ def input_attention_mask(self, x, h): attention_probs = self.input_dropout(nn.Softmax(dim = -1)(attention_scores)) inputs = torch.matmul(attention_probs, value_layer) * mask_.unsqueeze(2) - return inputs, mask_ + """ + parameters: + key_mat: input_size x key_size -> input 2 key + value_mat: input_size x val_size -> input 2 value + query_mat: num_units x hidden_size x key_size -> hidden 2 query + """ + key_mat = get_mat(self.key).transpose(0,1).detach() # TODO detach? + value_mat = get_mat(self.value).transpose(0,1).detach() + query_mat = self.query.w.detach() + + return inputs, mask_, key_mat, value_mat, query_mat, attention_scores def communication_attention(self, h, mask): """ @@ -249,6 +265,10 @@ def communication_attention(self, h, mask): attention_scores = attention_scores / math.sqrt(self.comm_key_size) self.inf_hook(attention_scores) attention_probs = nn.Softmax(dim=-1)(attention_scores) + + """ + attention_scores: (batch_size, num_heads, num_units, num_units) + """ mask = [mask for _ in range(attention_probs.size(1))] mask = torch.stack(mask, dim = 1) @@ -263,8 +283,18 @@ def communication_attention(self, h, mask): context_layer = context_layer.view(*new_context_layer_shape) context_layer = self.comm_attention_output(context_layer) context_layer = context_layer + h - - return context_layer + + """ + parameters: + key_mat: num_units x hidden_size x (key_size*num_heads) -> hidden 2 key + value_mat: num_units x hidden_size x (val_size*num_heads==hidden_size*num_heads) -> hidden 2 (value==new hidden_state) + query_mat: num_units x hidden_size x (key_size*num_heads) -> hidden 2 query + """ + key_mat = self.key_.w.detach() + value_mat = self.value_.w.detach() + query_mat = self.query_.w.detach() + + return context_layer, key_mat, value_mat, query_mat, attention_scores def nan_hook(self, out): nan_mask = torch.isnan(out) @@ -290,7 +320,7 @@ def forward(self, x, hs, cs = None): x = torch.cat((x.unsqueeze(1), null_input), dim = 1) # Compute input attention - inputs, mask = self.input_attention_mask(x, hs) + inputs, mask, *inp_ctx = self.input_attention_mask(x, hs) h_old = hs * 1.0 if cs is not None: c_old = cs * 1.0 @@ -309,15 +339,15 @@ def forward(self, x, hs, cs = None): h_new = blocked_grad.apply(hs, mask) # Compute communication attention - h_new = self.communication_attention(h_new, mask.squeeze(2)) + h_new, *comm_ctx = self.communication_attention(h_new, mask.squeeze(2)) self.nan_hook(h_new) hs = mask * h_new + (1 - mask) * h_old if cs is not None: cs = mask * cs + (1 - mask) * c_old - return hs, cs, None + return hs, cs, None, inp_ctx, comm_ctx self.nan_hook(hs) - return hs, None, None + return hs, None, None, inp_ctx, comm_ctx class RIM(nn.Module): @@ -625,4 +655,9 @@ def __init__(self, scale_factor, mode): def forward(self, x): x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=False) - return x \ No newline at end of file + return x + +def get_mat(linear_layer): + '''get linear transformation matrix in a linear layer''' + p_list = [p for p in linear_layer.parameters()] + return p_list[0] diff --git a/argument_parser.py b/argument_parser.py index f71315a..5050c85 100644 --- a/argument_parser.py +++ b/argument_parser.py @@ -71,6 +71,9 @@ def argument_parser(): help='Number of training epochs after which model is to ' 'be persisted. -1 means that the model is not' 'persisted') + parser.add_argument('--log_intm_frequency', type=int, default=10, + metavar="Frequency at which we log the intermediate variables of the model", + help='Just type in a positive integer') parser.add_argument('--batch_frequency_to_log_heatmaps', type=int, default=-1, metavar='Frequency at which the heatmaps are persisted', help='Number of training batches after which we will persit the ' diff --git a/data/MovingMNIST.py b/data/MovingMNIST.py index b9edc08..8fdb7d3 100644 --- a/data/MovingMNIST.py +++ b/data/MovingMNIST.py @@ -35,11 +35,12 @@ class MovingMNIST(data.Dataset): training_file = 'moving_mnist_train.pt' test_file = 'moving_mnist_test.pt' - def __init__(self, root, train=True, split=1000, transform=None, download=False): + def __init__(self, root, train=True, split=1000, transform=None, download=False, mini=False): self.root = os.path.expanduser(root) self.transform = transform self.split = split self.train = train # training set or test set + self.mini = mini if download: self.download() @@ -85,9 +86,15 @@ def _transform_time(data): def __len__(self): if self.train: - return len(self.train_data) + if not self.mini: + return len(self.train_data) + else: + return min(len(self.train_data) // 50 , 100) else: - return len(self.test_data) + if not self.mini: + return len(self.test_data) + else: + return min(len(self.test_data) // 50 , 100) def _check_exists(self): return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \ diff --git a/experiment_mmnist.sh b/experiment_mmnist.sh index 20711bd..8761358 100644 --- a/experiment_mmnist.sh +++ b/experiment_mmnist.sh @@ -12,7 +12,7 @@ num_input_heads=1 version=0 rnn_cell="GRU" input_key_size=64 -input_value_size=2040 +input_value_size=400 input_query_size=64 input_dropout=0.1 comm_key_size=32 @@ -20,11 +20,12 @@ comm_value_size=32 comm_query_size=32 num_comm_heads=4 comm_dropout=0.1 -experiment_name="Curtain" +experiment_name="MMNIST_complete" batch_size=64 -epochs=100 +epochs=200 should_resume="False" batch_frequency_to_log_heatmaps=10 -model_persist_frequency=1 +model_persist_frequency=5 +log_intm_frequency=10 -nohup python3 train_mmnist.py --train_dataset $train_dataset --hidden_size $hidden_size --should_save_csv $should_save_csv --lr $lr --num_units $num_units --k $k --num_input_heads $num_input_heads --version $version --rnn_cell $rnn_cell --input_key_size $input_key_size --input_value_size $input_value_size --input_query_size $input_query_size --input_dropout $input_dropout --comm_key_size $comm_key_size --comm_value_size $comm_value_size --comm_query_size $comm_query_size --num_comm_heads $num_comm_heads --comm_dropout $comm_dropout --experiment_name $experiment_name --batch_size $batch_size --epochs $epochs --should_resume $should_resume --batch_frequency_to_log_heatmaps $batch_frequency_to_log_heatmaps --model_persist_frequency $model_persist_frequency +nohup python3 train_mmnist.py --train_dataset $train_dataset --hidden_size $hidden_size --should_save_csv $should_save_csv --lr $lr --num_units $num_units --k $k --num_input_heads $num_input_heads --version $version --rnn_cell $rnn_cell --input_key_size $input_key_size --input_value_size $input_value_size --input_query_size $input_query_size --input_dropout $input_dropout --comm_key_size $comm_key_size --comm_value_size $comm_value_size --comm_query_size $comm_query_size --num_comm_heads $num_comm_heads --comm_dropout $comm_dropout --experiment_name $experiment_name --batch_size $batch_size --epochs $epochs --should_resume $should_resume --batch_frequency_to_log_heatmaps $batch_frequency_to_log_heatmaps --model_persist_frequency $model_persist_frequency --log_intm_frequency $log_intm_frequency diff --git a/logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/log.txt b/logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/log.txt deleted file mode 100644 index f045fcc..0000000 --- a/logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/log.txt +++ /dev/null @@ -1,37 +0,0 @@ -{"batch_size": 16, "epochs": 2, "num_input_heads": 1, "sequence_length": 51, "lr": 0.0007, "input_dropout": 0.1, "comm_dropout": 0.1, "kl_coeff": 0.0, "num_units": 3, "num_encoders": 1, "k": 2, "memorytopk": 4, "hidden_size": 10, "n_templates": 0, "share_inp": false, "share_comm": false, "do_rel": false, "memory_slots": 4, "memory_mlp": 4, "attention_out": 340, "id": "SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "rnn_cell": "GRU", "model_persist_frequency": 1, "batch_frequency_to_log_heatmaps": 1, "path_to_load_model": "", "components_to_load": "", "train_dataset": "balls3curtain64.h5", "test_dataset": null, "transfer_dataset": "balls3curtain64.h5", "should_save_csv": false, "should_resume": false, "experiment_name": "Curtain", "version": 0, "do_comm": true, "input_key_size": 10, "input_value_size": 10, "input_query_size": 10, "comm_key_size": 5, "comm_value_size": 5, "comm_query_size": 5, "num_comm_heads": 2, "frame_frequency_to_log_heatmaps": 5, "folder_log": "./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "type": "config", "timestamp": "06:29PM CET Feb 03, 2022"} -{"messgae": "Saving args to ./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args", "experiment_id": 1, "type": "print", "timestamp": "06:29PM CET Feb 03, 2022"} -{"batch_size": 16, "epochs": 2, "num_input_heads": 1, "sequence_length": 51, "lr": 0.0007, "input_dropout": 0.1, "comm_dropout": 0.1, "kl_coeff": 0.0, "num_units": 3, "num_encoders": 1, "k": 2, "memorytopk": 4, "hidden_size": 10, "n_templates": 0, "share_inp": false, "share_comm": false, "do_rel": false, "memory_slots": 4, "memory_mlp": 4, "attention_out": 340, "id": "SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "rnn_cell": "GRU", "model_persist_frequency": 1, "batch_frequency_to_log_heatmaps": 1, "path_to_load_model": "", "components_to_load": "", "train_dataset": "balls3curtain64.h5", "test_dataset": null, "transfer_dataset": "balls3curtain64.h5", "should_save_csv": false, "should_resume": false, "experiment_name": "Curtain", "version": 0, "do_comm": true, "input_key_size": 10, "input_value_size": 10, "input_query_size": 10, "comm_key_size": 5, "comm_value_size": 5, "comm_query_size": 5, "num_comm_heads": 2, "frame_frequency_to_log_heatmaps": 5, "folder_log": "./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "type": "config", "timestamp": "06:29PM CET Feb 03, 2022"} -{"messgae": "Saving args to ./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args", "experiment_id": 1, "type": "print", "timestamp": "06:29PM CET Feb 03, 2022"} -{"batch_size": 16, "epochs": 2, "num_input_heads": 1, "sequence_length": 51, "lr": 0.0007, "input_dropout": 0.1, "comm_dropout": 0.1, "kl_coeff": 0.0, "num_units": 3, "num_encoders": 1, "k": 2, "memorytopk": 4, "hidden_size": 10, "n_templates": 0, "share_inp": false, "share_comm": false, "do_rel": false, "memory_slots": 4, "memory_mlp": 4, "attention_out": 340, "id": "SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "rnn_cell": "GRU", "model_persist_frequency": 1, "batch_frequency_to_log_heatmaps": 1, "path_to_load_model": "", "components_to_load": "", "train_dataset": "balls3curtain64.h5", "test_dataset": null, "transfer_dataset": "balls3curtain64.h5", "should_save_csv": false, "should_resume": false, "experiment_name": "Curtain", "version": 0, "do_comm": true, "input_key_size": 10, "input_value_size": 10, "input_query_size": 10, "comm_key_size": 5, "comm_value_size": 5, "comm_query_size": 5, "num_comm_heads": 2, "frame_frequency_to_log_heatmaps": 5, "folder_log": "./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "type": "config", "timestamp": "06:30PM CET Feb 03, 2022"} -{"messgae": "Saving args to ./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args", "experiment_id": 1, "type": "print", "timestamp": "06:30PM CET Feb 03, 2022"} -{"batch_size": 16, "epochs": 2, "num_input_heads": 1, "sequence_length": 51, "lr": 0.0007, "input_dropout": 0.1, "comm_dropout": 0.1, "kl_coeff": 0.0, "num_units": 3, "num_encoders": 1, "k": 2, "memorytopk": 4, "hidden_size": 10, "n_templates": 0, "share_inp": false, "share_comm": false, "do_rel": false, "memory_slots": 4, "memory_mlp": 4, "attention_out": 340, "id": "SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "rnn_cell": "GRU", "model_persist_frequency": 1, "batch_frequency_to_log_heatmaps": 1, "path_to_load_model": "", "components_to_load": "", "train_dataset": "balls3curtain64.h5", "test_dataset": null, "transfer_dataset": "balls3curtain64.h5", "should_save_csv": false, "should_resume": false, "experiment_name": "Curtain", "version": 0, "do_comm": true, "input_key_size": 10, "input_value_size": 10, "input_query_size": 10, "comm_key_size": 5, "comm_value_size": 5, "comm_query_size": 5, "num_comm_heads": 2, "frame_frequency_to_log_heatmaps": 5, "folder_log": "./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "type": "config", "timestamp": "06:50PM CET Feb 03, 2022"} -{"messgae": "Saving args to ./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args", "experiment_id": 1, "type": "print", "timestamp": "06:50PM CET Feb 03, 2022"} -{"batch_size": 16, "epochs": 2, "num_input_heads": 1, "sequence_length": 51, "lr": 0.0007, "input_dropout": 0.1, "comm_dropout": 0.1, "kl_coeff": 0.0, "num_units": 3, "num_encoders": 1, "k": 2, "memorytopk": 4, "hidden_size": 10, "n_templates": 0, "share_inp": false, "share_comm": false, "do_rel": false, "memory_slots": 4, "memory_mlp": 4, "attention_out": 340, "id": "SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "rnn_cell": "GRU", "model_persist_frequency": 1, "batch_frequency_to_log_heatmaps": 1, "path_to_load_model": "", "components_to_load": "", "train_dataset": "balls3curtain64.h5", "test_dataset": null, "transfer_dataset": "balls3curtain64.h5", "should_save_csv": false, "should_resume": false, "experiment_name": "Curtain", "version": 0, "do_comm": true, "input_key_size": 10, "input_value_size": 10, "input_query_size": 10, "comm_key_size": 5, "comm_value_size": 5, "comm_query_size": 5, "num_comm_heads": 2, "frame_frequency_to_log_heatmaps": 5, "folder_log": "./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "type": "config", "timestamp": "06:52PM CET Feb 03, 2022"} -{"messgae": "Saving args to ./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args", "experiment_id": 1, "type": "print", "timestamp": "06:52PM CET Feb 03, 2022"} -{"batch_size": 16, "epochs": 2, "num_input_heads": 1, "sequence_length": 51, "lr": 0.0007, "input_dropout": 0.1, "comm_dropout": 0.1, "kl_coeff": 0.0, "num_units": 3, "num_encoders": 1, "k": 2, "memorytopk": 4, "hidden_size": 10, "n_templates": 0, "share_inp": false, "share_comm": false, "do_rel": false, "memory_slots": 4, "memory_mlp": 4, "attention_out": 340, "id": "SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "rnn_cell": "GRU", "model_persist_frequency": 1, "batch_frequency_to_log_heatmaps": 1, "path_to_load_model": "", "components_to_load": "", "train_dataset": "balls3curtain64.h5", "test_dataset": null, "transfer_dataset": "balls3curtain64.h5", "should_save_csv": false, "should_resume": false, "experiment_name": "Curtain", "version": 0, "do_comm": true, "input_key_size": 10, "input_value_size": 10, "input_query_size": 10, "comm_key_size": 5, "comm_value_size": 5, "comm_query_size": 5, "num_comm_heads": 2, "frame_frequency_to_log_heatmaps": 5, "folder_log": "./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "type": "config", "timestamp": "06:54PM CET Feb 03, 2022"} -{"messgae": "Saving args to ./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args", "experiment_id": 1, "type": "print", "timestamp": "06:54PM CET Feb 03, 2022"} -{"batch_size": 16, "epochs": 2, "num_input_heads": 1, "sequence_length": 51, "lr": 0.0007, "input_dropout": 0.1, "comm_dropout": 0.1, "kl_coeff": 0.0, "num_units": 3, "num_encoders": 1, "k": 2, "memorytopk": 4, "hidden_size": 10, "n_templates": 0, "share_inp": false, "share_comm": false, "do_rel": false, "memory_slots": 4, "memory_mlp": 4, "attention_out": 340, "id": "SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "rnn_cell": "GRU", "model_persist_frequency": 1, "batch_frequency_to_log_heatmaps": 1, "path_to_load_model": "", "components_to_load": "", "train_dataset": "balls3curtain64.h5", "test_dataset": null, "transfer_dataset": "balls3curtain64.h5", "should_save_csv": false, "should_resume": false, "experiment_name": "Curtain", "version": 0, "do_comm": true, "input_key_size": 10, "input_value_size": 10, "input_query_size": 10, "comm_key_size": 5, "comm_value_size": 5, "comm_query_size": 5, "num_comm_heads": 2, "frame_frequency_to_log_heatmaps": 5, "folder_log": "./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "type": "config", "timestamp": "06:55PM CET Feb 03, 2022"} -{"messgae": "Saving args to ./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args", "experiment_id": 1, "type": "print", "timestamp": "06:55PM CET Feb 03, 2022"} -{"batch_size": 16, "epochs": 2, "num_input_heads": 1, "sequence_length": 51, "lr": 0.0007, "input_dropout": 0.1, "comm_dropout": 0.1, "kl_coeff": 0.0, "num_units": 3, "num_encoders": 1, "k": 2, "memorytopk": 4, "hidden_size": 10, "n_templates": 0, "share_inp": false, "share_comm": false, "do_rel": false, "memory_slots": 4, "memory_mlp": 4, "attention_out": 340, "id": "SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "rnn_cell": "GRU", "model_persist_frequency": 1, "batch_frequency_to_log_heatmaps": 1, "path_to_load_model": "", "components_to_load": "", "train_dataset": "balls3curtain64.h5", "test_dataset": null, "transfer_dataset": "balls3curtain64.h5", "should_save_csv": false, "should_resume": false, "experiment_name": "Curtain", "version": 0, "do_comm": true, "input_key_size": 10, "input_value_size": 10, "input_query_size": 10, "comm_key_size": 5, "comm_value_size": 5, "comm_query_size": 5, "num_comm_heads": 2, "frame_frequency_to_log_heatmaps": 5, "folder_log": "./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "type": "config", "timestamp": "06:56PM CET Feb 03, 2022"} -{"messgae": "Saving args to ./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args", "experiment_id": 1, "type": "print", "timestamp": "06:56PM CET Feb 03, 2022"} -{"batch_size": 16, "epochs": 2, "num_input_heads": 1, "sequence_length": 51, "lr": 0.0007, "input_dropout": 0.1, "comm_dropout": 0.1, "kl_coeff": 0.0, "num_units": 3, "num_encoders": 1, "k": 2, "memorytopk": 4, "hidden_size": 10, "n_templates": 0, "share_inp": false, "share_comm": false, "do_rel": false, "memory_slots": 4, "memory_mlp": 4, "attention_out": 340, "id": "SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "rnn_cell": "GRU", "model_persist_frequency": 1, "batch_frequency_to_log_heatmaps": 1, "path_to_load_model": "", "components_to_load": "", "train_dataset": "balls3curtain64.h5", "test_dataset": null, "transfer_dataset": "balls3curtain64.h5", "should_save_csv": false, "should_resume": false, "experiment_name": "Curtain", "version": 0, "do_comm": true, "input_key_size": 10, "input_value_size": 10, "input_query_size": 10, "comm_key_size": 5, "comm_value_size": 5, "comm_query_size": 5, "num_comm_heads": 2, "frame_frequency_to_log_heatmaps": 5, "folder_log": "./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "type": "config", "timestamp": "06:59PM CET Feb 03, 2022"} -{"messgae": "Saving args to ./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args", "experiment_id": 1, "type": "print", "timestamp": "06:59PM CET Feb 03, 2022"} -{"batch_size": 16, "epochs": 2, "num_input_heads": 1, "sequence_length": 51, "lr": 0.0007, "input_dropout": 0.1, "comm_dropout": 0.1, "kl_coeff": 0.0, "num_units": 3, "num_encoders": 1, "k": 2, "memorytopk": 4, "hidden_size": 10, "n_templates": 0, "share_inp": false, "share_comm": false, "do_rel": false, "memory_slots": 4, "memory_mlp": 4, "attention_out": 340, "id": "SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "rnn_cell": "GRU", "model_persist_frequency": 1, "batch_frequency_to_log_heatmaps": 1, "path_to_load_model": "", "components_to_load": "", "train_dataset": "balls3curtain64.h5", "test_dataset": null, "transfer_dataset": "balls3curtain64.h5", "should_save_csv": false, "should_resume": false, "experiment_name": "Curtain", "version": 0, "do_comm": true, "input_key_size": 10, "input_value_size": 10, "input_query_size": 10, "comm_key_size": 5, "comm_value_size": 5, "comm_query_size": 5, "num_comm_heads": 2, "frame_frequency_to_log_heatmaps": 5, "folder_log": "./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "type": "config", "timestamp": "07:00PM CET Feb 03, 2022"} -{"messgae": "Saving args to ./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args", "experiment_id": 1, "type": "print", "timestamp": "07:00PM CET Feb 03, 2022"} -{"batch_size": 16, "epochs": 2, "num_input_heads": 1, "sequence_length": 51, "lr": 0.0007, "input_dropout": 0.1, "comm_dropout": 0.1, "kl_coeff": 0.0, "num_units": 3, "num_encoders": 1, "k": 2, "memorytopk": 4, "hidden_size": 10, "n_templates": 0, "share_inp": false, "share_comm": false, "do_rel": false, "memory_slots": 4, "memory_mlp": 4, "attention_out": 340, "id": "SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "rnn_cell": "GRU", "model_persist_frequency": 1, "batch_frequency_to_log_heatmaps": 1, "path_to_load_model": "", "components_to_load": "", "train_dataset": "balls3curtain64.h5", "test_dataset": null, "transfer_dataset": "balls3curtain64.h5", "should_save_csv": false, "should_resume": false, "experiment_name": "Curtain", "version": 0, "do_comm": true, "input_key_size": 10, "input_value_size": 10, "input_query_size": 10, "comm_key_size": 5, "comm_value_size": 5, "comm_query_size": 5, "num_comm_heads": 2, "frame_frequency_to_log_heatmaps": 5, "folder_log": "./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "type": "config", "timestamp": "07:02PM CET Feb 03, 2022"} -{"messgae": "Saving args to ./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args", "experiment_id": 1, "type": "print", "timestamp": "07:02PM CET Feb 03, 2022"} -{"batch_size": 16, "epochs": 2, "num_input_heads": 1, "sequence_length": 51, "lr": 0.0007, "input_dropout": 0.1, "comm_dropout": 0.1, "kl_coeff": 0.0, "num_units": 3, "num_encoders": 1, "k": 2, "memorytopk": 4, "hidden_size": 10, "n_templates": 0, "share_inp": false, "share_comm": false, "do_rel": false, "memory_slots": 4, "memory_mlp": 4, "attention_out": 340, "id": "SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "rnn_cell": "GRU", "model_persist_frequency": 1, "batch_frequency_to_log_heatmaps": 1, "path_to_load_model": "", "components_to_load": "", "train_dataset": "balls3curtain64.h5", "test_dataset": null, "transfer_dataset": "balls3curtain64.h5", "should_save_csv": false, "should_resume": false, "experiment_name": "Curtain", "version": 0, "do_comm": true, "input_key_size": 10, "input_value_size": 10, "input_query_size": 10, "comm_key_size": 5, "comm_value_size": 5, "comm_query_size": 5, "num_comm_heads": 2, "frame_frequency_to_log_heatmaps": 5, "folder_log": "./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "type": "config", "timestamp": "07:02PM CET Feb 03, 2022"} -{"messgae": "Saving args to ./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args", "experiment_id": 1, "type": "print", "timestamp": "07:02PM CET Feb 03, 2022"} -{"batch_size": 16, "epochs": 2, "num_input_heads": 1, "sequence_length": 51, "lr": 0.0007, "input_dropout": 0.1, "comm_dropout": 0.1, "kl_coeff": 0.0, "num_units": 3, "num_encoders": 1, "k": 2, "memorytopk": 4, "hidden_size": 10, "n_templates": 0, "share_inp": false, "share_comm": false, "do_rel": false, "memory_slots": 4, "memory_mlp": 4, "attention_out": 340, "id": "SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "rnn_cell": "GRU", "model_persist_frequency": 1, "batch_frequency_to_log_heatmaps": 1, "path_to_load_model": "", "components_to_load": "", "train_dataset": "balls3curtain64.h5", "test_dataset": null, "transfer_dataset": "balls3curtain64.h5", "should_save_csv": false, "should_resume": false, "experiment_name": "Curtain", "version": 0, "do_comm": true, "input_key_size": 10, "input_value_size": 10, "input_query_size": 10, "comm_key_size": 5, "comm_value_size": 5, "comm_query_size": 5, "num_comm_heads": 2, "frame_frequency_to_log_heatmaps": 5, "folder_log": "./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0", "type": "config", "timestamp": "07:03PM CET Feb 03, 2022"} -{"messgae": "Saving args to ./logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args", "experiment_id": 1, "type": "print", "timestamp": "07:03PM CET Feb 03, 2022"} -{"loss": 30.825395584106445, "mode": "train", "batch_idx": 1, "epoch": 1, "time_taken": 3.2734878063201904, "experiment_id": 1, "type": "metric", "timestamp": "07:03PM CET Feb 03, 2022"} -{"loss": -143.4939422607422, "mode": "train", "batch_idx": 2, "epoch": 1, "time_taken": 3.170254945755005, "experiment_id": 1, "type": "metric", "timestamp": "07:03PM CET Feb 03, 2022"} -{"loss": -273.0503234863281, "mode": "train", "batch_idx": 3, "epoch": 1, "time_taken": 3.102560043334961, "experiment_id": 1, "type": "metric", "timestamp": "07:03PM CET Feb 03, 2022"} -{"loss": -412.27764892578125, "mode": "train", "batch_idx": 4, "epoch": 1, "time_taken": 3.0873279571533203, "experiment_id": 1, "type": "metric", "timestamp": "07:03PM CET Feb 03, 2022"} -{"loss": -540.3322143554688, "mode": "train", "batch_idx": 5, "epoch": 1, "time_taken": 3.120579242706299, "experiment_id": 1, "type": "metric", "timestamp": "07:03PM CET Feb 03, 2022"} -{"loss": -716.9647827148438, "mode": "train", "batch_idx": 6, "epoch": 1, "time_taken": 3.1009089946746826, "experiment_id": 1, "type": "metric", "timestamp": "07:03PM CET Feb 03, 2022"} -{"loss": -709.658203125, "mode": "train", "batch_idx": 7, "epoch": 1, "time_taken": 3.1195900440216064, "experiment_id": 1, "type": "metric", "timestamp": "07:03PM CET Feb 03, 2022"} -{"loss": -729.649169921875, "mode": "train", "batch_idx": 8, "epoch": 1, "time_taken": 3.0995841026306152, "experiment_id": 1, "type": "metric", "timestamp": "07:03PM CET Feb 03, 2022"} -{"loss": -716.698974609375, "mode": "train", "batch_idx": 9, "epoch": 1, "time_taken": 3.098029851913452, "experiment_id": 1, "type": "metric", "timestamp": "07:03PM CET Feb 03, 2022"} -{"loss": -951.6339111328125, "mode": "train", "batch_idx": 10, "epoch": 1, "time_taken": 3.1082980632781982, "experiment_id": 1, "type": "metric", "timestamp": "07:03PM CET Feb 03, 2022"} -{"loss": -917.8902587890625, "mode": "train", "batch_idx": 11, "epoch": 1, "time_taken": 3.1051442623138428, "experiment_id": 1, "type": "metric", "timestamp": "07:03PM CET Feb 03, 2022"} diff --git a/logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args b/logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args deleted file mode 100644 index 8fb4449..0000000 Binary files a/logs/SchemaBlocks_10_3_Curtain_0.0007_num_inp_heads_1_ver_0/model/args and /dev/null differ diff --git a/networks.py b/networks.py index 214a447..f51b545 100644 --- a/networks.py +++ b/networks.py @@ -248,14 +248,14 @@ def forward(self, x, h_prev): encoded_input = self.Encoder(x) # encoded_input = clamp(encoded_input) self.nan_hook(encoded_input) - h_new, foo, bar = self.rim_model(encoded_input, h_prev) + h_new, foo, bar, inp_ctx, comm_ctx = self.rim_model(encoded_input, h_prev) # h_new = clamp(h_new) self.nan_hook(h_new) dec_out_ = self.Decoder(h_new.view(h_new.shape[0],-1)) # dec_out_ = clamp(dec_out_) self.nan_hook(dec_out_) - return dec_out_, h_new + return dec_out_, h_new, inp_ctx, comm_ctx, encoded_input def init_hidden(self, batch_size): # assert False, "don't call this" diff --git a/train_mmnist.py b/train_mmnist.py index 91d7d12..18d27b6 100644 --- a/train_mmnist.py +++ b/train_mmnist.py @@ -9,9 +9,11 @@ from argument_parser import argument_parser from logbook.logbook import LogBook from utils.util import set_seed, make_dir -from utils.visualize import VectorLog +from utils.visualize import ScalarLog, VectorLog +from utils.visualize import HeatmapLog from data.MovingMNIST import MovingMNIST from box import Box +from tqdm import tqdm import os from os import listdir @@ -35,12 +37,27 @@ def get_grad_norm(model): return total_norm def train(model, train_loader, optimizer, epoch, logbook, train_batch_idx, args): - grad_norm_log = VectorLog(args.folder_log, "grad_norm") + intm_log_folder = args.folder_log + '/intermediate_vars' + inp_key_log = HeatmapLog(intm_log_folder, "inp key mat") + inp_value_log = HeatmapLog(intm_log_folder, "inp value mat") + inp_query_log = HeatmapLog(intm_log_folder, "inp query mat") + comm_key_log = HeatmapLog(intm_log_folder, "comm key mat") + comm_value_log = HeatmapLog(intm_log_folder, "comm value mat") + comm_query_log = HeatmapLog(intm_log_folder, "comm query mat") + + grad_norm_log = ScalarLog(intm_log_folder, "grad_norm", epoch=epoch) + encoded_log = VectorLog(intm_log_folder, "encoded", epoch=epoch) # TODO put them in a folder + attn_score_log = VectorLog(intm_log_folder, "attn_score", epoch=epoch) + hidden_log = VectorLog(intm_log_folder, "hidden_state", epoch=epoch) model.train() epoch_loss = torch.tensor(0.).to(args.device) - for batch_idx, data in enumerate(train_loader): + for batch_idx, data in enumerate(tqdm(train_loader)): + attn_score_log.reset() + encoded_log.reset() + hidden_log.reset() + hidden = model.init_hidden(data.shape[0]).to(args.device) start_time = time() @@ -52,8 +69,15 @@ def train(model, train_loader, optimizer, epoch, logbook, train_batch_idx, args) # with autograd.detect_anomaly(): if True: for frame in range(data.shape[1]-1): - output, hidden = model(data[:, frame, :, :, :], hidden) - + output, hidden, inp_ctx, comm_ctx, encoded_input = model(data[:, frame, :, :, :], hidden) + # ----- logging ----- + encoded_log.append(encoded_input[-1]) # only take the last sample in a batch + attn_score_cat = torch.cat( + (inp_ctx[3][-1].flatten(),comm_ctx[3][-1].flatten()) + ) + attn_score_log.append(attn_score_cat) + hidden_log.append(hidden[-1].flatten()) + # ----- ------- ----- nan_hook(output) nan_hook(hidden) target = data[:, frame+1, :, :, :] @@ -75,8 +99,24 @@ def train(model, train_loader, optimizer, epoch, logbook, train_batch_idx, args) "time_taken": time() - start_time, } logbook.write_metric_logs(metrics=metrics) - epoch_loss += loss.detach() + + if args.log_intm_frequency > 0 and epoch % args.log_intm_frequency == 0: + """log intermediate variables here""" + pass + # TODO plot inp_ctx + inp_key_log.plot(inp_ctx[0], epoch) # happens in the first batch + inp_value_log.plot(inp_ctx[1], epoch) + inp_query_log.plot(inp_ctx[2], epoch) + pass + # TODO plot comm_ctx + comm_key_log.plot(comm_ctx[0], epoch) + comm_value_log.plot(comm_ctx[1], epoch) + comm_query_log.plot(comm_ctx[2], epoch) + # TODO SAVE logged vectors + encoded_log.save() + attn_score_log.save() + hidden_log.save() epoch_loss = epoch_loss / (batch_idx+1) return train_batch_idx, epoch_loss @@ -98,9 +138,9 @@ def main(): cudable = torch.cuda.is_available() args.device = torch.device("cuda" if cudable else "cpu") - model, optimizer, start_epoch, train_batch_idx = setup_model(args=args, logbook=logbook) + model, optimizer, start_epoch, train_batch_idx, epoch_loss_log = setup_model(args=args, logbook=logbook) - train_set = MovingMNIST(root='./data', train=True, download=True) + train_set = MovingMNIST(root='./data', train=True, download=True, mini=False) test_set = MovingMNIST(root='./data', train=False, download=True) train_loader = torch.utils.data.DataLoader( @@ -114,7 +154,7 @@ def main(): shuffle=False ) transfer_loader = test_loader - epoch_loss_log = VectorLog(args.folder_log, "epoch_loss") + # epoch_loss_log = ScalarLog(args.folder_log+'/intermediate_vars', "epoch_loss") for epoch in range(start_epoch, args.epochs+1): train_batch_idx, epoch_loss = train( model = model, @@ -137,6 +177,7 @@ def main(): 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': epoch_loss, + 'epoch_loss_log': epoch_loss_log }, f"{args.folder_log}/checkpoints/{epoch}") def setup_model(args, logbook): @@ -144,6 +185,7 @@ def setup_model(args, logbook): optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) start_epoch = 1 train_batch_idx = 0 + epoch_loss_log = ScalarLog(args.folder_log+'/intermediate_vars', "epoch_loss") if args.should_resume: # Find the last checkpointed model and resume from that model_dir = f"{args.folder_log}/checkpoints" @@ -160,10 +202,11 @@ def setup_model(args, logbook): optimizer.load_state_dict(checkpoint['optimizer_state_dict']) start_epoch = checkpoint['epoch'] + 1 loss = checkpoint['epoch'] + epoch_loss_log = checkpoint['epoch_loss_log'] logbook.write_message_logs(message=f"Resuming experiment id: {args.id}, from epoch: {start_epoch}") - return model, optimizer, start_epoch, train_batch_idx + return model, optimizer, start_epoch, train_batch_idx, epoch_loss_log if __name__ == '__main__': main() diff --git a/utils/visualize.py b/utils/visualize.py index 8c48230..88d2b35 100644 --- a/utils/visualize.py +++ b/utils/visualize.py @@ -1,5 +1,6 @@ import matplotlib.pyplot as plt import torch +from .util import make_dir def plot_frames(batch_of_pred, batch_of_target, start_frame, end_frame, batch_idx): ''' @@ -20,17 +21,74 @@ def plot_frames(batch_of_pred, batch_of_target, start_frame, end_frame, batch_id axs[1, frame-start_frame].imshow(pred[frame,:,:], cmap="Greys") axs[1, frame-start_frame].axis('off') plt.savefig(f'frames_in_batch_{batch_idx}.png', dpi=120) + plt.close() -def plot_curve(loss): - loss = loss.detach().to(torch.device('cpu')).squeeze() +def plot_curve(vector, save_path, filename): + vector = vector.detach().to(torch.device('cpu')).squeeze() fig, axs = plt.subplots(1,1) - axs.plot(loss) - plt.savefig(f"loss_curve.png",dpi=120) + axs.plot(vector) + plt.savefig(save_path +'/'+ filename, dpi=120) + plt.close() -class VectorLog: - def __init__(self, save_path, var_name): - self.save_path = save_path - self.var_name = var_name+".pt" +def plot_mat(mat, mat_name, epoch): + ''' + deprecated, use HeamapLog instead + ''' + if mat.dim() == 3: + mat_list = [mat[idx_unit,:,:].squeeze().cpu() for idx_unit in range(mat.shape[0])] + else: + mat_list = [mat.cpu()] + fig, axs = plt.subplots(1, len(mat_list), figsize=(2*len(mat_list), 2)) + for idx_mat, mat in enumerate(mat_list): + if len(mat_list) == 1: + axs.imshow(mat, cmap='hot', interpolation='nearest') + else: + axs[idx_mat].imshow(mat, cmap='hot', interpolation='nearest') + fig.suptitle(mat_name.replace("_", " ")+ f' in epoch [{epoch}]') + plt.savefig(mat_name + f'_epoch_{epoch}.png', dpi=120) + plt.close() + +class HeatmapLog: + def __init__(self, folder_log, mat_name): + mat_name = mat_name.replace(' ','_') + make_dir(f"{folder_log}/"+mat_name) + self.save_folder = f"{folder_log}/"+mat_name + self.mat_name = mat_name + + def plot(self, mat, epoch=0): + if mat.dim() == 3: + mat_list = [mat[idx_unit,:,:].squeeze().cpu() for idx_unit in range(mat.shape[0])] + else: + mat_list = [mat.cpu()] + w_h_ratio = (mat_list[0].shape[1]/mat_list[0].shape[0]) + if w_h_ratio >= 1: + fig, axs = plt.subplots(1, len(mat_list), figsize=(0.75+2*len(mat_list)*w_h_ratio, 2)) + cbar_ax = fig.add_axes([0.90, 0.15, 0.03, 0.7]) # left bottom width height + else: + fig, axs = plt.subplots(1, len(mat_list), figsize=(2*len(mat_list), 0.75+2/(w_h_ratio))) + cbar_ax = fig.add_axes([0.10, 0.85, 0.8, 0.03]) # left bottom width height + cbar_ax.orientation = 'horizontal' + for idx_mat, mat in enumerate(mat_list): + if len(mat_list) == 1: + im=axs.imshow(mat, cmap='Greys', interpolation='nearest') + else: + im=axs[idx_mat].imshow(mat, cmap='Greys', interpolation='nearest') + + fig.colorbar(im, cax=cbar_ax) + fig.suptitle(self.mat_name.replace("_", " ")+ f' in epoch [{epoch}]') + plt.savefig(self.save_folder + '/' + self.mat_name + f'_epoch_{epoch}.png', dpi=120) + plt.close() + + +class ScalarLog: + def __init__(self, folder_log, var_name, epoch=None): + self.save_folder = f"{folder_log}/"+var_name + self.var_name = var_name + make_dir(self.save_folder) + self.var = [] + self.epoch = epoch + + def reset(self): self.var = [] def append(self, value): @@ -38,7 +96,37 @@ def append(self, value): def save(self): var_tensor = torch.tensor(self.var) - torch.save(var_tensor, self.save_path+"/"+self.var_name) + if self.epoch is None: + torch.save(var_tensor, self.save_folder +'/'+ self.var_name + '.pt') + else: + torch.save(var_tensor, self.save_folder +'/'+ self.var_name + f'_epoch_{self.epoch}.pt') + +class VectorLog: + def __init__(self, folder_log, var_name, epoch=None): + ''' + log a vector in a list. vector is supposed to be 1-dim + ''' + self.save_folder = f"{folder_log}/"+var_name + self.var_name = var_name + make_dir(self.save_folder) + self.var_stack = None + self.epoch = epoch + + def reset(self): + self.var_stack = None + + def append(self, vector): + if self.var_stack is None: + self.var_stack = vector.detach().unsqueeze(0) + else: + self.var_stack = torch.cat((self.var_stack, vector.detach().unsqueeze(0)), 0) + + def save(self): + if self.epoch is None: + torch.save(self.var_stack, self.save_folder +'/'+ self.var_name + '.pt') + else: + torch.save(self.var_stack, self.save_folder +'/'+ self.var_name + f'_epoch_{self.epoch}.pt') + def main(): data = torch.rand((64,51,1,64,64)) diff --git a/visualize_logtensor.py b/visualize_logtensor.py new file mode 100644 index 0000000..31014c8 --- /dev/null +++ b/visualize_logtensor.py @@ -0,0 +1,80 @@ +''' +this script is intended to plot logged tensors. +entries in the logged tensor should be like: + 1-dim: tensor(idx) == scalar e.g. loss(scalar) vs. epoch(idx) + 2-dim: tensor(idx) == vector e.g. attn_scores(vector) vs. frame +''' +import torch +import matplotlib.pyplot as plt +from utils.visualize import plot_curve, plot_mat, HeatmapLog +from utils.util import make_dir +import argparse + +def arg_parser(): + parser = argparse.ArgumentParser(description='Visualize Logged Tensor') + parser.add_argument('--folder_log', type=str) + args = parser.parse_args() + + return args + +class TensorVisualizer: + def __init__(self, folder_log, tensor_name): + self.folder_log = folder_log + self.tensor_name = tensor_name.replace(' ','_') + self.save_folder = folder_log+'/'+self.tensor_name + self.filename = folder_log+'/'+self.tensor_name+'/'+self.tensor_name + self.tensor = None + + def log_tensor(self, epoch=None): + if epoch is not None: + self.tensor = torch.load(self.filename+f'_epoch_{epoch}.pt') + else: + self.tensor = torch.load(self.filename+f'.pt') + return 0 + + def plot_logged_tensor(self, epoch=None): # NOTE epoch for curve ????/ + save_path = self.save_folder + make_dir(save_path) + if epoch is not None: + figname = self.tensor_name+f'_epoch_{epoch}.png' + else: + figname = self.tensor_name+f'.png' + if self.tensor.dim() == 1: + plot_curve(self.tensor, save_path, figname) + elif self.tensor.dim() == 2: + # TODO plot a matrix as a Heatmap or? + mat_log = HeatmapLog(save_path, 'figures') + mat_log.plot(self.tensor, epoch) + else: + raise ValueError('tensor.dim should either be 1 or 2!') + return 0 + + def __call__(self, epoch=None): + self.log_tensor(epoch) + self.plot_logged_tensor(epoch) + return 0 + + +def main(): + pass + # TODO parse a folder_log + args = arg_parser() + # TODO load the tensor + loss_plot = TensorVisualizer(args.folder_log, "epoch_loss") + gradnorm_plot = TensorVisualizer(args.folder_log, 'grad_norm') + encoded_plot = TensorVisualizer(args.folder_log, "encoded") + attn_plot = TensorVisualizer(args.folder_log, "attn_score") + hidden_plot = TensorVisualizer(args.folder_log, "hidden_state") + # testmat_plot = TensorVisualizer(args.folder_log, "test_mat") + # TODO plot all-epoch tensors + loss_plot() + + # TODO plot per-epoch tensors + for epoch_idx in range(10,110,10): + gradnorm_plot(epoch_idx) + encoded_plot(epoch_idx) + attn_plot(epoch_idx) + hidden_plot(epoch_idx) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/visualize_mmnist.py b/visualize_mmnist.py index 9c7054b..a341549 100644 --- a/visualize_mmnist.py +++ b/visualize_mmnist.py @@ -51,7 +51,7 @@ def test(model, test_loader, args): prediction = torch.zeros_like(data) for frame in range(data.shape[1]-1): - output, hidden = model(data[:, frame, :, :, :], hidden) + output, hidden, *_ = model(data[:, frame, :, :, :], hidden) nan_hook(output) nan_hook(hidden) diff --git a/visualize_result.py b/visualize_result.py deleted file mode 100644 index 85028b0..0000000 --- a/visualize_result.py +++ /dev/null @@ -1,101 +0,0 @@ -"""Main entry point of the code""" -from __future__ import print_function - -import os -from time import time - -import matplotlib.pyplot as plt -import numpy as np -import torch - -from networks import BallModel -# from model_components import GruState -from argument_parser import argument_parser -from dataset import get_dataloaders -from logbook.logbook import LogBook -from utils.util import set_seed, make_dir -from utils.visualize import plot_frames, plot_curve -from box import Box - -import os -from os import listdir -from os.path import isfile, join - -set_seed(0) - -loss_fn = torch.nn.BCELoss() - -def repackage_hidden(ten_): - """Wraps hidden states in new Tensors, to detach them from their history.""" - if isinstance(ten_, torch.Tensor): - return ten_.detach() - else: - return tuple(repackage_hidden(v) for v in ten_) - - -def main(): - """Function to run the experiment""" - args = argument_parser() - args.id = f"SchemaBlocks_{args.hidden_size}_{args.num_units}"+\ - f"_{args.experiment_name}_{args.lr}_num_inp_heads_{args.num_input_heads}"+\ - f"_ver_{args.version}" - # name="SchemaBlocks_"$dim1"_"$block1"_"$topk1"_"$something"_"$lr"_inp_heads_"$inp_heads"_templates_"$templates"_enc_"$encoder"_ver_"$version"_com_"$comm"_Sharing" - print(args) - logbook = LogBook(config=args) - - if not args.should_resume: - # New Experiment - make_dir(f"{args.folder_log}/model") - logbook.write_message_logs(message=f"Saving args to {args.folder_log}/model/args") - torch.save({"args": vars(args)}, f"{args.folder_log}/model/args") - - use_cuda = torch.cuda.is_available() - args.device = torch.device("cuda" if use_cuda else "cpu") - # args.device = torch.device("cpu") - model = setup_model(args=args, logbook=logbook) - - optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) - - args.directory = './data' # dataset directory - # args.directory = 'D:\Projecten\Recurrent-Independent-Mechanisms\data' # dataset directory, windows os - train_loader, test_loader, transfer_loader = get_dataloaders(args) - - train_batch_idx = 0 - - start_epoch = 1 - for epoch in range(start_epoch,start_epoch+1): - model.eval() - with torch.no_grad(): - data = next(iter(train_loader)) - hidden = model.init_hidden(data.shape[0]).to(args.device) # NOTE initialize per epoch or per batch [??] - data = data.to(args.device) - hidden = hidden.detach() - pred = torch.zeros_like(data) - for frame in range(49): - output, hidden = model(data[:, frame, :, :, :], hidden) - pred[:,frame+1,:,:,:] = output - - pred = pred[:,1:,:,:,:] - plot_frames(pred, data, 10, 20, 6) - -def setup_model(args, logbook): - """Method to setup the model""" - - model = BallModel(args) - if args.should_resume: - # Find the last checkpointed model and resume from that - model_dir = f"{args.folder_log}/model" - latest_model_idx = max( - [int(model_idx) for model_idx in listdir(model_dir) - if model_idx != "args"] - ) - args.path_to_load_model = f"{model_dir}/{latest_model_idx}" - args.checkpoint = {"epoch": latest_model_idx} - else: - assert False, 'set args.should_resume true!' - - return model - - -if __name__ == '__main__': - main()