Skip to content

Preserve header casing #103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
18 changes: 9 additions & 9 deletions h11/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
# - If someone says Connection: close, we will close
# - If someone uses HTTP/1.0, we will close.
def _keep_alive(event):
connection = get_comma_header(event.headers, b"connection")
connection = get_comma_header(event.headers, b"Connection")
if b"close" in connection:
return False
if getattr(event, "http_version", b"1.1") < b"1.1":
Expand Down Expand Up @@ -85,13 +85,13 @@ def _body_framing(request_method, event):
assert event.status_code >= 200

# Step 2: check for Transfer-Encoding (T-E beats C-L):
transfer_encodings = get_comma_header(event.headers, b"transfer-encoding")
transfer_encodings = get_comma_header(event.headers, b"Transfer-Encoding")
if transfer_encodings:
assert transfer_encodings == [b"chunked"]
return ("chunked", ())

# Step 3: check for Content-Length
content_lengths = get_comma_header(event.headers, b"content-length")
content_lengths = get_comma_header(event.headers, b"Content-Length")
if content_lengths:
return ("content-length", (int(content_lengths[0]),))

Expand Down Expand Up @@ -234,7 +234,7 @@ def _process_event(self, role, event):
if role is CLIENT and type(event) is Request:
if event.method == b"CONNECT":
self._cstate.process_client_switch_proposal(_SWITCH_CONNECT)
if get_comma_header(event.headers, b"upgrade"):
if get_comma_header(event.headers, b"Upgrade"):
self._cstate.process_client_switch_proposal(_SWITCH_UPGRADE)
server_switch_event = None
if role is SERVER:
Expand Down Expand Up @@ -560,27 +560,27 @@ def _clean_up_response_headers_for_sending(self, response):
# but the HTTP spec says that if our peer does this then we have
# to fix it instead of erroring out, so we'll accord the user the
# same respect).
set_comma_header(headers, b"content-length", [])
set_comma_header(headers, b"Content-Length", [])
if self.their_http_version is None or self.their_http_version < b"1.1":
# Either we never got a valid request and are sending back an
# error (their_http_version is None), so we assume the worst;
# or else we did get a valid HTTP/1.0 request, so we know that
# they don't understand chunked encoding.
set_comma_header(headers, b"transfer-encoding", [])
set_comma_header(headers, b"Transfer-Encoding", [])
# This is actually redundant ATM, since currently we
# unconditionally disable keep-alive when talking to HTTP/1.0
# peers. But let's be defensive just in case we add
# Connection: keep-alive support later:
if self._request_method != b"HEAD":
need_close = True
else:
set_comma_header(headers, b"transfer-encoding", ["chunked"])
set_comma_header(headers, b"Transfer-Encoding", ["chunked"])

if not self._cstate.keep_alive or need_close:
# Make sure Connection: close is set
connection = set(get_comma_header(headers, b"connection"))
connection = set(get_comma_header(headers, b"Connection"))
connection.discard(b"keep-alive")
connection.add(b"close")
set_comma_header(headers, b"connection", sorted(connection))
set_comma_header(headers, b"Connection", sorted(connection))

response.headers = headers
41 changes: 28 additions & 13 deletions h11/_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,20 @@
_field_value_re = re.compile(field_value.encode("ascii"))


class Header:
def __init__(self, name, value):
self.raw_name = name
self.name = name.lower()
self.value = value

def __iter__(self):
yield self.name
yield self.value

def __eq__(self, other):
return (self.name, self.value) == other


