mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 10:54:36 +00:00
ggml : add ALiBi support for ggml_soft_max_ext (#5488)
* ggml : avoid recomputing alibi slopes (CPU) * llama : reuse hparams.f_max_alibi_bias in all cases ggml-ci * ggml : support alibi bias in ggml_soft_max_ext (CPU + Metal) ggml-ci * ggml : handle all SRCs (do not break on first null) ggml-ci * tests : do not use slope for large soft_max accumulates too much error ggml-ci * ggml : alternative ALiBi without extra tensor We compute the slopes in the kernel ggml-ci * cuda : add ALiBi support in ggml_soft_max_ext ggml-ci * ggml : deprecate ggml_alibi * ggml : support multi-sequence ALiBi (Metal) ggml-ci * cuda : add multi-seq ALiBi + remote F16 soft_max ggml-ci * ggml : update deprecation message * ggml : fix pos ptr when no ALiBi ggml-ci * cuda : fix performance (pow -> powf) * cuda : precompute ALiBi constants * metal : pre-compute ALiBi slopes ggml-ci * llama : init kq_pos only if needed ggml-ci * test-backend-ops : add null pos test to soft_max test-backend-ops : replace soft_max tests ggml-ci --------- Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
parent
6e4e973b26
commit
8f1be0d42f
@ -551,7 +551,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
|
||||
}
|
||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||
if (graph->nodes[i]->src[j] == NULL) {
|
||||
break;
|
||||
continue;
|
||||
}
|
||||
if (graph->nodes[i]->src[j]->flags & GGML_TENSOR_FLAG_INPUT) {
|
||||
ggml_gallocr_allocate_node(galloc, graph->nodes[i]->src[j], get_node_buffer_id(node_buffer_ids, i));
|
||||
@ -787,7 +787,7 @@ static bool ggml_gallocr_needs_realloc(ggml_gallocr_t galloc, struct ggml_cgraph
|
||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||
struct ggml_tensor * src = node->src[j];
|
||||
if (src == NULL) {
|
||||
break;
|
||||
continue;
|
||||
}
|
||||
if (!ggml_gallocr_node_needs_realloc(galloc, src, node_alloc, &node_alloc->src[j])) {
|
||||
#ifndef NDEBUG
|
||||
@ -833,7 +833,7 @@ bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph)
|
||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||
struct ggml_tensor * src = node->src[j];
|
||||
if (src == NULL) {
|
||||
break;
|
||||
continue;
|
||||
}
|
||||
ggml_gallocr_init_tensor(galloc, src, node_alloc, &node_alloc->src[j]);
|
||||
}
|
||||
|
@ -1041,7 +1041,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
const struct ggml_tensor * src = tensor->src[i];
|
||||
if (src == NULL) {
|
||||
break;
|
||||
continue;
|
||||
}
|
||||
if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
|
||||
int src_backend = ggml_backend_sched_backend_from_buffer(sched, src->buffer);
|
||||
@ -1088,7 +1088,7 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
|
||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||
struct ggml_tensor * src = node->src[j];
|
||||
if (src == NULL) {
|
||||
break;
|
||||
continue;
|
||||
}
|
||||
ggml_backend_t src_backend = tensor_backend(src);
|
||||
fprintf(stderr, " %20.20s (%5.5s) [%5.5s %8.8s]", src->name,
|
||||
@ -1144,7 +1144,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
|
||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||
struct ggml_tensor * src = node->src[j];
|
||||
if (src == NULL) {
|
||||
break;
|
||||
continue;
|
||||
}
|
||||
if (tensor_backend_id(src) == -1) {
|
||||
tensor_backend_id(src) = ggml_backend_sched_backend_id_from_cur(sched, src);
|
||||
@ -1256,7 +1256,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
|
||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||
struct ggml_tensor * src = node->src[j];
|
||||
if (src == NULL) {
|
||||
break;
|
||||
continue;
|
||||
}
|
||||
int src_backend_id = tensor_backend_id(src);
|
||||
if (src_backend_id == -1) {
|
||||
@ -1315,7 +1315,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
|
||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||
struct ggml_tensor * src = node->src[j];
|
||||
if (src == NULL) {
|
||||
break;
|
||||
continue;
|
||||
}
|
||||
int src_backend_id = tensor_backend_id(src);
|
||||
assert(src_backend_id != -1); // all inputs should be assigned by now
|
||||
@ -1362,7 +1362,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
|
||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||
struct ggml_tensor * src = node->src[j];
|
||||
if (src == NULL) {
|
||||
break;
|
||||
continue;
|
||||
}
|
||||
ggml_backend_t src_backend = tensor_backend(src);
|
||||
if (src_backend != tensor_backend /* && src_backend != NULL */) {
|
||||
@ -1668,7 +1668,7 @@ static struct ggml_tensor * graph_copy_dup_tensor(struct ggml_hash_set hash_set,
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
struct ggml_tensor * s = src->src[i];
|
||||
if (s == NULL) {
|
||||
break;
|
||||
continue;
|
||||
}
|
||||
dst->src[i] = graph_copy_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, s);
|
||||
}
|
||||
@ -1697,7 +1697,7 @@ static void graph_copy_init_tensor(struct ggml_hash_set hash_set, struct ggml_te
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
struct ggml_tensor * s = src->src[i];
|
||||
if (s == NULL) {
|
||||
break;
|
||||
continue;
|
||||
}
|
||||
graph_copy_init_tensor(hash_set, node_copies, node_init, s);
|
||||
}
|
||||
|
263
ggml-cuda.cu
263
ggml-cuda.cu
@ -5956,149 +5956,31 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
|
||||
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
|
||||
}
|
||||
|
||||
template <bool vals_smem, int ncols_template, int block_size_template, bool need_check>
|
||||
static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||
const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
|
||||
const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int rowx = blockIdx.x;
|
||||
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
|
||||
|
||||
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
|
||||
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
|
||||
extern __shared__ half data_soft_max_f16[];
|
||||
half * buf_iw = data_soft_max_f16 + 0; // shared memory buffer for inter-warp communication
|
||||
// (shared memory) buffer to cache values between iterations:
|
||||
half2 * vals = vals_smem ? (half2 *) (buf_iw + WARP_SIZE) : (half2 *) (dst + rowx*ncols_data);
|
||||
// if the buffer is larger than max. shared memory per block, use dst as temp. buffer instead
|
||||
// in that case col_smem == col_data must be enforced to avoid race conditions
|
||||
|
||||
half2 max_val = make_half2(-INFINITY, -INFINITY);
|
||||
|
||||
#pragma unroll
|
||||
for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
|
||||
const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
|
||||
const int col_smem = vals_smem ? col0 + tid : col_data;
|
||||
|
||||
const int ix = rowx*ncols_data + col_data;
|
||||
const int iy = rowy*ncols_data + col_data;
|
||||
|
||||
half2 val;
|
||||
if (need_check && col_data + 0 >= ncols_data) {
|
||||
val.x = -INFINITY;
|
||||
} else {
|
||||
val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f);
|
||||
}
|
||||
if (need_check && col_data + WARP_SIZE >= ncols_data) {
|
||||
val.y = -INFINITY;
|
||||
} else {
|
||||
val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f);
|
||||
}
|
||||
if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) {
|
||||
vals[col_smem] = val;
|
||||
}
|
||||
max_val = __hmax2(max_val, val);
|
||||
}
|
||||
|
||||
// find the max value in the block
|
||||
max_val = warp_reduce_max(max_val);
|
||||
if (block_size > WARP_SIZE) {
|
||||
if (warp_id == 0) {
|
||||
buf_iw[lane_id] = -INFINITY;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (lane_id == 0) {
|
||||
buf_iw[warp_id] = __hmax(max_val.x, max_val.y);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
max_val = __half2half2(buf_iw[lane_id]);
|
||||
max_val = warp_reduce_max(max_val);
|
||||
} else {
|
||||
max_val = __half2half2(__hmax(max_val.x, max_val.y));
|
||||
}
|
||||
|
||||
half2 tmp = make_half2(0.0f, 0.0f); // partial sums
|
||||
|
||||
#pragma unroll
|
||||
for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
|
||||
const int col_smem = vals_smem ? col0 + tid : 2*col0 + 2*warp_id*WARP_SIZE + lane_id;
|
||||
|
||||
if (ncols_template == 0 && col_smem >= (vals_smem ? ncols_smem : ncols_data)) {
|
||||
break;
|
||||
}
|
||||
|
||||
const half2 val = h2exp(vals[col_smem] - max_val);
|
||||
|
||||
tmp += val;
|
||||
vals[col_smem] = val;
|
||||
}
|
||||
|
||||
// find the sum of exps in the block
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
if (block_size > WARP_SIZE) {
|
||||
if (warp_id == 0) {
|
||||
buf_iw[lane_id] = 0.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (lane_id == 0) {
|
||||
buf_iw[warp_id] = tmp.x + tmp.y;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
tmp = __half2half2(buf_iw[lane_id]);
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
} else {
|
||||
tmp = __half2half2(tmp.x + tmp.y);
|
||||
}
|
||||
|
||||
const half2 inv_sum = make_half2(1.0f, 1.0f) / tmp;
|
||||
|
||||
#pragma unroll
|
||||
for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
|
||||
const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
|
||||
const int col_smem = vals_smem ? col0 + tid : col_data;
|
||||
|
||||
const int idst = rowx*ncols_data + col_data;
|
||||
const half2 result = vals[col_smem] * inv_sum;
|
||||
|
||||
if (need_check && col_data + 0 >= ncols_data) {
|
||||
return;
|
||||
}
|
||||
dst[idst] = result.x;
|
||||
|
||||
if (need_check && col_data + WARP_SIZE >= ncols_data) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst[idst + WARP_SIZE] = result.y;
|
||||
}
|
||||
#else
|
||||
(void) x; (void) y; (void) dst; (void) ncols_par; (void) nrows_y; (void) scale;
|
||||
NO_DEVICE_CODE;
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||
}
|
||||
|
||||
template <bool vals_smem, int ncols_template, int block_size_template>
|
||||
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
||||
static __global__ void soft_max_f32(const float * x, const float * mask, const float * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
|
||||
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int rowx = blockIdx.x;
|
||||
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
|
||||
const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
|
||||
|
||||
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
|
||||
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
|
||||
float slope = 0.0f;
|
||||
|
||||
// ALiBi
|
||||
if (max_bias > 0.0f) {
|
||||
const int h = rowx/nrows_y; // head index
|
||||
|
||||
const float base = h < n_head_log2 ? m0 : m1;
|
||||
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
|
||||
slope = powf(base, exp);
|
||||
}
|
||||
|
||||
extern __shared__ float data_soft_max_f32[];
|
||||
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|
||||
// shared memory buffer to cache values between iterations:
|
||||
@ -6117,7 +5999,8 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds
|
||||
const int ix = rowx*ncols + col;
|
||||
const int iy = rowy*ncols + col;
|
||||
|
||||
const float val = x[ix]*scale + (y ? y[iy] : 0.0f);
|
||||
const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + slope*pos[col];
|
||||
|
||||
vals[col] = val;
|
||||
max_val = max(max_val, val);
|
||||
}
|
||||
@ -7589,89 +7472,53 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
|
||||
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
|
||||
}
|
||||
|
||||
static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
|
||||
int nth = WARP_SIZE;
|
||||
while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||
const dim3 block_dims(nth, 1, 1);
|
||||
const dim3 block_nums(nrows_x, 1, 1);
|
||||
const size_t shmem = (GGML_PAD(ncols_x, 2*WARP_SIZE) + WARP_SIZE)*sizeof(half);
|
||||
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
||||
if (shmem <= g_device_caps[g_main_device].smpb) {
|
||||
switch (ncols_x) {
|
||||
case 32:
|
||||
soft_max_f16<true, 32, 32, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 64:
|
||||
soft_max_f16<true, 64, 32, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 128:
|
||||
soft_max_f16<true, 128, 64, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 256:
|
||||
soft_max_f16<true, 256, 128, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 512:
|
||||
soft_max_f16<true, 512, 256, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 1024:
|
||||
soft_max_f16<true, 1024, 512, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 2048:
|
||||
soft_max_f16<true, 2048, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
case 4096:
|
||||
soft_max_f16<true, 4096, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
default:
|
||||
soft_max_f16<true, 0, 0, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
const size_t shmem_low = WARP_SIZE*sizeof(half);
|
||||
soft_max_f16<false, 0, 0, true><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
}
|
||||
}
|
||||
|
||||
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
|
||||
static void soft_max_f32_cuda(const float * x, const float * mask, const float * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
|
||||
int nth = WARP_SIZE;
|
||||
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||
const dim3 block_dims(nth, 1, 1);
|
||||
const dim3 block_nums(nrows_x, 1, 1);
|
||||
const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
|
||||
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
||||
|
||||
const uint32_t n_head_kv = nrows_x/nrows_y;
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
||||
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
if (shmem < g_device_caps[g_main_device].smpb) {
|
||||
switch (ncols_x) {
|
||||
case 32:
|
||||
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||
break;
|
||||
case 64:
|
||||
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||
break;
|
||||
case 128:
|
||||
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||
break;
|
||||
case 256:
|
||||
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||
break;
|
||||
case 512:
|
||||
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||
break;
|
||||
case 1024:
|
||||
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||
break;
|
||||
case 2048:
|
||||
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||
break;
|
||||
case 4096:
|
||||
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||
break;
|
||||
default:
|
||||
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
const size_t shmem_low = WARP_SIZE*sizeof(float);
|
||||
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
||||
}
|
||||
}
|
||||
|
||||
@ -9090,30 +8937,36 @@ static void ggml_cuda_op_soft_max(
|
||||
|
||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t nrows_x = ggml_nrows(src0);
|
||||
const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1;
|
||||
const int64_t nrows_y = src0->ne[1];
|
||||
|
||||
float scale = 1.0f;
|
||||
memcpy(&scale, dst->op_params, sizeof(float));
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION >= CUDART_HMAX
|
||||
#ifdef GGML_CUDA_F16
|
||||
const bool use_f16_soft_max = true;
|
||||
#else
|
||||
const bool use_f16_soft_max = false;
|
||||
#endif // GGML_CUDA_F16
|
||||
#else
|
||||
const bool use_f16_soft_max = false;
|
||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && CUDART_VERSION >= CUDART_HMAX
|
||||
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
||||
|
||||
if (use_f16_soft_max) {
|
||||
soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
|
||||
} else {
|
||||
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
|
||||
// positions tensor
|
||||
float * src2_dd = dst_dd; // default to avoid null checks in the kernel
|
||||
cuda_pool_alloc<float> src2_f;
|
||||
|
||||
ggml_tensor * src2 = dst->src[2];
|
||||
const bool use_src2 = src2 != nullptr;
|
||||
|
||||
if (use_src2) {
|
||||
const bool src2_on_device = use_src2 && src2->backend == GGML_BACKEND_GPU;
|
||||
ggml_tensor_extra_gpu * src2_extra = use_src2 ? (ggml_tensor_extra_gpu *) src2->extra : nullptr;
|
||||
|
||||
if (src2_on_device) {
|
||||
src2_dd = (float *) src2_extra->data_device[g_main_device];
|
||||
} else {
|
||||
src2_dd = src2_f.alloc(ggml_nelements(src2));
|
||||
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src2_dd, src2, 0, 0, 0, 1, main_stream));
|
||||
}
|
||||
}
|
||||
|
||||
(void) dst;
|
||||
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, src2_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream);
|
||||
}
|
||||
|
||||
static void ggml_cuda_op_scale(
|
||||
|
35
ggml-metal.m
35
ggml-metal.m
@ -728,6 +728,7 @@ static bool ggml_metal_graph_compute(
|
||||
|
||||
size_t offs_src0 = 0;
|
||||
size_t offs_src1 = 0;
|
||||
size_t offs_src2 = 0;
|
||||
size_t offs_dst = 0;
|
||||
|
||||
id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
|
||||
@ -746,6 +747,7 @@ static bool ggml_metal_graph_compute(
|
||||
|
||||
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
||||
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
|
||||
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
||||
struct ggml_tensor * dst = gf->nodes[i];
|
||||
|
||||
switch (dst->op) {
|
||||
@ -807,6 +809,7 @@ static bool ggml_metal_graph_compute(
|
||||
|
||||
id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
|
||||
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
|
||||
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
|
||||
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
|
||||
|
||||
//GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
|
||||
@ -1188,7 +1191,16 @@ static bool ggml_metal_graph_compute(
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
|
||||
}
|
||||
|
||||
const float scale = ((float *) dst->op_params)[0];
|
||||
const float scale = ((float *) dst->op_params)[0];
|
||||
const float max_bias = ((float *) dst->op_params)[1];
|
||||
|
||||
const int64_t nrows_x = ggml_nrows(src0);
|
||||
const int64_t nrows_y = src0->ne[1];
|
||||
const uint32_t n_head_kv = nrows_x/nrows_y;
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
||||
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
@ -1197,11 +1209,20 @@ static bool ggml_metal_graph_compute(
|
||||
} else {
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
}
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
||||
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
|
||||
if (id_src2) {
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||
} else {
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
|
||||
}
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:5];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
|
||||
[encoder setBytes:&scale length:sizeof(scale) atIndex:7];
|
||||
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8];
|
||||
[encoder setBytes:&m0 length:sizeof(m0) atIndex:9];
|
||||
[encoder setBytes:&m1 length:sizeof(m1) atIndex:10];
|
||||
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
|
||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
@ -1514,8 +1535,6 @@ static bool ggml_metal_graph_compute(
|
||||
// max size of the src1ids array in the kernel stack
|
||||
GGML_ASSERT(ne11 <= 512);
|
||||
|
||||
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
||||
|
||||
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
||||
const int64_t ne21 = src2 ? src2->ne[1] : 0;
|
||||
const int64_t ne22 = src2 ? src2->ne[2] : 0;
|
||||
|
@ -351,12 +351,17 @@ kernel void kernel_sum_rows(
|
||||
kernel void kernel_soft_max(
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device const float * src2,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant float & scale,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
constant float & max_bias,
|
||||
constant float & m0,
|
||||
constant float & m1,
|
||||
constant uint32_t & n_head_log2,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||
@ -368,13 +373,26 @@ kernel void kernel_soft_max(
|
||||
|
||||
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
|
||||
device const float * ppos = src2 != src0 ? src2 : nullptr;
|
||||
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
|
||||
float slope = 0.0f;
|
||||
|
||||
// ALiBi
|
||||
if (max_bias > 0.0f) {
|
||||
const int64_t h = i02;
|
||||
|
||||
const float base = h < n_head_log2 ? m0 : m1;
|
||||
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
|
||||
slope = pow(base, exp);
|
||||
}
|
||||
|
||||
// parallel max
|
||||
float lmax = -INFINITY;
|
||||
|
||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
||||
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]);
|
||||
}
|
||||
|
||||
// find the max value in the block
|
||||
@ -399,7 +417,7 @@ kernel void kernel_soft_max(
|
||||
// parallel sum
|
||||
float lsum = 0.0f;
|
||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
||||
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]) - max_val);
|
||||
lsum += exp_psrc0;
|
||||
pdst[i00] = exp_psrc0;
|
||||
}
|
||||
@ -437,12 +455,17 @@ kernel void kernel_soft_max(
|
||||
kernel void kernel_soft_max_4(
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device const float * src2,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant float & scale,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
constant float & max_bias,
|
||||
constant float & m0,
|
||||
constant float & m1,
|
||||
constant uint32_t & n_head_log2,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||
@ -454,13 +477,25 @@ kernel void kernel_soft_max_4(
|
||||
|
||||
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
||||
device const float4 * ppos = src2 != src0 ? (device const float4 *)(src2) : nullptr;
|
||||
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
|
||||
float slope = 0.0f;
|
||||
|
||||
if (max_bias > 0.0f) {
|
||||
const int64_t h = i02;
|
||||
|
||||
const float base = h < n_head_log2 ? m0 : m1;
|
||||
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
|
||||
slope = pow(base, exp);
|
||||
}
|
||||
|
||||
// parallel max
|
||||
float4 lmax4 = -INFINITY;
|
||||
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
||||
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]);
|
||||
}
|
||||
|
||||
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
||||
@ -486,7 +521,7 @@ kernel void kernel_soft_max_4(
|
||||
// parallel sum
|
||||
float4 lsum4 = 0.0f;
|
||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
||||
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]) - max_val);
|
||||
lsum4 += exp_psrc4;
|
||||
pdst4[i00] = exp_psrc4;
|
||||
}
|
||||
|
118
ggml.c
118
ggml.c
@ -5096,16 +5096,28 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * mask,
|
||||
struct ggml_tensor * pos,
|
||||
float scale,
|
||||
float max_bias,
|
||||
bool inplace) {
|
||||
GGML_ASSERT(ggml_is_contiguous(a));
|
||||
|
||||
if (mask) {
|
||||
GGML_ASSERT(ggml_is_contiguous(mask));
|
||||
GGML_ASSERT(mask->ne[2] == 1);
|
||||
GGML_ASSERT(mask->ne[3] == 1);
|
||||
GGML_ASSERT(ggml_is_matrix(mask));
|
||||
GGML_ASSERT(ggml_can_repeat_rows(mask, a));
|
||||
}
|
||||
|
||||
if (pos) {
|
||||
GGML_ASSERT(ggml_is_vector(pos));
|
||||
GGML_ASSERT(pos->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(pos->ne[0] == a->ne[0]);
|
||||
}
|
||||
|
||||
if (max_bias > 0.0f) {
|
||||
GGML_ASSERT(pos);
|
||||
}
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
if (a->grad) {
|
||||
@ -5114,13 +5126,14 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
||||
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
float params[] = { scale };
|
||||
float params[] = { scale, max_bias };
|
||||
ggml_set_op_params(result, params, sizeof(params));
|
||||
|
||||
result->op = GGML_OP_SOFT_MAX;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = a;
|
||||
result->src[1] = mask;
|
||||
result->src[2] = pos;
|
||||
|
||||
return result;
|
||||
}
|
||||
@ -5128,21 +5141,23 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
||||
struct ggml_tensor * ggml_soft_max(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_soft_max_impl(ctx, a, NULL, 1.0f, false);
|
||||
return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_soft_max_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_soft_max_impl(ctx, a, NULL, 1.0f, true);
|
||||
return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, true);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_soft_max_ext(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * mask,
|
||||
float scale) {
|
||||
return ggml_soft_max_impl(ctx, a, mask, scale, false);
|
||||
struct ggml_tensor * pos,
|
||||
float scale,
|
||||
float max_bias) {
|
||||
return ggml_soft_max_impl(ctx, a, mask, pos, scale, max_bias, false);
|
||||
}
|
||||
|
||||
// ggml_soft_max_back
|
||||
@ -11495,6 +11510,7 @@ static void ggml_compute_forward_soft_max_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
const struct ggml_tensor * src2,
|
||||
struct ggml_tensor * dst) {
|
||||
assert(ggml_is_contiguous(dst));
|
||||
assert(ggml_are_same_shape(src0, dst));
|
||||
@ -11503,16 +11519,29 @@ static void ggml_compute_forward_soft_max_f32(
|
||||
return;
|
||||
}
|
||||
|
||||
float scale = 1.0f;
|
||||
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
|
||||
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
||||
|
||||
// TODO: handle transposed/permuted matrices
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
const int64_t ne11 = src1 ? src1->ne[1] : 1;
|
||||
|
||||
// TODO: is this supposed to be ceil instead of floor?
|
||||
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
|
||||
const uint32_t n_head_kv = ne02;
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head_kv));
|
||||
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
const int nc = src0->ne[0];
|
||||
const int nr = ggml_nrows(src0);
|
||||
|
||||
@ -11525,6 +11554,9 @@ static void ggml_compute_forward_soft_max_f32(
|
||||
|
||||
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
||||
|
||||
// when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
|
||||
float * pos = src2 ? (float *) src2->data : src0->data;
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
|
||||
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
||||
@ -11538,6 +11570,16 @@ static void ggml_compute_forward_soft_max_f32(
|
||||
ggml_vec_acc_f32(nc, wp, mp);
|
||||
}
|
||||
|
||||
// ALiBi bias
|
||||
if (max_bias > 0.0f) {
|
||||
const uint32_t h = (i1/ne01)%ne02; // head
|
||||
const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
|
||||
|
||||
for (int i = 0; i < nc; i++) {
|
||||
wp[i] = wp[i] + slope*pos[i];
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int i = 0; i < nc; ++i) {
|
||||
//printf("p[%d] = %f\n", i, p[i]);
|
||||
@ -11582,11 +11624,12 @@ static void ggml_compute_forward_soft_max(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
const struct ggml_tensor * src2,
|
||||
struct ggml_tensor * dst) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_soft_max_f32(params, src0, src1, dst);
|
||||
ggml_compute_forward_soft_max_f32(params, src0, src1, src2, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
@ -11730,22 +11773,20 @@ static void ggml_compute_forward_alibi_f32(
|
||||
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
||||
|
||||
for (int64_t i = 0; i < ne0; i++) {
|
||||
for (int64_t j = 0; j < ne1; j++) {
|
||||
for (int64_t k = 0; k < ne2_ne3; k++) {
|
||||
for (int64_t k = 0; k < ne2_ne3; k++) {
|
||||
// TODO: k*nb2 or k*nb3
|
||||
float m_k;
|
||||
|
||||
if (k < n_heads_log2_floor) {
|
||||
m_k = powf(m0, k + 1);
|
||||
} else {
|
||||
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < ne0; i++) {
|
||||
for (int64_t j = 0; j < ne1; j++) {
|
||||
float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
|
||||
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
|
||||
|
||||
// TODO: k*nb2 or k*nb3
|
||||
|
||||
float m_k;
|
||||
|
||||
if (k < n_heads_log2_floor) {
|
||||
m_k = powf(m0, k + 1);
|
||||
} else {
|
||||
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
|
||||
}
|
||||
|
||||
pdst[0] = i * m_k + src[0];
|
||||
}
|
||||
}
|
||||
@ -11790,21 +11831,20 @@ static void ggml_compute_forward_alibi_f16(
|
||||
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
||||
|
||||
for (int i = 0; i < ne0; i++) {
|
||||
for (int j = 0; j < ne1; j++) {
|
||||
for (int k = 0; k < ne2_ne3; k++) {
|
||||
for (int k = 0; k < ne2_ne3; k++) {
|
||||
// TODO: k*nb2 or k*nb3
|
||||
float m_k;
|
||||
|
||||
if (k < n_heads_log2_floor) {
|
||||
m_k = powf(m0, k + 1);
|
||||
} else {
|
||||
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
|
||||
}
|
||||
|
||||
for (int i = 0; i < ne0; i++) {
|
||||
for (int j = 0; j < ne1; j++) {
|
||||
ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
|
||||
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
|
||||
|
||||
// TODO: k*nb2 or k*nb3
|
||||
|
||||
float m_k;
|
||||
|
||||
if (k < n_heads_log2_floor) {
|
||||
m_k = powf(m0, k + 1);
|
||||
} else {
|
||||
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
|
||||
}
|
||||
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
|
||||
|
||||
// we return F32
|
||||
pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
|
||||
@ -15116,7 +15156,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
} break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
{
|
||||
ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor);
|
||||
ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
|
||||
} break;
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
{
|
||||
|
13
ggml.h
13
ggml.h
@ -1383,13 +1383,17 @@ extern "C" {
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
// fused soft_max(a*scale + mask)
|
||||
// fused soft_max(a*scale + mask + pos[i]*(ALiBi slope))
|
||||
// mask is optional
|
||||
// pos is required when max_bias > 0.0f
|
||||
// max_bias = 0.0f for no ALiBi
|
||||
GGML_API struct ggml_tensor * ggml_soft_max_ext(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * mask,
|
||||
float scale);
|
||||
struct ggml_tensor * pos,
|
||||
float scale,
|
||||
float max_bias);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_soft_max_back(
|
||||
struct ggml_context * ctx,
|
||||
@ -1491,12 +1495,13 @@ extern "C" {
|
||||
|
||||
// alibi position embedding
|
||||
// in-place, returns view(a)
|
||||
GGML_API struct ggml_tensor * ggml_alibi(
|
||||
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_alibi(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int n_past,
|
||||
int n_head,
|
||||
float bias_max);
|
||||
float bias_max),
|
||||
"use ggml_soft_max_ext instead (will be removed in Mar 2024)");
|
||||
|
||||
// clamp
|
||||
// in-place, returns view(a)
|
||||
|
133
llama.cpp
133
llama.cpp
@ -1557,12 +1557,13 @@ struct llama_hparams {
|
||||
uint32_t n_yarn_orig_ctx;
|
||||
int32_t rope_scaling_type_train;
|
||||
|
||||
float f_clamp_kqv;
|
||||
float f_max_alibi_bias;
|
||||
float f_clamp_kqv = 0.0f;
|
||||
float f_max_alibi_bias = 0.0f;
|
||||
|
||||
bool causal_attn = true;
|
||||
uint32_t pooling_type = LLAMA_POOLING_NONE;
|
||||
bool need_kq_pos = false;
|
||||
|
||||
uint32_t pooling_type = LLAMA_POOLING_NONE;
|
||||
|
||||
bool operator!=(const llama_hparams & other) const {
|
||||
if (this->vocab_only != other.vocab_only) return true;
|
||||
@ -1923,6 +1924,7 @@ struct llama_context {
|
||||
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
|
||||
struct ggml_tensor * inp_pos; // I32 [n_batch]
|
||||
struct ggml_tensor * inp_KQ_mask; // F32 [n_ctx, n_batch]
|
||||
struct ggml_tensor * inp_KQ_pos; // F32 [n_ctx]
|
||||
struct ggml_tensor * inp_K_shift; // I32 [n_ctx]
|
||||
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
|
||||
struct ggml_tensor * inp_cls; // I32 [n_batch]
|
||||
@ -3054,6 +3056,11 @@ static void llm_load_hparams(
|
||||
case 40: model.type = e_model::MODEL_13B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
|
||||
if (model.type == e_model::MODEL_13B) {
|
||||
// TODO: become GGUF KV parameter
|
||||
hparams.f_max_alibi_bias = 8.0f;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_STARCODER:
|
||||
{
|
||||
@ -3081,6 +3088,9 @@ static void llm_load_hparams(
|
||||
case 32: model.type = e_model::MODEL_1B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
|
||||
// TODO: become GGUF KV parameter
|
||||
hparams.f_max_alibi_bias = 8.0f;
|
||||
} break;
|
||||
case LLM_ARCH_BERT:
|
||||
{
|
||||
@ -3126,11 +3136,12 @@ static void llm_load_hparams(
|
||||
case 4096: model.type = e_model::MODEL_7B; break;
|
||||
} break;
|
||||
}
|
||||
|
||||
// TODO: become GGUF KV parameter
|
||||
hparams.f_max_alibi_bias = 8.0f;
|
||||
} break;
|
||||
case LLM_ARCH_MPT:
|
||||
{
|
||||
hparams.f_clamp_kqv = 0.0f;
|
||||
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false);
|
||||
ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias);
|
||||
@ -3232,6 +3243,10 @@ static void llm_load_hparams(
|
||||
}
|
||||
|
||||
model.ftype = ml.ftype;
|
||||
|
||||
if (hparams.f_max_alibi_bias > 0.0f) {
|
||||
hparams.need_kq_pos = true;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: This should probably be in llama.h
|
||||
@ -4774,10 +4789,10 @@ static struct ggml_tensor * llm_build_kqv(
|
||||
struct ggml_tensor * wo_b,
|
||||
struct ggml_tensor * q_cur,
|
||||
struct ggml_tensor * kq_mask,
|
||||
struct ggml_tensor * kq_pos,
|
||||
int64_t n_ctx,
|
||||
int32_t n_tokens,
|
||||
int32_t n_kv,
|
||||
float max_alibi_bias,
|
||||
float kq_scale,
|
||||
const llm_build_cb & cb,
|
||||
int il) {
|
||||
@ -4807,26 +4822,26 @@ static struct ggml_tensor * llm_build_kqv(
|
||||
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
||||
}
|
||||
|
||||
if (max_alibi_bias > 0.0f) {
|
||||
// temporary branch until we figure out how to handle ggml_alibi through ggml_add
|
||||
#if defined(GGML_USE_VULKAN) || defined(GGML_USE_KOMPUTE) || defined(GGML_USE_SYCL)
|
||||
#pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Vulkan, Kompute, and SYCL")
|
||||
#pragma message(" Falling back to ggml_alibi(). Will become an error in Mar 2024")
|
||||
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5488")
|
||||
if (hparams.f_max_alibi_bias > 0.0f) {
|
||||
kq = ggml_scale(ctx, kq, kq_scale);
|
||||
cb(kq, "kq_scaled", il);
|
||||
|
||||
if (max_alibi_bias > 0.0f) {
|
||||
// TODO: n_head or n_head_kv
|
||||
// TODO: K-shift is likely not working
|
||||
// TODO: change to ggml_add
|
||||
kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias);
|
||||
cb(kq, "kq_scaled_alibi", il);
|
||||
}
|
||||
kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, hparams.f_max_alibi_bias);
|
||||
cb(kq, "kq_scaled_alibi", il);
|
||||
|
||||
kq = ggml_add(ctx, kq, kq_mask);
|
||||
cb(kq, "kq_masked", il);
|
||||
|
||||
kq = ggml_soft_max(ctx, kq);
|
||||
cb(kq, "kq_soft_max", il);
|
||||
} else {
|
||||
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_pos, kq_scale, hparams.f_max_alibi_bias);
|
||||
cb(kq, "kq_soft_max_ext", il);
|
||||
}
|
||||
|
||||
@ -4874,11 +4889,11 @@ static struct ggml_tensor * llm_build_kv(
|
||||
struct ggml_tensor * v_cur,
|
||||
struct ggml_tensor * q_cur,
|
||||
struct ggml_tensor * kq_mask,
|
||||
struct ggml_tensor * kq_pos,
|
||||
int64_t n_ctx,
|
||||
int32_t n_tokens,
|
||||
int32_t kv_head,
|
||||
int32_t n_kv,
|
||||
float max_alibi_bias,
|
||||
float kq_scale,
|
||||
const llm_build_cb & cb,
|
||||
int il) {
|
||||
@ -4892,9 +4907,8 @@ static struct ggml_tensor * llm_build_kv(
|
||||
llm_build_kv_store(ctx, hparams, kv, graph, k_cur, v_cur, n_ctx, n_tokens, kv_head, cb, il);
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
cur = llm_build_kqv(ctx, model, hparams, kv, graph,
|
||||
wo, wo_b,
|
||||
q_cur, kq_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, kq_scale, cb, il);
|
||||
cur = llm_build_kqv(ctx, model, hparams, kv, graph, wo, wo_b,
|
||||
q_cur, kq_mask, kq_pos, n_ctx, n_tokens, n_kv, kq_scale, cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
return cur;
|
||||
@ -5062,7 +5076,7 @@ struct llm_build_context {
|
||||
}
|
||||
|
||||
Qcur = ggml_rope_custom(
|
||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
|
||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
|
||||
hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
@ -5077,7 +5091,7 @@ struct llm_build_context {
|
||||
|
||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
@ -5207,6 +5221,10 @@ struct llm_build_context {
|
||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
||||
cb(KQ_mask, "KQ_mask", -1);
|
||||
|
||||
// positions of the tokens in the KV cache
|
||||
struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0);
|
||||
cb(KQ_pos, "KQ_pos", -1);
|
||||
|
||||
// shift the entire K-cache if needed
|
||||
if (do_rope_shift) {
|
||||
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
|
||||
@ -5255,12 +5273,9 @@ struct llm_build_context {
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
|
||||
// apply ALiBi for 13B model
|
||||
const float max_alibi_bias = model.type == MODEL_13B ? 8.0f : -1.0f;
|
||||
|
||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
@ -5384,7 +5399,7 @@ struct llm_build_context {
|
||||
|
||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
@ -5483,7 +5498,7 @@ struct llm_build_context {
|
||||
|
||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
@ -5688,7 +5703,7 @@ struct llm_build_context {
|
||||
|
||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Q, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
Kcur, Vcur, Q, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
@ -5750,6 +5765,10 @@ struct llm_build_context {
|
||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
||||
cb(KQ_mask, "KQ_mask", -1);
|
||||
|
||||
// positions of the tokens in the KV cache
|
||||
struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0);
|
||||
cb(KQ_pos, "KQ_pos", -1);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
@ -5777,7 +5796,7 @@ struct llm_build_context {
|
||||
|
||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
@ -5878,7 +5897,7 @@ struct llm_build_context {
|
||||
|
||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
} else {
|
||||
// compute Q and K and RoPE them
|
||||
@ -5909,7 +5928,7 @@ struct llm_build_context {
|
||||
|
||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
@ -5985,6 +6004,10 @@ struct llm_build_context {
|
||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
||||
cb(KQ_mask, "KQ_mask", -1);
|
||||
|
||||
// positions of the tokens in the KV cache
|
||||
struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0);
|
||||
cb(KQ_pos, "KQ_pos", -1);
|
||||
|
||||
inpL = llm_build_norm(ctx0, inpL, hparams,
|
||||
model.tok_norm,
|
||||
model.tok_norm_b,
|
||||
@ -6018,7 +6041,7 @@ struct llm_build_context {
|
||||
|
||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
@ -6078,6 +6101,10 @@ struct llm_build_context {
|
||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
||||
cb(KQ_mask, "KQ_mask", -1);
|
||||
|
||||
// positions of the tokens in the KV cache
|
||||
struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0);
|
||||
cb(KQ_pos, "KQ_pos", -1);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * attn_norm;
|
||||
|
||||
@ -6111,7 +6138,7 @@ struct llm_build_context {
|
||||
|
||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, hparams.f_max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
@ -6233,7 +6260,7 @@ struct llm_build_context {
|
||||
|
||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
@ -6348,7 +6375,7 @@ struct llm_build_context {
|
||||
|
||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
@ -6469,7 +6496,7 @@ struct llm_build_context {
|
||||
|
||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
@ -6596,7 +6623,7 @@ struct llm_build_context {
|
||||
|
||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f, cb, il);
|
||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
@ -6699,7 +6726,7 @@ struct llm_build_context {
|
||||
|
||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
struct ggml_tensor * sa_out = cur;
|
||||
@ -6798,7 +6825,7 @@ struct llm_build_context {
|
||||
|
||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
@ -6907,7 +6934,7 @@ struct llm_build_context {
|
||||
|
||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
@ -7025,7 +7052,7 @@ struct llm_build_context {
|
||||
|
||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
@ -7144,7 +7171,7 @@ struct llm_build_context {
|
||||
|
||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
@ -7276,7 +7303,7 @@ struct llm_build_context {
|
||||
|
||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
@ -7507,6 +7534,18 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||
}
|
||||
}
|
||||
|
||||
if (hparams.need_kq_pos) {
|
||||
const int64_t n_kv = kv_self.n;
|
||||
|
||||
assert(ggml_backend_buffer_is_host(lctx.inp_KQ_pos->buffer));
|
||||
|
||||
float * data = (float *) lctx.inp_KQ_pos->data;
|
||||
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
data[i] = float(lctx.kv_self.cells[i].pos);
|
||||
}
|
||||
}
|
||||
|
||||
if (kv_self.has_shift) {
|
||||
const int64_t n_ctx = cparams.n_ctx;
|
||||
|
||||
@ -11434,7 +11473,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
// graph inputs
|
||||
{
|
||||
ggml_init_params init_params = {
|
||||
/* .mem_size */ ggml_tensor_overhead()*7,
|
||||
/* .mem_size */ ggml_tensor_overhead()*8,
|
||||
/* .mem_buffer */ nullptr,
|
||||
/* .no_alloc */ true,
|
||||
};
|
||||
@ -11444,6 +11483,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
ctx->inp_embd = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, hparams.n_embd, cparams.n_batch);
|
||||
ctx->inp_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
|
||||
ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch);
|
||||
ctx->inp_KQ_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx);
|
||||
ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx);
|
||||
ctx->inp_mean = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch);
|
||||
ctx->inp_cls = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
|
||||
@ -11452,6 +11492,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
ggml_set_name(ctx->inp_embd, "inp_embd");
|
||||
ggml_set_name(ctx->inp_pos, "inp_pos");
|
||||
ggml_set_name(ctx->inp_KQ_mask, "inp_KQ_mask");
|
||||
ggml_set_name(ctx->inp_KQ_pos, "inp_KQ_pos");
|
||||
ggml_set_name(ctx->inp_K_shift, "inp_K_shift");
|
||||
ggml_set_name(ctx->inp_mean, "inp_mean");
|
||||
ggml_set_name(ctx->inp_cls, "inp_cls");
|
||||
|
@ -1085,24 +1085,32 @@ struct test_diag_mask_inf : public test_case {
|
||||
struct test_soft_max : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
const float scale;
|
||||
const bool mask;
|
||||
const float scale;
|
||||
const float max_bias;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR4(type, ne, scale, mask);
|
||||
return VARS_TO_STR5(type, ne, mask, scale, max_bias);
|
||||
}
|
||||
|
||||
test_soft_max(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {10, 10, 10, 10},
|
||||
bool mask = false,
|
||||
float scale = 1.0f,
|
||||
bool mask = false)
|
||||
: type(type), ne(ne), scale(scale), mask(mask) {}
|
||||
float max_bias = 0.0f)
|
||||
: type(type), ne(ne), mask(mask), scale(scale), max_bias(max_bias) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_tensor * b = nullptr;
|
||||
if (mask) { b = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]); }
|
||||
ggml_tensor * out = ggml_soft_max_ext(ctx, a, b, scale);
|
||||
ggml_tensor * mask = nullptr;
|
||||
if (this->mask) {
|
||||
mask = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]);
|
||||
}
|
||||
ggml_tensor * pos = nullptr;
|
||||
if (max_bias > 0.0f) {
|
||||
pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]);
|
||||
}
|
||||
ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, pos, scale, max_bias);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
@ -1147,30 +1155,6 @@ struct test_rope : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_ALIBI
|
||||
struct test_alibi : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
int n_past;
|
||||
int n_head;
|
||||
float bias_max;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR5(type, ne, n_past, n_head, bias_max);
|
||||
}
|
||||
|
||||
test_alibi(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {10, 10, 10, 10},
|
||||
int n_past = 512, int n_head = 10, float bias_max = 0.5f)
|
||||
: type(type), ne(ne), n_past(n_past), n_head(n_head), bias_max(bias_max) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_tensor * out = ggml_alibi(ctx, a, n_past, n_head, bias_max);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_POOL2D
|
||||
struct test_pool2d : public test_case {
|
||||
enum ggml_op_pool pool_type;
|
||||
@ -1488,7 +1472,7 @@ struct test_moe : public test_case {
|
||||
ggml_tensor * cur = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens);
|
||||
|
||||
ggml_tensor * logits = ggml_mul_mat(ctx, ffn_gate_inp, cur);
|
||||
ggml_tensor * probs = ggml_soft_max_ext(ctx, logits, nullptr, 1.0f/sqrtf(n_embd));
|
||||
ggml_tensor * probs = ggml_soft_max_ext(ctx, logits, nullptr, nullptr, 1.0f/sqrtf(n_embd), 0.0f);
|
||||
|
||||
// select experts
|
||||
ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_experts_per_tok);
|
||||
@ -1617,7 +1601,6 @@ public:
|
||||
ggml_cpy(ctx, v_cur_t, v_cache_view);
|
||||
}
|
||||
|
||||
// if max_alibi_bias > 0 then apply ALiBi
|
||||
struct ggml_tensor * llm_build_kqv(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * k_l,
|
||||
@ -1636,7 +1619,7 @@ public:
|
||||
|
||||
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
||||
|
||||
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale);
|
||||
kq = ggml_soft_max_ext(ctx, kq, kq_mask, nullptr, kq_scale, 0.0f);
|
||||
|
||||
// split cached v into n_head heads
|
||||
struct ggml_tensor * v =
|
||||
@ -2083,6 +2066,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 1}, 5));
|
||||
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 10}, 5));
|
||||
|
||||
#if 0
|
||||
std::uniform_int_distribution<> dist_ne1(1, 50);
|
||||
int exponent = 1;
|
||||
while (exponent < (1 << 17)) {
|
||||
@ -2091,14 +2075,29 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
for (int n = 0; n < 10; ++n) {
|
||||
int64_t ne0 = dist_ne0(rng);
|
||||
int64_t ne1 = dist_ne1(rng);
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, n/2 == 0, 0.1f, ne0 < 1000 ? 4.0f : 0.0f));
|
||||
}
|
||||
|
||||
exponent <<= 1;
|
||||
}
|
||||
#endif
|
||||
for (bool mask : {false, true}) {
|
||||
for (float max_bias : {0.0f, 8.0f}) {
|
||||
for (float scale : {1.0f, 0.1f}) {
|
||||
for (int64_t ne0 : {16, 1024}) {
|
||||
for (int64_t ne1 : {16, 1024}) {
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, scale, max_bias));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, scale, max_bias));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, 0.1f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, 0.1f, true));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 8.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f));
|
||||
|
||||
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||
test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512)); // llama 7B
|
||||
@ -2113,7 +2112,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 32, 2, 512)); // neox (phi-2)
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_alibi());
|
||||
test_cases.emplace_back(new test_concat(GGML_TYPE_F32));
|
||||
test_cases.emplace_back(new test_concat(GGML_TYPE_I32));
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user