septiembre 17, 2021

~ 9 MIN

PerceiverIO - Clasificación de Imágenes

< Blog RSS

Open In Colab

PerceiverIO

En el post anterior vimos la arquitectura Perceiver, base del modelo más avanzado PerceiverIO, el cual aprenderemos a implementar en este post.

A grandes rasgos, PerceiverIO es muy similar a Perceiver. Utiliza exactamente la misma arquitectura para la codificación y procesado de datos. La única diferencia es en la salida, ya que introduce un decoder que es capaz de generar outputs para diferentes tareas. Gracias a este desacople entre entradas/salidas y el procesado de características, la red neuronal es capaz de llevar a cabo multitud de tareas en vez de tener que recurrir a modelos específicos (CNNs, RNNs, ...) en función del tipo de datos y tarea que queramos llevar a cabo.

# https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
# https://github.com/esceptico/perceiver-io/blob/master/src/perceiver/attention.py

import math
import torch

class MultiHeadAttention(torch.nn.Module):

    def __init__(self, kv_dim, q_dim, n_heads=1, attn_pdrop=0., resid_pdrop=0.):
        super().__init__()
        self.n_embd = q_dim
        self.n_heads = n_heads
        assert self.n_embd % self.n_heads == 0
        # key, query, value projections
        self.key = torch.nn.Linear(kv_dim, self.n_embd)
        self.query = torch.nn.Linear(q_dim, self.n_embd)
        self.value = torch.nn.Linear(kv_dim, self.n_embd)
        # regularization
        self.attn_drop = torch.nn.Dropout(attn_pdrop)
        self.resid_drop = torch.nn.Dropout(resid_pdrop)
        # output projection
        self.proj = torch.nn.Linear(self.n_embd, q_dim)

    def forward(self, kv, q, mask = None):
        B, M, C = kv.size()
        B, N, D = q.size()
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        k = self.key(kv).view(B, M, self.n_heads, D // self.n_heads).transpose(1, 2) # (B, nh, M, hs)
        q = self.query(q).view(B, N, self.n_heads, D // self.n_heads).transpose(1, 2) # (B, nh, N, hs)
        v = self.value(kv).view(B, M, self.n_heads, D // self.n_heads).transpose(1, 2) # (B, nh, M, hs)
        # attention (B, nh, N, hs) x (B, nh, hs, M) -> (B, nh, N, M)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        if mask is not None:
            att = att.masked_fill(self.mask[:,:,:N,:M] == 0, float('-inf'))
        att = torch.nn.functional.softmax(att, dim=-1)
        att = self.attn_drop(att)
        y = att @ v # (B, nh, N, M) x (B, nh, M, hs) -> (B, nh, N, hs)
        y = y.transpose(1, 2).contiguous().view(B, N, D) # re-assemble all head outputs side by side
        return self.resid_drop(self.proj(y)) # B, N, D

class Block(torch.nn.Module):
    def __init__(self, kv_dim, q_dim, n_heads=1, attn_pdrop=0., resid_pdrop=0.):
        super().__init__()
        self.ln1_kv = torch.nn.LayerNorm(kv_dim)
        self.ln1_q = torch.nn.LayerNorm(q_dim)
        self.ln2 = torch.nn.LayerNorm(q_dim)
        self.attn = MultiHeadAttention(kv_dim, q_dim, n_heads, attn_pdrop, resid_pdrop)
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(q_dim, 4 * q_dim),
            torch.nn.GELU(),
            torch.nn.Linear(4 * q_dim, q_dim),
            torch.nn.Dropout(resid_pdrop),
        )

    def forward(self, kv, q, mask=None):
        x = q + self.attn(self.ln1_kv(kv), self.ln1_q(q), mask)
        x = x + self.mlp(self.ln2(x))
        return x

class PerceiverEncoder(torch.nn.Module):
    # ejemplo sin recurrencia para clasificación
    def __init__(self, num_latents, latent_dim, input_dim, num_blocks, n_heads=1, attn_pdrop=0., resid_pdrop=0.):
        # se podrían separar los params en función de si es cross o self ...
        super().__init__()
        self.num_blocks = num_blocks
        self.latents = torch.nn.Parameter(torch.randn(num_latents, latent_dim))
        # encode
        self.cross_attn = Block(
            kv_dim=input_dim,
            q_dim=latent_dim,
            n_heads=n_heads,
            attn_pdrop=attn_pdrop,
            resid_pdrop=resid_pdrop
        )
        # process
        self.self_attention_blocks = torch.nn.ModuleList([
            Block( # se podrían hacer bloques diferenciados para en el forward pasar solo x
                kv_dim=latent_dim,
                q_dim=latent_dim,
                n_heads=n_heads,
                attn_pdrop=attn_pdrop,
                resid_pdrop=resid_pdrop
            ) for _ in range(num_blocks)
        ])

    def forward(self, x, mask = None):
        B = x.size(0)
        if mask is not None:
            mask = mask[None, None, :, :] # esto no se si está bien :S
        x = self.cross_attn(
            kv=x,
            q=self.latents.repeat(B, 1, 1),
            mask=mask
        )
        for _ in range(self.num_blocks):
            for self_attn_layer in self.self_attention_blocks:
                x = self_attn_layer(x, x)
        return x

Para este primer ejemplo vamos a implementar un decoder para clasificación de imágenes. Este decoder recibirá la salida del encoder, que ya implementamos en el post anterior, y que usaremos como keys y values. En cuanto a las query, usaremos un tensor con las dimensiones a las que queramos proyectar las salidas del encoder, en nuestro caso al número de clases a las cuales queramos hacer la clasificación. De la misma manera que el modelo aprenderá la mejor representación para las query a la entrada, lo mismo hacemos a la salida.

# https://github.com/esceptico/perceiver-io/blob/master/src/perceiver/decoders.py

class ClassificationDecoder(torch.nn.Module):
    def __init__(self, num_classes, latent_dim, n_heads=1, attn_pdrop=0., resid_pdrop=0.):
        super().__init__()
        self.task_ids = torch.nn.Parameter(torch.randn(1, num_classes))
        self.decoder = Block(
            kv_dim=latent_dim,
            q_dim=num_classes,
            n_heads=n_heads,
            attn_pdrop=attn_pdrop,
            resid_pdrop=resid_pdrop
        )

    def forward(self, latents):
        b = latents.size(0)
        logits = self.decoder(
            kv=latents,
            q=self.task_ids.repeat(b, 1, 1)
        )
        return logits.squeeze(1)
decoder = ClassificationDecoder(10, 512)

decoder
ClassificationDecoder(
  (decoder): Block(
    (ln1_kv): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (ln1_q): LayerNorm((10,), eps=1e-05, elementwise_affine=True)
    (ln2): LayerNorm((10,), eps=1e-05, elementwise_affine=True)
    (attn): MultiHeadAttention(
      (key): Linear(in_features=512, out_features=10, bias=True)
      (query): Linear(in_features=10, out_features=10, bias=True)
      (value): Linear(in_features=512, out_features=10, bias=True)
      (attn_drop): Dropout(p=0.0, inplace=False)
      (resid_drop): Dropout(p=0.0, inplace=False)
      (proj): Linear(in_features=10, out_features=10, bias=True)
    )
    (mlp): Sequential(
      (0): Linear(in_features=10, out_features=40, bias=True)
      (1): GELU()
      (2): Linear(in_features=40, out_features=10, bias=True)
      (3): Dropout(p=0.0, inplace=False)
    )
  )
)
output = decoder(torch.randn(32, 128, 512))

output.shape
torch.Size([32, 10])

Como puedes ver, podemos adaptar el decoder de manera muy sencilla para generar outputs de la dimensión que queramos, haciendo del PerceiverIO una arquitectura muy flexible. Podríamos tener diversos decoders para generar múltiples outputs o incluso entrenar un solo modelo con un solo decoder pero diferentes query para diferentes tareas (multi-tasking). Este hecho, junto a la flexibilidad a la hora de aceptar diferentes tipos de inputs que ya hereda del Perciever convierte a este modelo en uno de los más flexibles a día de hoy.

Clasificación de imágenes

Vamos a ver el mismo ejemplo que vimos en el post anterior: clasificación de imágenes con el dataset CIFAR10. Esta vez, con el nuevo modelo.

from einops import rearrange, repeat
from math import pi, log

class FourierEncoder(torch.nn.Module):
    def __init__(self, max_freq, num_freq_bands, freq_base=2):
        super().__init__()
        self.max_freq = max_freq
        self.num_freq_bands = num_freq_bands
        self.freq_base = freq_base

    def fourier_encode(self, x):
        x = x.unsqueeze(-1)
        device, dtype, orig_x = x.device, x.dtype, x
        scales = torch.logspace(0., log(self.max_freq / 2) / log(self.freq_base), self.num_freq_bands, base = self.freq_base, device = device, dtype = dtype)
        scales = scales[(*((None,) * (len(x.shape) - 1)), ...)]
        x = x * scales * pi
        x = torch.cat([x.sin(), x.cos()], dim=-1)
        x = torch.cat((x, orig_x), dim = -1)
        return x

    def forward(self, x):
         # fourier encoding
        b, *axis, _, device = *x.shape, x.device
        axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps = size, device = device), axis))
        pos = torch.stack(torch.meshgrid(*axis_pos), dim = -1)
        enc_pos = self.fourier_encode(pos)
        enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
        enc_pos = repeat(enc_pos, '... -> b ...', b = b)
        x = torch.cat((x, enc_pos), dim = -1)
        x = rearrange(x, 'b ... d -> b (...) d')
        return x
