diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index a3437b18..f3e18a3b 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -895,7 +895,7 @@ def agg( 2. Group-based UDF function input: Instead of individual rows, the function receives a list all rows within each group defined by `partition_by`. - Example: + Examples: ```py chain = chain.agg( total=lambda category, amount: [sum(amount)], @@ -904,6 +904,26 @@ def agg( ) chain.save("new_dataset") ``` + + An alternative syntax, when you need to specify a more complex function: + + ```py + # It automatically resolves which columns to pass to the function + # by looking at the function signature. + def agg_sum( + file: list[File], amount: list[float] + ) -> Iterator[tuple[File, float]]: + yield file[0], sum(amount) + + chain = chain.agg( + agg_sum, + output={"file": File, "total": float}, + # Alternative syntax is to use `C` (short for Column) to specify + # a column name or a nested column, e.g. C("file.path"). + partition_by=C("category"), + ) + chain.save("new_dataset") + ``` """ udf_obj = self._udf_to_obj(Aggregator, func, params, output, signal_map) return self._evolve(