path_from_obj which can accept the obj or the LoadOrCreateCFG of the obj
This commit is contained in:
parent
746b2cf6b0
commit
b8c72498ba
174
src/store.py
174
src/store.py
|
@ -1,5 +1,6 @@
|
|||
from hashlib import sha256
|
||||
import inspect
|
||||
import io
|
||||
import json
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
|
@ -23,6 +24,30 @@ def with_filtered_args(fn):
|
|||
return inner, arg_filter
|
||||
|
||||
|
||||
class HashingWriter(io.BytesIO):
|
||||
def __init__(self):
|
||||
super().__init__() # Initialize the BytesIO buffer
|
||||
self.hash = sha256() # Initialize the SHA-256 hash object
|
||||
|
||||
def write(self, b):
|
||||
self.hash.update(b) # Update the hash with the data being written
|
||||
return 0
|
||||
# return super().write(b) # Write the data to the BytesIO buffer
|
||||
|
||||
def writelines(self, lines):
|
||||
for line in lines:
|
||||
self.write(line)
|
||||
|
||||
def get_hash(self):
|
||||
return self.hash.hexdigest() # Return the hexadecimal digest of the hash
|
||||
|
||||
|
||||
class LoadOrCreateCFG:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
||||
class load_or_create:
|
||||
|
||||
def __init__(
|
||||
|
@ -32,6 +57,7 @@ class load_or_create:
|
|||
path_fn=lambda hash: hash + ".pkl",
|
||||
hash_len=8,
|
||||
save_args=False,
|
||||
save_json=False,
|
||||
plain_text=False
|
||||
):
|
||||
self.load=with_filtered_args(load)[0]
|
||||
|
@ -41,6 +67,7 @@ class load_or_create:
|
|||
self.path_fn=path_fn
|
||||
self.hash_len=hash_len
|
||||
self.save_args=save_args
|
||||
self.save_json=save_json
|
||||
self.plain_text=plain_text
|
||||
|
||||
def __call__(self, fn):
|
||||
|
@ -52,18 +79,24 @@ class inner(load_or_create):
|
|||
self.__dict__.update(parent.__dict__)
|
||||
self.fn = fn
|
||||
self.arg_names = [p.name for p in inspect.signature(fn).parameters.values()]
|
||||
self.obj_to_args = dict()
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# Store the keyword arguments into json and hash it to get the storage path.
|
||||
path = self.path(*args, **kwargs)
|
||||
merged_args = self.args_to_kwargs(args, kwargs, path=path)
|
||||
|
||||
obj = self.load_wrapper(**self.args_to_kwargs(args, kwargs, path=path))
|
||||
if obj is not None: return obj
|
||||
obj = self.load_wrapper(**merged_args)
|
||||
if obj is not None:
|
||||
if "file" in self.save_arg_names: self.hash_obj({"path": path} | kwargs)
|
||||
return obj
|
||||
|
||||
obj = self.fn(*args, **kwargs)
|
||||
if obj is None: return obj
|
||||
|
||||
self.save_wrapper(obj, **self.args_to_kwargs(args, kwargs, path=path))
|
||||
self.save_wrapper(obj, *args, **{"path": path} | kwargs)
|
||||
if "file" in self.save_arg_names: self.hash_obj({"path": path} | kwargs)
|
||||
if self.save_json: path.with_suffix(".json").write_bytes(self.to_json(**kwargs))
|
||||
|
||||
return obj
|
||||
|
||||
|
@ -94,24 +127,28 @@ class inner(load_or_create):
|
|||
return self.load(**{"file": file} | load_args)
|
||||
|
||||
|
||||
def save_wrapper(self, obj, path, **save_args):
|
||||
def save_wrapper(self, obj, *args, **kwargs):
|
||||
"""Only open a file if the save function requests you to.
|
||||
"""
|
||||
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
||||
if "file" not in self.save_arg_names:
|
||||
return self.save(obj, **{"path": path} | save_args)
|
||||
merged_args = self.args_to_kwargs(args, kwargs)
|
||||
|
||||
if "file" in save_args:
|
||||
file = save_args["file"]
|
||||
path = Path(merged_args["path"])
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if "file" not in self.save_arg_names:
|
||||
return self.save(obj, **merged_args)
|
||||
|
||||
if "file" in merged_args:
|
||||
file = merged_args["file"]
|
||||
else:
|
||||
file = open(path, "w" if self.plain_text else "wb")
|
||||
|
||||
if self.save_args:
|
||||
file.write(self.to_json(**save_args))
|
||||
file.write(self.to_json(**kwargs))
|
||||
|
||||
self.save(obj, **{"file": file, "path": path} | save_args)
|
||||
self.save(obj, **{"file": file} | merged_args)
|
||||
|
||||
if "file" not in save_args:
|
||||
if "file" not in merged_args:
|
||||
file.close()
|
||||
|
||||
|
||||
|
@ -125,7 +162,10 @@ class inner(load_or_create):
|
|||
) + "\n"
|
||||
).encode("utf-8")
|
||||
|
||||
def make_hash(self, **kwargs):
|
||||
# Functions to do with hashing keyword arguments and determining an
|
||||
# object's path from its arguments.
|
||||
|
||||
def hash_kwargs(self, **kwargs):
|
||||
"""Serialize all keyword arguments to json and hash it."""
|
||||
json_args = self.to_json(**kwargs)
|
||||
hash_str = sha256(json_args, usedforsecurity=False).hexdigest()
|
||||
|
@ -134,7 +174,7 @@ class inner(load_or_create):
|
|||
def path_and_hash(self, *args, **kwargs) -> Tuple[Path, str]:
|
||||
path_fn, arg_filter = with_filtered_args(self.path_fn)
|
||||
|
||||
hash = self.make_hash(
|
||||
hash = self.hash_kwargs(
|
||||
# Do not have arguments in the hash that path_fn uses.
|
||||
**{k: a for k, a in kwargs.items() if k not in arg_filter}
|
||||
)
|
||||
|
@ -157,65 +197,49 @@ class inner(load_or_create):
|
|||
def dir(self, *args, **kwargs) -> Path:
|
||||
return self.path(*args, **kwargs).parent
|
||||
|
||||
# Functions to do with inferring an object's arguments by keeping track of
|
||||
# its hash.
|
||||
|
||||
def hash_obj(self, kwargs):
|
||||
"""After saving or loading the object, we get its hash from the
|
||||
storage file.
|
||||
This hash we can later use, to infer the arguments that created this
|
||||
object.
|
||||
"""
|
||||
path = Path(kwargs["path"])
|
||||
if not path.exists():
|
||||
return
|
||||
with open(path, "rb") as file:
|
||||
if self.save_args:
|
||||
file.readline()
|
||||
hash = sha256(file.read()).hexdigest()
|
||||
self.obj_to_args[hash] = kwargs
|
||||
|
||||
def args_from_obj(self, obj, *args, **kwargs):
|
||||
"""Hash the object using the user-provided self.save.
|
||||
Then retrieve its arguments by looking up hashes of previously loaded
|
||||
or saved objects.
|
||||
Also ssee self.hash_obj
|
||||
"""
|
||||
file = HashingWriter()
|
||||
self.save(obj, **self.args_to_kwargs(args, kwargs), file=file)
|
||||
hash = file.get_hash()
|
||||
return self.obj_to_args[hash]
|
||||
|
||||
def path_from_obj(self, obj, *args, **kwargs):
|
||||
if isinstance(obj, LoadOrCreateCFG):
|
||||
return self.path(*obj.args, **obj.kwargs)
|
||||
return self.args_from_obj(obj, *args, **kwargs)["path"]
|
||||
|
||||
def dir_from_obj(self, obj, *args, **kwargs):
|
||||
return self.path_from_obj(obj, *args, **kwargs).parent
|
||||
|
||||
def cfg(self, *args, **kwargs):
|
||||
"""If you don't care about some object, but only about it's path,
|
||||
but you still need to pass an object to some other function in order
|
||||
to get it's path, you can pass a LoadOrCreateCFG instead, saving you
|
||||
from loading or creating that object..
|
||||
"""
|
||||
return LoadOrCreateCFG(*args, **kwargs)
|
||||
|
||||
|
||||
#%% [markdown]
|
||||
# A big possible extension is to hash the returned object.
|
||||
# If we do that, we can do something like this:
|
||||
# ```
|
||||
# fn = load_or_create(orig_fn)
|
||||
# obj = fn(*args)
|
||||
# assert fn.path(*args) == fn.obj_path(obj)
|
||||
# ```
|
||||
#
|
||||
# So, we'd get a _path_ from an object!
|
||||
#
|
||||
# How would this work out?
|
||||
#
|
||||
# first, after loading and/or storing, we save the hash as we store the object.
|
||||
# ```
|
||||
# obj_hash = get_hash_from_flie(path)
|
||||
# self.cache[obj_hash] = [path, kwargs]
|
||||
# ```
|
||||
#
|
||||
# then, in self.obj_path, we fake-save the object using the provided save function.
|
||||
# instead of saving to a file, we save to some class that inherited io.BytesIO:
|
||||
# ```
|
||||
# class HashingWriter(io.BytesIO):
|
||||
# def __init__(self):
|
||||
# super().__init__() # Initialize the BytesIO buffer
|
||||
# self.hash = hashlib.sha256() # Initialize the SHA-256 hash object
|
||||
#
|
||||
# def write(self, b):
|
||||
# return self.hash.update(b) # Update the hash with the data being written
|
||||
#
|
||||
# def get_hash(self):
|
||||
# return self.hash.hexdigest() # Return the hexadecimal digest of the hash
|
||||
# ```
|
||||
#
|
||||
# so self.obj_path would look something like:
|
||||
# ```
|
||||
# def obj_path(self, obj):
|
||||
# with open(HashingWriter(), "wb") as f:
|
||||
# f.write(obj)
|
||||
# if f.get_hash() in self.cache:
|
||||
# return self.cache[hash][0]
|
||||
# ```
|
||||
# In reality, obj_path() can just be part of path(), since path() should expect
|
||||
# kwargs to be present and can safely assume it's being called as obj_path()
|
||||
# otherwise.
|
||||
#
|
||||
# Caveat, this will only work for save functions that use an open file.
|
||||
# Save functions that take a path would not work.
|
||||
#
|
||||
# More practical case:
|
||||
# ```
|
||||
# fn_a = load_or_create()(fn_a)
|
||||
#
|
||||
# def fn_b(obj_a, cfg):
|
||||
# do_something(obj_a, cfg)
|
||||
#
|
||||
# fn_b = load_or_create(
|
||||
# path_fn=lambda obj_a, hash: fn_a.get_path(obj_a) / hash
|
||||
# )(fn_b)
|
||||
# ```
|
||||
# This avoids having to make pass obj_a's config to fn_b in order to determine the path!
|
||||
|
|
Loading…
Reference in New Issue
Block a user