Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,6 @@ dmypy.json

# Pyre type checker
.pyre/

logs/
data/
170 changes: 88 additions & 82 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,80 @@
{
"configurations": [
{
"name": "train bball complete (cuda)",
"name": "visualize logged tensor",
"type": "python",
"request": "launch",
"program": "./visualize_logtensor.py",
"console": "integratedTerminal",
"args": [
"--folder_log",
"logs/SchemaBlocks_10_3_MMNIST_mini_0.0007_num_inp_heads_1_ver_0/intermediate_vars"
]
},
{
"name": "minimal: only for debugging code",
"type": "python",
"request": "launch",
"program": "./train_mmnist.py",
"console": "integratedTerminal",
"args": [
// "--mini",
"--train_dataset",
"balls3curtain64.h5",
"--transfer_dataset",
"balls3curtain64.h5",
"--hidden_size",
"5",
"--should_save_csv",
"False",
"--lr",
"0.0007",
"--num_units",
"5",
"--k",
"1",
"--num_input_heads",
"1",
"--version",
"0",
"--rnn_cell",
"GRU",
"--input_key_size",
"3",
"--input_value_size",
"4",
"--input_query_size",
"3", //actually not used. altijd == input_key_size
"--input_dropout",
"0.1",
"--comm_key_size",
"3",
"--comm_value_size",
"4", //actually not used. altijd == hidden_size because output a hidden_state
"--comm_query_size",
"3",
"--num_comm_heads",
"2",
"--comm_dropout",
"0.1",
"--experiment_name",
"MMnist",
"--batch_size",
"16",
"--epochs",
"10",
"--should_resume",
"False",
"--batch_frequency_to_log_heatmaps",
"1",
"--model_persist_frequency",
"1",
"--log_intm_frequency",
"1"
]
},
{
"name": "train mmnist complete (cuda)",
"type": "python",
"request": "launch",
"program": "${file}",
Expand All @@ -28,7 +101,7 @@
"--input_key_size",
"64",
"--input_value_size",
"2040",
"400",
"--input_query_size",
"64",
"--input_dropout",
Expand All @@ -44,17 +117,19 @@
"--comm_dropout",
"0.1",
"--experiment_name",
"Curtain",
"MMNIST_complete",
"--batch_size",
"64",
"--epochs",
"50",
"200",
"--should_resume",
"False",
"--batch_frequency_to_log_heatmaps",
"10",
"--model_persist_frequency",
"1"
"5",
"--log_intm_frequency",
"10"
]
},
{
Expand Down Expand Up @@ -88,7 +163,7 @@
"--input_key_size",
"10",
"--input_value_size",
"10",
"40",
"--input_query_size",
"10",
"--input_dropout",
Expand All @@ -104,88 +179,19 @@
"--comm_dropout",
"0.1",
"--experiment_name",
"Curtain",
"MMNIST_mini",
"--batch_size",
"16",
"64",
"--epochs",
"10",
"100",
"--should_resume",
"True",
"False",
"--batch_frequency_to_log_heatmaps",
"1",
"--model_persist_frequency",
"1"
]
},
{
"name": "main: RIM (cuda)",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"args": [
"--cuda",
"True",
"--epochs",
"6",
"--batch_size",
"64",
"--input_size",
"1",
"--hidden_size",
"600",
"--size",
"14",
"--loadsaved",
"0",
"--model",
"RIM"
]
},
{
"name": "main: LSTM",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"args": [
"--cuda",
"False",
"--epochs",
"1",
"--batch_size",
"64",
"--input_size",
"1",
"--size",
"14",
"--loadsaved",
"0",
"--model",
"LSTM"
]
},
{
"name": "main: RIM",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"args": [
"--cuda",
"False",
"--epochs",
"7",
"--batch_size",
"64",
"--input_size",
"14",
"--size",
"14",
"--loadsaved",
"0",
"--model",
"RIM"
"5",
"--log_intm_frequency",
"10"
]
}
]
Expand Down
51 changes: 43 additions & 8 deletions RIM.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,12 @@ def input_attention_mask(self, x, h):
attention_scores = torch.mean(attention_scores, dim = 1)
mask_ = torch.zeros(x.size(0), self.num_units).to(self.device)

