Fix CUDA softmax by subtracting max value before exp (#2665)

This commit is contained in:
Jiahao Li 2023-08-23 02:27:06 +08:00 committed by GitHub
parent deb7dfca4b
commit 800c9635b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3979,24 +3979,29 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
// the CUDA soft max implementation differs from the CPU implementation // the CUDA soft max implementation differs from the CPU implementation
// instead of doubles floats are used // instead of doubles floats are used
// values are also not normalized to the maximum value by subtracting it in the exponential function
// theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine
static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) { static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
const int row = blockDim.x*blockIdx.x + threadIdx.x; const int row = blockDim.x*blockIdx.x + threadIdx.x;
const int block_size = blockDim.y; const int block_size = blockDim.y;
const int tid = threadIdx.y; const int tid = threadIdx.y;
float tmp = 0.0; float max_val = -INFINITY;
for (int block_start = 0; block_start < ncols; block_start += block_size) { for (int col = tid; col < ncols; col += block_size) {
const int col = block_start + tid; const int i = row*ncols + col;
max_val = max(max_val, x[i]);
if (col >= ncols) {
break;
} }
// find the max value in the block
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
max_val = max(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32));
}
float tmp = 0.f;
for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col; const int i = row*ncols + col;
const float val = expf(x[i]); const float val = expf(x[i] - max_val);
tmp += val; tmp += val;
dst[i] = val; dst[i] = val;
} }
@ -4007,15 +4012,11 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
} }
for (int block_start = 0; block_start < ncols; block_start += block_size) { const float inv_tmp = 1.f / tmp;
const int col = block_start + tid;
if (col >= ncols) {
break;
}
for (int col = tid; col < ncols; col += block_size) {
const int i = row*ncols + col; const int i = row*ncols + col;
dst[i] /= tmp; dst[i] *= inv_tmp;
} }
} }