Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 4218473

Browse files
authored
Refactor the CAS handler in prep for using the abstracted SSO code. (#8958)
This makes the CAS handler look more like the SAML/OIDC handlers: * Render errors to users instead of throwing JSON errors. * Internal reorganization.
1 parent 56e00ca commit 4218473

File tree

4 files changed

+162
-69
lines changed

4 files changed

+162
-69
lines changed

changelog.d/8958.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Properly store the mapping of external ID to Matrix ID for CAS users.

docs/dev/cas.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ easy to run CAS implementation built on top of Django.
3131
You should now have a Django project configured to serve CAS authentication with
3232
a single user created.
3333

34-
## Configure Synapse (and Riot) to use CAS
34+
## Configure Synapse (and Element) to use CAS
3535

3636
1. Modify your `homeserver.yaml` to enable CAS and point it to your locally
3737
running Django test server:
@@ -51,9 +51,9 @@ and that the CAS server is on port 8000, both on localhost.
5151

5252
## Testing the configuration
5353

54-
Then in Riot:
54+
Then in Element:
5555

56-
1. Visit the login page with a Riot pointing at your homeserver.
56+
1. Visit the login page with a Element pointing at your homeserver.
5757
2. Click the Single Sign-On button.
5858
3. Login using the credentials created with `createsuperuser`.
5959
4. You should be logged in.

synapse/handlers/cas_handler.py

Lines changed: 151 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import logging
16-
import urllib
17-
from typing import TYPE_CHECKING, Dict, Optional, Tuple
16+
import urllib.parse
17+
from typing import TYPE_CHECKING, Dict, Optional
1818
from xml.etree import ElementTree as ET
1919

20+
import attr
21+
2022
from twisted.web.client import PartialDownloadError
2123

22-
from synapse.api.errors import Codes, LoginError
24+
from synapse.api.errors import HttpResponseException
2325
from synapse.http.site import SynapseRequest
2426
from synapse.types import UserID, map_username_to_mxid_localpart
2527

@@ -29,6 +31,26 @@
2931
logger = logging.getLogger(__name__)
3032

3133

34+
class CasError(Exception):
35+
"""Used to catch errors when validating the CAS ticket.
36+
"""
37+
38+
def __init__(self, error, error_description=None):
39+
self.error = error
40+
self.error_description = error_description
41+
42+
def __str__(self):
43+
if self.error_description:
44+
return "{}: {}".format(self.error, self.error_description)
45+
return self.error
46+
47+
48+
@attr.s(slots=True, frozen=True)
49+
class CasResponse:
50+
username = attr.ib(type=str)
51+
attributes = attr.ib(type=Dict[str, Optional[str]])
52+
53+
3254
class CasHandler:
3355
"""
3456
Utility class for to handle the response from a CAS SSO service.
@@ -50,6 +72,8 @@ def __init__(self, hs: "HomeServer"):
5072

5173
self._http_client = hs.get_proxied_http_client()
5274

75+
self._sso_handler = hs.get_sso_handler()
76+
5377
def _build_service_param(self, args: Dict[str, str]) -> str:
5478
"""
5579
Generates a value to use as the "service" parameter when redirecting or
@@ -69,14 +93,20 @@ def _build_service_param(self, args: Dict[str, str]) -> str:
6993

7094
async def _validate_ticket(
7195
self, ticket: str, service_args: Dict[str, str]
72-
) -> Tuple[str, Optional[str]]:
96+
) -> CasResponse:
7397
"""
74-
Validate a CAS ticket with the server, parse the response, and return the user and display name.
98+
Validate a CAS ticket with the server, and return the parsed the response.
7599
76100
Args:
77101
ticket: The CAS ticket from the client.
78102
service_args: Additional arguments to include in the service URL.
79103
Should be the same as those passed to `get_redirect_url`.
104+
105+
Raises:
106+
CasError: If there's an error parsing the CAS response.
107+
108+
Returns:
109+
The parsed CAS response.
80110
"""
81111
uri = self._cas_server_url + "/proxyValidate"
82112
args = {
@@ -89,66 +119,65 @@ async def _validate_ticket(
89119
# Twisted raises this error if the connection is closed,
90120
# even if that's being used old-http style to signal end-of-data
91121
body = pde.response
122+
except HttpResponseException as e:
123+
description = (
124+
(
125+
'Authorization server responded with a "{status}" error '
126+
"while exchanging the authorization code."
127+
).format(status=e.code),
128+
)
129+
raise CasError("server_error", description) from e
92130

93-
user, attributes = self._parse_cas_response(body)
94-
displayname = attributes.pop(self._cas_displayname_attribute, None)
95-
96-
for required_attribute, required_value in self._cas_required_attributes.items():
97-
# If required attribute was not in CAS Response - Forbidden
98-
if required_attribute not in attributes:
99-
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
100-
101-
# Also need to check value
102-
if required_value is not None:
103-
actual_value = attributes[required_attribute]
104-
# If required attribute value does not match expected - Forbidden
105-
if required_value != actual_value:
106-
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
107-
108-
return user, displayname
131+
return self._parse_cas_response(body)
109132

110-
def _parse_cas_response(
111-
self, cas_response_body: bytes
112-
) -> Tuple[str, Dict[str, Optional[str]]]:
133+
def _parse_cas_response(self, cas_response_body: bytes) -> CasResponse:
113134
"""
114135
Retrieve the user and other parameters from the CAS response.
115136
116137
Args:
117138
cas_response_body: The response from the CAS query.
118139
140+
Raises:
141+
CasError: If there's an error parsing the CAS response.
142+
119143
Returns:
120-
A tuple of the user and a mapping of other attributes.
144+
The parsed CAS response.
121145
"""
146+
147+
# Ensure the response is valid.
148+
root = ET.fromstring(cas_response_body)
149+
if not root.tag.endswith("serviceResponse"):
150+
raise CasError(
151+
"missing_service_response",
152+
"root of CAS response is not serviceResponse",
153+
)
154+
155+
success = root[0].tag.endswith("authenticationSuccess")
156+
if not success:
157+
raise CasError("unsucessful_response", "Unsuccessful CAS response")
158+
159+
# Iterate through the nodes and pull out the user and any extra attributes.
122160
user = None
123161
attributes = {}
124-
try:
125-
root = ET.fromstring(cas_response_body)
126-
if not root.tag.endswith("serviceResponse"):
127-
raise Exception("root of CAS response is not serviceResponse")
128-
success = root[0].tag.endswith("authenticationSuccess")
129-
for child in root[0]:
130-
if child.tag.endswith("user"):
131-
user = child.text
132-
if child.tag.endswith("attributes"):
133-
for attribute in child:
134-
# ElementTree library expands the namespace in
135-
# attribute tags to the full URL of the namespace.
136-
# We don't care about namespace here and it will always
137-
# be encased in curly braces, so we remove them.
138-
tag = attribute.tag
139-
if "}" in tag:
140-
tag = tag.split("}")[1]
141-
attributes[tag] = attribute.text
142-
if user is None:
143-
raise Exception("CAS response does not contain user")
144-
except Exception:
145-
logger.exception("Error parsing CAS response")
146-
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
147-
if not success:
148-
raise LoginError(
149-
401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
150-
)
151-
return user, attributes
162+
for child in root[0]:
163+
if child.tag.endswith("user"):
164+
user = child.text
165+
if child.tag.endswith("attributes"):
166+
for attribute in child:
167+
# ElementTree library expands the namespace in
168+
# attribute tags to the full URL of the namespace.
169+
# We don't care about namespace here and it will always
170+
# be encased in curly braces, so we remove them.
171+
tag = attribute.tag
172+
if "}" in tag:
173+
tag = tag.split("}")[1]
174+
attributes[tag] = attribute.text
175+
176+
# Ensure a user was found.
177+
if user is None:
178+
raise CasError("no_user", "CAS response does not contain user")
179+
180+
return CasResponse(user, attributes)
152181

153182
def get_redirect_url(self, service_args: Dict[str, str]) -> str:
154183
"""
@@ -201,15 +230,76 @@ async def handle_ticket(
201230
args["redirectUrl"] = client_redirect_url
202231
if session:
203232
args["session"] = session
204-
username, user_display_name = await self._validate_ticket(ticket, args)
233+
234+
try:
235+
cas_response = await self._validate_ticket(ticket, args)
236+
except CasError as e:
237+
logger.exception("Could not validate ticket")
238+
self._sso_handler.render_error(request, e.error, e.error_description, 401)
239+
return
240+
241+
await self._handle_cas_response(
242+
request, cas_response, client_redirect_url, session
243+
)
244+
245+
async def _handle_cas_response(
246+
self,
247+
request: SynapseRequest,
248+
cas_response: CasResponse,
249+
client_redirect_url: Optional[str],
250+
session: Optional[str],
251+
) -> None:
252+
"""Handle a CAS response to a ticket request.
253+
254+
Assumes that the response has been validated. Maps the user onto an MXID,
255+
registering them if necessary, and returns a response to the browser.
256+
257+
Args:
258+
request: the incoming request from the browser. We'll respond to it with an
259+
HTML page or a redirect
260+
261+
cas_response: The parsed CAS response.
262+
263+
client_redirect_url: the redirectUrl parameter from the `/cas/ticket` HTTP request, if given.
264+
This should be the same as the redirectUrl from the original `/login/sso/redirect` request.
265+
266+
session: The session parameter from the `/cas/ticket` HTTP request, if given.
267+
This should be the UI Auth session id.
268+
"""
269+
270+
# Ensure that the attributes of the logged in user meet the required
271+
# attributes.
272+
for required_attribute, required_value in self._cas_required_attributes.items():
273+
# If required attribute was not in CAS Response - Forbidden
274+
if required_attribute not in cas_response.attributes:
275+
self._sso_handler.render_error(
276+
request,
277+
"unauthorised",
278+
"You are not authorised to log in here.",
279+
401,
280+
)
281+
return
282+
283+
# Also need to check value
284+
if required_value is not None:
285+
actual_value = cas_response.attributes[required_attribute]
286+
# If required attribute value does not match expected - Forbidden
287+
if required_value != actual_value:
288+
self._sso_handler.render_error(
289+
request,
290+
"unauthorised",
291+
"You are not authorised to log in here.",
292+
401,
293+
)
294+
return
205295

206296
# Pull out the user-agent and IP from the request.
207297
user_agent = request.get_user_agent("")
208298
ip_address = self.hs.get_ip_from_request(request)
209299

210300
# Get the matrix ID from the CAS username.
211301
user_id = await self._map_cas_user_to_matrix_user(
212-
username, user_display_name, user_agent, ip_address
302+
cas_response, user_agent, ip_address
213303
)
214304

215305
if session:
@@ -225,34 +315,31 @@ async def handle_ticket(
225315
)
226316

227317
async def _map_cas_user_to_matrix_user(
228-
self,
229-
remote_user_id: str,
230-
display_name: Optional[str],
231-
user_agent: str,
232-
ip_address: str,
318+
self, cas_response: CasResponse, user_agent: str, ip_address: str,
233319
) -> str:
234320
"""
235321
Given a CAS username, retrieve the user ID for it and possibly register the user.
236322
237323
Args:
238-
remote_user_id: The username from the CAS response.
239-
display_name: The display name from the CAS response.
324+
cas_response: The parsed CAS response.
240325
user_agent: The user agent of the client making the request.
241326
ip_address: The IP address of the client making the request.
242327
243328
Returns:
244329
The user ID associated with this response.
245330
"""
246331

247-
localpart = map_username_to_mxid_localpart(remote_user_id)
332+
localpart = map_username_to_mxid_localpart(cas_response.username)
248333
user_id = UserID(localpart, self._hostname).to_string()
249334
registered_user_id = await self._auth_handler.check_user_exists(user_id)
250335

336+
displayname = cas_response.attributes.get(self._cas_displayname_attribute, None)
337+
251338
# If the user does not exist, register it.
252339
if not registered_user_id:
253340
registered_user_id = await self._registration_handler.register_user(
254341
localpart=localpart,
255-
default_display_name=display_name,
342+
default_display_name=displayname,
256343
user_agent_ips=[(user_agent, ip_address)],
257344
)
258345

synapse/handlers/sso.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,11 @@ def __init__(self, hs: "HomeServer"):
101101
self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession]
102102

103103
def render_error(
104-
self, request, error: str, error_description: Optional[str] = None
104+
self,
105+
request: Request,
106+
error: str,
107+
error_description: Optional[str] = None,
108+
code: int = 400,
105109
) -> None:
106110
"""Renders the error template and responds with it.
107111
@@ -113,11 +117,12 @@ def render_error(
113117
We'll respond with an HTML page describing the error.
114118
error: A technical identifier for this error.
115119
error_description: A human-readable description of the error.
120+
code: The integer error code (an HTTP response code)
116121
"""
117122
html = self._error_template.render(
118123
error=error, error_description=error_description
119124
)
120-
respond_with_html(request, 400, html)
125+
respond_with_html(request, code, html)
121126

122127
async def get_sso_user_by_remote_user_id(
123128
self, auth_provider_id: str, remote_user_id: str

0 commit comments

Comments
 (0)