Skip to content

Commit

Permalink
Improve forbidden domain handling (#147)
Browse files Browse the repository at this point in the history
  • Loading branch information
Gallaecio authored Nov 28, 2023
1 parent 493a48c commit dc09ac3
Show file tree
Hide file tree
Showing 8 changed files with 227 additions and 20 deletions.
9 changes: 6 additions & 3 deletions scrapy_zyte_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

install_reactor("twisted.internet.asyncioreactor.AsyncioSelectorReactor")

from ._downloader_middleware import ScrapyZyteAPIDownloaderMiddleware # NOQA
from ._request_fingerprinter import ScrapyZyteAPIRequestFingerprinter # NOQA
from .handler import ScrapyZyteAPIDownloadHandler # NOQA
from ._middlewares import (
ScrapyZyteAPIDownloaderMiddleware,
ScrapyZyteAPISpiderMiddleware,
)
from ._request_fingerprinter import ScrapyZyteAPIRequestFingerprinter
from .handler import ScrapyZyteAPIDownloadHandler
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import logging

from scrapy.exceptions import IgnoreRequest
from zyte_api.aio.errors import RequestError

from ._params import _ParamParser

logger = logging.getLogger(__name__)


_start_requests_processed = object()


class ScrapyZyteAPIDownloaderMiddleware:
_slot_prefix = "zyte-api@"

Expand All @@ -15,6 +19,8 @@ def from_crawler(cls, crawler):
return cls(crawler)

def __init__(self, crawler) -> None:
self._forbidden_domain_start_request_count = 0
self._total_start_request_count = 0
self._param_parser = _ParamParser(crawler, cookies_enabled=False)
self._crawler = crawler

Expand All @@ -26,6 +32,14 @@ def __init__(self, crawler) -> None:
f"reached."
)

crawler.signals.connect(
self._start_requests_processed, signal=_start_requests_processed
)

def _start_requests_processed(self, count):
self._total_start_request_count = count
self._maybe_close()

def process_request(self, request, spider):
if self._param_parser.parse(request) is None:
return
Expand Down Expand Up @@ -59,3 +73,46 @@ def _max_requests_reached(self, downloader) -> bool:
)
total_requests = zapi_req_count + download_req_count
return total_requests >= self._max_requests

def process_exception(self, request, exception, spider):
if (
not request.meta.get("is_start_request")
or not isinstance(exception, RequestError)
or exception.status != 451
):
return

self._forbidden_domain_start_request_count += 1
self._maybe_close()

def _maybe_close(self):
if not self._total_start_request_count:
return
if self._forbidden_domain_start_request_count < self._total_start_request_count:
return
logger.error(
"Stopping the spider, all start requests failed because they "
"were pointing to a domain forbidden by Zyte API."
)
self._crawler.engine.close_spider(
self._crawler.spider, "failed_forbidden_domain"
)


class ScrapyZyteAPISpiderMiddleware:
@classmethod
def from_crawler(cls, crawler):
return cls(crawler)

def __init__(self, crawler):
self._send_signal = crawler.signals.send_catch_log

def process_start_requests(self, start_requests, spider):
# Mark start requests and reports to the downloader middleware the
# number of them once all have been processed.
count = 0
for request in start_requests:
request.meta["is_start_request"] = True
yield request
count += 1
self._send_signal(_start_requests_processed, count=count)
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ max-complexity = 18
select = B,C,E,F,W,T4
per-file-ignores =
tests/test_providers.py: E402
scrapy_zyte_api/__init__.py: F401
6 changes: 6 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,14 @@
"http": "scrapy_zyte_api.handler.ScrapyZyteAPIDownloadHandler",
"https": "scrapy_zyte_api.handler.ScrapyZyteAPIDownloadHandler",
},
"DOWNLOADER_MIDDLEWARES": {
"scrapy_zyte_api.ScrapyZyteAPIDownloaderMiddleware": 1000,
},
"REQUEST_FINGERPRINTER_CLASS": "scrapy_zyte_api.ScrapyZyteAPIRequestFingerprinter",
"REQUEST_FINGERPRINTER_IMPLEMENTATION": "2.7", # Silence deprecation warning
"SPIDER_MIDDLEWARES": {
"scrapy_zyte_api.ScrapyZyteAPISpiderMiddleware": 100,
},
"ZYTE_API_KEY": _API_KEY,
"TWISTED_REACTOR": "twisted.internet.asyncioreactor.AsyncioSelectorReactor",
}
Expand Down
11 changes: 11 additions & 0 deletions tests/mockserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def render_POST(self, request):
)

