Fixed OpenLLaMA 3b CUDA mul_mat_vec_q (#2144)

This commit is contained in:
Johannes Gäßler 2023-07-08 20:01:44 +02:00 committed by GitHub
parent 061f5f8d21
commit 64639555ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -208,6 +208,7 @@ typedef struct {
static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding"); static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
#define WARP_SIZE 32 #define WARP_SIZE 32
#define MATRIX_ROW_PADDING 256 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
#define CUDA_ADD_BLOCK_SIZE 256 #define CUDA_ADD_BLOCK_SIZE 256
#define CUDA_MUL_BLOCK_SIZE 256 #define CUDA_MUL_BLOCK_SIZE 256
@ -1171,7 +1172,7 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
v.y = x[ib + iqs + 1]; v.y = x[ib + iqs + 1];
} }
static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int k) { static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int ndata, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x; const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) { if (i >= k) {
@ -1180,10 +1181,10 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
block_q8_1 * y = (block_q8_1 *) vy; block_q8_1 * y = (block_q8_1 *) vy;
const int ib = i / QK8_0; // block index const int ib = i / QK8_1; // block index
const int iqs = i % QK8_0; // quant index const int iqs = i % QK8_1; // quant index
const float xi = x[i]; const float xi = i < ndata ? x[i] : 0.0f;
float amax = fabsf(xi); float amax = fabsf(xi);
float sum = xi; float sum = xi;
@ -1714,9 +1715,9 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols); rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
} }
static void quantize_row_q8_1_cuda(const float * x, void * vy, const int k, cudaStream_t stream) { static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
quantize_q8_1<<<num_blocks, CUDA_QUANTIZE_BLOCK_SIZE, 0, stream>>>(x, vy, k); quantize_q8_1<<<num_blocks, CUDA_QUANTIZE_BLOCK_SIZE, 0, stream>>>(x, vy, ndata, k);
} }
static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
@ -2359,9 +2360,11 @@ inline void ggml_cuda_op_mul_mat_vec(
#endif #endif
if (use_mul_mat_vec_q) { if (use_mul_mat_vec_q) {
int64_t padded_row_size = ne00 + MATRIX_ROW_PADDING - 1;
padded_row_size -= padded_row_size % MATRIX_ROW_PADDING;
size_t as; size_t as;
void * src1_q8_1 = ggml_cuda_pool_malloc(ne00*sizeof(block_q8_1)/QK8_1, &as); void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*sizeof(block_q8_1)/QK8_1, &as);
quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, cudaStream_main); quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, padded_row_size, cudaStream_main);
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
@ -3105,7 +3108,11 @@ void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) { void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
int nrows = ggml_nrows(tensor); int nrows = ggml_nrows(tensor);
const int64_t ne0 = tensor->ne[0];
const size_t nb1 = tensor->nb[1]; const size_t nb1 = tensor->nb[1];
ggml_backend backend = tensor->backend; ggml_backend backend = tensor->backend;
struct ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu; struct ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu;
memset(extra, 0, sizeof(*extra)); memset(extra, 0, sizeof(*extra));
@ -3134,11 +3141,24 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
int64_t nrows_split = row_high - row_low; int64_t nrows_split = row_high - row_low;
const size_t offset_split = row_low*nb1; const size_t offset_split = row_low*nb1;
const size_t size = ggml_nbytes_split(tensor, nrows_split); size_t size = ggml_nbytes_split(tensor, nrows_split);
const size_t original_size = size;
void * buf; // pad last row to a multiple of 256 elements to avoid out-of-bounds memory accesses
if (ne0 % MATRIX_ROW_PADDING != 0) {
size += (MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING)
* ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
}
char * buf;
CUDA_CHECK(cudaMalloc(&buf, size)); CUDA_CHECK(cudaMalloc(&buf, size));
void * buf_host = (char*)data + offset_split; char * buf_host = (char*)data + offset_split;
// set padding to 0 to avoid possible NaN values
if (size > original_size) {
CUDA_CHECK(cudaMemset(buf + original_size, 0, size - original_size));
}
cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice); cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);