checkpointing.decorator.base

  1from abc import ABC, abstractmethod
  2from functools import wraps
  3from typing import Callable, Dict, Generic, List, Tuple, TypeVar
  4from types import FrameType
  5from warnings import warn
  6
  7from checkpointing.exceptions import CheckpointNotExist, ExpensiveOverheadWarning, CheckpointFailedWarning, CheckpointFailedError
  8from checkpointing.util.timing import Timer, timed_run
  9from checkpointing._typing import ReturnValue, ContextId
 10from checkpointing.identifier.func_call.context import FuncCallContext
 11from checkpointing.identifier.func_call import FuncCallIdentifierBase
 12from checkpointing.logging import logger
 13from checkpointing.config import defaults
 14from checkpointing.cache import CacheBase
 15import inspect
 16
 17
 18class DecoratorCheckpoint(ABC, Generic[ReturnValue]):
 19    """The base class for any decorator checkpoint."""
 20
 21    def __init__(self, identifier: FuncCallIdentifierBase, cache: CacheBase, on_error: str = None) -> None:
 22        """
 23        Args:
 24            identifier: the function call identifier that creates an ID for any function call context
 25            cache: the cache that saves and retrieves the return value with a given ID
 26            on_error: the behavior when retrieval or saving raises unexpected exceptions
 27                (exceptions other than checkpointing.CheckpointNotExist). Possible values are:
 28                - `"raise"`, the exception will be raised.
 29                - `"warn"`, a warning will be issued to inform that the checkpointing task has failed.
 30                    But the user function will be invoked and executed as if it wasn't checkpointed.
 31                - `"ignore"`, the exception will be ignored and the user function will be invoked and executed normally.
 32                
 33                If None, use the global default `checkpoint.on_error`.
 34        """
 35
 36        self.__identifier = identifier
 37        """The function call identifier"""
 38
 39        self.__cache = cache
 40        """The cache instance"""
 41
 42        if on_error is None:
 43            on_error = defaults["checkpoint.on_error"]
 44
 45        self.__on_error: str = on_error
 46        """The behavior when identification, saving or retrieval raises unexpected exceptions."""
 47
 48        self.__definition_frame: FrameType = None
 49
 50        self.__validate_params()
 51
 52    def __validate_params(self):
 53        if not isinstance(self.__identifier, FuncCallIdentifierBase):
 54            raise ValueError(f"Invalid type for identifier: {type(self.__identifier)}")
 55
 56        if not isinstance(self.__cache, CacheBase):
 57            raise ValueError(f"Invalid type for cache: {type(self.__cache)}")
 58
 59        error = ["raise", "warn", "ignore"]
 60        if self.__on_error not in error:
 61            raise ValueError(f"Invalid argument value for error: {self.__on_error}, must be one of {error}")
 62
 63    def __call__(self, func: Callable[..., ReturnValue]) -> Callable[..., ReturnValue]:
 64        """Magic method invoked when used as a decorator."""
 65        logger.debug(f"{self.__class__.__name__} created for {func.__qualname__}")
 66
 67        current_frame = inspect.currentframe()
 68        self.__definition_frame = current_frame.f_back if current_frame is not None else None
 69
 70        inner = self.__create_inner(func)
 71
 72        self.__bind_rerun(func, inner)
 73
 74        return inner
 75
 76    def __get_context_and_id(self, func, args, kwargs):
 77        context = FuncCallContext(func, args, kwargs, self.__definition_frame)
 78        context_id = self.__identifier.identify(context)
 79        return context, context_id
 80
 81    def __create_inner(self, func: Callable[..., ReturnValue]) -> Callable[..., ReturnValue]:
 82        @wraps(func)
 83        def inner(*args, **kwargs) -> ReturnValue:
 84
 85            context, context_id = self.__get_context_and_id(func, args, kwargs)
 86            retrieve_success, res, retrieve_time = self.__timed_safe_retrieve(context, context_id)
 87
 88            if retrieve_success:
 89                logger.info(f"Result of {context.qualified_name} with args {context.arguments} retrieved from cache")
 90                return res
 91
 92            else:
 93                logger.info(f"Result of {context.qualified_name} with args {context.arguments} unavailable from cache")
 94
 95                res, run_time = timed_run(func, *args, **kwargs)
 96
 97                save_time = self.__timed_safe_save(context, context_id, res)
 98
 99                self.__warn_if_more_expensive(context, retrieve_time + save_time, run_time)
100                return res
101
102        return inner
103
104    def __bind_rerun(self, original_func: Callable[..., ReturnValue], inner_func: Callable[..., ReturnValue]) -> None:
105        def rerun(*args, **kwargs) -> ReturnValue:
106            context, context_id = self.__get_context_and_id(original_func, args, kwargs)
107
108            logger.info(f"Forcing rerun of {context.full_name} with args {context.arguments}")
109
110            res, run_time = timed_run(original_func, *args, **kwargs)
111
112            save_time = self.__timed_safe_save(context, context_id, res)
113
114            self.__warn_if_more_expensive(context, save_time, run_time)
115            return res
116
117        inner_func.rerun = rerun
118
119    def __warn_if_more_expensive(self, context: FuncCallContext, checkpoint_time: float, run_time: float, tol: float = 0.1) -> None:
120        """
121        Warn the user if retrieval takes longer than running the function.
122
123        Args:
124            checkpoint_time: approximate time for retrieving and saving the cached result
125            run_time: time for running the function
126            tol: tolerance of the difference between checkpoint_time and run_time in seconds.
127                Larger value indicates more tolerance of slow checkpointing, compared to actual function running, without raising an error.
128                Negative value indicates checkpoint should take less time than function running to avoid raising an error.
129        """
130
131        if checkpoint_time > run_time + tol:
132            warn(
133                f"The overhead for checkpointing '{context.full_name}' could possibly take more time than the function call itself "
134                f"({checkpoint_time:.2f}s > {run_time:.2f}s). "
135                "Consider optimize the checkpoint or just remove it, and let the function execute every time.",
136                category=ExpensiveOverheadWarning,
137            )
138
139    def __timed_safe_retrieve(self, context: FuncCallContext, context_id: ContextId) -> Tuple[bool, ReturnValue, float]:
140        """
141        Retrieve the cached result, tracking the time and capturing any error,
142        dealing with them according to the level specified by `self.__on_error`
143
144        Returns:
145            A tuple of three elements:
146            - bool: whether the retrival succeeds or not
147            - ReturnValue: the extracted return value, if successful, otherwise None
148            - float: the time (seconds) it takes to retrieve the result
149        """
150
151        timer = Timer().start()
152        try:
153            res = self.__cache.retrieve(context_id)
154            return True, res, timer.time
155
156        except CheckpointNotExist:
157            return False, None, timer.time
158
159        except Exception as e:
160            self.__handle_unexpected_error(context, e)
161            return False, None, timer.time
162
163    def __handle_unexpected_error(self, context: FuncCallContext, error: Exception):
164        """
165        Handle the unexpected error according to the level specified by `self.__on_error`.
166
167        Args:
168            error: the raised exception. Note that checkpointing.exceptions.CheckpointNotExist should NOT be handled by this method.
169                    It should be dealt within the saving/retrieving methods.
170        """
171        if self.__on_error == "raise":
172            raise CheckpointFailedError(f"Checkpointing for {context.full_name} failed because of the following error: {str(error)}", error)
173
174        elif self.__on_error == "warn":
175            warn(
176                f"Checkpointing for {context.full_name} failed because of the following error: {str(error)}. "
177                "The function is called to compute the return value.",
178                CheckpointFailedWarning,
179            )
180
181        else:  # self.__on_error == "ignore"
182            pass
183
184    def __timed_safe_save(self, context: FuncCallContext, context_id: ContextId, result: ReturnValue) -> float:
185        """
186        Save the result, tracking the time and capturing any error,
187        dealing with them according to the level specified by `self.__on_error`
188
189        Returns:
190            the time (seconds) it takes to save the result
191        """
192
193        timer = Timer().start()
194        try:
195            self.__cache.save(context_id, result)
196            logger.info(f"Result of {context.qualified_name} with args {context.arguments} saved to cache")
197
198        except Exception as e:
199            self.__handle_unexpected_error(context, e)
200
201        return timer.time
class DecoratorCheckpoint(abc.ABC, typing.Generic[~ReturnValue]):
 19class DecoratorCheckpoint(ABC, Generic[ReturnValue]):
 20    """The base class for any decorator checkpoint."""
 21
 22    def __init__(self, identifier: FuncCallIdentifierBase, cache: CacheBase, on_error: str = None) -> None:
 23        """
 24        Args:
 25            identifier: the function call identifier that creates an ID for any function call context
 26            cache: the cache that saves and retrieves the return value with a given ID
 27            on_error: the behavior when retrieval or saving raises unexpected exceptions
 28                (exceptions other than checkpointing.CheckpointNotExist). Possible values are:
 29                - `"raise"`, the exception will be raised.
 30                - `"warn"`, a warning will be issued to inform that the checkpointing task has failed.
 31                    But the user function will be invoked and executed as if it wasn't checkpointed.
 32                - `"ignore"`, the exception will be ignored and the user function will be invoked and executed normally.
 33                
 34                If None, use the global default `checkpoint.on_error`.
 35        """
 36
 37        self.__identifier = identifier
 38        """The function call identifier"""
 39
 40        self.__cache = cache
 41        """The cache instance"""
 42
 43        if on_error is None:
 44            on_error = defaults["checkpoint.on_error"]
 45
 46        self.__on_error: str = on_error
 47        """The behavior when identification, saving or retrieval raises unexpected exceptions."""
 48
 49        self.__definition_frame: FrameType = None
 50
 51        self.__validate_params()
 52
 53    def __validate_params(self):
 54        if not isinstance(self.__identifier, FuncCallIdentifierBase):
 55            raise ValueError(f"Invalid type for identifier: {type(self.__identifier)}")
 56
 57        if not isinstance(self.__cache, CacheBase):
 58            raise ValueError(f"Invalid type for cache: {type(self.__cache)}")
 59
 60        error = ["raise", "warn", "ignore"]
 61        if self.__on_error not in error:
 62            raise ValueError(f"Invalid argument value for error: {self.__on_error}, must be one of {error}")
 63
 64    def __call__(self, func: Callable[..., ReturnValue]) -> Callable[..., ReturnValue]:
 65        """Magic method invoked when used as a decorator."""
 66        logger.debug(f"{self.__class__.__name__} created for {func.__qualname__}")
 67
 68        current_frame = inspect.currentframe()
 69        self.__definition_frame = current_frame.f_back if current_frame is not None else None
 70
 71        inner = self.__create_inner(func)
 72
 73        self.__bind_rerun(func, inner)
 74
 75        return inner
 76
 77    def __get_context_and_id(self, func, args, kwargs):
 78        context = FuncCallContext(func, args, kwargs, self.__definition_frame)
 79        context_id = self.__identifier.identify(context)
 80        return context, context_id
 81
 82    def __create_inner(self, func: Callable[..., ReturnValue]) -> Callable[..., ReturnValue]:
 83        @wraps(func)
 84        def inner(*args, **kwargs) -> ReturnValue:
 85
 86            context, context_id = self.__get_context_and_id(func, args, kwargs)
 87            retrieve_success, res, retrieve_time = self.__timed_safe_retrieve(context, context_id)
 88
 89            if retrieve_success:
 90                logger.info(f"Result of {context.qualified_name} with args {context.arguments} retrieved from cache")
 91                return res
 92
 93            else:
 94                logger.info(f"Result of {context.qualified_name} with args {context.arguments} unavailable from cache")
 95
 96                res, run_time = timed_run(func, *args, **kwargs)
 97
 98                save_time = self.__timed_safe_save(context, context_id, res)
 99
100                self.__warn_if_more_expensive(context, retrieve_time + save_time, run_time)
101                return res
102
103        return inner
104
105    def __bind_rerun(self, original_func: Callable[..., ReturnValue], inner_func: Callable[..., ReturnValue]) -> None:
106        def rerun(*args, **kwargs) -> ReturnValue:
107            context, context_id = self.__get_context_and_id(original_func, args, kwargs)
108
109            logger.info(f"Forcing rerun of {context.full_name} with args {context.arguments}")
110
111            res, run_time = timed_run(original_func, *args, **kwargs)
112
113            save_time = self.__timed_safe_save(context, context_id, res)
114
115            self.__warn_if_more_expensive(context, save_time, run_time)
116            return res
117
118        inner_func.rerun = rerun
119
120    def __warn_if_more_expensive(self, context: FuncCallContext, checkpoint_time: float, run_time: float, tol: float = 0.1) -> None:
121        """
122        Warn the user if retrieval takes longer than running the function.
123
124        Args:
125            checkpoint_time: approximate time for retrieving and saving the cached result
126            run_time: time for running the function
127            tol: tolerance of the difference between checkpoint_time and run_time in seconds.
128                Larger value indicates more tolerance of slow checkpointing, compared to actual function running, without raising an error.
129                Negative value indicates checkpoint should take less time than function running to avoid raising an error.
130        """
131
132        if checkpoint_time > run_time + tol:
133            warn(
134                f"The overhead for checkpointing '{context.full_name}' could possibly take more time than the function call itself "
135                f"({checkpoint_time:.2f}s > {run_time:.2f}s). "
136                "Consider optimize the checkpoint or just remove it, and let the function execute every time.",
137                category=ExpensiveOverheadWarning,
138            )
139
140    def __timed_safe_retrieve(self, context: FuncCallContext, context_id: ContextId) -> Tuple[bool, ReturnValue, float]:
141        """
142        Retrieve the cached result, tracking the time and capturing any error,
143        dealing with them according to the level specified by `self.__on_error`
144
145        Returns:
146            A tuple of three elements:
147            - bool: whether the retrival succeeds or not
148            - ReturnValue: the extracted return value, if successful, otherwise None
149            - float: the time (seconds) it takes to retrieve the result
150        """
151
152        timer = Timer().start()
153        try:
154            res = self.__cache.retrieve(context_id)
155            return True, res, timer.time
156
157        except CheckpointNotExist:
158            return False, None, timer.time
159
160        except Exception as e:
161            self.__handle_unexpected_error(context, e)
162            return False, None, timer.time
163
164    def __handle_unexpected_error(self, context: FuncCallContext, error: Exception):
165        """
166        Handle the unexpected error according to the level specified by `self.__on_error`.
167
168        Args:
169            error: the raised exception. Note that checkpointing.exceptions.CheckpointNotExist should NOT be handled by this method.
170                    It should be dealt within the saving/retrieving methods.
171        """
172        if self.__on_error == "raise":
173            raise CheckpointFailedError(f"Checkpointing for {context.full_name} failed because of the following error: {str(error)}", error)
174
175        elif self.__on_error == "warn":
176            warn(
177                f"Checkpointing for {context.full_name} failed because of the following error: {str(error)}. "
178                "The function is called to compute the return value.",
179                CheckpointFailedWarning,
180            )
181
182        else:  # self.__on_error == "ignore"
183            pass
184
185    def __timed_safe_save(self, context: FuncCallContext, context_id: ContextId, result: ReturnValue) -> float:
186        """
187        Save the result, tracking the time and capturing any error,
188        dealing with them according to the level specified by `self.__on_error`
189
190        Returns:
191            the time (seconds) it takes to save the result
192        """
193
194        timer = Timer().start()
195        try:
196            self.__cache.save(context_id, result)
197            logger.info(f"Result of {context.qualified_name} with args {context.arguments} saved to cache")
198
199        except Exception as e:
200            self.__handle_unexpected_error(context, e)
201
202        return timer.time

