equinox Hook, is probably useless, but works
This commit is contained in:
parent
8d948fd2ca
commit
607c64351f
|
@ -41,3 +41,4 @@ formats = "ipynb,py"
|
||||||
exclude = ".venv"
|
exclude = ".venv"
|
||||||
venvPath = "."
|
venvPath = "."
|
||||||
venv = ".venv"
|
venv = ".venv"
|
||||||
|
reportMissingImports = false
|
||||||
|
|
|
@ -1,3 +0,0 @@
|
||||||
#! /usr/bin/env python3
|
|
||||||
# Remove this file if you don't have a main.py.
|
|
||||||
from .main import *
|
|
44
src/eqx.py
Normal file
44
src/eqx.py
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
#! /usr/bin/env python3
|
||||||
|
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import equinox as eqx
|
||||||
|
|
||||||
|
from .fn import compose
|
||||||
|
|
||||||
|
|
||||||
|
def create_hook(
|
||||||
|
fwd_pre: Callable = lambda *arg, **kwarg: None,
|
||||||
|
fwd_post: Callable = lambda *arg, **kwarg: None,
|
||||||
|
bwd_pre: Callable = lambda *arg, **kwarg: None,
|
||||||
|
bwd_post: Callable = lambda *arg, **kwarg: None,
|
||||||
|
) -> Callable:
|
||||||
|
def _create_hook(node: eqx.Module) -> eqx.Module:
|
||||||
|
node_call = type(node).__call__
|
||||||
|
|
||||||
|
@eqx.filter_custom_jvp
|
||||||
|
def fwd(hook, *args, **kwargs):
|
||||||
|
fwd_pre(*args, **kwargs)
|
||||||
|
out = node_call(hook, *args, **kwargs)
|
||||||
|
fwd_post(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
@fwd.def_jvp
|
||||||
|
def bwd(primals, tangents):
|
||||||
|
bwd_pre(*primals, *tangents)
|
||||||
|
primals_out, tangents_out = eqx.filter_jvp(
|
||||||
|
node_call, primals, tangents
|
||||||
|
)
|
||||||
|
bwd_post(primals_out, tangents_out)
|
||||||
|
return primals_out, tangents_out
|
||||||
|
|
||||||
|
class Hook(type(node)):
|
||||||
|
def __init__(self, node):
|
||||||
|
self.__dict__.update(node.__dict__)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return fwd(self, *args, **kwargs)
|
||||||
|
|
||||||
|
return Hook(node)
|
||||||
|
|
||||||
|
return _create_hook
|
26
src/fn.py
Normal file
26
src/fn.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
#! /usr/bin/env python3
|
||||||
|
# vim:fenc=utf-8
|
||||||
|
|
||||||
|
from functools import reduce
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def compose(*func):
|
||||||
|
"""Create a single function out of multiple functions."""
|
||||||
|
|
||||||
|
def compose_(f, g):
|
||||||
|
return lambda *args, **kwargs: f(g(*args, **kwargs))
|
||||||
|
|
||||||
|
return reduce(compose_, func, lambda out: out)
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
def fn0(x):
|
||||||
|
print("fn0", x + 1)
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
def fn1(x):
|
||||||
|
print("fn1", -x)
|
||||||
|
return -x
|
||||||
|
|
||||||
|
fn2 = compose(fn1, fn0)
|
||||||
|
assert fn2(5) == -6
|
Loading…
Reference in New Issue
Block a user