Speculative Decoding#

SGLang now provides an EAGLE-based (EAGLE-2/EAGLE-3) speculative decoding option. Our implementation aims to maximize speed and efficiency and is considered to be among the fastest in open-source LLM engines.

Performance Highlights#

Please see below for the huge improvements on throughput for LLaMA-Instruct 3.1 8B tested on MT bench that can be achieved via EAGLE3 decoding. For further details please see the EAGLE3 paper.

Method

Throughput (tokens/s)

SGLang (w/o speculative, 1x H100)

158.34 tokens/s

SGLang + EAGLE-2 (1x H100)

244.10 tokens/s

SGLang + EAGLE-3 (1x H100)

373.25 tokens/s

EAGLE Decoding#

To enable EAGLE speculative decoding the following parameters are relevant:

  • speculative_draft_model_path: Specifies draft model. This parameter is required.

  • speculative_num_steps: Depth of autoregressive drafting. Increases speculation range but risks rejection cascades. Default is 5.

  • speculative_eagle_topk: Branching factor per step. Improves candidate diversity, will lead to higher acceptance rate, but more lead to higher memory/compute consumption. Default is 4.

  • speculative_num_draft_tokens: Maximum parallel verification capacity. Allows deeper tree evaluation but will lead to higher GPU memory usage. Default is 8.

These parameters are the same for EAGLE-2 and EAGLE-3.

You can find the best combinations of these parameters with bench_speculative.py.

In the documentation below, we set --cuda-graph-max-bs to be a small value for faster engine startup. For your own workloads, please tune the above parameters together with --cuda-graph-max-bs, --max-running-requests, --mem-fraction-static for the best performance.

EAGLE-2 decoding#

You can enable EAGLE-2 decoding by setting --speculative-algorithm EAGLE and choosing an appropriate model.

[1]:
from sglang.test.doc_patch import launch_server_cmd
from sglang.utils import wait_for_server, print_highlight, terminate_process

import openai
[2026-01-09 04:07:23] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2026-01-09 04:07:23] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2026-01-09 04:07:23] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2]:
server_process, port = launch_server_cmd(
    """
python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf  --speculative-algorithm EAGLE \
    --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 3 \
    --speculative-eagle-topk 4 --speculative-num-draft-tokens 16 --cuda-graph-max-bs 8 --log-level warning
"""
)

wait_for_server(f"http://localhost:{port}")
[2026-01-09 04:07:31] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2026-01-09 04:07:31] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2026-01-09 04:07:31] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2026-01-09 04:07:34] INFO server_args.py:1616: Attention backend not specified. Use flashinfer backend by default.
[2026-01-09 04:07:34] WARNING server_args.py:2078: Overlap scheduler is disabled when spec v2 is off or using unsupported speculative algorithm. You can set env SGLANG_ENABLE_SPEC_V2=True to enable the experimental overlap scheduler.
[2026-01-09 04:07:34] INFO server_args.py:2513: Set soft_watchdog_timeout since in CI
[2026-01-09 04:07:45] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2026-01-09 04:07:45] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2026-01-09 04:07:45] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2026-01-09 04:07:45] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2026-01-09 04:07:45] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2026-01-09 04:07:45] INFO utils.py:164: NumExpr defaulting to 16 threads.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2026-01-09 04:07:50] Ignore import error when loading sglang.srt.models.glmasr: cannot import name 'GlmAsrConfig' from 'transformers' (/usr/local/lib/python3.10/dist-packages/transformers/__init__.py)
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:01<00:01,  1.86s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.32s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.41s/it]

Capturing batches (bs=1 avail_mem=23.62 GB): 100%|██████████| 4/4 [00:00<00:00,  6.98it/s]
[2026-01-09 04:07:57] SPECULATIVE_MOE_RUNNER_BACKEND is not initialized, using auto backend
[2026-01-09 04:07:57] SPECULATIVE_MOE_A2A_BACKEND is not initialized, using none backend
Loading pt checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.44s/it]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.44s/it]