def normalize_and_validate(headers, _parsed=False):
new_headers = []
saw_content_length = False
Expand All @@ -75,13 +89,13 @@ def normalize_and_validate(headers, _parsed=False):
value = bytesify(value)
validate(_field_name_re, name, "Illegal header name {!r}", name)
validate(_field_value_re, value, "Illegal header value {!r}", value)
name = name.lower()
if name == b"content-length":
header = Header(name, value)
if header.name == b"content-length":
if saw_content_length:
raise LocalProtocolError("multiple Content-Length headers")
validate(_content_length_re, value, "bad Content-Length")
saw_content_length = True
if name == b"transfer-encoding":
if header.name == b"transfer-encoding":
# "A server that receives a request message with a transfer coding
# it does not understand SHOULD respond with 501 (Not
# Implemented)."
Expand All @@ -92,23 +106,21 @@ def normalize_and_validate(headers, _parsed=False):
)
# "All transfer-coding names are case-insensitive"
# -- https://tools.ietf.org/html/rfc7230#section-4
value = value.lower()
if value != b"chunked":
header.value = header.value.lower()
if header.value != b"chunked":
raise LocalProtocolError(
"Only Transfer-Encoding: chunked is supported",
error_status_hint=501,
)
saw_transfer_encoding = True
new_headers.append((name, value))
new_headers.append(header)
return new_headers


def get_comma_header(headers, name):
# Should only be used for headers whose value is a list of
# comma-separated, case-insensitive values.
#
# The header name `name` is expected to be lower-case bytes.
#
# Connection: meets these criteria (including cast insensitivity).
#
# Content-Length: technically is just a single value (1*DIGIT), but the
Expand Down Expand Up @@ -139,6 +151,7 @@ def get_comma_header(headers, name):
# Expect: the only legal value is the literal string
# "100-continue". Splitting on commas is harmless. Case insensitive.
#
name = name.lower()
out = []
for found_name, found_raw_value in headers:
if found_name == name:
Expand All @@ -151,13 +164,15 @@ def get_comma_header(headers, name):


def set_comma_header(headers, name, new_values):
# The header name `name` is expected to be lower-case bytes.
raw_name = name
name = name.lower()

new_headers = []
for found_name, found_raw_value in headers:
if found_name != name:
new_headers.append((found_name, found_raw_value))
for header in headers:
if header.name != name:
new_headers.append((header.raw_name, header.value))
for new_value in new_values:
new_headers.append((name, new_value))
new_headers.append((raw_name, new_value))
headers[:] = normalize_and_validate(new_headers)


Expand Down
12 changes: 6 additions & 6 deletions h11/_writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ def write_headers(headers, write):
# "Since the Host field-value is critical information for handling a
# request, a user agent SHOULD generate Host as the first header field
# following the request-line." - RFC 7230
for name, value in headers:
if name == b"host":
write(bytesmod(b"%s: %s\r\n", (name, value)))
for name, value in headers:
if name != b"host":
write(bytesmod(b"%s: %s\r\n", (name, value)))
for header in headers:
if header.name == b"host":
write(bytesmod(b"%s: %s\r\n", (header.raw_name, header.value)))
for header in headers:
if header.name != b"host":
write(bytesmod(b"%s: %s\r\n", (header.raw_name, header.value)))
write(b"\r\n")


Expand Down
16 changes: 8 additions & 8 deletions h11/tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_Connection_basics_and_content_length():
),
)
assert data == (
b"GET / HTTP/1.1\r\n" b"host: example.com\r\n" b"content-length: 10\r\n\r\n"
b"GET / HTTP/1.1\r\n" b"Host: example.com\r\n" b"Content-Length: 10\r\n\r\n"
)

for conn in p.conns:
Expand All @@ -113,7 +113,7 @@ def test_Connection_basics_and_content_length():
assert data == b"HTTP/1.1 100 \r\n\r\n"

data = p.send(SERVER, Response(status_code=200, headers=[("Content-Length", "11")]))
assert data == b"HTTP/1.1 200 \r\ncontent-length: 11\r\n\r\n"
assert data == b"HTTP/1.1 200 \r\nContent-Length: 11\r\n\r\n"

for conn in p.conns:
assert conn.states == {CLIENT: SEND_BODY, SERVER: SEND_BODY}
Expand Down Expand Up @@ -243,7 +243,7 @@ def test_server_talking_to_http10_client():
# We automatically Connection: close back at them
assert (
c.send(Response(status_code=200, headers=[]))
== b"HTTP/1.1 200 \r\nconnection: close\r\n\r\n"
== b"HTTP/1.1 200 \r\nConnection: close\r\n\r\n"
)

