no duplicate warnings, only check for path existence if load accepts file

This commit is contained in:
JJJHolscher 2024-08-05 15:52:11 +02:00
parent 35396c80ca
commit 746b2cf6b0
2 changed files with 57 additions and 19 deletions

View File

@ -35,7 +35,9 @@ class load_or_create:
plain_text=False plain_text=False
): ):
self.load=with_filtered_args(load)[0] self.load=with_filtered_args(load)[0]
self.load_arg_names = {p.name for p in inspect.signature(load).parameters.values()}
self.save=with_filtered_args(save)[0] self.save=with_filtered_args(save)[0]
self.save_arg_names = {p.name for p in inspect.signature(save).parameters.values()}
self.path_fn=path_fn self.path_fn=path_fn
self.hash_len=hash_len self.hash_len=hash_len
self.save_args=save_args self.save_args=save_args
@ -53,35 +55,66 @@ class inner(load_or_create):
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
# Store the keyword arguments into json and hash it to get the storage path. # Store the keyword arguments into json and hash it to get the storage path.
json_args = self.to_json(**kwargs)
path = self.path(*args, **kwargs) path = self.path(*args, **kwargs)
# path = self.path(hash)
# If the storage path exists, load the cached object from there. obj = self.load_wrapper(**self.args_to_kwargs(args, kwargs, path=path))
if path.exists(): if obj is not None: return obj
with open(path, "r" if self.plain_text else "rb") as file:
# If specified, the first line is a json of the keyword arguments.
if self.save_args:
file.readline()
return self.load(**self.args_to_kwargs(args, kwargs, path=path, file=file))
# Else, run the function. and store its result.
obj = self.fn(*args, **kwargs) obj = self.fn(*args, **kwargs)
if obj is None: return obj
if obj is None: self.save_wrapper(obj, **self.args_to_kwargs(args, kwargs, path=path))
return obj
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w" if self.plain_text else "wb") as file:
if self.save_args:
file.write(json_args)
self.save(obj, **self.args_to_kwargs(args, kwargs, path=path, file=file))
return obj return obj
def args_to_kwargs(self, args, kwargs, **extra): def args_to_kwargs(self, args, kwargs, **extra):
return extra | kwargs | {self.arg_names[i]: a for i, a in enumerate(args)} return extra | kwargs | {self.arg_names[i]: a for i, a in enumerate(args)}
def load_wrapper(self, **load_args):
"""
load_wrapper returns None or an abject.
Iff it returns None, the object is deemed not-loaded and thus,
self.fn needs to be called to create the object.
It can return None, because self.load returns None or if self.load
expects an open file, but no such file exists.
"""
# Check whether self.load does not expect an open file.
if "file" in load_args or "file" not in self.load_arg_names:
return self.load(**load_args)
path = load_args["path"]
# If self.load expects an open file but there is none, run self.fn.
if not path.exists():
return None
with open(path, "r" if self.plain_text else "rb") as file:
# If specified, the first line is a json of the keyword arguments.
if self.save_args:
file.readline()
return self.load(**{"file": file} | load_args)
def save_wrapper(self, obj, path, **save_args):
"""Only open a file if the save function requests you to.
"""
Path(path).parent.mkdir(parents=True, exist_ok=True)
if "file" not in self.save_arg_names:
return self.save(obj, **{"path": path} | save_args)
if "file" in save_args:
file = save_args["file"]
else:
file = open(path, "w" if self.plain_text else "wb")
if self.save_args:
file.write(self.to_json(**save_args))
self.save(obj, **{"file": file, "path": path} | save_args)
if "file" not in save_args:
file.close()
def to_json(self, **kwargs) -> bytes: def to_json(self, **kwargs) -> bytes:
"""Serialize all keyword arguments to json. """Serialize all keyword arguments to json.
We never serialize positional arguments""" We never serialize positional arguments"""

View File

@ -14,5 +14,10 @@ class ToDoWarning(Warning):
def __str__(self): def __str__(self):
return repr(self.message) return repr(self.message)
PAST_TODO_MESSAGES = set()
def todo(msg): def todo(msg):
warnings.warn(msg, ToDoWarning) global PAST_TODO_MESSAGES
if msg not in PAST_TODO_MESSAGES:
warnings.warn(msg, ToDoWarning)
PAST_TODO_MESSAGES.add(msg)