Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SEP] Token in the QA model response #2607

Open
aruggero opened this issue May 17, 2023 · 13 comments
Open

[SEP] Token in the QA model response #2607

aruggero opened this issue May 17, 2023 · 13 comments
Labels
bug Something isn't working

Comments

@aruggero
Copy link

Description

The [SEP] token used in input for the Question Answering model "distilbert" of the DJL is returned as part of the extracted answer to the question.
Shouldn't be the answer extracted from the context only? And not from the concatenation of query + context?

Expected Behavior

I would expect to see the answer belonging only to the context, without any delimitation token.

Error Message

With an input like:
var question = "BBC Japan broadcasting";
var resourceDocument = "BBC Japan was a general entertainment Channel.\nWhich operated between December 2004 and April 2006.\nIt ceased operations after its Japanese distributor folded.";

The following answer is returned: "bbc japan broadcasting [SEP] bbc japan"

How to Reproduce?

public class QuestionAnsweringBug {
    public static void main(String[] args) throws ModelNotFoundException, MalformedModelException, IOException, TranslateException {
        var question = "BBC Japan broadcasting";
        var resourceDocument = "BBC Japan was a general entertainment Channel.\n" +
                "Which operated between December 2004 and April 2006.\n" +
                "It ceased operations after its Japanese distributor folded.";

        QAInput input = new QAInput(question, resourceDocument);

        ZooModel<QAInput, String> modelPyTorch = loadPyTorchModel();
        Predictor<QAInput, String> predictorPyTorch = modelPyTorch.newPredictor();
        String answerPyTorch = predictorPyTorch.predict(input);
        System.out.println(answerPyTorch);
    }

    public static ZooModel<QAInput, String> loadPyTorchModel() throws ModelNotFoundException, MalformedModelException, IOException {
        Criteria<QAInput, String> criteria = Criteria.builder()
                .optApplication(Application.NLP.QUESTION_ANSWER)
                .setTypes(QAInput.class, String.class)
                .optFilter("modelType", "distilbert")
                .optEngine("PyTorch") // Use PyTorch engine
                .optProgress(new ProgressBar()).build();
        return criteria.loadModel();
    }
}

Answer: "bbc japan broadcasting [SEP] bbc japan"

What have you tried to solve it?

  1. In the processOutput method of the translator: ai.djl.pytorch.zoo.nlp.qa.PtBertQATranslator#processOutput the tokens variable is used to extract the answer. This variable is a concatenation of the query + the context: ["CLS", "bbc", "japan", "broadcasting", "[SEP]", "bbc", "japan", "was", "a", "general", ...
  2. The index returned is used on this arrayList and therefore, if less then the SEP index, it could potentially extract part of the query as the answer.

Environment Info

java.vm.vendor: Homebrew
java.version: 17.0.6
os.arch: aarch64
DJL version: 0.23.0-SNAPSHOT
OS: macOS Ventura 13.3.1
Chip: Apple M2 Pro
Memory: 16 GB
@aruggero aruggero added the bug Something isn't working label May 17, 2023
@frankfliu
Copy link
Contributor

I would like to recommend you to use HuggingFace model, it uses Huggingface tokenizer and a more consistent post processing translator ai.djl.huggingface.translator.QuestionAnsweringTranslator:

String question = "When did BBC Japan start broadcasting?";
String paragraph =
        "BBC Japan was a general entertainment Channel. "
                + "Which operated between December 2004 and April 2006. "
                + "It ceased operations after its Japanese distributor folded.";

Criteria<QAInput, String> criteria = Criteria.builder()
                .setTypes(QAInput.class, String.class)
                .optModelUrls("djl://ai.djl.huggingface.pytorch/deepset/minilm-uncased-squad2")
                .optEngine("PyTorch")
                .optTranslatorFactory(new QuestionAnsweringTranslatorFactory())
                .optProgress(new ProgressBar())
                .build();

