save function accepts more hyperparameter types now
This commit is contained in:
parent
238e1fb880
commit
46669bed3d
|
@ -1,7 +1,7 @@
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "jo3util"
|
name = "jo3util"
|
||||||
version = "0.0.6"
|
version = "0.0.7"
|
||||||
description = ""
|
description = ""
|
||||||
dependencies = []
|
dependencies = []
|
||||||
dynamic = ["readme"]
|
dynamic = ["readme"]
|
||||||
|
|
16
src/eqx.py
16
src/eqx.py
|
@ -1,7 +1,8 @@
|
||||||
#! /usr/bin/env python3
|
#! /usr/bin/env python3
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Callable, Optional, Tuple
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import equinox as eqx
|
import equinox as eqx
|
||||||
import jax
|
import jax
|
||||||
|
@ -95,7 +96,11 @@ def sow(where: Callable, model: eqx.Module) -> eqx.Module:
|
||||||
return Sow(model)
|
return Sow(model)
|
||||||
|
|
||||||
|
|
||||||
def insert_after(where: Callable, model: eqx.Module, func: Callable) -> eqx.Module:
|
def insert_after(
|
||||||
|
where: Callable,
|
||||||
|
model: eqx.Module,
|
||||||
|
func: Callable
|
||||||
|
) -> eqx.Module:
|
||||||
"""Place a callable immediately after the argument modules"""
|
"""Place a callable immediately after the argument modules"""
|
||||||
|
|
||||||
class Ensemble(eqx.Module):
|
class Ensemble(eqx.Module):
|
||||||
|
@ -116,9 +121,12 @@ def insert_after(where: Callable, model: eqx.Module, func: Callable) -> eqx.Modu
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def save(path, pytree, hyperparameters={}):
|
def save(path: Union[Path, str], pytree, hyperparameters={}):
|
||||||
with open(path, "wb") as f:
|
with open(path, "wb") as f:
|
||||||
hyperparameters = json.dumps(hyperparameters)
|
hyperparameters = json.dumps(
|
||||||
|
hyperparameters,
|
||||||
|
default=lambda h: vars(h)
|
||||||
|
)
|
||||||
f.write((hyperparameters + "\n").encode())
|
f.write((hyperparameters + "\n").encode())
|
||||||
eqx.tree_serialise_leaves(f, pytree)
|
eqx.tree_serialise_leaves(f, pytree)
|
||||||
|
|
||||||
|
|
|
@ -1,7 +0,0 @@
|
||||||
#! /usr/bin/env python3
|
|
||||||
# vim:fenc=utf-8
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user