diff --git a/lib/jwe.rb b/lib/jwe.rb index 08f999d..94fcaf0 100644 --- a/lib/jwe.rb +++ b/lib/jwe.rb @@ -22,19 +22,18 @@ class InvalidData < RuntimeError; end class << self def encrypt(payload, key, alg: 'RSA-OAEP', enc: 'A128GCM', **more_headers) - header = { alg: alg, enc: enc }.merge(more_headers) - header.delete(:zip) if header[:zip] == '' + header = generate_header(alg, enc, more_headers) check_params(header, key) - cipher = Enc.for(enc).new - cipher.cek = key if alg == 'dir' + payload = apply_zip(header, payload, :compress) - payload = Zip.for(header[:zip]).new.compress(payload) if header[:zip] + cipher = Enc.for(enc) + cipher.cek = key if alg == 'dir' - ciphertext = cipher.encrypt(payload, Base64.jwe_encode(header.to_json)) - encrypted_cek = Alg.for(alg).new(key).encrypt(cipher.cek) + json_hdr = header.to_json + ciphertext = cipher.encrypt(payload, Base64.jwe_encode(json_hdr)) - Serialization::Compact.encode(header.to_json, encrypted_cek, cipher.iv, ciphertext, cipher.tag) + generate_serialization(json_hdr, Alg.encrypt_cek(alg, key, cipher.cek), ciphertext, cipher) end def decrypt(payload, key) @@ -42,17 +41,12 @@ def decrypt(payload, key) header = JSON.parse(header) check_params(header, key) - cek = Alg.for(header['alg']).new(key).decrypt(enc_key) - cipher = Enc.for(header['enc']).new(cek, iv) - cipher.tag = tag + cek = Alg.decrypt_cek(header['alg'], key, enc_key) + cipher = Enc.for(header['enc'], cek, iv, tag) plaintext = cipher.decrypt(ciphertext, payload.split('.').first) - if header['zip'] - Zip.for(header['zip']).new.decompress(plaintext) - else - plaintext - end + apply_zip(header, plaintext, :decompress) end def check_params(header, key) @@ -82,5 +76,24 @@ def param_to_class_name(param) klass = param.gsub(/[-\+]/, '_').downcase.sub(/^[a-z\d]*/) { $&.capitalize } klass.gsub(/_([a-z\d]*)/i) { Regexp.last_match(1).capitalize } end + + def apply_zip(header, data, direction) + zip = header[:zip] || header['zip'] + if zip + Zip.for(zip).new.send(direction, data) + else + data + end + end + + def generate_header(alg, enc, more) + header = { alg: alg, enc: enc }.merge(more) + header.delete(:zip) if header[:zip] == '' + header + end + + def generate_serialization(hdr, cek, content, cipher) + Serialization::Compact.encode(hdr, cek, cipher.iv, content, cipher.tag) + end end end diff --git a/lib/jwe/alg.rb b/lib/jwe/alg.rb index 528e58c..245d9c9 100644 --- a/lib/jwe/alg.rb +++ b/lib/jwe/alg.rb @@ -13,5 +13,13 @@ def self.for(alg) rescue NameError raise NotImplementedError.new("Unsupported alg type: #{alg}") end + + def self.encrypt_cek(alg, key, cek) + self.for(alg).new(key).encrypt(cek) + end + + def self.decrypt_cek(alg, key, encrypted_cek) + self.for(alg).new(key).decrypt(encrypted_cek) + end end end diff --git a/lib/jwe/enc.rb b/lib/jwe/enc.rb index ae6ace4..9d6a6b7 100644 --- a/lib/jwe/enc.rb +++ b/lib/jwe/enc.rb @@ -8,8 +8,11 @@ module JWE # Content encryption algorithms namespace module Enc - def self.for(enc) - const_get(JWE.param_to_class_name(enc)) + def self.for(enc, cek = nil, iv = nil, tag = nil) + klass = const_get(JWE.param_to_class_name(enc)) + inst = klass.new(cek, iv) + inst.tag = tag if tag + inst rescue NameError raise NotImplementedError.new("Unsupported enc type: #{enc}") end diff --git a/spec/jwe/enc_spec.rb b/spec/jwe/enc_spec.rb index 73bfd2d..0cd9e03 100644 --- a/spec/jwe/enc_spec.rb +++ b/spec/jwe/enc_spec.rb @@ -7,8 +7,8 @@ describe JWE::Enc do describe '.for' do - it 'returns a class for the specified enc' do - expect(JWE::Enc.for('A128GCM')).to eq JWE::Enc::A128gcm + it 'returns an instance for the specified enc' do + expect(JWE::Enc.for('A128GCM')).to be_a JWE::Enc::A128gcm end it 'raises an error for a not-implemented enc' do