equinox Hook, is probably useless, but works

This commit is contained in:
JJJHolscher 2023-10-03 19:41:30 +02:00
parent 8d948fd2ca
commit 607c64351f
4 changed files with 71 additions and 3 deletions

View File

@ -41,3 +41,4 @@ formats = "ipynb,py"
exclude = ".venv"
venvPath = "."
venv = ".venv"
reportMissingImports = false

View File

@ -1,3 +0,0 @@
#! /usr/bin/env python3
# Remove this file if you don't have a main.py.
from .main import *

44
src/eqx.py Normal file
View File

@ -0,0 +1,44 @@
#! /usr/bin/env python3
from typing import Callable, Optional
import equinox as eqx
from .fn import compose
def create_hook(
fwd_pre: Callable = lambda *arg, **kwarg: None,
fwd_post: Callable = lambda *arg, **kwarg: None,
bwd_pre: Callable = lambda *arg, **kwarg: None,
bwd_post: Callable = lambda *arg, **kwarg: None,
) -> Callable:
def _create_hook(node: eqx.Module) -> eqx.Module:
node_call = type(node).__call__
@eqx.filter_custom_jvp
def fwd(hook, *args, **kwargs):
fwd_pre(*args, **kwargs)
out = node_call(hook, *args, **kwargs)
fwd_post(out)
return out
@fwd.def_jvp
def bwd(primals, tangents):
bwd_pre(*primals, *tangents)
primals_out, tangents_out = eqx.filter_jvp(
node_call, primals, tangents
)
bwd_post(primals_out, tangents_out)
return primals_out, tangents_out
class Hook(type(node)):
def __init__(self, node):
self.__dict__.update(node.__dict__)
def __call__(self, *args, **kwargs):
return fwd(self, *args, **kwargs)
return Hook(node)
return _create_hook

26
src/fn.py Normal file
View File

@ -0,0 +1,26 @@
#! /usr/bin/env python3
# vim:fenc=utf-8
from functools import reduce
def compose(*func):
"""Create a single function out of multiple functions."""
def compose_(f, g):
return lambda *args, **kwargs: f(g(*args, **kwargs))
return reduce(compose_, func, lambda out: out)
if __name__ == "__main__":
def fn0(x):
print("fn0", x + 1)
return x + 1
def fn1(x):
print("fn1", -x)
return -x
fn2 = compose(fn1, fn0)
assert fn2(5) == -6