The base class for any decorator checkpoint.

DecoratorCheckpoint( identifier: checkpointing.identifier.func_call.base.FuncCallIdentifierBase, cache: checkpointing.cache.base.CacheBase, on_error: str = None)
22    def __init__(self, identifier: FuncCallIdentifierBase, cache: CacheBase, on_error: str = None) -> None:
23        """
24        Args:
25            identifier: the function call identifier that creates an ID for any function call context
26            cache: the cache that saves and retrieves the return value with a given ID
27            on_error: the behavior when retrieval or saving raises unexpected exceptions
28                (exceptions other than checkpointing.CheckpointNotExist). Possible values are:
29                - `"raise"`, the exception will be raised.
30                - `"warn"`, a warning will be issued to inform that the checkpointing task has failed.
31                    But the user function will be invoked and executed as if it wasn't checkpointed.
32                - `"ignore"`, the exception will be ignored and the user function will be invoked and executed normally.
33                
34                If None, use the global default `checkpoint.on_error`.
35        """
36
37        self.__identifier = identifier
38        """The function call identifier"""
39
40        self.__cache = cache
41        """The cache instance"""
42
43        if on_error is None:
44            on_error = defaults["checkpoint.on_error"]
45
46        self.__on_error: str = on_error
47        """The behavior when identification, saving or retrieval raises unexpected exceptions."""
48
49        self.__definition_frame: FrameType = None
50
51        self.__validate_params()
Args
  • identifier: the function call identifier that creates an ID for any function call context
  • cache: the cache that saves and retrieves the return value with a given ID
  • on_error: the behavior when retrieval or saving raises unexpected exceptions (exceptions other than checkpointing.CheckpointNotExist). Possible values are:

    • "raise", the exception will be raised.
    • "warn", a warning will be issued to inform that the checkpointing task has failed. But the user function will be invoked and executed as if it wasn't checkpointed.
    • "ignore", the exception will be ignored and the user function will be invoked and executed normally.

    If None, use the global default checkpoint.on_error.