import error

This commit is contained in:
JJJHolscher 2023-12-13 12:03:04 +01:00
parent 46669bed3d
commit 22e93da91f
3 changed files with 23 additions and 8 deletions

View File

@ -1,7 +1,7 @@
[project]
name = "jo3util"
version = "0.0.7"
version = "0.0.10"
description = ""
dependencies = []
dynamic = ["readme"]

View File

@ -50,7 +50,7 @@ def create_hook(
def sow(where: Callable, model: eqx.Module) -> eqx.Module:
"""Capture intermediate activations that the argument modules output
and return them together after model output"""
and return them together with the model output"""
activ = []
def install_sow(node: Callable):
@ -91,7 +91,7 @@ def sow(where: Callable, model: eqx.Module) -> eqx.Module:
def __call__(self, *args, **kwargs):
activ.clear() # empty the list
x = model_call(self, *args, **kwargs)
return x, activ
return activ + [x]
return Sow(model)

View File

@ -2,15 +2,18 @@
import inspect
import os
import sys
from importlib.resources import files
import json
from importlib.resources import files, as_file
from pathlib import Path
from typing import Optional
import __main__
from .string import hash_string
from .warning import todo
def walk_stack() -> Optional[Path]:
todo("check if this is being called from the python shell")
stack = inspect.stack()
i = 0
file_name = "<>"
@ -23,12 +26,13 @@ def walk_stack() -> Optional[Path]:
return Path(file_name)
def root_dir():
def root_dir() -> Path:
if "FILE_PATH" in os.environ:
return Path(os.environ["FILE_PATH"]).parent()
return Path(os.environ["FILE_PATH"]).parent
if __main__.__package__:
return files(__main__.__package__)
with as_file(files(__main__.__package__)) as path:
return path
if "__file__" in dir(__main__):
return Path(__main__.__file__).parent
@ -75,3 +79,14 @@ def root_file():
return out
raise NotImplementedError
def run_dir(obj, subdir=Path("run")):
run_id: str = json.dumps(obj, default=lambda x: vars(x))
run_id: str = hash_string(run_id)
path = root_dir()
if subdir:
path /= subdir
path /= run_id
return path