Offline Engine API#
SGLang provides a direct inference engine without the need for an HTTP server, especially for use cases where additional HTTP server adds unnecessary complexity or overhead. Here are two general use cases:
Offline Batch Inference
Custom Server on Top of the Engine
This document focuses on the offline batch inference, demonstrating four different inference modes:
Non-streaming synchronous generation
Streaming synchronous generation
Non-streaming asynchronous generation
Streaming asynchronous generation
Additionally, you can easily build a custom server on top of the SGLang offline engine. A detailed example working in a python script can be found in custom_server.
Nest Asyncio#
Note that if you want to use Offline Engine in ipython or some other nested loop code, you need to add the following code:
import nest_asyncio
nest_asyncio.apply()
Advanced Usage#
The engine supports vlm inference as well as extracting hidden states.
Please see the examples for further use cases.
Offline Batch Inference#
SGLang offline engine supports batch inference with efficient scheduling.
[1]:
# launch the offline engine
import asyncio
import sglang as sgl
import sglang.test.doc_patch
from sglang.utils import async_stream_and_merge, stream_and_merge
llm = sgl.Engine(model_path="qwen/qwen2.5-0.5b-instruct")
[2025-11-23 21:42:02] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-11-23 21:42:02] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2025-11-23 21:42:02] INFO utils.py:164: NumExpr defaulting to 16 threads.
[2025-11-23 21:42:04] WARNING server_args.py:1286: Attention backend not explicitly specified. Use fa3 backend by default.
[2025-11-23 21:42:04] INFO engine.py:123: server_args=ServerArgs(model_path='qwen/qwen2.5-0.5b-instruct', tokenizer_path='qwen/qwen2.5-0.5b-instruct', tokenizer_mode='auto', tokenizer_worker_num=1, skip_tokenizer_init=False, load_format='auto', model_loader_extra_config='{}', trust_remote_code=False, context_length=None, is_embedding=False, enable_multimodal=None, revision=None, model_impl='auto', host='127.0.0.1', port=30000, fastapi_root_path='', grpc_mode=False, skip_server_warmup=False, warmups=None, nccl_port=None, checkpoint_engine_wait_weights_before_ready=False, dtype='auto', quantization=None, quantization_param_path=None, kv_cache_dtype='auto', enable_fp32_lm_head=False, modelopt_quant=None, modelopt_checkpoint_restore_path=None, modelopt_checkpoint_save_path=None, modelopt_export_path=None, quantize_and_serve=False, mem_fraction_static=0.835, max_running_requests=128, max_queued_requests=None, max_total_tokens=20480, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='fcfs', enable_priority_scheduling=False, abort_on_priority_when_disabled=False, schedule_low_priority_values_first=False, priority_scheduling_preemption_threshold=10, schedule_conservativeness=1.0, page_size=1, hybrid_kvcache_ratio=None, swa_full_tokens_ratio=0.8, disable_hybrid_swa_memory=False, radix_eviction_policy='lru', device='cuda', tp_size=1, pp_size=1, pp_max_micro_batch_size=None, stream_interval=1, stream_output=False, random_seed=319157589, constrained_json_whitespace_pattern=None, constrained_json_disable_any_whitespace=False, watchdog_timeout=300, dist_timeout=None, download_dir=None, base_gpu_id=0, gpu_id_step=1, sleep_on_idle=False, mm_process_config={}, log_level='error', log_level_http=None, log_requests=False, log_requests_level=2, crash_dump_folder=None, show_time_cost=False, enable_metrics=False, enable_metrics_for_all_schedulers=False, tokenizer_metrics_custom_labels_header='x-custom-labels', tokenizer_metrics_allowed_custom_labels=None, bucket_time_to_first_token=None, bucket_inter_token_latency=None, bucket_e2e_request_latency=None, collect_tokens_histogram=False, prompt_tokens_buckets=None, generation_tokens_buckets=None, gc_warning_threshold_secs=0.0, decode_log_interval=40, enable_request_time_stats_logging=False, kv_events_config=None, enable_trace=False, otlp_traces_endpoint='localhost:4317', export_metrics_to_file=False, export_metrics_to_file_dir=None, api_key=None, served_model_name='qwen/qwen2.5-0.5b-instruct', weight_version='default', chat_template=None, completion_template=None, file_storage_path='sglang_storage', enable_cache_report=False, reasoning_parser=None, tool_call_parser=None, tool_server=None, sampling_defaults='model', dp_size=1, load_balance_method='round_robin', load_watch_interval=0.1, prefill_round_robin_balance=False, dist_init_addr=None, nnodes=1, node_rank=0, json_model_override_args='{}', preferred_sampling_params=None, enable_lora=None, max_lora_rank=None, lora_target_modules=None, lora_paths=None, max_loaded_loras=None, max_loras_per_batch=8, lora_eviction_policy='lru', lora_backend='csgmv', max_lora_chunk_size=16, attention_backend='fa3', decode_attention_backend=None, prefill_attention_backend=None, sampling_backend='flashinfer', grammar_backend='xgrammar', mm_attention_backend=None, nsa_prefill_backend='flashmla_sparse', nsa_decode_backend='fa3', speculative_algorithm=None, speculative_draft_model_path=None, speculative_draft_model_revision=None, speculative_draft_load_format=None, speculative_num_steps=None, speculative_eagle_topk=None, speculative_num_draft_tokens=None, speculative_accept_threshold_single=1.0, speculative_accept_threshold_acc=1.0, speculative_token_map=None, speculative_attention_mode='prefill', speculative_moe_runner_backend=None, speculative_ngram_min_match_window_size=1, speculative_ngram_max_match_window_size=12, speculative_ngram_min_bfs_breadth=1, speculative_ngram_max_bfs_breadth=10, speculative_ngram_match_type='BFS', speculative_ngram_branch_length=18, speculative_ngram_capacity=10000000, ep_size=1, moe_a2a_backend='none', moe_runner_backend='auto', flashinfer_mxfp4_moe_precision='default', enable_flashinfer_allreduce_fusion=False, deepep_mode='auto', ep_num_redundant_experts=0, ep_dispatch_algorithm='static', init_expert_location='trivial', enable_eplb=False, eplb_algorithm='auto', eplb_rebalance_num_iterations=1000, eplb_rebalance_layers_per_chunk=None, eplb_min_rebalancing_utilization_threshold=1.0, expert_distribution_recorder_mode=None, expert_distribution_recorder_buffer_size=1000, enable_expert_distribution_metrics=False, deepep_config=None, moe_dense_tp_size=None, elastic_ep_backend=None, mooncake_ib_device=None, max_mamba_cache_size=None, mamba_ssm_dtype='float32', mamba_full_memory_ratio=0.9, enable_hierarchical_cache=False, hicache_ratio=2.0, hicache_size=0, hicache_write_policy='write_through', hicache_io_backend='kernel', hicache_mem_layout='layer_first', hicache_storage_backend=None, hicache_storage_prefetch_policy='best_effort', hicache_storage_backend_extra_config=None, enable_lmcache=False, kt_weight_path=None, kt_method=None, kt_cpuinfer=None, kt_threadpool_count=None, kt_num_gpu_experts=None, kt_max_deferred_experts_per_token=None, enable_double_sparsity=False, ds_channel_config_path=None, ds_heavy_channel_num=32, ds_heavy_token_num=256, ds_heavy_channel_type='qk', ds_sparse_decode_threshold=4096, cpu_offload_gb=0, offload_group_size=-1, offload_num_in_group=1, offload_prefetch_step=1, offload_mode='cpu', multi_item_scoring_delimiter=None, disable_radix_cache=False, cuda_graph_max_bs=4, cuda_graph_bs=[1, 2, 4, 8, 12, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256], disable_cuda_graph=False, disable_cuda_graph_padding=False, enable_profile_cuda_graph=False, enable_cudagraph_gc=False, enable_layerwise_nvtx_marker=False, enable_nccl_nvls=False, enable_symm_mem=False, disable_flashinfer_cutlass_moe_fp4_allgather=False, enable_tokenizer_batch_encode=False, disable_tokenizer_batch_decode=False, disable_outlines_disk_cache=False, disable_custom_all_reduce=False, enable_mscclpp=False, enable_torch_symm_mem=False, disable_overlap_schedule=False, enable_mixed_chunk=False, enable_dp_attention=False, enable_dp_lm_head=False, enable_two_batch_overlap=False, enable_single_batch_overlap=False, tbo_token_distribution_threshold=0.48, enable_torch_compile=False, enable_piecewise_cuda_graph=False, enable_torch_compile_debug_mode=False, torch_compile_max_bs=32, piecewise_cuda_graph_max_tokens=4096, piecewise_cuda_graph_tokens=[4, 8, 12, 16, 20, 24, 28, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256, 288, 320, 352, 384, 416, 448, 480, 512, 640, 768, 896, 1024, 1152, 1280, 1408, 1536, 1664, 1792, 1920, 2048, 2176, 2304, 2432, 2560, 2688, 2816, 2944, 3072, 3200, 3328, 3456, 3584, 3712, 3840, 3968, 4096], piecewise_cuda_graph_compiler='eager', torchao_config='', enable_nan_detection=False, enable_p2p_check=False, triton_attention_reduce_in_fp32=False, triton_attention_num_kv_splits=8, triton_attention_split_tile_size=None, num_continuous_decode_steps=1, delete_ckpt_after_loading=False, enable_memory_saver=False, enable_weights_cpu_backup=False, enable_draft_weights_cpu_backup=False, allow_auto_truncate=False, enable_custom_logit_processor=False, flashinfer_mla_disable_ragged=False, disable_shared_experts_fusion=False, disable_chunked_prefix_cache=False, disable_fast_image_processor=False, keep_mm_feature_on_device=False, enable_return_hidden_states=False, scheduler_recv_interval=1, numa_node=None, enable_deterministic_inference=False, rl_on_policy_target=None, enable_attn_tp_input_scattered=False, enable_nsa_prefill_context_parallel=False, enable_dynamic_batch_tokenizer=False, dynamic_batch_tokenizer_batch_size=32, dynamic_batch_tokenizer_batch_timeout=0.002, debug_tensor_dump_output_folder=None, debug_tensor_dump_layers=None, debug_tensor_dump_input_file=None, debug_tensor_dump_inject=False, disaggregation_mode='null', disaggregation_transfer_backend='mooncake', disaggregation_bootstrap_port=8998, disaggregation_decode_tp=None, disaggregation_decode_dp=None, disaggregation_prefill_pp=1, disaggregation_ib_device=None, disaggregation_decode_enable_offload_kvcache=False, num_reserved_decode_tokens=512, disaggregation_decode_polling_interval=1, custom_weight_loader=[], weight_loader_disable_mmap=False, remote_instance_weight_loader_seed_instance_ip=None, remote_instance_weight_loader_seed_instance_service_port=None, remote_instance_weight_loader_send_weights_group_ports=None, enable_pdmux=False, pdmux_config_path=None, sm_group_num=8, mm_max_concurrent_calls=32, mm_per_request_timeout=10.0, enable_broadcast_mm_inputs_process=False, decrypted_config_file=None, decrypted_draft_config_file=None, mm_enable_dp_encoder=False, hooks=None)
[2025-11-23 21:42:11] INFO utils.py:148: Note: detected 112 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
[2025-11-23 21:42:11] INFO utils.py:151: Note: NumExpr detected 112 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.
[2025-11-23 21:42:11] INFO utils.py:164: NumExpr defaulting to 16 threads.
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
Loading safetensors checkpoint shards: 0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 4.83it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00, 4.82it/s]
Capturing batches (bs=1 avail_mem=76.22 GB): 100%|██████████| 20/20 [00:01<00:00, 17.83it/s]
Non-streaming Synchronous Generation#
[2]:
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = {"temperature": 0.8, "top_p": 0.95}
outputs = llm.generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
print("===============================")
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
===============================
Prompt: Hello, my name is
Generated text: Sarah and I'm from the United States. I'm here to share with you what my life has been like so far and what I have learned throughout this wonderful journey. As I've said before, I have had a great time in China. I have been here for over a year now. The school I'm studying is in Nantong. It's a beautiful city. It's home to some amazing places and people. On the first day of school, we were learning to get along with different cultures. We were learning to be more open to new ideas. I think that was very important. It's not easy to be open
===============================
Prompt: The president of the United States is
Generated text: a title given to the highest official of the government in the executive branch of the federal government of the United States. It is the most senior of the two branches of the government. The vice president is the second most senior official of the government.
The president is elected to a two-year term, except when it is determined that the incumbent president will resign, be removed from office, or is unable to discharge the powers of the office due to a medical condition. If the term expires, the president is eligible to reelect by a 30-day period in which they must receive a majority vote from the electors within the state they are
===============================
Prompt: The capital of France is
Generated text: _________. A. Paris B. Paris C. Paris D. Paris
Paris is the capital of France. So the correct answer is A. Paris.
The capital city of France is Paris, which is also known as the City of Light and the capital of France. It is a bustling city with a rich history and culture. Paris is famous for its iconic landmarks such as the Eiffel Tower, Louvre Museum, and Notre-Dame Cathedral, which have been a part of the city's identity since its founding as the capital of France.
Option B and C are not correct as they refer to other cities. Option
===============================
Prompt: The future of AI is
Generated text: here: it's the end of your time
In this video, we take a look at the key drivers behind this change, including advances in machine learning, quantum computing, and robotics.
• 1 of 4
## 10 key drivers for AI
In this video, we take a look at the key drivers behind the rapid pace of change in AI. The latest developments in machine learning, quantum computing, and robotics are driving this change.
• 1 of 4
## Machine learning is reshaping AI
With the ever increasing volume of data that's being generated, it's time for a new way to transform
Streaming Synchronous Generation#
[3]:
prompts = [
"Write a short, neutral self-introduction for a fictional character. Hello, my name is",
"Provide a concise factual statement about France’s capital city. The capital of France is",
"Explain possible future trends in artificial intelligence. The future of AI is",
]
sampling_params = {
"temperature": 0.2,
"top_p": 0.9,
}
print("\n=== Testing synchronous streaming generation with overlap removal ===\n")
for prompt in prompts:
print(f"Prompt: {prompt}")
merged_output = stream_and_merge(llm, prompt, sampling_params)
print("Generated text:", merged_output)
print()
=== Testing synchronous streaming generation with overlap removal ===
Prompt: Write a short, neutral self-introduction for a fictional character. Hello, my name is
Generated text: [Name], and I'm a [job title] at [company name]. I'm excited to meet you and learn more about you. What can you tell me about yourself? I'm a [age], [gender], [nationality], [occupation], and I have [number] years of experience in [field of work]. I'm always looking for new challenges and opportunities to grow and learn. What do you do for a living? I'm always looking for new challenges and opportunities to grow and learn. What do you enjoy doing? I enjoy [job title], and I'm always looking for new challenges and opportunities to grow and
Prompt: Provide a concise factual statement about France’s capital city. The capital of France is
Generated text: Paris, known for its iconic landmarks such as the Eiffel Tower, Notre-Dame Cathedral, and the Louvre Museum. It is also home to the French Parliament, the French Academy of Sciences, and the French Quarter. Paris is a bustling city with a rich cultural heritage and is a popular tourist destination. It is also known for its cuisine, including its famous croissants and its famous French fries. The city is home to many famous French artists, including Picasso and Van Gogh, and is a major center for the arts. Paris is a city of contrasts, with its modern architecture and historical landmarks blending seamlessly. It is
Prompt: Explain possible future trends in artificial intelligence. The future of AI is
Generated text: likely to be characterized by rapid advancements in several key areas, including:
1. Increased integration with human intelligence: AI is likely to become more integrated with human intelligence, allowing machines to learn and adapt to human behavior and preferences.
2. Enhanced machine learning: Machine learning algorithms will become even more sophisticated, allowing AI systems to learn from data and make more accurate predictions and decisions.
3. Improved natural language processing: Natural language processing will become even more advanced, allowing AI systems to understand and respond to human language in ways that are more intuitive and natural.
4. Increased use of AI in healthcare: AI will be used to improve the accuracy and
Non-streaming Asynchronous Generation#
[4]:
prompts = [
"Write a short, neutral self-introduction for a fictional character. Hello, my name is",
"Provide a concise factual statement about France’s capital city. The capital of France is",
"Explain possible future trends in artificial intelligence. The future of AI is",
]
sampling_params = {"temperature": 0.8, "top_p": 0.95}
print("\n=== Testing asynchronous batch generation ===")
async def main():
outputs = await llm.async_generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
print(f"\nPrompt: {prompt}")
print(f"Generated text: {output['text']}")
asyncio.run(main())
=== Testing asynchronous batch generation ===
Prompt: Write a short, neutral self-introduction for a fictional character. Hello, my name is
Generated text: [insert first name] and I'm a [insert profession or career] who have a strong passion for [insert something about your career or interests]. I have always been interested in learning more about the world and have always been drawn to [insert what you like to do]. I enjoy [insert why you enjoy doing what you do] and I strive to [insert what you plan to do next]. What’s your name? How do you get started? Here’s how I get started: [insert how you get started] I hope this short, neutral self-introduction is a good start. You can expand on your personality and interests by
Prompt: Provide a concise factual statement about France’s capital city. The capital of France is
Generated text: Paris, the world’s third largest city and the largest metropolitan area in the European Union. It is also the seat of government, of the French Government, of the European Parliament, of the French Chamber of Deputies, of the French Senate, and of the French Supreme Council for Foreign Affairs. Paris is the oldest capital city in the world and is the birthplace of modern art, cinema and French literature, with many famous historical and cultural landmarks. It is also the cultural and economic capital of France, attracting millions of visitors every year. Paris is home to many cultural institutions and museums, including the Louvre Museum and the Musée d
Prompt: Explain possible future trends in artificial intelligence. The future of AI is
Generated text: likely to be characterized by rapid advancements and increasing integration into various fields. Here are some possible trends that are likely to shape the future of AI:
1. Increased focus on AI ethics: With the increasing amount of data generated by AI, there is a growing concern about the impact of AI on society. Therefore, it is likely that ethical considerations will become increasingly important in the development and deployment of AI systems.
2. Increased integration of AI with other technologies: AI is already integrated into a variety of technologies, such as smartphones, smart homes, and autonomous vehicles. It is likely that this integration will continue as more technologies become integrated with AI.
Streaming Asynchronous Generation#
[5]:
prompts = [
"Write a short, neutral self-introduction for a fictional character. Hello, my name is",
"Provide a concise factual statement about France’s capital city. The capital of France is",
"Explain possible future trends in artificial intelligence. The future of AI is",
]
sampling_params = {"temperature": 0.8, "top_p": 0.95}
print("\n=== Testing asynchronous streaming generation (no repeats) ===")
async def main():
for prompt in prompts:
print(f"\nPrompt: {prompt}")
print("Generated text: ", end="", flush=True)
# Replace direct calls to async_generate with our custom overlap-aware version
async for cleaned_chunk in async_stream_and_merge(llm, prompt, sampling_params):
print(cleaned_chunk, end="", flush=True)
print() # New line after each prompt
asyncio.run(main())
=== Testing asynchronous streaming generation (no repeats) ===
Prompt: Write a short, neutral self-introduction for a fictional character. Hello, my name is
Generated text: __________ and I'm a/an _____________________.
I'm excited to meet you!
As an AI language model, I'm here to provide information and assist you with any questions you may have. How can I help you today?
I'm happy to introduce myself as an AI language model. My name is Caffeinated AI and I'm here to help you with any questions or concerns you may have. How can I assist you today?
I'm confident that I can provide you with useful information and answer any questions you may have. Let me know if there's anything specific you'd like to know or if you have any other
Prompt: Provide a concise factual statement about France’s capital city. The capital of France is
Generated text: Paris. Its population is about 2.7 million, and it is the largest city in Europe. Paris is known for its rich history, famous landmarks, and annual Parisian festivals. It is also the seat of government and culture for much of France. The city is characterized by its stunning architecture, vibrant arts scene, and cultural diversity. It is often referred to as the "City of Light" due to its numerous cinemas, theaters, and nightclubs. Paris is a major international hub for fashion, entertainment, and technology, and has played a significant role in shaping French and global culture. It is also the birthplace of Napoleon
Prompt: Explain possible future trends in artificial intelligence. The future of AI is
Generated text: likely to be characterized by several key trends, including:
1. Increased automation: AI is expected to become more integrated into our daily lives, and will likely automate many tasks that are currently done by humans. This may include tasks such as logistics, manufacturing, and healthcare, which are currently done by people. However, it's also possible that AI will also be used to automate certain jobs that are repetitive or can be done by machines, thus freeing up more human time for other tasks.
2. Improved privacy and security: As AI becomes more integrated into our daily lives, there is a risk that it may also be used to collect and analyze
[6]:
llm.shutdown()