mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 12:24:35 +00:00
cuda : fix build
This commit is contained in:
parent
013721df2b
commit
6be02b5969
@ -2384,17 +2384,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||||||
ggml_cuda_op_argsort(ctx, dst);
|
ggml_cuda_op_argsort(ctx, dst);
|
||||||
break;
|
break;
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
|
ggml_cuda_flash_attn_ext(ctx, dst);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
|
|
||||||
ggml_cuda_flash_attn_ext(ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
|
|
||||||
} else {
|
|
||||||
func(ctx, tensor->src[0], tensor->src[1], tensor);
|
|
||||||
}
|
|
||||||
|
|
||||||
cudaError_t err = cudaGetLastError();
|
cudaError_t err = cudaGetLastError();
|
||||||
if (err != cudaSuccess) {
|
if (err != cudaSuccess) {
|
||||||
fprintf(stderr, "%s: %s failed\n", __func__, ggml_op_desc(dst));
|
fprintf(stderr, "%s: %s failed\n", __func__, ggml_op_desc(dst));
|
||||||
|
@ -1,5 +1,33 @@
|
|||||||
#include "fattn.cuh"
|
#include "fattn.cuh"
|
||||||
|
|
||||||
|
#include <mma.h>
|
||||||
|
|
||||||
|
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||||
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
|
||||||
|
}
|
||||||
|
return a;
|
||||||
|
#else
|
||||||
|
GGML_UNUSED(a);
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
||||||
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
#else
|
||||||
|
GGML_UNUSED(x);
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||||
|
}
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= CC_VOLTA
|
#if __CUDA_ARCH__ >= CC_VOLTA
|
||||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_a;
|
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_a;
|
||||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_b;
|
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_b;
|
||||||
@ -10,11 +38,11 @@ typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, half>
|
|||||||
// based on metal version
|
// based on metal version
|
||||||
template<int D, int Q, int C> // D head size, Q queries per block, C cache items per block
|
template<int D, int Q, int C> // D head size, Q queries per block, C cache items per block
|
||||||
static __global__ void flash_attn_ext_f16(
|
static __global__ void flash_attn_ext_f16(
|
||||||
const char* __restrict__ q,
|
const char * __restrict__ q,
|
||||||
const char* __restrict__ k,
|
const char * __restrict__ k,
|
||||||
const char* __restrict__ v,
|
const char * __restrict__ v,
|
||||||
const char* __restrict__ mask,
|
const char * __restrict__ mask,
|
||||||
float* __restrict__ dst,
|
float * __restrict__ dst,
|
||||||
float scale,
|
float scale,
|
||||||
int ne00,
|
int ne00,
|
||||||
int ne01,
|
int ne01,
|
||||||
@ -408,7 +436,15 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, const ggml_tensor * mask, ggml_tensor * KQV) {
|
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * Q = dst->src[0];
|
||||||
|
const ggml_tensor * K = dst->src[1];
|
||||||
|
const ggml_tensor * V = dst->src[2];
|
||||||
|
|
||||||
|
const ggml_tensor * mask = dst->src[3];
|
||||||
|
|
||||||
|
ggml_tensor * KQV = dst;
|
||||||
|
|
||||||
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(K->type == GGML_TYPE_F16);
|
GGML_ASSERT(K->type == GGML_TYPE_F16);
|
||||||
GGML_ASSERT(V->type == GGML_TYPE_F16);
|
GGML_ASSERT(V->type == GGML_TYPE_F16);
|
||||||
|
@ -1,6 +1,3 @@
|
|||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
|
|
||||||
void ggml_cuda_flash_attn_ext(
|
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
ggml_backend_cuda_context & ctx,
|
|
||||||
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V,
|
|
||||||
const ggml_tensor * mask, ggml_tensor * KQV);
|
|
||||||
|
Loading…
Reference in New Issue
Block a user