1616from unittest .mock import AsyncMock
1717from unittest .mock import patch
1818
19- import httpx
2019import pytest
2120from pydantic import SecretStr
2221
@@ -322,12 +321,13 @@ async def test_authenticate_uses_config_auth_request(self, mock_config_with_auth
322321
323322 # Mock the discovery and registration process
324323 with patch .object (provider ._discoverer , 'discover' ) as mock_discover :
325- mock_discover .return_value = (OAuth2Endpoints (
326- authorization_url = "https://auth.example.com/authorize" , # type: ignore
327- token_url = "https://auth.example.com/token" , # type: ignore
328- registration_url = "https://auth.example.com/register" , # type: ignore
329- ),
330- True )
324+ mock_discover .return_value = (
325+ OAuth2Endpoints (
326+ authorization_url = "https://auth.example.com/authorize" , # type: ignore
327+ token_url = "https://auth.example.com/token" , # type: ignore
328+ registration_url = "https://auth.example.com/register" , # type: ignore
329+ ),
330+ True )
331331
332332 with patch .object (provider ._registrar , 'register' ) as mock_register :
333333 mock_register .return_value = OAuth2Credentials (client_id = "test_client_id" ,
@@ -494,12 +494,12 @@ async def test_fetch_pr_issuer_success(self, mock_config):
494494 with patch ("httpx.AsyncClient" ) as mock_client :
495495 mock_resp = AsyncMock ()
496496 mock_resp .raise_for_status .return_value = None
497- mock_resp .aread .return_value = b'{"authorization_servers": ["https://auth.example.com"]}'
497+ mock_resp .aread .return_value = b'{"resource": "https://example.com/api", " authorization_servers": ["https://auth.example.com"]}'
498498 mock_client .return_value .__aenter__ .return_value .get .return_value = mock_resp
499499
500500 issuer = await discoverer ._fetch_pr_issuer ("https://example.com/.well-known/oauth-protected-resource" )
501501
502- assert issuer == "https://auth.example.com"
502+ assert issuer == "https://auth.example.com/ "
503503
504504 async def test_fetch_pr_issuer_invalid_json (self , mock_config ):
505505 """Test protected resource issuer fetching with invalid JSON."""
@@ -522,37 +522,24 @@ async def test_fetch_pr_issuer_no_authorization_servers(self, mock_config):
522522 with patch ("httpx.AsyncClient" ) as mock_client :
523523 mock_resp = AsyncMock ()
524524 mock_resp .raise_for_status .return_value = None
525- mock_resp .aread .return_value = b'{"other_field": "value"}'
525+ mock_resp .aread .return_value = b'{"resource": "https://example.com/api", " other_field": "value"}'
526526 mock_client .return_value .__aenter__ .return_value .get .return_value = mock_resp
527527
528528 issuer = await discoverer ._fetch_pr_issuer ("https://example.com/.well-known/oauth-protected-resource" )
529529
530530 assert issuer is None
531531
532- async def test_fetch_pr_issuer_http_error (self , mock_config ):
533- """Test protected resource issuer fetching with HTTP error."""
534- discoverer = DiscoverOAuth2Endpoints (mock_config )
535-
536- with patch ("httpx.AsyncClient" ) as mock_client :
537- mock_resp = AsyncMock ()
538- mock_resp .raise_for_status .side_effect = httpx .HTTPStatusError ("Not Found" , request = None , response = None ) # type: ignore
539- mock_client .return_value .__aenter__ .return_value .get .return_value = mock_resp
540-
541- with pytest .raises (httpx .HTTPStatusError ):
542- await discoverer ._fetch_pr_issuer ("https://example.com/.well-known/oauth-protected-resource" )
543-
544532 async def test_discover_via_issuer_or_base_success (self , mock_config ):
545533 """Test successful discovery via issuer or base URL."""
546534 discoverer = DiscoverOAuth2Endpoints (mock_config )
547535
548536 with patch ("httpx.AsyncClient" ) as mock_client :
549537 mock_resp = AsyncMock ()
550538 mock_resp .status_code = 200
551- mock_resp .aread .return_value = (
552- b'{"authorization_endpoint": "https://auth.example.com/authorize", '
553- b'"token_endpoint": "https://auth.example.com/token", '
554- b'"registration_endpoint": "https://auth.example.com/register"}'
555- )
539+ mock_resp .aread .return_value = (b'{"issuer": "https://auth.example.com", '
540+ b'"authorization_endpoint": "https://auth.example.com/authorize", '
541+ b'"token_endpoint": "https://auth.example.com/token", '
542+ b'"registration_endpoint": "https://auth.example.com/register"}' )
556543 mock_client .return_value .__aenter__ .return_value .get .return_value = mock_resp
557544
558545 endpoints = await discoverer ._discover_via_issuer_or_base ("https://auth.example.com" )
@@ -639,18 +626,6 @@ def test_scopes_supported(self, mock_config):
639626 discoverer ._last_oauth_scopes = ["read" , "write" ]
640627 assert discoverer .scopes_supported () == ["read" , "write" ]
641628
642- async def test_register_http_error (self , mock_config , mock_endpoints ):
643- """Test registration with HTTP error."""
644- registrar = DynamicClientRegistration (mock_config )
645-
646- with patch ("httpx.AsyncClient" ) as mock_client :
647- mock_resp = AsyncMock ()
648- mock_resp .raise_for_status .side_effect = httpx .HTTPStatusError ("Bad Request" , request = None , response = None ) # type: ignore
649- mock_client .return_value .__aenter__ .return_value .post .return_value = mock_resp
650-
651- with pytest .raises (httpx .HTTPStatusError ):
652- await registrar .register (mock_endpoints , None )
653-
654629 async def test_register_network_error (self , mock_config , mock_endpoints ):
655630 """Test registration with network error."""
656631 registrar = DynamicClientRegistration (mock_config )
@@ -668,7 +643,9 @@ async def test_register_with_scopes(self, mock_config, mock_endpoints):
668643 with patch ("httpx.AsyncClient" ) as mock_client :
669644 mock_resp = AsyncMock ()
670645 mock_resp .raise_for_status .return_value = None
671- mock_resp .aread .return_value = b'{"client_id": "test_client_id", "client_secret": "test_secret"}'
646+ mock_resp .aread .return_value = b'{"client_id": "test_client_id", "client_secret": "test_secret",\
647+ "redirect_uris": ["https://example.com/callback"]}'
648+
672649 mock_client .return_value .__aenter__ .return_value .post .return_value = mock_resp
673650
674651 credentials = await registrar .register (mock_endpoints , ["read" , "write" ])
@@ -684,7 +661,9 @@ async def test_register_with_token_endpoint_auth_method(self, mock_config, mock_
684661 with patch ("httpx.AsyncClient" ) as mock_client :
685662 mock_resp = AsyncMock ()
686663 mock_resp .raise_for_status .return_value = None
687- mock_resp .aread .return_value = b'{"client_id": "test_client_id", "client_secret": "test_secret"}'
664+ mock_resp .aread .return_value = b'{"client_id": "test_client_id", "client_secret": "test_secret",\
665+ "redirect_uris": ["https://example.com/callback"]}'
666+
688667 mock_client .return_value .__aenter__ .return_value .post .return_value = mock_resp
689668
690669 credentials = await registrar .register (mock_endpoints , None )
@@ -701,7 +680,8 @@ async def test_build_oauth2_delegate_success(self, mock_config, mock_endpoints,
701680 provider ._cached_endpoints = mock_endpoints
702681 provider ._cached_credentials = mock_credentials
703682
704- with patch ('nat.authentication.oauth2.oauth2_auth_code_flow_provider.OAuth2AuthCodeFlowProvider' ) as mock_oauth2_provider : # noqa: E501
683+ with patch ('nat.authentication.oauth2.oauth2_auth_code_flow_provider.OAuth2AuthCodeFlowProvider'
684+ ) as mock_oauth2_provider : # noqa: E501
705685 mock_oauth2_provider .return_value = AsyncMock ()
706686
707687 await provider ._build_oauth2_delegate ()
@@ -716,7 +696,8 @@ async def test_build_oauth2_delegate_with_pkce(self, mock_config, mock_endpoints
716696 provider ._cached_endpoints = mock_endpoints
717697 provider ._cached_credentials = mock_credentials
718698
719- with patch ('nat.authentication.oauth2.oauth2_auth_code_flow_provider.OAuth2AuthCodeFlowProvider' ) as mock_oauth2_provider : # noqa: E501
699+ with patch ('nat.authentication.oauth2.oauth2_auth_code_flow_provider.OAuth2AuthCodeFlowProvider'
700+ ) as mock_oauth2_provider : # noqa: E501
720701 mock_oauth2_provider .return_value = AsyncMock ()
721702
722703 await provider ._build_oauth2_delegate ()
@@ -732,7 +713,8 @@ async def test_build_oauth2_delegate_with_custom_auth_method(self, mock_config,
732713 provider ._cached_endpoints = mock_endpoints
733714 provider ._cached_credentials = mock_credentials
734715
735- with patch ('nat.authentication.oauth2.oauth2_auth_code_flow_provider.OAuth2AuthCodeFlowProvider' ) as mock_oauth2_provider : # noqa: E501
716+ with patch ('nat.authentication.oauth2.oauth2_auth_code_flow_provider.OAuth2AuthCodeFlowProvider'
717+ ) as mock_oauth2_provider : # noqa: E501
736718 mock_oauth2_provider .return_value = AsyncMock ()
737719
738720 await provider ._build_oauth2_delegate ()
@@ -747,17 +729,17 @@ async def test_discover_and_register_with_endpoints_changed(self, mock_config):
747729
748730 # Mock discovery returning changed endpoints
749731 with patch .object (provider ._discoverer , 'discover' ) as mock_discover :
750- mock_discover .return_value = (OAuth2Endpoints (
751- authorization_url = "https://auth.example.com/authorize" , # type: ignore
752- token_url = "https://auth.example.com/token" , # type: ignore
753- registration_url = "https://auth.example.com/register" , # type: ignore
754- ), True )
732+ mock_discover .return_value = (
733+ OAuth2Endpoints (
734+ authorization_url = "https://auth.example.com/authorize" , # type: ignore
735+ token_url = "https://auth.example.com/token" , # type: ignore
736+ registration_url = "https://auth.example.com/register" , # type: ignore
737+ ),
738+ True )
755739
756740 with patch .object (provider ._registrar , 'register' ) as mock_register :
757- mock_register .return_value = OAuth2Credentials (
758- client_id = "test_client_id" ,
759- client_secret = "test_client_secret"
760- )
741+ mock_register .return_value = OAuth2Credentials (client_id = "test_client_id" ,
742+ client_secret = "test_client_secret" )
761743
762744 auth_request = AuthRequest (reason = AuthReason .RETRY_AFTER_401 )
763745 await provider ._discover_and_register (auth_request )
@@ -768,16 +750,17 @@ async def test_discover_and_register_with_endpoints_changed(self, mock_config):
768750 async def test_discover_and_register_with_manual_credentials (self , mock_config ):
769751 """Test discover and register with manual credentials."""
770752 config = mock_config .model_copy (update = {
771- 'client_id' : 'manual_client_id' ,
772- 'client_secret' : 'manual_client_secret'
753+ 'client_id' : 'manual_client_id' , 'client_secret' : 'manual_client_secret'
773754 })
774755 provider = MCPOAuth2Provider (config )
775756
776757 with patch .object (provider ._discoverer , 'discover' ) as mock_discover :
777- mock_discover .return_value = (OAuth2Endpoints (
778- authorization_url = "https://auth.example.com/authorize" , # type: ignore
779- token_url = "https://auth.example.com/token" , # type: ignore
780- ), True )
758+ mock_discover .return_value = (
759+ OAuth2Endpoints (
760+ authorization_url = "https://auth.example.com/authorize" , # type: ignore
761+ token_url = "https://auth.example.com/token" , # type: ignore
762+ ),
763+ True )
781764
782765 auth_request = AuthRequest (reason = AuthReason .RETRY_AFTER_401 )
783766 await provider ._discover_and_register (auth_request )
@@ -792,17 +775,17 @@ async def test_discover_and_register_without_registration_endpoint(self, mock_co
792775 provider = MCPOAuth2Provider (mock_config )
793776
794777 with patch .object (provider ._discoverer , 'discover' ) as mock_discover :
795- mock_discover .return_value = (OAuth2Endpoints (
796- authorization_url = "https://auth.example.com/authorize" , # type: ignore
797- token_url = "https://auth.example.com/token" , # type: ignore
798- registration_url = None , # No registration endpoint
799- ), True )
778+ mock_discover .return_value = (
779+ OAuth2Endpoints (
780+ authorization_url = "https://auth.example.com/authorize" , # type: ignore
781+ token_url = "https://auth.example.com/token" , # type: ignore
782+ registration_url = None , # No registration endpoint
783+ ),
784+ True )
800785
801786 with patch .object (provider ._registrar , 'register' ) as mock_register :
802- mock_register .return_value = OAuth2Credentials (
803- client_id = "test_client_id" ,
804- client_secret = "test_client_secret"
805- )
787+ mock_register .return_value = OAuth2Credentials (client_id = "test_client_id" ,
788+ client_secret = "test_client_secret" )
806789
807790 auth_request = AuthRequest (reason = AuthReason .RETRY_AFTER_401 )
808791 await provider ._discover_and_register (auth_request )
@@ -812,9 +795,9 @@ async def test_discover_and_register_without_registration_endpoint(self, mock_co
812795
813796 async def test_authenticate_with_user_id_propagation (self , mock_config_with_credentials , mock_endpoints ):
814797 """Test that user_id is properly propagated in auth request."""
815- config = mock_config_with_credentials .model_copy (
816- update = { 'auth_request' : AuthRequest (reason = AuthReason .RETRY_AFTER_401 , www_authenticate = "Bearer realm=api" )}
817- )
798+ config = mock_config_with_credentials .model_copy (update = {
799+ 'auth_request' : AuthRequest (reason = AuthReason .RETRY_AFTER_401 , www_authenticate = "Bearer realm=api" )
800+ } )
818801 provider = MCPOAuth2Provider (config )
819802
820803 with patch .object (provider ._discoverer , 'discover' ) as mock_discover :
@@ -837,8 +820,7 @@ async def test_authenticate_with_user_id_propagation(self, mock_config_with_cred
837820 async def test_authenticate_without_user_id_in_request (self , mock_config_with_credentials , mock_endpoints ):
838821 """Test authentication when user_id is not in the original request."""
839822 config = mock_config_with_credentials .model_copy (
840- update = {'auth_request' : AuthRequest (reason = AuthReason .RETRY_AFTER_401 )}
841- )
823+ update = {'auth_request' : AuthRequest (reason = AuthReason .RETRY_AFTER_401 )})
842824 provider = MCPOAuth2Provider (config )
843825
844826 with patch .object (provider ._discoverer , 'discover' ) as mock_discover :
@@ -858,11 +840,12 @@ async def test_authenticate_without_user_id_in_request(self, mock_config_with_cr
858840 if auth_request :
859841 assert auth_request .user_id == "test_user"
860842
861- async def test_authenticate_retry_after_401_clears_auth_code_provider (self , mock_config_with_credentials , mock_endpoints ): # noqa: E501
843+ async def test_authenticate_retry_after_401_clears_auth_code_provider (self ,
844+ mock_config_with_credentials ,
845+ mock_endpoints ): # noqa: E501
862846 """Test that RETRY_AFTER_401 clears the auth code provider."""
863847 config = mock_config_with_credentials .model_copy (
864- update = {'auth_request' : AuthRequest (reason = AuthReason .RETRY_AFTER_401 )}
865- )
848+ update = {'auth_request' : AuthRequest (reason = AuthReason .RETRY_AFTER_401 )})
866849 provider = MCPOAuth2Provider (config )
867850
868851 # Set up a mock auth code provider
0 commit comments