diff --git a/crmsh/bootstrap.py b/crmsh/bootstrap.py index ff00e44f6..60001b902 100644 --- a/crmsh/bootstrap.py +++ b/crmsh/bootstrap.py @@ -1523,19 +1523,21 @@ def _setup_passwordless_ssh_for_qnetd(cluster_node_list: typing.List[str]): 'root', )).add(qnetd_addr, qnetd_user, key) else: - if utils.check_ssh_passwd_need(local_user, qnetd_user, qnetd_addr): - if 0 != utils.ssh_copy_id_no_raise(local_user, qnetd_user, qnetd_addr): - msg = f"Failed to login to {qnetd_user}@{qnetd_addr}. Please check the credentials." - sudoer = userdir.get_sudoer() - if sudoer and qnetd_user != sudoer: - args = ['sudo crm'] - args += [x for x in sys.argv[1:]] - for i, arg in enumerate(args): - if arg == '--qnetd-hostname' and i + 1 < len(args): - if '@' not in args[i + 1]: - args[i + 1] = f'{sudoer}@{qnetd_addr}' - msg += '\nOr, run "{}".'.format(' '.join(args)) - raise ValueError(msg) + if 0 != utils.ssh_copy_id_no_raise( + local_user, qnetd_user, qnetd_addr, + sh.LocalShell(additional_environ={'SSH_AUTH_SOCK': ''}), + ): + msg = f"Failed to login to {qnetd_user}@{qnetd_addr}. Please check the credentials." + sudoer = userdir.get_sudoer() + if sudoer and qnetd_user != sudoer: + args = ['sudo crm'] + args += [x for x in sys.argv[1:]] + for i, arg in enumerate(args): + if arg == '--qnetd-hostname' and i + 1 < len(args): + if '@' not in args[i + 1]: + args[i + 1] = f'{sudoer}@{qnetd_addr}' + msg += '\nOr, run "{}".'.format(' '.join(args)) + raise ValueError(msg) cluster_shell = sh.cluster_shell() # Add other nodes' public keys to qnetd's authorized_keys @@ -1609,9 +1611,9 @@ def join_ssh_impl(local_user, seed_host, seed_user, ssh_public_keys: typing.List local_shell = sh.LocalShell(additional_environ={'SSH_AUTH_SOCK': os.environ.get('SSH_AUTH_SOCK')}) join_ssh_with_ssh_agent(local_shell, local_user, seed_host, seed_user, ssh_public_keys) else: - local_shell = sh.LocalShell() + local_shell = sh.LocalShell(additional_environ={'SSH_AUTH_SOCK': ''}) configure_ssh_key(local_user) - if 0 != utils.ssh_copy_id_no_raise(local_user, seed_user, seed_host): + if 0 != utils.ssh_copy_id_no_raise(local_user, seed_user, seed_host, local_shell): msg = f"Failed to login to {seed_user}@{seed_host}. Please check the credentials." sudoer = userdir.get_sudoer() if sudoer and seed_user != sudoer: @@ -2633,7 +2635,10 @@ def bootstrap_join_geo(context): join_ssh_with_ssh_agent(local_shell, local_user, node, remote_user, keys) else: configure_ssh_key(local_user) - if 0 != utils.ssh_copy_id_no_raise(local_user, remote_user, node): + if 0 != utils.ssh_copy_id_no_raise( + local_user, remote_user, node, + sh.LocalShell(additional_environ={'SSH_AUTH_SOCK': ''}), + ): raise ValueError(f"Failed to login to {remote_user}@{node}. Please check the credentials.") swap_public_ssh_key(node, local_user, remote_user, local_user, remote_user, add=True) user_by_host = utils.HostUserConfig() @@ -2669,7 +2674,10 @@ def bootstrap_arbitrator(context): join_ssh_with_ssh_agent(local_shell, local_user, node, remote_user, keys) else: configure_ssh_key(local_user) - if 0 != utils.ssh_copy_id_no_raise(local_user, remote_user, node): + if 0 != utils.ssh_copy_id_no_raise( + local_user, remote_user, node, + sh.LocalShell(additional_environ={'SSH_AUTH_SOCK': ''}), + ): raise ValueError(f"Failed to login to {remote_user}@{node}. Please check the credentials.") swap_public_ssh_key(node, local_user, remote_user, local_user, remote_user, add=True) user_by_host.add(local_user, utils.this_node()) diff --git a/crmsh/utils.py b/crmsh/utils.py index 0b19e10b2..350bb38b7 100644 --- a/crmsh/utils.py +++ b/crmsh/utils.py @@ -132,11 +132,13 @@ def user_pair_for_ssh(host): raise ValueError('Can not create ssh session from {} to {}.'.format(this_node(), host)) -def ssh_copy_id_no_raise(local_user, remote_user, remote_node): - if check_ssh_passwd_need(local_user, remote_user, remote_node): +def ssh_copy_id_no_raise(local_user, remote_user, remote_node, shell: sh.LocalShell = None): + if shell is None: + shell = sh.LocalShell() + if check_ssh_passwd_need(local_user, remote_user, remote_node, shell): logger.info("Configuring SSH passwordless with {}@{}".format(remote_user, remote_node)) cmd = "ssh-copy-id -i ~/.ssh/id_rsa.pub '{}@{}' &> /dev/null".format(remote_user, remote_node) - result = sh.LocalShell().su_subprocess_run(local_user, cmd, tty=True) + result = shell.su_subprocess_run(local_user, cmd, tty=True) return result.returncode else: return 0 @@ -2095,23 +2097,18 @@ def debug_timestamp(): return datetime.datetime.now().strftime('%Y/%m/%d %H:%M:%S') -def check_ssh_passwd_need(local_user, remote_user, host): +def check_ssh_passwd_need(local_user, remote_user, host, shell: sh.LocalShell = None): """ Check whether access to host need password """ ssh_options = "-o StrictHostKeyChecking=no -o EscapeChar=none -o ConnectTimeout=15" - ssh_cmd = "{} ssh {} -T -o Batchmode=yes {}@{} true".format(get_ssh_agent_str(), ssh_options, remote_user, host) - rc, _ = sh.LocalShell().get_rc_and_error(local_user, ssh_cmd) + ssh_cmd = "ssh {} -T -o Batchmode=yes {}@{} true".format(ssh_options, remote_user, host) + if shell is None: + shell = sh.LocalShell() + rc, _ = shell.get_rc_and_error(local_user, ssh_cmd) return rc != 0 -def get_ssh_agent_str(): - ssh_agent_str = "" - if crmsh.user_of_host.instance().use_ssh_agent(): - ssh_agent_str = f"SSH_AUTH_SOCK={os.environ.get('SSH_AUTH_SOCK')}" - return ssh_agent_str - - def check_port_open(ip, port): import socket diff --git a/test/unittests/test_bootstrap.py b/test/unittests/test_bootstrap.py index 2ad3ec762..416c1a295 100644 --- a/test/unittests/test_bootstrap.py +++ b/test/unittests/test_bootstrap.py @@ -585,10 +585,11 @@ def test_join_ssh_no_seed_host(self, mock_error): @mock.patch('crmsh.bootstrap.swap_public_ssh_key') @mock.patch('crmsh.utils.ssh_copy_id_no_raise') @mock.patch('crmsh.bootstrap.configure_ssh_key') + @mock.patch('crmsh.sh.LocalShell') @mock.patch('crmsh.service_manager.ServiceManager.start_service') def test_join_ssh( self, - mock_start_service, mock_config_ssh, mock_ssh_copy_id, mock_swap, + mock_start_service, mock_local_shell, mock_config_ssh, mock_ssh_copy_id, mock_swap, mock_ssh_shell, mock_change, mock_swap_2, mock_get_node_cononical_hostname, @@ -604,11 +605,15 @@ def test_join_ssh( bootstrap.join_ssh("node1", "alice") mock_start_service.assert_called_once_with("sshd.service", enable=True) + mock_local_shell: mock.MagicMock + mock_local_shell.assert_has_calls([ + mock.call(additional_environ={'SSH_AUTH_SOCK': ''}), + ]) mock_config_ssh.assert_has_calls([ mock.call("bob"), mock.call("hacluster"), ]) - mock_ssh_copy_id.assert_called_once_with("bob", "alice", "node1") + mock_ssh_copy_id.assert_called_once_with("bob", "alice", "node1", mock_local_shell.return_value) mock_subprocess_run_without_input.assert_called_once_with( 'node1', 'alice', 'sudo true', stdout=subprocess.DEVNULL, @@ -659,8 +664,9 @@ def test_swap_public_ssh_key_for_secondary_user( @mock.patch('crmsh.bootstrap.swap_public_ssh_key') @mock.patch('crmsh.utils.ssh_copy_id_no_raise') @mock.patch('crmsh.bootstrap.configure_ssh_key') + @mock.patch('crmsh.sh.LocalShell') @mock.patch('crmsh.service_manager.ServiceManager.start_service') - def test_join_ssh_bad_credential(self, mock_start_service, mock_config_ssh, mock_ssh_copy_id, mock_swap, mock_invoke, mock_change): + def test_join_ssh_bad_credential(self, mock_start_service, mock_local_shell, mock_config_ssh, mock_ssh_copy_id, mock_swap, mock_invoke, mock_change): bootstrap._context = mock.Mock(current_user="bob", default_nic="eth1", use_ssh_agent=False) mock_invoke.return_value = '' mock_swap.return_value = None @@ -670,10 +676,13 @@ def test_join_ssh_bad_credential(self, mock_start_service, mock_config_ssh, mock bootstrap.join_ssh("node1", "alice") mock_start_service.assert_called_once_with("sshd.service", enable=True) + mock_local_shell.assert_has_calls([ + mock.call(additional_environ={'SSH_AUTH_SOCK': ''}), + ]) mock_config_ssh.assert_has_calls([ mock.call("bob"), ]) - mock_ssh_copy_id.assert_called_once_with("bob", "alice", "node1") + mock_ssh_copy_id.assert_called_once_with("bob", "alice", "node1", mock_local_shell.return_value) mock_swap.assert_not_called() mock_invoke.assert_not_called() @@ -1046,10 +1055,11 @@ def test_init_qdevice_no_config(self, mock_status, mock_disable): @mock.patch('crmsh.corosync.is_qdevice_configured') @mock.patch('crmsh.bootstrap.configure_ssh_key') @mock.patch('crmsh.utils.check_ssh_passwd_need') + @mock.patch('crmsh.sh.LocalShell') @mock.patch('logging.Logger.info') def test_init_qdevice_already_configured( self, - mock_status, mock_ssh, mock_configure_ssh_key, + mock_status, mock_local_shell, mock_ssh, mock_configure_ssh_key, mock_qdevice_configured, mock_confirm, mock_list_nodes, mock_user_of_host, mock_host_user_config_class, mock_select_user_pair_for_ssh, @@ -1067,7 +1077,11 @@ def test_init_qdevice_already_configured( bootstrap.init_qdevice() mock_status.assert_called_once_with("Configure Qdevice/Qnetd:") - mock_ssh.assert_called_once_with("bob", "bob", "qnetd-node") + mock_local_shell.assert_has_calls([ + mock.call(additional_environ={'SSH_AUTH_SOCK': ''}), + mock.call(), + ]) + mock_ssh.assert_called_once_with("bob", "bob", "qnetd-node", mock_local_shell.return_value) mock_configure_ssh_key.assert_not_called() mock_host_user_config_class.return_value.save_remote.assert_called_once_with(mock_list_nodes.return_value) mock_qdevice_configured.assert_called_once_with() @@ -1084,8 +1098,9 @@ def test_init_qdevice_already_configured( @mock.patch('crmsh.corosync.is_qdevice_configured') @mock.patch('crmsh.bootstrap.configure_ssh_key') @mock.patch('crmsh.utils.check_ssh_passwd_need') + @mock.patch('crmsh.sh.LocalShell') @mock.patch('logging.Logger.info') - def test_init_qdevice(self, mock_info, mock_ssh, mock_configure_ssh_key, mock_qdevice_configured, + def test_init_qdevice(self, mock_info, mock_local_shell, mock_ssh, mock_configure_ssh_key, mock_qdevice_configured, mock_this_node, mock_list_nodes, mock_adjust_priority, mock_adjust_fence_delay, mock_user_of_host, mock_host_user_config_class, mock_select_user_pair_for_ssh): bootstrap._context = mock.Mock(qdevice_inst=self.qdevice_with_ip, current_user="bob") @@ -1103,7 +1118,11 @@ def test_init_qdevice(self, mock_info, mock_ssh, mock_configure_ssh_key, mock_qd bootstrap.init_qdevice() mock_info.assert_called_once_with("Configure Qdevice/Qnetd:") - mock_ssh.assert_called_once_with("bob", "bob", "qnetd-node") + mock_local_shell.assert_has_calls([ + mock.call(additional_environ={'SSH_AUTH_SOCK': ''}), + mock.call(), + ]) + mock_ssh.assert_called_once_with("bob", "bob", "qnetd-node", mock_local_shell.return_value) mock_host_user_config_class.return_value.add.assert_has_calls([ mock.call('bob', '192.0.2.100'), mock.call('bob', 'qnetd-node'), diff --git a/test/unittests/test_utils.py b/test/unittests/test_utils.py index c4f60b579..65db18f19 100644 --- a/test/unittests/test_utils.py +++ b/test/unittests/test_utils.py @@ -72,7 +72,7 @@ def test_check_ssh_passwd_need(mock_run): assert res is True mock_run.assert_called_once_with( "bob", - " ssh -o StrictHostKeyChecking=no -o EscapeChar=none -o ConnectTimeout=15 -T -o Batchmode=yes alice@node1 true", + "ssh -o StrictHostKeyChecking=no -o EscapeChar=none -o ConnectTimeout=15 -T -o Batchmode=yes alice@node1 true", )