Development Guide for JIT Kernels#
Environment Setup#
We strongly recommend using clangd as the language server for JIT kernel development.
For Ubuntu/Debian, you can download clangd from apt.llvm.org.
If you are using VS Code, we recommend installing the clangd extension for better IDE integration.
All JIT-related files are located in python/sglang/jit_kernel.
Unlike sgl-kernel, which compiles CUDA/C++ binaries ahead of time (AOT), just-in-time (JIT) kernels are compiled at runtime.
Consequently, a static compile_commands.json cannot be generated.
To enable code completion with clangd, run python -m sglang.jit_kernel to generate a .clangd configuration file in your current directory.
After generating the file, restart the clangd language server. It should now recognize all JIT kernel files.
Code Structure#
C++ Implementation#
C++ source code is located in python/sglang/jit_kernel/csrc.
Reusable functions should be placed in python/sglang/jit_kernel/include.
We use tvm-ffi for efficient foreign language bindings.
Refer to the documentation for advanced usage, such as exporting C++ objects.
Typically, tvm::ffi::TensorView is sufficient for passing PyTorch Tensors from Python.
Python Interface#
Python interfaces are defined in python/sglang/jit_kernel.
The load_jit utility function in python/sglang/jit_kernel/utils.py loads and returns the compiled module.
To export a C++ function (e.g., cpp_func), pass cuda_wrappers=[("func", "cpp_func")] to load_jit.
The function can then be called in Python as module.func.
For caching compiled modules, prefer sglang.jit_kernel.utils.cache_once over functools.lru_cache.
functools.lru_cache is not compatible with torch.compile.
C++ Utilities#
The following C++ utilities are available:
Integer Range#
Similar to PyTorch, we provide an irange function to represent an integer range.
#include <sgl_kernel/utils.h>
void test() {
for (auto i : host::irange(100)) { // [0, 100)
// do something
}
for (auto i : host::irange(0, 100)) { // [0, 100)
// do something
}
}
Runtime Checking#
RuntimeCheck validates conditions at runtime. It accepts optional arguments for error reporting.
If the check fails, these arguments are output to aid debugging.
RuntimeDeviceCheck verifies the status of the last kernel launch.
#include <sgl_kernel/utils.h>
#include <sgl_kernel/utils.cuh>
void test() {
host::RuntimeCheck(1 + 1 == 2, 1 + 1, " != ", 2);
host::RuntimeDeviceCheck();
// check the provided `cudaError_t`
host::RuntimeDeviceCheck(cudaGetLastError());
}
Tensor Checking#
TensorMatcher provides a readable way to validate and extract tensor shape information.
#include <sgl_kernel/tensor.h>
void test(const tvm::ffi::TensorView k_cache, const tvm::ffi::TensorView v_cache) {
using namespace host;
auto D = SymbolicSize{"D"}; // cache dimension
auto N = SymbolicSize{"N"}; // kvcache stride
auto dtype = SymbolicDType{};
auto device = SymbolicDevice{};
TensorMatcher({-1, D}) //
.with_strides({N, 1})
.with_dtype<int32_t, int64_t>(dtype)
.with_device<kDLCUDA, kDLCPU>(device)
.verify(k_cache)
.verify(v_cache);
}
Configure the TensorMatcher with expected stride, dtype, and device properties before verification.
If
with_stridesis omitted, the tensor is expected to be contiguous.Template arguments in
with_dtyperestrict the allowed data types.Template arguments in
with_devicerestrict the allowed devices.Values passed to
with_xxxmethods enforce equality checks.Passing
-1for size or stride allows matching any value.
A Symbolic variable must resolve to the same value across all verifications.
Use .unwrap() to retrieve the matched value after verification.
Note:
TensorMatcheris a temporary expression and should not be stored in a variable.
Tip: Add
//at the end of theTensorMatcherchain to enforce proper indentation.
Kernel Launching#
LaunchKernel::resolve_device retrieves the current cudaStream from PyTorch.
Kernels can also be launched directly using LaunchKernel.
#include <sgl_kernel/utils.cuh>
#include <dlpack/dlpack.h>
__global__ void kernel() {}
void test() {
const auto num_blocks = 1;
const auto num_threads = 32;
const auto dynamic_smem = 0;
DLDevice dev; // suppose this is initialized properly
host::LaunchKernel(num_blocks, num_threads, dev)(kernel);
cudaStream_t stream = host::LaunchKernel::resolve_device(dev);
host::LaunchKernel(num_blocks, num_threads, stream, dynamic_smem)(kernel);
}
Add new kernels#
This section walks through a complete, end-to-end example of adding a new JIT kernel to the system. We use a simple add_constant kernel as a running example, which adds a constant integer value to every element of an input tensor.
Conceptually, the Python interface looks like this:
def add_constant(src: torch.Tensor, c: int):
return src + c
STEP 1: Write the C++ kernel#
Write your CUDA kernel in jit_kernel/csrc/add_constant.cuh. For demonstration purposes, we pass the constant value as a template parameter.
#include <sgl_kernel/tensor.h> // For TensorMatcher, SymbolicSize, SymbolicDevice
#include <sgl_kernel/utils.cuh> // For LaunchKernel
#include <sgl_kernel/utils.h> // For div_ceil, RuntimeCheck
#include <dlpack/dlpack.h>
#include <tvm/ffi/container/tensor.h>
#include <cstddef>
#include <cstdint>
namespace {
template <int32_t kConstant>
__global__ void add_constant_kernel(int32_t* dst, const int32_t* src, size_t length) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < length) {
dst[idx] = src[idx] + kConstant;
}
}
constexpr size_t kBlockSize = 256;
// You can also use struct with static method as an alternative
template <int32_t kConstant>
void add_constant(tvm::ffi::TensorView dst, tvm::ffi::TensorView src) {
using namespace host;
// 1. Validate input tensors
SymbolicSize N = {"num_elements"};
SymbolicDevice device_;
TensorMatcher({N}) // 1D tensor, must be contiguous
.with_dtype<int32_t>() // must be int32
.with_device<kDLCUDA>(device_) // must be on CUDA device
.verify(dst) // check tensor dst
.verify(src); // check tensor src
// 2. Extract required parameters, prepare for kernel launch
const size_t num_elements = N.unwrap();
const size_t grid_size = div_ceil(num_elements, kBlockSize);
const DLDevice device = device_.unwrap();
// some extra runtime checks using host::RuntimeCheck
RuntimeCheck(num_elements > 0, "We only support non-empty tensors, got num_elements = ", num_elements);
// 3. Launch the kernel. Error code will be automatically checked.
LaunchKernel(grid_size, kBlockSize, device /*, dynamic_smem*/)(
// kernel function
add_constant_kernel<kConstant>,
// kernel arguments
static_cast<int32_t*>(dst.data_ptr()),
static_cast<int32_t*>(src.data_ptr()),
num_elements);
}
} // namespace
STEP 2: Create Python Interfaces#
Next, expose the kernel through a Python wrapper. Create a new file at jit_kernel/add_constant.py and expose the needed interfaces.
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from sglang.jit_kernel.utils import cache_once, load_jit, make_cpp_args
if TYPE_CHECKING:
from tvm_ffi.module import Module
@cache_once
def _jit_add_constant_module(constant: int) -> Module:
args = make_cpp_args(constant) # pass all the template argument
return load_jit(
"add_constant",
*args,
cuda_files=["add_constant.cuh"],
cuda_wrappers=[("add_constant", f"add_constant<{args}>")],
)
def add_constant(src: torch.Tensor, constant: int) -> torch.Tensor:
dst = torch.empty_like(src)
module = _jit_add_constant_module(constant)
module.add_constant(dst, src)
return dst
STEP 3: Use your kernel#
Finally, import and use the kernel like a regular Python function:
from sglang.jit_kernel.add_constant import add_constant
For a complete, runnable example, refer to test_add_constant.py.