Offline Engine API#

SGLang provides a direct inference engine without the need for an HTTP server, especially for use cases where additional HTTP server adds unnecessary complexity or overhead. Here are two general use cases:

  • Offline Batch Inference

  • Custom Server on Top of the Engine

This document focuses on the offline batch inference, demonstrating four different inference modes:

  • Non-streaming synchronous generation

  • Streaming synchronous generation

  • Non-streaming asynchronous generation

  • Streaming asynchronous generation

Additionally, you can easily build a custom server on top of the SGLang offline engine. A detailed example working in a python script can be found in custom_server.

Nest Asyncio#

Note that if you want to use Offline Engine in ipython or some other nested loop code, you need to add the following code:

import nest_asyncio

nest_asyncio.apply()

Advanced Usage#

The engine supports vlm inference as well as extracting hidden states.

Please see the examples for further use cases.

Offline Batch Inference#

SGLang offline engine supports batch inference with efficient scheduling.

[1]:
# launch the offline engine
import asyncio

import sglang as sgl
import sglang.test.doc_patch
from sglang.utils import async_stream_and_merge, stream_and_merge

llm = sgl.Engine(model_path="qwen/qwen2.5-0.5b-instruct")
[2026-01-09 04:12:36] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2026-01-09 04:12:36] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2026-01-09 04:12:36] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2026-01-09 04:12:38] INFO server_args.py:1616: Attention backend not specified. Use fa3 backend by default.
[2026-01-09 04:12:38] INFO server_args.py:2513: Set soft_watchdog_timeout since in CI
[2026-01-09 04:12:38] INFO engine.py:153: server_args=ServerArgs(model_path='qwen/qwen2.5-0.5b-instruct', tokenizer_path='qwen/qwen2.5-0.5b-instruct', tokenizer_mode='auto', tokenizer_worker_num=1, skip_tokenizer_init=False, load_format='auto', model_loader_extra_config='{}', trust_remote_code=False, context_length=None, is_embedding=False, enable_multimodal=None, revision=None, model_impl='auto', host='127.0.0.1', port=30000, fastapi_root_path='', grpc_mode=False, skip_server_warmup=False, warmups=None, nccl_port=None, checkpoint_engine_wait_weights_before_ready=False, dtype='auto', quantization=None, quantization_param_path=None, kv_cache_dtype='auto', enable_fp32_lm_head=False, modelopt_quant=None, modelopt_checkpoint_restore_path=None, modelopt_checkpoint_save_path=None, modelopt_export_path=None, quantize_and_serve=False, rl_quant_profile=None, mem_fraction_static=0.835, max_running_requests=128, max_queued_requests=None, max_total_tokens=20480, chunked_prefill_size=8192, enable_dynamic_chunking=False, max_prefill_tokens=16384, prefill_max_requests=None, schedule_policy='fcfs', enable_priority_scheduling=False, abort_on_priority_when_disabled=False, schedule_low_priority_values_first=False, priority_scheduling_preemption_threshold=10, schedule_conservativeness=1.0, page_size=1, swa_full_tokens_ratio=0.8, disable_hybrid_swa_memory=False, radix_eviction_policy='lru', device='cuda', tp_size=1, pp_size=1, pp_max_micro_batch_size=None, pp_async_batch_depth=0, stream_interval=1, stream_output=False, random_seed=362184846, constrained_json_whitespace_pattern=None, constrained_json_disable_any_whitespace=False, watchdog_timeout=300, soft_watchdog_timeout=300, dist_timeout=None, download_dir=None, model_checksum=None, base_gpu_id=0, gpu_id_step=1, sleep_on_idle=False, custom_sigquit_handler=None, log_level='error', log_level_http=None, log_requests=False, log_requests_level=2, log_requests_format='text', log_requests_target=None, crash_dump_folder=None, show_time_cost=False, enable_metrics=False, enable_metrics_for_all_schedulers=False, tokenizer_metrics_custom_labels_header='x-custom-labels', tokenizer_metrics_allowed_custom_labels=None, bucket_time_to_first_token=None, bucket_inter_token_latency=None, bucket_e2e_request_latency=None, collect_tokens_histogram=False, prompt_tokens_buckets=None, generation_tokens_buckets=None, gc_warning_threshold_secs=0.0, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, enable_trace=False, otlp_traces_endpoint='localhost:4317', export_metrics_to_file=False, export_metrics_to_file_dir=None, api_key=None, served_model_name='qwen/qwen2.5-0.5b-instruct', weight_version='default', chat_template=None, completion_template=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, tool_call_parser=None, tool_server=None, sampling_defaults='model', dp_size=1, load_balance_method='round_robin', dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, enable_lora=None, max_lora_rank=None, lora_target_modules=None, lora_paths=None, max_loaded_loras=None, max_loras_per_batch=8, lora_eviction_policy='lru', lora_backend='csgmv', max_lora_chunk_size=16, attention_backend='fa3', decode_attention_backend=None, prefill_attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', mm_attention_backend=None, fp8_gemm_runner_backend='auto', nsa_prefill_backend='flashmla_sparse', nsa_decode_backend='fa3', disable_flashinfer_autotune=False, speculative_algorithm=None, speculative_draft_model_path=None, speculative_draft_model_revision=None, speculative_draft_load_format=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, speculative_attention_mode='prefill', speculative_draft_attention_backend=None, speculative_moe_runner_backend='auto', speculative_moe_a2a_backend=None, speculative_draft_model_quantization=None, speculative_ngram_min_match_window_size=1, speculative_ngram_max_match_window_size=12, speculative_ngram_min_bfs_breadth=1, speculative_ngram_max_bfs_breadth=10, speculative_ngram_match_type='BFS', speculative_ngram_branch_length=18, speculative_ngram_capacity=10000000, enable_multi_layer_eagle=False, ep_size=1, moe_a2a_backend='none', moe_runner_backend='auto', flashinfer_mxfp4_moe_precision='default', enable_flashinfer_allreduce_fusion=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm=None, init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, eplb_rebalance_layers_per_chunk=None, eplb_min_rebalancing_utilization_threshold=1.0, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, moe_dense_tp_size=None, elastic_ep_backend=None, mooncake_ib_device=None, max_mamba_cache_size=None, mamba_ssm_dtype='float32', mamba_full_memory_ratio=0.9, mamba_scheduler_strategy='no_buffer', mamba_track_interval=256, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through', hicache_io_backend='kernel', hicache_mem_layout='layer_first', hicache_storage_backend=None, hicache_storage_prefetch_policy='best_effort', hicache_storage_backend_extra_config=None, hierarchical_sparse_attention_extra_config=None, enable_lmcache=False, kt_weight_path=None, kt_method=None, kt_cpuinfer=None, kt_threadpool_count=None, kt_num_gpu_experts=None, kt_max_deferred_experts_per_token=None, dllm_algorithm=None, dllm_algorithm_config=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, cpu_offload_gb=0, offload_group_size=-1, offload_num_in_group=1, offload_prefetch_step=1, offload_mode='cpu', multi_item_scoring_delimiter=None, disable_radix_cache=False, cuda_graph_max_bs=4, cuda_graph_bs=[1, 2, 4, 8, 12, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256], disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_profile_cuda_graph=False, enable_cudagraph_gc=False, enable_layerwise_nvtx_marker=False, enable_nccl_nvls=False, enable_symm_mem=False, disable_flashinfer_cutlass_moe_fp4_allgather=False, enable_tokenizer_batch_encode=False, disable_tokenizer_batch_decode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_mscclpp=False, enable_torch_symm_mem=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, enable_single_batch_overlap=False, tbo_token_distribution_threshold=0.48, enable_torch_compile=False, enable_piecewise_cuda_graph=False, enable_torch_compile_debug_mode=False, torch_compile_max_bs=32, piecewise_cuda_graph_max_tokens=8192, piecewise_cuda_graph_tokens=[4, 8, 12, 16, 20, 24, 28, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256, 288, 320, 352, 384, 416, 448, 480, 512, 640, 768, 896, 1024, 1152, 1280, 1408, 1536, 1664, 1792, 1920, 2048, 2176, 2304, 2432, 2560, 2688, 2816, 2944, 3072, 3200, 3328, 3456, 3584, 3712, 3840, 3968, 4096, 4352, 4608, 4864, 5120, 5376, 5632, 5888, 6144, 6400, 6656, 6912, 7168, 7424, 7680, 7936, 8192], piecewise_cuda_graph_compiler='eager', torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, triton_attention_split_tile_size=None, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, enable_weights_cpu_backup=False, enable_draft_weights_cpu_backup=False, allow_auto_truncate=False, enable_custom_logit_processor=False, flashinfer_mla_disable_ragged=False, disable_shared_experts_fusion=False, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, keep_mm_feature_on_device=False, enable_return_hidden_states=False, enable_return_routed_experts=False, scheduler_recv_interval=1, numa_node=None, enable_deterministic_inference=False, rl_on_policy_target=None, enable_attn_tp_input_scattered=False, enable_nsa_prefill_context_parallel=False, nsa_prefill_cp_mode='in-seq-split', enable_fused_qk_norm_rope=False, enable_precise_embedding_interpolation=False, enable_dynamic_batch_tokenizer=False, dynamic_batch_tokenizer_batch_size=32, dynamic_batch_tokenizer_batch_timeout=0.002, debug_tensor_dump_output_folder=None, debug_tensor_dump_layers=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_transfer_backend='mooncake', disaggregation_bootstrap_port=8998, disaggregation_decode_tp=None, disaggregation_decode_dp=None, disaggregation_prefill_pp=1, disaggregation_ib_device=None, disaggregation_decode_enable_offload_kvcache=False, disaggregation_decode_enable_fake_auto=False, num_reserved_decode_tokens=512, disaggregation_decode_polling_interval=1, encoder_only=False, language_only=False, encoder_transfer_backend='zmq_to_scheduler', encoder_urls=[], custom_weight_loader=[], weight_loader_disable_mmap=False, remote_instance_weight_loader_seed_instance_ip=None, remote_instance_weight_loader_seed_instance_service_port=None, remote_instance_weight_loader_send_weights_group_ports=None, remote_instance_weight_loader_backend='nccl', remote_instance_weight_loader_start_seed_via_transfer_engine=False, enable_pdmux=False, pdmux_config_path=None, sm_group_num=8, mm_max_concurrent_calls=32, mm_per_request_timeout=10.0, enable_broadcast_mm_inputs_process=False, enable_prefix_mm_cache=False, mm_enable_dp_encoder=False, mm_process_config={}, limit_mm_data_per_request=None, decrypted_config_file=None, decrypted_draft_config_file=None, forward_hooks=None)
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  5.79it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  5.78it/s]

