mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 21:39:52 +00:00
cuda : support broadcast add & mul (#2192)
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
4304bd3cde
commit
206e01de11
33
ggml-cuda.cu
33
ggml-cuda.cu
@ -252,13 +252,13 @@ struct ggml_tensor_extra_gpu {
|
|||||||
cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs
|
cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs
|
||||||
};
|
};
|
||||||
|
|
||||||
static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) {
|
static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
if (i >= k) {
|
if (i >= kx) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
dst[i] = x[i] + y[i];
|
dst[i] = x[i] + y[i%ky];
|
||||||
}
|
}
|
||||||
|
|
||||||
static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) {
|
static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) {
|
||||||
@ -1996,9 +1996,9 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale
|
|||||||
dst[i] = scale * x[i];
|
dst[i] = scale * x[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
static void add_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) {
|
static void add_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
|
||||||
const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
|
const int num_blocks = (kx + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
|
||||||
add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
|
add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
|
static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
|
||||||
@ -2610,17 +2610,15 @@ inline void ggml_cuda_op_add(
|
|||||||
GGML_ASSERT(src1_ddf_i != nullptr);
|
GGML_ASSERT(src1_ddf_i != nullptr);
|
||||||
GGML_ASSERT(dst_ddf_i != nullptr);
|
GGML_ASSERT(dst_ddf_i != nullptr);
|
||||||
|
|
||||||
// TODO: support broadcasting
|
|
||||||
GGML_ASSERT(ggml_nelements(src0) == ggml_nelements(src1));
|
|
||||||
|
|
||||||
const int64_t ne00 = src0->ne[0];
|
const int64_t ne00 = src0->ne[0];
|
||||||
const int64_t i01_diff = i01_high - i01_low;
|
const int64_t i01_diff = i01_high - i01_low;
|
||||||
|
|
||||||
// const int64_t ne10 = src1->ne[0];
|
const int64_t ne10 = src1->ne[0];
|
||||||
|
const int64_t ne11 = src1->ne[1];
|
||||||
|
|
||||||
// compute
|
// compute
|
||||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
|
add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10*ne11, cudaStream_main);
|
||||||
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
||||||
add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne00*i01_diff, cudaStream_main);
|
add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne00*i01_diff, cudaStream_main);
|
||||||
} else {
|
} else {
|
||||||
@ -2644,19 +2642,12 @@ inline void ggml_cuda_op_mul(
|
|||||||
GGML_ASSERT(dst_ddf_i != nullptr);
|
GGML_ASSERT(dst_ddf_i != nullptr);
|
||||||
|
|
||||||
const int64_t ne00 = src0->ne[0];
|
const int64_t ne00 = src0->ne[0];
|
||||||
|
const int64_t i01_diff = i01_high - i01_low;
|
||||||
|
|
||||||
const int64_t ne10 = src1->ne[0];
|
const int64_t ne10 = src1->ne[0];
|
||||||
const int64_t ne11 = src1->ne[1];
|
const int64_t ne11 = src1->ne[1];
|
||||||
|
|
||||||
for (int64_t i01 = i01_low; i01 < i01_high; i01++) {
|
mul_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10*ne11, cudaStream_main);
|
||||||
const int64_t i11 = i1*ne11 + i01%ne11; // broadcast src1 across src0
|
|
||||||
|
|
||||||
float * src0_ddf_i01 = src0_ddf_i + i01*ne00;
|
|
||||||
float * src1_ddf_i01 = src1_ddf_i + i11*ne10;
|
|
||||||
float * dst_ddf_i01 = dst_ddf_i + i01*ne00;
|
|
||||||
|
|
||||||
// compute
|
|
||||||
mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
|
|
||||||
}
|
|
||||||
|
|
||||||
(void) dst;
|
(void) dst;
|
||||||
(void) src0_ddq_i;
|
(void) src0_ddq_i;
|
||||||
|
Loading…
Reference in New Issue
Block a user