Skip to content

Commit

Permalink
Consider fallback beans when evaluating ConditionalOnSingleCandidate
Browse files Browse the repository at this point in the history
Closes gh-41580
  • Loading branch information
wilkinsona committed Jul 23, 2024
1 parent 3561ab8 commit 12ec18f
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.function.Predicate;

import org.springframework.aop.scope.ScopedProxyUtils;
import org.springframework.beans.factory.BeanFactory;
Expand Down Expand Up @@ -113,61 +115,90 @@ private ConditionOutcome getOutcome(Set<String> requiredBeanTypes, Class<? exten

@Override
public ConditionOutcome getMatchOutcome(ConditionContext context, AnnotatedTypeMetadata metadata) {
ConditionMessage matchMessage = ConditionMessage.empty();
ConditionOutcome matchOutcome = ConditionOutcome.match();
MergedAnnotations annotations = metadata.getAnnotations();
if (annotations.isPresent(ConditionalOnBean.class)) {
Spec<ConditionalOnBean> spec = new Spec<>(context, metadata, annotations, ConditionalOnBean.class);
MatchResult matchResult = getMatchingBeans(context, spec);
if (!matchResult.isAllMatched()) {
String reason = createOnBeanNoMatchReason(matchResult);
return ConditionOutcome.noMatch(spec.message().because(reason));
matchOutcome = evaluateConditionalOnBean(spec, matchOutcome.getConditionMessage());
if (!matchOutcome.isMatch()) {
return matchOutcome;
}
matchMessage = spec.message(matchMessage)
.found("bean", "beans")
.items(Style.QUOTE, matchResult.getNamesOfAllMatches());
}
if (metadata.isAnnotated(ConditionalOnSingleCandidate.class.getName())) {
Spec<ConditionalOnSingleCandidate> spec = new SingleCandidateSpec(context, metadata, annotations);
MatchResult matchResult = getMatchingBeans(context, spec);
if (!matchResult.isAllMatched()) {
return ConditionOutcome.noMatch(spec.message().didNotFind("any beans").atAll());
}
Set<String> allBeans = matchResult.getNamesOfAllMatches();
if (allBeans.size() == 1) {
matchMessage = spec.message(matchMessage).found("a single bean").items(Style.QUOTE, allBeans);
}
else {
List<String> primaryBeans = getPrimaryBeans(context.getBeanFactory(), allBeans,
spec.getStrategy() == SearchStrategy.ALL);
if (primaryBeans.isEmpty()) {
return ConditionOutcome
.noMatch(spec.message().didNotFind("a primary bean from beans").items(Style.QUOTE, allBeans));
}
if (primaryBeans.size() > 1) {
return ConditionOutcome
.noMatch(spec.message().found("multiple primary beans").items(Style.QUOTE, primaryBeans));
}
matchMessage = spec.message(matchMessage)
.found("a single primary bean '" + primaryBeans.get(0) + "' from beans")
.items(Style.QUOTE, allBeans);
Spec<ConditionalOnSingleCandidate> spec = new SingleCandidateSpec(context, metadata,
metadata.getAnnotations());
matchOutcome = evaluateConditionalOnSingleCandidate(spec, matchOutcome.getConditionMessage());
if (!matchOutcome.isMatch()) {
return matchOutcome;
}
}
if (metadata.isAnnotated(ConditionalOnMissingBean.class.getName())) {
Spec<ConditionalOnMissingBean> spec = new Spec<>(context, metadata, annotations,
ConditionalOnMissingBean.class);
MatchResult matchResult = getMatchingBeans(context, spec);
if (matchResult.isAnyMatched()) {
String reason = createOnMissingBeanNoMatchReason(matchResult);
return ConditionOutcome.noMatch(spec.message().because(reason));
matchOutcome = evaluateConditionalOnMissingBean(spec, matchOutcome.getConditionMessage());
if (!matchOutcome.isMatch()) {
return matchOutcome;
}
matchMessage = spec.message(matchMessage).didNotFind("any beans").atAll();
}
return ConditionOutcome.match(matchMessage);
return matchOutcome;
}

private ConditionOutcome evaluateConditionalOnBean(Spec<ConditionalOnBean> spec, ConditionMessage matchMessage) {
MatchResult matchResult = getMatchingBeans(spec);
if (!matchResult.isAllMatched()) {
String reason = createOnBeanNoMatchReason(matchResult);
return ConditionOutcome.noMatch(spec.message().because(reason));
}
return ConditionOutcome.match(spec.message(matchMessage)
.found("bean", "beans")
.items(Style.QUOTE, matchResult.getNamesOfAllMatches()));
}

private ConditionOutcome evaluateConditionalOnSingleCandidate(Spec<ConditionalOnSingleCandidate> spec,
ConditionMessage matchMessage) {
MatchResult matchResult = getMatchingBeans(spec);
if (!matchResult.isAllMatched()) {
return ConditionOutcome.noMatch(spec.message().didNotFind("any beans").atAll());
}
Set<String> allBeans = matchResult.getNamesOfAllMatches();
if (allBeans.size() == 1) {
return ConditionOutcome
.match(spec.message(matchMessage).found("a single bean").items(Style.QUOTE, allBeans));
}
Map<String, BeanDefinition> beanDefinitions = getBeanDefinitions(spec.context.getBeanFactory(), allBeans,
spec.getStrategy() == SearchStrategy.ALL);
List<String> primaryBeans = getPrimaryBeans(beanDefinitions);
if (primaryBeans.size() == 1) {
return ConditionOutcome.match(spec.message(matchMessage)
.found("a single primary bean '" + primaryBeans.get(0) + "' from beans")
.items(Style.QUOTE, allBeans));
}
if (primaryBeans.size() > 1) {
return ConditionOutcome
.noMatch(spec.message().found("multiple primary beans").items(Style.QUOTE, primaryBeans));
}
List<String> nonFallbackBeans = getNonFallbackBeans(beanDefinitions);
if (nonFallbackBeans.size() == 1) {
return ConditionOutcome.match(spec.message(matchMessage)
.found("a single non-fallback bean '" + nonFallbackBeans.get(0) + "' from beans")
.items(Style.QUOTE, allBeans));
}
return ConditionOutcome.noMatch(spec.message().found("multiple beans").items(Style.QUOTE, allBeans));
}

protected final MatchResult getMatchingBeans(ConditionContext context, Spec<?> spec) {
ClassLoader classLoader = context.getClassLoader();
ConfigurableListableBeanFactory beanFactory = context.getBeanFactory();
private ConditionOutcome evaluateConditionalOnMissingBean(Spec<ConditionalOnMissingBean> spec,
ConditionMessage matchMessage) {
MatchResult matchResult = getMatchingBeans(spec);
if (matchResult.isAnyMatched()) {
String reason = createOnMissingBeanNoMatchReason(matchResult);
return ConditionOutcome.noMatch(spec.message().because(reason));
}
return ConditionOutcome.match(spec.message(matchMessage).didNotFind("any beans").atAll());
}

protected final MatchResult getMatchingBeans(Spec<?> spec) {
ClassLoader classLoader = spec.getContext().getClassLoader();
ConfigurableListableBeanFactory beanFactory = spec.getContext().getBeanFactory();
boolean considerHierarchy = spec.getStrategy() != SearchStrategy.CURRENT;
Set<Class<?>> parameterizedContainers = spec.getParameterizedContainers();
if (spec.getStrategy() == SearchStrategy.ANCESTORS) {
Expand Down Expand Up @@ -373,16 +404,32 @@ private void appendMessageForMatches(StringBuilder reason, Map<String, Collectio
}
}

private List<String> getPrimaryBeans(ConfigurableListableBeanFactory beanFactory, Set<String> beanNames,
boolean considerHierarchy) {
List<String> primaryBeans = new ArrayList<>();
private Map<String, BeanDefinition> getBeanDefinitions(ConfigurableListableBeanFactory beanFactory,
Set<String> beanNames, boolean considerHierarchy) {
Map<String, BeanDefinition> definitions = new HashMap<>(beanNames.size());
for (String beanName : beanNames) {
BeanDefinition beanDefinition = findBeanDefinition(beanFactory, beanName, considerHierarchy);
if (beanDefinition != null && beanDefinition.isPrimary()) {
primaryBeans.add(beanName);
definitions.put(beanName, beanDefinition);
}
return definitions;
}

private List<String> getPrimaryBeans(Map<String, BeanDefinition> beanDefinitions) {
return getMatchingBeans(beanDefinitions, BeanDefinition::isPrimary);
}

private List<String> getNonFallbackBeans(Map<String, BeanDefinition> beanDefinitions) {
return getMatchingBeans(beanDefinitions, Predicate.not(BeanDefinition::isFallback));
}

private List<String> getMatchingBeans(Map<String, BeanDefinition> beanDefinitions, Predicate<BeanDefinition> test) {
List<String> matches = new ArrayList<>();
for (Entry<String, BeanDefinition> namedBeanDefinition : beanDefinitions.entrySet()) {
if (test.test(namedBeanDefinition.getValue())) {
matches.add(namedBeanDefinition.getKey());
}
}
return primaryBeans;
return matches;
}

private BeanDefinition findBeanDefinition(ConfigurableListableBeanFactory beanFactory, String beanName,
Expand Down Expand Up @@ -420,7 +467,7 @@ private static Set<String> addAll(Set<String> result, String[] additional) {
*/
private static class Spec<A extends Annotation> {

private final ClassLoader classLoader;
private final ConditionContext context;

private final Class<? extends Annotation> annotationType;

Expand All @@ -442,7 +489,7 @@ private static class Spec<A extends Annotation> {
.filter(MergedAnnotationPredicates.unique(MergedAnnotation::getMetaTypes))
.collect(MergedAnnotationCollectors.toMultiValueMap(Adapt.CLASS_TO_STRING));
MergedAnnotation<A> annotation = annotations.get(annotationType);
this.classLoader = context.getClassLoader();
this.context = context;
this.annotationType = annotationType;
this.names = extract(attributes, "name");
this.annotations = extract(attributes, "annotation");
Expand Down Expand Up @@ -497,7 +544,7 @@ private Set<Class<?>> resolveWhenPossible(Set<String> classNames) {
Set<Class<?>> resolved = new LinkedHashSet<>(classNames.size());
for (String className : classNames) {
try {
resolved.add(resolve(className, this.classLoader));
resolved.add(resolve(className, this.context.getClassLoader()));
}
catch (ClassNotFoundException | NoClassDefFoundError ex) {
// Ignore
Expand Down Expand Up @@ -596,31 +643,35 @@ private SearchStrategy getStrategy() {
return (this.strategy != null) ? this.strategy : SearchStrategy.ALL;
}

Set<String> getNames() {
private ConditionContext getContext() {
return this.context;
}

private Set<String> getNames() {
return this.names;
}

Set<String> getTypes() {
protected Set<String> getTypes() {
return this.types;
}

Set<String> getAnnotations() {
private Set<String> getAnnotations() {
return this.annotations;
}

Set<String> getIgnoredTypes() {
private Set<String> getIgnoredTypes() {
return this.ignoredTypes;
}

Set<Class<?>> getParameterizedContainers() {
private Set<Class<?>> getParameterizedContainers() {
return this.parameterizedContainers;
}

ConditionMessage.Builder message() {
private ConditionMessage.Builder message() {
return ConditionMessage.forCondition(this.annotationType, this);
}

ConditionMessage.Builder message(ConditionMessage message) {
private ConditionMessage.Builder message(ConditionMessage message) {
return message.andCondition(this.annotationType, this);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2012-2023 the original author or authors.
* Copyright 2012-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -21,6 +21,7 @@
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Fallback;
import org.springframework.context.annotation.Primary;
import org.springframework.context.annotation.Scope;
import org.springframework.context.annotation.ScopedProxyMode;
Expand Down Expand Up @@ -114,6 +115,17 @@ void singleCandidateMultipleCandidatesOnePrimary() {
});
}

@Test
void singleCandidateTwoCandidatesOneNormalOneFallback() {
this.contextRunner
.withUserConfiguration(AlphaFallbackConfiguration.class, BravoConfiguration.class,
OnBeanSingleCandidateConfiguration.class)
.run((context) -> {
assertThat(context).hasBean("consumer");
assertThat(context.getBean("consumer")).isEqualTo("bravo");
});
}

@Test
void singleCandidateMultipleCandidatesMultiplePrimary() {
this.contextRunner
Expand All @@ -122,6 +134,14 @@ void singleCandidateMultipleCandidatesMultiplePrimary() {
.run((context) -> assertThat(context).doesNotHaveBean("consumer"));
}

@Test
void singleCandidateMultipleCandidatesAllFallback() {
this.contextRunner
.withUserConfiguration(AlphaFallbackConfiguration.class, BravoFallbackConfiguration.class,
OnBeanSingleCandidateConfiguration.class)
.run((context) -> assertThat(context).doesNotHaveBean("consumer"));
}

@Test
void invalidAnnotationTwoTypes() {
this.contextRunner.withUserConfiguration(OnBeanSingleCandidateTwoTypesConfiguration.class).run((context) -> {
Expand Down Expand Up @@ -208,6 +228,17 @@ String alpha() {

}

@Configuration(proxyBeanMethods = false)
static class AlphaFallbackConfiguration {

@Bean
@Fallback
String alpha() {
return "alpha";
}

}

@Configuration(proxyBeanMethods = false)
static class AlphaScopedProxyConfiguration {

Expand Down Expand Up @@ -240,4 +271,15 @@ String bravo() {

}

@Configuration(proxyBeanMethods = false)
static class BravoFallbackConfiguration {

@Bean
@Fallback
String bravo() {
return "bravo";
}

}

}

0 comments on commit 12ec18f

Please sign in to comment.