diff --git a/pyproject.toml b/pyproject.toml index bb9d64e..1795671 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "jo3util" -version = "0.0.6" +version = "0.0.7" description = "" dependencies = [] dynamic = ["readme"] diff --git a/src/eqx.py b/src/eqx.py index 11b4f09..d13e74c 100644 --- a/src/eqx.py +++ b/src/eqx.py @@ -1,7 +1,8 @@ #! /usr/bin/env python3 import json -from typing import Callable, Optional, Tuple +from typing import Callable, Optional, Tuple, Union +from pathlib import Path import equinox as eqx import jax @@ -95,7 +96,11 @@ def sow(where: Callable, model: eqx.Module) -> eqx.Module: return Sow(model) -def insert_after(where: Callable, model: eqx.Module, func: Callable) -> eqx.Module: +def insert_after( + where: Callable, + model: eqx.Module, + func: Callable +) -> eqx.Module: """Place a callable immediately after the argument modules""" class Ensemble(eqx.Module): @@ -116,9 +121,12 @@ def insert_after(where: Callable, model: eqx.Module, func: Callable) -> eqx.Modu return model -def save(path, pytree, hyperparameters={}): +def save(path: Union[Path, str], pytree, hyperparameters={}): with open(path, "wb") as f: - hyperparameters = json.dumps(hyperparameters) + hyperparameters = json.dumps( + hyperparameters, + default=lambda h: vars(h) + ) f.write((hyperparameters + "\n").encode()) eqx.tree_serialise_leaves(f, pytree) diff --git a/src/main.py b/src/main.py deleted file mode 100644 index 9789cc1..0000000 --- a/src/main.py +++ /dev/null @@ -1,7 +0,0 @@ -#! /usr/bin/env python3 -# vim:fenc=utf-8 - -""" - -""" -