Skip to content

Commit

Permalink
c
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion committed Oct 22, 2024
1 parent eb596c9 commit e8edca2
Show file tree
Hide file tree
Showing 4 changed files with 273 additions and 56 deletions.
33 changes: 32 additions & 1 deletion crates/polars-io/src/cloud/credential_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,14 @@ mod python_impl {
use super::IntoCredentialProvider;

#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct PythonCredentialProvider(pub(super) Arc<PythonFunction>);

impl From<PythonFunction> for PythonCredentialProvider {
fn from(value: PythonFunction) -> Self {
Self(Arc::new(value))
}
}

impl IntoCredentialProvider for PythonCredentialProvider {
#[cfg(feature = "aws")]
fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider {
Expand Down Expand Up @@ -665,6 +670,32 @@ mod python_impl {
state.write_usize(Arc::as_ptr(&self.0) as *const () as usize)
}
}

#[cfg(feature = "serde")]
mod _serde_impl {
use polars_utils::python_function::PySerializeWrap;

use super::PythonCredentialProvider;

impl serde::Serialize for PythonCredentialProvider {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
PySerializeWrap(self.0.as_ref()).serialize(serializer)
}
}

impl<'a> serde::Deserialize<'a> for PythonCredentialProvider {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'a>,
{
PySerializeWrap::<super::PythonFunction>::deserialize(deserializer)
.map(|x| x.0.into())
}
}
}
}

#[cfg(test)]
Expand Down
21 changes: 11 additions & 10 deletions crates/polars-plan/src/dsl/python_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub static mut CALL_DF_UDF_PYTHON: Option<
> = None;

pub use polars_utils::python_function::{
PythonFunction, PYTHON_SERDE_MAGIC_BYTE_MARK, PYTHON_VERSION_MINOR,
PythonFunction, PYTHON3_VERSION, PYTHON_SERDE_MAGIC_BYTE_MARK,
};

