checkpointing.refactor.funcdef

  1from typing import Union, Dict, List, Any
  2from checkpointing.exceptions import RefactorFailedError
  3from checkpointing.refactor.util import local_variable_names_generator, nonlocal_variable_names_generator
  4import ast
  5from collections import deque, ChainMap
  6import textwrap
  7import copy
  8
  9
 10class FunctionDefinitionUnifier:
 11    def __init__(self, func_definition: str) -> None:
 12        tree = ast.parse(textwrap.dedent(func_definition), mode="exec")
 13        if len(tree.body) > 1 or not isinstance(
 14            tree.body[0],
 15            (
 16                ast.FunctionDef,
 17                ast.AsyncFunctionDef,
 18            ),
 19        ):
 20            raise RefactorFailedError(f"The given code is not a single function definition: {func_definition}")
 21
 22        self.transformer = _FunctionDefinitionTransformer()
 23        self.unified_tree = self.transformer.visit(tree)
 24
 25    @property
 26    def args_renaming(self) -> Dict[str, str]:
 27        """
 28        Dictionary of the renaming of function arguments.
 29
 30        >>> code = '''
 31        ...     def foo(a):
 32        ...         pass
 33        ...     '''
 34        >>>
 35        >>> u = FunctionDefinitionUnifier(code)
 36        >>> u.args_renaming
 37        {'a': '__checkpointing_local_var_1__'}
 38        """
 39
 40        return self.transformer.root_function_args_renaming
 41
 42    @property
 43    def nonlocal_variables_renaming(self) -> Dict[str, str]:
 44        """
 45        Dictionary of the renaming of nonlocal variables referenced by the function.
 46
 47        >>> code = '''
 48        ...     def foo():
 49        ...         a = b + 1 # b is some global variable defined elsewhere
 50        ...     '''
 51        >>>
 52        >>> u = FunctionDefinitionUnifier(code)
 53        >>> u.nonlocal_variables_renaming
 54        {'b': '__checkpointing_nonlocal_var_0__'}
 55        """
 56
 57        return self.transformer.nonlocal_variables
 58
 59    @property
 60    def has_global_statement(self) -> bool:
 61        return self.transformer.has_global_statement
 62
 63    @property
 64    def has_nonlocal_statement(self) -> bool:
 65        return self.transformer.has_nonlocal_statement
 66
 67    @property
 68    def unified_ast_dump(self) -> str:
 69        """
 70        Returns:
 71            the dump string of the unified AST of the function definition.
 72
 73        By unified, it means that
 74        - Type annotations are ignored
 75        - Arguments, position-only arguments, and keyword-only arguments are renamed based on their
 76          lexicographic order
 77        - Varargs and kwargs are renamed with a unique name
 78        - Default values of the arguments
 79        - Local variables are renamed based on their order of occurrence
 80        - Global variables are renamed based on their order of occurrence
 81        - Function name is considered as a local variable, and is also renamed
 82        - AugAssign statements (a += 1) are replaced with normal assign statements (a = a + 1)
 83        - All decorators are removed. TODO: This is essentially a problem, but there is no good fix for the moment
 84
 85        Therefore, changing any aspect mentioned above will not change the returned AST dump.
 86        Also trivially, AST dump is not affected by the code formatting and comments.
 87
 88        The criteria of the ignored/renamed/unified items above is:
 89
 90        Given that the arguments are provided in a keyword-specified way, taking the renaming of
 91        arguments into account, what changes will not cause the function return value to change.
 92
 93        This is useful for judging whether two function definitions, given the same input,
 94        can produce the same output.
 95        """
 96
 97        return ast.dump(
 98            self.unified_tree,
 99            annotate_fields=False,
100            include_attributes=False,
101        )
102
103
104class _FunctionDefinitionTransformer(ast.NodeTransformer):
105    def __init__(self) -> None:
106        super().__init__()
107
108        self.local_variables = ChainMap()
109        self.root_function_args_renaming = None
110        self.nonlocal_variables = {}
111
112        self.has_global_statement = False
113        self.has_nonlocal_statement = False
114
115        self.local_names = local_variable_names_generator()
116        self.nonlocal_names = nonlocal_variable_names_generator()
117
118    def visit_AnyClosure(
119        self,
120        node: Union[
121            ast.FunctionDef,
122            ast.AsyncFunctionDef,
123            ast.ClassDef,
124        ],
125        initial_map: Dict,
126    ) -> Union[ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef,]:
127
128        self.local_variables.maps.insert(0, initial_map)
129        self.generic_visit(node)
130        self.local_variables.maps.pop(0)
131
132        return node
133
134    @property
135    def current_closure_local_variables(self):
136        return self.local_variables.maps[0]
137
138    def unify_arg(self, node: ast.arg, local_vars: Dict, new_name: str = None) -> None:
139        old_name = node.arg
140
141        if new_name is None:
142            new_name = next(self.local_names)
143
144        node.annotation = None
145        node.arg = new_name
146        local_vars[old_name] = new_name
147
148    def unify_name(self, node: Any, local_vars: Dict) -> None:
149        old_name = node.name
150        new_name = next(self.local_names)
151
152        node.name = new_name
153        local_vars[old_name] = new_name
154
155        return new_name
156
157    def visit_AnyFunctionDef(
158        self,
159        node: Union[ast.FunctionDef, ast.AsyncFunctionDef, ast.Lambda],
160    ) -> Union[ast.FunctionDef, ast.AsyncFunctionDef, ast.Lambda]:
161        
162        node.decorator_list = []
163
164        if isinstance(node, ast.Lambda):
165            new_function_name = next(self.local_names)
166        else:
167            new_function_name = self.unify_name(node, self.current_closure_local_variables)
168
169        local_vars = {}
170        args: List[ast.arg] = []
171
172        for attrname in ["posonlyargs", "args", "kwonlyargs"]:
173            if hasattr(node.args, attrname) and getattr(node.args, attrname) is not None:
174                arglist = getattr(node.args, attrname)
175                args.extend(arglist)
176            setattr(node.args, attrname, [])
177
178        for arg in sorted(args, key=lambda x: x.arg):
179            self.unify_arg(arg, local_vars)
180
181        for attrname in ["vararg", "kwarg"]:
182            if hasattr(node.args, attrname) and getattr(node.args, attrname) is not None:
183                self.unify_arg(
184                    getattr(node.args, attrname),
185                    local_vars,
186                    f"{new_function_name}_{attrname}",
187                )
188            setattr(node.args, attrname, None)
189
190        node.args.defaults = []
191        node.args.kw_defaults = []
192        node.returns = None
193
194        if self.root_function_args_renaming is None:
195            self.root_function_args_renaming = copy.deepcopy(local_vars)
196
197        return self.visit_AnyClosure(node, local_vars)
198
199    def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
200        return self.visit_AnyFunctionDef(node)
201
202    def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef:
203        return self.visit_AnyFunctionDef(node)
204
205    def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
206        self.unify_name(node, self.current_closure_local_variables)
207        return self.visit_AnyClosure(node, {})
208
209    def visit_Name(self, node: ast.Name) -> ast.Name:
210
211        if node.id in self.local_variables:  # Local variables that are already renamed
212            return ast.Name(id=self.local_variables[node.id], ctx=node.ctx)
213
214        elif node.id in self.nonlocal_variables:  # Nonlocal variables that are already renamed
215            return ast.Name(id=self.nonlocal_variables[node.id], ctx=node.ctx)
216
217        elif isinstance(node.ctx, ast.Store):  # New local variable declaration
218            new_name = next(self.local_names)
219            self.current_closure_local_variables[node.id] = new_name
220            return ast.Name(id=new_name, ctx=node.ctx)
221
222        elif isinstance(node.ctx, ast.Load):  # New nonlocal variable reference
223            new_name = next(self.nonlocal_names)
224            self.nonlocal_variables[node.id] = new_name
225            return ast.Name(id=new_name, ctx=node.ctx)
226
227    def visit_Global(self, node: ast.Global) -> ast.Global:
228        self.has_global_statement = True
229        return node
230
231    def visit_Nonlocal(self, node: ast.Nonlocal) -> ast.Nonlocal:
232        self.has_nonlocal_statement = True
233        return node
234
235    def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.Assign:
236        assign = ast.Assign(
237            targets=[node.target],
238            value=node.value if hasattr(node, "value") else ast.Constant(value=None),
239            type_comment=None,
240        )
241
242        self.generic_visit(assign)
243        return assign
244
245    def visit_AugAssign(self, node: ast.AugAssign) -> ast.AugAssign:
246        new_target = _ContextStoreToLoadTransformer().visit(
247            copy.copy(
248                node.target,
249            )
250        )
251
252        assign = ast.Assign(
253            targets=[node.target],
254            value=ast.BinOp(
255                left=new_target,
256                op=node.op,
257                right=node.value,
258            ),
259            type_comment=None
260        )
261
262        self.generic_visit(assign)
263        return assign
264
265    def visit_Lambda(self, node: ast.Lambda) -> ast.Lambda:
266        return self.visit_AnyFunctionDef(node)
267
268
269class _ContextStoreToLoadTransformer(ast.NodeTransformer):
270    def visit_node_with_context(self, node: Union[ast.Attribute, ast.Subscript, ast.Starred, ast.Name, ast.List, ast.Tuple]):
271        cp = copy.copy(node)
272        cp.ctx = ast.Load()
273        self.generic_visit(cp)
274        return cp
275
276    def visit_Attribute(self, node: ast.Attribute) -> Any:
277        return self.visit_node_with_context(node)
278
279    def visit_Subscript(self, node: ast.Subscript) -> Any:
280        return self.visit_node_with_context(node)
281
282    def visit_Starred(self, node: ast.Starred) -> Any:
283        return self.visit_node_with_context(node)
284
285    def visit_Name(self, node: ast.Name) -> Any:
286        return self.visit_node_with_context(node)
287
288    def visit_List(self, node: ast.List) -> Any:
289        return self.visit_node_with_context(node)
290
291    def visit_Tuple(self, node: ast.Tuple) -> Any:
292        return self.visit_node_with_context(node)
class FunctionDefinitionUnifier:
 11class FunctionDefinitionUnifier:
 12    def __init__(self, func_definition: str) -> None:
 13        tree = ast.parse(textwrap.dedent(func_definition), mode="exec")
 14        if len(tree.body) > 1 or not isinstance(
 15            tree.body[0],
 16            (
 17                ast.FunctionDef,
 18                ast.AsyncFunctionDef,
 19            ),
 20        ):
 21            raise RefactorFailedError(f"The given code is not a single function definition: {func_definition}")
 22
 23        self.transformer = _FunctionDefinitionTransformer()
 24        self.unified_tree = self.transformer.visit(tree)
 25
 26    @property
 27    def args_renaming(self) -> Dict[str, str]:
 28        """
 29        Dictionary of the renaming of function arguments.
 30
 31        >>> code = '''
 32        ...     def foo(a):
 33        ...         pass
 34        ...     '''
 35        >>>
 36        >>> u = FunctionDefinitionUnifier(code)
 37        >>> u.args_renaming
 38        {'a': '__checkpointing_local_var_1__'}
 39        """
 40
 41        return self.transformer.root_function_args_renaming
 42
 43    @property
 44    def nonlocal_variables_renaming(self) -> Dict[str, str]:
 45        """
 46        Dictionary of the renaming of nonlocal variables referenced by the function.
 47
 48        >>> code = '''
 49        ...     def foo():
 50        ...         a = b + 1 # b is some global variable defined elsewhere
 51        ...     '''
 52        >>>
 53        >>> u = FunctionDefinitionUnifier(code)
 54        >>> u.nonlocal_variables_renaming
 55        {'b': '__checkpointing_nonlocal_var_0__'}
 56        """
 57
 58        return self.transformer.nonlocal_variables
 59
 60    @property
 61    def has_global_statement(self) -> bool:
 62        return self.transformer.has_global_statement
 63
 64    @property
 65    def has_nonlocal_statement(self) -> bool:
 66        return self.transformer.has_nonlocal_statement
 67
 68    @property
 69    def unified_ast_dump(self) -> str:
 70        """
 71        Returns:
 72            the dump string of the unified AST of the function definition.
 73
 74        By unified, it means that
 75        - Type annotations are ignored
 76        - Arguments, position-only arguments, and keyword-only arguments are renamed based on their
 77          lexicographic order
 78        - Varargs and kwargs are renamed with a unique name
 79        - Default values of the arguments
 80        - Local variables are renamed based on their order of occurrence
 81        - Global variables are renamed based on their order of occurrence
 82        - Function name is considered as a local variable, and is also renamed
 83        - AugAssign statements (a += 1) are replaced with normal assign statements (a = a + 1)
 84        - All decorators are removed. TODO: This is essentially a problem, but there is no good fix for the moment
 85
 86        Therefore, changing any aspect mentioned above will not change the returned AST dump.
 87        Also trivially, AST dump is not affected by the code formatting and comments.
 88
 89        The criteria of the ignored/renamed/unified items above is:
 90
 91        Given that the arguments are provided in a keyword-specified way, taking the renaming of
 92        arguments into account, what changes will not cause the function return value to change.
 93
 94        This is useful for judging whether two function definitions, given the same input,
 95        can produce the same output.
 96        """
 97
 98        return ast.dump(
 99            self.unified_tree,
100            annotate_fields=False,
101            include_attributes=False,
102        )
FunctionDefinitionUnifier(func_definition: str)
12    def __init__(self, func_definition: str) -> None:
13        tree = ast.parse(textwrap.dedent(func_definition), mode="exec")
14        if len(tree.body) > 1 or not isinstance(
15            tree.body[0],
16            (
17                ast.FunctionDef,
18                ast.AsyncFunctionDef,
19            ),
20        ):
21            raise RefactorFailedError(f"The given code is not a single function definition: {func_definition}")
22
23        self.transformer = _FunctionDefinitionTransformer()
24        self.unified_tree = self.transformer.visit(tree)
args_renaming: Dict[str, str]

