diff --git a/Sources/SpeziLLMLocal/LLMLocalSchema+PromptFormatting.swift b/Sources/SpeziLLMLocal/LLMLocalSchema+PromptFormatting.swift index d4861f9..4240514 100644 --- a/Sources/SpeziLLMLocal/LLMLocalSchema+PromptFormatting.swift +++ b/Sources/SpeziLLMLocal/LLMLocalSchema+PromptFormatting.swift @@ -12,6 +12,85 @@ import SpeziLLM extension LLMLocalSchema { /// Holds default prompt formatting strategies for [Llama2](https://ai.meta.com/llama/) as well as [Phi-2](https://www.microsoft.com/en-us/research/blog/phi-2-the-surprising-power-of-small-language-models/) models. public enum PromptFormattingDefaults { + /// Prompt formatting closure for the [Llama3](https://ai.meta.com/llama/) model + public static let llama3: (@Sendable (LLMContext) throws -> String) = { chat in // swiftlint:disable:this closure_body_length + let BEGINOFTEXT = "<|begin_of_text|>" + let SYSTEM = "system" + let USER = "user" + let ASSISTANT = "assistant" + let STARTHEADERID = "<|start_header_id|>" + let ENDHEADERID = "<|end_header_id|>" + let EOTID = "<|eot_id|>" + + guard chat.first?.role == .system else { + throw LLMLocalError.illegalContext + } + + var systemPrompts: [String] = [] + var initialUserPrompt: String = "" + + for contextEntity in chat { + if contextEntity.role != .system { + if contextEntity.role == .user { + initialUserPrompt = contextEntity.content + break + } else { + throw LLMLocalError.illegalContext + } + } + + systemPrompts.append(contextEntity.content) + } + + /// Build the initial Llama3 prompt structure + /// + /// A template of the prompt structure looks like: + /// """ + /// [INST] <> + /// {your_system_prompt} + /// <> + /// + /// {user_message_1} [/INST] + /// """ + var prompt = """ + \(BEGINOFTEXT) + \(STARTHEADERID)\(SYSTEM)\(ENDHEADERID) + \(systemPrompts.joined(separator: " ")) + \(EOTID) + + \(STARTHEADERID)\(USER)\(ENDHEADERID) + \(initialUserPrompt) + \(EOTID) + + """ + " " // Add a spacer to the generated output from the model + + for contextEntity in chat.dropFirst(2) { + if contextEntity.role == .assistant() { + /// Append response from assistant to the Llama3 prompt structure + /// + /// A template for appending an assistant response to the overall prompt looks like: + /// {user_message_1} [/INST]){model_reply_1} + prompt += """ + \(STARTHEADERID)\(ASSISTANT)\(ENDHEADERID) + \(contextEntity.content) + \(EOTID) + """ + } else if contextEntity.role == .user { + /// Append response from user to the Llama3 prompt structure + /// + /// A template for appending an assistant response to the overall prompt looks like: + /// [INST] {user_message_2} [/INST] + prompt += """ + \(STARTHEADERID)\(USER)\(ENDHEADERID) + \(contextEntity.content) + \(EOTID) + """ + " " // Add a spacer to the generated output from the model + } + } + + return prompt + } + /// Prompt formatting closure for the [Llama2](https://ai.meta.com/llama/) model public static let llama2: (@Sendable (LLMContext) throws -> String) = { chat in // swiftlint:disable:this closure_body_length /// BOS token of the LLM, used at the start of each prompt passage. diff --git a/Tests/UITests/TestApp/LLMLocal/LLMLocalChatTestView.swift b/Tests/UITests/TestApp/LLMLocal/LLMLocalChatTestView.swift index 3f54520..34b5797 100644 --- a/Tests/UITests/TestApp/LLMLocal/LLMLocalChatTestView.swift +++ b/Tests/UITests/TestApp/LLMLocal/LLMLocalChatTestView.swift @@ -27,7 +27,8 @@ struct LLMLocalChatTestView: View { with: LLMLocalSchema( modelPath: .cachesDirectory.appending(path: "llm.gguf"), parameters: .init(maxOutputLength: 512), - contextParameters: .init(contextWindowSize: 1024) + contextParameters: .init(contextWindowSize: 1024), + formatChat: LLMLocalSchema.PromptFormattingDefaults.llama3 ) ) }