Audience: ML engineers working on PyTorch + Transformer-based models
Scope: Three frequently used snippets for (1) freezing backbone parameters, (2) converting logits → probability/label for binary classification, and (3) performing a targeted debug forward pass on a validation dataset sample.
1) Freezing Transformer Parameters (Backbone Lock)
When to use
Freeze the transformer/backbone when you want to:
-
Train only a classification head (fast iteration, lower GPU memory, less risk of catastrophic forgetting)
-
Run ablations to isolate head behavior
-
Debug training stability without full fine-tuning
Snippet
import torch
def set_trainable(module: torch.nn.Module, trainable: bool) -> None:
for p in module.parameters():
p.requires_grad = trainable
# Example usage:
# set_trainable(model.backbone, trainable=False) # freeze backbone
# set_trainable(model.classifier, trainable=True) # keep head trainable
Recommended “production-safe” form
Use an explicit helper and log what is frozen/unfrozen.
def freeze_module(module: torch.nn.Module) -> None:
for _, p in module.named_parameters():
p.requires_grad = False
def unfreeze_module(module: torch.nn.Module) -> None:
for _, p in module.named_parameters():
p.requires_grad = True
# Example
freeze_module(self.transformer)
Best practice: build optimizer over trainable params only
trainable_params = (p for p in model.parameters() if p.requires_grad)
optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate)
Common pitfalls
-
Optimizer still includes frozen params: Usually harmless (no grads), but cleaner to construct optimizer from
filter(lambda p: p.requires_grad, model.parameters()). -
LayerNorm / bias fine-tuning: Sometimes you want partial freezing (e.g., unfreeze norms). Make it explicit.
2) Binary Classification: Logits → Probability → Label
When to use
Convert model outputs (logits) into:
-
A probability for the positive class
-
A predicted label using a threshold (e.g., 0.5)
Snippet (2-class softmax case)
Case A: Two-logit output (shape (B, 2)) — Softmax
Use this when the model outputs a logit per class.
import torch
def predict_binary_from_two_logits(logits: torch.Tensor, threshold: float = 0.5):
"""
logits: Tensor of shape (B, 2)
Returns:
probs_pos: Tensor of shape (B,) with P(class=1)
preds: Tensor of shape (B,) with {0,1}
"""
probs = torch.softmax(logits, dim=-1) # (B, 2)
probs_pos = probs[:, 1] # (B,)
preds = (probs_pos >= threshold).long()
return probs_pos, preds
Interpretation
-
Assumes
logitsshape is(B, 2)(batch size × number of classes) -
softmax(...)[0, 1]extracts the probability of class 1 for the first item in the batch
B) Single-logit (sigmoid) binary head case
If your model outputs (B,) or (B,1) logits for binary classification, use sigmoid:
import torch
def predict_binary_from_single_logit(logits: torch.Tensor, threshold: float = 0.5):
"""
logits: Tensor of shape (B,) or (B, 1)
Returns:
probs_pos: Tensor of shape (B,) with P(class=1)
preds: Tensor of shape (B,) with {0,1}
"""
logits = logits.view(-1)
probs_pos = torch.sigmoid(logits) # (B,)
preds = (probs_pos >= threshold).long()
return probs_pos, preds
Common pitfalls
-
Softmax vs Sigmoid mismatch:
-
Use softmax if you have 2 output logits (class 0 and class 1).
-
Use sigmoid if you have 1 output logit (positive class).
-
-
Threshold is not universal: For imbalanced data, 0.5 may be suboptimal. Consider tuning threshold on a validation set using ROC/PR trade-offs.
3) Debug Forward Pass on a Random Validation Sample
When to use
This is useful for quickly validating:
-
Input shape/device correctness
-
Model forward compatibility
-
Logits shape and label alignment
-
Unexpected dtype or tensor rank issues
Snippet
import torch
import random
def debug_single_forward(model, dataset, device, idx: int | None = None):
model.eval()
with torch.no_grad():
if idx is None:
idx = random.randrange(len(dataset))
sample = dataset[idx] # expected: dict-like with tensors/arrays
# Typical required inputs (adjust to your model contract)
input_ids = torch.as_tensor(sample["input_ids"]).unsqueeze(0).to(device)
batch = {"input_ids": input_ids}
# Optional fields commonly used in transformer pipelines
if "attention_mask" in sample:
batch["attention_mask"] = torch.as_tensor(sample["attention_mask"]).unsqueeze(0).to(device)
outputs = model(**batch)
# Support both dict-like outputs and attribute outputs
logits = outputs["logits"] if isinstance(outputs, dict) else outputs.logits
label = sample.get("label", sample.get("labels", None))
print(f"Index: {idx}")
print(f"input_ids: shape={tuple(input_ids.shape)} dtype={input_ids.dtype} device={input_ids.device}")
if "attention_mask" in batch:
am = batch["attention_mask"]
print(f"attention_mask: shape={tuple(am.shape)} dtype={am.dtype} device={am.device}")
print(f"logits: shape={tuple(logits.shape)} dtype={logits.dtype} device={logits.device}")
print(f"label: {label}")
return logits, label
Recommended “clean” debug block
This adds eval() and torch.no_grad() and moves labels too (when present).
import numpy as np
import torch
model.eval()
with torch.no_grad():
idx = np.random.randint(0, len(val_torch_ds))
inputs = val_torch_ds[idx]
input_ids = inputs["input_ids"].unsqueeze(0).to(device) # (1, T)
outputs = model(input_ids=input_ids)
logits = outputs["logits"] if isinstance(outputs, dict) else outputs.logits
labels = inputs.get("labels", None)
if labels is not None:
labels = torch.as_tensor(labels).to(device)
print("idx:", idx)
print("input_ids:", tuple(input_ids.shape), input_ids.dtype, input_ids.device)
print("logits:", tuple(logits.shape), logits.dtype, logits.device)
print("labels:", labels)
Notes on clone() and unsqueeze(0)
-
unsqueeze(0)is required if the dataset returns a single example of shape(T,)and the model expects batch-first(B, T). -
clone()is optional; it can be helpful if you suspect the dataset reuses memory or you are mutating tensors in-place.
Common pitfalls
-
Forgetting attention masks: If your model expects
attention_mask, include it:attention_mask = inputs["attention_mask"].unsqueeze(0).to(device) outputs = model(input_ids=input_ids, attention_mask=attention_mask) -
Model returns tuple vs dict: Hugging Face models often return a
ModelOutputthat supports both attribute and dict-like access; handle both as shown above.
Quick Reference Checklist
Freeze backbone
-
requires_grad=Falseset for intended parameters -
Optimizer built only over trainable params (recommended)
-
Confirm trainable parameter count
Binary prediction
-
Softmax for
(B,2)logits; sigmoid for(B,1)logits -
Use batch-safe indexing (
[:, 1]) -
Threshold tuned if imbalance exists
Forward-pass debug
-
model.eval()+torch.no_grad() -
Batch dimension added (
unsqueeze(0)) -
All required tensors moved to the same
device -
Print shapes/dtypes to confirm expectations
Metadata
Category: Engineering → ML Platform / Model Training
Tags: pytorch, transformers, debugging, binary-classification, fine-tuning