AI Architecture

Testing ML Systems: Beyond Unit Tests and Accuracy Metrics

A practical testing strategy for production machine learning

Published
January 24, 2026 00:30
Reading Time
4 min
ML testing pyramid showing data validation at the base, then unit tests, model validation, integration tests, and production testing at the top

The model passes all tests. Accuracy is 94%. The PR gets approved.

Three days later, production is on fire. Predictions are nonsense. Customer complaints flood in. The post-mortem reveals: a feature column changed upstream. The schema was the same. The values were garbage.

No test caught it because no test checked for it.

ML testing isn’t just unit tests and accuracy metrics. Production ML systems fail in ways that traditional software testing doesn’t anticipate — data drift, distribution shift, feature corruption, silent model degradation. A comprehensive testing strategy addresses these failure modes at every layer.

The ML Testing Pyramid

Like the traditional testing pyramid, ML testing has layers. Unlike traditional testing, the base isn’t unit tests — it’s data validation.

Data validation (base, most tests): Schema checks, distribution validation, drift detection, anomaly flagging. Data problems cause most ML failures. Test data first.

Unit tests: Preprocessing functions, feature engineering, data transformations. Traditional software tests for the code that touches data.

Model validation: Performance metrics, slice analysis, fairness checks, regression tests. Does the model behave correctly on known cases?

Integration tests: End-to-end pipeline tests, API contract tests, latency validation. Does everything work together?

Production testing (top, fewest but critical): Shadow deployment, canary analysis, A/B testing. Does it work with real traffic?

Invest proportionally: most effort at the base, less as you go up. Data tests catch more bugs than model tests. Model tests catch more bugs than integration tests. But you need all layers.

Layer 1: Data Validation

Data validation catches problems before they reach your model. This is where testing effort pays the highest returns.

Schema Validation

The minimum viable data test: does the data have the expected structure?

from dataclasses import dataclass
from typing import List, Dict, Any, Set
import pandas as pd

@dataclass
class ColumnSchema:
    name: str
    dtype: str
    nullable: bool = False
    allowed_values: Set[Any] = None
    min_value: float = None
    max_value: float = None

class SchemaValidator:
    def __init__(self, schema: List[ColumnSchema]):
        self.schema = {col.name: col for col in schema}
    
    def validate(self, df: pd.DataFrame) -> List[str]:
        errors = []
        
        # Check for missing columns
        missing = set(self.schema.keys()) - set(df.columns)
        if missing:
            errors.append(f"Missing columns: {missing}")
        
        # Check for unexpected columns
        unexpected = set(df.columns) - set(self.schema.keys())
        if unexpected:
            errors.append(f"Unexpected columns: {unexpected}")
        
        # Validate each column
        for col_name, col_schema in self.schema.items():
            if col_name not in df.columns:
                continue
            
            col = df[col_name]
            
            # Nullability
            if not col_schema.nullable and col.isna().any():
                null_count = col.isna().sum()
                errors.append(f"{col_name}: {null_count} null values (not allowed)")
            
            # Allowed values (for categoricals)
            if col_schema.allowed_values:
                invalid = set(col.dropna().unique()) - col_schema.allowed_values
                if invalid:
                    errors.append(f"{col_name}: invalid values {invalid}")
            
            # Range checks
            if col_schema.min_value is not None:
                below_min = (col < col_schema.min_value).sum()
                if below_min:
                    errors.append(f"{col_name}: {below_min} values below {col_schema.min_value}")
            
            if col_schema.max_value is not None:
                above_max = (col > col_schema.max_value).sum()
                if above_max:
                    errors.append(f"{col_name}: {above_max} values above {col_schema.max_value}")
        
        return errors

# Usage
schema = [
    ColumnSchema("user_id", "int64", nullable=False),
    ColumnSchema("amount", "float64", nullable=False, min_value=0),
    ColumnSchema("category", "object", allowed_values={"A", "B", "C"}),
    ColumnSchema("timestamp", "datetime64[ns]", nullable=False),
]

validator = SchemaValidator(schema)
errors = validator.validate(incoming_data)
if errors:
    raise DataValidationError(errors)

Run schema validation on every data ingestion. Block the pipeline if validation fails. Silent data corruption is worse than a failed pipeline.

Distribution Validation

Schema can be correct while data is garbage. Distribution checks catch semantic problems:

import numpy as np
from scipy import stats

