diff --git a/reloading/reloading.py b/reloading/reloading.py index 59dbd1c..192d7da 100644 --- a/reloading/reloading.py +++ b/reloading/reloading.py @@ -7,18 +7,23 @@ from functools import partial, update_wrapper -# have to make our own partial in case someone wants to use reloading as a iterator without any arguments -# they would get a partial back because a call without a iterator argument is assumed to be a decorator. -# getting a "TypeError: 'functools.partial' object is not iterable" -# which is not really descriptive. -# hence we overwrite the iter to make sure that the error makes sense. -class no_iter_partial(partial): +class NoIterPartial(partial): + """ + have to make our own partial in case someone wants to use reloading as a iterator without any arguments + they would get a partial back because a call without a iterator argument is assumed to be a decorator. + getting a "TypeError: 'functools.partial' object is not iterable" + which is not really descriptive. + hence we overwrite the iter to make sure that the error makes sense. + """ def __iter__(self): - raise TypeError("Nothing to iterate over. Please pass an iterable to reloading.") + raise TypeError( + "Nothing to iterate over. Please pass an iterable to reloading." + ) def reloading(fn_or_seq=None, every=1, forever=None): - """Wraps a loop iterator or decorates a function to reload the source code + """ + Wraps a loop iterator or decorates a function to reload the source code before every loop iteration or function invocation. When wrapped around the outermost iterator in a `for` loop, e.g. @@ -37,29 +42,39 @@ def reloading(fn_or_seq=None, every=1, forever=None): every (int, Optional): After how many iterations/invocations to reload forever (bool, Optional): Pass `forever=true` instead of an iterator to create an endless loop - """ - if fn_or_seq: - if isinstance(fn_or_seq, types.FunctionType): - return _reloading_function(fn_or_seq, every=every) - return _reloading_loop(fn_or_seq, every=every) + if forever and fn_or_seq is not None: + raise ValueError( + "Cannot use `forever=True` and pass an iterator at the same time" + ) if forever: return _reloading_loop(iter(int, 1), every=every) + if fn_or_seq: + if isinstance(fn_or_seq, types.FunctionType): + return _reloading_function(fn_or_seq, every=every) + if hasattr(fn_or_seq, "__iter__"): + return _reloading_loop(fn_or_seq, every=every) + raise TypeError( + f"{reloading.__name__} expected function or iterable, got {type(fn_or_seq)}" + ) # return this function with the keyword arguments partialed in, # so that the return value can be used as a decorator - decorator = update_wrapper(no_iter_partial(reloading, every=every), reloading) + decorator = update_wrapper(NoIterPartial(reloading, every=every), reloading) return decorator def unique_name(used): - # get the longest element of the used names and append a "0" + """ + Get the longest element of the used names and append a "0" + """ return max(used, key=len) + "0" def format_itervars(ast_node): - """Formats an `ast_node` of loop iteration variables as string, e.g. 'a, b'""" - + """ + Formats an `ast_node` of loop iteration variables as string, e.g. 'a, b' + """ # handle the case that there only is a single loop var if isinstance(ast_node, ast.Name): return ast_node.id @@ -78,7 +93,7 @@ def format_itervars(ast_node): def load_file(path): src = "" # while loop here since while saving, the file may sometimes be empty. - while (src == ""): + while src == "": with open(path, "r") as f: src = f.read() return src + "\n" @@ -96,8 +111,10 @@ def parse_file_until_successful(path): def isolate_loop_body_and_get_itervars(tree, lineno, loop_id): - """Modifies tree inplace as unclear how to create ast.Module. - Returns itervars""" + """ + Modifies tree inplace as unclear how to create ast.Module. + Returns itervars + """ candidate_nodes = [] for node in ast.walk(tree): if ( @@ -105,10 +122,10 @@ def isolate_loop_body_and_get_itervars(tree, lineno, loop_id): and isinstance(node.iter, ast.Call) and node.iter.func.id == "reloading" and ( - (loop_id is not None and loop_id == get_loop_id(node)) - or getattr(node, "lineno", None) == lineno - ) - ): + (loop_id is not None and loop_id == get_loop_id(node)) + or getattr(node, "lineno", None) == lineno + ) + ): candidate_nodes.append(node) if len(candidate_nodes) > 1: @@ -127,7 +144,8 @@ def isolate_loop_body_and_get_itervars(tree, lineno, loop_id): def get_loop_id(ast_node): - """Generates a unique identifier for an `ast_node` of type ast.For to find the loop in the changed source file + """ + Generates a unique identifier for an `ast_node` of type ast.For to find the loop in the changed source file """ return ast.dump(ast_node.target) + "__" + ast.dump(ast_node.iter) @@ -137,18 +155,39 @@ def get_loop_code(loop_frame_info, loop_id): while True: tree = parse_file_until_successful(fpath) try: - itervars, found_loop_id = isolate_loop_body_and_get_itervars(tree, lineno=loop_frame_info[2], loop_id=loop_id) - return compile(tree, filename="", mode="exec"), format_itervars(itervars), found_loop_id + itervars, found_loop_id = isolate_loop_body_and_get_itervars( + tree, lineno=loop_frame_info[2], loop_id=loop_id + ) + return ( + compile(tree, filename="", mode="exec"), + format_itervars(itervars), + found_loop_id, + ) except LookupError: handle_exception(fpath) def handle_exception(fpath): exc = traceback.format_exc() - exc = exc.replace('File ""', 'File "{}"'.format(fpath)) + exc = exc.replace('File ""', f'File "{fpath}"') sys.stderr.write(exc + "\n") - print("Edit {} and press return to continue".format(fpath)) - sys.stdin.readline() + + if sys.stdin.isatty(): + print( + f"An error occurred. Please edit the file '{fpath}' to fix the issue and press return to continue or Ctrl+C to exit." + ) + try: + sys.stdin.readline() + except KeyboardInterrupt: + print("\nExiting...") + sys.exit(1) + else: + # get error line number + line_number = int(exc.split(", line ")[-1].split(",")[0]) + print(line_number) + raise Exception( + f"An error occurred. Please fix the issue in the file '{fpath}' and run the script again." + ) def _reloading_loop(seq, every=1): @@ -158,19 +197,21 @@ def _reloading_loop(seq, every=1): caller_globals = loop_frame_info[0].f_globals caller_locals = loop_frame_info[0].f_locals - # create a unique name in the caller namespace that we can safely write - # the values of the iteration variables into unique = unique_name(chain(caller_locals.keys(), caller_globals.keys())) loop_id = None for i, itervar_values in enumerate(seq): if i % every == 0: - compiled_body, itervars, loop_id = get_loop_code(loop_frame_info, loop_id=loop_id) + compiled_body, itervars, loop_id = get_loop_code( + loop_frame_info, loop_id=loop_id + ) caller_locals[unique] = itervar_values exec(itervars + " = " + unique, caller_globals, caller_locals) + print(itervars) try: # run main loop body + # print(f"{caller_locals.keys()}") exec(compiled_body, caller_globals, caller_locals) except Exception: handle_exception(fpath) @@ -191,32 +232,36 @@ def get_decorator_name_or_none(dec_node): def strip_reloading_decorator(func): """Remove the 'reloading' decorator and all decorators before it""" - decorator_names = [get_decorator_name(dec) for dec in func.decorator_list] + decorator_names = [get_decorator_name_or_none(dec) for dec in func.decorator_list] reloading_idx = decorator_names.index("reloading") - func.decorator_list = func.decorator_list[reloading_idx + 1:] + func.decorator_list = func.decorator_list[reloading_idx + 1 :] -def isolate_function_def(funcname, tree): +def isolate_function_def(qualname, fn, tree): """Strip everything but the function definition from the ast in-place. Also strips the reloading decorator from the function definition""" + length = len(qualname.split(".")) + funcname = qualname.split(".")[-1] + classname = qualname.split(".")[length - 2] if length > 1 else None + + found = False for node in ast.walk(tree): - if ( - isinstance(node, ast.FunctionDef) - and node.name == funcname - and "reloading" in [ - get_decorator_name_or_none(dec) - for dec in node.decorator_list - ] - ): - strip_reloading_decorator(node) - tree.body = [ node ] - return True - return False + if isinstance(node, ast.ClassDef) and node.name == classname: + for subnode in node.body: + if isinstance(subnode, ast.FunctionDef) and subnode.name == funcname: + if "reloading" in [ + get_decorator_name_or_none(dec) + for dec in subnode.decorator_list + ]: + strip_reloading_decorator(subnode) + tree.body = [subnode] + found = True + return found def get_function_def_code(fpath, fn): tree = parse_file_until_successful(fpath) - found = isolate_function_def(fn.__name__, tree) + found = isolate_function_def(fn.__qualname__, fn, tree) if not found: return None compiled = compile(tree, filename="", mode="exec") @@ -243,13 +288,16 @@ def _reloading_function(fn, every=1): # crutch to use dict as python2 doesn't support nonlocal state = { - "func": None, + "func": fn, "reloads": 0, } def wrapped(*args, **kwargs): if state["reloads"] % every == 0: - state["func"] = get_reloaded_function(caller_globals, caller_locals, fpath, fn) or state["func"] + state["func"] = ( + get_reloaded_function(caller_globals, caller_locals, fpath, fn) + or state["func"] + ) state["reloads"] += 1 while True: try: @@ -257,7 +305,10 @@ def wrapped(*args, **kwargs): return result except Exception: handle_exception(fpath) - state["func"] = get_reloaded_function(caller_globals, caller_locals, fpath, fn) or state["func"] + state["func"] = ( + get_reloaded_function(caller_globals, caller_locals, fpath, fn) + or state["func"] + ) caller_locals[fn.__name__] = wrapped return wrapped