class PerceiverIO(torch.nn.Module):
    def __init__(self, num_classes, max_freq ,num_freq_bands, num_latents, latent_dim, input_dim, num_blocks, freq_base = 2, n_heads=1, attn_pdrop=0., resid_pdrop=0.):
        super().__init__()
        fourier_channels = (2 * ((num_freq_bands * 2) + 1))
        input_dim += fourier_channels
        self.fourier_encoder = FourierEncoder(
            max_freq=max_freq,
            num_freq_bands=num_freq_bands,
            freq_base=freq_base
        )
        self.encoder = PerceiverEncoder(
            num_latents=num_latents,
            latent_dim=latent_dim,
            input_dim=input_dim,
            num_blocks=num_blocks,
            n_heads=n_heads,
            attn_pdrop=attn_pdrop,
            resid_pdrop=resid_pdrop
        )
        self.decoder = ClassificationDecoder(
            num_classes=num_classes,
            latent_dim=latent_dim,
            n_heads=n_heads,
            attn_pdrop=attn_pdrop,
            resid_pdrop=resid_pdrop
        )

    def forward(self, x, mask = None):
        x = self.fourier_encoder(x)
        if mask is not None:
            mask = mask[None, None, :, :] # esto no se si está bien :S
        x = self.encoder(x, mask)
        return self.decoder(x)
