From d19ea6fdfea6532c7f2641425b82c65c36e702d6 Mon Sep 17 00:00:00 2001 From: JJJHolscher Date: Tue, 12 Dec 2023 14:55:03 +0100 Subject: [PATCH] equinox save --- pyproject.toml | 2 +- src/eqx.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index aa0ebb8..bb9d64e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "jo3util" -version = "0.0.3" +version = "0.0.6" description = "" dependencies = [] dynamic = ["readme"] diff --git a/src/eqx.py b/src/eqx.py index ab11538..11b4f09 100644 --- a/src/eqx.py +++ b/src/eqx.py @@ -1,5 +1,6 @@ #! /usr/bin/env python3 +import json from typing import Callable, Optional, Tuple import equinox as eqx @@ -113,3 +114,17 @@ def insert_after(where: Callable, model: eqx.Module, func: Callable) -> eqx.Modu model = eqx.tree_at(where, model, replace_fn=Ensemble) return model + + +def save(path, pytree, hyperparameters={}): + with open(path, "wb") as f: + hyperparameters = json.dumps(hyperparameters) + f.write((hyperparameters + "\n").encode()) + eqx.tree_serialise_leaves(f, pytree) + + +def load(path, type_): + with open(path, "rb") as f: + hyperparams = json.loads(f.readline().decode()) + pytree = type_(**hyperparams) + return eqx.tree_deserialise_leaves(f, pytree)