diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeUnion.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeUnion.java index d06fa31661a1..09d5a9e221e8 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeUnion.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableMergeUnion.java @@ -22,10 +22,8 @@ import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.linq4j.tree.ParameterExpression; import org.apache.calcite.plan.RelOptCluster; -import org.apache.calcite.plan.RelTrait; import org.apache.calcite.plan.RelTraitSet; import org.apache.calcite.rel.RelCollation; -import org.apache.calcite.rel.RelCollationTraitDef; import org.apache.calcite.rel.RelNode; import org.apache.calcite.util.BuiltInMethod; import org.apache.calcite.util.Pair; @@ -49,13 +47,17 @@ protected EnumerableMergeUnion(RelOptCluster cluster, RelTraitSet traitSet, throw new IllegalArgumentException("EnumerableMergeUnion with no collation"); } for (RelNode input : inputs) { - final RelTrait inputCollationTrait = - input.getTraitSet().getTrait(RelCollationTraitDef.INSTANCE); + // Use getCollations() rather than getTrait() so that we handle the case + // where the input's collation slot holds a RelCompositeTrait (multiple + // collations). For each required collation, at least one of the input's + // collations must satisfy it. + final List inputCollations = input.getTraitSet().getCollations(); for (RelCollation collation : collations) { - if (inputCollationTrait == null || !inputCollationTrait.satisfies(collation)) { + boolean satisfied = inputCollations.stream().anyMatch(ic -> ic.satisfies(collation)); + if (!satisfied) { throw new IllegalArgumentException("EnumerableMergeUnion input does " + "not satisfy collation. EnumerableMergeUnion collation: " - + collation + ". Input collation: " + inputCollationTrait + ". Input: " + + collation + ". Input collations: " + inputCollations + ". Input: " + input); } } diff --git a/core/src/main/java/org/apache/calcite/plan/RelTraitSet.java b/core/src/main/java/org/apache/calcite/plan/RelTraitSet.java index 65747426062c..79b4ebbc411a 100644 --- a/core/src/main/java/org/apache/calcite/plan/RelTraitSet.java +++ b/core/src/main/java/org/apache/calcite/plan/RelTraitSet.java @@ -87,7 +87,13 @@ public static RelTraitSet createEmpty() { * {@link #size()} or less than 0. */ public RelTrait getTrait(int index) { - return traits[index]; + final RelTrait trait = traits[index]; + if (trait instanceof RelCompositeTrait) { + throw new IllegalStateException("Trait index " + index + + " has multiple values in this trait set; " + + "use getTraits(RelTraitDef) instead of getTrait(RelTraitDef)"); + } + return trait; } /** @@ -110,21 +116,29 @@ public List getTraits(int index) { } @Override public RelTrait get(int index) { - return getTrait(index); + return traits[index]; } /** * Returns whether a given kind of trait is enabled. */ public boolean isEnabled(RelTraitDef traitDef) { - return getTrait(traitDef) != null; + return findIndex(traitDef) >= 0; } /** * Retrieves a RelTrait of the given type from the set. * + *

If this trait def supports multiple values (i.e. its trait implements + * {@link RelMultipleTrait}), the underlying slot may contain a + * {@link RelCompositeTrait} when more than one value is present. In that + * case this method throws {@link IllegalStateException}; use + * {@link #getTraits(RelTraitDef)} instead. + * * @param traitDef the type of RelTrait to retrieve * @return the RelTrait, or null if not found + * @throws IllegalStateException if the slot holds a composite (multiple) + * trait; use {@link #getTraits(RelTraitDef)} in that case */ public @Nullable T getTrait(RelTraitDef traitDef) { int index = findIndex(traitDef); @@ -375,17 +389,44 @@ public RelTraitSet getDefaultSansConvention() { * {@link RelDistributionTraitDef#INSTANCE}, or null if the * {@link RelDistributionTraitDef#INSTANCE} is not registered * in this traitSet. + * + *

If this trait set contains multiple distributions (a composite trait), + * this method throws {@link IllegalStateException}. Use + * {@link #getDistributions()} to handle both the single and multi-distribution + * cases uniformly. */ @SuppressWarnings("unchecked") public @Nullable T getDistribution() { return (@Nullable T) getTrait(RelDistributionTraitDef.INSTANCE); } + /** + * Returns {@link RelDistribution} traits defined by + * {@link RelDistributionTraitDef#INSTANCE}. + * + *

Returns an empty list when the trait def is not registered, a + * singleton list for the common single-distribution case, and a list with + * more than one element when a {@link RelCompositeTrait} is present. + */ + @SuppressWarnings("unchecked") + public List getDistributions() { + int index = findIndex(RelDistributionTraitDef.INSTANCE); + if (index < 0) { + return ImmutableList.of(); + } + return (List) (List) getTraits(index); + } + /** * Returns {@link RelCollation} trait defined by * {@link RelCollationTraitDef#INSTANCE}, or null if the * {@link RelCollationTraitDef#INSTANCE} is not registered * in this traitSet. + * + *

If this trait set contains multiple collations (a composite trait), + * this method throws {@link IllegalStateException}. Use + * {@link #getCollations()} to handle both the single and multi-collation + * cases uniformly. */ @SuppressWarnings("unchecked") public @Nullable T getCollation() { @@ -395,17 +436,19 @@ public RelTraitSet getDefaultSansConvention() { /** * Returns {@link RelCollation} traits defined by * {@link RelCollationTraitDef#INSTANCE}. + * + *

Returns an empty list when the trait def is not registered, a + * singleton list for the common single-collation case, and a list with + * more than one element when a {@link RelCompositeTrait} is present. */ @SuppressWarnings("unchecked") public List getCollations() { - RelCollation trait = getTrait(RelCollationTraitDef.INSTANCE); - if (trait == null) { + int index = findIndex(RelCollationTraitDef.INSTANCE); + if (index < 0) { return ImmutableList.of(); } - if (trait instanceof RelCompositeTrait) { - return ((RelCompositeTrait) trait).traitList(); - } - return ImmutableList.of(trait); + // getTraits(int) already unwraps RelCompositeTrait transparently. + return (List) (List) getTraits(index); } /** @@ -577,8 +620,22 @@ public boolean contains(RelTrait trait) { */ public boolean containsIfApplicable(RelTrait trait) { // Note that '==' is sufficient, because trait should be canonized. - final RelTrait trait1 = getTrait(trait.getTraitDef()); - return trait1 == null || trait1 == trait; + int index = findIndex(trait.getTraitDef()); + if (index < 0) { + // TraitDef not registered in this set → treat as "not applicable" → true + return true; + } + final RelTrait stored = getTrait(index); + if (stored instanceof RelCompositeTrait) { + // Any member of the composite matching the sought trait counts. + for (Object t : ((RelCompositeTrait) stored).traitList()) { + if (t == trait) { + return true; + } + } + return false; + } + return stored == trait; } /** diff --git a/core/src/test/java/org/apache/calcite/plan/RelTraitTest.java b/core/src/test/java/org/apache/calcite/plan/RelTraitTest.java index 6a9ddbaa49a5..2760b1e157e7 100644 --- a/core/src/test/java/org/apache/calcite/plan/RelTraitTest.java +++ b/core/src/test/java/org/apache/calcite/plan/RelTraitTest.java @@ -20,6 +20,10 @@ import org.apache.calcite.rel.RelCollation; import org.apache.calcite.rel.RelCollationTraitDef; import org.apache.calcite.rel.RelCollations; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.test.RelBuilderTest; +import org.apache.calcite.tools.FrameworkConfig; +import org.apache.calcite.tools.RelBuilder; import com.google.common.collect.ImmutableList; @@ -33,6 +37,7 @@ import static org.hamcrest.Matchers.hasSize; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static java.lang.Integer.toHexString; @@ -98,4 +103,34 @@ private void assertCanonical(String message, Supplier> collat RelTraitSet traits3 = traits2.replace(RelCollations.of(1)); assertFalse(traits3.equalsSansConvention(traits2)); } + + /** Test for + * [CALCITE-7431] + * RelTraitSet#getTrait seems to mishandle RelCompositeTrait. */ + @Test void testRelCompositeTrait() { + // Build: EMP -> Sort(MGR asc) -> Project(MGR, MGR as MGR2) + // The project maps both output columns 0 and 1 back to input column 3 + // (MGR), so the planner derives two collations: [0 ASC] and [1 ASC], which + // are stored as a RelCompositeTrait in the output trait set. + final FrameworkConfig config = RelBuilderTest.config().build(); + final RelBuilder b = RelBuilder.create(config); + final RelNode in = b + .scan("EMP") + .sort(3) // MGR asc + .project(b.field(3), b.alias(b.field(3), "MGR2")) // MGR, MGR as MGR2 + .build(); + + final RelTraitSet traitSet = in.getTraitSet(); + + final List collations = traitSet.getCollations(); + assertTrue(collations.size() >= 2, + "getCollations() should expose all composite collations"); + + assertThrows(IllegalStateException.class, traitSet::getCollation, + "getCollation() should throw when a RelCompositeTrait is present"); + + assertThrows(IllegalStateException.class, + () -> traitSet.getTrait(RelCollationTraitDef.INSTANCE), + "getTrait() should throw when a RelCompositeTrait is present"); + } }