Skip to content
This repository has been archived by the owner on Oct 17, 2024. It is now read-only.

Commit

Permalink
Add MethodNotAllowedFilter to block servlets that allow unsecure XSS …
Browse files Browse the repository at this point in the history
…vulnerabilities.
  • Loading branch information
chriseldredge committed Jul 28, 2015
1 parent a55b035 commit 0a984ad
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 0 deletions.
56 changes: 56 additions & 0 deletions src/main/java/com/fool/servlet/MethodNotAllowedFilter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package com.fool.servlet;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;

import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

public class MethodNotAllowedFilter implements Filter {
private Set<String> allowedMethods = new HashSet<String>();

@Override
public void init(FilterConfig config) throws ServletException {
String allowedMethodsParam = config.getInitParameter("allowedMethods");

if (allowedMethodsParam == null) {
throw new ServletException(getClass().getSimpleName() + " requires init-param 'allowedMethods'.");
}

String[] allowedMethods = allowedMethodsParam.toLowerCase().split("\\s*,\\s*");
this.allowedMethods.addAll(Arrays.asList(allowedMethods));
}

@Override
public void destroy() {
}

@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain) throws IOException, ServletException {
HttpServletRequest httpRequest = (HttpServletRequest) request;
if (isMethodAllowed(httpRequest.getMethod())) {
filterChain.doFilter(request, response);
} else {
HttpServletResponse httpResponse = (HttpServletResponse) response;
httpResponse.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED);
}
}

public boolean isMethodAllowed(String method) {
if (method == null) return false;

return allowedMethods.contains(method.toLowerCase());
}

public void addAllowedMethod(String method) {
this.allowedMethods.add(method.toLowerCase());
}
}
76 changes: 76 additions & 0 deletions src/test/java/com/fool/servlet/MethodNotAllowedFilterTests.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package com.fool.servlet;

import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import junit.framework.TestCase;

import org.mockito.Mockito;
import org.mockito.internal.verification.VerificationModeFactory;

public class MethodNotAllowedFilterTests extends TestCase {
HttpServletRequest request;
HttpServletResponse response;
FilterChain chain;
FilterConfig filterConfig;
MethodNotAllowedFilter filter;

@Override
protected void setUp() throws Exception {
super.setUp();
request = Mockito.mock(HttpServletRequest.class);
response = Mockito.mock(HttpServletResponse.class);
chain = Mockito.mock(FilterChain.class);
filterConfig = Mockito.mock(FilterConfig.class);
filter = new MethodNotAllowedFilter();
}

public void testInitParsesAllowedMethods() throws Exception {
Mockito.when(filterConfig.getInitParameter("allowedMethods")).thenReturn("POST, options");

filter.init(filterConfig);

assertTrue("isMethodAllowed(\"post\")", filter.isMethodAllowed("post"));
assertTrue("isMethodAllowed(\"OPTIONS\")", filter.isMethodAllowed("OPTIONS"));
}

public void testInitThrowsOnMissingParameter() throws Exception {
Mockito.when(filterConfig.getInitParameter("allowedMethods")).thenReturn(null);

try {
filter.init(filterConfig);
fail("Expected ServletException");
} catch (ServletException ex) {
}
}

public void testInvokesChain() throws Exception {
Mockito.when(request.getMethod()).thenReturn("get");
filter.addAllowedMethod("get");

filter.doFilter(request, response, chain);

Mockito.verify(chain).doFilter(request, response);
}

public void testMethodNotAllowedDoesNotInvokeChain() throws Exception {
Mockito.when(request.getMethod()).thenReturn("get");
filter.addAllowedMethod("post");

filter.doFilter(request, response, chain);

Mockito.verify(chain, VerificationModeFactory.times(0)).doFilter(request, response);
}

public void testMethodNotAllowedSetsResponse() throws Exception {
Mockito.when(request.getMethod()).thenReturn("get");
filter.addAllowedMethod("post");

filter.doFilter(request, response, chain);

Mockito.verify(response).sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED);
}
}

0 comments on commit 0a984ad

Please sign in to comment.