Skip to content
This repository was archived by the owner on Mar 26, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,116 @@ from google.api_core import retry as retries
from google.api_core import rest_helpers
from google.api_core import path_template
from google.api_core import gapic_v1

# TODO: Remove once my PR gets merged and released.
# Begin of ResponseIterator depedencies.
from collections import deque
import string
from typing import Deque
import requests

class ResponseIterator:
"""Iterator over REST API responses.

Args:
response (requests.Response): An API response object.
response_message_cls (Callable[proto.Message]): A proto
class expected to be returned from an API.
"""

def __init__(self, response: requests.Response, response_message_cls):
self._response = response
self._response_message_cls = response_message_cls
# Inner iterator over HTTP response's content.
self._response_itr = self._response.iter_content(decode_unicode=True)
# Contains a list of JSON responses ready to be sent to user.
self._ready_objs: Deque[str] = deque()
# Current JSON response being built.
self._obj = ""
# Keeps track of the nesting level within a JSON object.
self._level = 0
# Keeps track whether HTTP response is currently sending values
# inside of a string value.
self._in_string = False
# Whether an escape symbol "\" was encountered.
self._next_should_be_escaped = False

def cancel(self):
"""Cancel existing streaming operation.
"""
self._response.close()

def _process_chunk(self, chunk: str):
if self._level == 0:
if chunk[0] != "[":
raise ValueError(
"Can only parse array of JSON objects, instead got %s" % chunk
)
for char in chunk:
if char == "{":
if self._level == 1:
# Level 1 corresponds to the outermost JSON object
# (i.e. the one we care about).
self._obj = ""
if not self._in_string:
self._level += 1
self._obj += char
elif char == "}":
self._obj += char
if not self._in_string:
self._level -= 1
if not self._in_string and self._level == 1:
self._ready_objs.append(self._obj)
elif char == '"':
# Helps to deal with an escaped quotes inside of a string.
if not self._next_should_be_escaped:
self._in_string = not self._in_string
self._obj += char
elif char in string.whitespace:
if self._in_string:
self._obj += char
elif char == "[":
if self._level == 0:
self._level += 1
else:
self._obj += char
elif char == "]":
if self._level == 1:
self._level -= 1
else:
self._obj += char
else:
self._obj += char

if char == "\\":
# Escaping the "\".
if self._next_should_be_escaped:
self._next_should_be_escaped = False
else:
self._next_should_be_escaped = True
else:
self._next_should_be_escaped = False

def __next__(self):
while not self._ready_objs:
try:
chunk = next(self._response_itr)
self._process_chunk(chunk)
except StopIteration as e:
if self._level > 0:
raise ValueError("Unfinished stream: %s" % self._obj)
raise e
return self._grab()

def _grab(self):
# Add extra quotes to make json.loads happy.
return self._response_message_cls.from_json(self._ready_objs.popleft())

def __iter__(self):
return self

# End of ResponseIterator dependencies.

{% if service.has_lro %}
from google.api_core import operations_v1
from google.protobuf import json_format
Expand Down Expand Up @@ -179,7 +289,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
def __hash__(self):
return hash("{{method.name}}")

