Skip to content

Preserve header casing #103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
20 changes: 10 additions & 10 deletions h11/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
# - If someone says Connection: close, we will close
# - If someone uses HTTP/1.0, we will close.
def _keep_alive(event):
connection = get_comma_header(event.headers, b"connection")
connection = get_comma_header(event.headers, b"Connection")
if b"close" in connection:
return False
if getattr(event, "http_version", b"1.1") < b"1.1":
Expand Down Expand Up @@ -85,13 +85,13 @@ def _body_framing(request_method, event):
assert event.status_code >= 200

# Step 2: check for Transfer-Encoding (T-E beats C-L):
transfer_encodings = get_comma_header(event.headers, b"transfer-encoding")
transfer_encodings = get_comma_header(event.headers, b"Transfer-Encoding")
if transfer_encodings:
assert transfer_encodings == [b"chunked"]
return ("chunked", ())

# Step 3: check for Content-Length
content_lengths = get_comma_header(event.headers, b"content-length")
content_lengths = get_comma_header(event.headers, b"Content-Length")
if content_lengths:
return ("content-length", (int(content_lengths[0]),))

Expand Down Expand Up @@ -234,7 +234,7 @@ def _process_event(self, role, event):
if role is CLIENT and type(event) is Request:
if event.method == b"CONNECT":
self._cstate.process_client_switch_proposal(_SWITCH_CONNECT)
if get_comma_header(event.headers, b"upgrade"):
if get_comma_header(event.headers, b"Upgrade"):
self._cstate.process_client_switch_proposal(_SWITCH_UPGRADE)
server_switch_event = None
if role is SERVER:
Expand Down Expand Up @@ -534,7 +534,7 @@ def send_failed(self):
def _clean_up_response_headers_for_sending(self, response):
assert type(response) is Response

headers = list(response.headers)
headers = response.headers
need_close = False

# HEAD requests need some special handling: they always act like they
Expand All @@ -560,27 +560,27 @@ def _clean_up_response_headers_for_sending(self, response):
# but the HTTP spec says that if our peer does this then we have
# to fix it instead of erroring out, so we'll accord the user the
# same respect).
set_comma_header(headers, b"content-length", [])
headers = set_comma_header(headers, b"Content-Length", [])
if self.their_http_version is None or self.their_http_version < b"1.1":
# Either we never got a valid request and are sending back an
# error (their_http_version is None), so we assume the worst;
# or else we did get a valid HTTP/1.0 request, so we know that
# they don't understand chunked encoding.
set_comma_header(headers, b"transfer-encoding", [])
headers = set_comma_header(headers, b"Transfer-Encoding", [])
# This is actually redundant ATM, since currently we
# unconditionally disable keep-alive when talking to HTTP/1.0
# peers. But let's be defensive just in case we add
# Connection: keep-alive support later:
if self._request_method != b"HEAD":
need_close = True
else:
set_comma_header(headers, b"transfer-encoding", ["chunked"])
headers = set_comma_header(headers, b"Transfer-Encoding", ["chunked"])

if not self._cstate.keep_alive or need_close:
# Make sure Connection: close is set
connection = set(get_comma_header(headers, b"connection"))
connection = set(get_comma_header(headers, b"Connection"))
connection.discard(b"keep-alive")
connection.add(b"close")
set_comma_header(headers, b"connection", sorted(connection))
headers = set_comma_header(headers, b"Connection", sorted(connection))

response.headers = headers
48 changes: 39 additions & 9 deletions h11/_headers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
try:
from collections.abc import Sequence
except ImportError:
# Python 2.7 support
from collections import Sequence

import re

from ._abnf import field_name, field_value
Expand Down Expand Up @@ -62,6 +68,28 @@
_field_value_re = re.compile(field_value.encode("ascii"))


class Headers(Sequence):
def __init__(self, items):
self._items = items

def __getitem__(self, item):
_, _, value = self._items[item]
return value

def __len__(self):
return len(self._items)

def __iter__(self):
for name, _, value in self._items:
yield name, value

def __eq__(self, other):
return list(self) == other

def raw(self):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious about this. It seems it would return a list of the triplet where elsewhere raw implies the name that came over the wire. This should be clearer about what it is returning

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also wonder about the wisdom of having a headers sequence without structured single-header data

Copy link
Contributor Author

@tomchristie tomchristie Aug 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also wonder about the wisdom of having a headers sequence without structured single-header data

I'm not quite sure what you mean here. Are you saying "If you're returning a three-tuple from this interface then let's have it use a named-tuple" or something else?

Perhaps a marginally different interface for us to expose here would not be .raw() -> (<lowercase name>, <raw name>, <value>), but instead expose just .raw_items() -> (<raw name>, <value>)

