Skip to content

Commit d0989ad

Browse files
Google AI Edge Gallerycopybara-github
authored andcommitted
Add audio support.
- Add a new task "audio scribe". - Allow users to record audio clips or pick wav files to interact with model. - Add support for importing models with audio capability. - Fix a typo in Settings dialog (Thanks https://github.com/rhnvrm!) PiperOrigin-RevId: 774832681
1 parent 33c3ee6 commit d0989ad

27 files changed

+1369
-288
lines changed

Android/src/app/src/main/AndroidManifest.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
<uses-permission android:name="android.permission.FOREGROUND_SERVICE_DATA_SYNC"/>
3030
<uses-permission android:name="android.permission.INTERNET" />
3131
<uses-permission android:name="android.permission.POST_NOTIFICATIONS" />
32+
<uses-permission android:name="android.permission.RECORD_AUDIO" />
3233
<uses-permission android:name="android.permission.WAKE_LOCK"/>
3334

3435
<uses-feature

Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Types.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,5 @@ interface LatencyProvider {
2525
data class Classification(val label: String, val score: Float, val color: Color)
2626

2727
data class JsonObjAndTextContent<T>(val jsonObj: T, val textContent: String)
28+
29+
class AudioClip(val audioData: ByteArray, val sampleRate: Int)

Android/src/app/src/main/java/com/google/ai/edge/gallery/common/Utils.kt

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,17 @@
1717
package com.google.ai.edge.gallery.common
1818

1919
import android.content.Context
20+
import android.net.Uri
2021
import android.util.Log
22+
import com.google.ai.edge.gallery.data.SAMPLE_RATE
2123
import com.google.gson.Gson
2224
import com.google.gson.reflect.TypeToken
2325
import java.io.File
2426
import java.net.HttpURLConnection
2527
import java.net.URL
28+
import java.nio.ByteBuffer
29+
import java.nio.ByteOrder
30+
import kotlin.math.floor
2631

2732
data class LaunchInfo(val ts: Long)
2833

@@ -112,3 +117,135 @@ inline fun <reified T> getJsonResponse(url: String): JsonObjAndTextContent<T>? {
112117

113118
return null
114119
}
120+
121+
fun convertWavToMonoWithMaxSeconds(
122+
context: Context,
123+
stereoUri: Uri,
124+
maxSeconds: Int = 30,
125+
): AudioClip? {
126+
Log.d(TAG, "Start to convert wav file to mono channel")
127+
128+
try {
129+
val inputStream = context.contentResolver.openInputStream(stereoUri) ?: return null
130+
val originalBytes = inputStream.readBytes()
131+
inputStream.close()
132+
133+
// Read WAV header
134+
if (originalBytes.size < 44) {
135+
// Not a valid WAV file
136+
Log.e(TAG, "Not a valid wav file")
137+
return null
138+
}
139+
140+
val headerBuffer = ByteBuffer.wrap(originalBytes, 0, 44).order(ByteOrder.LITTLE_ENDIAN)
141+
val channels = headerBuffer.getShort(22)
142+
var sampleRate = headerBuffer.getInt(24)
143+
val bitDepth = headerBuffer.getShort(34)
144+
Log.d(TAG, "File metadata: channels: $channels, sampleRate: $sampleRate, bitDepth: $bitDepth")
145+
146+
// Normalize audio to 16-bit.
147+
val audioDataBytes = originalBytes.copyOfRange(fromIndex = 44, toIndex = originalBytes.size)
148+
var sixteenBitBytes: ByteArray =
149+
if (bitDepth.toInt() == 8) {
150+
Log.d(TAG, "Converting 8-bit audio to 16-bit.")
151+
convert8BitTo16Bit(audioDataBytes)
152+
} else {
153+
// Assume 16-bit or other format that can be handled directly
154+
audioDataBytes
155+
}
156+
157+
// Convert byte array to short array for processing
158+
val shortBuffer =
159+
ByteBuffer.wrap(sixteenBitBytes).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer()
160+
var pcmSamples = ShortArray(shortBuffer.remaining())
161+
shortBuffer.get(pcmSamples)
162+
163+
// Resample if sample rate is less than 16000 Hz ---
164+
if (sampleRate < SAMPLE_RATE) {
165+
Log.d(TAG, "Resampling from $sampleRate Hz to $SAMPLE_RATE Hz.")
166+
pcmSamples = resample(pcmSamples, sampleRate, SAMPLE_RATE, channels.toInt())
167+
sampleRate = SAMPLE_RATE
168+
Log.d(TAG, "Resampling complete. New sample count: ${pcmSamples.size}")
169+
}
170+
171+
// Convert stereo to mono if necessary
172+
var monoSamples =
173+
if (channels.toInt() == 2) {
174+
Log.d(TAG, "Converting stereo to mono.")
175+
val mono = ShortArray(pcmSamples.size / 2)
176+
for (i in mono.indices) {
177+
val left = pcmSamples[i * 2]
178+
val right = pcmSamples[i * 2 + 1]
179+
mono[i] = ((left + right) / 2).toShort()
180+
}
181+
mono
182+
} else {
183+
Log.d(TAG, "Audio is already mono. No channel conversion needed.")
184+
pcmSamples
185+
}
186+
187+
// Trim the audio to maxSeconds ---
188+
val maxSamples = maxSeconds * sampleRate
189+
if (monoSamples.size > maxSamples) {
190+
Log.d(TAG, "Trimming clip from ${monoSamples.size} samples to $maxSamples samples.")
191+
monoSamples = monoSamples.copyOfRange(0, maxSamples)
192+
}
193+
194+
val monoByteBuffer = ByteBuffer.allocate(monoSamples.size * 2).order(ByteOrder.LITTLE_ENDIAN)
195+
monoByteBuffer.asShortBuffer().put(monoSamples)
196+
return AudioClip(audioData = monoByteBuffer.array(), sampleRate = sampleRate)
197+
} catch (e: Exception) {
198+
Log.e(TAG, "Failed to convert wav to mono", e)
199+
return null
200+
}
201+
}
202+
203+
/** Converts 8-bit unsigned PCM audio data to 16-bit signed PCM. */
204+
private fun convert8BitTo16Bit(eightBitData: ByteArray): ByteArray {
205+
// The new 16-bit data will be twice the size
206+
val sixteenBitData = ByteArray(eightBitData.size * 2)
207+
val buffer = ByteBuffer.wrap(sixteenBitData).order(ByteOrder.LITTLE_ENDIAN)
208+
209+
for (byte in eightBitData) {
210+
// Convert the unsigned 8-bit byte (0-255) to a signed 16-bit short (-32768 to 32767)
211+
// 1. Get the unsigned value by masking with 0xFF
212+
// 2. Subtract 128 to center the waveform around 0 (range becomes -128 to 127)
213+
// 3. Scale by 256 to expand to the 16-bit range
214+
val unsignedByte = byte.toInt() and 0xFF
215+
val sixteenBitSample = ((unsignedByte - 128) * 256).toShort()
216+
buffer.putShort(sixteenBitSample)
217+
}
218+
return sixteenBitData
219+
}
220+
221+
/** Resamples PCM audio data from an original sample rate to a target sample rate. */
222+
private fun resample(
223+
inputSamples: ShortArray,
224+
originalSampleRate: Int,
225+
targetSampleRate: Int,
226+
channels: Int,
227+
): ShortArray {
228+
if (originalSampleRate == targetSampleRate) {
229+
return inputSamples
230+
}
231+
232+
val ratio = targetSampleRate.toDouble() / originalSampleRate
233+
val outputLength = (inputSamples.size * ratio).toInt()
234+
val resampledData = ShortArray(outputLength)
235+
236+
if (channels == 1) { // Mono
237+
for (i in resampledData.indices) {
238+
val position = i / ratio
239+
val index1 = floor(position).toInt()
240+
val index2 = index1 + 1
241+
val fraction = position - index1
242+
243+
val sample1 = if (index1 < inputSamples.size) inputSamples[index1].toDouble() else 0.0
244+
val sample2 = if (index2 < inputSamples.size) inputSamples[index2].toDouble() else 0.0
245+
246+
resampledData[i] = (sample1 * (1 - fraction) + sample2 * fraction).toInt().toShort()
247+
}
248+
}
249+
250+
return resampledData
251+
}

Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Config.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ enum class ConfigKey(val label: String) {
5050
DEFAULT_TOPP("Default TopP"),
5151
DEFAULT_TEMPERATURE("Default temperature"),
5252
SUPPORT_IMAGE("Support image"),
53+
SUPPORT_AUDIO("Support audio"),
5354
MAX_RESULT_COUNT("Max result count"),
5455
USE_GPU("Use GPU"),
5556
ACCELERATOR("Choose accelerator"),

Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Consts.kt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,12 @@ val DEFAULT_ACCELERATORS = listOf(Accelerator.GPU)
4444

4545
// Max number of images allowed in a "ask image" session.
4646
const val MAX_IMAGE_COUNT = 10
47+
48+
// Max number of audio clip in an "ask audio" session.
49+
const val MAX_AUDIO_CLIP_COUNT = 10
50+
51+
// Max audio clip duration in seconds.
52+
const val MAX_AUDIO_CLIP_DURATION_SEC = 30
53+
54+
// Audio-recording related consts.
55+
const val SAMPLE_RATE = 16000

Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Model.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ data class Model(
8787
/** Whether the LLM model supports image input. */
8888
val llmSupportImage: Boolean = false,
8989

90+
/** Whether the LLM model supports audio input. */
91+
val llmSupportAudio: Boolean = false,
92+
9093
/** Whether the model is imported or not. */
9194
val imported: Boolean = false,
9295

Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ data class AllowedModel(
3838
val taskTypes: List<String>,
3939
val disabled: Boolean? = null,
4040
val llmSupportImage: Boolean? = null,
41+
val llmSupportAudio: Boolean? = null,
4142
val estimatedPeakMemoryInBytes: Long? = null,
4243
) {
4344
fun toModel(): Model {
@@ -96,6 +97,7 @@ data class AllowedModel(
9697
showRunAgainButton = showRunAgainButton,
9798
learnMoreUrl = "https://huggingface.co/${modelId}",
9899
llmSupportImage = llmSupportImage == true,
100+
llmSupportAudio = llmSupportAudio == true,
99101
)
100102
}
101103

Android/src/app/src/main/java/com/google/ai/edge/gallery/data/Tasks.kt

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,22 @@
1717
package com.google.ai.edge.gallery.data
1818

1919
import androidx.annotation.StringRes
20+
import androidx.compose.material.icons.Icons
21+
import androidx.compose.material.icons.outlined.Forum
22+
import androidx.compose.material.icons.outlined.Mic
23+
import androidx.compose.material.icons.outlined.Mms
24+
import androidx.compose.material.icons.outlined.Widgets
2025
import androidx.compose.runtime.MutableState
2126
import androidx.compose.runtime.mutableLongStateOf
2227
import androidx.compose.ui.graphics.vector.ImageVector
2328
import com.google.ai.edge.gallery.R
24-
import com.google.ai.edge.gallery.ui.icon.Forum
25-
import com.google.ai.edge.gallery.ui.icon.Mms
26-
import com.google.ai.edge.gallery.ui.icon.Widgets
2729

2830
/** Type of task. */
2931
enum class TaskType(val label: String, val id: String) {
3032
LLM_CHAT(label = "AI Chat", id = "llm_chat"),
3133
LLM_PROMPT_LAB(label = "Prompt Lab", id = "llm_prompt_lab"),
3234
LLM_ASK_IMAGE(label = "Ask Image", id = "llm_ask_image"),
35+
LLM_ASK_AUDIO(label = "Audio Scribe", id = "llm_ask_audio"),
3336
TEST_TASK_1(label = "Test task 1", id = "test_task_1"),
3437
TEST_TASK_2(label = "Test task 2", id = "test_task_2"),
3538
}
@@ -71,7 +74,7 @@ data class Task(
7174
val TASK_LLM_CHAT =
7275
Task(
7376
type = TaskType.LLM_CHAT,
74-
icon = Forum,
77+
icon = Icons.Outlined.Forum,
7578
models = mutableListOf(),
7679
description = "Chat with on-device large language models",
7780
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
@@ -83,7 +86,7 @@ val TASK_LLM_CHAT =
8386
val TASK_LLM_PROMPT_LAB =
8487
Task(
8588
type = TaskType.LLM_PROMPT_LAB,
86-
icon = Widgets,
89+
icon = Icons.Outlined.Widgets,
8790
models = mutableListOf(),
8891
description = "Single turn use cases with on-device large language model",
8992
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
@@ -95,7 +98,7 @@ val TASK_LLM_PROMPT_LAB =
9598
val TASK_LLM_ASK_IMAGE =
9699
Task(
97100
type = TaskType.LLM_ASK_IMAGE,
98-
icon = Mms,
101+
icon = Icons.Outlined.Mms,
99102
models = mutableListOf(),
100103
description = "Ask questions about images with on-device large language models",
101104
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
@@ -104,8 +107,23 @@ val TASK_LLM_ASK_IMAGE =
104107
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat,
105108
)
106109

110+
val TASK_LLM_ASK_AUDIO =
111+
Task(
112+
type = TaskType.LLM_ASK_AUDIO,
113+
icon = Icons.Outlined.Mic,
114+
models = mutableListOf(),
115+
// TODO(do not submit)
116+
description =
117+
"Instantly transcribe and/or translate audio clips using on-device large language models",
118+
docUrl = "https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference/android",
119+
sourceCodeUrl =
120+
"https://github.com/google-ai-edge/gallery/blob/main/Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/llmchat/LlmChatModelHelper.kt",
121+
textInputPlaceHolderRes = R.string.text_input_placeholder_llm_chat,
122+
)
123+
107124
/** All tasks. */
108-
val TASKS: List<Task> = listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)
125+
val TASKS: List<Task> =
126+
listOf(TASK_LLM_ASK_IMAGE, TASK_LLM_ASK_AUDIO, TASK_LLM_PROMPT_LAB, TASK_LLM_CHAT)
109127

110128
fun getModelByName(name: String): Model? {
111129
for (task in TASKS) {

Android/src/app/src/main/java/com/google/ai/edge/gallery/ui/ViewModelProvider.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import androidx.lifecycle.viewmodel.CreationExtras
2121
import androidx.lifecycle.viewmodel.initializer
2222
import androidx.lifecycle.viewmodel.viewModelFactory
2323
import com.google.ai.edge.gallery.GalleryApplication
24+
import com.google.ai.edge.gallery.ui.llmchat.LlmAskAudioViewModel
2425
import com.google.ai.edge.gallery.ui.llmchat.LlmAskImageViewModel
2526
import com.google.ai.edge.gallery.ui.llmchat.LlmChatViewModel
2627
import com.google.ai.edge.gallery.ui.llmsingleturn.LlmSingleTurnViewModel
@@ -49,6 +50,9 @@ object ViewModelProvider {
4950

5051
// Initializer for LlmAskImageViewModel.
5152
initializer { LlmAskImageViewModel() }
53+
54+
// Initializer for LlmAskAudioViewModel.
55+
initializer { LlmAskAudioViewModel() }
5256
}
5357
}
5458

0 commit comments

Comments
 (0)