'''
attention_scores: (batch_size, num_heads, num_units, 2)
--> mean in dim-1
--> (batch_size, num_units, 2)
'''

not_null_scores = attention_scores[:,:, 0]
topk1 = torch.topk(not_null_scores,self.k, dim = 1)
row_index = np.arange(x.size(0))
Expand All @@ -222,7 +228,17 @@ def input_attention_mask(self, x, h):
attention_probs = self.input_dropout(nn.Softmax(dim = -1)(attention_scores))
inputs = torch.matmul(attention_probs, value_layer) * mask_.unsqueeze(2)

return inputs, mask_
"""
parameters:
key_mat: input_size x key_size -> input 2 key
value_mat: input_size x val_size -> input 2 value
query_mat: num_units x hidden_size x key_size -> hidden 2 query
"""
key_mat = get_mat(self.key).transpose(0,1).detach() # TODO detach?
value_mat = get_mat(self.value).transpose(0,1).detach()
query_mat = self.query.w.detach()

return inputs, mask_, key_mat, value_mat, query_mat, attention_scores

def communication_attention(self, h, mask):
"""
Expand All @@ -249,6 +265,10 @@ def communication_attention(self, h, mask):
attention_scores = attention_scores / math.sqrt(self.comm_key_size)
self.inf_hook(attention_scores)
attention_probs = nn.Softmax(dim=-1)(attention_scores)

"""
attention_scores: (batch_size, num_heads, num_units, num_units)
"""

mask = [mask for _ in range(attention_probs.size(1))]
mask = torch.stack(mask, dim = 1)
Expand All @@ -263,8 +283,18 @@ def communication_attention(self, h, mask):
context_layer = context_layer.view(*new_context_layer_shape)
context_layer = self.comm_attention_output(context_layer)
context_layer = context_layer + h

return context_layer

"""
parameters:
key_mat: num_units x hidden_size x (key_size*num_heads) -> hidden 2 key
value_mat: num_units x hidden_size x (val_size*num_heads==hidden_size*num_heads) -> hidden 2 (value==new hidden_state)
query_mat: num_units x hidden_size x (key_size*num_heads) -> hidden 2 query
"""
key_mat = self.key_.w.detach()
value_mat = self.value_.w.detach()
query_mat = self.query_.w.detach()

return context_layer, key_mat, value_mat, query_mat, attention_scores

def nan_hook(self, out):
nan_mask = torch.isnan(out)
Expand All @@ -290,7 +320,7 @@ def forward(self, x, hs, cs = None):
x = torch.cat((x.unsqueeze(1), null_input), dim = 1)

# Compute input attention
inputs, mask = self.input_attention_mask(x, hs)
inputs, mask, *inp_ctx = self.input_attention_mask(x, hs)
h_old = hs * 1.0
if cs is not None:
c_old = cs * 1.0
Expand All @@ -309,15 +339,15 @@ def forward(self, x, hs, cs = None):
h_new = blocked_grad.apply(hs, mask)

# Compute communication attention
h_new = self.communication_attention(h_new, mask.squeeze(2))
h_new, *comm_ctx = self.communication_attention(h_new, mask.squeeze(2))
self.nan_hook(h_new)

hs = mask * h_new + (1 - mask) * h_old
if cs is not None:
cs = mask * cs + (1 - mask) * c_old
return hs, cs, None
return hs, cs, None, inp_ctx, comm_ctx
self.nan_hook(hs)
return hs, None, None
return hs, None, None, inp_ctx, comm_ctx


class RIM(nn.Module):
Expand Down Expand Up @@ -625,4 +655,9 @@ def __init__(self, scale_factor, mode):

def forward(self, x):
x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=False)
return x
return x

def get_mat(linear_layer):
'''get linear transformation matrix in a linear layer'''
p_list = [p for p in linear_layer.parameters()]
return p_list[0]
3 changes: 3 additions & 0 deletions argument_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ def argument_parser():
help='Number of training epochs after which model is to '
'be persisted. -1 means that the model is not'
'persisted')
parser.add_argument('--log_intm_frequency', type=int, default=10,
metavar="Frequency at which we log the intermediate variables of the model",
help='Just type in a positive integer')
parser.add_argument('--batch_frequency_to_log_heatmaps', type=int, default=-1,
metavar='Frequency at which the heatmaps are persisted',
help='Number of training batches after which we will persit the '
Expand Down
13 changes: 10 additions & 3 deletions data/MovingMNIST.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@ class MovingMNIST(data.Dataset):
training_file = 'moving_mnist_train.pt'
test_file = 'moving_mnist_test.pt'

def __init__(self, root, train=True, split=1000, transform=None, download=False):
def __init__(self, root, train=True, split=1000, transform=None, download=False, mini=False):
self.root = os.path.expanduser(root)
self.transform = transform
self.split = split
self.train = train # training set or test set
self.mini = mini

if download:
self.download()
Expand Down Expand Up @@ -85,9 +86,15 @@ def _transform_time(data):

def __len__(self):
if self.train:
return len(self.train_data)
if not self.mini:
return len(self.train_data)
else:
return min(len(self.train_data) // 50 , 100)
else:
return len(self.test_data)
if not self.mini:
return len(self.test_data)
else:
return min(len(self.test_data) // 50 , 100)

def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \
Expand Down
11 changes: 6 additions & 5 deletions experiment_mmnist.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,20 @@ num_input_heads=1
version=0
rnn_cell="GRU"
input_key_size=64
input_value_size=2040
input_value_size=400
input_query_size=64
input_dropout=0.1
comm_key_size=32
comm_value_size=32
comm_query_size=32
num_comm_heads=4
comm_dropout=0.1
experiment_name="Curtain"
experiment_name="MMNIST_complete"
batch_size=64
epochs=100
epochs=200
should_resume="False"
batch_frequency_to_log_heatmaps=10
model_persist_frequency=1
model_persist_frequency=5
log_intm_frequency=10

nohup python3 train_mmnist.py --train_dataset $train_dataset --hidden_size $hidden_size --should_save_csv $should_save_csv --lr $lr --num_units $num_units --k $k --num_input_heads $num_input_heads --version $version --rnn_cell $rnn_cell --input_key_size $input_key_size --input_value_size $input_value_size --input_query_size $input_query_size --input_dropout $input_dropout --comm_key_size $comm_key_size --comm_value_size $comm_value_size --comm_query_size $comm_query_size --num_comm_heads $num_comm_heads --comm_dropout $comm_dropout --experiment_name $experiment_name --batch_size $batch_size --epochs $epochs --should_resume $should_resume --batch_frequency_to_log_heatmaps $batch_frequency_to_log_heatmaps --model_persist_frequency $model_persist_frequency
nohup python3 train_mmnist.py --train_dataset $train_dataset --hidden_size $hidden_size --should_save_csv $should_save_csv --lr $lr --num_units $num_units --k $k --num_input_heads $num_input_heads --version $version --rnn_cell $rnn_cell --input_key_size $input_key_size --input_value_size $input_value_size --input_query_size $input_query_size --input_dropout $input_dropout --comm_key_size $comm_key_size --comm_value_size $comm_value_size --comm_query_size $comm_query_size --num_comm_heads $num_comm_heads --comm_dropout $comm_dropout --experiment_name $experiment_name --batch_size $batch_size --epochs $epochs --should_resume $should_resume --batch_frequency_to_log_heatmaps $batch_frequency_to_log_heatmaps --model_persist_frequency $model_persist_frequency --log_intm_frequency $log_intm_frequency
Loading