mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 19:04:35 +00:00
cuBLAS: refactor and optimize f16 mat mul performance (#1259)
* cuBLAS: refactor, convert fp16 to fp32 on device * cuBLAS: use multiple streams, choose smartly between mul_mat_q and mul_mat_f16 * fix build * cuBLAS: update block_q5_1
This commit is contained in:
parent
ea3a0ad6b6
commit
58b367c2d7
429
ggml-cuda.cu
429
ggml-cuda.cu
@ -1,11 +1,38 @@
|
|||||||
|
#include <cstddef>
|
||||||
|
#include <cstdint>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <cuda_fp16.h>
|
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include "ggml-cuda.h"
|
|
||||||
|
|
||||||
typedef uint16_t ggml_fp16_t;
|
#include <cuda_runtime.h>
|
||||||
static_assert(sizeof(__half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
#include <cublas_v2.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
#include "ggml-cuda.h"
|
||||||
|
#include "ggml.h"
|
||||||
|
|
||||||
|
static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
||||||
|
|
||||||
|
#define CUDA_CHECK(err) \
|
||||||
|
do { \
|
||||||
|
cudaError_t err_ = (err); \
|
||||||
|
if (err_ != cudaSuccess) { \
|
||||||
|
fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
|
||||||
|
cudaGetErrorString(err_)); \
|
||||||
|
exit(1); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
#define CUBLAS_CHECK(err) \
|
||||||
|
do { \
|
||||||
|
cublasStatus_t err_ = (err); \
|
||||||
|
if (err_ != CUBLAS_STATUS_SUCCESS) { \
|
||||||
|
fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
|
||||||
|
exit(1); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
|
||||||
|
|
||||||
#define QK4_0 32
|
#define QK4_0 32
|
||||||
typedef struct {
|
typedef struct {
|
||||||
@ -24,14 +51,14 @@ static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 b
|
|||||||
|
|
||||||
#define QK4_2 16
|
#define QK4_2 16
|
||||||
typedef struct {
|
typedef struct {
|
||||||
__half d; // delta
|
half d; // delta
|
||||||
uint8_t qs[QK4_2 / 2]; // nibbles / quants
|
uint8_t qs[QK4_2 / 2]; // nibbles / quants
|
||||||
} block_q4_2;
|
} block_q4_2;
|
||||||
static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
|
static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
|
||||||
|
|
||||||
#define QK5_0 32
|
#define QK5_0 32
|
||||||
typedef struct {
|
typedef struct {
|
||||||
__half d; // delta
|
half d; // delta
|
||||||
uint8_t qh[4]; // 5-th bit of quants
|
uint8_t qh[4]; // 5-th bit of quants
|
||||||
uint8_t qs[QK5_0 / 2]; // nibbles / quants
|
uint8_t qs[QK5_0 / 2]; // nibbles / quants
|
||||||
} block_q5_0;
|
} block_q5_0;
|
||||||
@ -39,9 +66,9 @@ static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5
|
|||||||
|
|
||||||
#define QK5_1 32
|
#define QK5_1 32
|
||||||
typedef struct {
|
typedef struct {
|
||||||
__half d; // delta
|
half d; // delta
|
||||||
__half m; // min
|
half m; // min
|
||||||
uint32_t qh; // 5-th bit of quants
|
uint8_t qh[4]; // 5-th bit of quants
|
||||||
uint8_t qs[QK5_1 / 2]; // nibbles / quants
|
uint8_t qs[QK5_1 / 2]; // nibbles / quants
|
||||||
} block_q5_1;
|
} block_q5_1;
|
||||||
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
|
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
|
||||||
@ -162,7 +189,8 @@ static __global__ void dequantize_block_q5_1(const void * vx, float * y) {
|
|||||||
|
|
||||||
const uint8_t * pp = x[i].qs;
|
const uint8_t * pp = x[i].qs;
|
||||||
|
|
||||||
const uint32_t qh = x[i].qh;
|
uint32_t qh;
|
||||||
|
memcpy(&qh, x[i].qh, sizeof(qh));
|
||||||
|
|
||||||
for (int l = 0; l < QK5_1; l += 2) {
|
for (int l = 0; l < QK5_1; l += 2) {
|
||||||
const uint8_t vi = pp[l/2];
|
const uint8_t vi = pp[l/2];
|
||||||
@ -197,37 +225,50 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
||||||
const int nb = k / QK4_0;
|
const int nb = k / QK4_0;
|
||||||
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
|
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
static void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
||||||
const int nb = k / QK4_1;
|
const int nb = k / QK4_1;
|
||||||
dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
|
dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
static void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
||||||
const int nb = k / QK4_2;
|
const int nb = k / QK4_2;
|
||||||
dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y);
|
dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
static void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
||||||
const int nb = k / QK5_0;
|
const int nb = k / QK5_0;
|
||||||
dequantize_block_q5_0<<<nb, 1, 0, stream>>>(vx, y);
|
dequantize_block_q5_0<<<nb, 1, 0, stream>>>(vx, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
static void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
||||||
const int nb = k / QK5_1;
|
const int nb = k / QK5_1;
|
||||||
dequantize_block_q5_1<<<nb, 1, 0, stream>>>(vx, y);
|
dequantize_block_q5_1<<<nb, 1, 0, stream>>>(vx, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
||||||
const int nb = k / QK8_0;
|
const int nb = k / QK8_0;
|
||||||
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
|
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(ggml_type type) {
|
// TODO: optimize
|
||||||
|
static __global__ void convert_fp16_to_fp32(const void * vx, float * y) {
|
||||||
|
const half * x = (const half *) vx;
|
||||||
|
|
||||||
|
const int i = blockIdx.x;
|
||||||
|
|
||||||
|
y[i] = __half2float(x[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void convert_fp16_to_fp32_cuda(const void * x, float * y, int k, cudaStream_t stream) {
|
||||||
|
convert_fp16_to_fp32<<<k, 1, 0, stream>>>(x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
return dequantize_row_q4_0_cuda;
|
return dequantize_row_q4_0_cuda;
|
||||||
@ -241,6 +282,8 @@ dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(ggml_type type) {
|
|||||||
return dequantize_row_q5_1_cuda;
|
return dequantize_row_q5_1_cuda;
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
return dequantize_row_q8_0_cuda;
|
return dequantize_row_q8_0_cuda;
|
||||||
|
case GGML_TYPE_F16:
|
||||||
|
return convert_fp16_to_fp32_cuda;
|
||||||
default:
|
default:
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
@ -271,7 +314,7 @@ struct cuda_buffer {
|
|||||||
static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS];
|
static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS];
|
||||||
static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
|
static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
|
||||||
|
|
||||||
void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
|
static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
|
||||||
scoped_spin_lock lock(g_cuda_pool_lock);
|
scoped_spin_lock lock(g_cuda_pool_lock);
|
||||||
|
|
||||||
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
|
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
|
||||||
@ -290,7 +333,7 @@ void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
|
|||||||
return ptr;
|
return ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_pool_free(void * ptr, size_t size) {
|
static void ggml_cuda_pool_free(void * ptr, size_t size) {
|
||||||
scoped_spin_lock lock(g_cuda_pool_lock);
|
scoped_spin_lock lock(g_cuda_pool_lock);
|
||||||
|
|
||||||
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
|
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
|
||||||
@ -305,28 +348,55 @@ void ggml_cuda_pool_free(void * ptr, size_t size) {
|
|||||||
CUDA_CHECK(cudaFree(ptr));
|
CUDA_CHECK(cudaFree(ptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
cublasHandle_t g_cublasH = nullptr;
|
#define GGML_CUDA_MAX_STREAMS 8
|
||||||
cudaStream_t g_cudaStream = nullptr;
|
#define GGML_CUDA_MAX_EVENTS 64
|
||||||
cudaStream_t g_cudaStream2 = nullptr;
|
static cublasHandle_t g_cublasH = nullptr;
|
||||||
cudaEvent_t g_cudaEvent = nullptr;
|
static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_STREAMS] = { nullptr };
|
||||||
|
static cudaStream_t g_cudaStreams2[GGML_CUDA_MAX_STREAMS] = { nullptr };
|
||||||
|
static cudaEvent_t g_cudaEvents[GGML_CUDA_MAX_EVENTS] = { nullptr };
|
||||||
|
|
||||||
void ggml_init_cublas() {
|
void ggml_init_cublas() {
|
||||||
if (g_cublasH == nullptr) {
|
if (g_cublasH == nullptr) {
|
||||||
// create cublas handle, bind a stream
|
// create streams
|
||||||
CUBLAS_CHECK(cublasCreate(&g_cublasH));
|
for (int i = 0; i < GGML_CUDA_MAX_STREAMS; ++i) {
|
||||||
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream, cudaStreamNonBlocking));
|
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams[i], cudaStreamNonBlocking));
|
||||||
CUBLAS_CHECK(cublasSetStream(g_cublasH, g_cudaStream));
|
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams2[i], cudaStreamNonBlocking));
|
||||||
|
}
|
||||||
|
// create events
|
||||||
|
for (int i = 0; i < GGML_CUDA_MAX_EVENTS; ++i) {
|
||||||
|
CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents[i], cudaEventDisableTiming));
|
||||||
|
}
|
||||||
|
|
||||||
// create additional stream and event for synchronization
|
// create cublas handle
|
||||||
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream2, cudaStreamNonBlocking));
|
CUBLAS_CHECK(cublasCreate(&g_cublasH));
|
||||||
CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvent, cudaEventDisableTiming));
|
CUBLAS_CHECK(cublasSetMathMode(g_cublasH, CUBLAS_TF32_TENSOR_OP_MATH));
|
||||||
|
|
||||||
// configure logging to stdout
|
// configure logging to stdout
|
||||||
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
|
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream) {
|
void * ggml_cuda_host_malloc(size_t size) {
|
||||||
|
if (getenv("GGML_CUDA_NO_PINNED") != nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void * ptr = nullptr;
|
||||||
|
cudaError_t err = cudaMallocHost((void **) &ptr, size);
|
||||||
|
if (err != cudaSuccess) {
|
||||||
|
fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n",
|
||||||
|
size/1024.0/1024.0, cudaGetErrorString(err));
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_host_free(void * ptr) {
|
||||||
|
CUDA_CHECK(cudaFreeHost(ptr));
|
||||||
|
}
|
||||||
|
|
||||||
|
static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream) {
|
||||||
const uint64_t ne0 = src->ne[0];
|
const uint64_t ne0 = src->ne[0];
|
||||||
const uint64_t ne1 = src->ne[1];
|
const uint64_t ne1 = src->ne[1];
|
||||||
const uint64_t nb0 = src->nb[0];
|
const uint64_t nb0 = src->nb[0];
|
||||||
@ -354,22 +424,293 @@ cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void * ggml_cuda_host_malloc(size_t size) {
|
static void ggml_cuda_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
if (getenv("GGML_CUDA_NO_PINNED") != nullptr) {
|
const int64_t ne00 = src0->ne[0];
|
||||||
return nullptr;
|
const int64_t ne01 = src0->ne[1];
|
||||||
|
const int64_t ne02 = src0->ne[2];
|
||||||
|
const int64_t ne03 = src0->ne[3];
|
||||||
|
|
||||||
|
const int64_t ne10 = src1->ne[0];
|
||||||
|
const int64_t ne11 = src1->ne[1];
|
||||||
|
|
||||||
|
const int nb2 = dst->nb[2];
|
||||||
|
const int nb3 = dst->nb[3];
|
||||||
|
|
||||||
|
const float alpha = 1.0f;
|
||||||
|
const float beta = 0.0f;
|
||||||
|
const int x_ne = ne01 * ne00;
|
||||||
|
const int y_ne = ne11 * ne10;
|
||||||
|
const int d_ne = ne11 * ne01;
|
||||||
|
const int n_mm = ne03 * ne02;
|
||||||
|
|
||||||
|
size_t x_size, y_size, d_size;
|
||||||
|
float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
|
||||||
|
float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
|
||||||
|
float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
|
||||||
|
|
||||||
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||||
|
int i = i03*ne02 + i02;
|
||||||
|
cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
|
||||||
|
|
||||||
|
float * c_X = d_X + i * x_ne;
|
||||||
|
float * c_Y = d_Y + i * y_ne;
|
||||||
|
float * c_D = d_D + i * d_ne;
|
||||||
|
|
||||||
|
// copy data to device
|
||||||
|
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, cudaStream));
|
||||||
|
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
|
||||||
|
|
||||||
|
// compute
|
||||||
|
CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
|
||||||
|
CUBLAS_CHECK(
|
||||||
|
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||||
|
ne01, ne11, ne10,
|
||||||
|
&alpha, c_X, ne00,
|
||||||
|
c_Y, ne10,
|
||||||
|
&beta, c_D, ne01));
|
||||||
|
|
||||||
|
// copy dst to host
|
||||||
|
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||||
|
CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void * ptr = nullptr;
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
cudaError_t err = cudaMallocHost((void **) &ptr, size);
|
ggml_cuda_pool_free(d_X, x_size);
|
||||||
if (err != cudaSuccess) {
|
ggml_cuda_pool_free(d_Y, y_size);
|
||||||
fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n",
|
ggml_cuda_pool_free(d_D, d_size);
|
||||||
size/1024.0/1024.0, cudaGetErrorString(err));
|
}
|
||||||
return nullptr;
|
|
||||||
|
static void ggml_cuda_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t /* wsize */) {
|
||||||
|
const int64_t ne00 = src0->ne[0];
|
||||||
|
const int64_t ne01 = src0->ne[1];
|
||||||
|
const int64_t ne02 = src0->ne[2];
|
||||||
|
const int64_t ne03 = src0->ne[3];
|
||||||
|
|
||||||
|
const int64_t ne10 = src1->ne[0];
|
||||||
|
const int64_t ne11 = src1->ne[1];
|
||||||
|
|
||||||
|
const int nb10 = src1->nb[0];
|
||||||
|
const int nb11 = src1->nb[1];
|
||||||
|
const int nb12 = src1->nb[2];
|
||||||
|
const int nb13 = src1->nb[3];
|
||||||
|
|
||||||
|
const int nb2 = dst->nb[2];
|
||||||
|
const int nb3 = dst->nb[3];
|
||||||
|
|
||||||
|
const float alpha = 1.0f;
|
||||||
|
const float beta = 0.0f;
|
||||||
|
const int x_ne = ne01 * ne00;
|
||||||
|
const int y_ne = ne11 * ne10;
|
||||||
|
const int d_ne = ne11 * ne01;
|
||||||
|
const int n_mm = ne03 * ne02;
|
||||||
|
|
||||||
|
size_t x_size, y_size, d_size;
|
||||||
|
half * d_X = (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * x_ne, &x_size);
|
||||||
|
half * d_Y = (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * y_ne, &y_size);
|
||||||
|
float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
|
||||||
|
|
||||||
|
bool src1_cont_rows = nb10 == sizeof(float);
|
||||||
|
bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float);
|
||||||
|
|
||||||
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||||
|
int i = i03*ne02 + i02;
|
||||||
|
cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
|
||||||
|
|
||||||
|
half * c_X = d_X + i * x_ne;
|
||||||
|
half * c_Y = d_Y + i * y_ne;
|
||||||
|
float * c_D = d_D + i * d_ne;
|
||||||
|
|
||||||
|
// copy src0 to device
|
||||||
|
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, cudaStream));
|
||||||
|
|
||||||
|
// convert src1 to fp16
|
||||||
|
// TODO: use multiple threads
|
||||||
|
ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i03 * ne02 + i02);
|
||||||
|
char * src1i = (char *) src1->data + i03*nb13 + i02*nb12;
|
||||||
|
if (src1_cont_rows) {
|
||||||
|
if (src1_cont_cols) {
|
||||||
|
ggml_fp32_to_fp16_row((float *) src1i, tmp, ne10*ne11);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for (int64_t i01 = 0; i01 < ne11; i01++) {
|
||||||
|
ggml_fp32_to_fp16_row((float *) (src1i + i01*nb11), tmp + i01*ne10, ne10);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for (int64_t i01 = 0; i01 < ne11; i01++) {
|
||||||
|
for (int64_t i00 = 0; i00 < ne10; i00++) {
|
||||||
|
// very slow due to no inlining
|
||||||
|
tmp[i01*ne10 + i00] = ggml_fp32_to_fp16(*(float *) (src1i + i01*nb11 + i00*nb10));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ptr;
|
// copy src1 to device
|
||||||
|
CUDA_CHECK(cudaMemcpyAsync(c_Y, tmp, sizeof(half) * y_ne, cudaMemcpyHostToDevice, cudaStream));
|
||||||
|
|
||||||
|
// compute
|
||||||
|
CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
|
||||||
|
CUBLAS_CHECK(
|
||||||
|
cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||||
|
ne01, ne11, ne10,
|
||||||
|
&alpha, c_X, CUDA_R_16F, ne00,
|
||||||
|
c_Y, CUDA_R_16F, ne10,
|
||||||
|
&beta, c_D, CUDA_R_32F, ne01,
|
||||||
|
CUBLAS_COMPUTE_32F_FAST_16F,
|
||||||
|
CUBLAS_GEMM_DEFAULT));
|
||||||
|
|
||||||
|
// copy dst to host
|
||||||
|
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||||
|
CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
|
ggml_cuda_pool_free(d_X, x_size);
|
||||||
|
ggml_cuda_pool_free(d_Y, y_size);
|
||||||
|
ggml_cuda_pool_free(d_D, d_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_host_free(void * ptr) {
|
static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
CUDA_CHECK(cudaFreeHost(ptr));
|
const int64_t ne00 = src0->ne[0];
|
||||||
|
const int64_t ne01 = src0->ne[1];
|
||||||
|
const int64_t ne02 = src0->ne[2];
|
||||||
|
const int64_t ne03 = src0->ne[3];
|
||||||
|
|
||||||
|
const int64_t ne10 = src1->ne[0];
|
||||||
|
const int64_t ne11 = src1->ne[1];
|
||||||
|
|
||||||
|
const int nb2 = dst->nb[2];
|
||||||
|
const int nb3 = dst->nb[3];
|
||||||
|
const ggml_type type = src0->type;
|
||||||
|
|
||||||
|
const float alpha = 1.0f;
|
||||||
|
const float beta = 0.0f;
|
||||||
|
const int x_ne = ne01 * ne00;
|
||||||
|
const int y_ne = ne11 * ne10;
|
||||||
|
const int d_ne = ne11 * ne01;
|
||||||
|
const int n_mm = ne03 * ne02;
|
||||||
|
const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type);
|
||||||
|
|
||||||
|
size_t x_size, y_size, d_size, q_size;
|
||||||
|
float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
|
||||||
|
float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
|
||||||
|
float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
|
||||||
|
char * d_Q = (char *) ggml_cuda_pool_malloc(n_mm * q_sz, &q_size);
|
||||||
|
|
||||||
|
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(type);
|
||||||
|
GGML_ASSERT(to_fp32_cuda != nullptr);
|
||||||
|
|
||||||
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||||
|
int i = i03*ne02 + i02;
|
||||||
|
cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
|
||||||
|
cudaStream_t cudaStream2 = g_cudaStreams2[i % GGML_CUDA_MAX_STREAMS];
|
||||||
|
cudaEvent_t cudaEvent = g_cudaEvents[i % GGML_CUDA_MAX_EVENTS];
|
||||||
|
|
||||||
|
float * c_X = d_X + i * x_ne;
|
||||||
|
float * c_Y = d_Y + i * y_ne;
|
||||||
|
float * c_D = d_D + i * d_ne;
|
||||||
|
char * c_Q = d_Q + i * q_sz;
|
||||||
|
|
||||||
|
// copy src0 and convert to fp32 on device
|
||||||
|
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
|
||||||
|
to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
|
||||||
|
|
||||||
|
// copy src1 to device
|
||||||
|
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
|
||||||
|
|
||||||
|
// wait for conversion
|
||||||
|
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
|
||||||
|
|
||||||
|
// compute
|
||||||
|
CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
|
||||||
|
CUBLAS_CHECK(
|
||||||
|
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||||
|
ne01, ne11, ne10,
|
||||||
|
&alpha, c_X, ne00,
|
||||||
|
c_Y, ne10,
|
||||||
|
&beta, c_D, ne01));
|
||||||
|
|
||||||
|
// copy dst to host
|
||||||
|
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||||
|
CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
|
ggml_cuda_pool_free(d_X, x_size);
|
||||||
|
ggml_cuda_pool_free(d_Y, y_size);
|
||||||
|
ggml_cuda_pool_free(d_D, d_size);
|
||||||
|
ggml_cuda_pool_free(d_Q, q_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
|
||||||
|
const int64_t ne10 = src1->ne[0];
|
||||||
|
|
||||||
|
const int64_t ne0 = dst->ne[0];
|
||||||
|
const int64_t ne1 = dst->ne[1];
|
||||||
|
|
||||||
|
// TODO: find the optimal values for these
|
||||||
|
if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
|
||||||
|
src1->type == GGML_TYPE_F32 &&
|
||||||
|
dst->type == GGML_TYPE_F32 &&
|
||||||
|
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ggml_cuda_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) {
|
||||||
|
size_t src0_sz = ggml_nbytes(src0);
|
||||||
|
size_t src1_sz = ggml_nbytes(src1);
|
||||||
|
|
||||||
|
// mul_mat_q: src0 is converted to fp32 on device
|
||||||
|
size_t mul_mat_q_transfer = src0_sz + src1_sz;
|
||||||
|
|
||||||
|
// mul_mat_f16: src1 is converted to fp16 on cpu
|
||||||
|
size_t mul_mat_f16_transfer = src0_sz + sizeof(half) * ggml_nelements(src1);
|
||||||
|
|
||||||
|
// choose the smaller one to transfer to the device
|
||||||
|
// TODO: this is not always the best choice due to the overhead of converting to fp16
|
||||||
|
return mul_mat_f16_transfer < mul_mat_q_transfer;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) {
|
||||||
|
GGML_ASSERT(ggml_cuda_can_mul_mat(src0, src1, dst));
|
||||||
|
|
||||||
|
if (src0->type == GGML_TYPE_F32) {
|
||||||
|
ggml_cuda_mul_mat_f32(src0, src1, dst);
|
||||||
|
}
|
||||||
|
else if (src0->type == GGML_TYPE_F16) {
|
||||||
|
if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) {
|
||||||
|
ggml_cuda_mul_mat_f16(src0, src1, dst, wdata, wsize);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
ggml_cuda_mul_mat_q_f32(src0, src1, dst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (ggml_is_quantized(src0->type)) {
|
||||||
|
ggml_cuda_mul_mat_q_f32(src0, src1, dst);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
|
||||||
|
if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) {
|
||||||
|
return ggml_nelements(src1) * sizeof(ggml_fp16_t);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
47
ggml-cuda.h
47
ggml-cuda.h
@ -1,54 +1,19 @@
|
|||||||
#include <cublas_v2.h>
|
|
||||||
#include <cuda_runtime.h>
|
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define CUDA_CHECK(err) \
|
|
||||||
do { \
|
|
||||||
cudaError_t err_ = (err); \
|
|
||||||
if (err_ != cudaSuccess) { \
|
|
||||||
fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
|
|
||||||
cudaGetErrorString(err_)); \
|
|
||||||
exit(1); \
|
|
||||||
} \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
#define CUBLAS_CHECK(err) \
|
|
||||||
do { \
|
|
||||||
cublasStatus_t err_ = (err); \
|
|
||||||
if (err_ != CUBLAS_STATUS_SUCCESS) { \
|
|
||||||
fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
|
|
||||||
exit(1); \
|
|
||||||
} \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
extern cublasHandle_t g_cublasH;
|
|
||||||
extern cudaStream_t g_cudaStream;
|
|
||||||
extern cudaStream_t g_cudaStream2;
|
|
||||||
extern cudaEvent_t g_cudaEvent;
|
|
||||||
|
|
||||||
void ggml_init_cublas(void);
|
void ggml_init_cublas(void);
|
||||||
|
|
||||||
|
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
|
||||||
|
size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
|
||||||
|
void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
|
||||||
|
|
||||||
|
// TODO: export these with GGML_API
|
||||||
void * ggml_cuda_host_malloc(size_t size);
|
void * ggml_cuda_host_malloc(size_t size);
|
||||||
void ggml_cuda_host_free(void * ptr);
|
void ggml_cuda_host_free(void * ptr);
|
||||||
|
|
||||||
void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size);
|
|
||||||
void ggml_cuda_pool_free(void * ptr, size_t size);
|
|
||||||
|
|
||||||
void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
|
|
||||||
void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
|
|
||||||
void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream);
|
|
||||||
void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
|
|
||||||
void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
|
|
||||||
void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
|
|
||||||
|
|
||||||
cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream);
|
|
||||||
|
|
||||||
typedef void (*dequantize_row_q_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
|
|
||||||
dequantize_row_q_cuda_t ggml_get_dequantize_row_q_cuda(enum ggml_type type);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
252
ggml.c
252
ggml.c
@ -135,14 +135,6 @@ inline static void* ggml_aligned_malloc(size_t size) {
|
|||||||
#define UNUSED(x) (void)(x)
|
#define UNUSED(x) (void)(x)
|
||||||
#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
|
#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
|
||||||
|
|
||||||
#define GGML_ASSERT(x) \
|
|
||||||
do { \
|
|
||||||
if (!(x)) { \
|
|
||||||
fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
|
|
||||||
abort(); \
|
|
||||||
} \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
#if defined(GGML_USE_ACCELERATE)
|
#if defined(GGML_USE_ACCELERATE)
|
||||||
#include <Accelerate/Accelerate.h>
|
#include <Accelerate/Accelerate.h>
|
||||||
#elif defined(GGML_USE_OPENBLAS)
|
#elif defined(GGML_USE_OPENBLAS)
|
||||||
@ -370,6 +362,32 @@ ggml_fp16_t ggml_fp32_to_fp16(float x) {
|
|||||||
return GGML_FP32_TO_FP16(x);
|
return GGML_FP32_TO_FP16(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, size_t n) {
|
||||||
|
for (size_t i = 0; i < n; i++) {
|
||||||
|
y[i] = GGML_FP16_TO_FP32(x[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n) {
|
||||||
|
size_t i = 0;
|
||||||
|
#if defined(__F16C__)
|
||||||
|
for (; i + 7 < n; i += 8) {
|
||||||
|
__m256 x_vec = _mm256_loadu_ps(x + i);
|
||||||
|
__m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
|
||||||
|
_mm_storeu_si128((__m128i *)(y + i), y_vec);
|
||||||
|
}
|
||||||
|
for(; i + 3 < n; i += 4) {
|
||||||
|
__m128 x_vec = _mm_loadu_ps(x + i);
|
||||||
|
__m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
|
||||||
|
_mm_storel_epi64((__m128i *)(y + i), y_vec);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
for (; i < n; i++) {
|
||||||
|
y[i] = GGML_FP32_TO_FP16(x[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// timing
|
// timing
|
||||||
//
|
//
|
||||||
@ -4325,12 +4343,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|||||||
GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
|
GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
// initialize cuBLAS
|
#if defined(GGML_USE_CUBLAS)
|
||||||
#if defined(GGML_USE_CUBLAS)
|
|
||||||
ggml_init_cublas();
|
ggml_init_cublas();
|
||||||
#elif defined(GGML_USE_CLBLAST)
|
#elif defined(GGML_USE_CLBLAST)
|
||||||
ggml_cl_init();
|
ggml_cl_init();
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
is_first_call = false;
|
is_first_call = false;
|
||||||
}
|
}
|
||||||
@ -8101,7 +8118,7 @@ static void ggml_compute_forward_rms_norm(
|
|||||||
|
|
||||||
// ggml_compute_forward_mul_mat
|
// ggml_compute_forward_mul_mat
|
||||||
|
|
||||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
||||||
// helper function to determine if it is better to use BLAS or not
|
// helper function to determine if it is better to use BLAS or not
|
||||||
// for large matrices, BLAS is faster
|
// for large matrices, BLAS is faster
|
||||||
static bool ggml_compute_forward_mul_mat_use_blas(
|
static bool ggml_compute_forward_mul_mat_use_blas(
|
||||||
@ -8117,12 +8134,9 @@ static bool ggml_compute_forward_mul_mat_use_blas(
|
|||||||
const int64_t ne1 = dst->ne[1];
|
const int64_t ne1 = dst->ne[1];
|
||||||
|
|
||||||
// TODO: find the optimal values for these
|
// TODO: find the optimal values for these
|
||||||
if (
|
if (ggml_is_contiguous(src0) &&
|
||||||
#if !defined(GGML_USE_CUBLAS)
|
|
||||||
ggml_is_contiguous(src0) &&
|
|
||||||
ggml_is_contiguous(src1) &&
|
ggml_is_contiguous(src1) &&
|
||||||
#endif
|
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
|
||||||
((ne0 >= 32 && ne1 >= 32 && ne10 >= 32))) {
|
|
||||||
|
|
||||||
/*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
|
/*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
|
||||||
return true;
|
return true;
|
||||||
@ -8130,7 +8144,6 @@ static bool ggml_compute_forward_mul_mat_use_blas(
|
|||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
static void ggml_compute_forward_mul_mat_f32(
|
static void ggml_compute_forward_mul_mat_f32(
|
||||||
@ -8146,7 +8159,7 @@ static void ggml_compute_forward_mul_mat_f32(
|
|||||||
const int64_t ne02 = src0->ne[2];
|
const int64_t ne02 = src0->ne[2];
|
||||||
const int64_t ne03 = src0->ne[3];
|
const int64_t ne03 = src0->ne[3];
|
||||||
|
|
||||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
||||||
const int64_t ne10 = src1->ne[0];
|
const int64_t ne10 = src1->ne[0];
|
||||||
#endif
|
#endif
|
||||||
const int64_t ne11 = src1->ne[1];
|
const int64_t ne11 = src1->ne[1];
|
||||||
@ -8203,7 +8216,16 @@ static void ggml_compute_forward_mul_mat_f32(
|
|||||||
// nb01 >= nb00 - src0 is not transposed
|
// nb01 >= nb00 - src0 is not transposed
|
||||||
// compute by src0 rows
|
// compute by src0 rows
|
||||||
|
|
||||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
#if defined(GGML_USE_CUBLAS)
|
||||||
|
if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
|
||||||
|
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
||||||
|
ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
||||||
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
||||||
if (params->ith != 0) {
|
if (params->ith != 0) {
|
||||||
return;
|
return;
|
||||||
@ -8217,43 +8239,13 @@ static void ggml_compute_forward_mul_mat_f32(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(GGML_USE_CUBLAS)
|
|
||||||
const float alpha = 1.0f;
|
|
||||||
const float beta = 0.0f;
|
|
||||||
const int x_ne = ne01 * ne00;
|
|
||||||
const int y_ne = ne11 * ne10;
|
|
||||||
const int d_ne = ne11 * ne01;
|
|
||||||
|
|
||||||
size_t x_size, y_size, d_size;
|
|
||||||
float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
|
|
||||||
float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
|
|
||||||
float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||||
#if !defined(GGML_USE_CUBLAS)
|
|
||||||
const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
|
const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
|
||||||
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
||||||
#endif
|
|
||||||
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||||
|
|
||||||
#if defined(GGML_USE_CUBLAS)
|
#if defined(GGML_USE_CLBLAST)
|
||||||
// copy data to device
|
|
||||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_X, src0, i03, i02, g_cudaStream));
|
|
||||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Y, src1, i03, i02, g_cudaStream));
|
|
||||||
|
|
||||||
// compute
|
|
||||||
CUBLAS_CHECK(
|
|
||||||
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
|
||||||
ne01, ne11, ne10,
|
|
||||||
&alpha, d_X, ne00,
|
|
||||||
d_Y, ne10,
|
|
||||||
&beta, d_D, ne01));
|
|
||||||
|
|
||||||
// copy data to host
|
|
||||||
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
|
|
||||||
#elif defined(GGML_USE_CLBLAST)
|
|
||||||
// zT = y * xT
|
// zT = y * xT
|
||||||
ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
|
ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
|
||||||
ne11, ne01, ne10,
|
ne11, ne01, ne10,
|
||||||
@ -8270,12 +8262,6 @@ static void ggml_compute_forward_mul_mat_f32(
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#if defined(GGML_USE_CUBLAS)
|
|
||||||
CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
|
|
||||||
ggml_cuda_pool_free(d_X, x_size);
|
|
||||||
ggml_cuda_pool_free(d_Y, y_size);
|
|
||||||
ggml_cuda_pool_free(d_D, d_size);
|
|
||||||
#endif
|
|
||||||
//printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
|
//printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
|
||||||
|
|
||||||
return;
|
return;
|
||||||
@ -8405,7 +8391,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|||||||
// nb01 >= nb00 - src0 is not transposed
|
// nb01 >= nb00 - src0 is not transposed
|
||||||
// compute by src0 rows
|
// compute by src0 rows
|
||||||
|
|
||||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
#if defined(GGML_USE_CUBLAS)
|
||||||
|
if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
|
||||||
|
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
||||||
|
ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
||||||
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
||||||
GGML_ASSERT(nb10 == sizeof(float));
|
GGML_ASSERT(nb10 == sizeof(float));
|
||||||
|
|
||||||
@ -8421,37 +8416,8 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(GGML_USE_CUBLAS)
|
|
||||||
const float alpha = 1.0f;
|
|
||||||
const float beta = 0.0f;
|
|
||||||
const int x_ne = ne01 * ne00;
|
|
||||||
const int y_ne = ne11 * ne10;
|
|
||||||
const int d_ne = ne11 * ne01;
|
|
||||||
|
|
||||||
size_t x_size, y_size, d_size;
|
|
||||||
ggml_fp16_t * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
|
|
||||||
ggml_fp16_t * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
|
|
||||||
float * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
|
|
||||||
#endif
|
|
||||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||||
#if defined(GGML_USE_CUBLAS)
|
|
||||||
// copy src0 while converting src1
|
|
||||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_X, src0, i03, i02, g_cudaStream));
|
|
||||||
|
|
||||||
// with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
|
|
||||||
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + (ne11 * ne10) * (i03 * ne02 + i02);
|
|
||||||
{
|
|
||||||
size_t id = 0;
|
|
||||||
for (int64_t i01 = 0; i01 < ne11; ++i01) {
|
|
||||||
for (int64_t i00 = 0; i00 < ne10; ++i00) {
|
|
||||||
wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
assert(id*sizeof(ggml_fp16_t) <= params->wsize);
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
float * const wdata = params->wdata;
|
float * const wdata = params->wdata;
|
||||||
{
|
{
|
||||||
size_t id = 0;
|
size_t id = 0;
|
||||||
@ -8463,28 +8429,8 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|||||||
|
|
||||||
assert(id*sizeof(float) <= params->wsize);
|
assert(id*sizeof(float) <= params->wsize);
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
#if defined(GGML_USE_CUBLAS)
|
#if defined(GGML_USE_CLBLAST)
|
||||||
const ggml_fp16_t * y = (ggml_fp16_t *) wdata;
|
|
||||||
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
|
||||||
|
|
||||||
// copy data to device
|
|
||||||
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
|
|
||||||
|
|
||||||
// compute
|
|
||||||
CUBLAS_CHECK(
|
|
||||||
cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
|
||||||
ne01, ne11, ne10,
|
|
||||||
&alpha, d_X, CUDA_R_16F, ne00,
|
|
||||||
d_Y, CUDA_R_16F, ne10,
|
|
||||||
&beta, d_D, CUDA_R_32F, ne01,
|
|
||||||
CUBLAS_COMPUTE_32F,
|
|
||||||
CUBLAS_GEMM_DEFAULT));
|
|
||||||
|
|
||||||
// copy data to host
|
|
||||||
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
|
|
||||||
#elif defined(GGML_USE_CLBLAST)
|
|
||||||
const float * x = wdata;
|
const float * x = wdata;
|
||||||
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
||||||
|
|
||||||
@ -8513,12 +8459,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(GGML_USE_CUBLAS)
|
|
||||||
CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
|
|
||||||
ggml_cuda_pool_free(d_X, x_size);
|
|
||||||
ggml_cuda_pool_free(d_Y, y_size);
|
|
||||||
ggml_cuda_pool_free(d_D, d_size);
|
|
||||||
#endif
|
|
||||||
/*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
|
/*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
|
||||||
|
|
||||||
return;
|
return;
|
||||||
@ -8671,7 +8611,16 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|||||||
// nb01 >= nb00 - src0 is not transposed
|
// nb01 >= nb00 - src0 is not transposed
|
||||||
// compute by src0 rows
|
// compute by src0 rows
|
||||||
|
|
||||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
#if defined(GGML_USE_CUBLAS)
|
||||||
|
if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
|
||||||
|
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
||||||
|
ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
||||||
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
||||||
if (params->ith != 0) {
|
if (params->ith != 0) {
|
||||||
return;
|
return;
|
||||||
@ -8685,25 +8634,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(GGML_USE_CUBLAS)
|
|
||||||
const float alpha = 1.0f;
|
|
||||||
const float beta = 0.0f;
|
|
||||||
const int x_ne = ne01 * ne00;
|
|
||||||
const int y_ne = ne11 * ne10;
|
|
||||||
const int d_ne = ne11 * ne01;
|
|
||||||
|
|
||||||
size_t x_size, y_size, d_size, q_size;
|
|
||||||
float * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
|
|
||||||
float * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
|
|
||||||
float * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
|
|
||||||
void * d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size);
|
|
||||||
|
|
||||||
const dequantize_row_q_cuda_t dequantize_row_q_cuda = ggml_get_dequantize_row_q_cuda(type);
|
|
||||||
GGML_ASSERT(dequantize_row_q_cuda != NULL);
|
|
||||||
#else
|
|
||||||
float * const wdata = params->wdata;
|
float * const wdata = params->wdata;
|
||||||
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
|
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
|
||||||
#endif
|
|
||||||
|
|
||||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||||
@ -8711,14 +8643,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|||||||
|
|
||||||
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||||
|
|
||||||
#if defined(GGML_USE_CUBLAS)
|
#if defined(GGML_USE_CLBLAST)
|
||||||
// copy and dequantize on device
|
|
||||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, src0, i03, i02, g_cudaStream2));
|
|
||||||
|
|
||||||
dequantize_row_q_cuda(d_Q, d_X, x_ne, g_cudaStream2);
|
|
||||||
CUDA_CHECK(cudaGetLastError());
|
|
||||||
CUDA_CHECK(cudaEventRecord(g_cudaEvent, g_cudaStream2));
|
|
||||||
#elif defined(GGML_USE_CLBLAST)
|
|
||||||
const void* x = (char *) src0->data + i03*nb03 + i02*nb02;
|
const void* x = (char *) src0->data + i03*nb03 + i02*nb02;
|
||||||
#else
|
#else
|
||||||
{
|
{
|
||||||
@ -8734,24 +8659,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|||||||
const float * x = wdata;
|
const float * x = wdata;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(GGML_USE_CUBLAS)
|
#if defined(GGML_USE_CLBLAST)
|
||||||
// copy data to device
|
|
||||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Y, src1, i03, i02, g_cudaStream));
|
|
||||||
|
|
||||||
// wait for dequantization
|
|
||||||
CUDA_CHECK(cudaStreamWaitEvent(g_cudaStream, g_cudaEvent, 0));
|
|
||||||
|
|
||||||
// compute
|
|
||||||
CUBLAS_CHECK(
|
|
||||||
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
|
||||||
ne01, ne11, ne10,
|
|
||||||
&alpha, d_X, ne00,
|
|
||||||
d_Y, ne10,
|
|
||||||
&beta, d_D, ne01));
|
|
||||||
|
|
||||||
// copy data to host
|
|
||||||
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
|
|
||||||
#elif defined(GGML_USE_CLBLAST)
|
|
||||||
// zT = y * xT
|
// zT = y * xT
|
||||||
ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
|
ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
|
||||||
ne11, ne01, ne10,
|
ne11, ne01, ne10,
|
||||||
@ -8769,13 +8677,6 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(GGML_USE_CUBLAS)
|
|
||||||
CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
|
|
||||||
ggml_cuda_pool_free(d_X, x_size);
|
|
||||||
ggml_cuda_pool_free(d_Y, y_size);
|
|
||||||
ggml_cuda_pool_free(d_D, d_size);
|
|
||||||
ggml_cuda_pool_free(d_Q, q_size);
|
|
||||||
#endif
|
|
||||||
//printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
|
//printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
|
||||||
|
|
||||||
return;
|
return;
|
||||||
@ -11759,18 +11660,21 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|||||||
|
|
||||||
size_t cur = 0;
|
size_t cur = 0;
|
||||||
|
|
||||||
|
#if defined(GGML_USE_CUBLAS)
|
||||||
|
if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
|
||||||
|
node->n_tasks = 1; // TODO: this actually is doing nothing
|
||||||
|
// the threads are still spinning
|
||||||
|
cur = ggml_cuda_mul_mat_get_wsize(node->src0, node->src1, node);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
#endif
|
||||||
if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
|
if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
|
||||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
||||||
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
||||||
node->n_tasks = 1; // TODO: this actually is doing nothing
|
node->n_tasks = 1; // TODO: this actually is doing nothing
|
||||||
// the threads are still spinning
|
// the threads are still spinning
|
||||||
#if defined(GGML_USE_CUBLAS)
|
|
||||||
// with cuBLAS, we need memory for the full 3D / 4D data of src1
|
|
||||||
cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
|
|
||||||
#else
|
|
||||||
// here we need memory just for single 2D matrix from src0
|
// here we need memory just for single 2D matrix from src0
|
||||||
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
|
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
|
||||||
#endif
|
|
||||||
} else {
|
} else {
|
||||||
cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
|
cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
|
||||||
}
|
}
|
||||||
@ -11779,13 +11683,13 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|||||||
#endif
|
#endif
|
||||||
} else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
|
} else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
|
||||||
cur = 0;
|
cur = 0;
|
||||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
||||||
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
||||||
node->n_tasks = 1;
|
node->n_tasks = 1;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
} else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
|
} else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
|
||||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
|
||||||
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
||||||
node->n_tasks = 1;
|
node->n_tasks = 1;
|
||||||
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
|
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
|
||||||
|
11
ggml.h
11
ggml.h
@ -197,6 +197,14 @@
|
|||||||
#define GGML_MAX_OPT 4
|
#define GGML_MAX_OPT 4
|
||||||
#define GGML_DEFAULT_N_THREADS 4
|
#define GGML_DEFAULT_N_THREADS 4
|
||||||
|
|
||||||
|
#define GGML_ASSERT(x) \
|
||||||
|
do { \
|
||||||
|
if (!(x)) { \
|
||||||
|
fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
|
||||||
|
abort(); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
@ -212,6 +220,9 @@ extern "C" {
|
|||||||
GGML_API float ggml_fp16_to_fp32(ggml_fp16_t x);
|
GGML_API float ggml_fp16_to_fp32(ggml_fp16_t x);
|
||||||
GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x);
|
GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x);
|
||||||
|
|
||||||
|
GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, size_t n);
|
||||||
|
GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n);
|
||||||
|
|
||||||
struct ggml_object;
|
struct ggml_object;
|
||||||
struct ggml_context;
|
struct ggml_context;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user