response_data: _API_RESPONSE = {}

if "url" not in request_data:
request.setResponseCode(400)
return json.dumps(response_data).encode()
Expand All @@ -89,6 +90,16 @@ def render_POST(self, request):
"detail": "The authentication key is not valid or can't be matched.",
}
return json.dumps(response_data).encode()
if "forbidden" in domain:
request.setResponseCode(451)
response_data = {
"status": 451,
"type": "/download/domain-forbidden",
"title": "Domain Forbidden",
"detail": "Extraction for the domain is forbidden.",
"blockedDomain": domain,
}
return json.dumps(response_data).encode()
if "suspended-account" in domain:
request.setResponseCode(403)
response_data = {
Expand Down
127 changes: 127 additions & 0 deletions tests/test_downloader_middleware.py → tests/test_middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,130 @@ def parse(self, response):
)
> 0
)


@ensureDeferred
async def test_forbidden_domain_start_url():
class TestSpider(Spider):
name = "test"
start_urls = ["https://forbidden.example"]

def parse(self, response):
pass

settings = {
"ZYTE_API_TRANSPARENT_MODE": True,
**SETTINGS,
}

with MockServer() as server:
settings["ZYTE_API_URL"] = server.urljoin("/")
crawler = get_crawler(TestSpider, settings_dict=settings)
await crawler.crawl()

assert crawler.stats.get_value("finish_reason") == "failed_forbidden_domain"


@ensureDeferred
async def test_forbidden_domain_start_urls():
class TestSpider(Spider):
name = "test"
start_urls = [
"https://forbidden.example",
"https://also-forbidden.example",
"https://oh.definitely-forbidden.example",
]

def parse(self, response):
pass

settings = {
"ZYTE_API_TRANSPARENT_MODE": True,
**SETTINGS,
}

with MockServer() as server:
settings["ZYTE_API_URL"] = server.urljoin("/")
crawler = get_crawler(TestSpider, settings_dict=settings)
await crawler.crawl()

assert crawler.stats.get_value("finish_reason") == "failed_forbidden_domain"


@ensureDeferred
async def test_some_forbidden_domain_start_url():
class TestSpider(Spider):
name = "test"
start_urls = [
"https://forbidden.example",
"https://allowed.example",
]

def parse(self, response):
pass

settings = {
"ZYTE_API_TRANSPARENT_MODE": True,
**SETTINGS,
}

with MockServer() as server:
settings["ZYTE_API_URL"] = server.urljoin("/")
crawler = get_crawler(TestSpider, settings_dict=settings)
await crawler.crawl()

assert crawler.stats.get_value("finish_reason") == "finished"


@ensureDeferred
async def test_follow_up_forbidden_domain_url():
class TestSpider(Spider):
name = "test"
start_urls = [
"https://allowed.example",
]

def parse(self, response):
yield response.follow("https://forbidden.example")

settings = {
"ZYTE_API_TRANSPARENT_MODE": True,
**SETTINGS,
}

with MockServer() as server:
settings["ZYTE_API_URL"] = server.urljoin("/")
crawler = get_crawler(TestSpider, settings_dict=settings)
await crawler.crawl()

assert crawler.stats.get_value("finish_reason") == "finished"


@ensureDeferred
async def test_forbidden_domain_with_partial_start_request_consumption():
"""With concurrency lower than the number of start requests + 1, the code
path followed changes, because ``_total_start_request_count`` is not set
in the downloader middleware until *after* some start requests have been
processed."""

class TestSpider(Spider):
name = "test"
start_urls = [
"https://forbidden.example",
]

def parse(self, response):
yield response.follow("https://forbidden.example")

settings = {
"CONCURRENT_REQUESTS": 1,
"ZYTE_API_TRANSPARENT_MODE": True,
**SETTINGS,
}

with MockServer() as server:
settings["ZYTE_API_URL"] = server.urljoin("/")
crawler = get_crawler(TestSpider, settings_dict=settings)
await crawler.crawl()

