From 3ac47cd65c850d61190cc1fec31879b597091021 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Mon, 28 Oct 2024 19:26:36 -0700 Subject: [PATCH] Add type hints for get_model_output Signed-off-by: Vibhu Jawa --- crossfit/backend/torch/model.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/crossfit/backend/torch/model.py b/crossfit/backend/torch/model.py index 61f9d57..184a308 100644 --- a/crossfit/backend/torch/model.py +++ b/crossfit/backend/torch/model.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. - from enum import Enum from typing import Any, List, Union import cudf import cupy as cp +import torch from crossfit.backend.cudf.series import ( create_list_series_from_1d_or_2d_ar, @@ -82,7 +82,13 @@ def estimate_memory(self, max_num_tokens: int, batch_size: int) -> int: def max_seq_length(self) -> int: raise NotImplementedError() - def get_model_output(self, all_outputs_ls, index, loader, pred_output_col) -> cudf.DataFrame: + def get_model_output( + self, + all_outputs_ls: List[Union[dict, torch.Tensor]], + index: Union[cudf.Index], + loader: Any, + pred_output_col: str, + ) -> cudf.DataFrame: # importing here to avoid cyclic import error from crossfit.backend.torch.loader import SortedSeqLoader