Define and optimize RDNA1 (#8085)

This commit is contained in:
Daniele 2024-07-03 23:02:58 +00:00 committed by GitHub
parent 5f2d4e60e2
commit d23287f122
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 3 deletions

View File

@ -227,6 +227,10 @@ typedef float2 dfloat2;
#define RDNA2 #define RDNA2
#endif #endif
#if defined(__gfx1010__) || defined(__gfx1012__)
#define RDNA1
#endif
#ifndef __has_builtin #ifndef __has_builtin
#define __has_builtin(x) 0 #define __has_builtin(x) 0
#endif #endif

View File

@ -60,12 +60,16 @@ static constexpr __device__ int get_mmq_x_max_device() {
} }
static constexpr int get_mmq_y_host(const int cc) { static constexpr int get_mmq_y_host(const int cc) {
return int8_mma_available(cc) || cc >= CC_VOLTA ? 128 : 64; return cc >= CC_OFFSET_AMD ? (cc == CC_RDNA1 ? 64 : 128) : (cc >= CC_VOLTA ? 128 : 64);
} }
static constexpr __device__ int get_mmq_y_device() { static constexpr __device__ int get_mmq_y_device() {
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
#if defined(RDNA1)
return 64;
#else
return 128; return 128;
#endif // defined RDNA1
#else #else
#if __CUDA_ARCH__ >= CC_VOLTA #if __CUDA_ARCH__ >= CC_VOLTA
return 128; return 128;
@ -2259,9 +2263,9 @@ static __device__ void mul_mat_q_process_tile(
template <ggml_type type, int mmq_x, int nwarps, bool need_check> template <ggml_type type, int mmq_x, int nwarps, bool need_check>
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
#if defined(RDNA3) || defined(RDNA2) #if defined(RDNA3) || defined(RDNA2) || defined(RDNA1)
__launch_bounds__(WARP_SIZE*nwarps, 2) __launch_bounds__(WARP_SIZE*nwarps, 2)
#endif // defined(RDNA3) || defined(RDNA2) #endif // defined(RDNA3) || defined(RDNA2) || defined(RDNA1)
#else #else
#if __CUDA_ARCH__ >= CC_VOLTA #if __CUDA_ARCH__ >= CC_VOLTA
__launch_bounds__(WARP_SIZE*nwarps, 1) __launch_bounds__(WARP_SIZE*nwarps, 1)