From fb3757a28350ce2cde29bfe1a419edf2a4d3d380 Mon Sep 17 00:00:00 2001 From: winstxnhdw Date: Thu, 19 Mar 2026 18:03:02 +0800 Subject: [PATCH 01/17] feat: introduce stack-allocated `PyBuffer` --- src/buffer.rs | 448 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 448 insertions(+) diff --git a/src/buffer.rs b/src/buffer.rs index c9f9e70d913..fe04759e658 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -784,6 +784,279 @@ impl_element!(isize, SignedInteger); impl_element!(f32, Float); impl_element!(f64, Float); +/// Stack-allocated typed buffer view. Not constructible directly — use +/// [`PyBufferView::with()`] or [`PyBufferView::with_flags()`]. +/// +/// This is a lightweight alternative to [`PyBuffer`] that avoids heap allocation +/// by placing the `Py_buffer` on the stack. The scoped closure API ensures the +/// buffer cannot be moved +#[repr(transparent)] +pub struct PyBufferView(PyUntypedBufferView, PhantomData<[T]>); + +/// Unlike [`PyUntypedBuffer`] which always requests `PyBUF_FULL_RO`, this type allows +/// arbitrary flags, so flag-dependent accessors like [`format()`](Self::format), +/// [`shape()`](Self::shape), and [`strides()`](Self::strides) return `Option`. +pub struct PyUntypedBufferView { + raw: ffi::Py_buffer, + flags: c_int, +} + +impl PyUntypedBufferView { + /// Acquire a buffer view on the stack, pass it to `f`, then release the buffer. + /// + /// The `flags` parameter controls which buffer fields are requested from the exporter. + /// Use constants like [`ffi::PyBUF_SIMPLE`], [`ffi::PyBUF_FULL_RO`], etc. + pub fn with( + obj: &Bound<'_, PyAny>, + flags: c_int, + f: impl FnOnce(&PyUntypedBufferView) -> R, + ) -> PyResult { + let mut view = mem::MaybeUninit::::uninit(); + let view_ptr = view.as_mut_ptr(); + + unsafe { + ptr::addr_of_mut!((*view_ptr).flags).write(flags); + } + + err::error_on_minusone(obj.py(), unsafe { + ffi::PyObject_GetBuffer(obj.as_ptr(), ptr::addr_of_mut!((*view_ptr).raw), flags) + })?; + + // TODO: needs a cleanup strategy — MaybeUninit never drops its contents, so PyBuffer_Release is not currently called. + Ok(f(unsafe { view.assume_init_ref() })) + } + + /// Gets the pointer to the start of the buffer memory. + #[inline] + pub fn buf_ptr(&self) -> *mut c_void { + self.raw.buf + } + + /// Gets whether the underlying buffer is read-only. + #[inline] + pub fn readonly(&self) -> bool { + self.raw.readonly != 0 + } + + /// Gets the size of a single element, in bytes. + #[inline] + pub fn item_size(&self) -> usize { + self.raw.itemsize as usize + } + + /// Gets the total number of items. + #[inline] + pub fn item_count(&self) -> usize { + (self.raw.len as usize) / (self.raw.itemsize as usize) + } + + /// `item_size() * item_count()`. + /// For contiguous arrays, this is the length of the underlying memory block. + #[inline] + pub fn len_bytes(&self) -> usize { + self.raw.len as usize + } + + /// Gets the number of dimensions. + /// + /// May be 0 to indicate a single scalar value. + #[inline] + pub fn dimensions(&self) -> usize { + self.raw.ndim as usize + } + + /// A string in struct module style syntax describing the contents of a single item. + /// + /// Returns `None` if `PyBUF_FORMAT` was not included in the flags. + #[inline] + pub fn format(&self) -> Option<&CStr> { + if self.flags & ffi::PyBUF_FORMAT != ffi::PyBUF_FORMAT { + return None; + } + if self.raw.format.is_null() { + Some(ffi::c_str!("B")) + } else { + Some(unsafe { CStr::from_ptr(self.raw.format) }) + } + } + + /// Returns the shape array. `shape[i]` is the length of dimension `i`. + /// + /// Returns `None` if `PyBUF_ND` was not included in the flags. + #[inline] + pub fn shape(&self) -> Option<&[usize]> { + if self.flags & ffi::PyBUF_ND != ffi::PyBUF_ND || self.raw.shape.is_null() { + return None; + } + + Some(unsafe { slice::from_raw_parts(self.raw.shape.cast(), self.raw.ndim as usize) }) + } + + /// Returns the strides array. + /// + /// Returns `None` if `PyBUF_STRIDES` was not included in the flags. + #[inline] + pub fn strides(&self) -> Option<&[isize]> { + if self.flags & ffi::PyBUF_STRIDES != ffi::PyBUF_STRIDES || self.raw.strides.is_null() { + return None; + } + + Some(unsafe { slice::from_raw_parts(self.raw.strides, self.raw.ndim as usize) }) + } + + /// Returns the suboffsets array. + /// + /// May return `None` even with `PyBUF_INDIRECT` if the exporter sets `suboffsets` to NULL. + #[inline] + pub fn suboffsets(&self) -> Option<&[isize]> { + if self.raw.suboffsets.is_null() { + None + } else { + Some(unsafe { slice::from_raw_parts(self.raw.suboffsets, self.raw.ndim as usize) }) + } + } + + /// Gets whether the buffer is contiguous in C-style order. + #[inline] + pub fn is_c_contiguous(&self) -> bool { + unsafe { ffi::PyBuffer_IsContiguous(&self.raw, b'C' as std::ffi::c_char) != 0 } + } + + /// Gets whether the buffer is contiguous in Fortran-style order. + #[inline] + pub fn is_fortran_contiguous(&self) -> bool { + unsafe { ffi::PyBuffer_IsContiguous(&self.raw, b'F' as std::ffi::c_char) != 0 } + } + + /// Attempt to interpret this untyped view as containing elements of type `T`. + /// + /// Requires that `PyBUF_FORMAT` was included in the flags. + pub fn as_typed(&self) -> PyResult<&PyBufferView> { + self.ensure_compatible_with::()?; + // SAFETY: PyBufferView is repr(transparent) around PyUntypedBufferView + Ok(unsafe { NonNull::from(self).cast::>().as_ref() }) + } + + fn ensure_compatible_with(&self) -> PyResult<()> { + let format = self.format().ok_or_else(|| { + PyBufferError::new_err( + "buffer format not available (PyBUF_FORMAT flag was not requested)", + ) + })?; + + if mem::size_of::() != self.item_size() || !T::is_compatible_format(format) { + Err(PyBufferError::new_err(format!( + "buffer contents are not compatible with {}", + std::any::type_name::() + ))) + } else if self.raw.buf.align_offset(mem::align_of::()) != 0 { + Err(PyBufferError::new_err(format!( + "buffer contents are insufficiently aligned for {}", + std::any::type_name::() + ))) + } else { + Ok(()) + } + } +} + +impl Debug for PyUntypedBufferView { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + debug_buffer_view("PyUntypedBufferView", self, f) + } +} + +impl PyBufferView { + /// Acquire a typed buffer view on the stack with `PyBUF_FULL_RO` flags, + /// validating that the buffer format is compatible with `T`. + pub fn with(obj: &Bound<'_, PyAny>, f: impl FnOnce(&PyBufferView) -> R) -> PyResult { + Self::with_flags(obj, ffi::PyBUF_FULL_RO, f) + } + + /// Acquire a typed buffer view on the stack with user-specified flags. + /// + /// The flags must include `PyBUF_FORMAT` for type validation to succeed. + pub fn with_flags( + obj: &Bound<'_, PyAny>, + flags: c_int, + f: impl FnOnce(&PyBufferView) -> R, + ) -> PyResult { + PyUntypedBufferView::with(obj, flags, |view| view.as_typed::().map(f))? + } + + /// Gets the buffer memory as a slice. + /// + /// Returns `None` if the buffer is not C-contiguous. + /// + /// The returned slice uses type [`ReadOnlyCell`] because it's theoretically possible + /// for any call into the Python runtime to modify the values in the slice. + pub fn as_slice<'a>(&'a self, _py: Python<'a>) -> Option<&'a [ReadOnlyCell]> { + if self.is_c_contiguous() { + unsafe { + Some(slice::from_raw_parts( + self.0.raw.buf.cast(), + self.item_count(), + )) + } + } else { + None + } + } + + /// Gets the buffer memory as a mutable slice. + /// + /// Returns `None` if the buffer is read-only or not C-contiguous. + /// + /// The returned slice uses type [`Cell`](cell::Cell) because it's theoretically possible + /// for any call into the Python runtime to modify the values in the slice. + pub fn as_mut_slice<'a>(&'a self, _py: Python<'a>) -> Option<&'a [cell::Cell]> { + if !self.readonly() && self.is_c_contiguous() { + unsafe { + Some(slice::from_raw_parts( + self.0.raw.buf.cast(), + self.item_count(), + )) + } + } else { + None + } + } +} + +impl std::ops::Deref for PyBufferView { + type Target = PyUntypedBufferView; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Debug for PyBufferView { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + debug_buffer_view("PyBufferView", &self.0, f) + } +} + +fn debug_buffer_view( + name: &str, + view: &PyUntypedBufferView, + f: &mut std::fmt::Formatter<'_>, +) -> std::fmt::Result { + f.debug_struct(name) + .field("buf", &view.raw.buf) + .field("obj", &view.raw.obj) + .field("len", &view.raw.len) + .field("itemsize", &view.raw.itemsize) + .field("readonly", &view.raw.readonly) + .field("ndim", &view.raw.ndim) + .field("format", &view.format()) + .field("shape", &view.shape()) + .field("strides", &view.strides()) + .field("suboffsets", &view.suboffsets()) + .field("internal", &view.raw.internal) + .finish() +} + #[cfg(test)] mod tests { use super::*; @@ -1078,4 +1351,179 @@ mod tests { }); }); } + + #[test] + fn test_untyped_buffer_view_simple() { + Python::attach(|py| { + let bytes = PyBytes::new(py, b"abcde"); + PyUntypedBufferView::with(&bytes, ffi::PyBUF_SIMPLE, |view| { + assert!(!view.buf_ptr().is_null()); + assert_eq!(view.len_bytes(), 5); + assert_eq!(view.item_size(), 1); + assert_eq!(view.item_count(), 5); + assert!(view.readonly()); + assert_eq!(view.dimensions(), 1); + // PyBUF_SIMPLE doesn't include FORMAT, ND, or STRIDES + assert!(view.format().is_none()); + assert!(view.shape().is_none()); + assert!(view.strides().is_none()); + assert!(view.suboffsets().is_none()); + }) + .unwrap(); + }); + } + + #[test] + fn test_untyped_buffer_view_full() { + Python::attach(|py| { + let bytes = PyBytes::new(py, b"abcde"); + PyUntypedBufferView::with(&bytes, ffi::PyBUF_FULL_RO, |view| { + assert!(!view.buf_ptr().is_null()); + assert_eq!(view.len_bytes(), 5); + assert_eq!(view.item_size(), 1); + assert_eq!(view.item_count(), 5); + assert!(view.readonly()); + assert_eq!(view.dimensions(), 1); + assert_eq!(view.format().unwrap().to_str().unwrap(), "B"); + assert_eq!(view.shape().unwrap(), [5]); + assert_eq!(view.strides().unwrap(), [1]); + assert!(view.suboffsets().is_none()); + assert!(view.is_c_contiguous()); + assert!(view.is_fortran_contiguous()); + }) + .unwrap(); + }); + } + + #[test] + fn test_typed_buffer_view() { + Python::attach(|py| { + let bytes = PyBytes::new(py, b"abcde"); + PyBufferView::::with(&bytes, |view| { + assert_eq!(view.dimensions(), 1); + assert_eq!(view.item_count(), 5); + assert_eq!(view.format().unwrap().to_str().unwrap(), "B"); + assert_eq!(view.shape().unwrap(), [5]); + + let slice = view.as_slice(py).unwrap(); + assert_eq!(slice.len(), 5); + assert_eq!(slice[0].get(), b'a'); + assert_eq!(slice[4].get(), b'e'); + + // bytes are read-only + assert!(view.as_mut_slice(py).is_none()); + }) + .unwrap(); + }); + } + + #[test] + fn test_buffer_view_array() { + Python::attach(|py| { + let array = py + .import("array") + .unwrap() + .call_method("array", ("f", (1.0, 1.5, 2.0, 2.5)), None) + .unwrap(); + PyBufferView::::with(&array, |view| { + assert_eq!(view.dimensions(), 1); + assert_eq!(view.item_count(), 4); + assert_eq!(view.format().unwrap().to_str().unwrap(), "f"); + assert_eq!(view.shape().unwrap(), [4]); + + let slice = view.as_slice(py).unwrap(); + assert_eq!(slice.len(), 4); + assert_eq!(slice[0].get(), 1.0); + assert_eq!(slice[3].get(), 2.5); + + // array.array is writable + let mut_slice = view.as_mut_slice(py).unwrap(); + assert_eq!(mut_slice[0].get(), 1.0); + mut_slice[3].set(2.75); + assert_eq!(slice[3].get(), 2.75); + }) + .unwrap(); + }); + } + + #[test] + fn test_buffer_view_option_accessors() { + Python::attach(|py| { + let bytes = PyBytes::new(py, b"abcde"); + + PyUntypedBufferView::with(&bytes, ffi::PyBUF_SIMPLE, |view| { + assert!(view.format().is_none()); + assert!(view.shape().is_none()); + assert!(view.strides().is_none()); + }) + .unwrap(); + + PyUntypedBufferView::with(&bytes, ffi::PyBUF_ND, |view| { + assert!(view.format().is_none()); + assert_eq!(view.shape().unwrap(), [5]); + assert!(view.strides().is_none()); + }) + .unwrap(); + + PyUntypedBufferView::with(&bytes, ffi::PyBUF_STRIDES, |view| { + assert!(view.format().is_none()); + assert_eq!(view.shape().unwrap(), [5]); + assert_eq!(view.strides().unwrap(), [1]); + }) + .unwrap(); + + PyUntypedBufferView::with(&bytes, ffi::PyBUF_FORMAT, |view| { + assert_eq!(view.format().unwrap().to_str().unwrap(), "B"); + assert!(view.shape().is_none()); + assert!(view.strides().is_none()); + }) + .unwrap(); + }); + } + + #[test] + fn test_buffer_view_error() { + Python::attach(|py| { + let list = crate::types::PyList::empty(py); + let result = PyUntypedBufferView::with(&list, ffi::PyBUF_SIMPLE, |_view| {}); + assert!(result.is_err()); + }); + } + + #[test] + fn test_buffer_view_debug() { + Python::attach(|py| { + let bytes = PyBytes::new(py, b"abcde"); + + PyUntypedBufferView::with(&bytes, ffi::PyBUF_FULL_RO, |view| { + let expected = format!( + concat!( + "PyUntypedBufferView {{ buf: {:?}, obj: {:?}, ", + "len: 5, itemsize: 1, readonly: 1, ", + "ndim: 1, format: Some(\"B\"), shape: Some([5]), ", + "strides: Some([1]), suboffsets: None, internal: {:?} }}", + ), + view.raw.buf, view.raw.obj, view.raw.internal, + ); + let debug_repr = format!("{:?}", view); + assert_eq!(debug_repr, expected); + }) + .unwrap(); + + PyBufferView::::with(&bytes, |view| { + let expected = format!( + concat!( + "PyBufferView {{ buf: {:?}, obj: {:?}, ", + "len: 5, itemsize: 1, readonly: 1, ", + "ndim: 1, format: Some(\"B\"), shape: Some([5]), ", + "strides: Some([1]), suboffsets: None, internal: {:?} }}", + ), + view.0.raw.buf, view.0.raw.obj, view.0.raw.internal, + ); + let debug_repr = format!("{:?}", view); + assert_eq!(debug_repr, expected); + }) + .unwrap(); + }); + } } From a09d64e402d85a0ae12f1b4a9d24f63076f7de35 Mon Sep 17 00:00:00 2001 From: winstxnhdw Date: Thu, 19 Mar 2026 18:24:47 +0800 Subject: [PATCH 02/17] docs: update CHANGELOG --- guide/src/conversions/tables.md | 2 +- newsfragments/5894.added.md | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 newsfragments/5894.added.md diff --git a/guide/src/conversions/tables.md b/guide/src/conversions/tables.md index d62ae5ad00f..c3ec7277e48 100644 --- a/guide/src/conversions/tables.md +++ b/guide/src/conversions/tables.md @@ -29,7 +29,7 @@ The table below contains the Python type and the corresponding function argument | `slice` | - | `PySlice` | | `type` | - | `PyType` | | `module` | - | `PyModule` | -| `collections.abc.Buffer` | - | `PyBuffer` | +| `collections.abc.Buffer` | - | `PyBuffer`, `PyBufferView` | | `datetime.datetime` | `SystemTime`, `chrono::DateTime`[^7], `chrono::NaiveDateTime`[^7] | `PyDateTime` | | `datetime.date` | `chrono::NaiveDate`[^7] | `PyDate` | | `datetime.time` | `chrono::NaiveTime`[^7] | `PyTime` | diff --git a/newsfragments/5894.added.md b/newsfragments/5894.added.md new file mode 100644 index 00000000000..85c1695022f --- /dev/null +++ b/newsfragments/5894.added.md @@ -0,0 +1 @@ +Add `PyBufferView` and `PyUntypedBufferView`, stack-allocated alternatives to `PyBuffer` and `PyUntypedBuffer` with a scoped closure API that avoids heap allocation. From 2849d45d4fc4bc8536915e8aa67afc2d3fc72f2a Mon Sep 17 00:00:00 2001 From: winstxnhdw Date: Sat, 21 Mar 2026 00:43:55 +0800 Subject: [PATCH 03/17] refactor: apply suggestions --- src/buffer.rs | 144 +++++++++++++++++++++++++++----------------------- 1 file changed, 77 insertions(+), 67 deletions(-) diff --git a/src/buffer.rs b/src/buffer.rs index fe04759e658..7ad84191a1c 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -797,8 +797,7 @@ pub struct PyBufferView(PyUntypedBufferView, PhantomData<[T]>); /// arbitrary flags, so flag-dependent accessors like [`format()`](Self::format), /// [`shape()`](Self::shape), and [`strides()`](Self::strides) return `Option`. pub struct PyUntypedBufferView { - raw: ffi::Py_buffer, - flags: c_int, + raw: mem::MaybeUninit, } impl PyUntypedBufferView { @@ -811,50 +810,60 @@ impl PyUntypedBufferView { flags: c_int, f: impl FnOnce(&PyUntypedBufferView) -> R, ) -> PyResult { - let mut view = mem::MaybeUninit::::uninit(); - let view_ptr = view.as_mut_ptr(); - - unsafe { - ptr::addr_of_mut!((*view_ptr).flags).write(flags); - } + let mut raw = mem::MaybeUninit::::uninit(); err::error_on_minusone(obj.py(), unsafe { - ffi::PyObject_GetBuffer(obj.as_ptr(), ptr::addr_of_mut!((*view_ptr).raw), flags) + ffi::PyObject_GetBuffer(obj.as_ptr(), raw.as_mut_ptr(), flags) })?; - // TODO: needs a cleanup strategy — MaybeUninit never drops its contents, so PyBuffer_Release is not currently called. - Ok(f(unsafe { view.assume_init_ref() })) + // Construct view only after successful GetBuffer, so Drop always + // runs on an initialized Py_buffer. + let mut view = PyUntypedBufferView { raw }; + + // When shape is NULL the consumer must assume itemsize == 1. + if view.raw().shape.is_null() { + unsafe { &mut *view.raw.as_mut_ptr() }.itemsize = 1; + } + + Ok(f(&view)) + } + + #[inline] + fn raw(&self) -> &ffi::Py_buffer { + // SAFETY: PyUntypedBufferView is only constructed after a successful + // PyObject_GetBuffer call, so raw is always initialized. + unsafe { self.raw.assume_init_ref() } } /// Gets the pointer to the start of the buffer memory. #[inline] pub fn buf_ptr(&self) -> *mut c_void { - self.raw.buf + self.raw().buf } /// Gets whether the underlying buffer is read-only. #[inline] pub fn readonly(&self) -> bool { - self.raw.readonly != 0 + self.raw().readonly != 0 } /// Gets the size of a single element, in bytes. #[inline] pub fn item_size(&self) -> usize { - self.raw.itemsize as usize + self.raw().itemsize as usize } /// Gets the total number of items. #[inline] pub fn item_count(&self) -> usize { - (self.raw.len as usize) / (self.raw.itemsize as usize) + (self.raw().len as usize) / (self.raw().itemsize as usize) } /// `item_size() * item_count()`. /// For contiguous arrays, this is the length of the underlying memory block. #[inline] pub fn len_bytes(&self) -> usize { - self.raw.len as usize + self.raw().len as usize } /// Gets the number of dimensions. @@ -862,46 +871,42 @@ impl PyUntypedBufferView { /// May be 0 to indicate a single scalar value. #[inline] pub fn dimensions(&self) -> usize { - self.raw.ndim as usize + self.raw().ndim as usize } - /// A string in struct module style syntax describing the contents of a single item. - /// - /// Returns `None` if `PyBUF_FORMAT` was not included in the flags. + /// A [struct module style](https://docs.python.org/3/c-api/buffer.html#c.Py_buffer.format) + /// string describing the contents of a single item. Defaults to `"B"` if NULL. #[inline] - pub fn format(&self) -> Option<&CStr> { - if self.flags & ffi::PyBUF_FORMAT != ffi::PyBUF_FORMAT { - return None; - } - if self.raw.format.is_null() { - Some(ffi::c_str!("B")) + pub fn format(&self) -> &CStr { + if self.raw().format.is_null() { + ffi::c_str!("B") } else { - Some(unsafe { CStr::from_ptr(self.raw.format) }) + unsafe { CStr::from_ptr(self.raw().format) } } } /// Returns the shape array. `shape[i]` is the length of dimension `i`. /// - /// Returns `None` if `PyBUF_ND` was not included in the flags. + /// Returns `None` if the exporter set `shape` to NULL (e.g. `PyBUF_SIMPLE` was requested). #[inline] pub fn shape(&self) -> Option<&[usize]> { - if self.flags & ffi::PyBUF_ND != ffi::PyBUF_ND || self.raw.shape.is_null() { + if self.raw().shape.is_null() { return None; } - Some(unsafe { slice::from_raw_parts(self.raw.shape.cast(), self.raw.ndim as usize) }) + Some(unsafe { slice::from_raw_parts(self.raw().shape.cast(), self.raw().ndim as usize) }) } /// Returns the strides array. /// - /// Returns `None` if `PyBUF_STRIDES` was not included in the flags. + /// Returns `None` if the exporter set `strides` to NULL (e.g. `PyBUF_SIMPLE` was requested). #[inline] pub fn strides(&self) -> Option<&[isize]> { - if self.flags & ffi::PyBUF_STRIDES != ffi::PyBUF_STRIDES || self.raw.strides.is_null() { + if self.raw().strides.is_null() { return None; } - Some(unsafe { slice::from_raw_parts(self.raw.strides, self.raw.ndim as usize) }) + Some(unsafe { slice::from_raw_parts(self.raw().strides, self.raw().ndim as usize) }) } /// Returns the suboffsets array. @@ -909,23 +914,23 @@ impl PyUntypedBufferView { /// May return `None` even with `PyBUF_INDIRECT` if the exporter sets `suboffsets` to NULL. #[inline] pub fn suboffsets(&self) -> Option<&[isize]> { - if self.raw.suboffsets.is_null() { + if self.raw().suboffsets.is_null() { None } else { - Some(unsafe { slice::from_raw_parts(self.raw.suboffsets, self.raw.ndim as usize) }) + Some(unsafe { slice::from_raw_parts(self.raw().suboffsets, self.raw().ndim as usize) }) } } /// Gets whether the buffer is contiguous in C-style order. #[inline] pub fn is_c_contiguous(&self) -> bool { - unsafe { ffi::PyBuffer_IsContiguous(&self.raw, b'C' as std::ffi::c_char) != 0 } + unsafe { ffi::PyBuffer_IsContiguous(self.raw(), b'C' as std::ffi::c_char) != 0 } } /// Gets whether the buffer is contiguous in Fortran-style order. #[inline] pub fn is_fortran_contiguous(&self) -> bool { - unsafe { ffi::PyBuffer_IsContiguous(&self.raw, b'F' as std::ffi::c_char) != 0 } + unsafe { ffi::PyBuffer_IsContiguous(self.raw(), b'F' as std::ffi::c_char) != 0 } } /// Attempt to interpret this untyped view as containing elements of type `T`. @@ -938,18 +943,12 @@ impl PyUntypedBufferView { } fn ensure_compatible_with(&self) -> PyResult<()> { - let format = self.format().ok_or_else(|| { - PyBufferError::new_err( - "buffer format not available (PyBUF_FORMAT flag was not requested)", - ) - })?; - - if mem::size_of::() != self.item_size() || !T::is_compatible_format(format) { + if mem::size_of::() != self.item_size() || !T::is_compatible_format(self.format()) { Err(PyBufferError::new_err(format!( "buffer contents are not compatible with {}", std::any::type_name::() ))) - } else if self.raw.buf.align_offset(mem::align_of::()) != 0 { + } else if self.raw().buf.align_offset(mem::align_of::()) != 0 { Err(PyBufferError::new_err(format!( "buffer contents are insufficiently aligned for {}", std::any::type_name::() @@ -960,6 +959,12 @@ impl PyUntypedBufferView { } } +impl Drop for PyUntypedBufferView { + fn drop(&mut self) { + unsafe { ffi::PyBuffer_Release(self.raw.as_mut_ptr()) } + } +} + impl Debug for PyUntypedBufferView { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { debug_buffer_view("PyUntypedBufferView", self, f) @@ -994,7 +999,7 @@ impl PyBufferView { if self.is_c_contiguous() { unsafe { Some(slice::from_raw_parts( - self.0.raw.buf.cast(), + self.0.raw().buf.cast(), self.item_count(), )) } @@ -1013,7 +1018,7 @@ impl PyBufferView { if !self.readonly() && self.is_c_contiguous() { unsafe { Some(slice::from_raw_parts( - self.0.raw.buf.cast(), + self.0.raw().buf.cast(), self.item_count(), )) } @@ -1042,18 +1047,19 @@ fn debug_buffer_view( view: &PyUntypedBufferView, f: &mut std::fmt::Formatter<'_>, ) -> std::fmt::Result { + let raw = view.raw(); f.debug_struct(name) - .field("buf", &view.raw.buf) - .field("obj", &view.raw.obj) - .field("len", &view.raw.len) - .field("itemsize", &view.raw.itemsize) - .field("readonly", &view.raw.readonly) - .field("ndim", &view.raw.ndim) + .field("buf", &raw.buf) + .field("obj", &raw.obj) + .field("len", &raw.len) + .field("itemsize", &raw.itemsize) + .field("readonly", &raw.readonly) + .field("ndim", &raw.ndim) .field("format", &view.format()) .field("shape", &view.shape()) .field("strides", &view.strides()) .field("suboffsets", &view.suboffsets()) - .field("internal", &view.raw.internal) + .field("internal", &raw.internal) .finish() } @@ -1363,8 +1369,8 @@ mod tests { assert_eq!(view.item_count(), 5); assert!(view.readonly()); assert_eq!(view.dimensions(), 1); - // PyBUF_SIMPLE doesn't include FORMAT, ND, or STRIDES - assert!(view.format().is_none()); + // PyBUF_SIMPLE doesn't include ND or STRIDES; format defaults to "B" + assert_eq!(view.format().to_str().unwrap(), "B"); assert!(view.shape().is_none()); assert!(view.strides().is_none()); assert!(view.suboffsets().is_none()); @@ -1384,7 +1390,7 @@ mod tests { assert_eq!(view.item_count(), 5); assert!(view.readonly()); assert_eq!(view.dimensions(), 1); - assert_eq!(view.format().unwrap().to_str().unwrap(), "B"); + assert_eq!(view.format().to_str().unwrap(), "B"); assert_eq!(view.shape().unwrap(), [5]); assert_eq!(view.strides().unwrap(), [1]); assert!(view.suboffsets().is_none()); @@ -1402,7 +1408,7 @@ mod tests { PyBufferView::::with(&bytes, |view| { assert_eq!(view.dimensions(), 1); assert_eq!(view.item_count(), 5); - assert_eq!(view.format().unwrap().to_str().unwrap(), "B"); + assert_eq!(view.format().to_str().unwrap(), "B"); assert_eq!(view.shape().unwrap(), [5]); let slice = view.as_slice(py).unwrap(); @@ -1428,7 +1434,7 @@ mod tests { PyBufferView::::with(&array, |view| { assert_eq!(view.dimensions(), 1); assert_eq!(view.item_count(), 4); - assert_eq!(view.format().unwrap().to_str().unwrap(), "f"); + assert_eq!(view.format().to_str().unwrap(), "f"); assert_eq!(view.shape().unwrap(), [4]); let slice = view.as_slice(py).unwrap(); @@ -1452,28 +1458,28 @@ mod tests { let bytes = PyBytes::new(py, b"abcde"); PyUntypedBufferView::with(&bytes, ffi::PyBUF_SIMPLE, |view| { - assert!(view.format().is_none()); + assert_eq!(view.format().to_str().unwrap(), "B"); assert!(view.shape().is_none()); assert!(view.strides().is_none()); }) .unwrap(); PyUntypedBufferView::with(&bytes, ffi::PyBUF_ND, |view| { - assert!(view.format().is_none()); + assert_eq!(view.format().to_str().unwrap(), "B"); assert_eq!(view.shape().unwrap(), [5]); assert!(view.strides().is_none()); }) .unwrap(); PyUntypedBufferView::with(&bytes, ffi::PyBUF_STRIDES, |view| { - assert!(view.format().is_none()); + assert_eq!(view.format().to_str().unwrap(), "B"); assert_eq!(view.shape().unwrap(), [5]); assert_eq!(view.strides().unwrap(), [1]); }) .unwrap(); PyUntypedBufferView::with(&bytes, ffi::PyBUF_FORMAT, |view| { - assert_eq!(view.format().unwrap().to_str().unwrap(), "B"); + assert_eq!(view.format().to_str().unwrap(), "B"); assert!(view.shape().is_none()); assert!(view.strides().is_none()); }) @@ -1500,10 +1506,12 @@ mod tests { concat!( "PyUntypedBufferView {{ buf: {:?}, obj: {:?}, ", "len: 5, itemsize: 1, readonly: 1, ", - "ndim: 1, format: Some(\"B\"), shape: Some([5]), ", + "ndim: 1, format: \"B\", shape: Some([5]), ", "strides: Some([1]), suboffsets: None, internal: {:?} }}", ), - view.raw.buf, view.raw.obj, view.raw.internal, + view.raw().buf, + view.raw().obj, + view.raw().internal, ); let debug_repr = format!("{:?}", view); assert_eq!(debug_repr, expected); @@ -1515,10 +1523,12 @@ mod tests { concat!( "PyBufferView {{ buf: {:?}, obj: {:?}, ", "len: 5, itemsize: 1, readonly: 1, ", - "ndim: 1, format: Some(\"B\"), shape: Some([5]), ", + "ndim: 1, format: \"B\", shape: Some([5]), ", "strides: Some([1]), suboffsets: None, internal: {:?} }}", ), - view.0.raw.buf, view.0.raw.obj, view.0.raw.internal, + view.0.raw().buf, + view.0.raw().obj, + view.0.raw().internal, ); let debug_repr = format!("{:?}", view); assert_eq!(debug_repr, expected); From 394ef330c8ef2b1756b9f1f3dba9bb24da28ddf1 Mon Sep 17 00:00:00 2001 From: "Winston H." Date: Sun, 22 Mar 2026 23:49:23 +0800 Subject: [PATCH 04/17] refactor: use `assume_init` Co-authored-by: David Hewitt --- src/buffer.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/buffer.rs b/src/buffer.rs index 7ad84191a1c..63d33085558 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -816,9 +816,9 @@ impl PyUntypedBufferView { ffi::PyObject_GetBuffer(obj.as_ptr(), raw.as_mut_ptr(), flags) })?; - // Construct view only after successful GetBuffer, so Drop always + // SAEFTY: Construct view only after successful GetBuffer, so Drop always // runs on an initialized Py_buffer. - let mut view = PyUntypedBufferView { raw }; + let mut view = PyUntypedBufferView { raw: unsafe { raw.assume_init() } }; // When shape is NULL the consumer must assume itemsize == 1. if view.raw().shape.is_null() { From 3342322e25d21518c6d86ab713a460eb154d4bc1 Mon Sep 17 00:00:00 2001 From: winstxnhdw Date: Mon, 23 Mar 2026 00:18:41 +0800 Subject: [PATCH 05/17] refactor: apply some suggestions --- src/buffer.rs | 181 ++++++++++++++++++++++++++++---------------------- 1 file changed, 103 insertions(+), 78 deletions(-) diff --git a/src/buffer.rs b/src/buffer.rs index 63d33085558..c9638650b99 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -797,15 +797,43 @@ pub struct PyBufferView(PyUntypedBufferView, PhantomData<[T]>); /// arbitrary flags, so flag-dependent accessors like [`format()`](Self::format), /// [`shape()`](Self::shape), and [`strides()`](Self::strides) return `Option`. pub struct PyUntypedBufferView { - raw: mem::MaybeUninit, + raw: ffi::Py_buffer, } impl PyUntypedBufferView { + /// Acquire a buffer view on the stack with [`ffi::PyBUF_SIMPLE`] flags, + /// pass it to `f`, then release the buffer. + /// + /// Format is patched to `"B"` and itemsize to `1`, as required by the + /// buffer protocol for `PyBUF_SIMPLE` requests. + pub fn with( + obj: &Bound<'_, PyAny>, + f: impl FnOnce(&PyUntypedBufferView) -> R, + ) -> PyResult { + let mut raw = mem::MaybeUninit::::uninit(); + + err::error_on_minusone(obj.py(), unsafe { + ffi::PyObject_GetBuffer(obj.as_ptr(), raw.as_mut_ptr(), ffi::PyBUF_SIMPLE) + })?; + + // SAFETY: Construct view only after successful GetBuffer, so Drop always + // runs on an initialized Py_buffer. + let mut view = PyUntypedBufferView { + raw: unsafe { raw.assume_init() }, + }; + + // For PyBUF_SIMPLE, the consumer must assume itemsize == 1 and format "B". + view.raw.itemsize = 1; + view.raw.format = ffi::c_str!("B").as_ptr() as *mut _; + + Ok(f(&view)) + } + /// Acquire a buffer view on the stack, pass it to `f`, then release the buffer. /// /// The `flags` parameter controls which buffer fields are requested from the exporter. /// Use constants like [`ffi::PyBUF_SIMPLE`], [`ffi::PyBUF_FULL_RO`], etc. - pub fn with( + pub fn with_flags( obj: &Bound<'_, PyAny>, flags: c_int, f: impl FnOnce(&PyUntypedBufferView) -> R, @@ -816,54 +844,44 @@ impl PyUntypedBufferView { ffi::PyObject_GetBuffer(obj.as_ptr(), raw.as_mut_ptr(), flags) })?; - // SAEFTY: Construct view only after successful GetBuffer, so Drop always + // SAFETY: Construct view only after successful GetBuffer, so Drop always // runs on an initialized Py_buffer. - let mut view = PyUntypedBufferView { raw: unsafe { raw.assume_init() } }; - - // When shape is NULL the consumer must assume itemsize == 1. - if view.raw().shape.is_null() { - unsafe { &mut *view.raw.as_mut_ptr() }.itemsize = 1; - } + let view = PyUntypedBufferView { + raw: unsafe { raw.assume_init() }, + }; Ok(f(&view)) } - #[inline] - fn raw(&self) -> &ffi::Py_buffer { - // SAFETY: PyUntypedBufferView is only constructed after a successful - // PyObject_GetBuffer call, so raw is always initialized. - unsafe { self.raw.assume_init_ref() } - } - /// Gets the pointer to the start of the buffer memory. #[inline] pub fn buf_ptr(&self) -> *mut c_void { - self.raw().buf + self.raw.buf } /// Gets whether the underlying buffer is read-only. #[inline] pub fn readonly(&self) -> bool { - self.raw().readonly != 0 + self.raw.readonly != 0 } /// Gets the size of a single element, in bytes. #[inline] pub fn item_size(&self) -> usize { - self.raw().itemsize as usize + self.raw.itemsize as usize } /// Gets the total number of items. #[inline] pub fn item_count(&self) -> usize { - (self.raw().len as usize) / (self.raw().itemsize as usize) + (self.raw.len as usize) / (self.raw.itemsize as usize) } /// `item_size() * item_count()`. /// For contiguous arrays, this is the length of the underlying memory block. #[inline] pub fn len_bytes(&self) -> usize { - self.raw().len as usize + self.raw.len as usize } /// Gets the number of dimensions. @@ -871,18 +889,21 @@ impl PyUntypedBufferView { /// May be 0 to indicate a single scalar value. #[inline] pub fn dimensions(&self) -> usize { - self.raw().ndim as usize + self.raw.ndim as usize } /// A [struct module style](https://docs.python.org/3/c-api/buffer.html#c.Py_buffer.format) - /// string describing the contents of a single item. Defaults to `"B"` if NULL. + /// string describing the contents of a single item. + /// + /// Returns `None` if `PyBUF_FORMAT` was not requested and the request was not + /// `PyBUF_SIMPLE` or `PyBUF_WRITABLE`. #[inline] - pub fn format(&self) -> &CStr { - if self.raw().format.is_null() { - ffi::c_str!("B") - } else { - unsafe { CStr::from_ptr(self.raw().format) } + pub fn format(&self) -> Option<&CStr> { + if self.raw.format.is_null() { + return None; } + + Some(unsafe { CStr::from_ptr(self.raw.format) }) } /// Returns the shape array. `shape[i]` is the length of dimension `i`. @@ -890,11 +911,11 @@ impl PyUntypedBufferView { /// Returns `None` if the exporter set `shape` to NULL (e.g. `PyBUF_SIMPLE` was requested). #[inline] pub fn shape(&self) -> Option<&[usize]> { - if self.raw().shape.is_null() { + if self.raw.shape.is_null() { return None; } - Some(unsafe { slice::from_raw_parts(self.raw().shape.cast(), self.raw().ndim as usize) }) + Some(unsafe { slice::from_raw_parts(self.raw.shape.cast(), self.raw.ndim as usize) }) } /// Returns the strides array. @@ -902,11 +923,11 @@ impl PyUntypedBufferView { /// Returns `None` if the exporter set `strides` to NULL (e.g. `PyBUF_SIMPLE` was requested). #[inline] pub fn strides(&self) -> Option<&[isize]> { - if self.raw().strides.is_null() { + if self.raw.strides.is_null() { return None; } - Some(unsafe { slice::from_raw_parts(self.raw().strides, self.raw().ndim as usize) }) + Some(unsafe { slice::from_raw_parts(self.raw.strides, self.raw.ndim as usize) }) } /// Returns the suboffsets array. @@ -914,23 +935,23 @@ impl PyUntypedBufferView { /// May return `None` even with `PyBUF_INDIRECT` if the exporter sets `suboffsets` to NULL. #[inline] pub fn suboffsets(&self) -> Option<&[isize]> { - if self.raw().suboffsets.is_null() { - None - } else { - Some(unsafe { slice::from_raw_parts(self.raw().suboffsets, self.raw().ndim as usize) }) + if self.raw.suboffsets.is_null() { + return None; } + + Some(unsafe { slice::from_raw_parts(self.raw.suboffsets, self.raw.ndim as usize) }) } /// Gets whether the buffer is contiguous in C-style order. #[inline] pub fn is_c_contiguous(&self) -> bool { - unsafe { ffi::PyBuffer_IsContiguous(self.raw(), b'C' as std::ffi::c_char) != 0 } + unsafe { ffi::PyBuffer_IsContiguous(&self.raw, b'C' as std::ffi::c_char) != 0 } } /// Gets whether the buffer is contiguous in Fortran-style order. #[inline] pub fn is_fortran_contiguous(&self) -> bool { - unsafe { ffi::PyBuffer_IsContiguous(self.raw(), b'F' as std::ffi::c_char) != 0 } + unsafe { ffi::PyBuffer_IsContiguous(&self.raw, b'F' as std::ffi::c_char) != 0 } } /// Attempt to interpret this untyped view as containing elements of type `T`. @@ -943,12 +964,18 @@ impl PyUntypedBufferView { } fn ensure_compatible_with(&self) -> PyResult<()> { - if mem::size_of::() != self.item_size() || !T::is_compatible_format(self.format()) { + let format = self.format().ok_or_else(|| { + PyBufferError::new_err( + "buffer format is not available (PyBUF_FORMAT was not requested)", + ) + })?; + + if mem::size_of::() != self.item_size() || !T::is_compatible_format(format) { Err(PyBufferError::new_err(format!( "buffer contents are not compatible with {}", std::any::type_name::() ))) - } else if self.raw().buf.align_offset(mem::align_of::()) != 0 { + } else if self.raw.buf.align_offset(mem::align_of::()) != 0 { Err(PyBufferError::new_err(format!( "buffer contents are insufficiently aligned for {}", std::any::type_name::() @@ -961,7 +988,7 @@ impl PyUntypedBufferView { impl Drop for PyUntypedBufferView { fn drop(&mut self) { - unsafe { ffi::PyBuffer_Release(self.raw.as_mut_ptr()) } + unsafe { ffi::PyBuffer_Release(&mut self.raw) } } } @@ -980,13 +1007,15 @@ impl PyBufferView { /// Acquire a typed buffer view on the stack with user-specified flags. /// - /// The flags must include `PyBUF_FORMAT` for type validation to succeed. + /// `PyBUF_FORMAT` is implicitly added to the flags for type validation. pub fn with_flags( obj: &Bound<'_, PyAny>, flags: c_int, f: impl FnOnce(&PyBufferView) -> R, ) -> PyResult { - PyUntypedBufferView::with(obj, flags, |view| view.as_typed::().map(f))? + PyUntypedBufferView::with_flags(obj, flags | ffi::PyBUF_FORMAT, |view| { + view.as_typed::().map(f) + })? } /// Gets the buffer memory as a slice. @@ -999,7 +1028,7 @@ impl PyBufferView { if self.is_c_contiguous() { unsafe { Some(slice::from_raw_parts( - self.0.raw().buf.cast(), + self.0.raw.buf.cast(), self.item_count(), )) } @@ -1018,7 +1047,7 @@ impl PyBufferView { if !self.readonly() && self.is_c_contiguous() { unsafe { Some(slice::from_raw_parts( - self.0.raw().buf.cast(), + self.0.raw.buf.cast(), self.item_count(), )) } @@ -1047,19 +1076,18 @@ fn debug_buffer_view( view: &PyUntypedBufferView, f: &mut std::fmt::Formatter<'_>, ) -> std::fmt::Result { - let raw = view.raw(); f.debug_struct(name) - .field("buf", &raw.buf) - .field("obj", &raw.obj) - .field("len", &raw.len) - .field("itemsize", &raw.itemsize) - .field("readonly", &raw.readonly) - .field("ndim", &raw.ndim) + .field("buf", &view.raw.buf) + .field("obj", &view.raw.obj) + .field("len", &view.raw.len) + .field("itemsize", &view.raw.itemsize) + .field("readonly", &view.raw.readonly) + .field("ndim", &view.raw.ndim) .field("format", &view.format()) .field("shape", &view.shape()) .field("strides", &view.strides()) .field("suboffsets", &view.suboffsets()) - .field("internal", &raw.internal) + .field("internal", &view.raw.internal) .finish() } @@ -1362,15 +1390,15 @@ mod tests { fn test_untyped_buffer_view_simple() { Python::attach(|py| { let bytes = PyBytes::new(py, b"abcde"); - PyUntypedBufferView::with(&bytes, ffi::PyBUF_SIMPLE, |view| { + PyUntypedBufferView::with(&bytes, |view| { assert!(!view.buf_ptr().is_null()); assert_eq!(view.len_bytes(), 5); assert_eq!(view.item_size(), 1); assert_eq!(view.item_count(), 5); assert!(view.readonly()); assert_eq!(view.dimensions(), 1); - // PyBUF_SIMPLE doesn't include ND or STRIDES; format defaults to "B" - assert_eq!(view.format().to_str().unwrap(), "B"); + // with() uses PyBUF_SIMPLE and patches format to "B" + assert_eq!(view.format().unwrap().to_str().unwrap(), "B"); assert!(view.shape().is_none()); assert!(view.strides().is_none()); assert!(view.suboffsets().is_none()); @@ -1383,14 +1411,14 @@ mod tests { fn test_untyped_buffer_view_full() { Python::attach(|py| { let bytes = PyBytes::new(py, b"abcde"); - PyUntypedBufferView::with(&bytes, ffi::PyBUF_FULL_RO, |view| { + PyUntypedBufferView::with_flags(&bytes, ffi::PyBUF_FULL_RO, |view| { assert!(!view.buf_ptr().is_null()); assert_eq!(view.len_bytes(), 5); assert_eq!(view.item_size(), 1); assert_eq!(view.item_count(), 5); assert!(view.readonly()); assert_eq!(view.dimensions(), 1); - assert_eq!(view.format().to_str().unwrap(), "B"); + assert_eq!(view.format().unwrap().to_str().unwrap(), "B"); assert_eq!(view.shape().unwrap(), [5]); assert_eq!(view.strides().unwrap(), [1]); assert!(view.suboffsets().is_none()); @@ -1408,7 +1436,7 @@ mod tests { PyBufferView::::with(&bytes, |view| { assert_eq!(view.dimensions(), 1); assert_eq!(view.item_count(), 5); - assert_eq!(view.format().to_str().unwrap(), "B"); + assert_eq!(view.format().unwrap().to_str().unwrap(), "B"); assert_eq!(view.shape().unwrap(), [5]); let slice = view.as_slice(py).unwrap(); @@ -1434,7 +1462,7 @@ mod tests { PyBufferView::::with(&array, |view| { assert_eq!(view.dimensions(), 1); assert_eq!(view.item_count(), 4); - assert_eq!(view.format().to_str().unwrap(), "f"); + assert_eq!(view.format().unwrap().to_str().unwrap(), "f"); assert_eq!(view.shape().unwrap(), [4]); let slice = view.as_slice(py).unwrap(); @@ -1457,29 +1485,30 @@ mod tests { Python::attach(|py| { let bytes = PyBytes::new(py, b"abcde"); - PyUntypedBufferView::with(&bytes, ffi::PyBUF_SIMPLE, |view| { - assert_eq!(view.format().to_str().unwrap(), "B"); + PyUntypedBufferView::with(&bytes, |view| { + assert_eq!(view.item_size(), 1); + assert_eq!(view.format().unwrap().to_str().unwrap(), "B"); assert!(view.shape().is_none()); assert!(view.strides().is_none()); }) .unwrap(); - PyUntypedBufferView::with(&bytes, ffi::PyBUF_ND, |view| { - assert_eq!(view.format().to_str().unwrap(), "B"); + PyUntypedBufferView::with_flags(&bytes, ffi::PyBUF_ND, |view| { + assert!(view.format().is_none()); assert_eq!(view.shape().unwrap(), [5]); assert!(view.strides().is_none()); }) .unwrap(); - PyUntypedBufferView::with(&bytes, ffi::PyBUF_STRIDES, |view| { - assert_eq!(view.format().to_str().unwrap(), "B"); + PyUntypedBufferView::with_flags(&bytes, ffi::PyBUF_STRIDES, |view| { + assert!(view.format().is_none()); assert_eq!(view.shape().unwrap(), [5]); assert_eq!(view.strides().unwrap(), [1]); }) .unwrap(); - PyUntypedBufferView::with(&bytes, ffi::PyBUF_FORMAT, |view| { - assert_eq!(view.format().to_str().unwrap(), "B"); + PyUntypedBufferView::with_flags(&bytes, ffi::PyBUF_FORMAT, |view| { + assert_eq!(view.format().unwrap().to_str().unwrap(), "B"); assert!(view.shape().is_none()); assert!(view.strides().is_none()); }) @@ -1491,7 +1520,7 @@ mod tests { fn test_buffer_view_error() { Python::attach(|py| { let list = crate::types::PyList::empty(py); - let result = PyUntypedBufferView::with(&list, ffi::PyBUF_SIMPLE, |_view| {}); + let result = PyUntypedBufferView::with(&list, |_view| {}); assert!(result.is_err()); }); } @@ -1501,17 +1530,15 @@ mod tests { Python::attach(|py| { let bytes = PyBytes::new(py, b"abcde"); - PyUntypedBufferView::with(&bytes, ffi::PyBUF_FULL_RO, |view| { + PyUntypedBufferView::with_flags(&bytes, ffi::PyBUF_FULL_RO, |view| { let expected = format!( concat!( "PyUntypedBufferView {{ buf: {:?}, obj: {:?}, ", "len: 5, itemsize: 1, readonly: 1, ", - "ndim: 1, format: \"B\", shape: Some([5]), ", + "ndim: 1, format: Some(\"B\"), shape: Some([5]), ", "strides: Some([1]), suboffsets: None, internal: {:?} }}", ), - view.raw().buf, - view.raw().obj, - view.raw().internal, + view.raw.buf, view.raw.obj, view.raw.internal, ); let debug_repr = format!("{:?}", view); assert_eq!(debug_repr, expected); @@ -1523,12 +1550,10 @@ mod tests { concat!( "PyBufferView {{ buf: {:?}, obj: {:?}, ", "len: 5, itemsize: 1, readonly: 1, ", - "ndim: 1, format: \"B\", shape: Some([5]), ", + "ndim: 1, format: Some(\"B\"), shape: Some([5]), ", "strides: Some([1]), suboffsets: None, internal: {:?} }}", ), - view.0.raw().buf, - view.0.raw().obj, - view.0.raw().internal, + view.0.raw.buf, view.0.raw.obj, view.0.raw.internal, ); let debug_repr = format!("{:?}", view); assert_eq!(debug_repr, expected); From 976a39b7aa429dcaea05bb75d43e6df743aa26a4 Mon Sep 17 00:00:00 2001 From: winstxnhdw Date: Mon, 23 Mar 2026 00:46:42 +0800 Subject: [PATCH 06/17] fix: handle `PyBUF_WRITABLE` --- src/buffer.rs | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/buffer.rs b/src/buffer.rs index c9638650b99..4ce31485df5 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -803,9 +803,6 @@ pub struct PyUntypedBufferView { impl PyUntypedBufferView { /// Acquire a buffer view on the stack with [`ffi::PyBUF_SIMPLE`] flags, /// pass it to `f`, then release the buffer. - /// - /// Format is patched to `"B"` and itemsize to `1`, as required by the - /// buffer protocol for `PyBUF_SIMPLE` requests. pub fn with( obj: &Bound<'_, PyAny>, f: impl FnOnce(&PyUntypedBufferView) -> R, @@ -846,10 +843,16 @@ impl PyUntypedBufferView { // SAFETY: Construct view only after successful GetBuffer, so Drop always // runs on an initialized Py_buffer. - let view = PyUntypedBufferView { + let mut view = PyUntypedBufferView { raw: unsafe { raw.assume_init() }, }; + // For PyBUF_WRITABLE, the consumer must assume itemsize == 1 and format "B". + if flags == ffi::PyBUF_WRITABLE { + view.raw.itemsize = 1; + view.raw.format = ffi::c_str!("B").as_ptr() as *mut _; + } + Ok(f(&view)) } From addd61bfca807718d5c5a8833669c2a7ff1e8f34 Mon Sep 17 00:00:00 2001 From: winstxnhdw Date: Mon, 23 Mar 2026 01:57:59 +0800 Subject: [PATCH 07/17] refactor: encode compile-time buffer field availability --- src/buffer.rs | 479 ++++++++++++++++++++++++++------------------------ 1 file changed, 248 insertions(+), 231 deletions(-) diff --git a/src/buffer.rs b/src/buffer.rs index 4ce31485df5..0ad282c7752 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -784,78 +784,58 @@ impl_element!(isize, SignedInteger); impl_element!(f32, Float); impl_element!(f64, Float); -/// Stack-allocated typed buffer view. Not constructible directly — use -/// [`PyBufferView::with()`] or [`PyBufferView::with_flags()`]. -/// -/// This is a lightweight alternative to [`PyBuffer`] that avoids heap allocation -/// by placing the `Py_buffer` on the stack. The scoped closure API ensures the -/// buffer cannot be moved -#[repr(transparent)] -pub struct PyBufferView(PyUntypedBufferView, PhantomData<[T]>); - -/// Unlike [`PyUntypedBuffer`] which always requests `PyBUF_FULL_RO`, this type allows -/// arbitrary flags, so flag-dependent accessors like [`format()`](Self::format), -/// [`shape()`](Self::shape), and [`strides()`](Self::strides) return `Option`. -pub struct PyUntypedBufferView { - raw: ffi::Py_buffer, -} - -impl PyUntypedBufferView { - /// Acquire a buffer view on the stack with [`ffi::PyBUF_SIMPLE`] flags, - /// pass it to `f`, then release the buffer. - pub fn with( - obj: &Bound<'_, PyAny>, - f: impl FnOnce(&PyUntypedBufferView) -> R, - ) -> PyResult { - let mut raw = mem::MaybeUninit::::uninit(); - - err::error_on_minusone(obj.py(), unsafe { - ffi::PyObject_GetBuffer(obj.as_ptr(), raw.as_mut_ptr(), ffi::PyBUF_SIMPLE) - })?; +/// Sealed marker for buffer field availability. Either [`Known`] or [`Unknown`]. +mod buffer_info { + /// Whether a buffer field is guaranteed non-null. + pub trait FieldInfo: sealed::Sealed {} - // SAFETY: Construct view only after successful GetBuffer, so Drop always - // runs on an initialized Py_buffer. - let mut view = PyUntypedBufferView { - raw: unsafe { raw.assume_init() }, - }; - - // For PyBUF_SIMPLE, the consumer must assume itemsize == 1 and format "B". - view.raw.itemsize = 1; - view.raw.format = ffi::c_str!("B").as_ptr() as *mut _; - - Ok(f(&view)) + mod sealed { + pub trait Sealed {} + impl Sealed for super::Known {} + impl Sealed for super::Unknown {} } - /// Acquire a buffer view on the stack, pass it to `f`, then release the buffer. - /// - /// The `flags` parameter controls which buffer fields are requested from the exporter. - /// Use constants like [`ffi::PyBUF_SIMPLE`], [`ffi::PyBUF_FULL_RO`], etc. - pub fn with_flags( - obj: &Bound<'_, PyAny>, - flags: c_int, - f: impl FnOnce(&PyUntypedBufferView) -> R, - ) -> PyResult { - let mut raw = mem::MaybeUninit::::uninit(); - - err::error_on_minusone(obj.py(), unsafe { - ffi::PyObject_GetBuffer(obj.as_ptr(), raw.as_mut_ptr(), flags) - })?; - - // SAFETY: Construct view only after successful GetBuffer, so Drop always - // runs on an initialized Py_buffer. - let mut view = PyUntypedBufferView { - raw: unsafe { raw.assume_init() }, - }; + /// The field is guaranteed to be non-null. The accessor returns the value directly. + pub struct Known; + /// The field may be null. The accessor is not available. + pub struct Unknown; - // For PyBUF_WRITABLE, the consumer must assume itemsize == 1 and format "B". - if flags == ffi::PyBUF_WRITABLE { - view.raw.itemsize = 1; - view.raw.format = ffi::c_str!("B").as_ptr() as *mut _; - } + impl FieldInfo for Known {} + impl FieldInfo for Unknown {} +} +pub use buffer_info::{FieldInfo, Known, Unknown}; - Ok(f(&view)) - } +/// A typed form of [`PyUntypedBufferView`]. Not constructible directly — use +/// [`PyBufferView::with()`] or [`PyBufferView::with_flags()`]. +#[repr(transparent)] +pub struct PyBufferView< + T, + Format: FieldInfo = Known, + Shape: FieldInfo = Known, + Stride: FieldInfo = Known, +>(PyUntypedBufferView, PhantomData<[T]>); + +/// Stack-allocated untyped buffer view. +/// +/// Unlike [`PyUntypedBuffer`] which heap-allocates, this places the `Py_buffer` on the +/// stack. The scoped closure API ensures the buffer cannot be moved. +/// +/// [`with()`](Self::with) requests `PyBUF_FULL_RO` and provides [`format()`](Self::format), +/// [`shape()`](Self::shape), and [`strides()`](Self::strides) accessors. +/// [`with_flags()`](Self::with_flags) accepts arbitrary flags but does not provide those +/// accessors. +pub struct PyUntypedBufferView< + Format: FieldInfo = Unknown, + Shape: FieldInfo = Unknown, + Stride: FieldInfo = Unknown, +> { + raw: ffi::Py_buffer, + _marker: PhantomData<(Format, Shape, Stride)>, +} +impl + PyUntypedBufferView +{ /// Gets the pointer to the start of the buffer memory. #[inline] pub fn buf_ptr(&self) -> *mut c_void { @@ -895,44 +875,6 @@ impl PyUntypedBufferView { self.raw.ndim as usize } - /// A [struct module style](https://docs.python.org/3/c-api/buffer.html#c.Py_buffer.format) - /// string describing the contents of a single item. - /// - /// Returns `None` if `PyBUF_FORMAT` was not requested and the request was not - /// `PyBUF_SIMPLE` or `PyBUF_WRITABLE`. - #[inline] - pub fn format(&self) -> Option<&CStr> { - if self.raw.format.is_null() { - return None; - } - - Some(unsafe { CStr::from_ptr(self.raw.format) }) - } - - /// Returns the shape array. `shape[i]` is the length of dimension `i`. - /// - /// Returns `None` if the exporter set `shape` to NULL (e.g. `PyBUF_SIMPLE` was requested). - #[inline] - pub fn shape(&self) -> Option<&[usize]> { - if self.raw.shape.is_null() { - return None; - } - - Some(unsafe { slice::from_raw_parts(self.raw.shape.cast(), self.raw.ndim as usize) }) - } - - /// Returns the strides array. - /// - /// Returns `None` if the exporter set `strides` to NULL (e.g. `PyBUF_SIMPLE` was requested). - #[inline] - pub fn strides(&self) -> Option<&[isize]> { - if self.raw.strides.is_null() { - return None; - } - - Some(unsafe { slice::from_raw_parts(self.raw.strides, self.raw.ndim as usize) }) - } - /// Returns the suboffsets array. /// /// May return `None` even with `PyBUF_INDIRECT` if the exporter sets `suboffsets` to NULL. @@ -956,71 +898,202 @@ impl PyUntypedBufferView { pub fn is_fortran_contiguous(&self) -> bool { unsafe { ffi::PyBuffer_IsContiguous(&self.raw, b'F' as std::ffi::c_char) != 0 } } +} + +impl PyUntypedBufferView { + /// A [struct module style](https://docs.python.org/3/c-api/buffer.html#c.Py_buffer.format) + /// string describing the contents of a single item. + #[inline] + pub fn format(&self) -> &CStr { + debug_assert!(!self.raw.format.is_null()); + unsafe { CStr::from_ptr(self.raw.format) } + } /// Attempt to interpret this untyped view as containing elements of type `T`. - /// - /// Requires that `PyBUF_FORMAT` was included in the flags. - pub fn as_typed(&self) -> PyResult<&PyBufferView> { + pub fn as_typed(&self) -> PyResult<&PyBufferView> { self.ensure_compatible_with::()?; - // SAFETY: PyBufferView is repr(transparent) around PyUntypedBufferView - Ok(unsafe { NonNull::from(self).cast::>().as_ref() }) + // SAFETY: PyBufferView is repr(transparent) around PyUntypedBufferView + Ok(unsafe { + NonNull::from(self) + .cast::>() + .as_ref() + }) } fn ensure_compatible_with(&self) -> PyResult<()> { - let format = self.format().ok_or_else(|| { - PyBufferError::new_err( - "buffer format is not available (PyBUF_FORMAT was not requested)", - ) - })?; + let name = std::any::type_name::(); - if mem::size_of::() != self.item_size() || !T::is_compatible_format(format) { - Err(PyBufferError::new_err(format!( - "buffer contents are not compatible with {}", - std::any::type_name::() - ))) - } else if self.raw.buf.align_offset(mem::align_of::()) != 0 { - Err(PyBufferError::new_err(format!( - "buffer contents are insufficiently aligned for {}", - std::any::type_name::() - ))) - } else { - Ok(()) + if mem::size_of::() != self.item_size() || !T::is_compatible_format(self.format()) { + return Err(PyBufferError::new_err(format!( + "buffer contents are not compatible with {name}" + ))); + } + if self.raw.buf.align_offset(mem::align_of::()) != 0 { + return Err(PyBufferError::new_err(format!( + "buffer contents are insufficiently aligned for {name}" + ))); } + + Ok(()) + } +} + +impl PyUntypedBufferView { + /// Returns the shape array. `shape[i]` is the length of dimension `i`. + /// + /// Despite Python using an array of signed integers, the values are guaranteed to be + /// non-negative. However, dimensions of length 0 are possible and might need special + /// attention. + #[inline] + pub fn shape(&self) -> &[usize] { + debug_assert!(!self.raw.shape.is_null()); + unsafe { slice::from_raw_parts(self.raw.shape.cast(), self.raw.ndim as usize) } + } +} + +impl PyUntypedBufferView { + /// Returns the strides array. + /// + /// Stride values can be any integer. For regular arrays, strides are usually positive, + /// but a consumer MUST be able to handle the case `strides[n] <= 0`. + #[inline] + pub fn strides(&self) -> &[isize] { + debug_assert!(!self.raw.strides.is_null()); + unsafe { slice::from_raw_parts(self.raw.strides, self.raw.ndim as usize) } + } +} + +impl PyUntypedBufferView { + /// Acquire a buffer view with [`ffi::PyBUF_FULL_RO`] flags, + /// pass it to `f`, then release the buffer. + pub fn with( + obj: &Bound<'_, PyAny>, + f: impl FnOnce(&PyUntypedBufferView) -> R, + ) -> PyResult { + let mut raw = mem::MaybeUninit::::uninit(); + + err::error_on_minusone(obj.py(), unsafe { + ffi::PyObject_GetBuffer(obj.as_ptr(), raw.as_mut_ptr(), ffi::PyBUF_FULL_RO) + })?; + + let view = PyUntypedBufferView { + raw: unsafe { raw.assume_init() }, + _marker: PhantomData, + }; + + Ok(f(&view)) + } +} + +impl PyUntypedBufferView { + /// Acquire a buffer view with arbitrary flags, + /// pass it to `f`, then release the buffer. + /// + /// The `flags` parameter controls which buffer fields are requested from the exporter. + /// Use constants like [`ffi::PyBUF_SIMPLE`], [`ffi::PyBUF_ND`], [`ffi::PyBUF_STRIDES`], etc. + pub fn with_flags( + obj: &Bound<'_, PyAny>, + flags: c_int, + f: impl FnOnce(&PyUntypedBufferView) -> R, + ) -> PyResult { + let mut raw = mem::MaybeUninit::::uninit(); + + err::error_on_minusone(obj.py(), unsafe { + ffi::PyObject_GetBuffer(obj.as_ptr(), raw.as_mut_ptr(), flags) + })?; + + let view = PyUntypedBufferView { + raw: unsafe { raw.assume_init() }, + _marker: PhantomData, + }; + + Ok(f(&view)) } } -impl Drop for PyUntypedBufferView { +fn debug_buffer_view( + name: &str, + raw: &ffi::Py_buffer, + f: &mut std::fmt::Formatter<'_>, +) -> std::fmt::Result { + let ndim = raw.ndim as usize; + let format = NonNull::new(raw.format).map(|p| unsafe { CStr::from_ptr(p.as_ptr()) }); + let shape = NonNull::new(raw.shape) + .map(|p| unsafe { slice::from_raw_parts(p.as_ptr().cast::(), ndim) }); + let strides = + NonNull::new(raw.strides).map(|p| unsafe { slice::from_raw_parts(p.as_ptr(), ndim) }); + let suboffsets = + NonNull::new(raw.suboffsets).map(|p| unsafe { slice::from_raw_parts(p.as_ptr(), ndim) }); + + f.debug_struct(name) + .field("buf", &raw.buf) + .field("obj", &raw.obj) + .field("len", &raw.len) + .field("itemsize", &raw.itemsize) + .field("readonly", &raw.readonly) + .field("ndim", &raw.ndim) + .field("format", &format) + .field("shape", &shape) + .field("strides", &strides) + .field("suboffsets", &suboffsets) + .field("internal", &raw.internal) + .finish() +} + +impl Drop + for PyUntypedBufferView +{ fn drop(&mut self) { unsafe { ffi::PyBuffer_Release(&mut self.raw) } } } -impl Debug for PyUntypedBufferView { +impl Debug + for PyUntypedBufferView +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - debug_buffer_view("PyUntypedBufferView", self, f) + debug_buffer_view("PyUntypedBufferView", &self.raw, f) } } -impl PyBufferView { - /// Acquire a typed buffer view on the stack with `PyBUF_FULL_RO` flags, +impl PyBufferView { + /// Acquire a typed buffer view with `PyBUF_FULL_RO` flags, /// validating that the buffer format is compatible with `T`. - pub fn with(obj: &Bound<'_, PyAny>, f: impl FnOnce(&PyBufferView) -> R) -> PyResult { - Self::with_flags(obj, ffi::PyBUF_FULL_RO, f) + pub fn with( + obj: &Bound<'_, PyAny>, + f: impl FnOnce(&PyBufferView) -> R, + ) -> PyResult { + PyUntypedBufferView::with(obj, |view| view.as_typed::().map(f))? } +} - /// Acquire a typed buffer view on the stack with user-specified flags. +impl PyBufferView { + /// Acquire a typed buffer view with user-specified flags. /// /// `PyBUF_FORMAT` is implicitly added to the flags for type validation. pub fn with_flags( obj: &Bound<'_, PyAny>, flags: c_int, - f: impl FnOnce(&PyBufferView) -> R, + f: impl FnOnce(&PyBufferView) -> R, ) -> PyResult { - PyUntypedBufferView::with_flags(obj, flags | ffi::PyBUF_FORMAT, |view| { - view.as_typed::().map(f) - })? + let mut raw = mem::MaybeUninit::::uninit(); + + err::error_on_minusone(obj.py(), unsafe { + ffi::PyObject_GetBuffer(obj.as_ptr(), raw.as_mut_ptr(), flags | ffi::PyBUF_FORMAT) + })?; + + let view = PyUntypedBufferView:: { + raw: unsafe { raw.assume_init() }, + _marker: PhantomData, + }; + + view.as_typed::().map(f) } +} +impl + PyBufferView +{ /// Gets the buffer memory as a slice. /// /// Returns `None` if the buffer is not C-contiguous. @@ -1028,16 +1101,11 @@ impl PyBufferView { /// The returned slice uses type [`ReadOnlyCell`] because it's theoretically possible /// for any call into the Python runtime to modify the values in the slice. pub fn as_slice<'a>(&'a self, _py: Python<'a>) -> Option<&'a [ReadOnlyCell]> { - if self.is_c_contiguous() { - unsafe { - Some(slice::from_raw_parts( - self.0.raw.buf.cast(), - self.item_count(), - )) - } - } else { - None + if !self.is_c_contiguous() { + return None; } + + Some(unsafe { slice::from_raw_parts(self.0.raw.buf.cast(), self.item_count()) }) } /// Gets the buffer memory as a mutable slice. @@ -1047,53 +1115,32 @@ impl PyBufferView { /// The returned slice uses type [`Cell`](cell::Cell) because it's theoretically possible /// for any call into the Python runtime to modify the values in the slice. pub fn as_mut_slice<'a>(&'a self, _py: Python<'a>) -> Option<&'a [cell::Cell]> { - if !self.readonly() && self.is_c_contiguous() { - unsafe { - Some(slice::from_raw_parts( - self.0.raw.buf.cast(), - self.item_count(), - )) - } - } else { - None + if self.readonly() || !self.is_c_contiguous() { + return None; } + + Some(unsafe { slice::from_raw_parts(self.0.raw.buf.cast(), self.item_count()) }) } } -impl std::ops::Deref for PyBufferView { - type Target = PyUntypedBufferView; +impl std::ops::Deref + for PyBufferView +{ + type Target = PyUntypedBufferView; fn deref(&self) -> &Self::Target { &self.0 } } -impl Debug for PyBufferView { +impl Debug + for PyBufferView +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - debug_buffer_view("PyBufferView", &self.0, f) + debug_buffer_view("PyBufferView", &self.0.raw, f) } } -fn debug_buffer_view( - name: &str, - view: &PyUntypedBufferView, - f: &mut std::fmt::Formatter<'_>, -) -> std::fmt::Result { - f.debug_struct(name) - .field("buf", &view.raw.buf) - .field("obj", &view.raw.obj) - .field("len", &view.raw.len) - .field("itemsize", &view.raw.itemsize) - .field("readonly", &view.raw.readonly) - .field("ndim", &view.raw.ndim) - .field("format", &view.format()) - .field("shape", &view.shape()) - .field("strides", &view.strides()) - .field("suboffsets", &view.suboffsets()) - .field("internal", &view.raw.internal) - .finish() -} - #[cfg(test)] mod tests { use super::*; @@ -1390,7 +1437,7 @@ mod tests { } #[test] - fn test_untyped_buffer_view_simple() { + fn test_untyped_buffer_view() { Python::attach(|py| { let bytes = PyBytes::new(py, b"abcde"); PyUntypedBufferView::with(&bytes, |view| { @@ -1400,30 +1447,10 @@ mod tests { assert_eq!(view.item_count(), 5); assert!(view.readonly()); assert_eq!(view.dimensions(), 1); - // with() uses PyBUF_SIMPLE and patches format to "B" - assert_eq!(view.format().unwrap().to_str().unwrap(), "B"); - assert!(view.shape().is_none()); - assert!(view.strides().is_none()); - assert!(view.suboffsets().is_none()); - }) - .unwrap(); - }); - } - - #[test] - fn test_untyped_buffer_view_full() { - Python::attach(|py| { - let bytes = PyBytes::new(py, b"abcde"); - PyUntypedBufferView::with_flags(&bytes, ffi::PyBUF_FULL_RO, |view| { - assert!(!view.buf_ptr().is_null()); - assert_eq!(view.len_bytes(), 5); - assert_eq!(view.item_size(), 1); - assert_eq!(view.item_count(), 5); - assert!(view.readonly()); - assert_eq!(view.dimensions(), 1); - assert_eq!(view.format().unwrap().to_str().unwrap(), "B"); - assert_eq!(view.shape().unwrap(), [5]); - assert_eq!(view.strides().unwrap(), [1]); + // with() uses PyBUF_FULL_RO — all Known, direct return types + assert_eq!(view.format().to_str().unwrap(), "B"); + assert_eq!(view.shape(), [5]); + assert_eq!(view.strides(), [1]); assert!(view.suboffsets().is_none()); assert!(view.is_c_contiguous()); assert!(view.is_fortran_contiguous()); @@ -1439,8 +1466,9 @@ mod tests { PyBufferView::::with(&bytes, |view| { assert_eq!(view.dimensions(), 1); assert_eq!(view.item_count(), 5); - assert_eq!(view.format().unwrap().to_str().unwrap(), "B"); - assert_eq!(view.shape().unwrap(), [5]); + // PyBufferView::with uses PyBUF_FULL_RO — all Known + assert_eq!(view.format().to_str().unwrap(), "B"); + assert_eq!(view.shape(), [5]); let slice = view.as_slice(py).unwrap(); assert_eq!(slice.len(), 5); @@ -1465,8 +1493,8 @@ mod tests { PyBufferView::::with(&array, |view| { assert_eq!(view.dimensions(), 1); assert_eq!(view.item_count(), 4); - assert_eq!(view.format().unwrap().to_str().unwrap(), "f"); - assert_eq!(view.shape().unwrap(), [4]); + assert_eq!(view.format().to_str().unwrap(), "f"); + assert_eq!(view.shape(), [4]); let slice = view.as_slice(py).unwrap(); assert_eq!(slice.len(), 4); @@ -1484,36 +1512,22 @@ mod tests { } #[test] - fn test_buffer_view_option_accessors() { + fn test_buffer_view_with_flags() { Python::attach(|py| { let bytes = PyBytes::new(py, b"abcde"); - PyUntypedBufferView::with(&bytes, |view| { - assert_eq!(view.item_size(), 1); - assert_eq!(view.format().unwrap().to_str().unwrap(), "B"); - assert!(view.shape().is_none()); - assert!(view.strides().is_none()); - }) - .unwrap(); - + // with_flags gives all-Unknown — only always-available methods PyUntypedBufferView::with_flags(&bytes, ffi::PyBUF_ND, |view| { - assert!(view.format().is_none()); - assert_eq!(view.shape().unwrap(), [5]); - assert!(view.strides().is_none()); - }) - .unwrap(); - - PyUntypedBufferView::with_flags(&bytes, ffi::PyBUF_STRIDES, |view| { - assert!(view.format().is_none()); - assert_eq!(view.shape().unwrap(), [5]); - assert_eq!(view.strides().unwrap(), [1]); + assert_eq!(view.item_count(), 5); + assert_eq!(view.len_bytes(), 5); + assert!(view.readonly()); + assert!(view.suboffsets().is_none()); }) .unwrap(); PyUntypedBufferView::with_flags(&bytes, ffi::PyBUF_FORMAT, |view| { - assert_eq!(view.format().unwrap().to_str().unwrap(), "B"); - assert!(view.shape().is_none()); - assert!(view.strides().is_none()); + assert_eq!(view.item_count(), 5); + assert!(view.suboffsets().is_none()); }) .unwrap(); }); @@ -1533,6 +1547,7 @@ mod tests { Python::attach(|py| { let bytes = PyBytes::new(py, b"abcde"); + // Debug always uses raw_format/raw_shape/raw_strides (Option in output) PyUntypedBufferView::with_flags(&bytes, ffi::PyBUF_FULL_RO, |view| { let expected = format!( concat!( @@ -1543,6 +1558,7 @@ mod tests { ), view.raw.buf, view.raw.obj, view.raw.internal, ); + let debug_repr = format!("{:?}", view); assert_eq!(debug_repr, expected); }) @@ -1558,6 +1574,7 @@ mod tests { ), view.0.raw.buf, view.0.raw.obj, view.0.raw.internal, ); + let debug_repr = format!("{:?}", view); assert_eq!(debug_repr, expected); }) From a2b6119cc61f352896a425594dc8322d00a853a1 Mon Sep 17 00:00:00 2001 From: winstxnhdw Date: Mon, 23 Mar 2026 02:16:08 +0800 Subject: [PATCH 08/17] style: clean up --- src/buffer.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/buffer.rs b/src/buffer.rs index 0ad282c7752..b3f757af4dc 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -912,12 +912,14 @@ impl PyUntypedBufferView(&self) -> PyResult<&PyBufferView> { self.ensure_compatible_with::()?; - // SAFETY: PyBufferView is repr(transparent) around PyUntypedBufferView - Ok(unsafe { + // SAFETY: PyBufferView is repr(transparent) around PyUntypedBufferView<..> + let typed = unsafe { NonNull::from(self) .cast::>() .as_ref() - }) + }; + + Ok(typed) } fn ensure_compatible_with(&self) -> PyResult<()> { @@ -928,6 +930,7 @@ impl PyUntypedBufferView()) != 0 { return Err(PyBufferError::new_err(format!( "buffer contents are insufficiently aligned for {name}" From dc7a939dba6abc4cc658a19fb35b232bf307c0b2 Mon Sep 17 00:00:00 2001 From: winstxnhdw Date: Wed, 8 Apr 2026 02:07:29 +0800 Subject: [PATCH 09/17] tests: extend coverage --- src/buffer.rs | 143 +++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 125 insertions(+), 18 deletions(-) diff --git a/src/buffer.rs b/src/buffer.rs index b3f757af4dc..4cf78522868 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -1288,6 +1288,27 @@ mod tests { (c"=d", Float { bytes: 8 }), (c"=z", Unknown), (c"=0", Unknown), + // bare char (no prefix) goes to native_element_type_from_type_char + ( + c"b", + SignedInteger { + bytes: size_of::(), + }, + ), + ( + c"B", + UnsignedInteger { + bytes: size_of::(), + }, + ), + (c"?", Bool), + (c"f", Float { bytes: 4 }), + (c"d", Float { bytes: 8 }), + (c"z", Unknown), + // <, >, ! prefixes go to standard_element_type_from_type_char + (c"H", UnsignedInteger { bytes: 2 }), + (c"!q", SignedInteger { bytes: 8 }), // unknown prefix -> Unknown (c":b", Unknown), ] { @@ -1325,6 +1346,7 @@ mod tests { assert_eq!(slice.len(), 5); assert_eq!(slice[0].get(), b'a'); assert_eq!(slice[2].get(), b'c'); + assert_eq!(unsafe { *slice[0].as_ptr() }, b'a'); assert_eq!(unsafe { *(buffer.get_ptr(&[1]).cast::()) }, b'b'); @@ -1398,24 +1420,6 @@ mod tests { }); } - #[test] - fn test_untyped_buffer() { - Python::attach(|py| { - let bytes = PyBytes::new(py, b"abcde"); - let untyped = PyUntypedBuffer::get(&bytes).unwrap(); - assert_eq!(untyped.dimensions(), 1); - assert_eq!(untyped.item_count(), 5); - assert_eq!(untyped.format().to_str().unwrap(), "B"); - assert_eq!(untyped.shape(), [5]); - - let typed: &PyBuffer = untyped.as_typed().unwrap(); - assert_eq!(typed.dimensions(), 1); - assert_eq!(typed.item_count(), 5); - assert_eq!(typed.format().to_str().unwrap(), "B"); - assert_eq!(typed.shape(), [5]); - }); - } - #[test] fn test_obj_getter() { Python::attach(|py| { @@ -1439,6 +1443,67 @@ mod tests { }); } + #[test] + fn test_copy_to_fortran_slice() { + Python::attach(|py| { + let array = py + .import("array") + .unwrap() + .call_method("array", ("f", (1.0, 1.5, 2.0, 2.5)), None) + .unwrap(); + let buffer = PyBuffer::get(&array).unwrap(); + + // wrong length + assert!(buffer.copy_to_fortran_slice(py, &mut [0.0f32]).is_err()); + // correct length + let mut arr = [0.0f32; 4]; + buffer.copy_to_fortran_slice(py, &mut arr).unwrap(); + assert_eq!(arr, [1.0, 1.5, 2.0, 2.5]); + }); + } + + #[test] + fn test_copy_from_slice_wrong_length() { + Python::attach(|py| { + let array = py + .import("array") + .unwrap() + .call_method("array", ("f", (1.0, 1.5, 2.0, 2.5)), None) + .unwrap(); + let buffer = PyBuffer::get(&array).unwrap(); + // writable buffer, but wrong length + assert!(!buffer.readonly()); + assert!(buffer.copy_from_slice(py, &[0.0f32; 2]).is_err()); + assert!(buffer.copy_from_fortran_slice(py, &[0.0f32; 2]).is_err()); + }); + } + + #[test] + fn test_untyped_buffer() { + Python::attach(|py| { + let bytes = PyBytes::new(py, b"abcde"); + let buffer = PyUntypedBuffer::get(&bytes).unwrap(); + assert_eq!(buffer.dimensions(), 1); + assert_eq!(buffer.item_count(), 5); + assert_eq!(buffer.format().to_str().unwrap(), "B"); + assert_eq!(buffer.shape(), [5]); + assert!(!buffer.buf_ptr().is_null()); + assert_eq!(buffer.strides(), &[1]); + assert_eq!(buffer.len_bytes(), 5); + assert_eq!(buffer.item_size(), 1); + assert!(buffer.readonly()); + assert!(buffer.suboffsets().is_none()); + + assert!(format!("{:?}", buffer).starts_with("PyUntypedBuffer { buf: ")); + + let typed: &PyBuffer = buffer.as_typed().unwrap(); + assert_eq!(typed.dimensions(), 1); + assert_eq!(typed.item_count(), 5); + assert_eq!(typed.format().to_str().unwrap(), "B"); + assert_eq!(typed.shape(), [5]); + }); + } + #[test] fn test_untyped_buffer_view() { Python::attach(|py| { @@ -1536,6 +1601,48 @@ mod tests { }); } + #[test] + fn test_typed_buffer_view_with_flags() { + Python::attach(|py| { + let array = py + .import("array") + .unwrap() + .call_method("array", ("f", (1.0, 1.5, 2.0, 2.5)), None) + .unwrap(); + + PyBufferView::::with_flags( + &array, + ffi::PyBUF_ND, + |view| { + assert_eq!(view.item_count(), 4); + assert_eq!(view.format().to_str().unwrap(), "f"); + + let slice = view.as_slice(py).unwrap(); + assert_eq!(slice[0].get(), 1.0); + assert_eq!(slice[3].get(), 2.5); + + let mut_slice = view.as_mut_slice(py).unwrap(); + mut_slice[0].set(9.0); + assert_eq!(slice[0].get(), 9.0); + }, + ) + .unwrap(); + }); + } + + #[test] + fn test_typed_buffer_view_with_flags_incompatible() { + Python::attach(|py| { + let bytes = PyBytes::new(py, b"abcde"); + let result = PyBufferView::::with_flags( + &bytes, + ffi::PyBUF_ND, + |_view| {}, + ); + assert!(result.is_err()); + }); + } + #[test] fn test_buffer_view_error() { Python::attach(|py| { From 4602739a6fbd89b80f1aa42bfe17c4f59fcaed73 Mon Sep 17 00:00:00 2001 From: winstxnhdw Date: Wed, 8 Apr 2026 02:11:35 +0800 Subject: [PATCH 10/17] chore: add `obj` API --- src/buffer.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/buffer.rs b/src/buffer.rs index 4cf78522868..78717afc737 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -842,6 +842,12 @@ impl self.raw.buf } + /// Returns the Python object that owns the buffer data. + #[inline] + pub fn obj<'py>(&self, py: Python<'py>) -> Option<&Bound<'py, PyAny>> { + unsafe { Bound::ref_from_ptr_or_opt(py, &self.raw.obj).as_ref() } + } + /// Gets whether the underlying buffer is read-only. #[inline] pub fn readonly(&self) -> bool { @@ -1522,6 +1528,7 @@ mod tests { assert!(view.suboffsets().is_none()); assert!(view.is_c_contiguous()); assert!(view.is_fortran_contiguous()); + assert!(view.obj(py).unwrap().is(&bytes)); }) .unwrap(); }); From cbe37292177ef373ec0b243757f2943eac98de9a Mon Sep 17 00:00:00 2001 From: winstxnhdw Date: Wed, 8 Apr 2026 04:27:23 +0800 Subject: [PATCH 11/17] refactor: use `PyBufferFlags` --- src/buffer.rs | 459 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 313 insertions(+), 146 deletions(-) diff --git a/src/buffer.rs b/src/buffer.rs index 78717afc737..f276af48cdb 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -784,58 +784,130 @@ impl_element!(isize, SignedInteger); impl_element!(f32, Float); impl_element!(f64, Float); -/// Sealed marker for buffer field availability. Either [`Known`] or [`Unknown`]. -mod buffer_info { - /// Whether a buffer field is guaranteed non-null. - pub trait FieldInfo: sealed::Sealed {} +/// Type-safe buffer request flags. The const parameters encode which fields +/// the exporter is required to fill. +pub struct PyBufferFlags< + const FORMAT: bool = false, + const SHAPE: bool = false, + const STRIDE: bool = false, + const WRITABLE: bool = false, + const C_CONTIGUOUS: bool = false, + const F_CONTIGUOUS: bool = false, +>(c_int); + +mod py_buffer_flags_sealed { + pub trait Sealed {} + impl< + const FORMAT: bool, + const SHAPE: bool, + const STRIDE: bool, + const WRITABLE: bool, + const C_CONTIGUOUS: bool, + const F_CONTIGUOUS: bool, + > Sealed for super::PyBufferFlags + { + } +} - mod sealed { - pub trait Sealed {} - impl Sealed for super::Known {} - impl Sealed for super::Unknown {} +/// Trait implemented by all [`PyBufferFlags`] instantiations. +pub trait PyBufferFlagsType: py_buffer_flags_sealed::Sealed {} + +impl< + const FORMAT: bool, + const SHAPE: bool, + const STRIDE: bool, + const WRITABLE: bool, + const C_CONTIGUOUS: bool, + const F_CONTIGUOUS: bool, + > PyBufferFlagsType for PyBufferFlags +{ +} + +// PyBufferFlags::FORMAT | +impl + std::ops::BitOr> + for PyBufferFlags +{ + type Output = PyBufferFlags; + fn bitor(self, rhs: PyBufferFlags) -> Self::Output { + PyBufferFlags(self.0 | rhs.0) } +} - /// The field is guaranteed to be non-null. The accessor returns the value directly. - pub struct Known; - /// The field may be null. The accessor is not available. - pub struct Unknown; +// | PyBufferFlags::FORMAT +impl + std::ops::BitOr> + for PyBufferFlags +{ + type Output = PyBufferFlags; + fn bitor(self, rhs: PyBufferFlags) -> Self::Output { + PyBufferFlags(self.0 | rhs.0) + } +} - impl FieldInfo for Known {} - impl FieldInfo for Unknown {} +#[allow(non_upper_case_globals)] +impl PyBufferFlags { + /// Request a simple buffer with no shape, strides, or format information. + pub const SIMPLE: PyBufferFlags = PyBufferFlags(ffi::PyBUF_SIMPLE); + /// Request format information only. + pub const FORMAT: PyBufferFlags = PyBufferFlags(ffi::PyBUF_FORMAT); + /// Request shape information. + pub const ND: PyBufferFlags = PyBufferFlags(ffi::PyBUF_ND); + /// Request shape and strides. + pub const STRIDES: PyBufferFlags = PyBufferFlags(ffi::PyBUF_STRIDES); + /// Request C-contiguous buffer with shape and strides. + pub const C_CONTIGUOUS: PyBufferFlags = + PyBufferFlags(ffi::PyBUF_C_CONTIGUOUS); + /// Request Fortran-contiguous buffer with shape and strides. + pub const F_CONTIGUOUS: PyBufferFlags = + PyBufferFlags(ffi::PyBUF_F_CONTIGUOUS); + /// Request contiguous buffer (C or Fortran) with shape and strides. + pub const ANY_CONTIGUOUS: PyBufferFlags = + PyBufferFlags(ffi::PyBUF_ANY_CONTIGUOUS); + /// Request shape, strides, and suboffsets. + pub const INDIRECT: PyBufferFlags = PyBufferFlags(ffi::PyBUF_INDIRECT); + /// Request writable buffer with shape. + pub const CONTIG: PyBufferFlags = PyBufferFlags(ffi::PyBUF_CONTIG); + /// Request shape (read-only, equivalent to [`Self::ND`]). + pub const CONTIG_RO: PyBufferFlags = PyBufferFlags(ffi::PyBUF_CONTIG_RO); + /// Request writable buffer with shape and strides. + pub const STRIDED: PyBufferFlags = + PyBufferFlags(ffi::PyBUF_STRIDED); + /// Request shape and strides (read-only, equivalent to [`Self::STRIDES`]). + pub const STRIDED_RO: PyBufferFlags = + PyBufferFlags(ffi::PyBUF_STRIDED_RO); + /// Request writable buffer with shape, strides, and format. + pub const RECORDS: PyBufferFlags = + PyBufferFlags(ffi::PyBUF_RECORDS); + /// Request shape, strides, and format. + pub const RECORDS_RO: PyBufferFlags = PyBufferFlags(ffi::PyBUF_RECORDS_RO); + /// Request writable buffer with all information including suboffsets. + pub const FULL: PyBufferFlags = PyBufferFlags(ffi::PyBUF_FULL); + /// Request all buffer information including suboffsets. + pub const FULL_RO: PyBufferFlags = PyBufferFlags(ffi::PyBUF_FULL_RO); } -pub use buffer_info::{FieldInfo, Known, Unknown}; /// A typed form of [`PyUntypedBufferView`]. Not constructible directly — use /// [`PyBufferView::with()`] or [`PyBufferView::with_flags()`]. #[repr(transparent)] -pub struct PyBufferView< - T, - Format: FieldInfo = Known, - Shape: FieldInfo = Known, - Stride: FieldInfo = Known, ->(PyUntypedBufferView, PhantomData<[T]>); +pub struct PyBufferView>( + PyUntypedBufferView, + PhantomData<[T]>, +); /// Stack-allocated untyped buffer view. /// /// Unlike [`PyUntypedBuffer`] which heap-allocates, this places the `Py_buffer` on the /// stack. The scoped closure API ensures the buffer cannot be moved. /// -/// [`with()`](Self::with) requests `PyBUF_FULL_RO` and provides [`format()`](Self::format), -/// [`shape()`](Self::shape), and [`strides()`](Self::strides) accessors. -/// [`with_flags()`](Self::with_flags) accepts arbitrary flags but does not provide those -/// accessors. -pub struct PyUntypedBufferView< - Format: FieldInfo = Unknown, - Shape: FieldInfo = Unknown, - Stride: FieldInfo = Unknown, -> { +/// Use [`with_flags()`](Self::with_flags) with a [`PyBufferFlags`] constant to acquire a view. +/// The available accessors depend on the flags used. +pub struct PyUntypedBufferView { raw: ffi::Py_buffer, - _marker: PhantomData<(Format, Shape, Stride)>, + _flags: PhantomData, } -impl - PyUntypedBufferView -{ +impl PyUntypedBufferView { /// Gets the pointer to the start of the buffer memory. #[inline] pub fn buf_ptr(&self) -> *mut c_void { @@ -906,7 +978,9 @@ impl } } -impl PyUntypedBufferView { +impl + PyUntypedBufferView> +{ /// A [struct module style](https://docs.python.org/3/c-api/buffer.html#c.Py_buffer.format) /// string describing the contents of a single item. #[inline] @@ -916,38 +990,26 @@ impl PyUntypedBufferView(&self) -> PyResult<&PyBufferView> { + pub fn as_typed( + &self, + ) -> PyResult<&PyBufferView>> { self.ensure_compatible_with::()?; // SAFETY: PyBufferView is repr(transparent) around PyUntypedBufferView<..> - let typed = unsafe { + Ok(unsafe { NonNull::from(self) - .cast::>() + .cast::>>() .as_ref() - }; - - Ok(typed) + }) } fn ensure_compatible_with(&self) -> PyResult<()> { - let name = std::any::type_name::(); - - if mem::size_of::() != self.item_size() || !T::is_compatible_format(self.format()) { - return Err(PyBufferError::new_err(format!( - "buffer contents are not compatible with {name}" - ))); - } - - if self.raw.buf.align_offset(mem::align_of::()) != 0 { - return Err(PyBufferError::new_err(format!( - "buffer contents are insufficiently aligned for {name}" - ))); - } - - Ok(()) + check_buffer_compatibility::(self.raw.buf, self.item_size(), self.format()) } } -impl PyUntypedBufferView { +impl + PyUntypedBufferView> +{ /// Returns the shape array. `shape[i]` is the length of dimension `i`. /// /// Despite Python using an array of signed integers, the values are guaranteed to be @@ -960,7 +1022,9 @@ impl PyUntypedBufferView PyUntypedBufferView { +impl + PyUntypedBufferView> +{ /// Returns the strides array. /// /// Stride values can be any integer. For regular arrays, strides are usually positive, @@ -972,48 +1036,59 @@ impl PyUntypedBufferView { - /// Acquire a buffer view with [`ffi::PyBUF_FULL_RO`] flags, - /// pass it to `f`, then release the buffer. - pub fn with( - obj: &Bound<'_, PyAny>, - f: impl FnOnce(&PyUntypedBufferView) -> R, - ) -> PyResult { - let mut raw = mem::MaybeUninit::::uninit(); - - err::error_on_minusone(obj.py(), unsafe { - ffi::PyObject_GetBuffer(obj.as_ptr(), raw.as_mut_ptr(), ffi::PyBUF_FULL_RO) - })?; +/// Check that a buffer is compatible with element type `T`. +fn check_buffer_compatibility( + buf: *mut c_void, + itemsize: usize, + format: &CStr, +) -> PyResult<()> { + let name = std::any::type_name::(); - let view = PyUntypedBufferView { - raw: unsafe { raw.assume_init() }, - _marker: PhantomData, - }; + if mem::size_of::() != itemsize || !T::is_compatible_format(format) { + return Err(PyBufferError::new_err(format!( + "buffer contents are not compatible with {name}" + ))); + } - Ok(f(&view)) + if buf.align_offset(mem::align_of::()) != 0 { + return Err(PyBufferError::new_err(format!( + "buffer contents are insufficiently aligned for {name}" + ))); } + + Ok(()) } -impl PyUntypedBufferView { - /// Acquire a buffer view with arbitrary flags, +impl PyUntypedBufferView { + /// Acquire a buffer view with the given flags, /// pass it to `f`, then release the buffer. /// - /// The `flags` parameter controls which buffer fields are requested from the exporter. - /// Use constants like [`ffi::PyBUF_SIMPLE`], [`ffi::PyBUF_ND`], [`ffi::PyBUF_STRIDES`], etc. - pub fn with_flags( + /// Use predefined flag constants like [`PyBufferFlags::SIMPLE`], [`PyBufferFlags::ND`], + /// [`PyBufferFlags::STRIDES`], [`PyBufferFlags::FULL_RO`], etc. + pub fn with_flags< + const FORMAT: bool, + const SHAPE: bool, + const STRIDE: bool, + const WRITABLE: bool, + const CCONTIGUOUS: bool, + const FCONTIGUOUS: bool, + R, + >( obj: &Bound<'_, PyAny>, - flags: c_int, - f: impl FnOnce(&PyUntypedBufferView) -> R, + flags: PyBufferFlags, + f: impl FnOnce( + &PyUntypedBufferView>, + ) -> R, ) -> PyResult { let mut raw = mem::MaybeUninit::::uninit(); err::error_on_minusone(obj.py(), unsafe { - ffi::PyObject_GetBuffer(obj.as_ptr(), raw.as_mut_ptr(), flags) + ffi::PyObject_GetBuffer(obj.as_ptr(), raw.as_mut_ptr(), flags.0) })?; let view = PyUntypedBufferView { raw: unsafe { raw.assume_init() }, - _marker: PhantomData, + _flags: PhantomData, }; Ok(f(&view)) @@ -1049,60 +1124,65 @@ fn debug_buffer_view( .finish() } -impl Drop - for PyUntypedBufferView -{ +impl Drop for PyUntypedBufferView { fn drop(&mut self) { unsafe { ffi::PyBuffer_Release(&mut self.raw) } } } -impl Debug - for PyUntypedBufferView -{ +impl Debug for PyUntypedBufferView { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { debug_buffer_view("PyUntypedBufferView", &self.raw, f) } } -impl PyBufferView { - /// Acquire a typed buffer view with `PyBUF_FULL_RO` flags, +impl PyBufferView { + /// Acquire a typed buffer view with `PyBufferFlags::FULL_RO` flags, /// validating that the buffer format is compatible with `T`. - pub fn with( - obj: &Bound<'_, PyAny>, - f: impl FnOnce(&PyBufferView) -> R, - ) -> PyResult { - PyUntypedBufferView::with(obj, |view| view.as_typed::().map(f))? + pub fn with(obj: &Bound<'_, PyAny>, f: impl FnOnce(&PyBufferView) -> R) -> PyResult { + PyUntypedBufferView::with_flags(obj, PyBufferFlags::FULL_RO, |view| { + view.as_typed::().map(f) + })? } -} -impl PyBufferView { - /// Acquire a typed buffer view with user-specified flags. + /// Acquire a typed buffer view with the given flags. /// - /// `PyBUF_FORMAT` is implicitly added to the flags for type validation. - pub fn with_flags( + /// [`ffi::PyBUF_FORMAT`] is implicitly added for type validation. + pub fn with_flags< + const SHAPE: bool, + const STRIDE: bool, + const WRITABLE: bool, + const CCONTIGUOUS: bool, + const FCONTIGUOUS: bool, + R, + >( obj: &Bound<'_, PyAny>, - flags: c_int, - f: impl FnOnce(&PyBufferView) -> R, + flags: PyBufferFlags, + f: impl FnOnce( + &PyBufferView>, + ) -> R, ) -> PyResult { let mut raw = mem::MaybeUninit::::uninit(); err::error_on_minusone(obj.py(), unsafe { - ffi::PyObject_GetBuffer(obj.as_ptr(), raw.as_mut_ptr(), flags | ffi::PyBUF_FORMAT) + ffi::PyObject_GetBuffer( + obj.as_ptr(), + raw.as_mut_ptr(), + flags.0 | ffi::PyBUF_FORMAT, + ) })?; - let view = PyUntypedBufferView:: { - raw: unsafe { raw.assume_init() }, - _marker: PhantomData, - }; + let view = + PyUntypedBufferView::> { + raw: unsafe { raw.assume_init() }, + _flags: PhantomData, + }; view.as_typed::().map(f) } } -impl - PyBufferView -{ +impl PyBufferView { /// Gets the buffer memory as a slice. /// /// Returns `None` if the buffer is not C-contiguous. @@ -1132,19 +1212,72 @@ impl } } -impl std::ops::Deref - for PyBufferView +// C-contiguous guaranteed — no contiguity check needed. +impl< + T: Element, + const FORMAT: bool, + const SHAPE: bool, + const STRIDE: bool, + const WRITABLE: bool, + const FCONTIGUOUS: bool, + > PyBufferView> { - type Target = PyUntypedBufferView; + /// Gets the buffer memory as a slice. The buffer is guaranteed C-contiguous. + pub fn as_contiguous_slice<'a>(&'a self, _py: Python<'a>) -> &'a [ReadOnlyCell] { + unsafe { slice::from_raw_parts(self.0.raw.buf.cast(), self.item_count()) } + } +} + +// C-contiguous + writable guaranteed — no checks needed. +impl + PyBufferView> +{ + /// Gets the buffer memory as a mutable slice. + /// The buffer is guaranteed C-contiguous and writable. + pub fn as_contiguous_mut_slice<'a>(&'a self, _py: Python<'a>) -> &'a [cell::Cell] { + unsafe { slice::from_raw_parts(self.0.raw.buf.cast(), self.item_count()) } + } +} + +// Fortran-contiguous guaranteed — no contiguity check needed. +impl< + T: Element, + const FORMAT: bool, + const SHAPE: bool, + const STRIDE: bool, + const WRITABLE: bool, + const CCONTIGUOUS: bool, + > PyBufferView> +{ + /// Gets the buffer memory as a slice. The buffer is guaranteed Fortran-contiguous. + pub fn as_fortran_contiguous_slice<'a>(&'a self, _py: Python<'a>) -> &'a [ReadOnlyCell] { + unsafe { slice::from_raw_parts(self.0.raw.buf.cast(), self.item_count()) } + } +} + +// Fortran-contiguous + writable guaranteed — no checks needed. +impl + PyBufferView> +{ + /// Gets the buffer memory as a mutable slice. + /// The buffer is guaranteed Fortran-contiguous and writable. + pub fn as_fortran_contiguous_mut_slice<'a>( + &'a self, + _py: Python<'a>, + ) -> &'a [cell::Cell] { + unsafe { slice::from_raw_parts(self.0.raw.buf.cast(), self.item_count()) } + } +} + +impl std::ops::Deref for PyBufferView { + type Target = PyUntypedBufferView; fn deref(&self) -> &Self::Target { &self.0 } } -impl Debug - for PyBufferView -{ +impl Debug for PyBufferView { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { debug_buffer_view("PyBufferView", &self.0.raw, f) } @@ -1514,14 +1647,14 @@ mod tests { fn test_untyped_buffer_view() { Python::attach(|py| { let bytes = PyBytes::new(py, b"abcde"); - PyUntypedBufferView::with(&bytes, |view| { + PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::FULL_RO, |view| { assert!(!view.buf_ptr().is_null()); assert_eq!(view.len_bytes(), 5); assert_eq!(view.item_size(), 1); assert_eq!(view.item_count(), 5); assert!(view.readonly()); assert_eq!(view.dimensions(), 1); - // with() uses PyBUF_FULL_RO — all Known, direct return types + // with() uses PyBufferFlags::FULL_RO — all Known, direct return types assert_eq!(view.format().to_str().unwrap(), "B"); assert_eq!(view.shape(), [5]); assert_eq!(view.strides(), [1]); @@ -1541,7 +1674,7 @@ mod tests { PyBufferView::::with(&bytes, |view| { assert_eq!(view.dimensions(), 1); assert_eq!(view.item_count(), 5); - // PyBufferView::with uses PyBUF_FULL_RO — all Known + // PyBufferView::with uses PyBufferFlags::FULL_RO — all Known assert_eq!(view.format().to_str().unwrap(), "B"); assert_eq!(view.shape(), [5]); @@ -1591,8 +1724,7 @@ mod tests { Python::attach(|py| { let bytes = PyBytes::new(py, b"abcde"); - // with_flags gives all-Unknown — only always-available methods - PyUntypedBufferView::with_flags(&bytes, ffi::PyBUF_ND, |view| { + PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::SIMPLE, |view| { assert_eq!(view.item_count(), 5); assert_eq!(view.len_bytes(), 5); assert!(view.readonly()); @@ -1600,9 +1732,21 @@ mod tests { }) .unwrap(); - PyUntypedBufferView::with_flags(&bytes, ffi::PyBUF_FORMAT, |view| { + PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::ND, |view| { assert_eq!(view.item_count(), 5); - assert!(view.suboffsets().is_none()); + assert_eq!(view.shape(), [5]); + }) + .unwrap(); + + PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::STRIDES, |view| { + assert_eq!(view.shape(), [5]); + assert_eq!(view.strides(), [1]); + }) + .unwrap(); + + PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::FORMAT, |view| { + assert_eq!(view.item_count(), 5); + assert_eq!(view.format().to_str().unwrap(), "B"); }) .unwrap(); }); @@ -1617,22 +1761,19 @@ mod tests { .call_method("array", ("f", (1.0, 1.5, 2.0, 2.5)), None) .unwrap(); - PyBufferView::::with_flags( - &array, - ffi::PyBUF_ND, - |view| { - assert_eq!(view.item_count(), 4); - assert_eq!(view.format().to_str().unwrap(), "f"); + PyBufferView::::with_flags(&array, PyBufferFlags::ND, |view| { + assert_eq!(view.item_count(), 4); + assert_eq!(view.format().to_str().unwrap(), "f"); + assert_eq!(view.shape(), [4]); - let slice = view.as_slice(py).unwrap(); - assert_eq!(slice[0].get(), 1.0); - assert_eq!(slice[3].get(), 2.5); + let slice = view.as_slice(py).unwrap(); + assert_eq!(slice[0].get(), 1.0); + assert_eq!(slice[3].get(), 2.5); - let mut_slice = view.as_mut_slice(py).unwrap(); - mut_slice[0].set(9.0); - assert_eq!(slice[0].get(), 9.0); - }, - ) + let mut_slice = view.as_mut_slice(py).unwrap(); + mut_slice[0].set(9.0); + assert_eq!(slice[0].get(), 9.0); + }) .unwrap(); }); } @@ -1641,20 +1782,46 @@ mod tests { fn test_typed_buffer_view_with_flags_incompatible() { Python::attach(|py| { let bytes = PyBytes::new(py, b"abcde"); - let result = PyBufferView::::with_flags( - &bytes, - ffi::PyBUF_ND, - |_view| {}, - ); + let result = PyBufferView::::with_flags(&bytes, PyBufferFlags::ND, |_view| {}); assert!(result.is_err()); }); } + #[test] + fn test_c_contiguous_slice() { + Python::attach(|py| { + let array = py + .import("array") + .unwrap() + .call_method("array", ("f", (1.0, 1.5, 2.0)), None) + .unwrap(); + + // C_CONTIGUOUS: guaranteed contiguous readonly access (no Option) + PyBufferView::::with_flags(&array, PyBufferFlags::C_CONTIGUOUS, |view| { + let slice = view.as_contiguous_slice(py); + assert_eq!(slice.len(), 3); + assert_eq!(slice[0].get(), 1.0); + assert_eq!(slice[2].get(), 2.0); + }) + .unwrap(); + + // C_CONTIGUOUS | WRITABLE (via CONTIG combined with STRIDES-level): + // no predefined constant, but we can use PyBufferView::with on a writable array + // and the Option-based as_mut_slice still works + PyBufferView::::with(&array, |view| { + let mut_slice = view.as_mut_slice(py).unwrap(); + mut_slice[2].set(9.0); + assert_eq!(view.as_slice(py).unwrap()[2].get(), 9.0); + }) + .unwrap(); + }); + } + #[test] fn test_buffer_view_error() { Python::attach(|py| { let list = crate::types::PyList::empty(py); - let result = PyUntypedBufferView::with(&list, |_view| {}); + let result = PyUntypedBufferView::with_flags(&list, PyBufferFlags::FULL_RO, |_view| {}); assert!(result.is_err()); }); } @@ -1665,7 +1832,7 @@ mod tests { let bytes = PyBytes::new(py, b"abcde"); // Debug always uses raw_format/raw_shape/raw_strides (Option in output) - PyUntypedBufferView::with_flags(&bytes, ffi::PyBUF_FULL_RO, |view| { + PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::FULL_RO, |view| { let expected = format!( concat!( "PyUntypedBufferView {{ buf: {:?}, obj: {:?}, ", From 58ec4b3138f9c564d7276cfa9500dcd76f1a1609 Mon Sep 17 00:00:00 2001 From: winstxnhdw Date: Wed, 8 Apr 2026 04:59:12 +0800 Subject: [PATCH 12/17] feat: add flag builder --- src/buffer.rs | 580 ++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 466 insertions(+), 114 deletions(-) diff --git a/src/buffer.rs b/src/buffer.rs index f276af48cdb..d50051c85d2 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -804,7 +804,8 @@ mod py_buffer_flags_sealed { const WRITABLE: bool, const C_CONTIGUOUS: bool, const F_CONTIGUOUS: bool, - > Sealed for super::PyBufferFlags + > Sealed + for super::PyBufferFlags { } } @@ -819,72 +820,144 @@ impl< const WRITABLE: bool, const C_CONTIGUOUS: bool, const F_CONTIGUOUS: bool, - > PyBufferFlagsType for PyBufferFlags + > PyBufferFlagsType + for PyBufferFlags { } -// PyBufferFlags::FORMAT | -impl - std::ops::BitOr> - for PyBufferFlags +// Builder methods for composing flags. + +impl< + const SHAPE: bool, + const STRIDE: bool, + const WRITABLE: bool, + const C_CONTIGUOUS: bool, + const F_CONTIGUOUS: bool, + > PyBufferFlags { - type Output = PyBufferFlags; - fn bitor(self, rhs: PyBufferFlags) -> Self::Output { - PyBufferFlags(self.0 | rhs.0) + /// Request format information. + pub fn format( + self, + ) -> PyBufferFlags { + PyBufferFlags(self.0 | ffi::PyBUF_FORMAT) } } -// | PyBufferFlags::FORMAT -impl - std::ops::BitOr> - for PyBufferFlags +impl< + const FORMAT: bool, + const STRIDE: bool, + const WRITABLE: bool, + const C_CONTIGUOUS: bool, + const F_CONTIGUOUS: bool, + > PyBufferFlags { - type Output = PyBufferFlags; - fn bitor(self, rhs: PyBufferFlags) -> Self::Output { - PyBufferFlags(self.0 | rhs.0) + /// Request shape information. + pub fn nd(self) -> PyBufferFlags { + PyBufferFlags(self.0 | ffi::PyBUF_ND) } } -#[allow(non_upper_case_globals)] -impl PyBufferFlags { - /// Request a simple buffer with no shape, strides, or format information. - pub const SIMPLE: PyBufferFlags = PyBufferFlags(ffi::PyBUF_SIMPLE); - /// Request format information only. - pub const FORMAT: PyBufferFlags = PyBufferFlags(ffi::PyBUF_FORMAT); - /// Request shape information. - pub const ND: PyBufferFlags = PyBufferFlags(ffi::PyBUF_ND); - /// Request shape and strides. - pub const STRIDES: PyBufferFlags = PyBufferFlags(ffi::PyBUF_STRIDES); - /// Request C-contiguous buffer with shape and strides. - pub const C_CONTIGUOUS: PyBufferFlags = - PyBufferFlags(ffi::PyBUF_C_CONTIGUOUS); - /// Request Fortran-contiguous buffer with shape and strides. - pub const F_CONTIGUOUS: PyBufferFlags = - PyBufferFlags(ffi::PyBUF_F_CONTIGUOUS); - /// Request contiguous buffer (C or Fortran) with shape and strides. - pub const ANY_CONTIGUOUS: PyBufferFlags = - PyBufferFlags(ffi::PyBUF_ANY_CONTIGUOUS); - /// Request shape, strides, and suboffsets. - pub const INDIRECT: PyBufferFlags = PyBufferFlags(ffi::PyBUF_INDIRECT); - /// Request writable buffer with shape. - pub const CONTIG: PyBufferFlags = PyBufferFlags(ffi::PyBUF_CONTIG); - /// Request shape (read-only, equivalent to [`Self::ND`]). - pub const CONTIG_RO: PyBufferFlags = PyBufferFlags(ffi::PyBUF_CONTIG_RO); - /// Request writable buffer with shape and strides. - pub const STRIDED: PyBufferFlags = - PyBufferFlags(ffi::PyBUF_STRIDED); - /// Request shape and strides (read-only, equivalent to [`Self::STRIDES`]). - pub const STRIDED_RO: PyBufferFlags = - PyBufferFlags(ffi::PyBUF_STRIDED_RO); - /// Request writable buffer with shape, strides, and format. - pub const RECORDS: PyBufferFlags = - PyBufferFlags(ffi::PyBUF_RECORDS); - /// Request shape, strides, and format. - pub const RECORDS_RO: PyBufferFlags = PyBufferFlags(ffi::PyBUF_RECORDS_RO); - /// Request writable buffer with all information including suboffsets. - pub const FULL: PyBufferFlags = PyBufferFlags(ffi::PyBUF_FULL); +impl< + const FORMAT: bool, + const SHAPE: bool, + const WRITABLE: bool, + const C_CONTIGUOUS: bool, + const F_CONTIGUOUS: bool, + > PyBufferFlags +{ + /// Request strides information. Implies shape. + pub fn strides( + self, + ) -> PyBufferFlags { + PyBufferFlags(self.0 | ffi::PyBUF_STRIDES) + } + + /// Request suboffsets (indirect). Implies shape and strides. + pub fn indirect( + self, + ) -> PyBufferFlags { + PyBufferFlags(self.0 | ffi::PyBUF_INDIRECT) + } +} + +impl< + const FORMAT: bool, + const SHAPE: bool, + const STRIDE: bool, + const C_CONTIGUOUS: bool, + const F_CONTIGUOUS: bool, + > PyBufferFlags +{ + /// Request a writable buffer. + pub fn writable( + self, + ) -> PyBufferFlags { + PyBufferFlags(self.0 | ffi::PyBUF_WRITABLE) + } +} + +impl< + const FORMAT: bool, + const SHAPE: bool, + const STRIDE: bool, + const WRITABLE: bool, + const F_CONTIGUOUS: bool, + > PyBufferFlags +{ + /// Require C-contiguous layout. Implies shape and strides. + pub fn c_contiguous(self) -> PyBufferFlags { + PyBufferFlags(self.0 | ffi::PyBUF_C_CONTIGUOUS) + } +} + +impl< + const FORMAT: bool, + const SHAPE: bool, + const STRIDE: bool, + const WRITABLE: bool, + const C_CONTIGUOUS: bool, + > PyBufferFlags +{ + /// Require Fortran-contiguous layout. Implies shape and strides. + pub fn f_contiguous(self) -> PyBufferFlags { + PyBufferFlags(self.0 | ffi::PyBUF_F_CONTIGUOUS) + } +} + +impl + PyBufferFlags +{ + /// Require contiguous layout (C or Fortran). Implies shape and strides. + /// + /// The specific contiguity order is not known at compile time, + /// so this does not unlock non-Option slice accessors. + pub fn any_contiguous(self) -> PyBufferFlags { + PyBufferFlags(self.0 | ffi::PyBUF_ANY_CONTIGUOUS) + } +} + +// Requires FORMAT=false, SHAPE=false, STRIDE=false +impl + PyBufferFlags +{ /// Request all buffer information including suboffsets. - pub const FULL_RO: PyBufferFlags = PyBufferFlags(ffi::PyBUF_FULL_RO); + /// Implies format, shape, and strides. Chain `.writable()` for write access. + pub fn full(self) -> PyBufferFlags { + PyBufferFlags(self.0 | ffi::PyBUF_FULL_RO) + } + + /// Request format, shape, and strides. + /// Chain `.writable()` for write access. + pub fn records(self) -> PyBufferFlags { + PyBufferFlags(self.0 | ffi::PyBUF_RECORDS_RO) + } +} + +impl PyBufferFlags { + /// Create a base buffer request. Chain builder methods to add flags. + pub fn simple() -> PyBufferFlags { + PyBufferFlags(ffi::PyBUF_SIMPLE) + } } /// A typed form of [`PyUntypedBufferView`]. Not constructible directly — use @@ -978,8 +1051,14 @@ impl PyUntypedBufferView { } } -impl - PyUntypedBufferView> +impl< + const SHAPE: bool, + const STRIDE: bool, + const WRITABLE: bool, + const C_CONTIGUOUS: bool, + const F_CONTIGUOUS: bool, + > + PyUntypedBufferView> { /// A [struct module style](https://docs.python.org/3/c-api/buffer.html#c.Py_buffer.format) /// string describing the contents of a single item. @@ -992,12 +1071,17 @@ impl( &self, - ) -> PyResult<&PyBufferView>> { + ) -> PyResult< + &PyBufferView>, + > { self.ensure_compatible_with::()?; // SAFETY: PyBufferView is repr(transparent) around PyUntypedBufferView<..> Ok(unsafe { NonNull::from(self) - .cast::>>() + .cast::, + >>() .as_ref() }) } @@ -1007,8 +1091,14 @@ impl - PyUntypedBufferView> +impl< + const FORMAT: bool, + const STRIDE: bool, + const WRITABLE: bool, + const C_CONTIGUOUS: bool, + const F_CONTIGUOUS: bool, + > + PyUntypedBufferView> { /// Returns the shape array. `shape[i]` is the length of dimension `i`. /// @@ -1022,8 +1112,14 @@ impl - PyUntypedBufferView> +impl< + const FORMAT: bool, + const SHAPE: bool, + const WRITABLE: bool, + const C_CONTIGUOUS: bool, + const F_CONTIGUOUS: bool, + > + PyUntypedBufferView> { /// Returns the strides array. /// @@ -1036,6 +1132,15 @@ impl { + /// Returns the format string for a simple buffer, which is always `"B"`. + #[inline] + pub fn format(&self) -> &CStr { + ffi::c_str!("B") + } +} + /// Check that a buffer is compatible with element type `T`. fn check_buffer_compatibility( buf: *mut c_void, @@ -1063,21 +1168,23 @@ impl PyUntypedBufferView { /// Acquire a buffer view with the given flags, /// pass it to `f`, then release the buffer. /// - /// Use predefined flag constants like [`PyBufferFlags::SIMPLE`], [`PyBufferFlags::ND`], - /// [`PyBufferFlags::STRIDES`], [`PyBufferFlags::FULL_RO`], etc. + /// Use [`PyBufferFlags::simple()`] to compose flags, e.g. + /// `PyBufferFlags::simple().strides()`, `PyBufferFlags::simple().full()`, etc. pub fn with_flags< const FORMAT: bool, const SHAPE: bool, const STRIDE: bool, const WRITABLE: bool, - const CCONTIGUOUS: bool, - const FCONTIGUOUS: bool, + const C_CONTIGUOUS: bool, + const F_CONTIGUOUS: bool, R, >( obj: &Bound<'_, PyAny>, - flags: PyBufferFlags, + flags: PyBufferFlags, f: impl FnOnce( - &PyUntypedBufferView>, + &PyUntypedBufferView< + PyBufferFlags, + >, ) -> R, ) -> PyResult { let mut raw = mem::MaybeUninit::::uninit(); @@ -1137,10 +1244,10 @@ impl Debug for PyUntypedBufferView { } impl PyBufferView { - /// Acquire a typed buffer view with `PyBufferFlags::FULL_RO` flags, + /// Acquire a typed buffer view with `PyBufferFlags::simple().full()` flags, /// validating that the buffer format is compatible with `T`. pub fn with(obj: &Bound<'_, PyAny>, f: impl FnOnce(&PyBufferView) -> R) -> PyResult { - PyUntypedBufferView::with_flags(obj, PyBufferFlags::FULL_RO, |view| { + PyUntypedBufferView::with_flags(obj, PyBufferFlags::simple().full(), |view| { view.as_typed::().map(f) })? } @@ -1152,31 +1259,31 @@ impl PyBufferView { const SHAPE: bool, const STRIDE: bool, const WRITABLE: bool, - const CCONTIGUOUS: bool, - const FCONTIGUOUS: bool, + const C_CONTIGUOUS: bool, + const F_CONTIGUOUS: bool, R, >( obj: &Bound<'_, PyAny>, - flags: PyBufferFlags, + flags: PyBufferFlags, f: impl FnOnce( - &PyBufferView>, + &PyBufferView< + T, + PyBufferFlags, + >, ) -> R, ) -> PyResult { let mut raw = mem::MaybeUninit::::uninit(); err::error_on_minusone(obj.py(), unsafe { - ffi::PyObject_GetBuffer( - obj.as_ptr(), - raw.as_mut_ptr(), - flags.0 | ffi::PyBUF_FORMAT, - ) + ffi::PyObject_GetBuffer(obj.as_ptr(), raw.as_mut_ptr(), flags.0 | ffi::PyBUF_FORMAT) })?; - let view = - PyUntypedBufferView::> { - raw: unsafe { raw.assume_init() }, - _flags: PhantomData, - }; + let view = PyUntypedBufferView::< + PyBufferFlags, + > { + raw: unsafe { raw.assume_init() }, + _flags: PhantomData, + }; view.as_typed::().map(f) } @@ -1219,8 +1326,8 @@ impl< const SHAPE: bool, const STRIDE: bool, const WRITABLE: bool, - const FCONTIGUOUS: bool, - > PyBufferView> + const F_CONTIGUOUS: bool, + > PyBufferView> { /// Gets the buffer memory as a slice. The buffer is guaranteed C-contiguous. pub fn as_contiguous_slice<'a>(&'a self, _py: Python<'a>) -> &'a [ReadOnlyCell] { @@ -1229,8 +1336,13 @@ impl< } // C-contiguous + writable guaranteed — no checks needed. -impl - PyBufferView> +impl< + T: Element, + const FORMAT: bool, + const SHAPE: bool, + const STRIDE: bool, + const F_CONTIGUOUS: bool, + > PyBufferView> { /// Gets the buffer memory as a mutable slice. /// The buffer is guaranteed C-contiguous and writable. @@ -1246,8 +1358,8 @@ impl< const SHAPE: bool, const STRIDE: bool, const WRITABLE: bool, - const CCONTIGUOUS: bool, - > PyBufferView> + const C_CONTIGUOUS: bool, + > PyBufferView> { /// Gets the buffer memory as a slice. The buffer is guaranteed Fortran-contiguous. pub fn as_fortran_contiguous_slice<'a>(&'a self, _py: Python<'a>) -> &'a [ReadOnlyCell] { @@ -1256,15 +1368,17 @@ impl< } // Fortran-contiguous + writable guaranteed — no checks needed. -impl - PyBufferView> +impl< + T: Element, + const FORMAT: bool, + const SHAPE: bool, + const STRIDE: bool, + const C_CONTIGUOUS: bool, + > PyBufferView> { /// Gets the buffer memory as a mutable slice. /// The buffer is guaranteed Fortran-contiguous and writable. - pub fn as_fortran_contiguous_mut_slice<'a>( - &'a self, - _py: Python<'a>, - ) -> &'a [cell::Cell] { + pub fn as_fortran_contiguous_mut_slice<'a>(&'a self, _py: Python<'a>) -> &'a [cell::Cell] { unsafe { slice::from_raw_parts(self.0.raw.buf.cast(), self.item_count()) } } } @@ -1647,14 +1761,14 @@ mod tests { fn test_untyped_buffer_view() { Python::attach(|py| { let bytes = PyBytes::new(py, b"abcde"); - PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::FULL_RO, |view| { + PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple().full(), |view| { assert!(!view.buf_ptr().is_null()); assert_eq!(view.len_bytes(), 5); assert_eq!(view.item_size(), 1); assert_eq!(view.item_count(), 5); assert!(view.readonly()); assert_eq!(view.dimensions(), 1); - // with() uses PyBufferFlags::FULL_RO — all Known, direct return types + // with() uses PyBufferFlags::simple().full() — all Known, direct return types assert_eq!(view.format().to_str().unwrap(), "B"); assert_eq!(view.shape(), [5]); assert_eq!(view.strides(), [1]); @@ -1674,7 +1788,7 @@ mod tests { PyBufferView::::with(&bytes, |view| { assert_eq!(view.dimensions(), 1); assert_eq!(view.item_count(), 5); - // PyBufferView::with uses PyBufferFlags::FULL_RO — all Known + // PyBufferView::with uses PyBufferFlags::simple().full() — all Known assert_eq!(view.format().to_str().unwrap(), "B"); assert_eq!(view.shape(), [5]); @@ -1724,7 +1838,7 @@ mod tests { Python::attach(|py| { let bytes = PyBytes::new(py, b"abcde"); - PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::SIMPLE, |view| { + PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple(), |view| { assert_eq!(view.item_count(), 5); assert_eq!(view.len_bytes(), 5); assert!(view.readonly()); @@ -1732,19 +1846,19 @@ mod tests { }) .unwrap(); - PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::ND, |view| { + PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple().nd(), |view| { assert_eq!(view.item_count(), 5); assert_eq!(view.shape(), [5]); }) .unwrap(); - PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::STRIDES, |view| { + PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple().strides(), |view| { assert_eq!(view.shape(), [5]); assert_eq!(view.strides(), [1]); }) .unwrap(); - PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::FORMAT, |view| { + PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple().format(), |view| { assert_eq!(view.item_count(), 5); assert_eq!(view.format().to_str().unwrap(), "B"); }) @@ -1761,7 +1875,7 @@ mod tests { .call_method("array", ("f", (1.0, 1.5, 2.0, 2.5)), None) .unwrap(); - PyBufferView::::with_flags(&array, PyBufferFlags::ND, |view| { + PyBufferView::::with_flags(&array, PyBufferFlags::simple().nd(), |view| { assert_eq!(view.item_count(), 4); assert_eq!(view.format().to_str().unwrap(), "f"); assert_eq!(view.shape(), [4]); @@ -1782,7 +1896,8 @@ mod tests { fn test_typed_buffer_view_with_flags_incompatible() { Python::attach(|py| { let bytes = PyBytes::new(py, b"abcde"); - let result = PyBufferView::::with_flags(&bytes, PyBufferFlags::ND, |_view| {}); + let result = + PyBufferView::::with_flags(&bytes, PyBufferFlags::simple().nd(), |_view| {}); assert!(result.is_err()); }); } @@ -1797,12 +1912,16 @@ mod tests { .unwrap(); // C_CONTIGUOUS: guaranteed contiguous readonly access (no Option) - PyBufferView::::with_flags(&array, PyBufferFlags::C_CONTIGUOUS, |view| { - let slice = view.as_contiguous_slice(py); - assert_eq!(slice.len(), 3); - assert_eq!(slice[0].get(), 1.0); - assert_eq!(slice[2].get(), 2.0); - }) + PyBufferView::::with_flags( + &array, + PyBufferFlags::simple().c_contiguous(), + |view| { + let slice = view.as_contiguous_slice(py); + assert_eq!(slice.len(), 3); + assert_eq!(slice[0].get(), 1.0); + assert_eq!(slice[2].get(), 2.0); + }, + ) .unwrap(); // C_CONTIGUOUS | WRITABLE (via CONTIG combined with STRIDES-level): @@ -1821,18 +1940,251 @@ mod tests { fn test_buffer_view_error() { Python::attach(|py| { let list = crate::types::PyList::empty(py); - let result = PyUntypedBufferView::with_flags(&list, PyBufferFlags::FULL_RO, |_view| {}); + let result = + PyUntypedBufferView::with_flags(&list, PyBufferFlags::simple().full(), |_view| {}); assert!(result.is_err()); }); } + #[test] + fn test_flag_builders() { + Python::attach(|py| { + let bytes = PyBytes::new(py, b"abcde"); + let array = py + .import("array") + .unwrap() + .call_method("array", ("f", (1.0, 1.5, 2.0, 2.5)), None) + .unwrap(); + + // Primitive builders + PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple().format(), |view| { + assert_eq!(view.format().to_str().unwrap(), "B"); + }) + .unwrap(); + + PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple().nd(), |view| { + assert_eq!(view.shape(), [5]); + }) + .unwrap(); + + PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple().strides(), |view| { + assert_eq!(view.shape(), [5]); + assert_eq!(view.strides(), [1]); + }) + .unwrap(); + + PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple().indirect(), |view| { + assert_eq!(view.shape(), [5]); + assert_eq!(view.strides(), [1]); + }) + .unwrap(); + + PyUntypedBufferView::with_flags( + &array, + PyBufferFlags::simple().writable().nd(), + |view| { + assert_eq!(view.shape(), [4]); + assert!(!view.readonly()); + }, + ) + .unwrap(); + + // Chained primitive builders + PyUntypedBufferView::with_flags( + &bytes, + PyBufferFlags::simple().nd().format(), + |view| { + assert_eq!(view.shape(), [5]); + assert_eq!(view.format().to_str().unwrap(), "B"); + }, + ) + .unwrap(); + + PyUntypedBufferView::with_flags( + &bytes, + PyBufferFlags::simple().strides().format(), + |view| { + assert_eq!(view.shape(), [5]); + assert_eq!(view.strides(), [1]); + assert_eq!(view.format().to_str().unwrap(), "B"); + }, + ) + .unwrap(); + + // Contiguity builders + PyUntypedBufferView::with_flags( + &bytes, + PyBufferFlags::simple().c_contiguous(), + |view| { + assert_eq!(view.shape(), [5]); + assert_eq!(view.strides(), [1]); + }, + ) + .unwrap(); + + PyUntypedBufferView::with_flags( + &bytes, + PyBufferFlags::simple().f_contiguous(), + |view| { + assert_eq!(view.shape(), [5]); + assert_eq!(view.strides(), [1]); + }, + ) + .unwrap(); + + PyUntypedBufferView::with_flags( + &bytes, + PyBufferFlags::simple().any_contiguous(), + |view| { + assert_eq!(view.shape(), [5]); + assert_eq!(view.strides(), [1]); + }, + ) + .unwrap(); + + // Compound builders (read-only) + PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple().full(), |view| { + assert_eq!(view.format().to_str().unwrap(), "B"); + assert_eq!(view.shape(), [5]); + assert_eq!(view.strides(), [1]); + }) + .unwrap(); + + PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple().records(), |view| { + assert_eq!(view.format().to_str().unwrap(), "B"); + assert_eq!(view.shape(), [5]); + assert_eq!(view.strides(), [1]); + }) + .unwrap(); + + // Compound builders + .writable() + PyUntypedBufferView::with_flags( + &array, + PyBufferFlags::simple().full().writable(), + |view| { + assert_eq!(view.format().to_str().unwrap(), "f"); + assert_eq!(view.shape(), [4]); + assert_eq!(view.strides(), [4]); + assert!(!view.readonly()); + }, + ) + .unwrap(); + + PyUntypedBufferView::with_flags( + &array, + PyBufferFlags::simple().records().writable(), + |view| { + assert_eq!(view.format().to_str().unwrap(), "f"); + assert_eq!(view.shape(), [4]); + assert!(!view.readonly()); + }, + ) + .unwrap(); + + PyUntypedBufferView::with_flags( + &array, + PyBufferFlags::simple().strides().writable(), + |view| { + assert_eq!(view.shape(), [4]); + assert_eq!(view.strides(), [4]); + assert!(!view.readonly()); + }, + ) + .unwrap(); + + PyUntypedBufferView::with_flags( + &array, + PyBufferFlags::simple().nd().writable(), + |view| { + assert_eq!(view.shape(), [4]); + assert!(!view.readonly()); + }, + ) + .unwrap(); + + // Compound + contiguity + PyUntypedBufferView::with_flags( + &bytes, + PyBufferFlags::simple().full().c_contiguous(), + |view| { + assert_eq!(view.format().to_str().unwrap(), "B"); + assert_eq!(view.shape(), [5]); + assert_eq!(view.strides(), [1]); + }, + ) + .unwrap(); + + PyUntypedBufferView::with_flags( + &array, + PyBufferFlags::simple().full().writable().c_contiguous(), + |view| { + assert_eq!(view.format().to_str().unwrap(), "f"); + assert!(!view.readonly()); + }, + ) + .unwrap(); + + PyUntypedBufferView::with_flags( + &bytes, + PyBufferFlags::simple().strides().format(), + |view| { + assert_eq!(view.format().to_str().unwrap(), "B"); + assert_eq!(view.shape(), [5]); + assert_eq!(view.strides(), [1]); + }, + ) + .unwrap(); + + PyUntypedBufferView::with_flags( + &bytes, + PyBufferFlags::simple().c_contiguous().format(), + |view| { + assert_eq!(view.format().to_str().unwrap(), "B"); + assert_eq!(view.shape(), [5]); + }, + ) + .unwrap(); + + // Contiguity builder on typed view + PyBufferView::::with_flags( + &array, + PyBufferFlags::simple().c_contiguous(), + |view| { + let slice = view.as_contiguous_slice(py); + assert_eq!(slice[0].get(), 1.0); + }, + ) + .unwrap(); + + // Writable + contiguity on typed view + PyBufferView::::with_flags( + &array, + PyBufferFlags::simple().c_contiguous().writable(), + |view| { + let slice = view.as_contiguous_slice(py); + assert_eq!(slice[0].get(), 1.0); + let mut_slice = view.as_contiguous_mut_slice(py); + mut_slice[0].set(9.0); + assert_eq!(slice[0].get(), 9.0); + }, + ) + .unwrap(); + + // SIMPLE format() returns "B" + PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple(), |view| { + assert_eq!(view.format().to_str().unwrap(), "B"); + }) + .unwrap(); + }); + } + #[test] fn test_buffer_view_debug() { Python::attach(|py| { let bytes = PyBytes::new(py, b"abcde"); // Debug always uses raw_format/raw_shape/raw_strides (Option in output) - PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::FULL_RO, |view| { + PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple().full(), |view| { let expected = format!( concat!( "PyUntypedBufferView {{ buf: {:?}, obj: {:?}, ", From 76df0d19832d3d0f99e50de4fea886100cba742c Mon Sep 17 00:00:00 2001 From: winstxnhdw Date: Sun, 12 Apr 2026 22:39:11 +0800 Subject: [PATCH 13/17] refactor: apply suggestions --- src/buffer.rs | 452 ++++++++++++++++++++++---------------------------- 1 file changed, 199 insertions(+), 253 deletions(-) diff --git a/src/buffer.rs b/src/buffer.rs index d50051c85d2..3cc01c237a8 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -784,6 +784,19 @@ impl_element!(isize, SignedInteger); impl_element!(f32, Float); impl_element!(f64, Float); +#[repr(u8)] +enum PyBufferContiguity { + Undefined = 0, + C = 1, + F = 2, + Any = 3, +} + +const CONTIGUITY_UNDEFINED: u8 = PyBufferContiguity::Undefined as u8; +const CONTIGUITY_C: u8 = PyBufferContiguity::C as u8; +const CONTIGUITY_F: u8 = PyBufferContiguity::F as u8; +const CONTIGUITY_ANY: u8 = PyBufferContiguity::Any as u8; + /// Type-safe buffer request flags. The const parameters encode which fields /// the exporter is required to fill. pub struct PyBufferFlags< @@ -791,8 +804,7 @@ pub struct PyBufferFlags< const SHAPE: bool = false, const STRIDE: bool = false, const WRITABLE: bool = false, - const C_CONTIGUOUS: bool = false, - const F_CONTIGUOUS: bool = false, + const CONTIGUITY: u8 = CONTIGUITY_UNDEFINED, >(c_int); mod py_buffer_flags_sealed { @@ -802,161 +814,138 @@ mod py_buffer_flags_sealed { const SHAPE: bool, const STRIDE: bool, const WRITABLE: bool, - const C_CONTIGUOUS: bool, - const F_CONTIGUOUS: bool, - > Sealed - for super::PyBufferFlags + const CONTIGUITY: u8, + > Sealed for super::PyBufferFlags { } } /// Trait implemented by all [`PyBufferFlags`] instantiations. -pub trait PyBufferFlagsType: py_buffer_flags_sealed::Sealed {} +pub trait PyBufferFlagsType: py_buffer_flags_sealed::Sealed { + /// The contiguity requirement encoded by these flags. + const CONTIGUITY: u8; +} impl< const FORMAT: bool, const SHAPE: bool, const STRIDE: bool, const WRITABLE: bool, - const C_CONTIGUOUS: bool, - const F_CONTIGUOUS: bool, - > PyBufferFlagsType - for PyBufferFlags + const CONTIGUITY_REQ: u8, + > PyBufferFlagsType for PyBufferFlags { + const CONTIGUITY: u8 = CONTIGUITY_REQ; } -// Builder methods for composing flags. - -impl< - const SHAPE: bool, - const STRIDE: bool, - const WRITABLE: bool, - const C_CONTIGUOUS: bool, - const F_CONTIGUOUS: bool, - > PyBufferFlags +impl + PyBufferFlags { /// Request format information. - pub fn format( - self, - ) -> PyBufferFlags { + pub const fn format(self) -> PyBufferFlags { PyBufferFlags(self.0 | ffi::PyBUF_FORMAT) } } -impl< - const FORMAT: bool, - const STRIDE: bool, - const WRITABLE: bool, - const C_CONTIGUOUS: bool, - const F_CONTIGUOUS: bool, - > PyBufferFlags +impl + PyBufferFlags { /// Request shape information. - pub fn nd(self) -> PyBufferFlags { + pub const fn nd(self) -> PyBufferFlags { PyBufferFlags(self.0 | ffi::PyBUF_ND) } } -impl< - const FORMAT: bool, - const SHAPE: bool, - const WRITABLE: bool, - const C_CONTIGUOUS: bool, - const F_CONTIGUOUS: bool, - > PyBufferFlags +impl + PyBufferFlags { /// Request strides information. Implies shape. - pub fn strides( - self, - ) -> PyBufferFlags { + pub const fn strides(self) -> PyBufferFlags { PyBufferFlags(self.0 | ffi::PyBUF_STRIDES) } /// Request suboffsets (indirect). Implies shape and strides. - pub fn indirect( - self, - ) -> PyBufferFlags { + pub const fn indirect(self) -> PyBufferFlags { PyBufferFlags(self.0 | ffi::PyBUF_INDIRECT) } } -impl< - const FORMAT: bool, - const SHAPE: bool, - const STRIDE: bool, - const C_CONTIGUOUS: bool, - const F_CONTIGUOUS: bool, - > PyBufferFlags +impl + PyBufferFlags { /// Request a writable buffer. - pub fn writable( - self, - ) -> PyBufferFlags { + pub const fn writable(self) -> PyBufferFlags { PyBufferFlags(self.0 | ffi::PyBUF_WRITABLE) } } -impl< - const FORMAT: bool, - const SHAPE: bool, - const STRIDE: bool, - const WRITABLE: bool, - const F_CONTIGUOUS: bool, - > PyBufferFlags +impl + PyBufferFlags { /// Require C-contiguous layout. Implies shape and strides. - pub fn c_contiguous(self) -> PyBufferFlags { + pub const fn c_contiguous(self) -> PyBufferFlags { PyBufferFlags(self.0 | ffi::PyBUF_C_CONTIGUOUS) } -} -impl< - const FORMAT: bool, - const SHAPE: bool, - const STRIDE: bool, - const WRITABLE: bool, - const C_CONTIGUOUS: bool, - > PyBufferFlags -{ /// Require Fortran-contiguous layout. Implies shape and strides. - pub fn f_contiguous(self) -> PyBufferFlags { + pub const fn f_contiguous(self) -> PyBufferFlags { PyBufferFlags(self.0 | ffi::PyBUF_F_CONTIGUOUS) } -} -impl - PyBufferFlags -{ /// Require contiguous layout (C or Fortran). Implies shape and strides. /// /// The specific contiguity order is not known at compile time, /// so this does not unlock non-Option slice accessors. - pub fn any_contiguous(self) -> PyBufferFlags { + pub const fn any_contiguous( + self, + ) -> PyBufferFlags { PyBufferFlags(self.0 | ffi::PyBUF_ANY_CONTIGUOUS) } } -// Requires FORMAT=false, SHAPE=false, STRIDE=false -impl - PyBufferFlags -{ - /// Request all buffer information including suboffsets. - /// Implies format, shape, and strides. Chain `.writable()` for write access. - pub fn full(self) -> PyBufferFlags { - PyBufferFlags(self.0 | ffi::PyBUF_FULL_RO) +impl PyBufferFlags { + /// Create a base buffer request. Chain builder methods to add flags. + pub const fn simple() -> PyBufferFlags { + PyBufferFlags(ffi::PyBUF_SIMPLE) } - /// Request format, shape, and strides. - /// Chain `.writable()` for write access. - pub fn records(self) -> PyBufferFlags { - PyBufferFlags(self.0 | ffi::PyBUF_RECORDS_RO) + /// Create a writable request for all buffer information including suboffsets. + pub const fn full() -> PyBufferFlags { + PyBufferFlags(ffi::PyBUF_FULL) } -} -impl PyBufferFlags { - /// Create a base buffer request. Chain builder methods to add flags. - pub fn simple() -> PyBufferFlags { - PyBufferFlags(ffi::PyBUF_SIMPLE) + /// Create a read-only request for all buffer information including suboffsets. + pub const fn full_ro() -> PyBufferFlags { + PyBufferFlags(ffi::PyBUF_FULL_RO) + } + + /// Create a writable request for format, shape, and strides. + pub const fn records() -> PyBufferFlags { + PyBufferFlags(ffi::PyBUF_RECORDS) + } + + /// Create a read-only request for format, shape, and strides. + pub const fn records_ro() -> PyBufferFlags { + PyBufferFlags(ffi::PyBUF_RECORDS_RO) + } + + /// Create a writable request for shape and strides. + pub const fn strided() -> PyBufferFlags { + PyBufferFlags(ffi::PyBUF_STRIDED) + } + + /// Create a read-only request for shape and strides. + pub const fn strided_ro() -> PyBufferFlags { + PyBufferFlags(ffi::PyBUF_STRIDED_RO) + } + + /// Create a writable C-contiguous request. + pub const fn contig() -> PyBufferFlags { + PyBufferFlags(ffi::PyBUF_CONTIG) + } + + /// Create a read-only C-contiguous request. + pub const fn contig_ro() -> PyBufferFlags { + PyBufferFlags(ffi::PyBUF_CONTIG_RO) } } @@ -973,7 +962,7 @@ pub struct PyBufferView { raw: ffi::Py_buffer, @@ -1041,24 +1030,20 @@ impl PyUntypedBufferView { /// Gets whether the buffer is contiguous in C-style order. #[inline] pub fn is_c_contiguous(&self) -> bool { - unsafe { ffi::PyBuffer_IsContiguous(&self.raw, b'C' as std::ffi::c_char) != 0 } + Flags::CONTIGUITY == CONTIGUITY_C + || unsafe { ffi::PyBuffer_IsContiguous(&self.raw, b'C' as std::ffi::c_char) != 0 } } /// Gets whether the buffer is contiguous in Fortran-style order. #[inline] pub fn is_fortran_contiguous(&self) -> bool { - unsafe { ffi::PyBuffer_IsContiguous(&self.raw, b'F' as std::ffi::c_char) != 0 } + Flags::CONTIGUITY == CONTIGUITY_F + || unsafe { ffi::PyBuffer_IsContiguous(&self.raw, b'F' as std::ffi::c_char) != 0 } } } -impl< - const SHAPE: bool, - const STRIDE: bool, - const WRITABLE: bool, - const C_CONTIGUOUS: bool, - const F_CONTIGUOUS: bool, - > - PyUntypedBufferView> +impl + PyUntypedBufferView> { /// A [struct module style](https://docs.python.org/3/c-api/buffer.html#c.Py_buffer.format) /// string describing the contents of a single item. @@ -1071,17 +1056,12 @@ impl< /// Attempt to interpret this untyped view as containing elements of type `T`. pub fn as_typed( &self, - ) -> PyResult< - &PyBufferView>, - > { + ) -> PyResult<&PyBufferView>> { self.ensure_compatible_with::()?; // SAFETY: PyBufferView is repr(transparent) around PyUntypedBufferView<..> Ok(unsafe { NonNull::from(self) - .cast::, - >>() + .cast::>>() .as_ref() }) } @@ -1091,14 +1071,8 @@ impl< } } -impl< - const FORMAT: bool, - const STRIDE: bool, - const WRITABLE: bool, - const C_CONTIGUOUS: bool, - const F_CONTIGUOUS: bool, - > - PyUntypedBufferView> +impl + PyUntypedBufferView> { /// Returns the shape array. `shape[i]` is the length of dimension `i`. /// @@ -1112,14 +1086,8 @@ impl< } } -impl< - const FORMAT: bool, - const SHAPE: bool, - const WRITABLE: bool, - const C_CONTIGUOUS: bool, - const F_CONTIGUOUS: bool, - > - PyUntypedBufferView> +impl + PyUntypedBufferView> { /// Returns the strides array. /// @@ -1132,9 +1100,11 @@ impl< } } -// SIMPLE: format is guaranteed to be "B" (unsigned bytes) by the buffer protocol. -impl PyUntypedBufferView { - /// Returns the format string for a simple buffer, which is always `"B"`. +// SIMPLE and WRITABLE requests guarantee the implicit "B" format. +impl + PyUntypedBufferView> +{ + /// Returns the format string for a simple byte buffer, which is always `"B"`. #[inline] pub fn format(&self) -> &CStr { ffi::c_str!("B") @@ -1168,23 +1138,20 @@ impl PyUntypedBufferView { /// Acquire a buffer view with the given flags, /// pass it to `f`, then release the buffer. /// - /// Use [`PyBufferFlags::simple()`] to compose flags, e.g. - /// `PyBufferFlags::simple().strides()`, `PyBufferFlags::simple().full()`, etc. + /// Use [`PyBufferFlags::simple()`] or one of the compound-request constructors such as + /// [`PyBufferFlags::full_ro()`] to acquire a view. pub fn with_flags< const FORMAT: bool, const SHAPE: bool, const STRIDE: bool, const WRITABLE: bool, - const C_CONTIGUOUS: bool, - const F_CONTIGUOUS: bool, + const CONTIGUITY: u8, R, >( obj: &Bound<'_, PyAny>, - flags: PyBufferFlags, + flags: PyBufferFlags, f: impl FnOnce( - &PyUntypedBufferView< - PyBufferFlags, - >, + &PyUntypedBufferView>, ) -> R, ) -> PyResult { let mut raw = mem::MaybeUninit::::uninit(); @@ -1244,10 +1211,10 @@ impl Debug for PyUntypedBufferView { } impl PyBufferView { - /// Acquire a typed buffer view with `PyBufferFlags::simple().full()` flags, + /// Acquire a typed buffer view with `PyBufferFlags::full_ro()` flags, /// validating that the buffer format is compatible with `T`. pub fn with(obj: &Bound<'_, PyAny>, f: impl FnOnce(&PyBufferView) -> R) -> PyResult { - PyUntypedBufferView::with_flags(obj, PyBufferFlags::simple().full(), |view| { + PyUntypedBufferView::with_flags(obj, PyBufferFlags::full_ro(), |view| { view.as_typed::().map(f) })? } @@ -1256,21 +1223,16 @@ impl PyBufferView { /// /// [`ffi::PyBUF_FORMAT`] is implicitly added for type validation. pub fn with_flags< + const FORMAT: bool, const SHAPE: bool, const STRIDE: bool, const WRITABLE: bool, - const C_CONTIGUOUS: bool, - const F_CONTIGUOUS: bool, + const CONTIGUITY: u8, R, >( obj: &Bound<'_, PyAny>, - flags: PyBufferFlags, - f: impl FnOnce( - &PyBufferView< - T, - PyBufferFlags, - >, - ) -> R, + flags: PyBufferFlags, + f: impl FnOnce(&PyBufferView>) -> R, ) -> PyResult { let mut raw = mem::MaybeUninit::::uninit(); @@ -1278,9 +1240,7 @@ impl PyBufferView { ffi::PyObject_GetBuffer(obj.as_ptr(), raw.as_mut_ptr(), flags.0 | ffi::PyBUF_FORMAT) })?; - let view = PyUntypedBufferView::< - PyBufferFlags, - > { + let view = PyUntypedBufferView::> { raw: unsafe { raw.assume_init() }, _flags: PhantomData, }; @@ -1326,8 +1286,7 @@ impl< const SHAPE: bool, const STRIDE: bool, const WRITABLE: bool, - const F_CONTIGUOUS: bool, - > PyBufferView> + > PyBufferView> { /// Gets the buffer memory as a slice. The buffer is guaranteed C-contiguous. pub fn as_contiguous_slice<'a>(&'a self, _py: Python<'a>) -> &'a [ReadOnlyCell] { @@ -1336,13 +1295,8 @@ impl< } // C-contiguous + writable guaranteed — no checks needed. -impl< - T: Element, - const FORMAT: bool, - const SHAPE: bool, - const STRIDE: bool, - const F_CONTIGUOUS: bool, - > PyBufferView> +impl + PyBufferView> { /// Gets the buffer memory as a mutable slice. /// The buffer is guaranteed C-contiguous and writable. @@ -1358,8 +1312,7 @@ impl< const SHAPE: bool, const STRIDE: bool, const WRITABLE: bool, - const C_CONTIGUOUS: bool, - > PyBufferView> + > PyBufferView> { /// Gets the buffer memory as a slice. The buffer is guaranteed Fortran-contiguous. pub fn as_fortran_contiguous_slice<'a>(&'a self, _py: Python<'a>) -> &'a [ReadOnlyCell] { @@ -1368,13 +1321,8 @@ impl< } // Fortran-contiguous + writable guaranteed — no checks needed. -impl< - T: Element, - const FORMAT: bool, - const SHAPE: bool, - const STRIDE: bool, - const C_CONTIGUOUS: bool, - > PyBufferView> +impl + PyBufferView> { /// Gets the buffer memory as a mutable slice. /// The buffer is guaranteed Fortran-contiguous and writable. @@ -1761,14 +1709,14 @@ mod tests { fn test_untyped_buffer_view() { Python::attach(|py| { let bytes = PyBytes::new(py, b"abcde"); - PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple().full(), |view| { + PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::full_ro(), |view| { assert!(!view.buf_ptr().is_null()); assert_eq!(view.len_bytes(), 5); assert_eq!(view.item_size(), 1); assert_eq!(view.item_count(), 5); assert!(view.readonly()); assert_eq!(view.dimensions(), 1); - // with() uses PyBufferFlags::simple().full() — all Known, direct return types + // with_flags() uses PyBufferFlags::full_ro() — all Known, direct return types assert_eq!(view.format().to_str().unwrap(), "B"); assert_eq!(view.shape(), [5]); assert_eq!(view.strides(), [1]); @@ -1788,7 +1736,7 @@ mod tests { PyBufferView::::with(&bytes, |view| { assert_eq!(view.dimensions(), 1); assert_eq!(view.item_count(), 5); - // PyBufferView::with uses PyBufferFlags::simple().full() — all Known + // PyBufferView::with uses PyBufferFlags::full_ro() — all Known assert_eq!(view.format().to_str().unwrap(), "B"); assert_eq!(view.shape(), [5]); @@ -1941,7 +1889,7 @@ mod tests { Python::attach(|py| { let list = crate::types::PyList::empty(py); let result = - PyUntypedBufferView::with_flags(&list, PyBufferFlags::simple().full(), |_view| {}); + PyUntypedBufferView::with_flags(&list, PyBufferFlags::full_ro(), |_view| {}); assert!(result.is_err()); }); } @@ -1957,6 +1905,11 @@ mod tests { .unwrap(); // Primitive builders + PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple(), |view| { + assert_eq!(view.format().to_str().unwrap(), "B"); + }) + .unwrap(); + PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple().format(), |view| { assert_eq!(view.format().to_str().unwrap(), "B"); }) @@ -1979,6 +1932,12 @@ mod tests { }) .unwrap(); + PyUntypedBufferView::with_flags(&array, PyBufferFlags::simple().writable(), |view| { + assert_eq!(view.format().to_str().unwrap(), "B"); + assert!(!view.readonly()); + }) + .unwrap(); + PyUntypedBufferView::with_flags( &array, PyBufferFlags::simple().writable().nd(), @@ -2042,70 +2001,67 @@ mod tests { ) .unwrap(); - // Compound builders (read-only) - PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple().full(), |view| { + // Compound requests (read-only) + PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::full_ro(), |view| { assert_eq!(view.format().to_str().unwrap(), "B"); assert_eq!(view.shape(), [5]); assert_eq!(view.strides(), [1]); }) .unwrap(); - PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple().records(), |view| { + PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::records_ro(), |view| { assert_eq!(view.format().to_str().unwrap(), "B"); assert_eq!(view.shape(), [5]); assert_eq!(view.strides(), [1]); }) .unwrap(); - // Compound builders + .writable() - PyUntypedBufferView::with_flags( - &array, - PyBufferFlags::simple().full().writable(), - |view| { - assert_eq!(view.format().to_str().unwrap(), "f"); - assert_eq!(view.shape(), [4]); - assert_eq!(view.strides(), [4]); - assert!(!view.readonly()); - }, - ) + PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::strided_ro(), |view| { + assert_eq!(view.shape(), [5]); + assert_eq!(view.strides(), [1]); + }) .unwrap(); - PyUntypedBufferView::with_flags( - &array, - PyBufferFlags::simple().records().writable(), - |view| { - assert_eq!(view.format().to_str().unwrap(), "f"); - assert_eq!(view.shape(), [4]); - assert!(!view.readonly()); - }, - ) + PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::contig_ro(), |view| { + assert_eq!(view.shape(), [5]); + assert!(view.is_c_contiguous()); + }) .unwrap(); - PyUntypedBufferView::with_flags( - &array, - PyBufferFlags::simple().strides().writable(), - |view| { - assert_eq!(view.shape(), [4]); - assert_eq!(view.strides(), [4]); - assert!(!view.readonly()); - }, - ) + // Writable compound requests + PyUntypedBufferView::with_flags(&array, PyBufferFlags::full(), |view| { + assert_eq!(view.format().to_str().unwrap(), "f"); + assert_eq!(view.shape(), [4]); + assert_eq!(view.strides(), [4]); + assert!(!view.readonly()); + }) .unwrap(); - PyUntypedBufferView::with_flags( - &array, - PyBufferFlags::simple().nd().writable(), - |view| { - assert_eq!(view.shape(), [4]); - assert!(!view.readonly()); - }, - ) + PyUntypedBufferView::with_flags(&array, PyBufferFlags::records(), |view| { + assert_eq!(view.format().to_str().unwrap(), "f"); + assert_eq!(view.shape(), [4]); + assert!(!view.readonly()); + }) + .unwrap(); + + PyUntypedBufferView::with_flags(&array, PyBufferFlags::strided(), |view| { + assert_eq!(view.shape(), [4]); + assert_eq!(view.strides(), [4]); + assert!(!view.readonly()); + }) + .unwrap(); + + PyUntypedBufferView::with_flags(&array, PyBufferFlags::contig(), |view| { + assert_eq!(view.shape(), [4]); + assert!(!view.readonly()); + assert!(view.is_c_contiguous()); + }) .unwrap(); // Compound + contiguity PyUntypedBufferView::with_flags( &bytes, - PyBufferFlags::simple().full().c_contiguous(), + PyBufferFlags::full_ro().c_contiguous(), |view| { assert_eq!(view.format().to_str().unwrap(), "B"); assert_eq!(view.shape(), [5]); @@ -2114,25 +2070,17 @@ mod tests { ) .unwrap(); - PyUntypedBufferView::with_flags( - &array, - PyBufferFlags::simple().full().writable().c_contiguous(), - |view| { - assert_eq!(view.format().to_str().unwrap(), "f"); - assert!(!view.readonly()); - }, - ) + PyUntypedBufferView::with_flags(&array, PyBufferFlags::full().c_contiguous(), |view| { + assert_eq!(view.format().to_str().unwrap(), "f"); + assert!(!view.readonly()); + }) .unwrap(); - PyUntypedBufferView::with_flags( - &bytes, - PyBufferFlags::simple().strides().format(), - |view| { - assert_eq!(view.format().to_str().unwrap(), "B"); - assert_eq!(view.shape(), [5]); - assert_eq!(view.strides(), [1]); - }, - ) + PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::strided_ro().format(), |view| { + assert_eq!(view.format().to_str().unwrap(), "B"); + assert_eq!(view.shape(), [5]); + assert_eq!(view.strides(), [1]); + }) .unwrap(); PyUntypedBufferView::with_flags( @@ -2146,28 +2094,26 @@ mod tests { .unwrap(); // Contiguity builder on typed view - PyBufferView::::with_flags( - &array, - PyBufferFlags::simple().c_contiguous(), - |view| { - let slice = view.as_contiguous_slice(py); - assert_eq!(slice[0].get(), 1.0); - }, - ) + PyBufferView::::with_flags(&bytes, PyBufferFlags::simple().format(), |view| { + assert_eq!(view.format().to_str().unwrap(), "B"); + assert_eq!(view.item_count(), 5); + }) + .unwrap(); + + PyBufferView::::with_flags(&array, PyBufferFlags::contig(), |view| { + let slice = view.as_contiguous_slice(py); + assert_eq!(slice[0].get(), 1.0); + }) .unwrap(); // Writable + contiguity on typed view - PyBufferView::::with_flags( - &array, - PyBufferFlags::simple().c_contiguous().writable(), - |view| { - let slice = view.as_contiguous_slice(py); - assert_eq!(slice[0].get(), 1.0); - let mut_slice = view.as_contiguous_mut_slice(py); - mut_slice[0].set(9.0); - assert_eq!(slice[0].get(), 9.0); - }, - ) + PyBufferView::::with_flags(&array, PyBufferFlags::contig(), |view| { + let slice = view.as_contiguous_slice(py); + assert_eq!(slice[0].get(), 1.0); + let mut_slice = view.as_contiguous_mut_slice(py); + mut_slice[0].set(9.0); + assert_eq!(slice[0].get(), 9.0); + }) .unwrap(); // SIMPLE format() returns "B" @@ -2184,7 +2130,7 @@ mod tests { let bytes = PyBytes::new(py, b"abcde"); // Debug always uses raw_format/raw_shape/raw_strides (Option in output) - PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple().full(), |view| { + PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::full_ro(), |view| { let expected = format!( concat!( "PyUntypedBufferView {{ buf: {:?}, obj: {:?}, ", From e2ef191e8f86eda95bc071b8cb492f6809a34528 Mon Sep 17 00:00:00 2001 From: winstxnhdw Date: Tue, 14 Apr 2026 01:00:18 +0800 Subject: [PATCH 14/17] refactor: hide suboffsets if guaranteed null --- src/buffer.rs | 367 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 251 insertions(+), 116 deletions(-) diff --git a/src/buffer.rs b/src/buffer.rs index 3cc01c237a8..f0315c0be12 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -55,24 +55,20 @@ struct RawBuffer(ffi::Py_buffer, PhantomPinned); unsafe impl Send for PyUntypedBuffer {} unsafe impl Sync for PyUntypedBuffer {} -impl Debug for PyBuffer { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - debug_buffer("PyBuffer", &self.0, f) - } -} - -impl Debug for PyUntypedBuffer { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - debug_buffer("PyUntypedBuffer", self, f) - } -} - fn debug_buffer( name: &str, - b: &PyUntypedBuffer, + raw: &ffi::Py_buffer, f: &mut std::fmt::Formatter<'_>, ) -> std::fmt::Result { - let raw = b.raw(); + let ndim = raw.ndim as usize; + let format = NonNull::new(raw.format).map(|p| unsafe { CStr::from_ptr(p.as_ptr()) }); + let shape = NonNull::new(raw.shape) + .map(|p| unsafe { slice::from_raw_parts(p.as_ptr().cast::(), ndim) }); + let strides = + NonNull::new(raw.strides).map(|p| unsafe { slice::from_raw_parts(p.as_ptr(), ndim) }); + let suboffsets = + NonNull::new(raw.suboffsets).map(|p| unsafe { slice::from_raw_parts(p.as_ptr(), ndim) }); + f.debug_struct(name) .field("buf", &raw.buf) .field("obj", &raw.obj) @@ -80,14 +76,26 @@ fn debug_buffer( .field("itemsize", &raw.itemsize) .field("readonly", &raw.readonly) .field("ndim", &raw.ndim) - .field("format", &b.format()) - .field("shape", &b.shape()) - .field("strides", &b.strides()) - .field("suboffsets", &b.suboffsets()) + .field("format", &format) + .field("shape", &shape) + .field("strides", &strides) + .field("suboffsets", &suboffsets) .field("internal", &raw.internal) .finish() } +impl Debug for PyBuffer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + debug_buffer("PyBuffer", self.raw(), f) + } +} + +impl Debug for PyUntypedBuffer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + debug_buffer("PyUntypedBuffer", self.raw(), f) + } +} + /// Represents the type of a Python buffer element. #[derive(Copy, Clone, Debug, Eq, PartialEq)] pub enum ElementType { @@ -803,6 +811,7 @@ pub struct PyBufferFlags< const FORMAT: bool = false, const SHAPE: bool = false, const STRIDE: bool = false, + const INDIRECT: bool = false, const WRITABLE: bool = false, const CONTIGUITY: u8 = CONTIGUITY_UNDEFINED, >(c_int); @@ -813,9 +822,10 @@ mod py_buffer_flags_sealed { const FORMAT: bool, const SHAPE: bool, const STRIDE: bool, + const INDIRECT: bool, const WRITABLE: bool, const CONTIGUITY: u8, - > Sealed for super::PyBufferFlags + > Sealed for super::PyBufferFlags { } } @@ -824,70 +834,115 @@ mod py_buffer_flags_sealed { pub trait PyBufferFlagsType: py_buffer_flags_sealed::Sealed { /// The contiguity requirement encoded by these flags. const CONTIGUITY: u8; + + /// Whether these flags require a writable buffer. + const WRITABLE: bool; } impl< const FORMAT: bool, const SHAPE: bool, const STRIDE: bool, + const INDIRECT: bool, const WRITABLE: bool, const CONTIGUITY_REQ: u8, - > PyBufferFlagsType for PyBufferFlags + > PyBufferFlagsType + for PyBufferFlags { const CONTIGUITY: u8 = CONTIGUITY_REQ; + const WRITABLE: bool = WRITABLE; } -impl - PyBufferFlags +impl< + const SHAPE: bool, + const STRIDE: bool, + const INDIRECT: bool, + const WRITABLE: bool, + const CONTIGUITY: u8, + > PyBufferFlags { /// Request format information. - pub const fn format(self) -> PyBufferFlags { + pub const fn format( + self, + ) -> PyBufferFlags { PyBufferFlags(self.0 | ffi::PyBUF_FORMAT) } } -impl - PyBufferFlags +impl< + const FORMAT: bool, + const STRIDE: bool, + const INDIRECT: bool, + const WRITABLE: bool, + const CONTIGUITY: u8, + > PyBufferFlags { /// Request shape information. - pub const fn nd(self) -> PyBufferFlags { + pub const fn nd(self) -> PyBufferFlags { PyBufferFlags(self.0 | ffi::PyBUF_ND) } } -impl - PyBufferFlags +impl< + const FORMAT: bool, + const SHAPE: bool, + const INDIRECT: bool, + const WRITABLE: bool, + const CONTIGUITY: u8, + > PyBufferFlags { /// Request strides information. Implies shape. - pub const fn strides(self) -> PyBufferFlags { + pub const fn strides( + self, + ) -> PyBufferFlags { PyBufferFlags(self.0 | ffi::PyBUF_STRIDES) } +} +impl + PyBufferFlags +{ /// Request suboffsets (indirect). Implies shape and strides. - pub const fn indirect(self) -> PyBufferFlags { + pub const fn indirect(self) -> PyBufferFlags { PyBufferFlags(self.0 | ffi::PyBUF_INDIRECT) } } -impl - PyBufferFlags +impl< + const FORMAT: bool, + const SHAPE: bool, + const STRIDE: bool, + const INDIRECT: bool, + const CONTIGUITY: u8, + > PyBufferFlags { /// Request a writable buffer. - pub const fn writable(self) -> PyBufferFlags { + pub const fn writable( + self, + ) -> PyBufferFlags { PyBufferFlags(self.0 | ffi::PyBUF_WRITABLE) } } -impl - PyBufferFlags +impl< + const FORMAT: bool, + const SHAPE: bool, + const STRIDE: bool, + const INDIRECT: bool, + const WRITABLE: bool, + > PyBufferFlags { /// Require C-contiguous layout. Implies shape and strides. - pub const fn c_contiguous(self) -> PyBufferFlags { + pub const fn c_contiguous( + self, + ) -> PyBufferFlags { PyBufferFlags(self.0 | ffi::PyBUF_C_CONTIGUOUS) } /// Require Fortran-contiguous layout. Implies shape and strides. - pub const fn f_contiguous(self) -> PyBufferFlags { + pub const fn f_contiguous( + self, + ) -> PyBufferFlags { PyBufferFlags(self.0 | ffi::PyBUF_F_CONTIGUOUS) } @@ -897,7 +952,7 @@ impl PyBufferFlags { + ) -> PyBufferFlags { PyBufferFlags(self.0 | ffi::PyBUF_ANY_CONTIGUOUS) } } @@ -909,42 +964,42 @@ impl PyBufferFlags { } /// Create a writable request for all buffer information including suboffsets. - pub const fn full() -> PyBufferFlags { + pub const fn full() -> PyBufferFlags { PyBufferFlags(ffi::PyBUF_FULL) } /// Create a read-only request for all buffer information including suboffsets. - pub const fn full_ro() -> PyBufferFlags { + pub const fn full_ro() -> PyBufferFlags { PyBufferFlags(ffi::PyBUF_FULL_RO) } /// Create a writable request for format, shape, and strides. - pub const fn records() -> PyBufferFlags { + pub const fn records() -> PyBufferFlags { PyBufferFlags(ffi::PyBUF_RECORDS) } /// Create a read-only request for format, shape, and strides. - pub const fn records_ro() -> PyBufferFlags { + pub const fn records_ro() -> PyBufferFlags { PyBufferFlags(ffi::PyBUF_RECORDS_RO) } /// Create a writable request for shape and strides. - pub const fn strided() -> PyBufferFlags { + pub const fn strided() -> PyBufferFlags { PyBufferFlags(ffi::PyBUF_STRIDED) } /// Create a read-only request for shape and strides. - pub const fn strided_ro() -> PyBufferFlags { + pub const fn strided_ro() -> PyBufferFlags { PyBufferFlags(ffi::PyBUF_STRIDED_RO) } /// Create a writable C-contiguous request. - pub const fn contig() -> PyBufferFlags { + pub const fn contig() -> PyBufferFlags { PyBufferFlags(ffi::PyBUF_CONTIG) } /// Create a read-only C-contiguous request. - pub const fn contig_ro() -> PyBufferFlags { + pub const fn contig_ro() -> PyBufferFlags { PyBufferFlags(ffi::PyBUF_CONTIG_RO) } } @@ -952,7 +1007,7 @@ impl PyBufferFlags { /// A typed form of [`PyUntypedBufferView`]. Not constructible directly — use /// [`PyBufferView::with()`] or [`PyBufferView::with_flags()`]. #[repr(transparent)] -pub struct PyBufferView>( +pub struct PyBufferView>( PyUntypedBufferView, PhantomData<[T]>, ); @@ -985,7 +1040,7 @@ impl PyUntypedBufferView { /// Gets whether the underlying buffer is read-only. #[inline] pub fn readonly(&self) -> bool { - self.raw.readonly != 0 + !Flags::WRITABLE && self.raw.readonly != 0 } /// Gets the size of a single element, in bytes. @@ -1015,18 +1070,6 @@ impl PyUntypedBufferView { self.raw.ndim as usize } - /// Returns the suboffsets array. - /// - /// May return `None` even with `PyBUF_INDIRECT` if the exporter sets `suboffsets` to NULL. - #[inline] - pub fn suboffsets(&self) -> Option<&[isize]> { - if self.raw.suboffsets.is_null() { - return None; - } - - Some(unsafe { slice::from_raw_parts(self.raw.suboffsets, self.raw.ndim as usize) }) - } - /// Gets whether the buffer is contiguous in C-style order. #[inline] pub fn is_c_contiguous(&self) -> bool { @@ -1042,8 +1085,13 @@ impl PyUntypedBufferView { } } -impl - PyUntypedBufferView> +impl< + const SHAPE: bool, + const STRIDE: bool, + const INDIRECT: bool, + const WRITABLE: bool, + const CONTIGUITY: u8, + > PyUntypedBufferView> { /// A [struct module style](https://docs.python.org/3/c-api/buffer.html#c.Py_buffer.format) /// string describing the contents of a single item. @@ -1056,12 +1104,19 @@ impl( &self, - ) -> PyResult<&PyBufferView>> { + ) -> PyResult< + &PyBufferView>, + > { self.ensure_compatible_with::()?; // SAFETY: PyBufferView is repr(transparent) around PyUntypedBufferView<..> Ok(unsafe { NonNull::from(self) - .cast::>>() + .cast::< + PyBufferView< + T, + PyBufferFlags, + >, + >() .as_ref() }) } @@ -1071,8 +1126,13 @@ impl - PyUntypedBufferView> +impl< + const FORMAT: bool, + const STRIDE: bool, + const INDIRECT: bool, + const WRITABLE: bool, + const CONTIGUITY: u8, + > PyUntypedBufferView> { /// Returns the shape array. `shape[i]` is the length of dimension `i`. /// @@ -1086,8 +1146,13 @@ impl - PyUntypedBufferView> +impl< + const FORMAT: bool, + const SHAPE: bool, + const INDIRECT: bool, + const WRITABLE: bool, + const CONTIGUITY: u8, + > PyUntypedBufferView> { /// Returns the strides array. /// @@ -1100,9 +1165,31 @@ impl PyUntypedBufferView> +{ + /// Returns the suboffsets array. + /// + /// May return `None` even when suboffsets were requested if the exporter sets + /// `suboffsets` to `NULL`. + #[inline] + pub fn suboffsets(&self) -> Option<&[isize]> { + if self.raw.suboffsets.is_null() { + return None; + } + + Some(unsafe { slice::from_raw_parts(self.raw.suboffsets, self.raw.ndim as usize) }) + } +} + // SIMPLE and WRITABLE requests guarantee the implicit "B" format. impl - PyUntypedBufferView> + PyUntypedBufferView> { /// Returns the format string for a simple byte buffer, which is always `"B"`. #[inline] @@ -1144,14 +1231,17 @@ impl PyUntypedBufferView { const FORMAT: bool, const SHAPE: bool, const STRIDE: bool, + const INDIRECT: bool, const WRITABLE: bool, const CONTIGUITY: u8, R, >( obj: &Bound<'_, PyAny>, - flags: PyBufferFlags, + flags: PyBufferFlags, f: impl FnOnce( - &PyUntypedBufferView>, + &PyUntypedBufferView< + PyBufferFlags, + >, ) -> R, ) -> PyResult { let mut raw = mem::MaybeUninit::::uninit(); @@ -1169,35 +1259,6 @@ impl PyUntypedBufferView { } } -fn debug_buffer_view( - name: &str, - raw: &ffi::Py_buffer, - f: &mut std::fmt::Formatter<'_>, -) -> std::fmt::Result { - let ndim = raw.ndim as usize; - let format = NonNull::new(raw.format).map(|p| unsafe { CStr::from_ptr(p.as_ptr()) }); - let shape = NonNull::new(raw.shape) - .map(|p| unsafe { slice::from_raw_parts(p.as_ptr().cast::(), ndim) }); - let strides = - NonNull::new(raw.strides).map(|p| unsafe { slice::from_raw_parts(p.as_ptr(), ndim) }); - let suboffsets = - NonNull::new(raw.suboffsets).map(|p| unsafe { slice::from_raw_parts(p.as_ptr(), ndim) }); - - f.debug_struct(name) - .field("buf", &raw.buf) - .field("obj", &raw.obj) - .field("len", &raw.len) - .field("itemsize", &raw.itemsize) - .field("readonly", &raw.readonly) - .field("ndim", &raw.ndim) - .field("format", &format) - .field("shape", &shape) - .field("strides", &strides) - .field("suboffsets", &suboffsets) - .field("internal", &raw.internal) - .finish() -} - impl Drop for PyUntypedBufferView { fn drop(&mut self) { unsafe { ffi::PyBuffer_Release(&mut self.raw) } @@ -1206,7 +1267,7 @@ impl Drop for PyUntypedBufferView { impl Debug for PyUntypedBufferView { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - debug_buffer_view("PyUntypedBufferView", &self.raw, f) + debug_buffer("PyUntypedBufferView", &self.raw, f) } } @@ -1226,13 +1287,16 @@ impl PyBufferView { const FORMAT: bool, const SHAPE: bool, const STRIDE: bool, + const INDIRECT: bool, const WRITABLE: bool, const CONTIGUITY: u8, R, >( obj: &Bound<'_, PyAny>, - flags: PyBufferFlags, - f: impl FnOnce(&PyBufferView>) -> R, + flags: PyBufferFlags, + f: impl FnOnce( + &PyBufferView>, + ) -> R, ) -> PyResult { let mut raw = mem::MaybeUninit::::uninit(); @@ -1240,7 +1304,9 @@ impl PyBufferView { ffi::PyObject_GetBuffer(obj.as_ptr(), raw.as_mut_ptr(), flags.0 | ffi::PyBUF_FORMAT) })?; - let view = PyUntypedBufferView::> { + let view = PyUntypedBufferView::< + PyBufferFlags, + > { raw: unsafe { raw.assume_init() }, _flags: PhantomData, }; @@ -1285,8 +1351,9 @@ impl< const FORMAT: bool, const SHAPE: bool, const STRIDE: bool, + const INDIRECT: bool, const WRITABLE: bool, - > PyBufferView> + > PyBufferView> { /// Gets the buffer memory as a slice. The buffer is guaranteed C-contiguous. pub fn as_contiguous_slice<'a>(&'a self, _py: Python<'a>) -> &'a [ReadOnlyCell] { @@ -1295,8 +1362,13 @@ impl< } // C-contiguous + writable guaranteed — no checks needed. -impl - PyBufferView> +impl< + T: Element, + const FORMAT: bool, + const SHAPE: bool, + const STRIDE: bool, + const INDIRECT: bool, + > PyBufferView> { /// Gets the buffer memory as a mutable slice. /// The buffer is guaranteed C-contiguous and writable. @@ -1305,14 +1377,15 @@ impl } } -// Fortran-contiguous guaranteed — no contiguity check needed. +// Fortran-contiguous guaranteed. impl< T: Element, const FORMAT: bool, const SHAPE: bool, const STRIDE: bool, + const INDIRECT: bool, const WRITABLE: bool, - > PyBufferView> + > PyBufferView> { /// Gets the buffer memory as a slice. The buffer is guaranteed Fortran-contiguous. pub fn as_fortran_contiguous_slice<'a>(&'a self, _py: Python<'a>) -> &'a [ReadOnlyCell] { @@ -1320,9 +1393,14 @@ impl< } } -// Fortran-contiguous + writable guaranteed — no checks needed. -impl - PyBufferView> +// Fortran-contiguous + writable guaranteed. +impl< + T: Element, + const FORMAT: bool, + const SHAPE: bool, + const STRIDE: bool, + const INDIRECT: bool, + > PyBufferView> { /// Gets the buffer memory as a mutable slice. /// The buffer is guaranteed Fortran-contiguous and writable. @@ -1341,7 +1419,7 @@ impl std::ops::Deref for PyBufferView { impl Debug for PyBufferView { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - debug_buffer_view("PyBufferView", &self.0.raw, f) + debug_buffer("PyBufferView", &self.0.raw, f) } } @@ -1363,8 +1441,8 @@ mod tests { concat!( "PyBuffer {{ buf: {:?}, obj: {:?}, ", "len: 5, itemsize: 1, readonly: 1, ", - "ndim: 1, format: \"B\", shape: [5], ", - "strides: [1], suboffsets: None, internal: {:?} }}", + "ndim: 1, format: Some(\"B\"), shape: Some([5]), ", + "strides: Some([1]), suboffsets: None, internal: {:?} }}", ), buffer.raw().buf, buffer.raw().obj, @@ -1372,6 +1450,21 @@ mod tests { ); let debug_repr = format!("{:?}", buffer); assert_eq!(debug_repr, expected); + + let untyped = PyUntypedBuffer::get(&bytes).unwrap(); + let expected = format!( + concat!( + "PyUntypedBuffer {{ buf: {:?}, obj: {:?}, ", + "len: 5, itemsize: 1, readonly: 1, ", + "ndim: 1, format: Some(\"B\"), shape: Some([5]), ", + "strides: Some([1]), suboffsets: None, internal: {:?} }}", + ), + untyped.raw().buf, + untyped.raw().obj, + untyped.raw().internal + ); + let debug_repr = format!("{:?}", untyped); + assert_eq!(debug_repr, expected); }); } @@ -1739,6 +1832,7 @@ mod tests { // PyBufferView::with uses PyBufferFlags::full_ro() — all Known assert_eq!(view.format().to_str().unwrap(), "B"); assert_eq!(view.shape(), [5]); + assert!(view.suboffsets().is_none()); let slice = view.as_slice(py).unwrap(); assert_eq!(slice.len(), 5); @@ -1790,7 +1884,6 @@ mod tests { assert_eq!(view.item_count(), 5); assert_eq!(view.len_bytes(), 5); assert!(view.readonly()); - assert!(view.suboffsets().is_none()); }) .unwrap(); @@ -1806,6 +1899,13 @@ mod tests { }) .unwrap(); + PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple().indirect(), |view| { + assert_eq!(view.shape(), [5]); + assert_eq!(view.strides(), [1]); + assert!(view.suboffsets().is_none()); + }) + .unwrap(); + PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple().format(), |view| { assert_eq!(view.item_count(), 5); assert_eq!(view.format().to_str().unwrap(), "B"); @@ -1896,6 +1996,38 @@ mod tests { #[test] fn test_flag_builders() { + fn assert_direct< + const FORMAT: bool, + const SHAPE: bool, + const STRIDE: bool, + const WRITABLE: bool, + const CONTIGUITY: u8, + >( + _: PyBufferFlags, + ) { + } + + fn assert_indirect< + const FORMAT: bool, + const SHAPE: bool, + const STRIDE: bool, + const WRITABLE: bool, + const CONTIGUITY: u8, + >( + _: PyBufferFlags, + ) { + } + + assert_direct(PyBufferFlags::simple()); + assert_direct(PyBufferFlags::records_ro()); + assert_direct(PyBufferFlags::strided_ro()); + assert_direct(PyBufferFlags::contig_ro()); + assert_indirect(PyBufferFlags::simple().indirect()); + assert_indirect(PyBufferFlags::full_ro()); + assert_indirect(PyBufferFlags::full()); + assert_direct(PyBufferFlags::full_ro().c_contiguous()); + assert_direct(PyBufferFlags::full().c_contiguous()); + Python::attach(|py| { let bytes = PyBytes::new(py, b"abcde"); let array = py @@ -1929,6 +2061,7 @@ mod tests { PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple().indirect(), |view| { assert_eq!(view.shape(), [5]); assert_eq!(view.strides(), [1]); + assert!(view.suboffsets().is_none()); }) .unwrap(); @@ -2006,6 +2139,7 @@ mod tests { assert_eq!(view.format().to_str().unwrap(), "B"); assert_eq!(view.shape(), [5]); assert_eq!(view.strides(), [1]); + assert!(view.suboffsets().is_none()); }) .unwrap(); @@ -2033,6 +2167,7 @@ mod tests { assert_eq!(view.format().to_str().unwrap(), "f"); assert_eq!(view.shape(), [4]); assert_eq!(view.strides(), [4]); + assert!(view.suboffsets().is_none()); assert!(!view.readonly()); }) .unwrap(); From 26d7d63e433dd5707c4bdfd24a14036ba0c9b4fc Mon Sep 17 00:00:00 2001 From: winstxnhdw Date: Tue, 14 Apr 2026 01:21:36 +0800 Subject: [PATCH 15/17] refactor: hide `PyBufferFlags` --- src/buffer.rs | 209 +++++++++++++++++++++++++++----------------------- 1 file changed, 115 insertions(+), 94 deletions(-) diff --git a/src/buffer.rs b/src/buffer.rs index f0315c0be12..b19e919f585 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -805,18 +805,22 @@ const CONTIGUITY_C: u8 = PyBufferContiguity::C as u8; const CONTIGUITY_F: u8 = PyBufferContiguity::F as u8; const CONTIGUITY_ANY: u8 = PyBufferContiguity::Any as u8; -/// Type-safe buffer request flags. The const parameters encode which fields -/// the exporter is required to fill. +/// Type-safe buffer request flags. The state parameter is intentionally hidden +/// behind this wrapper so the internal encoding can evolve. pub struct PyBufferFlags< - const FORMAT: bool = false, - const SHAPE: bool = false, - const STRIDE: bool = false, - const INDIRECT: bool = false, - const WRITABLE: bool = false, - const CONTIGUITY: u8 = CONTIGUITY_UNDEFINED, ->(c_int); - -mod py_buffer_flags_sealed { + Flags: PyBufferFlagsType = FlagsImpl, +>(c_int, PhantomData); + +mod py_buffer_flags_impl { + pub struct PyBufferFlagsImpl< + const FORMAT: bool, + const SHAPE: bool, + const STRIDE: bool, + const INDIRECT: bool, + const WRITABLE: bool, + const CONTIGUITY: u8, + >; + pub trait Sealed {} impl< const FORMAT: bool, @@ -825,13 +829,15 @@ mod py_buffer_flags_sealed { const INDIRECT: bool, const WRITABLE: bool, const CONTIGUITY: u8, - > Sealed for super::PyBufferFlags + > Sealed for PyBufferFlagsImpl { } } -/// Trait implemented by all [`PyBufferFlags`] instantiations. -pub trait PyBufferFlagsType: py_buffer_flags_sealed::Sealed { +use self::py_buffer_flags_impl::PyBufferFlagsImpl as FlagsImpl; + +/// Trait implemented by all hidden [`PyBufferFlags`] states. +pub trait PyBufferFlagsType: py_buffer_flags_impl::Sealed { /// The contiguity requirement encoded by these flags. const CONTIGUITY: u8; @@ -846,8 +852,7 @@ impl< const INDIRECT: bool, const WRITABLE: bool, const CONTIGUITY_REQ: u8, - > PyBufferFlagsType - for PyBufferFlags + > PyBufferFlagsType for FlagsImpl { const CONTIGUITY: u8 = CONTIGUITY_REQ; const WRITABLE: bool = WRITABLE; @@ -859,13 +864,13 @@ impl< const INDIRECT: bool, const WRITABLE: bool, const CONTIGUITY: u8, - > PyBufferFlags + > PyBufferFlags> { /// Request format information. pub const fn format( self, - ) -> PyBufferFlags { - PyBufferFlags(self.0 | ffi::PyBUF_FORMAT) + ) -> PyBufferFlags> { + PyBufferFlags(self.0 | ffi::PyBUF_FORMAT, PhantomData) } } @@ -875,11 +880,13 @@ impl< const INDIRECT: bool, const WRITABLE: bool, const CONTIGUITY: u8, - > PyBufferFlags + > PyBufferFlags> { /// Request shape information. - pub const fn nd(self) -> PyBufferFlags { - PyBufferFlags(self.0 | ffi::PyBUF_ND) + pub const fn nd( + self, + ) -> PyBufferFlags> { + PyBufferFlags(self.0 | ffi::PyBUF_ND, PhantomData) } } @@ -889,22 +896,24 @@ impl< const INDIRECT: bool, const WRITABLE: bool, const CONTIGUITY: u8, - > PyBufferFlags + > PyBufferFlags> { /// Request strides information. Implies shape. pub const fn strides( self, - ) -> PyBufferFlags { - PyBufferFlags(self.0 | ffi::PyBUF_STRIDES) + ) -> PyBufferFlags> { + PyBufferFlags(self.0 | ffi::PyBUF_STRIDES, PhantomData) } } impl - PyBufferFlags + PyBufferFlags> { /// Request suboffsets (indirect). Implies shape and strides. - pub const fn indirect(self) -> PyBufferFlags { - PyBufferFlags(self.0 | ffi::PyBUF_INDIRECT) + pub const fn indirect( + self, + ) -> PyBufferFlags> { + PyBufferFlags(self.0 | ffi::PyBUF_INDIRECT, PhantomData) } } @@ -914,13 +923,13 @@ impl< const STRIDE: bool, const INDIRECT: bool, const CONTIGUITY: u8, - > PyBufferFlags + > PyBufferFlags> { /// Request a writable buffer. pub const fn writable( self, - ) -> PyBufferFlags { - PyBufferFlags(self.0 | ffi::PyBUF_WRITABLE) + ) -> PyBufferFlags> { + PyBufferFlags(self.0 | ffi::PyBUF_WRITABLE, PhantomData) } } @@ -930,20 +939,20 @@ impl< const STRIDE: bool, const INDIRECT: bool, const WRITABLE: bool, - > PyBufferFlags + > PyBufferFlags> { /// Require C-contiguous layout. Implies shape and strides. pub const fn c_contiguous( self, - ) -> PyBufferFlags { - PyBufferFlags(self.0 | ffi::PyBUF_C_CONTIGUOUS) + ) -> PyBufferFlags> { + PyBufferFlags(self.0 | ffi::PyBUF_C_CONTIGUOUS, PhantomData) } /// Require Fortran-contiguous layout. Implies shape and strides. pub const fn f_contiguous( self, - ) -> PyBufferFlags { - PyBufferFlags(self.0 | ffi::PyBUF_F_CONTIGUOUS) + ) -> PyBufferFlags> { + PyBufferFlags(self.0 | ffi::PyBUF_F_CONTIGUOUS, PhantomData) } /// Require contiguous layout (C or Fortran). Implies shape and strides. @@ -952,65 +961,81 @@ impl< /// so this does not unlock non-Option slice accessors. pub const fn any_contiguous( self, - ) -> PyBufferFlags { - PyBufferFlags(self.0 | ffi::PyBUF_ANY_CONTIGUOUS) + ) -> PyBufferFlags> { + PyBufferFlags(self.0 | ffi::PyBUF_ANY_CONTIGUOUS, PhantomData) } } -impl PyBufferFlags { +impl PyBufferFlags> { /// Create a base buffer request. Chain builder methods to add flags. - pub const fn simple() -> PyBufferFlags { - PyBufferFlags(ffi::PyBUF_SIMPLE) + pub const fn simple() -> Self { + PyBufferFlags(ffi::PyBUF_SIMPLE, PhantomData) } +} +impl PyBufferFlags> { /// Create a writable request for all buffer information including suboffsets. - pub const fn full() -> PyBufferFlags { - PyBufferFlags(ffi::PyBUF_FULL) + pub const fn full() -> Self { + PyBufferFlags(ffi::PyBUF_FULL, PhantomData) } +} +impl PyBufferFlags> { /// Create a read-only request for all buffer information including suboffsets. - pub const fn full_ro() -> PyBufferFlags { - PyBufferFlags(ffi::PyBUF_FULL_RO) + pub const fn full_ro() -> Self { + PyBufferFlags(ffi::PyBUF_FULL_RO, PhantomData) } +} +impl PyBufferFlags> { /// Create a writable request for format, shape, and strides. - pub const fn records() -> PyBufferFlags { - PyBufferFlags(ffi::PyBUF_RECORDS) + pub const fn records() -> Self { + PyBufferFlags(ffi::PyBUF_RECORDS, PhantomData) } +} +impl PyBufferFlags> { /// Create a read-only request for format, shape, and strides. - pub const fn records_ro() -> PyBufferFlags { - PyBufferFlags(ffi::PyBUF_RECORDS_RO) + pub const fn records_ro() -> Self { + PyBufferFlags(ffi::PyBUF_RECORDS_RO, PhantomData) } +} +impl PyBufferFlags> { /// Create a writable request for shape and strides. - pub const fn strided() -> PyBufferFlags { - PyBufferFlags(ffi::PyBUF_STRIDED) + pub const fn strided() -> Self { + PyBufferFlags(ffi::PyBUF_STRIDED, PhantomData) } +} +impl PyBufferFlags> { /// Create a read-only request for shape and strides. - pub const fn strided_ro() -> PyBufferFlags { - PyBufferFlags(ffi::PyBUF_STRIDED_RO) + pub const fn strided_ro() -> Self { + PyBufferFlags(ffi::PyBUF_STRIDED_RO, PhantomData) } +} +impl PyBufferFlags> { /// Create a writable C-contiguous request. - pub const fn contig() -> PyBufferFlags { - PyBufferFlags(ffi::PyBUF_CONTIG) + pub const fn contig() -> Self { + PyBufferFlags(ffi::PyBUF_CONTIG, PhantomData) } +} +impl PyBufferFlags> { /// Create a read-only C-contiguous request. - pub const fn contig_ro() -> PyBufferFlags { - PyBufferFlags(ffi::PyBUF_CONTIG_RO) + pub const fn contig_ro() -> Self { + PyBufferFlags(ffi::PyBUF_CONTIG_RO, PhantomData) } } /// A typed form of [`PyUntypedBufferView`]. Not constructible directly — use /// [`PyBufferView::with()`] or [`PyBufferView::with_flags()`]. #[repr(transparent)] -pub struct PyBufferView>( - PyUntypedBufferView, - PhantomData<[T]>, -); +pub struct PyBufferView< + T, + Flags: PyBufferFlagsType = FlagsImpl, +>(PyUntypedBufferView, PhantomData<[T]>); /// Stack-allocated untyped buffer view. /// @@ -1019,7 +1044,9 @@ pub struct PyBufferView { +pub struct PyUntypedBufferView< + Flags: PyBufferFlagsType = FlagsImpl, +> { raw: ffi::Py_buffer, _flags: PhantomData, } @@ -1091,7 +1118,7 @@ impl< const INDIRECT: bool, const WRITABLE: bool, const CONTIGUITY: u8, - > PyUntypedBufferView> + > PyUntypedBufferView> { /// A [struct module style](https://docs.python.org/3/c-api/buffer.html#c.Py_buffer.format) /// string describing the contents of a single item. @@ -1104,19 +1131,16 @@ impl< /// Attempt to interpret this untyped view as containing elements of type `T`. pub fn as_typed( &self, - ) -> PyResult< - &PyBufferView>, - > { + ) -> PyResult<&PyBufferView>> + { self.ensure_compatible_with::()?; // SAFETY: PyBufferView is repr(transparent) around PyUntypedBufferView<..> Ok(unsafe { NonNull::from(self) - .cast::< - PyBufferView< - T, - PyBufferFlags, - >, - >() + .cast::, + >>() .as_ref() }) } @@ -1132,7 +1156,7 @@ impl< const INDIRECT: bool, const WRITABLE: bool, const CONTIGUITY: u8, - > PyUntypedBufferView> + > PyUntypedBufferView> { /// Returns the shape array. `shape[i]` is the length of dimension `i`. /// @@ -1152,7 +1176,7 @@ impl< const INDIRECT: bool, const WRITABLE: bool, const CONTIGUITY: u8, - > PyUntypedBufferView> + > PyUntypedBufferView> { /// Returns the strides array. /// @@ -1171,7 +1195,7 @@ impl< const STRIDE: bool, const WRITABLE: bool, const CONTIGUITY: u8, - > PyUntypedBufferView> + > PyUntypedBufferView> { /// Returns the suboffsets array. /// @@ -1189,7 +1213,7 @@ impl< // SIMPLE and WRITABLE requests guarantee the implicit "B" format. impl - PyUntypedBufferView> + PyUntypedBufferView> { /// Returns the format string for a simple byte buffer, which is always `"B"`. #[inline] @@ -1237,11 +1261,9 @@ impl PyUntypedBufferView { R, >( obj: &Bound<'_, PyAny>, - flags: PyBufferFlags, + flags: PyBufferFlags>, f: impl FnOnce( - &PyUntypedBufferView< - PyBufferFlags, - >, + &PyUntypedBufferView>, ) -> R, ) -> PyResult { let mut raw = mem::MaybeUninit::::uninit(); @@ -1293,9 +1315,9 @@ impl PyBufferView { R, >( obj: &Bound<'_, PyAny>, - flags: PyBufferFlags, + flags: PyBufferFlags>, f: impl FnOnce( - &PyBufferView>, + &PyBufferView>, ) -> R, ) -> PyResult { let mut raw = mem::MaybeUninit::::uninit(); @@ -1304,12 +1326,11 @@ impl PyBufferView { ffi::PyObject_GetBuffer(obj.as_ptr(), raw.as_mut_ptr(), flags.0 | ffi::PyBUF_FORMAT) })?; - let view = PyUntypedBufferView::< - PyBufferFlags, - > { - raw: unsafe { raw.assume_init() }, - _flags: PhantomData, - }; + let view = + PyUntypedBufferView::> { + raw: unsafe { raw.assume_init() }, + _flags: PhantomData, + }; view.as_typed::().map(f) } @@ -1353,7 +1374,7 @@ impl< const STRIDE: bool, const INDIRECT: bool, const WRITABLE: bool, - > PyBufferView> + > PyBufferView> { /// Gets the buffer memory as a slice. The buffer is guaranteed C-contiguous. pub fn as_contiguous_slice<'a>(&'a self, _py: Python<'a>) -> &'a [ReadOnlyCell] { @@ -1368,7 +1389,7 @@ impl< const SHAPE: bool, const STRIDE: bool, const INDIRECT: bool, - > PyBufferView> + > PyBufferView> { /// Gets the buffer memory as a mutable slice. /// The buffer is guaranteed C-contiguous and writable. @@ -1385,7 +1406,7 @@ impl< const STRIDE: bool, const INDIRECT: bool, const WRITABLE: bool, - > PyBufferView> + > PyBufferView> { /// Gets the buffer memory as a slice. The buffer is guaranteed Fortran-contiguous. pub fn as_fortran_contiguous_slice<'a>(&'a self, _py: Python<'a>) -> &'a [ReadOnlyCell] { @@ -1400,7 +1421,7 @@ impl< const SHAPE: bool, const STRIDE: bool, const INDIRECT: bool, - > PyBufferView> + > PyBufferView> { /// Gets the buffer memory as a mutable slice. /// The buffer is guaranteed Fortran-contiguous and writable. @@ -2003,7 +2024,7 @@ mod tests { const WRITABLE: bool, const CONTIGUITY: u8, >( - _: PyBufferFlags, + _: PyBufferFlags>, ) { } @@ -2014,7 +2035,7 @@ mod tests { const WRITABLE: bool, const CONTIGUITY: u8, >( - _: PyBufferFlags, + _: PyBufferFlags>, ) { } From 2c78ada5a5dfad87b082ed64b215c526e861cdce Mon Sep 17 00:00:00 2001 From: winstxnhdw Date: Tue, 14 Apr 2026 23:43:22 +0800 Subject: [PATCH 16/17] refactor: apply trivial suggestions --- src/buffer.rs | 351 +++++++++++++++++++++++++++----------------------- 1 file changed, 191 insertions(+), 160 deletions(-) diff --git a/src/buffer.rs b/src/buffer.rs index b19e919f585..e83527271b3 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -805,14 +805,25 @@ const CONTIGUITY_C: u8 = PyBufferContiguity::C as u8; const CONTIGUITY_F: u8 = PyBufferContiguity::F as u8; const CONTIGUITY_ANY: u8 = PyBufferContiguity::Any as u8; -/// Type-safe buffer request flags. The state parameter is intentionally hidden +/// Type-safe buffer request. The state parameter is intentionally hidden /// behind this wrapper so the internal encoding can evolve. -pub struct PyBufferFlags< - Flags: PyBufferFlagsType = FlagsImpl, +/// +/// The requested flags constrain what exporters are allowed to return. For example, +/// without shape information, only 1-dimensional buffers are permitted, and accessors +/// for unrequested metadata are unavailable on the typed view. +pub struct PyBufferRequest< + Flags: PyBufferRequestType = RequestFlags< + false, + false, + false, + false, + false, + CONTIGUITY_UNDEFINED, + >, >(c_int, PhantomData); -mod py_buffer_flags_impl { - pub struct PyBufferFlagsImpl< +mod py_buffer_flags { + pub struct PyBufferFlags< const FORMAT: bool, const SHAPE: bool, const STRIDE: bool, @@ -829,15 +840,15 @@ mod py_buffer_flags_impl { const INDIRECT: bool, const WRITABLE: bool, const CONTIGUITY: u8, - > Sealed for PyBufferFlagsImpl + > Sealed for PyBufferFlags { } } -use self::py_buffer_flags_impl::PyBufferFlagsImpl as FlagsImpl; +use self::py_buffer_flags::PyBufferFlags as RequestFlags; -/// Trait implemented by all hidden [`PyBufferFlags`] states. -pub trait PyBufferFlagsType: py_buffer_flags_impl::Sealed { +/// Trait implemented by all hidden [`PyBufferRequest`] states. +pub trait PyBufferRequestType: py_buffer_flags::Sealed { /// The contiguity requirement encoded by these flags. const CONTIGUITY: u8; @@ -852,7 +863,8 @@ impl< const INDIRECT: bool, const WRITABLE: bool, const CONTIGUITY_REQ: u8, - > PyBufferFlagsType for FlagsImpl + > PyBufferRequestType + for RequestFlags { const CONTIGUITY: u8 = CONTIGUITY_REQ; const WRITABLE: bool = WRITABLE; @@ -864,13 +876,13 @@ impl< const INDIRECT: bool, const WRITABLE: bool, const CONTIGUITY: u8, - > PyBufferFlags> + > PyBufferRequest> { /// Request format information. pub const fn format( self, - ) -> PyBufferFlags> { - PyBufferFlags(self.0 | ffi::PyBUF_FORMAT, PhantomData) + ) -> PyBufferRequest> { + PyBufferRequest(self.0 | ffi::PyBUF_FORMAT, PhantomData) } } @@ -880,13 +892,13 @@ impl< const INDIRECT: bool, const WRITABLE: bool, const CONTIGUITY: u8, - > PyBufferFlags> + > PyBufferRequest> { /// Request shape information. pub const fn nd( self, - ) -> PyBufferFlags> { - PyBufferFlags(self.0 | ffi::PyBUF_ND, PhantomData) + ) -> PyBufferRequest> { + PyBufferRequest(self.0 | ffi::PyBUF_ND, PhantomData) } } @@ -896,24 +908,25 @@ impl< const INDIRECT: bool, const WRITABLE: bool, const CONTIGUITY: u8, - > PyBufferFlags> + > PyBufferRequest> { /// Request strides information. Implies shape. pub const fn strides( self, - ) -> PyBufferFlags> { - PyBufferFlags(self.0 | ffi::PyBUF_STRIDES, PhantomData) + ) -> PyBufferRequest> { + PyBufferRequest(self.0 | ffi::PyBUF_STRIDES, PhantomData) } } impl - PyBufferFlags> + PyBufferRequest> { /// Request suboffsets (indirect). Implies shape and strides. pub const fn indirect( self, - ) -> PyBufferFlags> { - PyBufferFlags(self.0 | ffi::PyBUF_INDIRECT, PhantomData) + ) -> PyBufferRequest> + { + PyBufferRequest(self.0 | ffi::PyBUF_INDIRECT, PhantomData) } } @@ -923,13 +936,13 @@ impl< const STRIDE: bool, const INDIRECT: bool, const CONTIGUITY: u8, - > PyBufferFlags> + > PyBufferRequest> { /// Request a writable buffer. pub const fn writable( self, - ) -> PyBufferFlags> { - PyBufferFlags(self.0 | ffi::PyBUF_WRITABLE, PhantomData) + ) -> PyBufferRequest> { + PyBufferRequest(self.0 | ffi::PyBUF_WRITABLE, PhantomData) } } @@ -939,20 +952,21 @@ impl< const STRIDE: bool, const INDIRECT: bool, const WRITABLE: bool, - > PyBufferFlags> + > + PyBufferRequest> { /// Require C-contiguous layout. Implies shape and strides. pub const fn c_contiguous( self, - ) -> PyBufferFlags> { - PyBufferFlags(self.0 | ffi::PyBUF_C_CONTIGUOUS, PhantomData) + ) -> PyBufferRequest> { + PyBufferRequest(self.0 | ffi::PyBUF_C_CONTIGUOUS, PhantomData) } /// Require Fortran-contiguous layout. Implies shape and strides. pub const fn f_contiguous( self, - ) -> PyBufferFlags> { - PyBufferFlags(self.0 | ffi::PyBUF_F_CONTIGUOUS, PhantomData) + ) -> PyBufferRequest> { + PyBufferRequest(self.0 | ffi::PyBUF_F_CONTIGUOUS, PhantomData) } /// Require contiguous layout (C or Fortran). Implies shape and strides. @@ -961,71 +975,65 @@ impl< /// so this does not unlock non-Option slice accessors. pub const fn any_contiguous( self, - ) -> PyBufferFlags> { - PyBufferFlags(self.0 | ffi::PyBUF_ANY_CONTIGUOUS, PhantomData) + ) -> PyBufferRequest> { + PyBufferRequest(self.0 | ffi::PyBUF_ANY_CONTIGUOUS, PhantomData) } } -impl PyBufferFlags> { +impl PyBufferRequest { /// Create a base buffer request. Chain builder methods to add flags. - pub const fn simple() -> Self { - PyBufferFlags(ffi::PyBUF_SIMPLE, PhantomData) + pub const fn simple( + ) -> PyBufferRequest> + { + PyBufferRequest(ffi::PyBUF_SIMPLE, PhantomData) } -} -impl PyBufferFlags> { /// Create a writable request for all buffer information including suboffsets. - pub const fn full() -> Self { - PyBufferFlags(ffi::PyBUF_FULL, PhantomData) + pub const fn full( + ) -> PyBufferRequest> { + PyBufferRequest(ffi::PyBUF_FULL, PhantomData) } -} -impl PyBufferFlags> { /// Create a read-only request for all buffer information including suboffsets. - pub const fn full_ro() -> Self { - PyBufferFlags(ffi::PyBUF_FULL_RO, PhantomData) + pub const fn full_ro( + ) -> PyBufferRequest> { + PyBufferRequest(ffi::PyBUF_FULL_RO, PhantomData) } -} -impl PyBufferFlags> { /// Create a writable request for format, shape, and strides. - pub const fn records() -> Self { - PyBufferFlags(ffi::PyBUF_RECORDS, PhantomData) + pub const fn records( + ) -> PyBufferRequest> { + PyBufferRequest(ffi::PyBUF_RECORDS, PhantomData) } -} -impl PyBufferFlags> { /// Create a read-only request for format, shape, and strides. - pub const fn records_ro() -> Self { - PyBufferFlags(ffi::PyBUF_RECORDS_RO, PhantomData) + pub const fn records_ro( + ) -> PyBufferRequest> { + PyBufferRequest(ffi::PyBUF_RECORDS_RO, PhantomData) } -} -impl PyBufferFlags> { /// Create a writable request for shape and strides. - pub const fn strided() -> Self { - PyBufferFlags(ffi::PyBUF_STRIDED, PhantomData) + pub const fn strided( + ) -> PyBufferRequest> { + PyBufferRequest(ffi::PyBUF_STRIDED, PhantomData) } -} -impl PyBufferFlags> { /// Create a read-only request for shape and strides. - pub const fn strided_ro() -> Self { - PyBufferFlags(ffi::PyBUF_STRIDED_RO, PhantomData) + pub const fn strided_ro( + ) -> PyBufferRequest> { + PyBufferRequest(ffi::PyBUF_STRIDED_RO, PhantomData) } -} -impl PyBufferFlags> { /// Create a writable C-contiguous request. - pub const fn contig() -> Self { - PyBufferFlags(ffi::PyBUF_CONTIG, PhantomData) + pub const fn contig( + ) -> PyBufferRequest> { + PyBufferRequest(ffi::PyBUF_CONTIG, PhantomData) } -} -impl PyBufferFlags> { /// Create a read-only C-contiguous request. - pub const fn contig_ro() -> Self { - PyBufferFlags(ffi::PyBUF_CONTIG_RO, PhantomData) + pub const fn contig_ro( + ) -> PyBufferRequest> { + PyBufferRequest(ffi::PyBUF_CONTIG_RO, PhantomData) } } @@ -1034,7 +1042,7 @@ impl PyBufferFlags> { #[repr(transparent)] pub struct PyBufferView< T, - Flags: PyBufferFlagsType = FlagsImpl, + Flags: PyBufferRequestType = RequestFlags, >(PyUntypedBufferView, PhantomData<[T]>); /// Stack-allocated untyped buffer view. @@ -1042,16 +1050,23 @@ pub struct PyBufferView< /// Unlike [`PyUntypedBuffer`] which heap-allocates, this places the `Py_buffer` on the /// stack. The scoped closure API ensures the buffer cannot be moved. /// -/// Use [`with_flags()`](Self::with_flags) with a [`PyBufferFlags`] value to acquire a view. +/// Use [`with_flags()`](Self::with_flags) with a [`PyBufferRequest`] value to acquire a view. /// The available accessors depend on the flags used. pub struct PyUntypedBufferView< - Flags: PyBufferFlagsType = FlagsImpl, + Flags: PyBufferRequestType = RequestFlags< + false, + false, + false, + false, + false, + CONTIGUITY_UNDEFINED, + >, > { raw: ffi::Py_buffer, _flags: PhantomData, } -impl PyUntypedBufferView { +impl PyUntypedBufferView { /// Gets the pointer to the start of the buffer memory. #[inline] pub fn buf_ptr(&self) -> *mut c_void { @@ -1118,7 +1133,7 @@ impl< const INDIRECT: bool, const WRITABLE: bool, const CONTIGUITY: u8, - > PyUntypedBufferView> + > PyUntypedBufferView> { /// A [struct module style](https://docs.python.org/3/c-api/buffer.html#c.Py_buffer.format) /// string describing the contents of a single item. @@ -1131,7 +1146,7 @@ impl< /// Attempt to interpret this untyped view as containing elements of type `T`. pub fn as_typed( &self, - ) -> PyResult<&PyBufferView>> + ) -> PyResult<&PyBufferView>> { self.ensure_compatible_with::()?; // SAFETY: PyBufferView is repr(transparent) around PyUntypedBufferView<..> @@ -1139,7 +1154,7 @@ impl< NonNull::from(self) .cast::, + RequestFlags, >>() .as_ref() }) @@ -1156,7 +1171,7 @@ impl< const INDIRECT: bool, const WRITABLE: bool, const CONTIGUITY: u8, - > PyUntypedBufferView> + > PyUntypedBufferView> { /// Returns the shape array. `shape[i]` is the length of dimension `i`. /// @@ -1176,7 +1191,7 @@ impl< const INDIRECT: bool, const WRITABLE: bool, const CONTIGUITY: u8, - > PyUntypedBufferView> + > PyUntypedBufferView> { /// Returns the strides array. /// @@ -1195,7 +1210,7 @@ impl< const STRIDE: bool, const WRITABLE: bool, const CONTIGUITY: u8, - > PyUntypedBufferView> + > PyUntypedBufferView> { /// Returns the suboffsets array. /// @@ -1213,7 +1228,7 @@ impl< // SIMPLE and WRITABLE requests guarantee the implicit "B" format. impl - PyUntypedBufferView> + PyUntypedBufferView> { /// Returns the format string for a simple byte buffer, which is always `"B"`. #[inline] @@ -1249,8 +1264,11 @@ impl PyUntypedBufferView { /// Acquire a buffer view with the given flags, /// pass it to `f`, then release the buffer. /// - /// Use [`PyBufferFlags::simple()`] or one of the compound-request constructors such as - /// [`PyBufferFlags::full_ro()`] to acquire a view. + /// Use [`PyBufferRequest::simple()`] or one of the compound-request constructors such as + /// [`PyBufferRequest::full_ro()`] to acquire a view. + /// + /// The requested flags constrain what exporters may return. For example, without shape + /// information only 1-dimensional buffers are permitted. pub fn with_flags< const FORMAT: bool, const SHAPE: bool, @@ -1261,9 +1279,11 @@ impl PyUntypedBufferView { R, >( obj: &Bound<'_, PyAny>, - flags: PyBufferFlags>, + flags: PyBufferRequest>, f: impl FnOnce( - &PyUntypedBufferView>, + &PyUntypedBufferView< + RequestFlags, + >, ) -> R, ) -> PyResult { let mut raw = mem::MaybeUninit::::uninit(); @@ -1281,30 +1301,32 @@ impl PyUntypedBufferView { } } -impl Drop for PyUntypedBufferView { +impl Drop for PyUntypedBufferView { fn drop(&mut self) { unsafe { ffi::PyBuffer_Release(&mut self.raw) } } } -impl Debug for PyUntypedBufferView { +impl Debug for PyUntypedBufferView { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { debug_buffer("PyUntypedBufferView", &self.raw, f) } } impl PyBufferView { - /// Acquire a typed buffer view with `PyBufferFlags::full_ro()` flags, + /// Acquire a typed buffer view with `PyBufferRequest::full_ro()` flags, /// validating that the buffer format is compatible with `T`. pub fn with(obj: &Bound<'_, PyAny>, f: impl FnOnce(&PyBufferView) -> R) -> PyResult { - PyUntypedBufferView::with_flags(obj, PyBufferFlags::full_ro(), |view| { + PyUntypedBufferView::with_flags(obj, PyBufferRequest::full_ro(), |view| { view.as_typed::().map(f) })? } /// Acquire a typed buffer view with the given flags. /// - /// [`ffi::PyBUF_FORMAT`] is implicitly added for type validation. + /// [`ffi::PyBUF_FORMAT`] is implicitly added for type validation. As with + /// [`PyUntypedBufferView::with_flags`], the requested flags also constrain what exporters + /// may return. pub fn with_flags< const FORMAT: bool, const SHAPE: bool, @@ -1315,9 +1337,9 @@ impl PyBufferView { R, >( obj: &Bound<'_, PyAny>, - flags: PyBufferFlags>, + flags: PyBufferRequest>, f: impl FnOnce( - &PyBufferView>, + &PyBufferView>, ) -> R, ) -> PyResult { let mut raw = mem::MaybeUninit::::uninit(); @@ -1326,17 +1348,18 @@ impl PyBufferView { ffi::PyObject_GetBuffer(obj.as_ptr(), raw.as_mut_ptr(), flags.0 | ffi::PyBUF_FORMAT) })?; - let view = - PyUntypedBufferView::> { - raw: unsafe { raw.assume_init() }, - _flags: PhantomData, - }; + let view = PyUntypedBufferView::< + RequestFlags, + > { + raw: unsafe { raw.assume_init() }, + _flags: PhantomData, + }; view.as_typed::().map(f) } } -impl PyBufferView { +impl PyBufferView { /// Gets the buffer memory as a slice. /// /// Returns `None` if the buffer is not C-contiguous. @@ -1374,7 +1397,7 @@ impl< const STRIDE: bool, const INDIRECT: bool, const WRITABLE: bool, - > PyBufferView> + > PyBufferView> { /// Gets the buffer memory as a slice. The buffer is guaranteed C-contiguous. pub fn as_contiguous_slice<'a>(&'a self, _py: Python<'a>) -> &'a [ReadOnlyCell] { @@ -1389,7 +1412,7 @@ impl< const SHAPE: bool, const STRIDE: bool, const INDIRECT: bool, - > PyBufferView> + > PyBufferView> { /// Gets the buffer memory as a mutable slice. /// The buffer is guaranteed C-contiguous and writable. @@ -1406,7 +1429,7 @@ impl< const STRIDE: bool, const INDIRECT: bool, const WRITABLE: bool, - > PyBufferView> + > PyBufferView> { /// Gets the buffer memory as a slice. The buffer is guaranteed Fortran-contiguous. pub fn as_fortran_contiguous_slice<'a>(&'a self, _py: Python<'a>) -> &'a [ReadOnlyCell] { @@ -1421,7 +1444,7 @@ impl< const SHAPE: bool, const STRIDE: bool, const INDIRECT: bool, - > PyBufferView> + > PyBufferView> { /// Gets the buffer memory as a mutable slice. /// The buffer is guaranteed Fortran-contiguous and writable. @@ -1430,7 +1453,7 @@ impl< } } -impl std::ops::Deref for PyBufferView { +impl std::ops::Deref for PyBufferView { type Target = PyUntypedBufferView; fn deref(&self) -> &Self::Target { @@ -1438,7 +1461,7 @@ impl std::ops::Deref for PyBufferView { } } -impl Debug for PyBufferView { +impl Debug for PyBufferView { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { debug_buffer("PyBufferView", &self.0.raw, f) } @@ -1823,14 +1846,14 @@ mod tests { fn test_untyped_buffer_view() { Python::attach(|py| { let bytes = PyBytes::new(py, b"abcde"); - PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::full_ro(), |view| { + PyUntypedBufferView::with_flags(&bytes, PyBufferRequest::full_ro(), |view| { assert!(!view.buf_ptr().is_null()); assert_eq!(view.len_bytes(), 5); assert_eq!(view.item_size(), 1); assert_eq!(view.item_count(), 5); assert!(view.readonly()); assert_eq!(view.dimensions(), 1); - // with_flags() uses PyBufferFlags::full_ro() — all Known, direct return types + // with_flags() uses PyBufferRequest::full_ro() — all Known, direct return types assert_eq!(view.format().to_str().unwrap(), "B"); assert_eq!(view.shape(), [5]); assert_eq!(view.strides(), [1]); @@ -1850,7 +1873,7 @@ mod tests { PyBufferView::::with(&bytes, |view| { assert_eq!(view.dimensions(), 1); assert_eq!(view.item_count(), 5); - // PyBufferView::with uses PyBufferFlags::full_ro() — all Known + // PyBufferView::with uses PyBufferRequest::full_ro() — all Known assert_eq!(view.format().to_str().unwrap(), "B"); assert_eq!(view.shape(), [5]); assert!(view.suboffsets().is_none()); @@ -1901,33 +1924,33 @@ mod tests { Python::attach(|py| { let bytes = PyBytes::new(py, b"abcde"); - PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple(), |view| { + PyUntypedBufferView::with_flags(&bytes, PyBufferRequest::simple(), |view| { assert_eq!(view.item_count(), 5); assert_eq!(view.len_bytes(), 5); assert!(view.readonly()); }) .unwrap(); - PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple().nd(), |view| { + PyUntypedBufferView::with_flags(&bytes, PyBufferRequest::simple().nd(), |view| { assert_eq!(view.item_count(), 5); assert_eq!(view.shape(), [5]); }) .unwrap(); - PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple().strides(), |view| { + PyUntypedBufferView::with_flags(&bytes, PyBufferRequest::simple().strides(), |view| { assert_eq!(view.shape(), [5]); assert_eq!(view.strides(), [1]); }) .unwrap(); - PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple().indirect(), |view| { + PyUntypedBufferView::with_flags(&bytes, PyBufferRequest::simple().indirect(), |view| { assert_eq!(view.shape(), [5]); assert_eq!(view.strides(), [1]); assert!(view.suboffsets().is_none()); }) .unwrap(); - PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple().format(), |view| { + PyUntypedBufferView::with_flags(&bytes, PyBufferRequest::simple().format(), |view| { assert_eq!(view.item_count(), 5); assert_eq!(view.format().to_str().unwrap(), "B"); }) @@ -1944,7 +1967,7 @@ mod tests { .call_method("array", ("f", (1.0, 1.5, 2.0, 2.5)), None) .unwrap(); - PyBufferView::::with_flags(&array, PyBufferFlags::simple().nd(), |view| { + PyBufferView::::with_flags(&array, PyBufferRequest::simple().nd(), |view| { assert_eq!(view.item_count(), 4); assert_eq!(view.format().to_str().unwrap(), "f"); assert_eq!(view.shape(), [4]); @@ -1966,7 +1989,7 @@ mod tests { Python::attach(|py| { let bytes = PyBytes::new(py, b"abcde"); let result = - PyBufferView::::with_flags(&bytes, PyBufferFlags::simple().nd(), |_view| {}); + PyBufferView::::with_flags(&bytes, PyBufferRequest::simple().nd(), |_view| {}); assert!(result.is_err()); }); } @@ -1983,7 +2006,7 @@ mod tests { // C_CONTIGUOUS: guaranteed contiguous readonly access (no Option) PyBufferView::::with_flags( &array, - PyBufferFlags::simple().c_contiguous(), + PyBufferRequest::simple().c_contiguous(), |view| { let slice = view.as_contiguous_slice(py); assert_eq!(slice.len(), 3); @@ -2010,7 +2033,7 @@ mod tests { Python::attach(|py| { let list = crate::types::PyList::empty(py); let result = - PyUntypedBufferView::with_flags(&list, PyBufferFlags::full_ro(), |_view| {}); + PyUntypedBufferView::with_flags(&list, PyBufferRequest::full_ro(), |_view| {}); assert!(result.is_err()); }); } @@ -2024,7 +2047,7 @@ mod tests { const WRITABLE: bool, const CONTIGUITY: u8, >( - _: PyBufferFlags>, + _: PyBufferRequest>, ) { } @@ -2035,19 +2058,19 @@ mod tests { const WRITABLE: bool, const CONTIGUITY: u8, >( - _: PyBufferFlags>, + _: PyBufferRequest>, ) { } - assert_direct(PyBufferFlags::simple()); - assert_direct(PyBufferFlags::records_ro()); - assert_direct(PyBufferFlags::strided_ro()); - assert_direct(PyBufferFlags::contig_ro()); - assert_indirect(PyBufferFlags::simple().indirect()); - assert_indirect(PyBufferFlags::full_ro()); - assert_indirect(PyBufferFlags::full()); - assert_direct(PyBufferFlags::full_ro().c_contiguous()); - assert_direct(PyBufferFlags::full().c_contiguous()); + assert_direct(PyBufferRequest::simple()); + assert_direct(PyBufferRequest::records_ro()); + assert_direct(PyBufferRequest::strided_ro()); + assert_direct(PyBufferRequest::contig_ro()); + assert_indirect(PyBufferRequest::simple().indirect()); + assert_indirect(PyBufferRequest::full_ro()); + assert_indirect(PyBufferRequest::full()); + assert_direct(PyBufferRequest::full_ro().c_contiguous()); + assert_direct(PyBufferRequest::full().c_contiguous()); Python::attach(|py| { let bytes = PyBytes::new(py, b"abcde"); @@ -2058,35 +2081,35 @@ mod tests { .unwrap(); // Primitive builders - PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple(), |view| { + PyUntypedBufferView::with_flags(&bytes, PyBufferRequest::simple(), |view| { assert_eq!(view.format().to_str().unwrap(), "B"); }) .unwrap(); - PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple().format(), |view| { + PyUntypedBufferView::with_flags(&bytes, PyBufferRequest::simple().format(), |view| { assert_eq!(view.format().to_str().unwrap(), "B"); }) .unwrap(); - PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple().nd(), |view| { + PyUntypedBufferView::with_flags(&bytes, PyBufferRequest::simple().nd(), |view| { assert_eq!(view.shape(), [5]); }) .unwrap(); - PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple().strides(), |view| { + PyUntypedBufferView::with_flags(&bytes, PyBufferRequest::simple().strides(), |view| { assert_eq!(view.shape(), [5]); assert_eq!(view.strides(), [1]); }) .unwrap(); - PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple().indirect(), |view| { + PyUntypedBufferView::with_flags(&bytes, PyBufferRequest::simple().indirect(), |view| { assert_eq!(view.shape(), [5]); assert_eq!(view.strides(), [1]); assert!(view.suboffsets().is_none()); }) .unwrap(); - PyUntypedBufferView::with_flags(&array, PyBufferFlags::simple().writable(), |view| { + PyUntypedBufferView::with_flags(&array, PyBufferRequest::simple().writable(), |view| { assert_eq!(view.format().to_str().unwrap(), "B"); assert!(!view.readonly()); }) @@ -2094,7 +2117,7 @@ mod tests { PyUntypedBufferView::with_flags( &array, - PyBufferFlags::simple().writable().nd(), + PyBufferRequest::simple().writable().nd(), |view| { assert_eq!(view.shape(), [4]); assert!(!view.readonly()); @@ -2105,7 +2128,7 @@ mod tests { // Chained primitive builders PyUntypedBufferView::with_flags( &bytes, - PyBufferFlags::simple().nd().format(), + PyBufferRequest::simple().nd().format(), |view| { assert_eq!(view.shape(), [5]); assert_eq!(view.format().to_str().unwrap(), "B"); @@ -2115,7 +2138,7 @@ mod tests { PyUntypedBufferView::with_flags( &bytes, - PyBufferFlags::simple().strides().format(), + PyBufferRequest::simple().strides().format(), |view| { assert_eq!(view.shape(), [5]); assert_eq!(view.strides(), [1]); @@ -2127,7 +2150,7 @@ mod tests { // Contiguity builders PyUntypedBufferView::with_flags( &bytes, - PyBufferFlags::simple().c_contiguous(), + PyBufferRequest::simple().c_contiguous(), |view| { assert_eq!(view.shape(), [5]); assert_eq!(view.strides(), [1]); @@ -2137,7 +2160,7 @@ mod tests { PyUntypedBufferView::with_flags( &bytes, - PyBufferFlags::simple().f_contiguous(), + PyBufferRequest::simple().f_contiguous(), |view| { assert_eq!(view.shape(), [5]); assert_eq!(view.strides(), [1]); @@ -2147,7 +2170,7 @@ mod tests { PyUntypedBufferView::with_flags( &bytes, - PyBufferFlags::simple().any_contiguous(), + PyBufferRequest::simple().any_contiguous(), |view| { assert_eq!(view.shape(), [5]); assert_eq!(view.strides(), [1]); @@ -2156,7 +2179,7 @@ mod tests { .unwrap(); // Compound requests (read-only) - PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::full_ro(), |view| { + PyUntypedBufferView::with_flags(&bytes, PyBufferRequest::full_ro(), |view| { assert_eq!(view.format().to_str().unwrap(), "B"); assert_eq!(view.shape(), [5]); assert_eq!(view.strides(), [1]); @@ -2164,27 +2187,27 @@ mod tests { }) .unwrap(); - PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::records_ro(), |view| { + PyUntypedBufferView::with_flags(&bytes, PyBufferRequest::records_ro(), |view| { assert_eq!(view.format().to_str().unwrap(), "B"); assert_eq!(view.shape(), [5]); assert_eq!(view.strides(), [1]); }) .unwrap(); - PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::strided_ro(), |view| { + PyUntypedBufferView::with_flags(&bytes, PyBufferRequest::strided_ro(), |view| { assert_eq!(view.shape(), [5]); assert_eq!(view.strides(), [1]); }) .unwrap(); - PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::contig_ro(), |view| { + PyUntypedBufferView::with_flags(&bytes, PyBufferRequest::contig_ro(), |view| { assert_eq!(view.shape(), [5]); assert!(view.is_c_contiguous()); }) .unwrap(); // Writable compound requests - PyUntypedBufferView::with_flags(&array, PyBufferFlags::full(), |view| { + PyUntypedBufferView::with_flags(&array, PyBufferRequest::full(), |view| { assert_eq!(view.format().to_str().unwrap(), "f"); assert_eq!(view.shape(), [4]); assert_eq!(view.strides(), [4]); @@ -2193,21 +2216,21 @@ mod tests { }) .unwrap(); - PyUntypedBufferView::with_flags(&array, PyBufferFlags::records(), |view| { + PyUntypedBufferView::with_flags(&array, PyBufferRequest::records(), |view| { assert_eq!(view.format().to_str().unwrap(), "f"); assert_eq!(view.shape(), [4]); assert!(!view.readonly()); }) .unwrap(); - PyUntypedBufferView::with_flags(&array, PyBufferFlags::strided(), |view| { + PyUntypedBufferView::with_flags(&array, PyBufferRequest::strided(), |view| { assert_eq!(view.shape(), [4]); assert_eq!(view.strides(), [4]); assert!(!view.readonly()); }) .unwrap(); - PyUntypedBufferView::with_flags(&array, PyBufferFlags::contig(), |view| { + PyUntypedBufferView::with_flags(&array, PyBufferRequest::contig(), |view| { assert_eq!(view.shape(), [4]); assert!(!view.readonly()); assert!(view.is_c_contiguous()); @@ -2217,7 +2240,7 @@ mod tests { // Compound + contiguity PyUntypedBufferView::with_flags( &bytes, - PyBufferFlags::full_ro().c_contiguous(), + PyBufferRequest::full_ro().c_contiguous(), |view| { assert_eq!(view.format().to_str().unwrap(), "B"); assert_eq!(view.shape(), [5]); @@ -2226,22 +2249,30 @@ mod tests { ) .unwrap(); - PyUntypedBufferView::with_flags(&array, PyBufferFlags::full().c_contiguous(), |view| { - assert_eq!(view.format().to_str().unwrap(), "f"); - assert!(!view.readonly()); - }) + PyUntypedBufferView::with_flags( + &array, + PyBufferRequest::full().c_contiguous(), + |view| { + assert_eq!(view.format().to_str().unwrap(), "f"); + assert!(!view.readonly()); + }, + ) .unwrap(); - PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::strided_ro().format(), |view| { - assert_eq!(view.format().to_str().unwrap(), "B"); - assert_eq!(view.shape(), [5]); - assert_eq!(view.strides(), [1]); - }) + PyUntypedBufferView::with_flags( + &bytes, + PyBufferRequest::strided_ro().format(), + |view| { + assert_eq!(view.format().to_str().unwrap(), "B"); + assert_eq!(view.shape(), [5]); + assert_eq!(view.strides(), [1]); + }, + ) .unwrap(); PyUntypedBufferView::with_flags( &bytes, - PyBufferFlags::simple().c_contiguous().format(), + PyBufferRequest::simple().c_contiguous().format(), |view| { assert_eq!(view.format().to_str().unwrap(), "B"); assert_eq!(view.shape(), [5]); @@ -2250,20 +2281,20 @@ mod tests { .unwrap(); // Contiguity builder on typed view - PyBufferView::::with_flags(&bytes, PyBufferFlags::simple().format(), |view| { + PyBufferView::::with_flags(&bytes, PyBufferRequest::simple().format(), |view| { assert_eq!(view.format().to_str().unwrap(), "B"); assert_eq!(view.item_count(), 5); }) .unwrap(); - PyBufferView::::with_flags(&array, PyBufferFlags::contig(), |view| { + PyBufferView::::with_flags(&array, PyBufferRequest::contig(), |view| { let slice = view.as_contiguous_slice(py); assert_eq!(slice[0].get(), 1.0); }) .unwrap(); // Writable + contiguity on typed view - PyBufferView::::with_flags(&array, PyBufferFlags::contig(), |view| { + PyBufferView::::with_flags(&array, PyBufferRequest::contig(), |view| { let slice = view.as_contiguous_slice(py); assert_eq!(slice[0].get(), 1.0); let mut_slice = view.as_contiguous_mut_slice(py); @@ -2273,7 +2304,7 @@ mod tests { .unwrap(); // SIMPLE format() returns "B" - PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::simple(), |view| { + PyUntypedBufferView::with_flags(&bytes, PyBufferRequest::simple(), |view| { assert_eq!(view.format().to_str().unwrap(), "B"); }) .unwrap(); @@ -2286,7 +2317,7 @@ mod tests { let bytes = PyBytes::new(py, b"abcde"); // Debug always uses raw_format/raw_shape/raw_strides (Option in output) - PyUntypedBufferView::with_flags(&bytes, PyBufferFlags::full_ro(), |view| { + PyUntypedBufferView::with_flags(&bytes, PyBufferRequest::full_ro(), |view| { let expected = format!( concat!( "PyUntypedBufferView {{ buf: {:?}, obj: {:?}, ", From 3a0674038cabd9d9fb01837b43c47b118febaf6b Mon Sep 17 00:00:00 2001 From: winstxnhdw Date: Fri, 17 Apr 2026 04:33:47 +0800 Subject: [PATCH 17/17] refactor: add diagnostic traits --- src/buffer.rs | 515 +++++++++++------- tests/test_compile_error.rs | 9 + tests/ui/invalid_buffer_flags.rs | 13 + tests/ui/invalid_buffer_flags.stderr | 17 + tests/ui/invalid_buffer_flags_contiguity.rs | 5 + .../ui/invalid_buffer_flags_contiguity.stderr | 19 + .../invalid_buffer_flags_duplicate_format.rs | 5 + ...valid_buffer_flags_duplicate_format.stderr | 19 + tests/ui/invalid_buffer_flags_indirect.rs | 5 + tests/ui/invalid_buffer_flags_indirect.stderr | 19 + 10 files changed, 444 insertions(+), 182 deletions(-) create mode 100644 tests/ui/invalid_buffer_flags.rs create mode 100644 tests/ui/invalid_buffer_flags.stderr create mode 100644 tests/ui/invalid_buffer_flags_contiguity.rs create mode 100644 tests/ui/invalid_buffer_flags_contiguity.stderr create mode 100644 tests/ui/invalid_buffer_flags_duplicate_format.rs create mode 100644 tests/ui/invalid_buffer_flags_duplicate_format.stderr create mode 100644 tests/ui/invalid_buffer_flags_indirect.rs create mode 100644 tests/ui/invalid_buffer_flags_indirect.stderr diff --git a/src/buffer.rs b/src/buffer.rs index e83527271b3..4814e14eb6d 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -832,6 +832,223 @@ mod py_buffer_flags { const CONTIGUITY: u8, >; + #[diagnostic::on_unimplemented( + message = "format information has already been requested for this buffer request", + note = "remove the extra `.format()` call" + )] + pub trait CanRequestFormat {} + #[diagnostic::do_not_recommend] + impl< + const SHAPE: bool, + const STRIDE: bool, + const INDIRECT: bool, + const WRITABLE: bool, + const CONTIGUITY: u8, + > CanRequestFormat for PyBufferFlags + { + } + + #[diagnostic::on_unimplemented( + message = "shape information has already been requested for this buffer request", + note = "remove the extra `.nd()` call" + )] + pub trait CanRequestShape {} + #[diagnostic::do_not_recommend] + impl< + const FORMAT: bool, + const STRIDE: bool, + const INDIRECT: bool, + const WRITABLE: bool, + const CONTIGUITY: u8, + > CanRequestShape for PyBufferFlags + { + } + + #[diagnostic::on_unimplemented( + message = "stride information has already been requested for this buffer request", + note = "remove the extra `.strides()` call" + )] + pub trait CanRequestStrides {} + #[diagnostic::do_not_recommend] + impl< + const FORMAT: bool, + const SHAPE: bool, + const INDIRECT: bool, + const WRITABLE: bool, + const CONTIGUITY: u8, + > CanRequestStrides + for PyBufferFlags + { + } + + #[diagnostic::on_unimplemented( + message = "suboffsets can only be requested on a direct unconstrained buffer request", + note = "call `.indirect()` before any contiguity builder, and only once" + )] + pub trait CanRequestIndirect {} + #[diagnostic::do_not_recommend] + impl + CanRequestIndirect + for PyBufferFlags + { + } + + #[diagnostic::on_unimplemented( + message = "writability has already been requested for this buffer request", + note = "remove the extra `.writable()` call" + )] + pub trait CanRequestWritable {} + #[diagnostic::do_not_recommend] + impl< + const FORMAT: bool, + const SHAPE: bool, + const STRIDE: bool, + const INDIRECT: bool, + const CONTIGUITY: u8, + > CanRequestWritable for PyBufferFlags + { + } + + #[diagnostic::on_unimplemented( + message = "contiguity has already been constrained for this buffer request", + note = "only one of `.c_contiguous()`, `.f_contiguous()`, or `.any_contiguous()` may be used" + )] + pub trait CanRequestContiguity {} + #[diagnostic::do_not_recommend] + impl< + const FORMAT: bool, + const SHAPE: bool, + const STRIDE: bool, + const INDIRECT: bool, + const WRITABLE: bool, + > CanRequestContiguity + for PyBufferFlags< + FORMAT, + SHAPE, + STRIDE, + INDIRECT, + WRITABLE, + { super::CONTIGUITY_UNDEFINED }, + > + { + } + + pub trait GuaranteesWritable {} + #[diagnostic::do_not_recommend] + impl< + const FORMAT: bool, + const SHAPE: bool, + const STRIDE: bool, + const INDIRECT: bool, + const CONTIGUITY: u8, + > GuaranteesWritable for PyBufferFlags + { + } + + pub trait GuaranteesCContiguous {} + #[diagnostic::do_not_recommend] + impl< + const FORMAT: bool, + const SHAPE: bool, + const STRIDE: bool, + const INDIRECT: bool, + const WRITABLE: bool, + > GuaranteesCContiguous + for PyBufferFlags + { + } + + pub trait GuaranteesFContiguous {} + #[diagnostic::do_not_recommend] + impl< + const FORMAT: bool, + const SHAPE: bool, + const STRIDE: bool, + const INDIRECT: bool, + const WRITABLE: bool, + > GuaranteesFContiguous + for PyBufferFlags + { + } + + /// Marker trait for buffer flags which have requested format information. + #[diagnostic::on_unimplemented( + message = "format information is not available with the requested buffer flags", + note = "use `.format()` when building a buffer request to request format information", + note = "`PyBufferRequest::simple()` and `PyBufferRequest::simple().writable()` also imply u8 format" + )] + pub trait IncludesFormat { + const ASSUME_U8: bool; + } + + #[diagnostic::do_not_recommend] + impl< + const SHAPE: bool, + const STRIDE: bool, + const INDIRECT: bool, + const WRITABLE: bool, + const CONTIGUITY: u8, + > IncludesFormat for PyBufferFlags + { + const ASSUME_U8: bool = false; + } + + // Simple (maybe writable) buffers also have an implied u8 format. + #[diagnostic::do_not_recommend] + impl IncludesFormat + for PyBufferFlags + { + const ASSUME_U8: bool = true; + } + + #[diagnostic::on_unimplemented( + message = "shape information is not available with the requested buffer flags", + note = "use `.nd()` when building a buffer request to request shape information" + )] + pub trait IncludesShape {} + #[diagnostic::do_not_recommend] + impl< + const FORMAT: bool, + const STRIDE: bool, + const INDIRECT: bool, + const WRITABLE: bool, + const CONTIGUITY: u8, + > IncludesShape for PyBufferFlags + { + } + + #[diagnostic::on_unimplemented( + message = "strides information is not available with the requested buffer flags", + note = "use `.strides()` when building a buffer request to request stride information" + )] + pub trait IncludesStrides {} + #[diagnostic::do_not_recommend] + impl< + const FORMAT: bool, + const SHAPE: bool, + const INDIRECT: bool, + const WRITABLE: bool, + const CONTIGUITY: u8, + > IncludesStrides for PyBufferFlags + { + } + + #[diagnostic::on_unimplemented( + message = "suboffsets information is not available with the requested buffer flags", + note = "use `.indirect()` when building a buffer request to request suboffset information" + )] + pub trait IncludesSuboffsets {} + #[diagnostic::do_not_recommend] + impl< + const FORMAT: bool, + const SHAPE: bool, + const STRIDE: bool, + const WRITABLE: bool, + const CONTIGUITY: u8, + > IncludesSuboffsets for PyBufferFlags + { + } + pub trait Sealed {} impl< const FORMAT: bool, @@ -854,6 +1071,35 @@ pub trait PyBufferRequestType: py_buffer_flags::Sealed { /// Whether these flags require a writable buffer. const WRITABLE: bool; + + /// The state after requesting format information. + type WithFormat: PyBufferRequestType + py_buffer_flags::IncludesFormat; + + /// The state after requesting shape information. + type WithShape: PyBufferRequestType + py_buffer_flags::IncludesShape; + + /// The state after requesting strides information. + type WithStrides: PyBufferRequestType + + py_buffer_flags::IncludesShape + + py_buffer_flags::IncludesStrides; + + /// The state after requesting indirect / suboffset information. + type WithIndirect: PyBufferRequestType + + py_buffer_flags::IncludesShape + + py_buffer_flags::IncludesStrides + + py_buffer_flags::IncludesSuboffsets; + + /// The state after requesting writability. + type WithWritable: PyBufferRequestType; + + /// The state after requesting C contiguity. + type WithCContiguous: PyBufferRequestType; + + /// The state after requesting Fortran contiguity. + type WithFContiguous: PyBufferRequestType; + + /// The state after requesting either C or Fortran contiguity. + type WithAnyContiguous: PyBufferRequestType; } impl< @@ -868,104 +1114,78 @@ impl< { const CONTIGUITY: u8 = CONTIGUITY_REQ; const WRITABLE: bool = WRITABLE; + + type WithFormat = RequestFlags; + type WithShape = RequestFlags; + type WithStrides = RequestFlags; + type WithIndirect = RequestFlags; + type WithWritable = RequestFlags; + type WithCContiguous = RequestFlags; + type WithFContiguous = RequestFlags; + type WithAnyContiguous = RequestFlags; } -impl< - const SHAPE: bool, - const STRIDE: bool, - const INDIRECT: bool, - const WRITABLE: bool, - const CONTIGUITY: u8, - > PyBufferRequest> +impl PyBufferRequest +where + Flags: PyBufferRequestType + py_buffer_flags::CanRequestFormat, { /// Request format information. - pub const fn format( - self, - ) -> PyBufferRequest> { + pub const fn format(self) -> PyBufferRequest { PyBufferRequest(self.0 | ffi::PyBUF_FORMAT, PhantomData) } } -impl< - const FORMAT: bool, - const STRIDE: bool, - const INDIRECT: bool, - const WRITABLE: bool, - const CONTIGUITY: u8, - > PyBufferRequest> +impl PyBufferRequest +where + Flags: PyBufferRequestType + py_buffer_flags::CanRequestShape, { /// Request shape information. - pub const fn nd( - self, - ) -> PyBufferRequest> { + pub const fn nd(self) -> PyBufferRequest { PyBufferRequest(self.0 | ffi::PyBUF_ND, PhantomData) } } -impl< - const FORMAT: bool, - const SHAPE: bool, - const INDIRECT: bool, - const WRITABLE: bool, - const CONTIGUITY: u8, - > PyBufferRequest> +impl PyBufferRequest +where + Flags: PyBufferRequestType + py_buffer_flags::CanRequestStrides, { /// Request strides information. Implies shape. - pub const fn strides( - self, - ) -> PyBufferRequest> { + pub const fn strides(self) -> PyBufferRequest { PyBufferRequest(self.0 | ffi::PyBUF_STRIDES, PhantomData) } } -impl - PyBufferRequest> +impl PyBufferRequest +where + Flags: PyBufferRequestType + py_buffer_flags::CanRequestIndirect, { /// Request suboffsets (indirect). Implies shape and strides. - pub const fn indirect( - self, - ) -> PyBufferRequest> - { + pub const fn indirect(self) -> PyBufferRequest { PyBufferRequest(self.0 | ffi::PyBUF_INDIRECT, PhantomData) } } -impl< - const FORMAT: bool, - const SHAPE: bool, - const STRIDE: bool, - const INDIRECT: bool, - const CONTIGUITY: u8, - > PyBufferRequest> +impl PyBufferRequest +where + Flags: PyBufferRequestType + py_buffer_flags::CanRequestWritable, { /// Request a writable buffer. - pub const fn writable( - self, - ) -> PyBufferRequest> { + pub const fn writable(self) -> PyBufferRequest { PyBufferRequest(self.0 | ffi::PyBUF_WRITABLE, PhantomData) } } -impl< - const FORMAT: bool, - const SHAPE: bool, - const STRIDE: bool, - const INDIRECT: bool, - const WRITABLE: bool, - > - PyBufferRequest> +impl PyBufferRequest +where + Flags: PyBufferRequestType + py_buffer_flags::CanRequestContiguity, { /// Require C-contiguous layout. Implies shape and strides. - pub const fn c_contiguous( - self, - ) -> PyBufferRequest> { + pub const fn c_contiguous(self) -> PyBufferRequest { PyBufferRequest(self.0 | ffi::PyBUF_C_CONTIGUOUS, PhantomData) } /// Require Fortran-contiguous layout. Implies shape and strides. - pub const fn f_contiguous( - self, - ) -> PyBufferRequest> { + pub const fn f_contiguous(self) -> PyBufferRequest { PyBufferRequest(self.0 | ffi::PyBUF_F_CONTIGUOUS, PhantomData) } @@ -973,9 +1193,7 @@ impl< /// /// The specific contiguity order is not known at compile time, /// so this does not unlock non-Option slice accessors. - pub const fn any_contiguous( - self, - ) -> PyBufferRequest> { + pub const fn any_contiguous(self) -> PyBufferRequest { PyBufferRequest(self.0 | ffi::PyBUF_ANY_CONTIGUOUS, PhantomData) } } @@ -1127,97 +1345,79 @@ impl PyUntypedBufferView { } } -impl< - const SHAPE: bool, - const STRIDE: bool, - const INDIRECT: bool, - const WRITABLE: bool, - const CONTIGUITY: u8, - > PyUntypedBufferView> -{ +impl PyUntypedBufferView { /// A [struct module style](https://docs.python.org/3/c-api/buffer.html#c.Py_buffer.format) /// string describing the contents of a single item. #[inline] - pub fn format(&self) -> &CStr { + pub fn format(&self) -> &CStr + where + Flags: py_buffer_flags::IncludesFormat, + { + if Flags::ASSUME_U8 { + return ffi::c_str!("B"); + } + debug_assert!(!self.raw.format.is_null()); unsafe { CStr::from_ptr(self.raw.format) } } /// Attempt to interpret this untyped view as containing elements of type `T`. - pub fn as_typed( - &self, - ) -> PyResult<&PyBufferView>> + pub fn as_typed(&self) -> PyResult<&PyBufferView> + where + Flags: py_buffer_flags::IncludesFormat, { self.ensure_compatible_with::()?; // SAFETY: PyBufferView is repr(transparent) around PyUntypedBufferView<..> Ok(unsafe { NonNull::from(self) - .cast::, - >>() + .cast::>() .as_ref() }) } - fn ensure_compatible_with(&self) -> PyResult<()> { + fn ensure_compatible_with(&self) -> PyResult<()> + where + Flags: py_buffer_flags::IncludesFormat, + { check_buffer_compatibility::(self.raw.buf, self.item_size(), self.format()) } -} -impl< - const FORMAT: bool, - const STRIDE: bool, - const INDIRECT: bool, - const WRITABLE: bool, - const CONTIGUITY: u8, - > PyUntypedBufferView> -{ /// Returns the shape array. `shape[i]` is the length of dimension `i`. /// /// Despite Python using an array of signed integers, the values are guaranteed to be /// non-negative. However, dimensions of length 0 are possible and might need special /// attention. #[inline] - pub fn shape(&self) -> &[usize] { + pub fn shape(&self) -> &[usize] + where + Flags: py_buffer_flags::IncludesShape, + { debug_assert!(!self.raw.shape.is_null()); unsafe { slice::from_raw_parts(self.raw.shape.cast(), self.raw.ndim as usize) } } -} -impl< - const FORMAT: bool, - const SHAPE: bool, - const INDIRECT: bool, - const WRITABLE: bool, - const CONTIGUITY: u8, - > PyUntypedBufferView> -{ /// Returns the strides array. /// /// Stride values can be any integer. For regular arrays, strides are usually positive, /// but a consumer MUST be able to handle the case `strides[n] <= 0`. #[inline] - pub fn strides(&self) -> &[isize] { + pub fn strides(&self) -> &[isize] + where + Flags: py_buffer_flags::IncludesStrides, + { debug_assert!(!self.raw.strides.is_null()); unsafe { slice::from_raw_parts(self.raw.strides, self.raw.ndim as usize) } } -} -impl< - const FORMAT: bool, - const SHAPE: bool, - const STRIDE: bool, - const WRITABLE: bool, - const CONTIGUITY: u8, - > PyUntypedBufferView> -{ /// Returns the suboffsets array. /// /// May return `None` even when suboffsets were requested if the exporter sets /// `suboffsets` to `NULL`. #[inline] - pub fn suboffsets(&self) -> Option<&[isize]> { + pub fn suboffsets(&self) -> Option<&[isize]> + where + Flags: py_buffer_flags::IncludesSuboffsets, + { if self.raw.suboffsets.is_null() { return None; } @@ -1226,17 +1426,6 @@ impl< } } -// SIMPLE and WRITABLE requests guarantee the implicit "B" format. -impl - PyUntypedBufferView> -{ - /// Returns the format string for a simple byte buffer, which is always `"B"`. - #[inline] - pub fn format(&self) -> &CStr { - ffi::c_str!("B") - } -} - /// Check that a buffer is compatible with element type `T`. fn check_buffer_compatibility( buf: *mut c_void, @@ -1269,22 +1458,10 @@ impl PyUntypedBufferView { /// /// The requested flags constrain what exporters may return. For example, without shape /// information only 1-dimensional buffers are permitted. - pub fn with_flags< - const FORMAT: bool, - const SHAPE: bool, - const STRIDE: bool, - const INDIRECT: bool, - const WRITABLE: bool, - const CONTIGUITY: u8, - R, - >( + pub fn with_flags( obj: &Bound<'_, PyAny>, - flags: PyBufferRequest>, - f: impl FnOnce( - &PyUntypedBufferView< - RequestFlags, - >, - ) -> R, + flags: PyBufferRequest, + f: impl FnOnce(&PyUntypedBufferView) -> R, ) -> PyResult { let mut raw = mem::MaybeUninit::::uninit(); @@ -1327,20 +1504,10 @@ impl PyBufferView { /// [`ffi::PyBUF_FORMAT`] is implicitly added for type validation. As with /// [`PyUntypedBufferView::with_flags`], the requested flags also constrain what exporters /// may return. - pub fn with_flags< - const FORMAT: bool, - const SHAPE: bool, - const STRIDE: bool, - const INDIRECT: bool, - const WRITABLE: bool, - const CONTIGUITY: u8, - R, - >( + pub fn with_flags( obj: &Bound<'_, PyAny>, - flags: PyBufferRequest>, - f: impl FnOnce( - &PyBufferView>, - ) -> R, + flags: PyBufferRequest, + f: impl FnOnce(&PyBufferView) -> R, ) -> PyResult { let mut raw = mem::MaybeUninit::::uninit(); @@ -1348,9 +1515,7 @@ impl PyBufferView { ffi::PyObject_GetBuffer(obj.as_ptr(), raw.as_mut_ptr(), flags.0 | ffi::PyBUF_FORMAT) })?; - let view = PyUntypedBufferView::< - RequestFlags, - > { + let view = PyUntypedBufferView:: { raw: unsafe { raw.assume_init() }, _flags: PhantomData, }; @@ -1390,14 +1555,9 @@ impl PyBufferView { } // C-contiguous guaranteed — no contiguity check needed. -impl< - T: Element, - const FORMAT: bool, - const SHAPE: bool, - const STRIDE: bool, - const INDIRECT: bool, - const WRITABLE: bool, - > PyBufferView> +impl PyBufferView +where + Flags: PyBufferRequestType + py_buffer_flags::GuaranteesCContiguous, { /// Gets the buffer memory as a slice. The buffer is guaranteed C-contiguous. pub fn as_contiguous_slice<'a>(&'a self, _py: Python<'a>) -> &'a [ReadOnlyCell] { @@ -1406,13 +1566,11 @@ impl< } // C-contiguous + writable guaranteed — no checks needed. -impl< - T: Element, - const FORMAT: bool, - const SHAPE: bool, - const STRIDE: bool, - const INDIRECT: bool, - > PyBufferView> +impl PyBufferView +where + Flags: PyBufferRequestType + + py_buffer_flags::GuaranteesCContiguous + + py_buffer_flags::GuaranteesWritable, { /// Gets the buffer memory as a mutable slice. /// The buffer is guaranteed C-contiguous and writable. @@ -1422,14 +1580,9 @@ impl< } // Fortran-contiguous guaranteed. -impl< - T: Element, - const FORMAT: bool, - const SHAPE: bool, - const STRIDE: bool, - const INDIRECT: bool, - const WRITABLE: bool, - > PyBufferView> +impl PyBufferView +where + Flags: PyBufferRequestType + py_buffer_flags::GuaranteesFContiguous, { /// Gets the buffer memory as a slice. The buffer is guaranteed Fortran-contiguous. pub fn as_fortran_contiguous_slice<'a>(&'a self, _py: Python<'a>) -> &'a [ReadOnlyCell] { @@ -1438,13 +1591,11 @@ impl< } // Fortran-contiguous + writable guaranteed. -impl< - T: Element, - const FORMAT: bool, - const SHAPE: bool, - const STRIDE: bool, - const INDIRECT: bool, - > PyBufferView> +impl PyBufferView +where + Flags: PyBufferRequestType + + py_buffer_flags::GuaranteesFContiguous + + py_buffer_flags::GuaranteesWritable, { /// Gets the buffer memory as a mutable slice. /// The buffer is guaranteed Fortran-contiguous and writable. diff --git a/tests/test_compile_error.rs b/tests/test_compile_error.rs index e5b646239d7..572eca52dca 100644 --- a/tests/test_compile_error.rs +++ b/tests/test_compile_error.rs @@ -96,4 +96,13 @@ fn test_compile_errors() { t.pass("tests/ui/pyclass_probe.rs"); t.compile_fail("tests/ui/invalid_pyfunction_warn.rs"); t.compile_fail("tests/ui/invalid_pymethods_warn.rs"); + // `pyo3::buffer` is unavailable with abi3 before Python 3.11. + #[cfg(any(not(Py_LIMITED_API), Py_3_11))] + t.compile_fail("tests/ui/invalid_buffer_flags_contiguity.rs"); + #[cfg(any(not(Py_LIMITED_API), Py_3_11))] + t.compile_fail("tests/ui/invalid_buffer_flags_duplicate_format.rs"); + #[cfg(any(not(Py_LIMITED_API), Py_3_11))] + t.compile_fail("tests/ui/invalid_buffer_flags.rs"); + #[cfg(any(not(Py_LIMITED_API), Py_3_11))] + t.compile_fail("tests/ui/invalid_buffer_flags_indirect.rs"); } diff --git a/tests/ui/invalid_buffer_flags.rs b/tests/ui/invalid_buffer_flags.rs new file mode 100644 index 00000000000..6cae9e5c14a --- /dev/null +++ b/tests/ui/invalid_buffer_flags.rs @@ -0,0 +1,13 @@ +use pyo3::buffer::{PyBufferRequest, PyUntypedBufferView}; +use pyo3::prelude::*; +use pyo3::types::PyBytes; + +fn main() { + Python::attach(|py| { + let bytes = PyBytes::new(py, &[1, 2, 3]); + PyUntypedBufferView::with_flags(&bytes, PyBufferRequest::strided(), |view| { + view.format(); + }) + .unwrap(); + }); +} diff --git a/tests/ui/invalid_buffer_flags.stderr b/tests/ui/invalid_buffer_flags.stderr new file mode 100644 index 00000000000..d6dc840151a --- /dev/null +++ b/tests/ui/invalid_buffer_flags.stderr @@ -0,0 +1,17 @@ +error[E0277]: format information is not available with the requested buffer flags + --> tests/ui/invalid_buffer_flags.rs:9:18 + | +9 | view.format(); + | ^^^^^^ unsatisfied trait bound + | + = help: the trait `buffer::py_buffer_flags::IncludesFormat` is not implemented for `buffer::py_buffer_flags::PyBufferFlags` + = note: use `.format()` when building a buffer request to request format information + = note: `PyBufferRequest::simple()` and `PyBufferRequest::simple().writable()` also imply u8 format +note: required by a bound in `PyUntypedBufferView::::format` + --> src/buffer.rs + | + | pub fn format(&self) -> &CStr + | ------ required by a bound in this associated function + | where + | Flags: py_buffer_flags::IncludesFormat, + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `PyUntypedBufferView::::format` diff --git a/tests/ui/invalid_buffer_flags_contiguity.rs b/tests/ui/invalid_buffer_flags_contiguity.rs new file mode 100644 index 00000000000..6dd86f871b2 --- /dev/null +++ b/tests/ui/invalid_buffer_flags_contiguity.rs @@ -0,0 +1,5 @@ +use pyo3::buffer::PyBufferRequest; + +fn main() { + let _ = PyBufferRequest::simple().c_contiguous().f_contiguous(); +} diff --git a/tests/ui/invalid_buffer_flags_contiguity.stderr b/tests/ui/invalid_buffer_flags_contiguity.stderr new file mode 100644 index 00000000000..a620306f102 --- /dev/null +++ b/tests/ui/invalid_buffer_flags_contiguity.stderr @@ -0,0 +1,19 @@ +error[E0599]: the method `f_contiguous` exists for struct `PyBufferRequest>`, but its trait bounds were not satisfied + --> tests/ui/invalid_buffer_flags_contiguity.rs:4:54 + | +4 | let _ = PyBufferRequest::simple().c_contiguous().f_contiguous(); + | ^^^^^^^^^^^^ method cannot be called due to unsatisfied trait bounds + | + ::: src/buffer.rs + | + | / pub struct PyBufferFlags< + | | const FORMAT: bool, + | | const SHAPE: bool, + | | const STRIDE: bool, +... | + | | const CONTIGUITY: u8, + | | >; + | |_____- doesn't satisfy `_: CanRequestContiguity` + | + = note: the following trait bounds were not satisfied: + `buffer::py_buffer_flags::PyBufferFlags: buffer::py_buffer_flags::CanRequestContiguity` diff --git a/tests/ui/invalid_buffer_flags_duplicate_format.rs b/tests/ui/invalid_buffer_flags_duplicate_format.rs new file mode 100644 index 00000000000..d7fbbc063b4 --- /dev/null +++ b/tests/ui/invalid_buffer_flags_duplicate_format.rs @@ -0,0 +1,5 @@ +use pyo3::buffer::PyBufferRequest; + +fn main() { + let _ = PyBufferRequest::strided().format().format(); +} diff --git a/tests/ui/invalid_buffer_flags_duplicate_format.stderr b/tests/ui/invalid_buffer_flags_duplicate_format.stderr new file mode 100644 index 00000000000..23833a66b5a --- /dev/null +++ b/tests/ui/invalid_buffer_flags_duplicate_format.stderr @@ -0,0 +1,19 @@ +error[E0599]: the method `format` exists for struct `PyBufferRequest>`, but its trait bounds were not satisfied + --> tests/ui/invalid_buffer_flags_duplicate_format.rs:4:49 + | +4 | let _ = PyBufferRequest::strided().format().format(); + | ^^^^^^ method cannot be called due to unsatisfied trait bounds + | + ::: src/buffer.rs + | + | / pub struct PyBufferFlags< + | | const FORMAT: bool, + | | const SHAPE: bool, + | | const STRIDE: bool, +... | + | | const CONTIGUITY: u8, + | | >; + | |_____- doesn't satisfy `_: CanRequestFormat` + | + = note: the following trait bounds were not satisfied: + `buffer::py_buffer_flags::PyBufferFlags: buffer::py_buffer_flags::CanRequestFormat` diff --git a/tests/ui/invalid_buffer_flags_indirect.rs b/tests/ui/invalid_buffer_flags_indirect.rs new file mode 100644 index 00000000000..d3a03f62c34 --- /dev/null +++ b/tests/ui/invalid_buffer_flags_indirect.rs @@ -0,0 +1,5 @@ +use pyo3::buffer::PyBufferRequest; + +fn main() { + let _ = PyBufferRequest::simple().indirect().indirect(); +} diff --git a/tests/ui/invalid_buffer_flags_indirect.stderr b/tests/ui/invalid_buffer_flags_indirect.stderr new file mode 100644 index 00000000000..e38e6edf8e9 --- /dev/null +++ b/tests/ui/invalid_buffer_flags_indirect.stderr @@ -0,0 +1,19 @@ +error[E0599]: the method `indirect` exists for struct `PyBufferRequest>`, but its trait bounds were not satisfied + --> tests/ui/invalid_buffer_flags_indirect.rs:4:50 + | +4 | let _ = PyBufferRequest::simple().indirect().indirect(); + | ^^^^^^^^ method cannot be called due to unsatisfied trait bounds + | + ::: src/buffer.rs + | + | / pub struct PyBufferFlags< + | | const FORMAT: bool, + | | const SHAPE: bool, + | | const STRIDE: bool, +... | + | | const CONTIGUITY: u8, + | | >; + | |_____- doesn't satisfy `_: CanRequestIndirect` + | + = note: the following trait bounds were not satisfied: + `buffer::py_buffer_flags::PyBufferFlags: buffer::py_buffer_flags::CanRequestIndirect`