diff --git a/ggml/src/ggml-cuda/sum.cu b/ggml/src/ggml-cuda/sum.cu index 21da63509..0583e4fe0 100644 --- a/ggml/src/ggml-cuda/sum.cu +++ b/ggml/src/ggml-cuda/sum.cu @@ -1,9 +1,13 @@ -#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700 +#define USE_CUB +#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700 + +#ifdef USE_CUB // On Windows CUB uses libraries with variables called CC_PASCAL which conflict with the define in common.cuh. // For this reason CUB must be included BEFORE anything else. #include using namespace cub; -#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) +#endif // USE_CUB #include "sumrows.cuh" #include "sum.cuh" @@ -11,7 +15,7 @@ using namespace cub; #include void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream) { -#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) +#ifdef USE_CUB size_t tmp_size = 0; DeviceReduce::Sum(nullptr, tmp_size, x, dst, ne, stream); ggml_cuda_pool_alloc tmp_alloc(pool, tmp_size); @@ -21,7 +25,7 @@ void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int // For AMD there is rocPRIM which could be used as a drop-in replacement via hipcub but this would require C++11 -> C++14. sum_rows_f32_cuda(x, dst, ne, 1, stream); GGML_UNUSED(pool); -#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) +#endif // USE_CUB } void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {