add-reference-tests

为 flashinfer_trace 中的参考实现生成 pytest 测试用例,通过比对 FlashInfer 或 SGLang 提供的权威结果来验证数值正确性,覆盖不同算子类型、数据精度与输入规模,并自动适配常量配置与容差策略。

快捷安装

在终端运行此命令,即可一键安装该 Skill 到您的 Claude 中

npx skills add flashinfer-ai/flashinfer-bench --skill "add-reference-tests"

Add Reference Tests

Add tests to validate reference implementations in the HuggingFace dataset clone at tmp/flashinfer-trace/. Ground truth is sourced from FlashInfer repository or SGLang when FlashInfer doesn’t have the implementation.

Description

This skill creates test cases under tmp/flashinfer-trace/tests/references/ (the HF dataset clone — the in-repo flashinfer_trace/ directory was removed in the trace-dataset refactor) to validate that reference implementations in Definition JSON files produce correct outputs. The ground truth comes from:

  1. FlashInfer repository (preferred): Official optimized GPU kernels in tmp/flashinfer/
  2. SGLang repository (fallback): When FlashInfer doesn’t have the kernel, use tmp/sglang/

Usage

# Test all definitions of a specific op_type
/add-reference-tests --op-type mla_paged
/add-reference-tests --op-type moe
/add-reference-tests --op-type gqa_paged
/add-reference-tests --op-type rmsnorm

# Test a specific definition
/add-reference-tests --definition-name mla_paged_decode_h16_ckv512_kpe64_ps1

# Test all definitions in the definitions directory
/add-reference-tests --all

# Test with custom tolerance
/add-reference-tests --definition-name rmsnorm_h4096 --tolerance 1e-4

Parameters

  • definition_name (optional): Specific definition to test (e.g., “mla_paged_decode_h16_ckv512_kpe64_ps1”)
  • op_type (optional): Test all definitions of a specific op_type (e.g., “mla_paged”, “moe”, “rmsnorm”)
  • all (optional): Test all definitions in the definitions directory
  • test_sizes (optional): List of test sizes [“small”, “medium”, “large”] (default: [“small”, “medium”])
  • tolerance (optional): Numerical tolerance for comparison (default: 1e-3 for fp16, 1e-5 for fp32)

Prerequisites

Run /clone-repos first to set up the tmp/ directory with SGLang, FlashInfer, and the HuggingFace trace dataset clone at tmp/flashinfer-trace/ — that clone is the only home for definitions and reference tests.

What This Skill Does

Phase 1: Definition Discovery

  1. Load Target Definitions:

    • If definition_name specified: load single definition
    • If op_type specified: load all definitions matching op_type from tmp/flashinfer-trace/definitions/{op_type}/
    • If all: scan all definitions
  2. Check Existing Tests:

    • Scan tmp/flashinfer-trace/tests/references/ for existing test files
    • Skip definitions that already have tests (unless force=true)
  3. Parse Definition Schema:

    • Extract axes (const/var), inputs, outputs
    • Identify required shapes and dtypes
    • Parse reference implementation code

Phase 2: Ground Truth Discovery

For each definition, locate ground truth implementation using this priority order:

For Model Constants: HuggingFace + SGLang (Required)

Note: See extract-kernel-definitions for detailed guidance on sourcing model constants from HuggingFace and SGLang.

For Ground Truth Execution: FlashInfer API (Primary)

  • When: FlashInfer has the kernel implementation (MOST kernels)
  • Location: tmp/flashinfer/python/flashinfer/
  • Use for: Running optimized GPU kernel as ground truth
  • Examples:
    import flashinfer
    # GQA decode
    flashinfer.BatchDecodeWithPagedKVCacheWrapper(...)
    # GQA prefill
    flashinfer.BatchPrefillWithPagedKVCacheWrapper(...)
    # MLA
    flashinfer.mla.BatchMLAPagedAttentionWrapper(...)
    # RMSNorm
    flashinfer.norm.rmsnorm(...)
  • Important: FlashInfer is the PRIMARY ground truth for correctness validation