perceiver = PerceiverIO(num_classes=10, max_freq=10, num_freq_bands=6, num_latents=256, latent_dim=512, input_dim=3, num_blocks=2)

output = perceiver(torch.randn(32, 32, 32, 3))

output.shape
torch.Size([32, 10])
import torchvision
import numpy as np

class Dataset(torch.utils.data.Dataset):
    def __init__(self, train=True):
        trainset = torchvision.datasets.CIFAR10(root='./data', train=train, download=True)
        self.classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
        self.imgs, self.labels = np.array([np.array(i[0]) for i in trainset]), np.array([i[1] for i in trainset])
    def __len__(self):
        return len(self.imgs)
    def __getitem__(self, ix):
        img = self.imgs[ix]
        return torch.from_numpy(img / 255.).float(), torch.tensor(self.labels[ix]).long()
ds = {
    'train': Dataset(),
    'test': Dataset(train=False)
}

batch_size = 64
num_workers = 0

dl = {
    'train': torch.utils.data.DataLoader(ds['train'], batch_size=batch_size, shuffle=True, num_workers=num_workers),
    'test': torch.utils.data.DataLoader(ds['test'], batch_size=batch_size, shuffle=False, num_workers=num_workers)
}

imgs, labels = next(iter(dl['train']))
imgs.shape, labels.shape
Files already downloaded and verified
Files already downloaded and verified





(torch.Size([64, 32, 32, 3]), torch.Size([64]))
import matplotlib.pyplot as plt

fig = plt.figure(dpi=200)
c, r = 6, 4
for j in range(r):
    for i in range(c):
        ix = j*c + i
        ax = plt.subplot(r, c, ix + 1)
        img, label = imgs[ix], labels[ix]
        ax.imshow(img)
        ax.set_title(ds['train'].classes[labels[ix].item()], fontsize=6)
        ax.axis('off')
plt.tight_layout()
plt.show()

png

from tqdm import tqdm
import torch.nn.functional as F

def step(model, batch, device):
    x, y = batch
    x, y = x.to(device), y.to(device)
    y_hat = model(x)
    loss = F.cross_entropy(y_hat, y)
    acc = (torch.argmax(y_hat, axis=1) == y).sum().item() / y.size(0)
    return loss, acc

