@@ -355,10 +355,50 @@ def setup_run_circuit_with_result_(client, result):
355355
356356
357357@mock .patch ('cirq_google.engine.engine_client.EngineClient' , autospec = True )
358- def test_run_circuit (client ):
358+ def test_run_circuit_with_unary_rpcs (client ):
359359 setup_run_circuit_with_result_ (client , _A_RESULT )
360360
361- engine = cg .Engine (project_id = 'proj' , service_args = {'client_info' : 1 })
361+ engine = cg .Engine (
362+ project_id = 'proj' ,
363+ context = EngineContext (service_args = {'client_info' : 1 }, enable_streaming = False ),
364+ )
365+ result = engine .run (
366+ program = _CIRCUIT , program_id = 'prog' , job_id = 'job-id' , processor_ids = ['mysim' ]
367+ )
368+
369+ assert result .repetitions == 1
370+ assert result .params .param_dict == {'a' : 1 }
371+ assert result .measurements == {'q' : np .array ([[0 ]], dtype = 'uint8' )}
372+ client .assert_called_with (service_args = {'client_info' : 1 }, verbose = None )
373+ client ().create_program_async .assert_called_once ()
374+ client ().create_job_async .assert_called_once_with (
375+ project_id = 'proj' ,
376+ program_id = 'prog' ,
377+ job_id = 'job-id' ,
378+ processor_ids = ['mysim' ],
379+ run_context = util .pack_any (
380+ v2 .run_context_pb2 .RunContext (
381+ parameter_sweeps = [v2 .run_context_pb2 .ParameterSweep (repetitions = 1 )]
382+ )
383+ ),
384+ description = None ,
385+ labels = None ,
386+ processor_id = '' ,
387+ run_name = '' ,
388+ device_config_name = '' ,
389+ )
390+ client ().get_job_async .assert_called_once_with ('proj' , 'prog' , 'job-id' , False )
391+ client ().get_job_results_async .assert_called_once_with ('proj' , 'prog' , 'job-id' )
392+
393+
394+ @mock .patch ('cirq_google.engine.engine_client.EngineClient' , autospec = True )
395+ def test_run_circuit_with_stream_rpcs (client ):
396+ setup_run_circuit_with_result_ (client , _A_RESULT )
397+
398+ engine = cg .Engine (
399+ project_id = 'proj' ,
400+ context = EngineContext (service_args = {'client_info' : 1 }, enable_streaming = True ),
401+ )
362402 result = engine .run (
363403 program = _CIRCUIT , program_id = 'prog' , job_id = 'job-id' , processor_ids = ['mysim' ]
364404 )
@@ -399,7 +439,37 @@ def test_unsupported_program_type():
399439
400440
401441@mock .patch ('cirq_google.engine.engine_client.EngineClient' , autospec = True )
402- def test_run_circuit_failed (client ):
442+ def test_run_circuit_failed_with_unary_rpcs (client ):
443+ client ().create_program_async .return_value = (
444+ 'prog' ,
445+ quantum .QuantumProgram (name = 'projects/proj/programs/prog' ),
446+ )
447+ client ().create_job_async .return_value = (
448+ 'job-id' ,
449+ quantum .QuantumJob (
450+ name = 'projects/proj/programs/prog/jobs/job-id' , execution_status = {'state' : 'READY' }
451+ ),
452+ )
453+ client ().get_job_async .return_value = quantum .QuantumJob (
454+ name = 'projects/proj/programs/prog/jobs/job-id' ,
455+ execution_status = {
456+ 'state' : 'FAILURE' ,
457+ 'processor_name' : 'myqc' ,
458+ 'failure' : {'error_code' : 'SYSTEM_ERROR' , 'error_message' : 'Not good' },
459+ },
460+ )
461+
462+ engine = cg .Engine (project_id = 'proj' , context = EngineContext (enable_streaming = False ))
463+ with pytest .raises (
464+ RuntimeError ,
465+ match = 'Job projects/proj/programs/prog/jobs/job-id on processor'
466+ ' myqc failed. SYSTEM_ERROR: Not good' ,
467+ ):
468+ engine .run (program = _CIRCUIT )
469+
470+
471+ @mock .patch ('cirq_google.engine.engine_client.EngineClient' , autospec = True )
472+ def test_run_circuit_failed_with_stream_rpcs (client ):
403473 failed_job = quantum .QuantumJob (
404474 name = 'projects/proj/programs/prog/jobs/job-id' ,
405475 execution_status = {
@@ -412,7 +482,7 @@ def test_run_circuit_failed(client):
412482 stream_future .try_set_result (failed_job )
413483 client ().run_job_over_stream .return_value = stream_future
414484
415- engine = cg .Engine (project_id = 'proj' )
485+ engine = cg .Engine (project_id = 'proj' , context = EngineContext ( enable_streaming = True ) )
416486 with pytest .raises (
417487 RuntimeError ,
418488 match = 'Job projects/proj/programs/prog/jobs/job-id on processor'
@@ -422,7 +492,36 @@ def test_run_circuit_failed(client):
422492
423493
424494@mock .patch ('cirq_google.engine.engine_client.EngineClient' , autospec = True )
425- def test_run_circuit_failed_missing_processor_name (client ):
495+ def test_run_circuit_failed_missing_processor_name_with_unary_rpcs (client ):
496+ client ().create_program_async .return_value = (
497+ 'prog' ,
498+ quantum .QuantumProgram (name = 'projects/proj/programs/prog' ),
499+ )
500+ client ().create_job_async .return_value = (
501+ 'job-id' ,
502+ quantum .QuantumJob (
503+ name = 'projects/proj/programs/prog/jobs/job-id' , execution_status = {'state' : 'READY' }
504+ ),
505+ )
506+ client ().get_job_async .return_value = quantum .QuantumJob (
507+ name = 'projects/proj/programs/prog/jobs/job-id' ,
508+ execution_status = {
509+ 'state' : 'FAILURE' ,
510+ 'failure' : {'error_code' : 'SYSTEM_ERROR' , 'error_message' : 'Not good' },
511+ },
512+ )
513+
514+ engine = cg .Engine (project_id = 'proj' , context = EngineContext (enable_streaming = False ))
515+ with pytest .raises (
516+ RuntimeError ,
517+ match = 'Job projects/proj/programs/prog/jobs/job-id on processor'
518+ ' UNKNOWN failed. SYSTEM_ERROR: Not good' ,
519+ ):
520+ engine .run (program = _CIRCUIT )
521+
522+
523+ @mock .patch ('cirq_google.engine.engine_client.EngineClient' , autospec = True )
524+ def test_run_circuit_failed_missing_processor_name_with_stream_rpcs (client ):
426525 failed_job = quantum .QuantumJob (
427526 name = 'projects/proj/programs/prog/jobs/job-id' ,
428527 execution_status = {
@@ -434,7 +533,7 @@ def test_run_circuit_failed_missing_processor_name(client):
434533 stream_future .try_set_result (failed_job )
435534 client ().run_job_over_stream .return_value = stream_future
436535
437- engine = cg .Engine (project_id = 'proj' )
536+ engine = cg .Engine (project_id = 'proj' , context = EngineContext ( enable_streaming = True ) )
438537 with pytest .raises (
439538 RuntimeError ,
440539 match = 'Job projects/proj/programs/prog/jobs/job-id on processor'
@@ -444,26 +543,78 @@ def test_run_circuit_failed_missing_processor_name(client):
444543
445544
446545@mock .patch ('cirq_google.engine.engine_client.EngineClient' , autospec = True )
447- def test_run_circuit_cancelled (client ):
546+ def test_run_circuit_cancelled_with_unary_rpcs (client ):
547+ client ().create_program_async .return_value = (
548+ 'prog' ,
549+ quantum .QuantumProgram (name = 'projects/proj/programs/prog' ),
550+ )
551+ client ().create_job_async .return_value = (
552+ 'job-id' ,
553+ quantum .QuantumJob (
554+ name = 'projects/proj/programs/prog/jobs/job-id' , execution_status = {'state' : 'READY' }
555+ ),
556+ )
557+ client ().get_job_async .return_value = quantum .QuantumJob (
558+ name = 'projects/proj/programs/prog/jobs/job-id' , execution_status = {'state' : 'CANCELLED' }
559+ )
560+
561+ engine = cg .Engine (project_id = 'proj' , context = EngineContext (enable_streaming = False ))
562+ with pytest .raises (
563+ RuntimeError , match = 'Job projects/proj/programs/prog/jobs/job-id failed in state CANCELLED.'
564+ ):
565+ engine .run (program = _CIRCUIT )
566+
567+
568+ @mock .patch ('cirq_google.engine.engine_client.EngineClient' , autospec = True )
569+ def test_run_circuit_cancelled_with_stream_rpcs (client ):
448570 canceled_job = quantum .QuantumJob (
449571 name = 'projects/proj/programs/prog/jobs/job-id' , execution_status = {'state' : 'CANCELLED' }
450572 )
451573 stream_future = duet .AwaitableFuture ()
452574 stream_future .try_set_result (canceled_job )
453575 client ().run_job_over_stream .return_value = stream_future
454576
455- engine = cg .Engine (project_id = 'proj' )
577+ engine = cg .Engine (project_id = 'proj' , context = EngineContext ( enable_streaming = True ) )
456578 with pytest .raises (
457579 RuntimeError , match = 'Job projects/proj/programs/prog/jobs/job-id failed in state CANCELLED.'
458580 ):
459581 engine .run (program = _CIRCUIT )
460582
461583
462584@mock .patch ('cirq_google.engine.engine_client.EngineClient' , autospec = True )
463- def test_run_sweep_params (client ):
585+ def test_run_sweep_params_with_unary_rpcs (client ):
586+ setup_run_circuit_with_result_ (client , _RESULTS )
587+
588+ engine = cg .Engine (project_id = 'proj' , context = EngineContext (enable_streaming = False ))
589+ job = engine .run_sweep (
590+ program = _CIRCUIT , params = [cirq .ParamResolver ({'a' : 1 }), cirq .ParamResolver ({'a' : 2 })]
591+ )
592+ results = job .results ()
593+ assert len (results ) == 2
594+ for i , v in enumerate ([1 , 2 ]):
595+ assert results [i ].repetitions == 1
596+ assert results [i ].params .param_dict == {'a' : v }
597+ assert results [i ].measurements == {'q' : np .array ([[0 ]], dtype = 'uint8' )}
598+
599+ client ().create_program_async .assert_called_once ()
600+ client ().create_job_async .assert_called_once ()
601+
602+ run_context = v2 .run_context_pb2 .RunContext ()
603+ client ().create_job_async .call_args [1 ]['run_context' ].Unpack (run_context )
604+ sweeps = run_context .parameter_sweeps
605+ assert len (sweeps ) == 2
606+ for i , v in enumerate ([1.0 , 2.0 ]):
607+ assert sweeps [i ].repetitions == 1
608+ assert sweeps [i ].sweep .sweep_function .sweeps [0 ].single_sweep .points .points == [v ]
609+ client ().get_job_async .assert_called_once ()
610+ client ().get_job_results_async .assert_called_once ()
611+
612+
613+ @mock .patch ('cirq_google.engine.engine_client.EngineClient' , autospec = True )
614+ def test_run_sweep_params_with_stream_rpcs (client ):
464615 setup_run_circuit_with_result_ (client , _RESULTS )
465616
466- engine = cg .Engine (project_id = 'proj' )
617+ engine = cg .Engine (project_id = 'proj' , context = EngineContext ( enable_streaming = True ) )
467618 job = engine .run_sweep (
468619 program = _CIRCUIT , params = [cirq .ParamResolver ({'a' : 1 }), cirq .ParamResolver ({'a' : 2 })]
469620 )
@@ -486,7 +637,12 @@ def test_run_sweep_params(client):
486637
487638
488639def test_run_sweep_with_multiple_processor_ids ():
489- engine = cg .Engine (project_id = 'proj' , proto_version = cg .engine .engine .ProtoVersion .V2 )
640+ engine = cg .Engine (
641+ project_id = 'proj' ,
642+ context = EngineContext (
643+ proto_version = cg .engine .engine .ProtoVersion .V2 , enable_streaming = True
644+ ),
645+ )
490646 with pytest .raises (ValueError , match = 'multiple processors is no longer supported' ):
491647 _ = engine .run_sweep (
492648 program = _CIRCUIT ,
@@ -527,10 +683,44 @@ def test_run_multiple_times(client):
527683
528684
529685@mock .patch ('cirq_google.engine.engine_client.EngineClient' , autospec = True )
530- def test_run_sweep_v2 (client ):
686+ def test_run_sweep_v2_with_unary_rpcs (client ):
531687 setup_run_circuit_with_result_ (client , _RESULTS_V2 )
532688
533- engine = cg .Engine (project_id = 'proj' , proto_version = cg .engine .engine .ProtoVersion .V2 )
689+ engine = cg .Engine (
690+ project_id = 'proj' ,
691+ context = EngineContext (
692+ proto_version = cg .engine .engine .ProtoVersion .V2 , enable_streaming = False
693+ ),
694+ )
695+ job = engine .run_sweep (program = _CIRCUIT , job_id = 'job-id' , params = cirq .Points ('a' , [1 , 2 ]))
696+ results = job .results ()
697+ assert len (results ) == 2
698+ for i , v in enumerate ([1 , 2 ]):
699+ assert results [i ].repetitions == 1
700+ assert results [i ].params .param_dict == {'a' : v }
701+ assert results [i ].measurements == {'q' : np .array ([[0 ]], dtype = 'uint8' )}
702+ client ().create_program_async .assert_called_once ()
703+ client ().create_job_async .assert_called_once ()
704+ run_context = v2 .run_context_pb2 .RunContext ()
705+ client ().create_job_async .call_args [1 ]['run_context' ].Unpack (run_context )
706+ sweeps = run_context .parameter_sweeps
707+ assert len (sweeps ) == 1
708+ assert sweeps [0 ].repetitions == 1
709+ assert sweeps [0 ].sweep .single_sweep .points .points == [1 , 2 ]
710+ client ().get_job_async .assert_called_once ()
711+ client ().get_job_results_async .assert_called_once ()
712+
713+
714+ @mock .patch ('cirq_google.engine.engine_client.EngineClient' , autospec = True )
715+ def test_run_sweep_v2_with_stream_rpcs (client ):
716+ setup_run_circuit_with_result_ (client , _RESULTS_V2 )
717+
718+ engine = cg .Engine (
719+ project_id = 'proj' ,
720+ context = EngineContext (
721+ proto_version = cg .engine .engine .ProtoVersion .V2 , enable_streaming = True
722+ ),
723+ )
534724 job = engine .run_sweep (program = _CIRCUIT , job_id = 'job-id' , params = cirq .Points ('a' , [1 , 2 ]))
535725 results = job .results ()
536726 assert len (results ) == 2
@@ -772,10 +962,33 @@ def test_get_processor():
772962
773963
774964@mock .patch ('cirq_google.engine.engine_client.EngineClient' , autospec = True )
775- def test_sampler (client ):
965+ def test_sampler_with_unary_rpcs (client ):
966+ setup_run_circuit_with_result_ (client , _RESULTS )
967+
968+ engine = cg .Engine (project_id = 'proj' , context = EngineContext (enable_streaming = False ))
969+ sampler = engine .get_sampler (processor_id = 'tmp' )
970+ results = sampler .run_sweep (
971+ program = _CIRCUIT , params = [cirq .ParamResolver ({'a' : 1 }), cirq .ParamResolver ({'a' : 2 })]
972+ )
973+ assert len (results ) == 2
974+ for i , v in enumerate ([1 , 2 ]):
975+ assert results [i ].repetitions == 1
976+ assert results [i ].params .param_dict == {'a' : v }
977+ assert results [i ].measurements == {'q' : np .array ([[0 ]], dtype = 'uint8' )}
978+ assert client ().create_program_async .call_args [0 ][0 ] == 'proj'
979+
980+ with cirq .testing .assert_deprecated ('sampler' , deadline = '1.0' ):
981+ _ = engine .sampler (processor_id = 'tmp' )
982+
983+ with pytest .raises (ValueError , match = 'list of processors' ):
984+ _ = engine .get_sampler (['test1' , 'test2' ])
985+
986+
987+ @mock .patch ('cirq_google.engine.engine_client.EngineClient' , autospec = True )
988+ def test_sampler_with_stream_rpcs (client ):
776989 setup_run_circuit_with_result_ (client , _RESULTS )
777990
778- engine = cg .Engine (project_id = 'proj' )
991+ engine = cg .Engine (project_id = 'proj' , context = EngineContext ( enable_streaming = True ) )
779992 sampler = engine .get_sampler (processor_id = 'tmp' )
780993 results = sampler .run_sweep (
781994 program = _CIRCUIT , params = [cirq .ParamResolver ({'a' : 1 }), cirq .ParamResolver ({'a' : 2 })]
0 commit comments