diff --git a/src/store.py b/src/store.py index 825e554..f008ea0 100644 --- a/src/store.py +++ b/src/store.py @@ -3,25 +3,24 @@ import inspect import json from pathlib import Path import pickle +from typing import Tuple -def with_filtered_kwargs(fn): +def with_filtered_args(fn): """ Return a version of the input function that ignores instead of errors on - unknown keyword arguments. + unknown arguments. """ - # Enumerate all small_fn's keyword arguments. - filter_kwargs = {param.name for param in inspect.signature(fn).parameters.values() - if param.kind == param.POSITIONAL_OR_KEYWORD - or param.kind == param.KEYWORD_ONLY} + arg_filter = {param.name for param in inspect.signature(fn).parameters.values() + if param.kind == param.POSITIONAL_OR_KEYWORD + or param.kind == param.KEYWORD_ONLY} def inner(*args, **kwargs): # Filter any keyword arguments not present in small_fn. - kwargs = {k: v for k, v in kwargs.items() if k in filter_kwargs} - + kwargs = {k: v for k, v in kwargs.items() if k in arg_filter} return fn(*args, **kwargs) - return inner + return inner, arg_filter class load_or_create: @@ -30,56 +29,160 @@ class load_or_create: self, load=pickle.load, save=pickle.dump, - prefix="./", - suffix=".pkl", + path_fn=lambda hash: hash + ".pkl", hash_len=8, save_args=False, plain_text=False ): - self.load=with_filtered_kwargs(load) - self.save=with_filtered_kwargs(save) - self.prefix=prefix - self.suffix=suffix + self.load=with_filtered_args(load)[0] + self.save=with_filtered_args(save)[0] + self.path_fn=path_fn self.hash_len=hash_len self.save_args=save_args self.plain_text=plain_text def __call__(self, fn): - # fn_args = tuple(inspect.signature(fn).parameters.keys()) + return inner(self, fn) + - def inner(*args, **kwargs): - # Store the keyword arguments into json and hash it to get the storage path. - json_args = ( - json.dumps( - kwargs, - default=lambda x: vars(x) if hasattr(x, "__dict__") else str(x) - ) + "\n" - ).encode("utf-8") - hash = sha256(json_args, usedforsecurity=False).hexdigest()[:self.hash_len] - path = Path(self.prefix + hash + self.suffix) +class inner(load_or_create): + def __init__(self, parent, fn): + self.__dict__.update(parent.__dict__) + self.fn = fn + self.arg_names = [p.name for p in inspect.signature(fn).parameters.values()] - # Convert all arguments to keyword arguments. - # for arg_name, arg in zip(fn_args, args): - # kwargs[arg_name] = arg + def __call__(self, *args, **kwargs): + # Store the keyword arguments into json and hash it to get the storage path. + json_args = self.to_json(**kwargs) + path = self.path(*args, **kwargs) + # path = self.path(hash) - # If the storage path exists, load the cached object from there. - if path.exists(): - with open(path, "r" if self.plain_text else "rb") as f: - if self.save_args: - f.readline() - return self.load(f, **kwargs) - - # Else, run the function and store its result. - out = fn(*args, **kwargs) - - if "/" in self.suffix: - path.parent.mkdir(parents=True) - with open(path, "w" if self.plain_text else "wb") as f: + # If the storage path exists, load the cached object from there. + if path.exists(): + with open(path, "r" if self.plain_text else "rb") as file: + # If specified, the first line is a json of the keyword arguments. if self.save_args: - f.write(json_args) - self.save(f, out, **kwargs) + file.readline() + return self.load(**self.args_to_kwargs(args, kwargs, path=path, file=file)) - return out + # Else, run the function. and store its result. + obj = self.fn(*args, **kwargs) - return inner + if obj is None: + return obj + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w" if self.plain_text else "wb") as file: + if self.save_args: + file.write(json_args) + self.save(obj, **self.args_to_kwargs(args, kwargs, path=path, file=file)) + + return obj + + def args_to_kwargs(self, args, kwargs, **extra): + return extra | kwargs | {self.arg_names[i]: a for i, a in enumerate(args)} + + def to_json(self, **kwargs) -> bytes: + """Serialize all keyword arguments to json. + We never serialize positional arguments""" + return ( + json.dumps( + kwargs, + default=lambda x: vars(x) if hasattr(x, "__dict__") else str(x) + ) + "\n" + ).encode("utf-8") + + def make_hash(self, **kwargs): + """Serialize all keyword arguments to json and hash it.""" + json_args = self.to_json(**kwargs) + hash_str = sha256(json_args, usedforsecurity=False).hexdigest() + return hash_str[:self.hash_len] + + def path_and_hash(self, *args, **kwargs) -> Tuple[Path, str]: + path_fn, arg_filter = with_filtered_args(self.path_fn) + + hash = self.make_hash( + # Do not have arguments in the hash that path_fn uses. + **{k: a for k, a in kwargs.items() if k not in arg_filter} + ) + + path_args = self.args_to_kwargs(args, kwargs, hash=hash) + + path = Path(path_fn(**path_args)) + + return path, hash + + def hash(self, *args, **kwargs) -> str: + """Hash the keyword arguments. + Note that the hash is dependent on the path function of this instance: + All arguments of the path function are excluded from the hash.""" + return self.path_and_hash(*args, **kwargs)[1] + + def path(self, *args, **kwargs) -> Path: + return self.path_and_hash(*args, **kwargs)[0] + + def dir(self, *args, **kwargs) -> Path: + return self.path(*args, **kwargs).parent + + +#%% [markdown] +# A big possible extension is to hash the returned object. +# If we do that, we can do something like this: +# ``` +# fn = load_or_create(orig_fn) +# obj = fn(*args) +# assert fn.path(*args) == fn.obj_path(obj) +# ``` +# +# So, we'd get a _path_ from an object! +# +# How would this work out? +# +# first, after loading and/or storing, we save the hash as we store the object. +# ``` +# obj_hash = get_hash_from_flie(path) +# self.cache[obj_hash] = [path, kwargs] +# ``` +# +# then, in self.obj_path, we fake-save the object using the provided save function. +# instead of saving to a file, we save to some class that inherited io.BytesIO: +# ``` +# class HashingWriter(io.BytesIO): +# def __init__(self): +# super().__init__() # Initialize the BytesIO buffer +# self.hash = hashlib.sha256() # Initialize the SHA-256 hash object +# +# def write(self, b): +# return self.hash.update(b) # Update the hash with the data being written +# +# def get_hash(self): +# return self.hash.hexdigest() # Return the hexadecimal digest of the hash +# ``` +# +# so self.obj_path would look something like: +# ``` +# def obj_path(self, obj): +# with open(HashingWriter(), "wb") as f: +# f.write(obj) +# if f.get_hash() in self.cache: +# return self.cache[hash][0] +# ``` +# In reality, obj_path() can just be part of path(), since path() should expect +# kwargs to be present and can safely assume it's being called as obj_path() +# otherwise. +# +# Caveat, this will only work for save functions that use an open file. +# Save functions that take a path would not work. +# +# More practical case: +# ``` +# fn_a = load_or_create()(fn_a) +# +# def fn_b(obj_a, cfg): +# do_something(obj_a, cfg) +# +# fn_b = load_or_create( +# path_fn=lambda obj_a, hash: fn_a.get_path(obj_a) / hash +# )(fn_b) +# ``` +# This avoids having to make pass obj_a's config to fn_b in order to determine the path!