Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Support Ranking Method #1820

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

yutyan0119
Copy link

@yutyan0119 yutyan0119 commented Nov 3, 2024

Description

Hello @abetlen! Thank you for your work on this library.
This PR introduces a new rank method in the Llama class, enabling users to rank documents based on their relevance to a given query. This functionality is useful for tasks such as document retrieval and relevance scoring within a list of documents. This addition corresponds to the feature introduced in llama.cpp PR #9510 and also addresses the request for re-ranking support mentioned in Issue #1794.

Changes Made:

1. New Method: rank

  • Added the rank method, which accepts a query string and a list of document strings.
  • Each document is embedded with the query using the embed method.
  • Returns a list of rank scores, representing the relevance of each document to the query.

2. Embed Method Enhancement

  • Added a special_tokenize parameter to the embed method. When set to True, the method uses a special tokenization strategy, which supports query-document embeddings for ranking purposes.

Usage Example

import llama_cpp

llm = llama_cpp.Llama("jina-reranker-v1-tiny-en-f16.gguf", embedding=True, pooling_type=llama_cpp.LLAMA_POOLING_TYPE_RANK)

query = "what is panda?"
docs = [
    "pandas are bears", 
    "pandas are cute", 
    "pandas are black and white", 
    "The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China."
]

scores = llm.rank(query, docs)
print(scores)

output

