Skip to content

Commit

Permalink
Use hashmap to track early stopping (#155)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZachNagengast authored May 30, 2024
1 parent 3aa94e8 commit aa4bb90
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 9 deletions.
8 changes: 4 additions & 4 deletions Examples/WhisperAX/WhisperAX.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -608,14 +608,14 @@
GENERATE_INFOPLIST_FILE = YES;
INFOPLIST_KEY_NSMicrophoneUsageDescription = "Required to record audio from the microphone for transcription.";
INFOPLIST_KEY_UISupportedInterfaceOrientations = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown";
INFOPLIST_KEY_WKCompanionAppBundleIdentifier = com.argmax.whisperkit.WhisperAX;
INFOPLIST_KEY_WKCompanionAppBundleIdentifier = "com.argmax.whisperkit.WhisperAX${DEVELOPMENT_TEAM}";
INFOPLIST_KEY_WKRunsIndependentlyOfCompanionApp = YES;
LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)",
"@executable_path/Frameworks",
);
MARKETING_VERSION = 0.1.2;
PRODUCT_BUNDLE_IDENTIFIER = com.argmax.whisperkit.WhisperAX.watchapp;
PRODUCT_BUNDLE_IDENTIFIER = "com.argmax.whisperkit.WhisperAX${DEVELOPMENT_TEAM}.watchapp";
PRODUCT_NAME = "WhisperAX Watch App";
PROVISIONING_PROFILE_SPECIFIER = "";
SDKROOT = watchos;
Expand Down Expand Up @@ -893,7 +893,7 @@
LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks";
"LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks";
MACOSX_DEPLOYMENT_TARGET = 14.0;
MARKETING_VERSION = 0.3.0;
MARKETING_VERSION = 0.3.1;
PRODUCT_BUNDLE_IDENTIFIER = "com.argmax.whisperkit.WhisperAX${DEVELOPMENT_TEAM}";
PRODUCT_NAME = "$(TARGET_NAME)";
SDKROOT = auto;
Expand Down Expand Up @@ -939,7 +939,7 @@
LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks";
"LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks";
MACOSX_DEPLOYMENT_TARGET = 14.0;
MARKETING_VERSION = 0.3.0;
MARKETING_VERSION = 0.3.1;
PRODUCT_BUNDLE_IDENTIFIER = com.argmax.whisperkit.WhisperAX;
PRODUCT_NAME = "$(TARGET_NAME)";
SDKROOT = auto;
Expand Down
4 changes: 4 additions & 0 deletions Examples/WhisperAX/WhisperAX/Views/ContentView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1308,10 +1308,12 @@ struct ContentView: View {
let checkTokens: [Int] = currentTokens.suffix(checkWindow)
let compressionRatio = compressionRatio(of: checkTokens)
if compressionRatio > options.compressionRatioThreshold! {
Logging.debug("Early stopping due to compression threshold")
return false
}
}
if progress.avgLogprob! < options.logProbThreshold! {
Logging.debug("Early stopping due to logprob threshold")
return false
}
return nil
Expand Down Expand Up @@ -1519,10 +1521,12 @@ struct ContentView: View {
let checkTokens: [Int] = currentTokens.suffix(checkWindow)
let compressionRatio = compressionRatio(of: checkTokens)
if compressionRatio > options.compressionRatioThreshold! {
Logging.debug("Early stopping due to compression threshold")
return false
}
}
if progress.avgLogprob! < options.logProbThreshold! {
Logging.debug("Early stopping due to logprob threshold")
return false
}

Expand Down
12 changes: 8 additions & 4 deletions Sources/WhisperKit/Core/TextDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ open class TextDecoder: TextDecoding, WhisperMLModel {
public var tokenizer: WhisperTokenizer?
public var prefillData: WhisperMLModel?
public var isModelMultilingual: Bool = false
public var shouldEarlyStop: Bool = false
public var shouldEarlyStop = [UUID: Bool]()
private var languageLogitsFilter: LanguageLogitsFilter?

public var supportsWordTimestamps: Bool {
Expand Down Expand Up @@ -588,7 +588,8 @@ open class TextDecoder: TextDecoding, WhisperMLModel {
Logging.debug("Running main loop for a maximum of \(loopCount) iterations, starting at index \(prefilledIndex)")
var hasAlignment = false
var isFirstTokenLogProbTooLow = false
shouldEarlyStop = false
let windowUUID = UUID()
shouldEarlyStop[windowUUID] = false
for tokenIndex in prefilledIndex..<loopCount {
let loopStart = Date()

Expand Down Expand Up @@ -733,7 +734,7 @@ open class TextDecoder: TextDecoding, WhisperMLModel {
let shouldContinue = callback(result)
if let shouldContinue = shouldContinue, !shouldContinue, !isPrefill {
Logging.debug("Early stopping")
self?.shouldEarlyStop = true
self?.shouldEarlyStop[windowUUID] = true
}
}
}
Expand All @@ -749,10 +750,13 @@ open class TextDecoder: TextDecoding, WhisperMLModel {
}

// Check if early stopping is triggered
if shouldEarlyStop {
if let shouldStop = shouldEarlyStop[windowUUID], shouldStop {
break
}
}

// Cleanup the early stop flag after loop completion
shouldEarlyStop.removeValue(forKey: windowUUID)

let cache = DecodingCache(
keyCache: decoderInputs.keyCache,
Expand Down
2 changes: 1 addition & 1 deletion Sources/WhisperKit/Core/Utils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ public extension MLComputeUnits {

#if os(macOS)
// From: https://stackoverflow.com/a/71726663
extension Process {
public extension Process {
static func stringFromTerminal(command: String) -> String {
let task = Process()
let pipe = Pipe()
Expand Down

0 comments on commit aa4bb90

Please sign in to comment.