For Ground Truth Execution: SGLang (Fallback ONLY)

  • When: FlashInfer does NOT have the kernel (e.g., some MoE variants)
  • Location: tmp/sglang/python/sglang/srt/layers/
  • Use for: Ground truth when FlashInfer unavailable
  • Examples:
    layers/moe/fused_moe.py         # MoE when FlashInfer MoE unavailable
  • Important: ONLY use SGLang as ground truth when FlashInfer doesn’t support the kernel

Ground Truth Source Mapping

Op TypeGround Truth SourceFlashInfer APIFallback (if FlashInfer unavailable)
rmsnormFlashInferflashinfer.norm.rmsnormN/A (FlashInfer has it)
fused_add_rmsnormFlashInferflashinfer.norm.fused_add_rmsnormN/A (FlashInfer has it)
gqa_pagedFlashInferflashinfer.BatchDecodeWithPagedKVCacheWrapper, flashinfer.BatchPrefillWithPagedKVCacheWrapperN/A
gqa_raggedFlashInferflashinfer.BatchPrefillWithRaggedKVCacheWrapperN/A
mla_pagedFlashInferflashinfer.mla.BatchMLAPagedAttentionWrapperN/A (FlashInfer has it)
moeSGLang (fallback)N/A (FlashInfer MoE may not cover all variants)sglang/layers/moe/fused_moe.py
gemmPyTorchN/Atorch.nn.functional.linear
samplingFlashInferflashinfer.sampling.*N/A
ropeFlashInferflashinfer.apply_rope_with_cos_sin_cache_inplaceN/A

Reference run() Function Sources

Note: For detailed guidance on sourcing reference implementations, see the extract-kernel-definitions skill’s “Reference Implementation Sources” section.

Quick Reference:

  • Primary: FlashInfer unit tests at tmp/flashinfer/tests/ (e.g., test_batch_decode.py, test_norm.py)
  • Fallback: SGLang vanilla implementations at tmp/sglang/python/sglang/srt/layers/ (only when FlashInfer unavailable)

Phase 3: Test Generation

For each definition, generate test file following the standards below.

Test File Standards

File Structure

Each test file should follow this structure:

import json
import math
from pathlib import Path

import numpy as np
import pytest
import torch

# Ground truth imports (with availability checks)
try:
    import flashinfer
    from flashinfer.xxx import some_kernel
    FLASHINFER_AVAILABLE = True
except ImportError:
    FLASHINFER_AVAILABLE = False

# Module-level constants from definition
HIDDEN_SIZE = 7168
NUM_EXPERTS = 256
# ... other constants

TRACE_ROOT = Path(__file__).resolve().parents[2]
WORKLOAD_JSONL_PATH = TRACE_ROOT / "workloads" / "op_type" / "definition_name.jsonl"


@torch.no_grad()
def run(...):
    """Reference implementation matching the definition."""
    # Check constants
    assert hidden_size == HIDDEN_SIZE
    ...


def generate_random_inputs(..., device="cuda"):
    """Generate random inputs for testing."""
    ...
    return {...}


def test_correctness(..., atol=1e-2, rtol=5e-2):
    """Test correctness of reference implementation against ground truth."""
    ...


def main():
    """Run comprehensive tests."""
    ...


if __name__ == "__main__":
    main()

