checkpointing.refactor.funcdef
1from typing import Union, Dict, List, Any 2from checkpointing.exceptions import RefactorFailedError 3from checkpointing.refactor.util import local_variable_names_generator, nonlocal_variable_names_generator 4import ast 5from collections import deque, ChainMap 6import textwrap 7import copy 8 9 10class FunctionDefinitionUnifier: 11 def __init__(self, func_definition: str) -> None: 12 tree = ast.parse(textwrap.dedent(func_definition), mode="exec") 13 if len(tree.body) > 1 or not isinstance( 14 tree.body[0], 15 ( 16 ast.FunctionDef, 17 ast.AsyncFunctionDef, 18 ), 19 ): 20 raise RefactorFailedError(f"The given code is not a single function definition: {func_definition}") 21 22 self.transformer = _FunctionDefinitionTransformer() 23 self.unified_tree = self.transformer.visit(tree) 24 25 @property 26 def args_renaming(self) -> Dict[str, str]: 27 """ 28 Dictionary of the renaming of function arguments. 29 30 >>> code = ''' 31 ... def foo(a): 32 ... pass 33 ... ''' 34 >>> 35 >>> u = FunctionDefinitionUnifier(code) 36 >>> u.args_renaming 37 {'a': '__checkpointing_local_var_1__'} 38 """ 39 40 return self.transformer.root_function_args_renaming 41 42 @property 43 def nonlocal_variables_renaming(self) -> Dict[str, str]: 44 """ 45 Dictionary of the renaming of nonlocal variables referenced by the function. 46 47 >>> code = ''' 48 ... def foo(): 49 ... a = b + 1 # b is some global variable defined elsewhere 50 ... ''' 51 >>> 52 >>> u = FunctionDefinitionUnifier(code) 53 >>> u.nonlocal_variables_renaming 54 {'b': '__checkpointing_nonlocal_var_0__'} 55 """ 56 57 return self.transformer.nonlocal_variables 58 59 @property 60 def has_global_statement(self) -> bool: 61 return self.transformer.has_global_statement 62 63 @property 64 def has_nonlocal_statement(self) -> bool: 65 return self.transformer.has_nonlocal_statement 66 67 @property 68 def unified_ast_dump(self) -> str: 69 """ 70 Returns: 71 the dump string of the unified AST of the function definition. 72 73 By unified, it means that 74 - Type annotations are ignored 75 - Arguments, position-only arguments, and keyword-only arguments are renamed based on their 76 lexicographic order 77 - Varargs and kwargs are renamed with a unique name 78 - Default values of the arguments 79 - Local variables are renamed based on their order of occurrence 80 - Global variables are renamed based on their order of occurrence 81 - Function name is considered as a local variable, and is also renamed 82 - AugAssign statements (a += 1) are replaced with normal assign statements (a = a + 1) 83 - All decorators are removed. TODO: This is essentially a problem, but there is no good fix for the moment 84 85 Therefore, changing any aspect mentioned above will not change the returned AST dump. 86 Also trivially, AST dump is not affected by the code formatting and comments. 87 88 The criteria of the ignored/renamed/unified items above is: 89 90 Given that the arguments are provided in a keyword-specified way, taking the renaming of 91 arguments into account, what changes will not cause the function return value to change. 92 93 This is useful for judging whether two function definitions, given the same input, 94 can produce the same output. 95 """ 96 97 return ast.dump( 98 self.unified_tree, 99 annotate_fields=False, 100 include_attributes=False, 101 ) 102 103 104class _FunctionDefinitionTransformer(ast.NodeTransformer): 105 def __init__(self) -> None: 106 super().__init__() 107 108 self.local_variables = ChainMap() 109 self.root_function_args_renaming = None 110 self.nonlocal_variables = {} 111 112 self.has_global_statement = False 113 self.has_nonlocal_statement = False 114 115 self.local_names = local_variable_names_generator() 116 self.nonlocal_names = nonlocal_variable_names_generator() 117 118 def visit_AnyClosure( 119 self, 120 node: Union[ 121 ast.FunctionDef, 122 ast.AsyncFunctionDef, 123 ast.ClassDef, 124 ], 125 initial_map: Dict, 126 ) -> Union[ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef,]: 127 128 self.local_variables.maps.insert(0, initial_map) 129 self.generic_visit(node) 130 self.local_variables.maps.pop(0) 131 132 return node 133 134 @property 135 def current_closure_local_variables(self): 136 return self.local_variables.maps[0] 137 138 def unify_arg(self, node: ast.arg, local_vars: Dict, new_name: str = None) -> None: 139 old_name = node.arg 140 141 if new_name is None: 142 new_name = next(self.local_names) 143 144 node.annotation = None 145 node.arg = new_name 146 local_vars[old_name] = new_name 147 148 def unify_name(self, node: Any, local_vars: Dict) -> None: 149 old_name = node.name 150 new_name = next(self.local_names) 151 152 node.name = new_name 153 local_vars[old_name] = new_name 154 155 return new_name 156 157 def visit_AnyFunctionDef( 158 self, 159 node: Union[ast.FunctionDef, ast.AsyncFunctionDef, ast.Lambda], 160 ) -> Union[ast.FunctionDef, ast.AsyncFunctionDef, ast.Lambda]: 161 162 node.decorator_list = [] 163 164 if isinstance(node, ast.Lambda): 165 new_function_name = next(self.local_names) 166 else: 167 new_function_name = self.unify_name(node, self.current_closure_local_variables) 168 169 local_vars = {} 170 args: List[ast.arg] = [] 171 172 for attrname in ["posonlyargs", "args", "kwonlyargs"]: 173 if hasattr(node.args, attrname) and getattr(node.args, attrname) is not None: 174 arglist = getattr(node.args, attrname) 175 args.extend(arglist) 176 setattr(node.args, attrname, []) 177 178 for arg in sorted(args, key=lambda x: x.arg): 179 self.unify_arg(arg, local_vars) 180 181 for attrname in ["vararg", "kwarg"]: 182 if hasattr(node.args, attrname) and getattr(node.args, attrname) is not None: 183 self.unify_arg( 184 getattr(node.args, attrname), 185 local_vars, 186 f"{new_function_name}_{attrname}", 187 ) 188 setattr(node.args, attrname, None) 189 190 node.args.defaults = [] 191 node.args.kw_defaults = [] 192 node.returns = None 193 194 if self.root_function_args_renaming is None: 195 self.root_function_args_renaming = copy.deepcopy(local_vars) 196 197 return self.visit_AnyClosure(node, local_vars) 198 199 def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef: 200 return self.visit_AnyFunctionDef(node) 201 202 def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef: 203 return self.visit_AnyFunctionDef(node) 204 205 def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: 206 self.unify_name(node, self.current_closure_local_variables) 207 return self.visit_AnyClosure(node, {}) 208 209 def visit_Name(self, node: ast.Name) -> ast.Name: 210 211 if node.id in self.local_variables: # Local variables that are already renamed 212 return ast.Name(id=self.local_variables[node.id], ctx=node.ctx) 213 214 elif node.id in self.nonlocal_variables: # Nonlocal variables that are already renamed 215 return ast.Name(id=self.nonlocal_variables[node.id], ctx=node.ctx) 216 217 elif isinstance(node.ctx, ast.Store): # New local variable declaration 218 new_name = next(self.local_names) 219 self.current_closure_local_variables[node.id] = new_name 220 return ast.Name(id=new_name, ctx=node.ctx) 221 222 elif isinstance(node.ctx, ast.Load): # New nonlocal variable reference 223 new_name = next(self.nonlocal_names) 224 self.nonlocal_variables[node.id] = new_name 225 return ast.Name(id=new_name, ctx=node.ctx) 226 227 def visit_Global(self, node: ast.Global) -> ast.Global: 228 self.has_global_statement = True 229 return node 230 231 def visit_Nonlocal(self, node: ast.Nonlocal) -> ast.Nonlocal: 232 self.has_nonlocal_statement = True 233 return node 234 235 def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.Assign: 236 assign = ast.Assign( 237 targets=[node.target], 238 value=node.value if hasattr(node, "value") else ast.Constant(value=None), 239 type_comment=None, 240 ) 241 242 self.generic_visit(assign) 243 return assign 244 245 def visit_AugAssign(self, node: ast.AugAssign) -> ast.AugAssign: 246 new_target = _ContextStoreToLoadTransformer().visit( 247 copy.copy( 248 node.target, 249 ) 250 ) 251 252 assign = ast.Assign( 253 targets=[node.target], 254 value=ast.BinOp( 255 left=new_target, 256 op=node.op, 257 right=node.value, 258 ), 259 type_comment=None 260 ) 261 262 self.generic_visit(assign) 263 return assign 264 265 def visit_Lambda(self, node: ast.Lambda) -> ast.Lambda: 266 return self.visit_AnyFunctionDef(node) 267 268 269class _ContextStoreToLoadTransformer(ast.NodeTransformer): 270 def visit_node_with_context(self, node: Union[ast.Attribute, ast.Subscript, ast.Starred, ast.Name, ast.List, ast.Tuple]): 271 cp = copy.copy(node) 272 cp.ctx = ast.Load() 273 self.generic_visit(cp) 274 return cp 275 276 def visit_Attribute(self, node: ast.Attribute) -> Any: 277 return self.visit_node_with_context(node) 278 279 def visit_Subscript(self, node: ast.Subscript) -> Any: 280 return self.visit_node_with_context(node) 281 282 def visit_Starred(self, node: ast.Starred) -> Any: 283 return self.visit_node_with_context(node) 284 285 def visit_Name(self, node: ast.Name) -> Any: 286 return self.visit_node_with_context(node) 287 288 def visit_List(self, node: ast.List) -> Any: 289 return self.visit_node_with_context(node) 290 291 def visit_Tuple(self, node: ast.Tuple) -> Any: 292 return self.visit_node_with_context(node)
class
FunctionDefinitionUnifier:
11class FunctionDefinitionUnifier: 12 def __init__(self, func_definition: str) -> None: 13 tree = ast.parse(textwrap.dedent(func_definition), mode="exec") 14 if len(tree.body) > 1 or not isinstance( 15 tree.body[0], 16 ( 17 ast.FunctionDef, 18 ast.AsyncFunctionDef, 19 ), 20 ): 21 raise RefactorFailedError(f"The given code is not a single function definition: {func_definition}") 22 23 self.transformer = _FunctionDefinitionTransformer() 24 self.unified_tree = self.transformer.visit(tree) 25 26 @property 27 def args_renaming(self) -> Dict[str, str]: 28 """ 29 Dictionary of the renaming of function arguments. 30 31 >>> code = ''' 32 ... def foo(a): 33 ... pass 34 ... ''' 35 >>> 36 >>> u = FunctionDefinitionUnifier(code) 37 >>> u.args_renaming 38 {'a': '__checkpointing_local_var_1__'} 39 """ 40 41 return self.transformer.root_function_args_renaming 42 43 @property 44 def nonlocal_variables_renaming(self) -> Dict[str, str]: 45 """ 46 Dictionary of the renaming of nonlocal variables referenced by the function. 47 48 >>> code = ''' 49 ... def foo(): 50 ... a = b + 1 # b is some global variable defined elsewhere 51 ... ''' 52 >>> 53 >>> u = FunctionDefinitionUnifier(code) 54 >>> u.nonlocal_variables_renaming 55 {'b': '__checkpointing_nonlocal_var_0__'} 56 """ 57 58 return self.transformer.nonlocal_variables 59 60 @property 61 def has_global_statement(self) -> bool: 62 return self.transformer.has_global_statement 63 64 @property 65 def has_nonlocal_statement(self) -> bool: 66 return self.transformer.has_nonlocal_statement 67 68 @property 69 def unified_ast_dump(self) -> str: 70 """ 71 Returns: 72 the dump string of the unified AST of the function definition. 73 74 By unified, it means that 75 - Type annotations are ignored 76 - Arguments, position-only arguments, and keyword-only arguments are renamed based on their 77 lexicographic order 78 - Varargs and kwargs are renamed with a unique name 79 - Default values of the arguments 80 - Local variables are renamed based on their order of occurrence 81 - Global variables are renamed based on their order of occurrence 82 - Function name is considered as a local variable, and is also renamed 83 - AugAssign statements (a += 1) are replaced with normal assign statements (a = a + 1) 84 - All decorators are removed. TODO: This is essentially a problem, but there is no good fix for the moment 85 86 Therefore, changing any aspect mentioned above will not change the returned AST dump. 87 Also trivially, AST dump is not affected by the code formatting and comments. 88 89 The criteria of the ignored/renamed/unified items above is: 90 91 Given that the arguments are provided in a keyword-specified way, taking the renaming of 92 arguments into account, what changes will not cause the function return value to change. 93 94 This is useful for judging whether two function definitions, given the same input, 95 can produce the same output. 96 """ 97 98 return ast.dump( 99 self.unified_tree, 100 annotate_fields=False, 101 include_attributes=False, 102 )
FunctionDefinitionUnifier(func_definition: str)
12 def __init__(self, func_definition: str) -> None: 13 tree = ast.parse(textwrap.dedent(func_definition), mode="exec") 14 if len(tree.body) > 1 or not isinstance( 15 tree.body[0], 16 ( 17 ast.FunctionDef, 18 ast.AsyncFunctionDef, 19 ), 20 ): 21 raise RefactorFailedError(f"The given code is not a single function definition: {func_definition}") 22 23 self.transformer = _FunctionDefinitionTransformer() 24 self.unified_tree = self.transformer.visit(tree)
args_renaming: Dict[str, str]
Dictionary of the renaming of function arguments.
>>> code = '''
... def foo(a):
... pass
... '''
>>>
>>> u = FunctionDefinitionUnifier(code)
>>> u.args_renaming
{'a': '__checkpointing_local_var_1__'}
nonlocal_variables_renaming: Dict[str, str]
Dictionary of the renaming of nonlocal variables referenced by the function.
>>> code = '''
... def foo():
... a = b + 1 # b is some global variable defined elsewhere
... '''
>>>
>>> u = FunctionDefinitionUnifier(code)
>>> u.nonlocal_variables_renaming
{'b': '__checkpointing_nonlocal_var_0__'}
unified_ast_dump: str
Returns
the dump string of the unified AST of the function definition.
By unified, it means that
- Type annotations are ignored
- Arguments, position-only arguments, and keyword-only arguments are renamed based on their lexicographic order
- Varargs and kwargs are renamed with a unique name
- Default values of the arguments
- Local variables are renamed based on their order of occurrence
- Global variables are renamed based on their order of occurrence
- Function name is considered as a local variable, and is also renamed
- AugAssign statements (a += 1) are replaced with normal assign statements (a = a + 1)
- All decorators are removed. TODO: This is essentially a problem, but there is no good fix for the moment
Therefore, changing any aspect mentioned above will not change the returned AST dump. Also trivially, AST dump is not affected by the code formatting and comments.
The criteria of the ignored/renamed/unified items above is:
Given that the arguments are provided in a keyword-specified way, taking the renaming of arguments into account, what changes will not cause the function return value to change.
This is useful for judging whether two function definitions, given the same input, can produce the same output.