checkpointing.identifier.func_call.auto
1from checkpointing.identifier.func_call.base import FuncCallIdentifierBase 2from checkpointing.identifier.func_call.context import FuncCallContext 3from checkpointing._typing import ContextId 4from checkpointing.config import defaults 5from checkpointing.refactor.funcdef import FunctionDefinitionUnifier 6from checkpointing.exceptions import GlobalStatementError, NonlocalStatementError 7from typing import Dict 8import textwrap 9import copy 10 11from checkpointing.hash import hash_anything 12 13 14class AutoFuncCallIdentifier(FuncCallIdentifierBase): 15 def __init__(self, algorithm: str = None, pickle_protocol: int = None) -> None: 16 """ 17 Args: 18 algorithm: the hash algorithm to use. If None, use the global default `hash.algorithm`. 19 pickle_protocol: the pickle protocol to use. If None, use the global default `hash.pickle_protocol` 20 """ 21 22 if algorithm is None: 23 algorithm = defaults["hash.algorithm"] 24 25 if pickle_protocol is None: 26 pickle_protocol = defaults["hash.pickle_protocol"] 27 28 self.algorithm = algorithm 29 self.pickle_protocol = pickle_protocol 30 31 def identify(self, context: FuncCallContext) -> ContextId: 32 """ 33 Identifies the context by producing a hash value. 34 35 Returns: 36 the function call context identifier 37 """ 38 39 unifier = FunctionDefinitionUnifier(context.code) 40 self.__check_unsupported_statements(unifier, context.code) 41 42 return self.__identify_with_unifier(context, unifier) 43 44 def __check_unsupported_statements(self, unifier: FunctionDefinitionUnifier, original_code: str): 45 error_text = lambda stmt_type: textwrap.dedent( 46 f""" 47 '{stmt_type}' statement detected in the code. This indicates that you are changing a {stmt_type} variable in the function, which is not a use case with checkpointing. 48 Your code: 49 {original_code} 50 """ 51 ) 52 if unifier.has_global_statement: 53 raise GlobalStatementError(error_text("global")) 54 55 if unifier.has_nonlocal_statement: 56 raise NonlocalStatementError(error_text("nonlocal")) 57 58 def __identify_with_unifier(self, context: FuncCallContext, unifier: FunctionDefinitionUnifier) -> ContextId: 59 60 variables = copy.copy(context.arguments) 61 62 for old_name, new_name in unifier.args_renaming.items(): 63 variables[new_name] = variables.pop(old_name) 64 65 for old_name, new_name in unifier.nonlocal_variables_renaming.items(): 66 var = context.get_nonlocal_variable(old_name) 67 variables[new_name] = var if var is not None else (old_name, "__checkpointing_no_nonlocal_reference__") 68 69 return hash_anything( 70 *sorted(variables.items()), 71 unifier.unified_ast_dump, 72 algorithm=self.algorithm, 73 pickle_protocol=self.pickle_protocol, 74 )
15class AutoFuncCallIdentifier(FuncCallIdentifierBase): 16 def __init__(self, algorithm: str = None, pickle_protocol: int = None) -> None: 17 """ 18 Args: 19 algorithm: the hash algorithm to use. If None, use the global default `hash.algorithm`. 20 pickle_protocol: the pickle protocol to use. If None, use the global default `hash.pickle_protocol` 21 """ 22 23 if algorithm is None: 24 algorithm = defaults["hash.algorithm"] 25 26 if pickle_protocol is None: 27 pickle_protocol = defaults["hash.pickle_protocol"] 28 29 self.algorithm = algorithm 30 self.pickle_protocol = pickle_protocol 31 32 def identify(self, context: FuncCallContext) -> ContextId: 33 """ 34 Identifies the context by producing a hash value. 35 36 Returns: 37 the function call context identifier 38 """ 39 40 unifier = FunctionDefinitionUnifier(context.code) 41 self.__check_unsupported_statements(unifier, context.code) 42 43 return self.__identify_with_unifier(context, unifier) 44 45 def __check_unsupported_statements(self, unifier: FunctionDefinitionUnifier, original_code: str): 46 error_text = lambda stmt_type: textwrap.dedent( 47 f""" 48 '{stmt_type}' statement detected in the code. This indicates that you are changing a {stmt_type} variable in the function, which is not a use case with checkpointing. 49 Your code: 50 {original_code} 51 """ 52 ) 53 if unifier.has_global_statement: 54 raise GlobalStatementError(error_text("global")) 55 56 if unifier.has_nonlocal_statement: 57 raise NonlocalStatementError(error_text("nonlocal")) 58 59 def __identify_with_unifier(self, context: FuncCallContext, unifier: FunctionDefinitionUnifier) -> ContextId: 60 61 variables = copy.copy(context.arguments) 62 63 for old_name, new_name in unifier.args_renaming.items(): 64 variables[new_name] = variables.pop(old_name) 65 66 for old_name, new_name in unifier.nonlocal_variables_renaming.items(): 67 var = context.get_nonlocal_variable(old_name) 68 variables[new_name] = var if var is not None else (old_name, "__checkpointing_no_nonlocal_reference__") 69 70 return hash_anything( 71 *sorted(variables.items()), 72 unifier.unified_ast_dump, 73 algorithm=self.algorithm, 74 pickle_protocol=self.pickle_protocol, 75 )
Base class for function call identifiers.
AutoFuncCallIdentifier(algorithm: str = None, pickle_protocol: int = None)
16 def __init__(self, algorithm: str = None, pickle_protocol: int = None) -> None: 17 """ 18 Args: 19 algorithm: the hash algorithm to use. If None, use the global default `hash.algorithm`. 20 pickle_protocol: the pickle protocol to use. If None, use the global default `hash.pickle_protocol` 21 """ 22 23 if algorithm is None: 24 algorithm = defaults["hash.algorithm"] 25 26 if pickle_protocol is None: 27 pickle_protocol = defaults["hash.pickle_protocol"] 28 29 self.algorithm = algorithm 30 self.pickle_protocol = pickle_protocol
Args
- algorithm: the hash algorithm to use. If None, use the global default
hash.algorithm. - pickle_protocol: the pickle protocol to use. If None, use the global default
hash.pickle_protocol
def
identify( self, context: checkpointing.identifier.func_call.context.FuncCallContext) -> ~ContextId:
32 def identify(self, context: FuncCallContext) -> ContextId: 33 """ 34 Identifies the context by producing a hash value. 35 36 Returns: 37 the function call context identifier 38 """ 39 40 unifier = FunctionDefinitionUnifier(context.code) 41 self.__check_unsupported_statements(unifier, context.code) 42 43 return self.__identify_with_unifier(context, unifier)
Identifies the context by producing a hash value.
Returns
the function call context identifier