# Breakable CUDA Graph

## Motivation

Standard CUDA graphs capture an entire forward pass as a single, opaque graph. This is great for performance, but creates two problems:

1. **Debugging is hard.** When something goes wrong inside a captured graph (wrong outputs, numerical mismatches, crashes), there is no way to step through the operations or insert print statements because the graph replays as a monolithic unit.

2. **Some ops are incompatible.** Certain operations — dynamic control flow, host-device synchronization, JIT compilation, or ops that change behavior across iterations — cannot be captured into a CUDA graph at all. Today, the only workaround is to disable CUDA graphs entirely, which sacrifices the kernel launch overhead savings for the rest of the model.

**Breakable CUDA Graph** solves both problems by allowing graph breaks to be inserted at specific points. The computation is split into multiple captured graph segments with eager (non-graph) execution in between. This preserves most of the CUDA graph performance benefit while allowing targeted operations to run outside the graph.

## Usage

### Debug Mode: Run Everything Eagerly

The simplest use case is debugging. The `--debug-cuda-graph` flag wraps the entire decode forward pass in a graph break, so every operation runs eagerly while still going through the full CUDA graph capture/replay code path. This lets you debug CUDA graph issues without changing model code.

```bash
python -m sglang.launch_server \
    --model meta-llama/Llama-3.1-8B-Instruct \
    --debug-cuda-graph
```

This mode is intended for debugging only — it eliminates the performance benefit of CUDA graphs since every op runs eagerly.

### Selective Graph Breaks in Model Code

For production use, you can mark specific functions as "non-graphable" using the `@eager_on_graph` decorator. During CUDA graph capture, these functions run eagerly between captured graph segments. Outside of capture, they behave normally.

```python
from sglang.srt.model_executor.breakable_cuda_graph.breakable_cuda_graph import eager_on_graph

@eager_on_graph(enable=True)
def my_dynamic_op(x):
    # This op is incompatible with CUDA graph capture
    return some_dynamic_operation(x)
```

You can also insert a bare graph break (no computation) using the `break_graph()` helper:

```python
from sglang.srt.model_executor.breakable_cuda_graph.breakable_cuda_graph import break_graph

def forward(self, x):
    x = self.layer1(x)
    break_graph()  # force a segment split here
    x = self.layer2(x)
    return x
```

To enable breakable CUDA graph at the environment level (without debug mode), set the environment variable:

```bash
export SGLANG_USE_BREAKABLE_CUDA_GRAPH=1
python -m sglang.launch_server \
    --model meta-llama/Llama-3.1-8B-Instruct
```

### Server Args

| Argument | Default | Description |
|---|---|---|
| `--debug-cuda-graph` | `False` | Enable debug/eager mode. Wraps the entire forward pass in a graph break so every op runs eagerly through the capture/replay path. |
| `SGLANG_USE_BREAKABLE_CUDA_GRAPH` | `0` | Environment variable. Enables breakable CUDA graph without debug mode. Required for `@eager_on_graph` decorators to take effect. |

## How It Works

### Capture

Breakable CUDA graph extends PyTorch's `torch.cuda.CUDAGraph` by splitting a single capture into multiple segments separated by graph breaks.

During capture, the flow is:

```
Begin capture (segment 1)
  ... graphable ops ...
  @eager_on_graph function encountered:
    1. End current capture segment
    2. Run the function eagerly (allocates output tensors)
    3. Record the function for later replay
    4. Begin new capture segment
  ... more graphable ops ...
End capture (segment N)
```

Each segment is independently instantiated as a CUDA graph executable. The non-graph functions and their argument references are stored for replay.

### Replay

During replay:

```
For each segment i:
  1. Launch CUDA graph segment i
  2. Run the recorded non-graph function i eagerly
Launch final CUDA graph segment
```

The non-graph functions are re-invoked with the same tensor references as capture time. Since these references point to the CUDA graph's static input/output buffers, they see updated values on each replay.

### Output Writeback

When a non-graph function produces output during replay, the result must be written back into the same tensor buffers that downstream graph segments reference. The mechanism handles:

- **Plain tensors**: In-place `copy_()` into the original buffer.
- **Structured outputs** (dataclasses, objects with tensor attributes): Tensor fields are copied in-place; non-tensor fields are replaced.
- **Dicts of tensors**: Tensor values are copied in-place; non-tensor values are replaced.

### Stream Fork/Join Tracking

Some models fork work onto secondary CUDA streams (e.g., for overlapped computation). Breakable CUDA graph hooks `torch.cuda.Stream.wait_stream` to track which streams are forked from the capture stream. When a graph break occurs, all forked streams are automatically joined back before ending the capture segment, and re-forked after beginning the next segment.

## Compatibility

- **NVIDIA CUDA only.** Breakable CUDA graph is not supported on ROCm/HIP or other non-CUDA platforms. On unsupported platforms, `--debug-cuda-graph` is automatically disabled with a warning.
- **Requires `cuda-python`.** The `cuda.bindings` package must be installed (`pip install cuda-python`).
- **Not compatible with memory saver mode.** Cannot be used together with `SGLANG_MEMORY_SAVER_CUDA_GRAPH`.

## Performance

When no graph breaks are inserted, breakable CUDA graph has minimal overhead compared to standard CUDA graph — the capture/replay path is nearly identical.

Each graph break adds:
- One `cudaGraphLaunch` call (to replay the segment before the break)
- One eager Python function call
- One `cudaStreamBeginCapture` / `cudaStreamEndCapture` pair during capture

For typical use cases with a small number of graph breaks, the overhead is negligible compared to the saved kernel launch overhead from the captured segments.

## Code Reference

| File | Description |
|---|---|
| `python/sglang/srt/model_executor/breakable_cuda_graph/breakable_cuda_graph.py` | Core implementation: `eager_on_graph`, `BreakableCUDAGraph`, `BreakableCUDAGraphCapture` |
| `python/sglang/srt/model_executor/breakable_cuda_graph/cuda_utils.py` | CUDA runtime binding utilities |
| `python/sglang/srt/model_executor/cuda_graph_runner.py` | Integration with the main CUDA graph runner |
| `python/sglang/srt/server_args.py` | `--debug-cuda-graph` flag and environment variable handling |
| `python/sglang/srt/environ.py` | `SGLANG_USE_BREAKABLE_CUDA_GRAPH` environment variable definition |
