diff --git a/doc/sources/algorithms.rst b/doc/sources/algorithms.rst index 6a73ee2b96..d7907e02f7 100755 --- a/doc/sources/algorithms.rst +++ b/doc/sources/algorithms.rst @@ -159,6 +159,9 @@ Dimensionality Reduction - ``svd_solver`` not in [`'full'`, `'covariance_eigh'`] - Sparse data is not supported + * - `IncrementalPCA` + - All parameters are supported + - Sparse data is not supported * - `TSNE` - All parameters are supported except: diff --git a/doc/sources/conf.py b/doc/sources/conf.py index 810c6534d6..65c44f4c87 100755 --- a/doc/sources/conf.py +++ b/doc/sources/conf.py @@ -67,6 +67,7 @@ "notfound.extension", "sphinx_design", "sphinx_copybutton", + "sphinx.ext.napoleon", ] # Add any paths that contain templates here, relative to this directory. diff --git a/doc/sources/index.rst b/doc/sources/index.rst index 62055385c8..b4734ca257 100755 --- a/doc/sources/index.rst +++ b/doc/sources/index.rst @@ -105,6 +105,7 @@ Enable Intel(R) GPU optimizations algorithms.rst oneAPI and GPU support distributed-mode.rst + non-scikit-algorithms.rst verbose.rst deprecation.rst diff --git a/doc/sources/non-scikit-algorithms.rst b/doc/sources/non-scikit-algorithms.rst new file mode 100644 index 0000000000..620461843f --- /dev/null +++ b/doc/sources/non-scikit-algorithms.rst @@ -0,0 +1,44 @@ +.. ****************************************************************************** +.. * Copyright 2024 Intel Corporation +.. * +.. * Licensed under the Apache License, Version 2.0 (the "License"); +.. * you may not use this file except in compliance with the License. +.. * You may obtain a copy of the License at +.. * +.. * http://www.apache.org/licenses/LICENSE-2.0 +.. * +.. * Unless required by applicable law or agreed to in writing, software +.. * distributed under the License is distributed on an "AS IS" BASIS, +.. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +.. * See the License for the specific language governing permissions and +.. * limitations under the License. +.. *******************************************************************************/ + +Non-Scikit-Learn Algorithms +=========================== +Algorithms not presented in the original scikit-learn are described here. All algorithms are +available for both CPU and GPU (including distributed mode) + +BasicStatistics +--------------- +.. autoclass:: sklearnex.basic_statistics.BasicStatistics +.. automethod:: sklearnex.basic_statistics.BasicStatistics.fit + +IncrementalBasicStatistics +-------------------------- +.. autoclass:: sklearnex.basic_statistics.IncrementalBasicStatistics +.. automethod:: sklearnex.basic_statistics.IncrementalBasicStatistics.fit +.. automethod:: sklearnex.basic_statistics.IncrementalBasicStatistics.partial_fit + +IncrementalEmpiricalCovariance +------------------------------ +.. autoclass:: sklearnex.covariance.IncrementalEmpiricalCovariance +.. automethod:: sklearnex.covariance.IncrementalEmpiricalCovariance.fit +.. automethod:: sklearnex.covariance.IncrementalEmpiricalCovariance.partial_fit + +IncrementalLinearRegression +--------------------------- +.. autoclass:: sklearnex.linear_model.IncrementalLinearRegression +.. automethod:: sklearnex.linear_model.IncrementalLinearRegression.fit +.. automethod:: sklearnex.linear_model.IncrementalLinearRegression.partial_fit +.. automethod:: sklearnex.linear_model.IncrementalLinearRegression.predict \ No newline at end of file diff --git a/sklearnex/basic_statistics/basic_statistics.py b/sklearnex/basic_statistics/basic_statistics.py index 277934dd98..af936870d8 100644 --- a/sklearnex/basic_statistics/basic_statistics.py +++ b/sklearnex/basic_statistics/basic_statistics.py @@ -42,36 +42,63 @@ class BasicStatistics(BaseEstimator): """ Estimator for basic statistics. Allows to compute basic statistics for provided data. - Note, some results can exhibit small variations due to - floating point error accumulation and multithreading. Parameters ---------- result_options: string or list, default='all' - List of statistics to compute + Used to set statistics to calculate. Possible values are ``'min'``, ``'max'``, ``'sum'``, ``'mean'``, ``'variance'``, + ``'variation'``, ``sum_squares'``, ``sum_squares_centered'``, ``'standard_deviation'``, ``'second_order_raw_moment'`` + or a list containing any of these values. If set to ``'all'`` then all possible statistics will be + calculated. - Attributes (are existing only if corresponding result option exists) + Attributes ---------- - min : ndarray of shape (n_features,) + min_ : ndarray of shape (n_features,) Minimum of each feature over all samples. - max : ndarray of shape (n_features,) + max_ : ndarray of shape (n_features,) Maximum of each feature over all samples. - sum : ndarray of shape (n_features,) + sum_ : ndarray of shape (n_features,) Sum of each feature over all samples. - mean : ndarray of shape (n_features,) + mean_ : ndarray of shape (n_features,) Mean of each feature over all samples. - variance : ndarray of shape (n_features,) + variance_ : ndarray of shape (n_features,) Variance of each feature over all samples. - variation : ndarray of shape (n_features,) + variation_ : ndarray of shape (n_features,) Variation of each feature over all samples. - sum_squares : ndarray of shape (n_features,) + sum_squares_ : ndarray of shape (n_features,) Sum of squares for each feature over all samples. - standard_deviation : ndarray of shape (n_features,) + standard_deviation_ : ndarray of shape (n_features,) Standard deviation of each feature over all samples. - sum_squares_centered : ndarray of shape (n_features,) + sum_squares_centered_ : ndarray of shape (n_features,) Centered sum of squares for each feature over all samples. - second_order_raw_moment : ndarray of shape (n_features,) + second_order_raw_moment_ : ndarray of shape (n_features,) Second order moment of each feature over all samples. + + Note + ---- + Attribute exists only if corresponding result option has been provided. + + Note + ---- + Attributes' names without the trailing underscore are + supported currently but deprecated in 2025.1 and will be removed in 2026.0 + + Note + ---- + Some results can exhibit small variations due to + floating point error accumulation and multithreading. + + Examples + -------- + >>> import numpy as np + >>> from sklearnex.basic_statistics import BasicStatistics + >>> bs = BasicStatistics(result_options=['sum', 'min', 'max']) + >>> X = np.array([[1, 2], [3, 4]]) + >>> bs.fit(X) + >>> bs.sum_ + np.array([4., 6.]) + >>> bs.min_ + np.array([1., 2.]) """ def __init__(self, result_options="all"): @@ -176,14 +203,14 @@ def fit(self, X, y=None, *, sample_weight=None): Parameters ---------- X : array-like of shape (n_samples, n_features) - Data for compute, where `n_samples` is the number of samples and - `n_features` is the number of features. + Data for compute, where ``n_samples`` is the number of samples and + ``n_features`` is the number of features. y : Ignored Not used, present for API consistency by convention. sample_weight : array-like of shape (n_samples,), default=None - Weights for compute weighted statistics, where `n_samples` is the number of samples. + Weights for compute weighted statistics, where ``n_samples`` is the number of samples. Returns ------- diff --git a/sklearnex/basic_statistics/incremental_basic_statistics.py b/sklearnex/basic_statistics/incremental_basic_statistics.py index a9bb01637b..ae8db61ffd 100644 --- a/sklearnex/basic_statistics/incremental_basic_statistics.py +++ b/sklearnex/basic_statistics/incremental_basic_statistics.py @@ -43,8 +43,10 @@ @control_n_jobs(decorated_methods=["partial_fit", "_onedal_finalize_fit"]) class IncrementalBasicStatistics(BaseEstimator): """ - Incremental estimator for basic statistics. - Allows to compute basic statistics if data are splitted into batches. + Calculates basic statistics on the given data, allows for computation when the data are split into + batches. The user can use ``partial_fit`` method to provide a single batch of data or use the ``fit`` method to provide + the entire dataset. + Parameters ---------- result_options: string or list, default='all' @@ -53,40 +55,76 @@ class IncrementalBasicStatistics(BaseEstimator): batch_size : int, default=None The number of samples to use for each batch. Only used when calling ``fit``. If ``batch_size`` is ``None``, then ``batch_size`` - is inferred from the data and set to ``5 * n_features``, to provide a - balance between approximation accuracy and memory consumption. + is inferred from the data and set to ``5 * n_features``. - Attributes (are existing only if corresponding result option exists) + Attributes ---------- - min : ndarray of shape (n_features,) + min_ : ndarray of shape (n_features,) Minimum of each feature over all samples. - max : ndarray of shape (n_features,) + max_ : ndarray of shape (n_features,) Maximum of each feature over all samples. - sum : ndarray of shape (n_features,) + sum_ : ndarray of shape (n_features,) Sum of each feature over all samples. - mean : ndarray of shape (n_features,) + mean_ : ndarray of shape (n_features,) Mean of each feature over all samples. - variance : ndarray of shape (n_features,) + variance_ : ndarray of shape (n_features,) Variance of each feature over all samples. - variation : ndarray of shape (n_features,) + variation_ : ndarray of shape (n_features,) Variation of each feature over all samples. - sum_squares : ndarray of shape (n_features,) + sum_squares_ : ndarray of shape (n_features,) Sum of squares for each feature over all samples. - standard_deviation : ndarray of shape (n_features,) + standard_deviation_ : ndarray of shape (n_features,) Standard deviation of each feature over all samples. - sum_squares_centered : ndarray of shape (n_features,) + sum_squares_centered_ : ndarray of shape (n_features,) Centered sum of squares for each feature over all samples. - second_order_raw_moment : ndarray of shape (n_features,) + second_order_raw_moment_ : ndarray of shape (n_features,) Second order moment of each feature over all samples. + + n_samples_seen_ : int + The number of samples processed by the estimator. Will be reset on + new calls to ``fit``, but increments across ``partial_fit`` calls. + + batch_size_ : int + Inferred batch size from ``batch_size``. + + n_features_in_ : int + Number of features seen during ``fit`` or ``partial_fit``. + + Note + ---- + Attribute exists only if corresponding result option has been provided. + + Note + ---- + Attributes' names without the trailing underscore are + supported currently but deprecated in 2025.1 and will be removed in 2026.0 + + Examples + -------- + >>> import numpy as np + >>> from sklearnex.basic_statistics import IncrementalBasicStatistics + >>> incbs = IncrementalBasicStatistics(batch_size=1) + >>> X = np.array([[1, 2], [3, 4]]) + >>> incbs.partial_fit(X[:1]) + >>> incbs.partial_fit(X[1:]) + >>> incbs.sum_ + np.array([4., 6.]) + >>> incbs.min_ + np.array([1., 2.]) + >>> incbs.fit(X) + >>> incbs.sum_ + np.array([4., 6.]) + >>> incbs.max_ + np.array([3., 4.]) """ _onedal_incremental_basic_statistics = staticmethod(onedal_IncrementalBasicStatistics) @@ -244,17 +282,17 @@ def partial_fit(self, X, sample_weight=None, check_input=True): Parameters ---------- X : array-like of shape (n_samples, n_features) - Data for compute, where `n_samples` is the number of samples and - `n_features` is the number of features. + Data for compute, where ``n_samples`` is the number of samples and + ``n_features`` is the number of features. y : Ignored Not used, present for API consistency by convention. sample_weight : array-like of shape (n_samples,), default=None - Weights for compute weighted statistics, where `n_samples` is the number of samples. + Weights for compute weighted statistics, where ``n_samples`` is the number of samples. check_input : bool, default=True - Run check_array on X. + Run ``check_array`` on X. Returns ------- @@ -280,14 +318,14 @@ def fit(self, X, y=None, sample_weight=None): Parameters ---------- X : array-like of shape (n_samples, n_features) - Data for compute, where `n_samples` is the number of samples and - `n_features` is the number of features. + Data for compute, where ``n_samples`` is the number of samples and + ``n_features`` is the number of features. y : Ignored Not used, present for API consistency by convention. sample_weight : array-like of shape (n_samples,), default=None - Weights for compute weighted statistics, where `n_samples` is the number of samples. + Weights for compute weighted statistics, where ``n_samples`` is the number of samples. Returns ------- diff --git a/sklearnex/covariance/incremental_covariance.py b/sklearnex/covariance/incremental_covariance.py index 36cc936d91..f248a0b0f6 100644 --- a/sklearnex/covariance/incremental_covariance.py +++ b/sklearnex/covariance/incremental_covariance.py @@ -49,9 +49,9 @@ @control_n_jobs(decorated_methods=["partial_fit", "fit", "_onedal_finalize_fit"]) class IncrementalEmpiricalCovariance(BaseEstimator): """ - Incremental estimator for covariance. - Allows to compute empirical covariance estimated by maximum - likelihood method if data are splitted into batches. + Maximum likelihood covariance estimator that allows for the estimation when the data are split into + batches. The user can use the ``partial_fit`` method to provide a single batch of data or use the ``fit`` method to provide + the entire dataset. Parameters ---------- @@ -84,13 +84,31 @@ class IncrementalEmpiricalCovariance(BaseEstimator): n_samples_seen_ : int The number of samples processed by the estimator. Will be reset on - new calls to fit, but increments across ``partial_fit`` calls. + new calls to ``fit``, but increments across ``partial_fit`` calls. batch_size_ : int Inferred batch size from ``batch_size``. n_features_in_ : int - Number of features seen during :term:`fit` `partial_fit`. + Number of features seen during ``fit`` or ``partial_fit``. + + Examples + -------- + >>> import numpy as np + >>> from sklearnex.covariance import IncrementalEmpiricalCovariance + >>> inccov = IncrementalEmpiricalCovariance(batch_size=1) + >>> X = np.array([[1, 2], [3, 4]]) + >>> inccov.partial_fit(X[:1]) + >>> inccov.partial_fit(X[1:]) + >>> inccov.covariance_ + np.array([[1., 1.],[1., 1.]]) + >>> inccov.location_ + np.array([2., 3.]) + >>> inccov.fit(X) + >>> inccov.covariance_ + np.array([[1., 1.],[1., 1.]]) + >>> inccov.location_ + np.array([2., 3.]) """ _onedal_incremental_covariance = staticmethod(onedal_IncrementalEmpiricalCovariance) diff --git a/sklearnex/linear_model/incremental_linear.py b/sklearnex/linear_model/incremental_linear.py index 066f22b37f..5538a17c4b 100644 --- a/sklearnex/linear_model/incremental_linear.py +++ b/sklearnex/linear_model/incremental_linear.py @@ -54,8 +54,9 @@ ) class IncrementalLinearRegression(MultiOutputMixin, RegressorMixin, BaseEstimator): """ - Incremental estimator for linear regression. - Allows to train linear regression if data are splitted into batches. + Trains a linear regression model, allows for computation if the data are split into + batches. The user can use the ``partial_fit`` method to provide a single batch of data or use the ``fit`` method to provide + the entire dataset. Parameters ---------- @@ -73,8 +74,7 @@ class IncrementalLinearRegression(MultiOutputMixin, RegressorMixin, BaseEstimato batch_size : int, default=None The number of samples to use for each batch. Only used when calling ``fit``. If ``batch_size`` is ``None``, then ``batch_size`` - is inferred from the data and set to ``5 * n_features``, to provide a - balance between approximation accuracy and memory consumption. + is inferred from the data and set to ``5 * n_features``. Attributes ---------- @@ -88,12 +88,9 @@ class IncrementalLinearRegression(MultiOutputMixin, RegressorMixin, BaseEstimato Independent term in the linear model. Set to 0.0 if `fit_intercept = False`. - n_features_in_ : int - Number of features seen during :term:`fit`. - n_samples_seen_ : int The number of samples processed by the estimator. Will be reset on - new calls to fit, but increments across ``partial_fit`` calls. + new calls to ``fit``, but increments across ``partial_fit`` calls. It should be not less than `n_features_in_` if `fit_intercept` is False and not less than `n_features_in_` + 1 if `fit_intercept` is True to obtain regression coefficients. @@ -102,8 +99,26 @@ class IncrementalLinearRegression(MultiOutputMixin, RegressorMixin, BaseEstimato Inferred batch size from ``batch_size``. n_features_in_ : int - Number of features seen during :term:`fit` `partial_fit`. - + Number of features seen during ``fit`` or ``partial_fit``. + + Examples + -------- + >>> import numpy as np + >>> from sklearnex.linear_model import IncrementalLinearRegression + >>> inclr = IncrementalLinearRegression(batch_size=2) + >>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 10]]) + >>> y = np.array([1.5, 3.5, 5.5, 8.5]) + >>> inclr.partial_fit(X[:2], y[:2]) + >>> inclr.partial_fit(X[2:], y[2:]) + >>> inclr.coef_ + np.array([0.5., 0.5.]) + >>> inclr.intercept_ + np.array(0.) + >>> inclr.fit(X) + >>> inclr.coef_ + np.array([0.5., 0.5.]) + >>> inclr.intercept_ + np.array(0.) """ _onedal_incremental_linear = staticmethod(onedal_IncrementalLinearRegression) @@ -274,7 +289,8 @@ def _onedal_fit(self, X, y, queue=None): self._onedal_finalize_fit(queue=queue) return self - def get_intercept_(self): + @property + def intercept_(self): if hasattr(self, "_onedal_estimator"): if self._need_to_finalize: self._onedal_finalize_fit() @@ -285,13 +301,15 @@ def get_intercept_(self): f"'{self.__class__.__name__}' object has no attribute 'intercept_'" ) - def set_intercept_(self, value): + @intercept_.setter + def intercept_(self, value): self.__dict__["intercept_"] = value if hasattr(self, "_onedal_estimator"): self._onedal_estimator.intercept_ = value del self._onedal_estimator._onedal_model - def get_coef_(self): + @property + def coef_(self): if hasattr(self, "_onedal_estimator"): if self._need_to_finalize: self._onedal_finalize_fit() @@ -302,15 +320,13 @@ def get_coef_(self): f"'{self.__class__.__name__}' object has no attribute 'coef_'" ) - def set_coef_(self, value): + @coef_.setter + def coef_(self, value): self.__dict__["coef_"] = value if hasattr(self, "_onedal_estimator"): self._onedal_estimator.coef_ = value del self._onedal_estimator._onedal_model - coef_ = property(get_coef_, set_coef_) - intercept_ = property(get_intercept_, set_intercept_) - def partial_fit(self, X, y, check_input=True): """ Incremental fit linear model with X and y. All of X and y is @@ -319,12 +335,12 @@ def partial_fit(self, X, y, check_input=True): Parameters ---------- X : array-like of shape (n_samples, n_features) - Training data, where `n_samples` is the number of samples and + Training data, where ``n_samples`` is the number of samples and `n_features` is the number of features. y : array-like of shape (n_samples,) or (n_samples, n_targets) - Target values, where `n_samples` is the number of samples and - `n_targets` is the number of targets. + Target values, where ``n_samples`` is the number of samples and + ``n_targets`` is the number of targets. Returns ------- @@ -347,20 +363,20 @@ def partial_fit(self, X, y, check_input=True): def fit(self, X, y): """ - Fit the model with X and y, using minibatches of size batch_size. + Fit the model with X and y, using minibatches of size ``batch_size``. Parameters ---------- X : array-like of shape (n_samples, n_features) - Training data, where `n_samples` is the number of samples and - `n_features` is the number of features. It is necessary for - `n_samples` to be not less than `n_features` if `fit_intercept` - is False and not less than `n_features` + 1 if `fit_intercept` + Training data, where ``n_samples`` is the number of samples and + ``n_features`` is the number of features. It is necessary for + ``n_samples`` to be not less than ``n_features`` if ``fit_intercept`` + is False and not less than ``n_features + 1`` if ``fit_intercept`` is True y : array-like of shape (n_samples,) or (n_samples, n_targets) - Target values, where `n_samples` is the number of samples and - `n_targets` is the number of targets. + Target values, where ``n_samples`` is the number of samples and + ``n_targets`` is the number of targets. Returns ------- @@ -384,10 +400,15 @@ def fit(self, X, y): def predict(self, X, y=None): """ Predict using the linear model. + Parameters ---------- X : array-like or sparse matrix, shape (n_samples, n_features) Samples. + + y : Ignored + Not used, present for API consistency by convention. + Returns ------- C : array, shape (n_samples, n_targets) diff --git a/sklearnex/linear_model/logistic_regression.py b/sklearnex/linear_model/logistic_regression.py index f0ce1c3913..df829a5c0e 100644 --- a/sklearnex/linear_model/logistic_regression.py +++ b/sklearnex/linear_model/logistic_regression.py @@ -66,7 +66,6 @@ def _save_attributes(self): ) class LogisticRegression(_sklearn_LogisticRegression, BaseLogisticRegression): __doc__ = _sklearn_LogisticRegression.__doc__ - intercept_, coef_, n_iter_ = None, None, None if sklearn_check_version("1.2"): _parameter_constraints: dict = { @@ -238,9 +237,6 @@ def _onedal_gpu_predict_supported(self, method_name, *data): ) n_samples = _num_samples(data[0]) - model_is_sparse = issparse(self.coef_) or ( - self.fit_intercept and issparse(self.intercept_) - ) dal_ready = patching_status.and_conditions( [ (n_samples > 0, "Number of samples is less than 1."), @@ -248,7 +244,6 @@ def _onedal_gpu_predict_supported(self, method_name, *data): (not any([issparse(i) for i in data])) or _sparsity_enabled, "Sparse input is not supported.", ), - (not model_is_sparse, "Sparse coefficients are not supported."), ( hasattr(self, "_onedal_estimator"), "oneDAL model was not trained.", diff --git a/sklearnex/tests/test_common.py b/sklearnex/tests/test_common.py index 41c65a2266..f427dfb982 100644 --- a/sklearnex/tests/test_common.py +++ b/sklearnex/tests/test_common.py @@ -146,6 +146,24 @@ def wrap(*args, **kwargs): return wrap +def test_class_trailing_underscore_ban(monkeypatch): + """Trailing underscores are defined for sklearn to be signatures of a fitted + estimator instance, sklearnex extends this to the classes as well""" + monkeypatch.setattr(pkgutil, "walk_packages", _sklearnex_walk(pkgutil.walk_packages)) + estimators = all_estimators() # list of tuples + for name, obj in estimators: + if "preview" not in obj.__module__ and "daal4py" not in obj.__module__: + # propeties also occur in sklearn, especially in deprecations and are expected + # to error if queried and the estimator is not fitted + assert all( + [ + isinstance(getattr(obj, attr), property) + or (attr.startswith("_") or not attr.endswith("_")) + for attr in dir(obj) + ] + ), f"{name} contains class attributes which have a trailing underscore but no leading one" + + def test_all_estimators_covered(monkeypatch): """Check that all estimators defined in sklearnex are available in either the patch map or covered in special testing via SPECIAL_INSTANCES. The estimator @@ -288,12 +306,12 @@ def estimator_trace(estimator, method, cache, capsys, monkeypatch): cache.set("key", key) cache.set( "text", - [ - re.findall(regex_func, text), - text, - [i.replace(os.sep, ".") for i in re.findall(regex_mod, text)], - [""] + re.findall(regex_callingline, text), - ], + { + "funcs": re.findall(regex_func, text), + "trace": text, + "modules": [i.replace(os.sep, ".") for i in re.findall(regex_mod, text)], + "callingline": [""] + re.findall(regex_callingline, text), + }, ) return cache.get("text", None) @@ -304,8 +322,8 @@ def call_validate_data(text, estimator, method): called once before offloading to oneDAL in sklearnex""" try: # get last to_table call showing end of oneDAL input portion of code - idx = len(text[0]) - 1 - text[0][::-1].index("to_table") - validfuncs = text[0][:idx] + idx = len(text["funcs"]) - 1 - text["funcs"][::-1].index("to_table") + validfuncs = text["funcs"][:idx] except ValueError: pytest.skip("onedal backend not used in this function") @@ -323,16 +341,17 @@ def n_jobs_check(text, estimator, method): """verify the n_jobs is being set if '_get_backend' or 'to_table' is called""" # remove the _get_backend function from sklearnex from considered _get_backend count = max( - text[0].count("to_table"), + text["funcs"].count("to_table"), len( [ i - for i in range(len(text[0])) - if text[0][i] == "_get_backend" and "sklearnex" not in text[2][i] + for i in range(len(text["funcs"])) + if text["funcs"][i] == "_get_backend" + and "sklearnex" not in text["modules"][i] ] ), ) - n_jobs_count = text[0].count("n_jobs_wrapper") + n_jobs_count = text["funcs"].count("n_jobs_wrapper") assert bool(count) == bool( n_jobs_count @@ -342,7 +361,7 @@ def n_jobs_check(text, estimator, method): def runtime_property_check(text, estimator, method): """use of Python's 'property' should not be used at runtime, only at class instantiation""" assert ( - len(re.findall(r"property\(", text[1])) == 0 + len(re.findall(r"property\(", text["trace"])) == 0 ), f"{estimator}.{method} should only use 'property' at instantiation"