metal-kernel
Write Metal/MPS kernels for PyTorch operators. Use when adding MPS device support to operators, implementing Metal shaders, or porting CUDA kernels to Apple Silicon. Covers native_functions.yaml dispatch, host-side operators, and Metal kernel implementation.
Install
mkdir -p .claude/skills/metal-kernel && curl -L -o skill.zip "https://mcp.directory/api/skills/download/2590" && unzip -o skill.zip -d .claude/skills/metal-kernel && rm skill.zipInstalls to .claude/skills/metal-kernel
About this skill
Metal Kernel Writing Guide
This skill guides you through implementing Metal kernels for PyTorch operators on Apple Silicon.
Important: The goal of this skill is to use native Metal capabilities via the c10/metal/ infrastructure, NOT MPSGraph. Native Metal kernels provide better control, performance, and maintainability.
Overview
There are two workflows covered by this skill:
- Adding new MPS support - Implementing a new operator from scratch
- Migrating from MPSGraph - Converting existing MPSGraph-based operators to native Metal
Both workflows involve:
- Update dispatch in
aten/src/ATen/native/native_functions.yaml - Write Metal kernel in
aten/src/ATen/native/mps/kernels/ - Implement host-side stub in
aten/src/ATen/native/mps/operations/
Step 1: Update native_functions.yaml
Location: aten/src/ATen/native/native_functions.yaml
For New Operators
Find the operator entry and add MPS dispatch:
# Simple MPS-specific implementation
- func: my_op(Tensor self) -> Tensor
dispatch:
CPU: my_op_cpu
CUDA: my_op_cuda
MPS: my_op_mps
# Shared implementation across devices (preferred for structured kernels)
- func: my_op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA, MPS: my_op_out
# Structured kernel (preferred for new ops)
- func: my_op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA, MPS: my_op_out
For Migrating from MPSGraph
When migrating an existing operator from MPSGraph to native Metal, consolidate the dispatch entry:
# BEFORE (MPSGraph-based, separate dispatch)
- func: atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: atan2_out
MPS: atan2_out_mps # Separate MPS implementation
# AFTER (native Metal, shared dispatch via stub)
- func: atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA, MPS: atan2_out # MPS now uses the same stub mechanism
Key change: Replace MPS: my_op_out_mps with adding MPS to the shared dispatch line (e.g., CPU, CUDA, MPS: my_op_out).
Dispatch naming conventions:
MPS: function_name_mps- MPS-specific implementation (old MPSGraph pattern)CPU, CUDA, MPS: function_name- Shared stub implementation (native Metal pattern)
Step 2: Implement Metal Kernel
Location: aten/src/ATen/native/mps/kernels/
Unary Kernel Pattern
// MyKernel.metal
#include <c10/metal/indexing.h>
#include <c10/metal/utils.h>
#include <metal_stdlib>
using namespace metal;
using namespace c10::metal;
// Define operation functor
struct my_op_functor {
template <typename T>
inline T operator()(const T x) {
return /* your operation */;
}
};
// Register for supported types
REGISTER_UNARY_OP(my_op, float, float);
REGISTER_UNARY_OP(my_op, half, half);
REGISTER_UNARY_OP(my_op, bfloat, bfloat);
Binary Kernel Pattern
struct my_binary_functor {
template <typename T>
inline T operator()(const T a, const T b) {
return /* your operation */;
}
};
REGISTER_BINARY_OP(my_binary, float, float);
REGISTER_BINARY_OP(my_binary, half, half);
Binary Kernel Type Registration Macros
For binary operations, use the convenience macros defined in BinaryKernel.metal:
// Floating-point types only (float, half, bfloat)
REGISTER_FLOAT_BINARY_OP(my_op);
// Integral types with float output (for math ops like atan2, copysign)
// Registers: long->float, int->float, short->float, uchar->float, char->float, bool->float
REGISTER_INT2FLOAT_BINARY_OP(my_op);
// Integral types with same-type output (for bitwise/logical ops)
// Registers: long, int, short, uchar, char, bool
REGISTER_INTEGER_BINARY_OP(my_op);
// Floating-point with opmath precision (for ops needing higher precision)
REGISTER_OPMATH_FLOAT_BINARY_OP(my_op);
Common patterns:
- Math functions (atan2, copysign, logaddexp): Use both
REGISTER_FLOAT_BINARY_OPandREGISTER_INT2FLOAT_BINARY_OP - Comparison/logical ops (maximum, minimum): Use both
REGISTER_FLOAT_BINARY_OPandREGISTER_INTEGER_BINARY_OP - Arithmetic ops (add, sub, mul): Use both
REGISTER_FLOAT_BINARY_OPandREGISTER_INTEGER_BINARY_OP
Example for atan2 (supports both float and int inputs):
struct atan2_functor {
template <typename T, enable_if_t<is_floating_point_v<T>, bool> = true>
inline T operator()(const T a, const T b) {
return static_cast<T>(precise::atan2(float(a), float(b)));
}
template <typename T, enable_if_t<is_integral_v<T>, bool> = true>
inline float operator()(const T a, const T b) {
return precise::atan2(float(a), float(b));
}
};
REGISTER_FLOAT_BINARY_OP(atan2);
REGISTER_INT2FLOAT_BINARY_OP(atan2);
With Scalar Parameter
struct my_alpha_functor {
template <typename T>
inline T operator()(const T a, const T b, const T alpha) {
return a + c10::metal::mul(alpha, b);
}
};
REGISTER_UNARY_ALPHA_OP(my_alpha, float, float, float);
REGISTER_UNARY_ALPHA_OP(my_alpha, half, half, half);
Type-Specialized Functor
struct special_functor {
// Floating point types
template <typename T, enable_if_t<is_scalar_floating_point_v<T>, bool> = true>
inline T operator()(const T x) {
return precise::exp(x); // Use precise math
}
// Integral types
template <typename T, enable_if_t<is_scalar_integral_v<T>, bool> = true>
inline float operator()(const T x) {
return precise::exp(float(x));
}
// Complex types (float2 for cfloat, half2 for chalf)
template <typename T, enable_if_t<is_complex_v<T>, bool> = true>
inline T operator()(const T x) {
// x.x = real, x.y = imaginary
return T(/* real */, /* imag */);
}
};
Note on complex types: Complex numbers in Metal are represented as vector types:
c10::complex<float>maps tofloat2(x = real, y = imaginary)c10::complex<half>maps tohalf2
Use is_complex_v<T> to specialize for complex types in functors.
Available c10/metal Utilities
utils.h:
opmath_t<T>- Operation math type (half->float)accum_t<T>- Accumulation type for reductionsmax(),min()with NaN propagation
special_math.h:
precise::exp(),precise::log(),precise::sqrt()precise::sin(),precise::cos(),precise::tan()erf(),erfc(),erfinv()
indexing.h:
REGISTER_UNARY_OP(name, in_type, out_type)REGISTER_BINARY_OP(name, in_type, out_type)REGISTER_UNARY_ALPHA_OP(name, in_type, alpha_type, out_type)
Step 3: Implement Host-Side Stub
Location: aten/src/ATen/native/mps/operations/
Choose or create an appropriate file based on operation type:
UnaryKernel.mm- Single input operations via stub dispatchBinaryKernel.mm- Two input operations via stub dispatchUnaryOps.mm/BinaryOps.mm- Legacy MPSGraph implementations (for reference)ReduceOps.mm- Reductions (sum, mean, max, etc.)- Create new file for distinct operation categories
Stub Registration Pattern (Preferred for Native Metal)
For structured kernels that use the TensorIterator pattern:
// In BinaryKernel.mm (or appropriate file)
static void my_op_mps_kernel(TensorIteratorBase& iter) {
lib.exec_binary_kernel(iter, "my_op"); // "my_op" matches the functor name in .metal
}
// Register the MPS stub - this connects to the dispatch system
REGISTER_DISPATCH(my_op_stub, &my_op_mps_kernel)
For unary operations:
static void my_unary_mps_kernel(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "my_unary");
}
REGISTER_DISPATCH(my_unary_stub, &my_unary_mps_kernel)
Migration: Removing Old MPSGraph Implementation
When migrating from MPSGraph, also remove the old implementation:
-
Remove from BinaryOps.mm (or UnaryOps.mm):
- Delete the
TORCH_IMPL_FUNC(my_op_out_mps)implementation - Remove the corresponding
#include <ATen/ops/my_op_native.h>header
- Delete the
-
Add to BinaryKernel.mm (or UnaryKernel.mm):
- Add the static kernel function
- Add the
REGISTER_DISPATCHcall
Step 4: Compile
After making changes, compile to verify everything builds correctly:
cd build && ninja torch_cpu
Testing
Basic operator support is already tested by test_output_match in test/test_mps.py. After implementing an operator, enable testing by removing expected failures:
1. Remove from common_mps.py
Location: torch/testing/_internal/common_mps.py
Find and remove the operator from skip/xfail lists:
# Remove entries like:
MPS_XFAILLIST = {
"my_op": ..., # Remove this line
}
MPS_SKIPLIST = {
"my_op": ..., # Remove this line
}
2. Remove from OpInfo decorators
Location: torch/testing/_internal/common_methods_invocations.py (or related files)
Remove MPS-specific decorators from the OpInfo:
OpInfo(
"my_op",
# Remove decorators like:
# decorators=[skipMPS, expectedFailureMPS("reason")],
...
)
3. Run tests to verify
# Run the specific operator test
python test/test_mps.py -k test_output_match_my_op
# Or run full MPS test suite
python test/test_mps.py
Debugging Metal Kernels with torch.mps.compile_shader
Use torch.mps.compile_shader to JIT-compile and test individual Metal kernels in isolation. This is invaluable for debugging multi-kernel pipelines where you need to verify each stage independently.
Basic Usage
import torch
source = '''
#include <metal_stdlib>
using namespace metal;
kernel void my_kernel(
const device float* input [[buffer(0)]],
device float* output [[buffer(1)]],
uint tid [[thread_position_in_grid]]) {
output[tid] = input[tid] * 2.0;
}
'''
lib = torch.mps.compile_shader(source)
inp = torch.tensor([1.0, 2.0, 3.0], device='mps')
out = torch.zeros(3, device='mps')
lib.my_kernel(inp, out, threads=[3, 1, 1], group_size=[3, 1, 1])
torch.mps.synchronize()
print(out) # tensor([2., 4., 6.], device='mps:0')
Dispatch Semantics
compile_shader uses dispatchThreads semantics (same as mtl_dispatch1DJob in PyTorch):
threads=[N, 1, 1]— total number of threads (NOT threadgroups)group_size=[G, 1, 1]— threads per threadgroup
This differs from the dispatchThreadgroups API used by some host-side code. To match dispatchThreadgroups:MTLSizeMake(num_tgs, num_slices, 1) threadsPerThreadgroup:MTLSizeMake(TG_SIZE, 1, 1):
# Equivalent compile_shader call:
lib.kernel(args...,
threads=[num_tgs * TG_SIZE, num_slices, 1],
group_size=[TG_SIZE, 1, 1])
Constant Buffer Parameters
Pass scalar constants as single-element tensors:
slice_size = torch.tensor([1024], dtype=torch.int32, device='mps')
lib.my_kernel(data, output, slice_size, threads=[1024, 1, 1], group_size=[256, 1, 1])
Debugging Strategy for Multi-Kernel Pipelines
When a pipeline of kernels (e.g., histogram → prefix_sum → scatter) produces wrong results, test each kernel individually and verify its output against a Python/NumPy reference:
# 1. Run GPU kernel
lib.histogram(keys, hist, ..., threads=[N, 1, 1], group_size=[256, 1, 1])
torch.mps.synchronize()
# 2. Compute reference in Python
ref_hist = compute_histogram_cpu(keys.cpu().numpy(), ...)
# 3. Compare
assert np.array_equal(hist.cpu().numpy(), ref_hist), "Histogram mismatch!"
This isolates which kernel in the pipeline is broken, rather than debugging the entire pipeline at once.
Common Pitfalls
- Wrong
threadscount —threadsis total threads, not threadgroups. For 5 threadgroups of 256, usethreads=[1280, 1, 1]. - Threadgroup memory —
compile_shaderdoesn't support[[threadgroup(N)]]parameters directly. If your kernel needs threadgroup memory, restructure to usethreadgrouparrays declared inside the kernel body instead.
Checklist
- Added MPS dispatch to
native_functions.yaml - Implemented Metal kernel in
kernels/ - Implemented host-side operator in
operations/ - Handles empty tensors
- Handles non-contiguous tensors
- Supports required dtypes (float32, float16, bfloat16, and often complex types via float2/half2)
- Removed expected failures from
torch/testing/_internal/common_mps.py - Removed skip/xfail decorators from OpInfo (if applicable)
More by pytorch
View all →You might also like
flutter-development
aj-geddes
Build beautiful cross-platform mobile apps with Flutter and Dart. Covers widgets, state management with Provider/BLoC, navigation, API integration, and material design.
drawio-diagrams-enhanced
jgtolentino
Create professional draw.io (diagrams.net) diagrams in XML format (.drawio files) with integrated PMP/PMBOK methodologies, extensive visual asset libraries, and industry-standard professional templates. Use this skill when users ask to create flowcharts, swimlane diagrams, cross-functional flowcharts, org charts, network diagrams, UML diagrams, BPMN, project management diagrams (WBS, Gantt, PERT, RACI), risk matrices, stakeholder maps, or any other visual diagram in draw.io format. This skill includes access to custom shape libraries for icons, clipart, and professional symbols.
godot
bfollington
This skill should be used when working on Godot Engine projects. It provides specialized knowledge of Godot's file formats (.gd, .tscn, .tres), architecture patterns (component-based, signal-driven, resource-based), common pitfalls, validation tools, code templates, and CLI workflows. The `godot` command is available for running the game, validating scripts, importing resources, and exporting builds. Use this skill for tasks involving Godot game development, debugging scene/resource files, implementing game systems, or creating new Godot components.
nano-banana-pro
garg-aayush
Generate and edit images using Google's Nano Banana Pro (Gemini 3 Pro Image) API. Use when the user asks to generate, create, edit, modify, change, alter, or update images. Also use when user references an existing image file and asks to modify it in any way (e.g., "modify this image", "change the background", "replace X with Y"). Supports both text-to-image generation and image-to-image editing with configurable resolution (1K default, 2K, or 4K for high resolution). DO NOT read the image file first - use this skill directly with the --input-image parameter.
ui-ux-pro-max
nextlevelbuilder
"UI/UX design intelligence. 50 styles, 21 palettes, 50 font pairings, 20 charts, 8 stacks (React, Next.js, Vue, Svelte, SwiftUI, React Native, Flutter, Tailwind). Actions: plan, build, create, design, implement, review, fix, improve, optimize, enhance, refactor, check UI/UX code. Projects: website, landing page, dashboard, admin panel, e-commerce, SaaS, portfolio, blog, mobile app, .html, .tsx, .vue, .svelte. Elements: button, modal, navbar, sidebar, card, table, form, chart. Styles: glassmorphism, claymorphism, minimalism, brutalism, neumorphism, bento grid, dark mode, responsive, skeuomorphism, flat design. Topics: color palette, accessibility, animation, layout, typography, font pairing, spacing, hover, shadow, gradient."
rust-coding-skill
UtakataKyosui
Guides Claude in writing idiomatic, efficient, well-structured Rust code using proper data modeling, traits, impl organization, macros, and build-speed best practices.
Stay ahead of the MCP ecosystem
Get weekly updates on new skills and servers.