Testing AI Systems: Beyond Unit Tests
Unit tests verify deterministic behavior. AI systems are probabilistic. Traditional assertions fail when correct outputs vary. “Generate a product description” has infinite valid responses.
Testing AI requires different approaches. Behavioral verification over exact matching. Property-based testing over example-based. Visual validation for UI outputs. Integration testing across the entire pipeline.
Testing Model Outputs#
Problem: Non-Deterministic Results
# This test will fail randomly
def test_generate_description():
description = model.generate("laptop")
assert description == "A portable computer for work and entertainment"
# Fails when model outputs equally valid: "Lightweight computing device for productivity"
Solution: Property Testing
def test_generate_description():
description = model.generate("laptop")
# Verify properties, not exact output
assert len(description) > 20 # Substantial content
assert len(description) < 200 # Not too verbose
assert "laptop" in description.lower() # On topic
assert not any(word in description.lower() for word in ["phone", "tablet"]) # Category correct
assert description[0].isupper() # Properly capitalized
assert description[-1] in ".!?" # Complete sentence
Properties define correctness without specifying exact output.
Testing Classification Models#
Confusion Matrix Testing:
def test_sentiment_classifier():
test_cases = [
("I love this product!", "positive"),
("This is terrible", "negative"),
("It's okay I guess", "neutral"),
]
predictions = [model.classify(text) for text, _ in test_cases]
expected = [label for _, label in test_cases]
# Calculate accuracy
accuracy = sum(p == e for p, e in zip(predictions, expected)) / len(test_cases)
assert accuracy >= 0.8 # 80% minimum
# Check no systematic bias
false_positives = sum(1 for p, e in zip(predictions, expected)
if p == "positive" and e != "positive")
assert false_positives <= 1 # At most 1 false positive
Testing with Multiple Examples#
@pytest.mark.parametrize("input_text,expected_category", [
("Buy now for 50% off!", "spam"),
("Your package will arrive tomorrow", "notification"),
("Meeting at 3pm", "calendar"),
("Invoice #12345 is overdue", "billing"),
])
def test_email_categorization(input_text, expected_category):
result = model.categorize(input_text)
assert result == expected_category
Parametrized tests cover multiple scenarios without duplication.
Testing Extraction Tasks#
def test_extract_dates():
text = "Meeting on Jan 15, 2025 and follow-up on January 20th"
dates = model.extract_dates(text)
# Verify structure
assert isinstance(dates, list)
assert len(dates) == 2
# Verify content
assert all(isinstance(d, datetime) for d in dates)
assert dates[0].year == 2025
assert dates[0].month == 1
assert dates[0].day == 15
Structural verification first, then content verification.
Integration Testing ML Pipelines#
def test_full_pipeline():
"""Test data flow from ingestion to inference"""
# Stage 1: Data ingestion
raw_data = load_test_data("sample.json")
ingested = pipeline.ingest(raw_data)
assert len(ingested) == 100
# Stage 2: Preprocessing
preprocessed = pipeline.preprocess(ingested)
assert all(item.has_required_fields() for item in preprocessed)
# Stage 3: Feature extraction
features = pipeline.extract_features(preprocessed)
assert features.shape == (100, 512) # Expected dimensions
# Stage 4: Model inference
predictions = pipeline.predict(features)
assert len(predictions) == 100
assert all(0 <= p <= 1 for p in predictions) # Valid probabilities
# End-to-end latency
assert pipeline.get_latency() < 1.0 # Sub-second processing
Each stage verified independently, then full pipeline verified end-to-end.
Visual Testing with Playwright MCP#
Traditional UI testing checks DOM structure:
def test_button_exists():
button = page.locator("button#submit")
assert button.is_visible()
assert button.text() == "Submit"
Visual testing with Playwright MCP checks rendered output:
def test_button_visual():
# AI can see the actual rendered button
screenshot = page.screenshot()
# Verify visual properties
assert_visual_match(screenshot, "expected_button.png", threshold=0.95)
# Or describe what should be visible
description = ai.describe_screenshot(screenshot)
assert "blue submit button" in description.lower()
assert "center of screen" in description.lower()
AI sees what users see, catches visual regressions text assertions miss.
Performance Testing#
def test_inference_latency():
"""Model must respond within SLA"""
latencies = []
for _ in range(100):
start = time.time()
model.predict(test_input)
latencies.append(time.time() - start)
p50 = np.percentile(latencies, 50)
p95 = np.percentile(latencies, 95)
p99 = np.percentile(latencies, 99)
assert p50 < 0.100 # 100ms median
assert p95 < 0.200 # 200ms p95
assert p99 < 0.500 # 500ms p99
Measure percentiles, not averages. Tail latency matters.
Data Quality Testing#
def test_training_data_quality():
"""Verify training data meets quality standards"""
dataset = load_training_data()
# No missing values in required fields
assert dataset['label'].isna().sum() == 0
# Balanced classes (within 10%)
class_distribution = dataset['label'].value_counts(normalize=True)
assert all(0.4 <= ratio <= 0.6 for ratio in class_distribution)
# No duplicate records
assert len(dataset) == len(dataset.drop_duplicates())
# Valid value ranges
assert dataset['age'].between(0, 120).all()
assert dataset['score'].between(0, 1).all()
Catch data quality issues before they affect model training.
Model Drift Detection#
def test_no_model_drift():
"""Production model performance hasn't degraded"""
# Load reference dataset
reference_data = load_reference_data()
# Get current model predictions
predictions = model.predict(reference_data.features)
# Compare to baseline performance
baseline_accuracy = 0.92
current_accuracy = accuracy_score(reference_data.labels, predictions)
# Alert if accuracy drops more than 5%
assert current_accuracy >= baseline_accuracy * 0.95
Run daily against reference dataset. Catch drift before users notice.
Testing Embedding Quality#
def test_embedding_similarity():
"""Semantically similar text should have similar embeddings"""
text1 = "The weather is beautiful today"
text2 = "It's a gorgeous day outside"
text3 = "Database connection failed"
emb1 = model.embed(text1)
emb2 = model.embed(text2)
emb3 = model.embed(text3)
# Similar sentences should be close
similarity_12 = cosine_similarity(emb1, emb2)
assert similarity_12 > 0.7
# Dissimilar sentences should be distant
similarity_13 = cosine_similarity(emb1, emb3)
assert similarity_13 < 0.3
Testing RAG Systems#
def test_rag_retrieval():
"""RAG should retrieve relevant context"""
query = "How do I reset my password?"
# Retrieve relevant documents
docs = rag.retrieve(query, k=3)
# Verify retrieval quality
assert len(docs) == 3
assert all("password" in doc.content.lower() for doc in docs)
assert any("reset" in doc.content.lower() for doc in docs)
# Generate answer with context
answer = rag.generate(query, docs)
# Verify answer quality
assert len(answer) > 50
assert "password" in answer.lower()
assert any(keyword in answer.lower() for keyword in ["reset", "change", "recover"])
Testing with Golden Datasets#
def test_against_golden_dataset():
"""Model maintains performance on curated test set"""
golden_data = load_golden_dataset()
results = []
for item in golden_data:
prediction = model.predict(item.input)
results.append({
'input': item.input,
'expected': item.expected_output,
'predicted': prediction,
'match': prediction == item.expected_output
})
# Overall accuracy
accuracy = sum(r['match'] for r in results) / len(results)
assert accuracy >= 0.90
# Log failures for analysis
failures = [r for r in results if not r['match']]
if failures:
log_failures(failures)
Golden datasets capture edge cases discovered over time.
Testing API Endpoints#
def test_inference_endpoint():
"""API contract remains stable"""
response = client.post("/api/classify", json={
"text": "This is a test message"
})
# Status code
assert response.status_code == 200
# Response structure
data = response.json()
assert "prediction" in data
assert "confidence" in data
assert "latency_ms" in data
# Value types
assert isinstance(data["prediction"], str)
assert isinstance(data["confidence"], float)
assert 0 <= data["confidence"] <= 1
# Performance
assert data["latency_ms"] < 500
Regression Testing#
def test_historical_predictions():
"""Known inputs should maintain stable outputs"""
test_cases = [
# (input, expected_output, reasoning)
("sunny day", "positive", "historically consistent"),
("system failure", "negative", "historically consistent"),
]
for input_text, expected, reason in test_cases:
result = model.predict(input_text)
assert result == expected, f"Regression: {reason}"
Lock in correct behavior once discovered.
Continuous Testing in Production#
# Production monitoring test
def test_production_health():
"""Production model meets SLAs"""
metrics = get_production_metrics(last_24h=True)
# Latency
assert metrics.p99_latency_ms < 1000
# Error rate
assert metrics.error_rate < 0.01 # Less than 1%
# Throughput
assert metrics.requests_per_second > 10
# Resource usage
assert metrics.memory_usage_percent < 80
assert metrics.cpu_usage_percent < 70
Run in CI against production endpoints. Catch degradation early.