Capturing batches (bs=1 avail_mem=22.44 GB): 100%|██████████| 4/4 [00:09<00:00,  2.37s/it]
Capturing batches (bs=1 avail_mem=15.95 GB): 100%|██████████| 4/4 [00:00<00:00, 19.30it/s]


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
To reduce the log length, we set the log level to warning for the server, the default log level is info.
We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.
[3]:
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")

response = client.chat.completions.create(
    model="meta-llama/Llama-2-7b-chat-hf",
    messages=[
        {"role": "user", "content": "List 3 countries and their capitals."},
    ],
    temperature=0,
    max_tokens=64,
)

print_highlight(f"Response: {response}")
Response: ChatCompletion(id='f82fcbe9349343a18d6f1652be0840a8', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=' Sure! Here are three countries and their capitals:\n\n1. Country: France\nCapital: Paris\n2. Country: Japan\nCapital: Tokyo\n3. Country: Brazil\nCapital: Brasília', refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=None, reasoning_content=None), matched_stop=2)], created=1767931697, model='meta-llama/Llama-2-7b-chat-hf', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=48, prompt_tokens=17, total_tokens=65, completion_tokens_details=None, prompt_tokens_details=None, reasoning_tokens=0), metadata={'weight_version': 'default'})
[4]:
terminate_process(server_process)

EAGLE-2 Decoding with torch.compile#

You can also enable torch.compile for further optimizations and optionally set --torch-compile-max-bs:

[5]:
server_process, port = launch_server_cmd(
    """
python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf  --speculative-algorithm EAGLE \
    --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \
        --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.6 \
            --enable-torch-compile --torch-compile-max-bs 2 --log-level warning
"""
)

wait_for_server(f"http://localhost:{port}")
[2026-01-09 04:08:22] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2026-01-09 04:08:22] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2026-01-09 04:08:22] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2026-01-09 04:08:25] INFO server_args.py:1616: Attention backend not specified. Use flashinfer backend by default.
[2026-01-09 04:08:25] WARNING server_args.py:2078: Overlap scheduler is disabled when spec v2 is off or using unsupported speculative algorithm. You can set env SGLANG_ENABLE_SPEC_V2=True to enable the experimental overlap scheduler.
[2026-01-09 04:08:25] INFO server_args.py:2513: Set soft_watchdog_timeout since in CI
[2026-01-09 04:08:31] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2026-01-09 04:08:31] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2026-01-09 04:08:31] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2026-01-09 04:08:32] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2026-01-09 04:08:32] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2026-01-09 04:08:32] INFO utils.py:164: NumExpr defaulting to 16 threads.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2026-01-09 04:08:38] Ignore import error when loading sglang.srt.models.glmasr: cannot import name 'GlmAsrConfig' from 'transformers' (/usr/local/lib/python3.10/dist-packages/transformers/__init__.py)
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:01<00:01,  1.70s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.22s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.29s/it]

[2026-01-09 04:08:41] max_total_tokens=20480 is larger than the profiled value 6689. Use the profiled value instead.
Capturing batches (bs=2 avail_mem=30.41 GB):  25%|██▌       | 1/4 [00:00<00:00,  6.71it/s]/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py:1692: UserWarning: Dynamo detected a call to a `functools.lru_cache`-wrapped function. Dynamo ignores the cache wrapper and directly traces the wrapped function. Silent incorrectness is only a *potential* risk, not something we have observed. Enable TORCH_LOGS="+dynamo" for a DEBUG stack trace.
  torch._dynamo.utils.warn_once(msg)
