This commit is contained in:
JJJHolscher 2024-08-17 15:49:39 +02:00
parent 38b13808cf
commit ff6775f2a1
70021 changed files with 271 additions and 41 deletions

21
LICENSE
View File

@ -1,21 +0,0 @@
MIT License
Copyright (c) [year] [fullname]
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

File diff suppressed because one or more lines are too long

5
doc/.quarto/xref/INDEX Normal file
View File

@ -0,0 +1,5 @@
{
"imagenet.ipynb": {
"imagenet.html": "c213c7e9"
}
}

View File

@ -0,0 +1 @@
{"entries":[{"caption":"","order":{"number":1,"section":[0,0,0,0,0,0,0]},"key":"fig-label-balance"}],"headings":["image-compression"]}

View File

@ -0,0 +1,245 @@
Load the features and the generations.
```{python}
from pathlib import Path
features2image_diffusion_dir = Path("../features2image_diffusion")
run_dir = features2image_diffusion_dir / "xyz/run/4409b6282a7d05f0b08880228d6d6564011fa40be412073ff05aff8bf2dc49fa"
batch_size = 64
shuffle = True
```
```{python}
from collections import OrderedDict
from itertools import islice
import jax.numpy as jnp
import numpy as np
import torchvision
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
class EditedFeaturesAndGenerations(Dataset):
def __init__(self, gen_path, mnist_path="./res/MNIST", progress_bar=True, transform=None):
self.dir = gen_path
self.subdirs = OrderedDict()
mnist_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5,), (0.5,)),
])
self.mnist = torchvision.datasets.MNIST(
str(mnist_path),
train=True,
download=True,
transform=mnist_transform,
)
a_generations_file = next(iter(filter(
lambda p: p.name.split("-")[-1] == "generations.npy",
Path(gen_path).rglob("*")
)))
self.generations_per_file = self.load_numpy(a_generations_file).shape[0]
self.len = 0
subdir_iter = filter(lambda d: d.is_dir(), Path(gen_path).iterdir())
if progress_bar:
subdir_iter = tqdm(
list(subdir_iter),
desc="loading features and generations",
postfix={"len": self.len}
)
for d in subdir_iter:
# convert /dir/parent/id-type.suffix to ./parent/id
files = sorted(map(
lambda f: (f.parent / f.stem.split("-")[0]).relative_to(self.dir),
# only iterate over generated images
filter(lambda f: f.suffix == ".png", d.rglob("*"))
))
self.subdirs[int(d.name)] = files
self.len += len(files) * self.generations_per_file
if progress_bar:
subdir_iter.set_postfix({"len": self.len}) # pyright: ignore[reportAttributeAccessIssue]
self.transform = torchvision.transforms.Compose(
[
# torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5,), (0.5,)),
]
) if transform is None else transform
@staticmethod
def load_numpy(path):
return np.load(path, allow_pickle=False)
def __getitem__(self, idx):
# find which directory to sample from
subdirs = iter(self.subdirs.items())
d, files = next(subdirs)
while idx // self.generations_per_file > len(files):
idx -= len(files) * self.generations_per_file
d, files = next(subdirs)
original, label = self.mnist[d]
file = files[idx // self.generations_per_file]
features = self.load_numpy(self.dir / file.parent / (file.stem + "-features.npy"))
generation = self.load_numpy(self.dir / file.parent / (file.stem + "-generations.npy"))
# We only get a single generation from the file that contains multiple generations.
generation = generation[idx % self.generations_per_file]
# Temporarily convert the array to a torch tensor to apply the transformations.
generation = self.transform(torch.Tensor(generation)).numpy()
return str(file), generation, original, features, label
def iter_unedited(self):
for d in self.subdirs:
original, label = self.mnist[d]
d = str(d)
features = self.load_numpy(self.dir / d / "unedited-features.npy")
for generation in self.load_numpy(self.dir / d / "unedited-generations.npy"):
generation = self.transform(torch.Tensor(generation)).numpy()
yield d + "/unedited", generation, original, features, label
def __len__(self):
return self.len
dataset = EditedFeaturesAndGenerations(run_dir)
data_loader = DataLoader(
dataset,
collate_fn=lambda batch: tuple(map(np.array, zip(*batch))),
batch_size=batch_size,
shuffle=shuffle
)
path, generation, original, features, label = next(iter(data_loader))
print(path.shape, generation.shape, original.shape, features.shape, label.shape)
```
```{python}
_, gen_img, orig_img, _, _ = next(dataset.iter_unedited())
print(orig_img.max(), gen_img.max())
print(orig_img.min(), gen_img.min())
```
```{python}
import jo3mnist
import matplotlib.pyplot as plt
_, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(jo3mnist.to_img(orig_img))
axes[1].imshow(jo3mnist.to_img(gen_img))
(ax.axis('off') for ax in axes)
plt.show()
```
Load the CNN and the SAE and sow the CNN.
```{python}
import tomllib
# the config.toml in the features2image_diffusion's run_dir describes which SAE run results it used.
sae_run_dir = features2image_diffusion_dir / tomllib.loads((run_dir / "config.toml").read_text())["eval"][0]["feature_dir"]
# the sae_run_dir contains configuration infarmation.
# We need this for knowing which cnn and sae to load and at what place to intercept cnn activations.
sae_config = tomllib.loads((sae_run_dir / "config.toml").read_text())
cnn_path = sae_config["cnn_storage"]
sae_path = sae_run_dir / "sae.eqx"
sow_layer = sae_config["sae"][0]["layer"]
print(cnn_path, sae_path, sow_layer, sep="\n")
```
```{python}
import jax
import equinox as eqx
import jo3util.eqx as jo3eqx
from src.sae import SAE
from src.cnn import CNN
sae = jo3eqx.load(sae_path, SAE)
# equinox wants to see an example cnn before loading the weights into it
cnn = CNN(jax.random.PRNGKey(0))
# load the weights into the example cnn
cnn = eqx.tree_deserialise_leaves(cnn_path, cnn)
# sow the cnn, so we can intercept intermediate activations every forward pass
cnn = jo3eqx.sow(lambda m: m.layers[sow_layer], cnn)
```
Now reconstruct features from the generated images and compare the reconstructed features with the original.
```{python}
path, gens, orig, orig_features, label = tuple(map(np.array, zip(*dataset.iter_unedited())))
print(path.shape, gens.shape, orig_features.shape)
# Pass the images generated by the diffusion through the cnn and fetch the intermediate activations.
activ, pred = jax.vmap(cnn)(gens) # pyright: ignore[reportArgumentType]
print(activ.shape, pred.shape)
# Reconstruct the features by encoding the intermediate activations with the SAE.
recon_features = jax.vmap(sae.encode)(activ)
print(recon_features.shape)
# Look at the difference between the original features and the reconstructed features.
mse = jnp.sum((orig_features - recon_features)**2, axis=-1) # pyright: ignore[reportAssignmentType]
print(mse)
```
If we use generated images from edited features instead, the reconstruction difference is larger.
```{python}
path, gens, orig, orig_features, label = next(iter(data_loader))
activ, pred = jax.vmap(cnn)(gens) # pyright: ignore[reportArgumentType]
recon_features = jax.vmap(sae.encode)(activ)
mse = jnp.sum((orig_features - recon_features)**2, axis=-1) # pyright: ignore[reportAssignmentType]
print(mse)
```

Binary file not shown.

Binary file not shown.

View File

@ -1,8 +0,0 @@
batch_size = 64
steps = 10000
print_every = 500
seed = 1
cnn_storage = "/home/jo3/p/sparse_autoencoder/sparse_autoencoder/../res/cnn.eqx"
sae = [
{ layer = 6, hidden_size = 256, input_size = 64, learning_rate = 0.0001, l1 = 0.0003 },
]

View File

@ -1,5 +0,0 @@
layer = 6
hidden_size = 256
input_size = 64
learning_rate = 0.0001
l1 = 0.0003

Some files were not shown because too many files have changed in this diff Show More