ggml-cuda : add rope f16, restore performance with parallel decoding (#3272)

* ggml-cuda : add rope f16, restore performance

* offload KQ_mask with all models

* fix rope shift

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
slaren 2023-09-20 13:00:28 +02:00 committed by GitHub
parent db0fc2da06
commit e04dc51988
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 110 additions and 67 deletions

View File

@ -439,7 +439,6 @@ static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullpt
struct ggml_tensor_extra_gpu {
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
cudaEvent_t events[GGML_CUDA_MAX_DEVICES][MAX_STREAMS]; // events for synchronizing multiple GPUs
bool copied;
};
// this is faster on Windows
@ -4357,8 +4356,9 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
// rope == RoPE == rotary positional embedding
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale,
const int p_delta_rows, const float theta_scale) {
template<typename T, bool has_pos>
static __global__ void rope(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale,
const int p_delta_rows, const float theta_scale) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (col >= ncols) {
@ -4369,8 +4369,8 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
const int i = row*ncols + col;
const int i2 = row/p_delta_rows;
const int p = pos != nullptr ? pos[i2] : 0;
const float p0 = p * freq_scale;
const int p = has_pos ? pos[i2] : 0;
const float p0 = p*freq_scale;
const float theta = p0*powf(theta_scale, col/2);
const float sin_theta = sinf(theta);
const float cos_theta = cosf(theta);
@ -4382,8 +4382,9 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
dst[i + 1] = x0*sin_theta + x1*cos_theta;
}
static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale,
const int p_delta_rows, const float theta_scale) {
template<typename T, bool has_pos>
static __global__ void rope_neox(const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale,
const int p_delta_rows, const float theta_scale) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (col >= ncols) {
@ -4394,8 +4395,8 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco
const int i = row*ncols + col/2;
const int i2 = row/p_delta_rows;
const int p = pos != nullptr ? pos[i2] : 0;
const float p0 = p * freq_scale;
const int p = has_pos ? pos[i2] : 0;
const float p0 = p*freq_scale;
const float theta = p0*powf(theta_scale, col/2);
const float sin_theta = sinf(theta);
const float cos_theta = cosf(theta);
@ -5371,22 +5372,32 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
}
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
template<typename T>
static void rope_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
GGML_ASSERT(ncols % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nrows, num_blocks_x, 1);
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
if (pos == nullptr) {
rope<T, false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
} else {
rope<T, true><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
}
}
static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
template<typename T>
static void rope_neox_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
GGML_ASSERT(ncols % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nrows, num_blocks_x, 1);
rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
if (pos == nullptr) {
rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
} else {
rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
}
}
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
@ -6036,7 +6047,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
const int64_t ne0 = dst->ne[0];
const int64_t row_diff = row_high - row_low;
float * src0_ddq_as_f32;
float * src0_ddq_as_f32 = nullptr;
size_t src0_as = 0;
if (src0->type != GGML_TYPE_F32) {
@ -6074,8 +6085,9 @@ inline void ggml_cuda_op_rope(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
@ -6093,23 +6105,12 @@ inline void ggml_cuda_op_rope(
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
const float theta_scale = powf(freq_base, -2.0f/n_dims);
// const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(src1->ne[0] == ne2);
GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
int id;
CUDA_CHECK(cudaGetDevice(&id));
int * pos = nullptr;
const int32_t * pos = nullptr;
if ((mode & 1) == 0) {
struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
pos = (int *) src1_extra->data_device[id];
if (!src1_extra->copied) {
CUDA_CHECK(cudaMemcpyAsync(pos, src1->data, ggml_nbytes(src1), cudaMemcpyHostToDevice, main_stream));
src1_extra->copied = true;
}
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(src1->ne[0] == ne2);
pos = (const int32_t *) src1_dd;
}
const bool is_neox = mode & 2;
@ -6121,9 +6122,21 @@ inline void ggml_cuda_op_rope(
rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, n_ctx, main_stream);
} else if (is_neox) {
GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
rope_neox_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
if (src0->type == GGML_TYPE_F32) {
rope_neox_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
} else if (src0->type == GGML_TYPE_F16) {
rope_neox_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
} else {
GGML_ASSERT(false);
}
} else {
rope_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
if (src0->type == GGML_TYPE_F32) {
rope_cuda((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
} else if (src0->type == GGML_TYPE_F16) {
rope_cuda((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
} else {
GGML_ASSERT(false);
}
}
(void) src1;
@ -6294,7 +6307,7 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
}
}
void ggml_cuda_set_peer_access(const int n_tokens) {
static void ggml_cuda_set_peer_access(const int n_tokens) {
static bool peer_access_enabled = false;
const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE;
@ -6622,27 +6635,27 @@ static void ggml_cuda_op_mul_mat(
}
}
void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
static void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_add);
}
void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
static void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul);
}
void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
static void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu);
}
void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu);
}
void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm);
}
void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm);
}
@ -6663,7 +6676,7 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
return false;
}
void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
@ -6692,7 +6705,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
}
void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1));
GGML_ASSERT(!ggml_is_permuted(src0));
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
@ -6726,7 +6739,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
}
void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU;
@ -6770,11 +6783,11 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
}
}
void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
}
void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne == ggml_nelements(src1));
@ -6822,29 +6835,29 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
(void) dst;
}
void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
static void ggml_cuda_dup(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_cpy(src0, dst, nullptr);
(void) src1;
}
void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
static void ggml_cuda_diag_mask_inf(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_diag_mask_inf);
}
void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
static void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_soft_max);
}
void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
static void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rope);
}
void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
}
void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
(void) src0;
(void) src1;
(void) dst;
@ -6967,11 +6980,13 @@ static struct ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
return extra;
}
void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) {
static void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) {
if (scratch && g_scratch_size == 0) {
return;
}
tensor->backend = GGML_BACKEND_GPU;
// recursively assign CUDA buffers until a compute tensor is found
if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) {
const ggml_op src0_op = tensor->src[0]->op;
@ -6983,8 +6998,6 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace, no_alloc);
}
tensor->backend = GGML_BACKEND_GPU;
if (scratch && no_alloc) {
return;
}
@ -7069,6 +7082,15 @@ void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset)
tensor->extra = extra;
}
void ggml_cuda_copy_to_device(struct ggml_tensor * tensor) {
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
GGML_ASSERT(ggml_is_contiguous(tensor));
struct ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
CUDA_CHECK(cudaMemcpy(extra->data_device[g_main_device], tensor->data, ggml_nbytes(tensor), cudaMemcpyHostToDevice));
}
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
ggml_cuda_assign_buffers_impl(tensor, true, false, false);
}

