EmptySAE
This commit is contained in:
parent
38b13808cf
commit
ff6775f2a1
21
LICENSE
21
LICENSE
|
@ -1,21 +0,0 @@
|
||||||
MIT License
|
|
||||||
|
|
||||||
Copyright (c) [year] [fullname]
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
1
doc/.quarto/idx/imagenet.ipynb.json
Normal file
1
doc/.quarto/idx/imagenet.ipynb.json
Normal file
File diff suppressed because one or more lines are too long
5
doc/.quarto/xref/INDEX
Normal file
5
doc/.quarto/xref/INDEX
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
{
|
||||||
|
"imagenet.ipynb": {
|
||||||
|
"imagenet.html": "c213c7e9"
|
||||||
|
}
|
||||||
|
}
|
1
doc/.quarto/xref/c213c7e9
Normal file
1
doc/.quarto/xref/c213c7e9
Normal file
|
@ -0,0 +1 @@
|
||||||
|
{"entries":[{"caption":"","order":{"number":1,"section":[0,0,0,0,0,0,0]},"key":"fig-label-balance"}],"headings":["image-compression"]}
|
245
doc/reconstruction-error.qmd
Normal file
245
doc/reconstruction-error.qmd
Normal file
|
@ -0,0 +1,245 @@
|
||||||
|
|
||||||
|
Load the features and the generations.
|
||||||
|
|
||||||
|
```{python}
|
||||||
|
from pathlib import Path
|
||||||
|
features2image_diffusion_dir = Path("../features2image_diffusion")
|
||||||
|
run_dir = features2image_diffusion_dir / "xyz/run/4409b6282a7d05f0b08880228d6d6564011fa40be412073ff05aff8bf2dc49fa"
|
||||||
|
batch_size = 64
|
||||||
|
shuffle = True
|
||||||
|
```
|
||||||
|
|
||||||
|
```{python}
|
||||||
|
from collections import OrderedDict
|
||||||
|
from itertools import islice
|
||||||
|
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
import torchvision
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
class EditedFeaturesAndGenerations(Dataset):
|
||||||
|
|
||||||
|
def __init__(self, gen_path, mnist_path="./res/MNIST", progress_bar=True, transform=None):
|
||||||
|
self.dir = gen_path
|
||||||
|
self.subdirs = OrderedDict()
|
||||||
|
|
||||||
|
mnist_transform = torchvision.transforms.Compose([
|
||||||
|
torchvision.transforms.ToTensor(),
|
||||||
|
torchvision.transforms.Normalize((0.5,), (0.5,)),
|
||||||
|
])
|
||||||
|
self.mnist = torchvision.datasets.MNIST(
|
||||||
|
str(mnist_path),
|
||||||
|
train=True,
|
||||||
|
download=True,
|
||||||
|
transform=mnist_transform,
|
||||||
|
)
|
||||||
|
|
||||||
|
a_generations_file = next(iter(filter(
|
||||||
|
lambda p: p.name.split("-")[-1] == "generations.npy",
|
||||||
|
Path(gen_path).rglob("*")
|
||||||
|
)))
|
||||||
|
self.generations_per_file = self.load_numpy(a_generations_file).shape[0]
|
||||||
|
|
||||||
|
self.len = 0
|
||||||
|
subdir_iter = filter(lambda d: d.is_dir(), Path(gen_path).iterdir())
|
||||||
|
if progress_bar:
|
||||||
|
subdir_iter = tqdm(
|
||||||
|
list(subdir_iter),
|
||||||
|
desc="loading features and generations",
|
||||||
|
postfix={"len": self.len}
|
||||||
|
)
|
||||||
|
|
||||||
|
for d in subdir_iter:
|
||||||
|
# convert /dir/parent/id-type.suffix to ./parent/id
|
||||||
|
files = sorted(map(
|
||||||
|
lambda f: (f.parent / f.stem.split("-")[0]).relative_to(self.dir),
|
||||||
|
# only iterate over generated images
|
||||||
|
filter(lambda f: f.suffix == ".png", d.rglob("*"))
|
||||||
|
))
|
||||||
|
self.subdirs[int(d.name)] = files
|
||||||
|
self.len += len(files) * self.generations_per_file
|
||||||
|
if progress_bar:
|
||||||
|
subdir_iter.set_postfix({"len": self.len}) # pyright: ignore[reportAttributeAccessIssue]
|
||||||
|
|
||||||
|
self.transform = torchvision.transforms.Compose(
|
||||||
|
[
|
||||||
|
# torchvision.transforms.ToTensor(),
|
||||||
|
torchvision.transforms.Normalize((0.5,), (0.5,)),
|
||||||
|
]
|
||||||
|
) if transform is None else transform
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_numpy(path):
|
||||||
|
return np.load(path, allow_pickle=False)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
# find which directory to sample from
|
||||||
|
subdirs = iter(self.subdirs.items())
|
||||||
|
d, files = next(subdirs)
|
||||||
|
while idx // self.generations_per_file > len(files):
|
||||||
|
idx -= len(files) * self.generations_per_file
|
||||||
|
d, files = next(subdirs)
|
||||||
|
|
||||||
|
original, label = self.mnist[d]
|
||||||
|
file = files[idx // self.generations_per_file]
|
||||||
|
features = self.load_numpy(self.dir / file.parent / (file.stem + "-features.npy"))
|
||||||
|
generation = self.load_numpy(self.dir / file.parent / (file.stem + "-generations.npy"))
|
||||||
|
# We only get a single generation from the file that contains multiple generations.
|
||||||
|
generation = generation[idx % self.generations_per_file]
|
||||||
|
# Temporarily convert the array to a torch tensor to apply the transformations.
|
||||||
|
generation = self.transform(torch.Tensor(generation)).numpy()
|
||||||
|
|
||||||
|
return str(file), generation, original, features, label
|
||||||
|
|
||||||
|
def iter_unedited(self):
|
||||||
|
for d in self.subdirs:
|
||||||
|
original, label = self.mnist[d]
|
||||||
|
d = str(d)
|
||||||
|
features = self.load_numpy(self.dir / d / "unedited-features.npy")
|
||||||
|
for generation in self.load_numpy(self.dir / d / "unedited-generations.npy"):
|
||||||
|
generation = self.transform(torch.Tensor(generation)).numpy()
|
||||||
|
yield d + "/unedited", generation, original, features, label
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.len
|
||||||
|
|
||||||
|
dataset = EditedFeaturesAndGenerations(run_dir)
|
||||||
|
data_loader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
collate_fn=lambda batch: tuple(map(np.array, zip(*batch))),
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=shuffle
|
||||||
|
)
|
||||||
|
path, generation, original, features, label = next(iter(data_loader))
|
||||||
|
print(path.shape, generation.shape, original.shape, features.shape, label.shape)
|
||||||
|
```
|
||||||
|
|
||||||
|
```{python}
|
||||||
|
_, gen_img, orig_img, _, _ = next(dataset.iter_unedited())
|
||||||
|
print(orig_img.max(), gen_img.max())
|
||||||
|
print(orig_img.min(), gen_img.min())
|
||||||
|
```
|
||||||
|
|
||||||
|
```{python}
|
||||||
|
import jo3mnist
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
_, axes = plt.subplots(1, 2, figsize=(10, 5))
|
||||||
|
axes[0].imshow(jo3mnist.to_img(orig_img))
|
||||||
|
axes[1].imshow(jo3mnist.to_img(gen_img))
|
||||||
|
(ax.axis('off') for ax in axes)
|
||||||
|
plt.show()
|
||||||
|
```
|
||||||
|
|
||||||
|
Load the CNN and the SAE and sow the CNN.
|
||||||
|
|
||||||
|
```{python}
|
||||||
|
import tomllib
|
||||||
|
# the config.toml in the features2image_diffusion's run_dir describes which SAE run results it used.
|
||||||
|
sae_run_dir = features2image_diffusion_dir / tomllib.loads((run_dir / "config.toml").read_text())["eval"][0]["feature_dir"]
|
||||||
|
# the sae_run_dir contains configuration infarmation.
|
||||||
|
# We need this for knowing which cnn and sae to load and at what place to intercept cnn activations.
|
||||||
|
sae_config = tomllib.loads((sae_run_dir / "config.toml").read_text())
|
||||||
|
cnn_path = sae_config["cnn_storage"]
|
||||||
|
sae_path = sae_run_dir / "sae.eqx"
|
||||||
|
sow_layer = sae_config["sae"][0]["layer"]
|
||||||
|
|
||||||
|
print(cnn_path, sae_path, sow_layer, sep="\n")
|
||||||
|
```
|
||||||
|
|
||||||
|
```{python}
|
||||||
|
import jax
|
||||||
|
import equinox as eqx
|
||||||
|
import jo3util.eqx as jo3eqx
|
||||||
|
|
||||||
|
from src.sae import SAE
|
||||||
|
from src.cnn import CNN
|
||||||
|
|
||||||
|
sae = jo3eqx.load(sae_path, SAE)
|
||||||
|
|
||||||
|
# equinox wants to see an example cnn before loading the weights into it
|
||||||
|
cnn = CNN(jax.random.PRNGKey(0))
|
||||||
|
# load the weights into the example cnn
|
||||||
|
cnn = eqx.tree_deserialise_leaves(cnn_path, cnn)
|
||||||
|
# sow the cnn, so we can intercept intermediate activations every forward pass
|
||||||
|
cnn = jo3eqx.sow(lambda m: m.layers[sow_layer], cnn)
|
||||||
|
```
|
||||||
|
|
||||||
|
Now reconstruct features from the generated images and compare the reconstructed features with the original.
|
||||||
|
|
||||||
|
```{python}
|
||||||
|
path, gens, orig, orig_features, label = tuple(map(np.array, zip(*dataset.iter_unedited())))
|
||||||
|
print(path.shape, gens.shape, orig_features.shape)
|
||||||
|
# Pass the images generated by the diffusion through the cnn and fetch the intermediate activations.
|
||||||
|
activ, pred = jax.vmap(cnn)(gens) # pyright: ignore[reportArgumentType]
|
||||||
|
print(activ.shape, pred.shape)
|
||||||
|
# Reconstruct the features by encoding the intermediate activations with the SAE.
|
||||||
|
recon_features = jax.vmap(sae.encode)(activ)
|
||||||
|
print(recon_features.shape)
|
||||||
|
# Look at the difference between the original features and the reconstructed features.
|
||||||
|
mse = jnp.sum((orig_features - recon_features)**2, axis=-1) # pyright: ignore[reportAssignmentType]
|
||||||
|
print(mse)
|
||||||
|
```
|
||||||
|
|
||||||
|
If we use generated images from edited features instead, the reconstruction difference is larger.
|
||||||
|
|
||||||
|
```{python}
|
||||||
|
path, gens, orig, orig_features, label = next(iter(data_loader))
|
||||||
|
activ, pred = jax.vmap(cnn)(gens) # pyright: ignore[reportArgumentType]
|
||||||
|
recon_features = jax.vmap(sae.encode)(activ)
|
||||||
|
mse = jnp.sum((orig_features - recon_features)**2, axis=-1) # pyright: ignore[reportAssignmentType]
|
||||||
|
print(mse)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
res/cnn.eqx
BIN
res/cnn.eqx
Binary file not shown.
BIN
res/sae.eqx
BIN
res/sae.eqx
Binary file not shown.
|
@ -1,8 +0,0 @@
|
||||||
batch_size = 64
|
|
||||||
steps = 10000
|
|
||||||
print_every = 500
|
|
||||||
seed = 1
|
|
||||||
cnn_storage = "/home/jo3/p/sparse_autoencoder/sparse_autoencoder/../res/cnn.eqx"
|
|
||||||
sae = [
|
|
||||||
{ layer = 6, hidden_size = 256, input_size = 64, learning_rate = 0.0001, l1 = 0.0003 },
|
|
||||||
]
|
|
Binary file not shown.
|
@ -1,5 +0,0 @@
|
||||||
layer = 6
|
|
||||||
hidden_size = 256
|
|
||||||
input_size = 64
|
|
||||||
learning_rate = 0.0001
|
|
||||||
l1 = 0.0003
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user