Custom Gather-scatter Operator by CUTLASS

19 minute read

This blog is to log my experience of building efficient custom operator based on CUTLASS. Jump to the final implementation of gather and scatter matrix multiplication operator.

Intro

Implementing efficient CUDA kernel is challenging and requires thorough understanding of GPU architecture and takes a lot of time to design. CUTLASS provides a collection of abstractions of GEMM-based operation. It exploits the hierarchical “memory” of GPU by swizzling the data to maximize the memory bandwidth. Using CUTLASS, we can easily build operators with high performance.

Simple Example

There are some examples provided by the developers in cutlass/examples/python at main · NVIDIA/cutlass · GitHub, which tells you how to generate your own GEMM operator with custom data type and layout. For example,

import cutlass
import torch

dtype = torch.float16
plan = cutlass.Gemm(element=dtype, layout=cutlass.LayoutType.ColumnMajor, element_accumulator=torch.float32)
op = plan.construct()
gs_gemm = cutlass.emit.pytorch(op, name='gemm', cc=plan.cc, sourcedir='gemm_out')

Then you will find CUDA, PyBind and setup code under gemm_out/ directory. To use it, run

cd gemm_out
pip install -e . # install in the current env

Now you can directly import your operator in your code, for example,

import torch
import gemm # import your op
import cutlass

dtype = torch.float16
import random
random.seed(2024)

# Utility function to initialize A, B, C, and D matrices corresponding to dimensions M, N, and K
def initialize(dtype, M, N, K):
    sizes = [(M, K), (K, N), (M, N)]
    return [torch.randint(-3, 3, size, device='cuda').to(dtype) for size in sizes]

A, B, C = initialize(dtype, 128, 128, 128)
cutlass_out = gemm.run(A, B, C)

This is cool, but I am going to show you that cutlass.Gemm is not a full wrapper of its CUDA operator.

Motivation

I was building a gather-and-scatter matrix multiplication operator several month ago. Although we chose Triton for easy kernel fusion, its performance is still not as good as the CUTLASS version provided in cutlass/examples/36_gather_scatter_fusion at main · NVIDIA/cutlass · GitHub.

Unfortunately, the template provided in the Python emit function definition didn’t include the necessary options GatherA/B, ScatterD and input indices. Thus, my plan is to generate a normal GEMM kernel and modify the generated C++ code directly instead of changing the source code of CUTLASS.

check the source code in “python/cutlass/backend/gemm_operation.py”, the cpp code template looks like,

// Gemm operator ${operation_name}

using ${operation_name}_base =

typename cutlass::gemm::kernel::DefaultGemmUniversal<

${element_a}, ${layout_a}, ${transform_a}, ${align_a},

${element_b}, ${layout_b}, ${transform_b}, ${align_b},

${element_c}, ${layout_c},

${element_accumulator},

${opcode_class},

${arch},

cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,

cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,

cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,

${epilogue_functor},

${swizzling_functor},

${stages},

${math_operation}

>::GemmKernel;

Compared with the full DefaultGemmUniversal, it doesn’t include the last six options.

/// Gather operand A by using an index array

bool GatherA = false,

/// Gather operand B by using an index array

bool GatherB = false,

/// Scatter result D by using an index array

bool ScatterD = false,

/// Permute result D

typename PermuteDLayout = layout::NoPermute,

/// Permute operand A

typename PermuteALayout_ = layout::NoPermute,

/// Permute operand B

typename PermuteBLayout_ = layout::NoPermute,

///

Because changing the source of CUTLASS can be risky and complicated, I decide to modify the code generated by CUTLASS emitter.

Implement Custom Operator

Our first step is to generate a normal GEMM code by CUTLASS. Just as what we did in simple-example, but change the name to gs_gemm (gather-scatter-gemm),

import cutlass
import torch

dtype = torch.float16
plan = cutlass.Gemm(element=dtype, layout=cutlass.LayoutType.ColumnMajor, element_accumulator=torch.float32)
op = plan.construct()
gs_gemm = cutlass.emit.pytorch(op, name='gs_gemm', cc=plan.cc, sourcedir='gs_out')

After that, go to the generated folder, which looks like

.
├── gs_gemm.cpp
├── gs_gemm_kernel.cu
└── setup.py

We start from .cpp file, and change the interface of the operator.

// This file was automatically generated by the CUTLASS 3.5.0 Python interface (https://github.com/nvidia/cutlass/python)