Capturing batches (bs=1 avail_mem=74.06 GB): 100%|██████████| 20/20 [00:00<00:00, 20.49it/s]

Non-streaming Synchronous Generation#

[2]:
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]

sampling_params = {"temperature": 0.8, "top_p": 0.95}

outputs = llm.generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
    print("===============================")
    print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
===============================
Prompt: Hello, my name is
Generated text:  Nanette and I'm a super friendly and personable doctor who specializes in providing care to clients with a wide range of medical conditions. I am a compassionate and attentive doctor who offers a wide variety of diagnostic and treatment services, including but not limited to, internal medicine, cardiovascular medicine, gastroenterology, and women's health. I'm also an expert in managing patients with chronic diseases and offering lifestyle modification programs to help them live a healthy and fulfilling life.
My goal is to help my patients feel confident in their health and well-being. I believe that effective communication, patience, and compassion are key to building trust with my patients and promoting
===============================
Prompt: The president of the United States is
Generated text:  a popular TV show, where viewers give him the choice to become the next president, or vote for him. This is how a president is elected.
The process of choosing the next president is called the "electoral process" and it is also called the "electoral college" in the United States.
The president will be elected for 4 years and the president will be re-elected every four years. This is the process in which the popular vote is not counted. However, a state governor or a court can override a state governor's decision to be president.
The president will have the power to appoint vice president, attorneys and other federal
===============================
Prompt: The capital of France is
Generated text: :
A) Paris
B) London
C) Berlin
D) Moscow

