Skip to content

Commit

Permalink
Refactor to use fs.FS for file system operations
Browse files Browse the repository at this point in the history
Replaced direct os operations with the fs.FS interface to allow greater flexibility and easier mocking in tests. Updated main function and unit tests to accommodate these changes.
  • Loading branch information
spachava753 committed Oct 17, 2024
1 parent 5e70000 commit 4a7aac3
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 11 deletions.
13 changes: 5 additions & 8 deletions codemapanalysis/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"encoding/json"
"fmt"
"github.com/spachava753/cpe/llm"
"os"
"io/fs"
)

//go:embed select_files_for_analysis.json
Expand All @@ -15,7 +15,7 @@ var selectFilesForAnalysisToolDef json.RawMessage
var codeMapAnalysisPrompt string

// PerformAnalysis performs code map analysis and returns selected files
func PerformAnalysis(provider llm.LLMProvider, genConfig llm.GenConfig, codeMapOutput string, userQuery string) ([]string, error) {
func PerformAnalysis(provider llm.LLMProvider, genConfig llm.GenConfig, codeMapOutput string, userQuery string, fsys fs.FS) ([]string, error) {
conversation := llm.Conversation{
SystemPrompt: codeMapAnalysisPrompt,
Messages: []llm.Message{
Expand Down Expand Up @@ -78,7 +78,7 @@ func PerformAnalysis(provider llm.LLMProvider, genConfig llm.GenConfig, codeMapO
}

// Validate selected files
if selectedFilesErr := validateSelectedFiles(result.SelectedFiles); selectedFilesErr != nil {
if selectedFilesErr := validateSelectedFiles(result.SelectedFiles, fsys); selectedFilesErr != nil {
if attempt < maxAttempts {
errorMsg := fmt.Sprintf("Error validating selected files: %v", selectedFilesErr)
conversation.Messages = append(conversation.Messages, llm.Message{
Expand All @@ -105,19 +105,16 @@ func PerformAnalysis(provider llm.LLMProvider, genConfig llm.GenConfig, codeMapO
return nil, fmt.Errorf("no valid files selected for analysis after %d attempts", maxAttempts+1)
}

func validateSelectedFiles(selectedFiles []string) error {
var validFiles []string
func validateSelectedFiles(selectedFiles []string, fsys fs.FS) error {
for _, file := range selectedFiles {
fileInfo, err := os.Stat(file)
fileInfo, err := fs.Stat(fsys, file)
if err != nil {
return fmt.Errorf("error checking file %s: %w", file, err)
}

if fileInfo.IsDir() {
return fmt.Errorf("%s is a directory, expect a file", file)
}

validFiles = append(validFiles, file)
}

return nil
Expand Down
25 changes: 23 additions & 2 deletions codemapanalysis/analyze_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/spachava753/cpe/llm"
"github.com/stretchr/testify/assert"
"testing"
"testing/fstest"
)

// CustomMockLLMProvider is a custom mock implementation of the LLMProvider interface
Expand Down Expand Up @@ -121,8 +122,18 @@ func TestTokenUsage(t *testing.T)
</code_map>`
userQuery := "What testing packages am I using?"

// Create a mock file system
mockFS := fstest.MapFS{
"main.go": &fstest.MapFile{
Data: []byte("package main\n\nimport (\n\t\"github.com/stretchr/testify/assert\"\n\t\"testing\"\n)\n\nfunc TestMain(t *testing.T)"),
},
"llm/types.go": &fstest.MapFile{
Data: []byte("package llm\n\nimport \"testing\"\n\nfunc TestTokenUsage(t *testing.T)"),
},
}

// Call the function under test
selectedFiles, err := PerformAnalysis(mockProvider, genConfig, codeMapOutput, userQuery)
selectedFiles, err := PerformAnalysis(mockProvider, genConfig, codeMapOutput, userQuery, mockFS)

// Assertions
assert.NoError(t, err)
Expand Down Expand Up @@ -205,8 +216,18 @@ func TestTokenUsage(t *testing.T)
</code_map>`
userQuery := "What testing packages am I using?"

// Create a mock file system
mockFS := fstest.MapFS{
"main.go": &fstest.MapFile{
Data: []byte("package main\n\nimport (\n\t\"github.com/stretchr/testify/assert\"\n\t\"testing\"\n)\n\nfunc TestMain(t *testing.T)"),
},
"llm/types.go": &fstest.MapFile{
Data: []byte("package llm\n\nimport \"testing\"\n\nfunc TestTokenUsage(t *testing.T)"),
},
}

// Call the function under test
selectedFiles, err := PerformAnalysis(mockProvider, genConfig, codeMapOutput, userQuery)
selectedFiles, err := PerformAnalysis(mockProvider, genConfig, codeMapOutput, userQuery, mockFS)

// Assertions
assert.NoError(t, err)
Expand Down
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ func main() {

// Perform code map analysis and select files
analysisStart := time.Now()
selectedFiles, err := codemapanalysis.PerformAnalysis(provider, genConfig, codeMapOutput, content)
selectedFiles, err := codemapanalysis.PerformAnalysis(provider, genConfig, codeMapOutput, content, os.DirFS("."))
logTimeElapsed(analysisStart, "performCodeMapAnalysis")
if err != nil {
fmt.Printf("Error performing code map analysis: %v\n", err)
Expand Down

0 comments on commit 4a7aac3

Please sign in to comment.