Skip to content

Commit

Permalink
KNOX-3073 - Token verification fallback to Knox keys behavior should …
Browse files Browse the repository at this point in the history
…configurable (#949)
  • Loading branch information
pzampino authored Nov 8, 2024
1 parent dceb495 commit 7dd8b43
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
package org.apache.knox.gateway.hadoopauth.filter;

import static org.apache.knox.gateway.provider.federation.jwt.filter.AbstractJWTFilter.JWT_INSTANCE_KEY_FALLBACK;
import static org.easymock.EasyMock.anyString;
import static org.easymock.EasyMock.capture;
import static org.easymock.EasyMock.captureInt;
Expand Down Expand Up @@ -577,6 +578,7 @@ private HadoopAuthFilter testIfJwtSupported(String supportJwt) throws Exception
expect(filterConfig.getInitParameter("support.jwt")).andReturn(supportJwt).anyTimes();
expect(filterConfig.getInitParameter("hadoop.auth.unauthenticated.path.list")).andReturn(null).anyTimes();
expect(filterConfig.getInitParameter("clusterName")).andReturn("topology1").anyTimes();
expect(filterConfig.getInitParameter(JWT_INSTANCE_KEY_FALLBACK)).andReturn("false").anyTimes();
final boolean isJwtSupported = Boolean.parseBoolean(supportJwt);
if (isJwtSupported) {
expect(filterConfig.getInitParameter(JWTFederationFilter.KNOX_TOKEN_AUDIENCES)).andReturn(null).anyTimes();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.util.Arrays;
import java.util.Date;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Locale;
import java.util.Set;
Expand Down Expand Up @@ -102,6 +103,9 @@ public abstract class AbstractJWTFilter implements Filter {
public static final String JWT_EXPECTED_SIGALG = "jwt.expected.sigalg";
public static final String JWT_DEFAULT_SIGALG = "RS256";

public static final String JWT_INSTANCE_KEY_FALLBACK = "jwt.instance.key.fallback";
public static final boolean JWT_INSTANCE_KEY_FALLBACK_DEFAULT = false;

static JWTMessages log = MessagesFactory.get( JWTMessages.class );

private static AuditService auditService = AuditServiceFactory.getAuditService();
Expand All @@ -116,13 +120,14 @@ public abstract class AbstractJWTFilter implements Filter {
private String expectedIssuer;
private String expectedSigAlg;
protected String expectedPrincipalClaim;
protected Set<URI> expectedJWKSUrls = new HashSet();
protected Set<URI> expectedJWKSUrls = new LinkedHashSet();
protected Set<JOSEObjectType> allowedJwsTypes;

private TokenStateService tokenStateService;
private TokenMAC tokenMAC;
protected long idleTimeoutSeconds = -1;
protected String topologyName;
protected boolean isJwtInstanceKeyFallback = JWT_INSTANCE_KEY_FALLBACK_DEFAULT;

@Override
public abstract void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
Expand Down Expand Up @@ -158,6 +163,9 @@ public void init( FilterConfig filterConfig ) throws ServletException {
// Setup the verified tokens cache
topologyName = context != null ? (String) context.getAttribute(GatewayServices.GATEWAY_CLUSTER_ATTRIBUTE) : null;
signatureVerificationCache = SignatureVerificationCache.getInstance(topologyName, filterConfig);

String fallbackConfig = filterConfig.getInitParameter(JWT_INSTANCE_KEY_FALLBACK);
isJwtInstanceKeyFallback = fallbackConfig != null ? Boolean.parseBoolean(fallbackConfig) : JWT_INSTANCE_KEY_FALLBACK_DEFAULT;
}

protected void configureExpectedParameters(FilterConfig filterConfig) {
Expand Down Expand Up @@ -512,17 +520,22 @@ protected boolean verifyTokenSignature(final JWT token) {
// If it has not yet been verified, then perform the verification now
if (!verified) {
try {
boolean attemptedPEMVerification = false;
boolean attemptedJWKSVerification = false;

if (publicKey != null) {
attemptedPEMVerification = true;
verified = authority.verifyToken(token, publicKey);
log.pemVerificationResultMessage(verified);
}

if (!verified && expectedJWKSUrls != null && !expectedJWKSUrls.isEmpty()) {
attemptedJWKSVerification = true;
verified = authority.verifyToken(token, expectedJWKSUrls, expectedSigAlg, allowedJwsTypes);
log.jwksVerificationResultMessage(verified);
}

if(!verified) {
if(!verified && ((!attemptedPEMVerification && !attemptedJWKSVerification) || isJwtInstanceKeyFallback)) {
verified = authority.verifyToken(token);
log.signingKeyVerificationResultMessage(verified);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@
import java.util.UUID;
import java.util.concurrent.TimeUnit;

import static org.apache.knox.gateway.provider.federation.jwt.filter.AbstractJWTFilter.JWT_INSTANCE_KEY_FALLBACK;
import static org.apache.knox.gateway.provider.federation.jwt.filter.JWTFederationFilter.JWKS_URL;
import static org.junit.Assert.fail;

public abstract class AbstractJWTFilterTest {
Expand Down Expand Up @@ -627,6 +629,9 @@ public void testSignatureVerificationChain() throws Exception {
/* Add a failing PEM */
props.put(getVerificationPemProperty(), failingPem);

/* Turn fallback to signing key on */
props.put(JWT_INSTANCE_KEY_FALLBACK, "true");

/* This handler is setup with a publicKey, corresponding privateKey is used to sign the JWT below */
handler.init(new TestFilterConfig(props));

Expand Down Expand Up @@ -660,14 +665,15 @@ public void testSignatureVerificationChain() throws Exception {
* This will test the signature verification chain.
* Specifically the flow when provided PEM is not invalid and
* knox signing key is valid.
*
* AND JWT_INSTANCE_KEY_FALLBACK is true
* NOTE: here valid means can validate JWT.
* @throws Exception
*/
@Test
public void testSignatureVerificationChainWithPEMandSignature() throws Exception {
try {
Properties props = getProperties();
props.put(JWT_INSTANCE_KEY_FALLBACK, "true");
KeyPairGenerator kpg = KeyPairGenerator.getInstance("RSA");
kpg.initialize(2048);

Expand Down Expand Up @@ -709,6 +715,128 @@ public void testSignatureVerificationChainWithPEMandSignature() throws Exception
}
}

@Test
public void testNoPEMOrJwksWithoutFallback() throws Exception {
// Test fallback disabled, but not PEM configured.
// You can't disable key fallback without specifying an explicit verification method.
boolean verified = doTestSignatureVerificationChain(null, null, false);
Assert.assertTrue("Token should have been verified.", verified);
}

@Test
public void testNoPEMOrJwksWithFallback() throws Exception {
boolean verified = doTestSignatureVerificationChain(null, null, true);
Assert.assertTrue("Token should have been verified by falling back to keys.", verified);
}

@Test
public void testInvalidPEMNoJwksWithFallback() throws Exception {
boolean verified = doTestSignatureVerificationChain(pem, null, true);
Assert.assertTrue("Token should have been verified by falling back to keys.", verified);
}

@Test
public void testInvalidPEMNoJwksWithoutFallback() throws Exception {
String invalidPEM = generateInvalidPEM();
boolean verified = doTestSignatureVerificationChain(invalidPEM, null, false);
Assert.assertFalse("Token should NOT have been verified.", verified);
}

@Test
public void testNoPEMInvalidJwksWithoutFallback() throws Exception {
boolean verified = doTestSignatureVerificationChain(null, "https://localhost/nonesense", false);
Assert.assertFalse("Token should have NOT been verified.", verified);
}

@Test
public void testNoPEMInvalidJwksWithFallback() throws Exception {
boolean verified = doTestSignatureVerificationChain(null, "https://localhost/nonesense", true);
Assert.assertTrue("Token should have been verified by falling back to keys.", verified);
}

@Test
public void testInvalidPEMInvalidJwksWithoutFallback() throws Exception {
String invalidPEM = generateInvalidPEM();
boolean verified = doTestSignatureVerificationChain(invalidPEM, "https://localhost/nonesense", false);
Assert.assertFalse("Token should NOT have been verified.", verified);
}

@Test
public void testInvalidPEMInvalidJwksWithFallback() throws Exception {
String invalidPEM = generateInvalidPEM();
boolean verified = doTestSignatureVerificationChain(invalidPEM, "https://localhost/nonesense", true);
Assert.assertTrue("Token should have been verified by falling back to keys.", verified);
}

protected String generateInvalidPEM() throws Exception {
KeyPairGenerator kpg = KeyPairGenerator.getInstance("RSA");
kpg.initialize(2048);

KeyPair KPair = kpg.generateKeyPair();
String dn = buildDistinguishedName(InetAddress.getLocalHost().getHostName());
Certificate cert = X509CertificateUtil.generateCertificate(dn, KPair, 365, "SHA1withRSA");
byte[] data = cert.getEncoded();
Base64 encoder = new Base64( 76, "\n".getBytes( StandardCharsets.US_ASCII ) );
return new String(encoder.encodeToString( data ).getBytes( StandardCharsets.US_ASCII ), StandardCharsets.US_ASCII).trim();
}

/**
* This will test the signature verification chain in the following order
* 1. PEM - check if PEM is configured and signature is validated
* 2. JWKS - check if endpoint id configured if not skip
* 3. Knox signing key - if the above two fail try to validate using knox signing cert
* @throws Exception
*/
public boolean doTestSignatureVerificationChain(final String testPEM,
final String testJwks,
final boolean fallbackToKeys) throws Exception {
boolean isVerified = false;

try {
Properties props = getProperties();
props.put(getAudienceProperty(), "bar");

if (testPEM != null) {
// Add a test PEM
props.put(getVerificationPemProperty(), testPEM);
}

if (testJwks != null) {
// Add the test JWKS URL
props.put(JWKS_URL, testJwks);
}

// Configure fallback to signing key on
props.put(JWT_INSTANCE_KEY_FALLBACK, String.valueOf(fallbackToKeys));

// This handler is setup with a publicKey, corresponding privateKey is used to sign the JWT below
handler.init(new TestFilterConfig(props));

SignedJWT jwt = getJWT(AbstractJWTFilter.JWT_DEFAULT_ISSUER, "alice",
new Date(new Date().getTime() + TimeUnit.MINUTES.toMillis(10)), privateKey);

HttpServletRequest request = EasyMock.createNiceMock(HttpServletRequest.class);
setTokenOnRequest(request, jwt);

EasyMock.expect(request.getRequestURL()).andReturn(new StringBuffer(SERVICE_URL)).anyTimes();
EasyMock.expect(request.getPathInfo()).andReturn("resource").anyTimes();
EasyMock.expect(request.getQueryString()).andReturn(null);
HttpServletResponse response = EasyMock.createNiceMock(HttpServletResponse.class);
EasyMock.expect(response.encodeRedirectURL(SERVICE_URL)).andReturn(SERVICE_URL);
EasyMock.expect(response.getOutputStream()).andAnswer(DummyServletOutputStream::new).anyTimes();
EasyMock.replay(request, response);

TestFilterChain chain = new TestFilterChain();
handler.doFilter(request, response, chain);
isVerified = chain.doFilterCalled;

} catch (ServletException se) {
fail("Should NOT have thrown a ServletException.");
}

return isVerified;
}

@Test
public void testInvalidIssuer() throws Exception {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,5 +111,39 @@ public void testAlternativeCaseUsername() throws Exception {
}
}

@Override
@Test
public void testNoPEMInvalidJwksWithoutFallback() throws Exception {
// No-op: This filter does not appear to support the JWKS URL(s) config like the
// JWTFederationFilter does, so this test does not apply
}

@Override
@Test
public void testNoPEMInvalidJwksWithFallback() throws Exception {
// No-op: This filter does not appear to support the JWKS URL(s) config like the
// JWTFederationFilter does, so this test does not apply
}

@Override
@Test
public void testInvalidPEMNoJwksWithoutFallback() throws Exception {
// No-op: This filter does not appear to support the JWKS URL(s) config like the
// JWTFederationFilter does, so this test does not apply
}

@Override
@Test
public void testInvalidPEMInvalidJwksWithoutFallback() throws Exception {
// No-op: This filter does not appear to support the JWKS URL(s) config like the
// JWTFederationFilter does, so this test does not apply
}

@Override
@Test
public void testInvalidPEMInvalidJwksWithFallback() throws Exception {
// No-op: This filter does not appear to support the JWKS URL(s) config like the
// JWTFederationFilter does, so this test does not apply
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,34 @@ public void testProxiedDefaultAuthenticationProviderURLWithoutMismatchInXForward
Assert.assertEquals(loginURL, "https://remotehost/notgateway/knoxsso/api/v1/websso?originalUrl=" + "https://remotehost/resource");
}

@Override
@Test
public void testNoPEMInvalidJwksWithoutFallback() throws Exception {
// No-op: The SSOCookieProvider does not appear to support the JWKS URL(s) config like the
// JWTFederationFilter does, so this test does not apply
}

@Override
@Test
public void testNoPEMInvalidJwksWithFallback() throws Exception {
// No-op: The SSOCookieProvider does not appear to support the JWKS URL(s) config like the
// JWTFederationFilter does, so this test does not apply
}

@Override
@Test
public void testInvalidPEMInvalidJwksWithoutFallback() throws Exception {
// No-op: The SSOCookieProvider does not appear to support the JWKS URL(s) config like the
// JWTFederationFilter does, so this test does not apply
}

@Override
@Test
public void testInvalidPEMInvalidJwksWithFallback() throws Exception {
// No-op: The SSOCookieProvider does not appear to support the JWKS URL(s) config like the
// JWTFederationFilter does, so this test does not apply
}

@Test
public void testIdleTimoutExceeded() throws Exception {
final TokenStateService tokenStateService = EasyMock.createNiceMock(TokenStateService.class);
Expand Down

0 comments on commit 7dd8b43

Please sign in to comment.