Skip to content

Commit 44aa3b6

Browse files
lzhangzzlvhan028
authored andcommitted
[Refactor] Support batch inference with shape clustering (#1733)
* refactor `NetModule` * name * fix sorting * fix indices (cherry picked from commit f5a05b5)
1 parent ef64224 commit 44aa3b6

File tree

1 file changed

+139
-46
lines changed

1 file changed

+139
-46
lines changed

csrc/mmdeploy/net/net_module.cpp

Lines changed: 139 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
// Copyright (c) OpenMMLab. All rights reserved.
22

3-
#include "net_module.h"
3+
#include "mmdeploy/net/net_module.h"
44

5+
#include <algorithm>
6+
#include <numeric>
57
#include <thread>
68

79
#include "mmdeploy/archive/value_archive.h"
@@ -31,6 +33,11 @@ struct NetModule::Impl {
3133
is_profiling_ = true;
3234
}
3335
auto model = context["model"].get<Model>();
36+
for (const auto& meta : model.meta().models) {
37+
if (meta.name == name) {
38+
max_batch_size_ = meta.batch_size;
39+
}
40+
}
3441
OUTCOME_TRY(auto config, model.GetModelConfig(name));
3542
device_ = context.value("device", Device{"cpu"});
3643
stream_ = context.value("stream", Stream::GetDefault(device_));
@@ -78,112 +85,197 @@ struct NetModule::Impl {
7885
return success();
7986
}
8087

81-
Result<TensorShape> InferInputShape(const vector<Tensor>& input) {
88+
static Result<TensorShape> InferBatchShape(const vector<Tensor>& input) {
8289
auto batch_size = input.size();
8390
auto& exemplar = input.front();
8491
auto shape = exemplar.shape();
8592
if (batch_size == 1) {
8693
return shape;
8794
}
8895
if (shape[0] != 1) {
89-
MMDEPLOY_ERROR("unsupported shape for batch assemble: {}", shape);
96+
MMDEPLOY_WARN("unsupported shape for batch assemble: {}", shape);
9097
return Status(eNotSupported);
9198
}
9299
for (int i = 1; i < input.size(); ++i) {
93100
auto& sample = input[i];
94101
if (sample.shape() != shape) {
95-
MMDEPLOY_ERROR("shapes are not consistent across the batch");
102+
MMDEPLOY_WARN("shapes are not consistent across the batch");
96103
return Status(eNotSupported);
97104
}
98105
}
99106
shape[0] = static_cast<int64_t>(batch_size);
100107
return shape;
101108
}
102109

103-
Result<vector<TensorShape> > InferInputShape(const vector<vector<Tensor> >& inputs) {
110+
static Result<vector<TensorShape>> InferBatchShape(const vector<vector<Tensor>>& inputs) {
104111
vector<TensorShape> shapes;
105112
shapes.reserve(inputs.size());
106113
for (const auto& input : inputs) {
107-
OUTCOME_TRY(auto shape, InferInputShape(input));
114+
OUTCOME_TRY(auto shape, InferBatchShape(input));
108115
shapes.push_back(std::move(shape));
109116
}
110117
return shapes;
111118
}
112119

113-
Result<std::vector<Output> > Forward(const std::vector<Input>& input) {
114-
// auto t0 = std::chrono::high_resolution_clock::now();
115-
//
116-
auto batch_size = static_cast<int>(input.size());
117-
118-
std::vector<std::vector<Tensor> > input_samples;
120+
Result<vector<vector<Tensor>>> CollectInputTensors(const vector<Input>& inputs) {
121+
vector<vector<Tensor>> input_samples;
119122
input_samples.reserve(inputs_.size());
120123
for (const auto& t : inputs_) {
121124
auto name = input_mapping_.at(t.name());
122-
std::vector<Tensor> tmp;
123-
tmp.reserve(input.size());
124-
for (int i = 0; i < input.size(); ++i) {
125-
auto& sample = input[i];
125+
auto& tmp = input_samples.emplace_back();
126+
for (const auto& sample : inputs) {
126127
if (auto it = sample.find(name); it != sample.end()) {
127128
tmp.push_back(it->second);
128129
} else {
129-
MMDEPLOY_ERROR("sample {} missing key {}", i, name);
130+
MMDEPLOY_ERROR("sample {} missing key {}", &sample - inputs.data(), name);
130131
return Status(eInvalidArgument);
131132
}
132133
}
133-
input_samples.push_back(std::move(tmp));
134+
}
135+
return input_samples;
136+
}
137+
138+
void SaveBatch(vector<vector<Tensor>> samples, vector<int> indices,
139+
vector<vector<vector<Tensor>>>& batch_tensors,
140+
vector<vector<TensorShape>>& batch_shapes,
141+
vector<vector<int>>& batch_sample_idxs) const {
142+
if (auto maybe_batch_shape = InferBatchShape(samples)) {
143+
batch_shapes.push_back(maybe_batch_shape.value());
144+
batch_tensors.push_back(std::move(samples));
145+
batch_sample_idxs.push_back(std::move(indices));
146+
} else {
147+
// cannot assemble batch, do it one by one
148+
for (int k = 0; k < indices.size(); ++k) {
149+
auto& shapes = batch_shapes.emplace_back();
150+
auto& batch = batch_tensors.emplace_back(inputs_.size());
151+
batch_sample_idxs.push_back({indices[k]});
152+
for (int j = 0; j < inputs_.size(); ++j) {
153+
shapes.push_back(samples[j][k].shape());
154+
batch[j].push_back(std::move(samples[j][k]));
155+
}
156+
}
157+
}
158+
}
159+
160+
void SamplesToBatches(const vector<vector<Tensor>>& input_samples, size_t n_samples,
161+
vector<vector<vector<Tensor>>>& batch_tensors,
162+
vector<vector<TensorShape>>& batch_shapes,
163+
vector<vector<int>>& batch_sample_idxs) const {
164+
// concat all shapes in samples to make comparison easier
165+
vector<vector<int64_t>> concat_shapes;
166+
concat_shapes.reserve(n_samples);
167+
for (size_t i = 0; i < n_samples; ++i) {
168+
auto& shape = concat_shapes.emplace_back();
169+
for (const auto& input : input_samples) {
170+
shape.insert(shape.end(), input[i].shape().begin(), input[i].shape().end());
171+
}
172+
}
173+
174+
// cluster samples by concatenated shapes
175+
vector<int> shape_idxs(concat_shapes.size());
176+
std::iota(shape_idxs.begin(), shape_idxs.end(), 0);
177+
std::sort(shape_idxs.begin(), shape_idxs.end(),
178+
[&concat_shapes](int i, int j) { return concat_shapes[i] < concat_shapes[j]; });
179+
shape_idxs.erase(std::unique(shape_idxs.begin(), shape_idxs.end(),
180+
[&concat_shapes](int i, int j) {
181+
return concat_shapes[i] == concat_shapes[j];
182+
}),
183+
shape_idxs.end());
184+
185+
// generate batches of samples with equal shapes, limit the batch size by max_batch_size_
186+
for (const auto ref_shape_idx : shape_idxs) {
187+
const auto& ref_shape = concat_shapes[ref_shape_idx];
188+
vector<vector<Tensor>> samples(inputs_.size());
189+
vector<int> indices;
190+
for (size_t i = 0; i < concat_shapes.size(); ++i) {
191+
if (concat_shapes[i] == ref_shape) {
192+
for (size_t j = 0; j < inputs_.size(); ++j) {
193+
samples[j].push_back(input_samples[j][i]);
194+
}
195+
indices.push_back(static_cast<int>(i));
196+
if (indices.size() == max_batch_size_) {
197+
SaveBatch(std::move(samples), std::move(indices), batch_tensors, batch_shapes,
198+
batch_sample_idxs);
199+
samples = vector<vector<Tensor>>(inputs_.size());
200+
indices = {};
201+
}
202+
}
203+
}
204+
if (!indices.empty()) {
205+
SaveBatch(std::move(samples), std::move(indices), batch_tensors, batch_shapes,
206+
batch_sample_idxs);
207+
}
208+
}
209+
}
210+
211+
Result<vector<Output>> Forward(const vector<Input>& inputs) {
212+
OUTCOME_TRY(auto input_samples, CollectInputTensors(inputs));
213+
214+
vector<vector<vector<Tensor>>> batch_tensors;
215+
vector<vector<TensorShape>> batch_shapes;
216+
vector<vector<int>> batch_sample_indices;
217+
218+
SamplesToBatches(input_samples, inputs.size(), batch_tensors, batch_shapes,
219+
batch_sample_indices);
220+
221+
vector<Output> outputs(inputs.size());
222+
for (size_t i = 0; i < batch_tensors.size(); ++i) {
223+
OUTCOME_TRY(net_->Reshape(batch_shapes[i]));
224+
OUTCOME_TRY(CopyInputTensors(batch_tensors[i], batch_shapes[i]));
225+
OUTCOME_TRY(net_->Forward());
226+
OUTCOME_TRY(CopyOutputTensors(batch_sample_indices[i], outputs));
227+
if (i + 1 < batch_tensors.size()) { // sync if not the last batch
228+
OUTCOME_TRY(stream_.Wait());
229+
}
134230
}
135231

136-
// 1. calculate input shape
137-
OUTCOME_TRY(auto input_shapes, InferInputShape(input_samples));
232+
if (is_profiling_) {
233+
OUTCOME_TRY(stream_.Wait());
234+
}
138235

139-
// 2. call backend's reshape
140-
OUTCOME_TRY(net_->Reshape(input_shapes));
236+
return outputs;
237+
}
141238

142-
// 3. fill input tensor
239+
Result<void> CopyInputTensors(const vector<vector<Tensor>>& batch,
240+
const vector<TensorShape>& shapes) const {
143241
for (int i = 0; i < inputs_.size(); ++i) {
144-
auto& src = input_samples[i];
242+
auto& src = batch[i];
145243
auto& dst = inputs_[i];
146-
if (dst.shape() != input_shapes[i]) {
147-
MMDEPLOY_ERROR("inconsistent input shape, expect {}, got {}", input_shapes[i], dst.shape());
244+
if (dst.shape() != shapes[i]) {
245+
MMDEPLOY_ERROR("inconsistent input shape, expect {}, got {}", shapes[i], dst.shape());
148246
return Status(eFail);
149247
}
150248
if (src.size() > 1) {
151249
for (int j = 0; j < src.size(); ++j) {
152-
auto slice = dst.Slice(j);
153-
OUTCOME_TRY(src[j].CopyTo(slice, stream_));
250+
OUTCOME_TRY(dst.Slice(j).CopyFrom(src[j], stream_));
154251
}
155252
} else {
156-
OUTCOME_TRY(src[0].CopyTo(dst, stream_));
253+
OUTCOME_TRY(src.front().CopyTo(dst, stream_));
157254
}
158255
}
256+
return success();
257+
}
159258

160-
// 5. forward
161-
OUTCOME_TRY(net_->Forward());
162-
163-
vector<Output> output(batch_size);
164-
for (const auto& t : outputs_) {
165-
auto name = output_mapping_.at(t.name());
166-
auto desc = t.desc();
259+
Result<void> CopyOutputTensors(const vector<int>& indices, vector<Output>& outputs) {
260+
for (const auto& output : outputs_) {
261+
auto name = output_mapping_.at(output.name());
262+
auto desc = output.desc();
167263
desc.device = device_;
168264
Tensor tmp(desc);
169265
if (tmp.size()) {
170-
OUTCOME_TRY(t.CopyTo(tmp, stream_));
266+
OUTCOME_TRY(output.CopyTo(tmp, stream_));
171267
} else {
172268
MMDEPLOY_WARN("copy skipped due to zero sized tensor");
173269
}
174-
if (output.size() > 1) {
175-
for (int i = 0; i < output.size(); ++i) {
176-
output[i].emplace(name, tmp.Slice(i));
270+
if (indices.size() > 1) {
271+
for (int i = 0; i < indices.size(); ++i) {
272+
outputs[indices[i]].emplace(name, tmp.Slice(i));
177273
}
178274
} else {
179-
output[0].emplace(name, std::move(tmp));
275+
outputs[indices.front()].emplace(name, std::move(tmp));
180276
}
181277
}
182-
if (is_profiling_) {
183-
OUTCOME_TRY(stream_.Wait());
184-
}
185-
186-
return output;
278+
return success();
187279
}
188280

189281
Device device_;
@@ -195,6 +287,7 @@ struct NetModule::Impl {
195287
std::map<std::string, std::string> input_mapping_;
196288
// outer scope to model output names
197289
std::map<std::string, std::string> output_mapping_;
290+
int max_batch_size_{1};
198291
bool is_profiling_{false};
199292
};
200293

0 commit comments

Comments
 (0)