2023-04-03 23:21:01 +00:00
|
|
|
import ast
|
|
|
|
import inspect
|
2023-04-04 23:06:31 +00:00
|
|
|
from typing import Callable
|
2023-04-03 23:21:01 +00:00
|
|
|
|
2023-04-04 23:45:21 +00:00
|
|
|
##################
|
|
|
|
# CURSED SECTION #
|
|
|
|
##################
|
|
|
|
|
2023-04-04 19:20:21 +00:00
|
|
|
|
2023-04-04 23:45:21 +00:00
|
|
|
class namespace_inject:
|
|
|
|
"""
|
|
|
|
This is a decorator that injects one or more functions into
|
|
|
|
the local namespace of the decorated function.
|
|
|
|
"""
|
2023-04-03 23:21:01 +00:00
|
|
|
|
2023-04-04 23:45:21 +00:00
|
|
|
# functions to be injected
|
2023-04-04 23:06:31 +00:00
|
|
|
local_fns: tuple[Callable, ...]
|
2023-04-03 23:21:01 +00:00
|
|
|
|
2023-04-04 23:06:31 +00:00
|
|
|
def __init__(self, *local_fn: Callable) -> None:
|
|
|
|
self.local_fns = local_fn
|
2023-04-03 23:21:01 +00:00
|
|
|
|
2023-04-04 23:06:31 +00:00
|
|
|
@staticmethod
|
2023-04-05 01:07:40 +00:00
|
|
|
def _get_func_def(fn: Callable) -> tuple[ast.Module, ast.FunctionDef]:
|
2023-04-04 23:45:21 +00:00
|
|
|
"""
|
|
|
|
Get the AST representation and the contained ast.FunctionDef
|
|
|
|
of a function
|
|
|
|
"""
|
2023-04-05 01:07:40 +00:00
|
|
|
assert (object_module := inspect.getmodule(fn)) is not None
|
|
|
|
module_ast = ast.parse(inspect.getsource(object_module))
|
2023-04-04 23:45:21 +00:00
|
|
|
|
2023-04-05 01:07:40 +00:00
|
|
|
function_defs = [
|
|
|
|
x for x in ast.walk(module_ast)
|
|
|
|
if isinstance(x, ast.FunctionDef) and x.name == fn.__name__
|
|
|
|
]
|
|
|
|
assert len(function_defs) == 1, \
|
|
|
|
f"Function {fn.__name__} must have a unique name in its module!"
|
|
|
|
|
|
|
|
return module_ast, function_defs[0]
|
2023-04-03 23:21:01 +00:00
|
|
|
|
2023-04-04 23:06:31 +00:00
|
|
|
def __call__(self, fn: Callable) -> Callable:
|
2023-04-04 23:45:21 +00:00
|
|
|
"""
|
|
|
|
The actual decorator function
|
|
|
|
"""
|
|
|
|
fn_ast, fn_func_def = self._get_func_def(fn)
|
2023-04-04 23:06:31 +00:00
|
|
|
|
2023-04-04 23:45:21 +00:00
|
|
|
# prepend the local functions to the decorated function's body
|
2023-04-04 23:06:31 +00:00
|
|
|
fn_func_def.body = [
|
2023-04-04 23:45:21 +00:00
|
|
|
self._get_func_def(lfn)[1]
|
2023-04-04 23:06:31 +00:00
|
|
|
for lfn in self.local_fns
|
|
|
|
] + fn_func_def.body
|
|
|
|
|
2023-04-04 23:45:21 +00:00
|
|
|
# remove the @namespace_inject decorator from the decorated function
|
2023-04-04 23:06:31 +00:00
|
|
|
fn_func_def.decorator_list = [
|
2023-04-04 23:45:21 +00:00
|
|
|
decorator for decorator in fn_func_def.decorator_list
|
2023-04-04 23:06:31 +00:00
|
|
|
if not (
|
|
|
|
isinstance(decorator, ast.Call)
|
|
|
|
and isinstance(decorator.func, ast.Name)
|
2023-04-04 23:45:21 +00:00
|
|
|
and decorator.func.id == namespace_inject.__name__
|
2023-04-04 23:06:31 +00:00
|
|
|
)
|
|
|
|
]
|
|
|
|
|
2023-04-04 23:45:21 +00:00
|
|
|
# recompile and return the decorated function
|
2023-04-04 23:06:31 +00:00
|
|
|
if (fn_sourcefile := inspect.getsourcefile(fn)) is None:
|
|
|
|
fn_sourcefile = "<string>"
|
|
|
|
|
|
|
|
exec(
|
2023-04-04 23:45:21 +00:00
|
|
|
compile(ast.fix_missing_locations(fn_ast), fn_sourcefile, "exec"),
|
2023-04-04 23:06:31 +00:00
|
|
|
fn_globals := {},
|
|
|
|
)
|
|
|
|
return fn_globals[fn.__name__]
|
2023-04-04 19:20:21 +00:00
|
|
|
|
|
|
|
##################
|
|
|
|
# NORMAL SECTION #
|
|
|
|
##################
|
|
|
|
|
|
|
|
|
2023-04-04 23:45:21 +00:00
|
|
|
def foo() -> None:
|
2023-04-05 01:07:40 +00:00
|
|
|
# foo in global namespace uses unbound variable `pstr`
|
2023-04-04 19:20:21 +00:00
|
|
|
|
2023-04-04 23:45:21 +00:00
|
|
|
print(f"{pstr = }") # type:ignore # noqa: F821
|
2023-04-04 19:20:21 +00:00
|
|
|
|
2023-04-04 23:45:21 +00:00
|
|
|
|
|
|
|
def call_foo_global(pstr: str) -> None:
|
|
|
|
"""
|
|
|
|
Just call foo, so global foo is called.
|
|
|
|
This should fail: foo does not know about `pstr`.
|
|
|
|
"""
|
2023-04-04 19:20:21 +00:00
|
|
|
|
|
|
|
foo()
|
2023-04-03 16:50:34 +00:00
|
|
|
|
|
|
|
|
2023-04-04 23:45:21 +00:00
|
|
|
def call_foo_local(pstr: str) -> None:
|
|
|
|
"""
|
|
|
|
Redefine foo locally, then call - local foo is called.
|
|
|
|
In local namespace, `pstr` is defined.
|
|
|
|
"""
|
2023-04-04 19:20:21 +00:00
|
|
|
|
2023-04-05 01:07:40 +00:00
|
|
|
def foo_local() -> None:
|
|
|
|
# foo in local namespace is fine (same code as global foo)
|
2023-04-04 19:20:21 +00:00
|
|
|
|
2023-04-04 23:45:21 +00:00
|
|
|
print(f"{pstr = }")
|
|
|
|
|
2023-04-05 01:07:40 +00:00
|
|
|
foo = foo_local
|
|
|
|
|
2023-04-04 23:45:21 +00:00
|
|
|
foo()
|
2023-04-04 23:06:31 +00:00
|
|
|
|
|
|
|
|
2023-04-04 23:45:21 +00:00
|
|
|
@namespace_inject(foo)
|
|
|
|
def call_foo_injected(pstr: str) -> None:
|
|
|
|
"""
|
|
|
|
Inject the global foo into the local namespace.
|
|
|
|
This behaves like `call_foo_local`.
|
|
|
|
"""
|
|
|
|
|
|
|
|
foo()
|
2023-04-04 19:20:21 +00:00
|
|
|
|
|
|
|
|
|
|
|
def func_info(f: Callable) -> None:
|
|
|
|
try:
|
|
|
|
print(f"{f.__name__} = {f}, {f.__code__.co_varnames = }")
|
|
|
|
f("bar")
|
|
|
|
f("baz")
|
|
|
|
except Exception:
|
|
|
|
print(f"Function {f.__name__} is broken.")
|
2023-04-03 16:50:34 +00:00
|
|
|
|
|
|
|
|
2023-04-05 00:21:32 +00:00
|
|
|
def main() -> None:
|
2023-04-04 23:45:21 +00:00
|
|
|
func_info(call_foo_global)
|
|
|
|
func_info(call_foo_local)
|
|
|
|
func_info(call_foo_injected)
|
2023-04-05 00:21:32 +00:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|