diff --git a/Sources/ApodiniContext/Context.swift b/Sources/ApodiniContext/Context.swift index 5871275..ff45013 100644 --- a/Sources/ApodiniContext/Context.swift +++ b/Sources/ApodiniContext/Context.swift @@ -11,9 +11,12 @@ import OrderedCollections private class ContextBox { var entries: [ObjectIdentifier: StoredContextValue] + /// Mapping from ``CodableContextKey/identifier`` to base64 encoded data + var decodedEntries: [String: String] - init(_ entries: [ObjectIdentifier: StoredContextValue]) { + init(_ entries: [ObjectIdentifier: StoredContextValue], _ decodedEntries: [String: String]) { self.entries = entries + self.decodedEntries = decodedEntries } } @@ -26,15 +29,17 @@ struct StoredContextValue { /// A `Context` holds a collection of values for predefined `ContextKey`s or `OptionalContextKey`s. public struct Context: ContextKeyRetrievable { private var boxedEntries: ContextBox + private var entries: [ObjectIdentifier: StoredContextValue] { boxedEntries.entries } - /// Mapping from ``CodableContextKey/identifier`` to base64 encoded data - private let decodedEntries: [String: String] + + private var decodedEntries: [String: String] { + boxedEntries.decodedEntries + } init(_ entries: [ObjectIdentifier: StoredContextValue] = [:], _ decodedEntries: [String: String] = [:]) { - self.boxedEntries = ContextBox(entries) - self.decodedEntries = decodedEntries + self.boxedEntries = ContextBox(entries, decodedEntries) } /// Create a new empty ``Context``. @@ -97,14 +102,21 @@ public struct Context: ContextKeyRetrievable { precondition(entries[key] == nil || allowOverwrite, "Cannot overwrite existing ContextKey entry with `unsafeAdd`: \(C.self): \(value)") if let codableContextKey = contextKey as? AnyCodableContextKey.Type { // we need to prevent this. as Otherwise we would need to handle merging this stuff which get really complex - precondition(decodedEntries[codableContextKey.identifier] == nil, "Cannot overwrite existing CodableContextKey entry with `unsafeAdd`: \(C.self): \(value)") + precondition( + decodedEntries[codableContextKey.identifier] == nil || allowOverwrite, + "Cannot overwrite existing CodableContextKey entry with `unsafeAdd`: \(C.self): \(value)" + ) + + // if we reach this point, either the key doesn't exist or `allowOverwrite` was turned on + // and we need to remove the existing entry in order to properly overwrite everything + boxedEntries.decodedEntries.removeValue(forKey: codableContextKey.identifier) } boxedEntries.entries[key] = StoredContextValue(key: contextKey, value: value) } private func checkForDecodedEntries(for key: Key.Type = Key.self) -> Key.Value? { - guard let dataValue = decodedEntries[Key.identifier] else { + guard let dataValue = boxedEntries.decodedEntries.removeValue(forKey: Key.identifier) else { return nil } @@ -140,15 +152,13 @@ extension Context: Codable { public init(from decoder: Decoder) throws { let container = try decoder.container(keyedBy: StringContextKey.self) - self.boxedEntries = ContextBox([:]) - var decodedEntries: [String: String] = [:] for key in container.allKeys { decodedEntries[key.stringValue] = try container.decode(String.self, forKey: key) } - self.decodedEntries = decodedEntries + self.boxedEntries = ContextBox([:], decodedEntries) } public func encode(to encoder: Encoder) throws { diff --git a/Tests/ApodiniContextTests/ContextKeyTests.swift b/Tests/ApodiniContextTests/ContextKeyTests.swift index 1e16a20..9eb64b4 100644 --- a/Tests/ApodiniContextTests/ContextKeyTests.swift +++ b/Tests/ApodiniContextTests/ContextKeyTests.swift @@ -118,7 +118,7 @@ class ContextKeyTests: XCTestCase { XCTAssertEqual(decodedContext.get(valueFor: CodableArrayStringContextKey.self), ["Hello Sun"]) } - func testUnsafeAddAllowingOverwrite() { + func testUnsafeAddAllowingOverwrite() throws { struct CodableStringContextKey: CodableContextKey { typealias Value = String } @@ -129,5 +129,20 @@ class ContextKeyTests: XCTestCase { XCTAssertEqual(context.get(valueFor: CodableStringContextKey.self), "Hello World") context.unsafeAdd(CodableStringContextKey.self, value: "Hello Mars", allowOverwrite: true) XCTAssertEqual(context.get(valueFor: CodableStringContextKey.self), "Hello Mars") + + let encoder = FineJSONEncoder() + encoder.jsonSerializeOptions = .init(isPrettyPrint: false) + let decoder = FineJSONDecoder() + + let encodedContext = try encoder.encode(context) + XCTAssertEqual( + String(data: encodedContext, encoding: .utf8), + "{\"CodableStringContextKey\":\"IkhlbGxvIE1hcnMi\"}" + ) + + let decodedContext = try decoder.decode(Context.self, from: encodedContext) + + decodedContext.unsafeAdd(CodableStringContextKey.self, value: "Hello Saturn", allowOverwrite: true) + XCTAssertEqual(decodedContext.get(valueFor: CodableStringContextKey.self), "Hello Saturn") } }