Skip to content

Commit

Permalink
[#1608] feat: Introduce ExpiringClosableSupplier and refactor Shuffle…
Browse files Browse the repository at this point in the history
…ManagerClient creation (#1838)

### What changes were proposed in this pull request?
1. Introduce StatefulCloseable and ExpiringClosableSupplier
2. refactor ShuffleManagerClient to leverage ExpiringClosableSupplier

### Why are the changes needed?
For better code quality

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Existing UTs and new UTs.
  • Loading branch information
xumanbu authored Jul 26, 2024
1 parent fa87381 commit 457c865
Show file tree
Hide file tree
Showing 23 changed files with 545 additions and 220 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@

package org.apache.spark.shuffle;

import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;

import scala.Option;
import scala.reflect.ClassTag;
Expand All @@ -43,21 +43,18 @@
import org.apache.uniffle.client.api.CoordinatorClient;
import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.client.factory.CoordinatorClientFactory;
import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest;
import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
import org.apache.uniffle.client.util.ClientUtils;
import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.exception.RssFetchFailedException;
import org.apache.uniffle.common.util.Constants;

import static org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED;
import static org.apache.uniffle.common.util.Constants.DRIVER_HOST;

public class RssSparkShuffleUtils {

Expand Down Expand Up @@ -346,6 +343,7 @@ public static boolean isStageResubmitSupported() {
}

public static RssException reportRssFetchFailedException(
Supplier<ShuffleManagerClient> managerClientSupplier,
RssFetchFailedException rssFetchFailedException,
SparkConf sparkConf,
String appId,
Expand All @@ -355,32 +353,24 @@ public static RssException reportRssFetchFailedException(
RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
if (rssConf.getBoolean(RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED)
&& RssSparkShuffleUtils.isStageResubmitSupported()) {
String driver = rssConf.getString(DRIVER_HOST, "");
int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
try (ShuffleManagerClient client =
ShuffleManagerClientFactory.getInstance()
.createShuffleManagerClient(ClientType.GRPC, driver, port)) {
// todo: Create a new rpc interface to report failures in batch.
for (int partitionId : failedPartitions) {
RssReportShuffleFetchFailureRequest req =
new RssReportShuffleFetchFailureRequest(
appId,
shuffleId,
stageAttemptId,
partitionId,
rssFetchFailedException.getMessage());
RssReportShuffleFetchFailureResponse response = client.reportShuffleFetchFailure(req);
if (response.getReSubmitWholeStage()) {
// since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1
// is provided.
FetchFailedException ffe =
RssSparkShuffleUtils.createFetchFailedException(
shuffleId, -1, partitionId, rssFetchFailedException);
return new RssException(ffe);
}
for (int partitionId : failedPartitions) {
RssReportShuffleFetchFailureRequest req =
new RssReportShuffleFetchFailureRequest(
appId,
shuffleId,
stageAttemptId,
partitionId,
rssFetchFailedException.getMessage());
RssReportShuffleFetchFailureResponse response =
managerClientSupplier.get().reportShuffleFetchFailure(req);
if (response.getReSubmitWholeStage()) {
// since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1
// is provided.
FetchFailedException ffe =
RssSparkShuffleUtils.createFetchFailedException(
shuffleId, -1, partitionId, rssFetchFailedException);
return new RssException(ffe);
}
} catch (IOException ioe) {
LOG.info("Error closing shuffle manager client with error:", ioe);
}
}
return rssFetchFailedException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

package org.apache.spark.shuffle.reader;

import java.io.IOException;
import java.util.Objects;
import java.util.function.Supplier;

import scala.Product2;
import scala.collection.AbstractIterator;
Expand All @@ -30,10 +30,8 @@
import org.slf4j.LoggerFactory;

import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest;
import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.exception.RssFetchFailedException;

Expand All @@ -52,8 +50,7 @@ public static class Builder {
private int shuffleId;
private int partitionId;
private int stageAttemptId;
private String reportServerHost;
private int reportServerPort;
private Supplier<ShuffleManagerClient> managerClientSupplier;

private Builder() {}

Expand All @@ -77,19 +74,13 @@ Builder stageAttemptId(int stageAttemptId) {
return this;
}

Builder reportServerHost(String host) {
this.reportServerHost = host;
return this;
}

Builder port(int port) {
this.reportServerPort = port;
Builder managerClientSupplier(Supplier<ShuffleManagerClient> managerClientSupplier) {
this.managerClientSupplier = managerClientSupplier;
return this;
}

<K, C> RssFetchFailedIterator<K, C> build(Iterator<Product2<K, C>> iter) {
Objects.requireNonNull(this.appId);
Objects.requireNonNull(this.reportServerHost);
return new RssFetchFailedIterator<>(this, iter);
}
}
Expand All @@ -98,37 +89,23 @@ static Builder newBuilder() {
return new Builder();
}

private static ShuffleManagerClient createShuffleManagerClient(String host, int port)
throws IOException {
ClientType grpc = ClientType.GRPC;
// host is passed from spark.driver.bindAddress, which would be set when SparkContext is
// constructed.
return ShuffleManagerClientFactory.getInstance().createShuffleManagerClient(grpc, host, port);
}

private RssException generateFetchFailedIfNecessary(RssFetchFailedException e) {
String driver = builder.reportServerHost;
int port = builder.reportServerPort;
// todo: reuse this manager client if this is a bottleneck.
try (ShuffleManagerClient client = createShuffleManagerClient(driver, port)) {
RssReportShuffleFetchFailureRequest req =
new RssReportShuffleFetchFailureRequest(
builder.appId,
builder.shuffleId,
builder.stageAttemptId,
builder.partitionId,
e.getMessage());
RssReportShuffleFetchFailureResponse response = client.reportShuffleFetchFailure(req);
if (response.getReSubmitWholeStage()) {
// since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 is
// provided.
FetchFailedException ffe =
RssSparkShuffleUtils.createFetchFailedException(
builder.shuffleId, -1, builder.partitionId, e);
return new RssException(ffe);
}
} catch (IOException ioe) {
LOG.info("Error closing shuffle manager client with error:", ioe);
ShuffleManagerClient client = builder.managerClientSupplier.get();
RssReportShuffleFetchFailureRequest req =
new RssReportShuffleFetchFailureRequest(
builder.appId,
builder.shuffleId,
builder.stageAttemptId,
builder.partitionId,
e.getMessage());
RssReportShuffleFetchFailureResponse response = client.reportShuffleFetchFailure(req);
if (response.getReSubmitWholeStage()) {
// since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 is
// provided.
FetchFailedException ffe =
RssSparkShuffleUtils.createFetchFailedException(
builder.shuffleId, -1, builder.partitionId, e);
return new RssException(ffe);
}
return e;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import org.roaringbitmap.longlong.Roaring64NavigableMap;
Expand All @@ -41,16 +42,16 @@
* driver side.
*/
public class BlockIdSelfManagedShuffleWriteClient extends ShuffleWriteClientImpl {
private ShuffleManagerClient shuffleManagerClient;
private Supplier<ShuffleManagerClient> managerClientSupplier;

public BlockIdSelfManagedShuffleWriteClient(
RssShuffleClientFactory.ExtendWriteClientBuilder builder) {
super(builder);

if (builder.getShuffleManagerClient() == null) {
if (builder.getManagerClientSupplier() == null) {
throw new RssException("Illegal empty shuffleManagerClient. This should not happen");
}
this.shuffleManagerClient = builder.getShuffleManagerClient();
this.managerClientSupplier = builder.getManagerClientSupplier();
}

@Override
Expand All @@ -73,7 +74,7 @@ public void reportShuffleResult(
RssReportShuffleResultRequest request =
new RssReportShuffleResultRequest(
appId, shuffleId, taskAttemptId, partitionToBlockIds, bitmapNum);
shuffleManagerClient.reportShuffleResult(request);
managerClientSupplier.get().reportShuffleResult(request);
}

@Override
Expand All @@ -85,7 +86,7 @@ public Roaring64NavigableMap getShuffleResult(
int partitionId) {
RssGetShuffleResultRequest request =
new RssGetShuffleResultRequest(appId, shuffleId, partitionId, BlockIdLayout.DEFAULT);
return shuffleManagerClient.getShuffleResult(request).getBlockIdBitmap();
return managerClientSupplier.get().getShuffleResult(request).getBlockIdBitmap();
}

@Override
Expand All @@ -101,6 +102,6 @@ public Roaring64NavigableMap getShuffleResultForMultiPart(
RssGetShuffleResultForMultiPartRequest request =
new RssGetShuffleResultForMultiPartRequest(
appId, shuffleId, partitionIds, BlockIdLayout.DEFAULT);
return shuffleManagerClient.getShuffleResultForMultiPart(request).getBlockIdBitmap();
return managerClientSupplier.get().getShuffleResultForMultiPart(request).getBlockIdBitmap();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.uniffle.shuffle;

import java.util.function.Supplier;

import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.factory.ShuffleClientFactory;
Expand All @@ -41,18 +43,18 @@ public static ExtendWriteClientBuilder<?> newWriteBuilder() {
public static class ExtendWriteClientBuilder<T extends ExtendWriteClientBuilder<T>>
extends WriteClientBuilder<T> {
private boolean blockIdSelfManagedEnabled;
private ShuffleManagerClient shuffleManagerClient;
private Supplier<ShuffleManagerClient> managerClientSupplier;

public boolean isBlockIdSelfManagedEnabled() {
return blockIdSelfManagedEnabled;
}

public ShuffleManagerClient getShuffleManagerClient() {
return shuffleManagerClient;
public Supplier<ShuffleManagerClient> getManagerClientSupplier() {
return managerClientSupplier;
}

public T shuffleManagerClient(ShuffleManagerClient client) {
this.shuffleManagerClient = client;
public T managerClientSupplier(Supplier<ShuffleManagerClient> managerClientSupplier) {
this.managerClientSupplier = managerClientSupplier;
return self();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import com.google.common.annotations.VisibleForTesting;
Expand Down Expand Up @@ -78,10 +79,12 @@
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.ConfigOption;
import org.apache.uniffle.common.config.RssBaseConf;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.ExpiringCloseableSupplier;
import org.apache.uniffle.common.util.RetryUtils;
import org.apache.uniffle.shuffle.BlockIdManager;

Expand All @@ -104,7 +107,7 @@ public abstract class RssShuffleManagerBase implements RssShuffleManagerInterfac
protected String clientType;

protected SparkConf sparkConf;
protected ShuffleManagerClient shuffleManagerClient;
protected Supplier<ShuffleManagerClient> managerClientSupplier;
protected boolean rssStageRetryEnabled;
protected boolean rssStageRetryForWriteFailureEnabled;
protected boolean rssStageRetryForFetchFailureEnabled;
Expand Down Expand Up @@ -588,7 +591,8 @@ protected synchronized StageAttemptShuffleHandleInfo getRemoteShuffleHandleInfoW
RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest =
new RssPartitionToShuffleServerRequest(shuffleId);
RssReassignOnStageRetryResponse rpcPartitionToShufflerServer =
getOrCreateShuffleManagerClient()
getOrCreateShuffleManagerClientSupplier()
.get()
.getPartitionToShufflerServerWithStageRetry(rssPartitionToShuffleServerRequest);
StageAttemptShuffleHandleInfo shuffleHandleInfo =
StageAttemptShuffleHandleInfo.fromProto(
Expand All @@ -607,25 +611,27 @@ protected synchronized MutableShuffleHandleInfo getRemoteShuffleHandleInfoWithBl
RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest =
new RssPartitionToShuffleServerRequest(shuffleId);
RssReassignOnBlockSendFailureResponse rpcPartitionToShufflerServer =
getOrCreateShuffleManagerClient()
getOrCreateShuffleManagerClientSupplier()
.get()
.getPartitionToShufflerServerWithBlockRetry(rssPartitionToShuffleServerRequest);
MutableShuffleHandleInfo shuffleHandleInfo =
MutableShuffleHandleInfo.fromProto(rpcPartitionToShufflerServer.getHandle());
return shuffleHandleInfo;
}

// todo: automatic close client when the client is idle to avoid too much connections for spark
// driver.
protected ShuffleManagerClient getOrCreateShuffleManagerClient() {
if (shuffleManagerClient == null) {
protected synchronized Supplier<ShuffleManagerClient> getOrCreateShuffleManagerClientSupplier() {
if (managerClientSupplier == null) {
RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
String driver = rssConf.getString("driver.host", "");
int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
this.shuffleManagerClient =
ShuffleManagerClientFactory.getInstance()
.createShuffleManagerClient(ClientType.GRPC, driver, port);
long rpcTimeout = rssConf.getLong(RssBaseConf.RSS_CLIENT_TYPE_GRPC_TIMEOUT_MS);
this.managerClientSupplier =
ExpiringCloseableSupplier.of(
() ->
ShuffleManagerClientFactory.getInstance()
.createShuffleManagerClient(ClientType.GRPC, driver, port, rpcTimeout));
}
return shuffleManagerClient;
return managerClientSupplier;
}

@Override
Expand Down Expand Up @@ -808,6 +814,14 @@ public MutableShuffleHandleInfo reassignOnBlockSendFailure(
}
}

@Override
public void stop() {
if (managerClientSupplier != null
&& managerClientSupplier instanceof ExpiringCloseableSupplier) {
((ExpiringCloseableSupplier<ShuffleManagerClient>) managerClientSupplier).close();
}
}

/**
* Creating the shuffleAssignmentInfo from the servers and partitionIds
*
Expand Down
Loading

0 comments on commit 457c865

Please sign in to comment.