mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 21:39:52 +00:00
cuda : replace asserts in wrong architecture checks with __trap (#4556)
* cuda : replace asserts in wrong architecture checks with __trap * make bad_arch noreturn, remove returns
This commit is contained in:
parent
d3223afdad
commit
1398823922
82
ggml-cuda.cu
82
ggml-cuda.cu
@ -512,6 +512,14 @@ static size_t g_scratch_offset = 0;
|
||||
|
||||
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
|
||||
|
||||
[[noreturn]]
|
||||
static __device__ void bad_arch() {
|
||||
printf("ERROR: ggml-cuda was compiled without support for the current GPU architecture.\n");
|
||||
__trap();
|
||||
|
||||
(void) bad_arch; // suppress unused function warning
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
@ -1972,8 +1980,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_0_q8_1_imp
|
||||
// second part effectively subtracts 8 from each quant value
|
||||
return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y);
|
||||
#else
|
||||
assert(false);
|
||||
return 0.0f; // only to satisfy the compiler
|
||||
bad_arch();
|
||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||
}
|
||||
|
||||
@ -2010,8 +2017,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp
|
||||
// scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
|
||||
return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1));
|
||||
#else
|
||||
assert(false);
|
||||
return 0.0f; // only to satisfy the compiler
|
||||
bad_arch();
|
||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||
}
|
||||
|
||||
@ -2046,8 +2052,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_0_q8_1_imp
|
||||
// second part effectively subtracts 16 from each quant value
|
||||
return d5 * (sumi * ds8f.x - (16*vdr/QI5_0) * ds8f.y);
|
||||
#else
|
||||
assert(false);
|
||||
return 0.0f; // only to satisfy the compiler
|
||||
bad_arch();
|
||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||
}
|
||||
|
||||
@ -2092,8 +2097,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp
|
||||
return sumi*d5d8 + m5s8 / (QI5_1 / vdr);
|
||||
|
||||
#else
|
||||
assert(false);
|
||||
return 0.0f; // only to satisfy the compiler
|
||||
bad_arch();
|
||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||
}
|
||||
|
||||
@ -2114,8 +2118,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_q8_1_imp
|
||||
|
||||
return d8_0*d8_1 * sumi;
|
||||
#else
|
||||
assert(false);
|
||||
return 0.0f; // only to satisfy the compiler
|
||||
bad_arch();
|
||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||
}
|
||||
|
||||
@ -2145,8 +2148,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
|
||||
// scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
|
||||
return sumi*d8d8 + m8s8 / (QI8_1 / vdr);
|
||||
#else
|
||||
assert(false);
|
||||
return 0.0f; // only to satisfy the compiler
|
||||
bad_arch();
|
||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||
}
|
||||
|
||||
@ -2181,8 +2183,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
|
||||
|
||||
return dm2f.x*sumf_d - dm2f.y*sumf_m;
|
||||
#else
|
||||
assert(false);
|
||||
return 0.0f; // only to satisfy the compiler
|
||||
bad_arch();
|
||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||
}
|
||||
|
||||
@ -2219,8 +2220,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
|
||||
|
||||
return d8 * (dm2f.x*sumi_d - dm2f.y*sumi_m);
|
||||
#else
|
||||
assert(false);
|
||||
return 0.0f; // only to satisfy the compiler
|
||||
bad_arch();
|
||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||
}
|
||||
|
||||
@ -2260,8 +2260,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
|
||||
|
||||
return d3 * sumf;
|
||||
#else
|
||||
assert(false);
|
||||
return 0.0f; // only to satisfy the compiler
|
||||
bad_arch();
|
||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||
}
|
||||
|
||||
@ -2286,8 +2285,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
|
||||
|
||||
return d3*d8 * sumi;
|
||||
#else
|
||||
assert(false);
|
||||
return 0.0f; // only to satisfy the compiler
|
||||
bad_arch();
|
||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||
}
|
||||
|
||||
@ -2320,8 +2318,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
|
||||
return dm4f.x*sumf_d - dm4f.y*sumf_m;
|
||||
|
||||
#else
|
||||
assert(false);
|
||||
return 0.0f; // only to satisfy the compiler
|
||||
bad_arch();
|
||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||
}
|
||||
|
||||
@ -2354,8 +2351,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
|
||||
return dm4f.x*sumf_d - dm4f.y*sumf_m;
|
||||
|
||||
#else
|
||||
assert(false);
|
||||
return 0.0f; // only to satisfy the compiler
|
||||
bad_arch();
|
||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||
}
|
||||
|
||||
@ -2395,8 +2391,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
|
||||
return dm5f.x*sumf_d - dm5f.y*sumf_m;
|
||||
|
||||
#else
|
||||
assert(false);
|
||||
return 0.0f; // only to satisfy the compiler
|
||||
bad_arch();
|
||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||
}
|
||||
|
||||
@ -2429,8 +2424,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
|
||||
return dm4f.x*sumf_d - dm4f.y*sumf_m;
|
||||
|
||||
#else
|
||||
assert(false);
|
||||
return 0.0f; // only to satisfy the compiler
|
||||
bad_arch();
|
||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||
}
|
||||
|
||||
@ -2460,8 +2454,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
|
||||
|
||||
return d*sumf;
|
||||
#else
|
||||
assert(false);
|
||||
return 0.0f; // only to satisfy the compiler
|
||||
bad_arch();
|
||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||
}
|
||||
|
||||
@ -2492,8 +2485,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
|
||||
return d6 * sumf_d;
|
||||
|
||||
#else
|
||||
assert(false);
|
||||
return 0.0f; // only to satisfy the compiler
|
||||
bad_arch();
|
||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||
}
|
||||
|
||||
@ -3359,8 +3351,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
|
||||
return dall * sumf_d - dmin * sumf_m;
|
||||
|
||||
#else
|
||||
assert(false);
|
||||
return 0.0f; // only to satisfy the compiler
|
||||
bad_arch();
|
||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||
|
||||
#endif
|
||||
@ -3543,8 +3534,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
|
||||
return d * sumf_d;
|
||||
|
||||
#else
|
||||
assert(false);
|
||||
return 0.0f; // only to satisfy the compiler
|
||||
bad_arch();
|
||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||
|
||||
#endif
|
||||
@ -3954,7 +3944,7 @@ template <bool need_check> static __global__ void
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
#else
|
||||
(void) vec_dot_q4_0_q8_1_mul_mat;
|
||||
assert(false);
|
||||
bad_arch();
|
||||
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
||||
}
|
||||
|
||||
@ -4023,7 +4013,7 @@ template <bool need_check> static __global__ void
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
#else
|
||||
(void) vec_dot_q4_1_q8_1_mul_mat;
|
||||
assert(false);
|
||||
bad_arch();
|
||||
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
||||
}
|
||||
|
||||
@ -4090,7 +4080,7 @@ template <bool need_check> static __global__ void
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
#else
|
||||
(void) vec_dot_q5_0_q8_1_mul_mat;
|
||||
assert(false);
|
||||
bad_arch();
|
||||
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
||||
}
|
||||
|
||||
@ -4157,7 +4147,7 @@ mul_mat_q5_1(
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
#else
|
||||
(void) vec_dot_q5_1_q8_1_mul_mat;
|
||||
assert(false);
|
||||
bad_arch();
|
||||
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
||||
}
|
||||
|
||||
@ -4224,7 +4214,7 @@ template <bool need_check> static __global__ void
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
#else
|
||||
(void) vec_dot_q8_0_q8_1_mul_mat;
|
||||
assert(false);
|
||||
bad_arch();
|
||||
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
||||
}
|
||||
|
||||
@ -4291,7 +4281,7 @@ mul_mat_q2_K(
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
#else
|
||||
(void) vec_dot_q2_K_q8_1_mul_mat;
|
||||
assert(false);
|
||||
bad_arch();
|
||||
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
||||
}
|
||||
|
||||
@ -4360,7 +4350,7 @@ template <bool need_check> static __global__ void
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
#else
|
||||
(void) vec_dot_q3_K_q8_1_mul_mat;
|
||||
assert(false);
|
||||
bad_arch();
|
||||
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
||||
}
|
||||
|
||||
@ -4429,7 +4419,7 @@ template <bool need_check> static __global__ void
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
#else
|
||||
(void) vec_dot_q4_K_q8_1_mul_mat;
|
||||
assert(false);
|
||||
bad_arch();
|
||||
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
||||
}
|
||||
|
||||
@ -4496,7 +4486,7 @@ mul_mat_q5_K(
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
#else
|
||||
(void) vec_dot_q5_K_q8_1_mul_mat;
|
||||
assert(false);
|
||||
bad_arch();
|
||||
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
||||
}
|
||||
|
||||
@ -4565,7 +4555,7 @@ template <bool need_check> static __global__ void
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
#else
|
||||
(void) vec_dot_q6_K_q8_1_mul_mat;
|
||||
assert(false);
|
||||
bad_arch();
|
||||
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user