Skip to content

Commit

Permalink
Add tests and refine behaviour of registerSingleton (#8469)
Browse files Browse the repository at this point in the history
There are a number of inconsistencies with `registerSingleton`. Currently the following problems exist:

* Calling `registerSingleton` when the bean type and qualifier match an existing bean type and qualifier doesn't override the bean (it does in Micronaut 3.x but probably shouldn't)
* Calling `registerSingleton` twice consecutively overrides the previously defined bean if the type and qualifier match
* If the registered type and the type of the instance differ (say `Foo` as the bean type but `FooImpl` for the implementation) then `getBeansOfType(Foo.class)` correctly contains the bean and `getBeansOfType(FooImpl.class)` correctly doesn't contain the bean. However the bean can be incorrectly located by `findBean(FooImpl.class)`

This PR tries to address these issues in the following ways:

* Calling `registerSingleton` always adds a new bean and never overrides
* If you want to use `registerSingleton` for bean replacement then you must call `replaces(TypeToReplace)` using the `RuntimeBeanDefinition` API thus formalizing the way to replace a bean.
* Beans that are registered with a particular type (`Foo`) and a particular impl (`FooImpl`) cannot be located by either `getBeansOfType(FooImpl)` or any bean lookup methods like `getBean`, `findBean` etc.
* Also improves handling of runtime beans with generics



Co-authored-by: Sergio del Amo <[email protected]>
  • Loading branch information