class DistributionValidator:
    def __init__(self, reference_stats: Dict[str, Dict]):
        self.reference = reference_stats
    
    def validate(self, df: pd.DataFrame, threshold: float = 0.1) -> List[str]:
        warnings = []
        
        for col_name, ref_stats in self.reference.items():
            if col_name not in df.columns:
                continue
            
            col = df[col_name].dropna()
            
            # Numeric columns: check mean/std shift
            if 'mean' in ref_stats:
                current_mean = col.mean()
                ref_mean = ref_stats['mean']
                ref_std = ref_stats['std']
                
                # Z-score of mean shift
                if ref_std > 0:
                    z_score = abs(current_mean - ref_mean) / ref_std
                    if z_score > 3:
                        warnings.append(
                            f"{col_name}: mean shifted from {ref_mean:.2f} to {current_mean:.2f} "
                            f"(z={z_score:.1f})"
                        )
            
            # Categorical columns: check distribution shift
            if 'value_counts' in ref_stats:
                current_dist = col.value_counts(normalize=True).to_dict()
                ref_dist = ref_stats['value_counts']
                
                # Simple distribution comparison
                for value, ref_pct in ref_dist.items():
                    current_pct = current_dist.get(value, 0)
                    if abs(current_pct - ref_pct) > threshold:
                        warnings.append(
                            f"{col_name}={value}: {ref_pct:.1%} -> {current_pct:.1%}"
                        )
        
        return warnings
    
    @classmethod
    def from_dataframe(cls, df: pd.DataFrame) -> 'DistributionValidator':
        """Create validator from reference dataframe."""
        stats = {}
        for col in df.columns:
            if df[col].dtype in ['int64', 'float64']:
                stats[col] = {
                    'mean': df[col].mean(),
                    'std': df[col].std(),
                    'min': df[col].min(),
                    'max': df[col].max()
                }
            else:
                stats[col] = {
                    'value_counts': df[col].value_counts(normalize=True).to_dict()
                }
        return cls(stats)

Generate reference statistics from known-good training data. Compare incoming data against reference. Alert on significant drift.

Data Expectations

For more sophisticated validation, define expectations declaratively:

# Using a Great Expectations-style approach
expectations = [
    # Column existence
    {"type": "expect_column_to_exist", "column": "user_id"},
    {"type": "expect_column_to_exist", "column": "amount"},
    
    # Type checks
    {"type": "expect_column_values_to_be_of_type", "column": "amount", "type": "float"},
    
    # Value constraints
    {"type": "expect_column_values_to_be_between", "column": "amount", "min": 0, "max": 1000000},
    {"type": "expect_column_values_to_not_be_null", "column": "user_id"},
    
    # Distribution checks
    {"type": "expect_column_mean_to_be_between", "column": "amount", "min": 50, "max": 500},
    {"type": "expect_column_proportion_of_unique_values_to_be_between", 
     "column": "category", "min": 0.001, "max": 0.1},
    
    # Cardinality
    {"type": "expect_column_unique_value_count_to_be_between",
     "column": "category", "min": 3, "max": 10},
]

Declarative expectations serve as documentation and tests simultaneously. They’re readable, maintainable, and version-controllable.

Layer 2: Unit Tests

Unit tests cover the code that transforms data: preprocessing, feature engineering, custom layers.

import pytest
import numpy as np

class TestFeatureEngineering:
    
    def test_normalize_standard_case(self):
        result = normalize(50, min_val=0, max_val=100)
        assert result == 0.5
    
    def test_normalize_edge_cases(self):
        assert normalize(0, min_val=0, max_val=100) == 0.0
        assert normalize(100, min_val=0, max_val=100) == 1.0
    
    def test_normalize_out_of_range(self):
        # Should clip, not crash
        assert normalize(150, min_val=0, max_val=100) == 1.0
        assert normalize(-50, min_val=0, max_val=100) == 0.0
    
    def test_extract_features_shape(self):
        transaction = {"amount": 100, "category": "retail", "hour": 14}
        features = extract_features(transaction)
        assert features.shape == (128,)
        assert features.dtype == np.float32
    
    def test_extract_features_deterministic(self):
        transaction = {"amount": 100, "category": "retail", "hour": 14}
        f1 = extract_features(transaction)
        f2 = extract_features(transaction)
        np.testing.assert_array_equal(f1, f2)
    
    def test_extract_features_missing_field(self):
        transaction = {"amount": 100}  # Missing category and hour
        with pytest.raises(ValueError, match="missing required field"):
            extract_features(transaction)

Property-based testing catches edge cases you didn’t think of:

from hypothesis import given, strategies as st

@given(st.floats(min_value=-1e6, max_value=1e6, allow_nan=False))
def test_normalize_always_returns_valid_range(value):
    result = normalize(value, min_val=0, max_val=100)
    assert 0.0 <= result <= 1.0
    assert not np.isnan(result)

@given(st.dictionaries(
    keys=st.sampled_from(["amount", "category", "hour"]),
    values=st.one_of(st.integers(), st.floats(), st.text())
))
def test_extract_features_never_crashes(data):
    # Should either return valid features or raise clean exception
    try:
        features = extract_features(data)
        assert features.shape == (128,)
    except ValueError:
        pass  # Expected for invalid input
    except Exception as e:
        pytest.fail(f"Unexpected exception: {e}")

Layer 3: Model Validation

Model validation goes beyond aggregate accuracy to check behaviour on important slices and edge cases.

Slice-Based Evaluation

Aggregate metrics hide problems. A model with 95% accuracy overall might have 60% accuracy on a critical segment:

def evaluate_by_slice(model, X, y, slice_fn, slice_name: str) -> Dict:
    """Evaluate model performance on a specific data slice."""
    mask = slice_fn(X)
    X_slice = X[mask]
    y_slice = y[mask]
    
    if len(X_slice) == 0:
        return {"slice": slice_name, "count": 0, "metrics": None}
    
    y_pred = model.predict(X_slice)
    
    return {
        "slice": slice_name,
        "count": len(X_slice),
        "metrics": {
            "accuracy": accuracy_score(y_slice, y_pred),
            "precision": precision_score(y_slice, y_pred, average='weighted'),
            "recall": recall_score(y_slice, y_pred, average='weighted'),
        }
    }

# Define critical slices
slices = [
    ("new_users", lambda X: X["account_age_days"] < 30),
    ("high_value", lambda X: X["lifetime_value"] > 1000),
    ("mobile", lambda X: X["device_type"] == "mobile"),
    ("international", lambda X: X["country"] != "US"),
]

# Evaluate and assert minimums
for slice_name, slice_fn in slices:
    result = evaluate_by_slice(model, X_test, y_test, slice_fn, slice_name)
    if result["count"] > 100:  # Only check slices with enough data
        assert result["metrics"]["accuracy"] > 0.85, \
            f"Accuracy on {slice_name} below threshold: {result['metrics']['accuracy']}"

Regression Testing

Keep a set of known examples with expected predictions. Test that the model still handles them correctly:

# regression_cases.json
[
    {
        "id": "obvious_fraud_1",
        "input": {"amount": 9999, "country": "XX", "velocity": 50},
        "expected_class": "fraud",
        "min_confidence": 0.9
    },
    {
        "id": "obvious_legitimate_1", 
        "input": {"amount": 25, "country": "US", "velocity": 1},
        "expected_class": "legitimate",
        "min_confidence": 0.8
    },
    {
        "id": "edge_case_high_amount_trusted_user",
        "input": {"amount": 5000, "country": "US", "velocity": 1, "account_age": 3650},
        "expected_class": "legitimate",
        "min_confidence": 0.6
    }
]

def test_regression_cases(model, cases_file: str):
    with open(cases_file) as f:
        cases = json.load(f)
    
    failures = []
    for case in cases:
        pred = model.predict_proba(case["input"])
        pred_class = model.predict(case["input"])
        
        if pred_class != case["expected_class"]:
            failures.append(f"{case['id']}: expected {case['expected_class']}, got {pred_class}")
        elif pred[case["expected_class"]] < case["min_confidence"]:
            failures.append(
                f"{case['id']}: confidence {pred[case['expected_class']]:.2f} "
                f"below threshold {case['min_confidence']}"
            )
    
    assert not failures, f"Regression failures:\n" + "\n".join(failures)

Add cases when you find bugs. The regression suite grows over time, encoding institutional knowledge about edge cases.

Fairness Checks

Test for unintended bias across protected attributes:

def test_demographic_parity(model, X, sensitive_attr: str, threshold: float = 0.1):
    """Check that positive prediction rate is similar across groups."""
    predictions = model.predict(X)
    
    groups = X[sensitive_attr].unique()
    positive_rates = {}
    
    for group in groups:
        mask = X[sensitive_attr] == group
        positive_rates[group] = predictions[mask].mean()
    
    max_rate = max(positive_rates.values())
    min_rate = min(positive_rates.values())
    disparity = max_rate - min_rate
    
    assert disparity < threshold, \
        f"Demographic disparity {disparity:.2%} exceeds threshold. Rates: {positive_rates}"

