From d673068e96dda85d3c3940e424ca2ee1209628a7 Mon Sep 17 00:00:00 2001 From: antonydellavecchia Date: Thu, 22 Aug 2024 16:28:38 +0200 Subject: [PATCH] Serialize Types with Attributes (#4020) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * first attempts at writing the serializers for the new types PhylogeneticModel and GroupBasedPhylogeneticModel * include serialization functions from AlgebraicStatistics * first steps to serializing attributes * saving and loading attributes working * added some documentation on with_attrs * updated docs, + removed algebraic stats commits, and fixed comment * Update experimental/AlgebraicStatistics/src/AlgebraicStatistics.jl * Apply suggestions from code review Co-authored-by: Lars Göttgens * includes load_attrs and save_attrs in import all function * Apply suggestions from code review Co-authored-by: Lars Göttgens * changed type_attr_map to a function * type_attr_map -> type_attr_dict * uses encode type to setup type_attr_map * still issues with the test * More `Type` -> `String` * forgot to with_attrs to save in setup_tests * toric tests passing * moved save_attrs to appropriate line * moved attrs_list to appropriate line * Update src/Serialization/main.jl Co-authored-by: Lars Göttgens * Update src/Serialization/main.jl Co-authored-by: Lars Göttgens * Update src/Serialization/serializers.jl Co-authored-by: Lars Göttgens * Update src/Serialization/main.jl Co-authored-by: Lars Göttgens * small fix still need to move the check * added check inside save_attrs * offline review * fixed save docs --------- Co-authored-by: Marius Co-authored-by: Lars Göttgens --- src/Serialization/ToricGeometry.jl | 23 ++++++- src/Serialization/main.jl | 94 +++++++++++++++++++++++------ src/Serialization/serializers.jl | 34 ++++++++--- test/Serialization/ToricGeometry.jl | 15 ++++- test/Serialization/setup_tests.jl | 14 +++-- 5 files changed, 143 insertions(+), 37 deletions(-) diff --git a/src/Serialization/ToricGeometry.jl b/src/Serialization/ToricGeometry.jl index 98a9706c084c..6d58b9b5ccb1 100644 --- a/src/Serialization/ToricGeometry.jl +++ b/src/Serialization/ToricGeometry.jl @@ -1,13 +1,30 @@ ################################################################################ # Toric varieties @register_serialization_type AffineNormalToricVariety uses_id -@register_serialization_type NormalToricVariety uses_id -function save_object(s::SerializerState, ntv::NormalToricVarietyType) - save_object(s, ntv.polymakeNTV) +@register_serialization_type NormalToricVariety uses_id [:cox_ring] + + +function save_object(s::SerializerState, ntv::T) where T <: NormalToricVarietyType + attrs = attrs_list(s, T) + + if !isempty(attrs) && any([has_attribute(ntv, attr) for attr in attrs]) + save_data_dict(s) do + save_attrs(s, ntv) + save_object(s, ntv.polymakeNTV, :pm_data) + end + else + save_object(s, ntv.polymakeNTV) + end end function load_object(s::DeserializerState, ::Type{T}) where {T <: Union{NormalToricVariety, AffineNormalToricVariety}} + if haskey(s, :pm_data) + ntv = T(load_object(s, Polymake.BigObject, :pm_data)) + load_attrs(s, ntv) + + return ntv + end return T(load_object(s, Polymake.BigObject)) end diff --git a/src/Serialization/main.jl b/src/Serialization/main.jl index 10d469fd1d6d..2de3cf9a271c 100644 --- a/src/Serialization/main.jl +++ b/src/Serialization/main.jl @@ -67,6 +67,10 @@ function get_oscar_serialization_version() return oscar_serialization_version[] = result end +################################################################################ +# Type attribute map +const type_attr_map = Dict{String, Vector{Symbol}}() + ################################################################################ # (De|En)coding types @@ -189,6 +193,16 @@ function save_type_params(s::SerializerState, obj::Any, key::Symbol) save_type_params(s, obj) end +function save_attrs(s::SerializerState, obj::T) where T + if any(attr -> has_attribute(obj, attr), attrs_list(s, T)) + save_data_dict(s, :attrs) do + for attr in attrs_list(s, T) + has_attribute(obj, attr) && save_typed_object(s, get_attribute(obj, attr), attr) + end + end + end +end + # The load mechanism first checks if the type needs to load necessary # parameters before loading it's data, if so a type tree is traversed function load_typed_object(s::DeserializerState, key::Symbol; override_params::Any = nothing) @@ -239,6 +253,14 @@ function load_object(s::DeserializerState, T::Type, params::Any, key::Union{Symb end end +function load_attrs(s::DeserializerState, obj::T) where T + haskey(s, :attrs) && load_node(s, :attrs) do d + for attr in keys(d) + set_attribute!(obj, attr, load_typed_object(s, attr)) + end + end +end + ################################################################################ # Default generic save_internal, load_internal function save_object_generic(s::SerializerState, obj::T) where T @@ -283,6 +305,13 @@ function register_serialization_type(@nospecialize(T::Type), str::String) reverse_type_map[str] = T end +function register_attr_list(@nospecialize(T::Type), + attrs::Union{Vector{Symbol}, Nothing}) + if !isnothing(attrs) + Oscar.type_attr_map[encode_type(T)] = attrs + end +end + import Serialization.serialize import Serialization.deserialize import Serialization.serialize_type @@ -295,7 +324,8 @@ serialize_with_id(obj::Any) = false serialize_with_params(::Type) = false -function register_serialization_type(ex::Any, str::String, uses_id::Bool, uses_params::Bool) +function register_serialization_type(ex::Any, str::String, uses_id::Bool, + uses_params::Bool, attrs::Any) return esc( quote Oscar.register_serialization_type($ex, $str) @@ -313,6 +343,9 @@ function register_serialization_type(ex::Any, str::String, uses_id::Bool, uses_p # Types like ZZ, QQ, and ZZ/nZZ do not require ids since there is no syntactic # ambiguities in their encodings. + # add list of possible attributes to save for a given type to a global dict + Oscar.register_attr_list($ex, $attrs) + Oscar.serialize_with_id(obj::T) where T <: $ex = $uses_id Oscar.serialize_with_id(T::Type{<:$ex}) = $uses_id Oscar.serialize_with_params(T::Type{<:$ex}) = $uses_params @@ -327,31 +360,36 @@ function register_serialization_type(ex::Any, str::String, uses_id::Bool, uses_p Oscar.load(s.io; serializer_type=Oscar.IPCSerializer) end end - end) end """ - @register_serialization_type NewType "String Representation of type" uses_id uses_params + @register_serialization_type NewType "String Representation of type" uses_id uses_params [:attr1, :attr2] `@register_serialization_type` is a macro to ensure that the string we generate matches exactly the expression passed as first argument, and does not change in unexpected ways when import/export statements are adjusted. -The last three arguments are optional and can arise in any order. Passing a string argument will override how the type is stored as a string. -The last two are boolean flags. When setting `uses_id` the object will be - stored as a reference and will be referred to throughout the serialization -sessions using a `UUID`. This should typically only be used for types that - do not have a fixed normal form for example `PolyRing` and `MPolyRing`. + +When setting `uses_id` the object will be stored as a reference and +will be referred to throughout the serialization sessions using a `UUID`. +This should typically only be used for types that do not have a fixed +normal form for example `PolyRing` and `MPolyRing`. + Using the `uses_params` flag will serialize the object with a more structured type description which will make the serialization more efficient see the discussion on `save_type_params` / `load_type_params` below. + +Passing a vector of symbols that correspond to attributes of type +indicates which attributes will be serialized when using save with `with_attrs=true`. + """ macro register_serialization_type(ex::Any, args...) uses_id = false uses_params = false str = nothing + attrs = nothing for el in args if el isa String str = el @@ -359,13 +397,15 @@ macro register_serialization_type(ex::Any, args...) uses_id = true elseif el == :uses_params uses_params = true + else + attrs = el end end if str === nothing str = string(ex) end - return register_serialization_type(ex, str, uses_id, uses_params) + return register_serialization_type(ex, str, uses_id, uses_params, attrs) end @@ -394,11 +434,13 @@ macro import_all_serialization_functions() encode_type, haskey, load_array_node, + load_attrs, load_node, load_params_node, load_ref, load_typed_object, save_as_ref, + save_attrs, save_data_array, save_data_basic, save_data_dict, @@ -435,11 +477,14 @@ include("Upgrades/main.jl") # Interacting with IO streams and files """ - save(io::IO, obj::Any; metadata::MetaData=nothing) - save(filename::String, obj::Any, metadata::MetaData=nothing) + save(io::IO, obj::Any; metadata::MetaData=nothing, with_attrs::Bool=true) + save(filename::String, obj::Any, metadata::MetaData=nothing, with_attrs::Bool=true) Save an object `obj` to the given io stream -respectively to the file `filename`. +respectively to the file `filename`. When used with `with_attrs=true` then the object will +save it's attributes along with all the attributes of the types used in the object's struct. +The attributes that will be saved are defined during type registration see +[`@register_serialization_type`](@ref) See [`load`](@ref). @@ -463,8 +508,11 @@ julia> load("/tmp/fourtitwo.mrdi") ``` """ function save(io::IO, obj::T; metadata::Union{MetaData, Nothing}=nothing, + with_attrs::Bool=true, serializer_type::Type{<: OscarSerializer} = JSONSerializer) where T - s = state(serializer_open(io, serializer_type)) + + s = state(serializer_open(io, serializer_type, + with_attrs ? type_attr_map : Dict{String, Vector{Symbol}}())) save_data_dict(s) do # write out the namespace first save_header(s, get_oscar_serialization_version(), :_ns) @@ -502,20 +550,21 @@ function save(io::IO, obj::T; metadata::Union{MetaData, Nothing}=nothing, return nothing end -function save(filename::String, obj::Any; metadata::Union{MetaData, Nothing}=nothing) +function save(filename::String, obj::Any; metadata::Union{MetaData, Nothing}=nothing, + with_attrs::Bool=true) dir_name = dirname(filename) # julia dirname does not return "." for plain filenames without any slashes temp_file = tempname(isempty(dir_name) ? pwd() : dir_name) open(temp_file, "w") do file - save(file, obj; metadata=metadata) + save(file, obj; metadata=metadata, with_attrs=with_attrs) end Base.Filesystem.rename(temp_file, filename) # atomic "multi process safe" return nothing end """ - load(io::IO; params::Any = nothing, type::Any = nothing) - load(filename::String; params::Any = nothing, type::Any = nothing) + load(io::IO; params::Any = nothing, type::Any = nothing, with_attrs::Bool=true) + load(filename::String; params::Any = nothing, type::Any = nothing, with_attrs::Bool=true) Load the object stored in the given io stream respectively in the file `filename`. @@ -529,6 +578,9 @@ results in setting its parent, or in the case of a container of ring types such If a type `T` is given then attempt to load the root object of the data being loaded with this type; if this fails, an error is thrown. +If `with_attrs=true` the object will be loaded with attributes available from +the file (or serialized data). + See [`save`](@ref). # Examples @@ -568,8 +620,9 @@ true ``` """ function load(io::IO; params::Any = nothing, type::Any = nothing, - serializer_type=JSONSerializer) - s = state(deserializer_open(io, serializer_type)) + serializer_type=JSONSerializer, with_attrs::Bool=true) + s = state(deserializer_open(io, serializer_type, + with_attrs ? type_attr_map : Dict{String, Vector{Symbol}}())) if haskey(s.obj, :id) id = s.obj[:id] if haskey(global_serializer_state.id_to_obj, UUID(id)) @@ -658,7 +711,8 @@ function load(io::IO; params::Any = nothing, type::Any = nothing, end end -function load(filename::String; params::Any = nothing, type::Any = nothing) +function load(filename::String; params::Any = nothing, + type::Any = nothing, with_attrs::Bool=true) open(filename) do file return load(file; params=params, type=type) end diff --git a/src/Serialization/serializers.jl b/src/Serialization/serializers.jl index db9d4ccc2437..2466707c0711 100644 --- a/src/Serialization/serializers.jl +++ b/src/Serialization/serializers.jl @@ -31,6 +31,7 @@ mutable struct SerializerState refs::Vector{UUID} io::IO key::Union{Symbol, Nothing} + type_attr_map::Dict{String, Vector{Symbol}} end function begin_node(s::SerializerState) @@ -142,6 +143,7 @@ mutable struct DeserializerState obj::Union{Dict{Symbol, Any}, Vector, JSON3.Object, JSON3.Array, BasicTypeUnion} key::Union{Symbol, Int, Nothing} refs::Union{Dict{Symbol, Any}, JSON3.Object, Nothing} + type_attr_map::Dict{String, Vector{Symbol}} end # general loading of a reference @@ -212,26 +214,44 @@ end state(s::OscarSerializer) = s.state -function serializer_open(io::IO, T::Type{<: OscarSerializer}) +function serializer_open( + io::IO, + T::Type{<: OscarSerializer}, + type_attr_map::S) where S <: Union{Dict{String, Vector{Symbol}}, Nothing} + # some level of handling should be done here at a later date - return T(SerializerState(true, UUID[], io, nothing)) + return T(SerializerState(true, UUID[], io, nothing, type_attr_map)) end -function deserializer_open(io::IO, T::Type{JSONSerializer}) +function deserializer_open( + io::IO, + T::Type{JSONSerializer}, + type_attr_map::S) where S <:Union{Dict{String, Vector{Symbol}}, Nothing} + obj = JSON3.read(io) refs = nothing if haskey(obj, refs_key) refs = obj[refs_key] end - - return T(DeserializerState(obj, nothing, refs)) + + return T(DeserializerState(obj, nothing, refs, type_attr_map)) end -function deserializer_open(io::IO, T::Type{IPCSerializer}) +function deserializer_open( + io::IO, + T::Type{IPCSerializer}, + type_attr_map::S) where S <:Union{Dict{String, Vector{Symbol}}, Nothing} # Using a JSON3.Object from JSON3 version 1.13.2 causes # @everywhere using Oscar # to hang. So we use a Dict here for now. obj = JSON.parse(io, dicttype=Dict{Symbol, Any}) - return T(DeserializerState(obj, nothing, nothing)) + return T(DeserializerState(obj, nothing, nothing, type_attr_map)) +end + +const state_types = Union{SerializerState, DeserializerState} + +function attrs_list(s::U, T::Type) where U <: state_types + return get(s.type_attr_map, encode_type(T), Symbol[]) end + diff --git a/test/Serialization/ToricGeometry.jl b/test/Serialization/ToricGeometry.jl index 353e7d82bb3f..05c965bb572e 100644 --- a/test/Serialization/ToricGeometry.jl +++ b/test/Serialization/ToricGeometry.jl @@ -2,7 +2,20 @@ mktempdir() do path @testset "NormalToricVariety" begin pp = projective_space(NormalToricVariety, 2) - test_save_load_roundtrip(path, pp) do loaded + R = cox_ring(pp) + check(x) = has_attribute(x, :cox_ring) + + test_save_load_roundtrip(path, pp; with_attrs=false, check_func=!check) do loaded + @test rays(pp) == rays(loaded) + @test ray_indices(maximal_cones(pp)) == ray_indices(maximal_cones(loaded)) + end + + test_save_load_roundtrip(path, pp; with_attrs=true, check_func=check) do loaded + @test rays(pp) == rays(loaded) + @test ray_indices(maximal_cones(pp)) == ray_indices(maximal_cones(loaded)) + end + + test_save_load_roundtrip(path, pp; check_func=check) do loaded @test rays(pp) == rays(loaded) @test ray_indices(maximal_cones(pp)) == ray_indices(maximal_cones(loaded)) end diff --git a/test/Serialization/setup_tests.jl b/test/Serialization/setup_tests.jl index 2b9a77878e18..2c2da30ee2e4 100644 --- a/test/Serialization/setup_tests.jl +++ b/test/Serialization/setup_tests.jl @@ -7,10 +7,11 @@ using JSONSchema, Oscar.JSON if !isdefined(Main, :test_save_load_roundtrip) mrdi_schema = Schema(JSON.parsefile(joinpath(Oscar.oscardir, "data", "schema.json"))) - function test_save_load_roundtrip(func, path, original::T; params=nothing) where {T} + function test_save_load_roundtrip(func, path, original::T; + params=nothing, check_func=nothing, kw...) where {T} # save and load from a file filename = joinpath(path, "original.json") - save(filename, original) + save(filename, original; kw...) loaded = load(filename; params=params) @test loaded isa T @@ -18,7 +19,7 @@ if !isdefined(Main, :test_save_load_roundtrip) # save and load from an IO buffer io = IOBuffer() - save(io, original) + save(io, original; kw...) seekstart(io) loaded = load(io; params=params) @@ -27,7 +28,7 @@ if !isdefined(Main, :test_save_load_roundtrip) # save and load from an IO buffer, with prescribed type io = IOBuffer() - save(io, original) + save(io, original; kw...) seekstart(io) loaded = load(io; type=T, params=params) @@ -35,7 +36,7 @@ if !isdefined(Main, :test_save_load_roundtrip) func(loaded) # test loading on a empty state - save(filename, original) + save(filename, original; kw...) Oscar.reset_global_serializer_state() loaded = load(filename; params=params) @test loaded isa T @@ -43,6 +44,7 @@ if !isdefined(Main, :test_save_load_roundtrip) # test schema jsondict = JSON.parsefile(filename) @test validate(mrdi_schema, jsondict) == nothing - end + isnothing(check_func) || @test check_func(loaded) + end end