diff --git a/pyproject.toml b/pyproject.toml index e0aad04..aa0ebb8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "jo3util" -version = "0.0.2" +version = "0.0.3" description = "" dependencies = [] dynamic = ["readme"] diff --git a/src/eqx.py b/src/eqx.py index 41e63d8..ab11538 100644 --- a/src/eqx.py +++ b/src/eqx.py @@ -1,10 +1,12 @@ #! /usr/bin/env python3 -from typing import Callable, Optional +from typing import Callable, Optional, Tuple import equinox as eqx +import jax from .fn import compose +from .warning import todo def create_hook( @@ -42,3 +44,72 @@ def create_hook( return Hook(node) return _create_hook + + +def sow(where: Callable, model: eqx.Module) -> eqx.Module: + """Capture intermediate activations that the argument modules output + and return them together after model output""" + activ = [] + + def install_sow(node: Callable): + node_call = type(node).__call__ + + if isinstance(node, eqx.Module): + todo("make StoreActivation a generic class") + + class StoreActivation(type(node)): + def __init__(self, node): + self.__dict__.update(node.__dict__) + + def __call__(self, *args, **kwargs): + x = node_call(self, *args, **kwargs) + activ.append(x) + return x + + return StoreActivation(node) + + else: + + def store_activation(*args, **kwargs): + x = node(*args, **kwargs) + activ.append(x) + return x + + return eqx.filter_jit(store_activation) + + model = eqx.tree_at(where, model, replace_fn=install_sow) + + model_call = type(model).__call__ + todo("make Sow a generic class") + + class Sow(type(model)): + def __init__(self, model): + self.__dict__.update(model.__dict__) + + def __call__(self, *args, **kwargs): + activ.clear() # empty the list + x = model_call(self, *args, **kwargs) + return x, activ + + return Sow(model) + + +def insert_after(where: Callable, model: eqx.Module, func: Callable) -> eqx.Module: + """Place a callable immediately after the argument modules""" + + class Ensemble(eqx.Module): + children: Tuple + + def __init__(self, node): + self.children = (node, func) + + def __call__(self, *args, **kwargs): + x = self.children[0](*args, **kwargs) + x = self.children[1](x) + return x + + def __getitem__(self, items): + return self.children[items] + + model = eqx.tree_at(where, model, replace_fn=Ensemble) + return model diff --git a/src/warning.py b/src/warning.py new file mode 100644 index 0000000..bf9781b --- /dev/null +++ b/src/warning.py @@ -0,0 +1,18 @@ +#! /usr/bin/env python3 +# vim:fenc=utf-8 + +""" + +""" + +import warnings + +class ToDoWarning(Warning): + def __init__(self, msg): + self.message = msg + + def __str__(self): + return repr(self.message) + +def todo(msg): + warnings.warn(msg, ToDoWarning)