Skip to content

Immutable QueryParams #1600

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion httpx/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def _merge_queryparams(
"""
if params or self.params:
merged_queryparams = QueryParams(self.params)
merged_queryparams.update(params)
merged_queryparams = merged_queryparams.merge(params)
return merged_queryparams
return params

Expand Down
172 changes: 144 additions & 28 deletions httpx/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import urllib.request
from collections.abc import MutableMapping
from http.cookiejar import Cookie, CookieJar
from urllib.parse import parse_qsl, quote, unquote, urlencode
from urllib.parse import parse_qs, quote, unquote, urlencode

import idna
import rfc3986
Expand Down Expand Up @@ -48,7 +48,6 @@
URLTypes,
)
from ._utils import (
flatten_queryparams,
guess_json_utf,
is_known_encoding,
normalize_header_key,
Expand Down Expand Up @@ -148,8 +147,7 @@ def __init__(
# Add any query parameters, merging with any in the URL if needed.
if params:
if self._uri_reference.query:
url_params = QueryParams(self._uri_reference.query)
url_params.update(params)
url_params = QueryParams(self._uri_reference.query).merge(params)
query_string = str(url_params)
else:
query_string = str(QueryParams(params))
Expand Down Expand Up @@ -450,7 +448,7 @@ def join(self, url: URLTypes) -> "URL":

url = httpx.URL("https://www.example.com/test")
url = url.join("/new/path")
assert url == "https://www.example.com/test/new/path"
assert url == "https://www.example.com/new/path"
"""
if self.is_relative_url:
# Workaround to handle relative URLs, which otherwise raise
Expand Down Expand Up @@ -504,38 +502,79 @@ def __init__(self, *args: QueryParamTypes, **kwargs: typing.Any) -> None:
items: typing.Sequence[typing.Tuple[str, PrimitiveData]]
if value is None or isinstance(value, (str, bytes)):
value = value.decode("ascii") if isinstance(value, bytes) else value
items = parse_qsl(value)
self._dict = parse_qs(value)
elif isinstance(value, QueryParams):
items = value.multi_items()
elif isinstance(value, (list, tuple)):
items = value
self._dict = {k: list(v) for k, v in value._dict.items()}
else:
items = flatten_queryparams(value)

self._dict: typing.Dict[str, typing.List[str]] = {}
for item in items:
k, v = item
if str(k) not in self._dict:
self._dict[str(k)] = [primitive_value_to_str(v)]
dict_value: typing.Dict[typing.Any, typing.List[typing.Any]] = {}
if isinstance(value, (list, tuple)):
# Convert list inputs like:
# [("a", "123"), ("a", "456"), ("b", "789")]
# To a dict representation, like:
# {"a": ["123", "456"], "b": ["789"]}
for item in value:
dict_value.setdefault(item[0], []).append(item[1])
else:
self._dict[str(k)].append(primitive_value_to_str(v))
# Convert dict inputs like:
# {"a": "123", "b": ["456", "789"]}
# To dict inputs where values are always lists, like:
# {"a": ["123"], "b": ["456", "789"]}
dict_value = {
k: list(v) if isinstance(v, (list, tuple)) else [v]
for k, v in value.items()
}

# Ensure that keys and values are neatly coerced to strings.
# We coerce values `True` and `False` to JSON-like "true" and "false"
# representations, and coerce `None` values to the empty string.
self._dict = {
str(k): [primitive_value_to_str(item) for item in v]
for k, v in dict_value.items()
}

def keys(self) -> typing.KeysView:
"""
Return all the keys in the query params.

Usage:

q = httpx.QueryParams("a=123&a=456&b=789")
assert list(q.keys()) == ["a", "b"]
"""
return self._dict.keys()

def values(self) -> typing.ValuesView:
"""
Return all the values in the query params. If a key occurs more than once
only the first item for that key is returned.

Usage:

q = httpx.QueryParams("a=123&a=456&b=789")
assert list(q.values()) == ["123", "789"]
"""
return {k: v[0] for k, v in self._dict.items()}.values()

def items(self) -> typing.ItemsView:
"""
Return all items in the query params. If a key occurs more than once
only the first item for that key is returned.

Usage:

q = httpx.QueryParams("a=123&a=456&b=789")
assert list(q.items()) == [("a", "123"), ("b", "789")]
"""
return {k: v[0] for k, v in self._dict.items()}.items()

def multi_items(self) -> typing.List[typing.Tuple[str, str]]:
"""
Return all items in the query params. Allow duplicate keys to occur.

Usage:

q = httpx.QueryParams("a=123&a=456&b=789")
assert list(q.multi_items()) == [("a", "123"), ("a", "456"), ("b", "789")]
"""
multi_items: typing.List[typing.Tuple[str, str]] = []
for k, v in self._dict.items():
Expand All @@ -546,31 +585,93 @@ def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
"""
Get a value from the query param for a given key. If the key occurs
more than once, then only the first value is returned.

Usage:

q = httpx.QueryParams("a=123&a=456&b=789")
assert q.get("a") == "123"
"""
if key in self._dict:
return self._dict[key][0]
return self._dict[str(key)][0]
return default

def get_list(self, key: typing.Any) -> typing.List[str]:
"""
Get all values from the query param for a given key.

Usage:

q = httpx.QueryParams("a=123&a=456&b=789")
assert q.get_list("a") == ["123", "456"]
"""
return list(self._dict.get(key, []))
return list(self._dict.get(str(key), []))

def update(self, params: QueryParamTypes = None) -> None:
if not params:
return
def set(self, key: typing.Any, value: typing.Any = None) -> "QueryParams":
"""
Return a new QueryParams instance, setting the value of a key.

Usage:

q = httpx.QueryParams("a=123")
q = q.set("a", "456")
assert q == httpx.QueryParams("a=456")
"""
q = QueryParams()
q._dict = dict(self._dict)
q._dict[str(key)] = [primitive_value_to_str(value)]
return q

def add(self, key: typing.Any, value: typing.Any = None) -> "QueryParams":
"""
Return a new QueryParams instance, setting or appending the value of a key.

params = QueryParams(params)
for k in params.keys():
self._dict[k] = params.get_list(k)
Usage:

q = httpx.QueryParams("a=123")
q = q.add("a", "456")
assert q == httpx.QueryParams("a=123&a=456")
"""
q = QueryParams()
q._dict = dict(self._dict)
q._dict[str(key)] = q.get_list(key) + [primitive_value_to_str(value)]
return q

def remove(self, key: typing.Any) -> "QueryParams":
"""
Return a new QueryParams instance, removing the value of a key.

Usage:

q = httpx.QueryParams("a=123")
q = q.remove("a")
assert q == httpx.QueryParams("")
"""
q = QueryParams()
q._dict = dict(self._dict)
q._dict.pop(str(key), None)
return q

def merge(self, params: QueryParamTypes = None) -> "QueryParams":
"""
Return a new QueryParams instance, updated with.

Usage:

q = httpx.QueryParams("a=123")
q = q.merge({"b": "456"})
assert q == httpx.QueryParams("a=123&b=456")

q = httpx.QueryParams("a=123")
q = q.merge({"a": "456", "b": "789"})
assert q == httpx.QueryParams("a=456&b=789")
"""
q = QueryParams(params)
q._dict = {**self._dict, **q._dict}
return q

def __getitem__(self, key: typing.Any) -> str:
return self._dict[key][0]

def __setitem__(self, key: str, value: str) -> None:
self._dict[key] = [value]

def __contains__(self, key: typing.Any) -> bool:
return key in self._dict

Expand All @@ -580,6 +681,9 @@ def __iter__(self) -> typing.Iterator[typing.Any]:
def __len__(self) -> int:
return len(self._dict)

def __hash__(self) -> int:
return hash(str(self))

def __eq__(self, other: typing.Any) -> bool:
if not isinstance(other, self.__class__):
return False
Expand All @@ -593,6 +697,18 @@ def __repr__(self) -> str:
query_string = str(self)
return f"{class_name}({query_string!r})"

def update(self, params: QueryParamTypes = None) -> None:
raise RuntimeError(
"QueryParams are immutable since 0.18.0. "
"Use `q = q.merge(...)` to create an updated copy."
)

def __setitem__(self, key: str, value: str) -> None:
raise RuntimeError(
"QueryParams are immutable since 0.18.0. "
"Use `q = q.set(key, value)` to create an updated copy."
)


class Headers(typing.MutableMapping[str, str]):
"""
Expand Down
26 changes: 0 additions & 26 deletions httpx/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import codecs
import collections
import logging
import mimetypes
import netrc
Expand Down Expand Up @@ -369,31 +368,6 @@ def peek_filelike_length(stream: typing.IO) -> int:
return os.fstat(fd).st_size


def flatten_queryparams(
queryparams: typing.Mapping[
str, typing.Union["PrimitiveData", typing.Sequence["PrimitiveData"]]
]
) -> typing.List[typing.Tuple[str, "PrimitiveData"]]:
"""
Convert a mapping of query params into a flat list of two-tuples
representing each item.

Example:
>>> flatten_queryparams_values({"q": "httpx", "tag": ["python", "dev"]})
[("q", "httpx), ("tag", "python"), ("tag", "dev")]
"""
items = []

for k, v in queryparams.items():
if isinstance(v, collections.abc.Sequence) and not isinstance(v, (str, bytes)):
for u in v:
items.append((k, u))
else:
items.append((k, typing.cast("PrimitiveData", v)))

return items


class Timer:
async def _get_time(self) -> float:
library = sniffio.current_async_library()
Expand Down
55 changes: 43 additions & 12 deletions tests/models/test_queryparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,19 +76,50 @@ def test_queryparam_types():
assert str(q) == "a=1&a=2"


def test_queryparam_setters():
q = httpx.QueryParams({"a": 1})
q.update([])
def test_queryparam_update_is_hard_deprecated():
q = httpx.QueryParams("a=123")
with pytest.raises(RuntimeError):
q.update({"a": "456"})

assert str(q) == "a=1"

q = httpx.QueryParams([("a", 1), ("a", 2)])
q["a"] = "3"
assert str(q) == "a=3"
def test_queryparam_setter_is_hard_deprecated():
q = httpx.QueryParams("a=123")
with pytest.raises(RuntimeError):
q["a"] = "456"

q = httpx.QueryParams([("a", 1), ("b", 1)])
u = httpx.QueryParams([("b", 2), ("b", 3)])
q.update(u)

assert str(q) == "a=1&b=2&b=3"
assert q["b"] == u["b"]
def test_queryparam_set():
q = httpx.QueryParams("a=123")
q = q.set("a", "456")
assert q == httpx.QueryParams("a=456")


def test_queryparam_add():
q = httpx.QueryParams("a=123")
q = q.add("a", "456")
assert q == httpx.QueryParams("a=123&a=456")


def test_queryparam_remove():
q = httpx.QueryParams("a=123")
q = q.remove("a")
assert q == httpx.QueryParams("")


def test_queryparam_merge():
q = httpx.QueryParams("a=123")
q = q.merge({"b": "456"})
assert q == httpx.QueryParams("a=123&b=456")
q = q.merge({"a": "000", "c": "789"})
assert q == httpx.QueryParams("a=000&b=456&c=789")


def test_queryparams_are_hashable():
params = (
httpx.QueryParams("a=123"),
httpx.QueryParams({"a": 123}),
httpx.QueryParams("b=456"),
httpx.QueryParams({"b": 456}),
)

assert len(set(params)) == 2