Skip to content

Commit

Permalink
Add support for OAuth 2 audience use
Browse files Browse the repository at this point in the history
  • Loading branch information
andythsu authored Jan 24, 2024
1 parent 3554ac6 commit 0ad2dbf
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ public class OAuthConfiguration
private List<String> scopes;
private String redirectUrl;
private String userIdField;
private List<String> audiences;

public OAuthConfiguration(String issuer, String clientId, String clientSecret, String tokenEndpoint, String authorizationEndpoint, String jwkEndpoint, List<String> scopes, String redirectUrl, String userIdField)
public OAuthConfiguration(String issuer, String clientId, String clientSecret, String tokenEndpoint, String authorizationEndpoint, String jwkEndpoint, List<String> scopes, String redirectUrl, String userIdField, List<String> audiences)
{
this.issuer = issuer;
this.clientId = clientId;
Expand All @@ -38,6 +39,7 @@ public OAuthConfiguration(String issuer, String clientId, String clientSecret, S
this.scopes = scopes;
this.redirectUrl = redirectUrl;
this.userIdField = userIdField;
this.audiences = audiences;
}

public OAuthConfiguration() {}
Expand Down Expand Up @@ -131,4 +133,14 @@ public void setUserIdField(String userIdField)
{
this.userIdField = userIdField;
}

public List<String> getAudiences()
{
return this.audiences;
}

public void setAudiences(List<String> audiences)
{
this.audiences = audiences;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ public Optional<Map<String, Claim>> getClaimsFromIdToken(String idToken)
try {
DecodedJWT jwt = JWT.decode(idToken);

if (LbTokenUtil.validateToken(idToken, lbKeyProvider.getRsaPublicKey(), jwt.getIssuer())) {
if (LbTokenUtil.validateToken(idToken, lbKeyProvider.getRsaPublicKey(), jwt.getIssuer(), Optional.empty())) {
return Optional.of(jwt.getClaims());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ public Optional<Map<String, Claim>> getClaimsFromIdToken(String idToken)
Jwk jwk = provider.get(jwt.getKeyId());
RSAPublicKey publicKey = (RSAPublicKey) jwk.getPublicKey();

if (LbTokenUtil.validateToken(idToken, publicKey, jwt.getIssuer())) {
if (LbTokenUtil.validateToken(idToken, publicKey, jwt.getIssuer(), Optional.of(oauthConfig.getAudiences()))) {
return Optional.of(jwt.getClaims());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
import com.auth0.jwt.JWT;
import com.auth0.jwt.algorithms.Algorithm;
import com.auth0.jwt.interfaces.DecodedJWT;
import com.auth0.jwt.interfaces.Verification;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.security.interfaces.RSAPublicKey;
import java.util.List;
import java.util.Optional;

public final class LbTokenUtil
{
Expand All @@ -33,7 +36,7 @@ private LbTokenUtil()
{
}

public static boolean validateToken(String idToken, RSAPublicKey publicKey, String issuer)
public static boolean validateToken(String idToken, RSAPublicKey publicKey, String issuer, Optional<List<String>> audiences)
{
try {
if (log.isDebugEnabled()) {
Expand All @@ -43,11 +46,13 @@ public static boolean validateToken(String idToken, RSAPublicKey publicKey, Stri
}

Algorithm algorithm = Algorithm.RSA256(publicKey, null);
JWT.require(algorithm)
.withIssuer(issuer)
.acceptLeeway(60 * 60) // Expired tokens are valid for an hour
.build()
.verify(idToken);
Verification verification =
JWT.require(algorithm)
.withIssuer(issuer);

audiences.ifPresent(auds -> verification.withAnyOfAudience(auds.toArray(new String[0])));

verification.build().verify(idToken);
}
catch (Exception exc) {
log.error("Could not validate token.", exc);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.gateway.ha.security;

import com.auth0.jwt.JWT;
import com.auth0.jwt.algorithms.Algorithm;
import com.auth0.jwt.interfaces.DecodedJWT;
import io.trino.gateway.ha.config.SelfSignKeyPairConfiguration;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.TestInstance.Lifecycle;

import java.util.List;
import java.util.Map;
import java.util.Optional;

import static java.util.Objects.requireNonNull;
import static org.assertj.core.api.Assertions.assertThat;

@TestInstance(Lifecycle.PER_CLASS)
public class TestLbTokenUtil
{
private String idToken;
private final String rsaPrivateKey = "auth/test_private_key.pem";
private final String rsaPublicKey = "auth/test_public_key.pem";
private LbKeyProvider lbKeyProvider;
private DecodedJWT jwt;

@BeforeAll
public void setup()
{
lbKeyProvider = new LbKeyProvider(new SelfSignKeyPairConfiguration(
requireNonNull(getClass().getClassLoader().getResource(rsaPrivateKey)).getFile(),
requireNonNull(getClass().getClassLoader().getResource(rsaPublicKey)).getFile()));
Map<String, Object> headers = java.util.Map.of("alg", "RS256");
Algorithm algorithm = Algorithm.RSA256(lbKeyProvider.getRsaPublicKey(),
lbKeyProvider.getRsaPrivateKey());
idToken = JWT.create()
.withHeader(headers)
.withIssuer(SessionCookie.SELF_ISSUER_ID)
.withSubject("test")
.withAudience("test.com")
.sign(algorithm);
jwt = JWT.decode(idToken);
}

@Test
public void testAudiencesShouldPassIfNoAudiencesAreRequired()
{
assertThat(LbTokenUtil.validateToken(idToken, lbKeyProvider.getRsaPublicKey(), jwt.getIssuer(), Optional.empty())).isTrue();
}

@Test
public void testAudiencesShouldPassIfAnAudienceIsRequired()
{
assertThat(LbTokenUtil.validateToken(idToken, lbKeyProvider.getRsaPublicKey(), jwt.getIssuer(), Optional.of(List.of("test.com")))).isTrue();
}

@Test
public void testAudiencesShouldFailIfAudienceDoesNotMatch()
{
assertThat(LbTokenUtil.validateToken(idToken, lbKeyProvider.getRsaPublicKey(), jwt.getIssuer(), Optional.of(List.of("no_match.com")))).isFalse();
}

@Test
public void testAudiencesShouldPassIfAnyAudienceIsMatched()
{
assertThat(LbTokenUtil.validateToken(idToken, lbKeyProvider.getRsaPublicKey(), jwt.getIssuer(), Optional.of(List.of("test.com", "test1.com")))).isTrue();
}
}

0 comments on commit 0ad2dbf

Please sign in to comment.