$0.50/GPU-hour amortized, output tokens only, batch=1. B200 and H100 numbers from public benchmarks.
We profiled Qwen3-Next-80B running on NVIDIA RTX PRO 6000 Blackwell GPUs with TensorRT-LLM. The MoE layers consume 73.6% of all GPU time during token generation. Within each MoE layer, 29% of that time is matrix multiplication. The remaining 71% is spent copying data between formats, quantizing activations, scattering inputs to experts, running separate activation kernels, and combining results.
Five of the eight pipeline stages in CUTLASS's MoE path do zero useful compute at batch size 1. They reshape data for the expert-centric GEMM libraries.
We wrote CUDA kernels that skip all of it by organizing the computation around outputs instead of experts. Our SM120 implementation reaches 162 tok/s, a 1.67x improvement over TRT-LLM's CUTLASS baseline, with no accuracy loss.
We specifically optimized this for Blackwell - but specifically for RTX PRO 6000 GPUs.
Why RTX PRO 6000?
- Very poor kernel support
- Very high availability
- Very low cost - theoretical FLOPs/$ can't be beat
Cost per token
The RTX PRO 6000 costs roughly $7,000. An H100 costs $25,000+. A B200 costs $35,000+.
| GPU | Price | MoE 80B tok/s | Cost per M tokens* |
|---|---|---|---|
| RTX PRO 6000 (warp-decode) | ~$7K | 162 | ~$0.012 |
| RTX PRO 6000 (baseline) | ~$7K | 97 | ~$0.020 |
| H100 SXM (TRT-LLM) | ~$25K | ~120 | ~$0.035 |
| B200 (est.) | ~$35K | ~180 | ~$0.028 |
*$0.50/GPU-hour amortized, output tokens only, batch=1.
At 162 tok/s, the RTX PRO 6000 delivers 1.67x the throughput of the baseline at 67% of the price. Per-token cost is 3x lower.
Where the time goes
nsys with --cuda-graph-trace=node, 64 decode tokens profiled:
Inside each MoE layer (145 us × 48 layers):
The marked items total 48.1 us per layer. The GEMMs total 41.4 us. More time reshaping data than multiplying matrices.
Warp decode
At batch size 1, one token routes to 10 of 512 experts. Each expert processes one input vector. The expert-centric GEMM pipeline scatters that single vector 10 times, quantizes it to FP4, runs 10 small GEMMs with tile-based CUTLASS kernels, runs a separate SiLU kernel, then combines the outputs.
Warp decode assigns each GPU warp (32 threads) to one output element. The warp streams the weight rows it needs directly from memory, dequantizes FP4 values in registers, and writes one scalar. No scatter, no intermediate buffers, no separate activation kernel.
Gate+Up kernel
5,120 warps (10 experts × 512 neurons). Each warp:
Gate and up share a single activation load. SiLU folds into the epilogue. One kernel replaces three CUTLASS launches.
Down kernel
2,048 warps (one per hidden dimension). Each warp loops over 10 experts:
The routing weight combination happens inside the accumulator. The eight expert outputs never materialize in memory.
What changes
| CUTLASS stage | us/layer | After warp decode |
|---|---|---|
| Expert scatter (copy input 10×) | 18.9 | Eliminated |
| FP4 activation quantize | 5.7 | Bypassed (bf16 input) |
| SiLU activation kernel | 13.7 | Fused in gate_up |
| TMA stride computation | 6.7 | Eliminated |
| Scale format conversion | 3.1 | Computed in registers |
| Gate+Up + Down GEMMs | 41.4 | 18.5 (warp-decode) |
| Total | 145 | ~38 |
FP4 dequantization without a lookup table
NVFP4 packs two 4-bit weights per byte. The standard approach uses a 16-entry lookup table in CUDA constant memory. When 5,120 warps access the LUT simultaneously with divergent nibble values, the constant cache serializes. This added 41 us to a 6 us kernel.
We construct the IEEE 754 float directly from the 4-bit encoding:
Three shifts, two masks, one conditional move. The compiler emits a predicated FSEL. Gate+Up kernel latency dropped from 47 us to 10.3 us.
Strategies evaluated (not used)
- Async bulk memory ops: Extra setup costs outweighed any acceleration for our typical size.
- Specialized tensor core instructions: Hardware tile size restrictions make direct use unwieldy at small dimensions.
- Shared memory buffering: Experiments with explicit staging added unnecessary synchronization and no observable gain.
- Built-in FP4 decode primitives: Hardware implementations required extra conversions and ended up slower than custom math.
- Batching beyond single token: For larger batches 16+, established libraries increasingly outperform; our approach is optimized for minimal batch scenarios.
TRT-LLM CUDA graph integration
TRT-LLM captures the forward pass as a CUDA graph during warmup and replays it during inference, eliminating Python dispatch overhead. Any custom kernel must be captured correctly during graph construction and replay with identical tensor addresses.
We inject at ConfigurableMoE._forward_chunk_impl inside TRT-LLM's moe_custom_op boundary. At this level, CUDA graphs capture our kernel launches during the batch-1 decode warmup pass. For batch=1 bf16, the patch bypasses Steps 4-5-6 (quantization + CUTLASS) and runs two warp-decode kernel calls. For other batch sizes, it falls through to CUTLASS.
The extension registers via TORCH_LIBRARY with Meta dispatch stubs for torch.compile compatibility.
Getting this right took 22 integration iterations (v5 through v22). The first 21 either produced correct output without CUDA graphs or fast output with garbled text. v22 was the first to achieve both.
Results
| Configuration | tok/s | Speedup | Output |
|---|---|---|---|
| TRT-LLM baseline (CUTLASS) | 97 | 1.0x | Correct |
| Warp-decode v22 | 162 | 1.67x | Correct |
Bypassing FP4 activation quantization (bf16 activations, FP32 accumulators throughout) means our output is closer to FP32 ground truth than the CUTLASS FP4×FP4 path. Cursor observed the same effect: 1.4x closer to 32-bit reference.
The speedup comes from the warp-per-element decode kernel eliminating the scatter/gather/padding overhead of grouped GEMM at small batch sizes. Grouped GEMM must reshape inputs into expert-shaped tiles regardless of how many tokens route to each expert. At batch=1, that reshaping dominates. At batch=8 the advantage narrows as expected: grouped GEMM amortizes the reshape cost across more rows, and CUTLASS tile utilization improves with larger M dimensions.
Next
The router GEMV is 34.1 us/layer (26% of remaining decode time). We built a warp-decode bf16 GEMV at 4.1 us but have not integrated it into production. Adding it projects to ~199 tok/s, a further 23% improvement.
Here at Morph, we're building specialized models and specialized inference engines for each one. If you're
All measurements on NVIDIA RTX PRO 6000 Blackwell Server Edition (SM120, 96 GB HBM3e, ~1.5 TB/s measured bandwidth). Model: Qwen3-Next-80B-A3B-Instruct-NVFP4 on TensorRT-LLM 1.3.0rc10.
