static const int GGML_CUDA_MAX_SUBSTREAMS = 1; static const bool GGML_CUDA_SEQ_COMPUTE = true; #define WARP_SIZE 32 #define CUDA_ADD_BLOCK_SIZE 256 #define CUDA_MUL_BLOCK_SIZE 256 #define CUDA_SILU_BLOCK_SIZE 256 #define CUDA_CPY_BLOCK_SIZE 32 #define CUDA_SCALE_BLOCK_SIZE 256 #define CUDA_ROPE_BLOCK_SIZE 256 #define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 #define CUDA_GET_ROWS_BLOCK_SIZE 256 #define CUDA_QUANTIZE_BLOCK_SIZE 256 // dmmv = dequantize_mul_mat_vec #ifndef GGML_CUDA_DMMV_X #define GGML_CUDA_DMMV_X 32 #endif #ifndef GGML_CUDA_DMMV_Y #define GGML_CUDA_DMMV_Y 1 #endif #ifndef GGML_CUDA_MMV_Y #define GGML_CUDA_MMV_Y 1 #endif #ifndef K_QUANTS_PER_ITERATION #define K_QUANTS_PER_ITERATION 2 #else static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); #endif #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "ggml.h" #include "ggml-cuda.h" #include "ggml-cuda-kern.h" #include "ggml-cuda-quant.h" #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data #endif static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); #define CUDA_CHECK(err) \ do { \ cudaError_t err_ = (err); \ if (err_ != cudaSuccess) { \ fprintf(stderr, "CUDA error %d at %s (%s:%d): %s\n", err_, \ __func__, __FILE__, __LINE__, cudaGetErrorString(err_)); \ exit(1); \ } \ } while (0) #if CUDART_VERSION >= 12000 #define CUBLAS_CHECK(err) \ do { \ cublasStatus_t err_ = (err); \ if (err_ != CUBLAS_STATUS_SUCCESS) { \ fprintf(stderr, "\ncuBLAS error %d at %s (%s:%d): %s\n", err_, \ __func__, __FILE__, __LINE__, cublasGetStatusString(err_)); \ exit(1); \ } \ } while (0) #else #define CUBLAS_CHECK(err) \ do { \ cublasStatus_t err_ = (err); \ if (err_ != CUBLAS_STATUS_SUCCESS) { \ fprintf(stderr, "\ncuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \ exit(1); \ } \ } while (0) #endif // CUDART_VERSION >= 12000 #define UNUSED(x) (void)(x) typedef void (*ggml_cuda_op_t)( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * src0_d, void * src1_d, void * dst_d, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, cudaStream_t cudaStream_main); struct cuda_pool_buffer { void * ptr; size_t size; }; static std::unordered_map> g_cuda_stream_pools; static size_t g_cuda_pool_size = 0; static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size, cudaStream_t stream) { std::vector& pool = g_cuda_stream_pools[stream]; // find existing for (size_t i = 0; i < pool.size(); ++i) { cuda_pool_buffer& b = pool[i]; if (b.size >= size && b.ptr != nullptr) { void * ptr = b.ptr; *actual_size = b.size; pool.erase(pool.begin() + i); return ptr; } } // allocate new void * ptr; CUDA_CHECK(cudaMalloc(&ptr, size)); *actual_size = size; g_cuda_pool_size += size; //fprintf(stderr, "cuda pool size: %.2f MB (allocating now: %.2f MB)\n", g_cuda_pool_size / 1024.0 / 1024.0, size / 1024.0 / 1024.0); return ptr; } static void ggml_cuda_pool_free(void * ptr, size_t size, cudaStream_t stream) { std::vector& pool = g_cuda_stream_pools[stream]; pool.push_back({ ptr, size }); } static void ggml_cuda_pool_free_all() { for (auto& p : g_cuda_stream_pools) { for (auto& b : p.second) { if (b.ptr != nullptr) { CUDA_CHECK(cudaFree(b.ptr)); } } } g_cuda_stream_pools.clear(); } template static void quantize_row_q8_1_cuda(const src_t * x, void * vy, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; quantize_q8_1<<>>(x, vy, k); } template static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; dequantize_block><<>>(vx, y, k); } template static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; dequantize_block><<>>(vx, y, k); } template static void dequantize_row_q5_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; dequantize_block><<>>(vx, y, k); } template static void dequantize_row_q5_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; dequantize_block><<>>(vx, y, k); } template static void dequantize_row_q8_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; dequantize_block><<>>(vx, y, k); } /* static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; dequantize_block_q2_K<<>>(vx, y); } static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; dequantize_block_q3_K<<>>(vx, y); } template static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; dequantize_block_q4_K<<>>(vx, y); } static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; dequantize_block_q5_K<<>>(vx, y); } */ template static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; dequantize_block_q6_K<<>>(vx, y); } template static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const src1_t * y, dst_t * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); dequantize_mul_mat_vec> <<>>(vx, y, dst, ncols, nrows); } template static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const src1_t * y, dst_t * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); dequantize_mul_mat_vec> <<>>(vx, y, dst, ncols, nrows); } template static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const src1_t * y, dst_t * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); dequantize_mul_mat_vec> <<>>(vx, y, dst, ncols, nrows); } template static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const src1_t * y, dst_t * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); dequantize_mul_mat_vec> <<>>(vx, y, dst, ncols, nrows); } template static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const src1_t * y, dst_t * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); dequantize_mul_mat_vec> <<>>(vx, y, dst, ncols, nrows); } /* template static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); const int ny = 2; const int block_num_y = (nrows + ny - 1) / ny; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(32, ny, 1); dequantize_mul_mat_vec_q2_k<<>>(vx, y, dst, ncols, nrows); } template static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); const dim3 block_dims(32, 1, 1); dequantize_mul_mat_vec_q3_k<<>>(vx, y, dst, ncols); } template static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); const dim3 block_dims(32, 1, 1); dequantize_mul_mat_vec_q4_k<<>>(vx, y, dst, ncols); } template static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); const dim3 block_dims(32, 1, 1); dequantize_mul_mat_vec_q5_k<<>>(vx, y, dst, ncols); } */ template static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const src1_t * y, dst_t * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); const int ny = 2 / K_QUANTS_PER_ITERATION; const int block_num_y = (nrows + ny - 1) / ny; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(32, ny, 1); dequantize_mul_mat_vec_q6_k<<>>(vx, y, dst, ncols, nrows); } template static void convert_mul_mat_vec_f16_cuda(const void * vx, const src1_t * y, dst_t * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); dequantize_mul_mat_vec><<>>(vx, y, dst, ncols, nrows); } template static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, dst_t * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } template static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, dst_t * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } template static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, dst_t * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } template static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, dst_t * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } template static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, dst_t * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } template static void convert_fp16_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; dequantize_block><<>>(vx, y, k); } template static to_t_cuda_t ggml_get_to_t_cuda(ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: return dequantize_row_q4_0_cuda; case GGML_TYPE_Q4_1: return dequantize_row_q4_1_cuda; case GGML_TYPE_Q5_0: return dequantize_row_q5_0_cuda; case GGML_TYPE_Q5_1: return dequantize_row_q5_1_cuda; case GGML_TYPE_Q8_0: return dequantize_row_q8_0_cuda; /* case GGML_TYPE_Q2_K: return dequantize_row_q2_K_cuda; case GGML_TYPE_Q3_K: return dequantize_row_q3_K_cuda; case GGML_TYPE_Q4_K: return dequantize_row_q4_K_cuda; case GGML_TYPE_Q5_K: return dequantize_row_q5_K_cuda; */ case GGML_TYPE_Q6_K: return dequantize_row_q6_K_cuda; case GGML_TYPE_F16: return convert_fp16_cuda; default: return nullptr; } } template static void ggml_mul_mat_p021_cuda(const src0_t * vx, const src1_t * y, dst_t * dst, const int ncols_x, const int nrows_x, const int nchannels_x, cudaStream_t stream) { const dim3 block_nums(1, nrows_x, nchannels_x); const dim3 block_dims(WARP_SIZE, 1, 1); k_mul_mat_p021<<>>(vx, y, dst, ncols_x, nrows_x, nchannels_x); } template static void ggml_mul_mat_vec_nc_cuda( const src0_t * vx, const src1_t * y, dst_t * dst, const int ncols_x, const int nrows_x, const int row_stride_x, const int nchannels_x, const int channel_stride_x, cudaStream_t stream) { const dim3 block_nums(1, nrows_x, nchannels_x); const dim3 block_dims(WARP_SIZE, 1, 1); k_mul_mat_vec_nc<<>> (vx, y, dst, ncols_x, nrows_x, row_stride_x, nchannels_x, channel_stride_x); } template static void ggml_cpy_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int nb00, const int nb01, const int nb02, const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) { const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; k_cpy<<>> (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); } template static void add_cuda(const src0_t * x, const src1_t * y, dst_t * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE; k_add<<>>(x, y, dst, k); } template static void mul_cuda(const src0_t * x, const src1_t * y, dst_t * dst, const int kx, const int ky, cudaStream_t stream) { const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE; k_mul<<>>(x, y, dst, kx, ky); } template static void silu_cuda(const src0_t * x, dst_t * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; k_silu<<>>(x, dst, k); } template static void rms_norm_cuda(const src0_t * x, dst_t * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % WARP_SIZE == 0); const dim3 block_dims(WARP_SIZE, 1, 1); k_rms_norm<<>>(x, dst, ncols); } template static void scale_cuda(const src0_t * x, dst_t * dst, const src1_t * scale, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; k_scale<<>>(x, dst, scale, k); } template static void rope_cuda(const src0_t * x, dst_t * dst, const int ncols, const int nrows, const float p, const float theta_scale, cudaStream_t stream) { GGML_ASSERT(nrows % 2 == 0); const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1); const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(num_blocks_x, nrows, 1); k_rope<<>>(x, dst, ncols, p, theta_scale); } template static void diag_mask_inf_cuda(const src0_t * x, dst_t * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) { const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1); const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE; const dim3 block_nums(block_num_x, nrows_x, 1); k_diag_mask_inf<<>>(x, dst, ncols_x, rows_per_channel, n_past); } template static void soft_max_cuda(const src0_t * x, dst_t * dst, const int ncols, const int nrows, cudaStream_t stream) { // TODO: implement fast numerically stable version for small ncols //if (ncols >= 1024) { int num_blocks = nrows; if (ncols % 2 == 0) { k_soft_max <<>>(x, dst, nrows, ncols); } else { k_soft_max <<>>(x, dst, nrows, ncols); } //} //else { // const dim3 block_dims(WARP_SIZE, 1, 1); // const dim3 block_nums(1, nrows, 1); // k_soft_max_orig<<>>(x, dst, ncols); //} } template dq> static void get_rows_cuda(const void * x, const int * y, dst_t * dst, const int nrows, const int ncols, cudaStream_t stream) { const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); const int block_num = (ncols/2 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE; const dim3 block_nums(block_num, nrows, 1); k_get_rows<<>>(x, y, dst, ncols); } // TODO: move to context static cublasHandle_t g_cublas_handle = nullptr; static cudaStream_t g_cudaStream_main = nullptr; static cudaEvent_t g_cudaEvent_main = nullptr; static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_SUBSTREAMS] = { }; static cudaEvent_t g_cudaEvents[GGML_CUDA_MAX_SUBSTREAMS] = { }; #define GGML_CUDA_MAX_DEVICES 16 static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES]; static void ggml_init_cublas() { static bool initialized = false; if (!initialized) { int device_count; CUDA_CHECK(cudaGetDeviceCount(&device_count)); int64_t total_vram = 0; fprintf(stderr, "%s: found %d CUDA devices:\n", __func__, device_count); for (int id = 0; id < device_count; ++id) { cudaDeviceProp prop; CUDA_CHECK(cudaGetDeviceProperties(&prop, id)); fprintf(stderr, " Device %d: %s (%.0f GB)\n", id, prop.name, prop.totalGlobalMem / 1024.0 / 1024.0 / 1024.0); total_vram += prop.totalGlobalMem; g_compute_capabilities[id] = 100*prop.major + 10*prop.minor; } // create main stream and event CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream_main, cudaStreamNonBlocking)); CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvent_main, cudaEventDisableTiming)); // create secondary streams and events for (int i = 0; i < GGML_CUDA_MAX_SUBSTREAMS; ++i) { CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams[i], cudaStreamNonBlocking)); CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents[i], cudaEventDisableTiming)); } // create cublas handle CUBLAS_CHECK(cublasCreate(&g_cublas_handle)); //CUBLAS_CHECK(cublasSetMathMode(g_cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH)); // configure logging to stdout //CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr)); initialized = true; } } void * ggml_cuda_host_malloc(size_t size) { if (getenv("GGML_CUDA_NO_PINNED") != nullptr) { return nullptr; } void * ptr = nullptr; cudaError_t err = cudaMallocHost((void **) &ptr, size); if (err != cudaSuccess) { // The allocation error can be bypassed. A null ptr will assigned out of this function. // This can fixed the OOM error in WSL. cudaGetLastError(); fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n", size/1024.0/1024.0, cudaGetErrorString(err)); return nullptr; } return ptr; } void ggml_cuda_host_free(void * ptr) { CUDA_CHECK(cudaFreeHost(ptr)); } void ggml_cuda_host_register(void * ptr, size_t size) { CUDA_CHECK(cudaHostRegister(ptr, size, 0)); } void ggml_cuda_host_unregister(void * ptr) { CUDA_CHECK(cudaHostUnregister(ptr)); } template static void ggml_cuda_op_add( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * src0_d, void * src1_d, void * dst_d, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, cudaStream_t stream) { const int64_t ne0 = src0->ne[0]; const int64_t i01_diff = i01_high - i01_low; // compute add_cuda((src0_t *)src0_d, (src1_t *) src1_d, (dst_t *) dst_d, ne0*i01_diff, stream); CUDA_CHECK(cudaGetLastError()); UNUSED(src1); UNUSED(dst); UNUSED(i02); UNUSED(i1); } template static void ggml_cuda_op_mul( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * src0_d, void * src1_d, void * dst_d, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, cudaStream_t stream) { const int64_t ne00 = src0->ne[0]; const int64_t ne10 = src1->ne[0]; const int64_t ne11 = src1->ne[1]; for (int64_t i01 = i01_low; i01 < i01_high; i01++) { const int64_t i11 = i1*ne11 + i01%ne11; // broadcast src1 across src0 src0_t * src0_d_i01 = (src0_t *) src0_d + i01*ne00; src1_t * src1_d_i01 = (src1_t *) src1_d + i11*ne10; dst_t * dst_d_i01 = (dst_t *) dst_d + i01*ne00; // compute mul_cuda(src0_d_i01, src1_d_i01, dst_d_i01, ne00, ne10, stream); CUDA_CHECK(cudaGetLastError()); } UNUSED(dst); UNUSED(i02); } template static void ggml_cuda_op_silu( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * src0_d, void * src1_d, void * dst_d, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, cudaStream_t stream) { const int64_t ne00 = src0->ne[0]; const int64_t i01_diff = i01_high - i01_low; // compute silu_cuda((src0_t *)src0_d, (dst_t *)dst_d, ne00*i01_diff, stream); CUDA_CHECK(cudaGetLastError()); UNUSED(src1); UNUSED(src1_d); UNUSED(dst); UNUSED(i02); UNUSED(i1); } template static void ggml_cuda_op_rms_norm( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * src0_d, void * src1_d, void * dst_d, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, cudaStream_t stream) { const int64_t ne00 = src0->ne[0]; const int64_t i01_diff = i01_high - i01_low; // compute rms_norm_cuda((src0_t *)src0_d, (dst_t *)dst_d, ne00, i01_diff, stream); CUDA_CHECK(cudaGetLastError()); UNUSED(src1); UNUSED(src1_d); UNUSED(dst); UNUSED(i02); UNUSED(i1); } template static void ggml_cuda_op_dequantize_mul_mat_vec( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * src0_d, void * src1_d, void * dst_d, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, cudaStream_t stream) { const int64_t ne00 = src0->ne[0]; const int64_t nrows = i01_high - i01_low; #ifdef GGML_CUDA_FORCE_DMMV const bool use_mul_mat_vec_q = false; #else int id; CUDA_CHECK(cudaGetDevice(&id)); const bool mul_mat_vec_q_implemented = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 || src0->type == GGML_TYPE_Q8_0; // The integer intrinsics used in mul_mat_vec_q are available with compute capability 6. // However, they have bad performance with Pascal cards. // Therefore, in a multi GPU setting decide at runtime which GPUs should use mul_mat_vec_q. const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 700 && mul_mat_vec_q_implemented; #endif if (use_mul_mat_vec_q) { size_t as; void * src1_q8_1 = ggml_cuda_pool_malloc(ne00*sizeof(block_q8_1)/QK8_1, &as, stream); quantize_row_q8_1_cuda((src1_t *)src1_d, src1_q8_1, ne00, stream); switch (src0->type) { case GGML_TYPE_Q4_0: mul_mat_vec_q4_0_q8_1_cuda(src0_d, src1_q8_1, (dst_t *)dst_d, ne00, nrows, stream); break; case GGML_TYPE_Q4_1: mul_mat_vec_q4_1_q8_1_cuda(src0_d, src1_q8_1, (dst_t *)dst_d, ne00, nrows, stream); break; case GGML_TYPE_Q5_0: mul_mat_vec_q5_0_q8_1_cuda(src0_d, src1_q8_1, (dst_t *)dst_d, ne00, nrows, stream); break; case GGML_TYPE_Q5_1: mul_mat_vec_q5_1_q8_1_cuda(src0_d, src1_q8_1, (dst_t *)dst_d, ne00, nrows, stream); break; case GGML_TYPE_Q8_0: mul_mat_vec_q8_0_q8_1_cuda(src0_d, src1_q8_1, (dst_t *)dst_d, ne00, nrows, stream); break; default: GGML_ASSERT(false); break; } ggml_cuda_pool_free(src1_q8_1, as, stream); } else { switch (src0->type) { case GGML_TYPE_Q4_0: dequantize_mul_mat_vec_q4_0_cuda(src0_d, (src1_t *)src1_d, (dst_t *)dst_d, ne00, nrows, stream); break; case GGML_TYPE_Q4_1: dequantize_mul_mat_vec_q4_1_cuda(src0_d, (src1_t *)src1_d, (dst_t *)dst_d, ne00, nrows, stream); break; case GGML_TYPE_Q5_0: dequantize_mul_mat_vec_q5_0_cuda(src0_d, (src1_t *)src1_d, (dst_t *)dst_d, ne00, nrows, stream); break; case GGML_TYPE_Q5_1: dequantize_mul_mat_vec_q5_1_cuda(src0_d, (src1_t *)src1_d, (dst_t *)dst_d, ne00, nrows, stream); break; case GGML_TYPE_Q8_0: dequantize_mul_mat_vec_q8_0_cuda(src0_d, (src1_t *)src1_d, (dst_t *)dst_d, ne00, nrows, stream); break; /* case GGML_TYPE_Q2_K: dequantize_mul_mat_vec_q2_K_cuda(src0_d, (src1_t *)src1_d, (dst_t *)dst_d, ne00, nrows, cudaStream_main); break; case GGML_TYPE_Q3_K: dequantize_mul_mat_vec_q3_K_cuda(src0_d, (src1_t *)src1_d, (dst_t *)dst_d, ne00, nrows, cudaStream_main); break; case GGML_TYPE_Q4_K: dequantize_mul_mat_vec_q4_K_cuda(src0_d, (src1_t *)src1_d, (dst_t *)dst_d, ne00, nrows, cudaStream_main); break; case GGML_TYPE_Q5_K: dequantize_mul_mat_vec_q5_K_cuda(src0_d, (src1_t *)src1_d, (dst_t *)dst_d, ne00, nrows, cudaStream_main); break; */ case GGML_TYPE_Q6_K: dequantize_mul_mat_vec_q6_K_cuda(src0_d, (src1_t *)src1_d, (dst_t *)dst_d, ne00, nrows, stream); break; case GGML_TYPE_F16: convert_mul_mat_vec_f16_cuda(src0_d, (src1_t *)src1_d, (dst_t *)dst_d, ne00, nrows, stream); break; default: GGML_ASSERT(false); break; } } CUDA_CHECK(cudaGetLastError()); UNUSED(src1); UNUSED(dst); UNUSED(i02); UNUSED(i1); } template static void ggml_cuda_op_rope( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * src0_d, void * src1_d, void * dst_d, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, cudaStream_t stream) { const int64_t ne00 = src0->ne[0]; const int64_t i01_diff = i01_high - i01_low; const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; //const int n_ctx = ((int32_t *) dst->params)[3]; GGML_ASSERT(mode == 0); const float theta_scale = powf(10000.0, -2.0f/n_dims); const float p = ((mode & 1) == 0 ? n_past + i02 : i02); // compute rope_cuda((src0_t *)src0_d, (dst_t *)dst_d, ne00, i01_diff, p, theta_scale, stream); CUDA_CHECK(cudaGetLastError()); UNUSED(dst); UNUSED(src1); UNUSED(src1_d); UNUSED(i1); } template static void ggml_cuda_op_diag_mask_inf( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * src0_d, void * src1_d, void * dst_d, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, cudaStream_t stream) { const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; const int64_t i01_diff = i01_high - i01_low; const int n_past = ((int32_t *) dst->op_params)[0]; // compute diag_mask_inf_cuda((src0_t *)src0_d, (dst_t *)dst_d, ne00, i01_diff, ne01, n_past, stream); CUDA_CHECK(cudaGetLastError()); UNUSED(dst); UNUSED(src1); UNUSED(src1_d); UNUSED(i02); UNUSED(i1); } template static void ggml_cuda_op_soft_max( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * src0_d, void * src1_d, void * dst_d, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, cudaStream_t stream) { const int64_t ne00 = src0->ne[0]; const int64_t i01_diff = i01_high - i01_low; // compute soft_max_cuda((src0_t *)src0_d, (dst_t *)dst_d, ne00, i01_diff, stream); CUDA_CHECK(cudaGetLastError()); UNUSED(src1); UNUSED(src1_d); UNUSED(dst); UNUSED(i02); UNUSED(i1); } template static void ggml_cuda_op_scale( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * src0_d, void * src1_d, void * dst_d, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, cudaStream_t stream) { //const src1_t scale = ((src1_t *) src1->data)[0]; const int64_t ne00 = src0->ne[0]; const int64_t i01_diff = i01_high - i01_low; // compute scale_cuda((src0_t *)src0_d, (dst_t *)dst_d, (src1_t *)src1_d, ne00*i01_diff, stream); CUDA_CHECK(cudaGetLastError()); UNUSED(src1); UNUSED(src1_d); UNUSED(dst); UNUSED(i02); UNUSED(i1); } template static void ggml_cuda_op_get_rows( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * src0_d, void * src1_d, void * dst_d, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, cudaStream_t stream) { GGML_ASSERT(src1->type == GGML_TYPE_I32); GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(src1)); GGML_ASSERT(ggml_is_contiguous(dst)); const int ncols = src0->ne[0]; const int nrows = ggml_nelements(src1); switch (src0->type) { case GGML_TYPE_F16: get_rows_cuda>(src0_d, (int *) src1_d, (dst_t *)dst_d, nrows, ncols, stream); break; case GGML_TYPE_F32: get_rows_cuda>(src0_d, (int *) src1_d, (dst_t *)dst_d, nrows, ncols, stream); break; case GGML_TYPE_Q4_0: get_rows_cuda>(src0_d, (int *) src1_d, (dst_t *)dst_d, nrows, ncols, stream); break; case GGML_TYPE_Q4_1: get_rows_cuda>(src0_d, (int *) src1_d, (dst_t *)dst_d, nrows, ncols, stream); break; case GGML_TYPE_Q5_0: get_rows_cuda>(src0_d, (int *) src1_d, (dst_t *)dst_d, nrows, ncols, stream); break; case GGML_TYPE_Q5_1: get_rows_cuda>(src0_d, (int *) src1_d, (dst_t *)dst_d, nrows, ncols, stream); break; case GGML_TYPE_Q8_0: get_rows_cuda>(src0_d, (int *) src1_d, (dst_t *)dst_d, nrows, ncols, stream); break; default: GGML_ASSERT(false); break; } CUDA_CHECK(cudaGetLastError()); UNUSED(i02); UNUSED(i01_low); UNUSED(i01_high); UNUSED(i1); } //////////////////////////////////////////////////////////////////////////////////////////////////// struct ggml_cuda_buffer { const char * name; void * data; size_t size; void * device; }; struct ggml_cuda_context { std::vector buffers; }; ggml_cuda_context * ggml_cuda_init() { ggml_init_cublas(); ggml_cuda_context * ctx = new ggml_cuda_context; return ctx; } void ggml_cuda_free(ggml_cuda_context * ctx) { for (size_t n = 0; n < ctx->buffers.size(); ++n) { if (ctx->buffers[n].device != nullptr) { CUDA_CHECK(cudaFree(ctx->buffers[n].device)); } } // this will free the global memory pool for all contexts ggml_cuda_pool_free_all(); delete ctx; } static void * ggml_cuda_get_buffer(ggml_cuda_context * ctx, ggml_tensor * t) { return t->data; UNUSED(ctx); } static cudaError_t ggml_cuda_cpy_tensor_2d(ggml_cuda_context * ctx, void * dst, ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) { cudaMemcpyKind kind = cudaMemcpyDeviceToDevice; const char * src_ptr = (const char *) ggml_cuda_get_buffer(ctx, src); char * dst_ptr = (char *) dst; const int64_t ne0 = src->ne[0]; const int64_t nb0 = src->nb[0]; const int64_t nb1 = src->nb[1]; const int64_t nb2 = src->nb[2]; const int64_t nb3 = src->nb[3]; const enum ggml_type type = src->type; const int64_t ts = ggml_type_size(type); const int64_t bs = ggml_blck_size(type); int64_t i1_diff = i1_high - i1_low; GGML_ASSERT(i1_low == 0); const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3; if (nb0 == ts && nb1 == ts*ne0/bs) { return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, kind, stream); } else if (nb0 == ts) { return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, kind, stream); } else { for (int64_t i1 = 0; i1 < i1_diff; i1++) { const void * rx = (const void *) ((const char *) x + i1*nb1); void * rd = (void *) (dst_ptr + i1*ts*ne0/bs); // pretend the row is a matrix with cols=1 cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, kind, stream); if (r != cudaSuccess) return r; } return cudaSuccess; } } static const ggml_type GGML_TYPE_NONE = GGML_TYPE_COUNT; struct ggml_cuda_op_dispatch_t { ggml_cuda_op_t d[GGML_TYPE_COUNT][GGML_TYPE_COUNT+1][GGML_TYPE_COUNT] = { nullptr }; }; template