Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1492090 Snowpipe streaming file master key id rotation #786

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,175 changes: 1,175 additions & 0 deletions dependency-reduced-pom.xml

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.zip.CRC32;
import javax.crypto.BadPaddingException;
import javax.crypto.IllegalBlockSizeException;
Expand Down Expand Up @@ -68,6 +69,7 @@ static <T> Blob constructBlobAndMetadata(
String filePath,
List<List<ChannelData<T>>> blobData,
Constants.BdecVersion bdecVersion,
Map<FullyQualifiedTableName, EncryptionKey> encryptionKeysPerTable,
boolean encrypt)
throws IOException, NoSuchPaddingException, NoSuchAlgorithmException,
InvalidAlgorithmParameterException, InvalidKeyException, IllegalBlockSizeException,
Expand All @@ -82,6 +84,13 @@ static <T> Blob constructBlobAndMetadata(
ChannelFlushContext firstChannelFlushContext =
channelsDataPerTable.get(0).getChannelContext();

EncryptionKey encryptionKey =
sfc-gh-bmikaili marked this conversation as resolved.
Show resolved Hide resolved
encryptionKeysPerTable.get(
new FullyQualifiedTableName(
firstChannelFlushContext.getDbName(),
firstChannelFlushContext.getSchemaName(),
firstChannelFlushContext.getTableName()));

Flusher<T> flusher = channelsDataPerTable.get(0).createFlusher();
Flusher.SerializationResult serializedChunk =
flusher.serialize(channelsDataPerTable, filePath);
Expand All @@ -104,8 +113,7 @@ static <T> Blob constructBlobAndMetadata(
// TODO: address alignment for the header SNOW-557866
long iv = curDataSize / Constants.ENCRYPTION_ALGORITHM_BLOCK_SIZE_BYTES;
compressedChunkData =
Cryptor.encrypt(
paddedChunkData, firstChannelFlushContext.getEncryptionKey(), filePath, iv);
Cryptor.encrypt(paddedChunkData, encryptionKey.getEncryptionKey(), filePath, iv);
compressedChunkDataSize = compressedChunkData.length;
} else {
compressedChunkData = serializedChunk.chunkData.toByteArray();
Expand All @@ -130,7 +138,7 @@ static <T> Blob constructBlobAndMetadata(
.setUncompressedChunkLength((int) serializedChunk.chunkEstimatedUncompressedSize)
.setChannelList(serializedChunk.channelsMetadataList)
.setChunkMD5(md5)
.setEncryptionKeyId(firstChannelFlushContext.getEncryptionKeyId())
.setEncryptionKeyId(encryptionKey.getEncryptionKeyId())
.setEpInfo(
AbstractRowBuffer.buildEpInfoFromStats(
serializedChunk.rowCount, serializedChunk.columnEpStatsMapCombined))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,8 @@ class ChannelFlushContext {
// connection to a channel at server side will be seen as a connection from a new client
private final Long channelSequencer;

// Data encryption key
private final String encryptionKey;

// Data encryption key id
private final Long encryptionKeyId;

ChannelFlushContext(
String name,
String dbName,
String schemaName,
String tableName,
Long channelSequencer,
String encryptionKey,
Long encryptionKeyId) {
String name, String dbName, String schemaName, String tableName, Long channelSequencer) {
this.name = name;
this.fullyQualifiedName =
Utils.getFullyQualifiedChannelName(dbName, schemaName, tableName, name);
Expand All @@ -45,8 +33,6 @@ class ChannelFlushContext {
this.tableName = tableName;
this.fullyQualifiedTableName = Utils.getFullyQualifiedTableName(dbName, schemaName, tableName);
this.channelSequencer = channelSequencer;
this.encryptionKey = encryptionKey;
this.encryptionKeyId = encryptionKeyId;
}

@Override
Expand All @@ -72,11 +58,6 @@ public String toString() {
+ '\''
+ ", channelSequencer="
+ getChannelSequencer()
+ ", encryptionKey='"
+ getEncryptionKey()
+ '\''
+ ", encryptionKeyId="
+ getEncryptionKeyId()
+ '}';
}

Expand Down Expand Up @@ -107,12 +88,4 @@ String getFullyQualifiedTableName() {
Long getChannelSequencer() {
return channelSequencer;
}

String getEncryptionKey() {
return encryptionKey;
}

Long getEncryptionKeyId() {
return encryptionKeyId;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package net.snowflake.ingest.streaming.internal;

import com.fasterxml.jackson.annotation.JsonProperty;

/** Represents an encryption key for a table */
public class EncryptionKey {
// Database name
private final String databaseName;

// Schema name
private final String schemaName;

// Table Name
private final String tableName;

String blobTableMasterKey;

long encryptionKeyId;

public EncryptionKey(
@JsonProperty("database") String databaseName,
@JsonProperty("schema") String schemaName,
@JsonProperty("table") String tableName,
@JsonProperty("encryption_key") String blobTableMasterKey,
@JsonProperty("encryption_key_id") long encryptionKeyId) {
this.databaseName = databaseName;
this.schemaName = schemaName;
this.tableName = tableName;
this.blobTableMasterKey = blobTableMasterKey;
this.encryptionKeyId = encryptionKeyId;
}

public EncryptionKey(EncryptionKey encryptionKey) {
this.databaseName = encryptionKey.databaseName;
this.schemaName = encryptionKey.schemaName;
this.tableName = encryptionKey.tableName;
this.blobTableMasterKey = encryptionKey.blobTableMasterKey;
this.encryptionKeyId = encryptionKey.encryptionKeyId;
}

public String getFullyQualifiedTableName() {
return String.format("%s.%s.%s", databaseName, schemaName, tableName);
}

@JsonProperty("database")
public String getDatabaseName() {
return databaseName;
}

@JsonProperty("schema")
public String getSchemaName() {
return schemaName;
}

@JsonProperty("table")
public String getTableName() {
return tableName;
}

@JsonProperty("encryption_key")
public String getEncryptionKey() {
return blobTableMasterKey;
}

@JsonProperty("encryption_key_id")
public long getEncryptionKeyId() {
return encryptionKeyId;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
Expand Down Expand Up @@ -455,15 +454,13 @@ && shouldStopProcessing(
logger.logInfo(
"Creation of another blob is needed because of blob/chunk size limit or"
+ " different encryption ids or different schema, client={}, table={},"
+ " blobSize={}, chunkSize={}, nextChannelSize={}, encryptionId1={},"
+ " encryptionId2={}, schema1={}, schema2={}",
+ " blobSize={}, chunkSize={}, nextChannelSize={},"
+ " schema1={}, schema2={}",
this.owningClient.getName(),
channelData.getChannelContext().getTableName(),
totalBufferSizeInBytes,
totalBufferSizePerTableInBytes,
channelData.getBufferSize(),
channelData.getChannelContext().getEncryptionKeyId(),
channelsDataPerTable.get(idx - 1).getChannelContext().getEncryptionKeyId(),
channelData.getColumnEps().keySet(),
channelsDataPerTable.get(idx - 1).getColumnEps().keySet());
break;
Expand All @@ -490,6 +487,14 @@ && shouldStopProcessing(
if (this.owningClient.flushLatency != null) {
latencyTimerContextMap.putIfAbsent(blobPath, this.owningClient.flushLatency.time());
}

// Copy encryptionKeysPerTable from owning client
Map<FullyQualifiedTableName, EncryptionKey> encryptionKeysPerTable =
new ConcurrentHashMap<>();
this.owningClient
.getEncryptionKeysPerTable()
.forEach((k, v) -> encryptionKeysPerTable.put(k, new EncryptionKey(v)));

blobs.add(
new Pair<>(
new BlobData<>(blobPath, blobData),
Expand All @@ -502,7 +507,11 @@ && shouldStopProcessing(
String fullyQualifiedTableName =
blobData.get(0).get(0).getChannelContext().getFullyQualifiedTableName();
BlobMetadata blobMetadata =
buildAndUpload(blobPath, blobData, fullyQualifiedTableName);
buildAndUpload(
blobPath,
blobData,
fullyQualifiedTableName,
encryptionKeysPerTable);
blobMetadata.getBlobStats().setFlushStartMs(flushStartMs);
return blobMetadata;
} catch (Throwable e) {
Expand Down Expand Up @@ -561,8 +570,6 @@ && shouldStopProcessing(
*
* <p>When the chunk size is larger than a certain threshold
*
* <p>When the encryption key ids are not the same
*
* <p>When the schemas are not the same
*/
private boolean shouldStopProcessing(
Expand All @@ -573,9 +580,6 @@ private boolean shouldStopProcessing(
return totalBufferSizeInBytes + current.getBufferSize() > MAX_BLOB_SIZE_IN_BYTES
|| totalBufferSizePerTableInBytes + current.getBufferSize()
> this.owningClient.getParameterProvider().getMaxChunkSizeInBytes()
|| !Objects.equals(
current.getChannelContext().getEncryptionKeyId(),
prev.getChannelContext().getEncryptionKeyId())
|| !current.getColumnEps().keySet().equals(prev.getColumnEps().keySet());
}

Expand All @@ -590,7 +594,10 @@ private boolean shouldStopProcessing(
* @return BlobMetadata for FlushService.upload
*/
BlobMetadata buildAndUpload(
String blobPath, List<List<ChannelData<T>>> blobData, String fullyQualifiedTableName)
String blobPath,
List<List<ChannelData<T>>> blobData,
String fullyQualifiedTableName,
Map<FullyQualifiedTableName, EncryptionKey> encryptionKeysPerTable)
throws IOException, NoSuchAlgorithmException, InvalidAlgorithmParameterException,
NoSuchPaddingException, IllegalBlockSizeException, BadPaddingException,
InvalidKeyException {
Expand All @@ -602,6 +609,7 @@ BlobMetadata buildAndUpload(
blobPath,
blobData,
bdecVersion,
encryptionKeysPerTable,
this.owningClient.getInternalParameterProvider().getEnableChunkEncryption());

blob.blobStats.setBuildDurationMs(buildContext);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package net.snowflake.ingest.streaming.internal;

import java.util.Objects;

/**
* FullyQualifiedTableName is a class that represents a fully qualified table name. It is used to
* store the fully qualified table name in the Snowflake format.
*/
public class FullyQualifiedTableName {
public FullyQualifiedTableName(String databaseName, String schemaName, String tableName) {
this.databaseName = databaseName;
this.schemaName = schemaName;
this.tableName = tableName;
}

// Database name
private final String databaseName;

// Schema name
private final String schemaName;

// Table Name
private final String tableName;

public String getTableName() {
return tableName;
}

public String getSchemaName() {
return schemaName;
}

public String getDatabaseName() {
return databaseName;
}

public String getFullyQualifiedName() {
return String.format("%s.%s.%s", databaseName, schemaName, tableName);
}

private int hashCode;

@Override
public int hashCode() {
int result = hashCode;
if (result == 0) {
result = 31 + ((databaseName == null) ? 0 : databaseName.hashCode());
result = 31 * result + ((schemaName == null) ? 0 : schemaName.hashCode());
result = 31 * result + ((tableName == null) ? 0 : tableName.hashCode());
hashCode = result;
}

return result;
}

@Override
public boolean equals(Object obj) {
if (this == obj) return true;

if (!(obj instanceof FullyQualifiedTableName)) return false;

FullyQualifiedTableName other = (FullyQualifiedTableName) obj;

if (!Objects.equals(databaseName, other.databaseName)) return false;
if (!Objects.equals(schemaName, other.schemaName)) return false;
return Objects.equals(tableName, other.tableName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class RegisterBlobResponse extends StreamingIngestResponse {
private Long statusCode;
private String message;
private List<BlobRegisterStatus> blobsStatus;
private List<EncryptionKey> encryptionKeys;

@JsonProperty("status_code")
void setStatusCode(Long statusCode) {
Expand Down Expand Up @@ -39,4 +40,13 @@ void setBlobsStatus(List<BlobRegisterStatus> blobsStatus) {
List<BlobRegisterStatus> getBlobsStatus() {
return this.blobsStatus;
}

@JsonProperty("encryption_keys")
void setEncryptionKeys(List<EncryptionKey> encryptionKeys) {
this.encryptionKeys = encryptionKeys;
}

List<EncryptionKey> getEncryptionKeys() {
return this.encryptionKeys;
}
}
Loading
Loading