Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120,000 changes: 120,000 additions & 0 deletions CommaiMini-^$/Long/Verify_Produce_longer.tsv

Large diffs are not rendered by default.

108,000 changes: 108,000 additions & 0 deletions CommaiMini-^$/Long/Verify_Produce_train.tsv

Large diffs are not rendered by default.

120,000 changes: 120,000 additions & 0 deletions CommaiMini-^$/Long/Verify_Produce_unseen.tsv

Large diffs are not rendered by default.

120,000 changes: 120,000 additions & 0 deletions CommaiMini-^$/Long/Verify_Produce_unseen_longer.tsv

Large diffs are not rendered by default.

12,000 changes: 12,000 additions & 0 deletions CommaiMini-^$/Long/Verify_Produce_validation.tsv

Large diffs are not rendered by default.

12,000 changes: 12,000 additions & 0 deletions CommaiMini-^$/Short/Verify_Produce_longer.tsv

Large diffs are not rendered by default.

10,800 changes: 10,800 additions & 0 deletions CommaiMini-^$/Short/Verify_Produce_train.tsv

Large diffs are not rendered by default.

12,000 changes: 12,000 additions & 0 deletions CommaiMini-^$/Short/Verify_Produce_unseen.tsv

Large diffs are not rendered by default.

12,000 changes: 12,000 additions & 0 deletions CommaiMini-^$/Short/Verify_Produce_unseen_longer.tsv

Large diffs are not rendered by default.

1,200 changes: 1,200 additions & 0 deletions CommaiMini-^$/Short/Verify_Produce_validation.tsv

Large diffs are not rendered by default.

Empty file added scripts/__init__.py
Empty file.
35 changes: 35 additions & 0 deletions scripts/check_correct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
def correct(src, target):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as this script is in a folder merely called "scripts" it should have a more descriptive name, check_correct for what?

'''
Return True if the target is a valid output given the source
Args
src (str)
target (str)
formats should match those in the datafiles
'''

grammar_vocab = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'T',
'U', 'V', 'W', 'X', 'Y', 'Z', 'AS', 'BS', 'CS', 'DS', 'ES', 'FS', 'GS', 'HS', 'IS', 'JS',
'KS', 'LS', 'MS', 'NS', 'OS']
inpt = src.split()
pred = target.split()
all_correct = False
#Check if the length is correct
length_check = True if len(pred) == 3 * len(inpt) else False
#Check if everything falls in the same bucket, and there are no repeats
#if length_check:
for idx, inp in enumerate(inpt):
vocab_idx = grammar_vocab.index(inp) + 1
#print(vocab_idx)
span = pred[idx*3:idx*3+3]
#print(span)
span_str = " ".join(span)
if (not all(int(item.replace("A", "").replace("B", "").replace("C", "").split("_")[0]) == vocab_idx for item in span)
or (not ("A" in span_str and "B" in span_str and "C" in span_str))):
all_correct = False
break
else:
all_correct = True
return all_correct

#print(correct('L Z Z Z Z', 'A12_2 B12_1 C12_1 B25_1 C25_2 A25_2 C25_2 B25_2 A25_2 C25_2 A25_2 B25_1 C25_1 B25_2 A25_2'))

73 changes: 73 additions & 0 deletions scripts/eval_kgrammar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import os
import argparse
import logging
import pandas as pd
import torch

from check_correct import correct

import seq2seq
from seq2seq.evaluator import Predictor
from seq2seq.util.checkpoint import Checkpoint

try:
raw_input # Python 2
except NameError:
raw_input = input # Python 3



parser = argparse.ArgumentParser()

parser.add_argument('--checkpoint_path', help='Give the checkpoint path from which to load the model')
parser.add_argument('--test_folder', help='Give the path to the folder containing test files')
parser.add_argument('--cuda_device', default=0, type=int, help='set cuda device to use')

opt = parser.parse_args()

test_files = os.listdir(opt.test_folder)

if torch.cuda.is_available():
print("Cuda device set to %i" % opt.cuda_device)
torch.cuda.set_device(opt.cuda_device)

#################################################################################
# load model

logging.info("loading checkpoint from {}".format(os.path.join(opt.checkpoint_path)))
checkpoint = Checkpoint.load(opt.checkpoint_path)
seq2seq = checkpoint.model
input_vocab = checkpoint.input_vocab
output_vocab = checkpoint.output_vocab

