Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Extension handling #3

Merged
merged 1 commit into from
Apr 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,16 +1,31 @@
package io.conduktor.kafka.security.oauthbearer.azure;

import org.apache.kafka.common.KafkaException;
import org.apache.kafka.common.config.ConfigException;
import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
import org.apache.kafka.common.security.auth.SaslExtensions;
import org.apache.kafka.common.security.auth.SaslExtensionsCallback;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenRetriever;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidatorFactory;
import org.apache.kafka.common.security.oauthbearer.internals.secured.JaasOptionsUtils;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback;
import org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerClientInitialResponse;
import org.apache.kafka.common.security.oauthbearer.internals.secured.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.security.auth.callback.Callback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.auth.login.AppConfigurationEntry;
import javax.security.sasl.SaslException;
import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

public class AzureManagedIdentityCallbackHandler extends OAuthBearerLoginCallbackHandler {
public class AzureManagedIdentityCallbackHandler implements AuthenticateCallbackHandler {

private static final Logger log = LoggerFactory.getLogger(AzureManagedIdentityCallbackHandler.class);

public static final String TENANT_ID_CONFIG = "tenantId";
public static final String CLIENT_ID_CONFIG = OAuthBearerLoginCallbackHandler.CLIENT_ID_CONFIG;
Expand All @@ -23,13 +38,111 @@ public class AzureManagedIdentityCallbackHandler extends OAuthBearerLoginCallbac
public static final String CLIENT_CERTIFICATE_PASSWORD_DOC = "The passphrase for certificate";
public static final String SCOPE_DOC = OAuthBearerLoginCallbackHandler.SCOPE_DOC;

private static final String EXTENSION_PREFIX = "extension_";

private Map<String, Object> moduleOptions;

private AccessTokenRetriever accessTokenRetriever;

private AccessTokenValidator accessTokenValidator;

private boolean isInitialized = false;


@Override
public void configure(Map<String, ?> configs, String saslMechanism, List<AppConfigurationEntry> jaasConfigEntries) {
this.moduleOptions = JaasOptionsUtils.getOptions(saslMechanism, jaasConfigEntries);
AccessTokenRetriever accessTokenRetriever = AzureIdentityAccessTokenRetriever.create(this.moduleOptions);
AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(configs, saslMechanism);
this.init(accessTokenRetriever, accessTokenValidator);
}

public void init(AccessTokenRetriever accessTokenRetriever, AccessTokenValidator accessTokenValidator) {
this.accessTokenRetriever = accessTokenRetriever;
this.accessTokenValidator = accessTokenValidator;

try {
this.accessTokenRetriever.init();
} catch (IOException var4) {
throw new KafkaException("The OAuth login configuration encountered an error when initializing the AccessTokenRetriever", var4);
}

this.isInitialized = true;
}

@Override
public void close() {
if (accessTokenRetriever != null) {
try {
this.accessTokenRetriever.close();
} catch (IOException e) {
log.warn("The OAuth login configuration encountered an error when closing the AccessTokenRetriever", e);
}
}
}

@Override
public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
checkInitialized();
for (Callback callback : callbacks) {
if (callback instanceof OAuthBearerTokenCallback) {
handleTokenCallback((OAuthBearerTokenCallback) callback);
} else if (callback instanceof SaslExtensionsCallback) {
handleExtensionsCallback((SaslExtensionsCallback) callback);
} else {
throw new UnsupportedCallbackException(callback);
}
}
}

private void handleTokenCallback(OAuthBearerTokenCallback callback) throws IOException {
checkInitialized();
String accessToken = accessTokenRetriever.retrieve();

try {
OAuthBearerToken token = accessTokenValidator.validate(accessToken);
callback.token(token);
} catch (ValidateException e) {
log.warn(e.getMessage(), e);
callback.error("invalid_token", e.getMessage(), null);
}
}

private void handleExtensionsCallback(SaslExtensionsCallback callback) {
checkInitialized();

Map<String, String> extensions = new HashMap<>();

for (Map.Entry<String, Object> configEntry : this.moduleOptions.entrySet()) {
String key = configEntry.getKey();

if (!key.startsWith(EXTENSION_PREFIX))
continue;

Object valueRaw = configEntry.getValue();
String value;

if (valueRaw instanceof String)
value = (String) valueRaw;
else
value = String.valueOf(valueRaw);

extensions.put(key.substring(EXTENSION_PREFIX.length()), value);
}

SaslExtensions saslExtensions = new SaslExtensions(extensions);

try {
OAuthBearerClientInitialResponse.validateExtensions(saslExtensions);
} catch (SaslException e) {
throw new ConfigException(e.getMessage());
}

callback.extensions(saslExtensions);
}

private void checkInitialized() {
if (!isInitialized)
throw new IllegalStateException(String.format("To use %s, first call the configure or init method", getClass().getSimpleName()));
}
}