Skip to main content

Documentation Index

Fetch the complete documentation index at: https://docs.sglang.io/llms.txt

Use this file to discover all available pages before exploring further.

Overview

SGL-Kernel-NPU is the official operator library provided by the SGLang framework for Ascend NPU. It includes two types of operator implementations:
  1. Ascend C operators: High-performance C++ kernels written in Ascend C, compiled into libsgl_kernel_npu.so, and registered through PyTorch’s custom operator mechanism (TORCH_LIBRARY_FRAGMENT). Called in SGLang via torch.ops.npu.<op_name>().
  2. Triton operators: Python kernels written in Triton, adapted for Ascend NPU. Called directly via from sgl_kernel_npu.xxx import ....
When SGLang detects an NPU device, it automatically loads sgl_kernel_npu and uses its operators in place of GPU counterparts, providing optimized inference on Ascend hardware.

Directory Structure

sgl-kernel-npu/
├── csrc/                          # Ascend C operator C++ sources
│   ├── CMakeLists.txt             # Build configuration
│   ├── pytorch_extensions.cpp     # PyTorch op registration (core integration file)
│   └── <op_name>/                 # One directory per operator
│       ├── op_host/               # Host-side code (validation, tiling, launch)
│       │   ├── <op_name>.cpp
│       │   └── tiling/            # Optional: tiling data
│       └── op_kernel/             # Device-side code (Ascend C kernel on AICore)
│           └── <op_name>_kernel.cpp
├── include/
│   └── sgl_kenel_npu_ops.h        # C++ interface declarations
├── python/
│   └── sgl_kernel_npu/
│       └── sgl_kernel_npu/
│           ├── __init__.py        # Loads libsgl_kernel_npu.so
│           ├── attention/         # Triton attention kernels
│           ├── norm/              # Triton normalization kernels
│           ├── activation/        # Triton activation kernels
│           ├── fla/               # Triton linear attention kernels
│           ├── mamba/             # Triton Mamba kernels
│           ├── moe/               # Triton MoE kernels
│           └── sample/            # Triton speculative decoding kernels
├── tests/
│   └── python/sgl_kernel_npu/     # One test file per operator
├── build.sh                       # Build script
└── CMakeLists.txt                 # Root CMake configuration

Developing Ascend C Operators

A complete Ascend C operator consists of two parts:
  • Device part: Kernel code running on the NPU AICore, responsible for actual computation. Written using the Ascend C API.
  • Host part: Code running on the CPU, responsible for parameter validation, data pre-processing, tiling, and kernel launch.
We recommend starting with the helloworld example, a simple operator that performs element-wise addition on two tensors.

Step 1: Create the operator directory and files

Create a new operator directory under csrc/, following the op_host/ + op_kernel/ structure:
csrc/<op_name>/
├── op_host/
│   └── <op_name>.cpp
└── op_kernel/
    └── <op_name>_kernel.cpp

Step 2: Write the Device-side Kernel (op_kernel)

Device-side code runs on AICore and follows the Ascend C programming model. The core structure is a class with Init() and Process() methods, plus an extern "C" entry function. Using helloworld as an example:
// csrc/helloworld/op_kernel/kernel_helloworld.cpp
#include "kernel_operator.h"

constexpr int32_t BUFFER_NUM = 2;

class KernalHelloworld {
public:
    __aicore__ inline KernalHelloworld() {}

    __aicore__ inline void Init(GM_ADDR x, GM_ADDR y, GM_ADDR z, uint32_t totalLength)
    {
        // Compute workload for current block
        this->blockLength = totalLength / AscendC::GetBlockNum();
        this->tileNum = 8;
        this->tileLength = this->blockLength / this->tileNum / BUFFER_NUM;

        // Set global memory buffers
        xGm.SetGlobalBuffer((__gm__ half *)x + this->blockLength * AscendC::GetBlockIdx(), this->blockLength);
        yGm.SetGlobalBuffer((__gm__ half *)y + this->blockLength * AscendC::GetBlockIdx(), this->blockLength);
        zGm.SetGlobalBuffer((__gm__ half *)z + this->blockLength * AscendC::GetBlockIdx(), this->blockLength);

        // Initialize pipeline queues
        pipe.InitBuffer(inQueueX, BUFFER_NUM, this->tileLength * sizeof(half));
        pipe.InitBuffer(inQueueY, BUFFER_NUM, this->tileLength * sizeof(half));
        pipe.InitBuffer(outQueueZ, BUFFER_NUM, this->tileLength * sizeof(half));
    }

    __aicore__ inline void Process()
    {
        int32_t loopCount = this->tileNum * BUFFER_NUM;
        for (int32_t i = 0; i < loopCount; i++) {
            CopyIn(i);    // Move data from Global Memory to Local Memory
            Compute(i);   // Compute on Local Memory
            CopyOut(i);   // Move results back to Global Memory
        }
    }

private:
    __aicore__ inline void CopyIn(int32_t progress) { /* data copy-in... */ }
    __aicore__ inline void Compute(int32_t progress) { /* core computation... */ }
    __aicore__ inline void CopyOut(int32_t progress) { /* data copy-out... */ }

private:
    AscendC::TPipe pipe;
    AscendC::TQue<AscendC::TPosition::VECIN, BUFFER_NUM> inQueueX, inQueueY;
    AscendC::TQue<AscendC::TPosition::VECOUT, BUFFER_NUM> outQueueZ;
    AscendC::GlobalTensor<half> xGm, yGm, zGm;
    uint32_t blockLength, tileNum, tileLength;
};

// Entry function: the compile tool auto-generates aclrtlaunch_<op_name>.h from this name
extern "C" __global__ __aicore__ void helloworld(
    GM_ADDR x, GM_ADDR y, GM_ADDR z, uint32_t totalLength)
{
    KernalHelloworld op;
    op.Init(x, y, z, totalLength);
    op.Process();
}
Key points:
  • Class methods must be marked with __aicore__, indicating they run on AICore.
  • Use AscendC::TPipe + AscendC::TQue to build a pipeline that overlaps data movement and computation.
  • The entry function must be declared extern "C" __global__ __aicore__. The compile tool generates a host-callable launch header aclrtlaunch_<func_name>.h from the function name.
  • Simple operators (e.g., helloworld, cache_assign, lora) do not need extra workspace memory. Complex operators (e.g., mla_preprocess, alloc_extend, build_tree) require temporary workspace memory and are compiled separately in CMakeLists.txt.
For more in-depth Ascend C programming knowledge, refer to the Ascend C Kernel Development Guide.

Step 3: Write the Host-side Code (op_host)

Host-side code is responsible for passing PyTorch Tensors to the kernel and launching it. The key macro is EXEC_KERNEL_CMD (located in csrc/utils/torch_helper.h).
// csrc/helloworld/op_host/helloworld.cpp
#include "defines.h"                   // Provides HOST_API macro
#include "torch_helper.h"              // Provides EXEC_KERNEL_CMD macro
#include "aclrtlaunch_helloworld.h"    // Auto-generated by compile tool

namespace sglang {
namespace npu_kernel {

HOST_API at::Tensor helloworld(const at::Tensor &x, const at::Tensor &y)
{
    // Create output tensor
    at::Tensor z = at::empty_like(x);

    // Define block count
    uint32_t blockDim = 8;

    // Compute total element count
    uint32_t totalLength = 1;
    for (uint32_t size : x.sizes()) {
        totalLength *= size;
    }

    // Launch kernel via EXEC_KERNEL_CMD macro
    EXEC_KERNEL_CMD(helloworld, blockDim, x, y, z, totalLength);
    return z;
}

} // namespace npu_kernel
} // namespace sglang
Key points:
  • The namespace must be sglang::npu_kernel.
  • Function signatures follow the pattern at::Tensor <op_name>(const at::Tensor &input, ...).
  • For operators with multiple outputs, use std::tuple<at::Tensor, at::Tensor, ...> or non-const reference parameters.

Step 4: Declare the C++ Interface (include/sgl_kenel_npu_ops.h)

Add the operator function declaration in include/sgl_kenel_npu_ops.h:
// include/sgl_kenel_npu_ops.h
namespace sglang {
namespace npu_kernel {

at::Tensor helloworld(const at::Tensor &x, const at::Tensor &y);

} // namespace npu_kernel
} // namespace sglang

Step 5: Register the PyTorch Custom Operator (pytorch_extensions.cpp)

