Add Reference Tests
Add tests to validate reference implementations in the HuggingFace dataset clone at tmp/flashinfer-trace/. Ground truth is sourced from FlashInfer repository or SGLang when FlashInfer doesn’t have the implementation.
Description
This skill creates test cases under tmp/flashinfer-trace/tests/references/ (the HF dataset clone — the in-repo flashinfer_trace/ directory was removed in the trace-dataset refactor) to validate that reference implementations in Definition JSON files produce correct outputs. The ground truth comes from:
- FlashInfer repository (preferred): Official optimized GPU kernels in
tmp/flashinfer/ - SGLang repository (fallback): When FlashInfer doesn’t have the kernel, use
tmp/sglang/
Usage
# Test all definitions of a specific op_type
/add-reference-tests --op-type mla_paged
/add-reference-tests --op-type moe
/add-reference-tests --op-type gqa_paged
/add-reference-tests --op-type rmsnorm
# Test a specific definition
/add-reference-tests --definition-name mla_paged_decode_h16_ckv512_kpe64_ps1
# Test all definitions in the definitions directory
/add-reference-tests --all
# Test with custom tolerance
/add-reference-tests --definition-name rmsnorm_h4096 --tolerance 1e-4
Parameters
definition_name(optional): Specific definition to test (e.g., “mla_paged_decode_h16_ckv512_kpe64_ps1”)op_type(optional): Test all definitions of a specific op_type (e.g., “mla_paged”, “moe”, “rmsnorm”)all(optional): Test all definitions in the definitions directorytest_sizes(optional): List of test sizes [“small”, “medium”, “large”] (default: [“small”, “medium”])tolerance(optional): Numerical tolerance for comparison (default: 1e-3 for fp16, 1e-5 for fp32)
Prerequisites
Run /clone-repos first to set up the tmp/ directory with SGLang, FlashInfer, and the HuggingFace trace dataset clone at tmp/flashinfer-trace/ — that clone is the only home for definitions and reference tests.
What This Skill Does
Phase 1: Definition Discovery
-
Load Target Definitions:
- If
definition_namespecified: load single definition - If
op_typespecified: load all definitions matching op_type fromtmp/flashinfer-trace/definitions/{op_type}/ - If
all: scan all definitions
- If
-
Check Existing Tests:
- Scan
tmp/flashinfer-trace/tests/references/for existing test files - Skip definitions that already have tests (unless force=true)
- Scan
-
Parse Definition Schema:
- Extract axes (const/var), inputs, outputs
- Identify required shapes and dtypes
- Parse reference implementation code
Phase 2: Ground Truth Discovery
For each definition, locate ground truth implementation using this priority order:
For Model Constants: HuggingFace + SGLang (Required)
Note: See extract-kernel-definitions for detailed guidance on sourcing model constants from HuggingFace and SGLang.
For Ground Truth Execution: FlashInfer API (Primary)
- When: FlashInfer has the kernel implementation (MOST kernels)
- Location:
tmp/flashinfer/python/flashinfer/ - Use for: Running optimized GPU kernel as ground truth
- Examples:
import flashinfer # GQA decode flashinfer.BatchDecodeWithPagedKVCacheWrapper(...) # GQA prefill flashinfer.BatchPrefillWithPagedKVCacheWrapper(...) # MLA flashinfer.mla.BatchMLAPagedAttentionWrapper(...) # RMSNorm flashinfer.norm.rmsnorm(...) - Important: FlashInfer is the PRIMARY ground truth for correctness validation
For Ground Truth Execution: SGLang (Fallback ONLY)
- When: FlashInfer does NOT have the kernel (e.g., some MoE variants)
- Location:
tmp/sglang/python/sglang/srt/layers/ - Use for: Ground truth when FlashInfer unavailable
- Examples:
layers/moe/fused_moe.py # MoE when FlashInfer MoE unavailable - Important: ONLY use SGLang as ground truth when FlashInfer doesn’t support the kernel
Ground Truth Source Mapping
| Op Type | Ground Truth Source | FlashInfer API | Fallback (if FlashInfer unavailable) |
|---|---|---|---|
rmsnorm | FlashInfer | flashinfer.norm.rmsnorm | N/A (FlashInfer has it) |
fused_add_rmsnorm | FlashInfer | flashinfer.norm.fused_add_rmsnorm | N/A (FlashInfer has it) |
gqa_paged | FlashInfer | flashinfer.BatchDecodeWithPagedKVCacheWrapper, flashinfer.BatchPrefillWithPagedKVCacheWrapper | N/A |
gqa_ragged | FlashInfer | flashinfer.BatchPrefillWithRaggedKVCacheWrapper | N/A |
mla_paged | FlashInfer | flashinfer.mla.BatchMLAPagedAttentionWrapper | N/A (FlashInfer has it) |
moe | SGLang (fallback) | N/A (FlashInfer MoE may not cover all variants) | sglang/layers/moe/fused_moe.py |
gemm | PyTorch | N/A | torch.nn.functional.linear |
sampling | FlashInfer | flashinfer.sampling.* | N/A |
rope | FlashInfer | flashinfer.apply_rope_with_cos_sin_cache_inplace | N/A |
Reference run() Function Sources
Note: For detailed guidance on sourcing reference implementations, see the extract-kernel-definitions skill’s “Reference Implementation Sources” section.
Quick Reference:
- Primary: FlashInfer unit tests at
tmp/flashinfer/tests/(e.g.,test_batch_decode.py,test_norm.py) - Fallback: SGLang vanilla implementations at
tmp/sglang/python/sglang/srt/layers/(only when FlashInfer unavailable)
Phase 3: Test Generation
For each definition, generate test file following the standards below.
Test File Standards
File Structure
Each test file should follow this structure:
import json
import math
from pathlib import Path
import numpy as np
import pytest
import torch
# Ground truth imports (with availability checks)
try:
import flashinfer
from flashinfer.xxx import some_kernel
FLASHINFER_AVAILABLE = True
except ImportError:
FLASHINFER_AVAILABLE = False
# Module-level constants from definition
HIDDEN_SIZE = 7168
NUM_EXPERTS = 256
# ... other constants
TRACE_ROOT = Path(__file__).resolve().parents[2]
WORKLOAD_JSONL_PATH = TRACE_ROOT / "workloads" / "op_type" / "definition_name.jsonl"
@torch.no_grad()
def run(...):
"""Reference implementation matching the definition."""
# Check constants
assert hidden_size == HIDDEN_SIZE
...
def generate_random_inputs(..., device="cuda"):
"""Generate random inputs for testing."""
...
return {...}
def test_correctness(..., atol=1e-2, rtol=5e-2):
"""Test correctness of reference implementation against ground truth."""
...
def main():
"""Run comprehensive tests."""
...
if __name__ == "__main__":
main()
Coding Style Patterns
-
Constants at Module Level: Define model-specific constants at the top
# DeepSeek V3/R1 MoE constants HIDDEN_SIZE = 7168 INTERMEDIATE_SIZE = 2048 NUM_EXPERTS_GLOBAL = 256 NUM_LOCAL_EXPERTS = 32 # EP=8 -
Use
@torch.no_grad()Decorator: For all reference implementations and test functions -
Input Generator Function: Separate function
generate_random_inputs(...)that returns a dict -
Test Function Pattern:
def test_correctness(batch_size=4, max_seq_len=64, atol=1e-2, rtol=5e-2): """Test correctness of reference implementation against ground truth.""" print(f"\n{'='*60}") print(f"Testing {description}: {params}") print(f"{'='*60}") device = "cuda" if torch.cuda.is_available() else "cpu" if device == "cpu": print("WARNING: CUDA not available, skipping test") return # Generate inputs inputs = generate_random_inputs(...) # Run reference implementation print("\nRunning reference implementation...") ref_output = run(**inputs) # Run ground truth print("Running ground truth (FlashInfer/SGLang)...") gt_output = ground_truth_fn(**inputs) # Compare outputs print("\nComparing outputs...") # ... detailed comparison
Correctness Checking Patterns
Standard Tolerance Check (FP16/BF16)
# Convert to float32 for comparison
ref_f32 = ref_output.float()
gt_f32 = gt_output.float()
# Compute detailed error metrics
abs_diff = torch.abs(ref_f32 - gt_f32)
rel_diff = abs_diff / (torch.abs(gt_f32) + 1e-8)
max_abs_diff = abs_diff.max().item()
max_rel_diff = rel_diff.max().item()
mean_abs_diff = abs_diff.mean().item()
mean_rel_diff = rel_diff.mean().item()
print(f"\nOutput tensor comparison:")
print(f"Max absolute difference: {max_abs_diff:.6e}")
print(f"Max relative difference: {max_rel_diff:.6e}")
print(f"Mean absolute difference: {mean_abs_diff:.6e}")
print(f"Mean relative difference: {mean_rel_diff:.6e}")
# Cosine similarity and MSE
cos_sim = torch.nn.functional.cosine_similarity(
ref_f32.flatten(), gt_f32.flatten(), dim=0
).item()
mse = torch.mean((ref_f32 - gt_f32) ** 2).item()
print(f"Cosine similarity: {cos_sim:.6f}")
print(f"MSE: {mse:.6e}")
# Check tolerance
all_close = torch.allclose(ref_f32, gt_f32, atol=atol, rtol=rtol)
if all_close:
print(f"\n✓ PASSED: Outputs match within tolerance (atol={atol}, rtol={rtol})")
else:
print(f"\n✗ FAILED: Outputs differ beyond tolerance (atol={atol}, rtol={rtol})")
Hit Ratio Check (for FP8/Quantized Kernels)
For quantized kernels with higher variance, use hit ratio instead of strict allclose:
# Check what percentage of elements pass the tolerance check
left = (ref_f32 - gt_f32).abs()
right = atol + rtol * gt_f32.abs()
ok = left <= right
hit_ratio = ok.float().mean().item()
print(f"\nHit ratio: {hit_ratio * 100:.2f}% (need >= {percent * 100:.2f}%)")
return hit_ratio >= percent # e.g., 85%
Error Location Debugging
When tests fail, show top error locations:
if not all_close:
flat = abs_diff.flatten()
k = min(5, flat.numel())
topv, topi = torch.topk(flat, k)
print(f"\nTop-{k} absolute error locations:")
for rank in range(k):
idx = topi[rank].item()
# Convert flat index to multi-dimensional
# ... compute indices
print(f" [{indices}]: ref={ref_val:.6e}, gt={gt_val:.6e}, diff={topv[rank].item():.6e}")
Tolerance Guidelines
| Data Type | atol | rtol | Notes |
|---|---|---|---|
| float32 | 1e-5 | 1e-5 | Strictest |
| float16 | 1e-3 | 1e-3 | Standard |
| bfloat16 | 8e-3 | 1e-2 | 0.8% abs, 1% rel |
| float8_e4m3fn | 1e-1 | 2e-1 | Use hit ratio ≥85% |
| nvfp4 | 1e-1 | 2e-1 | Use hit ratio ≥85% |
Multi-Ground-Truth Testing
Pattern for Multiple Ground Truths
When both FlashInfer and SGLang implementations are available, test against both:
# Ground truth imports
try:
from flashinfer.xxx import flashinfer_kernel
FLASHINFER_AVAILABLE = True
except ImportError:
FLASHINFER_AVAILABLE = False
try:
from sglang.srt.layers.xxx import sglang_kernel
SGLANG_AVAILABLE = True
except ImportError:
SGLANG_AVAILABLE = False
def test_correctness_vs_flashinfer(...):
"""Test reference against FlashInfer ground truth."""
if not FLASHINFER_AVAILABLE:
pytest.skip("FlashInfer not available")
ref_output = run(**inputs)
fi_output = flashinfer_kernel(**inputs)
assert_close(ref_output, fi_output, atol=atol, rtol=rtol)
def test_correctness_vs_sglang(...):
"""Test reference against SGLang ground truth."""
if not SGLANG_AVAILABLE:
pytest.skip("SGLang not available")
ref_output = run(**inputs)
sg_output = sglang_kernel(**inputs)
assert_close(ref_output, sg_output, atol=atol, rtol=rtol)
def test_ground_truths_match(...):
"""Test that FlashInfer and SGLang produce consistent results."""
if not (FLASHINFER_AVAILABLE and SGLANG_AVAILABLE):
pytest.skip("Both ground truths not available")
fi_output = flashinfer_kernel(**inputs)
sg_output = sglang_kernel(**inputs)
assert_close(fi_output, sg_output, atol=atol, rtol=rtol)
Comprehensive Test Runner
def main():
"""Run comprehensive tests against all available ground truths."""
print("Testing Reference Implementation")
# Test configurations
test_configs = [
(1, 16), # Small
(4, 32), # Medium
(8, 64), # Large
]
results = {"flashinfer": [], "sglang": []}
for config in test_configs:
if FLASHINFER_AVAILABLE:
try:
ok = test_correctness_vs_flashinfer(*config)
results["flashinfer"].append(ok)
except Exception as e:
print(f"FlashInfer test failed: {e}")
results["flashinfer"].append(False)
if SGLANG_AVAILABLE:
try:
ok = test_correctness_vs_sglang(*config)
results["sglang"].append(ok)
except Exception as e:
print(f"SGLang test failed: {e}")
results["sglang"].append(False)
# Summary
print(f"\n{'='*60}")
print("Summary:")
for source, passed_list in results.items():
if passed_list:
print(f" {source}: {sum(passed_list)}/{len(passed_list)} tests passed")
print(f"{'='*60}")
Standard Test Generation
For each definition, generate test file:
-
Create Test Class with:
- Fixture for loading definition
- Fixture for compiling reference implementation
- Fixture for ground truth function
-
Generate Test Inputs:
- Parse definition schema for input shapes and dtypes
- Generate random tensors matching specifications
- Handle both constant and variable axes
-
Create Test Methods:
test_output_shape: Verify output shapes match definitiontest_output_dtype: Verify output dtypes match definitiontest_numerical_correctness: Compare reference vs ground truthtest_determinism: Verify reproducible results
Phase 4: Test Cases
Generate multiple test cases with varying sizes:
# Small: Quick smoke tests
SMALL_SIZES = {
"batch_size": [1, 2],
"seq_len": [1, 16],
"num_pages": [1, 4],
}
# Medium: Standard tests
MEDIUM_SIZES = {
"batch_size": [4, 8, 16],
"seq_len": [64, 128, 256],
"num_pages": [16, 32, 64],
}
Phase 5: Write Test Files
Output to tmp/flashinfer-trace/tests/references/ (the HF dataset clone — committed and PR’d against flashinfer-ai/flashinfer-trace).
Output Structure
tmp/flashinfer-trace/tests/references/
├── conftest.py # Shared fixtures and utilities
├── test_rmsnorm.py # RMSNorm tests
├── test_gqa_paged.py # GQA paged tests
├── test_mla_paged.py # MLA paged tests
├── test_moe.py # MoE tests
└── test_gemm.py # GEMM tests
Test File Template
"""Tests for {definition_name} reference implementation."""
import math
from pathlib import Path
import numpy as np
import torch
# Ground truth imports with availability checks
try:
import flashinfer
from flashinfer.xxx import some_kernel
FLASHINFER_AVAILABLE = True
except ImportError:
FLASHINFER_AVAILABLE = False
try:
from sglang.srt.layers.xxx import sglang_kernel
SGLANG_AVAILABLE = True
except ImportError:
SGLANG_AVAILABLE = False
# Module-level constants (from definition)
NUM_QO_HEADS = 32
NUM_KV_HEADS = 8
HEAD_DIM = 128
PAGE_SIZE = 1
TRACE_ROOT = Path(__file__).resolve().parents[2]
@torch.no_grad()
def run(q, k_cache, v_cache, kv_indptr, kv_indices, sm_scale):
"""Reference implementation matching the definition schema."""
batch_size, num_qo_heads, head_dim = q.shape
_, page_size, num_kv_heads, _ = k_cache.shape
# Check constants
assert num_qo_heads == NUM_QO_HEADS
assert num_kv_heads == NUM_KV_HEADS
assert head_dim == HEAD_DIM
assert page_size == PAGE_SIZE
device = q.device
# Reference computation (pure PyTorch)
output = torch.zeros((batch_size, num_qo_heads, head_dim), dtype=torch.bfloat16, device=device)
lse = torch.full((batch_size, num_qo_heads), -float("inf"), dtype=torch.float32, device=device)
# ... detailed step-by-step computation ...
return output, lse
def generate_random_inputs(
batch_size,
max_seq_len,
num_attention_heads=NUM_QO_HEADS,
num_key_value_heads=NUM_KV_HEADS,
head_dim=HEAD_DIM,
page_size=PAGE_SIZE,
device="cuda",
):
"""Generate random inputs for testing."""
# Generate tensors matching definition schema
q = torch.randn(batch_size, num_attention_heads, head_dim, dtype=torch.bfloat16, device=device)
# ... generate other inputs ...
return {
"q": q,
"k_cache": k_cache,
"v_cache": v_cache,
"kv_indptr": kv_indptr,
"kv_indices": kv_indices,
"sm_scale": sm_scale,
}
def test_correctness(batch_size=4, max_seq_len=64, atol=1e-2, rtol=5e-2):
"""Test correctness of reference implementation against FlashInfer."""
print(f"\n{'='*60}")
print(f"Testing batch_size={batch_size}, max_seq_len={max_seq_len}")
print(f"{'='*60}")
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
print("WARNING: CUDA not available, skipping test")
return
# Generate inputs
inputs = generate_random_inputs(batch_size, max_seq_len, device=device)
print(f"Generated sequences with shapes: q={inputs['q'].shape}")
# Run reference implementation
print("\nRunning reference implementation...")
ref_output, ref_lse = run(**{k: v for k, v in inputs.items() if k != 'extra_keys'})
# Run FlashInfer ground truth
print("Running FlashInfer...")
# ... FlashInfer setup and execution ...
# Compare outputs
print("\nComparing outputs...")
# Convert to float32 for comparison
ref_f32 = ref_output.float()
fi_f32 = fi_output.float()
# Compute errors
abs_diff = torch.abs(ref_f32 - fi_f32)
rel_diff = abs_diff / (torch.abs(fi_f32) + 1e-8)
max_abs_diff = abs_diff.max().item()
max_rel_diff = rel_diff.max().item()
mean_abs_diff = abs_diff.mean().item()
mean_rel_diff = rel_diff.mean().item()
print(f"\nOutput tensor comparison:")
print(f"Max absolute difference: {max_abs_diff:.6e}")
print(f"Max relative difference: {max_rel_diff:.6e}")
print(f"Mean absolute difference: {mean_abs_diff:.6e}")
print(f"Mean relative difference: {mean_rel_diff:.6e}")
# Cosine similarity and MSE
cos_sim = torch.nn.functional.cosine_similarity(
ref_f32.flatten(), fi_f32.flatten(), dim=0
).item()
mse = torch.mean((ref_f32 - fi_f32) ** 2).item()
print(f"Cosine similarity: {cos_sim:.6f}")
print(f"MSE: {mse:.6e}")
# Check if outputs match within tolerance
all_close = torch.allclose(ref_f32, fi_f32, atol=atol, rtol=rtol)
if all_close:
print(f"\n✓ PASSED: Outputs match within tolerance (atol={atol}, rtol={rtol})")
else:
print(f"\n✗ FAILED: Outputs differ beyond tolerance (atol={atol}, rtol={rtol})")
# Show top error locations for debugging
flat = abs_diff.flatten()
k = min(5, flat.numel())
topv, topi = torch.topk(flat, k)
print(f"\nTop-{k} absolute error locations:")
for rank in range(k):
idx = topi[rank].item()
print(f" idx={idx}: ref={ref_f32.flatten()[idx].item():.6e}, "
f"fi={fi_f32.flatten()[idx].item():.6e}, diff={topv[rank].item():.6e}")
return all_close
def main():
"""Run comprehensive tests."""
print("Testing Reference Implementation vs FlashInfer")
# Test configurations
test_configs = [
(1, 16), # Single batch
(4, 32), # Small batch
(8, 64), # Medium batch
(16, 128), # Large batch
]
passed = 0
total = len(test_configs)
for batch_size, max_seq_len in test_configs:
try:
if test_correctness(batch_size, max_seq_len):
passed += 1
except Exception as e:
print(f"✗ Test failed with exception: {str(e)}")
import traceback
traceback.print_exc()
print(f"\n{'='*60}")
print(f"Summary: {passed}/{total} tests passed")
print(f"{'='*60}")
if passed == total:
print("✓ All tests passed!")
else:
print(f"✗ {total - passed} tests failed")
if __name__ == "__main__":
main()
conftest.py Template
"""Shared test fixtures for reference implementation tests."""
import json
import math
import pytest
import torch
from pathlib import Path
DEFINITIONS_DIR = Path(__file__).parent.parent / "definitions"
WORKLOADS_DIR = Path(__file__).parent.parent / "workloads"
@pytest.fixture
def device():
"""Get test device (CUDA if available)."""
return "cuda" if torch.cuda.is_available() else "cpu"
def load_definition(name: str) -> dict:
"""Load a definition JSON by name."""
for op_dir in DEFINITIONS_DIR.iterdir():
if op_dir.is_dir():
def_file = op_dir / f"{name}.json"
if def_file.exists():
with open(def_file) as f:
return json.load(f)
raise FileNotFoundError(f"Definition {name} not found")
def compile_reference(reference_code: str):
"""Compile reference implementation to callable function."""
namespace = {"torch": torch, "math": math, "np": __import__("numpy")}
exec(reference_code, namespace)
return namespace["run"]
def assert_close(actual, expected, rtol=1e-3, atol=1e-3):
"""Assert tensors are close within tolerance."""
if isinstance(actual, tuple):
for a, e in zip(actual, expected):
assert_close(a, e, rtol, atol)
elif isinstance(actual, dict):
for k in actual:
assert_close(actual[k], expected[k], rtol, atol)
else:
torch.testing.assert_close(actual, expected, rtol=rtol, atol=atol)
def compute_error_metrics(ref, gt, name="output"):
"""Compute and print detailed error metrics."""
ref_f32 = ref.float()
gt_f32 = gt.float()
abs_diff = torch.abs(ref_f32 - gt_f32)
rel_diff = abs_diff / (torch.abs(gt_f32) + 1e-8)
print(f"\n{name} comparison:")
print(f" Max absolute difference: {abs_diff.max().item():.6e}")
print(f" Max relative difference: {rel_diff.max().item():.6e}")
print(f" Mean absolute difference: {abs_diff.mean().item():.6e}")
print(f" Mean relative difference: {rel_diff.mean().item():.6e}")
cos_sim = torch.nn.functional.cosine_similarity(
ref_f32.flatten(), gt_f32.flatten(), dim=0
).item()
mse = torch.mean((ref_f32 - gt_f32) ** 2).item()
print(f" Cosine similarity: {cos_sim:.6f}")
print(f" MSE: {mse:.6e}")
return abs_diff, rel_diff
def check_hit_ratio(ref, gt, atol, rtol, required_percent=0.85):
"""Check if hit ratio meets threshold (for quantized kernels)."""
ref_f32 = ref.float()
gt_f32 = gt.float()
left = (ref_f32 - gt_f32).abs()
right = atol + rtol * gt_f32.abs()
ok = left <= right
hit_ratio = ok.float().mean().item()
print(f"\nHit ratio: {hit_ratio * 100:.2f}% (need >= {required_percent * 100:.2f}%)")
return hit_ratio >= required_percent
Implementation Steps
When executing this skill:
-
Identify definitions to test:
ls tmp/flashinfer-trace/definitions/{op_type}/ -
Check for existing tests:
ls tmp/flashinfer-trace/tests/references/ -
For each definition:
- Read the definition JSON
- Identify ground truth source (FlashInfer or SGLang)
- Generate test class with appropriate fixtures
- Generate test methods for shape, dtype, and numerical correctness
-
Create test file:
# Create tests directory if needed mkdir -p tmp/flashinfer-trace/tests/references/ -
Write test file:
- If testing multiple definitions of same op_type, combine into one file
- Each definition gets its own test class
-
Create/update conftest.py with shared fixtures
Running Tests
After generating tests, run from the HF dataset clone:
# Run all reference tests
pytest tmp/flashinfer-trace/tests/references/ -v
# Run specific test file
pytest tmp/flashinfer-trace/tests/references/test_mla_paged.py -v
# Run with GPU
pytest tmp/flashinfer-trace/tests/references/ -v --device cuda
# Run with verbose output
pytest tmp/flashinfer-trace/tests/references/ -v -s
Error Handling
Ground Truth Not Available
- Error: Ground truth implementation not found
- Handling: Follow priority order:
- Check SGLang vanilla implementation
- Check SGLang FlashInfer API integration
- Check FlashInfer API directly
- If none available, mark test as skip with reason
Definition Parse Error
- Error: Invalid definition JSON
- Handling: Report validation errors, skip test generation
Shape Mismatch
- Error: Reference output shape doesn’t match definition
- Handling: Create failing test, flag for investigation
Numerical Divergence
- Error: Reference differs from ground truth beyond tolerance
- Handling: Create failing test with detailed diff report
Integration with Other Skills
# Complete workflow
/clone-repos
# Extract definitions
/extract-kernel-definitions --model-name deepseek_v3
# Add tests for new definitions
/add-reference-tests --op-type mla_paged
/add-reference-tests --op-type moe
# Run tests from the HF dataset clone
pytest tmp/flashinfer-trace/tests/references/ -v
Notes
- Tests run on GPU by default; CPU fallback for CI environments
- Tolerance varies by dtype: looser for fp16 (1e-3), stricter for fp32 (1e-5)
- Some kernels may not have FlashInfer ground truth yet
- Test parametrization covers common batch/sequence sizes
- Tests marked with
@pytest.mark.slowfor large sizes
Maintaining This Document
Update this file when changing ground truth sources, test patterns, tolerance values, or adding new op_types.