Autotune Choices Stats:
{"num_choices": 20, "num_triton_choices": 19, "best_kernel": "mm", "best_time": 0.04835199937224388, "best_triton_pos": 1, "best_triton_time": 0.049375999718904495, "best_triton_kernel": "triton_mm_18", "best_triton_kernel_desc": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8"}
AUTOTUNE mm(128x4096, 4096x12288)
strides: [4096, 1], [1, 4096]
dtypes: torch.float16, torch.float16
  mm 0.0484 ms 100.0%
  triton_mm_18 0.0494 ms 97.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_11 0.0524 ms 92.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_12 0.0528 ms 91.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_8 0.0551 ms 87.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_7 0.0560 ms 86.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
  triton_mm_17 0.0580 ms 83.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_10 0.0673 ms 71.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
  triton_mm_14 0.0687 ms 70.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
  triton_mm_4 0.0711 ms 68.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.4427 seconds and 0.2569 seconds precompiling for 20 choices
Autotune Choices Stats:
{"num_choices": 20, "num_triton_choices": 19, "best_kernel": "mm", "best_time": 0.022016000002622604, "best_triton_pos": 1, "best_triton_time": 0.02319999970495701, "best_triton_kernel": "triton_mm_27", "best_triton_kernel_desc": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4"}
AUTOTUNE mm(128x4096, 4096x4096)
strides: [4096, 1], [1, 4096]
dtypes: torch.float16, torch.float16
  mm 0.0220 ms 100.0%
  triton_mm_27 0.0232 ms 94.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_31 0.0266 ms 82.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_23 0.0302 ms 72.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_37 0.0318 ms 69.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_26 0.0394 ms 55.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
  triton_mm_30 0.0411 ms 53.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_22 0.0418 ms 52.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_20 0.0424 ms 52.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=4
  triton_mm_36 0.0428 ms 51.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.2888 seconds and 0.4154 seconds precompiling for 20 choices
Autotune Choices Stats:
{"num_choices": 20, "num_triton_choices": 19, "best_kernel": "triton_mm_49", "best_kernel_desc": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4", "best_time": 0.07468800246715546, "best_triton_pos": 0}
AUTOTUNE mm(128x4096, 4096x22016)
strides: [4096, 1], [1, 4096]
dtypes: torch.float16, torch.float16
  triton_mm_49 0.0747 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  mm 0.0755 ms 98.9%
  triton_mm_55 0.0774 ms 96.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_50 0.0799 ms 93.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_56 0.0830 ms 89.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_45 0.0931 ms 80.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
  triton_mm_46 0.0961 ms 77.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_47 0.0970 ms 77.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_48 0.1000 ms 74.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
  triton_mm_54 0.1016 ms 73.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.4733 seconds and 0.2499 seconds precompiling for 20 choices
Autotune Choices Stats:
{"num_choices": 20, "num_triton_choices": 19, "best_kernel": "mm", "best_time": 0.046751998364925385, "best_triton_pos": 1, "best_triton_time": 0.04972799867391586, "best_triton_kernel": "triton_mm_65", "best_triton_kernel_desc": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4"}
AUTOTUNE mm(128x11008, 11008x4096)
strides: [11008, 1], [1, 11008]
dtypes: torch.float16, torch.float16
  mm 0.0468 ms 100.0%
  triton_mm_65 0.0497 ms 94.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_69 0.0547 ms 85.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_61 0.0651 ms 71.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_75 0.0689 ms 67.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_64 0.0881 ms 53.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
  triton_mm_68 0.0922 ms 50.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_60 0.0958 ms 48.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_74 0.0969 ms 48.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_58 0.0991 ms 47.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.5766 seconds and 0.0002 seconds precompiling for 20 choices
Autotune Choices Stats:
{"num_choices": 20, "num_triton_choices": 19, "best_kernel": "triton_mm_93", "best_kernel_desc": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4", "best_time": 0.10201600193977356, "best_triton_pos": 0}
AUTOTUNE mm(128x4096, 4096x32000)
strides: [4096, 1], [1, 4096]
dtypes: torch.float16, torch.float16
  triton_mm_93 0.1020 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_94 0.1045 ms 97.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  mm 0.1046 ms 97.5%
  triton_mm_88 0.1085 ms 94.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_83 0.1156 ms 88.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
  triton_mm_87 0.1163 ms 87.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_92 0.1242 ms 82.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_84 0.1284 ms 79.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_85 0.1349 ms 75.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_89 0.1418 ms 71.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.5310 seconds and 0.3282 seconds precompiling for 20 choices
Capturing batches (bs=1 avail_mem=61.55 GB):  75%|███████▌  | 3/4 [00:20<00:07,  7.66s/it]Autotune Choices Stats:
{"num_choices": 18, "num_triton_choices": 17, "best_kernel": "mm", "best_time": 0.047520000487565994, "best_triton_pos": 1, "best_triton_time": 0.04809600114822388, "best_triton_kernel": "triton_mm_107", "best_triton_kernel_desc": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4"}
AUTOTUNE mm(64x4096, 4096x12288)
strides: [4096, 1], [1, 4096]
dtypes: torch.float16, torch.float16
  mm 0.0475 ms 100.0%
  triton_mm_107 0.0481 ms 98.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_111 0.0482 ms 98.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_103 0.0492 ms 96.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_99 0.0507 ms 93.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_102 0.0550 ms 86.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
  triton_mm_106 0.0554 ms 85.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_96 0.0566 ms 84.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=4
  triton_mm_98 0.0614 ms 77.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_97 0.0641 ms 74.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.2855 seconds and 0.2314 seconds precompiling for 18 choices
Autotune Choices Stats:
{"num_choices": 18, "num_triton_choices": 17, "best_kernel": "mm", "best_time": 0.02208000048995018, "best_triton_pos": 1, "best_triton_time": 0.02236800082027912, "best_triton_kernel": "triton_mm_116", "best_triton_kernel_desc": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4"}
AUTOTUNE mm(64x4096, 4096x4096)
strides: [4096, 1], [1, 4096]
dtypes: torch.float16, torch.float16
  mm 0.0221 ms 100.0%
  triton_mm_116 0.0224 ms 98.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_120 0.0233 ms 94.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_124 0.0265 ms 83.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_128 0.0297 ms 74.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_119 0.0393 ms 56.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
  triton_mm_113 0.0393 ms 56.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=4
  triton_mm_123 0.0407 ms 54.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_115 0.0408 ms 54.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_114 0.0419 ms 52.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.2577 seconds and 0.2918 seconds precompiling for 18 choices
Autotune Choices Stats:
{"num_choices": 18, "num_triton_choices": 17, "best_kernel": "triton_mm_136", "best_kernel_desc": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8", "best_time": 0.07411199808120728, "best_triton_pos": 0}
AUTOTUNE mm(64x4096, 4096x22016)
strides: [4096, 1], [1, 4096]
dtypes: torch.float16, torch.float16
  triton_mm_136 0.0741 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
  triton_mm_137 0.0746 ms 99.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_140 0.0747 ms 99.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  mm 0.0757 ms 97.8%
  triton_mm_141 0.0761 ms 97.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_145 0.0793 ms 93.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_133 0.0827 ms 89.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_139 0.0899 ms 82.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
  triton_mm_143 0.0917 ms 80.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=8
  triton_mm_142 0.0925 ms 80.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.3776 seconds and 0.1382 seconds precompiling for 18 choices
Autotune Choices Stats:
{"num_choices": 18, "num_triton_choices": 17, "best_kernel": "triton_mm_150", "best_kernel_desc": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4", "best_time": 0.047775998711586, "best_triton_pos": 0}
AUTOTUNE mm(64x11008, 11008x4096)
strides: [11008, 1], [1, 11008]
dtypes: torch.float16, torch.float16
  triton_mm_150 0.0478 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  mm 0.0487 ms 98.1%
  triton_mm_154 0.0494 ms 96.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_158 0.0560 ms 85.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_162 0.0646 ms 74.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_153 0.0882 ms 54.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
  triton_mm_149 0.0889 ms 53.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_157 0.0922 ms 51.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_148 0.0922 ms 51.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_147 0.0931 ms 51.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.4025 seconds and 0.0003 seconds precompiling for 18 choices
Autotune Choices Stats:
{"num_choices": 18, "num_triton_choices": 17, "best_kernel": "triton_mm_175", "best_kernel_desc": "ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4", "best_time": 0.10060799866914749, "best_triton_pos": 0}
AUTOTUNE mm(64x4096, 4096x32000)
strides: [4096, 1], [1, 4096]
dtypes: torch.float16, torch.float16
  triton_mm_175 0.1006 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=4, num_warps=4
  triton_mm_179 0.1011 ms 99.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=8
  triton_mm_171 0.1018 ms 98.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_174 0.1018 ms 98.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_170 0.1025 ms 98.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=8
  mm 0.1036 ms 97.2%
  triton_mm_167 0.1106 ms 90.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=5, num_warps=4
  triton_mm_172 0.1193 ms 84.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_176 0.1220 ms 82.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=3, num_warps=4
  triton_mm_164 0.1334 ms 75.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=True, GROUP_M=8, USE_FAST_ACCUM=False, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.4342 seconds and 0.2807 seconds precompiling for 18 choices
Capturing batches (bs=1 avail_mem=61.55 GB): 100%|██████████| 4/4 [00:38<00:00,  9.74s/it]
[2026-01-09 04:09:22] SPECULATIVE_MOE_RUNNER_BACKEND is not initialized, using auto backend
[2026-01-09 04:09:22] SPECULATIVE_MOE_A2A_BACKEND is not initialized, using none backend
Loading pt checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.16s/it]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.16s/it]

[2026-01-09 04:09:23] max_total_tokens=20480 is larger than the profiled value 6689. Use the profiled value instead.
Capturing batches (bs=1 avail_mem=22.90 GB): 100%|██████████| 4/4 [00:12<00:00,  3.15s/it]
Capturing batches (bs=1 avail_mem=43.21 GB): 100%|██████████| 4/4 [00:00<00:00, 46.80it/s]


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
To reduce the log length, we set the log level to warning for the server, the default log level is info.
We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.
[6]:
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")

response = client.chat.completions.create(
    model="meta-llama/Llama-2-7b-chat-hf",
    messages=[
        {"role": "user", "content": "List 3 countries and their capitals."},
    ],
    temperature=0,
    max_tokens=64,
)

print_highlight(f"Response: {response}")
Response: ChatCompletion(id='f290712dc4e14705a16169fa2da33ac3', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=' Sure! Here are three countries and their capitals:\n\n1. Country: France\nCapital: Paris\n2. Country: Japan\nCapital: Tokyo\n3. Country: Brazil\nCapital: Brasília', refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=None, reasoning_content=None), matched_stop=2)], created=1767931784, model='meta-llama/Llama-2-7b-chat-hf', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=48, prompt_tokens=17, total_tokens=65, completion_tokens_details=None, prompt_tokens_details=None, reasoning_tokens=0), metadata={'weight_version': 'default'})
[7]:
terminate_process(server_process)

EAGLE-2 Decoding via Frequency-Ranked Speculative Sampling#

By employing a truncated high-frequency token vocabulary in the draft model, Eagle speculative decoding reduces lm_head computational overhead while accelerating the pipeline without quality degradation. For more details, checkout the paper.

In our implementation, set --speculative-token-map to enable the optimization. You can get the high-frequency token in FR-Spec from this model. Or you can obtain high-frequency token by directly downloading these token from this repo.

Thanks for the contribution from Weilin Zhao and Zhousx.

[8]:
server_process, port = launch_server_cmd(
    """
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B-Instruct --speculative-algorithm EAGLE \
    --speculative-draft-model-path lmsys/sglang-EAGLE-LLaMA3-Instruct-8B --speculative-num-steps 5 \
    --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --speculative-token-map thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt \
    --mem-fraction 0.7 --cuda-graph-max-bs 2 --dtype float16  --log-level warning
"""
)

wait_for_server(f"http://localhost:{port}")
[2026-01-09 04:09:53] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2026-01-09 04:09:53] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2026-01-09 04:09:53] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2026-01-09 04:09:55] WARNING model_config.py:1014: Casting torch.bfloat16 to torch.float16.
[2026-01-09 04:09:55] INFO server_args.py:1616: Attention backend not specified. Use flashinfer backend by default.
[2026-01-09 04:09:55] WARNING server_args.py:2078: Overlap scheduler is disabled when spec v2 is off or using unsupported speculative algorithm. You can set env SGLANG_ENABLE_SPEC_V2=True to enable the experimental overlap scheduler.
[2026-01-09 04:09:55] INFO server_args.py:2513: Set soft_watchdog_timeout since in CI
[2026-01-09 04:09:56] Casting torch.bfloat16 to torch.float16.
[2026-01-09 04:10:02] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2026-01-09 04:10:02] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2026-01-09 04:10:02] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2026-01-09 04:10:02] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2026-01-09 04:10:02] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2026-01-09 04:10:02] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2026-01-09 04:10:04] Casting torch.bfloat16 to torch.float16.
[2026-01-09 04:10:04] Casting torch.bfloat16 to torch.float16.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2026-01-09 04:10:08] Ignore import error when loading sglang.srt.models.glmasr: cannot import name 'GlmAsrConfig' from 'transformers' (/usr/local/lib/python3.10/dist-packages/transformers/__init__.py)
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:04<00:12,  4.33s/it]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:08<00:08,  4.48s/it]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:13<00:04,  4.50s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:14<00:00,  3.23s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:14<00:00,  3.68s/it]

