Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auth: Support Azure Entra (Event Hub with Kafka Protocol) #530

Merged
merged 12 commits into from
Sep 18, 2024
Merged
6 changes: 6 additions & 0 deletions api/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@
<version>2.1.0</version>
</dependency>

<dependency>
<groupId>com.azure</groupId>
Haarolean marked this conversation as resolved.
Show resolved Hide resolved
<artifactId>azure-identity</artifactId>
<version>1.13.0</version>
</dependency>

<dependency>
<groupId>org.apache.avro</groupId>
<artifactId>avro</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package io.kafbat.ui.sasl.azure.entra;

import static org.apache.kafka.clients.CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG;

import com.azure.core.credential.AccessToken;
import com.azure.core.credential.TokenCredential;
import com.azure.core.credential.TokenRequestContext;
import com.azure.identity.DefaultAzureCredentialBuilder;
import java.net.URI;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.auth.login.AppConfigurationEntry;
import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AzureEntraLoginCallbackHandler implements AuthenticateCallbackHandler {

private static final Logger LOGGER = LoggerFactory.getLogger(AzureEntraLoginCallbackHandler.class);
Haarolean marked this conversation as resolved.
Show resolved Hide resolved

private static final Duration ACCESS_TOKEN_REQUEST_BLOCK_TIME = Duration.ofSeconds(10);

private static final int ACCESS_TOKEN_REQUEST_MAX_RETRIES = 6;

private static final String TOKEN_AUDIENCE_FORMAT = "%s://%s/.default";

static TokenCredential tokenCredential = new DefaultAzureCredentialBuilder().build();

private TokenRequestContext tokenRequestContext;

@Override
public void configure(
Map<String, ?> configs, String mechanism, List<AppConfigurationEntry> jaasConfigEntries) {
tokenRequestContext = buildTokenRequestContext(configs);
}

private TokenRequestContext buildTokenRequestContext(Map<String, ?> configs) {
URI uri = buildEventHubsServerUri(configs);
String tokenAudience = buildTokenAudience(uri);

TokenRequestContext request = new TokenRequestContext();
request.addScopes(tokenAudience);
return request;
}

private URI buildEventHubsServerUri(Map<String, ?> configs) {
final List<String> bootstrapServers = (List<String>) configs.get(BOOTSTRAP_SERVERS_CONFIG);

if (null == bootstrapServers) {
final String message = BOOTSTRAP_SERVERS_CONFIG + " is missing from the Kafka configuration.";
LOGGER.error(message);
throw new IllegalArgumentException(message);
}

if (bootstrapServers.size() != 1) {
final String message =
Haarolean marked this conversation as resolved.
Show resolved Hide resolved
BOOTSTRAP_SERVERS_CONFIG
+ " contains multiple bootstrap servers. Only a single bootstrap server is supported.";
LOGGER.error(message);
throw new IllegalArgumentException(message);
}

return URI.create("https://" + bootstrapServers.get(0));
}

private String buildTokenAudience(URI uri) {
return String.format(TOKEN_AUDIENCE_FORMAT, uri.getScheme(), uri.getHost());
}

@Override
public void handle(Callback[] callbacks) throws UnsupportedCallbackException {
for (Callback callback : callbacks) {
if (callback instanceof OAuthBearerTokenCallback oauthCallback) {
handleOAuthCallback(oauthCallback);
} else {
throw new UnsupportedCallbackException(callback);
}
}
}

private void handleOAuthCallback(OAuthBearerTokenCallback oauthCallback) {
try {
final OAuthBearerToken token = tokenCredential
.getToken(tokenRequestContext)
.map(AzureEntraOAuthBearerTokenImpl::new)
.timeout(ACCESS_TOKEN_REQUEST_BLOCK_TIME)
.doOnError(e -> LOGGER.warn("Failed to acquire Azure token for Event Hub Authentication. Retrying.", e))
.retry(ACCESS_TOKEN_REQUEST_MAX_RETRIES)
.block();

oauthCallback.token(token);
} catch (final RuntimeException e) {
final String message =
"Failed to acquire Azure token for Event Hub Authentication. "
+ "Please ensure valid Azure credentials are configured.";
LOGGER.error(message, e);
oauthCallback.error("invalid_grant", message, null);
}
}

public void close() {
// NOOP
}

void setTokenCredential(final TokenCredential tokenCredential) {
AzureEntraLoginCallbackHandler.tokenCredential = tokenCredential;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package io.kafbat.ui.sasl.azure.entra;
Haarolean marked this conversation as resolved.
Show resolved Hide resolved

import com.azure.core.credential.AccessToken;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.JWTParser;
import java.text.ParseException;
import java.util.Arrays;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.kafka.common.errors.SaslAuthenticationException;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;

public class AzureEntraOAuthBearerTokenImpl implements OAuthBearerToken {
Haarolean marked this conversation as resolved.
Show resolved Hide resolved

private final AccessToken accessToken;

private final JWTClaimsSet claims;

public AzureEntraOAuthBearerTokenImpl(AccessToken accessToken) {
this.accessToken = accessToken;

try {
claims = JWTParser.parse(accessToken.getToken()).getJWTClaimsSet();
} catch (ParseException exception) {
throw new SaslAuthenticationException("Unable to parse the access token", exception);
}
}

@Override
public String value() {
return accessToken.getToken();
}

@Override
public Long startTimeMs() {
return claims.getIssueTime().getTime();
}

@Override
public long lifetimeMs() {
return claims.getExpirationTime().getTime();
}

@Override
public Set<String> scope() {
// Referring to
// https://docs.microsoft.com/azure/active-directory/develop/access-tokens#payload-claims, the
// scp
// claim is a String which is presented as a space separated list.
return Optional.ofNullable(claims.getClaim("scp"))
Haarolean marked this conversation as resolved.
Show resolved Hide resolved
.map(s -> Arrays.stream(((String) s).split(" ")).collect(Collectors.toSet()))
.orElse(null);
}

@Override
public String principalName() {
return (String) claims.getClaim("upn");
}

public boolean isExpired() {
return accessToken.isExpired();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
package io.kafbat.ui.sasl.azure.entra;

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.notNullValue;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.azure.core.credential.AccessToken;
import com.azure.core.credential.TokenCredential;
import com.azure.core.credential.TokenRequestContext;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.UnsupportedCallbackException;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import reactor.core.publisher.Mono;

@ExtendWith(MockitoExtension.class)
public class AzureEntraLoginCallbackHandlerTest {

// These are not real tokens. It was generated using fake values with an invalid signature,
// so it is safe to store here.
private static final String VALID_SAMPLE_TOKEN =
"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsIng1dCI6IjlHbW55RlBraGMzaE91UjIybXZTdmduTG83WSIsImtpZCI6IjlHbW55"
+ "RlBraGMzaE91UjIybXZTdmduTG83WSJ9.eyJhdWQiOiJodHRwczovL3NhbXBsZS5zZXJ2aWNlYnVzLndpbmRvd3MubmV0IiwiaX"
+ "NzIjoiaHR0cHM6Ly9zdHMud2luZG93cy5uZXQvc2FtcGxlLyIsImlhdCI6MTY5ODQxNTkxMiwibmJmIjoxNjk4NDE1OTEzLCJleH"
+ "AiOjE2OTg0MTU5MTQsImFjciI6IjEiLCJhaW8iOiJzYW1wbGUtYWlvIiwiYW1yIjpbXSwiYXBwaWQiOiJzYW1wbGUtYXBwLWlkIi"
+ "wiYXBwaWRhY3IiOiIwIiwiZmFtaWx5X25hbWUiOiJTYW1wbGUiLCJnaXZlbl9uYW1lIjoiU2FtcGxlIiwiZ3JvdXBzIjpbXSwiaX"
+ "BhZGRyIjoiMTI3LjAuMC4xIiwibmFtZSI6IlNhbXBsZSBOYW1lIiwib2lkIjoic2FtcGxlLW9pZCIsIm9ucHJlbV9zaWQiOiJzYW"
+ "1wbGUtb25wcmVtX3NpZCIsInB1aWQiOiJzYW1wbGUtcHVpZCIsInJoIjoic2FtcGxlLXJoIiwic2NwIjoiZXZlbnRfaHViIHN0b3"
+ "JhZ2VfYWNjb3VudCIsInN1YiI6IlNhbXBsZSBTdWJqZWN0IiwidGlkIjoic2FtcGxlLXRpZCIsInVuaXF1ZV9uYW1lIjoic2FtcG"
+ "xlQG1pY3Jvc29mdC5jb20iLCJ1cG4iOiJzYW1wbGVAbWljcm9zb2Z0LmNvbSIsInV0aSI6InNhbXBsZS11dGkiLCJ2ZXIiOiIxLj"
+ "AiLCJ3aWRzIjpbXX0.DC_guYOsDlRc5GsXE39dn_zlBX54_Y8_mDTLXLgienl9dPMX5RE2X1QXGXA9ukZtptMzP_0wcoqDDjNrys"
+ "GrNhztyeOr0YSeMMFq2NQ5vMBzLapwONwsnv55Hn0jOje9cqnMf43z1LHI6q6-rIIRz-SiTuoYUgOTxzFftpt-7FSqLjQpYEH7bL"
+ "p-0yIU_aJUSb5HQTJbtYYOb54hsZ6VXpaiZ013qGtKODbHTG37kdoIw2MPn66CxanLZKeZM31IVxC-duAqxDgK4O2Ne6xRZRIPW1"
+ "yt61QnZutWTJ4bAyhmplym3OWZ369cyiSJek0uyS5tibXeCYG4Kk8UQSFcsyfwgOsD0xvvcXcLexcUcEekoNBj6ixDhWssFzhC8T"
+ "Npy8-QKNe_Tp6qHzJdI6OV71jpDkGvcmseLHC9GOxBWB0IdYbePTFK-rz2dkN3uMUiFwQJvEbORsq1IaQXj2esT0F7sMfqzWQF9h"
+ "koVy4mJg_auvrZlnQkNPdLHfCacU33ZPwtuSS6b-0XolbxZ5DlJ4p1OJPeHl2xsi61qiHuCBsmnkLNtHmyxNTXGs7xc4dEQokaCK"
+ "-FB_lzC3D4mkJMxKWopQGXnQtizaZjyclGpiUFs3mEauxC7RpsbanitxPFs7FK3mY0MQJk9JNVi1oM-8qfEp8nYT2DwFBhLcIp2z"
+ "Q";

@Mock
private OAuthBearerTokenCallback oauthBearerTokenCallBack;

@Mock
private OAuthBearerToken oauthBearerToken;

@Mock
private TokenCredential tokenCredential;

@Mock
private AccessToken accessToken;

private AzureEntraLoginCallbackHandler azureEntraLoginCallbackHandler;

@BeforeEach
public void beforeEach() {
azureEntraLoginCallbackHandler = new AzureEntraLoginCallbackHandler();
azureEntraLoginCallbackHandler.setTokenCredential(tokenCredential);
}

@Test
public void shouldProvideTokenToCallbackWithSuccessfulTokenRequest()
throws UnsupportedCallbackException {
final Map<String, Object> configs = new HashMap<>();
configs.put(
"bootstrap.servers",
List.of("test-eh.servicebus.windows.net:9093"));

when(tokenCredential.getToken(any(TokenRequestContext.class))).thenReturn(Mono.just(accessToken));
when(accessToken.getToken()).thenReturn(VALID_SAMPLE_TOKEN);

azureEntraLoginCallbackHandler.configure(configs, null, null);
azureEntraLoginCallbackHandler.handle(new Callback[] {oauthBearerTokenCallBack});

final ArgumentCaptor<TokenRequestContext> contextCaptor =
ArgumentCaptor.forClass(TokenRequestContext.class);
final ArgumentCaptor<OAuthBearerToken> tokenCaptor =
ArgumentCaptor.forClass(OAuthBearerToken.class);

verify(tokenCredential, times(1)).getToken(contextCaptor.capture());
verify(oauthBearerTokenCallBack, times(0)).error(anyString(), anyString(), anyString());
verify(oauthBearerTokenCallBack, times(1)).token(tokenCaptor.capture());

final TokenRequestContext tokenRequestContext = contextCaptor.getValue();
assertThat(tokenRequestContext, is(notNullValue()));
assertThat(
tokenRequestContext.getScopes(),
is(List.of("https://test-eh.servicebus.windows.net/.default")));
assertThat(tokenRequestContext.getClaims(), is(nullValue()));
assertThat(tokenRequestContext.getTenantId(), is(nullValue()));
assertFalse(tokenRequestContext.isCaeEnabled());

assertThat(tokenCaptor.getValue(), is(notNullValue()));
assertEquals(VALID_SAMPLE_TOKEN, tokenCaptor.getValue().value());
}

@Test
public void shouldProvideErrorToCallbackWithTokenError() throws UnsupportedCallbackException {
final Map<String, Object> configs = new HashMap<>();
configs.put(
"bootstrap.servers",
List.of("test-eh.servicebus.windows.net:9093"));

when(tokenCredential.getToken(any(TokenRequestContext.class)))
.thenThrow(new RuntimeException("failed to acquire token"));

azureEntraLoginCallbackHandler.configure(configs, null, null);
azureEntraLoginCallbackHandler.handle(new Callback[] {oauthBearerTokenCallBack});

verify(oauthBearerTokenCallBack, times(1))
.error(
"invalid_grant",
"Failed to acquire Azure token for Event Hub Authentication. "
+ "Please ensure valid Azure credentials are configured.",
null);
verify(oauthBearerTokenCallBack, times(0)).token(any());
}

@Test
public void shouldThrowExceptionWithNullBootstrapServers() {
final Map<String, Object> configs = new HashMap<>();

assertThrows(IllegalArgumentException.class, () -> azureEntraLoginCallbackHandler.configure(
configs, null, null));
}

@Test
public void shouldThrowExceptionWithMultipleBootstrapServers() {
final Map<String, Object> configs = new HashMap<>();
configs.put("bootstrap.servers", List.of("server1", "server2"));

assertThrows(IllegalArgumentException.class, () -> azureEntraLoginCallbackHandler.configure(
configs, null, null));
}

@Test
public void shouldThrowExceptionWithUnsupportedCallback() {
assertThrows(UnsupportedCallbackException.class, () -> azureEntraLoginCallbackHandler.handle(
new Callback[] {mock(Callback.class)}));
}

@Test
public void shouldDoNothingOnClose() {
azureEntraLoginCallbackHandler.close();
}

@Test
public void shouldSupportDefaultConstructor() {
new AzureEntraLoginCallbackHandler();
}
}
Loading
Loading