Skip to content

Commit

Permalink
feat: adds counter party validation on all DSP messages (#3402)
Browse files Browse the repository at this point in the history
* feat: add counter party validation

* feat: add lease break on validation failure

* feat: add lease break on action failure
  • Loading branch information
wolf4ood authored Sep 1, 2023
1 parent 2e144b4 commit 28bcd41
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,37 +107,36 @@ public ServiceResult<TransferProcess> notifyRequested(TransferRequestMessage mes
@WithSpan
@NotNull
public ServiceResult<TransferProcess> notifyStarted(TransferStartMessage message, ClaimToken claimToken) {
return onMessageDo(message, transferProcess -> startedAction(message, transferProcess));
return onMessageDo(message, claimToken, transferProcess -> startedAction(message, transferProcess));
}

@Override
@WithSpan
@NotNull
public ServiceResult<TransferProcess> notifyCompleted(TransferCompletionMessage message, ClaimToken claimToken) {
return onMessageDo(message, transferProcess -> completedAction(message, transferProcess));
return onMessageDo(message, claimToken, transferProcess -> completedAction(message, transferProcess));
}

@Override
@WithSpan
@NotNull
public ServiceResult<TransferProcess> notifyTerminated(TransferTerminationMessage message, ClaimToken claimToken) {
return onMessageDo(message, transferProcess -> terminatedAction(message, transferProcess));
return onMessageDo(message, claimToken, transferProcess -> terminatedAction(message, transferProcess));
}

@Override
@WithSpan
@NotNull
public ServiceResult<TransferProcess> findById(String id, ClaimToken claimToken) {
return transactionContext.execute(() -> Optional.ofNullable(transferProcessStore.findById(id))
.filter(tp -> validateCounterParty(claimToken, tp))
.map(ServiceResult::success)
.orElse(ServiceResult.notFound(format("No negotiation with id %s found", id))));
.map(tp -> validateCounterParty(claimToken, tp))
.orElse(notFound(id)));
}

@NotNull
private ServiceResult<TransferProcess> requestedAction(TransferRequestMessage message, ContractId contractId) {
var assetId = contractId.assetIdPart();

var destination = message.getDataDestination() != null ? message.getDataDestination() :
DataAddress.Builder.newInstance().type(HTTP_PROXY).build();
var dataRequest = DataRequest.Builder.newInstance()
Expand Down Expand Up @@ -210,18 +209,30 @@ private ServiceResult<TransferProcess> terminatedAction(TransferTerminationMessa
}
}

private ServiceResult<TransferProcess> onMessageDo(TransferRemoteMessage message, Function<TransferProcess, ServiceResult<TransferProcess>> action) {
private ServiceResult<TransferProcess> onMessageDo(TransferRemoteMessage message, ClaimToken claimToken, Function<TransferProcess, ServiceResult<TransferProcess>> action) {
return transactionContext.execute(() -> transferProcessStore
.findByCorrelationIdAndLease(message.getProcessId())
.flatMap(ServiceResult::from)
.compose(action));
.compose(transferProcess -> validateCounterParty(claimToken, transferProcess)
.compose(action)
.onFailure(f -> breakLease(transferProcess))));
}
private boolean validateCounterParty(ClaimToken claimToken, TransferProcess transferProcess) {

private ServiceResult<TransferProcess> validateCounterParty(ClaimToken claimToken, TransferProcess transferProcess) {
return Optional.ofNullable(negotiationStore.findContractAgreement(transferProcess.getContractId()))
.map(agreement -> contractValidationService.validateRequest(claimToken, agreement))
.filter(Result::succeeded)
.isPresent();
.map(e -> ServiceResult.success(transferProcess))
.orElse(notFound(transferProcess.getId()));

}

private ServiceResult<TransferProcess> notFound(String transferProcessId) {
return ServiceResult.notFound(format("No transfer process with id %s found", transferProcessId));
}

private void breakLease(TransferProcess process) {
transferProcessStore.save(process);
}

private void update(TransferProcess transferProcess) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package org.eclipse.edc.connector.service.transferprocess;

import org.eclipse.edc.connector.contract.spi.negotiation.store.ContractNegotiationStore;
import org.eclipse.edc.connector.contract.spi.types.agreement.ContractAgreement;
import org.eclipse.edc.connector.core.event.EventExecutorServiceContainer;
import org.eclipse.edc.connector.dataplane.selector.spi.store.DataPlaneInstanceStore;
import org.eclipse.edc.connector.policy.spi.store.PolicyArchive;
Expand All @@ -38,6 +40,8 @@
import org.eclipse.edc.junit.extensions.EdcExtension;
import org.eclipse.edc.policy.model.Policy;
import org.eclipse.edc.spi.EdcException;
import org.eclipse.edc.spi.agent.ParticipantAgent;
import org.eclipse.edc.spi.agent.ParticipantAgentService;
import org.eclipse.edc.spi.event.EventEnvelope;
import org.eclipse.edc.spi.event.EventRouter;
import org.eclipse.edc.spi.event.EventSubscriber;
Expand Down Expand Up @@ -100,6 +104,8 @@ void setUp(EdcExtension extension) {
extension.registerServiceMock(ProtocolWebhook.class, () -> "http://dummy");
extension.registerServiceMock(DataPlaneInstanceStore.class, mock(DataPlaneInstanceStore.class));
extension.registerServiceMock(PolicyArchive.class, mock(PolicyArchive.class));
extension.registerServiceMock(ContractNegotiationStore.class, mock(ContractNegotiationStore.class));
extension.registerServiceMock(ParticipantAgentService.class, mock(ParticipantAgentService.class));
}

@Test
Expand All @@ -108,9 +114,23 @@ void shouldDispatchEventsOnTransferProcessStateChanges(TransferProcessService se
EventRouter eventRouter,
RemoteMessageDispatcherRegistry dispatcherRegistry,
StatusCheckerRegistry statusCheckerRegistry,
PolicyArchive policyArchive) {
PolicyArchive policyArchive,
ContractNegotiationStore negotiationStore,
ParticipantAgentService agentService) {

var token = ClaimToken.Builder.newInstance().build();
var agent = mock(ParticipantAgent.class);
var agreement = mock(ContractAgreement.class);
var providerId = "ProviderId";

when(agreement.getProviderId()).thenReturn(providerId);
when(agent.getIdentity()).thenReturn(providerId);


dispatcherRegistry.register(getTestDispatcher());
when(policyArchive.findPolicyForContract(matches("contractId"))).thenReturn(mock(Policy.class));
when(negotiationStore.findContractAgreement("contractId")).thenReturn(agreement);
when(agentService.createFor(token)).thenReturn(agent);
eventRouter.register(TransferProcessEvent.class, eventSubscriber);
var statusCheck = mock(StatusChecker.class);

Expand All @@ -134,7 +154,7 @@ void shouldDispatchEventsOnTransferProcessStateChanges(TransferProcessService se
.dataAddress(dataAddress)
.build();

protocolService.notifyStarted(startMessage, ClaimToken.Builder.newInstance().build());
protocolService.notifyStarted(startMessage, token);

await().atMost(TIMEOUT).untilAsserted(() -> {
ArgumentCaptor<EventEnvelope<TransferProcessStarted>> captor = ArgumentCaptor.forClass(EventEnvelope.class);
Expand Down
Loading

0 comments on commit 28bcd41

Please sign in to comment.