diff --git a/src/store.py b/src/store.py index c5b34bd..825e554 100644 --- a/src/store.py +++ b/src/store.py @@ -5,45 +5,21 @@ from pathlib import Path import pickle -def arg_filtering(large_fn, small_fn): - large_sig, small_sig = inspect.signature(large_fn), inspect.signature(small_fn) - large_keys = tuple(large_sig.parameters.keys()) - - # Create a mapping from small_fn's positional arguments to large_fn's positional arguments. - indices = [] - for i, (key, value) in enumerate(small_sig.parameters.items()): - if value.kind not in (value.POSITIONAL_OR_KEYWORD, value.POSITIONAL_ONLY): - - continue - - try: - print(large_keys.index(key)) - i = large_keys.index(key) - # If an argument from small_fn is not present in large_fn, just note the index of that - # argument as it is in small_fn. - except ValueError: - print("not found") - - assert i not in indices - indices.append(i) - +def with_filtered_kwargs(fn): + """ + Return a version of the input function that ignores instead of errors on + unknown keyword arguments. + """ # Enumerate all small_fn's keyword arguments. - filter_kwargs = {param.name for param in small_sig.parameters.values() + filter_kwargs = {param.name for param in inspect.signature(fn).parameters.values() if param.kind == param.POSITIONAL_OR_KEYWORD or param.kind == param.KEYWORD_ONLY} - def inner(*raw_args, **raw_kwargs): - # For each argument in small_fn, find the argument's value from large_fn by index. - args = [] - for i in indices: - if i >= len(raw_args): - break - args.append(raw_args[i]) - + def inner(*args, **kwargs): # Filter any keyword arguments not present in small_fn. - kwargs = {k: v for k, v in raw_kwargs.items() if k in filter_kwargs} + kwargs = {k: v for k, v in kwargs.items() if k in filter_kwargs} - return small_fn(*args, **kwargs) + return fn(*args, **kwargs) return inner @@ -60,39 +36,33 @@ class load_or_create: save_args=False, plain_text=False ): - self.load=load - self.save=save + self.load=with_filtered_kwargs(load) + self.save=with_filtered_kwargs(save) self.prefix=prefix self.suffix=suffix self.hash_len=hash_len self.save_args=save_args self.plain_text=plain_text - self.not_called=True def __call__(self, fn): - assert self.not_called - self.not_called = False - - fn_args = tuple(inspect.signature(fn).parameters.keys()) - - self.load = arg_filtering(fn, self.load) - self.save = arg_filtering(fn, self.save) + # fn_args = tuple(inspect.signature(fn).parameters.keys()) def inner(*args, **kwargs): - # Store all arguments into json and hash it to get the storage path. + # Store the keyword arguments into json and hash it to get the storage path. json_args = ( json.dumps( - (args, kwargs), + kwargs, default=lambda x: vars(x) if hasattr(x, "__dict__") else str(x) ) + "\n" ).encode("utf-8") hash = sha256(json_args, usedforsecurity=False).hexdigest()[:self.hash_len] path = Path(self.prefix + hash + self.suffix) - for i, arg in enumerate(args): - kwargs[fn_args[i]] = arg + # Convert all arguments to keyword arguments. + # for arg_name, arg in zip(fn_args, args): + # kwargs[arg_name] = arg - # If the path exists, load the object from there. + # If the storage path exists, load the cached object from there. if path.exists(): with open(path, "r" if self.plain_text else "rb") as f: if self.save_args: @@ -100,7 +70,7 @@ class load_or_create: return self.load(f, **kwargs) # Else, run the function and store its result. - out = fn(**kwargs) + out = fn(*args, **kwargs) if "/" in self.suffix: path.parent.mkdir(parents=True)