add-cuda-kernel
Step-by-step tutorial for adding new CUDA kernels to FlashInfer
Install
mkdir -p .claude/skills/add-cuda-kernel && curl -L -o skill.zip "https://mcp.directory/api/skills/download/5897" && unzip -o skill.zip -d .claude/skills/add-cuda-kernel && rm skill.zipInstalls to .claude/skills/add-cuda-kernel
About this skill
Tutorial: Adding a New Kernel to FlashInfer
This tutorial walks through adding a simple element-wise scale operation to FlashInfer. We'll implement scale(x, factor) = x * factor to demonstrate the complete workflow.
Goal
Add a new operation that scales each element of a tensor by a scalar factor:
- Input: tensor
xand scalarfactor - Output:
x * factor(element-wise) - Support multiple dtypes (FP16, BF16, FP32)
Step 1: Define CUDA Kernel in include/
Create include/flashinfer/scale.cuh:
#pragma once
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
namespace flashinfer {
/*!
* \brief Element-wise scale kernel
* \tparam T Data type (half, __nv_bfloat16, float)
* \param input Input tensor
* \param output Output tensor
* \param factor Scale factor
* \param n Number of elements
*/
template <typename T>
__global__ void ScaleKernel(const T* input, T* output, T factor, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) {
output[idx] = input[idx] * factor;
}
}
/*!
* \brief Launch scale kernel
* \tparam T Data type
* \param input Input pointer
* \param output Output pointer
* \param factor Scale factor
* \param n Number of elements
* \param stream CUDA stream
*/
template <typename T>
cudaError_t ScaleLauncher(const T* input, T* output, T factor, int n,
cudaStream_t stream = nullptr) {
const int threads = 256;
const int blocks = (n + threads - 1) / threads;
ScaleKernel<T><<<blocks, threads, 0, stream>>>(input, output, factor, n);
return cudaGetLastError();
}
} // namespace flashinfer
Key points:
- Framework-agnostic (no Torch headers)
- Uses raw pointers
- Template-based for dtype flexibility
- Only includes what's needed (cuda_runtime, cuda_fp16, cuda_bf16)
Step 2: Create Launcher in csrc/
Create csrc/scale.cu:
#include "flashinfer/scale.cuh"
using namespace flashinfer;
void scale_launcher(TensorView input, TensorView output,
float factor) {
CHECK_INPUT(input);
CHECK_INPUT(output);
TVM_FFI_ICHECK_EQ(input.dtype(), output.dtype());
int n = input.numel();
auto stream = get_stream(input.device());
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(input.dtype(), DType, [&] {
cudaError_t status = ScaleLauncher<DType>(
input.data_ptr<DType>(),
output.data_ptr<DType>(),
static_cast<DType>(factor),
n,
stream
);
TVM_FFI_ICHECK(status == cudaSuccess)
<< "Failed to run ScaleLauncher: " << cudaGetErrorString(status);
return true;
});
}
Key points:
- Includes TVM FFI utils headers
tvm_ffi_utils.h(only allowed incsrc/) - Uses
tvm::ffi::TensorViewas input and output tensor types - Uses macros defined in
tvm_ffi_utils.hto check the input and output if both on CUDA device, both contiguous, and share the same data type - Gets CUDA stream by TVM FFI, and prepare all scalar inputs for kernel function
- Dispatches on dtype with macros defined in
tvm_ffi_utils.h, or adds new one if not covered - Converts tvm::ffi::TensorView to raw pointers
- Handles the result status of kernel by
TVM_FFI_ICHECK - Add descriptive error messages with
<<operator - Use TVM-FFI exceptions:
TVM_FFI_THROW(ErrorType) << "message"for custom error checking
TVM-FFI Error Handling:
TVM_FFI_THROW(ValueError) << "message"- Throw ValueError with custom messageTVM_FFI_THROW(TypeError) << "message"- Throw TypeError- Use
<<to chain multiple values in the error message - Errors are properly propagated back to Python
When to use TVM_FFI_THROW vs TVM_FFI_LOG_AND_THROW:
-
TVM_FFI_THROW: Use for normal runtime error handling. This is the standard way to report errors that will be caught and propagated to Python. -
TVM_FFI_LOG_AND_THROW: Use only in cases where:- The function may be called during object construction time (e.g., validation in constructors or setup methods)
- The exception may not be caught properly (e.g., during module initialization)
- The error condition almost never fails in practice (e.g., internal errors, unsupported dtype combinations in dispatch macros)
This variant logs the error message before throwing, ensuring visibility even if the exception doesn't propagate correctly.
Example from fused_moe (see csrc/trtllm_fused_moe_kernel_launcher.cu):
// In a setup/validation function that may be called during construction
void check_weights_shape(std::string which_weights) const {
if (which_weights != "gemm1" && which_weights != "gemm2") {
// Internal error that should never happen - use LOG_AND_THROW
TVM_FFI_LOG_AND_THROW(InternalError) << "Internal error: which_weights = " << which_weights;
}
// ...
if (weight_layout is unsupported) {
// Unsupported config during setup - use LOG_AND_THROW
TVM_FFI_LOG_AND_THROW(NotImplementedError)
<< "Unsupported weight_layout: " << (int)weight_layout;
}
}
// In a normal runtime function
void scale_run(TensorView input, TensorView output, double factor) {
if (!input_tensor.is_cuda()) {
// Normal validation error - use TVM_FFI_THROW
TVM_FFI_THROW(ValueError) << "Input must be a CUDA tensor";
}
}
Step 3: Create TVM-FFI Binding in csrc/
Create csrc/scale_jit_binding.cu:
#include "scale.cu"
#include "tvm_ffi_utils.h"
// Forward declaration
void scale_launcher(TensorView input, TensorView output, float factor);
// Export to TVM-FFI
TVM_FFI_DLL_EXPORT_TYPED_FUNC(run, scale_launcher);
Key points:
- Forward declare the launcher function first
- Export using
TVM_FFI_DLL_EXPORT_TYPED_FUNC(name, function)
Step 4: Create JIT Generator (No Jinja for Simple Case)
Create flashinfer/jit/scale.py:
import os
import shutil
from pathlib import Path
from . import JitSpec, gen_jit_spec
from . import env as jit_env
from .core import write_if_different
def get_scale_uri(dtype_in: str, dtype_out: str) -> str:
"""Generate unique identifier for scale module."""
return f"scale_dtype_in_{dtype_in}_dtype_out_{dtype_out}"
def gen_scale_module(dtype_in, dtype_out):
"""
Generate JIT module for scale operation.
Note: This is a simple example without Jinja templating.
The dtype dispatch is handled at runtime in the C++ code.
"""
# Compute URI
uri = get_scale_uri(dtype_in, dtype_out)
# Create generation directory
gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri
os.makedirs(gen_directory, exist_ok=True)
# Copy source files (no Jinja needed for this simple case)
sources = []
for fname in ["scale.cu", "scale_jit_binding.cu"]:
src_path = jit_env.FLASHINFER_CSRC_DIR / fname
dest_path = gen_directory / fname
shutil.copy(src_path, dest_path)
sources.append(dest_path)
# Return JitSpec
return gen_jit_spec(
name=uri,
sources=sources,
extra_cuda_cflags=[],
)
Key points:
- No Jinja template needed for simple operations
- Just copy source files to generation directory
- URI uniquely identifies the module configuration
- NEVER write to package directories - see "JIT Directory Rules" in
CLAUDE.md
(Optional) Specifying Supported CUDA Architectures
FlashInfer uses CompilationContext to manage CUDA architecture targets. This is critical because some kernels only work on specific GPU architectures (e.g., Hopper SM90, Blackwell SM100).
How CompilationContext Works
Automatic Detection (default):
from flashinfer.compilation_context import CompilationContext
ctx = CompilationContext()
# Automatically detects all GPUs in the system
# For SM90+, adds 'a' suffix (e.g., 9.0a for Hopper)
# Result: ctx.TARGET_CUDA_ARCHS = {(9, '0a'), (10, '0a'), ...}
Manual Override (via environment variable):
export FLASHINFER_CUDA_ARCH_LIST="8.0 9.0a 10.0a"
# Now only these architectures will be compiled
Specifying Architectures in Your JIT Module
When creating a JIT module, specify which major SM versions are supported:
from flashinfer.jit.core import gen_jit_spec
from flashinfer.jit import current_compilation_context
def gen_my_hopper_only_module():
"""Example: Kernel works on SM90 and later supported architectures."""
uri = get_my_uri(...)
gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri
# ... copy sources ...
nvcc_flags = current_compilation_context.get_nvcc_flags_list(
# Explicitly list supported SM versions - no automatic future compatibility
supported_major_versions=[9, 10, 11, 12] # SM90, SM100, SM110, SM120
)
return gen_jit_spec(
name=uri,
sources=sources,
extra_cuda_cflags=nvcc_flags,
)
def gen_my_blackwell_only_module():
"""Example: Kernel only works on SM100 (Blackwell)"""
uri = get_my_uri(...)
gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri
# ... copy sources ...
nvcc_flags = current_compilation_context.get_nvcc_flags_list(
supported_major_versions=[10] # SM100 only
)
return gen_jit_spec(
name=uri,
sources=sources,
extra_cuda_cflags=nvcc_flags,
)
def gen_my_universal_module():
"""Example: Kernel works on all architectures"""
uri = get_my_uri(...)
gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri
# ... copy sources ...
nvcc_flags = current_compilation_context.get_nvcc_flags_list(
supported_major_versions=None # All available architectures
)
return gen_jit_spec(
name=uri,
sources=sources,
extra_cuda_cflags=nvcc_flags,
)
What Happens:
- ✅ If user's GPU is SM90 and they call a Hopper-only module → Compiles and runs
- ❌ If user's GPU is SM80 and they call a Hopper-only module →
RuntimeError: No supported CUDA architectures found for major versions [9, 10, 11, 12]
Real Examples from FlashInfe
Content truncated.
More by flashinfer-ai
View all skills by flashinfer-ai →You might also like
flutter-development
aj-geddes
Build beautiful cross-platform mobile apps with Flutter and Dart. Covers widgets, state management with Provider/BLoC, navigation, API integration, and material design.
drawio-diagrams-enhanced
jgtolentino
Create professional draw.io (diagrams.net) diagrams in XML format (.drawio files) with integrated PMP/PMBOK methodologies, extensive visual asset libraries, and industry-standard professional templates. Use this skill when users ask to create flowcharts, swimlane diagrams, cross-functional flowcharts, org charts, network diagrams, UML diagrams, BPMN, project management diagrams (WBS, Gantt, PERT, RACI), risk matrices, stakeholder maps, or any other visual diagram in draw.io format. This skill includes access to custom shape libraries for icons, clipart, and professional symbols.
ui-ux-pro-max
nextlevelbuilder
"UI/UX design intelligence. 50 styles, 21 palettes, 50 font pairings, 20 charts, 8 stacks (React, Next.js, Vue, Svelte, SwiftUI, React Native, Flutter, Tailwind). Actions: plan, build, create, design, implement, review, fix, improve, optimize, enhance, refactor, check UI/UX code. Projects: website, landing page, dashboard, admin panel, e-commerce, SaaS, portfolio, blog, mobile app, .html, .tsx, .vue, .svelte. Elements: button, modal, navbar, sidebar, card, table, form, chart. Styles: glassmorphism, claymorphism, minimalism, brutalism, neumorphism, bento grid, dark mode, responsive, skeuomorphism, flat design. Topics: color palette, accessibility, animation, layout, typography, font pairing, spacing, hover, shadow, gradient."
godot
bfollington
This skill should be used when working on Godot Engine projects. It provides specialized knowledge of Godot's file formats (.gd, .tscn, .tres), architecture patterns (component-based, signal-driven, resource-based), common pitfalls, validation tools, code templates, and CLI workflows. The `godot` command is available for running the game, validating scripts, importing resources, and exporting builds. Use this skill for tasks involving Godot game development, debugging scene/resource files, implementing game systems, or creating new Godot components.
nano-banana-pro
garg-aayush
Generate and edit images using Google's Nano Banana Pro (Gemini 3 Pro Image) API. Use when the user asks to generate, create, edit, modify, change, alter, or update images. Also use when user references an existing image file and asks to modify it in any way (e.g., "modify this image", "change the background", "replace X with Y"). Supports both text-to-image generation and image-to-image editing with configurable resolution (1K default, 2K, or 4K for high resolution). DO NOT read the image file first - use this skill directly with the --input-image parameter.
fastapi-templates
wshobson
Create production-ready FastAPI projects with async patterns, dependency injection, and comprehensive error handling. Use when building new FastAPI applications or setting up backend API projects.
Related MCP Servers
Browse all serversBoost your AI code assistant with Context7: inject real-time API documentation from OpenAPI specification sources into y
Create and edit PowerPoint presentations in Python with Office PowerPoint. Use python pptx or pptx python tools to add s
Empower your Unity projects with Unity-MCP: AI-driven control, seamless integration, and advanced workflows within the U
Excalidraw MCP Server: let AI agents generate, edit, and view Excalidraw diagrams via natural-language commands with rea
ipybox enables secure Python code execution with stateful IPython kernels, real-time output, file operations, and robust
Explore MCP Guide: interactive tutorials and tools to master and implement MCP concepts with ease.
Stay ahead of the MCP ecosystem
Get weekly updates on new skills and servers.