@@ -823,7 +823,7 @@ async def on_connect(self) -> None:
823
823
if str_if_bytes (await self .read_response ()) != "OK" :
824
824
raise ConnectionError ("Invalid Database" )
825
825
826
- async def disconnect (self ) -> None :
826
+ async def disconnect (self , nowait : bool = False ) -> None :
827
827
"""Disconnects from the Redis server"""
828
828
try :
829
829
async with async_timeout .timeout (self .socket_connect_timeout ):
@@ -833,8 +833,9 @@ async def disconnect(self) -> None:
833
833
try :
834
834
if os .getpid () == self .pid :
835
835
self ._writer .close () # type: ignore[union-attr]
836
- # py3.6 doesn't have this method
837
- if hasattr (self ._writer , "wait_closed" ):
836
+ # wait for close to finish, except when handling errors and
837
+ # forcecully disconnecting.
838
+ if not nowait :
838
839
await self ._writer .wait_closed () # type: ignore[union-attr]
839
840
except OSError :
840
841
pass
@@ -934,10 +935,10 @@ async def read_response(self, disable_decoding: bool = False):
934
935
disable_decoding = disable_decoding
935
936
)
936
937
except asyncio .TimeoutError :
937
- await self .disconnect ()
938
+ await self .disconnect (nowait = True )
938
939
raise TimeoutError (f"Timeout reading from { self .host } :{ self .port } " )
939
940
except OSError as e :
940
- await self .disconnect ()
941
+ await self .disconnect (nowait = True )
941
942
raise ConnectionError (
942
943
f"Error while reading from { self .host } :{ self .port } : { e .args } "
943
944
)
@@ -946,7 +947,7 @@ async def read_response(self, disable_decoding: bool = False):
946
947
# is subclass of Exception, not BaseException
947
948
raise
948
949
except Exception :
949
- await self .disconnect ()
950
+ await self .disconnect (nowait = True )
950
951
raise
951
952
952
953
if self .health_check_interval :
0 commit comments