diff --git a/h11/_connection.py b/h11/_connection.py index fc6289a..74eca39 100644 --- a/h11/_connection.py +++ b/h11/_connection.py @@ -45,7 +45,7 @@ # - If someone says Connection: close, we will close # - If someone uses HTTP/1.0, we will close. def _keep_alive(event): - connection = get_comma_header(event.headers, b"connection") + connection = get_comma_header(event.headers, b"Connection") if b"close" in connection: return False if getattr(event, "http_version", b"1.1") < b"1.1": @@ -85,13 +85,13 @@ def _body_framing(request_method, event): assert event.status_code >= 200 # Step 2: check for Transfer-Encoding (T-E beats C-L): - transfer_encodings = get_comma_header(event.headers, b"transfer-encoding") + transfer_encodings = get_comma_header(event.headers, b"Transfer-Encoding") if transfer_encodings: assert transfer_encodings == [b"chunked"] return ("chunked", ()) # Step 3: check for Content-Length - content_lengths = get_comma_header(event.headers, b"content-length") + content_lengths = get_comma_header(event.headers, b"Content-Length") if content_lengths: return ("content-length", (int(content_lengths[0]),)) @@ -234,7 +234,7 @@ def _process_event(self, role, event): if role is CLIENT and type(event) is Request: if event.method == b"CONNECT": self._cstate.process_client_switch_proposal(_SWITCH_CONNECT) - if get_comma_header(event.headers, b"upgrade"): + if get_comma_header(event.headers, b"Upgrade"): self._cstate.process_client_switch_proposal(_SWITCH_UPGRADE) server_switch_event = None if role is SERVER: @@ -560,13 +560,13 @@ def _clean_up_response_headers_for_sending(self, response): # but the HTTP spec says that if our peer does this then we have # to fix it instead of erroring out, so we'll accord the user the # same respect). - set_comma_header(headers, b"content-length", []) + set_comma_header(headers, b"Content-Length", []) if self.their_http_version is None or self.their_http_version < b"1.1": # Either we never got a valid request and are sending back an # error (their_http_version is None), so we assume the worst; # or else we did get a valid HTTP/1.0 request, so we know that # they don't understand chunked encoding. - set_comma_header(headers, b"transfer-encoding", []) + set_comma_header(headers, b"Transfer-Encoding", []) # This is actually redundant ATM, since currently we # unconditionally disable keep-alive when talking to HTTP/1.0 # peers. But let's be defensive just in case we add @@ -574,13 +574,13 @@ def _clean_up_response_headers_for_sending(self, response): if self._request_method != b"HEAD": need_close = True else: - set_comma_header(headers, b"transfer-encoding", ["chunked"]) + set_comma_header(headers, b"Transfer-Encoding", ["chunked"]) if not self._cstate.keep_alive or need_close: # Make sure Connection: close is set - connection = set(get_comma_header(headers, b"connection")) + connection = set(get_comma_header(headers, b"Connection")) connection.discard(b"keep-alive") connection.add(b"close") - set_comma_header(headers, b"connection", sorted(connection)) + set_comma_header(headers, b"Connection", sorted(connection)) response.headers = headers diff --git a/h11/_headers.py b/h11/_headers.py index 878f63c..ba83008 100644 --- a/h11/_headers.py +++ b/h11/_headers.py @@ -62,6 +62,20 @@ _field_value_re = re.compile(field_value.encode("ascii")) +class Header: + def __init__(self, name, value): + self.raw_name = name + self.name = name.lower() + self.value = value + + def __iter__(self): + yield self.name + yield self.value + + def __eq__(self, other): + return (self.name, self.value) == other + + def normalize_and_validate(headers, _parsed=False): new_headers = [] saw_content_length = False @@ -75,13 +89,13 @@ def normalize_and_validate(headers, _parsed=False): value = bytesify(value) validate(_field_name_re, name, "Illegal header name {!r}", name) validate(_field_value_re, value, "Illegal header value {!r}", value) - name = name.lower() - if name == b"content-length": + header = Header(name, value) + if header.name == b"content-length": if saw_content_length: raise LocalProtocolError("multiple Content-Length headers") validate(_content_length_re, value, "bad Content-Length") saw_content_length = True - if name == b"transfer-encoding": + if header.name == b"transfer-encoding": # "A server that receives a request message with a transfer coding # it does not understand SHOULD respond with 501 (Not # Implemented)." @@ -92,14 +106,14 @@ def normalize_and_validate(headers, _parsed=False): ) # "All transfer-coding names are case-insensitive" # -- https://tools.ietf.org/html/rfc7230#section-4 - value = value.lower() - if value != b"chunked": + header.value = header.value.lower() + if header.value != b"chunked": raise LocalProtocolError( "Only Transfer-Encoding: chunked is supported", error_status_hint=501, ) saw_transfer_encoding = True - new_headers.append((name, value)) + new_headers.append(header) return new_headers @@ -107,8 +121,6 @@ def get_comma_header(headers, name): # Should only be used for headers whose value is a list of # comma-separated, case-insensitive values. # - # The header name `name` is expected to be lower-case bytes. - # # Connection: meets these criteria (including cast insensitivity). # # Content-Length: technically is just a single value (1*DIGIT), but the @@ -139,6 +151,7 @@ def get_comma_header(headers, name): # Expect: the only legal value is the literal string # "100-continue". Splitting on commas is harmless. Case insensitive. # + name = name.lower() out = [] for found_name, found_raw_value in headers: if found_name == name: @@ -151,13 +164,15 @@ def get_comma_header(headers, name): def set_comma_header(headers, name, new_values): - # The header name `name` is expected to be lower-case bytes. + raw_name = name + name = name.lower() + new_headers = [] - for found_name, found_raw_value in headers: - if found_name != name: - new_headers.append((found_name, found_raw_value)) + for header in headers: + if header.name != name: + new_headers.append((header.raw_name, header.value)) for new_value in new_values: - new_headers.append((name, new_value)) + new_headers.append((raw_name, new_value)) headers[:] = normalize_and_validate(new_headers) diff --git a/h11/_writers.py b/h11/_writers.py index 6a41100..2c5b79b 100644 --- a/h11/_writers.py +++ b/h11/_writers.py @@ -38,12 +38,12 @@ def write_headers(headers, write): # "Since the Host field-value is critical information for handling a # request, a user agent SHOULD generate Host as the first header field # following the request-line." - RFC 7230 - for name, value in headers: - if name == b"host": - write(bytesmod(b"%s: %s\r\n", (name, value))) - for name, value in headers: - if name != b"host": - write(bytesmod(b"%s: %s\r\n", (name, value))) + for header in headers: + if header.name == b"host": + write(bytesmod(b"%s: %s\r\n", (header.raw_name, header.value))) + for header in headers: + if header.name != b"host": + write(bytesmod(b"%s: %s\r\n", (header.raw_name, header.value))) write(b"\r\n") diff --git a/h11/tests/test_connection.py b/h11/tests/test_connection.py index 13e6e2d..a43113e 100644 --- a/h11/tests/test_connection.py +++ b/h11/tests/test_connection.py @@ -96,7 +96,7 @@ def test_Connection_basics_and_content_length(): ), ) assert data == ( - b"GET / HTTP/1.1\r\n" b"host: example.com\r\n" b"content-length: 10\r\n\r\n" + b"GET / HTTP/1.1\r\n" b"Host: example.com\r\n" b"Content-Length: 10\r\n\r\n" ) for conn in p.conns: @@ -113,7 +113,7 @@ def test_Connection_basics_and_content_length(): assert data == b"HTTP/1.1 100 \r\n\r\n" data = p.send(SERVER, Response(status_code=200, headers=[("Content-Length", "11")])) - assert data == b"HTTP/1.1 200 \r\ncontent-length: 11\r\n\r\n" + assert data == b"HTTP/1.1 200 \r\nContent-Length: 11\r\n\r\n" for conn in p.conns: assert conn.states == {CLIENT: SEND_BODY, SERVER: SEND_BODY} @@ -243,7 +243,7 @@ def test_server_talking_to_http10_client(): # We automatically Connection: close back at them assert ( c.send(Response(status_code=200, headers=[])) - == b"HTTP/1.1 200 \r\nconnection: close\r\n\r\n" + == b"HTTP/1.1 200 \r\nConnection: close\r\n\r\n" ) assert c.send(Data(data=b"12345")) == b"12345" @@ -303,7 +303,7 @@ def test_automatic_transfer_encoding_in_response(): receive_and_get(c, b"GET / HTTP/1.0\r\n\r\n") assert ( c.send(Response(status_code=200, headers=user_headers)) - == b"HTTP/1.1 200 \r\nconnection: close\r\n\r\n" + == b"HTTP/1.1 200 \r\nConnection: close\r\n\r\n" ) assert c.send(Data(data=b"12345")) == b"12345" @@ -876,7 +876,7 @@ def test_errors(): if role is SERVER: assert ( c.send(Response(status_code=400, headers=[])) - == b"HTTP/1.1 400 \r\nconnection: close\r\n\r\n" + == b"HTTP/1.1 400 \r\nConnection: close\r\n\r\n" ) # After an error sending, you can no longer send @@ -988,14 +988,14 @@ def setup(method, http_version): c = setup(method, b"1.1") assert ( c.send(Response(status_code=200, headers=[])) == b"HTTP/1.1 200 \r\n" - b"transfer-encoding: chunked\r\n\r\n" + b"Transfer-Encoding: chunked\r\n\r\n" ) # No Content-Length, HTTP/1.0 peer, frame with connection: close c = setup(method, b"1.0") assert ( c.send(Response(status_code=200, headers=[])) == b"HTTP/1.1 200 \r\n" - b"connection: close\r\n\r\n" + b"Connection: close\r\n\r\n" ) # Content-Length + Transfer-Encoding, TE wins @@ -1011,7 +1011,7 @@ def setup(method, http_version): ) ) == b"HTTP/1.1 200 \r\n" - b"transfer-encoding: chunked\r\n\r\n" + b"Transfer-Encoding: chunked\r\n\r\n" ) diff --git a/h11/tests/test_io.py b/h11/tests/test_io.py index ef5e31b..e72d0b2 100644 --- a/h11/tests/test_io.py +++ b/h11/tests/test_io.py @@ -31,12 +31,12 @@ target="/a", headers=[("Host", "foo"), ("Connection", "close")], ), - b"GET /a HTTP/1.1\r\nhost: foo\r\nconnection: close\r\n\r\n", + b"GET /a HTTP/1.1\r\nHost: foo\r\nConnection: close\r\n\r\n", ), ( (SERVER, SEND_RESPONSE), Response(status_code=200, headers=[("Connection", "close")], reason=b"OK"), - b"HTTP/1.1 200 OK\r\nconnection: close\r\n\r\n", + b"HTTP/1.1 200 OK\r\nConnection: close\r\n\r\n", ), ( (SERVER, SEND_RESPONSE), @@ -48,7 +48,7 @@ InformationalResponse( status_code=101, headers=[("Upgrade", "websocket")], reason=b"Upgrade" ), - b"HTTP/1.1 101 Upgrade\r\nupgrade: websocket\r\n\r\n", + b"HTTP/1.1 101 Upgrade\r\nUpgrade: websocket\r\n\r\n", ), ( (SERVER, SEND_RESPONSE), @@ -435,7 +435,7 @@ def test_ChunkedWriter(): assert ( dowrite(w, EndOfMessage(headers=[("Etag", "asdf"), ("a", "b")])) - == b"0\r\netag: asdf\r\na: b\r\n\r\n" + == b"0\r\nEtag: asdf\r\na: b\r\n\r\n" ) @@ -503,5 +503,5 @@ def test_host_comes_first(): tw( write_headers, normalize_and_validate([("foo", "bar"), ("Host", "example.com")]), - b"host: example.com\r\nfoo: bar\r\n\r\n", + b"Host: example.com\r\nfoo: bar\r\n\r\n", )