Skip to content

Commit

Permalink
Add type hints for get_model_output
Browse files Browse the repository at this point in the history
Signed-off-by: Vibhu Jawa <[email protected]>
  • Loading branch information
VibhuJawa committed Oct 29, 2024
1 parent a268062 commit 3ac47cd
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions crossfit/backend/torch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from enum import Enum
from typing import Any, List, Union

import cudf
import cupy as cp
import torch

from crossfit.backend.cudf.series import (
create_list_series_from_1d_or_2d_ar,
Expand Down Expand Up @@ -82,7 +82,13 @@ def estimate_memory(self, max_num_tokens: int, batch_size: int) -> int:
def max_seq_length(self) -> int:
raise NotImplementedError()

def get_model_output(self, all_outputs_ls, index, loader, pred_output_col) -> cudf.DataFrame:
def get_model_output(
self,
all_outputs_ls: List[Union[dict, torch.Tensor]],
index: Union[cudf.Index],
loader: Any,
pred_output_col: str,
) -> cudf.DataFrame:
# importing here to avoid cyclic import error
from crossfit.backend.torch.loader import SortedSeqLoader

Expand Down

0 comments on commit 3ac47cd

Please sign in to comment.