diff --git a/docsite/docs/guides/integration.ipynb b/docsite/docs/guides/integration.ipynb index c1c3554..6f9419d 100644 --- a/docsite/docs/guides/integration.ipynb +++ b/docsite/docs/guides/integration.ipynb @@ -140,7 +140,7 @@ } ], "source": [ - "from typing import Optional\n", + "from typing import Iterable, List, Optional\n", "\n", "from sklearn.base import BaseEstimator, TransformerMixin\n", "from sklearn.pipeline import Pipeline\n", @@ -158,7 +158,7 @@ "\n", " def fit(self, X, y = None):\n", " \"\"\"\n", - " Generate the initial model spec by which subsequent X's will be \n", + " Generate the initial model spec by which subsequent X's will be\n", " transformed.\n", " \"\"\"\n", " self.model_spec = self.formula.get_model_matrix(X).model_spec\n", @@ -174,6 +174,14 @@ " X_ = self.model_spec.get_model_matrix(X)\n", " return X_\n", "\n", + " def get_feature_names_out(self, input_features: Optional[Iterable[str]] = None) -> List[str]:\n", + " \"\"\"\n", + " Expose model spec column names to scikit learn to allow column transforms later in the pipeline.\n", + " \"\"\"\n", + " if self.model_spec is None:\n", + " raise RuntimeError(\"`FormulaicTransformer.fit()` must be called before columns can be assigned names.\")\n", + " return self.model_spec.column_names\n", + "\n", "\n", "pipe = Pipeline([\n", " (\"formula\", FormulaicTransformer(\"x1 + x2 + x3\")),\n",