Skip to content

Commit

Permalink
[GSProcessing] Enforce re-order for node label processing during clas…
Browse files Browse the repository at this point in the history
…sification
  • Loading branch information
thvasilo committed Jan 17, 2025
1 parent 67219ec commit 4d0afcd
Show file tree
Hide file tree
Showing 12 changed files with 596 additions and 145 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import logging
from typing import Any, Dict, Optional

from graphstorm_processing.constants import VALID_TASK_TYPES


class LabelConfig(abc.ABC):
"""Basic class for label config"""
Expand Down Expand Up @@ -55,6 +57,9 @@ def __init__(self, config_dict: Dict[str, Any]):
self._mask_field_names = None

def _sanity_check(self):
assert (
self._task_type in VALID_TASK_TYPES
), f"Invalid task type {self._task_type}, must be one of {VALID_TASK_TYPES}"
if self._label_column == "":
assert self._task_type == "link_prediction", (
"When no label column is specified, the task type must be link_prediction, "
Expand Down Expand Up @@ -83,6 +88,23 @@ def _sanity_check(self):
assert all(isinstance(x, str) for x in self._mask_field_names)
assert len(self._mask_field_names) == 3

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(label_column={self._label_column!r}, "
f"task_type={self._task_type!r}, separator={self._separator!r}, "
f"multilabel={self._multilabel!r}, split={self._split!r}, "
f"custom_split_filenames={self._custom_split_filenames!r}, "
f"mask_field_names={self._mask_field_names!r})"
)

def __str__(self) -> str:
task_desc = f"{self._task_type} task"
if self._label_column:
task_desc += f" on column '{self._label_column}'"
if self._multilabel:
task_desc += f" (multilabel with separator '{self._separator}')"
return task_desc

@property
def label_column(self) -> str:
"""The name of the column storing the target label property value."""
Expand Down
10 changes: 10 additions & 0 deletions graphstorm-processing/graphstorm_processing/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,16 @@
NODE_MAPPING_STR = "orig"
NODE_MAPPING_INT = "new"

################# Reserved columns ################
DATA_SPLIT_SET_MASK_COL = "GSP-SAMPLE-SET-MASK"

################# Supported task types ##############
VALID_TASK_TYPES = {
"classification",
"regression",
"link_prediction",
}


################# Supported execution envs ##############
class ExecutionEnv(Enum):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@

from dataclasses import dataclass
from math import fsum
from typing import Optional

from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.types import FloatType
from pyspark.sql.types import NumericType

from graphstorm_processing.config.label_config_base import LabelConfig
from graphstorm_processing.data_transformations.dist_transformations import (
Expand Down Expand Up @@ -100,11 +101,33 @@ class DistLabelLoader:
The SparkSession to use for processing.
"""

def __init__(self, label_config: LabelConfig, spark: SparkSession) -> None:
def __init__(
self, label_config: LabelConfig, spark: SparkSession, order_col: Optional[str] = None
) -> None:
self.label_config = label_config
self.label_column = label_config.label_column
self.spark = spark
self.label_map: dict[str, int] = {}
self.order_col = order_col

def __str__(self) -> str:
"""String representation for informal display"""
return (
f"DistLabelLoader(label_column='{self.label_column}', "
f"task_type='{self.label_config.task_type}', "
f"multilabel={self.label_config.multilabel}, "
f"order_col={self.order_col!r})"
)

def __repr__(self) -> str:
"""Detailed string representation for debugging"""
return (
f"DistLabelLoader("
f"label_config={self.label_config!r}, "
f"spark={self.spark!r}, "
f"order_col={self.order_col!r}, "
f"label_map={self.label_map!r})"
)

def process_label(self, input_df: DataFrame) -> DataFrame:
"""Transforms the label column in the input DataFrame to conform to GraphStorm expectations.
Expand Down Expand Up @@ -134,21 +157,30 @@ def process_label(self, input_df: DataFrame) -> DataFrame:
label_type = input_df.schema[self.label_column].dataType

if self.label_config.task_type == "classification":
assert self.order_col, f"{self.order_col} must be provided for classification tasks"
if self.label_config.multilabel:
assert self.label_config.separator
label_transformer = DistMultiLabelTransformation(
[self.label_config.label_column], self.label_config.separator
)
else:
label_transformer = DistSingleLabelTransformation(
[self.label_config.label_column], self.spark
[self.label_config.label_column],
self.spark,
)

transformed_label = label_transformer.apply(input_df).select(self.label_column)
transformed_label = label_transformer.apply(input_df)
if self.order_col:
assert self.order_col in transformed_label.columns, (
f"{self.order_col=} needs to be part of transformed "
f"label DF, got {transformed_label.columns=}"
)
transformed_label = transformed_label.sort(self.order_col).cache()

self.label_map = label_transformer.value_map
return transformed_label
return transformed_label # .select(self.label_column)
elif self.label_config.task_type == "regression":
if not isinstance(label_type, FloatType):
if not isinstance(label_type, NumericType):
raise RuntimeError(
"Data type for regression should be FloatType, "
f"got {label_type} for {self.label_column}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,12 @@ def __init__(self, cols: Sequence[str], separator: str) -> None:
def get_transformation_name() -> str:
return "DistMultiCategoryTransformation"

def apply(self, input_df: DataFrame) -> DataFrame:
def apply(self, input_df: DataFrame, return_all_cols: bool = False) -> DataFrame:
col_datatype = input_df.schema[self.multi_column].dataType
if return_all_cols:
original_cols = {*input_df.columns} - {self.multi_column}
else:
original_cols = {}
is_array_col = False
if col_datatype.typeName() == "array":
assert isinstance(col_datatype, ArrayType)
Expand All @@ -326,13 +330,19 @@ def apply(self, input_df: DataFrame) -> DataFrame:

is_array_col = True

# Parquet input might come with arrays already, CSV will need splitting
if is_array_col:
list_df = input_df.select(self.multi_column).alias(self.multi_column)
multi_column = F.col(self.multi_column)
else:
list_df = input_df.select(
F.split(F.col(self.multi_column), self.separator).alias(self.multi_column)
multi_column = F.split(F.col(self.multi_column), self.separator).alias(
self.multi_column
)

list_df = input_df.select(
multi_column,
*original_cols,
)

distinct_category_counts = (
list_df.withColumn(SINGLE_CATEGORY_COL, F.explode(F.col(self.multi_column)))
.groupBy(SINGLE_CATEGORY_COL)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(self, cols: Sequence[str], spark: SparkSession) -> None:

def apply(self, input_df: DataFrame) -> DataFrame:
assert self.spark
original_cols = {*input_df.columns} - {self.label_column}
processed_col_name = self.label_column + "_processed"

str_indexer = StringIndexer(
Expand All @@ -63,13 +64,15 @@ def apply(self, input_df: DataFrame) -> DataFrame:

# Labels that were missing and were assigned the value numLabels by the StringIndexer
# are converted to None
long_class_label = indexed_df.select(F.col(self.label_column).cast("long")).select(
long_class_label = indexed_df.select(
F.when(
F.col(self.label_column) == len(str_indexer_model.labelsArray[0]), # type: ignore
F.lit(None),
)
.otherwise(F.col(self.label_column))
.alias(self.label_column)
.cast("long")
.alias(self.label_column),
*original_cols,
)

# Get a mapping from original label to encoded value
Expand Down Expand Up @@ -112,7 +115,7 @@ def __init__(self, cols: Sequence[str], separator: str) -> None:
super().__init__(cols, separator)
self.label_column = cols[0]

def apply(self, input_df: DataFrame) -> DataFrame:
multi_cat_df = super().apply(input_df)
def apply(self, input_df: DataFrame, return_all_cols=True) -> DataFrame:
multi_cat_df = super().apply(input_df, return_all_cols=return_all_cols)

return multi_cat_df
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,8 @@ def __init__(
output_prefix=self.output_prefix,
precomputed_transformations=self.precomputed_transformations,
)
print(f"{self.output_prefix}")
print(f"{loader_config.output_prefix}")
self.loader = DistHeterogeneousGraphLoader(
self.spark,
loader_config,
Expand All @@ -287,17 +289,18 @@ def _upload_output_files(self, loader: DistHeterogeneousGraphLoader, force=False
bucket, s3_prefix = s3_utils.extract_bucket_and_key(self.output_prefix)
s3 = boto3.resource("s3")

output_files = os.listdir(loader.output_path)
output_files = os.listdir(loader.local_meta_output_path)
for output_file in output_files:
s3.meta.client.upload_file(
f"{os.path.join(loader.output_path, output_file)}",
f"{os.path.join(loader.local_meta_output_path, output_file)}",
bucket,
f"{s3_prefix}/{output_file}",
)

def run(self) -> None:
"""
Executes the Spark processing job.
Executes the Spark processing job, optional repartition job, and uploads any metadata files
if needed.
"""
logging.info("Performing data processing with PySpark...")

Expand Down Expand Up @@ -355,7 +358,7 @@ def run(self) -> None:
# If any of the metadata modification took place, write an updated metadata file
if updated_metadata:
updated_meta_path = os.path.join(
self.loader.output_path, "updated_row_counts_metadata.json"
self.loader.local_meta_output_path, "updated_row_counts_metadata.json"
)
with open(
updated_meta_path,
Expand Down Expand Up @@ -614,6 +617,7 @@ def main():
level=gsprocessing_args.log_level,
format="[GSPROCESSING] %(asctime)s %(levelname)-8s %(message)s",
)
print(f"{gsprocessing_args.output_prefix=}")

# Determine execution environment
if os.path.exists("/opt/ml/config/processingjobconfig.json"):
Expand Down Expand Up @@ -715,6 +719,8 @@ def main():
do_repartition=gsprocessing_args.do_repartition,
)

print(f"{executor_configuration.output_prefix=}")

dist_executor = DistributedExecutor(executor_configuration)

dist_executor.run()
Expand Down
Loading

0 comments on commit 4d0afcd

Please sign in to comment.