From 14902cf7b75c2c217140533ba6530d804d40da61 Mon Sep 17 00:00:00 2001 From: Ryan Mehri Date: Tue, 31 Dec 2024 19:30:06 +0000 Subject: [PATCH] [uv] use torch cpu integration instead of env var --- internal/backends/python/python.go | 176 +++++++++++++++++++++++------ internal/util/toml-editor.go | 77 +++++++++++++ 2 files changed, 221 insertions(+), 32 deletions(-) create mode 100644 internal/util/toml-editor.go diff --git a/internal/backends/python/python.go b/internal/backends/python/python.go index bc353370..aa8334e5 100644 --- a/internal/backends/python/python.go +++ b/internal/backends/python/python.go @@ -4,12 +4,13 @@ package python import ( "context" "encoding/json" + "errors" "fmt" "io" "os" "os/exec" "regexp" - "runtime" + "slices" "strings" "github.com/BurntSushi/toml" @@ -22,22 +23,7 @@ import ( var normalizationPattern = regexp.MustCompile(`[-_.]+`) -type extraIndex struct { - // url is the location of the index - url string - // os is the operating system to override the index for, leave empty - // to override on any operating system - os string -} - -var torchCpu = extraIndex{ - url: "https://download.pytorch.org/whl/cpu", - os: "linux", -} - -var extraIndexMap = map[string][]extraIndex{ - "torch": {torchCpu}, -} +var torchOverrides = []string{"torch", "torchvision"} // this generates a mapping of pypi packages <-> modules // moduleToPypiPackage pypiPackageToModules are provided @@ -73,6 +59,12 @@ type pyprojectTOMLGroup struct { Dependencies map[string]interface{} `json:"dependencies"` } +type pyprojectUVIndex struct { + Name string `json:"name"` + Url string `json:"url"` + Explicit bool `json:"explicit"` +} + // pyprojectTOML represents the relevant parts of a pyproject.toml // file. type pyprojectTOML struct { @@ -95,6 +87,7 @@ type pyprojectTOML struct { } `toml:"poetry"` Uv *struct { Sources map[string]interface{} `toml:"sources"` + Index []pyprojectUVIndex `toml:"index"` } `toml:"uv"` } `toml:"tool"` } @@ -769,20 +762,123 @@ func makePythonUvBackend() api.LanguageBackend { return pkgs } - addExtraIndexes := func(pkgName string) { - extraIndexes, ok := extraIndexMap[pkgName] - if ok { - uvIndex := os.Getenv("UV_INDEX") - - for _, index := range extraIndexes { - if strings.HasPrefix(runtime.GOOS, index.os) { - uvIndex = index.url + " " + uvIndex + addTorchOverride := func() error { + if !util.TomlEditorIsAvailable() { + return errors.New( + "toml-editor is not on the PATH, please install it with " + + "`nix profile install github:replit/toml-editor` or " + + "`cargo install --git https://github.com/replit/toml-editor` and ensure it's on the PATH.", + ) + } + + // check if the torch cpu index is already present + getOp := util.TomlEditorOp{ + Op: "get", + Path: "tool/uv/index", + } + response, err := util.ExecTomlEditor("pyproject.toml", []util.TomlEditorOp{getOp}) + if err != nil { + return err + } + if len(response.Results) != 1 { + return fmt.Errorf("expected one result") + } + + hasIndex := false + result := response.Results[0] + if result != nil { + if arr, ok := result.([]interface{}); ok { + for _, value := range arr { + index, ok := value.(map[string]interface{}) + if !ok { + continue + } + name, ok := index["name"] + if !ok { + continue + } + hasIndex = name == "pytorch-cpu" + if hasIndex { + break + } } } + } + if !hasIndex { + value := map[string]interface{}{"name": "pytorch-cpu", "url": "https://download.pytorch.org/whl/cpu", "explicit": true} + valueBytes, err := json.Marshal(value) + if err != nil { + return err + } + addOp := util.TomlEditorOp{ + Op: "add", + TableHeaderPath: "tool/uv/index/[[]]", + Value: string(valueBytes), + } + _, err = util.ExecTomlEditor("pyproject.toml", []util.TomlEditorOp{addOp}) + if err != nil { + return err + } + } + + for _, name := range torchOverrides { + // check if the source is already present + getOp := util.TomlEditorOp{ + Op: "get", + Path: fmt.Sprintf("tool/uv/sources/%s", name), + } + response, err := util.ExecTomlEditor("pyproject.toml", []util.TomlEditorOp{getOp}) + if err != nil { + util.Log(fmt.Sprintf("error while checking override '%s': %s", name, err)) + continue + } + if len(response.Results) != 1 { + util.Log(fmt.Sprintf("error while checking override '%s': expected one result", name)) + continue + } - os.Setenv("UV_INDEX", uvIndex) + hasSource := false + var sources []interface{} + result := response.Results[0] + if result != nil { + if arr, ok := result.([]interface{}); ok { + for _, value := range arr { + source, ok := value.(map[string]interface{}) + if !ok { + continue + } + name, ok := source["index"] + if !ok { + continue + } + hasSource = name == "pytorch-cpu" + if hasSource { + break + } + } + sources = append(sources, arr...) + } + } + if !hasSource { + sources = append(sources, map[string]interface{}{"index": "pytorch-cpu", "marker": "platform_system == 'Linux'"}) + valueBytes, err := json.Marshal(sources) + if err != nil { + return err + } + addOp := util.TomlEditorOp{ + Op: "add", + TableHeaderPath: "tool/uv/sources", + Path: name, + Value: string(valueBytes), + } + _, err = util.ExecTomlEditor("pyproject.toml", []util.TomlEditorOp{addOp}) + if err != nil { + return err + } + } } - os.Setenv("UV_INDEX_STRATEGY", "unsafe-best-match") + + return nil } b := api.LanguageBackend{ @@ -857,6 +953,7 @@ func makePythonUvBackend() api.LanguageBackend { } } + hasTorch := false cmd := []string{"uv", "add"} for name, coords := range pkgs { if found, ok := moduleToPypiPackageAliases[string(name)]; ok { @@ -867,12 +964,17 @@ func makePythonUvBackend() api.LanguageBackend { } cmd = append(cmd, pep440Join(coords)) - addExtraIndexes(string(name)) + + if slices.Contains(torchOverrides, string(name)) { + hasTorch = true + } } - specPkgs := listUvSpecfile() - for pkg := range specPkgs { - addExtraIndexes(string(pkg)) + if hasTorch { + err := addTorchOverride() + if err != nil { + util.DieSubprocess("%s", err) + } } util.RunCmd(cmd) @@ -899,9 +1001,19 @@ func makePythonUvBackend() api.LanguageBackend { span, ctx := tracer.StartSpanFromContext(ctx, "uv install") defer span.Finish() + hasTorch := false pkgs := listUvSpecfile() for pkg := range pkgs { - addExtraIndexes(string(pkg)) + if slices.Contains(torchOverrides, string(pkg)) { + hasTorch = true + } + } + + if hasTorch { + err := addTorchOverride() + if err != nil { + util.DieSubprocess("%s", err) + } } util.RunCmd([]string{"uv", "sync"}) diff --git a/internal/util/toml-editor.go b/internal/util/toml-editor.go new file mode 100644 index 00000000..7ea84308 --- /dev/null +++ b/internal/util/toml-editor.go @@ -0,0 +1,77 @@ +package util + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "os/exec" +) + +// TomlEditorOp is the format of the JSON sent to toml-editor +type TomlEditorOp struct { + Op string `json:"op"` + Path string `json:"path,omitempty"` + TableHeaderPath string `json:"table_header_path,omitempty"` + Value string `json:"value,omitempty"` +} + +// TomlEditorResponse is the format of the JSON sent from toml-editor +type TomlEditorResponse struct { + Status string `json:"status"` + Message string `json:"message"` + Results []interface{} `json:"results"` +} + +func TomlEditorIsAvailable() bool { + _, err := exec.LookPath("toml-editor") + return err == nil +} + +func ExecTomlEditor(tomlPath string, ops []TomlEditorOp) (*TomlEditorResponse, error) { + cmd := exec.Command("toml-editor", "--path", tomlPath) + stdin, err := cmd.StdinPipe() + if err != nil { + return nil, fmt.Errorf("toml-editor error: %s", err) + } + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("toml-editor error: %s", err) + } + err = cmd.Start() + if err != nil { + return nil, fmt.Errorf("toml-editor error: %s", err) + } + + encoder := json.NewEncoder(stdin) + err = encoder.Encode(ops) + if err != nil { + return nil, fmt.Errorf("toml-editor error: %s", err) + } + decoder := json.NewDecoder(stdout) + var tomlEditorResponse TomlEditorResponse + err = decoder.Decode(&tomlEditorResponse) + if err != nil { + if !errors.Is(err, io.EOF) { + return nil, fmt.Errorf("unexpected toml-editor output: %s", err) + } + } + if tomlEditorResponse.Status != "success" { + input, _ := json.Marshal(ops) + return nil, fmt.Errorf("toml-editor error with input %s: %s", input, tomlEditorResponse.Message) + } + + stdout.Close() + err = stdin.Close() + if err != nil { + return nil, fmt.Errorf("toml-editor error: %s", err) + } + + err = cmd.Wait() + if err != nil { + input, _ := json.Marshal(ops) + return nil, fmt.Errorf("toml-editor error with input %s: %s", input, err) + } + + return &tomlEditorResponse, nil +}