possibility to add args to kwargs

This commit is contained in:
JJJHolscher 2024-08-01 14:37:52 +02:00
parent 58a965bc8e
commit fe8ae1fda7

View File

@ -5,45 +5,21 @@ from pathlib import Path
import pickle import pickle
def arg_filtering(large_fn, small_fn): def with_filtered_kwargs(fn):
large_sig, small_sig = inspect.signature(large_fn), inspect.signature(small_fn) """
large_keys = tuple(large_sig.parameters.keys()) Return a version of the input function that ignores instead of errors on
unknown keyword arguments.
# Create a mapping from small_fn's positional arguments to large_fn's positional arguments. """
indices = []
for i, (key, value) in enumerate(small_sig.parameters.items()):
if value.kind not in (value.POSITIONAL_OR_KEYWORD, value.POSITIONAL_ONLY):
continue
try:
print(large_keys.index(key))
i = large_keys.index(key)
# If an argument from small_fn is not present in large_fn, just note the index of that
# argument as it is in small_fn.
except ValueError:
print("not found")
assert i not in indices
indices.append(i)
# Enumerate all small_fn's keyword arguments. # Enumerate all small_fn's keyword arguments.
filter_kwargs = {param.name for param in small_sig.parameters.values() filter_kwargs = {param.name for param in inspect.signature(fn).parameters.values()
if param.kind == param.POSITIONAL_OR_KEYWORD if param.kind == param.POSITIONAL_OR_KEYWORD
or param.kind == param.KEYWORD_ONLY} or param.kind == param.KEYWORD_ONLY}
def inner(*raw_args, **raw_kwargs): def inner(*args, **kwargs):
# For each argument in small_fn, find the argument's value from large_fn by index.
args = []
for i in indices:
if i >= len(raw_args):
break
args.append(raw_args[i])
# Filter any keyword arguments not present in small_fn. # Filter any keyword arguments not present in small_fn.
kwargs = {k: v for k, v in raw_kwargs.items() if k in filter_kwargs} kwargs = {k: v for k, v in kwargs.items() if k in filter_kwargs}
return small_fn(*args, **kwargs) return fn(*args, **kwargs)
return inner return inner
@ -60,39 +36,33 @@ class load_or_create:
save_args=False, save_args=False,
plain_text=False plain_text=False
): ):
self.load=load self.load=with_filtered_kwargs(load)
self.save=save self.save=with_filtered_kwargs(save)
self.prefix=prefix self.prefix=prefix
self.suffix=suffix self.suffix=suffix
self.hash_len=hash_len self.hash_len=hash_len
self.save_args=save_args self.save_args=save_args
self.plain_text=plain_text self.plain_text=plain_text
self.not_called=True
def __call__(self, fn): def __call__(self, fn):
assert self.not_called # fn_args = tuple(inspect.signature(fn).parameters.keys())
self.not_called = False
fn_args = tuple(inspect.signature(fn).parameters.keys())
self.load = arg_filtering(fn, self.load)
self.save = arg_filtering(fn, self.save)
def inner(*args, **kwargs): def inner(*args, **kwargs):
# Store all arguments into json and hash it to get the storage path. # Store the keyword arguments into json and hash it to get the storage path.
json_args = ( json_args = (
json.dumps( json.dumps(
(args, kwargs), kwargs,
default=lambda x: vars(x) if hasattr(x, "__dict__") else str(x) default=lambda x: vars(x) if hasattr(x, "__dict__") else str(x)
) + "\n" ) + "\n"
).encode("utf-8") ).encode("utf-8")
hash = sha256(json_args, usedforsecurity=False).hexdigest()[:self.hash_len] hash = sha256(json_args, usedforsecurity=False).hexdigest()[:self.hash_len]
path = Path(self.prefix + hash + self.suffix) path = Path(self.prefix + hash + self.suffix)
for i, arg in enumerate(args): # Convert all arguments to keyword arguments.
kwargs[fn_args[i]] = arg # for arg_name, arg in zip(fn_args, args):
# kwargs[arg_name] = arg
# If the path exists, load the object from there. # If the storage path exists, load the cached object from there.
if path.exists(): if path.exists():
with open(path, "r" if self.plain_text else "rb") as f: with open(path, "r" if self.plain_text else "rb") as f:
if self.save_args: if self.save_args:
@ -100,7 +70,7 @@ class load_or_create:
return self.load(f, **kwargs) return self.load(f, **kwargs)
# Else, run the function and store its result. # Else, run the function and store its result.
out = fn(**kwargs) out = fn(*args, **kwargs)
if "/" in self.suffix: if "/" in self.suffix:
path.parent.mkdir(parents=True) path.parent.mkdir(parents=True)