From c68f4522ea5c0f21f49a807dfd2b7a9309716ea0 Mon Sep 17 00:00:00 2001 From: Aleksandr Mokrov Date: Tue, 31 Oct 2023 14:06:45 +0100 Subject: [PATCH] Rewrite TapasForQuestionAnswering wrapper --- .../266-table-question-answering.ipynb | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/notebooks/266-table-question-answering/266-table-question-answering.ipynb b/notebooks/266-table-question-answering/266-table-question-answering.ipynb index 74b0c525d64..d2a3dca9715 100644 --- a/notebooks/266-table-question-answering/266-table-question-answering.ipynb +++ b/notebooks/266-table-question-answering/266-table-question-answering.ipynb @@ -272,9 +272,17 @@ "execution_count": null, "outputs": [], "source": [ - "class TapasForQuestionAnswering(model.__class__): # it is better to keep the class name to avoid warnings\n", + "from transformers import TapasConfig\n", + "\n", + "\n", + "# get config for pretrained model\n", + "config = TapasConfig.from_pretrained('google/tapas-large-finetuned-wtq')\n", + "\n", + "\n", + "\n", + "class TapasForQuestionAnswering(TapasForQuestionAnswering): # it is better to keep the class name to avoid warnings\n", " def __init__(self, ov_model_path):\n", - " super().__init__(model.config) # pass config from the original model\n", + " super().__init__(config) # pass config from the pretrained model\n", " self.tqa_model = core.compile_model(ov_model_path, device.value)\n", " \n", " def forward(self, input_ids, *, attention_mask, token_type_ids):\n", @@ -325,10 +333,9 @@ "\n", "\n", "def highlight_answers(x, coordinates):\n", - " color = \"background-color: lightgreen\"\n", " highlighted_table = pd.DataFrame('', index=x.index, columns=x.columns)\n", " for coordinates_i in coordinates:\n", - " highlighted_table.iloc[coordinates_i[0], coordinates_i[1]] = color\n", + " highlighted_table.iloc[coordinates_i[0], coordinates_i[1]] = \"background-color: lightgreen\"\n", " \n", " return highlighted_table\n", "\n",