Perhaps that'd address the naming/intent slightly better?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking of having a class that encapsulates a Header and then having the collection of Headers use that. A namedtuple could work fine as well. All that said, I get that tuples may be a smidge faster and that these are internal implementation details. Speaking from having worked on header collection objects in urllib3 in the past, these tuples can drive maintainers to pull out their hair (as well as future folks trying to update/extend the behaviour).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, for something that remains compatible with the existing API I think the options here are...

  • A custom Headers sequence, that exposes the extra information in a .raw_items() interface or similar, that returns a two-tuple of (case-sensitive-name, value) for usages that require the raw casing info.
  • A custom Headers sequence, the returns Header instances, that can iterate as two-tuples, but also expose .name, .case_sensitive_name and .value attributes, which are available for usages that require the raw casing info.

Or some variation on those. (Eg. this PR which currently has .raw() returning the three-tuple of info.)
Personally I'm fairly agnostic, as both the above options seem reasonable enough. The Header case has the most extra overhead, since it creates and accesses a per-header instance rather than the plain tuple, while I wouldn't expect the .raw_items() approach to introduce anything really noticeable, but I could run through some timings on each of the options to help better inform our options.

return list(self._items)


def normalize_and_validate(headers, _parsed=False):
new_headers = []
saw_content_length = False
Expand All @@ -75,6 +103,7 @@ def normalize_and_validate(headers, _parsed=False):
value = bytesify(value)
validate(_field_name_re, name, "Illegal header name {!r}", name)
validate(_field_value_re, value, "Illegal header value {!r}", value)
raw_name = name
name = name.lower()
if name == b"content-length":
if saw_content_length:
Expand All @@ -99,16 +128,14 @@ def normalize_and_validate(headers, _parsed=False):
error_status_hint=501,
)
saw_transfer_encoding = True
new_headers.append((name, value))
return new_headers
new_headers.append((name, raw_name, value))
return Headers(new_headers)


def get_comma_header(headers, name):
# Should only be used for headers whose value is a list of
# comma-separated, case-insensitive values.
#
# The header name `name` is expected to be lower-case bytes.
#
# Connection: meets these criteria (including cast insensitivity).
#
# Content-Length: technically is just a single value (1*DIGIT), but the
Expand Down Expand Up @@ -139,6 +166,7 @@ def get_comma_header(headers, name):
# Expect: the only legal value is the literal string
# "100-continue". Splitting on commas is harmless. Case insensitive.
#
name = name.lower()
out = []
for found_name, found_raw_value in headers:
if found_name == name:
Expand All @@ -151,14 +179,16 @@ def get_comma_header(headers, name):


def set_comma_header(headers, name, new_values):
# The header name `name` is expected to be lower-case bytes.
raw_name = name
name = name.lower()

new_headers = []
for found_name, found_raw_value in headers:
for found_name, found_raw_name, found_raw_value in headers.raw():
if found_name != name:
new_headers.append((found_name, found_raw_value))
new_headers.append((found_raw_name, found_raw_value))
for new_value in new_values:
new_headers.append((name, new_value))
headers[:] = normalize_and_validate(new_headers)
new_headers.append((raw_name, new_value))
return normalize_and_validate(new_headers)


def has_expect_100_continue(request):
Expand Down
10 changes: 6 additions & 4 deletions h11/_writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,14 @@ def write_headers(headers, write):
# "Since the Host field-value is critical information for handling a
# request, a user agent SHOULD generate Host as the first header field
# following the request-line." - RFC 7230
for name, value in headers:
raw_headers = headers.raw()

for name, raw_name, value in raw_headers:
if name == b"host":
write(bytesmod(b"%s: %s\r\n", (name, value)))
for name, value in headers:
write(bytesmod(b"%s: %s\r\n", (raw_name, value)))
for name, raw_name, value in raw_headers:
if name != b"host":
write(bytesmod(b"%s: %s\r\n", (name, value)))
write(bytesmod(b"%s: %s\r\n", (raw_name, value)))
write(b"\r\n")


Expand Down
16 changes: 8 additions & 8 deletions h11/tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_Connection_basics_and_content_length():
),
)
assert data == (
b"GET / HTTP/1.1\r\n" b"host: example.com\r\n" b"content-length: 10\r\n\r\n"
b"GET / HTTP/1.1\r\n" b"Host: example.com\r\n" b"Content-Length: 10\r\n\r\n"
)

for conn in p.conns:
Expand All @@ -113,7 +113,7 @@ def test_Connection_basics_and_content_length():
assert data == b"HTTP/1.1 100 \r\n\r\n"

data = p.send(SERVER, Response(status_code=200, headers=[("Content-Length", "11")]))
assert data == b"HTTP/1.1 200 \r\ncontent-length: 11\r\n\r\n"
assert data == b"HTTP/1.1 200 \r\nContent-Length: 11\r\n\r\n"

for conn in p.conns:
assert conn.states == {CLIENT: SEND_BODY, SERVER: SEND_BODY}
Expand Down Expand Up @@ -243,7 +243,7 @@ def test_server_talking_to_http10_client():
# We automatically Connection: close back at them
assert (
c.send(Response(status_code=200, headers=[]))
== b"HTTP/1.1 200 \r\nconnection: close\r\n\r\n"
== b"HTTP/1.1 200 \r\nConnection: close\r\n\r\n"
)

