Piecewise CUDA Graph#
Motivation#
Standard CUDA graphs capture the entire model forward pass as a single graph. This works well for decode (fixed batch size), but not for extend/prefill where the number of tokens varies across iterations.
Piecewise CUDA Graph (PCG) solves this by splitting the model’s computation graph into pieces (roughly one per layer) at “split points” (e.g., MoE dispatch ops). Each piece is captured as a separate CUDA graph for a set of pre-defined token lengths. At runtime, the input is padded to the nearest captured size, and each piece is replayed. This eliminates kernel launch overhead for prefill/extend while still supporting dynamic shapes.
Recently we enabled PCG by default, which means that the old --enable-piecewise-cuda-graph flag is deprecated. Use --disable-piecewise-cuda-graph to turn it off.
Usage#
PCG is enabled by default for supported configurations. No extra flags needed:
python3 -m sglang.launch_server \
--model-path meta-llama/Llama-3.1-8B-Instruct
Disable PCG#
python3 -m sglang.launch_server \
--model-path meta-llama/Llama-3.1-8B-Instruct \
--disable-piecewise-cuda-graph
Custom capture sizes#
python3 -m sglang.launch_server \
--model-path meta-llama/Llama-3.1-8B-Instruct \
--piecewise-cuda-graph-max-tokens 2048
Server Args#
Argument |
Default |
Description |
|---|---|---|
|
|
Disable PCG for extend/prefill. |
|
|
Force-enable PCG, skipping all auto-disable conditions. For testing only. |
|
|
Maximum token count to capture. Defaults to |
|
|
Explicit list of token lengths to capture. Auto-generated if not set. |
|
|
Compiler backend for the captured subgraphs. Choices: |
~~ |
— |
Deprecated. PCG is now enabled by default. Use |
Bug Report#
PCG is enabled by default but is still in an experimental stage. Since PCG relies on torch.compile to trace the model’s forward pass, most bugs are introduced by torch compile tracing failures (e.g., untraceable ops, dynamic control flow, or graph breaks). If you encounter any issues related to PCG, please disable it by adding --disable-piecewise-cuda-graph to your launch command and report the bug at GitHub Issues. We greatly appreciate your help in improving this feature.
For Users#
If you see an error message like the following during server startup, it is a PCG bug:
Piecewise CUDA Graph is enabled by default as an experimental feature.
To work around this error, add --disable-piecewise-cuda-graph to your launch command.
Please report this issue at https://github.com/sgl-project/sglang/issues/new/choose
To work around it, add --disable-piecewise-cuda-graph to your launch command. When filing a bug report, please include:
The full error traceback
Model name and quantization method
Launch command with all arguments
GPU type and driver version
For Developers#
Since PCG relies on torch.compile to trace the model’s forward pass, newly developed CUDA kernels (both JIT kernels and sgl-kernels) are typically not compatible with torch.compile out of the box. The tracing will fail on untraceable operations such as JIT compilation, file I/O, or dynamic module loading inside the kernel.
To make a kernel compatible with PCG, you need to register it as a custom op using register_custom_op from sglang.srt.utils.custom_op. This wraps the kernel as an opaque node in the compiled graph so that torch.compile will not trace inside it.
Example usage (JIT kernel):
from sglang.srt.utils.custom_op import register_custom_op
# Inplace operator (no return value)
@register_custom_op(mutates_args=["output_q", "output_s"])
def per_token_group_quant_8bit(
input: torch.Tensor,
output_q: torch.Tensor,
output_s: torch.Tensor,
) -> None:
# kernel implementation ...
Example usage (operator with output):
# out_shape indicates which argument has the same shape as the output
@register_custom_op(mutates_args=["x"], out_shape=0)
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x.add_(y)
For wrapping external library functions (e.g., FlashInfer kernels), use register_custom_op_from_extern instead. See python/sglang/srt/utils/custom_op.py for full API documentation.
How it works#
Torch compile backend#
PCG uses torch.compile with a custom backend (SGLangBackend) to split and compile the model’s forward pass. The flow is:
model.forward wrapper
→ torch.compile(..., backend=SGLangBackend)
→ FX graph
→ split_graph() at registered split ops
→ split_gm (top-level graph that chains the pieces)
→ replace capturable submodules with CUDAPiecewiseBackend
→ runtime dispatch: eager split ops + per-piece capture/replay
Install:
install_torch_compiled()replacesmodel.forwardwith a wrapper function. Whenis_in_piecewise_cuda_graph()returns True, the wrapper dispatches to the compiled callable; otherwise it falls back to the original forward. The first invocation through this path triggers Dynamo tracing and graph compilation — CUDA graph replay only happens after the capture phase completes.Split: When
torch.compiletraces the model,SGLangBackendreceives the FX graph and callssplit_graph(). Ops listed inCompilationConfig.split_opsare treated as split points, so the graph is cut at each one. These split-op submodules are left to run eagerly at runtime, while the surrounding submodules are compiled and wrapped byCUDAPiecewiseBackend. The result is a top-level “stitching graph” (split_gm) with children such assubmod_0,submod_1, … interleaving capturable subgraphs and eager split-op submodules.Replace:
PiecewiseCompileInterpreteriterates over each capturable submodule insplit_gm, compiles it for general (dynamic) shapes, and replaces it in-place with aCUDAPiecewiseBackendinstance. Split-op submodules (e.g., attention, all-reduce) are left as-is and run eagerly at runtime.Dispatch: At runtime, calling
split_gmexecutes the stitching graph, which calls each submodule in order. Split-op submodules run eagerly. EachCUDAPiecewiseBackendsubmodule goes through three phases:Compile warmup — runs the general-shape compiled path.
Capture — for each capture size, runs one warmup pass then records a CUDA graph.
Steady-state replay — replays the captured CUDA graph for each forward pass.
Piecewise cuda graph runner#
PiecewiseCudaGraphRunner orchestrates the full lifecycle through three phases:
Compile — Warms up JIT kernels with a dummy forward pass, then wraps the model with
torch.compile, triggering Dynamo tracing to split the FX graph and createCUDAPiecewiseBackendinstances for each subgraph piece.Capture — Iterates over capture sizes in reverse order (largest first). For each size, runs the forward pass twice (one warmup, one CUDA graph capture).
Replay — At runtime, finds the smallest captured size >= actual token count via binary search, copies inputs into static buffers with zero-padding, replays the captured CUDA graphs, and slices outputs back to the actual token count.
Memory optimization#
The memory cost of PCG comes from two parts: torch memory allocator and non-torch memory.
The torch memory allocator overhead is trivial thanks to several optimizations: a global shared memory pool is reused across all CUDA graph runners and capture sizes, capture is done in reverse order (large to small) so smaller graphs reuse memory allocated by larger ones, and output tensors of the last subgraph are stored as weak references to maximize memory reuse.
The main memory overhead comes from non-torch memory — the CUDA graph objects themselves require GPU memory to store the recorded kernel launch parameters and internal state. This overhead scales with the number of captured sizes, which is why piecewise_cuda_graph_max_tokens is capped conservatively by default.
Shape configuration#
Piecewise CUDA graph pre-captures graphs for a set of token counts. At runtime, the actual token count is rounded up to the nearest captured size (via binary search), and the corresponding graph is replayed. If the token count exceeds the largest captured size, the runtime falls back to the normal (non-graph) forward path.
The default capture schedule is auto-generated with increasing granularity:
Token range |
Step size |
|---|---|
4 – 32 |
4 |
48 – 256 |
16 |
288 – 512 |
32 |
576 – 1024 |
64 |
1280 – 4096 |
256 |
4096+ |
512 |
For the auto-generated schedule, sizes are capped at --piecewise-cuda-graph-max-tokens. The default cap is chunked_prefill_size for non-MLA models and 2048 for MLA backend models. If --max-total-tokens is set, the cap is further limited to not exceed it. Additionally, Llama-2 models are auto-capped at 4096 tokens as a temporary workaround.
Compatibility#
PCG is auto-disabled in the following scenarios. We are actively working on expanding compatibility — support for many of these will be coming soon.
Disabled model architectures (e.g.,
DeepseekV32ForCausalLM)Speculative decoding
DP attention
Pipeline parallelism (
pp_size > 1)Non-CUDA hardware (AMD ROCm, Ascend NPU)
MoE A2A backend
LoRA
Multimodal / VLM models
DLLM (diffusion LLM)
Deterministic inference
PD disaggregation
Expert distribution recorder / EPLB
Use --enforce-piecewise-cuda-graph to skip all auto-disable checks (for testing/debugging only).
Code Reference#
File |
Description |
|---|---|
|
Main runner: init, capture, replay |
|
|
|
|
|
Per-subgraph CUDA graph capture/replay |
|
Global context flags and |
|
Capture sizes, split ops, compiler config |
|
|
|
Server arguments and auto-disable logic |