1
0
Fork 0

namespace_inject: fix line numbers

This commit is contained in:
Jörn-Michael Miehe 2023-04-05 01:07:40 +00:00
parent fa150ff1d0
commit 004334f5d9

View file

@ -1,6 +1,5 @@
import ast
import inspect
from textwrap import dedent
from typing import Callable
##################
@ -21,17 +20,22 @@ class namespace_inject:
self.local_fns = local_fn
@staticmethod
def _get_func_def(object: Callable) -> tuple[ast.Module, ast.FunctionDef]:
def _get_func_def(fn: Callable) -> tuple[ast.Module, ast.FunctionDef]:
"""
Get the AST representation and the contained ast.FunctionDef
of a function
"""
object_ast = ast.parse(dedent(inspect.getsource(object)))
assert (object_module := inspect.getmodule(fn)) is not None
module_ast = ast.parse(inspect.getsource(object_module))
return object_ast, next(
x for x in ast.walk(object_ast)
if isinstance(x, ast.FunctionDef) and x.name == object.__name__
)
function_defs = [
x for x in ast.walk(module_ast)
if isinstance(x, ast.FunctionDef) and x.name == fn.__name__
]
assert len(function_defs) == 1, \
f"Function {fn.__name__} must have a unique name in its module!"
return module_ast, function_defs[0]
def __call__(self, fn: Callable) -> Callable:
"""
@ -71,7 +75,7 @@ class namespace_inject:
def foo() -> None:
# foo, but in global namespace it uses an unbound variable
# foo in global namespace uses unbound variable `pstr`
print(f"{pstr = }") # type:ignore # noqa: F821
@ -91,11 +95,13 @@ def call_foo_local(pstr: str) -> None:
In local namespace, `pstr` is defined.
"""
def foo() -> None:
# foo in local namespace is fine
def foo_local() -> None:
# foo in local namespace is fine (same code as global foo)
print(f"{pstr = }")
foo = foo_local
foo()