save function accepts more hyperparameter types now

This commit is contained in:
JJJHolscher 2023-12-12 17:33:28 +01:00
parent 238e1fb880
commit 46669bed3d
3 changed files with 13 additions and 12 deletions

View File

@ -1,7 +1,7 @@
[project]
name = "jo3util"
version = "0.0.6"
version = "0.0.7"
description = ""
dependencies = []
dynamic = ["readme"]

View File

@ -1,7 +1,8 @@
#! /usr/bin/env python3
import json
from typing import Callable, Optional, Tuple
from typing import Callable, Optional, Tuple, Union
from pathlib import Path
import equinox as eqx
import jax
@ -95,7 +96,11 @@ def sow(where: Callable, model: eqx.Module) -> eqx.Module:
return Sow(model)
def insert_after(where: Callable, model: eqx.Module, func: Callable) -> eqx.Module:
def insert_after(
where: Callable,
model: eqx.Module,
func: Callable
) -> eqx.Module:
"""Place a callable immediately after the argument modules"""
class Ensemble(eqx.Module):
@ -116,9 +121,12 @@ def insert_after(where: Callable, model: eqx.Module, func: Callable) -> eqx.Modu
return model
def save(path, pytree, hyperparameters={}):
def save(path: Union[Path, str], pytree, hyperparameters={}):
with open(path, "wb") as f:
hyperparameters = json.dumps(hyperparameters)
hyperparameters = json.dumps(
hyperparameters,
default=lambda h: vars(h)
)
f.write((hyperparameters + "\n").encode())
eqx.tree_serialise_leaves(f, pytree)

View File

@ -1,7 +0,0 @@
#! /usr/bin/env python3
# vim:fenc=utf-8
"""
"""