checkpointing.identifier.func_call.auto

 1from checkpointing.identifier.func_call.base import FuncCallIdentifierBase
 2from checkpointing.identifier.func_call.context import FuncCallContext
 3from checkpointing._typing import ContextId
 4from checkpointing.config import defaults
 5from checkpointing.refactor.funcdef import FunctionDefinitionUnifier
 6from checkpointing.exceptions import GlobalStatementError, NonlocalStatementError
 7from typing import Dict
 8import textwrap
 9import copy
10
11from checkpointing.hash import hash_anything
12
13
14class AutoFuncCallIdentifier(FuncCallIdentifierBase):
15    def __init__(self, algorithm: str = None, pickle_protocol: int = None) -> None:
16        """
17        Args:
18            algorithm: the hash algorithm to use. If None, use the global default `hash.algorithm`.
19            pickle_protocol: the pickle protocol to use. If None, use the global default `hash.pickle_protocol`
20        """
21
22        if algorithm is None:
23            algorithm = defaults["hash.algorithm"]
24
25        if pickle_protocol is None:
26            pickle_protocol = defaults["hash.pickle_protocol"]
27
28        self.algorithm = algorithm
29        self.pickle_protocol = pickle_protocol
30
31    def identify(self, context: FuncCallContext) -> ContextId:
32        """
33        Identifies the context by producing a hash value.
34
35        Returns:
36            the function call context identifier
37        """
38
39        unifier = FunctionDefinitionUnifier(context.code)
40        self.__check_unsupported_statements(unifier, context.code)
41
42        return self.__identify_with_unifier(context, unifier)
43
44    def __check_unsupported_statements(self, unifier: FunctionDefinitionUnifier, original_code: str):
45        error_text = lambda stmt_type: textwrap.dedent(
46            f"""
47        '{stmt_type}' statement detected in the code. This indicates that you are changing a {stmt_type} variable in the function, which is not a use case with checkpointing.
48        Your code:
49        {original_code}
50        """
51        )
52        if unifier.has_global_statement:
53            raise GlobalStatementError(error_text("global"))
54
55        if unifier.has_nonlocal_statement:
56            raise NonlocalStatementError(error_text("nonlocal"))
57
58    def __identify_with_unifier(self, context: FuncCallContext, unifier: FunctionDefinitionUnifier) -> ContextId:
59
60        variables = copy.copy(context.arguments)
61
62        for old_name, new_name in unifier.args_renaming.items():
63            variables[new_name] = variables.pop(old_name)
64
65        for old_name, new_name in unifier.nonlocal_variables_renaming.items():
66            var = context.get_nonlocal_variable(old_name)
67            variables[new_name] = var if var is not None else (old_name, "__checkpointing_no_nonlocal_reference__")
68
69        return hash_anything(
70            *sorted(variables.items()),
71            unifier.unified_ast_dump,
72            algorithm=self.algorithm,
73            pickle_protocol=self.pickle_protocol,
74        )
class AutoFuncCallIdentifier(checkpointing.identifier.func_call.base.FuncCallIdentifierBase):
15class AutoFuncCallIdentifier(FuncCallIdentifierBase):
16    def __init__(self, algorithm: str = None, pickle_protocol: int = None) -> None:
17        """
18        Args:
19            algorithm: the hash algorithm to use. If None, use the global default `hash.algorithm`.
20            pickle_protocol: the pickle protocol to use. If None, use the global default `hash.pickle_protocol`
21        """
22
23        if algorithm is None:
24            algorithm = defaults["hash.algorithm"]
25
26        if pickle_protocol is None:
27            pickle_protocol = defaults["hash.pickle_protocol"]
28
29        self.algorithm = algorithm
30        self.pickle_protocol = pickle_protocol
31
32    def identify(self, context: FuncCallContext) -> ContextId:
33        """
34        Identifies the context by producing a hash value.
35
36        Returns:
37            the function call context identifier
38        """
39
40        unifier = FunctionDefinitionUnifier(context.code)
41        self.__check_unsupported_statements(unifier, context.code)
42
43        return self.__identify_with_unifier(context, unifier)
44
45    def __check_unsupported_statements(self, unifier: FunctionDefinitionUnifier, original_code: str):
46        error_text = lambda stmt_type: textwrap.dedent(
47            f"""
48        '{stmt_type}' statement detected in the code. This indicates that you are changing a {stmt_type} variable in the function, which is not a use case with checkpointing.
49        Your code:
50        {original_code}
51        """
52        )
53        if unifier.has_global_statement:
54            raise GlobalStatementError(error_text("global"))
55
56        if unifier.has_nonlocal_statement:
57            raise NonlocalStatementError(error_text("nonlocal"))
58
59    def __identify_with_unifier(self, context: FuncCallContext, unifier: FunctionDefinitionUnifier) -> ContextId:
60
61        variables = copy.copy(context.arguments)
62
63        for old_name, new_name in unifier.args_renaming.items():
64            variables[new_name] = variables.pop(old_name)
65
66        for old_name, new_name in unifier.nonlocal_variables_renaming.items():
67            var = context.get_nonlocal_variable(old_name)
68            variables[new_name] = var if var is not None else (old_name, "__checkpointing_no_nonlocal_reference__")
69
70        return hash_anything(
71            *sorted(variables.items()),
72            unifier.unified_ast_dump,
73            algorithm=self.algorithm,
74            pickle_protocol=self.pickle_protocol,
75        )

Base class for function call identifiers.

AutoFuncCallIdentifier(algorithm: str = None, pickle_protocol: int = None)
16    def __init__(self, algorithm: str = None, pickle_protocol: int = None) -> None:
17        """
18        Args:
19            algorithm: the hash algorithm to use. If None, use the global default `hash.algorithm`.
20            pickle_protocol: the pickle protocol to use. If None, use the global default `hash.pickle_protocol`
21        """
22
23        if algorithm is None:
24            algorithm = defaults["hash.algorithm"]
25
26        if pickle_protocol is None:
27            pickle_protocol = defaults["hash.pickle_protocol"]
28
29        self.algorithm = algorithm
30        self.pickle_protocol = pickle_protocol
Args
  • algorithm: the hash algorithm to use. If None, use the global default hash.algorithm.
  • pickle_protocol: the pickle protocol to use. If None, use the global default hash.pickle_protocol
def identify( self, context: checkpointing.identifier.func_call.context.FuncCallContext) -> ~ContextId:
32    def identify(self, context: FuncCallContext) -> ContextId:
33        """
34        Identifies the context by producing a hash value.
35
36        Returns:
37            the function call context identifier
38        """
39
40        unifier = FunctionDefinitionUnifier(context.code)
41        self.__check_unsupported_statements(unifier, context.code)
42
43        return self.__identify_with_unifier(context, unifier)

Identifies the context by producing a hash value.

Returns

the function call context identifier