Skip to content

Commit

Permalink
Switch to tree-sitter for parsing
Browse files Browse the repository at this point in the history
Replaced Go's standard library parsing with tree-sitter to handle type and function definitions, aliases, and imports. This change is the start of handling any programming language instead of just golang
  • Loading branch information
spachava753 committed Oct 7, 2024
1 parent 0ce4def commit c4cc8f4
Showing 1 changed file with 218 additions and 50 deletions.
268 changes: 218 additions & 50 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@ package main

import (
"bufio"
"context"
_ "embed"
"encoding/json"
"flag"
"fmt"
sitter "github.com/smacker/go-tree-sitter"
"github.com/smacker/go-tree-sitter/golang"
"github.com/spachava753/cpe/codemap"
"github.com/spachava753/cpe/extract"
"github.com/spachava753/cpe/fileops"
"github.com/spachava753/cpe/llm"
"go/ast"
"go/parser"
"go/token"
"io"
"io/fs"
"os"
Expand Down Expand Up @@ -67,59 +67,149 @@ func performCodeMapAnalysis(provider llm.LLMProvider, genConfig llm.GenConfig, c
}

func resolveTypeAndFunctionFiles(selectedFiles []string, sourceFS fs.FS) (map[string]bool, error) {
fset := token.NewFileSet()
typeDefinitions := make(map[string]map[string]string) // package.type -> file
functionDefinitions := make(map[string]map[string]string) // package.function -> file
usages := make(map[string]bool)
imports := make(map[string]map[string]string) // file -> alias -> package

parser := sitter.NewParser()
parser.SetLanguage(golang.GetLanguage())

// Queries
typeDefQuery, _ := sitter.NewQuery([]byte(`
(type_declaration
(type_spec
name: (type_identifier) @type.definition))
(type_alias
name: (type_identifier) @type.alias.definition)`), golang.GetLanguage())
funcDefQuery, _ := sitter.NewQuery([]byte(`
(function_declaration
name: (identifier) @function.definition)
(method_declaration
name: (field_identifier) @method.definition)`), golang.GetLanguage())
importQuery, _ := sitter.NewQuery([]byte(`
(import_declaration
(import_spec_list
(import_spec
name: (_)? @import.name
path: (interpreted_string_literal) @import.path)))
(import_declaration
(import_spec
name: (_)? @import.name
path: (interpreted_string_literal) @import.path))`), golang.GetLanguage())
typeUsageQuery, _ := sitter.NewQuery([]byte(`
[
(type_identifier) @type.usage
(qualified_type
package: (package_identifier) @package
name: (type_identifier) @type)
(generic_type
type: [
(type_identifier) @type.usage
(qualified_type
package: (package_identifier) @package
name: (type_identifier) @type)
])
]`), golang.GetLanguage())
funcUsageQuery, _ := sitter.NewQuery([]byte(`
(call_expression
function: [
(identifier) @function.usage
(selector_expression
operand: [
(identifier) @package
(selector_expression)
]?
field: (field_identifier) @method.usage)])`), golang.GetLanguage())

// Parse all files in the source directory and collect type and function definitions
err := fs.WalkDir(sourceFS, ".", func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if !d.IsDir() && strings.HasSuffix(path, ".go") {
file, err := sourceFS.Open(path)
content, err := fs.ReadFile(sourceFS, path)
if err != nil {
return fmt.Errorf("error opening file %s: %w", path, err)
return fmt.Errorf("error reading file %s: %w", path, err)
}
defer file.Close()

astFile, err := parser.ParseFile(fset, path, file, parser.ParseComments)
tree, err := parser.ParseCtx(context.Background(), nil, content)
if err != nil {
return fmt.Errorf("error parsing file %s: %w", path, err)
}

pkgName := astFile.Name.Name
// Extract package name
pkgNameQuery, _ := sitter.NewQuery([]byte(`(package_clause (package_identifier) @package.name)`), golang.GetLanguage())
pkgNameCursor := sitter.NewQueryCursor()
pkgNameCursor.Exec(pkgNameQuery, tree.RootNode())
pkgName := ""
if match, ok := pkgNameCursor.NextMatch(); ok {
for _, capture := range match.Captures {
pkgName = capture.Node.Content(content)
break
}
}

if _, ok := typeDefinitions[pkgName]; !ok {
typeDefinitions[pkgName] = make(map[string]string)
}
if _, ok := functionDefinitions[pkgName]; !ok {
functionDefinitions[pkgName] = make(map[string]string)
}

ast.Inspect(astFile, func(n ast.Node) bool {
switch x := n.(type) {
case *ast.TypeSpec:
typeDefinitions[pkgName][x.Name.Name] = path
case *ast.FuncDecl:
functionDefinitions[pkgName][x.Name.Name] = path
case *ast.ImportSpec:
if imports[path] == nil {
imports[path] = make(map[string]string)
}
importPath := strings.Trim(x.Path.Value, "\"")
alias := ""
if x.Name != nil {
alias = x.Name.Name
} else {
parts := strings.Split(importPath, "/")
alias = parts[len(parts)-1]
// Process type definitions and aliases
typeCursor := sitter.NewQueryCursor()
typeCursor.Exec(typeDefQuery, tree.RootNode())
for {
match, ok := typeCursor.NextMatch()
if !ok {
break
}
for _, capture := range match.Captures {
typeName := capture.Node.Content(content)
typeDefinitions[pkgName][typeName] = path
}
}

// Process function definitions
funcCursor := sitter.NewQueryCursor()
funcCursor.Exec(funcDefQuery, tree.RootNode())
for {
match, ok := funcCursor.NextMatch()
if !ok {
break
}
for _, capture := range match.Captures {
functionDefinitions[pkgName][capture.Node.Content(content)] = path
}
}

// Process imports
importCursor := sitter.NewQueryCursor()
importCursor.Exec(importQuery, tree.RootNode())
for {
match, ok := importCursor.NextMatch()
if !ok {
break
}
var importName, importPath string
for _, capture := range match.Captures {
switch capture.Node.Type() {
case "identifier":
importName = capture.Node.Content(content)
case "interpreted_string_literal":
importPath = strings.Trim(capture.Node.Content(content), "\"")
}
imports[path][alias] = importPath
}
return true
})
if importName == "" {
parts := strings.Split(importPath, "/")
importName = parts[len(parts)-1]
}
if imports[path] == nil {
imports[path] = make(map[string]string)
}
imports[path][importName] = importPath
}
}
return nil
})
Expand All @@ -130,42 +220,120 @@ func resolveTypeAndFunctionFiles(selectedFiles []string, sourceFS fs.FS) (map[st

// Collect type and function usages for selected files
for _, file := range selectedFiles {
f, err := sourceFS.Open(file)
content, err := fs.ReadFile(sourceFS, file)
if err != nil {
return nil, fmt.Errorf("error opening file %s: %w", file, err)
return nil, fmt.Errorf("error reading file %s: %w", file, err)
}
defer f.Close()

astFile, err := parser.ParseFile(fset, file, f, parser.ParseComments)
tree, err := parser.ParseCtx(context.Background(), nil, content)
if err != nil {
return nil, fmt.Errorf("error parsing file %s: %w", file, err)
}

ast.Inspect(astFile, func(n ast.Node) bool {
switch x := n.(type) {
case *ast.SelectorExpr:
if ident, ok := x.X.(*ast.Ident); ok {
if importPath, ok := imports[file][ident.Name]; ok {
parts := strings.Split(importPath, "/")
pkgName := parts[len(parts)-1]
if defFile, ok := typeDefinitions[pkgName][x.Sel.Name]; ok {
usages[defFile] = true
}
if defFile, ok := functionDefinitions[pkgName][x.Sel.Name]; ok {
usages[defFile] = true
// Process type usages
typeUsageCursor := sitter.NewQueryCursor()
typeUsageCursor.Exec(typeUsageQuery, tree.RootNode())
for {
match, ok := typeUsageCursor.NextMatch()
if !ok {
break
}
var packageName, typeName string
for _, capture := range match.Captures {
switch capture.Node.Type() {
case "package_identifier":
packageName = capture.Node.Content(content)
case "type_identifier":
typeName = capture.Node.Content(content)
}
}

// Check if it's a local type
for pkg, types := range typeDefinitions {
if defFile, ok := types[typeName]; ok {
usages[defFile] = true
break
}
if packageName != "" && pkg == packageName {
if defFile, ok := types[typeName]; ok {
usages[defFile] = true
break
}
}
}

// If not found as a local type, it might be an imported type
if packageName != "" {
if importPath, ok := imports[file][packageName]; ok {
// Mark only the specific imported type as used
for pkgName, types := range typeDefinitions {
if strings.HasSuffix(importPath, pkgName) {
if defFile, ok := types[typeName]; ok {
usages[defFile] = true
break
}
}
}
}
case *ast.Ident:
if defFile, ok := typeDefinitions[astFile.Name.Name][x.Name]; ok {
}
}

// Process function usages
funcUsageCursor := sitter.NewQueryCursor()
funcUsageCursor.Exec(funcUsageQuery, tree.RootNode())
for {
match, ok := funcUsageCursor.NextMatch()
if !ok {
break
}
var packageName, funcName string
for _, capture := range match.Captures {
switch capture.Node.Type() {
case "identifier", "package":
packageName = capture.Node.Content(content)
case "field_identifier", "function.usage", "method.usage":
funcName = capture.Node.Content(content)
}
}

// Check if it's a local function
for pkg, funcs := range functionDefinitions {
if defFile, ok := funcs[funcName]; ok {
usages[defFile] = true
break
}
if defFile, ok := functionDefinitions[astFile.Name.Name][x.Name]; ok {
if packageName != "" && pkg == packageName {
if defFile, ok := funcs[funcName]; ok {
usages[defFile] = true
break
}
}
}

// If not found as a local function, it might be an imported function
if packageName != "" {
if importPath, ok := imports[file][packageName]; ok {
// Mark only the specific imported function as used
for pkgName, funcs := range functionDefinitions {
if strings.HasSuffix(importPath, pkgName) {
if defFile, ok := funcs[funcName]; ok {
usages[defFile] = true
break
}
}
}
}
}
}

// Add files containing function definitions used in this file
for _, funcs := range functionDefinitions {
for funcName, defFile := range funcs {
if strings.Contains(string(content), funcName) {
usages[defFile] = true
}
}
return true
})
}

// Add the selected file to the usages
usages[file] = true
Expand Down

0 comments on commit c4cc8f4

Please sign in to comment.