@@ -95,6 +95,16 @@ def _get_entities(include_entities):
95
95
return entities
96
96
97
97
98
+ def make_mock_client (response ):
99
+ import mock
100
+ from google .cloud .language .connection import Connection
101
+ from google .cloud .language .client import Client
102
+
103
+ connection = mock .Mock (spec = Connection )
104
+ connection .api_request .return_value = response
105
+ return mock .Mock (_connection = connection , spec = Client )
106
+
107
+
98
108
class TestDocument (unittest .TestCase ):
99
109
100
110
@staticmethod
@@ -187,7 +197,36 @@ def _verify_entity(self, entity, name, entity_type, wiki_url, salience):
187
197
self .assertEqual (entity .salience , salience )
188
198
self .assertEqual (entity .mentions , [name ])
189
199
200
+ @staticmethod
201
+ def _expected_data (content , encoding_type = None ,
202
+ extract_sentiment = False ,
203
+ extract_entities = False ,
204
+ extract_syntax = False ):
205
+ from google .cloud .language .document import DEFAULT_LANGUAGE
206
+ from google .cloud .language .document import Document
207
+
208
+ expected = {
209
+ 'document' : {
210
+ 'language' : DEFAULT_LANGUAGE ,
211
+ 'type' : Document .PLAIN_TEXT ,
212
+ 'content' : content ,
213
+ },
214
+ }
215
+ if encoding_type is not None :
216
+ expected ['encodingType' ] = encoding_type
217
+ if extract_sentiment :
218
+ features = expected .setdefault ('features' , {})
219
+ features ['extractDocumentSentiment' ] = True
220
+ if extract_entities :
221
+ features = expected .setdefault ('features' , {})
222
+ features ['extractEntities' ] = True
223
+ if extract_syntax :
224
+ features = expected .setdefault ('features' , {})
225
+ features ['extractSyntax' ] = True
226
+ return expected
227
+
190
228
def test_analyze_entities (self ):
229
+ from google .cloud .language .document import Encoding
191
230
from google .cloud .language .entity import EntityType
192
231
193
232
name1 = 'R-O-C-K'
@@ -229,8 +268,7 @@ def test_analyze_entities(self):
229
268
],
230
269
'language' : 'en-US' ,
231
270
}
232
- connection = _Connection (response )
233
- client = _Client (connection = connection )
271
+ client = make_mock_client (response )
234
272
document = self ._make_one (client , content )
235
273
236
274
entities = document .analyze_entities ()
@@ -243,10 +281,10 @@ def test_analyze_entities(self):
243
281
wiki2 , salience2 )
244
282
245
283
# Verify the request.
246
- self .assertEqual ( len ( connection . _requested ), 1 )
247
- req = connection . _requested [ 0 ]
248
- self . assertEqual ( req [ 'path' ], 'analyzeEntities' )
249
- self . assertEqual ( req [ 'method' ], 'POST' )
284
+ expected = self ._expected_data (
285
+ content , encoding_type = Encoding . UTF8 )
286
+ client . _connection . api_request . assert_called_once_with (
287
+ path = 'analyzeEntities' , method = 'POST' , data = expected )
250
288
251
289
def _verify_sentiment (self , sentiment , polarity , magnitude ):
252
290
from google .cloud .language .sentiment import Sentiment
@@ -266,18 +304,16 @@ def test_analyze_sentiment(self):
266
304
},
267
305
'language' : 'en-US' ,
268
306
}
269
- connection = _Connection (response )
270
- client = _Client (connection = connection )
307
+ client = make_mock_client (response )
271
308
document = self ._make_one (client , content )
272
309
273
310
sentiment = document .analyze_sentiment ()
274
311
self ._verify_sentiment (sentiment , polarity , magnitude )
275
312
276
313
# Verify the request.
277
- self .assertEqual (len (connection ._requested ), 1 )
278
- req = connection ._requested [0 ]
279
- self .assertEqual (req ['path' ], 'analyzeSentiment' )
280
- self .assertEqual (req ['method' ], 'POST' )
314
+ expected = self ._expected_data (content )
315
+ client ._connection .api_request .assert_called_once_with (
316
+ path = 'analyzeSentiment' , method = 'POST' , data = expected )
281
317
282
318
def _verify_sentences (self , include_syntax , annotations ):
283
319
from google .cloud .language .syntax import Sentence
@@ -307,6 +343,7 @@ def _verify_tokens(self, annotations, token_info):
307
343
def _annotate_text_helper (self , include_sentiment ,
308
344
include_entities , include_syntax ):
309
345
from google .cloud .language .document import Annotations
346
+ from google .cloud .language .document import Encoding
310
347
from google .cloud .language .entity import EntityType
311
348
312
349
token_info , sentences = _get_token_and_sentences (include_syntax )
@@ -324,8 +361,7 @@ def _annotate_text_helper(self, include_sentiment,
324
361
'magnitude' : ANNOTATE_MAGNITUDE ,
325
362
}
326
363
327
- connection = _Connection (response )
328
- client = _Client (connection = connection )
364
+ client = make_mock_client (response )
329
365
document = self ._make_one (client , ANNOTATE_CONTENT )
330
366
331
367
annotations = document .annotate_text (
@@ -352,16 +388,13 @@ def _annotate_text_helper(self, include_sentiment,
352
388
self .assertEqual (annotations .entities , [])
353
389
354
390
# Verify the request.
355
- self .assertEqual (len (connection ._requested ), 1 )
356
- req = connection ._requested [0 ]
357
- self .assertEqual (req ['path' ], 'annotateText' )
358
- self .assertEqual (req ['method' ], 'POST' )
359
- features = req ['data' ]['features' ]
360
- self .assertEqual (features .get ('extractDocumentSentiment' , False ),
361
- include_sentiment )
362
- self .assertEqual (features .get ('extractEntities' , False ),
363
- include_entities )
364
- self .assertEqual (features .get ('extractSyntax' , False ), include_syntax )
391
+ expected = self ._expected_data (
392
+ ANNOTATE_CONTENT , encoding_type = Encoding .UTF8 ,
393
+ extract_sentiment = include_sentiment ,
394
+ extract_entities = include_entities ,
395
+ extract_syntax = include_syntax )
396
+ client ._connection .api_request .assert_called_once_with (
397
+ path = 'annotateText' , method = 'POST' , data = expected )
365
398
366
399
def test_annotate_text (self ):
367
400
self ._annotate_text_helper (True , True , True )
@@ -374,20 +407,3 @@ def test_annotate_text_entities_only(self):
374
407
375
408
def test_annotate_text_syntax_only (self ):
376
409
self ._annotate_text_helper (False , False , True )
377
-
378
-
379
- class _Connection (object ):
380
-
381
- def __init__ (self , response ):
382
- self ._response = response
383
- self ._requested = []
384
-
385
- def api_request (self , ** kwargs ):
386
- self ._requested .append (kwargs )
387
- return self ._response
388
-
389
-
390
- class _Client (object ):
391
-
392
- def __init__ (self , connection = None ):
393
- self ._connection = connection
0 commit comments