1
0
Fork 0

namespace-inject: decorator class

This commit is contained in:
Jörn-Michael Miehe 2023-04-04 23:06:31 +00:00
parent e9d0ad4f91
commit 3b1778314c

View file

@ -1,56 +1,53 @@
import ast import ast
import inspect import inspect
from textwrap import dedent from textwrap import dedent
from typing import Any, Callable from typing import Callable
################# #################
# MAGIC SECTION # # MAGIC SECTION #
################# #################
def ast_dump(sobject) -> None: class Shoehorn:
print(f"Dumping AST of {sobject}") local_fns: tuple[Callable, ...]
print(ast.dump(ast.parse(
dedent(inspect.getsource(sobject)),
), indent=4))
def __init__(self, *local_fn: Callable) -> None:
self.local_fns = local_fn
def shoehorn(f: Callable) -> Callable: @staticmethod
# ast_dump(f) def _get_func_def(object: Callable) -> tuple[ast.FunctionDef, ast.Module]:
object_ast = ast.parse(dedent(inspect.getsource(object)))
return next(
x for x in ast.walk(object_ast)
if isinstance(x, ast.FunctionDef) and x.name == object.__name__
), object_ast
def foo() -> None: def __call__(self, fn: Callable) -> Callable:
print(f"{pstr = }") # type:ignore # noqa: F821 fn_func_def, fn_ast = self._get_func_def(fn)
# ast_dump(foo) fn_func_def.body = [
foo_ast = ast.parse(dedent(inspect.getsource(foo))) self._get_func_def(lfn)[0]
foo_fn = next( for lfn in self.local_fns
x ] + fn_func_def.body
for x in ast.walk(foo_ast)
if isinstance(x, ast.FunctionDef) and x.name == "foo"
)
class Shoehorn(ast.NodeTransformer): fn_func_def.decorator_list = [
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
if node.name == f.__name__:
node.body.insert(0, foo_fn)
node.decorator_list = [
decorator decorator
for decorator in node.decorator_list for decorator in fn_func_def.decorator_list
if not ( if not (
isinstance(decorator, ast.Name) isinstance(decorator, ast.Call)
and decorator.id == shoehorn.__name__ and isinstance(decorator.func, ast.Name)
and decorator.func.id == Shoehorn.__name__
) )
] ]
return node
f_ast = ast.parse(dedent(inspect.getsource(f))) if (fn_sourcefile := inspect.getsourcefile(fn)) is None:
new_f_ast = ast.fix_missing_locations(Shoehorn().visit(f_ast)) fn_sourcefile = "<string>"
# print(ast.dump(new_f_ast, indent=4))
new_f_scope = {} exec(
exec(compile(new_f_ast, "<string>", "exec"), new_f_scope) compile(fn_ast, fn_sourcefile, "exec"),
fn_globals := {},
return new_f_scope[f.__name__] )
return fn_globals[fn.__name__]
################## ##################
# NORMAL SECTION # # NORMAL SECTION #
@ -72,9 +69,13 @@ def func2(pstr: str) -> None:
foo() foo()
@shoehorn def foo_local() -> None:
print(f"{pstr = }") # type:ignore # noqa: F821
@Shoehorn(foo_local)
def func3(pstr: str) -> None: def func3(pstr: str) -> None:
foo() foo_local()
def func_info(f: Callable) -> None: def func_info(f: Callable) -> None: