checkpointing.hash.generic

 1import inspect
 2from types import FunctionType, GeneratorType, ModuleType
 3from typing import Any
 4from warnings import warn
 5
 6import dill
 7from checkpointing.exceptions import HashFailedWarning
 8from checkpointing.hash.stream import HashStream
 9from checkpointing.refactor.funcdef import FunctionDefinitionUnifier
10from checkpointing.util import pickle
11from checkpointing.logging import logger
12
13
14def hash_with_dill(stream: HashStream, obj: Any, pickle_protocol: int) -> None:
15    dill.dump(
16        obj,
17        stream,
18        protocol=min(pickle_protocol, dill.HIGHEST_PROTOCOL),
19        byref=True,
20        recurse=False,
21    )
22
23
24def hash_with_pickle(stream: HashStream, obj: Any, pickle_protocol: int) -> None:
25    pickle.dump(obj, stream, protocol=pickle_protocol)
26
27
28def hash_string(stream: HashStream, s: str) -> None:
29    bytes_ = s.encode("utf-8")
30    stream.write(bytes_)
31
32
33def hash_with_qualname(stream: HashStream, type_: str, obj: Any) -> None:
34    hash_string(stream, f"{type_}::{obj.__qualname__}")
35
36
37def hash_generic(stream: HashStream, obj: Any, pickle_protocol: int) -> None:
38
39    for test, type_ in [
40        (inspect.isgenerator, "generator"),
41    ]:
42        if test(obj):
43            hash_with_qualname(stream, type_, obj)
44            return
45
46    for hasher in [hash_with_pickle, hash_with_dill]:
47        try:
48            hasher(stream, obj, pickle_protocol)
49            return
50        except:
51            pass
52
53    warn(
54        f"No generic hasher found for object: {str(obj)} of type: {type(obj)}, using its __repr__ as hash value. "
55        "This could lead to incorrect results",
56        category=HashFailedWarning,
57    )
58    hash_string(stream, repr(obj))
def hash_with_dill( stream: checkpointing.hash.stream.HashStream, obj: Any, pickle_protocol: int) -> None:
15def hash_with_dill(stream: HashStream, obj: Any, pickle_protocol: int) -> None:
16    dill.dump(
17        obj,
18        stream,
19        protocol=min(pickle_protocol, dill.HIGHEST_PROTOCOL),
20        byref=True,
21        recurse=False,
22    )
def hash_with_pickle( stream: checkpointing.hash.stream.HashStream, obj: Any, pickle_protocol: int) -> None:
25def hash_with_pickle(stream: HashStream, obj: Any, pickle_protocol: int) -> None:
26    pickle.dump(obj, stream, protocol=pickle_protocol)
def hash_string(stream: checkpointing.hash.stream.HashStream, s: str) -> None:
29def hash_string(stream: HashStream, s: str) -> None:
30    bytes_ = s.encode("utf-8")
31    stream.write(bytes_)
def hash_with_qualname( stream: checkpointing.hash.stream.HashStream, type_: str, obj: Any) -> None:
34def hash_with_qualname(stream: HashStream, type_: str, obj: Any) -> None:
35    hash_string(stream, f"{type_}::{obj.__qualname__}")
def hash_generic( stream: checkpointing.hash.stream.HashStream, obj: Any, pickle_protocol: int) -> None:
38def hash_generic(stream: HashStream, obj: Any, pickle_protocol: int) -> None:
39
40    for test, type_ in [
41        (inspect.isgenerator, "generator"),
42    ]:
43        if test(obj):
44            hash_with_qualname(stream, type_, obj)
45            return
46
47    for hasher in [hash_with_pickle, hash_with_dill]:
48        try:
49            hasher(stream, obj, pickle_protocol)
50            return
51        except:
52            pass
53
54    warn(
55        f"No generic hasher found for object: {str(obj)} of type: {type(obj)}, using its __repr__ as hash value. "
56        "This could lead to incorrect results",
57        category=HashFailedWarning,
58    )
59    hash_string(stream, repr(obj))