checkpointing.decorator.base
1from abc import ABC, abstractmethod 2from functools import wraps 3from typing import Callable, Dict, Generic, List, Tuple, TypeVar 4from types import FrameType 5from warnings import warn 6 7from checkpointing.exceptions import CheckpointNotExist, ExpensiveOverheadWarning, CheckpointFailedWarning, CheckpointFailedError 8from checkpointing.util.timing import Timer, timed_run 9from checkpointing._typing import ReturnValue, ContextId 10from checkpointing.identifier.func_call.context import FuncCallContext 11from checkpointing.identifier.func_call import FuncCallIdentifierBase 12from checkpointing.logging import logger 13from checkpointing.config import defaults 14from checkpointing.cache import CacheBase 15import inspect 16 17 18class DecoratorCheckpoint(ABC, Generic[ReturnValue]): 19 """The base class for any decorator checkpoint.""" 20 21 def __init__(self, identifier: FuncCallIdentifierBase, cache: CacheBase, on_error: str = None) -> None: 22 """ 23 Args: 24 identifier: the function call identifier that creates an ID for any function call context 25 cache: the cache that saves and retrieves the return value with a given ID 26 on_error: the behavior when retrieval or saving raises unexpected exceptions 27 (exceptions other than checkpointing.CheckpointNotExist). Possible values are: 28 - `"raise"`, the exception will be raised. 29 - `"warn"`, a warning will be issued to inform that the checkpointing task has failed. 30 But the user function will be invoked and executed as if it wasn't checkpointed. 31 - `"ignore"`, the exception will be ignored and the user function will be invoked and executed normally. 32 33 If None, use the global default `checkpoint.on_error`. 34 """ 35 36 self.__identifier = identifier 37 """The function call identifier""" 38 39 self.__cache = cache 40 """The cache instance""" 41 42 if on_error is None: 43 on_error = defaults["checkpoint.on_error"] 44 45 self.__on_error: str = on_error 46 """The behavior when identification, saving or retrieval raises unexpected exceptions.""" 47 48 self.__definition_frame: FrameType = None 49 50 self.__validate_params() 51 52 def __validate_params(self): 53 if not isinstance(self.__identifier, FuncCallIdentifierBase): 54 raise ValueError(f"Invalid type for identifier: {type(self.__identifier)}") 55 56 if not isinstance(self.__cache, CacheBase): 57 raise ValueError(f"Invalid type for cache: {type(self.__cache)}") 58 59 error = ["raise", "warn", "ignore"] 60 if self.__on_error not in error: 61 raise ValueError(f"Invalid argument value for error: {self.__on_error}, must be one of {error}") 62 63 def __call__(self, func: Callable[..., ReturnValue]) -> Callable[..., ReturnValue]: 64 """Magic method invoked when used as a decorator.""" 65 logger.debug(f"{self.__class__.__name__} created for {func.__qualname__}") 66 67 current_frame = inspect.currentframe() 68 self.__definition_frame = current_frame.f_back if current_frame is not None else None 69 70 inner = self.__create_inner(func) 71 72 self.__bind_rerun(func, inner) 73 74 return inner 75 76 def __get_context_and_id(self, func, args, kwargs): 77 context = FuncCallContext(func, args, kwargs, self.__definition_frame) 78 context_id = self.__identifier.identify(context) 79 return context, context_id 80 81 def __create_inner(self, func: Callable[..., ReturnValue]) -> Callable[..., ReturnValue]: 82 @wraps(func) 83 def inner(*args, **kwargs) -> ReturnValue: 84 85 context, context_id = self.__get_context_and_id(func, args, kwargs) 86 retrieve_success, res, retrieve_time = self.__timed_safe_retrieve(context, context_id) 87 88 if retrieve_success: 89 logger.info(f"Result of {context.qualified_name} with args {context.arguments} retrieved from cache") 90 return res 91 92 else: 93 logger.info(f"Result of {context.qualified_name} with args {context.arguments} unavailable from cache") 94 95 res, run_time = timed_run(func, *args, **kwargs) 96 97 save_time = self.__timed_safe_save(context, context_id, res) 98 99 self.__warn_if_more_expensive(context, retrieve_time + save_time, run_time) 100 return res 101 102 return inner 103 104 def __bind_rerun(self, original_func: Callable[..., ReturnValue], inner_func: Callable[..., ReturnValue]) -> None: 105 def rerun(*args, **kwargs) -> ReturnValue: 106 context, context_id = self.__get_context_and_id(original_func, args, kwargs) 107 108 logger.info(f"Forcing rerun of {context.full_name} with args {context.arguments}") 109 110 res, run_time = timed_run(original_func, *args, **kwargs) 111 112 save_time = self.__timed_safe_save(context, context_id, res) 113 114 self.__warn_if_more_expensive(context, save_time, run_time) 115 return res 116 117 inner_func.rerun = rerun 118 119 def __warn_if_more_expensive(self, context: FuncCallContext, checkpoint_time: float, run_time: float, tol: float = 0.1) -> None: 120 """ 121 Warn the user if retrieval takes longer than running the function. 122 123 Args: 124 checkpoint_time: approximate time for retrieving and saving the cached result 125 run_time: time for running the function 126 tol: tolerance of the difference between checkpoint_time and run_time in seconds. 127 Larger value indicates more tolerance of slow checkpointing, compared to actual function running, without raising an error. 128 Negative value indicates checkpoint should take less time than function running to avoid raising an error. 129 """ 130 131 if checkpoint_time > run_time + tol: 132 warn( 133 f"The overhead for checkpointing '{context.full_name}' could possibly take more time than the function call itself " 134 f"({checkpoint_time:.2f}s > {run_time:.2f}s). " 135 "Consider optimize the checkpoint or just remove it, and let the function execute every time.", 136 category=ExpensiveOverheadWarning, 137 ) 138 139 def __timed_safe_retrieve(self, context: FuncCallContext, context_id: ContextId) -> Tuple[bool, ReturnValue, float]: 140 """ 141 Retrieve the cached result, tracking the time and capturing any error, 142 dealing with them according to the level specified by `self.__on_error` 143 144 Returns: 145 A tuple of three elements: 146 - bool: whether the retrival succeeds or not 147 - ReturnValue: the extracted return value, if successful, otherwise None 148 - float: the time (seconds) it takes to retrieve the result 149 """ 150 151 timer = Timer().start() 152 try: 153 res = self.__cache.retrieve(context_id) 154 return True, res, timer.time 155 156 except CheckpointNotExist: 157 return False, None, timer.time 158 159 except Exception as e: 160 self.__handle_unexpected_error(context, e) 161 return False, None, timer.time 162 163 def __handle_unexpected_error(self, context: FuncCallContext, error: Exception): 164 """ 165 Handle the unexpected error according to the level specified by `self.__on_error`. 166 167 Args: 168 error: the raised exception. Note that checkpointing.exceptions.CheckpointNotExist should NOT be handled by this method. 169 It should be dealt within the saving/retrieving methods. 170 """ 171 if self.__on_error == "raise": 172 raise CheckpointFailedError(f"Checkpointing for {context.full_name} failed because of the following error: {str(error)}", error) 173 174 elif self.__on_error == "warn": 175 warn( 176 f"Checkpointing for {context.full_name} failed because of the following error: {str(error)}. " 177 "The function is called to compute the return value.", 178 CheckpointFailedWarning, 179 ) 180 181 else: # self.__on_error == "ignore" 182 pass 183 184 def __timed_safe_save(self, context: FuncCallContext, context_id: ContextId, result: ReturnValue) -> float: 185 """ 186 Save the result, tracking the time and capturing any error, 187 dealing with them according to the level specified by `self.__on_error` 188 189 Returns: 190 the time (seconds) it takes to save the result 191 """ 192 193 timer = Timer().start() 194 try: 195 self.__cache.save(context_id, result) 196 logger.info(f"Result of {context.qualified_name} with args {context.arguments} saved to cache") 197 198 except Exception as e: 199 self.__handle_unexpected_error(context, e) 200 201 return timer.time
class
DecoratorCheckpoint(abc.ABC, typing.Generic[~ReturnValue]):
19class DecoratorCheckpoint(ABC, Generic[ReturnValue]): 20 """The base class for any decorator checkpoint.""" 21 22 def __init__(self, identifier: FuncCallIdentifierBase, cache: CacheBase, on_error: str = None) -> None: 23 """ 24 Args: 25 identifier: the function call identifier that creates an ID for any function call context 26 cache: the cache that saves and retrieves the return value with a given ID 27 on_error: the behavior when retrieval or saving raises unexpected exceptions 28 (exceptions other than checkpointing.CheckpointNotExist). Possible values are: 29 - `"raise"`, the exception will be raised. 30 - `"warn"`, a warning will be issued to inform that the checkpointing task has failed. 31 But the user function will be invoked and executed as if it wasn't checkpointed. 32 - `"ignore"`, the exception will be ignored and the user function will be invoked and executed normally. 33 34 If None, use the global default `checkpoint.on_error`. 35 """ 36 37 self.__identifier = identifier 38 """The function call identifier""" 39 40 self.__cache = cache 41 """The cache instance""" 42 43 if on_error is None: 44 on_error = defaults["checkpoint.on_error"] 45 46 self.__on_error: str = on_error 47 """The behavior when identification, saving or retrieval raises unexpected exceptions.""" 48 49 self.__definition_frame: FrameType = None 50 51 self.__validate_params() 52 53 def __validate_params(self): 54 if not isinstance(self.__identifier, FuncCallIdentifierBase): 55 raise ValueError(f"Invalid type for identifier: {type(self.__identifier)}") 56 57 if not isinstance(self.__cache, CacheBase): 58 raise ValueError(f"Invalid type for cache: {type(self.__cache)}") 59 60 error = ["raise", "warn", "ignore"] 61 if self.__on_error not in error: 62 raise ValueError(f"Invalid argument value for error: {self.__on_error}, must be one of {error}") 63 64 def __call__(self, func: Callable[..., ReturnValue]) -> Callable[..., ReturnValue]: 65 """Magic method invoked when used as a decorator.""" 66 logger.debug(f"{self.__class__.__name__} created for {func.__qualname__}") 67 68 current_frame = inspect.currentframe() 69 self.__definition_frame = current_frame.f_back if current_frame is not None else None 70 71 inner = self.__create_inner(func) 72 73 self.__bind_rerun(func, inner) 74 75 return inner 76 77 def __get_context_and_id(self, func, args, kwargs): 78 context = FuncCallContext(func, args, kwargs, self.__definition_frame) 79 context_id = self.__identifier.identify(context) 80 return context, context_id 81 82 def __create_inner(self, func: Callable[..., ReturnValue]) -> Callable[..., ReturnValue]: 83 @wraps(func) 84 def inner(*args, **kwargs) -> ReturnValue: 85 86 context, context_id = self.__get_context_and_id(func, args, kwargs) 87 retrieve_success, res, retrieve_time = self.__timed_safe_retrieve(context, context_id) 88 89 if retrieve_success: 90 logger.info(f"Result of {context.qualified_name} with args {context.arguments} retrieved from cache") 91 return res 92 93 else: 94 logger.info(f"Result of {context.qualified_name} with args {context.arguments} unavailable from cache") 95 96 res, run_time = timed_run(func, *args, **kwargs) 97 98 save_time = self.__timed_safe_save(context, context_id, res) 99 100 self.__warn_if_more_expensive(context, retrieve_time + save_time, run_time) 101 return res 102 103 return inner 104 105 def __bind_rerun(self, original_func: Callable[..., ReturnValue], inner_func: Callable[..., ReturnValue]) -> None: 106 def rerun(*args, **kwargs) -> ReturnValue: 107 context, context_id = self.__get_context_and_id(original_func, args, kwargs) 108 109 logger.info(f"Forcing rerun of {context.full_name} with args {context.arguments}") 110 111 res, run_time = timed_run(original_func, *args, **kwargs) 112 113 save_time = self.__timed_safe_save(context, context_id, res) 114 115 self.__warn_if_more_expensive(context, save_time, run_time) 116 return res 117 118 inner_func.rerun = rerun 119 120 def __warn_if_more_expensive(self, context: FuncCallContext, checkpoint_time: float, run_time: float, tol: float = 0.1) -> None: 121 """ 122 Warn the user if retrieval takes longer than running the function. 123 124 Args: 125 checkpoint_time: approximate time for retrieving and saving the cached result 126 run_time: time for running the function 127 tol: tolerance of the difference between checkpoint_time and run_time in seconds. 128 Larger value indicates more tolerance of slow checkpointing, compared to actual function running, without raising an error. 129 Negative value indicates checkpoint should take less time than function running to avoid raising an error. 130 """ 131 132 if checkpoint_time > run_time + tol: 133 warn( 134 f"The overhead for checkpointing '{context.full_name}' could possibly take more time than the function call itself " 135 f"({checkpoint_time:.2f}s > {run_time:.2f}s). " 136 "Consider optimize the checkpoint or just remove it, and let the function execute every time.", 137 category=ExpensiveOverheadWarning, 138 ) 139 140 def __timed_safe_retrieve(self, context: FuncCallContext, context_id: ContextId) -> Tuple[bool, ReturnValue, float]: 141 """ 142 Retrieve the cached result, tracking the time and capturing any error, 143 dealing with them according to the level specified by `self.__on_error` 144 145 Returns: 146 A tuple of three elements: 147 - bool: whether the retrival succeeds or not 148 - ReturnValue: the extracted return value, if successful, otherwise None 149 - float: the time (seconds) it takes to retrieve the result 150 """ 151 152 timer = Timer().start() 153 try: 154 res = self.__cache.retrieve(context_id) 155 return True, res, timer.time 156 157 except CheckpointNotExist: 158 return False, None, timer.time 159 160 except Exception as e: 161 self.__handle_unexpected_error(context, e) 162 return False, None, timer.time 163 164 def __handle_unexpected_error(self, context: FuncCallContext, error: Exception): 165 """ 166 Handle the unexpected error according to the level specified by `self.__on_error`. 167 168 Args: 169 error: the raised exception. Note that checkpointing.exceptions.CheckpointNotExist should NOT be handled by this method. 170 It should be dealt within the saving/retrieving methods. 171 """ 172 if self.__on_error == "raise": 173 raise CheckpointFailedError(f"Checkpointing for {context.full_name} failed because of the following error: {str(error)}", error) 174 175 elif self.__on_error == "warn": 176 warn( 177 f"Checkpointing for {context.full_name} failed because of the following error: {str(error)}. " 178 "The function is called to compute the return value.", 179 CheckpointFailedWarning, 180 ) 181 182 else: # self.__on_error == "ignore" 183 pass 184 185 def __timed_safe_save(self, context: FuncCallContext, context_id: ContextId, result: ReturnValue) -> float: 186 """ 187 Save the result, tracking the time and capturing any error, 188 dealing with them according to the level specified by `self.__on_error` 189 190 Returns: 191 the time (seconds) it takes to save the result 192 """ 193 194 timer = Timer().start() 195 try: 196 self.__cache.save(context_id, result) 197 logger.info(f"Result of {context.qualified_name} with args {context.arguments} saved to cache") 198 199 except Exception as e: 200 self.__handle_unexpected_error(context, e) 201 202 return timer.time
The base class for any decorator checkpoint.
DecoratorCheckpoint( identifier: checkpointing.identifier.func_call.base.FuncCallIdentifierBase, cache: checkpointing.cache.base.CacheBase, on_error: str = None)
22 def __init__(self, identifier: FuncCallIdentifierBase, cache: CacheBase, on_error: str = None) -> None: 23 """ 24 Args: 25 identifier: the function call identifier that creates an ID for any function call context 26 cache: the cache that saves and retrieves the return value with a given ID 27 on_error: the behavior when retrieval or saving raises unexpected exceptions 28 (exceptions other than checkpointing.CheckpointNotExist). Possible values are: 29 - `"raise"`, the exception will be raised. 30 - `"warn"`, a warning will be issued to inform that the checkpointing task has failed. 31 But the user function will be invoked and executed as if it wasn't checkpointed. 32 - `"ignore"`, the exception will be ignored and the user function will be invoked and executed normally. 33 34 If None, use the global default `checkpoint.on_error`. 35 """ 36 37 self.__identifier = identifier 38 """The function call identifier""" 39 40 self.__cache = cache 41 """The cache instance""" 42 43 if on_error is None: 44 on_error = defaults["checkpoint.on_error"] 45 46 self.__on_error: str = on_error 47 """The behavior when identification, saving or retrieval raises unexpected exceptions.""" 48 49 self.__definition_frame: FrameType = None 50 51 self.__validate_params()
Args
- identifier: the function call identifier that creates an ID for any function call context
- cache: the cache that saves and retrieves the return value with a given ID
on_error: the behavior when retrieval or saving raises unexpected exceptions (exceptions other than checkpointing.CheckpointNotExist). Possible values are:
"raise", the exception will be raised."warn", a warning will be issued to inform that the checkpointing task has failed. But the user function will be invoked and executed as if it wasn't checkpointed."ignore", the exception will be ignored and the user function will be invoked and executed normally.
If None, use the global default
checkpoint.on_error.