Skip to content

Commit

Permalink
Merge pull request #21 from petro-rudenko/spark-3.0-support_2
Browse files Browse the repository at this point in the history
[CORE] Spark-3.0 reducer protocol.
  • Loading branch information
yosefe authored Mar 30, 2020
2 parents ab24e1d + 161c1ae commit 43f8a4d
Show file tree
Hide file tree
Showing 12 changed files with 435 additions and 30 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/sparkucx-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ on:

jobs:
build-sparkucx:
strategy:
matrix:
spark_version: [2.4, 3.0]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
Expand All @@ -21,6 +24,6 @@ jobs:
run: mvn -B package -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn
--file pom.xml
- name: Run Sonar code analysis
run: mvn -B sonar:sonar -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -Dsonar.projectKey=openucx:spark-ucx -Dsonar.organization=openucx -Dsonar.host.url=https://sonarcloud.io -Dsonar.login=97f4df88ff4fa04e2d5b061acf07315717f1f08b
run: mvn -B sonar:sonar -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -Dsonar.projectKey=openucx:spark-ucx -Dsonar.organization=openucx -Dsonar.host.url=https://sonarcloud.io -Dsonar.login=97f4df88ff4fa04e2d5b061acf07315717f1f08b -Pspark-${{ matrix.spark_version }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
1 change: 1 addition & 0 deletions buildlib/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ setup_configuration() {

cat <<-EOF > ${SPARK_CONF_DIR}/spark-defaults.conf
spark.shuffle.manager org.apache.spark.shuffle.UcxShuffleManager
spark.shuffle.sort.io.plugin.class org.apache.spark.shuffle.compat.spark_3_0.UcxLocalDiskShuffleDataIO
spark.shuffle.readHostLocalDisk.enabled false
spark.driver.extraClassPath ${SPARK_UCX_JAR}:${UCX_LIB}
spark.executor.extraClassPath ${SPARK_UCX_JAR}:${UCX_LIB}
Expand Down
4 changes: 3 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ See file LICENSE for terms.
<properties>
<spark.version>2.4.0</spark.version>
<project.excludes>**/spark_3_0/**</project.excludes>
<sonar.exclusions>**/spark_3_0/**</sonar.exclusions>
<scala.version>2.11.12</scala.version>
<scala.compat.version>2.11</scala.compat.version>
</properties>
Expand All @@ -53,6 +54,7 @@ See file LICENSE for terms.
<scala.version>2.12.10</scala.version>
<scala.compat.version>2.12</scala.compat.version>
<project.excludes>**/spark_2_4/**</project.excludes>
<sonar.exclusions>**/spark_2_4/**</sonar.exclusions>
</properties>
</profile>
</profiles>
Expand All @@ -68,7 +70,7 @@ See file LICENSE for terms.
<dependency>
<groupId>org.openucx</groupId>
<artifactId>jucx</artifactId>
<version>1.8.0-SNAPSHOT</version>
<version>1.9.0-SNAPSHOT</version>
</dependency>
</dependencies>

Expand Down
3 changes: 3 additions & 0 deletions src/main/java/org/apache/spark/shuffle/ucx/UnsafeUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ public class UnsafeUtils {

private static final Constructor<?> directBufferConstructor;

public static final int LONG_SIZE = 8;
public static final int INT_SIZE = 4;

static {
try {
mmap = FileChannelImpl.class.getDeclaredMethod("map0", int.class, long.class, long.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package org.apache.spark.shuffle.ucx.reducer.compat.spark_2_4;

import org.apache.spark.network.shuffle.BlockFetchingListener;
import org.apache.spark.shuffle.ucx.UnsafeUtils;
import org.apache.spark.shuffle.ucx.memory.RegisteredMemory;
import org.apache.spark.shuffle.ucx.reducer.OnBlocksFetchCallback;
import org.apache.spark.shuffle.ucx.reducer.ReducerCallback;
Expand Down Expand Up @@ -39,16 +40,18 @@ public void onSuccess(UcpRequest request) {
ByteBuffer resultOffset = offsetMemory.getBuffer();
long totalSize = 0;
int[] sizes = new int[blockIds.length];
int offsetSize = UnsafeUtils.LONG_SIZE;
for (int i = 0; i < blockIds.length; i++) {
long blockOffset = resultOffset.getLong(i * 16);
long blockLength = resultOffset.getLong(i * 16 + 8) - blockOffset;
assert (blockLength > 0) && (blockLength < Integer.MAX_VALUE);
// Blocks in metadata buffer are in form | blockOffsetStart | blockOffsetEnd |
long blockOffset = resultOffset.getLong(i * 2 * offsetSize);
long blockLength = resultOffset.getLong(i * 2 * offsetSize + offsetSize) - blockOffset;
assert (blockLength > 0) && (blockLength <= Integer.MAX_VALUE);
sizes[i] = (int) blockLength;
totalSize += blockLength;
dataAddresses[i] += blockOffset;
}

assert (totalSize > 0) && (totalSize < Integer.MAX_VALUE);
assert (totalSize > 0) && (totalSize < Integer.MAX_VALUE);
mempool.put(offsetMemory);
RegisteredMemory blocksMemory = mempool.get((int) totalSize);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.apache.spark.network.shuffle.DownloadFileManager;
import org.apache.spark.network.shuffle.ShuffleClient;
import org.apache.spark.shuffle.*;
import org.apache.spark.shuffle.ucx.UnsafeUtils;
import org.apache.spark.shuffle.ucx.memory.MemoryPool;
import org.apache.spark.shuffle.ucx.memory.RegisteredMemory;
import org.apache.spark.storage.BlockId;
Expand Down Expand Up @@ -61,10 +62,10 @@ private void submitFetchOffsets(UcpEndpoint endpoint, ShuffleBlockId[] blockIds,
endpoint.unpackRemoteKey(driverMetadata.dataRkey(blockId.mapId())));

endpoint.getNonBlockingImplicit(
offsetAddress + blockId.reduceId() * UcxWorkerWrapper.LONG_SIZE(),
offsetAddress + blockId.reduceId() * UnsafeUtils.LONG_SIZE,
offsetRkeysCache.get(blockId.mapId()),
UcxUtils.getAddress(offsetMemory.getBuffer()) + (i * 2L * UcxWorkerWrapper.LONG_SIZE()),
2L * UcxWorkerWrapper.LONG_SIZE());
UcxUtils.getAddress(offsetMemory.getBuffer()) + (i * 2L * UnsafeUtils.LONG_SIZE),
2L * UnsafeUtils.LONG_SIZE);
}
}

Expand All @@ -85,7 +86,7 @@ public void fetchBlocks(String host, int port, String execId,
long[] dataAddresses = new long[blockIds.length];

// Need to fetch 2 long offsets current block + next block to calculate exact block size.
RegisteredMemory offsetMemory = mempool.get(2 * UcxWorkerWrapper.LONG_SIZE() * blockIds.length);
RegisteredMemory offsetMemory = mempool.get(2 * UnsafeUtils.LONG_SIZE * blockIds.length);

ShuffleBlockId[] shuffleBlockIds = Arrays.stream(blockIds)
.map(blockId -> (ShuffleBlockId) BlockId.apply(blockId)).toArray(ShuffleBlockId[]::new);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
* See file LICENSE for terms.
*/
package org.apache.spark.shuffle.ucx.reducer.compat.spark_3_0;

import org.apache.spark.network.shuffle.BlockFetchingListener;
import org.apache.spark.shuffle.UcxWorkerWrapper;
import org.apache.spark.shuffle.ucx.UnsafeUtils;
import org.apache.spark.shuffle.ucx.memory.RegisteredMemory;
import org.apache.spark.shuffle.ucx.reducer.ReducerCallback;
import org.apache.spark.shuffle.ucx.reducer.OnBlocksFetchCallback;
import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.ShuffleBlockBatchId;
import org.apache.spark.storage.ShuffleBlockId;
import org.openucx.jucx.UcxUtils;
import org.openucx.jucx.ucp.UcpEndpoint;
import org.openucx.jucx.ucp.UcpRemoteKey;
import org.openucx.jucx.ucp.UcpRequest;

import java.nio.ByteBuffer;
import java.util.Map;

/**
* Callback, called when got all offsets for blocks
*/
public class OnOffsetsFetchCallback extends ReducerCallback {
private final RegisteredMemory offsetMemory;
private final long[] dataAddresses;
private Map<Integer, UcpRemoteKey> dataRkeysCache;
private final Map<Long, Integer> mapId2PartitionId;

public OnOffsetsFetchCallback(BlockId[] blockIds, UcpEndpoint endpoint, BlockFetchingListener listener,
RegisteredMemory offsetMemory, long[] dataAddresses,
Map<Integer, UcpRemoteKey> dataRkeysCache,
Map<Long, Integer> mapId2PartitionId) {
super(blockIds, endpoint, listener);
this.offsetMemory = offsetMemory;
this.dataAddresses = dataAddresses;
this.dataRkeysCache = dataRkeysCache;
this.mapId2PartitionId = mapId2PartitionId;
}

@Override
public void onSuccess(UcpRequest request) {
ByteBuffer resultOffset = offsetMemory.getBuffer();
long totalSize = 0;
int[] sizes = new int[blockIds.length];
int offset = 0;
long blockOffset;
long blockLength;
int offsetSize = UnsafeUtils.LONG_SIZE;
for (int i = 0; i < blockIds.length; i++) {
// Blocks in metadata buffer are in form | blockOffsetStart | blockOffsetEnd |
if (blockIds[i] instanceof ShuffleBlockBatchId) {
ShuffleBlockBatchId blockBatchId = (ShuffleBlockBatchId) blockIds[i];
int blocksInBatch = blockBatchId.endReduceId() - blockBatchId.startReduceId();
blockOffset = resultOffset.getLong(offset * 2 * offsetSize);
blockLength = resultOffset.getLong(offset * 2 * offsetSize + offsetSize * blocksInBatch)
- blockOffset;
offset += blocksInBatch;
} else {
blockOffset = resultOffset.getLong(offset * 16);
blockLength = resultOffset.getLong(offset * 16 + 8) - blockOffset;
offset++;
}

assert (blockLength > 0) && (blockLength <= Integer.MAX_VALUE);
sizes[i] = (int) blockLength;
totalSize += blockLength;
dataAddresses[i] += blockOffset;
}

assert (totalSize > 0) && (totalSize < Integer.MAX_VALUE);
mempool.put(offsetMemory);
RegisteredMemory blocksMemory = mempool.get((int) totalSize);

offset = 0;
// Submits N fetch blocks requests
for (int i = 0; i < blockIds.length; i++) {
int mapPartitionId = (blockIds[i] instanceof ShuffleBlockId) ?
mapId2PartitionId.get(((ShuffleBlockId)blockIds[i]).mapId()) :
mapId2PartitionId.get(((ShuffleBlockBatchId)blockIds[i]).mapId());
endpoint.getNonBlockingImplicit(dataAddresses[i], dataRkeysCache.get(mapPartitionId),
UcxUtils.getAddress(blocksMemory.getBuffer()) + offset, sizes[i]);
offset += sizes[i];
}

// Process blocks when all fetched.
// Flush guarantees that callback would invoke when all fetch requests will completed.
endpoint.flushNonBlocking(new OnBlocksFetchCallback(this, blocksMemory, sizes));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,133 @@
*/
package org.apache.spark.shuffle.ucx.reducer.compat.spark_3_0;

import org.apache.spark.SparkEnv;
import org.apache.spark.executor.TempShuffleReadMetrics;
import org.apache.spark.network.shuffle.BlockFetchingListener;
import org.apache.spark.network.shuffle.BlockStoreClient;
import org.apache.spark.network.shuffle.DownloadFileManager;
import org.apache.spark.shuffle.DriverMetadata;
import org.apache.spark.shuffle.UcxShuffleManager;
import org.apache.spark.shuffle.UcxWorkerWrapper;
import org.apache.spark.shuffle.ucx.UnsafeUtils;
import org.apache.spark.shuffle.ucx.memory.RegisteredMemory;
import org.apache.spark.storage.*;
import org.openucx.jucx.UcxUtils;
import org.openucx.jucx.ucp.UcpEndpoint;
import org.openucx.jucx.ucp.UcpRemoteKey;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Option;


import java.util.HashMap;
import java.util.Map;

public class UcxShuffleClient extends BlockStoreClient {
private static final Logger logger = LoggerFactory.getLogger(UcxShuffleClient.class);
private final UcxWorkerWrapper workerWrapper;
private final Map<Long, Integer> mapId2PartitionId;
private final TempShuffleReadMetrics shuffleReadMetrics;
private final int shuffleId;
final HashMap<Integer, UcpRemoteKey> offsetRkeysCache = new HashMap<>();
final HashMap<Integer, UcpRemoteKey> dataRkeysCache = new HashMap<>();

@Override
public void close() {
throw new UnsupportedOperationException("TODO");

public UcxShuffleClient(int shuffleId, UcxWorkerWrapper workerWrapper,
Map<Long, Integer> mapId2PartitionId, TempShuffleReadMetrics shuffleReadMetrics) {
this.workerWrapper = workerWrapper;
this.shuffleId = shuffleId;
this.mapId2PartitionId = mapId2PartitionId;
this.shuffleReadMetrics = shuffleReadMetrics;
}

/**
* Submits n non blocking fetch offsets to get needed offsets for n blocks.
*/
private void submitFetchOffsets(UcpEndpoint endpoint, BlockId[] blockIds,
RegisteredMemory offsetMemory,
long[] dataAddresses) {
DriverMetadata driverMetadata = workerWrapper.fetchDriverMetadataBuffer(shuffleId);
long offset = 0;
int startReduceId;
long size;

for (int i = 0; i < blockIds.length; i++) {
BlockId blockId = blockIds[i];
int mapIdpartition;

if (blockId instanceof ShuffleBlockId) {
ShuffleBlockId shuffleBlockId = (ShuffleBlockId) blockId;
mapIdpartition = mapId2PartitionId.get(shuffleBlockId.mapId());
size = 2L * UnsafeUtils.LONG_SIZE;
startReduceId = shuffleBlockId.reduceId();
} else {
ShuffleBlockBatchId shuffleBlockBatchId = (ShuffleBlockBatchId) blockId;
mapIdpartition = mapId2PartitionId.get(shuffleBlockBatchId.mapId());
size = (shuffleBlockBatchId.endReduceId() - shuffleBlockBatchId.startReduceId())
* 2L * UnsafeUtils.LONG_SIZE;
startReduceId = shuffleBlockBatchId.startReduceId();
}

long offsetAddress = driverMetadata.offsetAddress(mapIdpartition);
dataAddresses[i] = driverMetadata.dataAddress(mapIdpartition);

offsetRkeysCache.computeIfAbsent(mapIdpartition, mapId ->
endpoint.unpackRemoteKey(driverMetadata.offsetRkey(mapIdpartition)));

dataRkeysCache.computeIfAbsent(mapIdpartition, mapId ->
endpoint.unpackRemoteKey(driverMetadata.dataRkey(mapIdpartition)));

endpoint.getNonBlockingImplicit(
offsetAddress + startReduceId * UnsafeUtils.LONG_SIZE,
offsetRkeysCache.get(mapIdpartition),
UcxUtils.getAddress(offsetMemory.getBuffer()) + offset,
size);

offset += size;
}
}

@Override
public void fetchBlocks(String host, int port, String execId, String[] blockIds, BlockFetchingListener listener,
DownloadFileManager downloadFileManager) {
throw new UnsupportedOperationException("TODO");
long startTime = System.currentTimeMillis();
BlockManagerId blockManagerId = BlockManagerId.apply(execId, host, port, Option.empty());
UcpEndpoint endpoint = workerWrapper.getConnection(blockManagerId);
long[] dataAddresses = new long[blockIds.length];
int totalBlocks = 0;

BlockId[] blocks = new BlockId[blockIds.length];

for (int i = 0; i < blockIds.length; i++) {
blocks[i] = BlockId.apply(blockIds[i]);
if (blocks[i] instanceof ShuffleBlockId) {
totalBlocks += 1;
} else {
ShuffleBlockBatchId blockBatchId = (ShuffleBlockBatchId)blocks[i];
totalBlocks += (blockBatchId.endReduceId() - blockBatchId.startReduceId());
}
}

RegisteredMemory offsetMemory = ((UcxShuffleManager)SparkEnv.get().shuffleManager())
.ucxNode().getMemoryPool().get(totalBlocks * 2 * UnsafeUtils.LONG_SIZE);
// Submits N implicit get requests without callback
submitFetchOffsets(endpoint, blocks, offsetMemory, dataAddresses);

// flush guarantees that all that requests completes when callback is called.
// TODO: fix https://github.com/openucx/ucx/issues/4267 and use endpoint flush.
workerWrapper.worker().flushNonBlocking(
new OnOffsetsFetchCallback(blocks, endpoint, listener, offsetMemory,
dataAddresses, dataRkeysCache, mapId2PartitionId));

shuffleReadMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime);
}

@Override
public void close() {
offsetRkeysCache.values().forEach(UcpRemoteKey::close);
dataRkeysCache.values().forEach(UcpRemoteKey::close);
logger.info("Shuffle read metrics, fetch wait time: {}ms", shuffleReadMetrics.fetchWaitTime());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
*/
package org.apache.spark.shuffle

import java.io.{Closeable, File, RandomAccessFile}
import java.io.{File, RandomAccessFile}
import java.util.concurrent.{ConcurrentHashMap, CopyOnWriteArrayList}

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -50,7 +50,7 @@ abstract class CommonUcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffle
}
val dataMemory = ucxShuffleManager.ucxNode.getContext.memoryMap(memMapParams)
fileMappings(shuffleId).add(dataMemory)
assume(indexBackFile.length() == UcxWorkerWrapper.LONG_SIZE * (lengths.length + 1))
assume(indexBackFile.length() == UnsafeUtils.LONG_SIZE * (lengths.length + 1))

val offsetAddress = UnsafeUtils.mmap(indexFileChannel, 0, indexBackFile.length())
memMapParams.setAddress(offsetAddress).setLength(indexBackFile.length())
Expand All @@ -70,8 +70,8 @@ abstract class CommonUcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffle
val metadataBuffer = metadataRegisteredMemory.getBuffer.slice()

if (metadataBuffer.remaining() > ucxShuffleManager.ucxShuffleConf.metadataBlockSize) {
throw new SparkException(s"Metadata block size ${metadataBuffer.remaining()} " +
s"is greater then configured 2 * ${ucxShuffleManager.ucxShuffleConf.RKEY_SIZE.key}" +
throw new SparkException(s"Metadata block size ${metadataBuffer.remaining() / 2} " +
s"is greater then configured ${ucxShuffleManager.ucxShuffleConf.RKEY_SIZE.key}" +
s"(${ucxShuffleManager.ucxShuffleConf.metadataBlockSize}).")
}

Expand Down
Loading

0 comments on commit 43f8a4d

Please sign in to comment.