save function accepts more hyperparameter types now
This commit is contained in:
parent
238e1fb880
commit
46669bed3d
|
@ -1,7 +1,7 @@
|
|||
|
||||
[project]
|
||||
name = "jo3util"
|
||||
version = "0.0.6"
|
||||
version = "0.0.7"
|
||||
description = ""
|
||||
dependencies = []
|
||||
dynamic = ["readme"]
|
||||
|
|
16
src/eqx.py
16
src/eqx.py
|
@ -1,7 +1,8 @@
|
|||
#! /usr/bin/env python3
|
||||
|
||||
import json
|
||||
from typing import Callable, Optional, Tuple
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
from pathlib import Path
|
||||
|
||||
import equinox as eqx
|
||||
import jax
|
||||
|
@ -95,7 +96,11 @@ def sow(where: Callable, model: eqx.Module) -> eqx.Module:
|
|||
return Sow(model)
|
||||
|
||||
|
||||
def insert_after(where: Callable, model: eqx.Module, func: Callable) -> eqx.Module:
|
||||
def insert_after(
|
||||
where: Callable,
|
||||
model: eqx.Module,
|
||||
func: Callable
|
||||
) -> eqx.Module:
|
||||
"""Place a callable immediately after the argument modules"""
|
||||
|
||||
class Ensemble(eqx.Module):
|
||||
|
@ -116,9 +121,12 @@ def insert_after(where: Callable, model: eqx.Module, func: Callable) -> eqx.Modu
|
|||
return model
|
||||
|
||||
|
||||
def save(path, pytree, hyperparameters={}):
|
||||
def save(path: Union[Path, str], pytree, hyperparameters={}):
|
||||
with open(path, "wb") as f:
|
||||
hyperparameters = json.dumps(hyperparameters)
|
||||
hyperparameters = json.dumps(
|
||||
hyperparameters,
|
||||
default=lambda h: vars(h)
|
||||
)
|
||||
f.write((hyperparameters + "\n").encode())
|
||||
eqx.tree_serialise_leaves(f, pytree)
|
||||
|
||||
|
|
|
@ -1,7 +0,0 @@
|
|||
#! /usr/bin/env python3
|
||||
# vim:fenc=utf-8
|
||||
|
||||
"""
|
||||
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user