diff --git a/ext/openssl/ossl_pkey.c b/ext/openssl/ossl_pkey.c index 6af2245f3..0b14d78fe 100644 --- a/ext/openssl/ossl_pkey.c +++ b/ext/openssl/ossl_pkey.c @@ -446,6 +446,186 @@ pkey_generate(int argc, VALUE *argv, VALUE self, int genparam) return ossl_pkey_new(gen_arg.pkey); } +#if OSSL_OPENSSL_PREREQ(3, 0, 0) +#include +#include + +struct pkey_from_data_alias { + char alias[10]; + char param_name[20]; +}; + +static const struct pkey_from_data_alias rsa_aliases[] = { + { "p", OSSL_PKEY_PARAM_RSA_FACTOR1 }, + { "q", OSSL_PKEY_PARAM_RSA_FACTOR2 }, + { "dmp1", OSSL_PKEY_PARAM_RSA_EXPONENT1 }, + { "dmq1", OSSL_PKEY_PARAM_RSA_EXPONENT2 }, + { "iqmp", OSSL_PKEY_PARAM_RSA_COEFFICIENT1 }, + { "", "" } +}; + +static const struct pkey_from_data_alias fcc_aliases[] = { + { "pub_key", OSSL_PKEY_PARAM_PUB_KEY }, + { "priv_key", OSSL_PKEY_PARAM_PRIV_KEY }, + { "", "" } +}; + +struct pkey_from_data_arg { + VALUE options; + OSSL_PARAM_BLD *param_bld; + const OSSL_PARAM *settable_params; + const struct pkey_from_data_alias *aliases; +}; + +static int +add_data_to_builder(VALUE key, VALUE value, VALUE arg) +{ + if(NIL_P(value)) + return ST_CONTINUE; + + if (SYMBOL_P(key)) + key = rb_sym2str(key); + + const char *key_ptr = StringValueCStr(key); + const struct pkey_from_data_arg *params = (const struct pkey_from_data_arg *) arg; + + for(int i = 0; strlen(params->aliases[i].alias) > 0; i++) { + if(strcmp(params->aliases[i].alias, key_ptr) == 0) { + key_ptr = params->aliases[i].param_name; + break; + } + } + + for (const OSSL_PARAM *settable_params = params->settable_params; settable_params->key != NULL; settable_params++) { + if(strcmp(settable_params->key, key_ptr) == 0) { + switch (settable_params->data_type) { + case OSSL_PARAM_INTEGER: + case OSSL_PARAM_UNSIGNED_INTEGER: + if(!OSSL_PARAM_BLD_push_BN(params->param_bld, key_ptr, GetBNPtr(value))) { + ossl_raise(ePKeyError, "OSSL_PARAM_BLD_push_BN"); + } + break; + case OSSL_PARAM_UTF8_STRING: + StringValue(value); + if(!OSSL_PARAM_BLD_push_utf8_string(params->param_bld, key_ptr, RSTRING_PTR(value), RSTRING_LENINT(value))) { + ossl_raise(ePKeyError, "OSSL_PARAM_BLD_push_utf8_string"); + } + break; + + case OSSL_PARAM_OCTET_STRING: + StringValue(value); + if(!OSSL_PARAM_BLD_push_octet_string(params->param_bld, key_ptr, RSTRING_PTR(value), RSTRING_LENINT(value))) { + ossl_raise(ePKeyError, "OSSL_PARAM_BLD_push_octet_string"); + } + break; + case OSSL_PARAM_UTF8_PTR: + case OSSL_PARAM_OCTET_PTR: + ossl_raise(ePKeyError, "Unsupported parameter \"%s\", handling of OSSL_PARAM_UTF8_PTR and OSSL_PARAM_OCTET_PTR not implemented", key_ptr); + break; + default: + ossl_raise(ePKeyError, "Unsupported parameter \"%s\"", key_ptr); + break; + } + + return ST_CONTINUE; + } + } + + VALUE supported_parameters = rb_ary_new(); + + for (const OSSL_PARAM *settable_params = params->settable_params; settable_params->key != NULL; settable_params++) { + rb_ary_push(supported_parameters, rb_str_new_cstr(settable_params->key)); + } + + for(int i = 0; strlen(params->aliases[i].alias) > 0; i++) { + rb_ary_push(supported_parameters, rb_str_new_cstr(params->aliases[i].alias)); + } + + ossl_raise(ePKeyError, "Invalid parameter \"%s\". Supported parameters: %"PRIsVALUE, key_ptr, rb_ary_join(supported_parameters, rb_str_new2(", "))); +} + +static VALUE +iterate_from_data_options_cb(VALUE value) +{ + struct pkey_from_data_arg *args = (void *)value; + + rb_hash_foreach(args->options, &add_data_to_builder, (VALUE) args); + + return Qnil; +} + +static VALUE +pkey_from_data(int argc, VALUE *argv, VALUE self) +{ + VALUE alg, options; + rb_scan_args(argc, argv, "11", &alg, &options); + + const char* algorithm = StringValueCStr(alg); + + EVP_PKEY_CTX *ctx = EVP_PKEY_CTX_new_from_name(NULL, algorithm, NULL); + + if (ctx == NULL) + ossl_raise(ePKeyError, "EVP_PKEY_CTX_new_from_name"); + + struct pkey_from_data_arg from_data_args = { 0 }; + + from_data_args.param_bld = OSSL_PARAM_BLD_new(); + from_data_args.options = options; + + if (from_data_args.param_bld == NULL) { + EVP_PKEY_CTX_free(ctx); + ossl_raise(ePKeyError, "OSSL_PARAM_BLD_new"); + } + + from_data_args.settable_params = EVP_PKEY_fromdata_settable(ctx, EVP_PKEY_KEYPAIR); + + if (from_data_args.settable_params == NULL) { + EVP_PKEY_CTX_free(ctx); + OSSL_PARAM_BLD_free(from_data_args.param_bld); + ossl_raise(ePKeyError, "EVP_PKEY_fromdata_settable"); + } + + if (strcmp("RSA", algorithm) == 0) + from_data_args.aliases = rsa_aliases; + else + from_data_args.aliases = fcc_aliases; + + int state; + rb_protect(iterate_from_data_options_cb, (VALUE) &from_data_args, &state); + + if(state) { + EVP_PKEY_CTX_free(ctx); + OSSL_PARAM_BLD_free(from_data_args.param_bld); + rb_jump_tag(state); + } + + OSSL_PARAM *params = OSSL_PARAM_BLD_to_param(from_data_args.param_bld); + OSSL_PARAM_BLD_free(from_data_args.param_bld); + + if (params == NULL) { + EVP_PKEY_CTX_free(ctx); + ossl_raise(ePKeyError, "OSSL_PARAM_BLD_to_param"); + } + + EVP_PKEY *pkey = NULL; + + if (EVP_PKEY_fromdata_init(ctx) <= 0) { + EVP_PKEY_CTX_free(ctx); + ossl_raise(ePKeyError, "EVP_PKEY_fromdata_init"); + } + + if (EVP_PKEY_fromdata(ctx, &pkey, EVP_PKEY_KEYPAIR, params) <= 0) { + EVP_PKEY_CTX_free(ctx); + ossl_raise(ePKeyError, "EVP_PKEY_fromdata"); + } + + EVP_PKEY_CTX_free(ctx); + + return ossl_pkey_new(pkey); +} + +#endif + /* * call-seq: * OpenSSL::PKey.generate_parameters(algo_name [, options]) -> pkey @@ -498,6 +678,30 @@ ossl_pkey_s_generate_key(int argc, VALUE *argv, VALUE self) return pkey_generate(argc, argv, self, 0); } +/* + * call-seq: + * OpenSSL::PKey.from_data(algo_name, parameters) -> pkey + * + * Generates a new key based on given key parameters. + * NOTE: Requires OpenSSL 3.0 or later. + * + * The first parameter is the type of the key to create, given as a String, for example RSA, DSA, EC etc. + * Second parameter is the parameters to be used for the key. + * + * For details algorithms and parameters see https://www.openssl.org/docs/man3.0/man3/EVP_PKEY_fromdata.html + * + * == Example + * pkey = OpenSSL::PKey.from_data("RSA", n: 3161751493, e: 65537, d: 2064855961) + * pkey.private? #=> true + * pkey.n #=> # + */ +#if OSSL_OPENSSL_PREREQ(3, 0, 0) +static VALUE +ossl_pkey_s_from_data(int argc, VALUE *argv, VALUE self) +{ + return pkey_from_data(argc, argv, self); +} +#endif /* * TODO: There is no convenient way to check the presence of public key * components on OpenSSL 3.0. But since keys are immutable on 3.0, pkeys without @@ -1751,6 +1955,9 @@ Init_ossl_pkey(void) rb_define_module_function(mPKey, "read", ossl_pkey_new_from_data, -1); rb_define_module_function(mPKey, "generate_parameters", ossl_pkey_s_generate_parameters, -1); rb_define_module_function(mPKey, "generate_key", ossl_pkey_s_generate_key, -1); +#if OSSL_OPENSSL_PREREQ(3, 0, 0) + rb_define_module_function(mPKey, "from_data", ossl_pkey_s_from_data, -1); +#endif #ifdef HAVE_EVP_PKEY_NEW_RAW_PRIVATE_KEY rb_define_module_function(mPKey, "new_raw_private_key", ossl_pkey_new_raw_private_key, 2); rb_define_module_function(mPKey, "new_raw_public_key", ossl_pkey_new_raw_public_key, 2); diff --git a/test/openssl/test_pkey.rb b/test/openssl/test_pkey.rb index 811d5103d..244a2ebdc 100644 --- a/test/openssl/test_pkey.rb +++ b/test/openssl/test_pkey.rb @@ -224,4 +224,207 @@ def test_to_text rsa = Fixtures.pkey("rsa1024") assert_include rsa.to_text, "publicExponent" end + + if openssl?(3, 0, 0) + def test_from_data_with_invalid_alg + assert_raise_with_message(OpenSSL::PKey::PKeyError, /^EVP_PKEY_CTX_new_from_name: unsupported/) do + OpenSSL::PKey.from_data("ASR", {}) + end + end + + def test_s_from_data_rsa_with_n_e_and_d_given_as_integers + new_key = OpenSSL::PKey.from_data("RSA", "n" => 3161751493, + "e" => 65537, + "d" => 2064855961) + + assert_instance_of OpenSSL::PKey::RSA, new_key + assert_equal true, new_key.private? + assert_equal OpenSSL::BN.new(3161751493), new_key.n + assert_equal OpenSSL::BN.new(65537), new_key.e + assert_equal OpenSSL::BN.new(2064855961), new_key.d + end + + def test_s_from_data_rsa_with_n_e_and_d_given_as_symbols + new_key = OpenSSL::PKey.from_data("RSA", n: OpenSSL::BN.new(3161751493), + e: OpenSSL::BN.new(65537), + d: OpenSSL::BN.new(2064855961)) + + assert_instance_of OpenSSL::PKey::RSA, new_key + assert_equal true, new_key.private? + assert_equal OpenSSL::BN.new(3161751493), new_key.n + assert_equal OpenSSL::BN.new(65537), new_key.e + assert_equal OpenSSL::BN.new(2064855961), new_key.d + end + + def test_s_from_data_rsa_with_n_and_e_given + new_key = OpenSSL::PKey.from_data("RSA", "n" => OpenSSL::BN.new(3161751493), + "e" => OpenSSL::BN.new(65537)) + + assert_instance_of OpenSSL::PKey::RSA, new_key + assert_equal false, new_key.private? + assert_equal OpenSSL::BN.new(3161751493), new_key.n + assert_equal OpenSSL::BN.new(65537), new_key.e + assert_equal nil, new_key.d + end + + def test_s_from_data_rsa_with_openssl_internal_names + source = Fixtures.pkey("rsa2048") + new_key = OpenSSL::PKey.from_data("RSA", "n" => source.n, + "e" => source.e, + "d" => source.d, + "rsa-factor1" => source.p, + "rsa-factor2" => source.q, + "rsa-exponent1" => source.dmp1, + "rsa-exponent2" => source.dmq1, + "rsa-coefficient1" => source.iqmp) + + assert_equal source.n, new_key.n + assert_equal source.e, new_key.e + assert_equal source.d, new_key.d + assert_equal source.p, new_key.p + assert_equal source.q, new_key.q + assert_equal source.dmp1, new_key.dmp1 + assert_equal source.dmq1, new_key.dmq1 + assert_equal source.iqmp, new_key.iqmp + assert_equal source.to_pem, new_key.to_pem + end + + def test_s_from_data_rsa_with_simple_names + source = Fixtures.pkey("rsa2048") + new_key = OpenSSL::PKey.from_data("RSA", "n" => source.n, + "e" => source.e, + "d" => source.d, + "p" => source.p, + "q" => source.q, + "dmp1" => source.dmp1, + "dmq1" => source.dmq1, + "iqmp" => source.iqmp) + + assert_equal source.n, new_key.n + assert_equal source.e, new_key.e + assert_equal source.d, new_key.d + assert_equal source.p, new_key.p + assert_equal source.q, new_key.q + assert_equal source.dmp1, new_key.dmp1 + assert_equal source.dmq1, new_key.dmq1 + assert_equal source.iqmp, new_key.iqmp + assert_equal source.to_pem, new_key.to_pem + end + + def test_s_from_data_rsa_with_invalid_parameter + assert_raise_with_message(OpenSSL::PKey::PKeyError, /Invalid parameter "invalid"/) do + OpenSSL::PKey.from_data("RSA", "invalid" => 1234) + end + end + + def test_s_from_data_ec_pub_given_as_string + source = OpenSSL::PKey::EC.generate("prime256v1") + new_key = OpenSSL::PKey.from_data("EC", "group" => source.group.curve_name, + "pub" => source.public_key.to_bn.to_s(2)) + assert_instance_of OpenSSL::PKey::EC, new_key + assert_equal source.group.curve_name, new_key.group.curve_name + assert_equal source.public_key, new_key.public_key + assert_equal nil, new_key.private_key + end + + def test_s_from_data_ec_priv_given_as_bn + source = OpenSSL::PKey::EC.generate("prime256v1") + new_key = OpenSSL::PKey.from_data("EC", "group" => source.group.curve_name, + "priv" => source.private_key.to_bn) + assert_instance_of OpenSSL::PKey::EC, new_key + assert_equal source.group.curve_name, new_key.group.curve_name + assert_equal source.private_key, new_key.private_key + assert_equal nil, new_key.public_key + end + + def test_s_from_data_ec_priv_given_as_integer + source = OpenSSL::PKey::EC.generate("prime256v1") + new_key = OpenSSL::PKey.from_data("EC", "group" => source.group.curve_name, + "priv" => source.private_key.to_i) + assert_instance_of OpenSSL::PKey::EC, new_key + assert_equal source.group.curve_name, new_key.group.curve_name + assert_equal source.private_key, new_key.private_key + assert_equal nil, new_key.public_key + end + + def test_s_from_data_ec_priv_and_pub_given_for_different_curves + [OpenSSL::PKey::EC.generate("prime256v1"), + OpenSSL::PKey::EC.generate("secp384r1"), + OpenSSL::PKey::EC.generate("secp521r1")].each do |source| + new_key = OpenSSL::PKey.from_data("EC", "group" => source.group.curve_name, + "pub" => source.public_key.to_bn.to_s(2), + "priv" => source.private_key.to_i) + assert_instance_of OpenSSL::PKey::EC, new_key + assert_equal source.group.curve_name, new_key.group.curve_name + assert_equal source.private_key, new_key.private_key + assert_equal source.public_key, new_key.public_key + end + end + + def test_s_from_data_ec_pub_given_as_integer + assert_raise_with_message(TypeError, "no implicit conversion of Integer into String") do + OpenSSL::PKey.from_data("EC", { "group" => "prime256v1", "pub" => 12345 }) + end + end + + def test_s_from_data_ec_with_invalid_parameter + assert_raise_with_message(OpenSSL::PKey::PKeyError, /Invalid parameter "invalid"/) do + OpenSSL::PKey.from_data("EC", "invalid" => 1234) + end + end + + def test_s_from_data_dsa_with_all_supported_parameters + source = Fixtures.pkey("dsa1024") + new_key = OpenSSL::PKey.from_data("DSA", "pub" => source.params["pub_key"], + "priv" => source.params["priv_key"], + "p" => source.params["p"], + "q" => source.params["q"], + "g" => source.params["g"]) + + assert_instance_of OpenSSL::PKey::DSA, new_key + assert_equal source.params, new_key.params + end + + def test_s_from_data_dsa_with_gem_specific_keys + source = Fixtures.pkey("dsa2048") + new_key = OpenSSL::PKey.from_data("DSA", source.params) + + assert_equal source.params, new_key.params + end + + def test_s_from_data_dsa_with_invalid_parameter + assert_raise_with_message(OpenSSL::PKey::PKeyError, /Invalid parameter "invalid". Supported parameters: p, q, g, j/) do + OpenSSL::PKey.from_data("DSA", "invalid" => 1234) + end + end + + def test_s_from_data_dh_with_all_supported_parameters + source = Fixtures.pkey("dh2048_ffdhe2048") + new_key = OpenSSL::PKey.from_data("DH", source.params) + + assert_instance_of OpenSSL::PKey::DH, new_key + assert_equal source.params, new_key.params + end + + def test_s_from_data_dh_with_invalid_parameter + assert_raise_with_message(OpenSSL::PKey::PKeyError, /Invalid parameter "invalid"/) do + OpenSSL::PKey.from_data("DH", "invalid" => 1234) + end + end + + def test_s_from_data_ed25519 + # Ed25519 is not FIPS-approved. + omit_on_fips + + pub_pem = <<~EOF + -----BEGIN PUBLIC KEY----- + MCowBQYDK2VwAyEA0I6olrZGYml7JGusuKJW9G7D0DZ9UormSady9kR7V4Q= + -----END PUBLIC KEY----- + EOF + + key = OpenSSL::PKey.from_data("ED25519", "pub" => "\xD0\x8E\xA8\x96\xB6Fbi{$k\xAC\xB8\xA2V\xF4n\xC3\xD06}R\x8A\xE6I\xA7r\xF6D{W\x84") + assert_instance_of OpenSSL::PKey::PKey, key + assert_equal pub_pem, key.public_to_pem + end + end end