Dictionary of the renaming of function arguments.

>>> code = '''
...     def foo(a):
...         pass
...     '''
>>>
>>> u = FunctionDefinitionUnifier(code)
>>> u.args_renaming
{'a': '__checkpointing_local_var_1__'}
nonlocal_variables_renaming: Dict[str, str]

Dictionary of the renaming of nonlocal variables referenced by the function.

>>> code = '''
...     def foo():
...         a = b + 1 # b is some global variable defined elsewhere
...     '''
>>>
>>> u = FunctionDefinitionUnifier(code)
>>> u.nonlocal_variables_renaming
{'b': '__checkpointing_nonlocal_var_0__'}
has_global_statement: bool
has_nonlocal_statement: bool
unified_ast_dump: str
Returns

the dump string of the unified AST of the function definition.

By unified, it means that

  • Type annotations are ignored
  • Arguments, position-only arguments, and keyword-only arguments are renamed based on their lexicographic order
  • Varargs and kwargs are renamed with a unique name
  • Default values of the arguments
  • Local variables are renamed based on their order of occurrence
  • Global variables are renamed based on their order of occurrence
  • Function name is considered as a local variable, and is also renamed
  • AugAssign statements (a += 1) are replaced with normal assign statements (a = a + 1)
  • All decorators are removed. TODO: This is essentially a problem, but there is no good fix for the moment

Therefore, changing any aspect mentioned above will not change the returned AST dump. Also trivially, AST dump is not affected by the code formatting and comments.

The criteria of the ignored/renamed/unified items above is:

Given that the arguments are provided in a keyword-specified way, taking the renaming of arguments into account, what changes will not cause the function return value to change.

This is useful for judging whether two function definitions, given the same input, can produce the same output.