Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
JakeRaskind committed Nov 15, 2024
1 parent 521c05d commit 99c1b1b
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions autointent/modules/prediction/adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
class AdaptivePredictorDumpMetadata(TypedDict):
r: float
multilabel: bool
tags: list[Tag]
tags: list[Tag] | None


class AdaptivePredictor(PredictionModule):
Expand Down Expand Up @@ -79,7 +79,7 @@ def load(self, path: str) -> None:

self._r = metadata["r"]
self.multilabel = metadata["multilabel"]
self.tags = [Tag(**tag) for tag in metadata["tags"] if metadata["tags"] and isinstance(metadata["tags"], list)] # type: ignore[arg-type, union-attr]
self.tags = [Tag(**tag) for tag in metadata["tags"] if metadata["tags"] and isinstance(metadata["tags"], list)] # type: ignore[arg-type]
self.metadata = metadata


Expand All @@ -95,7 +95,8 @@ def multilabel_predict(scores: npt.NDArray[Any], r: float, tags: list[Tag] | Non
return res


def multilabel_score(y_true: list[LabelType] | list[list[LabelType]], y_pred: list[LabelType] | list[list[LabelType]]) -> float:
def multilabel_score(y_true: list[LabelType],
y_pred: npt.NDArray[Any]) -> float:

y_true_, y_pred_ = transform(y_true, y_pred)

Expand Down

0 comments on commit 99c1b1b

Please sign in to comment.