From 1b43efdda54d0612c28f674512b056d7a334a870 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Fri, 25 Oct 2024 20:57:53 +0200 Subject: [PATCH] input validation --- src/geoarrow.cpp | 5 +++++ tests/test_geoarrow.py | 6 ++++++ 2 files changed, 11 insertions(+) diff --git a/src/geoarrow.cpp b/src/geoarrow.cpp index 1fc8729..aebb733 100644 --- a/src/geoarrow.cpp +++ b/src/geoarrow.cpp @@ -14,6 +14,11 @@ py::array_t from_geoarrow(py::object input, bool planar, float tessellate_tolerance, py::object geometry_encoding) { + if (!py::hasattr(input, "__arrow_c_array__")) { + throw std::invalid_argument( + "input should be an Arrow-compatible array object (i.e. has an '__arrow_c_array__' " + "method)"); + } py::tuple capsules = input.attr("__arrow_c_array__")(); py::capsule schema_capsule = capsules[0]; py::capsule array_capsule = capsules[1]; diff --git a/tests/test_geoarrow.py b/tests/test_geoarrow.py index d9a83b8..02780d5 100644 --- a/tests/test_geoarrow.py +++ b/tests/test_geoarrow.py @@ -1,5 +1,6 @@ from packaging.version import Version +import numpy as np import pyarrow as pa import geoarrow.pyarrow as ga @@ -102,3 +103,8 @@ def test_from_geoarrow_invalid_encoding(): with pytest.raises(ValueError, match="'geometry_encoding' should be one"): spherely.from_geoarrow(arr, geometry_encoding="point") + + +def test_from_geoarrow_no_arrow_object(): + with pytest.raises(ValueError, match="input should be an Arrow-compatible array"): + spherely.from_geoarrow(np.array(["POINT (1 1)"], dtype=object))