diff --git a/.gitignore b/.gitignore index 7cda7192a..8fd73e414 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ src/main/resources/log4j.properties src/test/resources/log4j.properties testOutput/ .cache/ +/dependency-reduced-pom.xml diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/BlobBuilder.java b/src/main/java/net/snowflake/ingest/streaming/internal/BlobBuilder.java index 30abfac38..18b28a037 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/BlobBuilder.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/BlobBuilder.java @@ -23,6 +23,7 @@ import java.security.NoSuchAlgorithmException; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.zip.CRC32; import javax.crypto.BadPaddingException; import javax.crypto.IllegalBlockSizeException; @@ -67,7 +68,8 @@ static Blob constructBlobAndMetadata( String filePath, List>> blobData, Constants.BdecVersion bdecVersion, - InternalParameterProvider internalParameterProvider) + InternalParameterProvider internalParameterProvider, + Map encryptionKeysPerTable) throws IOException, NoSuchPaddingException, NoSuchAlgorithmException, InvalidAlgorithmParameterException, InvalidKeyException, IllegalBlockSizeException, BadPaddingException { @@ -81,6 +83,13 @@ static Blob constructBlobAndMetadata( ChannelFlushContext firstChannelFlushContext = channelsDataPerTable.get(0).getChannelContext(); + EncryptionKey encryptionKey = + encryptionKeysPerTable.get( + new FullyQualifiedTableName( + firstChannelFlushContext.getDbName(), + firstChannelFlushContext.getSchemaName(), + firstChannelFlushContext.getTableName())); + Flusher flusher = channelsDataPerTable.get(0).createFlusher(); Flusher.SerializationResult serializedChunk = flusher.serialize(channelsDataPerTable, filePath, curDataSize); @@ -102,9 +111,19 @@ static Blob constructBlobAndMetadata( // to align with decryption on the Snowflake query path. // TODO: address alignment for the header SNOW-557866 long iv = curDataSize / Constants.ENCRYPTION_ALGORITHM_BLOCK_SIZE_BYTES; + + if (encryptionKey == null) + encryptionKey = + new EncryptionKey( + firstChannelFlushContext.getDbName(), + firstChannelFlushContext.getSchemaName(), + firstChannelFlushContext.getTableName(), + firstChannelFlushContext.getEncryptionKey(), + firstChannelFlushContext.getEncryptionKeyId()); + compressedChunkData = - Cryptor.encrypt( - paddedChunkData, firstChannelFlushContext.getEncryptionKey(), filePath, iv); + Cryptor.encrypt(paddedChunkData, encryptionKey.getEncryptionKey(), filePath, iv); + compressedChunkDataSize = compressedChunkData.length; } else { compressedChunkData = serializedChunk.chunkData.toByteArray(); @@ -129,7 +148,7 @@ static Blob constructBlobAndMetadata( .setUncompressedChunkLength((int) serializedChunk.chunkEstimatedUncompressedSize) .setChannelList(serializedChunk.channelsMetadataList) .setChunkMD5(md5) - .setEncryptionKeyId(firstChannelFlushContext.getEncryptionKeyId()) + .setEncryptionKeyId(encryptionKey.getEncryptionKeyId()) .setEpInfo( AbstractRowBuffer.buildEpInfoFromStats( serializedChunk.rowCount, diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/EncryptionKey.java b/src/main/java/net/snowflake/ingest/streaming/internal/EncryptionKey.java new file mode 100644 index 000000000..7022dc321 --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/EncryptionKey.java @@ -0,0 +1,69 @@ +package net.snowflake.ingest.streaming.internal; + +import com.fasterxml.jackson.annotation.JsonProperty; + +/** Represents an encryption key for a table */ +public class EncryptionKey { + // Database name + private final String databaseName; + + // Schema name + private final String schemaName; + + // Table Name + private final String tableName; + + String blobTableMasterKey; + + long encryptionKeyId; + + public EncryptionKey( + @JsonProperty("database") String databaseName, + @JsonProperty("schema") String schemaName, + @JsonProperty("table") String tableName, + @JsonProperty("encryption_key") String blobTableMasterKey, + @JsonProperty("encryption_key_id") long encryptionKeyId) { + this.databaseName = databaseName; + this.schemaName = schemaName; + this.tableName = tableName; + this.blobTableMasterKey = blobTableMasterKey; + this.encryptionKeyId = encryptionKeyId; + } + + public EncryptionKey(EncryptionKey encryptionKey) { + this.databaseName = encryptionKey.databaseName; + this.schemaName = encryptionKey.schemaName; + this.tableName = encryptionKey.tableName; + this.blobTableMasterKey = encryptionKey.blobTableMasterKey; + this.encryptionKeyId = encryptionKey.encryptionKeyId; + } + + public String getFullyQualifiedTableName() { + return String.format("%s.%s.%s", databaseName, schemaName, tableName); + } + + @JsonProperty("database") + public String getDatabaseName() { + return databaseName; + } + + @JsonProperty("schema") + public String getSchemaName() { + return schemaName; + } + + @JsonProperty("table") + public String getTableName() { + return tableName; + } + + @JsonProperty("encryption_key") + public String getEncryptionKey() { + return blobTableMasterKey; + } + + @JsonProperty("encryption_key_id") + public long getEncryptionKeyId() { + return encryptionKeyId; + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java b/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java index 2e64f77b8..9bd3bcdca 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/FlushService.java @@ -495,11 +495,19 @@ && shouldStopProcessing( blobPath.fileName, this.owningClient.flushLatency.time()); } + // Copy encryptionKeysPerTable from owning client + Map encryptionKeysPerTable = + new ConcurrentHashMap<>(); + this.owningClient + .getEncryptionKeysPerTable() + .forEach((k, v) -> encryptionKeysPerTable.put(k, new EncryptionKey(v))); + Supplier supplier = () -> { try { BlobMetadata blobMetadata = - buildAndUpload(blobPath, blobData, fullyQualifiedTableName); + buildAndUpload( + blobPath, blobData, fullyQualifiedTableName, encryptionKeysPerTable); blobMetadata.getBlobStats().setFlushStartMs(flushStartMs); return blobMetadata; } catch (Throwable e) { @@ -562,8 +570,6 @@ && shouldStopProcessing( * *

When the chunk size is larger than a certain threshold * - *

When the encryption key ids are not the same - * *

When the schemas are not the same */ private boolean shouldStopProcessing( @@ -591,7 +597,10 @@ private boolean shouldStopProcessing( * @return BlobMetadata for FlushService.upload */ BlobMetadata buildAndUpload( - BlobPath blobPath, List>> blobData, String fullyQualifiedTableName) + BlobPath blobPath, + List>> blobData, + String fullyQualifiedTableName, + Map encryptionKeysPerTable) throws IOException, NoSuchAlgorithmException, InvalidAlgorithmParameterException, NoSuchPaddingException, IllegalBlockSizeException, BadPaddingException, InvalidKeyException { @@ -601,7 +610,7 @@ BlobMetadata buildAndUpload( // Construct the blob along with the metadata of the blob BlobBuilder.Blob blob = BlobBuilder.constructBlobAndMetadata( - blobPath.fileName, blobData, bdecVersion, paramProvider); + blobPath.fileName, blobData, bdecVersion, paramProvider, encryptionKeysPerTable); blob.blobStats.setBuildDurationMs(buildContext); diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/FullyQualifiedTableName.java b/src/main/java/net/snowflake/ingest/streaming/internal/FullyQualifiedTableName.java new file mode 100644 index 000000000..8ab639123 --- /dev/null +++ b/src/main/java/net/snowflake/ingest/streaming/internal/FullyQualifiedTableName.java @@ -0,0 +1,68 @@ +package net.snowflake.ingest.streaming.internal; + +import java.util.Objects; + +/** + * FullyQualifiedTableName is a class that represents a fully qualified table name. It is used to + * store the fully qualified table name in the Snowflake format. + */ +public class FullyQualifiedTableName { + public FullyQualifiedTableName(String databaseName, String schemaName, String tableName) { + this.databaseName = databaseName; + this.schemaName = schemaName; + this.tableName = tableName; + } + + // Database name + private final String databaseName; + + // Schema name + private final String schemaName; + + // Table Name + private final String tableName; + + public String getTableName() { + return tableName; + } + + public String getSchemaName() { + return schemaName; + } + + public String getDatabaseName() { + return databaseName; + } + + public String getFullyQualifiedName() { + return String.format("%s.%s.%s", databaseName, schemaName, tableName); + } + + private int hashCode; + + @Override + public int hashCode() { + int result = hashCode; + if (result == 0) { + result = 31 + ((databaseName == null) ? 0 : databaseName.hashCode()); + result = 31 * result + ((schemaName == null) ? 0 : schemaName.hashCode()); + result = 31 * result + ((tableName == null) ? 0 : tableName.hashCode()); + hashCode = result; + } + + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + + if (!(obj instanceof FullyQualifiedTableName)) return false; + + FullyQualifiedTableName other = (FullyQualifiedTableName) obj; + + if (!Objects.equals(databaseName, other.databaseName)) return false; + if (!Objects.equals(schemaName, other.schemaName)) return false; + return Objects.equals(tableName, other.tableName); + } +} diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/RegisterBlobResponse.java b/src/main/java/net/snowflake/ingest/streaming/internal/RegisterBlobResponse.java index bcfae3355..c0d67b6cd 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/RegisterBlobResponse.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/RegisterBlobResponse.java @@ -12,6 +12,7 @@ class RegisterBlobResponse extends StreamingIngestResponse { private Long statusCode; private String message; private List blobsStatus; + private List encryptionKeys; @JsonProperty("status_code") void setStatusCode(Long statusCode) { @@ -39,4 +40,13 @@ void setBlobsStatus(List blobsStatus) { List getBlobsStatus() { return this.blobsStatus; } + + @JsonProperty("encryption_keys") + void setEncryptionKeys(List encryptionKeys) { + this.encryptionKeys = encryptionKeys; + } + + List getEncryptionKeys() { + return this.encryptionKeys; + } } diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java index fd1c0e38a..843b8975a 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java @@ -43,6 +43,7 @@ import java.util.Properties; import java.util.Set; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; @@ -109,7 +110,7 @@ public class SnowflakeStreamingIngestClientInternal implements SnowflakeStrea private final FlushService flushService; // Reference to storage manager - private final IStorageManager storageManager; + private IStorageManager storageManager; // Indicates whether the client has closed private volatile boolean isClosed; @@ -120,6 +121,9 @@ public class SnowflakeStreamingIngestClientInternal implements SnowflakeStrea // Indicates whether the client is under test mode private final boolean isTestMode; + // Stores encryptionkey per table: FullyQualifiedTableName -> EncryptionKey + private final Map encryptionKeysPerTable; + // Performance testing related metrics MetricRegistry metrics; Histogram blobSizeHistogram; // Histogram for blob size after compression @@ -176,6 +180,7 @@ public class SnowflakeStreamingIngestClientInternal implements SnowflakeStrea this.channelCache = new ChannelCache<>(); this.isClosed = false; this.requestBuilder = requestBuilder; + this.encryptionKeysPerTable = new ConcurrentHashMap<>(); if (!isTestMode) { // Setup request builder for communication with the server side @@ -398,6 +403,17 @@ public SnowflakeStreamingIngestChannelInternal openChannel(OpenChannelRequest new TableRef(response.getDBName(), response.getSchemaName(), response.getTableName()), response.getIcebergLocationInfo()); + // Add encryption key to the client map for the table + this.encryptionKeysPerTable.put( + new FullyQualifiedTableName( + request.getDBName(), request.getSchemaName(), request.getTableName()), + new EncryptionKey( + response.getDBName(), + response.getSchemaName(), + response.getTableName(), + response.getEncryptionKey(), + response.getEncryptionKeyId())); + return channel; } @@ -594,6 +610,18 @@ void registerBlobs(List blobs, final int executionCount) { this.name, executionCount); + // Update encryption keys for the table given the response + if (response.getEncryptionKeys() == null) { + this.encryptionKeysPerTable.clear(); + } else { + for (EncryptionKey key : response.getEncryptionKeys()) { + this.encryptionKeysPerTable.put( + new FullyQualifiedTableName( + key.getDatabaseName(), key.getSchemaName(), key.getTableName()), + key); + } + } + // We will retry any blob chunks that were rejected because internal Snowflake queues are full Set queueFullChunks = new HashSet<>(); response @@ -1062,4 +1090,13 @@ private void cleanUpResources() { HttpUtil.shutdownHttpConnectionManagerDaemonThread(); } } + + public Map getEncryptionKeysPerTable() { + return encryptionKeysPerTable; + } + + // TESTING ONLY - inject the storage manager + public void setStorageManager(IStorageManager storageManager) { + this.storageManager = storageManager; + } } diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java index 185fa5ded..8a6190a71 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/BlobBuilderTest.java @@ -9,7 +9,9 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; import net.snowflake.ingest.utils.Constants; import net.snowflake.ingest.utils.ErrorCode; import net.snowflake.ingest.utils.Pair; @@ -33,12 +35,18 @@ public static Object[] isIceberg() { @Test public void testSerializationErrors() throws Exception { + Map encryptionKeysPerTable = new ConcurrentHashMap<>(); + encryptionKeysPerTable.put( + new FullyQualifiedTableName("DB", "SCHEMA", "TABLE"), + new EncryptionKey("DB", "SCHEMA", "TABLE", "KEY", 1234L)); + // Construction succeeds if both data and metadata contain 1 row BlobBuilder.constructBlobAndMetadata( "a.bdec", Collections.singletonList(createChannelDataPerTable(1)), Constants.BdecVersion.THREE, - new InternalParameterProvider(isIceberg)); + new InternalParameterProvider(isIceberg), + encryptionKeysPerTable); // Construction fails if metadata contains 0 rows and data 1 row try { @@ -46,7 +54,8 @@ public void testSerializationErrors() throws Exception { "a.bdec", Collections.singletonList(createChannelDataPerTable(0)), Constants.BdecVersion.THREE, - new InternalParameterProvider(isIceberg)); + new InternalParameterProvider(isIceberg), + encryptionKeysPerTable); } catch (SFException e) { Assert.assertEquals(ErrorCode.INTERNAL_ERROR.getMessageCode(), e.getVendorCode()); Assert.assertTrue(e.getMessage().contains("parquetTotalRowsInFooter=1")); diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/FlushServiceTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/FlushServiceTest.java index b5ed0ba96..f05f34009 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/FlushServiceTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/FlushServiceTest.java @@ -40,6 +40,7 @@ import java.util.Map; import java.util.TimeZone; import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.stream.IntStream; @@ -99,6 +100,7 @@ private abstract static class TestContext implements AutoCloseable { InternalStage storage; ExternalVolume extVolume; ParameterProvider parameterProvider; + Map encryptionKeysPerTable; RegisterService registerService; final List> channelData = new ArrayList<>(); @@ -124,6 +126,25 @@ private abstract static class TestContext implements AutoCloseable { Mockito.when(storageManager.getClientPrefix()).thenReturn("client_prefix"); Mockito.when(client.getParameterProvider()) .thenAnswer((Answer) (i) -> parameterProvider); + + encryptionKeysPerTable = new ConcurrentHashMap<>(); + if (isIcebergMode) { + encryptionKeysPerTable.put( + new FullyQualifiedTableName("db1", "schema1", "table1"), + new EncryptionKey("db1", "schema1", "table1", "key1", 1234L)); + encryptionKeysPerTable.put( + new FullyQualifiedTableName("db2", "schema1", "table2"), + new EncryptionKey("db2", "schema1", "table2", "key1", 1234L)); + + for (int i = 0; i <= 9999; i++) { + encryptionKeysPerTable.put( + new FullyQualifiedTableName("db1", "PUBLIC", String.format("table%d", i)), + new EncryptionKey("db1", "PUBLIC", String.format("table%d", i), "key1", 1234L)); + } + + Mockito.when(client.getEncryptionKeysPerTable()).thenReturn(encryptionKeysPerTable); + } + channelCache = new ChannelCache<>(); Mockito.when(client.getChannelCache()).thenReturn(channelCache); registerService = Mockito.spy(new RegisterService(client, client.isTestMode())); @@ -147,7 +168,8 @@ BlobMetadata buildAndUpload() throws Exception { return flushService.buildAndUpload( BlobPath.fileNameWithoutToken("file_name.bdec"), blobData, - blobData.get(0).get(0).getChannelContext().getFullyQualifiedTableName()); + blobData.get(0).get(0).getChannelContext().getFullyQualifiedTableName(), + encryptionKeysPerTable); } abstract SnowflakeStreamingIngestChannelInternal createChannel( @@ -633,7 +655,7 @@ public void testBlobCreation() throws Exception { if (!isIcebergMode) { flushService.flush(true).get(); Mockito.verify(flushService, Mockito.atLeast(2)) - .buildAndUpload(Mockito.any(), Mockito.any(), Mockito.any()); + .buildAndUpload(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); } } @@ -686,7 +708,7 @@ public void testBlobSplitDueToDifferentSchema() throws Exception { // Force = true flushes flushService.flush(true).get(); Mockito.verify(flushService, Mockito.atLeast(2)) - .buildAndUpload(Mockito.any(), Mockito.any(), Mockito.any()); + .buildAndUpload(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); } } @@ -724,7 +746,7 @@ public void testBlobSplitDueToChunkSizeLimit() throws Exception { // Force = true flushes flushService.flush(true).get(); Mockito.verify(flushService, Mockito.times(2)) - .buildAndUpload(Mockito.any(), Mockito.any(), Mockito.any()); + .buildAndUpload(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); } } @@ -772,7 +794,7 @@ public void runTestBlobSplitDueToNumberOfChunks(int numberOfRows) throws Excepti ArgumentCaptor>>>>> blobDataCaptor = ArgumentCaptor.forClass(List.class); Mockito.verify(flushService, Mockito.times(expectedBlobs)) - .buildAndUpload(Mockito.any(), blobDataCaptor.capture(), Mockito.any()); + .buildAndUpload(Mockito.any(), blobDataCaptor.capture(), Mockito.any(), Mockito.any()); // 1. list => blobs; 2. list => chunks; 3. list => channels; 4. list => rows, 5. list => columns List>>>>> allUploadedBlobs = @@ -820,7 +842,7 @@ public void testBlobSplitDueToNumberOfChunksWithLeftoverChannels() throws Except ArgumentCaptor>>>>> blobDataCaptor = ArgumentCaptor.forClass(List.class); Mockito.verify(flushService, Mockito.atLeast(2)) - .buildAndUpload(Mockito.any(), blobDataCaptor.capture(), Mockito.any()); + .buildAndUpload(Mockito.any(), blobDataCaptor.capture(), Mockito.any(), Mockito.any()); // 1. list => blobs; 2. list => chunks; 3. list => channels; 4. list => rows, 5. list => columns List>>>>> allUploadedBlobs = diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelTest.java index b3e2d23d6..b689d162a 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelTest.java @@ -913,4 +913,46 @@ public void testGetLatestCommittedOffsetToken() { Assert.assertEquals(ErrorCode.CHANNEL_STATUS_INVALID.getMessageCode(), e.getVendorCode()); } } + + @Test + public void testOpenChannelWithEncryptionKey() throws Exception { + // TODO: SNOW-1490151 Iceberg testing gaps + if (isIcebergMode) { + return; + } + + String response = + "{\n" + + " \"status_code\" : 0,\n" + + " \"message\" : \"Success\",\n" + + " \"encryption_key\" : \"key\",\n" + + " \"encryption_key_id\" : 1,\n" + + " \"database\" : \"db\",\n" + + " \"schema\" : \"schema\",\n" + + " \"table\" : \"table\",\n" + + " \"channel\" : \"channel\",\n" + + " \"row_sequencer\" : 0,\n" + + " \"table_columns\" : [],\n" + + " \"client_sequencer\" : 0\n" + + "}"; + + apiOverride.addSerializedJsonOverride( + OPEN_CHANNEL_ENDPOINT, request -> Pair.of(HttpStatus.SC_OK, response)); + + OpenChannelRequest request = + OpenChannelRequest.builder("channel") + .setDBName("db") + .setSchemaName("schema") + .setTableName("table") + .setOnErrorOption(OpenChannelRequest.OnErrorOption.CONTINUE) + .build(); + client.openChannel(request); + + FullyQualifiedTableName fqn = new FullyQualifiedTableName("db", "schema", "table"); + Map keys = client.getEncryptionKeysPerTable(); + Assert.assertEquals(1, keys.size()); + Assert.assertTrue(keys.containsKey(fqn)); + Assert.assertEquals("key", keys.get(fqn).getEncryptionKey()); + Assert.assertEquals(1, keys.get(fqn).getEncryptionKeyId()); + } } diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java index 0dbeeebee..fb9cfb253 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java @@ -5,16 +5,7 @@ package net.snowflake.ingest.streaming.internal; import static java.time.ZoneOffset.UTC; -import static net.snowflake.ingest.utils.Constants.ACCOUNT_URL; -import static net.snowflake.ingest.utils.Constants.CHANNEL_STATUS_ENDPOINT; -import static net.snowflake.ingest.utils.Constants.DROP_CHANNEL_ENDPOINT; -import static net.snowflake.ingest.utils.Constants.MAX_STREAMING_INGEST_API_CHANNEL_RETRY; -import static net.snowflake.ingest.utils.Constants.PRIVATE_KEY; -import static net.snowflake.ingest.utils.Constants.REGISTER_BLOB_ENDPOINT; -import static net.snowflake.ingest.utils.Constants.RESPONSE_ERR_ENQUEUE_TABLE_CHUNK_QUEUE_FULL; -import static net.snowflake.ingest.utils.Constants.RESPONSE_SUCCESS; -import static net.snowflake.ingest.utils.Constants.ROLE; -import static net.snowflake.ingest.utils.Constants.USER; +import static net.snowflake.ingest.utils.Constants.*; import static net.snowflake.ingest.utils.ParameterProvider.ENABLE_SNOWPIPE_STREAMING_METRICS; import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.when; @@ -719,6 +710,7 @@ public void testRegisterBlobSuccessResponse() throws Exception { "{\n" + " \"status_code\" : 0,\n" + " \"message\" : \"Success\",\n" + + " \"encryption_keys\": [],\n" + " \"blobs\" : [ {\n" + " \"chunks\" : [ {\n" + " \"database\" : \"DB_STREAMINGINGEST\",\n" @@ -773,10 +765,12 @@ public void testRegisterBlobsRetries() throws Exception { RegisterBlobResponse initialResponse = new RegisterBlobResponse(); initialResponse.setMessage("successish"); initialResponse.setStatusCode(RESPONSE_SUCCESS); + initialResponse.setEncryptionKeys(new ArrayList<>()); RegisterBlobResponse retryResponse = new RegisterBlobResponse(); retryResponse.setMessage("successish"); retryResponse.setStatusCode(RESPONSE_SUCCESS); + retryResponse.setEncryptionKeys(new ArrayList<>()); List blobRegisterStatuses = new ArrayList<>(); BlobRegisterStatus blobRegisterStatus1 = new BlobRegisterStatus(); @@ -941,10 +935,12 @@ public void testRegisterBlobsRetriesSucceeds() throws Exception { RegisterBlobResponse initialResponse = new RegisterBlobResponse(); initialResponse.setMessage("successish"); initialResponse.setStatusCode(RESPONSE_SUCCESS); + initialResponse.setEncryptionKeys(new ArrayList<>()); RegisterBlobResponse retryResponse = new RegisterBlobResponse(); retryResponse.setMessage("successish"); retryResponse.setStatusCode(RESPONSE_SUCCESS); + retryResponse.setEncryptionKeys(new ArrayList<>()); List blobRegisterStatuses = new ArrayList<>(); BlobRegisterStatus blobRegisterStatus1 = new BlobRegisterStatus(); @@ -1018,6 +1014,7 @@ public void testRegisterBlobResponseWithInvalidChannel() throws Exception { "{\n" + " \"status_code\" : 0,\n" + " \"message\" : \"Success\",\n" + + " \"encryption_keys\": [],\n" + " \"blobs\" : [ {\n" + " \"chunks\" : [ {\n" + " \"database\" : \"%s\",\n" @@ -1088,6 +1085,58 @@ public void testRegisterBlobResponseWithInvalidChannel() throws Exception { Assert.assertFalse(channel2.isValid()); } + @Test + public void testRegisterBlobSuccessResponseWithEncryptionKeys() throws Exception { + String response = + "{\n" + + " \"status_code\" : 0,\n" + + " \"message\" : \"Success\",\n" + + " \"encryption_keys\": [\n" + + " {\n" + + " \"database\" : \"DB_STREAMINGINGEST\",\n" + + " \"schema\" : \"PUBLIC\",\n" + + " \"table\" : \"T_STREAMINGINGEST\",\n" + + " \"encryption_key\" : \"key\",\n" + + " \"encryption_key_id\" : 1234\n" + + " }\n" + + " ],\n" + + " \"blobs\" : [ {\n" + + " \"chunks\" : [ {\n" + + " \"database\" : \"DB_STREAMINGINGEST\",\n" + + " \"schema\" : \"PUBLIC\",\n" + + " \"table\" : \"T_STREAMINGINGEST\",\n" + + " \"channels\" : [ {\n" + + " \"status_code\" : 0,\n" + + " \"channel\" : \"CHANNEL\",\n" + + " \"client_sequencer\" : 0\n" + + " }, {\n" + + " \"status_code\" : 0,\n" + + " \"channel\" : \"CHANNEL1\",\n" + + " \"client_sequencer\" : 0\n" + + " } ]\n" + + " } ]\n" + + " } ]\n" + + "}"; + + apiOverride.addSerializedJsonOverride( + REGISTER_BLOB_ENDPOINT, request -> Pair.of(HttpStatus.SC_OK, response)); + + List blobs = + Collections.singletonList(new BlobMetadata("path", "md5", new ArrayList<>(), null)); + + FullyQualifiedTableName fqn = + new FullyQualifiedTableName("DB_STREAMINGINGEST", "PUBLIC", "T_STREAMINGINGEST"); + client.registerBlobs(blobs); + Assert.assertEquals(1, client.getEncryptionKeysPerTable().size()); + Assert.assertEquals( + "DB_STREAMINGINGEST", client.getEncryptionKeysPerTable().get(fqn).getDatabaseName()); + Assert.assertEquals("PUBLIC", client.getEncryptionKeysPerTable().get(fqn).getSchemaName()); + Assert.assertEquals( + "T_STREAMINGINGEST", client.getEncryptionKeysPerTable().get(fqn).getTableName()); + Assert.assertEquals("key", client.getEncryptionKeysPerTable().get(fqn).getEncryptionKey()); + Assert.assertEquals(1234, client.getEncryptionKeysPerTable().get(fqn).getEncryptionKeyId()); + } + @Test public void testFlush() throws Exception { client.flush(false).get();