Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions authlib/integrations/base_client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .errors import InvalidTokenError
from .errors import MismatchingStateError
from .errors import MissingCodeError
from .errors import MissingRequestTokenError
from .errors import MissingTokenError
from .errors import OAuthError
Expand All @@ -22,6 +23,7 @@
"OAuthError",
"MissingRequestTokenError",
"MissingTokenError",
"MissingCodeError",
"TokenExpiredError",
"InvalidTokenError",
"UnsupportedTokenTypeError",
Expand Down
5 changes: 5 additions & 0 deletions authlib/integrations/base_client/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ class MissingTokenError(OAuthError):
error = "missing_token"


class MissingCodeError(OAuthError):
error = "missing_code"
description = "The authorization code is missing from the callback request."


class TokenExpiredError(OAuthError):
error = "token_expired"

Expand Down
4 changes: 4 additions & 0 deletions authlib/integrations/django_client/apps.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from django.http import HttpResponseRedirect

from ..base_client import BaseApp
from ..base_client import MissingCodeError
from ..base_client import OAuth1Mixin
from ..base_client import OAuth2Mixin
from ..base_client import OAuthError
Expand Down Expand Up @@ -78,6 +79,9 @@ def authorize_access_token(self, request, **kwargs):
"state": request.POST.get("state"),
}

if not params["code"]:
raise MissingCodeError()

state_data = self.framework.get_state_data(request.session, params.get("state"))
self.framework.clear_state_data(request.session, params.get("state"))
params = self._format_state_params(state_data, params)
Expand Down
4 changes: 4 additions & 0 deletions authlib/integrations/flask_client/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from flask import session

from ..base_client import BaseApp
from ..base_client import MissingCodeError
from ..base_client import OAuth1Mixin
from ..base_client import OAuth2Mixin
from ..base_client import OAuthError
Expand Down Expand Up @@ -100,6 +101,9 @@ def authorize_access_token(self, **kwargs):
"state": request.form.get("state"),
}

if not params["code"]:
raise MissingCodeError()

state_data = self.framework.get_state_data(session, params.get("state"))
self.framework.clear_state_data(session, params.get("state"))
params = self._format_state_params(state_data, params)
Expand Down
4 changes: 4 additions & 0 deletions authlib/integrations/starlette_client/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from starlette.responses import RedirectResponse

from ..base_client import BaseApp
from ..base_client import MissingCodeError
from ..base_client import OAuthError
from ..base_client.async_app import AsyncOAuth1Mixin
from ..base_client.async_app import AsyncOAuth2Mixin
Expand Down Expand Up @@ -73,6 +74,9 @@ async def authorize_access_token(self, request, **kwargs):
"state": request.query_params.get("state"),
}

if not params["code"]:
raise MissingCodeError()

if self.framework.cache:
session = None
else:
Expand Down
6 changes: 3 additions & 3 deletions tests/clients/test_django/test_oauth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_oauth2_authorize(self):

with mock.patch("requests.sessions.Session.send") as send:
send.return_value = mock_send_value(get_bearer_token())
request2 = self.factory.get(f"/authorize?state={state}")
request2 = self.factory.get(f"/authorize?state={state}&code=foo")
request2.session = request.session

token = client.authorize_access_token(request2)
Expand Down Expand Up @@ -162,7 +162,7 @@ def fake_send(sess, req, **kwargs):
return mock_send_value(get_bearer_token())

with mock.patch("requests.sessions.Session.send", fake_send):
request2 = self.factory.get(f"/authorize?state={state}")
request2 = self.factory.get(f"/authorize?state={state}&code=foo")
request2.session = request.session
token = client.authorize_access_token(request2)
assert token["access_token"] == "a"
Expand Down Expand Up @@ -193,7 +193,7 @@ def test_oauth2_authorize_code_verifier(self):
with mock.patch("requests.sessions.Session.send") as send:
send.return_value = mock_send_value(get_bearer_token())

request2 = self.factory.get(f"/authorize?state={state}")
request2 = self.factory.get(f"/authorize?state={state}&code=foo")
request2.session = request.session

token = client.authorize_access_token(request2)
Expand Down
29 changes: 29 additions & 0 deletions tests/clients/test_flask/test_oauth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from authlib.common.urls import url_decode
from authlib.common.urls import urlparse
from authlib.integrations.base_client.errors import MissingCodeError
from authlib.integrations.flask_client import FlaskOAuth2App
from authlib.integrations.flask_client import OAuth
from authlib.integrations.flask_client import OAuthError
Expand Down Expand Up @@ -525,3 +526,31 @@ def fake_send(sess, req, **kwargs):
assert resp.text == "hi"
with pytest.raises(OAuthError):
client.get("https://i.b/api/user")

def test_oauth2_authorize_missing_code(self):
app = Flask(__name__)
app.secret_key = "!"
oauth = OAuth(app)
client = oauth.register(
"dev",
client_id="dev",
client_secret="dev",
api_base_url="https://i.b/api",
access_token_url="https://i.b/token",
authorize_url="https://i.b/authorize",
)

with app.test_request_context():
resp = client.authorize_redirect("https://b.com/bar")
state = dict(url_decode(urlparse.urlparse(resp.headers["Location"]).query))[
"state"
]
session_data = session[f"_state_dev_{state}"]

# Test missing code parameter
with app.test_request_context(path=f"/?state={state}"):
session[f"_state_dev_{state}"] = session_data
with pytest.raises(MissingCodeError) as exc_info:
client.authorize_access_token()
assert exc_info.value.error == "missing_code"
assert "authorization code is missing" in exc_info.value.description