diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index 682cc885cd6be..c6626bf75674c 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -722,6 +722,7 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Field Arc::new(Field::new(name, common_type, is_nullable)) } +/// coerce two types if they are Maps by coercing their inner 'entries' fields' types +/// using struct coercion +fn map_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { + use arrow::datatypes::DataType::*; + match (lhs_type, rhs_type) { + (Map(lhs_field, lhs_ordered), Map(rhs_field, rhs_ordered)) => { + struct_coercion(lhs_field.data_type(), rhs_field.data_type()).map( + |key_value_type| { + Map( + Arc::new((**lhs_field).clone().with_data_type(key_value_type)), + *lhs_ordered && *rhs_ordered, + ) + }, + ) + } + _ => None, + } +} + /// Returns the output type of applying mathematics operations such as /// `+` to arguments of `lhs_type` and `rhs_type`. fn mathematics_numerical_coercion( @@ -2303,4 +2323,49 @@ mod tests { ); Ok(()) } + + #[test] + fn test_map_coercion() -> Result<()> { + let lhs = Field::new_map( + "lhs", + "entries", + Arc::new(Field::new("keys", DataType::Utf8, false)), + Arc::new(Field::new("values", DataType::LargeUtf8, false)), + true, + false, + ); + let rhs = Field::new_map( + "rhs", + "kvp", + Arc::new(Field::new("k", DataType::Utf8, false)), + Arc::new(Field::new("v", DataType::Utf8, true)), + false, + true, + ); + + let expected = Field::new_map( + "expected", + "entries", // struct coercion takes lhs name + Arc::new(Field::new( + "keys", // struct coercion takes lhs name + DataType::Utf8, + false, + )), + Arc::new(Field::new( + "values", // struct coercion takes lhs name + DataType::LargeUtf8, // lhs is large string + true, // rhs is nullable + )), + false, // both sides must be sorted + true, // rhs is nullable + ); + + test_coercion_binary_rule!( + lhs.data_type(), + rhs.data_type(), + Operator::Eq, + expected.data_type().clone() + ); + Ok(()) + } }