Capturing batches (bs=1 avail_mem=43.00 GB): 100%|██████████| 4/4 [00:00<00:00, 12.72it/s]
[2026-01-09 04:10:25] SPECULATIVE_MOE_RUNNER_BACKEND is not initialized, using auto backend
[2026-01-09 04:10:25] SPECULATIVE_MOE_A2A_BACKEND is not initialized, using none backend
[2026-01-09 04:10:25] Warning: Target model's context_length (8192) is greater than the derived context_length (2048). This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config.
[2026-01-09 04:10:25] Overriding the draft model's max_position_embeddings to 8192.
Loading pt checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.11s/it]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.11s/it]

Capturing batches (bs=1 avail_mem=41.01 GB): 100%|██████████| 4/4 [00:07<00:00,  1.80s/it]
Capturing batches (bs=1 avail_mem=40.87 GB): 100%|██████████| 4/4 [00:00<00:00, 52.93it/s]


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
To reduce the log length, we set the log level to warning for the server, the default log level is info.
We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.
[9]:
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")

response = client.chat.completions.create(
    model="meta-llama/Meta-Llama-3-8B-Instruct",
    messages=[
        {"role": "user", "content": "List 3 countries and their capitals."},
    ],
    temperature=0,
    max_tokens=64,
)

print_highlight(f"Response: {response}")
Response: ChatCompletion(id='7ce7447a2cda45ba9934bbb506fe4933', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='Here are 3 countries and their capitals:\n\n1. **France** - **Paris**\n2. **Japan** - **Tokyo**\n3. **Australia** - **Canberra**', refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=None, reasoning_content=None), matched_stop=128009)], created=1767931843, model='meta-llama/Meta-Llama-3-8B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=39, prompt_tokens=18, total_tokens=57, completion_tokens_details=None, prompt_tokens_details=None, reasoning_tokens=0), metadata={'weight_version': 'default'})
[10]:
terminate_process(server_process)

