todos as warnings, eqx.py now is able to capture intermediate activations and insert modules

This commit is contained in:
JJJHolscher 2023-10-27 16:13:28 +02:00
parent bf2c478ac8
commit 070bf9c3c2
3 changed files with 91 additions and 2 deletions

View File

@ -1,7 +1,7 @@
[project]
name = "jo3util"
version = "0.0.2"
version = "0.0.3"
description = ""
dependencies = []
dynamic = ["readme"]

View File

@ -1,10 +1,12 @@
#! /usr/bin/env python3
from typing import Callable, Optional
from typing import Callable, Optional, Tuple
import equinox as eqx
import jax
from .fn import compose
from .warning import todo
def create_hook(
@ -42,3 +44,72 @@ def create_hook(
return Hook(node)
return _create_hook
def sow(where: Callable, model: eqx.Module) -> eqx.Module:
"""Capture intermediate activations that the argument modules output
and return them together after model output"""
activ = []
def install_sow(node: Callable):
node_call = type(node).__call__
if isinstance(node, eqx.Module):
todo("make StoreActivation a generic class")
class StoreActivation(type(node)):
def __init__(self, node):
self.__dict__.update(node.__dict__)
def __call__(self, *args, **kwargs):
x = node_call(self, *args, **kwargs)
activ.append(x)
return x
return StoreActivation(node)
else:
def store_activation(*args, **kwargs):
x = node(*args, **kwargs)
activ.append(x)
return x
return eqx.filter_jit(store_activation)
model = eqx.tree_at(where, model, replace_fn=install_sow)
model_call = type(model).__call__
todo("make Sow a generic class")
class Sow(type(model)):
def __init__(self, model):
self.__dict__.update(model.__dict__)
def __call__(self, *args, **kwargs):
activ.clear() # empty the list
x = model_call(self, *args, **kwargs)
return x, activ
return Sow(model)
def insert_after(where: Callable, model: eqx.Module, func: Callable) -> eqx.Module:
"""Place a callable immediately after the argument modules"""
class Ensemble(eqx.Module):
children: Tuple
def __init__(self, node):
self.children = (node, func)
def __call__(self, *args, **kwargs):
x = self.children[0](*args, **kwargs)
x = self.children[1](x)
return x
def __getitem__(self, items):
return self.children[items]
model = eqx.tree_at(where, model, replace_fn=Ensemble)
return model

18
src/warning.py Normal file
View File

@ -0,0 +1,18 @@
#! /usr/bin/env python3
# vim:fenc=utf-8
"""
"""
import warnings
class ToDoWarning(Warning):
def __init__(self, msg):
self.message = msg
def __str__(self):
return repr(self.message)
def todo(msg):
warnings.warn(msg, ToDoWarning)