Skip to content

Commit

Permalink
feat(cbor): add protocol resolution priority system
Browse files Browse the repository at this point in the history
  • Loading branch information
kuhe committed Aug 16, 2024
1 parent 4c692ae commit e920c7e
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand Down Expand Up @@ -136,7 +136,9 @@ private ProtocolGenerator resolveProtocolGenerator(
TypeScriptSettings settings
) {
// Collect all of the supported protocol generators.
Map<ShapeId, ProtocolGenerator> generators = new HashMap<>();
// Preserve insertion order as default priority order.
Map<ShapeId, ProtocolGenerator> generators = new LinkedHashMap<>();

for (TypeScriptIntegration integration : integrations) {
for (ProtocolGenerator generator : integration.getProtocolGenerators()) {
generators.put(generator.getProtocol(), generator);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package software.amazon.smithy.typescript.codegen;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
Expand All @@ -36,6 +37,7 @@
import software.amazon.smithy.model.shapes.ShapeId;
import software.amazon.smithy.model.traits.DefaultTrait;
import software.amazon.smithy.model.traits.RequiredTrait;
import software.amazon.smithy.typescript.codegen.protocols.ProtocolPriority;
import software.amazon.smithy.utils.SmithyUnstableApi;

/**
Expand Down Expand Up @@ -450,8 +452,13 @@ public ShapeId resolveServiceProtocol(Model model, ServiceShape service, Set<Sha
+ "generate in smithy-build.json to generate this service.");
}

return resolvedProtocols.stream()
.filter(supportedProtocols::contains)
List<ShapeId> protocolPriority = ProtocolPriority.getProtocolPriority(service.toShapeId());
List<ShapeId> protocolPriorityList = protocolPriority != null && !protocolPriority.isEmpty()
? protocolPriority
: new ArrayList<>(supportedProtocols);

return protocolPriorityList.stream()
.filter(resolvedProtocols::contains)
.findFirst()
.orElseThrow(() -> new UnresolvableProtocolException(String.format(
"The %s service supports the following unsupported protocols %s. The following protocol "
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package software.amazon.smithy.typescript.codegen.protocols;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import software.amazon.smithy.model.shapes.ShapeId;


/**
* Allows customization of protocol selection for specific services or a global default ordering.
*/
public final class ProtocolPriority {
private static final Map<ShapeId, List<ShapeId>> SERVICE_PROTOCOL_PRIORITY_CUSTOMIZATIONS = new HashMap<>();
private static List<ShapeId> customDefaultPriority = null;

private ProtocolPriority() {}

/**
* @param serviceShapeId - service scope.
* @param protocolPriorityOrder - priority order of protocols.
*/
public static void setProtocolPriority(ShapeId serviceShapeId, List<ShapeId> protocolPriorityOrder) {
SERVICE_PROTOCOL_PRIORITY_CUSTOMIZATIONS.put(serviceShapeId, protocolPriorityOrder);
}

/**
* @param defaultProtocolPriorityOrder - use for all services that don't have a more specific priority order.
*/
public static void setCustomDefaultProtocolPriority(List<ShapeId> defaultProtocolPriorityOrder) {
customDefaultPriority = new ArrayList<>(defaultProtocolPriorityOrder);
}

/**
* @param serviceShapeId - service scope.
* @return priority order of protocols or null if no override exists.
*/
public static List<ShapeId> getProtocolPriority(ShapeId serviceShapeId) {
return SERVICE_PROTOCOL_PRIORITY_CUSTOMIZATIONS.getOrDefault(
serviceShapeId,
customDefaultPriority != null ? new ArrayList<>(customDefaultPriority) : null
);
}

/**
* @param serviceShapeId - to unset.
*/
public static void deleteProtocolPriority(ShapeId serviceShapeId) {
SERVICE_PROTOCOL_PRIORITY_CUSTOMIZATIONS.remove(serviceShapeId);
}

/**
* Unset the custom default priority order.
*/
public static void deleteCustomDefaultProtocolPriority() {
customDefaultPriority = null;
}
}
Original file line number Diff line number Diff line change
@@ -1,20 +1,31 @@
package software.amazon.smithy.typescript.codegen;

import java.util.LinkedHashSet;
import java.util.List;
import java.util.stream.Stream;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.knowledge.ServiceIndex;
import software.amazon.smithy.model.node.Node;
import software.amazon.smithy.model.node.ObjectNode;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.shapes.ShapeId;

import java.util.stream.Stream;
import software.amazon.smithy.typescript.codegen.protocols.ProtocolPriority;
import software.amazon.smithy.utils.MapUtils;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
public class TypeScriptSettingsTest {

@Test
Expand Down Expand Up @@ -87,6 +98,121 @@ private static Stream<Arguments> providePackageDescriptionTestCases() {
);
}

@Test
public void resolveServiceProtocol(@Mock Model model,
@Mock ServiceShape service,
@Mock ServiceIndex serviceIndex) {
TypeScriptSettings subject = new TypeScriptSettings();

// note: these are mock protocol names.
ShapeId rpcv2Cbor = ShapeId.from("namespace#rpcv2Cbor");
ShapeId json1_0 = ShapeId.from("namespace#json1_0");
ShapeId json1_1 = ShapeId.from("namespace#json1_1");
ShapeId restJson1 = ShapeId.from("namespace#restJson1");
ShapeId restXml = ShapeId.from("namespace#restXml");
ShapeId query = ShapeId.from("namespace#query");
ShapeId serviceQuery = ShapeId.from("namespace#serviceQuery");

when(model.getKnowledge(any(), any())).thenReturn(serviceIndex);
ShapeId serviceShapeId = ShapeId.from("namespace#Service");
when(service.toShapeId()).thenReturn(serviceShapeId);

LinkedHashSet<ShapeId> protocolShapeIds = new LinkedHashSet<>(
List.of(
json1_0, json1_1, restJson1, rpcv2Cbor, restXml, query, serviceQuery
)
);

{
// spec case 1.
when(serviceIndex.getProtocols(service)).thenReturn(MapUtils.of(
rpcv2Cbor, null,
json1_0, null
));
ShapeId protocol = subject.resolveServiceProtocol(model, service, protocolShapeIds);
// Note: JS customization JSON higher default priority than CBOR.
assertEquals(json1_0, protocol);
}

{
// spec case 2.
when(serviceIndex.getProtocols(service)).thenReturn(MapUtils.of(
rpcv2Cbor, null
));
ShapeId protocol = subject.resolveServiceProtocol(model, service, protocolShapeIds);
assertEquals(rpcv2Cbor, protocol);
}

{
// spec case 3.
when(serviceIndex.getProtocols(service)).thenReturn(MapUtils.of(
rpcv2Cbor, null,
json1_0, null,
query, null
));
ShapeId protocol = subject.resolveServiceProtocol(model, service, protocolShapeIds);
// Note: JS customization JSON higher default priority than CBOR.
assertEquals(json1_0, protocol);
}

{
// spec case 4.
when(serviceIndex.getProtocols(service)).thenReturn(MapUtils.of(
json1_0, null,
query, null
));
ShapeId protocol = subject.resolveServiceProtocol(model, service, protocolShapeIds);
assertEquals(json1_0, protocol);
}

{
// spec case 5.
when(serviceIndex.getProtocols(service)).thenReturn(MapUtils.of(
query, null
));
ShapeId protocol = subject.resolveServiceProtocol(model, service, protocolShapeIds);
assertEquals(query, protocol);
}

{
// service override, non-spec
when(serviceIndex.getProtocols(service)).thenReturn(MapUtils.of(
json1_0, null,
json1_1, null,
restJson1, null,
rpcv2Cbor, null,
restXml, null,
query, null,
serviceQuery, null
));
ProtocolPriority.setProtocolPriority(serviceShapeId, List.of(
serviceQuery, rpcv2Cbor, json1_1, restJson1, restXml, query
));
ShapeId protocol = subject.resolveServiceProtocol(model, service, protocolShapeIds);
ProtocolPriority.deleteProtocolPriority(serviceShapeId);
assertEquals(serviceQuery, protocol);
}

{
// global default override
when(serviceIndex.getProtocols(service)).thenReturn(MapUtils.of(
json1_0, null,
json1_1, null,
restJson1, null,
rpcv2Cbor, null,
restXml, null,
query, null,
serviceQuery, null
));
ProtocolPriority.setCustomDefaultProtocolPriority(List.of(
rpcv2Cbor, json1_1, restJson1, restXml, query
));
ShapeId protocol = subject.resolveServiceProtocol(model, service, protocolShapeIds);
ProtocolPriority.deleteCustomDefaultProtocolPriority();
assertEquals(rpcv2Cbor, protocol);
}
}

@Test
public void resolvesSupportProtocols() {
// TODO
Expand Down

0 comments on commit e920c7e

Please sign in to comment.