mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 02:44:36 +00:00
opencl : fix element-wise multiplication (#3656)
This commit is contained in:
parent
cb33f43a2a
commit
1117d06607
@ -1395,75 +1395,46 @@ static void ggml_cl_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1,
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t ne03 = src0->ne[3];
|
||||
const int64_t ne0 = ne00 * ne01 * ne02 * ne03;
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
const int64_t ne11 = src1->ne[1];
|
||||
const int64_t ne12 = src1->ne[2];
|
||||
const int64_t ne13 = src1->ne[3];
|
||||
const int64_t nb10 = src1->nb[0];
|
||||
const int nb2 = dst->nb[2];
|
||||
const int nb3 = dst->nb[3];
|
||||
size_t x_size;
|
||||
size_t d_size;
|
||||
|
||||
cl_mem d_X = ggml_cl_pool_malloc(ne0 * sizeof(float), &x_size); // src0
|
||||
cl_mem d_X = ggml_cl_pool_malloc(ne00 * ne01 * sizeof(float), &x_size); // src0
|
||||
cl_mem d_Y = (cl_mem) src1->extra; // src1 is already on device, broadcasted.
|
||||
cl_mem d_D = ggml_cl_pool_malloc(ne0 * sizeof(float), &d_size); // dst
|
||||
cl_mem d_D = ggml_cl_pool_malloc(ne00 * ne01 * sizeof(float), &d_size); // dst
|
||||
|
||||
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
const int i0 = i03*ne02 + i02;
|
||||
|
||||
cl_event ev;
|
||||
|
||||
// copy src0 to device
|
||||
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, i0, src0, i03, i02, &ev));
|
||||
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, &ev));
|
||||
|
||||
if (nb10 == sizeof(float)) {
|
||||
// Contiguous, avoid overhead from queueing many kernel runs
|
||||
const int64_t i13 = i03%ne13;
|
||||
const int64_t i12 = i02%ne12;
|
||||
const int i1 = i13*ne12*ne11 + i12*ne11;
|
||||
const int64_t i13 = i03%ne13;
|
||||
const int64_t i12 = i02%ne12;
|
||||
const int i1 = i13*ne12*ne11 + i12*ne11;
|
||||
|
||||
cl_int x_offset = 0;
|
||||
cl_int y_offset = i1*ne10;
|
||||
cl_int d_offset = 0;
|
||||
cl_int x_offset = 0;
|
||||
cl_int y_offset = i1*ne10;
|
||||
cl_int d_offset = 0;
|
||||
|
||||
size_t global = ne00 * ne01;
|
||||
cl_int ky = ne10;
|
||||
CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X));
|
||||
CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset));
|
||||
CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y));
|
||||
CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset));
|
||||
CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D));
|
||||
CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset));
|
||||
CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky));
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL));
|
||||
} else {
|
||||
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
||||
const int64_t i13 = i03%ne13;
|
||||
const int64_t i12 = i02%ne12;
|
||||
const int64_t i11 = i01%ne11;
|
||||
const int i1 = i13*ne12*ne11 + i12*ne11 + i11;
|
||||
size_t global = ne00 * ne01;
|
||||
cl_int ky = ne10 * ne11;
|
||||
|
||||
cl_int x_offset = i01*ne00;
|
||||
cl_int y_offset = i1*ne10;
|
||||
cl_int d_offset = i01*ne00;
|
||||
|
||||
// compute
|
||||
size_t global = ne00;
|
||||
cl_int ky = ne10;
|
||||
CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X));
|
||||
CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset));
|
||||
CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y));
|
||||
CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset));
|
||||
CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D));
|
||||
CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset));
|
||||
CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky));
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL));
|
||||
}
|
||||
}
|
||||
CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X));
|
||||
CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset));
|
||||
CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y));
|
||||
CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset));
|
||||
CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D));
|
||||
CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset));
|
||||
CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky));
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL));
|
||||
|
||||
CL_CHECK(clReleaseEvent(ev));
|
||||
CL_CHECK(clFinish(queue));
|
||||
|
Loading…
Reference in New Issue
Block a user