Skip to content

Commit

Permalink
imatrix : handle partial entries (ggerganov#7833)
Browse files Browse the repository at this point in the history
  • Loading branch information
Georgi Gerganov authored Jun 9, 2024
1 parent 57bf62c commit e95beeb
Showing 1 changed file with 51 additions and 7 deletions.
58 changes: 51 additions & 7 deletions examples/imatrix/imatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,20 +218,64 @@ void IMatrixCollector::save_imatrix(int ncall) const {
fname += std::to_string(ncall);
}

// avoid writing imatrix entries that do not have full data
// this can happen with MoE models where some of the experts end up not being exercised by the provided training data

int n_entries = 0;
std::vector<std::string> to_store;

bool is_first = true; // for printing
for (const auto & kv : m_stats) {
const int n_all = kv.second.counts.size();

if (n_all == 0) {
continue;
}

int n_zeros = 0;
for (const int c : kv.second.counts) {
if (c == 0) {
n_zeros++;
}
}

if (n_zeros != 0 && is_first) {
fprintf(stderr, "\n");
is_first = false;
}

if (n_zeros == n_all) {
fprintf(stderr, "%s: entry '%40s' has no data - skipping\n", __func__, kv.first.c_str());
continue;
}

if (n_zeros > 0) {
fprintf(stderr, "%s: entry '%40s' has partial data (%.2f%%) - skipping\n", __func__, kv.first.c_str(), 100.0f * (n_all - n_zeros) / n_all);
continue;
}

n_entries++;
to_store.push_back(kv.first);
}

if (to_store.size() < m_stats.size()) {
fprintf(stderr, "%s: warning: storing only %zu out of %zu entries\n", __func__, to_store.size(), m_stats.size());
}

std::ofstream out(fname, std::ios::binary);
int n_entries = m_stats.size();
out.write((const char *) &n_entries, sizeof(n_entries));
for (const auto & p : m_stats) {
int len = p.first.size();
for (const auto & name : to_store) {
const auto & stat = m_stats.at(name);
int len = name.size();
out.write((const char *) &len, sizeof(len));
out.write(p.first.c_str(), len);
out.write((const char *) &p.second.ncall, sizeof(p.second.ncall));
int nval = p.second.values.size();
out.write(name.c_str(), len);
out.write((const char *) &stat.ncall, sizeof(stat.ncall));
int nval = stat.values.size();
out.write((const char *) &nval, sizeof(nval));
if (nval > 0) {
std::vector<float> tmp(nval);
for (int i = 0; i < nval; i++) {
tmp[i] = (p.second.values[i] / static_cast<float>(p.second.counts[i])) * static_cast<float>(p.second.ncall);
tmp[i] = (stat.values[i] / static_cast<float>(stat.counts[i])) * static_cast<float>(stat.ncall);
}
out.write((const char*)tmp.data(), nval*sizeof(float));
}
Expand Down

0 comments on commit e95beeb

Please sign in to comment.