Skip to content

Commit

Permalink
[uv] use torch cpu integration instead of env var
Browse files Browse the repository at this point in the history
  • Loading branch information
rmehri01 committed Jan 6, 2025
1 parent 8f4d06f commit 14902cf
Show file tree
Hide file tree
Showing 2 changed files with 221 additions and 32 deletions.
176 changes: 144 additions & 32 deletions internal/backends/python/python.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ package python
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"os/exec"
"regexp"
"runtime"
"slices"
"strings"

"github.com/BurntSushi/toml"
Expand All @@ -22,22 +23,7 @@ import (

var normalizationPattern = regexp.MustCompile(`[-_.]+`)

type extraIndex struct {
// url is the location of the index
url string
// os is the operating system to override the index for, leave empty
// to override on any operating system
os string
}

var torchCpu = extraIndex{
url: "https://download.pytorch.org/whl/cpu",
os: "linux",
}

var extraIndexMap = map[string][]extraIndex{
"torch": {torchCpu},
}
var torchOverrides = []string{"torch", "torchvision"}

// this generates a mapping of pypi packages <-> modules
// moduleToPypiPackage pypiPackageToModules are provided
Expand Down Expand Up @@ -73,6 +59,12 @@ type pyprojectTOMLGroup struct {
Dependencies map[string]interface{} `json:"dependencies"`
}

type pyprojectUVIndex struct {
Name string `json:"name"`
Url string `json:"url"`
Explicit bool `json:"explicit"`
}

// pyprojectTOML represents the relevant parts of a pyproject.toml
// file.
type pyprojectTOML struct {
Expand All @@ -95,6 +87,7 @@ type pyprojectTOML struct {
} `toml:"poetry"`
Uv *struct {
Sources map[string]interface{} `toml:"sources"`
Index []pyprojectUVIndex `toml:"index"`
} `toml:"uv"`
} `toml:"tool"`
}
Expand Down Expand Up @@ -769,20 +762,123 @@ func makePythonUvBackend() api.LanguageBackend {

return pkgs
}
addExtraIndexes := func(pkgName string) {
extraIndexes, ok := extraIndexMap[pkgName]
if ok {
uvIndex := os.Getenv("UV_INDEX")

for _, index := range extraIndexes {
if strings.HasPrefix(runtime.GOOS, index.os) {
uvIndex = index.url + " " + uvIndex
addTorchOverride := func() error {
if !util.TomlEditorIsAvailable() {
return errors.New(
"toml-editor is not on the PATH, please install it with " +
"`nix profile install github:replit/toml-editor` or " +
"`cargo install --git https://github.com/replit/toml-editor` and ensure it's on the PATH.",
)
}

// check if the torch cpu index is already present
getOp := util.TomlEditorOp{
Op: "get",
Path: "tool/uv/index",
}
response, err := util.ExecTomlEditor("pyproject.toml", []util.TomlEditorOp{getOp})
if err != nil {
return err
}
if len(response.Results) != 1 {
return fmt.Errorf("expected one result")
}

hasIndex := false
result := response.Results[0]
if result != nil {
if arr, ok := result.([]interface{}); ok {
for _, value := range arr {
index, ok := value.(map[string]interface{})
if !ok {
continue
}
name, ok := index["name"]
if !ok {
continue
}
hasIndex = name == "pytorch-cpu"
if hasIndex {
break
}
}
}
}
if !hasIndex {
value := map[string]interface{}{"name": "pytorch-cpu", "url": "https://download.pytorch.org/whl/cpu", "explicit": true}
valueBytes, err := json.Marshal(value)
if err != nil {
return err
}
addOp := util.TomlEditorOp{
Op: "add",
TableHeaderPath: "tool/uv/index/[[]]",
Value: string(valueBytes),
}
_, err = util.ExecTomlEditor("pyproject.toml", []util.TomlEditorOp{addOp})
if err != nil {
return err
}
}

for _, name := range torchOverrides {
// check if the source is already present
getOp := util.TomlEditorOp{
Op: "get",
Path: fmt.Sprintf("tool/uv/sources/%s", name),
}
response, err := util.ExecTomlEditor("pyproject.toml", []util.TomlEditorOp{getOp})
if err != nil {
util.Log(fmt.Sprintf("error while checking override '%s': %s", name, err))
continue
}
if len(response.Results) != 1 {
util.Log(fmt.Sprintf("error while checking override '%s': expected one result", name))
continue
}

os.Setenv("UV_INDEX", uvIndex)
hasSource := false
var sources []interface{}
result := response.Results[0]
if result != nil {
if arr, ok := result.([]interface{}); ok {
for _, value := range arr {
source, ok := value.(map[string]interface{})
if !ok {
continue
}
name, ok := source["index"]
if !ok {
continue
}
hasSource = name == "pytorch-cpu"
if hasSource {
break
}
}
sources = append(sources, arr...)
}
}
if !hasSource {
sources = append(sources, map[string]interface{}{"index": "pytorch-cpu", "marker": "platform_system == 'Linux'"})
valueBytes, err := json.Marshal(sources)
if err != nil {
return err
}
addOp := util.TomlEditorOp{
Op: "add",
TableHeaderPath: "tool/uv/sources",
Path: name,
Value: string(valueBytes),
}
_, err = util.ExecTomlEditor("pyproject.toml", []util.TomlEditorOp{addOp})
if err != nil {
return err
}
}
}
os.Setenv("UV_INDEX_STRATEGY", "unsafe-best-match")

return nil
}

b := api.LanguageBackend{
Expand Down Expand Up @@ -857,6 +953,7 @@ func makePythonUvBackend() api.LanguageBackend {
}
}

hasTorch := false
cmd := []string{"uv", "add"}
for name, coords := range pkgs {
if found, ok := moduleToPypiPackageAliases[string(name)]; ok {
Expand All @@ -867,12 +964,17 @@ func makePythonUvBackend() api.LanguageBackend {
}

cmd = append(cmd, pep440Join(coords))
addExtraIndexes(string(name))

if slices.Contains(torchOverrides, string(name)) {
hasTorch = true
}
}

specPkgs := listUvSpecfile()
for pkg := range specPkgs {
addExtraIndexes(string(pkg))
if hasTorch {
err := addTorchOverride()
if err != nil {
util.DieSubprocess("%s", err)
}
}

util.RunCmd(cmd)
Expand All @@ -899,9 +1001,19 @@ func makePythonUvBackend() api.LanguageBackend {
span, ctx := tracer.StartSpanFromContext(ctx, "uv install")
defer span.Finish()

hasTorch := false
pkgs := listUvSpecfile()
for pkg := range pkgs {
addExtraIndexes(string(pkg))
if slices.Contains(torchOverrides, string(pkg)) {
hasTorch = true
}
}

if hasTorch {
err := addTorchOverride()
if err != nil {
util.DieSubprocess("%s", err)
}
}

util.RunCmd([]string{"uv", "sync"})
Expand Down
77 changes: 77 additions & 0 deletions internal/util/toml-editor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package util

import (
"encoding/json"
"errors"
"fmt"
"io"
"os/exec"
)

// TomlEditorOp is the format of the JSON sent to toml-editor
type TomlEditorOp struct {
Op string `json:"op"`
Path string `json:"path,omitempty"`
TableHeaderPath string `json:"table_header_path,omitempty"`
Value string `json:"value,omitempty"`
}

// TomlEditorResponse is the format of the JSON sent from toml-editor
type TomlEditorResponse struct {
Status string `json:"status"`
Message string `json:"message"`
Results []interface{} `json:"results"`
}

func TomlEditorIsAvailable() bool {
_, err := exec.LookPath("toml-editor")
return err == nil
}

func ExecTomlEditor(tomlPath string, ops []TomlEditorOp) (*TomlEditorResponse, error) {
cmd := exec.Command("toml-editor", "--path", tomlPath)
stdin, err := cmd.StdinPipe()
if err != nil {
return nil, fmt.Errorf("toml-editor error: %s", err)
}
stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("toml-editor error: %s", err)
}
err = cmd.Start()
if err != nil {
return nil, fmt.Errorf("toml-editor error: %s", err)
}

encoder := json.NewEncoder(stdin)
err = encoder.Encode(ops)
if err != nil {
return nil, fmt.Errorf("toml-editor error: %s", err)
}
decoder := json.NewDecoder(stdout)
var tomlEditorResponse TomlEditorResponse
err = decoder.Decode(&tomlEditorResponse)
if err != nil {
if !errors.Is(err, io.EOF) {
return nil, fmt.Errorf("unexpected toml-editor output: %s", err)
}
}
if tomlEditorResponse.Status != "success" {
input, _ := json.Marshal(ops)
return nil, fmt.Errorf("toml-editor error with input %s: %s", input, tomlEditorResponse.Message)
}

stdout.Close()
err = stdin.Close()
if err != nil {
return nil, fmt.Errorf("toml-editor error: %s", err)
}

err = cmd.Wait()
if err != nil {
input, _ := json.Marshal(ops)
return nil, fmt.Errorf("toml-editor error with input %s: %s", input, err)
}

return &tomlEditorResponse, nil
}

0 comments on commit 14902cf

Please sign in to comment.