diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/WebMvcSecurityConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/WebMvcSecurityConfiguration.java index 9a2fd125c89..55e69940375 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/WebMvcSecurityConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/WebMvcSecurityConfiguration.java @@ -16,9 +16,14 @@ package org.springframework.security.config.annotation.web.configuration; +import java.io.IOException; import java.util.List; import jakarta.servlet.Filter; +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.ServletResponse; import jakarta.servlet.http.HttpServletRequest; import org.springframework.beans.BeanMetadataElement; @@ -26,7 +31,6 @@ import org.springframework.beans.factory.FactoryBean; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; -import org.springframework.beans.factory.config.RuntimeBeanReference; import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.beans.factory.support.BeanDefinitionRegistry; import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor; @@ -42,6 +46,8 @@ import org.springframework.security.web.SecurityFilterChain; import org.springframework.security.web.access.HandlerMappingIntrospectorRequestTransformer; import org.springframework.security.web.context.AbstractSecurityWebApplicationInitializer; +import org.springframework.security.web.firewall.HttpFirewall; +import org.springframework.security.web.firewall.RequestRejectedHandler; import org.springframework.security.web.method.annotation.AuthenticationPrincipalArgumentResolver; import org.springframework.security.web.method.annotation.CsrfTokenArgumentResolver; import org.springframework.security.web.method.annotation.CurrentSecurityContextArgumentResolver; @@ -135,11 +141,8 @@ public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) t registry.registerBeanDefinition(HANDLER_MAPPING_INTROSPECTOR_BEAN_NAME + "RequestTransformer", hmiRequestTransformer); - String filterChainProxyBeanName = AbstractSecurityWebApplicationInitializer.DEFAULT_FILTER_NAME - + "Proxy"; BeanDefinition filterChainProxy = registry .getBeanDefinition(AbstractSecurityWebApplicationInitializer.DEFAULT_FILTER_NAME); - registry.registerBeanDefinition(filterChainProxyBeanName, filterChainProxy); BeanDefinitionBuilder hmiCacheFilterBldr = BeanDefinitionBuilder .rootBeanDefinition(HandlerMappingIntrospectorCachFilterFactoryBean.class) @@ -147,9 +150,9 @@ public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) t ManagedList filters = new ManagedList<>(); filters.add(hmiCacheFilterBldr.getBeanDefinition()); - filters.add(new RuntimeBeanReference(filterChainProxyBeanName)); + filters.add(filterChainProxy); BeanDefinitionBuilder compositeSpringSecurityFilterChainBldr = BeanDefinitionBuilder - .rootBeanDefinition(SpringSecurityFilterCompositeFilter.class) + .rootBeanDefinition(CompositeFilterChainProxy.class) .addConstructorArgValue(filters); registry.removeBeanDefinition(AbstractSecurityWebApplicationInitializer.DEFAULT_FILTER_NAME); @@ -188,21 +191,73 @@ public Class getObjectType() { } /** - * Extension to {@link CompositeFilter} to expose private methods used by Spring - * Security's test support + * Extends {@link FilterChainProxy} to provide as much passivity as possible but + * delegates to {@link CompositeFilter} for + * {@link #doFilter(ServletRequest, ServletResponse, FilterChain)}. */ - static class SpringSecurityFilterCompositeFilter extends CompositeFilter { + static class CompositeFilterChainProxy extends FilterChainProxy { - private FilterChainProxy springSecurityFilterChain; + /** + * Used for {@link #doFilter(ServletRequest, ServletResponse, FilterChain)} + */ + private final Filter doFilterDelegate; - SpringSecurityFilterCompositeFilter(List filters) { - setFilters(filters); // for the parent + private final FilterChainProxy springSecurityFilterChain; + + /** + * Creates a new instance + * @param filters the Filters to delegate to. One of which must be + * FilterChainProxy. + */ + CompositeFilterChainProxy(List filters) { + this.doFilterDelegate = createDoFilterDelegate(filters); + this.springSecurityFilterChain = findFilterChainProxy(filters); } @Override - public void setFilters(List filters) { - super.setFilters(filters); - this.springSecurityFilterChain = findFilterChainProxy(filters); + public void afterPropertiesSet() { + this.springSecurityFilterChain.afterPropertiesSet(); + } + + @Override + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) + throws IOException, ServletException { + this.doFilterDelegate.doFilter(request, response, chain); + } + + @Override + public List getFilters(String url) { + return this.springSecurityFilterChain.getFilters(url); + } + + @Override + public List getFilterChains() { + return this.springSecurityFilterChain.getFilterChains(); + } + + @Override + public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) { + this.springSecurityFilterChain.setSecurityContextHolderStrategy(securityContextHolderStrategy); + } + + @Override + public void setFilterChainValidator(FilterChainValidator filterChainValidator) { + this.springSecurityFilterChain.setFilterChainValidator(filterChainValidator); + } + + @Override + public void setFilterChainDecorator(FilterChainDecorator filterChainDecorator) { + this.springSecurityFilterChain.setFilterChainDecorator(filterChainDecorator); + } + + @Override + public void setFirewall(HttpFirewall firewall) { + this.springSecurityFilterChain.setFirewall(firewall); + } + + @Override + public void setRequestRejectedHandler(RequestRejectedHandler requestRejectedHandler) { + this.springSecurityFilterChain.setRequestRejectedHandler(requestRejectedHandler); } /** @@ -212,7 +267,7 @@ public void setFilters(List filters) { * @return */ private List getFilters(HttpServletRequest request) { - List filterChains = getFilterChainProxy().getFilterChains(); + List filterChains = this.springSecurityFilterChain.getFilterChains(); for (SecurityFilterChain chain : filterChains) { if (chain.matches(request)) { return chain.getFilters(); @@ -222,11 +277,14 @@ private List getFilters(HttpServletRequest request) { } /** - * Used by Spring Security's Test support to find the FilterChainProxy - * @return + * Creates the Filter to delegate to for doFilter + * @param filters the Filters to delegate to. + * @return the Filter for doFilter */ - private FilterChainProxy getFilterChainProxy() { - return this.springSecurityFilterChain; + private static Filter createDoFilterDelegate(List filters) { + CompositeFilter delegate = new CompositeFilter(); + delegate.setFilters(filters); + return delegate; } /**