Skip to content

Commit 98a8e28

Browse files
authored
Add server tests (#142)
* add enable jax profiler to run_server * add run_server.py unit tests * add ring buffer test
1 parent 476593b commit 98a8e28

File tree

1 file changed

+138
-0
lines changed

1 file changed

+138
-0
lines changed

tests/test_run_server.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
from unittest.mock import patch, MagicMock
17+
from absl import app
18+
from absl.testing import flagsaver
19+
from parameterized import parameterized, param
20+
21+
22+
class MockServer(MagicMock):
23+
24+
def run(self, **kwargs):
25+
return self
26+
27+
def wait_for_termination(self):
28+
raise SystemExit("Successfully exited test.")
29+
30+
31+
def mock_engine(**kwargs):
32+
return kwargs
33+
34+
35+
class ServerRunTest(unittest.TestCase):
36+
37+
def reset_flags(self):
38+
flagsaver.restore_flag_values(self.original)
39+
40+
def setup(self):
41+
from run_server import flags
42+
43+
FLAGS = flags.FLAGS
44+
self.original = flagsaver.save_flag_values()
45+
return FLAGS
46+
47+
@parameterized.expand(
48+
[
49+
param(["test1", "--model_name", "llama-3"], "llama-3"),
50+
param(["test2", "--model_name", "llama-2"], "llama-2"),
51+
param(["test3", "--model_name", "mixtral"], "mixtral"),
52+
param(["test4", "--model_name", "gemma"], "gemma"),
53+
]
54+
)
55+
@patch("jetstream_pt.engine.create_pytorch_engine", mock_engine)
56+
@patch("jetstream.core.server_lib.run", MockServer().run)
57+
def test_no_change_from_defaults(self, args, expected):
58+
"""test defaults remain unchanged when launching a server for different models.
59+
60+
Args:
61+
args (List): List to simulate sys.argv with dummy first entry at index 0.
62+
expected (str): model_name flag value to inspect
63+
"""
64+
from run_server import main
65+
66+
FLAGS = self.setup()
67+
with self.assertRaisesRegex(SystemExit, "Successfully exited test."):
68+
app.run(main, args)
69+
70+
# run_server
71+
self.assertEqual(FLAGS.port, 9000)
72+
self.assertEqual(FLAGS.threads, 64)
73+
self.assertEqual(FLAGS.config, "InterleavedCPUTestServer")
74+
self.assertEqual(FLAGS.prometheus_port, 0)
75+
self.assertEqual(FLAGS.enable_jax_profiler, False)
76+
self.assertEqual(FLAGS.jax_profiler_port, 9999)
77+
78+
# quantization configs
79+
self.assertEqual(FLAGS.quantize_weights, False)
80+
self.assertEqual(FLAGS.quantize_activation, False)
81+
self.assertEqual(FLAGS.quantize_type, "int8_per_channel")
82+
self.assertEqual(FLAGS.quantize_kv_cache, False)
83+
84+
# engine configs
85+
self.assertEqual(FLAGS.tokenizer_path, None)
86+
self.assertEqual(FLAGS.checkpoint_path, None)
87+
self.assertEqual(FLAGS.bf16_enable, True)
88+
self.assertEqual(FLAGS.context_length, 1024)
89+
self.assertEqual(FLAGS.batch_size, 32)
90+
self.assertEqual(FLAGS.size, "tiny")
91+
self.assertEqual(FLAGS.max_cache_length, 1024)
92+
self.assertEqual(FLAGS.shard_on_batch, False)
93+
self.assertEqual(FLAGS.sharding_config, "")
94+
self.assertEqual(FLAGS.ragged_mha, False)
95+
self.assertEqual(FLAGS.starting_position, 512)
96+
self.assertEqual(FLAGS.temperature, 1.0)
97+
self.assertEqual(FLAGS.sampling_algorithm, "greedy")
98+
self.assertEqual(FLAGS.nucleus_topp, 0.0)
99+
self.assertEqual(FLAGS.topk, 0)
100+
self.assertEqual(FLAGS.ring_buffer, True)
101+
102+
# profiling configs
103+
self.assertEqual(FLAGS.profiling_prefill, False)
104+
self.assertEqual(FLAGS.profiling_output, "")
105+
106+
# model_name flag updates
107+
self.assertEqual(FLAGS.model_name, expected)
108+
109+
# reset back to original flags
110+
self.reset_flags()
111+
112+
@parameterized.expand([param(["test1", "--model_name", "llama3"])])
113+
@patch("jetstream_pt.engine.create_pytorch_engine", mock_engine)
114+
def test_call_server_object(self, args):
115+
"""tests whether running the main script from absl.app.run launches a server and waits for termination
116+
117+
Args:
118+
args (List): List to simulate sys.argv with dummy first entry at index 0.
119+
"""
120+
with patch(
121+
"jetstream.core.server_lib.run", autospec=MockServer().run
122+
) as mock_server:
123+
from run_server import main
124+
125+
FLAGS = self.setup()
126+
with self.assertRaises(SystemExit):
127+
app.run(main, args)
128+
self.assertEqual(mock_server.call_count, 1)
129+
self.assertEqual(
130+
mock_server.return_value.wait_for_termination.call_count, 1
131+
)
132+
133+
# reset back to original flags
134+
self.reset_flags()
135+
136+
137+
if __name__ == "__main__":
138+
unittest.main()

0 commit comments

Comments
 (0)