store now allows positional arguments to be passed to the path function

This commit is contained in:
JJJHolscher 2024-08-04 21:25:09 +02:00
parent fe8ae1fda7
commit 35396c80ca

View File

@ -3,25 +3,24 @@ import inspect
import json import json
from pathlib import Path from pathlib import Path
import pickle import pickle
from typing import Tuple
def with_filtered_kwargs(fn): def with_filtered_args(fn):
""" """
Return a version of the input function that ignores instead of errors on Return a version of the input function that ignores instead of errors on
unknown keyword arguments. unknown arguments.
""" """
# Enumerate all small_fn's keyword arguments. arg_filter = {param.name for param in inspect.signature(fn).parameters.values()
filter_kwargs = {param.name for param in inspect.signature(fn).parameters.values() if param.kind == param.POSITIONAL_OR_KEYWORD
if param.kind == param.POSITIONAL_OR_KEYWORD or param.kind == param.KEYWORD_ONLY}
or param.kind == param.KEYWORD_ONLY}
def inner(*args, **kwargs): def inner(*args, **kwargs):
# Filter any keyword arguments not present in small_fn. # Filter any keyword arguments not present in small_fn.
kwargs = {k: v for k, v in kwargs.items() if k in filter_kwargs} kwargs = {k: v for k, v in kwargs.items() if k in arg_filter}
return fn(*args, **kwargs) return fn(*args, **kwargs)
return inner return inner, arg_filter
class load_or_create: class load_or_create:
@ -30,56 +29,160 @@ class load_or_create:
self, self,
load=pickle.load, load=pickle.load,
save=pickle.dump, save=pickle.dump,
prefix="./", path_fn=lambda hash: hash + ".pkl",
suffix=".pkl",
hash_len=8, hash_len=8,
save_args=False, save_args=False,
plain_text=False plain_text=False
): ):
self.load=with_filtered_kwargs(load) self.load=with_filtered_args(load)[0]
self.save=with_filtered_kwargs(save) self.save=with_filtered_args(save)[0]
self.prefix=prefix self.path_fn=path_fn
self.suffix=suffix
self.hash_len=hash_len self.hash_len=hash_len
self.save_args=save_args self.save_args=save_args
self.plain_text=plain_text self.plain_text=plain_text
def __call__(self, fn): def __call__(self, fn):
# fn_args = tuple(inspect.signature(fn).parameters.keys()) return inner(self, fn)
def inner(*args, **kwargs): class inner(load_or_create):
# Store the keyword arguments into json and hash it to get the storage path. def __init__(self, parent, fn):
json_args = ( self.__dict__.update(parent.__dict__)
json.dumps( self.fn = fn
kwargs, self.arg_names = [p.name for p in inspect.signature(fn).parameters.values()]
default=lambda x: vars(x) if hasattr(x, "__dict__") else str(x)
) + "\n"
).encode("utf-8")
hash = sha256(json_args, usedforsecurity=False).hexdigest()[:self.hash_len]
path = Path(self.prefix + hash + self.suffix)
# Convert all arguments to keyword arguments. def __call__(self, *args, **kwargs):
# for arg_name, arg in zip(fn_args, args): # Store the keyword arguments into json and hash it to get the storage path.
# kwargs[arg_name] = arg json_args = self.to_json(**kwargs)
path = self.path(*args, **kwargs)
# path = self.path(hash)
# If the storage path exists, load the cached object from there. # If the storage path exists, load the cached object from there.
if path.exists(): if path.exists():
with open(path, "r" if self.plain_text else "rb") as f: with open(path, "r" if self.plain_text else "rb") as file:
if self.save_args: # If specified, the first line is a json of the keyword arguments.
f.readline()
return self.load(f, **kwargs)
# Else, run the function and store its result.
out = fn(*args, **kwargs)
if "/" in self.suffix:
path.parent.mkdir(parents=True)
with open(path, "w" if self.plain_text else "wb") as f:
if self.save_args: if self.save_args:
f.write(json_args) file.readline()
self.save(f, out, **kwargs) return self.load(**self.args_to_kwargs(args, kwargs, path=path, file=file))
return out # Else, run the function. and store its result.
obj = self.fn(*args, **kwargs)
return inner if obj is None:
return obj
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w" if self.plain_text else "wb") as file:
if self.save_args:
file.write(json_args)
self.save(obj, **self.args_to_kwargs(args, kwargs, path=path, file=file))
return obj
def args_to_kwargs(self, args, kwargs, **extra):
return extra | kwargs | {self.arg_names[i]: a for i, a in enumerate(args)}
def to_json(self, **kwargs) -> bytes:
"""Serialize all keyword arguments to json.
We never serialize positional arguments"""
return (
json.dumps(
kwargs,
default=lambda x: vars(x) if hasattr(x, "__dict__") else str(x)
) + "\n"
).encode("utf-8")
def make_hash(self, **kwargs):
"""Serialize all keyword arguments to json and hash it."""
json_args = self.to_json(**kwargs)
hash_str = sha256(json_args, usedforsecurity=False).hexdigest()
return hash_str[:self.hash_len]
def path_and_hash(self, *args, **kwargs) -> Tuple[Path, str]:
path_fn, arg_filter = with_filtered_args(self.path_fn)
hash = self.make_hash(
# Do not have arguments in the hash that path_fn uses.
**{k: a for k, a in kwargs.items() if k not in arg_filter}
)
path_args = self.args_to_kwargs(args, kwargs, hash=hash)
path = Path(path_fn(**path_args))
return path, hash
def hash(self, *args, **kwargs) -> str:
"""Hash the keyword arguments.
Note that the hash is dependent on the path function of this instance:
All arguments of the path function are excluded from the hash."""
return self.path_and_hash(*args, **kwargs)[1]
def path(self, *args, **kwargs) -> Path:
return self.path_and_hash(*args, **kwargs)[0]
def dir(self, *args, **kwargs) -> Path:
return self.path(*args, **kwargs).parent
#%% [markdown]
# A big possible extension is to hash the returned object.
# If we do that, we can do something like this:
# ```
# fn = load_or_create(orig_fn)
# obj = fn(*args)
# assert fn.path(*args) == fn.obj_path(obj)
# ```
#
# So, we'd get a _path_ from an object!
#
# How would this work out?
#
# first, after loading and/or storing, we save the hash as we store the object.
# ```
# obj_hash = get_hash_from_flie(path)
# self.cache[obj_hash] = [path, kwargs]
# ```
#
# then, in self.obj_path, we fake-save the object using the provided save function.
# instead of saving to a file, we save to some class that inherited io.BytesIO:
# ```
# class HashingWriter(io.BytesIO):
# def __init__(self):
# super().__init__() # Initialize the BytesIO buffer
# self.hash = hashlib.sha256() # Initialize the SHA-256 hash object
#
# def write(self, b):
# return self.hash.update(b) # Update the hash with the data being written
#
# def get_hash(self):
# return self.hash.hexdigest() # Return the hexadecimal digest of the hash
# ```
#
# so self.obj_path would look something like:
# ```
# def obj_path(self, obj):
# with open(HashingWriter(), "wb") as f:
# f.write(obj)
# if f.get_hash() in self.cache:
# return self.cache[hash][0]
# ```
# In reality, obj_path() can just be part of path(), since path() should expect
# kwargs to be present and can safely assume it's being called as obj_path()
# otherwise.
#
# Caveat, this will only work for save functions that use an open file.
# Save functions that take a path would not work.
#
# More practical case:
# ```
# fn_a = load_or_create()(fn_a)
#
# def fn_b(obj_a, cfg):
# do_something(obj_a, cfg)
#
# fn_b = load_or_create(
# path_fn=lambda obj_a, hash: fn_a.get_path(obj_a) / hash
# )(fn_b)
# ```
# This avoids having to make pass obj_a's config to fn_b in order to determine the path!