Skip to content

Commit bd9700c

Browse files
jpadillaclaudepre-commit-ci[bot]
authored
Use PyJWK algorithm when encoding without explicit algorithm (#1148)
* fix: use PyJWK key algorithm when encoding without explicit algorithm (#1147) When a PyJWK object is passed to jwt.encode() without specifying an algorithm, the key's embedded algorithm is now used instead of defaulting to HS256. This is achieved by using a sentinel default value so the code can distinguish "no algorithm specified" from an explicit algorithm parameter. https://claude.ai/code/session_016Ekc2jQzpuiDpBvnMAMnUB * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update changelog --------- Co-authored-by: Claude <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: José Padilla <[email protected]>
1 parent 051ea34 commit bd9700c

5 files changed

Lines changed: 50 additions & 4 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Fixed
1717
- Close ``HTTPError`` response to prevent ``ResourceWarning`` on Python 3.14 by @veeceey in `#1133 <https://github.com/jpadilla/pyjwt/pull/1133>`__
1818
- Do not keep ``algorithms`` dict in PyJWK instances by @akx in `#1143 <https://github.com/jpadilla/pyjwt/pull/1143>`__
1919
- Validate the crit (Critical) Header Parameter defined in RFC 7515 §4.1.11. by @dmbs335 in `GHSA-752w-5fwx-jx9f <https://github.com/jpadilla/pyjwt/security/advisories/GHSA-752w-5fwx-jx9f>`__
20+
- Use PyJWK algorithm when encoding without explicit algorithm in `#1148 <https://github.com/jpadilla/pyjwt/pull/1148>`__
2021

2122
Added
2223
~~~~~

jwt/api_jws.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from .algorithms import AllowedPrivateKeys, AllowedPublicKeys
2828
from .types import SigOptions
2929

30+
_ALGORITHM_UNSET = object()
31+
3032

3133
class PyJWS:
3234
header_typ = "JWT"
@@ -119,7 +121,7 @@ def encode(
119121
self,
120122
payload: bytes,
121123
key: AllowedPrivateKeys | PyJWK | str | bytes,
122-
algorithm: str | None = "HS256",
124+
algorithm: str | None = _ALGORITHM_UNSET, # type: ignore[assignment]
123125
headers: dict[str, Any] | None = None,
124126
json_encoder: type[json.JSONEncoder] | None = None,
125127
is_payload_detached: bool = False,
@@ -128,7 +130,12 @@ def encode(
128130
segments: list[bytes] = []
129131

130132
# declare a new var to narrow the type for type checkers
131-
if algorithm is None:
133+
if algorithm is _ALGORITHM_UNSET:
134+
if isinstance(key, PyJWK):
135+
algorithm_ = key.algorithm_name
136+
else:
137+
algorithm_ = "HS256"
138+
elif algorithm is None:
132139
if isinstance(key, PyJWK):
133140
algorithm_ = key.algorithm_name
134141
else:

jwt/api_jwt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from datetime import datetime, timedelta, timezone
99
from typing import TYPE_CHECKING, Any, Union, cast
1010

11-
from .api_jws import PyJWS, _jws_global_obj
11+
from .api_jws import PyJWS, _ALGORITHM_UNSET, _jws_global_obj
1212
from .exceptions import (
1313
DecodeError,
1414
ExpiredSignatureError,
@@ -91,7 +91,7 @@ def encode(
9191
self,
9292
payload: dict[str, Any],
9393
key: AllowedPrivateKeyTypes,
94-
algorithm: str | None = "HS256",
94+
algorithm: str | None = _ALGORITHM_UNSET, # type: ignore[assignment]
9595
headers: dict[str, Any] | None = None,
9696
json_encoder: type[json.JSONEncoder] | None = None,
9797
sort_headers: bool = True,

tests/test_api_jws.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,27 @@ def test_encode_with_jwk(self, jws: PyJWS, payload: bytes) -> None:
261261
),
262262
}
263263

264+
def test_encode_with_jwk_uses_key_algorithm(
265+
self, jws: PyJWS, payload: bytes
266+
) -> None:
267+
"""Test that encoding with a PyJWK key uses the key's algorithm
268+
when no algorithm is explicitly specified. Regression test for #1147."""
269+
jwk = PyJWK(
270+
{
271+
"kty": "oct",
272+
"alg": "HS384",
273+
"k": "c2VjcmV0", # "secret"
274+
}
275+
)
276+
# Should use HS384 from the key, not default to HS256
277+
msg = jws.encode(payload, key=jwk)
278+
header = jws.get_unverified_header(msg)
279+
assert header["alg"] == "HS384"
280+
281+
# Should also be decodable with the same key
282+
decoded = jws.decode(msg, key=jwk)
283+
assert decoded == payload
284+
264285
def test_decode_algorithm_param_should_be_case_sensitive(self, jws: PyJWS) -> None:
265286
example_jws = (
266287
"eyJhbGciOiJoczI1NiIsInR5cCI6IkpXVCJ9" # alg = hs256

tests/test_api_jwt.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytest
88

99
from jwt.types import Options
10+
from jwt.api_jwk import PyJWK
1011
from jwt.api_jwt import PyJWT
1112
from jwt.exceptions import (
1213
DecodeError,
@@ -45,6 +46,22 @@ def test_jwt_with_options(self) -> None:
4546
# assert that verify_signature is respected unless verify_exp is overridden
4647
assert jwt.options["verify_exp"] is False
4748

49+
def test_encode_with_jwk_uses_key_algorithm(self, jwt: PyJWT) -> None:
50+
"""Test that encoding with a PyJWK key uses the key's algorithm
51+
when no algorithm is explicitly specified. Regression test for #1147."""
52+
jwk = PyJWK(
53+
{
54+
"kty": "oct",
55+
"alg": "HS384",
56+
"k": "c2VjcmV0", # "secret"
57+
}
58+
)
59+
payload = {"hello": "world"}
60+
# Should use HS384 from the key, not default to HS256
61+
token = jwt.encode(payload, jwk)
62+
header = jwt.decode_complete(token, jwk, algorithms=["HS384"])["header"]
63+
assert header["alg"] == "HS384"
64+
4865
def test_decodes_valid_jwt(self, jwt: PyJWT) -> None:
4966
example_payload = {"hello": "world"}
5067
example_secret = "secret"

0 commit comments

Comments
 (0)