try (ZooModel<QAInput, String> model = criteria.loadModel();
     Predictor<QAInput, String> predictor = model.newPredictor()) {
    QAInput input = new QAInput(question, paragraph);
    String res = predictor.predict(input);
    System.out.println("answer: " + res);
}

See example: https://github.com/deepjavalibrary/djl-demo/blob/master/huggingface/nlp/src/main/java/com/examples/QuestionAnswering.java

@aruggero
Copy link
Author

Thanks @frankfliu
I've tried with the huggingface model.
The answer is better, even if I'm still getting separator tokens:
[SEP] bbc japan was a general entertainment channel. which operated between december 2004 and april 2006. it ceased operations after its japanese distributor folded. [SEP]

@aruggero
Copy link
Author

aruggero commented May 19, 2023

Another question.. I see that the implementation of ai.djl.huggingface.translator.QuestionAnsweringTranslator is a bit different from the one suggested here https://docs.djl.ai/master/docs/demos/jupyter/pytorch/load_your_own_pytorch_bert.html

I've this custom translator implemented in order to use this HuggingFace model:
`public class HuggingFaceBERTQATranslator implements NoBatchifyTranslator<QAInput, String> {
private String[] tokens;
private HuggingFaceTokenizer tokenizer;

@Override
public void prepare(TranslatorContext ctx) throws IOException {
    tokenizer = HuggingFaceTokenizer.newInstance("distilbert-base-uncased-distilled-squad");
}

@Override
public NDList processInput(TranslatorContext ctx, QAInput input) {
    Encoding encoding =
            tokenizer.encode(
                    input.getQuestion().toLowerCase(),
                    input.getParagraph().toLowerCase());
    tokens = encoding.getTokens();
    NDManager manager = ctx.getNDManager();

    long[] indices = encoding.getIds();
    NDArray indicesArray = manager.create(indices);

    return new NDList(indicesArray);
}

@Override
public String processOutput(TranslatorContext ctx, NDList list) {
    NDArray startLogits = list.get(0);
    NDArray endLogits = list.get(1);
    int startIdx = (int) startLogits.argMax().getLong();
    int endIdx = (int) endLogits.argMax().getLong();
    return Arrays.toString(Arrays.copyOfRange(tokens, startIdx, endIdx + 1));
}

@Override
public Batchifier getBatchifier() {
    return Batchifier.STACK;
}

}`

This is having the same problem (the one reported in this issue) in returning separator tokens in the response.
Should I change the translator implementation and adapt it to the one of ai.djl.huggingface.translator.QuestionAnsweringTranslator?

From what I could see, the problem of the tokens seems to be related more to the model than the translator, because the problem are the returned indexes. I do not know if changing this helps.

@frankfliu
Copy link
Contributor

ai.djl.huggingface.translator.QuestionAnsweringTranslator is preferred.

It's seems your model didn't return correct result. Did you try run inference with python?

@aruggero
Copy link
Author

aruggero commented May 19, 2023

Yes, in python it returns the first part of the context as the answer. Therefore: "BBC Japan".
This is fine since the question is ambiguous "BBC Japan broadcasting".
It returns no separator token anyway, therefore the indexes are different.

@frankfliu
Copy link
Contributor

A few thing to check:

  1. Check the tokenizer generate the identical embedding
  2. if model forward return the identical tensor for both java and python
  3. make sure your processOutput matches the python implementation

@aruggero
Copy link
Author

aruggero commented May 25, 2023

Hi @frankfliu sorry for the delay
Is this analysis necessary?
We have the problem also with the model given by the DJL.. the distilbert:

public static ZooModel<QAInput, String> loadPyTorchModel() throws ModelNotFoundException, MalformedModelException, IOException {
        Criteria<QAInput, String> criteria = Criteria.builder()
                .optApplication(Application.NLP.QUESTION_ANSWER)
                .setTypes(QAInput.class, String.class)
                .optFilter("modelType", "distilbert")
                .optEngine("PyTorch")
                .optProgress(new ProgressBar()).build();
        return criteria.loadModel();
    }

The points you mentioned, in this case, are managed internally, I have no power over them..
Therefore I suppose there is something in the library that is not 100% working...

