Skip to content

Commit

Permalink
Add column name extraction to sklearn example.
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewwardrop committed Mar 9, 2024
1 parent 34c667e commit ffcf9c8
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions docsite/docs/guides/integration.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@
}
],
"source": [
"from typing import Optional\n",
"from typing import Iterable, List, Optional\n",
"\n",
"from sklearn.base import BaseEstimator, TransformerMixin\n",
"from sklearn.pipeline import Pipeline\n",
Expand All @@ -158,7 +158,7 @@
"\n",
" def fit(self, X, y = None):\n",
" \"\"\"\n",
" Generate the initial model spec by which subsequent X's will be \n",
" Generate the initial model spec by which subsequent X's will be\n",
" transformed.\n",
" \"\"\"\n",
" self.model_spec = self.formula.get_model_matrix(X).model_spec\n",
Expand All @@ -174,6 +174,14 @@
" X_ = self.model_spec.get_model_matrix(X)\n",
" return X_\n",
"\n",
" def get_feature_names_out(self, input_features: Optional[Iterable[str]] = None) -> List[str]:\n",
" \"\"\"\n",
" Expose model spec column names to scikit learn to allow column transforms later in the pipeline.\n",
" \"\"\"\n",
" if self.model_spec is None:\n",
" raise RuntimeError(\"`FormulaicTransformer.fit()` must be called before columns can be assigned names.\")\n",
" return self.model_spec.column_names\n",
"\n",
"\n",
"pipe = Pipeline([\n",
" (\"formula\", FormulaicTransformer(\"x1 + x2 + x3\")),\n",
Expand Down

0 comments on commit ffcf9c8

Please sign in to comment.