{% if not (method.server_streaming or method.client_streaming) %}
{% if not method.client_streaming %}
{% if method.input.required_fields %}
__REQUIRED_FIELDS_DEFAULT_VALUES = {
{% for req_field in method.input.required_fields if req_field.is_primitive %}
Expand All @@ -200,7 +310,7 @@ class {{service.name}}RestTransport({{service.name}}Transport):
timeout: float=None,
metadata: Sequence[Tuple[str, str]]=(),
) -> {{method.output.ident}}:
{% if method.http_options and not (method.server_streaming or method.client_streaming) %}
{% if method.http_options and not method.client_streaming %}
r"""Call the {{- ' ' -}}
{{ (method.name|snake_case).replace('_',' ')|wrap(
width=70, offset=45, indent=8) }}
Expand Down Expand Up @@ -291,6 +401,8 @@ class {{service.name}}RestTransport({{service.name}}Transport):
return_op = operations_pb2.Operation()
json_format.Parse(response.content, return_op, ignore_unknown_fields=True)
return return_op
{% elif method.server_streaming %}
return ResponseIterator(response, {{method.output.ident}})
{% else %}
return {{method.output.ident}}.from_json(
response.content,
Expand All @@ -299,14 +411,14 @@ class {{service.name}}RestTransport({{service.name}}Transport):

{% endif %}{# method.lro #}
{% endif %}{# method.void #}
{% else %}{# method.http_options and not (method.server_streaming or method.client_streaming) #}
{% else %}{# method.http_options and not method.client_streaming #}
{% if not method.http_options %}
raise RuntimeError(
"Cannot define a method without a valid 'google.api.http' annotation.")

{% elif method.server_streaming or method.client_streaming %}
{% elif method.client_streaming %}
raise NotImplementedError(
"Streaming over REST is not yet defined for python client")
"Client streaming over REST is not yet defined for python client")

{% else %}
raise NotImplementedError()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import mock
import grpc
from grpc.experimental import aio
{% if "rest" in opts.transport %}
from collections.abc import Iterable
import json
{% endif %}
import math
Expand Down Expand Up @@ -823,8 +824,8 @@ def test_{{ method_name }}_raw_page_lro():
{% endfor %} {# method in methods for grpc #}

{% for method in service.methods.values() if 'rest' in opts.transport %}{% with method_name = method.name|snake_case + "_unary" if method.operation_service else method.name|snake_case %}{% if method.http_options %}
{# TODO(kbandes): remove this if condition when streaming are supported. #}
{% if not (method.server_streaming or method.client_streaming) %}
{# TODO(kbandes): remove this if condition when client streaming are supported. #}
{% if not method.client_streaming %}
@pytest.mark.parametrize("request_type", [
{{ method.input.ident }},
dict,
Expand All @@ -846,8 +847,6 @@ def test_{{ method_name }}_rest(request_type, transport: str = 'rest'):
return_value = None
{% elif method.lro %}
return_value = operations_pb2.Operation(name='operations/spam')
{% elif method.server_streaming %}
return_value = iter([{{ method.output.ident }}()])
{% else %}
return_value = {{ method.output.ident }}(
{% for field in method.output.fields.values() | rejectattr('message')%}
Expand All @@ -867,6 +866,8 @@ def test_{{ method_name }}_rest(request_type, transport: str = 'rest'):
req.return_value.request = PreparedRequest()
{% if method.void %}
json_return_value = ''
{% elif method.server_streaming %}
json_return_value = "[{}]".format({{ method.output.ident }}.to_json(return_value))
{% else %}
json_return_value = {{ method.output.ident }}.to_json(return_value)
{% endif %}
Expand All @@ -876,6 +877,10 @@ def test_{{ method_name }}_rest(request_type, transport: str = 'rest'):
# the request over the wire, so an empty request is fine.
{% if method.client_streaming %}
client.{{ method_name }}(iter([requests]))
{% elif method.server_streaming %}
with mock.patch.object(response_value, 'iter_content') as iter_content:
iter_content.return_value = iter(json_return_value)
response = client.{{ method_name }}(request)
{% else %}
client.{{ method_name }}(request)
{% endif %}
Expand Down Expand Up @@ -1038,8 +1043,6 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
return_value = None
{% elif method.lro %}
return_value = operations_pb2.Operation(name='operations/spam')
{% elif method.server_streaming %}
return_value = iter([{{ method.output.ident }}()])
{% else %}
return_value = {{ method.output.ident }}()
{% endif %}
Expand Down Expand Up @@ -1067,6 +1070,8 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide
json_return_value = ''
{% elif method.lro %}
json_return_value = json_format.MessageToJson(return_value)
{% elif method.server_streaming %}
json_return_value = "[{}]".format({{ method.output.ident }}.to_json(return_value))
{% else %}
json_return_value = {{ method.output.ident }}.to_json(return_value)
{% endif %}
Expand All @@ -1075,6 +1080,10 @@ def test_{{ method_name }}_rest_required_fields(request_type={{ method.input.ide

{% if method.client_streaming %}
response = client.{{ method.name|snake_case }}(iter(requests))
{% elif method.server_streaming %}
with mock.patch.object(response_value, 'iter_content') as iter_content:
iter_content.return_value = iter(json_return_value)
response = client.{{ method_name }}(request)
{% else %}
response = client.{{ method_name }}(request)
{% endif %}
Expand Down Expand Up @@ -1145,8 +1154,6 @@ def test_{{ method.name|snake_case }}_rest_flattened():
return_value = None
{% elif method.lro %}
return_value = operations_pb2.Operation(name='operations/spam')
{% elif method.server_streaming %}
return_value = iter([{{ method.output.ident }}()])
{% else %}
return_value = {{ method.output.ident }}()
{% endif %}
Expand All @@ -1158,6 +1165,8 @@ def test_{{ method.name|snake_case }}_rest_flattened():
json_return_value = ''
{% elif method.lro %}
json_return_value = json_format.MessageToJson(return_value)
{% elif method.server_streaming %}
json_return_value = "[{}]".format({{ method.output.ident }}.to_json(return_value))
{% else %}
json_return_value = {{ method.output.ident }}.to_json(return_value)
{% endif %}
Expand All @@ -1178,7 +1187,14 @@ def test_{{ method.name|snake_case }}_rest_flattened():
{% endfor %}
)
mock_args.update(sample_request)

{% if method.server_streaming %}
with mock.patch.object(response_value, 'iter_content') as iter_content:
iter_content.return_value = iter(json_return_value)
client.{{ method_name }}(**mock_args)
{% else %}
client.{{ method_name }}(**mock_args)
{% endif %}

# Establish that the underlying call was made with the expected
# request object values.
Expand Down Expand Up @@ -1282,6 +1298,9 @@ def test_{{ method_name }}_rest_pager(transport: str = 'rest'):
response = tuple({{ method.output.ident }}.to_json(x) for x in response)
return_values = tuple(Response() for i in response)
for return_val, response_val in zip(return_values, response):
{% if method.server_streaming %}
response_val = "[{}]".format({{ method.output.ident }}.to_json(response_val))
{% endif %}
return_val._content = response_val.encode('UTF-8')
return_val.status_code = 200
req.side_effect = return_values
Expand Down
Loading