diff --git a/src/eqx.py b/src/eqx.py index e05ea5b..08add64 100644 --- a/src/eqx.py +++ b/src/eqx.py @@ -77,12 +77,12 @@ def sow(where: Callable, model: eqx.Module) -> eqx.Module: activ.append(x) return x - return eqx.filter_jit(store_activation) + return store_activation model = eqx.tree_at(where, model, replace_fn=install_sow) model_call = type(model).__call__ - todo("make Sow a generic class") + todo("make Sow a generic class, also don't have nested Sows but check whether model is already a sow.") class Sow(type(model)): def __init__(self, model): @@ -91,7 +91,10 @@ def sow(where: Callable, model: eqx.Module) -> eqx.Module: def __call__(self, *args, **kwargs): activ.clear() # empty the list x = model_call(self, *args, **kwargs) - return activ + [x] + if isinstance(x, list): + return activ + x + else: + return activ + [x] return Sow(model) diff --git a/src/store.py b/src/store.py new file mode 100644 index 0000000..c5b34bd --- /dev/null +++ b/src/store.py @@ -0,0 +1,115 @@ +from hashlib import sha256 +import inspect +import json +from pathlib import Path +import pickle + + +def arg_filtering(large_fn, small_fn): + large_sig, small_sig = inspect.signature(large_fn), inspect.signature(small_fn) + large_keys = tuple(large_sig.parameters.keys()) + + # Create a mapping from small_fn's positional arguments to large_fn's positional arguments. + indices = [] + for i, (key, value) in enumerate(small_sig.parameters.items()): + if value.kind not in (value.POSITIONAL_OR_KEYWORD, value.POSITIONAL_ONLY): + + continue + + try: + print(large_keys.index(key)) + i = large_keys.index(key) + # If an argument from small_fn is not present in large_fn, just note the index of that + # argument as it is in small_fn. + except ValueError: + print("not found") + + assert i not in indices + indices.append(i) + + # Enumerate all small_fn's keyword arguments. + filter_kwargs = {param.name for param in small_sig.parameters.values() + if param.kind == param.POSITIONAL_OR_KEYWORD + or param.kind == param.KEYWORD_ONLY} + + def inner(*raw_args, **raw_kwargs): + # For each argument in small_fn, find the argument's value from large_fn by index. + args = [] + for i in indices: + if i >= len(raw_args): + break + args.append(raw_args[i]) + + # Filter any keyword arguments not present in small_fn. + kwargs = {k: v for k, v in raw_kwargs.items() if k in filter_kwargs} + + return small_fn(*args, **kwargs) + + return inner + + +class load_or_create: + + def __init__( + self, + load=pickle.load, + save=pickle.dump, + prefix="./", + suffix=".pkl", + hash_len=8, + save_args=False, + plain_text=False + ): + self.load=load + self.save=save + self.prefix=prefix + self.suffix=suffix + self.hash_len=hash_len + self.save_args=save_args + self.plain_text=plain_text + self.not_called=True + + def __call__(self, fn): + assert self.not_called + self.not_called = False + + fn_args = tuple(inspect.signature(fn).parameters.keys()) + + self.load = arg_filtering(fn, self.load) + self.save = arg_filtering(fn, self.save) + + def inner(*args, **kwargs): + # Store all arguments into json and hash it to get the storage path. + json_args = ( + json.dumps( + (args, kwargs), + default=lambda x: vars(x) if hasattr(x, "__dict__") else str(x) + ) + "\n" + ).encode("utf-8") + hash = sha256(json_args, usedforsecurity=False).hexdigest()[:self.hash_len] + path = Path(self.prefix + hash + self.suffix) + + for i, arg in enumerate(args): + kwargs[fn_args[i]] = arg + + # If the path exists, load the object from there. + if path.exists(): + with open(path, "r" if self.plain_text else "rb") as f: + if self.save_args: + f.readline() + return self.load(f, **kwargs) + + # Else, run the function and store its result. + out = fn(**kwargs) + + if "/" in self.suffix: + path.parent.mkdir(parents=True) + with open(path, "w" if self.plain_text else "wb") as f: + if self.save_args: + f.write(json_args) + self.save(f, out, **kwargs) + + return out + + return inner +