diff --git a/gopls/internal/golang/completion/literal.go b/gopls/internal/golang/completion/literal.go index 7427d559e94..82f123048e8 100644 --- a/gopls/internal/golang/completion/literal.go +++ b/gopls/internal/golang/completion/literal.go @@ -370,6 +370,8 @@ func (c *completer) functionLiteral(ctx context.Context, sig *types.Signature, m // conventionalAcronyms contains conventional acronyms for type names // in lower case. For example, "ctx" for "context" and "err" for "error". +// +// Keep this up to date with golang.conventionalVarNames. var conventionalAcronyms = map[string]string{ "context": "ctx", "error": "err", @@ -382,11 +384,6 @@ var conventionalAcronyms = map[string]string{ // non-identifier runes. For example, "[]int" becomes "i", and // "struct { i int }" becomes "s". func abbreviateTypeName(s string) string { - var ( - b strings.Builder - useNextUpper bool - ) - // Trim off leading non-letters. We trim everything between "[" and // "]" to handle array types like "[someConst]int". var inBracket bool @@ -407,27 +404,7 @@ func abbreviateTypeName(s string) string { return acr } - for i, r := range s { - // Stop if we encounter a non-identifier rune. - if !unicode.IsLetter(r) && !unicode.IsNumber(r) { - break - } - - if i == 0 { - b.WriteRune(unicode.ToLower(r)) - } - - if unicode.IsUpper(r) { - if useNextUpper { - b.WriteRune(unicode.ToLower(r)) - useNextUpper = false - } - } else { - useNextUpper = true - } - } - - return b.String() + return golang.AbbreviateVarName(s) } // compositeLiteral adds a composite literal completion item for the given typeName. diff --git a/gopls/internal/golang/extract.go b/gopls/internal/golang/extract.go index 510c6f6eba3..82ef6fd69ad 100644 --- a/gopls/internal/golang/extract.go +++ b/gopls/internal/golang/extract.go @@ -37,14 +37,14 @@ func extractVariable(fset *token.FileSet, start, end token.Pos, src []byte, file // TODO: stricter rules for selectorExpr. case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.SliceExpr, *ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr: - lhsName, _ := generateAvailableIdentifier(expr.Pos(), path, pkg, info, "x", 0) + lhsName, _ := generateAvailableName(expr.Pos(), path, pkg, info, "x", 0) lhsNames = append(lhsNames, lhsName) case *ast.CallExpr: tup, ok := info.TypeOf(expr).(*types.Tuple) if !ok { // If the call expression only has one return value, we can treat it the // same as our standard extract variable case. - lhsName, _ := generateAvailableIdentifier(expr.Pos(), path, pkg, info, "x", 0) + lhsName, _ := generateAvailableName(expr.Pos(), path, pkg, info, "x", 0) lhsNames = append(lhsNames, lhsName) break } @@ -52,7 +52,7 @@ func extractVariable(fset *token.FileSet, start, end token.Pos, src []byte, file for i := 0; i < tup.Len(); i++ { // Generate a unique variable for each return value. var lhsName string - lhsName, idx = generateAvailableIdentifier(expr.Pos(), path, pkg, info, "x", idx) + lhsName, idx = generateAvailableName(expr.Pos(), path, pkg, info, "x", idx) lhsNames = append(lhsNames, lhsName) } default: @@ -150,12 +150,12 @@ func calculateIndentation(content []byte, tok *token.File, insertBeforeStmt ast. return string(content[lineOffset:stmtOffset]), nil } -// generateAvailableIdentifier adjusts the new function name until there are no collisions in scope. +// generateAvailableName adjusts the new function name until there are no collisions in scope. // Possible collisions include other function and variable names. Returns the next index to check for prefix. -func generateAvailableIdentifier(pos token.Pos, path []ast.Node, pkg *types.Package, info *types.Info, prefix string, idx int) (string, int) { +func generateAvailableName(pos token.Pos, path []ast.Node, pkg *types.Package, info *types.Info, prefix string, idx int) (string, int) { scopes := CollectScopes(info, path, pos) scopes = append(scopes, pkg.Scope()) - return generateIdentifier(idx, prefix, func(name string) bool { + return generateName(idx, prefix, func(name string) bool { for _, scope := range scopes { if scope != nil && scope.Lookup(name) != nil { return true @@ -165,7 +165,31 @@ func generateAvailableIdentifier(pos token.Pos, path []ast.Node, pkg *types.Pack }) } -func generateIdentifier(idx int, prefix string, hasCollision func(string) bool) (string, int) { +// generateNameOutsideOfRange is like generateAvailableName, but ignores names +// declared between start and end for the purposes of detecting conflicts. +// +// This is used for function extraction, where [start, end) will be extracted +// to a new scope. +func generateNameOutsideOfRange(start, end token.Pos, path []ast.Node, pkg *types.Package, info *types.Info, prefix string, idx int) (string, int) { + scopes := CollectScopes(info, path, start) + scopes = append(scopes, pkg.Scope()) + return generateName(idx, prefix, func(name string) bool { + for _, scope := range scopes { + if scope != nil { + if obj := scope.Lookup(name); obj != nil { + // Only report a collision if the object declaration was outside the + // extracted range. + if obj.Pos() < start || end <= obj.Pos() { + return true + } + } + } + } + return false + }) +} + +func generateName(idx int, prefix string, hasCollision func(string) bool) (string, int) { name := prefix if idx != 0 { name += fmt.Sprintf("%d", idx) @@ -182,7 +206,7 @@ func generateIdentifier(idx int, prefix string, hasCollision func(string) bool) type returnVariable struct { // name is the identifier that is used on the left-hand side of the call to // the extracted function. - name ast.Expr + name *ast.Ident // decl is the declaration of the variable. It is used in the type signature of the // extracted function and for variable declarations. decl *ast.Field @@ -517,7 +541,7 @@ func extractFunctionMethod(fset *token.FileSet, start, end token.Pos, src []byte // statements in the selection. Update the type signature of the extracted // function and construct the if statement that will be inserted in the enclosing // function. - retVars, ifReturn, err = generateReturnInfo(enclosing, pkg, path, file, info, start, hasNonNestedReturn) + retVars, ifReturn, err = generateReturnInfo(enclosing, pkg, path, file, info, start, end, hasNonNestedReturn) if err != nil { return nil, nil, err } @@ -552,7 +576,7 @@ func extractFunctionMethod(fset *token.FileSet, start, end token.Pos, src []byte funName = name } else { name = "newFunction" - funName, _ = generateAvailableIdentifier(start, path, pkg, info, name, 0) + funName, _ = generateAvailableName(start, path, pkg, info, name, 0) } extractedFunCall := generateFuncCall(hasNonNestedReturn, hasReturnValues, params, append(returns, getNames(retVars)...), funName, sym, receiverName) @@ -1187,12 +1211,12 @@ func parseBlockStmt(fset *token.FileSet, src []byte) (*ast.BlockStmt, error) { // signature of the extracted function. We prepare names, signatures, and "zero values" that // represent the new variables. We also use this information to construct the if statement that // is inserted below the call to the extracted function. -func generateReturnInfo(enclosing *ast.FuncType, pkg *types.Package, path []ast.Node, file *ast.File, info *types.Info, pos token.Pos, hasNonNestedReturns bool) ([]*returnVariable, *ast.IfStmt, error) { +func generateReturnInfo(enclosing *ast.FuncType, pkg *types.Package, path []ast.Node, file *ast.File, info *types.Info, start, end token.Pos, hasNonNestedReturns bool) ([]*returnVariable, *ast.IfStmt, error) { var retVars []*returnVariable var cond *ast.Ident if !hasNonNestedReturns { // Generate information for the added bool value. - name, _ := generateAvailableIdentifier(pos, path, pkg, info, "shouldReturn", 0) + name, _ := generateNameOutsideOfRange(start, end, path, pkg, info, "shouldReturn", 0) cond = &ast.Ident{Name: name} retVars = append(retVars, &returnVariable{ name: cond, @@ -1202,7 +1226,7 @@ func generateReturnInfo(enclosing *ast.FuncType, pkg *types.Package, path []ast. } // Generate information for the values in the return signature of the enclosing function. if enclosing.Results != nil { - idx := 0 + nameIdx := make(map[string]int) // last integral suffixes of generated names for _, field := range enclosing.Results.List { typ := info.TypeOf(field.Type) if typ == nil { @@ -1213,17 +1237,32 @@ func generateReturnInfo(enclosing *ast.FuncType, pkg *types.Package, path []ast. if expr == nil { return nil, nil, fmt.Errorf("nil AST expression") } - var name string - name, idx = generateAvailableIdentifier(pos, path, pkg, info, "returnValue", idx) - z := analysisinternal.ZeroValue(file, pkg, typ) - if z == nil { - return nil, nil, fmt.Errorf("can't generate zero value for %T", typ) + names := []string{""} + if len(field.Names) > 0 { + names = nil + for _, n := range field.Names { + names = append(names, n.Name) + } + } + for _, name := range names { + bestName := "result" + if name != "" && name != "_" { + bestName = name + } else if n, ok := varNameForType(typ); ok { + bestName = n + } + retName, idx := generateNameOutsideOfRange(start, end, path, pkg, info, bestName, nameIdx[bestName]) + nameIdx[bestName] = idx + z := analysisinternal.ZeroValue(file, pkg, typ) + if z == nil { + return nil, nil, fmt.Errorf("can't generate zero value for %T", typ) + } + retVars = append(retVars, &returnVariable{ + name: ast.NewIdent(retName), + decl: &ast.Field{Type: expr}, + zeroVal: z, + }) } - retVars = append(retVars, &returnVariable{ - name: ast.NewIdent(name), - decl: &ast.Field{Type: expr}, - zeroVal: z, - }) } } var ifReturn *ast.IfStmt @@ -1240,6 +1279,48 @@ func generateReturnInfo(enclosing *ast.FuncType, pkg *types.Package, path []ast. return retVars, ifReturn, nil } +type objKey struct{ pkg, name string } + +// conventionalVarNames specifies conventional names for variables with various +// standard library types. +// +// Keep this up to date with completion.conventionalAcronyms. +// +// TODO(rfindley): consider factoring out a "conventions" library. +var conventionalVarNames = map[objKey]string{ + {"", "error"}: "err", + {"context", "Context"}: "ctx", + {"sql", "Tx"}: "tx", + {"http", "ResponseWriter"}: "rw", // Note: same as [AbbreviateVarName]. +} + +// varNameForTypeName chooses a "good" name for a variable with the given type, +// if possible. Otherwise, it returns "", false. +// +// For special types, it uses known conventional names. +func varNameForType(t types.Type) (string, bool) { + var typeName string + if tn, ok := t.(interface{ Obj() *types.TypeName }); ok { + obj := tn.Obj() + k := objKey{name: obj.Name()} + if obj.Pkg() != nil { + k.pkg = obj.Pkg().Name() + } + if name, ok := conventionalVarNames[k]; ok { + return name, true + } + typeName = obj.Name() + } else if b, ok := t.(*types.Basic); ok { + typeName = b.Name() + } + + if typeName == "" { + return "", false + } + + return AbbreviateVarName(typeName), true +} + // adjustReturnStatements adds "zero values" of the given types to each return statement // in the given AST node. func adjustReturnStatements(returnTypes []*ast.Field, seenVars map[types.Object]ast.Expr, file *ast.File, pkg *types.Package, extractedBlock *ast.BlockStmt) error { @@ -1346,9 +1427,8 @@ func initializeVars(uninitialized []types.Object, retVars []*returnVariable, see // Each variable added from a return statement in the selection // must be initialized. for i, retVar := range retVars { - n := retVar.name.(*ast.Ident) valSpec := &ast.ValueSpec{ - Names: []*ast.Ident{n}, + Names: []*ast.Ident{retVar.name}, Type: retVars[i].decl.Type, } genDecl := &ast.GenDecl{ diff --git a/gopls/internal/golang/util.go b/gopls/internal/golang/util.go index 18f72421a64..06239af17d6 100644 --- a/gopls/internal/golang/util.go +++ b/gopls/internal/golang/util.go @@ -12,6 +12,7 @@ import ( "go/types" "regexp" "strings" + "unicode" "golang.org/x/tools/gopls/internal/cache" "golang.org/x/tools/gopls/internal/cache/metadata" @@ -363,3 +364,36 @@ func btoi(b bool) int { return 0 } } + +// AbbreviateVarName returns an abbreviated var name based on the given full +// name (which may be a type name, for example). +// +// See the simple heuristics documented in line. +func AbbreviateVarName(s string) string { + var ( + b strings.Builder + useNextUpper bool + ) + for i, r := range s { + // Stop if we encounter a non-identifier rune. + if !unicode.IsLetter(r) && !unicode.IsNumber(r) { + break + } + + // Otherwise, take the first letter from word boundaries, assuming + // camelCase. + if i == 0 { + b.WriteRune(unicode.ToLower(r)) + } + + if unicode.IsUpper(r) { + if useNextUpper { + b.WriteRune(unicode.ToLower(r)) + useNextUpper = false + } + } else { + useNextUpper = true + } + } + return b.String() +} diff --git a/gopls/internal/test/marker/testdata/codeaction/functionextraction.txt b/gopls/internal/test/marker/testdata/codeaction/functionextraction.txt index 1b9f487c49d..6ae0bc7177e 100644 --- a/gopls/internal/test/marker/testdata/codeaction/functionextraction.txt +++ b/gopls/internal/test/marker/testdata/codeaction/functionextraction.txt @@ -56,9 +56,9 @@ package extract func _() bool { x := 1 //@codeaction("if", ifend, "refactor.extract.function", return) - shouldReturn, returnValue := newFunction(x) + shouldReturn, b := newFunction(x) if shouldReturn { - return returnValue + return b } //@loc(ifend, "}") return false } @@ -124,9 +124,9 @@ func _() (int, string, error) { x := 1 y := "hello" //@codeaction("z", rcEnd, "refactor.extract.function", rc) - z, shouldReturn, returnValue, returnValue1, returnValue2 := newFunction(y, x) + z, shouldReturn, i, s, err := newFunction(y, x) if shouldReturn { - return returnValue, returnValue1, returnValue2 + return i, s, err } //@loc(rcEnd, "}") return x, z, nil } @@ -205,9 +205,9 @@ import "go/ast" func _() { ast.Inspect(ast.NewIdent("a"), func(n ast.Node) bool { //@codeaction("if", rflEnd, "refactor.extract.function", rfl) - shouldReturn, returnValue := newFunction(n) + shouldReturn, b := newFunction(n) if shouldReturn { - return returnValue + return b } //@loc(rflEnd, "}") return false }) @@ -272,9 +272,9 @@ package extract func _() string { x := 1 //@codeaction("if", riEnd, "refactor.extract.function", ri) - shouldReturn, returnValue := newFunction(x) + shouldReturn, s := newFunction(x) if shouldReturn { - return returnValue + return s } //@loc(riEnd, "}") x = 2 return "b" diff --git a/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue66289.txt b/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue66289.txt index 65412ee91fa..c032c7797a6 100644 --- a/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue66289.txt +++ b/gopls/internal/test/marker/testdata/codeaction/functionextraction_issue66289.txt @@ -8,19 +8,19 @@ import ( ) func F() error { - a, err := json.Marshal(0) //@codeaction("a", end, "refactor.extract.function", out) + a, err := json.Marshal(0) //@codeaction("a", endF, "refactor.extract.function", F) if err != nil { return fmt.Errorf("1: %w", err) } b, err := json.Marshal(0) if err != nil { return fmt.Errorf("2: %w", err) - } //@loc(end, "}") + } //@loc(endF, "}") fmt.Println(a, b) return nil } --- @out/a.go -- +-- @F/a.go -- package a import ( @@ -29,11 +29,11 @@ import ( ) func F() error { - //@codeaction("a", end, "refactor.extract.function", out) - a, b, shouldReturn, returnValue := newFunction() + //@codeaction("a", endF, "refactor.extract.function", F) + a, b, shouldReturn, err := newFunction() if shouldReturn { - return returnValue - } //@loc(end, "}") + return err + } //@loc(endF, "}") fmt.Println(a, b) return nil } @@ -50,3 +50,50 @@ func newFunction() ([]byte, []byte, bool, error) { return a, b, false, nil } +-- b.go -- +package a + +import ( + "fmt" + "math/rand" +) + +func G() (x, y int) { + v := rand.Int() //@codeaction("v", endG, "refactor.extract.function", G) + if v < 0 { + return 1, 2 + } + if v > 0 { + return 3, 4 + } //@loc(endG, "}") + fmt.Println(v) + return 5, 6 +} +-- @G/b.go -- +package a + +import ( + "fmt" + "math/rand" +) + +func G() (x, y int) { + //@codeaction("v", endG, "refactor.extract.function", G) + v, shouldReturn, x1, y1 := newFunction() + if shouldReturn { + return x1, y1 + } //@loc(endG, "}") + fmt.Println(v) + return 5, 6 +} + +func newFunction() (int, bool, int, int) { + v := rand.Int() + if v < 0 { + return 0, true, 1, 2 + } + if v > 0 { + return 0, true, 3, 4 + } + return v, false, 0, 0 +}