equinox Hook, is probably useless, but works
This commit is contained in:
parent
8d948fd2ca
commit
607c64351f
|
@ -41,3 +41,4 @@ formats = "ipynb,py"
|
|||
exclude = ".venv"
|
||||
venvPath = "."
|
||||
venv = ".venv"
|
||||
reportMissingImports = false
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
#! /usr/bin/env python3
|
||||
# Remove this file if you don't have a main.py.
|
||||
from .main import *
|
44
src/eqx.py
Normal file
44
src/eqx.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
#! /usr/bin/env python3
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import equinox as eqx
|
||||
|
||||
from .fn import compose
|
||||
|
||||
|
||||
def create_hook(
|
||||
fwd_pre: Callable = lambda *arg, **kwarg: None,
|
||||
fwd_post: Callable = lambda *arg, **kwarg: None,
|
||||
bwd_pre: Callable = lambda *arg, **kwarg: None,
|
||||
bwd_post: Callable = lambda *arg, **kwarg: None,
|
||||
) -> Callable:
|
||||
def _create_hook(node: eqx.Module) -> eqx.Module:
|
||||
node_call = type(node).__call__
|
||||
|
||||
@eqx.filter_custom_jvp
|
||||
def fwd(hook, *args, **kwargs):
|
||||
fwd_pre(*args, **kwargs)
|
||||
out = node_call(hook, *args, **kwargs)
|
||||
fwd_post(out)
|
||||
return out
|
||||
|
||||
@fwd.def_jvp
|
||||
def bwd(primals, tangents):
|
||||
bwd_pre(*primals, *tangents)
|
||||
primals_out, tangents_out = eqx.filter_jvp(
|
||||
node_call, primals, tangents
|
||||
)
|
||||
bwd_post(primals_out, tangents_out)
|
||||
return primals_out, tangents_out
|
||||
|
||||
class Hook(type(node)):
|
||||
def __init__(self, node):
|
||||
self.__dict__.update(node.__dict__)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return fwd(self, *args, **kwargs)
|
||||
|
||||
return Hook(node)
|
||||
|
||||
return _create_hook
|
26
src/fn.py
Normal file
26
src/fn.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
#! /usr/bin/env python3
|
||||
# vim:fenc=utf-8
|
||||
|
||||
from functools import reduce
|
||||
|
||||
|
||||
|
||||
def compose(*func):
|
||||
"""Create a single function out of multiple functions."""
|
||||
|
||||
def compose_(f, g):
|
||||
return lambda *args, **kwargs: f(g(*args, **kwargs))
|
||||
|
||||
return reduce(compose_, func, lambda out: out)
|
||||
if __name__ == "__main__":
|
||||
|
||||
def fn0(x):
|
||||
print("fn0", x + 1)
|
||||
return x + 1
|
||||
|
||||
def fn1(x):
|
||||
print("fn1", -x)
|
||||
return -x
|
||||
|
||||
fn2 = compose(fn1, fn0)
|
||||
assert fn2(5) == -6
|
Loading…
Reference in New Issue
Block a user