diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java index 22be1f71f7ade..aedeea683c8dd 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java @@ -51,6 +51,7 @@ import org.apache.flink.api.java.tuple.Tuple; import org.apache.flink.api.java.typeutils.InputTypeConfigurable; import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.configuration.MemorySize; import org.apache.flink.configuration.RpcOptions; import org.apache.flink.core.execution.JobClient; import org.apache.flink.core.fs.FileSystem.WriteMode; @@ -109,6 +110,7 @@ import org.apache.flink.util.OutputTag; import org.apache.flink.util.Preconditions; +import java.time.Duration; import java.util.ArrayList; import java.util.List; import java.util.UUID; @@ -1439,8 +1441,13 @@ public void collectAsync(Collector collector) { String accumulatorName = "dataStreamCollect_" + UUID.randomUUID().toString(); StreamExecutionEnvironment env = getExecutionEnvironment(); + MemorySize maxBatchSize = + env.getConfiguration().get(CollectSinkOperatorFactory.MAX_BATCH_SIZE); + Duration socketTimeout = + env.getConfiguration().get(CollectSinkOperatorFactory.SOCKET_TIMEOUT); CollectSinkOperatorFactory factory = - new CollectSinkOperatorFactory<>(serializer, accumulatorName); + new CollectSinkOperatorFactory<>( + serializer, accumulatorName, maxBatchSize, socketTimeout); CollectSinkOperator operator = (CollectSinkOperator) factory.getOperator(); long resultFetchTimeout = env.getConfiguration().get(RpcOptions.ASK_TIMEOUT_DURATION).toMillis(); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/collect/CollectSinkFunction.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/collect/CollectSinkFunction.java index a71430fd3b702..91fc44ca44dc8 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/collect/CollectSinkFunction.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/collect/CollectSinkFunction.java @@ -167,6 +167,10 @@ public CollectSinkFunction( this.accumulatorName = accumulatorName; } + public long getMaxBytesPerBatch() { + return maxBytesPerBatch; + } + private void initBuffer() { if (buffer != null) { return; diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/collect/CollectSinkOperatorFactory.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/collect/CollectSinkOperatorFactory.java index aa0d4c2148530..a91a5591a3c49 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/collect/CollectSinkOperatorFactory.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/collect/CollectSinkOperatorFactory.java @@ -58,6 +58,10 @@ public CollectSinkOperatorFactory( this.socketTimeoutMillis = (int) socketTimeout.toMillis(); } + public int getSocketTimeoutMillis() { + return socketTimeoutMillis; + } + @Override @SuppressWarnings("unchecked") public > T createStreamOperator( diff --git a/flink-tests/src/test/java/org/apache/flink/api/datastream/DataStreamCollectTestITCase.java b/flink-tests/src/test/java/org/apache/flink/api/datastream/DataStreamCollectTestITCase.java index a0c87eb29b86e..5d5929a1ec645 100644 --- a/flink-tests/src/test/java/org/apache/flink/api/datastream/DataStreamCollectTestITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/api/datastream/DataStreamCollectTestITCase.java @@ -18,10 +18,16 @@ package org.apache.flink.api.datastream; import org.apache.flink.api.common.RuntimeExecutionMode; +import org.apache.flink.api.dag.Transformation; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.ExecutionOptions; +import org.apache.flink.configuration.MemorySize; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.operators.collect.CollectSinkFunction; +import org.apache.flink.streaming.api.operators.collect.CollectSinkOperator; +import org.apache.flink.streaming.api.operators.collect.CollectSinkOperatorFactory; +import org.apache.flink.streaming.api.transformations.LegacySinkTransformation; import org.apache.flink.util.CloseableIterator; import org.apache.flink.util.CollectionUtil; import org.apache.flink.util.TestLogger; @@ -30,6 +36,7 @@ import org.junit.Assert; import org.junit.Test; +import java.time.Duration; import java.util.List; import java.util.function.Consumer; @@ -111,6 +118,31 @@ public void testBoundedCollectAndLimit() throws Exception { results.size()); } + @Test + public void testAsyncCollectWithSinkConfigs() { + Configuration configuration = new Configuration(); + configuration.set(CollectSinkOperatorFactory.SOCKET_TIMEOUT, Duration.ofMillis(2)); + configuration.set(CollectSinkOperatorFactory.MAX_BATCH_SIZE, new MemorySize(3)); + final StreamExecutionEnvironment env = + StreamExecutionEnvironment.getExecutionEnvironment(configuration); + + final DataStream stream = env.fromData(1, 2, 3, 4, 5); + stream.collectAsync(); + + List> transformations = env.getTransformations(); + Assert.assertEquals(1, transformations.size()); + LegacySinkTransformation transformation = + (LegacySinkTransformation) transformations.get(transformations.size() - 1); + CollectSinkOperatorFactory collectSinkOperatorFactory = + (CollectSinkOperatorFactory) transformation.getOperatorFactory(); + CollectSinkFunction collectSinkFunction = + ((CollectSinkFunction) + ((CollectSinkOperator) collectSinkOperatorFactory.getOperator()) + .getUserFunction()); + Assert.assertEquals(2, collectSinkOperatorFactory.getSocketTimeoutMillis()); + Assert.assertEquals(3, collectSinkFunction.getMaxBytesPerBatch()); + } + @Test public void testAsyncCollect() throws Exception { final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();