diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java index ddef704bfa..a5e3c323cd 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java @@ -140,7 +140,7 @@ private void checkForModelsUsingConnector(String connectorId, String tenantId, A SearchResponse searchResponse = SearchResponse.fromXContent(sr.parser()); SearchHit[] searchHits = searchResponse.getHits().getHits(); if (searchHits.length == 0) { - deleteConnector(connectorId, restoringListener); + deleteConnector(connectorId, tenantId, restoringListener); } else { handleModelsUsingConnector(searchHits, connectorId, restoringListener); } @@ -153,7 +153,7 @@ private void checkForModelsUsingConnector(String connectorId, String tenantId, A } } else { Exception cause = SdkClientUtils.unwrapAndConvertToException(st); - handleSearchFailure(connectorId, cause, restoringListener); + handleSearchFailure(connectorId, tenantId, cause, restoringListener); } }); } catch (Exception e) { @@ -179,21 +179,21 @@ private void handleModelsUsingConnector(SearchHit[] searchHits, String connector ); } - private void handleSearchFailure(String connectorId, Exception cause, ActionListener actionListener) { + private void handleSearchFailure(String connectorId, String tenantId, Exception cause, ActionListener actionListener) { if (cause instanceof IndexNotFoundException) { - deleteConnector(connectorId, actionListener); + deleteConnector(connectorId, tenantId, actionListener); return; } log.error("Failed to search for models using connector: {}", connectorId, cause); actionListener.onFailure(cause); } - private void deleteConnector(String connectorId, ActionListener actionListener) { + private void deleteConnector(String connectorId, String tenantId, ActionListener actionListener) { DeleteRequest deleteRequest = new DeleteRequest(ML_CONNECTOR_INDEX, connectorId); try { sdkClient .deleteDataObjectAsync( - DeleteDataObjectRequest.builder().index(deleteRequest.index()).id(deleteRequest.id()).build(), + DeleteDataObjectRequest.builder().index(deleteRequest.index()).id(deleteRequest.id()).tenantId(tenantId).build(), client.threadPool().executor(GENERAL_THREAD_POOL) ) .whenComplete((response, throwable) -> handleDeleteResponse(response, throwable, connectorId, actionListener)); diff --git a/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java b/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java index 7bbb3b4993..dd81a88179 100644 --- a/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java +++ b/plugin/src/main/java/org/opensearch/ml/sdkclient/DDBOpenSearchClient.java @@ -80,6 +80,7 @@ public class DDBOpenSearchClient implements SdkClientDelegate { private static final Long DEFAULT_SEQUENCE_NUMBER = 0L; private static final Long DEFAULT_PRIMARY_TERM = 1L; private static final String RANGE_KEY = "_id"; + private static final String HASH_KEY = "_tenant_id"; private static final String SOURCE = "_source"; private static final String SEQ_NO_KEY = "_seq_no"; @@ -130,7 +131,7 @@ public CompletionStage putDataObjectAsync( sourceMap.put(TENANT_ID, AttributeValue.builder().s(tenantId).build()); } Map item = new HashMap<>(); - item.put(TENANT_ID, AttributeValue.builder().s(tenantId).build()); + item.put(HASH_KEY, AttributeValue.builder().s(tenantId).build()); item.put(RANGE_KEY, AttributeValue.builder().s(id).build()); item.put(SOURCE, AttributeValue.builder().m(sourceMap).build()); item.put(SEQ_NO_KEY, AttributeValue.builder().n(sequenceNumber.toString()).build()); @@ -232,10 +233,10 @@ public CompletionStage updateDataObjectAsync( String source = Strings.toString(MediaTypeRegistry.JSON, request.dataObject()); JsonNode jsonNode = OBJECT_MAPPER.readTree(source); Map updateItem = JsonTransformer.convertJsonObjectToDDBAttributeMap(jsonNode); - updateItem.remove(TENANT_ID); + updateItem.remove(HASH_KEY); updateItem.remove(RANGE_KEY); Map updateKey = new HashMap<>(); - updateKey.put(TENANT_ID, AttributeValue.builder().s(tenantId).build()); + updateKey.put(HASH_KEY, AttributeValue.builder().s(tenantId).build()); updateKey.put(RANGE_KEY, AttributeValue.builder().s(request.id()).build()); UpdateItemRequest.Builder updateItemRequestBuilder = UpdateItemRequest.builder().tableName(request.index()).key(updateKey); Map expressionAttributeNames = new HashMap<>(); @@ -305,7 +306,7 @@ public CompletionStage deleteDataObjectAsync( .key( Map .ofEntries( - Map.entry(TENANT_ID, AttributeValue.builder().s(tenantId).build()), + Map.entry(HASH_KEY, AttributeValue.builder().s(tenantId).build()), Map.entry(RANGE_KEY, AttributeValue.builder().s(request.id()).build()) ) ) @@ -371,7 +372,7 @@ private GetItemRequest buildGetItemRequest(String requestTenantId, String docume .key( Map .ofEntries( - Map.entry(TENANT_ID, AttributeValue.builder().s(tenantId).build()), + Map.entry(HASH_KEY, AttributeValue.builder().s(tenantId).build()), Map.entry(RANGE_KEY, AttributeValue.builder().s(documentId).build()) ) ) diff --git a/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java b/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java index 2288441b94..0b9e97f25f 100644 --- a/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java +++ b/plugin/src/test/java/org/opensearch/ml/sdkclient/DDBOpenSearchClientTests.java @@ -83,6 +83,7 @@ public class DDBOpenSearchClientTests extends OpenSearchTestCase { private static final String RANGE_KEY = "_id"; + private static final String HASH_KEY = "_tenant_id"; private static final String SEQ_NUM = "_seq_no"; private static final String TEST_ID = "123"; @@ -159,7 +160,7 @@ public void testPutDataObject_HappyCase() throws IOException { PutItemRequest putItemRequest = putItemRequestArgumentCaptor.getValue(); Assert.assertEquals(TEST_INDEX, putItemRequest.tableName()); Assert.assertEquals(TEST_ID, putItemRequest.item().get(RANGE_KEY).s()); - Assert.assertEquals(TENANT_ID, putItemRequest.item().get(CommonValue.TENANT_ID).s()); + Assert.assertEquals(TENANT_ID, putItemRequest.item().get(HASH_KEY).s()); Assert.assertEquals("0", putItemRequest.item().get(SEQ_NUM).n()); Assert.assertEquals("foo", putItemRequest.item().get("_source").m().get("data").s()); } @@ -248,7 +249,7 @@ public void testPutDataObject_NullTenantId_SetsDefaultTenantId() throws IOExcept Mockito.verify(dynamoDbClient).putItem(putItemRequestArgumentCaptor.capture()); PutItemRequest putItemRequest = putItemRequestArgumentCaptor.getValue(); - Assert.assertEquals("DEFAULT_TENANT", putItemRequest.item().get(CommonValue.TENANT_ID).s()); + Assert.assertEquals("DEFAULT_TENANT", putItemRequest.item().get(HASH_KEY).s()); Assert.assertNull(putItemRequest.item().get("_source").m().get(CommonValue.TENANT_ID)); } @@ -307,7 +308,7 @@ public void testGetDataObject_HappyCase() throws IOException { Mockito.verify(dynamoDbClient).getItem(getItemRequestArgumentCaptor.capture()); GetItemRequest getItemRequest = getItemRequestArgumentCaptor.getValue(); Assert.assertEquals(TEST_INDEX, getItemRequest.tableName()); - Assert.assertEquals(TENANT_ID, getItemRequest.key().get(CommonValue.TENANT_ID).s()); + Assert.assertEquals(TENANT_ID, getItemRequest.key().get(HASH_KEY).s()); Assert.assertEquals(TEST_ID, getItemRequest.key().get(RANGE_KEY).s()); Assert.assertEquals(TEST_ID, response.id()); Assert.assertEquals("foo", response.source().get("data")); @@ -379,7 +380,7 @@ public void testGetDataObject_UseDefaultTenantIdIfNull() throws IOException { sdkClient.getDataObjectAsync(getRequest, testThreadPool.executor(GENERAL_THREAD_POOL)).toCompletableFuture().join(); Mockito.verify(dynamoDbClient).getItem(getItemRequestArgumentCaptor.capture()); GetItemRequest getItemRequest = getItemRequestArgumentCaptor.getValue(); - Assert.assertEquals("DEFAULT_TENANT", getItemRequest.key().get(CommonValue.TENANT_ID).s()); + Assert.assertEquals("DEFAULT_TENANT", getItemRequest.key().get(HASH_KEY).s()); } @Test @@ -405,7 +406,7 @@ public void testDeleteDataObject_HappyCase() throws IOException { .join(); DeleteItemRequest deleteItemRequest = deleteItemRequestArgumentCaptor.getValue(); Assert.assertEquals(TEST_INDEX, deleteItemRequest.tableName()); - Assert.assertEquals(TENANT_ID, deleteItemRequest.key().get(CommonValue.TENANT_ID).s()); + Assert.assertEquals(TENANT_ID, deleteItemRequest.key().get(HASH_KEY).s()); Assert.assertEquals(TEST_ID, deleteItemRequest.key().get(RANGE_KEY).s()); Assert.assertEquals(TEST_ID, deleteResponse.id()); @@ -425,7 +426,7 @@ public void testDeleteDataObject_NullTenantId_UsesDefaultTenantId() { Mockito.when(dynamoDbClient.deleteItem(deleteItemRequestArgumentCaptor.capture())).thenReturn(DeleteItemResponse.builder().build()); sdkClient.deleteDataObjectAsync(deleteRequest, testThreadPool.executor(GENERAL_THREAD_POOL)).toCompletableFuture().join(); DeleteItemRequest deleteItemRequest = deleteItemRequestArgumentCaptor.getValue(); - Assert.assertEquals("DEFAULT_TENANT", deleteItemRequest.key().get(CommonValue.TENANT_ID).s()); + Assert.assertEquals("DEFAULT_TENANT", deleteItemRequest.key().get(HASH_KEY).s()); } @Test @@ -447,7 +448,7 @@ public void updateDataObjectAsync_HappyCase() { assertEquals(TEST_ID, updateRequest.id()); assertEquals(TEST_INDEX, updateItemRequest.tableName()); assertEquals(TEST_ID, updateItemRequest.key().get(RANGE_KEY).s()); - assertEquals(TENANT_ID, updateItemRequest.key().get(CommonValue.TENANT_ID).s()); + assertEquals(TENANT_ID, updateItemRequest.key().get(HASH_KEY).s()); assertEquals("foo", updateItemRequest.expressionAttributeValues().get(":source").m().get("data").s()); } @@ -474,7 +475,7 @@ public void updateDataObjectAsync_HappyCaseWithMap() throws Exception { UpdateItemRequest updateItemRequest = updateItemRequestArgumentCaptor.getValue(); assertEquals(TEST_INDEX, updateItemRequest.tableName()); assertEquals(TEST_ID, updateItemRequest.key().get(RANGE_KEY).s()); - assertEquals(TENANT_ID, updateItemRequest.key().get(CommonValue.TENANT_ID).s()); + assertEquals(TENANT_ID, updateItemRequest.key().get(HASH_KEY).s()); assertTrue(updateItemRequest.expressionAttributeNames().containsKey("#seqNo")); assertTrue(updateItemRequest.expressionAttributeNames().containsKey("#source")); assertTrue(updateItemRequest.expressionAttributeValues().containsKey(":incr")); @@ -498,7 +499,7 @@ public void updateDataObjectAsync_NullTenantId_UsesDefaultTenantId() { Mockito.when(dynamoDbClient.updateItem(updateItemRequestArgumentCaptor.capture())).thenReturn(UpdateItemResponse.builder().build()); sdkClient.updateDataObjectAsync(updateRequest, testThreadPool.executor(GENERAL_THREAD_POOL)).toCompletableFuture().join(); UpdateItemRequest updateItemRequest = updateItemRequestArgumentCaptor.getValue(); - assertEquals(TENANT_ID, updateItemRequest.key().get(CommonValue.TENANT_ID).s()); + assertEquals(TENANT_ID, updateItemRequest.key().get(HASH_KEY).s()); } public void testUpdateDataObject_VersionCheck() throws IOException {