Documentation Index
Fetch the complete documentation index at: https://docs.sglang.io/llms.txt
Use this file to discover all available pages before exploring further.
Overview
SGL-Kernel-NPU is the official operator library provided by the SGLang framework for Ascend NPU. It includes two types of operator implementations:- Ascend C operators: High-performance C++ kernels written in Ascend C,
compiled into
libsgl_kernel_npu.so, and registered through PyTorch’s custom operator mechanism (TORCH_LIBRARY_FRAGMENT). Called in SGLang viatorch.ops.npu.<op_name>(). - Triton operators: Python kernels written in Triton, adapted for Ascend
NPU. Called directly via
from sgl_kernel_npu.xxx import ....
sgl_kernel_npu and
uses its operators in place of GPU counterparts, providing optimized inference
on Ascend hardware.
Directory Structure
Developing Ascend C Operators
A complete Ascend C operator consists of two parts:- Device part: Kernel code running on the NPU AICore, responsible for actual computation. Written using the Ascend C API.
- Host part: Code running on the CPU, responsible for parameter validation, data pre-processing, tiling, and kernel launch.
Step 1: Create the operator directory and files
Create a new operator directory undercsrc/, following the op_host/ +
op_kernel/ structure:
Step 2: Write the Device-side Kernel (op_kernel)
Device-side code runs on AICore and follows the Ascend C programming model. The core structure is a class withInit() and Process() methods, plus an
extern "C" entry function.
Using helloworld as an example:
- Class methods must be marked with
__aicore__, indicating they run on AICore. - Use
AscendC::TPipe+AscendC::TQueto build a pipeline that overlaps data movement and computation. - The entry function must be declared
extern "C" __global__ __aicore__. The compile tool generates a host-callable launch headeraclrtlaunch_<func_name>.hfrom the function name. - Simple operators (e.g., helloworld, cache_assign, lora) do not need extra
workspace memory. Complex operators (e.g., mla_preprocess, alloc_extend,
build_tree) require temporary workspace memory and are compiled separately in
CMakeLists.txt.
Step 3: Write the Host-side Code (op_host)
Host-side code is responsible for passing PyTorch Tensors to the kernel and launching it. The key macro isEXEC_KERNEL_CMD (located in
csrc/utils/torch_helper.h).
- The namespace must be
sglang::npu_kernel. - Function signatures follow the pattern
at::Tensor <op_name>(const at::Tensor &input, ...). - For operators with multiple outputs, use
std::tuple<at::Tensor, at::Tensor, ...>or non-const reference parameters.
Step 4: Declare the C++ Interface (include/sgl_kenel_npu_ops.h)
Add the operator function declaration ininclude/sgl_kenel_npu_ops.h:
Step 5: Register the PyTorch Custom Operator (pytorch_extensions.cpp)
Register the operator incsrc/pytorch_extensions.cpp in two steps: define the
schema and bind the implementation.
- The namespace is fixed to
npu. In SGLang, operators are called viatorch.ops.npu.<op_name>(). - Output tensor parameters use the
Tensor(a!)mutating annotation. - Optional parameters use the
Tensor?annotation, withc10::optional<T>handling in the impl. - For detailed schema syntax, see the PyTorch Schema Reference.
- The device name is fixed to
PrivateUse1(PyTorch NPU backend identifier). - Use the
TORCH_FNmacro to bind to the implementation function. - For complex operators with optional parameters, use lambda expressions to unpack the arguments.
Step 6: Update the Build Configuration (csrc/CMakeLists.txt)
Add the new operator’s source files tocsrc/CMakeLists.txt:
For operators not requiring workspace (simple operators), add kernel source
to no_workspace_kernel:
workspace_kernel with the -DHAVE_WORKSPACE -DHAVE_TILING compile flags:
OP_SRCS:
Step 7: Build
Build following the steps in the python/sgl_kernel_npu/README.md:libsgl_kernel_npu.so is copied into
python/sgl_kernel_npu/sgl_kernel_npu/lib/ and loaded by the Python package.
Developing Triton Operators
Triton operators are located underpython/sgl_kernel_npu/sgl_kernel_npu/,
organized by function category:
- Create a new
.pyfile in the appropriate category subdirectory. - Write the kernel using the Triton language, using existing operators in the same directory as templates.
- Export functions in the corresponding
__init__.pyif needed. - Write tests under
tests/python/sgl_kernel_npu/.
fla/utils.py note the original source). Pay special
attention to differences between NPU and GPU when adapting.
Integrating Operators into SGLang
Ascend C Operator Integration
After completing the Steps 1-7 above (writing the kernel, registering the torch op, building), and installing thesgl-kernel-npu wheel, call the operator in SGLang as follows:
sglang/srt/speculative/eagle_utils.py):
Triton Operator Integration
Import and call directly via Python:sglang/srt/models/llama.py):
sgl-kernel-npu Wheel Update Process
Since SGLang and sgl-kernel-npu are separate Python packages, dependency updates require a multi-PR workflow:- Submit sgl-kernel-npu PR: Add/modify operators in the sgl-kernel-npu repository, ensuring all tests pass.
- Bump sgl-kernel-npu version: Update the version number in sgl-kernel-npu. Merging triggers an automatic PyPI release.
- Reference the new version in SGLang:
- Update the
sgl-kernel-npuversion requirement in SGLang’spython/pyproject.toml. - Use the new operator in SGLang code.
- Update the
Writing Unit Tests
Each operator needs a corresponding unit test undertests/python/sgl_kernel_npu/, using Python’s unittest framework.
Test file naming convention: test_<op_name>.py
- Cover typical input shapes (power-of-2 sizes and non-standard sizes).
- Cover different data types (bf16 / fp16, etc.).
- For operators with in-place behavior, verify correctness of output tensors.
- Compare against PyTorch native computation to verify accuracy.
Code Style
Pre-commit Checks
sgl-kernel-npu uses pre-commit for consistent code style:pre-commit run --all-files fails the first time, run it again to
ensure all lint errors are auto-fixed. All code must pass checks before
submitting a PR.
C++ Code Style
- Use the C++17 standard.
- Place all operator implementations under the
sglang::npu_kernelnamespace. - Follow existing code style; format using
.clang-format. - Use
TORCH_CHECKfor error checking (not the standard GEOP_ADDmacros). - Do not include unnecessary GE registration code (e.g.,
OP_ADD()macros).
Python Code Style
- Follow PEP 8.
- Use snake_case for file and function names.
- When adapting from SGLang GPU code, note the original source in the file header.
General Principles
- Avoid code duplication: extract shared functions for any repeated code blocks over 5 lines.
- Minimize device synchronization: reduce CPU-NPU sync operations like
tensor.item()ortensor.cpu(). - Keep functions pure: avoid in-place argument modification.
- Keep files concise: split files exceeding 2,000 lines.
Submitting a PR
- Fork the repo: Fork sgl-kernel-npu on GitHub, then clone locally.
- Create a branch: Create a new branch from
main, e.g.,feature/add-my-op. - Develop and test: Develop the operator and write tests following the steps above. Ensure all tests pass.
- Run pre-commit: Ensure code formatting compliance.
- Commit and push:
- Create a PR: Open a Pull Request on GitHub from your branch to
sgl-project/sgl-kernel-npu:main. - Wait for CI and review: CI checks include linting, compilation, and operator tests. After passing, wait for maintainer review and merge.
