diff --git a/pkg/chunked/storage_linux.go b/pkg/chunked/storage_linux.go index 403d7d5aa3..c6259636b2 100644 --- a/pkg/chunked/storage_linux.go +++ b/pkg/chunked/storage_linux.go @@ -26,6 +26,7 @@ import ( "github.com/containers/storage/pkg/chunked/toc" "github.com/containers/storage/pkg/fsverity" "github.com/containers/storage/pkg/idtools" + "github.com/containers/storage/pkg/ioutils" "github.com/containers/storage/pkg/system" jsoniter "github.com/json-iterator/go" "github.com/klauspost/compress/zstd" @@ -107,7 +108,7 @@ type chunkedLayerData struct { Format graphdriver.DifferOutputFormat `json:"format"` } -func (c *chunkedDiffer) convertTarToZstdChunked(destDirectory string, payload *os.File) (int64, *seekableFile, digest.Digest, map[string]string, error) { +func (c *chunkedDiffer) convertTarToZstdChunked(destDirectory string, payload io.Reader) (int64, *seekableFile, digest.Digest, map[string]string, error) { diff, err := archive.DecompressStream(payload) if err != nil { return 0, nil, "", nil, err @@ -1076,7 +1077,7 @@ func makeEntriesFlat(mergedEntries []fileMetadata) ([]fileMetadata, error) { return new, nil } -func (c *chunkedDiffer) copyAllBlobToFile(destination *os.File) (digest.Digest, error) { +func (c *chunkedDiffer) requestWholeBlob() (io.ReadCloser, digest.Digester, error) { var payload io.ReadCloser var streams chan io.ReadCloser var errs chan error @@ -1091,26 +1092,27 @@ func (c *chunkedDiffer) copyAllBlobToFile(destination *os.File) (digest.Digest, streams, errs, err = c.stream.GetBlobAt(chunksToRequest) if err != nil { - return "", err + return nil, nil, err } select { case p := <-streams: payload = p case err := <-errs: - return "", err + return nil, nil, err } if payload == nil { - return "", errors.New("invalid stream returned") + return nil, nil, errors.New("invalid stream returned") } - originalRawDigester := digest.Canonical.Digester() + digester := digest.Canonical.Digester() - r := io.TeeReader(payload, originalRawDigester.Hash()) + r := io.TeeReader(payload, digester.Hash()) - // copy the entire tarball and compute its digest - _, err = io.CopyBuffer(destination, r, c.copyBuffer) + rc := ioutils.NewReadCloserWrapper(r, func() error { + return payload.Close() + }) - return originalRawDigester.Digest(), err + return rc, digester, nil } func (c *chunkedDiffer) ApplyDiff(dest string, options *archive.TarOptions, differOpts *graphdriver.DifferOptions) (graphdriver.DriverWithDifferOutput, error) { @@ -1131,32 +1133,17 @@ func (c *chunkedDiffer) ApplyDiff(dest string, options *archive.TarOptions, diff var convertedBlobSize int64 if c.convertToZstdChunked { - fd, err := unix.Open(dest, unix.O_TMPFILE|unix.O_RDWR|unix.O_CLOEXEC, 0o600) + payload, digester, err := c.requestWholeBlob() if err != nil { - return graphdriver.DriverWithDifferOutput{}, &fs.PathError{Op: "open", Path: dest, Err: err} + return graphdriver.DriverWithDifferOutput{}, err } - blobFile := os.NewFile(uintptr(fd), "blob-file") defer func() { - if blobFile != nil { - blobFile.Close() + if payload != nil { + payload.Close() } }() - // calculate the checksum before accessing the file. - compressedDigest, err = c.copyAllBlobToFile(blobFile) - if err != nil { - return graphdriver.DriverWithDifferOutput{}, err - } - - if compressedDigest != c.blobDigest { - return graphdriver.DriverWithDifferOutput{}, fmt.Errorf("invalid digest to convert: expected %q, got %q", c.blobDigest, compressedDigest) - } - - if _, err := blobFile.Seek(0, io.SeekStart); err != nil { - return graphdriver.DriverWithDifferOutput{}, err - } - - tarSize, fileSource, diffID, annotations, err := c.convertTarToZstdChunked(dest, blobFile) + tarSize, fileSource, diffID, annotations, err := c.convertTarToZstdChunked(dest, payload) if err != nil { return graphdriver.DriverWithDifferOutput{}, err } @@ -1165,9 +1152,15 @@ func (c *chunkedDiffer) ApplyDiff(dest string, options *archive.TarOptions, diff // need to keep it open until the entire file is processed. defer fileSource.Close() - // Close the file so that the file descriptor is released and the file is deleted. - blobFile.Close() - blobFile = nil + // Make sure the entire payload is consumed. + _, _ = io.Copy(io.Discard, payload) + payload.Close() + payload = nil + + compressedDigest = digester.Digest() + if compressedDigest != c.blobDigest { + return graphdriver.DriverWithDifferOutput{}, fmt.Errorf("invalid digest to convert: expected %q, got %q", c.blobDigest, compressedDigest) + } tocDigest, err := toc.GetTOCDigest(annotations) if err != nil {