design pattern 2025-04-11 12 min read

ML Model Monitoring: Detecting Drift Before It Becomes a Problem

Build a production ML monitoring system. Learn the difference between data drift and concept drift, how to detect them statistically, and how to set up automated alerting and retraining triggers.

model monitoring data drift concept drift MLOps production ML

Models Degrade Silently

Software services fail loudly: 500 errors, latency spikes, crash reports. ML models fail quietly: predictions get gradually worse while the service appears healthy. By the time users complain, the model may have been degraded for weeks.

The root cause: the real world changes, but your model was trained on the past.

Types of Model Degradation

Data Drift (Covariate Shift)

The distribution of input features changes. The relationship between features and labels stays the same, but the model sees inputs different from what it was trained on.

Example: A fraud model trained on desktop web transactions starts receiving mobile transactions. Feature distributions (session length, click patterns) shift significantly.

# Detect data drift with Kolmogorov-Smirnov test
from scipy import stats
import numpy as np

def ks_drift_test(reference: np.ndarray, current: np.ndarray, threshold: float = 0.05) -> dict:
    """
    Test if two distributions are significantly different.
    p_value < threshold → drift detected.
    """
    stat, p_value = stats.ks_2samp(reference, current)
    return {
        "statistic": stat,
        "p_value": p_value,
        "drift_detected": p_value < threshold,
        "reference_mean": reference.mean(),
        "current_mean": current.mean(),
    }

# Run on each feature
for feature in feature_columns:
    result = ks_drift_test(
        reference=training_data[feature].values,
        current=last_week_predictions_df[feature].values,
    )
    if result["drift_detected"]:
        print(f"DRIFT: {feature} — KS={result['statistic']:.3f}, p={result['p_value']:.4f}")
        print(f"  Reference mean: {result['reference_mean']:.3f}")
        print(f"  Current mean:   {result['current_mean']:.3f}")

Concept Drift

The relationship between features and labels changes. Feature distributions may be stable, but the model's predictions are increasingly wrong.

Example: A churn model trained in 2022 — when users churned because of pricing — starts failing in 2024 when users churn because of content quality. Same features, different signal.

Concept drift is harder to detect without labels. Options:

  1. Use delayed labels: Wait for ground truth, compare to model predictions
  2. Proxy metrics: If model predicts click probability, monitor actual click rates
  3. Output distribution monitoring: If the model's prediction distribution shifts, investigate
from scipy.stats import chi2_contingency

def prediction_distribution_drift(
    reference_predictions: np.ndarray,
    current_predictions: np.ndarray,
    n_bins: int = 10,
    threshold: float = 0.05,
) -> dict:
    """
    Detect shift in model's output distribution using Chi-squared test.
    Works for both classification (class probabilities) and regression.
    """
    bins = np.linspace(0, 1, n_bins + 1)

    ref_counts = np.histogram(reference_predictions, bins=bins)[0] + 1  # +1 smoothing
    cur_counts = np.histogram(current_predictions, bins=bins)[0] + 1

    chi2, p_value, _, _ = chi2_contingency(
        np.array([ref_counts, cur_counts])
    )

    return {
        "chi2": chi2,
        "p_value": p_value,
        "drift_detected": p_value < threshold,
        "reference_mean_pred": reference_predictions.mean(),
        "current_mean_pred": current_predictions.mean(),
    }

A Production Monitoring System

import pandas as pd
import numpy as np
from dataclasses import dataclass
from typing import Optional
from datetime import datetime, timedelta

@dataclass
class DriftAlert:
    feature: str
    drift_type: str  # "data", "prediction", "performance"
    severity: str    # "warning", "critical"
    statistic: float
    p_value: float
    reference_value: float
    current_value: float
    timestamp: datetime

