import error

This commit is contained in:
JJJHolscher 2023-12-13 12:03:04 +01:00
parent 46669bed3d
commit 22e93da91f
3 changed files with 23 additions and 8 deletions

View File

@ -1,7 +1,7 @@
[project] [project]
name = "jo3util" name = "jo3util"
version = "0.0.7" version = "0.0.10"
description = "" description = ""
dependencies = [] dependencies = []
dynamic = ["readme"] dynamic = ["readme"]

View File

@ -50,7 +50,7 @@ def create_hook(
def sow(where: Callable, model: eqx.Module) -> eqx.Module: def sow(where: Callable, model: eqx.Module) -> eqx.Module:
"""Capture intermediate activations that the argument modules output """Capture intermediate activations that the argument modules output
and return them together after model output""" and return them together with the model output"""
activ = [] activ = []
def install_sow(node: Callable): def install_sow(node: Callable):
@ -91,7 +91,7 @@ def sow(where: Callable, model: eqx.Module) -> eqx.Module:
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
activ.clear() # empty the list activ.clear() # empty the list
x = model_call(self, *args, **kwargs) x = model_call(self, *args, **kwargs)
return x, activ return activ + [x]
return Sow(model) return Sow(model)

View File

@ -2,15 +2,18 @@
import inspect import inspect
import os import os
import sys import json
from importlib.resources import files from importlib.resources import files, as_file
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import __main__ import __main__
from .string import hash_string
from .warning import todo
def walk_stack() -> Optional[Path]: def walk_stack() -> Optional[Path]:
todo("check if this is being called from the python shell")
stack = inspect.stack() stack = inspect.stack()
i = 0 i = 0
file_name = "<>" file_name = "<>"
@ -23,12 +26,13 @@ def walk_stack() -> Optional[Path]:
return Path(file_name) return Path(file_name)
def root_dir(): def root_dir() -> Path:
if "FILE_PATH" in os.environ: if "FILE_PATH" in os.environ:
return Path(os.environ["FILE_PATH"]).parent() return Path(os.environ["FILE_PATH"]).parent
if __main__.__package__: if __main__.__package__:
return files(__main__.__package__) with as_file(files(__main__.__package__)) as path:
return path
if "__file__" in dir(__main__): if "__file__" in dir(__main__):
return Path(__main__.__file__).parent return Path(__main__.__file__).parent
@ -75,3 +79,14 @@ def root_file():
return out return out
raise NotImplementedError raise NotImplementedError
def run_dir(obj, subdir=Path("run")):
run_id: str = json.dumps(obj, default=lambda x: vars(x))
run_id: str = hash_string(run_id)
path = root_dir()
if subdir:
path /= subdir
path /= run_id
return path