Skip to content

Commit

Permalink
alloc using PyMem functions
Browse files Browse the repository at this point in the history
  • Loading branch information
ijl committed Nov 12, 2024
1 parent a53024a commit 797935b
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 77 deletions.
38 changes: 38 additions & 0 deletions src/alloc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// SPDX-License-Identifier: (Apache-2.0 OR MIT)

use std::alloc::{GlobalAlloc, Layout};
use std::ffi::c_void;

struct PyMemAllocator {}

#[global_allocator]
static ALLOCATOR: PyMemAllocator = PyMemAllocator {};

unsafe impl Sync for PyMemAllocator {}

unsafe impl GlobalAlloc for PyMemAllocator {
#[inline]
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
unsafe { pyo3_ffi::PyMem_Malloc(layout.size()) as *mut u8 }
}

#[inline]
unsafe fn dealloc(&self, ptr: *mut u8, _layout: Layout) {
unsafe { pyo3_ffi::PyMem_Free(ptr as *mut c_void) }
}

#[inline]
unsafe fn alloc_zeroed(&self, layout: Layout) -> *mut u8 {
unsafe {
let len = layout.size();
let ptr = pyo3_ffi::PyMem_Malloc(len) as *mut u8;
core::ptr::write_bytes(ptr, 0, len);
ptr
}
}

#[inline]
unsafe fn realloc(&self, ptr: *mut u8, _layout: Layout, new_size: usize) -> *mut u8 {
unsafe { pyo3_ffi::PyMem_Realloc(ptr as *mut c_void, new_size) as *mut u8 }
}
}
67 changes: 39 additions & 28 deletions src/deserialize/backend/yyjson.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use crate::deserialize::pyobject::*;
use crate::deserialize::DeserializeError;
use crate::ffi::yyjson::*;
use crate::str::unicode_from_str;
use crate::typeref::{yyjson_init, YYJSON_ALLOC, YYJSON_BUFFER_SIZE};
use core::ffi::c_char;
use core::ptr::{null, null_mut, NonNull};
use std::borrow::Cow;
Expand Down Expand Up @@ -57,25 +56,52 @@ fn unsafe_yyjson_get_next_non_container(val: *mut yyjson_val) -> *mut yyjson_val
unsafe { ((val as *mut u8).add(YYJSON_VAL_SIZE)) as *mut yyjson_val }
}

type DeserializeBuffer = Vec<core::mem::MaybeUninit<u8>>;

const MINIMUM_BUFFER_CAPACITY: usize = 4096 - core::mem::size_of::<DeserializeBuffer>();

pub(crate) fn deserialize(
data: &'static str,
) -> Result<NonNull<pyo3_ffi::PyObject>, DeserializeError<'static>> {
let buffer_capacity = usize::max(
MINIMUM_BUFFER_CAPACITY,
yyjson_read_max_memory_usage(data.len()),
);
let mut buffer: DeserializeBuffer = Vec::with_capacity(buffer_capacity);
let mut alloc = crate::ffi::yyjson::yyjson_alc {
malloc: None,
realloc: None,
free: None,
ctx: null_mut(),
};
let alloc_ptr = core::ptr::addr_of_mut!(alloc);
unsafe {
crate::ffi::yyjson::yyjson_alc_pool_init(
alloc_ptr,
buffer.as_mut_ptr().cast::<core::ffi::c_void>(),
buffer_capacity,
);
}

