From 67be2ce1015d070b3b2cd488bcb041eefb61de72 Mon Sep 17 00:00:00 2001 From: slaren Date: Sun, 3 Mar 2024 14:26:18 +0100 Subject: [PATCH 1/8] cuda : fix data race in soft max (#5853) --- ggml-cuda.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 7ed97430f..04c6cb1b8 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6904,6 +6904,7 @@ static __global__ void soft_max_f32(const float * x, const float * mask, const f // find the sum of exps in the block tmp = warp_reduce_sum(tmp); if (block_size > WARP_SIZE) { + __syncthreads(); if (warp_id == 0) { buf_iw[lane_id] = 0.0f; } From 5a51cc1bb4592f0d71f9af89cd08b11a066ba447 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?DAN=E2=84=A2?= Date: Mon, 4 Mar 2024 02:57:20 -0500 Subject: [PATCH 2/8] main : support special tokens as reverse/anti prompt (#5847) * Support special tokens as reverse/anti prompt. * Tokenize antiprompts only once. * main : minor --------- Co-authored-by: Georgi Gerganov --- examples/main/main.cpp | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 34e84d0d4..47059e582 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -511,6 +511,14 @@ int main(int argc, char ** argv) { std::vector embd; std::vector embd_guidance; + // tokenized antiprompts + std::vector> antiprompt_ids; + + antiprompt_ids.reserve(params.antiprompt.size()); + for (const std::string & antiprompt : params.antiprompt) { + antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true)); + } + struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); while ((n_remain != 0 && !is_antiprompt) || params.interactive) { @@ -769,6 +777,18 @@ int main(int argc, char ** argv) { } } + // check for reverse prompt using special tokens + llama_token last_token = llama_sampling_last(ctx_sampling); + for (std::vector ids : antiprompt_ids) { + if (ids.size() == 1 && last_token == ids[0]) { + if (params.interactive) { + is_interacting = true; + } + is_antiprompt = true; + break; + } + } + if (is_antiprompt) { LOG("found antiprompt: %s\n", last_output.c_str()); } From 82f3e668adafba647de703f835991e91a96b5ac4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?DAN=E2=84=A2?= Date: Mon, 4 Mar 2024 03:08:19 -0500 Subject: [PATCH 3/8] common : use LLAMA_DEFAULT_SEED (#5855) --- common/common.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/common.h b/common/common.h index d3682b7ad..b2868833b 100644 --- a/common/common.h +++ b/common/common.h @@ -43,7 +43,7 @@ extern char const *LLAMA_BUILD_TARGET; int32_t get_num_physical_cores(); struct gpt_params { - uint32_t seed = -1; // RNG seed + uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed int32_t n_threads = get_num_physical_cores(); int32_t n_threads_draft = -1; From 7d43c585dc174bb586775c22c15e5db9242b5b4b Mon Sep 17 00:00:00 2001 From: leejet Date: Sun, 3 Mar 2024 20:23:52 +0800 Subject: [PATCH 4/8] add some new ops, fix some operators and add batch operations to certain operators. (ggml/747) * cuda: fix group_norm * cuda: add batch inference support for ggml_pad/ggml_upscale * add ggml_arrange * add ggml_timestep_embedding * update ggml_arange/ggml_timestep_embedding tests * cuda: fix im2col * add ggml_arange/ggml_timestep_embbeding support for metal backend * fix some bugs * fix some bugs * Update ggml.h Co-authored-by: Georgi Gerganov * Update ggml-cuda.cu Co-authored-by: Georgi Gerganov * Update ggml-metal.m Co-authored-by: Georgi Gerganov * Update ggml-metal.m Co-authored-by: Georgi Gerganov * Update ggml-metal.metal Co-authored-by: Georgi Gerganov * modify according to the review comments * ggml : fix compile warnings + code style * ggml : normalize compute_forward calls + fix seg fault in debug * minor --------- Co-authored-by: Georgi Gerganov Co-authored-by: slaren --- ggml-cuda.cu | 227 ++++++++++++++++++++++++++++++------- ggml-metal.m | 62 +++++++++- ggml-metal.metal | 43 +++++++ ggml.c | 207 +++++++++++++++++++++++++++++++-- ggml.h | 17 +++ tests/test-backend-ops.cpp | 46 ++++++++ 6 files changed, 550 insertions(+), 52 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 04c6cb1b8..7d027a30a 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -616,6 +616,8 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_fp16_t) + sizeof(uint16_t) + Q #define CUDA_UPSCALE_BLOCK_SIZE 256 #define CUDA_CONCAT_BLOCK_SIZE 256 #define CUDA_PAD_BLOCK_SIZE 256 +#define CUDA_ARANGE_BLOCK_SIZE 256 +#define CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE 256 #define CUDA_ACC_BLOCK_SIZE 256 #define CUDA_IM2COL_BLOCK_SIZE 256 #define CUDA_POOL2D_BLOCK_SIZE 256 @@ -990,17 +992,21 @@ static __global__ void concat_f32(const float * x,const float * y, float * dst, nidx + blockIdx.y * ne0 + blockIdx.z * ne0 * gridDim.y; - dst[offset_dst] = x[offset_src]; + dst[offset_dst] = x[offset_src]; } else { int offset_src = nidx + blockIdx.y * ne0 + (blockIdx.z - ne02) * ne0 * gridDim.y; - dst[offset_dst] = y[offset_src]; + dst[offset_dst] = y[offset_src]; } } -static __global__ void upscale_f32(const float * x, float * dst, const int ne00, const int nb02, const int scale_factor) { +static __global__ void upscale_f32(const float * x, float * dst, const int ne00, const int ne00xne01, const int scale_factor) { + // blockIdx.z: idx of ne02*ne03 + // blockIdx.y: idx of ne01*scale_factor, aka ne1 + // blockIDx.x: idx of ne00*scale_factor / BLOCK_SIZE + // ne00xne01: ne00 * ne01 int ne0 = ne00 * scale_factor; int nidx = threadIdx.x + blockIdx.x * blockDim.x; if (nidx >= ne0) { @@ -1012,7 +1018,7 @@ static __global__ void upscale_f32(const float * x, float * dst, const int ne00, int offset_src = i00 + i01 * ne00 + - blockIdx.z * nb02; + blockIdx.z * ne00xne01; int offset_dst = nidx + blockIdx.y * ne0 + @@ -1020,7 +1026,10 @@ static __global__ void upscale_f32(const float * x, float * dst, const int ne00, dst[offset_dst] = x[offset_src]; } -static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02) { +static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) { + // blockIdx.z: idx of ne2*ne3, aka ne02*ne03 + // blockIdx.y: idx of ne1 + // blockIDx.x: idx of ne0 / BLOCK_SIZE int nidx = threadIdx.x + blockIdx.x * blockDim.x; if (nidx >= ne0) { return; @@ -1031,19 +1040,53 @@ static __global__ void pad_f32(const float * x, float * dst, const int ne0, cons nidx + blockIdx.y * ne0 + blockIdx.z * ne0 * gridDim.y; - if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02) { + if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) { int offset_src = nidx + blockIdx.y * ne00 + blockIdx.z * ne00 * ne01; - dst[offset_dst] = x[offset_src]; + dst[offset_dst] = x[offset_src]; } else { dst[offset_dst] = 0.0f; } } +static __global__ void arange_f32(float * dst, const int ne0, const float start, const float step) { + // blockIDx.x: idx of ne0 / BLOCK_SIZE + int nidx = threadIdx.x + blockIdx.x * blockDim.x; + if (nidx >= ne0) { + return; + } + dst[nidx] = start + step * nidx; +} + +static __global__ void timestep_embedding_f32(const float * timesteps, float * dst, const int nb1, const int dim, const int max_period) { + // blockIDx.y: idx of timesteps->ne[0] + // blockIDx.x: idx of ((dim + 1) / 2) / BLOCK_SIZE + int i = blockIdx.y; + int j = threadIdx.x + blockIdx.x * blockDim.x; + float * embed_data = (float *)((char *)dst + i*nb1); + + if (dim % 2 != 0 && j == ((dim + 1) / 2)) { + embed_data[dim] = 0.f; + } + + int half = dim / 2; + if (j >= half) { + return; + } + + float timestep = timesteps[i]; + float freq = (float)expf(-logf(max_period) * j / half); + float arg = timestep * freq; + embed_data[j] = cosf(arg); + embed_data[j + half] = sinf(arg); +} + template static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) { + // blockIdx.x: num_groups idx + // threadIdx.x: block_size idx int start = blockIdx.x * group_size; int end = start + group_size; @@ -6448,7 +6491,7 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; + const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= ne) { return; @@ -6456,17 +6499,17 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, // determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor // then combine those indices with the corresponding byte offsets to get the total offsets - const int i03 = i/(ne00 * ne01 * ne02); - const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); - const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; - const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; - const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; + const int64_t i03 = i/(ne00 * ne01 * ne02); + const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); + const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; + const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; + const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; - const int i13 = i/(ne10 * ne11 * ne12); - const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); - const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; - const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; - const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13; + const int64_t i13 = i/(ne10 * ne11 * ne12); + const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); + const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; + const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; + const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13; cpy_1(cx + x_offset, cdst + dst_offset); } @@ -6956,23 +6999,23 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min, template static __global__ void im2col_kernel( - const float * x, T * dst, int batch_offset, - int offset_delta, int IC, int IW, int IH, int OH, int OW, int KW, int KH, int pelements, int CHW, + const float * x, T * dst, int64_t batch_offset, + int64_t offset_delta, int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW, int s0, int s1, int p0, int p1, int d0, int d1) { - const int i = threadIdx.x + blockIdx.x * blockDim.x; + const int64_t i = threadIdx.x + blockIdx.x * blockDim.x; if (i >= pelements) { return; } - const int ksize = OW * (KH > 1 ? KW : 1); - const int kx = i / ksize; - const int kd = kx * ksize; - const int ky = (i - kd) / OW; - const int ix = i % OW; + const int64_t ksize = OW * (KH > 1 ? KW : 1); + const int64_t kx = i / ksize; + const int64_t kd = kx * ksize; + const int64_t ky = (i - kd) / OW; + const int64_t ix = i % OW; - const int oh = blockIdx.y; - const int batch = blockIdx.z / IC; - const int ic = blockIdx.z % IC; + const int64_t oh = blockIdx.y; + const int64_t batch = blockIdx.z / IC; + const int64_t ic = blockIdx.z % IC; const int64_t iiw = ix * s0 + kx * d0 - p0; const int64_t iih = oh * s1 + ky * d1 - p1; @@ -7298,19 +7341,33 @@ static void concat_f32_cuda(const float * x, const float * y, float * dst, const concat_f32<<>>(x, y, dst, ne0, ne02); } -static void upscale_f32_cuda(const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int scale_factor, cudaStream_t stream) { +static void upscale_f32_cuda(const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int ne03, + const int scale_factor, cudaStream_t stream) { int ne0 = (ne00 * scale_factor); int num_blocks = (ne0 + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE; - dim3 gridDim(num_blocks, (ne01 * scale_factor), ne02); + dim3 gridDim(num_blocks, (ne01 * scale_factor), ne02*ne03); upscale_f32<<>>(x, dst, ne00, ne00 * ne01, scale_factor); } static void pad_f32_cuda(const float * x, float * dst, - const int ne00, const int ne01, const int ne02, - const int ne0, const int ne1, const int ne2, cudaStream_t stream) { + const int ne00, const int ne01, const int ne02, const int ne03, + const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) { int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE; - dim3 gridDim(num_blocks, ne1, ne2); - pad_f32<<>>(x, dst, ne0, ne00, ne01, ne02); + dim3 gridDim(num_blocks, ne1, ne2*ne3); + pad_f32<<>>(x, dst, ne0, ne00, ne01, ne02, ne03); +} + +static void arange_f32_cuda(float * dst, const int ne0, const float start, const float step, cudaStream_t stream) { + int num_blocks = (ne0 + CUDA_ARANGE_BLOCK_SIZE - 1) / CUDA_ARANGE_BLOCK_SIZE; + arange_f32<<>>(dst, ne0, start, step); +} + +static void timestep_embedding_f32_cuda(const float * x, float * dst, const int ne00, const int nb1, + const int dim, const int max_period, cudaStream_t stream) { + int half_ceil = (dim + 1) / 2; + int num_blocks = (half_ceil + CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE; + dim3 gridDim(num_blocks, ne00, 1); + timestep_embedding_f32<<>>(x, dst, nb1, dim, max_period); } static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { @@ -8443,8 +8500,8 @@ static void soft_max_f32_cuda(const float * x, const float * mask, const float * template static void im2col_cuda(const float* x, T* dst, - int IW, int IH, int OW, int OH, int KW, int KH, int IC, - int batch, int batch_offset, int offset_delta, + int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC, + int64_t batch, int64_t batch_offset, int64_t offset_delta, int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) { const int parallel_elements = OW * KW * KH; const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; @@ -9123,7 +9180,7 @@ static void ggml_cuda_op_group_norm( int num_groups = dst->op_params[0]; int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); - group_norm_f32_cuda(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream); + group_norm_f32_cuda(src0_dd, dst_dd, num_groups * src0->ne[3], group_size, ggml_nelements(src0), main_stream); (void) src1; (void) dst; @@ -9156,7 +9213,7 @@ static void ggml_cuda_op_upscale( const int scale_factor = dst->op_params[0]; - upscale_f32_cuda(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], scale_factor, main_stream); + upscale_f32_cuda(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], scale_factor, main_stream); (void) src1; (void) dst; @@ -9172,8 +9229,49 @@ static void ggml_cuda_op_pad( GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors pad_f32_cuda(src0_dd, dst_dd, - src0->ne[0], src0->ne[1], src0->ne[2], - dst->ne[0], dst->ne[1], dst->ne[2], main_stream); + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +static void ggml_cuda_op_arange( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { + + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + float start; + float stop; + float step; + memcpy(&start, (float *)dst->op_params + 0, sizeof(float)); + memcpy(&stop, (float *)dst->op_params + 1, sizeof(float)); + memcpy(&step, (float *)dst->op_params + 2, sizeof(float)); + + int64_t steps = (int64_t)ceil((stop - start) / step); + GGML_ASSERT(ggml_nelements(dst) == steps); + + arange_f32_cuda(dst_dd, dst->ne[0], start, step, main_stream); + + (void) src0; + (void) src1; + (void) src0_dd; + (void) src1_dd; +} + +static void ggml_cuda_op_timestep_embedding( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int dim = dst->op_params[0]; + const int max_period = dst->op_params[1]; + + timestep_embedding_f32_cuda(src0_dd, dst_dd, src0->ne[0], dst->nb[1], dim, max_period, main_stream); (void) src1; (void) dst; @@ -10458,6 +10556,45 @@ static void ggml_cuda_pad(const ggml_tensor * src0, const ggml_tensor * src1, gg ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_pad); } +static void ggml_cuda_arange(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + + const bool dst_on_device = dst->backend == GGML_BACKEND_TYPE_GPU; + + // dd = data device + float * src0_ddf = nullptr; + float * src1_ddf = nullptr; + float * dst_ddf = nullptr; + + cuda_pool_alloc dst_f; + + ggml_cuda_set_device(g_main_device); + cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + + if (dst_on_device) { + dst_ddf = (float *) dst_extra->data_device[g_main_device]; + } else { + dst_ddf = dst_f.alloc(ggml_nelements(dst)); + } + + // do the computation + ggml_cuda_op_arange(src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream); + CUDA_CHECK(cudaGetLastError()); + + // copy dst to host if necessary + if (!dst_on_device) { + CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream)); + } + + if (dst->backend == GGML_BACKEND_TYPE_CPU) { + CUDA_CHECK(cudaDeviceSynchronize()); + } +} + +static void ggml_cuda_timestep_embedding(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_timestep_embedding); +} + static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm); } @@ -11358,6 +11495,12 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st case GGML_OP_PAD: func = ggml_cuda_pad; break; + case GGML_OP_ARANGE: + func = ggml_cuda_arange; + break; + case GGML_OP_TIMESTEP_EMBEDDING: + func = ggml_cuda_timestep_embedding; + break; case GGML_OP_LEAKY_RELU: func = ggml_cuda_leaky_relu; break; @@ -12253,6 +12396,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_GROUP_NORM: case GGML_OP_UPSCALE: case GGML_OP_PAD: + case GGML_OP_ARANGE: + case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_LEAKY_RELU: return true; default: diff --git a/ggml-metal.m b/ggml-metal.m index 71fcca560..6b5a8fdf5 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -163,6 +163,8 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_IM2COL_F32, GGML_METAL_KERNEL_TYPE_UPSCALE_F32, GGML_METAL_KERNEL_TYPE_PAD_F32, + GGML_METAL_KERNEL_TYPE_ARANGE_F32, + GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, @@ -569,6 +571,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); @@ -697,6 +701,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const return false; case GGML_OP_UPSCALE: case GGML_OP_PAD: + case GGML_OP_ARANGE: + case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_ARGSORT: case GGML_OP_LEAKY_RELU: return true; @@ -1091,7 +1097,8 @@ static bool ggml_metal_graph_compute( { GGML_ASSERT(ggml_is_contiguous(src0)); - const float scale = *(const float *) dst->op_params; + float scale; + memcpy(&scale, dst->op_params, sizeof(scale)); int64_t n = ggml_nelements(dst); @@ -1250,11 +1257,15 @@ static bool ggml_metal_graph_compute( pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline; } - const float scale = ((float *) dst->op_params)[0]; - const float max_bias = ((float *) dst->op_params)[1]; + float scale; + float max_bias; + + memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); + memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); const int64_t nrows_x = ggml_nrows(src0); const int64_t nrows_y = src0->ne[1]; + const uint32_t n_head_kv = nrows_x/nrows_y; const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); @@ -2086,6 +2097,7 @@ static bool ggml_metal_graph_compute( //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_head = ((int32_t *) dst->op_params)[1]; + float max_bias; memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); @@ -2300,6 +2312,50 @@ static bool ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; + case GGML_OP_ARANGE: + { + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + float start; + float step; + + memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float)); + memcpy(&step, ((int32_t *) dst->op_params) + 2, sizeof(float)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:0]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1]; + [encoder setBytes:&start length:sizeof(start) atIndex:2]; + [encoder setBytes:&step length:sizeof(step) atIndex:3]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_TIMESTEP_EMBEDDING: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + const int dim = dst->op_params[0]; + const int max_period = dst->op_params[1]; + + const int half = dim / 2; + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2]; + [encoder setBytes:&dim length:sizeof(dim) atIndex:3]; + [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4]; + + const int nth = MIN(1024, half); + + [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case GGML_OP_ARGSORT: { GGML_ASSERT(src0->type == GGML_TYPE_F32); diff --git a/ggml-metal.metal b/ggml-metal.metal index 8b9488437..a65d12641 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1959,6 +1959,49 @@ kernel void kernel_pad_f32( } } +kernel void kernel_arange_f32( + device char * dst, + constant int64_t & ne0, + constant float & start, + constant float & step, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + device float * dst_ptr = (device float *) dst; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + dst_ptr[i0] = start + step * i0; + } +} + +kernel void kernel_timestep_embedding_f32( + device const char * src0, + device char * dst, + constant uint64_t & nb1, + constant int & dim, + constant int & max_period, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + int i = tgpig.x; + device float * embed_data = (device float *)(dst + i*nb1); + + int half_ = dim / 2; + for (int j = tpitg.x; j < half_; j += ntg.x) { + float timestep = ((device float *)src0)[i]; + float freq = (float)exp(-log((float)max_period) * j / half_); + float arg = timestep * freq; + embed_data[j ] = cos(arg); + embed_data[j + half_] = sin(arg); + } + + if (dim % 2 != 0 && tpitg.x == 0) { + embed_data[dim] = 0.f; + } +} + // bitonic sort implementation following the CUDA kernels as reference typedef void (argsort_t)( device const float * x, diff --git a/ggml.c b/ggml.c index f29b9f13f..870e41614 100644 --- a/ggml.c +++ b/ggml.c @@ -1822,6 +1822,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "POOL_2D", "UPSCALE", "PAD", + "ARANGE", + "TIMESTEP_EMBEDDING", "ARGSORT", "LEAKY_RELU", @@ -1850,7 +1852,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72"); +static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1908,6 +1910,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "pool_2d(x)", "upscale(x)", "pad(x)", + "arange(start, stop, step)", + "timestep_embedding(timesteps, dim, max_period)", "argsort(x)", "leaky_relu(x)", @@ -1936,7 +1940,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72"); +static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -2895,11 +2899,21 @@ static int32_t ggml_get_op_params_i32(const struct ggml_tensor * tensor, uint32_ return ((const int32_t *)(tensor->op_params))[i]; } +static float ggml_get_op_params_f32(const struct ggml_tensor * tensor, uint32_t i) { + assert(i < GGML_MAX_OP_PARAMS / sizeof(float)); + return ((const float *)(tensor->op_params))[i]; +} + static void ggml_set_op_params_i32(struct ggml_tensor * tensor, uint32_t i, int32_t value) { assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t)); ((int32_t *)(tensor->op_params))[i] = value; } +static void ggml_set_op_params_f32(struct ggml_tensor * tensor, uint32_t i, float value) { + assert(i < GGML_MAX_OP_PARAMS / sizeof(float)); + ((float *)(tensor->op_params))[i] = value; +} + struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) { memset(tensor->data, 0, ggml_nbytes(tensor)); return tensor; @@ -5898,6 +5912,55 @@ struct ggml_tensor * ggml_upscale( return ggml_upscale_impl(ctx, a, scale_factor); } +struct ggml_tensor * ggml_arange( + struct ggml_context * ctx, + float start, + float stop, + float step) { + + GGML_ASSERT(stop > start); + + const int64_t steps = (int64_t) ceilf((stop - start) / step); + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps); + + result->op = GGML_OP_ARANGE; + ggml_set_op_params_f32(result, 0, start); + ggml_set_op_params_f32(result, 1, stop); + ggml_set_op_params_f32(result, 2, step); + + return result; +} + +struct ggml_tensor * ggml_timestep_embedding( + struct ggml_context * ctx, + struct ggml_tensor * timesteps, + int dim, + int max_period) { + bool is_node = false; + + if (timesteps->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + int actual_dim = dim; + if (dim % 2 != 0) { + actual_dim = dim + 1; + } + + struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, actual_dim, timesteps->ne[0]); + + result->op = GGML_OP_TIMESTEP_EMBEDDING; + ggml_set_op_params_i32(result, 0, dim); + ggml_set_op_params_i32(result, 1, max_period); + + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = timesteps; + + return result; +} + // ggml_argsort struct ggml_tensor * ggml_argsort( @@ -10231,7 +10294,7 @@ static void ggml_compute_forward_group_norm_f32( int n_channels = src0->ne[2]; int n_groups = dst->op_params[0]; int n_channels_per_group = (n_channels + n_groups - 1) / n_groups; - for (int i = ith; i < n_groups; i+=nth) { + for (int i = ith; i < n_groups; i += nth) { int start = i * n_channels_per_group; int end = start + n_channels_per_group; if (end > n_channels) { @@ -10245,28 +10308,32 @@ static void ggml_compute_forward_group_norm_f32( for (int64_t i01 = 0; i01 < ne01; i01++) { const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); + ggml_float sumr = 0.0; for (int64_t i00 = 0; i00 < ne00; i00++) { - sum += (ggml_float)x[i00]; + sumr += (ggml_float)x[i00]; } + sum += sumr; } } - float mean = sum / (ne00 * ne01 * step); - ggml_float sum2 = 0.0; + const float mean = sum / (ne00 * ne01 * step); + ggml_float sum2 = 0.0; for (int64_t i02 = start; i02 < end; i02++) { for (int64_t i01 = 0; i01 < ne01; i01++) { const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3); + ggml_float sumr = 0.0; for (int64_t i00 = 0; i00 < ne00; i00++) { float v = x[i00] - mean; y[i00] = v; - sum2 += (ggml_float)(v * v); + sumr += (ggml_float)(v * v); } + sum2 += sumr; } } - float variance = sum2 / (ne00 * ne01 * step); + const float variance = sum2 / (ne00 * ne01 * step); const float scale = 1.0f / sqrtf(variance + eps); for (int64_t i02 = start; i02 < end; i02++) { @@ -13547,6 +13614,106 @@ static void ggml_compute_forward_pad( } } + +// ggml_compute_forward_arange + +static void ggml_compute_forward_arange_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + return; + } + + GGML_ASSERT(dst->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const float start = ggml_get_op_params_f32(dst, 0); + const float stop = ggml_get_op_params_f32(dst, 1); + const float step = ggml_get_op_params_f32(dst, 2); + + const int64_t steps = (int64_t) ceilf((stop - start) / step); + + GGML_ASSERT(ggml_nelements(dst) == steps); + + for (int64_t i = ith; i < steps; i+= nth) { + float value = start + step * i; + ((float *)dst->data)[i] = value; + } +} + +static void ggml_compute_forward_arange( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + switch (dst->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_arange_f32(params, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + +static void ggml_compute_forward_timestep_embedding_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + return; + } + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + const int dim = ggml_get_op_params_i32(dst, 0); + const int max_period = ggml_get_op_params_i32(dst, 1); + + int half = dim / 2; + + for (int64_t i = 0; i < ne00; i++) { + float * embed_data = (float *)((char *) dst->data + i*nb1); + for (int64_t j = ith; j < half; j += nth) { + float timestep = ((float *)src0->data)[i]; + float freq = (float)expf(-logf(max_period) * j / half); + float arg = timestep * freq; + embed_data[j] = cosf(arg); + embed_data[j + half] = sinf(arg); + } + if (dim % 2 != 0 && ith == 0) { + embed_data[dim] = 0.f; + } + } +} + +static void ggml_compute_forward_timestep_embedding( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_timestep_embedding_f32(params, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_argsort static void ggml_compute_forward_argsort_f32( @@ -15615,6 +15782,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_pad(params, tensor); } break; + case GGML_OP_ARANGE: + { + ggml_compute_forward_arange(params, tensor); + } break; + case GGML_OP_TIMESTEP_EMBEDDING: + { + ggml_compute_forward_timestep_embedding(params, tensor); + } break; case GGML_OP_ARGSORT: { ggml_compute_forward_argsort(params, tensor); @@ -16617,6 +16792,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // TODO: not implemented } break; + case GGML_OP_ARANGE: + { + GGML_ASSERT(false); // TODO: not implemented + } break; + case GGML_OP_TIMESTEP_EMBEDDING: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_ARGSORT: { GGML_ASSERT(false); // TODO: not implemented @@ -17368,6 +17551,14 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { { n_tasks = n_threads; } break; + case GGML_OP_ARANGE: + { + n_tasks = n_threads; + } break; + case GGML_OP_TIMESTEP_EMBEDDING: + { + n_tasks = n_threads; + } break; case GGML_OP_ARGSORT: { n_tasks = n_threads; diff --git a/ggml.h b/ggml.h index 0a6d3c051..98cfc7bf8 100644 --- a/ggml.h +++ b/ggml.h @@ -454,6 +454,8 @@ extern "C" { GGML_OP_POOL_2D, GGML_OP_UPSCALE, // nearest interpolate GGML_OP_PAD, + GGML_OP_ARANGE, + GGML_OP_TIMESTEP_EMBEDDING, GGML_OP_ARGSORT, GGML_OP_LEAKY_RELU, @@ -1661,6 +1663,15 @@ extern "C" { int p2, int p3); + // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151 + // timesteps: [N,] + // return: [N, dim] + GGML_API struct ggml_tensor * ggml_timestep_embedding( + struct ggml_context * ctx, + struct ggml_tensor * timesteps, + int dim, + int max_period); + // sort rows enum ggml_sort_order { GGML_SORT_ORDER_ASC, @@ -1672,6 +1683,12 @@ extern "C" { struct ggml_tensor * a, enum ggml_sort_order order); + GGML_API struct ggml_tensor * ggml_arange( + struct ggml_context * ctx, + float start, + float stop, + float step); + // top k elements per row GGML_API struct ggml_tensor * ggml_top_k( struct ggml_context * ctx, diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index d4cea805f..8a6999f21 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1412,6 +1412,50 @@ struct test_pad : public test_case { } }; +// GGML_OP_ARANGE +struct test_arange : public test_case { + const ggml_type type; + const float start; + const float stop; + const float step; + + std::string vars() override { + return VARS_TO_STR4(type, start, stop, step); + } + + test_arange(ggml_type type = GGML_TYPE_F32, + float start = 0.f, float stop = 10.f, float step = 1.f) + : type(type), start(start), stop(stop), step(step) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * out = ggml_arange(ctx, start, stop, step); + return out; + } +}; + +// GGML_OP_TIMESTEP_EMBEDDING +struct test_timestep_embedding : public test_case { + const ggml_type type; + const std::array ne_a; + const int dim; + const int max_period; + + std::string vars() override { + return VARS_TO_STR4(type, ne_a, dim, max_period); + } + + test_timestep_embedding(ggml_type type = GGML_TYPE_F32, + std::array ne_a = {2, 1, 1, 1}, + int dim = 320, int max_period=10000) + : type(type), ne_a(ne_a), dim(dim), max_period(max_period) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_tensor * out = ggml_timestep_embedding(ctx, a, dim, max_period); + return out; + } +}; + // GGML_OP_LEAKY_RELU struct test_leaky_relu : public test_case { const ggml_type type; @@ -2126,6 +2170,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_group_norm()); test_cases.emplace_back(new test_acc()); test_cases.emplace_back(new test_pad()); + test_cases.emplace_back(new test_arange()); + test_cases.emplace_back(new test_timestep_embedding()); test_cases.emplace_back(new test_leaky_relu()); // these tests are disabled to save execution time, but they can be handy for debugging From a0fc62661f0fd2a9edd10ae5617345bbbf972f42 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 4 Mar 2024 10:40:04 +0200 Subject: [PATCH 5/8] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 389c0bdfe..4a1b0bab4 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -b458250b736a7473f7ff3560d47c93f1644f3290 +274680868e12427373bab4bec87554431b954704 From 4ffcdce2ff877ebb683cd217ea38faf20faa5ffe Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 4 Mar 2024 12:22:08 +0100 Subject: [PATCH 6/8] add alias for chat template (#5858) --- examples/server/server.cpp | 2 +- llama.cpp | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 0ca388f47..208edd571 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -413,7 +413,7 @@ struct llama_server_context int res = llama_chat_apply_template(model, nullptr, chat, 1, true, buf.data(), buf.size()); if (res < 0) { LOG_ERROR("The chat template comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {}); - sparams.chat_template = "<|im_start|>"; // llama_chat_apply_template only checks if <|im_start|> exist in the template + sparams.chat_template = "chatml"; } } diff --git a/llama.cpp b/llama.cpp index c1f015791..de579d9e3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13282,7 +13282,7 @@ static int32_t llama_chat_apply_template_internal( std::string & dest, bool add_ass) { // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 std::stringstream ss; - if (tmpl.find("<|im_start|>") != std::string::npos) { + if (tmpl == "chatml" || tmpl.find("<|im_start|>") != std::string::npos) { // chatml template for (auto message : chat) { ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n"; @@ -13290,7 +13290,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|im_start|>assistant\n"; } - } else if (tmpl.find("[INST]") != std::string::npos) { + } else if (tmpl == "llama2" || tmpl.find("[INST]") != std::string::npos) { // llama2 template and its variants // [variant] support system message bool support_system_message = tmpl.find("<>") != std::string::npos; @@ -13325,7 +13325,7 @@ static int32_t llama_chat_apply_template_internal( } } // llama2 templates seem to not care about "add_generation_prompt" - } else if (tmpl.find("<|user|>") != std::string::npos) { + } else if (tmpl == "zephyr" || tmpl.find("<|user|>") != std::string::npos) { // zephyr template for (auto message : chat) { ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n"; @@ -13333,7 +13333,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|assistant|>\n"; } - } else if (tmpl.find("bos_token + message['role']") != std::string::npos) { + } else if (tmpl == "monarch" || tmpl.find("bos_token + message['role']") != std::string::npos) { // mlabonne/AlphaMonarch-7B template (the is included inside history) for (auto message : chat) { std::string bos = (message == chat.front()) ? "" : ""; // skip BOS for first message @@ -13342,7 +13342,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "assistant\n"; } - } else if (tmpl.find("") != std::string::npos) { + } else if (tmpl == "gemma" || tmpl.find("") != std::string::npos) { // google/gemma-7b-it std::string system_prompt = ""; for (auto message : chat) { @@ -13389,7 +13389,7 @@ LLAMA_API int32_t llama_chat_apply_template( int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size()); if (res < 0) { // worst case: there is no information about template, we will use chatml by default - curr_tmpl = "<|im_start|>"; // see llama_chat_apply_template_internal + curr_tmpl = "chatml"; // see llama_chat_apply_template_internal } else { curr_tmpl = std::string(model_template.data(), model_template.size()); } From 6d341ab6c53cd51f2921d986d0090cc8b049b39a Mon Sep 17 00:00:00 2001 From: Minsoo Cheong <54794500+mscheong01@users.noreply.github.com> Date: Tue, 5 Mar 2024 03:24:00 +0900 Subject: [PATCH 7/8] speculative : implement stochastic speculative sampling (#5625) * (WIP) Implement stochastic speculative decoding * sample from residual distribution on draft accept failure * fix #5657: force greedy sampling with probs when temp is 0 * remove p_accept parameter * fix style * remove unused variables * add srand() in speculative.cpp * replace use of rand() with mt19937 sampling * fixes based on review (@JohannesGaessler) * fix r random generation * randomly select next sequence to verify + fix bug in memory freeing * fix bug in active_seqs sync * fix uniform int distribution initialization * remove warnings from comparison between int and size_t * check grammar in `llama_sample_probability_distribution_impl` * remove malloc code by utilizing vectors * add PR link to README --- common/common.cpp | 7 - common/common.h | 3 +- common/sampling.cpp | 79 ++++++++++ common/sampling.h | 7 + examples/speculative/README.md | 1 + examples/speculative/speculative.cpp | 224 ++++++++++++++++++++------- 6 files changed, 260 insertions(+), 61 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index dbe7e9229..036a98134 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -513,12 +513,6 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } params.n_sequences = std::stoi(argv[i]); - } else if (arg == "--p-accept" || arg == "-pa") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.p_accept = std::stof(argv[i]); } else if (arg == "--p-split" || arg == "-ps") { if (++i >= argc) { invalid_param = true; @@ -1044,7 +1038,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks); printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel); printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences); - printf(" -pa N, --p-accept N speculative decoding accept probability (default: %.1f)\n", (double)params.p_accept); printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split); printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n"); diff --git a/common/common.h b/common/common.h index b2868833b..977ce419f 100644 --- a/common/common.h +++ b/common/common.h @@ -53,11 +53,10 @@ struct gpt_params { int32_t n_ctx = 512; // context size int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_draft = 8; // number of tokens to draft during speculative decoding + int32_t n_draft = 5; // number of tokens to draft during speculative decoding int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) int32_t n_parallel = 1; // number of parallel sequences to decode int32_t n_sequences = 1; // number of sequences to decode - float p_accept = 0.5f; // speculative decoding accept probability float p_split = 0.1f; // speculative decoding split probability int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) diff --git a/common/sampling.cpp b/common/sampling.cpp index e67096bea..823031feb 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -295,6 +295,77 @@ static llama_token llama_sampling_sample_impl( return id; } +static llama_token_data_array llama_sample_probability_distribution_impl( + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_cfg, + const int idx) { + const llama_sampling_params & params = ctx_sampling->params; + + const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + + const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; + const float penalty_repeat = params.penalty_repeat; + const float penalty_freq = params.penalty_freq; + const float penalty_present = params.penalty_present; + const bool penalize_nl = params.penalize_nl; + + auto & prev = ctx_sampling->prev; + auto & cur = ctx_sampling->cur; + + // Get a pointer to the logits + float * logits = llama_get_logits_ith(ctx_main, idx); + + // Declare original_logits at the beginning of the function scope + std::vector original_logits; + + // apply params.logit_bias map + for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { + logits[it->first] += it->second; + } + + if (ctx_cfg) { + float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx); + llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale); + } + + cur.clear(); + + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + + llama_token_data_array cur_p = { cur.data(), cur.size(), false }; + + // apply penalties + const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev; + const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n); + if (penalty_tokens_used_size) { + const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; + + llama_sample_repetition_penalties(ctx_main, &cur_p, + penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, + penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present); + + if (!penalize_nl) { + for (size_t idx = 0; idx < cur_p.size; idx++) { + if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) { + cur_p.data[idx].logit = nl_logit; + break; + } + } + } + } + + // apply grammar checks + if (ctx_sampling->grammar != NULL) { + llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar); + } + + llama_sample_softmax(ctx_main, &cur_p); + return cur_p; +} + llama_token llama_sampling_sample( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, @@ -304,6 +375,14 @@ llama_token llama_sampling_sample( return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false); } +llama_token_data_array llama_sampling_probability_distribution( + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_cfg, + const int idx) { + return llama_sample_probability_distribution_impl(ctx_sampling,ctx_main, ctx_cfg, idx); +} + void llama_sampling_accept( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, diff --git a/common/sampling.h b/common/sampling.h index 95d875394..48b2459d1 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -131,6 +131,13 @@ llama_token llama_sampling_sample( struct llama_context * ctx_cfg, int idx = 0); +// returns the probability that token of given id will be sampled +llama_token_data_array llama_sampling_probability_distribution( + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_cfg, + int idx = 0); + void llama_sampling_accept( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, diff --git a/examples/speculative/README.md b/examples/speculative/README.md index 814efa592..a6608c5fe 100644 --- a/examples/speculative/README.md +++ b/examples/speculative/README.md @@ -6,3 +6,4 @@ More info: - https://github.com/ggerganov/llama.cpp/pull/2926 - https://github.com/ggerganov/llama.cpp/pull/3624 +- https://github.com/ggerganov/llama.cpp/pull/5625 diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 3848791d4..85bc0a762 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 @@ -18,6 +19,7 @@ struct seq_draft { std::vector i_batch_tgt; std::vector tokens; + std::vector> dists; struct llama_sampling_context * ctx_sampling; }; @@ -37,12 +39,15 @@ int main(int argc, char ** argv) { // max number of parallel drafting sequences (i.e. tree branches) const int n_seq_dft = params.n_parallel; - // probability threshold for accepting a token from the draft model - const float p_accept = params.p_accept; - // probability threshold for splitting a draft branch (only for n_seq_dft > 1) const float p_split = params.p_split; + if (params.seed == LLAMA_DEFAULT_SEED) { + params.seed = time(NULL); + } + std::default_random_engine rng(params.seed); + std::uniform_real_distribution<> u_dist; + #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("speculative", "log")); LOG_TEE("Log start\n"); @@ -166,7 +171,9 @@ int main(int argc, char ** argv) { std::vector drafts(n_seq_dft); params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar - params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model + if (params.sparams.temp == 0) { + params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model + } for (int s = 0; s < n_seq_dft; ++s) { drafts[s].ctx_sampling = llama_sampling_init(params.sparams); @@ -182,12 +189,15 @@ int main(int argc, char ** argv) { drafts[0].i_batch_tgt[0] = 0; while (true) { + std::set active_seqs = {}; + // print current draft sequences for (int s = 0; s < n_seq_dft; ++s) { if (!drafts[s].active) { continue; } + active_seqs.insert(s); const auto & tokens = drafts[s].tokens; LOG("draft %d: %s\n", s, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens).c_str()); @@ -196,48 +206,156 @@ int main(int argc, char ** argv) { int i_dft = 0; int s_keep = 0; + llama_token token_id; + std::string token_str; + + // loop until we fail to accept a drafted token or we run out of drafted tokens while (true) { - LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]); - - // sample from the target model - llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); - - llama_sampling_accept(ctx_sampling, ctx_tgt, id, true); - - //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str()); - - const std::string token_str = llama_token_to_piece(ctx_tgt, id); - - if (!params.use_color) { - printf("%s", token_str.c_str()); - } - - if (id == llama_token_eos(model_tgt)) { - has_eos = true; - } - - ++n_predict; // check if the target token matches any of the drafts + // for stochastic sampling, attempt to match the token with the drafted tokens { - bool matches = false; + bool accept = false; + if (params.sparams.temp > 0) { + // stochastic verification - for (int s = 0; s < n_seq_dft; ++s) { - if (!drafts[s].active) { - continue; + llama_token_data_array dist_tgt = llama_sampling_probability_distribution(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); + float p_tgt = 0, p_dft = 0; + + // GGML_ASSERT(dist_tgt.size() == dist_dft.size()); + + while (active_seqs.size() > 0) { + // randomly select a sequence to verify from active sequences + std::uniform_int_distribution u_int_dist(0, active_seqs.size() - 1); + int s = *std::next(active_seqs.begin(), u_int_dist(rng)); + if (i_dft >= (int) drafts[s].tokens.size()) { + drafts[s].active = false; + active_seqs.erase(s); + continue; + } + if (accept) { + // if we already accepted a token, we can skip the rest + if (drafts[s].tokens[i_dft] != drafts[s_keep].tokens[i_dft]) { + drafts[s].active = false; + active_seqs.erase(s); + } + continue; + } + LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size()); + float r = u_dist(rng); + llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), true }; + // acquire the token probabilities assigned by the draft and target models + for (size_t i = 0; i < dist_tgt.size; i++) { + if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) { + p_tgt = dist_tgt.data[i].p; + } + if (dist_dft.data[i].id == drafts[s].tokens[i_dft]) { + p_dft = dist_dft.data[i].p; + } + if (p_tgt && p_dft) { + break; + } + } + LOG("r = %f, p_dft = %f, p_tgt = %f\n", r, p_dft, p_tgt); + if (r <= p_tgt / p_dft) { + s_keep = s; + accept = true; + token_id = drafts[s].tokens[i_dft]; + token_str = llama_token_to_piece(ctx_tgt, token_id); + llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true); + + LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str()); + break; + } else { + LOG("draft token %d of sequence %d (%d, '%s') rejected\n", i_dft, s, drafts[s].tokens[i_dft], llama_token_to_piece(ctx_tgt, drafts[s].tokens[i_dft]).c_str()); + drafts[s].active = false; + + // calculate residual probability + GGML_ASSERT(dist_tgt.sorted); + GGML_ASSERT(dist_dft.sorted); + float sum_probs = 0.0f; + + // sort dist by id + std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) { + return a.id < b.id; + }); + std::sort(dist_dft.data, dist_dft.data + dist_dft.size, [](const llama_token_data &a, const llama_token_data &b) { + return a.id < b.id; + }); + + for (size_t i = 0; i < dist_tgt.size; i++) { + dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p); + sum_probs += dist_tgt.data[i].p; + } + for (size_t i = 0; i < dist_tgt.size; i++) { + dist_tgt.data[i].p /= sum_probs; + } + + // sort dist_tgt by p desc + std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) { + return a.p > b.p; + }); + } + + active_seqs.erase(s); + for(int i = 0; i < n_seq_dft; i++) { + if (i == s) { + continue; + } + if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) { + // synchronize active status for sequences with the same drafted token + drafts[i].active = drafts[i].active && accept; + if (!drafts[i].active) { + active_seqs.erase(s); + } + } + } } - if (i_dft < (int) drafts[s].tokens.size() && id == drafts[s].tokens[i_dft]) { - LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, id, token_str.c_str()); + if (!accept) { + // all drafted tokens were rejected + // sample from the target model + LOG("all drafted tokens were rejected, sampling from residual distribution\n"); + token_id = llama_sample_token(ctx_tgt, &dist_tgt); + llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true); + token_str = llama_token_to_piece(ctx_tgt, token_id); + } - s_keep = s; - matches = true; - } else { - drafts[s].active = false; + } else { + // greedy verification + + // sample from the target model + LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]); + token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); + + llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true); + + //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str()); + + token_str = llama_token_to_piece(ctx_tgt, token_id); + + for (int s = 0; s < n_seq_dft; ++s) { + if (!drafts[s].active) { + continue; + } + + if (i_dft < (int) drafts[s].tokens.size() && token_id == drafts[s].tokens[i_dft]) { + LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, token_id, token_str.c_str()); + + s_keep = s; + accept = true; + } else { + drafts[s].active = false; + } } } - if (matches) { + if (token_id == llama_token_eos(model_tgt)) { + has_eos = true; + } + ++n_predict; + + if (accept) { ++n_accept; ++n_past_tgt; ++n_past_dft; @@ -245,17 +363,21 @@ int main(int argc, char ** argv) { if (params.use_color) { // Color token according to its origin sequence printf("\u001b[%dm%s\u001b[37m", (36 - s_keep % 6), token_str.c_str()); - fflush(stdout); + } else { + printf("%s", token_str.c_str()); } + fflush(stdout); continue; + } else { + printf("%s", token_str.c_str()); + fflush(stdout); + break; } } - if (params.use_color) { - printf("%s", token_str.c_str()); - } - fflush(stdout); + } - LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str()); + { + LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", token_id, token_str.c_str()); // TODO: simplify { @@ -275,21 +397,21 @@ int main(int argc, char ** argv) { drafts[s].active = false; drafts[s].tokens.clear(); drafts[s].i_batch_tgt.clear(); + drafts[s].dists.clear(); } // note: will be erased after the speculation phase - drafts[0].tokens.push_back(id); + drafts[0].tokens.push_back(token_id); + drafts[0].dists.push_back(std::vector()); drafts[0].i_batch_tgt.push_back(0); llama_batch_clear(batch_dft); - llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true); + llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true); llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1); // LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); - llama_decode (ctx_dft, batch_dft); + llama_decode(ctx_dft, batch_dft); ++n_past_dft; - - break; } if (n_predict > params.n_predict || has_eos) { @@ -334,12 +456,6 @@ int main(int argc, char ** argv) { k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str()); } - if (cur_p[0].p < p_accept) { - LOG("stopping drafting for seq %3d, probability too low: %.3f < %.3f\n", s, cur_p[0].p, p_accept); - drafts[s].drafting = false; - continue; - } - std::vector sa(1, s); // attempt to split the branch if the probability is high enough @@ -367,6 +483,7 @@ int main(int argc, char ** argv) { drafts[n_seq_cur].skip = true; drafts[n_seq_cur].tokens = drafts[s].tokens; + drafts[n_seq_cur].dists = drafts[s].dists; drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft; drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt; @@ -389,6 +506,8 @@ int main(int argc, char ** argv) { llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true); drafts[s].tokens.push_back(id); + // save cur_p.data into drafts[s].dists + drafts[s].dists.push_back(cur_p); // add unique drafted tokens to the target batch drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens); @@ -440,6 +559,7 @@ int main(int argc, char ** argv) { } drafts[s].tokens.erase(drafts[s].tokens.begin()); + drafts[s].dists.erase(drafts[s].dists.begin()); } } From fe52be11e35358d2fd249f19d7ef5b6f9c08b16b Mon Sep 17 00:00:00 2001 From: Dane Madsen Date: Tue, 5 Mar 2024 05:26:55 +1100 Subject: [PATCH 8/8] cmake : handle cases where git index is not found in .git (#5844) * Update CMakeLists.txt * Update CMakeLists.txt --- common/CMakeLists.txt | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index f79acfef1..350bbdf7f 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -19,7 +19,12 @@ if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/../.git") endif() endif() - set(GIT_INDEX "${GIT_DIR}/index") + if(EXISTS "${GIT_DIR}/index") + set(GIT_INDEX "${GIT_DIR}/index") + else() + message(WARNING "Git index not found in git repository.") + set(GIT_INDEX "") + endif() else() message(WARNING "Git repository not found; to enable automatic generation of build info, make sure Git is installed and the project is a Git repository.") set(GIT_INDEX "")