Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 65 additions & 130 deletions native/Cargo.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions native/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ async-trait = { workspace = true }
log = "0.4"
log4rs = "1.4.0"
prost = "0.14.3"
jni = "0.21"
jni = "0.22.4"
snap = "1.1"
# we disable default features in lz4_flex to force the use of the faster unsafe encoding and decoding implementation
lz4_flex = { version = "0.13.0", default-features = false, features = ["frame"] }
Expand Down Expand Up @@ -91,7 +91,7 @@ hdrs = { version = "0.3.2", features = ["vendored"] }
[dev-dependencies]
pprof = { version = "0.15", features = ["flamegraph"] }
criterion = { version = "0.7", features = ["async", "async_tokio", "async_std"] }
jni = { version = "0.21", features = ["invocation"] }
jni = { version = "0.22.4", features = ["invocation"] }
lazy_static = "1.4"
assertables = "9"
hex = "0.4.3"
Expand Down
41 changes: 22 additions & 19 deletions native/core/src/execution/expressions/subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use datafusion::common::{internal_err, ScalarValue};
use datafusion::logical_expr::ColumnarValue;
use datafusion::physical_expr::PhysicalExpr;
use jni::{
objects::JByteArray,
objects::{JByteArray, JString},
sys::{jboolean, jbyte, jint, jlong, jshort},
};
use std::{
Expand Down Expand Up @@ -81,67 +81,68 @@ impl PhysicalExpr for Subquery {

fn evaluate(&self, _: &RecordBatch) -> datafusion::common::Result<ColumnarValue> {
let mut env = JVMClasses::get_env()?;
let env = env.borrow_env_mut();

unsafe {
let is_null = jni_static_call!(&mut env,
let is_null = jni_static_call!(env,
comet_exec.is_null(self.exec_context_id, self.id) -> jboolean
)?;

if is_null > 0 {
if is_null {
return Ok(ColumnarValue::Scalar(ScalarValue::try_from(
&self.data_type,
)?));
}

match &self.data_type {
DataType::Boolean => {
let r = jni_static_call!(&mut env,
let r = jni_static_call!(env,
comet_exec.get_bool(self.exec_context_id, self.id) -> jboolean
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(r > 0))))
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(r))))
}
DataType::Int8 => {
let r = jni_static_call!(&mut env,
let r = jni_static_call!(env,
comet_exec.get_byte(self.exec_context_id, self.id) -> jbyte
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(r))))
}
DataType::Int16 => {
let r = jni_static_call!(&mut env,
let r = jni_static_call!(env,
comet_exec.get_short(self.exec_context_id, self.id) -> jshort
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(r))))
}
DataType::Int32 => {
let r = jni_static_call!(&mut env,
let r = jni_static_call!(env,
comet_exec.get_int(self.exec_context_id, self.id) -> jint
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(r))))
}
DataType::Int64 => {
let r = jni_static_call!(&mut env,
let r = jni_static_call!(env,
comet_exec.get_long(self.exec_context_id, self.id) -> jlong
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(r))))
}
DataType::Float32 => {
let r = jni_static_call!(&mut env,
let r = jni_static_call!(env,
comet_exec.get_float(self.exec_context_id, self.id) -> f32
)?;
Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(r))))
}
DataType::Float64 => {
let r = jni_static_call!(&mut env,
let r = jni_static_call!(env,
comet_exec.get_double(self.exec_context_id, self.id) -> f64
)?;

Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(r))))
}
DataType::Decimal128(p, s) => {
let bytes = jni_static_call!(&mut env,
let bytes = jni_static_call!(env,
comet_exec.get_decimal(self.exec_context_id, self.id) -> BinaryWrapper
)?;
let bytes: &JByteArray = bytes.get().into();
let bytes = JByteArray::from_raw(env, bytes.get().as_raw());
let slice = env.convert_byte_array(bytes).unwrap();

Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
Expand All @@ -151,14 +152,14 @@ impl PhysicalExpr for Subquery {
)))
}
DataType::Date32 => {
let r = jni_static_call!(&mut env,
let r = jni_static_call!(env,
comet_exec.get_int(self.exec_context_id, self.id) -> jint
)?;

Ok(ColumnarValue::Scalar(ScalarValue::Date32(Some(r))))
}
DataType::Timestamp(TimeUnit::Microsecond, timezone) => {
let r = jni_static_call!(&mut env,
let r = jni_static_call!(env,
comet_exec.get_long(self.exec_context_id, self.id) -> jlong
)?;

Expand All @@ -168,18 +169,20 @@ impl PhysicalExpr for Subquery {
)))
}
DataType::Utf8 => {
let string = jni_static_call!(&mut env,
let string = jni_static_call!(env,
comet_exec.get_string(self.exec_context_id, self.id) -> StringWrapper
)?;

let string = env.get_string(string.get()).unwrap().into();
let string = unsafe { JString::from_raw(env, string.get().as_raw()) }
.try_to_string(env)
.unwrap();
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(string))))
}
DataType::Binary => {
let bytes = jni_static_call!(&mut env,
let bytes = jni_static_call!(env,
comet_exec.get_binary(self.exec_context_id, self.id) -> BinaryWrapper
)?;
let bytes: &JByteArray = bytes.get().into();
let bytes = JByteArray::from_raw(env, bytes.get().as_raw());
let slice = env.convert_byte_array(bytes).unwrap();

Ok(ColumnarValue::Scalar(ScalarValue::Binary(Some(slice))))
Expand Down
39 changes: 20 additions & 19 deletions native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ use jni::{
ReleaseMode,
},
sys::{jboolean, jdouble, jint, jlong},
JNIEnv,
Env, JNIEnv,
};
use std::collections::HashMap;
use std::path::PathBuf;
Expand Down Expand Up @@ -152,13 +152,13 @@ struct ExecutionContext {
/// The input sources for the DataFusion plan
pub scans: Vec<ScanExec>,
/// The global reference of input sources for the DataFusion plan
pub input_sources: Vec<Arc<GlobalRef>>,
pub input_sources: Vec<Arc<GlobalRef<JObject<'static>>>>,
/// The record batch stream to pull results from
pub stream: Option<SendableRecordBatchStream>,
/// Receives batches from a spawned tokio task (async I/O path)
pub batch_receiver: Option<mpsc::Receiver<DataFusionResult<RecordBatch>>>,
/// Native metrics
pub metrics: Arc<GlobalRef>,
pub metrics: Arc<GlobalRef<JObject<'static>>>,
// The interval in milliseconds to update metrics
pub metrics_update_interval: Option<Duration>,
// The last update time of metrics
Expand Down Expand Up @@ -239,7 +239,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
let mut input_sources = vec![];
let num_inputs = env.get_array_length(&iterators)?;
for i in 0..num_inputs {
let input_source = env.get_object_array_element(&iterators, i)?;
let input_source = env.get_object_array_element(&iterators, i as usize)?;
let input_source = Arc::new(jni_new_global_ref!(env, input_source)?);
input_sources.push(input_source);
}
Expand Down Expand Up @@ -268,7 +268,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
let num_local_dirs = env.get_array_length(&local_dirs)?;
let mut local_dirs_vec = vec![];
for i in 0..num_local_dirs {
let local_dir: JString = env.get_object_array_element(&local_dirs, i)?.into();
let local_dir = env.get_object_array_element(&local_dirs, i as usize)?;
let local_dir = unsafe { JString::from_raw(&*env, local_dir.into_raw()) };
let local_dir = env.get_string(&local_dir)?;
local_dirs_vec.push(local_dir.into());
}
Expand Down Expand Up @@ -296,7 +297,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
// Handle key unwrapper for encrypted files
if !key_unwrapper_obj.is_null() {
let encryption_factory = CometEncryptionFactory {
key_unwrapper: jni_new_global_ref!(env, key_unwrapper_obj)?,
key_unwrapper: Arc::new(jni_new_global_ref!(env, key_unwrapper_obj)?),
};
session.runtime_env().register_parquet_encryption_factory(
ENCRYPTION_FACTORY_ID,
Expand Down Expand Up @@ -411,7 +412,7 @@ fn register_datafusion_spark_function(session_ctx: &SessionContext) {

/// Prepares arrow arrays for output.
fn prepare_output(
env: &mut JNIEnv,
env: &mut Env,
array_addrs: JLongArray,
schema_addrs: JLongArray,
output_batch: RecordBatch,
Expand Down Expand Up @@ -698,8 +699,7 @@ pub extern "system" fn Java_org_apache_comet_Native_releasePlan(
})
}

/// Updates the metrics of the query plan.
fn update_metrics(env: &mut JNIEnv, exec_context: &mut ExecutionContext) -> CometResult<()> {
fn update_metrics(env: &mut Env, exec_context: &mut ExecutionContext) -> CometResult<()> {
if let Some(native_query) = &exec_context.root_op {
let metrics = exec_context.metrics.as_obj();
update_comet_metric(env, metrics, native_query)
Expand All @@ -724,15 +724,15 @@ fn log_plan_metrics(exec_context: &ExecutionContext, stage_id: jint, partition:
}

fn convert_datatype_arrays(
env: &'_ mut JNIEnv<'_>,
env: &mut Env,
serialized_datatypes: JObjectArray,
) -> JNIResult<Vec<ArrowDataType>> {
let array_len = env.get_array_length(&serialized_datatypes)?;
let mut res: Vec<ArrowDataType> = Vec::new();

for i in 0..array_len {
let inner_array = env.get_object_array_element(&serialized_datatypes, i)?;
let inner_array: JByteArray = inner_array.into();
let inner_array = env.get_object_array_element(&serialized_datatypes, i as usize)?;
let inner_array = unsafe { JByteArray::from_raw(&*env, inner_array.into_raw()) };
let bytes = env.convert_byte_array(inner_array)?;
let data_type = serde::deserialize_data_type(bytes.as_slice()).unwrap();
let arrow_dt = to_arrow_datatype(&data_type);
Expand Down Expand Up @@ -788,7 +788,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative

let output_path: String = env.get_string(&file_path).unwrap().into();

let checksum_enabled = checksum_enabled == 1;
let checksum_enabled = checksum_enabled;
let current_checksum = if current_checksum == i64::MIN {
// Initial checksum is not available.
None
Expand Down Expand Up @@ -886,7 +886,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_decodeShuffleBlock(
let length = length as usize;
let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer, length) };
let batch = read_ipc_compressed(slice)?;
prepare_output(&mut env, array_addrs, schema_addrs, batch, false)
prepare_output(env, array_addrs, schema_addrs, batch, false)
})
})
}
Expand Down Expand Up @@ -957,7 +957,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_columnarToRowInit(
) -> jlong {
try_unwrap_or_throw(&e, |mut env| {
// Deserialize the schema
let schema = convert_datatype_arrays(&mut env, serialized_schema)?;
let schema = convert_datatype_arrays(env, serialized_schema)?;

// Create the context
let ctx = Box::new(ColumnarToRowContext::new(schema, batch_size as usize));
Expand Down Expand Up @@ -1030,17 +1030,18 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_columnarToRowConvert(
let (buffer_ptr, offsets, lengths) = ctx.convert(&arrays, num_rows as usize)?;

// Create Java int arrays for offsets and lengths
let offsets_array = env.new_int_array(offsets.len() as i32)?;
let offsets_array = env.new_int_array(offsets.len())?;
env.set_int_array_region(&offsets_array, 0, offsets)?;

let lengths_array = env.new_int_array(lengths.len() as i32)?;
let lengths_array = env.new_int_array(lengths.len())?;
env.set_int_array_region(&lengths_array, 0, lengths)?;

// Create the NativeColumnarToRowInfo object
let info_class = env.find_class("org/apache/comet/NativeColumnarToRowInfo")?;
let info_class =
env.find_class(jni::jni_str!("org/apache/comet/NativeColumnarToRowInfo"))?;
let info_obj = env.new_object(
info_class,
"(J[I[I)V",
jni::jni_sig!("(J[I[I)V"),
&[
jni::objects::JValue::Long(buffer_ptr as jlong),
jni::objects::JValue::Object(&offsets_array),
Expand Down
12 changes: 7 additions & 5 deletions native/core/src/execution/memory_pools/fair_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::{
sync::Arc,
};

use jni::objects::GlobalRef;
use jni::objects::{GlobalRef, JObject};

use crate::{errors::CometResult, jvm_bridge::JVMClasses};
use datafusion::common::resources_err;
Expand All @@ -34,7 +34,7 @@ use parking_lot::Mutex;
/// A DataFusion fair `MemoryPool` implementation for Comet. Internally this is
/// implemented via delegating calls to [`crate::jvm_bridge::CometTaskMemoryManager`].
pub struct CometFairMemoryPool {
task_memory_manager_handle: Arc<GlobalRef>,
task_memory_manager_handle: Arc<GlobalRef<JObject<'static>>>,
pool_size: usize,
state: Mutex<CometFairPoolState>,
}
Expand All @@ -57,7 +57,7 @@ impl Debug for CometFairMemoryPool {

impl CometFairMemoryPool {
pub fn new(
task_memory_manager_handle: Arc<GlobalRef>,
task_memory_manager_handle: Arc<GlobalRef<JObject<'static>>>,
pool_size: usize,
) -> CometFairMemoryPool {
Self {
Expand All @@ -69,18 +69,20 @@ impl CometFairMemoryPool {

fn acquire(&self, additional: usize) -> CometResult<i64> {
let mut env = JVMClasses::get_env()?;
let env = env.borrow_env_mut();
let handle = self.task_memory_manager_handle.as_obj();
unsafe {
jni_call!(&mut env,
jni_call!(env,
comet_task_memory_manager(handle).acquire_memory(additional as i64) -> i64)
}
}

fn release(&self, size: usize) -> CometResult<()> {
let mut env = JVMClasses::get_env()?;
let env = env.borrow_env_mut();
let handle = self.task_memory_manager_handle.as_obj();
unsafe {
jni_call!(&mut env, comet_task_memory_manager(handle).release_memory(size as i64) -> ())
jni_call!(env, comet_task_memory_manager(handle).release_memory(size as i64) -> ())
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions native/core/src/execution/memory_pools/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use datafusion::execution::memory_pool::{
FairSpillPool, GreedyMemoryPool, MemoryPool, TrackConsumersPool, UnboundedMemoryPool,
};
use fair_pool::CometFairMemoryPool;
use jni::objects::GlobalRef;
use jni::objects::{GlobalRef, JObject};
use once_cell::sync::OnceCell;
use std::num::NonZeroUsize;
use std::sync::Arc;
Expand All @@ -36,7 +36,7 @@ pub(crate) use task_shared::*;

pub(crate) fn create_memory_pool(
memory_pool_config: &MemoryPoolConfig,
comet_task_memory_manager: Arc<GlobalRef>,
comet_task_memory_manager: Arc<GlobalRef<JObject<'static>>>,
task_attempt_id: i64,
) -> Arc<dyn MemoryPool> {
const NUM_TRACKED_CONSUMERS: usize = 10;
Expand Down
Loading
Loading