llama_model_loader: loaded meta data with 32 key-value pairs and 70 tensors from jina-reranker-v1-tiny-en-f16.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = jina-bert-v2
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Jina Bert Implementation
llama_model_loader: - kv   3:                       general.organization str              = Jinaai
llama_model_loader: - kv   4:                         general.size_label str              = 33M
llama_model_loader: - kv   5:                            general.license str              = apache-2.0
llama_model_loader: - kv   6:                               general.tags arr[str,4]       = ["reranker", "cross-encoder", "transf...
llama_model_loader: - kv   7:                          general.languages arr[str,1]       = ["en"]
llama_model_loader: - kv   8:                   jina-bert-v2.block_count u32              = 4
llama_model_loader: - kv   9:                jina-bert-v2.context_length u32              = 8192
llama_model_loader: - kv  10:              jina-bert-v2.embedding_length u32              = 384
llama_model_loader: - kv  11:           jina-bert-v2.feed_forward_length u32              = 1536
llama_model_loader: - kv  12:          jina-bert-v2.attention.head_count u32              = 12
llama_model_loader: - kv  13:  jina-bert-v2.attention.layer_norm_epsilon f32              = 0.000000
llama_model_loader: - kv  14:                          general.file_type u32              = 1
llama_model_loader: - kv  15:              jina-bert-v2.attention.causal bool             = false
llama_model_loader: - kv  16:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  17:                         tokenizer.ggml.pre str              = jina-v1-en
llama_model_loader: - kv  18:                      tokenizer.ggml.tokens arr[str,61056]   = ["<s>", "<pad>", "</s>", "<unk>", "<m...
llama_model_loader: - kv  19:                  tokenizer.ggml.token_type arr[i32,61056]   = [3, 3, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  20:                      tokenizer.ggml.merges arr[str,39382]   = ["t h", "i n", "a n", "e r", "th e", ...
llama_model_loader: - kv  21:                tokenizer.ggml.bos_token_id u32              = 0
llama_model_loader: - kv  22:                tokenizer.ggml.eos_token_id u32              = 2
llama_model_loader: - kv  23:            tokenizer.ggml.unknown_token_id u32              = 3
llama_model_loader: - kv  24:          tokenizer.ggml.seperator_token_id u32              = 2
llama_model_loader: - kv  25:            tokenizer.ggml.padding_token_id u32              = 1
llama_model_loader: - kv  26:                tokenizer.ggml.cls_token_id u32              = 0
llama_model_loader: - kv  27:               tokenizer.ggml.mask_token_id u32              = 4
llama_model_loader: - kv  28:            tokenizer.ggml.token_type_count u32              = 2
llama_model_loader: - kv  29:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  30:               tokenizer.ggml.add_eos_token bool             = true
llama_model_loader: - kv  31:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:   41 tensors
llama_model_loader: - type  f16:   29 tensors
llm_load_vocab: empty token at index 5
llm_load_vocab: model vocab missing newline token, using special_pad_id instead
llm_load_vocab: control token:      2 '</s>' is not marked as EOG
llm_load_vocab: control token:      4 '<mask>' is not marked as EOG
llm_load_vocab: control token:      1 '<pad>' is not marked as EOG
llm_load_vocab: control token:      0 '<s>' is not marked as EOG
llm_load_vocab: control token:      3 '<unk>' is not marked as EOG
llm_load_vocab: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect
llm_load_vocab: special tokens cache size = 5
llm_load_vocab: token to piece cache size = 1.5060 MB
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = jina-bert-v2
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 61056
llm_load_print_meta: n_merges         = 39382
llm_load_print_meta: vocab_only       = 0
llm_load_print_meta: n_ctx_train      = 8192
llm_load_print_meta: n_embd           = 384
llm_load_print_meta: n_layer          = 4
llm_load_print_meta: n_head           = 12
llm_load_print_meta: n_head_kv        = 12
llm_load_print_meta: n_rot            = 32
llm_load_print_meta: n_swa            = 0
llm_load_print_meta: n_embd_head_k    = 32
llm_load_print_meta: n_embd_head_v    = 32
llm_load_print_meta: n_gqa            = 1
llm_load_print_meta: n_embd_k_gqa     = 384
llm_load_print_meta: n_embd_v_gqa     = 384
llm_load_print_meta: f_norm_eps       = 1.0e-12
llm_load_print_meta: f_norm_rms_eps   = 0.0e+00
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 8.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 1536
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attn      = 0
llm_load_print_meta: pooling type     = -1
llm_load_print_meta: rope type        = -1
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_ctx_orig_yarn  = 8192
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: ssm_dt_b_c_rms   = 0
llm_load_print_meta: model type       = 33M
llm_load_print_meta: model ftype      = F16
llm_load_print_meta: model params     = 32.90 M
llm_load_print_meta: model size       = 62.78 MiB (16.01 BPW) 
llm_load_print_meta: general.name     = Jina Bert Implementation
llm_load_print_meta: BOS token        = 0 '<s>'
llm_load_print_meta: EOS token        = 2 '</s>'
llm_load_print_meta: UNK token        = 3 '<unk>'
llm_load_print_meta: SEP token        = 2 '</s>'
llm_load_print_meta: PAD token        = 1 '<pad>'
llm_load_print_meta: CLS token        = 0 '<s>'
llm_load_print_meta: MASK token       = 4 '<mask>'
llm_load_print_meta: EOG token        = 2 '</s>'
llm_load_print_meta: max token length = 45
llm_load_tensors: CPU_Mapped model buffer size =    62.78 MiB
......................
llama_new_context_with_model: n_ctx      = 512
llama_new_context_with_model: n_batch    = 512
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:        CPU KV buffer size =     3.00 MiB
llama_new_context_with_model: KV self size  =    3.00 MiB, K (f16):    1.50 MiB, V (f16):    1.50 MiB
llama_new_context_with_model:        CPU  output buffer size =     0.00 MiB
llama_new_context_with_model:        CPU compute buffer size =    16.00 MiB
llama_new_context_with_model: graph nodes  = 154
llama_new_context_with_model: graph splits = 1
AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 1 | AVX512_VBMI = 1 | AVX512_VNNI = 1 | AVX512_BF16 = 0 | AMX_INT8 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | RISCV_VECT = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 | 
Model metadata: {'tokenizer.ggml.add_bos_token': 'true', 'tokenizer.ggml.add_eos_token': 'true', 'tokenizer.ggml.token_type_count': '2', 'tokenizer.ggml.cls_token_id': '0', 'tokenizer.ggml.padding_token_id': '1', 'tokenizer.ggml.seperator_token_id': '2', 'tokenizer.ggml.unknown_token_id': '3', 'tokenizer.ggml.eos_token_id': '2', 'general.quantization_version': '2', 'tokenizer.ggml.model': 'gpt2', 'general.architecture': 'jina-bert-v2', 'tokenizer.ggml.pre': 'jina-v1-en', 'general.name': 'Jina Bert Implementation', 'jina-bert-v2.attention.causal': 'false', 'jina-bert-v2.block_count': '4', 'general.organization': 'Jinaai', 'general.type': 'model', 'general.size_label': '33M', 'general.license': 'apache-2.0', 'tokenizer.ggml.mask_token_id': '4', 'jina-bert-v2.context_length': '8192', 'jina-bert-v2.embedding_length': '384', 'tokenizer.ggml.bos_token_id': '0', 'jina-bert-v2.attention.head_count': '12', 'jina-bert-v2.feed_forward_length': '1536', 'general.file_type': '1', 'jina-bert-v2.attention.layer_norm_epsilon': '0.000000'}
Using fallback chat format: llama-2
llama_perf_context_print:        load time =      17.98 ms
llama_perf_context_print: prompt eval time =       0.00 ms /    78 tokens (    0.00 ms per token,      inf tokens per second)
llama_perf_context_print:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_perf_context_print:       total time =      18.03 ms /    79 tokens
[0.022738507017493248, 0.01924673095345497, 0.027259426191449165, 0.134610116481781]

@Joshua-Usi
Copy link

Wait this actually looks fire, why hasn't this been merged

@SubatomicPlanets
Copy link

Yeah, this looks really cool! The code for this is short and sweet, would love to see it merged!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants