-
Notifications
You must be signed in to change notification settings - Fork 9
sequence accuracy for k-grammar task #7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| def correct(src, target): | ||
| ''' | ||
| Return True if the target is a valid output given the source | ||
| Args | ||
| src (str) | ||
| target (str) | ||
| formats should match those in the datafiles | ||
| ''' | ||
|
|
||
| grammar_vocab = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'T', | ||
| 'U', 'V', 'W', 'X', 'Y', 'Z', 'AS', 'BS', 'CS', 'DS', 'ES', 'FS', 'GS', 'HS', 'IS', 'JS', | ||
| 'KS', 'LS', 'MS', 'NS', 'OS'] | ||
| inpt = src.split() | ||
| pred = target.split() | ||
| all_correct = False | ||
| #Check if the length is correct | ||
| length_check = True if len(pred) == 3 * len(inpt) else False | ||
| #Check if everything falls in the same bucket, and there are no repeats | ||
| #if length_check: | ||
| for idx, inp in enumerate(inpt): | ||
| vocab_idx = grammar_vocab.index(inp) + 1 | ||
| #print(vocab_idx) | ||
| span = pred[idx*3:idx*3+3] | ||
| #print(span) | ||
| span_str = " ".join(span) | ||
| if (not all(int(item.replace("A", "").replace("B", "").replace("C", "").split("_")[0]) == vocab_idx for item in span) | ||
| or (not ("A" in span_str and "B" in span_str and "C" in span_str))): | ||
| all_correct = False | ||
| break | ||
| else: | ||
| all_correct = True | ||
| return all_correct | ||
|
|
||
| #print(correct('L Z Z Z Z', 'A12_2 B12_1 C12_1 B25_1 C25_2 A25_2 C25_2 B25_2 A25_2 C25_2 A25_2 B25_1 C25_1 B25_2 A25_2')) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| import os | ||
| import argparse | ||
| import logging | ||
| import pandas as pd | ||
| import torch | ||
|
|
||
| from check_correct import correct | ||
|
|
||
| import seq2seq | ||
| from seq2seq.evaluator import Predictor | ||
| from seq2seq.util.checkpoint import Checkpoint | ||
|
|
||
| try: | ||
| raw_input # Python 2 | ||
| except NameError: | ||
| raw_input = input # Python 3 | ||
|
|
||
|
|
||
|
|
||
| parser = argparse.ArgumentParser() | ||
|
|
||
| parser.add_argument('--checkpoint_path', help='Give the checkpoint path from which to load the model') | ||
| parser.add_argument('--test_folder', help='Give the path to the folder containing test files') | ||
| parser.add_argument('--cuda_device', default=0, type=int, help='set cuda device to use') | ||
|
|
||
| opt = parser.parse_args() | ||
|
|
||
| test_files = os.listdir(opt.test_folder) | ||
|
|
||
| if torch.cuda.is_available(): | ||
| print("Cuda device set to %i" % opt.cuda_device) | ||
| torch.cuda.set_device(opt.cuda_device) | ||
|
|
||
| ################################################################################# | ||
| # load model | ||
|
|
||
| logging.info("loading checkpoint from {}".format(os.path.join(opt.checkpoint_path))) | ||
| checkpoint = Checkpoint.load(opt.checkpoint_path) | ||
| seq2seq = checkpoint.model | ||
| input_vocab = checkpoint.input_vocab | ||
| output_vocab = checkpoint.output_vocab | ||
|
|
||
| ################################################################################# | ||
| # Generate predictor | ||
| predictor = Predictor(seq2seq, input_vocab, output_vocab) | ||
| #seq_acc = {} | ||
| for tf in test_files: | ||
| data_arr = pd.read_csv(os.path.join(opt.test_folder, tf), delimiter='\t', header=None).values | ||
| count1 = 0 | ||
| count2 = 0 | ||
| for i in range(data_arr.shape[0]): | ||
| src = data_arr[i,0].strip().split() | ||
| tgt = predictor.predict(src) | ||
| tgt1 = tgt[:tgt.index('<eos>')] #without <eos> | ||
| flag2 =False | ||
| if(len(tgt)>=3*len(src)): | ||
| tgt2 = tgt[:3*len(src)] #with <eos> | ||
| flag2 = correct(data_arr[i, 0], ' '.join(map(str,tgt1))) #with eos | ||
| flag1 = correct(data_arr[i,0], ' '.join(map(str,tgt1))) #without eos | ||
| if (flag1==True): | ||
| count1 += 1 | ||
| if(flag2 == True): | ||
| count2 +=1 | ||
| sa1 = count1/(data_arr.shape[0]) #without eos | ||
| sa2 = count2 / (data_arr.shape[0]) #with eos | ||
| print("Sequence Accuracy for {}: without <eos> = {} and with <eos> = {}".format(tf, sa1, sa2)) | ||
| #seq_acc[tf] = (sa1, sa2) | ||
|
|
||
|
|
||
| # while True: | ||
| # seq_str = raw_input("Type in a source sequence:") | ||
| # seq = seq_str.strip().split() | ||
| # print(predictor.predict(seq)) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,219 @@ | ||
| import numpy as np | ||
| import random | ||
| import argparse | ||
| import string | ||
|
|
||
| mfolder = '../CommaiMini-^$' | ||
| try: | ||
| raw_input # Python 2 | ||
| except NameError: | ||
| raw_input = input # Python 3 | ||
|
|
||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument('--max_train_com', type=int, help= 'max length of compositions in train', default=4) | ||
| parser.add_argument('--max_test_com', type=int, help= 'max length of compositions in test', default=7) | ||
|
|
||
|
|
||
| opt = parser.parse_args() | ||
|
|
||
| alphabets = [string.printable[i] for i in range(len(string.printable)-6)] | ||
| # alphabets = [chr(i) for i in range(ord('A'),ord('Z')+1)] | ||
| # vocab1 = list(itertools.product(alphabets, repeat=2)) | ||
| # vocab2 = list(itertools.product(alphabets, repeat=3)) | ||
| # vocab = [''.join(map(str,v)) for v in vocab1] | ||
| # vocab.extend([''.join(map(str,v)) for v in vocab2]) | ||
| # alphabets.extend(vocab) | ||
| pidx = int(len(alphabets)/2) | ||
| random.shuffle(alphabets) | ||
| subset1 = alphabets[:pidx] | ||
| subset2 = alphabets[pidx:] | ||
| operators = ['and', 'or', 'not'] | ||
|
|
||
| def and_gate(ps, token): | ||
| temp_ipt = [token] | ||
| for s in ps: | ||
| temp_ipt.append(s) | ||
| if (ps.index(s) != len(ps) - 1): | ||
| temp_ipt.append('and') | ||
| random.shuffle(ps) | ||
| temp_attn = [] | ||
| for s in ps: | ||
| temp_attn.append(temp_ipt.index(s)) | ||
| #ps.append('<eos>') | ||
| return(ps, temp_ipt, temp_attn) | ||
|
|
||
| def or_gate(ps, token): | ||
| temp_ipt = [token] | ||
| for s in ps: | ||
| temp_ipt.append(s) | ||
| if (ps.index(s) != len(ps) - 1): | ||
| temp_ipt.append('or') | ||
| size = random.sample(np.arange(1, len(ps) + 1, dtype=int).tolist(), 1) | ||
| out_str = np.random.choice(ps, size=size, replace=False).tolist() | ||
| temp_attn = [] | ||
| for s in out_str: | ||
| temp_attn.append(temp_ipt.index(s)) | ||
| #out_str.append('<eos>') | ||
| return(out_str, temp_ipt, temp_attn) | ||
|
|
||
| def not_gate(ps, token): | ||
| temp_ipt = [token] | ||
| temp_opt = [] | ||
| cvocab = [] | ||
| num_nots = random.sample(np.arange(1,len(ps)+1, dtype=int).tolist(),1)[0] | ||
| not_pfx = random.sample(ps, num_nots) | ||
| for pf in not_pfx: | ||
| temp_alpha = list(set(alphabets) - set([pf])) | ||
| cvocab.append(temp_alpha) | ||
| for s in ps: | ||
| if (s in not_pfx): | ||
| temp_ipt.append('not') | ||
| temp_ipt.append(s) | ||
| if (ps.index(s) != len(ps) - 1): | ||
| temp_ipt.append('and') | ||
| temp_opt.append(random.sample(cvocab[not_pfx.index(s)], 1)[0]) | ||
| else: | ||
| temp_ipt.append(s) | ||
| if (ps.index(s) != len(ps) - 1): | ||
| temp_ipt.append('and') | ||
| temp_opt.append(s) | ||
|
|
||
| temp_attn = [] | ||
| for s in ps: | ||
| temp_attn.append(temp_ipt.index(s)) | ||
| #temp_opt.append('<eos>') | ||
| return (temp_opt, temp_ipt, temp_attn) | ||
|
|
||
|
|
||
|
|
||
| def io_strings(word, all_words, comp_len, token): | ||
| ipt = [] | ||
| out = [] | ||
| attn = [] | ||
| comps = np.random.choice(comp_len, size=len(operators)) | ||
| operations = np.random.choice(operators, size=len(operators), replace=False).tolist() | ||
| random.shuffle(operations) | ||
| for i in range(len(comps)): | ||
| ps = [word] | ||
| ps.extend(np.random.choice(all_words, size=comps[i]-1).tolist()) | ||
| if (operations[i] == 'and'): | ||
| str_tup = and_gate(ps, token) | ||
| elif(operations[i] == 'or'): | ||
| str_tup = or_gate(ps, token) | ||
| else: | ||
| str_tup = not_gate(ps, token) | ||
| out.append(' '.join(map(str, str_tup[0]))) | ||
| ipt.append(' '.join(map(str, str_tup[1]))) | ||
| attn.append(' '.join(map(str, str_tup[2]))) | ||
| return (ipt, out, attn) | ||
|
|
||
| def train(words, size): | ||
| comp_lens = np.arange(2, opt.max_train_com+1, dtype=int).tolist() | ||
| data = np.zeros((size, 2), dtype=object) | ||
| idx = 0 | ||
| try: | ||
| while idx < data.shape[0]: | ||
| random.shuffle(words) | ||
| for w in words: | ||
| tup = io_strings(w, words, comp_lens, 'produce') | ||
| data[idx:idx+len(tup[0]),0] = tup[0] | ||
| data[idx:idx + len(tup[0]), 1] = tup[1] | ||
| #data[idx:idx + len(tup[0]), 2] = tup[2] | ||
| idx += len(tup[0]) | ||
| if(idx > data.shape[0]-len(operators)): | ||
| raise StopIteration() | ||
| except StopIteration: | ||
| pass | ||
|
|
||
| return data | ||
|
|
||
| def unseen(words1, words2, size): | ||
| comp_lens = np.arange(2, opt.max_train_com+1, dtype=int).tolist() | ||
| data = np.zeros((size, 2), dtype=object) | ||
| idx = 0 | ||
| try: | ||
| while idx < data.shape[0]: | ||
| random.shuffle(words1) | ||
| for w in words1: | ||
| tup = io_strings(w, words2, comp_lens, 'produce') | ||
| data[idx:idx + len(tup[0]), 0] = tup[0] | ||
| data[idx:idx + len(tup[0]), 1] = tup[1] | ||
| #data[idx:idx + len(tup[0]), 2] = tup[2] | ||
| idx += len(tup[0]) | ||
| if (idx > data.shape[0] - len(operators)): | ||
| raise StopIteration() | ||
|
|
||
| random.shuffle(words2) | ||
| for w in words2: | ||
| tup = io_strings(w, words1, comp_lens, 'produce') | ||
| data[idx:idx + len(tup[0]), 0] = tup[0] | ||
| data[idx:idx + len(tup[0]), 1] = tup[1] | ||
| #data[idx:idx + len(tup[0]), 2] = tup[2] | ||
| idx += len(tup[0]) | ||
| if (idx > data.shape[0] - len(operators)): | ||
| raise StopIteration() | ||
| except StopIteration: | ||
| pass | ||
|
|
||
| return data | ||
|
|
||
| def longer(words, size): | ||
| comp_lens = np.arange(opt.max_train_com+1, opt.max_test_com+1, dtype=int).tolist() | ||
| data = np.zeros((size, 2), dtype=object) | ||
| idx = 0 | ||
| try: | ||
| while idx < data.shape[0]: | ||
| random.shuffle(words) | ||
| for w in words: | ||
| tup = io_strings(w, words, comp_lens, 'produce') | ||
| data[idx:idx + len(tup[0]), 0] = tup[0] | ||
| data[idx:idx + len(tup[0]), 1] = tup[1] | ||
| #data[idx:idx + len(tup[0]), 2] = tup[2] | ||
| idx += len(tup[0]) | ||
| if (idx > data.shape[0] - len(operators)): | ||
| raise StopIteration() | ||
| except StopIteration: | ||
| pass | ||
| return data | ||
|
|
||
| def long_unseen(words1, words2, size): | ||
| comp_lens = np.arange(opt.max_train_com + 1, opt.max_test_com + 1, dtype=int).tolist() | ||
| data = np.zeros((size, 2), dtype=object) | ||
| idx = 0 | ||
| try: | ||
| while idx < data.shape[0]: | ||
| random.shuffle(words1) | ||
| for w in words1: | ||
| tup = io_strings(w, words2, comp_lens, 'produce') | ||
| data[idx:idx + len(tup[0]), 0] = tup[0] | ||
| data[idx:idx + len(tup[0]), 1] = tup[1] | ||
| #data[idx:idx + len(tup[0]), 2] = tup[2] | ||
| idx += len(tup[0]) | ||
| if (idx > data.shape[0] - len(operators)): | ||
| raise StopIteration() | ||
| random.shuffle(words2) | ||
| for w in words2: | ||
| tup = io_strings(w, words1, comp_lens, 'produce') | ||
| data[idx:idx + len(tup[0]), 0] = tup[0] | ||
| data[idx:idx + len(tup[0]), 1] = tup[1] | ||
| #data[idx:idx + len(tup[0]), 2] = tup[2] | ||
| idx += len(tup[0]) | ||
| if (idx > data.shape[0] - len(operators)): | ||
| raise StopIteration() | ||
| except StopIteration: | ||
| pass | ||
|
|
||
| return data | ||
|
|
||
| def get_data(num_samples): | ||
| tr1 = train(subset1, int(num_samples/2)) | ||
| tr2 = train(subset2, int(num_samples/2)) | ||
| train_data = np.vstack((tr1,tr2)) | ||
| unseen_test = unseen(subset1, subset2, num_samples) | ||
| lg1 = longer(subset1, int(num_samples/2)) | ||
| lg2 = longer(subset2, int(num_samples/2)) | ||
| longer_test = np.vstack((lg1, lg2)) | ||
| unseen_long_test = long_unseen(subset1, subset2, num_samples) | ||
|
|
||
| return(train_data, unseen_long_test, unseen_test, longer_test) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as this script is in a folder merely called "scripts" it should have a more descriptive name, check_correct for what?