Skip to content

Commit

Permalink
Merge pull request #431 from andreidenissov-cog/feature/424
Browse files Browse the repository at this point in the history
Feature/424
  • Loading branch information
andreidenissov-cog authored Jul 3, 2020
2 parents ccec83c + 7b90449 commit 7135868
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 18 deletions.
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ cma

apscheduler
pycryptodome
paramiko
sshpubkeys
PyNaCl
requests
requests_toolbelt
Expand Down
113 changes: 102 additions & 11 deletions studio/encrypted_payload_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
from Crypto.Hash import SHA256
import nacl.secret
import nacl.utils
import nacl.signing
import paramiko
import base64
import json
from sshpubkeys import SSHKey

from .payload_builder import PayloadBuilder
from studio import logs
Expand All @@ -15,28 +18,71 @@ class EncryptedPayloadBuilder(PayloadBuilder):
Implementation for experiment payload builder
using public key RSA encryption.
"""
def __init__(self, name: str, keypath: str):
def __init__(self, name: str,
receiver_keypath: str,
sender_keypath: str = None):
"""
param: name - payload builder name
param: keypath - file path to .pem file with public key
param: receiver_keypath - file path to .pem file
with recipient public key
param: sender_keypath - file path to .pem file
with sender private key
"""
super(EncryptedPayloadBuilder, self).__init__(name)

# XXX Set logger verbosity level here
self.logger = logs.getLogger(self.__class__.__name__)

self.key_path = keypath
self.recipient_key_path = receiver_keypath
self.recipient_key = None
try:
self.recipient_key = RSA.import_key(open(self.key_path).read())
self.recipient_key =\
RSA.import_key(open(self.recipient_key_path).read())
except:
self.logger.error(
"FAILED to import recipient public key from: {0}".format(self.key_path))
return
msg = "FAILED to import recipient public key from: {0}"\
.format(self.recipient_key_path)
self.logger.error(msg)
raise ValueError(msg)

self.sender_key_path = sender_keypath
self.sender_key = None
self.sender_fingerprint = None

if self.sender_key_path is None:
self.logger.error("Signing key path must be specified for encrypted payloads. ABORTING.")
raise ValueError()

# We expect ed25519 signing key in "private key" format
try:
self.sender_key =\
paramiko.Ed25519Key(filename=self.sender_key_path)

if self.sender_key is None:
self._raise_error("Failed to import private signing key. ABORTING.")
except:
self._raise_error("FAILED to open/read private signing key file: {0}"\
.format(self.sender_key_path))

self.sender_fingerprint = \
self._get_fingerprint(self.sender_key)

self.simple_builder =\
UnencryptedPayloadBuilder("simple-builder-for-encryptor")

def _raise_error(self, msg: str):
self.logger.error(msg)
raise ValueError(msg)

def _get_fingerprint(self, signing_key):
ssh_key = SSHKey("ssh-ed25519 {0}"
.format(signing_key.get_base64()))
try:
ssh_key.parse()
except:
self._raise_error("INVALID signing key type. ABORTING.")

return ssh_key.hash_sha256() # SHA256:xyz

def _import_rsa_key(self, key_path: str):
key = None
try:
Expand All @@ -47,6 +93,13 @@ def _import_rsa_key(self, key_path: str):
key = None
return key

def _rsa_encrypt_data_to_base64(self, key, data):
# Encrypt byte data with RSA key
cipher_rsa = PKCS1_OAEP.new(key=key, hashAlgo=SHA256)
encrypted_data = cipher_rsa.encrypt(data)
encrypted_data_base64 = base64.b64encode(encrypted_data)
return encrypted_data_base64

def _encrypt_str(self, workload: str):
# Generate one-time symmetric session key:
session_key = nacl.utils.random(32)
Expand All @@ -58,12 +111,39 @@ def _encrypt_str(self, workload: str):
encrypted_data_text = base64.b64encode(encrypted_data)

# Encrypt the session key with the public RSA key
cipher_rsa = PKCS1_OAEP.new(key=self.recipient_key, hashAlgo=SHA256)
encrypted_session_key = cipher_rsa.encrypt(session_key)
encrypted_session_key_text = base64.b64encode(encrypted_session_key)
encrypted_session_key_text =\
self._rsa_encrypt_data_to_base64(self.recipient_key, session_key)

return encrypted_session_key_text, encrypted_data_text

def _verify_signature(self, data, msg):
if msg.get_text() != "ssh-ed25519":
return False

try:
self.sender_key._signing_key.verify_key.verify(data, msg.get_binary())
except:
return False
else:
return True

def _sign_payload(self, encrypted_payload):
"""
encrypted_payload - base64 representation of the encrypted payload.
returns: base64-encoded signature
"""
sign_message = self.sender_key.sign_ssh_data(encrypted_payload)

# Verify what we generated just in case:
verify_message = paramiko.Message(sign_message.asbytes())
verify_res = self._verify_signature(encrypted_payload, verify_message)

if not verify_res:
self._raise_error("FAILED to verify signed data. ABORTING.")

result = base64.b64encode(sign_message.asbytes())
return result

def _decrypt_data(self, private_key_path, encrypted_key_text, encrypted_data_text):
private_key = self._import_rsa_key(private_key_path)
if private_key is None:
Expand Down Expand Up @@ -92,6 +172,7 @@ def _decrypt_data(self, private_key_path, encrypted_key_text, encrypted_data_tex
def construct(self, experiment, config, packages):
unencrypted_payload =\
self.simple_builder.construct(experiment, config, packages)
unencrypted_payload_str = json.dumps(unencrypted_payload)

# Construct payload template:
encrypted_payload = {
Expand All @@ -108,7 +189,7 @@ def construct(self, experiment, config, packages):
}

# Now fill it up with experiment properties:
enc_key, enc_payload = self._encrypt_str(json.dumps(unencrypted_payload))
enc_key, enc_payload = self._encrypt_str(unencrypted_payload_str)

encrypted_payload["message"]["experiment"]["status"] =\
experiment.status
Expand All @@ -122,6 +203,16 @@ def construct(self, experiment, config, packages):
experiment.resources_needed
encrypted_payload["message"]["payload"] =\
"{0},{1}".format(enc_key.decode("utf-8"), enc_payload.decode("utf-8"))
if self.sender_key is not None:
# Generate sender/workload signature:
final_payload = encrypted_payload["message"]["payload"]
payload_signature = self._sign_payload(final_payload.encode("utf-8"))
encrypted_payload["message"]["signature"] =\
"{0}".format(payload_signature.decode("utf-8"))
encrypted_payload["message"]["fingerprint"] =\
"{0}".format(self.sender_fingerprint)

print(json.dumps(encrypted_payload, indent=4))

return encrypted_payload

7 changes: 6 additions & 1 deletion studio/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .local_db_provider import LocalDbProvider
from .s3_provider import S3Provider
from .gs_provider import GSProvider
from .model_setup import setup_model
from .model_setup import setup_model, get_model_db_provider
from . import logs

def get_config(config_file=None):
Expand Down Expand Up @@ -59,6 +59,11 @@ def replace_with_env(config):
.format(config_paths))

def get_db_provider(config=None, blocking_auth=True):

db_provider = get_model_db_provider()
if not db_provider is None:
return db_provider

if not config:
config = get_config()
verbose = parse_verbosity(config.get('verbose'))
Expand Down
4 changes: 2 additions & 2 deletions studio/model_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
def setup_model(db_provider, artifact_store):
_model_setup = { DB_KEY: db_provider, STORE_KEY: artifact_store }

def get_db_provider():
def get_model_db_provider():
if _model_setup is None:
return None
return _model_setup.get(DB_KEY, None)

def get_artifact_store():
def get_model_artifact_store():
if _model_setup is None:
return None
return _model_setup.get(STORE_KEY, None)
15 changes: 11 additions & 4 deletions studio/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,12 +597,16 @@ def submit_experiments(

payload_builder = UnencryptedPayloadBuilder("simple-payload")
# Are we using experiment payload encryption?
key_path = config.get('public_key_path')
if key_path is not None:
logger.info("Using RSA public key path: {0}".format(key_path))
public_key_path = config.get('public_key_path', None)
if public_key_path is not None:
logger.info("Using RSA public key path: {0}".format(public_key_path))
signing_key_path = config.get('signing_key_path', None)
if signing_key_path is not None:
logger.info("Using RSA signing key path: {0}".format(signing_key_path))
payload_builder = \
EncryptedPayloadBuilder(
"cs-rsa-encryptor [{0}]".format(key_path), key_path)
"cs-rsa-encryptor [{0}]".format(public_key_path),
public_key_path, signing_key_path)

start_time = time.time()

Expand Down Expand Up @@ -637,6 +641,9 @@ def submit_experiments(

for experiment in experiments:
payload = payload_builder.construct(experiment, config, python_pkg)

print(json.dumps(payload, indent=4))

queue.enqueue(json.dumps(payload))
logger.info("studio run: submitted experiment " + experiment.key)

Expand Down

0 comments on commit 7135868

Please sign in to comment.