EAGLE-3 Decoding#

You can enable EAGLE-3 decoding by setting --speculative-algorithm EAGLE3 and choosing an appropriate model.

[11]:
server_process, port = launch_server_cmd(
    """
python3 -m sglang.launch_server --model meta-llama/Llama-3.1-8B-Instruct  --speculative-algorithm EAGLE3 \
    --speculative-draft-model-path jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B --speculative-num-steps 5 \
        --speculative-eagle-topk 8 --speculative-num-draft-tokens 32 --mem-fraction 0.6 \
        --cuda-graph-max-bs 2 --dtype float16 --log-level warning
"""
)

wait_for_server(f"http://localhost:{port}")
[2026-01-09 04:10:49] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2026-01-09 04:10:49] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2026-01-09 04:10:49] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2026-01-09 04:10:51] WARNING model_config.py:1014: Casting torch.bfloat16 to torch.float16.
[2026-01-09 04:10:51] INFO server_args.py:1616: Attention backend not specified. Use flashinfer backend by default.
[2026-01-09 04:10:51] WARNING server_args.py:2078: Overlap scheduler is disabled when spec v2 is off or using unsupported speculative algorithm. You can set env SGLANG_ENABLE_SPEC_V2=True to enable the experimental overlap scheduler.
[2026-01-09 04:10:51] INFO server_args.py:2513: Set soft_watchdog_timeout since in CI
[2026-01-09 04:10:52] Casting torch.bfloat16 to torch.float16.
[2026-01-09 04:10:59] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2026-01-09 04:10:59] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2026-01-09 04:10:59] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2026-01-09 04:10:59] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2026-01-09 04:10:59] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2026-01-09 04:10:59] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2026-01-09 04:11:01] Casting torch.bfloat16 to torch.float16.
[2026-01-09 04:11:02] Casting torch.bfloat16 to torch.float16.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2026-01-09 04:11:05] Ignore import error when loading sglang.srt.models.glmasr: cannot import name 'GlmAsrConfig' from 'transformers' (/usr/local/lib/python3.10/dist-packages/transformers/__init__.py)
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:04<00:12,  4.33s/it]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:08<00:08,  4.45s/it]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:13<00:04,  4.39s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:14<00:00,  3.13s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:14<00:00,  3.60s/it]

