Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fx.Self: a parameter to fx.As for providing a type as itself #1201

Merged
merged 5 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 61 additions & 10 deletions annotated.go
Original file line number Diff line number Diff line change
Expand Up @@ -1097,7 +1097,19 @@ func OnStop(onStop interface{}) Annotation {

type asAnnotation struct {
targets []interface{}
types []reflect.Type
types []asType
}

type asType struct {
self bool
typ reflect.Type // May be nil if self is true.
}

func (a asType) String() string {
if a.self {
return "self"
}
return a.typ.String()
}

func isOut(t reflect.Type) bool {
Expand All @@ -1119,7 +1131,7 @@ var _ Annotation = (*asAnnotation)(nil)
// bytes.NewBuffer (bytes.Buffer) should be provided as io.Writer type:
//
// fx.Provide(
// fx.Annotate(bytes.NewBuffer(...), fx.As(new(io.Writer)))
// fx.Annotate(bytes.NewBuffer, fx.As(new(io.Writer)))
// )
//
// In other words, the code above is equivalent to:
Expand Down Expand Up @@ -1157,15 +1169,50 @@ func As(interfaces ...interface{}) Annotation {
return &asAnnotation{targets: interfaces}
}

// Self returns a special value that can be passed to [As] to indicate
// that a type should be provided as its original type, in addition to whatever other
// types it gets provided as via other [As] annotations.
//
// For example,
//
// fx.Provide(
// fx.Annotate(
// bytes.NewBuffer,
// fx.As(new(io.Writer)),
// fx.As(fx.Self()),
// )
// )
//
// Is equivalent to,
//
// fx.Provide(
// bytes.NewBuffer,
// func(b *bytes.Buffer) io.Writer {
// return b
// },
// )
//
// in that it provides the same *bytes.Buffer instance
// as both a *bytes.Buffer and an io.Writer.
func Self() any {
return &self{}
}

type self struct{}

func (at *asAnnotation) apply(ann *annotated) error {
at.types = make([]reflect.Type, len(at.targets))
at.types = make([]asType, len(at.targets))
for i, typ := range at.targets {
if _, ok := typ.(*self); ok {
at.types[i] = asType{self: true}
continue
}
t := reflect.TypeOf(typ)
if t.Kind() != reflect.Ptr || t.Elem().Kind() != reflect.Interface {
return fmt.Errorf("fx.As: argument must be a pointer to an interface: got %v", t)
}
t = t.Elem()
at.types[i] = t
at.types[i] = asType{typ: t}
}

ann.As = append(ann.As, at.types)
Expand Down Expand Up @@ -1209,12 +1256,16 @@ func (at *asAnnotation) results(ann *annotated) (
Type: t,
Tag: f.Tag,
}
if i < len(at.types) {
if !t.Implements(at.types[i]) {
return nil, nil, fmt.Errorf("invalid fx.As: %v does not implement %v", t, at.types[i])
}
field.Type = at.types[i]

if i >= len(at.types) || at.types[i].self {
fields = append(fields, field)
continue
}

if !t.Implements(at.types[i].typ) {
return nil, nil, fmt.Errorf("invalid fx.As: %v does not implement %v", t, at.types[i])
}
field.Type = at.types[i].typ
fields = append(fields, field)
}
resType := reflect.StructOf(fields)
Expand Down Expand Up @@ -1475,7 +1526,7 @@ type annotated struct {
Annotations []Annotation
ParamTags []string
ResultTags []string
As [][]reflect.Type
As [][]asType
From []reflect.Type
FuncPtr uintptr
Hooks []*lifecycleHookAnnotation
Expand Down
100 changes: 100 additions & 0 deletions annotated_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,12 @@ func TestAnnotatedAs(t *testing.T) {

S fmt.Stringer `name:"goodStringer"`
}
type inSelf struct {
fx.In

S1 fmt.Stringer `name:"goodStringer"`
S2 *asStringer `name:"goodStringer"`
}
type myStringer interface {
String() string
}
Expand Down Expand Up @@ -699,6 +705,100 @@ func TestAnnotatedAs(t *testing.T) {
},
startApp: true,
},
{
desc: "self w other As annotations",
provide: fx.Provide(
fx.Annotate(
func() *asStringer {
return &asStringer{name: "stringer"}
},
fx.As(fx.Self()),
fx.As(new(fmt.Stringer)),
),
),
invoke: func(s fmt.Stringer, as *asStringer) {
assert.Equal(t, "stringer", s.String())
assert.Equal(t, "stringer", as.String())
},
},
{
desc: "self as one As target",
provide: fx.Provide(
fx.Annotate(
func() (*asStringer, *bytes.Buffer) {
s := &asStringer{name: "stringer"}
b := &bytes.Buffer{}
return s, b
},
fx.As(fx.Self(), new(io.Writer)),
JacobOaks marked this conversation as resolved.
Show resolved Hide resolved
),
),
invoke: func(s *asStringer, w io.Writer) {
assert.Equal(t, "stringer", s.String())
_, err := w.Write([]byte("."))
assert.NoError(t, err)
},
},
{
desc: "two as, two self, four types",
provide: fx.Provide(
fx.Annotate(
func() (*asStringer, *bytes.Buffer) {
s := &asStringer{name: "stringer"}
b := &bytes.Buffer{}
return s, b
},
fx.As(fx.Self(), new(io.Writer)),
fx.As(new(fmt.Stringer)),
),
),
invoke: func(s1 *asStringer, s2 fmt.Stringer, b *bytes.Buffer, w io.Writer) {
assert.Equal(t, "stringer", s1.String())
assert.Equal(t, "stringer", s2.String())
_, err := w.Write([]byte("."))
assert.NoError(t, err)
_, err = b.Write([]byte("."))
assert.NoError(t, err)
},
},
{
desc: "self with lifecycle hook",
provide: fx.Provide(
fx.Annotate(
func() *asStringer {
return &asStringer{name: "stringer"}
},
fx.As(fx.Self()),
fx.As(new(fmt.Stringer)),
fx.OnStart(func(s fmt.Stringer, as *asStringer) {
assert.Equal(t, "stringer", s.String())
assert.Equal(t, "stringer", as.String())
}),
),
),
invoke: func(s fmt.Stringer, as *asStringer) {
assert.Equal(t, "stringer", s.String())
assert.Equal(t, "stringer", as.String())
},
startApp: true,
},
{
desc: "self with result tags",
provide: fx.Provide(
fx.Annotate(
func() *asStringer {
return &asStringer{name: "stringer"}
},
fx.As(fx.Self()),
fx.As(new(fmt.Stringer)),
fx.ResultTags(`name:"goodStringer"`),
),
),
invoke: func(i inSelf) {
assert.Equal(t, "stringer", i.S1.String())
assert.Equal(t, "stringer", i.S2.String())
},
},
}

for _, tt := range tests {
Expand Down