Common PyTorch / Transformers Debug Snippets (Freeze, Predict, Forward-Pass Probe)

Audience: ML engineers working on PyTorch + Transformer-based models
Scope: Three frequently used snippets for (1) freezing backbone parameters, (2) converting logits → probability/label for binary classification, and (3) performing a targeted debug forward pass on a validation dataset sample.


1) Freezing Transformer Parameters (Backbone Lock)

When to use

Freeze the transformer/backbone when you want to:

  • Train only a classification head (fast iteration, lower GPU memory, less risk of catastrophic forgetting)

  • Run ablations to isolate head behavior

  • Debug training stability without full fine-tuning

Snippet

import torch

def set_trainable(module: torch.nn.Module, trainable: bool) -> None:
    for p in module.parameters():
        p.requires_grad = trainable

# Example usage:
# set_trainable(model.backbone, trainable=False)   # freeze backbone
# set_trainable(model.classifier, trainable=True)  # keep head trainable

Recommended “production-safe” form

Use an explicit helper and log what is frozen/unfrozen.

def freeze_module(module: torch.nn.Module) -> None:
    for _, p in module.named_parameters():
        p.requires_grad = False

def unfreeze_module(module: torch.nn.Module) -> None:
    for _, p in module.named_parameters():
        p.requires_grad = True

# Example
freeze_module(self.transformer)

Best practice: build optimizer over trainable params only

trainable_params = (p for p in model.parameters() if p.requires_grad)
optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate)

Common pitfalls

  • Optimizer still includes frozen params: Usually harmless (no grads), but cleaner to construct optimizer from filter(lambda p: p.requires_grad, model.parameters()).

  • LayerNorm / bias fine-tuning: Sometimes you want partial freezing (e.g., unfreeze norms). Make it explicit.


2) Binary Classification: Logits → Probability → Label

When to use

Convert model outputs (logits) into:

  • A probability for the positive class

  • A predicted label using a threshold (e.g., 0.5)

Snippet (2-class softmax case)

Case A: Two-logit output (shape (B, 2)) — Softmax

Use this when the model outputs a logit per class.

import torch

def predict_binary_from_two_logits(logits: torch.Tensor, threshold: float = 0.5):
    """
    logits: Tensor of shape (B, 2)
    Returns:
      probs_pos: Tensor of shape (B,) with P(class=1)
      preds: Tensor of shape (B,) with {0,1}
    """
    probs = torch.softmax(logits, dim=-1)          # (B, 2)
    probs_pos = probs[:, 1]                        # (B,)
    preds = (probs_pos >= threshold).long()
    return probs_pos, preds


Interpretation

  • Assumes logits shape is (B, 2) (batch size × number of classes)

  • softmax(...)[0, 1] extracts the probability of class 1 for the first item in the batch

B) Single-logit (sigmoid) binary head case

If your model outputs (B,) or (B,1) logits for binary classification, use sigmoid:

import torch

def predict_binary_from_single_logit(logits: torch.Tensor, threshold: float = 0.5):
    """
    logits: Tensor of shape (B,) or (B, 1)
    Returns:
      probs_pos: Tensor of shape (B,) with P(class=1)
      preds: Tensor of shape (B,) with {0,1}
    """
    logits = logits.view(-1)
    probs_pos = torch.sigmoid(logits)              # (B,)
    preds = (probs_pos >= threshold).long()
    return probs_pos, preds


Common pitfalls

  • Softmax vs Sigmoid mismatch:

    • Use softmax if you have 2 output logits (class 0 and class 1).

    • Use sigmoid if you have 1 output logit (positive class).

  • Threshold is not universal: For imbalanced data, 0.5 may be suboptimal. Consider tuning threshold on a validation set using ROC/PR trade-offs.


3) Debug Forward Pass on a Random Validation Sample

When to use

This is useful for quickly validating:

  • Input shape/device correctness

  • Model forward compatibility

  • Logits shape and label alignment

  • Unexpected dtype or tensor rank issues

Snippet

import torch
import random

def debug_single_forward(model, dataset, device, idx: int | None = None):
    model.eval()
    with torch.no_grad():
        if idx is None:
            idx = random.randrange(len(dataset))

        sample = dataset[idx]  # expected: dict-like with tensors/arrays

        # Typical required inputs (adjust to your model contract)
        input_ids = torch.as_tensor(sample["input_ids"]).unsqueeze(0).to(device)

        batch = {"input_ids": input_ids}

        # Optional fields commonly used in transformer pipelines
        if "attention_mask" in sample:
            batch["attention_mask"] = torch.as_tensor(sample["attention_mask"]).unsqueeze(0).to(device)

        outputs = model(**batch)

        # Support both dict-like outputs and attribute outputs
        logits = outputs["logits"] if isinstance(outputs, dict) else outputs.logits

        label = sample.get("label", sample.get("labels", None))

        print(f"Index: {idx}")
        print(f"input_ids: shape={tuple(input_ids.shape)} dtype={input_ids.dtype} device={input_ids.device}")
        if "attention_mask" in batch:
            am = batch["attention_mask"]
            print(f"attention_mask: shape={tuple(am.shape)} dtype={am.dtype} device={am.device}")
        print(f"logits: shape={tuple(logits.shape)} dtype={logits.dtype} device={logits.device}")
        print(f"label: {label}")

        return logits, label


Recommended “clean” debug block

This adds eval() and torch.no_grad() and moves labels too (when present).

import numpy as np
import torch

model.eval()
with torch.no_grad():
    idx = np.random.randint(0, len(val_torch_ds))
    inputs = val_torch_ds[idx]

    input_ids = inputs["input_ids"].unsqueeze(0).to(device)  # (1, T)
    outputs = model(input_ids=input_ids)

    logits = outputs["logits"] if isinstance(outputs, dict) else outputs.logits

    labels = inputs.get("labels", None)
    if labels is not None:
        labels = torch.as_tensor(labels).to(device)

    print("idx:", idx)
    print("input_ids:", tuple(input_ids.shape), input_ids.dtype, input_ids.device)
    print("logits:", tuple(logits.shape), logits.dtype, logits.device)
    print("labels:", labels)

Notes on clone() and unsqueeze(0)

  • unsqueeze(0) is required if the dataset returns a single example of shape (T,) and the model expects batch-first (B, T).

  • clone() is optional; it can be helpful if you suspect the dataset reuses memory or you are mutating tensors in-place.

Common pitfalls

  • Forgetting attention masks: If your model expects attention_mask, include it:

    attention_mask = inputs["attention_mask"].unsqueeze(0).to(device)
    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    
    
  • Model returns tuple vs dict: Hugging Face models often return a ModelOutput that supports both attribute and dict-like access; handle both as shown above.


Quick Reference Checklist

Freeze backbone

  • requires_grad=False set for intended parameters

  • Optimizer built only over trainable params (recommended)

  • Confirm trainable parameter count

Binary prediction

  • Softmax for (B,2) logits; sigmoid for (B,1) logits

  • Use batch-safe indexing ([:, 1])

  • Threshold tuned if imbalance exists

Forward-pass debug

  • model.eval() + torch.no_grad()

  • Batch dimension added (unsqueeze(0))

  • All required tensors moved to the same device

  • Print shapes/dtypes to confirm expectations


Metadata

Category: Engineering → ML Platform / Model Training
Tags: pytorch, transformers, debugging, binary-classification, fine-tuning