Coding Style Patterns

  1. Constants at Module Level: Define model-specific constants at the top

    # DeepSeek V3/R1 MoE constants
    HIDDEN_SIZE = 7168
    INTERMEDIATE_SIZE = 2048
    NUM_EXPERTS_GLOBAL = 256
    NUM_LOCAL_EXPERTS = 32  # EP=8
  2. Use @torch.no_grad() Decorator: For all reference implementations and test functions

  3. Input Generator Function: Separate function generate_random_inputs(...) that returns a dict

  4. Test Function Pattern:

    def test_correctness(batch_size=4, max_seq_len=64, atol=1e-2, rtol=5e-2):
        """Test correctness of reference implementation against ground truth."""
        print(f"\n{'='*60}")
        print(f"Testing {description}: {params}")
        print(f"{'='*60}")
    
        device = "cuda" if torch.cuda.is_available() else "cpu"
        if device == "cpu":
            print("WARNING: CUDA not available, skipping test")
            return
    
        # Generate inputs
        inputs = generate_random_inputs(...)
    
        # Run reference implementation
        print("\nRunning reference implementation...")
        ref_output = run(**inputs)
    
        # Run ground truth
        print("Running ground truth (FlashInfer/SGLang)...")
        gt_output = ground_truth_fn(**inputs)
    
        # Compare outputs
        print("\nComparing outputs...")
        # ... detailed comparison

Correctness Checking Patterns

Standard Tolerance Check (FP16/BF16)

# Convert to float32 for comparison
ref_f32 = ref_output.float()
gt_f32 = gt_output.float()

# Compute detailed error metrics
abs_diff = torch.abs(ref_f32 - gt_f32)
rel_diff = abs_diff / (torch.abs(gt_f32) + 1e-8)

max_abs_diff = abs_diff.max().item()
max_rel_diff = rel_diff.max().item()
mean_abs_diff = abs_diff.mean().item()
mean_rel_diff = rel_diff.mean().item()

print(f"\nOutput tensor comparison:")
print(f"Max absolute difference: {max_abs_diff:.6e}")
print(f"Max relative difference: {max_rel_diff:.6e}")
print(f"Mean absolute difference: {mean_abs_diff:.6e}")
print(f"Mean relative difference: {mean_rel_diff:.6e}")

# Cosine similarity and MSE
cos_sim = torch.nn.functional.cosine_similarity(
    ref_f32.flatten(), gt_f32.flatten(), dim=0
).item()
mse = torch.mean((ref_f32 - gt_f32) ** 2).item()
print(f"Cosine similarity: {cos_sim:.6f}")
print(f"MSE: {mse:.6e}")

# Check tolerance
all_close = torch.allclose(ref_f32, gt_f32, atol=atol, rtol=rtol)
if all_close:
    print(f"\n✓ PASSED: Outputs match within tolerance (atol={atol}, rtol={rtol})")
else:
    print(f"\n✗ FAILED: Outputs differ beyond tolerance (atol={atol}, rtol={rtol})")

Hit Ratio Check (for FP8/Quantized Kernels)

For quantized kernels with higher variance, use hit ratio instead of strict allclose:

# Check what percentage of elements pass the tolerance check
left = (ref_f32 - gt_f32).abs()
right = atol + rtol * gt_f32.abs()
ok = left <= right
hit_ratio = ok.float().mean().item()
print(f"\nHit ratio: {hit_ratio * 100:.2f}%  (need >= {percent * 100:.2f}%)")

return hit_ratio >= percent  # e.g., 85%

Error Location Debugging

When tests fail, show top error locations:

if not all_close:
    flat = abs_diff.flatten()
    k = min(5, flat.numel())
    topv, topi = torch.topk(flat, k)
    print(f"\nTop-{k} absolute error locations:")
    for rank in range(k):
        idx = topi[rank].item()
        # Convert flat index to multi-dimensional
        # ... compute indices
        print(f"  [{indices}]: ref={ref_val:.6e}, gt={gt_val:.6e}, diff={topv[rank].item():.6e}")

Tolerance Guidelines

Data TypeatolrtolNotes
float321e-51e-5Strictest
float161e-31e-3Standard
bfloat168e-31e-20.8% abs, 1% rel
float8_e4m3fn1e-12e-1Use hit ratio ≥85%
nvfp41e-12e-1Use hit ratio ≥85%

Multi-Ground-Truth Testing

Pattern for Multiple Ground Truths

When both FlashInfer and SGLang implementations are available, test against both:

# Ground truth imports
try:
    from flashinfer.xxx import flashinfer_kernel
    FLASHINFER_AVAILABLE = True
