mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-31 22:04:35 +00:00
Improve performance with better q4_k and q5_k dequant and store unrolling
This commit is contained in:
parent
2455bbc8bd
commit
c704d58112
@ -1861,7 +1861,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||||||
if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eFloat16 &&
|
if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eFloat16 &&
|
||||||
(vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eFloat16 &&
|
(vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eFloat16 &&
|
||||||
(vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat32 &&
|
(vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat32 &&
|
||||||
(vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat32
|
(vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat32 &&
|
||||||
|
(vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup
|
||||||
) {
|
) {
|
||||||
device->coop_mat_m = prop.MSize;
|
device->coop_mat_m = prop.MSize;
|
||||||
device->coop_mat_n = prop.NSize;
|
device->coop_mat_n = prop.NSize;
|
||||||
|
@ -365,15 +365,20 @@ void main() {
|
|||||||
|
|
||||||
const vec2 loadd = vec2(data_a[ib].d);
|
const vec2 loadd = vec2(data_a[ib].d);
|
||||||
|
|
||||||
uint8_t sc;
|
const uint scidx0 = (is < 4) ? is : (is + 4);
|
||||||
uint8_t mbyte;
|
const uint scidx1 = (is < 4) ? is : (is - 4);
|
||||||
if (is < 4) {
|
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||||
sc = uint8_t(data_a[ib].scales[is ] & 63);
|
const uint scidxshift1 = (is < 4) ? 0 : 2;
|
||||||
mbyte = uint8_t(data_a[ib].scales[is + 4] & 63);
|
const uint mbidx0 = is + 4;
|
||||||
} else {
|
const uint mbidx1 = (is < 4) ? is + 4 : is;
|
||||||
sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4));
|
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
|
||||||
mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4));
|
const uint mbidxshift0 = (is < 4) ? 0 : 4;
|
||||||
}
|
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||||
|
const uint mbidxshift1 = (is < 4) ? 0 : 2;
|
||||||
|
|
||||||
|
const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
|
||||||
|
const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
|
||||||
|
|
||||||
const float d = loadd.x * sc;
|
const float d = loadd.x * sc;
|
||||||
const float m = -loadd.y * mbyte;
|
const float m = -loadd.y * mbyte;
|
||||||
|
|
||||||
@ -396,15 +401,20 @@ void main() {
|
|||||||
|
|
||||||
const vec2 loadd = vec2(data_a[ib].d);
|
const vec2 loadd = vec2(data_a[ib].d);
|
||||||
|
|
||||||
uint8_t sc;
|
const uint scidx0 = (is < 4) ? is : (is + 4);
|
||||||
uint8_t mbyte;
|
const uint scidx1 = (is < 4) ? is : (is - 4);
|
||||||
if (is < 4) {
|
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||||
sc = uint8_t(data_a[ib].scales[is ] & 63);
|
const uint scidxshift1 = (is < 4) ? 0 : 2;
|
||||||
mbyte = uint8_t(data_a[ib].scales[is + 4] & 63);
|
const uint mbidx0 = is + 4;
|
||||||
} else {
|
const uint mbidx1 = (is < 4) ? is + 4 : is;
|
||||||
sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4));
|
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
|
||||||
mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4));
|
const uint mbidxshift0 = (is < 4) ? 0 : 4;
|
||||||
}
|
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||||
|
const uint mbidxshift1 = (is < 4) ? 0 : 2;
|
||||||
|
|
||||||
|
const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
|
||||||
|
const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
|
||||||
|
|
||||||
const float d = loadd.x * sc;
|
const float d = loadd.x * sc;
|
||||||
const float m = -loadd.y * mbyte;
|
const float m = -loadd.y * mbyte;
|
||||||
|
|
||||||
@ -547,8 +557,8 @@ void main() {
|
|||||||
|
|
||||||
#ifdef COOPMAT
|
#ifdef COOPMAT
|
||||||
#ifdef MUL_MAT_ID
|
#ifdef MUL_MAT_ID
|
||||||
for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||||
for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
|
||||||
|
|
||||||
[[unroll]] for (uint col = 0; col < BN; col += storestride) {
|
[[unroll]] for (uint col = 0; col < BN; col += storestride) {
|
||||||
@ -564,8 +574,8 @@ void main() {
|
|||||||
#else
|
#else
|
||||||
const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
|
const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
|
||||||
|
|
||||||
for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||||
for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||||
const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
|
const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
|
||||||
|
|
||||||
if (is_aligned && is_in_bounds) {
|
if (is_aligned && is_in_bounds) {
|
||||||
|
Loading…
Reference in New Issue
Block a user