todos as warnings, eqx.py now is able to capture intermediate activations and insert modules

This commit is contained in:
JJJHolscher 2023-10-27 16:13:28 +02:00
parent bf2c478ac8
commit 070bf9c3c2
3 changed files with 91 additions and 2 deletions

View File

@ -1,7 +1,7 @@
[project] [project]
name = "jo3util" name = "jo3util"
version = "0.0.2" version = "0.0.3"
description = "" description = ""
dependencies = [] dependencies = []
dynamic = ["readme"] dynamic = ["readme"]

View File

@ -1,10 +1,12 @@
#! /usr/bin/env python3 #! /usr/bin/env python3
from typing import Callable, Optional from typing import Callable, Optional, Tuple
import equinox as eqx import equinox as eqx
import jax
from .fn import compose from .fn import compose
from .warning import todo
def create_hook( def create_hook(
@ -42,3 +44,72 @@ def create_hook(
return Hook(node) return Hook(node)
return _create_hook return _create_hook
def sow(where: Callable, model: eqx.Module) -> eqx.Module:
"""Capture intermediate activations that the argument modules output
and return them together after model output"""
activ = []
def install_sow(node: Callable):
node_call = type(node).__call__
if isinstance(node, eqx.Module):
todo("make StoreActivation a generic class")
class StoreActivation(type(node)):
def __init__(self, node):
self.__dict__.update(node.__dict__)
def __call__(self, *args, **kwargs):
x = node_call(self, *args, **kwargs)
activ.append(x)
return x
return StoreActivation(node)
else:
def store_activation(*args, **kwargs):
x = node(*args, **kwargs)
activ.append(x)
return x
return eqx.filter_jit(store_activation)
model = eqx.tree_at(where, model, replace_fn=install_sow)
model_call = type(model).__call__
todo("make Sow a generic class")
class Sow(type(model)):
def __init__(self, model):
self.__dict__.update(model.__dict__)
def __call__(self, *args, **kwargs):
activ.clear() # empty the list
x = model_call(self, *args, **kwargs)
return x, activ
return Sow(model)
def insert_after(where: Callable, model: eqx.Module, func: Callable) -> eqx.Module:
"""Place a callable immediately after the argument modules"""
class Ensemble(eqx.Module):
children: Tuple
def __init__(self, node):
self.children = (node, func)
def __call__(self, *args, **kwargs):
x = self.children[0](*args, **kwargs)
x = self.children[1](x)
return x
def __getitem__(self, items):
return self.children[items]
model = eqx.tree_at(where, model, replace_fn=Ensemble)
return model

18
src/warning.py Normal file
View File

@ -0,0 +1,18 @@
#! /usr/bin/env python3
# vim:fenc=utf-8
"""
"""
import warnings
class ToDoWarning(Warning):
def __init__(self, msg):
self.message = msg
def __str__(self):
return repr(self.message)
def todo(msg):
warnings.warn(msg, ToDoWarning)