The capital of France is:
A) Paris

Paris, officially the City of Paris, is the capital and most populous city of France. It is located in the southwestern region of the country and is situated on the River Seine. The city's history dates back to ancient times and has been the seat of government, culture, and religion for more than 2,000 years.

Paris is known for its iconic landmarks such as the Eiffel Tower, Louvre Museum, Notre-Dame Cathedral, and the Arc de Triomphe.
===============================
Prompt: The future of AI is
Generated text:  unclear, but there are some key trends that may help to shape the way we see and use technology in the coming years.

One of the key trends that is shaping the future of AI is the increasing use of data. With the increasing amount of data being generated and collected, the need for AI systems to be able to understand and process this data has become more important. As AI systems become more sophisticated, they will need to be able to learn from the data they are exposed to, and adapt to changing conditions in the world.

Another key trend is the development of deep learning. Deep learning is a type of AI that relies on neural networks

Streaming Synchronous Generation#

[3]:
prompts = [
    "Write a short, neutral self-introduction for a fictional character. Hello, my name is",
    "Provide a concise factual statement about France’s capital city. The capital of France is",
    "Explain possible future trends in artificial intelligence. The future of AI is",
]

sampling_params = {
    "temperature": 0.2,
    "top_p": 0.9,
}

print("\n=== Testing synchronous streaming generation with overlap removal ===\n")