Capturing batches (bs=1 avail_mem=59.16 GB): 100%|██████████| 4/4 [00:00<00:00,  4.63it/s]
[2026-01-09 04:11:24] SPECULATIVE_MOE_RUNNER_BACKEND is not initialized, using auto backend
[2026-01-09 04:11:24] SPECULATIVE_MOE_A2A_BACKEND is not initialized, using none backend
[2026-01-09 04:11:24] Warning: Target model's context_length (131072) is greater than the derived context_length (2048). This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config.
[2026-01-09 04:11:24] Overriding the draft model's max_position_embeddings to 131072.
Loading pt checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.64it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.64it/s]

Capturing batches (bs=1 avail_mem=7.55 GB): 100%|██████████| 4/4 [00:05<00:00,  1.32s/it]
Capturing batches (bs=1 avail_mem=7.00 GB): 100%|██████████| 4/4 [00:00<00:00, 104.46it/s]


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
To reduce the log length, we set the log level to warning for the server, the default log level is info.
We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.
[12]:
client = openai.Client(base_url=f"http://127.0.0.1:{port}/v1", api_key="None")

response = client.chat.completions.create(
    model="meta-llama/Meta-Llama-3.1-8B-Instruct",
    messages=[
        {"role": "user", "content": "List 3 countries and their capitals."},
    ],
    temperature=0,
    max_tokens=64,
)

print_highlight(f"Response: {response}")
Response: ChatCompletion(id='d9e6eb2190bc4f2385ada0f3b1754699', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='Here are 3 countries and their capitals:\n\n1. Country: Japan\n Capital: Tokyo\n\n2. Country: Australia\n Capital: Canberra\n\n3. Country: Brazil\n Capital: Brasília', refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=None, reasoning_content=None), matched_stop=128009)], created=1767931898, model='meta-llama/Meta-Llama-3.1-8B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=43, prompt_tokens=43, total_tokens=86, completion_tokens_details=None, prompt_tokens_details=None, reasoning_tokens=0), metadata={'weight_version': 'default'})
[13]:
terminate_process(server_process)

Multi Token Prediction#

We support MTP(Multi-Token Prediction) in SGLang by using speculative decoding. We use Xiaomi/MiMo-7B-RL model as example here (deepseek mtp usage refer to deepseek doc)

[14]:
server_process, port = launch_server_cmd(
    """
    python3 -m sglang.launch_server --model-path XiaomiMiMo/MiMo-7B-RL --host 0.0.0.0 --trust-remote-code \
    --speculative-algorithm EAGLE --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2 \
    --mem-fraction 0.5 --log-level warning
"""
)

wait_for_server(f"http://localhost:{port}")
[2026-01-09 04:11:44] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2026-01-09 04:11:44] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2026-01-09 04:11:44] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2026-01-09 04:11:46] INFO server_args.py:1616: Attention backend not specified. Use flashinfer backend by default.
[2026-01-09 04:11:46] WARNING server_args.py:2078: Overlap scheduler is disabled when spec v2 is off or using unsupported speculative algorithm. You can set env SGLANG_ENABLE_SPEC_V2=True to enable the experimental overlap scheduler.
[2026-01-09 04:11:46] INFO server_args.py:2513: Set soft_watchdog_timeout since in CI
[2026-01-09 04:11:52] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2026-01-09 04:11:52] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2026-01-09 04:11:52] INFO utils.py:164: NumExpr defaulting to 16 threads.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2026-01-09 04:11:58] Ignore import error when loading sglang.srt.models.glmasr: cannot import name 'GlmAsrConfig' from 'transformers' (/usr/local/lib/python3.10/dist-packages/transformers/__init__.py)
[2026-01-09 04:11:59] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2026-01-09 04:11:59] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2026-01-09 04:11:59] INFO utils.py:164: NumExpr defaulting to 16 threads.
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:01<00:04,  1.55s/it]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:02<00:02,  1.24s/it]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:03<00:01,  1.10s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:04<00:00,  1.02it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:04<00:00,  1.07s/it]

