Skip to content

Commit 8186325

Browse files
committed
remove the original python dict token storage
Signed-off-by: Yuchen Zhang <[email protected]>
1 parent 05039c2 commit 8186325

File tree

2 files changed

+11
-15
lines changed

2 files changed

+11
-15
lines changed

packages/nvidia_nat_mcp/src/nat/plugins/mcp/client_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ async def _get_auth_headers(self,
146146
token = auth_result.credentials[0].token.get_secret_value()
147147
return {"Authorization": f"Bearer {token}"}
148148
else:
149+
logger.info("Auth provider did not return BearerTokenCred")
149150
return {}
150151
except Exception as e:
151152
logger.warning("Failed to get auth token: %s", e)

src/nat/authentication/oauth2/oauth2_auth_code_flow_provider.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,13 @@ class OAuth2AuthCodeFlowProvider(AuthProviderBase[OAuth2AuthCodeFlowProviderConf
3838

3939
def __init__(self, config: OAuth2AuthCodeFlowProviderConfig, token_storage=None):
4040
super().__init__(config)
41-
self._authenticated_tokens: dict[str, AuthResult] = {}
4241
self._auth_callback = None
43-
self._token_storage = token_storage
42+
# Always use token storage - defaults to in-memory if not provided
43+
if token_storage is None:
44+
from nat.plugins.mcp.auth.token_storage import InMemoryTokenStorage
45+
self._token_storage = InMemoryTokenStorage()
46+
else:
47+
self._token_storage = token_storage
4448

4549
async def _attempt_token_refresh(self, user_id: str, auth_result: AuthResult) -> AuthResult | None:
4650
refresh_token = auth_result.raw.get("refresh_token")
@@ -63,10 +67,7 @@ async def _attempt_token_refresh(self, user_id: str, auth_result: AuthResult) ->
6367
raw=new_token_data,
6468
)
6569

66-
if self._token_storage:
67-
await self._token_storage.store(user_id, new_auth_result)
68-
else:
69-
self._authenticated_tokens[user_id] = new_auth_result
70+
await self._token_storage.store(user_id, new_auth_result)
7071
except httpx.HTTPStatusError:
7172
return None
7273
except httpx.RequestError:
@@ -93,11 +94,8 @@ async def authenticate(self, user_id: str | None = None, **kwargs) -> AuthResult
9394
user_id = session_id
9495

9596
if user_id:
96-
# Try to retrieve from token storage or fallback to dict
97-
if self._token_storage:
98-
auth_result = await self._token_storage.retrieve(user_id)
99-
else:
100-
auth_result = self._authenticated_tokens.get(user_id)
97+
# Try to retrieve from token storage
98+
auth_result = await self._token_storage.retrieve(user_id)
10199

102100
if auth_result:
103101
if not auth_result.is_expired():
@@ -137,9 +135,6 @@ async def authenticate(self, user_id: str | None = None, **kwargs) -> AuthResult
137135
)
138136

139137
if user_id:
140-
if self._token_storage:
141-
await self._token_storage.store(user_id, auth_result)
142-
else:
143-
self._authenticated_tokens[user_id] = auth_result
138+
await self._token_storage.store(user_id, auth_result)
144139

145140
return auth_result

0 commit comments

Comments
 (0)