path_from_obj which can accept the obj or the LoadOrCreateCFG of the obj
This commit is contained in:
parent
746b2cf6b0
commit
b8c72498ba
174
src/store.py
174
src/store.py
|
@ -1,5 +1,6 @@
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
import inspect
|
import inspect
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import pickle
|
import pickle
|
||||||
|
@ -23,6 +24,30 @@ def with_filtered_args(fn):
|
||||||
return inner, arg_filter
|
return inner, arg_filter
|
||||||
|
|
||||||
|
|
||||||
|
class HashingWriter(io.BytesIO):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__() # Initialize the BytesIO buffer
|
||||||
|
self.hash = sha256() # Initialize the SHA-256 hash object
|
||||||
|
|
||||||
|
def write(self, b):
|
||||||
|
self.hash.update(b) # Update the hash with the data being written
|
||||||
|
return 0
|
||||||
|
# return super().write(b) # Write the data to the BytesIO buffer
|
||||||
|
|
||||||
|
def writelines(self, lines):
|
||||||
|
for line in lines:
|
||||||
|
self.write(line)
|
||||||
|
|
||||||
|
def get_hash(self):
|
||||||
|
return self.hash.hexdigest() # Return the hexadecimal digest of the hash
|
||||||
|
|
||||||
|
|
||||||
|
class LoadOrCreateCFG:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.args = args
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
|
||||||
class load_or_create:
|
class load_or_create:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -32,6 +57,7 @@ class load_or_create:
|
||||||
path_fn=lambda hash: hash + ".pkl",
|
path_fn=lambda hash: hash + ".pkl",
|
||||||
hash_len=8,
|
hash_len=8,
|
||||||
save_args=False,
|
save_args=False,
|
||||||
|
save_json=False,
|
||||||
plain_text=False
|
plain_text=False
|
||||||
):
|
):
|
||||||
self.load=with_filtered_args(load)[0]
|
self.load=with_filtered_args(load)[0]
|
||||||
|
@ -41,6 +67,7 @@ class load_or_create:
|
||||||
self.path_fn=path_fn
|
self.path_fn=path_fn
|
||||||
self.hash_len=hash_len
|
self.hash_len=hash_len
|
||||||
self.save_args=save_args
|
self.save_args=save_args
|
||||||
|
self.save_json=save_json
|
||||||
self.plain_text=plain_text
|
self.plain_text=plain_text
|
||||||
|
|
||||||
def __call__(self, fn):
|
def __call__(self, fn):
|
||||||
|
@ -52,18 +79,24 @@ class inner(load_or_create):
|
||||||
self.__dict__.update(parent.__dict__)
|
self.__dict__.update(parent.__dict__)
|
||||||
self.fn = fn
|
self.fn = fn
|
||||||
self.arg_names = [p.name for p in inspect.signature(fn).parameters.values()]
|
self.arg_names = [p.name for p in inspect.signature(fn).parameters.values()]
|
||||||
|
self.obj_to_args = dict()
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
# Store the keyword arguments into json and hash it to get the storage path.
|
# Store the keyword arguments into json and hash it to get the storage path.
|
||||||
path = self.path(*args, **kwargs)
|
path = self.path(*args, **kwargs)
|
||||||
|
merged_args = self.args_to_kwargs(args, kwargs, path=path)
|
||||||
|
|
||||||
obj = self.load_wrapper(**self.args_to_kwargs(args, kwargs, path=path))
|
obj = self.load_wrapper(**merged_args)
|
||||||
if obj is not None: return obj
|
if obj is not None:
|
||||||
|
if "file" in self.save_arg_names: self.hash_obj({"path": path} | kwargs)
|
||||||
|
return obj
|
||||||
|
|
||||||
obj = self.fn(*args, **kwargs)
|
obj = self.fn(*args, **kwargs)
|
||||||
if obj is None: return obj
|
if obj is None: return obj
|
||||||
|
|
||||||
self.save_wrapper(obj, **self.args_to_kwargs(args, kwargs, path=path))
|
self.save_wrapper(obj, *args, **{"path": path} | kwargs)
|
||||||
|
if "file" in self.save_arg_names: self.hash_obj({"path": path} | kwargs)
|
||||||
|
if self.save_json: path.with_suffix(".json").write_bytes(self.to_json(**kwargs))
|
||||||
|
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
@ -94,24 +127,28 @@ class inner(load_or_create):
|
||||||
return self.load(**{"file": file} | load_args)
|
return self.load(**{"file": file} | load_args)
|
||||||
|
|
||||||
|
|
||||||
def save_wrapper(self, obj, path, **save_args):
|
def save_wrapper(self, obj, *args, **kwargs):
|
||||||
"""Only open a file if the save function requests you to.
|
"""Only open a file if the save function requests you to.
|
||||||
"""
|
"""
|
||||||
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
merged_args = self.args_to_kwargs(args, kwargs)
|
||||||
if "file" not in self.save_arg_names:
|
|
||||||
return self.save(obj, **{"path": path} | save_args)
|
|
||||||
|
|
||||||
if "file" in save_args:
|
path = Path(merged_args["path"])
|
||||||
file = save_args["file"]
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if "file" not in self.save_arg_names:
|
||||||
|
return self.save(obj, **merged_args)
|
||||||
|
|
||||||
|
if "file" in merged_args:
|
||||||
|
file = merged_args["file"]
|
||||||
else:
|
else:
|
||||||
file = open(path, "w" if self.plain_text else "wb")
|
file = open(path, "w" if self.plain_text else "wb")
|
||||||
|
|
||||||
if self.save_args:
|
if self.save_args:
|
||||||
file.write(self.to_json(**save_args))
|
file.write(self.to_json(**kwargs))
|
||||||
|
|
||||||
self.save(obj, **{"file": file, "path": path} | save_args)
|
self.save(obj, **{"file": file} | merged_args)
|
||||||
|
|
||||||
if "file" not in save_args:
|
if "file" not in merged_args:
|
||||||
file.close()
|
file.close()
|
||||||
|
|
||||||
|
|
||||||
|
@ -125,7 +162,10 @@ class inner(load_or_create):
|
||||||
) + "\n"
|
) + "\n"
|
||||||
).encode("utf-8")
|
).encode("utf-8")
|
||||||
|
|
||||||
def make_hash(self, **kwargs):
|
# Functions to do with hashing keyword arguments and determining an
|
||||||
|
# object's path from its arguments.
|
||||||
|
|
||||||
|
def hash_kwargs(self, **kwargs):
|
||||||
"""Serialize all keyword arguments to json and hash it."""
|
"""Serialize all keyword arguments to json and hash it."""
|
||||||
json_args = self.to_json(**kwargs)
|
json_args = self.to_json(**kwargs)
|
||||||
hash_str = sha256(json_args, usedforsecurity=False).hexdigest()
|
hash_str = sha256(json_args, usedforsecurity=False).hexdigest()
|
||||||
|
@ -134,7 +174,7 @@ class inner(load_or_create):
|
||||||
def path_and_hash(self, *args, **kwargs) -> Tuple[Path, str]:
|
def path_and_hash(self, *args, **kwargs) -> Tuple[Path, str]:
|
||||||
path_fn, arg_filter = with_filtered_args(self.path_fn)
|
path_fn, arg_filter = with_filtered_args(self.path_fn)
|
||||||
|
|
||||||
hash = self.make_hash(
|
hash = self.hash_kwargs(
|
||||||
# Do not have arguments in the hash that path_fn uses.
|
# Do not have arguments in the hash that path_fn uses.
|
||||||
**{k: a for k, a in kwargs.items() if k not in arg_filter}
|
**{k: a for k, a in kwargs.items() if k not in arg_filter}
|
||||||
)
|
)
|
||||||
|
@ -157,65 +197,49 @@ class inner(load_or_create):
|
||||||
def dir(self, *args, **kwargs) -> Path:
|
def dir(self, *args, **kwargs) -> Path:
|
||||||
return self.path(*args, **kwargs).parent
|
return self.path(*args, **kwargs).parent
|
||||||
|
|
||||||
|
# Functions to do with inferring an object's arguments by keeping track of
|
||||||
|
# its hash.
|
||||||
|
|
||||||
|
def hash_obj(self, kwargs):
|
||||||
|
"""After saving or loading the object, we get its hash from the
|
||||||
|
storage file.
|
||||||
|
This hash we can later use, to infer the arguments that created this
|
||||||
|
object.
|
||||||
|
"""
|
||||||
|
path = Path(kwargs["path"])
|
||||||
|
if not path.exists():
|
||||||
|
return
|
||||||
|
with open(path, "rb") as file:
|
||||||
|
if self.save_args:
|
||||||
|
file.readline()
|
||||||
|
hash = sha256(file.read()).hexdigest()
|
||||||
|
self.obj_to_args[hash] = kwargs
|
||||||
|
|
||||||
|
def args_from_obj(self, obj, *args, **kwargs):
|
||||||
|
"""Hash the object using the user-provided self.save.
|
||||||
|
Then retrieve its arguments by looking up hashes of previously loaded
|
||||||
|
or saved objects.
|
||||||
|
Also ssee self.hash_obj
|
||||||
|
"""
|
||||||
|
file = HashingWriter()
|
||||||
|
self.save(obj, **self.args_to_kwargs(args, kwargs), file=file)
|
||||||
|
hash = file.get_hash()
|
||||||
|
return self.obj_to_args[hash]
|
||||||
|
|
||||||
|
def path_from_obj(self, obj, *args, **kwargs):
|
||||||
|
if isinstance(obj, LoadOrCreateCFG):
|
||||||
|
return self.path(*obj.args, **obj.kwargs)
|
||||||
|
return self.args_from_obj(obj, *args, **kwargs)["path"]
|
||||||
|
|
||||||
|
def dir_from_obj(self, obj, *args, **kwargs):
|
||||||
|
return self.path_from_obj(obj, *args, **kwargs).parent
|
||||||
|
|
||||||
|
def cfg(self, *args, **kwargs):
|
||||||
|
"""If you don't care about some object, but only about it's path,
|
||||||
|
but you still need to pass an object to some other function in order
|
||||||
|
to get it's path, you can pass a LoadOrCreateCFG instead, saving you
|
||||||
|
from loading or creating that object..
|
||||||
|
"""
|
||||||
|
return LoadOrCreateCFG(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
#%% [markdown]
|
|
||||||
# A big possible extension is to hash the returned object.
|
|
||||||
# If we do that, we can do something like this:
|
|
||||||
# ```
|
|
||||||
# fn = load_or_create(orig_fn)
|
|
||||||
# obj = fn(*args)
|
|
||||||
# assert fn.path(*args) == fn.obj_path(obj)
|
|
||||||
# ```
|
|
||||||
#
|
|
||||||
# So, we'd get a _path_ from an object!
|
|
||||||
#
|
|
||||||
# How would this work out?
|
|
||||||
#
|
|
||||||
# first, after loading and/or storing, we save the hash as we store the object.
|
|
||||||
# ```
|
|
||||||
# obj_hash = get_hash_from_flie(path)
|
|
||||||
# self.cache[obj_hash] = [path, kwargs]
|
|
||||||
# ```
|
|
||||||
#
|
|
||||||
# then, in self.obj_path, we fake-save the object using the provided save function.
|
|
||||||
# instead of saving to a file, we save to some class that inherited io.BytesIO:
|
|
||||||
# ```
|
|
||||||
# class HashingWriter(io.BytesIO):
|
|
||||||
# def __init__(self):
|
|
||||||
# super().__init__() # Initialize the BytesIO buffer
|
|
||||||
# self.hash = hashlib.sha256() # Initialize the SHA-256 hash object
|
|
||||||
#
|
|
||||||
# def write(self, b):
|
|
||||||
# return self.hash.update(b) # Update the hash with the data being written
|
|
||||||
#
|
|
||||||
# def get_hash(self):
|
|
||||||
# return self.hash.hexdigest() # Return the hexadecimal digest of the hash
|
|
||||||
# ```
|
|
||||||
#
|
|
||||||
# so self.obj_path would look something like:
|
|
||||||
# ```
|
|
||||||
# def obj_path(self, obj):
|
|
||||||
# with open(HashingWriter(), "wb") as f:
|
|
||||||
# f.write(obj)
|
|
||||||
# if f.get_hash() in self.cache:
|
|
||||||
# return self.cache[hash][0]
|
|
||||||
# ```
|
|
||||||
# In reality, obj_path() can just be part of path(), since path() should expect
|
|
||||||
# kwargs to be present and can safely assume it's being called as obj_path()
|
|
||||||
# otherwise.
|
|
||||||
#
|
|
||||||
# Caveat, this will only work for save functions that use an open file.
|
|
||||||
# Save functions that take a path would not work.
|
|
||||||
#
|
|
||||||
# More practical case:
|
|
||||||
# ```
|
|
||||||
# fn_a = load_or_create()(fn_a)
|
|
||||||
#
|
|
||||||
# def fn_b(obj_a, cfg):
|
|
||||||
# do_something(obj_a, cfg)
|
|
||||||
#
|
|
||||||
# fn_b = load_or_create(
|
|
||||||
# path_fn=lambda obj_a, hash: fn_a.get_path(obj_a) / hash
|
|
||||||
# )(fn_b)
|
|
||||||
# ```
|
|
||||||
# This avoids having to make pass obj_a's config to fn_b in order to determine the path!
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user