Skip to content

[LANG-1772] Restrict size of cache to prevent overflow errors #1379

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
14 changes: 12 additions & 2 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,10 @@
</distributionManagement>

<properties>
<argLine>-Xmx512m</argLine>
<heapSize>-Xmx512m</heapSize>
<extraArgs/>
<systemProperties/>
<argLine>${heapSize} ${extraArgs} ${systemProperties}</argLine>
<project.build.sourceEncoding>ISO-8859-1</project.build.sourceEncoding>
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
<!-- project.build.outputTimestamp is managed by Maven plugins, see https://maven.apache.org/guides/mini/guide-reproducible-builds.html -->
Expand Down Expand Up @@ -460,7 +463,7 @@
<properties>
<!-- LANG-1265: allow tests to access private fields/methods of java.base classes via reflection -->
<!-- LANG-1667: allow tests to access private fields/methods of java.base/java.util such as ArrayList via reflection -->
<argLine>-Xmx512m --add-opens java.base/java.lang.reflect=ALL-UNNAMED --add-opens java.base/java.lang=ALL-UNNAMED --add-opens java.base/java.util=ALL-UNNAMED --add-opens java.base/java.time=ALL-UNNAMED --add-opens java.base/java.time.chrono=ALL-UNNAMED</argLine>
<extraArgs>--add-opens java.base/java.lang.reflect=ALL-UNNAMED --add-opens java.base/java.lang=ALL-UNNAMED --add-opens java.base/java.util=ALL-UNNAMED --add-opens java.base/java.time=ALL-UNNAMED --add-opens java.base/java.time.chrono=ALL-UNNAMED</extraArgs>
</properties>
</profile>
<profile>
Expand Down Expand Up @@ -522,6 +525,13 @@
</plugins>
</build>
</profile>
<profile>
<id>largeheap</id>
<properties>
<heapSize>-Xmx1024m</heapSize>
<systemProperties>-Dtest.large.heap=true</systemProperties>
</properties>
</profile>
</profiles>
<developers>
<developer>
Expand Down
60 changes: 48 additions & 12 deletions src/main/java/org/apache/commons/lang3/CachedRandomBits.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,20 @@ final class CachedRandomBits {
*/
private int bitIndex;

/**
* The maximum size of the cache.
*
* <p>
* This is to prevent the possibility of overflow in the {@code if (bitIndex >> 3 >= cache.length)} in the {@link #nextBits(int)} method.
* </p>
*/
private static final int MAX_CACHE_SIZE = Integer.MAX_VALUE >> 3;
/** Maximum number of bits that can be generated (size of an int) */
private static final int MAX_BITS = 32;
/** Mask to extract the bit offset within a byte (0-7) */
private static final int BIT_INDEX_MASK = 0x7;
/** Number of bits in a byte */
private static final int BITS_PER_BYTE = 8;
/**
* Creates a new instance.
*
Expand All @@ -62,7 +76,7 @@ final class CachedRandomBits {
if (cacheSize <= 0) {
throw new IllegalArgumentException("cacheSize must be positive");
}
this.cache = new byte[cacheSize];
this.cache = cacheSize <= MAX_CACHE_SIZE ? new byte[cacheSize] : new byte[MAX_CACHE_SIZE];
this.random = Objects.requireNonNull(random, "random");
this.random.nextBytes(this.cache);
this.bitIndex = 0;
Expand All @@ -71,28 +85,50 @@ final class CachedRandomBits {
/**
* Generates a random integer with the specified number of bits.
*
* @param bits number of bits to generate, MUST be between 1 and 32
* @return random integer with {@code bits} bits
* <p>This method efficiently generates random bits by using a byte cache and bit manipulation:
* <ul>
* <li>Uses a byte array cache to avoid frequent calls to the underlying random number generator</li>
* <li>Extracts bits from each byte using bit shifting and masking</li>
* <li>Handles partial bytes to avoid wasting random bits</li>
* <li>Accumulates bits until the requested number is reached</li>
* </ul>
* </p>
*
* @param bits number of bits to generate, MUST be between 1 and 32 (inclusive)
* @return random integer containing exactly the requested number of random bits
* @throws IllegalArgumentException if bits is not between 1 and 32
*/
public int nextBits(final int bits) {
if (bits > 32 || bits <= 0) {
throw new IllegalArgumentException("number of bits must be between 1 and 32");
if (bits > MAX_BITS || bits <= 0) {
throw new IllegalArgumentException("number of bits must be between 1 and " + MAX_BITS);
}
int result = 0;
int generatedBits = 0; // number of generated bits up to now
while (generatedBits < bits) {
// Check if we need to refill the cache
// Convert bitIndex to byte index by dividing by 8 (right shift by 3)
if (bitIndex >> 3 >= cache.length) {
// we exhausted the number of bits in the cache
// this should only happen if the bitIndex is exactly matching the cache length
assert bitIndex == cache.length * 8;
// We exhausted the number of bits in the cache
// This should only happen if the bitIndex is exactly matching the cache length
assert bitIndex == cache.length * BITS_PER_BYTE;
random.nextBytes(cache);
bitIndex = 0;
}
// generatedBitsInIteration is the number of bits that we will generate
// in this iteration of the while loop
final int generatedBitsInIteration = Math.min(8 - (bitIndex & 0x7), bits - generatedBits);
// Calculate how many bits we can extract from the current byte
// 1. Get current position within byte (0-7) using bitIndex & 0x7
// 2. Calculate remaining bits in byte: 8 - (position within byte)
// 3. Take minimum of remaining bits in byte and bits still needed
final int generatedBitsInIteration = Math.min(
BITS_PER_BYTE - (bitIndex & BIT_INDEX_MASK),
bits - generatedBits);
// Shift existing result left to make room for new bits
result = result << generatedBitsInIteration;
result |= cache[bitIndex >> 3] >> (bitIndex & 0x7) & (1 << generatedBitsInIteration) - 1;
// Extract and append new bits:
// 1. Get byte from cache (bitIndex >> 3 converts bit index to byte index)
// 2. Shift right by bit position within byte (bitIndex & 0x7)
// 3. Mask to keep only the bits we want ((1 << generatedBitsInIteration) - 1)
result |= cache[bitIndex >> 3] >> (bitIndex & BIT_INDEX_MASK) & ((1 << generatedBitsInIteration) - 1);
// Update counters
generatedBits += generatedBitsInIteration;
bitIndex += generatedBitsInIteration;
}
Expand Down
15 changes: 14 additions & 1 deletion src/main/java/org/apache/commons/lang3/RandomStringUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ public class RandomStringUtils {
private static final int ASCII_A = 'A';
private static final int ASCII_z = 'z';

private static final int CACHE_PADDING_BITS = 3;
private static final int BITS_TO_BYTES_DIVISOR = 5;
private static final int BASE_CACHE_SIZE_PADDING = 10;

/**
* Gets the singleton instance based on {@link ThreadLocalRandom#current()}; <b>which is not cryptographically
* secure</b>; use {@link #secure()} to use an algorithms/providers specified in the
Expand Down Expand Up @@ -329,7 +333,16 @@ public static String random(int count, int start, int end, final boolean letters
// Ideally the cache size depends on multiple factor, including the cost of generating x bytes
// of randomness as well as the probability of rejection. It is however not easy to know
// those values programmatically for the general case.
final CachedRandomBits arb = new CachedRandomBits((count * gapBits + 3) / 5 + 10, random);
// Calculate cache size:
// 1. Multiply count by bits needed per character (gapBits)
// 2. Add padding bits (3) to handle partial bytes
// 3. Divide by 5 to convert to bytes (normally this would be by 8, dividing by 5 allows for about 60% extra space)
// 4. Add base padding (10) to handle small counts efficiently
// 5. Ensure we don't exceed Integer.MAX_VALUE / 5 + 10 to provide a good balance between overflow prevention and
// making the cache extremely large
final long desiredCacheSize = ((long) count * gapBits + CACHE_PADDING_BITS) / BITS_TO_BYTES_DIVISOR + BASE_CACHE_SIZE_PADDING;
final int cacheSize = (int) Math.min(desiredCacheSize, Integer.MAX_VALUE / BITS_TO_BYTES_DIVISOR + BASE_CACHE_SIZE_PADDING);
final CachedRandomBits arb = new CachedRandomBits(cacheSize, random);

while (count-- != 0) {
// Generate a random value between start (included) and end (excluded)
Expand Down
13 changes: 13 additions & 0 deletions src/test/java/org/apache/commons/lang3/RandomStringUtilsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,20 @@
import java.util.stream.Stream;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfSystemProperty;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;

/**
* Tests {@link RandomStringUtils}.
*/
public class RandomStringUtilsTest extends AbstractLangTest {

private static final int LOOP_COUNT = 1_000;
/** Maximum safe value for count to avoid overflow: (21x + 3) / 5 + 10 < 0x0FFF_FFFF */
private static final int MAX_SAFE_COUNT = 63_913_201;


static Stream<RandomStringUtils> randomProvider() {
return Stream.of(RandomStringUtils.secure(), RandomStringUtils.secureStrong(), RandomStringUtils.insecure());
Expand Down Expand Up @@ -802,4 +807,12 @@ public void testRandomWithChars(final RandomStringUtils rsu) {
assertNotEquals(r1, r3);
assertNotEquals(r2, r3);
}

@ParameterizedTest
@ValueSource(ints = {MAX_SAFE_COUNT, MAX_SAFE_COUNT + 1})
@EnabledIfSystemProperty(named = "test.large.heap", matches = "true")
public void testHugeStrings(final int expectedLength) {
final String hugeString = RandomStringUtils.random(expectedLength);
assertEquals(expectedLength, hugeString.length(), "hugeString.length() == expectedLength");
}
}
Loading