store now allows positional arguments to be passed to the path function
This commit is contained in:
parent
fe8ae1fda7
commit
35396c80ca
195
src/store.py
195
src/store.py
|
@ -3,25 +3,24 @@ import inspect
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import pickle
|
import pickle
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
|
||||||
def with_filtered_kwargs(fn):
|
def with_filtered_args(fn):
|
||||||
"""
|
"""
|
||||||
Return a version of the input function that ignores instead of errors on
|
Return a version of the input function that ignores instead of errors on
|
||||||
unknown keyword arguments.
|
unknown arguments.
|
||||||
"""
|
"""
|
||||||
# Enumerate all small_fn's keyword arguments.
|
arg_filter = {param.name for param in inspect.signature(fn).parameters.values()
|
||||||
filter_kwargs = {param.name for param in inspect.signature(fn).parameters.values()
|
if param.kind == param.POSITIONAL_OR_KEYWORD
|
||||||
if param.kind == param.POSITIONAL_OR_KEYWORD
|
or param.kind == param.KEYWORD_ONLY}
|
||||||
or param.kind == param.KEYWORD_ONLY}
|
|
||||||
|
|
||||||
def inner(*args, **kwargs):
|
def inner(*args, **kwargs):
|
||||||
# Filter any keyword arguments not present in small_fn.
|
# Filter any keyword arguments not present in small_fn.
|
||||||
kwargs = {k: v for k, v in kwargs.items() if k in filter_kwargs}
|
kwargs = {k: v for k, v in kwargs.items() if k in arg_filter}
|
||||||
|
|
||||||
return fn(*args, **kwargs)
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
return inner
|
return inner, arg_filter
|
||||||
|
|
||||||
|
|
||||||
class load_or_create:
|
class load_or_create:
|
||||||
|
@ -30,56 +29,160 @@ class load_or_create:
|
||||||
self,
|
self,
|
||||||
load=pickle.load,
|
load=pickle.load,
|
||||||
save=pickle.dump,
|
save=pickle.dump,
|
||||||
prefix="./",
|
path_fn=lambda hash: hash + ".pkl",
|
||||||
suffix=".pkl",
|
|
||||||
hash_len=8,
|
hash_len=8,
|
||||||
save_args=False,
|
save_args=False,
|
||||||
plain_text=False
|
plain_text=False
|
||||||
):
|
):
|
||||||
self.load=with_filtered_kwargs(load)
|
self.load=with_filtered_args(load)[0]
|
||||||
self.save=with_filtered_kwargs(save)
|
self.save=with_filtered_args(save)[0]
|
||||||
self.prefix=prefix
|
self.path_fn=path_fn
|
||||||
self.suffix=suffix
|
|
||||||
self.hash_len=hash_len
|
self.hash_len=hash_len
|
||||||
self.save_args=save_args
|
self.save_args=save_args
|
||||||
self.plain_text=plain_text
|
self.plain_text=plain_text
|
||||||
|
|
||||||
def __call__(self, fn):
|
def __call__(self, fn):
|
||||||
# fn_args = tuple(inspect.signature(fn).parameters.keys())
|
return inner(self, fn)
|
||||||
|
|
||||||
|
|
||||||
def inner(*args, **kwargs):
|
class inner(load_or_create):
|
||||||
# Store the keyword arguments into json and hash it to get the storage path.
|
def __init__(self, parent, fn):
|
||||||
json_args = (
|
self.__dict__.update(parent.__dict__)
|
||||||
json.dumps(
|
self.fn = fn
|
||||||
kwargs,
|
self.arg_names = [p.name for p in inspect.signature(fn).parameters.values()]
|
||||||
default=lambda x: vars(x) if hasattr(x, "__dict__") else str(x)
|
|
||||||
) + "\n"
|
|
||||||
).encode("utf-8")
|
|
||||||
hash = sha256(json_args, usedforsecurity=False).hexdigest()[:self.hash_len]
|
|
||||||
path = Path(self.prefix + hash + self.suffix)
|
|
||||||
|
|
||||||
# Convert all arguments to keyword arguments.
|
def __call__(self, *args, **kwargs):
|
||||||
# for arg_name, arg in zip(fn_args, args):
|
# Store the keyword arguments into json and hash it to get the storage path.
|
||||||
# kwargs[arg_name] = arg
|
json_args = self.to_json(**kwargs)
|
||||||
|
path = self.path(*args, **kwargs)
|
||||||
|
# path = self.path(hash)
|
||||||
|
|
||||||
# If the storage path exists, load the cached object from there.
|
# If the storage path exists, load the cached object from there.
|
||||||
if path.exists():
|
if path.exists():
|
||||||
with open(path, "r" if self.plain_text else "rb") as f:
|
with open(path, "r" if self.plain_text else "rb") as file:
|
||||||
if self.save_args:
|
# If specified, the first line is a json of the keyword arguments.
|
||||||
f.readline()
|
|
||||||
return self.load(f, **kwargs)
|
|
||||||
|
|
||||||
# Else, run the function and store its result.
|
|
||||||
out = fn(*args, **kwargs)
|
|
||||||
|
|
||||||
if "/" in self.suffix:
|
|
||||||
path.parent.mkdir(parents=True)
|
|
||||||
with open(path, "w" if self.plain_text else "wb") as f:
|
|
||||||
if self.save_args:
|
if self.save_args:
|
||||||
f.write(json_args)
|
file.readline()
|
||||||
self.save(f, out, **kwargs)
|
return self.load(**self.args_to_kwargs(args, kwargs, path=path, file=file))
|
||||||
|
|
||||||
return out
|
# Else, run the function. and store its result.
|
||||||
|
obj = self.fn(*args, **kwargs)
|
||||||
|
|
||||||
return inner
|
if obj is None:
|
||||||
|
return obj
|
||||||
|
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(path, "w" if self.plain_text else "wb") as file:
|
||||||
|
if self.save_args:
|
||||||
|
file.write(json_args)
|
||||||
|
self.save(obj, **self.args_to_kwargs(args, kwargs, path=path, file=file))
|
||||||
|
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def args_to_kwargs(self, args, kwargs, **extra):
|
||||||
|
return extra | kwargs | {self.arg_names[i]: a for i, a in enumerate(args)}
|
||||||
|
|
||||||
|
def to_json(self, **kwargs) -> bytes:
|
||||||
|
"""Serialize all keyword arguments to json.
|
||||||
|
We never serialize positional arguments"""
|
||||||
|
return (
|
||||||
|
json.dumps(
|
||||||
|
kwargs,
|
||||||
|
default=lambda x: vars(x) if hasattr(x, "__dict__") else str(x)
|
||||||
|
) + "\n"
|
||||||
|
).encode("utf-8")
|
||||||
|
|
||||||
|
def make_hash(self, **kwargs):
|
||||||
|
"""Serialize all keyword arguments to json and hash it."""
|
||||||
|
json_args = self.to_json(**kwargs)
|
||||||
|
hash_str = sha256(json_args, usedforsecurity=False).hexdigest()
|
||||||
|
return hash_str[:self.hash_len]
|
||||||
|
|
||||||
|
def path_and_hash(self, *args, **kwargs) -> Tuple[Path, str]:
|
||||||
|
path_fn, arg_filter = with_filtered_args(self.path_fn)
|
||||||
|
|
||||||
|
hash = self.make_hash(
|
||||||
|
# Do not have arguments in the hash that path_fn uses.
|
||||||
|
**{k: a for k, a in kwargs.items() if k not in arg_filter}
|
||||||
|
)
|
||||||
|
|
||||||
|
path_args = self.args_to_kwargs(args, kwargs, hash=hash)
|
||||||
|
|
||||||
|
path = Path(path_fn(**path_args))
|
||||||
|
|
||||||
|
return path, hash
|
||||||
|
|
||||||
|
def hash(self, *args, **kwargs) -> str:
|
||||||
|
"""Hash the keyword arguments.
|
||||||
|
Note that the hash is dependent on the path function of this instance:
|
||||||
|
All arguments of the path function are excluded from the hash."""
|
||||||
|
return self.path_and_hash(*args, **kwargs)[1]
|
||||||
|
|
||||||
|
def path(self, *args, **kwargs) -> Path:
|
||||||
|
return self.path_and_hash(*args, **kwargs)[0]
|
||||||
|
|
||||||
|
def dir(self, *args, **kwargs) -> Path:
|
||||||
|
return self.path(*args, **kwargs).parent
|
||||||
|
|
||||||
|
|
||||||
|
#%% [markdown]
|
||||||
|
# A big possible extension is to hash the returned object.
|
||||||
|
# If we do that, we can do something like this:
|
||||||
|
# ```
|
||||||
|
# fn = load_or_create(orig_fn)
|
||||||
|
# obj = fn(*args)
|
||||||
|
# assert fn.path(*args) == fn.obj_path(obj)
|
||||||
|
# ```
|
||||||
|
#
|
||||||
|
# So, we'd get a _path_ from an object!
|
||||||
|
#
|
||||||
|
# How would this work out?
|
||||||
|
#
|
||||||
|
# first, after loading and/or storing, we save the hash as we store the object.
|
||||||
|
# ```
|
||||||
|
# obj_hash = get_hash_from_flie(path)
|
||||||
|
# self.cache[obj_hash] = [path, kwargs]
|
||||||
|
# ```
|
||||||
|
#
|
||||||
|
# then, in self.obj_path, we fake-save the object using the provided save function.
|
||||||
|
# instead of saving to a file, we save to some class that inherited io.BytesIO:
|
||||||
|
# ```
|
||||||
|
# class HashingWriter(io.BytesIO):
|
||||||
|
# def __init__(self):
|
||||||
|
# super().__init__() # Initialize the BytesIO buffer
|
||||||
|
# self.hash = hashlib.sha256() # Initialize the SHA-256 hash object
|
||||||
|
#
|
||||||
|
# def write(self, b):
|
||||||
|
# return self.hash.update(b) # Update the hash with the data being written
|
||||||
|
#
|
||||||
|
# def get_hash(self):
|
||||||
|
# return self.hash.hexdigest() # Return the hexadecimal digest of the hash
|
||||||
|
# ```
|
||||||
|
#
|
||||||
|
# so self.obj_path would look something like:
|
||||||
|
# ```
|
||||||
|
# def obj_path(self, obj):
|
||||||
|
# with open(HashingWriter(), "wb") as f:
|
||||||
|
# f.write(obj)
|
||||||
|
# if f.get_hash() in self.cache:
|
||||||
|
# return self.cache[hash][0]
|
||||||
|
# ```
|
||||||
|
# In reality, obj_path() can just be part of path(), since path() should expect
|
||||||
|
# kwargs to be present and can safely assume it's being called as obj_path()
|
||||||
|
# otherwise.
|
||||||
|
#
|
||||||
|
# Caveat, this will only work for save functions that use an open file.
|
||||||
|
# Save functions that take a path would not work.
|
||||||
|
#
|
||||||
|
# More practical case:
|
||||||
|
# ```
|
||||||
|
# fn_a = load_or_create()(fn_a)
|
||||||
|
#
|
||||||
|
# def fn_b(obj_a, cfg):
|
||||||
|
# do_something(obj_a, cfg)
|
||||||
|
#
|
||||||
|
# fn_b = load_or_create(
|
||||||
|
# path_fn=lambda obj_a, hash: fn_a.get_path(obj_a) / hash
|
||||||
|
# )(fn_b)
|
||||||
|
# ```
|
||||||
|
# This avoids having to make pass obj_a's config to fn_b in order to determine the path!
|
||||||
|
|
Loading…
Reference in New Issue
Block a user