Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@
import org.apache.calcite.linq4j.tree.Expressions;
import org.apache.calcite.linq4j.tree.ParameterExpression;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelTrait;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelCollationTraitDef;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.util.BuiltInMethod;
import org.apache.calcite.util.Pair;
Expand All @@ -49,13 +47,17 @@ protected EnumerableMergeUnion(RelOptCluster cluster, RelTraitSet traitSet,
throw new IllegalArgumentException("EnumerableMergeUnion with no collation");
}
for (RelNode input : inputs) {
final RelTrait inputCollationTrait =
input.getTraitSet().getTrait(RelCollationTraitDef.INSTANCE);
// Use getCollations() rather than getTrait() so that we handle the case
// where the input's collation slot holds a RelCompositeTrait (multiple
// collations). For each required collation, at least one of the input's
// collations must satisfy it.
final List<RelCollation> inputCollations = input.getTraitSet().getCollations();
for (RelCollation collation : collations) {
if (inputCollationTrait == null || !inputCollationTrait.satisfies(collation)) {
boolean satisfied = inputCollations.stream().anyMatch(ic -> ic.satisfies(collation));
if (!satisfied) {
throw new IllegalArgumentException("EnumerableMergeUnion input does "
+ "not satisfy collation. EnumerableMergeUnion collation: "
+ collation + ". Input collation: " + inputCollationTrait + ". Input: "
+ collation + ". Input collations: " + inputCollations + ". Input: "
+ input);
}
}
Expand Down
79 changes: 68 additions & 11 deletions core/src/main/java/org/apache/calcite/plan/RelTraitSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,13 @@ public static RelTraitSet createEmpty() {
* {@link #size()} or less than 0.
*/
public RelTrait getTrait(int index) {
return traits[index];
final RelTrait trait = traits[index];
if (trait instanceof RelCompositeTrait) {
throw new IllegalStateException("Trait index " + index
+ " has multiple values in this trait set; "
+ "use getTraits(RelTraitDef) instead of getTrait(RelTraitDef)");
}
return trait;
}

/**
Expand All @@ -110,21 +116,29 @@ public <E extends RelMultipleTrait> List<E> getTraits(int index) {
}

@Override public RelTrait get(int index) {
return getTrait(index);
return traits[index];
}

/**
* Returns whether a given kind of trait is enabled.
*/
public <T extends RelTrait> boolean isEnabled(RelTraitDef<T> traitDef) {
return getTrait(traitDef) != null;
return findIndex(traitDef) >= 0;
}

/**
* Retrieves a RelTrait of the given type from the set.
*
* <p>If this trait def supports multiple values (i.e. its trait implements
* {@link RelMultipleTrait}), the underlying slot may contain a
* {@link RelCompositeTrait} when more than one value is present. In that
* case this method throws {@link IllegalStateException}; use
* {@link #getTraits(RelTraitDef)} instead.
*
* @param traitDef the type of RelTrait to retrieve
* @return the RelTrait, or null if not found
* @throws IllegalStateException if the slot holds a composite (multiple)
* trait; use {@link #getTraits(RelTraitDef)} in that case
*/
public <T extends RelTrait> @Nullable T getTrait(RelTraitDef<T> traitDef) {
int index = findIndex(traitDef);
Expand Down Expand Up @@ -375,17 +389,44 @@ public RelTraitSet getDefaultSansConvention() {
* {@link RelDistributionTraitDef#INSTANCE}, or null if the
* {@link RelDistributionTraitDef#INSTANCE} is not registered
* in this traitSet.
*
* <p>If this trait set contains multiple distributions (a composite trait),
* this method throws {@link IllegalStateException}. Use
* {@link #getDistributions()} to handle both the single and multi-distribution
* cases uniformly.
*/
@SuppressWarnings("unchecked")
public <T extends RelDistribution> @Nullable T getDistribution() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, getTrait and getCollation will throw an error, while getDistribution returns the first element. These are two different strategies. Can you please explain this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was indeed an overlooked point, and I've already made corrections based on the comments. I'm not sure if this PR will be useful, as it doesn't address the root cause.

return (@Nullable T) getTrait(RelDistributionTraitDef.INSTANCE);
}

/**
* Returns {@link RelDistribution} traits defined by
* {@link RelDistributionTraitDef#INSTANCE}.
*
* <p>Returns an empty list when the trait def is not registered, a
* singleton list for the common single-distribution case, and a list with
* more than one element when a {@link RelCompositeTrait} is present.
*/
@SuppressWarnings("unchecked")
public List<RelDistribution> getDistributions() {
int index = findIndex(RelDistributionTraitDef.INSTANCE);
if (index < 0) {
return ImmutableList.of();
}
return (List<RelDistribution>) (List<?>) getTraits(index);
}

/**
* Returns {@link RelCollation} trait defined by
* {@link RelCollationTraitDef#INSTANCE}, or null if the
* {@link RelCollationTraitDef#INSTANCE} is not registered
* in this traitSet.
*
* <p>If this trait set contains multiple collations (a composite trait),
* this method throws {@link IllegalStateException}. Use
* {@link #getCollations()} to handle both the single and multi-collation
* cases uniformly.
*/
@SuppressWarnings("unchecked")
public <T extends RelCollation> @Nullable T getCollation() {
Expand All @@ -395,17 +436,19 @@ public RelTraitSet getDefaultSansConvention() {
/**
* Returns {@link RelCollation} traits defined by
* {@link RelCollationTraitDef#INSTANCE}.
*
* <p>Returns an empty list when the trait def is not registered, a
* singleton list for the common single-collation case, and a list with
* more than one element when a {@link RelCompositeTrait} is present.
*/
@SuppressWarnings("unchecked")
public List<RelCollation> getCollations() {
RelCollation trait = getTrait(RelCollationTraitDef.INSTANCE);
if (trait == null) {
int index = findIndex(RelCollationTraitDef.INSTANCE);
if (index < 0) {
return ImmutableList.of();
}
if (trait instanceof RelCompositeTrait) {
return ((RelCompositeTrait<RelCollation>) trait).traitList();
}
return ImmutableList.of(trait);
// getTraits(int) already unwraps RelCompositeTrait transparently.
return (List<RelCollation>) (List<?>) getTraits(index);
}

/**
Expand Down Expand Up @@ -577,8 +620,22 @@ public boolean contains(RelTrait trait) {
*/
public boolean containsIfApplicable(RelTrait trait) {
// Note that '==' is sufficient, because trait should be canonized.
final RelTrait trait1 = getTrait(trait.getTraitDef());
return trait1 == null || trait1 == trait;
int index = findIndex(trait.getTraitDef());
if (index < 0) {
// TraitDef not registered in this set → treat as "not applicable" → true
return true;
}
final RelTrait stored = getTrait(index);
if (stored instanceof RelCompositeTrait) {
// Any member of the composite matching the sought trait counts.
for (Object t : ((RelCompositeTrait<?>) stored).traitList()) {
if (t == trait) {
return true;
}
}
return false;
}
return stored == trait;
}

/**
Expand Down
35 changes: 35 additions & 0 deletions core/src/test/java/org/apache/calcite/plan/RelTraitTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelCollationTraitDef;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.test.RelBuilderTest;
import org.apache.calcite.tools.FrameworkConfig;
import org.apache.calcite.tools.RelBuilder;

import com.google.common.collect.ImmutableList;

Expand All @@ -33,6 +37,7 @@
import static org.hamcrest.Matchers.hasSize;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

import static java.lang.Integer.toHexString;
Expand Down Expand Up @@ -98,4 +103,34 @@ private void assertCanonical(String message, Supplier<List<RelCollation>> collat
RelTraitSet traits3 = traits2.replace(RelCollations.of(1));
assertFalse(traits3.equalsSansConvention(traits2));
}

/** Test for
* <a href="https://issues.apache.org/jira/browse/CALCITE-7431">[CALCITE-7431]
* RelTraitSet#getTrait seems to mishandle RelCompositeTrait</a>. */
@Test void testRelCompositeTrait() {
// Build: EMP -> Sort(MGR asc) -> Project(MGR, MGR as MGR2)
// The project maps both output columns 0 and 1 back to input column 3
// (MGR), so the planner derives two collations: [0 ASC] and [1 ASC], which
// are stored as a RelCompositeTrait in the output trait set.
final FrameworkConfig config = RelBuilderTest.config().build();
final RelBuilder b = RelBuilder.create(config);
final RelNode in = b
.scan("EMP")
.sort(3) // MGR asc
.project(b.field(3), b.alias(b.field(3), "MGR2")) // MGR, MGR as MGR2
.build();

final RelTraitSet traitSet = in.getTraitSet();

final List<RelCollation> collations = traitSet.getCollations();
assertTrue(collations.size() >= 2,
"getCollations() should expose all composite collations");

assertThrows(IllegalStateException.class, traitSet::getCollation,
"getCollation() should throw when a RelCompositeTrait is present");

assertThrows(IllegalStateException.class,
() -> traitSet.getTrait(RelCollationTraitDef.INSTANCE),
"getTrait() should throw when a RelCompositeTrait is present");
}
}
Loading