Skip to content

Commit

Permalink
Rewrite TapasForQuestionAnswering wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
aleksandr-mokrov committed Oct 31, 2023
1 parent 7513c58 commit c68f452
Showing 1 changed file with 11 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,17 @@
"execution_count": null,
"outputs": [],
"source": [
"class TapasForQuestionAnswering(model.__class__): # it is better to keep the class name to avoid warnings\n",
"from transformers import TapasConfig\n",
"\n",
"\n",
"# get config for pretrained model\n",
"config = TapasConfig.from_pretrained('google/tapas-large-finetuned-wtq')\n",
"\n",
"\n",
"\n",
"class TapasForQuestionAnswering(TapasForQuestionAnswering): # it is better to keep the class name to avoid warnings\n",
" def __init__(self, ov_model_path):\n",
" super().__init__(model.config) # pass config from the original model\n",
" super().__init__(config) # pass config from the pretrained model\n",
" self.tqa_model = core.compile_model(ov_model_path, device.value)\n",
" \n",
" def forward(self, input_ids, *, attention_mask, token_type_ids):\n",
Expand Down Expand Up @@ -325,10 +333,9 @@
"\n",
"\n",
"def highlight_answers(x, coordinates):\n",
" color = \"background-color: lightgreen\"\n",
" highlighted_table = pd.DataFrame('', index=x.index, columns=x.columns)\n",
" for coordinates_i in coordinates:\n",
" highlighted_table.iloc[coordinates_i[0], coordinates_i[1]] = color\n",
" highlighted_table.iloc[coordinates_i[0], coordinates_i[1]] = \"background-color: lightgreen\"\n",
" \n",
" return highlighted_table\n",
"\n",
Expand Down

0 comments on commit c68f452

Please sign in to comment.