#include <torch/extension.h>
#include <ATen/ATen.h>
#include <pybind11/stl.h>

// CUDA forward declarations
// indices shape is (gather_size, 1)
at::Tensor gs_gemm_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt, at::optional<const at::Tensor> Indices=at::nullopt, float alpha=1.f, float beta=0.f);

// C++ interface
at::Tensor gs_gemm(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt, at::optional<const at::Tensor> Indices=at::nullopt, float alpha=1.f, float beta=0.f) {
  return gs_gemm_kernel(A, B, C, Indices, alpha, beta);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("run", py::overload_cast<const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>, at::optional<const at::Tensor>, float, float>(&gs_gemm), py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("Indices") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f);
}

I added a new argument Indices, which specifies the columns we are going to select in matrix B.

Then let’s go to the kernel code. First we need to set GatherB and ScatterD to true when specifying the operator,

#include "cutlass/gemm/device/gemm_universal.h"


// Gemm operator cutlass_tensorop_f16_s16816gemm_f16_256x128_64x3_tt_align8
using DeviceKernel =
    typename cutlass::gemm::device::GemmUniversal<
        // Data type and layout of operand A
        cutlass::half_t, cutlass::layout::ColumnMajor,
        // Data type and layout of operand B
        cutlass::half_t, cutlass::layout::ColumnMajor,
        // Data type and layout of operand C
        cutlass::half_t, cutlass::layout::ColumnMajor,
        // Data type of accumulator
        float,
        // Class of operation
        cutlass::arch::OpClassTensorOp,
        // Compute capability of the target kernel
        cutlass::arch::Sm80,
        // Threadblock tile shape
        cutlass::gemm::GemmShape<256, 128, 64>,
        // Warp tile shape
        cutlass::gemm::GemmShape<64, 64, 64>,
        // Instruction shape
        cutlass::gemm::GemmShape<16, 8, 16>,
        // Epilogue functor
        cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 8, float, float>,
        // Swizzling function
        cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
        // Number of pipeline stages
        3,
        // Alignment of operands A and B
        8, 8,
        // Type of math operation
        cutlass::arch::OpMultiplyAdd,
        // Complex transform types of operands A and B
        cutlass::ComplexTransform::kNone, cutlass::ComplexTransform::kNone,
+        false, /*GatherA*/
+        true,  /*GatherB*/
+        true   /*ScatterD*/
    >;

Then in the following kernel launch function, add the new argument and pass that into the constructor of kernel’s argument

using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;
cutlass::Status gs_gemm_kernel_run(int M, int N, int K,
                        const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C, 
 +                       const int* Indices, 
                        DeviceKernel::ElementC* D,
                        ElementCompute alpha, ElementCompute beta) {

  typename DeviceKernel::Arguments arguments {
      cutlass::gemm::GemmUniversalMode::kGemm,
      {M, N, K},                                        // problem size
      1,                                                // split k dimension
      {alpha, beta},
      A, B, C, D,
      0, 0, 0, 0,                                       // batch strides
      DeviceKernel::LayoutA::packed({M, K}).stride(0),  // lda
      DeviceKernel::LayoutB::packed({K, N}).stride(0),  // ldb
      DeviceKernel::LayoutC::packed({M, N}).stride(0),  // ldc
      DeviceKernel::LayoutC::packed({M, N}).stride(0),  // ldd
+      nullptr,                                          // <- pointer to index vector to gather A on device
+      Indices,                                          // <- pointer to index vector to gather B on device
+      Indices                                         // <- pointer to index vector to scatter D on device
  };
  // keep the rest of the kernel launcher
  ...
  return status;
}

Finally, we also need to add the argument Indices to the tensor to pointer converter.

at::Tensor gs_gemm_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C, 
+                        at::optional<const at::Tensor> Indices, 
                        float alpha, float beta) {
    int M = A.size(0);
    int N = B.size(1);
    int K = A.size(1);

    typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ?
                                            nullptr :
                                            reinterpret_cast<typename DeviceKernel::ElementC*>(C->contiguous().data_ptr());
+    int* ptrIndices = (Indices == at::nullopt) ?
+                      nullptr :
+                      reinterpret_cast<int*>(Indices->contiguous().data_ptr());
    at::Tensor D = B.new_empty({M, N}, torch::kF16);

    cutlass::Status status = gs_gemm_kernel_run(M, N, K,
                                                reinterpret_cast<typename DeviceKernel::ElementA*>(A.contiguous().data_ptr()),
                                                reinterpret_cast<typename DeviceKernel::ElementB*>(B.contiguous().data_ptr()),
                                                ptrC,
+                                                ptrIndices,
                                                reinterpret_cast<typename DeviceKernel::ElementC*>(D.contiguous().data_ptr()),
                                                ElementCompute(alpha), ElementCompute(beta));

    TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
    return D;
}

