Skip to content

Commit

Permalink
use get_raw_mutable_storage instead
Browse files Browse the repository at this point in the history
  • Loading branch information
am831 committed Oct 20, 2023
1 parent 3853c64 commit 1fc8e45
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions fairseq2n/src/fairseq2n/data/image/png_decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,22 +90,24 @@ png_decoder::operator()(data &&d) const
int channels = png_get_channels(png_ptr, info_ptr);

at::ScalarType dtype = opts_.maybe_dtype().value_or(at::kByte);
at::Tensor image = at::empty({height, width, 4}, at::dtype(dtype).device(at::kCPU).pinned_memory(opts_.pin_memory()));
at::Tensor image = at::empty({height, width, channels}, at::dtype(dtype).device(at::kCPU).pinned_memory(opts_.pin_memory()));

size_t rowbytes = png_get_rowbytes(png_ptr, info_ptr);
auto t_ptr = image.accessor<uint8_t, 3>().data();

// Copy image data into tensor object
writable_memory_span image_bits = get_raw_mutable_storage(image);
png_bytep image_data = reinterpret_cast<png_bytep>(image_bits.data());

// Read image data into tensor
for (png_uint_32 i = 0; i < height; ++i) {
png_read_row(png_ptr, t_ptr, nullptr);
t_ptr += rowbytes;
png_read_row(png_ptr, image_data, nullptr);
image_data += rowbytes;
}
t_ptr = image.accessor<uint8_t, 3>().data();


// Move tensor to specified device
at::Device device = opts_.maybe_device().value_or(at::kCPU);
if (device != at::kCPU)
image = image.to(device);


// Pack png data and format as output.
data_dict output{
{"bit_depth", static_cast<float32>(bit_depth)}, {"color_type", static_cast<float32>(color_type)},
Expand Down

0 comments on commit 1fc8e45

Please sign in to comment.