|
5 | 5 | #include <Common/logger_useful.h> |
6 | 6 | #include <Poco/String.h> |
7 | 7 | #include <openssl/bio.h> |
8 | | -#include <openssl/bn.h> |
9 | | -#include <openssl/ec.h> |
| 8 | +#include <openssl/core_names.h> |
10 | 9 | #include <openssl/evp.h> |
| 10 | +#include <openssl/param_build.h> |
11 | 11 | #include <openssl/pem.h> |
| 12 | +#include <cstring> |
12 | 13 |
|
13 | 14 | namespace DB { |
14 | 15 |
|
@@ -171,37 +172,54 @@ std::string create_public_key_from_ec_components(const std::string & x, const st |
171 | 172 | auto decoded_x = decode_base64url(x); |
172 | 173 | auto decoded_y = decode_base64url(y); |
173 | 174 |
|
174 | | - BIGNUM * raw_x_bn = BN_bin2bn(reinterpret_cast<const unsigned char *>(decoded_x.data()), static_cast<int>(decoded_x.size()), nullptr); |
175 | | - BIGNUM * raw_y_bn = BN_bin2bn(reinterpret_cast<const unsigned char *>(decoded_y.data()), static_cast<int>(decoded_y.size()), nullptr); |
176 | | - std::unique_ptr<BIGNUM, decltype(&BN_free)> x_bn(raw_x_bn, BN_free); |
177 | | - std::unique_ptr<BIGNUM, decltype(&BN_free)> y_bn(raw_y_bn, BN_free); |
178 | | - if (!x_bn || !y_bn) |
179 | | - throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to parse EC key coordinates"); |
| 175 | + size_t coordinate_size = 0; |
| 176 | + if (curve_nid == NID_X9_62_prime256v1) |
| 177 | + coordinate_size = 32; |
| 178 | + else if (curve_nid == NID_secp384r1) |
| 179 | + coordinate_size = 48; |
| 180 | + else if (curve_nid == NID_secp521r1) |
| 181 | + coordinate_size = 66; |
| 182 | + else |
| 183 | + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: unsupported EC curve"); |
| 184 | + |
| 185 | + if (decoded_x.size() > coordinate_size || decoded_y.size() > coordinate_size) |
| 186 | + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: invalid EC key coordinates length"); |
| 187 | + |
| 188 | + std::vector<unsigned char> public_key_octets(1 + 2 * coordinate_size, 0); |
| 189 | + public_key_octets[0] = 0x04; // Uncompressed point format. |
| 190 | + std::memcpy(public_key_octets.data() + 1 + (coordinate_size - decoded_x.size()), decoded_x.data(), decoded_x.size()); |
| 191 | + std::memcpy(public_key_octets.data() + 1 + coordinate_size + (coordinate_size - decoded_y.size()), decoded_y.data(), decoded_y.size()); |
| 192 | + |
| 193 | + const char * group_name = OBJ_nid2sn(curve_nid); |
| 194 | + if (!group_name) |
| 195 | + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: unsupported EC curve"); |
180 | 196 |
|
181 | | - std::unique_ptr<EC_KEY, decltype(&EC_KEY_free)> ec_key(EC_KEY_new_by_curve_name(curve_nid), EC_KEY_free); |
182 | | - if (!ec_key) |
183 | | - throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to construct EC key"); |
| 197 | + std::unique_ptr<OSSL_PARAM_BLD, decltype(&OSSL_PARAM_BLD_free)> params_bld(OSSL_PARAM_BLD_new(), OSSL_PARAM_BLD_free); |
| 198 | + if (!params_bld) |
| 199 | + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to allocate OpenSSL parameter builder"); |
184 | 200 |
|
185 | | - const EC_GROUP * group = EC_KEY_get0_group(ec_key.get()); |
186 | | - if (!group) |
187 | | - throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to read EC group"); |
| 201 | + if (OSSL_PARAM_BLD_push_utf8_string(params_bld.get(), OSSL_PKEY_PARAM_GROUP_NAME, group_name, 0) != 1) |
| 202 | + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to set EC group parameter"); |
188 | 203 |
|
189 | | - std::unique_ptr<EC_POINT, decltype(&EC_POINT_free)> point(EC_POINT_new(group), EC_POINT_free); |
190 | | - if (!point) |
191 | | - throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to create EC point"); |
| 204 | + if (OSSL_PARAM_BLD_push_octet_string(params_bld.get(), OSSL_PKEY_PARAM_PUB_KEY, public_key_octets.data(), public_key_octets.size()) != 1) |
| 205 | + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to set EC public key parameter"); |
192 | 206 |
|
193 | | - if (EC_POINT_set_affine_coordinates(group, point.get(), x_bn.get(), y_bn.get(), nullptr) != 1) |
194 | | - throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: invalid EC public key point"); |
| 207 | + std::unique_ptr<OSSL_PARAM, decltype(&OSSL_PARAM_free)> params(OSSL_PARAM_BLD_to_param(params_bld.get()), OSSL_PARAM_free); |
| 208 | + if (!params) |
| 209 | + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to build OpenSSL parameters"); |
195 | 210 |
|
196 | | - if (EC_KEY_set_public_key(ec_key.get(), point.get()) != 1) |
197 | | - throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to set EC public key"); |
| 211 | + std::unique_ptr<EVP_PKEY_CTX, decltype(&EVP_PKEY_CTX_free)> key_ctx(EVP_PKEY_CTX_new_from_name(nullptr, "EC", nullptr), EVP_PKEY_CTX_free); |
| 212 | + if (!key_ctx) |
| 213 | + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to create EVP key context"); |
198 | 214 |
|
199 | | - std::unique_ptr<EVP_PKEY, decltype(&EVP_PKEY_free)> evp_key(EVP_PKEY_new(), EVP_PKEY_free); |
200 | | - if (!evp_key) |
201 | | - throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to allocate EVP key"); |
| 215 | + if (EVP_PKEY_fromdata_init(key_ctx.get()) <= 0) |
| 216 | + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to initialize EVP key import"); |
202 | 217 |
|
203 | | - if (EVP_PKEY_assign_EC_KEY(evp_key.get(), ec_key.release()) != 1) |
204 | | - throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to assign EC key"); |
| 218 | + EVP_PKEY * raw_evp_key = nullptr; |
| 219 | + if (EVP_PKEY_fromdata(key_ctx.get(), &raw_evp_key, EVP_PKEY_PUBLIC_KEY, params.get()) <= 0) |
| 220 | + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "JWT cannot be validated: failed to import EC public key"); |
| 221 | + |
| 222 | + std::unique_ptr<EVP_PKEY, decltype(&EVP_PKEY_free)> evp_key(raw_evp_key, EVP_PKEY_free); |
205 | 223 |
|
206 | 224 | std::unique_ptr<BIO, decltype(&BIO_free)> bio(BIO_new(BIO_s_mem()), BIO_free); |
207 | 225 | if (!bio) |
@@ -451,8 +469,12 @@ bool JwksJwtProcessor::resolveAndValidate(TokenCredentials & credentials) const |
451 | 469 |
|
452 | 470 | if (public_key.empty()) |
453 | 471 | { |
454 | | - if (jwk.has_jwk_claim("x") && jwk.has_jwk_claim("y")) |
| 472 | + const auto key_type = jwk.get_key_type(); |
| 473 | + if (key_type == "EC") |
455 | 474 | { |
| 475 | + if (!(jwk.has_jwk_claim("x") && jwk.has_jwk_claim("y"))) |
| 476 | + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "{}: invalid JWK: missing 'x'/'y' claims for EC key type", processor_name); |
| 477 | + |
456 | 478 | int curve_nid = NID_undef; |
457 | 479 | std::optional<String> expected_crv; |
458 | 480 | if (algo == "es256") |
@@ -486,15 +508,17 @@ bool JwksJwtProcessor::resolveAndValidate(TokenCredentials & credentials) const |
486 | 508 | const auto y = jwk.get_jwk_claim("y").as_string(); |
487 | 509 | public_key = create_public_key_from_ec_components(x, y, curve_nid); |
488 | 510 | } |
489 | | - else |
| 511 | + else if (key_type == "RSA") |
490 | 512 | { |
491 | 513 | if (!(jwk.has_jwk_claim("n") && jwk.has_jwk_claim("e"))) |
492 | | - throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "{}: invalid JWK: neither 'x'/'y' nor 'n'/'e' found", processor_name); |
| 514 | + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "{}: invalid JWK: missing 'n'/'e' claims for RSA key type", processor_name); |
493 | 515 | LOG_TRACE(getLogger("TokenAuthentication"), "{}: `issuer` or `x5c` not present, verifying {} with RSA components", processor_name, username); |
494 | 516 | const auto modulus = jwk.get_jwk_claim("n").as_string(); |
495 | 517 | const auto exponent = jwk.get_jwk_claim("e").as_string(); |
496 | 518 | public_key = jwt::helper::create_public_key_from_rsa_components(modulus, exponent); |
497 | 519 | } |
| 520 | + else |
| 521 | + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "{}: invalid JWK key type '{}'", processor_name, key_type); |
498 | 522 | } |
499 | 523 |
|
500 | 524 | if (jwk.has_algorithm() && Poco::toLower(jwk.get_algorithm()) != algo) |
|
0 commit comments