This commit is contained in:
Milot Mirdita 2025-01-12 23:46:06 +05:00 committed by GitHub
commit da74ed47a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 45 additions and 13 deletions

View File

@ -149,6 +149,7 @@ set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
"ggml: max. batch size for using peer access")
option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF)
option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" OFF)
option(GGML_CUDA_FA "ggml: compile with FlashAttention" ON)
option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT})

View File

@ -28,11 +28,21 @@ if (CUDAToolkit_FOUND)
list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
file(GLOB GGML_SOURCES_CUDA "*.cu")
if (GGML_CUDA_FA)
file(GLOB SRCS "template-instances/fattn-wmma*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
else()
list(FILTER GGML_SOURCES_CUDA EXCLUDE REGEX ".*fattn.*")
list(FILTER GGML_HEADERS_CUDA EXCLUDE REGEX ".*fattn.*")
# message(FATAL_ERROR ${GGML_SOURCES_CUDA})
endif()
if (NOT GGML_CUDA_FORCE_CUBLAS)
file(GLOB SRCS "template-instances/mmq*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
endif()
if (GGML_CUDA_FA)
add_compile_definitions(GGML_CUDA_FA)
if (GGML_CUDA_FA_ALL_QUANTS)
file(GLOB SRCS "template-instances/fattn-vec*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
@ -45,6 +55,7 @@ if (CUDAToolkit_FOUND)
file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
endif()
endif()
ggml_add_backend_library(ggml-cuda
${GGML_HEADERS_CUDA}

View File

@ -151,6 +151,10 @@ typedef float2 dfloat2;
#define FLASH_ATTN_AVAILABLE
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
#if !defined(GGML_CUDA_FA)
#undef FLASH_ATTN_AVAILABLE
#endif
static constexpr bool fast_fp16_available(const int cc) {
return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
}

View File

@ -16,7 +16,9 @@
#include "ggml-cuda/cpy.cuh"
#include "ggml-cuda/cross-entropy-loss.cuh"
#include "ggml-cuda/diagmask.cuh"
#ifdef FLASH_ATTN_AVAILABLE
#include "ggml-cuda/fattn.cuh"
#endif
#include "ggml-cuda/getrows.cuh"
#include "ggml-cuda/im2col.cuh"
#include "ggml-cuda/mmq.cuh"
@ -2160,8 +2162,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_argsort(ctx, dst);
break;
case GGML_OP_FLASH_ATTN_EXT:
#ifdef FLASH_ATTN_AVAILABLE
ggml_cuda_flash_attn_ext(ctx, dst);
break;
#else
return false;
#endif
case GGML_OP_CROSS_ENTROPY_LOSS:
ggml_cuda_cross_entropy_loss(ctx, dst);
break;

View File

@ -1,5 +1,12 @@
#include "mmq.cuh"
#ifdef GGML_CUDA_FORCE_CUBLAS
void ggml_cuda_op_mul_mat_q(
ggml_backend_cuda_context &,
const ggml_tensor *, const ggml_tensor *, ggml_tensor *, const char *, const float *,
const char *, float *, const int64_t, const int64_t, const int64_t,
const int64_t, cudaStream_t) {}
#else
void ggml_cuda_op_mul_mat_q(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
@ -94,6 +101,7 @@ void ggml_cuda_op_mul_mat_q(
GGML_UNUSED(dst);
GGML_UNUSED(src1_ddf_i);
}
#endif // GGML_CUDA_FORCE_CUBLAS
bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
#ifdef GGML_CUDA_FORCE_CUBLAS

View File

@ -2906,6 +2906,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
#define DECL_MMQ_CASE(type) \
template void mul_mat_q_case<type>(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \
#if !defined(GGML_CUDA_FORCE_CUBLAS)
extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
@ -2924,6 +2925,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
#endif
// -------------------------------------------------------------------------------------------------------------------------