equinox save
This commit is contained in:
parent
55d4db51a0
commit
d19ea6fdfe
|
@ -1,7 +1,7 @@
|
|||
|
||||
[project]
|
||||
name = "jo3util"
|
||||
version = "0.0.3"
|
||||
version = "0.0.6"
|
||||
description = ""
|
||||
dependencies = []
|
||||
dynamic = ["readme"]
|
||||
|
|
15
src/eqx.py
15
src/eqx.py
|
@ -1,5 +1,6 @@
|
|||
#! /usr/bin/env python3
|
||||
|
||||
import json
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import equinox as eqx
|
||||
|
@ -113,3 +114,17 @@ def insert_after(where: Callable, model: eqx.Module, func: Callable) -> eqx.Modu
|
|||
|
||||
model = eqx.tree_at(where, model, replace_fn=Ensemble)
|
||||
return model
|
||||
|
||||
|
||||
def save(path, pytree, hyperparameters={}):
|
||||
with open(path, "wb") as f:
|
||||
hyperparameters = json.dumps(hyperparameters)
|
||||
f.write((hyperparameters + "\n").encode())
|
||||
eqx.tree_serialise_leaves(f, pytree)
|
||||
|
||||
|
||||
def load(path, type_):
|
||||
with open(path, "rb") as f:
|
||||
hyperparams = json.loads(f.readline().decode())
|
||||
pytree = type_(**hyperparams)
|
||||
return eqx.tree_deserialise_leaves(f, pytree)
|
||||
|
|
Loading…
Reference in New Issue
Block a user