Ahora es el momento de algunos aprendizaje contrastivoPara aliviar el problema de las etiquetas de anotación insuficientes y explotar plenamente la gran cantidad de datos sin etiquetar, se podría utilizar el aprendizaje contrastivo para ayudar de manera efectiva a la columna vertebral a aprender representaciones de datos sin una tarea específica. La columna vertebral podría congelarse para una tarea posterior determinada y entrenar solo una red poco profunda en un conjunto de datos limitado y anotado para lograr resultados satisfactorios.
Los enfoques de aprendizaje contrastivo más utilizados incluyen SimCLR, SimSiam y MOCO (consulte mi artículo anterior sobre MOCO). Aquí comparamos SimCLR y SimSiam.
SimCLR calcula pares positivos y negativos en el lote de datos, lo que requiere una extracción negativa rigurosa, pérdida de NT-Xent (que extiende la pérdida de similitud del coseno en un lote) y un tamaño de lote grande. SimCLR también requiere que el optimizador LARS admita lotes de gran tamaño.
SimSiam, Sin embargo, SimSiam utiliza una arquitectura siamesa, que evita el uso de pares negativos y evita aún más la necesidad de lotes de gran tamaño. Las diferencias entre SimSiam y SimCLR se muestran en la siguiente tabla.
Podemos ver en la figura anterior que la arquitectura SimSiam contiene solo dos partes: el codificador/backbone y el predictor. Durante el tiempo de entrenamiento, se detiene la propagación del gradiente de la parte siamesa y se calcula la similitud del coseno entre las salidas de los predictores y la red troncal.
Entonces, ¿cómo implementar esta arquitectura en la realidad? Sigamos con el diseño de clasificación supervisada, mantener la misma estructura y solo modificar la capa MLP. En la arquitectura de aprendizaje supervisado, el MLP genera un vector de 10 elementos que indica las probabilidades de las 10 clases. Pero para SimSiam, el objetivo no es realizar una “clasificación” sino aprender la “representación”, por lo que necesitamos que la salida tenga la misma dimensión que la salida principal para calcular las pérdidas. Y la similitud_coseno_negativo se proporciona a continuación:
import torch.nn as nn
import matplotlib.pyplot as pltclass SimSiam(nn.Module):
def __init__(self):
super(SimSiam, self).__init__()
self.backbone = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.BatchNorm2d(32),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.BatchNorm2d(64),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.BatchNorm2d(128),
)
self.prediction_mlp = nn.Sequential(nn.Linear(128*4*4, 64),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.Linear(64, 128*4*4),
)
def forward(self, x):
x = self.backbone(x)
x = x.view(-1, 128 * 4 * 4)
pred_output = self.prediction_mlp(x)
return x, pred_output
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
def negative_cosine_similarity_stopgradient(pred, proj):
return -cos(pred, proj.detach()).mean()
El pseudocódigo para entrenar SimSiam se proporciona en el artículo original a continuación:
Y lo convertimos en código de entrenamiento real:
import tqdmimport torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.transforms import RandAugment
import wandb
wandb_config = {
"learning_rate": 0.0001,
"architecture": "simsiam",
"dataset": "FashionMNIST",
"epochs": 100,
"batch_size": 256,
}
wandb.init(
# set the wandb project where this run will be logged
project="simsiam",
# track hyperparameters and run metadata
config=wandb_config,
)
# Initialize model and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
simsiam = SimSiam()
random_augmenter = RandAugment(num_ops=5)
optimizer = optim.SGD(simsiam.parameters(),
lr=wandb_config["learning_rate"],
momentum=0.9,
weight_decay=1e-5,
)
train_dataloader = DataLoader(train_dataset, batch_size=wandb_config["batch_size"], shuffle=True)
# Training loop
for epoch in range(wandb_config["epochs"]):
simsiam.train()
print(f"Epoch {epoch}")
train_loss = 0
for batch_idx, (image, _) in enumerate(tqdm.tqdm(train_dataloader, total=len(train_dataloader))):
optimizer.zero_grad()
aug1, aug2 = random_augmenter((image*255).to(dtype=torch.uint8)).to(dtype=torch.float32) / 255.0, \
random_augmenter((image*255).to(dtype=torch.uint8)).to(dtype=torch.float32) / 255.0
proj1, pred1 = simsiam(aug1)
proj2, pred2 = simsiam(aug2)
loss = negative_cosine_similarity_stopgradient(pred1, proj2) / 2 + negative_cosine_similarity_stopgradient(pred2, proj1) / 2
loss.backward()
optimizer.step()
wandb.log({"training loss": loss})
if (epoch+1) % 10 == 0:
torch.save(simsiam.state_dict(), f"weights/simsiam_epoch{epoch+1}.pt")
Entrenamos en 100 épocas para una comparación justa con el entrenamiento supervisado limitado; La pérdida de entrenamiento se muestra a continuación. Nota: Debido a su diseño siamés, SimSiam podría ser muy sensible a hiperparámetros como la tasa de aprendizaje y las capas ocultas de MLP. El documento original de SimSiam proporciona una configuración detallada para la red troncal ResNet50. Para la red troncal basada en ViT, recomendamos leer el Documento MOCO v3que adopta el modelo SimSiam en un esquema de actualización de impulso.
A continuación, ejecutamos SimSiam entrenado en el conjunto de prueba y visualizamos las representaciones usando la reducción UMAP:
import tqdm
import numpy as npimport torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
simsiam = SimSiam()
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)
simsiam.load_state_dict(torch.load("weights/simsiam_epoch100.pt"))
simsiam.eval()
simsiam.to(device)
features = []
labels = []
for batch_idx, (image, target) in enumerate(tqdm.tqdm(test_dataloader, total=len(test_dataloader))):
with torch.no_grad():
proj, pred = simsiam(image.to(device))
features.extend(np.squeeze(pred.detach().cpu().numpy()).tolist())
labels.extend(target.detach().cpu().numpy().tolist())
import plotly.express as px
import umap.umap_ as umap
reducer = umap.UMAP(n_components=3, n_neighbors=10, metric="cosine")
projections = reducer.fit_transform(np.array(features))
px.scatter(projections, x=0, y=1,
color=labels, labels={'color': 'Fashion MNIST Labels'}
)
Es interesante ver que hay dos pequeñas islas en el mapa de dimensiones reducidas de arriba: clases 5, 7, 8 y 9. Si extraemos la lista de clases FashionMNIST, sabemos que estas clases corresponden a zapatos como “Sandal”, “Baloncesto”, “Bolsa” y “Bota”. El gran grupo violeta corresponde a clases de ropa como «camiseta/top», «pantalones», «suéter», «vestido», «abrigo» y «camisa». SimSiam demuestra el aprendizaje de una representación significativa en el dominio de la visión.