Skip to content

Commit e88d6fe

Browse files
ivanbupercopybara-github
authored andcommitted
Fix truncation error acumulation for Sonic's resampling algorithm
Sonic would accumulate truncation errors on float to int conversions that caused the final output sample count to drift noticeably, by hundreds of samples on streams of a few minutes of length. The fix now keeps track of the truncation error and compensates for it. Other small fixes include eliminating lossy operations (e.g. int division) and using doubles instead of floats for resampling where helpful. This CL also introduces `SonicParameterizedTest`, which helps test resampling on an arbitrary number of randomly generated parameters, with random sample data. `SonicParameterizedTest` uses `BigDecimal`s for calculating sample count values, as to avoid precision issues with large sample counts. PiperOrigin-RevId: 673852768
1 parent 3caebbf commit e88d6fe

81 files changed

Lines changed: 5254 additions & 4934 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

RELEASENOTES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
([#1659](https://github.com/google/ExoPlayer/issues/1659)).
1919
* DataSource:
2020
* Audio:
21+
* Fix truncation error acumulation for `Sonic`'s resampling algorithm to
22+
prevent drift on number of output samples.
2123
* Video:
2224
* Text:
2325
* Metadata:

libraries/common/src/main/java/androidx/media3/common/audio/Sonic.java

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
private final int channelCount;
3939
private final float speed;
4040
private final float pitch;
41-
private final float rate;
41+
private final double rate;
4242
private final int minPeriod;
4343
private final int maxPeriod;
4444
private final int maxRequiredFrameCount;
@@ -57,6 +57,7 @@
5757
private int prevMinDiff;
5858
private int minDiff;
5959
private int maxDiff;
60+
private double accumulatedInterpolationError;
6061

6162
/**
6263
* Creates a new Sonic audio stream processor.
@@ -73,7 +74,7 @@ public Sonic(
7374
this.channelCount = channelCount;
7475
this.speed = speed;
7576
this.pitch = pitch;
76-
rate = (float) inputSampleRateHz / outputSampleRateHz;
77+
rate = (double) inputSampleRateHz / outputSampleRateHz;
7778
minPeriod = inputSampleRateHz / MAXIMUM_PITCH;
7879
maxPeriod = inputSampleRateHz / MINIMUM_PITCH;
7980
maxRequiredFrameCount = 2 * maxPeriod;
@@ -130,10 +131,20 @@ public void getOutput(ShortBuffer buffer) {
130131
*/
131132
public void queueEndOfStream() {
132133
int remainingFrameCount = inputFrameCount;
134+
133135
float s = speed / pitch;
134-
float r = rate * pitch;
136+
double r = rate * pitch;
137+
138+
// Math.round(double) returns a long, but we can safely cast from long to int because the only
139+
// double (accumulatedInterpolationError) always has a value between (-0.5 ; 0.5).
135140
int expectedOutputFrames =
136-
outputFrameCount + (int) ((remainingFrameCount / s + pitchFrameCount) / r + 0.5f);
141+
outputFrameCount
142+
+ (int)
143+
Math.round(
144+
((remainingFrameCount / s + pitchFrameCount) / r
145+
+ accumulatedInterpolationError));
146+
147+
accumulatedInterpolationError = 0;
137148

138149
// Add enough silence to flush both input and pitch buffers.
139150
inputBuffer =
@@ -144,10 +155,12 @@ public void queueEndOfStream() {
144155
}
145156
inputFrameCount += 2 * maxRequiredFrameCount;
146157
processStreamInput();
158+
147159
// Throw away any extra frames we generated due to the silence we added.
148160
if (outputFrameCount > expectedOutputFrames) {
149161
outputFrameCount = expectedOutputFrames;
150162
}
163+
151164
// Empty input and pitch buffers.
152165
inputFrameCount = 0;
153166
remainingInputToCopyFrameCount = 0;
@@ -166,6 +179,7 @@ public void flush() {
166179
prevMinDiff = 0;
167180
minDiff = 0;
168181
maxDiff = 0;
182+
accumulatedInterpolationError = 0;
169183
}
170184

171185
/** Returns the size of output that can be read with {@link #getOutput(ShortBuffer)}, in bytes. */
@@ -366,20 +380,35 @@ private short interpolate(short[] in, int inPos, int oldSampleRate, int newSampl
366380
return (short) ((ratio * left + (width - ratio) * right) / width);
367381
}
368382

369-
private void adjustRate(float rate, int originalOutputFrameCount) {
383+
private void adjustRate(double rate, int originalOutputFrameCount) {
384+
// If no new samples added to output buffer, then return.
370385
if (outputFrameCount == originalOutputFrameCount) {
371386
return;
372387
}
373-
int newSampleRate = (int) (inputSampleRateHz / rate);
374-
int oldSampleRate = inputSampleRateHz;
375-
// Set these values to help with the integer math.
376-
while (newSampleRate > (1 << 14) || oldSampleRate > (1 << 14)) {
388+
389+
// Move samples to pitch buffer first to calculate the block size.
390+
moveNewSamplesToPitchBuffer(originalOutputFrameCount);
391+
// Leave at least one pitch sample in the buffer.
392+
int blockSize = pitchFrameCount - 1;
393+
double expectedFrameCount = blockSize / rate + accumulatedInterpolationError;
394+
// We can safely cast from long to int because accumulatedInterpolationError is always between
395+
// (-0.5 ; 0.5), blockSize should always receive a reasonable buffer size (e.g. 1024 frames),
396+
// and we can assume that rate will not involve infinitesimally small values under normal
397+
// operation.
398+
int newSampleRate = (int) Math.round(expectedFrameCount);
399+
accumulatedInterpolationError = expectedFrameCount - newSampleRate;
400+
int oldSampleRate = blockSize;
401+
402+
// Simplify ratio for interpolation.
403+
while (newSampleRate != 0
404+
&& oldSampleRate != 0
405+
&& newSampleRate % 2 == 0
406+
&& oldSampleRate % 2 == 0) {
377407
newSampleRate /= 2;
378408
oldSampleRate /= 2;
379409
}
380-
moveNewSamplesToPitchBuffer(originalOutputFrameCount);
381-
// Leave at least one pitch sample in the buffer.
382-
for (int position = 0; position < pitchFrameCount - 1; position++) {
410+
411+
for (int position = 0; position < blockSize; position++) {
383412
while ((oldRatePosition + 1) * newSampleRate > newRatePosition * oldSampleRate) {
384413
outputBuffer =
385414
ensureSpaceForAdditionalFrames(
@@ -398,7 +427,7 @@ private void adjustRate(float rate, int originalOutputFrameCount) {
398427
newRatePosition = 0;
399428
}
400429
}
401-
removePitchFrames(pitchFrameCount - 1);
430+
removePitchFrames(blockSize);
402431
}
403432

404433
private int skipPitchPeriod(short[] samples, int position, float speed, int period) {
@@ -479,14 +508,14 @@ private void processStreamInput() {
479508
// Resample as many pitch periods as we have buffered on the input.
480509
int originalOutputFrameCount = outputFrameCount;
481510
float s = speed / pitch;
482-
float r = rate * pitch;
511+
double r = rate * pitch;
483512
if (s > 1.00001 || s < 0.99999) {
484513
changeSpeed(s);
485514
} else {
486515
copyToOutput(inputBuffer, 0, inputFrameCount);
487516
inputFrameCount = 0;
488517
}
489-
if (r != 1.0f) {
518+
if (r != 1.0) {
490519
adjustRate(r, originalOutputFrameCount);
491520
}
492521
}
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
/*
2+
* Copyright (C) 2024 The Android Open Source Project
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package androidx.media3.common.audio;
17+
18+
import static com.google.common.truth.Truth.assertThat;
19+
20+
import com.google.common.collect.ImmutableList;
21+
import com.google.common.collect.ImmutableSet;
22+
import com.google.common.collect.Range;
23+
import java.math.BigDecimal;
24+
import java.math.RoundingMode;
25+
import java.nio.ByteBuffer;
26+
import java.nio.ShortBuffer;
27+
import java.util.Random;
28+
import org.junit.Test;
29+
import org.junit.runner.RunWith;
30+
import org.robolectric.ParameterizedRobolectricTestRunner;
31+
import org.robolectric.ParameterizedRobolectricTestRunner.Parameter;
32+
import org.robolectric.ParameterizedRobolectricTestRunner.Parameters;
33+
34+
/** Parameterized robolectric test for testing {@link Sonic}. */
35+
@RunWith(ParameterizedRobolectricTestRunner.class)
36+
public final class RandomParameterizedSonicTest {
37+
38+
private static final int BLOCK_SIZE = 4096;
39+
private static final int BYTES_PER_SAMPLE = 2;
40+
private static final int SAMPLE_RATE = 48000;
41+
// Max 10 min streams.
42+
private static final long MAX_LENGTH_SAMPLES = 10 * 60 * SAMPLE_RATE;
43+
// How many instances per parameter to generate.
44+
private static final int PARAM_COUNT = 5;
45+
private static final int SPEED_DECIMAL_PRECISION = 2;
46+
private static final ImmutableList<Range<Float>> SPEED_RANGES =
47+
ImmutableList.of(
48+
Range.closedOpen(0f, 1f), Range.closedOpen(1f, 2f), Range.closedOpen(2f, 20f));
49+
50+
private static final Random random = new Random(/* seed */ 0);
51+
52+
private static final ImmutableList<Object[]> sParams = initParams();
53+
54+
@Parameters(name = "speed={0}, streamLength={1}")
55+
public static ImmutableList<Object[]> params() {
56+
// params() is called multiple times, so return cached parameters to avoid regenerating
57+
// different random parameter values.
58+
return sParams;
59+
}
60+
61+
private static ImmutableList<Object[]> initParams() {
62+
ImmutableSet.Builder<Object[]> paramsBuilder = new ImmutableSet.Builder<>();
63+
ImmutableSet.Builder<Float> speedsBuilder = new ImmutableSet.Builder<>();
64+
65+
for (int i = 0; i < PARAM_COUNT; i++) {
66+
Range<Float> r = SPEED_RANGES.get(i % SPEED_RANGES.size());
67+
speedsBuilder.add(round(generateFloatInRange(r)));
68+
}
69+
ImmutableSet<Float> speeds = speedsBuilder.build();
70+
71+
ImmutableSet<Long> lengths =
72+
new ImmutableSet.Builder<Long>()
73+
.addAll(
74+
random
75+
.longs(/* min */ 0, MAX_LENGTH_SAMPLES)
76+
.distinct()
77+
.limit(PARAM_COUNT)
78+
.iterator())
79+
.build();
80+
for (long length : lengths) {
81+
for (float speed : speeds) {
82+
paramsBuilder.add(new Object[] {speed, length});
83+
}
84+
}
85+
return paramsBuilder.build().asList();
86+
}
87+
88+
@Parameter(0)
89+
public float speed;
90+
91+
@Parameter(1)
92+
public long streamLength;
93+
94+
@Test
95+
public void resampling_returnsExpectedNumberOfSamples() {
96+
byte[] buf = new byte[BLOCK_SIZE * BYTES_PER_SAMPLE];
97+
ShortBuffer outBuffer = ShortBuffer.allocate(BLOCK_SIZE);
98+
// Use same speed and pitch values for Sonic to resample stream.
99+
Sonic sonic =
100+
new Sonic(
101+
/* inputSampleRateHz= */ SAMPLE_RATE,
102+
/* channelCount= */ 1,
103+
/* speed= */ speed,
104+
/* pitch= */ speed,
105+
/* outputSampleRateHz= */ SAMPLE_RATE);
106+
long readSampleCount = 0;
107+
108+
for (long samplesLeft = streamLength; samplesLeft > 0; samplesLeft -= BLOCK_SIZE) {
109+
random.nextBytes(buf);
110+
if (samplesLeft >= BLOCK_SIZE) {
111+
sonic.queueInput(ByteBuffer.wrap(buf).asShortBuffer());
112+
} else {
113+
sonic.queueInput(
114+
ByteBuffer.wrap(buf, 0, (int) (samplesLeft * BYTES_PER_SAMPLE)).asShortBuffer());
115+
sonic.queueEndOfStream();
116+
}
117+
while (sonic.getOutputSize() > 0) {
118+
sonic.getOutput(outBuffer);
119+
readSampleCount += outBuffer.position();
120+
outBuffer.clear();
121+
}
122+
}
123+
sonic.flush();
124+
125+
BigDecimal bigSpeed = new BigDecimal(String.valueOf(speed));
126+
BigDecimal bigLength = new BigDecimal(String.valueOf(streamLength));
127+
// The scale of expectedSize will always be equal to bigLength. Thus, the result will always
128+
// yield an integer.
129+
BigDecimal expectedSize = bigLength.divide(bigSpeed, RoundingMode.HALF_EVEN);
130+
assertThat(readSampleCount).isWithin(1).of(expectedSize.longValueExact());
131+
}
132+
133+
private static float round(float num) {
134+
BigDecimal bigDecimal = new BigDecimal(Float.toString(num));
135+
return bigDecimal.setScale(SPEED_DECIMAL_PRECISION, RoundingMode.HALF_EVEN).floatValue();
136+
}
137+
138+
private static float generateFloatInRange(Range<Float> r) {
139+
return r.lowerEndpoint() + random.nextFloat() * (r.upperEndpoint() - r.lowerEndpoint());
140+
}
141+
}

0 commit comments

Comments
 (0)