@@ -468,9 +468,8 @@ def keyslot(self, key: EncodableT) -> int:
468
468
return key_slot (k )
469
469
470
470
async def _determine_nodes (
471
- self , * args : Any , node_flag : Optional [str ] = None
471
+ self , command : str , * args : Any , node_flag : Optional [str ] = None
472
472
) -> List ["ClusterNode" ]:
473
- command = args [0 ]
474
473
if not node_flag :
475
474
# get the nodes group for this command if it was predefined
476
475
node_flag = self .command_flags .get (command )
@@ -495,16 +494,15 @@ async def _determine_nodes(
495
494
# get the node that holds the key's slot
496
495
return [
497
496
self .nodes_manager .get_node_from_slot (
498
- await self ._determine_slot (* args ),
497
+ await self ._determine_slot (command , * args ),
499
498
self .read_from_replicas and command in READ_COMMANDS ,
500
499
)
501
500
]
502
501
503
- async def _determine_slot (self , * args : Any ) -> int :
504
- command = args [0 ]
502
+ async def _determine_slot (self , command : str , * args : Any ) -> int :
505
503
if self .command_flags .get (command ) == SLOT_ID :
506
504
# The command contains the slot ID
507
- return int (args [1 ])
505
+ return int (args [0 ])
508
506
509
507
# Get the keys in the command
510
508
@@ -516,19 +514,17 @@ async def _determine_slot(self, *args: Any) -> int:
516
514
# - fix: https://github.com/redis/redis/pull/9733
517
515
if command in ("EVAL" , "EVALSHA" ):
518
516
# command syntax: EVAL "script body" num_keys ...
519
- if len (args ) <= 2 :
520
- raise RedisClusterException (f"Invalid args in command: { args } " )
521
- num_actual_keys = args [2 ]
522
- eval_keys = args [3 : 3 + num_actual_keys ]
517
+ if len (args ) < 2 :
518
+ raise RedisClusterException (
519
+ f"Invalid args in command: { command , * args } "
520
+ )
521
+ keys = args [2 : 2 + args [1 ]]
523
522
# if there are 0 keys, that means the script can be run on any node
524
523
# so we can just return a random slot
525
- if not eval_keys :
524
+ if not keys :
526
525
return random .randrange (0 , REDIS_CLUSTER_HASH_SLOTS )
527
- keys = eval_keys
528
526
else :
529
- keys = await self .commands_parser .get_keys (
530
- self .nodes_manager .default_node , * args
531
- )
527
+ keys = await self .commands_parser .get_keys (command , * args )
532
528
if not keys :
533
529
# FCALL can call a function with 0 keys, that means the function
534
530
# can be run on any node so we can just return a random slot
@@ -848,13 +844,13 @@ def acquire_connection(self) -> Connection:
848
844
self ._free .append (connection )
849
845
850
846
return self ._free .popleft ()
851
- else :
852
- if len (self ._connections ) < self .max_connections :
853
- connection = self .connection_class (** self .connection_kwargs )
854
- self ._connections .append (connection )
855
- return connection
856
- else :
857
- raise ConnectionError ("Too many connections" )
847
+
848
+ if len (self ._connections ) < self .max_connections :
849
+ connection = self .connection_class (** self .connection_kwargs )
850
+ self ._connections .append (connection )
851
+ return connection
852
+
853
+ raise ConnectionError ("Too many connections" )
858
854
859
855
async def parse_response (
860
856
self , connection : Connection , command : str , ** kwargs : Any
@@ -872,10 +868,10 @@ async def parse_response(
872
868
raise
873
869
874
870
# Return response
875
- try :
871
+ if command in self . response_callbacks :
876
872
return self .response_callbacks [command ](response , ** kwargs )
877
- except KeyError :
878
- return response
873
+
874
+ return response
879
875
880
876
async def execute_command (self , * args : Any , ** kwargs : Any ) -> Any :
881
877
# Acquire connection
@@ -891,7 +887,7 @@ async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
891
887
# Release connection
892
888
self ._free .append (connection )
893
889
894
- async def execute_pipeline (self ) -> None :
890
+ async def execute_pipeline (self ) -> bool :
895
891
# Acquire connection
896
892
connection = self .acquire_connection ()
897
893
@@ -901,17 +897,20 @@ async def execute_pipeline(self) -> None:
901
897
)
902
898
903
899
# Read responses
904
- try :
905
- for cmd in self ._command_stack :
906
- try :
907
- cmd .result = await self .parse_response (
908
- connection , cmd .args [0 ], ** cmd .kwargs
909
- )
910
- except Exception as e :
911
- cmd .result = e
912
- finally :
913
- # Release connection
914
- self ._free .append (connection )
900
+ ret = False
901
+ for cmd in self ._command_stack :
902
+ try :
903
+ cmd .result = await self .parse_response (
904
+ connection , cmd .args [0 ], ** cmd .kwargs
905
+ )
906
+ except Exception as e :
907
+ cmd .result = e
908
+ ret = True
909
+
910
+ # Release connection
911
+ self ._free .append (connection )
912
+
913
+ return ret
915
914
916
915
917
916
class NodesManager :
@@ -1257,6 +1256,13 @@ async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> N
1257
1256
def __await__ (self ) -> Generator [Any , None , "ClusterPipeline" ]:
1258
1257
return self .initialize ().__await__ ()
1259
1258
1259
+ def __enter__ (self ) -> "ClusterPipeline" :
1260
+ self ._command_stack = []
1261
+ return self
1262
+
1263
+ def __exit__ (self , exc_type : None , exc_value : None , traceback : None ) -> None :
1264
+ self ._command_stack = []
1265
+
1260
1266
def __bool__ (self ) -> bool :
1261
1267
return bool (self ._command_stack )
1262
1268
@@ -1310,6 +1316,7 @@ async def execute(
1310
1316
1311
1317
try :
1312
1318
return await self ._execute (
1319
+ self ._client ,
1313
1320
self ._command_stack ,
1314
1321
raise_on_error = raise_on_error ,
1315
1322
allow_redirections = allow_redirections ,
@@ -1331,60 +1338,60 @@ async def execute(
1331
1338
1332
1339
async def _execute (
1333
1340
self ,
1341
+ client : "RedisCluster" ,
1334
1342
stack : List ["PipelineCommand" ],
1335
1343
raise_on_error : bool = True ,
1336
1344
allow_redirections : bool = True ,
1337
1345
) -> List [Any ]:
1338
- client = self ._client
1346
+ todo = [
1347
+ cmd for cmd in stack if not cmd .result or isinstance (cmd .result , Exception )
1348
+ ]
1349
+
1339
1350
nodes = {}
1340
- for cmd in stack :
1341
- if not cmd .result or isinstance (cmd .result , Exception ):
1342
- target_nodes = await client ._determine_nodes (* cmd .args )
1343
- if not target_nodes :
1344
- raise RedisClusterException (
1345
- f"No targets were found to execute { cmd .args } command on"
1346
- )
1347
- if len (target_nodes ) > 1 :
1348
- raise RedisClusterException (
1349
- f"Too many targets for command { cmd .args } "
1350
- )
1351
+ for cmd in todo :
1352
+ target_nodes = await client ._determine_nodes (* cmd .args )
1353
+ if not target_nodes :
1354
+ raise RedisClusterException (
1355
+ f"No targets were found to execute { cmd .args } command on"
1356
+ )
1357
+ if len (target_nodes ) > 1 :
1358
+ raise RedisClusterException (f"Too many targets for command { cmd .args } " )
1351
1359
1352
- node = target_nodes [0 ]
1353
- if node .name not in nodes :
1354
- nodes [node .name ] = node
1355
- node ._command_stack = []
1356
- node ._command_stack .append (cmd )
1360
+ node = target_nodes [0 ]
1361
+ if node .name not in nodes :
1362
+ nodes [node .name ] = node
1363
+ node ._command_stack = []
1364
+ node ._command_stack .append (cmd )
1357
1365
1358
- await asyncio .gather (
1366
+ errors = await asyncio .gather (
1359
1367
* (asyncio .ensure_future (node .execute_pipeline ()) for node in nodes .values ())
1360
1368
)
1361
1369
1362
- if allow_redirections :
1363
- # send each errored command individually
1364
- for cmd in stack :
1365
- if isinstance (cmd .result , (TryAgainError , MovedError , AskError )):
1366
- try :
1367
- cmd .result = await client .execute_command (
1368
- * cmd .args , ** cmd .kwargs
1370
+ if any (errors ):
1371
+ if allow_redirections :
1372
+ # send each errored command individually
1373
+ for cmd in todo :
1374
+ if isinstance (cmd .result , (TryAgainError , MovedError , AskError )):
1375
+ try :
1376
+ cmd .result = await client .execute_command (
1377
+ * cmd .args , ** cmd .kwargs
1378
+ )
1379
+ except Exception as e :
1380
+ cmd .result = e
1381
+
1382
+ if raise_on_error :
1383
+ for cmd in todo :
1384
+ result = cmd .result
1385
+ if isinstance (result , Exception ):
1386
+ command = " " .join (map (safe_str , cmd .args ))
1387
+ msg = (
1388
+ f"Command # { cmd .position + 1 } ({ command } ) of pipeline "
1389
+ f"caused error: { result .args } "
1369
1390
)
1370
- except Exception as e :
1371
- cmd .result = e
1372
-
1373
- responses = [cmd .result for cmd in stack ]
1374
-
1375
- if raise_on_error :
1376
- for cmd in stack :
1377
- result = cmd .result
1378
- if isinstance (result , Exception ):
1379
- command = " " .join (map (safe_str , cmd .args ))
1380
- msg = (
1381
- f"Command # { cmd .position + 1 } ({ command } ) of pipeline "
1382
- f"caused error: { result .args } "
1383
- )
1384
- result .args = (msg ,) + result .args [1 :]
1385
- raise result
1391
+ result .args = (msg ,) + result .args [1 :]
1392
+ raise result
1386
1393
1387
- return responses
1394
+ return [ cmd . result for cmd in stack ]
1388
1395
1389
1396
def _split_command_across_slots (
1390
1397
self , command : str , * keys : KeyT
0 commit comments