mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 02:44:36 +00:00
ggml : add ggml_soft_max_ext (#4256)
* metal : implement soft_max_ext * cuda : implement soft_max_ext * ggml : implement soft_max_ext (CPU) * batched-bench : print threads ggml-ci * metal : simplify soft_max encoding ggml-ci * cuda : use 512 threads for soft_max instead of 32 * ggml : update soft max cpu * cuda : do warp-based block reduce * cuda : increase max block size to 1024 * cuda : fix warp reduction initialization of shared mem * metal : warp-based reduction for soft max kernel * metal : warp-based reduce for rms_norm * metal : simplify soft max kernel ggml-ci * alloc : fix build with debug
This commit is contained in:
parent
1d144112c0
commit
ef47ec18da
@ -155,7 +155,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
LOG_TEE("\n");
|
LOG_TEE("\n");
|
||||||
LOG_TEE("%s: n_kv_max = %d, is_pp_shared = %d, n_gpu_layers = %d, mmq = %d\n", __func__, n_kv_max, is_pp_shared, n_gpu_layers, mmq);
|
LOG_TEE("%s: n_kv_max = %d, is_pp_shared = %d, n_gpu_layers = %d, mmq = %d, n_threads = %d, n_threads_batch = %d\n", __func__, n_kv_max, is_pp_shared, n_gpu_layers, mmq, ctx_params.n_threads, ctx_params.n_threads_batch);
|
||||||
LOG_TEE("\n");
|
LOG_TEE("\n");
|
||||||
|
|
||||||
LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");
|
LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");
|
||||||
|
@ -137,7 +137,7 @@ void ggml_tallocr_alloc(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
|
|||||||
|
|
||||||
#ifdef GGML_ALLOCATOR_DEBUG
|
#ifdef GGML_ALLOCATOR_DEBUG
|
||||||
add_allocated_tensor(alloc, tensor);
|
add_allocated_tensor(alloc, tensor);
|
||||||
size_t cur_max = (char*)addr - (char*)alloc->data + size;
|
size_t cur_max = (char*)addr - (char*)alloc->base + size;
|
||||||
if (cur_max > alloc->max_size) {
|
if (cur_max > alloc->max_size) {
|
||||||
printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
|
printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
|
||||||
for (int i = 0; i < 1024; i++) {
|
for (int i = 0; i < 1024; i++) {
|
||||||
|
130
ggml-cuda.cu
130
ggml-cuda.cu
@ -443,6 +443,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
|
|||||||
#define CUDA_SCALE_BLOCK_SIZE 256
|
#define CUDA_SCALE_BLOCK_SIZE 256
|
||||||
#define CUDA_CLAMP_BLOCK_SIZE 256
|
#define CUDA_CLAMP_BLOCK_SIZE 256
|
||||||
#define CUDA_ROPE_BLOCK_SIZE 256
|
#define CUDA_ROPE_BLOCK_SIZE 256
|
||||||
|
#define CUDA_SOFT_MAX_BLOCK_SIZE 1024
|
||||||
#define CUDA_ALIBI_BLOCK_SIZE 32
|
#define CUDA_ALIBI_BLOCK_SIZE 32
|
||||||
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
|
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
|
||||||
#define CUDA_QUANTIZE_BLOCK_SIZE 256
|
#define CUDA_QUANTIZE_BLOCK_SIZE 256
|
||||||
@ -501,6 +502,31 @@ static size_t g_scratch_offset = 0;
|
|||||||
|
|
||||||
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
|
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
|
||||||
|
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
|
||||||
|
}
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
|
static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
|
||||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
@ -577,15 +603,6 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
|
|||||||
dst[i] = x[i] * x[i];
|
dst[i] = x[i] * x[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
|
||||||
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
|
|
||||||
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
|
|
||||||
}
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int block_size>
|
template <int block_size>
|
||||||
static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
|
static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
|
||||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||||
@ -624,14 +641,6 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
|
||||||
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
|
|
||||||
}
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int block_size>
|
template <int block_size>
|
||||||
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
|
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
|
||||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||||
@ -4717,45 +4726,74 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
|
|||||||
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
|
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
|
||||||
}
|
}
|
||||||
|
|
||||||
// the CUDA soft max implementation differs from the CPU implementation
|
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
|
||||||
// instead of doubles floats are used
|
const int tid = threadIdx.x;
|
||||||
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
|
const int rowx = blockIdx.x;
|
||||||
const int row = blockDim.x*blockIdx.x + threadIdx.x;
|
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
|
||||||
const int block_size = blockDim.y;
|
|
||||||
const int tid = threadIdx.y;
|
const int block_size = blockDim.x;
|
||||||
|
|
||||||
|
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
|
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
|
|
||||||
|
__shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
|
||||||
|
|
||||||
float max_val = -INFINITY;
|
float max_val = -INFINITY;
|
||||||
|
|
||||||
for (int col = tid; col < ncols; col += block_size) {
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
const int i = row*ncols + col;
|
const int ix = rowx*ncols + col;
|
||||||
max_val = max(max_val, x[i]);
|
const int iy = rowy*ncols + col;
|
||||||
|
max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f));
|
||||||
}
|
}
|
||||||
|
|
||||||
// find the max value in the block
|
// find the max value in the block
|
||||||
#pragma unroll
|
max_val = warp_reduce_max(max_val);
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
if (block_size > WARP_SIZE) {
|
||||||
max_val = max(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32));
|
if (warp_id == 0) {
|
||||||
|
buf[lane_id] = -INFINITY;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (lane_id == 0) {
|
||||||
|
buf[warp_id] = max_val;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
max_val = buf[lane_id];
|
||||||
|
max_val = warp_reduce_max(max_val);
|
||||||
}
|
}
|
||||||
|
|
||||||
float tmp = 0.f;
|
float tmp = 0.f;
|
||||||
|
|
||||||
for (int col = tid; col < ncols; col += block_size) {
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
const int i = row*ncols + col;
|
const int ix = rowx*ncols + col;
|
||||||
const float val = expf(x[i] - max_val);
|
const int iy = rowy*ncols + col;
|
||||||
|
const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val);
|
||||||
tmp += val;
|
tmp += val;
|
||||||
dst[i] = val;
|
dst[ix] = val;
|
||||||
}
|
}
|
||||||
|
|
||||||
// sum up partial sums
|
// find the sum of exps in the block
|
||||||
#pragma unroll
|
tmp = warp_reduce_sum(tmp);
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
if (block_size > WARP_SIZE) {
|
||||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
if (warp_id == 0) {
|
||||||
|
buf[lane_id] = 0.f;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (lane_id == 0) {
|
||||||
|
buf[warp_id] = tmp;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
tmp = buf[lane_id];
|
||||||
|
tmp = warp_reduce_sum(tmp);
|
||||||
}
|
}
|
||||||
|
|
||||||
const float inv_tmp = 1.f / tmp;
|
const float inv_tmp = 1.f / tmp;
|
||||||
|
|
||||||
for (int col = tid; col < ncols; col += block_size) {
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
const int i = row*ncols + col;
|
const int i = rowx*ncols + col;
|
||||||
dst[i] *= inv_tmp;
|
dst[i] *= inv_tmp;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -5792,10 +5830,12 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
|
|||||||
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
|
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
|
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
|
||||||
const dim3 block_dims(1, WARP_SIZE, 1);
|
int nth = WARP_SIZE;
|
||||||
|
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||||
|
const dim3 block_dims(nth, 1, 1);
|
||||||
const dim3 block_nums(nrows_x, 1, 1);
|
const dim3 block_nums(nrows_x, 1, 1);
|
||||||
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
|
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void im2col_f32_f16_cuda(const float * x, half * dst,
|
static void im2col_f32_f16_cuda(const float * x, half * dst,
|
||||||
@ -6846,14 +6886,18 @@ inline void ggml_cuda_op_soft_max(
|
|||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
||||||
|
|
||||||
const int64_t ne00 = src0->ne[0];
|
const int64_t ne00 = src0->ne[0];
|
||||||
const int64_t nrows = ggml_nrows(src0);
|
const int64_t nrows_x = ggml_nrows(src0);
|
||||||
|
const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1;
|
||||||
|
|
||||||
soft_max_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream);
|
float scale = 1.0f;
|
||||||
|
memcpy(&scale, dst->op_params, sizeof(float));
|
||||||
|
|
||||||
|
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
|
||||||
|
|
||||||
(void) src1;
|
|
||||||
(void) dst;
|
(void) dst;
|
||||||
(void) src1_dd;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void ggml_cuda_op_scale(
|
inline void ggml_cuda_op_scale(
|
||||||
|
31
ggml-metal.m
31
ggml-metal.m
@ -1028,20 +1028,27 @@ void ggml_metal_graph_compute(
|
|||||||
int nth = 32; // SIMD width
|
int nth = 32; // SIMD width
|
||||||
|
|
||||||
if (ne00%4 == 0) {
|
if (ne00%4 == 0) {
|
||||||
|
while (nth < ne00/4 && nth < 256) {
|
||||||
|
nth *= 2;
|
||||||
|
}
|
||||||
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
|
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
|
||||||
} else {
|
} else {
|
||||||
do {
|
while (nth < ne00 && nth < 1024) {
|
||||||
nth *= 2;
|
nth *= 2;
|
||||||
} while (nth <= ne00 && nth <= 1024);
|
}
|
||||||
nth /= 2;
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const float scale = ((float *) dst->op_params)[0];
|
||||||
|
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||||
[encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
||||||
|
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
|
||||||
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
@ -1351,7 +1358,11 @@ void ggml_metal_graph_compute(
|
|||||||
float eps;
|
float eps;
|
||||||
memcpy(&eps, dst->op_params, sizeof(float));
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
|
||||||
const int nth = MIN(512, ne00);
|
int nth = 32; // SIMD width
|
||||||
|
|
||||||
|
while (nth < ne00/4 && nth < 1024) {
|
||||||
|
nth *= 2;
|
||||||
|
}
|
||||||
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
|
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
@ -1359,7 +1370,7 @@ void ggml_metal_graph_compute(
|
|||||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
||||||
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
||||||
[encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||||
|
|
||||||
const int64_t nrows = ggml_nrows(src0);
|
const int64_t nrows = ggml_nrows(src0);
|
||||||
|
|
||||||
|
148
ggml-metal.metal
148
ggml-metal.metal
@ -39,6 +39,8 @@ typedef struct {
|
|||||||
int8_t qs[QK8_0]; // quants
|
int8_t qs[QK8_0]; // quants
|
||||||
} block_q8_0;
|
} block_q8_0;
|
||||||
|
|
||||||
|
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
||||||
|
|
||||||
// general-purpose kernel for addition of two tensors
|
// general-purpose kernel for addition of two tensors
|
||||||
// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
|
// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
|
||||||
// cons: not very efficient
|
// cons: not very efficient
|
||||||
@ -180,10 +182,12 @@ kernel void kernel_gelu(
|
|||||||
|
|
||||||
kernel void kernel_soft_max(
|
kernel void kernel_soft_max(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
|
constant float & scale,
|
||||||
threadgroup float * buf [[threadgroup(0)]],
|
threadgroup float * buf [[threadgroup(0)]],
|
||||||
uint tgpig[[threadgroup_position_in_grid]],
|
uint tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tpitg[[thread_position_in_threadgroup]],
|
uint tpitg[[thread_position_in_threadgroup]],
|
||||||
@ -195,72 +199,76 @@ kernel void kernel_soft_max(
|
|||||||
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
||||||
|
|
||||||
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||||
|
device const float * pmask = src1 ? src1 + i01*ne00 : nullptr;
|
||||||
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||||
|
|
||||||
// parallel max
|
// parallel max
|
||||||
float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
|
float lmax = -INFINITY;
|
||||||
|
|
||||||
for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||||
lmax = MAX(lmax, psrc0[i00]);
|
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
||||||
}
|
}
|
||||||
|
|
||||||
float max = simd_max(lmax);
|
// find the max value in the block
|
||||||
|
float max_val = simd_max(lmax);
|
||||||
|
if (ntg > N_SIMDWIDTH) {
|
||||||
|
if (sgitg == 0) {
|
||||||
|
buf[tiisg] = -INFINITY;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
buf[sgitg] = max;
|
buf[sgitg] = max_val;
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// broadcast, simd group number is ntg / 32
|
max_val = buf[tiisg];
|
||||||
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
max_val = simd_max(max_val);
|
||||||
if (tpitg < i) {
|
|
||||||
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
max = buf[0];
|
|
||||||
|
|
||||||
// parallel sum
|
// parallel sum
|
||||||
float lsum = 0.0f;
|
float lsum = 0.0f;
|
||||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||||
const float exp_psrc0 = exp(psrc0[i00] - max);
|
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
||||||
lsum += exp_psrc0;
|
lsum += exp_psrc0;
|
||||||
// Remember the result of exp here. exp is expensive, so we really do not
|
|
||||||
// wish to compute it twice.
|
|
||||||
pdst[i00] = exp_psrc0;
|
pdst[i00] = exp_psrc0;
|
||||||
}
|
}
|
||||||
|
|
||||||
float sum = simd_sum(lsum);
|
float sum = simd_sum(lsum);
|
||||||
|
if (ntg > N_SIMDWIDTH) {
|
||||||
|
if (sgitg == 0) {
|
||||||
|
buf[tiisg] = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
buf[sgitg] = sum;
|
buf[sgitg] = sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// broadcast, simd group number is ntg / 32
|
sum = buf[tiisg];
|
||||||
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
sum = simd_sum(sum);
|
||||||
if (tpitg < i) {
|
|
||||||
buf[tpitg] += buf[tpitg + i];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
const float inv_sum = 1.0f/sum;
|
||||||
|
|
||||||
sum = buf[0];
|
|
||||||
|
|
||||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||||
pdst[i00] /= sum;
|
pdst[i00] *= inv_sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_soft_max_4(
|
kernel void kernel_soft_max_4(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
|
constant float & scale,
|
||||||
threadgroup float * buf [[threadgroup(0)]],
|
threadgroup float * buf [[threadgroup(0)]],
|
||||||
uint tgpig[[threadgroup_position_in_grid]],
|
uint tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tpitg[[thread_position_in_threadgroup]],
|
uint tpitg[[thread_position_in_threadgroup]],
|
||||||
@ -272,63 +280,67 @@ kernel void kernel_soft_max_4(
|
|||||||
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
||||||
|
|
||||||
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||||
|
device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
||||||
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||||
|
|
||||||
// parallel max
|
// parallel max
|
||||||
float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
|
float4 lmax4 = -INFINITY;
|
||||||
|
|
||||||
for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||||
lmax4 = fmax(lmax4, psrc4[i00]);
|
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
||||||
}
|
}
|
||||||
|
|
||||||
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
||||||
float max = simd_max(lmax);
|
|
||||||
|
float max_val = simd_max(lmax);
|
||||||
|
if (ntg > N_SIMDWIDTH) {
|
||||||
|
if (sgitg == 0) {
|
||||||
|
buf[tiisg] = -INFINITY;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
buf[sgitg] = max;
|
buf[sgitg] = max_val;
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// broadcast, simd group number is ntg / 32
|
max_val = buf[tiisg];
|
||||||
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
max_val = simd_max(max_val);
|
||||||
if (tpitg < i) {
|
|
||||||
buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
max = buf[0];
|
|
||||||
|
|
||||||
// parallel sum
|
// parallel sum
|
||||||
float4 lsum4 = 0.0f;
|
float4 lsum4 = 0.0f;
|
||||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||||
const float4 exp_psrc4 = exp(psrc4[i00] - max);
|
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
||||||
lsum4 += exp_psrc4;
|
lsum4 += exp_psrc4;
|
||||||
pdst4[i00] = exp_psrc4;
|
pdst4[i00] = exp_psrc4;
|
||||||
}
|
}
|
||||||
|
|
||||||
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
|
||||||
float sum = simd_sum(lsum);
|
float sum = simd_sum(lsum);
|
||||||
|
if (ntg > N_SIMDWIDTH) {
|
||||||
|
if (sgitg == 0) {
|
||||||
|
buf[tiisg] = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
buf[sgitg] = sum;
|
buf[sgitg] = sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// broadcast, simd group number is ntg / 32
|
sum = buf[tiisg];
|
||||||
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
sum = simd_sum(sum);
|
||||||
if (tpitg < i) {
|
|
||||||
buf[tpitg] += buf[tpitg + i];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
const float inv_sum = 1.0f/sum;
|
||||||
|
|
||||||
sum = buf[0];
|
|
||||||
|
|
||||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||||
pdst4[i00] /= sum;
|
pdst4[i00] *= inv_sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -435,14 +447,13 @@ kernel void kernel_rms_norm(
|
|||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant uint64_t & nb01,
|
constant uint64_t & nb01,
|
||||||
constant float & eps,
|
constant float & eps,
|
||||||
threadgroup float * sum [[threadgroup(0)]],
|
threadgroup float * buf [[threadgroup(0)]],
|
||||||
uint tgpig[[threadgroup_position_in_grid]],
|
uint tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tpitg[[thread_position_in_threadgroup]],
|
uint tpitg[[thread_position_in_threadgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint ntg[[threads_per_threadgroup]]) {
|
uint ntg[[threads_per_threadgroup]]) {
|
||||||
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
||||||
device const float * x_scalar = (device const float *) x;
|
|
||||||
|
|
||||||
float4 sumf = 0;
|
float4 sumf = 0;
|
||||||
float all_sum = 0;
|
float all_sum = 0;
|
||||||
@ -453,40 +464,30 @@ kernel void kernel_rms_norm(
|
|||||||
}
|
}
|
||||||
all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
|
all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
|
||||||
all_sum = simd_sum(all_sum);
|
all_sum = simd_sum(all_sum);
|
||||||
|
if (ntg > N_SIMDWIDTH) {
|
||||||
|
if (sgitg == 0) {
|
||||||
|
buf[tiisg] = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
sum[sgitg] = all_sum;
|
buf[sgitg] = all_sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// broadcast, simd group number is ntg / 32
|
all_sum = buf[tiisg];
|
||||||
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
all_sum = simd_sum(all_sum);
|
||||||
if (tpitg < i) {
|
|
||||||
sum[tpitg] += sum[tpitg + i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (tpitg == 0) {
|
|
||||||
for (int i = 4 * (ne00 / 4); i < ne00; i++) {
|
|
||||||
sum[0] += x_scalar[i];
|
|
||||||
}
|
|
||||||
sum[0] /= ne00;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
const float mean = all_sum/ne00;
|
||||||
|
|
||||||
const float mean = sum[0];
|
|
||||||
const float scale = 1.0f/sqrt(mean + eps);
|
const float scale = 1.0f/sqrt(mean + eps);
|
||||||
|
|
||||||
device float4 * y = (device float4 *) (dst + tgpig*ne00);
|
device float4 * y = (device float4 *) (dst + tgpig*ne00);
|
||||||
device float * y_scalar = (device float *) y;
|
|
||||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||||
y[i00] = x[i00] * scale;
|
y[i00] = x[i00] * scale;
|
||||||
}
|
}
|
||||||
if (tpitg == 0) {
|
|
||||||
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
|
|
||||||
y_scalar[i00] = x_scalar[i00] * scale;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
||||||
@ -576,7 +577,6 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
|||||||
// putting them in the kernel cause a significant performance penalty
|
// putting them in the kernel cause a significant performance penalty
|
||||||
#define N_DST 4 // each SIMD group works on 4 rows
|
#define N_DST 4 // each SIMD group works on 4 rows
|
||||||
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
||||||
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
|
||||||
//Note: This is a template, but strictly speaking it only applies to
|
//Note: This is a template, but strictly speaking it only applies to
|
||||||
// quantizations where the block size is 32. It also does not
|
// quantizations where the block size is 32. It also does not
|
||||||
// giard against the number of rows not being divisible by
|
// giard against the number of rows not being divisible by
|
||||||
|
69
ggml.c
69
ggml.c
@ -4826,7 +4826,17 @@ struct ggml_tensor * ggml_diag_mask_zero_inplace(
|
|||||||
static struct ggml_tensor * ggml_soft_max_impl(
|
static struct ggml_tensor * ggml_soft_max_impl(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * mask,
|
||||||
|
float scale,
|
||||||
bool inplace) {
|
bool inplace) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(a));
|
||||||
|
if (mask) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(mask));
|
||||||
|
GGML_ASSERT(mask->ne[2] == 1);
|
||||||
|
GGML_ASSERT(mask->ne[3] == 1);
|
||||||
|
GGML_ASSERT(ggml_can_repeat_rows(mask, a));
|
||||||
|
}
|
||||||
|
|
||||||
bool is_node = false;
|
bool is_node = false;
|
||||||
|
|
||||||
if (a->grad) {
|
if (a->grad) {
|
||||||
@ -4835,9 +4845,13 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
|||||||
|
|
||||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||||
|
|
||||||
|
float params[] = { scale };
|
||||||
|
ggml_set_op_params(result, params, sizeof(params));
|
||||||
|
|
||||||
result->op = GGML_OP_SOFT_MAX;
|
result->op = GGML_OP_SOFT_MAX;
|
||||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
result->src[0] = a;
|
result->src[0] = a;
|
||||||
|
result->src[1] = mask;
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
@ -4845,13 +4859,21 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
|||||||
struct ggml_tensor * ggml_soft_max(
|
struct ggml_tensor * ggml_soft_max(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a) {
|
struct ggml_tensor * a) {
|
||||||
return ggml_soft_max_impl(ctx, a, false);
|
return ggml_soft_max_impl(ctx, a, NULL, 1.0f, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * ggml_soft_max_inplace(
|
struct ggml_tensor * ggml_soft_max_inplace(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a) {
|
struct ggml_tensor * a) {
|
||||||
return ggml_soft_max_impl(ctx, a, true);
|
return ggml_soft_max_impl(ctx, a, NULL, 1.0f, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_soft_max_ext(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * mask,
|
||||||
|
float scale) {
|
||||||
|
return ggml_soft_max_impl(ctx, a, mask, scale, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ggml_soft_max_back
|
// ggml_soft_max_back
|
||||||
@ -10551,20 +10573,25 @@ static void ggml_compute_forward_diag_mask_zero(
|
|||||||
static void ggml_compute_forward_soft_max_f32(
|
static void ggml_compute_forward_soft_max_f32(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
const struct ggml_tensor * src0,
|
const struct ggml_tensor * src0,
|
||||||
|
const struct ggml_tensor * src1,
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
assert(ggml_is_contiguous(dst));
|
||||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
assert(ggml_are_same_shape(src0, dst));
|
||||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
|
||||||
|
|
||||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
float scale = 1.0f;
|
||||||
|
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
||||||
|
|
||||||
// TODO: handle transposed/permuted matrices
|
// TODO: handle transposed/permuted matrices
|
||||||
|
|
||||||
const int ith = params->ith;
|
const int ith = params->ith;
|
||||||
const int nth = params->nth;
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
const int64_t ne11 = src1 ? src1->ne[1] : 1;
|
||||||
|
|
||||||
const int nc = src0->ne[0];
|
const int nc = src0->ne[0];
|
||||||
const int nr = ggml_nrows(src0);
|
const int nr = ggml_nrows(src0);
|
||||||
|
|
||||||
@ -10575,29 +10602,40 @@ static void ggml_compute_forward_soft_max_f32(
|
|||||||
const int ir0 = dr*ith;
|
const int ir0 = dr*ith;
|
||||||
const int ir1 = MIN(ir0 + dr, nr);
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
|
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
||||||
|
|
||||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||||
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
|
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
|
||||||
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
||||||
|
|
||||||
|
// broadcast the mask across rows
|
||||||
|
float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
|
||||||
|
|
||||||
|
ggml_vec_cpy_f32 (nc, wp, sp);
|
||||||
|
ggml_vec_scale_f32(nc, wp, scale);
|
||||||
|
if (mp) {
|
||||||
|
ggml_vec_acc_f32(nc, wp, mp);
|
||||||
|
}
|
||||||
|
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
for (int i = 0; i < nc; ++i) {
|
for (int i = 0; i < nc; ++i) {
|
||||||
//printf("p[%d] = %f\n", i, p[i]);
|
//printf("p[%d] = %f\n", i, p[i]);
|
||||||
assert(!isnan(sp[i]));
|
assert(!isnan(wp[i]));
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
float max = -INFINITY;
|
float max = -INFINITY;
|
||||||
ggml_vec_max_f32(nc, &max, sp);
|
ggml_vec_max_f32(nc, &max, wp);
|
||||||
|
|
||||||
ggml_float sum = 0.0;
|
ggml_float sum = 0.0;
|
||||||
|
|
||||||
uint16_t scvt;
|
uint16_t scvt;
|
||||||
for (int i = 0; i < nc; i++) {
|
for (int i = 0; i < nc; i++) {
|
||||||
if (sp[i] == -INFINITY) {
|
if (wp[i] == -INFINITY) {
|
||||||
dp[i] = 0.0f;
|
dp[i] = 0.0f;
|
||||||
} else {
|
} else {
|
||||||
// const float val = (sp[i] == -INFINITY) ? 0.0 : exp(sp[i] - max);
|
// const float val = (wp[i] == -INFINITY) ? 0.0 : exp(wp[i] - max);
|
||||||
ggml_fp16_t s = GGML_FP32_TO_FP16(sp[i] - max);
|
ggml_fp16_t s = GGML_FP32_TO_FP16(wp[i] - max);
|
||||||
memcpy(&scvt, &s, sizeof(scvt));
|
memcpy(&scvt, &s, sizeof(scvt));
|
||||||
const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
|
const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
|
||||||
sum += (ggml_float)val;
|
sum += (ggml_float)val;
|
||||||
@ -10622,11 +10660,12 @@ static void ggml_compute_forward_soft_max_f32(
|
|||||||
static void ggml_compute_forward_soft_max(
|
static void ggml_compute_forward_soft_max(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
const struct ggml_tensor * src0,
|
const struct ggml_tensor * src0,
|
||||||
|
const struct ggml_tensor * src1,
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_soft_max_f32(params, src0, dst);
|
ggml_compute_forward_soft_max_f32(params, src0, src1, dst);
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
@ -13863,7 +13902,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_soft_max(params, tensor->src[0], tensor);
|
ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_SOFT_MAX_BACK:
|
case GGML_OP_SOFT_MAX_BACK:
|
||||||
{
|
{
|
||||||
@ -15899,6 +15938,12 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
|||||||
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
|
cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_SOFT_MAX:
|
||||||
|
{
|
||||||
|
n_tasks = MIN(MIN(4, n_threads), ggml_nrows(node->src[0]));
|
||||||
|
|
||||||
|
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
|
||||||
|
} break;
|
||||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(node->src[0]->ne[3] == 1);
|
GGML_ASSERT(node->src[0]->ne[3] == 1);
|
||||||
|
8
ggml.h
8
ggml.h
@ -1282,6 +1282,14 @@ extern "C" {
|
|||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
// fused soft_max(a*scale + mask)
|
||||||
|
// mask is optional
|
||||||
|
GGML_API struct ggml_tensor * ggml_soft_max_ext(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * mask,
|
||||||
|
float scale);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_soft_max_back(
|
GGML_API struct ggml_tensor * ggml_soft_max_back(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
@ -3704,6 +3704,8 @@ static struct ggml_tensor * llm_build_kqv(
|
|||||||
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
||||||
cb(kq, "kq", il);
|
cb(kq, "kq", il);
|
||||||
|
|
||||||
|
if (max_alibi_bias > 0.0f) {
|
||||||
|
// temporary branch until we figure out how to handle ggml_alibi through ggml_add
|
||||||
kq = ggml_scale(ctx, kq, kq_scale);
|
kq = ggml_scale(ctx, kq, kq_scale);
|
||||||
cb(kq, "kq_scaled", il);
|
cb(kq, "kq_scaled", il);
|
||||||
|
|
||||||
@ -3720,6 +3722,10 @@ static struct ggml_tensor * llm_build_kqv(
|
|||||||
|
|
||||||
kq = ggml_soft_max(ctx, kq);
|
kq = ggml_soft_max(ctx, kq);
|
||||||
cb(kq, "kq_soft_max", il);
|
cb(kq, "kq_soft_max", il);
|
||||||
|
} else {
|
||||||
|
kq = ggml_soft_max_ext(ctx, kq, kq_mask, 1.0f/sqrtf(float(n_embd_head)));
|
||||||
|
cb(kq, "kq_soft_max_ext", il);
|
||||||
|
}
|
||||||
|
|
||||||
// split cached v into n_head heads
|
// split cached v into n_head heads
|
||||||
struct ggml_tensor * v =
|
struct ggml_tensor * v =
|
||||||
@ -5041,6 +5047,7 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
|
|||||||
{ "kq_scaled_alibi", OFFLOAD_FUNC_KQ },
|
{ "kq_scaled_alibi", OFFLOAD_FUNC_KQ },
|
||||||
{ "kq_masked", OFFLOAD_FUNC_KQ },
|
{ "kq_masked", OFFLOAD_FUNC_KQ },
|
||||||
{ "kq_soft_max", OFFLOAD_FUNC_V },
|
{ "kq_soft_max", OFFLOAD_FUNC_V },
|
||||||
|
{ "kq_soft_max_ext", OFFLOAD_FUNC_V },
|
||||||
{ "v", OFFLOAD_FUNC_V },
|
{ "v", OFFLOAD_FUNC_V },
|
||||||
{ "kqv", OFFLOAD_FUNC_V },
|
{ "kqv", OFFLOAD_FUNC_V },
|
||||||
{ "kqv_merged", OFFLOAD_FUNC_V },
|
{ "kqv_merged", OFFLOAD_FUNC_V },
|
||||||
|
Loading…
Reference in New Issue
Block a user