class ModelMonitor:
    def __init__(
        self,
        model_name: str,
        reference_data: pd.DataFrame,
        feature_cols: list[str],
        prediction_col: str = "prediction",
        label_col: Optional[str] = None,
    ):
        self.model_name = model_name
        self.reference_data = reference_data
        self.feature_cols = feature_cols
        self.prediction_col = prediction_col
        self.label_col = label_col

        # Pre-compute reference statistics
        self.reference_stats = {
            col: {
                "values": reference_data[col].dropna().values,
                "mean": reference_data[col].mean(),
                "std": reference_data[col].std(),
                "p5": reference_data[col].quantile(0.05),
                "p95": reference_data[col].quantile(0.95),
            }
            for col in feature_cols
        }

    def check_data_drift(
        self,
        current_data: pd.DataFrame,
        p_value_threshold: float = 0.05,
    ) -> list[DriftAlert]:
        alerts = []

        for col in self.feature_cols:
            if col not in current_data.columns:
                continue

            current_values = current_data[col].dropna().values
            ref_values = self.reference_stats[col]["values"]

            stat, p_value = stats.ks_2samp(ref_values, current_values)

            if p_value < p_value_threshold:
                severity = "critical" if p_value < 0.001 else "warning"
                alerts.append(DriftAlert(
                    feature=col,
                    drift_type="data",
                    severity=severity,
                    statistic=stat,
                    p_value=p_value,
                    reference_value=self.reference_stats[col]["mean"],
                    current_value=current_values.mean(),
                    timestamp=datetime.now(),
                ))

        return alerts

    def check_prediction_drift(
        self,
        current_predictions: np.ndarray,
        p_value_threshold: float = 0.05,
    ) -> Optional[DriftAlert]:
        ref_preds = self.reference_data[self.prediction_col].values
        stat, p_value = stats.ks_2samp(ref_preds, current_predictions)

        if p_value < p_value_threshold:
            return DriftAlert(
                feature="predictions",
                drift_type="prediction",
                severity="critical" if p_value < 0.001 else "warning",
                statistic=stat,
                p_value=p_value,
                reference_value=ref_preds.mean(),
                current_value=current_predictions.mean(),
                timestamp=datetime.now(),
            )
        return None

    def check_performance_drift(
        self,
        current_data: pd.DataFrame,
        performance_threshold: float = 0.05,
        metric: str = "auc",
    ) -> Optional[DriftAlert]:
        """Requires labels — use when delayed labels become available."""
        if self.label_col not in current_data.columns:
            return None

        from sklearn.metrics import roc_auc_score
        current_auc = roc_auc_score(
            current_data[self.label_col],
            current_data[self.prediction_col],
        )
        reference_auc = roc_auc_score(
            self.reference_data[self.label_col],
            self.reference_data[self.prediction_col],
        )

        performance_drop = reference_auc - current_auc
        if performance_drop > performance_threshold:
            return DriftAlert(
                feature="model_performance",
                drift_type="performance",
                severity="critical",
                statistic=performance_drop,
                p_value=0.0,
                reference_value=reference_auc,
                current_value=current_auc,
                timestamp=datetime.now(),
            )
        return None

    def run_full_check(self, current_data: pd.DataFrame) -> list[DriftAlert]:
        alerts = []
        alerts.extend(self.check_data_drift(current_data))

        if self.prediction_col in current_data.columns:
            pred_alert = self.check_prediction_drift(current_data[self.prediction_col].values)
            if pred_alert:
                alerts.append(pred_alert)

        perf_alert = self.check_performance_drift(current_data)
        if perf_alert:
            alerts.append(perf_alert)

        return alerts

Alerting and Retraining Triggers

class AlertingSystem:
    def process_alerts(self, alerts: list[DriftAlert], monitor: ModelMonitor):
        if not alerts:
            return

        # Categorize
        critical = [a for a in alerts if a.severity == "critical"]
        warnings = [a for a in alerts if a.severity == "warning"]

        # Log all
        for alert in alerts:
            self.log_alert(alert)

        # Page on-call for critical alerts
        if critical:
            self.send_pagerduty_alert(
                title=f"[CRITICAL] Model drift detected: {monitor.model_name}",
                body=self.format_alert_body(critical),
            )

        # Trigger retraining if multiple features drift or performance drops
        if len(warnings) >= 3 or any(a.drift_type == "performance" for a in critical):
            self.trigger_retraining_pipeline(monitor.model_name)

    def trigger_retraining_pipeline(self, model_name: str):
        """Kick off automated retraining."""
        import subprocess
        subprocess.run([
            "python", "src/train.py",
            "--model", model_name,
            "--trigger", "drift_detected",
            "--date", datetime.now().strftime("%Y-%m-%d"),
        ])
        print(f"Retraining triggered for {model_name}")

Running Monitoring on a Schedule

# monitoring/daily_check.py — run via cron or Airflow
from datetime import datetime, timedelta

monitor = ModelMonitor(
    model_name="churn-v3",
    reference_data=load_training_data(),
    feature_cols=FEATURE_COLUMNS,
    prediction_col="churn_probability",
    label_col="churned",
)

alerting = AlertingSystem()

# Load yesterday's predictions + any newly available labels
yesterday = (datetime.now() - timedelta(days=1)).strftime("%Y-%m-%d")
current_data = load_production_predictions(date=yesterday)

alerts = monitor.run_full_check(current_data)
alerting.process_alerts(alerts, monitor)

print(f"Monitoring check complete: {len(alerts)} alerts generated")

The Monitoring Dashboard Checklist

Every production ML model should expose:

  • Prediction volume (requests/minute) — detect serving failures
  • Prediction distribution (mean, percentiles) — detect output drift
  • Feature distribution for top 10 features — detect input drift
  • Model latency (p50, p99) — detect performance degradation
  • Error rate (null predictions, timeouts) — detect pipeline issues
  • Model performance metrics (with label delay) — detect accuracy degradation

Build the full production ML pipeline with our guides to MLOps and CI/CD and feature stores.

Want to Go Deeper?

This article is part of our comprehensive curriculum on building ML systems at scale. Explore our full courses for hands-on learning.