diff --git a/archiver.go b/archiver.go index 68c53d2a..4f973411 100644 --- a/archiver.go +++ b/archiver.go @@ -232,3 +232,35 @@ func folderNameFromFileName(filename string) string { } return base } + +// makeBaseDir returns the base directory to use for storing files in an +// archive. topLevelFolder should be the name of the top-level folder of +// the archive (if there is one), and sourceInfo is the file info obtained +// by calling os.Stat on the source file or directory to include in the +// archive. +func makeBaseDir(topLevelFolder string, sourceInfo os.FileInfo) string { + var baseDir string + if topLevelFolder != "" { + baseDir = topLevelFolder + } + if sourceInfo.IsDir() { + baseDir = path.Join(baseDir, sourceInfo.Name()) + } + return baseDir +} + +// makeNameInArchive returns the filename for the file given by fpath to be used within +// the archive. sourceInfo is the info obtained by calling os.Stat on source, and baseDir +// is the base directory obtained by calling makeBaseDir. fpath should be the unaltered +// file path of the file given to a filepath.WalkFunc. +func makeNameInArchive(sourceInfo os.FileInfo, source, baseDir, fpath string) (string, error) { + name := fpath + if sourceInfo.IsDir() { + var err error + name, err = filepath.Rel(source, fpath) + if err != nil { + return "", err + } + } + return path.Join(baseDir, filepath.ToSlash(name)), nil +} diff --git a/archiver_test.go b/archiver_test.go index dddb0bf9..9485e608 100644 --- a/archiver_test.go +++ b/archiver_test.go @@ -8,6 +8,7 @@ import ( "os" "path/filepath" "testing" + "time" ) func TestWithin(t *testing.T) { @@ -138,6 +139,94 @@ func TestMultipleTopLevels(t *testing.T) { } } +func TestMakeBaseDir(t *testing.T) { + for i, tc := range []struct { + topLevelFolder string + sourceInfo fakeFileInfo + expect string + }{ + { + topLevelFolder: "", + sourceInfo: fakeFileInfo{isDir: false}, + expect: "", + }, + { + topLevelFolder: "foo", + sourceInfo: fakeFileInfo{isDir: false}, + expect: "foo", + }, + { + topLevelFolder: "", + sourceInfo: fakeFileInfo{isDir: true, name: "bar"}, + expect: "bar", + }, + { + topLevelFolder: "foo", + sourceInfo: fakeFileInfo{isDir: true, name: "bar"}, + expect: "foo/bar", + }, + } { + actual := makeBaseDir(tc.topLevelFolder, tc.sourceInfo) + if actual != tc.expect { + t.Errorf("Test %d: Expected '%s' but got '%s'", i, tc.expect, actual) + } + } +} + +func TestMakeNameInArchive(t *testing.T) { + for i, tc := range []struct { + sourceInfo fakeFileInfo + source string + baseDir string + fpath string + expect string + }{ + { + sourceInfo: fakeFileInfo{isDir: false}, + source: "foo.txt", + baseDir: "", + fpath: "foo.txt", + expect: "foo.txt", + }, + { + sourceInfo: fakeFileInfo{isDir: false}, + source: "foo.txt", + baseDir: "base", + fpath: "foo.txt", + expect: "base/foo.txt", + }, + { + sourceInfo: fakeFileInfo{isDir: false}, + source: "foo/bar.txt", + baseDir: "", + fpath: "foo/bar.txt", + expect: "foo/bar.txt", + }, + { + sourceInfo: fakeFileInfo{isDir: false}, + source: "foo/bar.txt", + baseDir: "base", + fpath: "foo/bar.txt", + expect: "base/foo/bar.txt", + }, + { + sourceInfo: fakeFileInfo{isDir: true}, + source: "foo/bar", + baseDir: "bar", + fpath: "foo/bar", + expect: "bar", + }, + } { + actual, err := makeNameInArchive(tc.sourceInfo, tc.source, tc.baseDir, tc.fpath) + if err != nil { + t.Errorf("Test %d: Got error: %v", i, err) + } + if actual != tc.expect { + t.Errorf("Test %d: Expected '%s' but got '%s'", i, tc.expect, actual) + } + } +} + func TestArchiveUnarchive(t *testing.T) { for _, af := range archiveFormats { au, ok := af.(archiverUnarchiver) @@ -288,3 +377,19 @@ type archiverUnarchiver interface { Archiver Unarchiver } + +type fakeFileInfo struct { + name string + size int64 + mode os.FileMode + modTime time.Time + isDir bool + sys interface{} +} + +func (ffi fakeFileInfo) Name() string { return ffi.name } +func (ffi fakeFileInfo) Size() int64 { return ffi.size } +func (ffi fakeFileInfo) Mode() os.FileMode { return ffi.mode } +func (ffi fakeFileInfo) ModTime() time.Time { return ffi.modTime } +func (ffi fakeFileInfo) IsDir() bool { return ffi.isDir } +func (ffi fakeFileInfo) Sys() interface{} { return ffi.sys } diff --git a/tar.go b/tar.go index dd9cf0d2..6fbdf5f0 100644 --- a/tar.go +++ b/tar.go @@ -249,14 +249,7 @@ func (t *Tar) writeWalk(source, topLevelFolder, destination string) error { if err != nil { return fmt.Errorf("%s: getting absolute path of destination %s: %v", source, destination, err) } - - var baseDir string - if topLevelFolder != "" { - baseDir = topLevelFolder - } - if sourceInfo.IsDir() { - baseDir = path.Join(baseDir, sourceInfo.Name()) - } + baseDir := makeBaseDir(topLevelFolder, sourceInfo) return filepath.Walk(source, func(fpath string, info os.FileInfo, err error) error { handleErr := func(err error) error { @@ -282,12 +275,11 @@ func (t *Tar) writeWalk(source, topLevelFolder, destination string) error { return nil } - // build the name to be used in the archive - name, err := filepath.Rel(source, fpath) + // build the name to be used within the archive + nameInArchive, err := makeNameInArchive(sourceInfo, source, baseDir, fpath) if err != nil { return handleErr(err) } - nameInArchive := path.Join(baseDir, filepath.ToSlash(name)) file, err := os.Open(fpath) if err != nil { @@ -344,7 +336,6 @@ func (t *Tar) Write(f File) error { if f.ReadCloser == nil { return fmt.Errorf("%s: no way to read file contents", f.Name()) } - hdr, err := tar.FileInfoHeader(f, f.Name()) if err != nil { return fmt.Errorf("%s: making header: %v", f.Name(), err) diff --git a/zip.go b/zip.go index 9828c630..1f10696a 100644 --- a/zip.go +++ b/zip.go @@ -207,14 +207,7 @@ func (z *Zip) writeWalk(source, topLevelFolder, destination string) error { if err != nil { return fmt.Errorf("%s: getting absolute path of destination %s: %v", source, destination, err) } - - var baseDir string - if topLevelFolder != "" { - baseDir = topLevelFolder - } - if sourceInfo.IsDir() { - baseDir = path.Join(baseDir, sourceInfo.Name()) - } + baseDir := makeBaseDir(topLevelFolder, sourceInfo) return filepath.Walk(source, func(fpath string, info os.FileInfo, err error) error { handleErr := func(err error) error { @@ -242,11 +235,10 @@ func (z *Zip) writeWalk(source, topLevelFolder, destination string) error { } // build the name to be used within the archive - name, err := filepath.Rel(source, fpath) + nameInArchive, err := makeNameInArchive(sourceInfo, source, baseDir, fpath) if err != nil { return handleErr(err) } - nameInArchive := path.Join(baseDir, filepath.ToSlash(name)) file, err := os.Open(fpath) if err != nil {