mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 03:44:35 +00:00
cuda : fix build
This commit is contained in:
parent
013721df2b
commit
6be02b5969
@ -2384,17 +2384,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
ggml_cuda_op_argsort(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
ggml_cuda_flash_attn_ext(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
|
||||
ggml_cuda_flash_attn_ext(ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
|
||||
} else {
|
||||
func(ctx, tensor->src[0], tensor->src[1], tensor);
|
||||
}
|
||||
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess) {
|
||||
fprintf(stderr, "%s: %s failed\n", __func__, ggml_op_desc(dst));
|
||||
|
@ -1,5 +1,33 @@
|
||||
#include "fattn.cuh"
|
||||
|
||||
#include <mma.h>
|
||||
|
||||
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
|
||||
}
|
||||
return a;
|
||||
#else
|
||||
GGML_UNUSED(a);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
||||
}
|
||||
return x;
|
||||
#else
|
||||
GGML_UNUSED(x);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||
}
|
||||
|
||||
#if __CUDA_ARCH__ >= CC_VOLTA
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_a;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_b;
|
||||
@ -10,11 +38,11 @@ typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, half>
|
||||
// based on metal version
|
||||
template<int D, int Q, int C> // D head size, Q queries per block, C cache items per block
|
||||
static __global__ void flash_attn_ext_f16(
|
||||
const char* __restrict__ q,
|
||||
const char* __restrict__ k,
|
||||
const char* __restrict__ v,
|
||||
const char* __restrict__ mask,
|
||||
float* __restrict__ dst,
|
||||
const char * __restrict__ q,
|
||||
const char * __restrict__ k,
|
||||
const char * __restrict__ v,
|
||||
const char * __restrict__ mask,
|
||||
float * __restrict__ dst,
|
||||
float scale,
|
||||
int ne00,
|
||||
int ne01,
|
||||
@ -408,7 +436,15 @@ static __global__ void flash_attn_ext_f16(
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, const ggml_tensor * mask, ggml_tensor * KQV) {
|
||||
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
|
||||
ggml_tensor * KQV = dst;
|
||||
|
||||
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(K->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(V->type == GGML_TYPE_F16);
|
||||
|
@ -1,6 +1,3 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_flash_attn_ext(
|
||||
ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V,
|
||||
const ggml_tensor * mask, ggml_tensor * KQV);
|
||||
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
Loading…
Reference in New Issue
Block a user