-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
189 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build" | |
|
||
[project] | ||
name = "nypl_py_utils" | ||
version = "1.3.0" | ||
version = "1.4.0" | ||
authors = [ | ||
{ name="Aaron Friedman", email="[email protected]" }, | ||
] | ||
|
@@ -60,6 +60,9 @@ secrets-manager-client = [ | |
"boto3>=1.26.5", | ||
"botocore>=1.29.5" | ||
] | ||
sftp-client = [ | ||
"paramiko>=3.4.1" | ||
] | ||
config-helper = [ | ||
"nypl_py_utils[kms-client]", | ||
"PyYAML>=6.0" | ||
|
@@ -71,7 +74,7 @@ research-catalog-identifier-helper = [ | |
"requests>=2.28.1" | ||
] | ||
development = [ | ||
"nypl_py_utils[avro-client,kinesis-client,kms-client,mysql-client,oauth2-api-client,postgresql-client,postgresql-pool-client,redshift-client,s3-client,secrets-manager-client,config-helper,obfuscation-helper,research-catalog-identifier-helper]", | ||
"nypl_py_utils[avro-client,kinesis-client,kms-client,mysql-client,oauth2-api-client,postgresql-client,postgresql-pool-client,redshift-client,s3-client,secrets-manager-client,sftp-client,config-helper,obfuscation-helper,research-catalog-identifier-helper]", | ||
"flake8>=6.0.0", | ||
"freezegun>=1.2.2", | ||
"mock>=4.0.3", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from base64 import b64decode | ||
from io import StringIO | ||
from nypl_py_utils.functions.log_helper import create_log | ||
from paramiko import PKey, RSAKey, SSHClient | ||
from paramiko.ssh_exception import SSHException | ||
|
||
|
||
class SftpClient: | ||
"""Client for interacting with a remote SSH server via SFTP""" | ||
|
||
def __init__(self, host, user, password=None, private_key_str=None): | ||
self.logger = create_log("sftp_client") | ||
self.host = host | ||
self.user = user | ||
self.password = password | ||
self.private_key_str = private_key_str | ||
self.ssh_client = SSHClient() | ||
|
||
def add_host_key(self, key_type, public_key): | ||
try: | ||
public_key = PKey.from_type_string(key_type, b64decode(public_key)) | ||
self.ssh_client.get_host_keys().add( | ||
hostname=self.host, keytype=key_type, key=public_key | ||
) | ||
except Exception as e: | ||
self.logger.warning(f"Failed to load host key: {e}") | ||
|
||
def connect(self): | ||
"""Connects to a remote server using SSH""" | ||
self.logger.info("Connecting to {}".format(self.host)) | ||
pkey = None | ||
try: | ||
if self.private_key_str: | ||
pkey = RSAKey.from_private_key(StringIO(self.private_key_str)) | ||
self.ssh_client.connect(self.host, username=self.user, | ||
password=self.password, pkey=pkey) | ||
self.sftp_conn = self.ssh_client.open_sftp() | ||
except SSHException as e: | ||
self.logger.error( | ||
"Error connecting to {host}: {error}".format( | ||
host=self.host, error=e) | ||
) | ||
raise SftpClientError( | ||
"Error connecting to {host}: {error}".format( | ||
host=self.host, error=e) | ||
) from None | ||
|
||
def download(self, remote_path, local_path): | ||
"""Downloads a file on the remote server to the local machine""" | ||
self.logger.info( | ||
"Downloading {remote} file as {local}".format( | ||
remote=remote_path, local=local_path | ||
) | ||
) | ||
try: | ||
self.sftp_conn.get(remote_path, local_path) | ||
except Exception as e: | ||
self.logger.error("Error downloading file: {}".format(e)) | ||
self.close_connection() | ||
raise SftpClientError( | ||
"Error downloading file: {}".format(e)) from None | ||
|
||
def close_connection(self): | ||
"""Closes the connection""" | ||
self.logger.debug("Closing connection to {}".format(self.host)) | ||
self.sftp_conn.close() | ||
self.ssh_client.close() | ||
|
||
|
||
class SftpClientError(Exception): | ||
def __init__(self, message=None): | ||
self.message = message |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import pytest | ||
|
||
from nypl_py_utils.classes.sftp_client import SftpClient, SftpClientError | ||
|
||
_TEST_PUBLIC_KEY = ( | ||
'AAAAB3NzaC1yc2EAAAADAQABAAAAgQCHc5r1z7bCxJ+dwR4r65CKB4KBF6mB+VZNYPc/1kmyT' | ||
'vRh+P89asNvGDwATw7FZkz+g/0Z/Arak2ae454AHW7gBRO+TJ6YoAIrH2H5O3vQ4GGOepcTz3' | ||
'0ckuLoXtoaRMYzDTM1juvnITFq9fE5RMeFIM+Qc7BhOub/nDPLQI7/sw==' | ||
) | ||
|
||
_TEST_PRIVATE_KEY = ( | ||
'-----BEGIN RSA PRIVATE KEY-----\nMIICWwIBAAKBgQCItzqS6yQYBq+923wf4pQ6M2u0' | ||
'pNMknrO4itBBQiDO6uDktZn2\nONnF1L9bYCtsucBGmRes6gdn+qFGTFRa+mWBHBO5CtOhbxA' | ||
'bH9K4MWi9B6fF6Riw\nUkhOIsXHQFPtPg23kF+0MV953CrhZMMdWmYh4EVaRFfRmQchsjJkP0' | ||
'eqBQIDAQAB\nAoGAEC+ZOLGsGUgZYGHu5Rt/LxDNbJqjAM/lOTD+DOvWVIkMTSeO7c63Qau5a' | ||
'AkP\nuxSWxgTz/53JeK78jwUUa5z/jUbD+4D0NbfjmFOXGlnVxs/kbx4z4tPwwArN6gMS\n7T' | ||
'fuEDgx4RF4a5kl5hOwDV1RUUCJ2TBO9wbm533ca7TvcCECQQDy3pKOB1ae9HM/\nYgtR6z1k0' | ||
'd734ujmDXpViESfvJpm+fd/o0MEh193cO9qGFDWiOU23axF/n5fIaaf\nhHt/8C/dAkEAkBtw' | ||
'bdQGDN9eZKH4XX1pRvB2PzUmrpgzZl3Zst8svKPDjeD9nm0Z\n+pGFLcVCIFT8ddUH1LSbt96' | ||
'a4wn5/dPUSQJAUs2fmdzWo4skX8/FnEBfxifnpQwv\n639c3hx/iRZ8be97eoDnMHwXCFnwxn' | ||
'NT3FEAFRyux45k93o5nNlGYfA54QJAKIwP\n7lch/K082gPY5jVLUfKG0vIZmDaq/7qYboPtC' | ||
'obplxofQlxgWuhnGKHQIVjIUD9I\nnMjUp7+yxP8hoBHiQQJAZsNUg/q1JNCEoa4Gqb89yygr' | ||
'x2fFOC/6eNp0ruWMRr5P\n8x1L+ugdXeUfI5vH7qI9wU+A7oADke63JBEHavv0UQ==\n-----' | ||
'END RSA PRIVATE KEY-----' | ||
) | ||
|
||
|
||
class TestSftpClient: | ||
|
||
@pytest.fixture | ||
def test_instance(self, mocker): | ||
mocker.patch('paramiko.SSHClient.connect') | ||
mocker.patch('paramiko.SSHClient.open_sftp') | ||
return SftpClient('test_host', 'test_user') | ||
|
||
def test_add_host_key(self, test_instance): | ||
assert len(test_instance.ssh_client.get_host_keys().keys()) == 0 | ||
|
||
test_instance.add_host_key('ssh-rsa', _TEST_PUBLIC_KEY) | ||
|
||
assert len(test_instance.ssh_client.get_host_keys().keys()) == 1 | ||
assert test_instance.ssh_client.get_host_keys().lookup( | ||
'test_host') is not None | ||
|
||
def test_connect_password(self, test_instance): | ||
test_instance.password = 'test_password' | ||
|
||
test_instance.connect() | ||
|
||
test_instance.ssh_client.connect.assert_called_once_with( | ||
'test_host', username='test_user', password='test_password', | ||
pkey=None) | ||
test_instance.ssh_client.open_sftp.assert_called_once() | ||
assert test_instance.sftp_conn is not None | ||
|
||
def test_connect_pkey(self, test_instance, mocker): | ||
mock_rsa_key = mocker.MagicMock() | ||
mock_pkey_method = mocker.patch('paramiko.RSAKey.from_private_key', | ||
return_value=mock_rsa_key) | ||
test_instance.private_key_str = _TEST_PRIVATE_KEY | ||
|
||
test_instance.connect() | ||
|
||
assert mock_pkey_method.call_args[0][0].read() == _TEST_PRIVATE_KEY | ||
test_instance.ssh_client.connect.assert_called_once_with( | ||
'test_host', username='test_user', password=None, | ||
pkey=mock_rsa_key) | ||
test_instance.ssh_client.open_sftp.assert_called_once() | ||
assert test_instance.sftp_conn is not None | ||
|
||
def test_download(self, test_instance, mocker): | ||
test_instance.sftp_conn = mocker.MagicMock() | ||
|
||
test_instance.download('remote/path', 'local/path') | ||
|
||
test_instance.sftp_conn.get.assert_called_once_with( | ||
'remote/path', 'local/path') | ||
|
||
def test_download_error(self, test_instance, mocker): | ||
test_instance.ssh_client = mocker.MagicMock() | ||
test_instance.sftp_conn = mocker.MagicMock() | ||
test_instance.sftp_conn.get.side_effect = IOError('test error') | ||
|
||
with pytest.raises(SftpClientError): | ||
test_instance.download('remote/path', 'local/path') | ||
|
||
test_instance.sftp_conn.get.assert_called_once_with( | ||
'remote/path', 'local/path') | ||
test_instance.sftp_conn.close.assert_called_once() | ||
test_instance.ssh_client.close.assert_called_once() | ||
|
||
def test_close_connection(self, test_instance, mocker): | ||
test_instance.sftp_conn = mocker.MagicMock() | ||
test_instance.ssh_client = mocker.MagicMock() | ||
|
||
test_instance.close_connection() | ||
|
||
test_instance.sftp_conn.close.assert_called_once() | ||
test_instance.ssh_client.close.assert_called_once() |