Skip to content

Commit f045d8d

Browse files
committed
Fix gRPC frontend race condition (#7110)
* Fix state complete_ race condition * Add delay and error checking to StreamInferResponseComplete * Add test for gRPC decoupled infer complete flag
1 parent 03efe23 commit f045d8d

File tree

4 files changed

+80
-9
lines changed

4 files changed

+80
-9
lines changed

qa/L0_grpc_state_cleanup/cleanup_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,15 @@ def test_decoupled_infer_with_params_shutdownserver(self):
554554
infer_helper_map=[False, True],
555555
)
556556

557+
def test_decoupled_infer_complete(self):
558+
# Test if the Process() thread could release the state object before
559+
# the StreamInferResponseComplete() thread is done accessing it.
560+
self._decoupled_infer(request_count=1, repeat_count=1, stream_timeout=16)
561+
# Check no error is printed to the log.
562+
with open(os.environ["SERVER_LOG"]) as f:
563+
server_log = f.read()
564+
self.assertNotIn("Should not print this", server_log)
565+
557566

558567
if __name__ == "__main__":
559568
CleanUpTest.SERVER_PID = os.environ.get("SERVER_PID", CleanUpTest.SERVER_PID)

qa/L0_grpc_state_cleanup/test.sh

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,40 @@ for i in test_simple_infer_shutdownserver \
184184
set -e
185185
done
186186

187+
TEST_NAME=test_decoupled_infer_complete
188+
export TRITONSERVER_DELAY_GRPC_COMPLETE=2000
189+
190+
SERVER_LOG="./inference_server.$TEST_NAME.log"
191+
SERVER_ARGS="--model-repository=`pwd`/models --log-verbose=2"
192+
run_server
193+
if [ "$SERVER_PID" == "0" ]; then
194+
echo -e "\n***\n*** Failed to start $SERVER\n***"
195+
cat $SERVER_LOG
196+
exit 1
197+
fi
198+
199+
echo "Test: $TEST_NAME" >>$CLIENT_LOG
200+
201+
set +e
202+
203+
SERVER_LOG=$SERVER_LOG python $CLEANUP_TEST CleanUpTest.$TEST_NAME >>$CLIENT_LOG 2>&1
204+
if [ $? -ne 0 ]; then
205+
cat $CLIENT_LOG
206+
echo -e "\n***\n*** Test $TEST_NAME Failed\n***"
207+
RET=1
208+
fi
209+
210+
kill $SERVER_PID
211+
wait $SERVER_PID
212+
213+
check_state_release $SERVER_LOG
214+
if [ $? -ne 0 ]; then
215+
cat $SERVER_LOG
216+
echo -e "\n***\n*** State Verification Failed for $TEST_NAME\n***"
217+
RET=1
218+
fi
219+
220+
set -e
187221

188222
if [ $RET -eq 0 ]; then
189223
echo -e "\n***\n*** Test Passed\n***"

