How to Support a New Language Model
To support a new model in SGLang, you only need to add a single file under the SGLang Models Directory. You can learn from existing model implementations and create a new file for your model. For most models, you should be able to find a similar model to start with (e.g., starting from Llama). Also refer how to port a Model from vLLM to SGLangHow to Support a New Multimodal Large Language Model
To support a new multimodal large language model (MLLM) in SGLang, there are several key components in addition to the standard LLM support:-
Register your new model as multimodal:
Extend
is_multimodal_modelin model_config.py to returnTruefor your model. - Register a new chat-template: Only when your default chat-template is unable to accept images as input: Register a new chat template in conversation.py and the corresponding matching function.
-
Multimodal Data Processor:
Define a new
Processorclass that inherits fromBaseMultimodalProcessorand register this processor as your model’s dedicated processor. See multimodal_processor.py for more details. -
Handle Multimodal Tokens:
Implement a
pad_input_idsfunction for your new model. In this function, multimodal tokens in the prompt should be expanded (if necessary) and padded with multimodal-data-hashes so that SGLang can recognize different multimodal data withRadixAttention. -
Handle Image Feature Extraction:
Implement a
get_image_featurefunction for your new model, which extracts image features from raw image data and converts them into the embeddings used by the language model. -
Adapt to Vision Attention:
Adapt the multi-headed
Attentionof ViT with SGLang’sVisionAttention.
Testing and Debugging
Please note all your testing and benchmarking results in PR description.Interactive Debugging
For interactive debugging, compare the outputs of Hugging Face/Transformers and SGLang. The following two commands should give the same text output and very similar prefill logits:- Get the reference output:
Command
- Get the SGLang output:
Command
Add the Model to the Test Suite
To ensure the new model is well maintained, add it to the test suite by including it in theALL_OTHER_MODELS list in
the test_generation_models.py
file, test the new model on your local machine and report the results on demonstrative benchmarks (GSM8K, MMLU, MMMU,
MMMU-Pro, etc.) in your PR. \
For VLMs, also include a test in test_vision_openai_server_{x}.py (e.g. test_vision_openai_server_a.py).
This is an example command to run to test a new model on your local machine:
Run Test
Benchmark
- (Required) MMMU: follow MMMU benchmark README.md to get SGLang vs. HF Transformer accuracy comparison. The accuracy score from SGLang run should not be much lower than that from HF Transformer run. Similarly, follow https://docs.sglang.io/developer_guide/benchmark_and_profiling.html to get performance comparison: TTFT and throughput must meet or exceed baselines (e.g., HF Transformer).
- (Optional) Other evals: If you ran other evals, please note the results in PR description.
Port a Model from vLLM to SGLang
The vLLM Models Directory is a valuable resource, as vLLM covers many models. SGLang reuses vLLM’s interface and some layers, making it easier to port models from vLLM to SGLang. To port a model from vLLM to SGLang:- Compare these two files for guidance:
- The major differences include:
- Replace vLLM’s
AttentionwithRadixAttention(ensure you passlayer_idtoRadixAttention). - Replace vLLM’s
LogitsProcessorwith SGLang’sLogitsProcessor. - Replace the multi-headed
Attentionof ViT with SGLang’sVisionAttention. - Replace other vLLM layers (such as
RMSNorm,SiluAndMul) with SGLang layers. - Remove
Sample. - Change the
forward()functions and add aforward_batch()method. - Add
EntryClassat the end. - Ensure that the new implementation uses only SGLang components and does not rely on any vLLM components.
- Replace vLLM’s
Registering an External Model Implementation
In addition to the methods above, you can register your new model with theModelRegistry before launching the server.
This allows you to integrate your model without modifying the source code.
For example:
Register Model
Example: Implementing and Serving a Llama Wrapper Model
Below is an introductory, step-by-step walkthrough on how to implement a new model end-to-end in SGLang and then run it via the Offline Engine.Implementing Our Model
To keep things simple, this new model will be a simple wrapper around Llama 3.1-8B-Instruct, and our goal will be just to bias the output logits for eachforward call by taking the square root of each individual logit.
Let’s start by defining our model in a file called llama_wrapper.py.
The first step is to import the necessary libraries from SRT, which is SGLang’s internal backend.
Example
class for our model and have it inherit from LlamaForCausalLM, which allows our model to access LlamaForCausalLM’s predefined modules and layers, such as LlamaAttention and LlamaMLP.
Note that almost all model implementations take in config and quant_config as arguments for their __init__ method; config and quant_config are passed in via model_loader/loader.py.
Because we have inherited from LlamaForCausalLM, we can pass our parameters directly to its constructor, which will set the member variables for us.
Class Definition
forward method, which is what will be called at inference time.
Note that the signature for forward is essentially the same for any model; you can take a look at the other models defined in the models directory for references.
To see where exactly forward is called in the SGLang runtime’s internals, take a look at forward_decode and forward_extend in the ModelRunner class.
Forward Method Signature
__call__ method for self.model (which is a member variable that LlamaForCausalLM defines in its __init__ method), which eventually calls LlamaForCausalLM’s forward method.
After that, we feed the hidden_states into our model’s LogitsProcessor (again defined in LlamaForCausalLM).
Call Model and LogitsProcessor
Logit Biasing
LlamaWrapper model is created and ready to be served!
Serving Our Model Via SGLang’s Offline Engine
The next step of this walkthrough involves hosting our new model offline, so that it can be served locally and without an HTTP server. First, create a new file calledrun.py.
Now, we must ensure that SGLang’s ModelRegistry can find our model.
To do this, we first download the model’s configuration and weights from Huggingface.
Example
LlamaWrapper by changing the architectures field in ./llama_ckpt/config.json to be LlamaWrapper.
That way, when we pass in the path of our model checkpoint to SGLang, it will know that we want to use “LlamaWrapper” instead of “LlamaForCausalLM” as our model.
Example
LlamaWrapper class to the “LlamaWrapper” registry keyword, then SGLang won’t be able to find our model.
Thus, to register our LlamaWrapper, we want to follow the steps in the above section titled “Registering an External Model Implementation”.
Register LlamaWrapper
Engine, we just pass in the path to the local model directory.
Then, our LlamaWrapper is ready to be served; for this walkthrough, we will use SGLang Engine’s non-streaming asynchronous generation endpoint.
Example
python run.py, we will get the outputs of our newly created model!
Serving External Models via the Standard CLI
The previous sections show how to register a model programmatically viaModelRegistry and serve it through the Offline Engine. Similar to vLLM model plugin, there is an alternative that lets you keep using the standard python -m sglang.launch_server CLI without modifying any SGLang source code: you can register your model using the SGLANG_EXTERNAL_MODEL_PACKAGE environment variable.
The EntryClass Variable
When SGLang scans a model package, it looks for the variable EntryClass at the module level of your Python file. The model registry imports your file, checks for EntryClass, and registers the class assigned to it. If you are using a model based on HuggingFace, the name of this class needs to match the "architectures" field in your model’s config.json.
For example, if you are implementing a Llama wrapper, add this line at the end of your model file:
Example
Example: Text-Only Model
Using the same Llama wrapper from the previous section, here is how to package and serve it via the CLI.- Create your project
setup.py:
Example
- Write your model code
llama_wrapper.py, write your model and include EntryClass:
Example
- Install your package
sglang_custom_project directory to install your code into the active Python environment:
Command
- Update your
config.json
config.json under your HuggingFace model checkpoint directory so the architectures field matches your class name:
Config
- Launch the server
Command
SGLANG_EXTERNAL_MODEL_PACKAGE should be the parent folder name containing your model-related code. In this example, it should be custom_llm.
Example: Multimodal Model
If you are working with multimodal models, settingSGLANG_EXTERNAL_MODEL_PACKAGE alone is not enough. SGLang also needs to recognize your architecture as multimodal to enable the image/video processing pipelines, and it needs a custom processor.
You can handle this by setting two additional environment variables:
SGLANG_EXTERNAL_MM_MODEL_ARCH: Adds your architecture name to SGLang’s internal list of multimodal models.SGLANG_EXTERNAL_MM_PROCESSOR_PACKAGE: Tells SGLang where to find your custom processor class.
setup.py:
Example
qwenvl_wrapper.py:
Example
EntryClass for the custom processor as long as you associate the processor with the specific model class.
Install the package, update config.json, and launch:
Command
Config
Command
Documentation
Add to table of supported models in generative_models.md or multimodal_language_models.mdBy following these guidelines, you can add support for new language models and multimodal large language models in SGLang and ensure they are thoroughly tested and easily integrated into the system.
