possibility to add args to kwargs
This commit is contained in:
parent
58a965bc8e
commit
fe8ae1fda7
68
src/store.py
68
src/store.py
|
@ -5,45 +5,21 @@ from pathlib import Path
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
|
|
||||||
def arg_filtering(large_fn, small_fn):
|
def with_filtered_kwargs(fn):
|
||||||
large_sig, small_sig = inspect.signature(large_fn), inspect.signature(small_fn)
|
"""
|
||||||
large_keys = tuple(large_sig.parameters.keys())
|
Return a version of the input function that ignores instead of errors on
|
||||||
|
unknown keyword arguments.
|
||||||
# Create a mapping from small_fn's positional arguments to large_fn's positional arguments.
|
"""
|
||||||
indices = []
|
|
||||||
for i, (key, value) in enumerate(small_sig.parameters.items()):
|
|
||||||
if value.kind not in (value.POSITIONAL_OR_KEYWORD, value.POSITIONAL_ONLY):
|
|
||||||
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
print(large_keys.index(key))
|
|
||||||
i = large_keys.index(key)
|
|
||||||
# If an argument from small_fn is not present in large_fn, just note the index of that
|
|
||||||
# argument as it is in small_fn.
|
|
||||||
except ValueError:
|
|
||||||
print("not found")
|
|
||||||
|
|
||||||
assert i not in indices
|
|
||||||
indices.append(i)
|
|
||||||
|
|
||||||
# Enumerate all small_fn's keyword arguments.
|
# Enumerate all small_fn's keyword arguments.
|
||||||
filter_kwargs = {param.name for param in small_sig.parameters.values()
|
filter_kwargs = {param.name for param in inspect.signature(fn).parameters.values()
|
||||||
if param.kind == param.POSITIONAL_OR_KEYWORD
|
if param.kind == param.POSITIONAL_OR_KEYWORD
|
||||||
or param.kind == param.KEYWORD_ONLY}
|
or param.kind == param.KEYWORD_ONLY}
|
||||||
|
|
||||||
def inner(*raw_args, **raw_kwargs):
|
def inner(*args, **kwargs):
|
||||||
# For each argument in small_fn, find the argument's value from large_fn by index.
|
|
||||||
args = []
|
|
||||||
for i in indices:
|
|
||||||
if i >= len(raw_args):
|
|
||||||
break
|
|
||||||
args.append(raw_args[i])
|
|
||||||
|
|
||||||
# Filter any keyword arguments not present in small_fn.
|
# Filter any keyword arguments not present in small_fn.
|
||||||
kwargs = {k: v for k, v in raw_kwargs.items() if k in filter_kwargs}
|
kwargs = {k: v for k, v in kwargs.items() if k in filter_kwargs}
|
||||||
|
|
||||||
return small_fn(*args, **kwargs)
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
|
@ -60,39 +36,33 @@ class load_or_create:
|
||||||
save_args=False,
|
save_args=False,
|
||||||
plain_text=False
|
plain_text=False
|
||||||
):
|
):
|
||||||
self.load=load
|
self.load=with_filtered_kwargs(load)
|
||||||
self.save=save
|
self.save=with_filtered_kwargs(save)
|
||||||
self.prefix=prefix
|
self.prefix=prefix
|
||||||
self.suffix=suffix
|
self.suffix=suffix
|
||||||
self.hash_len=hash_len
|
self.hash_len=hash_len
|
||||||
self.save_args=save_args
|
self.save_args=save_args
|
||||||
self.plain_text=plain_text
|
self.plain_text=plain_text
|
||||||
self.not_called=True
|
|
||||||
|
|
||||||
def __call__(self, fn):
|
def __call__(self, fn):
|
||||||
assert self.not_called
|
# fn_args = tuple(inspect.signature(fn).parameters.keys())
|
||||||
self.not_called = False
|
|
||||||
|
|
||||||
fn_args = tuple(inspect.signature(fn).parameters.keys())
|
|
||||||
|
|
||||||
self.load = arg_filtering(fn, self.load)
|
|
||||||
self.save = arg_filtering(fn, self.save)
|
|
||||||
|
|
||||||
def inner(*args, **kwargs):
|
def inner(*args, **kwargs):
|
||||||
# Store all arguments into json and hash it to get the storage path.
|
# Store the keyword arguments into json and hash it to get the storage path.
|
||||||
json_args = (
|
json_args = (
|
||||||
json.dumps(
|
json.dumps(
|
||||||
(args, kwargs),
|
kwargs,
|
||||||
default=lambda x: vars(x) if hasattr(x, "__dict__") else str(x)
|
default=lambda x: vars(x) if hasattr(x, "__dict__") else str(x)
|
||||||
) + "\n"
|
) + "\n"
|
||||||
).encode("utf-8")
|
).encode("utf-8")
|
||||||
hash = sha256(json_args, usedforsecurity=False).hexdigest()[:self.hash_len]
|
hash = sha256(json_args, usedforsecurity=False).hexdigest()[:self.hash_len]
|
||||||
path = Path(self.prefix + hash + self.suffix)
|
path = Path(self.prefix + hash + self.suffix)
|
||||||
|
|
||||||
for i, arg in enumerate(args):
|
# Convert all arguments to keyword arguments.
|
||||||
kwargs[fn_args[i]] = arg
|
# for arg_name, arg in zip(fn_args, args):
|
||||||
|
# kwargs[arg_name] = arg
|
||||||
|
|
||||||
# If the path exists, load the object from there.
|
# If the storage path exists, load the cached object from there.
|
||||||
if path.exists():
|
if path.exists():
|
||||||
with open(path, "r" if self.plain_text else "rb") as f:
|
with open(path, "r" if self.plain_text else "rb") as f:
|
||||||
if self.save_args:
|
if self.save_args:
|
||||||
|
@ -100,7 +70,7 @@ class load_or_create:
|
||||||
return self.load(f, **kwargs)
|
return self.load(f, **kwargs)
|
||||||
|
|
||||||
# Else, run the function and store its result.
|
# Else, run the function and store its result.
|
||||||
out = fn(**kwargs)
|
out = fn(*args, **kwargs)
|
||||||
|
|
||||||
if "/" in self.suffix:
|
if "/" in self.suffix:
|
||||||
path.parent.mkdir(parents=True)
|
path.parent.mkdir(parents=True)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user