graemerocher and sdelamo authored Dec 9, 2022
1 parent 5573f6a commit 629db38
Show file tree
Hide file tree
Showing 14 changed files with 256 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ class ConstructorFactorySpec extends Specification {

void "test injection with constructor supplied by a provider"() {
given:
BeanContext context = new DefaultBeanContext()
context.start()
BeanContext context = BeanContext.run()

when:"A bean is obtained which has a constructor that depends on a bean provided by a provider"
B b = context.getBean(B)
Expand All @@ -44,7 +43,10 @@ class ConstructorFactorySpec extends Specification {
b.a.c != null
b.a.c2 != null
b.a.d != null
b.a.is(context.getBean(AImpl))
b.a.is(context.getBean(A))

cleanup:
context.close()
}

static interface A {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ class FieldArrayFactorySpec extends Specification {

void "test injection with field supplied by a provider"() {
given:
BeanContext context = new DefaultBeanContext()
context.start()
BeanContext context = BeanContext.run()

when:"A bean is obtained which has a field that depends on a bean provided by a provider"
B b = context.getBean(B)
Expand All @@ -42,7 +41,10 @@ class FieldArrayFactorySpec extends Specification {
b.all[0] instanceof AImpl
b.all[0].c != null
b.all[0].c2 != null
b.all[0].is(context.getBean(AImpl))
b.all[0].is(context.getBean(A))

cleanup:
context.close()
}

static interface A {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ class FieldFactorySpec extends Specification {

void "test injection with field supplied by a provider"() {
given:
BeanContext context = new DefaultBeanContext()
context.start()
BeanContext context = BeanContext.run()

when:"A bean is obtained which has a field that depends on a bean provided by a provider"
B b = context.getBean(B)
Expand All @@ -42,7 +41,10 @@ class FieldFactorySpec extends Specification {
b.a instanceof AImpl
b.a.c != null
b.a.c2 != null
b.a.is(context.getBean(AImpl))
b.a.is(context.getBean(A))

cleanup:
context.close()
}

static interface A {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ class ConstructorFactorySpec extends Specification {

void "test injection with constructor supplied by a provider"() {
given:
BeanContext context = new DefaultBeanContext()
context.start()
BeanContext context = BeanContext.run()

when:"A bean is obtained which has a constructor that depends on a bean provided by a provider"
B b = context.getBean(B)
Expand All @@ -35,6 +34,9 @@ class ConstructorFactorySpec extends Specification {
b.a.c != null
b.a.c2 != null
b.a.d != null
b.a.is(context.getBean(AImpl))
b.a.is(context.getBean(A))

cleanup:
context.close()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,77 @@ package io.micronaut.inject.context

import io.micronaut.context.BeanContext
import io.micronaut.context.DefaultBeanContext
import io.micronaut.context.RuntimeBeanDefinition
import io.micronaut.context.annotation.Bean
import io.micronaut.context.annotation.Type
import io.micronaut.core.type.Argument
import io.micronaut.inject.qualifiers.Qualifiers
import jakarta.inject.Named
import jakarta.inject.Singleton
import spock.lang.Issue
import spock.lang.Specification

import java.lang.reflect.Proxy

class RegisterSingletonSpec extends Specification {

void "test register singleton with generic types"() {
given:
BeanContext context = BeanContext.run()

when:
context.registerSingleton(new TestReporter())

then:
context.containsBean(Argument.of(Reporter, Span))

cleanup:
context.close()
}

void "test register singleton and exposed type"() {
given:
BeanContext context = BeanContext.run()

when:
context.registerBeanDefinition(
RuntimeBeanDefinition.builder(Codec, ()-> new OverridingCodec())
.singleton(true)
.qualifier(Qualifiers.byName("foo"))
.replaces(ToBeReplacedCodec)
.build()
) // replaces ToBeReplacedCodec
context.registerSingleton(Codec, { } as Codec) // adds a new codec
context.registerSingleton(Codec, new FooCodec()) // adds another codec
context.registerSingleton(new BarCodec()) // should be registered with bean type BarCodec
context.registerSingleton(Codec, new BazCodec(), Qualifiers.byName("baz"))

then:
def codecs = context.getBeansOfType(Codec)
codecs.size() == 7
codecs.find { it in FooCodec }
codecs.find { it in BarCodec }
codecs.find { it in BazCodec }
!codecs.find { it in ToBeReplacedCodec }
codecs.find { it in OverridingCodec }
codecs.find { it in OtherCodec }
codecs.find { it in StuffCodec }
codecs.find { it in Proxy }
codecs == context.getBeansOfType(Codec) // second resolve returns the same result
context.getBeansOfType(FooCodec).size() == 0 // not an exposed type
context.getBeansOfType(BarCodec).size() == 1 // BarCodec type is exposed
context.findBean(FooCodec).isEmpty() // not an exposed type
context.findBean(StuffCodec).isEmpty() // not an exposed type
context.findBean(OtherCodec).isPresent() // an exposed type

cleanup:
context.close()
}


void "test register singleton method"() {
given:
BeanContext context = new DefaultBeanContext().start()
BeanContext context = BeanContext.run()
def b = new B()

when:
Expand Down Expand Up @@ -83,4 +144,26 @@ class RegisterSingletonSpec extends Specification {
this.type = type
}
}

static interface Codec {

}

static class OverridingCodec implements Codec {}
static class FooCodec implements Codec {}
static class BarCodec implements Codec {}
static class BazCodec implements Codec {}
@Singleton
@Bean(typed = Codec)
static class StuffCodec implements Codec {}
@Singleton
static class OtherCodec implements Codec {}

@Singleton
@Named("foo")
static class ToBeReplacedCodec implements Codec {}

static interface Reporter<B> {}
static class Span {}
static class TestReporter implements Reporter<Span> {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class FieldArrayFactorySpec extends Specification {
b.all[0] instanceof AImpl
((AImpl)b.all[0]).c != null
((AImpl)b.all[0]).c2 != null
b.all[0].is(context.getBean(AImpl))
b.all[0].is(context.getBean(A))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ class FieldFactorySpec extends Specification {

void "test injection with field supplied by a provider"() {
given:
BeanContext context = new DefaultBeanContext()
context.start()
BeanContext context = BeanContext.run()

when:"A bean is obtained which has a field that depends on a bean provided by a provider"
B b = context.getBean(B)
Expand All @@ -36,7 +35,10 @@ class FieldFactorySpec extends Specification {
b.a instanceof AImpl
b.a.c != null
b.a.c2 != null
b.a.is(context.getBean(AImpl))
b.a.is(context.getBean(A))

cleanup:
context.close()
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,13 @@ protected void startEnvironment() {
.qualifier(PrimaryQualifier.INSTANCE);

//noinspection resource
registerBeanDefinition(definition.build());

RuntimeBeanDefinition<? extends Environment> beanDefinition = definition.build();
BeanDefinition<? extends Environment> existing = findBeanDefinition(beanDefinition.getBeanType()).orElse(null);
if (existing instanceof RuntimeBeanDefinition<?> runtimeBeanDefinition) {
removeBeanDefinition(runtimeBeanDefinition);
}
registerBeanDefinition(beanDefinition);
}

@Override
Expand Down
47 changes: 26 additions & 21 deletions inject/src/main/java/io/micronaut/context/DefaultBeanContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ public <T> BeanContext registerSingleton(@NonNull Class<T> type, @NonNull T sing
BeanDefinition<T> beanDefinition;
if (inject && running.get()) {
// Bean cannot be injected before the start of the context
beanDefinition = findBeanDefinition(type, qualifier).orElse(null);
beanDefinition = findConcreteCandidate(null, Argument.of(type), qualifier, false).orElse(null);
if (beanDefinition == null) {
// Purge cache miss
purgeCacheForBeanInstance(singleton);
Expand Down Expand Up @@ -758,14 +758,6 @@ public <T> BeanContext registerSingleton(@NonNull Class<T> type, @NonNull T sing
);
singletonScope.registerSingletonBean(registration, qualifier);
registerBeanDefinition(runtimeBeanDefinition);

for (Class<?> indexedType : indexedTypes) {
if (indexedType == type || indexedType.isAssignableFrom(type)) {
final Collection<BeanDefinitionReference> indexed = resolveTypeIndex(indexedType);
indexed.add(runtimeBeanDefinition);
break;
}
}
}
return this;
}
Expand Down Expand Up @@ -1663,21 +1655,15 @@ public Collection<BeanDefinitionReference<?>> getBeanDefinitionReferences() {
@NonNull
public <B> BeanContext registerBeanDefinition(@NonNull RuntimeBeanDefinition<B> definition) {
Objects.requireNonNull(definition, "Bean definition cannot be null");
BeanDefinition<B> existing = findBeanDefinition(definition.getGenericBeanType(), definition.getDeclaredQualifier()).orElse(null);
if (existing instanceof RuntimeBeanDefinition<B> runtimeBeanDefinition) {
this.beanDefinitionsClasses.remove(runtimeBeanDefinition);
}
Class<B> beanType = definition.getBeanType();
this.beanDefinitionsClasses.add(definition);
for (Class<?> indexedType : indexedTypes) {
if (definition.isCandidateBean(Argument.of(indexedType))) {
Collection<BeanDefinitionReference> index = resolveTypeIndex(indexedType);
if (existing instanceof RuntimeBeanDefinition<B> runtimeBeanDefinition) {
index.remove(runtimeBeanDefinition);
}
index.add(definition);
if (indexedType == beanType || indexedType.isAssignableFrom(beanType)) {
final Collection<BeanDefinitionReference> indexed = resolveTypeIndex(indexedType);
indexed.add(definition);
break;
}
}
this.beanDefinitionsClasses.add(definition);
Class<B> beanType = definition.getBeanType();
purgeCacheForBeanType(beanType);
return this;
}
Expand All @@ -1689,6 +1675,25 @@ private <B> void purgeCacheForBeanType(Class<B> beanType) {
containsBeanCache.entrySet().removeIf(entry -> entry.getKey().beanType.isAssignableFrom(beanType));
}

/**
* The definition to remove.
* @param definition The definition to remove
* @param <B> The bean type
*/
@Internal
<B> void removeBeanDefinition(RuntimeBeanDefinition<B> definition) {
Class<B> beanType = definition.getBeanType();
for (Class<?> indexedType : indexedTypes) {
if (indexedType == beanType || indexedType.isAssignableFrom(beanType)) {
final Collection<BeanDefinitionReference> indexed = resolveTypeIndex(indexedType);
indexed.remove(definition);
break;
}
}
this.beanDefinitionsClasses.remove(definition);
purgeCacheForBeanType(definition.getBeanType());
}

/**
* Get a bean of the given type.
*
Expand Down
Loading

0 comments on commit 629db38

Please sign in to comment.