mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 04:00:16 +00:00
Merge 8dfe3d8e97
into 924518e2e5
This commit is contained in:
commit
da74ed47a4
@ -149,6 +149,7 @@ set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
|
||||
"ggml: max. batch size for using peer access")
|
||||
option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF)
|
||||
option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" OFF)
|
||||
option(GGML_CUDA_FA "ggml: compile with FlashAttention" ON)
|
||||
option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
|
||||
option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT})
|
||||
|
||||
|
@ -28,24 +28,35 @@ if (CUDAToolkit_FOUND)
|
||||
list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
|
||||
|
||||
file(GLOB GGML_SOURCES_CUDA "*.cu")
|
||||
file(GLOB SRCS "template-instances/fattn-wmma*.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
file(GLOB SRCS "template-instances/mmq*.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
|
||||
if (GGML_CUDA_FA_ALL_QUANTS)
|
||||
file(GLOB SRCS "template-instances/fattn-vec*.cu")
|
||||
if (GGML_CUDA_FA)
|
||||
file(GLOB SRCS "template-instances/fattn-wmma*.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
|
||||
else()
|
||||
file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu")
|
||||
list(FILTER GGML_SOURCES_CUDA EXCLUDE REGEX ".*fattn.*")
|
||||
list(FILTER GGML_HEADERS_CUDA EXCLUDE REGEX ".*fattn.*")
|
||||
# message(FATAL_ERROR ${GGML_SOURCES_CUDA})
|
||||
endif()
|
||||
if (NOT GGML_CUDA_FORCE_CUBLAS)
|
||||
file(GLOB SRCS "template-instances/mmq*.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA_FA)
|
||||
add_compile_definitions(GGML_CUDA_FA)
|
||||
if (GGML_CUDA_FA_ALL_QUANTS)
|
||||
file(GLOB SRCS "template-instances/fattn-vec*.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
|
||||
else()
|
||||
file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
endif()
|
||||
endif()
|
||||
|
||||
ggml_add_backend_library(ggml-cuda
|
||||
${GGML_HEADERS_CUDA}
|
||||
${GGML_SOURCES_CUDA}
|
||||
|
@ -151,6 +151,10 @@ typedef float2 dfloat2;
|
||||
#define FLASH_ATTN_AVAILABLE
|
||||
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
||||
|
||||
#if !defined(GGML_CUDA_FA)
|
||||
#undef FLASH_ATTN_AVAILABLE
|
||||
#endif
|
||||
|
||||
static constexpr bool fast_fp16_available(const int cc) {
|
||||
return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
|
||||
}
|
||||
|
@ -16,7 +16,9 @@
|
||||
#include "ggml-cuda/cpy.cuh"
|
||||
#include "ggml-cuda/cross-entropy-loss.cuh"
|
||||
#include "ggml-cuda/diagmask.cuh"
|
||||
#ifdef FLASH_ATTN_AVAILABLE
|
||||
#include "ggml-cuda/fattn.cuh"
|
||||
#endif
|
||||
#include "ggml-cuda/getrows.cuh"
|
||||
#include "ggml-cuda/im2col.cuh"
|
||||
#include "ggml-cuda/mmq.cuh"
|
||||
@ -2160,8 +2162,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
ggml_cuda_op_argsort(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
#ifdef FLASH_ATTN_AVAILABLE
|
||||
ggml_cuda_flash_attn_ext(ctx, dst);
|
||||
break;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
ggml_cuda_cross_entropy_loss(ctx, dst);
|
||||
break;
|
||||
|
@ -1,5 +1,12 @@
|
||||
#include "mmq.cuh"
|
||||
|
||||
#ifdef GGML_CUDA_FORCE_CUBLAS
|
||||
void ggml_cuda_op_mul_mat_q(
|
||||
ggml_backend_cuda_context &,
|
||||
const ggml_tensor *, const ggml_tensor *, ggml_tensor *, const char *, const float *,
|
||||
const char *, float *, const int64_t, const int64_t, const int64_t,
|
||||
const int64_t, cudaStream_t) {}
|
||||
#else
|
||||
void ggml_cuda_op_mul_mat_q(
|
||||
ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
||||
@ -94,6 +101,7 @@ void ggml_cuda_op_mul_mat_q(
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(src1_ddf_i);
|
||||
}
|
||||
#endif // GGML_CUDA_FORCE_CUBLAS
|
||||
|
||||
bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
||||
#ifdef GGML_CUDA_FORCE_CUBLAS
|
||||
|
@ -2906,6 +2906,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
|
||||
#define DECL_MMQ_CASE(type) \
|
||||
template void mul_mat_q_case<type>(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \
|
||||
|
||||
#if !defined(GGML_CUDA_FORCE_CUBLAS)
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
|
||||
@ -2924,6 +2925,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
|
||||
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
|
||||
#endif
|
||||
|
||||
// -------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user