import error
This commit is contained in:
parent
46669bed3d
commit
22e93da91f
|
@ -1,7 +1,7 @@
|
|||
|
||||
[project]
|
||||
name = "jo3util"
|
||||
version = "0.0.7"
|
||||
version = "0.0.10"
|
||||
description = ""
|
||||
dependencies = []
|
||||
dynamic = ["readme"]
|
||||
|
|
|
@ -50,7 +50,7 @@ def create_hook(
|
|||
|
||||
def sow(where: Callable, model: eqx.Module) -> eqx.Module:
|
||||
"""Capture intermediate activations that the argument modules output
|
||||
and return them together after model output"""
|
||||
and return them together with the model output"""
|
||||
activ = []
|
||||
|
||||
def install_sow(node: Callable):
|
||||
|
@ -91,7 +91,7 @@ def sow(where: Callable, model: eqx.Module) -> eqx.Module:
|
|||
def __call__(self, *args, **kwargs):
|
||||
activ.clear() # empty the list
|
||||
x = model_call(self, *args, **kwargs)
|
||||
return x, activ
|
||||
return activ + [x]
|
||||
|
||||
return Sow(model)
|
||||
|
||||
|
|
25
src/root.py
25
src/root.py
|
@ -2,15 +2,18 @@
|
|||
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
from importlib.resources import files
|
||||
import json
|
||||
from importlib.resources import files, as_file
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import __main__
|
||||
from .string import hash_string
|
||||
from .warning import todo
|
||||
|
||||
|
||||
def walk_stack() -> Optional[Path]:
|
||||
todo("check if this is being called from the python shell")
|
||||
stack = inspect.stack()
|
||||
i = 0
|
||||
file_name = "<>"
|
||||
|
@ -23,12 +26,13 @@ def walk_stack() -> Optional[Path]:
|
|||
return Path(file_name)
|
||||
|
||||
|
||||
def root_dir():
|
||||
def root_dir() -> Path:
|
||||
if "FILE_PATH" in os.environ:
|
||||
return Path(os.environ["FILE_PATH"]).parent()
|
||||
return Path(os.environ["FILE_PATH"]).parent
|
||||
|
||||
if __main__.__package__:
|
||||
return files(__main__.__package__)
|
||||
with as_file(files(__main__.__package__)) as path:
|
||||
return path
|
||||
|
||||
if "__file__" in dir(__main__):
|
||||
return Path(__main__.__file__).parent
|
||||
|
@ -75,3 +79,14 @@ def root_file():
|
|||
return out
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def run_dir(obj, subdir=Path("run")):
|
||||
run_id: str = json.dumps(obj, default=lambda x: vars(x))
|
||||
run_id: str = hash_string(run_id)
|
||||
|
||||
path = root_dir()
|
||||
if subdir:
|
||||
path /= subdir
|
||||
path /= run_id
|
||||
return path
|
||||
|
|
Loading…
Reference in New Issue
Block a user