Skip to content

Commit

Permalink
grammar suport for LLaMA
Browse files Browse the repository at this point in the history
  • Loading branch information
guinmoon committed Sep 26, 2023
1 parent 4159cf3 commit f68ba38
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Sources/llmfarm_core/AI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ public struct ModelContextParams {
public var embedding = false // embedding mode only
public var processorsConunt = Int32(ProcessInfo.processInfo.processorCount)
public var use_metal = false
public var grammar_path = ""
public var grammar_path:String? = nil

public var warm_prompt = "\n\n\n"

Expand Down
53 changes: 48 additions & 5 deletions Sources/llmfarm_core/LLMBase.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,25 @@ public enum ModelLoadError: Error {
// case unexpected(code: Int)
}


//func bridge(_ obj : T) -> UnsafeMutableRawPointer {
// return UnsafeMutableRawPointer(Unmanaged.passUnretained(obj).toOpaque())
//}
//
//func bridge(_ ptr : UnsafeMutableRawPointer) -> T? {
// return Unmanaged.fromOpaque(ptr).takeUnretainedValue()
//}

//func bridge<T : AnyObject>(obj : T) -> UnsafeRawPointer {
// return UnsafeRawPointer(Unmanaged.passUnretained(obj).toOpaque())
//}
//
//func bridge<T : AnyObject>(ptr : UnsafeRawPointer) -> T {
// return Unmanaged<T>.fromOpaque(ptr).takeUnretainedValue()
//}



public class LLMBase: Model {

// Used to keep old context until it needs to be rotated or purge out for new tokens
Expand Down Expand Up @@ -63,8 +82,8 @@ public class LLMBase: Model {

print("%s: seed = %d\n", params.seed);

if contextParams.grammar_path != ""{
try? self.load_grammar(contextParams.grammar_path)
if contextParams.grammar_path != nil && contextParams.grammar_path! != ""{
try? self.load_grammar(contextParams.grammar_path!)
}

print(String(cString: print_system_info()))
Expand All @@ -77,12 +96,21 @@ public class LLMBase: Model {
print("Logits inited.")
}

func TestMethod(){

}

deinit {

}

public func load_grammar(_ path:String) throws -> Void{
self.grammar = llama_load_grammar(path)
let exception = tryBlock {
self.grammar = llama_load_grammar(path)
}
if exception != nil{
throw ModelLoadError.grammarLoadError
}
}

public override func llm_load_model(path: String = "", contextParams: ModelContextParams = .default, params:gpt_context_params ) throws -> Bool{
Expand Down Expand Up @@ -114,6 +142,7 @@ public class LLMBase: Model {
return gpt_base_n_ctx(ctx)
}


// Simple topK, topP, temp sampling, with repeat penalty
func llm_sample(ctx: OpaquePointer!,
last_n_tokens: inout [ModelToken],
Expand Down Expand Up @@ -164,10 +193,24 @@ public class LLMBase: Model {
logits[nl_token] = nl_logit
}

if (self.grammar != nil) {
llama_sample_grammar(ctx, &candidates_p, self.grammar);
// let swiftTokenCallback : (@convention(c) (Int32 ) -> String?) = {
// in_token -> String? in
// return self.llm_token_to_str(outputToken:in_token)
// }
if (self.grammar != nil ) {
llama_sample_grammar(ctx,&candidates_p, self.grammar)
}

// if (self.grammar != nil) {
// llama_sample_grammar(ctx,&candidates_p, self.grammar, self.llm_token_eos(),bridge(self),
// {(observer) -> Void in
// // Extract pointer to `self` from void pointer:
// let mySelf = Unmanaged.fromOpaque(observer!).takeUnretainedValue()
// // Call instance method:
// mySelf.TestMethod();
// });
// }

var res_token:Int32 = 0

if(temp <= 0) {
Expand Down
68 changes: 66 additions & 2 deletions Sources/llmfarm_core_cpp/llama/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4958,7 +4958,31 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l
}
}

void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) {

typedef struct Callbacks
{
void * classPtr;
void(*callback)(void *);


}Callbacks;

static Callbacks * callbacks = new Callbacks();

void CallSwiftMemberFromC(void * classPtr, void(*callback)(void *))
{
callbacks->classPtr = classPtr;
callbacks->callback = callback;

std::function<void()> actaulCallback = [&](){
callbacks->callback(callbacks->classPtr);
};
actaulCallback();
}

//char* (* _Nonnull token_to_str)(llama_token)

void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar,llama_token token_eos,void * classPtr, void(*callback)(void *) ) {
GGML_ASSERT(ctx);
const int64_t t_start_sample_us = ggml_time_us();

Expand All @@ -4970,8 +4994,48 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
}
}

const llama_token eos = llama_token_eos(ctx);
const llama_token eos = token_eos;

std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
std::vector<llama_grammar_candidate> candidates_grammar;

for (size_t i = 0; i < candidates->size; ++i) {
const llama_token id = candidates->data[i].id;
const std::string piece = llama_token_to_str(ctx, id);
if (id == eos) {
if (!allow_eos) {
candidates->data[i].logit = -INFINITY;
}
} else if (piece.empty() || piece[0] == 0) {
candidates->data[i].logit = -INFINITY;
} else {
candidates_decoded.push_back(decode_utf8(piece.c_str(), grammar->partial_utf8));
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
}
}

const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar);
for (const auto & reject : rejects) {
candidates->data[reject.index].logit = -INFINITY;
}

ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
}


void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar ) {
GGML_ASSERT(ctx);
const int64_t t_start_sample_us = ggml_time_us();

bool allow_eos = false;
for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
allow_eos = true;
break;
}
}

const llama_token eos = llama_token_eos(ctx);
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
std::vector<llama_grammar_candidate> candidates_grammar;

Expand Down
4 changes: 3 additions & 1 deletion Sources/llmfarm_core_cpp/spm-headers/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,9 @@ extern "C" {
LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);

/// @details Apply constraints from grammar
LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar);
///
LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar);
// LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar );

/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
Expand Down

0 comments on commit f68ba38

Please sign in to comment.