Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Introduce an optional OAuth2WebClient behavior where only a fir… #2117

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ public class OAuth2WebClientOptionsConverter {
public static void fromJson(Iterable<java.util.Map.Entry<String, Object>> json, OAuth2WebClientOptions obj) {
for (java.util.Map.Entry<String, Object> member : json) {
switch (member.getKey()) {
case "failFast":
if (member.getValue() instanceof Boolean) {
obj.setFailFast((Boolean)member.getValue());
}
break;
case "leeway":
if (member.getValue() instanceof Number) {
obj.setLeeway(((Number)member.getValue()).intValue());
Expand All @@ -39,6 +44,7 @@ public static void toJson(OAuth2WebClientOptions obj, JsonObject json) {
}

public static void toJson(OAuth2WebClientOptions obj, java.util.Map<String, Object> json) {
json.put("failFast", obj.getFailFast());
json.put("leeway", obj.getLeeway());
json.put("renewTokenOnForbidden", obj.isRenewTokenOnForbidden());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,14 @@ public class OAuth2WebClientOptions {
*/
public static final int DEFAULT_LEEWAY = 0;

/**
* The default fail fast of requests they happen while refreshing/loading token.
*/
public static final boolean DEFAULT_FAIL_FAST = true;

private boolean renewTokenOnForbidden = DEFAULT_RENEW_TOKEN_ON_FORBIDDEN;
private int leeway = DEFAULT_LEEWAY;
private boolean failFast = DEFAULT_FAIL_FAST;

public OAuth2WebClientOptions() {
}
Expand Down Expand Up @@ -111,4 +117,25 @@ public OAuth2WebClientOptions setLeeway(int leeway) {
this.leeway = leeway;
return this;
}

/**
* Should all requests that happen while this object is refreshing/loading token fail as soon as possible, or queue
* the incoming requests until the in flight operation completes.
*
* @return default value is {@link #DEFAULT_FAIL_FAST}
*/
public boolean getFailFast() {
return failFast;
}

/**
* Set if all requests that happen while this object is refreshing/loading token should fail as soon as possible.
*
* @param failFast requests while refreshing/loading token
* @return fluent self
*/
public OAuth2WebClientOptions setFailFast(boolean failFast) {
this.failFast = failFast;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
import io.vertx.core.Future;
import io.vertx.core.Handler;
import io.vertx.core.Promise;
import io.vertx.ext.auth.User;

import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;

import static io.vertx.core.http.HttpHeaders.AUTHORIZATION;

Expand All @@ -26,9 +28,12 @@ public class OAuth2AwareInterceptor implements Handler<HttpContext<?>> {

private final Set<HttpContext<?>> dejaVu = new HashSet<>();
private final Oauth2WebClientAware parentClient;
private final AtomicReference<Future<User>> pendingAuth = new AtomicReference<>();
private final boolean failFast;

public OAuth2AwareInterceptor(Oauth2WebClientAware webClientOauth2Aware) {
public OAuth2AwareInterceptor(Oauth2WebClientAware webClientOauth2Aware, boolean failFast) {
this.parentClient = webClientOauth2Aware;
this.failFast = failFast;
}

@Override
Expand Down Expand Up @@ -58,9 +63,7 @@ private void processResponse(HttpContext<?> context) {
} else {
// we need some stop condition so we don't go into an infinite loop
dejaVu.add(context);
parentClient
.oauth2Auth()
.authenticate(parentClient.getCredentials())
authenticate()
.onSuccess(userResult -> {
// update the user
parentClient.setUser(userResult);
Expand All @@ -87,19 +90,15 @@ private Future<Void> createRequest(HttpContext<?> context) {
if (parentClient.getUser() != null) {
if (parentClient.getUser().expired(parentClient.getLeeway())) {
//Token has expired we need to invalidate the session
parentClient
.oauth2Auth()
.refresh(parentClient.getUser())
refreshToken()
.onSuccess(userResult -> {
parentClient.setUser(userResult);
context.requestOptions().putHeader(AUTHORIZATION, "Bearer " + userResult.principal().getString("access_token"));
promise.complete();
})
.onFailure(error -> {
// Refresh token failed, we can try standard authentication
parentClient
.oauth2Auth()
.authenticate(parentClient.getCredentials())
authenticate()
.onSuccess(userResult -> {
parentClient.setUser(userResult);
context.requestOptions().putHeader(AUTHORIZATION, "Bearer " + userResult.principal().getString("access_token"));
Expand All @@ -117,9 +116,7 @@ private Future<Void> createRequest(HttpContext<?> context) {
promise.complete();
}
} else {
parentClient
.oauth2Auth()
.authenticate(parentClient.getCredentials())
authenticate()
.onSuccess(userResult -> {
parentClient.setUser(userResult);
context.requestOptions().putHeader(AUTHORIZATION, "Bearer " + userResult.principal().getString("access_token"));
Expand All @@ -133,4 +130,38 @@ private Future<Void> createRequest(HttpContext<?> context) {

return promise.future();
}

private Future<User> authenticate() {
final Future<User> pendingAuthFuture = pendingAuth.get();
if (pendingAuthFuture != null) {
if (failFast) {
return Future.failedFuture("OAuth2 web client authentication in progress and client is configured to fail fast");
} else {
return pendingAuthFuture;
}
}
final Future<User> newAuthFuture = parentClient
.oauth2Auth()
.authenticate(parentClient.getCredentials());
pendingAuth.set(newAuthFuture);
newAuthFuture.onComplete(userAsyncResult -> pendingAuth.set(null));
return newAuthFuture;
}

private Future<User> refreshToken() {
final Future<User> pendingAuthFuture = pendingAuth.get();
if (pendingAuthFuture != null) {
if (failFast) {
return Future.failedFuture("OAuth2 web client token refresh in progress and client is configured to fail fast");
} else {
return pendingAuthFuture;
}
}
final Future<User> newRefreshTokenFuture = parentClient
.oauth2Auth()
.refresh(parentClient.getUser());
pendingAuth.set(newRefreshTokenFuture);
newRefreshTokenFuture.onComplete(userAsyncResult -> pendingAuth.set(null));
return newRefreshTokenFuture;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public Oauth2WebClientAware(WebClient client, OAuth2Auth oauth2Auth, OAuth2WebCl
}
this.option = options;

addInterceptor(new OAuth2AwareInterceptor(this));
addInterceptor(new OAuth2AwareInterceptor(this, options.getFailFast()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package io.vertx.ext.web.client;

import io.vertx.core.Future;
import io.vertx.core.Vertx;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.http.HttpMethod;
import io.vertx.core.http.HttpServerRequest;
import io.vertx.core.json.JsonObject;
import io.vertx.ext.auth.oauth2.OAuth2Auth;
import io.vertx.ext.auth.oauth2.OAuth2FlowType;
Expand All @@ -14,6 +16,7 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;

import static io.vertx.core.Future.failedFuture;
Expand Down Expand Up @@ -119,6 +122,126 @@ public void testWithAuthentication() throws Exception {
awaitLatch(latchClient);
}

@Test
public void testFastFailWhilePendingAuth() throws Exception {
final CountDownLatch latch = new CountDownLatch(1);
final CountDownLatch latchClient = new CountDownLatch(2);

OAuth2Auth oauth2 = OAuth2Auth.create(vertx, new OAuth2Options()
.setFlow(OAuth2FlowType.CLIENT)
.setClientId("client-id")
.setClientSecret("client-secret")
.setSite("http://localhost:8080"));

OAuth2WebClient oauth2WebClient =
OAuth2WebClient.create(WebClientSession.create(webClient), oauth2);

server = vertx.createHttpServer().requestHandler(req -> {
if (req.method() == HttpMethod.POST && "/oauth/token".equals(req.path())) {
assertEquals("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=", req.getHeader("Authorization"));
// Send second auth request while first is still in-flight
oauth2WebClient
.withCredentials(oauthConfig)
.get(8080, "localhost", "/protected/path")
.send(result -> {
if (result.succeeded()) {
fail("Second auth request should fail when fail fast is enabled");
} else {
latchClient.countDown();
// Respond to first request
req.response().putHeader("Content-Type", "application/json").end(fixture.encode());
}
});
} else if (req.method() == HttpMethod.GET && "/protected/path".equals(req.path())) {
assertEquals("Bearer " + fixture.getString("access_token"), req.getHeader("Authorization"));
req.response().end();
} else {
req.response().setStatusCode(400).end();
}
}).listen(8080, ready -> {
if (ready.failed()) {
throw new RuntimeException(ready.cause());
}
// ready
latch.countDown();
});

awaitLatch(latch);

oauth2WebClient
.withCredentials(oauthConfig)
.get(8080, "localhost", "/protected/path")
.send(result -> {
if (result.failed()) {
fail(result.cause());
} else {
assertEquals(200, result.result().statusCode());
latchClient.countDown();
}
});

awaitLatch(latchClient);
}

@Test
public void testOnlyOneAuthReqReachesServerWhenFailFastDisabled() throws Exception {
final CountDownLatch latch = new CountDownLatch(1);
final AtomicInteger serverReachedAuthRequestCount = new AtomicInteger();

server = vertx.createHttpServer().requestHandler(req -> {
if (req.method() == HttpMethod.POST && "/oauth/token".equals(req.path())) {
assertEquals("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=", req.getHeader("Authorization"));
req.response().putHeader("Content-Type", "application/json").end(fixture.encode());
serverReachedAuthRequestCount.incrementAndGet();
} else if (req.method() == HttpMethod.GET && "/protected/path".equals(req.path())) {
assertEquals("Bearer " + fixture.getString("access_token"), req.getHeader("Authorization"));
req.response().end();
} else {
req.response().setStatusCode(400).end();
}
}).listen(8080, ready -> {
if (ready.failed()) {
throw new RuntimeException(ready.cause());
}
// ready
latch.countDown();
});

awaitLatch(latch);


OAuth2Auth oauth2 = OAuth2Auth.create(vertx, new OAuth2Options()
.setFlow(OAuth2FlowType.CLIENT)
.setClientId("client-id")
.setClientSecret("client-secret")
.setSite("http://localhost:8080"));

OAuth2WebClient oauth2WebClient =
OAuth2WebClient.create(WebClientSession.create(webClient), oauth2, new OAuth2WebClientOptions().setFailFast(false));

final int requestCount = 100;

final CountDownLatch latchClient = new CountDownLatch(requestCount);

for (int i = 0; i < requestCount; i++) {
oauth2WebClient
.withCredentials(oauthConfig)
.get(8080, "localhost", "/protected/path")
.send(result -> {
if (result.failed()) {
fail(result.cause());
} else {
assertEquals(200, result.result().statusCode());
latchClient.countDown();
}
});
}

awaitLatch(latchClient);

assertEquals(1, serverReachedAuthRequestCount.get());
}

@Test
public void testWithAuthenticationWithoutSession() throws Exception {
final CountDownLatch latch = new CountDownLatch(1);
Expand Down