From 22e93da91f13f149914972ebd91fe730b0b9b74b Mon Sep 17 00:00:00 2001 From: JJJHolscher Date: Wed, 13 Dec 2023 12:03:04 +0100 Subject: [PATCH] import error --- pyproject.toml | 2 +- src/eqx.py | 4 ++-- src/root.py | 25 ++++++++++++++++++++----- 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1795671..d3fb10c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "jo3util" -version = "0.0.7" +version = "0.0.10" description = "" dependencies = [] dynamic = ["readme"] diff --git a/src/eqx.py b/src/eqx.py index d13e74c..e05ea5b 100644 --- a/src/eqx.py +++ b/src/eqx.py @@ -50,7 +50,7 @@ def create_hook( def sow(where: Callable, model: eqx.Module) -> eqx.Module: """Capture intermediate activations that the argument modules output - and return them together after model output""" + and return them together with the model output""" activ = [] def install_sow(node: Callable): @@ -91,7 +91,7 @@ def sow(where: Callable, model: eqx.Module) -> eqx.Module: def __call__(self, *args, **kwargs): activ.clear() # empty the list x = model_call(self, *args, **kwargs) - return x, activ + return activ + [x] return Sow(model) diff --git a/src/root.py b/src/root.py index 298a274..c82fa6a 100644 --- a/src/root.py +++ b/src/root.py @@ -2,15 +2,18 @@ import inspect import os -import sys -from importlib.resources import files +import json +from importlib.resources import files, as_file from pathlib import Path from typing import Optional import __main__ +from .string import hash_string +from .warning import todo def walk_stack() -> Optional[Path]: + todo("check if this is being called from the python shell") stack = inspect.stack() i = 0 file_name = "<>" @@ -23,12 +26,13 @@ def walk_stack() -> Optional[Path]: return Path(file_name) -def root_dir(): +def root_dir() -> Path: if "FILE_PATH" in os.environ: - return Path(os.environ["FILE_PATH"]).parent() + return Path(os.environ["FILE_PATH"]).parent if __main__.__package__: - return files(__main__.__package__) + with as_file(files(__main__.__package__)) as path: + return path if "__file__" in dir(__main__): return Path(__main__.__file__).parent @@ -75,3 +79,14 @@ def root_file(): return out raise NotImplementedError + + +def run_dir(obj, subdir=Path("run")): + run_id: str = json.dumps(obj, default=lambda x: vars(x)) + run_id: str = hash_string(run_id) + + path = root_dir() + if subdir: + path /= subdir + path /= run_id + return path