CUDA: fix MMQ writeback for int8 tensor cores (#8100)

This commit is contained in:
Johannes Gäßler 2024-06-24 22:15:33 +02:00 committed by GitHub
parent a818f3028d
commit 3b099bcd9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2054,15 +2054,13 @@ static __device__ __forceinline__ void mmq_write_back_mma(
static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y"); static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
#endif // INT8_MMA_AVAILABLE #endif // INT8_MMA_AVAILABLE
dst += (threadIdx.y % ntx) * mma_C::J*stride;
#pragma unroll #pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
#pragma unroll #pragma unroll
for (int n = 0; n < ntx; ++n) { for (int n = 0; n < ntx; ++n) {
#pragma unroll #pragma unroll
for (int l = 0; l < mma_C::ne; ++l) { for (int l = 0; l < mma_C::ne; ++l) {
const int j = j0 + mma_C::get_j(l); const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l);
if (j > j_max) { if (j > j_max) {
continue; continue;