Skip to content

Commit

Permalink
[FLINK-34063][runtime] Fix OperatorState repartitioning when compress…
Browse files Browse the repository at this point in the history
…ion is enabled. We should only write compression headers once, at the end of the "value" part of the serialized stream, to make sure we can always seek to a split point.
  • Loading branch information
dmvk committed Jan 19, 2024
1 parent d536b5a commit b7c314b
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -170,35 +170,26 @@ public SnapshotResultSupplier<OperatorStateHandle> asyncSnapshot(
final Map<String, OperatorStateHandle.StateMetaInfo> writtenStatesMetaData =
CollectionUtil.newHashMapWithExpectedSize(initialMapCapacity);

for (Map.Entry<String, PartitionableListState<?>> entry :
registeredOperatorStatesDeepCopies.entrySet()) {
try (final CompressibleFSDataOutputStream compressedLocalOut =
new CompressibleFSDataOutputStream(
localOut,
compressionDecorator)) { // closes only the outer compression stream
for (Map.Entry<String, PartitionableListState<?>> entry :
registeredOperatorStatesDeepCopies.entrySet()) {

PartitionableListState<?> value = entry.getValue();
// create the compressed stream for each state to have the compression header for
// each
try (final CompressibleFSDataOutputStream compressedLocalOut =
new CompressibleFSDataOutputStream(
localOut,
compressionDecorator)) { // closes only the outer compression stream
PartitionableListState<?> value = entry.getValue();
long[] partitionOffsets = value.write(compressedLocalOut);
OperatorStateHandle.Mode mode = value.getStateMetaInfo().getAssignmentMode();
writtenStatesMetaData.put(
entry.getKey(),
new OperatorStateHandle.StateMetaInfo(partitionOffsets, mode));
}
}

// ... and the broadcast states themselves ...
for (Map.Entry<String, BackendWritableBroadcastState<?, ?>> entry :
registeredBroadcastStatesDeepCopies.entrySet()) {
// ... and the broadcast states themselves ...
for (Map.Entry<String, BackendWritableBroadcastState<?, ?>> entry :
registeredBroadcastStatesDeepCopies.entrySet()) {

BackendWritableBroadcastState<?, ?> value = entry.getValue();
// create the compressed stream for each state to have the compression header for
// each
try (final CompressibleFSDataOutputStream compressedLocalOut =
new CompressibleFSDataOutputStream(
localOut,
compressionDecorator)) { // closes only the outer compression stream
BackendWritableBroadcastState<?, ?> value = entry.getValue();
long[] partitionOffsets = {value.write(compressedLocalOut)};
OperatorStateHandle.Mode mode = value.getStateMetaInfo().getAssignmentMode();
writtenStatesMetaData.put(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,24 +177,20 @@ public Void restore() throws Exception {
restoredBroadcastMetaInfoSnapshots.forEach(
stateName -> toRestore.add(stateName.getName()));

for (String stateName : toRestore) {

final OperatorStateHandle.StateMetaInfo offsets =
stateHandle.getStateNameToPartitionOffsets().get(stateName);

PartitionableListState<?> listStateForName =
registeredOperatorStates.get(stateName);
final StreamCompressionDecorator compressionDecorator =
backendSerializationProxy.isUsingStateCompression()
? SnappyStreamCompressionDecorator.INSTANCE
: UncompressedStreamCompressionDecorator.INSTANCE;
// create the compressed stream for each state to have the compression header
// for each
try (final CompressibleFSDataInputStream compressedIn =
new CompressibleFSDataInputStream(
in,
compressionDecorator)) { // closes only the outer compression
// stream
final StreamCompressionDecorator compressionDecorator =
backendSerializationProxy.isUsingStateCompression()
? SnappyStreamCompressionDecorator.INSTANCE
: UncompressedStreamCompressionDecorator.INSTANCE;

try (final CompressibleFSDataInputStream compressedIn =
new CompressibleFSDataInputStream(
in,
compressionDecorator)) { // closes only the outer compression stream
for (String stateName : toRestore) {
final OperatorStateHandle.StateMetaInfo offsets =
stateHandle.getStateNameToPartitionOffsets().get(stateName);
PartitionableListState<?> listStateForName =
registeredOperatorStates.get(stateName);
if (listStateForName == null) {
BackendWritableBroadcastState<?, ?> broadcastStateForName =
registeredBroadcastStates.get(stateName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.core.fs.CloseableRegistry;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.checkpoint.RoundRobinOperatorStateRepartitioner;
import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory;

import org.junit.jupiter.params.ParameterizedTest;
Expand All @@ -38,6 +39,8 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static org.assertj.core.api.Assertions.assertThat;

Expand Down Expand Up @@ -157,8 +160,7 @@ void testRestoringMixedOperatorState(boolean snapshotCompressionEnabled) throws

@ParameterizedTest
@ValueSource(booleans = {true, false})
void testRestoreAndRescalePartitionedOperatorState(boolean snapshotCompressionEnabled)
throws Exception {
void testMergeOperatorState(boolean snapshotCompressionEnabled) throws Exception {
final ExecutionConfig cfg = new ExecutionConfig();
cfg.setUseSnapshotCompression(snapshotCompressionEnabled);
final ThrowingFunction<Collection<OperatorStateHandle>, OperatorStateBackend>
Expand Down Expand Up @@ -213,14 +215,75 @@ void testEmptyPartitionedOperatorState(boolean snapshotCompressionEnabled) throw
listStates.put("bufferState", Collections.emptyList());
listStates.put("offsetState", Collections.singletonList("foo"));

final Map<String, Map<String, String>> broadcastStates = new HashMap<>();
broadcastStates.put("whateverState", Collections.emptyMap());

final OperatorStateHandle stateHandle =
createOperatorStateHandle(
operatorStateBackendFactory, listStates, Collections.emptyMap());
createOperatorStateHandle(operatorStateBackendFactory, listStates, broadcastStates);

verifyOperatorStateHandle(
operatorStateBackendFactory,
Collections.singletonList(stateHandle),
listStates,
Collections.emptyMap());
broadcastStates);
}

@ParameterizedTest
@ValueSource(booleans = {true, false})
void testRepartitionOperatorState(boolean snapshotCompressionEnabled) throws Exception {
final ExecutionConfig cfg = new ExecutionConfig();
cfg.setUseSnapshotCompression(snapshotCompressionEnabled);
final ThrowingFunction<Collection<OperatorStateHandle>, OperatorStateBackend>
operatorStateBackendFactory =
createOperatorStateBackendFactory(
cfg, new CloseableRegistry(), this.getClass().getClassLoader());

final Map<String, List<String>> listStates = new HashMap<>();
listStates.put(
"bufferState",
IntStream.range(0, 10).mapToObj(idx -> "foo" + idx).collect(Collectors.toList()));
listStates.put(
"offsetState",
IntStream.range(0, 10).mapToObj(idx -> "bar" + idx).collect(Collectors.toList()));

final OperatorStateHandle stateHandle =
createOperatorStateHandle(
operatorStateBackendFactory, listStates, Collections.emptyMap());

for (int newParallelism : Arrays.asList(1, 2, 5, 10)) {
final RoundRobinOperatorStateRepartitioner partitioner =
new RoundRobinOperatorStateRepartitioner();
final List<List<OperatorStateHandle>> repartitioned =
partitioner.repartitionState(
Collections.singletonList(Collections.singletonList(stateHandle)),
1,
newParallelism);
for (int idx = 0; idx < newParallelism; idx++) {
verifyOperatorStateHandle(
operatorStateBackendFactory,
repartitioned.get(idx),
getExpectedSplit(listStates, newParallelism, idx),
Collections.emptyMap());
}
}
}

/**
* This is a simplified version of what RR partitioner does, so it only works in case there is
* no remainder.
*/
private static Map<String, List<String>> getExpectedSplit(
Map<String, List<String>> states, int newParallelism, int idx) {
final Map<String, List<String>> newStates = new HashMap<>();
for (String stateName : states.keySet()) {
final int stateSize = states.get(stateName).size();
newStates.put(
stateName,
states.get(stateName)
.subList(
idx * stateSize / newParallelism,
(idx + 1) * stateSize / newParallelism));
}
return newStates;
}
}

0 comments on commit b7c314b

Please sign in to comment.