Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow fetching token embeddings from a cross-encoding #407

Open
dan-octo opened this issue Sep 13, 2024 · 0 comments
Open

Allow fetching token embeddings from a cross-encoding #407

dan-octo opened this issue Sep 13, 2024 · 0 comments

Comments

@dan-octo
Copy link

Feature request

It would be nice to allow fetching the token embeddings from a cross-encoding, which is necessary to implement systems such as retrieval augmented named entity recognition (RA-NER).

Ideally, it would be implemented via an endpoint akin to the /embed_all endpoint, but would take an additional argument which plays the role of the text_pair argument here.

In addition to the token embeddings, this new endpoint would return token_type_ids, so as to be able to distinguish which token embeddings represent tokens from which sequence (text or text_pair, in the parlance of Transformers tokenizers).

Additionally, I believe this would help round out the API, as this functionality is available in the transformers library but unavailable here.

An MWE of calling the endpoint as I would like to is as follows:

import asyncio

import aiohttp


async def main():
    payload = {
        "inputs": ["This is a query.", "This is a second query."],
        "inputs_pair": ["This is a doc for query 1.", "This is a doc for query 2."],
    }
    session = aiohttp.ClientSession()
    async with session.post(
        "http://127.0.0.1:8080/embed_all_cross_encoding",
        headers={"Content-Type": "application/json"},
        json=payload,
    ) as response:
        data = await response.json()

    token_embeddings = data["token_embeddings"]
    token_type_ids = data["token_type_ids"]

if __name__ == "__main__":
    asyncio.run(main())

where token_embeddings is of shape batch_size * sequence_length * n_dims, and token_type_ids is of shape batch_size * sequence_length.

Motivation

Fetching token embeddings from a cross-encoding serves two purposes:

(i) Enables implementation of systems such as RA-NER.

(ii) Helps to round out the API, bringing functionality available in the Transformers library which is as of yet unavailable here.

Your contribution

I could contribute to examples and/or documentation. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant