Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

handle exponents when converting str to float #891

Merged
merged 8 commits into from
Nov 25, 2024
Merged
46 changes: 46 additions & 0 deletions src/inspect_ai/_util/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,49 @@ def truncate_string_to_bytes(input: str, max_bytes: int) -> TruncatedOutput | No
except Exception as ex:
logger.warning(f"Unexpected error occurred truncating string: {ex}")
return None


def str_to_float(s: str) -> float:
"""Convert a str to float, including handling exponent characters.

The Python isnumeric() function returns True for strings that include exponents
(e.g. 5²) however the float() function doesn't handle exponents. This function
will correctly handle these exponents when converting from str to float.

Args:
s (str): String to convert to float

Returns:
float: Converted value

Raises:
ValueError: If the string is not a valid numeric value.
"""
# handle empty input
if not s:
raise ValueError("Input string is empty.")

superscript_map = str.maketrans("⁰¹²³⁴⁵⁶⁷⁸⁹", "0123456789")
superscript_chars = "⁰¹²³⁴⁵⁶⁷⁸⁹"

base_part = ""
exponent_part = ""
for idx, char in enumerate(s):
if char in superscript_chars:
base_part = s[:idx]
exponent_part = s[idx:]
break
else:
base_part = s

# handle empty base (e.g., '²')
base = float(base_part) if base_part else 1.0

# handle exponent part
if exponent_part:
exponent_str = exponent_part.translate(superscript_map)
exponent = int(exponent_str)
else:
exponent = 1 # Default exponent is 1 if no superscript is present

return base**exponent
8 changes: 6 additions & 2 deletions src/inspect_ai/scorer/_common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import Callable, Literal

from inspect_ai._util.text import strip_numeric_punctuation, strip_punctuation
from inspect_ai._util.text import (
str_to_float,
strip_numeric_punctuation,
strip_punctuation,
)
from inspect_ai.solver._task_state import TaskState

from ._metric import CORRECT, INCORRECT, Score
Expand Down Expand Up @@ -96,7 +100,7 @@ def first_number_normalized(words: list[str]) -> str:

def normalize_number(number: str, precision: int = 5) -> str:
if number.replace(".", "").isnumeric():
num = float(number)
num = str_to_float(number)
return format(num, f".{precision}g")
else:
return number
64 changes: 64 additions & 0 deletions tests/util/test_str_to_float.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pytest

from inspect_ai._util.text import str_to_float


def test_str_to_float_basic():
assert str_to_float("1²") == 1.0
assert str_to_float("2³") == 8.0
assert str_to_float("5⁴") == 625.0
assert str_to_float("10⁰") == 1.0
assert str_to_float("3") == 3.0


def test_str_to_float_decimal_base():
assert str_to_float("2.5²") == 2.5**2
assert str_to_float("0.1³") == 0.1**3


def test_str_to_float_negative_base():
assert str_to_float("-2²") == (-2) ** 2
assert str_to_float("-2³") == (-2) ** 3


def test_str_to_float_multi_digit_exponent():
assert str_to_float("2⁴⁵") == 2**45
assert str_to_float("3⁰⁰⁰") == 3**0 # Exponent is 0


def test_str_to_float_no_exponent():
assert str_to_float("7") == 7.0
assert str_to_float("0") == 0.0


def test_str_to_float_no_base():
# When the base is missing, default to 1.0
assert str_to_float("⁵") == 1.0**5
assert str_to_float("⁰") == 1.0**0


def test_str_to_float_zero_exponent():
assert str_to_float("5⁰") == 1.0
assert str_to_float("0⁰") == 1.0 # 0^0 is considered 1 in this context


def test_str_to_float_invalid_input():
with pytest.raises(ValueError):
str_to_float("abc")
with pytest.raises(ValueError):
str_to_float("")
with pytest.raises(ValueError):
str_to_float("2^3")
with pytest.raises(ValueError):
str_to_float("⁺²") # Unsupported superscript characters


def test_str_to_float_edge_cases():
# Exponent with unsupported characters
with pytest.raises(ValueError):
str_to_float("2⁻³")
# Base with unsupported characters
with pytest.raises(ValueError):
str_to_float("a²")
# Superscript after decimal point
assert str_to_float("2.5⁴") == 2.5**4
Loading