pub struct PythonUdfExpression {
Expand Down Expand Up @@ -57,14 +57,14 @@ impl PythonUdfExpression {
// Handle pickle metadata
let use_cloudpickle = buf[0];
if use_cloudpickle != 0 {
let ser_py_version = buf[1];
let cur_py_version = *PYTHON_VERSION_MINOR;
let ser_py_version = &buf[..2];
let cur_py_version = *PYTHON3_VERSION;
polars_ensure!(
ser_py_version == cur_py_version,
InvalidOperation:
"current Python version (3.{}) does not match the Python version used to serialize the UDF (3.{})",
cur_py_version,
ser_py_version
"current Python version {:?} does not match the Python version used to serialize the UDF {:?}",
(3, cur_py_version[0], cur_py_version[1]),
(3, ser_py_version[0], ser_py_version[1] )
);
}
let buf = &buf[2..];
Expand Down Expand Up @@ -141,8 +141,8 @@ impl ColumnsUdf for PythonUdfExpression {
.getattr("dumps")
.unwrap();
let pickle_result = pickle.call1((self.python_function.clone_ref(py),));
let (dumped, use_cloudpickle, py_version) = match pickle_result {
Ok(dumped) => (dumped, false, 0),
let (dumped, use_cloudpickle) = match pickle_result {
Ok(dumped) => (dumped, false),
Err(_) => {
let cloudpickle = PyModule::import_bound(py, "cloudpickle")
.map_err(from_pyerr)?
Expand All @@ -151,12 +151,13 @@ impl ColumnsUdf for PythonUdfExpression {
let dumped = cloudpickle
.call1((self.python_function.clone_ref(py),))
.map_err(from_pyerr)?;
(dumped, true, *PYTHON_VERSION_MINOR)
(dumped, true)
},
};

// Write pickle metadata
buf.extend_from_slice(&[use_cloudpickle as u8, py_version]);
buf.push(use_cloudpickle as u8);
buf.extend_from_slice(&*PYTHON3_VERSION);

// Write UDF metadata
ciborium::ser::into_writer(
Expand Down
226 changes: 181 additions & 45 deletions crates/polars-utils/src/python_function.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
use once_cell::sync::Lazy;
use polars_error::{polars_bail, PolarsError, PolarsResult};
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedBytes;
use pyo3::types::PyBytes;
#[cfg(feature = "serde")]
use serde::ser::Error;
#[cfg(feature = "serde")]
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub use serde_wrap::{
PySerializeWrap, TrySerializeToBytes, PYTHON3_VERSION,
SERDE_MAGIC_BYTE_MARK as PYTHON_SERDE_MAGIC_BYTE_MARK,
};

#[cfg(feature = "serde")]
pub const PYTHON_SERDE_MAGIC_BYTE_MARK: &[u8] = "PLPYUDF".as_bytes();
pub static PYTHON_VERSION_MINOR: Lazy<u8> = Lazy::new(get_python_minor_version);
use crate::flatten;

#[derive(Debug)]
pub struct PythonFunction(pub PyObject);
Expand Down Expand Up @@ -42,64 +41,201 @@ impl PartialEq for PythonFunction {
}

#[cfg(feature = "serde")]
impl Serialize for PythonFunction {
impl serde::Serialize for PythonFunction {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
S: serde::Serializer,
{
Python::with_gil(|py| {
let pickle = PyModule::import_bound(py, "cloudpickle")
.or_else(|_| PyModule::import_bound(py, "pickle"))
.expect("unable to import 'cloudpickle' or 'pickle'")
.getattr("dumps")
.unwrap();

let python_function = self.0.clone_ref(py);

let dumped = pickle
.call1((python_function,))
.map_err(|s| S::Error::custom(format!("cannot pickle {s}")))?;
let dumped = dumped.extract::<PyBackedBytes>().unwrap();

serializer.serialize_bytes(&dumped)
})
use serde::ser::Error;
serializer.serialize_bytes(
self.try_serialize_to_bytes()
.map_err(|e| S::Error::custom(e.to_string()))?
.as_slice(),
)
}
}

#[cfg(feature = "serde")]
impl<'a> Deserialize<'a> for PythonFunction {
impl<'a> serde::Deserialize<'a> for PythonFunction {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'a>,
D: serde::Deserializer<'a>,
{
use serde::de::Error;
let bytes = Vec::<u8>::deserialize(deserializer)?;
Self::try_deserialize_bytes(bytes.as_slice()).map_err(|e| D::Error::custom(e.to_string()))
}
}

Python::with_gil(|py| {
let pickle = PyModule::import_bound(py, "pickle")
.expect("unable to import 'pickle'")
.getattr("loads")
#[cfg(feature = "serde")]
impl TrySerializeToBytes for PythonFunction {
fn try_serialize_to_bytes(&self) -> polars_error::PolarsResult<Vec<u8>> {
serialize_pyobject_with_cloudpickle_fallback(&self.0)
}

fn try_deserialize_bytes(bytes: &[u8]) -> polars_error::PolarsResult<Self> {
deserialize_pyobject_bytes_maybe_cloudpickle(bytes)
}
}

pub fn serialize_pyobject_with_cloudpickle_fallback(py_object: &PyObject) -> PolarsResult<Vec<u8>> {
Python::with_gil(|py| {
let pickle = PyModule::import_bound(py, "pickle")
.expect("unable to import 'pickle'")
.getattr("dumps")
.unwrap();

let dumped = pickle.call1((py_object.clone_ref(py),));

let (dumped, used_cloudpickle) = if let Ok(v) = dumped {
(v, false)
} else {
let cloudpickle = PyModule::import_bound(py, "cloudpickle")
.map_err(from_pyerr)?
.getattr("dumps")
.unwrap();
let arg = (PyBytes::new_bound(py, &bytes),);
let python_function = pickle
.call1(arg)
.map_err(|s| D::Error::custom(format!("cannot pickle {s}")))?;
let dumped = cloudpickle
.call1((py_object.clone_ref(py),))
.map_err(from_pyerr)?;
(dumped, true)
};

Ok(Self(python_function.into()))
})
let py_bytes = dumped.extract::<PyBackedBytes>().map_err(from_pyerr)?;

Ok(flatten(
&[&[used_cloudpickle as u8, b'C'][..], py_bytes.as_ref()],
None,
))
})
}

pub fn deserialize_pyobject_bytes_maybe_cloudpickle<T: for<'a> From<PyObject>>(
bytes: &[u8],
) -> PolarsResult<T> {
// TODO: Actually deserialize with cloudpickle if it's set.
let [_used_cloudpickle @ 0 | _used_cloudpickle @ 1, b'C', rem @ ..] = bytes else {
polars_bail!(ComputeError: "deserialize_pyobject_bytes_maybe_cloudpickle: invalid start bytes")
};

let bytes = rem;

Python::with_gil(|py| {
let pickle = PyModule::import_bound(py, "pickle")
.expect("unable to import 'pickle'")
.getattr("loads")
.unwrap();
let arg = (PyBytes::new_bound(py, bytes),);
let pyany_bound = pickle.call1(arg).map_err(from_pyerr)?;
Ok(PyObject::from(pyany_bound).into())
})
}

#[cfg(feature = "serde")]
mod serde_wrap {
use once_cell::sync::Lazy;
use polars_error::PolarsResult;

use crate::flatten;

pub const SERDE_MAGIC_BYTE_MARK: &[u8] = "PLPYFN".as_bytes();
/// [minor, micro]
pub static PYTHON3_VERSION: Lazy<[u8; 2]> = Lazy::new(super::get_python3_version);

/// Serializes a Python object without additional system metadata. This is intended to be used
/// together with `PySerializeWrap`, which attaches e.g. Python version metadata.
pub trait TrySerializeToBytes: Sized {
fn try_serialize_to_bytes(&self) -> PolarsResult<Vec<u8>>;
fn try_deserialize_bytes(bytes: &[u8]) -> PolarsResult<Self>;
}

/// Serialization wrapper for T: TrySerializeToBytes that attaches Python
/// version metadata.
pub struct PySerializeWrap<T>(pub T);

impl<T: TrySerializeToBytes> serde::Serialize for PySerializeWrap<&T> {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::Error;
let dumped = self
.0
.try_serialize_to_bytes()
.map_err(|e| S::Error::custom(e.to_string()))?;

serializer.serialize_bytes(
flatten(
&[SERDE_MAGIC_BYTE_MARK, &*PYTHON3_VERSION, dumped.as_slice()],
None,
)
.as_slice(),
)
}
}

impl<'a, T: TrySerializeToBytes> serde::Deserialize<'a> for PySerializeWrap<T> {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'a>,
{
use serde::de::Error;
let bytes = Vec::<u8>::deserialize(deserializer)?;

let Some((magic, rem)) = bytes.split_at_checked(SERDE_MAGIC_BYTE_MARK.len()) else {
return Err(D::Error::custom(
"unexpected EOF when reading serialized pyobject version",
));
};

if magic != SERDE_MAGIC_BYTE_MARK {
return Err(D::Error::custom(
"serialized pyobject did not begin with magic byte mark",
));
}

let bytes = rem;

let [a, b, rem @ ..] = bytes else {
return Err(D::Error::custom(
"unexpected EOF when reading serialized pyobject metadata",
));
};

let py3_version = [*a, *b];

if py3_version != *PYTHON3_VERSION {
return Err(D::Error::custom(format!(
"python version that pyobject was serialized with {:?} \
differs from system python version {:?}",
(3, py3_version[0], py3_version[1]),
(3, PYTHON3_VERSION[0], PYTHON3_VERSION[1]),
)));
}

let bytes = rem;

T::try_deserialize_bytes(bytes)
.map(Self)
.map_err(|e| D::Error::custom(e.to_string()))
}
}
}

/// Get the minor Python version from the `sys` module.
fn get_python_minor_version() -> u8 {
/// Get the [minor, micro] Python3 version from the `sys` module.
fn get_python3_version() -> [u8; 2] {
Python::with_gil(|py| {
PyModule::import_bound(py, "sys")
let version_info = PyModule::import_bound(py, "sys")
.unwrap()
.getattr("version_info")
.unwrap()
.getattr("minor")
.unwrap()
.extract()
.unwrap()
.unwrap();

[
version_info.getattr("minor").unwrap().extract().unwrap(),
version_info.getattr("micro").unwrap().extract().unwrap(),
]
})
}

fn from_pyerr(e: PyErr) -> PolarsError {
PolarsError::ComputeError(format!("error raised in python: {e}").into())
}
Loading

0 comments on commit e8edca2

Please sign in to comment.