diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index 212dca6cd57b0..1def37601dd1b 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -724,6 +724,163 @@ pub fn is_volatile(expr: &Arc) -> bool { is_volatile } +/// A transparent wrapper that marks a [`PhysicalExpr`] as *optional* — i.e., +/// droppable without affecting query correctness. +/// +/// This is used for filters that are performance hints (e.g., dynamic join +/// filters) as opposed to mandatory predicates. The selectivity tracker can +/// detect this wrapper via `expr.as_any().downcast_ref::()` +/// and choose to drop the filter entirely when it is not cost-effective. +/// +/// All [`PhysicalExpr`] methods are delegated to the wrapped inner expression. +/// +/// Currently used by `HashJoinExec` for dynamic join filters. When the +/// selectivity tracker drops such a filter, the join still enforces +/// correctness independently — "dropped" simply means the filter is never +/// applied as a scan-time optimization. +#[derive(Debug)] +pub struct OptionalFilterPhysicalExpr { + inner: Arc, +} + +impl OptionalFilterPhysicalExpr { + /// Create a new optional filter wrapping the given expression. + pub fn new(inner: Arc) -> Self { + Self { inner } + } + + /// Returns a clone of the inner (unwrapped) expression. + pub fn inner(&self) -> Arc { + Arc::clone(&self.inner) + } +} + +impl Display for OptionalFilterPhysicalExpr { + /// Pass through to the inner expression. Surfacing the `Optional(..)` + /// wrapper in plan output would require updating dozens of sqllogictest + /// baselines for what is purely a runtime concept (the adaptive + /// scheduler's permission to drop this filter); plan readers don't need + /// to see it. + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.inner) + } +} + +impl PartialEq for OptionalFilterPhysicalExpr { + fn eq(&self, other: &Self) -> bool { + self.inner.as_ref() == other.inner.as_ref() + } +} + +impl Eq for OptionalFilterPhysicalExpr {} + +impl Hash for OptionalFilterPhysicalExpr { + fn hash(&self, state: &mut H) { + self.inner.as_ref().hash(state); + } +} + +impl PhysicalExpr for OptionalFilterPhysicalExpr { + fn data_type(&self, input_schema: &Schema) -> Result { + self.inner.data_type(input_schema) + } + + fn nullable(&self, input_schema: &Schema) -> Result { + self.inner.nullable(input_schema) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + self.inner.evaluate(batch) + } + + fn return_field(&self, input_schema: &Schema) -> Result { + self.inner.return_field(input_schema) + } + + fn evaluate_selection( + &self, + batch: &RecordBatch, + selection: &BooleanArray, + ) -> Result { + self.inner.evaluate_selection(batch, selection) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.inner] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + assert_eq_or_internal_err!( + children.len(), + 1, + "OptionalFilterPhysicalExpr: expected 1 child" + ); + Ok(Arc::new(OptionalFilterPhysicalExpr::new(Arc::clone( + &children[0], + )))) + } + + fn evaluate_bounds(&self, children: &[&Interval]) -> Result { + self.inner.evaluate_bounds(children) + } + + fn propagate_constraints( + &self, + interval: &Interval, + children: &[&Interval], + ) -> Result>> { + self.inner.propagate_constraints(interval, children) + } + + #[expect(deprecated)] + fn evaluate_statistics(&self, children: &[&Distribution]) -> Result { + self.inner.evaluate_statistics(children) + } + + #[expect(deprecated)] + fn propagate_statistics( + &self, + parent: &Distribution, + children: &[&Distribution], + ) -> Result>> { + self.inner.propagate_statistics(parent, children) + } + + fn get_properties(&self, children: &[ExprProperties]) -> Result { + self.inner.get_properties(children) + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> fmt::Result { + self.inner.fmt_sql(f) + } + + fn snapshot(&self) -> Result>> { + // Always unwrap the Optional wrapper for snapshot consumers (e.g. PruningPredicate). + // If inner has a snapshot, use it; otherwise return the inner directly. + Ok(Some(match self.inner.snapshot()? { + Some(snap) => snap, + None => Arc::clone(&self.inner), + })) + } + + fn snapshot_generation(&self) -> u64 { + // The wrapper itself is not dynamic; tree-walking picks up + // inner's generation via children(). + 0 + } + + fn is_volatile_node(&self) -> bool { + self.inner.is_volatile_node() + } + + fn placement(&self) -> ExpressionPlacement { + self.inner.placement() + } +} + #[cfg(test)] mod test { use crate::physical_expr::PhysicalExpr; @@ -731,6 +888,7 @@ mod test { use arrow::datatypes::{DataType, Schema}; use datafusion_expr_common::columnar_value::ColumnarValue; use std::fmt::{Display, Formatter}; + use std::hash::{Hash, Hasher}; use std::sync::Arc; #[derive(Debug, PartialEq, Eq, Hash)] @@ -905,4 +1063,104 @@ mod test { &BooleanArray::from(vec![true; 5]), ); } + + #[test] + fn test_optional_filter_downcast() { + use super::OptionalFilterPhysicalExpr; + + let inner: Arc = Arc::new(TestExpr {}); + let optional = Arc::new(OptionalFilterPhysicalExpr::new(Arc::clone(&inner))); + + // Can downcast to detect the wrapper + let as_physical: Arc = optional; + assert!( + as_physical + .downcast_ref::() + .is_some() + ); + + // Inner expr is NOT detectable as optional + assert!(inner.downcast_ref::().is_none()); + } + + #[test] + fn test_optional_filter_delegates_evaluate() { + use super::OptionalFilterPhysicalExpr; + + let inner: Arc = Arc::new(TestExpr {}); + let optional = OptionalFilterPhysicalExpr::new(Arc::clone(&inner)); + + let batch = + unsafe { RecordBatch::new_unchecked(Arc::new(Schema::empty()), vec![], 5) }; + let result = optional.evaluate(&batch).unwrap(); + let array = result.to_array(5).unwrap(); + assert_eq!(array.len(), 5); + } + + #[test] + fn test_optional_filter_children_and_with_new_children() { + use super::OptionalFilterPhysicalExpr; + + let inner: Arc = Arc::new(TestExpr {}); + let optional = Arc::new(OptionalFilterPhysicalExpr::new(Arc::clone(&inner))); + + // children() returns the inner + let children = optional.children(); + assert_eq!(children.len(), 1); + + // with_new_children preserves the wrapper + let new_inner: Arc = Arc::new(TestExpr {}); + let rewrapped = Arc::clone(&optional) + .with_new_children(vec![new_inner]) + .unwrap(); + assert!( + rewrapped + .downcast_ref::() + .is_some() + ); + } + + #[test] + fn test_optional_filter_inner() { + use super::OptionalFilterPhysicalExpr; + + let inner: Arc = Arc::new(TestExpr {}); + let optional = OptionalFilterPhysicalExpr::new(Arc::clone(&inner)); + + // inner() returns a clone of the wrapped expression + let unwrapped = optional.inner(); + assert!(unwrapped.downcast_ref::().is_some()); + } + + #[test] + fn test_optional_filter_snapshot_generation_zero() { + use super::OptionalFilterPhysicalExpr; + + let inner: Arc = Arc::new(TestExpr {}); + let optional = OptionalFilterPhysicalExpr::new(inner); + + assert_eq!(optional.snapshot_generation(), 0); + } + + #[test] + fn test_optional_filter_eq_hash() { + use super::OptionalFilterPhysicalExpr; + use std::collections::hash_map::DefaultHasher; + + let inner1: Arc = Arc::new(TestExpr {}); + let inner2: Arc = Arc::new(TestExpr {}); + + let opt1 = OptionalFilterPhysicalExpr::new(inner1); + let opt2 = OptionalFilterPhysicalExpr::new(inner2); + + // Same inner type → equal + assert_eq!(opt1, opt2); + + // Same hash + let mut h1 = DefaultHasher::new(); + let mut h2 = DefaultHasher::new(); + opt1.hash(&mut h1); + opt2.hash(&mut h2); + assert_eq!(h1.finish(), h2.finish()); + } } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 865887d41e111..1edf45c550bf2 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -936,6 +936,8 @@ message PhysicalExprNode { PhysicalScalarSubqueryExprNode scalar_subquery = 22; PhysicalDynamicFilterNode dynamic_filter = 23; + + PhysicalOptionalFilterNode optional_filter = 24; } } @@ -947,6 +949,10 @@ message PhysicalDynamicFilterNode { bool is_complete = 5; } +message PhysicalOptionalFilterNode { + PhysicalExprNode inner = 1; +} + message PhysicalScalarUdfNode { string name = 1; repeated PhysicalExprNode args = 2; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index b8639afd04a89..3a9c4d27a1321 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -16972,6 +16972,9 @@ impl serde::Serialize for PhysicalExprNode { physical_expr_node::ExprType::DynamicFilter(v) => { struct_ser.serialize_field("dynamicFilter", v)?; } + physical_expr_node::ExprType::OptionalFilter(v) => { + struct_ser.serialize_field("optionalFilter", v)?; + } } } struct_ser.end() @@ -17022,6 +17025,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { "scalarSubquery", "dynamic_filter", "dynamicFilter", + "optional_filter", + "optionalFilter", ]; #[allow(clippy::enum_variant_names)] @@ -17048,6 +17053,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { HashExpr, ScalarSubquery, DynamicFilter, + OptionalFilter, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -17091,6 +17097,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { "hashExpr" | "hash_expr" => Ok(GeneratedField::HashExpr), "scalarSubquery" | "scalar_subquery" => Ok(GeneratedField::ScalarSubquery), "dynamicFilter" | "dynamic_filter" => Ok(GeneratedField::DynamicFilter), + "optionalFilter" | "optional_filter" => Ok(GeneratedField::OptionalFilter), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -17267,6 +17274,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { return Err(serde::de::Error::duplicate_field("dynamicFilter")); } expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::DynamicFilter) +; + } + GeneratedField::OptionalFilter => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("optionalFilter")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::OptionalFilter) ; } } @@ -18380,6 +18394,97 @@ impl<'de> serde::Deserialize<'de> for PhysicalNot { deserializer.deserialize_struct("datafusion.PhysicalNot", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for PhysicalOptionalFilterNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.inner.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalOptionalFilterNode", len)?; + if let Some(v) = self.inner.as_ref() { + struct_ser.serialize_field("inner", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for PhysicalOptionalFilterNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "inner", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Inner, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "inner" => Ok(GeneratedField::Inner), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PhysicalOptionalFilterNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.PhysicalOptionalFilterNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut inner__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Inner => { + if inner__.is_some() { + return Err(serde::de::Error::duplicate_field("inner")); + } + inner__ = map_.next_value()?; + } + } + } + Ok(PhysicalOptionalFilterNode { + inner: inner__, + }) + } + } + deserializer.deserialize_struct("datafusion.PhysicalOptionalFilterNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for PhysicalPlanNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index b742e82ea24ec..c661872454afb 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1336,7 +1336,7 @@ pub struct PhysicalExprNode { pub expr_id: ::core::option::Option, #[prost( oneof = "physical_expr_node::ExprType", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18, 19, 20, 21, 22, 23" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18, 19, 20, 21, 22, 23, 24" )] pub expr_type: ::core::option::Option, } @@ -1393,6 +1393,8 @@ pub mod physical_expr_node { ScalarSubquery(super::PhysicalScalarSubqueryExprNode), #[prost(message, tag = "23")] DynamicFilter(::prost::alloc::boxed::Box), + #[prost(message, tag = "24")] + OptionalFilter(::prost::alloc::boxed::Box), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -1409,6 +1411,11 @@ pub struct PhysicalDynamicFilterNode { pub is_complete: bool, } #[derive(Clone, PartialEq, ::prost::Message)] +pub struct PhysicalOptionalFilterNode { + #[prost(message, optional, boxed, tag = "1")] + pub inner: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalScalarUdfNode { #[prost(string, tag = "1")] pub name: ::prost::alloc::string::String, diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 43ebf0474320a..41807491bda79 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -63,6 +63,7 @@ use crate::{convert_required, protobuf}; use datafusion_physical_expr::expressions::{ DynamicFilterInner, DynamicFilterPhysicalExpr, }; +use datafusion_physical_expr_common::physical_expr::OptionalFilterPhysicalExpr; impl From<&protobuf::PhysicalColumn> for Column { fn from(c: &protobuf::PhysicalColumn) -> Column { @@ -561,6 +562,16 @@ pub fn parse_physical_expr_with_converter( )); base_filter } + ExprType::OptionalFilter(optional_filter) => { + let inner = parse_required_physical_expr( + optional_filter.inner.as_deref(), + ctx, + "inner", + input_schema, + proto_converter, + )?; + Arc::new(OptionalFilterPhysicalExpr::new(inner)) + } ExprType::Extension(extension) => { let inputs: Vec> = extension .inputs diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 83c11cfc6b299..84df5acec73bb 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -34,6 +34,7 @@ use datafusion_expr::WindowFrame; use datafusion_physical_expr::ScalarFunctionExpr; use datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr; use datafusion_physical_expr::window::{SlidingAggregateWindowExpr, StandardWindowExpr}; +use datafusion_physical_expr_common::physical_expr::OptionalFilterPhysicalExpr; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use datafusion_physical_plan::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, DynamicFilterPhysicalExpr, InListExpr, @@ -569,6 +570,17 @@ pub fn serialize_physical_expr_with_converter( }), )), }) + } else if let Some(opt) = expr.downcast_ref::() { + let inner_expr = + Box::new(proto_converter.physical_expr_to_proto(&opt.inner(), codec)?); + Ok(protobuf::PhysicalExprNode { + expr_id, + expr_type: Some(protobuf::physical_expr_node::ExprType::OptionalFilter( + Box::new(protobuf::PhysicalOptionalFilterNode { + inner: Some(inner_expr), + }), + )), + }) } else { let mut buf: Vec = vec![]; match codec.try_encode_expr(value, &mut buf) {