mistaken attempt at arg filtering
This commit is contained in:
parent
85cc1373cb
commit
58a965bc8e
|
@ -77,12 +77,12 @@ def sow(where: Callable, model: eqx.Module) -> eqx.Module:
|
|||
activ.append(x)
|
||||
return x
|
||||
|
||||
return eqx.filter_jit(store_activation)
|
||||
return store_activation
|
||||
|
||||
model = eqx.tree_at(where, model, replace_fn=install_sow)
|
||||
|
||||
model_call = type(model).__call__
|
||||
todo("make Sow a generic class")
|
||||
todo("make Sow a generic class, also don't have nested Sows but check whether model is already a sow.")
|
||||
|
||||
class Sow(type(model)):
|
||||
def __init__(self, model):
|
||||
|
@ -91,6 +91,9 @@ def sow(where: Callable, model: eqx.Module) -> eqx.Module:
|
|||
def __call__(self, *args, **kwargs):
|
||||
activ.clear() # empty the list
|
||||
x = model_call(self, *args, **kwargs)
|
||||
if isinstance(x, list):
|
||||
return activ + x
|
||||
else:
|
||||
return activ + [x]
|
||||
|
||||
return Sow(model)
|
||||
|
|
115
src/store.py
Normal file
115
src/store.py
Normal file
|
@ -0,0 +1,115 @@
|
|||
from hashlib import sha256
|
||||
import inspect
|
||||
import json
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
|
||||
|
||||
def arg_filtering(large_fn, small_fn):
|
||||
large_sig, small_sig = inspect.signature(large_fn), inspect.signature(small_fn)
|
||||
large_keys = tuple(large_sig.parameters.keys())
|
||||
|
||||
# Create a mapping from small_fn's positional arguments to large_fn's positional arguments.
|
||||
indices = []
|
||||
for i, (key, value) in enumerate(small_sig.parameters.items()):
|
||||
if value.kind not in (value.POSITIONAL_OR_KEYWORD, value.POSITIONAL_ONLY):
|
||||
|
||||
continue
|
||||
|
||||
try:
|
||||
print(large_keys.index(key))
|
||||
i = large_keys.index(key)
|
||||
# If an argument from small_fn is not present in large_fn, just note the index of that
|
||||
# argument as it is in small_fn.
|
||||
except ValueError:
|
||||
print("not found")
|
||||
|
||||
assert i not in indices
|
||||
indices.append(i)
|
||||
|
||||
# Enumerate all small_fn's keyword arguments.
|
||||
filter_kwargs = {param.name for param in small_sig.parameters.values()
|
||||
if param.kind == param.POSITIONAL_OR_KEYWORD
|
||||
or param.kind == param.KEYWORD_ONLY}
|
||||
|
||||
def inner(*raw_args, **raw_kwargs):
|
||||
# For each argument in small_fn, find the argument's value from large_fn by index.
|
||||
args = []
|
||||
for i in indices:
|
||||
if i >= len(raw_args):
|
||||
break
|
||||
args.append(raw_args[i])
|
||||
|
||||
# Filter any keyword arguments not present in small_fn.
|
||||
kwargs = {k: v for k, v in raw_kwargs.items() if k in filter_kwargs}
|
||||
|
||||
return small_fn(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
class load_or_create:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
load=pickle.load,
|
||||
save=pickle.dump,
|
||||
prefix="./",
|
||||
suffix=".pkl",
|
||||
hash_len=8,
|
||||
save_args=False,
|
||||
plain_text=False
|
||||
):
|
||||
self.load=load
|
||||
self.save=save
|
||||
self.prefix=prefix
|
||||
self.suffix=suffix
|
||||
self.hash_len=hash_len
|
||||
self.save_args=save_args
|
||||
self.plain_text=plain_text
|
||||
self.not_called=True
|
||||
|
||||
def __call__(self, fn):
|
||||
assert self.not_called
|
||||
self.not_called = False
|
||||
|
||||
fn_args = tuple(inspect.signature(fn).parameters.keys())
|
||||
|
||||
self.load = arg_filtering(fn, self.load)
|
||||
self.save = arg_filtering(fn, self.save)
|
||||
|
||||
def inner(*args, **kwargs):
|
||||
# Store all arguments into json and hash it to get the storage path.
|
||||
json_args = (
|
||||
json.dumps(
|
||||
(args, kwargs),
|
||||
default=lambda x: vars(x) if hasattr(x, "__dict__") else str(x)
|
||||
) + "\n"
|
||||
).encode("utf-8")
|
||||
hash = sha256(json_args, usedforsecurity=False).hexdigest()[:self.hash_len]
|
||||
path = Path(self.prefix + hash + self.suffix)
|
||||
|
||||
for i, arg in enumerate(args):
|
||||
kwargs[fn_args[i]] = arg
|
||||
|
||||
# If the path exists, load the object from there.
|
||||
if path.exists():
|
||||
with open(path, "r" if self.plain_text else "rb") as f:
|
||||
if self.save_args:
|
||||
f.readline()
|
||||
return self.load(f, **kwargs)
|
||||
|
||||
# Else, run the function and store its result.
|
||||
out = fn(**kwargs)
|
||||
|
||||
if "/" in self.suffix:
|
||||
path.parent.mkdir(parents=True)
|
||||
with open(path, "w" if self.plain_text else "wb") as f:
|
||||
if self.save_args:
|
||||
f.write(json_args)
|
||||
self.save(f, out, **kwargs)
|
||||
|
||||
return out
|
||||
|
||||
return inner
|
||||
|
Loading…
Reference in New Issue
Block a user