Testing AI Systems: Beyond Unit Tests

Testing AI Systems: Beyond Unit Tests

Unit tests verify deterministic behavior. AI systems are probabilistic. Traditional assertions fail when correct outputs vary. “Generate a product description” has infinite valid responses.

Testing AI requires different approaches. Behavioral verification over exact matching. Property-based testing over example-based. Visual validation for UI outputs. Integration testing across the entire pipeline.

Testing Model Outputs#

Problem: Non-Deterministic Results

# This test will fail randomly
def test_generate_description():
    description = model.generate("laptop")
    assert description == "A portable computer for work and entertainment"
    # Fails when model outputs equally valid: "Lightweight computing device for productivity"

Solution: Property Testing

def test_generate_description():
    description = model.generate("laptop")

    # Verify properties, not exact output
    assert len(description) > 20  # Substantial content
    assert len(description) < 200  # Not too verbose
    assert "laptop" in description.lower()  # On topic
    assert not any(word in description.lower() for word in ["phone", "tablet"])  # Category correct
    assert description[0].isupper()  # Properly capitalized
    assert description[-1] in ".!?"  # Complete sentence

Properties define correctness without specifying exact output.

Testing Classification Models#

Confusion Matrix Testing:

def test_sentiment_classifier():
    test_cases = [
        ("I love this product!", "positive"),
        ("This is terrible", "negative"),
        ("It's okay I guess", "neutral"),
    ]

    predictions = [model.classify(text) for text, _ in test_cases]
    expected = [label for _, label in test_cases]

    # Calculate accuracy
    accuracy = sum(p == e for p, e in zip(predictions, expected)) / len(test_cases)
    assert accuracy >= 0.8  # 80% minimum

    # Check no systematic bias
    false_positives = sum(1 for p, e in zip(predictions, expected)
                         if p == "positive" and e != "positive")
    assert false_positives <= 1  # At most 1 false positive

Testing with Multiple Examples#

@pytest.mark.parametrize("input_text,expected_category", [
    ("Buy now for 50% off!", "spam"),
    ("Your package will arrive tomorrow", "notification"),
    ("Meeting at 3pm", "calendar"),
    ("Invoice #12345 is overdue", "billing"),
])
def test_email_categorization(input_text, expected_category):
    result = model.categorize(input_text)
    assert result == expected_category

Parametrized tests cover multiple scenarios without duplication.

Testing Extraction Tasks#

def test_extract_dates():
    text = "Meeting on Jan 15, 2025 and follow-up on January 20th"
    dates = model.extract_dates(text)

    # Verify structure
    assert isinstance(dates, list)
    assert len(dates) == 2

    # Verify content
    assert all(isinstance(d, datetime) for d in dates)
    assert dates[0].year == 2025
    assert dates[0].month == 1
    assert dates[0].day == 15

Structural verification first, then content verification.

Integration Testing ML Pipelines#

def test_full_pipeline():
    """Test data flow from ingestion to inference"""

    # Stage 1: Data ingestion
    raw_data = load_test_data("sample.json")
    ingested = pipeline.ingest(raw_data)
    assert len(ingested) == 100

    # Stage 2: Preprocessing
    preprocessed = pipeline.preprocess(ingested)
    assert all(item.has_required_fields() for item in preprocessed)

    # Stage 3: Feature extraction
    features = pipeline.extract_features(preprocessed)
    assert features.shape == (100, 512)  # Expected dimensions

    # Stage 4: Model inference
    predictions = pipeline.predict(features)
    assert len(predictions) == 100
    assert all(0 <= p <= 1 for p in predictions)  # Valid probabilities

    # End-to-end latency
    assert pipeline.get_latency() < 1.0  # Sub-second processing

Each stage verified independently, then full pipeline verified end-to-end.

Visual Testing with Playwright MCP#

Traditional UI testing checks DOM structure:

def test_button_exists():
    button = page.locator("button#submit")
    assert button.is_visible()
    assert button.text() == "Submit"

Visual testing with Playwright MCP checks rendered output:

def test_button_visual():
    # AI can see the actual rendered button
    screenshot = page.screenshot()

    # Verify visual properties
    assert_visual_match(screenshot, "expected_button.png", threshold=0.95)

    # Or describe what should be visible
    description = ai.describe_screenshot(screenshot)
    assert "blue submit button" in description.lower()
    assert "center of screen" in description.lower()

AI sees what users see, catches visual regressions text assertions miss.

Performance Testing#

def test_inference_latency():
    """Model must respond within SLA"""

    latencies = []
    for _ in range(100):
        start = time.time()
        model.predict(test_input)
        latencies.append(time.time() - start)

    p50 = np.percentile(latencies, 50)
    p95 = np.percentile(latencies, 95)
    p99 = np.percentile(latencies, 99)

    assert p50 < 0.100  # 100ms median
    assert p95 < 0.200  # 200ms p95
    assert p99 < 0.500  # 500ms p99

