Performance no longer terrible

This commit is contained in:
JohannesGaessler 2023-05-11 23:27:06 +02:00
parent 4b12881329
commit d882d1c2fe

View File

@ -488,6 +488,34 @@ static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor
}
}
static cudaError_t ggml_cuda_h2d_tensor_2d_hack(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream, void * wdata) {
const uint64_t ne0 = src->ne[0];
const uint64_t ne1 = src->ne[1];
const uint64_t nb0 = src->nb[0];
const uint64_t nb1 = src->nb[1];
const uint64_t nb2 = src->nb[2];
const uint64_t nb3 = src->nb[3];
const enum ggml_type type = src->type;
const size_t ts = ggml_type_size(type);
const size_t bs = ggml_blck_size(type);
const void * x = (const void *) ((const char *) wdata + i2*nb2 + i3*nb3);
if (nb0 == ts && nb1 == ts*ne0/bs) {
return cudaMemcpyAsync(dst, x, ne1*nb1, cudaMemcpyHostToDevice, stream);
} else if (nb0 == ts) {
return cudaMemcpy2DAsync(dst, ts*ne0/bs, x, nb1, ts*ne0/bs, ne1, cudaMemcpyHostToDevice, stream);
} else {
for (uint64_t i1 = 0; i1 < ne1; i1++) {
const void * rx = (const void *) ((const char *) x + i1*nb1);
void * rd = (void *) ((char *) dst + i1*ts*ne0/bs);
// pretend the row is a matrix with cols=1
cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyHostToDevice, stream);
if (r != cudaSuccess) return r;
}
return cudaSuccess;
}
}
static void ggml_cuda_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
@ -695,13 +723,13 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
// copy src1 to device
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
CUDA_CHECK(ggml_cuda_h2d_tensor_2d_hack(c_Y, src1, i03, i02, cudaStream, wdata));
// wait for data
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
// compute
dequantize_mul_mat_q4_0_cuda(c_Q, wdata + i * QK8_0, c_D, ne00, ne01, cudaStream);
dequantize_mul_mat_q4_0_cuda(c_Q, c_Y, c_D, ne00, ne01, cudaStream);
CUDA_CHECK(cudaGetLastError());
} else {