Skip to content

Commit 8b65380

Browse files
committed
address review
1 parent b2f1e5a commit 8b65380

File tree

1 file changed

+53
-29
lines changed

1 file changed

+53
-29
lines changed

src/Access/TokenProcessorsJWT.cpp

Lines changed: 53 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
#include <Common/logger_useful.h>
66
#include <Poco/String.h>
77
#include <openssl/bio.h>
8-
#include <openssl/bn.h>
9-
#include <openssl/ec.h>
8+
#include <openssl/core_names.h>
109
#include <openssl/evp.h>
10+
#include <openssl/param_build.h>
1111
#include <openssl/pem.h>
12+
#include <cstring>
1213

1314
namespace DB {
1415

@@ -171,37 +172,54 @@ std::string create_public_key_from_ec_components(const std::string & x, const st
171172
auto decoded_x = decode_base64url(x);
172173
auto decoded_y = decode_base64url(y);
173174

174-
BIGNUM * raw_x_bn = BN_bin2bn(reinterpret_cast<const unsigned char *>(decoded_x.data()), static_cast<int>(decoded_x.size()), nullptr);
175-
BIGNUM * raw_y_bn = BN_bin2bn(reinterpret_cast<const unsigned char *>(decoded_y.data()), static_cast<int>(decoded_y.size()), nullptr);
176-
std::unique_ptr<BIGNUM, decltype(&BN_free)> x_bn(raw_x_bn, BN_free);
177-
std::unique_ptr<BIGNUM, decltype(&BN_free)> y_bn(raw_y_bn, BN_free);
178-
if (!x_bn || !y_bn)
179-
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to parse EC key coordinates");
175+
size_t coordinate_size = 0;
176+
if (curve_nid == NID_X9_62_prime256v1)
177+
coordinate_size = 32;
178+
else if (curve_nid == NID_secp384r1)
179+
coordinate_size = 48;
180+
else if (curve_nid == NID_secp521r1)
181+
coordinate_size = 66;
182+
else
183+
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: unsupported EC curve");
184+
185+
if (decoded_x.size() > coordinate_size || decoded_y.size() > coordinate_size)
186+
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: invalid EC key coordinates length");
187+
188+
std::vector<unsigned char> public_key_octets(1 + 2 * coordinate_size, 0);
189+
public_key_octets[0] = 0x04; // Uncompressed point format.
190+
std::memcpy(public_key_octets.data() + 1 + (coordinate_size - decoded_x.size()), decoded_x.data(), decoded_x.size());
191+
std::memcpy(public_key_octets.data() + 1 + coordinate_size + (coordinate_size - decoded_y.size()), decoded_y.data(), decoded_y.size());
192+
193+
const char * group_name = OBJ_nid2sn(curve_nid);
194+
if (!group_name)
195+
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: unsupported EC curve");
180196

181-
std::unique_ptr<EC_KEY, decltype(&EC_KEY_free)> ec_key(EC_KEY_new_by_curve_name(curve_nid), EC_KEY_free);
182-
if (!ec_key)
183-
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to construct EC key");
197+
std::unique_ptr<OSSL_PARAM_BLD, decltype(&OSSL_PARAM_BLD_free)> params_bld(OSSL_PARAM_BLD_new(), OSSL_PARAM_BLD_free);
198+
if (!params_bld)
199+
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to allocate OpenSSL parameter builder");
184200

185-
const EC_GROUP * group = EC_KEY_get0_group(ec_key.get());
186-
if (!group)
187-
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to read EC group");
201+
if (OSSL_PARAM_BLD_push_utf8_string(params_bld.get(), OSSL_PKEY_PARAM_GROUP_NAME, group_name, 0) != 1)
202+
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to set EC group parameter");
188203

189-
std::unique_ptr<EC_POINT, decltype(&EC_POINT_free)> point(EC_POINT_new(group), EC_POINT_free);
190-
if (!point)
191-
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to create EC point");
204+
if (OSSL_PARAM_BLD_push_octet_string(params_bld.get(), OSSL_PKEY_PARAM_PUB_KEY, public_key_octets.data(), public_key_octets.size()) != 1)
205+
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to set EC public key parameter");
192206

193-
if (EC_POINT_set_affine_coordinates(group, point.get(), x_bn.get(), y_bn.get(), nullptr) != 1)
194-
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: invalid EC public key point");
207+
std::unique_ptr<OSSL_PARAM, decltype(&OSSL_PARAM_free)> params(OSSL_PARAM_BLD_to_param(params_bld.get()), OSSL_PARAM_free);
208+
if (!params)
209+
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to build OpenSSL parameters");
195210

196-
if (EC_KEY_set_public_key(ec_key.get(), point.get()) != 1)
197-
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to set EC public key");
211+
std::unique_ptr<EVP_PKEY_CTX, decltype(&EVP_PKEY_CTX_free)> key_ctx(EVP_PKEY_CTX_new_from_name(nullptr, "EC", nullptr), EVP_PKEY_CTX_free);
212+
if (!key_ctx)
213+
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to create EVP key context");
198214

199-
std::unique_ptr<EVP_PKEY, decltype(&EVP_PKEY_free)> evp_key(EVP_PKEY_new(), EVP_PKEY_free);
200-
if (!evp_key)
201-
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to allocate EVP key");
215+
if (EVP_PKEY_fromdata_init(key_ctx.get()) <= 0)
216+
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to initialize EVP key import");
202217

203-
if (EVP_PKEY_assign_EC_KEY(evp_key.get(), ec_key.release()) != 1)
204-
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to assign EC key");
218+
EVP_PKEY * raw_evp_key = nullptr;
219+
if (EVP_PKEY_fromdata(key_ctx.get(), &raw_evp_key, EVP_PKEY_PUBLIC_KEY, params.get()) <= 0)
220+
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to import EC public key");
221+
222+
std::unique_ptr<EVP_PKEY, decltype(&EVP_PKEY_free)> evp_key(raw_evp_key, EVP_PKEY_free);
205223

206224
std::unique_ptr<BIO, decltype(&BIO_free)> bio(BIO_new(BIO_s_mem()), BIO_free);
207225
if (!bio)
@@ -451,8 +469,12 @@ bool JwksJwtProcessor::resolveAndValidate(TokenCredentials & credentials) const
451469

452470
if (public_key.empty())
453471
{
454-
if (jwk.has_jwk_claim("x") && jwk.has_jwk_claim("y"))
472+
const auto key_type = jwk.get_key_type();
473+
if (key_type == "EC")
455474
{
475+
if (!(jwk.has_jwk_claim("x") && jwk.has_jwk_claim("y")))
476+
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "{}: invalid JWK: missing 'x'/'y' claims for EC key type", processor_name);
477+
456478
int curve_nid = NID_undef;
457479
std::optional<String> expected_crv;
458480
if (algo == "es256")
@@ -486,15 +508,17 @@ bool JwksJwtProcessor::resolveAndValidate(TokenCredentials & credentials) const
486508
const auto y = jwk.get_jwk_claim("y").as_string();
487509
public_key = create_public_key_from_ec_components(x, y, curve_nid);
488510
}
489-
else
511+
else if (key_type == "RSA")
490512
{
491513
if (!(jwk.has_jwk_claim("n") && jwk.has_jwk_claim("e")))
492-
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "{}: invalid JWK: neither 'x'/'y' nor 'n'/'e' found", processor_name);
514+
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "{}: invalid JWK: missing 'n'/'e' claims for RSA key type", processor_name);
493515
LOG_TRACE(getLogger("TokenAuthentication"), "{}: `issuer` or `x5c` not present, verifying {} with RSA components", processor_name, username);
494516
const auto modulus = jwk.get_jwk_claim("n").as_string();
495517
const auto exponent = jwk.get_jwk_claim("e").as_string();
496518
public_key = jwt::helper::create_public_key_from_rsa_components(modulus, exponent);
497519
}
520+
else
521+
throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "{}: invalid JWK key type '{}'", processor_name, key_type);
498522
}
499523

500524
if (jwk.has_algorithm() && Poco::toLower(jwk.get_algorithm()) != algo)

0 commit comments

Comments
 (0)