let mut err = yyjson_read_err {
code: YYJSON_READ_SUCCESS,
msg: null(),
pos: 0,
};
let doc = if yyjson_read_max_memory_usage(data.len()) < YYJSON_BUFFER_SIZE {
read_doc_with_buffer(data, &mut err)
} else {
read_doc_default(data, &mut err)

let doc = unsafe {
yyjson_read_opts(
data.as_ptr() as *mut c_char,
data.len(),
alloc_ptr as *const crate::ffi::yyjson::yyjson_alc,
&mut err,
)
};
if unlikely!(doc.is_null()) {
let msg: Cow<str> = unsafe { core::ffi::CStr::from_ptr(err.msg).to_string_lossy() };
Err(DeserializeError::from_yyjson(msg, err.pos as i64, data))
} else {
let val = yyjson_doc_get_root(doc);

if unlikely!(!unsafe_yyjson_is_ctn(val)) {
let pyval = match ElementType::from_tag(val) {
ElementType::String => parse_yy_string(val),
Expand All @@ -85,8 +111,8 @@ pub(crate) fn deserialize(
ElementType::Null => parse_none(),
ElementType::True => parse_true(),
ElementType::False => parse_false(),
ElementType::Array => unreachable!(),
ElementType::Object => unreachable!(),
ElementType::Array => unreachable_unchecked!(),
ElementType::Object => unreachable_unchecked!(),
};
unsafe { yyjson_doc_free(doc) };
Ok(pyval)
Expand All @@ -110,21 +136,6 @@ pub(crate) fn deserialize(
}
}

fn read_doc_default(data: &'static str, err: &mut yyjson_read_err) -> *mut yyjson_doc {
unsafe { yyjson_read_opts(data.as_ptr() as *mut c_char, data.len(), null_mut(), err) }
}

fn read_doc_with_buffer(data: &'static str, err: &mut yyjson_read_err) -> *mut yyjson_doc {
unsafe {
yyjson_read_opts(
data.as_ptr() as *mut c_char,
data.len(),
&YYJSON_ALLOC.get_or_init(yyjson_init).alloc,
err,
)
}
}

enum ElementType {
String,
Uint64,
Expand All @@ -149,7 +160,7 @@ impl ElementType {
TAG_FALSE => Self::False,
TAG_ARRAY => Self::Array,
TAG_OBJECT => Self::Object,
_ => unreachable!(),
_ => unreachable_unchecked!(),
}
}
}
Expand Down Expand Up @@ -221,8 +232,8 @@ fn populate_yy_array(list: *mut pyo3_ffi::PyObject, elem: *mut yyjson_val) {
ElementType::Null => parse_none(),
ElementType::True => parse_true(),
ElementType::False => parse_false(),
ElementType::Array => unreachable!(),
ElementType::Object => unreachable!(),
ElementType::Array => unreachable_unchecked!(),
ElementType::Object => unreachable_unchecked!(),
};
append_to_list!(dptr, pyval.as_ptr());
}
Expand Down Expand Up @@ -283,8 +294,8 @@ fn populate_yy_object(dict: *mut pyo3_ffi::PyObject, elem: *mut yyjson_val) {
ElementType::Null => parse_none(),
ElementType::True => parse_true(),
ElementType::False => parse_false(),
ElementType::Array => unreachable!(),
ElementType::Object => unreachable!(),
ElementType::Array => unreachable_unchecked!(),
ElementType::Object => unreachable_unchecked!(),
};
add_to_dict!(dict, pykey, pyval.as_ptr());
reverse_pydict_incref!(pykey);
Expand Down
2 changes: 1 addition & 1 deletion src/deserialize/pyobject.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub fn get_unicode_key(key_str: &str) -> *mut pyo3_ffi::PyObject {
unsafe {
let entry = KEY_MAP
.get_mut()
.unwrap_or_else(|| unreachable!())
.unwrap_or_else(|| unreachable_unchecked!())
.entry(&hash)
.or_insert_with(
|| hash,
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ extern crate unwinding;
#[macro_use]
mod util;

mod alloc;
mod deserialize;
mod ffi;
mod opt;
Expand Down
48 changes: 0 additions & 48 deletions src/typeref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,9 @@

use crate::ffi::orjson_fragmenttype_new;
use core::ffi::c_char;
#[cfg(feature = "yyjson")]
use core::ffi::c_void;
#[cfg(feature = "yyjson")]
use core::mem::MaybeUninit;
use core::ptr::{null_mut, NonNull};
use once_cell::race::{OnceBool, OnceBox};
use pyo3_ffi::*;
#[cfg(feature = "yyjson")]
use std::cell::UnsafeCell;

pub struct NumpyTypes {
pub array: *mut PyTypeObject,
Expand Down Expand Up @@ -76,48 +70,6 @@ pub static mut DESCR_STR: *mut PyObject = null_mut();
pub static mut VALUE_STR: *mut PyObject = null_mut();
pub static mut INT_ATTR_STR: *mut PyObject = null_mut();

#[cfg(feature = "yyjson")]
pub const YYJSON_BUFFER_SIZE: usize = 1024 * 1024 * 8;

#[cfg(feature = "yyjson")]
#[repr(align(64))]
struct YYJSONBuffer(UnsafeCell<MaybeUninit<[u8; YYJSON_BUFFER_SIZE]>>);

#[cfg(feature = "yyjson")]
pub struct YYJSONAlloc {
pub alloc: crate::ffi::yyjson::yyjson_alc,
_buffer: Box<YYJSONBuffer>,
}

#[cfg(feature = "yyjson")]
pub static mut YYJSON_ALLOC: OnceBox<YYJSONAlloc> = OnceBox::new();

#[cfg(feature = "yyjson")]
pub fn yyjson_init() -> Box<YYJSONAlloc> {
// Using unsafe to ensure allocation happens on the heap without going through the stack
// so we don't stack overflow in debug mode. Once rust-lang/rust#63291 is stable (Box::new_uninit)
// we can use that instead.
let layout = std::alloc::Layout::new::<YYJSONBuffer>();
let buffer = unsafe { Box::from_raw(std::alloc::alloc(layout).cast::<YYJSONBuffer>()) };
let mut alloc = crate::ffi::yyjson::yyjson_alc {
malloc: None,
realloc: None,
free: None,
ctx: null_mut(),
};
unsafe {
crate::ffi::yyjson::yyjson_alc_pool_init(
&mut alloc,
buffer.0.get().cast::<c_void>(),
YYJSON_BUFFER_SIZE,
);
}
Box::new(YYJSONAlloc {
alloc,
_buffer: buffer,
})
}

#[allow(non_upper_case_globals)]
pub static mut JsonEncodeError: *mut PyObject = null_mut();
#[allow(non_upper_case_globals)]
Expand Down
6 changes: 6 additions & 0 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,9 @@ macro_rules! popcnt {
core::mem::transmute::<u32, i32>($val).count_ones() as usize
};
}

macro_rules! unreachable_unchecked {
() => {
unsafe { core::hint::unreachable_unchecked() }
};
}

0 comments on commit 797935b

Please sign in to comment.