diff --git a/graphql_ws/gevent.py b/graphql_ws/gevent.py index 92a65ce..790d5b5 100644 --- a/graphql_ws/gevent.py +++ b/graphql_ws/gevent.py @@ -12,7 +12,8 @@ ) from .constants import ( GQL_CONNECTION_ACK, - GQL_CONNECTION_ERROR + GQL_CONNECTION_ERROR, + GQL_COMPLETE, ) @@ -81,34 +82,38 @@ def on_start(self, connection_context, op_id, params): connection_context.request_context, params) assert isinstance( execution_result, Observable), "A subscription must return an observable" - execution_result.subscribe(SubscriptionObserver( + disposable = execution_result.subscribe(SubscriptionObserver( connection_context, op_id, self.send_execution_result, self.send_error, - self.on_close + self.on_complete )) + connection_context.register_operation(op_id, disposable) except Exception as e: self.send_error(connection_context, op_id, str(e)) def on_stop(self, connection_context, op_id): self.unsubscribe(connection_context, op_id) + def on_complete(self, connection_context, op_id): + self.send_message(connection_context, op_id, GQL_COMPLETE) + class SubscriptionObserver(Observer): - def __init__(self, connection_context, op_id, send_execution_result, send_error, on_close): + def __init__(self, connection_context, op_id, send_execution_result, send_error, on_complete): self.connection_context = connection_context self.op_id = op_id self.send_execution_result = send_execution_result self.send_error = send_error - self.on_close = on_close + self.on_complete = on_complete def on_next(self, value): self.send_execution_result(self.connection_context, self.op_id, value) def on_completed(self): - self.on_close(self.connection_context) + self.on_complete(self.connection_context, self.op_id) def on_error(self, error): self.send_error(self.connection_context, self.op_id, error)