diff --git a/core/src/main/java/org/apache/calcite/rel/rel2sql/RelToSqlConverter.java b/core/src/main/java/org/apache/calcite/rel/rel2sql/RelToSqlConverter.java index 0404fa7bc51..d6e46a07e41 100644 --- a/core/src/main/java/org/apache/calcite/rel/rel2sql/RelToSqlConverter.java +++ b/core/src/main/java/org/apache/calcite/rel/rel2sql/RelToSqlConverter.java @@ -772,6 +772,7 @@ private Builder visitAggregate(Aggregate e, List groupKeyList, // "select a, b, sum(x) from ( ... ) group by a, b" final boolean ignoreClauses = e.getInput() instanceof Project; final Result x = visitInput(e, 0, isAnon(), ignoreClauses, clauseSet); + parseCorrelTable(e, x); final Builder builder = x.builder(e); final List selectList = new ArrayList<>(); final List groupByList = diff --git a/core/src/main/java/org/apache/calcite/rel/rel2sql/SqlImplementor.java b/core/src/main/java/org/apache/calcite/rel/rel2sql/SqlImplementor.java index dec24ca1f24..2dcc6c5ad45 100644 --- a/core/src/main/java/org/apache/calcite/rel/rel2sql/SqlImplementor.java +++ b/core/src/main/java/org/apache/calcite/rel/rel2sql/SqlImplementor.java @@ -1587,9 +1587,17 @@ public static SqlNode toSql(RexLiteral literal) { } protected Context getAliasContext(RexCorrelVariable variable) { - return requireNonNull( - correlTableMap.get(variable.id), - () -> "variable " + variable.id + " is not found"); + Context context = correlTableMap.get(variable.id); + if (context == null) { + if (correlTableMap.isEmpty()) { + context = + aliasContext(ImmutableMap.of(variable.id.getName(), variable.getType()), true); + } + if (context != null) { + correlTableMap.put(variable.id, context); + } + } + return requireNonNull(context, () -> "variable " + variable.id + " is not found"); } /** Simple implementation of {@link Context} that cannot handle sub-queries diff --git a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java index 41054312202..417a73776fe 100644 --- a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java +++ b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java @@ -11887,6 +11887,61 @@ public Sql schema(CalciteAssert.SchemaSpec schemaSpec) { sql(sql).schema(CalciteAssert.SchemaSpec.JDBC_SCOTT).ok(expected); } + /** Test case for + * [CALCITE-7440] + * RelToSqlConverter throws NPE when correlation scope is missing after + * semi-join rewrites.. */ + @Test void testPostgresqlRoundTripCorrelatedProject() { + final String query = correlatedProjectQuery7440(); + final RuleSet rules = RuleSets.ofList(); + final String generated = sql(query).withPostgresql().optimize(rules, null).exec(); + try { + sql(generated).withPostgresql().exec(); + } catch (Exception e) { + throw new AssertionError( + "Generated SQL failed PostgreSQL round-trip validation:\n" + + generated, + e); + } + } + + @Test void testPostgresqlRoundTripCorrelatedProjectWithSemiJoinRules() { + final String query = correlatedProjectQuery7440(); + + final RuleSet rules = + RuleSets.ofList(CoreRules.FILTER_SUB_QUERY_TO_MARK_CORRELATE, + CoreRules.PROJECT_SUB_QUERY_TO_MARK_CORRELATE, + CoreRules.MARK_TO_SEMI_OR_ANTI_JOIN_RULE, + CoreRules.SEMI_JOIN_JOIN_TRANSPOSE); + + final String generated = sql(query).withPostgresql().optimize(rules, null).exec(); + try { + sql(generated).withPostgresql().exec(); + } catch (Exception e) { + throw new AssertionError( + "Generated SQL failed PostgreSQL round-trip validation:\n" + + generated, + e); + } + } + + private static String correlatedProjectQuery7440() { + return "WITH product_keys AS (\n" + + " SELECT p.\"product_id\",\n" + + " (SELECT MAX(p3.\"product_id\")\n" + + " FROM \"foodmart\".\"product\" p3\n" + + " WHERE p3.\"product_id\" = p.\"product_id\") AS \"mx\"\n" + + " FROM \"foodmart\".\"product\" p\n" + + ")\n" + + "SELECT DISTINCT pk.\"product_id\"\n" + + "FROM product_keys pk\n" + + "LEFT JOIN \"foodmart\".\"product\" p2 USING (\"product_id\")\n" + + "WHERE pk.\"product_id\" IN (\n" + + " SELECT p4.\"product_id\"\n" + + " FROM \"foodmart\".\"product\" p4\n" + + ")"; + } + @Test void testNotBetween() { Sql f = fixture().withConvertletTable(new SqlRexConvertletTable() { @Override public @Nullable SqlRexConvertlet get(SqlCall call) {