This article describes a practical pattern for validating a multi-stage ML inference pipeline using end-to-end (E2E) contract tests in Pytest. The objective is to detect regressions early by asserting that each pipeline stage (not just the final prediction) continues to produce expected outputs within defined tolerances.
Why this pattern exists
Complex ML pipelines typically include multiple steps before inference, such as:
-
Session ingestion and schema normalization
-
Windowing / segmentation
-
Signal formatting and trimming
-
Noise / motion filtering
-
Feature extraction per model
-
Model inference (one or more models)
-
Post-processing / ensemble averaging
A regression in any intermediate step can silently degrade results. Traditional “final output only” tests often fail to pinpoint where the problem originated. Contract testing solves this by validating intermediate artifacts against versioned golden outputs.
Core concept
Stage contracts
A contract is what each stage promises to output:
-
Required keys / schema
-
Expected shapes
-
Deterministic transformations (within numeric drift)
-
Allowed differences (e.g., timestamps, request IDs)
Golden artifacts
A golden artifact is a stored reference output for a stage, produced from a known-good pipeline version. Tests compare the current output to the golden output.
reference snippet
"""
E2E Contract Test Pattern (Pytest)
---------------------------------
Goal:
Validate each pipeline stage (preprocess → features → inference → ensemble)
against versioned "golden" artifacts to catch regressions early.
What you store as goldens:
- Stage outputs (lists of dicts, DataFrames, arrays)
- Feature tensors/arrays for each model
- Aggregate prediction summaries (means, ensembles)
How you compare:
- Structural equality via stable JSON serialization for nested objects
- Numerical closeness via allclose/isclose for float drift
- Ignore known-volatile keys (request_id, timestamps, runtime, etc.)
Replace placeholders:
<E2E_INPUT_DIR>, <GOLDEN_DIR>, run_preprocess_chain, extract_features_*, load_artifact, model_*
"""
import json
import os
import numpy as np
import pandas as pd
import pytest
# ----------------------------
# Helpers: serialization + comparisons
# ----------------------------
def to_jsonable(x):
"""Convert numpy/scalars/arrays into JSON-serializable equivalents."""
if isinstance(x, dict):
return {k: to_jsonable(v) for k, v in x.items()}
if isinstance(x, list):
return [to_jsonable(v) for v in x]
if isinstance(x, tuple):
return [to_jsonable(v) for v in x]
if isinstance(x, np.ndarray):
return x.tolist()
if isinstance(x, (np.integer, np.floating)):
return x.item()
return x
def assert_struct_equal(a, b):
"""
Contract check for nested dict/list outputs.
Uses JSON dumps with stable key ordering to prevent ordering noise.
"""
a2 = json.dumps(to_jsonable(a), sort_keys=True)
b2 = json.dumps(to_jsonable(b), sort_keys=True)
assert a2 == b2
def assert_arrays_close(a, b, *, atol=1e-2, rtol=1e-5):
"""
Contract check for numeric arrays/tensors.
Use tolerances to accommodate minor float drift.
"""
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
assert a.shape == b.shape
assert np.allclose(a, b, atol=atol, rtol=rtol, equal_nan=True)
def drop_keys(record: dict, ignore_keys: set[str]) -> dict:
"""Remove volatile keys (IDs, timestamps, etc.) before structural comparisons."""
return {k: v for k, v in record.items() if k not in ignore_keys}
# ----------------------------
# Discovery: input cases for parametrized execution
# ----------------------------
def list_cases(input_dir: str):
"""
Returns a list of (case_id, file_path) pairs.
case_id can be used to map to corresponding golden artifact names.
"""
files = sorted(f for f in os.listdir(input_dir) if f.endswith(".json"))
return [(i, os.path.join(input_dir, f)) for i, f in enumerate(files)]
# ----------------------------
# Fixture: load models / shared config once per class
# ----------------------------
@pytest.fixture(scope="class", autouse=True)
def setup_pipeline(request):
"""
One-time initialization:
- Load models
- Set input + golden artifact directories
- Optionally configure runtime flags for deterministic behavior
"""
request.cls.model_a = load_model("model_a") # e.g., CNN
request.cls.model_b = load_model("model_b") # e.g., LLM/Transformer
request.cls.e2e_input_dir = "<E2E_INPUT_DIR>"
request.cls.golden_dir = "<GOLDEN_DIR>"
# Keys that are expected to vary between runs (ignore in comparisons)
request.cls.ignore_keys = {"request_id", "timestamp", "runtime_seconds"}
# ----------------------------
# Test: Stage-by-stage contract validation
# ----------------------------
@pytest.mark.usefixtures("setup_pipeline")
class TestPipelineE2EContracts:
@pytest.mark.parametrize(("case_id", "input_file"), list_cases("<E2E_INPUT_DIR>"))
def test_pipeline_contracts(self, case_id: int, input_file: str):
# ---- Stage 0: Load raw input payload ----
payload = json.load(open(input_file))
# Optional: Validate raw payload matches expected input golden
# golden_input = load_artifact(f"{self.golden_dir}/stage0_input_{case_id}.pkl")
# assert_struct_equal(payload, golden_input)
# ---- Stage 1: Preprocessing (window → format → trim → filter) ----
windows = run_preprocess_chain(
payload,
steps=["window", "format", "trim", "filter"],
)
golden_windows = load_artifact(f"{self.golden_dir}/stage1_windows_{case_id}.pkl")
if isinstance(golden_windows, pd.DataFrame):
golden_windows = golden_windows.to_dict("records")
# Normalize records by dropping volatile keys prior to comparison
windows_norm = [drop_keys(r, self.ignore_keys) for r in windows]
golden_norm = [drop_keys(r, self.ignore_keys) for r in golden_windows]
assert_struct_equal(windows_norm, golden_norm)
# ---- Stage 2: Feature extraction per model ----
feat_a = extract_features_a(windows) # e.g., CNN pixel_values
feat_b = extract_features_b(windows) # e.g., LLM input_ids
golden_feat_a = load_artifact(f"{self.golden_dir}/stage2_feat_a_{case_id}.pkl")
golden_feat_b = load_artifact(f"{self.golden_dir}/stage2_feat_b_{case_id}.pkl")
assert_arrays_close(feat_a, golden_feat_a, atol=1e-2, rtol=1e-5)
assert_arrays_close(feat_b, golden_feat_b, atol=1e-2, rtol=1e-5)
# ---- Stage 3: Model inference ----
pred_a = self.model_a(feat_a) # vector per window
pred_b = self.model_b(feat_b) # vector per window
# Compare either full vectors or just stable aggregates (recommended)
pred_a_mean = float(np.mean(pred_a))
pred_b_mean = float(np.mean(pred_b))
# ---- Stage 4: Ensemble / post-processing rule ----
ensemble_mean = 0.5 * (pred_a_mean + pred_b_mean)
golden_preds = load_artifact(f"{self.golden_dir}/stage3_preds_{case_id}.pkl")
# golden_preds example:
# {"model_a_mean": 1.23, "model_b_mean": 1.10, "ensemble_mean": 1.165}
# Use looser tolerances for model outputs if needed
assert np.isclose(pred_a_mean, golden_preds["model_a_mean"], atol=1e-1, rtol=1e-5)
assert np.isclose(pred_b_mean, golden_preds["model_b_mean"], atol=1e-1, rtol=1e-5)
assert np.isclose(ensemble_mean, golden_preds["ensemble_mean"], atol=1e-1, rtol=1e-5)
Notes for teams adopting this pattern
-
Prefer asserting intermediate stage outputs rather than only final predictions.
-
Keep goldens versioned (e.g.,
goldens/v1,goldens/v2) and regenerate them only with explicit review. -
Use strict comparisons for deterministic transforms and tolerant comparisons for numeric/model outputs.
-
Explicitly document ignored keys and tolerances so failures are explainable.
Metadata
Category: Engineering → MLOps → Testing & Validation
Tags: pytest, end-to-end-testing, contract-testing, golden-files, regression, ml-pipelines
