1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515import logging
16- import urllib
17- from typing import TYPE_CHECKING , Dict , Optional , Tuple
16+ import urllib . parse
17+ from typing import TYPE_CHECKING , Dict , Optional
1818from xml .etree import ElementTree as ET
1919
20+ import attr
21+
2022from twisted .web .client import PartialDownloadError
2123
22- from synapse .api .errors import Codes , LoginError
24+ from synapse .api .errors import HttpResponseException
2325from synapse .http .site import SynapseRequest
2426from synapse .types import UserID , map_username_to_mxid_localpart
2527
2931logger = logging .getLogger (__name__ )
3032
3133
34+ class CasError (Exception ):
35+ """Used to catch errors when validating the CAS ticket.
36+ """
37+
38+ def __init__ (self , error , error_description = None ):
39+ self .error = error
40+ self .error_description = error_description
41+
42+ def __str__ (self ):
43+ if self .error_description :
44+ return "{}: {}" .format (self .error , self .error_description )
45+ return self .error
46+
47+
48+ @attr .s (slots = True , frozen = True )
49+ class CasResponse :
50+ username = attr .ib (type = str )
51+ attributes = attr .ib (type = Dict [str , Optional [str ]])
52+
53+
3254class CasHandler :
3355 """
3456 Utility class for to handle the response from a CAS SSO service.
@@ -50,6 +72,8 @@ def __init__(self, hs: "HomeServer"):
5072
5173 self ._http_client = hs .get_proxied_http_client ()
5274
75+ self ._sso_handler = hs .get_sso_handler ()
76+
5377 def _build_service_param (self , args : Dict [str , str ]) -> str :
5478 """
5579 Generates a value to use as the "service" parameter when redirecting or
@@ -69,14 +93,20 @@ def _build_service_param(self, args: Dict[str, str]) -> str:
6993
7094 async def _validate_ticket (
7195 self , ticket : str , service_args : Dict [str , str ]
72- ) -> Tuple [ str , Optional [ str ]] :
96+ ) -> CasResponse :
7397 """
74- Validate a CAS ticket with the server, parse the response, and return the user and display name .
98+ Validate a CAS ticket with the server, and return the parsed the response .
7599
76100 Args:
77101 ticket: The CAS ticket from the client.
78102 service_args: Additional arguments to include in the service URL.
79103 Should be the same as those passed to `get_redirect_url`.
104+
105+ Raises:
106+ CasError: If there's an error parsing the CAS response.
107+
108+ Returns:
109+ The parsed CAS response.
80110 """
81111 uri = self ._cas_server_url + "/proxyValidate"
82112 args = {
@@ -89,66 +119,65 @@ async def _validate_ticket(
89119 # Twisted raises this error if the connection is closed,
90120 # even if that's being used old-http style to signal end-of-data
91121 body = pde .response
122+ except HttpResponseException as e :
123+ description = (
124+ (
125+ 'Authorization server responded with a "{status}" error '
126+ "while exchanging the authorization code."
127+ ).format (status = e .code ),
128+ )
129+ raise CasError ("server_error" , description ) from e
92130
93- user , attributes = self ._parse_cas_response (body )
94- displayname = attributes .pop (self ._cas_displayname_attribute , None )
95-
96- for required_attribute , required_value in self ._cas_required_attributes .items ():
97- # If required attribute was not in CAS Response - Forbidden
98- if required_attribute not in attributes :
99- raise LoginError (401 , "Unauthorized" , errcode = Codes .UNAUTHORIZED )
100-
101- # Also need to check value
102- if required_value is not None :
103- actual_value = attributes [required_attribute ]
104- # If required attribute value does not match expected - Forbidden
105- if required_value != actual_value :
106- raise LoginError (401 , "Unauthorized" , errcode = Codes .UNAUTHORIZED )
107-
108- return user , displayname
131+ return self ._parse_cas_response (body )
109132
110- def _parse_cas_response (
111- self , cas_response_body : bytes
112- ) -> Tuple [str , Dict [str , Optional [str ]]]:
133+ def _parse_cas_response (self , cas_response_body : bytes ) -> CasResponse :
113134 """
114135 Retrieve the user and other parameters from the CAS response.
115136
116137 Args:
117138 cas_response_body: The response from the CAS query.
118139
140+ Raises:
141+ CasError: If there's an error parsing the CAS response.
142+
119143 Returns:
120- A tuple of the user and a mapping of other attributes .
144+ The parsed CAS response .
121145 """
146+
147+ # Ensure the response is valid.
148+ root = ET .fromstring (cas_response_body )
149+ if not root .tag .endswith ("serviceResponse" ):
150+ raise CasError (
151+ "missing_service_response" ,
152+ "root of CAS response is not serviceResponse" ,
153+ )
154+
155+ success = root [0 ].tag .endswith ("authenticationSuccess" )
156+ if not success :
157+ raise CasError ("unsucessful_response" , "Unsuccessful CAS response" )
158+
159+ # Iterate through the nodes and pull out the user and any extra attributes.
122160 user = None
123161 attributes = {}
124- try :
125- root = ET .fromstring (cas_response_body )
126- if not root .tag .endswith ("serviceResponse" ):
127- raise Exception ("root of CAS response is not serviceResponse" )
128- success = root [0 ].tag .endswith ("authenticationSuccess" )
129- for child in root [0 ]:
130- if child .tag .endswith ("user" ):
131- user = child .text
132- if child .tag .endswith ("attributes" ):
133- for attribute in child :
134- # ElementTree library expands the namespace in
135- # attribute tags to the full URL of the namespace.
136- # We don't care about namespace here and it will always
137- # be encased in curly braces, so we remove them.
138- tag = attribute .tag
139- if "}" in tag :
140- tag = tag .split ("}" )[1 ]
141- attributes [tag ] = attribute .text
142- if user is None :
143- raise Exception ("CAS response does not contain user" )
144- except Exception :
145- logger .exception ("Error parsing CAS response" )
146- raise LoginError (401 , "Invalid CAS response" , errcode = Codes .UNAUTHORIZED )
147- if not success :
148- raise LoginError (
149- 401 , "Unsuccessful CAS response" , errcode = Codes .UNAUTHORIZED
150- )
151- return user , attributes
162+ for child in root [0 ]:
163+ if child .tag .endswith ("user" ):
164+ user = child .text
165+ if child .tag .endswith ("attributes" ):
166+ for attribute in child :
167+ # ElementTree library expands the namespace in
168+ # attribute tags to the full URL of the namespace.
169+ # We don't care about namespace here and it will always
170+ # be encased in curly braces, so we remove them.
171+ tag = attribute .tag
172+ if "}" in tag :
173+ tag = tag .split ("}" )[1 ]
174+ attributes [tag ] = attribute .text
175+
176+ # Ensure a user was found.
177+ if user is None :
178+ raise CasError ("no_user" , "CAS response does not contain user" )
179+
180+ return CasResponse (user , attributes )
152181
153182 def get_redirect_url (self , service_args : Dict [str , str ]) -> str :
154183 """
@@ -201,15 +230,76 @@ async def handle_ticket(
201230 args ["redirectUrl" ] = client_redirect_url
202231 if session :
203232 args ["session" ] = session
204- username , user_display_name = await self ._validate_ticket (ticket , args )
233+
234+ try :
235+ cas_response = await self ._validate_ticket (ticket , args )
236+ except CasError as e :
237+ logger .exception ("Could not validate ticket" )
238+ self ._sso_handler .render_error (request , e .error , e .error_description , 401 )
239+ return
240+
241+ await self ._handle_cas_response (
242+ request , cas_response , client_redirect_url , session
243+ )
244+
245+ async def _handle_cas_response (
246+ self ,
247+ request : SynapseRequest ,
248+ cas_response : CasResponse ,
249+ client_redirect_url : Optional [str ],
250+ session : Optional [str ],
251+ ) -> None :
252+ """Handle a CAS response to a ticket request.
253+
254+ Assumes that the response has been validated. Maps the user onto an MXID,
255+ registering them if necessary, and returns a response to the browser.
256+
257+ Args:
258+ request: the incoming request from the browser. We'll respond to it with an
259+ HTML page or a redirect
260+
261+ cas_response: The parsed CAS response.
262+
263+ client_redirect_url: the redirectUrl parameter from the `/cas/ticket` HTTP request, if given.
264+ This should be the same as the redirectUrl from the original `/login/sso/redirect` request.
265+
266+ session: The session parameter from the `/cas/ticket` HTTP request, if given.
267+ This should be the UI Auth session id.
268+ """
269+
270+ # Ensure that the attributes of the logged in user meet the required
271+ # attributes.
272+ for required_attribute , required_value in self ._cas_required_attributes .items ():
273+ # If required attribute was not in CAS Response - Forbidden
274+ if required_attribute not in cas_response .attributes :
275+ self ._sso_handler .render_error (
276+ request ,
277+ "unauthorised" ,
278+ "You are not authorised to log in here." ,
279+ 401 ,
280+ )
281+ return
282+
283+ # Also need to check value
284+ if required_value is not None :
285+ actual_value = cas_response .attributes [required_attribute ]
286+ # If required attribute value does not match expected - Forbidden
287+ if required_value != actual_value :
288+ self ._sso_handler .render_error (
289+ request ,
290+ "unauthorised" ,
291+ "You are not authorised to log in here." ,
292+ 401 ,
293+ )
294+ return
205295
206296 # Pull out the user-agent and IP from the request.
207297 user_agent = request .get_user_agent ("" )
208298 ip_address = self .hs .get_ip_from_request (request )
209299
210300 # Get the matrix ID from the CAS username.
211301 user_id = await self ._map_cas_user_to_matrix_user (
212- username , user_display_name , user_agent , ip_address
302+ cas_response , user_agent , ip_address
213303 )
214304
215305 if session :
@@ -225,34 +315,31 @@ async def handle_ticket(
225315 )
226316
227317 async def _map_cas_user_to_matrix_user (
228- self ,
229- remote_user_id : str ,
230- display_name : Optional [str ],
231- user_agent : str ,
232- ip_address : str ,
318+ self , cas_response : CasResponse , user_agent : str , ip_address : str ,
233319 ) -> str :
234320 """
235321 Given a CAS username, retrieve the user ID for it and possibly register the user.
236322
237323 Args:
238- remote_user_id: The username from the CAS response.
239- display_name: The display name from the CAS response.
324+ cas_response: The parsed CAS response.
240325 user_agent: The user agent of the client making the request.
241326 ip_address: The IP address of the client making the request.
242327
243328 Returns:
244329 The user ID associated with this response.
245330 """
246331
247- localpart = map_username_to_mxid_localpart (remote_user_id )
332+ localpart = map_username_to_mxid_localpart (cas_response . username )
248333 user_id = UserID (localpart , self ._hostname ).to_string ()
249334 registered_user_id = await self ._auth_handler .check_user_exists (user_id )
250335
336+ displayname = cas_response .attributes .get (self ._cas_displayname_attribute , None )
337+
251338 # If the user does not exist, register it.
252339 if not registered_user_id :
253340 registered_user_id = await self ._registration_handler .register_user (
254341 localpart = localpart ,
255- default_display_name = display_name ,
342+ default_display_name = displayname ,
256343 user_agent_ips = [(user_agent , ip_address )],
257344 )
258345
0 commit comments