assert c.send(Data(data=b"12345")) == b"12345"
Expand Down Expand Up @@ -303,7 +303,7 @@ def test_automatic_transfer_encoding_in_response():
receive_and_get(c, b"GET / HTTP/1.0\r\n\r\n")
assert (
c.send(Response(status_code=200, headers=user_headers))
== b"HTTP/1.1 200 \r\nconnection: close\r\n\r\n"
== b"HTTP/1.1 200 \r\nConnection: close\r\n\r\n"
)
assert c.send(Data(data=b"12345")) == b"12345"

Expand Down Expand Up @@ -876,7 +876,7 @@ def test_errors():
if role is SERVER:
assert (
c.send(Response(status_code=400, headers=[]))
== b"HTTP/1.1 400 \r\nconnection: close\r\n\r\n"
== b"HTTP/1.1 400 \r\nConnection: close\r\n\r\n"
)

# After an error sending, you can no longer send
Expand Down Expand Up @@ -988,14 +988,14 @@ def setup(method, http_version):
c = setup(method, b"1.1")
assert (
c.send(Response(status_code=200, headers=[])) == b"HTTP/1.1 200 \r\n"
b"transfer-encoding: chunked\r\n\r\n"
b"Transfer-Encoding: chunked\r\n\r\n"
)

# No Content-Length, HTTP/1.0 peer, frame with connection: close
c = setup(method, b"1.0")
assert (
c.send(Response(status_code=200, headers=[])) == b"HTTP/1.1 200 \r\n"
b"connection: close\r\n\r\n"
b"Connection: close\r\n\r\n"
)

# Content-Length + Transfer-Encoding, TE wins
Expand All @@ -1011,7 +1011,7 @@ def setup(method, http_version):
)
)
== b"HTTP/1.1 200 \r\n"
b"transfer-encoding: chunked\r\n\r\n"
b"Transfer-Encoding: chunked\r\n\r\n"
)


Expand Down
10 changes: 5 additions & 5 deletions h11/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@
target="/a",
headers=[("Host", "foo"), ("Connection", "close")],
),
b"GET /a HTTP/1.1\r\nhost: foo\r\nconnection: close\r\n\r\n",
b"GET /a HTTP/1.1\r\nHost: foo\r\nConnection: close\r\n\r\n",
),
(
(SERVER, SEND_RESPONSE),
Response(status_code=200, headers=[("Connection", "close")], reason=b"OK"),
b"HTTP/1.1 200 OK\r\nconnection: close\r\n\r\n",
b"HTTP/1.1 200 OK\r\nConnection: close\r\n\r\n",
),
(
(SERVER, SEND_RESPONSE),
Expand All @@ -48,7 +48,7 @@
InformationalResponse(
status_code=101, headers=[("Upgrade", "websocket")], reason=b"Upgrade"
),
b"HTTP/1.1 101 Upgrade\r\nupgrade: websocket\r\n\r\n",
b"HTTP/1.1 101 Upgrade\r\nUpgrade: websocket\r\n\r\n",
),
(
(SERVER, SEND_RESPONSE),
Expand Down Expand Up @@ -435,7 +435,7 @@ def test_ChunkedWriter():

assert (
dowrite(w, EndOfMessage(headers=[("Etag", "asdf"), ("a", "b")]))
== b"0\r\netag: asdf\r\na: b\r\n\r\n"
== b"0\r\nEtag: asdf\r\na: b\r\n\r\n"
)


Expand Down Expand Up @@ -503,5 +503,5 @@ def test_host_comes_first():
tw(
write_headers,
normalize_and_validate([("foo", "bar"), ("Host", "example.com")]),
b"host: example.com\r\nfoo: bar\r\n\r\n",
b"Host: example.com\r\nfoo: bar\r\n\r\n",
)