Skip to content

Commit 2a95e58

Browse files
Google AI Edge Gallerycopybara-github
authored andcommitted
Add a simple local test for allowlisted model.
PiperOrigin-RevId: 775265777
1 parent d0989ad commit 2a95e58

File tree

2 files changed

+111
-4
lines changed

2 files changed

+111
-4
lines changed

Android/src/app/src/main/java/com/google/ai/edge/gallery/data/ModelAllowlist.kt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ data class AllowedModel(
5050
taskTypes.contains(TASK_LLM_CHAT.type.id) || taskTypes.contains(TASK_LLM_PROMPT_LAB.type.id)
5151
var configs: List<Config> = listOf()
5252
if (isLlmModel) {
53-
var defaultTopK: Int = defaultConfig.topK ?: DEFAULT_TOPK
54-
var defaultTopP: Float = defaultConfig.topP ?: DEFAULT_TOPP
55-
var defaultTemperature: Float = defaultConfig.temperature ?: DEFAULT_TEMPERATURE
56-
var defaultMaxToken = defaultConfig.maxTokens ?: 1024
53+
val defaultTopK: Int = defaultConfig.topK ?: DEFAULT_TOPK
54+
val defaultTopP: Float = defaultConfig.topP ?: DEFAULT_TOPP
55+
val defaultTemperature: Float = defaultConfig.temperature ?: DEFAULT_TEMPERATURE
56+
val defaultMaxToken = defaultConfig.maxTokens ?: 1024
5757
var accelerators: List<Accelerator> = DEFAULT_ACCELERATORS
5858
if (defaultConfig.accelerators != null) {
5959
val items = defaultConfig.accelerators.split(",")
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.ai.edge.gallery.data
18+
19+
import org.junit.Assert.assertEquals
20+
import org.junit.Assert.assertFalse
21+
import org.junit.Assert.assertTrue
22+
import org.junit.Test
23+
import org.junit.runner.RunWith
24+
import org.junit.runners.JUnit4
25+
26+
@RunWith(JUnit4::class)
27+
class ModelAllowlistTest {
28+
@Test
29+
fun toModel_success() {
30+
val modelName = "test_model"
31+
val modelId = "test_model_id"
32+
val modelFile = "test_model_file"
33+
val description = "test description"
34+
val sizeInBytes = 100L
35+
val version = "20250623"
36+
val topK = 10
37+
val topP = 0.5f
38+
val temperature = 0.1f
39+
val maxTokens = 1000
40+
val accelerators = "gpu,cpu"
41+
val taskTypes = listOf("llm_chat", "ask_image")
42+
val estimatedPeakMemoryInBytes = 300L
43+
44+
val allowedModel =
45+
AllowedModel(
46+
name = modelName,
47+
modelId = modelId,
48+
modelFile = modelFile,
49+
description = description,
50+
sizeInBytes = sizeInBytes,
51+
version = version,
52+
defaultConfig =
53+
DefaultConfig(
54+
topK = topK,
55+
topP = topP,
56+
temperature = temperature,
57+
maxTokens = maxTokens,
58+
accelerators = accelerators,
59+
),
60+
taskTypes = taskTypes,
61+
llmSupportImage = true,
62+
llmSupportAudio = true,
63+
estimatedPeakMemoryInBytes = estimatedPeakMemoryInBytes,
64+
)
65+
val model = allowedModel.toModel()
66+
67+
// Check that basic fields are set correctly.
68+
assertEquals(model.name, modelName)
69+
assertEquals(model.version, version)
70+
assertEquals(model.info, description)
71+
assertEquals(
72+
model.url,
73+
"https://huggingface.co/test_model_id/resolve/main/test_model_file?download=true",
74+
)
75+
assertEquals(model.sizeInBytes, sizeInBytes)
76+
assertEquals(model.estimatedPeakMemoryInBytes, estimatedPeakMemoryInBytes)
77+
assertEquals(model.downloadFileName, modelFile)
78+
assertFalse(model.showBenchmarkButton)
79+
assertFalse(model.showRunAgainButton)
80+
assertTrue(model.llmSupportImage)
81+
assertTrue(model.llmSupportAudio)
82+
83+
// Check that configs are set correctly.
84+
assertEquals(model.configs.size, 5)
85+
86+
// A label for showing max tokens (non-changeable).
87+
assertTrue(model.configs[0] is LabelConfig)
88+
assertEquals((model.configs[0] as LabelConfig).defaultValue, "$maxTokens")
89+
90+
// A slider for topK.
91+
assertTrue(model.configs[1] is NumberSliderConfig)
92+
assertEquals((model.configs[1] as NumberSliderConfig).defaultValue, topK.toFloat())
93+
94+
// A slider for topP.
95+
assertTrue(model.configs[2] is NumberSliderConfig)
96+
assertEquals((model.configs[2] as NumberSliderConfig).defaultValue, topP)
97+
98+
// A slider for temperature.
99+
assertTrue(model.configs[3] is NumberSliderConfig)
100+
assertEquals((model.configs[3] as NumberSliderConfig).defaultValue, temperature)
101+
102+
// A segmented button for accelerators.
103+
assertTrue(model.configs[4] is SegmentedButtonConfig)
104+
assertEquals((model.configs[4] as SegmentedButtonConfig).defaultValue, "GPU")
105+
assertEquals((model.configs[4] as SegmentedButtonConfig).options, listOf("GPU", "CPU"))
106+
}
107+
}

0 commit comments

Comments
 (0)