From 99c1b1b8745d591521e60147265a48c258917e84 Mon Sep 17 00:00:00 2001 From: James Raskind Date: Fri, 15 Nov 2024 12:02:35 +0300 Subject: [PATCH] . --- autointent/modules/prediction/adaptive.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/autointent/modules/prediction/adaptive.py b/autointent/modules/prediction/adaptive.py index ba4c08c..e40864a 100644 --- a/autointent/modules/prediction/adaptive.py +++ b/autointent/modules/prediction/adaptive.py @@ -20,7 +20,7 @@ class AdaptivePredictorDumpMetadata(TypedDict): r: float multilabel: bool - tags: list[Tag] + tags: list[Tag] | None class AdaptivePredictor(PredictionModule): @@ -79,7 +79,7 @@ def load(self, path: str) -> None: self._r = metadata["r"] self.multilabel = metadata["multilabel"] - self.tags = [Tag(**tag) for tag in metadata["tags"] if metadata["tags"] and isinstance(metadata["tags"], list)] # type: ignore[arg-type, union-attr] + self.tags = [Tag(**tag) for tag in metadata["tags"] if metadata["tags"] and isinstance(metadata["tags"], list)] # type: ignore[arg-type] self.metadata = metadata @@ -95,7 +95,8 @@ def multilabel_predict(scores: npt.NDArray[Any], r: float, tags: list[Tag] | Non return res -def multilabel_score(y_true: list[LabelType] | list[list[LabelType]], y_pred: list[LabelType] | list[list[LabelType]]) -> float: +def multilabel_score(y_true: list[LabelType], + y_pred: npt.NDArray[Any]) -> float: y_true_, y_pred_ = transform(y_true, y_pred)