add-cuda-kernel

0
0
Source

Step-by-step tutorial for adding new CUDA kernels to FlashInfer

Install

mkdir -p .claude/skills/add-cuda-kernel && curl -L -o skill.zip "https://mcp.directory/api/skills/download/5897" && unzip -o skill.zip -d .claude/skills/add-cuda-kernel && rm skill.zip

Installs to .claude/skills/add-cuda-kernel

About this skill

Tutorial: Adding a New Kernel to FlashInfer

This tutorial walks through adding a simple element-wise scale operation to FlashInfer. We'll implement scale(x, factor) = x * factor to demonstrate the complete workflow.

Goal

Add a new operation that scales each element of a tensor by a scalar factor:

  • Input: tensor x and scalar factor
  • Output: x * factor (element-wise)
  • Support multiple dtypes (FP16, BF16, FP32)

Step 1: Define CUDA Kernel in include/

Create include/flashinfer/scale.cuh:

#pragma once
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>

namespace flashinfer {

/*!
 * \brief Element-wise scale kernel
 * \tparam T Data type (half, __nv_bfloat16, float)
 * \param input Input tensor
 * \param output Output tensor
 * \param factor Scale factor
 * \param n Number of elements
 */
template <typename T>
__global__ void ScaleKernel(const T* input, T* output, T factor, int n) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < n) {
    output[idx] = input[idx] * factor;
  }
}

/*!
 * \brief Launch scale kernel
 * \tparam T Data type
 * \param input Input pointer
 * \param output Output pointer
 * \param factor Scale factor
 * \param n Number of elements
 * \param stream CUDA stream
 */
template <typename T>
cudaError_t ScaleLauncher(const T* input, T* output, T factor, int n,
                          cudaStream_t stream = nullptr) {
  const int threads = 256;
  const int blocks = (n + threads - 1) / threads;

  ScaleKernel<T><<<blocks, threads, 0, stream>>>(input, output, factor, n);

  return cudaGetLastError();
}

}  // namespace flashinfer

Key points:

  • Framework-agnostic (no Torch headers)
  • Uses raw pointers
  • Template-based for dtype flexibility
  • Only includes what's needed (cuda_runtime, cuda_fp16, cuda_bf16)

Step 2: Create Launcher in csrc/

Create csrc/scale.cu:

#include "flashinfer/scale.cuh"

using namespace flashinfer;

void scale_launcher(TensorView input, TensorView output,
                    float factor) {
  CHECK_INPUT(input);
  CHECK_INPUT(output);
  TVM_FFI_ICHECK_EQ(input.dtype(), output.dtype());
  int n = input.numel();
  auto stream = get_stream(input.device());

  DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP32_FP16(input.dtype(), DType, [&] {
    cudaError_t status = ScaleLauncher<DType>(
      input.data_ptr<DType>(),
      output.data_ptr<DType>(),
      static_cast<DType>(factor),
      n,
      stream
    );
    TVM_FFI_ICHECK(status == cudaSuccess)
        << "Failed to run ScaleLauncher: " << cudaGetErrorString(status);
    return true;
  });
}

Key points:

  • Includes TVM FFI utils headers tvm_ffi_utils.h (only allowed in csrc/)
  • Uses tvm::ffi::TensorView as input and output tensor types
  • Uses macros defined in tvm_ffi_utils.h to check the input and output if both on CUDA device, both contiguous, and share the same data type
  • Gets CUDA stream by TVM FFI, and prepare all scalar inputs for kernel function
  • Dispatches on dtype with macros defined in tvm_ffi_utils.h, or adds new one if not covered
  • Converts tvm::ffi::TensorView to raw pointers
  • Handles the result status of kernel by TVM_FFI_ICHECK
  • Add descriptive error messages with << operator
  • Use TVM-FFI exceptions: TVM_FFI_THROW(ErrorType) << "message" for custom error checking

TVM-FFI Error Handling:

  • TVM_FFI_THROW(ValueError) << "message" - Throw ValueError with custom message
  • TVM_FFI_THROW(TypeError) << "message" - Throw TypeError
  • Use << to chain multiple values in the error message
  • Errors are properly propagated back to Python

When to use TVM_FFI_THROW vs TVM_FFI_LOG_AND_THROW:

  • TVM_FFI_THROW: Use for normal runtime error handling. This is the standard way to report errors that will be caught and propagated to Python.

  • TVM_FFI_LOG_AND_THROW: Use only in cases where:

    1. The function may be called during object construction time (e.g., validation in constructors or setup methods)
    2. The exception may not be caught properly (e.g., during module initialization)
    3. The error condition almost never fails in practice (e.g., internal errors, unsupported dtype combinations in dispatch macros)

    This variant logs the error message before throwing, ensuring visibility even if the exception doesn't propagate correctly.

Example from fused_moe (see csrc/trtllm_fused_moe_kernel_launcher.cu):

// In a setup/validation function that may be called during construction
void check_weights_shape(std::string which_weights) const {
  if (which_weights != "gemm1" && which_weights != "gemm2") {
    // Internal error that should never happen - use LOG_AND_THROW
    TVM_FFI_LOG_AND_THROW(InternalError) << "Internal error: which_weights = " << which_weights;
  }
  // ...
  if (weight_layout is unsupported) {
    // Unsupported config during setup - use LOG_AND_THROW
    TVM_FFI_LOG_AND_THROW(NotImplementedError)
        << "Unsupported weight_layout: " << (int)weight_layout;
  }
}

// In a normal runtime function
void scale_run(TensorView input, TensorView output, double factor) {
  if (!input_tensor.is_cuda()) {
    // Normal validation error - use TVM_FFI_THROW
    TVM_FFI_THROW(ValueError) << "Input must be a CUDA tensor";
  }
}

Step 3: Create TVM-FFI Binding in csrc/

Create csrc/scale_jit_binding.cu:

#include "scale.cu"
#include "tvm_ffi_utils.h"

// Forward declaration
void scale_launcher(TensorView input, TensorView output, float factor);

// Export to TVM-FFI
TVM_FFI_DLL_EXPORT_TYPED_FUNC(run, scale_launcher);

Key points:

  • Forward declare the launcher function first
  • Export using TVM_FFI_DLL_EXPORT_TYPED_FUNC(name, function)

Step 4: Create JIT Generator (No Jinja for Simple Case)

Create flashinfer/jit/scale.py:

import os
import shutil
from pathlib import Path

from . import JitSpec, gen_jit_spec
from . import env as jit_env
from .core import write_if_different


def get_scale_uri(dtype_in: str, dtype_out: str) -> str:
    """Generate unique identifier for scale module."""
    return f"scale_dtype_in_{dtype_in}_dtype_out_{dtype_out}"


def gen_scale_module(dtype_in, dtype_out):
    """
    Generate JIT module for scale operation.

    Note: This is a simple example without Jinja templating.
    The dtype dispatch is handled at runtime in the C++ code.
    """
    # Compute URI
    uri = get_scale_uri(dtype_in, dtype_out)

    # Create generation directory
    gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri
    os.makedirs(gen_directory, exist_ok=True)

    # Copy source files (no Jinja needed for this simple case)
    sources = []
    for fname in ["scale.cu", "scale_jit_binding.cu"]:
        src_path = jit_env.FLASHINFER_CSRC_DIR / fname
        dest_path = gen_directory / fname
        shutil.copy(src_path, dest_path)
        sources.append(dest_path)

    # Return JitSpec
    return gen_jit_spec(
        name=uri,
        sources=sources,
        extra_cuda_cflags=[],
    )

Key points:

  • No Jinja template needed for simple operations
  • Just copy source files to generation directory
  • URI uniquely identifies the module configuration
  • NEVER write to package directories - see "JIT Directory Rules" in CLAUDE.md

(Optional) Specifying Supported CUDA Architectures

FlashInfer uses CompilationContext to manage CUDA architecture targets. This is critical because some kernels only work on specific GPU architectures (e.g., Hopper SM90, Blackwell SM100).

How CompilationContext Works

Automatic Detection (default):

from flashinfer.compilation_context import CompilationContext

ctx = CompilationContext()
# Automatically detects all GPUs in the system
# For SM90+, adds 'a' suffix (e.g., 9.0a for Hopper)
# Result: ctx.TARGET_CUDA_ARCHS = {(9, '0a'), (10, '0a'), ...}

Manual Override (via environment variable):

export FLASHINFER_CUDA_ARCH_LIST="8.0 9.0a 10.0a"
# Now only these architectures will be compiled

Specifying Architectures in Your JIT Module

When creating a JIT module, specify which major SM versions are supported:

from flashinfer.jit.core import gen_jit_spec
from flashinfer.jit import current_compilation_context

