From 7d579b0842518b44675323448ea6adac56d21edf Mon Sep 17 00:00:00 2001 From: Jason Phan Date: Fri, 21 Apr 2023 21:57:12 -0500 Subject: [PATCH 1/4] attrs: Add tag container attribute --- src/attributes.zig | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/attributes.zig b/src/attributes.zig index 1e300b5f..35fcb9ae 100644 --- a/src/attributes.zig +++ b/src/attributes.zig @@ -29,6 +29,12 @@ pub const Case = enum { screaming_kebab, }; +pub const Tag = enum { + external, + internal, + untagged, +}; + /// Returns an attribute map type. pub fn Attributes(comptime T: type, comptime attributes: anytype) type { const type_name = @typeName(T); @@ -171,9 +177,8 @@ const ContainerAttributes = struct { // convention. //rename_all: ?Case = null, - // Use the internally tagged enum representation for this enum, with - // the given tag. - //tag: ?[]const u8 = null, + // Use the specified representation for this union. + tag: Tag = .external, // Deserialize this type by deserializing into the given type, then // converting fallibly. From 30cbd6a27334abba8f8c56fb514fb157ed580c69 Mon Sep 17 00:00:00 2001 From: Jason Phan Date: Fri, 21 Apr 2023 21:57:32 -0500 Subject: [PATCH 2/4] ser: Refactor union block for tags --- src/ser/blocks/union.zig | 137 ++++++++++++++++++++++++++++++--------- 1 file changed, 107 insertions(+), 30 deletions(-) diff --git a/src/ser/blocks/union.zig b/src/ser/blocks/union.zig index 5c2f1bb0..b5fa69b6 100644 --- a/src/ser/blocks/union.zig +++ b/src/ser/blocks/union.zig @@ -1,7 +1,9 @@ const std = @import("std"); const getAttributes = @import("../attributes.zig").getAttributes; +const getty_serialize = @import("../serialize.zig").serialize; const t = @import("../testing.zig"); +const Tag = @import("../../attributes.zig").Tag; /// Specifies all types that can be serialized by this block. pub fn is( @@ -20,11 +22,8 @@ pub fn serialize( /// A `getty.Serializer` interface value. serializer: anytype, ) @TypeOf(serializer).Error!@TypeOf(serializer).Ok { - _ = allocator; - const T = @TypeOf(value); const info = @typeInfo(T).Union; - const attributes = comptime getAttributes(T, @TypeOf(serializer)); if (info.tag_type == null) { @compileError(std.fmt.comptimePrint("untagged unions cannot be serialized: {s}", .{@typeName(T)})); @@ -38,44 +37,90 @@ pub fn serialize( const tag_matches = value == @field(T, field.name); if (tag_matches) { - const attrs = comptime blk: { - if (attributes) |attrs| { - if (@hasField(@TypeOf(attrs), field.name)) { - const a = @field(attrs, field.name); - const A = @TypeOf(a); - - break :blk @as(?A, a); - } - } + return try serializeVariant(allocator, value, serializer, field); + } + } - break :blk null; - }; + // UNREACHABLE: We've already checked that the union has a tag, meaning + // that the above for loop will always enter its top-level if block, which + // always returns from this function. + unreachable; +} + +fn serializeVariant( + allocator: ?std.mem.Allocator, + value: anytype, + serializer: anytype, + comptime field: std.builtin.Type.UnionField, +) @TypeOf(serializer).Error!@TypeOf(serializer).Ok { + const tag: Tag = comptime blk: { + const attributes = getAttributes(@TypeOf(value), @TypeOf(serializer)); - if (attrs) |a| { - const skipped = @hasField(@TypeOf(a), "skip") and a.skip; - if (skipped) return error.UnknownVariant; + if (attributes) |attrs| { + if (@hasField(@TypeOf(attrs), "Container")) { + if (@hasField(@TypeOf(attrs.Container), "tag")) { + break :blk attrs.Container.tag; + } } + } - var m = try serializer.serializeMap(1); - const map = m.map(); + break :blk .external; + }; - comptime var name = field.name; + return switch (tag) { + .external => try serializeExternallyTaggedVariant(value, serializer, field), + .untagged => try serializeUntaggedVariant(allocator, value, serializer, field), + .internal => @compileError("TODO: internally tagged representation"), + }; +} - if (attrs) |a| { - const renamed = @hasField(@TypeOf(a), "rename"); - if (renamed) name = a.rename; - } +fn serializeExternallyTaggedVariant( + value: anytype, + serializer: anytype, + comptime field: std.builtin.Type.UnionField, +) @TypeOf(serializer).Error!@TypeOf(serializer).Ok { + const attrs = comptime blk: { + const attributes = getAttributes(@TypeOf(value), @TypeOf(serializer)); - try map.serializeEntry(name, @field(value, field.name)); + if (attributes) |attrs| { + if (@hasField(@TypeOf(attrs), field.name)) { + const a = @field(attrs, field.name); + const A = @TypeOf(a); - return try map.end(); + break :blk @as(?A, a); + } } + + break :blk null; + }; + + if (attrs) |a| { + const skipped = @hasField(@TypeOf(a), "skip") and a.skip; + if (skipped) return error.UnknownVariant; } - // UNREACHABLE: We've already checked that the union has a tag, meaning - // that the above for loop will always enter its top-level if block, which - // always returns from this function. - unreachable; + var m = try serializer.serializeMap(1); + const map = m.map(); + + comptime var name = field.name; + + if (attrs) |a| { + const renamed = @hasField(@TypeOf(a), "rename"); + if (renamed) name = a.rename; + } + + try map.serializeEntry(name, @field(value, field.name)); + + return try map.end(); +} + +fn serializeUntaggedVariant( + allocator: ?std.mem.Allocator, + value: anytype, + serializer: anytype, + comptime field: std.builtin.Type.UnionField, +) @TypeOf(serializer).Error!@TypeOf(serializer).Ok { + return getty_serialize(allocator, @field(value, field.name), serializer); } test "serialize - union" { @@ -144,3 +189,35 @@ test "serialize - union, attributes (skip)" { .{ .MapEnd = {} }, }); } + +test "serialize - union, attributes (tag, untagged)" { + const T = union(enum) { + Int: i32, + Bool: bool, + Union: union(enum) { + Int: i32, + Bool: bool, + }, + + pub const @"getty.sb" = struct { + pub const attributes = .{ + .Container = .{ .tag = .untagged }, + }; + }; + }; + + try t.run(null, serialize, T{ .Int = 0 }, &.{.{ .I32 = 0 }}); + try t.run(null, serialize, T{ .Bool = true }, &.{.{ .Bool = true }}); + try t.run(null, serialize, T{ .Union = .{ .Int = 0 } }, &.{ + .{ .Map = .{ .len = 1 } }, + .{ .String = "Int" }, + .{ .I32 = 0 }, + .{ .MapEnd = {} }, + }); + try t.run(null, serialize, T{ .Union = .{ .Bool = true } }, &.{ + .{ .Map = .{ .len = 1 } }, + .{ .String = "Bool" }, + .{ .Bool = true }, + .{ .MapEnd = {} }, + }); +} From d20b78a529a3fcae247c05dbd58a4da4528c9134 Mon Sep 17 00:00:00 2001 From: Jason Phan Date: Fri, 21 Apr 2023 22:15:39 -0500 Subject: [PATCH 3/4] attrs: Document the Tag declaration --- src/attributes.zig | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/attributes.zig b/src/attributes.zig index 35fcb9ae..852dbf68 100644 --- a/src/attributes.zig +++ b/src/attributes.zig @@ -2,7 +2,7 @@ const std = @import("std"); const comptimePrint = std.fmt.comptimePrint; const Type = std.builtin.Type; -/// Case conventions for the `rename_all` attribute. +/// Case conventions. pub const Case = enum { // foobar lower, @@ -29,6 +29,7 @@ pub const Case = enum { screaming_kebab, }; +/// Tag representations for union variants. pub const Tag = enum { external, internal, From 8866e9839ef7e03498d587edeba7ff2eac0eb29e Mon Sep 17 00:00:00 2001 From: Jason Phan Date: Sat, 6 May 2023 10:40:29 -0500 Subject: [PATCH 4/4] de: Add support for unions with untagged attribute --- src/de/blocks/struct.zig | 6 +- src/de/blocks/union.zig | 720 ++++++++++++++++++++++++++++++++++++++- src/de/deserialize.zig | 20 +- src/de/testing.zig | 100 ++---- 4 files changed, 743 insertions(+), 103 deletions(-) diff --git a/src/de/blocks/struct.zig b/src/de/blocks/struct.zig index 4bc007d0..d7d1b85b 100644 --- a/src/de/blocks/struct.zig +++ b/src/de/blocks/struct.zig @@ -91,7 +91,7 @@ test "deserialize - struct" { inline for (tests) |t| { const Want = @TypeOf(t.want); - const got = try testing.deserialize(null, t.name, Self, Want, t.tokens); + const got = try testing.deserialize(std.testing.allocator, t.name, Self, Want, t.tokens); try testing.expectEqual(t.name, t.want, got); } } @@ -252,12 +252,12 @@ test "deserialize - struct, attributes" { try testing.expectError( t.name, t.want_err, - testing.deserializeErr(null, Self, Want, t.tokens), + testing.deserializeErr(std.testing.allocator, Self, Want, t.tokens), ); } else { const Want = @TypeOf(t.want); - const got = try testing.deserialize(null, t.name, Self, Want, t.tokens); + const got = try testing.deserialize(std.testing.allocator, t.name, Self, Want, t.tokens); try testing.expectEqual(t.name, t.want, got); } } diff --git a/src/de/blocks/union.zig b/src/de/blocks/union.zig index bf191d3f..25d23f1b 100644 --- a/src/de/blocks/union.zig +++ b/src/de/blocks/union.zig @@ -1,10 +1,18 @@ const std = @import("std"); +const DeserializerInterface = @import("../interfaces/deserializer.zig").Deserializer; +const getAttributes = @import("../attributes.zig").getAttributes; +const getty_deserialize = @import("../deserialize.zig").deserialize; +const getty_error = @import("../error.zig").Error; const getty_free = @import("../free.zig").free; -const UnionVisitor = @import("../impls/visitor/union.zig").Visitor; +const MapAccessInterface = @import("../interfaces/map_access.zig").MapAccess; +const SeqAccessInterface = @import("../interfaces/seq_access.zig").SeqAccess; +const Tag = @import("../../attributes.zig").Tag; const testing = @import("../testing.zig"); - -const Self = @This(); +const UnionAccessInterface = @import("../interfaces/union_access.zig").UnionAccess; +const UnionVisitor = @import("../impls/visitor/union.zig").Visitor; +const VariantAccessInterface = @import("../interfaces/variant_access.zig").VariantAccess; +const VisitorInterface = @import("../interfaces/visitor.zig").Visitor; /// Specifies all types that can be deserialized by this block. pub fn is( @@ -24,12 +32,560 @@ pub fn deserialize( deserializer: anytype, /// A `getty.de.Visitor` interface value. visitor: anytype, -) !@TypeOf(visitor).Value { - _ = T; +) @TypeOf(deserializer).Error!@TypeOf(visitor).Value { + const tag: Tag = comptime blk: { + const attributes = getAttributes(T, @TypeOf(deserializer)); + + if (attributes) |attrs| { + if (@hasField(@TypeOf(attrs), "Container")) { + if (@hasField(@TypeOf(attrs.Container), "tag")) { + break :blk attrs.Container.tag; + } + } + } + + break :blk .external; + }; + + return switch (tag) { + .external => try deserializeExternallyTaggedVariant(allocator, deserializer, visitor), + .untagged => try deserializeUntaggedVariant(allocator, T, deserializer, visitor), + .internal => @compileError("TODO: internally tagged representation"), + }; +} +fn deserializeExternallyTaggedVariant( + allocator: ?std.mem.Allocator, + deserializer: anytype, + visitor: anytype, +) @TypeOf(deserializer).Error!@TypeOf(visitor).Value { return try deserializer.deserializeUnion(allocator, visitor); } +// Untagged unions are only supported in self-describing formats. +fn deserializeUntaggedVariant( + allocator: ?std.mem.Allocator, + comptime T: type, + deserializer: anytype, + visitor: anytype, +) @TypeOf(deserializer).Error!@TypeOf(visitor).Value { + // Deserialize the input data into a Content value. + // + // This intermediate value allows us to repeatedly attempt deserialization + // for each variant of the untagged union, without further modifying the + // actual input data of the deserializer. + var content = try getty_deserialize(allocator, Content, deserializer); + defer switch (content) { + .Int, .Map, .Seq, .String, .Some => { + // If content was successfully deserialized, and we're here, then + // that means allocator must've not been null. + std.debug.assert(allocator != null); + getty_free(allocator.?, @TypeOf(deserializer), content); + }, + else => {}, + }; + + // Deserialize the Content value into a value of type T. + var cd = ContentDeserializer{ .content = content }; + const d = cd.deserializer(); + + inline for (std.meta.fields(T)) |field| { + if (getty_deserialize(allocator, field.type, d)) |value| { + return @unionInit(T, field.name, value); + } else |err| switch (err) { + error.DuplicateField, + error.InvalidLength, + error.InvalidType, + error.MissingField, + error.MissingVariant, + error.UnknownField, + error.UnknownVariant, + => {}, + else => return err, + } + } + + return error.MissingVariant; +} + +const ContentMap = struct { + key: Content, + value: Content, +}; + +const ContentDeserializerMap = struct { + key: ContentDeserializer, + value: ContentDeserializer, +}; + +const ContentMultiArrayList = std.MultiArrayList(ContentMap); +const ContentDeserializerMultiArrayList = std.MultiArrayList(ContentDeserializerMap); + +// Does not support compile-time known types. +const Content = union(enum) { + Bool: bool, + F16: f16, + F32: f32, + F64: f64, + F128: f128, + Int: std.math.big.int.Managed, + Map: ContentMultiArrayList, + Null, + Seq: std.ArrayList(Content), + Some: *Content, + String: []const u8, + Void, + + pub fn deinit(self: Content, allocator: std.mem.Allocator) void { + switch (self) { + .Int => |v| { + var mut = v; + mut.deinit(); + }, + .Seq => |v| { + for (v.items) |elem| elem.deinit(allocator); + v.deinit(); + }, + .Map => |v| { + for (v.items(.key), v.items(.value)) |key, value| { + key.deinit(allocator); + value.deinit(allocator); + } + var mut = v; + mut.deinit(allocator); + }, + .String => |v| allocator.free(v), + .Some => |v| { + v.deinit(allocator); + allocator.destroy(v); + }, + else => {}, + } + } + + pub const @"getty.db" = struct { + pub fn deserialize( + allocator: ?std.mem.Allocator, + comptime _: type, + deserializer: anytype, + visitor: anytype, + ) !@TypeOf(visitor).Value { + return try deserializer.deserializeAny(allocator, visitor); + } + + pub fn Visitor(comptime _: type) type { + return struct { + pub usingnamespace VisitorInterface( + @This(), + Content, + .{ + .visitBool = visitBool, + .visitFloat = visitFloat, + .visitInt = visitInt, + .visitMap = visitMap, + .visitNull = visitNull, + .visitSeq = visitSeq, + .visitSome = visitSome, + .visitString = visitString, + .visitUnion = visitUnion, + .visitVoid = visitVoid, + }, + ); + + fn visitBool(_: @This(), _: ?std.mem.Allocator, comptime Deserializer: type, input: bool) Deserializer.Error!Content { + return .{ .Bool = input }; + } + + fn visitFloat(_: @This(), _: ?std.mem.Allocator, comptime Deserializer: type, input: anytype) Deserializer.Error!Content { + return switch (@TypeOf(input)) { + f16 => .{ .F16 = input }, + f32 => .{ .F32 = input }, + f64 => .{ .F64 = input }, + f128 => .{ .F128 = input }, + comptime_float => @compileError("comptime_float is not supported"), + else => unreachable, // UNREACHABLE: The Visitor interface guarantees that input is a float. + }; + } + + fn visitInt(_: @This(), allocator: ?std.mem.Allocator, comptime Deserializer: type, input: anytype) Deserializer.Error!Content { + if (allocator == null) { + return error.MissingAllocator; + } + + return switch (@typeInfo(@TypeOf(input))) { + .Int => .{ .Int = try std.math.big.int.Managed.initSet(allocator.?, input) }, + .ComptimeInt => @compileError("comptime_int is not supported"), + else => unreachable, // UNREACHABLE: The Visitor interface guarantees that input is an integer. + }; + } + + fn visitMap(_: @This(), allocator: ?std.mem.Allocator, comptime Deserializer: type, mapAccess: anytype) Deserializer.Error!Content { + if (allocator == null) { + return error.MissingAllocator; + } + + var map = ContentMultiArrayList{}; + errdefer map.deinit(allocator.?); + + while (try mapAccess.nextKey(allocator.?, Content)) |key| { + errdefer if (mapAccess.isKeyAllocated(@TypeOf(key))) { + getty_free(allocator.?, Deserializer, key); + }; + + const value = try mapAccess.nextValue(allocator, Content); + errdefer getty_free(allocator.?, Deserializer, value); + + try map.append(allocator.?, .{ + .key = key, + .value = value, + }); + } + + return .{ .Map = map }; + } + + fn visitNull(_: @This(), _: ?std.mem.Allocator, comptime Deserializer: type) Deserializer.Error!Content { + return .{ .Null = {} }; + } + + fn visitSeq(_: @This(), allocator: ?std.mem.Allocator, comptime Deserializer: type, seqAccess: anytype) Deserializer.Error!Content { + if (allocator == null) { + return error.MissingAllocator; + } + + var list = std.ArrayList(Content).init(allocator.?); + errdefer list.deinit(); + + while (try seqAccess.nextElement(allocator.?, Content)) |elem| { + try list.append(elem); + } + + return .{ .Seq = list }; + } + + fn visitSome(_: @This(), allocator: ?std.mem.Allocator, deserializer: anytype) @TypeOf(deserializer).Error!Content { + return .{ .Some = try getty_deserialize(allocator, *Content, deserializer) }; + } + + fn visitString(_: @This(), allocator: ?std.mem.Allocator, comptime Deserializer: type, input: anytype) Deserializer.Error!Content { + const output = try allocator.?.alloc(u8, input.len); + std.mem.copy(u8, output, input); + + return .{ .String = output }; + } + + fn visitUnion(_: @This(), allocator: ?std.mem.Allocator, comptime Deserializer: type, ua: anytype, va: anytype) Deserializer.Error!Content { + if (allocator == null) { + return error.MissingAllocator; + } + + var variant = try ua.variant(allocator, Content); + errdefer if (ua.isVariantAllocated(@TypeOf(variant))) { + getty_free(allocator.?, Deserializer, variant); + }; + + var payload = try va.payload(allocator.?, Content); + errdefer getty_free(allocator.?, Deserializer, payload); + + var map = ContentMultiArrayList{}; + errdefer map.deinit(allocator.?); + + try map.append(allocator.?, .{ + .key = variant, + .value = payload, + }); + + return .{ .Map = map }; + } + + fn visitVoid(_: @This(), _: ?std.mem.Allocator, comptime Deserializer: type) Deserializer.Error!Content { + return .{ .Void = {} }; + } + }; + } + + pub fn free(allocator: std.mem.Allocator, comptime _: type, value: anytype) void { + switch (value) { + .Int, .Map, .Seq, .String, .Some => value.deinit(allocator), + else => {}, + } + } + }; +}; + +const ContentDeserializer = struct { + content: Content, + + const Self = @This(); + + pub usingnamespace DeserializerInterface( + Self, + getty_error, + null, + null, + .{ + .deserializeAny = deserializeAny, + .deserializeBool = deserializeBool, + .deserializeEnum = deserializeEnum, + .deserializeFloat = deserializeFloat, + .deserializeInt = deserializeInt, + .deserializeIgnored = deserializeIgnored, + .deserializeMap = deserializeMap, + .deserializeOptional = deserializeOptional, + .deserializeSeq = deserializeSeq, + .deserializeString = deserializeString, + .deserializeStruct = deserializeMap, + .deserializeUnion = deserializeUnion, + .deserializeVoid = deserializeVoid, + }, + ); + + const De = Self.@"getty.Deserializer"; + + fn deserializeAny(self: Self, allocator: ?std.mem.Allocator, visitor: anytype) getty_error!@TypeOf(visitor).Value { + return switch (self.content) { + .Bool => |v| try visitor.visitBool(allocator, De, v), + inline .F16, .F32, .F64, .F128 => |v| try visitor.visitFloat(allocator, De, v), + .Int => |v| blk: { + comptime var Value = @TypeOf(visitor).Value; + + if (@typeInfo(Value) == .Int) { + break :blk try visitor.visitInt(allocator, De, v.to(Value) catch unreachable); + } + + if (v.isPositive()) { + break :blk try visitor.visitInt(allocator, De, v.to(u128) catch return error.InvalidValue); + } else { + break :blk try visitor.visitInt(allocator, De, v.to(i128) catch return error.InvalidValue); + } + }, + .Map => |v| try visitContentMap(allocator, v, visitor), + .Null => try visitor.visitNull(allocator, De), + .Seq => |v| try visitContentSeq(allocator, v, visitor), + .Some => |v| blk: { + var cd = Self{ .content = v.* }; + break :blk try visitor.visitSome(allocator, cd.deserializer()); + }, + .String => |v| try visitor.visitString(allocator, De, v), + .Void => try visitor.visitVoid(allocator, De), + }; + } + + fn deserializeBool(self: Self, allocator: ?std.mem.Allocator, visitor: anytype) getty_error!@TypeOf(visitor).Value { + return switch (self.content) { + .Bool => |v| try visitor.visitBool(allocator, De, v), + else => error.InvalidType, + }; + } + + fn deserializeEnum(self: Self, allocator: ?std.mem.Allocator, visitor: anytype) getty_error!@TypeOf(visitor).Value { + return switch (self.content) { + .Int => |v| blk: { + const int = v.to(@TypeOf(visitor).Value) catch unreachable; + break :blk try visitor.visitInt(allocator, De, int); + }, + .String => |v| try visitor.visitString(allocator, De, v), + else => error.InvalidType, + }; + } + + fn deserializeFloat(self: Self, allocator: ?std.mem.Allocator, visitor: anytype) getty_error!@TypeOf(visitor).Value { + return switch (self.content) { + inline .F16, .F32, .F64, .F128 => |v| try visitor.visitFloat(allocator, De, v), + else => error.InvalidType, + }; + } + + fn deserializeIgnored(_: Self, allocator: ?std.mem.Allocator, visitor: anytype) getty_error!@TypeOf(visitor).Value { + return try visitor.visitVoid(allocator, De); + } + + fn deserializeInt(self: Self, allocator: ?std.mem.Allocator, visitor: anytype) getty_error!@TypeOf(visitor).Value { + return switch (self.content) { + .Int => |v| blk: { + const int = v.to(@TypeOf(visitor).Value) catch unreachable; + break :blk try visitor.visitInt(allocator, De, int); + }, + else => error.InvalidType, + }; + } + + fn deserializeMap(self: Self, allocator: ?std.mem.Allocator, visitor: anytype) getty_error!@TypeOf(visitor).Value { + return switch (self.content) { + .Map => |v| try visitContentMap(allocator, v, visitor), + else => error.InvalidType, + }; + } + + fn deserializeOptional(self: Self, allocator: ?std.mem.Allocator, visitor: anytype) getty_error!@TypeOf(visitor).Value { + return switch (self.content) { + .Null => try visitor.visitNull(allocator, De), + .Some => |v| blk: { + var cd = Self{ .content = v.* }; + break :blk try visitor.visitSome(allocator, cd.deserializer()); + }, + else => error.InvalidType, + }; + } + + fn deserializeSeq(self: Self, allocator: ?std.mem.Allocator, visitor: anytype) getty_error!@TypeOf(visitor).Value { + return switch (self.content) { + .Seq => |v| try visitContentSeq(allocator, v, visitor), + else => error.InvalidType, + }; + } + + fn deserializeString(self: Self, allocator: ?std.mem.Allocator, visitor: anytype) getty_error!@TypeOf(visitor).Value { + return switch (self.content) { + .String => |v| try visitor.visitString(allocator, De, v), + else => error.InvalidType, + }; + } + + fn deserializeUnion(self: Self, allocator: ?std.mem.Allocator, visitor: anytype) getty_error!@TypeOf(visitor).Value { + return switch (self.content) { + .Map => |mal| blk: { + const keys = mal.items(.key); + const values = mal.items(.value); + + if (mal.len != 1 or keys.len != 1 or values.len != 1) { + return error.InvalidValue; + } + + var uva = UnionVariantAccess{ .key = keys[0], .value = values[0] }; + const ua = uva.unionAccess(); + const va = uva.variantAccess(); + + break :blk try visitor.visitUnion(allocator, De, ua, va); + }, + else => error.InvalidType, + }; + } + + fn deserializeVoid(self: Self, allocator: ?std.mem.Allocator, visitor: anytype) getty_error!@TypeOf(visitor).Value { + return switch (self.content) { + .Void => try visitor.visitVoid(allocator, De), + else => error.InvalidType, + }; + } + + fn visitContentMap(allocator: ?std.mem.Allocator, content: ContentMultiArrayList, visitor: anytype) getty_error!@TypeOf(visitor).Value { + if (allocator == null) { + return error.MissingAllocator; + } + + var map = ContentDeserializerMultiArrayList{}; + try map.ensureTotalCapacity(allocator.?, content.len); + defer map.deinit(allocator.?); + + for (content.items(.key), content.items(.value)) |k, v| { + map.appendAssumeCapacity(.{ + .key = ContentDeserializer{ .content = k }, + .value = ContentDeserializer{ .content = v }, + }); + } + + var ma = MapAccess{ .deserializers = map }; + return try visitor.visitMap(allocator.?, De, ma.mapAccess()); + } + + fn visitContentSeq(allocator: ?std.mem.Allocator, content: std.ArrayList(Content), visitor: anytype) getty_error!@TypeOf(visitor).Value { + if (allocator == null) { + return error.MissingAllocator; + } + + var seq = try std.ArrayList(ContentDeserializer).initCapacity(allocator.?, content.items.len); + defer seq.deinit(); + + for (content.items) |c| { + seq.appendAssumeCapacity(ContentDeserializer{ .content = c }); + } + + var sa = SeqAccess{ .deserializers = seq }; + return try visitor.visitSeq(allocator.?, De, sa.seqAccess()); + } +}; + +const MapAccess = struct { + pos: u64 = 0, + deserializers: ContentDeserializerMultiArrayList, + + pub usingnamespace MapAccessInterface( + *@This(), + getty_error, + .{ + .nextKeySeed = nextKeySeed, + .nextValueSeed = nextValueSeed, + }, + ); + + fn nextKeySeed(self: *@This(), allocator: ?std.mem.Allocator, seed: anytype) getty_error!?@TypeOf(seed).Value { + if (self.pos >= self.deserializers.items(.key).len) { + return null; + } + + var d = self.deserializers.items(.key)[self.pos]; + return try seed.deserialize(allocator, d.deserializer()); + } + + fn nextValueSeed(self: *@This(), allocator: ?std.mem.Allocator, seed: anytype) getty_error!@TypeOf(seed).Value { + var d = self.deserializers.items(.value)[self.pos]; + self.pos += 1; + return try seed.deserialize(allocator, d.deserializer()); + } +}; + +const SeqAccess = struct { + pos: u64 = 0, + deserializers: std.ArrayList(ContentDeserializer), + + pub usingnamespace SeqAccessInterface( + *@This(), + getty_error, + .{ .nextElementSeed = nextElementSeed }, + ); + + fn nextElementSeed(self: *@This(), allocator: ?std.mem.Allocator, seed: anytype) getty_error!?@TypeOf(seed).Value { + if (self.pos >= self.deserializers.items.len) { + return null; + } + + var d = self.deserializers.items[self.pos]; + self.pos += 1; + + return try seed.deserialize(allocator, d.deserializer()); + } +}; + +const UnionVariantAccess = struct { + key: Content, + value: Content, + + const Self = @This(); + + pub usingnamespace UnionAccessInterface( + *Self, + getty_error, + .{ .variantSeed = variantSeed }, + ); + + pub usingnamespace VariantAccessInterface( + *Self, + getty_error, + .{ .payloadSeed = payloadSeed }, + ); + + fn variantSeed(self: *Self, allocator: ?std.mem.Allocator, seed: anytype) getty_error!@TypeOf(seed).Value { + var cd = ContentDeserializer{ .content = self.key }; + return try seed.deserialize(allocator, cd.deserializer()); + } + + fn payloadSeed(self: *Self, allocator: ?std.mem.Allocator, seed: anytype) getty_error!@TypeOf(seed).Value { + var cd = ContentDeserializer{ .content = self.value }; + return try seed.deserialize(allocator, cd.deserializer()); + } +}; + /// Returns a type that implements `getty.de.Visitor`. pub fn Visitor( /// The type being deserialized into. @@ -49,9 +605,9 @@ pub fn free( ) void { const info = @typeInfo(@TypeOf(value)).Union; - if (info.tag_type) |Tag| { + if (info.tag_type) |T| { inline for (info.fields) |field| { - if (value == @field(Tag, field.name)) { + if (value == @field(T, field.name)) { getty_free(allocator, Deserializer, @field(value, field.name)); break; } @@ -308,6 +864,152 @@ test "deserialize - union, attributes (skip)" { } } +test "deserialize - union, attributes (tag, untagged)" { + const WantTagged = union(enum) { + Bool: bool, + F32: f32, + I32: i32, + Optional: ?void, + Map: struct { A: i32, B: i32, C: i32 }, + Seq: [3]i32, + String: []const u8, + Union: union(enum) { foo: i32 }, + // NOTE: The variant in this union needs to be different than all the + // other variants in WantTagged. Otherwise, an earlier variant will be + // deserialized into. + UnionUntagged: union(enum) { + Bools: [2]bool, + + pub const @"getty.db" = struct { + pub const attributes = .{ .Container = .{ .tag = .untagged } }; + }; + }, + Void, + + pub const @"getty.db" = struct { + pub const attributes = .{ .Container = .{ .tag = .untagged } }; + }; + }; + + const tests = .{ + .{ + .name = "tagged, bool variant", + .tokens = &.{.{ .Bool = true }}, + .want = WantTagged{ .Bool = true }, + }, + .{ + .name = "tagged, float variant", + .tokens = &.{.{ .F32 = 3.14 }}, + .want = WantTagged{ .F32 = 3.14 }, + }, + .{ + .name = "tagged, int variant", + .tokens = &.{.{ .I32 = 123 }}, + .want = WantTagged{ .I32 = 123 }, + }, + .{ + .name = "tagged, map variant", + .tokens = &.{ + .{ .Map = .{ .len = 3 } }, + .{ .String = "A" }, + .{ .I32 = 1 }, + .{ .String = "B" }, + .{ .I32 = 2 }, + .{ .String = "C" }, + .{ .I32 = 3 }, + .{ .MapEnd = {} }, + }, + .want = WantTagged{ .Map = .{ .A = 1, .B = 2, .C = 3 } }, + }, + .{ + .name = "tagged, optional variant (null)", + .tokens = &.{.{ .Null = {} }}, + .want = WantTagged{ .Optional = null }, + }, + .{ + .name = "tagged, optional variant (some)", + .tokens = &.{ + .{ .Some = {} }, + .{ .Void = {} }, + }, + .want = WantTagged{ .Optional = {} }, + }, + .{ + .name = "tagged, sequence variant", + .tokens = &.{ + .{ .Seq = .{ .len = 3 } }, + .{ .I32 = 1 }, + .{ .I32 = 2 }, + .{ .I32 = 3 }, + .{ .SeqEnd = {} }, + }, + .want = WantTagged{ .Seq = [_]i32{ 1, 2, 3 } }, + }, + .{ + .name = "tagged, string variant", + .tokens = &.{.{ .String = "abcdef" }}, + .want = WantTagged{ .String = "abcdef" }, + }, + .{ + .name = "tagged, union variant", + .tokens = &.{ + .{ .Map = .{ .len = 1 } }, + .{ .String = "foo" }, + .{ .I32 = 1 }, + .{ .MapEnd = {} }, + }, + .want = WantTagged{ .Union = .{ .foo = 1 } }, + }, + .{ + .name = "tagged, union variant (untagged)", + .tokens = &.{ + .{ .Seq = .{ .len = 2 } }, + .{ .Bool = true }, + .{ .Bool = false }, + .{ .SeqEnd = {} }, + }, + .want = WantTagged{ .UnionUntagged = .{ .Bools = [_]bool{ true, false } } }, + }, + .{ + .name = "tagged, void variant", + .tokens = &.{.{ .Void = {} }}, + .want = WantTagged{ .Void = {} }, + }, + }; + + inline for (tests) |t| { + const Want = @TypeOf(t.want); + const Test = @TypeOf(t); + + if (@hasField(Test, "want_err")) { + try testing.expectError( + t.name, + t.want_err, + testing.deserializeErr(std.testing.allocator, @This(), Want, t.tokens), + ); + } else { + const got = try testing.deserialize(std.testing.allocator, t.name, @This(), Want, t.tokens); + + if (@typeInfo(@TypeOf(t.want)).Union.tag_type) |_| { + switch (t.want) { + .String => |want| { + defer std.testing.allocator.free(got.String); + try testing.expectEqualSlices(t.name, u8, want, got.String); + }, + else => |want| try testing.expectEqual(t.name, want, got), + } + } else { + if (comptime std.mem.eql(u8, t.tag, "String")) { + defer std.testing.allocator.free(got.String); + try testing.expectEqualSlices(t.name, u8, got.want, got.String); + } else { + try testing.expectEqual(t.name, t.want, @field(got, t.tag)); + } + } + } + } +} + fn runTest(t: anytype, comptime Want: type) !void { const Test = @TypeOf(t); @@ -315,10 +1017,10 @@ fn runTest(t: anytype, comptime Want: type) !void { try testing.expectError( t.name, t.want_err, - testing.deserializeErr(null, Self, Want, t.tokens), + testing.deserializeErr(std.testing.allocator, @This(), Want, t.tokens), ); } else { - const got = try testing.deserialize(null, t.name, Self, Want, t.tokens); + const got = try testing.deserialize(std.testing.allocator, t.name, @This(), Want, t.tokens); if (t.tagged) { try testing.expectEqual(t.name, t.want, got); diff --git a/src/de/deserialize.zig b/src/de/deserialize.zig index 128f5b42..c4543ed7 100644 --- a/src/de/deserialize.zig +++ b/src/de/deserialize.zig @@ -123,7 +123,7 @@ test "deserialize - success, normal" { .{ .StructEnd = {} }, }); - const got = deserialize(null, Point, d.deserializer()) catch return error.UnexpectedTestError; + const got = deserialize(std.testing.allocator, Point, d.deserializer()) catch return error.UnexpectedTestError; try expectEqual(expected, got); } @@ -136,7 +136,7 @@ test "deserialize - success, normal" { .{ .SeqEnd = {} }, }); - const got = deserialize(null, Point, d.deserializer()) catch return error.UnexpectedTestError; + const got = deserialize(std.testing.allocator, Point, d.deserializer()) catch return error.UnexpectedTestError; try expectEqual(expected, got); } @@ -149,7 +149,7 @@ test "deserialize - success, normal" { .{ .SeqEnd = {} }, }); - const got = deserialize(null, Point, d.deserializer()) catch return error.UnexpectedTestError; + const got = deserialize(std.testing.allocator, Point, d.deserializer()) catch return error.UnexpectedTestError; try expectEqual(expected, got); } @@ -162,7 +162,7 @@ test "deserialize - success, normal" { .{ .SeqEnd = {} }, }); - const got = deserialize(null, PointCustom, d.deserializer()) catch return error.UnexpectedTestError; + const got = deserialize(std.testing.allocator, PointCustom, d.deserializer()) catch return error.UnexpectedTestError; try expectEqual(expected_custom, got); } } @@ -208,7 +208,7 @@ test "deserialize - success, attributes" { .{ .StructEnd = {} }, }); - const got = deserialize(null, Point, d.deserializer()) catch return error.UnexpectedTestError; + const got = deserialize(std.testing.allocator, Point, d.deserializer()) catch return error.UnexpectedTestError; try expectEqual(expected, got); } @@ -223,7 +223,7 @@ test "deserialize - success, attributes" { .{ .StructEnd = {} }, }); - const got = deserialize(null, Point, d.deserializer()) catch return error.UnexpectedTestError; + const got = deserialize(std.testing.allocator, Point, d.deserializer()) catch return error.UnexpectedTestError; try expectEqual(expected, got); } @@ -238,7 +238,7 @@ test "deserialize - success, attributes" { .{ .StructEnd = {} }, }); - const got = deserialize(null, PointCustom, d.deserializer()) catch return error.UnexpectedTestError; + const got = deserialize(std.testing.allocator, PointCustom, d.deserializer()) catch return error.UnexpectedTestError; try expectEqual(expected_custom, got); } } @@ -294,7 +294,7 @@ test "deserialize - priority" { .{ .StructEnd = {} }, }); - const got = deserialize(null, Point, d.deserializer()) catch return error.UnexpectedTestError; + const got = deserialize(std.testing.allocator, Point, d.deserializer()) catch return error.UnexpectedTestError; try expectEqual(expected, got); } @@ -319,7 +319,7 @@ test "deserialize - priority" { .{ .StructEnd = {} }, }); - const got = deserialize(null, PointCustom, d.deserializer()) catch return error.UnexpectedTestError; + const got = deserialize(std.testing.allocator, PointCustom, d.deserializer()) catch return error.UnexpectedTestError; try expectEqual(expected, got); } @@ -344,7 +344,7 @@ test "deserialize - priority" { .{ .StructEnd = {} }, }); - const got = deserialize(null, PointInvalidCustom, d.deserializer()) catch return error.UnexpectedTestError; + const got = deserialize(std.testing.allocator, PointInvalidCustom, d.deserializer()) catch return error.UnexpectedTestError; try expectEqual(expected, got); } } diff --git a/src/de/testing.zig b/src/de/testing.zig index 5ef0be3a..7d27cca7 100644 --- a/src/de/testing.zig +++ b/src/de/testing.zig @@ -114,7 +114,7 @@ pub fn Deserializer(comptime user_dbt: anytype, comptime deserializer_dbt: anyty .deserializeFloat = deserializeAny, .deserializeInt = deserializeAny, .deserializeIgnored = deserializeIgnored, - .deserializeMap = deserializeMap, + .deserializeMap = deserializeAny, .deserializeOptional = deserializeAny, .deserializeSeq = deserializeAny, .deserializeString = deserializeAny, @@ -154,6 +154,15 @@ pub fn Deserializer(comptime user_dbt: anytype, comptime deserializer_dbt: anyty .String => |v| try visitor.visitString(allocator, De, v), else => |v| std.debug.panic("deserialization did not expect this token: {s}", .{@tagName(v)}), }, + .Map => |v| blk: { + var m = Map{ .de = self, .len = v.len, .end = .MapEnd }; + var value = try visitor.visitMap(allocator, De, m.mapAccess()); + + try expectEqual(@as(usize, 0), m.len.?); + try self.assertNextToken(.MapEnd); + + break :blk value; + }, .Null => try visitor.visitNull(allocator, De), .Some => try visitor.visitSome(allocator, self.deserializer()), .String => |v| try visitor.visitString(allocator, De, v), @@ -168,10 +177,10 @@ pub fn Deserializer(comptime user_dbt: anytype, comptime deserializer_dbt: anyty break :blk value; }, .Struct => |v| blk: { - var s = Struct{ .de = self, .len = v.len, .end = .StructEnd }; - var value = try visitor.visitMap(allocator, De, s.mapAccess()); + var m = Map{ .de = self, .len = v.len, .end = .StructEnd }; + var value = try visitor.visitMap(allocator, De, m.mapAccess()); - try expectEqual(@as(usize, 0), s.len.?); + try expectEqual(@as(usize, 0), m.len.?); try self.assertNextToken(.StructEnd); break :blk value; @@ -186,21 +195,6 @@ pub fn Deserializer(comptime user_dbt: anytype, comptime deserializer_dbt: anyty }; } - fn deserializeMap(self: *Self, allocator: ?std.mem.Allocator, visitor: anytype) Error!@TypeOf(visitor).Value { - return switch (self.nextToken()) { - .Map => |v| blk: { - var m = Map{ .de = self, .len = v.len, .end = .MapEnd }; - var value = try visitor.visitMap(allocator, De, m.mapAccess()); - - try expectEqual(@as(usize, 0), m.len.?); - try self.assertNextToken(.MapEnd); - - break :blk value; - }, - else => |v| std.debug.panic("deserialization did not expect this token: {s}", .{@tagName(v)}), - }; - } - fn deserializeIgnored(self: *Self, allocator: ?std.mem.Allocator, visitor: anytype) Error!@TypeOf(visitor).Value { _ = self.nextTokenOpt(); return try visitor.visitVoid(allocator, De); @@ -289,65 +283,13 @@ pub fn Deserializer(comptime user_dbt: anytype, comptime deserializer_dbt: anyty } }; - const Struct = struct { - de: *Self, - len: ?usize, - end: Token, - - pub usingnamespace MapAccessInterface( - *Struct, - Error, - .{ - .nextKeySeed = nextKeySeed, - .nextValueSeed = nextValueSeed, - .isKeyAllocated = isKeyAllocated, - }, - ); - - fn nextKeySeed(self: *Struct, _: ?std.mem.Allocator, seed: anytype) Error!?@TypeOf(seed).Value { - // All fields have been deserialized. - if (self.len.? == 0) { - return null; - } - - if (self.de.peekTokenOpt()) |token| { - if (std.meta.eql(token, self.end)) return null; - } else { - return null; - } - - if (self.de.nextTokenOpt()) |token| { - self.len.? -= @as(usize, 1); - - if (token != .String) { - return error.InvalidType; - } - - return token.String; - } else { - return null; - } - } - - fn nextValueSeed(self: *Struct, allocator: ?std.mem.Allocator, seed: anytype) Error!@TypeOf(seed).Value { - return try seed.deserialize(allocator, self.de.deserializer()); - } - - fn isKeyAllocated(_: *Struct, comptime _: type) bool { - return false; - } - }; - const Union = struct { de: *Self, pub usingnamespace UnionAccessInterface( *Union, Error, - .{ - .variantSeed = variantSeed, - .isVariantAllocated = isVariantAllocated, - }, + .{ .variantSeed = variantSeed }, ); pub usingnamespace VariantAccessInterface( @@ -356,11 +298,11 @@ pub fn Deserializer(comptime user_dbt: anytype, comptime deserializer_dbt: anyty .{ .payloadSeed = payloadSeed }, ); - fn variantSeed(self: *Union, _: ?std.mem.Allocator, seed: anytype) Error!@TypeOf(seed).Value { - const token = self.de.nextToken(); - - if (token == .String) { - return token.String; + fn variantSeed(self: *Union, allocator: ?std.mem.Allocator, seed: anytype) Error!@TypeOf(seed).Value { + if (self.de.peekTokenOpt()) |token| { + if (token == .String) { + return try seed.deserialize(allocator, self.de.deserializer()); + } } return error.InvalidType; @@ -375,10 +317,6 @@ pub fn Deserializer(comptime user_dbt: anytype, comptime deserializer_dbt: anyty return error.UnknownVariant; } } - - fn isVariantAllocated(_: *Union, comptime _: type) bool { - return false; - } }; }; }