store now allows positional arguments to be passed to the path function
This commit is contained in:
parent
fe8ae1fda7
commit
35396c80ca
195
src/store.py
195
src/store.py
|
@ -3,25 +3,24 @@ import inspect
|
|||
import json
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def with_filtered_kwargs(fn):
|
||||
def with_filtered_args(fn):
|
||||
"""
|
||||
Return a version of the input function that ignores instead of errors on
|
||||
unknown keyword arguments.
|
||||
unknown arguments.
|
||||
"""
|
||||
# Enumerate all small_fn's keyword arguments.
|
||||
filter_kwargs = {param.name for param in inspect.signature(fn).parameters.values()
|
||||
if param.kind == param.POSITIONAL_OR_KEYWORD
|
||||
or param.kind == param.KEYWORD_ONLY}
|
||||
arg_filter = {param.name for param in inspect.signature(fn).parameters.values()
|
||||
if param.kind == param.POSITIONAL_OR_KEYWORD
|
||||
or param.kind == param.KEYWORD_ONLY}
|
||||
|
||||
def inner(*args, **kwargs):
|
||||
# Filter any keyword arguments not present in small_fn.
|
||||
kwargs = {k: v for k, v in kwargs.items() if k in filter_kwargs}
|
||||
|
||||
kwargs = {k: v for k, v in kwargs.items() if k in arg_filter}
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
return inner, arg_filter
|
||||
|
||||
|
||||
class load_or_create:
|
||||
|
@ -30,56 +29,160 @@ class load_or_create:
|
|||
self,
|
||||
load=pickle.load,
|
||||
save=pickle.dump,
|
||||
prefix="./",
|
||||
suffix=".pkl",
|
||||
path_fn=lambda hash: hash + ".pkl",
|
||||
hash_len=8,
|
||||
save_args=False,
|
||||
plain_text=False
|
||||
):
|
||||
self.load=with_filtered_kwargs(load)
|
||||
self.save=with_filtered_kwargs(save)
|
||||
self.prefix=prefix
|
||||
self.suffix=suffix
|
||||
self.load=with_filtered_args(load)[0]
|
||||
self.save=with_filtered_args(save)[0]
|
||||
self.path_fn=path_fn
|
||||
self.hash_len=hash_len
|
||||
self.save_args=save_args
|
||||
self.plain_text=plain_text
|
||||
|
||||
def __call__(self, fn):
|
||||
# fn_args = tuple(inspect.signature(fn).parameters.keys())
|
||||
return inner(self, fn)
|
||||
|
||||
|
||||
def inner(*args, **kwargs):
|
||||
# Store the keyword arguments into json and hash it to get the storage path.
|
||||
json_args = (
|
||||
json.dumps(
|
||||
kwargs,
|
||||
default=lambda x: vars(x) if hasattr(x, "__dict__") else str(x)
|
||||
) + "\n"
|
||||
).encode("utf-8")
|
||||
hash = sha256(json_args, usedforsecurity=False).hexdigest()[:self.hash_len]
|
||||
path = Path(self.prefix + hash + self.suffix)
|
||||
class inner(load_or_create):
|
||||
def __init__(self, parent, fn):
|
||||
self.__dict__.update(parent.__dict__)
|
||||
self.fn = fn
|
||||
self.arg_names = [p.name for p in inspect.signature(fn).parameters.values()]
|
||||
|
||||
# Convert all arguments to keyword arguments.
|
||||
# for arg_name, arg in zip(fn_args, args):
|
||||
# kwargs[arg_name] = arg
|
||||
def __call__(self, *args, **kwargs):
|
||||
# Store the keyword arguments into json and hash it to get the storage path.
|
||||
json_args = self.to_json(**kwargs)
|
||||
path = self.path(*args, **kwargs)
|
||||
# path = self.path(hash)
|
||||
|
||||
# If the storage path exists, load the cached object from there.
|
||||
if path.exists():
|
||||
with open(path, "r" if self.plain_text else "rb") as f:
|
||||
if self.save_args:
|
||||
f.readline()
|
||||
return self.load(f, **kwargs)
|
||||
|
||||
# Else, run the function and store its result.
|
||||
out = fn(*args, **kwargs)
|
||||
|
||||
if "/" in self.suffix:
|
||||
path.parent.mkdir(parents=True)
|
||||
with open(path, "w" if self.plain_text else "wb") as f:
|
||||
# If the storage path exists, load the cached object from there.
|
||||
if path.exists():
|
||||
with open(path, "r" if self.plain_text else "rb") as file:
|
||||
# If specified, the first line is a json of the keyword arguments.
|
||||
if self.save_args:
|
||||
f.write(json_args)
|
||||
self.save(f, out, **kwargs)
|
||||
file.readline()
|
||||
return self.load(**self.args_to_kwargs(args, kwargs, path=path, file=file))
|
||||
|
||||
return out
|
||||
# Else, run the function. and store its result.
|
||||
obj = self.fn(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
if obj is None:
|
||||
return obj
|
||||
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "w" if self.plain_text else "wb") as file:
|
||||
if self.save_args:
|
||||
file.write(json_args)
|
||||
self.save(obj, **self.args_to_kwargs(args, kwargs, path=path, file=file))
|
||||
|
||||
return obj
|
||||
|
||||
def args_to_kwargs(self, args, kwargs, **extra):
|
||||
return extra | kwargs | {self.arg_names[i]: a for i, a in enumerate(args)}
|
||||
|
||||
def to_json(self, **kwargs) -> bytes:
|
||||
"""Serialize all keyword arguments to json.
|
||||
We never serialize positional arguments"""
|
||||
return (
|
||||
json.dumps(
|
||||
kwargs,
|
||||
default=lambda x: vars(x) if hasattr(x, "__dict__") else str(x)
|
||||
) + "\n"
|
||||
).encode("utf-8")
|
||||
|
||||
def make_hash(self, **kwargs):
|
||||
"""Serialize all keyword arguments to json and hash it."""
|
||||
json_args = self.to_json(**kwargs)
|
||||
hash_str = sha256(json_args, usedforsecurity=False).hexdigest()
|
||||
return hash_str[:self.hash_len]
|
||||
|
||||
def path_and_hash(self, *args, **kwargs) -> Tuple[Path, str]:
|
||||
path_fn, arg_filter = with_filtered_args(self.path_fn)
|
||||
|
||||
hash = self.make_hash(
|
||||
# Do not have arguments in the hash that path_fn uses.
|
||||
**{k: a for k, a in kwargs.items() if k not in arg_filter}
|
||||
)
|
||||
|
||||
path_args = self.args_to_kwargs(args, kwargs, hash=hash)
|
||||
|
||||
path = Path(path_fn(**path_args))
|
||||
|
||||
return path, hash
|
||||
|
||||
def hash(self, *args, **kwargs) -> str:
|
||||
"""Hash the keyword arguments.
|
||||
Note that the hash is dependent on the path function of this instance:
|
||||
All arguments of the path function are excluded from the hash."""
|
||||
return self.path_and_hash(*args, **kwargs)[1]
|
||||
|
||||
def path(self, *args, **kwargs) -> Path:
|
||||
return self.path_and_hash(*args, **kwargs)[0]
|
||||
|
||||
def dir(self, *args, **kwargs) -> Path:
|
||||
return self.path(*args, **kwargs).parent
|
||||
|
||||
|
||||
#%% [markdown]
|
||||
# A big possible extension is to hash the returned object.
|
||||
# If we do that, we can do something like this:
|
||||
# ```
|
||||
# fn = load_or_create(orig_fn)
|
||||
# obj = fn(*args)
|
||||
# assert fn.path(*args) == fn.obj_path(obj)
|
||||
# ```
|
||||
#
|
||||
# So, we'd get a _path_ from an object!
|
||||
#
|
||||
# How would this work out?
|
||||
#
|
||||
# first, after loading and/or storing, we save the hash as we store the object.
|
||||
# ```
|
||||
# obj_hash = get_hash_from_flie(path)
|
||||
# self.cache[obj_hash] = [path, kwargs]
|
||||
# ```
|
||||
#
|
||||
# then, in self.obj_path, we fake-save the object using the provided save function.
|
||||
# instead of saving to a file, we save to some class that inherited io.BytesIO:
|
||||
# ```
|
||||
# class HashingWriter(io.BytesIO):
|
||||
# def __init__(self):
|
||||
# super().__init__() # Initialize the BytesIO buffer
|
||||
# self.hash = hashlib.sha256() # Initialize the SHA-256 hash object
|
||||
#
|
||||
# def write(self, b):
|
||||
# return self.hash.update(b) # Update the hash with the data being written
|
||||
#
|
||||
# def get_hash(self):
|
||||
# return self.hash.hexdigest() # Return the hexadecimal digest of the hash
|
||||
# ```
|
||||
#
|
||||
# so self.obj_path would look something like:
|
||||
# ```
|
||||
# def obj_path(self, obj):
|
||||
# with open(HashingWriter(), "wb") as f:
|
||||
# f.write(obj)
|
||||
# if f.get_hash() in self.cache:
|
||||
# return self.cache[hash][0]
|
||||
# ```
|
||||
# In reality, obj_path() can just be part of path(), since path() should expect
|
||||
# kwargs to be present and can safely assume it's being called as obj_path()
|
||||
# otherwise.
|
||||
#
|
||||
# Caveat, this will only work for save functions that use an open file.
|
||||
# Save functions that take a path would not work.
|
||||
#
|
||||
# More practical case:
|
||||
# ```
|
||||
# fn_a = load_or_create()(fn_a)
|
||||
#
|
||||
# def fn_b(obj_a, cfg):
|
||||
# do_something(obj_a, cfg)
|
||||
#
|
||||
# fn_b = load_or_create(
|
||||
# path_fn=lambda obj_a, hash: fn_a.get_path(obj_a) / hash
|
||||
# )(fn_b)
|
||||
# ```
|
||||
# This avoids having to make pass obj_a's config to fn_b in order to determine the path!
|
||||
|
|
Loading…
Reference in New Issue
Block a user