@@ -826,7 +826,7 @@ async def on_connect(self) -> None:
826
826
if str_if_bytes (await self .read_response ()) != "OK" :
827
827
raise ConnectionError ("Invalid Database" )
828
828
829
- async def disconnect (self ) -> None :
829
+ async def disconnect (self , nowait : bool = False ) -> None :
830
830
"""Disconnects from the Redis server"""
831
831
try :
832
832
async with async_timeout .timeout (self .socket_connect_timeout ):
@@ -836,8 +836,9 @@ async def disconnect(self) -> None:
836
836
try :
837
837
if os .getpid () == self .pid :
838
838
self ._writer .close () # type: ignore[union-attr]
839
- # py3.6 doesn't have this method
840
- if hasattr (self ._writer , "wait_closed" ):
839
+ # wait for close to finish, except when handling errors and
840
+ # forcecully disconnecting.
841
+ if not nowait :
841
842
await self ._writer .wait_closed () # type: ignore[union-attr]
842
843
except OSError :
843
844
pass
@@ -938,10 +939,10 @@ async def read_response(self, disable_decoding: bool = False):
938
939
disable_decoding = disable_decoding
939
940
)
940
941
except asyncio .TimeoutError :
941
- await self .disconnect ()
942
+ await self .disconnect (nowait = True )
942
943
raise TimeoutError (f"Timeout reading from { self .host } :{ self .port } " )
943
944
except OSError as e :
944
- await self .disconnect ()
945
+ await self .disconnect (nowait = True )
945
946
raise ConnectionError (
946
947
f"Error while reading from { self .host } :{ self .port } : { e .args } "
947
948
)
@@ -950,7 +951,7 @@ async def read_response(self, disable_decoding: bool = False):
950
951
# is subclass of Exception, not BaseException
951
952
raise
952
953
except Exception :
953
- await self .disconnect ()
954
+ await self .disconnect (nowait = True )
954
955
raise
955
956
956
957
if self .health_check_interval :
0 commit comments