mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 12:10:18 +00:00
CUDA: fix MMQ writeback for int8 tensor cores (#8100)
This commit is contained in:
parent
a818f3028d
commit
3b099bcd9c
@ -2054,15 +2054,13 @@ static __device__ __forceinline__ void mmq_write_back_mma(
|
|||||||
static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
|
static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
|
||||||
#endif // INT8_MMA_AVAILABLE
|
#endif // INT8_MMA_AVAILABLE
|
||||||
|
|
||||||
dst += (threadIdx.y % ntx) * mma_C::J*stride;
|
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int n = 0; n < ntx; ++n) {
|
for (int n = 0; n < ntx; ++n) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < mma_C::ne; ++l) {
|
for (int l = 0; l < mma_C::ne; ++l) {
|
||||||
const int j = j0 + mma_C::get_j(l);
|
const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l);
|
||||||
|
|
||||||
if (j > j_max) {
|
if (j > j_max) {
|
||||||
continue;
|
continue;
|
||||||
|
Loading…
Reference in New Issue
Block a user