From 3b1778314cb6ba2c2b55ced89dd5d87da88149cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rn-Michael=20Miehe?= Date: Tue, 4 Apr 2023 23:06:31 +0000 Subject: [PATCH] namespace-inject: decorator class --- namespace-inject.py | 79 +++++++++++++++++++++++---------------------- 1 file changed, 40 insertions(+), 39 deletions(-) diff --git a/namespace-inject.py b/namespace-inject.py index 786b3c1..6293394 100644 --- a/namespace-inject.py +++ b/namespace-inject.py @@ -1,56 +1,53 @@ import ast import inspect from textwrap import dedent -from typing import Any, Callable +from typing import Callable ################# # MAGIC SECTION # ################# -def ast_dump(sobject) -> None: - print(f"Dumping AST of {sobject}") - print(ast.dump(ast.parse( - dedent(inspect.getsource(sobject)), - ), indent=4)) +class Shoehorn: + local_fns: tuple[Callable, ...] + def __init__(self, *local_fn: Callable) -> None: + self.local_fns = local_fn -def shoehorn(f: Callable) -> Callable: - # ast_dump(f) + @staticmethod + def _get_func_def(object: Callable) -> tuple[ast.FunctionDef, ast.Module]: + object_ast = ast.parse(dedent(inspect.getsource(object))) + return next( + x for x in ast.walk(object_ast) + if isinstance(x, ast.FunctionDef) and x.name == object.__name__ + ), object_ast - def foo() -> None: - print(f"{pstr = }") # type:ignore # noqa: F821 + def __call__(self, fn: Callable) -> Callable: + fn_func_def, fn_ast = self._get_func_def(fn) - # ast_dump(foo) - foo_ast = ast.parse(dedent(inspect.getsource(foo))) - foo_fn = next( - x - for x in ast.walk(foo_ast) - if isinstance(x, ast.FunctionDef) and x.name == "foo" - ) + fn_func_def.body = [ + self._get_func_def(lfn)[0] + for lfn in self.local_fns + ] + fn_func_def.body - class Shoehorn(ast.NodeTransformer): - def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: - if node.name == f.__name__: - node.body.insert(0, foo_fn) - node.decorator_list = [ - decorator - for decorator in node.decorator_list - if not ( - isinstance(decorator, ast.Name) - and decorator.id == shoehorn.__name__ - ) - ] - return node + fn_func_def.decorator_list = [ + decorator + for decorator in fn_func_def.decorator_list + if not ( + isinstance(decorator, ast.Call) + and isinstance(decorator.func, ast.Name) + and decorator.func.id == Shoehorn.__name__ + ) + ] - f_ast = ast.parse(dedent(inspect.getsource(f))) - new_f_ast = ast.fix_missing_locations(Shoehorn().visit(f_ast)) - # print(ast.dump(new_f_ast, indent=4)) + if (fn_sourcefile := inspect.getsourcefile(fn)) is None: + fn_sourcefile = "" - new_f_scope = {} - exec(compile(new_f_ast, "", "exec"), new_f_scope) - - return new_f_scope[f.__name__] + exec( + compile(fn_ast, fn_sourcefile, "exec"), + fn_globals := {}, + ) + return fn_globals[fn.__name__] ################## # NORMAL SECTION # @@ -72,9 +69,13 @@ def func2(pstr: str) -> None: foo() -@shoehorn +def foo_local() -> None: + print(f"{pstr = }") # type:ignore # noqa: F821 + + +@Shoehorn(foo_local) def func3(pstr: str) -> None: - foo() + foo_local() def func_info(f: Callable) -> None: