generated from Yavook.de/vscode-python3
namespace-inject: decorator class
This commit is contained in:
parent
e9d0ad4f91
commit
3b1778314c
1 changed files with 40 additions and 39 deletions
|
@ -1,56 +1,53 @@
|
||||||
import ast
|
import ast
|
||||||
import inspect
|
import inspect
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from typing import Any, Callable
|
from typing import Callable
|
||||||
|
|
||||||
#################
|
#################
|
||||||
# MAGIC SECTION #
|
# MAGIC SECTION #
|
||||||
#################
|
#################
|
||||||
|
|
||||||
|
|
||||||
def ast_dump(sobject) -> None:
|
class Shoehorn:
|
||||||
print(f"Dumping AST of {sobject}")
|
local_fns: tuple[Callable, ...]
|
||||||
print(ast.dump(ast.parse(
|
|
||||||
dedent(inspect.getsource(sobject)),
|
|
||||||
), indent=4))
|
|
||||||
|
|
||||||
|
def __init__(self, *local_fn: Callable) -> None:
|
||||||
|
self.local_fns = local_fn
|
||||||
|
|
||||||
def shoehorn(f: Callable) -> Callable:
|
@staticmethod
|
||||||
# ast_dump(f)
|
def _get_func_def(object: Callable) -> tuple[ast.FunctionDef, ast.Module]:
|
||||||
|
object_ast = ast.parse(dedent(inspect.getsource(object)))
|
||||||
|
return next(
|
||||||
|
x for x in ast.walk(object_ast)
|
||||||
|
if isinstance(x, ast.FunctionDef) and x.name == object.__name__
|
||||||
|
), object_ast
|
||||||
|
|
||||||
def foo() -> None:
|
def __call__(self, fn: Callable) -> Callable:
|
||||||
print(f"{pstr = }") # type:ignore # noqa: F821
|
fn_func_def, fn_ast = self._get_func_def(fn)
|
||||||
|
|
||||||
# ast_dump(foo)
|
fn_func_def.body = [
|
||||||
foo_ast = ast.parse(dedent(inspect.getsource(foo)))
|
self._get_func_def(lfn)[0]
|
||||||
foo_fn = next(
|
for lfn in self.local_fns
|
||||||
x
|
] + fn_func_def.body
|
||||||
for x in ast.walk(foo_ast)
|
|
||||||
if isinstance(x, ast.FunctionDef) and x.name == "foo"
|
|
||||||
)
|
|
||||||
|
|
||||||
class Shoehorn(ast.NodeTransformer):
|
fn_func_def.decorator_list = [
|
||||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
|
||||||
if node.name == f.__name__:
|
|
||||||
node.body.insert(0, foo_fn)
|
|
||||||
node.decorator_list = [
|
|
||||||
decorator
|
decorator
|
||||||
for decorator in node.decorator_list
|
for decorator in fn_func_def.decorator_list
|
||||||
if not (
|
if not (
|
||||||
isinstance(decorator, ast.Name)
|
isinstance(decorator, ast.Call)
|
||||||
and decorator.id == shoehorn.__name__
|
and isinstance(decorator.func, ast.Name)
|
||||||
|
and decorator.func.id == Shoehorn.__name__
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
return node
|
|
||||||
|
|
||||||
f_ast = ast.parse(dedent(inspect.getsource(f)))
|
if (fn_sourcefile := inspect.getsourcefile(fn)) is None:
|
||||||
new_f_ast = ast.fix_missing_locations(Shoehorn().visit(f_ast))
|
fn_sourcefile = "<string>"
|
||||||
# print(ast.dump(new_f_ast, indent=4))
|
|
||||||
|
|
||||||
new_f_scope = {}
|
exec(
|
||||||
exec(compile(new_f_ast, "<string>", "exec"), new_f_scope)
|
compile(fn_ast, fn_sourcefile, "exec"),
|
||||||
|
fn_globals := {},
|
||||||
return new_f_scope[f.__name__]
|
)
|
||||||
|
return fn_globals[fn.__name__]
|
||||||
|
|
||||||
##################
|
##################
|
||||||
# NORMAL SECTION #
|
# NORMAL SECTION #
|
||||||
|
@ -72,9 +69,13 @@ def func2(pstr: str) -> None:
|
||||||
foo()
|
foo()
|
||||||
|
|
||||||
|
|
||||||
@shoehorn
|
def foo_local() -> None:
|
||||||
|
print(f"{pstr = }") # type:ignore # noqa: F821
|
||||||
|
|
||||||
|
|
||||||
|
@Shoehorn(foo_local)
|
||||||
def func3(pstr: str) -> None:
|
def func3(pstr: str) -> None:
|
||||||
foo()
|
foo_local()
|
||||||
|
|
||||||
|
|
||||||
def func_info(f: Callable) -> None:
|
def func_info(f: Callable) -> None:
|
||||||
|
|
Loading…
Reference in a new issue