assert c.send(Data(data=b"12345")) == b"12345"
Expand Down Expand Up @@ -303,7 +303,7 @@ def test_automatic_transfer_encoding_in_response():
receive_and_get(c, b"GET / HTTP/1.0\r\n\r\n")
assert (
c.send(Response(status_code=200, headers=user_headers))
== b"HTTP/1.1 200 \r\nconnection: close\r\n\r\n"
== b"HTTP/1.1 200 \r\nConnection: close\r\n\r\n"
)
assert c.send(Data(data=b"12345")) == b"12345"

Expand Down Expand Up @@ -876,7 +876,7 @@ def test_errors():
if role is SERVER:
assert (
c.send(Response(status_code=400, headers=[]))
== b"HTTP/1.1 400 \r\nconnection: close\r\n\r\n"
== b"HTTP/1.1 400 \r\nConnection: close\r\n\r\n"
)

# After an error sending, you can no longer send
Expand Down Expand Up @@ -988,14 +988,14 @@ def setup(method, http_version):
c = setup(method, b"1.1")
assert (
c.send(Response(status_code=200, headers=[])) == b"HTTP/1.1 200 \r\n"
b"transfer-encoding: chunked\r\n\r\n"
b"Transfer-Encoding: chunked\r\n\r\n"
)

# No Content-Length, HTTP/1.0 peer, frame with connection: close
c = setup(method, b"1.0")
assert (
c.send(Response(status_code=200, headers=[])) == b"HTTP/1.1 200 \r\n"
b"connection: close\r\n\r\n"
b"Connection: close\r\n\r\n"
)

# Content-Length + Transfer-Encoding, TE wins
Expand All @@ -1011,7 +1011,7 @@ def setup(method, http_version):
)
)
== b"HTTP/1.1 200 \r\n"
b"transfer-encoding: chunked\r\n\r\n"
b"Transfer-Encoding: chunked\r\n\r\n"
)


Expand Down
4 changes: 2 additions & 2 deletions h11/tests/test_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_get_set_comma_header():

assert get_comma_header(headers, b"connection") == [b"close", b"foo", b"bar"]

set_comma_header(headers, b"newthing", ["a", "b"])
headers = set_comma_header(headers, b"newthing", ["a", "b"])

with pytest.raises(LocalProtocolError):
set_comma_header(headers, b"newthing", [" a", "b"])
Expand All @@ -96,7 +96,7 @@ def test_get_set_comma_header():
(b"newthing", b"b"),
]

set_comma_header(headers, b"whatever", ["different thing"])
headers = set_comma_header(headers, b"whatever", ["different thing"])

assert headers == [
(b"connection", b"close"),
Expand Down
12 changes: 6 additions & 6 deletions h11/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@
target="/a",
headers=[("Host", "foo"), ("Connection", "close")],
),
b"GET /a HTTP/1.1\r\nhost: foo\r\nconnection: close\r\n\r\n",
b"GET /a HTTP/1.1\r\nHost: foo\r\nConnection: close\r\n\r\n",
),
(
(SERVER, SEND_RESPONSE),
Response(status_code=200, headers=[("Connection", "close")], reason=b"OK"),
b"HTTP/1.1 200 OK\r\nconnection: close\r\n\r\n",
b"HTTP/1.1 200 OK\r\nConnection: close\r\n\r\n",
),
(
(SERVER, SEND_RESPONSE),
Expand All @@ -48,7 +48,7 @@
InformationalResponse(
status_code=101, headers=[("Upgrade", "websocket")], reason=b"Upgrade"
),
b"HTTP/1.1 101 Upgrade\r\nupgrade: websocket\r\n\r\n",
b"HTTP/1.1 101 Upgrade\r\nUpgrade: websocket\r\n\r\n",
),
(
(SERVER, SEND_RESPONSE),
Expand Down Expand Up @@ -121,7 +121,7 @@ def test_writers_unusual():
normalize_and_validate([("foo", "bar"), ("baz", "quux")]),
b"foo: bar\r\nbaz: quux\r\n\r\n",
)
tw(write_headers, [], b"\r\n")
tw(write_headers, normalize_and_validate([]), b"\r\n")

# We understand HTTP/1.0, but we don't speak it
with pytest.raises(LocalProtocolError):
Expand Down Expand Up @@ -435,7 +435,7 @@ def test_ChunkedWriter():

assert (
dowrite(w, EndOfMessage(headers=[("Etag", "asdf"), ("a", "b")]))
== b"0\r\netag: asdf\r\na: b\r\n\r\n"
== b"0\r\nEtag: asdf\r\na: b\r\n\r\n"
)


Expand Down Expand Up @@ -503,5 +503,5 @@ def test_host_comes_first():
tw(
write_headers,
normalize_and_validate([("foo", "bar"), ("Host", "example.com")]),
b"host: example.com\r\nfoo: bar\r\n\r\n",
b"Host: example.com\r\nfoo: bar\r\n\r\n",
)