// kernels for ggml-cuda #include #include template using to_t_cuda_t = void (*)(const void * x, dst_t * y, int k, cudaStream_t stream); // support for vector types in generic code template struct vec2_t_impl; template<> struct vec2_t_impl { typedef half2 type; }; template<> struct vec2_t_impl { typedef float2 type; }; template using vec2_t = typename vec2_t_impl::type; template inline __host__ __device__ vec2_t make_vec2_t(const T & x, const T & y); template<> inline __host__ __device__ vec2_t make_vec2_t(const half & x, const half & y) { return make_half2 (x, y); } template<> inline __host__ __device__ vec2_t make_vec2_t(const float & x, const float & y) { return make_float2(x, y); } // the cuda headers define operators for half2, but not for float2 // they are defined here to simplify generic code inline __host__ __device__ float2 operator+(const float2 & a, const float2 & b) { return make_float2(a.x + b.x, a.y + b.y); } inline __host__ __device__ float2 operator-(const float2 & a, const float2 & b) { return make_float2(a.x - b.x, a.y - b.y); } inline __host__ __device__ float2 operator*(const float2 & a, const float2 & b) { return make_float2(a.x * b.x, a.y * b.y); } inline __host__ __device__ float2 operator/(const float2 & a, const float2 & b) { return make_float2(a.x / b.x, a.y / b.y); } inline __host__ __device__ float2 & operator+=( float2 & a, const float2 & b) { a.x += b.x; a.y += b.y; return a; } inline __host__ __device__ float2 & operator-=( float2 & a, const float2 & b) { a.x -= b.x; a.y -= b.y; return a; } inline __host__ __device__ float2 & operator*=( float2 & a, const float2 & b) { a.x *= b.x; a.y *= b.y; return a; } inline __host__ __device__ float2 & operator/=( float2 & a, const float2 & b) { a.x /= b.x; a.y /= b.y; return a; } template using dequantize_kernel_t = void (*)(const void * vx, const int ib, const int iqs, vec2_t & v); __device__ half sqrt(const half x) { return hsqrt(x); } __device__ half exp(const half x) { return hexp(x); } __device__ half2 exp(const half2 x) { return h2exp(x); } __device__ half cos(const half x) { return hcos(x); } __device__ half sin(const half x) { return hsin(x); } __device__ half max(const half x, const half y) { return __hmax(x, y); } __device__ half2 max(const half2 x, const half2 y) { return __hmax2(x, y); } template struct op_max { __device__ T operator()(T a, T b) const { return max(a, b); } }; template struct op_sum { __device__ T operator()(T a, T b) const { return a + b; } }; template class op_t, typename T> static inline __device__ T warp_reduce_all(T val) { op_t op; #pragma unroll for (int mask = warpSize/2; mask > 0; mask /= 2) { val = op(val, __shfl_xor_sync(0xffffffff, val, mask, 32)); } return val; } template static __device__ T zero_init() { return T(0); } template<> __device__ half2 zero_init() { return half2(0.0f, 0.0f); } template class op_t, typename T> static __device__ T block_reduce_all(const T val, const T init = zero_init()) { const int warp_id = threadIdx.x / warpSize; // warp id within the block const int lane_id = threadIdx.x % warpSize; // lane id within the warp const int num_warps = blockDim.x / warpSize; // number of warps in the block __shared__ T lane_result[32]; // max 32 warps per block // reduce warps T warp_reduction = warp_reduce_all(val); __syncthreads(); // first thread within a warp writes reduction to shared memory if (lane_id == 0) { lane_result[warp_id] = warp_reduction; } // wait for all warps to finish writing their reductions __syncthreads(); // reduce the results of all warps T block_reduction = init; if (lane_id < num_warps) { block_reduction = lane_result[lane_id]; } block_reduction = warp_reduce_all(block_reduction); return block_reduction; } template static __device__ void convert_fp16(const void * vx, const int ib, const int iqs, vec2_t & v) { const half * x = (const half *) vx; v.x = (dst_t)(x[ib + iqs + 0]); v.y = (dst_t)(x[ib + iqs + 1]); } template static __device__ void convert_fp32(const void * vx, const int ib, const int iqs, vec2_t & v) { const float * x = (const float *) vx; v.x = (dst_t)(x[ib + iqs + 0]); v.y = (dst_t)(x[ib + iqs + 1]); } template static __global__ void k_mul_mat_p021(const src0_t * vx, const src1_t * y, dst_t * dst, const int ncols_x, const int nrows_x, const int nchannels_x) { const src0_t * x = vx; // const int col_x = blockDim.x*blockIdx.x + threadIdx.x; // const int row_x = blockDim.y*blockIdx.y + threadIdx.y; const int row_x = blockDim.y*blockIdx.y + threadIdx.y; const int channel = blockDim.z*blockIdx.z + threadIdx.z; const int nrows_y = ncols_x; const int nrows_dst = nrows_x; const int row_dst = row_x; dst_t tmp = 0; for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) { const int col_x = col_x0 + threadIdx.x; if (col_x >= ncols_x) { break; } // x is transposed and permuted const int ix = row_x*nchannels_x*ncols_x + channel*ncols_x + col_x; const dst_t xi = (dst_t)(x[ix]); const int row_y = col_x; // y is not transposed but permuted const int iy = channel*nrows_y + row_y; tmp += xi * y[iy]; } // dst is not transposed and not permuted const int idst = channel*nrows_dst + row_dst; // sum up partial sums and write back result #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); } if (threadIdx.x == 0) { dst[idst] = tmp; } } template static __global__ void k_mul_mat_vec_nc( const src0_t * vx, const src1_t * y, dst_t * dst, const int ncols_x, const int nrows_x, const int row_stride_x, const int nchannels_x, const int channel_stride_x) { const src0_t * x = vx; const int row_x = blockDim.y*blockIdx.y + threadIdx.y; const int channel = blockDim.z*blockIdx.z + threadIdx.z; const int nrows_y = ncols_x; const int nrows_dst = nrows_x; const int row_dst = row_x; const int idst = channel*nrows_dst + row_dst; dst_t tmp = 0; for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) { const int col_x = col_x0 + threadIdx.x; if (col_x >= ncols_x) { break; } const int ix = channel*channel_stride_x + row_x*row_stride_x + col_x; const dst_t xi = (dst_t)(x[ix]); const int row_y = col_x; const int iy = channel*nrows_y + row_y; tmp += xi * y[iy]; } // sum up partial sums and write back result #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); } if (threadIdx.x == 0) { dst[idst] = tmp; } } template static __global__ void k_cpy(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= ne) { return; } const int i02 = i / (ne00*ne01); const int i01 = (i - i02*ne01*ne00) / ne00; const int i00 = i - i02*ne01*ne00 - i01*ne00; const int x_offset = i00*nb00 + i01*nb01 + i02*nb02; const int i12 = i / (ne10*ne11); const int i11 = (i - i12*ne10*ne11) / ne10; const int i10 = i - i12*ne10*ne11 - i11*ne10; const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12; *(dst_t *)(cdst + dst_offset) = *(const src_t *)(cx + x_offset); } template static __global__ void k_add(const src0_t * x, const src1_t * y, dst_t * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } dst[i] = (dst_t)x[i] + (dst_t)y[i]; } template static __global__ void k_mul(const src0_t * x, const src1_t * y, dst_t * dst, const int kx, const int ky) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= kx) { return; } dst[i] = (dst_t)x[i] * (dst_t)y[i%ky]; } template static __global__ void k_silu(const src0_t * x, dst_t * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } dst[i] = x[i] / (src0_t(1) + exp(-x[i])); } // TODO: unstable with f16 compute, using f32 compute for now template static __global__ void k_rms_norm(const src0_t * x, dst_t * dst, const int ncols) { const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; const float eps = 1e-6; float tmp = 0; // partial sum for thread in warp for (int col = tid; col < ncols; col += WARP_SIZE) { const float xi = x[row*ncols + col]; tmp += xi * xi; } // sum up partial sums #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); } const float mean = tmp / (float)ncols; const float scale = 1.0f / sqrtf(mean + eps); for (int col = tid; col < ncols; col += WARP_SIZE) { dst[row*ncols + col] = scale * (float)x[row*ncols + col]; } } template static __global__ void k_rope(const src0_t * x, dst_t * dst, const int ncols, const float p, const float theta_scale) { const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x); if (col >= ncols) { return; } const int row = blockDim.y*blockIdx.y + threadIdx.y; const int i = row*ncols + col; const dst_t theta = p * powf(theta_scale, col/2); const dst_t sin_theta = sin(theta); const dst_t cos_theta = cos(theta); const dst_t x0 = x[i + 0]; const dst_t x1 = x[i + 1]; dst[i + 0] = (dst_t)x0*cos_theta - (dst_t)x1*sin_theta; dst[i + 1] = (dst_t)x0*sin_theta + (dst_t)x1*cos_theta; } template static __global__ void k_diag_mask_inf(const src0_t * x, dst_t * dst, const int ncols, const int rows_per_channel, const int n_past) { const int col = blockDim.x*blockIdx.x + threadIdx.x; const int row = blockDim.y*blockIdx.y + threadIdx.y; if (col >= ncols) { return; } const int i = row*ncols + col; //dst[i] = col > (n_past + row % rows_per_channel) ? (dst_t)-INFINITY : (dst_t)x[i]; dst[i] = (dst_t)x[i] - (dst_t)((col > n_past + row % rows_per_channel) * INT_MAX); // equivalent within rounding error but slightly faster on GPU } // TODO: numerically stable version - low prio since the softmax is computed in the fused attention kernel // check: https://arxiv.org/pdf/2001.04438.pdf template static __global__ void k_soft_max_orig(const src0_t * x, dst_t * dst, const int ncols) { const int row = blockDim.y*blockIdx.y + threadIdx.y; const int block_size = blockDim.x; const int tid = threadIdx.x; float tmp = 0; for (int block_start = 0; block_start < ncols; block_start += block_size) { const int col = block_start + tid; if (col >= ncols) { break; } const int i = row*ncols + col; const float val = expf(x[i]); tmp += val; dst[i] = val; } // sum up partial sums #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); } for (int block_start = 0; block_start < ncols; block_start += block_size) { const int col = block_start + tid; if (col >= ncols) { break; } const int i = row*ncols + col; dst[i] /= tmp; } } template static __global__ void k_soft_max(const src_t * x, dst_t * dst, const int64_t nrows, const int64_t ncols) { //assert(ncols % pack_size == 0); const int tid = threadIdx.x; const int num_packs = ncols / pack_size; for (int row = blockIdx.x; row < nrows; row += gridDim.x) { src_t th_max = -INFINITY; // row max thread #pragma unroll for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { // load pack src_t pack[pack_size]; #pragma unroll for (int i = 0; i < pack_size; i++) { pack[i] = x[row * ncols + pack_id * pack_size + i]; } // reduce max pack #pragma unroll for (int i = 0; i < pack_size; ++i) { th_max = max(th_max, pack[i]); } } // reduce max row warp threads src_t row_max = block_reduce_all(th_max, (src_t)-INFINITY); // row exp sum thread src_t th_sum = 0; #pragma unroll for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { // load pack src_t pack[pack_size]; #pragma unroll for (int i = 0; i < pack_size; i++) { pack[i] = x[row * ncols + pack_id * pack_size + i]; } // reduce pack #pragma unroll for (int i = 0; i < pack_size; ++i) { th_sum += exp(pack[i] - row_max); } } // reduce row exp sum all threads src_t row_sum = block_reduce_all(th_sum); // store (row - row_max) / row exp sum #pragma unroll for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) { // load pack src_t pack[pack_size]; #pragma unroll for (int i = 0; i < pack_size; i++) { pack[i] = x[row * ncols + pack_id * pack_size + i]; } // reduce pack #pragma unroll for (int i = 0; i < pack_size; ++i) { pack[i] = exp(pack[i] - row_max) / row_sum; } // store pack #pragma unroll for (int i = 0; i < pack_size; i++) { dst[row * ncols + pack_id * pack_size + i] = pack[i]; } } } } template static __global__ void k_scale(const src0_t * x, dst_t * dst, const src1_t * scale, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } dst[i] = (dst_t)(*scale) * (dst_t)x[i]; } template dequantize_kernel> static __global__ void k_get_rows(const void * x, const int * y, dst_t * dst, const int ncols) { const int col = (blockIdx.x*blockDim.x + threadIdx.x)*2; const int row = blockDim.y*blockIdx.y + threadIdx.y; if (col >= ncols) { return; } const int r = y[row]; // copy x[r*ncols + col] to dst[row*ncols + col] const int xi = r*ncols + col; const int di = row*ncols + col; const int ib = xi/qk; // block index const int iqs = (xi%qk)/qr; // quant index const int iybs = di - di%qk; // y block start index const int y_offset = qr == 1 ? 1 : qk/2; // dequantize vec2_t v; dequantize_kernel(x, ib, iqs, v); dst[iybs + iqs + 0] = v.x; dst[iybs + iqs + y_offset] = v.y; }