Reasoning Parser#

SGLang supports parsing reasoning content out from “normal” content for reasoning models such as DeepSeek R1.

Supported Models & Parsers#

Model

Reasoning tags

Parser

Notes

DeepSeek‑R1 series

<think></think>

deepseek-r1

Supports all variants (R1, R1-0528, R1-Distill)

DeepSeek‑V3 series

<think></think>

deepseek-v3

Including DeepSeek‑V3.2. Supports thinking parameter

Standard Qwen3 models

<think></think>

qwen3

Supports enable_thinking parameter

Qwen3-Thinking m odels

<think></think>

qwen3 or qwen3-thinking

Always generates thinking content

Kimi models

◁think▷◁/think▷

kimi

Uses special thinking delimiters

GPT OSS

<\|channel\|>analysis<\|message\|><\|end\|>

gpt-oss

N/A

Model-Specific Behaviors#

DeepSeek-R1 Family:

  • DeepSeek-R1: No <think> start tag, jumps directly to thinking content

  • DeepSeek-R1-0528: Generates both <think> start and </think> end tags

  • Both are handled by the same deepseek-r1 parser

DeepSeek-V3 Family:

  • DeepSeek-V3.1/V3.2: Hybrid model supporting both thinking and non-thinking modes, use the deepseek-v3 parser and thinking parameter (NOTE: not enable_thinking)

Qwen3 Family:

  • Standard Qwen3 (e.g., Qwen3-2507): Use qwen3 parser, supports enable_thinking in chat templates

  • Qwen3-Thinking (e.g., Qwen3-235B-A22B-Thinking-2507): Use qwen3 or qwen3-thinking parser, always thinks

Kimi:

  • Kimi: Uses special ◁think▷ and ◁/think▷ tags

GPT OSS:

  • GPT OSS: Uses special <|channel|>analysis<|message|> and <|end|> tags

Usage#

Launching the Server#

Specify the --reasoning-parser option.

[1]:
import requests
from openai import OpenAI
from sglang.test.doc_patch import launch_server_cmd
from sglang.utils import wait_for_server, print_highlight, terminate_process

server_process, port = launch_server_cmd(
    "python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-7B --host 0.0.0.0 --reasoning-parser deepseek-r1 --log-level warning"
)

wait_for_server(f"http://localhost:{port}")
[2025-11-23 21:36:51] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-11-23 21:36:51] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2025-11-23 21:36:51] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2025-11-23 21:36:59] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-11-23 21:36:59] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2025-11-23 21:36:59] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2025-11-23 21:37:01] WARNING server_args.py:1286: Attention backend not explicitly specified. Use fa3 backend by default.
[2025-11-23 21:37:11] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-11-23 21:37:11] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2025-11-23 21:37:11] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2025-11-23 21:37:12] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-11-23 21:37:12] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2025-11-23 21:37:12] INFO utils.py:164: NumExpr defaulting to 16 threads.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[2025-11-23 21:37:20] Ignore import error when loading sglang.srt.models.mindspore: name 'ms' is not defined
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:01<00:01,  1.50s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.45s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.46s/it]

Capturing batches (bs=1 avail_mem=22.72 GB): 100%|██████████| 3/3 [00:00<00:00,  8.16it/s]
[2025-11-23 21:37:27] Endpoint '/get_model_info' is deprecated and will be removed in a future version. Please use '/model_info' instead.


NOTE: Typically, the server runs in a separate terminal.
In this notebook, we run the server and notebook code together, so their outputs are combined.
To improve clarity, the server logs are displayed in the original black color, while the notebook outputs are highlighted in blue.
To reduce the log length, we set the log level to warning for the server, the default log level is info.
We are running those notebooks in a CI environment, so the throughput is not representative of the actual performance.

Note that --reasoning-parser defines the parser used to interpret responses.

OpenAI Compatible API#

Using the OpenAI compatible API, the contract follows the DeepSeek API design established with the release of DeepSeek-R1:

  • reasoning_content: The content of the CoT.

  • content: The content of the final answer.

[2]:
# Initialize OpenAI-like client
client = OpenAI(api_key="None", base_url=f"http://0.0.0.0:{port}/v1")
model_name = client.models.list().data[0].id

