diff --git a/src/store.py b/src/store.py index e773a8e..5b8e4ff 100644 --- a/src/store.py +++ b/src/store.py @@ -1,5 +1,6 @@ from hashlib import sha256 import inspect +import io import json from pathlib import Path import pickle @@ -23,6 +24,30 @@ def with_filtered_args(fn): return inner, arg_filter +class HashingWriter(io.BytesIO): + def __init__(self): + super().__init__() # Initialize the BytesIO buffer + self.hash = sha256() # Initialize the SHA-256 hash object + + def write(self, b): + self.hash.update(b) # Update the hash with the data being written + return 0 + # return super().write(b) # Write the data to the BytesIO buffer + + def writelines(self, lines): + for line in lines: + self.write(line) + + def get_hash(self): + return self.hash.hexdigest() # Return the hexadecimal digest of the hash + + +class LoadOrCreateCFG: + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + + class load_or_create: def __init__( @@ -32,6 +57,7 @@ class load_or_create: path_fn=lambda hash: hash + ".pkl", hash_len=8, save_args=False, + save_json=False, plain_text=False ): self.load=with_filtered_args(load)[0] @@ -41,6 +67,7 @@ class load_or_create: self.path_fn=path_fn self.hash_len=hash_len self.save_args=save_args + self.save_json=save_json self.plain_text=plain_text def __call__(self, fn): @@ -52,18 +79,24 @@ class inner(load_or_create): self.__dict__.update(parent.__dict__) self.fn = fn self.arg_names = [p.name for p in inspect.signature(fn).parameters.values()] + self.obj_to_args = dict() def __call__(self, *args, **kwargs): # Store the keyword arguments into json and hash it to get the storage path. path = self.path(*args, **kwargs) + merged_args = self.args_to_kwargs(args, kwargs, path=path) - obj = self.load_wrapper(**self.args_to_kwargs(args, kwargs, path=path)) - if obj is not None: return obj + obj = self.load_wrapper(**merged_args) + if obj is not None: + if "file" in self.save_arg_names: self.hash_obj({"path": path} | kwargs) + return obj obj = self.fn(*args, **kwargs) if obj is None: return obj - self.save_wrapper(obj, **self.args_to_kwargs(args, kwargs, path=path)) + self.save_wrapper(obj, *args, **{"path": path} | kwargs) + if "file" in self.save_arg_names: self.hash_obj({"path": path} | kwargs) + if self.save_json: path.with_suffix(".json").write_bytes(self.to_json(**kwargs)) return obj @@ -94,24 +127,28 @@ class inner(load_or_create): return self.load(**{"file": file} | load_args) - def save_wrapper(self, obj, path, **save_args): + def save_wrapper(self, obj, *args, **kwargs): """Only open a file if the save function requests you to. """ - Path(path).parent.mkdir(parents=True, exist_ok=True) - if "file" not in self.save_arg_names: - return self.save(obj, **{"path": path} | save_args) + merged_args = self.args_to_kwargs(args, kwargs) - if "file" in save_args: - file = save_args["file"] + path = Path(merged_args["path"]) + path.parent.mkdir(parents=True, exist_ok=True) + + if "file" not in self.save_arg_names: + return self.save(obj, **merged_args) + + if "file" in merged_args: + file = merged_args["file"] else: file = open(path, "w" if self.plain_text else "wb") if self.save_args: - file.write(self.to_json(**save_args)) + file.write(self.to_json(**kwargs)) - self.save(obj, **{"file": file, "path": path} | save_args) + self.save(obj, **{"file": file} | merged_args) - if "file" not in save_args: + if "file" not in merged_args: file.close() @@ -125,7 +162,10 @@ class inner(load_or_create): ) + "\n" ).encode("utf-8") - def make_hash(self, **kwargs): + # Functions to do with hashing keyword arguments and determining an + # object's path from its arguments. + + def hash_kwargs(self, **kwargs): """Serialize all keyword arguments to json and hash it.""" json_args = self.to_json(**kwargs) hash_str = sha256(json_args, usedforsecurity=False).hexdigest() @@ -134,7 +174,7 @@ class inner(load_or_create): def path_and_hash(self, *args, **kwargs) -> Tuple[Path, str]: path_fn, arg_filter = with_filtered_args(self.path_fn) - hash = self.make_hash( + hash = self.hash_kwargs( # Do not have arguments in the hash that path_fn uses. **{k: a for k, a in kwargs.items() if k not in arg_filter} ) @@ -157,65 +197,49 @@ class inner(load_or_create): def dir(self, *args, **kwargs) -> Path: return self.path(*args, **kwargs).parent + # Functions to do with inferring an object's arguments by keeping track of + # its hash. + + def hash_obj(self, kwargs): + """After saving or loading the object, we get its hash from the + storage file. + This hash we can later use, to infer the arguments that created this + object. + """ + path = Path(kwargs["path"]) + if not path.exists(): + return + with open(path, "rb") as file: + if self.save_args: + file.readline() + hash = sha256(file.read()).hexdigest() + self.obj_to_args[hash] = kwargs + + def args_from_obj(self, obj, *args, **kwargs): + """Hash the object using the user-provided self.save. + Then retrieve its arguments by looking up hashes of previously loaded + or saved objects. + Also ssee self.hash_obj + """ + file = HashingWriter() + self.save(obj, **self.args_to_kwargs(args, kwargs), file=file) + hash = file.get_hash() + return self.obj_to_args[hash] + + def path_from_obj(self, obj, *args, **kwargs): + if isinstance(obj, LoadOrCreateCFG): + return self.path(*obj.args, **obj.kwargs) + return self.args_from_obj(obj, *args, **kwargs)["path"] + + def dir_from_obj(self, obj, *args, **kwargs): + return self.path_from_obj(obj, *args, **kwargs).parent + + def cfg(self, *args, **kwargs): + """If you don't care about some object, but only about it's path, + but you still need to pass an object to some other function in order + to get it's path, you can pass a LoadOrCreateCFG instead, saving you + from loading or creating that object.. + """ + return LoadOrCreateCFG(*args, **kwargs) + -#%% [markdown] -# A big possible extension is to hash the returned object. -# If we do that, we can do something like this: -# ``` -# fn = load_or_create(orig_fn) -# obj = fn(*args) -# assert fn.path(*args) == fn.obj_path(obj) -# ``` -# -# So, we'd get a _path_ from an object! -# -# How would this work out? -# -# first, after loading and/or storing, we save the hash as we store the object. -# ``` -# obj_hash = get_hash_from_flie(path) -# self.cache[obj_hash] = [path, kwargs] -# ``` -# -# then, in self.obj_path, we fake-save the object using the provided save function. -# instead of saving to a file, we save to some class that inherited io.BytesIO: -# ``` -# class HashingWriter(io.BytesIO): -# def __init__(self): -# super().__init__() # Initialize the BytesIO buffer -# self.hash = hashlib.sha256() # Initialize the SHA-256 hash object -# -# def write(self, b): -# return self.hash.update(b) # Update the hash with the data being written -# -# def get_hash(self): -# return self.hash.hexdigest() # Return the hexadecimal digest of the hash -# ``` -# -# so self.obj_path would look something like: -# ``` -# def obj_path(self, obj): -# with open(HashingWriter(), "wb") as f: -# f.write(obj) -# if f.get_hash() in self.cache: -# return self.cache[hash][0] -# ``` -# In reality, obj_path() can just be part of path(), since path() should expect -# kwargs to be present and can safely assume it's being called as obj_path() -# otherwise. -# -# Caveat, this will only work for save functions that use an open file. -# Save functions that take a path would not work. -# -# More practical case: -# ``` -# fn_a = load_or_create()(fn_a) -# -# def fn_b(obj_a, cfg): -# do_something(obj_a, cfg) -# -# fn_b = load_or_create( -# path_fn=lambda obj_a, hash: fn_a.get_path(obj_a) / hash -# )(fn_b) -# ``` -# This avoids having to make pass obj_a's config to fn_b in order to determine the path!