Research project comparing the performance of unimodal (text-only) models against multimodal (text + image) models in classification tasks.
Research Question
When does adding visual information significantly improve text classification, and when is it just computational overhead?
Models Evaluated
Unimodal Models (Text)
| Model | Parameters | Description |
|---|---|---|
| BERT-base | 110M | Pretrained transformer encoder |
| RoBERTa | 125M | BERT with better pretraining |
| BETO | 110M | BERT for Spanish |
Multimodal Models
| Model | Parameters | Modalities |
|---|---|---|
| CLIP | 400M | Image + Text (contrastive) |
| FLAVA | 350M | Image + Text (fusion) |
| ViLT | 113M | Vision-Language Transformer |
Methodology
Training Pipeline
class MultimodalTrainer:
def __init__(self, model_name: str, dataset: Dataset):
self.model = self._load_model(model_name)
self.dataset = dataset
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def train(self, epochs: int = 10, batch_size: int = 32):
self.model.to(self.device)
optimizer = AdamW(self.model.parameters(), lr=2e-5)
for epoch in range(epochs):
self.model.train()
total_loss = 0
for batch in DataLoader(self.dataset, batch_size=batch_size):
# Forward pass
if self.is_multimodal:
outputs = self.model(
input_ids=batch['input_ids'].to(self.device),
pixel_values=batch['pixel_values'].to(self.device),
labels=batch['labels'].to(self.device)
)
else:
outputs = self.model(
input_ids=batch['input_ids'].to(self.device),
labels=batch['labels'].to(self.device)
)
loss = outputs.loss
total_loss += loss.item()
# Backward pass
loss.backward()
optimizer.step()
optimizer.zero_grad()
print(f"Epoch {epoch}: Loss = {total_loss / len(self.dataset)}")
Datasets Used
- Hateful Memes (Facebook): Hate content detection in memes
- SNLI-VE: Textual inference with visual evidence
- VQA v2: Question answering about images
- Custom News Dataset: Chilean news classification with images
Preprocessing
class MultimodalProcessor:
def __init__(self, model_name: str):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.image_processor = AutoImageProcessor.from_pretrained(model_name)
def process(self, text: str, image: Image) -> Dict:
# Tokenize text
text_inputs = self.tokenizer(
text,
max_length=512,
truncation=True,
padding='max_length',
return_tensors='pt'
)
# Process image
image_inputs = self.image_processor(
image,
return_tensors='pt'
)
return {
'input_ids': text_inputs['input_ids'],
'attention_mask': text_inputs['attention_mask'],
'pixel_values': image_inputs['pixel_values']
}
CUDA Optimizations
Mixed Precision Training
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for batch in dataloader:
optimizer.zero_grad()
with autocast():
outputs = model(**batch)
loss = outputs.loss
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Gradient Checkpointing
# Reduce memory usage in large models
model.gradient_checkpointing_enable()
Multi-GPU with DataParallel
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
print(f"Using {torch.cuda.device_count()} GPUs")
Results
Hateful Memes Dataset
| Model | Accuracy | F1 | AUC-ROC |
|---|---|---|---|
| BERT | 62.3% | 0.58 | 0.67 |
| CLIP | 68.7% | 0.65 | 0.74 |
| FLAVA | 71.2% | 0.68 | 0.77 |
| ViLT | 69.5% | 0.66 | 0.75 |
News Classification (Custom)
| Model | Accuracy | F1 Macro |
|---|---|---|
| BETO | 87.2% | 0.86 |
| CLIP | 85.4% | 0.84 |
| FLAVA | 86.8% | 0.85 |
Conclusions
-
Multimodal helps when image is relevant: In Hateful Memes, where meaning depends on text-image combination, FLAVA outperforms BERT by ~9%.
-
Text is sufficient in many cases: In news classification, where images are generic, BETO matches or exceeds multimodal models.
-
Computational trade-off: Multimodal models require ~4x more VRAM and ~3x more training time.
Recommendation:
├── Image is part of meaning → Multimodal (FLAVA)
├── Image is decorative → Unimodal (BERT/BETO)
└── Limited resources → ViLT (more efficient)
Code and Reproducibility
The repository includes: - Training scripts for all models - Results analysis notebooks - Hyperparameter configurations - Trained model checkpoints
# Train FLAVA on Hateful Memes
python train.py \
--model flava \
--dataset hateful_memes \
--epochs 10 \
--batch_size 16 \
--learning_rate 2e-5