messages = [
    {
        "role": "user",
        "content": "What is 1+3?",
    }
]

Non-Streaming Request#

[3]:
response_non_stream = client.chat.completions.create(
    model=model_name,
    messages=messages,
    temperature=0.6,
    top_p=0.95,
    stream=False,  # Non-streaming
    extra_body={"separate_reasoning": True},
)
print_highlight("==== Reasoning ====")
print_highlight(response_non_stream.choices[0].message.reasoning_content)

print_highlight("==== Text ====")
print_highlight(response_non_stream.choices[0].message.content)
==== Reasoning ====
First, I recognize that the problem is asking for the sum of the numbers 1 and 3.

Next, I add the two numbers together: 1 plus 3 equals 4.

Therefore, the final answer is 4.
==== Text ====
Sure! Let's solve the problem step by step.

**Question:** What is \(1 + 3\)?

**Solution:**

1. **Identify the numbers to add:** We have the numbers 1 and 3.

2. **Add the numbers together:**
\[
1 + 3 = 4
\]

**Final Answer:**
\[
\boxed{4}
\]

Streaming Request#

[4]:
response_stream = client.chat.completions.create(
    model=model_name,
    messages=messages,
    temperature=0.6,
    top_p=0.95,
    stream=True,  # Non-streaming
    extra_body={"separate_reasoning": True},
)

reasoning_content = ""
content = ""
for chunk in response_stream:
    if chunk.choices[0].delta.content:
        content += chunk.choices[0].delta.content
    if chunk.choices[0].delta.reasoning_content:
        reasoning_content += chunk.choices[0].delta.reasoning_content

print_highlight("==== Reasoning ====")
print_highlight(reasoning_content)

print_highlight("==== Text ====")
print_highlight(content)
==== Reasoning ====
First, I recognize that the problem is asking for the sum of the numbers 1 and 3.

To solve this, I start by identifying the two numbers involved: 1 and 3.

Next, I perform the addition operation by combining these two numbers.

Finally, I calculate the result to find the total sum.
==== Text ====


To solve the problem \(1 + 3\), follow these simple steps:

1. **Identify the numbers to add:**
\[
1 \quad \text{and} \quad 3
\]

2. **Perform the addition:**
\[
1 + 3 = 4
\]

3. **Present the final answer:**
\[
\boxed{4}
\]

Optionally, you can buffer the reasoning content to the last reasoning chunk (or the first chunk after the reasoning content).

[5]:
response_stream = client.chat.completions.create(
    model=model_name,
    messages=messages,
    temperature=0.6,
    top_p=0.95,
    stream=True,  # Non-streaming
    extra_body={"separate_reasoning": True, "stream_reasoning": False},
)

reasoning_content = ""
content = ""
for chunk in response_stream:
    if chunk.choices[0].delta.content:
        content += chunk.choices[0].delta.content
    if chunk.choices[0].delta.reasoning_content:
        reasoning_content += chunk.choices[0].delta.reasoning_content

print_highlight("==== Reasoning ====")
print_highlight(reasoning_content)

print_highlight("==== Text ====")
print_highlight(content)
==== Reasoning ====
First, I recognize that the problem is asking for the sum of the numbers 1 and 3.

I start by identifying the two numbers involved in the addition: 1 and 3.

Next, I add these two numbers together: 1 plus 3 equals 4.

Finally, I conclude that the result of the addition is 4.
==== Text ====


Sure! Let's solve the problem step by step.

**Problem:**
What is \(1 + 3\)?

**Solution:**
To find the sum of 1 and 3, follow these simple steps:

1. **Start with the first number:**
\[
1
\]

2. **Add the second number:**
\[
1 + 3
\]

3. **Calculate the sum:**
\[
1 + 3 = 4
\]

**Final Answer:**
\[
\boxed{4}
\]

The reasoning separation is enable by default when specify . To disable it, set the ``separate_reasoning`` option to ``False`` in request.

[6]:
response_non_stream = client.chat.completions.create(
    model=model_name,
    messages=messages,
    temperature=0.6,
    top_p=0.95,
    stream=False,  # Non-streaming
    extra_body={"separate_reasoning": False},
)

print_highlight("==== Original Output ====")
print_highlight(response_non_stream.choices[0].message.content)
==== Original Output ====
I need to add the numbers 1 and 3.