View File

@ -31,6 +31,7 @@ GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tens
GGML_API void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor);
GGML_API void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset);
GGML_API void ggml_cuda_copy_to_device(struct ggml_tensor * tensor);
GGML_API void ggml_cuda_set_main_device(int main_device);
GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q);

2
ggml.c
View File

@ -6343,7 +6343,7 @@ static struct ggml_tensor * ggml_cpy_impl(
}
// make a view of the destination
struct ggml_tensor * result = ggml_view_tensor(ctx, b);
struct ggml_tensor * result = b->op == GGML_OP_NONE ? b : ggml_view_tensor(ctx, b);
if (strlen(b->name) > 0) {
ggml_format_name(result, "%s (copy of %s)", b->name, a->name);
} else {

View File

@ -1256,10 +1256,10 @@ static bool llama_kv_cache_init(
(void) n_gpu_layers;
#ifdef GGML_USE_CUBLAS
if (n_gpu_layers > n_layer + 1) {
if (n_gpu_layers > (int)n_layer + 1) {
ggml_cuda_assign_buffers_no_scratch(cache.v);
}
if (n_gpu_layers > n_layer + 2) {
if (n_gpu_layers > (int)n_layer + 2) {
ggml_cuda_assign_buffers_no_scratch(cache.k);
}
#endif // GGML_USE_CUBLAS
@ -2692,14 +2692,16 @@ static struct ggml_cgraph * llm_build_llama(
// KQ_scale
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
ggml_allocr_alloc(lctx.alloc, KQ_scale);
if (!ggml_allocr_is_measure(lctx.alloc)) {
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd_head)));
}
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
offload_func_kq(KQ_mask);
ggml_set_name(KQ_mask, "KQ_mask");
ggml_allocr_alloc(lctx.alloc, KQ_mask);
if (!ggml_allocr_is_measure(lctx.alloc)) {
float * data = (float *) KQ_mask->data;
@ -2722,6 +2724,7 @@ static struct ggml_cgraph * llm_build_llama(
// KQ_pos - contains the positions
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
offload_func_kq(KQ_pos);
ggml_set_name(KQ_pos, "KQ_pos");
ggml_allocr_alloc(lctx.alloc, KQ_pos);
if (!ggml_allocr_is_measure(lctx.alloc)) {
int * data = (int *) KQ_pos->data;
@ -2734,6 +2737,7 @@ static struct ggml_cgraph * llm_build_llama(
if (do_rope_shift) {
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
offload_func_kq(K_shift);
ggml_set_name(K_shift, "K_shift");
ggml_allocr_alloc(lctx.alloc, K_shift);
if (!ggml_allocr_is_measure(lctx.alloc)) {
int * data = (int *) K_shift->data;
@ -2743,14 +2747,16 @@ static struct ggml_cgraph * llm_build_llama(
}
for (int il = 0; il < n_layer; ++il) {
ggml_build_forward_expand(gf,
struct ggml_tensor * tmp =
ggml_rope_custom_inplace(ctx0,
ggml_view_3d(ctx0, kv_self.k,
n_embd_head, n_head_kv, n_ctx,
ggml_element_size(kv_self.k)*n_embd_head,
ggml_element_size(kv_self.k)*n_embd_gqa,
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
K_shift, n_embd_head, 0, 0, freq_base, freq_scale));
K_shift, n_embd_head, 0, 0, freq_base, freq_scale);
offload_func_kq(tmp);
ggml_build_forward_expand(gf, tmp);
}
}
@ -3078,14 +3084,16 @@ static struct ggml_cgraph * llm_build_baichaun(
// KQ_scale
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
ggml_allocr_alloc(lctx.alloc, KQ_scale);
if (!ggml_allocr_is_measure(lctx.alloc)) {
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
}
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
offload_func_kq(KQ_mask);
ggml_set_name(KQ_mask, "KQ_mask");
ggml_allocr_alloc(lctx.alloc, KQ_mask);
if (!ggml_allocr_is_measure(lctx.alloc)) {
float * data = (float *) KQ_mask->data;
@ -3108,6 +3116,7 @@ static struct ggml_cgraph * llm_build_baichaun(
// KQ_pos - contains the positions
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
offload_func_kq(KQ_pos);
ggml_set_name(KQ_pos, "KQ_pos");
ggml_allocr_alloc(lctx.alloc, KQ_pos);
if (!ggml_allocr_is_measure(lctx.alloc)) {
int * data = (int *) KQ_pos->data;
@ -3120,6 +3129,7 @@ static struct ggml_cgraph * llm_build_baichaun(
if (do_rope_shift) {
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
offload_func_kq(K_shift);
ggml_set_name(K_shift, "K_shift");
ggml_allocr_alloc(lctx.alloc, K_shift);
if (!ggml_allocr_is_measure(lctx.alloc)) {
int * data = (int *) K_shift->data;
@ -3129,14 +3139,16 @@ static struct ggml_cgraph * llm_build_baichaun(
}
for (int il = 0; il < n_layer; ++il) {
ggml_build_forward_expand(gf,
struct ggml_tensor * tmp =
ggml_rope_custom_inplace(ctx0,
ggml_view_3d(ctx0, kv_self.k,
n_embd_head, n_head_kv, n_ctx,
ggml_element_size(kv_self.k)*n_embd_head,
ggml_element_size(kv_self.k)*n_embd_gqa,
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
K_shift, n_embd_head, 0, 0, freq_base, freq_scale));
K_shift, n_embd_head, 0, 0, freq_base, freq_scale);
offload_func_kq(tmp);
ggml_build_forward_expand(gf, tmp);
}
}
@ -3484,14 +3496,16 @@ static struct ggml_cgraph * llm_build_falcon(
// KQ_scale
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
ggml_allocr_alloc(lctx.alloc, KQ_scale);
if (!ggml_allocr_is_measure(lctx.alloc)) {
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
}
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
offload_func_kq(KQ_mask);
ggml_set_name(KQ_mask, "KQ_mask");
ggml_allocr_alloc(lctx.alloc, KQ_mask);
if (!ggml_allocr_is_measure(lctx.alloc)) {
float * data = (float *) KQ_mask->data;
@ -3514,6 +3528,7 @@ static struct ggml_cgraph * llm_build_falcon(
// KQ_pos - contains the positions
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
offload_func_kq(KQ_pos);
ggml_set_name(KQ_pos, "KQ_pos");
ggml_allocr_alloc(lctx.alloc, KQ_pos);
if (!ggml_allocr_is_measure(lctx.alloc)) {
int * data = (int *) KQ_pos->data;
@ -3526,6 +3541,7 @@ static struct ggml_cgraph * llm_build_falcon(
if (do_rope_shift) {
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
offload_func_kq(K_shift);
ggml_set_name(K_shift, "K_shift");
ggml_allocr_alloc(lctx.alloc, K_shift);
if (!ggml_allocr_is_measure(lctx.alloc)) {
int * data = (int *) K_shift->data;
@ -3535,14 +3551,16 @@ static struct ggml_cgraph * llm_build_falcon(
}
for (int il = 0; il < n_layer; ++il) {
ggml_build_forward_expand(gf,
struct ggml_tensor * tmp =
ggml_rope_custom_inplace(ctx0,
ggml_view_3d(ctx0, kv_self.k,
n_embd_head, n_head_kv, n_ctx,
ggml_element_size(kv_self.k)*n_embd_head,
ggml_element_size(kv_self.k)*n_embd_gqa,
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
K_shift, n_embd_head, 2, 0, freq_base, freq_scale));
K_shift, n_embd_head, 2, 0, freq_base, freq_scale);
offload_func_kq(tmp);
ggml_build_forward_expand(gf, tmp);
}
}
@ -3832,14 +3850,15 @@ static struct ggml_cgraph * llm_build_starcoder(
// KQ_scale
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
ggml_allocr_alloc(lctx.alloc, KQ_scale);
if (!ggml_allocr_is_measure(lctx.alloc)) {
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
}
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
ggml_set_name(KQ_mask, "KQ_mask");
ggml_allocr_alloc(lctx.alloc, KQ_mask);
if (!ggml_allocr_is_measure(lctx.alloc)) {
float * data = (float *) KQ_mask->data;
@ -4118,6 +4137,7 @@ static int llama_decode_internal(
ggml_tensor * node = gf->leafs[i];
if (node->backend == GGML_BACKEND_GPU && node->extra == NULL) {
ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data);
ggml_cuda_copy_to_device(node);
}
}