From e8edca2553a693bf1dc244c9c2eadfca67557dcd Mon Sep 17 00:00:00 2001 From: Simon Lin Date: Tue, 22 Oct 2024 20:47:52 +1100 Subject: [PATCH] c --- .../src/cloud/credential_provider.rs | 33 ++- crates/polars-plan/src/dsl/python_udf.rs | 21 +- crates/polars-utils/src/python_function.rs | 226 ++++++++++++++---- py-polars/tests/unit/io/cloud/test_cloud.py | 49 ++++ 4 files changed, 273 insertions(+), 56 deletions(-) diff --git a/crates/polars-io/src/cloud/credential_provider.rs b/crates/polars-io/src/cloud/credential_provider.rs index ddd23568e0268..409a8bd71c3e6 100644 --- a/crates/polars-io/src/cloud/credential_provider.rs +++ b/crates/polars-io/src/cloud/credential_provider.rs @@ -457,9 +457,14 @@ mod python_impl { use super::IntoCredentialProvider; #[derive(Clone, Debug)] - #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct PythonCredentialProvider(pub(super) Arc); + impl From for PythonCredentialProvider { + fn from(value: PythonFunction) -> Self { + Self(Arc::new(value)) + } + } + impl IntoCredentialProvider for PythonCredentialProvider { #[cfg(feature = "aws")] fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider { @@ -665,6 +670,32 @@ mod python_impl { state.write_usize(Arc::as_ptr(&self.0) as *const () as usize) } } + + #[cfg(feature = "serde")] + mod _serde_impl { + use polars_utils::python_function::PySerializeWrap; + + use super::PythonCredentialProvider; + + impl serde::Serialize for PythonCredentialProvider { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + PySerializeWrap(self.0.as_ref()).serialize(serializer) + } + } + + impl<'a> serde::Deserialize<'a> for PythonCredentialProvider { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'a>, + { + PySerializeWrap::::deserialize(deserializer) + .map(|x| x.0.into()) + } + } + } } #[cfg(test)] diff --git a/crates/polars-plan/src/dsl/python_udf.rs b/crates/polars-plan/src/dsl/python_udf.rs index 1d8080855e3c8..521e3a7809e88 100644 --- a/crates/polars-plan/src/dsl/python_udf.rs +++ b/crates/polars-plan/src/dsl/python_udf.rs @@ -23,7 +23,7 @@ pub static mut CALL_DF_UDF_PYTHON: Option< > = None; pub use polars_utils::python_function::{ - PythonFunction, PYTHON_SERDE_MAGIC_BYTE_MARK, PYTHON_VERSION_MINOR, + PythonFunction, PYTHON3_VERSION, PYTHON_SERDE_MAGIC_BYTE_MARK, }; pub struct PythonUdfExpression { @@ -57,14 +57,14 @@ impl PythonUdfExpression { // Handle pickle metadata let use_cloudpickle = buf[0]; if use_cloudpickle != 0 { - let ser_py_version = buf[1]; - let cur_py_version = *PYTHON_VERSION_MINOR; + let ser_py_version = &buf[..2]; + let cur_py_version = *PYTHON3_VERSION; polars_ensure!( ser_py_version == cur_py_version, InvalidOperation: - "current Python version (3.{}) does not match the Python version used to serialize the UDF (3.{})", - cur_py_version, - ser_py_version + "current Python version {:?} does not match the Python version used to serialize the UDF {:?}", + (3, cur_py_version[0], cur_py_version[1]), + (3, ser_py_version[0], ser_py_version[1] ) ); } let buf = &buf[2..]; @@ -141,8 +141,8 @@ impl ColumnsUdf for PythonUdfExpression { .getattr("dumps") .unwrap(); let pickle_result = pickle.call1((self.python_function.clone_ref(py),)); - let (dumped, use_cloudpickle, py_version) = match pickle_result { - Ok(dumped) => (dumped, false, 0), + let (dumped, use_cloudpickle) = match pickle_result { + Ok(dumped) => (dumped, false), Err(_) => { let cloudpickle = PyModule::import_bound(py, "cloudpickle") .map_err(from_pyerr)? @@ -151,12 +151,13 @@ impl ColumnsUdf for PythonUdfExpression { let dumped = cloudpickle .call1((self.python_function.clone_ref(py),)) .map_err(from_pyerr)?; - (dumped, true, *PYTHON_VERSION_MINOR) + (dumped, true) }, }; // Write pickle metadata - buf.extend_from_slice(&[use_cloudpickle as u8, py_version]); + buf.push(use_cloudpickle as u8); + buf.extend_from_slice(&*PYTHON3_VERSION); // Write UDF metadata ciborium::ser::into_writer( diff --git a/crates/polars-utils/src/python_function.rs b/crates/polars-utils/src/python_function.rs index 551d7cda70ae9..f3b8f2d5fa8c0 100644 --- a/crates/polars-utils/src/python_function.rs +++ b/crates/polars-utils/src/python_function.rs @@ -1,15 +1,14 @@ -use once_cell::sync::Lazy; +use polars_error::{polars_bail, PolarsError, PolarsResult}; use pyo3::prelude::*; use pyo3::pybacked::PyBackedBytes; use pyo3::types::PyBytes; #[cfg(feature = "serde")] -use serde::ser::Error; -#[cfg(feature = "serde")] -use serde::{Deserialize, Deserializer, Serialize, Serializer}; +pub use serde_wrap::{ + PySerializeWrap, TrySerializeToBytes, PYTHON3_VERSION, + SERDE_MAGIC_BYTE_MARK as PYTHON_SERDE_MAGIC_BYTE_MARK, +}; -#[cfg(feature = "serde")] -pub const PYTHON_SERDE_MAGIC_BYTE_MARK: &[u8] = "PLPYUDF".as_bytes(); -pub static PYTHON_VERSION_MINOR: Lazy = Lazy::new(get_python_minor_version); +use crate::flatten; #[derive(Debug)] pub struct PythonFunction(pub PyObject); @@ -42,64 +41,201 @@ impl PartialEq for PythonFunction { } #[cfg(feature = "serde")] -impl Serialize for PythonFunction { +impl serde::Serialize for PythonFunction { fn serialize(&self, serializer: S) -> std::result::Result where - S: Serializer, + S: serde::Serializer, { - Python::with_gil(|py| { - let pickle = PyModule::import_bound(py, "cloudpickle") - .or_else(|_| PyModule::import_bound(py, "pickle")) - .expect("unable to import 'cloudpickle' or 'pickle'") - .getattr("dumps") - .unwrap(); - - let python_function = self.0.clone_ref(py); - - let dumped = pickle - .call1((python_function,)) - .map_err(|s| S::Error::custom(format!("cannot pickle {s}")))?; - let dumped = dumped.extract::().unwrap(); - - serializer.serialize_bytes(&dumped) - }) + use serde::ser::Error; + serializer.serialize_bytes( + self.try_serialize_to_bytes() + .map_err(|e| S::Error::custom(e.to_string()))? + .as_slice(), + ) } } #[cfg(feature = "serde")] -impl<'a> Deserialize<'a> for PythonFunction { +impl<'a> serde::Deserialize<'a> for PythonFunction { fn deserialize(deserializer: D) -> std::result::Result where - D: Deserializer<'a>, + D: serde::Deserializer<'a>, { use serde::de::Error; let bytes = Vec::::deserialize(deserializer)?; + Self::try_deserialize_bytes(bytes.as_slice()).map_err(|e| D::Error::custom(e.to_string())) + } +} - Python::with_gil(|py| { - let pickle = PyModule::import_bound(py, "pickle") - .expect("unable to import 'pickle'") - .getattr("loads") +#[cfg(feature = "serde")] +impl TrySerializeToBytes for PythonFunction { + fn try_serialize_to_bytes(&self) -> polars_error::PolarsResult> { + serialize_pyobject_with_cloudpickle_fallback(&self.0) + } + + fn try_deserialize_bytes(bytes: &[u8]) -> polars_error::PolarsResult { + deserialize_pyobject_bytes_maybe_cloudpickle(bytes) + } +} + +pub fn serialize_pyobject_with_cloudpickle_fallback(py_object: &PyObject) -> PolarsResult> { + Python::with_gil(|py| { + let pickle = PyModule::import_bound(py, "pickle") + .expect("unable to import 'pickle'") + .getattr("dumps") + .unwrap(); + + let dumped = pickle.call1((py_object.clone_ref(py),)); + + let (dumped, used_cloudpickle) = if let Ok(v) = dumped { + (v, false) + } else { + let cloudpickle = PyModule::import_bound(py, "cloudpickle") + .map_err(from_pyerr)? + .getattr("dumps") .unwrap(); - let arg = (PyBytes::new_bound(py, &bytes),); - let python_function = pickle - .call1(arg) - .map_err(|s| D::Error::custom(format!("cannot pickle {s}")))?; + let dumped = cloudpickle + .call1((py_object.clone_ref(py),)) + .map_err(from_pyerr)?; + (dumped, true) + }; - Ok(Self(python_function.into())) - }) + let py_bytes = dumped.extract::().map_err(from_pyerr)?; + + Ok(flatten( + &[&[used_cloudpickle as u8, b'C'][..], py_bytes.as_ref()], + None, + )) + }) +} + +pub fn deserialize_pyobject_bytes_maybe_cloudpickle From>( + bytes: &[u8], +) -> PolarsResult { + // TODO: Actually deserialize with cloudpickle if it's set. + let [_used_cloudpickle @ 0 | _used_cloudpickle @ 1, b'C', rem @ ..] = bytes else { + polars_bail!(ComputeError: "deserialize_pyobject_bytes_maybe_cloudpickle: invalid start bytes") + }; + + let bytes = rem; + + Python::with_gil(|py| { + let pickle = PyModule::import_bound(py, "pickle") + .expect("unable to import 'pickle'") + .getattr("loads") + .unwrap(); + let arg = (PyBytes::new_bound(py, bytes),); + let pyany_bound = pickle.call1(arg).map_err(from_pyerr)?; + Ok(PyObject::from(pyany_bound).into()) + }) +} + +#[cfg(feature = "serde")] +mod serde_wrap { + use once_cell::sync::Lazy; + use polars_error::PolarsResult; + + use crate::flatten; + + pub const SERDE_MAGIC_BYTE_MARK: &[u8] = "PLPYFN".as_bytes(); + /// [minor, micro] + pub static PYTHON3_VERSION: Lazy<[u8; 2]> = Lazy::new(super::get_python3_version); + + /// Serializes a Python object without additional system metadata. This is intended to be used + /// together with `PySerializeWrap`, which attaches e.g. Python version metadata. + pub trait TrySerializeToBytes: Sized { + fn try_serialize_to_bytes(&self) -> PolarsResult>; + fn try_deserialize_bytes(bytes: &[u8]) -> PolarsResult; + } + + /// Serialization wrapper for T: TrySerializeToBytes that attaches Python + /// version metadata. + pub struct PySerializeWrap(pub T); + + impl serde::Serialize for PySerializeWrap<&T> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::Error; + let dumped = self + .0 + .try_serialize_to_bytes() + .map_err(|e| S::Error::custom(e.to_string()))?; + + serializer.serialize_bytes( + flatten( + &[SERDE_MAGIC_BYTE_MARK, &*PYTHON3_VERSION, dumped.as_slice()], + None, + ) + .as_slice(), + ) + } + } + + impl<'a, T: TrySerializeToBytes> serde::Deserialize<'a> for PySerializeWrap { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'a>, + { + use serde::de::Error; + let bytes = Vec::::deserialize(deserializer)?; + + let Some((magic, rem)) = bytes.split_at_checked(SERDE_MAGIC_BYTE_MARK.len()) else { + return Err(D::Error::custom( + "unexpected EOF when reading serialized pyobject version", + )); + }; + + if magic != SERDE_MAGIC_BYTE_MARK { + return Err(D::Error::custom( + "serialized pyobject did not begin with magic byte mark", + )); + } + + let bytes = rem; + + let [a, b, rem @ ..] = bytes else { + return Err(D::Error::custom( + "unexpected EOF when reading serialized pyobject metadata", + )); + }; + + let py3_version = [*a, *b]; + + if py3_version != *PYTHON3_VERSION { + return Err(D::Error::custom(format!( + "python version that pyobject was serialized with {:?} \ + differs from system python version {:?}", + (3, py3_version[0], py3_version[1]), + (3, PYTHON3_VERSION[0], PYTHON3_VERSION[1]), + ))); + } + + let bytes = rem; + + T::try_deserialize_bytes(bytes) + .map(Self) + .map_err(|e| D::Error::custom(e.to_string())) + } } } -/// Get the minor Python version from the `sys` module. -fn get_python_minor_version() -> u8 { +/// Get the [minor, micro] Python3 version from the `sys` module. +fn get_python3_version() -> [u8; 2] { Python::with_gil(|py| { - PyModule::import_bound(py, "sys") + let version_info = PyModule::import_bound(py, "sys") .unwrap() .getattr("version_info") - .unwrap() - .getattr("minor") - .unwrap() - .extract() - .unwrap() + .unwrap(); + + [ + version_info.getattr("minor").unwrap().extract().unwrap(), + version_info.getattr("micro").unwrap().extract().unwrap(), + ] }) } + +fn from_pyerr(e: PyErr) -> PolarsError { + PolarsError::ComputeError(format!("error raised in python: {e}").into()) +} diff --git a/py-polars/tests/unit/io/cloud/test_cloud.py b/py-polars/tests/unit/io/cloud/test_cloud.py index 322d13c44a014..17854f2f66d22 100644 --- a/py-polars/tests/unit/io/cloud/test_cloud.py +++ b/py-polars/tests/unit/io/cloud/test_cloud.py @@ -1,3 +1,6 @@ +import io +import sys + import pytest import polars as pl @@ -52,3 +55,49 @@ def raises_2() -> pl.CredentialProviderFunctionReturn: # from Rust. with pytest.raises(ComputeError, match=err_magic): pl.scan_parquet("s3://bucket/path", credential_provider=raises_2).collect() + + +def test_scan_credential_provider_serialization( + monkeypatch: pytest.MonkeyPatch, +) -> None: + err_magic = "err_magic_3" + + class ErrCredentialProvider(pl.CredentialProvider): + def __call__(self) -> pl.CredentialProviderFunctionReturn: + raise AssertionError(err_magic) + + lf = pl.scan_parquet( + "s3://bucket/path", credential_provider=ErrCredentialProvider() + ) + + serialized = lf.serialize() + + lf = pl.LazyFrame.deserialize(io.BytesIO(serialized)) + + with pytest.raises(AssertionError, match=err_magic): + lf.collect() + + +def test_scan_credential_provider_serialization_pyversion() -> None: + lf = pl.scan_parquet( + "s3://bucket/path", credential_provider=pl.CredentialProviderAWS() + ) + + serialized = lf.serialize() + serialized = bytearray(serialized) + + # We can't monkeypatch sys.python_version so we just mutate the output + # instead. + + v = b"PLPYFN" + i = serialized.index(v) + len(v) + a, b, *_ = serialized[i:] + serialized_pyver = (a, b) + assert serialized_pyver == (sys.version_info.minor, sys.version_info.micro) + serialized[i] = 255 + serialized[i + 1] = 255 + + with pytest.raises( + ComputeError, match="python version .* (3, 255, 255) .* differs .*" + ): + lf = pl.LazyFrame.deserialize(io.BytesIO(serialized))