equinox save
This commit is contained in:
parent
55d4db51a0
commit
d19ea6fdfe
|
@ -1,7 +1,7 @@
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "jo3util"
|
name = "jo3util"
|
||||||
version = "0.0.3"
|
version = "0.0.6"
|
||||||
description = ""
|
description = ""
|
||||||
dependencies = []
|
dependencies = []
|
||||||
dynamic = ["readme"]
|
dynamic = ["readme"]
|
||||||
|
|
15
src/eqx.py
15
src/eqx.py
|
@ -1,5 +1,6 @@
|
||||||
#! /usr/bin/env python3
|
#! /usr/bin/env python3
|
||||||
|
|
||||||
|
import json
|
||||||
from typing import Callable, Optional, Tuple
|
from typing import Callable, Optional, Tuple
|
||||||
|
|
||||||
import equinox as eqx
|
import equinox as eqx
|
||||||
|
@ -113,3 +114,17 @@ def insert_after(where: Callable, model: eqx.Module, func: Callable) -> eqx.Modu
|
||||||
|
|
||||||
model = eqx.tree_at(where, model, replace_fn=Ensemble)
|
model = eqx.tree_at(where, model, replace_fn=Ensemble)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def save(path, pytree, hyperparameters={}):
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
hyperparameters = json.dumps(hyperparameters)
|
||||||
|
f.write((hyperparameters + "\n").encode())
|
||||||
|
eqx.tree_serialise_leaves(f, pytree)
|
||||||
|
|
||||||
|
|
||||||
|
def load(path, type_):
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
hyperparams = json.loads(f.readline().decode())
|
||||||
|
pytree = type_(**hyperparams)
|
||||||
|
return eqx.tree_deserialise_leaves(f, pytree)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user