Skip to content

Commit 4e4f82c

Browse files
authored
Support ZMQ Curve for transport encryption (#1515)
This PR adds support for ZMQ Curve cryptography to encrypt kernel communications over TCP, addressing the current limitation of all code and outputs being transmitted in plaintext. What's changed - Curve key support in IPKernelApp: Loads `curve_publickey` / `curve_secretkey` from connection files and applies them to all ZMQ sockets (`iopub`, `shell`, `control`, `stdin`, `heartbeat`) via new `_apply_curve_server_options()` / `_apply_curve_client_options()` helpers. - Kernelspec metadata: Advertises `"supported_encryption": "curve"` so frontends can discover and negotiate encryption capability. - Emits a warning when TCP transport is used without encryption, guiding users toward IPC or CurveZMQ. - Heartbeat thread: Extended to accept and apply curve keys to its ROUTER socket. ### References - Depends on jupyter/jupyter_client#1110 - Enables jupyter/jupyter_client#808 - PoC for jupyter/enhancement-proposals#75 - Toggle on server-level: jupyter-server/jupyter_server#1638 ### Code changes - [x] tests - [x] additions
2 parents 6fc8b58 + b64f597 commit 4e4f82c

9 files changed

Lines changed: 375 additions & 8 deletions

File tree

hatch_build.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ def initialize(self, version, build_data):
2020

2121
# When building a standard wheel, the executable specified in the kernelspec is simply 'python'.
2222
if version == "standard":
23-
overrides["metadata"] = dict(debugger=True)
23+
overrides["metadata"] = {
24+
"debugger": True,
25+
"supported_encryption": "curve",
26+
}
2427
argv = make_ipkernel_cmd(executable="python")
2528

2629
# When installing an editable wheel, the full `sys.executable` can be used.

ipykernel/heartbeat.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,29 @@
2727
class Heartbeat(Thread):
2828
"""A simple ping-pong style heartbeat that runs in a thread."""
2929

30-
def __init__(self, context, addr=None):
31-
"""Initialize the heartbeat thread."""
30+
def __init__(self, context, addr=None, *, curve_publickey=None, curve_secretkey=None):
31+
"""Initialize the heartbeat thread.
32+
33+
Parameters
34+
----------
35+
context : zmq.Context
36+
addr : tuple, optional
37+
(transport, ip, port)
38+
curve_publickey : bytes, optional
39+
CurveZMQ public key (Z85). When provided together with
40+
*curve_secretkey*, the heartbeat socket will operate as a
41+
CurveZMQ server so that only authenticated clients can connect.
42+
curve_secretkey : bytes, optional
43+
CurveZMQ secret key (Z85, paired with *curve_publickey*).
44+
"""
3245
if addr is None:
3346
addr = ("tcp", localhost(), 0)
3447
Thread.__init__(self, name="Heartbeat")
3548
self.context = context
3649
self.transport, self.ip, self.port = addr
3750
self.original_port = self.port
51+
self._curve_publickey = curve_publickey
52+
self._curve_secretkey = curve_secretkey
3853
if self.original_port == 0:
3954
self.pick_port()
4055
self.addr = (self.ip, self.port)
@@ -94,6 +109,10 @@ def run(self):
94109
self.name = "Heartbeat"
95110
self.socket = self.context.socket(zmq.ROUTER)
96111
self.socket.linger = 1000
112+
if self._curve_secretkey is not None:
113+
self.socket.curve_secretkey = self._curve_secretkey
114+
self.socket.curve_publickey = self._curve_publickey
115+
self.socket.curve_server = True
97116
try:
98117
self._bind_socket()
99118
except Exception:
@@ -122,3 +141,6 @@ def run(self):
122141
raise
123142
else:
124143
break
144+
145+
_curve_publickey: bytes | None
146+
_curve_secretkey: bytes | None

ipykernel/inprocess/client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from jupyter_client.client import KernelClient
1919
from jupyter_client.clientabc import KernelClientABC
20+
from jupyter_client.connect import KernelConnectionInfo
2021
from jupyter_core.utils import run_sync
2122

2223
# IPython imports
@@ -59,10 +60,10 @@ def _default_blocking_class(self):
5960

6061
return BlockingInProcessKernelClient
6162

62-
def get_connection_info(self, session: bool = False) -> dict[str, int | str | bytes]:
63+
def get_connection_info(self, session: bool = False) -> KernelConnectionInfo:
6364
"""Get the connection info for the client."""
6465
d = super().get_connection_info(session=session)
65-
d["kernel"] = self.kernel # type:ignore[assignment]
66+
d["kernel"] = self.kernel # type: ignore[typeddict-unknown-key]
6667
return d
6768

6869
def start_channels(self, *args, **kwargs):

ipykernel/kernelapp.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from traitlets.traitlets import (
3434
Any,
3535
Bool,
36+
Bytes,
3637
Dict,
3738
DottedObjectName,
3839
Instance,
@@ -158,6 +159,11 @@ class IPKernelApp(BaseIPythonApplication, InteractiveShellApp, ConnectionFileMix
158159
# connection info:
159160
connection_dir = Unicode()
160161

162+
# Optional CurveZMQ keys loaded from the connection file (Z85-encoded bytes).
163+
# None when the kernel was not started with CurveZMQ enabled.
164+
curve_publickey: Bytes | None = Bytes(allow_none=True, default_value=None)
165+
curve_secretkey: Bytes | None = Bytes(allow_none=True, default_value=None)
166+
161167
@default("connection_dir")
162168
def _default_connection_dir(self):
163169
return jupyter_runtime_dir()
@@ -211,6 +217,25 @@ def excepthook(self, etype, evalue, tb):
211217
# write uncaught traceback to 'real' stderr, not zmq-forwarder
212218
traceback.print_exception(etype, evalue, tb, file=sys.__stderr__)
213219

220+
def _apply_curve_server_options(self, socket: zmq.Socket[t.Any]) -> None:
221+
"""Set CurveZMQ server-side options on *socket* before it is bound.
222+
223+
This is a no-op when Curve keys are not available yet, so it is safe
224+
to call unconditionally.
225+
"""
226+
if self.curve_secretkey is not None:
227+
socket.curve_secretkey = self.curve_secretkey
228+
socket.curve_publickey = self.curve_publickey
229+
socket.curve_server = True
230+
231+
def _apply_curve_client_options(self, socket: zmq.Socket[t.Any]) -> None:
232+
"""Set CurveZMQ client-side options on *socket* before it connects."""
233+
if self.curve_secretkey is not None:
234+
socket.curve_serverkey = self.curve_publickey
235+
# Reuse manager-provisioned keypair for the in-kernel client socket.
236+
socket.curve_secretkey = self.curve_secretkey
237+
socket.curve_publickey = self.curve_publickey
238+
214239
def init_poller(self):
215240
"""Initialize the poller."""
216241
if sys.platform == "win32":
@@ -274,6 +299,9 @@ def write_connection_file(self, **kwargs: Any) -> None:
274299
iopub_port=self.iopub_port,
275300
control_port=self.control_port,
276301
)
302+
if self.curve_publickey is not None:
303+
connection_info["curve_publickey"] = self.curve_publickey
304+
connection_info["curve_secretkey"] = self.curve_secretkey
277305
if Path(cf).exists():
278306
# If the file exists, merge our info into it. For example, if the
279307
# original file had port number 0, we update with the actual port
@@ -328,13 +356,26 @@ def init_sockets(self):
328356
self.context = context = zmq.Context()
329357
atexit.register(self.close)
330358

359+
if self.curve_secretkey is not None:
360+
self.log.info("Detected CurveZMQ secret key; using transport encryption")
361+
elif self.transport == "tcp":
362+
self.log.warning(
363+
"Kernel is running over TCP without encryption."
364+
" All communication (including code and outputs) is sent in plain text"
365+
" and is susceptible to eavesdropping."
366+
" Use IPC transport or launch with kernel manager-provisioned"
367+
" CurveZMQ keys to enable transport encryption."
368+
)
369+
331370
self.shell_socket = context.socket(zmq.ROUTER)
332371
self.shell_socket.linger = 1000
372+
self._apply_curve_server_options(self.shell_socket)
333373
self.shell_port = self._bind_socket(self.shell_socket, self.shell_port)
334374
self.log.debug("shell ROUTER Channel on port: %i", self.shell_port)
335375

336376
self.stdin_socket = context.socket(zmq.ROUTER)
337377
self.stdin_socket.linger = 1000
378+
self._apply_curve_server_options(self.stdin_socket)
338379
self.stdin_port = self._bind_socket(self.stdin_socket, self.stdin_port)
339380
self.log.debug("stdin ROUTER Channel on port: %i", self.stdin_port)
340381

@@ -351,6 +392,7 @@ def init_control(self, context):
351392
"""Initialize the control channel."""
352393
self.control_socket = context.socket(zmq.ROUTER)
353394
self.control_socket.linger = 1000
395+
self._apply_curve_server_options(self.control_socket)
354396
self.control_port = self._bind_socket(self.control_socket, self.control_port)
355397
self.log.debug("control ROUTER Channel on port: %i", self.control_port)
356398

@@ -359,6 +401,7 @@ def init_control(self, context):
359401

360402
self.debug_shell_socket = context.socket(zmq.DEALER)
361403
self.debug_shell_socket.linger = 1000
404+
self._apply_curve_client_options(self.debug_shell_socket)
362405
if self.shell_socket.getsockopt(zmq.LAST_ENDPOINT):
363406
self.debug_shell_socket.connect(self.shell_socket.getsockopt(zmq.LAST_ENDPOINT))
364407

@@ -379,6 +422,7 @@ def init_iopub(self, context):
379422
"""Initialize the iopub channel."""
380423
self.iopub_socket = context.socket(zmq.XPUB)
381424
self.iopub_socket.linger = 1000
425+
self._apply_curve_server_options(self.iopub_socket)
382426
self.iopub_port = self._bind_socket(self.iopub_socket, self.iopub_port)
383427
self.log.debug("iopub PUB Channel on port: %i", self.iopub_port)
384428
self.configure_tornado_logger()
@@ -392,7 +436,12 @@ def init_heartbeat(self):
392436
# heartbeat doesn't share context, because it mustn't be blocked
393437
# by the GIL, which is accessed by libzmq when freeing zero-copy messages
394438
hb_ctx = zmq.Context()
395-
self.heartbeat = Heartbeat(hb_ctx, (self.transport, self.ip, self.hb_port))
439+
self.heartbeat = Heartbeat(
440+
hb_ctx,
441+
(self.transport, self.ip, self.hb_port),
442+
curve_publickey=self.curve_publickey,
443+
curve_secretkey=self.curve_secretkey,
444+
)
396445
self.hb_port = self.heartbeat.port
397446
self.log.debug("Heartbeat REP Channel on port: %i", self.hb_port)
398447
self.heartbeat.start()

ipykernel/kernelspec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def get_kernel_dict(
6666
),
6767
"display_name": "Python %i (ipykernel)" % sys.version_info[0],
6868
"language": "python",
69-
"metadata": {"debugger": True},
69+
"metadata": {"debugger": True, "supported_encryption": "curve"},
7070
"kernel_protocol_version": "5.5",
7171
}
7272

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ dependencies = [
2323
"ipython>=7.23.1",
2424
"comm>=0.1.1",
2525
"traitlets>=5.4.0",
26-
"jupyter_client>=8.8.0",
26+
"jupyter_client>=8.9.0",
2727
"jupyter_core>=5.1,!=6.0.*",
2828
# For tk event loop support only.
2929
"nest_asyncio2>=1.7.0",

0 commit comments

Comments
 (0)