diff --git a/flink-connectors/flink-connector-datagen-test/src/test/java/org/apache/flink/connector/datagen/source/DataGeneratorSourceITCase.java b/flink-connectors/flink-connector-datagen-test/src/test/java/org/apache/flink/connector/datagen/source/DataGeneratorSourceITCase.java index 2aa1468caeaec..891b6b6c04d90 100644 --- a/flink-connectors/flink-connector-datagen-test/src/test/java/org/apache/flink/connector/datagen/source/DataGeneratorSourceITCase.java +++ b/flink-connectors/flink-connector-datagen-test/src/test/java/org/apache/flink/connector/datagen/source/DataGeneratorSourceITCase.java @@ -21,6 +21,7 @@ import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.eventtime.WatermarkStrategy; import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.state.CheckpointListener; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; @@ -254,21 +255,42 @@ public TypeInformation getTypeInformation() { } } + /** + * A filter that only passes through elements received before the first checkpoint completes. + * + *

The filter stops collecting elements in {@link #notifyCheckpointComplete(long)} rather + * than in {@link #snapshotState(FunctionSnapshotContext)}, to avoid a race condition where the + * checkpoint barrier arrives at this operator before all upstream elements (emitted in the same + * checkpoint cycle) have been processed. Using {@code notifyCheckpointComplete} ensures that + * the checkpoint has fully propagated through the pipeline before we stop collecting. + */ private static class FirstCheckpointFilter - implements FlatMapFunction, CheckpointedFunction { + implements FlatMapFunction, CheckpointedFunction, CheckpointListener { - private volatile boolean firstCheckpoint = true; + private volatile boolean firstCheckpointCompleted = false; + private long firstCheckpointId = Long.MIN_VALUE; @Override public void flatMap(Long value, Collector out) throws Exception { - if (firstCheckpoint) { + if (!firstCheckpointCompleted) { out.collect(value); } } @Override public void snapshotState(FunctionSnapshotContext context) throws Exception { - firstCheckpoint = false; + // Record the ID of the first checkpoint so we can stop collecting when it completes. + if (firstCheckpointId == Long.MIN_VALUE) { + firstCheckpointId = context.getCheckpointId(); + } + } + + @Override + public void notifyCheckpointComplete(long checkpointId) throws Exception { + // Stop collecting elements once the first checkpoint has completed. + if (checkpointId >= firstCheckpointId && firstCheckpointId != Long.MIN_VALUE) { + firstCheckpointCompleted = true; + } } @Override