except ImportError:
    FLASHINFER_AVAILABLE = False

try:
    from sglang.srt.layers.xxx import sglang_kernel
    SGLANG_AVAILABLE = True
except ImportError:
    SGLANG_AVAILABLE = False


def test_correctness_vs_flashinfer(...):
    """Test reference against FlashInfer ground truth."""
    if not FLASHINFER_AVAILABLE:
        pytest.skip("FlashInfer not available")

    ref_output = run(**inputs)
    fi_output = flashinfer_kernel(**inputs)

    assert_close(ref_output, fi_output, atol=atol, rtol=rtol)


def test_correctness_vs_sglang(...):
    """Test reference against SGLang ground truth."""
    if not SGLANG_AVAILABLE:
        pytest.skip("SGLang not available")

    ref_output = run(**inputs)
    sg_output = sglang_kernel(**inputs)

    assert_close(ref_output, sg_output, atol=atol, rtol=rtol)


def test_ground_truths_match(...):
    """Test that FlashInfer and SGLang produce consistent results."""
    if not (FLASHINFER_AVAILABLE and SGLANG_AVAILABLE):
        pytest.skip("Both ground truths not available")

    fi_output = flashinfer_kernel(**inputs)
    sg_output = sglang_kernel(**inputs)

    assert_close(fi_output, sg_output, atol=atol, rtol=rtol)

Comprehensive Test Runner

def main():
    """Run comprehensive tests against all available ground truths."""
    print("Testing Reference Implementation")

    # Test configurations
    test_configs = [
        (1, 16),   # Small
        (4, 32),   # Medium
        (8, 64),   # Large
    ]

    results = {"flashinfer": [], "sglang": []}

    for config in test_configs:
        if FLASHINFER_AVAILABLE:
            try:
                ok = test_correctness_vs_flashinfer(*config)
                results["flashinfer"].append(ok)
            except Exception as e:
                print(f"FlashInfer test failed: {e}")
                results["flashinfer"].append(False)

        if SGLANG_AVAILABLE:
            try:
                ok = test_correctness_vs_sglang(*config)
                results["sglang"].append(ok)
            except Exception as e:
                print(f"SGLang test failed: {e}")
                results["sglang"].append(False)

    # Summary
    print(f"\n{'='*60}")
    print("Summary:")
    for source, passed_list in results.items():
        if passed_list:
            print(f"  {source}: {sum(passed_list)}/{len(passed_list)} tests passed")
    print(f"{'='*60}")

Standard Test Generation

For each definition, generate test file:

  1. Create Test Class with:

    • Fixture for loading definition
    • Fixture for compiling reference implementation
    • Fixture for ground truth function
  2. Generate Test Inputs:

    • Parse definition schema for input shapes and dtypes
    • Generate random tensors matching specifications
    • Handle both constant and variable axes
  3. Create Test Methods:

    • test_output_shape: Verify output shapes match definition
    • test_output_dtype: Verify output dtypes match definition
    • test_numerical_correctness: Compare reference vs ground truth
    • test_determinism: Verify reproducible results

Phase 4: Test Cases

Generate multiple test cases with varying sizes:

# Small: Quick smoke tests
SMALL_SIZES = {
    "batch_size": [1, 2],
    "seq_len": [1, 16],
    "num_pages": [1, 4],
}

# Medium: Standard tests
MEDIUM_SIZES = {
    "batch_size": [4, 8, 16],
    "seq_len": [64, 128, 256],
    "num_pages": [16, 32, 64],
}

Phase 5: Write Test Files

Output to tmp/flashinfer-trace/tests/references/ (the HF dataset clone — committed and PR’d against flashinfer-ai/flashinfer-trace).

Output Structure

tmp/flashinfer-trace/tests/references/
├── conftest.py                    # Shared fixtures and utilities
├── test_rmsnorm.py                # RMSNorm tests
├── test_gqa_paged.py              # GQA paged tests
├── test_mla_paged.py              # MLA paged tests
├── test_moe.py                    # MoE tests
└── test_gemm.py                   # GEMM tests

