Skip to content

Commit

Permalink
[Commoncrawl pipeline] Add component download_commoncrawl_segments (#273
Browse files Browse the repository at this point in the history
)

This PR adds the second component of the Commoncrawl pipeline. The
component downloads the WARC segment files and extracts the webpage urls
and html code to be returned as a dask dataframe.
  • Loading branch information
shayorshay authored Aug 10, 2023
1 parent ecf9e62 commit ed73f1b
Show file tree
Hide file tree
Showing 7 changed files with 278 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
FROM --platform=linux/amd64 python:3.8-slim

## System dependencies
RUN apt-get update && \
apt-get upgrade -y && \
apt-get install git -y && \
apt-get install -y gcc

# RUN apt-get update -y && apt-get install -y gcc

# install requirements
COPY requirements.txt /
RUN pip3 install --no-cache-dir -r requirements.txt

# Install Fondant
# This is split from other requirements to leverage caching
ARG FONDANT_VERSION=main
RUN pip3 install git+https://github.com/ml6team/fondant@${FONDANT_VERSION}#egg=fondant[aws]

# Set the working directory to the component folder
WORKDIR /component/src

# Copy over src-files
COPY src/ .

ENTRYPOINT ["python", "main.py"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# download_commoncrawl_segments

### Description
This component downloads commoncrawl segment files based on WARC file paths. Download can be done through the CommonCrawl API or from S3.

### **Inputs/Outputs**

See [`fondant_component.yaml`](fondant_component.yaml) for a more detailed description on all the input/output parameters.
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: Download commoncrawl segment files
description: Component that downloads commoncrawl segment files based on WARC paths
image: ghcr.io/ml6team/download_commoncrawl_segments:latest

consumes:
segment:
fields:
path:
type: string

produces:
webpage:
fields:
url:
type: string
html:
type: string

args:
use_s3:
description: Whether to use S3 to download the commoncrawl segment file. Set to True if you are running this component on an AWS cluster.
type: bool
default: 'False'
get_plain_text:
description: Whether to extract plain text from the HTML.
type: bool
default: 'False'
n_records_to_download:
description: Number of records to download from the commoncrawl segment file. Useful for small-scale testing.
type: int
default: None
retries:
description: Number of retries when downloading the commoncrawl segment file. Only used when use_s3 is set to False.
type: int
default: 3
backoff_factor:
description: Backoff factor when retrying to download the commoncrawl segment file. Only used when use_s3 is set to False.
type: float
default: 5
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
boto3==1.26.158
graphviz==0.20.1
html_text==0.5.2
requests==2.31.0
s3fs==2023.6.0
warcio==1.7.4
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import os
import logging
from typing import List, Optional

import dask.dataframe as dd
import dask.delayed as delayed
import pandas as pd

import gzip
from warcio.archiveiterator import ArchiveIterator

from fondant.component import DaskTransformComponent
from fondant.executor import DaskTransformExecutor

from utils.text_utils import convert_to_plain_text
from utils.download_utils import get_warc_file_using_boto3, get_warc_file_using_requests

logger = logging.getLogger(__name__)


def get_records(file, get_plain_text, n_records_to_download) -> List[List[str]]:
"""Extracts records from a WARC file, optionally converting HTML to plain text.
Args:
file: The WARC file.
get_plain_text: Whether to convert HTML to plain text.
n_records_to_download: The number of records to download.
Returns:
A list of webpage records, where each record is a url and content.
"""
records = []
counter = 0

for record in ArchiveIterator(file, arc2warc=True):
if record.rec_type == "response":
url = record.rec_headers.get_header("WARC-Target-URI")
content = record.content_stream().read().decode("utf-8", "replace")
if get_plain_text:
content = convert_to_plain_text(content)
records.append([url, content])
counter += 1

if n_records_to_download and counter >= n_records_to_download:
break

return records


def get_records_from_warc_file(
warc_file: str,
use_s3: Optional[bool] = False,
get_plain_text: Optional[bool] = False,
n_records_to_download: Optional[int] = None,
retries: Optional[int] = None,
backoff_factor: Optional[int] = None,
) -> List[List[str]]:
"""Downloads a WARC file and extracts the webpages.
Args:
warc_file: The path to the WARC file.
use_s3: Whether to download the WARC file from S3 or from the Commoncrawl API.
get_plain_text: Whether to convert the HTML content to plain text.
n_records_to_download: The number of webpages to download from the WARC file.
Returns:
A list of webpages.
"""
logger.info(f"Processing WARC file from segment path: {warc_file}...")

if use_s3:
response = get_warc_file_using_boto3(warc_file)
with gzip.GzipFile(fileobj=response, mode="rb") as file:
return get_records(file, get_plain_text, n_records_to_download)
else:
response = get_warc_file_using_requests(warc_file, retries, backoff_factor)
return get_records(response.raw, get_plain_text, n_records_to_download)


class DownloadCommoncrawlSegments(DaskTransformComponent):
def __init__(
self,
*_,
use_s3: Optional[bool] = False,
get_plain_text: Optional[bool] = False,
n_records_to_download: Optional[int] = None,
retries: Optional[int] = None,
backoff_factor: Optional[float] = None,
):
"""Downloads Commoncrawl segments based on a list of WARC paths.
Args:
use_s3: Whether to download the WARC files from S3 or from the Commoncrawl API.
get_plain_text: Whether to convert the HTML content to plain text.
n_records_to_download: The number of webpages to download from each segment.
"""
self.use_s3 = use_s3
self.get_plain_text = get_plain_text
self.n_records_to_download = n_records_to_download
self.retries = retries
self.backoff_factor = backoff_factor

def transform(
self,
dataframe: dd.DataFrame,
) -> dd.DataFrame:
"""Downloads Commoncrawl segments based on a list of WARC paths.
Args:
dataframe: A Dask DataFrame containing a column of WARC paths.
Returns:
A Dask DataFrame containing the downloaded webpages.
"""

dataframe = (
dataframe.apply(
lambda row: get_records_from_warc_file(
row["segment_path"],
self.use_s3,
self.get_plain_text,
self.n_records_to_download,
self.retries,
self.backoff_factor,
),
axis=1,
meta=("object"),
)
.explode()
.apply(pd.Series, meta={0: "object", 1: "object"})
)

dataframe.columns = [
"webpage_url",
"webpage_html",
]

dataframe = dataframe.reset_index(drop=True)

return dataframe


if __name__ == "__main__":
executor = DaskTransformExecutor.from_args()
executor.execute(DownloadCommoncrawlSegments)
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import time
import logging

import boto3

import requests
from requests import Session
from requests.adapters import HTTPAdapter
from requests import RequestException, ConnectionError
from urllib3.util import Retry


logger = logging.getLogger(__name__)

S3_COMMONCRAWL_BUCKET = "commoncrawl"
COMMONCRAWL_BASE_URL = "https://data.commoncrawl.org/"


def get_warc_file_using_boto3(s3_key: str) -> bytes:
"""Downloads a WARC file using boto3.
Args:
warc_file: The path to the WARC file.
Returns:
The WARC file as a bytes object.
"""
s3 = boto3.client("s3")
response = s3.get_object(Bucket=S3_COMMONCRAWL_BUCKET, Key=s3_key)
return response["Body"]


def get_warc_file_using_requests(
warc_file: str, retries: int = 3, backoff_factor: int = 5
) -> requests.Response:
session = Session()
retry_strategy = Retry(
total=retries,
backoff_factor=backoff_factor,
status_forcelist=[502, 503, 504],
allowed_methods={"POST", "GET"},
)
session.mount("https://", HTTPAdapter(max_retries=retry_strategy))

try:
response = session.get(COMMONCRAWL_BASE_URL + warc_file, stream=True)
response.raise_for_status()
return response
except requests.exceptions.RequestException as e:
logger.error(f"Error downloading WARC file: {e}")
raise
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import logging
import html_text

logger = logging.getLogger(__name__)


def convert_to_plain_text(html: str) -> str:
try:
return html_text.extract_text(html)
except Exception as e:
logger.error(f"Error converting HTML to plain text: {e}")
return None

0 comments on commit ed73f1b

Please sign in to comment.