From 5b5b0314f50a595a3646e9d3c43f26ba69f16563 Mon Sep 17 00:00:00 2001 From: Manu Zhang Date: Sun, 22 Mar 2026 00:20:10 +0800 Subject: [PATCH] chore(deps): bump jni from 0.21.1 to 0.22.4 in /native Co-authored-by: Codex --- native/Cargo.lock | 195 ++++++------------ native/core/Cargo.toml | 4 +- .../src/execution/expressions/subquery.rs | 41 ++-- native/core/src/execution/jni_api.rs | 39 ++-- .../src/execution/memory_pools/fair_pool.rs | 12 +- native/core/src/execution/memory_pools/mod.rs | 4 +- .../execution/memory_pools/unified_pool.rs | 12 +- native/core/src/execution/metrics/utils.rs | 4 +- .../src/execution/operators/projection.rs | 4 +- native/core/src/execution/operators/scan.rs | 37 ++-- native/core/src/execution/planner.rs | 6 +- .../execution/planner/operator_registry.rs | 6 +- native/core/src/lib.rs | 2 +- native/core/src/parquet/encryption_support.rs | 32 +-- native/core/src/parquet/mod.rs | 33 +-- native/core/src/parquet/util/jni.rs | 15 +- native/jni-bridge/Cargo.toml | 4 +- native/jni-bridge/src/batch_iterator.rs | 31 ++- native/jni-bridge/src/comet_exec.rs | 73 +++++-- native/jni-bridge/src/comet_metric_node.rs | 28 ++- .../src/comet_task_memory_manager.rs | 19 +- native/jni-bridge/src/errors.rs | 94 ++++----- native/jni-bridge/src/lib.rs | 103 +++++---- 23 files changed, 414 insertions(+), 384 deletions(-) diff --git a/native/Cargo.lock b/native/Cargo.lock index 5f99c614b3..1f49572784 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -128,7 +128,7 @@ dependencies = [ "serde_json", "strum", "strum_macros", - "thiserror 2.0.18", + "thiserror", "uuid", "zstd", ] @@ -1214,12 +1214,6 @@ dependencies = [ "shlex", ] -[[package]] -name = "cesu8" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d43a04d8753f35258c91f8ec639f792891f748a1edbd759cf1dcea3382ad83c" - [[package]] name = "cexpr" version = "0.6.0" @@ -1321,7 +1315,7 @@ checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" dependencies = [ "glob", "libc", - "libloading 0.8.9", + "libloading", ] [[package]] @@ -1885,13 +1879,13 @@ dependencies = [ [[package]] name = "datafusion-comet-common" -version = "0.14.0" +version = "0.15.0" dependencies = [ "arrow", "datafusion", "serde", "serde_json", - "thiserror 2.0.18", + "thiserror", ] [[package]] @@ -1911,7 +1905,7 @@ dependencies = [ [[package]] name = "datafusion-comet-jni-bridge" -version = "0.14.0" +version = "0.15.0" dependencies = [ "arrow", "assertables", @@ -1924,7 +1918,7 @@ dependencies = [ "paste", "prost", "regex", - "thiserror 2.0.18", + "thiserror", ] [[package]] @@ -2728,7 +2722,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.61.2", + "windows-sys 0.52.0", ] [[package]] @@ -3606,7 +3600,7 @@ checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46" dependencies = [ "hermit-abi", "libc", - "windows-sys 0.61.2", + "windows-sys 0.52.0", ] [[package]] @@ -3654,7 +3648,7 @@ dependencies = [ "portable-atomic", "portable-atomic-util", "serde_core", - "windows-sys 0.61.2", + "windows-sys 0.52.0", ] [[package]] @@ -3685,27 +3679,54 @@ dependencies = [ [[package]] name = "jni" -version = "0.21.1" +version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a87aa2bb7d2af34197c04845522473242e1aa17c12f4935d5856491a7fb8c97" +checksum = "5efd9a482cf3a427f00d6b35f14332adc7902ce91efb778580e180ff90fa3498" dependencies = [ - "cesu8", "cfg-if", "combine", "java-locator", + "jni-macros", "jni-sys", - "libloading 0.7.4", + "libloading", "log", - "thiserror 1.0.69", + "simd_cesu8", + "thiserror", "walkdir", - "windows-sys 0.45.0", + "windows-link", +] + +[[package]] +name = "jni-macros" +version = "0.22.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a00109accc170f0bdb141fed3e393c565b6f5e072365c3bd58f5b062591560a3" +dependencies = [ + "proc-macro2", + "quote", + "rustc_version", + "simd_cesu8", + "syn 2.0.117", ] [[package]] name = "jni-sys" -version = "0.3.0" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6377a88cb3910bee9b0fa88d4f42e1d2da8e79915598f65fb0c7ee14c878af2" +dependencies = [ + "jni-sys-macros", +] + +[[package]] +name = "jni-sys-macros" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" +checksum = "38c0b942f458fe50cdac086d2f946512305e5631e720728f2a61aabcd47a6264" +dependencies = [ + "quote", + "syn 2.0.117", +] [[package]] name = "jobserver" @@ -3841,16 +3862,6 @@ version = "0.2.183" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5b646652bf6661599e1da8901b3b9522896f01e736bad5f723fe7a3a27f899d" -[[package]] -name = "libloading" -version = "0.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b67380fd3b2fbe7527a606e18729d21c6f3951633d0500574c4dc22d2d638b9f" -dependencies = [ - "cfg-if", - "winapi", -] - [[package]] name = "libloading" version = "0.8.9" @@ -3962,7 +3973,7 @@ dependencies = [ "serde-value", "serde_json", "serde_yaml", - "thiserror 2.0.18", + "thiserror", "thread-id", "typemap-ors", "unicode-segmentation", @@ -4258,7 +4269,7 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", - "thiserror 2.0.18", + "thiserror", "tokio", "tracing", "url", @@ -4734,7 +4745,7 @@ dependencies = [ "spin 0.10.0", "symbolic-demangle", "tempfile", - "thiserror 2.0.18", + "thiserror", ] [[package]] @@ -4889,7 +4900,7 @@ dependencies = [ "rustc-hash 2.1.1", "rustls", "socket2", - "thiserror 2.0.18", + "thiserror", "tokio", "tracing", "web-time", @@ -4910,7 +4921,7 @@ dependencies = [ "rustls", "rustls-pki-types", "slab", - "thiserror 2.0.18", + "thiserror", "tinyvec", "tracing", "web-time", @@ -5300,7 +5311,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys 0.12.1", - "windows-sys 0.61.2", + "windows-sys 0.52.0", ] [[package]] @@ -5669,6 +5680,16 @@ version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" +[[package]] +name = "simd_cesu8" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94f90157bb87cddf702797c5dadfa0be7d266cdf49e22da2fcaa32eff75b2c33" +dependencies = [ + "rustc_version", + "simdutf8", +] + [[package]] name = "simdutf8" version = "0.1.5" @@ -5683,7 +5704,7 @@ checksum = "0d585997b0ac10be3c5ee635f1bab02d512760d14b7c468801ac8a01d9ae5f1d" dependencies = [ "num-bigint", "num-traits", - "thiserror 2.0.18", + "thiserror", "time", ] @@ -5718,7 +5739,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" dependencies = [ "libc", - "windows-sys 0.61.2", + "windows-sys 0.60.2", ] [[package]] @@ -5893,16 +5914,7 @@ dependencies = [ "getrandom 0.4.2", "once_cell", "rustix 1.1.4", - "windows-sys 0.61.2", -] - -[[package]] -name = "thiserror" -version = "1.0.69" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" -dependencies = [ - "thiserror-impl 1.0.69", + "windows-sys 0.52.0", ] [[package]] @@ -5911,18 +5923,7 @@ version = "2.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" dependencies = [ - "thiserror-impl 2.0.18", -] - -[[package]] -name = "thiserror-impl" -version = "1.0.69" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.117", + "thiserror-impl", ] [[package]] @@ -6581,7 +6582,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.61.2", + "windows-sys 0.52.0", ] [[package]] @@ -6649,15 +6650,6 @@ dependencies = [ "windows-link", ] -[[package]] -name = "windows-sys" -version = "0.45.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" -dependencies = [ - "windows-targets 0.42.2", -] - [[package]] name = "windows-sys" version = "0.52.0" @@ -6694,21 +6686,6 @@ dependencies = [ "windows-link", ] -[[package]] -name = "windows-targets" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" -dependencies = [ - "windows_aarch64_gnullvm 0.42.2", - "windows_aarch64_msvc 0.42.2", - "windows_i686_gnu 0.42.2", - "windows_i686_msvc 0.42.2", - "windows_x86_64_gnu 0.42.2", - "windows_x86_64_gnullvm 0.42.2", - "windows_x86_64_msvc 0.42.2", -] - [[package]] name = "windows-targets" version = "0.52.6" @@ -6742,12 +6719,6 @@ dependencies = [ "windows_x86_64_msvc 0.53.1", ] -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" - [[package]] name = "windows_aarch64_gnullvm" version = "0.52.6" @@ -6760,12 +6731,6 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" -[[package]] -name = "windows_aarch64_msvc" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" - [[package]] name = "windows_aarch64_msvc" version = "0.52.6" @@ -6778,12 +6743,6 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" -[[package]] -name = "windows_i686_gnu" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" - [[package]] name = "windows_i686_gnu" version = "0.52.6" @@ -6808,12 +6767,6 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" -[[package]] -name = "windows_i686_msvc" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" - [[package]] name = "windows_i686_msvc" version = "0.52.6" @@ -6826,12 +6779,6 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" -[[package]] -name = "windows_x86_64_gnu" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" - [[package]] name = "windows_x86_64_gnu" version = "0.52.6" @@ -6844,12 +6791,6 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" - [[package]] name = "windows_x86_64_gnullvm" version = "0.52.6" @@ -6862,12 +6803,6 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" -[[package]] -name = "windows_x86_64_msvc" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" - [[package]] name = "windows_x86_64_msvc" version = "0.52.6" diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index 3f305a631d..3a2474d0a8 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -46,7 +46,7 @@ async-trait = { workspace = true } log = "0.4" log4rs = "1.4.0" prost = "0.14.3" -jni = "0.21" +jni = "0.22.4" snap = "1.1" # we disable default features in lz4_flex to force the use of the faster unsafe encoding and decoding implementation lz4_flex = { version = "0.13.0", default-features = false, features = ["frame"] } @@ -91,7 +91,7 @@ hdrs = { version = "0.3.2", features = ["vendored"] } [dev-dependencies] pprof = { version = "0.15", features = ["flamegraph"] } criterion = { version = "0.7", features = ["async", "async_tokio", "async_std"] } -jni = { version = "0.21", features = ["invocation"] } +jni = { version = "0.22.4", features = ["invocation"] } lazy_static = "1.4" assertables = "9" hex = "0.4.3" diff --git a/native/core/src/execution/expressions/subquery.rs b/native/core/src/execution/expressions/subquery.rs index ad4106c251..8d5b5d53c2 100644 --- a/native/core/src/execution/expressions/subquery.rs +++ b/native/core/src/execution/expressions/subquery.rs @@ -25,7 +25,7 @@ use datafusion::common::{internal_err, ScalarValue}; use datafusion::logical_expr::ColumnarValue; use datafusion::physical_expr::PhysicalExpr; use jni::{ - objects::JByteArray, + objects::{JByteArray, JString}, sys::{jboolean, jbyte, jint, jlong, jshort}, }; use std::{ @@ -81,13 +81,14 @@ impl PhysicalExpr for Subquery { fn evaluate(&self, _: &RecordBatch) -> datafusion::common::Result { let mut env = JVMClasses::get_env()?; + let env = env.borrow_env_mut(); unsafe { - let is_null = jni_static_call!(&mut env, + let is_null = jni_static_call!(env, comet_exec.is_null(self.exec_context_id, self.id) -> jboolean )?; - if is_null > 0 { + if is_null { return Ok(ColumnarValue::Scalar(ScalarValue::try_from( &self.data_type, )?)); @@ -95,53 +96,53 @@ impl PhysicalExpr for Subquery { match &self.data_type { DataType::Boolean => { - let r = jni_static_call!(&mut env, + let r = jni_static_call!(env, comet_exec.get_bool(self.exec_context_id, self.id) -> jboolean )?; - Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(r > 0)))) + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(r)))) } DataType::Int8 => { - let r = jni_static_call!(&mut env, + let r = jni_static_call!(env, comet_exec.get_byte(self.exec_context_id, self.id) -> jbyte )?; Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(r)))) } DataType::Int16 => { - let r = jni_static_call!(&mut env, + let r = jni_static_call!(env, comet_exec.get_short(self.exec_context_id, self.id) -> jshort )?; Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(r)))) } DataType::Int32 => { - let r = jni_static_call!(&mut env, + let r = jni_static_call!(env, comet_exec.get_int(self.exec_context_id, self.id) -> jint )?; Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(r)))) } DataType::Int64 => { - let r = jni_static_call!(&mut env, + let r = jni_static_call!(env, comet_exec.get_long(self.exec_context_id, self.id) -> jlong )?; Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(r)))) } DataType::Float32 => { - let r = jni_static_call!(&mut env, + let r = jni_static_call!(env, comet_exec.get_float(self.exec_context_id, self.id) -> f32 )?; Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(r)))) } DataType::Float64 => { - let r = jni_static_call!(&mut env, + let r = jni_static_call!(env, comet_exec.get_double(self.exec_context_id, self.id) -> f64 )?; Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(r)))) } DataType::Decimal128(p, s) => { - let bytes = jni_static_call!(&mut env, + let bytes = jni_static_call!(env, comet_exec.get_decimal(self.exec_context_id, self.id) -> BinaryWrapper )?; - let bytes: &JByteArray = bytes.get().into(); + let bytes = JByteArray::from_raw(env, bytes.get().as_raw()); let slice = env.convert_byte_array(bytes).unwrap(); Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( @@ -151,14 +152,14 @@ impl PhysicalExpr for Subquery { ))) } DataType::Date32 => { - let r = jni_static_call!(&mut env, + let r = jni_static_call!(env, comet_exec.get_int(self.exec_context_id, self.id) -> jint )?; Ok(ColumnarValue::Scalar(ScalarValue::Date32(Some(r)))) } DataType::Timestamp(TimeUnit::Microsecond, timezone) => { - let r = jni_static_call!(&mut env, + let r = jni_static_call!(env, comet_exec.get_long(self.exec_context_id, self.id) -> jlong )?; @@ -168,18 +169,20 @@ impl PhysicalExpr for Subquery { ))) } DataType::Utf8 => { - let string = jni_static_call!(&mut env, + let string = jni_static_call!(env, comet_exec.get_string(self.exec_context_id, self.id) -> StringWrapper )?; - let string = env.get_string(string.get()).unwrap().into(); + let string = unsafe { JString::from_raw(env, string.get().as_raw()) } + .try_to_string(env) + .unwrap(); Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(string)))) } DataType::Binary => { - let bytes = jni_static_call!(&mut env, + let bytes = jni_static_call!(env, comet_exec.get_binary(self.exec_context_id, self.id) -> BinaryWrapper )?; - let bytes: &JByteArray = bytes.get().into(); + let bytes = JByteArray::from_raw(env, bytes.get().as_raw()); let slice = env.convert_byte_array(bytes).unwrap(); Ok(ColumnarValue::Scalar(ScalarValue::Binary(Some(slice)))) diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 59ac674431..d5a71e5fba 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -70,7 +70,7 @@ use jni::{ ReleaseMode, }, sys::{jboolean, jdouble, jint, jlong}, - JNIEnv, + Env, JNIEnv, }; use std::collections::HashMap; use std::path::PathBuf; @@ -152,13 +152,13 @@ struct ExecutionContext { /// The input sources for the DataFusion plan pub scans: Vec, /// The global reference of input sources for the DataFusion plan - pub input_sources: Vec>, + pub input_sources: Vec>>>, /// The record batch stream to pull results from pub stream: Option, /// Receives batches from a spawned tokio task (async I/O path) pub batch_receiver: Option>>, /// Native metrics - pub metrics: Arc, + pub metrics: Arc>>, // The interval in milliseconds to update metrics pub metrics_update_interval: Option, // The last update time of metrics @@ -239,7 +239,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( let mut input_sources = vec![]; let num_inputs = env.get_array_length(&iterators)?; for i in 0..num_inputs { - let input_source = env.get_object_array_element(&iterators, i)?; + let input_source = env.get_object_array_element(&iterators, i as usize)?; let input_source = Arc::new(jni_new_global_ref!(env, input_source)?); input_sources.push(input_source); } @@ -268,7 +268,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( let num_local_dirs = env.get_array_length(&local_dirs)?; let mut local_dirs_vec = vec![]; for i in 0..num_local_dirs { - let local_dir: JString = env.get_object_array_element(&local_dirs, i)?.into(); + let local_dir = env.get_object_array_element(&local_dirs, i as usize)?; + let local_dir = unsafe { JString::from_raw(&*env, local_dir.into_raw()) }; let local_dir = env.get_string(&local_dir)?; local_dirs_vec.push(local_dir.into()); } @@ -296,7 +297,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( // Handle key unwrapper for encrypted files if !key_unwrapper_obj.is_null() { let encryption_factory = CometEncryptionFactory { - key_unwrapper: jni_new_global_ref!(env, key_unwrapper_obj)?, + key_unwrapper: Arc::new(jni_new_global_ref!(env, key_unwrapper_obj)?), }; session.runtime_env().register_parquet_encryption_factory( ENCRYPTION_FACTORY_ID, @@ -411,7 +412,7 @@ fn register_datafusion_spark_function(session_ctx: &SessionContext) { /// Prepares arrow arrays for output. fn prepare_output( - env: &mut JNIEnv, + env: &mut Env, array_addrs: JLongArray, schema_addrs: JLongArray, output_batch: RecordBatch, @@ -698,8 +699,7 @@ pub extern "system" fn Java_org_apache_comet_Native_releasePlan( }) } -/// Updates the metrics of the query plan. -fn update_metrics(env: &mut JNIEnv, exec_context: &mut ExecutionContext) -> CometResult<()> { +fn update_metrics(env: &mut Env, exec_context: &mut ExecutionContext) -> CometResult<()> { if let Some(native_query) = &exec_context.root_op { let metrics = exec_context.metrics.as_obj(); update_comet_metric(env, metrics, native_query) @@ -724,15 +724,15 @@ fn log_plan_metrics(exec_context: &ExecutionContext, stage_id: jint, partition: } fn convert_datatype_arrays( - env: &'_ mut JNIEnv<'_>, + env: &mut Env, serialized_datatypes: JObjectArray, ) -> JNIResult> { let array_len = env.get_array_length(&serialized_datatypes)?; let mut res: Vec = Vec::new(); for i in 0..array_len { - let inner_array = env.get_object_array_element(&serialized_datatypes, i)?; - let inner_array: JByteArray = inner_array.into(); + let inner_array = env.get_object_array_element(&serialized_datatypes, i as usize)?; + let inner_array = unsafe { JByteArray::from_raw(&*env, inner_array.into_raw()) }; let bytes = env.convert_byte_array(inner_array)?; let data_type = serde::deserialize_data_type(bytes.as_slice()).unwrap(); let arrow_dt = to_arrow_datatype(&data_type); @@ -788,7 +788,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative let output_path: String = env.get_string(&file_path).unwrap().into(); - let checksum_enabled = checksum_enabled == 1; + let checksum_enabled = checksum_enabled; let current_checksum = if current_checksum == i64::MIN { // Initial checksum is not available. None @@ -886,7 +886,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_decodeShuffleBlock( let length = length as usize; let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer, length) }; let batch = read_ipc_compressed(slice)?; - prepare_output(&mut env, array_addrs, schema_addrs, batch, false) + prepare_output(env, array_addrs, schema_addrs, batch, false) }) }) } @@ -957,7 +957,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_columnarToRowInit( ) -> jlong { try_unwrap_or_throw(&e, |mut env| { // Deserialize the schema - let schema = convert_datatype_arrays(&mut env, serialized_schema)?; + let schema = convert_datatype_arrays(env, serialized_schema)?; // Create the context let ctx = Box::new(ColumnarToRowContext::new(schema, batch_size as usize)); @@ -1030,17 +1030,18 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_columnarToRowConvert( let (buffer_ptr, offsets, lengths) = ctx.convert(&arrays, num_rows as usize)?; // Create Java int arrays for offsets and lengths - let offsets_array = env.new_int_array(offsets.len() as i32)?; + let offsets_array = env.new_int_array(offsets.len())?; env.set_int_array_region(&offsets_array, 0, offsets)?; - let lengths_array = env.new_int_array(lengths.len() as i32)?; + let lengths_array = env.new_int_array(lengths.len())?; env.set_int_array_region(&lengths_array, 0, lengths)?; // Create the NativeColumnarToRowInfo object - let info_class = env.find_class("org/apache/comet/NativeColumnarToRowInfo")?; + let info_class = + env.find_class(jni::jni_str!("org/apache/comet/NativeColumnarToRowInfo"))?; let info_obj = env.new_object( info_class, - "(J[I[I)V", + jni::jni_sig!("(J[I[I)V"), &[ jni::objects::JValue::Long(buffer_ptr as jlong), jni::objects::JValue::Object(&offsets_array), diff --git a/native/core/src/execution/memory_pools/fair_pool.rs b/native/core/src/execution/memory_pools/fair_pool.rs index 2c25fe9443..9c4decdfdb 100644 --- a/native/core/src/execution/memory_pools/fair_pool.rs +++ b/native/core/src/execution/memory_pools/fair_pool.rs @@ -20,7 +20,7 @@ use std::{ sync::Arc, }; -use jni::objects::GlobalRef; +use jni::objects::{GlobalRef, JObject}; use crate::{errors::CometResult, jvm_bridge::JVMClasses}; use datafusion::common::resources_err; @@ -34,7 +34,7 @@ use parking_lot::Mutex; /// A DataFusion fair `MemoryPool` implementation for Comet. Internally this is /// implemented via delegating calls to [`crate::jvm_bridge::CometTaskMemoryManager`]. pub struct CometFairMemoryPool { - task_memory_manager_handle: Arc, + task_memory_manager_handle: Arc>>, pool_size: usize, state: Mutex, } @@ -57,7 +57,7 @@ impl Debug for CometFairMemoryPool { impl CometFairMemoryPool { pub fn new( - task_memory_manager_handle: Arc, + task_memory_manager_handle: Arc>>, pool_size: usize, ) -> CometFairMemoryPool { Self { @@ -69,18 +69,20 @@ impl CometFairMemoryPool { fn acquire(&self, additional: usize) -> CometResult { let mut env = JVMClasses::get_env()?; + let env = env.borrow_env_mut(); let handle = self.task_memory_manager_handle.as_obj(); unsafe { - jni_call!(&mut env, + jni_call!(env, comet_task_memory_manager(handle).acquire_memory(additional as i64) -> i64) } } fn release(&self, size: usize) -> CometResult<()> { let mut env = JVMClasses::get_env()?; + let env = env.borrow_env_mut(); let handle = self.task_memory_manager_handle.as_obj(); unsafe { - jni_call!(&mut env, comet_task_memory_manager(handle).release_memory(size as i64) -> ()) + jni_call!(env, comet_task_memory_manager(handle).release_memory(size as i64) -> ()) } } } diff --git a/native/core/src/execution/memory_pools/mod.rs b/native/core/src/execution/memory_pools/mod.rs index d8b3473353..f206c0f0f0 100644 --- a/native/core/src/execution/memory_pools/mod.rs +++ b/native/core/src/execution/memory_pools/mod.rs @@ -25,7 +25,7 @@ use datafusion::execution::memory_pool::{ FairSpillPool, GreedyMemoryPool, MemoryPool, TrackConsumersPool, UnboundedMemoryPool, }; use fair_pool::CometFairMemoryPool; -use jni::objects::GlobalRef; +use jni::objects::{GlobalRef, JObject}; use once_cell::sync::OnceCell; use std::num::NonZeroUsize; use std::sync::Arc; @@ -36,7 +36,7 @@ pub(crate) use task_shared::*; pub(crate) fn create_memory_pool( memory_pool_config: &MemoryPoolConfig, - comet_task_memory_manager: Arc, + comet_task_memory_manager: Arc>>, task_attempt_id: i64, ) -> Arc { const NUM_TRACKED_CONSUMERS: usize = 10; diff --git a/native/core/src/execution/memory_pools/unified_pool.rs b/native/core/src/execution/memory_pools/unified_pool.rs index 3233dd6d40..805304fba0 100644 --- a/native/core/src/execution/memory_pools/unified_pool.rs +++ b/native/core/src/execution/memory_pools/unified_pool.rs @@ -28,14 +28,14 @@ use datafusion::{ common::{resources_datafusion_err, DataFusionError}, execution::memory_pool::{MemoryPool, MemoryReservation}, }; -use jni::objects::GlobalRef; +use jni::objects::{GlobalRef, JObject}; use log::warn; /// A DataFusion `MemoryPool` implementation for Comet that delegates to /// Spark's off-heap executor memory pool via JNI by calling /// [`crate::jvm_bridge::CometTaskMemoryManager`]. pub struct CometUnifiedMemoryPool { - task_memory_manager_handle: Arc, + task_memory_manager_handle: Arc>>, used: AtomicUsize, task_attempt_id: i64, } @@ -50,7 +50,7 @@ impl Debug for CometUnifiedMemoryPool { impl CometUnifiedMemoryPool { pub fn new( - task_memory_manager_handle: Arc, + task_memory_manager_handle: Arc>>, task_attempt_id: i64, ) -> CometUnifiedMemoryPool { Self { @@ -63,9 +63,10 @@ impl CometUnifiedMemoryPool { /// Request memory from Spark's off-heap memory pool via JNI fn acquire_from_spark(&self, additional: usize) -> CometResult { let mut env = JVMClasses::get_env()?; + let env = env.borrow_env_mut(); let handle = self.task_memory_manager_handle.as_obj(); unsafe { - jni_call!(&mut env, + jni_call!(env, comet_task_memory_manager(handle).acquire_memory(additional as i64) -> i64) } } @@ -73,9 +74,10 @@ impl CometUnifiedMemoryPool { /// Release memory to Spark's off-heap memory pool via JNI fn release_to_spark(&self, size: usize) -> CometResult<()> { let mut env = JVMClasses::get_env()?; + let env = env.borrow_env_mut(); let handle = self.task_memory_manager_handle.as_obj(); unsafe { - jni_call!(&mut env, comet_task_memory_manager(handle).release_memory(size as i64) -> ()) + jni_call!(env, comet_task_memory_manager(handle).release_memory(size as i64) -> ()) } } } diff --git a/native/core/src/execution/metrics/utils.rs b/native/core/src/execution/metrics/utils.rs index 161c1f1cf9..eb7e10bfc9 100644 --- a/native/core/src/execution/metrics/utils.rs +++ b/native/core/src/execution/metrics/utils.rs @@ -19,7 +19,7 @@ use crate::errors::CometError; use crate::execution::spark_plan::SparkPlan; use datafusion::physical_plan::metrics::MetricValue; use datafusion_comet_proto::spark_metric::NativeMetricNode; -use jni::{objects::JObject, JNIEnv}; +use jni::{objects::JObject, Env}; use prost::Message; use std::collections::HashMap; use std::sync::Arc; @@ -28,7 +28,7 @@ use std::sync::Arc; /// update the metrics of all the children nodes. The metrics are pulled from the /// native execution plan and pushed to the Java side through JNI. pub(crate) fn update_comet_metric( - env: &mut JNIEnv, + env: &mut Env, metric_node: &JObject, spark_plan: &Arc, ) -> Result<(), CometError> { diff --git a/native/core/src/execution/operators/projection.rs b/native/core/src/execution/operators/projection.rs index 6ba1bb5d59..9e5119eeee 100644 --- a/native/core/src/execution/operators/projection.rs +++ b/native/core/src/execution/operators/projection.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use datafusion::physical_plan::projection::ProjectionExec; use datafusion_comet_proto::spark_operator::Operator; -use jni::objects::GlobalRef; +use jni::objects::{GlobalRef, JObject}; use crate::{ execution::{ @@ -39,7 +39,7 @@ impl OperatorBuilder for ProjectionBuilder { fn build( &self, spark_plan: &Operator, - inputs: &mut Vec>, + inputs: &mut Vec>>>, partition_count: usize, planner: &PhysicalPlanner, ) -> Result<(Vec, Arc), ExecutionError> { diff --git a/native/core/src/execution/operators/scan.rs b/native/core/src/execution/operators/scan.rs index 2394912e41..133fc6c53f 100644 --- a/native/core/src/execution/operators/scan.rs +++ b/native/core/src/execution/operators/scan.rs @@ -40,9 +40,7 @@ use datafusion::{ }; use futures::Stream; use itertools::Itertools; -use jni::objects::JValueGen; -use jni::objects::{GlobalRef, JObject}; -use jni::sys::jsize; +use jni::objects::{GlobalRef, JObject, JValue}; use std::rc::Rc; use std::{ any::Any, @@ -61,7 +59,7 @@ pub struct ScanExec { /// environment `JNIEnv` from the execution context. pub exec_context_id: i64, /// The input source of scan node. It is a global reference of JVM `CometBatchIterator` object. - pub input_source: Option>, + pub input_source: Option>>>, /// A description of the input source for informational purposes pub input_source_description: String, /// The data types of columns of the input batch. Converted from Spark schema. @@ -84,7 +82,7 @@ pub struct ScanExec { impl ScanExec { pub fn new( exec_context_id: i64, - input_source: Option>, + input_source: Option>>>, input_source_description: &str, data_types: Vec, arrow_ffi_safe: bool, @@ -176,9 +174,10 @@ impl ScanExec { } let mut env = JVMClasses::get_env()?; + let env = env.borrow_env_mut(); let num_rows: i32 = unsafe { - jni_call!(&mut env, + jni_call!(env, comet_batch_iterator(iter).has_next() -> i32)? }; @@ -190,11 +189,11 @@ impl ScanExec { // JVM via FFI // Selection vectors can be provided by, for instance, Iceberg to // remove rows that have been deleted. - let selection_indices_arrays = Self::get_selection_indices(&mut env, iter, num_cols)?; + let selection_indices_arrays = Self::get_selection_indices(env, iter, num_cols)?; // fetch batch data from JVM via FFI let (num_rows, array_addrs, schema_addrs) = - Self::allocate_and_fetch_batch(&mut env, iter, num_cols)?; + Self::allocate_and_fetch_batch(env, iter, num_cols)?; let mut inputs: Vec = Vec::with_capacity(num_cols); @@ -262,7 +261,7 @@ impl ScanExec { /// Allocates Arrow FFI structures and calls JNI to get the next batch data. /// Returns the number of rows and the allocated array/schema addresses. fn allocate_and_fetch_batch( - env: &mut jni::JNIEnv, + env: &mut jni::Env, iter: &JObject, num_cols: usize, ) -> Result<(i32, Vec, Vec), CometError> { @@ -282,8 +281,8 @@ impl ScanExec { } // Prepare the java array parameters - let long_array_addrs = env.new_long_array(num_cols as jsize)?; - let long_schema_addrs = env.new_long_array(num_cols as jsize)?; + let long_array_addrs = env.new_long_array(num_cols)?; + let long_schema_addrs = env.new_long_array(num_cols)?; env.set_long_array_region(&long_array_addrs, 0, &array_addrs)?; env.set_long_array_region(&long_schema_addrs, 0, &schema_addrs)?; @@ -291,8 +290,8 @@ impl ScanExec { let array_obj = JObject::from(long_array_addrs); let schema_obj = JObject::from(long_schema_addrs); - let array_obj = JValueGen::Object(array_obj.as_ref()); - let schema_obj = JValueGen::Object(schema_obj.as_ref()); + let array_obj = JValue::Object(array_obj.as_ref()); + let schema_obj = JValue::Object(schema_obj.as_ref()); let num_rows: i32 = unsafe { jni_call!(env, @@ -309,7 +308,7 @@ impl ScanExec { /// Checks for selection vectors and exports selection indices if needed. /// Returns selection arrays if they exist (applies to all columns). fn get_selection_indices( - env: &mut jni::JNIEnv, + env: &mut jni::Env, iter: &JObject, num_cols: usize, ) -> Result>, CometError> { @@ -318,7 +317,7 @@ impl ScanExec { jni_call!(env, comet_batch_iterator(iter).has_selection_vectors() -> jni::sys::jboolean)? }; - let has_selection_vectors = has_selection_vectors_result != 0; + let has_selection_vectors = has_selection_vectors_result; let selection_indices_arrays = if has_selection_vectors { // Allocate arrays for selection indices export (one per column) @@ -333,8 +332,8 @@ impl ScanExec { } // Prepare JNI arrays for the export call - let indices_array_obj = env.new_long_array(num_cols as jsize)?; - let indices_schema_obj = env.new_long_array(num_cols as jsize)?; + let indices_array_obj = env.new_long_array(num_cols)?; + let indices_schema_obj = env.new_long_array(num_cols)?; env.set_long_array_region(&indices_array_obj, 0, &indices_array_addrs)?; env.set_long_array_region(&indices_schema_obj, 0, &indices_schema_addrs)?; @@ -342,8 +341,8 @@ impl ScanExec { let _exported_count: i32 = unsafe { jni_call!(env, comet_batch_iterator(iter).export_selection_indices( - JValueGen::Object(JObject::from(indices_array_obj).as_ref()), - JValueGen::Object(JObject::from(indices_schema_obj).as_ref()) + JValue::Object(JObject::from(indices_array_obj).as_ref()), + JValue::Object(JObject::from(indices_schema_obj).as_ref()) ) -> i32)? }; diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index e730fd0c89..223fc81b74 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -127,7 +127,7 @@ use datafusion_comet_spark_expr::{ WideDecimalBinaryExpr, WideDecimalOp, }; use itertools::Itertools; -use jni::objects::GlobalRef; +use jni::objects::{GlobalRef, JObject}; use num::{BigInt, ToPrimitive}; use object_store::path::Path; use std::cmp::max; @@ -908,7 +908,7 @@ impl PhysicalPlanner { pub(crate) fn create_plan<'a>( &'a self, spark_plan: &'a Operator, - inputs: &mut Vec>, + inputs: &mut Vec>>>, partition_count: usize, ) -> Result<(Vec, Arc), ExecutionError> { // Try to use the modular registry first - this automatically handles any registered operator types @@ -1746,7 +1746,7 @@ impl PhysicalPlanner { #[allow(clippy::too_many_arguments)] fn parse_join_parameters( &self, - inputs: &mut Vec>, + inputs: &mut Vec>>>, children: &[Operator], left_join_keys: &[Expr], right_join_keys: &[Expr], diff --git a/native/core/src/execution/planner/operator_registry.rs b/native/core/src/execution/planner/operator_registry.rs index b34a80df95..a656ad5c77 100644 --- a/native/core/src/execution/planner/operator_registry.rs +++ b/native/core/src/execution/planner/operator_registry.rs @@ -23,7 +23,7 @@ use std::{ }; use datafusion_comet_proto::spark_operator::Operator; -use jni::objects::GlobalRef; +use jni::objects::{GlobalRef, JObject}; use super::PhysicalPlanner; use crate::execution::{ @@ -37,7 +37,7 @@ pub trait OperatorBuilder: Send + Sync { fn build( &self, spark_plan: &datafusion_comet_proto::spark_operator::Operator, - inputs: &mut Vec>, + inputs: &mut Vec>>>, partition_count: usize, planner: &PhysicalPlanner, ) -> Result<(Vec, Arc), ExecutionError>; @@ -97,7 +97,7 @@ impl OperatorRegistry { pub fn create_plan( &self, spark_operator: &Operator, - inputs: &mut Vec>, + inputs: &mut Vec>>>, partition_count: usize, planner: &PhysicalPlanner, ) -> Result<(Vec, Arc), ExecutionError> { diff --git a/native/core/src/lib.rs b/native/core/src/lib.rs index 1b87dc1dba..b3b3f1da32 100644 --- a/native/core/src/lib.rs +++ b/native/core/src/lib.rs @@ -150,7 +150,7 @@ pub extern "system" fn Java_org_apache_comet_NativeBase_isFeatureEnabled( _ => false, // Unknown features return false }; - Ok(enabled as u8) + Ok(enabled) }) } diff --git a/native/core/src/parquet/encryption_support.rs b/native/core/src/parquet/encryption_support.rs index 4540c217d5..ca28639f01 100644 --- a/native/core/src/parquet/encryption_support.rs +++ b/native/core/src/parquet/encryption_support.rs @@ -23,7 +23,7 @@ use datafusion::common::extensions_options; use datafusion::config::EncryptionFactoryOptions; use datafusion::error::DataFusionError; use datafusion::execution::parquet_encryption::EncryptionFactory; -use jni::objects::{GlobalRef, JMethodID}; +use jni::objects::{GlobalRef, JMethodID, JObject}; use object_store::path::Path; use parquet::encryption::decrypt::{FileDecryptionProperties, KeyRetriever}; use parquet::encryption::encrypt::FileEncryptionProperties; @@ -42,7 +42,7 @@ extensions_options! { #[derive(Debug)] pub struct CometEncryptionFactory { - pub(crate) key_unwrapper: GlobalRef, + pub(crate) key_unwrapper: Arc>>, } /// `EncryptionFactory` is a DataFusion trait for types that generate @@ -73,7 +73,7 @@ impl EncryptionFactory for CometEncryptionFactory { let config: CometEncryptionConfig = options.to_extension_options()?; let full_path: String = config.uri_base + file_path.as_ref(); - let key_retriever = CometKeyRetriever::new(&full_path, self.key_unwrapper.clone()) + let key_retriever = CometKeyRetriever::new(&full_path, Arc::clone(&self.key_unwrapper)) .map_err(|e| DataFusionError::External(Box::new(e)))?; let decryption_properties = FileDecryptionProperties::with_key_retriever(Arc::new(key_retriever)).build()?; @@ -83,22 +83,26 @@ impl EncryptionFactory for CometEncryptionFactory { pub struct CometKeyRetriever { file_path: String, - key_unwrapper: GlobalRef, + key_unwrapper: Arc>>, get_key_method_id: JMethodID, } impl CometKeyRetriever { - pub fn new(file_path: &str, key_unwrapper: GlobalRef) -> Result { + pub fn new( + file_path: &str, + key_unwrapper: Arc>>, + ) -> Result { let mut env = JVMClasses::get_env()?; + let env = env.borrow_env_mut(); Ok(CometKeyRetriever { file_path: file_path.to_string(), key_unwrapper, get_key_method_id: env .get_method_id( - "org/apache/comet/parquet/CometFileKeyUnwrapper", - "getKey", - "(Ljava/lang/String;[B)[B", + jni::jni_str!("org/apache/comet/parquet/CometFileKeyUnwrapper"), + jni::jni_str!("getKey"), + jni::jni_sig!("(Ljava/lang/String;[B)[B"), ) .map_err(|e| { ExecutionError::GeneralError(format!("Failed to get JNI method ID: {}", e)) @@ -110,15 +114,14 @@ impl CometKeyRetriever { impl KeyRetriever for CometKeyRetriever { /// Get a data encryption key using the metadata stored in the Parquet file. fn retrieve_key(&self, key_metadata: &[u8]) -> datafusion::parquet::errors::Result> { - use jni::{objects::JObject, signature::ReturnType}; + use jni::signature::ReturnType; // Get JNI environment let mut env = JVMClasses::get_env()?; + let env = env.borrow_env_mut(); // Get the key unwrapper instance from GlobalRef - let unwrapper_instance = self.key_unwrapper.as_obj(); - - let instance: JObject = unsafe { JObject::from_raw(unwrapper_instance.as_raw()) }; + let instance = self.key_unwrapper.as_obj(); // Convert file path to JString let file_path_jstring = env @@ -144,7 +147,7 @@ impl KeyRetriever for CometKeyRetriever { }; // Check for Java exceptions first, before processing the result - if let Some(exception) = check_exception(&mut env).map_err(|e| { + if let Some(exception) = check_exception(env).map_err(|e| { ParquetError::General(format!("Failed to check for Java exception: {}", e)) })? { return Err(ParquetError::General(format!( @@ -162,7 +165,8 @@ impl KeyRetriever for CometKeyRetriever { .map_err(|e| ParquetError::General(format!("Failed to extract result: {}", e)))?; // Convert JObject to JByteArray and then to Vec - let byte_array: jni::objects::JByteArray = result_array.into(); + let byte_array = + unsafe { jni::objects::JByteArray::from_raw(env, result_array.into_raw()) }; let result_vec = env .convert_byte_array(&byte_array) diff --git a/native/core/src/parquet/mod.rs b/native/core/src/parquet/mod.rs index d24a6a503e..366a15d583 100644 --- a/native/core/src/parquet/mod.rs +++ b/native/core/src/parquet/mod.rs @@ -38,11 +38,10 @@ use crate::errors::{try_unwrap_or_throw, CometError}; use arrow::ffi::FFI_ArrowArray; -/// JNI exposed methods -use jni::JNIEnv; use jni::{ - objects::{GlobalRef, JByteBuffer, JClass}, + objects::{GlobalRef, JByteBuffer, JClass, JObject}, sys::{jboolean, jint, jlong}, + Env, JNIEnv, }; use self::util::jni::TypePromotionInfo; @@ -65,7 +64,7 @@ use datafusion::execution::SendableRecordBatchStream; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::{SessionConfig, SessionContext}; use futures::{poll, StreamExt}; -use jni::objects::{JByteArray, JLongArray, JMap, JObject, JObjectArray, JString, ReleaseMode}; +use jni::objects::{JByteArray, JLongArray, JMap, JObjectArray, JString, ReleaseMode}; use jni::sys::{jintArray, JNI_FALSE}; use object_store::path::Path; use read::ColumnReader; @@ -74,7 +73,7 @@ use util::jni::{convert_column_descriptor, convert_encoding, deserialize_schema} /// Parquet read context maintained across multiple JNI calls. struct Context { pub column_reader: ColumnReader, - last_data_page: Option, + last_data_page: Option>>, } #[no_mangle] @@ -128,8 +127,8 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_initColumnReader( desc, promotion_info, batch_size as usize, - use_decimal_128 != 0, - use_legacy_date_timestamp != 0, + use_decimal_128, + use_legacy_date_timestamp, ), last_data_page: None, }; @@ -315,7 +314,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_skipBatch( ) -> jint { try_unwrap_or_throw(&env, |_| { let reader = get_reader(handle)?; - Ok(reader.skip_batch(batch_size as usize, discard == 0) as jint) + Ok(reader.skip_batch(batch_size as usize, discard) as jint) }) } @@ -379,7 +378,7 @@ enum ParquetReaderState { /// Parquet read context maintained across multiple JNI calls. struct BatchContext { native_plan: Arc, - metrics_node: Arc, + metrics_node: Arc>>, batch_stream: Option, current_batch: Option, reader_state: ParquetReaderState, @@ -416,16 +415,20 @@ fn get_file_groups_single_file( } pub fn get_object_store_options( - env: &mut JNIEnv, + env: &mut Env, map_object: JObject, ) -> Result, CometError> { - let map = JMap::from_env(env, &map_object)?; + let map = JMap::from_env(env, map_object)?; // Convert to a HashMap let mut collected_map = HashMap::new(); map.iter(env).and_then(|mut iter| { - while let Some((key, value)) = iter.next(env)? { - let key_string: String = String::from(env.get_string(&JString::from(key))?); - let value_string: String = String::from(env.get_string(&JString::from(value))?); + while let Some(entry) = iter.next(env)? { + let key = entry.key(env)?; + let value = entry.value(env)?; + let key = unsafe { JString::from_raw(env, key.into_raw()) }; + let value = unsafe { JString::from_raw(env, value.into_raw()) }; + let key_string = key.try_to_string(env)?; + let value_string = value.try_to_string(env)?; collected_map.insert(key_string, value_string); } Ok(()) @@ -524,7 +527,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_initRecordBat // Handle key unwrapper for encrypted files let encryption_enabled = if !key_unwrapper_obj.is_null() { let encryption_factory = CometEncryptionFactory { - key_unwrapper: jni_new_global_ref!(env, key_unwrapper_obj)?, + key_unwrapper: Arc::new(jni_new_global_ref!(env, key_unwrapper_obj)?), }; session_ctx .runtime_env() diff --git a/native/core/src/parquet/util/jni.rs b/native/core/src/parquet/util/jni.rs index 2223f508f4..4bfc0adca8 100644 --- a/native/core/src/parquet/util/jni.rs +++ b/native/core/src/parquet/util/jni.rs @@ -21,7 +21,7 @@ use jni::{ errors::Result as JNIResult, objects::{JObjectArray, JString}, sys::{jboolean, jint}, - JNIEnv, + Env, }; use arrow::error::ArrowError; @@ -37,7 +37,7 @@ use url::{ParseError, Url}; /// Convert primitives from Spark side into a `ColumnDescriptor`. #[allow(clippy::too_many_arguments)] pub fn convert_column_descriptor( - env: &mut JNIEnv, + env: &mut Env, physical_type_id: jint, logical_type_id: jint, max_dl: jint, @@ -131,12 +131,13 @@ impl TypePromotionInfo { } } -fn convert_column_path(env: &mut JNIEnv, path_array: JObjectArray) -> JNIResult { +fn convert_column_path(env: &mut Env, path_array: JObjectArray) -> JNIResult { let array_len = env.get_array_length(&path_array)?; let mut res: Vec = Vec::new(); for i in 0..array_len { - let p: JString = env.get_object_array_element(&path_array, i)?.into(); - res.push(env.get_string(&p)?.into()); + let p = env.get_object_array_element(&path_array, i as usize)?; + let p: JString = unsafe { JString::from_raw(env, p.into_raw()) }; + res.push(p.try_to_string(env)?); } Ok(ColumnPath::new(res)) } @@ -167,13 +168,13 @@ fn convert_logical_type( match id { 0 => LogicalType::Integer { bit_width: bit_width as i8, - is_signed: is_signed != 0, + is_signed, }, 1 => LogicalType::String, 2 => LogicalType::Decimal { scale, precision }, 3 => LogicalType::Date, 4 => LogicalType::Timestamp { - is_adjusted_to_u_t_c: is_adjusted_utc != 0, + is_adjusted_to_u_t_c: is_adjusted_utc, unit: convert_time_unit(time_unit), }, 5 => LogicalType::Enum, diff --git a/native/jni-bridge/Cargo.toml b/native/jni-bridge/Cargo.toml index 0c50825667..a0ef4a73c8 100644 --- a/native/jni-bridge/Cargo.toml +++ b/native/jni-bridge/Cargo.toml @@ -32,7 +32,7 @@ publish = false arrow = { workspace = true } parquet = { workspace = true } datafusion = { workspace = true } -jni = "0.21" +jni = "0.22.4" thiserror = { workspace = true } regex = { workspace = true } lazy_static = "1.4.0" @@ -42,5 +42,5 @@ prost = "0.14.3" datafusion-comet-common = { workspace = true } [dev-dependencies] -jni = { version = "0.21", features = ["invocation"] } +jni = { version = "0.22.4", features = ["invocation"] } assertables = "9" diff --git a/native/jni-bridge/src/batch_iterator.rs b/native/jni-bridge/src/batch_iterator.rs index 2824bdbfc6..65ca7e7d11 100644 --- a/native/jni-bridge/src/batch_iterator.rs +++ b/native/jni-bridge/src/batch_iterator.rs @@ -20,7 +20,8 @@ use jni::{ errors::Result as JniResult, objects::{JClass, JMethodID}, signature::ReturnType, - JNIEnv, + strings::JNIString, + Env, }; /// A struct that holds all the JNI methods and fields for JVM `CometBatchIterator` class. @@ -40,25 +41,33 @@ pub struct CometBatchIterator<'a> { impl<'a> CometBatchIterator<'a> { pub const JVM_CLASS: &'static str = "org/apache/comet/CometBatchIterator"; - pub fn new(env: &mut JNIEnv<'a>) -> JniResult> { - let class = env.find_class(Self::JVM_CLASS)?; + pub fn new(env: &mut Env<'a>) -> JniResult> { + let class = env.find_class(JNIString::new(Self::JVM_CLASS))?; Ok(CometBatchIterator { class, - method_has_next: env.get_method_id(Self::JVM_CLASS, "hasNext", "()I")?, + method_has_next: env.get_method_id( + JNIString::new(Self::JVM_CLASS), + jni::jni_str!("hasNext"), + jni::jni_sig!("()I"), + )?, method_has_next_ret: ReturnType::Primitive(Primitive::Int), - method_next: env.get_method_id(Self::JVM_CLASS, "next", "([J[J)I")?, + method_next: env.get_method_id( + JNIString::new(Self::JVM_CLASS), + jni::jni_str!("next"), + jni::jni_sig!("([J[J)I"), + )?, method_next_ret: ReturnType::Primitive(Primitive::Int), method_has_selection_vectors: env.get_method_id( - Self::JVM_CLASS, - "hasSelectionVectors", - "()Z", + JNIString::new(Self::JVM_CLASS), + jni::jni_str!("hasSelectionVectors"), + jni::jni_sig!("()Z"), )?, method_has_selection_vectors_ret: ReturnType::Primitive(Primitive::Boolean), method_export_selection_indices: env.get_method_id( - Self::JVM_CLASS, - "exportSelectionIndices", - "([J[J)I", + JNIString::new(Self::JVM_CLASS), + jni::jni_str!("exportSelectionIndices"), + jni::jni_sig!("([J[J)I"), )?, method_export_selection_indices_ret: ReturnType::Primitive(Primitive::Int), }) diff --git a/native/jni-bridge/src/comet_exec.rs b/native/jni-bridge/src/comet_exec.rs index 1bcbbc4ad2..a0b39d0eac 100644 --- a/native/jni-bridge/src/comet_exec.rs +++ b/native/jni-bridge/src/comet_exec.rs @@ -19,7 +19,8 @@ use jni::{ errors::Result as JniResult, objects::{JClass, JStaticMethodID}, signature::{Primitive, ReturnType}, - JNIEnv, + strings::JNIString, + Env, }; /// A struct that holds all the JNI methods and fields for JVM CometExec object. @@ -52,39 +53,75 @@ pub struct CometExec<'a> { impl<'a> CometExec<'a> { pub const JVM_CLASS: &'static str = "org/apache/spark/sql/comet/CometScalarSubquery"; - pub fn new(env: &mut JNIEnv<'a>) -> JniResult> { - let class = env.find_class(Self::JVM_CLASS)?; + pub fn new(env: &mut Env<'a>) -> JniResult> { + let class = env.find_class(JNIString::new(Self::JVM_CLASS))?; Ok(CometExec { - method_get_bool: env.get_static_method_id(Self::JVM_CLASS, "getBoolean", "(JJ)Z")?, + method_get_bool: env.get_static_method_id( + JNIString::new(Self::JVM_CLASS), + jni::jni_str!("getBoolean"), + jni::jni_sig!("(JJ)Z"), + )?, method_get_bool_ret: ReturnType::Primitive(Primitive::Boolean), - method_get_byte: env.get_static_method_id(Self::JVM_CLASS, "getByte", "(JJ)B")?, + method_get_byte: env.get_static_method_id( + JNIString::new(Self::JVM_CLASS), + jni::jni_str!("getByte"), + jni::jni_sig!("(JJ)B"), + )?, method_get_byte_ret: ReturnType::Primitive(Primitive::Byte), - method_get_short: env.get_static_method_id(Self::JVM_CLASS, "getShort", "(JJ)S")?, + method_get_short: env.get_static_method_id( + JNIString::new(Self::JVM_CLASS), + jni::jni_str!("getShort"), + jni::jni_sig!("(JJ)S"), + )?, method_get_short_ret: ReturnType::Primitive(Primitive::Short), - method_get_int: env.get_static_method_id(Self::JVM_CLASS, "getInt", "(JJ)I")?, + method_get_int: env.get_static_method_id( + JNIString::new(Self::JVM_CLASS), + jni::jni_str!("getInt"), + jni::jni_sig!("(JJ)I"), + )?, method_get_int_ret: ReturnType::Primitive(Primitive::Int), - method_get_long: env.get_static_method_id(Self::JVM_CLASS, "getLong", "(JJ)J")?, + method_get_long: env.get_static_method_id( + JNIString::new(Self::JVM_CLASS), + jni::jni_str!("getLong"), + jni::jni_sig!("(JJ)J"), + )?, method_get_long_ret: ReturnType::Primitive(Primitive::Long), - method_get_float: env.get_static_method_id(Self::JVM_CLASS, "getFloat", "(JJ)F")?, + method_get_float: env.get_static_method_id( + JNIString::new(Self::JVM_CLASS), + jni::jni_str!("getFloat"), + jni::jni_sig!("(JJ)F"), + )?, method_get_float_ret: ReturnType::Primitive(Primitive::Float), - method_get_double: env.get_static_method_id(Self::JVM_CLASS, "getDouble", "(JJ)D")?, + method_get_double: env.get_static_method_id( + JNIString::new(Self::JVM_CLASS), + jni::jni_str!("getDouble"), + jni::jni_sig!("(JJ)D"), + )?, method_get_double_ret: ReturnType::Primitive(Primitive::Double), method_get_decimal: env.get_static_method_id( - Self::JVM_CLASS, - "getDecimal", - "(JJ)[B", + JNIString::new(Self::JVM_CLASS), + jni::jni_str!("getDecimal"), + jni::jni_sig!("(JJ)[B"), )?, method_get_decimal_ret: ReturnType::Array, method_get_string: env.get_static_method_id( - Self::JVM_CLASS, - "getString", - "(JJ)Ljava/lang/String;", + JNIString::new(Self::JVM_CLASS), + jni::jni_str!("getString"), + jni::jni_sig!("(JJ)Ljava/lang/String;"), )?, method_get_string_ret: ReturnType::Object, - method_get_binary: env.get_static_method_id(Self::JVM_CLASS, "getBinary", "(JJ)[B")?, + method_get_binary: env.get_static_method_id( + JNIString::new(Self::JVM_CLASS), + jni::jni_str!("getBinary"), + jni::jni_sig!("(JJ)[B"), + )?, method_get_binary_ret: ReturnType::Array, - method_is_null: env.get_static_method_id(Self::JVM_CLASS, "isNull", "(JJ)Z")?, + method_is_null: env.get_static_method_id( + JNIString::new(Self::JVM_CLASS), + jni::jni_str!("isNull"), + jni::jni_sig!("(JJ)Z"), + )?, method_is_null_ret: ReturnType::Primitive(Primitive::Boolean), class, }) diff --git a/native/jni-bridge/src/comet_metric_node.rs b/native/jni-bridge/src/comet_metric_node.rs index f1f0255845..4cc8ae1631 100644 --- a/native/jni-bridge/src/comet_metric_node.rs +++ b/native/jni-bridge/src/comet_metric_node.rs @@ -18,8 +18,10 @@ use jni::{ errors::Result as JniResult, objects::{JClass, JMethodID}, + signature::RuntimeMethodSignature, signature::{Primitive, ReturnType}, - JNIEnv, + strings::JNIString, + Env, }; /// A struct that holds all the JNI methods and fields for JVM CometMetricNode class. @@ -37,22 +39,28 @@ pub struct CometMetricNode<'a> { impl<'a> CometMetricNode<'a> { pub const JVM_CLASS: &'static str = "org/apache/spark/sql/comet/CometMetricNode"; - pub fn new(env: &mut JNIEnv<'a>) -> JniResult> { - let class = env.find_class(Self::JVM_CLASS)?; + pub fn new(env: &mut Env<'a>) -> JniResult> { + let class = env.find_class(JNIString::new(Self::JVM_CLASS))?; + let get_child_node_sig = + RuntimeMethodSignature::from_str(format!("(I)L{};", Self::JVM_CLASS))?; Ok(CometMetricNode { method_get_child_node: env.get_method_id( - Self::JVM_CLASS, - "getChildNode", - format!("(I)L{:};", Self::JVM_CLASS).as_str(), + JNIString::new(Self::JVM_CLASS), + jni::jni_str!("getChildNode"), + get_child_node_sig.method_signature(), )?, method_get_child_node_ret: ReturnType::Object, - method_set: env.get_method_id(Self::JVM_CLASS, "set", "(Ljava/lang/String;J)V")?, + method_set: env.get_method_id( + JNIString::new(Self::JVM_CLASS), + jni::jni_str!("set"), + jni::jni_sig!("(Ljava/lang/String;J)V"), + )?, method_set_ret: ReturnType::Primitive(Primitive::Void), method_set_all_from_bytes: env.get_method_id( - Self::JVM_CLASS, - "set_all_from_bytes", - "([B)V", + JNIString::new(Self::JVM_CLASS), + jni::jni_str!("set_all_from_bytes"), + jni::jni_sig!("([B)V"), )?, method_set_all_from_bytes_ret: ReturnType::Primitive(Primitive::Void), class, diff --git a/native/jni-bridge/src/comet_task_memory_manager.rs b/native/jni-bridge/src/comet_task_memory_manager.rs index 22c3332c61..cec0b70511 100644 --- a/native/jni-bridge/src/comet_task_memory_manager.rs +++ b/native/jni-bridge/src/comet_task_memory_manager.rs @@ -19,7 +19,8 @@ use jni::{ errors::Result as JniResult, objects::{JClass, JMethodID}, signature::{Primitive, ReturnType}, - JNIEnv, + strings::JNIString, + Env, }; /// A wrapper which delegate acquire/release memory calls to the @@ -38,20 +39,20 @@ pub struct CometTaskMemoryManager<'a> { impl<'a> CometTaskMemoryManager<'a> { pub const JVM_CLASS: &'static str = "org/apache/spark/CometTaskMemoryManager"; - pub fn new(env: &mut JNIEnv<'a>) -> JniResult> { - let class = env.find_class(Self::JVM_CLASS)?; + pub fn new(env: &mut Env<'a>) -> JniResult> { + let class = env.find_class(JNIString::new(Self::JVM_CLASS))?; let result = CometTaskMemoryManager { class, method_acquire_memory: env.get_method_id( - Self::JVM_CLASS, - "acquireMemory", - "(J)J".to_string(), + JNIString::new(Self::JVM_CLASS), + jni::jni_str!("acquireMemory"), + jni::jni_sig!("(J)J"), )?, method_release_memory: env.get_method_id( - Self::JVM_CLASS, - "releaseMemory", - "(J)V".to_string(), + JNIString::new(Self::JVM_CLASS), + jni::jni_str!("releaseMemory"), + jni::jni_sig!("(J)V"), )?, method_acquire_memory_ret: ReturnType::Primitive(Primitive::Long), method_release_memory_ret: ReturnType::Primitive(Primitive::Void), diff --git a/native/jni-bridge/src/errors.rs b/native/jni-bridge/src/errors.rs index 640201f6f0..0b6f232d04 100644 --- a/native/jni-bridge/src/errors.rs +++ b/native/jni-bridge/src/errors.rs @@ -27,7 +27,7 @@ use std::{ any::Any, convert, fmt::Write, - panic::{catch_unwind, UnwindSafe}, + panic::UnwindSafe, result, str, str::Utf8Error, sync::{Arc, Mutex}, @@ -39,7 +39,7 @@ use std::{ use jni::sys::{jboolean, jbyte, jchar, jdouble, jfloat, jint, jlong, jobject, jshort}; use jni::objects::{GlobalRef, JThrowable}; -use jni::JNIEnv; +use jni::{strings::JNIString, Env, JNIEnv, Outcome}; use lazy_static::lazy_static; use parquet::errors::ParquetError; use thiserror::Error; @@ -72,7 +72,7 @@ pub enum ExecutionError { JavaException { class: String, msg: String, - throwable: GlobalRef, + throwable: GlobalRef>, }, } @@ -167,7 +167,7 @@ pub enum CometError { JavaException { class: String, msg: String, - throwable: GlobalRef, + throwable: GlobalRef>, }, } @@ -388,7 +388,7 @@ pub trait JNIDefault { impl JNIDefault for jboolean { fn default() -> jboolean { - 0 + false } } @@ -449,7 +449,7 @@ impl JNIDefault for () { // `RuntimeException` back to the calling Java. Since a return result is required, use `JNIDefault` // to create a reasonable result. This returned default value will be ignored due to the exception. pub fn unwrap_or_throw_default( - env: &mut JNIEnv, + env: &mut Env, result: std::result::Result, ) -> T { match result { @@ -465,16 +465,16 @@ pub fn unwrap_or_throw_default( } } -fn throw_exception(env: &mut JNIEnv, error: &CometError, backtrace: Option) { +fn throw_exception(env: &mut Env, error: &CometError, backtrace: Option) { // If there isn't already an exception? - if env.exception_check().is_ok() { + if !env.exception_check() { // ... then throw new exception match error { CometError::JavaException { class: _, msg: _, throwable, - } => env.throw(<&JThrowable>::from(throwable.as_obj())), + } => env.throw(throwable), CometError::Execution { source: ExecutionError::JavaException { @@ -482,7 +482,7 @@ fn throw_exception(env: &mut JNIEnv, error: &CometError, backtrace: Option env.throw(<&JThrowable>::from(throwable.as_obj())), + } => env.throw(throwable), // Handle DataFusion errors containing SparkError or SparkErrorWithContext CometError::DataFusion { msg: _, @@ -491,14 +491,14 @@ fn throw_exception(env: &mut JNIEnv, error: &CometError, backtrace: Option() { let json_message = spark_error_with_ctx.to_json(); env.throw_new( - "org/apache/comet/exceptions/CometQueryExecutionException", - json_message, + jni::jni_str!("org/apache/comet/exceptions/CometQueryExecutionException"), + JNIString::new(json_message), ) } else if let Some(spark_error) = e.downcast_ref::() { let json_message = spark_error.to_json(); env.throw_new( - "org/apache/comet/exceptions/CometQueryExecutionException", - json_message, + jni::jni_str!("org/apache/comet/exceptions/CometQueryExecutionException"), + JNIString::new(json_message), ) } else { // Check for file-not-found errors from object store @@ -513,10 +513,15 @@ fn throw_exception(env: &mut JNIEnv, error: &CometError, backtrace: Option env.throw_new( - exception.class, - to_stacktrace_string(exception.msg, backtrace_string).unwrap(), + JNIString::new(exception.class), + JNIString::new( + to_stacktrace_string(exception.msg, backtrace_string).unwrap(), + ), + ), + _ => env.throw_new( + JNIString::new(exception.class), + JNIString::new(exception.msg), ), - _ => env.throw_new(exception.class, exception.msg), } } } @@ -537,10 +542,15 @@ fn throw_exception(env: &mut JNIEnv, error: &CometError, backtrace: Option env.throw_new( - exception.class, - to_stacktrace_string(exception.msg, backtrace_string).unwrap(), + JNIString::new(exception.class), + JNIString::new( + to_stacktrace_string(exception.msg, backtrace_string).unwrap(), + ), + ), + _ => env.throw_new( + JNIString::new(exception.class), + JNIString::new(exception.msg), ), - _ => env.throw_new(exception.class, exception.msg), } } } @@ -550,17 +560,14 @@ fn throw_exception(env: &mut JNIEnv, error: &CometError, backtrace: Option jni::errors::Result<()> { +fn throw_spark_error_as_json(env: &mut Env, spark_error: &SparkError) -> jni::errors::Result<()> { // Serialize error to JSON let json_message = spark_error.to_json(); // Throw CometQueryExecutionException with JSON message env.throw_new( - "org/apache/comet/exceptions/CometQueryExecutionException", - json_message, + jni::jni_str!("org/apache/comet/exceptions/CometQueryExecutionException"), + JNIString::new(json_message), ) } @@ -659,33 +666,26 @@ fn to_stacktrace_string(msg: String, backtrace_string: String) -> Result(result: Result, E>) -> Result { - result.and_then(convert::identity) -} - -// Implements "currying" from `FnOnce(T) -> R` to `FnOnce() -> R`, given -// an instance of T. Curring is not supported in Rust so we have to use this -// custom function to achieve something similar here. -fn curry<'a, T: 'a, F, R>(f: F, t: T) -> impl FnOnce() -> R + 'a -where - F: FnOnce(T) -> R + 'a, -{ - || f(t) -} - // It is currently undefined behavior to unwind from Rust code into foreign code, so we can wrap // our JNI functions and turn these panics into a `RuntimeException`. pub fn try_unwrap_or_throw(env: &JNIEnv, f: F) -> T where T: JNIDefault, - F: FnOnce(JNIEnv) -> Result + UnwindSafe, + F: FnOnce(&mut Env) -> Result + UnwindSafe, { - let mut env1 = unsafe { JNIEnv::from_raw(env.get_raw()).unwrap() }; - let env2 = unsafe { JNIEnv::from_raw(env.get_raw()).unwrap() }; - unwrap_or_throw_default( - &mut env1, - flatten(catch_unwind(curry(f, env2)).map_err(CometError::from)), - ) + let raw = env.as_raw(); + let mut env1 = unsafe { JNIEnv::from_raw(raw) }; + match env1.with_env(f).into_outcome() { + Outcome::Ok(value) => value, + Outcome::Err(err) => { + let mut guard = unsafe { jni::AttachGuard::from_unowned(raw) }; + unwrap_or_throw_default(guard.borrow_env_mut(), Err(err)) + } + Outcome::Panic(payload) => { + let mut guard = unsafe { jni::AttachGuard::from_unowned(raw) }; + unwrap_or_throw_default(guard.borrow_env_mut(), Err(CometError::from(payload))) + } + } } #[cfg(test)] diff --git a/native/jni-bridge/src/lib.rs b/native/jni-bridge/src/lib.rs index 456fbdf688..f63cbbdb3e 100644 --- a/native/jni-bridge/src/lib.rs +++ b/native/jni-bridge/src/lib.rs @@ -24,9 +24,9 @@ use jni::objects::JClass; use jni::{ errors::Error, - objects::{JMethodID, JObject, JString, JThrowable, JValueGen, JValueOwned}, + objects::{JMethodID, JObject, JString, JThrowable, JValueOwned}, signature::ReturnType, - AttachGuard, JNIEnv, JavaVM, + AttachGuard, Env, JavaVM, }; use once_cell::sync::OnceCell; @@ -127,15 +127,15 @@ macro_rules! jni_new_global_ref { /// Wrapper for JString. Because we cannot implement `TryFrom` trait for `JString` as they /// are defined in different crates. pub struct StringWrapper<'a> { - value: JString<'a>, + value: JObject<'a>, } impl<'a> StringWrapper<'a> { - pub fn new(value: JString<'a>) -> StringWrapper<'a> { + pub fn new(value: JObject<'a>) -> StringWrapper<'a> { Self { value } } - pub fn get(&self) -> &JString<'_> { + pub fn get(&self) -> &JObject<'_> { &self.value } } @@ -159,7 +159,7 @@ impl<'a> TryFrom> for StringWrapper<'a> { fn try_from(value: JValueOwned<'a>) -> Result, Error> { match value { - JValueGen::Object(b) => Ok(StringWrapper::new(JString::from(b))), + JValueOwned::Object(b) => Ok(StringWrapper::new(b)), _ => Err(Error::WrongJValueType("object", value.type_name())), } } @@ -170,7 +170,7 @@ impl<'a> TryFrom> for BinaryWrapper<'a> { fn try_from(value: JValueOwned<'a>) -> Result, Error> { match value { - JValueGen::Object(b) => Ok(BinaryWrapper::new(b)), + JValueOwned::Object(b) => Ok(BinaryWrapper::new(b)), _ => Err(Error::WrongJValueType("object", value.type_name())), } } @@ -224,29 +224,47 @@ static JVM_CLASSES: OnceCell = OnceCell::new(); impl JVMClasses<'_> { /// Creates a new JVMClasses struct. - pub fn init(env: &mut JNIEnv) { + pub fn init(env: &mut Env) { JVM_CLASSES.get_or_init(|| { - // A hack to make the `JNIEnv` static. It is not safe but we don't really use the - // `JNIEnv` except for creating the global references of the classes. - let env = unsafe { std::mem::transmute::<&mut JNIEnv, &'static mut JNIEnv>(env) }; + // A hack to make the `Env` static. It is not safe but we don't really use the + // `Env` except for creating the global references of the classes. + let env = unsafe { std::mem::transmute::<&mut Env, &'static mut Env>(env) }; - let java_lang_object = env.find_class("java/lang/Object").unwrap(); + let java_lang_object = env.find_class(jni::jni_str!("java/lang/Object")).unwrap(); let object_get_class_method = env - .get_method_id(&java_lang_object, "getClass", "()Ljava/lang/Class;") + .get_method_id( + &java_lang_object, + jni::jni_str!("getClass"), + jni::jni_sig!("()Ljava/lang/Class;"), + ) .unwrap(); - let java_lang_class = env.find_class("java/lang/Class").unwrap(); + let java_lang_class = env.find_class(jni::jni_str!("java/lang/Class")).unwrap(); let class_get_name_method = env - .get_method_id(&java_lang_class, "getName", "()Ljava/lang/String;") + .get_method_id( + &java_lang_class, + jni::jni_str!("getName"), + jni::jni_sig!("()Ljava/lang/String;"), + ) .unwrap(); - let java_lang_throwable = env.find_class("java/lang/Throwable").unwrap(); + let java_lang_throwable = env + .find_class(jni::jni_str!("java/lang/Throwable")) + .unwrap(); let throwable_get_message_method = env - .get_method_id(&java_lang_throwable, "getMessage", "()Ljava/lang/String;") + .get_method_id( + &java_lang_throwable, + jni::jni_str!("getMessage"), + jni::jni_sig!("()Ljava/lang/String;"), + ) .unwrap(); let throwable_get_cause_method = env - .get_method_id(&java_lang_throwable, "getCause", "()Ljava/lang/Throwable;") + .get_method_id( + &java_lang_throwable, + jni::jni_str!("getCause"), + jni::jni_sig!("()Ljava/lang/Throwable;"), + ) .unwrap(); // SAFETY: According to the documentation for `JMethodID`, it is our @@ -284,19 +302,25 @@ impl JVMClasses<'_> { ); unsafe { let java_vm = JAVA_VM.get_unchecked(); - java_vm.attach_current_thread().map_err(|e| { - CometError::Internal(format!( - "JVMClasses::get_env() failed to attach current thread: {e}" - )) - }) + let mut scope = jni::ScopeToken::default(); + let guard = java_vm + .attach_current_thread_guard(Default::default, &mut scope) + .map_err(|e| { + CometError::Internal(format!( + "JVMClasses::get_env() failed to attach current thread: {e}" + )) + })?; + Ok(std::mem::transmute::, AttachGuard<'static>>(guard)) } } } -pub fn check_exception(env: &mut JNIEnv) -> CometResult> { - let result = if env.exception_check()? { - let exception = env.exception_occurred()?; - env.exception_clear()?; +pub fn check_exception(env: &mut Env) -> CometResult> { + let result = if env.exception_check() { + let exception = env + .exception_occurred() + .expect("exception_check returned true without an exception"); + env.exception_clear(); let exception_err = convert_exception(env, &exception)?; Some(exception_err) } else { @@ -310,7 +334,7 @@ pub fn check_exception(env: &mut JNIEnv) -> CometResult> { /// 1. get the `Class` object of the input `throwable` via `Object#getClass` method /// 2. get the exception class name via calling `Class#getName` on the above object fn get_throwable_class_name( - env: &mut JNIEnv, + env: &mut Env, jvm_classes: &JVMClasses, throwable: &JThrowable, ) -> CometResult { @@ -323,16 +347,17 @@ fn get_throwable_class_name( &[], )? .l()?; + let class_obj = unsafe { JClass::from_raw(env, class_obj.into_raw()) }; let class_name = env .call_method_unchecked( - class_obj, + &class_obj, jvm_classes.class_get_name_method, ReturnType::Object, &[], )? - .l()? - .into(); - let class_name_str = env.get_string(&class_name)?.into(); + .l()?; + let class_name = unsafe { JString::from_raw(env, class_name.into_raw()) }; + let class_name_str = class_name.try_to_string(env)?; Ok(class_name_str) } @@ -340,7 +365,7 @@ fn get_throwable_class_name( /// Get the exception message via calling `Throwable#getMessage` on the throwable object fn get_throwable_message( - env: &mut JNIEnv, + env: &mut Env, jvm_classes: &JVMClasses, throwable: &JThrowable, ) -> CometResult { @@ -352,10 +377,10 @@ fn get_throwable_message( ReturnType::Object, &[], )? - .l()? - .into(); + .l() + .map(|obj| unsafe { JString::from_raw(env, obj.into_raw()) })?; let message_str = if !message.is_null() { - env.get_string(&message)?.into() + message.try_to_string(env)? } else { String::from("null") }; @@ -367,8 +392,8 @@ fn get_throwable_message( ReturnType::Object, &[], )? - .l()? - .into(); + .l() + .map(|obj| unsafe { JThrowable::from_raw(env, obj.into_raw()) })?; if !cause.is_null() { let cause_class_name = get_throwable_class_name(env, jvm_classes, &cause)?; @@ -386,7 +411,7 @@ fn get_throwable_message( /// this converts it into a `CometError::JavaException` with the exception class name /// and exception message. This error can then be populated to the JVM side to let /// users know the cause of the native side error. -pub fn convert_exception(env: &mut JNIEnv, throwable: &JThrowable) -> CometResult { +pub fn convert_exception(env: &mut Env, throwable: &JThrowable) -> CometResult { let cache = JVMClasses::get(); let exception_class_name_str = get_throwable_class_name(env, cache, throwable)?; let message_str = get_throwable_message(env, cache, throwable)?;