Skip to content

Commit

Permalink
Fix DDB client tenant id handling
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Sep 7, 2024
1 parent 1096e47 commit 0ef79ec
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ private void checkForModelsUsingConnector(String connectorId, String tenantId, A
SearchResponse searchResponse = SearchResponse.fromXContent(sr.parser());
SearchHit[] searchHits = searchResponse.getHits().getHits();
if (searchHits.length == 0) {
deleteConnector(connectorId, restoringListener);
deleteConnector(connectorId, tenantId, restoringListener);
} else {
handleModelsUsingConnector(searchHits, connectorId, restoringListener);
}
Expand All @@ -153,7 +153,7 @@ private void checkForModelsUsingConnector(String connectorId, String tenantId, A
}
} else {
Exception cause = SdkClientUtils.unwrapAndConvertToException(st);
handleSearchFailure(connectorId, cause, restoringListener);
handleSearchFailure(connectorId, tenantId, cause, restoringListener);
}
});
} catch (Exception e) {
Expand All @@ -179,21 +179,21 @@ private void handleModelsUsingConnector(SearchHit[] searchHits, String connector
);
}

private void handleSearchFailure(String connectorId, Exception cause, ActionListener<DeleteResponse> actionListener) {
private void handleSearchFailure(String connectorId, String tenantId, Exception cause, ActionListener<DeleteResponse> actionListener) {
if (cause instanceof IndexNotFoundException) {
deleteConnector(connectorId, actionListener);
deleteConnector(connectorId, tenantId, actionListener);
return;
}
log.error("Failed to search for models using connector: {}", connectorId, cause);
actionListener.onFailure(cause);
}

private void deleteConnector(String connectorId, ActionListener<DeleteResponse> actionListener) {
private void deleteConnector(String connectorId, String tenantId, ActionListener<DeleteResponse> actionListener) {
DeleteRequest deleteRequest = new DeleteRequest(ML_CONNECTOR_INDEX, connectorId);
try {
sdkClient
.deleteDataObjectAsync(
DeleteDataObjectRequest.builder().index(deleteRequest.index()).id(deleteRequest.id()).build(),
DeleteDataObjectRequest.builder().index(deleteRequest.index()).id(deleteRequest.id()).tenantId(tenantId).build(),
client.threadPool().executor(GENERAL_THREAD_POOL)
)
.whenComplete((response, throwable) -> handleDeleteResponse(response, throwable, connectorId, actionListener));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ public class DDBOpenSearchClient implements SdkClientDelegate {
private static final Long DEFAULT_SEQUENCE_NUMBER = 0L;
private static final Long DEFAULT_PRIMARY_TERM = 1L;
private static final String RANGE_KEY = "_id";
private static final String HASH_KEY = "_tenant_id";

private static final String SOURCE = "_source";
private static final String SEQ_NO_KEY = "_seq_no";
Expand Down Expand Up @@ -130,7 +131,7 @@ public CompletionStage<PutDataObjectResponse> putDataObjectAsync(
sourceMap.put(TENANT_ID, AttributeValue.builder().s(tenantId).build());
}
Map<String, AttributeValue> item = new HashMap<>();
item.put(TENANT_ID, AttributeValue.builder().s(tenantId).build());
item.put(HASH_KEY, AttributeValue.builder().s(tenantId).build());
item.put(RANGE_KEY, AttributeValue.builder().s(id).build());
item.put(SOURCE, AttributeValue.builder().m(sourceMap).build());
item.put(SEQ_NO_KEY, AttributeValue.builder().n(sequenceNumber.toString()).build());
Expand Down Expand Up @@ -232,10 +233,10 @@ public CompletionStage<UpdateDataObjectResponse> updateDataObjectAsync(
String source = Strings.toString(MediaTypeRegistry.JSON, request.dataObject());
JsonNode jsonNode = OBJECT_MAPPER.readTree(source);
Map<String, AttributeValue> updateItem = JsonTransformer.convertJsonObjectToDDBAttributeMap(jsonNode);
updateItem.remove(TENANT_ID);
updateItem.remove(HASH_KEY);
updateItem.remove(RANGE_KEY);
Map<String, AttributeValue> updateKey = new HashMap<>();
updateKey.put(TENANT_ID, AttributeValue.builder().s(tenantId).build());
updateKey.put(HASH_KEY, AttributeValue.builder().s(tenantId).build());
updateKey.put(RANGE_KEY, AttributeValue.builder().s(request.id()).build());
UpdateItemRequest.Builder updateItemRequestBuilder = UpdateItemRequest.builder().tableName(request.index()).key(updateKey);
Map<String, String> expressionAttributeNames = new HashMap<>();
Expand Down Expand Up @@ -305,7 +306,7 @@ public CompletionStage<DeleteDataObjectResponse> deleteDataObjectAsync(
.key(
Map
.ofEntries(
Map.entry(TENANT_ID, AttributeValue.builder().s(tenantId).build()),
Map.entry(HASH_KEY, AttributeValue.builder().s(tenantId).build()),
Map.entry(RANGE_KEY, AttributeValue.builder().s(request.id()).build())
)
)
Expand Down Expand Up @@ -371,7 +372,7 @@ private GetItemRequest buildGetItemRequest(String requestTenantId, String docume
.key(
Map
.ofEntries(
Map.entry(TENANT_ID, AttributeValue.builder().s(tenantId).build()),
Map.entry(HASH_KEY, AttributeValue.builder().s(tenantId).build()),
Map.entry(RANGE_KEY, AttributeValue.builder().s(documentId).build())
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
public class DDBOpenSearchClientTests extends OpenSearchTestCase {

private static final String RANGE_KEY = "_id";
private static final String HASH_KEY = "_tenant_id";
private static final String SEQ_NUM = "_seq_no";

private static final String TEST_ID = "123";
Expand Down Expand Up @@ -159,7 +160,7 @@ public void testPutDataObject_HappyCase() throws IOException {
PutItemRequest putItemRequest = putItemRequestArgumentCaptor.getValue();
Assert.assertEquals(TEST_INDEX, putItemRequest.tableName());
Assert.assertEquals(TEST_ID, putItemRequest.item().get(RANGE_KEY).s());
Assert.assertEquals(TENANT_ID, putItemRequest.item().get(CommonValue.TENANT_ID).s());
Assert.assertEquals(TENANT_ID, putItemRequest.item().get(HASH_KEY).s());
Assert.assertEquals("0", putItemRequest.item().get(SEQ_NUM).n());
Assert.assertEquals("foo", putItemRequest.item().get("_source").m().get("data").s());
}
Expand Down Expand Up @@ -248,7 +249,7 @@ public void testPutDataObject_NullTenantId_SetsDefaultTenantId() throws IOExcept
Mockito.verify(dynamoDbClient).putItem(putItemRequestArgumentCaptor.capture());

PutItemRequest putItemRequest = putItemRequestArgumentCaptor.getValue();
Assert.assertEquals("DEFAULT_TENANT", putItemRequest.item().get(CommonValue.TENANT_ID).s());
Assert.assertEquals("DEFAULT_TENANT", putItemRequest.item().get(HASH_KEY).s());
Assert.assertNull(putItemRequest.item().get("_source").m().get(CommonValue.TENANT_ID));
}

Expand Down Expand Up @@ -307,7 +308,7 @@ public void testGetDataObject_HappyCase() throws IOException {
Mockito.verify(dynamoDbClient).getItem(getItemRequestArgumentCaptor.capture());
GetItemRequest getItemRequest = getItemRequestArgumentCaptor.getValue();
Assert.assertEquals(TEST_INDEX, getItemRequest.tableName());
Assert.assertEquals(TENANT_ID, getItemRequest.key().get(CommonValue.TENANT_ID).s());
Assert.assertEquals(TENANT_ID, getItemRequest.key().get(HASH_KEY).s());
Assert.assertEquals(TEST_ID, getItemRequest.key().get(RANGE_KEY).s());
Assert.assertEquals(TEST_ID, response.id());
Assert.assertEquals("foo", response.source().get("data"));
Expand Down Expand Up @@ -379,7 +380,7 @@ public void testGetDataObject_UseDefaultTenantIdIfNull() throws IOException {
sdkClient.getDataObjectAsync(getRequest, testThreadPool.executor(GENERAL_THREAD_POOL)).toCompletableFuture().join();
Mockito.verify(dynamoDbClient).getItem(getItemRequestArgumentCaptor.capture());
GetItemRequest getItemRequest = getItemRequestArgumentCaptor.getValue();
Assert.assertEquals("DEFAULT_TENANT", getItemRequest.key().get(CommonValue.TENANT_ID).s());
Assert.assertEquals("DEFAULT_TENANT", getItemRequest.key().get(HASH_KEY).s());
}

@Test
Expand All @@ -405,7 +406,7 @@ public void testDeleteDataObject_HappyCase() throws IOException {
.join();
DeleteItemRequest deleteItemRequest = deleteItemRequestArgumentCaptor.getValue();
Assert.assertEquals(TEST_INDEX, deleteItemRequest.tableName());
Assert.assertEquals(TENANT_ID, deleteItemRequest.key().get(CommonValue.TENANT_ID).s());
Assert.assertEquals(TENANT_ID, deleteItemRequest.key().get(HASH_KEY).s());
Assert.assertEquals(TEST_ID, deleteItemRequest.key().get(RANGE_KEY).s());
Assert.assertEquals(TEST_ID, deleteResponse.id());

Expand All @@ -425,7 +426,7 @@ public void testDeleteDataObject_NullTenantId_UsesDefaultTenantId() {
Mockito.when(dynamoDbClient.deleteItem(deleteItemRequestArgumentCaptor.capture())).thenReturn(DeleteItemResponse.builder().build());
sdkClient.deleteDataObjectAsync(deleteRequest, testThreadPool.executor(GENERAL_THREAD_POOL)).toCompletableFuture().join();
DeleteItemRequest deleteItemRequest = deleteItemRequestArgumentCaptor.getValue();
Assert.assertEquals("DEFAULT_TENANT", deleteItemRequest.key().get(CommonValue.TENANT_ID).s());
Assert.assertEquals("DEFAULT_TENANT", deleteItemRequest.key().get(HASH_KEY).s());
}

@Test
Expand All @@ -447,7 +448,7 @@ public void updateDataObjectAsync_HappyCase() {
assertEquals(TEST_ID, updateRequest.id());
assertEquals(TEST_INDEX, updateItemRequest.tableName());
assertEquals(TEST_ID, updateItemRequest.key().get(RANGE_KEY).s());
assertEquals(TENANT_ID, updateItemRequest.key().get(CommonValue.TENANT_ID).s());
assertEquals(TENANT_ID, updateItemRequest.key().get(HASH_KEY).s());
assertEquals("foo", updateItemRequest.expressionAttributeValues().get(":source").m().get("data").s());

}
Expand All @@ -474,7 +475,7 @@ public void updateDataObjectAsync_HappyCaseWithMap() throws Exception {
UpdateItemRequest updateItemRequest = updateItemRequestArgumentCaptor.getValue();
assertEquals(TEST_INDEX, updateItemRequest.tableName());
assertEquals(TEST_ID, updateItemRequest.key().get(RANGE_KEY).s());
assertEquals(TENANT_ID, updateItemRequest.key().get(CommonValue.TENANT_ID).s());
assertEquals(TENANT_ID, updateItemRequest.key().get(HASH_KEY).s());
assertTrue(updateItemRequest.expressionAttributeNames().containsKey("#seqNo"));
assertTrue(updateItemRequest.expressionAttributeNames().containsKey("#source"));
assertTrue(updateItemRequest.expressionAttributeValues().containsKey(":incr"));
Expand All @@ -498,7 +499,7 @@ public void updateDataObjectAsync_NullTenantId_UsesDefaultTenantId() {
Mockito.when(dynamoDbClient.updateItem(updateItemRequestArgumentCaptor.capture())).thenReturn(UpdateItemResponse.builder().build());
sdkClient.updateDataObjectAsync(updateRequest, testThreadPool.executor(GENERAL_THREAD_POOL)).toCompletableFuture().join();
UpdateItemRequest updateItemRequest = updateItemRequestArgumentCaptor.getValue();
assertEquals(TENANT_ID, updateItemRequest.key().get(CommonValue.TENANT_ID).s());
assertEquals(TENANT_ID, updateItemRequest.key().get(HASH_KEY).s());
}

public void testUpdateDataObject_VersionCheck() throws IOException {
Expand Down

0 comments on commit 0ef79ec

Please sign in to comment.