The model passes all tests. Accuracy is 94%. The PR gets approved.
Three days later, production is on fire. Predictions are nonsense. Customer complaints flood in. The post-mortem reveals: a feature column changed upstream. The schema was the same. The values were garbage.
No test caught it because no test checked for it.
ML testing isn’t just unit tests and accuracy metrics. Production ML systems fail in ways that traditional software testing doesn’t anticipate — data drift, distribution shift, feature corruption, silent model degradation. A comprehensive testing strategy addresses these failure modes at every layer.
The ML Testing Pyramid
Like the traditional testing pyramid, ML testing has layers. Unlike traditional testing, the base isn’t unit tests — it’s data validation.
Data validation (base, most tests): Schema checks, distribution validation, drift detection, anomaly flagging. Data problems cause most ML failures. Test data first.
Unit tests: Preprocessing functions, feature engineering, data transformations. Traditional software tests for the code that touches data.
Model validation: Performance metrics, slice analysis, fairness checks, regression tests. Does the model behave correctly on known cases?
Integration tests: End-to-end pipeline tests, API contract tests, latency validation. Does everything work together?
Production testing (top, fewest but critical): Shadow deployment, canary analysis, A/B testing. Does it work with real traffic?
Invest proportionally: most effort at the base, less as you go up. Data tests catch more bugs than model tests. Model tests catch more bugs than integration tests. But you need all layers.
Layer 1: Data Validation
Data validation catches problems before they reach your model. This is where testing effort pays the highest returns.
Schema Validation
The minimum viable data test: does the data have the expected structure?
from dataclasses import dataclass
from typing import List, Dict, Any, Set
import pandas as pd
@dataclass
class ColumnSchema:
name: str
dtype: str
nullable: bool = False
allowed_values: Set[Any] = None
min_value: float = None
max_value: float = None
class SchemaValidator:
def __init__(self, schema: List[ColumnSchema]):
self.schema = {col.name: col for col in schema}
def validate(self, df: pd.DataFrame) -> List[str]:
errors = []
# Check for missing columns
missing = set(self.schema.keys()) - set(df.columns)
if missing:
errors.append(f"Missing columns: {missing}")
# Check for unexpected columns
unexpected = set(df.columns) - set(self.schema.keys())
if unexpected:
errors.append(f"Unexpected columns: {unexpected}")
# Validate each column
for col_name, col_schema in self.schema.items():
if col_name not in df.columns:
continue
col = df[col_name]
# Nullability
if not col_schema.nullable and col.isna().any():
null_count = col.isna().sum()
errors.append(f"{col_name}: {null_count} null values (not allowed)")
# Allowed values (for categoricals)
if col_schema.allowed_values:
invalid = set(col.dropna().unique()) - col_schema.allowed_values
if invalid:
errors.append(f"{col_name}: invalid values {invalid}")
# Range checks
if col_schema.min_value is not None:
below_min = (col < col_schema.min_value).sum()
if below_min:
errors.append(f"{col_name}: {below_min} values below {col_schema.min_value}")
if col_schema.max_value is not None:
above_max = (col > col_schema.max_value).sum()
if above_max:
errors.append(f"{col_name}: {above_max} values above {col_schema.max_value}")
return errors
# Usage
schema = [
ColumnSchema("user_id", "int64", nullable=False),
ColumnSchema("amount", "float64", nullable=False, min_value=0),
ColumnSchema("category", "object", allowed_values={"A", "B", "C"}),
ColumnSchema("timestamp", "datetime64[ns]", nullable=False),
]
validator = SchemaValidator(schema)
errors = validator.validate(incoming_data)
if errors:
raise DataValidationError(errors)Run schema validation on every data ingestion. Block the pipeline if validation fails. Silent data corruption is worse than a failed pipeline.
Distribution Validation
Schema can be correct while data is garbage. Distribution checks catch semantic problems:
import numpy as np
from scipy import stats
class DistributionValidator:
def __init__(self, reference_stats: Dict[str, Dict]):
self.reference = reference_stats
def validate(self, df: pd.DataFrame, threshold: float = 0.1) -> List[str]:
warnings = []
for col_name, ref_stats in self.reference.items():
if col_name not in df.columns:
continue
col = df[col_name].dropna()
# Numeric columns: check mean/std shift
if 'mean' in ref_stats:
current_mean = col.mean()
ref_mean = ref_stats['mean']
ref_std = ref_stats['std']
# Z-score of mean shift
if ref_std > 0:
z_score = abs(current_mean - ref_mean) / ref_std
if z_score > 3:
warnings.append(
f"{col_name}: mean shifted from {ref_mean:.2f} to {current_mean:.2f} "
f"(z={z_score:.1f})"
)
# Categorical columns: check distribution shift
if 'value_counts' in ref_stats:
current_dist = col.value_counts(normalize=True).to_dict()
ref_dist = ref_stats['value_counts']
# Simple distribution comparison
for value, ref_pct in ref_dist.items():
current_pct = current_dist.get(value, 0)
if abs(current_pct - ref_pct) > threshold:
warnings.append(
f"{col_name}={value}: {ref_pct:.1%} -> {current_pct:.1%}"
)
return warnings
@classmethod
def from_dataframe(cls, df: pd.DataFrame) -> 'DistributionValidator':
"""Create validator from reference dataframe."""
stats = {}
for col in df.columns:
if df[col].dtype in ['int64', 'float64']:
stats[col] = {
'mean': df[col].mean(),
'std': df[col].std(),
'min': df[col].min(),
'max': df[col].max()
}
else:
stats[col] = {
'value_counts': df[col].value_counts(normalize=True).to_dict()
}
return cls(stats)Generate reference statistics from known-good training data. Compare incoming data against reference. Alert on significant drift.
Data Expectations
For more sophisticated validation, define expectations declaratively:
# Using a Great Expectations-style approach
expectations = [
# Column existence
{"type": "expect_column_to_exist", "column": "user_id"},
{"type": "expect_column_to_exist", "column": "amount"},
# Type checks
{"type": "expect_column_values_to_be_of_type", "column": "amount", "type": "float"},
# Value constraints
{"type": "expect_column_values_to_be_between", "column": "amount", "min": 0, "max": 1000000},
{"type": "expect_column_values_to_not_be_null", "column": "user_id"},
# Distribution checks
{"type": "expect_column_mean_to_be_between", "column": "amount", "min": 50, "max": 500},
{"type": "expect_column_proportion_of_unique_values_to_be_between",
"column": "category", "min": 0.001, "max": 0.1},
# Cardinality
{"type": "expect_column_unique_value_count_to_be_between",
"column": "category", "min": 3, "max": 10},
]Declarative expectations serve as documentation and tests simultaneously. They’re readable, maintainable, and version-controllable.
Layer 2: Unit Tests
Unit tests cover the code that transforms data: preprocessing, feature engineering, custom layers.
import pytest
import numpy as np
class TestFeatureEngineering:
def test_normalize_standard_case(self):
result = normalize(50, min_val=0, max_val=100)
assert result == 0.5
def test_normalize_edge_cases(self):
assert normalize(0, min_val=0, max_val=100) == 0.0
assert normalize(100, min_val=0, max_val=100) == 1.0
def test_normalize_out_of_range(self):
# Should clip, not crash
assert normalize(150, min_val=0, max_val=100) == 1.0
assert normalize(-50, min_val=0, max_val=100) == 0.0
def test_extract_features_shape(self):
transaction = {"amount": 100, "category": "retail", "hour": 14}
features = extract_features(transaction)
assert features.shape == (128,)
assert features.dtype == np.float32
def test_extract_features_deterministic(self):
transaction = {"amount": 100, "category": "retail", "hour": 14}
f1 = extract_features(transaction)
f2 = extract_features(transaction)
np.testing.assert_array_equal(f1, f2)
def test_extract_features_missing_field(self):
transaction = {"amount": 100} # Missing category and hour
with pytest.raises(ValueError, match="missing required field"):
extract_features(transaction)Property-based testing catches edge cases you didn’t think of:
from hypothesis import given, strategies as st
@given(st.floats(min_value=-1e6, max_value=1e6, allow_nan=False))
def test_normalize_always_returns_valid_range(value):
result = normalize(value, min_val=0, max_val=100)
assert 0.0 <= result <= 1.0
assert not np.isnan(result)
@given(st.dictionaries(
keys=st.sampled_from(["amount", "category", "hour"]),
values=st.one_of(st.integers(), st.floats(), st.text())
))
def test_extract_features_never_crashes(data):
# Should either return valid features or raise clean exception
try:
features = extract_features(data)
assert features.shape == (128,)
except ValueError:
pass # Expected for invalid input
except Exception as e:
pytest.fail(f"Unexpected exception: {e}")Layer 3: Model Validation
Model validation goes beyond aggregate accuracy to check behaviour on important slices and edge cases.
Slice-Based Evaluation
Aggregate metrics hide problems. A model with 95% accuracy overall might have 60% accuracy on a critical segment:
def evaluate_by_slice(model, X, y, slice_fn, slice_name: str) -> Dict:
"""Evaluate model performance on a specific data slice."""
mask = slice_fn(X)
X_slice = X[mask]
y_slice = y[mask]
if len(X_slice) == 0:
return {"slice": slice_name, "count": 0, "metrics": None}
y_pred = model.predict(X_slice)
return {
"slice": slice_name,
"count": len(X_slice),
"metrics": {
"accuracy": accuracy_score(y_slice, y_pred),
"precision": precision_score(y_slice, y_pred, average='weighted'),
"recall": recall_score(y_slice, y_pred, average='weighted'),
}
}
# Define critical slices
slices = [
("new_users", lambda X: X["account_age_days"] < 30),
("high_value", lambda X: X["lifetime_value"] > 1000),
("mobile", lambda X: X["device_type"] == "mobile"),
("international", lambda X: X["country"] != "US"),
]
# Evaluate and assert minimums
for slice_name, slice_fn in slices:
result = evaluate_by_slice(model, X_test, y_test, slice_fn, slice_name)
if result["count"] > 100: # Only check slices with enough data
assert result["metrics"]["accuracy"] > 0.85, \
f"Accuracy on {slice_name} below threshold: {result['metrics']['accuracy']}"Regression Testing
Keep a set of known examples with expected predictions. Test that the model still handles them correctly:
# regression_cases.json
[
{
"id": "obvious_fraud_1",
"input": {"amount": 9999, "country": "XX", "velocity": 50},
"expected_class": "fraud",
"min_confidence": 0.9
},
{
"id": "obvious_legitimate_1",
"input": {"amount": 25, "country": "US", "velocity": 1},
"expected_class": "legitimate",
"min_confidence": 0.8
},
{
"id": "edge_case_high_amount_trusted_user",
"input": {"amount": 5000, "country": "US", "velocity": 1, "account_age": 3650},
"expected_class": "legitimate",
"min_confidence": 0.6
}
]
def test_regression_cases(model, cases_file: str):
with open(cases_file) as f:
cases = json.load(f)
failures = []
for case in cases:
pred = model.predict_proba(case["input"])
pred_class = model.predict(case["input"])
if pred_class != case["expected_class"]:
failures.append(f"{case['id']}: expected {case['expected_class']}, got {pred_class}")
elif pred[case["expected_class"]] < case["min_confidence"]:
failures.append(
f"{case['id']}: confidence {pred[case['expected_class']]:.2f} "
f"below threshold {case['min_confidence']}"
)
assert not failures, f"Regression failures:\n" + "\n".join(failures)Add cases when you find bugs. The regression suite grows over time, encoding institutional knowledge about edge cases.
Fairness Checks
Test for unintended bias across protected attributes:
def test_demographic_parity(model, X, sensitive_attr: str, threshold: float = 0.1):
"""Check that positive prediction rate is similar across groups."""
predictions = model.predict(X)
groups = X[sensitive_attr].unique()
positive_rates = {}
for group in groups:
mask = X[sensitive_attr] == group
positive_rates[group] = predictions[mask].mean()
max_rate = max(positive_rates.values())
min_rate = min(positive_rates.values())
disparity = max_rate - min_rate
assert disparity < threshold, \
f"Demographic disparity {disparity:.2%} exceeds threshold. Rates: {positive_rates}"Layer 4: Integration Tests
Integration tests verify the full pipeline works end-to-end:
def test_inference_pipeline_e2e():
"""Test complete inference pipeline with realistic request."""
# Setup
client = TestClient(app)
sample_request = load_fixture("sample_inference_request.json")
# Execute
start = time.monotonic()
response = client.post("/predict", json=sample_request)
latency_ms = (time.monotonic() - start) * 1000
# Verify
assert response.status_code == 200
assert "prediction" in response.json()
assert "confidence" in response.json()
assert 0 <= response.json()["confidence"] <= 1
assert latency_ms < 100, f"Latency {latency_ms}ms exceeds 100ms SLA"
def test_training_pipeline_e2e():
"""Test complete training pipeline produces valid model."""
# Run training on small dataset
result = training_pipeline.run(
data_path="test_fixtures/small_training_set.parquet",
output_path="/tmp/test_model",
epochs=2
)
# Verify outputs
assert result.status == "success"
assert os.path.exists("/tmp/test_model/model.pt")
assert os.path.exists("/tmp/test_model/metrics.json")
# Verify model loads and predicts
model = load_model("/tmp/test_model")
test_input = load_fixture("sample_input.json")
prediction = model.predict(test_input)
assert prediction is not NoneLayer 5: Production Testing
The final layer tests in production, with real traffic.
Shadow Deployment
Run the new model alongside production, compare outputs, don’t serve results:
class ShadowDeployment:
def __init__(self, prod_model, shadow_model, sample_rate: float = 0.1):
self.prod = prod_model
self.shadow = shadow_model
self.sample_rate = sample_rate
self.comparisons = []
def predict(self, input_data):
# Always get production prediction
prod_result = self.prod.predict(input_data)
# Conditionally run shadow
if random.random() < self.sample_rate:
shadow_result = self.shadow.predict(input_data)
self.comparisons.append({
"input_hash": hash_input(input_data),
"prod": prod_result,
"shadow": shadow_result,
"match": prod_result == shadow_result
})
# Always return production result
return prod_resultCanary Analysis
Route small percentage of traffic to new model, monitor for degradation:
def analyze_canary(control_metrics: Dict, canary_metrics: Dict, threshold: float = 0.05):
"""Compare canary against control, return go/no-go decision."""
checks = []
# Error rate comparison
error_increase = canary_metrics["error_rate"] - control_metrics["error_rate"]
checks.append({
"metric": "error_rate",
"control": control_metrics["error_rate"],
"canary": canary_metrics["error_rate"],
"pass": error_increase < threshold
})
# Latency comparison
latency_increase = (canary_metrics["p99_latency"] - control_metrics["p99_latency"]) / control_metrics["p99_latency"]
checks.append({
"metric": "p99_latency",
"control": control_metrics["p99_latency"],
"canary": canary_metrics["p99_latency"],
"pass": latency_increase < 0.2 # Allow 20% latency increase
})
all_passed = all(c["pass"] for c in checks)
return {"decision": "proceed" if all_passed else "rollback", "checks": checks}The Testing Mindset
ML testing isn’t a phase — it’s continuous. Data changes. Models drift. Production reveals failure modes that tests missed.
Build testing into the pipeline:
- Data validation on every ingestion
- Unit tests on every commit
- Model validation on every training run
- Integration tests on every deployment
- Production testing on every release
The goal isn’t perfect coverage — it’s catching problems before users do. Every production incident should add a test. Over time, your test suite encodes everything you’ve learned about how your system fails.
For more on ML system reliability, see The Observability Blind Spot for monitoring production systems, and Graceful Degradation in ML Systems for handling failures gracefully. For deterministic testing approaches, explore certifiable-harness.