Skip to content

Commit f08a818

Browse files
committed
[FLINK-39018][network] Buffer migration from RecoveredInputChannel to physical channels
1 parent 8e8dc32 commit f08a818

14 files changed

Lines changed: 242 additions & 29 deletions

flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ public LocalInputChannel(
9898
int maxBackoff,
9999
Counter numBytesIn,
100100
Counter numBuffersIn,
101-
ChannelStateWriter stateWriter) {
101+
ChannelStateWriter stateWriter,
102+
ArrayDeque<Buffer> initialRecoveredBuffers) {
102103

103104
super(
104105
inputGate,
@@ -113,6 +114,31 @@ public LocalInputChannel(
113114
this.partitionManager = checkNotNull(partitionManager);
114115
this.taskEventPublisher = checkNotNull(taskEventPublisher);
115116
this.channelStatePersister = new ChannelStatePersister(stateWriter, getChannelInfo());
117+
118+
// Migrate recovered buffers from RecoveredInputChannel if provided.
119+
// These buffers have been filtered but not yet consumed by the Task.
120+
if (!initialRecoveredBuffers.isEmpty()) {
121+
final int expectedCount = initialRecoveredBuffers.size();
122+
// Sequence number starts at Integer.MIN_VALUE, consistent with RecoveredInputChannel.
123+
int seqNum = Integer.MIN_VALUE;
124+
while (!initialRecoveredBuffers.isEmpty()) {
125+
Buffer buffer = initialRecoveredBuffers.poll();
126+
// Determine next data type based on the next buffer in the queue
127+
Buffer.DataType nextDataType =
128+
initialRecoveredBuffers.isEmpty()
129+
? Buffer.DataType.NONE
130+
: initialRecoveredBuffers.peek().getDataType();
131+
// buffersInBacklog is set to 0 as these are recovered buffers
132+
BufferAndBacklog bufferAndBacklog =
133+
new BufferAndBacklog(buffer, 0, nextDataType, seqNum++);
134+
toBeConsumedBuffers.add(bufferAndBacklog);
135+
}
136+
checkState(
137+
toBeConsumedBuffers.size() == expectedCount,
138+
"Buffer migration failed: expected %s buffers but got %s",
139+
expectedCount,
140+
toBeConsumedBuffers.size());
141+
}
116142
}
117143

118144
// ------------------------------------------------------------------------

flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannel.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,14 @@
1919
package org.apache.flink.runtime.io.network.partition.consumer;
2020

2121
import org.apache.flink.runtime.io.network.TaskEventPublisher;
22+
import org.apache.flink.runtime.io.network.buffer.Buffer;
2223
import org.apache.flink.runtime.io.network.metrics.InputChannelMetrics;
2324
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
2425
import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
2526
import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet;
2627

28+
import java.util.ArrayDeque;
29+
2730
import static org.apache.flink.util.Preconditions.checkNotNull;
2831

2932
/**
@@ -61,7 +64,7 @@ public class LocalRecoveredInputChannel extends RecoveredInputChannel {
6164
}
6265

6366
@Override
64-
protected InputChannel toInputChannelInternal() {
67+
protected InputChannel toInputChannelInternal(ArrayDeque<Buffer> remainingBuffers) {
6568
return new LocalInputChannel(
6669
inputGate,
6770
getChannelIndex(),
@@ -73,6 +76,7 @@ protected InputChannel toInputChannelInternal() {
7376
maxBackoff,
7477
numBytesIn,
7578
numBuffersIn,
76-
channelStateWriter);
79+
channelStateWriter,
80+
remainingBuffers);
7781
}
7882
}

flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannel.java

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,16 @@ public void setChannelStateWriter(ChannelStateWriter channelStateWriter) {
111111
public final InputChannel toInputChannel() throws IOException {
112112
Preconditions.checkState(
113113
stateConsumedFuture.isDone(), "recovered state is not fully consumed");
114-
final InputChannel inputChannel = toInputChannelInternal();
114+
115+
// Extract remaining buffers before conversion.
116+
// These buffers have been filtered but not yet consumed by the Task.
117+
final ArrayDeque<Buffer> remainingBuffers;
118+
synchronized (receivedBuffers) {
119+
remainingBuffers = new ArrayDeque<>(receivedBuffers);
120+
receivedBuffers.clear();
121+
}
122+
123+
final InputChannel inputChannel = toInputChannelInternal(remainingBuffers);
115124
inputChannel.checkpointStopped(lastStoppedCheckpointId);
116125
return inputChannel;
117126
}
@@ -121,7 +130,15 @@ public void checkpointStopped(long checkpointId) {
121130
this.lastStoppedCheckpointId = checkpointId;
122131
}
123132

124-
protected abstract InputChannel toInputChannelInternal() throws IOException;
133+
/**
134+
* Creates the physical InputChannel from this recovered channel.
135+
*
136+
* @param remainingBuffers buffers that have been filtered but not yet consumed by the Task.
137+
* These buffers will be migrated to the new physical channel.
138+
* @return the physical InputChannel (LocalInputChannel or RemoteInputChannel)
139+
*/
140+
protected abstract InputChannel toInputChannelInternal(ArrayDeque<Buffer> remainingBuffers)
141+
throws IOException;
125142

126143
CompletableFuture<?> getStateConsumedFuture() {
127144
return stateConsumedFuture;

flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ public RemoteInputChannel(
138138
int networkBuffersPerChannel,
139139
Counter numBytesIn,
140140
Counter numBuffersIn,
141-
ChannelStateWriter stateWriter) {
141+
ChannelStateWriter stateWriter,
142+
ArrayDeque<Buffer> initialRecoveredBuffers) {
142143

143144
super(
144145
inputGate,
@@ -157,6 +158,29 @@ public RemoteInputChannel(
157158
this.connectionManager = checkNotNull(connectionManager);
158159
this.bufferManager = new BufferManager(inputGate.getMemorySegmentProvider(), this, 0);
159160
this.channelStatePersister = new ChannelStatePersister(stateWriter, getChannelInfo());
161+
162+
// Migrate recovered buffers from RecoveredInputChannel if provided.
163+
// These buffers have been filtered but not yet consumed by the Task.
164+
if (!initialRecoveredBuffers.isEmpty()) {
165+
final int expectedCount = initialRecoveredBuffers.size();
166+
// Sequence number starts at Integer.MIN_VALUE, consistent with RecoveredInputChannel.
167+
int seqNum = Integer.MIN_VALUE;
168+
for (Buffer buffer : initialRecoveredBuffers) {
169+
// subpartitionId is set to 0 for recovered buffers. This is correct because:
170+
// 1) For single-subpartition channels, the only valid subpartition is 0.
171+
// 2) For multi-subpartition channels (consumedSubpartitionIndexSet.size() > 1),
172+
// RecoveryMetadata events embedded in the recovered buffer sequence track
173+
// the actual subpartition context for proper routing.
174+
SequenceBuffer sequenceBuffer = new SequenceBuffer(buffer, seqNum++, 0);
175+
receivedBuffers.add(sequenceBuffer);
176+
totalQueueSizeInBytes += buffer.getSize();
177+
}
178+
checkState(
179+
receivedBuffers.size() == expectedCount,
180+
"Buffer migration failed: expected %s buffers but got %s",
181+
expectedCount,
182+
receivedBuffers.size());
183+
}
160184
}
161185

162186
@VisibleForTesting
@@ -239,9 +263,9 @@ protected boolean increaseBackoff() {
239263

240264
@Override
241265
protected int peekNextBufferSubpartitionIdInternal() throws IOException {
242-
checkPartitionRequestQueueInitialized();
243-
244266
synchronized (receivedBuffers) {
267+
checkReadability();
268+
245269
final SequenceBuffer next = receivedBuffers.peek();
246270

247271
if (next != null) {
@@ -254,12 +278,12 @@ protected int peekNextBufferSubpartitionIdInternal() throws IOException {
254278

255279
@Override
256280
public Optional<BufferAndAvailability> getNextBuffer() throws IOException {
257-
checkPartitionRequestQueueInitialized();
258-
259281
final SequenceBuffer next;
260282
final DataType nextDataType;
261283

262284
synchronized (receivedBuffers) {
285+
checkReadability();
286+
263287
next = receivedBuffers.poll();
264288

265289
if (next != null) {
@@ -879,6 +903,20 @@ public void onError(Throwable cause) {
879903
setError(cause);
880904
}
881905

906+
/**
907+
* When receivedBuffers contains migrated buffers from RecoveredInputChannel, they can be read
908+
* before requestSubpartitions(). In that case only check for errors. Once migrated buffers are
909+
* drained, require full client initialization check.
910+
*/
911+
private void checkReadability() throws IOException {
912+
assert Thread.holdsLock(receivedBuffers);
913+
if (receivedBuffers.isEmpty()) {
914+
checkPartitionRequestQueueInitialized();
915+
} else {
916+
checkError();
917+
}
918+
}
919+
882920
private void checkPartitionRequestQueueInitialized() throws IOException {
883921
checkError();
884922
checkState(

flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannel.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@
2020

2121
import org.apache.flink.runtime.io.network.ConnectionID;
2222
import org.apache.flink.runtime.io.network.ConnectionManager;
23+
import org.apache.flink.runtime.io.network.buffer.Buffer;
2324
import org.apache.flink.runtime.io.network.metrics.InputChannelMetrics;
2425
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
2526
import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet;
2627

2728
import java.io.IOException;
29+
import java.util.ArrayDeque;
2830

2931
import static org.apache.flink.util.Preconditions.checkNotNull;
3032

@@ -66,7 +68,8 @@ public class RemoteRecoveredInputChannel extends RecoveredInputChannel {
6668
}
6769

6870
@Override
69-
protected InputChannel toInputChannelInternal() throws IOException {
71+
protected InputChannel toInputChannelInternal(ArrayDeque<Buffer> remainingBuffers)
72+
throws IOException {
7073
RemoteInputChannel remoteInputChannel =
7174
new RemoteInputChannel(
7275
inputGate,
@@ -81,7 +84,8 @@ protected InputChannel toInputChannelInternal() throws IOException {
8184
networkBuffersPerChannel,
8285
numBytesIn,
8386
numBuffersIn,
84-
channelStateWriter);
87+
channelStateWriter,
88+
remainingBuffers);
8589
remoteInputChannel.setup();
8690
return remoteInputChannel;
8791
}

flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,10 @@ public void requestPartitions() {
375375
}
376376
}
377377

378+
/**
379+
* Converts all {@link RecoveredInputChannel}s to their real channel types ({@link
380+
* LocalInputChannel} or {@link RemoteInputChannel}).
381+
*/
378382
@VisibleForTesting
379383
public void convertRecoveredInputChannels() {
380384
LOG.debug("Converting recovered input channels ({} channels)", getNumberOfInputChannels());
@@ -384,19 +388,40 @@ public void convertRecoveredInputChannels() {
384388
new HashSet<>(inputChannelsForCurrentPartition.keySet());
385389
for (InputChannelInfo inputChannelInfo : oldInputChannelInfos) {
386390
InputChannel inputChannel = inputChannelsForCurrentPartition.get(inputChannelInfo);
387-
if (inputChannel instanceof RecoveredInputChannel) {
388-
try {
389-
InputChannel realInputChannel =
390-
((RecoveredInputChannel) inputChannel).toInputChannel();
391-
inputChannel.releaseAllResources();
391+
if (!(inputChannel instanceof RecoveredInputChannel)) {
392+
continue;
393+
}
394+
try {
395+
// Phase 1: Convert channel and release resources outside the lock.
396+
// These calls may acquire the receivedBuffers lock internally, so they
397+
// run outside inputChannelsWithData lock to maintain a consistent lock
398+
// order with onRecoveredStateBuffer() which acquires receivedBuffers
399+
// first and then inputChannelsWithData.
400+
InputChannel realInputChannel =
401+
((RecoveredInputChannel) inputChannel).toInputChannel();
402+
inputChannel.releaseAllResources();
403+
int buffersInUseCount = realInputChannel.getBuffersInUseCount();
404+
405+
// Phase 2: Atomically update data structures under the lock.
406+
synchronized (inputChannelsWithData) {
407+
if (inputChannelsWithData.contains(inputChannel)) {
408+
inputChannelsWithData.getAndRemove(ch -> ch == inputChannel);
409+
}
410+
enqueuedInputChannelsWithData.clear(inputChannel.getChannelIndex());
411+
392412
inputChannelsForCurrentPartition.remove(inputChannelInfo);
393413
inputChannelsForCurrentPartition.put(
394414
realInputChannel.getChannelInfo(), realInputChannel);
395415
channels[inputChannel.getChannelIndex()] = realInputChannel;
396-
} catch (Throwable t) {
397-
inputChannel.setError(t);
398-
return;
416+
417+
if (buffersInUseCount > 0) {
418+
inputChannelsWithData.add(realInputChannel);
419+
enqueuedInputChannelsWithData.set(realInputChannel.getChannelIndex());
420+
}
399421
}
422+
} catch (Throwable t) {
423+
inputChannel.setError(t);
424+
return;
400425
}
401426
}
402427
}

flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnknownInputChannel.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import javax.annotation.Nullable;
3636

3737
import java.io.IOException;
38+
import java.util.ArrayDeque;
3839
import java.util.Optional;
3940

4041
import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHECKPOINT_DECLINED_TASK_NOT_READY;
@@ -183,7 +184,8 @@ public RemoteInputChannel toRemoteInputChannel(
183184
networkBuffersPerChannel,
184185
metrics.getNumBytesInRemoteCounter(),
185186
metrics.getNumBuffersInRemoteCounter(),
186-
channelStateWriter == null ? ChannelStateWriter.NO_OP : channelStateWriter);
187+
channelStateWriter == null ? ChannelStateWriter.NO_OP : channelStateWriter,
188+
new ArrayDeque<>());
187189
}
188190

189191
public LocalInputChannel toLocalInputChannel(ResultPartitionID resultPartitionID) {
@@ -198,7 +200,8 @@ public LocalInputChannel toLocalInputChannel(ResultPartitionID resultPartitionID
198200
maxBackoff,
199201
metrics.getNumBytesInLocalCounter(),
200202
metrics.getNumBuffersInLocalCounter(),
201-
channelStateWriter == null ? ChannelStateWriter.NO_OP : channelStateWriter);
203+
channelStateWriter == null ? ChannelStateWriter.NO_OP : channelStateWriter,
204+
new ArrayDeque<>());
202205
}
203206

204207
@Override

flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
import org.junit.jupiter.params.provider.MethodSource;
6868

6969
import java.io.IOException;
70+
import java.util.ArrayDeque;
7071
import java.net.InetSocketAddress;
7172
import java.util.stream.Stream;
7273

@@ -951,7 +952,8 @@ private static class TestRemoteInputChannelForError extends RemoteInputChannel {
951952
2,
952953
new SimpleCounter(),
953954
new SimpleCounter(),
954-
ChannelStateWriter.NO_OP);
955+
ChannelStateWriter.NO_OP,
956+
new ArrayDeque<>());
955957
this.expectedMessage = expectedMessage;
956958
}
957959

flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestRegistrationTest.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444

4545
import org.junit.jupiter.api.Test;
4646

47+
import java.util.ArrayDeque;
4748
import java.util.Optional;
4849
import java.util.concurrent.CountDownLatch;
4950
import java.util.concurrent.TimeUnit;
@@ -248,7 +249,8 @@ private static class TestRemoteInputChannelForPartitionNotFound extends RemoteIn
248249
2,
249250
new SimpleCounter(),
250251
new SimpleCounter(),
251-
ChannelStateWriter.NO_OP);
252+
ChannelStateWriter.NO_OP,
253+
new ArrayDeque<>());
252254
this.latch = latch;
253255
}
254256

flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannelBuilder.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet;
3535

3636
import java.net.InetSocketAddress;
37+
import java.util.ArrayDeque;
3738

3839
import static org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateTest.TestingResultPartitionManager;
3940

@@ -164,7 +165,8 @@ public LocalInputChannel buildLocalChannel(SingleInputGate inputGate) {
164165
maxBackoff,
165166
metrics.getNumBytesInLocalCounter(),
166167
metrics.getNumBuffersInLocalCounter(),
167-
stateWriter);
168+
stateWriter,
169+
new ArrayDeque<>());
168170
}
169171

170172
public RemoteInputChannel buildRemoteChannel(SingleInputGate inputGate) {
@@ -181,7 +183,8 @@ public RemoteInputChannel buildRemoteChannel(SingleInputGate inputGate) {
181183
networkBuffersPerChannel,
182184
metrics.getNumBytesInRemoteCounter(),
183185
metrics.getNumBuffersInRemoteCounter(),
184-
stateWriter);
186+
stateWriter,
187+
new ArrayDeque<>());
185188
}
186189

187190
public LocalRecoveredInputChannel buildLocalRecoveredChannel(SingleInputGate inputGate) {

0 commit comments

Comments
 (0)