equinox save

This commit is contained in:
JJJHolscher 2023-12-12 14:55:03 +01:00
parent 55d4db51a0
commit d19ea6fdfe
2 changed files with 16 additions and 1 deletions

View File

@ -1,7 +1,7 @@
[project]
name = "jo3util"
version = "0.0.3"
version = "0.0.6"
description = ""
dependencies = []
dynamic = ["readme"]

View File

@ -1,5 +1,6 @@
#! /usr/bin/env python3
import json
from typing import Callable, Optional, Tuple
import equinox as eqx
@ -113,3 +114,17 @@ def insert_after(where: Callable, model: eqx.Module, func: Callable) -> eqx.Modu
model = eqx.tree_at(where, model, replace_fn=Ensemble)
return model
def save(path, pytree, hyperparameters={}):
with open(path, "wb") as f:
hyperparameters = json.dumps(hyperparameters)
f.write((hyperparameters + "\n").encode())
eqx.tree_serialise_leaves(f, pytree)
def load(path, type_):
with open(path, "rb") as f:
hyperparams = json.loads(f.readline().decode())
pytree = type_(**hyperparams)
return eqx.tree_deserialise_leaves(f, pytree)