Environment Setup
We strongly recommend usingclangd as the language server for JIT kernel development.
For Ubuntu/Debian, you can download clangd from apt.llvm.org.
If you are using VS Code, we recommend installing the clangd extension for better IDE integration.
All JIT-related files are located in python/sglang/jit_kernel.
Unlike sgl-kernel, which compiles CUDA/C++ binaries ahead of time (AOT), just-in-time (JIT) kernels are compiled at runtime.
Consequently, a static compile_commands.json cannot be generated.
To enable code completion with clangd, run python -m sglang.jit_kernel to generate a .clangd configuration file in your current directory.
After generating the file, restart the clangd language server. It should now recognize all JIT kernel files.
Code Structure
C++ Implementation
C++ source code is located inpython/sglang/jit_kernel/csrc.
Reusable functions should be placed in python/sglang/jit_kernel/include.
We use tvm-ffi for efficient foreign language bindings.
Refer to the documentation for advanced usage, such as exporting C++ objects.
Typically, tvm::ffi::TensorView is sufficient for passing PyTorch Tensors from Python.
Python Interface
Python interfaces are defined inpython/sglang/jit_kernel.
The load_jit utility function in python/sglang/jit_kernel/utils.py loads and returns the compiled module.
To export a C++ function (e.g., cpp_func), pass cuda_wrappers=[("func", "cpp_func")] to load_jit.
The function can then be called in Python as module.func.
For caching compiled modules, prefer sglang.jit_kernel.utils.cache_once over functools.lru_cache.
functools.lru_cache is not compatible with torch.compile.
C++ Utilities
The following C++ utilities are available:Integer Range
Similar to PyTorch, we provide anirange function to represent an integer range.
Example
Runtime Checking
RuntimeCheck validates conditions at runtime. It accepts optional arguments for error reporting.
If the check fails, these arguments are output to aid debugging.
RuntimeDeviceCheck verifies the status of the last kernel launch.
Example
Tensor Checking
TensorMatcher provides a readable way to validate and extract tensor shape information.
Example
TensorMatcher with expected stride, dtype, and device properties before verification.
- If
with_stridesis omitted, the tensor is expected to be contiguous. - Template arguments in
with_dtyperestrict the allowed data types. - Template arguments in
with_devicerestrict the allowed devices. - Values passed to
with_xxxmethods enforce equality checks. - Passing
-1for size or stride allows matching any value.
Symbolic variable must resolve to the same value across all verifications.
Use .unwrap() to retrieve the matched value after verification.
Note: TensorMatcher is a temporary expression and should not be stored in a variable.
Tip: Add//at the end of theTensorMatcherchain to enforce proper indentation.
Kernel Launching
LaunchKernel::resolve_device retrieves the current cudaStream from PyTorch.
Kernels can also be launched directly using LaunchKernel.
Example
Add new kernels
This section walks through a complete, end-to-end example of adding a new JIT kernel to the system. We use a simple add_constant kernel as a running example, which adds a constant integer value to every element of an input tensor. Conceptually, the Python interface looks like this:Example
STEP 1: Write the C++ kernel
Write your CUDA kernel in jit_kernel/csrc/add_constant.cuh. For demonstration purposes, we pass the constant value as a template parameter.Example
STEP 2: Create Python Interfaces
Next, expose the kernel through a Python wrapper. Create a new file at jit_kernel/add_constant.py and expose the needed interfaces.Example
STEP 3: Use your kernel
Finally, import and use the kernel like a regular Python function:Example
C++ Include Library Reference
The JIT kernel framework provides a set of reusable C++ headers inpython/sglang/jit_kernel/include/sgl_kernel/. Each header is designed
to be lightweight and self-contained. Below is a summary of each header
and its key APIs.
Core Utilities
| Header | Namespace | Purpose |
|---|---|---|
utils.h | host | Host-side essentials: RuntimeCheck, Panic, div_ceil, irange |
utils.cuh | device / host | Type aliases (fp16_t, bf16_t, …), SGL_DEVICE macro, PDL helpers, LaunchKernel, RuntimeDeviceCheck |
source_location.h | (global) | Portable std::source_location wrapper for error reporting |
runtime.cuh | host::runtime | CUDA runtime queries: get_blocks_per_sm, get_sm_count, get_cc_major, get_runtime_version, get_available_dynamic_smem_per_block |
Tensor Validation
| Header | Namespace | Purpose |
|---|---|---|
tensor.h | host | TensorMatcher, SymbolicSize, SymbolicDType, SymbolicDevice |
Math & Type System
| Header | Namespace | Purpose |
|---|---|---|
math.cuh | device::math | max, min, abs, sqrt, rsqrt, exp, sin, cos, constants |
type.cuh | (global) / device | dtype_trait<T>, packed_t<T>, device::cast<To>(from) |
Memory Access
| Header | Namespace | Purpose |
|---|---|---|
vec.cuh | device | AlignedVector<T, N> - vectorized load/store (up to 128-bit; 256-bit requires Blackwell GPUs) |
tile.cuh | device::tile | Memory<T> - cooperative tiled memory I/O (thread/warp/CTA) |
Parallel Primitives
| Header | Namespace | Purpose |
|---|---|---|
warp.cuh | device::warp | reduce_sum, reduce_max via __shfl_xor_sync |
cta.cuh | device::cta | reduce_max across warps via shared memory |
atomic.cuh | device::atomic | max - atomic float max (CUDA + ROCm fallback) |
Reusable Kernel Templates
| Header | Namespace | Purpose |
|---|---|---|
impl/norm.cuh | host::norm / device::norm | RMSNorm building blocks (warp & CTA paths, StorageType) |
