Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor RateLimiter #534

Merged
merged 3 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import org.apache.polaris.service.context.CallContextResolver;
import org.apache.polaris.service.context.RealmContextResolver;
import org.apache.polaris.service.ratelimiter.RateLimiter;
import org.apache.polaris.service.ratelimiter.TokenBucketFactory;
import org.apache.polaris.service.storage.PolarisStorageIntegrationProviderImpl;
import org.apache.polaris.service.types.TokenType;
import org.glassfish.hk2.api.Factory;
Expand Down Expand Up @@ -90,6 +91,7 @@ public class PolarisApplicationConfig extends Configuration {
private String awsSecretKey;
private FileIOFactory fileIOFactory;
private RateLimiter rateLimiter;
private TokenBucketFactory tokenBucketFactory;
private TokenBrokerFactory tokenBrokerFactory;

private AccessToken gcpAccessToken;
Expand Down Expand Up @@ -144,6 +146,9 @@ protected void configure() {
bindFactory(SupplierFactory.create(serviceLocator, config::getRateLimiter))
.to(RateLimiter.class)
.ranked(OVERRIDE_BINDING_RANK);
bindFactory(SupplierFactory.create(serviceLocator, config::getTokenBucketFactory))
.to(TokenBucketFactory.class)
.ranked(OVERRIDE_BINDING_RANK);
}
};
}
Expand Down Expand Up @@ -332,6 +337,17 @@ public void setRateLimiter(@Nullable RateLimiter rateLimiter) {
this.rateLimiter = rateLimiter;
}

@JsonProperty("tokenBucketFactory")
private TokenBucketFactory getTokenBucketFactory() {
return tokenBucketFactory;
}

@JsonProperty("tokenBucketFactory")
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.PROPERTY, property = "type")
public void setTokenBucketFactory(@Nullable TokenBucketFactory tokenBucketFactory) {
this.tokenBucketFactory = tokenBucketFactory;
}

Comment on lines +340 to +350

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to scope this config down to the RealmTokenBucketRateLimiter itself? Otherwise the tokenBucketFactory can be unused if you're using a different limiter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried but I'm afraid that it won't be achievable with Dropwizard configuration system + HK2 dependency injection.

BTW we already have the same situation: DefaultPolarisAuthenticator requires a TokenBrokerFactory, but
TestInlineBearerTokenPolarisAuthenticator doesn't – if you use the latter, the tokenBroker configuration section would be ignored.

