/* SPDX-License-Identifier: BSL-1.0 OR BSD-3-Clause */ #ifndef MPT_CRYPTO_JWK_HPP #define MPT_CRYPTO_JWK_HPP #include "mpt/base/alloc.hpp" #include "mpt/base/memory.hpp" #include "mpt/base/namespace.hpp" #include "mpt/base/saturate_cast.hpp" #include "mpt/base/span.hpp" #include "mpt/binary/base64url.hpp" #include "mpt/crypto/exception.hpp" #include "mpt/crypto/hash.hpp" #include "mpt/detect/nlohmann_json.hpp" #include "mpt/json/json.hpp" #include "mpt/out_of_memory/out_of_memory.hpp" #include "mpt/string/types.hpp" #include "mpt/string/utility.hpp" #include "mpt/string_transcode/transcode.hpp" #include #include #include #include #include #include #include #include #if MPT_OS_WINDOWS #include // must be before wincrypt.h for clang-cl #include #include // must be before ncrypt.h #include #endif // MPT_OS_WINDOWS namespace mpt { inline namespace MPT_INLINE_NS { namespace crypto { #if MPT_OS_WINDOWS && MPT_DETECTED_NLOHMANN_JSON class keystore { public: enum class domain { system = 1, user = 2, }; private: NCRYPT_PROV_HANDLE hProv = NULL; domain ProvDomain = domain::user; private: void cleanup() { if (hProv) { NCryptFreeObject(hProv); hProv = NULL; } } public: keystore(domain d) : ProvDomain(d) { try { CheckSECURITY_STATUS(NCryptOpenStorageProvider(&hProv, MS_KEY_STORAGE_PROVIDER, 0), "NCryptOpenStorageProvider"); } catch (...) { cleanup(); throw; } } ~keystore() { return; } operator NCRYPT_PROV_HANDLE() { return hProv; } keystore::domain store_domain() const { return ProvDomain; } }; namespace asymmetric { class signature_verification_failed : public std::runtime_error { public: signature_verification_failed() : std::runtime_error("Signature Verification failed.") { return; } }; inline std::vector jws_get_keynames(const mpt::ustring & jws_) { std::vector result; nlohmann::json jws = nlohmann::json::parse(mpt::transcode(mpt::common_encoding::utf8, jws_)); for (const auto & s : jws["signatures"]) { result.push_back(s["header"]["kid"]); } return result; } struct RSASSA_PSS_SHA512_traits { using hash_type = mpt::crypto::hash::SHA512; static constexpr const char * jwk_alg = "PS512"; }; template class rsassa_pss { public: using hash_type = typename Traits::hash_type; static constexpr const char * jwk_alg = Traits::jwk_alg; struct public_key_data { mpt::ustring name; uint32 length = 0; std::vector public_exp; std::vector modulus; std::vector as_cng_blob() const { BCRYPT_RSAKEY_BLOB rsakey_blob{}; rsakey_blob.Magic = BCRYPT_RSAPUBLIC_MAGIC; rsakey_blob.BitLength = length; rsakey_blob.cbPublicExp = mpt::saturate_cast(public_exp.size()); rsakey_blob.cbModulus = mpt::saturate_cast(modulus.size()); std::vector result(sizeof(BCRYPT_RSAKEY_BLOB) + public_exp.size() + modulus.size()); std::memcpy(result.data(), &rsakey_blob, sizeof(BCRYPT_RSAKEY_BLOB)); std::memcpy(result.data() + sizeof(BCRYPT_RSAKEY_BLOB), public_exp.data(), public_exp.size()); std::memcpy(result.data() + sizeof(BCRYPT_RSAKEY_BLOB) + public_exp.size(), modulus.data(), modulus.size()); return result; } mpt::ustring as_jwk() const { nlohmann::json json = nlohmann::json::object(); json["kid"] = name; json["kty"] = "RSA"; json["alg"] = jwk_alg; json["use"] = "sig"; json["e"] = mpt::encode_base64url(mpt::as_span(public_exp)); json["n"] = mpt::encode_base64url(mpt::as_span(modulus)); return mpt::transcode(mpt::common_encoding::utf8, json.dump()); } static public_key_data from_jwk(const mpt::ustring & jwk) { public_key_data result; try { nlohmann::json json = nlohmann::json::parse(mpt::transcode(mpt::common_encoding::utf8, jwk)); if (json["kty"] != "RSA") { throw std::runtime_error("Cannot parse RSA public key JWK."); } if (json["alg"] != jwk_alg) { throw std::runtime_error("Cannot parse RSA public key JWK."); } if (json["use"] != "sig") { throw std::runtime_error("Cannot parse RSA public key JWK."); } result.name = json["kid"].get(); result.public_exp = mpt::decode_base64url(json["e"]); result.modulus = mpt::decode_base64url(json["n"]); result.length = mpt::saturate_cast(result.modulus.size() * 8); } catch (mpt::out_of_memory e) { mpt::rethrow_out_of_memory(e); } catch (...) { throw std::runtime_error("Cannot parse RSA public key JWK."); } return result; } static public_key_data from_cng_blob(const mpt::ustring & name, const std::vector & blob) { public_key_data result; BCRYPT_RSAKEY_BLOB rsakey_blob{}; if (blob.size() < sizeof(BCRYPT_RSAKEY_BLOB)) { throw std::runtime_error("Cannot parse RSA public key blob."); } std::memcpy(&rsakey_blob, blob.data(), sizeof(BCRYPT_RSAKEY_BLOB)); if (rsakey_blob.Magic != BCRYPT_RSAPUBLIC_MAGIC) { throw std::runtime_error("Cannot parse RSA public key blob."); } if (blob.size() != sizeof(BCRYPT_RSAKEY_BLOB) + rsakey_blob.cbPublicExp + rsakey_blob.cbModulus) { throw std::runtime_error("Cannot parse RSA public key blob."); } result.name = name; result.length = rsakey_blob.BitLength; result.public_exp = std::vector(blob.data() + sizeof(BCRYPT_RSAKEY_BLOB), blob.data() + sizeof(BCRYPT_RSAKEY_BLOB) + rsakey_blob.cbPublicExp); result.modulus = std::vector(blob.data() + sizeof(BCRYPT_RSAKEY_BLOB) + rsakey_blob.cbPublicExp, blob.data() + sizeof(BCRYPT_RSAKEY_BLOB) + rsakey_blob.cbPublicExp + rsakey_blob.cbModulus); return result; } }; static std::vector parse_jwk_set(const mpt::ustring & jwk_set_) { std::vector result; nlohmann::json jwk_set = nlohmann::json::parse(mpt::transcode(mpt::common_encoding::utf8, jwk_set_)); for (const auto & k : jwk_set["keys"]) { try { result.push_back(public_key_data::from_jwk(mpt::transcode(mpt::common_encoding::utf8, k.dump()))); } catch (...) { // nothing } } return result; } class public_key { private: mpt::ustring name; BCRYPT_ALG_HANDLE hSignAlg = NULL; BCRYPT_KEY_HANDLE hKey = NULL; private: void cleanup() { if (hKey) { BCryptDestroyKey(hKey); hKey = NULL; } if (hSignAlg) { BCryptCloseAlgorithmProvider(hSignAlg, 0); hSignAlg = NULL; } } public: public_key(const public_key_data & data) { try { name = data.name; CheckNTSTATUS(BCryptOpenAlgorithmProvider(&hSignAlg, BCRYPT_RSA_ALGORITHM, NULL, 0), "BCryptOpenAlgorithmProvider"); std::vector blob = data.as_cng_blob(); CheckNTSTATUS(BCryptImportKeyPair(hSignAlg, NULL, BCRYPT_RSAPUBLIC_BLOB, &hKey, mpt::byte_cast(blob.data()), mpt::saturate_cast(blob.size()), 0), "BCryptImportKeyPair"); } catch (...) { cleanup(); throw; } } public_key(const public_key & other) : public_key(other.get_public_key_data()) { return; } public_key & operator=(const public_key & other) { if (&other == this) { return *this; } public_key copy(other); { using std::swap; swap(copy.name, name); swap(copy.hSignAlg, hSignAlg); swap(copy.hKey, hKey); } return *this; } ~public_key() { cleanup(); } mpt::ustring get_name() const { return name; } public_key_data get_public_key_data() const { DWORD bytes = 0; CheckNTSTATUS(BCryptExportKey(hKey, NULL, BCRYPT_RSAPUBLIC_BLOB, NULL, 0, &bytes, 0), "BCryptExportKey"); std::vector blob(bytes); CheckNTSTATUS(BCryptExportKey(hKey, NULL, BCRYPT_RSAPUBLIC_BLOB, mpt::byte_cast(blob.data()), mpt::saturate_cast(blob.size()), &bytes, 0), "BCryptExportKey"); return public_key_data::from_cng_blob(name, blob); } void verify_hash(typename hash_type::result_type hash, std::vector signature) { BCRYPT_PSS_PADDING_INFO paddinginfo; paddinginfo.pszAlgId = hash_type::traits::bcrypt_name; paddinginfo.cbSalt = mpt::saturate_cast(hash_type::traits::output_bytes); NTSTATUS result = BCryptVerifySignature(hKey, &paddinginfo, mpt::byte_cast(hash.data()), mpt::saturate_cast(hash.size()), mpt::byte_cast(signature.data()), mpt::saturate_cast(signature.size()), BCRYPT_PAD_PSS); if (result == 0x00000000 /*STATUS_SUCCESS*/) { return; } if (result == 0xC000A000 /*STATUS_INVALID_SIGNATURE*/) { throw signature_verification_failed(); } CheckNTSTATUS(result, "BCryptVerifySignature"); throw signature_verification_failed(); } void verify(mpt::const_byte_span payload, const std::vector & signature) { verify_hash(hash_type().process(payload).result(), signature); } std::vector jws_verify(const mpt::ustring & jws_) { nlohmann::json jws = nlohmann::json::parse(mpt::transcode(mpt::common_encoding::utf8, jws_)); std::vector payload = mpt::decode_base64url(jws["payload"]); nlohmann::json jsignature = nlohmann::json::object(); bool sigfound = false; for (const auto & s : jws["signatures"]) { if (s["header"]["kid"] == mpt::transcode(mpt::common_encoding::utf8, name)) { jsignature = s; sigfound = true; } } if (!sigfound) { throw signature_verification_failed(); } std::vector protectedheaderraw = mpt::decode_base64url(jsignature["protected"]); std::vector signature = mpt::decode_base64url(jsignature["signature"]); nlohmann::json header = nlohmann::json::parse(mpt::buffer_cast(protectedheaderraw)); if (header["typ"] != "JWT") { throw signature_verification_failed(); } if (header["alg"] != jwk_alg) { throw signature_verification_failed(); } verify_hash(hash_type().process(mpt::byte_cast(mpt::as_span(mpt::transcode(mpt::common_encoding::utf8, mpt::encode_base64url(mpt::as_span(protectedheaderraw)) + MPT_USTRING(".") + mpt::encode_base64url(mpt::as_span(payload)))))).result(), signature); return payload; } std::vector jws_compact_verify(const mpt::ustring & jws) { std::vector parts = mpt::split(jws, MPT_USTRING(".")); if (parts.size() != 3) { throw signature_verification_failed(); } std::vector protectedheaderraw = mpt::decode_base64url(parts[0]); std::vector payload = mpt::decode_base64url(parts[1]); std::vector signature = mpt::decode_base64url(parts[2]); nlohmann::json header = nlohmann::json::parse(mpt::buffer_cast(protectedheaderraw)); if (header["typ"] != "JWT") { throw signature_verification_failed(); } if (header["alg"] != jwk_alg) { throw signature_verification_failed(); } verify_hash(hash_type().process(mpt::byte_cast(mpt::as_span(mpt::transcode(mpt::common_encoding::utf8, mpt::encode_base64url(mpt::as_span(protectedheaderraw)) + MPT_USTRING(".") + mpt::encode_base64url(mpt::as_span(payload)))))).result(), signature); return payload; } }; static inline void jws_verify_at_least_one(std::vector & keys, const std::vector & expectedPayload, const mpt::ustring & signature) { std::vector keynames = mpt::crypto::asymmetric::jws_get_keynames(signature); bool sigchecked = false; for (const auto & keyname : keynames) { for (auto & key : keys) { if (key.get_name() == keyname) { if (expectedPayload != key.jws_verify(signature)) { throw mpt::crypto::asymmetric::signature_verification_failed(); } sigchecked = true; } } } if (!sigchecked) { throw mpt::crypto::asymmetric::signature_verification_failed(); } } static inline std::vector jws_verify_at_least_one(std::vector & keys, const mpt::ustring & signature) { std::vector keynames = mpt::crypto::asymmetric::jws_get_keynames(signature); for (const auto & keyname : keynames) { for (auto & key : keys) { if (key.get_name() == keyname) { return key.jws_verify(signature); } } } throw mpt::crypto::asymmetric::signature_verification_failed(); } class managed_private_key { private: mpt::ustring name; NCRYPT_KEY_HANDLE hKey = NULL; private: void cleanup() { if (hKey) { NCryptFreeObject(hKey); hKey = NULL; } } public: managed_private_key() = delete; managed_private_key(const managed_private_key &) = delete; managed_private_key & operator=(const managed_private_key &) = delete; managed_private_key(keystore & keystore) { try { CheckSECURITY_STATUS(NCryptCreatePersistedKey(keystore, &hKey, BCRYPT_RSA_ALGORITHM, NULL, 0, 0), "NCryptCreatePersistedKey"); } catch (...) { cleanup(); throw; } } managed_private_key(keystore & keystore, const mpt::ustring & name_) : name(name_) { try { SECURITY_STATUS openKeyStatus = NCryptOpenKey(keystore, &hKey, mpt::transcode(name).c_str(), 0, (keystore.store_domain() == keystore::domain::system ? NCRYPT_MACHINE_KEY_FLAG : 0)); if (openKeyStatus == NTE_BAD_KEYSET) { CheckSECURITY_STATUS(NCryptCreatePersistedKey(keystore, &hKey, BCRYPT_RSA_ALGORITHM, mpt::transcode(name).c_str(), 0, (keystore.store_domain() == keystore::domain::system ? NCRYPT_MACHINE_KEY_FLAG : 0)), "NCryptCreatePersistedKey"); DWORD length = mpt::saturate_cast(keysize); CheckSECURITY_STATUS(NCryptSetProperty(hKey, NCRYPT_LENGTH_PROPERTY, (PBYTE)&length, mpt::saturate_cast(sizeof(DWORD)), 0), "NCryptSetProperty"); CheckSECURITY_STATUS(NCryptFinalizeKey(hKey, 0), "NCryptFinalizeKey"); } else { CheckSECURITY_STATUS(openKeyStatus, "NCryptOpenKey"); } } catch (...) { cleanup(); throw; } } ~managed_private_key() { cleanup(); } void destroy() { CheckSECURITY_STATUS(NCryptDeleteKey(hKey, 0), "NCryptDeleteKey"); name = mpt::ustring(); hKey = NULL; } public: public_key_data get_public_key_data() const { DWORD bytes = 0; CheckSECURITY_STATUS(NCryptExportKey(hKey, NULL, BCRYPT_RSAPUBLIC_BLOB, NULL, NULL, 0, &bytes, 0), "NCryptExportKey"); std::vector blob(bytes); CheckSECURITY_STATUS(NCryptExportKey(hKey, NULL, BCRYPT_RSAPUBLIC_BLOB, NULL, mpt::byte_cast(blob.data()), mpt::saturate_cast(blob.size()), &bytes, 0), "NCryptExportKey"); return public_key_data::from_cng_blob(name, blob); } std::vector sign_hash(typename hash_type::result_type hash) { BCRYPT_PSS_PADDING_INFO paddinginfo; paddinginfo.pszAlgId = hash_type::traits::bcrypt_name; paddinginfo.cbSalt = mpt::saturate_cast(hash_type::traits::output_bytes); DWORD bytes = 0; CheckSECURITY_STATUS(NCryptSignHash(hKey, &paddinginfo, mpt::byte_cast(hash.data()), mpt::saturate_cast(hash.size()), NULL, 0, &bytes, BCRYPT_PAD_PSS), "NCryptSignHash"); std::vector result(bytes); CheckSECURITY_STATUS(NCryptSignHash(hKey, &paddinginfo, mpt::byte_cast(hash.data()), mpt::saturate_cast(hash.size()), mpt::byte_cast(result.data()), mpt::saturate_cast(result.size()), &bytes, BCRYPT_PAD_PSS), "NCryptSignHash"); return result; } std::vector sign(mpt::const_byte_span payload) { return sign_hash(hash_type().process(payload).result()); } mpt::ustring jws_compact_sign(mpt::const_byte_span payload) { nlohmann::json protectedheader = nlohmann::json::object(); protectedheader["typ"] = "JWT"; protectedheader["alg"] = jwk_alg; std::string protectedheaderstring = protectedheader.dump(); std::vector signature = sign_hash(hash_type().process(mpt::byte_cast(mpt::as_span(mpt::transcode(mpt::common_encoding::utf8, mpt::encode_base64url(mpt::as_span(protectedheaderstring)) + MPT_USTRING(".") + mpt::encode_base64url(payload))))).result()); return mpt::encode_base64url(mpt::as_span(protectedheaderstring)) + MPT_USTRING(".") + mpt::encode_base64url(payload) + MPT_USTRING(".") + mpt::encode_base64url(mpt::as_span(signature)); } mpt::ustring jws_sign(mpt::const_byte_span payload) { nlohmann::json protectedheader = nlohmann::json::object(); protectedheader["typ"] = "JWT"; protectedheader["alg"] = jwk_alg; std::string protectedheaderstring = protectedheader.dump(); nlohmann::json header = nlohmann::json::object(); header["kid"] = name; std::vector signature = sign_hash(hash_type().process(mpt::byte_cast(mpt::as_span(mpt::transcode(mpt::common_encoding::utf8, mpt::encode_base64url(mpt::as_span(protectedheaderstring)) + MPT_USTRING(".") + mpt::encode_base64url(payload))))).result()); nlohmann::json jws = nlohmann::json::object(); jws["payload"] = mpt::encode_base64url(payload); jws["signatures"] = nlohmann::json::array(); nlohmann::json jsignature = nlohmann::json::object(); jsignature["header"] = header; jsignature["protected"] = mpt::encode_base64url(mpt::as_span(protectedheaderstring)); jsignature["signature"] = mpt::encode_base64url(mpt::as_span(signature)); jws["signatures"].push_back(jsignature); return mpt::transcode(mpt::common_encoding::utf8, jws.dump()); } }; }; // class rsassa_pss } // namespace asymmetric #endif // MPT_OS_WINDOWS && MPT_DETECTED_NLOHMANN_JSON } // namespace crypto } // namespace MPT_INLINE_NS } // namespace mpt #endif // MPT_CRYPTO_JWK_HPP