todos as warnings, eqx.py now is able to capture intermediate activations and insert modules
This commit is contained in:
parent
bf2c478ac8
commit
070bf9c3c2
|
@ -1,7 +1,7 @@
|
|||
|
||||
[project]
|
||||
name = "jo3util"
|
||||
version = "0.0.2"
|
||||
version = "0.0.3"
|
||||
description = ""
|
||||
dependencies = []
|
||||
dynamic = ["readme"]
|
||||
|
|
73
src/eqx.py
73
src/eqx.py
|
@ -1,10 +1,12 @@
|
|||
#! /usr/bin/env python3
|
||||
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import equinox as eqx
|
||||
import jax
|
||||
|
||||
from .fn import compose
|
||||
from .warning import todo
|
||||
|
||||
|
||||
def create_hook(
|
||||
|
@ -42,3 +44,72 @@ def create_hook(
|
|||
return Hook(node)
|
||||
|
||||
return _create_hook
|
||||
|
||||
|
||||
def sow(where: Callable, model: eqx.Module) -> eqx.Module:
|
||||
"""Capture intermediate activations that the argument modules output
|
||||
and return them together after model output"""
|
||||
activ = []
|
||||
|
||||
def install_sow(node: Callable):
|
||||
node_call = type(node).__call__
|
||||
|
||||
if isinstance(node, eqx.Module):
|
||||
todo("make StoreActivation a generic class")
|
||||
|
||||
class StoreActivation(type(node)):
|
||||
def __init__(self, node):
|
||||
self.__dict__.update(node.__dict__)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
x = node_call(self, *args, **kwargs)
|
||||
activ.append(x)
|
||||
return x
|
||||
|
||||
return StoreActivation(node)
|
||||
|
||||
else:
|
||||
|
||||
def store_activation(*args, **kwargs):
|
||||
x = node(*args, **kwargs)
|
||||
activ.append(x)
|
||||
return x
|
||||
|
||||
return eqx.filter_jit(store_activation)
|
||||
|
||||
model = eqx.tree_at(where, model, replace_fn=install_sow)
|
||||
|
||||
model_call = type(model).__call__
|
||||
todo("make Sow a generic class")
|
||||
|
||||
class Sow(type(model)):
|
||||
def __init__(self, model):
|
||||
self.__dict__.update(model.__dict__)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
activ.clear() # empty the list
|
||||
x = model_call(self, *args, **kwargs)
|
||||
return x, activ
|
||||
|
||||
return Sow(model)
|
||||
|
||||
|
||||
def insert_after(where: Callable, model: eqx.Module, func: Callable) -> eqx.Module:
|
||||
"""Place a callable immediately after the argument modules"""
|
||||
|
||||
class Ensemble(eqx.Module):
|
||||
children: Tuple
|
||||
|
||||
def __init__(self, node):
|
||||
self.children = (node, func)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
x = self.children[0](*args, **kwargs)
|
||||
x = self.children[1](x)
|
||||
return x
|
||||
|
||||
def __getitem__(self, items):
|
||||
return self.children[items]
|
||||
|
||||
model = eqx.tree_at(where, model, replace_fn=Ensemble)
|
||||
return model
|
||||
|
|
18
src/warning.py
Normal file
18
src/warning.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
#! /usr/bin/env python3
|
||||
# vim:fenc=utf-8
|
||||
|
||||
"""
|
||||
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
class ToDoWarning(Warning):
|
||||
def __init__(self, msg):
|
||||
self.message = msg
|
||||
|
||||
def __str__(self):
|
||||
return repr(self.message)
|
||||
|
||||
def todo(msg):
|
||||
warnings.warn(msg, ToDoWarning)
|
Loading…
Reference in New Issue
Block a user