Skip to content

Commit

Permalink
Add @nospecializeinfer around worst offenders (#527)
Browse files Browse the repository at this point in the history
* @nospecializeinfer macro

* add compat

* more of less inference

* fixes a broken test

* less inlining and more despecialization

* undo type assert

* typo

* another typo

* typo
  • Loading branch information
JonasIsensee authored Jan 11, 2024
1 parent f8a9dd3 commit bda2a07
Show file tree
Hide file tree
Showing 11 changed files with 82 additions and 59 deletions.
2 changes: 1 addition & 1 deletion src/compression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ function deflate_data(f::JLDFile, data::Array{T}, odr::S, wsession::JLDWriteSess
end


@inline chunked_storage_message_size(ndims::Int) =
chunked_storage_message_size(ndims::Int) =
jlsizeof(HeaderMessage) + 5 + (ndims+1)*jlsizeof(Length) + 1 + jlsizeof(Length) + 4 + jlsizeof(RelOffset)


Expand Down
6 changes: 3 additions & 3 deletions src/data/custom_serialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,20 @@ CustomSerialization(::Type{WrittenAs}, ::Type{ReadAs}, odr) where {WrittenAs,Rea
odr_sizeof(::Type{CustomSerialization{T,ODR}}) where {T,ODR} = odr_sizeof(ODR)

# Usually we want to convert the object and then write it.
@inline h5convert!(out::Pointers, ::Type{CustomSerialization{T,ODR}}, f::JLDFile,
h5convert!(out::Pointers, ::Type{CustomSerialization{T,ODR}}, f::JLDFile,
x, wsession::JLDWriteSession) where {T,ODR} =
h5convert!(out, ODR, f, wconvert(T, x)::T, wsession)

# When writing as a reference, we don't want to convert the object first. That
# should happen automatically after write_dataset is called so that the written
# object gets the right written_type attribute.
@inline h5convert!(out::Pointers, odr::Type{CustomSerialization{T,RelOffset}},
h5convert!(out::Pointers, odr::Type{CustomSerialization{T,RelOffset}},
f::JLDFile, x, wsession::JLDWriteSession) where {T} =
h5convert!(out, RelOffset, f, x, wsession)

# When writing as a reference to something that's being custom-serialized as an
# array, we have to convert the object first.
@inline h5convert!(out::Pointers, odr::Type{CustomSerialization{T,RelOffset}},
h5convert!(out::Pointers, odr::Type{CustomSerialization{T,RelOffset}},
f::JLDFile, x, wsession::JLDWriteSession) where {T<:Array} =
h5convert!(out, RelOffset, f, wconvert(T, x)::T, wsession)

Expand Down
3 changes: 2 additions & 1 deletion src/data/specialcased_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ h5fieldtype(::JLDFile, ::Type{T}, ::Type{T}, ::Initialized) where {T<:Array} =
ReferenceDatatype()
fieldodr(::Type{T}, ::Bool) where {T<:Array} = RelOffset

@inline function odr(::Type{Array{T,N}}) where {T,N}
function odr(A::Type{<:Array})
T = eltype(A)
writtenas = writeas(T)
CustomSerialization(writtenas, T, fieldodr(writtenas, false))
end
Expand Down
40 changes: 21 additions & 19 deletions src/data/writing_datatypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ end
odr_sizeof(::ReadRepresentation{T,S}) where {T,S} = odr_sizeof(S)

# Determines whether a specific field type should be saved in the file
@noinline function hasfielddata(@nospecialize(T), encounteredtypes=DataType[])
function hasfielddata(@nospecialize(T), encounteredtypes=DataType[])::Bool
T === Union{} && return false
!isconcretetype(T) && return true
T = T::DataType
Expand All @@ -35,7 +35,7 @@ odr_sizeof(::ReadRepresentation{T,S}) where {T,S} = odr_sizeof(S)
end

# Determines whether a specific type has fields that should be saved in the file
function hasdata(T::DataType, encounteredtypes=DataType[])
@nospecializeinfer function hasdata(@nospecialize(T::DataType), encounteredtypes=DataType[])::Bool
isempty(T.types) && sizeof(T) != 0 && return true
for ty in T.types
hasfielddata(writeas(ty), copy(encounteredtypes)) && return true
Expand All @@ -50,7 +50,7 @@ function odr_sizeof(::OnDiskRepresentation{Offsets,JLTypes,H5Types,Size}) where
end

# Determines whether a type will have the same layout on disk as in memory
function samelayout(T::DataType)
function samelayout(@nospecialize(T::DataType))::Bool
isempty(T.types) && return true
offset = 0
for i = 1:length(T.types)
Expand All @@ -64,7 +64,7 @@ function samelayout(T::DataType)
end
samelayout(::Type) = false

fieldnames(x::Type{T}) where {T<:Tuple} = [Symbol(i) for i = 1:length(x.types)]
fieldnames(@nospecialize(x::Type{<:Tuple})) = [Symbol(i) for i = 1:length(x.types)]
fieldnames(@nospecialize x) = collect(Base.fieldnames(x))

const MAX_INLINE_SIZE = 2^10
Expand All @@ -89,8 +89,9 @@ end

# h5fieldtype is fieldodr's HDF5 companion. It should give the HDF5
# datatype reflecting the on-disk representation.
function h5fieldtype(f::JLDFile, writeas::Type{T}, readas::Type,
initialized::Initialized) where T
@nospecializeinfer function h5fieldtype(f::JLDFile, @nospecialize(writeas), @nospecialize(readas::Type),
initialized::Initialized)::Union{CommittedDatatype, H5Datatype, Nothing}
T = writeas
if isconcretetype(T)
if !hasfielddata(T)
return nothing
Expand All @@ -113,7 +114,7 @@ end
# almost always the on-disk representation of the type. The only
# exception is strings, where the length is encoded in the datatype in
# HDF5, but in the object in Julia.
@inline function objodr(x)
@nospecializeinfer function objodr(@nospecialize(x))
writtenas = writeas(typeof(x))
_odr(writtenas, typeof(x), odr(writtenas))
end
Expand All @@ -124,7 +125,7 @@ _odr(writtenas::DataType, readas::DataType, odr) =
# reflecting the on-disk representation
#
# Performance note: this should be inferable.
function h5type(f::JLDFile, writtenas, x)
@nospecializeinfer function h5type(f::JLDFile, @nospecialize(writtenas), @nospecialize(x))
check_writtenas_type(writtenas)
T = typeof(x)
@lookup_committed f T
Expand All @@ -138,11 +139,11 @@ function h5type(f::JLDFile, writtenas, x)
end
check_writtenas_type(::DataType) = nothing
check_writtenas_type(::Any) = throw(ArgumentError("writeas(leaftype) must return a leaf type"))
h5type(f::JLDFile, x) = h5type(f, writeas(typeof(x)), x)
h5type(f::JLDFile, @nospecialize(x)) = h5type(f, writeas(typeof(x)), x)

# Make a compound datatype from a set of names and types
function commit_compound(f::JLDFile, names::AbstractVector{Symbol},
writtenas::DataType, readas::Type)
@nospecializeinfer function commit_compound(f::JLDFile, names::AbstractVector{Symbol},
@nospecialize(writtenas::DataType), @nospecialize(readas::Type))
types = writtenas.types
offsets = Int[]
h5names = Symbol[]
Expand Down Expand Up @@ -186,7 +187,7 @@ function commit_compound(f::JLDFile, names::AbstractVector{Symbol},
end

# Write an HDF5 datatype to the file
function commit(f::JLDFile,
@nospecializeinfer function commit(f::JLDFile,
@nospecialize(dtype),#::H5Datatype,
@nospecialize(writeas::DataType),
@nospecialize(readas::DataType),
Expand Down Expand Up @@ -272,7 +273,7 @@ jlconvert_canbeuninitialized(::Any) = false

# jlconvert converts data from a pointer into a Julia object. This method
# handles types where this is just a simple load
@inline jlconvert(::ReadRepresentation{T,T}, ::JLDFile, ptr::Ptr,
jlconvert(::ReadRepresentation{T,T}, ::JLDFile, ptr::Ptr,
::RelOffset) where {T} =
jlunsafe_load(pconvert(Ptr{T}, ptr))

Expand All @@ -292,7 +293,7 @@ Base.showerror(io::IO, x::UndefinedFieldException) =
h5type(::JLDFile, ::Type{RelOffset}, ::RelOffset) = ReferenceDatatype()
odr(::Type{RelOffset}) = RelOffset

@inline function h5convert!(out::Pointers, odr::Type{RelOffset}, f::JLDFile, x::Any,
function h5convert!(out::Pointers, odr::Type{RelOffset}, f::JLDFile, x::Any,
wsession::JLDWriteSession)
ref = write_ref(f, x, wsession)
jlunsafe_store!(pconvert(Ptr{RelOffset}, out), ref)
Expand All @@ -308,7 +309,7 @@ jlconvert(::ReadRepresentation{RelOffset,RelOffset}, f::JLDFile, ptr::Ptr,
jlconvert_canbeuninitialized(::ReadRepresentation{RelOffset,RelOffset}) = false

# Reading references as other types
@inline function jlconvert(::ReadRepresentation{T,RelOffset}, f::JLDFile, ptr::Ptr,
function jlconvert(::ReadRepresentation{T,RelOffset}, f::JLDFile, ptr::Ptr,
::RelOffset) where T
x = load_dataset(f, jlunsafe_load(pconvert(Ptr{RelOffset}, ptr)))
(isa(x, T) ? x : rconvert(T, x))::T
Expand All @@ -321,7 +322,7 @@ jlconvert_isinitialized(::ReadRepresentation{T,RelOffset}, ptr::Ptr) where {T} =
## Routines for variable-length datatypes

# Write variable-length data and store the offset and length to out pointer
@inline function store_vlen!(out::Pointers, odr, f::JLDFile, x::AbstractVector,
function store_vlen!(out::Pointers, odr, f::JLDFile, x::AbstractVector,
wsession::JLDWriteSession)
jlunsafe_store!(pconvert(Ptr{UInt32}, out), length(x))
obj = write_heap_object(f, odr, x, wsession)
Expand Down Expand Up @@ -564,8 +565,9 @@ end


# jlconvert for empty objects
function jlconvert(::ReadRepresentation{T,nothing}, f::JLDFile, ptr::Ptr,
header_offset::RelOffset) where T
function jlconvert(@nospecialize(rr::ReadRepresentation{T,nothing} where T), f::JLDFile, ptr::Ptr,
header_offset::RelOffset)::eltype(rr)
T = eltype(rr)
sizeof(T) == 0 && return newstruct(T)::T

# In this case, T is a non-empty object, but the written data was empty
Expand Down Expand Up @@ -594,7 +596,7 @@ end
# odr gives the on-disk representation of a given type, similar to
# fieldodr, but actually encoding the data for things that odr stores
# as references
function odr(::Type{T}) where T
@nospecializeinfer function odr(@nospecialize(T::Type))
if !hasdata(T)
# A pointer singleton or ghost. We need to write something, but we'll
# just write a single byte.
Expand Down
12 changes: 6 additions & 6 deletions src/dataio.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ function read_compressed_array! end
# Cutoff for using ordinary IO instead of copying into mmapped region
const MMAP_CUTOFF = 1048576

@inline function read_scalar(f::JLDFile{MmapIO}, rr, header_offset::RelOffset)
function read_scalar(f::JLDFile{MmapIO}, @nospecialize(rr), header_offset::RelOffset)::Any
io = f.io
inptr = io.curptr
obj = jlconvert(rr, f, inptr, header_offset)
io.curptr = inptr + odr_sizeof(rr)
obj
end

@inline function read_array!(v::Array{T}, f::JLDFile{MmapIO},
function read_array!(v::Array{T}, f::JLDFile{MmapIO},
rr::ReadRepresentation{T,T}) where T
io = f.io
inptr = io.curptr
Expand All @@ -60,7 +60,7 @@ end
v
end

@inline function read_array!(v::Array{T}, f::JLDFile{MmapIO},
function read_array!(v::Array{T}, f::JLDFile{MmapIO},
rr::ReadRepresentation{T,RR}) where {T,RR}
io = f.io
inptr = io.curptr
Expand Down Expand Up @@ -175,7 +175,7 @@ end
# IOStream/BufferedWriter
#

@inline function read_scalar(f::JLDFile{IOStream}, rr, header_offset::RelOffset)
function read_scalar(f::JLDFile{IOStream}, rr, header_offset::RelOffset)::Any
r = Vector{UInt8}(undef, odr_sizeof(rr))
@GC.preserve r begin
unsafe_read(f.io, pointer(r), odr_sizeof(rr))
Expand All @@ -184,13 +184,13 @@ end
end


@inline function read_array!(v::Array{T}, f::JLDFile{IOStream},
function read_array!(v::Array{T}, f::JLDFile{IOStream},
rr::ReadRepresentation{T,T}) where T
unsafe_read(f.io, pointer(v), odr_sizeof(T)*length(v))
v
end

@inline function read_array!(v::Array{T}, f::JLDFile{IOStream},
function read_array!(v::Array{T}, f::JLDFile{IOStream},
rr::ReadRepresentation{T,RR}) where {T,RR}
n = length(v)
nb = odr_sizeof(RR)*n
Expand Down
33 changes: 18 additions & 15 deletions src/datasets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,8 @@ end

# Types with no payload can only be null dataspace
function read_data(f::JLDFile,
rr::Union{ReadRepresentation{T,nothing} where T,
ReadRepresentation{T,CustomSerialization{S,nothing}} where {S,T}},
@nospecialize(rr::Union{ReadRepresentation{T,nothing} where T,
ReadRepresentation{T,CustomSerialization{S,nothing}} where {S,T}}),
read_dataspace::Tuple{ReadDataspace,RelOffset,DataLayout,FilterPipeline},
attributes::Vector{ReadAttribute})
dataspace, header_offset, layout, filters = read_dataspace
Expand Down Expand Up @@ -391,9 +391,10 @@ function construct_array(io::IO, ::Type{T}, ::Val{N})::Array{T,N} where {T,N}
end

function read_array(f::JLDFile, dataspace::ReadDataspace,
rr::ReadRepresentation{T,RR}, layout::DataLayout,
@nospecialize(rr::ReadRepresentation), layout::DataLayout,
filters::FilterPipeline, header_offset::RelOffset,
attributes::Union{Vector{ReadAttribute},Nothing}) where {T,RR}
attributes::Union{Vector{ReadAttribute},Nothing})
T = eltype(rr)
io = f.io
data_offset = layout.data_offset
if !ischunked(layout) || (layout.chunk_indexing_type == 1)
Expand Down Expand Up @@ -421,7 +422,8 @@ function read_array(f::JLDFile, dataspace::ReadDataspace,
chunks = read_v1btree_dataset_chunks(f, h5offset(f, layout.data_offset), layout.dimensionality)
vchunk = Array{T, Int(ndims)}(undef, reverse(layout.chunk_dimensions)...)
for chunk in chunks
idx = reverse(chunk.idx[1:end-1])
cidx = chunk.idx::NTuple{Int(ndims+1), Int}
idx = reverse(cidx[1:end-1])
seek(io, fileoffset(f, chunk.offset))
indexview = (:).(idx .+1, min.(idx .+ reverse(layout.chunk_dimensions), size(v)))
indexview2 = (:).(1, length.(indexview))
Expand Down Expand Up @@ -468,15 +470,16 @@ function payload_size_without_storage_message(dataspace::WriteDataspace, datatyp
end


function write_dataset(
@nospecializeinfer function write_dataset(
f::JLDFile,
dataspace::WriteDataspace,
datatype::H5Datatype,
odr::S,
data::Array{T},
@nospecialize(odr),
@nospecialize(data::Array),
wsession::JLDWriteSession,
compress = f.compress,
) where {T,S}
@nospecialize(compress = f.compress),
)
T = eltype(data)
io = f.io
datasz = odr_sizeof(odr)::Int * numel(dataspace)::Int
#layout_class
Expand Down Expand Up @@ -540,7 +543,7 @@ function write_dataset(
h5offset(f, header_offset)
end

function write_dataset(f::JLDFile, dataspace::WriteDataspace, datatype::H5Datatype, odr::S, data, wsession::JLDWriteSession) where S
@nospecializeinfer function write_dataset(f::JLDFile, dataspace::WriteDataspace, datatype::H5Datatype, @nospecialize(odr), @nospecialize(data), wsession::JLDWriteSession)
io = f.io
datasz = (odr_sizeof(odr)::Int * numel(dataspace))
psz = payload_size_without_storage_message(dataspace, datatype)
Expand Down Expand Up @@ -621,7 +624,7 @@ struct CompactStorageMessage
data_size::UInt16
end
define_packed(CompactStorageMessage)
@inline CompactStorageMessage(datasz::Int) =
CompactStorageMessage(datasz::Int) =
CompactStorageMessage(
HeaderMessage(HM_DATA_LAYOUT, jlsizeof(CompactStorageMessage) - jlsizeof(HeaderMessage) + datasz, 0),
4, LC_COMPACT_STORAGE, datasz
Expand All @@ -635,13 +638,13 @@ struct ContiguousStorageMessage
data_size::Length
end
define_packed(ContiguousStorageMessage)
@inline ContiguousStorageMessage(datasz::Int, offset::RelOffset) =
ContiguousStorageMessage(datasz::Int, offset::RelOffset) =
ContiguousStorageMessage(
HeaderMessage(HM_DATA_LAYOUT, jlsizeof(ContiguousStorageMessage) - jlsizeof(HeaderMessage), 0),
4, LC_CONTIGUOUS_STORAGE, offset, datasz
)

function write_dataset(f::JLDFile, x, wsession::JLDWriteSession)
@nospecializeinfer function write_dataset(f::JLDFile, @nospecialize(x), wsession::JLDWriteSession)::RelOffset
if ismutabletype(typeof(x)) && !isa(wsession, JLDWriteSession{Union{}})
offset = get(wsession.h5offset, objectid(x), UNDEFINED_ADDRESS)
offset != UNDEFINED_ADDRESS && return offset
Expand All @@ -650,7 +653,7 @@ function write_dataset(f::JLDFile, x, wsession::JLDWriteSession)
write_dataset(f, WriteDataspace(f, x, odr), h5type(f, x), odr, x, wsession)::RelOffset
end

write_ref(f::JLDFile, x, wsession::JLDWriteSession) = write_dataset(f, x, wsession)::RelOffset
write_ref(f::JLDFile, @nospecialize(x), wsession::JLDWriteSession) = write_dataset(f, x, wsession)::RelOffset
write_ref(f::JLDFile, x::RelOffset, wsession::JLDWriteSession) = x

Base.delete!(f::JLDFile, x::AbstractString) = delete!(f.root_group, x)
Expand Down
3 changes: 2 additions & 1 deletion src/inlineunion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ function writeasbits(T::Union)
length(types) == 2 && isbitstype(types[1]) && isbitstype(types[2])
end

function write_dataset(f::JLDFile, x::Array{T}, wsession::JLDWriteSession, compress=f.compress) where {T}
@nospecializeinfer function write_dataset(f::JLDFile, @nospecialize(x::Array), wsession::JLDWriteSession, @nospecialize(compress=f.compress))
T = eltype(x)
if !isa(wsession, JLDWriteSession{Union{}})
offset = get(wsession.h5offset, objectid(x), UNDEFINED_ADDRESS)
offset != UNDEFINED_ADDRESS && return offset
Expand Down
8 changes: 8 additions & 0 deletions src/julia_compat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,12 @@ else
function ninitialized(@nospecialize(T::Type))::Int
fieldcount(T) - T.name.n_uninitialized
end
end

@static if VERSION < v"1.10.0"
macro nospecializeinfer(exp)
esc(exp)
end
else
using Base: @nospecializeinfer
end
Loading

0 comments on commit bda2a07

Please sign in to comment.