Test File Template

"""Tests for {definition_name} reference implementation."""
import math
from pathlib import Path

import numpy as np
import torch

# Ground truth imports with availability checks
try:
    import flashinfer
    from flashinfer.xxx import some_kernel
    FLASHINFER_AVAILABLE = True
except ImportError:
    FLASHINFER_AVAILABLE = False

try:
    from sglang.srt.layers.xxx import sglang_kernel
    SGLANG_AVAILABLE = True
except ImportError:
    SGLANG_AVAILABLE = False


# Module-level constants (from definition)
NUM_QO_HEADS = 32
NUM_KV_HEADS = 8
HEAD_DIM = 128
PAGE_SIZE = 1

TRACE_ROOT = Path(__file__).resolve().parents[2]


@torch.no_grad()
def run(q, k_cache, v_cache, kv_indptr, kv_indices, sm_scale):
    """Reference implementation matching the definition schema."""
    batch_size, num_qo_heads, head_dim = q.shape
    _, page_size, num_kv_heads, _ = k_cache.shape

    # Check constants
    assert num_qo_heads == NUM_QO_HEADS
    assert num_kv_heads == NUM_KV_HEADS
    assert head_dim == HEAD_DIM
    assert page_size == PAGE_SIZE

    device = q.device

    # Reference computation (pure PyTorch)
    output = torch.zeros((batch_size, num_qo_heads, head_dim), dtype=torch.bfloat16, device=device)
    lse = torch.full((batch_size, num_qo_heads), -float("inf"), dtype=torch.float32, device=device)

    # ... detailed step-by-step computation ...

    return output, lse


def generate_random_inputs(
    batch_size,
    max_seq_len,
    num_attention_heads=NUM_QO_HEADS,
    num_key_value_heads=NUM_KV_HEADS,
    head_dim=HEAD_DIM,
    page_size=PAGE_SIZE,
    device="cuda",
):
    """Generate random inputs for testing."""
    # Generate tensors matching definition schema
    q = torch.randn(batch_size, num_attention_heads, head_dim, dtype=torch.bfloat16, device=device)
    # ... generate other inputs ...

    return {
        "q": q,
        "k_cache": k_cache,
        "v_cache": v_cache,
        "kv_indptr": kv_indptr,
        "kv_indices": kv_indices,
        "sm_scale": sm_scale,
    }


def test_correctness(batch_size=4, max_seq_len=64, atol=1e-2, rtol=5e-2):
    """Test correctness of reference implementation against FlashInfer."""
    print(f"\n{'='*60}")
    print(f"Testing batch_size={batch_size}, max_seq_len={max_seq_len}")
    print(f"{'='*60}")

    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cpu":
        print("WARNING: CUDA not available, skipping test")
        return

    # Generate inputs
    inputs = generate_random_inputs(batch_size, max_seq_len, device=device)
    print(f"Generated sequences with shapes: q={inputs['q'].shape}")

    # Run reference implementation
    print("\nRunning reference implementation...")
    ref_output, ref_lse = run(**{k: v for k, v in inputs.items() if k != 'extra_keys'})

    # Run FlashInfer ground truth
    print("Running FlashInfer...")
    # ... FlashInfer setup and execution ...

    # Compare outputs
    print("\nComparing outputs...")

    # Convert to float32 for comparison
    ref_f32 = ref_output.float()
    fi_f32 = fi_output.float()

    # Compute errors
    abs_diff = torch.abs(ref_f32 - fi_f32)
    rel_diff = abs_diff / (torch.abs(fi_f32) + 1e-8)

    max_abs_diff = abs_diff.max().item()
    max_rel_diff = rel_diff.max().item()
    mean_abs_diff = abs_diff.mean().item()
    mean_rel_diff = rel_diff.mean().item()

    print(f"\nOutput tensor comparison:")
    print(f"Max absolute difference: {max_abs_diff:.6e}")
    print(f"Max relative difference: {max_rel_diff:.6e}")
    print(f"Mean absolute difference: {mean_abs_diff:.6e}")
    print(f"Mean relative difference: {mean_rel_diff:.6e}")

    # Cosine similarity and MSE
    cos_sim = torch.nn.functional.cosine_similarity(
        ref_f32.flatten(), fi_f32.flatten(), dim=0
    ).item()
    mse = torch.mean((ref_f32 - fi_f32) ** 2).item()
    print(f"Cosine similarity: {cos_sim:.6f}")
    print(f"MSE: {mse:.6e}")

    # Check if outputs match within tolerance
    all_close = torch.allclose(ref_f32, fi_f32, atol=atol, rtol=rtol)

    if all_close:
        print(f"\n✓ PASSED: Outputs match within tolerance (atol={atol}, rtol={rtol})")
    else:
        print(f"\n✗ FAILED: Outputs differ beyond tolerance (atol={atol}, rtol={rtol})")

        # Show top error locations for debugging
        flat = abs_diff.flatten()
        k = min(5, flat.numel())
        topv, topi = torch.topk(flat, k)
        print(f"\nTop-{k} absolute error locations:")
        for rank in range(k):
            idx = topi[rank].item()
            print(f"  idx={idx}: ref={ref_f32.flatten()[idx].item():.6e}, "
                  f"fi={fi_f32.flatten()[idx].item():.6e}, diff={topv[rank].item():.6e}")

    return all_close


