Skip to content
This repository was archived by the owner on Aug 7, 2025. It is now read-only.
Merged
3 changes: 3 additions & 0 deletions docs/FAQs.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,6 @@ You can use both s3 v2 and v4 signature URLs.
Note: For v4 type replace `&` characters in model url with its URL encoding character in the curl command i.e.`%26`.

Relevant issues: [[#669](https://github.com/pytorch/serve/issues/669)]

### How to set a model's batch size on SageMaker? Key parameters for TorchServe performance tuning.
[TorchServe performance tuning example](https://github.com/lxning/torchserve_perf/blob/master/torchserve_perf.ipynb)
50 changes: 50 additions & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,56 @@ By default, TorchServe uses all available GPUs for inference. Use `number_of_gpu
* `metrics_format` : Use this to specify metric report format . At present, the only supported and default value for this is `prometheus`
This is used in conjunction with `enable_metrics_api` option above.

### Config model
* `models`: Use this to set configuration of each model. The value is presented in json format.
```
{
"modelName": {
"version": {
"parameterName1": parameterValue1,
"parameterName2": parameterValue2,
"parameterNameN": parameterValueN,
}
}
}
```
A model's parameters are defined in [model source code](https://github.com/pytorch/serve/blob/master/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java#L24)
```
minWorkers: the minimum number of workers of a model
maxWorkers: the maximum number of workers of a model
batchSize: the batch size of a model
maxBatchDelay: the maximum dalay in msec of a batch of a model
responseTimeout: the timeout in msec of a model's response
defaultVersion: the default version of a model
marName: the mar file name of a model
```
A model's configuration example
```properties
models={\
"noop": {\
"1.0": {\
"defaultVersion": true,\
"marName": "noop.mar",\
"minWorkers": 1,\
"maxWorkers": 1,\
"batchSize": 4,\
"maxBatchDelay": 100,\
"responseTimeout": 120\
}\
},\
"vgg16": {\
"1.0": {\
"defaultVersion": true,\
"marName": "vgg16.mar",\
"minWorkers": 1,\
"maxWorkers": 4,\
"batchSize": 8,\
"maxBatchDelay": 100,\
"responseTimeout": 120\
}\
}\
}
```

### Other properties

Expand Down
24 changes: 20 additions & 4 deletions frontend/server/src/main/java/org/pytorch/serve/ModelServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,16 @@ private void initModelStore() throws InvalidSnapshotException, IOException {
modelManager.updateModel(
archive.getModelName(),
archive.getModelVersion(),
workers,
workers,
configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.MIN_WORKERS,
workers),
configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.MAX_WORKERS,
workers),
true,
false);
startupModels.add(archive.getModelName());
Expand Down Expand Up @@ -246,8 +254,16 @@ private void initModelStore() throws InvalidSnapshotException, IOException {
modelManager.updateModel(
archive.getModelName(),
archive.getModelVersion(),
workers,
workers,
configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.MIN_WORKERS,
workers),
configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.MAX_WORKERS,
workers),
true,
false);
startupModels.add(archive.getModelName());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package org.pytorch.serve.util;

import com.google.gson.JsonObject;
import com.google.gson.reflect.TypeToken;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.lang.reflect.Type;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.nio.charset.StandardCharsets;
Expand All @@ -29,6 +32,7 @@
import java.util.HashMap;
import java.util.InvalidPropertiesFormatException;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import java.util.regex.Matcher;
Expand Down Expand Up @@ -95,6 +99,7 @@ public final class ConfigManager {
private static final String METRIC_TIME_INTERVAL = "metric_time_interval";
private static final String ENABLE_ENVVARS_CONFIG = "enable_envvars_config";
private static final String MODEL_SNAPSHOT = "model_snapshot";
private static final String MODEL_CONFIG = "models";
private static final String VERSION = "version";

// Variables which are local
Expand All @@ -119,6 +124,7 @@ public final class ConfigManager {

private static ConfigManager instance;
private String hostName;
private Map<String, Map<String, JsonObject>> modelConfig = new HashMap<>();

private ConfigManager(Arguments args) throws IOException {
prop = new Properties();
Expand Down Expand Up @@ -215,6 +221,8 @@ private ConfigManager(Arguments args) throws IOException {
// Environment variables have higher precedence over the config file variables
setSystemVars();
}

setModelConfig();
}

public static String readFile(String path) throws IOException {
Expand Down Expand Up @@ -607,7 +615,9 @@ public String dumpConfigurations() {
+ "\nEnable metrics API: "
+ prop.getProperty(TS_ENABLE_METRICS_API, "true")
+ "\nWorkflow Store: "
+ (getWorkflowStore() == null ? "N/A" : getWorkflowStore());
+ (getWorkflowStore() == null ? "N/A" : getWorkflowStore())
+ "\nModel config: "
+ prop.getProperty(MODEL_CONFIG, "N/A");
}

public boolean useNativeIo() {
Expand Down Expand Up @@ -768,6 +778,43 @@ public void setInitialWorkerPort(int initialPort) {
prop.setProperty(TS_INITIAL_WORKER_PORT, String.valueOf(initialPort));
}

private void setModelConfig() {
String modelConfigStr = prop.getProperty(MODEL_CONFIG, null);
Type type = new TypeToken<Map<String, Map<String, JsonObject>>>() {}.getType();

if (modelConfigStr != null) {
this.modelConfig = JsonUtils.GSON.fromJson(modelConfigStr, type);
}
}

public int getJsonIntValue(String modelName, String version, String element, int defaultVal) {
int value = defaultVal;
if (this.modelConfig.containsKey(modelName)) {
Map<String, JsonObject> versionModel = this.modelConfig.get(modelName);
JsonObject jsonObject = versionModel.getOrDefault(version, null);

if (jsonObject != null && jsonObject.get(element) != null) {
try {
value = jsonObject.get(element).getAsInt();
if (value <= 0) {
value = defaultVal;
}
} catch (ClassCastException | IllegalStateException e) {
Logger.getRootLogger()
.error(
"Invalid value for model: "
+ modelName
+ ":"
+ version
+ ", parameter: "
+ element);
return defaultVal;
}
}
}
return value;
}

public static final class Arguments {

private String tsConfigFile;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,25 @@ private Model createModel(
int responseTimeout,
boolean isWorkflowModel) {
Model model = new Model(archive, configManager.getJobQueueSize());
model.setBatchSize(batchSize);
model.setMaxBatchDelay(maxBatchDelay);
model.setResponseTimeout(responseTimeout);

model.setBatchSize(
configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.BATCH_SIZE,
batchSize));
model.setMaxBatchDelay(
configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.MAX_BATCH_DELAY,
maxBatchDelay));
model.setResponseTimeout(
configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.RESPONSE_TIMEOUT,
responseTimeout));
model.setWorkflowModel(isWorkflowModel);

return model;
Expand Down Expand Up @@ -383,6 +399,7 @@ public CompletableFuture<Integer> updateModel(
throw new ModelVersionNotFoundException(
"Model version: " + versionId + " does not exist for model: " + modelName);
}

model.setMinWorkers(minWorkers);
model.setMaxWorkers(maxWorkers);
logger.debug("updateModel: {}, count: {}", modelName, minWorkers);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.pytorch.serve.util.ConfigManager;
import org.pytorch.serve.util.ConnectorType;
import org.pytorch.serve.util.JsonUtils;
import org.pytorch.serve.wlm.Model;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeSuite;
Expand Down Expand Up @@ -200,7 +201,10 @@ public void testInitialWorkers() throws InterruptedException {
DescribeModelResponse[] resp =
JsonUtils.GSON.fromJson(TestUtils.getResult(), DescribeModelResponse[].class);
Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.OK);
Assert.assertEquals(resp[0].getMinWorkers(), configManager.getDefaultWorkers());
Assert.assertEquals(
resp[0].getMinWorkers(),
configManager.getJsonIntValue(
"noop", "1.11", Model.MIN_WORKERS, configManager.getDefaultWorkers()));
}

@Test(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public class SnapshotTest {
public void beforeSuite()
throws InterruptedException, IOException, GeneralSecurityException,
InvalidSnapshotException {
System.setProperty("tsConfigFile", "src/test/resources/config.properties");
System.setProperty("tsConfigFile", "src/test/resources/config_snapshot.properties");
FileUtils.cleanDirectory(new File(System.getProperty("LOG_LOCATION"), "config"));

ConfigManager.init(new ConfigManager.Arguments());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ public void testNoEnvVars() throws ReflectiveOperationException, IOException {
ConfigManager configManager = ConfigManager.getInstance();
Assert.assertEquals("false", configManager.getEnableEnvVarsConfig());
Assert.assertEquals(120, configManager.getDefaultResponseTimeout());
Assert.assertEquals(4, configManager.getJsonIntValue("noop", "1.0", "batchSize", 1));
Assert.assertEquals(4, configManager.getJsonIntValue("vgg16", "1.0", "maxWorkers", 1));
modifyEnv("TS_DEFAULT_RESPONSE_TIMEOUT", "120");
}
}
13 changes: 13 additions & 0 deletions frontend/server/src/test/resources/config.properties
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,19 @@ max_request_size=2047093252
# blacklist_env_vars=.*USERNAME.*|.*PASSWORD.*
# decode_input_request=true
enable_envvars_config=true
models={\
"noop": {\
"1.11": {\
"defaultVersion": true,\
"marName": "noop.mar",\
"minWorkers": 1,\
"maxWorkers": 1,\
"batchSize": 4,\
"maxBatchDelay": 100,\
"responseTimeout": 120\
}\
}\
}
# default_service_handler=/path/to/service.py:handle
# install_py_dep_per_model=false
# enable_metrics_api=false
Expand Down
34 changes: 34 additions & 0 deletions frontend/server/src/test/resources/config_snapshot.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# debug=true
# vmargs=-Xmx128m -XX:-UseLargePages -XX:+UseG1GC -XX:MaxMetaspaceSize=32M -XX:MaxDirectMemorySize=10m -XX:+ExitOnOutOfMemoryError
inference_address=https://127.0.0.1:8443
management_address=https://127.0.0.1:8444
metrics_address=https://127.0.0.1:8445
# model_server_home=../..
model_store=../archive/src/test/resources/models
load_models=noop.mar
# number_of_netty_threads=0
# netty_client_threads=0
# default_workers_per_model=0
# job_queue_size=100
# plugins_path=/tmp/plugins
async_logging=true
default_response_timeout=120
unregister_model_timeout=120
# number_of_gpu=1
# cors_allowed_origin
# cors_allowed_methods
# cors_allowed_headers
# keystore=src/test/resources/keystore.p12
# keystore_pass=changeit
# keystore_type=PKCS12
private_key_file=src/test/resources/key.pem
certificate_file=src/test/resources/certs.pem
max_response_size=2047093252
max_request_size=2047093252
# blacklist_env_vars=.*USERNAME.*|.*PASSWORD.*
# decode_input_request=true
enable_envvars_config=true
# default_service_handler=/path/to/service.py:handle
# install_py_dep_per_model=false
# enable_metrics_api=false
workflow_store=../archive/src/test/resources/workflows
24 changes: 24 additions & 0 deletions frontend/server/src/test/resources/config_test_env.properties
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,27 @@ max_request_size=10485760
# enable_envvars_config=false
# decode_input_request=true
workflow_store=../archive/src/test/resources/workflows
models={\
"noop": {\
"1.0": {\
"defaultVersion": true,\
"marName": "noop.mar",\
"minWorkers": 1,\
"maxWorkers": 1,\
"batchSize": 4,\
"maxBatchDelay": 100,\
"responseTimeout": 120\
}\
},\
"vgg16": {\
"1.0": {\
"defaultVersion": true,\
"marName": "vgg16.mar",\
"minWorkers": 1,\
"maxWorkers": 4,\
"batchSize": 8,\
"maxBatchDelay": 100,\
"responseTimeout": 120\
}\
}\
}
1 change: 1 addition & 0 deletions ts/model_service_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def load_model(load_model_request):
batch_size = None
if "batchSize" in load_model_request:
batch_size = int(load_model_request["batchSize"])
logging.info("model_name: %s, batchSize: %d", model_name, batch_size)

gpu = None
if "gpu" in load_model_request:
Expand Down