First, I'll identify the two numbers to be added.

Then, I'll perform the addition to find the sum.

Finally, I'll provide the result as the answer.


Sure! Let's solve the addition step by step.

**Question:** What is \(1 + 3\)?

**Solution:**

1. **Identify the numbers to add:**
\[
1 \quad \text{and} \quad 3
\]

2. **Perform the addition:**
\[
1 + 3 = 4
\]

**Answer:**
\[
\boxed{4}
\]

SGLang Native API#

[7]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B")
input = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True, return_dict=False
)

gen_url = f"http://localhost:{port}/generate"
gen_data = {
    "text": input,
    "sampling_params": {
        "skip_special_tokens": False,
        "max_new_tokens": 1024,
        "temperature": 0.6,
        "top_p": 0.95,
    },
}
gen_response = requests.post(gen_url, json=gen_data).json()["text"]

print_highlight("==== Original Output ====")
print_highlight(gen_response)

parse_url = f"http://localhost:{port}/separate_reasoning"
separate_reasoning_data = {
    "text": gen_response,
    "reasoning_parser": "deepseek-r1",
}
separate_reasoning_response_json = requests.post(
    parse_url, json=separate_reasoning_data
).json()
print_highlight("==== Reasoning ====")
print_highlight(separate_reasoning_response_json["reasoning_text"])
print_highlight("==== Text ====")
print_highlight(separate_reasoning_response_json["text"])
==== Original Output ====
First, I need to identify the two numbers in the problem, which are 1 and 3.

Next, I will add these two numbers together.

Adding 1 and 3 gives a sum of 4.

Therefore, the final answer is 4.


**Solution:**

We need to calculate the sum of the numbers 1 and 3.

1. **Identify the numbers to add:**
\[
1 \quad \text{and} \quad 3
\]

2. **Add the two numbers:**
\[
1 + 3 = 4
\]

**Final Answer:**
\[
\boxed{4}
\]
==== Reasoning ====
First, I need to identify the two numbers in the problem, which are 1 and 3.

Next, I will add these two numbers together.

Adding 1 and 3 gives a sum of 4.

Therefore, the final answer is 4.
==== Text ====
**Solution:**

We need to calculate the sum of the numbers 1 and 3.

1. **Identify the numbers to add:**
\[
1 \quad \text{and} \quad 3
\]

2. **Add the two numbers:**
\[
1 + 3 = 4
\]

**Final Answer:**
\[
\boxed{4}
\]
[8]:
terminate_process(server_process)

Offline Engine API#

[9]:
import sglang as sgl
from sglang.srt.parser.reasoning_parser import ReasoningParser
from sglang.utils import print_highlight

llm = sgl.Engine(model_path="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B")
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B")
input = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True, return_dict=False
)
sampling_params = {
    "max_new_tokens": 1024,
    "skip_special_tokens": False,
    "temperature": 0.6,
    "top_p": 0.95,
}
result = llm.generate(prompt=input, sampling_params=sampling_params)

generated_text = result["text"]  # Assume there is only one prompt

print_highlight("==== Original Output ====")
print_highlight(generated_text)

