Contrastive Learning
Contrastive learning is a powerful self-supervised learning technique that has revolutionized how we train deep learning models without labeled data. This tutorial covers the fundamentals, architectures, and practical implementations using PyTorch.
Introduction
Contrastive learning is a machine learning technique that learns representations by contrasting positive pairs against negative pairs. The core idea is simple yet powerful: similar data points should have similar representations, while dissimilar points should have different representations.
Contrastive learning teaches models to understand what makes things similar or different rather than memorizing specific labels. This makes it particularly effective for self-supervised learning scenarios.
Why Contrastive Learning?
Traditional supervised learning requires large amounts of labeled data, which is:
- Expensive to obtain
- Time-consuming to annotate
- Sometimes impossible to label accurately
Contrastive learning addresses these challenges by:
- Learning from unlabeled data
- Creating powerful representations
- Achieving competitive or superior performance to supervised methods
- Reducing dependency on large labeled datasets
Core Concepts
1. Similarity Learning
The foundation of contrastive learning is measuring similarity between data points. We want to:
- Pull together representations of similar samples (positive pairs)
- Push apart representations of dissimilar samples (negative pairs)
2. Positive and Negative Pairs
Positive Pairs: Two augmented versions of the same sample
Original Image → [Crop, Flip] → Augmented View 1
→ [Color Jitter, Blur] → Augmented View 2
Negative Pairs: Different samples in the batch
Image A → Representation A
Image B → Representation B (different from A)
3. The Contrastive Loss Landscape
The loss function creates a landscape where:
- Similar samples cluster together
- Different samples spread apart
- The model learns meaningful features automatically
Common Loss Functions
1. InfoNCE Loss (Noise-Contrastive Estimation)
InfoNCE is one of the most popular contrastive losses, used in methods like SimCLR and MoCo.
Mathematical Formulation:
ℓ(i,j) = -log[exp(sim(zi, zj)/τ) / Σk exp(sim(zi, zk)/τ)]
Where:
zi, zjare the representations of positive pairszkincludes all samples (positive and negatives)sim(·,·)is the similarity function (usually cosine similarity)τis the temperature parameter
PyTorch Implementation:
import torch
import torch.nn.functional as F
def info_nce_loss(features, temperature=0.5):
"""
InfoNCE Loss for contrastive learning.
Args:
features: [batch_size * 2, embedding_dim] tensor
(contains pairs of augmented views)
temperature: Temperature parameter for scaling
Returns:
loss: Scalar loss value
"""
batch_size = features.shape[0] // 2
# Normalize features
features = F.normalize(features, dim=1)
# Compute similarity matrix
similarity_matrix = torch.matmul(features, features.T)
# Create labels: positive pairs are at indices (i, i+batch_size)
labels = torch.cat([torch.arange(batch_size) + batch_size,
torch.arange(batch_size)]).to(features.device)
# Mask to remove self-similarity
mask = torch.eye(batch_size * 2, dtype=torch.bool).to(features.device)
similarity_matrix = similarity_matrix.masked_fill(mask, -9e15)
# Apply temperature scaling
similarity_matrix = similarity_matrix / temperature
# Compute loss using cross-entropy
loss = F.cross_entropy(similarity_matrix, labels)
return loss
# Example usage
batch_size = 64
embedding_dim = 128
# Simulated features from two augmented views
features = torch.randn(batch_size * 2, embedding_dim)
loss = info_nce_loss(features)
print(f"InfoNCE Loss: {loss.item():.4f}")
2. NT-Xent Loss (Normalized Temperature-scaled Cross-Entropy)
NT-Xent is used in SimCLR and is essentially InfoNCE with specific normalization.
import torch
import torch.nn as nn
import torch.nn.functional as F
class NTXentLoss(nn.Module):
"""
NT-Xent (Normalized Temperature-scaled Cross Entropy) Loss.
Used in SimCLR paper.
"""
def __init__(self, temperature=0.5):
super(NTXentLoss, self).__init__()
self.temperature = temperature
def forward(self, z_i, z_j):
"""
Args:
z_i: [batch_size, embedding_dim] - First view embeddings
z_j: [batch_size, embedding_dim] - Second view embeddings
"""
batch_size = z_i.shape[0]
# Normalize
z_i = F.normalize(z_i, dim=1)
z_j = F.normalize(z_j, dim=1)
# Concatenate both views
representations = torch.cat([z_i, z_j], dim=0)
# Compute similarity matrix
similarity_matrix = torch.mm(representations, representations.T)
# Create positive pairs mask
# For index i, positive pair is at i + batch_size (and vice versa)
positives = torch.cat([
torch.arange(batch_size, 2 * batch_size),
torch.arange(batch_size)
]).to(z_i.device)
# Create mask to exclude self-similarity
mask = torch.eye(2 * batch_size, dtype=torch.bool).to(z_i.device)
# Get positive similarity scores
positives_sim = similarity_matrix[torch.arange(2 * batch_size), positives].unsqueeze(1)
# Get negative similarity scores (excluding self)
negatives_sim = similarity_matrix.masked_fill(mask, -9e15)
# Concatenate positive and negative similarities
logits = torch.cat([positives_sim, negatives_sim], dim=1)
# Apply temperature
logits = logits / self.temperature
# Labels: positive is always at index 0
labels = torch.zeros(2 * batch_size, dtype=torch.long).to(z_i.device)
# Cross-entropy loss
loss = F.cross_entropy(logits, labels)
return loss
# Example usage
criterion = NTXentLoss(temperature=0.5)
z_i = torch.randn(32, 128) # First augmented view
z_j = torch.randn(32, 128) # Second augmented view
loss = criterion(z_i, z_j)
print(f"NT-Xent Loss: {loss.item():.4f}")
3. Triplet Loss
Triplet loss works with anchor, positive, and negative samples.
import torch
import torch.nn as nn
import torch.nn.functional as F
class TripletLoss(nn.Module):
"""
Triplet Loss for contrastive learning.
"""
def __init__(self, margin=1.0):
super(TripletLoss, self).__init__()
self.margin = margin
def forward(self, anchor, positive, negative):
"""
Args:
anchor: [batch_size, embedding_dim]
positive: [batch_size, embedding_dim]
negative: [batch_size, embedding_dim]
"""
# Calculate distances
distance_positive = F.pairwise_distance(anchor, positive, p=2)
distance_negative = F.pairwise_distance(anchor, negative, p=2)
# Triplet loss with margin
losses = F.relu(distance_positive - distance_negative + self.margin)
return losses.mean()
# Example usage
criterion = TripletLoss(margin=1.0)
anchor = torch.randn(32, 128)
positive = torch.randn(32, 128)
negative = torch.randn(32, 128)
loss = criterion(anchor, positive, negative)
print(f"Triplet Loss: {loss.item():.4f}")
Popular Architectures
1. SimCLR (Simple Framework for Contrastive Learning)
SimCLR is one of the most influential contrastive learning frameworks. It uses:
- Strong data augmentation
- Large batch sizes
- NT-Xent loss
- A projection head (MLP) after the encoder
Architecture Overview:
Input Image
↓
[Data Augmentation]
↓ ↓
View 1 View 2
↓ ↓
[Encoder (ResNet)]
↓ ↓
h_i h_j
↓ ↓
[Projection Head (MLP)]
↓ ↓
z_i z_j
↓ ↓
[Contrastive Loss]
PyTorch Implementation:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
class SimCLR(nn.Module):
"""
SimCLR: A Simple Framework for Contrastive Learning of Visual Representations
"""
def __init__(self, base_encoder='resnet50', projection_dim=128, hidden_dim=2048):
super(SimCLR, self).__init__()
# Base encoder (typically ResNet)
if base_encoder == 'resnet50':
self.encoder = models.resnet50(pretrained=False)
encoder_dim = self.encoder.fc.in_features
# Remove the final classification layer
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
else:
raise NotImplementedError(f"Encoder {base_encoder} not implemented")
# Projection head (MLP with one hidden layer)
self.projection_head = nn.Sequential(
nn.Linear(encoder_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, projection_dim)
)
def forward(self, x):
"""
Args:
x: [batch_size, 3, H, W] - Input images
Returns:
z: [batch_size, projection_dim] - Projected embeddings
"""
# Get representation from encoder
h = self.encoder(x)
h = torch.flatten(h, start_dim=1)
# Project to contrastive space
z = self.projection_head(h)
return z
def get_representation(self, x):
"""
Get the encoder representation (without projection head).
Used for downstream tasks.
"""
h = self.encoder(x)
h = torch.flatten(h, start_dim=1)
return h
# Example usage
model = SimCLR(base_encoder='resnet50', projection_dim=128)
# Simulated batch of augmented views
batch_size = 32
view1 = torch.randn(batch_size, 3, 224, 224)
view2 = torch.randn(batch_size, 3, 224, 224)
# Forward pass
z1 = model(view1)
z2 = model(view2)
print(f"View 1 embeddings shape: {z1.shape}")
print(f"View 2 embeddings shape: {z2.shape}")
# Compute loss
criterion = NTXentLoss(temperature=0.5)
loss = criterion(z1, z2)
print(f"SimCLR Loss: {loss.item():.4f}")
2. MoCo (Momentum Contrast)
MoCo uses a queue of negative samples and a momentum encoder for stable training.
import torch
import torch.nn as nn
import torchvision.models as models
class MoCo(nn.Module):
"""
Momentum Contrast for Unsupervised Visual Representation Learning
"""
def __init__(self, base_encoder='resnet50', projection_dim=128,
queue_size=65536, momentum=0.999, temperature=0.07):
super(MoCo, self).__init__()
self.queue_size = queue_size
self.momentum = momentum
self.temperature = temperature
# Query encoder
if base_encoder == 'resnet50':
encoder = models.resnet50(pretrained=False)
encoder_dim = encoder.fc.in_features
encoder = nn.Sequential(*list(encoder.children())[:-1])
else:
raise NotImplementedError
self.encoder_q = encoder
self.encoder_k = encoder
# Projection heads
self.projection_q = nn.Sequential(
nn.Linear(encoder_dim, encoder_dim),
nn.ReLU(),
nn.Linear(encoder_dim, projection_dim)
)
self.projection_k = nn.Sequential(
nn.Linear(encoder_dim, encoder_dim),
nn.ReLU(),
nn.Linear(encoder_dim, projection_dim)
)
# Initialize key encoder with query encoder weights
for param_q, param_k in zip(self.encoder_q.parameters(),
self.encoder_k.parameters()):
param_k.data.copy_(param_q.data)
param_k.requires_grad = False # Key encoder doesn't get gradient
for param_q, param_k in zip(self.projection_q.parameters(),
self.projection_k.parameters()):
param_k.data.copy_(param_q.data)
param_k.requires_grad = False
# Queue for negative samples
self.register_buffer("queue", torch.randn(projection_dim, queue_size))
self.queue = nn.functional.normalize(self.queue, dim=0)
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
@torch.no_grad()
def _momentum_update_key_encoder(self):
"""
Momentum update of the key encoder
"""
for param_q, param_k in zip(self.encoder_q.parameters(),
self.encoder_k.parameters()):
param_k.data = param_k.data * self.momentum + param_q.data * (1. - self.momentum)
for param_q, param_k in zip(self.projection_q.parameters(),
self.projection_k.parameters()):
param_k.data = param_k.data * self.momentum + param_q.data * (1. - self.momentum)
@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
"""
Update queue with new keys
"""
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
# Replace the oldest batch in the queue
if ptr + batch_size <= self.queue_size:
self.queue[:, ptr:ptr + batch_size] = keys.T
else:
# Wrap around if necessary
remaining = self.queue_size - ptr
self.queue[:, ptr:] = keys[:remaining].T
self.queue[:, :batch_size - remaining] = keys[remaining:].T
ptr = (ptr + batch_size) % self.queue_size
self.queue_ptr[0] = ptr
def forward(self, im_q, im_k):
"""
Args:
im_q: Query images [batch_size, 3, H, W]
im_k: Key images [batch_size, 3, H, W]
Returns:
logits: [batch_size, 1 + queue_size]
labels: [batch_size]
"""
# Compute query features
h_q = self.encoder_q(im_q)
h_q = torch.flatten(h_q, start_dim=1)
q = self.projection_q(h_q)
q = nn.functional.normalize(q, dim=1)
# Compute key features (no gradient)
with torch.no_grad():
self._momentum_update_key_encoder()
h_k = self.encoder_k(im_k)
h_k = torch.flatten(h_k, start_dim=1)
k = self.projection_k(h_k)
k = nn.functional.normalize(k, dim=1)
# Compute logits
# Positive logits: [batch_size, 1]
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
# Negative logits: [batch_size, queue_size]
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
# Logits: [batch_size, 1 + queue_size]
logits = torch.cat([l_pos, l_neg], dim=1)
# Apply temperature
logits /= self.temperature
# Labels: positives are at index 0
labels = torch.zeros(logits.shape[0], dtype=torch.long).to(logits.device)
# Update queue
self._dequeue_and_enqueue(k)
return logits, labels
# Example usage
model = MoCo(base_encoder='resnet50', projection_dim=128)
batch_size = 32
im_q = torch.randn(batch_size, 3, 224, 224)
im_k = torch.randn(batch_size, 3, 224, 224)
logits, labels = model(im_q, im_k)
print(f"Logits shape: {logits.shape}")
print(f"Labels shape: {labels.shape}")
# Compute loss
criterion = nn.CrossEntropyLoss()
loss = criterion(logits, labels)
print(f"MoCo Loss: {loss.item():.4f}")
3. BYOL (Bootstrap Your Own Latent)
BYOL is unique - it doesn't use negative pairs!
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
class BYOL(nn.Module):
"""
Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning
"""
def __init__(self, base_encoder='resnet50', projection_dim=256,
hidden_dim=4096, momentum=0.996):
super(BYOL, self).__init__()
self.momentum = momentum
# Online network
if base_encoder == 'resnet50':
encoder = models.resnet50(pretrained=False)
encoder_dim = encoder.fc.in_features
encoder = nn.Sequential(*list(encoder.children())[:-1])
else:
raise NotImplementedError
self.online_encoder = encoder
# Online projector
self.online_projector = nn.Sequential(
nn.Linear(encoder_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, projection_dim)
)
# Online predictor (only in online network)
self.predictor = nn.Sequential(
nn.Linear(projection_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, projection_dim)
)
# Target network (no gradients)
self.target_encoder = encoder
self.target_projector = nn.Sequential(
nn.Linear(encoder_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, projection_dim)
)
# Initialize target network with online network weights
for param_o, param_t in zip(self.online_encoder.parameters(),
self.target_encoder.parameters()):
param_t.data.copy_(param_o.data)
param_t.requires_grad = False
for param_o, param_t in zip(self.online_projector.parameters(),
self.target_projector.parameters()):
param_t.data.copy_(param_o.data)
param_t.requires_grad = False
@torch.no_grad()
def _update_target_network(self):
"""
Momentum update of target network
"""
for param_o, param_t in zip(self.online_encoder.parameters(),
self.target_encoder.parameters()):
param_t.data = param_t.data * self.momentum + param_o.data * (1. - self.momentum)
for param_o, param_t in zip(self.online_projector.parameters(),
self.target_projector.parameters()):
param_t.data = param_t.data * self.momentum + param_o.data * (1. - self.momentum)
def forward(self, x1, x2):
"""
Args:
x1: First augmented view [batch_size, 3, H, W]
x2: Second augmented view [batch_size, 3, H, W]
Returns:
loss: BYOL loss
"""
# Online network forward pass for both views
h1_online = self.online_encoder(x1)
h1_online = torch.flatten(h1_online, start_dim=1)
z1_online = self.online_projector(h1_online)
p1 = self.predictor(z1_online)
h2_online = self.online_encoder(x2)
h2_online = torch.flatten(h2_online, start_dim=1)
z2_online = self.online_projector(h2_online)
p2 = self.predictor(z2_online)
# Target network forward pass (no gradients)
with torch.no_grad():
self._update_target_network()
h1_target = self.target_encoder(x1)
h1_target = torch.flatten(h1_target, start_dim=1)
z1_target = self.target_projector(h1_target)
h2_target = self.target_encoder(x2)
h2_target = torch.flatten(h2_target, start_dim=1)
z2_target = self.target_projector(h2_target)
# Compute loss (mean squared error with L2 normalization)
loss1 = self._loss_fn(p1, z2_target)
loss2 = self._loss_fn(p2, z1_target)
loss = (loss1 + loss2) / 2
return loss
def _loss_fn(self, x, y):
"""
BYOL loss: negative cosine similarity
"""
x = F.normalize(x, dim=-1, p=2)
y = F.normalize(y, dim=-1, p=2)
return 2 - 2 * (x * y).sum(dim=-1).mean()
# Example usage
model = BYOL(base_encoder='resnet50', projection_dim=256)
batch_size = 32
x1 = torch.randn(batch_size, 3, 224, 224)
x2 = torch.randn(batch_size, 3, 224, 224)
loss = model(x1, x2)
print(f"BYOL Loss: {loss.item():.4f}")
Data Augmentation for Contrastive Learning
Data augmentation is crucial for contrastive learning. Here are common augmentations:
import torch
import torchvision.transforms as transforms
from PIL import Image
class ContrastiveTransform:
"""
Transform for contrastive learning.
Applies two different random augmentations to create positive pairs.
"""
def __init__(self, image_size=224):
# Color jitter parameters
color_jitter = transforms.ColorJitter(
brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2
)
# First augmentation pipeline
self.transform = transforms.Compose([
transforms.RandomResizedCrop(size=image_size, scale=(0.2, 1.0)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply([color_jitter], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def __call__(self, x):
"""
Apply two different augmentations
"""
return self.transform(x), self.transform(x)
# Example usage
transform = ContrastiveTransform(image_size=224)
# Simulated PIL image
image = Image.new('RGB', (256, 256), color='red')
# Get two augmented views
view1, view2 = transform(image)
print(f"View 1 shape: {view1.shape}")
print(f"View 2 shape: {view2.shape}")
Complete Training Example
Here's a complete training loop for SimCLR:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
class SimCLRTrainer:
"""
Complete trainer for SimCLR
"""
def __init__(self, model, device='cuda', learning_rate=0.001,
temperature=0.5, batch_size=256, epochs=100):
self.model = model.to(device)
self.device = device
self.temperature = temperature
self.batch_size = batch_size
self.epochs = epochs
# Optimizer
self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
# Loss function
self.criterion = NTXentLoss(temperature=temperature)
def train_epoch(self, dataloader):
"""
Train for one epoch
"""
self.model.train()
total_loss = 0
for batch_idx, ((view1, view2), _) in enumerate(dataloader):
view1 = view1.to(self.device)
view2 = view2.to(self.device)
# Forward pass
z1 = self.model(view1)
z2 = self.model(view2)
# Compute loss
loss = self.criterion(z1, z2)
# Backward pass
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
total_loss += loss.item()
if batch_idx % 10 == 0:
print(f"Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}")
return total_loss / len(dataloader)
def train(self, dataloader):
"""
Complete training loop
"""
for epoch in range(self.epochs):
print(f"\nEpoch {epoch+1}/{self.epochs}")
avg_loss = self.train_epoch(dataloader)
print(f"Average Loss: {avg_loss:.4f}")
# Save checkpoint
if (epoch + 1) % 10 == 0:
self.save_checkpoint(f"simclr_epoch_{epoch+1}.pth")
def save_checkpoint(self, filename):
"""
Save model checkpoint
"""
torch.save({
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
}, filename)
print(f"Checkpoint saved: {filename}")
# Data loading with contrastive transforms
class ContrastiveDataset(torch.utils.data.Dataset):
"""
Wrapper dataset that applies contrastive transforms
"""
def __init__(self, base_dataset, transform):
self.base_dataset = base_dataset
self.transform = transform
def __len__(self):
return len(self.base_dataset)
def __getitem__(self, idx):
image, label = self.base_dataset[idx]
view1, view2 = self.transform(image)
return (view1, view2), label
# Example training script
if __name__ == "__main__":
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Create transforms
transform = ContrastiveTransform(image_size=32)
# Load CIFAR-10 dataset
base_dataset = CIFAR10(root='./data', train=True, download=True)
dataset = ContrastiveDataset(base_dataset, transform)
# Create dataloader
dataloader = DataLoader(
dataset,
batch_size=256,
shuffle=True,
num_workers=4,
pin_memory=True,
drop_last=True
)
# Create model
model = SimCLR(base_encoder='resnet50', projection_dim=128)
# Create trainer
trainer = SimCLRTrainer(
model=model,
device=device,
learning_rate=0.001,
temperature=0.5,
batch_size=256,
epochs=100
)
# Train
print("Starting training...")
trainer.train(dataloader)
Evaluation and Downstream Tasks
After training with contrastive learning, we evaluate the learned representations:
Linear Evaluation Protocol
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
class LinearEvaluator:
"""
Linear evaluation protocol for contrastive learning
"""
def __init__(self, encoder, num_classes=10, device='cuda'):
self.encoder = encoder.to(device)
self.device = device
# Freeze encoder
for param in self.encoder.parameters():
param.requires_grad = False
# Get encoder output dimension
with torch.no_grad():
dummy_input = torch.randn(1, 3, 224, 224).to(device)
encoder_dim = self.encoder.get_representation(dummy_input).shape[1]
# Linear classifier
self.classifier = nn.Linear(encoder_dim, num_classes).to(device)
self.optimizer = optim.Adam(self.classifier.parameters(), lr=0.001)
self.criterion = nn.CrossEntropyLoss()
def train_epoch(self, dataloader):
"""
Train classifier for one epoch
"""
self.classifier.train()
total_loss = 0
correct = 0
total = 0
for images, labels in dataloader:
images = images.to(self.device)
labels = labels.to(self.device)
# Get frozen features
with torch.no_grad():
features = self.encoder.get_representation(images)
# Forward pass through classifier
outputs = self.classifier(features)
loss = self.criterion(outputs, labels)
# Backward pass
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# Statistics
total_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
accuracy = 100. * correct / total
avg_loss = total_loss / len(dataloader)
return avg_loss, accuracy
def evaluate(self, dataloader):
"""
Evaluate classifier
"""
self.classifier.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in dataloader:
images = images.to(self.device)
labels = labels.to(self.device)
features = self.encoder.get_representation(images)
outputs = self.classifier(features)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
accuracy = 100. * correct / total
return accuracy
# Example usage
if __name__ == "__main__":
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load pretrained SimCLR model
model = SimCLR(base_encoder='resnet50', projection_dim=128)
# model.load_state_dict(torch.load('simclr_checkpoint.pth'))
# Prepare data
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform_test)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=4)
# Linear evaluation
evaluator = LinearEvaluator(model, num_classes=10, device=device)
print("Starting linear evaluation...")
for epoch in range(100):
train_loss, train_acc = evaluator.train_epoch(train_loader)
test_acc = evaluator.evaluate(test_loader)
print(f"Epoch {epoch+1}/100 - Train Loss: {train_loss:.4f}, "
f"Train Acc: {train_acc:.2f}%, Test Acc: {test_acc:.2f}%")
Applications of Contrastive Learning
1. Computer Vision
- Image Classification: Learn powerful representations for downstream classification
- Object Detection: Pretrain backbones for detection models
- Semantic Segmentation: Transfer learned features to segmentation tasks
2. Natural Language Processing
- Sentence Embeddings: Learn sentence representations (e.g., SimCSE)
- Document Similarity: Find similar documents in large corpora
- Information Retrieval: Improve search and retrieval systems
3. Multimodal Learning
- Vision-Language Models: CLIP, ALIGN - connecting images and text
- Audio-Visual Learning: Learning from audio and video simultaneously
- Medical Imaging: Learning from multiple imaging modalities
4. Specialized Domains
- Anomaly Detection: Detecting outliers in data
- Few-Shot Learning: Learning from limited labeled examples
- Domain Adaptation: Transferring knowledge across domains
Best Practices and Tips
1. Batch Size Matters
- Large batches are crucial: More negatives = better contrastive learning
- Use batch sizes of 256-4096 if possible
- Employ gradient accumulation if GPU memory is limited
# Gradient accumulation example
accumulation_steps = 4
optimizer.zero_grad()
for i, (view1, view2) in enumerate(dataloader):
z1 = model(view1)
z2 = model(view2)
loss = criterion(z1, z2) / accumulation_steps
loss.backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
2. Temperature Parameter
- Lower temperature (0.1-0.5): Sharper distributions, more discriminative
- Higher temperature (0.5-1.0): Softer distributions, more exploration
- Typical values: 0.5 for SimCLR, 0.07 for MoCo
3. Data Augmentation
- Strong augmentation is key: The quality of augmentations directly impacts performance
- Use multiple augmentations: crop, flip, color jitter, blur
- Avoid augmentations that change semantic meaning
4. Training Duration
- Contrastive learning requires longer training: 200-1000 epochs
- Be patient - representations improve gradually
- Monitor validation metrics to avoid overfitting
5. Evaluation
- Linear evaluation: Standard protocol to assess representation quality
- Fine-tuning: Can further improve performance on downstream tasks
- Transfer learning: Test on multiple datasets to assess generalization
Comparison of Methods
| Method | Key Feature | Pros | Cons |
|---|---|---|---|
| SimCLR | Large batches, strong augmentation | Simple, effective | Requires large batch sizes |
| MoCo | Momentum encoder, queue | Memory efficient | More complex implementation |
| BYOL | No negative pairs | Stable training | Requires careful tuning |
| SwAV | Clustering + contrastive | High performance | Complex architecture |
Common Pitfalls and Solutions
1. Collapsed Representations
Problem: All embeddings become identical
Solutions:
- Use proper normalization
- Ensure diverse augmentations
- Check temperature parameter
- Increase batch size
2. Training Instability
Problem: Loss oscillates or diverges
Solutions:
- Lower learning rate
- Use gradient clipping
- Check for NaN values
- Ensure proper initialization
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
3. Poor Downstream Performance
Problem: Good contrastive loss but poor downstream accuracy
Solutions:
- Increase training epochs
- Improve augmentation strategy
- Adjust temperature parameter
- Try different architectures
Conclusion
Contrastive learning has revolutionized self-supervised learning by enabling models to learn powerful representations without labeled data. Key takeaways:
- Learn from similarity: Contrastive learning teaches models to distinguish similar from dissimilar samples
- Strong augmentations: Data augmentation is crucial for creating meaningful positive pairs
- Architecture matters: SimCLR, MoCo, and BYOL offer different trade-offs
- Evaluation is important: Use proper protocols to assess representation quality
- Patience required: Contrastive learning needs longer training than supervised learning
Further Reading
Seminal Papers
- SimCLR: A Simple Framework for Contrastive Learning of Visual Representations
- MoCo: Momentum Contrast for Unsupervised Visual Representation Learning
- BYOL: Bootstrap Your Own Latent
- SwAV: Unsupervised Learning of Visual Features by Contrasting Cluster Assignments
Advanced Topics
- Barlow Twins: Self-Supervised Learning via Redundancy Reduction
- VICReg: Variance-Invariance-Covariance Regularization
- SimSiam: Exploring Simple Siamese Representation Learning
Multimodal
- CLIP: Learning Transferable Visual Models From Natural Language Supervision
- ALIGN: Scaling Up Visual and Vision-Language Representation Learning
Happy Learning! 🚀
This tutorial provides a comprehensive foundation for understanding and implementing contrastive learning. Experiment with different architectures, loss functions, and augmentation strategies to find what works best for your specific use case.