diff --git a/tests/trestlebot/entrypoints/test_entrypoint_base.py b/tests/trestlebot/entrypoints/test_entrypoint_base.py index e47f14a4..8aebaa3a 100644 --- a/tests/trestlebot/entrypoints/test_entrypoint_base.py +++ b/tests/trestlebot/entrypoints/test_entrypoint_base.py @@ -23,16 +23,16 @@ @patch.dict("os.environ", {"GITHUB_ACTIONS": "true"}) def test_set_git_provider_with_github() -> None: """Test set_git_provider function in Entrypoint Base for GitHub Actions""" - provider: Optional[GitProvider] - fake_token = StringIO("fake_token") - args = argparse.Namespace( - target_branch="main", - with_token=fake_token, - git_provider_type="", - git_server_url="", - ) - provider = EntrypointBase.set_git_provider(args=args) - assert isinstance(provider, GitHub) + with patch("sys.stdin", return_value=StringIO("fake_token")): + provider: Optional[GitProvider] + args = argparse.Namespace( + target_branch="main", + with_token=True, + git_provider_type="", + git_server_url="", + ) + provider = EntrypointBase.set_git_provider(args=args) + assert isinstance(provider, GitHub) @patch.dict( @@ -46,100 +46,100 @@ def test_set_git_provider_with_github() -> None: ) def test_set_git_provider_with_gitlab() -> None: """Test set_git_provider function in Entrypoint Base for GitLab CI""" - provider: Optional[GitProvider] - fake_token = StringIO("fake_token") - args = argparse.Namespace( - target_branch="main", - with_token=fake_token, - git_provider_type="", - git_server_url="", - ) - provider = EntrypointBase.set_git_provider(args=args) - assert isinstance(provider, GitLab) + with patch("sys.stdin", return_value=StringIO("fake_token")): + provider: Optional[GitProvider] + args = argparse.Namespace( + target_branch="main", + with_token=True, + git_provider_type="", + git_server_url="", + ) + provider = EntrypointBase.set_git_provider(args=args) + assert isinstance(provider, GitLab) @patch.dict("os.environ", {"GITHUB_ACTIONS": "false", "GITLAB_CI": "true"}) def test_set_git_provider_with_gitlab_with_failure() -> None: """Trigger error with GitLab provider with insufficient environment variables""" - fake_token = StringIO("fake_token") - args = argparse.Namespace( - target_branch="main", - with_token=fake_token, - git_provider_type="", - git_server_url="", - ) - with pytest.raises( - GitProviderException, - match="Set CI_SERVER_PROTOCOL and CI SERVER HOST environment variables", - ): - EntrypointBase.set_git_provider(args=args) + with patch("sys.stdin", return_value=StringIO("fake_token")): + args = argparse.Namespace( + target_branch="main", + with_token=True, + git_provider_type="", + git_server_url="", + ) + with pytest.raises( + GitProviderException, + match="Set CI_SERVER_PROTOCOL and CI SERVER HOST environment variables", + ): + EntrypointBase.set_git_provider(args=args) @patch.dict("os.environ", {"GITHUB_ACTIONS": "false"}) def test_set_git_provider_with_none() -> None: """Test set_git_provider function when no git provider is set""" - fake_token = StringIO("fake_token") - provider: Optional[GitProvider] - args = argparse.Namespace( - target_branch="main", - with_token=fake_token, - git_provider_type="", - git_server_url="", - ) - - with pytest.raises( - EntrypointInvalidArgException, - match="Invalid args --target-branch, --git-provider-type: " - "Could not detect Git provider from environment or inputs", - ): - EntrypointBase.set_git_provider(args=args) - - # Now test with no target branch which is a valid case - args = argparse.Namespace(target_branch=None, with_token=None) - provider = EntrypointBase.set_git_provider(args=args) - assert provider is None + with patch("sys.stdin", return_value=StringIO("fake_token")): + provider: Optional[GitProvider] + args = argparse.Namespace( + target_branch="main", + with_token=True, + git_provider_type="", + git_server_url="", + ) + + with pytest.raises( + EntrypointInvalidArgException, + match="Invalid args --target-branch, --git-provider-type: " + "Could not detect Git provider from environment or inputs", + ): + EntrypointBase.set_git_provider(args=args) + + # Now test with no target branch which is a valid case + args = argparse.Namespace(target_branch=None) + provider = EntrypointBase.set_git_provider(args=args) + assert provider is None def test_set_provider_with_no_token() -> None: """Test set_git_provider function with no token""" - args = argparse.Namespace(target_branch="main", with_token=None) + args = argparse.Namespace(target_branch="main", with_token=False) with pytest.raises( EntrypointInvalidArgException, - match="Invalid args --with-token: " - "with-token flag must be set when using target-branch", + match="Invalid args --with-token: with-token flag must be set to read from " + "standard input when using target-branch", ): EntrypointBase.set_git_provider(args=args) def test_set_provider_with_input() -> None: """Test set_git_provider function with type and server url input.""" - provider: Optional[GitProvider] - fake_token = StringIO("fake_token") - args = argparse.Namespace( - target_branch="main", - with_token=fake_token, - git_provider_type="github", - git_server_url="", - ) - provider = EntrypointBase.set_git_provider(args=args) - assert isinstance(provider, GitHub) - args = argparse.Namespace( - target_branch="main", - with_token=fake_token, - git_provider_type="gitlab", - git_server_url="", - ) - provider = EntrypointBase.set_git_provider(args=args) - assert isinstance(provider, GitLab) - - args = argparse.Namespace( - target_branch="main", - with_token=fake_token, - git_provider_type="github", - git_server_url="https://notgithub.com", - ) - with pytest.raises( - EntrypointInvalidArgException, - match="Invalid args --server-url: GitHub provider does not support custom server URLs", - ): - EntrypointBase.set_git_provider(args=args) + with patch("sys.stdin", return_value=StringIO("fake_token")): + provider: Optional[GitProvider] + args = argparse.Namespace( + target_branch="main", + with_token=True, + git_provider_type="github", + git_server_url="", + ) + provider = EntrypointBase.set_git_provider(args=args) + assert isinstance(provider, GitHub) + args = argparse.Namespace( + target_branch="main", + with_token=True, + git_provider_type="gitlab", + git_server_url="", + ) + provider = EntrypointBase.set_git_provider(args=args) + assert isinstance(provider, GitLab) + + args = argparse.Namespace( + target_branch="main", + with_token=True, + git_provider_type="github", + git_server_url="https://notgithub.com", + ) + with pytest.raises( + EntrypointInvalidArgException, + match="Invalid args --server-url: GitHub provider does not support custom server URLs", + ): + EntrypointBase.set_git_provider(args=args) diff --git a/trestlebot/entrypoints/entrypoint_base.py b/trestlebot/entrypoints/entrypoint_base.py index 09e00012..93bfbd26 100644 --- a/trestlebot/entrypoints/entrypoint_base.py +++ b/trestlebot/entrypoints/entrypoint_base.py @@ -133,10 +133,9 @@ def _set_git_provider_args(self) -> None: ) git_provider_arg_group.add_argument( "--with-token", - nargs="?", - type=argparse.FileType("r"), required=False, - default=sys.stdin, + default=False, + action="store_true", help="Read token from standard input for authenticated requests with \ Git provider (e.g. create pull requests)", ) @@ -151,7 +150,7 @@ def _set_git_provider_args(self) -> None: "--git-provider-type", required=False, choices=[const.GITHUB, const.GITLAB], - help="Optional supported Git provider identify. " + help="Optional supported Git provider to identify. " "Defaults to auto detection based on pre-defined CI environment variables.", ) git_provider_arg_group.add_argument( @@ -166,13 +165,15 @@ def _set_git_provider_args(self) -> None: def set_git_provider(args: argparse.Namespace) -> Optional[GitProvider]: """Get the git provider based on the environment and args.""" git_provider: Optional[GitProvider] = None - if args.target_branch: + if args.target_branch is not None: if not args.with_token: raise EntrypointInvalidArgException( "--with-token", - "with-token flag must be set when using target-branch", + "with-token flag must be set to read from standard input when " + "using target-branch", ) - access_token = args.with_token.read().strip() + else: + access_token = sys.stdin.read().strip() try: git_provider_type = args.git_provider_type git_server_url = args.git_server_url