From b9350e54daa34f1653cd8fa68100df34217611fe Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Sep 2024 17:39:29 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- elk/extraction/extraction.py | 1 + elk/plotting/visualize.py | 6 +++--- elk/promptsource/templates.py | 8 +++++--- elk/training/platt_scaling.py | 3 +-- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 9c082c52f..1cf3bf446 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -1,4 +1,5 @@ """Functions for extracting the hidden states of a model.""" + import os from collections import defaultdict from dataclasses import InitVar, dataclass, replace diff --git a/elk/plotting/visualize.py b/elk/plotting/visualize.py index 85eedd43d..93fbd650f 100644 --- a/elk/plotting/visualize.py +++ b/elk/plotting/visualize.py @@ -78,9 +78,9 @@ def render( y=dataset_data["auroc_estimate"], mode="lines", name=ensemble, - showlegend=False - if dataset_name != unique_datasets[0] - else True, + showlegend=( + False if dataset_name != unique_datasets[0] else True + ), line=dict(color=color_map[ensemble]), ), row=row, diff --git a/elk/promptsource/templates.py b/elk/promptsource/templates.py index 7d4c0b845..8d93a40c0 100644 --- a/elk/promptsource/templates.py +++ b/elk/promptsource/templates.py @@ -215,9 +215,11 @@ def _escape_pipe(cls, example): # Replaces any occurrences of the "|||" separator in the example, which # which will be replaced back after splitting protected_example = { - key: value.replace("|||", cls.pipe_protector) - if isinstance(value, str) - else value + key: ( + value.replace("|||", cls.pipe_protector) + if isinstance(value, str) + else value + ) for key, value in example.items() } return protected_example diff --git a/elk/training/platt_scaling.py b/elk/training/platt_scaling.py index 278d8d95d..70dd87c30 100644 --- a/elk/training/platt_scaling.py +++ b/elk/training/platt_scaling.py @@ -12,8 +12,7 @@ class PlattMixin(ABC): scale: nn.Parameter @abstractmethod - def __call__(self, *args: Any, **kwds: Any) -> Any: - ... + def __call__(self, *args: Any, **kwds: Any) -> Any: ... def platt_scale(self, labels: Tensor, hiddens: Tensor, max_iter: int = 100): """Fit the scale and bias terms to data with LBFGS.