Skip to content

Collect stats about rate limited requests #181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import dev.aikido.agent_api.ratelimiting.ShouldRateLimit;
import dev.aikido.agent_api.storage.ServiceConfigStore;
import dev.aikido.agent_api.storage.ServiceConfiguration;
import dev.aikido.agent_api.storage.routes.RoutesStore;
import dev.aikido.agent_api.storage.statistics.StatisticsStore;

public final class ShouldBlockRequest {
private ShouldBlockRequest() {
Expand Down Expand Up @@ -34,6 +36,14 @@ public static ShouldBlockRequestResult shouldBlockRequest() {
context.getRouteMetadata(), context.getUser(), context.getRemoteAddress()
);
if (rateLimitDecision.block()) {
// increment rate-limiting stats both globally and on the route :
StatisticsStore.incrementRateLimited();
// increment routes stats using method & route from the endpoint (store stats for wildcards, in wildcard route)
RoutesStore.addRouteRateLimitedCount(
rateLimitDecision.rateLimitedEndpoint().getMethod(),
rateLimitDecision.rateLimitedEndpoint().getRoute()
);

BlockedRequestResult blockedRequestResult = new BlockedRequestResult(
"ratelimited", rateLimitDecision.trigger(), context.getRemoteAddress()
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ public static void report(int statusCode) {
return;
}

RoutesStore.addRouteHits(routeMetadata);
RoutesStore.addRouteHits(context.getMethod(), context.getRoute());

// check if we need to generate api spec
int hits = RoutesStore.getRouteHits(routeMetadata);
int hits = RoutesStore.getRouteHits(context.getMethod(), context.getRoute());
if (hits <= ANALYSIS_ON_FIRST_X_REQUESTS) {
APISpec apiSpec = getApiInfo(context);
RoutesStore.updateApiSpec(routeMetadata, apiSpec);
RoutesStore.updateApiSpec(context.getMethod(), context.getRoute(), apiSpec);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,20 @@

public final class ShouldRateLimit {
private ShouldRateLimit() {}
public record RateLimitDecision(boolean block, String trigger) {}

public record RateLimitDecision(
boolean block,
String trigger,
Endpoint rateLimitedEndpoint
) {
}

public static RateLimitDecision shouldRateLimit(RouteMetadata routeMetadata, User user, String remoteAddress) {
List<Endpoint> endpoints = ServiceConfigStore.getConfig().getEndpoints();
List<Endpoint> matches = matchEndpoints(routeMetadata, endpoints);
Endpoint rateLimitedEndpoint = getRateLimitedEndpoint(matches, routeMetadata.route());
if (rateLimitedEndpoint == null) {
return new RateLimitDecision(/*block*/false, null);
return new RateLimitDecision(/*block*/false, null, null);
}

long windowSizeInMS = rateLimitedEndpoint.getRateLimiting().windowSizeInMS();
Expand All @@ -29,17 +36,17 @@ public static RateLimitDecision shouldRateLimit(RouteMetadata routeMetadata, Use
boolean allowed = RateLimiterStore.isAllowed(key, windowSizeInMS, maxRequests);
if (allowed) {
// Do not continue to check based on IP if user is present:
return new RateLimitDecision(/*block*/false, null);
return new RateLimitDecision(/*block*/false, null, null);
}
return new RateLimitDecision(/*block*/ true, /*trigger*/ "user");
return new RateLimitDecision(/*block*/ true, /*trigger*/ "user", rateLimitedEndpoint);
}
if (remoteAddress != null && !remoteAddress.isEmpty()) {
String key = rateLimitedEndpoint.getMethod() + ":" + rateLimitedEndpoint.getRoute() + ":ip:" + remoteAddress;
boolean allowed = RateLimiterStore.isAllowed(key, windowSizeInMS, maxRequests);
if (!allowed) {
return new RateLimitDecision(/*block*/ true, /*trigger*/ "ip");
return new RateLimitDecision(/*block*/ true, /*trigger*/ "ip", rateLimitedEndpoint);
}
}
return new RateLimitDecision(/*block*/false, null);
return new RateLimitDecision(/*block*/false, null, null);
}
}
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
package dev.aikido.agent_api.storage.routes;

import com.google.gson.*;
import dev.aikido.agent_api.api_discovery.APISpec;
import dev.aikido.agent_api.context.RouteMetadata;

import java.lang.reflect.Type;

import static dev.aikido.agent_api.api_discovery.APISpecMerger.mergeAPISpecs;

public class RouteEntry {
final String method;
final String path;
private int hits;
private int rateLimitedCount;
private APISpec apispec;

public RouteEntry(String method, String path) {
Expand All @@ -32,6 +30,13 @@ public int getHits() {
return hits;
}

public void incrementRateLimitCount() {
rateLimitedCount++;
}

public int getRateLimitCount() {
return rateLimitedCount;
}
public void updateApiSpec(APISpec newApiSpec) {
this.apispec = mergeAPISpecs(newApiSpec, this.apispec);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
package dev.aikido.agent_api.storage.routes;

import dev.aikido.agent_api.context.RouteMetadata;

public final class RouteToKeyHelper {
private RouteToKeyHelper() {}

public static String routeToKey(RouteMetadata routeMetadata) {
return routeMetadata.method() + ":" + routeMetadata.route();
public static String routeToKey(String method, String route) {
return method + ":" + route;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,32 @@ public Routes() {
this(1000); // Default max size
}

private void initializeRoute(RouteMetadata routeMetadata) {
private void ensureRoute(String method, String route) {
manageRoutesSize();
String key = routeToKey(routeMetadata);
routes.put(key, new RouteEntry(routeMetadata));
String key = routeToKey(method, route);
if(!routes.containsKey(key)) {
routes.put(key, new RouteEntry(method, route));
}
}

public void incrementRoute(RouteMetadata routeMetadata) {
String key = routeToKey(routeMetadata);
if (!routes.containsKey(key)) {
// if the route does not yet exist, create it.
initializeRoute(routeMetadata);
public void incrementRoute(String method, String route) {
ensureRoute(method, route);
RouteEntry routeEntry = this.get(method, route);
if (routeEntry != null) {
routeEntry.incrementHits();
}
RouteEntry route = routes.get(key);
if (route != null) {
route.incrementHits();
}

public void incrementRateLimitCount(String method, String route) {
ensureRoute(method, route);
RouteEntry routeEntry = this.get(method, route);
if (routeEntry != null) {
routeEntry.incrementRateLimitCount();
}
}

public RouteEntry get(RouteMetadata routeMetadata) {
String key = routeToKey(routeMetadata);
public RouteEntry get(String method, String route) {
String key = routeToKey(method, route);
return routes.get(key);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ private RoutesStore() {

}

public static int getRouteHits(RouteMetadata routeMetadata) {
public static int getRouteHits(String method, String route) {
mutex.lock();
try {
return routes.get(routeMetadata).getHits();
return routes.get(method, route).getHits();
} finally {
mutex.unlock();
}
Expand All @@ -34,12 +34,12 @@ public static RouteEntry[] getRoutesAsList() {
}
}

public static void updateApiSpec(RouteMetadata routeMetadata, APISpec apiSpec) {
public static void updateApiSpec(String method, String route, APISpec apiSpec) {
mutex.lock();
try {
RouteEntry route = routes.get(routeMetadata);
if (route != null) {
route.updateApiSpec(apiSpec);
RouteEntry routeEntry = routes.get(method, route);
if (routeEntry != null) {
routeEntry.updateApiSpec(apiSpec);
}
} catch (Throwable e) {
logger.debug("Error occurred updating api specs: %s", e.getMessage());
Expand All @@ -48,17 +48,28 @@ public static void updateApiSpec(RouteMetadata routeMetadata, APISpec apiSpec) {
}
}

public static void addRouteHits(RouteMetadata routeMetadata) {
public static void addRouteHits(String method, String route) {
mutex.lock();
try {
routes.incrementRoute(routeMetadata);
routes.incrementRoute(method, route);
} catch (Throwable e) {
logger.debug("Error occurred incrementing route hits: %s", e.getMessage());
} finally {
mutex.unlock();
}
}

public static void addRouteRateLimitedCount(String method, String route) {
mutex.lock();
try {
routes.incrementRateLimitCount(method, route);
} catch (Throwable e) {
logger.debug("Error occurred incrementing route rate limit count: %s", e.getMessage());
} finally {
mutex.unlock();
}
}

public static void clear() {
mutex.lock();
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,19 @@ public class Statistics {
private final Map<String, Integer> ipAddressMatches = new HashMap<>();
private final Map<String, Integer> userAgentMatches = new HashMap<>();
private int totalHits;
private final int aborted; // We don't use the "aborted" field right now
private int rateLimited;
private int attacksDetected;
private int attacksBlocked;
private long startedAt;

public Statistics(int totalHits, int attacksDetected, int attacksBlocked) {
this.totalHits = totalHits;
this.attacksDetected = attacksDetected;
this.attacksBlocked = attacksBlocked;
this.startedAt = UnixTimeMS.getUnixTimeMS();
}

public Statistics() {
this(0, 0, 0);
this.totalHits = 0;
this.rateLimited = 0;
this.aborted = 0;
this.attacksDetected = 0;
this.attacksBlocked = 0;
this.startedAt = UnixTimeMS.getUnixTimeMS();
}


Expand All @@ -35,6 +35,14 @@ public int getTotalHits() {
return totalHits;
}

public void incrementRateLimited() {
rateLimited += 1;
}

public int getRateLimited() {
return rateLimited;
}


// attack stats
public void incrementAttacksDetected(String operation) {
Expand Down Expand Up @@ -104,8 +112,7 @@ public void addMatchToUserAgents(String key) {
public StatsRecord getRecord() {
long endedAt = UnixTimeMS.getUnixTimeMS();
return new StatsRecord(this.startedAt, endedAt, new StatsRequestsRecord(
/* total */ totalHits,
/* aborted */ 0, // Unknown statistic, default to 0,
totalHits, aborted, rateLimited,
/* attacksDetected */ Map.of(
"total", attacksDetected,
"blocked", attacksBlocked
Expand All @@ -118,6 +125,7 @@ public StatsRecord getRecord() {

public void clear() {
this.totalHits = 0;
this.rateLimited = 0;
this.attacksBlocked = 0;
this.attacksDetected = 0;
this.startedAt = UnixTimeMS.getUnixTimeMS();
Expand All @@ -127,7 +135,8 @@ public void clear() {
}

// Stats records for sending out the heartbeat :
public record StatsRequestsRecord(long total, long aborted, Map<String, Integer> attacksDetected) {
public record StatsRequestsRecord(long total, long aborted, long rateLimited,
Map<String, Integer> attacksDetected) {
}

public record StatsRecord(long startedAt, long endedAt, StatsRequestsRecord requests,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ public static void incrementHits() {
}
}

public static void incrementRateLimited() {
mutex.lock();
try {
stats.incrementRateLimited();
} finally {
mutex.unlock();
}
}

public static void incrementAttacksDetected(String operation) {
mutex.lock();
try {
Expand Down
Loading
Loading