path_from_obj which can accept the obj or the LoadOrCreateCFG of the obj

This commit is contained in:
JJJHolscher 2024-08-06 14:51:30 +02:00
parent 746b2cf6b0
commit b8c72498ba

View File

@ -1,5 +1,6 @@
from hashlib import sha256
import inspect
import io
import json
from pathlib import Path
import pickle
@ -23,6 +24,30 @@ def with_filtered_args(fn):
return inner, arg_filter
class HashingWriter(io.BytesIO):
def __init__(self):
super().__init__() # Initialize the BytesIO buffer
self.hash = sha256() # Initialize the SHA-256 hash object
def write(self, b):
self.hash.update(b) # Update the hash with the data being written
return 0
# return super().write(b) # Write the data to the BytesIO buffer
def writelines(self, lines):
for line in lines:
self.write(line)
def get_hash(self):
return self.hash.hexdigest() # Return the hexadecimal digest of the hash
class LoadOrCreateCFG:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
class load_or_create:
def __init__(
@ -32,6 +57,7 @@ class load_or_create:
path_fn=lambda hash: hash + ".pkl",
hash_len=8,
save_args=False,
save_json=False,
plain_text=False
):
self.load=with_filtered_args(load)[0]
@ -41,6 +67,7 @@ class load_or_create:
self.path_fn=path_fn
self.hash_len=hash_len
self.save_args=save_args
self.save_json=save_json
self.plain_text=plain_text
def __call__(self, fn):
@ -52,18 +79,24 @@ class inner(load_or_create):
self.__dict__.update(parent.__dict__)
self.fn = fn
self.arg_names = [p.name for p in inspect.signature(fn).parameters.values()]
self.obj_to_args = dict()
def __call__(self, *args, **kwargs):
# Store the keyword arguments into json and hash it to get the storage path.
path = self.path(*args, **kwargs)
merged_args = self.args_to_kwargs(args, kwargs, path=path)
obj = self.load_wrapper(**self.args_to_kwargs(args, kwargs, path=path))
if obj is not None: return obj
obj = self.load_wrapper(**merged_args)
if obj is not None:
if "file" in self.save_arg_names: self.hash_obj({"path": path} | kwargs)
return obj
obj = self.fn(*args, **kwargs)
if obj is None: return obj
self.save_wrapper(obj, **self.args_to_kwargs(args, kwargs, path=path))
self.save_wrapper(obj, *args, **{"path": path} | kwargs)
if "file" in self.save_arg_names: self.hash_obj({"path": path} | kwargs)
if self.save_json: path.with_suffix(".json").write_bytes(self.to_json(**kwargs))
return obj
@ -94,24 +127,28 @@ class inner(load_or_create):
return self.load(**{"file": file} | load_args)
def save_wrapper(self, obj, path, **save_args):
def save_wrapper(self, obj, *args, **kwargs):
"""Only open a file if the save function requests you to.
"""
Path(path).parent.mkdir(parents=True, exist_ok=True)
if "file" not in self.save_arg_names:
return self.save(obj, **{"path": path} | save_args)
merged_args = self.args_to_kwargs(args, kwargs)
if "file" in save_args:
file = save_args["file"]
path = Path(merged_args["path"])
path.parent.mkdir(parents=True, exist_ok=True)
if "file" not in self.save_arg_names:
return self.save(obj, **merged_args)
if "file" in merged_args:
file = merged_args["file"]
else:
file = open(path, "w" if self.plain_text else "wb")
if self.save_args:
file.write(self.to_json(**save_args))
file.write(self.to_json(**kwargs))
self.save(obj, **{"file": file, "path": path} | save_args)
self.save(obj, **{"file": file} | merged_args)
if "file" not in save_args:
if "file" not in merged_args:
file.close()
@ -125,7 +162,10 @@ class inner(load_or_create):
) + "\n"
).encode("utf-8")
def make_hash(self, **kwargs):
# Functions to do with hashing keyword arguments and determining an
# object's path from its arguments.
def hash_kwargs(self, **kwargs):
"""Serialize all keyword arguments to json and hash it."""
json_args = self.to_json(**kwargs)
hash_str = sha256(json_args, usedforsecurity=False).hexdigest()
@ -134,7 +174,7 @@ class inner(load_or_create):
def path_and_hash(self, *args, **kwargs) -> Tuple[Path, str]:
path_fn, arg_filter = with_filtered_args(self.path_fn)
hash = self.make_hash(
hash = self.hash_kwargs(
# Do not have arguments in the hash that path_fn uses.
**{k: a for k, a in kwargs.items() if k not in arg_filter}
)
@ -157,65 +197,49 @@ class inner(load_or_create):
def dir(self, *args, **kwargs) -> Path:
return self.path(*args, **kwargs).parent
# Functions to do with inferring an object's arguments by keeping track of
# its hash.
def hash_obj(self, kwargs):
"""After saving or loading the object, we get its hash from the
storage file.
This hash we can later use, to infer the arguments that created this
object.
"""
path = Path(kwargs["path"])
if not path.exists():
return
with open(path, "rb") as file:
if self.save_args:
file.readline()
hash = sha256(file.read()).hexdigest()
self.obj_to_args[hash] = kwargs
def args_from_obj(self, obj, *args, **kwargs):
"""Hash the object using the user-provided self.save.
Then retrieve its arguments by looking up hashes of previously loaded
or saved objects.
Also ssee self.hash_obj
"""
file = HashingWriter()
self.save(obj, **self.args_to_kwargs(args, kwargs), file=file)
hash = file.get_hash()
return self.obj_to_args[hash]
def path_from_obj(self, obj, *args, **kwargs):
if isinstance(obj, LoadOrCreateCFG):
return self.path(*obj.args, **obj.kwargs)
return self.args_from_obj(obj, *args, **kwargs)["path"]
def dir_from_obj(self, obj, *args, **kwargs):
return self.path_from_obj(obj, *args, **kwargs).parent
def cfg(self, *args, **kwargs):
"""If you don't care about some object, but only about it's path,
but you still need to pass an object to some other function in order
to get it's path, you can pass a LoadOrCreateCFG instead, saving you
from loading or creating that object..
"""
return LoadOrCreateCFG(*args, **kwargs)
#%% [markdown]
# A big possible extension is to hash the returned object.
# If we do that, we can do something like this:
# ```
# fn = load_or_create(orig_fn)
# obj = fn(*args)
# assert fn.path(*args) == fn.obj_path(obj)
# ```
#
# So, we'd get a _path_ from an object!
#
# How would this work out?
#
# first, after loading and/or storing, we save the hash as we store the object.
# ```
# obj_hash = get_hash_from_flie(path)
# self.cache[obj_hash] = [path, kwargs]
# ```
#
# then, in self.obj_path, we fake-save the object using the provided save function.
# instead of saving to a file, we save to some class that inherited io.BytesIO:
# ```
# class HashingWriter(io.BytesIO):
# def __init__(self):
# super().__init__() # Initialize the BytesIO buffer
# self.hash = hashlib.sha256() # Initialize the SHA-256 hash object
#
# def write(self, b):
# return self.hash.update(b) # Update the hash with the data being written
#
# def get_hash(self):
# return self.hash.hexdigest() # Return the hexadecimal digest of the hash
# ```
#
# so self.obj_path would look something like:
# ```
# def obj_path(self, obj):
# with open(HashingWriter(), "wb") as f:
# f.write(obj)
# if f.get_hash() in self.cache:
# return self.cache[hash][0]
# ```
# In reality, obj_path() can just be part of path(), since path() should expect
# kwargs to be present and can safely assume it's being called as obj_path()
# otherwise.
#
# Caveat, this will only work for save functions that use an open file.
# Save functions that take a path would not work.
#
# More practical case:
# ```
# fn_a = load_or_create()(fn_a)
#
# def fn_b(obj_a, cfg):
# do_something(obj_a, cfg)
#
# fn_b = load_or_create(
# path_fn=lambda obj_a, hash: fn_a.get_path(obj_a) / hash
# )(fn_b)
# ```
# This avoids having to make pass obj_a's config to fn_b in order to determine the path!