Skip to content

Commit

Permalink
numpy2 support (#1985)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgsavage authored May 12, 2024
1 parent cbdd79e commit cb0ec94
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
fail-fast: false
matrix:
python-version: ["3.10", "3.11", "3.12"]
numpy: [null, "numpy>=1.23,<2.0.0"]
numpy: [null, "numpy>=1.23,<2.0.0", "numpy>=2.0.0rc1"]
uncertainties: [null, "uncertainties==3.1.6", "uncertainties>=3.1.6,<4.0.0"]
extras: [null]
include:
Expand Down
7 changes: 4 additions & 3 deletions CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ Pint Changelog
0.24 (unreleased)
-----------------

- NumPy 2.0 support
(PR #1985, #1971)
- Implement numpy roll (Related to issue #981)
- Add `dim_sort` function to _formatter_helpers.
- Add `dim_order` and `default_sort_func` properties to FullFormatter.
(PR #1926, fixes Issue #1841)
- Fix LaTeX siuntix formatting when using non_int_type=decimal.Decimal.

- Fix converting to offset units of higher dimension e.g. gauge pressure (#1949).
-
- Fix converting to offset units of higher dimension e.g. gauge pressure
(PR #1949)

0.23 (2023-12-08)
-----------------
Expand Down
2 changes: 1 addition & 1 deletion pint/facets/numpy/numpy_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ def _base_unit_if_needed(a):
raise OffsetUnitCalculusError(a.units)


# Can remove trapz wrapping when we only support numpy>=2
# NP2 Can remove trapz wrapping when we only support numpy>=2
@implements("trapz", "function")
@implements("trapezoid", "function")
def _trapz(y, x=None, dx=1.0, **kwargs):
Expand Down
15 changes: 14 additions & 1 deletion pint/testsuite/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,13 +438,23 @@ def test_cross(self):
np.cross(a, b), [[-15, -2, 39]] * self.ureg.kPa * self.ureg.m**2
)

# NP2: Remove this when we only support np>=2.0
@helpers.requires_array_function_protocol()
def test_trapz(self):
helpers.assert_quantity_equal(
np.trapz([1.0, 2.0, 3.0, 4.0] * self.ureg.J, dx=1 * self.ureg.m),
7.5 * self.ureg.J * self.ureg.m,
)

@helpers.requires_array_function_protocol()
def test_trapezoid(self):
# NP2: Remove this when we only support np>=2.0
if np.lib.NumpyVersion(np.__version__) >= "2.0.0b1":
helpers.assert_quantity_equal(
np.trapezoid([1.0, 2.0, 3.0, 4.0] * self.ureg.J, dx=1 * self.ureg.m),
7.5 * self.ureg.J * self.ureg.m,
)

@helpers.requires_array_function_protocol()
def test_dot(self):
helpers.assert_quantity_equal(
Expand Down Expand Up @@ -758,9 +768,12 @@ def test_minimum(self):
np.minimum(self.q, self.Q_([0, 5], "m")), self.Q_([[0, 2], [0, 4]], "m")
)

# NP2: Can remove Q_(arr).ptp test when we only support numpy>=2
def test_ptp(self):
assert self.q.ptp() == 3 * self.ureg.m
if not np.lib.NumpyVersion(np.__version__) >= "2.0.0b1":
assert self.q.ptp() == 3 * self.ureg.m

# NP2: Keep this test for numpy>=2, it's only arr.ptp() that is deprecated
@helpers.requires_array_function_protocol()
def test_ptp_numpy_func(self):
helpers.assert_quantity_equal(np.ptp(self.q, axis=0), [2, 2] * self.ureg.m)
Expand Down

0 comments on commit cb0ec94

Please sign in to comment.