-
Notifications
You must be signed in to change notification settings - Fork 521
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Separate KV cache from GemmaImpl #81
Separate KV cache from GemmaImpl #81
Conversation
This is an interesting idea, I'm not at my desk today but a few pointers to be aware of in considering the design:
Won't have a chance to take a closer look before tonight but these may be of interest. Also can find us in discord if it's useful to chat https://discord.gg/H5jCBAWxAe |
Read/write KV cache from disk is useful but in many use cases, we may want to serialize/deserialize from another medium such as a remote DB. So if I use gemma.cpp as a library I prefer serialize/deserialize methods. 😊
Agreed, this will provide better flexibility and be more easily applied to more scenarios. As I mentioned in the PR, it's currently not possible to implement a single Gemma instance to handle multiple ongoing multi-turn conversations. If I want to handle multiple ongoing conversations, I have to load multiple copies of the weights, which consumes a lot of memory. In this use case, the serialization/deserialization/save/load methods are not needed for now, only need to temporarily store the KV cache corresponding to the conversation in memory. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall.
Main issue I have is I don't understand the reasoning for making CreateKVCache a method. Keeping it free seems more in line with the overall goal of composability between how the KV cace is used and its relationship to a model. Maybe there's a reason for this I'm not seeing?
A minor point to be aware of (no change needed) - this change does increase the parameter count for some functions and in the past the codebase transitioned from something similar with more flexible / decoupled, longer signatures to a more consolidated signature (with the Gemma class consolidating state).
My personal preference is to keep state as decoupled as possible even if it means slightly more verbosity, so I'm fine with this direction. Just be aware this PR is partly reverting tradeoffs (reducing verbosity at the expense of coupling state) from earlier iterations.
gemma.cc
Outdated
PROFILER_ZONE("Startup.tokenizer"); | ||
|
||
HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok()); | ||
} | ||
|
||
template <class Config> | ||
KVCache GemmaImpl<Config>::CreateKVCache() const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the rationale behind making this a method vs. leaving it as a free function? It seems like it would be more flexible / decoupled keeping it as a free function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the original version, the KV cache is created in the constructor of class GemmaImpl
, the size of the cache depends on the structure of the model defined in Config
. For processing multiple ongoing conversations, each conversation requires a separate KV cache, so we need to create a KV cache before each conversation starts, and its size must match the model structure we used. In my opinion, the intuitive thing to do is to create it via the instance of Gemma. However, which Config
used by the instance is not yet exposed to users of the high-level API and they also don't need to know the details of model structure, so I need to use the dynamic dispatch of GemmaInterface
to implement it. That's why I made it a method instead of a free function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a fair assessment given the current state of the implementation... although it's probably not that obvious, the Config structs are used is more like a configuration for the entire runtime, which consists of mostly model parameters but also folds in kSeqLen and kTopK which are not really model parameters.
This was an okay design for the initial "vanilla" interactive use case. As we flesh out the library usage, it's easy to imagine an application where the same model is interacting with different KV caches of different sizes and with multiple samplers that do different things (eg imagine a system with different scratchpads for code / reasoning / RAG with different properties, or a multi-agent system with a shared set of weights etc).
Somewhat orthogonal to this, we want to simplify the management of KV cache limits - having all 3 of kSeqLen, max_len, and max_generathion_length is more complicated than I'd like. One solution is to have the sizing be more dynamic, potentially allocating more as needed up to some limit. There's some pointer arithmetic things that would need to be updated to do this.
So the general direction is probably to break up the coupling between runtime and model to allow for more flexible composability in the library, at least for the lower level *Impl layers. We maintain some of the higher-level / less-verbose interfaces for common usages but hoist them to sit above the core implementation layer as factory functions or possibly convenience wrapper types - sitting at the boundary between the "frontends" and "models" layers (rather than having them baked into the core model code).
You're right CreateKVCache is somewhat coupled to the model parameters today. If it's possible to do without adding too much complexity, I'd like to make changes in this direction of proactively decoupling in this direction, rather than having to peel things back later. If you're up for making the change that would be helpful, but I'd also be okay with merging this as is for now and refining this as we refactor the model/library interfaces.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the detailed response, I probably understand what you are considering.
How about resetting to the original CreateKVCache
function but exposing it in gemma.h
? Maybe we need to add an auxiliary template function to call it like the original constructor of GemmaImpl
and determine which Config
to use based on ModelType
method of LoaderArgs
. Then GemmaInterface
will not change, which can not only solve the current problem but also be beneficial for future refining.
Of course, this is only a workaround solution. If it's acceptable, I'll do it tomorrow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes we can try exposing it in gemma.h
.
The direction I'm moving is gradually moving to factory functions to decouple core types from how their state is populated. By making GemmaImpl decoupled from disk loading concerns, that should make it easier for different frontends wit different needs (eg reading from disk vs streaming) to do whatever they need to without running into roadblocks.
I've made some updates to the library interfaces in this PR #82 (in the examples branch). It's not finished + has some additional polishing to do, especially at the top level Gemma
constructor, but hopefully that gives a more concrete idea of the general direction.
Might consider targeting this PR to that examples/ branch if it doesn't add too much complexity.
3fb6a2b
to
170a9b4
Compare
The modifications I mentioned yesterday have been done and I've targeted this PR to the |
This PR separates the KV cache from
GemmaImpl
so that we can continue specific chat conversations using the stored KV cache. It allows us to use a single Gemma instance to handle multiple ongoing multi-turn conversations.