Register the operator in csrc/pytorch_extensions.cpp in two steps: define the schema and bind the implementation.
// csrc/pytorch_extensions.cpp
namespace {

// 1. Define operator schema (used by torch.compile, etc.)
TORCH_LIBRARY_FRAGMENT(npu, m)
{
    m.def("helloworld(Tensor x, Tensor y) -> Tensor");
    // ... other operator schemas ...
}

// 2. Bind implementation for the PrivateUse1 device (i.e., NPU)
TORCH_LIBRARY_IMPL(npu, PrivateUse1, m)
{
    m.impl("helloworld", TORCH_FN(sglang::npu_kernel::helloworld));
    // ... other operator implementations ...
}

} // namespace
Schema conventions:
  • The namespace is fixed to npu. In SGLang, operators are called via torch.ops.npu.<op_name>().
  • Output tensor parameters use the Tensor(a!) mutating annotation.
  • Optional parameters use the Tensor? annotation, with c10::optional<T> handling in the impl.
  • For detailed schema syntax, see the PyTorch Schema Reference.
Implementation binding rules:
  • The device name is fixed to PrivateUse1 (PyTorch NPU backend identifier).
  • Use the TORCH_FN macro to bind to the implementation function.
  • For complex operators with optional parameters, use lambda expressions to unpack the arguments.

Step 6: Update the Build Configuration (csrc/CMakeLists.txt)

Add the new operator’s source files to csrc/CMakeLists.txt: For operators not requiring workspace (simple operators), add kernel source to no_workspace_kernel:
ascendc_library(no_workspace_kernel STATIC
    # ... existing kernel files ...
    ${PROJECT_OP_SRC_BASE}/<op_name>/op_kernel/<op_name>_kernel.cpp
)
For operators requiring workspace (complex operators), add kernel source to workspace_kernel with the -DHAVE_WORKSPACE -DHAVE_TILING compile flags:
ascendc_library(workspace_kernel STATIC
    # ... existing kernel files ...
    ${PROJECT_OP_SRC_BASE}/<op_name>/op_kernel/<op_name>_kernel.cpp
)
Add host source files to OP_SRCS:
FILE(GLOB OP_SRCS
    # ... existing host files ...
    ${PROJECT_OP_SRC_BASE}/<op_name>/op_host/<op_name>.cpp
)

Step 7: Build

Build following the steps in the python/sgl_kernel_npu/README.md:
cd sgl-kernel-npu

# Build all modules
bash build.sh

# Install the sgl_kernel_npu wheel
pip install output/sgl_kernel_npu*.whl
The compiled libsgl_kernel_npu.so is copied into python/sgl_kernel_npu/sgl_kernel_npu/lib/ and loaded by the Python package.

Developing Triton Operators

Triton operators are located under python/sgl_kernel_npu/sgl_kernel_npu/, organized by function category:
python/sgl_kernel_npu/sgl_kernel_npu/
├── attention/     # Attention (decode_attention, sinks_attention)
├── norm/          # Normalization (rmsnorm, fused_qk_norm, l1_norm)
├── activation/    # Activation (swiglu_oai, swiglu_quant)
├── fla/           # Linear attention (chunk, cumsum, wy_fast)
├── mamba/         # Mamba-related (causal_conv1d, state_update)
├── moe/           # MoE-related (mul_add, zero_experts)
└── sample/        # Speculative decoding (verify_tree_greedy)
Development steps:
  1. Create a new .py file in the appropriate category subdirectory.
  2. Write the kernel using the Triton language, using existing operators in the same directory as templates.
  3. Export functions in the corresponding __init__.py if needed.
  4. Write tests under tests/python/sgl_kernel_npu/.
Note: Many Triton operators are adapted from SGLang’s GPU Triton kernels (e.g., comments in fla/utils.py note the original source). Pay special attention to differences between NPU and GPU when adapting.

Integrating Operators into SGLang

Ascend C Operator Integration

After completing the Steps 1-7 above (writing the kernel, registering the torch op, building), and installing the sgl-kernel-npu wheel, call the operator in SGLang as follows:
import sgl_kernel_npu  # Loading the library auto-triggers libsgl_kernel_npu.so loading

# Call the operator
result = torch.ops.npu.helloworld(x, y)
Real-world usage in SGLang (from sglang/srt/speculative/eagle_utils.py):
torch.ops.npu.build_tree_kernel_efficient(
    parent_list, selected_index, verified_seq_len, tree_mask,
    positions, retrive_index, retrive_next_token,
    retrive_next_sibling, topk, depth, draft_token_num, tree_mask_mode
)

