Skip to content

Commit

Permalink
feat: skip building the weights image to speed up (#1223)
Browse files Browse the repository at this point in the history
* feat: skip building the weights image to speed up

Signed-off-by: Hongchao Deng <[email protected]>

* Encode CRC32 checksums as hexadecimal instead of base64

Signed-off-by: Mattt Zmuda <[email protected]>

* Rename Hash to Manifest and refactor implementation

Signed-off-by: Mattt Zmuda <[email protected]>

* Remove comments for self-explanatory code

Signed-off-by: Mattt Zmuda <[email protected]>

---------

Signed-off-by: Hongchao Deng <[email protected]>
Signed-off-by: Mattt Zmuda <[email protected]>
Co-authored-by: Mattt Zmuda <[email protected]>
  • Loading branch information
hongchaodeng and mattt authored Jul 27, 2023
1 parent 22dfa5a commit 02c4a3e
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 5 deletions.
38 changes: 35 additions & 3 deletions pkg/dockerfile/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ type Generator struct {
relativeTmpDir string

fileWalker weights.FileWalker

modelDirs []string
modelFiles []string
}

func NewGenerator(config *config.Config, dir string) (*Generator, error) {
Expand Down Expand Up @@ -155,7 +158,7 @@ func (g *Generator) GenerateDockerfileWithoutSeparateWeights() (string, error) {
// - dockerignoreContents: A string that represents the .dockerignore content.
// - err: An error object if an error occurred during Dockerfile generation; otherwise nil.
func (g *Generator) Generate(imageName string) (weightsBase string, dockerfile string, dockerignoreContents string, err error) {
weightsBase, modelDirs, modelFiles, err := g.generateForWeights()
weightsBase, g.modelDirs, g.modelFiles, err = g.generateForWeights()
if err != nil {
return "", "", "", fmt.Errorf("Failed to generate Dockerfile for model weights files: %w", err)
}
Expand Down Expand Up @@ -201,7 +204,7 @@ func (g *Generator) Generate(imageName string) (weightsBase string, dockerfile s
runCommands,
}

for _, p := range append(modelDirs, modelFiles...) {
for _, p := range append(g.modelDirs, g.modelFiles...) {
base = append(base, "", fmt.Sprintf("COPY --from=%s --link %[2]s %[2]s", "weights", path.Join("/src", p)))
}

Expand All @@ -212,7 +215,7 @@ func (g *Generator) Generate(imageName string) (weightsBase string, dockerfile s
`COPY . /src`,
)

dockerignoreContents = makeDockerignoreForWeights(modelDirs, modelFiles)
dockerignoreContents = makeDockerignoreForWeights(g.modelDirs, g.modelFiles)
return weightsBase, strings.Join(filterEmpty(base), "\n"), dockerignoreContents, nil
}

Expand Down Expand Up @@ -410,3 +413,32 @@ func filterEmpty(list []string) []string {
}
return filtered
}

func (g *Generator) GenerateWeightsManifest() (*weights.Manifest, error) {
m := weights.NewManifest()

for _, dir := range g.modelDirs {
err := g.fileWalker(dir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
return nil
}

return m.AddFile(path)
})
if err != nil {
return nil, err
}
}

for _, path := range g.modelFiles {
err := m.AddFile(path)
if err != nil {
return nil, err
}
}

return m, nil
}
24 changes: 22 additions & 2 deletions pkg/image/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ import (
"github.com/replicate/cog/pkg/dockerfile"
"github.com/replicate/cog/pkg/global"
"github.com/replicate/cog/pkg/util/console"
"github.com/replicate/cog/pkg/weights"
)

const dockerignoreBackupPath = ".dockerignore.cog.bak"
const weightsManifestPath = ".cog/cache/weights_manifest.json"

// Build a Cog model from a config
//
Expand All @@ -40,8 +42,26 @@ func Build(cfg *config.Config, dir, imageName string, secrets []string, noCache,
return fmt.Errorf("Failed to generate Dockerfile: %w", err)
}

if err := buildWeightsImage(dir, weightsDockerfile, imageName+"-weights", secrets, noCache, progressOutput); err != nil {
return fmt.Errorf("Failed to build model weights Docker image: %w", err)
if err := backupDockerignore(); err != nil {
return fmt.Errorf("Failed to backup .dockerignore file: %w", err)
}

weightsManifest, err := generator.GenerateWeightsManifest()
if err != nil {
return fmt.Errorf("Failed to generate weights manifest: %w", err)
}
cachedManifest, err := weights.LoadManifest(weightsManifestPath)
changed := err != nil && weightsManifest.Equal(cachedManifest)
if changed {
if err := buildWeightsImage(dir, weightsDockerfile, imageName+"-weights", secrets, noCache, progressOutput); err != nil {
return fmt.Errorf("Failed to build model weights Docker image: %w", err)
}
err := weightsManifest.Save(weightsManifestPath)
if err != nil {
return fmt.Errorf("Failed to save weights hash: %w", err)
}
} else {
console.Info("Weights unchanged, skip rebuilding and use cached image...")
}

if err := buildRunnerImage(dir, runnerDockerfile, dockerignore, imageName, secrets, noCache, progressOutput); err != nil {
Expand Down
107 changes: 107 additions & 0 deletions pkg/weights/manifest.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package weights

import (
"encoding/binary"
"encoding/hex"
"encoding/json"
"fmt"
"hash/crc32"
"io"
"os"
"path"
)

// Manifest contains metadata about weights files in a model
type Manifest struct {
Files map[string]Metadata `json:"files"`
}

// Metadata contains information about a file
type Metadata struct {
// CRC32 is the CRC32 checksum of the file encoded as a hexadecimal string
CRC32 string `json:"crc32"`
}

// NewManifest creates a new manifest
func NewManifest() *Manifest {
return &Manifest{}
}

// LoadManifest loads a manifest from a file
func LoadManifest(filename string) (*Manifest, error) {
if _, err := os.Stat(filename); err != nil {
return nil, err
}
file, err := os.Open(filename)
if err != nil {
return nil, err
}
defer file.Close()

m := &Manifest{}
decoder := json.NewDecoder(file)
if err := decoder.Decode(m); err != nil {
return nil, err
}
return m, nil
}

// Save saves a manifest to a file
func (m *Manifest) Save(filename string) error {
if err := os.MkdirAll(path.Dir(filename), 0o755); err != nil {
return err
}

file, err := os.Create(filename)
if err != nil {
return err
}
defer file.Close()
encoder := json.NewEncoder(file)
return encoder.Encode(m)
}

// Equal compares the files in two manifests for strict equality
func (m *Manifest) Equal(other *Manifest) bool {
if len(m.Files) != len(other.Files) {
return false
}

for path, crc32 := range m.Files {
if otherCrc32, ok := other.Files[path]; !ok || otherCrc32 != crc32 {
return false
}
}

return true
}

// AddFile adds a file to the manifest, calculating its CRC32 checksum
func (m *Manifest) AddFile(path string) error {
crc32Algo := crc32.NewIEEE()
// generate checksum of file
file, err := os.Open(path)
if err != nil {
return fmt.Errorf("failed to open file %s: %w", path, err)
}
defer file.Close()
_, err = io.Copy(crc32Algo, file)
if err != nil {
return fmt.Errorf("failed to generate checksum of file %s: %w", path, err)
}
checksum := crc32Algo.Sum32()

// encode checksum as hexadecimal string
bytes := make([]byte, 4)
binary.LittleEndian.PutUint32(bytes, checksum)
encoded := hex.EncodeToString(bytes)

if m.Files == nil {
m.Files = make(map[string]Metadata)
}
m.Files[path] = Metadata{
CRC32: encoded,
}

return nil
}

0 comments on commit 02c4a3e

Please sign in to comment.