I decided I need to play more. Specifically, I wanted an excuse to use Google’s new T5Gemma2 Encoder-Decoder model. Except I did not want to use the decoder. I split the model and added a linear classifier head to be trained with BCE. I decided to use an old Kaggle competition dataset: https://www.kaggle.com/competitions/plant-pathology-2020-fgvc7/overview
The goal was not to beat the leaderboard but to walk through training, data manipulation and handling of low sample datasets. One of the most common criticisms of Kaggle is that entries are overfit. I won’t be submitting but I did want to set an accuracy bound on the top submissions so we had a rough target. The top entry is \~0.99 AUC but I recorded both F1 and Accuracy so I wanted benchmarks.
In addition to a 0.99 AUC, I found one of the top entries on Github. It had 95.002% of the test columns with single class confidence. So our upper bound is \~95% accuracy. Our lower bound without destroying the AUC would need to be greater than 85%. A key note on this is that our experiment supports this range. When I had a static AUC, we often had accuracy swing \~2% between evaluation runs.
The first run was a naive baseline. I wanted to set up the environment and check basic assumptions. I rented an A6000 on Prime Intellect. We ran with a batch of 8, LR of 3e-4, a cosine LR decay, and 5 epochs. I used the 270m encoder of the T5Gemma2. The train dataset was used since it is the only labelled dataset. I split it 2:1 train:val. Only the classifier head was trained. We did not touch the weights of the encoder. We ended with a 86% acc, 0.514 F1 and 0.924 AUC.
I then wanted to look at gradient noise (what batch size should we use) and then the impact of extending the dataset with augmentations. I used flips, reversals, shifts, blurs and contrast adjustments. Unsurprisingly, the encoder is well trained and our augmentations had minimal effects. If the encoder was untrained then we would expect these augmentations to influence our results. However, my gradient noise showed the optimal batch size for SNR was \~1.
Our batch size change then made an immediate impact. My f1 increased to 0.61, accuracy increased to 0.915 and my auc increased to 0.945. Again, there were some periods where AUC was \~0.945 and our accuracy was lower (\~0.900). The SNR and low batch size indicates that the dataset is pretty clean or our base encoder is rather expressive. I think it was a bit of both. Regardless, the next step is to see if we can clean up the dataset a little bit more.
In the next run, I used the encoder to create embeddings of our training data. I then used cosine similarity to determine if any crossed a 0.995 threshold (they were the same). We tagged 24 pairs which corresponded to 23 images. These images and labels were removed from training. This could have been doubly diseased leaves, accidental duplications or duplicates mislabelled. My deduplication did not explicitly improve mislabelled data. Again, we saw training improvements. My f1 increased to 0.63, the accuracy increased to 92.6% and auc was 0.952.
In terms of data, the next question was whether we could increase our gains by tuning noisy labels. One method would be to use gradient noise to determine image-label pairs which were difficult to learn. Some of these are difficult but often they are mislabelled which is why the model struggles to learn. The other option is to use K-folds where you soften the labels by using a trained model. Since this is a toy project and I am out of practice identifying apple diseases, I decided to use 5 folds. I trained 5 models to label a unique, held out ⅕ sample of the data. I then combined these new labels with the “ground truth” labels at a ratio of 3:7. Explicitly, the trained labels constituted 0.3 proportion of the label we used for training on the full set. This did not noticeably improve our performance.
The final implementation for the 270m model was using test-time augmentation (TTA) to improve our AUC. After our training sample augmentation had no noticeable effect, I did not expect TTA to improve our scores. However, I wanted to give it a fair shot. I implemented 5 TTA modifications which were a mix of horizontal/vertical flips and brightness augmentation. It did not meaningfully increase our performance.
Scaling is always fun so I ran one training run with the 4B model using my optimal setup (deduplication, batch size of 1). What was interesting is that the larger model removed only 21 (as opposed to 23 samples) and the optimal batch size was determined to be higher at 3. I ran with 1 because our noise test and batch size set are not wired together. I did leave capacity on the table by not increasing our batch size.
If this was more than a toy demo, I would begin digging into these training examples:
[train] step=2410 epoch=3 loss=0.1239 lr=2.00e-04
[train] step=2420 epoch=3 loss=0.1199 lr=1.99e-04
[train] step=2430 epoch=3 loss=0.4373 lr=1.98e-04
[train] step=2440 epoch=3 loss=0.0358 lr=1.97e-04
[train] step=2450 epoch=3 loss=0.8764 lr=1.96e-04
[train] step=2460 epoch=3 loss=0.3298 lr=1.96e-04
[train] step=2470 epoch=3 loss=0.7009 lr=1.95e-04
[train] step=2480 epoch=3 loss=0.0971 lr=1.94e-04
[train] step=2490 epoch=3 loss=0.6982 lr=1.93e-04
[train] step=2500 epoch=3 loss=0.0729 lr=1.92e-04
[train] step=2510 epoch=3 loss=0.1496 lr=1.91e-04
[train] step=2520 epoch=3 loss=0.1850 lr=1.91e-04
[train] step=2530 epoch=3 loss=0.4256 lr=1.90e-04
[train] step=2540 epoch=3 loss=0.0152 lr=1.89e-04
Why are these losses so high? Our validation loss is \~0.19 so there is something interesting with the data. I could likely push the results even higher. However, scaling also gives us a decent bump.
My final f1 is 0.663 with an accuracy of 94.4% and auc 0.969. This is only good for \~450 on the leaderboard but we’re effectively saturating the dataset. I could continue to play with the data to eke out the final offerings but that is more time than I will dedicate today. The intent was to take the new T5 model for a spin and sharpen my classification skills.
Here is the Validation AUC chart from WandB.
Addendum: I went back and decided to see if I could score at the top of the leaderboard. Adding an MLP head and decreasing our dropout to 0 allowed me to get an AUC of 0.9872. Effectively saturating this benchmark completely. This is with the 270m parameter encoder. The final run and performance are below:
uv run python3 vision_dis.py \
--data-dir ./pp2020/unzipped --batch-size 1 \
--eval-every 200 --epochs 5 --remove-duplicates \
--dedupe-threshold 0.995 --classifier-head mlp --lr 1e-3 --classifier-dropout 0
[eval] step=4200 loss=0.1057 f1=0.6880 acc=0.9596 auc=0.9872
The code can be found below.
from __future__ import annotations
import argparse
import csv
import json
import math
import os
import random
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import wandb
from transformers import (
AutoConfig,
AutoImageProcessor,
AutoProcessor,
AutoTokenizer,
AutoModelForSeq2SeqLM,
get_cosine_schedule_with_warmup,
)
# Optional AUROC support
try:
from sklearn.metrics import roc_auc_score
except Exception:
roc_auc_score = None
LABEL_COLS = ["healthy", "multiple_diseases", "rust", "scab"]
@dataclass
class EncoderClassifierOutput:
loss: Optional[torch.Tensor]
logits: torch.Tensor
def build_pp2020_augmentations():
try:
import albumentations as A
import cv2
except Exception as e:
raise RuntimeError(
"PP2020 augmentations require `albumentations` and `opencv-python`."
) from e
if hasattr(A, "RandomBrightness") and hasattr(A, "RandomContrast"):
brightness_contrast = A.OneOf(
[A.RandomBrightness(limit=0.1, p=1), A.RandomContrast(limit=0.1, p=1)],
p=1,
)
else:
brightness_contrast = A.RandomBrightnessContrast(
brightness_limit=0.1, contrast_limit=0.1, p=1
)
# Resize/normalize are handled by the processor to avoid double preprocessing.
return A.Compose(
[
brightness_contrast,
A.OneOf(
[
A.MotionBlur(blur_limit=3),
A.MedianBlur(blur_limit=3),
A.GaussianBlur(blur_limit=3),
],
p=0.5,
),
A.VerticalFlip(p=0.5),
A.HorizontalFlip(p=0.5),
A.ShiftScaleRotate(
shift_limit=0.2,
scale_limit=0.2,
rotate_limit=20,
interpolation=cv2.INTER_LINEAR,
border_mode=cv2.BORDER_REFLECT_101,
p=1.0,
),
]
)
def build_pp2020_tta_augmentations():
try:
import albumentations as A
except Exception as e:
raise RuntimeError(
"TTA augmentations require `albumentations`."
) from e
# Keep TTA light to avoid distribution shift.
return A.Compose(
[
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.RandomBrightnessContrast(brightness_limit=0.05, contrast_limit=0.05, p=0.3),
]
)
def apply_albumentations(augment, img: Image.Image) -> Image.Image:
arr = np.array(img)
out = augment(image=arr)["image"]
if out.dtype != np.uint8:
max_val = float(out.max()) if out.size else 1.0
if max_val <= 1.0:
out = (np.clip(out, 0.0, 1.0) * 255.0).round()
else:
out = np.clip(out, 0.0, 255.0)
out = out.astype(np.uint8)
return Image.fromarray(out)
def infer_encoder_hidden_size(encoder, config) -> int:
for attr in ("hidden_size", "d_model"):
val = getattr(encoder.config, attr, None)
if isinstance(val, int):
return val
text_config = getattr(encoder.config, "text_config", None)
if text_config is not None:
val = getattr(text_config, "hidden_size", None)
if isinstance(val, int):
return val
for attr in ("hidden_size", "d_model"):
val = getattr(config, attr, None)
if isinstance(val, int):
return val
text_config = getattr(config, "text_config", None)
if text_config is not None:
val = getattr(text_config, "hidden_size", None)
if isinstance(val, int):
return val
get_embeddings = getattr(encoder, "get_input_embeddings", None)
if callable(get_embeddings):
emb = get_embeddings()
if emb is not None:
if hasattr(emb, "embedding_dim"):
return int(emb.embedding_dim)
if hasattr(emb, "weight") and emb.weight is not None:
return int(emb.weight.shape[1])
raise ValueError("Could not infer encoder hidden size for classification head.")
class EncoderClassifier(nn.Module):
def __init__(
self,
model_id: str,
num_labels: int,
classifier_dropout: float = 0.1,
classifier_head: str = "linear",
classifier_hidden_dim: int = 0,
):
super().__init__()
self.backbone = AutoModelForSeq2SeqLM.from_pretrained(model_id)
self.num_labels = num_labels
encoder = self.backbone.get_encoder()
hidden_size = infer_encoder_hidden_size(encoder, self.backbone.config)
self.dropout = nn.Dropout(p=classifier_dropout)
backbone_dtype = next(self.backbone.parameters()).dtype
head = classifier_head.lower()
if head == "linear":
self.classifier = nn.Linear(hidden_size, num_labels)
elif head == "mlp":
mlp_dim = classifier_hidden_dim if classifier_hidden_dim > 0 else hidden_size
self.classifier = nn.Sequential(
nn.Linear(hidden_size, mlp_dim),
nn.GELU(),
nn.Dropout(p=classifier_dropout),
nn.Linear(mlp_dim, num_labels),
)
else:
raise ValueError(f"Unknown classifier_head={classifier_head!r}.")
self.classifier.to(dtype=backbone_dtype)
self.loss_fn = nn.BCEWithLogitsLoss()
def _classifier_dtype(self) -> torch.dtype:
for p in self.classifier.parameters():
return p.dtype
return torch.float32
def _attention_mask_dtype(self) -> torch.dtype:
if torch.is_autocast_enabled():
if torch.cuda.is_available():
get_dtype = getattr(torch, "get_autocast_dtype", None)
if callable(get_dtype):
return get_dtype("cuda")
get_dtype = getattr(torch, "get_autocast_gpu_dtype", None)
if callable(get_dtype):
return get_dtype()
if hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
return torch.bfloat16
return torch.float16
get_dtype = getattr(torch, "get_autocast_dtype", None)
if callable(get_dtype):
return get_dtype("cpu")
get_dtype = getattr(torch, "get_autocast_cpu_dtype", None)
if callable(get_dtype):
return get_dtype()
emb = self.backbone.get_input_embeddings()
if emb is not None and hasattr(emb, "weight") and emb.weight is not None:
return emb.weight.dtype
return torch.float32
def _build_full_attention_mask(
self, attention_mask: torch.Tensor, dtype: torch.dtype
) -> torch.Tensor:
# Build an additive mask: 0 for keep, large negative for masked keys.
mask = (1.0 - attention_mask.to(dtype)) * torch.finfo(dtype).min
return mask[:, None, None, :].expand(-1, 1, attention_mask.shape[1], -1)
def _pool(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor]) -> torch.Tensor:
if attention_mask is None:
return hidden_states.mean(dim=1)
mask = attention_mask.unsqueeze(-1).to(hidden_states.dtype)
return (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1.0)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
pixel_values: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
**kwargs,
) -> EncoderClassifierOutput:
orig_attention_mask = attention_mask
attn_mask_mapping = None
if attention_mask is not None:
mask_dtype = self._attention_mask_dtype()
full_mask = self._build_full_attention_mask(attention_mask, mask_dtype)
else:
full_mask = None
# Avoid transformer mask helpers that require torch>=2.6 by passing a dict.
attn_mask_mapping = {"full_attention": full_mask, "sliding_attention": full_mask}
encoder_kwargs = {}
if input_ids is not None:
encoder_kwargs["input_ids"] = input_ids
encoder_kwargs["attention_mask"] = attn_mask_mapping
if pixel_values is not None:
encoder_kwargs["pixel_values"] = pixel_values
if "position_ids" in kwargs and kwargs["position_ids"] is not None:
encoder_kwargs["position_ids"] = kwargs["position_ids"]
if "inputs_embeds" in kwargs and kwargs["inputs_embeds"] is not None:
encoder_kwargs["inputs_embeds"] = kwargs["inputs_embeds"]
encoder = self.backbone.get_encoder()
enc_out = encoder(return_dict=True, **encoder_kwargs)
pooled = self._pool(enc_out.last_hidden_state, orig_attention_mask)
pooled = pooled.to(self._classifier_dtype())
logits = self.classifier(self.dropout(pooled))
loss = self.loss_fn(logits.float(), labels) if labels is not None else None
return EncoderClassifierOutput(loss=loss, logits=logits)
def read_train_csv(train_csv: Path) -> List[Dict[str, str]]:
with train_csv.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
rows = list(reader)
if not rows:
raise ValueError(f"Empty train.csv: {train_csv}")
for col in ["image_id"] + LABEL_COLS:
if col not in rows[0]:
raise ValueError(f"Expected column '{col}' in {train_csv}, got columns={list(rows[0].keys())}")
return rows
def seed_everything(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def resolve_pp2020_image(images_dir: Path, image_id: str) -> Path:
# In PP2020, image files are typically images/<image_id>.jpg
cand = images_dir / f"{image_id}.jpg"
if cand.exists():
return cand
# fallback: try png/jpeg variants
for ext in [".jpeg", ".png", ".bmp", ".webp"]:
cand2 = images_dir / f"{image_id}{ext}"
if cand2.exists():
return cand2
raise FileNotFoundError(f"Could not find image for image_id={image_id} under {images_dir}")
class PP2020Dataset(Dataset):
def __init__(
self,
rows: List[Dict[str, str]],
images_dir: Path,
processor,
max_items: int = 0,
augment=None,
use_soft_labels: bool = False,
):
self.rows = rows[: max_items] if max_items and max_items > 0 else rows
self.images_dir = images_dir
self.processor = processor
self.prompt = "<start_of_image>"
self.augment = augment
self.use_soft_labels = use_soft_labels
def __len__(self) -> int:
return len(self.rows)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
r = self.rows[idx]
img_path = resolve_pp2020_image(self.images_dir, r["image_id"])
try:
img = Image.open(img_path).convert("RGB")
if self.augment is not None:
img = apply_albumentations(self.augment, img)
except Exception:
# Defensive fallback to a black image (won't crash training)
img = Image.fromarray(np.zeros((896, 896, 3), dtype=np.uint8), mode="RGB")
enc = self.processor(text=self.prompt, images=img, return_tensors="pt")
batch = {k: v.squeeze(0) for k, v in enc.items() if isinstance(v, torch.Tensor)}
if self.use_soft_labels and "soft_labels" in r:
y_vals = r["soft_labels"]
else:
y_vals = [float(r[c]) for c in LABEL_COLS]
y = torch.tensor(y_vals, dtype=torch.float32)
batch["labels"] = y
return batch
class ImageOnlyDataset(Dataset):
def __init__(
self,
rows: List[Dict[str, str]],
images_dir: Path,
processor,
max_items: int = 0,
augment=None,
):
self.rows = rows[: max_items] if max_items and max_items > 0 else rows
self.images_dir = images_dir
self.processor = processor
self.augment = augment
self.prompt = "<start_of_image>"
def __len__(self) -> int:
return len(self.rows)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, str]:
r = self.rows[idx]
img_path = resolve_pp2020_image(self.images_dir, r["image_id"])
try:
img = Image.open(img_path).convert("RGB")
if self.augment is not None:
img = apply_albumentations(self.augment, img)
except Exception:
img = Image.fromarray(np.zeros((896, 896, 3), dtype=np.uint8), mode="RGB")
enc = self.processor(text=self.prompt, images=img, return_tensors="pt")
pixel_values = enc["pixel_values"].squeeze(0)
return pixel_values, r["image_id"]
def collate_fn(examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
# Token padding (prompt is tiny, but pad defensively)
input_ids = [e["input_ids"] for e in examples]
attn = [e.get("attention_mask", torch.ones_like(e["input_ids"])) for e in examples]
pix = [e["pixel_values"] for e in examples]
labels = [e["labels"] for e in examples]
max_len = max(x.shape[-1] for x in input_ids)
def pad_1d(x: torch.Tensor, pad_val: int = 0) -> torch.Tensor:
if x.shape[-1] == max_len:
return x
out = torch.full((max_len,), pad_val, dtype=x.dtype)
out[: x.shape[-1]] = x
return out
return {
"input_ids": torch.stack([pad_1d(x, 0) for x in input_ids], dim=0),
"attention_mask": torch.stack([pad_1d(x, 0) for x in attn], dim=0),
"pixel_values": torch.stack(pix, dim=0),
"labels": torch.stack(labels, dim=0),
}
def freeze_all_but_classifier(model: nn.Module) -> List[str]:
for p in model.parameters():
p.requires_grad = False
trainable: List[str] = []
# HF convention for sequence classification heads varies; include several patterns.
head_patterns = ["classifier", "classification_head", "score"]
for name, p in model.named_parameters():
if any(pat in name for pat in head_patterns):
p.requires_grad = True
trainable.append(name)
# Fallback: unfreeze last Linear module found
if not trainable:
last_linear_prefix = None
for name, m in model.named_modules():
if isinstance(m, nn.Linear):
last_linear_prefix = name
if last_linear_prefix is not None:
for name, p in model.named_parameters():
if name.startswith(last_linear_prefix):
p.requires_grad = True
trainable.append(name)
if not trainable:
raise RuntimeError("Failed to identify classification head parameters to train.")
return trainable
def unfreeze_all(model: nn.Module) -> List[str]:
trainable: List[str] = []
for name, p in model.named_parameters():
p.requires_grad = True
trainable.append(name)
return trainable
@torch.no_grad()
def compute_metrics_multilabel(logits: torch.Tensor, labels: torch.Tensor) -> Dict[str, float]:
"""
logits: (N, 4), labels: (N, 4) float {0,1}
"""
probs = torch.sigmoid(logits.float()).cpu().numpy()
y = labels.cpu().numpy().astype(np.int32)
# Macro-F1 at threshold 0.5
pred = (probs >= 0.5).astype(np.int32)
eps = 1e-9
tp = (pred & y).sum(axis=0).astype(np.float64)
fp = (pred & (1 - y)).sum(axis=0).astype(np.float64)
fn = ((1 - pred) & y).sum(axis=0).astype(np.float64)
precision = tp / (tp + fp + eps)
recall = tp / (tp + fn + eps)
f1 = (2.0 * precision * recall) / (precision + recall + eps)
f1_macro = float(np.mean(f1))
# Exact match (all 4 labels match)
exact = float(np.mean(np.all(pred == y, axis=1)))
# Label-wise accuracy (fraction of correctly predicted labels)
accuracy = float(np.mean(pred == y))
# Mean ROC-AUC across 4 columns (Kaggle metric)
mean_auc = float("nan")
if roc_auc_score is not None:
aucs = []
for k in range(y.shape[1]):
# roc_auc requires both classes present
if len(np.unique(y[:, k])) < 2:
continue
aucs.append(roc_auc_score(y[:, k], probs[:, k]))
mean_auc = float(np.mean(aucs)) if aucs else float("nan")
return {
"val/f1_macro": f1_macro,
"val/exact_match": exact,
"val/accuracy": accuracy,
"val/mean_roc_auc": mean_auc,
}
@torch.no_grad()
def evaluate(
model,
loader,
device,
tta_loader=None,
tta_repeats: int = 0,
) -> Dict[str, float]:
def _eval_pass(pass_loader):
losses: List[float] = []
all_logits: List[torch.Tensor] = []
all_labels: List[torch.Tensor] = []
for batch in pass_loader:
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
out = model(**batch)
losses.append(float(out.loss.detach().cpu()))
all_logits.append(out.logits.detach().float().cpu())
all_labels.append(batch["labels"].detach().cpu())
if not losses:
return None, None, None
return losses, torch.cat(all_logits, dim=0), torch.cat(all_labels, dim=0)
model.eval()
if tta_loader is not None and tta_repeats > 0:
total_logits = None
total_losses: List[float] = []
labels = None
for _ in range(tta_repeats):
losses, logits, pass_labels = _eval_pass(tta_loader)
if losses is None or logits is None or pass_labels is None:
continue
total_losses.extend(losses)
if labels is None:
labels = pass_labels
if total_logits is None:
total_logits = logits
else:
total_logits += logits
if total_logits is None or labels is None:
return {"val/loss": float("nan")}
logits = total_logits / float(max(1, tta_repeats))
metrics = {
"val/loss": float(np.mean(total_losses)) if total_losses else float("nan"),
"val/tta_repeats": float(tta_repeats),
}
metrics.update(compute_metrics_multilabel(logits, labels))
return metrics
losses, logits, labels = _eval_pass(loader)
if losses is None or logits is None or labels is None:
return {"val/loss": float("nan")}
metrics = {"val/loss": float(np.mean(losses))}
metrics.update(compute_metrics_multilabel(logits, labels))
return metrics
def estimate_gradient_noise_scale(
model,
loader,
device,
steps: int,
amp_device: str,
amp_dtype: torch.dtype,
) -> Tuple[float, float]:
params = [p for p in model.parameters() if p.requires_grad]
if not params:
return float("nan"), float("nan")
was_training = model.training
model.train()
data_iter = iter(loader)
gns_values: List[float] = []
def next_batch():
nonlocal data_iter
try:
return next(data_iter)
except StopIteration:
data_iter = iter(loader)
return next(data_iter)
def grad_vector(batch) -> torch.Tensor:
model.zero_grad(set_to_none=True)
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
with torch.amp.autocast(amp_device, enabled=amp_device == "cuda", dtype=amp_dtype):
out = model(**batch)
if out.loss is None:
raise RuntimeError("Gradient noise scale estimation requires labels.")
loss = out.loss
loss.backward()
grads = [p.grad.detach().float().flatten() for p in params if p.grad is not None]
if not grads:
raise RuntimeError("No gradients collected for gradient noise scale estimation.")
return torch.cat(grads)
for _ in range(max(1, steps)):
b1 = next_batch()
b2 = next_batch()
g1 = grad_vector(b1)
g2 = grad_vector(b2)
g = 0.5 * (g1 + g2)
denom = 2.0 * (g.pow(2).sum() + 1e-12)
num = (g1 - g2).pow(2).sum()
batch_size = int(b1["labels"].shape[0])
gns = (num / denom) * batch_size
gns_values.append(float(gns))
model.zero_grad(set_to_none=True)
if not was_training:
model.eval()
return float(np.mean(gns_values)), float(np.std(gns_values))
@torch.no_grad()
def compute_image_embeddings(
rows: List[Dict[str, str]],
images_dir: Path,
processor,
model,
device,
batch_size: int,
num_workers: int,
amp_device: str,
amp_dtype: torch.dtype,
) -> Tuple[torch.Tensor, List[str]]:
encoder = model.backbone.get_encoder()
if not hasattr(encoder, "get_image_features"):
raise RuntimeError("Encoder does not expose get_image_features for duplicate checking.")
ds = ImageOnlyDataset(rows, images_dir, processor, max_items=0, augment=None)
loader = DataLoader(
ds,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=torch.cuda.is_available(),
drop_last=False,
)
was_training = model.training
model.eval()
embeddings: List[torch.Tensor] = []
image_ids: List[str] = []
for pixel_values, ids in loader:
pixel_values = pixel_values.to(device, non_blocking=True)
with torch.amp.autocast(amp_device, enabled=amp_device == "cuda", dtype=amp_dtype):
feats = encoder.get_image_features(pixel_values)
pooled = feats.mean(dim=1)
embeddings.append(pooled.float().cpu())
image_ids.extend(list(ids))
if was_training:
model.train()
return torch.cat(embeddings, dim=0), image_ids
def find_duplicate_pairs(
embeddings: torch.Tensor,
image_ids: List[str],
threshold: float,
max_pairs: int,
chunk_size: int,
) -> Tuple[List[Dict[str, float]], int]:
if embeddings.numel() == 0:
return [], 0
emb = F.normalize(embeddings.float(), dim=1)
total_pairs = 0
pairs: List[Dict[str, float]] = []
n = emb.shape[0]
for start in range(0, n, max(1, chunk_size)):
end = min(start + max(1, chunk_size), n)
sims = emb[start:end] @ emb.T
for row in range(end - start):
idx = start + row
sim_row = sims[row]
if idx + 1 < sim_row.numel():
sim_row[: idx + 1] = -1.0
hits = (sim_row >= threshold).nonzero(as_tuple=False).squeeze(1)
total_pairs += int(hits.numel())
if len(pairs) >= max_pairs:
continue
for j in hits.tolist():
if len(pairs) >= max_pairs:
break
pairs.append(
{
"image_id_a": image_ids[idx],
"image_id_b": image_ids[j],
"cosine_sim": float(sim_row[j].item()),
}
)
return pairs, total_pairs
def scan_duplicate_pairs(
embeddings: torch.Tensor,
image_ids: List[str],
threshold: float,
max_pairs: int,
chunk_size: int,
build_groups: bool = False,
) -> Tuple[List[Dict[str, float]], int, Optional[Dict[int, List[int]]]]:
if embeddings.numel() == 0:
return [], 0, {} if build_groups else None
emb = F.normalize(embeddings.float(), dim=1)
total_pairs = 0
pairs: List[Dict[str, float]] = []
n = emb.shape[0]
parent = list(range(n)) if build_groups else None
def find(x: int) -> int:
while parent[x] != x:
parent[x] = parent[parent[x]]
x = parent[x]
return x
def union(a: int, b: int) -> None:
ra, rb = find(a), find(b)
if ra != rb:
if ra < rb:
parent[rb] = ra
else:
parent[ra] = rb
for start in range(0, n, max(1, chunk_size)):
end = min(start + max(1, chunk_size), n)
sims = emb[start:end] @ emb.T
for row in range(end - start):
idx = start + row
sim_row = sims[row]
if idx + 1 < sim_row.numel():
sim_row[: idx + 1] = -1.0
hits = (sim_row >= threshold).nonzero(as_tuple=False).squeeze(1)
total_pairs += int(hits.numel())
if build_groups and hits.numel() > 0:
for j in hits.tolist():
union(idx, j)
if len(pairs) < max_pairs:
for j in hits.tolist():
if len(pairs) >= max_pairs:
break
pairs.append(
{
"image_id_a": image_ids[idx],
"image_id_b": image_ids[j],
"cosine_sim": float(sim_row[j].item()),
}
)
groups = None
if build_groups:
groups = {}
for i in range(n):
root = find(i)
groups.setdefault(root, []).append(i)
return pairs, total_pairs, groups
def build_kfold_splits(n_items: int, kfolds: int, seed: int) -> List[np.ndarray]:
if kfolds < 2:
raise ValueError("kfolds must be >= 2 to build folds.")
if kfolds > n_items:
raise ValueError(f"kfolds={kfolds} is greater than dataset size={n_items}.")
rng = np.random.RandomState(seed)
idxs = np.arange(n_items)
rng.shuffle(idxs)
return [fold for fold in np.array_split(idxs, kfolds) if len(fold) > 0]
def _get_int_attr(obj, names: List[str]) -> Optional[int]:
for name in names:
val = getattr(obj, name, None)
if isinstance(val, int) and val > 0:
return val
return None
class ImageTokenProcessor:
def __init__(self, image_processor, image_token_id: int, mm_tokens_per_image: int):
self.image_processor = image_processor
self.image_token_id = int(image_token_id)
self.mm_tokens_per_image = int(mm_tokens_per_image)
def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
if images is None:
raise ValueError("images must be provided to the processor.")
enc = self.image_processor(images=images, return_tensors=return_tensors)
pixel_values = enc.get("pixel_values")
if pixel_values is None:
raise ValueError("image processor did not return pixel_values.")
batch_size = int(pixel_values.shape[0])
input_ids = torch.full(
(batch_size, self.mm_tokens_per_image),
self.image_token_id,
dtype=torch.long,
)
attention_mask = torch.ones((batch_size, self.mm_tokens_per_image), dtype=torch.long)
enc["input_ids"] = input_ids
enc["attention_mask"] = attention_mask
return enc
def build_fallback_processor(model_id: str) -> ImageTokenProcessor:
config = AutoConfig.from_pretrained(model_id)
encoder_cfg = getattr(config, "encoder", None) or config
image_token_id = _get_int_attr(encoder_cfg, ["image_token_id", "image_token_index"])
mm_tokens_per_image = _get_int_attr(encoder_cfg, ["mm_tokens_per_image"])
if image_token_id is None:
image_token_id = _get_int_attr(config, ["image_token_id", "image_token_index"])
if mm_tokens_per_image is None:
mm_tokens_per_image = _get_int_attr(config, ["mm_tokens_per_image"])
if image_token_id is None:
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
for tok in ["<image>", "<start_of_image>"]:
tok_id = tokenizer.convert_tokens_to_ids(tok)
if isinstance(tok_id, int) and tok_id != tokenizer.unk_token_id:
image_token_id = tok_id
break
except Exception:
image_token_id = None
if image_token_id is None or mm_tokens_per_image is None:
raise RuntimeError(
"Could not resolve image_token_id/mm_tokens_per_image for fallback processor."
)
image_processor = AutoImageProcessor.from_pretrained(model_id)
print(
"[processor] using image-only fallback processor; text tokens will be ignored.",
flush=True,
)
return ImageTokenProcessor(image_processor, image_token_id, mm_tokens_per_image)
def load_processor(model_id: str, use_fast: bool):
if use_fast:
try:
return AutoProcessor.from_pretrained(model_id, use_fast=True)
except Exception as exc:
print(
f"[processor] fast processor failed ({exc.__class__.__name__}), falling back to slow.",
flush=True,
)
try:
return AutoProcessor.from_pretrained(model_id, use_fast=False)
except Exception as exc:
print(
f"[processor] AutoProcessor failed ({exc.__class__.__name__}); using fallback.",
flush=True,
)
return build_fallback_processor(model_id)
def train_simple(
model,
train_loader,
args,
device,
amp_device: str,
amp_dtype: torch.dtype,
use_grad_scaler: bool,
) -> None:
trainable_params = [p for p in model.parameters() if p.requires_grad]
optim = torch.optim.AdamW(trainable_params, lr=args.lr, weight_decay=args.weight_decay)
steps_per_epoch = math.ceil(len(train_loader) / max(1, args.grad_accum))
total_steps = steps_per_epoch * args.epochs
warmup_steps = int(total_steps * args.warmup_ratio)
sched = get_cosine_schedule_with_warmup(
optim, num_warmup_steps=warmup_steps, num_training_steps=total_steps
)
scaler = torch.amp.GradScaler(amp_device, enabled=use_grad_scaler)
for _ in range(1, args.epochs + 1):
model.train()
optim.zero_grad(set_to_none=True)
for step, batch in enumerate(train_loader, start=1):
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
with torch.amp.autocast(
amp_device,
enabled=amp_device == "cuda",
dtype=amp_dtype,
):
out = model(**batch)
loss = out.loss / max(1, args.grad_accum)
if use_grad_scaler:
scaler.scale(loss).backward()
else:
loss.backward()
if step % args.grad_accum == 0:
if use_grad_scaler:
scaler.step(optim)
scaler.update()
else:
optim.step()
optim.zero_grad(set_to_none=True)
sched.step()
@torch.no_grad()
def predict_probs(
model,
loader,
device,
amp_device: str,
amp_dtype: torch.dtype,
) -> torch.Tensor:
model.eval()
probs: List[torch.Tensor] = []
for batch in loader:
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
with torch.amp.autocast(
amp_device,
enabled=amp_device == "cuda",
dtype=amp_dtype,
):
out = model(**batch)
probs.append(torch.sigmoid(out.logits.float()).cpu())
if not probs:
return torch.empty((0, len(LABEL_COLS)), dtype=torch.float32)
return torch.cat(probs, dim=0)
def build_oof_soft_labels(
rows: List[Dict[str, str]],
images_dir: Path,
processor,
args,
device,
amp_device: str,
amp_dtype: torch.dtype,
use_grad_scaler: bool,
train_augment,
log_cli,
) -> np.ndarray:
folds = build_kfold_splits(len(rows), args.kfolds, args.seed)
oof_probs = np.zeros((len(rows), len(LABEL_COLS)), dtype=np.float32)
for fold_idx, val_idx in enumerate(folds, start=1):
train_idx = np.setdiff1d(np.arange(len(rows)), val_idx)
fold_train_rows = [rows[i] for i in train_idx]
fold_val_rows = [rows[i] for i in val_idx]
fold_model = EncoderClassifier(
model_id=args.model_id,
num_labels=len(LABEL_COLS),
classifier_dropout=args.classifier_dropout,
classifier_head=args.classifier_head,
classifier_hidden_dim=args.classifier_mlp_dim,
)
if args.train_full_model:
unfreeze_all(fold_model)
else:
freeze_all_but_classifier(fold_model)
fold_model.to(device)
fold_train_ds = PP2020Dataset(
fold_train_rows,
images_dir,
processor,
max_items=0,
augment=train_augment,
use_soft_labels=False,
)
fold_val_ds = PP2020Dataset(
fold_val_rows,
images_dir,
processor,
max_items=0,
augment=None,
use_soft_labels=False,
)
fold_train_loader = DataLoader(
fold_train_ds,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=torch.cuda.is_available(),
collate_fn=collate_fn,
drop_last=False,
)
fold_val_loader = DataLoader(
fold_val_ds,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=torch.cuda.is_available(),
collate_fn=collate_fn,
drop_last=False,
)
log_cli(f"[kfold] fold={fold_idx}/{len(folds)} train={len(fold_train_rows)} val={len(fold_val_rows)}")
train_simple(
fold_model,
fold_train_loader,
args,
device,
amp_device,
amp_dtype,
use_grad_scaler,
)
fold_probs = predict_probs(
fold_model,
fold_val_loader,
device,
amp_device,
amp_dtype,
)
oof_probs[val_idx] = fold_probs.numpy()
del fold_model
if torch.cuda.is_available():
torch.cuda.empty_cache()
return oof_probs
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--data-dir", type=str, default="./pp2020")
ap.add_argument("--model-id", type=str, default="google/t5gemma-2-270m-270m")
ap.add_argument("--classifier-head", type=str, choices=["linear", "mlp"], default="linear")
ap.add_argument("--classifier-mlp-dim", type=int, default=0)
ap.add_argument("--classifier-dropout", type=float, default=0.1)
ap.add_argument("--epochs", type=int, default=5)
ap.add_argument("--batch-size", type=int, default=8)
ap.add_argument("--num-workers", type=int, default=4)
ap.add_argument("--lr", type=float, default=3e-4)
ap.add_argument("--weight-decay", type=float, default=0.0)
ap.add_argument("--warmup-ratio", type=float, default=0.05)
ap.add_argument("--grad-accum", type=int, default=1)
ap.add_argument("--log-every", type=int, default=10)
ap.add_argument("--eval-every", type=int, default=25)
ap.add_argument("--augment", action="store_true")
ap.add_argument("--tta-repeats", type=int, default=0)
ap.add_argument("--estimate-gns", action="store_true")
ap.add_argument("--gns-steps", type=int, default=5)
ap.add_argument("--gns-batch-size", type=int, default=8)
ap.add_argument("--check-duplicates", action="store_true")
ap.add_argument("--remove-duplicates", action="store_true")
ap.add_argument("--train-full-model", action="store_true")
ap.add_argument("--kfolds", type=int, default=0)
ap.add_argument("--label-mix-alpha", type=float, default=0.0)
ap.add_argument("--dedupe-threshold", type=float, default=0.995)
ap.add_argument("--dedupe-batch-size", type=int, default=124)
ap.add_argument("--dedupe-max-pairs", type=int, default=20)
ap.add_argument("--dedupe-chunk-size", type=int, default=256)
ap.add_argument("--dedupe-output", type=str, default="duplicate_report.json")
ap.add_argument("--seed", type=int, default=42)
ap.add_argument("--val-split", type=float, default=1.0 / 3.0)
ap.add_argument("--max-items", type=int, default=0)
ap.add_argument("--project", type=str, default="pp2020-t5gemma2")
ap.add_argument("--run-name", type=str, default="")
ap.add_argument("--wandb-mode", type=str, choices=["online", "offline", "disabled"], default="online")
ap.add_argument("--wandb-init-timeout", type=int, default=180)
ap.add_argument("--use-fast-processor", action=argparse.BooleanOptionalAction, default=True)
args = ap.parse_args()
seed_everything(args.seed)
# W&B
if args.wandb_mode == "disabled":
os.environ["WANDB_MODE"] = "disabled"
else:
os.environ["WANDB_MODE"] = args.wandb_mode
wandb.init(
project=args.project,
name=args.run_name,
settings=wandb.Settings(init_timeout=args.wandb_init_timeout),
config={
"dataset": "pp2020",
"model_id": args.model_id,
"classifier_head": args.classifier_head,
"classifier_mlp_dim": args.classifier_mlp_dim,
"classifier_dropout": args.classifier_dropout,
"epochs": args.epochs,
"batch_size": args.batch_size,
"lr": args.lr,
"weight_decay": args.weight_decay,
"warmup_ratio": args.warmup_ratio,
"grad_accum": args.grad_accum,
"log_every": args.log_every,
"eval_every": args.eval_every,
"augment": args.augment,
"tta_repeats": args.tta_repeats,
"estimate_gns": args.estimate_gns,
"gns_steps": args.gns_steps,
"gns_batch_size": args.gns_batch_size,
"check_duplicates": args.check_duplicates,
"remove_duplicates": args.remove_duplicates,
"train_full_model": args.train_full_model,
"kfolds": args.kfolds,
"label_mix_alpha": args.label_mix_alpha,
"dedupe_threshold": args.dedupe_threshold,
"dedupe_batch_size": args.dedupe_batch_size,
"dedupe_max_pairs": args.dedupe_max_pairs,
"dedupe_chunk_size": args.dedupe_chunk_size,
"dedupe_output": args.dedupe_output,
"seed": args.seed,
"val_split": args.val_split,
"max_items": args.max_items,
"labels": LABEL_COLS,
"wandb_init_timeout": args.wandb_init_timeout,
"use_fast_processor": args.use_fast_processor,
},
)
# Data directory should contain train.csv and images/
root = Path(args.data_dir).resolve()
train_csv = root / "train.csv"
images_dir = root / "images"
if not train_csv.exists():
raise FileNotFoundError(f"Missing {train_csv}")
if not images_dir.exists():
raise FileNotFoundError(f"Missing {images_dir}")
rows = read_train_csv(train_csv)
train_augment = None
if args.augment:
train_augment = build_pp2020_augmentations()
# Model + processor
processor = load_processor(args.model_id, args.use_fast_processor)
model = EncoderClassifier(
model_id=args.model_id,
num_labels=len(LABEL_COLS),
classifier_dropout=args.classifier_dropout,
classifier_head=args.classifier_head,
classifier_hidden_dim=args.classifier_mlp_dim,
)
if args.train_full_model:
trainable_names = unfreeze_all(model)
else:
trainable_names = freeze_all_but_classifier(model)
wandb.config.update({"trainable_param_names": trainable_names}, allow_val_change=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
amp_device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
if hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
amp_dtype = torch.bfloat16
else:
amp_dtype = torch.float16
else:
amp_dtype = torch.float32
use_grad_scaler = torch.cuda.is_available() and amp_dtype == torch.float16
scaler = torch.amp.GradScaler(amp_device, enabled=use_grad_scaler)
def log_cli(msg: str) -> None:
print(msg, flush=True)
if args.check_duplicates or args.remove_duplicates:
embeddings, image_ids = compute_image_embeddings(
rows,
images_dir,
processor,
model,
device,
args.dedupe_batch_size,
args.num_workers,
amp_device,
amp_dtype,
)
max_pairs = args.dedupe_max_pairs if args.check_duplicates else 0
pairs, total_pairs, groups = scan_duplicate_pairs(
embeddings,
image_ids,
args.dedupe_threshold,
max_pairs,
args.dedupe_chunk_size,
build_groups=args.remove_duplicates,
)
report = {
"threshold": args.dedupe_threshold,
"checked_images": len(image_ids),
"total_pairs": total_pairs,
"pairs": pairs,
}
if args.remove_duplicates:
keep = set()
removed_ids: List[str] = []
if groups:
for group in groups.values():
group_sorted = sorted(group)
keep_idx = group_sorted[0]
keep.add(keep_idx)
for idx in group_sorted[1:]:
removed_ids.append(image_ids[idx])
if removed_ids:
rows = [rows[i] for i in range(len(rows)) if i in keep]
report["removed_count"] = len(removed_ids)
report["kept_count"] = len(rows)
report["removed_image_ids"] = removed_ids
log_cli(f"[dedupe] removed={len(removed_ids)} kept={len(rows)}")
wandb.log({"dedupe/removed_count": len(removed_ids)}, step=0)
Path(args.dedupe_output).write_text(json.dumps(report, indent=2), encoding="utf-8")
wandb.log(
{
"dedupe/checked_images": len(image_ids),
"dedupe/total_pairs": total_pairs,
"dedupe/threshold": args.dedupe_threshold,
},
step=0,
)
log_cli(
f"[dedupe] checked={len(image_ids)} pairs={total_pairs} report={args.dedupe_output}"
)
for pair in pairs:
log_cli(
f"[dedupe] {pair['image_id_a']} <-> {pair['image_id_b']} sim={pair['cosine_sim']:.4f}"
)
# Random holdout split (no group ids available in PP2020 package)
idxs = np.arange(len(rows))
np.random.shuffle(idxs)
n_val = max(1, int(len(rows) * args.val_split))
val_set = set(idxs[:n_val].tolist())
train_rows = [rows[i] for i in range(len(rows)) if i not in val_set]
val_rows = [rows[i] for i in range(len(rows)) if i in val_set]
if args.max_items and args.max_items > 0:
train_rows = train_rows[: args.max_items]
val_rows = val_rows[: max(1, int(args.max_items * args.val_split))]
if args.label_mix_alpha < 0.0 or args.label_mix_alpha > 1.0:
raise ValueError("--label-mix-alpha must be in [0, 1].")
if args.tta_repeats < 0:
raise ValueError("--tta-repeats must be >= 0.")
if args.classifier_mlp_dim < 0:
raise ValueError("--classifier-mlp-dim must be >= 0.")
if args.classifier_dropout < 0.0 or args.classifier_dropout > 1.0:
raise ValueError("--classifier-dropout must be in [0, 1].")
use_soft_labels = args.kfolds >= 2 and args.label_mix_alpha > 0.0
if args.label_mix_alpha > 0.0 and args.kfolds < 2:
raise ValueError("--label-mix-alpha requires --kfolds >= 2 for OOF predictions.")
if use_soft_labels:
oof_probs = build_oof_soft_labels(
train_rows,
images_dir,
processor,
args,
device,
amp_device,
amp_dtype,
use_grad_scaler,
train_augment,
log_cli,
)
for idx, row in enumerate(train_rows):
y = np.array([float(row[c]) for c in LABEL_COLS], dtype=np.float32)
soft = (1.0 - args.label_mix_alpha) * y + args.label_mix_alpha * oof_probs[idx]
row["soft_labels"] = soft.tolist()
log_cli(
f"[kfold] mixed labels alpha={args.label_mix_alpha:.2f} folds={args.kfolds}"
)
# Data loaders
train_ds = PP2020Dataset(
train_rows,
images_dir,
processor,
max_items=0,
augment=train_augment,
use_soft_labels=use_soft_labels,
)
val_ds = PP2020Dataset(val_rows, images_dir, processor, max_items=0, augment=None)
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=torch.cuda.is_available(),
collate_fn=collate_fn,
drop_last=False,
)
val_loader = DataLoader(
val_ds,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=torch.cuda.is_available(),
collate_fn=collate_fn,
drop_last=False,
)
val_tta_loader = None
if args.tta_repeats > 0:
val_tta_ds = PP2020Dataset(
val_rows,
images_dir,
processor,
max_items=0,
augment=build_pp2020_tta_augmentations(),
)
val_tta_loader = DataLoader(
val_tta_ds,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=torch.cuda.is_available(),
collate_fn=collate_fn,
drop_last=False,
)
# Optim + schedule (head-only)
trainable_params = [p for p in model.parameters() if p.requires_grad]
optim = torch.optim.AdamW(trainable_params, lr=args.lr, weight_decay=args.weight_decay)
steps_per_epoch = math.ceil(len(train_loader) / max(1, args.grad_accum))
total_steps = steps_per_epoch * args.epochs
warmup_steps = int(total_steps * args.warmup_ratio)
sched = get_cosine_schedule_with_warmup(
optim, num_warmup_steps=warmup_steps, num_training_steps=total_steps
)
if args.augment:
log_cli("[augment] enabled=pp2020")
if args.tta_repeats > 0:
log_cli(f"[tta] enabled repeats={args.tta_repeats}")
if args.estimate_gns:
gns_loader = DataLoader(
train_ds,
batch_size=args.gns_batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=torch.cuda.is_available(),
collate_fn=collate_fn,
drop_last=True,
)
gns_mean, gns_std = estimate_gradient_noise_scale(
model,
gns_loader,
device,
args.gns_steps,
amp_device,
amp_dtype,
)
gns_opt = int(max(1.0, round(gns_mean))) if not math.isnan(gns_mean) else 0
wandb.log(
{
"gns/scale": gns_mean,
"gns/std": gns_std,
"gns/batch_size": args.gns_batch_size,
"gns/steps": args.gns_steps,
"gns/optimal_batch_size": gns_opt,
},
step=0,
)
log_cli(f"[gns] scale={gns_mean:.4f} std={gns_std:.4f} opt_batch~{gns_opt}")
# Train
global_step = 0
best_auc = -1.0
best_path = Path("best_pp2020_t5gemma2_head.pt").resolve()
last_eval_step = -1
for epoch in range(1, args.epochs + 1):
model.train()
optim.zero_grad(set_to_none=True)
running = 0.0
for step, batch in enumerate(train_loader, start=1):
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
with torch.amp.autocast(
amp_device,
enabled=torch.cuda.is_available(),
dtype=amp_dtype,
):
out = model(**batch)
loss = out.loss / max(1, args.grad_accum)
if use_grad_scaler:
scaler.scale(loss).backward()
else:
loss.backward()
running += float(loss.detach().cpu())
if step % args.grad_accum == 0:
if use_grad_scaler:
scaler.step(optim)
scaler.update()
else:
optim.step()
optim.zero_grad(set_to_none=True)
sched.step()
global_step += 1
lr = sched.get_last_lr()[0]
wandb.log(
{
"train/loss": running,
"train/lr": lr,
"train/epoch": epoch,
"train/step": global_step,
},
step=global_step,
)
if args.log_every > 0 and global_step % args.log_every == 0:
log_cli(
f"[train] step={global_step} epoch={epoch} loss={running:.4f} lr={lr:.2e}"
)
running = 0.0
if args.eval_every > 0 and global_step % args.eval_every == 0:
metrics = evaluate(
model,
val_loader,
device,
tta_loader=val_tta_loader,
tta_repeats=args.tta_repeats,
)
metrics["val/epoch"] = epoch
metrics["val/step"] = global_step
wandb.log(metrics, step=global_step)
log_cli(
"[eval] "
f"step={global_step} loss={metrics.get('val/loss', float('nan')):.4f} "
f"f1={metrics.get('val/f1_macro', float('nan')):.4f} "
f"acc={metrics.get('val/accuracy', float('nan')):.4f} "
f"auc={metrics.get('val/mean_roc_auc', float('nan')):.4f}"
)
sel_auc = metrics.get("val/mean_roc_auc", float("nan"))
if not math.isnan(sel_auc) and sel_auc > best_auc:
best_auc = float(sel_auc)
torch.save(
{
"model_id": args.model_id,
"state_dict": model.state_dict(),
"labels": LABEL_COLS,
"processor_id": args.model_id,
"metrics": metrics,
},
best_path,
)
wandb.log({"val/best_mean_roc_auc": best_auc}, step=global_step)
model.train()
last_eval_step = global_step
# Eval at epoch end if no recent eval ran on the last step.
if last_eval_step != global_step:
metrics = evaluate(
model,
val_loader,
device,
tta_loader=val_tta_loader,
tta_repeats=args.tta_repeats,
)
metrics["val/epoch"] = epoch
metrics["val/step"] = global_step
wandb.log(metrics, step=global_step)
log_cli(
"[eval] "
f"step={global_step} loss={metrics.get('val/loss', float('nan')):.4f} "
f"f1={metrics.get('val/f1_macro', float('nan')):.4f} "
f"acc={metrics.get('val/accuracy', float('nan')):.4f} "
f"auc={metrics.get('val/mean_roc_auc', float('nan')):.4f}"
)
sel_auc = metrics.get("val/mean_roc_auc", float("nan"))
if not math.isnan(sel_auc) and sel_auc > best_auc:
best_auc = float(sel_auc)
torch.save(
{
"model_id": args.model_id,
"state_dict": model.state_dict(),
"labels": LABEL_COLS,
"processor_id": args.model_id,
"metrics": metrics,
},
best_path,
)
wandb.log({"val/best_mean_roc_auc": best_auc}, step=global_step)
last_eval_step = global_step
# Log artifacts
art = wandb.Artifact(name=f"{args.run_name}-best", type="model")
art.add_file(str(best_path))
wandb.log_artifact(art)
manifest = {
"dataset": "pp2020",
"data_root": str(root),
"train_csv": str(train_csv),
"images_dir": str(images_dir),
"model_id": args.model_id,
"classifier_head": args.classifier_head,
"classifier_mlp_dim": args.classifier_mlp_dim,
"classifier_dropout": args.classifier_dropout,
"labels": LABEL_COLS,
"best_mean_roc_auc": best_auc,
}
Path("run_manifest.json").write_text(json.dumps(manifest, indent=2), encoding="utf-8")
wandb.save("run_manifest.json")
if __name__ == "__main__":
main()
Here are the requirements:
--extra-index-url https://download.pytorch.org/whl/cu118
torch==2.5.1+cu118
torchvision==0.20.1+cu118
# T5Gemma2 support via transformers git (matches current setup)
transformers @ git+https://github.com/huggingface/transformers.git@ad7f4d0103599ff098bb33c11b9c1a73d97262fd
huggingface-hub==1.3.1
sentencepiece
protobuf
pillow
numpy
wandb
kaggle
scikit-learn
albumentations
opencv-python