1
0
Fork 0

namespace_inject: fix line numbers

This commit is contained in:
Jörn-Michael Miehe 2023-04-05 01:07:40 +00:00
parent fa150ff1d0
commit 004334f5d9

View file

@ -1,6 +1,5 @@
import ast import ast
import inspect import inspect
from textwrap import dedent
from typing import Callable from typing import Callable
################## ##################
@ -21,17 +20,22 @@ class namespace_inject:
self.local_fns = local_fn self.local_fns = local_fn
@staticmethod @staticmethod
def _get_func_def(object: Callable) -> tuple[ast.Module, ast.FunctionDef]: def _get_func_def(fn: Callable) -> tuple[ast.Module, ast.FunctionDef]:
""" """
Get the AST representation and the contained ast.FunctionDef Get the AST representation and the contained ast.FunctionDef
of a function of a function
""" """
object_ast = ast.parse(dedent(inspect.getsource(object))) assert (object_module := inspect.getmodule(fn)) is not None
module_ast = ast.parse(inspect.getsource(object_module))
return object_ast, next( function_defs = [
x for x in ast.walk(object_ast) x for x in ast.walk(module_ast)
if isinstance(x, ast.FunctionDef) and x.name == object.__name__ if isinstance(x, ast.FunctionDef) and x.name == fn.__name__
) ]
assert len(function_defs) == 1, \
f"Function {fn.__name__} must have a unique name in its module!"
return module_ast, function_defs[0]
def __call__(self, fn: Callable) -> Callable: def __call__(self, fn: Callable) -> Callable:
""" """
@ -71,7 +75,7 @@ class namespace_inject:
def foo() -> None: def foo() -> None:
# foo, but in global namespace it uses an unbound variable # foo in global namespace uses unbound variable `pstr`
print(f"{pstr = }") # type:ignore # noqa: F821 print(f"{pstr = }") # type:ignore # noqa: F821
@ -91,11 +95,13 @@ def call_foo_local(pstr: str) -> None:
In local namespace, `pstr` is defined. In local namespace, `pstr` is defined.
""" """
def foo() -> None: def foo_local() -> None:
# foo in local namespace is fine # foo in local namespace is fine (same code as global foo)
print(f"{pstr = }") print(f"{pstr = }")
foo = foo_local
foo() foo()