Skip to content

Commit

Permalink
feat(dsp): implement GET /id transfer process endpoint (#3234)
Browse files Browse the repository at this point in the history
  • Loading branch information
ronjaquensel authored Jun 27, 2023
1 parent d762971 commit 20c366a
Show file tree
Hide file tree
Showing 10 changed files with 216 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

import java.time.Instant;
import java.util.ArrayList;
import java.util.Optional;

import static java.lang.String.format;
import static org.eclipse.edc.connector.contract.spi.ContractId.createContractId;
Expand Down Expand Up @@ -141,7 +142,16 @@ public Result<ContractAgreement> validateAgreement(ClaimToken token, ContractAgr
}
return success(agreement);
}


@Override
public @NotNull Result<Void> validateRequest(ClaimToken token, ContractAgreement agreement) {
var agent = agentService.createFor(token);
return Optional.ofNullable(agent.getIdentity())
.filter(id -> id.equals(agreement.getConsumerId()) || id.equals(agreement.getProviderId()))
.map(id -> Result.success())
.orElse(Result.failure("Invalid counter-party identity"));
}

@Override
@NotNull
public Result<Void> validateRequest(ClaimToken token, ContractNegotiation negotiation) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ void testNegotiation_initialOfferAccepted() {
var offer = getContractOffer();
when(validationService.validateInitialOffer(token, offer)).thenReturn(Result.success(new ValidatedConsumerOffer(CONSUMER_ID, offer)));
when(validationService.validateConfirmed(eq(token), any(ContractAgreement.class), any(ContractOffer.class))).thenReturn(Result.success());
when(validationService.validateRequest(eq(token), any())).thenReturn(Result.success());
when(validationService.validateRequest(eq(token), any(ContractNegotiation.class))).thenReturn(Result.success());

// Start provider and consumer negotiation managers
providerManager.start();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,45 @@ void validateConfirmed_failsIfPoliciesAreNotEqual() {

verify(agentService).createFor(eq(token));
}

@Test
void validateRequest_shouldReturnSuccess_whenRequestingPartyProvider() {
var token = ClaimToken.Builder.newInstance().build();
var agreement = createContractAgreement().build();
var participantAgent = new ParticipantAgent(Map.of(), Map.of(PARTICIPANT_IDENTITY, PROVIDER_ID));

when(agentService.createFor(token)).thenReturn(participantAgent);

var result = validationService.validateRequest(token, agreement);

assertThat(result).isSucceeded();
}

@Test
void validateRequest_shouldReturnSuccess_whenRequestingPartyConsumer() {
var token = ClaimToken.Builder.newInstance().build();
var agreement = createContractAgreement().build();
var participantAgent = new ParticipantAgent(Map.of(), Map.of(PARTICIPANT_IDENTITY, CONSUMER_ID));

when(agentService.createFor(token)).thenReturn(participantAgent);

var result = validationService.validateRequest(token, agreement);

assertThat(result).isSucceeded();
}

@Test
void validateRequest_shouldReturnFailure_whenRequestingPartyUnauthorized() {
var token = ClaimToken.Builder.newInstance().build();
var agreement = createContractAgreement().build();
var participantAgent = new ParticipantAgent(Map.of(), Map.of(PARTICIPANT_IDENTITY, "invalid"));

when(agentService.createFor(token)).thenReturn(participantAgent);

var result = validationService.validateRequest(token, agreement);

assertThat(result).isFailed();
}

@Test
void validateConsumerRequest() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.eclipse.edc.spi.dataaddress.DataAddressValidator;
import org.eclipse.edc.spi.iam.ClaimToken;
import org.eclipse.edc.spi.monitor.Monitor;
import org.eclipse.edc.spi.result.Result;
import org.eclipse.edc.spi.telemetry.Telemetry;
import org.eclipse.edc.transaction.spi.TransactionContext;
import org.jetbrains.annotations.NotNull;
Expand Down Expand Up @@ -117,7 +118,17 @@ public ServiceResult<TransferProcess> notifyCompleted(TransferCompletionMessage
public ServiceResult<TransferProcess> notifyTerminated(TransferTerminationMessage message, ClaimToken claimToken) {
return onMessageDo(message, transferProcess -> terminatedAction(message, transferProcess));
}


@Override
@WithSpan
@NotNull
public ServiceResult<TransferProcess> findById(String id, ClaimToken claimToken) {
return transactionContext.execute(() -> Optional.ofNullable(transferProcessStore.findById(id))
.filter(tp -> validateCounterParty(claimToken, tp))
.map(ServiceResult::success)
.orElse(ServiceResult.notFound(format("No negotiation with id %s found", id))));
}

@NotNull
private ServiceResult<TransferProcess> requestedAction(TransferRequestMessage message) {
var contractId = ContractId.parse(message.getContractId());
Expand Down Expand Up @@ -200,6 +211,13 @@ private ServiceResult<TransferProcess> onMessageDo(TransferRemoteMessage message
.map(action)
.orElse(ServiceResult.notFound(format("TransferProcess with DataRequest id %s not found", message.getProcessId()))));
}

private boolean validateCounterParty(ClaimToken claimToken, TransferProcess transferProcess) {
return Optional.ofNullable(negotiationStore.findContractAgreement(transferProcess.getDataRequest().getContractId()))
.map(agreement -> contractValidationService.validateRequest(claimToken, agreement))
.filter(Result::succeeded)
.isPresent();
}

private void update(TransferProcess transferProcess) {
transferProcessStore.updateOrCreate(transferProcess);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ void notifyAgreed_shouldReturnBadRequest_whenValidationFails() {
void notifyVerified_shouldTransitionToVerified() {
var negotiation = contractNegotiationBuilder().id("negotiationId").type(PROVIDER).state(AGREED.code()).build();
when(store.findForCorrelationId("processId")).thenReturn(negotiation);
when(validationService.validateRequest(any(), any())).thenReturn(Result.success());
when(validationService.validateRequest(any(), any(ContractNegotiation.class))).thenReturn(Result.success());
var message = ContractAgreementVerificationMessage.Builder.newInstance()
.protocol("protocol")
.counterPartyAddress("http://any")
Expand All @@ -209,15 +209,15 @@ void notifyVerified_shouldTransitionToVerified() {
assertThat(result).isSucceeded();
verify(store).save(argThat(n -> n.getState() == VERIFIED.code()));
verify(listener).verified(negotiation);
verify(validationService).validateRequest(any(), any());
verify(validationService).validateRequest(any(), any(ContractNegotiation.class));
verify(transactionContext, atLeastOnce()).execute(any(TransactionContext.ResultTransactionBlock.class));
}

@Test
void notifyVerified_shouldReturnBadRequest_whenValidationFails() {
var negotiation = contractNegotiationBuilder().id("negotiationId").type(PROVIDER).state(AGREED.code()).build();
when(store.findForCorrelationId("processId")).thenReturn(negotiation);
when(validationService.validateRequest(any(), any())).thenReturn(Result.failure("validation error"));
when(validationService.validateRequest(any(), any(ContractNegotiation.class))).thenReturn(Result.failure("validation error"));
var message = ContractAgreementVerificationMessage.Builder.newInstance()
.protocol("protocol")
.counterPartyAddress("http://any")
Expand All @@ -235,7 +235,7 @@ void notifyVerified_shouldReturnBadRequest_whenValidationFails() {
void notifyFinalized_shouldTransitionToFinalized() {
var negotiation = contractNegotiationBuilder().id("negotiationId").type(PROVIDER).state(VERIFIED.code()).build();
when(store.findForCorrelationId("processId")).thenReturn(negotiation);
when(validationService.validateRequest(any(), any())).thenReturn(Result.success());
when(validationService.validateRequest(any(), any(ContractNegotiation.class))).thenReturn(Result.success());
var message = ContractNegotiationEventMessage.Builder.newInstance()
.type(ContractNegotiationEventMessage.Type.FINALIZED)
.protocol("protocol")
Expand All @@ -249,15 +249,15 @@ void notifyFinalized_shouldTransitionToFinalized() {
assertThat(result).isSucceeded();
verify(store).save(argThat(n -> n.getState() == FINALIZED.code()));
verify(listener).finalized(negotiation);
verify(validationService).validateRequest(any(), any());
verify(validationService).validateRequest(any(), any(ContractNegotiation.class));
verify(transactionContext, atLeastOnce()).execute(any(TransactionContext.ResultTransactionBlock.class));
}

@Test
void notifyFinalized_shouldReturnBadRequest_whenValidationFails() {
var negotiation = contractNegotiationBuilder().id("negotiationId").type(PROVIDER).state(VERIFIED.code()).build();
when(store.findForCorrelationId("processId")).thenReturn(negotiation);
when(validationService.validateRequest(any(), any())).thenReturn(Result.failure("validation error"));
when(validationService.validateRequest(any(), any(ContractNegotiation.class))).thenReturn(Result.failure("validation error"));
var message = ContractNegotiationEventMessage.Builder.newInstance()
.type(ContractNegotiationEventMessage.Type.FINALIZED)
.protocol("protocol")
Expand All @@ -277,7 +277,7 @@ void notifyFinalized_shouldReturnBadRequest_whenValidationFails() {
void notifyTerminated_shouldTransitionToTerminated() {
var negotiation = contractNegotiationBuilder().id("negotiationId").type(PROVIDER).state(VERIFIED.code()).build();
when(store.findForCorrelationId("processId")).thenReturn(negotiation);
when(validationService.validateRequest(any(), any())).thenReturn(Result.success());
when(validationService.validateRequest(any(), any(ContractNegotiation.class))).thenReturn(Result.success());
var message = ContractNegotiationTerminationMessage.Builder.newInstance()
.protocol("protocol")
.processId("processId")
Expand All @@ -291,15 +291,15 @@ void notifyTerminated_shouldTransitionToTerminated() {
assertThat(result).isSucceeded();
verify(store).save(argThat(n -> n.getState() == TERMINATED.code()));
verify(listener).terminated(negotiation);
verify(validationService).validateRequest(any(), any());
verify(validationService).validateRequest(any(), any(ContractNegotiation.class));
verify(transactionContext, atLeastOnce()).execute(any(TransactionContext.ResultTransactionBlock.class));
}

@Test
void notifyTerminated_shouldReturnBadRequest_whenValidationFails() {
var negotiation = contractNegotiationBuilder().id("negotiationId").type(PROVIDER).state(VERIFIED.code()).build();
when(store.findForCorrelationId("processId")).thenReturn(negotiation);
when(validationService.validateRequest(any(), any())).thenReturn(Result.failure("validation error"));
when(validationService.validateRequest(any(), any(ContractNegotiation.class))).thenReturn(Result.failure("validation error"));
var message = ContractNegotiationTerminationMessage.Builder.newInstance()
.protocol("protocol")
.processId("processId")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.eclipse.edc.connector.transfer.spi.observe.TransferProcessListener;
import org.eclipse.edc.connector.transfer.spi.observe.TransferProcessStartedData;
import org.eclipse.edc.connector.transfer.spi.store.TransferProcessStore;
import org.eclipse.edc.connector.transfer.spi.types.DataRequest;
import org.eclipse.edc.connector.transfer.spi.types.TransferProcess;
import org.eclipse.edc.connector.transfer.spi.types.TransferProcessStates;
import org.eclipse.edc.connector.transfer.spi.types.protocol.TransferCompletionMessage;
Expand Down Expand Up @@ -64,6 +65,7 @@
import static org.eclipse.edc.junit.assertions.AbstractResultAssert.assertThat;
import static org.eclipse.edc.service.spi.result.ServiceFailure.Reason.BAD_REQUEST;
import static org.eclipse.edc.service.spi.result.ServiceFailure.Reason.CONFLICT;
import static org.eclipse.edc.service.spi.result.ServiceFailure.Reason.NOT_FOUND;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.atLeastOnce;
Expand Down Expand Up @@ -306,6 +308,55 @@ void notifyTerminated_shouldReturnConflict_whenStatusIsNotValid() {
verify(store, never()).updateOrCreate(any());
verifyNoInteractions(listener);
}

@Test
void findById_shouldReturnTransferProcess_whenValidCounterParty() {
var processId = "transferProcessId";
var transferProcess = transferProcess(INITIAL, processId);
var token = claimToken();
var agreement = contractAgreement();

when(store.findById(processId)).thenReturn(transferProcess);
when(negotiationStore.findContractAgreement(any())).thenReturn(agreement);
when(validationService.validateRequest(token, agreement)).thenReturn(Result.success());

var result = service.findById(processId, token);

assertThat(result)
.isSucceeded()
.isEqualTo(transferProcess);
}

@Test
void findById_shouldReturnNotFound_whenNegotiationNotFound() {
when(store.findById(any())).thenReturn(null);

var result = service.findById("invalidId", ClaimToken.Builder.newInstance().build());

assertThat(result)
.isFailed()
.extracting(ServiceFailure::getReason)
.isEqualTo(NOT_FOUND);
}

@Test
void findById_shouldReturnNotFound_whenCounterPartyUnauthorized() {
var processId = "transferProcessId";
var transferProcess = transferProcess(INITIAL, processId);
var token = claimToken();
var agreement = contractAgreement();

when(store.findById(processId)).thenReturn(transferProcess);
when(negotiationStore.findContractAgreement(any())).thenReturn(agreement);
when(validationService.validateRequest(token, agreement)).thenReturn(Result.failure("error"));

var result = service.findById(processId, token);

assertThat(result)
.isFailed()
.extracting(ServiceFailure::getReason)
.isEqualTo(NOT_FOUND);
}

@ParameterizedTest
@ArgumentsSource(NotFoundArguments.class)
Expand All @@ -323,6 +374,7 @@ private TransferProcess transferProcess(TransferProcessStates state, String id)
return TransferProcess.Builder.newInstance()
.state(state.code())
.id(id)
.dataRequest(dataRequest())
.build();
}

Expand All @@ -341,6 +393,13 @@ private ContractAgreement contractAgreement() {
.policy(Policy.Builder.newInstance().build())
.build();
}

private DataRequest dataRequest() {
return DataRequest.Builder.newInstance()
.contractId("contractId")
.destinationType("type")
.build();
}

@FunctionalInterface
private interface MethodCall<M extends RemoteMessage> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,15 @@ public DspTransferProcessApiController(Monitor monitor,
*/
@GET
@Path("/{id}")
public Response getTransferProcess(@PathParam("id") String id) {
return error().processId(id).notImplemented();
public Response getTransferProcess(@PathParam("id") String id, @HeaderParam(AUTHORIZATION) String token) {
var claimTokenResult = checkAuthToken(token);
if (claimTokenResult.failed()) {
return error().processId(id).unauthorized();
}

return protocolService.findById(id, claimTokenResult.getContent())
.map(this::createTransferProcessResponse)
.orElse(createErrorResponse(id));
}

/**
Expand All @@ -120,15 +127,8 @@ public Response initiateTransferProcess(JsonObject jsonObject, @HeaderParam(AUTH
if (transferProcessResult.failed()) {
return error().from(transferProcessResult.getFailure());
}

var transferProcess = transferProcessResult.getContent();
return registry.transform(transferProcess, JsonObject.class)
.map(transformedJson -> Response.ok().type(MediaType.APPLICATION_JSON).entity(transformedJson).build())
.orElse(failure -> {
var errorCode = UUID.randomUUID();
monitor.warning(String.format("Error transforming transfer process, error id %s: %s", errorCode, failure.getFailureDetail()));
return error().processId(transferProcess.getCorrelationId()).message(String.format("Error code %s", errorCode)).internalServerError();
});

return createTransferProcessResponse(transferProcessResult.getContent());
}

/**
Expand Down Expand Up @@ -270,6 +270,16 @@ private <M extends TransferRemoteMessage> Result<M> validateProcessId(MessageSpe
}
return Result.success(message);
}

private Response createTransferProcessResponse(TransferProcess transferProcess) {
return registry.transform(transferProcess, JsonObject.class)
.map(transformedJson -> Response.ok().type(MediaType.APPLICATION_JSON).entity(transformedJson).build())
.orElse(failure -> {
var errorCode = UUID.randomUUID();
monitor.warning(String.format("Error transforming transfer process, error id %s: %s", errorCode, failure.getFailureDetail()));
return error().processId(transferProcess.getCorrelationId()).message(String.format("Error code %s", errorCode)).internalServerError();
});
}

@NotNull
private static DspErrorResponse error() {
Expand Down
Loading

0 comments on commit 20c366a

Please sign in to comment.