Skip to content

Commit 70ae26c

Browse files
committed
Fix #1065 by having lock for async Socket Mode client reconnection
1 parent 8bbb5d4 commit 70ae26c

File tree

4 files changed

+42
-13
lines changed

4 files changed

+42
-13
lines changed

integration_tests/samples/socket_mode/aiohttp_example.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ async def process(client: SocketModeClient, req: SocketModeRequest):
2323
if req.type == "events_api":
2424
response = SocketModeResponse(envelope_id=req.envelope_id)
2525
await client.send_socket_mode_response(response)
26-
27-
await client.web_client.reactions_add(
28-
name="eyes",
29-
channel=req.payload["event"]["channel"],
30-
timestamp=req.payload["event"]["ts"],
31-
)
26+
if req.payload["event"]["type"] == "message":
27+
await client.web_client.reactions_add(
28+
name="eyes",
29+
channel=req.payload["event"]["channel"],
30+
timestamp=req.payload["event"]["ts"],
31+
)
3232

3333
client.socket_mode_request_listeners.append(process)
3434
await client.connect()

slack_sdk/socket_mode/aiohttp/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"""
88
import asyncio
99
import logging
10-
from asyncio import Future
10+
from asyncio import Future, Lock
1111
from asyncio import Queue
1212
from logging import Logger
1313
from typing import Union, Optional, List, Callable, Awaitable
@@ -58,6 +58,7 @@ class SocketModeClient(AsyncBaseSocketModeClient):
5858
auto_reconnect_enabled: bool
5959
default_auto_reconnect_enabled: bool
6060
closed: bool
61+
connect_operation_lock: Lock
6162

6263
on_message_listeners: List[Callable[[WSMessage], Awaitable[None]]]
6364
on_error_listeners: List[Callable[[WSMessage], Awaitable[None]]]
@@ -92,6 +93,7 @@ def __init__(
9293
self.logger = logger or logging.getLogger(__name__)
9394
self.web_client = web_client or AsyncWebClient()
9495
self.closed = False
96+
self.connect_operation_lock = Lock()
9597
self.proxy = proxy
9698
if self.proxy is None or len(self.proxy.strip()) == 0:
9799
env_variable = load_http_proxy_from_env(self.logger)
@@ -185,6 +187,13 @@ async def receive_messages(self) -> None:
185187
else:
186188
await asyncio.sleep(consecutive_error_count)
187189

190+
async def is_connected(self) -> bool:
191+
return (
192+
not self.closed
193+
and self.current_session is not None
194+
and not self.current_session.closed
195+
)
196+
188197
async def connect(self):
189198
old_session = None if self.current_session is None else self.current_session
190199
if self.wss_uri is None:

slack_sdk/socket_mode/async_client.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import json
33
import logging
4-
from asyncio import Queue
4+
from asyncio import Queue, Lock
55
from asyncio.futures import Future
66
from logging import Logger
77
from typing import Dict, Union, Any, Optional, List, Callable, Awaitable
@@ -23,6 +23,8 @@ class AsyncBaseSocketModeClient:
2323
wss_uri: str
2424
auto_reconnect_enabled: bool
2525
closed: bool
26+
connect_operation_lock: Lock
27+
2628
message_queue: Queue
2729
message_listeners: List[
2830
Union[
@@ -58,15 +60,24 @@ async def issue_new_wss_url(self) -> str:
5860
self.logger.error(f"Failed to retrieve WSS URL: {e}")
5961
raise e
6062

63+
async def is_connected(self) -> bool:
64+
return False
65+
6166
async def connect(self):
6267
raise NotImplementedError()
6368

6469
async def disconnect(self):
6570
raise NotImplementedError()
6671

67-
async def connect_to_new_endpoint(self):
68-
self.wss_uri = await self.issue_new_wss_url()
69-
await self.connect()
72+
async def connect_to_new_endpoint(self, force: bool = False):
73+
try:
74+
await self.connect_operation_lock.acquire()
75+
if force or not await self.is_connected():
76+
self.wss_uri = await self.issue_new_wss_url()
77+
await self.connect()
78+
finally:
79+
if self.connect_operation_lock.locked() is True:
80+
self.connect_operation_lock.release()
7081

7182
async def close(self):
7283
self.closed = True
@@ -116,7 +127,7 @@ async def run_message_listeners(self, message: dict, raw_message: str) -> None:
116127
)
117128
try:
118129
if message.get("type") == "disconnect":
119-
await self.connect_to_new_endpoint()
130+
await self.connect_to_new_endpoint(force=True)
120131
return
121132

122133
for listener in self.message_listeners:

slack_sdk/socket_mode/websockets/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"""
88
import asyncio
99
import logging
10-
from asyncio import Future
10+
from asyncio import Future, Lock
1111
from logging import Logger
1212
from asyncio import Queue
1313
from typing import Union, Optional, List, Callable, Awaitable
@@ -56,6 +56,7 @@ class SocketModeClient(AsyncBaseSocketModeClient):
5656
auto_reconnect_enabled: bool
5757
default_auto_reconnect_enabled: bool
5858
closed: bool
59+
connect_operation_lock: Lock
5960

6061
def __init__(
6162
self,
@@ -78,6 +79,7 @@ def __init__(
7879
self.logger = logger or logging.getLogger(__name__)
7980
self.web_client = web_client or AsyncWebClient()
8081
self.closed = False
82+
self.connect_operation_lock = Lock()
8183
self.default_auto_reconnect_enabled = auto_reconnect_enabled
8284
self.auto_reconnect_enabled = self.default_auto_reconnect_enabled
8385
self.ping_interval = ping_interval
@@ -130,6 +132,13 @@ async def receive_messages(self) -> None:
130132
else:
131133
await asyncio.sleep(consecutive_error_count)
132134

135+
async def is_connected(self) -> bool:
136+
return (
137+
not self.closed
138+
and self.current_session is not None
139+
and not self.current_session.closed
140+
)
141+
133142
async def connect(self):
134143
if self.wss_uri is None:
135144
self.wss_uri = await self.issue_new_wss_url()

0 commit comments

Comments
 (0)