diff --git a/pyproject.toml b/pyproject.toml index 4fb0ee7..e0aad04 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,3 +41,4 @@ formats = "ipynb,py" exclude = ".venv" venvPath = "." venv = ".venv" +reportMissingImports = false diff --git a/src/__init__.py b/src/__init__.py index 00e3147..e69de29 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,3 +0,0 @@ -#! /usr/bin/env python3 -# Remove this file if you don't have a main.py. -from .main import * diff --git a/src/eqx.py b/src/eqx.py new file mode 100644 index 0000000..41e63d8 --- /dev/null +++ b/src/eqx.py @@ -0,0 +1,44 @@ +#! /usr/bin/env python3 + +from typing import Callable, Optional + +import equinox as eqx + +from .fn import compose + + +def create_hook( + fwd_pre: Callable = lambda *arg, **kwarg: None, + fwd_post: Callable = lambda *arg, **kwarg: None, + bwd_pre: Callable = lambda *arg, **kwarg: None, + bwd_post: Callable = lambda *arg, **kwarg: None, +) -> Callable: + def _create_hook(node: eqx.Module) -> eqx.Module: + node_call = type(node).__call__ + + @eqx.filter_custom_jvp + def fwd(hook, *args, **kwargs): + fwd_pre(*args, **kwargs) + out = node_call(hook, *args, **kwargs) + fwd_post(out) + return out + + @fwd.def_jvp + def bwd(primals, tangents): + bwd_pre(*primals, *tangents) + primals_out, tangents_out = eqx.filter_jvp( + node_call, primals, tangents + ) + bwd_post(primals_out, tangents_out) + return primals_out, tangents_out + + class Hook(type(node)): + def __init__(self, node): + self.__dict__.update(node.__dict__) + + def __call__(self, *args, **kwargs): + return fwd(self, *args, **kwargs) + + return Hook(node) + + return _create_hook diff --git a/src/fn.py b/src/fn.py new file mode 100644 index 0000000..bb3e490 --- /dev/null +++ b/src/fn.py @@ -0,0 +1,26 @@ +#! /usr/bin/env python3 +# vim:fenc=utf-8 + +from functools import reduce + + + +def compose(*func): + """Create a single function out of multiple functions.""" + + def compose_(f, g): + return lambda *args, **kwargs: f(g(*args, **kwargs)) + + return reduce(compose_, func, lambda out: out) +if __name__ == "__main__": + + def fn0(x): + print("fn0", x + 1) + return x + 1 + + def fn1(x): + print("fn1", -x) + return -x + + fn2 = compose(fn1, fn0) + assert fn2(5) == -6