def main():
    """Run comprehensive tests."""
    print("Testing Reference Implementation vs FlashInfer")

    # Test configurations
    test_configs = [
        (1, 16),   # Single batch
        (4, 32),   # Small batch
        (8, 64),   # Medium batch
        (16, 128), # Large batch
    ]

    passed = 0
    total = len(test_configs)

    for batch_size, max_seq_len in test_configs:
        try:
            if test_correctness(batch_size, max_seq_len):
                passed += 1
        except Exception as e:
            print(f"✗ Test failed with exception: {str(e)}")
            import traceback
            traceback.print_exc()

    print(f"\n{'='*60}")
    print(f"Summary: {passed}/{total} tests passed")
    print(f"{'='*60}")

    if passed == total:
        print("✓ All tests passed!")
    else:
        print(f"✗ {total - passed} tests failed")


if __name__ == "__main__":
    main()

conftest.py Template

"""Shared test fixtures for reference implementation tests."""
import json
import math
import pytest
import torch
from pathlib import Path


DEFINITIONS_DIR = Path(__file__).parent.parent / "definitions"
WORKLOADS_DIR = Path(__file__).parent.parent / "workloads"


@pytest.fixture
def device():
    """Get test device (CUDA if available)."""
    return "cuda" if torch.cuda.is_available() else "cpu"


def load_definition(name: str) -> dict:
    """Load a definition JSON by name."""
    for op_dir in DEFINITIONS_DIR.iterdir():
        if op_dir.is_dir():
            def_file = op_dir / f"{name}.json"
            if def_file.exists():
                with open(def_file) as f:
                    return json.load(f)
    raise FileNotFoundError(f"Definition {name} not found")


def compile_reference(reference_code: str):
    """Compile reference implementation to callable function."""
    namespace = {"torch": torch, "math": math, "np": __import__("numpy")}
    exec(reference_code, namespace)
    return namespace["run"]


def assert_close(actual, expected, rtol=1e-3, atol=1e-3):
    """Assert tensors are close within tolerance."""
    if isinstance(actual, tuple):
        for a, e in zip(actual, expected):
            assert_close(a, e, rtol, atol)
    elif isinstance(actual, dict):
        for k in actual:
            assert_close(actual[k], expected[k], rtol, atol)
    else:
        torch.testing.assert_close(actual, expected, rtol=rtol, atol=atol)