for prompt in prompts:
    print(f"Prompt: {prompt}")
    merged_output = stream_and_merge(llm, prompt, sampling_params)
    print("Generated text:", merged_output)
    print()

=== Testing synchronous streaming generation with overlap removal ===

Prompt: Write a short, neutral self-introduction for a fictional character. Hello, my name is
Generated text:  [Name], and I'm a [job title] at [company name]. I'm excited to meet you and learn more about you. What can you tell me about yourself? I'm a [job title] at [company name], and I'm passionate about [job title] and [job title]. I enjoy [job title] because [reason for interest]. What's your favorite hobby or activity? I enjoy [hobby or activity]. What's your favorite book or movie? I love [book or movie]. What's your favorite place to go? I love [favorite place]. What's your favorite color? I love [

Prompt: Provide a concise factual statement about France’s capital city. The capital of France is
Generated text:  Paris, which is known for its iconic landmarks such as the Eiffel Tower, Notre-Dame Cathedral, and the Louvre Museum. It is also home to the French Parliament and the French National Museum of Modern Art. Paris is a cultural and historical center with a rich history dating back to the Roman Empire and the French Revolution. It is a major transportation hub and a popular tourist destination, attracting millions of visitors each year. The city is known for its cuisine, fashion, and art, and is home to many world-renowned museums, theaters, and galleries. Paris is a vibrant and dynamic city that continues to evolve and thrive in

Prompt: Explain possible future trends in artificial intelligence. The future of AI is
Generated text:  likely to be characterized by rapid advancements in areas such as machine learning, natural language processing, and computer vision. Some possible future trends in AI include:

1. Increased use of AI in healthcare: AI is already being used in healthcare to improve patient outcomes, reduce costs, and increase efficiency. As AI technology continues to improve, we can expect to see even more widespread use of AI in healthcare.

2. AI in finance: AI is already being used in finance to improve fraud detection, risk management, and trading algorithms. As AI technology continues to improve, we can expect to see even more widespread use of AI in finance.

3. AI

Non-streaming Asynchronous Generation#

[4]:
prompts = [
    "Write a short, neutral self-introduction for a fictional character. Hello, my name is",
    "Provide a concise factual statement about France’s capital city. The capital of France is",
    "Explain possible future trends in artificial intelligence. The future of AI is",
]

sampling_params = {"temperature": 0.8, "top_p": 0.95}

print("\n=== Testing asynchronous batch generation ===")


async def main():
    outputs = await llm.async_generate(prompts, sampling_params)

    for prompt, output in zip(prompts, outputs):
        print(f"\nPrompt: {prompt}")
        print(f"Generated text: {output['text']}")


asyncio.run(main())

=== Testing asynchronous batch generation ===

Prompt: Write a short, neutral self-introduction for a fictional character. Hello, my name is
Generated text:  __________ and I am a/an _________ . I am here to __________ and __________. I hope that you have a pleasant day. Please feel free to say anything you want to say in the review.

As an AI language model, I am programmed to understand and process natural language text. When I hear your name and your role, I am ready to assist you in generating an appropriate self-introduction. How can I assist you today? Please provide me with more details about your role and your purpose in the conversation. Once I have your information, I will be able to generate a neutral self-introduction that is tailored to your

Prompt: Provide a concise factual statement about France’s capital city. The capital of France is
Generated text:  Paris.

France's capital is Paris.

Prompt: Explain possible future trends in artificial intelligence. The future of AI is
Generated text:  bound to be exciting and diverse, with new technologies constantly emerging and evolving. Here are some potential trends we can expect to see in AI in the coming years:

1. AI in healthcare: AI will continue to play a vital role in healthcare, with more doctors using AI to diagnose and treat patients. AI can help doctors detect diseases early, reduce the risk of medical errors, and improve patient outcomes.

2. AI in finance: AI can help finance companies to analyze large amounts of data, identify fraud, and improve risk management. AI can also help finance companies to personalize financial advice and recommendations to individual customers.

3. AI in transportation:

Streaming Asynchronous Generation#

[5]:
prompts = [
    "Write a short, neutral self-introduction for a fictional character. Hello, my name is",
    "Provide a concise factual statement about France’s capital city. The capital of France is",
    "Explain possible future trends in artificial intelligence. The future of AI is",
]

sampling_params = {"temperature": 0.8, "top_p": 0.95}

print("\n=== Testing asynchronous streaming generation (no repeats) ===")


async def main():
    for prompt in prompts:
        print(f"\nPrompt: {prompt}")
        print("Generated text: ", end="", flush=True)

        # Replace direct calls to async_generate with our custom overlap-aware version
        async for cleaned_chunk in async_stream_and_merge(llm, prompt, sampling_params):
            print(cleaned_chunk, end="", flush=True)

        print()  # New line after each prompt


asyncio.run(main())

=== Testing asynchronous streaming generation (no repeats) ===

Prompt: Write a short, neutral self-introduction for a fictional character. Hello, my name is
Generated text:  [Name], and I'm a [career] at [company name]. I love [majority of your favorite hobby], and I find myself [reason for passion]. I'm excited to meet you and learn more about you.
**[Name]** is a fictional character, specifically a **business analyst** at a **legal firm**. He has a strong passion for law and enjoys solving complex legal problems, especially those involving digital and international transactions. His career is centered around understanding legal precedents and interpreting complex contracts and regulations. He is enthusiastic and always eager to contribute to the legal team. He is known for his analytical skills and his

Prompt: Provide a concise factual statement about France’s capital city. The capital of France is
Generated text:  Paris, also known as the City of Light. It is located in the south of France and is the largest city in the country. Paris is known for its iconic landmarks such as the Eiffel Tower, Notre-Dame Cathedral, Louvre Museum, and the Champs-Elysees, and is home to many world-renowned artists, architects, and writers. The city is also known for its diverse cultural scene, including music, art, and theater, and is a major financial hub in Europe. Paris is a global cultural and economic hub, and has a long and storied history as a major center of power and culture

Prompt: Explain possible future trends in artificial intelligence. The future of AI is
Generated text:  likely to be highly dynamic and unpredictable, as new technologies and applications are constantly being developed and refined. Here are some potential future trends in AI that are currently being discussed and predicted:

1. Increased Use of Machine Learning: As AI continues to mature, machine learning techniques will be applied to more and more areas, including healthcare, finance, and transportation. These techniques will help machines learn from data, improve their performance, and adapt to new situations.

2. AI Ethics and Privacy: As AI becomes more integrated into our daily lives, there will be growing concerns about its ethical implications and potential privacy issues. Governments and organizations will need to develop policies
[6]:
llm.shutdown()