Skip to content

Commit

Permalink
Notification security controls (#15272)
Browse files Browse the repository at this point in the history
  • Loading branch information
cicdw authored Sep 9, 2024
1 parent e340007 commit b249a3c
Show file tree
Hide file tree
Showing 6 changed files with 253 additions and 13 deletions.
21 changes: 21 additions & 0 deletions src/prefect/blocks/notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from prefect.types import SecretDict
from prefect.utilities.asyncutils import sync_compatible
from prefect.utilities.templating import apply_values, find_placeholders
from prefect.utilities.urls import validate_restricted_url

PREFECT_NOTIFY_TYPE_DEFAULT = "prefect_default"

Expand Down Expand Up @@ -80,6 +81,26 @@ class AppriseNotificationBlock(AbstractAppriseNotificationBlock, ABC):
description="Incoming webhook URL used to send notifications.",
examples=["https://hooks.example.com/XXX"],
)
allow_private_urls: bool = Field(
default=True,
description="Whether to allow notifications to private URLs. Defaults to True.",
)

@sync_compatible
async def notify(
self,
body: str,
subject: Optional[str] = None,
):
if not self.allow_private_urls:
try:
validate_restricted_url(self.url.get_secret_value())
except ValueError as exc:
if self._raise_on_failure:
raise NotificationError(str(exc))
raise

await super().notify(body, subject)


# TODO: Move to prefect-slack once collection block auto-registration is
Expand Down
8 changes: 8 additions & 0 deletions src/prefect/blocks/webhook.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from prefect.blocks.core import Block
from prefect.types import SecretDict
from prefect.utilities.urls import validate_restricted_url

# Use a global HTTP transport to maintain a process-wide connection pool for
# interservice requests
Expand Down Expand Up @@ -39,6 +40,10 @@ class Webhook(Block):
title="Webhook Headers",
description="A dictionary of headers to send with the webhook request.",
)
allow_private_urls: bool = Field(
default=True,
description="Whether to allow notifications to private URLs. Defaults to True.",
)

def block_initialization(self):
self._client = AsyncClient(transport=_http_transport)
Expand All @@ -50,6 +55,9 @@ async def call(self, payload: Optional[dict] = None) -> Response:
Args:
payload: an optional payload to send when calling the webhook.
"""
if not self.allow_private_urls:
validate_restricted_url(self.url.get_secret_value())

async with self._client:
return await self._client.request(
method=self.method,
Expand Down
82 changes: 70 additions & 12 deletions src/prefect/utilities/urls.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
import inspect
import ipaddress
import socket
import urllib.parse
from typing import Any, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
from urllib.parse import urlparse
from uuid import UUID

from pydantic import BaseModel

from prefect import settings
from prefect.blocks.core import Block
from prefect.events.schemas.automations import Automation
from prefect.events.schemas.events import ReceivedEvent, Resource
from prefect.futures import PrefectFuture
from prefect.logging.loggers import get_logger
from prefect.variables import Variable

if TYPE_CHECKING:
from prefect.blocks.core import Block
from prefect.events.schemas.automations import Automation
from prefect.events.schemas.events import ReceivedEvent, Resource
from prefect.futures import PrefectFuture
from prefect.variables import Variable

logger = get_logger("utilities.urls")

Expand Down Expand Up @@ -58,6 +63,54 @@
RUN_TYPES = {"flow-run", "task-run"}


def validate_restricted_url(url: str):
"""
Validate that the provided URL is safe for outbound requests. This prevents
attacks like SSRF (Server Side Request Forgery), where an attacker can make
requests to internal services (like the GCP metadata service, localhost addresses,
or in-cluster Kubernetes services)
Args:
url: The URL to validate.
Raises:
ValueError: If the URL is a restricted URL.
"""

try:
parsed_url = urlparse(url)
except ValueError:
raise ValueError(f"{url!r} is not a valid URL.")

if parsed_url.scheme not in ("http", "https"):
raise ValueError(
f"{url!r} is not a valid URL. Only HTTP and HTTPS URLs are allowed."
)

hostname = parsed_url.hostname or ""

# Remove IPv6 brackets if present
if hostname.startswith("[") and hostname.endswith("]"):
hostname = hostname[1:-1]

if not hostname:
raise ValueError(f"{url!r} is not a valid URL.")

try:
ip_address = socket.gethostbyname(hostname)
ip = ipaddress.ip_address(ip_address)
except socket.gaierror:
try:
ip = ipaddress.ip_address(hostname)
except ValueError:
raise ValueError(f"{url!r} is not a valid URL. It could not be resolved.")

if ip.is_private:
raise ValueError(
f"{url!r} is not a valid URL. It resolves to the private address {ip}."
)


def convert_class_to_name(obj: Any) -> str:
"""
Convert CamelCase class name to dash-separated lowercase name
Expand All @@ -69,12 +122,12 @@ def convert_class_to_name(obj: Any) -> str:

def url_for(
obj: Union[
PrefectFuture,
Block,
Variable,
Automation,
Resource,
ReceivedEvent,
"PrefectFuture",
"Block",
"Variable",
"Automation",
"Resource",
"ReceivedEvent",
BaseModel,
str,
],
Expand Down Expand Up @@ -105,6 +158,11 @@ def url_for(
url_for(obj=my_flow_run)
url_for("flow-run", obj_id="123e4567-e89b-12d3-a456-426614174000")
"""
from prefect.blocks.core import Block
from prefect.events.schemas.automations import Automation
from prefect.events.schemas.events import ReceivedEvent, Resource
from prefect.futures import PrefectFuture

if isinstance(obj, PrefectFuture):
name = "task-run"
elif isinstance(obj, Block):
Expand Down
66 changes: 66 additions & 0 deletions tests/blocks/test_notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
import respx

from prefect.blocks.abstract import NotificationError
from prefect.blocks.notifications import (
PREFECT_NOTIFY_TYPE_DEFAULT,
AppriseNotificationBlock,
Expand All @@ -31,6 +32,41 @@
key=lambda cls: cls.__name__,
)

RESTRICTED_URLS = [
("", ""),
(" ", ""),
("[]", ""),
("not a url", ""),
("http://", ""),
("https://", ""),
("ftp://example.com", "HTTP and HTTPS"),
("gopher://example.com", "HTTP and HTTPS"),
("https://localhost", "private address"),
("https://127.0.0.1", "private address"),
("https://[::1]", "private address"),
("https://[fc00:1234:5678:9abc::10]", "private address"),
("https://[fd12:3456:789a:1::1]", "private address"),
("https://[fe80::1234:5678:9abc]", "private address"),
("https://10.0.0.1", "private address"),
("https://10.255.255.255", "private address"),
("https://172.16.0.1", "private address"),
("https://172.31.255.255", "private address"),
("https://192.168.1.1", "private address"),
("https://192.168.1.255", "private address"),
("https://169.254.0.1", "private address"),
("https://169.254.169.254", "private address"),
("https://169.254.254.255", "private address"),
# These will resolve to a private address in production, but not in tests,
# so we'll use "resolve" as the reason to catch both cases
("https://metadata.google.internal", "resolve"),
("https://anything.privatecloud", "resolve"),
("https://anything.privatecloud.svc", "resolve"),
("https://anything.privatecloud.svc.cluster.local", "resolve"),
("https://cluster-internal", "resolve"),
("https://network-internal.cloud.svc", "resolve"),
("https://private-internal.cloud.svc.cluster.local", "resolve"),
]


@pytest.mark.parametrize("block_class", notification_classes)
class TestAppriseNotificationBlock:
Expand Down Expand Up @@ -81,6 +117,36 @@ def test_is_picklable(self, block_class: Type[AppriseNotificationBlock]):
unpickled = cloudpickle.loads(pickled)
assert isinstance(unpickled, block_class)

@pytest.mark.parametrize("value, reason", RESTRICTED_URLS)
async def test_notification_can_prevent_restricted_urls(
self, block_class, value: str, reason: str
):
notification = block_class(url=value, allow_private_urls=False)

with pytest.raises(ValueError, match=f"is not a valid URL.*{reason}"):
await notification.notify(subject="example", body="example")

async def test_raises_on_url_validation_failure(self, block_class):
"""
When within a raise_on_failure block, we want URL validation errors to be
wrapped and captured as NotificationErrors for reporting back to users.
"""
block = block_class(url="https://127.0.0.1/foo/bar", allow_private_urls=False)

# outside of a raise_on_failure block, we get a ValueError directly
with pytest.raises(ValueError, match="not a valid URL") as captured:
await block.notify(subject="Test", body="Test")

# inside of a raise_on_failure block, we get a NotificationError
with block.raise_on_failure():
with pytest.raises(NotificationError) as captured:
await block.notify(subject="Test", body="Test")

assert captured.value.log == (
"'https://127.0.0.1/foo/bar' is not a valid URL. It resolves to the "
"private address 127.0.0.1."
)


class TestMattermostWebhook:
async def test_notify_async(self):
Expand Down
45 changes: 45 additions & 0 deletions tests/blocks/test_webhook.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,42 @@
from prefect.blocks.webhook import Webhook
from prefect.testing.utilities import AsyncMock

RESTRICTED_URLS = [
("", ""),
(" ", ""),
("[]", ""),
("not a url", ""),
("http://", ""),
("https://", ""),
("http://[]/foo/bar", ""),
("ftp://example.com", "HTTP and HTTPS"),
("gopher://example.com", "HTTP and HTTPS"),
("https://localhost", "private address"),
("https://127.0.0.1", "private address"),
("https://[::1]", "private address"),
("https://[fc00:1234:5678:9abc::10]", "private address"),
("https://[fd12:3456:789a:1::1]", "private address"),
("https://[fe80::1234:5678:9abc]", "private address"),
("https://10.0.0.1", "private address"),
("https://10.255.255.255", "private address"),
("https://172.16.0.1", "private address"),
("https://172.31.255.255", "private address"),
("https://192.168.1.1", "private address"),
("https://192.168.1.255", "private address"),
("https://169.254.0.1", "private address"),
("https://169.254.169.254", "private address"),
("https://169.254.254.255", "private address"),
# These will resolve to a private address in production, but not in tests,
# so we'll use "resolve" as the reason to catch both cases
("https://metadata.google.internal", "resolve"),
("https://anything.privatecloud", "resolve"),
("https://anything.privatecloud.svc", "resolve"),
("https://anything.privatecloud.svc.cluster.local", "resolve"),
("https://cluster-internal", "resolve"),
("https://network-internal.cloud.svc", "resolve"),
("https://private-internal.cloud.svc.cluster.local", "resolve"),
]


class TestWebhook:
def test_webhook_raises_error_on_bad_request_method(self):
Expand All @@ -13,6 +49,15 @@ def test_webhook_raises_error_on_bad_request_method(self):
with pytest.raises(ValueError):
Webhook(method=bad_method, url="http://google.com")

@pytest.mark.parametrize("value, reason", RESTRICTED_URLS)
async def test_webhook_must_not_point_to_restricted_urls(
self, value: str, reason: str
):
webhook = Webhook(url=value, allow_private_urls=False)

with pytest.raises(ValueError, match=f"is not a valid URL.*{reason}"):
await webhook.call(payload="some payload")

async def test_webhook_sends(self, monkeypatch):
send_mock = AsyncMock()
monkeypatch.setattr("httpx.AsyncClient.request", send_mock)
Expand Down
44 changes: 43 additions & 1 deletion tests/utilities/test_urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,48 @@
from prefect.server.schemas.core import FlowRun, TaskRun
from prefect.server.schemas.states import State
from prefect.settings import PREFECT_API_URL, PREFECT_UI_URL, temporary_settings
from prefect.utilities.urls import url_for
from prefect.utilities.urls import url_for, validate_restricted_url
from prefect.variables import Variable

MOCK_PREFECT_UI_URL = "https://ui.prefect.io"
MOCK_PREFECT_API_URL = "https://api.prefect.io"

RESTRICTED_URLS = [
("", ""),
(" ", ""),
("[]", ""),
("not a url", ""),
("http://", ""),
("https://", ""),
("http://[]/foo/bar", ""),
("ftp://example.com", "HTTP and HTTPS"),
("gopher://example.com", "HTTP and HTTPS"),
("https://localhost", "private address"),
("https://127.0.0.1", "private address"),
("https://[::1]", "private address"),
("https://[fc00:1234:5678:9abc::10]", "private address"),
("https://[fd12:3456:789a:1::1]", "private address"),
("https://[fe80::1234:5678:9abc]", "private address"),
("https://10.0.0.1", "private address"),
("https://10.255.255.255", "private address"),
("https://172.16.0.1", "private address"),
("https://172.31.255.255", "private address"),
("https://192.168.1.1", "private address"),
("https://192.168.1.255", "private address"),
("https://169.254.0.1", "private address"),
("https://169.254.169.254", "private address"),
("https://169.254.254.255", "private address"),
# These will resolve to a private address in production, but not in tests,
# so we'll use "resolve" as the reason to catch both cases
("https://metadata.google.internal", "resolve"),
("https://anything.privatecloud", "resolve"),
("https://anything.privatecloud.svc", "resolve"),
("https://anything.privatecloud.svc.cluster.local", "resolve"),
("https://cluster-internal", "resolve"),
("https://network-internal.cloud.svc", "resolve"),
("https://private-internal.cloud.svc.cluster.local", "resolve"),
]


@pytest.fixture
async def variable():
Expand Down Expand Up @@ -105,6 +141,12 @@ def resource():
return Resource({"prefect.resource.id": f"prefect.flow-run.{uuid.uuid4()}"})


@pytest.mark.parametrize("value, reason", RESTRICTED_URLS)
def test_validate_restricted_url_validates(value: str, reason: str):
with pytest.raises(ValueError, match=f"is not a valid URL.*{reason}"):
validate_restricted_url(url=value)


@pytest.mark.parametrize("url_type", ["ui", "api"])
def test_url_for_flow_run(flow_run, url_type: Literal["ui", "api"]):
expected_url = (
Expand Down

0 comments on commit b249a3c

Please sign in to comment.