save function accepts more hyperparameter types now

This commit is contained in:
JJJHolscher 2023-12-12 17:33:28 +01:00
parent 238e1fb880
commit 46669bed3d
3 changed files with 13 additions and 12 deletions

View File

@ -1,7 +1,7 @@
[project] [project]
name = "jo3util" name = "jo3util"
version = "0.0.6" version = "0.0.7"
description = "" description = ""
dependencies = [] dependencies = []
dynamic = ["readme"] dynamic = ["readme"]

View File

@ -1,7 +1,8 @@
#! /usr/bin/env python3 #! /usr/bin/env python3
import json import json
from typing import Callable, Optional, Tuple from typing import Callable, Optional, Tuple, Union
from pathlib import Path
import equinox as eqx import equinox as eqx
import jax import jax
@ -95,7 +96,11 @@ def sow(where: Callable, model: eqx.Module) -> eqx.Module:
return Sow(model) return Sow(model)
def insert_after(where: Callable, model: eqx.Module, func: Callable) -> eqx.Module: def insert_after(
where: Callable,
model: eqx.Module,
func: Callable
) -> eqx.Module:
"""Place a callable immediately after the argument modules""" """Place a callable immediately after the argument modules"""
class Ensemble(eqx.Module): class Ensemble(eqx.Module):
@ -116,9 +121,12 @@ def insert_after(where: Callable, model: eqx.Module, func: Callable) -> eqx.Modu
return model return model
def save(path, pytree, hyperparameters={}): def save(path: Union[Path, str], pytree, hyperparameters={}):
with open(path, "wb") as f: with open(path, "wb") as f:
hyperparameters = json.dumps(hyperparameters) hyperparameters = json.dumps(
hyperparameters,
default=lambda h: vars(h)
)
f.write((hyperparameters + "\n").encode()) f.write((hyperparameters + "\n").encode())
eqx.tree_serialise_leaves(f, pytree) eqx.tree_serialise_leaves(f, pytree)

View File

@ -1,7 +0,0 @@
#! /usr/bin/env python3
# vim:fenc=utf-8
"""
"""