diff --git a/tests/trestlebot/test_gitlab.py b/tests/trestlebot/test_gitlab.py index 669dca31..a2b6b85e 100644 --- a/tests/trestlebot/test_gitlab.py +++ b/tests/trestlebot/test_gitlab.py @@ -94,7 +94,7 @@ def test_parse_repository_with_incorrect_name() -> None: gl = GitLab("fake") with pytest.raises( GitProviderException, - match="https://notgitlab.com/owner/repo.git is an invalid repo URL", + match="https://notgitlab.com/owner/repo.git is an invalid Gitlab repo URL", ): gl.parse_repository("https://notgitlab.com/owner/repo.git") diff --git a/trestlebot/github.py b/trestlebot/github.py index 5fa7ffe3..a12fcf8e 100644 --- a/trestlebot/github.py +++ b/trestlebot/github.py @@ -31,7 +31,15 @@ def __init__(self, access_token: str): session.login(token=access_token) self._session = session - self.pattern = r"^(?:https?://)?github\.com/([^/]+)/([^/.]+)" + + # For repo URL input validation + pattern = r"^(?:https?://)?github\.com/([^/]+)/([^/.]+)" + self._pattern = re.compile(pattern) + + @property + def provider_pattern(self) -> re.Pattern[str]: + """Regex pattern to validate repository URLs""" + return self._pattern def parse_repository(self, repo_url: str) -> Tuple[str, str]: """ @@ -43,11 +51,12 @@ def parse_repository(self, repo_url: str) -> Tuple[str, str]: Returns: Owner and repo name in a tuple, respectively """ - - match = re.match(self.pattern, repo_url) + match: Optional[re.Match[str]] + stripped_url: str + match, stripped_url = self.match_url(repo_url) if not match: - raise GitProviderException(f"{repo_url} is an invalid GitHub repo URL") + raise GitProviderException(f"{stripped_url} is an invalid GitHub repo URL") owner = match.group(1) repo = match.group(2) diff --git a/trestlebot/gitlab.py b/trestlebot/gitlab.py index 5d8c58b9..5d26098e 100644 --- a/trestlebot/gitlab.py +++ b/trestlebot/gitlab.py @@ -7,7 +7,8 @@ import os import re import time -from typing import Tuple +from typing import Optional, Tuple +from urllib.parse import ParseResult, urlparse import gitlab @@ -21,10 +22,16 @@ def __init__(self, api_token: str, server_url: str = "https://gitlab.com"): self._gitlab_client = gitlab.Gitlab(server_url, private_token=api_token) - stripped_url = re.sub(r"^(https?://)?", "", server_url) - self.pattern = r"^(?:https?://)?{0}(/.+)/([^/.]+)(\.git)?$".format( - re.escape(stripped_url) - ) + # For repo URL input validation + parsed_url: ParseResult = urlparse(server_url) + stripped_url = f"{parsed_url.netloc}{parsed_url.path}" + pattern = rf"^(?:https?://)?{re.escape(stripped_url)}(/.+)/([^/.]+)(\.git)?$" + self._pattern = re.compile(pattern) + + @property + def provider_pattern(self) -> re.Pattern[str]: + """Regex pattern to validate repository URLs""" + return self._pattern def parse_repository(self, repo_url: str) -> Tuple[str, str]: """ @@ -37,13 +44,12 @@ def parse_repository(self, repo_url: str) -> Tuple[str, str]: Owner and project name in a tuple, respectively """ - # Strip out any basic auth - stripped_url = re.sub(r"https?://.*?@", "https://", repo_url) - - match = re.match(self.pattern, stripped_url) + match: Optional[re.Match[str]] + stripped_url: str + match, stripped_url = self.match_url(repo_url) if not match: - raise GitProviderException(f"{stripped_url} is an invalid repo URL") + raise GitProviderException(f"{stripped_url} is an invalid Gitlab repo URL") owner = match.group(1)[1:] # Removing the leading slash repo = match.group(2) diff --git a/trestlebot/provider.py b/trestlebot/provider.py index 7fdef153..7a67b6c4 100644 --- a/trestlebot/provider.py +++ b/trestlebot/provider.py @@ -4,8 +4,10 @@ """Base Git Provider class for the Trestle Bot.""" +import re from abc import ABC, abstractmethod -from typing import Tuple +from typing import Optional, Tuple +from urllib.parse import ParseResult, urlparse class GitProviderException(Exception): @@ -17,6 +19,25 @@ class GitProvider(ABC): Abstract base class for Git provider types """ + @property + @abstractmethod + def provider_pattern(self) -> re.Pattern[str]: + """Regex pattern to validate repository URLs""" + + def match_url(self, repo_url: str) -> Tuple[Optional[re.Match[str]], str]: + """Match a repository URL with the pattern""" + parsed_url: ParseResult = urlparse(repo_url) + scheme = parsed_url.scheme + host = parsed_url.hostname + path = parsed_url.path + + stripped_url = path + if host: + stripped_url = f"{host}{path}" + if scheme: + stripped_url = f"{scheme}://{stripped_url}" + return self.provider_pattern.match(stripped_url), stripped_url + @abstractmethod def parse_repository(self, repository_url: str) -> Tuple[str, str]: """Parse repository information into namespace and repo, respectively"""