From f96b4145fc2c5b7c9a2677a1bfdd7e884a7967af Mon Sep 17 00:00:00 2001 From: bowenli86 Date: Sun, 29 Mar 2026 17:49:15 -0700 Subject: [PATCH] [FLINK-39226][python] Fix embedded PyIterator class cast after recovery --- ...neInputEmbeddedPythonFunctionOperator.java | 16 ++- ...woInputEmbeddedPythonFunctionOperator.java | 31 +++-- .../embedded/EmbeddedPythonIterator.java | 73 ++++++++++++ .../EmbeddedPythonKeyedCoProcessOperator.java | 17 ++- .../EmbeddedPythonKeyedProcessOperator.java | 17 ++- .../EmbeddedPythonWindowOperator.java | 17 ++- .../EmbeddedPythonTableFunctionOperator.java | 35 +++--- .../embedded/EmbeddedPythonIteratorTest.java | 109 ++++++++++++++++++ .../embedded/ForeignClassLoaderIterator.java | 47 ++++++++ 9 files changed, 288 insertions(+), 74 deletions(-) create mode 100644 flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonIterator.java create mode 100644 flink-python/src/test/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonIteratorTest.java create mode 100644 flink-python/src/test/java/org/apache/flink/streaming/api/operators/python/embedded/ForeignClassLoaderIterator.java diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractOneInputEmbeddedPythonFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractOneInputEmbeddedPythonFunctionOperator.java index 6b06d2e5ebf46..4db2b4bbaf1aa 100644 --- a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractOneInputEmbeddedPythonFunctionOperator.java +++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractOneInputEmbeddedPythonFunctionOperator.java @@ -31,7 +31,6 @@ import org.apache.flink.util.Preconditions; import com.google.protobuf.AbstractMessageLite; -import pemja.core.object.PyIterator; import java.util.List; import java.util.stream.Collectors; @@ -151,18 +150,17 @@ public void processElement(StreamRecord element) throws Exception { timestamp = element.getTimestamp(); IN value = element.getValue(); - PyIterator results = - (PyIterator) + try (EmbeddedPythonIterator results = + EmbeddedPythonIterator.from( interpreter.invokeMethod( "operation", "process_element", - inputDataConverter.toExternal(value)); - - while (results.hasNext()) { - OUT result = outputDataConverter.toInternal(results.next()); - collector.collect(result); + inputDataConverter.toExternal(value)))) { + while (results.hasNext()) { + OUT result = outputDataConverter.toInternal(results.next()); + collector.collect(result); + } } - results.close(); } TypeInformation getInputTypeInfo() { diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractTwoInputEmbeddedPythonFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractTwoInputEmbeddedPythonFunctionOperator.java index 6de7e6b821361..dce1000709dc4 100644 --- a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractTwoInputEmbeddedPythonFunctionOperator.java +++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/AbstractTwoInputEmbeddedPythonFunctionOperator.java @@ -31,7 +31,6 @@ import org.apache.flink.util.Preconditions; import com.google.protobuf.AbstractMessageLite; -import pemja.core.object.PyIterator; import java.util.List; import java.util.stream.Collectors; @@ -178,18 +177,17 @@ public void processElement1(StreamRecord element) throws Exception { timestamp = element.getTimestamp(); IN1 value = element.getValue(); - PyIterator results = - (PyIterator) + try (EmbeddedPythonIterator results = + EmbeddedPythonIterator.from( interpreter.invokeMethod( "operation", "process_element1", - inputDataConverter1.toExternal(value)); - - while (results.hasNext()) { - OUT result = outputDataConverter.toInternal(results.next()); - collector.collect(result); + inputDataConverter1.toExternal(value)))) { + while (results.hasNext()) { + OUT result = outputDataConverter.toInternal(results.next()); + collector.collect(result); + } } - results.close(); } @Override @@ -198,18 +196,17 @@ public void processElement2(StreamRecord element) throws Exception { timestamp = element.getTimestamp(); IN2 value = element.getValue(); - PyIterator results = - (PyIterator) + try (EmbeddedPythonIterator results = + EmbeddedPythonIterator.from( interpreter.invokeMethod( "operation", "process_element2", - inputDataConverter2.toExternal(value)); - - while (results.hasNext()) { - OUT result = outputDataConverter.toInternal(results.next()); - collector.collect(result); + inputDataConverter2.toExternal(value)))) { + while (results.hasNext()) { + OUT result = outputDataConverter.toInternal(results.next()); + collector.collect(result); + } } - results.close(); } TypeInformation getInputTypeInfo1() { diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonIterator.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonIterator.java new file mode 100644 index 0000000000000..061f759dd181f --- /dev/null +++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonIterator.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.operators.python.embedded; + +import org.apache.flink.annotation.Internal; + +import java.lang.reflect.Method; +import java.util.Objects; + +/** + * Reflective adapter for embedded Python iterators. + * + *

PEMJA iterator objects may come back from a different user-code classloader after recovery, so + * callers should not hard-cast them to {@code pemja.core.object.PyIterator}. + */ +@Internal +public final class EmbeddedPythonIterator implements AutoCloseable { + + private final Object iterator; + private final Method hasNextMethod; + private final Method nextMethod; + private final Method closeMethod; + + private EmbeddedPythonIterator(Object iterator) { + this.iterator = Objects.requireNonNull(iterator, "iterator must not be null"); + + try { + Class iteratorClass = iterator.getClass(); + this.hasNextMethod = iteratorClass.getMethod("hasNext"); + this.nextMethod = iteratorClass.getMethod("next"); + this.closeMethod = iteratorClass.getMethod("close"); + } catch (ReflectiveOperationException e) { + throw new IllegalStateException( + String.format( + "Failed to adapt embedded Python iterator of type %s.", + iterator.getClass().getName()), + e); + } + } + + public static EmbeddedPythonIterator from(Object iterator) { + return new EmbeddedPythonIterator(iterator); + } + + public boolean hasNext() throws Exception { + return (boolean) hasNextMethod.invoke(iterator); + } + + public Object next() throws Exception { + return nextMethod.invoke(iterator); + } + + @Override + public void close() throws Exception { + closeMethod.invoke(iterator); + } +} diff --git a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonKeyedCoProcessOperator.java b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonKeyedCoProcessOperator.java index edeee270208db..5751716433652 100644 --- a/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonKeyedCoProcessOperator.java +++ b/flink-python/src/main/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonKeyedCoProcessOperator.java @@ -36,8 +36,6 @@ import org.apache.flink.streaming.api.utils.PythonTypeUtils; import org.apache.flink.types.Row; -import pemja.core.object.PyIterator; - import java.util.List; import static org.apache.flink.python.PythonOptions.MAP_STATE_READ_CACHE_SIZE; @@ -149,15 +147,14 @@ private void invokeUserFunction(TimeDomain timeDomain, InternalTimer DataStreamPythonFunctionOperator copy( private void invokeUserFunction(InternalTimer timer) throws Exception { windowTimerContext.timer = timer; - PyIterator results = - (PyIterator) - interpreter.invokeMethod("operation", "on_timer", timer.getTimestamp()); - - while (results.hasNext()) { - OUT result = outputDataConverter.toInternal(results.next()); - collector.collect(result); + try (EmbeddedPythonIterator results = + EmbeddedPythonIterator.from( + interpreter.invokeMethod("operation", "on_timer", timer.getTimestamp()))) { + while (results.hasNext()) { + OUT result = outputDataConverter.toInternal(results.next()); + collector.collect(result); + } } - results.close(); windowTimerContext.timer = null; } diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/table/EmbeddedPythonTableFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/table/EmbeddedPythonTableFunctionOperator.java index 77962daa67c64..d23b2e3580adf 100644 --- a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/table/EmbeddedPythonTableFunctionOperator.java +++ b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/table/EmbeddedPythonTableFunctionOperator.java @@ -22,6 +22,7 @@ import org.apache.flink.configuration.Configuration; import org.apache.flink.fnexecution.v1.FlinkFnApi; import org.apache.flink.python.util.ProtoUtils; +import org.apache.flink.streaming.api.operators.python.embedded.EmbeddedPythonIterator; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.table.data.GenericRowData; import org.apache.flink.table.data.RowData; @@ -33,8 +34,6 @@ import org.apache.flink.table.types.logical.RowType; import org.apache.flink.util.Preconditions; -import pemja.core.object.PyIterator; - import static org.apache.flink.python.PythonOptions.PYTHON_METRIC_ENABLED; import static org.apache.flink.python.PythonOptions.PYTHON_PROFILE_ENABLED; import static org.apache.flink.python.util.ProtoUtils.createFlattenRowTypeCoderInfoDescriptorProto; @@ -147,25 +146,25 @@ public void processElement(StreamRecord element) throws Exception { userDefinedFunctionInputConverters[i].toExternal(value, udfInputOffsets[i]); } - PyIterator udtfResults = - (PyIterator) + try (EmbeddedPythonIterator udtfResults = + EmbeddedPythonIterator.from( interpreter.invokeMethod( "table_operation", "process_element", - (Object) (userDefinedFunctionInputArgs)); - - if (udtfResults.hasNext()) { - do { - Object[] udtfResult = (Object[]) udtfResults.next(); - for (int i = 0; i < udtfResult.length; i++) { - reuseResultRowData.setField( - i, userDefinedFunctionOutputConverters[i].toInternal(udtfResult[i])); - } - rowDataWrapper.collect(reuseJoinedRow.replace(value, reuseResultRowData)); - } while (udtfResults.hasNext()); - } else if (joinType == FlinkJoinType.LEFT) { - rowDataWrapper.collect(reuseJoinedRow.replace(value, reuseNullResultRowData)); + (Object) (userDefinedFunctionInputArgs)))) { + if (udtfResults.hasNext()) { + do { + Object[] udtfResult = (Object[]) udtfResults.next(); + for (int i = 0; i < udtfResult.length; i++) { + reuseResultRowData.setField( + i, + userDefinedFunctionOutputConverters[i].toInternal(udtfResult[i])); + } + rowDataWrapper.collect(reuseJoinedRow.replace(value, reuseResultRowData)); + } while (udtfResults.hasNext()); + } else if (joinType == FlinkJoinType.LEFT) { + rowDataWrapper.collect(reuseJoinedRow.replace(value, reuseNullResultRowData)); + } } - udtfResults.close(); } } diff --git a/flink-python/src/test/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonIteratorTest.java b/flink-python/src/test/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonIteratorTest.java new file mode 100644 index 0000000000000..2f3c0a2a381e2 --- /dev/null +++ b/flink-python/src/test/java/org/apache/flink/streaming/api/operators/python/embedded/EmbeddedPythonIteratorTest.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.operators.python.embedded; + +import org.junit.jupiter.api.Test; +import pemja.core.object.PyIterator; + +import java.lang.reflect.Constructor; +import java.net.URL; +import java.net.URLClassLoader; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for {@link EmbeddedPythonIterator}. */ +class EmbeddedPythonIteratorTest { + + @Test + void testReadsIteratorLoadedByDifferentClassLoader() throws Exception { + URL classesUrl = + EmbeddedPythonIteratorTest.class + .getProtectionDomain() + .getCodeSource() + .getLocation(); + + try (URLClassLoader classLoader = new URLClassLoader(new URL[] {classesUrl}, null)) { + Class iteratorClass = + Class.forName( + "org.apache.flink.streaming.api.operators.python.embedded." + + "ForeignClassLoaderIterator", + true, + classLoader); + Object iterator = + iteratorClass + .getConstructor(Object[].class) + .newInstance((Object) new Object[] {"first", "second"}); + + assertThat(iterator.getClass().getClassLoader()) + .isNotSameAs(EmbeddedPythonIteratorTest.class.getClassLoader()); + + try (EmbeddedPythonIterator embeddedPythonIterator = + EmbeddedPythonIterator.from(iterator)) { + assertThat(embeddedPythonIterator.hasNext()).isTrue(); + assertThat(embeddedPythonIterator.next()).isEqualTo("first"); + assertThat(embeddedPythonIterator.hasNext()).isTrue(); + assertThat(embeddedPythonIterator.next()).isEqualTo("second"); + assertThat(embeddedPythonIterator.hasNext()).isFalse(); + } + + assertThat(iteratorClass.getMethod("isClosed").invoke(iterator)).isEqualTo(true); + } + } + + @Test + void testReproducesDirectPemjaCastFailureAcrossClassLoaders() throws Exception { + Object iterator = createForeignPemjaIterator(); + + assertThat(iterator.getClass().getName()).isEqualTo(PyIterator.class.getName()); + assertThat(iterator.getClass()).isNotEqualTo(PyIterator.class); + assertThat(iterator.getClass().getClassLoader()) + .isNotSameAs(PyIterator.class.getClassLoader()); + + assertThatThrownBy(() -> castToLocalPemjaIterator(iterator)) + .isInstanceOf(ClassCastException.class) + .hasMessageContaining("pemja.core.object.PyIterator") + .hasMessageContaining("cannot be cast"); + } + + @Test + void testWrapsPemjaIteratorLoadedByDifferentClassLoaderWithoutCastFailure() throws Exception { + Object iterator = createForeignPemjaIterator(); + + assertThatCode(() -> EmbeddedPythonIterator.from(iterator)).doesNotThrowAnyException(); + } + + private static Object createForeignPemjaIterator() throws Exception { + URL pemjaJarUrl = PyIterator.class.getProtectionDomain().getCodeSource().getLocation(); + + try (URLClassLoader classLoader = new URLClassLoader(new URL[] {pemjaJarUrl}, null)) { + Class iteratorClass = + Class.forName("pemja.core.object.PyIterator", true, classLoader); + Constructor constructor = + iteratorClass.getDeclaredConstructor(long.class, long.class); + constructor.setAccessible(true); + return constructor.newInstance(0L, 0L); + } + } + + private static PyIterator castToLocalPemjaIterator(Object iterator) { + return (PyIterator) iterator; + } +} diff --git a/flink-python/src/test/java/org/apache/flink/streaming/api/operators/python/embedded/ForeignClassLoaderIterator.java b/flink-python/src/test/java/org/apache/flink/streaming/api/operators/python/embedded/ForeignClassLoaderIterator.java new file mode 100644 index 0000000000000..5fadded27027c --- /dev/null +++ b/flink-python/src/test/java/org/apache/flink/streaming/api/operators/python/embedded/ForeignClassLoaderIterator.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.operators.python.embedded; + +/** Test-only iterator that can be loaded through an isolated classloader. */ +public class ForeignClassLoaderIterator { + + private final Object[] values; + private int index; + private boolean closed; + + public ForeignClassLoaderIterator(Object... values) { + this.values = values; + } + + public boolean hasNext() { + return index < values.length; + } + + public Object next() { + return values[index++]; + } + + public void close() { + closed = true; + } + + public boolean isClosed() { + return closed; + } +}