Triton Operator Integration

Import and call directly via Python:
from sgl_kernel_npu.attention.decode_attention import decode_attention_fwd
from sgl_kernel_npu.norm.rmsnorm_bias import rmsnorm_bias
from sgl_kernel_npu.mamba.causal_conv1d import causal_conv1d_fwd

# Direct function call
output = decode_attention_fwd(q, k, v, ...)
Real-world usage in SGLang (from sglang/srt/models/llama.py):
from sgl_kernel_npu.norm.split_qkv_rmsnorm_rope import split_qkv_rmsnorm_rope

sgl-kernel-npu Wheel Update Process

Since SGLang and sgl-kernel-npu are separate Python packages, dependency updates require a multi-PR workflow:
  1. Submit sgl-kernel-npu PR: Add/modify operators in the sgl-kernel-npu repository, ensuring all tests pass.
  2. Bump sgl-kernel-npu version: Update the version number in sgl-kernel-npu. Merging triggers an automatic PyPI release.
  3. Reference the new version in SGLang:
    • Update the sgl-kernel-npu version requirement in SGLang’s python/pyproject.toml.
    • Use the new operator in SGLang code.
If not urgent, you can wait for a regular release (typically within one week).

Writing Unit Tests

Each operator needs a corresponding unit test under tests/python/sgl_kernel_npu/, using Python’s unittest framework. Test file naming convention: test_<op_name>.py
# tests/python/sgl_kernel_npu/test_helloworld.py
import unittest
import torch
import sgl_kernel_npu

class TestHelloworld(unittest.TestCase):
    def test_helloworld_basic(self):
        x = torch.randn(1024, dtype=torch.bfloat16, device="npu")
        y = torch.randn(1024, dtype=torch.bfloat16, device="npu")

        z = torch.ops.npu.helloworld(x, y)
        expected = x + y

        torch.testing.assert_close(z, expected)

if __name__ == "__main__":
    unittest.main()
Run tests:
python tests/python/sgl_kernel_npu/test_helloworld.py
Testing checklist:
  • Cover typical input shapes (power-of-2 sizes and non-standard sizes).
  • Cover different data types (bf16 / fp16, etc.).
  • For operators with in-place behavior, verify correctness of output tensors.
  • Compare against PyTorch native computation to verify accuracy.

Code Style

Pre-commit Checks

sgl-kernel-npu uses pre-commit for consistent code style:
pip3 install pre-commit
cd sgl-kernel-npu
pre-commit install
pre-commit run --all-files
Note: If pre-commit run --all-files fails the first time, run it again to ensure all lint errors are auto-fixed. All code must pass checks before submitting a PR.

C++ Code Style

  • Use the C++17 standard.
  • Place all operator implementations under the sglang::npu_kernel namespace.
  • Follow existing code style; format using .clang-format.
  • Use TORCH_CHECK for error checking (not the standard GE OP_ADD macros).
  • Do not include unnecessary GE registration code (e.g., OP_ADD() macros).

Python Code Style

  • Follow PEP 8.
  • Use snake_case for file and function names.
  • When adapting from SGLang GPU code, note the original source in the file header.

General Principles

  • Avoid code duplication: extract shared functions for any repeated code blocks over 5 lines.
  • Minimize device synchronization: reduce CPU-NPU sync operations like tensor.item() or tensor.cpu().
  • Keep functions pure: avoid in-place argument modification.
  • Keep files concise: split files exceeding 2,000 lines.

Submitting a PR

  1. Fork the repo: Fork sgl-kernel-npu on GitHub, then clone locally.
  2. Create a branch: Create a new branch from main, e.g., feature/add-my-op.
  3. Develop and test: Develop the operator and write tests following the steps above. Ensure all tests pass.
  4. Run pre-commit: Ensure code formatting compliance.
  5. Commit and push:
    git add .
    git commit -m "feat: add <op_name> operator"
    git push origin feature/add-my-op
    
  6. Create a PR: Open a Pull Request on GitHub from your branch to sgl-project/sgl-kernel-npu:main.
  7. Wait for CI and review: CI checks include linting, compilation, and operator tests. After passing, wait for maintainer review and merge.

References