Skip to content

Commit

Permalink
KNOX-3037 - Client credentials flow accepts essential parameters in t…
Browse files Browse the repository at this point in the history
…he request body only
  • Loading branch information
smolnar82 committed May 8, 2024
1 parent 6c26ec6 commit d2c67ea
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,10 @@ public interface JWTMessages {
@Message( level = MessageLevel.WARN, text = "Unable to derive authentication provider URL: {0}" )
void failedToDeriveAuthenticationProviderUrl(@StackTrace( level = MessageLevel.ERROR) Exception e);

@Message( level = MessageLevel.ERROR,
text = "The configuration value ({0}) for maximum token verification cache is invalid; Using the default value." )
@Message( level = MessageLevel.ERROR, text = "The configuration value ({0}) for maximum token verification cache is invalid; Using the default value." )
void invalidVerificationCacheMaxConfiguration(String value);

@Message( level = MessageLevel.ERROR,
text = "Missing token passcode." )
@Message( level = MessageLevel.ERROR, text = "Missing token passcode." )
void missingTokenPasscode();

@Message( level = MessageLevel.INFO, text = "Initialized token signature verification cache for the {0} topology." )
Expand All @@ -114,4 +112,10 @@ public interface JWTMessages {
@Message( level = MessageLevel.INFO, text = "Idle timeout has been configured to {0} seconds in {1}" )
void configuredIdleTimeout(long idleTimeout, String topology);

@Message(level = MessageLevel.WARN, text = "Client secret passed as a query parameter and exposed in the logs.")
void clientSecretExposed();

@Message(level = MessageLevel.ERROR, text = "Error while fetching grant type and client secret from the request: {0}")
void errorFetchingClientSecret(String errorMessage, @StackTrace(level = MessageLevel.DEBUG) Exception e);

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.apache.knox.gateway.util.AuthFilterUtils.DEFAULT_AUTH_UNAUTHENTICATED_PATHS_PARAM;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URLDecoder;
import java.nio.charset.StandardCharsets;
import java.text.ParseException;
import java.util.Base64;
import java.util.HashSet;
Expand Down Expand Up @@ -50,6 +54,7 @@
import org.apache.knox.gateway.util.AuthFilterUtils;
import org.apache.knox.gateway.util.CertificateUtils;
import org.apache.knox.gateway.util.CookieUtils;
import org.apache.knox.gateway.util.RequestUtils;

import com.nimbusds.jose.JOSEObjectType;

Expand Down Expand Up @@ -224,7 +229,7 @@ private String decodeBase64(String toBeDecoded) {
return new String(Base64.getDecoder().decode(toBeDecoded.getBytes(UTF_8)), UTF_8);
}

public Pair<TokenType, String> getWireToken(final ServletRequest request) {
public Pair<TokenType, String> getWireToken(final ServletRequest request) throws IOException {
Pair<TokenType, String> parsed = null;
String token = null;
final String header = ((HttpServletRequest)request).getHeader("Authorization");
Expand Down Expand Up @@ -253,12 +258,9 @@ public Pair<TokenType, String> getWireToken(final ServletRequest request) {
}

return parsed;
}

private Pair<TokenType, String> parseFromClientCredentialsFlow(ServletRequest request) {
Pair<TokenType, String> parsed = null;
String token = null;
}

private Pair<TokenType, String> parseFromClientCredentialsFlow(ServletRequest request) throws IOException {
/*
POST /{tenant}/oauth2/v2.0/token HTTP/1.1
Host: login.microsoftonline.com:443
Expand All @@ -270,15 +272,41 @@ private Pair<TokenType, String> parseFromClientCredentialsFlow(ServletRequest re
&grant_type=client_credentials
*/

String grantType = request.getParameter(GRANT_TYPE);
if (CLIENT_CREDENTIALS.equals(grantType)) {
// this is indeed a client credentials flow client_id and
// client_secret are expected now the client_id will be in
// the token as the token_id so we will get that later
token = request.getParameter(CLIENT_SECRET);
parsed = Pair.of(TokenType.Passcode, token);
if (request.getParameter(CLIENT_SECRET) != null) {
log.clientSecretExposed();
}
return parsed;
return getClientCredentialsFromRequestBody(request);
}

private Pair<TokenType, String> getClientCredentialsFromRequestBody(ServletRequest request) throws IOException {
try {
final String requestBodyString = getRequestBodyString(request);
final String grantType = RequestUtils.getRequestBodyParameter(requestBodyString, GRANT_TYPE);
if (CLIENT_CREDENTIALS.equals(grantType)) {
// this is indeed a client credentials flow client_id and
// client_secret are expected now the client_id will be in
// the token as the token_id so we will get that later
final String clientSecret = RequestUtils.getRequestBodyParameter(requestBodyString, CLIENT_SECRET);
return Pair.of(TokenType.Passcode, clientSecret);
}
} catch (IOException e) {
log.errorFetchingClientSecret(e.getMessage(), e);
throw e;
}
return null;
}

private String getRequestBodyString(ServletRequest request) throws IOException {
if (request.getInputStream() != null) {
final BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(request.getInputStream()));
final StringBuilder requestBodyBuilder = new StringBuilder();
String line;
while ((line = bufferedReader.readLine()) != null) {
requestBodyBuilder.append(line);
}
return URLDecoder.decode(requestBodyBuilder.toString(), StandardCharsets.UTF_8.name());
}
return null;
}

private Pair<TokenType, String> parseFromHTTPBasicCredentials(final String header) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,28 @@
import static org.apache.knox.gateway.provider.federation.jwt.filter.AbstractJWTFilter.JWT_DEFAULT_ISSUER;
import static org.apache.knox.gateway.provider.federation.jwt.filter.SSOCookieFederationFilter.DEFAULT_SSO_COOKIE_NAME;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;

import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.Date;
import java.util.Properties;

import javax.servlet.FilterConfig;
import javax.servlet.ServletInputStream;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.knox.gateway.provider.federation.jwt.filter.AbstractJWTFilter;
import org.apache.knox.gateway.provider.federation.jwt.filter.JWTFederationFilter;
import org.apache.knox.gateway.provider.federation.jwt.filter.SignatureVerificationCache;
import org.apache.knox.gateway.provider.federation.jwt.filter.JWTFederationFilter.TokenType;
import org.apache.knox.gateway.services.security.token.TokenMetadata;
import org.apache.knox.gateway.services.security.token.TokenStateService;
import org.apache.knox.test.mock.MockServletInputStream;
import org.easymock.EasyMock;
import org.junit.Assert;
import org.junit.Before;
Expand All @@ -43,6 +51,7 @@

@SuppressWarnings("PMD.TestClassWithoutTestCases")
public class JWTFederationFilterTest extends AbstractJWTFilterTest {

@Before
public void setUp() {
handler = new TestJWTFederationFilter();
Expand Down Expand Up @@ -114,6 +123,30 @@ public void testVerifyPasscodeTokensTssDisabled() throws Exception {
testVerifyPasscodeTokens(false);
}

@Test
public void testGetWireTokenUsingClientCredentialsFlow() throws Exception {
final String clientSecret = "sup3r5ecreT!";
final HttpServletRequest request = EasyMock.createNiceMock(HttpServletRequest.class);
EasyMock.expect(request.getInputStream()).andAnswer(() -> produceServletInputStream(clientSecret)).atLeastOnce();
EasyMock.replay(request);

handler.init(new TestFilterConfig(getProperties()));
final Pair<TokenType, String> wireToken = ((TestJWTFederationFilter) handler).getWireToken(request);

EasyMock.verify(request);

assertNotNull(wireToken);
assertEquals(TokenType.Passcode, wireToken.getLeft());
assertEquals(clientSecret, wireToken.getRight());
}

private ServletInputStream produceServletInputStream(String clientSecret) {
final String requestBody = JWTFederationFilter.GRANT_TYPE + "=" + JWTFederationFilter.CLIENT_CREDENTIALS + "&" + JWTFederationFilter.CLIENT_SECRET + "="
+ clientSecret;
final InputStream inputStream = IOUtils.toInputStream(requestBody, StandardCharsets.UTF_8);
return new MockServletInputStream(inputStream);
}

private void testVerifyPasscodeTokens(boolean tssEnabled) throws Exception {
final String topologyName = "jwt-topology";
final String tokenId = "4e0c548b-6568-4061-a3dc-62908087650a";
Expand Down
5 changes: 5 additions & 0 deletions gateway-util-common/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -94,5 +94,10 @@
<artifactId>gateway-test-utils</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with this
* work for additional information regarding copyright ownership. The ASF
* licenses this file to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package org.apache.knox.gateway.util;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.net.URLDecoder;
import java.nio.charset.StandardCharsets;

import javax.servlet.ServletRequest;

public class RequestUtils {

public static String getRequestBodyParameter(ServletRequest request, String parameter) throws IOException {
return getRequestBodyParameter(request, parameter, false);
}

public static String getRequestBodyParameter(ServletRequest request, String parameter, boolean decode) throws IOException {
return getRequestBodyParameter(request.getInputStream(), parameter, decode);
}

public static String getRequestBodyParameter(InputStream inputStream, String parameter) throws IOException {
return getRequestBodyParameter(inputStream, parameter, false);
}

public static String getRequestBodyParameter(InputStream inputStream, String parameter, boolean decode) throws IOException {
return getRequestBodyParameter(new InputStreamReader(inputStream, StandardCharsets.UTF_8), parameter, decode);
}

public static String getRequestBodyParameter(Reader reader, String parameter) throws IOException {
return getRequestBodyParameter(reader, parameter, false);
}

public static String getRequestBodyParameter(Reader reader, String parameter, boolean decode) throws IOException {
final BufferedReader bufferedReader = new BufferedReader(reader);
final StringBuilder requestBodyBuilder = new StringBuilder();
String line;
while ((line = bufferedReader.readLine()) != null) {
requestBodyBuilder.append(line);
}

final String requestBodyString = decode ? URLDecoder.decode(requestBodyBuilder.toString(), StandardCharsets.UTF_8.name()) : requestBodyBuilder.toString();
return getRequestBodyParameter(requestBodyString, parameter);
}

public static String getRequestBodyParameter(String requestBodyString, String parameter) {
if (requestBodyString != null) {
final String[] requestBodyParams = requestBodyString.split("&");
for (String requestBodyParam : requestBodyParams) {
String[] keyValue = requestBodyParam.split("=", 2);
if (parameter.equals(keyValue[0])) {
return keyValue[1];
}
}
}
return null;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with this
* work for additional information regarding copyright ownership. The ASF
* licenses this file to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package org.apache.knox.gateway.util;

import static org.junit.Assert.assertEquals;

import java.io.InputStream;
import java.nio.charset.StandardCharsets;

import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;

import org.apache.commons.io.IOUtils;
import org.apache.knox.test.mock.MockServletInputStream;
import org.easymock.EasyMock;
import org.junit.Test;

public class RequestUtilsTest {

private static final String REQUEST_BODY_PARAM_NAME = "myParam";
private static final String REQUEST_BODY_PARAM_VALUE_RAW = "This-is_my sample text!";
private static final String REQUEST_BODY_PARAM_VALUE_ENCODED = "This-is_my%20sample%20text%21";

@Test
public void testGetRequestBodyParameterEncoded() throws Exception {
testGetRequestBodyParameter(true);
}

@Test
public void testGetRequestBodyParameterRaw() throws Exception {
testGetRequestBodyParameter(false);
}

private void testGetRequestBodyParameter(boolean decode) throws Exception {
final ServletRequest request = EasyMock.createNiceMock(ServletRequest.class);
EasyMock.expect(request.getInputStream()).andReturn(produceServletInputStream(decode)).anyTimes();

EasyMock.replay(request);

final String requestBodyParam = RequestUtils.getRequestBodyParameter(request, REQUEST_BODY_PARAM_NAME, decode);
assertEquals(REQUEST_BODY_PARAM_VALUE_RAW, requestBodyParam);
}

private ServletInputStream produceServletInputStream(boolean encode) {
final String requestBody = REQUEST_BODY_PARAM_NAME + "=" + (encode ? REQUEST_BODY_PARAM_VALUE_ENCODED : REQUEST_BODY_PARAM_VALUE_RAW);
final InputStream inputStream = IOUtils.toInputStream(requestBody, StandardCharsets.UTF_8);
return new MockServletInputStream(inputStream);
}

}

0 comments on commit d2c67ea

Please sign in to comment.