From 26b1e07756556b1b329f3a9a3267a5f99c0c2e78 Mon Sep 17 00:00:00 2001 From: Leon Nissen <50104433+LeonNissen@users.noreply.github.com> Date: Wed, 18 Dec 2024 15:21:13 -0800 Subject: [PATCH] Fix Download and Mocking LLMLocalSession (#81) --- Sources/SpeziLLM/Mock/LLMMockSession.swift | 26 ++-- Sources/SpeziLLMLocal/LLMLocalPlatform.swift | 13 +- .../LLMLocalSession+Generate.swift | 28 +++++ .../SpeziLLMLocal/LLMLocalSession+Setup.swift | 22 ++++ .../Mock/LLMLocalMockSession.swift | 111 ------------------ .../Resources/Localizable.xcstrings | 11 ++ .../LLMLocalDownloadManager.swift | 50 ++++---- .../LLMLocalDownloadView.swift | 22 ++-- .../LLMLocalOnboardingDownloadView.swift | 2 +- 9 files changed, 118 insertions(+), 167 deletions(-) delete mode 100644 Sources/SpeziLLMLocal/Mock/LLMLocalMockSession.swift diff --git a/Sources/SpeziLLM/Mock/LLMMockSession.swift b/Sources/SpeziLLM/Mock/LLMMockSession.swift index 751acc02..2169110f 100644 --- a/Sources/SpeziLLM/Mock/LLMMockSession.swift +++ b/Sources/SpeziLLM/Mock/LLMMockSession.swift @@ -52,29 +52,19 @@ public final class LLMMockSession: LLMSession, @unchecked Sendable { return } - /// Generate mock messages await MainActor.run { self.state = .generating } - await injectAndYield("Mock ", on: continuation) - - try? await Task.sleep(for: .milliseconds(500)) - guard await !checkCancellation(on: continuation) else { - return - } - await injectAndYield("Message ", on: continuation) - - try? await Task.sleep(for: .milliseconds(500)) - guard await !checkCancellation(on: continuation) else { - return - } - await injectAndYield("from ", on: continuation) - try? await Task.sleep(for: .milliseconds(500)) - guard await !checkCancellation(on: continuation) else { - return + /// Generate mock messages + let tokens = ["Mock ", "Message ", "from ", "SpeziLLM!"] + for token in tokens { + try? await Task.sleep(for: .milliseconds(500)) + guard await !checkCancellation(on: continuation) else { + return + } + await injectAndYield(token, on: continuation) } - await injectAndYield("SpeziLLM!", on: continuation) continuation.finish() await MainActor.run { diff --git a/Sources/SpeziLLMLocal/LLMLocalPlatform.swift b/Sources/SpeziLLMLocal/LLMLocalPlatform.swift index 83a23dd1..d4ddf979 100644 --- a/Sources/SpeziLLMLocal/LLMLocalPlatform.swift +++ b/Sources/SpeziLLMLocal/LLMLocalPlatform.swift @@ -63,7 +63,12 @@ public actor LLMLocalPlatform: LLMPlatform, DefaultInitializable { Logger( subsystem: "Spezi", category: "LLMLocalPlatform" - ).warning("SpeziLLMLocal is only supported on physical devices. Use `LLMMockPlatform` instead.") + ).warning("SpeziLLMLocal is only supported on physical devices. A mock session will be used instead.") + + Logger( + subsystem: "Spezi", + category: "LLMLocalPlatform" + ).warning("\(String(localized: "LLM_MLX_NOT_SUPPORTED_WORKAROUND", bundle: .module))") #else if let cacheLimit = configuration.cacheLimit { MLX.GPU.set(cacheLimit: cacheLimit * 1024 * 1024) @@ -74,15 +79,9 @@ public actor LLMLocalPlatform: LLMPlatform, DefaultInitializable { #endif } -#if targetEnvironment(simulator) - public nonisolated func callAsFunction(with llmSchema: LLMLocalSchema) -> LLMLocalMockSession { - LLMLocalMockSession(self, schema: llmSchema) - } -#else public nonisolated func callAsFunction(with llmSchema: LLMLocalSchema) -> LLMLocalSession { LLMLocalSession(self, schema: llmSchema) } -#endif deinit { MLX.GPU.clearCache() diff --git a/Sources/SpeziLLMLocal/LLMLocalSession+Generate.swift b/Sources/SpeziLLMLocal/LLMLocalSession+Generate.swift index c7785d72..4182637e 100644 --- a/Sources/SpeziLLMLocal/LLMLocalSession+Generate.swift +++ b/Sources/SpeziLLMLocal/LLMLocalSession+Generate.swift @@ -18,6 +18,11 @@ import SpeziLLM extension LLMLocalSession { // swiftlint:disable:next identifier_name function_body_length internal func _generate(continuation: AsyncThrowingStream.Continuation) async { +#if targetEnvironment(simulator) + // swiftlint:disable:next return_value_from_void_function + return await _mockGenerate(continuation: continuation) +#endif + guard let modelContainer = await self.modelContainer else { Self.logger.error("SpeziLLMLocal: Failed to load `modelContainer`") await finishGenerationWithError(LLMLocalError.modelNotFound, on: continuation) @@ -119,4 +124,27 @@ extension LLMLocalSession { state = .ready } } + + private func _mockGenerate(continuation: AsyncThrowingStream.Continuation) async { + let tokens = [ + "Mock ", "Message ", "from ", "SpeziLLM! ", + "**Using SpeziLLMLocal only works on physical devices.**", + "\n\n", + String(localized: "LLM_MLX_NOT_SUPPORTED_WORKAROUND", bundle: .module) + ] + + for token in tokens { + try? await Task.sleep(for: .seconds(1)) + guard await !checkCancellation(on: continuation) else { + return + } + continuation.yield(token) + } + + continuation.finish() + await MainActor.run { + context.completeAssistantStreaming() + self.state = .ready + } + } } diff --git a/Sources/SpeziLLMLocal/LLMLocalSession+Setup.swift b/Sources/SpeziLLMLocal/LLMLocalSession+Setup.swift index 715ba732..91369905 100644 --- a/Sources/SpeziLLMLocal/LLMLocalSession+Setup.swift +++ b/Sources/SpeziLLMLocal/LLMLocalSession+Setup.swift @@ -27,6 +27,10 @@ extension LLMLocalSession { // swiftlint:disable:next identifier_name internal func _setup(continuation: AsyncThrowingStream.Continuation?) async -> Bool { +#if targetEnvironment(simulator) + return await _mockSetup(continuation: continuation) +#endif + Self.logger.debug("SpeziLLMLocal: Local LLM is being initialized") await MainActor.run { @@ -62,4 +66,22 @@ extension LLMLocalSession { Self.logger.debug("SpeziLLMLocal: Local LLM has finished initializing") return true } + + private func _mockSetup(continuation: AsyncThrowingStream.Continuation?) async -> Bool { + Self.logger.debug("SpeziLLMLocal: Local Mock LLM is being initialized") + + await MainActor.run { + self.state = .loading + } + + try? await Task.sleep(for: .seconds(1)) + + await MainActor.run { + self.state = .ready + } + + Self.logger.debug("SpeziLLMLocal: Local Mock LLM has finished initializing") + + return true + } } diff --git a/Sources/SpeziLLMLocal/Mock/LLMLocalMockSession.swift b/Sources/SpeziLLMLocal/Mock/LLMLocalMockSession.swift deleted file mode 100644 index 690c425b..00000000 --- a/Sources/SpeziLLMLocal/Mock/LLMLocalMockSession.swift +++ /dev/null @@ -1,111 +0,0 @@ -// -// This source file is part of the Stanford Spezi open source project -// -// SPDX-FileCopyrightText: 2024 Stanford University and the project authors (see CONTRIBUTORS.md) -// -// SPDX-License-Identifier: MIT -// - -import Foundation -import Observation -import SpeziLLM - - -/// A mock ``LLMLocalMockSession``, used for testing purposes. -/// -/// See `LLMMockSession` for more details -@Observable -public final class LLMLocalMockSession: LLMSession, @unchecked Sendable { - let platform: LLMLocalPlatform - let schema: LLMLocalSchema - - @ObservationIgnored private var task: Task<(), Never>? - - @MainActor public var state: LLMState = .uninitialized - @MainActor public var context: LLMContext = [] - - - /// Initializer for the ``LLMMockSession``. - /// - /// - Parameters: - /// - platform: The mock LLM platform. - /// - schema: The mock LLM schema. - init(_ platform: LLMLocalPlatform, schema: LLMLocalSchema) { - self.platform = platform - self.schema = schema - } - - - @discardableResult - public func generate() async throws -> AsyncThrowingStream { - let (stream, continuation) = AsyncThrowingStream.makeStream(of: String.self) - - // swiftlint:disable:next closure_body_length - task = Task { - await MainActor.run { - self.state = .loading - } - try? await Task.sleep(for: .seconds(1)) - guard await !checkCancellation(on: continuation) else { - return - } - - /// Generate mock messages - await MainActor.run { - self.state = .generating - } - await injectAndYield("Mock ", on: continuation) - - try? await Task.sleep(for: .milliseconds(500)) - guard await !checkCancellation(on: continuation) else { - return - } - await injectAndYield("Message ", on: continuation) - - try? await Task.sleep(for: .milliseconds(500)) - guard await !checkCancellation(on: continuation) else { - return - } - await injectAndYield("from ", on: continuation) - - try? await Task.sleep(for: .milliseconds(500)) - guard await !checkCancellation(on: continuation) else { - return - } - await injectAndYield("SpeziLLM! ", on: continuation) - - try? await Task.sleep(for: .milliseconds(500)) - guard await !checkCancellation(on: continuation) else { - return - } - await injectAndYield("Using SpeziLLMLocal only works on physical devices.", on: continuation) - - - continuation.finish() - await MainActor.run { - context.completeAssistantStreaming() - self.state = .ready - } - } - - return stream - } - - public func cancel() { - task?.cancel() - } - - private func injectAndYield(_ piece: String, on continuation: AsyncThrowingStream.Continuation) async { - continuation.yield(piece) - if schema.injectIntoContext { - await MainActor.run { - context.append(assistantOutput: piece) - } - } - } - - - deinit { - cancel() - } -} diff --git a/Sources/SpeziLLMLocal/Resources/Localizable.xcstrings b/Sources/SpeziLLMLocal/Resources/Localizable.xcstrings index b2290765..bd3f7599 100644 --- a/Sources/SpeziLLMLocal/Resources/Localizable.xcstrings +++ b/Sources/SpeziLLMLocal/Resources/Localizable.xcstrings @@ -91,6 +91,17 @@ } } }, + "LLM_MLX_NOT_SUPPORTED_WORKAROUND" : { + "extractionState" : "manual", + "localizations" : { + "en" : { + "stringUnit" : { + "state" : "translated", + "value" : "Here are two recommended workarounds:\n1. Add the Mac (Designed for iPad) destination to your target in Xcode.\n- SpeziLLMLocal requires MLX which requires Apple silicon, with `Mac (Designed for iPad)` you build an iPad application that will run on macOS.\n- The UI may present with differences to iOS, but this will allow you to build an iOS binary that runs with a fully featured Metal GPU.\n\n2. Make a multiplatform application that can run on macOS, iOS and iPadOS.\n- With SwiftUI it is possible to do most of your development in a macOS application and fine tune it for iOS by running it on an actual device.\n\nYou can also use the simulator for developing UI features but local LLM execution is not possible." + } + } + } + }, "LLM_MODEL_NOT_FOUND_ERROR_DESCRIPTION" : { "localizations" : { "en" : { diff --git a/Sources/SpeziLLMLocalDownload/LLMLocalDownloadManager.swift b/Sources/SpeziLLMLocalDownload/LLMLocalDownloadManager.swift index 1e1b2b89..a022e2b7 100644 --- a/Sources/SpeziLLMLocalDownload/LLMLocalDownloadManager.swift +++ b/Sources/SpeziLLMLocalDownload/LLMLocalDownloadManager.swift @@ -8,11 +8,11 @@ import Foundation import Hub -import MLXLLM import Observation import SpeziLLMLocal import SpeziViews + /// Manages the download and storage of Large Language Models (LLM) to the local device. /// /// One configures the ``LLMLocalDownloadManager`` via the ``LLMLocalDownloadManager/init(llmDownloadUrl:llmStorageUrl:)`` initializer, @@ -28,7 +28,7 @@ public final class LLMLocalDownloadManager: NSObject { public enum DownloadState: Equatable { case idle case downloading(progress: Progress) - case downloaded(storageUrl: URL) + case downloaded case error(LocalizedError) @@ -47,10 +47,10 @@ public final class LLMLocalDownloadManager: NSObject { @ObservationIgnored private var downloadTask: Task<(), Never>? /// Indicates the current state of the ``LLMLocalDownloadManager``. @MainActor public var state: DownloadState = .idle - private let modelConfiguration: ModelConfiguration + private let model: LLMLocalModel - @ObservationIgnored public var modelExists: Bool { - LLMLocalDownloadManager.modelExsist(model: .custom(id: modelConfiguration.name)) + @ObservationIgnored public var modelExist: Bool { + LLMLocalDownloadManager.modelExist(model: model) } /// Initializes a ``LLMLocalDownloadManager`` instance to manage the download of Large Language Model (LLM) files from remote servers. @@ -58,14 +58,14 @@ public final class LLMLocalDownloadManager: NSObject { /// - Parameters: /// - modelID: The Huggingface model ID of the LLM that needs to be downloaded. public init(model: LLMLocalModel) { - self.modelConfiguration = .init(id: model.hubID) + self.model = model } /// Checks if a model is already downloaded to the local device. /// /// - Parameter model: The model to check for local existence. /// - Returns: A Boolean value indicating whether the model exists on the device. - public static func modelExsist(model: LLMLocalModel) -> Bool { + public static func modelExist(model: LLMLocalModel) -> Bool { let repo = Hub.Repo(id: model.hubID) let url = HubApi.shared.localRepoLocation(repo) let modelFileExtension = ".safetensors" @@ -79,28 +79,23 @@ public final class LLMLocalDownloadManager: NSObject { } /// Starts a `URLSessionDownloadTask` to download the specified model. - public func startDownload() { - if case let .directory(url) = modelConfiguration.id { + public func startDownload() async { + if modelExist { Task { @MainActor in - self.state = .downloaded(storageUrl: url) + self.state = .downloaded } return } - downloadTask?.cancel() + await cancelDownload() downloadTask = Task(priority: .userInitiated) { do { - _ = try await loadModelContainer(configuration: modelConfiguration) { progress in - Task { @MainActor in - self.state = .downloading(progress: progress) - } - } - - Task { @MainActor in - self.state = .downloaded(storageUrl: modelConfiguration.modelDirectory()) + try await downloadWithHub() + await MainActor.run { + self.state = .downloaded } } catch { - Task { @MainActor in + await MainActor.run { self.state = .error( AnyLocalizedError( error: error, @@ -113,7 +108,20 @@ public final class LLMLocalDownloadManager: NSObject { } /// Cancels the download of a specified model via a `URLSessionDownloadTask`. - public func cancelDownload() { + public func cancelDownload() async { downloadTask?.cancel() + await MainActor.run { + self.state = .idle + } + } + + @MainActor + private func downloadWithHub() async throws { + let repo = Hub.Repo(id: model.hubID) + let modelFiles = ["*.safetensors", "config.json"] + + try await HubApi.shared.snapshot(from: repo, matching: modelFiles) { progress in + self.state = .downloading(progress: progress) + } } } diff --git a/Sources/SpeziLLMLocalDownload/LLMLocalDownloadView.swift b/Sources/SpeziLLMLocalDownload/LLMLocalDownloadView.swift index f1eec999..bacf29db 100644 --- a/Sources/SpeziLLMLocalDownload/LLMLocalDownloadView.swift +++ b/Sources/SpeziLLMLocalDownload/LLMLocalDownloadView.swift @@ -65,7 +65,7 @@ public struct LLMLocalDownloadView: View { VStack { informationView - if !modelExists { + if !modelExist { downloadButton if isDownloading { @@ -89,12 +89,12 @@ public struct LLMLocalDownloadView: View { Spacer() } .transition(.opacity) - .animation(.easeInOut, value: isDownloading || modelExists) + .animation(.easeInOut, value: isDownloading || modelExist) }, actionView: { OnboardingActionsView(.init("LLM_DOWNLOAD_NEXT_BUTTON", bundle: .atURL(from: .module))) { try await self.action() } - .disabled(!modelExists) + .disabled(!modelExist) } ) .map(state: downloadManager.state, to: $viewState) @@ -120,14 +120,18 @@ public struct LLMLocalDownloadView: View { /// Button which starts the download of the model. @MainActor private var downloadButton: some View { - Button(action: downloadManager.startDownload) { + Button { + Task { + await downloadManager.startDownload() + } + } label: { Text("LLM_DOWNLOAD_BUTTON", bundle: .module) .padding(.horizontal) .padding(.vertical, 6) } - .buttonStyle(.borderedProminent) - .disabled(isDownloading) - .padding() + .buttonStyle(.borderedProminent) + .disabled(isDownloading) + .padding() } /// A progress view indicating the state of the download @@ -165,8 +169,8 @@ public struct LLMLocalDownloadView: View { } /// A `Bool` flag indicating if the model already exists on the device - private var modelExists: Bool { - self.downloadManager.modelExists + private var modelExist: Bool { + self.downloadManager.modelExist } diff --git a/Tests/UITests/TestApp/LLMLocal/Onboarding/LLMLocalOnboardingDownloadView.swift b/Tests/UITests/TestApp/LLMLocal/Onboarding/LLMLocalOnboardingDownloadView.swift index 37bbe71b..8e2cd6f6 100644 --- a/Tests/UITests/TestApp/LLMLocal/Onboarding/LLMLocalOnboardingDownloadView.swift +++ b/Tests/UITests/TestApp/LLMLocal/Onboarding/LLMLocalOnboardingDownloadView.swift @@ -19,7 +19,7 @@ struct LLMLocalOnboardingDownloadView: View { var body: some View { LLMLocalDownloadView( - model: .phi3_4bit, + model: .llama3_8B_4bit, downloadDescription: "LLM_DOWNLOAD_DESCRIPTION", action: onboardingNavigationPath.nextStep )