From 09685a216cc1e510d50c220572675fe9c91c6048 Mon Sep 17 00:00:00 2001 From: Jerome Haltom Date: Thu, 7 May 2026 13:34:23 -0500 Subject: [PATCH] [CALCITE-7510] EnumerableTableModify UPDATE support --- .../enumerable/EnumerableTableModify.java | 130 ++++++++++++++++-- .../apache/calcite/util/BuiltInMethod.java | 2 + .../calcite/linq4j/DefaultEnumerable.java | 9 ++ .../calcite/linq4j/EnumerableDefaults.java | 35 +++++ .../calcite/linq4j/ExtendedEnumerable.java | 26 ++++ .../org/apache/calcite/test/ServerTest.java | 37 +++++ 6 files changed, 230 insertions(+), 9 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableTableModify.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableTableModify.java index 186ec6903f91..f6fb8ae791bb 100644 --- a/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableTableModify.java +++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/EnumerableTableModify.java @@ -17,6 +17,7 @@ package org.apache.calcite.adapter.enumerable; import org.apache.calcite.adapter.java.JavaTypeFactory; +import org.apache.calcite.linq4j.function.Function1; import org.apache.calcite.linq4j.tree.BlockBuilder; import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; @@ -28,6 +29,7 @@ import org.apache.calcite.prepare.Prepare; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.TableModify; +import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexNode; import org.apache.calcite.schema.ModifiableTable; import org.apache.calcite.util.BuiltInMethod; @@ -37,6 +39,7 @@ import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.List; @@ -100,6 +103,92 @@ public EnumerableTableModify(RelOptCluster cluster, RelTraitSet traits, expression, BuiltInMethod.MODIFIABLE_TABLE_GET_MODIFIABLE_COLLECTION .method))); + + final PhysType physType = + PhysTypeImpl.of( + implementor.getTypeFactory(), + getRowType(), + pref == Prefer.ARRAY + ? JavaRowFormat.ARRAY : JavaRowFormat.SCALAR); + + if (getOperation() == Operation.UPDATE) { + // For UPDATE, the child produces, for each row matched by the WHERE + // clause, a row of tableFieldCount + M fields: + // [originalField_0, ..., originalField_N-1, newValue_0, ..., newValue_M-1] + // The first N fields are the *entire* original table row (all columns, + // not just those being updated); the trailing M = updateColumnList.size() + // fields are the new values, one per column named in the SET clause. + // Filtering by WHERE has already been applied upstream, so every source + // row corresponds to an existing row in the modifiable collection and + // can be located by full-row content equality. + final List updateCols = requireNonNull(getUpdateColumnList()); + final List tableFields = table.getRowType().getFieldList(); + final int tableFieldCount = tableFields.size(); + final int[] updateColumnIndices = new int[updateCols.size()]; + for (int i = 0; i < updateCols.size(); i++) { + final String colName = updateCols.get(i); + int found = -1; + for (int j = 0; j < tableFields.size(); j++) { + if (tableFields.get(j).getName().equals(colName)) { + found = j; + break; + } + } + if (found < 0) { + throw new AssertionError("column '" + colName + "' not found in table"); + } + updateColumnIndices[i] = found; + } + + // Build the three lambdas required by ExtendedEnumerable.update: + // sinkKeySelector: row -> Arrays.asList(row) + // sourceKeySelector: row -> Arrays.asList(Arrays.copyOf(row, N)) + // sourceTransform: row -> applyUpdate(row, N, updateColumnIndices) + final ParameterExpression sinkRow = + Expressions.parameter(Object[].class, "sinkRow"); + final Expression sinkKeySelector = + Expressions.lambda(Function1.class, + Expressions.call(Arrays.class, "asList", sinkRow), + sinkRow); + + final ParameterExpression srcKeyRow = + Expressions.parameter(Object[].class, "row"); + final Expression sourceKeySelector = + Expressions.lambda(Function1.class, + Expressions.call(Arrays.class, "asList", + Expressions.call(Arrays.class, "copyOf", + srcKeyRow, Expressions.constant(tableFieldCount))), + srcKeyRow); + + final ParameterExpression srcXformRow = + Expressions.parameter(Object[].class, "row"); + final Expression sourceTransform = + Expressions.lambda(Function1.class, + Expressions.call(EnumerableTableModify.class, "applyUpdate", + srcXformRow, + Expressions.constant(tableFieldCount), + Expressions.constant(updateColumnIndices)), + srcXformRow); + + final Expression updateCountExp = + builder.append( + "updateCount", + Expressions.call( + childExp, + BuiltInMethod.UPDATE.method, + Expressions.convert_(collectionParameter, List.class), + sinkKeySelector, + sourceKeySelector, + sourceTransform)); + builder.add( + Expressions.return_( + null, + Expressions.call( + BuiltInMethod.SINGLETON_ENUMERABLE.method, + Expressions.convert_(updateCountExp, long.class)))); + return implementor.result(physType, builder.toBlock()); + } + final Expression countParameter = builder.append( "count", @@ -110,7 +199,7 @@ public EnumerableTableModify(RelOptCluster cluster, RelTraitSet traits, final JavaTypeFactory typeFactory = (JavaTypeFactory) getCluster().getTypeFactory(); final JavaRowFormat format = EnumerableTableScan.deduceFormat(table); - PhysType physType = + PhysType tablePhysType = PhysTypeImpl.of(typeFactory, table.getRowType(), format); List expressionList = new ArrayList<>(); final PhysType childPhysType = result.physType; @@ -120,7 +209,7 @@ public EnumerableTableModify(RelOptCluster cluster, RelTraitSet traits, childPhysType.getRowType().getFieldCount(); for (int i = 0; i < fieldCount; i++) { expressionList.add( - childPhysType.fieldReference(o_, i, physType.getJavaFieldType(i))); + childPhysType.fieldReference(o_, i, tablePhysType.getJavaFieldType(i))); } convertedChildExp = builder.append( @@ -129,7 +218,7 @@ public EnumerableTableModify(RelOptCluster cluster, RelTraitSet traits, childExp, BuiltInMethod.SELECT.method, Expressions.lambda( - physType.record(expressionList), o_))); + tablePhysType.record(expressionList), o_))); } else { convertedChildExp = childExp; } @@ -167,13 +256,36 @@ public EnumerableTableModify(RelOptCluster cluster, RelTraitSet traits, Expressions.subtract( countParameter, updatedCountParameter)), long.class)))); - final PhysType physType = - PhysTypeImpl.of( - implementor.getTypeFactory(), - getRowType(), - pref == Prefer.ARRAY - ? JavaRowFormat.ARRAY : JavaRowFormat.SCALAR); return implementor.result(physType, builder.toBlock()); } + /** + * Builds the replacement row for an UPDATE source row. + * + *

The source row layout is: + * {@code [originalField_0, ..., originalField_N-1, newValue_0, ..., newValue_M-1]} + * where {@code N = tableFieldCount} is the full width of the table row + * (i.e. all original columns, not just those being updated) and + * {@code M = updateColumnIndices.length}. The result is a copy of the first + * {@code N} fields with the trailing new values substituted at the indicated + * column positions; columns not named in the SET clause therefore retain + * their original values. + * + * @param row Source row (full original row followed by new + * values for the SET columns) + * @param tableFieldCount Number of fields in the original table row + * @param updateColumnIndices 0-based indices of the columns being updated + * @return The replacement row + */ + public static Object[] applyUpdate( + Object[] row, + int tableFieldCount, + int[] updateColumnIndices) { + final Object[] newRow = Arrays.copyOf(row, tableFieldCount); + for (int i = 0; i < updateColumnIndices.length; i++) { + newRow[updateColumnIndices[i]] = row[tableFieldCount + i]; + } + return newRow; + } + } diff --git a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java index 98b67a02bbf3..98c334799506 100644 --- a/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java +++ b/core/src/main/java/org/apache/calcite/util/BuiltInMethod.java @@ -181,6 +181,8 @@ public enum BuiltInMethod { Integer.class, int.class, int.class, BigDecimal.class, RoundingMode.class), INTO(ExtendedEnumerable.class, "into", Collection.class), REMOVE_ALL(ExtendedEnumerable.class, "removeAll", Collection.class), + UPDATE(ExtendedEnumerable.class, "update", List.class, Function1.class, + Function1.class, Function1.class), SCHEMA_GET_SUB_SCHEMA(Schema.class, "getSubSchema", String.class), SCHEMA_GET_TABLE(Schema.class, "getTable", String.class), SCHEMA_PLUS_ADD_TABLE(SchemaPlus.class, "add", String.class, Table.class), diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java b/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java index d859519b4178..54dcc20f41e7 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java @@ -377,6 +377,15 @@ protected OrderedQueryable asOrderedQueryable() { return EnumerableDefaults.remove(getThis(), sink); } + @Override public long update( + List sink, + Function1 sinkKeySelector, + Function1 sourceKeySelector, + Function1 transform) { + return EnumerableDefaults.update(getThis(), sink, sinkKeySelector, + sourceKeySelector, transform); + } + @Override public Enumerable hashJoin( Enumerable inner, Function1 outerKeySelector, Function1 innerKeySelector, diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java index 988450df3643..ae26a602132d 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java @@ -61,6 +61,7 @@ import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; +import java.util.ListIterator; import java.util.Map; import java.util.NoSuchElementException; import java.util.Objects; @@ -4399,6 +4400,40 @@ public static > C remove( return sink; } + /** + * Default implementation of + * {@link ExtendedEnumerable#update(List, Function1, Function1, Function1)}. + * + *

Builds a map from source-row keys to replacement rows in a single pass + * over the source, then performs a single pass over the sink, replacing + * matched rows in place. + */ + public static long update( + Enumerable source, + List sink, + Function1 sinkKeySelector, + Function1 sourceKeySelector, + Function1 sourceTransform) { + final Map updateMap = new HashMap<>(); + try (Enumerator e = source.enumerator()) { + while (e.moveNext()) { + final T row = e.current(); + updateMap.put(sourceKeySelector.apply(row), sourceTransform.apply(row)); + } + } + long updateCount = 0; + final ListIterator it = sink.listIterator(); + while (it.hasNext()) { + final T current = it.next(); + final T newRow = updateMap.get(sinkKeySelector.apply(current)); + if (newRow != null) { + it.set(newRow); + updateCount++; + } + } + return updateCount; + } + /** * Hash table with null-safe key set. * diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java b/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java index 160f2afa0b1e..ff28c3fe822c 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java @@ -547,6 +547,32 @@ Enumerable intersect(Enumerable enumerable1, */ > C removeAll(C sink); + /** + * Updates rows of {@code sink} based on the contents of this sequence. + * + *

For each element {@code x} of this sequence, {@code sourceKeySelector} + * computes a key, and {@code sourceTransform} computes a replacement row. + * Then for each element {@code y} of {@code sink}, {@code sinkKeySelector} + * computes a key; if it matches a key produced from this sequence, {@code y} + * is replaced (in place) with the corresponding replacement row. + * + *

The sink is a {@link List} so that elements can be replaced + * in place while preserving order. + * + * @param sink List to be updated in place + * @param sinkKeySelector Function that extracts a key from a sink row + * @param sourceKeySelector Function that extracts a key from a source row + * @param transform Function that produces the replacement row from a + * source row + * @param Key type + * @return Number of rows replaced + */ + long update( + List sink, + Function1 sinkKeySelector, + Function1 sourceKeySelector, + Function1 transform); + /** * Correlates the elements of two sequences based on * matching keys. The default equality comparer is used to compare diff --git a/server/src/test/java/org/apache/calcite/test/ServerTest.java b/server/src/test/java/org/apache/calcite/test/ServerTest.java index 40e1430c8799..39a9be18c67d 100644 --- a/server/src/test/java/org/apache/calcite/test/ServerTest.java +++ b/server/src/test/java/org/apache/calcite/test/ServerTest.java @@ -108,6 +108,43 @@ static Connection connect() throws SQLException { executor.execute((SqlTruncateTable) o, context); } + @Test void testUpdate() throws Exception { + try (Connection c = connect(); + Statement s = c.createStatement()) { + s.execute("create table t (i int not null, j int not null)"); + s.executeUpdate("insert into t values (1, 10)"); + s.executeUpdate("insert into t values (2, 20)"); + s.executeUpdate("insert into t values (3, 30)"); + + // Update one row + int count = s.executeUpdate("update t set j = 99 where i = 2"); + assertThat(count, is(1)); + + try (ResultSet r = s.executeQuery("select i, j from t order by i")) { + assertThat(r.next(), is(true)); + assertThat(r.getInt(1), is(1)); + assertThat(r.getInt(2), is(10)); + assertThat(r.next(), is(true)); + assertThat(r.getInt(1), is(2)); + assertThat(r.getInt(2), is(99)); + assertThat(r.next(), is(true)); + assertThat(r.getInt(1), is(3)); + assertThat(r.getInt(2), is(30)); + assertThat(r.next(), is(false)); + } + + // Update multiple rows + count = s.executeUpdate("update t set j = 0 where i > 1"); + assertThat(count, is(2)); + + try (ResultSet r = s.executeQuery("select sum(j) from t")) { + assertThat(r.next(), is(true)); + assertThat(r.getInt(1), is(10)); + assertThat(r.next(), is(false)); + } + } + } + @Test void testStatement() throws Exception { try (Connection c = connect(); Statement s = c.createStatement();