diff --git a/kafka/sasl/oauth.py b/kafka/sasl/oauth.py index 4041a93bd..f1e959cb6 100644 --- a/kafka/sasl/oauth.py +++ b/kafka/sasl/oauth.py @@ -1,10 +1,14 @@ from __future__ import absolute_import import abc +import logging from kafka.sasl.abc import SaslMechanism +log = logging.getLogger(__name__) + + class SaslMechanismOAuth(SaslMechanism): def __init__(self, **config): @@ -12,17 +16,26 @@ def __init__(self, **config): assert isinstance(config['sasl_oauth_token_provider'], AbstractTokenProvider), \ 'sasl_oauth_token_provider must implement kafka.sasl.oauth.AbstractTokenProvider' self.token_provider = config['sasl_oauth_token_provider'] + self._error = None self._is_done = False self._is_authenticated = False def auth_bytes(self): + if self._error: + # Server should respond to this with SaslAuthenticate failure, which ends the auth process + return self._error token = self.token_provider.token() extensions = self._token_extensions() return "n,,\x01auth=Bearer {}{}\x01\x01".format(token, extensions).encode('utf-8') def receive(self, auth_bytes): - self._is_done = True - self._is_authenticated = auth_bytes == b'' + if auth_bytes != b'': + error = auth_bytes.decode('utf-8') + log.debug("Sending x01 response to server after receiving SASL OAuth error: %s", error) + self._error = b'\x01' + else: + self._is_done = True + self._is_authenticated = True def is_done(self): return self._is_done