import error
This commit is contained in:
parent
46669bed3d
commit
22e93da91f
|
@ -1,7 +1,7 @@
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "jo3util"
|
name = "jo3util"
|
||||||
version = "0.0.7"
|
version = "0.0.10"
|
||||||
description = ""
|
description = ""
|
||||||
dependencies = []
|
dependencies = []
|
||||||
dynamic = ["readme"]
|
dynamic = ["readme"]
|
||||||
|
|
|
@ -50,7 +50,7 @@ def create_hook(
|
||||||
|
|
||||||
def sow(where: Callable, model: eqx.Module) -> eqx.Module:
|
def sow(where: Callable, model: eqx.Module) -> eqx.Module:
|
||||||
"""Capture intermediate activations that the argument modules output
|
"""Capture intermediate activations that the argument modules output
|
||||||
and return them together after model output"""
|
and return them together with the model output"""
|
||||||
activ = []
|
activ = []
|
||||||
|
|
||||||
def install_sow(node: Callable):
|
def install_sow(node: Callable):
|
||||||
|
@ -91,7 +91,7 @@ def sow(where: Callable, model: eqx.Module) -> eqx.Module:
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
activ.clear() # empty the list
|
activ.clear() # empty the list
|
||||||
x = model_call(self, *args, **kwargs)
|
x = model_call(self, *args, **kwargs)
|
||||||
return x, activ
|
return activ + [x]
|
||||||
|
|
||||||
return Sow(model)
|
return Sow(model)
|
||||||
|
|
||||||
|
|
25
src/root.py
25
src/root.py
|
@ -2,15 +2,18 @@
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import sys
|
import json
|
||||||
from importlib.resources import files
|
from importlib.resources import files, as_file
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import __main__
|
import __main__
|
||||||
|
from .string import hash_string
|
||||||
|
from .warning import todo
|
||||||
|
|
||||||
|
|
||||||
def walk_stack() -> Optional[Path]:
|
def walk_stack() -> Optional[Path]:
|
||||||
|
todo("check if this is being called from the python shell")
|
||||||
stack = inspect.stack()
|
stack = inspect.stack()
|
||||||
i = 0
|
i = 0
|
||||||
file_name = "<>"
|
file_name = "<>"
|
||||||
|
@ -23,12 +26,13 @@ def walk_stack() -> Optional[Path]:
|
||||||
return Path(file_name)
|
return Path(file_name)
|
||||||
|
|
||||||
|
|
||||||
def root_dir():
|
def root_dir() -> Path:
|
||||||
if "FILE_PATH" in os.environ:
|
if "FILE_PATH" in os.environ:
|
||||||
return Path(os.environ["FILE_PATH"]).parent()
|
return Path(os.environ["FILE_PATH"]).parent
|
||||||
|
|
||||||
if __main__.__package__:
|
if __main__.__package__:
|
||||||
return files(__main__.__package__)
|
with as_file(files(__main__.__package__)) as path:
|
||||||
|
return path
|
||||||
|
|
||||||
if "__file__" in dir(__main__):
|
if "__file__" in dir(__main__):
|
||||||
return Path(__main__.__file__).parent
|
return Path(__main__.__file__).parent
|
||||||
|
@ -75,3 +79,14 @@ def root_file():
|
||||||
return out
|
return out
|
||||||
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def run_dir(obj, subdir=Path("run")):
|
||||||
|
run_id: str = json.dumps(obj, default=lambda x: vars(x))
|
||||||
|
run_id: str = hash_string(run_id)
|
||||||
|
|
||||||
|
path = root_dir()
|
||||||
|
if subdir:
|
||||||
|
path /= subdir
|
||||||
|
path /= run_id
|
||||||
|
return path
|
||||||
|
|
Loading…
Reference in New Issue
Block a user