diff --git a/src/prefect/runner/storage.py b/src/prefect/runner/storage.py index 6de697d2462b..4f9cef4310b4 100644 --- a/src/prefect/runner/storage.py +++ b/src/prefect/runner/storage.py @@ -2,7 +2,15 @@ import subprocess from copy import deepcopy from pathlib import Path -from typing import Any, Dict, Optional, Protocol, TypedDict, Union, runtime_checkable +from typing import ( + Any, + Dict, + Optional, + Protocol, + TypedDict, + Union, + runtime_checkable, +) from urllib.parse import urlparse, urlsplit, urlunparse from uuid import uuid4 @@ -147,11 +155,9 @@ def pull_interval(self) -> Optional[int]: return self._pull_interval @property - def _repository_url_with_credentials(self) -> str: + def _formatted_credentials(self) -> Optional[str]: if not self._credentials: - return self._url - - url_components = urlparse(self._url) + return None credentials = ( self._credentials.model_dump() @@ -165,18 +171,39 @@ def _repository_url_with_credentials(self) -> str: elif isinstance(v, SecretStr): credentials[k] = v.get_secret_value() - formatted_credentials = _format_token_from_credentials( - urlparse(self._url).netloc, credentials + return _format_token_from_credentials(urlparse(self._url).netloc, credentials) + + def _add_credentials_to_url(self, url: str) -> str: + """Add credentials to given url if possible.""" + components = urlparse(url) + credentials = self._formatted_credentials + + if components.scheme != "https" or not credentials: + return url + + return urlunparse( + components._replace(netloc=f"{credentials}@{components.netloc}") ) - if url_components.scheme == "https" and formatted_credentials is not None: - updated_components = url_components._replace( - netloc=f"{formatted_credentials}@{url_components.netloc}" - ) - repository_url = urlunparse(updated_components) - else: - repository_url = self._url - return repository_url + @property + def _repository_url_with_credentials(self) -> str: + return self._add_credentials_to_url(self._url) + + @property + def _git_config(self) -> list[str]: + """Build a git configuration to use when running git commands.""" + config = {} + + # Submodules can be private. The url in .gitmodules + # will not include the credentials, we need to + # propagate them down here if they exist. + if self._include_submodules and self._formatted_credentials: + base_url = urlparse(self._url)._replace(path="") + without_auth = urlunparse(base_url) + with_auth = self._add_credentials_to_url(without_auth) + config[f"url.{with_auth}.insteadOf"] = without_auth + + return ["-c", " ".join(f"{k}={v}" for k, v in config.items())] if config else [] async def pull_code(self): """ @@ -208,7 +235,11 @@ async def pull_code(self): self._logger.debug("Pulling latest changes from origin/%s", self._branch) # Update the existing repository - cmd = ["git", "pull", "origin"] + cmd = ["git"] + # Add the git configuration, must be given after `git` and before the command + cmd += self._git_config + # Add the pull command and parameters + cmd += ["pull", "origin"] if self._branch: cmd += [self._branch] if self._include_submodules: @@ -234,12 +265,12 @@ async def _clone_repo(self): self._logger.debug("Cloning repository %s", self._url) repository_url = self._repository_url_with_credentials + cmd = ["git"] + # Add the git configuration, must be given after `git` and before the command + cmd += self._git_config + # Add the clone command and its parameters + cmd += ["clone", repository_url] - cmd = [ - "git", - "clone", - repository_url, - ] if self._branch: cmd += ["--branch", self._branch] if self._include_submodules: diff --git a/tests/runner/test_storage.py b/tests/runner/test_storage.py index 333761c69319..bd3adf08d790 100644 --- a/tests/runner/test_storage.py +++ b/tests/runner/test_storage.py @@ -268,6 +268,50 @@ async def test_include_submodules_property( cwd=Path.cwd() / "repo", ) + async def test_include_submodules_with_credentials( + self, mock_run_process: AsyncMock, monkeypatch + ): + access_token = Secret(value="token") + await access_token.save("test-token") + + repo = GitRepository( + url="https://github.com/org/repo.git", + include_submodules=True, + credentials={"access_token": access_token}, + ) + await repo.pull_code() + mock_run_process.assert_awaited_with( + [ + "git", + "-c", + "url.https://token@github.com.insteadOf=https://github.com", + "clone", + "https://token@github.com/org/repo.git", + "--recurse-submodules", + "--depth", + "1", + str(Path.cwd() / "repo"), + ] + ) + + # pretend the repo already exists + monkeypatch.setattr("pathlib.Path.exists", lambda x: ".git" in str(x)) + + await repo.pull_code() + mock_run_process.assert_awaited_with( + [ + "git", + "-c", + "url.https://token@github.com.insteadOf=https://github.com", + "pull", + "origin", + "--recurse-submodules", + "--depth", + "1", + ], + cwd=Path.cwd() / "repo", + ) + async def test_git_clone_errors_obscure_access_token( self, monkeypatch, capsys, tmp_path: Path ):