From 45ea392dc6f6bd5295a954c4969abce6d7242a9c Mon Sep 17 00:00:00 2001 From: Andrei Denissov Date: Tue, 19 May 2020 14:08:24 -0700 Subject: [PATCH 1/9] First cut for implementing experiment payload signing. --- studio/encrypted_payload_builder.py | 77 ++++++++++++++++++++++++----- studio/runner.py | 20 ++++++-- 2 files changed, 82 insertions(+), 15 deletions(-) diff --git a/studio/encrypted_payload_builder.py b/studio/encrypted_payload_builder.py index a86364b5..04d0fd1f 100644 --- a/studio/encrypted_payload_builder.py +++ b/studio/encrypted_payload_builder.py @@ -15,24 +15,58 @@ class EncryptedPayloadBuilder(PayloadBuilder): Implementation for experiment payload builder using public key RSA encryption. """ - def __init__(self, name: str, keypath: str): + def __init__(self, name: str, + receiver_keypath: str, + sender_keypath: str = None): """ param: name - payload builder name - param: keypath - file path to .pem file with public key + param: receiver_keypath - file path to .pem file + with recipient public key + param: sender_keypath - file path to .pem file + with sender private key """ super(EncryptedPayloadBuilder, self).__init__(name) # XXX Set logger verbosity level here self.logger = logs.getLogger(self.__class__.__name__) - self.key_path = keypath + self.recipient_key_path = receiver_keypath self.recipient_key = None try: - self.recipient_key = RSA.import_key(open(self.key_path).read()) + self.recipient_key =\ + RSA.import_key(open(self.recipient_key_path).read()) except: - self.logger.error( - "FAILED to import recipient public key from: {0}".format(self.key_path)) - return + msg = "FAILED to import recipient public key from: {0}"\ + .format(self.recipient_key_path) + self.logger.error(msg) + raise ValueError(msg) + + self.sender_key_path = sender_keypath + self.sender_key = None + self.sender_fingerprint = None + if self.sender_key_path: + key_text = None + try: + with open(self.sender_key_path, 'r') as keyfile: + key_text = keyfile.read() + except: + msg = "FAILED to open/read sender private key file: {0}"\ + .format(self.sender_key_path) + self.logger.error(msg) + raise ValueError(msg) + + try: + self.sender_key = RSA.import_key(key_text) + except: + msg = "FAILED to import sender private key from: {0}"\ + .format(self.sender_key_path) + self.logger.error(msg) + raise ValueError(msg) + + self.sender_fingerprint =\ + SHA256.new(key_text.encode("utf-8")).digest() + self.sender_fingerprint = \ + base64.b64encode(self.sender_fingerprint).decode("utf-8") self.simple_builder =\ UnencryptedPayloadBuilder("simple-builder-for-encryptor") @@ -47,6 +81,13 @@ def _import_rsa_key(self, key_path: str): key = None return key + def _rsa_encrypt_data_to_base64(self, key, data): + # Encrypt byte data with RSA key + cipher_rsa = PKCS1_OAEP.new(key=key, hashAlgo=SHA256) + encrypted_data = cipher_rsa.encrypt(data) + encrypted_data_base64 = base64.b64encode(encrypted_data) + return encrypted_data_base64 + def _encrypt_str(self, workload: str): # Generate one-time symmetric session key: session_key = nacl.utils.random(32) @@ -58,12 +99,18 @@ def _encrypt_str(self, workload: str): encrypted_data_text = base64.b64encode(encrypted_data) # Encrypt the session key with the public RSA key - cipher_rsa = PKCS1_OAEP.new(key=self.recipient_key, hashAlgo=SHA256) - encrypted_session_key = cipher_rsa.encrypt(session_key) - encrypted_session_key_text = base64.b64encode(encrypted_session_key) + encrypted_session_key_text =\ + self._rsa_encrypt_data_to_base64(self.recipient_key, session_key) return encrypted_session_key_text, encrypted_data_text + def _get_signature_str(self, workload: str): + data_to_hash = workload.encode("utf-8") + data_hash = SHA256.new(data_to_hash).digest() + encrypted_data_hash =\ + self._rsa_encrypt_data_to_base64(self.sender_key, data_hash) + return encrypted_data_hash + def _decrypt_data(self, private_key_path, encrypted_key_text, encrypted_data_text): private_key = self._import_rsa_key(private_key_path) if private_key is None: @@ -92,6 +139,7 @@ def _decrypt_data(self, private_key_path, encrypted_key_text, encrypted_data_tex def construct(self, experiment, config, packages): unencrypted_payload =\ self.simple_builder.construct(experiment, config, packages) + unencrypted_payload_str = json.dumps(unencrypted_payload) # Construct payload template: encrypted_payload = { @@ -108,7 +156,7 @@ def construct(self, experiment, config, packages): } # Now fill it up with experiment properties: - enc_key, enc_payload = self._encrypt_str(json.dumps(unencrypted_payload)) + enc_key, enc_payload = self._encrypt_str(unencrypted_payload_str) encrypted_payload["message"]["experiment"]["status"] =\ experiment.status @@ -122,6 +170,13 @@ def construct(self, experiment, config, packages): experiment.resources_needed encrypted_payload["message"]["payload"] =\ "{0},{1}".format(enc_key.decode("utf-8"), enc_payload.decode("utf-8")) + if self.sender_key is not None: + # Generate sender/workload signature: + signature_str = self._get_signature_str(unencrypted_payload_str) + encrypted_payload["message"]["signature"] =\ + "{0}".format(signature_str.decode("utf-8")) + encrypted_payload["message"]["fingerprint"] =\ + "{0}".format(self.sender_fingerprint) return encrypted_payload diff --git a/studio/runner.py b/studio/runner.py index 65347487..2a543f36 100644 --- a/studio/runner.py +++ b/studio/runner.py @@ -595,14 +595,22 @@ def submit_experiments( num_experiments = len(experiments) verbose = model.parse_verbosity(config['verbose']) + print("===========================================") + print(repr(config)) + print("===========================================") + payload_builder = UnencryptedPayloadBuilder("simple-payload") # Are we using experiment payload encryption? - key_path = config.get('public_key_path') - if key_path is not None: - logger.info("Using RSA public key path: {0}".format(key_path)) + public_key_path = config.get('public_key_path', None) + if public_key_path is not None: + logger.info("Using RSA public key path: {0}".format(public_key_path)) + signing_key_path = config.get('signing_key_path', None) + if signing_key_path is not None: + logger.info("Using RSA signing key path: {0}".format(signing_key_path)) payload_builder = \ EncryptedPayloadBuilder( - "cs-rsa-encryptor [{0}]".format(key_path), key_path) + "cs-rsa-encryptor [{0}]".format(public_key_path), + public_key_path, signing_key_path) start_time = time.time() @@ -637,6 +645,10 @@ def submit_experiments( for experiment in experiments: payload = payload_builder.construct(experiment, config, python_pkg) + + print(json.dumps(payload, indent=4)) + exit(0) + queue.enqueue(json.dumps(payload)) logger.info("studio run: submitted experiment " + experiment.key) From fe926e52c104c02b9ba73e78efc38c3810d05204 Mon Sep 17 00:00:00 2001 From: Andrei Denissov Date: Wed, 24 Jun 2020 18:10:34 -0700 Subject: [PATCH 2/9] Work in progress. Use ed25519 signing keys. Updated requirements.txt for paramiko package. --- requirements.txt | 3 +- studio/encrypted_payload_builder.py | 47 ++++++++++++++++++++--------- 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/requirements.txt b/requirements.txt index 25f41fd2..a108a8dc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,6 +17,7 @@ cma apscheduler pycryptodome PyNaCl==1.3.0 +paramiko requests requests_toolbelt python_jwt @@ -43,5 +44,5 @@ pika == 1.1.0 cachetools == 2.0.1 -tensorflow==1.15.2 +tensorflow==2.2.0 diff --git a/studio/encrypted_payload_builder.py b/studio/encrypted_payload_builder.py index 04d0fd1f..ca9c208d 100644 --- a/studio/encrypted_payload_builder.py +++ b/studio/encrypted_payload_builder.py @@ -3,6 +3,8 @@ from Crypto.Hash import SHA256 import nacl.secret import nacl.utils +import nacl.signing +import paramiko import base64 import json @@ -45,26 +47,29 @@ def __init__(self, name: str, self.sender_key = None self.sender_fingerprint = None if self.sender_key_path: - key_text = None + # We expect ed25519 signing key in "private key" format try: - with open(self.sender_key_path, 'r') as keyfile: - key_text = keyfile.read() + ed25519Key =\ + paramiko.Ed25519Key(filename=self.sender_key_path) + self.sender_key = ed25519Key._signing_key + + if self.sender_key is None: + self.logger.error("Failed to get signing key. ABORTING.") + raise ValueError() + + if not isinstance(self.sender_key, nacl.signing.SigningKey): + self.logger.error("Unexpected type {0} of signing key. ABORTING." + .format(str(type(self.sender_key)))) + raise ValueError() except: msg = "FAILED to open/read sender private key file: {0}"\ .format(self.sender_key_path) self.logger.error(msg) raise ValueError(msg) - try: - self.sender_key = RSA.import_key(key_text) - except: - msg = "FAILED to import sender private key from: {0}"\ - .format(self.sender_key_path) - self.logger.error(msg) - raise ValueError(msg) - - self.sender_fingerprint =\ - SHA256.new(key_text.encode("utf-8")).digest() + self.sender_fingerprint = "None" + # self.sender_fingerprint =\ + # SHA256.new(key_text.encode("utf-8")).digest() self.sender_fingerprint = \ base64.b64encode(self.sender_fingerprint).decode("utf-8") @@ -111,6 +116,15 @@ def _get_signature_str(self, workload: str): self._rsa_encrypt_data_to_base64(self.sender_key, data_hash) return encrypted_data_hash + def _sign_payload(self, encrypted_payload): + """ + encrypted_payload - base64 representation of the encrypted payload. + returns: base64-encoded signature + """ + signed = self.sender_key.sign(encrypted_payload) + signature = signed.signature + return base64.b64encode(signature) + def _decrypt_data(self, private_key_path, encrypted_key_text, encrypted_data_text): private_key = self._import_rsa_key(private_key_path) if private_key is None: @@ -172,11 +186,14 @@ def construct(self, experiment, config, packages): "{0},{1}".format(enc_key.decode("utf-8"), enc_payload.decode("utf-8")) if self.sender_key is not None: # Generate sender/workload signature: - signature_str = self._get_signature_str(unencrypted_payload_str) + payload_signature = self._sign_payload(enc_payload) encrypted_payload["message"]["signature"] =\ - "{0}".format(signature_str.decode("utf-8")) + "{0}".format(payload_signature.decode("utf-8")) encrypted_payload["message"]["fingerprint"] =\ "{0}".format(self.sender_fingerprint) + print("{0}".format(json.dumps(encrypted_payload, indent=4))) + exit(0) + return encrypted_payload From 9296aaf182a895bf12d984ecd29405208212b621 Mon Sep 17 00:00:00 2001 From: Andrei Denissov Date: Wed, 24 Jun 2020 18:26:33 -0700 Subject: [PATCH 3/9] Added fingerprint generation. --- studio/encrypted_payload_builder.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/studio/encrypted_payload_builder.py b/studio/encrypted_payload_builder.py index ca9c208d..363d8d8a 100644 --- a/studio/encrypted_payload_builder.py +++ b/studio/encrypted_payload_builder.py @@ -67,15 +67,28 @@ def __init__(self, name: str, self.logger.error(msg) raise ValueError(msg) - self.sender_fingerprint = "None" - # self.sender_fingerprint =\ - # SHA256.new(key_text.encode("utf-8")).digest() + self.sender_fingerprint = \ + self._get_fingerprint(self.sender_key_path) self.sender_fingerprint = \ base64.b64encode(self.sender_fingerprint).decode("utf-8") self.simple_builder =\ UnencryptedPayloadBuilder("simple-builder-for-encryptor") + def _get_fingerprint(self, key_file_path): + key_text = None + try: + with open(key_file_path, 'r') as keyfile: + key_text = keyfile.read() + except: + msg = "FAILED to open/read key file: {0}".format(key_file_path) + self.logger.error(msg) + raise ValueError(msg) + + fingerprint = \ + SHA256.new(key_text.encode("utf-8")).digest() + return fingerprint + def _import_rsa_key(self, key_path: str): key = None try: From 9f24d5591e27f121ee6d42a5cc17c1f1c967270e Mon Sep 17 00:00:00 2001 From: Andrei Denissov Date: Thu, 25 Jun 2020 17:45:26 -0700 Subject: [PATCH 4/9] Fixes. --- studio/encrypted_payload_builder.py | 10 +++++++--- studio/runner.py | 1 - 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/studio/encrypted_payload_builder.py b/studio/encrypted_payload_builder.py index 363d8d8a..ac45bab3 100644 --- a/studio/encrypted_payload_builder.py +++ b/studio/encrypted_payload_builder.py @@ -75,6 +75,11 @@ def __init__(self, name: str, self.simple_builder =\ UnencryptedPayloadBuilder("simple-builder-for-encryptor") + def _get_text_fingerprint(self, text: str): + fingerprint = \ + SHA256.new(text.encode("ascii")).digest() + return fingerprint + def _get_fingerprint(self, key_file_path): key_text = None try: @@ -86,7 +91,7 @@ def _get_fingerprint(self, key_file_path): raise ValueError(msg) fingerprint = \ - SHA256.new(key_text.encode("utf-8")).digest() + self._get_text_fingerprint(key_text) return fingerprint def _import_rsa_key(self, key_path: str): @@ -205,8 +210,7 @@ def construct(self, experiment, config, packages): encrypted_payload["message"]["fingerprint"] =\ "{0}".format(self.sender_fingerprint) - print("{0}".format(json.dumps(encrypted_payload, indent=4))) - exit(0) + #print("{0}".format(json.dumps(encrypted_payload, indent=4))) return encrypted_payload diff --git a/studio/runner.py b/studio/runner.py index 2a543f36..684a344e 100644 --- a/studio/runner.py +++ b/studio/runner.py @@ -647,7 +647,6 @@ def submit_experiments( payload = payload_builder.construct(experiment, config, python_pkg) print(json.dumps(payload, indent=4)) - exit(0) queue.enqueue(json.dumps(payload)) logger.info("studio run: submitted experiment " + experiment.key) From 6fe0c1eadd8cb8cc8db3e02b5f1fda8c9d1c6a4a Mon Sep 17 00:00:00 2001 From: Andrei Denissov Date: Fri, 26 Jun 2020 09:53:51 -0700 Subject: [PATCH 5/9] Fixes to generate fingerprint from private ed25519 key. Attempt to use paramiko "sign_ssh_data" method. --- requirements.txt | 1 + studio/encrypted_payload_builder.py | 79 ++++++++++++++--------------- 2 files changed, 38 insertions(+), 42 deletions(-) diff --git a/requirements.txt b/requirements.txt index 6fa4c09d..4708067c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,6 +14,7 @@ cma apscheduler pycryptodome paramiko +sshpubkeys PyNaCl requests requests_toolbelt diff --git a/studio/encrypted_payload_builder.py b/studio/encrypted_payload_builder.py index ac45bab3..4feba372 100644 --- a/studio/encrypted_payload_builder.py +++ b/studio/encrypted_payload_builder.py @@ -7,6 +7,7 @@ import paramiko import base64 import json +from sshpubkeys import SSHKey from .payload_builder import PayloadBuilder from studio import logs @@ -46,53 +47,42 @@ def __init__(self, name: str, self.sender_key_path = sender_keypath self.sender_key = None self.sender_fingerprint = None - if self.sender_key_path: - # We expect ed25519 signing key in "private key" format - try: - ed25519Key =\ - paramiko.Ed25519Key(filename=self.sender_key_path) - self.sender_key = ed25519Key._signing_key - - if self.sender_key is None: - self.logger.error("Failed to get signing key. ABORTING.") - raise ValueError() - - if not isinstance(self.sender_key, nacl.signing.SigningKey): - self.logger.error("Unexpected type {0} of signing key. ABORTING." - .format(str(type(self.sender_key)))) - raise ValueError() - except: - msg = "FAILED to open/read sender private key file: {0}"\ - .format(self.sender_key_path) - self.logger.error(msg) - raise ValueError(msg) - - self.sender_fingerprint = \ - self._get_fingerprint(self.sender_key_path) - self.sender_fingerprint = \ - base64.b64encode(self.sender_fingerprint).decode("utf-8") + + if self.sender_key_path is None: + self.logger.error("Signing key path must be specified for encrypted payloads. ABORTING.") + raise ValueError() + + # We expect ed25519 signing key in "private key" format + try: + self.sender_key =\ + paramiko.Ed25519Key(filename=self.sender_key_path) + + if self.sender_key is None: + self.logger.error("Failed to import private signing key. ABORTING.") + raise ValueError() + except: + msg = "FAILED to open/read private signing key file: {0}"\ + .format(self.sender_key_path) + self.logger.error(msg) + raise ValueError(msg) + + self.sender_fingerprint = \ + self._get_fingerprint(self.sender_key) self.simple_builder =\ UnencryptedPayloadBuilder("simple-builder-for-encryptor") - def _get_text_fingerprint(self, text: str): - fingerprint = \ - SHA256.new(text.encode("ascii")).digest() - return fingerprint - - def _get_fingerprint(self, key_file_path): - key_text = None + def _get_fingerprint(self, signing_key): + ssh_key = SSHKey("ssh-ed25519 {0}" + .format(signing_key.get_base64())) try: - with open(key_file_path, 'r') as keyfile: - key_text = keyfile.read() + ssh_key.parse() except: - msg = "FAILED to open/read key file: {0}".format(key_file_path) + msg = "INVALID signing key type. ABORTING." self.logger.error(msg) raise ValueError(msg) - fingerprint = \ - self._get_text_fingerprint(key_text) - return fingerprint + return ssh_key.hash_sha256() # SHA256:xyz def _import_rsa_key(self, key_path: str): key = None @@ -139,9 +129,14 @@ def _sign_payload(self, encrypted_payload): encrypted_payload - base64 representation of the encrypted payload. returns: base64-encoded signature """ - signed = self.sender_key.sign(encrypted_payload) - signature = signed.signature - return base64.b64encode(signature) + sign_message = self.sender_key.sign_ssh_data(encrypted_payload) + + # Verify what we generated just in case: + signature = self.sender_key._signing_key.sign(encrypted_payload).signature + self.sender_key._signing_key.verify_key.verify(encrypted_payload, signature) + print("VERIFIED!") + + return base64.b64encode(sign_message.asbytes()) def _decrypt_data(self, private_key_path, encrypted_key_text, encrypted_data_text): private_key = self._import_rsa_key(private_key_path) @@ -210,7 +205,7 @@ def construct(self, experiment, config, packages): encrypted_payload["message"]["fingerprint"] =\ "{0}".format(self.sender_fingerprint) - #print("{0}".format(json.dumps(encrypted_payload, indent=4))) + print(json.dumps(encrypted_payload, indent=4)) return encrypted_payload From d6b22ecf3dbebb7248a65f492934cd8d891fea73 Mon Sep 17 00:00:00 2001 From: Andrei Denissov Date: Wed, 1 Jul 2020 16:53:46 -0700 Subject: [PATCH 6/9] Fixes + code cleanup. --- studio/encrypted_payload_builder.py | 49 +++++++++++++++++------------ studio/runner.py | 4 --- 2 files changed, 29 insertions(+), 24 deletions(-) diff --git a/studio/encrypted_payload_builder.py b/studio/encrypted_payload_builder.py index 4feba372..ed6035a0 100644 --- a/studio/encrypted_payload_builder.py +++ b/studio/encrypted_payload_builder.py @@ -58,13 +58,10 @@ def __init__(self, name: str, paramiko.Ed25519Key(filename=self.sender_key_path) if self.sender_key is None: - self.logger.error("Failed to import private signing key. ABORTING.") - raise ValueError() + self._raise_error("Failed to import private signing key. ABORTING.") except: - msg = "FAILED to open/read private signing key file: {0}"\ - .format(self.sender_key_path) - self.logger.error(msg) - raise ValueError(msg) + self._raise_error("FAILED to open/read private signing key file: {0}"\ + .format(self.sender_key_path)) self.sender_fingerprint = \ self._get_fingerprint(self.sender_key) @@ -72,15 +69,17 @@ def __init__(self, name: str, self.simple_builder =\ UnencryptedPayloadBuilder("simple-builder-for-encryptor") + def _raise_error(self, msg: str): + self.logger.error(msg) + raise ValueError(msg) + def _get_fingerprint(self, signing_key): ssh_key = SSHKey("ssh-ed25519 {0}" .format(signing_key.get_base64())) try: ssh_key.parse() except: - msg = "INVALID signing key type. ABORTING." - self.logger.error(msg) - raise ValueError(msg) + self._raise_error("INVALID signing key type. ABORTING.") return ssh_key.hash_sha256() # SHA256:xyz @@ -117,12 +116,16 @@ def _encrypt_str(self, workload: str): return encrypted_session_key_text, encrypted_data_text - def _get_signature_str(self, workload: str): - data_to_hash = workload.encode("utf-8") - data_hash = SHA256.new(data_to_hash).digest() - encrypted_data_hash =\ - self._rsa_encrypt_data_to_base64(self.sender_key, data_hash) - return encrypted_data_hash + def _verify_signature(self, data, msg): + if msg.get_text() != "ssh-ed25519": + return False + + try: + self.sender_key._signing_key.verify_key.verify(data, msg.get_binary()) + except: + return False + else: + return True def _sign_payload(self, encrypted_payload): """ @@ -132,11 +135,14 @@ def _sign_payload(self, encrypted_payload): sign_message = self.sender_key.sign_ssh_data(encrypted_payload) # Verify what we generated just in case: - signature = self.sender_key._signing_key.sign(encrypted_payload).signature - self.sender_key._signing_key.verify_key.verify(encrypted_payload, signature) - print("VERIFIED!") + verify_message = paramiko.Message(sign_message.asbytes()) + verify_res = self._verify_signature(encrypted_payload, verify_message) - return base64.b64encode(sign_message.asbytes()) + if not verify_res: + self._raise_error("FAILED to verify signed data. ABORTING.") + + result = base64.b64encode(sign_message.asbytes()) + return result def _decrypt_data(self, private_key_path, encrypted_key_text, encrypted_data_text): private_key = self._import_rsa_key(private_key_path) @@ -199,7 +205,8 @@ def construct(self, experiment, config, packages): "{0},{1}".format(enc_key.decode("utf-8"), enc_payload.decode("utf-8")) if self.sender_key is not None: # Generate sender/workload signature: - payload_signature = self._sign_payload(enc_payload) + final_payload = encrypted_payload["message"]["payload"] + payload_signature = self._sign_payload(final_payload.encode("utf-8")) encrypted_payload["message"]["signature"] =\ "{0}".format(payload_signature.decode("utf-8")) encrypted_payload["message"]["fingerprint"] =\ @@ -207,5 +214,7 @@ def construct(self, experiment, config, packages): print(json.dumps(encrypted_payload, indent=4)) + exit(0) + return encrypted_payload diff --git a/studio/runner.py b/studio/runner.py index 684a344e..c1528613 100644 --- a/studio/runner.py +++ b/studio/runner.py @@ -595,10 +595,6 @@ def submit_experiments( num_experiments = len(experiments) verbose = model.parse_verbosity(config['verbose']) - print("===========================================") - print(repr(config)) - print("===========================================") - payload_builder = UnencryptedPayloadBuilder("simple-payload") # Are we using experiment payload encryption? public_key_path = config.get('public_key_path', None) From a89e9e479c5bef63c69869def71aa222529156bf Mon Sep 17 00:00:00 2001 From: Andrei Denissov Date: Thu, 2 Jul 2020 17:24:47 -0700 Subject: [PATCH 7/9] Fixes and optimization for getting db_provider. --- studio/encrypted_payload_builder.py | 2 -- studio/model.py | 11 ++++++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/studio/encrypted_payload_builder.py b/studio/encrypted_payload_builder.py index ed6035a0..5c8fd8db 100644 --- a/studio/encrypted_payload_builder.py +++ b/studio/encrypted_payload_builder.py @@ -214,7 +214,5 @@ def construct(self, experiment, config, packages): print(json.dumps(encrypted_payload, indent=4)) - exit(0) - return encrypted_payload diff --git a/studio/model.py b/studio/model.py index b4df9bdc..72d05b4b 100644 --- a/studio/model.py +++ b/studio/model.py @@ -18,7 +18,7 @@ from .local_db_provider import LocalDbProvider from .s3_provider import S3Provider from .gs_provider import GSProvider -from .model_setup import setup_model +import model_setup from . import logs def get_config(config_file=None): @@ -59,6 +59,11 @@ def replace_with_env(config): .format(config_paths)) def get_db_provider(config=None, blocking_auth=True): + + db_provider = model_setup.get_db_provider() + if not db_provider is None: + return db_provider + if not config: config = get_config() verbose = parse_verbosity(config.get('verbose')) @@ -115,10 +120,10 @@ def get_db_provider(config=None, blocking_auth=True): artifact_store = db_provider.get_artifact_store() else: - _model_setup = None + model_setup._model_setup = None raise ValueError('Unknown type of the database ' + db_config['type']) - setup_model(db_provider, artifact_store) + model_setup.setup_model(db_provider, artifact_store) return db_provider def parse_verbosity(verbosity=None): From 7c81627d3aa3468e167a1b87cc5dd2272b9f170e Mon Sep 17 00:00:00 2001 From: Andrei Denissov Date: Thu, 2 Jul 2020 17:35:53 -0700 Subject: [PATCH 8/9] Fixes. --- studio/model.py | 8 ++++---- studio/model_setup.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/studio/model.py b/studio/model.py index 72d05b4b..ff5600b0 100644 --- a/studio/model.py +++ b/studio/model.py @@ -18,7 +18,7 @@ from .local_db_provider import LocalDbProvider from .s3_provider import S3Provider from .gs_provider import GSProvider -import model_setup +from model_setup import setup_model, get_model_db_provider from . import logs def get_config(config_file=None): @@ -60,7 +60,7 @@ def replace_with_env(config): def get_db_provider(config=None, blocking_auth=True): - db_provider = model_setup.get_db_provider() + db_provider = get_model_db_provider() if not db_provider is None: return db_provider @@ -120,10 +120,10 @@ def get_db_provider(config=None, blocking_auth=True): artifact_store = db_provider.get_artifact_store() else: - model_setup._model_setup = None + _model_setup = None raise ValueError('Unknown type of the database ' + db_config['type']) - model_setup.setup_model(db_provider, artifact_store) + setup_model(db_provider, artifact_store) return db_provider def parse_verbosity(verbosity=None): diff --git a/studio/model_setup.py b/studio/model_setup.py index cf676a5d..05f0b945 100644 --- a/studio/model_setup.py +++ b/studio/model_setup.py @@ -9,12 +9,12 @@ def setup_model(db_provider, artifact_store): _model_setup = { DB_KEY: db_provider, STORE_KEY: artifact_store } -def get_db_provider(): +def get_model_db_provider(): if _model_setup is None: return None return _model_setup.get(DB_KEY, None) -def get_artifact_store(): +def get_model_artifact_store(): if _model_setup is None: return None return _model_setup.get(STORE_KEY, None) From 7b90449cb5055bcb9b9da75a0f549ae65c9bf429 Mon Sep 17 00:00:00 2001 From: Andrei Denissov Date: Thu, 2 Jul 2020 17:44:50 -0700 Subject: [PATCH 9/9] One more fix. --- studio/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/studio/model.py b/studio/model.py index ff5600b0..3f1cc8a5 100644 --- a/studio/model.py +++ b/studio/model.py @@ -18,7 +18,7 @@ from .local_db_provider import LocalDbProvider from .s3_provider import S3Provider from .gs_provider import GSProvider -from model_setup import setup_model, get_model_db_provider +from .model_setup import setup_model, get_model_db_provider from . import logs def get_config(config_file=None):