def compute_error_metrics(ref, gt, name="output"):
    """Compute and print detailed error metrics."""
    ref_f32 = ref.float()
    gt_f32 = gt.float()

    abs_diff = torch.abs(ref_f32 - gt_f32)
    rel_diff = abs_diff / (torch.abs(gt_f32) + 1e-8)

    print(f"\n{name} comparison:")
    print(f"  Max absolute difference: {abs_diff.max().item():.6e}")
    print(f"  Max relative difference: {rel_diff.max().item():.6e}")
    print(f"  Mean absolute difference: {abs_diff.mean().item():.6e}")
    print(f"  Mean relative difference: {rel_diff.mean().item():.6e}")

    cos_sim = torch.nn.functional.cosine_similarity(
        ref_f32.flatten(), gt_f32.flatten(), dim=0
    ).item()
    mse = torch.mean((ref_f32 - gt_f32) ** 2).item()
    print(f"  Cosine similarity: {cos_sim:.6f}")
    print(f"  MSE: {mse:.6e}")

    return abs_diff, rel_diff


def check_hit_ratio(ref, gt, atol, rtol, required_percent=0.85):
    """Check if hit ratio meets threshold (for quantized kernels)."""
    ref_f32 = ref.float()
    gt_f32 = gt.float()

    left = (ref_f32 - gt_f32).abs()
    right = atol + rtol * gt_f32.abs()
    ok = left <= right
    hit_ratio = ok.float().mean().item()

    print(f"\nHit ratio: {hit_ratio * 100:.2f}%  (need >= {required_percent * 100:.2f}%)")
    return hit_ratio >= required_percent

Implementation Steps

When executing this skill:

  1. Identify definitions to test:

    ls tmp/flashinfer-trace/definitions/{op_type}/
  2. Check for existing tests:

    ls tmp/flashinfer-trace/tests/references/
  3. For each definition:

    • Read the definition JSON
    • Identify ground truth source (FlashInfer or SGLang)
    • Generate test class with appropriate fixtures
    • Generate test methods for shape, dtype, and numerical correctness
  4. Create test file:

    # Create tests directory if needed
    mkdir -p tmp/flashinfer-trace/tests/references/
  5. Write test file:

    • If testing multiple definitions of same op_type, combine into one file
    • Each definition gets its own test class
  6. Create/update conftest.py with shared fixtures

Running Tests

After generating tests, run from the HF dataset clone:

# Run all reference tests
pytest tmp/flashinfer-trace/tests/references/ -v

# Run specific test file
pytest tmp/flashinfer-trace/tests/references/test_mla_paged.py -v

# Run with GPU
pytest tmp/flashinfer-trace/tests/references/ -v --device cuda

# Run with verbose output
pytest tmp/flashinfer-trace/tests/references/ -v -s

Error Handling

Ground Truth Not Available

  • Error: Ground truth implementation not found
  • Handling: Follow priority order:
    1. Check SGLang vanilla implementation
    2. Check SGLang FlashInfer API integration
    3. Check FlashInfer API directly
    4. If none available, mark test as skip with reason

Definition Parse Error

  • Error: Invalid definition JSON
  • Handling: Report validation errors, skip test generation

Shape Mismatch

  • Error: Reference output shape doesn’t match definition
  • Handling: Create failing test, flag for investigation

Numerical Divergence

  • Error: Reference differs from ground truth beyond tolerance
  • Handling: Create failing test with detailed diff report

Integration with Other Skills

# Complete workflow
/clone-repos

# Extract definitions
/extract-kernel-definitions --model-name deepseek_v3

# Add tests for new definitions
/add-reference-tests --op-type mla_paged
/add-reference-tests --op-type moe

# Run tests from the HF dataset clone
pytest tmp/flashinfer-trace/tests/references/ -v

Notes

  • Tests run on GPU by default; CPU fallback for CI environments
  • Tolerance varies by dtype: looser for fp16 (1e-3), stricter for fp32 (1e-5)
  • Some kernels may not have FlashInfer ground truth yet
  • Test parametrization covers common batch/sequence sizes
  • Tests marked with @pytest.mark.slow for large sizes

Maintaining This Document

Update this file when changing ground truth sources, test patterns, tolerance values, or adding new op_types.

See Also