Layer 4: Integration Tests

Integration tests verify the full pipeline works end-to-end:

def test_inference_pipeline_e2e():
    """Test complete inference pipeline with realistic request."""
    # Setup
    client = TestClient(app)
    sample_request = load_fixture("sample_inference_request.json")
    
    # Execute
    start = time.monotonic()
    response = client.post("/predict", json=sample_request)
    latency_ms = (time.monotonic() - start) * 1000
    
    # Verify
    assert response.status_code == 200
    assert "prediction" in response.json()
    assert "confidence" in response.json()
    assert 0 <= response.json()["confidence"] <= 1
    assert latency_ms < 100, f"Latency {latency_ms}ms exceeds 100ms SLA"

def test_training_pipeline_e2e():
    """Test complete training pipeline produces valid model."""
    # Run training on small dataset
    result = training_pipeline.run(
        data_path="test_fixtures/small_training_set.parquet",
        output_path="/tmp/test_model",
        epochs=2
    )
    
    # Verify outputs
    assert result.status == "success"
    assert os.path.exists("/tmp/test_model/model.pt")
    assert os.path.exists("/tmp/test_model/metrics.json")
    
    # Verify model loads and predicts
    model = load_model("/tmp/test_model")
    test_input = load_fixture("sample_input.json")
    prediction = model.predict(test_input)
    assert prediction is not None

Layer 5: Production Testing

The final layer tests in production, with real traffic.

Shadow Deployment

Run the new model alongside production, compare outputs, don’t serve results:

class ShadowDeployment:
    def __init__(self, prod_model, shadow_model, sample_rate: float = 0.1):
        self.prod = prod_model
        self.shadow = shadow_model
        self.sample_rate = sample_rate
        self.comparisons = []
    
    def predict(self, input_data):
        # Always get production prediction
        prod_result = self.prod.predict(input_data)
        
        # Conditionally run shadow
        if random.random() < self.sample_rate:
            shadow_result = self.shadow.predict(input_data)
            self.comparisons.append({
                "input_hash": hash_input(input_data),
                "prod": prod_result,
                "shadow": shadow_result,
                "match": prod_result == shadow_result
            })
        
        # Always return production result
        return prod_result

Canary Analysis

Route small percentage of traffic to new model, monitor for degradation:

def analyze_canary(control_metrics: Dict, canary_metrics: Dict, threshold: float = 0.05):
    """Compare canary against control, return go/no-go decision."""
    
    checks = []
    
    # Error rate comparison
    error_increase = canary_metrics["error_rate"] - control_metrics["error_rate"]
    checks.append({
        "metric": "error_rate",
        "control": control_metrics["error_rate"],
        "canary": canary_metrics["error_rate"],
        "pass": error_increase < threshold
    })
    
    # Latency comparison
    latency_increase = (canary_metrics["p99_latency"] - control_metrics["p99_latency"]) / control_metrics["p99_latency"]
    checks.append({
        "metric": "p99_latency",
        "control": control_metrics["p99_latency"],
        "canary": canary_metrics["p99_latency"],
        "pass": latency_increase < 0.2  # Allow 20% latency increase
    })
    
    all_passed = all(c["pass"] for c in checks)
    return {"decision": "proceed" if all_passed else "rollback", "checks": checks}

The Testing Mindset

ML testing isn’t a phase — it’s continuous. Data changes. Models drift. Production reveals failure modes that tests missed.

Build testing into the pipeline:

  • Data validation on every ingestion
  • Unit tests on every commit
  • Model validation on every training run
  • Integration tests on every deployment
  • Production testing on every release

The goal isn’t perfect coverage — it’s catching problems before users do. Every production incident should add a test. Over time, your test suite encodes everything you’ve learned about how your system fails.


For more on ML system reliability, see The Observability Blind Spot for monitoring production systems, and Graceful Degradation in ML Systems for handling failures gracefully. For deterministic testing approaches, explore certifiable-harness.

About the Author

William Murray is a Regenerative Systems Architect with 30 years of UNIX infrastructure experience, specializing in deterministic computing for safety-critical systems. Based in the Scottish Highlands, he operates SpeyTech and maintains several open-source projects including C-Sentinel and c-from-scratch.

Let's Discuss Your AI Infrastructure

Available for UK-based consulting on production ML systems and infrastructure architecture.

Get in touch
← Back to AI Architecture