Capturing batches (bs=1 avail_mem=60.22 GB): 100%|██████████| 4/4 [00:00<00:00,  9.96it/s]
[2026-01-09 04:12:05] SPECULATIVE_MOE_RUNNER_BACKEND is not initialized, using auto backend
[2026-01-09 04:12:05] SPECULATIVE_MOE_A2A_BACKEND is not initialized, using none backend
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:00,  5.13it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:00<00:00,  7.55it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:00<00:00,  7.29it/s]

Capturing batches (bs=1 avail_mem=59.37 GB): 100%|██████████| 4/4 [00:00<00:00, 60.10it/s]


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
To reduce the log length, we set the log level to warning for the server, the default log level is info.
We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.
[15]:
import requests

url = f"http://localhost:{port}/v1/chat/completions"

data = {
    "model": "XiaomiMiMo/MiMo-7B-RL",
    "messages": [{"role": "user", "content": "What is the capital of France?"}],
}

response = requests.post(url, json=data)
print_highlight(response.json())
{'id': '2140c2b4d866446381270dfc29624513', 'object': 'chat.completion', 'created': 1767931934, 'model': 'XiaomiMiMo/MiMo-7B-RL', 'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': "\nOkay, the user is asking for the capital of France. Let me start by recalling the basic geography of France. I know that France is a country in Europe. The capital city... Hmm, Paris comes to mind. Wait, is that right? Yes, Paris is the largest city in France and also the capital. But maybe I should double-check to be sure. Sometimes people confuse it with other European capitals, like London or Rome, but no, France's capital is definitely Paris. Let me think if there's any other major city in France that could be the capital. Lyon is a significant city, but it's not the capital. So yes, the answer should be Paris. I should confirm the spelling too. P-A-R-I-S, yes. Paris. Okay, that seems correct. No contradictions in my memory. So the answer is Paris.\n\nThe capital of France is **Paris**. This historic city is known for its iconic landmarks like the Eiffel Tower, Louvre Museum, and the prestigious Panthéon. It remains a global hub for culture, politics, and diplomacy.", 'reasoning_content': None, 'tool_calls': None}, 'logprobs': None, 'finish_reason': 'stop', 'matched_stop': 151645}], 'usage': {'prompt_tokens': 26, 'total_tokens': 255, 'completion_tokens': 229, 'prompt_tokens_details': None, 'reasoning_tokens': 0}, 'metadata': {'weight_version': 'default'}}
[16]:
terminate_process(server_process)

References#

EAGLE process is as follows:

  • Within EAGLE the draft model predicts the next feature vector, i.e. the last hidden state of the original LLM, using the feature sequence \((f_1, ..., f_k)\) and the token sequence \((t_2, ..., t_{k+1})\).

  • The next token is then sampled from \(p_{k+2}=\text{LMHead}(f_{k+1})\). Afterwards, the two sequences are extended in a tree style—branching out multiple potential continuations, with the branching factor per step controlled by the speculative_eagle_topk parameter—to ensure a more coherent connection of context, and are given as input again.

  • EAGLE-2 additionally uses the draft model to evaluate how probable certain branches in the draft tree are, dynamically stopping the expansion of unlikely branches. After the expansion phase, reranking is employed to select only the top speculative_num_draft_tokens final nodes as draft tokens.

  • EAGLE-3 removes the feature prediction objective, incorporates low and mid-layer features, and is trained in an on-policy manner.

This enhances drafting accuracy by operating on the features instead of tokens for more regular inputs and passing the tokens from the next timestep additionally to minimize randomness effects from sampling. Furthermore the dynamic adjustment of the draft tree and selection of reranked final nodes increases acceptance rate of draft tokens further. For more details see EAGLE-2 and EAGLE-3 paper.

For guidance how to train your own EAGLE model please see the EAGLE repo.