Skip to content

Commit

Permalink
Adds customizable header and adapts to async build
Browse files Browse the repository at this point in the history
  • Loading branch information
spencergibb committed Dec 4, 2024
1 parent 7db16a7 commit d68ac81
Showing 1 changed file with 59 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import java.util.function.Supplier;

import io.github.bucket4j.Bandwidth;
import io.github.bucket4j.BucketConfiguration;
Expand All @@ -33,9 +35,15 @@
import org.springframework.cloud.gateway.route.RouteDefinitionRouteLocator;
import org.springframework.cloud.gateway.support.ConfigurationService;
import org.springframework.core.style.ToStringCreator;
import org.springframework.util.Assert;

public class Bucket4jRateLimiter extends AbstractRateLimiter<Bucket4jRateLimiter.Config> {

/**
* Default Header Name.
*/
public static final String DEFAULT_HEADER_NAME = "X-RateLimit-Remaining";

/**
* Redis Rate Limiter property name.
*/
Expand All @@ -56,9 +64,7 @@ public Bucket4jRateLimiter(AsyncProxyManager<String> proxyManager, Configuration
public Mono<Response> isAllowed(String routeId, String id) {
Config routeConfig = loadRouteConfiguration(routeId);

BucketConfiguration bucketConfiguration = getBucketConfiguration(routeConfig);

AsyncBucketProxy bucket = proxyManager.builder().build(id, bucketConfiguration);
AsyncBucketProxy bucket = proxyManager.builder().build(id, routeConfig.getConfigurationSupplier());
CompletableFuture<ConsumptionProbe> bucketFuture = bucket
.tryConsumeAndReturnRemaining(routeConfig.getRequestedTokens());
return Mono.fromFuture(bucketFuture).onErrorResume(throwable -> {
Expand All @@ -78,12 +84,6 @@ public Mono<Response> isAllowed(String routeId, String id) {
});
}

protected static BucketConfiguration getBucketConfiguration(Config routeConfig) {
return BucketConfiguration.builder()
.addLimit(Bandwidth.simple(routeConfig.getCapacity(), routeConfig.getPeriod()))
.build();
}

protected Config loadRouteConfiguration(String routeId) {
Config routeConfig = getConfig().getOrDefault(routeId, defaultConfig);

Expand All @@ -99,19 +99,33 @@ protected Config loadRouteConfiguration(String routeId) {

public Map<String, String> getHeaders(Config config, Long tokensLeft) {
Map<String, String> headers = new HashMap<>();
// TODO: configurable isIncludeHeaders?
// if (isIncludeHeaders()) {
// TODO: configurable headers ala RedisRateLimiter
headers.put("X-RateLimit-Remaining", tokensLeft.toString());
headers.put(config.getHeaderName(), tokensLeft.toString());
// }
return headers;
}

public static class Config {

// TODO: create simple and classic w/Refill
// TODO: create simple and classic w/Refill (see builder)

private static final Function<Config, BucketConfiguration> DEFAULT_CONFIGURATION_BUILDER = config -> BucketConfiguration
.builder()
.addLimit(Bandwidth.builder()
.capacity(config.getCapacity())
.refillGreedy(config.getCapacity(), config.getPeriod())
.build())
.build();

long capacity;

Function<Config, BucketConfiguration> configurationBuilder = DEFAULT_CONFIGURATION_BUILDER;

Supplier<CompletableFuture<BucketConfiguration>> configurationSupplier;

String headerName = DEFAULT_HEADER_NAME;

Duration period;

private long requestedTokens = 1;
Expand All @@ -125,6 +139,37 @@ public Config setCapacity(long capacity) {
return this;
}

public Function<Config, BucketConfiguration> getConfigurationBuilder() {
return configurationBuilder;
}

public void setConfigurationBuilder(Function<Config, BucketConfiguration> configurationBuilder) {
Assert.notNull(configurationBuilder, "configurationBuilder may not be null");
this.configurationBuilder = configurationBuilder;
}

public Supplier<CompletableFuture<BucketConfiguration>> getConfigurationSupplier() {
if (configurationSupplier != null) {
return configurationSupplier;
}
return () -> CompletableFuture.completedFuture(getConfigurationBuilder().apply(this));
}

public void setConfigurationSupplier(Function<Config, BucketConfiguration> configurationBuilder) {
Assert.notNull(configurationBuilder, "configurationBuilder may not be null");
this.configurationBuilder = configurationBuilder;
}

public String getHeaderName() {
return headerName;
}

public Config setHeaderName(String headerName) {
Assert.notNull(headerName, "headerName may not be null");
this.headerName = headerName;
return this;
}

public Duration getPeriod() {
return period;
}
Expand All @@ -145,8 +190,9 @@ public Config setRequestedTokens(long requestedTokens) {

public String toString() {
return new ToStringCreator(this).append("capacity", capacity)
.append("requestedTokens", requestedTokens)
.append("headerName", headerName)
.append("period", period)
.append("requestedTokens", requestedTokens)
.toString();
}

Expand Down

0 comments on commit d68ac81

Please sign in to comment.