assert crawler.stats.get_value("finish_reason") == "failed_forbidden_domain"
35 changes: 18 additions & 17 deletions tests/test_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@
from pytest_twisted import ensureDeferred
from scrapy import Request, Spider
from scrapy_poet import DummyResponse
from scrapy_poet.utils.testing import (
HtmlResource,
crawl_single_item,
create_scrapy_settings,
)
from scrapy_poet.utils.testing import HtmlResource, crawl_single_item
from scrapy_poet.utils.testing import create_scrapy_settings as _create_scrapy_settings
from twisted.internet import reactor
from twisted.web.client import Agent, readBody
from web_poet import BrowserHtml, BrowserResponse, ItemPage, field, handle_urls
Expand All @@ -22,6 +19,16 @@
from .mockserver import get_ephemeral_port


def create_scrapy_settings():
settings = _create_scrapy_settings(None)
for setting, value in SETTINGS.items():
if setting.endswith("_MIDDLEWARES") and settings[setting]:
settings[setting].update(value)
else:
settings[setting] = value
return settings


@attrs.define
class ProductPage(BasePage):
html: BrowserHtml
Expand All @@ -45,8 +52,7 @@ def parse_(self, response: DummyResponse, page: ProductPage):

@ensureDeferred
async def test_provider(mockserver):
settings = create_scrapy_settings(None)
settings.update(SETTINGS)
settings = create_scrapy_settings()
settings["ZYTE_API_URL"] = mockserver.urljoin("/")
settings["SCRAPY_POET_PROVIDERS"] = {ZyteApiProvider: 0}
item, url, _ = await crawl_single_item(ZyteAPISpider, HtmlResource, settings)
Expand Down Expand Up @@ -93,8 +99,7 @@ def parse_( # type: ignore[override]
port = get_ephemeral_port()
handle_urls(f"{fresh_mockserver.host}:{port}")(MyPage)

settings = create_scrapy_settings(None)
settings.update(SETTINGS)
settings = create_scrapy_settings()
settings["ZYTE_API_URL"] = fresh_mockserver.urljoin("/")
settings["SCRAPY_POET_PROVIDERS"] = {ZyteApiProvider: 1100}
item, url, _ = await crawl_single_item(
Expand Down Expand Up @@ -123,8 +128,7 @@ def parse_(self, response: DummyResponse, product: Product, my_item: MyItem): #
port = get_ephemeral_port()
handle_urls(f"{fresh_mockserver.host}:{port}")(MyPage)

settings = create_scrapy_settings(None)
settings.update(SETTINGS)
settings = create_scrapy_settings()
settings["ZYTE_API_URL"] = fresh_mockserver.urljoin("/")
settings["SCRAPY_POET_PROVIDERS"] = {ZyteApiProvider: 1100}
item, url, _ = await crawl_single_item(
Expand Down Expand Up @@ -152,8 +156,7 @@ def parse_(self, response: DummyResponse, product: Product, browser_response: Br
port = get_ephemeral_port()
handle_urls(f"{fresh_mockserver.host}:{port}")(MyPage)

settings = create_scrapy_settings(None)
settings.update(SETTINGS)
settings = create_scrapy_settings()
settings["ZYTE_API_URL"] = fresh_mockserver.urljoin("/")
settings["SCRAPY_POET_PROVIDERS"] = {ZyteApiProvider: 1}
item, url, _ = await crawl_single_item(
Expand All @@ -171,8 +174,7 @@ def parse_(self, response: DummyResponse, product: Product, browser_response: Br

@ensureDeferred
async def test_provider_params(mockserver):
settings = create_scrapy_settings(None)
settings.update(SETTINGS)
settings = create_scrapy_settings()
settings["ZYTE_API_URL"] = mockserver.urljoin("/")
settings["SCRAPY_POET_PROVIDERS"] = {ZyteApiProvider: 0}
settings["ZYTE_API_PROVIDER_PARAMS"] = {"geolocation": "IE"}
Expand All @@ -183,8 +185,7 @@ async def test_provider_params(mockserver):

@ensureDeferred
async def test_provider_params_remove_unused_options(mockserver):
settings = create_scrapy_settings(None)
settings.update(SETTINGS)
settings = create_scrapy_settings()
settings["ZYTE_API_URL"] = mockserver.urljoin("/")
settings["SCRAPY_POET_PROVIDERS"] = {ZyteApiProvider: 0}
settings["ZYTE_API_PROVIDER_PARAMS"] = {
Expand Down
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ deps =
pytest-twisted
commands =
py.test \
--cov-report=term-missing \
--cov-report=html:coverage-html \
--cov-report=xml \
--cov=scrapy_zyte_api \
Expand Down

0 comments on commit dc09ac3

Please sign in to comment.