Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEA add link for HTML representation #1051

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion skrub/_gap_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,18 @@

from . import _dataframe as sbd
from ._on_each_column import RejectColumn, SingleColumnTransformer
from ._repr import (
_HTMLDocumentationLinkMixin,
doc_link_module,
doc_link_template,
doc_link_url_param_generator,
)
from ._utils import unique_strings


class GapEncoder(TransformerMixin, SingleColumnTransformer):
class GapEncoder(
_HTMLDocumentationLinkMixin, TransformerMixin, SingleColumnTransformer
):
"""Constructs latent topics with continuous encoding.

This encoder can be understood as a continuous encoding on a set of latent
Expand Down Expand Up @@ -177,6 +185,10 @@ class GapEncoder(TransformerMixin, SingleColumnTransformer):
The higher the value, the bigger the correspondence with the topic.
"""

_doc_link_module = doc_link_module
_doc_link_template = doc_link_template
_doc_link_url_param_generator = doc_link_url_param_generator

def __init__(
self,
n_components=10,
Expand Down
103 changes: 103 additions & 0 deletions skrub/_repr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import itertools
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we name this module something more explicit like "_sklearn_html_repr_utils" or similar?


import sklearn
from sklearn.utils.fixes import parse_version

sklearn_version = parse_version(sklearn.__version__)

# TODO: remove when scikit-learn 1.6 is the minimum supported version
# TODO: subsequently, we should remove the inheritance from _HTMLDocumentationLinkMixin
# for each estimator then.
if sklearn_version > parse_version("1.6"):
from sklearn.utils._estimator_html_repr import _HTMLDocumentationLinkMixin

Check warning on line 12 in skrub/_repr.py

View check run for this annotation

Codecov / codecov/patch

skrub/_repr.py#L12

Added line #L12 was not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC we will now have 2 of those in every skrub estimator's parents, one directly and one through the BaseEstimator?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. The one on the left in the inheritance will be the one used.

else:

class _HTMLDocumentationLinkMixin:
"""Mixin class allowing to generate a link to the API documentation."""

_doc_link_module = "sklearn"
_doc_link_url_param_generator = None

@property
def _doc_link_template(self):
sklearn_version = parse_version(sklearn.__version__)

Check warning on line 23 in skrub/_repr.py

View check run for this annotation

Codecov / codecov/patch

skrub/_repr.py#L23

Added line #L23 was not covered by tests
if sklearn_version.dev is None:
version_url = f"{sklearn_version.major}.{sklearn_version.minor}"

Check warning on line 25 in skrub/_repr.py

View check run for this annotation

Codecov / codecov/patch

skrub/_repr.py#L25

Added line #L25 was not covered by tests
else:
version_url = "dev"
return getattr(

Check warning on line 28 in skrub/_repr.py

View check run for this annotation

Codecov / codecov/patch

skrub/_repr.py#L27-L28

Added lines #L27 - L28 were not covered by tests
self,
"__doc_link_template",
(
f"https://scikit-learn.org/{version_url}/modules/generated/"
"{estimator_module}.{estimator_name}.html"
),
)

@_doc_link_template.setter
def _doc_link_template(self, value):
setattr(self, "__doc_link_template", value)

Check warning on line 39 in skrub/_repr.py

View check run for this annotation

Codecov / codecov/patch

skrub/_repr.py#L39

Added line #L39 was not covered by tests

def _get_doc_link(self):
"""Generates a link to the API documentation for a given estimator.

This method generates the link to the estimator's documentation page
by using the template defined by the attribute `_doc_link_template`.

Returns
-------
url : str
The URL to the API documentation for this estimator. If the estimator
does not belong to module `_doc_link_module`, the empty string (i.e.
`""`) is returned.
"""
if self.__class__.__module__.split(".")[0] != self._doc_link_module:
return ""

Check warning on line 55 in skrub/_repr.py

View check run for this annotation

Codecov / codecov/patch

skrub/_repr.py#L55

Added line #L55 was not covered by tests

if self._doc_link_url_param_generator is None:
estimator_name = self.__class__.__name__

Check warning on line 58 in skrub/_repr.py

View check run for this annotation

Codecov / codecov/patch

skrub/_repr.py#L58

Added line #L58 was not covered by tests
# Construct the estimator's module name, up to the first private
# submodule. This works because in scikit-learn all public estimators
# are exposed at that level, even if they actually live in a private
# sub-module.
estimator_module = ".".join(
itertools.takewhile(
lambda part: not part.startswith("_"),
self.__class__.__module__.split("."),
)
)
return self._doc_link_template.format(

Check warning on line 69 in skrub/_repr.py

View check run for this annotation

Codecov / codecov/patch

skrub/_repr.py#L69

Added line #L69 was not covered by tests
estimator_module=estimator_module, estimator_name=estimator_name
)
return self._doc_link_template.format(

Check warning on line 72 in skrub/_repr.py

View check run for this annotation

Codecov / codecov/patch

skrub/_repr.py#L72

Added line #L72 was not covered by tests
**self._doc_link_url_param_generator()
)


doc_link_template = (
"https://skrub-data.org/{version}/reference/generated/"
"{estimator_module}.{estimator_name}.html"
)
doc_link_module = "skrub"


def doc_link_url_param_generator(estimator):
from skrub import __version__

Check warning on line 85 in skrub/_repr.py

View check run for this annotation

Codecov / codecov/patch

skrub/_repr.py#L85

Added line #L85 was not covered by tests

skrub_version = parse_version(__version__)

Check warning on line 87 in skrub/_repr.py

View check run for this annotation

Codecov / codecov/patch

skrub/_repr.py#L87

Added line #L87 was not covered by tests
if skrub_version.dev is None:
version_url = f"{skrub_version.major}.{skrub_version.minor}"

Check warning on line 89 in skrub/_repr.py

View check run for this annotation

Codecov / codecov/patch

skrub/_repr.py#L89

Added line #L89 was not covered by tests
else:
version_url = "dev"
estimator_name = estimator.__class__.__name__

Check warning on line 92 in skrub/_repr.py

View check run for this annotation

Codecov / codecov/patch

skrub/_repr.py#L91-L92

Added lines #L91 - L92 were not covered by tests
estimator_module = ".".join(
itertools.takewhile(
lambda part: not part.startswith("_"),
estimator.__class__.__module__.split("."),
)
)
return {

Check warning on line 99 in skrub/_repr.py

View check run for this annotation

Codecov / codecov/patch

skrub/_repr.py#L99

Added line #L99 was not covered by tests
"version": version_url,
"estimator_module": estimator_module,
"estimator_name": estimator_name,
}
12 changes: 11 additions & 1 deletion skrub/_table_vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
from ._datetime_encoder import DatetimeEncoder
from ._gap_encoder import GapEncoder
from ._on_each_column import SingleColumnTransformer
from ._repr import (
_HTMLDocumentationLinkMixin,
doc_link_module,
doc_link_template,
doc_link_url_param_generator,
)
from ._select_cols import Drop
from ._to_datetime import ToDatetime
from ._to_float32 import ToFloat32
Expand Down Expand Up @@ -110,7 +116,7 @@ def _check_transformer(transformer):
return clone(transformer)


class TableVectorizer(TransformerMixin, BaseEstimator):
class TableVectorizer(_HTMLDocumentationLinkMixin, TransformerMixin, BaseEstimator):
"""Transform a dataframe to a numerical (vectorized) representation.

Applies a different transformation to each of several kinds of columns:
Expand Down Expand Up @@ -405,6 +411,10 @@ class TableVectorizer(TransformerMixin, BaseEstimator):
ValueError: Column 'A' used twice in 'specific_transformers', at indices 0 and 1.
""" # noqa: E501

_doc_link_module = doc_link_module
_doc_link_template = doc_link_template
_doc_link_url_param_generator = doc_link_url_param_generator

def __init__(
self,
*,
Expand Down