diff --git a/fairseq2n/src/fairseq2n/data/image/png_decoder.cc b/fairseq2n/src/fairseq2n/data/image/png_decoder.cc index 99a6ba2b8..b716d2439 100644 --- a/fairseq2n/src/fairseq2n/data/image/png_decoder.cc +++ b/fairseq2n/src/fairseq2n/data/image/png_decoder.cc @@ -90,22 +90,24 @@ png_decoder::operator()(data &&d) const int channels = png_get_channels(png_ptr, info_ptr); at::ScalarType dtype = opts_.maybe_dtype().value_or(at::kByte); - at::Tensor image = at::empty({height, width, 4}, at::dtype(dtype).device(at::kCPU).pinned_memory(opts_.pin_memory())); + at::Tensor image = at::empty({height, width, channels}, at::dtype(dtype).device(at::kCPU).pinned_memory(opts_.pin_memory())); + size_t rowbytes = png_get_rowbytes(png_ptr, info_ptr); - auto t_ptr = image.accessor().data(); - - // Copy image data into tensor object + writable_memory_span image_bits = get_raw_mutable_storage(image); + png_bytep image_data = reinterpret_cast(image_bits.data()); + + // Read image data into tensor for (png_uint_32 i = 0; i < height; ++i) { - png_read_row(png_ptr, t_ptr, nullptr); - t_ptr += rowbytes; + png_read_row(png_ptr, image_data, nullptr); + image_data += rowbytes; } - t_ptr = image.accessor().data(); - + // Move tensor to specified device at::Device device = opts_.maybe_device().value_or(at::kCPU); if (device != at::kCPU) image = image.to(device); + // Pack png data and format as output. data_dict output{ {"bit_depth", static_cast(bit_depth)}, {"color_type", static_cast(color_type)},