Skip to content
This repository was archived by the owner on Aug 7, 2025. It is now read-only.

Commit 7b1c600

Browse files
authored
Merge pull request #1122 from pytorch/issue_1107
Issue 1107
2 parents 3196b43 + 23b70eb commit 7b1c600

File tree

12 files changed

+221
-10
lines changed

12 files changed

+221
-10
lines changed

docs/FAQs.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,6 @@ You can use both s3 v2 and v4 signature URLs.
146146
Note: For v4 type replace `&` characters in model url with its URL encoding character in the curl command i.e.`%26`.
147147

148148
Relevant issues: [[#669](https://github.com/pytorch/serve/issues/669)]
149+
150+
### How to set a model's batch size on SageMaker? Key parameters for TorchServe performance tuning.
151+
[TorchServe performance tuning example](https://github.com/lxning/torchserve_perf/blob/master/torchserve_perf.ipynb)

docs/configuration.md

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,56 @@ By default, TorchServe uses all available GPUs for inference. Use `number_of_gpu
213213
* `metrics_format` : Use this to specify metric report format . At present, the only supported and default value for this is `prometheus`
214214
This is used in conjunction with `enable_metrics_api` option above.
215215

216+
### Config model
217+
* `models`: Use this to set configuration of each model. The value is presented in json format.
218+
```
219+
{
220+
"modelName": {
221+
"version": {
222+
"parameterName1": parameterValue1,
223+
"parameterName2": parameterValue2,
224+
"parameterNameN": parameterValueN,
225+
}
226+
}
227+
}
228+
```
229+
A model's parameters are defined in [model source code](https://github.com/pytorch/serve/blob/master/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java#L24)
230+
```
231+
minWorkers: the minimum number of workers of a model
232+
maxWorkers: the maximum number of workers of a model
233+
batchSize: the batch size of a model
234+
maxBatchDelay: the maximum dalay in msec of a batch of a model
235+
responseTimeout: the timeout in msec of a model's response
236+
defaultVersion: the default version of a model
237+
marName: the mar file name of a model
238+
```
239+
A model's configuration example
240+
```properties
241+
models={\
242+
"noop": {\
243+
"1.0": {\
244+
"defaultVersion": true,\
245+
"marName": "noop.mar",\
246+
"minWorkers": 1,\
247+
"maxWorkers": 1,\
248+
"batchSize": 4,\
249+
"maxBatchDelay": 100,\
250+
"responseTimeout": 120\
251+
}\
252+
},\
253+
"vgg16": {\
254+
"1.0": {\
255+
"defaultVersion": true,\
256+
"marName": "vgg16.mar",\
257+
"minWorkers": 1,\
258+
"maxWorkers": 4,\
259+
"batchSize": 8,\
260+
"maxBatchDelay": 100,\
261+
"responseTimeout": 120\
262+
}\
263+
}\
264+
}
265+
```
216266

217267
### Other properties
218268

frontend/server/src/main/java/org/pytorch/serve/ModelServer.java

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,16 @@ private void initModelStore() throws InvalidSnapshotException, IOException {
195195
modelManager.updateModel(
196196
archive.getModelName(),
197197
archive.getModelVersion(),
198-
workers,
199-
workers,
198+
configManager.getJsonIntValue(
199+
archive.getModelName(),
200+
archive.getModelVersion(),
201+
Model.MIN_WORKERS,
202+
workers),
203+
configManager.getJsonIntValue(
204+
archive.getModelName(),
205+
archive.getModelVersion(),
206+
Model.MAX_WORKERS,
207+
workers),
200208
true,
201209
false);
202210
startupModels.add(archive.getModelName());
@@ -246,8 +254,16 @@ private void initModelStore() throws InvalidSnapshotException, IOException {
246254
modelManager.updateModel(
247255
archive.getModelName(),
248256
archive.getModelVersion(),
249-
workers,
250-
workers,
257+
configManager.getJsonIntValue(
258+
archive.getModelName(),
259+
archive.getModelVersion(),
260+
Model.MIN_WORKERS,
261+
workers),
262+
configManager.getJsonIntValue(
263+
archive.getModelName(),
264+
archive.getModelVersion(),
265+
Model.MAX_WORKERS,
266+
workers),
251267
true,
252268
false);
253269
startupModels.add(archive.getModelName());

frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
package org.pytorch.serve.util;
22

3+
import com.google.gson.JsonObject;
4+
import com.google.gson.reflect.TypeToken;
35
import io.netty.handler.ssl.SslContext;
46
import io.netty.handler.ssl.SslContextBuilder;
57
import io.netty.handler.ssl.util.SelfSignedCertificate;
68
import java.io.File;
79
import java.io.IOException;
810
import java.io.InputStream;
911
import java.lang.reflect.Field;
12+
import java.lang.reflect.Type;
1013
import java.net.InetAddress;
1114
import java.net.UnknownHostException;
1215
import java.nio.charset.StandardCharsets;
@@ -29,6 +32,7 @@
2932
import java.util.HashMap;
3033
import java.util.InvalidPropertiesFormatException;
3134
import java.util.List;
35+
import java.util.Map;
3236
import java.util.Properties;
3337
import java.util.Set;
3438
import java.util.regex.Matcher;
@@ -95,6 +99,7 @@ public final class ConfigManager {
9599
private static final String METRIC_TIME_INTERVAL = "metric_time_interval";
96100
private static final String ENABLE_ENVVARS_CONFIG = "enable_envvars_config";
97101
private static final String MODEL_SNAPSHOT = "model_snapshot";
102+
private static final String MODEL_CONFIG = "models";
98103
private static final String VERSION = "version";
99104

100105
// Variables which are local
@@ -119,6 +124,7 @@ public final class ConfigManager {
119124

120125
private static ConfigManager instance;
121126
private String hostName;
127+
private Map<String, Map<String, JsonObject>> modelConfig = new HashMap<>();
122128

123129
private ConfigManager(Arguments args) throws IOException {
124130
prop = new Properties();
@@ -215,6 +221,8 @@ private ConfigManager(Arguments args) throws IOException {
215221
// Environment variables have higher precedence over the config file variables
216222
setSystemVars();
217223
}
224+
225+
setModelConfig();
218226
}
219227

220228
public static String readFile(String path) throws IOException {
@@ -607,7 +615,9 @@ public String dumpConfigurations() {
607615
+ "\nEnable metrics API: "
608616
+ prop.getProperty(TS_ENABLE_METRICS_API, "true")
609617
+ "\nWorkflow Store: "
610-
+ (getWorkflowStore() == null ? "N/A" : getWorkflowStore());
618+
+ (getWorkflowStore() == null ? "N/A" : getWorkflowStore())
619+
+ "\nModel config: "
620+
+ prop.getProperty(MODEL_CONFIG, "N/A");
611621
}
612622

613623
public boolean useNativeIo() {
@@ -768,6 +778,43 @@ public void setInitialWorkerPort(int initialPort) {
768778
prop.setProperty(TS_INITIAL_WORKER_PORT, String.valueOf(initialPort));
769779
}
770780

781+
private void setModelConfig() {
782+
String modelConfigStr = prop.getProperty(MODEL_CONFIG, null);
783+
Type type = new TypeToken<Map<String, Map<String, JsonObject>>>() {}.getType();
784+
785+
if (modelConfigStr != null) {
786+
this.modelConfig = JsonUtils.GSON.fromJson(modelConfigStr, type);
787+
}
788+
}
789+
790+
public int getJsonIntValue(String modelName, String version, String element, int defaultVal) {
791+
int value = defaultVal;
792+
if (this.modelConfig.containsKey(modelName)) {
793+
Map<String, JsonObject> versionModel = this.modelConfig.get(modelName);
794+
JsonObject jsonObject = versionModel.getOrDefault(version, null);
795+
796+
if (jsonObject != null && jsonObject.get(element) != null) {
797+
try {
798+
value = jsonObject.get(element).getAsInt();
799+
if (value <= 0) {
800+
value = defaultVal;
801+
}
802+
} catch (ClassCastException | IllegalStateException e) {
803+
Logger.getRootLogger()
804+
.error(
805+
"Invalid value for model: "
806+
+ modelName
807+
+ ":"
808+
+ version
809+
+ ", parameter: "
810+
+ element);
811+
return defaultVal;
812+
}
813+
}
814+
}
815+
return value;
816+
}
817+
771818
public static final class Arguments {
772819

773820
private String tsConfigFile;

frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,9 +260,25 @@ private Model createModel(
260260
int responseTimeout,
261261
boolean isWorkflowModel) {
262262
Model model = new Model(archive, configManager.getJobQueueSize());
263-
model.setBatchSize(batchSize);
264-
model.setMaxBatchDelay(maxBatchDelay);
265-
model.setResponseTimeout(responseTimeout);
263+
264+
model.setBatchSize(
265+
configManager.getJsonIntValue(
266+
archive.getModelName(),
267+
archive.getModelVersion(),
268+
Model.BATCH_SIZE,
269+
batchSize));
270+
model.setMaxBatchDelay(
271+
configManager.getJsonIntValue(
272+
archive.getModelName(),
273+
archive.getModelVersion(),
274+
Model.MAX_BATCH_DELAY,
275+
maxBatchDelay));
276+
model.setResponseTimeout(
277+
configManager.getJsonIntValue(
278+
archive.getModelName(),
279+
archive.getModelVersion(),
280+
Model.RESPONSE_TIMEOUT,
281+
responseTimeout));
266282
model.setWorkflowModel(isWorkflowModel);
267283

268284
return model;
@@ -383,6 +399,7 @@ public CompletableFuture<Integer> updateModel(
383399
throw new ModelVersionNotFoundException(
384400
"Model version: " + versionId + " does not exist for model: " + modelName);
385401
}
402+
386403
model.setMinWorkers(minWorkers);
387404
model.setMaxWorkers(maxWorkers);
388405
logger.debug("updateModel: {}, count: {}", modelName, minWorkers);

frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import org.pytorch.serve.util.ConfigManager;
4545
import org.pytorch.serve.util.ConnectorType;
4646
import org.pytorch.serve.util.JsonUtils;
47+
import org.pytorch.serve.wlm.Model;
4748
import org.testng.Assert;
4849
import org.testng.annotations.AfterClass;
4950
import org.testng.annotations.BeforeSuite;
@@ -200,7 +201,10 @@ public void testInitialWorkers() throws InterruptedException {
200201
DescribeModelResponse[] resp =
201202
JsonUtils.GSON.fromJson(TestUtils.getResult(), DescribeModelResponse[].class);
202203
Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.OK);
203-
Assert.assertEquals(resp[0].getMinWorkers(), configManager.getDefaultWorkers());
204+
Assert.assertEquals(
205+
resp[0].getMinWorkers(),
206+
configManager.getJsonIntValue(
207+
"noop", "1.11", Model.MIN_WORKERS, configManager.getDefaultWorkers()));
204208
}
205209

206210
@Test(

frontend/server/src/test/java/org/pytorch/serve/SnapshotTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ public class SnapshotTest {
5353
public void beforeSuite()
5454
throws InterruptedException, IOException, GeneralSecurityException,
5555
InvalidSnapshotException {
56-
System.setProperty("tsConfigFile", "src/test/resources/config.properties");
56+
System.setProperty("tsConfigFile", "src/test/resources/config_snapshot.properties");
5757
FileUtils.cleanDirectory(new File(System.getProperty("LOG_LOCATION"), "config"));
5858

5959
ConfigManager.init(new ConfigManager.Arguments());

frontend/server/src/test/java/org/pytorch/serve/util/ConfigManagerTest.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ public void testNoEnvVars() throws ReflectiveOperationException, IOException {
9797
ConfigManager configManager = ConfigManager.getInstance();
9898
Assert.assertEquals("false", configManager.getEnableEnvVarsConfig());
9999
Assert.assertEquals(120, configManager.getDefaultResponseTimeout());
100+
Assert.assertEquals(4, configManager.getJsonIntValue("noop", "1.0", "batchSize", 1));
101+
Assert.assertEquals(4, configManager.getJsonIntValue("vgg16", "1.0", "maxWorkers", 1));
100102
modifyEnv("TS_DEFAULT_RESPONSE_TIMEOUT", "120");
101103
}
102104
}

frontend/server/src/test/resources/config.properties

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,19 @@ max_request_size=2047093252
2828
# blacklist_env_vars=.*USERNAME.*|.*PASSWORD.*
2929
# decode_input_request=true
3030
enable_envvars_config=true
31+
models={\
32+
"noop": {\
33+
"1.11": {\
34+
"defaultVersion": true,\
35+
"marName": "noop.mar",\
36+
"minWorkers": 1,\
37+
"maxWorkers": 1,\
38+
"batchSize": 4,\
39+
"maxBatchDelay": 100,\
40+
"responseTimeout": 120\
41+
}\
42+
}\
43+
}
3144
# default_service_handler=/path/to/service.py:handle
3245
# install_py_dep_per_model=false
3346
# enable_metrics_api=false
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# debug=true
2+
# vmargs=-Xmx128m -XX:-UseLargePages -XX:+UseG1GC -XX:MaxMetaspaceSize=32M -XX:MaxDirectMemorySize=10m -XX:+ExitOnOutOfMemoryError
3+
inference_address=https://127.0.0.1:8443
4+
management_address=https://127.0.0.1:8444
5+
metrics_address=https://127.0.0.1:8445
6+
# model_server_home=../..
7+
model_store=../archive/src/test/resources/models
8+
load_models=noop.mar
9+
# number_of_netty_threads=0
10+
# netty_client_threads=0
11+
# default_workers_per_model=0
12+
# job_queue_size=100
13+
# plugins_path=/tmp/plugins
14+
async_logging=true
15+
default_response_timeout=120
16+
unregister_model_timeout=120
17+
# number_of_gpu=1
18+
# cors_allowed_origin
19+
# cors_allowed_methods
20+
# cors_allowed_headers
21+
# keystore=src/test/resources/keystore.p12
22+
# keystore_pass=changeit
23+
# keystore_type=PKCS12
24+
private_key_file=src/test/resources/key.pem
25+
certificate_file=src/test/resources/certs.pem
26+
max_response_size=2047093252
27+
max_request_size=2047093252
28+
# blacklist_env_vars=.*USERNAME.*|.*PASSWORD.*
29+
# decode_input_request=true
30+
enable_envvars_config=true
31+
# default_service_handler=/path/to/service.py:handle
32+
# install_py_dep_per_model=false
33+
# enable_metrics_api=false
34+
workflow_store=../archive/src/test/resources/workflows

0 commit comments

Comments
 (0)