Benchmark
SGLang provides four benchmark tools that operate at different levels of the stack. The table below summarizes their key differences:| Tool | HTTP Server | Scheduler | Use Case |
|---|---|---|---|
bench_serving | Yes (async HTTP client to a running server) | Yes (indirectly, via server) | Realistic online serving benchmarks with latency metrics (TTFT, TPOT, ITL) |
bench_one_batch_server | Yes (sends HTTP requests to a running server) | Yes (indirectly, via server) | End-to-end single-batch latency including HTTP and scheduler overhead |
bench_offline_throughput | No | Yes (directly uses Engine in-process) | Maximum throughput measurement without HTTP overhead |
bench_one_batch | No | No (directly calls ModelRunner) | Kernel-level latency profiling of a single static batch |
bench_serving by default unless there are specific needs.
bench_serving is an async HTTP load-testing client that sends requests at controlled rates with configurable concurrency to a running server. It measures realistic online serving metrics including time-to-first-token (TTFT), time-per-output-token (TPOT), inter-token latency (ITL), and throughput. Use num-prompts >= 5 * max-concurrency to measure steady-state performance. Launch a server with sglang.launch_server first.
Command
bench_one_batch_server sends a single batch as one HTTP request to a running server. Due to only having a single batch, the server is never in a steady-state and metrics will be biased. Launch a server with sglang.launch_server first.
Command
bench_offline_throughput directly instantiates the Engine object in-process (no HTTP server) and submits all requests at once via engine.generate(). The engine’s scheduler handles batching and execution. This measures maximum achievable throughput without any network overhead.
Command
bench_one_batch is the lowest-level tool. It directly instantiates a ModelRunner and calls extend() / decode() on a fixed static batch, bypassing the scheduler entirely. The prefill and decode phases are run separately, making profiling easier but rendering the metrics unrealistic. Because there is no dynamic batching, it may run out of memory for batch sizes that a real server can handle (a real server chunks prefill into smaller batches). This is best suited for profiling individual kernel performance.
Command
Profile with PyTorch Profiler
Pytorch Profiler is a convenient basic tool to inspect kernel execution time, call stack, and kernel overlap and occupancy.Profile a server with sglang.bench_serving
Command
bench_serving --profile, the output directory is selected on the client side from --profile-output-dir or SGLANG_TORCH_PROFILER_DIR (fallback: /tmp), then sent in the /start_profile request.
If you call /start_profile directly and do not provide output_dir, the server uses its own SGLANG_TORCH_PROFILER_DIR (fallback: /tmp).
Setting SGLANG_TORCH_PROFILER_DIR on both server and client is still recommended to avoid confusion about where traces are written.
For more details, please refer to Bench Serving Guide.
Profile In PD Disaggregation Mode
When profiling in PD disaggregation mode, prefill and decode workers must be profiled separately due to torch profiler limitations. Thebench_serving command provides dedicated options for this:
Profile Prefill Workers
Command
Profile Decode Workers
Command
Important Notes
--profile-prefill-urland--profile-decode-urlare mutually exclusive - you cannot profile both at the same time- Both options support multiple worker URLs for multi-instance setups:
Command
- Make sure
SGLANG_TORCH_PROFILER_DIRis set on all worker nodes before starting the servers - For more details on setting up PD disaggregation, see PD Disaggregation Guide
Profile a server with sglang.bench_offline_throughput
Command
Profile a server with sglang.profiler
When the server is running (e.g., processing a decoding request), you can start live profiling immediately by sending a profile request to the server.
You can do this by running python3 -m sglang.profiler. For example:
Output
Output
Profile a server with HTTP API endpoints
SGLang provides HTTP API endpoints to control profiling on a running server. This allows you to start and stop profiling programmatically, which is useful for capturing specific workload patterns.Using /start_profile endpoint
The /start_profile endpoint starts profiling on the server. You can control when profiling begins and how long it runs using the following parameters:
Basic usage:
Command
output_dir(optional): Directory where profile traces will be saved. If not specified, usesSGLANG_TORCH_PROFILER_DIRenvironment variable, or/tmpas the defaultnum_steps(optional): Number of steps to profile. If not specified, profiling continues until manually stopped with/stop_profilestart_step(optional): Step number at which to start profiling (inclusive). Useful for skipping warmup iterationsactivities(optional): List of activities to profile, e.g.,["CPU", "GPU"]. Default is["CPU", "GPU"]merge_profiles(optional): Whether to merge distributed traces. Default isfalse
start_step (inclusive) and continues for num_steps iterations. For example, with start_step=3 and num_steps=10, profiling captures steps 3, 4, 5, 6, 7, 8, 9, 10, 11, and 12 (10 steps total, starting from step 3).
Advanced usage with start_step:
Command
Command
Using /stop_profile endpoint
The /stop_profile endpoint stops an ongoing profiling session and saves the trace file.
Command
num_steps. If num_steps is specified, profiling will automatically stop after that many steps.
Example workflow
Command
Profiler Trace Merger for Distributed Traces
SGLang now supports automatic merging of profiling traces from distributed setups with multiple parallelism types (TP, DP, PP, EP). This feature is particularly useful for analyzing performance across distributed runs.Multi-Node Profiling and Shared Storage Considerations
Single-node profiler output merging is completely supported. When profiling in distributed environments spanning multiple nodes, shared storage (e.g., NFS, Lustre) should be accessible by all nodes for the output directory to enable merging of trace files. If there is no shared storage accessible across nodes, automatic merging of trace files during profiling is not supported directly as of now.HTTP API Usage
Command
Command Line Usage
Command
Output Files
The profile merger generates:- Individual rank trace files:
{profile_id}-TP-{tp}-DP-{dp}-PP-{pp}-EP-{ep}.trace.json.gz - Merged trace file:
merged-{profile_id}.trace.json.gz
Possible PyTorch bugs
If in any cases you encounter the following error (for example, using qwen 2.5 VL):Command
with_stack with an environment variable such as follows:
Command
View traces
Trace files can be loaded and visualized from:- https://ui.perfetto.dev/ (any browser)
- chrome://tracing (Chrome browser only)
Command
--num-prompts argument and limits the length of output sequences to 100 with --sharegpt-output-len argument, which can generate a small trace file for browser to open smoothly.
Additionally, if you want to locate the SGLang Python source code through the cuda kernel in Trace, you need to disable CUDA Graph when starting the service. This can be done by using the --disable-cuda-graph parameter in the command to start the service.
Profile with Nsight
Nsight systems is an advanced tool that exposes more profiling details, such as register and shared memory usage, annotated code regions and low-level CUDA APIs and events.-
Prerequisite:
Install using apt, or run inside a NVIDIA Docker container or SGLang Docker container.
Command
-
To profile a single batch, use
Command
-
To profile a server, e.g.
In practice, we recommend users to setCommand
--durationargument to a large value. Whenever user wants the server to stop profiling. Firstly run:to get the session id in the form ofCommandprofile-XXXXX, then run:to manually kill the profiler and generateCommandnsys-repfiles instantly. -
Use NVTX to annotate code regions, e.g. to see their execution time.
CommandExample
Layer-wise NVTX Profiling with Nsight Systems
SGLang provides built-in layerwise NVTX annotations that can be combined with the CUDA Profiler for detailed per-layer profiling in Nsight Systems. This is particularly useful for identifying performance bottlenecks at the layer level.Using --enable-layerwise-nvtx-marker with Nsight Systems and /start_profile
The --enable-layerwise-nvtx-marker flag automatically adds NVTX markers to every layer in your model. This is particularly powerful when combined with Nsight Systems profiling to see detailed per-layer performance.
Method 1: Using /start_profile with CUDA_PROFILER (for programmatic control)
This method allows you to control exactly when profiling starts/stops via HTTP API while Nsight Systems is running.
-
Launch the server with layerwise NVTX enabled under Nsight Systems:
Note: NVTX markers are not emitted for kernel launches captured by CUDA graphs. UseCommand
--disable-cuda-graphto ensure all layerwise NVTX markers are emitted in the trace. -
In another terminal, control profiling via
/start_profilewithCUDA_PROFILERactivity:Command -
Send requests to generate load:
Command
-
Profiling will automatically stop after 10 steps (due to
num_steps: 10). If you hadn’t specifiednum_steps, you would need to manually stop it:Command
--capture-range=cudaProfilerApi option tells Nsight Systems to only capture data between cudaProfilerStart() and cudaProfilerStop() calls (triggered by /start_profile and /stop_profile), reducing overhead and file size. The start_step parameter skips the first 3 steps to avoid capturing warmup overhead.
Method 2: Simpler approach without /start_profile API
For simpler use cases where you don’t need fine-grained control over profiling start/stop, you can profile with Nsight Systems capturing the entire workload:
Command
.qdrep file with Nsight Systems:
Command
- NVTX ranges: Each layer appears as a labeled range in the timeline with detailed information in the marker metadata
- CUDA kernels: All GPU kernels are shown alongside the layer annotations
- Layer hierarchy: The full module path (e.g.,
meta-llama/Meta-Llama-3.1-8B-Instruct.model.layers.0.self_attn.qkv_proj) helps identify specific layers. The prefix uses the full model path from--model-path. - Tensor shapes: Input/output dimensions and parameter shapes are included in the NVTX marker data
- Granular visibility: See exactly which layers are taking the most time
- Memory tracking: Identify layers with large memory allocations
- Bottleneck identification: Quickly locate inefficient operations
- Communication overhead: In multi-GPU setups, see per-layer communication costs
- Development debugging: Validate that model architecture changes have the expected performance impact
Other tips
-
You can benchmark a model using dummy weights by only providing the config.json file. This allows for quick testing of model variants without training. To do so, add
--load-format dummyto the above commands and then you only need a correctconfig.jsonunder the checkpoint folder. -
You can benchmark a model with modified configs (e.g., less layers) by using
--json-model-override-args. For example, you can benchmark a model with only 2 layers and 2 kv heads using:Command -
You can use
--python-backtrace=cudato see python call stack for all CUDA kernels, as in PyTorch Profiler. (Caveat: this can cause inaccurately long kernel runtimes for CUDA event based timing) - For more arguments see Nsight Systems User Guide.