#################################################################################
# Generate predictor
predictor = Predictor(seq2seq, input_vocab, output_vocab)
#seq_acc = {}
for tf in test_files:
data_arr = pd.read_csv(os.path.join(opt.test_folder, tf), delimiter='\t', header=None).values
count1 = 0
count2 = 0
for i in range(data_arr.shape[0]):
src = data_arr[i,0].strip().split()
tgt = predictor.predict(src)
tgt1 = tgt[:tgt.index('<eos>')] #without <eos>
flag2 =False
if(len(tgt)>=3*len(src)):
tgt2 = tgt[:3*len(src)] #with <eos>
flag2 = correct(data_arr[i, 0], ' '.join(map(str,tgt1))) #with eos
flag1 = correct(data_arr[i,0], ' '.join(map(str,tgt1))) #without eos
if (flag1==True):
count1 += 1
if(flag2 == True):
count2 +=1
sa1 = count1/(data_arr.shape[0]) #without eos
sa2 = count2 / (data_arr.shape[0]) #with eos
print("Sequence Accuracy for {}: without <eos> = {} and with <eos> = {}".format(tf, sa1, sa2))
#seq_acc[tf] = (sa1, sa2)


# while True:
# seq_str = raw_input("Type in a source sequence:")
# seq = seq_str.strip().split()
# print(predictor.predict(seq))
219 changes: 219 additions & 0 deletions scripts/produce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
import numpy as np
import random
import argparse
import string

mfolder = '../CommaiMini-^$'
try:
raw_input # Python 2
except NameError:
raw_input = input # Python 3

parser = argparse.ArgumentParser()
parser.add_argument('--max_train_com', type=int, help= 'max length of compositions in train', default=4)
parser.add_argument('--max_test_com', type=int, help= 'max length of compositions in test', default=7)


opt = parser.parse_args()

alphabets = [string.printable[i] for i in range(len(string.printable)-6)]
# alphabets = [chr(i) for i in range(ord('A'),ord('Z')+1)]
# vocab1 = list(itertools.product(alphabets, repeat=2))
# vocab2 = list(itertools.product(alphabets, repeat=3))
# vocab = [''.join(map(str,v)) for v in vocab1]
# vocab.extend([''.join(map(str,v)) for v in vocab2])
# alphabets.extend(vocab)
pidx = int(len(alphabets)/2)
random.shuffle(alphabets)
subset1 = alphabets[:pidx]
subset2 = alphabets[pidx:]
operators = ['and', 'or', 'not']

def and_gate(ps, token):
temp_ipt = [token]
for s in ps:
temp_ipt.append(s)
if (ps.index(s) != len(ps) - 1):
temp_ipt.append('and')
random.shuffle(ps)
temp_attn = []
for s in ps:
temp_attn.append(temp_ipt.index(s))
#ps.append('<eos>')
return(ps, temp_ipt, temp_attn)

def or_gate(ps, token):
temp_ipt = [token]
for s in ps:
temp_ipt.append(s)
if (ps.index(s) != len(ps) - 1):
temp_ipt.append('or')
size = random.sample(np.arange(1, len(ps) + 1, dtype=int).tolist(), 1)
out_str = np.random.choice(ps, size=size, replace=False).tolist()
temp_attn = []
for s in out_str:
temp_attn.append(temp_ipt.index(s))
#out_str.append('<eos>')
return(out_str, temp_ipt, temp_attn)

def not_gate(ps, token):
temp_ipt = [token]
temp_opt = []
cvocab = []
num_nots = random.sample(np.arange(1,len(ps)+1, dtype=int).tolist(),1)[0]
not_pfx = random.sample(ps, num_nots)
for pf in not_pfx:
temp_alpha = list(set(alphabets) - set([pf]))
cvocab.append(temp_alpha)
for s in ps:
if (s in not_pfx):
temp_ipt.append('not')
temp_ipt.append(s)
if (ps.index(s) != len(ps) - 1):
temp_ipt.append('and')
temp_opt.append(random.sample(cvocab[not_pfx.index(s)], 1)[0])
else:
temp_ipt.append(s)
if (ps.index(s) != len(ps) - 1):
temp_ipt.append('and')
temp_opt.append(s)

temp_attn = []
for s in ps:
temp_attn.append(temp_ipt.index(s))
#temp_opt.append('<eos>')
return (temp_opt, temp_ipt, temp_attn)



