Skip to content

Commit

Permalink
Add support for setting module specific task locals (#84)
Browse files Browse the repository at this point in the history
* Add support for setting module specific task locals

* Patch "com.amazonaws.s3#ChecksumAlgorithm" with "CRC64NVME"
  • Loading branch information
adam-fowler authored Dec 18, 2024
1 parent ee2d403 commit 702b964
Show file tree
Hide file tree
Showing 4 changed files with 420 additions and 350 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
.DS_Store
/.build
/.index-build
/Packages
/*.xcodeproj
xcuserdata/
Expand Down
110 changes: 77 additions & 33 deletions Sources/SotoCodeGeneratorLib/AwsService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ struct AwsService {

// separate by non-alphanumeric character, then capitalize the first letter of each component
// and join back together
let serviceName = sdkId
let serviceName =
sdkId
.components(separatedBy: CharacterSet.alphanumerics.inverted)
.map { $0.prefix(1).capitalized + $0.dropFirst() }
.joined()
Expand Down Expand Up @@ -153,7 +154,7 @@ struct AwsService {
/// filter operations list
mutating func filterOperations(_ filter: [String]) {
self.operations = self.operations.filter { key, _ in
return filter.contains(key.shapeName.toSwiftVariableCase())
filter.contains(key.shapeName.toSwiftVariableCase())
}
}

Expand Down Expand Up @@ -237,9 +238,10 @@ struct AwsService {
hostPrefix: endpointTrait?.hostPrefix,
deprecated: deprecatedTrait?.message,
streaming: streaming ? "ByteBuffer" : nil,
documentationUrl: nil, // added to comment
documentationUrl: nil, // added to comment
endpointRequired: requireEndpointDiscovery.map { OperationContext.DiscoverableEndpoint(required: $0) },
initParameters: initParamContext
initParameters: initParamContext,
taskLocals: generateTaskLocals(operation: operation)
)
}

Expand Down Expand Up @@ -276,6 +278,27 @@ struct AwsService {
}
}

func generateTaskLocals(
operation: OperationShape
) -> TaskLocalParameters? {
guard let staticParamsTrait = operation.trait(type: StaticContextParamsTrait.self) else { return nil }
let name: String
let possibleParameters: [String]
switch self.serviceEndpointPrefix {
case "s3":
name = "S3Middleware.$executionContext"
possibleParameters = ["UseS3ExpressControlEndpoint"]
default:
return nil
}
let parameters = staticParamsTrait.value
.filter { possibleParameters.contains($0.key) }
.compactMap { param in
param.value.dictionary?["value"].map { TaskLocalParameters.Parameter(key: param.key.toSwiftLabelCase(), value: $0) }
}
return !parameters.isEmpty ? .init(taskLocalName: name, taskLocalParams: parameters) : nil
}

static func getTrait<T: StaticTrait>(from shape: SotoSmithy.Shape, trait: T.Type, id: ShapeId) throws -> T {
guard let trait = shape.trait(type: T.self) else {
throw Error(reason: "\(id) does not have a \(T.staticName) trait")
Expand Down Expand Up @@ -353,7 +376,8 @@ struct AwsService {
if self.outputHTMLComments {
docs = documentation?.split(separator: "\n") ?? []
} else {
docs = documentation?
docs =
documentation?
.tagStriped()
.replacingOccurrences(of: "\n +", with: " ", options: .regularExpression, range: nil)
.split(separator: "\n")
Expand All @@ -375,7 +399,8 @@ struct AwsService {
if let recommendation = shape.trait(type: RecommendedTrait.self)?.reason {
documentation += "\n\(recommendation)"
}
return documentation
return
documentation
.tagStriped()
.replacingOccurrences(of: "\n +", with: " ", options: .regularExpression, range: nil)
.split(separator: "\n")
Expand All @@ -385,7 +410,7 @@ struct AwsService {

/// process documentation string
func processDocs(_ documentation: String?) -> [String.SubSequence] {
return documentation?
documentation?
.tagStriped()
.replacingOccurrences(of: "\n +", with: " ", options: .regularExpression, range: nil)
.split(separator: "\n")
Expand All @@ -399,11 +424,11 @@ struct AwsService {
return "AWSEditHeadersMiddleware(.add(name: \"accept\", value: \"application/json\"))"
case "Glacier":
return """
AWSMiddlewareStack {
AWSEditHeadersMiddleware(.add(name: \"x-amz-glacier-version\", value: \"\(service.version ?? "2012-06-01")\"))
TreeHashMiddleware(header: \"x-amz-sha256-tree-hash\")
}
"""
AWSMiddlewareStack {
AWSEditHeadersMiddleware(.add(name: \"x-amz-glacier-version\", value: \"\(service.version ?? "2012-06-01")\"))
TreeHashMiddleware(header: \"x-amz-sha256-tree-hash\")
}
"""
case "S3":
return "S3Middleware()"
default:
Expand All @@ -420,7 +445,7 @@ struct AwsService {
}

func encodingName(_ name: String) -> String {
return "_\(name)Encoding"
"_\(name)Encoding"
}

/// return payload member of structure
Expand Down Expand Up @@ -521,9 +546,10 @@ struct AwsService {
/// The JSON decoder requires an array to exist, even if it is empty so we have to make
/// all arrays in output shapes optional
func removeRequiredTraitFromOutputCollections(_ model: Model) {
guard self.serviceProtocolTrait is AwsProtocolsAwsJson1_0Trait ||
self.serviceProtocolTrait is AwsProtocolsAwsJson1_1Trait ||
self.serviceProtocolTrait is AwsProtocolsRestJson1Trait else { return }
guard
self.serviceProtocolTrait is AwsProtocolsAwsJson1_0Trait || self.serviceProtocolTrait is AwsProtocolsAwsJson1_1Trait
|| self.serviceProtocolTrait is AwsProtocolsRestJson1Trait
else { return }

for shape in model.shapes {
guard shape.value.hasTrait(type: SotoOutputShapeTrait.self) else { continue }
Expand Down Expand Up @@ -556,8 +582,9 @@ struct AwsService {
}
// if output token is member of an optional struct add ? suffix
if let member = structure.members?[String(split[0])] {
let required = member.hasTrait(type: RequiredTrait.self) ||
(member.hasTrait(type: HttpPayloadTrait.self) && structure.hasTrait(type: SotoOutputShapeTrait.self))
let required =
member.hasTrait(type: RequiredTrait.self)
|| (member.hasTrait(type: HttpPayloadTrait.self) && structure.hasTrait(type: SotoOutputShapeTrait.self))
if !required, split.count > 1 {
split[0] += "?"
}
Expand Down Expand Up @@ -603,13 +630,17 @@ struct AwsService {
guard let service = $0.services[self.serviceEndpointPrefix] else { return }
guard let partitionEndpoint = service.partitionEndpoint else { return }
guard let endpoint = service.endpoints[partitionEndpoint] else {
self.logger.error("Partition endpoint \(partitionEndpoint) for service \(self.serviceEndpointPrefix) in \($0.partitionName) does not exist")
self.logger.error(
"Partition endpoint \(partitionEndpoint) for service \(self.serviceEndpointPrefix) in \($0.partitionName) does not exist"
)
return
}
guard let region = endpoint.credentialScope?.region else {
// services with SigV4 authentication require an endpoint
if self.service.trait(type: AwsAuthSigV4Trait.self) != nil {
self.logger.error("Partition endpoint \(partitionEndpoint) for service \(self.serviceEndpointPrefix) in \($0.partitionName) has no credential scope region")
self.logger.error(
"Partition endpoint \(partitionEndpoint) for service \(self.serviceEndpointPrefix) in \($0.partitionName) has no credential scope region"
)
}
return
}
Expand Down Expand Up @@ -637,17 +668,24 @@ struct AwsService {
.sorted()
.joined(separator: ", ")
// get dnsSuffix for this variant
guard let dnsSuffix = getDefaultValue(partition: partition, service: service, getValue: { defaults in
return defaults.variants?.first(where: { $0.tags == variant.tags })?.dnsSuffix
}) else {
guard
let dnsSuffix = getDefaultValue(
partition: partition,
service: service,
getValue: { defaults in
defaults.variants?.first(where: { $0.tags == variant.tags })?.dnsSuffix
}
)
else {
continue
}
if variantEndpoints[variantString] == nil {
variantEndpoints[variantString] = .init()
}
if let hostname = variant.hostname {
// get hostname and replace any variables (wrapped in {}) in hostname
let finalHostname = hostname
let finalHostname =
hostname
.replacingOccurrences(of: "{region}", with: endpoint.key)
.replacingOccurrences(of: "{dnsSuffix}", with: dnsSuffix)
.replacingOccurrences(of: "{service}", with: self.serviceEndpointPrefix)
Expand All @@ -659,7 +697,7 @@ struct AwsService {
}
// return variants with endpoints sorted by region name
return variantEndpoints.mapValues {
return .init(defaultEndpoint: $0.defaultEndpoint, endpoints: $0.endpoints.sorted { $0.region < $1.region })
.init(defaultEndpoint: $0.defaultEndpoint, endpoints: $0.endpoints.sorted { $0.region < $1.region })
}
}

Expand Down Expand Up @@ -695,12 +733,9 @@ struct AwsService {
}

func isMemberInBody(_ member: MemberShape, isOutputShape: Bool) -> Bool {
return !(member.hasTrait(type: HttpHeaderTrait.self) ||
member.hasTrait(type: HttpPrefixHeadersTrait.self) ||
(member.hasTrait(type: HttpQueryTrait.self) && !isOutputShape) ||
member.hasTrait(type: HttpQueryParamsTrait.self) ||
member.hasTrait(type: HttpLabelTrait.self) ||
member.hasTrait(type: HttpResponseCodeTrait.self))
!(member.hasTrait(type: HttpHeaderTrait.self) || member.hasTrait(type: HttpPrefixHeadersTrait.self)
|| (member.hasTrait(type: HttpQueryTrait.self) && !isOutputShape) || member.hasTrait(type: HttpQueryParamsTrait.self)
|| member.hasTrait(type: HttpLabelTrait.self) || member.hasTrait(type: HttpResponseCodeTrait.self))
}
}

Expand All @@ -711,11 +746,19 @@ extension AwsService {
let reason: String
}

struct TaskLocalParameters {
struct Parameter {
let key: String
let value: Any
}
let taskLocalName: String
let taskLocalParams: [Parameter]
}

struct OperationContext {
struct DiscoverableEndpoint {
let required: Bool
}

let comment: [String.SubSequence]
let funcName: String
let inputShape: String?
Expand All @@ -729,6 +772,7 @@ extension AwsService {
let documentationUrl: String?
let endpointRequired: DiscoverableEndpoint?
var initParameters: [OperationInitParamContext]
let taskLocals: TaskLocalParameters?
}

struct OperationInitParamContext {
Expand Down Expand Up @@ -955,6 +999,6 @@ extension AwsService {
case jmesAllPath(path: String, expected: String)
case error(String)
case errorStatus(Int)
case success(Int) // Success requires a dummy associated value, so a mustache context is created for the `MatcherContext`
case success(Int) // Success requires a dummy associated value, so a mustache context is created for the `MatcherContext`
}
}
Loading

0 comments on commit 702b964

Please sign in to comment.