parser = ReasoningParser("deepseek-r1")
reasoning_text, text = parser.parse_non_stream(generated_text)
print_highlight("==== Reasoning ====")
print_highlight(reasoning_text)
print_highlight("==== Text ====")
print_highlight(text)
[2025-11-23 21:37:41] WARNING server_args.py:1286: Attention backend not explicitly specified. Use fa3 backend by default.
[2025-11-23 21:37:41] INFO engine.py:123: server_args=ServerArgs(model_path='deepseek-ai/DeepSeek-R1-Distill-Qwen-7B', tokenizer_path='deepseek-ai/DeepSeek-R1-Distill-Qwen-7B', tokenizer_mode='auto', tokenizer_worker_num=1, skip_tokenizer_init=False, load_format='auto', model_loader_extra_config='{}', trust_remote_code=False, context_length=None, is_embedding=False, enable_multimodal=None, revision=None, model_impl='auto', host='127.0.0.1', port=30000, fastapi_root_path='', grpc_mode=False, skip_server_warmup=False, warmups=None, nccl_port=None, checkpoint_engine_wait_weights_before_ready=False, dtype='auto', quantization=None, quantization_param_path=None, kv_cache_dtype='auto', enable_fp32_lm_head=False, modelopt_quant=None, modelopt_checkpoint_restore_path=None, modelopt_checkpoint_save_path=None, modelopt_export_path=None, quantize_and_serve=False, mem_fraction_static=0.835, max_running_requests=128, max_queued_requests=None, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', enable_priority_scheduling=False, abort_on_priority_when_disabled=False, schedule_low_priority_values_first=False, priority_scheduling_preemption_threshold=10, schedule_conservativeness=1.0, page_size=1, hybrid_kvcache_ratio=None, swa_full_tokens_ratio=0.8, disable_hybrid_swa_memory=False, radix_eviction_policy='lru', device='cuda', tp_size=1, pp_size=1, pp_max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=1072992076, constrained_json_whitespace_pattern=None, constrained_json_disable_any_whitespace=False, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, sleep_on_idle=False, mm_process_config={}, log_level='error', log_level_http=None, log_requests=False, log_requests_level=2, crash_dump_folder=None, show_time_cost=False, enable_metrics=False, enable_metrics_for_all_schedulers=False, tokenizer_metrics_custom_labels_header='x-custom-labels', tokenizer_metrics_allowed_custom_labels=None, bucket_time_to_first_token=None, bucket_inter_token_latency=None, bucket_e2e_request_latency=None, collect_tokens_histogram=False, prompt_tokens_buckets=None, generation_tokens_buckets=None, gc_warning_threshold_secs=0.0, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, enable_trace=False, otlp_traces_endpoint='localhost:4317', export_metrics_to_file=False, export_metrics_to_file_dir=None, api_key=None, served_model_name='deepseek-ai/DeepSeek-R1-Distill-Qwen-7B', weight_version='default', chat_template=None, completion_template=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, tool_call_parser=None, tool_server=None, sampling_defaults='model', dp_size=1, load_balance_method='round_robin', load_watch_interval=0.1, prefill_round_robin_balance=False, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, enable_lora=None, max_lora_rank=None, lora_target_modules=None, lora_paths=None, max_loaded_loras=None, max_loras_per_batch=8, lora_eviction_policy='lru', lora_backend='csgmv', max_lora_chunk_size=16, attention_backend='fa3', decode_attention_backend=None, prefill_attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', mm_attention_backend=None, nsa_prefill_backend='flashmla_sparse', nsa_decode_backend='fa3', speculative_algorithm=None, speculative_draft_model_path=None, speculative_draft_model_revision=None, speculative_draft_load_format=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, speculative_attention_mode='prefill', speculative_moe_runner_backend=None, speculative_ngram_min_match_window_size=1, speculative_ngram_max_match_window_size=12, speculative_ngram_min_bfs_breadth=1, speculative_ngram_max_bfs_breadth=10, speculative_ngram_match_type='BFS', speculative_ngram_branch_length=18, speculative_ngram_capacity=10000000, ep_size=1, moe_a2a_backend='none', moe_runner_backend='auto', flashinfer_mxfp4_moe_precision='default', enable_flashinfer_allreduce_fusion=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, eplb_rebalance_layers_per_chunk=None, eplb_min_rebalancing_utilization_threshold=1.0, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, moe_dense_tp_size=None, elastic_ep_backend=None, mooncake_ib_device=None, max_mamba_cache_size=None, mamba_ssm_dtype='float32', mamba_full_memory_ratio=0.9, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through', hicache_io_backend='kernel', hicache_mem_layout='layer_first', hicache_storage_backend=None, hicache_storage_prefetch_policy='best_effort', hicache_storage_backend_extra_config=None, enable_lmcache=False, kt_weight_path=None, kt_method=None, kt_cpuinfer=None, kt_threadpool_count=None, kt_num_gpu_experts=None, kt_max_deferred_experts_per_token=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, cpu_offload_gb=0, offload_group_size=-1, offload_num_in_group=1, offload_prefetch_step=1, offload_mode='cpu', multi_item_scoring_delimiter=None, disable_radix_cache=False, cuda_graph_max_bs=4, cuda_graph_bs=[1, 2, 4, 8, 12, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256], disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_profile_cuda_graph=False, enable_cudagraph_gc=False, enable_layerwise_nvtx_marker=False, enable_nccl_nvls=False, enable_symm_mem=False, disable_flashinfer_cutlass_moe_fp4_allgather=False, enable_tokenizer_batch_encode=False, disable_tokenizer_batch_decode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_mscclpp=False, enable_torch_symm_mem=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, enable_single_batch_overlap=False, tbo_token_distribution_threshold=0.48, enable_torch_compile=False, enable_piecewise_cuda_graph=False, enable_torch_compile_debug_mode=False, torch_compile_max_bs=32, piecewise_cuda_graph_max_tokens=4096, piecewise_cuda_graph_tokens=[4, 8, 12, 16, 20, 24, 28, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256, 288, 320, 352, 384, 416, 448, 480, 512, 640, 768, 896, 1024, 1152, 1280, 1408, 1536, 1664, 1792, 1920, 2048, 2176, 2304, 2432, 2560, 2688, 2816, 2944, 3072, 3200, 3328, 3456, 3584, 3712, 3840, 3968, 4096], piecewise_cuda_graph_compiler='eager', torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, triton_attention_split_tile_size=None, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, enable_weights_cpu_backup=False, enable_draft_weights_cpu_backup=False, allow_auto_truncate=False, enable_custom_logit_processor=False, flashinfer_mla_disable_ragged=False, disable_shared_experts_fusion=False, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, keep_mm_feature_on_device=False, enable_return_hidden_states=False, scheduler_recv_interval=1, numa_node=None, enable_deterministic_inference=False, rl_on_policy_target=None, enable_attn_tp_input_scattered=False, enable_nsa_prefill_context_parallel=False, enable_dynamic_batch_tokenizer=False, dynamic_batch_tokenizer_batch_size=32, dynamic_batch_tokenizer_batch_timeout=0.002, debug_tensor_dump_output_folder=None, debug_tensor_dump_layers=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_transfer_backend='mooncake', disaggregation_bootstrap_port=8998, disaggregation_decode_tp=None, disaggregation_decode_dp=None, disaggregation_prefill_pp=1, disaggregation_ib_device=None, disaggregation_decode_enable_offload_kvcache=False, num_reserved_decode_tokens=512, disaggregation_decode_polling_interval=1, custom_weight_loader=[], weight_loader_disable_mmap=False, remote_instance_weight_loader_seed_instance_ip=None, remote_instance_weight_loader_seed_instance_service_port=None, remote_instance_weight_loader_send_weights_group_ports=None, enable_pdmux=False, pdmux_config_path=None, sm_group_num=8, mm_max_concurrent_calls=32, mm_per_request_timeout=10.0, enable_broadcast_mm_inputs_process=False, decrypted_config_file=None, decrypted_draft_config_file=None, mm_enable_dp_encoder=False, hooks=None)
[2025-11-23 21:37:47] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-11-23 21:37:47] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2025-11-23 21:37:47] INFO utils.py:164: NumExpr defaulting to 16 threads.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:01<00:01,  1.38s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.36s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.36s/it]

Capturing batches (bs=1 avail_mem=46.69 GB): 100%|██████████| 20/20 [00:01<00:00, 18.34it/s]
==== Original Output ====
First, I recognize that the problem is asking for the sum of 1 and 3.

Next, I add the two numbers together to find their total.

Finally, I conclude that the result of 1 plus 3 is 4.


**Solution:**

We need to find the sum of 1 and 3.

\[
1 + 3 = 4
\]

Therefore, the final answer is \(\boxed{4}\).
==== Reasoning ====
First, I recognize that the problem is asking for the sum of 1 and 3.

Next, I add the two numbers together to find their total.

Finally, I conclude that the result of 1 plus 3 is 4.
==== Text ====
**Solution:**

We need to find the sum of 1 and 3.

\[
1 + 3 = 4
\]

Therefore, the final answer is \(\boxed{4}\).
[10]:
llm.shutdown()

Supporting New Reasoning Model Schemas#

For future reasoning models, you can implement the reasoning parser as a subclass of BaseReasoningFormatDetector in python/sglang/srt/reasoning_parser.py and specify the reasoning parser for new reasoning model schemas accordingly.