def gen_my_hopper_only_module():
    """Example: Kernel works on SM90 and later supported architectures."""
    uri = get_my_uri(...)
    gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri
    # ... copy sources ...

    nvcc_flags = current_compilation_context.get_nvcc_flags_list(
        # Explicitly list supported SM versions - no automatic future compatibility
        supported_major_versions=[9, 10, 11, 12]  # SM90, SM100, SM110, SM120
    )

    return gen_jit_spec(
        name=uri,
        sources=sources,
        extra_cuda_cflags=nvcc_flags,
    )

def gen_my_blackwell_only_module():
    """Example: Kernel only works on SM100 (Blackwell)"""
    uri = get_my_uri(...)
    gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri
    # ... copy sources ...

    nvcc_flags = current_compilation_context.get_nvcc_flags_list(
        supported_major_versions=[10]  # SM100 only
    )

    return gen_jit_spec(
        name=uri,
        sources=sources,
        extra_cuda_cflags=nvcc_flags,
    )

def gen_my_universal_module():
    """Example: Kernel works on all architectures"""
    uri = get_my_uri(...)
    gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri
    # ... copy sources ...

    nvcc_flags = current_compilation_context.get_nvcc_flags_list(
        supported_major_versions=None  # All available architectures
    )

    return gen_jit_spec(
        name=uri,
        sources=sources,
        extra_cuda_cflags=nvcc_flags,
    )

What Happens:

  • ✅ If user's GPU is SM90 and they call a Hopper-only module → Compiles and runs
  • ❌ If user's GPU is SM80 and they call a Hopper-only module → RuntimeError: No supported CUDA architectures found for major versions [9, 10, 11, 12]

Real Examples from FlashInfe


Content truncated.

You might also like

flutter-development

aj-geddes

Build beautiful cross-platform mobile apps with Flutter and Dart. Covers widgets, state management with Provider/BLoC, navigation, API integration, and material design.

643969

drawio-diagrams-enhanced

jgtolentino

Create professional draw.io (diagrams.net) diagrams in XML format (.drawio files) with integrated PMP/PMBOK methodologies, extensive visual asset libraries, and industry-standard professional templates. Use this skill when users ask to create flowcharts, swimlane diagrams, cross-functional flowcharts, org charts, network diagrams, UML diagrams, BPMN, project management diagrams (WBS, Gantt, PERT, RACI), risk matrices, stakeholder maps, or any other visual diagram in draw.io format. This skill includes access to custom shape libraries for icons, clipart, and professional symbols.

591705

ui-ux-pro-max

nextlevelbuilder

"UI/UX design intelligence. 50 styles, 21 palettes, 50 font pairings, 20 charts, 8 stacks (React, Next.js, Vue, Svelte, SwiftUI, React Native, Flutter, Tailwind). Actions: plan, build, create, design, implement, review, fix, improve, optimize, enhance, refactor, check UI/UX code. Projects: website, landing page, dashboard, admin panel, e-commerce, SaaS, portfolio, blog, mobile app, .html, .tsx, .vue, .svelte. Elements: button, modal, navbar, sidebar, card, table, form, chart. Styles: glassmorphism, claymorphism, minimalism, brutalism, neumorphism, bento grid, dark mode, responsive, skeuomorphism, flat design. Topics: color palette, accessibility, animation, layout, typography, font pairing, spacing, hover, shadow, gradient."

318398

godot

bfollington

This skill should be used when working on Godot Engine projects. It provides specialized knowledge of Godot's file formats (.gd, .tscn, .tres), architecture patterns (component-based, signal-driven, resource-based), common pitfalls, validation tools, code templates, and CLI workflows. The `godot` command is available for running the game, validating scripts, importing resources, and exporting builds. Use this skill for tasks involving Godot game development, debugging scene/resource files, implementing game systems, or creating new Godot components.

339397

nano-banana-pro

garg-aayush

Generate and edit images using Google's Nano Banana Pro (Gemini 3 Pro Image) API. Use when the user asks to generate, create, edit, modify, change, alter, or update images. Also use when user references an existing image file and asks to modify it in any way (e.g., "modify this image", "change the background", "replace X with Y"). Supports both text-to-image generation and image-to-image editing with configurable resolution (1K default, 2K, or 4K for high resolution). DO NOT read the image file first - use this skill directly with the --input-image parameter.

451339

fastapi-templates

wshobson

Create production-ready FastAPI projects with async patterns, dependency injection, and comprehensive error handling. Use when building new FastAPI applications or setting up backend API projects.

304231

Stay ahead of the MCP ecosystem

Get weekly updates on new skills and servers.