diff --git a/namespace-inject.py b/namespace-inject.py index fbad954..786b3c1 100644 --- a/namespace-inject.py +++ b/namespace-inject.py @@ -1,20 +1,64 @@ import ast import inspect from textwrap import dedent -from types import FunctionType from typing import Any, Callable +################# +# MAGIC SECTION # +################# + def ast_dump(sobject) -> None: + print(f"Dumping AST of {sobject}") print(ast.dump(ast.parse( - source=dedent(inspect.getsource(sobject)), + dedent(inspect.getsource(sobject)), ), indent=4)) -def func_info(f: Callable) -> None: - print(f"{f.__name__ } = {f}, {f.__code__.co_varnames = }") - f("bar") - f("baz") +def shoehorn(f: Callable) -> Callable: + # ast_dump(f) + + def foo() -> None: + print(f"{pstr = }") # type:ignore # noqa: F821 + + # ast_dump(foo) + foo_ast = ast.parse(dedent(inspect.getsource(foo))) + foo_fn = next( + x + for x in ast.walk(foo_ast) + if isinstance(x, ast.FunctionDef) and x.name == "foo" + ) + + class Shoehorn(ast.NodeTransformer): + def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: + if node.name == f.__name__: + node.body.insert(0, foo_fn) + node.decorator_list = [ + decorator + for decorator in node.decorator_list + if not ( + isinstance(decorator, ast.Name) + and decorator.id == shoehorn.__name__ + ) + ] + return node + + f_ast = ast.parse(dedent(inspect.getsource(f))) + new_f_ast = ast.fix_missing_locations(Shoehorn().visit(f_ast)) + # print(ast.dump(new_f_ast, indent=4)) + + new_f_scope = {} + exec(compile(new_f_ast, "", "exec"), new_f_scope) + + return new_f_scope[f.__name__] + +################## +# NORMAL SECTION # +################## + + +def foo(): + raise NotImplementedError() def func1(pstr: str) -> None: @@ -24,49 +68,25 @@ def func1(pstr: str) -> None: foo() -def shoehorn(f: Callable) -> Callable: - ast_dump(f) - f_ast = ast.parse(source=dedent(inspect.getsource(f))) - - def foo() -> None: - print(f"{pstr = }") # type:ignore # noqa: F821 - - ast_dump(foo) - # foo_ast = ast.parse(source=dedent(inspect.getsource(foo))) - - class Shoehorn(ast.NodeTransformer): - def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: - print(ast.dump(node, indent=4)) - return node - - new_f_ast = ast.fix_missing_locations(Shoehorn().visit(f_ast)) - print(ast.dump(new_f_ast, indent=4)) - - code = f.__code__.replace( - co_nlocals=2, - co_varnames=("pstr", "foo"), - ) - - return FunctionType( - name=f.__name__, - code=code, - globals=func1.__globals__, - ) +def func2(pstr: str) -> None: + foo() @shoehorn -def func2(pstr: str) -> None: - foo() # type:ignore # noqa: F821 +def func3(pstr: str) -> None: + foo() + + +def func_info(f: Callable) -> None: + try: + print(f"{f.__name__} = {f}, {f.__code__.co_varnames = }") + f("bar") + f("baz") + except Exception: + print(f"Function {f.__name__} is broken.") if __name__ == "__main__": func_info(func1) - # ast_dump(func1) func_info(func2) - - import json - - print(json.dumps({ - name: repr(getattr(func2.__code__, name)) - for name in dir(func2.__code__) - }, indent=2)) + func_info(func3)