diff --git a/src/lib.rs b/src/lib.rs index 2e408eb7b34..5ff3db0cf0a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -351,6 +351,7 @@ pub use crate::pyclass::{PyClass, PyClassGuard, PyClassGuardMut}; pub use crate::pyclass_init::PyClassInitializer; pub use crate::type_object::{PyTypeCheck, PyTypeInfo}; pub use crate::types::PyAny; +pub use crate::unpack::Unpackable; pub use crate::version::PythonVersionInfo; pub(crate) mod ffi_ptr_ext; @@ -424,6 +425,7 @@ mod instance; mod interpreter_lifecycle; pub mod marker; pub mod marshal; +mod unpack; #[macro_use] pub mod sync; pub(crate) mod byteswriter; diff --git a/src/unpack.rs b/src/unpack.rs new file mode 100644 index 00000000000..ebdfe6524c3 --- /dev/null +++ b/src/unpack.rs @@ -0,0 +1,66 @@ +use crate::{ + exceptions::PyValueError, + types::{PyAnyMethods, PyIterator}, + Borrowed, Bound, FromPyObject, PyAny, PyResult, +}; + +/// TODO +pub trait Unpackable<'py>: Sized { + /// TODO + fn unpack(obj: Borrowed<'_, 'py, PyAny>) -> PyResult; +} + +fn get_value<'py, T>(iter: &mut Bound<'py, PyIterator>, expected: usize) -> PyResult +where + T: for<'a> FromPyObject<'a, 'py>, +{ + let Some(item) = iter.next() else { + return Err(PyValueError::new_err(format!( + "not enough values to unpack (expected {expected})", + ))); + }; + match item?.extract::() { + Ok(v) => Ok(v), + Err(e) => return Err(e.into()), + } +} + +fn one() -> usize { + 1 +} + +macro_rules! tuple_impls { + ($T:ident $num:literal) => { + tuple_impls!(@impl $T $num); + }; + ($T:ident $num:literal $( $U:ident $unum:literal )+) => { + tuple_impls!($( $U $unum )+); + tuple_impls!(@impl $T $num $( $U $unum )+); + }; + (@impl $( $T:ident $num:literal )+) => { + impl<'py, $($T,)+> Unpackable<'py> for ($($T,)+) + where + $($T: for<'a> FromPyObject<'a, 'py>),+ + { + fn unpack(obj: Borrowed<'_, 'py, PyAny>) -> PyResult { + let total = $(one::<$T>() +)+ 0; + let mut iter = obj.try_iter()?; + let out = ($( + get_value::<$T>(&mut iter, total)?, + )+); + + if iter.next().is_some() { + return Err(PyValueError::new_err(format!( + "too many values to unpack (expected {total})" + ))); + } + + Ok(out) + } + } + }; +} + +tuple_impls! { + T11 11 T10 10 T9 9 T8 8 T7 7 T6 6 T5 5 T4 4 T3 3 T2 2 T1 1 T0 0 +} diff --git a/tests/test_unpack.rs b/tests/test_unpack.rs new file mode 100644 index 00000000000..3730e291ae6 --- /dev/null +++ b/tests/test_unpack.rs @@ -0,0 +1,48 @@ +#![cfg(feature = "macros")] + +use pyo3::prelude::*; +use pyo3::Unpackable; + +mod test_utils; + +#[test] +fn test_unpack_3() { + Python::attach(|py| { + let tuple = (0, 1, 2); + let py_tuple = tuple.into_pyobject(py).unwrap(); + let unpacked: (i32, i32, i32) = + Unpackable::unpack(py_tuple.as_any().as_borrowed()).unwrap(); + + assert_eq!(tuple, unpacked); + }); +} + +#[test] +fn test_unpack_not_enough() { + Python::attach(|py| { + let tuple = (0, 1); + let py_tuple = tuple.into_pyobject(py).unwrap(); + let try_unpack = + <(i32, i32, i32) as Unpackable>::unpack(py_tuple.as_any().as_borrowed()).unwrap_err(); + + assert_eq!( + try_unpack.value(py).to_string(), + "not enough values to unpack (expected 3)" + ); + }); +} + +#[test] +fn test_unpack_too_many() { + Python::attach(|py| { + let tuple = (0, 1, 2); + let py_tuple = tuple.into_pyobject(py).unwrap(); + let try_unpack = + <(i32, i32) as Unpackable>::unpack(py_tuple.as_any().as_borrowed()).unwrap_err(); + + assert_eq!( + try_unpack.value(py).to_string(), + "too many values to unpack (expected 2)" + ); + }); +}