diff --git a/context.go b/context.go index 6d0d5edd..34cdd642 100644 --- a/context.go +++ b/context.go @@ -79,7 +79,7 @@ type Context struct { Lookup func(root, path string) (dir string, found bool) // lookup external import evalCallFn func(interp *Interp, call *ssa.Call, res ...interface{}) // internal eval func for repl debugFunc func(*DebugInfo) // debug func - pkgs map[string]*sourcePackage // imports + pkgs map[string]*SourcePackage // imports override map[string]reflect.Value // override function evalInit map[string]bool // eval init check nestedMap map[*types.Named]int // nested named index @@ -107,22 +107,26 @@ func (ctx *Context) lookupPath(path string) (dir string, found bool) { return } -type sourcePackage struct { +type SourcePackage struct { Context *Context Package *types.Package Info *types.Info + Importer types.Importer Files []*ast.File Links []*load.LinkSym Dir string Register bool // register package } -func (sp *sourcePackage) Load() (err error) { +func (sp *SourcePackage) Load() (err error) { if sp.Info == nil { sp.Info = newTypesInfo() + if sp.Importer == nil { + sp.Importer = NewImporter(sp.Context) + } conf := &types.Config{ Sizes: sp.Context.sizes, - Importer: NewImporter(sp.Context), + Importer: sp.Importer, } if sp.Context.evalMode { conf.DisableUnusedImportCheck = true @@ -161,7 +165,7 @@ func NewContext(mode Mode) *Context { Mode: mode, BuilderMode: 0, //ssa.SanityCheckFunctions, BuildContext: build.Default, - pkgs: make(map[string]*sourcePackage), + pkgs: make(map[string]*SourcePackage), override: make(map[string]reflect.Value), nestedMap: make(map[*types.Named]int), callForPool: 64, @@ -245,7 +249,7 @@ func (ctx *Context) LoadDir(dir string, test bool) (pkg *ssa.Package, err error) return nil, err } bp.ImportPath = importPath - var sp *sourcePackage + var sp *SourcePackage if test { sp, err = ctx.loadTestPackage(bp, importPath, dir) } else { @@ -293,7 +297,11 @@ func (ctx *Context) AddImport(path string, dir string) (err error) { return } -func (ctx *Context) addImportFile(path string, filename string, src interface{}) (*sourcePackage, error) { +func (ctx *Context) SourcePackage(path string) *SourcePackage { + return ctx.pkgs[path] +} + +func (ctx *Context) addImportFile(path string, filename string, src interface{}) (*SourcePackage, error) { tp, err := ctx.loadPackageFile(path, filename, src) if err != nil { return nil, err @@ -302,7 +310,7 @@ func (ctx *Context) addImportFile(path string, filename string, src interface{}) return tp, nil } -func (ctx *Context) addImport(path string, dir string) (*sourcePackage, error) { +func (ctx *Context) addImport(path string, dir string) (*SourcePackage, error) { bp, err := ctx.BuildContext.ImportDir(dir, 0) if err != nil { return nil, err @@ -316,13 +324,13 @@ func (ctx *Context) addImport(path string, dir string) (*sourcePackage, error) { return tp, nil } -func (ctx *Context) loadPackageFile(path string, filename string, src interface{}) (*sourcePackage, error) { +func (ctx *Context) loadPackageFile(path string, filename string, src interface{}) (*SourcePackage, error) { file, err := ctx.ParseFile(filename, src) if err != nil { return nil, err } pkg := types.NewPackage(path, file.Name.Name) - tp := &sourcePackage{ + tp := &SourcePackage{ Context: ctx, Package: pkg, Files: []*ast.File{file}, @@ -331,7 +339,7 @@ func (ctx *Context) loadPackageFile(path string, filename string, src interface{ return tp, nil } -func (ctx *Context) loadPackage(bp *build.Package, path string, dir string) (*sourcePackage, error) { +func (ctx *Context) loadPackage(bp *build.Package, path string, dir string) (*SourcePackage, error) { files, err := ctx.parseGoFiles(dir, append(bp.GoFiles, bp.CgoFiles...)) if err != nil { return nil, err @@ -346,7 +354,7 @@ func (ctx *Context) loadPackage(bp *build.Package, path string, dir string) (*so if bp.Name == "main" { path = "main" } - tp := &sourcePackage{ + tp := &SourcePackage{ Package: types.NewPackage(path, bp.Name), Files: files, Dir: dir, @@ -356,7 +364,7 @@ func (ctx *Context) loadPackage(bp *build.Package, path string, dir string) (*so return tp, nil } -func (ctx *Context) loadTestPackage(bp *build.Package, path string, dir string) (*sourcePackage, error) { +func (ctx *Context) loadTestPackage(bp *build.Package, path string, dir string) (*SourcePackage, error) { if len(bp.TestGoFiles) == 0 && len(bp.XTestGoFiles) == 0 { return nil, ErrNoTestFiles } @@ -371,7 +379,7 @@ func (ctx *Context) loadTestPackage(bp *build.Package, path string, dir string) if embed != nil { files = append(files, embed) } - tp := &sourcePackage{ + tp := &SourcePackage{ Package: types.NewPackage(path, bp.Name), Files: files, Dir: dir, @@ -390,7 +398,7 @@ func (ctx *Context) loadTestPackage(bp *build.Package, path string, dir string) if embed != nil { files = append(files, embed) } - tp := &sourcePackage{ + tp := &SourcePackage{ Package: types.NewPackage(path+"_test", bp.Name+"_test"), Files: files, Dir: dir, @@ -406,7 +414,7 @@ func (ctx *Context) loadTestPackage(bp *build.Package, path string, dir string) if err != nil { return nil, err } - return &sourcePackage{ + return &SourcePackage{ Package: types.NewPackage(path+".test", "main"), Files: []*ast.File{f}, Dir: dir, @@ -485,7 +493,7 @@ func (ctx *Context) LoadAstFile(path string, file *ast.File) (*ssa.Package, erro if embed != nil { files = append(files, embed) } - sp := &sourcePackage{ + sp := &SourcePackage{ Context: ctx, Package: types.NewPackage(path, file.Name.Name), Files: files, @@ -507,7 +515,7 @@ func (ctx *Context) LoadAstPackage(path string, apkg *ast.Package) (*ssa.Package files = append([]*ast.File{f}, files...) } } - sp := &sourcePackage{ + sp := &SourcePackage{ Context: ctx, Package: types.NewPackage(path, apkg.Name), Files: files, @@ -668,7 +676,7 @@ func (ctx *Context) RunTest(dir string, args []string) error { return ctx.TestPkg(pkg, dir, args) } -func (ctx *Context) buildPackage(sp *sourcePackage) (pkg *ssa.Package, err error) { +func (ctx *Context) buildPackage(sp *SourcePackage) (pkg *ssa.Package, err error) { if ctx.Mode&DisableRecover == 0 { defer func() { if e := recover(); e != nil { diff --git a/gopbuild/build.go b/gopbuild/build.go index d4ac2a15..97649acb 100644 --- a/gopbuild/build.go +++ b/gopbuild/build.go @@ -187,12 +187,17 @@ func NewContext(ctx *igop.Context) *Context { ctx = igop.NewContext(0) } ctx.Mode |= igop.CheckGopOverloadFunc - return &Context{Context: ctx, Importer: igop.NewImporter(ctx), FileSet: token.NewFileSet(), + c := &Context{Context: ctx, Importer: igop.NewImporter(ctx), FileSet: token.NewFileSet(), Loader: igop.NewTypesLoader(ctx, 0), pkgs: make(map[string]*types.Package)} + return c } -func RegisterPackagePatch(ctx *igop.Context, path string, src string) error { - return ctx.AddImportFile(path+"@patch", "src.go", src) +func RegisterPackagePatch(ctx *igop.Context, path string, src interface{}) error { + err := ctx.AddImportFile(path+"@patch", "src.go", src) + if err != nil { + return err + } + return ctx.AddImportFile(path+"@patch.gop", "src.go", src) } func isGopPackage(path string) bool { @@ -204,27 +209,56 @@ func isGopPackage(path string) bool { return false } -func (c *Context) importPath(path string) (*types.Package, error) { +func (c *Context) importPath(path string) (gop bool, pkg *types.Package, err error) { if isGopPackage(path) { - return c.Loader.Import(path) + gop = true + pkg, err = c.Loader.Import(path) + } else { + pkg, err = c.Importer.Import(path) } - return c.Importer.Import(path) + return } func (c *Context) Import(path string) (*types.Package, error) { if pkg, ok := c.pkgs[path]; ok { return pkg, nil } - pkg, err := c.importPath(path) + gop, pkg, err := c.importPath(path) if err != nil { - return nil, err + return pkg, err } - if patch, err := c.Context.Loader.Import(path + "@patch"); err == nil { - for _, name := range patch.Scope().Names() { - pkg.Scope().Insert(patch.Scope().Lookup(name)) + c.pkgs[path] = pkg + if gop { + if sp := c.Context.SourcePackage(path + "@patch.gop"); sp != nil { + sp.Importer = c + err := sp.Load() + if err != nil { + return nil, err + } + patch := types.NewPackage(path+"@patch", pkg.Name()) + for _, name := range sp.Package.Scope().Names() { + obj := sp.Package.Scope().Lookup(name) + switch obj.(type) { + case *types.Func: + obj = types.NewFunc(obj.Pos(), patch, obj.Name(), obj.Type().(*types.Signature)) + case *types.TypeName: + named := obj.Type().(*types.Named) + var methods []*types.Func + if n := named.NumMethods(); n > 0 { + methods = make([]*types.Func, n) + for i := 0; i < n; i++ { + methods[i] = named.Method(i) + } + } + obj = types.NewTypeName(obj.Pos(), patch, obj.Name(), nil) + types.NewNamed(obj.(*types.TypeName), named.Underlying(), methods) + default: + continue + } + pkg.Scope().Insert(obj) + } } } - c.pkgs[path] = pkg return pkg, nil } diff --git a/gopbuild/build_test.go b/gopbuild/build_test.go index b835ada2..27a692bd 100644 --- a/gopbuild/build_test.go +++ b/gopbuild/build_test.go @@ -433,14 +433,39 @@ func TestPackagePatch(t *testing.T) { RegisterPackagePatch(ctx, "github.com/qiniu/x/gsh", `package gsh import "github.com/qiniu/x/gsh" +type Point struct { + X int + Y int +} + +func (p *Point) Info() { + println(p.X, p.Y) +} + +type Info interface { + Info() +} + +func Dump(i Info) { + i.Info() +} + func Gopt_App_Gopx_GetWidget[T any](app any, name string) { var _ gsh.App println(app, name) } `) + src := ` +getWidget(int,"info") +pt := &Point{100,200} +pt.Info() +println(pt.X) +dump(pt) +` expected := `package main import ( + "fmt" "github.com/qiniu/x/gsh" gsh1 "github.com/qiniu/x/gsh@patch" ) @@ -452,6 +477,14 @@ type App struct { func (this *App) MainEntry() { //line main.gsh:2:1 gsh1.Gopt_App_Gopx_GetWidget[int](this, "info") +//line main.gsh:3:1 + pt := &gsh1.Point{100, 200} +//line main.gsh:4:1 + pt.Info() +//line main.gsh:5:1 + fmt.Println(pt.X) +//line main.gsh:6:1 + gsh1.Dump(pt) } func (this *App) Main() { gsh.Gopt_App_Main(this) @@ -460,13 +493,11 @@ func main() { new(App).Main() } ` - data, err := BuildFile(ctx, "main.gsh", ` -getWidget(int,"info") -`) + data, err := BuildFile(ctx, "main.gsh", src) if err != nil { t.Fatal(err) } if string(data) != expected { - t.Fatal("build error", string(data)) + t.Fatal("build error:\n", string(data)) } }