mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 10:24:35 +00:00
ggml : fix quant dot product with odd number of blocks (#8549)
* ggml : fix iq4_nl dot product with odd number of blocks * ggml : fix odd blocks for ARM_NEON (#8556) * ggml : fix iq4_nl dot product with odd number of blocks * ggml : fix q4_1 * ggml : fix q5_0 * ggml : fix q5_1 * ggml : fix iq4_nl metal ggml-ci * ggml : fix q4_0 * ggml : fix q8_0 ggml-ci * ggml : remove special Q4_0 code for first 2 blocks * ggml : fix sumf redefinition --------- Co-authored-by: slaren <slarengh@gmail.com> --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
57b1d4f9eb
commit
87e397d00b
@ -1786,10 +1786,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
}
|
||||
};
|
||||
|
||||
if (ggml_is_quantized(src0t)) {
|
||||
GGML_ASSERT(ne00 >= nth0*nth1);
|
||||
}
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
|
@ -4757,7 +4757,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
||||
device const float4 * y4 = (device const float4 *)yb;
|
||||
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
|
||||
|
||||
for (int row = 0; row < 2; ++row) {
|
||||
for (int row = 0; row < 2 && first_row + row < ne01; ++row) {
|
||||
|
||||
device const block_iq4_nl & xb = x[row*nb + ib];
|
||||
device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
|
||||
@ -4789,7 +4789,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
||||
yb += 16 * QK4_NL;
|
||||
}
|
||||
|
||||
for (int row = 0; row < 2; ++row) {
|
||||
for (int row = 0; row < 2 && first_row + row < ne01; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -79,8 +79,16 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
|
||||
im = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
ggml_quantize_chunk(tensor->type, data.data(), dataq.data(), 0, size/tensor->ne[0], tensor->ne[0], im);
|
||||
GGML_ASSERT(ggml_validate_row_data(tensor->type, dataq.data(), dataq.size()));
|
||||
// TODO: other cases
|
||||
//#pragma omp parallel for
|
||||
//for (int i = 0; i < tensor->ne[1]; i++) {
|
||||
// ggml_quantize_chunk(tensor->type, data.data(), dataq.data(),
|
||||
// i * tensor->ne[0], 1, tensor->ne[0], im);
|
||||
//}
|
||||
|
||||
ggml_backend_tensor_set(tensor, dataq.data(), 0, dataq.size());
|
||||
} else if (tensor->type == GGML_TYPE_I8 || tensor->type == GGML_TYPE_I16 || tensor->type == GGML_TYPE_I32) {
|
||||
// This is going to create some weird integers though.
|
||||
@ -2220,6 +2228,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps));
|
||||
}
|
||||
|
||||
#if 1
|
||||
for (ggml_type type_a : base_types) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
|
||||
@ -2239,6 +2248,24 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2}));
|
||||
}
|
||||
}
|
||||
#else
|
||||
// m = a rows
|
||||
// n = b rows
|
||||
// k = cols
|
||||
std::uniform_int_distribution<> dist_m(1, 128);
|
||||
std::uniform_int_distribution<> dist_n(16, 128);
|
||||
std::uniform_int_distribution<> dist_k(1, 16);
|
||||
for (int i = 0; i < 1000; i++) {
|
||||
for (ggml_type type_a : all_types) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
||||
int m = dist_m(rng);
|
||||
int n = dist_n(rng);
|
||||
int k = dist_k(rng) * ggml_blck_size(type_a);
|
||||
test_cases.emplace_back(new test_mul_mat(type_a, type_b, m, n, k, { 1, 1}, {1, 1}));
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
for (ggml_type type_a : other_types) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
||||
|
Loading…
Reference in New Issue
Block a user