pass function pointers to kernels

I was reading leimao's blog about passing func pointers to kernels and thought I should update the code with modern C++17 syntax and some other things that I thought could be improved.

First, let's understand the problem. When we try to pass a raw function pointer to a CUDA kernel, we're passing a host-side memory address that doesn't exist on the device. The device can't dereference it. This is useful when we want to dynamically call different device functions from the host.

this is the problem to fix.

#include <cuda_runtime.h>

__device__ float add(float a, float b) {
    return a + b;
}

__global__ void kernel(float (*op)(float, float), float* a, float* b, float* out) {
    int idx = threadIdx.x;
    out[idx] = op(a[idx], b[idx]);  // op points to host memory, not device
}

int main() {
    float *d_a, *d_b, *d_out;
    cudaMalloc(&d_a, sizeof(float) * 10);
    cudaMalloc(&d_b, sizeof(float) * 10);
    cudaMalloc(&d_out, sizeof(float) * 10);

    // This will crash or produce garbage results as device doesnt understand the raw pointers sent to it.
    // here clangd gives us the error: Reference to __device__ function 'add' in __host__ function
    kernel<<<1, 10>>>(add, d_a, d_b, d_out);

    return 0;
}

The original code works around this with static device pointers and manual copying. But I had some readability issues with it so I wanted to improve it a bit.

how do you pass a function as an argument to a CUDA kernel?

fundamentally we have to do 4 things:

  1. Declare device functions (__device__)
  2. Store their addresses in device-side global variables (__device__ BinaryOp d_add = add_func)
  3. Copy that address to the host using cudaMemcpyFromSymbol
  4. Pass the copied address into the kernel launch

I had a few goals:

  • Make it correct first
  • Make it safe (RAII, error checks, type constraints)
  • Make it clear (good names, less duplication)
  • Make it fast (restrict, compiler hints)
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <driver_types.h>
#include <iostream>

inline void cuda_check(cudaError_t err, const char *msg) {
  if (err != cudaSuccess) {
    std::cerr << msg << ": " << cudaGetErrorString(err) << '\n';
    std::exit(EXIT_FAILURE);
  }
}

// since there's ownership transfer of raw pointers, we follow Rule of Five
template <typename T> struct DBuf {
  T *ptr = nullptr;

  DBuf() { cudaMalloc(&ptr, sizeof(T)); }
  ~DBuf() {
    if (ptr)
      cudaFree(ptr);
  }

  DBuf(const DBuf &) = delete;
  DBuf &operator=(const DBuf &) = delete;
  DBuf(DBuf &&o) noexcept : ptr(std::exchange(o.ptr, nullptr)) {}

  void h2d(const T &val) {
    cudaMemcpy(ptr, &val, sizeof(T), cudaMemcpyHostToDevice);
  }
  T d2h() const {
    T val;
    cudaMemcpy(&val, ptr, sizeof(T), cudaMemcpyDeviceToHost);
    return val;
  }
};

// device functions
template <typename T> __device__ T add_func(T x, T y) { return x + y; }
template <typename T> __device__ T mul_func(T x, T y) { return x * y; }

// device func pointer symbols
template <typename T> using BinaryOp = T (*)(T, T); // type alias
template <typename T> __device__ BinaryOp<T> dptr_add = add_func<T>;
template <typename T> __device__ BinaryOp<T> dptr_mul = mul_func<T>;

template <typename T> BinaryOp<T> dptr2hptr(BinaryOp<T> &dptr) {
  BinaryOp<T> h_ptr;
  cudaMemcpyFromSymbol(&h_ptr, dptr, sizeof(BinaryOp<T>));
  return h_ptr;
}

// kernel
template <typename T>
__global__ void apply_op(BinaryOp<T> op, const T *__restrict__ x,
                         const T *__restrict__ y, T *__restrict__ out) {
  *out = op(*x, *y);
}

template <typename T> void test(T x, T y) {
  static_assert(std::is_arithmetic_v<T>, "T must be arithmetic");

  // 1. send data to device
  DBuf<T> dx, dy, d_res;
  dx.h2d(x);
  dy.h2d(y);

  // 2. get pointer from device to host before dynamic apply_op call
  auto hptr_add = dptr2hptr(dptr_add<T>);
  auto hptr_mul = dptr2hptr(dptr_mul<T>);

  // 3. pass op dynamically from host
  apply_op<T><<<1, 1>>>(hptr_add, dx.ptr, dy.ptr, d_res.ptr);
  cuda_check(cudaDeviceSynchronize(), "add kernel sync check");
  std::cout << "   Sum:   " << d_res.d2h() << '\n';

  apply_op<T><<<1, 1>>>(hptr_mul, dx.ptr, dy.ptr, d_res.ptr);
  cuda_check(cudaDeviceSynchronize(), "mul kernel sync check");
  std::cout << "   Product:   " << d_res.d2h() << '\n';
}

int main() {
  std::cout << "int:\n";
  test<int>(2, 10);

  std::cout << "float:\n";
  test<float>(2.05f, 10.0f);

  std::cout << "double:\n";
  test<double>(2.05, 10.0);
}
$ nvcc main.cu  -o main && ./main
int:
   Sum:   12
   Product:   20
float:
   Sum:   12.05
   Product:   20.5
double:
   Sum:   12.05
   Product:   20.5