diff --git a/jupyter_server/base/zmqhandlers.py b/jupyter_server/base/zmqhandlers.py index 1e1c9cf9a5..6109db5d1d 100644 --- a/jupyter_server/base/zmqhandlers.py +++ b/jupyter_server/base/zmqhandlers.py @@ -82,6 +82,38 @@ def deserialize_binary_message(bmsg): return msg +def serialize_msg_to_ws_v1(msg_or_list, channel, pack=None): + if pack: + msg_list = [ + pack(msg_or_list["header"]), + pack(msg_or_list["parent_header"]), + pack(msg_or_list["metadata"]), + pack(msg_or_list["content"]), + ] + else: + msg_list = msg_or_list + channel = channel.encode("utf-8") + offsets = [] + offsets.append(8 * (1 + 1 + len(msg_list) + 1)) + offsets.append(len(channel) + offsets[-1]) + for msg in msg_list: + offsets.append(len(msg) + offsets[-1]) + offset_number = len(offsets).to_bytes(8, byteorder="little") + offsets = [offset.to_bytes(8, byteorder="little") for offset in offsets] + bin_msg = b"".join([offset_number] + offsets + [channel] + msg_list) + return bin_msg + + +def deserialize_msg_from_ws_v1(ws_msg): + offset_number = int.from_bytes(ws_msg[:8], "little") + offsets = [ + int.from_bytes(ws_msg[8 * (i + 1) : 8 * (i + 2)], "little") for i in range(offset_number) + ] + channel = ws_msg[offsets[0] : offsets[1]].decode("utf-8") + msg_list = [ws_msg[offsets[i] : offsets[i + 1]] for i in range(1, offset_number - 1)] + return channel, msg_list + + # ping interval for keeping websockets alive (30 seconds) WS_PING_INTERVAL = 30000 @@ -239,6 +271,16 @@ def _reserialize_reply(self, msg_or_list, channel=None): smsg = json.dumps(msg, default=json_default) return cast_unicode(smsg) + def select_subprotocol(self, subprotocols): + preferred_protocol = self.settings.get("kernel_ws_protocol") + if preferred_protocol is None: + preferred_protocol = "v1.kernel.websocket.jupyter.org" + elif preferred_protocol == "": + preferred_protocol = None + selected_subprotocol = preferred_protocol if preferred_protocol in subprotocols else None + # None is the default, "legacy" protocol + return selected_subprotocol + def _on_zmq_reply(self, stream, msg_list): # Sometimes this gets triggered when the on_close method is scheduled in the # eventloop but hasn't been called. @@ -247,12 +289,16 @@ def _on_zmq_reply(self, stream, msg_list): self.close() return channel = getattr(stream, "channel", None) - try: - msg = self._reserialize_reply(msg_list, channel=channel) - except Exception: - self.log.critical("Malformed message: %r" % msg_list, exc_info=True) + if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org": + bin_msg = serialize_msg_to_ws_v1(msg_list, channel) + self.write_message(bin_msg, binary=True) else: - self.write_message(msg, binary=isinstance(msg, bytes)) + try: + msg = self._reserialize_reply(msg_list, channel=channel) + except Exception: + self.log.critical("Malformed message: %r" % msg_list, exc_info=True) + else: + self.write_message(msg, binary=isinstance(msg, bytes)) class AuthenticatedZMQStreamHandler(ZMQStreamHandler, JupyterHandler): diff --git a/jupyter_server/serverapp.py b/jupyter_server/serverapp.py index 51ebd23f3e..428345f750 100644 --- a/jupyter_server/serverapp.py +++ b/jupyter_server/serverapp.py @@ -314,7 +314,10 @@ def init_settings( "no_cache_paths": [url_path_join(base_url, "static", "custom")], }, version_hash=version_hash, + # kernel message protocol over websoclet + kernel_ws_protocol=jupyter_app.kernel_ws_protocol, # rate limits + limit_rate=jupyter_app.limit_rate, iopub_msg_rate_limit=jupyter_app.iopub_msg_rate_limit, iopub_data_rate_limit=jupyter_app.iopub_data_rate_limit, rate_limit_window=jupyter_app.rate_limit_window, @@ -1612,6 +1615,29 @@ def _update_server_extensions(self, change): help=_i18n("Reraise exceptions encountered loading server extensions?"), ) + kernel_ws_protocol = Unicode( + None, + allow_none=True, + config=True, + help=_i18n( + "Preferred kernel message protocol over websocket to use (default: None). " + "If an empty string is passed, select the legacy protocol. If None, " + "the selected protocol will depend on what the front-end supports " + "(usually the most recent protocol supported by the back-end and the " + "front-end)." + ), + ) + + limit_rate = Bool( + True, + config=True, + help=_i18n( + "Whether to limit the rate of IOPub messages (default: True). " + "If True, use iopub_msg_rate_limit, iopub_data_rate_limit and/or rate_limit_window " + "to tune the rate." + ), + ) + iopub_msg_rate_limit = Float( 1000, config=True, diff --git a/jupyter_server/services/kernels/handlers.py b/jupyter_server/services/kernels/handlers.py index 46b54b9372..88ac9df2d8 100644 --- a/jupyter_server/services/kernels/handlers.py +++ b/jupyter_server/services/kernels/handlers.py @@ -22,7 +22,11 @@ from ...base.handlers import APIHandler from ...base.zmqhandlers import AuthenticatedZMQStreamHandler -from ...base.zmqhandlers import deserialize_binary_message +from ...base.zmqhandlers import ( + deserialize_binary_message, + serialize_msg_to_ws_v1, + deserialize_msg_from_ws_v1, +) from jupyter_server.utils import ensure_async from jupyter_server.utils import url_escape from jupyter_server.utils import url_path_join @@ -105,6 +109,10 @@ def kernel_info_timeout(self): km_default = self.kernel_manager.kernel_info_timeout return self.settings.get("kernel_info_timeout", km_default) + @property + def limit_rate(self): + return self.settings.get("limit_rate", True) + @property def iopub_msg_rate_limit(self): return self.settings.get("iopub_msg_rate_limit", 0) @@ -449,16 +457,25 @@ def subscribe(value): return connected - def on_message(self, msg): + def on_message(self, ws_msg): if not self.channels: # already closed, ignore the message - self.log.debug("Received message on closed websocket %r", msg) + self.log.debug("Received message on closed websocket %r", ws_msg) return - if isinstance(msg, bytes): - msg = deserialize_binary_message(msg) + + if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org": + channel, msg_list = deserialize_msg_from_ws_v1(ws_msg) + msg = { + "header": None, + } else: - msg = json.loads(msg) - channel = msg.pop("channel", None) + if isinstance(ws_msg, bytes): + msg = deserialize_binary_message(ws_msg) + else: + msg = json.loads(ws_msg) + msg_list = [] + channel = msg.pop("channel", None) + if channel is None: self.log.warning("No channel specified, assuming shell: %s", msg) channel = "shell" @@ -466,47 +483,86 @@ def on_message(self, msg): self.log.warning("No such channel: %r", channel) return am = self.kernel_manager.allowed_message_types - mt = msg["header"]["msg_type"] - if am and mt not in am: - self.log.warning('Received message of type "%s", which is not allowed. Ignoring.' % mt) - else: + ignore_msg = False + if am: + msg["header"] = self.get_part("header", msg["header"], msg_list) + if msg["header"]["msg_type"] not in am: + self.log.warning( + 'Received message of type "%s", which is not allowed. Ignoring.' + % msg["header"]["msg_type"] + ) + ignore_msg = True + if not ignore_msg: stream = self.channels[channel] - self.session.send(stream, msg) + if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org": + self.session.send_raw(stream, msg_list) + else: + self.session.send(stream, msg) + + def get_part(self, field, value, msg_list): + if value is None: + field2idx = { + "header": 0, + "parent_header": 1, + "content": 3, + } + value = self.session.unpack(msg_list[field2idx[field]]) + return value def _on_zmq_reply(self, stream, msg_list): idents, fed_msg_list = self.session.feed_identities(msg_list) - msg = self.session.deserialize(fed_msg_list) - parent = msg["parent_header"] - - def write_stderr(error_message): - self.log.warning(error_message) - msg = self.session.msg( - "stream", content={"text": error_message + "\n", "name": "stderr"}, parent=parent - ) - msg["channel"] = "iopub" - self.write_message(json.dumps(msg, default=json_default)) + if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org": + msg = {"header": None, "parent_header": None, "content": None} + else: + msg = self.session.deserialize(fed_msg_list) channel = getattr(stream, "channel", None) - msg_type = msg["header"]["msg_type"] + parts = fed_msg_list[1:] - if channel == "iopub" and msg_type == "error": - self._on_error(msg) + self._on_error(channel, msg, parts) - if ( - channel == "iopub" - and msg_type == "status" - and msg["content"].get("execution_state") == "idle" - ): - # reset rate limit counter on status=idle, - # to avoid 'Run All' hitting limits prematurely. - self._iopub_window_byte_queue = [] - self._iopub_window_msg_count = 0 - self._iopub_window_byte_count = 0 - self._iopub_msgs_exceeded = False - self._iopub_data_exceeded = False + if self._limit_rate(channel, msg, parts): + return - if channel == "iopub" and msg_type not in {"status", "comm_open", "execute_input"}: + if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org": + super(ZMQChannelsHandler, self)._on_zmq_reply(stream, parts) + else: + super(ZMQChannelsHandler, self)._on_zmq_reply(stream, msg) + + def write_stderr(self, error_message, parent_header): + self.log.warning(error_message) + err_msg = self.session.msg( + "stream", + content={"text": error_message + "\n", "name": "stderr"}, + parent=parent_header, + ) + if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org": + bin_msg = serialize_msg_to_ws_v1(err_msg, "iopub", self.session.pack) + self.write_message(bin_msg, binary=True) + else: + err_msg["channel"] = "iopub" + self.write_message(json.dumps(err_msg, default=json_default)) + + def _limit_rate(self, channel, msg, msg_list): + if not (self.limit_rate and channel == "iopub"): + return False + + msg["header"] = self.get_part("header", msg["header"], msg_list) + + msg_type = msg["header"]["msg_type"] + if msg_type == "status": + msg["content"] = self.get_part("content", msg["content"], msg_list) + if msg["content"].get("execution_state") == "idle": + # reset rate limit counter on status=idle, + # to avoid 'Run All' hitting limits prematurely. + self._iopub_window_byte_queue = [] + self._iopub_window_msg_count = 0 + self._iopub_window_byte_count = 0 + self._iopub_msgs_exceeded = False + self._iopub_data_exceeded = False + + if msg_type not in {"status", "comm_open", "execute_input"}: # Remove the counts queued for removal. now = IOLoop.current().time() @@ -542,7 +598,10 @@ def write_stderr(error_message): if self.iopub_msg_rate_limit > 0 and msg_rate > self.iopub_msg_rate_limit: if not self._iopub_msgs_exceeded: self._iopub_msgs_exceeded = True - write_stderr( + msg["parent_header"] = self.get_part( + "parent_header", msg["parent_header"], msg_list + ) + self.write_stderr( dedent( """\ IOPub message rate exceeded. @@ -557,7 +616,8 @@ def write_stderr(error_message): """.format( self.iopub_msg_rate_limit, self.rate_limit_window ) - ) + ), + msg["parent_header"], ) else: # resume once we've got some headroom below the limit @@ -570,7 +630,10 @@ def write_stderr(error_message): if self.iopub_data_rate_limit > 0 and data_rate > self.iopub_data_rate_limit: if not self._iopub_data_exceeded: self._iopub_data_exceeded = True - write_stderr( + msg["parent_header"] = self.get_part( + "parent_header", msg["parent_header"], msg_list + ) + self.write_stderr( dedent( """\ IOPub data rate exceeded. @@ -585,7 +648,8 @@ def write_stderr(error_message): """.format( self.iopub_data_rate_limit, self.rate_limit_window ) - ) + ), + msg["parent_header"], ) else: # resume once we've got some headroom below the limit @@ -600,8 +664,9 @@ def write_stderr(error_message): self._iopub_window_msg_count -= 1 self._iopub_window_byte_count -= byte_count self._iopub_window_byte_queue.pop(-1) - return - super(ZMQChannelsHandler, self)._on_zmq_reply(stream, msg) + return True + + return False def close(self): super(ZMQChannelsHandler, self).close() @@ -651,8 +716,12 @@ def _send_status_message(self, status): # that all messages from the stopped kernel have been delivered iopub.flush() msg = self.session.msg("status", {"execution_state": status}) - msg["channel"] = "iopub" - self.write_message(json.dumps(msg, default=json_default)) + if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org": + bin_msg = serialize_msg_to_ws_v1(msg, "iopub", self.session.pack) + self.write_message(bin_msg, binary=True) + else: + msg["channel"] = "iopub" + self.write_message(json.dumps(msg, default=json_default)) def on_kernel_restarted(self): self.log.warning("kernel %s restarted", self.kernel_id) @@ -662,12 +731,19 @@ def on_restart_failed(self): self.log.error("kernel %s restarted failed!", self.kernel_id) self._send_status_message("dead") - def _on_error(self, msg): + def _on_error(self, channel, msg, msg_list): if self.kernel_manager.allow_tracebacks: return - msg["content"]["ename"] = "ExecutionError" - msg["content"]["evalue"] = "Execution error" - msg["content"]["traceback"] = [self.kernel_manager.traceback_replacement_message] + + if channel == "iopub": + msg["header"] = self.get_part("header", msg["header"], msg_list) + if msg["header"]["msg_type"] == "error": + msg["content"] = self.get_part("content", msg["content"], msg_list) + msg["content"]["ename"] = "ExecutionError" + msg["content"]["evalue"] = "Execution error" + msg["content"]["traceback"] = [self.kernel_manager.traceback_replacement_message] + if self.selected_subprotocol == "v1.kernel.websocket.jupyter.org": + msg_list[3] = self.session.pack(msg["content"]) # -----------------------------------------------------------------------------