Measure percentiles, not averages. Tail latency matters.

Data Quality Testing#

def test_training_data_quality():
    """Verify training data meets quality standards"""

    dataset = load_training_data()

    # No missing values in required fields
    assert dataset['label'].isna().sum() == 0

    # Balanced classes (within 10%)
    class_distribution = dataset['label'].value_counts(normalize=True)
    assert all(0.4 <= ratio <= 0.6 for ratio in class_distribution)

    # No duplicate records
    assert len(dataset) == len(dataset.drop_duplicates())

    # Valid value ranges
    assert dataset['age'].between(0, 120).all()
    assert dataset['score'].between(0, 1).all()

Catch data quality issues before they affect model training.

Model Drift Detection#

def test_no_model_drift():
    """Production model performance hasn't degraded"""

    # Load reference dataset
    reference_data = load_reference_data()

    # Get current model predictions
    predictions = model.predict(reference_data.features)

    # Compare to baseline performance
    baseline_accuracy = 0.92
    current_accuracy = accuracy_score(reference_data.labels, predictions)

    # Alert if accuracy drops more than 5%
    assert current_accuracy >= baseline_accuracy * 0.95

Run daily against reference dataset. Catch drift before users notice.

Testing Embedding Quality#

def test_embedding_similarity():
    """Semantically similar text should have similar embeddings"""

    text1 = "The weather is beautiful today"
    text2 = "It's a gorgeous day outside"
    text3 = "Database connection failed"

    emb1 = model.embed(text1)
    emb2 = model.embed(text2)
    emb3 = model.embed(text3)

    # Similar sentences should be close
    similarity_12 = cosine_similarity(emb1, emb2)
    assert similarity_12 > 0.7

    # Dissimilar sentences should be distant
    similarity_13 = cosine_similarity(emb1, emb3)
    assert similarity_13 < 0.3

Testing RAG Systems#

def test_rag_retrieval():
    """RAG should retrieve relevant context"""

    query = "How do I reset my password?"

    # Retrieve relevant documents
    docs = rag.retrieve(query, k=3)

    # Verify retrieval quality
    assert len(docs) == 3
    assert all("password" in doc.content.lower() for doc in docs)
    assert any("reset" in doc.content.lower() for doc in docs)

    # Generate answer with context
    answer = rag.generate(query, docs)

    # Verify answer quality
    assert len(answer) > 50
    assert "password" in answer.lower()
    assert any(keyword in answer.lower() for keyword in ["reset", "change", "recover"])

Testing with Golden Datasets#

def test_against_golden_dataset():
    """Model maintains performance on curated test set"""

    golden_data = load_golden_dataset()

    results = []
    for item in golden_data:
        prediction = model.predict(item.input)
        results.append({
            'input': item.input,
            'expected': item.expected_output,
            'predicted': prediction,
            'match': prediction == item.expected_output
        })

    # Overall accuracy
    accuracy = sum(r['match'] for r in results) / len(results)
    assert accuracy >= 0.90

    # Log failures for analysis
    failures = [r for r in results if not r['match']]
    if failures:
        log_failures(failures)

Golden datasets capture edge cases discovered over time.

Testing API Endpoints#

def test_inference_endpoint():
    """API contract remains stable"""

    response = client.post("/api/classify", json={
        "text": "This is a test message"
    })

    # Status code
    assert response.status_code == 200

    # Response structure
    data = response.json()
    assert "prediction" in data
    assert "confidence" in data
    assert "latency_ms" in data

    # Value types
    assert isinstance(data["prediction"], str)
    assert isinstance(data["confidence"], float)
    assert 0 <= data["confidence"] <= 1

    # Performance
    assert data["latency_ms"] < 500

Regression Testing#

def test_historical_predictions():
    """Known inputs should maintain stable outputs"""

    test_cases = [
        # (input, expected_output, reasoning)
        ("sunny day", "positive", "historically consistent"),
        ("system failure", "negative", "historically consistent"),
    ]

    for input_text, expected, reason in test_cases:
        result = model.predict(input_text)
        assert result == expected, f"Regression: {reason}"

Lock in correct behavior once discovered.

Continuous Testing in Production#

# Production monitoring test
def test_production_health():
    """Production model meets SLAs"""

    metrics = get_production_metrics(last_24h=True)

    # Latency
    assert metrics.p99_latency_ms < 1000

    # Error rate
    assert metrics.error_rate < 0.01  # Less than 1%

    # Throughput
    assert metrics.requests_per_second > 10

    # Resource usage
    assert metrics.memory_usage_percent < 80
    assert metrics.cpu_usage_percent < 70

Run in CI against production endpoints. Catch degradation early.