import ast import inspect from textwrap import dedent from typing import Callable ################# # MAGIC SECTION # ################# class Shoehorn: local_fns: tuple[Callable, ...] def __init__(self, *local_fn: Callable) -> None: self.local_fns = local_fn @staticmethod def _get_func_def(object: Callable) -> tuple[ast.FunctionDef, ast.Module]: object_ast = ast.parse(dedent(inspect.getsource(object))) return next( x for x in ast.walk(object_ast) if isinstance(x, ast.FunctionDef) and x.name == object.__name__ ), object_ast def __call__(self, fn: Callable) -> Callable: fn_func_def, fn_ast = self._get_func_def(fn) fn_func_def.body = [ self._get_func_def(lfn)[0] for lfn in self.local_fns ] + fn_func_def.body fn_func_def.decorator_list = [ decorator for decorator in fn_func_def.decorator_list if not ( isinstance(decorator, ast.Call) and isinstance(decorator.func, ast.Name) and decorator.func.id == Shoehorn.__name__ ) ] if (fn_sourcefile := inspect.getsourcefile(fn)) is None: fn_sourcefile = "" exec( compile(fn_ast, fn_sourcefile, "exec"), fn_globals := {}, ) return fn_globals[fn.__name__] ################## # NORMAL SECTION # ################## def foo(): raise NotImplementedError() def func1(pstr: str) -> None: def foo() -> None: print(f"{pstr = }") foo() def func2(pstr: str) -> None: foo() def foo_local() -> None: print(f"{pstr = }") # type:ignore # noqa: F821 @Shoehorn(foo_local) def func3(pstr: str) -> None: foo_local() def func_info(f: Callable) -> None: try: print(f"{f.__name__} = {f}, {f.__code__.co_varnames = }") f("bar") f("baz") except Exception: print(f"Function {f.__name__} is broken.") if __name__ == "__main__": func_info(func1) func_info(func2) func_info(func3)