Skip to content

Commit ee3c933

Browse files
authored
feat(ai-logic): add hybrid on-device inference sample (#2773)
1 parent 7da89ea commit ee3c933

9 files changed

Lines changed: 403 additions & 4 deletions

File tree

firebase-ai/app/build.gradle.kts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ android {
1212

1313
defaultConfig {
1414
applicationId = "com.google.firebase.quickstart.ai"
15-
minSdk = 23
15+
minSdk = 26
1616
targetSdk = 36
1717
versionCode = 1
1818
versionName = "1.0"
@@ -73,6 +73,7 @@ dependencies {
7373
// Firebase
7474
implementation(platform(libs.firebase.bom))
7575
implementation(libs.firebase.ai)
76+
implementation(libs.firebase.ai.ondevice)
7677

7778
// Image loading
7879
implementation(libs.coil.compose)

firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/MainActivity.kt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import androidx.navigation.compose.NavHost
2727
import androidx.navigation.compose.composable
2828
import androidx.navigation.compose.rememberNavController
2929
import com.google.firebase.quickstart.ai.feature.live.BidiViewModel
30+
import com.google.firebase.quickstart.ai.feature.hybrid.HybridInferenceViewModel
3031
import com.google.firebase.quickstart.ai.feature.media.imagen.ImagenViewModel
3132
import com.google.firebase.quickstart.ai.feature.text.ChatViewModel
3233
import com.google.firebase.quickstart.ai.feature.text.ServerPromptTemplateViewModel
@@ -36,6 +37,7 @@ import com.google.firebase.quickstart.ai.ui.ImagenScreen
3637
import com.google.firebase.quickstart.ai.ui.ServerPromptScreen
3738
import com.google.firebase.quickstart.ai.ui.StreamRealtimeScreen
3839
import com.google.firebase.quickstart.ai.ui.StreamRealtimeVideoScreen
40+
import com.google.firebase.quickstart.ai.ui.HybridInferenceScreen
3941
import com.google.firebase.quickstart.ai.ui.SvgScreen
4042
import com.google.firebase.quickstart.ai.ui.navigation.FIREBASE_AI_SAMPLES
4143
import com.google.firebase.quickstart.ai.ui.navigation.MainMenuScreen
@@ -123,6 +125,10 @@ class MainActivity : ComponentActivity() {
123125
StreamRealtimeVideoScreen(it)
124126
}
125127
}
128+
129+
ScreenType.HYBRID -> {
130+
(vm as? HybridInferenceViewModel)?.let { HybridInferenceScreen(it) }
131+
}
126132
}
127133
}
128134
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package com.google.firebase.quickstart.ai.feature.hybrid
2+
3+
import kotlinx.serialization.Serializable
4+
5+
@Serializable
6+
data class Expense(
7+
val name: String,
8+
val price: Double,
9+
val inferenceMode: String = ""
10+
)
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
package com.google.firebase.quickstart.ai.feature.hybrid
2+
3+
import android.graphics.Bitmap
4+
import android.util.Log
5+
import androidx.lifecycle.ViewModel
6+
import androidx.lifecycle.viewModelScope
7+
import com.google.firebase.Firebase
8+
import com.google.firebase.ai.InferenceMode
9+
import com.google.firebase.ai.InferenceSource
10+
import com.google.firebase.ai.OnDeviceConfig
11+
import com.google.firebase.ai.ai
12+
import com.google.firebase.ai.ondevice.DownloadStatus
13+
import com.google.firebase.ai.ondevice.FirebaseAIOnDevice
14+
import com.google.firebase.ai.ondevice.OnDeviceModelStatus
15+
import com.google.firebase.ai.type.GenerativeBackend
16+
import com.google.firebase.ai.type.PublicPreviewAPI
17+
import com.google.firebase.ai.type.content
18+
import com.google.firebase.quickstart.ai.ui.HybridInferenceUiState
19+
import kotlinx.coroutines.flow.MutableStateFlow
20+
import kotlinx.coroutines.flow.StateFlow
21+
import kotlinx.coroutines.flow.asStateFlow
22+
import kotlinx.coroutines.flow.update
23+
import kotlinx.coroutines.launch
24+
import kotlinx.serialization.Serializable
25+
import kotlinx.serialization.json.Json
26+
import java.util.UUID
27+
28+
@Serializable
29+
object HybridInferenceRoute
30+
31+
@OptIn(PublicPreviewAPI::class)
32+
class HybridInferenceViewModel : ViewModel() {
33+
private val _uiState = MutableStateFlow(
34+
HybridInferenceUiState(
35+
expenses = listOf(
36+
Expense("Lunch", 15.50, "Example data"),
37+
Expense("Coffee", 4.75, "Example data")
38+
)
39+
)
40+
)
41+
val uiState: StateFlow<HybridInferenceUiState> = _uiState.asStateFlow()
42+
43+
private val model = Firebase.ai(backend = GenerativeBackend.googleAI()).generativeModel(
44+
modelName = "gemini-3.1-flash-lite-preview",
45+
onDeviceConfig = OnDeviceConfig(mode = InferenceMode.PREFER_ON_DEVICE)
46+
)
47+
48+
init {
49+
checkAndDownloadModel()
50+
}
51+
52+
private fun checkAndDownloadModel() {
53+
viewModelScope.launch {
54+
try {
55+
val status = FirebaseAIOnDevice.checkStatus()
56+
updateStatus(status)
57+
58+
if (status == OnDeviceModelStatus.DOWNLOADABLE) {
59+
FirebaseAIOnDevice.download().collect { downloadStatus ->
60+
when (downloadStatus) {
61+
is DownloadStatus.DownloadStarted -> {
62+
_uiState.update { it.copy(modelStatus = "Downloading model...") }
63+
}
64+
65+
is DownloadStatus.DownloadInProgress -> {
66+
val progress = downloadStatus.totalBytesDownloaded
67+
_uiState.update { it.copy(modelStatus = "Downloading: $progress bytes downloaded") }
68+
}
69+
70+
is DownloadStatus.DownloadCompleted -> {
71+
_uiState.update { it.copy(modelStatus = "Model ready") }
72+
}
73+
74+
is DownloadStatus.DownloadFailed -> {
75+
_uiState.update {
76+
it.copy(
77+
modelStatus = "Download failed", errorMessage = "Model download failed"
78+
)
79+
}
80+
}
81+
}
82+
}
83+
}
84+
} catch (e: Exception) {
85+
_uiState.update { it.copy(modelStatus = "Error checking status", errorMessage = e.message) }
86+
}
87+
}
88+
}
89+
90+
private fun updateStatus(status: OnDeviceModelStatus) {
91+
val statusText = when (status) {
92+
OnDeviceModelStatus.AVAILABLE -> "Model available"
93+
OnDeviceModelStatus.DOWNLOADABLE -> "Model downloadable"
94+
OnDeviceModelStatus.DOWNLOADING -> "Model downloading..."
95+
OnDeviceModelStatus.UNAVAILABLE -> "On-device model unavailable"
96+
else -> "Unknown"
97+
}
98+
_uiState.update { it.copy(modelStatus = statusText) }
99+
}
100+
101+
fun scanReceipt(bitmap: Bitmap) {
102+
viewModelScope.launch {
103+
_uiState.update { it.copy(isScanning = true, errorMessage = null) }
104+
try {
105+
val prompt = content {
106+
image(bitmap)
107+
text(
108+
"""
109+
Extract the store name and the total price from this receipt.
110+
Output only in JSON format containg 2 fields '{name,price}'.
111+
Do not include any currency signs or backticks or any text around it.
112+
Use dots for decimals.
113+
Examples:
114+
- {"name": "FakeStore", "price": "2.0"}
115+
- {"name": "SomeMarket", "price": "3.5"}
116+
""".trimIndent()
117+
)
118+
}
119+
120+
val response = model.generateContent(prompt)
121+
val text = response.text
122+
val inferenceMode = if (response.inferenceSource == InferenceSource.ON_DEVICE) {
123+
"On-device"
124+
} else {
125+
"Cloud"
126+
}
127+
Log.d("HybridVM", "$inferenceMode response: $text")
128+
if (text != null) {
129+
parseAndAddExpense(text, inferenceMode)
130+
} else {
131+
_uiState.update { it.copy(errorMessage = "Could not extract data") }
132+
}
133+
} catch (e: Exception) {
134+
_uiState.update { it.copy(errorMessage = "Error: ${e.message}") }
135+
} finally {
136+
_uiState.update { it.copy(isScanning = false) }
137+
}
138+
}
139+
}
140+
141+
private fun parseAndAddExpense(text: String, inferenceMode: String) {
142+
val json = text
143+
// The on-device model sometimes outputs backticks, so we remove those
144+
.replace("```json", "")
145+
.replace("```", "")
146+
try {
147+
val newExpense = Json.decodeFromString<Expense>(json).copy(inferenceMode = inferenceMode)
148+
_uiState.update { it.copy(expenses = it.expenses + newExpense) }
149+
} catch (e: Exception) {
150+
_uiState.update { it.copy(errorMessage = e.localizedMessage) }
151+
}
152+
}
153+
}

0 commit comments

Comments
 (0)