src/grpc/infer_handler.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1013,12 +1013,18 @@ class InferHandlerState {
10131013
const std::shared_ptr<Context>& context, Steps start_step = Steps::START)
10141014
: tritonserver_(tritonserver), async_notify_state_(false)
10151015
{
1016-
// For debugging and testing,
1016+
// For debugging and testing
10171017
const char* dstr = getenv("TRITONSERVER_DELAY_GRPC_RESPONSE");
10181018
delay_response_ms_ = 0;
10191019
if (dstr != nullptr) {
10201020
delay_response_ms_ = atoi(dstr);
10211021
}
1022+
const char* cstr = getenv("TRITONSERVER_DELAY_GRPC_COMPLETE");
1023+
delay_complete_ms_ = 0;
1024+
if (cstr != nullptr) {
1025+
delay_complete_ms_ = atoi(cstr);
1026+
}
1027+
10221028
response_queue_.reset(new ResponseQueue<ResponseType>());
10231029
Reset(context, start_step);
10241030
}
@@ -1113,6 +1119,7 @@ class InferHandlerState {
11131119

11141120
// For testing and debugging
11151121
int delay_response_ms_;
1122+
int delay_complete_ms_;
11161123

11171124
// For inference requests the allocator payload, unused for other
11181125
// requests.

src/grpc/stream_infer_handler.cc

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -574,9 +574,10 @@ ModelStreamInferHandler::StreamInferResponseComplete(
574574
#endif // TRITON_ENABLE_TRACING
575575

576576
// Log appropriate errors
577-
state->complete_ = ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) != 0);
577+
bool is_complete =
578+
state->complete_ || (flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) != 0;
578579
if (!state->is_decoupled_) {
579-
if (!state->complete_) {
580+
if (!is_complete) {
580581
LOG_ERROR << "[INTERNAL] ModelStreamInfer received a response without "
581582
"FINAL flag for a model with one-to-one transaction";
582583
}
@@ -591,7 +592,7 @@ ModelStreamInferHandler::StreamInferResponseComplete(
591592
// Also make sure that if this state was sent to gRPC async notification
592593
// mechanism then the state is not removed as it would be needed for handling
593594
// the cancellation if detected.
594-
if (state->complete_ && (!state->IsAsyncNotifyState())) {
595+
if (is_complete && (!state->IsAsyncNotifyState())) {
595596
state->context_->EraseInflightState(state);
596597
}
597598

@@ -610,11 +611,12 @@ ModelStreamInferHandler::StreamInferResponseComplete(
610611
// If this was the final callback for the state
611612
// then cycle through the completion queue so
612613
// that state object can be released.
613-
if (state->complete_) {
614+
if (is_complete) {
614615
state->step_ = Steps::CANCELLED;
615616
state->context_->PutTaskBackToQueue(state);
616617
}
617618

619+
state->complete_ = is_complete;
618620
return;
619621
}
620622

@@ -661,8 +663,7 @@ ModelStreamInferHandler::StreamInferResponseComplete(
661663
// "empty" responses are not sent back to the client. Clients can
662664
// opt-in to receiving these empty responses via request parameters.
663665
// NOTE: The complete flag is the only flag used for this case at this time.
664-
const bool empty_final =
665-
(!iresponse && state->is_decoupled_ && state->complete_);
666+
const bool empty_final = !iresponse && state->is_decoupled_ && is_complete;
666667
const bool enable_empty_final =
667668
state->parameters_.enable_empty_final_response_;
668669

@@ -690,7 +691,24 @@ ModelStreamInferHandler::StreamInferResponseComplete(
690691
infer_response.set_model_version(state->request_.model_version());
691692
}
692693
auto& params = *(infer_response.mutable_parameters());
693-
params["triton_final_response"].set_bool_param(state->complete_);
694+
params["triton_final_response"].set_bool_param(is_complete);
695+
}
696+
697+
if (state->delay_complete_ms_ != 0) {
698+
// Delay updating the state. This is useful for testing race condition with
699+
// the thread that runs Process().
700+
LOG_INFO << "Delaying the completion of reporting response / flag by "
701+
<< state->delay_complete_ms_ << " ms...";
702+
void* context_ptr_before_delay = (void*)state->context_.get();
703+
std::this_thread::sleep_for(
704+
std::chrono::milliseconds(state->delay_complete_ms_));
705+
void* context_ptr_after_delay = (void*)state->context_.get();
706+
if (context_ptr_before_delay != context_ptr_after_delay) {
707+
LOG_ERROR << "Should not print this! The state context object has "
708+
"changed after delay, pointer before: "
709+
<< context_ptr_before_delay
710+
<< ", pointer after: " << context_ptr_after_delay;
711+
}
694712
}
695713

696714
// Update states to signal that response/error is ready to write to stream
@@ -708,11 +726,12 @@ ModelStreamInferHandler::StreamInferResponseComplete(
708726
// If this was the final callback for the state
709727
// then cycle through the completion queue so
710728
// that state object can be released.
711-
if (state->complete_) {
729+
if (is_complete) {
712730
state->step_ = Steps::CANCELLED;
713731
state->context_->PutTaskBackToQueue(state);
714732
}
715733

734+
state->complete_ = is_complete;
716735
return;
717736
}
718737

@@ -728,6 +747,8 @@ ModelStreamInferHandler::StreamInferResponseComplete(
728747
state->step_ = Steps::WRITEREADY;
729748
state->context_->WriteResponseIfReady(state);
730749
}
750+
751+
state->complete_ = is_complete;
731752
}
732753
}
733754

0 commit comments

Comments
 (0)