def train(model, dl, optimizer, epochs=10, device="cuda", use_amp=True, overfit=False):
    model.to(device)
    hist = {'loss': [], 'acc': [], 'test_loss': [], 'test_acc': []}
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
    for e in range(1, epochs+1):
        # train
        model.train()
        l, a = [], []
        bar = tqdm(dl['train'])
        # como cambiar el shuffle a False si overfit ?
        for batch in bar:
            optimizer.zero_grad()
            with torch.cuda.amp.autocast(enabled=use_amp):
                loss, acc = step(model, batch, device)
            scaler.scale(loss).backward()
            # gradient clipping
            #torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1)
            scaler.step(optimizer)
            scaler.update()
            l.append(loss.item())
            a.append(acc)
            if overfit:
                break
            bar.set_description(f"training... loss {np.mean(l):.4f} acc {np.mean(a):.4f}")
        hist['loss'].append(np.mean(l))
        hist['acc'].append(np.mean(a))
        # eval
        model.eval()
        l, a = [], []
        if overfit:
            bar = tqdm(dl['train'])
        else:
            bar = tqdm(dl['test'])
        with torch.inference_mode():
            for batch in bar:
                loss, acc = step(model, batch, device)
                l.append(loss.item())
                a.append(acc)
                if overfit:
                    break
                bar.set_description(f"testing... loss {np.mean(l):.4f} acc {np.mean(a):.4f}")
        hist['test_loss'].append(np.mean(l))
        hist['test_acc'].append(np.mean(a))
        if not overfit:
            # log
            log = f'Epoch {e}/{epochs}'
            for k, v in hist.items():
                log += f' {k} {v[-1]:.4f}'
            print(log)
    return hist
import pandas as pd

def plot_hist(hist):
    fig = plt.figure(figsize=(10, 3), dpi=100)
    df = pd.DataFrame(hist)
    ax = plt.subplot(1, 2, 1)
    df[['loss', 'test_loss']].plot(ax=ax)
    ax.grid(True)
    ax = plt.subplot(1, 2, 2)
    df[['acc', 'test_acc']].plot(ax=ax)
    ax.grid(True)
    plt.show()
model = PerceiverIO(num_classes=10, max_freq=5, num_freq_bands=3, num_latents=32, latent_dim=32, input_dim=3, num_blocks=3)


optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

dl = {
    'train': torch.utils.data.DataLoader(ds['train'], batch_size=512, shuffle=True, num_workers=0),
    'test': torch.utils.data.DataLoader(ds['test'], batch_size=512, shuffle=False, num_workers=0)
}
hist = train(model, dl, optimizer, epochs=100, overfit=False)
plot_hist(hist)

png

Tras 100 epochs nuestro modelo empieza a mostras síntomas de overffiting, podríamos pensar en activar el dropout del modelo, usar data augmentation, y otras técnicas de regularización.

imgs, labels = next(iter(dl['train']))

model.eval()
with torch.inference_mode():
    logits = model(imgs.cuda())
    preds = torch.argmax(logits, axis=1).cpu()

preds.shape
torch.Size([512])
fig = plt.figure(dpi=200)
c, r = 6, 4
for j in range(r):
    for i in range(c):
        ix = j*c + i
        ax = plt.subplot(r, c, ix + 1)
        img, label = imgs[ix], labels[ix]
        ax.imshow(img)
        gt = ds['test'].classes[labels[ix].item()]
        pr = ds['test'].classes[preds[ix].item()]
        ax.set_title(f'{pr}/{gt}', fontsize=6, color='green' if pr == gt else 'red')
        ax.axis('off')
plt.tight_layout()
plt.show()

png

Resumen

En este post hemos implementado nuestro primer modelo basado en la arquitectura PercieverIO. Para ello, hemos reutilizado el código implementado en el post anterior para el Perceiver y añadido la parte del decoder, capaz de generar outputs de dimensiones determinadas que podemos controlar. De esta manera, nuestro modelo desacopla totalmente las entradas/salidas de la parte de extracción y procesado de características convirtiendo al PerceiverIO en una red neuronal muy flexible capaz de trabajar con imágenes, texto, audio, video e incluso combinaciones de los mismos a la vez que puede usarse para diferentes tareas como la clasificación, generación de texto, detección de objetos en imágenes, etc. En los siguientes posts exploraremos más en detalle estas características.

Recursos

< Blog RSS