Here I made the Indices an optional input, however, during my test, this looks like a bad idea since the kernel will throw fault if the indices is nullptr because we specified GatherB as True. I will update the code later.

Experiment

Now let’s see if the new kernel works. Start from compiling the kernel by pip install -e ., you will get

.
├── build
│   ├── lib.linux-x86_64-cpython-39
│   │   └── gs_gemm.cpython-39-x86_64-linux-gnu.so
│   └── temp.linux-x86_64-cpython-39
│       ├── build.ninja
│       ├── gs_gemm_kernel.o
│       └── gs_gemm.o
├── compiled_cache.db
├── gs_gemm.cpp
├── gs_gemm.cpython-39-x86_64-linux-gnu.so
├── gs_gemm.egg-info
│   ├── dependency_links.txt
│   ├── PKG-INFO
│   ├── SOURCES.txt
│   └── top_level.txt
├── gs_gemm_kernel.cu
└── setup.py

I write a test code to check if the results are correct

import torch
import gs_gemm
import cutlass

dtype = torch.float16

import random
random.seed(2023)

# Utility function to initialize A, B, C, and D matrices corresponding to dimensions M, N, and K
def initialize(dtype, M, N, K):
    sizes = [(M, K), (K, N), (M, N)]
    return [torch.randint(-3, 3, size, device='cuda').to(dtype) for size in sizes]

A, B, C = initialize(dtype, 128, 128, 128)

# select the first 12 columns
indices = torch.arange(12, device='cuda', dtype=torch.int32).reshape(1, -1)

cutlass_out = gs_gemm.run(A.clone(), B.clone(), None, indices)
torch.cuda.synchronize()
cutlass_out = cutlass_out.cpu()
print(cutlass_out)

The output is like

tensor([[ 98.,  22.,  30.,  ...,  24.,  43.,  31.],
        [ 67.,   0.,  42.,  ..., -27., -21.,  23.],
        [ 18., -15.,  64.,  ...,  57.,  39.,  33.],
        ...,
        [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
        [  0.,   0.,   0.,  ...,   0.,   0.,   0.],
        [  0.,   0.,   0.,  ...,   0.,   0.,   0.]], dtype=torch.float16)

Your can verify the correctness of the first 12 lines by comparing with

torch_out = torch.zeros_like(C)
torch_out[indices[0]] += (A.T @ B[indices[0]].T).T
print(torch_out)

The output is the same.

Here I made several transposes because the matrix declared in our operator is column major for better coalescing during gathering, you can find more detailed explanation in my previous post — Efficient Gather-and-scatter Matrix Multiplication Kernel with Triton - Xueshen Liu.

Final Implementation

gs_gemm.cpp

// This file was automatically generated by the CUTLASS 3.5.0 Python interface (https://github.com/nvidia/cutlass/python)

#include <torch/extension.h>
#include <ATen/ATen.h>
#include <pybind11/stl.h>

// CUDA forward declarations
// indices shape is (gather_size, 1)
at::Tensor gs_gemm_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt, at::optional<const at::Tensor> Indices=at::nullopt, float alpha=1.f, float beta=0.f);

// C++ interface
at::Tensor gs_gemm(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt, at::optional<const at::Tensor> Indices=at::nullopt, float alpha=1.f, float beta=0.f) {
  return gs_gemm_kernel(A, B, C, Indices, alpha, beta);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("run", py::overload_cast<const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>, at::optional<const at::Tensor>, float, float>(&gs_gemm), py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("Indices") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f);
}

gs_gemm_kernel.cu

// This file was automatically generated by the CUTLASS 3.5.0 Python interface (https://github.com/nvidia/cutlass/python)

#include <cuda_runtime.h>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include "cutlass/cutlass.h"
#include "cutlass/util/device_memory.h"

