From 7122afed5a89c082fac028ab152cc50af3e57386 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Thu, 4 Apr 2024 07:16:41 -0700 Subject: [PATCH] Add note on weight update and improve error message PiperOrigin-RevId: 621849989 --- README.md | 5 +++++ compression/blob_store.cc | 8 +++++++- gemma.cc | 6 ++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 5b0df182..7c9b3949 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,11 @@ For additional information about Gemma, see specific artifacts, are [available on kaggle](https://www.kaggle.com/models/google/gemma). +NOTE: 2024-04-04: if using 2B models, please re-download weights from Kaggle and +ensure you have the latest version (-mqa or version 3). We are changing the code +to match the new weights. If you wish to use old weights, change `ConfigGemma2B` +in `configs.h` back to `kVocabSize = 256128` and `kKVHeads = 8`. + ## Who is this project for? Modern LLM inference engines are sophisticated systems, often with bespoke diff --git a/compression/blob_store.cc b/compression/blob_store.cc index 050dfbd4..b47515a8 100644 --- a/compression/blob_store.cc +++ b/compression/blob_store.cc @@ -378,7 +378,13 @@ BlobError BlobReader::Enqueue(hwy::uint128_t key, void* data, size_t size) { uint64_t offset; size_t actual_size; if (!blob_store_->FindKey(key, offset, actual_size)) return __LINE__; - if (actual_size != size) return __LINE__; + if (actual_size != size) { + fprintf(stderr, + "Mismatch between expected %d and actual %d KiB size. Please see " + "README.md on how to update the weights.\n", + static_cast(size >> 10), static_cast(actual_size >> 10)); + return __LINE__; + } EnqueueChunkRequests(offset, actual_size, reinterpret_cast(data), requests_); diff --git a/gemma.cc b/gemma.cc index 23fc5971..bfaa8122 100644 --- a/gemma.cc +++ b/gemma.cc @@ -109,6 +109,12 @@ hwy::AlignedUniquePtr> LoadWeights(const Path& checkpoint) { weights->layers = hwy::MakeUniqueAlignedArray>(TConfig::kLayers); + if (checkpoint.path.empty()) { + HWY_ABORT( + "Loading --compressed_weights failed; we require a --weights argument. " + "Please see issue #11 on how to create this file.\n"); + } + FILE* fptr; fptr = fopen(checkpoint.path.c_str(), "rb"); if (fptr == nullptr) {