diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index fe8acc803..2dd4fc739 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -149,6 +149,7 @@ set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING "ggml: max. batch size for using peer access") option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF) option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" OFF) +option(GGML_CUDA_FA "ggml: compile with FlashAttention" ON) option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF) option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT}) diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index 14761650f..a8df3a221 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -28,24 +28,35 @@ if (CUDAToolkit_FOUND) list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h") file(GLOB GGML_SOURCES_CUDA "*.cu") - file(GLOB SRCS "template-instances/fattn-wmma*.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) - file(GLOB SRCS "template-instances/mmq*.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) - - if (GGML_CUDA_FA_ALL_QUANTS) - file(GLOB SRCS "template-instances/fattn-vec*.cu") + if (GGML_CUDA_FA) + file(GLOB SRCS "template-instances/fattn-wmma*.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) - add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) else() - file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) - file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) - file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu") + list(FILTER GGML_SOURCES_CUDA EXCLUDE REGEX ".*fattn.*") + list(FILTER GGML_HEADERS_CUDA EXCLUDE REGEX ".*fattn.*") + # message(FATAL_ERROR ${GGML_SOURCES_CUDA}) + endif() + if (NOT GGML_CUDA_FORCE_CUBLAS) + file(GLOB SRCS "template-instances/mmq*.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) endif() + if (GGML_CUDA_FA) + add_compile_definitions(GGML_CUDA_FA) + if (GGML_CUDA_FA_ALL_QUANTS) + file(GLOB SRCS "template-instances/fattn-vec*.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) + else() + file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + endif() + endif() + ggml_add_backend_library(ggml-cuda ${GGML_HEADERS_CUDA} ${GGML_SOURCES_CUDA} diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 2c0a56226..9508a5262 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -151,6 +151,10 @@ typedef float2 dfloat2; #define FLASH_ATTN_AVAILABLE #endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1) +#if !defined(GGML_CUDA_FA) +#undef FLASH_ATTN_AVAILABLE +#endif + static constexpr bool fast_fp16_available(const int cc) { return cc >= GGML_CUDA_CC_PASCAL && cc != 610; } diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 8476ee1bc..43cba1cf7 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -16,7 +16,9 @@ #include "ggml-cuda/cpy.cuh" #include "ggml-cuda/cross-entropy-loss.cuh" #include "ggml-cuda/diagmask.cuh" +#ifdef FLASH_ATTN_AVAILABLE #include "ggml-cuda/fattn.cuh" +#endif #include "ggml-cuda/getrows.cuh" #include "ggml-cuda/im2col.cuh" #include "ggml-cuda/mmq.cuh" @@ -2160,8 +2162,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_argsort(ctx, dst); break; case GGML_OP_FLASH_ATTN_EXT: +#ifdef FLASH_ATTN_AVAILABLE ggml_cuda_flash_attn_ext(ctx, dst); break; +#else + return false; +#endif case GGML_OP_CROSS_ENTROPY_LOSS: ggml_cuda_cross_entropy_loss(ctx, dst); break; diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 270251df4..4bd540f48 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -1,5 +1,12 @@ #include "mmq.cuh" +#ifdef GGML_CUDA_FORCE_CUBLAS +void ggml_cuda_op_mul_mat_q( + ggml_backend_cuda_context &, + const ggml_tensor *, const ggml_tensor *, ggml_tensor *, const char *, const float *, + const char *, float *, const int64_t, const int64_t, const int64_t, + const int64_t, cudaStream_t) {} +#else void ggml_cuda_op_mul_mat_q( ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, @@ -94,6 +101,7 @@ void ggml_cuda_op_mul_mat_q( GGML_UNUSED(dst); GGML_UNUSED(src1_ddf_i); } +#endif // GGML_CUDA_FORCE_CUBLAS bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { #ifdef GGML_CUDA_FORCE_CUBLAS diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 3cd508a1d..6c150dffb 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -2906,6 +2906,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda #define DECL_MMQ_CASE(type) \ template void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \ +#if !defined(GGML_CUDA_FORCE_CUBLAS) extern DECL_MMQ_CASE(GGML_TYPE_Q4_0); extern DECL_MMQ_CASE(GGML_TYPE_Q4_1); extern DECL_MMQ_CASE(GGML_TYPE_Q5_0); @@ -2924,6 +2925,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S); extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); +#endif // -------------------------------------------------------------------------------------------------------------------------