diff --git a/cursed_snakes/namespace_inject.py b/cursed_snakes/namespace_inject.py index 0a8044b..583406c 100644 --- a/cursed_snakes/namespace_inject.py +++ b/cursed_snakes/namespace_inject.py @@ -1,6 +1,5 @@ import ast import inspect -from textwrap import dedent from typing import Callable ################## @@ -21,17 +20,22 @@ class namespace_inject: self.local_fns = local_fn @staticmethod - def _get_func_def(object: Callable) -> tuple[ast.Module, ast.FunctionDef]: + def _get_func_def(fn: Callable) -> tuple[ast.Module, ast.FunctionDef]: """ Get the AST representation and the contained ast.FunctionDef of a function """ - object_ast = ast.parse(dedent(inspect.getsource(object))) + assert (object_module := inspect.getmodule(fn)) is not None + module_ast = ast.parse(inspect.getsource(object_module)) - return object_ast, next( - x for x in ast.walk(object_ast) - if isinstance(x, ast.FunctionDef) and x.name == object.__name__ - ) + function_defs = [ + x for x in ast.walk(module_ast) + if isinstance(x, ast.FunctionDef) and x.name == fn.__name__ + ] + assert len(function_defs) == 1, \ + f"Function {fn.__name__} must have a unique name in its module!" + + return module_ast, function_defs[0] def __call__(self, fn: Callable) -> Callable: """ @@ -71,7 +75,7 @@ class namespace_inject: def foo() -> None: - # foo, but in global namespace it uses an unbound variable + # foo in global namespace uses unbound variable `pstr` print(f"{pstr = }") # type:ignore # noqa: F821 @@ -91,11 +95,13 @@ def call_foo_local(pstr: str) -> None: In local namespace, `pstr` is defined. """ - def foo() -> None: - # foo in local namespace is fine + def foo_local() -> None: + # foo in local namespace is fine (same code as global foo) print(f"{pstr = }") + foo = foo_local + foo()