def io_strings(word, all_words, comp_len, token):
ipt = []
out = []
attn = []
comps = np.random.choice(comp_len, size=len(operators))
operations = np.random.choice(operators, size=len(operators), replace=False).tolist()
random.shuffle(operations)
for i in range(len(comps)):
ps = [word]
ps.extend(np.random.choice(all_words, size=comps[i]-1).tolist())
if (operations[i] == 'and'):
str_tup = and_gate(ps, token)
elif(operations[i] == 'or'):
str_tup = or_gate(ps, token)
else:
str_tup = not_gate(ps, token)
out.append(' '.join(map(str, str_tup[0])))
ipt.append(' '.join(map(str, str_tup[1])))
attn.append(' '.join(map(str, str_tup[2])))
return (ipt, out, attn)

def train(words, size):
comp_lens = np.arange(2, opt.max_train_com+1, dtype=int).tolist()
data = np.zeros((size, 2), dtype=object)
idx = 0
try:
while idx < data.shape[0]:
random.shuffle(words)
for w in words:
tup = io_strings(w, words, comp_lens, 'produce')
data[idx:idx+len(tup[0]),0] = tup[0]
data[idx:idx + len(tup[0]), 1] = tup[1]
#data[idx:idx + len(tup[0]), 2] = tup[2]
idx += len(tup[0])
if(idx > data.shape[0]-len(operators)):
raise StopIteration()
except StopIteration:
pass

return data

def unseen(words1, words2, size):
comp_lens = np.arange(2, opt.max_train_com+1, dtype=int).tolist()
data = np.zeros((size, 2), dtype=object)
idx = 0
try:
while idx < data.shape[0]:
random.shuffle(words1)
for w in words1:
tup = io_strings(w, words2, comp_lens, 'produce')
data[idx:idx + len(tup[0]), 0] = tup[0]
data[idx:idx + len(tup[0]), 1] = tup[1]
#data[idx:idx + len(tup[0]), 2] = tup[2]
idx += len(tup[0])
if (idx > data.shape[0] - len(operators)):
raise StopIteration()

random.shuffle(words2)
for w in words2:
tup = io_strings(w, words1, comp_lens, 'produce')
data[idx:idx + len(tup[0]), 0] = tup[0]
data[idx:idx + len(tup[0]), 1] = tup[1]
#data[idx:idx + len(tup[0]), 2] = tup[2]
idx += len(tup[0])
if (idx > data.shape[0] - len(operators)):
raise StopIteration()
except StopIteration:
pass

return data

def longer(words, size):
comp_lens = np.arange(opt.max_train_com+1, opt.max_test_com+1, dtype=int).tolist()
data = np.zeros((size, 2), dtype=object)
idx = 0
try:
while idx < data.shape[0]:
random.shuffle(words)
for w in words:
tup = io_strings(w, words, comp_lens, 'produce')
data[idx:idx + len(tup[0]), 0] = tup[0]
data[idx:idx + len(tup[0]), 1] = tup[1]
#data[idx:idx + len(tup[0]), 2] = tup[2]
idx += len(tup[0])
if (idx > data.shape[0] - len(operators)):
raise StopIteration()
except StopIteration:
pass
return data

def long_unseen(words1, words2, size):
comp_lens = np.arange(opt.max_train_com + 1, opt.max_test_com + 1, dtype=int).tolist()
data = np.zeros((size, 2), dtype=object)
idx = 0
try:
while idx < data.shape[0]:
random.shuffle(words1)
for w in words1:
tup = io_strings(w, words2, comp_lens, 'produce')
data[idx:idx + len(tup[0]), 0] = tup[0]
data[idx:idx + len(tup[0]), 1] = tup[1]
#data[idx:idx + len(tup[0]), 2] = tup[2]
idx += len(tup[0])
if (idx > data.shape[0] - len(operators)):
raise StopIteration()
random.shuffle(words2)
for w in words2:
tup = io_strings(w, words1, comp_lens, 'produce')
data[idx:idx + len(tup[0]), 0] = tup[0]
data[idx:idx + len(tup[0]), 1] = tup[1]
#data[idx:idx + len(tup[0]), 2] = tup[2]
idx += len(tup[0])
if (idx > data.shape[0] - len(operators)):
raise StopIteration()
except StopIteration:
pass

return data

def get_data(num_samples):
tr1 = train(subset1, int(num_samples/2))
tr2 = train(subset2, int(num_samples/2))
train_data = np.vstack((tr1,tr2))
unseen_test = unseen(subset1, subset2, num_samples)
lg1 = longer(subset1, int(num_samples/2))
lg2 = longer(subset2, int(num_samples/2))
longer_test = np.vstack((lg1, lg2))
unseen_long_test = long_unseen(subset1, subset2, num_samples)

return(train_data, unseen_long_test, unseen_test, longer_test)

Loading