// helper function allocating the memory
void* device_memory_allocation(size_t size, int device_id=0) {
    if (size > 0) {
        torch::Device device(torch::kCUDA, device_id);
        cudaStream_t stream = at::cuda::getCurrentCUDAStream();
        torch::TensorOptions options = torch::TensorOptions().dtype(torch::kI8).device(device);
        at::Tensor device_tensor = torch::empty({(long)size,}, options);
        return reinterpret_cast<void*>(device_tensor.data_ptr());
    } else {
        return nullptr;
    }
}


#include "cutlass/gemm/device/gemm_universal.h"


// Gemm operator cutlass_tensorop_f16_s16816gemm_f16_256x128_64x3_tt_align8
using DeviceKernel =
    typename cutlass::gemm::device::GemmUniversal<
        // Data type and layout of operand A
        cutlass::half_t, cutlass::layout::ColumnMajor,
        // Data type and layout of operand B
        cutlass::half_t, cutlass::layout::ColumnMajor,
        // Data type and layout of operand C
        cutlass::half_t, cutlass::layout::ColumnMajor,
        // Data type of accumulator
        float,
        // Class of operation
        cutlass::arch::OpClassTensorOp,
        // Compute capability of the target kernel
        cutlass::arch::Sm80,
        // Threadblock tile shape
        cutlass::gemm::GemmShape<256, 128, 64>,
        // Warp tile shape
        cutlass::gemm::GemmShape<64, 64, 64>,
        // Instruction shape
        cutlass::gemm::GemmShape<16, 8, 16>,
        // Epilogue functor
        cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 8, float, float>,
        // Swizzling function
        cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
        // Number of pipeline stages
        3,
        // Alignment of operands A and B
        8, 8,
        // Type of math operation
        cutlass::arch::OpMultiplyAdd,
        // Complex transform types of operands A and B
        cutlass::ComplexTransform::kNone, cutlass::ComplexTransform::kNone,
        false, /*GatherA*/
        true,  /*GatherB*/
        true   /*ScatterD*/
    >;


using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;
cutlass::Status gs_gemm_kernel_run(int M, int N, int K,
                        const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C, 
                        const int* Indices, 
                        DeviceKernel::ElementC* D,
                        ElementCompute alpha, ElementCompute beta) {

  typename DeviceKernel::Arguments arguments {
      cutlass::gemm::GemmUniversalMode::kGemm,
      {M, N, K},                                        // problem size
      1,                                                // split k dimension
      {alpha, beta},
      A, B, C, D,
      0, 0, 0, 0,                                       // batch strides
      DeviceKernel::LayoutA::packed({M, K}).stride(0),  // lda
      DeviceKernel::LayoutB::packed({K, N}).stride(0),  // ldb
      DeviceKernel::LayoutC::packed({M, N}).stride(0),  // ldc
      DeviceKernel::LayoutC::packed({M, N}).stride(0),  // ldd
      nullptr,                                          // <- pointer to index vector to gather A on device
      Indices,                                          // <- pointer to index vector to gather B on device
      Indices                                         // <- pointer to index vector to scatter D on device
  };

  size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
  cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

  DeviceKernel gemm_op;
  cutlass::Status status = gemm_op.initialize(arguments,
                                              workspace.get(),
                                              nullptr);     // CUDA stream

  if (status != cutlass::Status::kSuccess) {
    return status;
  }

  status = gemm_op();
  return status;
}

at::Tensor gs_gemm_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C, 
                        at::optional<const at::Tensor> Indices, 
                        float alpha, float beta) {
    int M = A.size(0);
    int N = B.size(1);
    int K = A.size(1);

    typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ?
                                            nullptr :
                                            reinterpret_cast<typename DeviceKernel::ElementC*>(C->contiguous().data_ptr());
    int* ptrIndices = (Indices == at::nullopt) ?
                      nullptr :
                      reinterpret_cast<int*>(Indices->contiguous().data_ptr());
    at::Tensor D = B.new_empty({M, N}, torch::kF16);

    cutlass::Status status = gs_gemm_kernel_run(M, N, K,
                                                reinterpret_cast<typename DeviceKernel::ElementA*>(A.contiguous().data_ptr()),
                                                reinterpret_cast<typename DeviceKernel::ElementB*>(B.contiguous().data_ptr()),
                                                ptrC,
                                                ptrIndices,
                                                reinterpret_cast<typename DeviceKernel::ElementC*>(D.contiguous().data_ptr()),
                                                ElementCompute(alpha), ElementCompute(beta));

    TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
    return D;
}

Updated:

Comments