todos as warnings, eqx.py now is able to capture intermediate activations and insert modules
This commit is contained in:
parent
bf2c478ac8
commit
070bf9c3c2
|
@ -1,7 +1,7 @@
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "jo3util"
|
name = "jo3util"
|
||||||
version = "0.0.2"
|
version = "0.0.3"
|
||||||
description = ""
|
description = ""
|
||||||
dependencies = []
|
dependencies = []
|
||||||
dynamic = ["readme"]
|
dynamic = ["readme"]
|
||||||
|
|
73
src/eqx.py
73
src/eqx.py
|
@ -1,10 +1,12 @@
|
||||||
#! /usr/bin/env python3
|
#! /usr/bin/env python3
|
||||||
|
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional, Tuple
|
||||||
|
|
||||||
import equinox as eqx
|
import equinox as eqx
|
||||||
|
import jax
|
||||||
|
|
||||||
from .fn import compose
|
from .fn import compose
|
||||||
|
from .warning import todo
|
||||||
|
|
||||||
|
|
||||||
def create_hook(
|
def create_hook(
|
||||||
|
@ -42,3 +44,72 @@ def create_hook(
|
||||||
return Hook(node)
|
return Hook(node)
|
||||||
|
|
||||||
return _create_hook
|
return _create_hook
|
||||||
|
|
||||||
|
|
||||||
|
def sow(where: Callable, model: eqx.Module) -> eqx.Module:
|
||||||
|
"""Capture intermediate activations that the argument modules output
|
||||||
|
and return them together after model output"""
|
||||||
|
activ = []
|
||||||
|
|
||||||
|
def install_sow(node: Callable):
|
||||||
|
node_call = type(node).__call__
|
||||||
|
|
||||||
|
if isinstance(node, eqx.Module):
|
||||||
|
todo("make StoreActivation a generic class")
|
||||||
|
|
||||||
|
class StoreActivation(type(node)):
|
||||||
|
def __init__(self, node):
|
||||||
|
self.__dict__.update(node.__dict__)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
x = node_call(self, *args, **kwargs)
|
||||||
|
activ.append(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
return StoreActivation(node)
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def store_activation(*args, **kwargs):
|
||||||
|
x = node(*args, **kwargs)
|
||||||
|
activ.append(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
return eqx.filter_jit(store_activation)
|
||||||
|
|
||||||
|
model = eqx.tree_at(where, model, replace_fn=install_sow)
|
||||||
|
|
||||||
|
model_call = type(model).__call__
|
||||||
|
todo("make Sow a generic class")
|
||||||
|
|
||||||
|
class Sow(type(model)):
|
||||||
|
def __init__(self, model):
|
||||||
|
self.__dict__.update(model.__dict__)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
activ.clear() # empty the list
|
||||||
|
x = model_call(self, *args, **kwargs)
|
||||||
|
return x, activ
|
||||||
|
|
||||||
|
return Sow(model)
|
||||||
|
|
||||||
|
|
||||||
|
def insert_after(where: Callable, model: eqx.Module, func: Callable) -> eqx.Module:
|
||||||
|
"""Place a callable immediately after the argument modules"""
|
||||||
|
|
||||||
|
class Ensemble(eqx.Module):
|
||||||
|
children: Tuple
|
||||||
|
|
||||||
|
def __init__(self, node):
|
||||||
|
self.children = (node, func)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
x = self.children[0](*args, **kwargs)
|
||||||
|
x = self.children[1](x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def __getitem__(self, items):
|
||||||
|
return self.children[items]
|
||||||
|
|
||||||
|
model = eqx.tree_at(where, model, replace_fn=Ensemble)
|
||||||
|
return model
|
||||||
|
|
18
src/warning.py
Normal file
18
src/warning.py
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
#! /usr/bin/env python3
|
||||||
|
# vim:fenc=utf-8
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
class ToDoWarning(Warning):
|
||||||
|
def __init__(self, msg):
|
||||||
|
self.message = msg
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return repr(self.message)
|
||||||
|
|
||||||
|
def todo(msg):
|
||||||
|
warnings.warn(msg, ToDoWarning)
|
Loading…
Reference in New Issue
Block a user