Skip to content

Commit e075842

Browse files
Drop some superfluos tests
Signed-off-by: Anuradha Karuppiah <[email protected]>
1 parent 711ae67 commit e075842

File tree

1 file changed

+59
-76
lines changed

1 file changed

+59
-76
lines changed

tests/nat/mcp/test_mcp_auth_provider.py

Lines changed: 59 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from unittest.mock import AsyncMock
1717
from unittest.mock import patch
1818

19-
import httpx
2019
import pytest
2120
from pydantic import SecretStr
2221

@@ -322,12 +321,13 @@ async def test_authenticate_uses_config_auth_request(self, mock_config_with_auth
322321

323322
# Mock the discovery and registration process
324323
with patch.object(provider._discoverer, 'discover') as mock_discover:
325-
mock_discover.return_value = (OAuth2Endpoints(
326-
authorization_url="https://auth.example.com/authorize", # type: ignore
327-
token_url="https://auth.example.com/token", # type: ignore
328-
registration_url="https://auth.example.com/register", # type: ignore
329-
),
330-
True)
324+
mock_discover.return_value = (
325+
OAuth2Endpoints(
326+
authorization_url="https://auth.example.com/authorize", # type: ignore
327+
token_url="https://auth.example.com/token", # type: ignore
328+
registration_url="https://auth.example.com/register", # type: ignore
329+
),
330+
True)
331331

332332
with patch.object(provider._registrar, 'register') as mock_register:
333333
mock_register.return_value = OAuth2Credentials(client_id="test_client_id",
@@ -494,12 +494,12 @@ async def test_fetch_pr_issuer_success(self, mock_config):
494494
with patch("httpx.AsyncClient") as mock_client:
495495
mock_resp = AsyncMock()
496496
mock_resp.raise_for_status.return_value = None
497-
mock_resp.aread.return_value = b'{"authorization_servers": ["https://auth.example.com"]}'
497+
mock_resp.aread.return_value = b'{"resource": "https://example.com/api", "authorization_servers": ["https://auth.example.com"]}'
498498
mock_client.return_value.__aenter__.return_value.get.return_value = mock_resp
499499

500500
issuer = await discoverer._fetch_pr_issuer("https://example.com/.well-known/oauth-protected-resource")
501501

502-
assert issuer == "https://auth.example.com"
502+
assert issuer == "https://auth.example.com/"
503503

504504
async def test_fetch_pr_issuer_invalid_json(self, mock_config):
505505
"""Test protected resource issuer fetching with invalid JSON."""
@@ -522,37 +522,24 @@ async def test_fetch_pr_issuer_no_authorization_servers(self, mock_config):
522522
with patch("httpx.AsyncClient") as mock_client:
523523
mock_resp = AsyncMock()
524524
mock_resp.raise_for_status.return_value = None
525-
mock_resp.aread.return_value = b'{"other_field": "value"}'
525+
mock_resp.aread.return_value = b'{"resource": "https://example.com/api", "other_field": "value"}'
526526
mock_client.return_value.__aenter__.return_value.get.return_value = mock_resp
527527

528528
issuer = await discoverer._fetch_pr_issuer("https://example.com/.well-known/oauth-protected-resource")
529529

530530
assert issuer is None
531531

532-
async def test_fetch_pr_issuer_http_error(self, mock_config):
533-
"""Test protected resource issuer fetching with HTTP error."""
534-
discoverer = DiscoverOAuth2Endpoints(mock_config)
535-
536-
with patch("httpx.AsyncClient") as mock_client:
537-
mock_resp = AsyncMock()
538-
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError("Not Found", request=None, response=None) # type: ignore
539-
mock_client.return_value.__aenter__.return_value.get.return_value = mock_resp
540-
541-
with pytest.raises(httpx.HTTPStatusError):
542-
await discoverer._fetch_pr_issuer("https://example.com/.well-known/oauth-protected-resource")
543-
544532
async def test_discover_via_issuer_or_base_success(self, mock_config):
545533
"""Test successful discovery via issuer or base URL."""
546534
discoverer = DiscoverOAuth2Endpoints(mock_config)
547535

548536
with patch("httpx.AsyncClient") as mock_client:
549537
mock_resp = AsyncMock()
550538
mock_resp.status_code = 200
551-
mock_resp.aread.return_value = (
552-
b'{"authorization_endpoint": "https://auth.example.com/authorize", '
553-
b'"token_endpoint": "https://auth.example.com/token", '
554-
b'"registration_endpoint": "https://auth.example.com/register"}'
555-
)
539+
mock_resp.aread.return_value = (b'{"issuer": "https://auth.example.com", '
540+
b'"authorization_endpoint": "https://auth.example.com/authorize", '
541+
b'"token_endpoint": "https://auth.example.com/token", '
542+
b'"registration_endpoint": "https://auth.example.com/register"}')
556543
mock_client.return_value.__aenter__.return_value.get.return_value = mock_resp
557544

558545
endpoints = await discoverer._discover_via_issuer_or_base("https://auth.example.com")
@@ -639,18 +626,6 @@ def test_scopes_supported(self, mock_config):
639626
discoverer._last_oauth_scopes = ["read", "write"]
640627
assert discoverer.scopes_supported() == ["read", "write"]
641628

642-
async def test_register_http_error(self, mock_config, mock_endpoints):
643-
"""Test registration with HTTP error."""
644-
registrar = DynamicClientRegistration(mock_config)
645-
646-
with patch("httpx.AsyncClient") as mock_client:
647-
mock_resp = AsyncMock()
648-
mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError("Bad Request", request=None, response=None) # type: ignore
649-
mock_client.return_value.__aenter__.return_value.post.return_value = mock_resp
650-
651-
with pytest.raises(httpx.HTTPStatusError):
652-
await registrar.register(mock_endpoints, None)
653-
654629
async def test_register_network_error(self, mock_config, mock_endpoints):
655630
"""Test registration with network error."""
656631
registrar = DynamicClientRegistration(mock_config)
@@ -668,7 +643,9 @@ async def test_register_with_scopes(self, mock_config, mock_endpoints):
668643
with patch("httpx.AsyncClient") as mock_client:
669644
mock_resp = AsyncMock()
670645
mock_resp.raise_for_status.return_value = None
671-
mock_resp.aread.return_value = b'{"client_id": "test_client_id", "client_secret": "test_secret"}'
646+
mock_resp.aread.return_value = b'{"client_id": "test_client_id", "client_secret": "test_secret",\
647+
"redirect_uris": ["https://example.com/callback"]}'
648+
672649
mock_client.return_value.__aenter__.return_value.post.return_value = mock_resp
673650

674651
credentials = await registrar.register(mock_endpoints, ["read", "write"])
@@ -684,7 +661,9 @@ async def test_register_with_token_endpoint_auth_method(self, mock_config, mock_
684661
with patch("httpx.AsyncClient") as mock_client:
685662
mock_resp = AsyncMock()
686663
mock_resp.raise_for_status.return_value = None
687-
mock_resp.aread.return_value = b'{"client_id": "test_client_id", "client_secret": "test_secret"}'
664+
mock_resp.aread.return_value = b'{"client_id": "test_client_id", "client_secret": "test_secret",\
665+
"redirect_uris": ["https://example.com/callback"]}'
666+
688667
mock_client.return_value.__aenter__.return_value.post.return_value = mock_resp
689668

690669
credentials = await registrar.register(mock_endpoints, None)
@@ -701,7 +680,8 @@ async def test_build_oauth2_delegate_success(self, mock_config, mock_endpoints,
701680
provider._cached_endpoints = mock_endpoints
702681
provider._cached_credentials = mock_credentials
703682

704-
with patch('nat.authentication.oauth2.oauth2_auth_code_flow_provider.OAuth2AuthCodeFlowProvider') as mock_oauth2_provider: # noqa: E501
683+
with patch('nat.authentication.oauth2.oauth2_auth_code_flow_provider.OAuth2AuthCodeFlowProvider'
684+
) as mock_oauth2_provider: # noqa: E501
705685
mock_oauth2_provider.return_value = AsyncMock()
706686

707687
await provider._build_oauth2_delegate()
@@ -716,7 +696,8 @@ async def test_build_oauth2_delegate_with_pkce(self, mock_config, mock_endpoints
716696
provider._cached_endpoints = mock_endpoints
717697
provider._cached_credentials = mock_credentials
718698

719-
with patch('nat.authentication.oauth2.oauth2_auth_code_flow_provider.OAuth2AuthCodeFlowProvider') as mock_oauth2_provider: # noqa: E501
699+
with patch('nat.authentication.oauth2.oauth2_auth_code_flow_provider.OAuth2AuthCodeFlowProvider'
700+
) as mock_oauth2_provider: # noqa: E501
720701
mock_oauth2_provider.return_value = AsyncMock()
721702

722703
await provider._build_oauth2_delegate()
@@ -732,7 +713,8 @@ async def test_build_oauth2_delegate_with_custom_auth_method(self, mock_config,
732713
provider._cached_endpoints = mock_endpoints
733714
provider._cached_credentials = mock_credentials
734715

735-
with patch('nat.authentication.oauth2.oauth2_auth_code_flow_provider.OAuth2AuthCodeFlowProvider') as mock_oauth2_provider: # noqa: E501
716+
with patch('nat.authentication.oauth2.oauth2_auth_code_flow_provider.OAuth2AuthCodeFlowProvider'
717+
) as mock_oauth2_provider: # noqa: E501
736718
mock_oauth2_provider.return_value = AsyncMock()
737719

738720
await provider._build_oauth2_delegate()
@@ -747,17 +729,17 @@ async def test_discover_and_register_with_endpoints_changed(self, mock_config):
747729

748730
# Mock discovery returning changed endpoints
749731
with patch.object(provider._discoverer, 'discover') as mock_discover:
750-
mock_discover.return_value = (OAuth2Endpoints(
751-
authorization_url="https://auth.example.com/authorize", # type: ignore
752-
token_url="https://auth.example.com/token", # type: ignore
753-
registration_url="https://auth.example.com/register", # type: ignore
754-
), True)
732+
mock_discover.return_value = (
733+
OAuth2Endpoints(
734+
authorization_url="https://auth.example.com/authorize", # type: ignore
735+
token_url="https://auth.example.com/token", # type: ignore
736+
registration_url="https://auth.example.com/register", # type: ignore
737+
),
738+
True)
755739

756740
with patch.object(provider._registrar, 'register') as mock_register:
757-
mock_register.return_value = OAuth2Credentials(
758-
client_id="test_client_id",
759-
client_secret="test_client_secret"
760-
)
741+
mock_register.return_value = OAuth2Credentials(client_id="test_client_id",
742+
client_secret="test_client_secret")
761743

762744
auth_request = AuthRequest(reason=AuthReason.RETRY_AFTER_401)
763745
await provider._discover_and_register(auth_request)
@@ -768,16 +750,17 @@ async def test_discover_and_register_with_endpoints_changed(self, mock_config):
768750
async def test_discover_and_register_with_manual_credentials(self, mock_config):
769751
"""Test discover and register with manual credentials."""
770752
config = mock_config.model_copy(update={
771-
'client_id': 'manual_client_id',
772-
'client_secret': 'manual_client_secret'
753+
'client_id': 'manual_client_id', 'client_secret': 'manual_client_secret'
773754
})
774755
provider = MCPOAuth2Provider(config)
775756

776757
with patch.object(provider._discoverer, 'discover') as mock_discover:
777-
mock_discover.return_value = (OAuth2Endpoints(
778-
authorization_url="https://auth.example.com/authorize", # type: ignore
779-
token_url="https://auth.example.com/token", # type: ignore
780-
), True)
758+
mock_discover.return_value = (
759+
OAuth2Endpoints(
760+
authorization_url="https://auth.example.com/authorize", # type: ignore
761+
token_url="https://auth.example.com/token", # type: ignore
762+
),
763+
True)
781764

782765
auth_request = AuthRequest(reason=AuthReason.RETRY_AFTER_401)
783766
await provider._discover_and_register(auth_request)
@@ -792,17 +775,17 @@ async def test_discover_and_register_without_registration_endpoint(self, mock_co
792775
provider = MCPOAuth2Provider(mock_config)
793776

794777
with patch.object(provider._discoverer, 'discover') as mock_discover:
795-
mock_discover.return_value = (OAuth2Endpoints(
796-
authorization_url="https://auth.example.com/authorize", # type: ignore
797-
token_url="https://auth.example.com/token", # type: ignore
798-
registration_url=None, # No registration endpoint
799-
), True)
778+
mock_discover.return_value = (
779+
OAuth2Endpoints(
780+
authorization_url="https://auth.example.com/authorize", # type: ignore
781+
token_url="https://auth.example.com/token", # type: ignore
782+
registration_url=None, # No registration endpoint
783+
),
784+
True)
800785

801786
with patch.object(provider._registrar, 'register') as mock_register:
802-
mock_register.return_value = OAuth2Credentials(
803-
client_id="test_client_id",
804-
client_secret="test_client_secret"
805-
)
787+
mock_register.return_value = OAuth2Credentials(client_id="test_client_id",
788+
client_secret="test_client_secret")
806789

807790
auth_request = AuthRequest(reason=AuthReason.RETRY_AFTER_401)
808791
await provider._discover_and_register(auth_request)
@@ -812,9 +795,9 @@ async def test_discover_and_register_without_registration_endpoint(self, mock_co
812795

813796
async def test_authenticate_with_user_id_propagation(self, mock_config_with_credentials, mock_endpoints):
814797
"""Test that user_id is properly propagated in auth request."""
815-
config = mock_config_with_credentials.model_copy(
816-
update={'auth_request': AuthRequest(reason=AuthReason.RETRY_AFTER_401, www_authenticate="Bearer realm=api")}
817-
)
798+
config = mock_config_with_credentials.model_copy(update={
799+
'auth_request': AuthRequest(reason=AuthReason.RETRY_AFTER_401, www_authenticate="Bearer realm=api")
800+
})
818801
provider = MCPOAuth2Provider(config)
819802

820803
with patch.object(provider._discoverer, 'discover') as mock_discover:
@@ -837,8 +820,7 @@ async def test_authenticate_with_user_id_propagation(self, mock_config_with_cred
837820
async def test_authenticate_without_user_id_in_request(self, mock_config_with_credentials, mock_endpoints):
838821
"""Test authentication when user_id is not in the original request."""
839822
config = mock_config_with_credentials.model_copy(
840-
update={'auth_request': AuthRequest(reason=AuthReason.RETRY_AFTER_401)}
841-
)
823+
update={'auth_request': AuthRequest(reason=AuthReason.RETRY_AFTER_401)})
842824
provider = MCPOAuth2Provider(config)
843825

844826
with patch.object(provider._discoverer, 'discover') as mock_discover:
@@ -858,11 +840,12 @@ async def test_authenticate_without_user_id_in_request(self, mock_config_with_cr
858840
if auth_request:
859841
assert auth_request.user_id == "test_user"
860842

861-
async def test_authenticate_retry_after_401_clears_auth_code_provider(self, mock_config_with_credentials, mock_endpoints): # noqa: E501
843+
async def test_authenticate_retry_after_401_clears_auth_code_provider(self,
844+
mock_config_with_credentials,
845+
mock_endpoints): # noqa: E501
862846
"""Test that RETRY_AFTER_401 clears the auth code provider."""
863847
config = mock_config_with_credentials.model_copy(
864-
update={'auth_request': AuthRequest(reason=AuthReason.RETRY_AFTER_401)}
865-
)
848+
update={'auth_request': AuthRequest(reason=AuthReason.RETRY_AFTER_401)})
866849
provider = MCPOAuth2Provider(config)
867850

868851
# Set up a mock auth code provider

0 commit comments

Comments
 (0)