@frankfliu
Copy link
Contributor

@aruggero
I'm not able to reproduce your issue with above code. The example can be found: https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/BertQaInference.java

We have unit-test test against this model nightly and didn't see the issue you mentioned.

@aruggero
Copy link
Author

Hi @frankfliu
Do you use the code I put at the top of the issue?
The question I used is a bit different from the original one. I used "BBC Japan broadcasting".
Did you try this one?

I see that your example is a bit different.
Thank you

@frankfliu
Copy link
Contributor

@aruggero

will take a look.

@frankfliu
Copy link
Contributor

@aruggero

You can use .optArgument("addSpecialTokens", "false") for your case:

Criteria<QAInput, String> criteria = Criteria.builder()
                .setTypes(QAInput.class, String.class)
                .optModelUrls("djl://ai.djl.huggingface.pytorch/deepset/minilm-uncased-squad2")
                .optEngine("PyTorch")
                .optTranslatorFactory(new QuestionAnsweringTranslatorFactory())
                .optArgument("addSpecialTokens", "false")
                .optProgress(new ProgressBar())
                .build();

@aruggero
Copy link
Author

Hi @frankfliu
From what I know it is important to pass an input that reflects the one the model has seen during training (therefore with special tokens).
What does removing them imply in terms of model prediction?

@demq
Copy link
Contributor

demq commented Sep 15, 2023

Another question.. I see that the implementation of ai.djl.huggingface.translator.QuestionAnsweringTranslator is a bit different from the one suggested here https://docs.djl.ai/master/docs/demos/jupyter/pytorch/load_your_own_pytorch_bert.html

I've this custom translator implemented in order to use this HuggingFace model: `public class HuggingFaceBERTQATranslator implements NoBatchifyTranslator<QAInput, String> { private String[] tokens; private HuggingFaceTokenizer tokenizer;

@Override
public void prepare(TranslatorContext ctx) throws IOException {
    tokenizer = HuggingFaceTokenizer.newInstance("distilbert-base-uncased-distilled-squad");
}

@Override
public NDList processInput(TranslatorContext ctx, QAInput input) {
    Encoding encoding =
            tokenizer.encode(
                    input.getQuestion().toLowerCase(),
                    input.getParagraph().toLowerCase());
    tokens = encoding.getTokens();
    NDManager manager = ctx.getNDManager();

    long[] indices = encoding.getIds();
    NDArray indicesArray = manager.create(indices);

    return new NDList(indicesArray);
}

@Override
public String processOutput(TranslatorContext ctx, NDList list) {
    NDArray startLogits = list.get(0);
    NDArray endLogits = list.get(1);
    int startIdx = (int) startLogits.argMax().getLong();
    int endIdx = (int) endLogits.argMax().getLong();
    return Arrays.toString(Arrays.copyOfRange(tokens, startIdx, endIdx + 1));
}

@Override
public Batchifier getBatchifier() {
    return Batchifier.STACK;
}

}`

This is having the same problem (the one reported in this issue) in returning separator tokens in the response. Should I change the translator implementation and adapt it to the one of ai.djl.huggingface.translator.QuestionAnsweringTranslator?

From what I could see, the problem of the tokens seems to be related more to the model than the translator, because the problem are the returned indexes. I do not know if changing this helps.

You can use the Encoding::getSpecialTokenMask() from ai.djl.huggingface.tokenizers to limit the prediction to be only in the original "context" and exclude the special tokens in your custom translator's processOutput().

In general, I would recommend following the suggestion by @frankfliu to run all the inference steps in Python with a traced model of interest on a simple input and compare the raw outputs/logits to those you get in djl. Then you would know if it is an issue with the model itself behaving differently with djl (highly unlikely) or the pre/post processing steps are different (and you can look up HF source code for the details). As to why the model predicts the special tokens to be a part of the answer, I don't think there are any strict masks applied to the output to prevent it, and the fine-tuning objective might not be penalizing it (enough).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants