From 6a75fa480960c8a91f4f69440c2377e9b88d82c0 Mon Sep 17 00:00:00 2001 From: faizan khan Date: Wed, 8 Apr 2020 20:57:36 -0400 Subject: [PATCH 1/9] WIP: self attention --- script/dfs_traversal_2.py | 3 - script/parent_node_pairs.py | 500 ++++++++++++++++++++++++++++++++++++ src/utils/codegen.py | 326 ++++++++++++++++------- src/utils/my_ast.py | 11 +- 4 files changed, 739 insertions(+), 101 deletions(-) create mode 100644 script/parent_node_pairs.py diff --git a/script/dfs_traversal_2.py b/script/dfs_traversal_2.py index 97a56e45..d250e102 100644 --- a/script/dfs_traversal_2.py +++ b/script/dfs_traversal_2.py @@ -1,7 +1,4 @@ -from yapf.yapflib import pytree_utils - # from src.dpu_utils.utils import RichPath -from src.utils.my_utils import DotDict from src.utils import my_ast from src.utils.codegen2 import * import json_lines diff --git a/script/parent_node_pairs.py b/script/parent_node_pairs.py new file mode 100644 index 00000000..2d1814cf --- /dev/null +++ b/script/parent_node_pairs.py @@ -0,0 +1,500 @@ + +from dpu_utils.utils import RichPath +from src.utils import my_ast +from src.utils.codegen import * +import subprocess +import pandas as pd +import os + +# path = 'resources/data/python/final/jsonl/valid_old/temp_train_10.jsonl.gz' +# # s_path = 'resources/data/python/final/jsonl/valid/temp_valid_10.jsonl.gz' +# +# a = RichPath.create(path) +# s = RichPath.create(s_path) +# +# print('started') +# b = list(a.read_as_jsonl()) + + +count = 0 +def convert_code_to_tokens(code): + global count + tree ='' + # tree = my_ast.parse(code) + + try: + tree = my_ast.parse(code) + except: + try: + f = open('temp.py', 'w+') + f.write(code) + f.close() + subprocess.run(['2to3', '-w', 'temp.py']) + f = open('temp.py', 'r') + code = f.read() + # print(code) + tree = my_ast.parse(code) + # os.rmdir('temp.py') + except: + pass + if tree!='': + an = SourceGenerator(' ') + an.visit(tree) + return an.result, an.parents + else: + return [] +# + +# templist = [] +# for idx, sample in enumerate(b): +# print("sample {} in progress".format(idx)) +# # print(sample['code']) +# if idx==3282: +# print(sample['code']) +# +# tokenization = convert_code_to_tokens(sample['code']) +# if tokenization == []: +# templist.append(idx) +# else: +# b[idx]['code_tokens'] = tokenization +# # tree = my_ast.parse(sample['code']) +# # an = SourceGenerator(' ') +# # an.visit(tree) +# # b[idx]['code_tokens'] = an.result +# +# s.save_as_compressed_file(b) +# print('finished', templist, len(templist), tokenization) + +import ast +import sys +import json +# def parse_file(code): +# global c, d +# tree = ast.parse(code) +# +# json_tree = [] +# +# def gen_identifier(identifier, node_type='identifier'): +# pos = len(json_tree) +# json_node = {} +# json_tree.append(json_node) +# json_node['type'] = node_type +# json_node['value'] = identifier +# return pos +# +# def traverse_list(l, node_type='list'): +# pos = len(json_tree) +# json_node = {} +# json_tree.append(json_node) +# json_node['type'] = node_type +# children = [] +# for item in l: +# children.append(traverse(item)) +# if (len(children) != 0): +# json_node['children'] = children +# return pos +# +# def traverse(node): +# pos = len(json_tree) +# json_node = {} +# json_tree.append(json_node) +# json_node['type'] = type(node).__name__ +# children = [] +# if isinstance(node, ast.Name): +# json_node['value'] = node.id +# elif isinstance(node, ast.Num): +# json_node['value'] = unicode(node.n) +# elif isinstance(node, ast.Str): +# json_node['value'] = node.s.decode('utf-8') +# elif isinstance(node, ast.alias): +# json_node['value'] = unicode(node.name) +# if node.asname: +# children.append(gen_identifier(node.asname)) +# elif isinstance(node, ast.FunctionDef): +# json_node['value'] = unicode(node.name) +# elif isinstance(node, ast.ClassDef): +# json_node['value'] = unicode(node.name) +# elif isinstance(node, ast.ImportFrom): +# if node.module: +# json_node['value'] = unicode(node.module) +# elif isinstance(node, ast.Global): +# for n in node.names: +# children.append(gen_identifier(n)) +# elif isinstance(node, ast.keyword): +# json_node['value'] = unicode(node.arg) +# +# # Process children. +# if isinstance(node, ast.For): +# children.append(traverse(node.target)) +# children.append(traverse(node.iter)) +# children.append(traverse_list(node.body, 'body')) +# if node.orelse: +# children.append(traverse_list(node.orelse, 'orelse')) +# elif isinstance(node, ast.If) or isinstance(node, ast.While): +# children.append(traverse(node.test)) +# children.append(traverse_list(node.body, 'body')) +# if node.orelse: +# children.append(traverse_list(node.orelse, 'orelse')) +# elif isinstance(node, ast.With): +# children.append(traverse(node.context_expr)) +# if node.optional_vars: +# children.append(traverse(node.optional_vars)) +# children.append(traverse_list(node.body, 'body')) +# elif isinstance(node, ast.Try): +# children.append(traverse_list(node.body, 'body')) +# children.append(traverse_list(node.handlers, 'handlers')) +# if node.orelse: +# children.append(traverse_list(node.orelse, 'orelse')) +# if node.finalbody: +# children.append(traverse_list(node.finalbody, 'finalbody')) +# elif isinstance(node, ast.arguments): +# children.append(traverse_list(node.args, 'args')) +# children.append(traverse_list(node.defaults, 'defaults')) +# if node.vararg: +# children.append(gen_identifier(node.vararg, 'vararg')) +# if node.kwarg: +# children.append(gen_identifier(node.kwarg, 'kwarg')) +# elif isinstance(node, ast.ExceptHandler): +# if node.type: +# children.append(traverse_list([node.type], 'type')) +# if node.name: +# children.append(traverse_list([node.name], 'name')) +# children.append(traverse_list(node.body, 'body')) +# elif isinstance(node, ast.ClassDef): +# children.append(traverse_list(node.bases, 'bases')) +# children.append(traverse_list(node.body, 'body')) +# children.append(traverse_list(node.decorator_list, 'decorator_list')) +# elif isinstance(node, ast.FunctionDef): +# children.append(traverse(node.args)) +# children.append(traverse_list(node.body, 'body')) +# children.append(traverse_list(node.decorator_list, 'decorator_list')) +# else: +# # Default handling: iterate over children. +# for child in ast.iter_child_nodes(node): +# if isinstance(child, ast.expr_context) or isinstance(child, ast.operator) or isinstance(child, +# ast.boolop) or isinstance( +# child, ast.unaryop) or isinstance(child, ast.cmpop): +# # Directly include expr_context, and operators into the type instead of creating a child. +# json_node['type'] = json_node['type'] + type(child).__name__ +# else: +# children.append(traverse(child)) +# +# if isinstance(node, ast.Attribute): +# children.append(gen_identifier(node.attr, 'attr')) +# +# if (len(children) != 0): +# json_node['children'] = children +# return pos +# +# traverse(tree) +# return json_tree + +# def updated_parse_file(code): +# global c, d +# tree = ast.parse(code) +# +# json_tree = [] +# +# def gen_identifier(identifier, node_type='identifier', parent=None): +# pos = len(json_tree) +# json_node = {} +# json_tree.append(json_node) +# # json_node['type'] = node_type +# json_node[node_type] = identifier +# if parent: +# json_node['parent'] = type(parent).__name__ +# else: +# json_node['parent'] = None +# return pos +# +# def traverse_list(l, node_type='list', parent=None): +# pos = len(json_tree) +# json_node = {} +# json_tree.append(json_node) +# json_node[node_type] = [] +# if parent: +# json_node['parent'] = type(parent).__name__ +# else: +# json_node['parent'] = None +# children = [] +# for item in l: +# children.append(traverse(item)) +# if (len(children) != 0): +# json_node[node_type] = children +# return pos +# +# def traverse(node, parent=None): +# pos = len(json_tree) +# json_node = {} +# json_tree.append(json_node) +# json_node[type(node).__name__] = [] +# if parent: +# json_node['parent'] = type(parent).__name__ +# else: +# json_node['parent'] = None +# children = [] +# if isinstance(node, ast.Name): +# json_node[type(node).__name__] = node.id +# elif isinstance(node, ast.Num): +# json_node[type(node).__name__] = unicode(node.n) +# elif isinstance(node, ast.Str): +# json_node[type(node).__name__] = node.s.decode('utf-8') +# elif isinstance(node, ast.alias): +# json_node[type(node).__name__] = unicode(node.name) +# if node.asname: +# children.append(gen_identifier(node.asname)) +# elif isinstance(node, ast.FunctionDef): +# json_node[type(node).__name__] = unicode(node.name) +# elif isinstance(node, ast.ClassDef): +# json_node[type(node).__name__] = unicode(node.name) +# elif isinstance(node, ast.ImportFrom): +# if node.module: +# json_node[type(node).__name__] = unicode(node.module) +# elif isinstance(node, ast.Global): +# for n in node.names: +# children.append(gen_identifier(n)) +# elif isinstance(node, ast.keyword): +# json_node[type(node).__name__] = unicode(node.arg) +# +# # Process children. +# if isinstance(node, ast.For): +# children.append(traverse(node.target, node)) +# children.append(traverse(node.iter, node)) +# children.append(traverse_list(node.body, 'body', node)) +# if node.orelse: +# children.append(traverse_list(node.orelse, 'orelse', node)) +# elif isinstance(node, ast.If) or isinstance(node, ast.While): +# children.append(traverse(node.test, node)) +# children.append(traverse_list(node.body, 'body', node)) +# if node.orelse: +# children.append(traverse_list(node.orelse, 'orelse', node)) +# elif isinstance(node, ast.With): +# children.append(traverse(node.context_expr, node)) +# if node.optional_vars: +# children.append(traverse(node.optional_vars, node)) +# children.append(traverse_list(node.body, 'body', node)) +# elif isinstance(node, ast.Try): +# children.append(traverse_list(node.body, 'body', node)) +# children.append(traverse_list(node.handlers, 'handlers', node)) +# if node.orelse: +# children.append(traverse_list(node.orelse, 'orelse', node)) +# if node.finalbody: +# children.append(traverse_list(node.finalbody, 'finalbody', node)) +# elif isinstance(node, ast.arguments): +# children.append(traverse_list(node.args, 'args', node)) +# children.append(traverse_list(node.defaults, 'defaults', node)) +# if node.vararg: +# children.append(gen_identifier(node.vararg, 'vararg')) +# if node.kwarg: +# children.append(gen_identifier(node.kwarg, 'kwarg')) +# elif isinstance(node, ast.ExceptHandler): +# if node.type: +# children.append(traverse_list([node.type], 'type', node)) +# if node.name: +# children.append(traverse_list([node.name], 'name', node)) +# children.append(traverse_list(node.body, 'body', node)) +# elif isinstance(node, ast.ClassDef): +# children.append(traverse_list(node.bases, 'bases', node)) +# children.append(traverse_list(node.body, 'body', node)) +# children.append(traverse_list(node.decorator_list, 'decorator_list', node)) +# elif isinstance(node, ast.FunctionDef): +# children.append(traverse(node.args, node)) +# children.append(traverse_list(node.body, 'body',node)) +# children.append(traverse_list(node.decorator_list, 'decorator_list',node)) +# else: +# # Default handling: iterate over children. +# for child in ast.iter_child_nodes(node): +# if isinstance(child, ast.expr_context) or isinstance(child, ast.operator) or isinstance(child, +# ast.boolop) or isinstance( +# child, ast.unaryop) or isinstance(child, ast.cmpop): +# # Directly include expr_context, and operators into the type instead of creating a child. +# json_node[type(node).__name__ + type(child).__name__] = json_node[type(node).__name__] +# del json_node[type(node).__name__] +# else: +# children.append(traverse(child,node)) +# +# if isinstance(node, ast.Attribute): +# children.append(gen_identifier(node.attr, 'Attr')) +# +# if (len(children) != 0): +# if type(node).__name__ not in json_node.keys(): +# json_node[type(node).__name__ + type(child).__name__] = children +# else: +# json_node[type(node).__name__] = children +# return pos +# +# traverse(tree) +# return json_tree +# # return json.dumps(json_tree, separators=(',', ':'), ensure_ascii=False) + +def parse_file_with_parents(code): + global c, d + tree = ast.parse(code) + + json_tree = [] + + def gen_identifier(identifier, node_type='identifier', parent=None): + pos = len(json_tree) + json_node = {} + json_tree.append(json_node) + json_node['type'] = node_type + json_node['value'] = identifier + if parent: + json_node['parent'] = type(parent).__name__ + else: + json_node['parent'] = None + return pos + + def traverse_list(l, node_type='list'): + pos = len(json_tree + json_node = {} + json_tree.append(json_node) + json_node['type'] = node_type + if parent: + json_node['parent'] = type(parent).__name__ + else: + json_node['parent'] = None + children = [] + for item in l: + children.append(traverse(item)) + if (len(children) != 0): + json_node['children'] = children + return pos + + def traverse(node): + pos = len(json_tree) + json_node = {} + json_tree.append(json_node) + json_node['type'] = type(node).__name__ + if parent: + json_node['parent'] = type(parent).__name__ + else: + json_node['parent'] = None + children = [] + if isinstance(node, ast.Name): + json_node['value'] = node.id + elif isinstance(node, ast.Num): + json_node['value'] = unicode(node.n) + elif isinstance(node, ast.Str): + json_node['value'] = node.s.decode('utf-8') + elif isinstance(node, ast.alias): + json_node['value'] = unicode(node.name) + if node.asname: + children.append(gen_identifier(node.asname)) + elif isinstance(node, ast.FunctionDef): + json_node['value'] = unicode(node.name) + elif isinstance(node, ast.ClassDef): + json_node['value'] = unicode(node.name) + elif isinstance(node, ast.ImportFrom): + if node.module: + json_node['value'] = unicode(node.module) + elif isinstance(node, ast.Global): + for n in node.names: + children.append(gen_identifier(n)) + elif isinstance(node, ast.keyword): + json_node['value'] = unicode(node.arg) + + # Process children. + if isinstance(node, ast.For): + children.append(traverse(node.target)) + children.append(traverse(node.iter)) + children.append(traverse_list(node.body, 'body')) + if node.orelse: + children.append(traverse_list(node.orelse, 'orelse')) + elif isinstance(node, ast.If) or isinstance(node, ast.While): + children.append(traverse(node.test)) + children.append(traverse_list(node.body, 'body')) + if node.orelse: + children.append(traverse_list(node.orelse, 'orelse')) + elif isinstance(node, ast.With): + children.append(traverse(node.context_expr)) + if node.optional_vars: + children.append(traverse(node.optional_vars)) + children.append(traverse_list(node.body, 'body')) + elif isinstance(node, ast.Try): + children.append(traverse_list(node.body, 'body')) + children.append(traverse_list(node.handlers, 'handlers')) + if node.orelse: + children.append(traverse_list(node.orelse, 'orelse')) + if node.finalbody: + children.append(traverse_list(node.finalbody, 'finalbody')) + elif isinstance(node, ast.arguments): + children.append(traverse_list(node.args, 'args')) + children.append(traverse_list(node.defaults, 'defaults')) + if node.vararg: + children.append(gen_identifier(node.vararg, 'vararg')) + if node.kwarg: + children.append(gen_identifier(node.kwarg, 'kwarg')) + elif isinstance(node, ast.ExceptHandler): + if node.type: + children.append(traverse_list([node.type], 'type')) + if node.name: + children.append(traverse_list([node.name], 'name')) + children.append(traverse_list(node.body, 'body')) + elif isinstance(node, ast.ClassDef): + children.append(traverse_list(node.bases, 'bases')) + children.append(traverse_list(node.body, 'body')) + children.append(traverse_list(node.decorator_list, 'decorator_list')) + elif isinstance(node, ast.FunctionDef): + children.append(traverse(node.args)) + children.append(traverse_list(node.body, 'body')) + children.append(traverse_list(node.decorator_list, 'decorator_list')) + else: + # Default handling: iterate over children. + for child in ast.iter_child_nodes(node): + if isinstance(child, ast.expr_context) or isinstance(child, ast.operator) or isinstance(child, + ast.boolop) or isinstance( + child, ast.unaryop) or isinstance(child, ast.cmpop): + # Directly include expr_context, and operators into the type instead of creating a child. + json_node['type'] = json_node['type'] + type(child).__name__ + else: + children.append(traverse(child)) + + if isinstance(node, ast.Attribute): + children.append(gen_identifier(node.attr, 'attr')) + + if (len(children) != 0): + json_node['children'] = children + return pos + + traverse(tree) + return json_tree + + +# [{'children': [1], 'type': 'Module'}, +# {'children': [2, 3], 'type': 'Assign'}, +# {'type': 'NameStore', 'value': 'ip'}, +# {'children': [4, 7], 'type': 'Call'}, +# {'children': [5, 6], 'type': 'AttributeLoad'}, +# {'type': 'NameLoad', 'value': 'socket'}, +# {'type': 'attr', 'value': 'gethostbyname'}, +# {'type': 'NameLoad', 'value': 'host'}] + + +from pprint import pprint +if __name__=='__main__': + print('something') + +# code ='''print('something') +# try: +# a+1 +# except IOError: +# return 1 +# else: +# a+2 +# finally: +# return 2''' + + # code= '''def f(a, b=1, c=2, *d, e, f=3, **g): + # pass''' + + code = '''ip = socket.gethostbyname(host)''' + # code = '''func(a, b=c, *d, **e)''' + # a, b = convert_code_to_tokens(code) + # df = pd.DataFrame([a, b]) + # print(df.T) + + result_tree = updated_parse_file(code) + + # print(pd.read_json(result_tree)) + print(result_tree) \ No newline at end of file diff --git a/src/utils/codegen.py b/src/utils/codegen.py index adeae537..db74600e 100644 --- a/src/utils/codegen.py +++ b/src/utils/codegen.py @@ -38,6 +38,7 @@ class SourceGenerator(NodeVisitor): def __init__(self, indent_with, add_line_information=False): self.result = [] + self.parents = [] self.indent_with = indent_with self.add_line_information = add_line_information self.indentation = 0 @@ -58,22 +59,26 @@ def newline(self, node=None, extra=0): self.new_lines = max(self.new_lines, 1 + extra) if node is not None and self.add_line_information: # self.write('# line: %s' % node.lineno) + self.parents.append(type(node).__name__) # self.new_lines = 1 self.write('%s: ' % node.lineno) + self.parents.append(type(node).__name__) - def body(self, statements): + def body(self, statements, parent=None): + node = parent self.new_line = True self.indentation += 1 for stmt in statements: - self.visit(stmt) + self.visit(stmt, node) self.indentation -= 1 def body_or_else(self, node): - self.body(node.body) + self.body(node.body, node) if node.orelse: self.newline() self.write('else:') - self.body(node.orelse) + self.parents.append(type(node).__name__) + self.body(node.orelse, node) def signature(self, node): want_comma = [] @@ -81,6 +86,7 @@ def signature(self, node): def write_comma(): if want_comma: self.write(', ') + self.parents.append(type(node).__name__) else: want_comma.append(True) @@ -88,24 +94,30 @@ def write_comma(): for arg, default in zip(node.args, padding + node.defaults): write_comma() self.write(arg.arg) + self.parents.append(type(node).__name__) if default is not None: self.write('=') - self.visit(default) + self.parents.append(type(node).__name__) + self.visit(default, node) if node.vararg is not None: write_comma() self.write('*' + node.vararg.arg) + self.parents.append(type(node).__name__) for arg, default in zip(node.kwonlyargs, node.kw_defaults): write_comma() self.write(arg.arg) + self.parents.append(type(node).__name__) if default is not None: self.write('=') - self.visit(default) + self.parents.append(type(node).__name__) + self.visit(default, node) if node.kwarg is not None: write_comma() self.write('**' + node.kwarg.arg) + self.parents.append(type(node).__name__) def decorators(self, node): if not node: @@ -116,7 +128,8 @@ def decorators(self, node): for decorator in node.decorator_list: self.newline(decorator) self.write('@') - self.visit(decorator) + self.parents.append(type(node).__name__) + self.visit(decorator, node) # Statements @@ -125,31 +138,38 @@ def visit_Assign(self, node): for idx, target in enumerate(node.targets): if idx: self.write(', ') - self.visit(target) + self.parents.append(type(node).__name__) + self.visit(target, node) # self.write(' = ') - self.visit(node.value) + self.parents.append(type(node).__name__) + self.visit(node.value, node) def visit_AugAssign(self, node): self.newline(node) - self.visit(node.target) + self.visit(node.target, node) # self.write(BINOP_SYMBOLS[type(node.op)] + '=') - self.visit(node.value) + self.parents.append(type(node).__name__) + self.visit(node.value, node) def visit_ImportFrom(self, node): self.newline(node) self.write('from %s%s import ' % ('.' * node.level, node.module)) + self.parents.append(type(node).__name__) for idx, item in enumerate(node.names): if idx: self.write(', ') - self.visit(item) + self.parents.append(type(node).__name__) + self.visit(item, node) def visit_Import(self, node): self.newline(node) self.write('import ') + self.parents.append(type(node).__name__) for idx, item in enumerate(node.names): if idx: self.write(', ') - self.visit(item) + self.parents.append(type(node).__name__) + self.visit(item, node) def visit_Expr(self, node): self.newline(node) @@ -163,13 +183,18 @@ def visit_FunctionDef(self, node): if node.decorator_list: for decorator in node.decorator_list: self.write('@') - self.visit(decorator) + self.parents.append(type(node).__name__) + self.visit(decorator, node) self.write('def ') + self.parents.append(type(node).__name__) self.write('%s' % node.name) + self.parents.append(type(node).__name__) self.write('(') + self.parents.append(type(node).__name__) self.signature(node.args) self.write('):') - self.body(node.body) + self.parents.append(type(node).__name__) + self.body(node.body, node) def visit_ClassDef(self, node): have_args = [] @@ -177,54 +202,66 @@ def visit_ClassDef(self, node): def paren_or_comma(): if have_args: self.write(', ') + self.parents.append(type(node).__name__) else: have_args.append(True) self.write('(') + self.parents.append(type(node).__name__) self.newline(extra=2) self.decorators(node) self.newline(node) self.write('class %s' % node.name) + self.parents.append(type(node).__name__) for base in node.bases: paren_or_comma() - self.visit(base) + self.visit(base, node) # XXX: the if here is used to keep this module compatible # with python 2.6. if hasattr(node, 'keywords'): for keyword in node.keywords: paren_or_comma() self.write(keyword.arg + '=') - self.visit(keyword.value) + self.parents.append(type(node).__name__) + self.visit(keyword.value, node) if hasattr(node, 'starargs') and node.starargs is not None: paren_or_comma() self.write('*') - self.visit(node.starargs) + self.parents.append(type(node).__name__) + self.visit(node.starargs, node) if hasattr(node, 'kwargs') and node.kwargs is not None: paren_or_comma() self.write('**') - self.visit(node.kwargs) + self.parents.append(type(node).__name__) + self.visit(node.kwargs, node) self.write(have_args and '):' or ':') - self.body(node.body) + self.parents.append(type(node).__name__) + self.body(node.body, node) def visit_If(self, node): self.newline(node) self.write('if ') - self.visit(node.test) + self.parents.append(type(node).__name__) + self.visit(node.test, node) self.write(':') - self.body(node.body) + self.parents.append(type(node).__name__) + self.body(node.body, node) while True: else_ = node.orelse if len(else_) == 1 and isinstance(else_[0], If): node = else_[0] self.newline() self.write('elif ') - self.visit(node.test) + self.parents.append(type(node).__name__) + self.visit(node.test, node) self.write(':') - self.body(node.body) + self.parents.append(type(node).__name__) + self.body(node.body, node) elif else_: self.newline() self.write('else:') - self.body(else_) + self.parents.append(type(node).__name__) + self.body(else_, node) break else: break @@ -232,144 +269,177 @@ def visit_If(self, node): def visit_For(self, node): self.newline(node) self.write('for ') - self.visit(node.target) + self.parents.append(type(node).__name__) + self.visit(node.target, node) self.write(' in ') - self.visit(node.iter) + self.parents.append(type(node).__name__) + self.visit(node.iter, node) self.write(':') - self.body_or_else(node) + self.parents.append(type(node).__name__) + self.body_or_else(node, node) def visit_While(self, node): self.newline(node) self.write('while ') - self.visit(node.test) + self.parents.append(type(node).__name__) + self.visit(node.test, node) self.write(':') - self.body_or_else(node) + self.parents.append(type(node).__name__) + self.body_or_else(node, node) def visit_With(self, node): self.newline(node) self.write('with ') + self.parents.append(type(node).__name__) for item in node.items: - self.visit(item.context_expr) + self.visit(item.context_expr, node) if item.optional_vars is not None: self.write(' as ') - self.visit(item.optional_vars) + self.parents.append(type(node).__name__) + self.visit(item.optional_vars, node) self.write(':') - self.body(node.body) + self.parents.append(type(node).__name__) + self.body(node.body, node) def visit_Pass(self, node): self.newline(node) self.write('pass') + self.parents.append(type(node).__name__) def visit_Print(self, node): # XXX: python 2.6 only self.newline(node) self.write('print ') + self.parents.append(type(node).__name__) want_comma = False if node.dest is not None: self.write(' >> ') - self.visit(node.dest) + self.parents.append(type(node).__name__) + self.visit(node.dest, node) want_comma = True for value in node.values: if want_comma: self.write(', ') - self.visit(value) + self.parents.append(type(node).__name__) + self.visit(value, node) want_comma = True if not node.nl: self.write(',') + self.parents.append(type(node).__name__) def visit_Delete(self, node): self.newline(node) self.write('del ') + self.parents.append(type(node).__name__) for idx, target in enumerate(node.targets): if idx: self.write(', ') - self.visit(target) + self.parents.append(type(node).__name__) + self.visit(target, node) def visit_Try(self, node): self.newline(node) #try block self.write('try:') - self.body(node.body) + self.parents.append(type(node).__name__) + self.body(node.body, node) self.newline(node) #except block for handler in node.handlers: - self.visit(handler) + self.visit(handler, node) #except else if len(node.orelse): self.write('else:') - self.body(node.orelse) + self.parents.append(type(node).__name__) + self.body(node.orelse, node) #except finally if len(node.finalbody): self.write('finally:') - self.body(node.finalbody) + self.parents.append(type(node).__name__) + self.body(node.finalbody, node) def visit_TryExcept(self, node): self.newline(node) self.write('try:') - self.body(node.body) + self.parents.append(type(node).__name__) + self.body(node.body, node) for handler in node.handlers: - self.visit(handler) + self.visit(handler, node) def visit_TryFinally(self, node): self.newline(node) self.write('try:') - self.body(node.body) + self.parents.append(type(node).__name__) + self.body(node.body, node) self.newline(node) self.write('finally:') - self.body(node.finalbody) + self.parents.append(type(node).__name__) + self.body(node.finalbody, node) def visit_Global(self, node): self.newline(node) self.write('global ' + ', '.join(node.names)) + self.parents.append(type(node).__name__) def visit_Nonlocal(self, node): self.newline(node) self.write('nonlocal ' + ', '.join(node.names)) + self.parents.append(type(node).__name__) def visit_Return(self, node): self.newline(node) self.write('return ') + self.parents.append(type(node).__name__) if node.value: - self.visit(node.value) + self.visit(node.value, node) def visit_Break(self, node): self.newline(node) self.write('break') + self.parents.append(type(node).__name__) def visit_Continue(self, node): self.newline(node) self.write('continue') + self.parents.append(type(node).__name__) def visit_Raise(self, node): # XXX: Python 2.6 / 3.0 compatibility self.newline(node) self.write('raise') + self.parents.append(type(node).__name__) if hasattr(node, 'exc') and node.exc is not None: self.write(' ') - self.visit(node.exc) + self.parents.append(type(node).__name__) + self.visit(node.exc, node) if node.cause is not None: self.write(' from ') - self.visit(node.cause) + self.parents.append(type(node).__name__) + self.visit(node.cause, node) elif hasattr(node, 'type') and node.type is not None: - self.visit(node.type) + self.visit(node.type, node) if node.inst is not None: self.write(', ') - self.visit(node.inst) + self.parents.append(type(node).__name__) + self.visit(node.inst, node) if node.tback is not None: self.write(', ') - self.visit(node.tback) + self.parents.append(type(node).__name__) + self.visit(node.tback, node) # Expressions def visit_Attribute(self, node): - self.visit(node.value) + self.visit(node.value, node) self.write('.') + self.parents.append(type(node).__name__) self.write(node.attr) + self.parents.append(type(node).__name__) def visit_Call(self, node): want_comma = [] @@ -377,63 +447,80 @@ def visit_Call(self, node): def write_comma(): if want_comma: self.write(', ') + self.parents.append(type(node).__name__) else: want_comma.append(True) - self.visit(node.func) + self.visit(node.func, node) self.write('(') + self.parents.append(type(node).__name__) for arg in node.args: write_comma() - self.visit(arg) + self.visit(arg, node) for keyword in node.keywords: write_comma() if keyword.arg: self.write(keyword.arg + '=') - self.visit(keyword.value) + self.parents.append(type(node).__name__) + self.visit(keyword.value, node) else: self.write('**') - self.visit(keyword.value) + self.parents.append(type(node).__name__) + self.visit(keyword.value, node) # if hasattr(node, 'starargs') and node.starargs is not None: # write_comma() # self.write('*') - # self.visit(node.starargs) + self.parents.append(type(node).__name__) + # self.visit(node.starargs, node) # if hasattr(node, 'kwargs') and node.kwargs is not None: # write_comma() # self.write('**') - # self.visit(node.kwargs) + self.parents.append(type(node).__name__) + # self.visit(node.kwargs, node) self.write(')') + self.parents.append(type(node).__name__) def visit_Name(self, node): self.write(node.id) + self.parents.append(type(node).__name__) def visit_Str(self, node): if self.docstring != node.s: self.write(repr(node.s)) + self.parents.append(type(node).__name__) def visit_Bytes(self, node): self.write(repr(node.s)) + self.parents.append(type(node).__name__) def visit_Num(self, node): self.write(repr(node.n)) + self.parents.append(type(node).__name__) def visit_Tuple(self, node): self.write('(') + self.parents.append(type(node).__name__) idx = -1 for idx, item in enumerate(node.elts): if idx: self.write(',') - self.visit(item) + self.parents.append(type(node).__name__) + self.visit(item, node) self.write(idx and ')' or ',)') + self.parents.append(type(node).__name__) def sequence_visit(left, right): def visit(self, node): self.write(left) + self.parents.append(type(node).__name__) for idx, item in enumerate(node.elts): if idx: self.write(', ') - self.visit(item) + self.parents.append(type(node).__name__) + self.visit(item, node) self.write(right) + self.parents.append(type(node).__name__) return visit @@ -443,92 +530,119 @@ def visit(self, node): def visit_Dict(self, node): self.write('{') + self.parents.append(type(node).__name__) for idx, (key, value) in enumerate(zip(node.keys, node.values)): if idx: self.write(', ') + self.parents.append(type(node).__name__) if key!=None: - self.visit(key) + self.visit(key, node) self.write(': ') - self.visit(value) + self.parents.append(type(node).__name__) + self.visit(value, node) elif key==None: self.write('**') - self.visit(value) + self.parents.append(type(node).__name__) + self.visit(value, node) self.write('}') + self.parents.append(type(node).__name__) def visit_BinOp(self, node): - self.visit(node.left) + self.visit(node.left, node) # self.write('%s' % BINOP_SYMBOLS[type(node.op)]) - self.visit(node.right) + self.parents.append(type(node).__name__) + self.visit(node.right, node) def visit_BoolOp(self, node): self.write('(') + self.parents.append(type(node).__name__) for idx, value in enumerate(node.values): if idx: self.write(' %s ' % BOOLOP_SYMBOLS[type(node.op)]) - self.visit(value) + self.parents.append(type(node).__name__) + self.visit(value, node) self.write(')') + self.parents.append(type(node).__name__) def visit_Compare(self, node): self.write('(') - self.visit(node.left) + self.parents.append(type(node).__name__) + self.visit(node.left, node) for op, right in zip(node.ops, node.comparators): self.write(' %s ' % CMPOP_SYMBOLS[type(op)]) - self.visit(right) + self.parents.append(type(node).__name__) + self.visit(right, node) self.write(')') + self.parents.append(type(node).__name__) def visit_UnaryOp(self, node): self.write('(') + self.parents.append(type(node).__name__) op = UNARYOP_SYMBOLS[type(node.op)] self.write(op) + self.parents.append(type(node).__name__) if op == 'not': self.write(' ') - self.visit(node.operand) + self.parents.append(type(node).__name__) + self.visit(node.operand, node) self.write(')') + self.parents.append(type(node).__name__) def visit_Subscript(self, node): - self.visit(node.value) + self.visit(node.value, node) self.write('[') - self.visit(node.slice) + self.parents.append(type(node).__name__) + self.visit(node.slice, node) self.write(']') + self.parents.append(type(node).__name__) def visit_Slice(self, node): if node.lower is not None: - self.visit(node.lower) + self.visit(node.lower, node) self.write(':') + self.parents.append(type(node).__name__) if node.upper is not None: - self.visit(node.upper) + self.visit(node.upper, node) if node.step is not None: self.write(':') + self.parents.append(type(node).__name__) if not (isinstance(node.step, Name) and node.step.id == 'None'): - self.visit(node.step) + self.visit(node.step, node) def visit_ExtSlice(self, node): for idx, item in enumerate(node.dims): if idx: self.write(', ') - self.visit(item) + self.parents.append(type(node).__name__) + self.visit(item, node) def visit_Yield(self, node): self.write('yield ') + self.parents.append(type(node).__name__) if node.value: - self.visit(node.value) + self.visit(node.value, node) def visit_Lambda(self, node): self.write('lambda ') + self.parents.append(type(node).__name__) self.signature(node.args) self.write(': ') - self.visit(node.body) + self.parents.append(type(node).__name__) + self.visit(node.body, node) def visit_Ellipsis(self, node): self.write('Ellipsis') + self.parents.append(type(node).__name__) def generator_visit(left, right): def visit(self, node): self.write(left) - self.visit(node.elt) + self.parents.append(type(node).__name__) + self.visit(node.elt, node) for comprehension in node.generators: - self.visit(comprehension) + self.visit(comprehension, node) self.write(right) + self.parents.append(type(node).__name__) return visit @@ -539,67 +653,89 @@ def visit(self, node): def visit_DictComp(self, node): self.write('{') - self.visit(node.key) + self.parents.append(type(node).__name__) + self.visit(node.key, node) self.write(': ') - self.visit(node.value) + self.parents.append(type(node).__name__) + self.visit(node.value, node) for comprehension in node.generators: - self.visit(comprehension) + self.visit(comprehension, node) self.write('}') + self.parents.append(type(node).__name__) def visit_IfExp(self, node): - self.visit(node.body) + self.visit(node.body, node) self.write(' if ') - self.visit(node.test) + self.parents.append(type(node).__name__) + self.visit(node.test, node) self.write(' else ') - self.visit(node.orelse) + self.parents.append(type(node).__name__) + self.visit(node.orelse, node) def visit_Starred(self, node): self.write('*') - self.visit(node.value) + self.parents.append(type(node).__name__) + self.visit(node.value, node) def visit_Repr(self, node): # XXX: python 2.6 only self.write('`') - self.visit(node.value) + self.parents.append(type(node).__name__) + self.visit(node.value, node) self.write('`') + self.parents.append(type(node).__name__) # Helper Nodes def visit_alias(self, node): self.write(node.name) + self.parents.append(type(node).__name__) if node.asname is not None: self.write(' as ' + node.asname) + self.parents.append(type(node).__name__) def visit_comprehension(self, node): self.write(' for ') - self.visit(node.target) + self.parents.append(type(node).__name__) + self.visit(node.target, node) self.write(' in ') - self.visit(node.iter) + self.parents.append(type(node).__name__) + self.visit(node.iter, node) if node.ifs: for if_ in node.ifs: self.write(' if ') - self.visit(if_) + self.parents.append(type(node).__name__) + self.visit(if_, node) def visit_ExceptHandler(self, node): self.newline(node) self.write('except') + self.parents.append(type(node).__name__) if node.type is not None: self.write(' ') - self.visit(node.type) + self.parents.append(type(node).__name__) + self.visit(node.type, node) if node.name is not None: self.write(', ') + self.parents.append(type(node).__name__) self.write(node.name) + self.parents.append(type(node).__name__) self.write(':') - self.body(node.body) + self.parents.append(type(node).__name__) + self.body(node.body, node) # def visit_exceptHandler(self, node): # self.newline(node) # self.write('except') + # self.parents.append(type(node).__name__) # if node.type is not None: # self.write(' ') - # self.visit(node.type) + # self.parents.append(type(node).__name__) + # self.visit(node.type, node) # if node.name is not None: - # self.write(', ') - # self.visit(node.name) + # # self.write(', ') + # self.parents.append(type(node).__name__) + # self.visit(node.name, node) # self.write(':') - # self.body(node.body) \ No newline at end of file + # self.parents.append(type(node).__name__) + # self.body(node.body, node) \ No newline at end of file diff --git a/src/utils/my_ast.py b/src/utils/my_ast.py index e13a2be8..277789fc 100644 --- a/src/utils/my_ast.py +++ b/src/utils/my_ast.py @@ -304,16 +304,21 @@ class name of the node. So a `TryFinally` node visit function would """ first_type = True - def visit(self, node): + def visit(self, node, parent=None): """Visit a node.""" method = 'visit_' + node.__class__.__name__ visitor = getattr(self, method, self.generic_visit) + if type(node) ==Expr: pass elif type(node)==Str and self.docstring == node.s: pass else: self.write(node.__class__.__name__) + if parent: + self.parents.append(type(parent).__name__) + else: + self.parents.append(None) return visitor(node) def generic_visit(self, node): @@ -322,9 +327,9 @@ def generic_visit(self, node): if isinstance(value, list): for item in value: if isinstance(item, AST): - self.visit(item) + self.visit(item, node) elif isinstance(value, AST): - self.visit(value) + self.visit(value, node) class NodeTransformer(NodeVisitor): From 0b646d2f40a9ae930abf7069612fcd5c2ebc4da1 Mon Sep 17 00:00:00 2001 From: faizan khan Date: Wed, 8 Apr 2020 21:27:17 -0400 Subject: [PATCH 2/9] finalized removal of buggy samples --- script/dfs_traversal.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/script/dfs_traversal.py b/script/dfs_traversal.py index 96ef25b9..9e74aa18 100644 --- a/script/dfs_traversal.py +++ b/script/dfs_traversal.py @@ -5,8 +5,8 @@ import subprocess import os -path = 'resources/data/python/final/jsonl/valid_old/python_valid_0.jsonl.gz' -s_path = 'resources/data/python/final/jsonl/valid/python_valid_0_updated.jsonl.gz' +path = '../resources/data/python/final/jsonl/valid_old/python_valid_0.jsonl.gz' +s_path = '../resources/data/python/final/jsonl/valid/python_valid_0_updated.jsonl.gz' a = RichPath.create(path) s = RichPath.create(s_path) @@ -14,6 +14,7 @@ print('started') b = list(a.read_as_jsonl()) +c=[] count = 0 def convert_code_to_tokens(code): @@ -48,7 +49,7 @@ def convert_code_to_tokens(code): for idx, sample in enumerate(b): print("sample {} in progress".format(idx)) # print(sample['code']) - if idx==3282: + if idx==5306: print(sample['code']) tokenization = convert_code_to_tokens(sample['code']) @@ -56,12 +57,13 @@ def convert_code_to_tokens(code): templist.append(idx) else: b[idx]['code_tokens'] = tokenization + c.append(b[idx]) # tree = my_ast.parse(sample['code']) # an = SourceGenerator(' ') # an.visit(tree) # b[idx]['code_tokens'] = an.result -s.save_as_compressed_file(b) +s.save_as_compressed_file(c) print('finished', templist, len(templist), tokenization) From fcaee6a1997c4d6c4450319bea54f7ba038ed3d7 Mon Sep 17 00:00:00 2001 From: faizan khan Date: Fri, 10 Apr 2020 18:14:15 -0400 Subject: [PATCH 3/9] final working parent-node-extraction code --- script/parent_node_pairs.py | 496 ++++------------------------ script/parent_node_parse_helpers.py | 337 +++++++++++++++++++ 2 files changed, 393 insertions(+), 440 deletions(-) create mode 100644 script/parent_node_parse_helpers.py diff --git a/script/parent_node_pairs.py b/script/parent_node_pairs.py index 2d1814cf..31cbe743 100644 --- a/script/parent_node_pairs.py +++ b/script/parent_node_pairs.py @@ -3,17 +3,11 @@ from src.utils import my_ast from src.utils.codegen import * import subprocess + +from parent_node_parse_helpers import dfs_traversal_with_parents import pandas as pd import os -# path = 'resources/data/python/final/jsonl/valid_old/temp_train_10.jsonl.gz' -# # s_path = 'resources/data/python/final/jsonl/valid/temp_valid_10.jsonl.gz' -# -# a = RichPath.create(path) -# s = RichPath.create(s_path) -# -# print('started') -# b = list(a.read_as_jsonl()) count = 0 @@ -37,443 +31,66 @@ def convert_code_to_tokens(code): # os.rmdir('temp.py') except: pass - if tree!='': - an = SourceGenerator(' ') - an.visit(tree) - return an.result, an.parents + if tree!='' and tree != None: + return dfs_traversal_with_parents(tree) else: - return [] -# - -# templist = [] -# for idx, sample in enumerate(b): -# print("sample {} in progress".format(idx)) -# # print(sample['code']) -# if idx==3282: -# print(sample['code']) -# -# tokenization = convert_code_to_tokens(sample['code']) -# if tokenization == []: -# templist.append(idx) -# else: -# b[idx]['code_tokens'] = tokenization -# # tree = my_ast.parse(sample['code']) -# # an = SourceGenerator(' ') -# # an.visit(tree) -# # b[idx]['code_tokens'] = an.result -# -# s.save_as_compressed_file(b) -# print('finished', templist, len(templist), tokenization) - -import ast -import sys -import json -# def parse_file(code): -# global c, d -# tree = ast.parse(code) -# -# json_tree = [] -# -# def gen_identifier(identifier, node_type='identifier'): -# pos = len(json_tree) -# json_node = {} -# json_tree.append(json_node) -# json_node['type'] = node_type -# json_node['value'] = identifier -# return pos -# -# def traverse_list(l, node_type='list'): -# pos = len(json_tree) -# json_node = {} -# json_tree.append(json_node) -# json_node['type'] = node_type -# children = [] -# for item in l: -# children.append(traverse(item)) -# if (len(children) != 0): -# json_node['children'] = children -# return pos -# -# def traverse(node): -# pos = len(json_tree) -# json_node = {} -# json_tree.append(json_node) -# json_node['type'] = type(node).__name__ -# children = [] -# if isinstance(node, ast.Name): -# json_node['value'] = node.id -# elif isinstance(node, ast.Num): -# json_node['value'] = unicode(node.n) -# elif isinstance(node, ast.Str): -# json_node['value'] = node.s.decode('utf-8') -# elif isinstance(node, ast.alias): -# json_node['value'] = unicode(node.name) -# if node.asname: -# children.append(gen_identifier(node.asname)) -# elif isinstance(node, ast.FunctionDef): -# json_node['value'] = unicode(node.name) -# elif isinstance(node, ast.ClassDef): -# json_node['value'] = unicode(node.name) -# elif isinstance(node, ast.ImportFrom): -# if node.module: -# json_node['value'] = unicode(node.module) -# elif isinstance(node, ast.Global): -# for n in node.names: -# children.append(gen_identifier(n)) -# elif isinstance(node, ast.keyword): -# json_node['value'] = unicode(node.arg) -# -# # Process children. -# if isinstance(node, ast.For): -# children.append(traverse(node.target)) -# children.append(traverse(node.iter)) -# children.append(traverse_list(node.body, 'body')) -# if node.orelse: -# children.append(traverse_list(node.orelse, 'orelse')) -# elif isinstance(node, ast.If) or isinstance(node, ast.While): -# children.append(traverse(node.test)) -# children.append(traverse_list(node.body, 'body')) -# if node.orelse: -# children.append(traverse_list(node.orelse, 'orelse')) -# elif isinstance(node, ast.With): -# children.append(traverse(node.context_expr)) -# if node.optional_vars: -# children.append(traverse(node.optional_vars)) -# children.append(traverse_list(node.body, 'body')) -# elif isinstance(node, ast.Try): -# children.append(traverse_list(node.body, 'body')) -# children.append(traverse_list(node.handlers, 'handlers')) -# if node.orelse: -# children.append(traverse_list(node.orelse, 'orelse')) -# if node.finalbody: -# children.append(traverse_list(node.finalbody, 'finalbody')) -# elif isinstance(node, ast.arguments): -# children.append(traverse_list(node.args, 'args')) -# children.append(traverse_list(node.defaults, 'defaults')) -# if node.vararg: -# children.append(gen_identifier(node.vararg, 'vararg')) -# if node.kwarg: -# children.append(gen_identifier(node.kwarg, 'kwarg')) -# elif isinstance(node, ast.ExceptHandler): -# if node.type: -# children.append(traverse_list([node.type], 'type')) -# if node.name: -# children.append(traverse_list([node.name], 'name')) -# children.append(traverse_list(node.body, 'body')) -# elif isinstance(node, ast.ClassDef): -# children.append(traverse_list(node.bases, 'bases')) -# children.append(traverse_list(node.body, 'body')) -# children.append(traverse_list(node.decorator_list, 'decorator_list')) -# elif isinstance(node, ast.FunctionDef): -# children.append(traverse(node.args)) -# children.append(traverse_list(node.body, 'body')) -# children.append(traverse_list(node.decorator_list, 'decorator_list')) -# else: -# # Default handling: iterate over children. -# for child in ast.iter_child_nodes(node): -# if isinstance(child, ast.expr_context) or isinstance(child, ast.operator) or isinstance(child, -# ast.boolop) or isinstance( -# child, ast.unaryop) or isinstance(child, ast.cmpop): -# # Directly include expr_context, and operators into the type instead of creating a child. -# json_node['type'] = json_node['type'] + type(child).__name__ -# else: -# children.append(traverse(child)) -# -# if isinstance(node, ast.Attribute): -# children.append(gen_identifier(node.attr, 'attr')) -# -# if (len(children) != 0): -# json_node['children'] = children -# return pos -# -# traverse(tree) -# return json_tree - -# def updated_parse_file(code): -# global c, d -# tree = ast.parse(code) -# -# json_tree = [] -# -# def gen_identifier(identifier, node_type='identifier', parent=None): -# pos = len(json_tree) -# json_node = {} -# json_tree.append(json_node) -# # json_node['type'] = node_type -# json_node[node_type] = identifier -# if parent: -# json_node['parent'] = type(parent).__name__ -# else: -# json_node['parent'] = None -# return pos -# -# def traverse_list(l, node_type='list', parent=None): -# pos = len(json_tree) -# json_node = {} -# json_tree.append(json_node) -# json_node[node_type] = [] -# if parent: -# json_node['parent'] = type(parent).__name__ -# else: -# json_node['parent'] = None -# children = [] -# for item in l: -# children.append(traverse(item)) -# if (len(children) != 0): -# json_node[node_type] = children -# return pos + return [], [] # -# def traverse(node, parent=None): -# pos = len(json_tree) -# json_node = {} -# json_tree.append(json_node) -# json_node[type(node).__name__] = [] -# if parent: -# json_node['parent'] = type(parent).__name__ -# else: -# json_node['parent'] = None -# children = [] -# if isinstance(node, ast.Name): -# json_node[type(node).__name__] = node.id -# elif isinstance(node, ast.Num): -# json_node[type(node).__name__] = unicode(node.n) -# elif isinstance(node, ast.Str): -# json_node[type(node).__name__] = node.s.decode('utf-8') -# elif isinstance(node, ast.alias): -# json_node[type(node).__name__] = unicode(node.name) -# if node.asname: -# children.append(gen_identifier(node.asname)) -# elif isinstance(node, ast.FunctionDef): -# json_node[type(node).__name__] = unicode(node.name) -# elif isinstance(node, ast.ClassDef): -# json_node[type(node).__name__] = unicode(node.name) -# elif isinstance(node, ast.ImportFrom): -# if node.module: -# json_node[type(node).__name__] = unicode(node.module) -# elif isinstance(node, ast.Global): -# for n in node.names: -# children.append(gen_identifier(n)) -# elif isinstance(node, ast.keyword): -# json_node[type(node).__name__] = unicode(node.arg) -# -# # Process children. -# if isinstance(node, ast.For): -# children.append(traverse(node.target, node)) -# children.append(traverse(node.iter, node)) -# children.append(traverse_list(node.body, 'body', node)) -# if node.orelse: -# children.append(traverse_list(node.orelse, 'orelse', node)) -# elif isinstance(node, ast.If) or isinstance(node, ast.While): -# children.append(traverse(node.test, node)) -# children.append(traverse_list(node.body, 'body', node)) -# if node.orelse: -# children.append(traverse_list(node.orelse, 'orelse', node)) -# elif isinstance(node, ast.With): -# children.append(traverse(node.context_expr, node)) -# if node.optional_vars: -# children.append(traverse(node.optional_vars, node)) -# children.append(traverse_list(node.body, 'body', node)) -# elif isinstance(node, ast.Try): -# children.append(traverse_list(node.body, 'body', node)) -# children.append(traverse_list(node.handlers, 'handlers', node)) -# if node.orelse: -# children.append(traverse_list(node.orelse, 'orelse', node)) -# if node.finalbody: -# children.append(traverse_list(node.finalbody, 'finalbody', node)) -# elif isinstance(node, ast.arguments): -# children.append(traverse_list(node.args, 'args', node)) -# children.append(traverse_list(node.defaults, 'defaults', node)) -# if node.vararg: -# children.append(gen_identifier(node.vararg, 'vararg')) -# if node.kwarg: -# children.append(gen_identifier(node.kwarg, 'kwarg')) -# elif isinstance(node, ast.ExceptHandler): -# if node.type: -# children.append(traverse_list([node.type], 'type', node)) -# if node.name: -# children.append(traverse_list([node.name], 'name', node)) -# children.append(traverse_list(node.body, 'body', node)) -# elif isinstance(node, ast.ClassDef): -# children.append(traverse_list(node.bases, 'bases', node)) -# children.append(traverse_list(node.body, 'body', node)) -# children.append(traverse_list(node.decorator_list, 'decorator_list', node)) -# elif isinstance(node, ast.FunctionDef): -# children.append(traverse(node.args, node)) -# children.append(traverse_list(node.body, 'body',node)) -# children.append(traverse_list(node.decorator_list, 'decorator_list',node)) -# else: -# # Default handling: iterate over children. -# for child in ast.iter_child_nodes(node): -# if isinstance(child, ast.expr_context) or isinstance(child, ast.operator) or isinstance(child, -# ast.boolop) or isinstance( -# child, ast.unaryop) or isinstance(child, ast.cmpop): -# # Directly include expr_context, and operators into the type instead of creating a child. -# json_node[type(node).__name__ + type(child).__name__] = json_node[type(node).__name__] -# del json_node[type(node).__name__] -# else: -# children.append(traverse(child,node)) -# -# if isinstance(node, ast.Attribute): -# children.append(gen_identifier(node.attr, 'Attr')) -# -# if (len(children) != 0): -# if type(node).__name__ not in json_node.keys(): -# json_node[type(node).__name__ + type(child).__name__] = children -# else: -# json_node[type(node).__name__] = children -# return pos -# -# traverse(tree) -# return json_tree -# # return json.dumps(json_tree, separators=(',', ':'), ensure_ascii=False) -def parse_file_with_parents(code): - global c, d - tree = ast.parse(code) - json_tree = [] +from pprint import pprint +if __name__=='__main__': + print('something') - def gen_identifier(identifier, node_type='identifier', parent=None): - pos = len(json_tree) - json_node = {} - json_tree.append(json_node) - json_node['type'] = node_type - json_node['value'] = identifier - if parent: - json_node['parent'] = type(parent).__name__ - else: - json_node['parent'] = None - return pos + #[26045, 28475] - def traverse_list(l, node_type='list'): - pos = len(json_tree - json_node = {} - json_tree.append(json_node) - json_node['type'] = node_type - if parent: - json_node['parent'] = type(parent).__name__ - else: - json_node['parent'] = None - children = [] - for item in l: - children.append(traverse(item)) - if (len(children) != 0): - json_node['children'] = children - return pos + path = '../resources/data/python/final/jsonl/train_old/python_train_0.jsonl.gz' + s_path = '../resources/data/python/final/jsonl/train/python_train_0_dfs_parent.jsonl.gz' - def traverse(node): - pos = len(json_tree) - json_node = {} - json_tree.append(json_node) - json_node['type'] = type(node).__name__ - if parent: - json_node['parent'] = type(parent).__name__ - else: - json_node['parent'] = None - children = [] - if isinstance(node, ast.Name): - json_node['value'] = node.id - elif isinstance(node, ast.Num): - json_node['value'] = unicode(node.n) - elif isinstance(node, ast.Str): - json_node['value'] = node.s.decode('utf-8') - elif isinstance(node, ast.alias): - json_node['value'] = unicode(node.name) - if node.asname: - children.append(gen_identifier(node.asname)) - elif isinstance(node, ast.FunctionDef): - json_node['value'] = unicode(node.name) - elif isinstance(node, ast.ClassDef): - json_node['value'] = unicode(node.name) - elif isinstance(node, ast.ImportFrom): - if node.module: - json_node['value'] = unicode(node.module) - elif isinstance(node, ast.Global): - for n in node.names: - children.append(gen_identifier(n)) - elif isinstance(node, ast.keyword): - json_node['value'] = unicode(node.arg) + a = RichPath.create(path) + s = RichPath.create(s_path) - # Process children. - if isinstance(node, ast.For): - children.append(traverse(node.target)) - children.append(traverse(node.iter)) - children.append(traverse_list(node.body, 'body')) - if node.orelse: - children.append(traverse_list(node.orelse, 'orelse')) - elif isinstance(node, ast.If) or isinstance(node, ast.While): - children.append(traverse(node.test)) - children.append(traverse_list(node.body, 'body')) - if node.orelse: - children.append(traverse_list(node.orelse, 'orelse')) - elif isinstance(node, ast.With): - children.append(traverse(node.context_expr)) - if node.optional_vars: - children.append(traverse(node.optional_vars)) - children.append(traverse_list(node.body, 'body')) - elif isinstance(node, ast.Try): - children.append(traverse_list(node.body, 'body')) - children.append(traverse_list(node.handlers, 'handlers')) - if node.orelse: - children.append(traverse_list(node.orelse, 'orelse')) - if node.finalbody: - children.append(traverse_list(node.finalbody, 'finalbody')) - elif isinstance(node, ast.arguments): - children.append(traverse_list(node.args, 'args')) - children.append(traverse_list(node.defaults, 'defaults')) - if node.vararg: - children.append(gen_identifier(node.vararg, 'vararg')) - if node.kwarg: - children.append(gen_identifier(node.kwarg, 'kwarg')) - elif isinstance(node, ast.ExceptHandler): - if node.type: - children.append(traverse_list([node.type], 'type')) - if node.name: - children.append(traverse_list([node.name], 'name')) - children.append(traverse_list(node.body, 'body')) - elif isinstance(node, ast.ClassDef): - children.append(traverse_list(node.bases, 'bases')) - children.append(traverse_list(node.body, 'body')) - children.append(traverse_list(node.decorator_list, 'decorator_list')) - elif isinstance(node, ast.FunctionDef): - children.append(traverse(node.args)) - children.append(traverse_list(node.body, 'body')) - children.append(traverse_list(node.decorator_list, 'decorator_list')) - else: - # Default handling: iterate over children. - for child in ast.iter_child_nodes(node): - if isinstance(child, ast.expr_context) or isinstance(child, ast.operator) or isinstance(child, - ast.boolop) or isinstance( - child, ast.unaryop) or isinstance(child, ast.cmpop): - # Directly include expr_context, and operators into the type instead of creating a child. - json_node['type'] = json_node['type'] + type(child).__name__ - else: - children.append(traverse(child)) + print('started') + b = list(a.read_as_jsonl()) - if isinstance(node, ast.Attribute): - children.append(gen_identifier(node.attr, 'attr')) + b = sorted(b, key=lambda v: len(v['code_tokens'])) - if (len(children) != 0): - json_node['children'] = children - return pos + templist = [] + c = [] + for idx, sample in enumerate(b[10000:30000],10000): + print("sample {} in progress".format(idx)) + # print(sample['code']) - traverse(tree) - return json_tree + if idx == 19 or sample['sha']=='618d6bff71073c8c93501ab7392c3cc579730f0b': + print(sample['code']) + dfs, parent_dfs = convert_code_to_tokens(sample['code']) + if dfs == [] or parent_dfs==[]: + templist.append(idx) + else: + b[idx]['code_tokens'] = dfs + b[idx]['parent_dfs'] = parent_dfs + c.append(b[idx]) -# [{'children': [1], 'type': 'Module'}, -# {'children': [2, 3], 'type': 'Assign'}, -# {'type': 'NameStore', 'value': 'ip'}, -# {'children': [4, 7], 'type': 'Call'}, -# {'children': [5, 6], 'type': 'AttributeLoad'}, -# {'type': 'NameLoad', 'value': 'socket'}, -# {'type': 'attr', 'value': 'gethostbyname'}, -# {'type': 'NameLoad', 'value': 'host'}] + s.save_as_compressed_file(c) + # df = pd.DataFrame([dfs, parent_dfs]) + # print(parent_dfs) + print('finished', templist, len(templist), len(c)) -from pprint import pprint -if __name__=='__main__': - print('something') + # code= '''def f(a, b=1, c=2, *d, e, f=3, **g): + # pass''' + # + # code = b[2]['code'] + # print(code) + # code = '''ip = socket.gethostbyname(host)''' + # + # code = '''ip = socket.gethostbyname(host)\n[ port , request_size , num_requests , num_conns ] = map ( + # string .atoi , sys . argv [2:] + # )\nchain = build_request_chain ( num_requests , host , request_size )''' + + # code = '''from foo import bar as b, car as c, dar as d''' + # print(convert_code_to_tokens(code)) # code ='''print('something') # try: @@ -485,16 +102,15 @@ def traverse(node): # finally: # return 2''' - # code= '''def f(a, b=1, c=2, *d, e, f=3, **g): - # pass''' - code = '''ip = socket.gethostbyname(host)''' - # code = '''func(a, b=c, *d, **e)''' - # a, b = convert_code_to_tokens(code) - # df = pd.DataFrame([a, b]) - # print(df.T) - result_tree = updated_parse_file(code) - # print(pd.read_json(result_tree)) - print(result_tree) \ No newline at end of file +# # code = '''func(a, b=c, *d, **e)''' +# # a, b = parse_file_with_parents(code) +# # df = pd.DataFrame([a, b]) +# # print(df.T) +# +# result_tree = parse_file_with_parents(code) + # # + # # # print(pd.read_json(result_tree)) + # pprint(result_tree) \ No newline at end of file diff --git a/script/parent_node_parse_helpers.py b/script/parent_node_parse_helpers.py new file mode 100644 index 00000000..09bcb00f --- /dev/null +++ b/script/parent_node_parse_helpers.py @@ -0,0 +1,337 @@ + +unicode = lambda s: str(s) +import ast +from pprint import pprint +import pandas as pd + +def create_tree_without_parents(code): + global c, d + tree = ast.parse(code) + + json_tree = [] + + def gen_identifier(identifier, node_type='identifier'): + pos = len(json_tree) + json_node = {} + json_tree.append(json_node) + json_node['type'] = node_type + json_node['value'] = identifier + return pos + + def traverse_list(l, node_type='list'): + pos = len(json_tree) + json_node = {} + json_tree.append(json_node) + json_node['type'] = node_type + children = [] + for item in l: + children.append(traverse(item)) + if (len(children) != 0): + json_node['children'] = children + return pos + + def traverse(node): + pos = len(json_tree) + json_node = {} + json_tree.append(json_node) + json_node['type'] = type(node).__name__ + children = [] + if isinstance(node, ast.Name): + json_node['value'] = node.id + elif isinstance(node, ast.Num): + json_node['value'] = unicode(node.n) + elif isinstance(node, ast.Str): + json_node['value'] = node.s + elif isinstance(node, ast.alias): + json_node['value'] = unicode(node.name) + if node.asname: + children.append(gen_identifier(node.asname)) + elif isinstance(node, ast.FunctionDef): + json_node['value'] = unicode(node.name) + elif isinstance(node, ast.ClassDef): + json_node['value'] = unicode(node.name) + elif isinstance(node, ast.ImportFrom): + if node.module: + json_node['value'] = unicode(node.module) + elif isinstance(node, ast.Global): + for n in node.names: + children.append(gen_identifier(n)) + elif isinstance(node, ast.keyword): + json_node['value'] = unicode(node.arg) + + # Process children. + if isinstance(node, ast.For): + children.append(traverse(node.target)) + children.append(traverse(node.iter)) + children.append(traverse_list(node.body, 'body')) + if node.orelse: + children.append(traverse_list(node.orelse, 'orelse')) + elif isinstance(node, ast.If) or isinstance(node, ast.While): + children.append(traverse(node.test)) + children.append(traverse_list(node.body, 'body')) + if node.orelse: + children.append(traverse_list(node.orelse, 'orelse')) + elif isinstance(node, ast.With): + children.append(traverse(node.context_expr)) + if node.optional_vars: + children.append(traverse(node.optional_vars)) + children.append(traverse_list(node.body, 'body')) + elif isinstance(node, ast.Try): + children.append(traverse_list(node.body, 'body')) + children.append(traverse_list(node.handlers, 'handlers')) + if node.orelse: + children.append(traverse_list(node.orelse, 'orelse')) + if node.finalbody: + children.append(traverse_list(node.finalbody, 'finalbody')) + elif isinstance(node, ast.arguments): + children.append(traverse_list(node.args, 'args')) + children.append(traverse_list(node.defaults, 'defaults')) + if node.vararg: + children.append(gen_identifier(node.vararg, 'vararg')) + if node.kwarg: + children.append(gen_identifier(node.kwarg, 'kwarg')) + elif isinstance(node, ast.ExceptHandler): + if node.type: + children.append(traverse_list([node.type], 'type')) + if node.name: + children.append(traverse_list([node.name], 'name')) + children.append(traverse_list(node.body, 'body')) + elif isinstance(node, ast.ClassDef): + children.append(traverse_list(node.bases, 'bases')) + children.append(traverse_list(node.body, 'body')) + children.append(traverse_list(node.decorator_list, 'decorator_list')) + elif isinstance(node, ast.FunctionDef): + children.append(traverse(node.args)) + children.append(traverse_list(node.body, 'body')) + children.append(traverse_list(node.decorator_list, 'decorator_list')) + else: + # Default handling: iterate over children. + for child in ast.iter_child_nodes(node): + if isinstance(child, ast.expr_context) or isinstance(child, ast.operator) or isinstance(child, + ast.boolop) or isinstance( + child, ast.unaryop) or isinstance(child, ast.cmpop): + # Directly include expr_context, and operators into the type instead of creating a child. + json_node['type'] = json_node['type'] + type(child).__name__ + else: + children.append(traverse(child)) + + if isinstance(node, ast.Attribute): + children.append(gen_identifier(node.attr, 'attr')) + + if (len(children) != 0): + json_node['children'] = children + return pos + + traverse(tree) + return json_tree + + +def get_docstring(node, clean=True): + """ + Return the docstring for the given node or None if no docstring can + be found. If the node provided does not have docstrings a TypeError + will be raised. + + If *clean* is `True`, all tabs are expanded to spaces and any whitespace + that can be uniformly removed from the second line onwards is removed. + """ + if not isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Module)): + raise TypeError("%r can't have docstrings" % node.__class__.__name__) + if not(node.body and isinstance(node.body[0], ast.Expr)): + return None + node = node.body[0].value + if isinstance(node, ast.Str): + text = node.s + # elif isinstance(node, Constant) and isinstance(node.value, str): + # text = node.value + else: + return None + if clean: + import inspect + text = inspect.cleandoc(text) + return text + + +def dfs_traversal_with_parents(tree): + global c, d + + docstring = '' + json_tree = [] + + def gen_identifier(identifier, node_type='identifier', parent=None): + global docstring + pos = len(json_tree) + json_node = {} + json_tree.append(json_node) + json_node['type'] = node_type + json_node['value'] = identifier + + if parent: + if hasattr(parent, 'ctx'): + json_node['parent'] = type(parent).__name__+ type(parent.ctx).__name__ + else: + json_node['parent'] = type(parent).__name__ + else: + json_node['parent'] = None + return pos + + def traverse_list(l, node_type='list', parent=None): + pos = len(json_tree) + json_node = {} + json_tree.append(json_node) + json_node['type'] = node_type + if parent: + if hasattr(parent, 'ctx'): + json_node['parent'] = type(parent).__name__ + type(parent.ctx).__name__ + else: + json_node['parent'] = type(parent).__name__ + else: + json_node['parent'] = None + children = [] + for item in l: + if item: + children.append(traverse(item, node_type)) + if (len(children) != 0): + json_node['children'] = children + return pos + + def traverse(node, parent=None): + global docstring + pos = len(json_tree) + if not (isinstance(node, ast.Str) and docstring == node.s): + json_node = {} + json_tree.append(json_node) + json_node['type'] = type(node).__name__ + if parent: + if type(parent) == str: + json_node['parent'] = parent + elif hasattr(parent, 'ctx'): + json_node['parent'] = type(parent).__name__ + type(parent.ctx).__name__ + else: + json_node['parent'] = type(parent).__name__ + else: + json_node['parent'] = None + children = [] + if isinstance(node, ast.Name): + json_node['value'] = node.id + elif isinstance(node, ast.Num): + json_node['value'] = unicode(node.n) + elif isinstance(node, ast.Str): + if docstring != node.s: + json_node['value'] = node.s + elif isinstance(node, ast.alias): + json_node['value'] = unicode(node.name) + if node.asname: + json_node['value'] = unicode(node.name) + " as " + str(node.asname) + # children.append(gen_identifier(node.asname, 'asname', node)) + elif isinstance(node, ast.FunctionDef): + docstring = get_docstring(node, clean=False) + json_node['value'] = unicode(node.name) + elif isinstance(node, ast.ClassDef): + json_node['value'] = unicode(node.name) + elif isinstance(node, ast.ImportFrom): + if node.module: + json_node['value'] = unicode(node.module) + # if node.names: + # children.append(traverse_list(node.names, 'imports', node)) + + elif isinstance(node, ast.Global): + for n in node.names: + children.append(gen_identifier(n, 'name', node)) + elif isinstance(node, ast.keyword): + json_node['value'] = unicode(node.arg) + elif isinstance(node, ast.arg): + json_node['value'] = unicode(node.arg) + + # Process children. + if isinstance(node, ast.For): + children.append(traverse(node.target, node)) + children.append(traverse(node.iter, node)) + children.append(traverse_list(node.body, 'body', node)) + if node.orelse: + children.append(traverse_list(node.orelse, 'orelse', node)) + elif isinstance(node, ast.If) or isinstance(node, ast.While): + children.append(traverse(node.test, node)) + children.append(traverse_list(node.body, 'body', node)) + if node.orelse: + children.append(traverse_list(node.orelse, 'orelse', node)) + elif isinstance(node, ast.With): + for item in node.items: + children.append(traverse(item.context_expr, node)) + if item.optional_vars: + children.append(traverse(item.optional_vars, node)) + children.append(traverse_list(node.body, 'body', node)) + elif isinstance(node, ast.Try): + children.append(traverse_list(node.body, 'body', node)) + children.append(traverse_list(node.handlers, 'handlers', node)) + if node.orelse: + children.append(traverse_list(node.orelse, 'orelse', node)) + if node.finalbody: + children.append(traverse_list(node.finalbody, 'finalbody', node)) + elif isinstance(node, ast.arguments): + if node.args: + children.append(traverse_list(node.args, 'args', node)) + if node.defaults: + children.append(traverse_list(node.defaults, 'defaults', node)) + if node.vararg: + children.append(gen_identifier(node.vararg.arg, 'vararg', node)) + if node.kwarg: + children.append(gen_identifier(node.kwarg.arg, 'kwarg', node)) + if node.kwonlyargs: + children.append(traverse_list(node.kwonlyargs, 'kwonlyargs', node)) + if node.kw_defaults: + children.append(traverse_list(node.kw_defaults, 'kw_defaults', node)) + + elif isinstance(node, ast.ExceptHandler): + if node.type: + children.append(traverse(node.type)) + # if node.name: + # children.append(traverse(node.name)) + children.append(traverse_list(node.body, 'body', node)) + elif isinstance(node, ast.ClassDef): + children.append(traverse_list(node.bases, 'bases', node)) + children.append(traverse_list(node.body, 'body', node)) + children.append(traverse_list(node.decorator_list, 'decorator_list', node)) + elif isinstance(node, ast.FunctionDef): + children.append(traverse(node.args, node)) + children.append(traverse_list(node.body, 'body', node)) + if node.decorator_list: + children.append(traverse_list(node.decorator_list, 'decorator_list', node)) + else: + # Default handling: iterate over children. + for child in ast.iter_child_nodes(node): + if isinstance(child, ast.expr_context) or isinstance(child, ast.operator) or isinstance(child, + ast.boolop) or isinstance( + child, ast.unaryop) or isinstance(child, ast.cmpop): + # Directly include expr_context, and operators into the type instead of creating a child. + json_node['type'] = json_node['type'] + type(child).__name__ + else: + children.append(traverse(child, node)) + + if isinstance(node, ast.Attribute): + children.append(gen_identifier(node.attr, 'attribute', node)) + + if (len(children) != 0): + json_node['children'] = children + return pos + + traverse(tree) + + dfs_list = [] + parent_dfs = [] + for node in json_tree: + parent_dfs.append(node['parent']) + dfs_list.append(node['type']) + value = node.get('value', None) + if value: + dfs_list.append(value) + parent_dfs.append(node['type']) + + # df = pd.DataFrame([dfs_list, parent_dfs]) + # print(df.T) + + # pprint(json_tree) + + return dfs_list, parent_dfs + # return json_tree + From 262ed9d2c4d73a6664d45349028e463b4b2e0f29 Mon Sep 17 00:00:00 2001 From: faizan khan Date: Fri, 10 Apr 2020 18:20:08 -0400 Subject: [PATCH 4/9] resolve the indexing issue --- script/parent_node_pairs.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/script/parent_node_pairs.py b/script/parent_node_pairs.py index 31cbe743..50219a53 100644 --- a/script/parent_node_pairs.py +++ b/script/parent_node_pairs.py @@ -44,8 +44,8 @@ def convert_code_to_tokens(code): #[26045, 28475] - path = '../resources/data/python/final/jsonl/train_old/python_train_0.jsonl.gz' - s_path = '../resources/data/python/final/jsonl/train/python_train_0_dfs_parent.jsonl.gz' + path = '../resources/data/python/final/jsonl/train_old/temp_train_10.jsonl.gz' + s_path = '../resources/data/python/final/jsonl/train/temp_train_10_dfs_parent.jsonl.gz' a = RichPath.create(path) s = RichPath.create(s_path) @@ -54,10 +54,9 @@ def convert_code_to_tokens(code): b = list(a.read_as_jsonl()) b = sorted(b, key=lambda v: len(v['code_tokens'])) - templist = [] c = [] - for idx, sample in enumerate(b[10000:30000],10000): + for idx, sample in enumerate(b): print("sample {} in progress".format(idx)) # print(sample['code']) From 3c9373f27abce583f5efa9646e2a96ad7ec80b8e Mon Sep 17 00:00:00 2001 From: faizan khan Date: Mon, 13 Apr 2020 03:09:41 -0400 Subject: [PATCH 5/9] final self-attention with parent nodes --- script/parent_node_pairs.py | 4 +- src/encoders/masked_seq_encoder.py | 9 +- src/encoders/self_att_encoder.py | 20 ++- src/encoders/seq_encoder.py | 19 ++- src/encoders/utils/bert_self_attention.py | 192 ++++++++++++++++++++-- src/models/model.py | 13 +- src/models/self_att_model.py | 1 + src/train.py | 9 +- 8 files changed, 237 insertions(+), 30 deletions(-) diff --git a/script/parent_node_pairs.py b/script/parent_node_pairs.py index 50219a53..c2f99406 100644 --- a/script/parent_node_pairs.py +++ b/script/parent_node_pairs.py @@ -44,8 +44,8 @@ def convert_code_to_tokens(code): #[26045, 28475] - path = '../resources/data/python/final/jsonl/train_old/temp_train_10.jsonl.gz' - s_path = '../resources/data/python/final/jsonl/train/temp_train_10_dfs_parent.jsonl.gz' + path = '../resources/data/python/final/jsonl/valid_old/temp_valid_10.jsonl.gz' + s_path = '../resources/data/python/final/jsonl/valid/temp_valid_10_dfs_parent.jsonl.gz' a = RichPath.create(path) s = RichPath.create(s_path) diff --git a/src/encoders/masked_seq_encoder.py b/src/encoders/masked_seq_encoder.py index 0d696e6b..66f8e5e0 100755 --- a/src/encoders/masked_seq_encoder.py +++ b/src/encoders/masked_seq_encoder.py @@ -29,12 +29,19 @@ def _make_placeholders(self): shape=[None, self.get_hyper('max_num_tokens')], name='tokens_mask') - def init_minibatch(self, batch_data: Dict[str, Any]) -> None: + def init_minibatch(self, batch_data: Dict[str, Any], code=True) -> None: super().init_minibatch(batch_data) batch_data['tokens'] = [] batch_data['tokens_mask'] = [] + if self.hyperparameters['use_parent'] and code: + batch_data['parent_tokens'] = [] + batch_data['parent_tokens_mask'] = [] def minibatch_to_feed_dict(self, batch_data: Dict[str, Any], feed_dict: Dict[tf.Tensor, Any], is_train: bool) -> None: super().minibatch_to_feed_dict(batch_data, feed_dict, is_train) write_to_feed_dict(feed_dict, self.placeholders['tokens'], batch_data['tokens']) write_to_feed_dict(feed_dict, self.placeholders['tokens_mask'], batch_data['tokens_mask']) + + if self.hyperparameters['use_parent'] and batch_data.get('parent_tokens', None): + write_to_feed_dict(feed_dict, self.placeholders['parent_tokens'], batch_data['parent_tokens']) + write_to_feed_dict(feed_dict, self.placeholders['parent_tokens_mask'], batch_data['parent_tokens_mask']) diff --git a/src/encoders/self_att_encoder.py b/src/encoders/self_att_encoder.py index 363dc1b0..a8e1f062 100755 --- a/src/encoders/self_att_encoder.py +++ b/src/encoders/self_att_encoder.py @@ -15,7 +15,7 @@ def get_default_hyperparameters(cls) -> Dict[str, Any]: 'self_attention_intermediate_size': 512, 'self_attention_num_layers': 3, 'self_attention_num_heads': 8, - 'self_attention_pool_mode': 'weighted_mean', + 'self_attention_pool_mode': 'weighted_mean' } hypers = super().get_default_hyperparameters() hypers.update(encoder_hypers) @@ -32,6 +32,18 @@ def make_model(self, is_train: bool = False) -> tf.Tensor: with tf.variable_scope("self_attention_encoder"): self._make_placeholders() + if self.label == "code" and self.hyperparameters['use_parent']: + self.placeholders['parent_tokens'] = tf.placeholder(tf.int32, + shape=[None, self.get_hyper('max_num_tokens')], + name='parent_tokens') + + self.placeholders['parent_tokens_mask'] = tf.placeholder(tf.int32, + shape=[None, self.get_hyper('max_num_tokens')], + name='parent_tokens_mask') + else: + self.placeholders['parent_tokens'] = None + self.placeholders['parent_tokens_mask'] = None + config = BertConfig(vocab_size=self.get_hyper('token_vocab_size'), hidden_size=self.get_hyper('self_attention_hidden_size'), num_hidden_layers=self.get_hyper('self_attention_num_layers'), @@ -40,9 +52,11 @@ def make_model(self, is_train: bool = False) -> tf.Tensor: model = BertModel(config=config, is_training=is_train, - input_ids=self.placeholders['tokens'], + input_ids= self.placeholders['tokens'], input_mask=self.placeholders['tokens_mask'], - use_one_hot_embeddings=False) + use_one_hot_embeddings=False, + parent_ids=self.placeholders['parent_tokens'], + parent_mask=self.placeholders['parent_tokens_mask']) output_pool_mode = self.get_hyper('self_attention_pool_mode').lower() if output_pool_mode == 'bert': diff --git a/src/encoders/seq_encoder.py b/src/encoders/seq_encoder.py index 79f96e2c..b3db6661 100755 --- a/src/encoders/seq_encoder.py +++ b/src/encoders/seq_encoder.py @@ -128,7 +128,8 @@ def load_data_from_sample(cls, data_to_load: Any, function_name: Optional[str], result_holder: Dict[str, Any], - is_test: bool = True) -> bool: + is_test: bool = True, + parent_tokens=False) -> bool: """ Saves two versions of both the code and the query: one using the docstring as the query and the other using the function-name as the query, and replacing the function name in the code with an out-of-vocab token. @@ -168,6 +169,17 @@ def load_data_from_sample(cls, result_holder[f'{encoder_label}_tokens_mask_{key}'] = tokens_mask result_holder[f'{encoder_label}_tokens_length_{key}'] = int(np.sum(tokens_mask)) + if parent_tokens: + parent_tokens[0] = Vocabulary.get_unk() + tokens, tokens_mask = \ + convert_and_pad_token_sequence(metadata['token_vocab'], list(parent_tokens), + hyperparameters[f'{encoder_label}_max_num_tokens']) + # Note that we share the result_holder with different encoders, and so we need to make our identifiers + # unique-ish + result_holder[f'{encoder_label}_parent_tokens_{key}'] = tokens + result_holder[f'{encoder_label}_parent_tokens_mask_{key}'] = tokens_mask + result_holder[f'{encoder_label}_parent_tokens_length_{key}'] = int(np.sum(tokens_mask)) + if result_holder[f'{encoder_label}_tokens_mask_{QueryType.DOCSTRING.value}'] is None or \ int(np.sum(result_holder[f'{encoder_label}_tokens_mask_{QueryType.DOCSTRING.value}'])) == 0: return False @@ -187,6 +199,11 @@ def extend_minibatch_by_sample(self, batch_data: Dict[str, Any], sample: Dict[st current_sample['tokens_mask'] = sample[f'{self.label}_tokens_mask_{query_type}'] current_sample['tokens_lengths'] = sample[f'{self.label}_tokens_length_{query_type}'] + if self.label == 'code': + current_sample['parent_tokens'] = sample[f'{self.label}_parent_tokens_{query_type}'] + current_sample['parent_tokens_mask'] = sample[f'{self.label}_parent_tokens_mask_{query_type}'] + current_sample['parent_tokens_lengths'] = sample[f'{self.label}_parent_tokens_length_{query_type}'] + # In the query, randomly add high-frequency tokens: # TODO: Add tokens with frequency proportional to their frequency in the vocabulary if is_train and self.label == 'query' and self.hyperparameters['query_random_token_frequency'] > 0.: diff --git a/src/encoders/utils/bert_self_attention.py b/src/encoders/utils/bert_self_attention.py index 76918e95..4cbeb997 100755 --- a/src/encoders/utils/bert_self_attention.py +++ b/src/encoders/utils/bert_self_attention.py @@ -18,6 +18,9 @@ from __future__ import division from __future__ import print_function +import tensorflow.contrib.eager as tfe + + import collections import copy import json @@ -137,7 +140,9 @@ def __init__(self, token_type_ids=None, use_one_hot_embeddings=True, scope=None, - embedded_input=None): + embedded_input=None, + parent_ids=None, + parent_mask=None): """Constructor for BertModel. Args: @@ -160,6 +165,7 @@ def __init__(self, ValueError: The config is invalid or one of the input tensor shapes is invalid. """ + config = copy.deepcopy(config) if not is_training: config.hidden_dropout_prob = 0.0 @@ -179,6 +185,7 @@ def __init__(self, with tf.variable_scope("embeddings"): if embedded_input is None: # Perform embedding lookup on the word ids. + #returns a vector of B x SeqLength x hidden_size (self.embedding_output, self.embedding_table) = embedding_lookup( input_ids=input_ids, vocab_size=config.vocab_size, @@ -186,6 +193,7 @@ def __init__(self, initializer_range=config.initializer_range, word_embedding_name="word_embeddings", use_one_hot_embeddings=use_one_hot_embeddings) + else: self.embedding_output = embedded_input @@ -212,18 +220,65 @@ def __init__(self, # Run the stacked transformer. # `sequence_output` shape = [batch_size, seq_length, hidden_size]. - self.all_encoder_layers = transformer_model( - input_tensor=self.embedding_output, - attention_mask=attention_mask, - hidden_size=config.hidden_size, - num_hidden_layers=config.num_hidden_layers, - num_attention_heads=config.num_attention_heads, - intermediate_size=config.intermediate_size, - intermediate_act_fn=get_activation(config.hidden_act), - hidden_dropout_prob=config.hidden_dropout_prob, - attention_probs_dropout_prob=config.attention_probs_dropout_prob, - initializer_range=config.initializer_range, - do_return_all_layers=True) + + if parent_ids is not None: + + if parent_mask is None: + parent_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32) + + parent_attention_mask = create_attention_mask_from_input_mask( + parent_ids, parent_mask) + + + parent_embedding_output, parent_embedding_table = embedding_lookup( + input_ids=parent_ids, + vocab_size=config.vocab_size, + embedding_size=config.hidden_size, + initializer_range=config.initializer_range, + word_embedding_name="word_embeddings", + use_one_hot_embeddings=use_one_hot_embeddings) + + parent_embedding_output = embedding_postprocessor( + input_tensor=parent_embedding_output, + use_token_type=True, + token_type_ids=token_type_ids, + token_type_vocab_size=config.type_vocab_size, + token_type_embedding_name="token_type_embeddings", + use_position_embeddings=True, + position_embedding_name="position_embeddings", + initializer_range=config.initializer_range, + max_position_embeddings=config.max_position_embeddings, + dropout_prob=config.hidden_dropout_prob) + + + + self.all_encoder_layers = transformer_model( + input_tensor=self.embedding_output, + attention_mask=attention_mask, + hidden_size=config.hidden_size, + num_hidden_layers=config.num_hidden_layers, + num_attention_heads=config.num_attention_heads, + intermediate_size=config.intermediate_size, + intermediate_act_fn=get_activation(config.hidden_act), + hidden_dropout_prob=config.hidden_dropout_prob, + attention_probs_dropout_prob=config.attention_probs_dropout_prob, + initializer_range=config.initializer_range, + do_return_all_layers=True, + parent_tensor=parent_embedding_output, + parent_attention_mask=parent_attention_mask) + else: + self.all_encoder_layers = transformer_model( + input_tensor=self.embedding_output, + attention_mask=attention_mask, + hidden_size=config.hidden_size, + num_hidden_layers=config.num_hidden_layers, + num_attention_heads=config.num_attention_heads, + intermediate_size=config.intermediate_size, + intermediate_act_fn=get_activation(config.hidden_act), + hidden_dropout_prob=config.hidden_dropout_prob, + attention_probs_dropout_prob=config.attention_probs_dropout_prob, + initializer_range=config.initializer_range, + do_return_all_layers=True) self.sequence_output = self.all_encoder_layers[-1] # The "pooler" converts the encoded sequence tensor of shape @@ -519,7 +574,6 @@ def create_attention_mask_from_input_mask(from_tensor, to_mask): return mask - def attention_layer(from_tensor, to_tensor, attention_mask=None, @@ -686,7 +740,7 @@ def transpose_for_scores(input_tensor, batch_size, num_attention_heads, # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = dropout(attention_probs, attention_probs_dropout_prob) + # attention_probs = dropout(attention_probs, attention_probs_dropout_prob) # `value_layer` = [B, T, N, H] value_layer = tf.reshape( @@ -726,7 +780,9 @@ def transformer_model(input_tensor, hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, initializer_range=0.02, - do_return_all_layers=False): + do_return_all_layers=False, + parent_tensor=None, + parent_attention_mask=None): """Multi-headed, multi-layer Transformer from "Attention is All You Need". This is almost an exact implementation of the original Transformer encoder. @@ -786,6 +842,7 @@ def transformer_model(input_tensor, # the GPU/CPU but may not be free on the TPU, so we want to minimize them to # help the optimizer. prev_output = reshape_to_matrix(input_tensor) + old_attention_head = num_attention_heads all_layer_outputs = [] for layer_idx in range(num_hidden_layers): @@ -794,6 +851,37 @@ def transformer_model(input_tensor, with tf.variable_scope("attention"): attention_heads = [] + num_attention_heads = old_attention_head + if parent_tensor is not None: + with tf.variable_scope('parent'): + attention_head_size = int(hidden_size / num_attention_heads) + parent_shape = get_shape_list(parent_tensor, expected_rank=3) + parent_input_width = parent_shape[2] + + if parent_input_width != hidden_size: + raise ValueError("The width of the input tensor (%d) != hidden size (%d)" % + (parent_input_width, hidden_size)) + + parent_reshaped_tensor = reshape_to_matrix(parent_tensor) + + attention_head = attention_layer( + from_tensor=layer_input, + to_tensor=parent_reshaped_tensor, + attention_mask=parent_attention_mask, + num_attention_heads=1, + size_per_head=attention_head_size, + attention_probs_dropout_prob=attention_probs_dropout_prob, + initializer_range=initializer_range, + do_return_2d_tensor=True, + batch_size=batch_size, + from_seq_length=seq_length, + to_seq_length=seq_length) + + attention_heads.append(attention_head) + old_attention_head = num_attention_heads + num_attention_heads = num_attention_heads-1 + parent_tensor = None + with tf.variable_scope("self"): attention_head = attention_layer( from_tensor=layer_input, @@ -948,4 +1036,74 @@ def assert_rank(tensor, expected_rank, name=None): raise ValueError( "For the tensor `%s` in scope `%s`, the actual rank " "`%d` (shape = %s) is not equal to the expected rank `%s`" % - (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank))) \ No newline at end of file + (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank))) + + + + +if __name__=='__main__': + print('something') + + x = tf.placeholder(tf.float32, shape=[None, 4]) + y = tf.placeholder(tf.float32, shape=[None, 4]) + + # test = tf.layers.dense( + # x, + # 2 * 3, + # activation=None, + # kernel_initializer=create_initializer(0.02)) + + + from_tensor = x + to_tensor = y + context = attention_layer(from_tensor, to_tensor, num_attention_heads=1, size_per_head=4, batch_size=1, from_seq_length=2, to_seq_length=2) + # transformer = transformer_model(from_tensor, + # attention_mask=None, + # hidden_size=4, + # num_hidden_layers=1, + # num_attention_heads=1, + # intermediate_size=4, + # intermediate_act_fn=get_activation('gelu'), + # hidden_dropout_prob=0.1, + # attention_probs_dropout_prob=0.1, + # initializer_range=0.02, + # do_return_all_layers=False) + # + # input_ids = tf.constant([[31, 51, 99, 100], [15, 5, 0, 200]]) + # input_mask = tf.constant([[1, 1, 1, 1], [1, 1, 0, 1]]) + # token_type_ids = tf.constant([[0, 0, 1, 1], [0, 2, 0, 1]]) + + # config = BertConfig(vocab_size=4, hidden_size=1, + # num_hidden_layers=1, num_attention_heads=1, intermediate_size=4) + # + # model = BertModel(config=config, is_training=True, + # input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids) + # + # return_value = model.get_pooled_output() + + '''python + # Already been converted into WordPiece token ids + input_ids = tf.constant([[31, 51, 99], [15, 5, 0]]) + input_mask = tf.constant([[1, 1, 1], [1, 1, 0]]) + token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]]) + + config = modeling.BertConfig(vocab_size=32000, hidden_size=512, + num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) + + model = modeling.BertModel(config=config, is_training=True, + input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids) + + label_embeddings = tf.get_variable(...) + pooled_output = model.get_pooled_output() + logits = tf.matmul(pooled_output, label_embeddings)''' + + + sess = tf.InteractiveSession() + + tf.global_variables_initializer().run() + + # print(print('Loss(x,y) = {}'.format(sess.run([model], {x: [["t","2","a","a"], ["a","b","c","d"]]})))) + + print(print('Loss(x,y) = {}'.format(sess.run([context], {x:[[1,2,3,4], [11,12,13,19]], y:[[1,2,3,4], [11,12,13,19]]})))) + + print(print('Loss(x,y) = {}'.format(sess.run([context], {x:[[1,2,3,4], [11,12,13,19]], y:[[6,7,8,9], [20,21,22,23]]})))) \ No newline at end of file diff --git a/src/models/model.py b/src/models/model.py index 2420a19d..626fdef4 100755 --- a/src/models/model.py +++ b/src/models/model.py @@ -7,6 +7,7 @@ from collections import defaultdict, OrderedDict from enum import Enum, auto from typing import List, Dict, Any, Iterable, Tuple, Optional, Union, Callable, Type, DefaultDict +from tensorflow.python import debug as tf_debug import numpy as np import wandb @@ -19,7 +20,7 @@ LoadedSamples = Dict[str, List[Dict[str, Any]]] SampleId = Tuple[str, int] - +from pprint import pprint class RepresentationType(Enum): CODE = auto() @@ -62,7 +63,8 @@ def parse_data_file(hyperparameters: Dict[str, Any], raw_sample['code_tokens'], function_name, sample, - is_test) + is_test, + raw_sample['parent_dfs']) use_query_flag = query_encoder_class.load_data_from_sample("query", hyperparameters, @@ -71,6 +73,7 @@ def parse_data_file(hyperparameters: Dict[str, Any], function_name, sample, is_test) + use_example = use_code_flag and use_query_flag results[language].append((use_example, sample)) return results @@ -152,7 +155,7 @@ def __init__(self, graph = tf.Graph() self.__sess = tf.Session(graph=graph, config=config) - + # save directory as tensorboard. self.__tensorboard_dir = log_save_dir @@ -505,7 +508,7 @@ def __init_minibatch(self) -> Dict[str, Any]: for (language, language_encoder) in self.__code_encoders.items(): batch_data['per_language_query_data'][language] = {} batch_data['per_language_query_data'][language]['query_sample_ids'] = [] - self.__query_encoder.init_minibatch(batch_data['per_language_query_data'][language]) + self.__query_encoder.init_minibatch(batch_data['per_language_query_data'][language], code=False) batch_data['per_language_code_data'][language] = {} batch_data['per_language_code_data'][language]['code_sample_ids'] = [] language_encoder.init_minibatch(batch_data['per_language_code_data'][language]) @@ -723,6 +726,8 @@ def __run_epoch_in_batches(self, data: LoadedSamples, epoch_name: str, is_train: ops_to_run = {'loss': self.__ops['loss'], 'mrr': self.__ops['mrr']} if is_train: ops_to_run['train_step'] = self.__ops['train_step'] + + # print(batch_data_dict) op_results = self.__sess.run(ops_to_run, feed_dict=batch_data_dict) assert not np.isnan(op_results['loss']) diff --git a/src/models/self_att_model.py b/src/models/self_att_model.py index c47fa200..269b7cd3 100755 --- a/src/models/self_att_model.py +++ b/src/models/self_att_model.py @@ -16,6 +16,7 @@ def get_default_hyperparameters(cls) -> Dict[str, Any]: 'code_use_subtokens': False, 'code_mark_subtoken_end': False, 'batch_size': 450, + 'use_parent': True } hypers.update(super().get_default_hyperparameters()) hypers.update(model_hypers) diff --git a/src/train.py b/src/train.py index f0eeaba6..d1d13cb3 100755 --- a/src/train.py +++ b/src/train.py @@ -134,6 +134,10 @@ def run(arguments, tag_in_vcs=False) -> None: batch_size = arguments.get('--batch-size') if batch_size: hyperparameters['batch_size'] = int(batch_size) + use_parent = arguments.get('--use-parent') + if not use_parent: + hyperparameters['use_parent'] = bool(use_parent) + # hyperparameters['code_use_bpe'] = False # hyperparameters['query_use_bpe'] = False @@ -201,9 +205,10 @@ def run(arguments, tag_in_vcs=False) -> None: args = docopt(__doc__) args['--model'] = 'selfatt' args['--dryrun'] = True - # args['--testrun'] = True + # args['--testrun'] = False args['--sequential'] = True - args['--max_epoch'] = 20 + args['--max_num_epochs'] = 2 + # args['--use-parent'] = False args['--batch-size'] = 2 # run_and_debug(lambda: run(args), args['--debug']) From 00e2e67cb11449e8fc0d9b6934b830a37b7f1672 Mon Sep 17 00:00:00 2001 From: faizan khan Date: Mon, 13 Apr 2020 03:30:37 -0400 Subject: [PATCH 6/9] using parent node connection --- src/encoders/self_att_encoder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/encoders/self_att_encoder.py b/src/encoders/self_att_encoder.py index a8e1f062..338742ce 100755 --- a/src/encoders/self_att_encoder.py +++ b/src/encoders/self_att_encoder.py @@ -33,6 +33,7 @@ def make_model(self, is_train: bool = False) -> tf.Tensor: self._make_placeholders() if self.label == "code" and self.hyperparameters['use_parent']: + print("USING PARENT NODE CONNECTIONS") self.placeholders['parent_tokens'] = tf.placeholder(tf.int32, shape=[None, self.get_hyper('max_num_tokens')], name='parent_tokens') From 396cee8836adbaa92b8f455d804a8d8968dcae0a Mon Sep 17 00:00:00 2001 From: Faizan KHAN Date: Mon, 13 Apr 2020 05:19:05 -0400 Subject: [PATCH 7/9] resolve script issues, seq=encode, update requiremts --- script/parent_node_pairs.py | 8 ++++---- src/encoders/seq_encoder.py | 2 +- src/gpurequirements.tx | 1 + 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/script/parent_node_pairs.py b/script/parent_node_pairs.py index c2f99406..96e447d0 100644 --- a/script/parent_node_pairs.py +++ b/script/parent_node_pairs.py @@ -4,7 +4,7 @@ from src.utils.codegen import * import subprocess -from parent_node_parse_helpers import dfs_traversal_with_parents +from .parent_node_parse_helpers import dfs_traversal_with_parents import pandas as pd import os @@ -44,8 +44,8 @@ def convert_code_to_tokens(code): #[26045, 28475] - path = '../resources/data/python/final/jsonl/valid_old/temp_valid_10.jsonl.gz' - s_path = '../resources/data/python/final/jsonl/valid/temp_valid_10_dfs_parent.jsonl.gz' + path = 'resources/data/python/final/jsonl/old_train/python_train_0.jsonl.gz' + s_path = 'resources/data/python/final/jsonl/train/python_train_0_dfs_parent.jsonl.gz' a = RichPath.create(path) s = RichPath.create(s_path) @@ -112,4 +112,4 @@ def convert_code_to_tokens(code): # result_tree = parse_file_with_parents(code) # # # # # print(pd.read_json(result_tree)) - # pprint(result_tree) \ No newline at end of file + # pprint(result_tree) diff --git a/src/encoders/seq_encoder.py b/src/encoders/seq_encoder.py index b3db6661..576721da 100755 --- a/src/encoders/seq_encoder.py +++ b/src/encoders/seq_encoder.py @@ -170,7 +170,7 @@ def load_data_from_sample(cls, result_holder[f'{encoder_label}_tokens_length_{key}'] = int(np.sum(tokens_mask)) if parent_tokens: - parent_tokens[0] = Vocabulary.get_unk() + parent_tokens = [Vocabulary.get_unk() if token==None else token for token in parent_tokens] tokens, tokens_mask = \ convert_and_pad_token_sequence(metadata['token_vocab'], list(parent_tokens), hyperparameters[f'{encoder_label}_max_num_tokens']) diff --git a/src/gpurequirements.tx b/src/gpurequirements.tx index 8b19cbd1..b056805e 100644 --- a/src/gpurequirements.tx +++ b/src/gpurequirements.tx @@ -1,3 +1,4 @@ +2to3==1.0 absl-py==0.9.0 altair==3.2.0 annoy==1.16.0 From a160cdd2cb0880abb6425b723d94538056edc5a5 Mon Sep 17 00:00:00 2001 From: faizan khan Date: Mon, 13 Apr 2020 15:18:33 -0400 Subject: [PATCH 8/9] resolve the none issue --- script/parent_node_pairs.py | 4 ++-- src/encoders/seq_encoder.py | 2 +- src/train.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/script/parent_node_pairs.py b/script/parent_node_pairs.py index c2f99406..a0a39eba 100644 --- a/script/parent_node_pairs.py +++ b/script/parent_node_pairs.py @@ -44,8 +44,8 @@ def convert_code_to_tokens(code): #[26045, 28475] - path = '../resources/data/python/final/jsonl/valid_old/temp_valid_10.jsonl.gz' - s_path = '../resources/data/python/final/jsonl/valid/temp_valid_10_dfs_parent.jsonl.gz' + path = '../resources/data/python/final/jsonl/train_old/python_train_0.jsonl.gz' + s_path = '../resources/data/python/final/jsonl/train/python_train_0_parent_dfs.jsonl.gz' a = RichPath.create(path) s = RichPath.create(s_path) diff --git a/src/encoders/seq_encoder.py b/src/encoders/seq_encoder.py index b3db6661..576721da 100755 --- a/src/encoders/seq_encoder.py +++ b/src/encoders/seq_encoder.py @@ -170,7 +170,7 @@ def load_data_from_sample(cls, result_holder[f'{encoder_label}_tokens_length_{key}'] = int(np.sum(tokens_mask)) if parent_tokens: - parent_tokens[0] = Vocabulary.get_unk() + parent_tokens = [Vocabulary.get_unk() if token==None else token for token in parent_tokens] tokens, tokens_mask = \ convert_and_pad_token_sequence(metadata['token_vocab'], list(parent_tokens), hyperparameters[f'{encoder_label}_max_num_tokens']) diff --git a/src/train.py b/src/train.py index d1d13cb3..f213617e 100755 --- a/src/train.py +++ b/src/train.py @@ -134,9 +134,9 @@ def run(arguments, tag_in_vcs=False) -> None: batch_size = arguments.get('--batch-size') if batch_size: hyperparameters['batch_size'] = int(batch_size) - use_parent = arguments.get('--use-parent') - if not use_parent: - hyperparameters['use_parent'] = bool(use_parent) + # use_parent = arguments.get('--use-parent') + # if not use_parent: + # hyperparameters['use_parent'] = bool(use_parent) # hyperparameters['code_use_bpe'] = False # hyperparameters['query_use_bpe'] = False From 1f86dc0425f90c8ec738a44e231ad6247f700863 Mon Sep 17 00:00:00 2001 From: faizan khan Date: Mon, 13 Apr 2020 20:24:13 -0400 Subject: [PATCH 9/9] add parent attention to immediate parent only --- src/encoders/utils/bert_self_attention.py | 25 +++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/encoders/utils/bert_self_attention.py b/src/encoders/utils/bert_self_attention.py index 4cbeb997..ecbbb066 100755 --- a/src/encoders/utils/bert_self_attention.py +++ b/src/encoders/utils/bert_self_attention.py @@ -229,6 +229,10 @@ def __init__(self, parent_attention_mask = create_attention_mask_from_input_mask( parent_ids, parent_mask) + identity = tf.eye(seq_length) + identity = tf.reshape(identity, [1, seq_length, seq_length]) + + parent_attention_mask = tf.tile(identity, [batch_size,1,1]) parent_embedding_output, parent_embedding_table = embedding_lookup( input_ids=parent_ids, @@ -718,10 +722,20 @@ def transpose_for_scores(input_tensor, batch_size, num_attention_heads, # attention scores. # `attention_scores` = [B, N, F, T] attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) + # if num_attention_heads == 1: + # attention_scores = tf.Print(attention_scores, [attention_scores[0][0][i] for i in range(10)], + # "ATTENTION MASK: unscaled attention scores\n", summarize=10) + # query_layer = tf.Print(query_layer, [query_layer[0][0][i] for i in range(10)], + # "ATTENTION MASK: query\n", summarize=10) + # key_layer = tf.Print(key_layer, [key_layer[0][0][i] for i in range(10)], + # "ATTENTION MASK: key\n", summarize=10) attention_scores = tf.multiply(attention_scores, 1.0 / math.sqrt(float(size_per_head))) if attention_mask is not None: + # if num_attention_heads==1: + # attention_mask = tf.Print(attention_mask, [attention_mask[0][i] for i in range(10)], "ATTENTION MASK: original_attention_masks\n", summarize=10) + # `attention_mask` = [B, 1, F, T] attention_mask = tf.expand_dims(attention_mask, axis=[1]) @@ -730,14 +744,25 @@ def transpose_for_scores(input_tensor, batch_size, num_attention_heads, # positions we want to attend and -10000.0 for masked positions. adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0 + # if num_attention_heads==1: + # adder = tf.Print(adder, [adder[0][0][i] for i in range(10)], "ATTENTION MASK: adders\n", summarize=10) + # attention_scores = tf.Print(attention_scores, [attention_scores[0][0][i] for i in range(10)], "ATTENTION MASK: raw attention_scores\n", summarize=10) + # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. attention_scores += adder + # if num_attention_heads==1: + # attention_scores = tf.Print(attention_scores, [attention_scores[0][0][i] for i in range(10)], "ATTENTION MASK: added attention_scores\n", summarize=10) # Normalize the attention scores to probabilities. # `attention_probs` = [B, N, F, T] attention_probs = tf.nn.softmax(attention_scores) + # if num_attention_heads == 1: + # attention_probs = tf.Print(attention_probs, [attention_probs[0][0][i] for i in range(10)], "ATTENTION MASK:parent attention probs\n", summarize=10) + # else: + # attention_probs = tf.Print(attention_probs, [attention_probs[0][0][i] for i in range(10)], + # "ATTENTION MASK: normal attention probs\n", summarize=10) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. # attention_probs = dropout(attention_probs, attention_probs_dropout_prob)