Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1492090 Snowpipe streaming file master key id rotation #786

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ src/main/resources/log4j.properties
src/test/resources/log4j.properties
testOutput/
.cache/
/dependency-reduced-pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.zip.CRC32;
import javax.crypto.BadPaddingException;
import javax.crypto.IllegalBlockSizeException;
Expand Down Expand Up @@ -67,7 +68,8 @@ static <T> Blob constructBlobAndMetadata(
String filePath,
List<List<ChannelData<T>>> blobData,
Constants.BdecVersion bdecVersion,
InternalParameterProvider internalParameterProvider)
InternalParameterProvider internalParameterProvider,
Map<FullyQualifiedTableName, EncryptionKey> encryptionKeysPerTable)
throws IOException, NoSuchPaddingException, NoSuchAlgorithmException,
InvalidAlgorithmParameterException, InvalidKeyException, IllegalBlockSizeException,
BadPaddingException {
Expand All @@ -81,6 +83,13 @@ static <T> Blob constructBlobAndMetadata(
ChannelFlushContext firstChannelFlushContext =
channelsDataPerTable.get(0).getChannelContext();

EncryptionKey encryptionKey =
sfc-gh-bmikaili marked this conversation as resolved.
Show resolved Hide resolved
encryptionKeysPerTable.get(
new FullyQualifiedTableName(
firstChannelFlushContext.getDbName(),
firstChannelFlushContext.getSchemaName(),
firstChannelFlushContext.getTableName()));

Flusher<T> flusher = channelsDataPerTable.get(0).createFlusher();
Flusher.SerializationResult serializedChunk =
flusher.serialize(channelsDataPerTable, filePath, curDataSize);
Expand All @@ -102,9 +111,19 @@ static <T> Blob constructBlobAndMetadata(
// to align with decryption on the Snowflake query path.
// TODO: address alignment for the header SNOW-557866
long iv = curDataSize / Constants.ENCRYPTION_ALGORITHM_BLOCK_SIZE_BYTES;

if (encryptionKey == null)
encryptionKey =
new EncryptionKey(
firstChannelFlushContext.getDbName(),
firstChannelFlushContext.getSchemaName(),
firstChannelFlushContext.getTableName(),
firstChannelFlushContext.getEncryptionKey(),
firstChannelFlushContext.getEncryptionKeyId());

compressedChunkData =
Cryptor.encrypt(
paddedChunkData, firstChannelFlushContext.getEncryptionKey(), filePath, iv);
Cryptor.encrypt(paddedChunkData, encryptionKey.getEncryptionKey(), filePath, iv);

compressedChunkDataSize = compressedChunkData.length;
} else {
compressedChunkData = serializedChunk.chunkData.toByteArray();
Expand All @@ -129,7 +148,7 @@ static <T> Blob constructBlobAndMetadata(
.setUncompressedChunkLength((int) serializedChunk.chunkEstimatedUncompressedSize)
.setChannelList(serializedChunk.channelsMetadataList)
.setChunkMD5(md5)
.setEncryptionKeyId(firstChannelFlushContext.getEncryptionKeyId())
.setEncryptionKeyId(encryptionKey.getEncryptionKeyId())
.setEpInfo(
AbstractRowBuffer.buildEpInfoFromStats(
serializedChunk.rowCount,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package net.snowflake.ingest.streaming.internal;

import com.fasterxml.jackson.annotation.JsonProperty;

/** Represents an encryption key for a table */
public class EncryptionKey {
// Database name
private final String databaseName;

// Schema name
private final String schemaName;

// Table Name
private final String tableName;

String blobTableMasterKey;

long encryptionKeyId;

public EncryptionKey(
@JsonProperty("database") String databaseName,
@JsonProperty("schema") String schemaName,
@JsonProperty("table") String tableName,
@JsonProperty("encryption_key") String blobTableMasterKey,
@JsonProperty("encryption_key_id") long encryptionKeyId) {
this.databaseName = databaseName;
this.schemaName = schemaName;
this.tableName = tableName;
this.blobTableMasterKey = blobTableMasterKey;
this.encryptionKeyId = encryptionKeyId;
}

public EncryptionKey(EncryptionKey encryptionKey) {
this.databaseName = encryptionKey.databaseName;
this.schemaName = encryptionKey.schemaName;
this.tableName = encryptionKey.tableName;
this.blobTableMasterKey = encryptionKey.blobTableMasterKey;
this.encryptionKeyId = encryptionKey.encryptionKeyId;
}

public String getFullyQualifiedTableName() {
return String.format("%s.%s.%s", databaseName, schemaName, tableName);
}

@JsonProperty("database")
public String getDatabaseName() {
return databaseName;
}

@JsonProperty("schema")
public String getSchemaName() {
return schemaName;
}

@JsonProperty("table")
public String getTableName() {
return tableName;
}

@JsonProperty("encryption_key")
public String getEncryptionKey() {
return blobTableMasterKey;
}

@JsonProperty("encryption_key_id")
public long getEncryptionKeyId() {
return encryptionKeyId;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -495,11 +495,19 @@ && shouldStopProcessing(
blobPath.fileName, this.owningClient.flushLatency.time());
}

// Copy encryptionKeysPerTable from owning client
Map<FullyQualifiedTableName, EncryptionKey> encryptionKeysPerTable =
new ConcurrentHashMap<>();
this.owningClient
.getEncryptionKeysPerTable()
.forEach((k, v) -> encryptionKeysPerTable.put(k, new EncryptionKey(v)));

Supplier<BlobMetadata> supplier =
() -> {
try {
BlobMetadata blobMetadata =
buildAndUpload(blobPath, blobData, fullyQualifiedTableName);
buildAndUpload(
blobPath, blobData, fullyQualifiedTableName, encryptionKeysPerTable);
blobMetadata.getBlobStats().setFlushStartMs(flushStartMs);
return blobMetadata;
} catch (Throwable e) {
Expand Down Expand Up @@ -562,8 +570,6 @@ && shouldStopProcessing(
*
* <p>When the chunk size is larger than a certain threshold
*
* <p>When the encryption key ids are not the same
*
* <p>When the schemas are not the same
*/
private boolean shouldStopProcessing(
Expand Down Expand Up @@ -591,7 +597,10 @@ private boolean shouldStopProcessing(
* @return BlobMetadata for FlushService.upload
*/
BlobMetadata buildAndUpload(
BlobPath blobPath, List<List<ChannelData<T>>> blobData, String fullyQualifiedTableName)
BlobPath blobPath,
List<List<ChannelData<T>>> blobData,
String fullyQualifiedTableName,
Map<FullyQualifiedTableName, EncryptionKey> encryptionKeysPerTable)
throws IOException, NoSuchAlgorithmException, InvalidAlgorithmParameterException,
NoSuchPaddingException, IllegalBlockSizeException, BadPaddingException,
InvalidKeyException {
Expand All @@ -601,7 +610,7 @@ BlobMetadata buildAndUpload(
// Construct the blob along with the metadata of the blob
BlobBuilder.Blob blob =
BlobBuilder.constructBlobAndMetadata(
blobPath.fileName, blobData, bdecVersion, paramProvider);
blobPath.fileName, blobData, bdecVersion, paramProvider, encryptionKeysPerTable);

blob.blobStats.setBuildDurationMs(buildContext);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package net.snowflake.ingest.streaming.internal;

import java.util.Objects;

/**
* FullyQualifiedTableName is a class that represents a fully qualified table name. It is used to
* store the fully qualified table name in the Snowflake format.
*/
public class FullyQualifiedTableName {
public FullyQualifiedTableName(String databaseName, String schemaName, String tableName) {
this.databaseName = databaseName;
this.schemaName = schemaName;
this.tableName = tableName;
}

// Database name
private final String databaseName;

// Schema name
private final String schemaName;

// Table Name
private final String tableName;

public String getTableName() {
return tableName;
}

public String getSchemaName() {
return schemaName;
}

public String getDatabaseName() {
return databaseName;
}

public String getFullyQualifiedName() {
return String.format("%s.%s.%s", databaseName, schemaName, tableName);
}

private int hashCode;

@Override
public int hashCode() {
int result = hashCode;
if (result == 0) {
result = 31 + ((databaseName == null) ? 0 : databaseName.hashCode());
result = 31 * result + ((schemaName == null) ? 0 : schemaName.hashCode());
result = 31 * result + ((tableName == null) ? 0 : tableName.hashCode());
hashCode = result;
}

return result;
}

@Override
public boolean equals(Object obj) {
if (this == obj) return true;

if (!(obj instanceof FullyQualifiedTableName)) return false;

FullyQualifiedTableName other = (FullyQualifiedTableName) obj;

if (!Objects.equals(databaseName, other.databaseName)) return false;
if (!Objects.equals(schemaName, other.schemaName)) return false;
return Objects.equals(tableName, other.tableName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class RegisterBlobResponse extends StreamingIngestResponse {
private Long statusCode;
private String message;
private List<BlobRegisterStatus> blobsStatus;
private List<EncryptionKey> encryptionKeys;

@JsonProperty("status_code")
void setStatusCode(Long statusCode) {
Expand Down Expand Up @@ -39,4 +40,13 @@ void setBlobsStatus(List<BlobRegisterStatus> blobsStatus) {
List<BlobRegisterStatus> getBlobsStatus() {
return this.blobsStatus;
}

@JsonProperty("encryption_keys")
void setEncryptionKeys(List<EncryptionKey> encryptionKeys) {
this.encryptionKeys = encryptionKeys;
}

List<EncryptionKey> getEncryptionKeys() {
return this.encryptionKeys;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
Expand Down Expand Up @@ -109,7 +110,7 @@ public class SnowflakeStreamingIngestClientInternal<T> implements SnowflakeStrea
private final FlushService<T> flushService;

// Reference to storage manager
private final IStorageManager storageManager;
private IStorageManager storageManager;

// Indicates whether the client has closed
private volatile boolean isClosed;
Expand All @@ -120,6 +121,9 @@ public class SnowflakeStreamingIngestClientInternal<T> implements SnowflakeStrea
// Indicates whether the client is under test mode
private final boolean isTestMode;

// Stores encryptionkey per table: FullyQualifiedTableName -> EncryptionKey
private final Map<FullyQualifiedTableName, EncryptionKey> encryptionKeysPerTable;

// Performance testing related metrics
MetricRegistry metrics;
Histogram blobSizeHistogram; // Histogram for blob size after compression
Expand Down Expand Up @@ -176,6 +180,7 @@ public class SnowflakeStreamingIngestClientInternal<T> implements SnowflakeStrea
this.channelCache = new ChannelCache<>();
this.isClosed = false;
this.requestBuilder = requestBuilder;
this.encryptionKeysPerTable = new ConcurrentHashMap<>();

if (!isTestMode) {
// Setup request builder for communication with the server side
Expand Down Expand Up @@ -398,6 +403,17 @@ public SnowflakeStreamingIngestChannelInternal<?> openChannel(OpenChannelRequest
new TableRef(response.getDBName(), response.getSchemaName(), response.getTableName()),
response.getIcebergLocationInfo());

// Add encryption key to the client map for the table
this.encryptionKeysPerTable.put(
new FullyQualifiedTableName(
request.getDBName(), request.getSchemaName(), request.getTableName()),
new EncryptionKey(
response.getDBName(),
response.getSchemaName(),
response.getTableName(),
response.getEncryptionKey(),
response.getEncryptionKeyId()));

return channel;
}

Expand Down Expand Up @@ -594,6 +610,18 @@ void registerBlobs(List<BlobMetadata> blobs, final int executionCount) {
this.name,
executionCount);

// Update encryption keys for the table given the response
if (response.getEncryptionKeys() == null) {
this.encryptionKeysPerTable.clear();
} else {
for (EncryptionKey key : response.getEncryptionKeys()) {
this.encryptionKeysPerTable.put(
new FullyQualifiedTableName(
key.getDatabaseName(), key.getSchemaName(), key.getTableName()),
key);
}
}

// We will retry any blob chunks that were rejected because internal Snowflake queues are full
Set<ChunkRegisterStatus> queueFullChunks = new HashSet<>();
response
Expand Down Expand Up @@ -1062,4 +1090,13 @@ private void cleanUpResources() {
HttpUtil.shutdownHttpConnectionManagerDaemonThread();
}
}

public Map<FullyQualifiedTableName, EncryptionKey> getEncryptionKeysPerTable() {
return encryptionKeysPerTable;
}

// TESTING ONLY - inject the storage manager
public void setStorageManager(IStorageManager storageManager) {
this.storageManager = storageManager;
}
}
Loading
Loading