Compare commits
10 Commits
086fad8602
...
90146757f1
Author | SHA1 | Date | |
---|---|---|---|
90146757f1 | |||
7048bb058a | |||
b8c72498ba | |||
746b2cf6b0 | |||
35396c80ca | |||
fe8ae1fda7 | |||
58a965bc8e | |||
85cc1373cb | |||
b52d72d087 | |||
b6be9cee2d |
218
README.md
218
README.md
|
@ -1,12 +1,212 @@
|
|||
# README
|
||||
|
||||
## files
|
||||
A bunch of functionality in here, but the main one is `store.load_or_create`.
|
||||
It might get its own package someday, when it's less before-pre-alpha-sneakpeak-demo as it's now.
|
||||
|
||||
______________________________________________________________________
|
||||
|.gitignore |a list of unshared files|
|
||||
|makefile|dev tools for installing, publisizing etc.|
|
||||
|pyproject.toml |project metadata|
|
||||
|requirements.txt |python dependencies|
|
||||
|setup.py|necessary for `pip install -e .`|
|
||||
|src/main.py|first file that gets called|
|
||||
--------------------------------------------
|
||||
## load_or_create
|
||||
|
||||
The main idea is that an object's storage path should be inferable from the arguments with which it was created.
|
||||
In reality we often keep track of an object and its path seperately, which can work out poorly.
|
||||
The `load_or_create` function bundles an object's creation, load, save and path function together such that anyone working with the code less often has to think about storage locations and more often can focus on the functionality of their code.
|
||||
|
||||
### the status quo
|
||||
|
||||
A common pattern has you write the following 4 functions for objects that are expensive to create:
|
||||
|
||||
```python
|
||||
from pathlib import Path
|
||||
import pickle # We use pickle as an example here.
|
||||
|
||||
def create_obj(name, some_kwarg=1):
|
||||
"""This function takes long, so we don't want to run it redundantly."""
|
||||
return obj
|
||||
|
||||
def save_obj(obj, path):
|
||||
"""Instead we write a save function..."""
|
||||
with open(path, "wb") as file:
|
||||
pickle.dump(obj, file)
|
||||
|
||||
def load_obj(path):
|
||||
"""...and a load function so we only need to create the object once."""
|
||||
with open(path, "rb") as file:
|
||||
return pickle.load(file)
|
||||
|
||||
def infer_path(name):
|
||||
"""We need to keep track where the created object is stored."""
|
||||
return "./obj/" + str(name) + ".pkl"
|
||||
|
||||
# When you have the above 4 fuctions, this pattern will start to occur.
|
||||
name = "MyObject"
|
||||
path = infer_path(name)
|
||||
if Path(path).exists():
|
||||
obj = load_obj(load)
|
||||
else:
|
||||
obj = create_obj(name, some_kwarg=0)
|
||||
save_obj(obj, path)
|
||||
```
|
||||
|
||||
In some cases, you want to create and save many variations of the object.
|
||||
It might be better to hash its characteristics and use that as part of the path.
|
||||
|
||||
```python
|
||||
import sha256
|
||||
import json
|
||||
|
||||
def infer_path(name, **some_other_kwargs):
|
||||
hash = str(sha256(json.dumps(some_other_kwargs)).hexdigest())
|
||||
return "./obj/" + hash + ".pkl"
|
||||
```
|
||||
|
||||
### the problem
|
||||
|
||||
The above is fine and dandy, but when someone wants to use your obj,
|
||||
they'd need to keep track of your 4 separate functions.
|
||||
|
||||
You can dress it up as such:
|
||||
```python
|
||||
def get_obj(name, some_kwarg):
|
||||
path = infer_path(name)
|
||||
if path.exists():
|
||||
obj = load_obj(load)
|
||||
else:
|
||||
obj = create_obj(name, some_kwarg=some_kwarg)
|
||||
save_obj(obj, path)
|
||||
return obj
|
||||
```
|
||||
But that takes a lot of freedom away from your user, who might have their
|
||||
own ideas on where and how the object should be loaded or stored.
|
||||
|
||||
### the solution
|
||||
|
||||
```python
|
||||
from jo3util.store import load_or_create
|
||||
get_obj = load_or_create(
|
||||
load=load_obj,
|
||||
save=save_obj,
|
||||
path_fn=infer_path,
|
||||
)(create_obj)
|
||||
|
||||
obj = get_obj(name, some_kwarg=0)
|
||||
|
||||
# We can now infer the path of an object from its creation arguments.
|
||||
path = get_obj.path(name, some_kwarg=0)
|
||||
|
||||
# We also can use `get_obj.path_of_obj` to recall the path of any object
|
||||
# that `get_obj` returned in he past.
|
||||
assert path == get_obj.path_of_obj(obj)
|
||||
```
|
||||
|
||||
You can now elegantly pack the four functions together.
|
||||
But you still have the flexibility to alter the path function on the fly:
|
||||
|
||||
```python
|
||||
get_obj.path_fn = lambda hash: f"./{hash}.pkl"
|
||||
```
|
||||
|
||||
Now, storing different objects of which one is dependent on the other, becomes intuitive and elegant:
|
||||
|
||||
```python
|
||||
# This code is written at the library level
|
||||
|
||||
get_human = load_or_create(
|
||||
path_fn=lambda name: "./" + name + "/body.pkl"
|
||||
# If you omit the save and load functions, load_or_create will use pickle.
|
||||
)(lambda name: name)
|
||||
|
||||
get_finger_print = load_or_create(
|
||||
path_fn=lambda human, finger: get_human.dir_from_obj(human) / f"{finger}.print"
|
||||
)(lambda human, finger: f"{human}'s finger the {finger}")
|
||||
|
||||
# This code is what a user can work with.
|
||||
|
||||
assert not get_human.path("john").exists() # ./john/body.pkl
|
||||
human = get_human("john")
|
||||
assert get_human.path("john").exists()
|
||||
|
||||
finger_print = get_finger_print(human, "thumb")
|
||||
assert get_finger_print.path(human, "thumb") == "./john/thumb.print"
|
||||
```
|
||||
|
||||
The finger print is now always stored in the same directory as where the human's `body.pkl` is stored.
|
||||
You don't need to keep track of the location of `body.pkl`.
|
||||
|
||||
### four functions in one
|
||||
|
||||
The main trick is to match the parameter names of the `create` function (in our case `create_obj`)
|
||||
with those of the three other subfunctions (in our case `load_obj`, `save_obj` and `infer_path`).
|
||||
|
||||
The three subfunctions's allowed parameters are mostly a non-strict superset of the create function's
|
||||
parameters.
|
||||
|
||||
When you call the `load_or_create`-wrapped `get_obj`, something like this happens:
|
||||
|
||||
```python
|
||||
def call_fn_with_filtered_arguments(fn, *args, **kwargs):
|
||||
""" call `fn` with only the subset of `args` and `kwargs` that it expects.
|
||||
|
||||
This is necessary, as python will complain if a function receives any
|
||||
argument for which there is no function parameter.
|
||||
So
|
||||
def fn(a):
|
||||
pass
|
||||
fn(a=0, b=1)
|
||||
will error, so we need to remove b before calling fn.
|
||||
|
||||
This example function is wrong, if you're curious you need to check the
|
||||
source code.
|
||||
"""
|
||||
# Get the names of the paremeters that `fn` accepts.
|
||||
path_parameters = get_parameters_that_fn_expects(fn)
|
||||
# Filter for positinoal arguments that `fn` accepts.
|
||||
args = [a for i, a in enumerate(args) if name_of_positional(i, fn) in path_parameters]
|
||||
# Filter for keyword arguments that `fn` accepts.
|
||||
kwargs = {k: a for k, a in kwargs.items() if k in path_parameters}
|
||||
# Call `fn` with the filtered subset of the original args and kwargs.
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
def get_obj_pseudo_code(*args, **kwargs):
|
||||
hash = some_hash_fn(*args, **kwargs)
|
||||
path = call_fn_with_filtered_arguments(infer_path, *args, hash=hash, **kwargs)
|
||||
if path.exists():
|
||||
return call_fn_with_filtered_arguments(
|
||||
load_obj,
|
||||
*args,
|
||||
path=path,
|
||||
file=open(path, "rb"),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
obj = create_obj(*args, **kwargs)
|
||||
call_fn_with_filtered_arguments(
|
||||
save_obj,
|
||||
qbj,
|
||||
*args,
|
||||
path=path,
|
||||
file=open(path, "wb"),
|
||||
**kwargs
|
||||
)
|
||||
return obj
|
||||
```
|
||||
|
||||
So, the load, save and path functions you provide do not have to have the same signature as the create
|
||||
function does, but you can call them _as if_ they are the create function.
|
||||
|
||||
### philosophy
|
||||
|
||||
The main idea is that some object's storage location should be inferrable from the arguments
|
||||
during its creation call.
|
||||
|
||||
In reality, we tend to separately keep track of some object's path, its arguments and itself.
|
||||
This tends to go bad when we need to load, save or create the object in some other context.
|
||||
It becomes easy to forget where some object ought to be stored.
|
||||
Or it can happen that different places where the same object is handled, have different opinions on its storage location.
|
||||
|
||||
It can lead to duplicates; forgetting where the object was stored; or losing a folder of data
|
||||
because the folder is too unwieldy to salvage.
|
||||
|
||||
By packaging a function with its load and save countparts and a default storage location, we don't
|
||||
need to worry about the storage location anymore and can focus on creating and using our objects.
|
||||
|
||||
If we ever do change our minds on the ideal storage location, then there is an obvious central place
|
||||
for changing it, and that change then easily immediately applies to _all_ the places where
|
||||
that object's path needs to be determined.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
|
||||
[project]
|
||||
name = "jo3util"
|
||||
version = "0.0.13"
|
||||
version = "0.0.18"
|
||||
description = ""
|
||||
dependencies = []
|
||||
dynamic = ["readme"]
|
||||
|
|
|
@ -77,12 +77,12 @@ def sow(where: Callable, model: eqx.Module) -> eqx.Module:
|
|||
activ.append(x)
|
||||
return x
|
||||
|
||||
return eqx.filter_jit(store_activation)
|
||||
return store_activation
|
||||
|
||||
model = eqx.tree_at(where, model, replace_fn=install_sow)
|
||||
|
||||
model_call = type(model).__call__
|
||||
todo("make Sow a generic class")
|
||||
todo("make Sow a generic class, also don't have nested Sows but check whether model is already a sow.")
|
||||
|
||||
class Sow(type(model)):
|
||||
def __init__(self, model):
|
||||
|
@ -91,7 +91,10 @@ def sow(where: Callable, model: eqx.Module) -> eqx.Module:
|
|||
def __call__(self, *args, **kwargs):
|
||||
activ.clear() # empty the list
|
||||
x = model_call(self, *args, **kwargs)
|
||||
return activ + [x]
|
||||
if isinstance(x, list):
|
||||
return activ + x
|
||||
else:
|
||||
return activ + [x]
|
||||
|
||||
return Sow(model)
|
||||
|
||||
|
|
15
src/root.py
15
src/root.py
|
@ -5,7 +5,7 @@ import os
|
|||
import json
|
||||
from importlib.resources import files, as_file
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import __main__
|
||||
from .string import hash_string
|
||||
|
@ -81,10 +81,13 @@ def root_file():
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
def run_dir(obj, root=Path(os.path.abspath("./run"))):
|
||||
run_id: str = json.dumps(
|
||||
def run_dir(
|
||||
obj,
|
||||
root: Union[Path, str] = root_dir() / "run",
|
||||
name_len=8
|
||||
):
|
||||
obj_hash: str = hash_string(json.dumps(
|
||||
obj,
|
||||
default=lambda x: vars(x) if hasattr(x, "__dict__") else str(x)
|
||||
)
|
||||
run_id: str = hash_string(run_id)
|
||||
return root / run_id
|
||||
))[:name_len]
|
||||
return root / Path(obj_hash)
|
||||
|
|
245
src/store.py
Normal file
245
src/store.py
Normal file
|
@ -0,0 +1,245 @@
|
|||
from hashlib import sha256
|
||||
import inspect
|
||||
import io
|
||||
import json
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def with_filtered_args(fn):
|
||||
"""
|
||||
Return a version of the input function that ignores instead of errors on
|
||||
unknown arguments.
|
||||
"""
|
||||
arg_filter = {param.name for param in inspect.signature(fn).parameters.values()
|
||||
if param.kind == param.POSITIONAL_OR_KEYWORD
|
||||
or param.kind == param.KEYWORD_ONLY}
|
||||
|
||||
def inner(*args, **kwargs):
|
||||
# Filter any keyword arguments not present in small_fn.
|
||||
kwargs = {k: v for k, v in kwargs.items() if k in arg_filter}
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return inner, arg_filter
|
||||
|
||||
|
||||
class HashingWriter(io.BytesIO):
|
||||
def __init__(self):
|
||||
super().__init__() # Initialize the BytesIO buffer
|
||||
self.hash = sha256() # Initialize the SHA-256 hash object
|
||||
|
||||
def write(self, b):
|
||||
self.hash.update(b) # Update the hash with the data being written
|
||||
return 0
|
||||
# return super().write(b) # Write the data to the BytesIO buffer
|
||||
|
||||
def writelines(self, lines):
|
||||
for line in lines:
|
||||
self.write(line)
|
||||
|
||||
def get_hash(self):
|
||||
return self.hash.hexdigest() # Return the hexadecimal digest of the hash
|
||||
|
||||
|
||||
class LoadOrCreateCFG:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
||||
class load_or_create:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
load=pickle.load,
|
||||
save=pickle.dump,
|
||||
path_fn=lambda hash: hash + ".pkl",
|
||||
hash_len=8,
|
||||
save_args=False,
|
||||
save_json=False,
|
||||
plain_text=False
|
||||
):
|
||||
self.load=with_filtered_args(load)[0]
|
||||
self.load_arg_names = {p.name for p in inspect.signature(load).parameters.values()}
|
||||
self.save=with_filtered_args(save)[0]
|
||||
self.save_arg_names = {p.name for p in inspect.signature(save).parameters.values()}
|
||||
self.path_fn=path_fn
|
||||
self.hash_len=hash_len
|
||||
self.save_args=save_args
|
||||
self.save_json=save_json
|
||||
self.plain_text=plain_text
|
||||
|
||||
def __call__(self, fn):
|
||||
return inner(self, fn)
|
||||
|
||||
|
||||
class inner(load_or_create):
|
||||
def __init__(self, parent, fn):
|
||||
self.__dict__.update(parent.__dict__)
|
||||
self.fn = fn
|
||||
self.arg_names = [p.name for p in inspect.signature(fn).parameters.values()]
|
||||
self.obj_to_args = dict()
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# Store the keyword arguments into json and hash it to get the storage path.
|
||||
path = self.path(*args, **kwargs)
|
||||
merged_args = self.args_to_kwargs(args, kwargs, path=path)
|
||||
|
||||
obj = self.load_wrapper(**merged_args)
|
||||
if obj is not None:
|
||||
if "file" in self.save_arg_names: self.hash_obj({"path": path} | kwargs)
|
||||
return obj
|
||||
|
||||
obj = self.fn(*args, **kwargs)
|
||||
if obj is None: return obj
|
||||
|
||||
self.save_wrapper(obj, *args, **{"path": path} | kwargs)
|
||||
if "file" in self.save_arg_names: self.hash_obj({"path": path} | kwargs)
|
||||
if self.save_json: path.with_suffix(".kwargs.json").write_bytes(self.to_json(**kwargs))
|
||||
|
||||
return obj
|
||||
|
||||
def args_to_kwargs(self, args, kwargs, **extra):
|
||||
return extra | kwargs | {self.arg_names[i]: a for i, a in enumerate(args)}
|
||||
|
||||
def load_wrapper(self, **load_args):
|
||||
"""
|
||||
load_wrapper returns None or an abject.
|
||||
Iff it returns None, the object is deemed not-loaded and thus,
|
||||
self.fn needs to be called to create the object.
|
||||
It can return None, because self.load returns None or if self.load
|
||||
expects an open file, but no such file exists.
|
||||
"""
|
||||
# Check whether self.load does not expect an open file.
|
||||
if "file" in load_args or "file" not in self.load_arg_names:
|
||||
return self.load(**load_args)
|
||||
|
||||
path = load_args["path"]
|
||||
# If self.load expects an open file but there is none, run self.fn.
|
||||
if not path.exists():
|
||||
return None
|
||||
|
||||
with open(path, "r" if self.plain_text else "rb") as file:
|
||||
# If specified, the first line is a json of the keyword arguments.
|
||||
if self.save_args:
|
||||
file.readline()
|
||||
return self.load(**{"file": file} | load_args)
|
||||
|
||||
|
||||
def save_wrapper(self, obj, *args, **kwargs):
|
||||
"""Only open a file if the save function requests you to.
|
||||
"""
|
||||
merged_args = self.args_to_kwargs(args, kwargs)
|
||||
|
||||
path = Path(merged_args["path"])
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if "file" not in self.save_arg_names:
|
||||
return self.save(obj, **merged_args)
|
||||
|
||||
if "file" in merged_args:
|
||||
file = merged_args["file"]
|
||||
else:
|
||||
file = open(path, "w" if self.plain_text else "wb")
|
||||
|
||||
if self.save_args:
|
||||
file.write(self.to_json(**kwargs))
|
||||
|
||||
self.save(obj, **{"file": file} | merged_args)
|
||||
|
||||
if "file" not in merged_args:
|
||||
file.close()
|
||||
|
||||
|
||||
def to_json(self, **kwargs) -> bytes:
|
||||
"""Serialize all keyword arguments to json.
|
||||
We never serialize positional arguments"""
|
||||
return (
|
||||
json.dumps(
|
||||
kwargs,
|
||||
default=lambda x: vars(x) if hasattr(x, "__dict__") else str(x)
|
||||
) + "\n"
|
||||
).encode("utf-8")
|
||||
|
||||
# Functions to do with hashing keyword arguments and determining an
|
||||
# object's path from its arguments.
|
||||
|
||||
def hash_kwargs(self, **kwargs):
|
||||
"""Serialize all keyword arguments to json and hash it."""
|
||||
json_args = self.to_json(**kwargs)
|
||||
hash_str = sha256(json_args, usedforsecurity=False).hexdigest()
|
||||
return hash_str[:self.hash_len]
|
||||
|
||||
def path_and_hash(self, *args, **kwargs) -> Tuple[Path, str]:
|
||||
path_fn, arg_filter = with_filtered_args(self.path_fn)
|
||||
|
||||
hash = self.hash_kwargs(
|
||||
# Do not have arguments in the hash that path_fn uses.
|
||||
**{k: a for k, a in kwargs.items() if k not in arg_filter}
|
||||
)
|
||||
|
||||
path_args = self.args_to_kwargs(args, kwargs, hash=hash)
|
||||
|
||||
path = Path(path_fn(**path_args))
|
||||
|
||||
return path, hash
|
||||
|
||||
def hash(self, *args, **kwargs) -> str:
|
||||
"""Hash the keyword arguments.
|
||||
Note that the hash is dependent on the path function of this instance:
|
||||
All arguments of the path function are excluded from the hash."""
|
||||
return self.path_and_hash(*args, **kwargs)[1]
|
||||
|
||||
def path(self, *args, **kwargs) -> Path:
|
||||
return self.path_and_hash(*args, **kwargs)[0]
|
||||
|
||||
def dir(self, *args, **kwargs) -> Path:
|
||||
return self.path(*args, **kwargs).parent
|
||||
|
||||
# Functions to do with inferring an object's arguments by keeping track of
|
||||
# its hash.
|
||||
|
||||
def hash_obj(self, kwargs):
|
||||
"""After saving or loading the object, we get its hash from the
|
||||
storage file.
|
||||
This hash we can later use, to infer the arguments that created this
|
||||
object.
|
||||
"""
|
||||
path = Path(kwargs["path"])
|
||||
if not path.exists():
|
||||
return
|
||||
with open(path, "rb") as file:
|
||||
if self.save_args:
|
||||
file.readline()
|
||||
hash = sha256(file.read()).hexdigest()
|
||||
self.obj_to_args[hash] = kwargs
|
||||
|
||||
def args_from_obj(self, obj, *args, **kwargs):
|
||||
"""Hash the object using the user-provided self.save.
|
||||
Then retrieve its arguments by looking up hashes of previously loaded
|
||||
or saved objects.
|
||||
Also ssee self.hash_obj
|
||||
"""
|
||||
file = HashingWriter()
|
||||
self.save(obj, **self.args_to_kwargs(args, kwargs), file=file)
|
||||
hash = file.get_hash()
|
||||
return self.obj_to_args[hash]
|
||||
|
||||
def path_from_obj(self, obj, *args, **kwargs):
|
||||
if isinstance(obj, LoadOrCreateCFG):
|
||||
return self.path(*obj.args, **obj.kwargs)
|
||||
return self.args_from_obj(obj, *args, **kwargs)["path"]
|
||||
|
||||
def dir_from_obj(self, obj, *args, **kwargs):
|
||||
return self.path_from_obj(obj, *args, **kwargs).parent
|
||||
|
||||
def cfg(self, *args, **kwargs):
|
||||
"""If you don't care about some object, but only about it's path,
|
||||
but you still need to pass an object to some other function in order
|
||||
to get it's path, you can pass a LoadOrCreateCFG instead, saving you
|
||||
from loading or creating that object..
|
||||
"""
|
||||
return LoadOrCreateCFG(*args, **kwargs)
|
||||
|
||||
|
|
@ -14,5 +14,10 @@ class ToDoWarning(Warning):
|
|||
def __str__(self):
|
||||
return repr(self.message)
|
||||
|
||||
PAST_TODO_MESSAGES = set()
|
||||
|
||||
def todo(msg):
|
||||
warnings.warn(msg, ToDoWarning)
|
||||
global PAST_TODO_MESSAGES
|
||||
if msg not in PAST_TODO_MESSAGES:
|
||||
warnings.warn(msg, ToDoWarning)
|
||||
PAST_TODO_MESSAGES.add(msg)
|
||||
|
|
Loading…
Reference in New Issue
Block a user