diff --git a/namespace-inject.py b/namespace-inject.py index 6293394..6e35ec1 100644 --- a/namespace-inject.py +++ b/namespace-inject.py @@ -3,48 +3,64 @@ import inspect from textwrap import dedent from typing import Callable -################# -# MAGIC SECTION # -################# +################## +# CURSED SECTION # +################## -class Shoehorn: +class namespace_inject: + """ + This is a decorator that injects one or more functions into + the local namespace of the decorated function. + """ + + # functions to be injected local_fns: tuple[Callable, ...] def __init__(self, *local_fn: Callable) -> None: self.local_fns = local_fn @staticmethod - def _get_func_def(object: Callable) -> tuple[ast.FunctionDef, ast.Module]: + def _get_func_def(object: Callable) -> tuple[ast.Module, ast.FunctionDef]: + """ + Get the AST representation and the contained ast.FunctionDef + of a function + """ object_ast = ast.parse(dedent(inspect.getsource(object))) - return next( + + return object_ast, next( x for x in ast.walk(object_ast) if isinstance(x, ast.FunctionDef) and x.name == object.__name__ - ), object_ast + ) def __call__(self, fn: Callable) -> Callable: - fn_func_def, fn_ast = self._get_func_def(fn) + """ + The actual decorator function + """ + fn_ast, fn_func_def = self._get_func_def(fn) + # prepend the local functions to the decorated function's body fn_func_def.body = [ - self._get_func_def(lfn)[0] + self._get_func_def(lfn)[1] for lfn in self.local_fns ] + fn_func_def.body + # remove the @namespace_inject decorator from the decorated function fn_func_def.decorator_list = [ - decorator - for decorator in fn_func_def.decorator_list + decorator for decorator in fn_func_def.decorator_list if not ( isinstance(decorator, ast.Call) and isinstance(decorator.func, ast.Name) - and decorator.func.id == Shoehorn.__name__ + and decorator.func.id == namespace_inject.__name__ ) ] + # recompile and return the decorated function if (fn_sourcefile := inspect.getsourcefile(fn)) is None: fn_sourcefile = "" exec( - compile(fn_ast, fn_sourcefile, "exec"), + compile(ast.fix_missing_locations(fn_ast), fn_sourcefile, "exec"), fn_globals := {}, ) return fn_globals[fn.__name__] @@ -54,30 +70,46 @@ class Shoehorn: ################## -def foo(): - raise NotImplementedError() +def foo() -> None: + # foo, but in global namespace it uses an unbound variable + + print(f"{pstr = }") # type:ignore # noqa: F821 -def func1(pstr: str) -> None: +def call_foo_global(pstr: str) -> None: + """ + Just call foo, so global foo is called. + This should fail: foo does not know about `pstr`. + """ + + foo() + + +def call_foo_local(pstr: str) -> None: + """ + Redefine foo locally, then call - local foo is called. + In local namespace, `pstr` is defined. + """ + # call_local_foo got its own foo, so it calls locally + def foo() -> None: + # foo in local namespace is fine + print(f"{pstr = }") foo() -def func2(pstr: str) -> None: +@namespace_inject(foo) +def call_foo_injected(pstr: str) -> None: + """ + Inject the global foo into the local namespace. + This behaves like `call_foo_local`. + """ + foo() -def foo_local() -> None: - print(f"{pstr = }") # type:ignore # noqa: F821 - - -@Shoehorn(foo_local) -def func3(pstr: str) -> None: - foo_local() - - def func_info(f: Callable) -> None: try: print(f"{f.__name__} = {f}, {f.__code__.co_varnames = }") @@ -88,6 +120,6 @@ def func_info(f: Callable) -> None: if __name__ == "__main__": - func_info(func1) - func_info(func2) - func_info(func3) + func_info(call_foo_global) + func_info(call_foo_local) + func_info(call_foo_injected)