imatrix : handle partial entries (#7833)

This commit is contained in:
Georgi Gerganov 2024-06-09 20:19:35 +03:00 committed by GitHub
parent 57bf62ce7c
commit e95beeb1fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -218,20 +218,64 @@ void IMatrixCollector::save_imatrix(int ncall) const {
fname += std::to_string(ncall); fname += std::to_string(ncall);
} }
// avoid writing imatrix entries that do not have full data
// this can happen with MoE models where some of the experts end up not being exercised by the provided training data
int n_entries = 0;
std::vector<std::string> to_store;
bool is_first = true; // for printing
for (const auto & kv : m_stats) {
const int n_all = kv.second.counts.size();
if (n_all == 0) {
continue;
}
int n_zeros = 0;
for (const int c : kv.second.counts) {
if (c == 0) {
n_zeros++;
}
}
if (n_zeros != 0 && is_first) {
fprintf(stderr, "\n");
is_first = false;
}
if (n_zeros == n_all) {
fprintf(stderr, "%s: entry '%40s' has no data - skipping\n", __func__, kv.first.c_str());
continue;
}
if (n_zeros > 0) {
fprintf(stderr, "%s: entry '%40s' has partial data (%.2f%%) - skipping\n", __func__, kv.first.c_str(), 100.0f * (n_all - n_zeros) / n_all);
continue;
}
n_entries++;
to_store.push_back(kv.first);
}
if (to_store.size() < m_stats.size()) {
fprintf(stderr, "%s: warning: storing only %zu out of %zu entries\n", __func__, to_store.size(), m_stats.size());
}
std::ofstream out(fname, std::ios::binary); std::ofstream out(fname, std::ios::binary);
int n_entries = m_stats.size();
out.write((const char *) &n_entries, sizeof(n_entries)); out.write((const char *) &n_entries, sizeof(n_entries));
for (const auto & p : m_stats) { for (const auto & name : to_store) {
int len = p.first.size(); const auto & stat = m_stats.at(name);
int len = name.size();
out.write((const char *) &len, sizeof(len)); out.write((const char *) &len, sizeof(len));
out.write(p.first.c_str(), len); out.write(name.c_str(), len);
out.write((const char *) &p.second.ncall, sizeof(p.second.ncall)); out.write((const char *) &stat.ncall, sizeof(stat.ncall));
int nval = p.second.values.size(); int nval = stat.values.size();
out.write((const char *) &nval, sizeof(nval)); out.write((const char *) &nval, sizeof(nval));
if (nval > 0) { if (nval > 0) {
std::vector<float> tmp(nval); std::vector<float> tmp(nval);
for (int i = 0; i < nval; i++) { for (int i = 0; i < nval; i++) {
tmp[i] = (p.second.values[i] / static_cast<float>(p.second.counts[i])) * static_cast<float>(p.second.ncall); tmp[i] = (stat.values[i] / static_cast<float>(stat.counts[i])) * static_cast<float>(stat.ncall);
} }
out.write((const char*)tmp.data(), nval*sizeof(float)); out.write((const char*)tmp.data(), nval*sizeof(float));
} }