public void setTaskHandler(TaskHandlerConfiguration taskHandler) {
this.taskHandler = taskHandler;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ contract={org.apache.polaris.service.ratelimiter.RateLimiter}
name=no-op
qualifier={io.smallrye.common.annotation.Identifier}

[org.apache.polaris.service.ratelimiter.TokenBucketRateLimiter]S
contract={org.apache.polaris.service.ratelimiter.RateLimiter}
name=token-bucket
qualifier={io.smallrye.common.annotation.Identifier}

[org.apache.polaris.service.ratelimiter.RealmTokenBucketRateLimiter]S
contract={org.apache.polaris.service.ratelimiter.RateLimiter}
name=realm-token-bucket
qualifier={io.smallrye.common.annotation.Identifier}

[org.apache.polaris.service.ratelimiter.DefaultTokenBucketFactory]S
contract={org.apache.polaris.service.ratelimiter.TokenBucketFactory}
name=default
qualifier={io.smallrye.common.annotation.Identifier}

Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,20 @@
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.smallrye.common.annotation.Identifier;
import java.time.Clock;
import java.time.Instant;
import java.time.ZoneOffset;
import org.apache.polaris.service.ratelimiter.RealmTokenBucketRateLimiter;
import org.apache.polaris.service.ratelimiter.DefaultTokenBucketFactory;
import org.threeten.extra.MutableClock;

/** RealmTokenBucketRateLimiter with a mock clock */
@Identifier("mock-realm-token-bucket")
public class MockRealmTokenBucketRateLimiter extends RealmTokenBucketRateLimiter {
/** TokenBucketFactory with a mock clock */
@Identifier("mock")
public class MockTokenBucketFactory extends DefaultTokenBucketFactory {
public static MutableClock CLOCK = MutableClock.of(Instant.now(), ZoneOffset.UTC);

@JsonCreator
public MockRealmTokenBucketRateLimiter(
@JsonProperty("requestsPerSecond") final long requestsPerSecond,
@JsonProperty("windowSeconds") final long windowSeconds) {
super(requestsPerSecond, windowSeconds);
}

@Override
protected Clock getClock() {
return CLOCK;
public MockTokenBucketFactory(
@JsonProperty("requestsPerSecond") long requestsPerSecond,
@JsonProperty("windowSeconds") long windowSeconds) {
super(requestsPerSecond, windowSeconds, CLOCK);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,15 @@ public class RateLimiterFilterTest {
"server.applicationConnectors[0].port",
"0"), // Bind to random port to support parallelism
ConfigOverride.config("server.adminConnectors[0].port", "0"),
ConfigOverride.config("rateLimiter.type", "mock-realm-token-bucket"),
ConfigOverride.config("tokenBucketFactory.type", "mock"),
ConfigOverride.config(
"rateLimiter.requestsPerSecond", String.valueOf(REQUESTS_PER_SECOND)),
ConfigOverride.config("rateLimiter.windowSeconds", String.valueOf(WINDOW_SECONDS)));
"tokenBucketFactory.requestsPerSecond", String.valueOf(REQUESTS_PER_SECOND)),
ConfigOverride.config(
"tokenBucketFactory.windowSeconds", String.valueOf(WINDOW_SECONDS)));

private static String userToken;
private static String realm;
private static MutableClock clock = MockRealmTokenBucketRateLimiter.CLOCK;
private static MutableClock clock = MockTokenBucketFactory.CLOCK;

@BeforeAll
public static void setup(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,41 +18,49 @@
*/
package org.apache.polaris.service.dropwizard.ratelimiter;

import static org.apache.polaris.service.dropwizard.ratelimiter.MockTokenBucketFactory.CLOCK;

import java.time.Duration;
import org.apache.polaris.core.context.CallContext;
import org.apache.polaris.service.ratelimiter.RateLimiter;
import org.apache.polaris.service.ratelimiter.DefaultTokenBucketFactory;
import org.apache.polaris.service.ratelimiter.RealmTokenBucketRateLimiter;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.threeten.extra.MutableClock;

/** Main unit test class for TokenBucketRateLimiter */
public class RealmTokenBucketRateLimiterTest {
@Test
void testDifferentBucketsDontTouch() {
RateLimiter rateLimiter = new MockRealmTokenBucketRateLimiter(10, 10);
RateLimitResultAsserter asserter = new RateLimitResultAsserter(rateLimiter);
MutableClock clock = MockRealmTokenBucketRateLimiter.CLOCK;
RealmTokenBucketRateLimiter rateLimiter = new RealmTokenBucketRateLimiter();
rateLimiter.setTokenBucketFactory(new DefaultTokenBucketFactory(10, 10, CLOCK));

for (int i = 0; i < 202; i++) {
String realm = (i % 2 == 0) ? "realm1" : "realm2";
CallContext.setCurrentContext(CallContext.of(() -> realm, null));

if (i < 200) {
asserter.canAcquire(1);
Assertions.assertTrue(rateLimiter.canProceed());
} else {
asserter.cantAcquire();
assertCannotProceed(rateLimiter);
}
}

clock.add(Duration.ofSeconds(1));
CLOCK.add(Duration.ofSeconds(1));
for (int i = 0; i < 22; i++) {
String realm = (i % 2 == 0) ? "realm1" : "realm2";
CallContext.setCurrentContext(CallContext.of(() -> realm, null));

if (i < 20) {
asserter.canAcquire(1);
Assertions.assertTrue(rateLimiter.canProceed());
} else {
asserter.cantAcquire();
assertCannotProceed(rateLimiter);
}
}
}

private void assertCannotProceed(RealmTokenBucketRateLimiter rateLimiter) {
for (int i = 0; i < 5; i++) {
Assertions.assertFalse(rateLimiter.canProceed());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.polaris.service.ratelimiter.TokenBucketRateLimiter;
import org.apache.polaris.service.ratelimiter.TokenBucket;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.threeten.extra.MutableClock;
Expand All @@ -38,19 +38,18 @@ void testBasic() {
MutableClock clock = MutableClock.of(Instant.now(), ZoneOffset.UTC);
clock.add(Duration.ofSeconds(5));

RateLimitResultAsserter asserter =
new RateLimitResultAsserter(new TokenBucketRateLimiter(10, 100, clock));
TokenBucket tokenBucket = new TokenBucket(10, 100, clock);

asserter.canAcquire(100);
asserter.cantAcquire();
assertCanAcquire(tokenBucket, 100);
assertCannotAcquire(tokenBucket);

clock.add(Duration.ofSeconds(1));
asserter.canAcquire(10);
asserter.cantAcquire();
assertCanAcquire(tokenBucket, 10);
assertCannotAcquire(tokenBucket);

clock.add(Duration.ofSeconds(10));
asserter.canAcquire(100);
asserter.cantAcquire();
assertCanAcquire(tokenBucket, 100);
assertCannotAcquire(tokenBucket);
}

/**
Expand All @@ -63,9 +62,8 @@ void testConcurrent() throws InterruptedException {
int numTasks = 50000;
int tokensPerSecond = 10; // Can be anything above 0

TokenBucketRateLimiter rl =
new TokenBucketRateLimiter(
tokensPerSecond, maxTokens, Clock.fixed(Instant.now(), ZoneOffset.UTC));
TokenBucket rl =
new TokenBucket(tokensPerSecond, maxTokens, Clock.fixed(Instant.now(), ZoneOffset.UTC));
AtomicInteger numAcquired = new AtomicInteger();
CountDownLatch startLatch = new CountDownLatch(numTasks);
CountDownLatch endLatch = new CountDownLatch(numTasks);
Expand Down Expand Up @@ -95,4 +93,16 @@ void testConcurrent() throws InterruptedException {
endLatch.await();
Assertions.assertEquals(maxTokens, numAcquired.get());
}

private void assertCanAcquire(TokenBucket tokenBucket, int times) {
for (int i = 0; i < times; i++) {
Assertions.assertTrue(tokenBucket.tryAcquire());
}
}

private void assertCannotAcquire(TokenBucket tokenBucket) {
for (int i = 0; i < 5; i++) {
Assertions.assertFalse(tokenBucket.tryAcquire());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ contract={org.apache.polaris.service.catalog.io.FileIOFactory}
name=test
qualifier={io.smallrye.common.annotation.Identifier}

[org.apache.polaris.service.dropwizard.ratelimiter.MockRealmTokenBucketRateLimiter]S
contract={org.apache.polaris.service.ratelimiter.RateLimiter}
name=mock-realm-token-bucket
[org.apache.polaris.service.dropwizard.ratelimiter.MockTokenBucketFactory]S
contract={org.apache.polaris.service.ratelimiter.TokenBucketFactory}
name=mock
qualifier={io.smallrye.common.annotation.Identifier}
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,12 @@ logging:
# Limits the size of request bodies sent to Polaris. -1 means no limit.
maxRequestBodyBytes: 1000000

# Limits the request rate per realm
# Limits the request rate per realm.
rateLimiter:
type: realm-token-bucket

# The token bucket factory to use when using the realm-token-bucket rate limiter.
tokenBucketFactory:
type: default
requestsPerSecond: 9999
windowSeconds: 10
8 changes: 8 additions & 0 deletions polaris-server.yml
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,11 @@ maxRequestBodyBytes: -1
# Optional, not specifying a "rateLimiter" section also means no rate limiter
rateLimiter:
type: no-op
# Uncomment to use the realm-token-bucket rate limiter
# type: realm-token-bucket

# The token bucket factory to use when using the realm-token-bucket rate limiter.
tokenBucketFactory:
type: default
requestsPerSecond: 9999
windowSeconds: 10
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.polaris.service.ratelimiter;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.smallrye.common.annotation.Identifier;
import java.time.Clock;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.polaris.core.context.RealmContext;

@Identifier("default")
public class DefaultTokenBucketFactory implements TokenBucketFactory {

private final long requestsPerSecond;
private final long windowSeconds;
private final Clock clock;
private final Map<String, TokenBucket> perRealmBuckets = new ConcurrentHashMap<>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not new here (so not blocking this PR), but this can be abused to cause an OOM in Polaris if an attacker issues requests with random realm-IDs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


@JsonCreator
public DefaultTokenBucketFactory(
@JsonProperty("requestsPerSecond") long requestsPerSecond,
@JsonProperty("windowSeconds") long windowSeconds) {
this(requestsPerSecond, windowSeconds, Clock.systemUTC());
}

public DefaultTokenBucketFactory(long requestsPerSecond, long windowSeconds, Clock clock) {
this.requestsPerSecond = requestsPerSecond;
this.windowSeconds = windowSeconds;
this.clock = clock;
}

@Override
public TokenBucket getOrCreateTokenBucket(RealmContext realmContext) {
String realmId = realmContext.getRealmIdentifier();
return perRealmBuckets.computeIfAbsent(
realmId,
k ->
new TokenBucket(
requestsPerSecond, Math.multiplyExact(requestsPerSecond, windowSeconds), clock));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
@Identifier("no-op")
public class NoOpRateLimiter implements RateLimiter {
@Override
public boolean tryAcquire() {
public boolean canProceed() {
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ public interface RateLimiter {
*
* @return Whether the request is allowed to proceed by the rate limiter
*/
boolean tryAcquire();
boolean canProceed();
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public RateLimiterFilter(RateLimiter rateLimiter) {
/** Returns a 429 if the rate limiter says so. Otherwise, forwards the request along. */
@Override
public void filter(ContainerRequestContext ctx) throws IOException {
if (!rateLimiter.tryAcquire()) {
if (!rateLimiter.canProceed()) {
ctx.abortWith(Response.status(Response.Status.TOO_MANY_REQUESTS).build());
LOGGER.atDebug().log("Rate limiting request");
}
Expand Down
Loading
Loading