From 87817e313b7784238594b2822a2837b0cc040781 Mon Sep 17 00:00:00 2001 From: Patrick Smyth Date: Mon, 18 Nov 2024 12:16:54 -0600 Subject: [PATCH] first commit --- .github/conda/build.sh | 1 + .github/conda/meta.yaml | 46 + .github/workflows/ci_cd_pipeline.yml | 60 + .github/workflows/conda-release.yml | 36 + .github/workflows/pypi-release.yml | 35 + .gitignore | 132 + .pre-commit-config.yaml | 9 + LICENSE | 202 + MANIFEST.in | 1 + Makefile | 20 + README.md | 149 + docs/.gitkeep | 0 docs/CONTRIBUTING.md | 108 + setup.cfg | 16 + setup.py | 129 + src/biofit/__init__.py | 15 + src/biofit/__main__.py | 15 + src/biofit/auto/__init__.py | 11 + src/biofit/auto/auto_factory.py | 338 ++ src/biofit/auto/configuration_auto.py | 1269 ++++++ src/biofit/auto/modeling_auto.py | 61 + src/biofit/auto/plotting_auto.py | 143 + src/biofit/auto/processing_auto.py | 98 + src/biofit/cli/__init__.py | 0 src/biofit/cli/install_r.py | 95 + src/biofit/cli/main.py | 50 + src/biofit/config.py | 231 ++ src/biofit/eval.py | 296 ++ src/biofit/exceptions.py | 0 src/biofit/integration/R/__init__.py | 2 + src/biofit/integration/R/r_caller.py | 736 ++++ .../integration/R/scripts/analysis_utils.R | 245 ++ .../integration/R/scripts/dependencies.R | 10 + .../integration/R/scripts/plotting_utils.R | 1447 +++++++ src/biofit/integration/R/scripts/utils.R | 246 ++ src/biofit/integration/__init__.py | 2 + src/biofit/integration/biosets.py | 14 + src/biofit/integration/patcher.py | 880 ++++ src/biofit/metrics/__init__.py | 168 + src/biofit/metrics/metrics.py | 201 + src/biofit/models/__init__.py | 3 + src/biofit/models/ensemble/__init__.py | 2 + src/biofit/models/ensemble/ensemble.py | 0 src/biofit/models/lasso/__init__.py | 7 + src/biofit/models/lasso/lasso.py | 355 ++ src/biofit/models/lightgbm/__init__.py | 7 + src/biofit/models/lightgbm/lightgbm.py | 578 +++ .../models/logistic_regression/__init__.py | 5 + .../logistic_regression.py | 246 ++ src/biofit/models/models.py | 223 + src/biofit/models/random_forest/__init__.py | 7 + .../models/random_forest/random_forest.py | 428 ++ src/biofit/preprocessing/__init__.py | 10 + src/biofit/preprocessing/encoding/__init__.py | 3 + src/biofit/preprocessing/encoding/encoding.py | 14 + .../encoding/label_binarizing/__init__.py | 2 + .../label_binarizing/label_binarizing.py | 405 ++ .../encoding/label_encoding/__init__.py | 2 + .../encoding/label_encoding/label_encoding.py | 243 ++ .../feature_extraction/__init__.py | 3 + .../feature_extraction/feature_extraction.py | 25 + .../feature_extraction/pca/__init__.py | 5 + .../feature_extraction/pca/pca.py | 235 ++ .../feature_extraction/pca/plot_pca.py | 82 + .../feature_extraction/pcoa/__init__.py | 10 + .../feature_extraction/pcoa/pcoa.py | 413 ++ .../feature_extraction/pcoa/plot_pcoa.py | 91 + .../plot_feature_extraction.py | 19 + .../feature_selection/__init__.py | 2 + .../feature_selection/feature_selection.py | 105 + .../__init__.py | 20 + .../min_prevalence_feature_selector.py | 404 ++ .../plot_min_prevalence_feature_selector.py | 367 ++ .../plot_feature_selection.R | 57 + .../plot_feature_selection.py | 20 + .../feature_selection/rfe/plot_rfe.R | 25 + .../feature_selection/rfe/plot_rfe.py | 0 .../preprocessing/filtering/__init__.py | 4 + .../preprocessing/filtering/filtering.py | 78 + .../min_prevalence_sample_filter/__init__.py | 20 + .../min_prevalence_sample_filter.py | 370 ++ .../plot_min_prevalence_sample_filter.py | 304 ++ .../filtering/missing_labels/__init__.py | 2 + .../missing_labels/missing_labels.py | 221 + .../preprocessing/filtering/plot_filtering.R | 69 + .../preprocessing/filtering/plot_filtering.py | 20 + .../filtering/row_abundance/__init__.py | 11 + .../row_abundance/plot_row_abundance.py | 327 ++ .../filtering/row_abundance/row_abundance.py | 312 ++ .../preprocessing/imputation/__init__.py | 0 .../preprocessing/imputation/imputation.py | 14 + .../preprocessing/resampling/__init__.py | 2 + .../preprocessing/resampling/resampling.py | 52 + .../resampling/upsampling/__init__.py | 7 + .../resampling/upsampling/upsampling.py | 479 +++ src/biofit/preprocessing/scaling/__init__.py | 9 + .../preprocessing/scaling/clr/__init__.py | 2 + src/biofit/preprocessing/scaling/clr/clr.py | 187 + .../preprocessing/scaling/css/__init__.py | 19 + src/biofit/preprocessing/scaling/css/css.py | 348 ++ .../preprocessing/scaling/css/plot_css.py | 297 ++ .../preprocessing/scaling/plot_scaling.R | 56 + .../preprocessing/scaling/plot_scaling.py | 51 + .../scaling/relative_abundance/__init__.py | 11 + .../plot_relative_abundance.py | 274 ++ .../relative_abundance/relative_abundance.py | 188 + src/biofit/preprocessing/scaling/scaling.py | 24 + .../preprocessing/scaling/tmm/__init__.py | 7 + src/biofit/preprocessing/scaling/tmm/tmm.R | 89 + src/biofit/preprocessing/scaling/tmm/tmm.py | 281 ++ .../preprocessing/transformation/__init__.py | 2 + .../transformation/log/__init__.py | 2 + .../preprocessing/transformation/log/log.py | 201 + .../transformation/transformation.py | 17 + src/biofit/processing.py | 3696 +++++++++++++++++ src/biofit/stat/__init__.py | 9 + src/biofit/stat/col_mean/__init__.py | 6 + src/biofit/stat/col_mean/col_mean.py | 208 + src/biofit/stat/col_missingness/__init__.py | 8 + .../stat/col_missingness/col_missingness.py | 321 ++ src/biofit/stat/col_sum/__init__.py | 6 + src/biofit/stat/col_sum/col_sum.py | 277 ++ src/biofit/stat/correlation/__init__.py | 6 + src/biofit/stat/correlation/correlation.py | 235 ++ src/biofit/stat/distance/__init__.py | 9 + src/biofit/stat/distance/distance.py | 296 ++ src/biofit/stat/row_mean/__init__.py | 6 + src/biofit/stat/row_mean/row_mean.py | 178 + src/biofit/stat/row_missingness/__init__.py | 8 + .../stat/row_missingness/row_missingness.py | 296 ++ src/biofit/stat/row_sum/__init__.py | 6 + src/biofit/stat/row_sum/row_sum.py | 150 + src/biofit/stat/stat.py | 103 + src/biofit/stat/summary/__init__.py | 2 + src/biofit/stat/summary/summary.py | 0 src/biofit/train.py | 564 +++ src/biofit/train_eval_utils.py | 1406 +++++++ src/biofit/utils/__init__.py | 60 + src/biofit/utils/_dill.py | 469 +++ src/biofit/utils/doc.py | 1213 ++++++ src/biofit/utils/file_utils.py | 234 ++ src/biofit/utils/fingerprint.py | 400 ++ src/biofit/utils/generic.py | 536 +++ src/biofit/utils/gorilla.py | 1009 +++++ src/biofit/utils/logging.py | 447 ++ src/biofit/utils/py_util.py | 209 + src/biofit/utils/recorder.py | 156 + src/biofit/utils/table_util.py | 863 ++++ src/biofit/utils/types.py | 27 + src/biofit/utils/version.py | 115 + src/biofit/visualization/__init__.py | 22 + src/biofit/visualization/barplot.py | 198 + .../visualization/dimension_reduction.py | 435 ++ .../visualization/feature_importance.py | 355 ++ src/biofit/visualization/histogram.py | 430 ++ .../visualization/plot_feature_importance.R | 216 + .../visualization/plot_sample_metadata.R | 140 + src/biofit/visualization/plotting.py | 429 ++ src/biofit/visualization/plotting_utils.py | 1210 ++++++ src/biofit/visualization/report_generation.py | 35 + src/biofit/visualization/sample_metadata.py | 128 + src/biofit/visualization/scatterplot.py | 188 + src/biofit/visualization/violin.py | 140 + tests/__init__.py | 0 tests/conftest.py | 108 + tests/fixtures/__init__.py | 0 tests/fixtures/files.py | 1426 +++++++ tests/fixtures/fsspec.py | 120 + tests/preprocessing/__init__.py | 0 tests/preprocessing/outputs/histogram.pdf | Bin 0 -> 4407 bytes tests/preprocessing/outputs/histogram.png | Bin 0 -> 53930 bytes .../preprocessing/test_abundance_filtering.py | 66 + tests/preprocessing/test_auto_plotting.py | 58 + .../preprocessing/test_auto_preprocessing.py | 67 + tests/preprocessing/test_css.py | 95 + tests/preprocessing/test_label_binarizer.py | 86 + tests/preprocessing/test_log.py | 66 + .../test_min_prevalence_features.py | 91 + .../test_min_prevalence_samples.py | 68 + tests/preprocessing/test_missing_labels.py | 63 + tests/preprocessing/test_pca.py | 77 + tests/preprocessing/test_pcoa.py | 99 + tests/preprocessing/test_r_caller.py | 70 + tests/preprocessing/test_tmm.py | 90 + tests/preprocessing/test_upsampling.py | 70 + tests/stat/__init__.py | 0 tests/stat/test_correlation.py | 89 + tests/stat/test_distance.py | 100 + tests/stat/test_mean.py | 157 + tests/stat/test_missingness.py | 150 + tests/stat/test_sum.py | 157 + tests/test_eval.py | 692 +++ tests/test_patcher.py | 309 ++ tests/test_plot_feature_importances.py | 350 ++ tests/test_plot_sample_metadata.py | 8 + tests/test_plotting_utils.py | 697 ++++ tests/test_processing.py | 2266 ++++++++++ tests/utils.py | 133 + tools/generate_auto_maps.py | 414 ++ 199 files changed, 41999 insertions(+) create mode 100644 .github/conda/build.sh create mode 100644 .github/conda/meta.yaml create mode 100644 .github/workflows/ci_cd_pipeline.yml create mode 100644 .github/workflows/conda-release.yml create mode 100644 .github/workflows/pypi-release.yml create mode 100644 .gitignore create mode 100644 .pre-commit-config.yaml create mode 100644 LICENSE create mode 100644 MANIFEST.in create mode 100644 Makefile create mode 100644 README.md create mode 100644 docs/.gitkeep create mode 100644 docs/CONTRIBUTING.md create mode 100644 setup.cfg create mode 100644 setup.py create mode 100644 src/biofit/__init__.py create mode 100644 src/biofit/__main__.py create mode 100644 src/biofit/auto/__init__.py create mode 100644 src/biofit/auto/auto_factory.py create mode 100644 src/biofit/auto/configuration_auto.py create mode 100644 src/biofit/auto/modeling_auto.py create mode 100644 src/biofit/auto/plotting_auto.py create mode 100644 src/biofit/auto/processing_auto.py create mode 100644 src/biofit/cli/__init__.py create mode 100644 src/biofit/cli/install_r.py create mode 100644 src/biofit/cli/main.py create mode 100644 src/biofit/config.py create mode 100644 src/biofit/eval.py create mode 100644 src/biofit/exceptions.py create mode 100644 src/biofit/integration/R/__init__.py create mode 100644 src/biofit/integration/R/r_caller.py create mode 100644 src/biofit/integration/R/scripts/analysis_utils.R create mode 100644 src/biofit/integration/R/scripts/dependencies.R create mode 100644 src/biofit/integration/R/scripts/plotting_utils.R create mode 100644 src/biofit/integration/R/scripts/utils.R create mode 100644 src/biofit/integration/__init__.py create mode 100644 src/biofit/integration/biosets.py create mode 100644 src/biofit/integration/patcher.py create mode 100644 src/biofit/metrics/__init__.py create mode 100644 src/biofit/metrics/metrics.py create mode 100644 src/biofit/models/__init__.py create mode 100644 src/biofit/models/ensemble/__init__.py create mode 100644 src/biofit/models/ensemble/ensemble.py create mode 100644 src/biofit/models/lasso/__init__.py create mode 100644 src/biofit/models/lasso/lasso.py create mode 100644 src/biofit/models/lightgbm/__init__.py create mode 100644 src/biofit/models/lightgbm/lightgbm.py create mode 100644 src/biofit/models/logistic_regression/__init__.py create mode 100644 src/biofit/models/logistic_regression/logistic_regression.py create mode 100644 src/biofit/models/models.py create mode 100644 src/biofit/models/random_forest/__init__.py create mode 100644 src/biofit/models/random_forest/random_forest.py create mode 100644 src/biofit/preprocessing/__init__.py create mode 100644 src/biofit/preprocessing/encoding/__init__.py create mode 100644 src/biofit/preprocessing/encoding/encoding.py create mode 100644 src/biofit/preprocessing/encoding/label_binarizing/__init__.py create mode 100644 src/biofit/preprocessing/encoding/label_binarizing/label_binarizing.py create mode 100644 src/biofit/preprocessing/encoding/label_encoding/__init__.py create mode 100644 src/biofit/preprocessing/encoding/label_encoding/label_encoding.py create mode 100644 src/biofit/preprocessing/feature_extraction/__init__.py create mode 100644 src/biofit/preprocessing/feature_extraction/feature_extraction.py create mode 100644 src/biofit/preprocessing/feature_extraction/pca/__init__.py create mode 100644 src/biofit/preprocessing/feature_extraction/pca/pca.py create mode 100644 src/biofit/preprocessing/feature_extraction/pca/plot_pca.py create mode 100644 src/biofit/preprocessing/feature_extraction/pcoa/__init__.py create mode 100644 src/biofit/preprocessing/feature_extraction/pcoa/pcoa.py create mode 100644 src/biofit/preprocessing/feature_extraction/pcoa/plot_pcoa.py create mode 100644 src/biofit/preprocessing/feature_extraction/plot_feature_extraction.py create mode 100644 src/biofit/preprocessing/feature_selection/__init__.py create mode 100644 src/biofit/preprocessing/feature_selection/feature_selection.py create mode 100644 src/biofit/preprocessing/feature_selection/min_prevalence_feature_selector/__init__.py create mode 100644 src/biofit/preprocessing/feature_selection/min_prevalence_feature_selector/min_prevalence_feature_selector.py create mode 100644 src/biofit/preprocessing/feature_selection/min_prevalence_feature_selector/plot_min_prevalence_feature_selector.py create mode 100644 src/biofit/preprocessing/feature_selection/plot_feature_selection.R create mode 100644 src/biofit/preprocessing/feature_selection/plot_feature_selection.py create mode 100644 src/biofit/preprocessing/feature_selection/rfe/plot_rfe.R create mode 100644 src/biofit/preprocessing/feature_selection/rfe/plot_rfe.py create mode 100644 src/biofit/preprocessing/filtering/__init__.py create mode 100644 src/biofit/preprocessing/filtering/filtering.py create mode 100644 src/biofit/preprocessing/filtering/min_prevalence_sample_filter/__init__.py create mode 100644 src/biofit/preprocessing/filtering/min_prevalence_sample_filter/min_prevalence_sample_filter.py create mode 100644 src/biofit/preprocessing/filtering/min_prevalence_sample_filter/plot_min_prevalence_sample_filter.py create mode 100644 src/biofit/preprocessing/filtering/missing_labels/__init__.py create mode 100644 src/biofit/preprocessing/filtering/missing_labels/missing_labels.py create mode 100644 src/biofit/preprocessing/filtering/plot_filtering.R create mode 100644 src/biofit/preprocessing/filtering/plot_filtering.py create mode 100644 src/biofit/preprocessing/filtering/row_abundance/__init__.py create mode 100644 src/biofit/preprocessing/filtering/row_abundance/plot_row_abundance.py create mode 100644 src/biofit/preprocessing/filtering/row_abundance/row_abundance.py create mode 100644 src/biofit/preprocessing/imputation/__init__.py create mode 100644 src/biofit/preprocessing/imputation/imputation.py create mode 100644 src/biofit/preprocessing/resampling/__init__.py create mode 100644 src/biofit/preprocessing/resampling/resampling.py create mode 100644 src/biofit/preprocessing/resampling/upsampling/__init__.py create mode 100644 src/biofit/preprocessing/resampling/upsampling/upsampling.py create mode 100644 src/biofit/preprocessing/scaling/__init__.py create mode 100644 src/biofit/preprocessing/scaling/clr/__init__.py create mode 100644 src/biofit/preprocessing/scaling/clr/clr.py create mode 100644 src/biofit/preprocessing/scaling/css/__init__.py create mode 100644 src/biofit/preprocessing/scaling/css/css.py create mode 100644 src/biofit/preprocessing/scaling/css/plot_css.py create mode 100644 src/biofit/preprocessing/scaling/plot_scaling.R create mode 100644 src/biofit/preprocessing/scaling/plot_scaling.py create mode 100644 src/biofit/preprocessing/scaling/relative_abundance/__init__.py create mode 100644 src/biofit/preprocessing/scaling/relative_abundance/plot_relative_abundance.py create mode 100644 src/biofit/preprocessing/scaling/relative_abundance/relative_abundance.py create mode 100644 src/biofit/preprocessing/scaling/scaling.py create mode 100644 src/biofit/preprocessing/scaling/tmm/__init__.py create mode 100644 src/biofit/preprocessing/scaling/tmm/tmm.R create mode 100644 src/biofit/preprocessing/scaling/tmm/tmm.py create mode 100644 src/biofit/preprocessing/transformation/__init__.py create mode 100644 src/biofit/preprocessing/transformation/log/__init__.py create mode 100644 src/biofit/preprocessing/transformation/log/log.py create mode 100644 src/biofit/preprocessing/transformation/transformation.py create mode 100644 src/biofit/processing.py create mode 100644 src/biofit/stat/__init__.py create mode 100644 src/biofit/stat/col_mean/__init__.py create mode 100644 src/biofit/stat/col_mean/col_mean.py create mode 100644 src/biofit/stat/col_missingness/__init__.py create mode 100644 src/biofit/stat/col_missingness/col_missingness.py create mode 100644 src/biofit/stat/col_sum/__init__.py create mode 100644 src/biofit/stat/col_sum/col_sum.py create mode 100644 src/biofit/stat/correlation/__init__.py create mode 100644 src/biofit/stat/correlation/correlation.py create mode 100644 src/biofit/stat/distance/__init__.py create mode 100644 src/biofit/stat/distance/distance.py create mode 100644 src/biofit/stat/row_mean/__init__.py create mode 100644 src/biofit/stat/row_mean/row_mean.py create mode 100644 src/biofit/stat/row_missingness/__init__.py create mode 100644 src/biofit/stat/row_missingness/row_missingness.py create mode 100644 src/biofit/stat/row_sum/__init__.py create mode 100644 src/biofit/stat/row_sum/row_sum.py create mode 100644 src/biofit/stat/stat.py create mode 100644 src/biofit/stat/summary/__init__.py create mode 100644 src/biofit/stat/summary/summary.py create mode 100644 src/biofit/train.py create mode 100644 src/biofit/train_eval_utils.py create mode 100644 src/biofit/utils/__init__.py create mode 100644 src/biofit/utils/_dill.py create mode 100644 src/biofit/utils/doc.py create mode 100644 src/biofit/utils/file_utils.py create mode 100644 src/biofit/utils/fingerprint.py create mode 100644 src/biofit/utils/generic.py create mode 100644 src/biofit/utils/gorilla.py create mode 100644 src/biofit/utils/logging.py create mode 100644 src/biofit/utils/py_util.py create mode 100644 src/biofit/utils/recorder.py create mode 100644 src/biofit/utils/table_util.py create mode 100644 src/biofit/utils/types.py create mode 100644 src/biofit/utils/version.py create mode 100644 src/biofit/visualization/__init__.py create mode 100644 src/biofit/visualization/barplot.py create mode 100644 src/biofit/visualization/dimension_reduction.py create mode 100644 src/biofit/visualization/feature_importance.py create mode 100644 src/biofit/visualization/histogram.py create mode 100644 src/biofit/visualization/plot_feature_importance.R create mode 100644 src/biofit/visualization/plot_sample_metadata.R create mode 100644 src/biofit/visualization/plotting.py create mode 100644 src/biofit/visualization/plotting_utils.py create mode 100644 src/biofit/visualization/report_generation.py create mode 100644 src/biofit/visualization/sample_metadata.py create mode 100644 src/biofit/visualization/scatterplot.py create mode 100644 src/biofit/visualization/violin.py create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/fixtures/__init__.py create mode 100644 tests/fixtures/files.py create mode 100644 tests/fixtures/fsspec.py create mode 100644 tests/preprocessing/__init__.py create mode 100644 tests/preprocessing/outputs/histogram.pdf create mode 100644 tests/preprocessing/outputs/histogram.png create mode 100644 tests/preprocessing/test_abundance_filtering.py create mode 100644 tests/preprocessing/test_auto_plotting.py create mode 100644 tests/preprocessing/test_auto_preprocessing.py create mode 100644 tests/preprocessing/test_css.py create mode 100644 tests/preprocessing/test_label_binarizer.py create mode 100644 tests/preprocessing/test_log.py create mode 100644 tests/preprocessing/test_min_prevalence_features.py create mode 100644 tests/preprocessing/test_min_prevalence_samples.py create mode 100644 tests/preprocessing/test_missing_labels.py create mode 100644 tests/preprocessing/test_pca.py create mode 100644 tests/preprocessing/test_pcoa.py create mode 100644 tests/preprocessing/test_r_caller.py create mode 100644 tests/preprocessing/test_tmm.py create mode 100644 tests/preprocessing/test_upsampling.py create mode 100644 tests/stat/__init__.py create mode 100644 tests/stat/test_correlation.py create mode 100644 tests/stat/test_distance.py create mode 100644 tests/stat/test_mean.py create mode 100644 tests/stat/test_missingness.py create mode 100644 tests/stat/test_sum.py create mode 100644 tests/test_eval.py create mode 100644 tests/test_patcher.py create mode 100644 tests/test_plot_feature_importances.py create mode 100644 tests/test_plot_sample_metadata.py create mode 100644 tests/test_plotting_utils.py create mode 100644 tests/test_processing.py create mode 100644 tests/utils.py create mode 100644 tools/generate_auto_maps.py diff --git a/.github/conda/build.sh b/.github/conda/build.sh new file mode 100644 index 0000000..a660906 --- /dev/null +++ b/.github/conda/build.sh @@ -0,0 +1 @@ +$PYTHON setup.py install --single-version-externally-managed --record=record.txt diff --git a/.github/conda/meta.yaml b/.github/conda/meta.yaml new file mode 100644 index 0000000..5a9d230 --- /dev/null +++ b/.github/conda/meta.yaml @@ -0,0 +1,46 @@ +{% set name = "biofit" %} + +package: + name: "{{ name|lower }}" + version: "{{ BIOFIT_VERSION }}" + +source: + path: ../../ + +build: + noarch: python +requirements: + host: + - python + - pip + run: + - python + - pip + - biocore + - filelock + - numpy >=1.17 + - pyarrow >=8.0.0 + - pyarrow-hotfix + - dill >=0.3.0,<0.3.8 + - pandas + - requests >=2.19.0 + - tqdm >=4.62.1 + - xxhash + - multiprocess + - fsspec[http] >=2023.1.0,<=2023.10.0 + - packaging + - pyyaml >=5.1 + - scikit-learn >=1.4.0 +test: + imports: + - biofit +about: + home: https://github.com/psmyth94/biofit + license: Apache-2.0 + license_file: LICENSE + summary: A Python package for machine learning on omics datasets. + keywords: + - omics + - machine learning + - bioinformatics + - datasets diff --git a/.github/workflows/ci_cd_pipeline.yml b/.github/workflows/ci_cd_pipeline.yml new file mode 100644 index 0000000..570ddd8 --- /dev/null +++ b/.github/workflows/ci_cd_pipeline.yml @@ -0,0 +1,60 @@ +name: CI/CD Pipeline +on: + push: + branches: + - main + - "ci-*" + pull_request: + branches: + - main +jobs: + check_code_quality: + name: Check Code Quality + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Set up Python 3.8 + uses: actions/setup-python@v5 + with: + python-version: "3.8" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install .[quality] + - name: Run Ruff code quality checks + run: | + ruff check tests src setup.py --output-format=github + ruff format --check tests src setup.py + tests: + needs: check_code_quality + strategy: + matrix: + test: [unit] + os: [ubuntu-latest, windows-latest] + python-version: ["3.8", "3.9", "3.10", "3.11"] + runs-on: ${{ matrix.os }} + continue-on-error: ${{ matrix.test == 'integration' }} + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install uv --upgrade + uv pip install --system "biofit[test] @ ." + - name: Run Tests + run: | + mkdir -p reports/${{ matrix.test }}_tests + python3 -m pytest -m ${{ matrix.test }} -n 2 --dist loadfile -sv ./tests/ --junitxml=reports/${{ matrix.test }}_tests/results.xml + - name: Upload Unit Test Results + uses: actions/upload-artifact@v3 + with: + name: ${{ matrix.test }}-test-reports + path: reports/${{ matrix.test }}_tests/results.xml diff --git a/.github/workflows/conda-release.yml b/.github/workflows/conda-release.yml new file mode 100644 index 0000000..818bceb --- /dev/null +++ b/.github/workflows/conda-release.yml @@ -0,0 +1,36 @@ +name: Conda - Build +on: + push: + tags: + - "[0-9]+.[0-9]+.[0-9]+*" +env: + ANACONDA_API_TOKEN: ${{ secrets.ANACONDA_API_TOKEN }} +jobs: + build_and_package: + runs-on: ubuntu-22.04 + defaults: + run: + shell: bash -l {0} + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Install miniconda + uses: conda-incubator/setup-miniconda@v2 + with: + auto-update-conda: true + auto-activate-base: false + activate-environment: "build-biofit" + python-version: 3.8 + channels: patrico49 + - name: Setup conda env + run: | + conda install -c defaults anaconda-client conda-build + - name: Extract version + run: echo "BIOFIT_VERSION=`python setup.py --version`" >> $GITHUB_ENV + - name: Build conda packages + run: | + conda info + conda build .github/conda + - name: Upload to Anaconda + run: | + anaconda upload `conda build .github/conda --output -c conda-forge` --force diff --git a/.github/workflows/pypi-release.yml b/.github/workflows/pypi-release.yml new file mode 100644 index 0000000..7155722 --- /dev/null +++ b/.github/workflows/pypi-release.yml @@ -0,0 +1,35 @@ +name: PyPI - Release +on: + push: + tags: + - "[0-9]+.[0-9]+.[0-9]+*" +env: + PYPI_API_TOKEN: ${{ secrets.PYPI_API_TOKEN }} +jobs: + build_and_publish: + runs-on: ubuntu-22.04 + defaults: + run: + shell: bash -l {0} + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: 3.8 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install setuptools wheel twine + - name: Extract version + run: echo "BIOFIT_VERSION=$(python setup.py --version)" >> $GITHUB_ENV + - name: Build package + run: | + python setup.py sdist bdist_wheel + - name: Publish to PyPI + run: | + python -m twine upload dist/* --non-interactive + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0c55a78 --- /dev/null +++ b/.gitignore @@ -0,0 +1,132 @@ +results/ +*.lock +_dev +# Ignore Nextflow-related files and directories +.nextflow* +work +_dev/in_progress* +.dev + +# Ignore data directory +tests/data + +# anything not ipynb related +tutorials/* +!tutorials/*.ipynb + + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +.conda + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +cache*.arrow +cache*.yaml + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints +-checkpoint.ipynb + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to https://pipenv.pypa.io/en/latest/advanced/#pipfile-vs-pipfilelock +# PIPENV_VENV_IN_PROJECT=1 +Pipfile.lock + +# Poetry +# https://python-poetry.org/docs/basic-usage/#configuring-poetry +.poetry/ + +# virtualenv +venv/ +env/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# Idea project settings +.idea/ + +# PyCharm project settings +.idea/ + +# Visual Studio Code settings +.vscode/ +.vs/ + +# MyPy +.mypy_cache/ + +# pycodestyle +.pytest_cache/ + +#Docker +*Dockerfile +*dockerignore diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..c3751a6 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,9 @@ +repos: + - repo: https://github.com/charliermarsh/ruff-pre-commit # https://github.com/charliermarsh/ruff#usage + rev: 'v0.3.0' + hooks: + # Run the linter. + - id: ruff + args: [ --fix ] + # Run the formatter. + - id: ruff-format diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..d645695 --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..3032302 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +recursive-include src/biofit *.R diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..695cf22 --- /dev/null +++ b/Makefile @@ -0,0 +1,20 @@ +.PHONY: quality style test + +check_dirs := tests src + +# Check that source code meets quality standards + +quality: + ruff check $(check_dirs) setup.py # linter + ruff format --check $(check_dirs) setup.py # formatter + +# Format source code automatically + +style: + ruff check --fix $(check_dirs) setup.py # linter + ruff format $(check_dirs) setup.py # formatter + +# Run tests for the library + +test: + python -m pytest -n auto --dist=loadfile -s -v ./tests/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..3873972 --- /dev/null +++ b/README.md @@ -0,0 +1,149 @@ +

+ $${\Huge{\textbf{\textsf{\color{#2E8B57}Bio\color{#4682B4}fit}}}}$$ +
+
+

+

+ Build + GitHub + Documentation + GitHub release + Contributor Covenant + DOI +

+ +**Biofit** is a machine learning library designed for bioinformatics datasets. It +provides tools for transforming, extracting, training, and evaluating machine learning +models on biomedical data. It also provides automatic data preprocessing, visualization, +and configurable processing pipelines. Here are some of the main features of Biofit: + +- **Automatic Data Preprocessing:** Automatically preprocess biomedical datasets using + built-in preprocessing steps. +- **Automatic Visualization:** Automatically visualize data using built-in visualization + methods geared towards biomedical data. +- **Configurable Processing Pipelines:** Define and customize data processing pipelines. +- **Data Handling Flexibility:** Support for a wide range of data formats, including: + - [Pandas](https://github.com/pandas-dev/pandas) + - [Polars](https://github.com/pola-rs/polars) + - [NumPy](https://github.com/numpy/numpy) + - [CSR (SciPy)](https://github.com/scipy/scipy) + - [Arrow](https://github.com/apache/arrow) + - 🤗 [Datasets](https://github.com/huggingface/datasets) + - [Biosets](https://github.com/psmyth94/biosets) +- **Machine Learning Models:** Supports a wide range of machine learning models, including: + - [Scikit-learn](https://github.com/scikit-learn/scikit-learn) + - [Random Forest](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html) + - [Support Vector Machine](https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html) + - [Logistic Regression](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html) + - [Lasso Regression](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Lasso.html) + - [LightGBM](https://github.com/microsoft/LightGBM) + - More to come! +- **Caching and Reuse:** Caches intermediate results using Apache Arrow for efficient reuse. +- **Batch Processing and Multiprocessing:** Utilize batch processing and multiprocessing for efficient handling of large-scale data. + +## Installation + +You can install Biofit via pip: + +```bash +pip install biofit +``` + +## Quick Start + +### Preprocessing Data + +Biofit provides preprocessing capabilities tailored for omics data. You can use +built-in classes to load preprocessing steps based on the experiment type or create +custom preprocessing pipelines. The preprocessing pipeline in Biofit uses a syntax +similar to sklearn and supports distributed processing. + +#### Using a Preprocessor + +Biofit allows you to fit and transform your data in a few lines, similar to sklearn. +For example, you can use the LogTransformer to apply a log transformation to your data: + +```python +from biofit.preprocessing import LogTransformer +import pandas as pd + +dataset = pd.DataFrame({"feature1": [1, 2, 3, 4, 5]}) +log_transformer = LogTransformer() +preprocessed_data = log_transformer.fit_transform(dataset) +# Applying log transformation: 100%|█████████████████████████████| 5/5 [00:00<00:00, 7656.63 examples/s] +print(preprocessed_data) +# feature1 +# 0 0.000000 +# 1 0.693147 +# 2 1.098612 +# 3 1.386294 +# 4 1.609438 +``` + +#### Auto Preprocessing + +You can automatically apply standard preprocessing steps by specifying the experiment +type. This allows you to load tailored preprocessing steps for the type of data you are +working with, such as "otu", "asv", "snp", or "maldi": + +```python +from biofit.preprocessing import AutoPreprocessor + +preprocessor = AutoPreprocessor.for_experiment("snp", [{"min_prevalence": 0.1}, None]) +print(preprocessor) +# [('min_prevalence_row', MinPrevalencFilter(min_prevalence=0.1)), +# ('min_prevalence', MinPrevalenceFeatureSelector(min_prevalence=0.01))] + +# Fit and transform the dataset using the preprocessor +preprocessed_data = preprocessor.fit_transform(dataset) +``` + +Biofit is made with Biosets in mind. You can pass the loaded dataset instead of a string +to load the preprocessors: + +```python +from biosets import load_dataset + +dataset = load_dataset("csv", data_files="my_file.csv", experiment_type="snp") + +preprocessor = AutoPreprocessor.for_experiment(dataset) +print(preprocessor) +# [('min_prevalence_row', MinPrevalencFilter(min_prevalence=0.01)), +# ('min_prevalence', MinPrevalenceFeatureSelector(min_prevalence=0.01))] +preprocessed_data = preprocessor.fit_transform(dataset) +``` + +#### Custom Preprocessing Pipeline + +Biofit allows you to create custom preprocessing pipelines using the +`PreprocessorPipeline` class. This allows chaining multiple preprocessing steps from +`sklearn` and Biofit in a single operation: + +```python +from biosets import load_dataset +from biofit.preprocessing import LogTransformer, PreprocessorPipeline +from sklearn.preprocessing import StandardScaler + +# Load the dataset +dataset = load_dataset("csv", data_files="my_file.csv") + +# Define a custom preprocessing pipeline +pipeline = PreprocessorPipeline( + [("scaler", StandardScaler()), ("log_transformer", LogTransformer())] +) + +# Fit and transform the dataset using the pipeline +preprocessed_data = pipeline.fit_transform(dataset.to_pandas()) +``` + +For further details, check the [advance usage documentation](./docs/PREPROCESSING.md). + +# License + +Biofit is licensed under the Apache 2.0 License. See the [LICENSE](./LICENSE) file for +more information. + +# Contributing + +If you would like to contribute to Biofit, please read the +[CONTRIBUTING](./CONTRIBUTING.md) guidelines. diff --git a/docs/.gitkeep b/docs/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md new file mode 100644 index 0000000..992fbfa --- /dev/null +++ b/docs/CONTRIBUTING.md @@ -0,0 +1,108 @@ +# Contributing to Biofit + +Biofit is an open source project, so all contributions and suggestions are welcome. + +You can contribute in many different ways: giving ideas, answering questions, reporting +bugs, proposing enhancements, improving the documentation, fixing bugs, and more. + +Many thanks in advance to every contributor. + +## How to Work on an Open Issue? + +You have the list of open Issues at: +[Biofit Issues](https://github.com/psmyth94/biofit/issues) + +## How to Create a Pull Request? + +If you want to contribute to the codebase, follow these steps: + +1. **Clone the Repository:** + + Clone the `dev` branch of the repository to your local disk: + + ```bash + git clone git@github.com:psmyth94/biofit + cd biofit + ``` + +2. **Create a New Branch:** + + Create a new branch to hold your development changes: + + ```bash + git checkout -b a-descriptive-name-for-my-changes + ``` + + Do not work on the `main` branch directly. + +3. **Set Up a Development Environment:** + + Set up a development environment by running the following command: + + ```bash + mamba env create -n biofit-local python=3.10 + mamba activate biofit-local + pip install -e ".[test]" + ``` + + (If Biofit was already installed in the virtual environment, remove it with + `pip uninstall biofit` before reinstalling it in editable mode with the `-e` flag.) + +4. **Develop the Features on Your Branch:** + + Make your changes to the code. + +5. **Format Your Code:** + + Format your code. Run `ruff` so that your newly added files look nice with the + following command: + + ```bash + ruff check . --fix + ``` + +6. **(Optional) Use Pre-commit Hooks:** + + You can also use pre-commit to format your code automatically each time you run + `git commit`, instead of running `ruff` manually. To do this, install pre-commit via + `pip install pre-commit` and then run `pre-commit install` in the project's root + directory to set up the hooks. Note that if any files were formatted by pre-commit + hooks during committing, you have to run `git commit` again. + +7. **Commit Your Changes:** + + Once you're happy with your contribution, add your changed files and make a commit + to record your changes locally: + + ```bash + git add -u + git commit -m "Your commit message" + ``` + +8. **Sync with the Original Repository:** + + It is a good idea to sync your copy of the code with the original repository + regularly. This way you can quickly account for changes: + + ```bash + git fetch upstream + git rebase upstream main + ``` + +9. **Push the Changes:** + + Once you are satisfied, push the changes to the remote repository using: + + ```bash + git push origin a-descriptive-name-for-my-changes + ``` + +10. **Create a Pull Request:** + + Go to the webpage of the repository on GitHub. Click on "Pull request" to send + your changes to the project maintainers for review. + +## Code of Conduct + +This project adheres to the Contributor Covenant code of conduct. By participating, you +are expected to abide by this code. diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..cde56b0 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,16 @@ +[metadata] +license_files = LICENSE + +[tool:ruff] +line-length = 88 +select = ["E", "F", "W", "C90"] +ignore = ["E501"] +exclude = ["*.ipynb"] + +[tool:pytest] +# Test fails if a FutureWarning is thrown by `huggingface_hub` +filterwarnings = + error::FutureWarning:huggingface_hub* +markers = + unit: unit test + integration: integration test diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..3dbe6be --- /dev/null +++ b/setup.py @@ -0,0 +1,129 @@ +from setuptools import find_packages, setup + +REQUIRED_PKGS = [ + # the library that biofit is built upon + "biocore>=1.1.0", + # For file locking + "filelock", + # We use numpy>=1.17 to have np.random.Generator (Dataset shuffling) + "numpy>=1.17", + # Backend and serialization. + # Minimum 8.0.0 to be able to use .to_reader() + "pyarrow>=8.0.0", + # As long as we allow pyarrow < 14.0.1, to fix vulnerability CVE-2023-47248 + "pyarrow-hotfix", + # For smart caching dataset processing + "dill>=0.3.0,<0.3.8", # tmp pin until dill has official support for determinism see https://github.com/uqfoundation/dill/issues/19 + # For performance gains with apache arrow + "pandas", + # for downloading datasets over HTTPS + "requests>=2.19.0", + # progress bars in download and scripts + "tqdm>=4.62.1", + # for fast hashing + "xxhash", + # for better multiprocessing + "multiprocess", + # to save datasets locally or on any filesystem + # minimum 2023.1.0 to support protocol=kwargs in fsspec's `open`, `get_fs_token_paths`, etc.: see https://github.com/fsspec/filesystem_spec/pull/1143 + "fsspec[http]>=2023.1.0,<=2023.10.0", + # Utilities from PyPA to e.g., compare versions + "packaging", + # To parse YAML metadata from dataset cards + "pyyaml>=5.1", + # for processing and transforming datasets + "scikit-learn>=1.4.0", +] + +QUALITY_REQUIRE = ["ruff>=0.1.5"] + +DOCS_REQUIRE = [ + # Might need to add doc-builder and some specific deps in the future + "s3fs", +] + +VISUALIZATION_REQUIRE = [ + "matplotlib", + "seaborn", + "psutil", # for the start time of the biofit run +] + + +ML_REQUIRE = [ + "polars>=0.20.5", + "timezones>=0.10.2", + "optuna", + "lightgbm", + # "xgboost", + # "catboost", + "imbalanced-learn", +] + +TESTS_REQUIRE = ["pytest", "pytest-timeout", "pytest-xdist"] + + +EXTRAS_REQUIRE = { + "polars": ["polars>=0.20.5", "timezones>=0.10.2"], + "rpy2": ["rpy2>=3.5.15", "rpy2-arrow>=0.0.8"], + "ml": ML_REQUIRE, + "apache-beam": ["apache-beam>=2.26.0,<2.44.0"], + "vcf": ["cyvcf2>=0.30.0", "sgkit>=0.0.1"], + "tensorflow": [ + "tensorflow>=2.2.0,!=2.6.0,!=2.6.1; sys_platform != 'darwin' or platform_machine != 'arm64'", + "tensorflow-macos; sys_platform == 'darwin' and platform_machine == 'arm64'", + ], + "tensorflow_gpu": ["tensorflow-gpu>=2.2.0,!=2.6.0,!=2.6.1"], + "torch": ["torch"], + "jax": ["jax>=0.3.14", "jaxlib>=0.3.14"], + "s3": ["s3fs"], + "viz": VISUALIZATION_REQUIRE, + "test": QUALITY_REQUIRE + + TESTS_REQUIRE + + DOCS_REQUIRE + + VISUALIZATION_REQUIRE + + ML_REQUIRE, + "all": VISUALIZATION_REQUIRE + ML_REQUIRE + QUALITY_REQUIRE + DOCS_REQUIRE, + "quality": QUALITY_REQUIRE, + "docs": DOCS_REQUIRE, +} + +setup( + name="biofit", + version="0.0.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) + description="BioFit: Bioinformatics Machine Learning Framework", + long_description=open("README.md", encoding="utf-8").read(), + long_description_content_type="text/markdown", + author="Patrick Smyth", + author_email="psmyth1994@gmail.com", + url="https://github.com/psmyth94/biofit", + download_url="https://github.com/psmyth94/biofit/tags", + license="MIT", + package_dir={"": "src"}, + packages=find_packages("src"), + include_package_data=True, + python_requires=">=3.8.0,<3.12.0", + install_requires=REQUIRED_PKGS, + extras_require=EXTRAS_REQUIRE, + entry_points={ + "console_scripts": [ + "biofit = biofit.cli.main:main", + ], + }, + classifiers=[ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Bio-Informatics", + ], + keywords="omics machine learning bioinformatics metrics", + zip_safe=False, # Required for mypy to find the py.typed file +) diff --git a/src/biofit/__init__.py b/src/biofit/__init__.py new file mode 100644 index 0000000..078865d --- /dev/null +++ b/src/biofit/__init__.py @@ -0,0 +1,15 @@ +# ruff: noqa +from .auto import * +from .models import * +from .train import train +from .utils import ( + disable_progress_bar, + enable_progress_bar, + logging, + set_verbosity, + set_verbosity_debug, + set_verbosity_error, + set_verbosity_info, +) +from .utils.version import __version__ +from .visualization import * diff --git a/src/biofit/__main__.py b/src/biofit/__main__.py new file mode 100644 index 0000000..67a8cb8 --- /dev/null +++ b/src/biofit/__main__.py @@ -0,0 +1,15 @@ +import argparse + +from .__version__ import __version__ + + +def main(): + print("use biofit.cli instead. available commands: --version") + # print(f"{__file__}") + parser = argparse.ArgumentParser() + parser.add_argument("--version", action="version", version=f"{__version__}") + parser.parse_args() + + +if __name__ == "__main__": + main() diff --git a/src/biofit/auto/__init__.py b/src/biofit/auto/__init__.py new file mode 100644 index 0000000..cd528a4 --- /dev/null +++ b/src/biofit/auto/__init__.py @@ -0,0 +1,11 @@ +# ruff: noqa +from .configuration_auto import ( + AutoConfig, + AutoPreprocessorConfig, +) +from .modeling_auto import ( + AutoModel, + AutoModelForClassification, +) +from .processing_auto import AutoProcessor, AutoPreprocessor +from .plotting_auto import AutoPlotter, PlotterPipeline diff --git a/src/biofit/auto/auto_factory.py b/src/biofit/auto/auto_factory.py new file mode 100644 index 0000000..60917e8 --- /dev/null +++ b/src/biofit/auto/auto_factory.py @@ -0,0 +1,338 @@ +import importlib +from collections import OrderedDict +from typing import TYPE_CHECKING + +from biocore.utils.import_util import is_transformers_available + +from biofit.utils import logging + +from ..auto.configuration_auto import ( + PROCESSOR_CATEGORY_MAPPING_NAMES, + PROCESSOR_TYPE_MAPPING_NAMES, + AutoConfig, +) + +if TYPE_CHECKING: + from biofit.processing import BaseProcessor + +logger = logging.get_logger(__name__) + + +def _get_model_class(config, model_mapping: "_LazyAutoMapping"): + supported_models = model_mapping[type(config)] + if not isinstance(supported_models, (list, tuple)): + return supported_models + + name_to_model = {model.__name__: model for model in supported_models} + architectures = getattr(config, "architectures", []) + for arch in architectures: + if arch in name_to_model: + return name_to_model[arch] + elif f"TF{arch}" in name_to_model: + return name_to_model[f"TF{arch}"] + elif f"Flax{arch}" in name_to_model: + return name_to_model[f"Flax{arch}"] + + # If not architecture is set in the config or match the supported models, the first element of the tuple is the + # defaults. + return supported_models[0] + + +def _get_class(config, model_mapping): + supported_models = model_mapping[type(config)] + return supported_models + + +class _BaseAutoProcessorClass: + # Base class for auto preprocessors. + _processor_mapping = None + + def __init__(self, *args, **kwargs): + raise EnvironmentError( + f"{self.__class__.__name__} is designed to be instantiated " + f"`{self.__class__.__name__}.from_config(config)` methods." + ) + + @classmethod + def for_processor(self, processor_name, **kwargs): + config = AutoConfig.for_processor(processor_name, **kwargs) + return self.from_config(config, **kwargs) + + @classmethod + def from_config(cls, config, **kwargs): + if type(config) in cls._processor_mapping.keys(): + preprocessor_class: "BaseProcessor" = _get_class( + config, cls._processor_mapping + ) + return preprocessor_class._from_config(config, **kwargs) + + raise ValueError( + f"Unrecognized configuration class {config.__class__} for this kind of AutoProcessor: {cls.__name__}.\n" + f"Processor type should be one of {', '.join(c.__name__ for c in cls._processor_mapping.keys())}." + ) + + @classmethod + def register(cls, _config_class, processor_class: "BaseProcessor", exist_ok=False): + """ + Register a new model for this class. + + Args: + _config_class ([`PretrainedConfig`]): + The configuration corresponding to the model to register. + model_class ([`PreTrainedModel`]): + The model to register. + """ + if hasattr(processor_class, "_config_class") and issubclass( + _config_class, processor_class._config_class + ): + raise ValueError( + "The model class you are passing has a `_config_class` attribute that is not consistent with the " + f"config class you passed (model has {processor_class._config_class} and you passed {_config_class}. Fix " + "one of those so they match!" + ) + cls._processor_mapping.register( + _config_class, processor_class, exist_ok=exist_ok + ) + + +class _BaseAutoModelClass: + # Base class for auto models. + _processor_mapping = None + + def __init__(self, *args, **kwargs): + raise EnvironmentError( + f"{self.__class__.__name__} is designed to be instantiated " + f"`{self.__class__.__name__}.from_config(config)` methods." + ) + + @classmethod + def for_model(cls, model_name, *model_args, **kwargs): + config = AutoConfig.for_processor(model_name, **kwargs) + return cls.from_config(config, *model_args, **kwargs) + + @classmethod + def from_config(cls, config, **kwargs): + if type(config) in cls._processor_mapping.keys(): + model_class: "BaseProcessor" = _get_model_class( + config, cls._processor_mapping + ) + return model_class._from_config(config, **kwargs) + + if is_transformers_available(): + logger.warning( + "Could not find a matching model for this configuration. " + "Searching for a model in the Transformers library instead." + ) + + from transformers.models.auto.auto_factory import ( + _BaseAutoModelClass as _HfBaseAutoModelClass, + ) + + return _HfBaseAutoModelClass.from_config(config, **kwargs) + + raise ValueError( + f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" + f"Model type should be one of {', '.join(c.__name__ for c in cls._processor_mapping.keys())}." + ) + + @classmethod + def from_pretrained(cls, pretrained_estimator_name_or_path, *model_args, **kwargs): + if is_transformers_available(): + from transformers.models.auto.auto_factory import ( + _BaseAutoModelClass as _HfBaseAutoModelClass, + ) + + return _HfBaseAutoModelClass.from_pretrained( + pretrained_estimator_name_or_path, *model_args, **kwargs + ) + else: + raise EnvironmentError( + f"Using {cls.__name__}.from_pretrained requires the transformers library to be installed. " + "You can install it with `pip install transformers`" + ) + + @classmethod + def register(cls, _config_class, processor_class: "BaseProcessor", exist_ok=False): + """ + Register a new model for this class. + + Args: + _config_class ([`PretrainedConfig`]): + The configuration corresponding to the model to register. + model_class ([`PreTrainedModel`]): + The model to register. + """ + if hasattr(processor_class, "_config_class") and issubclass( + _config_class, processor_class._config_class + ): + raise ValueError( + "The model class you are passing has a `_config_class` attribute that is not consistent with the " + f"config class you passed (model has {processor_class._config_class} and you passed {_config_class}. Fix " + "one of those so they match!" + ) + cls._processor_mapping.register( + _config_class, processor_class, exist_ok=exist_ok + ) + + +def insert_head_doc(docstring, head_doc=""): + if len(head_doc) > 0: + return docstring.replace( + "one of the model classes of the library ", + f"one of the model classes of the library (with a {head_doc} head) ", + ) + return docstring.replace( + "one of the model classes of the library ", + "one of the base model classes of the library ", + ) + + +def get_values(model_mapping): + result = [] + for model in model_mapping.values(): + if isinstance(model, (list, tuple)): + result += list(model) + else: + result.append(model) + + return result + + +def getattribute_from_module(module, attr): + if attr is None: + return None + if isinstance(attr, tuple): + return tuple(getattribute_from_module(module, a) for a in attr) + if hasattr(module, attr): + return getattr(module, attr) + # Some of the mappings have entries estimator_type -> object of another model type. In that case we try to grab the + # object at the top level. + biofit_module = importlib.import_module("biofit") + + if module != biofit_module: + try: + return getattribute_from_module(biofit_module, attr) + except ValueError: + raise ValueError( + f"Could not find {attr} neither in {module} nor in {biofit_module}!" + ) + else: + raise ValueError(f"Could not find {attr} in {biofit_module}!") + + +class _LazyAutoMapping(OrderedDict): + """ + " A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed. + + Args: + - config_mapping: The map model type to config class + - model_mapping: The map model type to model (or tokenizer) class + """ + + def __init__(self, config_mapping, processor_mapping): + self._config_mapping = config_mapping + self._reverse_config_mapping = {v: k for k, v in config_mapping.items()} + self._processor_mapping = processor_mapping + self._processor_mapping._processor_mapping = self + self._extra_content = {} + self._modules = {} + + def __len__(self): + common_keys = set(self._config_mapping.keys()).intersection( + self._processor_mapping.keys() + ) + return len(common_keys) + len(self._extra_content) + + def __getitem__(self, key): + if key in self._extra_content: + return self._extra_content[key] + processor_type = self._reverse_config_mapping[key.__name__] + if processor_type in self._processor_mapping: + processor_name = self._processor_mapping[processor_type] + return self._load_attr_from_module(processor_type, processor_name) + + # Maybe there was several model types associated with this config. + estimator_types = [ + k for k, v in self._config_mapping.items() if v == key.__name__ + ] + for ptype in estimator_types: + if ptype in self._processor_mapping: + processor_name = self._processor_mapping[ptype] + return self._load_attr_from_module(ptype, processor_name) + raise KeyError(key) + + def _load_attr_from_module(self, module_name, attr): + if module_name not in self._modules: + processor_category = PROCESSOR_CATEGORY_MAPPING_NAMES.get( + module_name, "models" + ) + processor_type = PROCESSOR_TYPE_MAPPING_NAMES.get(module_name, None) + if processor_type is not None: + package_name = f"biofit.{processor_category}.{processor_type}" + else: + package_name = f"biofit.{processor_category}" + self._modules[module_name] = importlib.import_module( + f".{module_name}", package_name + ) + return getattribute_from_module(self._modules[module_name], attr) + + def keys(self): + mapping_keys = [ + self._load_attr_from_module(key, name) + for key, name in self._config_mapping.items() + if key in self._processor_mapping.keys() + ] + return mapping_keys + list(self._extra_content.keys()) + + def get(self, key, default): + try: + return self.__getitem__(key) + except KeyError: + return default + + def __bool__(self): + return bool(self.keys()) + + def values(self): + mapping_values = [ + self._load_attr_from_module(key, name) + for key, name in self._processor_mapping.items() + if key in self._config_mapping.keys() + ] + return mapping_values + list(self._extra_content.values()) + + def items(self): + mapping_items = [ + ( + self._load_attr_from_module(key, self._config_mapping[key]), + self._load_attr_from_module(key, self._processor_mapping[key]), + ) + for key in self._processor_mapping.keys() + if key in self._config_mapping.keys() + ] + return mapping_items + list(self._extra_content.items()) + + def __iter__(self): + return iter(self.keys()) + + def __contains__(self, item): + if item in self._extra_content: + return True + if ( + not hasattr(item, "__name__") + or item.__name__ not in self._reverse_config_mapping + ): + return False + estimator_type = self._reverse_config_mapping[item.__name__] + return estimator_type in self._processor_mapping + + def register(self, key, value, exist_ok=False): + """ + Register a new processor in this mapping. + """ + if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping: + estimator_type = self._reverse_config_mapping[key.__name__] + if estimator_type in self._processor_mapping.keys() and not exist_ok: + raise ValueError(f"'{key}' is already used by a Transformers model.") + + self._extra_content[key] = value diff --git a/src/biofit/auto/configuration_auto.py b/src/biofit/auto/configuration_auto.py new file mode 100644 index 0000000..fb324ba --- /dev/null +++ b/src/biofit/auto/configuration_auto.py @@ -0,0 +1,1269 @@ +import importlib +import re +import warnings +from collections import OrderedDict +from typing import List, Union + +from biocore.utils.import_util import ( + is_biosets_available, + is_transformers_available, +) +from biocore.utils.inspect import get_kwargs + +from biofit.processing import ProcessorConfig +from biofit.visualization.plotting import PlotterConfig + +PROCESSOR_MAPPING_NAMES = OrderedDict( + [ + ("col_sum", "ColumnSumStat"), + ("row_sum", "RowSumStat"), + ("correlation", "CorrelationStat"), + ("distance", "DistanceStat"), + ("row_missingness", "RowMissingnessStat"), + ("row_mean", "RowMeanStat"), + ("col_missingness", "ColumnMissingnessStat"), + ("col_mean", "ColumnMeanStat"), + ("lightgbm", "LightGBMModel"), + ("lasso", "LassoModel"), + ("random_forest", "RandomForestModel"), + ("logistic_regression", "LogisticRegressionModel"), + ("pcoa", "PCoAFeatureExtractor"), + ("pca", "PCAFeatureExtractor"), + ("label_binarizing", "LabelBinarizer"), + ("label_encoding", "LabelEncoder"), + ("upsampling", "UpSampler"), + ("min_prevalence_feature_selector", "MinPrevalenceFeatureSelector"), + ("log", "LogTransformer"), + ("relative_abundance", "RelativeAbundanceScaler"), + ("tmm", "TMMScaler"), + ("css", "CumulativeSumScaler"), + ("missing_labels", "MissingLabelsSampleFilter"), + ("min_prevalence_sample_filter", "MinPrevalenceSampleFilter"), + ("row_abundance", "AbundanceSampleFilter"), + ] +) + +PLOTTER_MAPPING_NAMES = OrderedDict( + [ + ("pcoa", "PCoAFeatureExtractorPlotter"), + ("min_prevalence_feature_selector", "MinPrevalenceFeatureSelectorPlotter"), + ("relative_abundance", "RelativeAbundancePlotter"), + ("css", "CumulativeSumScalerPlotter"), + ("min_prevalence_sample_filter", "MinPrevalenceRowPlotter"), + ("row_abundance", "AbundanceSampleFilterPlotter"), + ] +) + +CONFIG_MAPPING_NAMES = OrderedDict( + [ + ("col_sum", "ColumnSumStatConfig"), + ("row_sum", "RowSumStatConfig"), + ("correlation", "CorrelationStatConfig"), + ("distance", "DistanceStatConfig"), + ("row_missingness", "RowMissingnessStatConfig"), + ("row_mean", "RowMeanStatConfig"), + ("col_missingness", "ColumnMissingnessStatConfig"), + ("col_mean", "ColumnMeanStatConfig"), + ("lightgbm", "LightGBMConfig"), + ("lasso", "LassoConfig"), + ("random_forest", "RandomForestConfig"), + ("logistic_regression", "LogisticRegressionConfig"), + ("pcoa", "PCoAFeatureExtractorConfig"), + ("pca", "PCAFeatureExtractorConfig"), + ("label_binarizing", "LabelBinarizerConfig"), + ("label_encoding", "LabelEncoderConfig"), + ("upsampling", "UpSamplerConfig"), + ("min_prevalence_feature_selector", "MinPrevalenceFeatureSelectorConfig"), + ("log", "LogTransformerConfig"), + ("relative_abundance", "RelativeAbundanceScalerConfig"), + ("tmm", "TMMScalerConfig"), + ("css", "CumulativeSumScalerConfig"), + ("missing_labels", "MissingLabelsSampleFilterConfig"), + ("min_prevalence_sample_filter", "MinPrevalenceRowSampleFilterConfig"), + ("row_abundance", "AbundanceSampleFilterConfig"), + ] +) + +PLOTTER_CONFIG_MAPPING_NAMES = OrderedDict( + [ + ("pcoa", "PCoAFeatureExtractorPlotterConfig"), + ("min_prevalence_feature_selector", "MinPrevalencePlotterConfig"), + ("relative_abundance", "RelativeAbundancePlotterConfig"), + ("css", "CumulativeSumScalerPlotterConfig"), + ("min_prevalence_sample_filter", "MinPrevalenceRowPlotterConfig"), + ("row_abundance", "AbundanceSampleFilterPlotterConfig"), + ] +) + +PROCESSOR_CATEGORY_MAPPING_NAMES = OrderedDict( + [ + ("col_sum", "stat"), + ("row_sum", "stat"), + ("correlation", "stat"), + ("distance", "stat"), + ("row_missingness", "stat"), + ("row_mean", "stat"), + ("col_missingness", "stat"), + ("col_mean", "stat"), + ("lightgbm", "models"), + ("lasso", "models"), + ("random_forest", "models"), + ("logistic_regression", "models"), + ("pcoa", "preprocessing"), + ("pca", "preprocessing"), + ("label_binarizing", "preprocessing"), + ("label_encoding", "preprocessing"), + ("upsampling", "preprocessing"), + ("min_prevalence_feature_selector", "preprocessing"), + ("log", "preprocessing"), + ("relative_abundance", "preprocessing"), + ("tmm", "preprocessing"), + ("css", "preprocessing"), + ("missing_labels", "preprocessing"), + ("min_prevalence_sample_filter", "preprocessing"), + ("row_abundance", "preprocessing"), + ] +) + +PROCESSOR_TYPE_MAPPING_NAMES = OrderedDict( + [ + ("pcoa", "feature_extraction"), + ("pca", "feature_extraction"), + ("label_binarizing", "encoding"), + ("label_encoding", "encoding"), + ("upsampling", "resampling"), + ("min_prevalence_feature_selector", "feature_selection"), + ("log", "transformation"), + ("relative_abundance", "scaling"), + ("tmm", "scaling"), + ("css", "scaling"), + ("missing_labels", "filtering"), + ("min_prevalence_sample_filter", "filtering"), + ("row_abundance", "filtering"), + ] +) + +METABOLOMICS_MAPPING_NAMES = OrderedDict( + [ + ("col_sum", "ColumnSumStatConfig"), + ("row_sum", "RowSumStatConfig"), + ("correlation", "CorrelationStatConfig"), + ("distance", "DistanceStatConfig"), + ("row_missingness", "RowMissingnessStatConfig"), + ("row_mean", "RowMeanStatConfig"), + ("col_missingness", "ColumnMissingnessStatConfig"), + ("col_mean", "ColumnMeanStatConfig"), + ("lightgbm", "LightGBMConfig"), + ("lasso", "LassoConfig"), + ("random_forest", "RandomForestConfig"), + ("logistic_regression", "LogisticRegressionConfig"), + ("pcoa", "PCoAFeatureExtractorConfig"), + ("pca", "PCAFeatureExtractorConfig"), + ("label_binarizing", "LabelBinarizerConfig"), + ("label_encoding", "LabelEncoderConfig"), + ("upsampling", "UpSamplerConfig"), + ("min_prevalence_feature_selector", "MinPrevalenceFeatureSelectorConfig"), + ("log", "LogTransformerConfig"), + ("relative_abundance", "RelativeAbundanceScalerConfig"), + ("tmm", "TMMScalerConfig"), + ("css", "CLRScalerConfig"), + ("imputation", "ImputationConfig"), + ("missing_labels", "MissingLabelsSampleFilterConfig"), + ("min_prevalence_sample_filter", "MinPrevalenceRowSampleFilterConfig"), + ("row_abundance", "AbundanceSampleFilterConfig"), + ] +) + +KMER_MAPPING_NAMES = OrderedDict( + [ + ("col_sum", "ColumnSumStatConfig"), + ("row_sum", "RowSumStatConfig"), + ("correlation", "CorrelationStatConfig"), + ("distance", "DistanceStatConfig"), + ("row_missingness", "RowMissingnessStatConfig"), + ("row_mean", "RowMeanStatConfig"), + ("col_missingness", "ColumnMissingnessStatConfig"), + ("col_mean", "ColumnMeanStatConfig"), + ("lightgbm", "LightGBMConfig"), + ("lasso", "LassoConfig"), + ("random_forest", "RandomForestConfig"), + ("logistic_regression", "LogisticRegressionConfig"), + ("pcoa", "PCoAFeatureExtractorConfig"), + ("pca", "PCAFeatureExtractorConfig"), + ("label_binarizing", "LabelBinarizerConfig"), + ("label_encoding", "LabelEncoderConfig"), + ("upsampling", "UpSamplerConfig"), + ("min_prevalence_feature_selector", "MinPrevalenceFeatureSelectorConfig"), + ("log", "LogTransformerConfig"), + ("relative_abundance", "RelativeAbundanceScalerConfig"), + ("tmm", "TMMScalerConfig"), + ("css", "CLRScalerConfig"), + ("imputation", "ImputationConfig"), + ("missing_labels", "MissingLabelsSampleFilterConfig"), + ("min_prevalence_sample_filter", "MinPrevalenceRowSampleFilterConfig"), + ("row_abundance", "AbundanceSampleFilterConfig"), + ] +) + +MS1_MAPPING_NAMES = OrderedDict( + [ + ("col_sum", "ColumnSumStatConfig"), + ("row_sum", "RowSumStatConfig"), + ("correlation", "CorrelationStatConfig"), + ("distance", "DistanceStatConfig"), + ("row_missingness", "RowMissingnessStatConfig"), + ("row_mean", "RowMeanStatConfig"), + ("col_missingness", "ColumnMissingnessStatConfig"), + ("col_mean", "ColumnMeanStatConfig"), + ("lightgbm", "LightGBMConfig"), + ("lasso", "LassoConfig"), + ("random_forest", "RandomForestConfig"), + ("logistic_regression", "LogisticRegressionConfig"), + ("pcoa", "PCoAFeatureExtractorConfig"), + ("pca", "PCAFeatureExtractorConfig"), + ("label_binarizing", "LabelBinarizerConfig"), + ("label_encoding", "LabelEncoderConfig"), + ("upsampling", "UpSamplerConfig"), + ("min_prevalence_feature_selector", "MinPrevalenceFeatureSelectorConfig"), + ("log", "LogTransformerConfig"), + ("relative_abundance", "RelativeAbundanceScalerConfig"), + ("tmm", "TMMScalerConfig"), + ("css", "CLRScalerConfig"), + ("imputation", "ImputationConfig"), + ("missing_labels", "MissingLabelsSampleFilterConfig"), + ("min_prevalence_sample_filter", "MinPrevalenceRowSampleFilterConfig"), + ("row_abundance", "AbundanceSampleFilterConfig"), + ] +) + +MALDI_MAPPING_NAMES = OrderedDict( + [ + ("col_sum", "ColumnSumStatConfig"), + ("row_sum", "RowSumStatConfig"), + ("correlation", "CorrelationStatConfig"), + ("distance", "DistanceStatConfig"), + ("row_missingness", "RowMissingnessStatConfig"), + ("row_mean", "RowMeanStatConfig"), + ("col_missingness", "ColumnMissingnessStatConfig"), + ("col_mean", "ColumnMeanStatConfig"), + ("lightgbm", "LightGBMConfig"), + ("lasso", "LassoConfig"), + ("random_forest", "RandomForestConfig"), + ("logistic_regression", "LogisticRegressionConfig"), + ("pcoa", "PCoAFeatureExtractorConfig"), + ("pca", "PCAFeatureExtractorConfig"), + ("label_binarizing", "LabelBinarizerConfig"), + ("label_encoding", "LabelEncoderConfig"), + ("upsampling", "UpSamplerConfig"), + ("min_prevalence_feature_selector", "MinPrevalenceFeatureSelectorConfig"), + ("log", "LogTransformerConfig"), + ("relative_abundance", "RelativeAbundanceScalerConfig"), + ("tmm", "TMMScalerConfig"), + ("css", "CLRScalerConfig"), + ("imputation", "ImputationConfig"), + ("missing_labels", "MissingLabelsSampleFilterConfig"), + ("min_prevalence_sample_filter", "MinPrevalenceRowSampleFilterConfigForMaldi"), + ("row_abundance", "AbundanceSampleFilterConfig"), + ] +) + +BIODATA_MAPPING_NAMES = OrderedDict( + [ + ("col_sum", "ColumnSumStatConfig"), + ("row_sum", "RowSumStatConfig"), + ("correlation", "CorrelationStatConfig"), + ("distance", "DistanceStatConfig"), + ("row_missingness", "RowMissingnessStatConfig"), + ("row_mean", "RowMeanStatConfig"), + ("col_missingness", "ColumnMissingnessStatConfig"), + ("col_mean", "ColumnMeanStatConfig"), + ("lightgbm", "LightGBMConfig"), + ("lasso", "LassoConfig"), + ("random_forest", "RandomForestConfig"), + ("logistic_regression", "LogisticRegressionConfig"), + ("pcoa", "PCoAFeatureExtractorConfig"), + ("pca", "PCAFeatureExtractorConfig"), + ("label_binarizing", "LabelBinarizerConfig"), + ("label_encoding", "LabelEncoderConfig"), + ("upsampling", "UpSamplerConfig"), + ("min_prevalence_feature_selector", "MinPrevalenceFeatureSelectorConfig"), + ("log", "LogTransformerConfig"), + ("relative_abundance", "RelativeAbundanceScalerConfig"), + ("tmm", "TMMScalerConfig"), + ("css", "CLRScalerConfig"), + ("imputation", "ImputationConfig"), + ("missing_labels", "MissingLabelsSampleFilterConfig"), + ("min_prevalence_sample_filter", "MinPrevalenceRowSampleFilterConfig"), + ("row_abundance", "AbundanceSampleFilterConfig"), + ] +) + +RNA_SEQ_MAPPING_NAMES = OrderedDict( + [ + ("col_sum", "ColumnSumStatConfig"), + ("row_sum", "RowSumStatConfig"), + ("correlation", "CorrelationStatConfig"), + ("distance", "DistanceStatConfig"), + ("row_missingness", "RowMissingnessStatConfig"), + ("row_mean", "RowMeanStatConfig"), + ("col_missingness", "ColumnMissingnessStatConfig"), + ("col_mean", "ColumnMeanStatConfig"), + ("lightgbm", "LightGBMConfig"), + ("lasso", "LassoConfig"), + ("random_forest", "RandomForestConfig"), + ("logistic_regression", "LogisticRegressionConfig"), + ("pcoa", "PCoAFeatureExtractorConfig"), + ("pca", "PCAFeatureExtractorConfig"), + ("label_binarizing", "LabelBinarizerConfig"), + ("label_encoding", "LabelEncoderConfig"), + ("upsampling", "UpSamplerConfig"), + ("min_prevalence_feature_selector", "MinPrevalenceFeatureSelectorConfig"), + ("log", "LogTransformerConfig"), + ("relative_abundance", "RelativeAbundanceScalerConfig"), + ("tmm", "TMMScalerConfig"), + ("css", "CLRScalerConfig"), + ("imputation", "ImputationConfig"), + ("missing_labels", "MissingLabelsSampleFilterConfig"), + ("min_prevalence_sample_filter", "MinPrevalenceRowSampleFilterConfig"), + ("row_abundance", "AbundanceSampleFilterConfig"), + ] +) + +PROTEOMICS_MAPPING_NAMES = OrderedDict( + [ + ("col_sum", "ColumnSumStatConfig"), + ("row_sum", "RowSumStatConfig"), + ("correlation", "CorrelationStatConfig"), + ("distance", "DistanceStatConfig"), + ("row_missingness", "RowMissingnessStatConfig"), + ("row_mean", "RowMeanStatConfig"), + ("col_missingness", "ColumnMissingnessStatConfig"), + ("col_mean", "ColumnMeanStatConfig"), + ("lightgbm", "LightGBMConfig"), + ("lasso", "LassoConfig"), + ("random_forest", "RandomForestConfig"), + ("logistic_regression", "LogisticRegressionConfig"), + ("pcoa", "PCoAFeatureExtractorConfig"), + ("pca", "PCAFeatureExtractorConfig"), + ("label_binarizing", "LabelBinarizerConfig"), + ("label_encoding", "LabelEncoderConfig"), + ("upsampling", "UpSamplerConfig"), + ("min_prevalence_feature_selector", "MinPrevalenceFeatureSelectorConfig"), + ("log", "LogTransformerConfig"), + ("relative_abundance", "RelativeAbundanceScalerConfig"), + ("tmm", "TMMScalerConfig"), + ("css", "CLRScalerConfig"), + ("imputation", "ImputationConfig"), + ("missing_labels", "MissingLabelsSampleFilterConfig"), + ("min_prevalence_sample_filter", "MinPrevalenceRowSampleFilterConfig"), + ("row_abundance", "AbundanceSampleFilterConfig"), + ] +) + +METAGENOMICS_MAPPING_NAMES = OrderedDict( + [ + ("col_sum", "ColumnSumStatConfig"), + ("row_sum", "RowSumStatConfig"), + ("correlation", "CorrelationStatConfig"), + ("distance", "DistanceStatConfigForMetagenomics"), + ("row_missingness", "RowMissingnessStatConfig"), + ("row_mean", "RowMeanStatConfig"), + ("col_missingness", "ColumnMissingnessStatConfigForMetagenomics"), + ("col_mean", "ColumnMeanStatConfig"), + ("lightgbm", "LightGBMConfig"), + ("lasso", "LassoConfig"), + ("random_forest", "RandomForestConfig"), + ("logistic_regression", "LogisticRegressionConfig"), + ("pcoa", "PCoAFeatureExtractorConfig"), + ("pca", "PCAFeatureExtractorConfig"), + ("label_binarizing", "LabelBinarizerConfig"), + ("label_encoding", "LabelEncoderConfig"), + ("upsampling", "UpSamplerConfigForMetagenomics"), + ( + "min_prevalence_feature_selector", + "MinPrevalenceFeatureSelectorConfigForMetagenomics", + ), + ("log", "LogTransformerConfig"), + ("relative_abundance", "RelativeAbundanceScalerConfig"), + ("tmm", "TMMScalerConfigForMetagenomics"), + ("css", "CumulativeSumScalerConfigForMetagenomics"), + ("imputation", "ImputationConfig"), + ("missing_labels", "MissingLabelsSampleFilterConfig"), + ("min_prevalence_sample_filter", "MinPrevalenceRowSampleFilterConfig"), + ("row_abundance", "AbundanceSampleFilterConfig"), + ] +) + +MS2_MAPPING_NAMES = OrderedDict( + [ + ("col_sum", "ColumnSumStatConfig"), + ("row_sum", "RowSumStatConfig"), + ("correlation", "CorrelationStatConfig"), + ("distance", "DistanceStatConfig"), + ("row_missingness", "RowMissingnessStatConfig"), + ("row_mean", "RowMeanStatConfig"), + ("col_missingness", "ColumnMissingnessStatConfig"), + ("col_mean", "ColumnMeanStatConfig"), + ("lightgbm", "LightGBMConfig"), + ("lasso", "LassoConfig"), + ("random_forest", "RandomForestConfig"), + ("logistic_regression", "LogisticRegressionConfig"), + ("pcoa", "PCoAFeatureExtractorConfig"), + ("pca", "PCAFeatureExtractorConfig"), + ("label_binarizing", "LabelBinarizerConfig"), + ("label_encoding", "LabelEncoderConfig"), + ("upsampling", "UpSamplerConfig"), + ("min_prevalence_feature_selector", "MinPrevalenceFeatureSelectorConfig"), + ("log", "LogTransformerConfig"), + ("relative_abundance", "RelativeAbundanceScalerConfig"), + ("tmm", "TMMScalerConfig"), + ("css", "CLRScalerConfig"), + ("imputation", "ImputationConfig"), + ("missing_labels", "MissingLabelsSampleFilterConfig"), + ("min_prevalence_sample_filter", "MinPrevalenceRowSampleFilterConfig"), + ("row_abundance", "AbundanceSampleFilterConfig"), + ] +) + +OTU_MAPPING_NAMES = OrderedDict( + [ + ("col_sum", "ColumnSumStatConfigForOTU"), + ("row_sum", "RowSumStatConfigForOTU"), + ("correlation", "CorrelationStatConfig"), + ("distance", "DistanceStatConfigForOTU"), + ("row_missingness", "RowMissingnessStatConfigForOTU"), + ("row_mean", "RowMeanStatConfigForOTU"), + ("col_missingness", "ColumnMissingnessStatConfigForOTU"), + ("col_mean", "ColumnMeanStatConfigForOTU"), + ("lightgbm", "LightGBMConfig"), + ("lasso", "LassoConfigForOTU"), + ("random_forest", "RandomForestConfig"), + ("logistic_regression", "LogisticRegressionConfig"), + ("pcoa", "PCoAFeatureExtractorConfigForOTU"), + ("pca", "PCAFeatureExtractorConfig"), + ("label_binarizing", "LabelBinarizerConfig"), + ("label_encoding", "LabelEncoderConfig"), + ("upsampling", "UpSamplerConfigForOTU"), + ("min_prevalence_feature_selector", "MinPrevalenceFeatureSelectorConfigForOTU"), + ("log", "LogTransformerConfigForOTU"), + ("relative_abundance", "RelativeAbundanceScalerConfig"), + ("tmm", "TMMScalerConfigForOTU"), + ("css", "CumulativeSumScalerConfigForOTU"), + ("imputation", "ImputationConfig"), + ("missing_labels", "MissingLabelsSampleFilterConfig"), + ("min_prevalence_sample_filter", "MinPrevalenceRowSampleFilterConfigForOTU"), + ("row_abundance", "AbundanceSampleFilterConfigForOTU"), + ] +) + +GENOMICS_MAPPING_NAMES = OrderedDict( + [ + ("col_sum", "ColumnSumStatConfig"), + ("row_sum", "RowSumStatConfig"), + ("correlation", "CorrelationStatConfig"), + ("distance", "DistanceStatConfig"), + ("row_missingness", "RowMissingnessStatConfig"), + ("row_mean", "RowMeanStatConfig"), + ("col_missingness", "ColumnMissingnessStatConfig"), + ("col_mean", "ColumnMeanStatConfig"), + ("lightgbm", "LightGBMConfig"), + ("lasso", "LassoConfig"), + ("random_forest", "RandomForestConfig"), + ("logistic_regression", "LogisticRegressionConfig"), + ("pcoa", "PCoAFeatureExtractorConfig"), + ("pca", "PCAFeatureExtractorConfig"), + ("label_binarizing", "LabelBinarizerConfig"), + ("label_encoding", "LabelEncoderConfig"), + ("upsampling", "UpSamplerConfig"), + ("min_prevalence_feature_selector", "MinPrevalenceFeatureSelectorConfig"), + ("log", "LogTransformerConfig"), + ("relative_abundance", "RelativeAbundanceScalerConfig"), + ("tmm", "TMMScalerConfig"), + ("css", "CLRScalerConfig"), + ("imputation", "ImputationConfig"), + ("missing_labels", "MissingLabelsSampleFilterConfig"), + ("min_prevalence_sample_filter", "MinPrevalenceRowSampleFilterConfig"), + ("row_abundance", "AbundanceSampleFilterConfig"), + ] +) + +SNP_MAPPING_NAMES = OrderedDict( + [ + ("col_sum", "ColumnSumStatConfig"), + ("row_sum", "RowSumStatConfig"), + ("correlation", "CorrelationStatConfig"), + ("distance", "DistanceStatConfigForSNP"), + ("row_missingness", "RowMissingnessStatConfigForSNP"), + ("row_mean", "RowMeanStatConfig"), + ("col_missingness", "ColumnMissingnessStatConfigForSNP"), + ("col_mean", "ColumnMeanStatConfig"), + ("lightgbm", "LightGBMConfig"), + ("lasso", "LassoConfig"), + ("random_forest", "RandomForestConfig"), + ("logistic_regression", "LogisticRegressionConfig"), + ("pcoa", "PCoAFeatureExtractorConfig"), + ("pca", "PCAFeatureExtractorConfig"), + ("label_binarizing", "LabelBinarizerConfig"), + ("label_encoding", "LabelEncoderConfig"), + ("upsampling", "UpSamplerConfigForSNP"), + ("min_prevalence_feature_selector", "MinPrevalenceFeatureSelectorConfigForSNP"), + ("log", "LogTransformerConfig"), + ("relative_abundance", "RelativeAbundanceScalerConfig"), + ("tmm", "TMMScalerConfigForSNP"), + ("css", "CLRScalerConfig"), + ("imputation", "ImputationConfig"), + ("missing_labels", "MissingLabelsSampleFilterConfig"), + ("min_prevalence_sample_filter", "MinPrevalenceRowSampleFilterConfigForSNP"), + ("row_abundance", "AbundanceSampleFilterConfig"), + ] +) + +METABOLOMICS_PLOTTER_MAPPING_NAMES = OrderedDict( + [ + ("pcoa", "PCoAFeatureExtractorPlotterConfig"), + ("min_prevalence_feature_selector", "MinPrevalencePlotterConfig"), + ("relative_abundance", "RelativeAbundancePlotterConfig"), + ("css", "CumulativeSumScalerPlotterConfig"), + ("min_prevalence_sample_filter", "MinPrevalenceRowPlotterConfig"), + ("row_abundance", "AbundanceSampleFilterPlotterConfig"), + ] +) + +KMER_PLOTTER_MAPPING_NAMES = OrderedDict( + [ + ("pcoa", "PCoAFeatureExtractorPlotterConfig"), + ("min_prevalence_feature_selector", "MinPrevalencePlotterConfigForGenomics"), + ("relative_abundance", "RelativeAbundancePlotterConfig"), + ("css", "CumulativeSumScalerPlotterConfigForGenomics"), + ("min_prevalence_sample_filter", "MinPrevalenceRowPlotterConfigForGenomics"), + ("row_abundance", "AbundanceSampleFilterPlotterConfigForGenomics"), + ] +) + +MS1_PLOTTER_MAPPING_NAMES = OrderedDict( + [ + ("pcoa", "PCoAFeatureExtractorPlotterConfig"), + ("min_prevalence_feature_selector", "MinPrevalencePlotterConfigForProteomics"), + ("relative_abundance", "RelativeAbundancePlotterConfig"), + ("css", "CumulativeSumScalerPlotterConfig"), + ("min_prevalence_sample_filter", "MinPrevalenceRowPlotterConfigForProteomics"), + ("row_abundance", "AbundanceSampleFilterPlotterConfigForProteomics"), + ] +) + +MALDI_PLOTTER_MAPPING_NAMES = OrderedDict( + [ + ("pcoa", "PCoAFeatureExtractorPlotterConfig"), + ("min_prevalence_feature_selector", "MinPrevalencePlotterConfigForMaldi"), + ("relative_abundance", "RelativeAbundancePlotterConfigForMaldi"), + ("css", "CumulativeSumScalerPlotterConfig"), + ("min_prevalence_sample_filter", "MinPrevalenceRowPlotterConfigForMaldi"), + ("row_abundance", "AbundanceSampleFilterPlotterConfigForProteomics"), + ] +) + +BIODATA_PLOTTER_MAPPING_NAMES = OrderedDict( + [ + ("pcoa", "PCoAFeatureExtractorPlotterConfig"), + ("min_prevalence_feature_selector", "MinPrevalencePlotterConfig"), + ("relative_abundance", "RelativeAbundancePlotterConfig"), + ("css", "CumulativeSumScalerPlotterConfig"), + ("min_prevalence_sample_filter", "MinPrevalenceRowPlotterConfig"), + ("row_abundance", "AbundanceSampleFilterPlotterConfig"), + ] +) + +RNA_SEQ_PLOTTER_MAPPING_NAMES = OrderedDict( + [ + ("pcoa", "PCoAFeatureExtractorPlotterConfig"), + ("min_prevalence_feature_selector", "MinPrevalencePlotterConfig"), + ("relative_abundance", "RelativeAbundancePlotterConfig"), + ("css", "CumulativeSumScalerPlotterConfig"), + ("min_prevalence_sample_filter", "MinPrevalenceRowPlotterConfig"), + ("row_abundance", "AbundanceSampleFilterPlotterConfig"), + ] +) + +PROTEOMICS_PLOTTER_MAPPING_NAMES = OrderedDict( + [ + ("pcoa", "PCoAFeatureExtractorPlotterConfig"), + ("min_prevalence_feature_selector", "MinPrevalencePlotterConfigForProteomics"), + ("relative_abundance", "RelativeAbundancePlotterConfig"), + ("css", "CumulativeSumScalerPlotterConfig"), + ("min_prevalence_sample_filter", "MinPrevalenceRowPlotterConfigForProteomics"), + ("row_abundance", "AbundanceSampleFilterPlotterConfigForProteomics"), + ] +) + +METAGENOMICS_PLOTTER_MAPPING_NAMES = OrderedDict( + [ + ("pcoa", "PCoAFeatureExtractorPlotterConfig"), + ( + "min_prevalence_feature_selector", + "MinPrevalencePlotterConfigForMetagenomics", + ), + ("relative_abundance", "RelativeAbundancePlotterConfigForMetagenomics"), + ("css", "CumulativeSumScalerPlotterConfigForMetagenomics"), + ( + "min_prevalence_sample_filter", + "MinPrevalenceRowPlotterConfigForMetagenomics", + ), + ("row_abundance", "AbundanceSampleFilterPlotterConfigForMetagenomics"), + ] +) + +MS2_PLOTTER_MAPPING_NAMES = OrderedDict( + [ + ("pcoa", "PCoAFeatureExtractorPlotterConfig"), + ("min_prevalence_feature_selector", "MinPrevalencePlotterConfigForProteomics"), + ("relative_abundance", "RelativeAbundancePlotterConfig"), + ("css", "CumulativeSumScalerPlotterConfig"), + ("min_prevalence_sample_filter", "MinPrevalenceRowPlotterConfigForProteomics"), + ("row_abundance", "AbundanceSampleFilterPlotterConfigForProteomics"), + ] +) + +OTU_PLOTTER_MAPPING_NAMES = OrderedDict( + [ + ("pcoa", "PCoAFeatureExtractorPlotterConfig"), + ("min_prevalence_feature_selector", "MinPrevalencePlotterConfigForOTU"), + ("relative_abundance", "RelativeAbundancePlotterConfigForOTU"), + ("css", "CumulativeSumScalerPlotterConfigForOTU"), + ("min_prevalence_sample_filter", "MinPrevalenceRowPlotterConfigForOTU"), + ("row_abundance", "AbundanceSampleFilterPlotterConfigForOTU"), + ] +) + +GENOMICS_PLOTTER_MAPPING_NAMES = OrderedDict( + [ + ("pcoa", "PCoAFeatureExtractorPlotterConfig"), + ("min_prevalence_feature_selector", "MinPrevalencePlotterConfigForGenomics"), + ("relative_abundance", "RelativeAbundancePlotterConfig"), + ("css", "CumulativeSumScalerPlotterConfigForGenomics"), + ("min_prevalence_sample_filter", "MinPrevalenceRowPlotterConfigForGenomics"), + ("row_abundance", "AbundanceSampleFilterPlotterConfigForGenomics"), + ] +) + +SNP_PLOTTER_MAPPING_NAMES = OrderedDict( + [ + ("pcoa", "PCoAFeatureExtractorPlotterConfig"), + ("min_prevalence_feature_selector", "MinPrevalencePlotterConfigForSNP"), + ("relative_abundance", "RelativeAbundancePlotterConfigForSNP"), + ("css", "CumulativeSumScalerPlotterConfigForSNP"), + ("min_prevalence_sample_filter", "MinPrevalenceRowPlotterConfigForSNP"), + ("row_abundance", "AbundanceSampleFilterPlotterConfigForSNP"), + ] +) + + +def config_class_to_model_type(config): + """Converts a config class name to the corresponding model type""" + for key, cls in CONFIG_MAPPING_NAMES.items(): + if cls == config: + return key + for key, cls in CONFIG_MAPPING._extra_content.items(): + if cls.__name__ == config: + return key + return None + + +class _LazyConfigMapping(OrderedDict): + """ + A dictionary that lazily load its values when they are requested. + """ + + def __init__(self, mapping): + self._mapping = mapping + self._extra_content = {} + self._modules = {} + + def __getitem__(self, key: str): + if key in self._extra_content: + return self._extra_content[key] + if key not in self._mapping: + raise KeyError(key) + value = self._mapping[key] + module_name = key.replace("-", "_") + processor_category = PROCESSOR_CATEGORY_MAPPING_NAMES.get(key, "models") + processor_type = PROCESSOR_TYPE_MAPPING_NAMES.get(key, None) + try: + if module_name not in self._modules: + package = ( + f"biofit.{processor_category}" + if not processor_type + else f"biofit.{processor_category}.{processor_type}" + ) + self._modules[module_name] = importlib.import_module( + f".{module_name}", package + ) + except ImportError: + if is_transformers_available() and processor_category == "models": + from transformers.models.auto.configuration_auto import ( + model_type_to_module_name, + ) + + module_name = model_type_to_module_name(key) + if module_name not in self._modules: + self._modules[module_name] = importlib.import_module( + f".{module_name}", "transformers.models" + ) + else: + raise + if hasattr(self._modules[module_name], value): + return getattr(self._modules[module_name], value) + try: + biofit_module = importlib.import_module("biofit") + return getattr(biofit_module, value) + except AttributeError: + if is_transformers_available(): + transformers_module = importlib.import_module("transformers") + return getattr(transformers_module, value) + raise + + def keys(self): + return list(self._mapping.keys()) + list(self._extra_content.keys()) + + def values(self): + return [self[k] for k in self._mapping.keys()] + list( + self._extra_content.values() + ) + + def items(self): + return [(k, self[k]) for k in self._mapping.keys()] + list( + self._extra_content.items() + ) + + def __iter__(self): + return iter(list(self._mapping.keys()) + list(self._extra_content.keys())) + + def __contains__(self, item): + return item in self._mapping or item in self._extra_content + + def register(self, key, value, exist_ok=False): + """ + Register a new configuration in this mapping. + """ + if key in self._mapping.keys() and (not exist_ok): + raise ValueError( + f"'{key}' is already used by a Transformers config, pick another name." + ) + self._extra_content[key] = value + + +CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES) +PLOTTER_CONFIG_MAPPING = _LazyConfigMapping(PLOTTER_CONFIG_MAPPING_NAMES) +METABOLOMICS_MAPPING = _LazyConfigMapping(METABOLOMICS_MAPPING_NAMES) +KMER_MAPPING = _LazyConfigMapping(KMER_MAPPING_NAMES) +MS1_MAPPING = _LazyConfigMapping(MS1_MAPPING_NAMES) +MALDI_MAPPING = _LazyConfigMapping(MALDI_MAPPING_NAMES) +BIODATA_MAPPING = _LazyConfigMapping(BIODATA_MAPPING_NAMES) +RNA_SEQ_MAPPING = _LazyConfigMapping(RNA_SEQ_MAPPING_NAMES) +PROTEOMICS_MAPPING = _LazyConfigMapping(PROTEOMICS_MAPPING_NAMES) +METAGENOMICS_MAPPING = _LazyConfigMapping(METAGENOMICS_MAPPING_NAMES) +MS2_MAPPING = _LazyConfigMapping(MS2_MAPPING_NAMES) +OTU_MAPPING = _LazyConfigMapping(OTU_MAPPING_NAMES) +GENOMICS_MAPPING = _LazyConfigMapping(GENOMICS_MAPPING_NAMES) +SNP_MAPPING = _LazyConfigMapping(SNP_MAPPING_NAMES) +METABOLOMICS_PLOTTER_MAPPING = _LazyConfigMapping(METABOLOMICS_PLOTTER_MAPPING_NAMES) +KMER_PLOTTER_MAPPING = _LazyConfigMapping(KMER_PLOTTER_MAPPING_NAMES) +MS1_PLOTTER_MAPPING = _LazyConfigMapping(MS1_PLOTTER_MAPPING_NAMES) +MALDI_PLOTTER_MAPPING = _LazyConfigMapping(MALDI_PLOTTER_MAPPING_NAMES) +BIODATA_PLOTTER_MAPPING = _LazyConfigMapping(BIODATA_PLOTTER_MAPPING_NAMES) +RNA_SEQ_PLOTTER_MAPPING = _LazyConfigMapping(RNA_SEQ_PLOTTER_MAPPING_NAMES) +PROTEOMICS_PLOTTER_MAPPING = _LazyConfigMapping(PROTEOMICS_PLOTTER_MAPPING_NAMES) +METAGENOMICS_PLOTTER_MAPPING = _LazyConfigMapping(METAGENOMICS_PLOTTER_MAPPING_NAMES) +MS2_PLOTTER_MAPPING = _LazyConfigMapping(MS2_PLOTTER_MAPPING_NAMES) +OTU_PLOTTER_MAPPING = _LazyConfigMapping(OTU_PLOTTER_MAPPING_NAMES) +GENOMICS_PLOTTER_MAPPING = _LazyConfigMapping(GENOMICS_PLOTTER_MAPPING_NAMES) +SNP_PLOTTER_MAPPING = _LazyConfigMapping(SNP_PLOTTER_MAPPING_NAMES) + + +DATASET_TO_MAPPER = { + "metabolomics": METABOLOMICS_MAPPING, + "kmer": KMER_MAPPING, + "ms1": MS1_MAPPING, + "maldi": MALDI_MAPPING, + "biodata": BIODATA_MAPPING, + "rna-seq": RNA_SEQ_MAPPING, + "proteomics": PROTEOMICS_MAPPING, + "metagenomics": METAGENOMICS_MAPPING, + "ms2": MS2_MAPPING, + "otu": OTU_MAPPING, + "genomics": GENOMICS_MAPPING, + "snp": SNP_MAPPING, +} + +DATASET_TO_MAPPER_NAMES = { + "metabolomics": METABOLOMICS_MAPPING_NAMES, + "kmer": KMER_MAPPING_NAMES, + "ms1": MS1_MAPPING_NAMES, + "maldi": MALDI_MAPPING_NAMES, + "biodata": BIODATA_MAPPING_NAMES, + "rna-seq": RNA_SEQ_MAPPING_NAMES, + "proteomics": PROTEOMICS_MAPPING_NAMES, + "metagenomics": METAGENOMICS_MAPPING_NAMES, + "ms2": MS2_MAPPING_NAMES, + "otu": OTU_MAPPING_NAMES, + "genomics": GENOMICS_MAPPING_NAMES, + "snp": SNP_MAPPING_NAMES, +} + +DATASET_PLT_TO_MAPPER = { + "metabolomics": METABOLOMICS_PLOTTER_MAPPING, + "kmer": KMER_PLOTTER_MAPPING, + "ms1": MS1_PLOTTER_MAPPING, + "maldi": MALDI_PLOTTER_MAPPING, + "biodata": BIODATA_PLOTTER_MAPPING, + "rna-seq": RNA_SEQ_PLOTTER_MAPPING, + "proteomics": PROTEOMICS_PLOTTER_MAPPING, + "metagenomics": METAGENOMICS_PLOTTER_MAPPING, + "ms2": MS2_PLOTTER_MAPPING, + "otu": OTU_PLOTTER_MAPPING, + "genomics": GENOMICS_PLOTTER_MAPPING, + "snp": SNP_PLOTTER_MAPPING, +} + +DATASET_PLT_TO_MAPPER_NAMES = { + "metabolomics": METABOLOMICS_PLOTTER_MAPPING_NAMES, + "kmer": KMER_PLOTTER_MAPPING_NAMES, + "ms1": MS1_PLOTTER_MAPPING_NAMES, + "maldi": MALDI_PLOTTER_MAPPING_NAMES, + "biodata": BIODATA_PLOTTER_MAPPING_NAMES, + "rna-seq": RNA_SEQ_PLOTTER_MAPPING_NAMES, + "proteomics": PROTEOMICS_PLOTTER_MAPPING_NAMES, + "metagenomics": METAGENOMICS_PLOTTER_MAPPING_NAMES, + "ms2": MS2_PLOTTER_MAPPING_NAMES, + "otu": OTU_PLOTTER_MAPPING_NAMES, + "genomics": GENOMICS_PLOTTER_MAPPING_NAMES, + "snp": SNP_PLOTTER_MAPPING_NAMES, +} + + +DATASET_PREPROCESSOR_MAPPING_NAMES = OrderedDict( + [ + ( + "otu", + [ + "row_abundance", + "min_prevalence_feature_selector", + "css", + ], + ), + ( + "asv", + [ + "row_abundance", + "min_prevalence_feature_selector", + "css", + ], + ), + ( + "abundance", + [ + "row_abundance", + "min_prevalence_feature_selector", + "css", + ], + ), + ( + "metagenomics", + [ + "row_abundance", + "min_prevalence_feature_selector", + "css", + ], + ), + ( + "snp", + [ + "min_prevalence_sample_filter", + "min_prevalence_feature_selector", + ], + ), + ( + "genomics", + [ + "min_prevalence_sample_filter", + "min_prevalence_feature_selector", + ], + ), + ( + "maldi", + [ + "min_prevalence_sample_filter", + "relative_abundance", + "min_prevalence_feature_selector", + ], + ), + ( + "metabolomics", + [ + "min_prevalence_sample_filter", + "min_prevalence_feature_selector", + ], + ), + ( + "proteomics", + [ + "min_prevalence_sample_filter", + "min_prevalence_feature_selector", + ], + ), + ( + "transcriptomics", + [ + "min_prevalence_sample_filter", + "min_prevalence_feature_selector", + ], + ), + ] +) + + +class _LazyLoadAllMappings(OrderedDict): + """ + A mapping that will load all pairs of key values at the first access (either by indexing, requestions keys, values, + etc.) + + Args: + mapping: The mapping to load. + """ + + def __init__(self, mapping): + self._mapping = mapping + self._initialized = False + self._data = {} + + def _initialize(self): + if self._initialized: + return + warnings.warn( + "ALL_PRETRAINED_CONFIG_ARCHIVE_MAP is deprecated and will be removed in v5 of Transformers. It does not contain all available model checkpoints, far from it. Checkout hf.co/models for that.", + FutureWarning, + ) + for processor_name, map_name in self._mapping.items(): + processor_category = PROCESSOR_CATEGORY_MAPPING_NAMES.get( + processor_name, "models" + ) + processor_type = PROCESSOR_TYPE_MAPPING_NAMES.get(processor_name, None) + try: + module_name = processor_name.replace("-", "_") + package = ( + f"biofit.{processor_category}" + if not processor_type + else f"biofit.{processor_category}.{processor_type}" + ) + module = importlib.import_module(f".{module_name}", package) + mapping = getattr(module, map_name) + except ImportError: + if is_transformers_available() and processor_category == "models": + from transformers.models.auto.configuration_auto import ( + model_type_to_module_name, + ) + + module_name = model_type_to_module_name(processor_name) + module = importlib.import_module( + f".{module_name}", "transformers.models" + ) + mapping = getattr(module, map_name) + self._data.update(mapping) + self._initialized = True + + def __getitem__(self, key): + self._initialize() + return self._data[key] + + def keys(self): + self._initialize() + return self._data.keys() + + def values(self): + self._initialize() + return self._data.values() + + def items(self): + self._initialize() + return self._data.keys() + + def __iter__(self): + self._initialize() + return iter(self._data) + + def __contains__(self, item): + self._initialize() + return item in self._data + + +def _get_class_name(model_class: Union[str, List[str]]): + if isinstance(model_class, (list, tuple)): + return " or ".join([f"[`{c}`]" for c in model_class if c is not None]) + return f"[`{model_class}`]" + + +def _list_model_options(indent, config_to_class=None, use_model_types=True): + if config_to_class is None and (not use_model_types): + raise ValueError( + "Using `use_model_types=False` requires a `config_to_class` dictionary." + ) + if use_model_types: + if config_to_class is None: + model_type_to_name = { + model_type: f"[`{config}`]" + for model_type, config in CONFIG_MAPPING_NAMES.items() + } + else: + model_type_to_name = { + model_type: _get_class_name(model_class) + for model_type, model_class in config_to_class.items() + if model_type in PROCESSOR_MAPPING_NAMES + } + lines = [ + f"{indent}- **{model_type}** -- {model_type_to_name[model_type]} ({PROCESSOR_MAPPING_NAMES[model_type]} model)" + for model_type in sorted(model_type_to_name.keys()) + ] + else: + config_to_name = { + CONFIG_MAPPING_NAMES[config]: _get_class_name(clas) + for config, clas in config_to_class.items() + if config in CONFIG_MAPPING_NAMES + } + config_to_model_name = { + config: PROCESSOR_MAPPING_NAMES[model_type] + for model_type, config in CONFIG_MAPPING_NAMES.items() + } + lines = [ + f"{indent}- [`{config_name}`] configuration class: {config_to_name[config_name]} ({config_to_model_name[config_name]} model)" + for config_name in sorted(config_to_name.keys()) + ] + return "\n".join(lines) + + +def replace_list_option_in_docstrings(config_to_class=None, use_model_types=True): + def docstring_decorator(fn): + docstrings = fn.__doc__ + if docstrings is None: + return fn + lines = docstrings.split("\n") + i = 0 + while ( + i < len(lines) and re.search("^(\\s*)List options\\s*$", lines[i]) is None + ): + i += 1 + if i < len(lines): + indent = re.search("^(\\s*)List options\\s*$", lines[i]).groups()[0] + if use_model_types: + indent = f"{indent} " + lines[i] = _list_model_options( + indent, config_to_class=config_to_class, use_model_types=use_model_types + ) + docstrings = "\n".join(lines) + else: + raise ValueError( + f"The function {fn} should have an empty 'List options' in its docstring as placeholder, current docstring is:\n{docstrings}" + ) + fn.__doc__ = docstrings + return fn + + return docstring_decorator + + +class AutoConfig: + """ + This is a generic configuration class that will be instantiated as one of the configuration classes of the library + when created with the [`~AutoConfig.from_pretrained`] class method. + + This class cannot be instantiated directly using `__init__()` (throws an error). + """ + + def __init__(self): + raise EnvironmentError( + "AutoConfig is designed to be instantiated using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method." + ) + + @classmethod + def for_processor( + cls, model_type: str, *args, experiment_type: str = None, **kwargs + ): + if is_biosets_available(): + from biosets.packaged_modules import EXPERIMENT_TYPE_ALIAS + else: + EXPERIMENT_TYPE_ALIAS = {} + + experiment_type = EXPERIMENT_TYPE_ALIAS.get(experiment_type, experiment_type) + if experiment_type is not None: + if experiment_type in DATASET_TO_MAPPER: + mapper = DATASET_TO_MAPPER[experiment_type] + if model_type in mapper: + _config_class = mapper[model_type] + return _config_class(*args, **kwargs) + raise ValueError( + f"Unrecognized model identifier: {model_type}. Should contain one of {', '.join(mapper.keys())}" + ) + if model_type in CONFIG_MAPPING: + _config_class = CONFIG_MAPPING[model_type] + cls_kwargs = get_kwargs(kwargs, _config_class.__init__) + return _config_class(*args, **cls_kwargs) + elif ( + is_transformers_available() + and PROCESSOR_CATEGORY_MAPPING_NAMES.get(model_type, "models") == "models" + ): + from transformers.models.auto.configuration_auto import ( + AutoConfig as HfAutoConfig, + ) + + return HfAutoConfig.for_model(model_type, *args, **kwargs) + raise ValueError( + f"Unrecognized model identifier: {model_type}. Should contain one of {', '.join(CONFIG_MAPPING.keys())}" + ) + + @classmethod + @replace_list_option_in_docstrings() + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + if is_transformers_available(): + from transformers.models.auto.configuration_auto import ( + AutoConfig as HfAutoConfig, + ) + + return HfAutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + raise EnvironmentError( + "Using `AutoConfig.from_pretrained` requires the transformers library to be installed. You can install it with `pip install transformers`." + ) + + @staticmethod + def register(processor_type, config, exist_ok=False): + """ + Register a new configuration for this class. + + Args: + model_type (`str`): The model type like "bert" or "gpt". + config ([`PretrainedConfig`]): The config to register. + """ + types = ProcessorConfig + if is_transformers_available(): + from transformers import PretrainedConfig + + types = (PretrainedConfig, ProcessorConfig) + config_processor_type = getattr( + config, "processor_type", getattr(config, "model_type", None) + ) + if issubclass(config, types) and config_processor_type != processor_type: + raise ValueError( + f"The config you are passing has a `model_type` attribute that is not consistent with the model type you passed (config has {config_processor_type} and you passed {processor_type}. Fix one of those so they match!" + ) + CONFIG_MAPPING.register(processor_type, config, exist_ok=exist_ok) + + +class AutoPlotterConfig: + """ + This is a generic configuration class that will be instantiated as one of the configuration classes of the library + when created with the [`~AutoPlotterConfig.from_pretrained`] class method. + + This class cannot be instantiated directly using `__init__()` (throws an error). + """ + + @classmethod + def for_dataset(cls, dataset_name: str, *args, **kwargs): + if is_biosets_available(): + from biosets.packaged_modules import EXPERIMENT_TYPE_ALIAS + else: + EXPERIMENT_TYPE_ALIAS = {} + + dataset_name = EXPERIMENT_TYPE_ALIAS.get(dataset_name, dataset_name) + if dataset_name in DATASET_PREPROCESSOR_MAPPING_NAMES: + preprocessors = DATASET_PREPROCESSOR_MAPPING_NAMES[dataset_name] + mapper = DATASET_PLT_TO_MAPPER[dataset_name] + classes = [] + for preprocessor in preprocessors: + _config_class = mapper[preprocessor] + classes.append(_config_class(*args, **kwargs)) + return classes + + raise ValueError( + f"Unrecognized dataset identifier: {dataset_name}. Should contain one of {', '.join(DATASET_PREPROCESSOR_MAPPING_NAMES.keys())}" + ) + + @classmethod + def for_processor( + cls, processor_type: str, *args, dataset_name: str = None, **kwargs + ): + if dataset_name is not None: + if is_biosets_available(): + from biosets.packaged_modules import EXPERIMENT_TYPE_ALIAS + else: + EXPERIMENT_TYPE_ALIAS = {} + + dataset_name = EXPERIMENT_TYPE_ALIAS.get(dataset_name, dataset_name) + if dataset_name in DATASET_PLT_TO_MAPPER: + mapper = DATASET_PLT_TO_MAPPER[dataset_name] + if processor_type in mapper: + _config_class = mapper[processor_type] + return _config_class(*args, **kwargs) + raise ValueError( + f"Unrecognized model identifier: {processor_type}. Should contain one of {', '.join(mapper.keys())}" + ) + if processor_type in PLOTTER_CONFIG_MAPPING: + _config_class = PLOTTER_CONFIG_MAPPING[processor_type] + return _config_class(*args, **kwargs) + elif ( + is_transformers_available() + and PROCESSOR_CATEGORY_MAPPING_NAMES.get(processor_type, "models") + == "models" + ): + from transformers.models.auto.configuration_auto import ( + AutoConfig as HfAutoConfig, + ) + + return HfAutoConfig.for_model(processor_type, *args, **kwargs) + raise ValueError( + f"Unrecognized model identifier: {processor_type}. Should contain one of {', '.join(PLOTTER_CONFIG_MAPPING.keys())}" + ) + + @staticmethod + def register(processor_type, config, exist_ok=False): + """ + Register a new configuration for this class. + + Args: + model_type (`str`): The model type like "bert" or "gpt". + config ([`PretrainedConfig`]): The config to register. + """ + types = PlotterConfig + config_processor_type = getattr( + config, "processor_type", getattr(config, "model_type", None) + ) + if issubclass(config, types) and config_processor_type != processor_type: + raise ValueError( + f"The config you are passing has a `model_type` attribute that is not consistent with the model type you passed (config has {config_processor_type} and you passed {processor_type}. Fix one of those so they match!" + ) + PLOTTER_CONFIG_MAPPING.register(processor_type, config, exist_ok=exist_ok) + + +class AutoPreprocessorConfig(AutoConfig): + """ + This is a generic configuration class that will be instantiated as one of the configuration classes of the library + when created with the [`~AutoPreprocessorConfig.from_pretrained`] class method. + + This class cannot be instantiated directly using `__init__()` (throws an error). + """ + + @classmethod + def for_dataset(cls, dataset_name: str, *args, **kwargs): + if is_biosets_available(): + from biosets.packaged_modules import EXPERIMENT_TYPE_ALIAS + else: + EXPERIMENT_TYPE_ALIAS = {} + + dataset_name = EXPERIMENT_TYPE_ALIAS.get(dataset_name, dataset_name) + if dataset_name in DATASET_PREPROCESSOR_MAPPING_NAMES: + preprocessors = DATASET_PREPROCESSOR_MAPPING_NAMES[dataset_name] + mapper = DATASET_TO_MAPPER[dataset_name] + classes = [] + for preprocessor in preprocessors: + _config_class = mapper[preprocessor] + classes.append(_config_class(*args, **kwargs.get(preprocessor, {}))) + return classes + + raise ValueError( + f"Unrecognized dataset identifier: {dataset_name}. Should contain one of {', '.join(DATASET_PREPROCESSOR_MAPPING_NAMES.keys())}" + ) diff --git a/src/biofit/auto/modeling_auto.py b/src/biofit/auto/modeling_auto.py new file mode 100644 index 0000000..aeaa77f --- /dev/null +++ b/src/biofit/auto/modeling_auto.py @@ -0,0 +1,61 @@ +from collections import OrderedDict + +from biofit.auto.auto_factory import _BaseAutoModelClass, _LazyAutoMapping + +from .configuration_auto import CONFIG_MAPPING_NAMES + +MODEL_MAPPING_NAMES = OrderedDict( + [ + ("xgboost", "XGBoostModel"), + ("lightgbm", "LightGBMModel"), + ("catboost", "CatBoostModel"), + ("random_forest", "RandomForestModel"), + ("logistic_regression", "LogisticRegressionModel"), + ("lasso", "LassoModel"), + ("svm", "SVMModel"), + ("knn", "KNNModel"), + ] +) + +MODEL_FOR_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + ("xgboost", "XGBoostForClassification"), + ("lightgbm", "LightGBMForClassification"), + ("catboost", "CatBoostForClassification"), + ("random_forest", "RandomForestForClassification"), + ("logistic_regression", "LogisticRegressionForClassification"), + ("lasso", "LassoForClassification"), + ("svm", "SVMForClassification"), + ("knn", "KNNForClassification"), + ] +) + +MODEL_FOR_REGRESSION_MAPPING_NAMES = OrderedDict( + [ + ("xgboost", "XGBoostForRegression"), + ("lightgbm", "LightGBMForRegression"), + ("catboost", "CatBoostForRegression"), + ("random_forest", "RandomForestForRegression"), + ("logistic_regression", "LogisticRegressionForRegression"), + ("lasso", "LassoForRegression"), + ("svm", "SVMForRegression"), + ("knn", "KNNForRegression"), + ] +) + +MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) +MODEL_FOR_CLASSIFICATION_MAPPING_NAMES = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_CLASSIFICATION_MAPPING_NAMES +) + + +class AutoModel(_BaseAutoModelClass): + _processor_mapping = MODEL_MAPPING + + +class AutoModelForClassification(_BaseAutoModelClass): + _processor_mapping = MODEL_FOR_CLASSIFICATION_MAPPING_NAMES + + +class AutoModelForRegression(_BaseAutoModelClass): + _processor_mapping = MODEL_FOR_REGRESSION_MAPPING_NAMES diff --git a/src/biofit/auto/plotting_auto.py b/src/biofit/auto/plotting_auto.py new file mode 100644 index 0000000..491e1ac --- /dev/null +++ b/src/biofit/auto/plotting_auto.py @@ -0,0 +1,143 @@ +from typing import List, Union + +from biocore.utils.import_util import is_biosets_available +from biocore.utils.inspect import get_kwargs + +from biofit.auto.auto_factory import ( + _BaseAutoProcessorClass, + _LazyAutoMapping, +) +from biofit.processing import BaseProcessor +from biofit.visualization.plotting import BasePlotter +from biofit.visualization.plotting_utils import ( + display_image_carousel, + is_in_notebook, +) + +from .configuration_auto import ( + DATASET_PLT_TO_MAPPER_NAMES, + PLOTTER_CONFIG_MAPPING_NAMES, + PLOTTER_MAPPING_NAMES, + AutoPlotterConfig, +) +from .processing_auto import AutoPreprocessor, ProcessorPipeline + +PLOTTER_MAPPING = _LazyAutoMapping(PLOTTER_CONFIG_MAPPING_NAMES, PLOTTER_MAPPING_NAMES) + + +class PlotterPipeline: + def __init__(self, plotters: List[BasePlotter], processors: List[BaseProcessor]): + self.plotters = plotters + self.processors = processors + + def plot(self, X, *args, fit=True, **kwargs): + from datasets import Dataset, IterableDataset + + from biofit import Bioset + + if not isinstance(X, (Bioset, Dataset, IterableDataset)): + raise ValueError("X must be a Bioset or huggingface Dataset.") + pre_X = X + show = kwargs.pop("show", True) + images = [] + for plotter, processor in zip(self.plotters, self.processors): + fit_trans_kwargs = get_kwargs(kwargs, processor.fit_transform) + if fit: + after_X = processor.fit_transform(pre_X, *args, **fit_trans_kwargs) + else: + after_X = processor.transform(pre_X, *args, **fit_trans_kwargs) + if plotter.config._compare: + path = plotter.plot(pre_X, after_X, *args, show=False, **kwargs) + else: + path = plotter.plot(after_X, *args, show=False, **kwargs) + if isinstance(path, list): + images.extend(path) + else: + images.append(path) + pre_X = after_X + + if show and is_in_notebook(): + display_image_carousel(images) + + +class AutoPlotter(_BaseAutoProcessorClass): + _processor_mapping = PLOTTER_MAPPING + + @classmethod + def for_dataset(cls, dataset_name, **kwargs): + """Create a processor for a dataset. + + Args: + dataset (Bioset): The dataset to create a processor for. + + Returns: + Processor: The processor for the dataset. + """ + + if is_biosets_available(): + from biosets.packaged_modules import EXPERIMENT_TYPE_ALIAS + else: + EXPERIMENT_TYPE_ALIAS = {} + + dataset_name = EXPERIMENT_TYPE_ALIAS.get(dataset_name, dataset_name) + _plotter_mapping = _LazyAutoMapping( + DATASET_PLT_TO_MAPPER_NAMES.get(dataset_name), PLOTTER_MAPPING_NAMES + ) + configs = AutoPlotterConfig.for_dataset(dataset_name) + procs = [] + for config in configs: + config_kwargs = get_kwargs(kwargs, config.__class__.__init__) + procs.append( + _plotter_mapping[type(config)]._from_config(config, **config_kwargs) + ) + + processors = AutoPreprocessor.for_dataset(dataset_name) + + return PlotterPipeline(procs, processors) + + @classmethod + def from_processor( + cls, processor, dataset_name=None, **kwargs + ) -> Union[PlotterPipeline, BasePlotter]: + """Create a processor from another processor. + + Args: + processor (Processor): The processor to create a processor for. + + Returns: + Processor: The processor for the dataset. + """ + + def get_proc(proc, dataset_name=None, **kwargs): + dataset_name = dataset_name or proc.config.dataset_name + if dataset_name: + if is_biosets_available(): + from biosets.packaged_modules import EXPERIMENT_TYPE_ALIAS + else: + EXPERIMENT_TYPE_ALIAS = {} + dataset_name = EXPERIMENT_TYPE_ALIAS.get(dataset_name, dataset_name) + _plotter_mapping = _LazyAutoMapping( + DATASET_PLT_TO_MAPPER_NAMES.get(dataset_name), + PLOTTER_MAPPING_NAMES, + ) + config = AutoPlotterConfig.for_processor( + proc.config.processor_name, + dataset_name=dataset_name, + ) + else: + _plotter_mapping = PLOTTER_MAPPING + config = AutoPlotterConfig.for_processor(proc.config.processor_name) + config_kwargs = get_kwargs(kwargs, config.__class__.__init__) + return _plotter_mapping[type(config)]._from_config(config, **config_kwargs) + + if isinstance(processor, ProcessorPipeline): + plotters = [] + for proc in processor.steps: + plotters.append(get_proc(proc[1], dataset_name=dataset_name, **kwargs)) + return PlotterPipeline(plotters, processor.processors) + elif isinstance(processor, BaseProcessor): + return get_proc(processor, **kwargs) + else: + raise ValueError( + "processor must be a `biofit.ProcessorPipeline` or `biofit.BaseProcessor`." + ) diff --git a/src/biofit/auto/processing_auto.py b/src/biofit/auto/processing_auto.py new file mode 100644 index 0000000..8258b88 --- /dev/null +++ b/src/biofit/auto/processing_auto.py @@ -0,0 +1,98 @@ +from typing import List + +from biocore.utils.import_util import is_biosets_available +from biocore.utils.inspect import get_kwargs +from sklearn.pipeline import Pipeline + +from biofit.auto.auto_factory import ( + _BaseAutoProcessorClass, + _get_class, + _LazyAutoMapping, +) +from biofit.processing import BaseProcessor + +from .configuration_auto import ( + CONFIG_MAPPING_NAMES, + DATASET_TO_MAPPER_NAMES, + PROCESSOR_MAPPING_NAMES, + AutoPreprocessorConfig, +) + +PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, PROCESSOR_MAPPING_NAMES) + + +class AutoProcessor(_BaseAutoProcessorClass): + _processor_mapping = PROCESSOR_MAPPING + + +class ProcessorPipeline(Pipeline): + """A pipeline of processors.""" + + def __init__(self, processors: List[BaseProcessor] = None, **kwargs): + if "steps" in kwargs: + return super().__init__(**kwargs) + + self.processors = processors + steps = [] + for i, processor in enumerate(processors): + if not isinstance(processor, tuple): + if hasattr(processor, "config"): + steps.append( + ( + f"{processor.config.processor_name}_{i + 1}", + processor, + ) + ) + else: + steps.append((f"{processor.__class__.__name__}_{i + 1}", processor)) + else: + steps.append(processor) + super().__init__(steps, **kwargs) + + def fit_transform(self, X, y=None, **kwargs): + for processor in self.processors: + if isinstance(processor, tuple): + p = processor[-1] + else: + p = processor + fit_transform_kwargs = get_kwargs(kwargs, p.fit_transform) + X = p.fit_transform(X, y, **fit_transform_kwargs) + return X + + def pop(self, index): + self.steps.pop(index) + return self.processors.pop(index) + + +class AutoPreprocessor(_BaseAutoProcessorClass): + _processor_mapping = PROCESSOR_MAPPING + + @classmethod + def for_dataset(cls, dataset_name, **kwargs): + """ + Create a preprocessor pipeline for a given dataset. + + Args: + dataset: str + The dataset name. + Returns: + ProcessorPipeline + The preprocessor pipeline for the dataset. + """ + if is_biosets_available(): + from biosets.packaged_modules import EXPERIMENT_TYPE_ALIAS + else: + EXPERIMENT_TYPE_ALIAS = {} + + dataset_name = EXPERIMENT_TYPE_ALIAS.get(dataset_name, dataset_name) + _processor_mapping = _LazyAutoMapping( + DATASET_TO_MAPPER_NAMES.get(dataset_name), PROCESSOR_MAPPING_NAMES + ) + configs = AutoPreprocessorConfig.for_dataset(dataset_name, **kwargs) + procs = [ + _get_class(config, _processor_mapping)._from_config(config) + for config in configs + ] + + processors = ProcessorPipeline(procs) + return processors diff --git a/src/biofit/cli/__init__.py b/src/biofit/cli/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/biofit/cli/install_r.py b/src/biofit/cli/install_r.py new file mode 100644 index 0000000..fd012fc --- /dev/null +++ b/src/biofit/cli/install_r.py @@ -0,0 +1,95 @@ +import os + +from biofit.integration.R.r_caller import ( + R_PLOTTING_DEPENDENCIES, + R_PREPROCESSING_DEPENDENCIES, + RCaller, +) +from biofit.utils import logging + +logger = logging.get_logger(__name__) + + +def add_to_parser(parser): + # Register the 'install' subcommand + install_r_parser = parser.add_parser( + "install", + help="Install R dependencies", + description="Install various dependencies including R and its packages.", + ) + + install_r_parser.add_argument( + "--all", + action="store_true", + help="Install all R dependencies for plotting and preprocessing", + ) + install_r_parser.add_argument( + "--plotting", + action="store_true", + help="Install R plotting dependencies", + ) + install_r_parser.add_argument( + "--preprocessing", + action="store_true", + help="Install R preprocessing dependencies", + ) + install_r_parser.add_argument( + "--cran", + nargs="+", + help="Install specific CRAN packages", + ) + install_r_parser.add_argument( + "--bioconductor", + nargs="+", + help="Install specific Bioconductor packages", + ) + install_r_parser.add_argument( + "--binary", + action="store_true", + help="Install with binaries only", + ) + + install_r_parser.add_argument( + "--r-home", + "-r", + type=str, + help="Path to the R installation directory which contains the library folder", + ) + return parser + + +def run(args): + if args.r_home: + os.environ["R_HOME"] = args.r_home + cran_deps = args.cran or [] + bioconductor_deps = args.bioconductor or [] + params = {} + if args.binary: + params["pkgType"] = "binary" + if args.all or args.plotting: + cran_deps += R_PLOTTING_DEPENDENCIES["cran"] + bioconductor_deps += R_PLOTTING_DEPENDENCIES["bioconductor"] + if args.all or args.preprocessing: + cran_deps += R_PREPROCESSING_DEPENDENCIES["cran"] + bioconductor_deps += R_PREPROCESSING_DEPENDENCIES["bioconductor"] + + if cran_deps: + logger.info(f"Checking for CRAN packages: {', '.join(cran_deps)}") + RCaller.verify_r_dependencies( + cran_dependencies=cran_deps, + bioconductor_dependencies=[], + install_missing=True, + ) + logger.info("CRAN dependencies installed successfully.") + + if bioconductor_deps: + logger.info(f"Checking Bioconductor packages: {', '.join(bioconductor_deps)}") + RCaller.verify_r_dependencies( + cran_dependencies=[], + bioconductor_dependencies=bioconductor_deps, + install_missing=True, + ) + logger.info("Bioconductor dependencies installed successfully.") + + if not cran_deps and not bioconductor_deps: + logger.info("No R dependencies to install.") diff --git a/src/biofit/cli/main.py b/src/biofit/cli/main.py new file mode 100644 index 0000000..ded0aa2 --- /dev/null +++ b/src/biofit/cli/main.py @@ -0,0 +1,50 @@ +import argparse +import sys + +import biofit.cli.install_r as install_r +from biofit.utils import logging + +logger = logging.get_logger(__name__) + + +def create_parser(): + parser = argparse.ArgumentParser( + description="BIOFIT: General-purpose Omics Machine Learning framework" + ) + + # add a quiet option + parser.add_argument( + "--quiet", + "-q", + action="store_true", + help="Suppress all logging output except for errors", + ) + subparsers = parser.add_subparsers(dest="subcommand") + subparsers.required = True + + return parser, subparsers + + +def main(): + parser, subparsers = create_parser() + subparsers = install_r.add_to_parser(subparsers) + args = parser.parse_args() + + if args.quiet: + logging.set_verbosity(logging.ERROR) + else: + logging.set_verbosity(logging.INFO) + + try: + if args.subcommand == "install": + install_r.run(args) + else: + parser.print_help() + sys.exit(1) + except Exception as e: + logger.error(str(e)) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/src/biofit/config.py b/src/biofit/config.py new file mode 100644 index 0000000..63c2c89 --- /dev/null +++ b/src/biofit/config.py @@ -0,0 +1,231 @@ +import importlib +import importlib.metadata +import importlib.util +import logging +import os +from pathlib import Path + +from packaging import version + +logger = logging.getLogger(__name__) + +ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} +ENV_VARS_FALSE_VALUES = {"0", "OFF", "NO", "FALSE"} +ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) +ENV_VARS_FALSE_AND_AUTO_VALUES = ENV_VARS_FALSE_VALUES.union({"AUTO"}) + + +DEFAULT_XDG_CACHE_HOME = "~/.cache" +XDG_CACHE_HOME = os.getenv("XDG_CACHE_HOME", DEFAULT_XDG_CACHE_HOME) +DEFAULT_BIOFIT_CACHE_HOME = os.path.join(XDG_CACHE_HOME, "biofit") +BIOFIT_CACHE_HOME = os.path.expanduser( + os.getenv("BIOFIT_HOME", DEFAULT_BIOFIT_CACHE_HOME) +) + +DEFAULT_BIOFIT_DATASETS_CACHE = os.path.join(BIOFIT_CACHE_HOME, "datasets") +BIOFIT_DATASETS_CACHE = Path( + os.getenv("BIOFIT_DATASETS_CACHE", DEFAULT_BIOFIT_DATASETS_CACHE) +) +DEFAULT_BIOFIT_PREPROCESSORS_CACHE = os.path.join(BIOFIT_CACHE_HOME, "processors") +BIOFIT_PROCESSORS_CACHE = Path( + os.getenv("BIOFIT_PROCESSORS_CACHE", DEFAULT_BIOFIT_PREPROCESSORS_CACHE) +) + +DEFAULT_BIOFIT_METRICS_CACHE = os.path.join(BIOFIT_CACHE_HOME, "metrics") +BIOFIT_METRICS_CACHE = Path( + os.getenv("BIOFIT_METRICS_CACHE", DEFAULT_BIOFIT_METRICS_CACHE) +) + +DEFAULT_BIOFIT_MODULES_CACHE = os.path.join(BIOFIT_CACHE_HOME, "modules") +BIOFIT_MODULES_CACHE = Path( + os.getenv("BIOFIT_MODULES_CACHE", DEFAULT_BIOFIT_MODULES_CACHE) +) + +BIOFIT_DYNAMIC_MODULE_NAME = Path( + os.getenv("BIOFIT_DYNAMIC_MODULE_NAME", "datasets_modules") +) + +BIOFIT_CACHE_HOME = os.getenv( + "BIOFIT_CACHE_HOME", + os.path.join(os.getenv("XDG_CACHE_HOME", os.path.expanduser("~/.cache")), "biofit"), +) +DEFAULT_BIOFIT_PATCHES_CACHE = os.path.join(BIOFIT_CACHE_HOME, "patches") +BIOFIT_PATCHES_CACHE = Path( + os.getenv("BIOFIT_PATCHES_CACHE", DEFAULT_BIOFIT_PATCHES_CACHE) +) + +DOWNLOADED_PATCHES_DIR = "downloads" +DEFAULT_DOWNLOADED_PATCHES_PATH = os.path.join( + BIOFIT_PATCHES_CACHE, DOWNLOADED_PATCHES_DIR +) +DOWNLOADED_PATCHES_PATH = Path( + os.getenv("BIOFIT_PATCHES_DOWNLOADED_PATCHES_PATH", DEFAULT_DOWNLOADED_PATCHES_PATH) +) + +EXTRACTED_PATCHES_DIR = "extracted" +DEFAULT_EXTRACTED_PATCHES_PATH = os.path.join( + DEFAULT_DOWNLOADED_PATCHES_PATH, EXTRACTED_PATCHES_DIR +) +EXTRACTED_PATCHES_PATH = Path( + os.getenv("BIOFIT_PATCHES_EXTRACTED_PATCHES_PATH", DEFAULT_EXTRACTED_PATCHES_PATH) +) + +PATCHES_FILENAME = "patches.json" +NO_PATCHES_FILENAME = "no_patches.json" +IS_CONDA = os.getenv("CONDA_PREFIX") is not None + + +PYARROW_AVAILABLE = False +PYARROW_VERSION = "N/A" + +PYARROW_AVAILABLE = importlib.util.find_spec("pyarrow") is not None +if PYARROW_AVAILABLE: + try: + PYARROW_VERSION = version.parse(importlib.metadata.version("pyarrow")) + logger.info(f"pyarrow version {PYARROW_VERSION} available.") + except importlib.metadata.PackageNotFoundError: + pass + +RPY2_AVAILABLE = False +RPY2_VERSION = "N/A" + +RPY2_AVAILABLE = importlib.util.find_spec("rpy2") is not None +if RPY2_AVAILABLE: + try: + RPY2_VERSION = version.parse(importlib.metadata.version("rpy2")) + logger.info(f"rpy2 version {RPY2_VERSION} available.") + except importlib.metadata.PackageNotFoundError: + pass + +RPY2_ARROW_AVAILABLE = False +RPY2_ARROW_VERSION = "N/A" + +RPY2_ARROW_AVAILABLE = importlib.util.find_spec("rpy2_arrow") is not None +if RPY2_ARROW_AVAILABLE: + try: + RPY2_ARROW_VERSION = version.parse(importlib.metadata.version("rpy2-arrow")) + logger.info(f"rpy2 version {RPY2_ARROW_VERSION} available.") + except importlib.metadata.PackageNotFoundError: + pass + +POLARS_VERSION = "N/A" +POLARS_AVAILABLE = False + +POLARS_AVAILABLE = importlib.util.find_spec("polars") is not None +if POLARS_AVAILABLE: + try: + POLARS_VERSION = version.parse(importlib.metadata.version("polars")) + logger.info(f"Polars version {POLARS_VERSION} available.") + except importlib.metadata.PackageNotFoundError: + pass + +DASK_VERSION = "N/A" +DASK_AVAILABLE = False + +DASK_AVAILABLE = importlib.util.find_spec("dask") is not None +if DASK_AVAILABLE: + try: + DASK_VERSION = version.parse(importlib.metadata.version("dask")) + logger.info(f"Dask version {DASK_VERSION} available.") + except importlib.metadata.PackageNotFoundError: + pass + + +USE_TF = os.environ.get("USE_TF", "AUTO").upper() +USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() +USE_JAX = os.environ.get("USE_JAX", "AUTO").upper() + + +TORCH_VERSION = "N/A" +TORCH_AVAILABLE = False + +if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: + TORCH_AVAILABLE = importlib.util.find_spec("torch") is not None + if TORCH_AVAILABLE: + try: + TORCH_VERSION = version.parse(importlib.metadata.version("torch")) + logger.info(f"PyTorch version {TORCH_VERSION} available.") + except importlib.metadata.PackageNotFoundError: + pass +else: + logger.info("Disabling PyTorch because USE_TF is set") + +TF_VERSION = "N/A" +TF_AVAILABLE = False + +if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: + TF_AVAILABLE = importlib.util.find_spec("tensorflow") is not None + if TF_AVAILABLE: + # For the metadata, we have to look for both tensorflow and tensorflow-cpu + for package in [ + "tensorflow", + "tensorflow-cpu", + "tensorflow-gpu", + "tf-nightly", + "tf-nightly-cpu", + "tf-nightly-gpu", + "intel-tensorflow", + "tensorflow-rocm", + "tensorflow-macos", + ]: + try: + TF_VERSION = version.parse(importlib.metadata.version(package)) + except importlib.metadata.PackageNotFoundError: + continue + else: + break + else: + TF_AVAILABLE = False + if TF_AVAILABLE: + if TF_VERSION.major < 2: + logger.info( + f"TensorFlow found but with version {TF_VERSION}. `datasets` requires version 2 minimum." + ) + TF_AVAILABLE = False + else: + logger.info(f"TensorFlow version {TF_VERSION} available.") +else: + logger.info("Disabling Tensorflow because USE_TORCH is set") + + +JAX_VERSION = "N/A" +JAX_AVAILABLE = False + +if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: + JAX_AVAILABLE = ( + importlib.util.find_spec("jax") is not None + and importlib.util.find_spec("jaxlib") is not None + ) + if JAX_AVAILABLE: + try: + JAX_VERSION = version.parse(importlib.metadata.version("jax")) + logger.info(f"JAX version {JAX_VERSION} available.") + except importlib.metadata.PackageNotFoundError: + pass +else: + logger.info("Disabling JAX because USE_JAX is set to False") + + +USE_BEAM = os.environ.get("USE_BEAM", "AUTO").upper() +BEAM_VERSION = "N/A" +BEAM_AVAILABLE = False +if USE_BEAM in ENV_VARS_TRUE_AND_AUTO_VALUES: + try: + BEAM_VERSION = version.parse(importlib.metadata.version("apache_beam")) + BEAM_AVAILABLE = True + logger.info(f"Apache Beam version {BEAM_VERSION} available.") + except importlib.metadata.PackageNotFoundError: + pass +else: + logger.info("Disabling Apache Beam because USE_BEAM is set to False") + + +R_SCRIPTS = Path(__file__).parent / "integration/R/scripts" +BIOFIT_SKIP_R_DEPENDENCIES = ( + os.getenv("BIOFIT_SKIP_R_DEPENDENCIES", "true").upper() not in ENV_VARS_FALSE_VALUES +) +RECORDER_ENABLED = os.getenv("BIOFIT_RECORDER_ENABLED", "true").lower() == "true" + +NO_IMBALANCE_ADJUSTMENT = os.getenv("BIOFIT_NO_IMBALANCE_ADJUSTMENT", "false") == "true" + +PBAR_REFRESH_TIME_INTERVAL = 0.05 diff --git a/src/biofit/eval.py b/src/biofit/eval.py new file mode 100644 index 0000000..f541662 --- /dev/null +++ b/src/biofit/eval.py @@ -0,0 +1,296 @@ +from pathlib import Path +from typing import List, Union + +import numpy as np +from biocore import DataHandler +from sklearn.pipeline import Pipeline + +from biofit.metrics import calculate_metrics, confusion_matrix, get_metrics +from biofit.processing import BaseProcessor +from biofit.train_eval_utils import ( + _get_data, + get_model_info, + preprocess, + save_confusion_matrix, + save_metrics, + save_predictions, +) + + +def _predict( + models, + x_eval, + y_eval=None, + preprocessors=None, + use_proba=True, + task=None, + label_names=None, + cache_dir=None, +): + # model can be a list of models for each y_eval, if multi-label + # classification/regression is used + if not isinstance(models, list): + models = [models] + + y_preds = [] + for i, model in enumerate(models): + preprocessor = preprocessors[i] if preprocessors is not None else None + if isinstance(model, Pipeline): + if len(model.steps) > 1 and preprocessor is None: + preprocessor = Pipeline([p for p in model.steps[:-1]]) + model = model.steps[-1][1] + if preprocessor: + x_eval, _, _, _ = preprocess( + preprocessor, x_eval, cache_dir=cache_dir, transform_only=True + ) + + def fn(model, use_proba, x_test): + extra_kwargs = {} + if isinstance(model, BaseProcessor): + extra_kwargs = {"cache_dir": cache_dir, "load_from_cache_file": False} + + if use_proba and hasattr(model, "predict_proba"): + return model.predict_proba(x_test, **extra_kwargs) + elif hasattr(model, "predict"): + return model.predict(x_test, **extra_kwargs) + else: + return model.transform(x_test, **extra_kwargs) + + if not hasattr(model, "predict") and not hasattr(model, "predict_proba"): + for m in model: + y_pred = fn(m, use_proba, x_eval) + elif hasattr(model, "steps"): + for step in model.steps: + y_pred = fn( + step[-1] if isinstance(step, tuple) else step, use_proba, x_eval + ) + else: + y_pred = fn(model, use_proba, x_eval) + + y_pred_dims = DataHandler.get_shape(y_pred) + x_test_dims = DataHandler.get_shape(x_eval) + if y_eval is not None: + if task == "multiclass_classification": + if label_names is None: + label_names = np.unique(y_eval).tolist() + + if len(y_pred_dims) == 1 and ( + (y_pred_dims[0] % len(label_names)) != 0 + or (y_pred_dims[0] // len(label_names)) != x_test_dims[0] + ): + y_pred_ohe = np.zeros((x_test_dims[0], len(label_names))) + y_pred_ohe[np.arange(x_test_dims[0]), y_pred] = 1 + y_pred = y_pred_ohe + elif y_pred_dims[1] != len(label_names): + y_pred_ohe = np.zeros((y_pred_dims[0], len(label_names))) + y_pred_ohe[np.arange(y_pred_dims[0]), y_pred] = 1 + y_pred = y_pred_ohe + + if task == "binary_classification": + if len(y_pred_dims) == 1 and ( + (y_pred_dims[0] % 2) != 0 or (y_pred_dims[0] // 2) != x_test_dims[0] + ): + y_pred = DataHandler.concat([1 - y_pred, y_pred], axis=1) + elif len(y_pred_dims) > 1 and y_pred_dims[1] != 2: + y_pred = DataHandler.concat([1 - y_pred, y_pred], axis=1) + + y_preds.append(y_pred) + + if len(y_preds) == 1: + y_preds = y_preds[0] + else: + y_preds = DataHandler.concat(y_preds, axis=1) + return y_preds + + +def _update_metrics(y_true, y_pred, labels, metrics, task, results): + calc_results = calculate_metrics(metrics, y_true, y_pred, task, labels) + results.update(calc_results) + + +def evaluate( + model, + data, + target=None, + label_names=None, + input_columns: Union[List[str], str] = "auto", + target_columns: Union[List[str], str] = "auto", + preprocessors=None, + task=None, + metrics=None, + use_proba=True, + output_dir=None, + cache_dir=None, + config=None, +): + """ + Evaluate the model or models on the test data with improved flexibility and error + handling. + + Args: + data (np.ndarray): Data to be used for testing the model. + target (np.ndarray, *optional*): Labels corresponding to the test data. + models (list): A list of models to evaluate. + preprocessor (object, *optional*): + Preprocessor to apply to the data before evaluation. + metrics (dict): + Metrics to calculate for evaluation, with each metric being a callable. if + None, all metrics for the `task` will be calculated. + use_proba (bool): + Whether to use the predict_proba method of the model for predictions. + save_indices (bool): Whether to save the indices of the test set to a file. + output_dir (str): Directory to save the predictions table + cache_dir (str): + Directory to save the cache files. Defaults to f"{output_dir}/cache", if + output_dir is provided. + + Returns: + list: A list of dictionaries with metric results for each model. + """ + + x_eval, y_eval, _, _, _, _, input_columns, target_columns = _get_data( + data=data, + target=target, + valid_data=None, + valid_target=None, + input_columns=input_columns, + target_columns=target_columns, + format=None, + target_required=False, + ) + + model_info = get_model_info(model, task) + if task is None: + task = model_info.get("task", None) + + if ( + preprocessors is not None + and not isinstance(preprocessors, list) + and isinstance(model, list) + ): + preprocessors = [preprocessors] * len(model) + + if "classification" in task and label_names is None: + label_names = DataHandler.to_list(DataHandler.unique(y_eval)) + + if task and metrics is None and y_eval is not None: + if task == "classification": + if len(label_names) > 2: + task = "multiclass_classification" + else: + task = "binary_classification" + metrics = get_metrics(task) + + if ( + preprocessors is not None + and not isinstance(preprocessors, list) + and isinstance(model, list) + ): + preprocessors = [preprocessors] * len(model) + + y_preds = _predict( + x_eval=x_eval, + models=model, + preprocessors=preprocessors, + use_proba=use_proba, + ) + y_preds = DataHandler.to_pandas(y_preds) + + results = None + if y_eval is not None: + results = calculate_metrics( + metrics, DataHandler.to_pandas(y_eval), y_preds, task, label_names + ) + + y_preds = y_preds.sort_index() + if output_dir is not None: + output_dir = Path(output_dir) + output_dir.mkdir(exist_ok=True, parents=True) + y_preds = DataHandler.to_pandas(y_preds) + labs = model_info.get("class_names", None) or label_names + if results is not None: + save_metrics(output_dir, results) + + cm = confusion_matrix(y_eval, y_preds, labels=labs) + save_confusion_matrix(output_dir, cm, label_names) + save_predictions( + output_dir, + y_preds, + data, + y_eval, + label_names=label_names, + target_columns=target_columns, + ) + + return y_preds, results + + +def predict( + model, + data, + input_columns: Union[List[str], str] = "auto", + preprocessors=None, + use_proba=True, + output_dir=None, + cache_dir=None, + config=None, +): + """ + Evaluate the model or models on the test data with improved flexibility and error + handling. + + Args: + data (np.ndarray): Data to be used for testing the model. + target (np.ndarray, *optional*): Labels corresponding to the test data. + models (list): A list of models to evaluate. + preprocessor (object, *optional*): + Preprocessor to apply to the data before evaluation. + use_proba (bool): + Whether to use the predict_proba method of the model for predictions. + output_dir (str): Directory to save the predictions table + cache_dir (str): + Directory to save the cache files. Defaults to f"{output_dir}/cache", if + output_dir is provided. + config (dict, *optional*): + Placeholder for additional configuration options. Currently not used. + + Returns: + list: A list of dictionaries with metric results for each model. + """ + + model_info = get_model_info(model, task=None) + x_eval, _, _, _, _, _, input_columns, _ = _get_data( + data=data, + input_columns=input_columns, + target_columns=None, + format=None, + target_required=False, + ) + + if ( + preprocessors is not None + and not isinstance(preprocessors, list) + and isinstance(model, list) + ): + preprocessors = [preprocessors] * len(model) + + y_preds = _predict( + x_eval=x_eval, + models=model, + preprocessors=preprocessors, + use_proba=use_proba, + ) + y_preds = DataHandler.to_pandas(y_preds) + + if output_dir is not None: + output_dir = Path(output_dir) + output_dir.mkdir(exist_ok=True, parents=True) + y_pred = DataHandler.to_pandas(y_preds) + save_predictions( + output_dir, + y_pred, + data, + label_names=model_info.get("class_names", None), + ) + + return y_preds diff --git a/src/biofit/exceptions.py b/src/biofit/exceptions.py new file mode 100644 index 0000000..e69de29 diff --git a/src/biofit/integration/R/__init__.py b/src/biofit/integration/R/__init__.py new file mode 100644 index 0000000..261c8f4 --- /dev/null +++ b/src/biofit/integration/R/__init__.py @@ -0,0 +1,2 @@ +# ruff: noqa +from .r_caller import RCaller diff --git a/src/biofit/integration/R/r_caller.py b/src/biofit/integration/R/r_caller.py new file mode 100644 index 0000000..f3b2b17 --- /dev/null +++ b/src/biofit/integration/R/r_caller.py @@ -0,0 +1,736 @@ +import gc +import inspect +import io +import os +import re +import shutil +import subprocess +import sys +from glob import glob +from pathlib import Path +from typing import TYPE_CHECKING, List, Optional, TextIO + +import numpy as np +from biocore.utils.import_util import ( + is_rpy2_arrow_available, + is_rpy2_available, + requires_backends, +) + +import biofit.config +from biofit import config +from biofit.utils import logging + +if TYPE_CHECKING: + from rpy2.robjects.conversion import Converter + + +class PackageNotInstalledError(ImportError): + """Error occuring because the R package to import is not installed.""" + + pass + + +logger = logging.get_logger(__name__) + +_is_auto_install = None + +R_PLOTTING_DEPENDENCIES = { + "cran": [ + "ggplot2", + "arrow", + "circlize", + "RColorBrewer", + "scales", + "forcats", + "patchwork", + "reshape2", + "dplyr", + "tools", + ], + "bioconductor": ["ComplexHeatmap"], +} +R_PREPROCESSING_DEPENDENCIES = { + "cran": [], + "bioconductor": ["edgeR"], +} + + +def is_auto_install(): + global _is_auto_install + if _is_auto_install is None: + _is_auto_install = not biofit.config.BIOFIT_SKIP_R_DEPENDENCIES + return _is_auto_install + + +def enable_auto_install(): + global _is_auto_install + _is_auto_install = True + + +def disable_auto_install(): + global _is_auto_install + _is_auto_install = False + + +class ROutputCapture: + def __init__(self, stdout: TextIO = io.StringIO(), stderr: TextIO = io.StringIO()): + requires_backends("ROutputCapture", "rpy2") + from rpy2.rinterface_lib import callbacks + + # Create StringIO buffers to capture output + self.stdout = stdout + self.stderr = stderr + # Save original R console write functions + self._original_consolewrite_print = callbacks.consolewrite_print + self._original_consolewrite_warnerror = callbacks.consolewrite_warnerror + + def __enter__(self): + from rpy2.rinterface_lib import callbacks + + # Define new console write functions + def custom_consolewrite_print(output): + self.stdout.write(output) + + def custom_consolewrite_warnerror(output): + self.stderr.write(output) + + # Replace R console write functions with custom functions + callbacks.consolewrite_print = custom_consolewrite_print + callbacks.consolewrite_warnerror = custom_consolewrite_warnerror + + return self + + def __exit__(self, exc_type, exc_value, traceback): + from rpy2.rinterface_lib import callbacks + + # Restore original R console write functions + callbacks.consolewrite_print = self._original_consolewrite_print + callbacks.consolewrite_warnerror = self._original_consolewrite_warnerror + + # Reset the position of StringIO buffers to the beginning + self.stdout.seek(0) + self.stderr.seek(0) + + def get_stdout(self): + return self.stdout.getvalue() + + def get_stderr(self): + return self.stderr.getvalue() + + +def get_linux_distro_codename(): + """ + Retrieves the codename of the Linux distribution by parsing /etc/os-release. + + Returns: + str: The codename of the distribution (e.g., 'focal', 'bullseye', 'centos7'). + Returns 'Unknown' if the codename cannot be determined. + """ + os_release_path = "/etc/os-release" + + if not os.path.isfile(os_release_path): + return "Unknown" + + os_info = {} + + try: + with open(os_release_path, "r") as file: + for line in file: + # Remove any leading/trailing whitespace and skip empty lines + line = line.strip() + if not line or "=" not in line: + continue + + key, value = line.split("=", 1) + # Remove surrounding quotes if present + value = value.strip('"').strip("'") + os_info[key] = value + except Exception as e: + print(f"Error reading {os_release_path}: {e}") + return None + + # Handle specific distributions + distro_id = os_info.get("ID", "").lower() + + if distro_id in ["ubuntu", "debian"]: + # For Ubuntu and Debian, extract codename from VERSION or VERSION_ID + if "VERSION_CODENAME" in os_info: + return os_info["VERSION_CODENAME"].lower() + elif "VERSION" in os_info: + # Example VERSION: "20.04.5 LTS (Focal Fossa)" + match = re.search(r"\(([^)]+)\)", os_info["VERSION"]) + if match: + return match.group(1).lower() + elif "VERSION_ID" in os_info: + return os_info["VERSION_ID"].lower() + + elif distro_id in ["centos", "rhel", "fedora"]: + # For CentOS, RHEL, Fedora, use VERSION_ID or similar + if "VERSION_ID" in os_info: + return f"{distro_id}{os_info['VERSION_ID'].lower()}" + elif "VERSION" in os_info: + # Example VERSION: "7 (Core)" + match = re.search(r"(\d+)", os_info["VERSION"]) + if match: + return f"{distro_id}{match.group(1)}" + + # Attempt to get VERSION_CODENAME directly + if "VERSION_CODENAME" in os_info: + return os_info["VERSION_CODENAME"].lower() + + # Fallback: Try to extract codename from PRETTY_NAME or other fields + for key in ["PRETTY_NAME", "NAME", "DESCRIPTION"]: + if key in os_info: + match = re.search(r"\b([a-z]+)\b", os_info[key], re.IGNORECASE) + if match: + return match.group(1).lower() + + return None + + +def get_cran_info(): + """ + Detects the R version using rpy2 and constructs the appropriate CRAN URL based on the operating system + and, for Linux, the distribution codename. + + Returns: + str: The constructed CRAN URL, or a default 'latest' URL if R version cannot be detected. + """ + cran_url = None + + from rpy2.rinterface_lib.embedded import RRuntimeError + + try: + if sys.platform in ["win32", "darwin"]: + # For Windows, include R version in the CRAN URL + cran_url = "https://packagemanager.posit.co/cran/latest" + elif sys.platform.startswith("linux"): + # For Linux and macOS + dist = get_linux_distro_codename() + if dist: + cran_url = ( + f"https://packagemanager.posit.co/cran/__linux__/{dist}/latest" + ) + else: + logger.warning( + "Linux distribution codename could not be determined. Using default CRAN URL." + ) + return None + else: + logger.warning("Unknown operating system. Using default CRAN URL.") + return None + + except RRuntimeError as e: + logger.error(f"Error detecting R version: {e}. Using default CRAN URL.") + + return cran_url + + +def get_bioconductor_info(): + bioc_mirror = "https://packagemanager.posit.co/bioconductor/latest" + bioconductor_config_file = f"{bioc_mirror}/config.yaml" + logger.info(f"Using Bioconductor mirror: {bioc_mirror}") + logger.info(f"Using Bioconductor config file: {bioconductor_config_file}") + + return bioconductor_config_file, bioc_mirror + + +class RCaller: + r_code = None + r_source = None + _r_context = {} + _global_vars = {} + + def _get_converter(self, obj, use_arrow=True) -> "Converter": + """Check if rpy2 is available and retrieves the right converter. + + Args: + obj (obj): The object (or name) calling this method. + + Raises: + RuntimeError: If rpy2 is not available. + + Returns: + Converter: The converter to convert R objects to Python objects. + """ + if is_rpy2_arrow_available() and use_arrow: + from rpy2.robjects import default_converter, numpy2ri, pandas2ri + from rpy2_arrow.arrow import converter + + _converter = ( + default_converter + numpy2ri.converter + pandas2ri.converter + converter + ) + elif is_rpy2_available(): + from rpy2.robjects import conversion, numpy2ri, pandas2ri + + _converter = ( + ( + conversion.get_conversion() + if getattr(conversion, "get_conversion", None) + else conversion.converter + ) + + numpy2ri.converter + + pandas2ri.converter + ) + else: + # suggest installing rpy2_arrow if rpy2 is not available + requires_backends(obj, "rpy2_arrow") + _converter = None + + return _converter + + @staticmethod + def get_r_home(): + """ + Retrieve the R_HOME directory across multiple systems, prioritizing local R installations, + particularly those within an active Conda environment. + + Returns: + str or None: The path to R_HOME if found, otherwise None. + """ + # Step 1: Check R_HOME environment variable + r_home_env = os.environ.get("R_HOME") + if r_home_env and os.path.exists(r_home_env): + return r_home_env + + # Step 2: Check for R in active Conda environment + conda_prefix = os.environ.get("CONDA_PREFIX") + if conda_prefix: + if sys.platform == "win32": + r_executable = os.path.join(conda_prefix, "Scripts", "R.exe") + else: + r_executable = os.path.join(conda_prefix, "bin", "R") + + if os.path.exists(r_executable): + try: + r_home = subprocess.check_output( + [r_executable, "RHOME"], universal_newlines=True + ).strip() + if os.path.exists(r_home): + return r_home + except subprocess.CalledProcessError: + pass # R executable found but failed to get R_HOME + + # Step 3: Check system PATH for R executable + r_executable = shutil.which("R") + if r_executable: + try: + r_home = subprocess.check_output( + [r_executable, "RHOME"], universal_newlines=True + ).strip() + if os.path.exists(r_home): + return r_home + except subprocess.CalledProcessError: + pass # R executable found but failed to get R_HOME + + # Step 4: Check common installation paths (additional step for thoroughness) + potential_paths = [] + if sys.platform == "win32": + # Windows common installation paths + potential_paths.extend(glob("C:/Program Files/R/R-*/")) + + # Attempt to read R_HOME from Windows Registry + try: + import winreg + + reg_path = r"SOFTWARE\R-core\R" + for root in [winreg.HKEY_LOCAL_MACHINE, winreg.HKEY_CURRENT_USER]: + try: + key = winreg.OpenKey(root, reg_path) + r_home, _ = winreg.QueryValueEx(key, "InstallPath") + if os.path.exists(r_home): + return r_home + except FileNotFoundError: + continue + except ImportError: + pass # winreg not available + else: + # macOS and Linux common installation paths + potential_paths.extend( + [ + "/usr/local/lib/R", + "/usr/lib/R", + "/Library/Frameworks/R.framework/Resources", # macOS + ] + ) + + for path in potential_paths: + if os.path.exists(path): + return path + + # R_HOME could not be determined + return None + + @staticmethod + def verify_r_dependencies( + cran_dependencies: List[str] = R_PLOTTING_DEPENDENCIES["cran"] + + R_PREPROCESSING_DEPENDENCIES["cran"], + bioconductor_dependencies: List[str] = R_PREPROCESSING_DEPENDENCIES[ + "bioconductor" + ] + + R_PLOTTING_DEPENDENCIES["bioconductor"], + install_missing: Optional[bool] = None, + **kwargs, + ): + if install_missing is None: + install_missing = is_auto_install() + + r_home = None + try: + import rpy2.situation + + r_home = rpy2.situation.get_r_home() + except Exception: + pass + r_home = r_home or RCaller.get_r_home() + if r_home is None: + raise RuntimeError("R_HOME could not be determined.") + + import rpy2.robjects.packages as rpackages + from rpy2.robjects import ListVector + from rpy2.robjects.vectors import StrVector + + utils = rpackages.importr("utils") + base = rpackages.importr("base") + bioc_missing, cran_missing = [], [] + cran_url = get_cran_info() + if cran_url is not None and "repos" not in kwargs: + kwargs["repos"] = ListVector([("CRAN", cran_url)]) + if cran_dependencies: + names_to_install = [ + x for x in cran_dependencies if not rpackages.isinstalled(x) + ] + if len(names_to_install) > 0 and install_missing: + logger.info(f"Using R_HOME: {r_home}") + logger.info(f"Installing missing CRAN dependencies: {names_to_install}") + if "BiocManager" in names_to_install: + names_to_install.remove("BiocManager") + utils.chooseCRANmirror(ind=1) + utils.install_packages("BiocManager") + + if cran_url is not None: + logger.info("Using CRAN mirror: %s", cran_url) + base.options(**kwargs) + utils.install_packages(StrVector(names_to_install)) + # verify again to check if all dependencies are installed + names_to_install = [ + x for x in cran_dependencies if not rpackages.isinstalled(x) + ] + if len(names_to_install) > 0: + # conda package names + conda_packages = ["r-" + name.lower() for name in names_to_install] + + conda_install_cmd = ( + f"conda install -y -c conda-forge {' '.join(conda_packages)}" + ) + logger.warning( + f"Failed to install the following CRAN dependencies: " + f"{names_to_install}. If you are using conda, you can try " + "installing the dependencies via:\n" + f"{conda_install_cmd}" + ) + + else: + cran_missing = names_to_install + if bioconductor_dependencies: + names_to_install = [ + x for x in bioconductor_dependencies if not rpackages.isinstalled(x) + ] + if len(names_to_install) > 0 and install_missing: + logger.info(f"Using R_HOME: {r_home}") + logger.info( + "Installing missing Bioconductor dependencies: " + f"{bioconductor_dependencies}" + ) + if not rpackages.isinstalled("BiocManager"): + logger.info("Installing BiocManager") + utils.chooseCRANmirror(ind=1) + utils.install_packages("BiocManager") + + biocmanager = rpackages.importr("BiocManager") + bioconductor_config_file, bioc_mirror = get_bioconductor_info() + if cran_url is not None: + logger.info("Using CRAN mirror: %s", cran_url) + if bioc_mirror is not None: + logger.info("Using Bioconductor mirror: %s", bioc_mirror) + if bioc_mirror is not None and "BioC_mirror" not in kwargs: + kwargs["BioC_mirror"] = bioc_mirror + if ( + bioconductor_config_file is not None + and "BIOCONDUCTOR_CONFIG_FILE" not in kwargs + ): + kwargs["BIOCONDUCTOR_CONFIG_FILE"] = bioconductor_config_file + base.options(**kwargs) + biocmanager.install( + StrVector(names_to_install), ask=False, update=False + ) + # verify again to check if all dependencies are installed + names_to_install = [ + x for x in bioconductor_dependencies if not rpackages.isinstalled(x) + ] + if len(names_to_install) > 0: + conda_packages = [ + "bioconductor-" + name.lower() for name in names_to_install + ] + conda_install_cmd = ( + "conda install -y -c conda-forge -c bioconda " + f"{' '.join(conda_packages)}" + ) + msg = ( + f"Failed to install the following Bioconductor dependencies: " + f"{names_to_install}. If you are using conda, you can try " + "installing the dependencies via:\n" + f"conda install -y -c conda-forge {' '.join(conda_packages)}" + ) + if sys.platform == "win32": + msg += ( + "\nNote: Many of the Bioconductor packages are not " + "available on Windows. Please submit an issue on the BIOFIT " + "GitLab repository if you need help with installation." + ) + logger.warning(msg) + else: + bioc_missing = names_to_install + if (cran_missing or bioc_missing) and not install_missing: + instructions = "Run `biofit.integration.R.RCaller.verify_r_dependencies(" + if cran_missing: + instructions += f"cran_deps={cran_missing}, " + if bioc_missing: + instructions += f"bioc_deps={bioc_missing}, " + instructions += "install_missing=True)` to install missing dependencies." + raise PackageNotInstalledError( + f"Missing R dependencies: {cran_missing + bioc_missing}. " + f"{instructions}" + ) + + @classmethod + def from_script(cls, r_code_or_path=None): + self = cls() + self._global_vars["R_SCRIPTS_PATH"] = ( + f"{config.R_SCRIPTS.resolve().as_posix()}/" + ) + self.r_code = "" + if r_code_or_path: + if os.path.exists(r_code_or_path): + self.r_source = r_code_or_path + self._global_vars["CURRENT_DIR"] = ( + f"{Path(self.r_source).resolve().parent.as_posix()}/" + ) + with open(r_code_or_path, "r") as f: + self.r_code += f.read() + else: + self.r_code += r_code_or_path + return self + + def _run_r( + self, + code_runner, + r_code: str, + env=None, + add_globals=False, + quiet=False, + ) -> object: + """ + Run R code using the provided code_runner function. + + Args: + code_runner (Callable): The function to run the R code. + r_code (str): The R code to run. + env (SexpEnvironment, optional): The R environment to add global variables + to. Defaults to globalenv. + add_globals (bool, optional): Whether to add global variables to the R + context. Defaults to False. + quiet (bool, optional): Whether to suppress output from R. Defaults to + False. + + Returns: + The output from the R code or the converted output if convert is True. + """ + from rpy2.rinterface_lib.embedded import RRuntimeError + + def run_code(env): + if env is None: + from rpy2.rinterface import globalenv + + env = globalenv + if add_globals: + for name, value in self._global_vars.items(): + env[name] = value.removeprefix('"').removesuffix('"') + results = code_runner(r_code) + last_line = r_code.strip().split("\n")[-1].strip() + # for some reason, arrow tables are not returned properly when letting + # rpy2 handle the conversion. So we grab it directly from the R context. + if last_line in env: + results = env[last_line] + return results + + if quiet: + with ROutputCapture() as output_capture: + try: + return run_code(env) + except RRuntimeError as e: + # Attempt to capture the traceback from R. + r_stdout = output_capture.get_stdout() + r_stderr = output_capture.get_stderr() + logger.error(f"captured stdout from R: {r_stdout}") + logger.error(f"captured stderr from R: {r_stderr}") + try: + r_traceback = "\n".join(code_runner("unlist(traceback())")) + except Exception as traceback_exc: + r_traceback = f"Failed to capture R traceback: {traceback_exc}" + raise RuntimeError(f"Error in R code: {e}\n{r_traceback}") from e + else: + try: + return run_code(env) + except RRuntimeError as e: + try: + r_traceback = "\n".join(code_runner("unlist(traceback())")) + except Exception as traceback_exc: + r_traceback = f"Failed to capture R traceback: {traceback_exc}" + raise RuntimeError(f"Error in R code: {e}\n{r_traceback}") from e + + def _convert_output(self, obj, code_runner, converter): + # check if the result is a list + is_list = converter.rpy2py( + self._run_r(code_runner, "is.list", quiet=True)(obj) + )[0] + if not is_list: + obj = [obj] + + names = self._run_r(code_runner, "names")(obj) + has_no_names = converter.rpy2py(self._run_r(code_runner, "is.null")(names))[0] + if has_no_names: + names = [f"{i}" for i in range(len(obj))] + out = {} + + for n, result in zip(names, obj): + if converter.rpy2py(self._run_r(code_runner, "is.vector")(result))[0]: + out[n] = np.array([r for r in result]) + elif converter.rpy2py(self._run_r(code_runner, "is.list")(result))[0]: + out[n] = [r for r in result] + elif result.__class__ in converter.rpy2py_nc_map and hasattr( + result, "rclass" + ): + nc_map = converter.rpy2py_nc_map[result.__class__] + for rclass in result.rclass: + if rclass in nc_map: + out[n] = nc_map[rclass](result) + break + else: + out[n] = converter.rpy2py(result) + else: + out[n] = converter.rpy2py(result) + + if len(out) == 1: + out = out[n] + return out + + def _cleanup(self, code_runner): + code_runner("rm(list=ls())") + # garbage collection in Python doesn't always clean up after R automatically + code_runner("gc()") + gc.collect() + + def _init_vars(self, code_runner, env, converter, **kwargs): + import rpy2.robjects as ro + from rpy2.rinterface import Sexp + + def convert_to_r(arg): + if isinstance(arg, Sexp): + return arg + elif arg is None: + return ro.NULL + elif isinstance(arg, (list, tuple)): + return converter.py2rpy([convert_to_r(a) for a in arg]) + elif isinstance(arg, dict): + return ro.ListVector(arg) + else: + return converter.py2rpy(arg) + + for n, arg in kwargs.items(): + env[n] = convert_to_r(arg) + + def get_method(self, name, enter_code="", exit_code="", convert=True, quiet=True): + _converter = self._get_converter(name) + import rpy2.rinterface + import rpy2.robjects as ro + + with rpy2.rinterface.local_context() as r_context: + self._run_r(ro.r, self.r_code, add_globals=True, quiet=quiet) + func_args = list(_converter.rpy2py(r_context[name]).formals().names) + self._cleanup(ro.r) + + def func(*args, context_kwargs: dict = None, **kwargs): + if args: + for n, arg in zip(func_args[: len(args)], args): + kwargs[n] = arg + + run_code = self.r_code + "\n" + enter_code + "\n" + + arg_str = ", ".join([f"{k} = {k}" for k in kwargs.keys() if k in func_args]) + + # additional variables not in the function signature + if context_kwargs: + kwargs.update(context_kwargs) + kwargs.update(self._global_vars) + + run_code += f"results <- {name}({arg_str})" + + run_code += "\n" + exit_code + "\nresults" + + out = self.run(run_code, convert=convert, quiet=quiet, **kwargs) + + return out + + parameters = [ + inspect.Parameter(name=arg, kind=inspect.Parameter.POSITIONAL_OR_KEYWORD) + for arg in func_args + if arg != "..." + ] + parameters += ( + [inspect.Parameter(name="kwargs", kind=inspect.Parameter.VAR_KEYWORD)] + if "..." in func_args + else [] + ) + func.__signature__ = inspect.Signature(parameters=parameters) + return func + + def run(self, r_code=None, convert=True, quiet=False, **kwargs): + """A method to run the R code with the provided arguments. + + Args: + r_code (str, optional): The R code to run. Defaults to the provided R code + given in the `from_script` method. + convert (bool, optional): Whether to convert the R output to Python objects. + Defaults to True. + **kwargs: The variables to pass to the R code + """ + r_code = r_code or self.r_code + if r_code: + r_code = r_code.strip() + _converter = self._get_converter(self.run) + + import rpy2.rinterface + import rpy2.robjects as ro + + with rpy2.rinterface.local_context() as r_context: + self._init_vars(ro.r, r_context, _converter, **kwargs) + out = self._run_r(ro.r, r_code, env=r_context, quiet=quiet) + if out is None: + logger.warning("No output from R code") + if convert: + if out is None: + raise RuntimeError( + "Conversion requested but no output from R code. Either " + "provide an output by adding the variable name on the last " + "line of the R code or set `convert` to `False`." + ) + out = self._convert_output(out, ro.r, _converter) + self._cleanup(ro.r) + return out + else: + raise RuntimeError( + "R code not provided, please provide R code or the path to the R script" + " via `from_script` method." + ) diff --git a/src/biofit/integration/R/scripts/analysis_utils.R b/src/biofit/integration/R/scripts/analysis_utils.R new file mode 100644 index 0000000..4a2e837 --- /dev/null +++ b/src/biofit/integration/R/scripts/analysis_utils.R @@ -0,0 +1,245 @@ +# cumNorm <- function(x, p = cumNormStatFast(x)) { +# normFactors <- calcNormFactors(x, p = p) +# x$normFactors <- normFactors +# return(x) +# } +# +# calcNormFactors <- function(x, p = cumNormStatFast(x)) { +# xx <- x +# xx[x == 0] <- NA +# qs <- rowQuantiles(xx, probs = p, na.rm = TRUE) +# norm_factors <- apply(xx, 1, function(row, qs) { +# row <- (row - .Machine$double.eps) +# sum(row[row <= qs]) +# }, qs) +# names(norm_factors) <- rownames(x) +# as.data.frame(norm_factors) +# } +# +# cumNormMat <- function(x, p = cumNormStatFast(x), sl = 1000) { +# xx <- x +# xx[x == 0] <- NA +# +# qs <- rowQuantiles(xx, probs = p, na.rm = TRUE) +# +# newMat <- apply(xx, 1, function(row, qs) { +# row <- (row - .Machine$double.eps) +# sum(row[row <= qs]) +# }, qs) +# nmat <- sweep(x, 1, newMat / sl, "/") +# return(nmat) +# } +# +# cumNormStat <- function(obj, qFlag = TRUE, pFlag = FALSE, rel = .1, ...) { +# mat <- returnAppropriateObj(obj, FALSE, FALSE) +# if (any(rowSums(mat) == 0)) stop("Warning: empty feature") +# +# smat <- apply(mat, 1, function(row) sort(row, decreasing = FALSE)) +# ref <- colMeans(smat) +# +# yy <- mat +# yy[yy == 0] = NA +# +# refS <- sort(ref) +# +# k <- which(refS > 0)[1] +# lo <- (length(refS) - k + 1) +# +# if (qFlag) { +# diffr <- apply(yy, 1, function(row) { +# refS[k:length(refS)] - quantile(row, probs = seq(0, 1, length.out = lo), na.rm = TRUE) +# }) +# } else { +# diffr <- apply(yy, 1, function(row) { +# refS[k:length(refS)] - approx(sort(row, decreasing = FALSE), n = lo)$y +# }) +# } +# +# diffr2 <- matrixStats::colMedians(abs(diffr), na.rm = TRUE) +# +# x <- which(abs(diff(diffr2)) / diffr2[-1] > rel)[1] / length(diffr2) +# if (x <= 0.50) { +# message("Default value being used.") +# x <- 0.50 +# } +# +# return(x) +# } +# +# cumNormStatFast <- function(mat, pFlag = FALSE, rel = .1, ...) { +# +# smat <- apply(mat, 1, function(row) { +# sort(row[which(row > 0)], decreasing = TRUE) +# }) +# +# leng <- max(sapply(smat, length)) +# if (any(sapply(smat, length) == 1)) stop("Warning: sample with one or zero features") +# +# smat2 <- array(NA, dim = c(leng, nrow(mat))) +# for (i in 1:nrow(mat)) { +# smat2[leng:(leng - length(smat[i]) + 1), i] = smat[i] +# } +# +# rmat2 <- apply(smat2, 1, function(row) { +# quantile(row, probs = seq(0, 1, length.out = ncol(smat2)), na.rm = TRUE) +# }) +# +# smat2[is.na(smat2)] = 0 +# ref1 <- colMeans(smat2) +# +# ncols <- ncol(rmat2) +# diffr <- apply(rmat2, 1, function(row) { +# ref1 - row +# }) +# diffr1 <- matrixStats::colMedians(abs(diffr)) +# +# x <- which(abs(diff(diffr1)) / diffr1[-1] > rel)[1] / length(diffr1) +# if (x <= 0.50) { +# message("Default value being used.") +# x <- 0.50 +# } +# +# return(x) +# } + + +############### +## setup the MRexperiment object +## phenodat: sample data.frame with rownames being the +## sampleID (crucial), first column also the sampleID here +## OTUdata: feaure data.frame with rownames being the featID +## (crucial). first column also the featID +## cntdat: matrix data.frame with rownames being the featID and +## colnames as sampleID +############### + +setMRobject <- function(cntdat, phenodat, featdat) { + suppressPackageStartupMessages(require("metagenomeSeq")) + phenodatDF <- as.data.frame(phenodat) + phenotypeData <- AnnotatedDataFrame(phenodatDF) + + featdatDF <- as.data.frame(featdat) + OTUdata <- AnnotatedDataFrame(featdatDF) + + cntdatDF <- as.data.frame(cntdat) + + obj <- newMRexperiment(cntdatDF, phenoData = phenotypeData, featureData = OTUdata) + return(obj) +} + +############### + ### Filter OTU/Taxa presence with a chosen read count + ### 'present' for the number of samples the feature is at least present in + ### 'fdepth' for feature count cutoff for presence, a taxa/OTU has to have + ### at least this number of reads to be deemed present within a sample. + ### 'depth' for the total number reads a sample has to at least have +############### +filterMRobject <- function(obj, present = 1, fdepth = 1, depth = 1, norm = FALSE) { + mat <- returnAppropriateObj(obj, norm = norm, log = FALSE) > (fdepth - 1) + cols <- which(colSums(MRcounts(obj, norm = norm)) >= depth) + rows <- which(rowSums(mat[, cols]) >= present) + + # Apply filter + obj <- obj[rows, cols] + return(obj) +} + + +############### +# Normalize the data with percentile param +# When percentile is NULL, normalization factor will be calculated +############### +normMRobject <- function(obj, percentile = NULL) { + # Calculating normalization factor + if (is.null(percentile)) { + percentile <- cumNormStat(obj) + + } + # Apply normalization + obj <- cumNorm(obj, p = signif(percentile, 4)) + return(obj) +} + +############### +# plot histograms to get a sense of the counts and presence +############### +plotHistMeta <- function(x, xlab="log2(Sum of counts)", main="Histogram", breaks=50) { + hist(x, xlab = xlab, main = main, breaks = breaks) +} + +################## +# Boxplot of distributions: Before and after normalization in log2 form +################## +boxplotMeta <- function(obj, keyAnnot, cols = NULL) { + #color setup + if (is.null(cols)) { + suppressPackageStartupMessages(require("RColorBrewer")) + cols <- c(brewer.pal(8, "Accent"), rev(brewer.pal(8, "Dark2")[-8]), brewer.pal(8, "Paired")) + } + + cl <- factor(pData(obj)[, keyAnnot]) + clcol <- cols[as.integer(cl)] + + #plot + par(mfrow = c(2, 1)) + boxplot(log2(1 + MRcounts(obj, norm = F, log = F)), col = clcol, outcol = clcol, ylab = "log2(1+Abundance)") + boxplot(log2(1 + MRcounts(obj, norm = T, log = F)), col = clcol, outcol = clcol, ylab = "log2(1+Abundance)") +} + + + +################## +# PCoA - Bray-Curtis +################## +fitPCoA <- function(obj, method = "bray", norm = T) { + suppressPackageStartupMessages(library("vegan")) + suppressPackageStartupMessages(library("ape")) + + #### distance computation and dimension reduction + d <- vegdist(t(MRcounts(obj, norm = norm, log = F)), method = method) + pcodecomp <- pcoa(d) + + # if any eigenvalue is negative, leverage Cailliez correction + if (sum(pcodecomp$values$Relative_eig < 0) > 0) { + pcodecomp <- pcoa(d, correction = "cailliez") + } + return(pcodecomp) +} + +################## +# Plot PCoA - Bray-Curtis +################## +plotPCoA <- function(obj, pcodecomp, keyAnnot, keyAnnot2=NULL, dimn=2, fileNameAdd="", cols=NULL){ + suppressPackageStartupMessages(require(ape)) + + #color setup + if(is.null(cols)){ + suppressPackageStartupMessages(require("RColorBrewer")) + cols = c(brewer.pal(8, "Accent"),rev(brewer.pal(8, "Dark2")[-8]), brewer.pal(8,"Paired")) + } + cl=cl2=factor(pData(obj)[,keyAnnot]) + clcol=cols[as.integer(cl)] + + # if a secondary annotation were specified, pch is used + if(!is.null(keyAnnot2)){ + cl2 <- factor(pData(obj)[,keyAnnot2]) + } + pch2use=c(1: length(levels(cl2))) + pchInput=pch2use[cl2] + + # Compute the percentage of explained variance + PCOAaxes <- pcodecomp$vectors[,c(1:dimn)] + eignPERC<- pcodecomp$values$Rel_corr_eig[c(1:dimn)] + + # plot PCoA + pairs(PCOAaxes, main=paste("PCoA", fileNameAdd), col=clcol, pch=pchInput, + cex=1.1, cex.labels = 2, cex.axis=1.5, upper.panel=NULL, + labels=paste("Dim",1:dimn,"\n",round(eignPERC,3)*100,"%")) + if(is.null(keyAnnot2)){ + legend("topright", legend = levels(cl), col=cols, pch=pch2use, ncol=3, cex=1)#, inset = c(0.1, 0.1)) + }else{ + legend("top", legend = levels(cl), col=cols, pch = 1, ncol=3, cex=1, inset = c(0.1, 0.1)) + legend("topright", legend = levels(cl2), pch=pch2use, cex=1, inset = c(0.1, -0.1)) + } + +} diff --git a/src/biofit/integration/R/scripts/dependencies.R b/src/biofit/integration/R/scripts/dependencies.R new file mode 100644 index 0000000..12735c5 --- /dev/null +++ b/src/biofit/integration/R/scripts/dependencies.R @@ -0,0 +1,10 @@ +source(paste0(R_SCRIPTS_PATH, "/utils.R")) + +install_dependencies <- function() { + dependencies <- c( + "ggplot2", "arrow", "circlize", "RColorBrewer", "scales", + "forcats", "patchwork", "reshape2", "ComplexHeatmap", + "edgeR", "dplyr", "tools" + ) + ensure_packages(dependencies) +} diff --git a/src/biofit/integration/R/scripts/plotting_utils.R b/src/biofit/integration/R/scripts/plotting_utils.R new file mode 100644 index 0000000..4349f94 --- /dev/null +++ b/src/biofit/integration/R/scripts/plotting_utils.R @@ -0,0 +1,1447 @@ +source(file.path(R_SCRIPTS_PATH, "utils.R")) + +sci_format <- function(x) { + suppressPackageStartupMessages(require(scales)) + ifelse((abs(x) < 1 & abs(x) > 0 & floor(log10(abs(x))) <= -2) | abs(x) >= 10000, + formatC(x, format = "e", digits = 2), formatC(x, format = "f", digits = 2) + ) +} + +#' `prepare_data_for_hist()` calculates row sums and columns sums for 1 or 2 data frames. +#' +#' Note: If you leave x2 blank, it will be NULL and not calculate. +#' +#' @param x1 data.frame: 1st or only data frame to have sums be calculated +#' @param x2 data.frame: 2nd data frame to have sums be calculated. Default is NULL. +#' +#' @returns list: a list of all the sums; 1st two values in the list are for x1, the next two for x2. +#' +prepare_data_for_hist <- function(x1, x2 = NULL) { + # Calculate row sum and col sums of dataset 1 + data1_sample <- rowSums(x1) + data1_feature <- colSums(x1) + + list_sums <- list(data1_sample, data1_feature) + + # Calculates row and col sums for second dataset if it is given Allows for + # situations with only a single dataset + if (!is.null(x2)) { + data2_sample <- rowSums(x2) + data2_feature <- colSums(x2) + + list_sums <- append(list_sums, list(data2_sample, data2_feature)) + } + + return(list_sums) +} + +#' `non_zero_sums()` calculates row sums and columns sums for non-zero values in 1 or 2 data frames. +#' +#' Note: If you leave x2 blank, it will be NULL and not calculate. +#' +#' @param x1 (data.frame) 1st or only data frame to have non-zero sums be calculated. +#' @param x2 (data.frame) 2nd data frame to have non-zero sums be calculated. Default is NULL. +#' +#' @returns (list) of all the sums; 1st value in the list are for x1, the 2nd is for x2 if used. +#' +non_zero_sums <- function(x1, x2 = NULL) { + # Sum all non-zero values for a dataset + x1_non_zero_row <- rowSums(x1 != 0) + x1_non_zero_col <- colSums(x1 != 0) + sums <- list(x1_non_zero_row, x1_non_zero_col) + + # Allows for situations with only a single dataset + if (!is.null(x2)) { + x2_non_zero_row <- rowSums(x2 != 0) + x2_non_zero_col <- colSums(x2 != 0) + sums <- append(sums, list(x2_non_zero_row, x2_non_zero_col)) + } + + # Returns list of non_zero rowSums + return(sums) +} + +#' `prepare_axis_label()` adjusts axis label names to include log transformation. +#' +#' @param label chr: the axis label to be adjusted. +#' @param log_type chr: the log transformation being applied. +#' +#' @returns chr: the new label with the transformation added. +#' +prepare_axis_label <- function(label, log_type) { + if (grepl("1p", log_type)) { + if (grepl("_1p", log_type)) { + label_log <- gsub("_1p", "", log_type) + label <- paste0(label, " (", label_log, "(x+1))") + } else { + label <- paste0(label, " (ln(x+1))") + } + } else if (log_type == "log") { + label <- paste0(label, " (ln)") + } else { + label <- paste0(label, " (", log_type, ")") + } + return(label) +} + +log_transformation <- function(x, log_type) { + if (grepl("1p", log_type)) { + if (grepl("_1p", log_type)) { + label_log <- gsub("_1p", "", log_type) + if (label_log == "log10") { + return(log10(1 + x)) + } else if (label_log == "log2") { + return(log2(1 + x)) + } + } else { + return(log(1 + x)) + } + } else if (log_type == "log") { + return(log(x)) + } else if (log_type == "log2") { + return(log2(x)) + } else if (log_type == "log10") { + return(log10(x)) + } + return(x) +} + +#' +#' Transformations functions for different transformation functions +#' They are called by scale_*_continuous when we enter the name (first argument in trans_new()) +#' They just need to be here, don't need to be called any other time. +#' + +#' `log1p()` is a transformation function for log with x + 1 as an input +log2_1p <- function(x) { + log2(1 + x) +} + +#' `log2_1p_trans()` is a transformation function for log base 2 with x + 1 as an input +log2_1p_trans <- function() { + suppressPackageStartupMessages(require(scales)) + + trans_new("log2_1p", + transform = log2_1p, inverse = function(x) { + 2^x - 1 + }, breaks = trans_breaks(log2_1p, function(x) 2^x - 1), + domain = c(0, Inf) + ) +} + +#' `log10_1p()` is a transformation function for log base 10 with x + 1 as an input +log10_1p <- function(x) { + log10(1 + x) +} + +#' `log10_1p_trans()` is a transformation function for log base 10 with x + 1 as an input +log10_1p_trans <- function() { + suppressPackageStartupMessages(require(scales)) + + trans_new("log10_1p", + transform = log10_1p, inverse = function(x) { + 10^x - 1 + }, breaks = trans_breaks(log10_1p, function(x) 10^x - 1), + domain = c(0, Inf) + ) +} + +#' +#' `color_select()` prepares a vector of colours to use in a ggplot +#' +#' Note: Requires RColorBrewer and circulize packages +#' +#' @details It will use RColorBrewer's sets of colours as the base for the vector +#' If the number of colours needed are <= the number of colors in the set, +#' then the colours are used directly from the set +#' Else, the function will generate colors in between using a colour ramp +#' +#' @param levels (int) number of colours that are to be returned +#' @param col_set (chr) the RColorBrewer set that is to be used, and only those sets +#' ('Set1', 'Set2', 'Set3', 'Pastel2', 'Pastel1', 'Paired', 'Dark2', 'Accent') +#' +#' @returns (vec) a vector of colours +#' +#' @examples +#' \dontrun{ +#' colors <- color_select(5, "Set3") +#' } +color_select <- function(levels, col_set = "Set1") { + col_set <- match.arg(col_set, c( + "Set1", "Set2", "Set3", "Pastel2", "Pastel1", + "Paired", "Dark2", "Accent" + )) + + if (col_set %in% c("Set1", "Pastel1")) { + num_col <- 9 + } else if (col_set %in% c("Set3", "Paired")) { + num_col <- 12 + } else { + num_col <- 8 + } + + if (levels > num_col) { + at <- seq(0, levels, length.out = num_col) + color_fn <- circlize::colorRamp2(at, RColorBrewer::brewer.pal(num_col, col_set)) + colors <- color_fn(1:levels) + } else { + color_fn <- RColorBrewer::brewer.pal(num_col, col_set) + colors <- color_fn[1:levels] + } + return(colors) +} + + +#' Generate a simple histogram for one variable. +#' +#' `generate_histogram()` creates a simple histogram for a single numerical variable. +#' +#' Note: requires packages ggplot2 and rlang. +#' +#' @param data (data.frame or vector) a data frame that the function is obtaining the data from or a vector of data +#' If you pass a vector, it will be made into a data.frame with the label as the name of the column. +#' @param column (chr) the column with the data that is to be used in the histogram. +#' Default is value so label doesn't need to be provided and the function works. +#' @param xlab (chr) name of the x-axis. Default is 'X'. +#' @param ylab (chr) name of the y-axis. Default is 'Frequency'. +#' @param title (chr) title of the figure. Default is 'Histogram'. +#' @param bins (int) number of bins for the histogram. Default is 30. +#' @param font_size (dbl) size of the font for the plot. Default is 8. +#' @param alpha (dbl) Opacity of bars. Default is 0.6. +#' @param col_fill (chr) primary colour of the bars. Default is 'grey40'. +#' @param col_outline (chr) colour of the outline of the bars. Default is 'black'. +#' @param xlog (chr) logarithmic transformation of the x-axis. +#' Options include: (default) NULL (none), 'log', 'log2', 'log10', 'log1p', 'log2_1p', 'log10_1p' +#' @param ylog (chr) logarithmic transformation of the y-axis. Default is NULL (none). +#' Options include: (default) NULL (none), 'log', 'log2', 'log10', 'log1p', 'log2_1p', 'log10_1p' +#' +#' @returns the ggplot build of the final histogram. +#' +#' @examples +#' \dontrun{ +#' generate_histogram(mtcars, wt) +#' +#' generate_histogram(mtcars, wt, +#' xlab = "Weights", ylab = "Frequency", +#' title = "Weights of Cars", bins = 30, +#' col_fill = "red", col_outline = "black", xlog = "log2", +#' ylog = NULL +#' ) +#' } +generate_histogram <- function( + data, column = NULL, xlab = "X", ylab = "Frequency", + title = "Histogram", bins = 30, font_size = 8, alpha = 0.6, + col_fill = "grey40", col_outline = "black", + xlog = NULL, ylog = NULL) { + data <- convert_to_dataframe(data) + + column <- get_default_columns(data, column) + if (is.null(column)) { + data <- reshape2::melt(data) + column <- "value" + } + data <- validate_data(data, column) + + if (!is.null(xlog)) { + xlog <- match.arg(xlog, c( + "log2", "log10", "log", "log2_1p", "log10_1p", + "log1p" + )) + xlab <- prepare_axis_label(xlab, xlog) + } + if (!is.null(ylog)) { + ylog <- match.arg(ylog, c( + "log2", "log10", "log", "log2_1p", "log10_1p", + "log1p" + )) + ylab <- prepare_axis_label(ylab, ylog) + } + + suppressPackageStartupMessages(require(ggplot2)) + + ggplot(data, aes(x = .data[[column]])) + + geom_histogram( + bins = bins, fill = col_fill, + color = col_outline, alpha = alpha + ) + + labs(x = xlab, y = ylab, title = title) + + theme_bw() + + theme(text = element_text(size = font_size), plot.title = element_text(hjust = 0.5)) + + { + if (!is.null(xlog)) { + scale_x_continuous(trans = xlog) + } + } + + { + if (!is.null(ylog)) { + scale_y_continuous(trans = ylog) + } + } +} + +#' Generate a histogram to compare two sets of values +#' +#' `generate_comparison_histogram()` creates a histogram that compares values from two groups. Typically a before and after some transformation or filtering, but could also just be two categories. +#' +#' Note: requires packages ggplot2, RColorBrewer. +#' +#' @param data1 (vector) data of the 1st group. +#' @param data2 (vector) data of the 2nd group. +#' @param column1 (chr) name of column containing first set of data +#' @param column2 (chr) name of column containing second set of data +#' @param xlab (chr) name of the x-axis. +#' @param ylab (chr) name of the y-axis. Default is 'Count'. +#' @param title (chr) title of the figure. Default is 'Comparison Histogram' +#' @param bins (int) number of bins for the histogram. Default is 30. +#' @param alpha (double) opacity of the histograms. Value between 0 and 1. Default is 0.6. +#' @param legend_title (chr) name of the legend. Default is 'Feature Selection'. +#' @param legend_position (chr) placement of the legend in the figure. +#' Options indude: (default) 'top', 'bottom', 'left', 'right'. +#' @param subplot_title1 (chr) name of the values from the 1st dataset. Default is 'Before'. +#' @param subplot_title2 (chr) name of the values from the 2nd dataset. Default is 'After'. +#' @param col_set (chr) name of RColorBrewer set to be used for colouring. +#' Options include: (default) 'Set1', 'Set2', 'Set3', 'Pastel2', 'Pastel1', 'Paired', 'Dark2', 'Accent' +#' @param cols (chr (vec)) vector of colors that you wish to use. Default to NULL and other colours are produced using the set in col_set. +#' @param col_outline (chr) colour of the border for the bars. Default is 'black'. +#' @param xlog (chr) logarithmic transformation of the x-axis. +#' Options include: (default) NULL (none), 'log', 'log2', 'log10'. +#' @param ylog (chr) logarithmic transformation of the y-axis. Default is NULL (none). +#' Options include: (default) NULL (none), 'log', 'log2', 'log10'. +#' +#' @returns the ggplot build of the final histogram. +#' +#' @examples +#' \dontrun{ +#' control <- PlantGrowth$weight[PlantGrowth$group == "ctrl"] +#' treat1 <- PlantGrowth$weight[PlantGrowth$group == "trt1"] +#' generate_comparison_histogram(control, treat1, +#' xlab = "Weights", +#' title = "Control VS. Treatment: Plant Weights" +#' ) +#' } +generate_comparison_histogram <- function( + data1, data2, column1 = NULL, column2 = NULL, + xlab = NULL, ylab = "Count", title = "Comparison Histogram", bins = 30, alpha = 0.6, + legend_title = "Feature Selection", legend_position = "top", subplot_title1 = "Before", + subplot_title2 = "After", col_set = "Set1", cols = NULL, col_outline = "black", xlog = NULL, + ylog = NULL, ...) { + suppressPackageStartupMessages(require(ggplot2)) + + data1 <- convert_to_dataframe(data1) + + data2 <- convert_to_dataframe(data2) + + column1 <- get_default_columns(data1, column1) + if (is.null(column1)) { + data1 <- reshape2::melt(data1) + column1 <- "value" + } + data1 <- validate_data(data1, column1) + + column2 <- get_default_columns(data2, column2) + if (is.null(column2)) { + data2 <- reshape2::melt(data2) + column2 <- "value" + } + data2 <- validate_data(data2, column1) + + if (is.null(xlab)) { + if (is.null(column1)) { + xlab <- "Values" + } else { + xlab <- column1 + } + } + + if (!is.null(xlog)) { + xlog <- match.arg(xlog, c( + "log2", "log10", "log", "log2_1p", "log10_1p", + "log1p" + )) + xlab <- prepare_axis_label(xlab, xlog) + } + if (!is.null(ylog)) { + ylog <- match.arg(ylog, c( + "log2", "log10", "log", "log2_1p", "log10_1p", + "log1p" + )) + ylab <- prepare_axis_label(ylab, ylog) + } + + data_frame <- rbind(data.frame(value = data1, category = subplot_title1), data.frame( + value = data2, + category = subplot_title2 + )) + data_frame$category <- factor(data_frame$category, levels = c(subplot_title1, subplot_title2)) + + # Default Colours + if (is.null(cols)) { + cols <- color_select(2, col_set = col_set) + names(cols) <- c(subplot_title1, subplot_title2) + } + + ggplot(data_frame, aes(x = .data[[column1]], fill = category)) + + geom_histogram( + bins = bins, + alpha = alpha, position = "identity", color = col_outline + ) + + scale_fill_manual(values = c( + cols[1], + cols[2] + ), name = legend_title) + + labs(x = xlab, title = title) + + theme_bw() + + theme( + legend.position = legend_position, text = element_text(size = 8), + axis.title = element_text(size = 6), legend.text = element_text(size = 6), + legend.key.size = unit(1, "line"), legend.title = element_text(size = 7) + ) + + { + if (!is.null(xlog)) { + scale_x_continuous(trans = xlog) + } + } + + { + if (!is.null(ylog)) { + scale_y_continuous(trans = ylog) + } + } +} + +#' Generate a simple density plot for one variable. +#' +#' `generate_density()` creates a simple density plot for a single numerical variable. +#' +#' Note: requires packages ggplot2. +#' +#' @param data (data.frame) a data frame that the function is obtaining the data from. +#' @param column (chr) the column with the data that is to be used in the density plot. +#' @param xlab (chr) name of the x-axis. Default is 'X'. +#' @param ylab (chr) name of the y-axis. Default is 'Density'. +#' @param title (chr) title of the figure. Default is 'Density Plot'. +#' @param col_fill (chr) primary colour of the bars. Default is 'grey40'. +#' @param col_outline (chr) colour of the outline of the bars. Default is 'black'. +#' @param adjust (dbl) adjust bandwidth, which determines how precise (lower values) or smooth (higher values) the density plot is. Default is 1. +#' @param alpha (dbl) opacity of the density plot. Default is 0.8. +#' @param xlog (chr) logarithmic transformation of the x-axis. +#' Options include: (default) NULL (none), 'log', 'log2', 'log10', 'log1p', 'log2_1p', 'log10_1p' +#' +#' @returns the ggplot build of the final density plot. +#' +#' @examples +#' \dontrun{ +#' generate_density(mtcars, wt) +#' +#' generate_density(mtcars, wt, +#' xlab = "Weights", ylab = "Density", +#' title = "Weights of Cars", +#' col_fill = "red", col_outline = "black", +#' adjust = 0.5, alpha = 0.8, xlog = "log2" +#' ) +#' } +generate_density <- function( + data, column = NULL, xlab = "X", ylab = "Density", + title = "Density Plot", col_fill = "grey40", col_outline = "black", adjust = 1, + alpha = 0.6, xlog = NULL) { + suppressPackageStartupMessages(require(ggplot2)) + + data <- convert_to_dataframe(data) + + column <- get_default_columns(data, column) + if (is.null(column)) { + data <- reshape2::melt(data) + column <- "value" + } + data <- validate_data(data, column) + + if (!is.null(xlog)) { + xlog <- match.arg(xlog, c( + "log2", "log10", "log", "log2_1p", "log10_1p", + "log1p" + )) + xlab <- prepare_axis_label(xlab, xlog) + } + + if (!is.data.frame(data)) { + data <- convert_to_dataframe(data) + colnames(data) <- column + } + + ggplot(data, aes(x = .data[[column]])) + + geom_density( + fill = col_fill, color = col_outline, + adjust = adjust, alpha = alpha + ) + + labs(x = xlab, y = ylab, title = title) + + theme_bw() + + theme(plot.title = element_text(hjust = 0.5)) + + { + if (!is.null(xlog)) { + scale_x_continuous(trans = xlog) + } + } +} + +#' Generate a density plot to compare two sets of values. +#' +#' `generate_comparison_density()` creates a density plot that compares values from two groups. Typically a before and after some transformation or filtering, but could also just be two categories. +#' +#' Note: requires packages ggplot2, RColorBrewer. +#' +#' @param data1 (vector) data of the 1st group. +#' @param data2 (vector) data of the 2nd group. +#' @param column1 (chr) name of column containing first set of data +#' @param column2 (chr) name of column containing second set of data +#' @param xlab (chr) name of the x-axis. +#' @param ylab (chr) name of the y-axis. Default is 'Count'. +#' @param title (chr) title of the figure. Default is 'Comparison Density Plot' +#' @param legend_title (chr) name of the legend. Default is 'Feature Selection'. +#' @param legend_position (chr) placement of the legend in the figure. +#' Options include: (default) 'top', 'bottom', 'left', 'right'. +#' @param subplot_title1 (chr) name of the values from the 1st dataset. Default is 'Before'. +#' @param subplot_title2 (chr) name of the values from the 2nd dataset. Default is 'After'. +#' @param col_set (chr) name of RColorBrewer set to be used for colouring. +#' Options include: (default) 'Set1', 'Set2', 'Set3', 'Pastel2', 'Pastel1', 'Paired', 'Dark2', 'Accent' +#' @param cols (chr (vec)) vector of colors that you wish to use. Default to NULL and other colours are produced using the set in col_set. +#' @param col_outline (chr) colour of the border of the density shape. Default is 'black'. +#' @param adjust (dbl) adjust bandwidth, which determines how precise (lower values) or smooth (higher values) the density plot is. Default is 1. +#' @param alpha (dbl) opacity of the histograms. Value between 0 and 1. Default is 0.6. +#' @param xlog (chr) logarithmic transformation of the x-axis. +#' Options include: (default) NULL (none), 'log', 'log2', 'log10', 'log1p', 'log2_1p', 'log10_1p'. +#' +#' @returns the ggplot build of the final density plot. +#' +#' @examples +#' \dontrun{ +#' control <- PlantGrowth$weight[PlantGrowth$group == "ctrl"] +#' treat1 <- PlantGrowth$weight[PlantGrowth$group == "trt1"] +#' generate_comparison_density(control, treat1, +#' xlab = "Weights", +#' title = "Control VS. Treatment: Plant Weights" +#' ) +#' } +generate_comparison_density <- function( + data1, data2, column1 = NULL, column2 = NULL, + xlab = NULL, ylab = "Count", title = "Comparison Density Plot", legend_title = "Feature Selection", + legend_position = "top", subplot_title1 = "Before", subplot_title2 = "After", col_set = "Set1", + cols = NULL, col_outline = "black", adjust = 1, alpha = 0.6, xlog = NULL) { + suppressPackageStartupMessages(require(ggplot2)) + data1 <- convert_to_dataframe(data1) + + data2 <- convert_to_dataframe(data2) + + column1 <- get_default_columns(data1, column1) + if (is.null(column1)) { + data1 <- reshape2::melt(data1) + column1 <- "value" + } + data1 <- validate_data(data1, column1) + + column2 <- get_default_columns(data2, column2) + if (is.null(column2)) { + data2 <- reshape2::melt(data2) + column2 <- "value" + } + data2 <- validate_data(data2, column1) + + if (is.null(xlab)) { + if (is.null(column1)) { + xlab <- "Values" + } else { + xlab <- column1 + } + } + + if (!is.null(xlog)) { + xlog <- match.arg(xlog, c( + "log2", "log10", "log", "log2_1p", "log10_1p", + "log1p" + )) + xlab <- prepare_axis_label(xlab, xlog) + } + + data_frame <- rbind(data.frame(value = data1, category = subplot_title1), data.frame( + value = data2, + category = subplot_title2 + )) + data_frame$category <- factor(data_frame$category, levels = c(subplot_title1, subplot_title2)) + + if (is.null(cols)) { + cols <- color_select(2, col_set = col_set) + names(cols) <- c(subplot_title1, subplot_title2) + } + + ggplot(data_frame, aes(x = value, fill = category)) + + geom_density( + adjust = adjust, + alpha = alpha, color = col_outline + ) + + scale_fill_manual(values = c( + cols[1], + cols[2] + ), name = legend_title) + + labs(x = xlab, title = title) + + theme_bw() + + theme( + legend.position = legend_position, text = element_text(size = 8), + axis.title = element_text(size = 6), legend.text = element_text(size = 6), + legend.key.size = unit(1, "line"), legend.title = element_text(size = 7) + ) + + { + if (!is.null(xlog)) { + scale_x_continuous(trans = xlog) + } + } +} + +#' +#' `generate_barplot()` creates a bar plot using ggplot2 +#' +#' Note: Requires ggplot2 package +#' +#' @details Settings available for both normal or stacked (& proportional stacked) bar plots. +#' adding a groupby variable will make a stacked bar plot and setting prop = T will make proportional bar plot. +#' +#' @param data (data.frame/vector) data frame contained the data of interest or vector with levels (categorical variable) +#' @param y (data.frame or vector) contains the identity values/heights of the bars (ex. counts of the bars pre-calculated) +#' @param group +#' @param label_name (chr) the name of the categorical variable to be used for the bars. Default is 'Labels'. +#' @param value_name (chr) the name of the column containing values for y (pre-calculated counts) +#' @param groupby (chr) the name of the secondary categorical variable to group the bars. Default is NULL. +#' @param xlab (chr) x-axis label. Default is 'X'. +#' @param ylab (chr) y-axis label. Default is 'Count' +#' @param title (chr) plot name/title. Default is 'Bar Plot' +#' @param col_set (chr) name of RColorBrewer set to be used for colouring. +#' Options include: (default) 'Set1', 'Set2', 'Set3', 'Pastel2', 'Pastel1', 'Paired', 'Dark2', 'Accent' +#' @param cols (chr (vec)) vector of colors that you wish to use. Default to NULL and other colours are produced using the set in col_set. +#' @param col_outline (chr) colour of the outline of the bars. Default is 'grey'. +#' @param col_labels (chr) colour of the labels on the bars. Default is 'black'. +#' @param alpha (dbl) Opacity. Default is 0.6. +#' @param prop (logical) make a stacked barplot a proportional barplot. Default is FALSE. +#' @param add_count_lab (logical) add count labels to bars. Default is TRUE. +#' @param vars_as_entered (logical) leave the variable order as entered, not flipping the values so the variable with more levels makes the bars. Default is FALSE. +#' @param legend_position (chr) position of the legend. +#' Options Include: (default) 'top', 'bottom', 'left', 'right' +#' @param font_size (dbl) size of the font. Default is 3.25. +#' +#' @returns The ggplot object of the completed plot +#' +#' @examples +#' \dontrun{ +#' generate_barplot(mtcars, "gear") +#' +#' generate_barplot(mtcars, "gear", +#' groupby = "cyl", xlab = "gear", ylab = "Count", +#' title = "Gear Vs. Carb", col_set = "Set2", +#' prop = T, add_count_lab = T, vars_as_entered = F, +#' legend_position = "top", font_size = 4 +#' ) +#' } +generate_barplot <- function( + data, y = NULL, group = NULL, label_name = "labels", + value_name = "values", groupby = "group", xlab = NULL, ylab = NULL, title = "Bar Plot", + col_set = "Set1", cols = NULL, col_outline = "grey30", col_labels = "black", alpha = 0.6, + prop = F, add_count_lab = T, vars_as_entered = F, legend_position = "top", font_size = 3.25) { + suppressPackageStartupMessages(require(ggplot2)) + + # Angled text threshold + VERT_LEVELS_MIN <- 9 # minimum number of levels + VERT_WORD_MIN <- 7 # minimum length of largest label name + + data <- convert_to_dataframe(data) + label_name <- get_default_columns(data, label_name) + data <- validate_data(data, label_name) + + if (!is.null(y)) { + y <- convert_to_dataframe(y) + value_name <- get_default_columns(y, value_name) + # check if value_name is in y or add it if possible, otherwise error + y <- validate_data(y, value_name) + data <- concatenate_datasets(data, y, how = "horizontal") + } + + if (!is.null(group)) { + group <- convert_to_dataframe(group) + groupby <- get_default_columns(group, groupby) + group <- validate_data(group, groupby) + data <- concatenate_datasets(data, group, how = "horizontal") + } + + if (is.null(xlab)) { + xlab <- label_name + } + + # Check that the variable is a character + if (!is.character(data[[label_name]])) { + data[[label_name]] <- as.character(data[[label_name]]) + } + cat_lvls <- length(unique(data[[label_name]])) + + if (!is.null(value_name) && value_name %in% colnames(data)) { + data <- data[, c(label_name, value_name)] + if (is.null(ylab)) { + ylab <- value_name + } + sorted_inds <- order(data[[value_name]], decreasing = TRUE) + data <- data[sorted_inds, ] + max_length <- max(nchar(unique(na.omit(data[[label_name]])))) + data[[label_name]] <- factor(data[[label_name]], levels = data[[label_name]]) + levels <- 1 + if (is.null(cols)) { + cols <- color_select(levels, col_set) + } + the_plot <- ggplot(data, aes(x = .data[[label_name]], y = .data[[value_name]])) + + geom_bar(stat = "identity", color = col_outline, fill = cols, alpha = alpha) + + theme_bw() + + labs(x = xlab, y = ylab, title = title) + + theme( + legend.position = legend_position, + plot.title = element_text(hjust = 0.5) + ) + + { + if (cat_lvls >= VERT_LEVELS_MIN || max_length >= VERT_WORD_MIN) { + theme(axis.text.x = element_text(angle = 60, hjust = 1)) + } + } + + { + if (add_count_lab) { + geom_text(aes(label = sci_format(.data[[value_name]])), + stat = "identity", + position = position_stack(vjust = 0.5), color = col_labels, size = font_size + ) + } + } + } else { + if (is.null(ylab)) { + ylab <- "Count" + } + # If a groupby value has been passed, the function sets up for a stacked + # barplot + if (!is.null(groupby) && groupby %in% colnames(data)) { + stacked <- TRUE + + # Check that the variable is a character + if (!is.character(data[[groupby]])) { + data[[groupby]] <- as.character(data[[groupby]]) + } + + # settings for plot type and labels Proportional or Stack + if (prop) { + position_set <- "fill" + position_func <- position_fill(vjust = 0.5) + + # Won't change the name if a custom label is given + if (ylab == "Count") { + ylab <- "Proportion" + } + } else { + position_set <- "stack" + position_func <- position_stack(vjust = 0.5) + } + + # Calculate Levels + groupby_lvls <- length(unique(data[[groupby]])) + max_length <- max(nchar(unique(na.omit(data[[label_name]])))) + + # Function will make the variable with more levels the bars of the plot + # Unless it is specified to leave it as is + if (!vars_as_entered) { + # if grouby is larger, swap the values + if (groupby_lvls > cat_lvls) { + temp <- label_name + label_name <- groupby + groupby <- temp + xlab <- label_name + + levels <- cat_lvls + cat_lvls <- groupby_lvls + max_length <- max(nchar(unique(na.omit(data[[label_name]])))) + } else { + levels <- groupby_lvls + } + } else { + levels <- groupby_lvls + } + } else { + stacked <- FALSE + cat_lvls <- length(unique(data[[label_name]])) + max_length <- max(nchar(unique(na.omit(data[[label_name]])))) + levels <- 1 + } + + if (is.null(cols)) { + cols <- color_select(levels, col_set) + } + + if (stacked) { + the_plot <- ggplot(data, aes( + x = forcats::fct_infreq(.data[[label_name]]), + fill = forcats::fct_rev(forcats::fct_infreq(.data[[groupby]])) + )) + + geom_bar(position = position_set, stat = "count", color = col_outline, alpha = alpha) + + theme_bw() + + labs( + x = xlab, y = ylab, color = groupby, fill = groupby, + title = title + ) + + theme(legend.position = legend_position, plot.title = element_text(hjust = 0.5)) + + scale_fill_manual(values = cols) + + { + if (cat_lvls >= VERT_LEVELS_MIN || max_length >= VERT_WORD_MIN) { + theme(axis.text.x = element_text(angle = 60, hjust = 1)) + } + } + + { + if (add_count_lab) { + geom_text(aes(label = after_stat(count)), + stat = "count", position = position_func, + color = col_labels, size = font_size + ) + } + } + } else { + the_plot <- ggplot(data, aes(x = forcats::fct_infreq(.data[[label_name]]))) + + geom_bar(stat = "count", color = col_outline, alpha = alpha) + + theme_bw() + + labs(x = xlab, y = ylab, title = title) + + theme( + legend.position = legend_position, + plot.title = element_text(hjust = 0.5) + ) + + { + if (cat_lvls >= VERT_LEVELS_MIN || max_length >= VERT_WORD_MIN) { + theme(axis.text.x = element_text(angle = 60, hjust = 1)) + } + } + + { + if (add_count_lab) { + geom_text(aes(label = after_stat(count)), + stat = "count", position = position_stack(vjust = 0.5), + color = col_labels, size = font_size + ) + } + } + } + } + return(the_plot) +} + +#' +#' `generate_boxplot()` creates a violin plot using ggplot2 +#' +#' Note: Requires ggplot2 +#' +#' @param data (data.frame or vector) data frame containing the variable or numerical data +#' @param labels (data.frame or vector) vector (or data frame) containing the categorical data +#' @param column (chr) name of numerical variable column. +#' @param label_name (chr) name of categorical variable. +#' @param xlab (chr) label for categorical axis. +#' @param ylab (chr) label for numerical axis. +#' Note: You don't need to change the x and y labels around if horizontal_plot is TRUE, the function will change them automatically. +#' @param title (chr) title of the plot. Default is 'Violin Plot' +#' @param legend_position (chr) position of the legend. +#' Options Include: (default) 'top', 'bottom', 'left', 'right' +#' @param add_box (logical) Add boxplots on top of violin plots. Default is TRUE +#' @param horizontal_plot (logical) if you want the boxplots to be running horizontally, set to TRUE. Default is FALSE. +#' @param order (logical) if you want the violin plots ordered by median. Default is FALSE. +#' @param col_set (chr) name of RColorBrewer set to be used for colouring. +#' Options Include: (default) 'Set1', 'Set2', 'Set3', 'Pastel2', 'Pastel1', 'Paired', 'Dark2', 'Accent' +#' @param cols (chr (vec)) vector of colors that you wish to use. Default to NULL and other colours are produced using the set in col_set. +#' @param alpha (dbl) Opacity. Default is 0.6. +#' @param log_num (chr) logarithmic transformation of the numerical axis (y). +#' Options include: (default) NULL (none), 'log', 'log2', 'log10', 'log1p', 'log2_1p', 'log10_1p'. +#' +#' @returns The ggplot object of the completed plot +#' +#' @examples +#' \dontrun{ +#' generate_boxplot(iris, column = "Petal.Length", label_name = "Species") +#' +#' generate_boxplot(iris, +#' column = "Petal.Length", label_name = "Species", +#' ylab = "Petal Length", xlab = "Species", horizontal_plot = T +#' ) +#' } +generate_boxplot <- function( + data, labels = NULL, column = NULL, label_name = "labels", + xlab = NULL, ylab = NULL, title = "Boxplot", legend_position = "top", horizontal_plot = F, + order = F, col_set = "Set1", cols = NULL, alpha = 0.6, log_num = NULL) { + suppressPackageStartupMessages(require(ggplot2)) + + # Angled text threshold + VERT_LEVELS_MIN <- 9 # minimum number of levels + VERT_WORD_MIN <- 7 # minimum length of largest label name + + data <- convert_to_dataframe(data) + + if (!is.null(labels)) { + labels <- convert_to_dataframe(labels) + label_name <- get_default_columns(labels, label_name) + labels <- validate_data(labels, label_name) + } + + if (is.null(column) && ncol(data) > 2 && !is.null(label_name)) { + suppressPackageStartupMessages(require(reshape2)) + + if (!is.null(labels)) { + data <- concatenate_datasets(data, labels, how = "horizontal") + } + data <- reshape2::melt(data, id.vars = label_name) + column <- names(data)[3] + } else { + column <- get_default_columns(data, column) + data <- validate_data(data, column) + if (!is.null(labels)) { + data <- concatenate_datasets(data, labels, how = "horizontal") + } + } + + if (is.null(xlab)) { + xlab <- label_name + } + + if (is.null(ylab)) { + ylab <- column + } + + if (!is.null(log_num)) { + log_num <- match.arg(log_num, c( + "log2", "log10", "log", "log2_1p", "log10_1p", + "log1p" + )) + ylab <- prepare_axis_label(ylab, log_num) + } + + if (!is.null(label_name) && label_name %in% colnames(data)) { + # Make sure label_name is categorical + if (!is.character(data[[label_name]])) { + data[[label_name]] <- as.character(data[[label_name]]) + } + + levels <- length(unique(data[[label_name]])) + max_length <- max(nchar(unique(na.omit(data[[label_name]])))) + + if (is.null(cols)) { + cols <- color_select(levels, col_set) + } + + if (order) { + the_plot <- ggplot(data, aes(x = forcats::fct_reorder(.data[[label_name]], + .data[[column]], + .fun = median + ), y = .data[[column]], fill = .data[[label_name]])) + + geom_boxplot(alpha = alpha) + + theme_bw() + + scale_fill_manual(values = cols) + + labs(x = xlab, y = ylab, title = title) + + theme( + legend.position = legend_position, + plot.title = element_text(hjust = 0.5) + ) + + { + if (!is.null(log_num)) { + scale_y_continuous(trans = log_num) + } + } + + { + if (horizontal_plot) { + coord_flip() + } + } + + { + if ((levels >= VERT_LEVELS_MIN || max_length >= VERT_WORD_MIN)) { + theme(axis.text.x = element_text(angle = 60, hjust = 1)) + } + } + } else { + the_plot <- ggplot(data, aes( + x = .data[[label_name]], y = .data[[column]], + fill = .data[[label_name]] + )) + + geom_boxplot(alpha = alpha) + + theme_bw() + + scale_fill_manual(values = cols) + + labs(x = xlab, y = ylab, title = title) + + theme( + legend.position = legend_position, + plot.title = element_text(hjust = 0.5) + ) + + { + if (!is.null(log_num)) { + scale_y_continuous(trans = log_num) + } + } + + { + if (horizontal_plot) { + coord_flip() + } + } + + { + if ((levels >= VERT_LEVELS_MIN || max_length >= VERT_WORD_MIN)) { + theme(axis.text.x = element_text(angle = 60, hjust = 1)) + } + } + } + } else { + the_plot <- ggplot(data, aes(x = 1, y = .data[[column]])) + + geom_boxplot(alpha = alpha) + + theme_bw() + + labs(x = xlab, y = ylab, title = title) + + theme( + legend.position = legend_position, + plot.title = element_text(hjust = 0.5) + ) + + { + if (!is.null(log_num)) { + scale_y_continuous(trans = log_num) + } + } + + { + if (horizontal_plot) { + coord_flip() + } + } + } + return(the_plot) +} + +#' +#' `generate_violin()` creates a violin plot using ggplot2 +#' +#' Note: Requires ggplot2 +#' +#' @param data (data.frame or vector) data frame containing the variable or numerical data +#' @param labels (data.frame or vector) vector (or data frame) containing the categorical data +#' @param column (chr) name of numerical variable column. +#' @param label_name (chr) name of categorical variable. +#' @param xlab (chr) label for categorical axis. +#' @param ylab (chr) label for numerical axis. +#' Note: You don't need to change the x and y labels around if horizontal_plot is TRUE, the function will change them automatically. +#' @param title (chr) title of the plot. Default is 'Violin Plot' +#' @param legend_position (chr) position of the legend. +#' Options Include: (default) 'top', 'bottom', 'left', 'right' +#' @param add_box (logical) Add boxplots on top of violin plots. Default is TRUE +#' @param horizontal_plot (logical) if you want the boxplots to be running horizontally, set to TRUE. Default is FALSE. +#' @param order (logical) if you want the violin plots ordered by median. Default is FALSE. +#' @param col_set (chr) name of RColorBrewer set to be used for colouring. +#' Options Include: (default) 'Set1', 'Set2', 'Set3', 'Pastel2', 'Pastel1', 'Paired', 'Dark2', 'Accent' +#' @param cols (chr (vec)) vector of colors that you wish to use. Default to NULL and other colours are produced using the set in col_set. +#' @param alpha (dbl) Opacity. Default is 0.6. +#' @param log_num (chr) logarithmic transformation of the numerical axis (y). +#' Options include: (default) NULL (none), 'log', 'log2', 'log10', 'log1p', 'log2_1p', 'log10_1p'. +#' +#' @returns The ggplot object of the completed plot +#' +#' @examples +#' \dontrun{ +#' generate_violin(iris, column = "Petal.Length", label_name = "Species") +#' +#' generate_violin(iris, +#' column = "Petal.Length", label_name = "Species", +#' ylab = "Petal Length", xlab = "Species", horizontal_plot = T +#' ) +#' } +generate_violin <- function( + data, labels = NULL, column = NULL, label_name = "labels", + ylab = NULL, xlab = NULL, ylim = NULL, title = "Violin Plot", + legend_position = "top", add_box = TRUE, horizontal_plot = FALSE, + order = FALSE, col_set = "Set1", cols = NULL, alpha = 0.6, + log_num = NULL, show_outliers = TRUE) { + suppressPackageStartupMessages(require(ggplot2)) + + # Angled text threshold + VERT_LEVELS_MIN <- 9 # minimum number of levels + VERT_WORD_MIN <- 7 # minimum length of largest label name + + data <- convert_to_dataframe(data) + + if (!is.null(labels)) { + labels <- convert_to_dataframe(labels) + label_name <- get_default_columns(labels, label_name) + labels <- validate_data(labels, label_name) + } + + if (is.null(column) && ncol(data) > 2 && !is.null(label_name)) { + suppressPackageStartupMessages(require(reshape2)) + # melt the data + if (!is.null(labels)) { + data <- concatenate_datasets(data, labels, how = "horizontal") + } + data <- reshape2::melt(data, id.vars = label_name) + column <- names(data)[3] + } else { + column <- get_default_columns(data, column) + data <- validate_data(data, column) + if (!is.null(labels)) { + data <- concatenate_datasets(data, labels, how = "horizontal") + } + } + + + if (is.null(xlab) && !is.null(label_name) && label_name %in% colnames(data)) { + xlab <- label_name + } + + if (is.null(ylab)) { + ylab <- column + } + + if (!is.null(log_num)) { + log_num <- match.arg(log_num, c( + "log2", "log10", "log", "log2_1p", "log10_1p", + "log1p" + )) + # ylab <- prepare_axis_label(ylab, log_num) + data[[column]] <- log_transformation(data[[column]], log_num) + } + + if (!is.null(label_name) && label_name %in% colnames(data)) { + # Make sure label_column is categorical + if (!is.character(data[[label_name]])) { + data[[label_name]] <- as.character(data[[label_name]]) + } + + levels <- length(unique(data[[label_name]])) + max_length <- max(nchar(unique(na.omit(data[[label_name]])))) + + if (is.null(cols)) { + cols <- color_select(levels, col_set) + } + + if (order) { + the_plot <- ggplot(data, aes(x = forcats::fct_reorder(.data[[label_name]], + .data[[column]], + .fun = median + ), y = .data[[column]], fill = .data[[label_name]])) + + geom_violin(alpha = alpha) + + { + if (add_box) { + geom_boxplot( + width = 0.2, color = "black", outlier.shape = 1, + fill = NA, outlier.fill = "black", alpha = alpha, + outliers = show_outliers + ) + } + } + + theme_bw() + + scale_fill_manual(values = cols) + + labs( + x = xlab, y = ylab, + title = title + ) + + theme(legend.position = legend_position, plot.title = element_text(hjust = 0.5)) + + # { + # if (!is.null(log_num)) { + # scale_y_continuous(trans = log_num) + # } + # } + + { + if (!is.null(ylim)) { + ylim(ylim[[1]], ylim[[2]]) + } + } + + { + if (horizontal_plot) { + coord_flip() + } + } + + { + if ((levels >= VERT_LEVELS_MIN || max_length >= VERT_WORD_MIN)) { + theme(axis.text.x = element_text(angle = 60, hjust = 1)) + } + } + } else { + the_plot <- ggplot(data, aes( + x = .data[[label_name]], y = .data[[column]], + fill = .data[[label_name]] + )) + + geom_violin(alpha = alpha) + + { + if (add_box) { + geom_boxplot( + width = 0.7, color = "black", outlier.shape = 1, outlier.size = 0.5, + fill = NA, outlier.fill = "black", alpha = alpha, outliers = show_outliers + ) + } + } + + theme_bw() + + scale_fill_manual(values = cols) + + labs( + x = xlab, y = ylab, + title = title + ) + + theme(legend.position = legend_position, plot.title = element_text(hjust = 0.5)) + + # { + # if (!is.null(log_num)) { + # scale_y_continuous(trans = log_num) + # } + # } + + { + if (!is.null(ylim)) { + ylim(ylim[[1]], ylim[[2]]) + } + } + + { + if (horizontal_plot) { + coord_flip() + } + } + + { + if ((levels >= VERT_LEVELS_MIN || max_length >= VERT_WORD_MIN)) { + theme(axis.text.x = element_text(angle = 60, hjust = 1)) + } + } + } + } else { + # generate a single violin plot + the_plot <- ggplot(data, aes(x = 1, y = .data[[column]])) + + geom_violin(fill = "grey40", alpha = alpha) + + { + if (add_box) { + geom_boxplot( + width = 0.2, color = "black", outlier.shape = 1, + fill = NA, outlier.fill = "black", alpha = alpha + ) + } + } + + theme_bw() + + labs(x = xlab, y = ylab, title = title) + + theme( + legend.position = legend_position, + plot.title = element_text(hjust = 0.5) + ) + + { + if (!is.null(log_num)) { + scale_y_continuous(trans = log_num) + } + } + + { + if (horizontal_plot) { + coord_flip() + } + } + } + return(the_plot) +} + +#' +#' `generate_scatterplot()` creates a scatterplot using ggplot2 +#' +#' Note: Requires ggplot2 package +#' +#' @param data (data.frame or vector) data frame containing the variables or data for the x axis. +#' @param y (data.frame or vector) data for the y axis. +#' @param group (data.frame or vector) label data for grouping the x and y by colour. +#' @param xdata (chr) name of column with data for the x-axis. +#' @param ydata (chr) name of column with data for the y-axis. +#' @param groupby (chr) name of categorical variable to group the points. Default is NULL. +#' @param xlab (chr) x-axis label. Default is 'Var 1'. +#' @param ylab (chr) y-axis label. Default is 'Var 2'. +#' @param title (chr) title for the plot. Default is 'Scatterplot'. +#' @param alpha (dbl) opacity of points. Range: 0 to 1. Default is 1. +#' @param col_set (chr) name of RColorBrewer set to be used for colouring. +#' Options Include: (default) 'Set1', 'Set2', 'Set3', 'Pastel2', 'Pastel1', 'Paired', 'Dark2', 'Accent' +#' @param xlog (chr) log transformation of the x-axis. +#' Options Include: (default) NULL (None), 'log', 'log2', 'log10', 'log1p', 'log2_1p', 'log10_1p'. +#' @param ylog (chr) log transformation of the y-axis. +#' Options Include: (default) NULL (None), 'log', 'log2', 'log10', 'log1p', 'log2_1p', 'log10_1p'. +#' +#' @returns The ggplot object of the completed plot +#' +#' @examples +#' \dontrun{ +#' generate_scatterplot(mtcars, xdata = "wt", ydata = "qsec") +#' +#' generate_scatterplot(iris, +#' xdata = "Petal.Length", ydata = "Petal.Width", +#' groupby = "Species", xlab = "Petal Length", +#' ylab = "Petal Width", title = "Petals (Length VS. Width)", +#' alpha = 0.75, col_set = "Set2" +#' ) +#' } +generate_scatterplot <- function( + data, y = NULL, group = NULL, xdata = "x", ydata = "y", + groupby = "group", xlab = NULL, ylab = NULL, title = "Scatterplot", alpha = 1, + col_set = "Set1", cols = NULL, xlog = NULL, ylog = NULL) { + suppressPackageStartupMessages(require(ggplot2)) + data <- convert_to_dataframe(data) + xdata <- get_default_columns(data, xdata) + data <- validate_data(data, xdata) + + if (!is.null(y)) { + y <- convert_to_dataframe(y) + value_name <- get_default_columns(y, ydata) + y <- validate_data(y, ydata) + data <- concatenate_datasets(data, y, how = "horizontal") + } else { + data <- validate_data(data, ydata) + } + + if (!is.null(group)) { + group <- convert_to_dataframe(group) + groupby <- get_default_columns(group, groupby) + group <- validate_data(group, groupby) + data <- concatenate_datasets(data, group, how = "horizontal") + } + + if (is.null(xlab)) { + xlab <- xdata + } + + if (is.null(ylab)) { + ylab <- ydata + } + + if (!is.null(xlog)) { + xlog <- match.arg(xlog, c( + "log2", "log10", "log", "log2_1p", "log10_1p", + "log1p" + )) + xlab <- prepare_axis_label(xlab, xlog) + } + if (!is.null(ylog)) { + ylog <- match.arg(ylog, c( + "log2", "log10", "log", "log2_1p", "log10_1p", + "log1p" + )) + ylab <- prepare_axis_label(ylab, ylog) + } + + if (!is.null(groupby) && groupby %in% colnames(data)) { + grouped <- TRUE + + if (!is.character(data[[groupby]])) { + data[[groupby]] <- as.character(data[[groupby]]) + } + + if (is.null(cols)) { + levels <- length(unique(data[[groupby]])) + cols <- color_select(levels, col_set) + } + } else { + grouped <- FALSE + } + + if (grouped) { + ggplot(data, aes(x = .data[[xdata]], y = .data[[ydata]], col = .data[[groupby]])) + + geom_point(alpha = alpha) + + theme_bw() + + theme(plot.title = element_text(hjust = 0.5)) + + scale_color_manual(values = cols) + + labs(x = xlab, y = ylab, title = title) + + { + if (!is.null(xlog)) { + scale_x_continuous(trans = xlog) + } + } + + { + if (!is.null(ylog)) { + scale_y_continuous(trans = ylog) + } + } + } else { + ggplot(data, aes(x = .data[[xdata]], y = .data[[ydata]])) + + geom_point(alpha = alpha) + + theme_bw() + + theme(plot.title = element_text(hjust = 0.5)) + + labs( + x = xlab, + y = ylab, title = title + ) + + { + if (!is.null(xlog)) { + scale_x_continuous(trans = xlog) + } + } + + { + if (!is.null(ylog)) { + scale_y_continuous(trans = ylog) + } + } + } +} + +#' calls all the plots for a single categorical variable +#' @param data (data.frame) data frame containing the data. +#' @param label_name (chr) name of categorical variable. +plot_cat_metadata <- function(data, label_name, title = NULL, ...) { + if (is.null(title)) { + title <- paste0("Distribution of ", label_name) + } + p <- generate_barplot(data, + label_name = label_name, xlab = label_name, title = title, + ... + ) + return(p) +} + +#' calls all the plots for a single numerical variable +#' @param data (data.frame) data frame containing the data. +#' @param num_var (chr) name of numerical variable. +plot_num_metadata <- function(data, num_var, title = NULL, ...) { + if (is.null(title)) { + title <- paste0("Distribution of ", num_var) + } + p <- generate_histogram(data, num_var, xlab = num_var, title = title, ...) + return(p) +} + +#' Categorical Variable vs Categorical Variable +#' @param data (data.frame) data frame containing the data. +#' @param var1 (chr) name of 1st categorical variable. +#' @param var2 (chr) name of 2nd categorical variable. +plot_cat_vs_cat_metadata <- function(data, var1, var2, title = NULL, ...) { + suppressPackageStartupMessages(require(patchwork)) + # Categorical vs Categorical + if (is.null(title)) { + title <- paste0(var1, " VS. ", var2) + } + # Stacked Barplot + p1 <- generate_barplot(data, + label_name = var1, groupby = var2, xlab = var1, + title = NULL, add_count_lab = F, ... + ) + # Proportional Barplot + p2 <- generate_barplot(data, + label_name = var1, xlab = var1, groupby = var2, + prop = T, title = NULL, ... + ) + # Grid + p <- (p1 + p2 + plot_layout(guides = "collect") + plot_annotation( + title = title, + theme = theme(plot.title = element_text(hjust = 0.5)) + ) & theme(legend.position = "top")) + return(p) +} + +#' Categorical Variable vs Numerical Variable +#' @param data (data.frame) data frame containing the data. +#' @param cat (chr) name of categorical variable. +#' @param num (chr) name of numerical variable. +plot_cat_vs_num_metadata <- function(data, cat, num, title = NULL, ...) { + suppressPackageStartupMessages(require(patchwork)) + # Categorical vs Numeric + if (is.null(title)) { + title <- paste0(cat, " VS. ", num) + } + # Violin Plot + p <- generate_violin(data, column = num, label_name = cat, title = title, ...) + return(p) +} + +#' Numerical Variable vs Numerical Variable +#' @param data (data.frame) data frame containing the data. +#' @param var1 (chr) name of 1st numerical variable. +#' @param var2 (chr) name of 2nd numerical variable. +plot_num_vs_num_metadata <- function(data, var1, var2, title = NULL, ...) { + # Numeric vs Numeric + if (is.null(title)) { + title <- paste0(var1, " VS. ", var2) + } + # Scatterplot + p <- generate_scatterplot(data, + xdata = var1, ydata = var2, xlab = var1, ylab = var2, + title = title, ... + ) + return(p) +} diff --git a/src/biofit/integration/R/scripts/utils.R b/src/biofit/integration/R/scripts/utils.R new file mode 100644 index 0000000..98225e1 --- /dev/null +++ b/src/biofit/integration/R/scripts/utils.R @@ -0,0 +1,246 @@ +BIOCONDUCTOR <- c( + "limma", "metagenomeSeq", "phyloseq", + "DESeq2", "biomaRt", "GenomicRanges", + "Biostrings", "BSgenome", "GenomicFeatures", + "GenomicAlignments", "VariantAnnotation", + "S4Vectors", "IRanges", "AnnotationDbi", + "ComplexHeatmap" +) + +save_plots <- function( + filename, plot = last_plot(), device = NULL, path = NULL, + scale = 1, width = NA, height = NA, units = c( + "in", "cm", + "mm", "px" + ), dpi = 300, limitsize = TRUE, bg = NULL, + create.dir = TRUE, ...) { + suppressPackageStartupMessages(c( + require(ggplot2) + )) + + + ext <- NULL + if (!is.null(filename)) { + ext <- tolower(tools::file_ext(filename)) + } else if (!is.null(path)) { + ext <- tolower(tools::file_ext(path)) + } + + # always save a png if ext is not png + if (ext != "png") { + fn <- paste0(tools::file_path_sans_ext(filename), ".png") + + ggplot2::ggsave( + fn, + plot = plot, device = "png", + path = path, scale = scale, width = width, + height = height, units = units, dpi = dpi, + limitsize = limitsize, bg = bg, create.dir = create.dir, ... + ) + } + ggplot2::ggsave( + filename, + plot = plot, device = device, + path = path, scale = scale, width = width, + height = height, units = units, dpi = dpi, + limitsize = limitsize, bg = bg, create.dir = create.dir, ... + ) +} + +get_info_from_arrow_table <- function(arrow_table) { + # Get the schema + info <- arrow_table$schema$metadata$huggingface + if (is.null(info)) { + return(NULL) + } + jsonlite::fromJSON(info)$info$features +} + +#' @title arrow_to_factor +#' @description Convert label encodings to factors. +#' Uses the arrow metadata to get the label information. +#' @param y The arrow table +#' @return if y is not an arrow table or y is not a class label, +#' it returns y as is. +encodings_to_factor <- function(y) { + # verify that y is an arrow table + if (!inherits(y, "Table")) { + return(y) + } + label_info <- get_info_from_arrow_table(y)$features$labels + if (is_class(label_info)) { + return(y) + } + # Add a label for missing values + label_names <- c("", label_info$names) + y <- as.data.frame(y)[, 1] + # labels ranges from -1 to n-1, + # where n is the number of classes and a -1 is a missing value + factor(y + 2, + levels = seq_along(label_names), + labels = label_names + ) +} + +is_class <- function(label_info) { + if (label_info$`_type` != "ClassLabel") { + return(TRUE) + } + FALSE +} + +#' detect using class what each variable type is +#' @param data data frame containing the data +#' @param col name of the column that is being checked +detect_var_type <- function(data, col) { + cat_types <- c("character", "logical") + num_types <- c("numeric", "integer", "double") + + if (class(data[[col]]) %in% cat_types) { + col_type <- "categorical" + + if (is_other_var(data, col)) { + col_type <- "other" + } + } else if (class(data[[col]]) %in% num_types) { + col_type <- "numerical" + + if (!is_other_var(data, col)) { + col_type <- "categorical" + } + if (length(unique(data[[col]])) <= 1) { + col_type <- "other" + } + } + return(col_type) +} + +#' For detecting our classification of other variable +#' @param data data frame containing the data +#' @param col name of variable being checked +#' @param threshold max number of levels +is_other_var <- function(data, col, threshold = 30) { + other_var <- FALSE + + num_levels <- length(unique(data[[col]])) + # if there are more than 30 levels we consider this an other variable + if (num_levels > threshold || num_levels <= 1) { + other_var <- TRUE + } + + return(other_var) +} + +get_feature_metadata <- function(data) { + # Create a metadata table for the features + feature_metadata <- get_info_from_arrow_table(data) + + extract_metadata <- function(item) { + return(item$metadata) + } + # Extract and combine the metadata from each item into a data frame + metadata_list <- lapply(feature_metadata, extract_metadata) + metadata_df <- dplyr::bind_rows(metadata_list) + # add the feature names to the metadata + metadata_df$feature <- colnames(data) + return(metadata_df) +} + +start_device <- function(path, ...) { + # Start the device + if (is.null(path)) { + path <- tempfile() + } + ext <- tolower(tools::file_ext(path)) + if (ext == "pdf") { + args <- list(...) + pdf(path, width = args$width, height = args$height) + } else if (ext == "png") { + png(path, ...) + } else if (ext == "jpeg" || ext == "jpg") { + jpeg(path, ...) + } else if (ext == "bmp") { + bmp(path, ...) + } else if (ext == "tiff") { + tiff(path, ...) + } else { + stop("Unsupported file format") + } +} + + +is_dataframe <- function(data) { + is(data, "Table") || + is(data, "RecordBatch") || + is.data.frame(data) +} + + +is_vector <- function(data) { + is(data, "Array") || + is(data, "ChunkedArray") || + is.vector(data) +} + + +convert_to_dataframe <- function(value) { + if (is_dataframe(value)) { + as.data.frame(value) + } else if (is_vector(value)) { + `colnames<-`(as.data.frame(value), NULL) + } else { + stop(paste0("Unsupported object type: ", class(data))) + } +} + +get_default_columns <- function(data, default = NULL, max_cols = 1) { + if ( + is_dataframe(data) && + ( + is.null(max_cols) || + ncol(data) == max_cols + ) && + !is.null(colnames(data)) + ) { + colnames(data) + } else { + default + } +} + +validate_data <- function(data, names, required = FALSE, replace = TRUE) { + if (is_dataframe(data)) { + if (!is.null(names)) { + if (!all(names %in% colnames(data)) && !replace) { + stop(paste0("Column name ", names, " is not in the data")) + } + } else if (required) { + stop("A column name was not provided") + } + if (ncol(data) == length(names) && replace) { + colnames(data) <- names + } + data + } else { + stop("Data must be a data frame") + } +} + + +concatenate_datasets <- function(data1, data2, how = "horizontal") { + if (is_dataframe(data1) || is_dataframe(data2)) { + if (how == "vertical") { + rbind(data1, data2) + } else { + cbind(data1, data2) + } + } else if (is_vector(data1) && is_vector(data2)) { + if (how == "vertical") { + c(data1, data2) + } else { + cbind(data1, data2) + } + } else { + stop("Data types must match") + } +} diff --git a/src/biofit/integration/__init__.py b/src/biofit/integration/__init__.py new file mode 100644 index 0000000..2737d0b --- /dev/null +++ b/src/biofit/integration/__init__.py @@ -0,0 +1,2 @@ +# ruff: noqa +from .R import RCaller diff --git a/src/biofit/integration/biosets.py b/src/biofit/integration/biosets.py new file mode 100644 index 0000000..9c56126 --- /dev/null +++ b/src/biofit/integration/biosets.py @@ -0,0 +1,14 @@ +import importlib + +from biocore.utils.import_util import is_biosets_available + + +class NotAvailable: + def __init__(self, *args, **kwargs): + pass + + +def get_feature(val): + if is_biosets_available(): + return getattr(importlib.import_module("biosets.features"), val) + return None diff --git a/src/biofit/integration/patcher.py b/src/biofit/integration/patcher.py new file mode 100644 index 0000000..151f844 --- /dev/null +++ b/src/biofit/integration/patcher.py @@ -0,0 +1,880 @@ +import ast +import importlib +import importlib.util +import inspect +import json +import os +import posixpath +import shutil +import threading +from collections import defaultdict +from contextlib import contextmanager +from pathlib import Path +from types import ModuleType +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeAlias, Union + +import filelock + +import biofit.config + +from ..utils import gorilla, logging +from ..utils.file_utils import is_remote_url +from ..utils.fingerprint import Hasher, is_caching_enabled +from ..utils.gorilla import ( + SameSourceAndDestinationError, + _get_members, + _module_iterator, +) + +logger = logging.get_logger(__name__) + +_EXCLUDED_MEMBERS = [ + "__builtins__", + "__cached__", + "__doc__", + "__file__", + "__loader__", + "__name__", + "__package__", + "__spec__", + "__annotations__", +] + + +def dynamic_import(attr_path: str): + """ + Dynamically imports an attribute (class, function, variable) from a given path. + + :param attr_path: The full path to the attribute + :return: The attribute object. + """ + module_path, attr_name = attr_path.rsplit(".", 1) + module = importlib.import_module(module_path) + return getattr(module, attr_name) + + +def parse_module_for_patching( + module: Union[ModuleType, str], + hash_module: bool = True, + extra=[], + filter: Optional[Callable] = None, +) -> Dict[str, Tuple[Any, str]]: + if isinstance(module, str): + module = importlib.import_module(module) + + members = _get_members( + module, filter=create_patch_filter(module) if filter is None else filter + ) + _patches = {} + for member_name, member in members: + if hash_module: + _patches[member_name] = tuple( + [member, module.__name__, Hasher.hash(member)] + extra + ) + else: + _patches[member_name] = tuple([member, module.__name__] + extra) + return _patches + + +def create_patch_filter(module: ModuleType): + excluded_imports = set() + for node in ast.parse(inspect.getsource(module)).body: + if isinstance(node, ast.ImportFrom) or isinstance(node, ast.Import): + excluded_imports.update( + [ + name.name if name.asname is None else name.asname + for name in node.names + ] + ) + + def filter(name: str, obj: Any): + if ( + name in _EXCLUDED_MEMBERS + or name in excluded_imports + or inspect.ismodule(obj) + ): + return False + return True + + return filter + + +def get_hashed_patches(entity_paths=[], module_paths=[]): + """ + Defines patches with their module paths and dynamically loads and hashes their values. + + :return: A list of patches with their values hashed. + """ + + patches = [] + for path in entity_paths: + value = dynamic_import(path) + module_name = path.rsplit(".", 1)[0] + hashed_value = Hasher.hash(value) + patches.append( + ( + path.rsplit(".", maxsplit=1)[-1], + (value, module_name, hashed_value), + ) + ) + + for path in module_paths: + patches.extend(parse_module_for_patching(path).items()) + + return patches + + +def create_lock_path(root, rel_path): + lock_path = posixpath.join( + root, Path(rel_path).as_posix().replace("/", "_") + ".lock" + ) + return lock_path + + +class PatcherConfig: + """Base class for the wrapper config. Handles patching and caching. + + Args: + patches (Dict[str, Any]): Dictionary of patches to apply. The keys are the names of the members to patch and the values are the patches. + package (ModuleType): The target package that will be within the `patch_targets`. All members within patches will + patch_targets (Union[List[ModuleType], ModuleType]): The packages to patch. + cache_dir (Optional[Union[Path, str]], optional): The cache directory. Defaults to None. + """ + + patches: Dict[str, Any] = None + root: ModuleType = None + patch_targets: Union[List[ModuleType], ModuleType] = None + settings = gorilla.Settings(allow_hit=True, store_hit=True) + cache_dir: Optional[Union[Path, str]] = None + + def __init__( + self, + patches: Dict[str, Any] = None, + root: ModuleType = None, + patch_targets: Union[List[ModuleType], ModuleType] = None, + cache_dir: Optional[Union[Path, str]] = None, + **kwargs, + ): + self._cache_enabled = is_caching_enabled() + + self.patches = self.patches if self.patches is not None else patches + self.root = self.root if self.root is not None else root + self.patch_targets = ( + self.patch_targets if self.patch_targets is not None else patch_targets + ) + self.cache_dir = self.cache_dir if self.cache_dir is not None else cache_dir + + if isinstance(self.patches, dict): + self.patches = list(self.patches.items()) + + if self.patch_targets is None: + self.patch_targets = [self.root] + elif not isinstance(self.patch_targets, list): + self.patch_targets = [self.patch_targets] + + self.cache_files = None + + if self.cache_dir is None: + self._cache_dir_root = biofit.config.BIOFIT_PATCHES_CACHE.as_posix() + else: + if isinstance(self.cache_dir, str) or isinstance(self.cache_dir, Path): + self._cache_dir_root = self.cache_dir + elif isinstance(self.cache_dir, dict): + self._cache_dir_root = self.cache_dir.get( + "root", biofit.config.BIOFIT_PATCHES_CACHE + ) + # specifies the cache files for each patch target. + for patch_target in self.patch_targets: + if patch_target.__name__ in self.cache_dir: + for name, path in self.cache_dir[patch_target.__name__].items(): + if isinstance(path, Path): + path = path.as_posix() + if os.path.exists(path): + if self.cache_files is None: + self.cache_files = defaultdict(dict) + self.cache_files[patch_target.__name__][name] = path + + self._cache_dir_root = ( + self._cache_dir_root + if is_remote_url(self._cache_dir_root) + else os.path.expanduser(self._cache_dir_root) + ) + + self._relative_ref_pkg_cache_dir = self._build_relative_cache_dir(self.root) + self._relative_patch_targets_cache_dirs = { + p.__name__: self._build_relative_cache_dir(p) + for p in self.patch_targets + if hasattr(p, "__name__") + } + + self.patches = self._sort_patches(self._prepare_patches(self.patches)) + self._output_dir = self._build_output_dir() + # NOTE: this is causing issues with the cache, so we will disable it for now + # self.cache_files = self._get_cache_files() + self.cache_files = None + + self._attemps = kwargs.get("attemps", 1) + + def _generate_modules(self): + return { + package.__name__: ( + module for module in _module_iterator(package, recursive=True) + ) + for package in self.patch_targets + if hasattr(package, "__name__") + } + + def _build_output_dir(self): + _output_dir = Hasher().hash(self.patches) + os.makedirs(os.path.join(self._cache_dir_root, _output_dir), exist_ok=True) + return _output_dir + + def _build_relative_cache_dir(self, package): + """Return the data directory for the current version.""" + version = package.__version__ if hasattr(package, "__version__") else "0.0.0" + package_name = package.__name__ if hasattr(package, "__name__") else "unknown" + relative_patches_dir = posixpath.join(package_name, version) + return relative_patches_dir + + def _prepare_patches(self, patches: Dict[str, Any]): + if patches is None: + raise ValueError("patches cannot be None.") + _patches = [] + if not is_remote_url(self._cache_dir_root): + source_cache_dir = posixpath.join( + self._cache_dir_root, self._relative_ref_pkg_cache_dir + ) + os.makedirs(source_cache_dir, exist_ok=True) + for name, patch in patches: + if isinstance(patch, Tuple): + if len(patch) == 2: + patch_, source_ = patch + hash_ = Hasher.hash(patch_) + if isinstance(source_, ModuleType): + source_ = source_.__name__ + # check if source exists + if not isinstance(source_, str) or not importlib.util.find_spec( + source_ + ): + raise ValueError( + f"Invalid source: {source_}, source must be a module." + ) + _patches.append((name, (patch_, source_, hash_, None))) + elif len(patch) == 3: + patch_, source_, hash_ = patch + if isinstance(source_, ModuleType): + source_ = source_.__name__ + if not isinstance(source_, str) or not importlib.util.find_spec( + source_ + ): + raise ValueError( + f"Invalid source: {source_}, source must be a module." + ) + if hash_ is None: + hash_ = Hasher.hash(patch) + _patches.append((name, (patch_, source_, hash_, None))) + elif len(patch) == 4: + patch_, source_, hash_, destination_ = patch + if isinstance(source_, ModuleType): + source_ = source_.__name__ + if not isinstance(source_, str) or not importlib.util.find_spec( + source_ + ): + raise ValueError( + f"Invalid source: {source_}, source must be a module." + ) + if hash_ is None: + hash_ = Hasher.hash(patch) + if isinstance(destination_, str): + destination_ = importlib.import_module(destination_) + if not isinstance(destination_, ModuleType): + raise ValueError( + f"Invalid destination: {destination_}, destination must be a module." + ) + _patches.append((name, (patch_, source_, hash_, destination_))) + else: + hash_ = Hasher.hash(patch) + origin = inspect.getmodule(patch) + source_ = origin.__name__ if hasattr(origin, "__name__") else None + _patches.append((name, (patch_, source_, hash_, None))) + return _patches + + def _get_cache_files(self): + cache_files = None + if self._cache_enabled and not os.path.exists( + posixpath.join( + self._cache_dir_root, self._output_dir, biofit.config.PATCHES_FILENAME + ) + ): + for name, (_, _, hash, _) in self.patches: + for ( + patch_target_name, + relative_target_cache_dir, + ) in self._relative_patch_targets_cache_dirs.items(): + cache_file = posixpath.join( + self._cache_dir_root, + self._relative_ref_pkg_cache_dir, + relative_target_cache_dir, + hash, + ) + os.makedirs(cache_file, exist_ok=True) + patches_file = f"{cache_file}/{biofit.config.PATCHES_FILENAME}" + if os.path.exists(patches_file): + if cache_files is None: + cache_files = defaultdict(dict) + cache_files[patch_target_name][name] = patches_file + no_patches_file = ( + f"{cache_file}/{biofit.config.NO_PATCHES_FILENAME}" + ) + if os.path.exists(no_patches_file): + if cache_files is None: + cache_files = defaultdict(dict) + cache_files[patch_target_name][name] = no_patches_file + return cache_files + + @classmethod + def clear_cache(cls, cache_dir: Optional[Union[Path, str]] = None): + """Clear a directory in the root cache directory for patches. + + Args: + cache_dir (Optional[Union[Path, str]], optional): The cache directory to delete. + If None, the entire cache directory will be deleted. Defaults to None. + """ + if cache_dir is None: + cache_dir = ( + cls._cache_dir_root + if hasattr(cls, "_cache_dir_root") + else biofit.config.BIOFIT_PATCHES_CACHE.as_posix() + ) + if hasattr(cls, "_output_dir"): + cache_dir = posixpath.join(cache_dir, cls._output_dir) + + if os.path.exists(cache_dir): + shutil.rmtree(cache_dir) + + def _clear_cache(self): + self.cache_files = None + for _, (_, _, hash, _) in self.patches: + for ( + relative_target_cache_dir + ) in self._relative_patch_targets_cache_dirs.values(): + rel_member_dir = posixpath.join( + self._relative_ref_pkg_cache_dir, relative_target_cache_dir, hash + ) + member_dir = posixpath.join(self._cache_dir_root, rel_member_dir) + # lock the directory + if os.path.exists(member_dir): + for file in os.listdir(member_dir): + rel_fp = posixpath.join(rel_member_dir, file) + fp = os.path.join(self._cache_dir_root, rel_fp) + lock_path = create_lock_path(self._cache_dir_root, rel_fp) + with filelock.FileLock(lock_path): + os.remove(fp) + with filelock.FileLock( + create_lock_path(self._cache_dir_root, rel_member_dir) + ): + os.rmdir(member_dir) + output_dir = posixpath.join(self._cache_dir_root, self._output_dir) + fp = posixpath.join(output_dir, biofit.config.PATCHES_FILENAME) + if os.path.exists(fp): + with filelock.FileLock( + create_lock_path( + self._cache_dir_root, + posixpath.join( + self._output_dir, biofit.config.PATCHES_FILENAME + ), + ) + ): + os.remove(fp) + if os.path.exists(output_dir): + with filelock.FileLock( + create_lock_path(self._cache_dir_root, self._output_dir) + ): + shutil.rmtree(output_dir) + + self._cleanup_empty_dir() + + def _cleanup_empty_dir(self): + # use os.walk to find empty directories in self._cache_dir_root + for root, dirs, files in os.walk(self._cache_dir_root, topdown=False): + if len(dirs) == 0 and len(files) == 0: + os.rmdir(root) + + def _sort_patches(self, patches: list): + """Sorts the patches for consistent hashing""" + unique_keys = [str(p) for p in patches] + inds = sorted(range(len(unique_keys)), key=lambda index: unique_keys[index]) + return [patches[i] for i in inds] + + +class Patcher: + """ + Patcher class for applying patches to target packages/modules. + """ + + config: PatcherConfig = None + _lock: Path = None + + def __init__(self, config: PatcherConfig = None, **kwargs): + """ + Initialize the Patcher object. + + Args: + configs (Union[PatcherConfig, List[PatcherConfig]]): Configuration object or list of configuration objects. + filter (Optional[Union[str, List[str]]], optional): + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + """ + self.config = self.config if self.config is not None else config + self._cache_enabled = is_caching_enabled() + self.patches = self._sort_patches(self.load_and_prepare_patches()) + self._fingerprint = Hasher.hash(str(self.patches)) + self._validate_cache() + + def _sort_patches(self, patches: list): + """Sorts the patches for consistent hashing""" + return self.config._sort_patches(patches) + + @contextmanager + def patch(self, patches: List[gorilla.Patch]): + """Context manager for applying additional patches to the target packages/modules.""" + for patch in patches: + gorilla.apply(patch) + + self._apply_patches() + + yield + + self._revert_patches() + + for patch in patches: + gorilla.revert(patch) + + def _apply_patches(self): + if self._lock: + # means that it is trying to run in nested context of the same instance + # specify this by setting _cleanup to None + self._cleanup = None + return + + thread_id = threading.get_ident() + file_path = os.path.normpath( + posixpath.join(self.config._output_dir, f"{self._fingerprint}_{thread_id}") + ) + self._lock = Path(create_lock_path(self.config._cache_dir_root, file_path)) + self._cleanup = False + + if not self._lock.exists(): + self._cleanup = True + for patch in self.patches: + gorilla.apply(patch) + self._lock.parent.mkdir(parents=True, exist_ok=True) + self._lock.touch() + else: + self._lock = None + + def _revert_patches(self): + if self._cleanup is None: + # same instance is already running, we put it back to true + self._cleanup = True + return + if self._cleanup: + for patch in self.patches: + gorilla.revert(patch) + # remove the lock file + if self._lock and self._lock.exists(): + self._lock.unlink() + self._lock = None + + def __enter__(self): + self._apply_patches() + + def __exit__(self, exc_type, exc_value, traceback): + self._revert_patches() + + def load_and_prepare_patches(self): + output_file = posixpath.join( + self.config._output_dir, biofit.config.PATCHES_FILENAME + ) + fp = posixpath.join(self.config._cache_dir_root, output_file) + if self._cache_enabled and os.path.exists(fp): + try: + return self._load_patches_from_cache(fp) + except Exception as e: + # something went wrong with loading the patches from cache, so we will clear the cache and recompile + self.config._clear_cache() + logger.warning(f"Failed to load patches from cache: {e}") + + patches: List[gorilla.Patch] = [] + _uncached_patches = defaultdict(dict) + for name, (patch, source, hash, destination) in self.config.patches: + if destination is not None: + try: + patches.append( + gorilla.Patch( + destination, + name=name, + obj=patch, + source=source, + settings=self.config.settings, + ) + ) + except SameSourceAndDestinationError: + # a patch with the same source and destination was given, raise an error + raise SameSourceAndDestinationError( + f'Patch with same source and destination: ("{name}", ({name}, "{source}", "{hash}", "{destination}"))' + ) + for patch_target in self.config.patch_targets: + _uncached_patches[patch_target.__name__][name] = (patch, source, hash) + + patches += self.create_patches(_uncached_patches) + try: + self._save_patches_to_cache(patches, fp) + except Exception as e: + logger.warning(f"Failed to save patches to cache: {e}") + + if os.path.exists(fp): + with filelock.FileLock( + create_lock_path(self.config._cache_dir_root, output_file) + ): + os.remove(fp) + with filelock.FileLock( + create_lock_path( + self.config._cache_dir_root, self.config._output_dir + ) + ): + os.rmdir( + os.path.join( + self.config._cache_dir_root, self.config._output_dir + ) + ) + return patches + + def create_patches(self, patches: Dict[str, Any] = None): + results = [] + if len(patches) > 0: + for patch_target in self.config.patch_targets: + _patches: List[Tuple[gorilla.Patch, str]] = self._create_patches( + patches, + patch_target, + return_hash=True, + ) + results += [p for p, _ in _patches] + + # NOTE: this is causing issues with the cache, so we will disable it for now + # if self._cache_enabled: + # self._save_patches_to_patch_cache(_patches, patch_target.__name__) + + return results + + def _create_patches( + self, + patches: Dict[str, Any], + patch_target: ModuleType, + return_hash=False, + ) -> Union[List[gorilla.Patch], List[Tuple[gorilla.Patch, str]]]: + """Create patches for within `~self.config.package`. + + Args: + patches (`Dict[str, Any]`): + Dictionary of patches to apply. The keys are the names of the members. + Values can be either the patch (i.e. the object that will replace the member) or a tuple of + (patch, source module), (patch, hash), or (patch, source module, hash). If tuple is length 2, + the second element is inferred based on if it is an instance of `ModuleType` or if str is importable. + Otherwise, it will be taken as the hash. + patch_targets (`Union[ModuleType, List[ModuleType]]`): + The target packages to patch. If recursive is True, all submodules will be patched as well. + recursive (`bool`, Defaults to True): + Whether to recursively search for modules to patch. + return_hash (`bool`, Defaults to False): + Whether to return the hash of the patch in final output. If True, the output will be a tuple of (patch, hash). + + Returns: + `List[gorilla.Patch]` or `List[Tuple[gorilla.Patch, str]]`: + List of patches to apply. If return_hash is True, the output will be a tuple of (patch, hash). + """ + + _patches = [] + + try: + for module in self.config._generate_modules()[patch_target.__name__]: + for asname, name, value in self._find_patches( + patches[patch_target.__name__], module + ): + if isinstance(value, tuple): + if len(value) == 3: + patch, source, hash = value + elif len(value) == 2: + patch, source = value + if importlib.util.find_spec(source): + hash = Hasher.hash(patch) if return_hash else None + else: + hash = source + origin = inspect.getmodule(value) + source = ( + origin.__name__ + if hasattr(origin, "__name__") + else None + ) + else: + patch = value + origin = inspect.getmodule(value) + source = ( + origin.__name__ if hasattr(origin, "__name__") else None + ) + hash = Hasher.hash(patch) if return_hash else None + try: + if return_hash: + _patches.append( + ( + gorilla.Patch( + module, + name=asname, + obj=patch, + source=source, + source_name=name, + settings=self.config.settings, + ), + hash, + ) + ) + else: + _patches.append( + gorilla.Patch( + module, + name=asname, + obj=patch, + source=source, + source_name=name, + settings=self.config.settings, + ) + ) + except SameSourceAndDestinationError: + pass + + except Exception as e: + logger.warning( + f"Failed to create patches: {e} for patch_target: {patch_target.__name__}" + ) + raise e + return _patches + + def _get_absolute_module_name( + self, source_module: Union[str, ModuleType], node: ast.ImportFrom + ) -> str: + """ + Reconstructs the absolute module name from a relative import. + + Args: + - source_module (str): The name of the module that contains the relative import. + - node (ast.ImportFrom): The AST node representing the relative import. + + Returns: + - str: The absolute module name. + """ + if not isinstance(node, ast.ImportFrom): + raise ValueError("Node must be of type ast.ImportFrom") + + relative_level = node.level + source_module_parts = source_module.split(".") + if relative_level > len(source_module_parts): + raise ValueError("Relative level is too high for the current module") + + source_module_parts = source_module_parts[:-relative_level] + imported_module_parts = node.module.split(".") if node.module else [] + + absolute_module_parts = source_module_parts + imported_module_parts + return ".".join(absolute_module_parts) + + def _find_patches( + self, patches: Dict[str, Any], module: ModuleType + ) -> List[Tuple[str, Any]]: + """ + Find the patches to apply in the module. + + Args: + patches (Dict[str, Any]): Dictionary of patches to apply. + module (ModuleType): Module to search for patches. + + Returns: + OrderedDict[str, Any]: Ordered dictionary of patches to apply. + """ + try: + module_source = inspect.getsource(module) + except OSError: + with open(module.__file__, "r") as file: + module_source = file.read() + + exclude_imports = set() + asname_map = {} + + package_name = self.config.root.__name__ + + for node in ast.parse(module_source).body: + if isinstance(node, ast.ImportFrom): + module_name = ( + node.module + if node.level == 0 + else self._get_absolute_module_name(module.__name__, node) + ) + if package_name not in module_name: + exclude_imports.update([name.name for name in node.names]) + else: + for name in node.names: + if name.asname is not None: + asname_map[name.asname] = name.name + + def is_from_package(member: TypeAlias, module: ModuleType, package_name): + """Check if the member is from the same package as the module.""" + origin = inspect.getmodule(member) + return origin is None or origin.__name__.startswith(package_name) + + module_globals = inspect.getmembers( + module, lambda a: is_from_package(a, module, package_name) + ) + + return [ + ( + (name, name, patches[name]) + if name not in asname_map + else (name, asname_map[name], patches[asname_map[name]]) + ) + for (name, value) in module_globals + if ( + name not in exclude_imports + and name not in _EXCLUDED_MEMBERS + and not inspect.ismodule(value) + and ( + name in patches + or (name in asname_map and asname_map[name] in patches) + ) + ) + ] + + def _save_patches_to_cache( + self, patches: List[Tuple[gorilla.Patch, str]], file_path: str + ): + if not is_remote_url(self.config._cache_dir_root): + _relative_cache_dir = ( + Path(file_path).relative_to(self.config._cache_dir_root).as_posix() + ) + with filelock.FileLock( + create_lock_path(self.config._cache_dir_root, _relative_cache_dir) + ): + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, "w") as f: + json.dump([p.to_dict() for p in patches], f) + + def _save_patches_to_patch_cache( + self, patches: List[Tuple[gorilla.Patch, str]], patch_target: str + ): + """ + Save the patches to the cache. + + Args: + config (PatcherConfig): Patcher configuration object. + patches (List[Tuple[gorilla.Patch, str]]): List of patches to save. + patch_targets_name_or_file_path (str, optional): Name of the patch target or the file path to save the patches to. + """ + + if not is_remote_url(self.config._cache_dir_root): + gorilla_patches: Dict[str, List[gorilla.Patch]] = defaultdict(list) + for gorilla_patch, hash in patches: + gorilla_patches[hash].append(gorilla_patch) + + for _, (_, _, hash, _) in self.config.patches: + if hash in gorilla_patches: + continue + gorilla_patches[hash] = [] + + for hash, gorilla_patch in gorilla_patches.items(): + _relative_cache_dir = posixpath.join( + self.config._relative_ref_pkg_cache_dir, + self.config._relative_patch_targets_cache_dirs[patch_target], + hash, + ) + if gorilla_patch: + file_path = os.path.join( + self.config._cache_dir_root, + _relative_cache_dir, + biofit.config.PATCHES_FILENAME, + ) + else: + file_path = os.path.join( + self.config._cache_dir_root, + _relative_cache_dir, + biofit.config.NO_PATCHES_FILENAME, + ) + # lock the directory + with filelock.FileLock( + create_lock_path( + self.config._cache_dir_root, _relative_cache_dir + ) + ): + try: + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, "w") as f: + json.dump([p.to_dict() for p in gorilla_patch], f) + + except Exception as e: + logger.warning(f"Failed to save patch to cache: {e}") + if os.path.exists(file_path): + os.remove(file_path) + os.rmdir(os.path.dirname(file_path)) + + def _load_patches_from_cache(self, file_path: str) -> List[gorilla.Patch]: + """ + Load the patches from the cache. + + Args: + config (PatcherConfig): Patcher configuration object. + file_path (str): Path to the cache file. + + Returns: + List[gorilla.Patch]: List of patches loaded from the cache. + """ + relative_file_path = ( + Path(file_path).relative_to(self.config._cache_dir_root).as_posix() + ) + + # lock the directory + with filelock.FileLock( + create_lock_path(self.config._cache_dir_root, relative_file_path) + ): + with open(file_path, "r") as f: + content = json.load(f) + if content is None: + raise ValueError(f"Failed to load patch from cache: {file_path}") + if isinstance(content, dict): + return [gorilla.Patch.from_dict(content)] + elif isinstance(content, list): + return [gorilla.Patch.from_dict(c) for c in content] + + def _validate_cache(self): + if not is_remote_url(self.config._cache_dir_root): + output_dir = posixpath.join( + self.config._cache_dir_root, self.config._output_dir + ) + probe_file = os.path.normpath( + posixpath.join(output_dir, f"{self._fingerprint}.txt") + ) + # look for a .cache file in the output directory + for root, dirs, files in os.walk(output_dir, topdown=False): + for file in files: + if file.endswith(".txt"): + fp = os.path.normpath(os.path.join(root, file)) + if probe_file != fp: + logger.debug( + "Cache file is outdated or corrupted, clearing the cache and re-building the patches..." + ) + self.config.clear_cache(self.config._output_dir) + self.__init__(self.config) + return + if not os.path.exists(probe_file): + with filelock.FileLock( + create_lock_path( + self.config._cache_dir_root, self.config._output_dir + ) + ): + with open(probe_file, "w") as f: + f.write("") diff --git a/src/biofit/metrics/__init__.py b/src/biofit/metrics/__init__.py new file mode 100644 index 0000000..7015280 --- /dev/null +++ b/src/biofit/metrics/__init__.py @@ -0,0 +1,168 @@ +from functools import partial + +from sklearn.metrics import ( + accuracy_score, + balanced_accuracy_score, + f1_score, + log_loss, + mean_absolute_error, + mean_squared_error, + mean_squared_log_error, + precision_score, + r2_score, + recall_score, + roc_auc_score, +) + +from .metrics import ( + calculate_metrics, # noqa: F401 + confusion_matrix, # noqa: F401 + log_loss_weighted, + specificity, +) + +_OPTUNA_METRICS = { + "binary_classification": "f1", + "multiclass_classification": "f1_weighted", + "regression": "rmse", + "multi_regression": "rmse", +} + +_METRICS = { + "binary_classification": { + "logloss": (log_loss, "minimize"), + "logloss_weighted": ( + partial(log_loss_weighted, class_weights="balanced"), + "minimize", + ), + "auc": (roc_auc_score, "maximize"), + "f1": (f1_score, "maximize"), + "accuracy": (accuracy_score, "maximize"), + "balanced_accuracy": (balanced_accuracy_score, "maximize"), + "precision": (precision_score, "maximize"), + "recall": (recall_score, "maximize"), + "specificity": (specificity, "maximize"), + }, + "multiclass_classification": { + "mlogloss": ( + lambda y_true, y_pred, labels: log_loss(y_true, y_pred, labels=labels), + "minimize", + ), + "mlogloss_weighted": ( + lambda y_true, y_pred, labels: log_loss_weighted( + y_true, y_pred, labels=labels, class_weights="balanced" + ), + "minimize", + ), + "accuracy": ( + lambda y_true, y_pred, labels: accuracy_score(y_true, y_pred), + "maximize", + ), + "balanced_accuracy": ( + lambda y_true, y_pred, labels: balanced_accuracy_score(y_true, y_pred), + "maximize", + ), + "f1_macro": ( + lambda y_true, y_pred, labels: f1_score( + y_true, y_pred, average="macro", labels=labels + ), + "maximize", + ), + "f1_micro": ( + lambda y_true, y_pred, labels: f1_score( + y_true, y_pred, average="micro", labels=labels + ), + "maximize", + ), + "f1_weighted": ( + lambda y_true, y_pred, labels: f1_score( + y_true, y_pred, average="weighted", labels=labels + ), + "maximize", + ), + "precision_macro": ( + lambda y_true, y_pred, labels: precision_score( + y_true, y_pred, average="macro", labels=labels + ), + "maximize", + ), + "precision_micro": ( + lambda y_true, y_pred, labels: precision_score( + y_true, y_pred, average="micro", labels=labels + ), + "maximize", + ), + "precision_weighted": ( + lambda y_true, y_pred, labels: precision_score( + y_true, y_pred, average="weighted", labels=labels + ), + "maximize", + ), + "recall_macro": ( + lambda y_true, y_pred, labels: recall_score( + y_true, y_pred, average="macro", labels=labels + ), + "maximize", + ), + "recall_micro": ( + lambda y_true, y_pred, labels: recall_score( + y_true, y_pred, average="micro", labels=labels + ), + "maximize", + ), + "recall_weighted": ( + lambda y_true, y_pred, labels: recall_score( + y_true, y_pred, average="weighted", labels=labels + ), + "maximize", + ), + "specificity_macro": ( + lambda y_true, y_pred, labels: specificity( + y_true, y_pred, average="macro", labels=labels + ), + "maximize", + ), + # "specificity_micro": ( + # lambda y_true, y_pred, labels: specificity( + # y_true, y_pred, average="micro", labels=labels + # ), + # "maximize", + # ), + "specificity_weighted": ( + lambda y_true, y_pred, labels: specificity( + y_true, y_pred, average="weighted", labels=labels + ), + "maximize", + ), + }, + "regression": { + "rmse": (partial(mean_squared_error, squared=False), "minimize"), + "rmsle": (partial(mean_squared_log_error, squared=False), "minimize"), + "r2": (r2_score, "maximize"), + "mse": (mean_squared_error, "minimize"), + "mae": (mean_absolute_error, "minimize"), + }, + "multi_regression": { + "rmse": (partial(mean_squared_error, squared=False), "minimize"), + "rmsle": (partial(mean_squared_log_error, squared=False), "minimize"), + "r2": (r2_score, "maximize"), + "mse": (mean_squared_error, "minimize"), + "mae": (mean_absolute_error, "minimize"), + }, + "multilabel_classification": { + "logloss": (log_loss, "minimize"), + }, +} + + +def get_metrics(task=None, metric=None): + if task is None: + return _METRICS + if metric is None: + return _METRICS[task] + else: + return _METRICS[task][metric] + + +def get_optuna_metric(task): + return _OPTUNA_METRICS[task] diff --git a/src/biofit/metrics/metrics.py b/src/biofit/metrics/metrics.py new file mode 100644 index 0000000..00ddeb9 --- /dev/null +++ b/src/biofit/metrics/metrics.py @@ -0,0 +1,201 @@ +import copy +import inspect +from typing import Union + +import numpy as np +import pandas as pd +from biocore import DataHandler +from sklearn.metrics import ( + confusion_matrix as sk_confusion_matrix, +) +from sklearn.metrics import ( + log_loss, + multilabel_confusion_matrix, +) +from sklearn.metrics._classification import ( + _check_set_wise_labels, + _check_zero_division, + _nanaverage, + _prf_divide, +) +from sklearn.utils.class_weight import compute_class_weight + + +def confusion_matrix( + y_true, y_pred, labels=None, sample_weight=None, samplewise=False, normalize=None +): + y_true = DataHandler.to_numpy(y_true) + y_pred = DataHandler.to_numpy(y_pred) + + if y_true.ndim == 1 or y_true.shape[1] == 1: + y_true = y_true.flatten() + if y_pred.ndim == 1 or y_pred.shape[1] == 1: + y_pred = (y_pred.flatten() > 0.5).astype(int) + else: + y_pred = np.argmax(y_pred, axis=1) + + if len(labels) < 3: + mat = sk_confusion_matrix( + y_true, + y_pred, + labels=labels + if labels is not None and not isinstance(labels[0], str) + else None, + sample_weight=sample_weight, + normalize=normalize, + ) + + df = pd.DataFrame( + mat, + columns=[f"Predicted {label}" for label in labels], + index=[f"Actual {label}" for label in labels], + ) + return df + else: + # create a len(labels) x len(labels) matrix + mat = np.zeros((len(labels), len(labels)), dtype=int) + for i, label in enumerate(labels): + for j, pred in enumerate(labels): + mat[i, j] = np.sum((y_true == i) & (y_pred == j)) + + df = pd.DataFrame( + mat, + columns=[f"Predicted {label}" for label in labels], + index=[f"Actual {label}" for label in labels], + ) + return df + + +def specificity( + y_true, + y_pred, + *, + labels=None, + pos_label=1, + average=None, + sample_weight=None, + zero_division="warn", +): + _check_zero_division(zero_division) + labels = _check_set_wise_labels(y_true, y_pred, average, labels, pos_label) + + # Calculate tp_sum, pred_sum, true_sum ### + samplewise = average == "samples" + MCM = multilabel_confusion_matrix( + y_true, + y_pred, + sample_weight=sample_weight, + labels=labels, + samplewise=samplewise, + ) + tp_sum = MCM[:, 1, 1] + tn_sum = MCM[:, 0, 0] + fp_sum = MCM[:, 0, 1] + fn_sum = MCM[:, 1, 0] + true_sum = tp_sum + fn_sum + false_sum = tn_sum + fp_sum + + if average == "micro": + false_sum = np.array([false_sum.sum()]) + + # Divide, and on zero-division, set scores and/or warn according to + # zero_division: + specificity = _prf_divide( + tn_sum, false_sum, "specificity", "false", average, "specificity", zero_division + ) + + # Average the results + if average == "weighted": + weights = true_sum + elif average == "samples": + weights = sample_weight + else: + weights = None + + if average is not None: + assert average != "binary" or len(specificity) == 1 + specificity = _nanaverage(specificity, weights=weights) + if isinstance(specificity, (list, tuple, np.ndarray)) and len(specificity) > 1: + specificity = specificity[-1] + return specificity + + +def log_loss_weighted(y_true, y_pred, labels=None, class_weights=None): + ytrue = DataHandler.to_numpy(y_true).flatten() + if labels is None: + labels = DataHandler.unique(DataHandler.to_numpy(ytrue).flatten()) + else: + labels = DataHandler.to_numpy(labels).flatten() + if set(labels) - set(ytrue): + labels = DataHandler.unique(DataHandler.to_numpy(ytrue).flatten()) + + if class_weights is None: + class_weights = np.ones(len(labels)) + else: + class_weights = compute_class_weight(class_weights, classes=labels, y=ytrue) + + ypred = DataHandler.to_numpy(y_pred) + if ypred.ndim == 2 and ypred.shape[1] == 1: + ypred = ypred.flatten() + class_weights = dict(zip(labels, class_weights)) + sample_weights = np.array([class_weights[label] for label in ytrue]) + return log_loss(ytrue, ypred, sample_weight=sample_weights) + + +def calculate_metrics( + metrics: Union[dict, callable], y_true, y_pred, sub_task, labels=None +): + results = {} + + if labels is not None: + labels = list(range(len(labels))) + if callable(metrics): + params = inspect.signature(metrics).parameters + eval_kwargs = {} + if "labels" in params: + eval_kwargs["labels"] = labels + + results["custom_metric"].append(metrics(y_true, y_pred, **eval_kwargs)) + + else: + for metric_name, (metric_func, _) in metrics.items(): + if sub_task == "binary_classification": + if metric_name in ["logloss", "logloss_weighted", "auc"]: + results[metric_name] = metric_func( + y_true, DataHandler.select_column(y_pred, 1) + ) + else: + results[metric_name] = metric_func( + y_true, + DataHandler.ge(DataHandler.select_column(y_pred, 1), 0.5), + ) + elif sub_task == "multiclass_classification": + if metric_name in ( + "accuracy", + "balanced_accuracy", + "f1_macro", + "f1_micro", + "f1_weighted", + "precision_macro", + "precision_micro", + "precision_weighted", + "recall_macro", + "recall_micro", + "recall_weighted", + "specificity_macro", + "specificity_micro", + "specificity_weighted", + ): + results[metric_name] = metric_func( + y_true, DataHandler.argmax(y_pred, axis=1), labels + ) + else: + results[metric_name] = metric_func(y_true, y_pred, labels) + else: + if metric_name == "rmsle": + temp_pred = copy.deepcopy(DataHandler.to_numpy(y_pred)) + temp_pred = np.clip(temp_pred, 0, None) + results[metric_name] = metric_func(y_true, temp_pred) + else: + results[metric_name] = metric_func(y_true, y_pred) + return results diff --git a/src/biofit/models/__init__.py b/src/biofit/models/__init__.py new file mode 100644 index 0000000..6b809ad --- /dev/null +++ b/src/biofit/models/__init__.py @@ -0,0 +1,3 @@ +# ruff: noqa +from .random_forest import * +from .lightgbm import * diff --git a/src/biofit/models/ensemble/__init__.py b/src/biofit/models/ensemble/__init__.py new file mode 100644 index 0000000..6e3f30a --- /dev/null +++ b/src/biofit/models/ensemble/__init__.py @@ -0,0 +1,2 @@ +# ruff: noqa +from .random_forest import * diff --git a/src/biofit/models/ensemble/ensemble.py b/src/biofit/models/ensemble/ensemble.py new file mode 100644 index 0000000..e69de29 diff --git a/src/biofit/models/lasso/__init__.py b/src/biofit/models/lasso/__init__.py new file mode 100644 index 0000000..bf46f72 --- /dev/null +++ b/src/biofit/models/lasso/__init__.py @@ -0,0 +1,7 @@ +# ruff: noqa +from .lasso import ( + LassoModel, + LassoConfig, + LassoForClassification, + LassoForRegression, +) diff --git a/src/biofit/models/lasso/lasso.py b/src/biofit/models/lasso/lasso.py new file mode 100644 index 0000000..2d2ba88 --- /dev/null +++ b/src/biofit/models/lasso/lasso.py @@ -0,0 +1,355 @@ +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Optional, Union + +import numpy as np +from sklearn.linear_model import Lasso + +from biofit.integration.biosets import get_feature +from biofit.processing import SelectedColumnTypes, sync_backup_config +from biofit.utils import logging + +from ..logistic_regression import LogisticRegressionModel +from ..models import Model, ModelConfig + +logger = logging.get_logger(__name__) + +if TYPE_CHECKING: + pass + + +@dataclass +class LassoConfig(ModelConfig): + _fit_input_feature_types: List[Union[None, type]] = field( + default_factory=lambda: [None, get_feature("TARGET_FEATURE_TYPES")], + init=False, + repr=False, + ) + _fit_unused_feature_types: List[Union[None, type]] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES"), None], + init=False, + repr=False, + ) + _transform_unused_feature_types: List[type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + _fit_process_desc: str = field( + default="Fitting the lasso model", init=False, repr=False + ) + predict_process_desc: str = field( + default=None, + init=False, + repr=False, + ) + predict_proba_process_desc: str = field( + default=None, + init=False, + repr=False, + ) + processor_name: str = field(default="lasso", init=False, repr=False) + + input_columns: SelectedColumnTypes = None + target_column: SelectedColumnTypes = None + alpha = 1.0 + fit_intercept: bool = True + precompute: Union[bool, np.ndarray] = False + copy_X: bool = True + max_iter: int = 1000 + tol: float = 1e-4 + warm_start: bool = False + positive: bool = False + random_state: Optional[int] = 42 + class_weight: Union[dict, str] = None + selection: str = "cyclic" # "cyclic" or "random" + solver: str = "saga" + + use_predict_proba: bool = False + task: str = None + estimator: Lasso = field(default=None, init=False, repr=False) + n_classes: int = None + + def __post_init__(self): + if self.task is None: + self.task = "classification" if self.use_predict_proba else "regression" + self._fit_process_desc = f"Fitting the lasso {self.task} model" + self.predict_process_desc = f"Predicting with the lasso {self.task} model" + self.predict_proba_process_desc = ( + f"Predicting probabilities with the lasso {self.task} model" + ) + + +class LassoConfigForOTU(LassoConfig): + dataset_name: str = field(default="otu", init=False, repr=False) + log_transform: Union[str, bool] = field(default="log2_1p", init=False, repr=False) + + +class LassoModel(Model): + config_class = LassoConfig + config: LassoConfig + lasso: Lasso + + @sync_backup_config + def set_params(self, **kwargs): + self.config = self.config.replace_defaults(**kwargs) + self.lasso = Lasso( + alpha=self.config.alpha, + fit_intercept=self.config.fit_intercept, + precompute=self.config.precompute, + copy_X=self.config.copy_X, + max_iter=self.config.max_iter, + tol=self.config.tol, + warm_start=self.config.warm_start, + positive=self.config.positive, + random_state=self.config.random_state, + selection=self.config.selection, + ) + return self + + @property + def feature_importances_(self): + return self.config.estimator.coef_.flatten() + + @property + def feature_names_in_(self): + return self.config.feature_names_in_ + + def fit( + self, + X, + y=None, + input_columns: SelectedColumnTypes = None, + target_column: SelectedColumnTypes = None, + raise_if_missing: bool = None, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + num_proc: int = None, + map_kwargs: dict = {"fn_kwargs": {}}, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ) -> "LassoModel": + self.config._input_columns = self._set_input_columns_and_arity( + input_columns, target_column + ) + return self._process_fit( + X, + y, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ) + + def predict( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + self._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_transform( + X, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def predict_proba( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + self._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_transform( + X, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def _fit_sklearn(self, X, y): + if self.config.transform_log == "log2_1p": + X = np.log2(X + 1) + self.config.estimator = self.lasso.fit(X, y) + return self + + def _predict_sklearn(self, X): + return self.config.estimator.predict(X) + + def _predict_proba_sklearn(self, X): + logger.warning_once( + "predict_proba is not supported for Lasso. Returning using predict instead." + ) + return self._predict_sklearn(X) + + @staticmethod + def suggest_params(data): + return { + "alpha": ( + "suggest_float", + ( + 1e-8, + 1e3, + ), + {"log": True}, + ), + "fit_intercept": ("suggest_categorical", ([True, False],), {}), + "max_iter": ( + "suggest_int", + ( + 1000, + 10000, + ), + {}, + ), + } + + @staticmethod + def suggest_first_trial(): + return { + "alpha": 1.0, + "fit_intercept": True, + "max_iter": 1000, + } + + +# Lasso for Classification is a logistic regression model with L1 penalty, except +# that with the additional layer of converting the scores for classes to the "winning" +# class output label. +class LassoForClassification(LogisticRegressionModel): + class_config = LassoConfig + + def __init__( + self, + alpha=1.0, + fit_intercept: bool = True, + copy_X: bool = True, + max_iter: int = 1000, + tol: float = 1e-4, + warm_start: bool = False, + class_weight: str = None, + random_state: int = 42, + solver: str = "saga", + config: LassoConfig = None, + **kwargs, + ): + if "C" in kwargs: + alpha = 1.0 / kwargs.pop("C") + super().__init__( + C=1.0 / alpha, # C is the inverse of alpha + fit_intercept=fit_intercept, + max_iter=max_iter, + tol=tol, + penalty="l1", + solver=solver, + warm_start=warm_start, + class_weight=class_weight, + random_state=random_state, + config=config, + ) + + @sync_backup_config + def set_params(self, **kwargs): + if "alpha" in kwargs: + kwargs["C"] = 1.0 / kwargs.pop("alpha") + super().set_params(**kwargs) + return self + + @property + def classes_(self): + return self.config.estimator.classes_ + + @staticmethod + def suggest_params(data): + params = { + "C": ( + "suggest_float", + ( + 1e-4, + 1e2, + ), + {"log": True}, + ), + "tol": ( + "suggest_float", + ( + 1e-5, + 1e-2, + ), + {"log": True}, + ), + "fit_intercept": ("suggest_categorical", ([True, False],), {}), + "max_iter": ("suggest_categorical", ([100, 200, 500, 1000],), {}), + "class_weight": ("suggest_categorical", (["balanced", "None"],), {}), + } + return params + + @staticmethod + def suggest_first_trial(): + return { + "alpha": 1.0, + "tol": 1e-4, + "fit_intercept": True, + "max_iter": 1000, + "class_weight": "None", + } + + +class LassoForRegression(LassoModel): + pass diff --git a/src/biofit/models/lightgbm/__init__.py b/src/biofit/models/lightgbm/__init__.py new file mode 100644 index 0000000..5c69cd2 --- /dev/null +++ b/src/biofit/models/lightgbm/__init__.py @@ -0,0 +1,7 @@ +# ruff: noqa +from .lightgbm import ( + LightGBMModel, + LightGBMConfig, + LightGBMForClassification, + LightGBMForRegression, +) diff --git a/src/biofit/models/lightgbm/lightgbm.py b/src/biofit/models/lightgbm/lightgbm.py new file mode 100644 index 0000000..754b120 --- /dev/null +++ b/src/biofit/models/lightgbm/lightgbm.py @@ -0,0 +1,578 @@ +from dataclasses import dataclass, field, fields +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Type, Union + +import numpy as np +from biocore import DataHandler +from biocore.utils.import_util import is_lightgbm_available, requires_backends + +from biofit.integration.biosets import get_feature +from biofit.processing import SelectedColumnTypes, sync_backup_config +from biofit.utils import logging + +from ..models import Model, ModelConfig + +logger = logging.get_logger(__name__) + +if TYPE_CHECKING: + from lightgbm import LGBMClassifier, LGBMRegressor + from lightgbm.sklearn import _LGBM_ScikitCustomObjectiveFunction + +if is_lightgbm_available(): + from lightgbm import LGBMClassifier, LGBMRegressor + from lightgbm.sklearn import _LGBM_ScikitCustomObjectiveFunction + + +@dataclass +class LightGBMConfig(ModelConfig): + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [ + None, + get_feature("TARGET_FEATURE_TYPES"), + None, + None, + ], + init=False, + repr=False, + ) + _fit_unused_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("METADATA_FEATURE_TYPES"), + None, + None, + None, + ], + init=False, + repr=False, + ) + _transform_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + processor_name: str = field(default="lightgbm", init=False, repr=False) + boosting_type: str = "gbdt" + num_leaves: int = 31 + max_depth: int = -1 + learning_rate: float = 0.1 + n_estimators: int = 100 + subsample_for_bin: int = 200000 + objective: Optional[Union[str, "_LGBM_ScikitCustomObjectiveFunction"]] = None + eval_metric: str = None + class_weight: Optional[Union[Dict, str]] = None + min_split_gain: float = 0.0 + min_child_weight: float = 1e-3 + min_child_samples: int = 20 + subsample: float = 1.0 + subsample_freq: int = 0 + colsample_bytree: float = 1.0 + reg_alpha: float = 0.0 + reg_lambda: float = 0.0 + random_state: Optional[Union[int, np.random.RandomState]] = 42 + n_jobs: Optional[int] = None + importance_type: str = "split" + early_stopping_rounds: int = None + callbacks: Optional[List[Callable]] = None + verbosity: int = -1 + additional_kwargs: dict = field(default_factory=dict) + + n_classes: int = None + task: str = None + estimator: Union["LGBMClassifier", "LGBMRegressor"] = field( + default=None, init=False, repr=False + ) + + def __post_init__(self): + requires_backends(self.__class__, "lightgbm") + self._fit_process_desc = f"Fitting the LGBM {self.task} model" + self.predict_process_desc = f"Predicting with the LGBM {self.task} model" + self.predict_proba_process_desc = ( + f"Predicting probabilities with the LGBM {self.task} model" + ) + + +class LightGBMModel(Model): + config_class = LightGBMConfig + config: LightGBMConfig + lightgbm: Union["LGBMClassifier", "LGBMRegressor"] + + def __init__( + self, + boosting_type: str = "gbdt", + num_leaves: int = 31, + max_depth: int = -1, + learning_rate: float = 0.1, + n_estimators: int = 100, + subsample_for_bin: int = 200000, + objective: Optional[Union[str, "_LGBM_ScikitCustomObjectiveFunction"]] = None, + eval_metric: str = None, + class_weight: Optional[Union[Dict, str]] = None, + min_split_gain: float = 0.0, + min_child_weight: float = 1e-3, + min_child_samples: int = 20, + subsample: float = 1.0, + subsample_freq: int = 0, + colsample_bytree: float = 1.0, + reg_alpha: float = 0.0, + reg_lambda: float = 0.0, + random_state: Optional[Union[int, np.random.RandomState]] = None, + n_jobs: Optional[int] = None, + importance_type: str = "split", + callbacks: Optional[List[Callable]] = None, + verbosity: int = -1, + n_classes: int = None, + early_stopping_rounds: int = None, + task: str = None, + version: str = None, + config: Optional[LightGBMConfig] = None, + **kwargs, + ): + m_params = [f.name for f in fields(LightGBMConfig) if f.init] + biofit_params = {k: v for k, v in kwargs.items() if k in m_params} + lightgbm_params = {k: v for k, v in kwargs.items() if k not in m_params} + if "additional_kwargs" in biofit_params: + lightgbm_params.update(biofit_params.pop("additional_kwargs")) + super().__init__( + config=config, + boosting_type=boosting_type, + num_leaves=num_leaves, + max_depth=max_depth, + learning_rate=learning_rate, + n_estimators=n_estimators, + subsample_for_bin=subsample_for_bin, + objective=objective, + eval_metric=eval_metric, + class_weight=class_weight, + min_split_gain=min_split_gain, + min_child_weight=min_child_weight, + min_child_samples=min_child_samples, + subsample=subsample, + subsample_freq=subsample_freq, + colsample_bytree=colsample_bytree, + reg_alpha=reg_alpha, + reg_lambda=reg_lambda, + random_state=random_state, + n_jobs=n_jobs, + importance_type=importance_type, + early_stopping_rounds=early_stopping_rounds, + callbacks=callbacks, + verbosity=verbosity, + n_classes=n_classes, + task=task, + version=version, + additional_kwargs=lightgbm_params, + **biofit_params, + ) + if self.config.task is None: + raise ValueError( + "Task is not set. Please set the task before setting the parameters." + ) + + if "classification" in self.config.task: + self.lightgbm = LGBMClassifier( + boosting_type=self.config.boosting_type, + num_leaves=self.config.num_leaves, + max_depth=self.config.max_depth if self.config.max_depth > 0 else -1, + learning_rate=self.config.learning_rate, + n_estimators=self.config.n_estimators, + subsample_for_bin=self.config.subsample_for_bin, + objective=self.config.objective, + class_weight=self.config.class_weight + if self.config.class_weight != "None" + else None, + min_split_gain=self.config.min_split_gain, + min_child_weight=self.config.min_child_weight, + min_child_samples=self.config.min_child_samples, + subsample=self.config.subsample, + subsample_freq=self.config.subsample_freq, + colsample_bytree=self.config.colsample_bytree, + reg_alpha=self.config.reg_alpha, + reg_lambda=self.config.reg_lambda, + random_state=self.config.random_state, + n_jobs=self.config.n_jobs, + importance_type=self.config.importance_type, + verbosity=self.config.verbosity, + **self.config.additional_kwargs, + ) + elif "regression" in self.config.task: + self.lightgbm = LGBMRegressor( + boosting_type=self.config.boosting_type, + num_leaves=self.config.num_leaves, + max_depth=self.config.max_depth if self.config.max_depth > 0 else -1, + learning_rate=self.config.learning_rate, + n_estimators=self.config.n_estimators, + subsample_for_bin=self.config.subsample_for_bin, + objective=self.config.objective, + min_split_gain=self.config.min_split_gain, + min_child_weight=self.config.min_child_weight, + min_child_samples=self.config.min_child_samples, + subsample=self.config.subsample, + subsample_freq=self.config.subsample_freq, + colsample_bytree=self.config.colsample_bytree, + reg_alpha=self.config.reg_alpha, + reg_lambda=self.config.reg_lambda, + random_state=self.config.random_state, + n_jobs=self.config.n_jobs, + importance_type=self.config.importance_type, + verbosity=self.config.verbosity, + **self.config.additional_kwargs, + ) + else: + raise ValueError(f"Invalid task: {self.config.task}") + + @sync_backup_config + def set_params(self, **params): + self.config = self.config.replace_defaults(**params) + if self.config.task is None: + raise ValueError( + "Task is not set. Please set the task before setting the parameters." + ) + + if "classification" in self.config.task: + self.lightgbm = LGBMClassifier( + boosting_type=self.config.boosting_type, + num_leaves=self.config.num_leaves, + max_depth=self.config.max_depth if self.config.max_depth > 0 else -1, + learning_rate=self.config.learning_rate, + n_estimators=self.config.n_estimators, + subsample_for_bin=self.config.subsample_for_bin, + objective=self.config.objective, + class_weight=self.config.class_weight + if self.config.class_weight != "None" + else None, + min_split_gain=self.config.min_split_gain, + min_child_weight=self.config.min_child_weight, + min_child_samples=self.config.min_child_samples, + subsample=self.config.subsample, + subsample_freq=self.config.subsample_freq, + colsample_bytree=self.config.colsample_bytree, + reg_alpha=self.config.reg_alpha, + reg_lambda=self.config.reg_lambda, + random_state=self.config.random_state, + n_jobs=self.config.n_jobs, + importance_type=self.config.importance_type, + verbosity=self.config.verbosity, + ) + elif "regression" in self.config.task: + self.lightgbm = LGBMRegressor( + boosting_type=self.config.boosting_type, + num_leaves=self.config.num_leaves, + max_depth=self.config.max_depth if self.config.max_depth > 0 else -1, + learning_rate=self.config.learning_rate, + n_estimators=self.config.n_estimators, + subsample_for_bin=self.config.subsample_for_bin, + objective=self.config.objective, + min_split_gain=self.config.min_split_gain, + min_child_weight=self.config.min_child_weight, + min_child_samples=self.config.min_child_samples, + subsample=self.config.subsample, + subsample_freq=self.config.subsample_freq, + colsample_bytree=self.config.colsample_bytree, + reg_alpha=self.config.reg_alpha, + reg_lambda=self.config.reg_lambda, + random_state=self.config.random_state, + n_jobs=self.config.n_jobs, + importance_type=self.config.importance_type, + verbosity=self.config.verbosity, + ) + else: + raise ValueError(f"Invalid task: {self.config.task}") + return self + + def set_objective(self, task: str): + if task == "regression": + if self.config.eval_metric == "rmse": + self.config.objective = "regression" + elif self.config.eval_metric == "mae": + self.config.objective = "regression_l1" + elif task == "binary_classification": + self.config.objective = "binary" + if self.config.eval_metric == "accuracy": + self.config.eval_metric = "binary_error" + elif task == "multiclass_classification": + self.config.objective = "multiclass" + self.lightgbm = self.lightgbm.set_params(objective=self.config.objective) + return self + + @property + def feature_importances_(self): + return self.config.estimator.feature_importances_ + + @property + def feature_names_in_(self): + return self.config.feature_names_in_ + + def fit( + self, + X, + y=None, + eval_set=None, + input_columns: SelectedColumnTypes = None, + target_column: SelectedColumnTypes = None, + eval_input_columns: SelectedColumnTypes = None, + eval_target_column: SelectedColumnTypes = None, + raise_if_missing: bool = None, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + num_proc: int = None, + map_kwargs: dict = {"fn_kwargs": {}}, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + early_stopping_rounds: int = None, + ) -> "LightGBMModel": + self.config._input_columns = self._set_input_columns_and_arity( + input_columns, target_column, eval_input_columns, eval_target_column + ) + if eval_set is not None: + if isinstance(eval_set, tuple): + extras = eval_set + elif ( + isinstance(eval_set, list) + and len(eval_set) == 1 + and isinstance(eval_set[0], tuple) + ): + extras = eval_set[0] + else: + raise ValueError( + "eval_set must be a tuple or a list containing a single tuple" + ) + else: + extras = (None, None) + + return self._process_fit( + X, + y, + *extras, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ) + + def predict( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + self._method_prefix = "_predict" + self._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_transform( + X, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def predict_proba( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + self._method_prefix = "_predict_proba" + self._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_transform( + X, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def _fit_sklearn(self, X, y, eval_x=None, eval_y=None): + callbacks = self.config.callbacks + if self.config.early_stopping_rounds: + callbacks = callbacks or [] + from lightgbm.callback import early_stopping + + callbacks.append(early_stopping(self.config.early_stopping_rounds)) + + if eval_x is not None and eval_y is not None: + self.config.estimator = self.lightgbm.fit( + X, + y, + eval_set=[(eval_x, eval_y)], + eval_metric=self.config.eval_metric, + callbacks=callbacks, + ) + else: + self.config.estimator = self.lightgbm.fit( + X, y, eval_metric=self.config.eval_metric, callbacks=callbacks + ) + if self.config.early_stopping_rounds: + self.config.n_estimators = self.config.estimator.best_iteration_ + self.config.estimator.set_params(n_estimators=self.config.n_estimators) + self.config.early_stopping_rounds = None + return self + + def _predict_sklearn(self, X): + return self.config.estimator.predict(X) + + def _predict_proba_sklearn(self, X): + return self.config.estimator.predict_proba(X) + + def _get_features_out( + self, X, selected_indices=None, unselected_indices=None, **kwargs + ): + if self._method_prefix == "_predict_proba": + self.config._n_features_out = self.config.n_classes + self.output_dtype = "float64" + else: + self.config._n_features_out = 1 + return super()._get_features_out( + X, + selected_indices=selected_indices, + unselected_indices=unselected_indices, + one_to_one_features=False, + n_features_out=self.config._n_features_out, + keep_unused_columns=False, + ) + + @staticmethod + def suggest_params(data): + params = { + "verbosity": -1, + "n_estimators": ("suggest_int", [50, 1000], {"step": 50}), + "max_depth": ("suggest_int", [0, 11], {}), + "learning_rate": ("suggest_float", [1e-3, 0.1], {"log": True}), + "num_leaves": ("suggest_int", [31, 255], {}), + "subsample": ("suggest_float", [0.5, 1.0], {"step": 0.1}), + "reg_alpha": ("suggest_float", [1e-9, 1.0], {"log": True}), + "reg_lambda": ("suggest_float", [1e-9, 1.0], {"log": True}), + "min_child_samples": ("suggest_int", [10, 100], {}), + "min_split_gain": ("suggest_float", [0.0, 1.0], {}), + "num_threads": -1, + } + + if data is not None: + data_dim = DataHandler.get_shape(data) + + if data_dim[1] > data_dim[0]: + msg = "The number of features is greater than the number of samples" + max_feat_frac = max(0.5, data_dim[0] / data_dim[1]) + max_feat_frac = round(max_feat_frac, 1) + params["colsample_bytree"] = ( + "suggest_float", + [0.1, max_feat_frac], + {"step": 0.1}, + ) + if data_dim[1] > 1000: + msg += " and exceeds 1000" + params["n_estimators"] = ("suggest_int", [50, 500], {"step": 50}) + params["max_depth"] = ("suggest_int", [3, 8], {}) + params["num_leaves"] = ("suggest_int", [31, 63], {}) + params["min_child_samples"] = ("suggest_int", [10, 50], {}) + params["max_bin"] = ("suggest_int", [63, 127], {}) + msg += ". Adjusting the hyperparameters accordingly for faster training. " + logger.warning(msg) + + return params + + @staticmethod + def suggest_first_trial(): + return { + "n_estimators": 100, + "max_depth": 8, + "learning_rate": 0.1, + "num_leaves": 31, + "subsample": 1.0, + "reg_alpha": 1e-9, + "reg_lambda": 1e-9, + "min_child_samples": 20, + "colsample_bytree": 0.9, + "min_split_gain": 0.0, + } + + +class LightGBMForClassification(LightGBMModel): + def __init__(self, **kwargs): + kwargs["task"] = kwargs.get("task", "classification") + super().__init__(**kwargs) + + @property + def classes_(self): + return self.config.estimator.classes_ + + @sync_backup_config + def set_params(self, **params): + self.config.task = self.config.task or "classification" + return super().set_params(**params) + + def _process_fit_input(self, X, **kwargs): + X, kwargs = super()._process_fit_input(X, **kwargs) + self.set_objective(self.config.task) + return X, kwargs + + @staticmethod + def suggest_params(data): + params = LightGBMModel.suggest_params(data) + params["class_weight"] = ("suggest_categorical", [["balanced", "None"]], {}) + return params + + @staticmethod + def suggest_first_trial(): + return { + **LightGBMModel.suggest_first_trial(), + "class_weight": "balanced", + } + + +class LightGBMForRegression(LightGBMModel): + @sync_backup_config + def set_params(self, **params): + self.config.task = self.config.task or "regression" + return super().set_params(**params) diff --git a/src/biofit/models/logistic_regression/__init__.py b/src/biofit/models/logistic_regression/__init__.py new file mode 100644 index 0000000..d909e73 --- /dev/null +++ b/src/biofit/models/logistic_regression/__init__.py @@ -0,0 +1,5 @@ +# ruff: noqa +from .logistic_regression import ( + LogisticRegressionModel, + LogisticRegressionConfig, +) diff --git a/src/biofit/models/logistic_regression/logistic_regression.py b/src/biofit/models/logistic_regression/logistic_regression.py new file mode 100644 index 0000000..2cb15b2 --- /dev/null +++ b/src/biofit/models/logistic_regression/logistic_regression.py @@ -0,0 +1,246 @@ +from dataclasses import dataclass, field +from typing import List, Optional, Union + +from sklearn.linear_model import LogisticRegression + +from biofit.integration.biosets import get_feature +from biofit.processing import SelectedColumnTypes, sync_backup_config +from biofit.utils import logging + +from ..models import Model, ModelConfig + +logger = logging.get_logger(__name__) + + +@dataclass +class LogisticRegressionConfig(ModelConfig): + _fit_input_feature_types: List[Union[None, type]] = field( + default_factory=lambda: [None, get_feature("TARGET_FEATURE_TYPES")], + init=False, + repr=False, + ) + _fit_unused_feature_types: List[Union[None, type]] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES"), None], + init=False, + repr=False, + ) + _transform_unused_feature_types: List[type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + _fit_process_desc: str = field( + default="Fitting the lasso model", init=False, repr=False + ) + predict_process_desc: str = field( + default=None, + init=False, + repr=False, + ) + predict_proba_process_desc: str = field( + default=None, + init=False, + repr=False, + ) + processor_name: str = field(default="logistic_regression", init=False, repr=False) + + input_columns: SelectedColumnTypes = None + target_column: SelectedColumnTypes = None + + penalty: str = "l2" + dual: bool = False + tol: float = 1e-4 + C: float = 1.0 + fit_intercept: bool = True + intercept_scaling: float = 1 + class_weight: Union[dict, str] = None + random_state: Optional[int] = 42 + solver: str = "saga" + max_iter: int = 100 + multi_class: str = "auto" + verbose: int = 0 + warm_start: bool = False + n_jobs: Optional[int] = None + l1_ratio: Optional[float] = None + + use_predict_proba: bool = False + task: str = field(default="classification", init=False, repr=False) + estimator: LogisticRegression = field(default=None, init=False, repr=False) + n_classes: int = None + + def __post_init__(self): + self._fit_process_desc = "Fitting the logistic regression model" + self.predict_process_desc = "Predicting with the logistic regression model" + self.predict_proba_process_desc = ( + "Predicting probabilities with the logistic regression model" + ) + + +class LogisticRegressionModel(Model): + config_class = LogisticRegressionConfig + config: LogisticRegressionConfig + logistic_regression: LogisticRegression + + def __init__( + self, + penalty: str = "l2", + dual: bool = False, + tol: float = 1e-4, + C: float = 1.0, + fit_intercept: bool = True, + intercept_scaling: float = 1, + class_weight: Union[dict, str] = None, + random_state: int = 42, + solver: str = "lbfgs", + max_iter: int = 100, + multi_class: str = "auto", + verbose: int = 0, + warm_start: bool = False, + n_jobs: Optional[int] = None, + l1_ratio: Optional[float] = None, + config: LogisticRegressionConfig = None, + **kwargs, + ): + super().__init__( + penalty=penalty, + dual=dual, + tol=tol, + C=C, + fit_intercept=fit_intercept, + intercept_scaling=intercept_scaling, + class_weight=class_weight, + random_state=random_state, + solver=solver, + max_iter=max_iter, + multi_class=multi_class if multi_class != "None" else None, + verbose=verbose, + warm_start=warm_start, + n_jobs=n_jobs, + l1_ratio=l1_ratio, + config=config, + **kwargs, + ) + self.logistic_regression = LogisticRegression( + penalty=self.config.penalty, + dual=self.config.dual, + tol=self.config.tol, + C=self.config.C, + fit_intercept=self.config.fit_intercept, + intercept_scaling=self.config.intercept_scaling, + class_weight=self.config.class_weight + if self.config.class_weight != "None" + else None, + random_state=self.config.random_state, + solver=self.config.solver, + max_iter=self.config.max_iter, + multi_class=self.config.multi_class, + verbose=self.config.verbose, + warm_start=self.config.warm_start, + n_jobs=self.config.n_jobs, + l1_ratio=self.config.l1_ratio, + ) + + @sync_backup_config + def set_params(self, **kwargs): + self.config = self.config.replace_defaults(**kwargs) + self.logistic_regression = LogisticRegression( + penalty=self.config.penalty, + dual=self.config.dual, + tol=self.config.tol, + C=self.config.C, + fit_intercept=self.config.fit_intercept, + intercept_scaling=self.config.intercept_scaling, + class_weight=self.config.class_weight + if self.config.class_weight != "None" + else None, + random_state=self.config.random_state, + solver=self.config.solver, + max_iter=self.config.max_iter, + multi_class=self.config.multi_class, + verbose=self.config.verbose, + warm_start=self.config.warm_start, + n_jobs=self.config.n_jobs, + l1_ratio=self.config.l1_ratio, + ) + return self + + @property + def feature_importances_(self): + return self.config.estimator.coef_.flatten() + + @property + def feature_names_in_(self): + return self.config.feature_names_in_ + + def fit( + self, + X, + y=None, + input_columns: SelectedColumnTypes = None, + target_column: SelectedColumnTypes = None, + raise_if_missing: bool = None, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + num_proc: int = None, + map_kwargs: dict = {"fn_kwargs": {}}, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ) -> "LogisticRegressionModel": + self.config._input_columns = self._set_input_columns_and_arity( + input_columns, target_column + ) + return self._process_fit( + X, + y, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ) + + def _fit_sklearn(self, X, y): + self.config.estimator = self.logistic_regression.fit(X, y) + return self + + def _predict_sklearn(self, X): + return self.config.estimator.predict(X) + + def _predict_proba_sklearn(self, X): + return self.config.estimator.predict_proba(X) + + @staticmethod + def suggest_params(data): + params = { + "C": ( + "suggest_float", + ( + 1e-4, + 1e2, + ), + {"log": True}, + ), + "tol": ( + "suggest_float", + ( + 1e-5, + 1e-2, + ), + {"log": True}, + ), + "max_iter": ("suggest_categorical", ([100, 200, 500, 1000],), {}), + "penalty": "l1", + "n_jobs": -1, + "class_weight": ("suggest_categorical", (["balanced", "None"],), {}), + } + return params diff --git a/src/biofit/models/models.py b/src/biofit/models/models.py new file mode 100644 index 0000000..52a58d8 --- /dev/null +++ b/src/biofit/models/models.py @@ -0,0 +1,223 @@ +from dataclasses import dataclass, field +from functools import wraps + +import pyarrow as pa +from biocore import DataHandler +from biocore.utils.import_util import is_datasets_available + +from biofit.processing import BaseProcessor, ProcessorConfig, SelectedColumnTypes +from biofit.utils import logging + +logger = logging.get_logger(__name__) + + +@dataclass +class ModelConfig(ProcessorConfig): + """Base class for feature extraction processor configurations.""" + + _fit_process_desc: str = field(default="Fitting the model", init=False, repr=False) + predict_process_desc: str = field( + default="Predicting target output", init=False, repr=False + ) + predict_proba_process_desc: str = field( + default="Predicting target output probabilities", init=False, repr=False + ) + processor_type: str = field(default="models", init=False, repr=False) + _missing_val: float = field(default=0, init=False, repr=False) + _missing_val_pa_type: pa.DataType = field( + default=pa.float64(), init=False, repr=False + ) + class_names: list = field(default=None, init=False, repr=False) + + task: str = None + + +class Model(BaseProcessor): + """Base class for models.""" + + @wraps(BaseProcessor._process_transform) + def predict( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + """Predict the model.""" + self._method_prefix = "_predict" + self.output_dtype = "int64" + self._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_transform( + X, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def predict_proba( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + self._method_prefix = "_predict_proba" + self.output_dtype = "float64" + self._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_transform( + X, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + @staticmethod + def suggest_params() -> dict: + """Get hyperparameters for an optuna trial.""" + raise NotImplementedError + + def _process_fit_input(self, X, **kwargs): + target_col = kwargs["fn_kwargs"]["extra_indices"][0] + if target_col and "class" in self.config.task.lower(): + if isinstance(target_col, list): + target_col = target_col[0] + + target_col = DataHandler.get_column_names(X, generate_cols=True)[target_col] + + if self.config.n_classes is None: + self.config.n_classes = DataHandler.nunique(X, target_col) + + if is_datasets_available(): + from datasets import Dataset + + if isinstance(X, Dataset): + feat = X._info.features[target_col] + self.config.class_names = feat.names + + if self.config.n_classes > 2: + self.config.task = "multiclass_classification" + else: + self.config.task = "binary_classification" + return super()._process_fit_input(X, **kwargs) + + def _process_fit_output(self, input, out): + if hasattr(self, "classes_"): + if self.config.class_names is not None and isinstance( + self.classes_[0], int + ): + if sorted(self.classes_.tolist()) == list(range(self.config.n_classes)): + self.config.class_names = [ + self.config.class_names[i] for i in self.classes_ + ] + else: + self.config.class_names = [str(i) for i in self.classes_] + else: + self.config.class_names = [str(i) for i in self.classes_] + return super()._process_fit_output(input, out) + + def _process_transform_input(self, X, **kwargs): + if self._method_prefix == "_predict_proba": + self.config._n_features_out = self.config.n_classes + self.config._feature_names_out = self.config.class_names + self.output_dtype = "float64" + else: + if self.config.extra_names_in_ is not None: + self.config._feature_names_out = self.config.extra_names_in_[0] + self.config._n_features_out = 1 + return super()._process_transform_input(X, **kwargs) + + def _process_transform_batch_input(self, X, *fn_args, **fn_kwargs): + X, args, kwargs = super()._process_transform_batch_input( + X, *fn_args, **fn_kwargs + ) + if DataHandler.supports_named_columns(X): + input_cols = DataHandler.get_column_names(X) + missing_cols = set(self.config.feature_names_in_) - set(input_cols) + intersecting_cols = set(input_cols) & set(self.config.feature_names_in_) + + X_dims = DataHandler.get_shape(X) + num_cols = None + if len(X_dims) > 1: + num_cols = X_dims[1] + + num_non_existing = num_cols - len(intersecting_cols) + if num_non_existing: + logger.warning_once( + f"Dataset has {num_non_existing} out of {num_cols} columns that were " + "not in the training data. Dropping these columns." + ) + X = DataHandler.select_columns(X, list(intersecting_cols)) + + if missing_cols: + if self.config._missing_val is None: + self.config._missing_val_str = "`None`" + elif self.config._missing_val == 0: + self.config._missing_val_str = "zeroes" + else: + self.config._missing_val_str = f"{self.config._missing_val}" + logger.warning_once( + f"Dataset is missing {len(missing_cols)} out of " + f"{len(self.config.feature_names_in_)} columns that were in the " + "training data. Adding these columns as " + f"{self.config._missing_val_str}." + ) + num_rows = DataHandler.get_shape(X)[0] + zeros_mat = pa.table( + { + col: pa.array( + [self.config._missing_val] * num_rows, + type=self.config._missing_val_pa_type, + ) + for col in missing_cols + } + ) + X = DataHandler.concat( + [X, DataHandler.to_format(zeros_mat, DataHandler.get_format(X))], + axis=1, + ) + # Reorder columns to match the training data + X = DataHandler.select_columns(X, self.config.feature_names_in_) + + return X, args, kwargs diff --git a/src/biofit/models/random_forest/__init__.py b/src/biofit/models/random_forest/__init__.py new file mode 100644 index 0000000..5aaf26a --- /dev/null +++ b/src/biofit/models/random_forest/__init__.py @@ -0,0 +1,7 @@ +# ruff: noqa +from .random_forest import ( + RandomForestModel, + RandomForestConfig, + RandomForestForClassification, + RandomForestForRegression, +) diff --git a/src/biofit/models/random_forest/random_forest.py b/src/biofit/models/random_forest/random_forest.py new file mode 100644 index 0000000..f98269e --- /dev/null +++ b/src/biofit/models/random_forest/random_forest.py @@ -0,0 +1,428 @@ +from dataclasses import dataclass, field +from typing import List, Union + +from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor + +from biofit.integration.biosets import get_feature +from biofit.processing import SelectedColumnTypes, sync_backup_config + +from ..models import Model, ModelConfig + + +@dataclass +class RandomForestConfig(ModelConfig): + _fit_input_feature_types: List[Union[None, type]] = field( + default_factory=lambda: [None, get_feature("TARGET_FEATURE_TYPES")], + init=False, + repr=False, + ) + _fit_unused_feature_types: List[Union[None, type]] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES"), None], + init=False, + repr=False, + ) + _transform_unused_feature_types: List[type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + _fit_process_desc: str = field( + default="Fitting the random forest model", init=False, repr=False + ) + predict_proba_process_desc: str = field( + default=None, + init=False, + repr=False, + ) + processor_name: str = field(default="random_forest", init=False, repr=False) + + n_estimators: int = 100 + max_depth: int = None + min_samples_split: int = 2 + min_samples_leaf: int = 1 + min_weight_fraction_leaf: float = 0.0 + max_leaf_nodes: int = None + min_impurity_decrease: float = 0.0 + bootstrap: bool = False + oob_score: bool = False + n_jobs: int = None + random_state: int = 42 + verbose: int = 0 + warm_start: bool = False + ccp_alpha: float = 0.0 + max_samples: int = None + monotonic_cst: dict = None + criterion: str = None # gini for classification, squared_error for regression + max_features: Union[str, int, float] = "sqrt" + + class_weight: Union[str, dict, List[dict]] = None + n_classes: int = None + + use_predict_proba: bool = False + task: str = None + + estimator: Union[RandomForestClassifier, RandomForestRegressor] = field( + default=None, init=False, repr=False + ) + + +class RandomForestModel(Model): + config_class = RandomForestConfig + config: RandomForestConfig + random_forest: Union[RandomForestClassifier, RandomForestRegressor] + + def __init__( + self, + n_estimators: int = 100, + max_depth: int = None, + min_samples_split: int = 2, + min_samples_leaf: int = 1, + min_weight_fraction_leaf: float = 0.0, + max_leaf_nodes: int = None, + min_impurity_decrease: float = 0.0, + bootstrap: bool = True, + oob_score: bool = False, + n_jobs: int = None, + random_state: int = None, + verbose: int = 0, + warm_start: bool = False, + ccp_alpha: float = 0.0, + max_samples: int = None, + monotonic_cst: dict = None, + criterion: str = None, + max_features: Union[str, int, float] = "sqrt", + class_weight: Union[str, dict, List[dict]] = None, + task: str = None, + config: RandomForestConfig = None, + **kwargs, + ): + super().__init__( + config=config, + n_estimators=n_estimators, + max_depth=max_depth, + min_samples_split=min_samples_split, + min_samples_leaf=min_samples_leaf, + min_weight_fraction_leaf=min_weight_fraction_leaf, + max_leaf_nodes=max_leaf_nodes, + min_impurity_decrease=min_impurity_decrease, + bootstrap=bootstrap, + oob_score=oob_score, + n_jobs=n_jobs, + random_state=random_state, + verbose=verbose, + warm_start=warm_start, + ccp_alpha=ccp_alpha, + max_samples=max_samples, + monotonic_cst=monotonic_cst, + criterion=criterion, + max_features=max_features, + task=task, + class_weight=class_weight, + **kwargs, + ) + + if self.config.task == "classification": + self.random_forest = RandomForestClassifier( + n_estimators=self.config.n_estimators, + criterion=self.config.criterion if self.config.criterion else "gini", + max_depth=self.config.max_depth if self.config.max_depth else None, + min_samples_split=self.config.min_samples_split, + min_samples_leaf=self.config.min_samples_leaf, + min_weight_fraction_leaf=self.config.min_weight_fraction_leaf, + max_features=self.config.max_features, + max_leaf_nodes=self.config.max_leaf_nodes, + min_impurity_decrease=self.config.min_impurity_decrease, + bootstrap=self.config.bootstrap, + oob_score=self.config.oob_score, + n_jobs=self.config.n_jobs, + random_state=self.config.random_state, + verbose=self.config.verbose, + warm_start=self.config.warm_start, + class_weight=self.config.class_weight + if self.config.class_weight and self.config.class_weight != "None" + else None, + ccp_alpha=self.config.ccp_alpha, + max_samples=self.config.max_samples, + monotonic_cst=self.config.monotonic_cst, + ) + else: + self.random_forest = RandomForestRegressor( + n_estimators=self.config.n_estimators, + criterion=self.config.criterion + if self.config.criterion + else "squared_error", + max_depth=self.config.max_depth if self.config.max_depth > 0 else None, + min_samples_split=self.config.min_samples_split, + min_samples_leaf=self.config.min_samples_leaf, + min_weight_fraction_leaf=self.config.min_weight_fraction_leaf, + max_features=self.config.max_features, + max_leaf_nodes=self.config.max_leaf_nodes, + min_impurity_decrease=self.config.min_impurity_decrease, + bootstrap=self.config.bootstrap, + oob_score=self.config.oob_score, + n_jobs=self.config.n_jobs, + random_state=self.config.random_state, + verbose=self.config.verbose, + warm_start=self.config.warm_start, + ccp_alpha=self.config.ccp_alpha, + max_samples=self.config.max_samples, + monotonic_cst=self.config.monotonic_cst, + ) + + @sync_backup_config + def set_params(self, **kwargs): + self.config = self.config.replace_defaults(**kwargs) + if self.config.task == "classification": + self.random_forest = RandomForestClassifier( + n_estimators=self.config.n_estimators, + criterion=self.config.criterion if self.config.criterion else "gini", + max_depth=self.config.max_depth if self.config.max_depth else None, + min_samples_split=self.config.min_samples_split, + min_samples_leaf=self.config.min_samples_leaf, + min_weight_fraction_leaf=self.config.min_weight_fraction_leaf, + max_features=self.config.max_features, + max_leaf_nodes=self.config.max_leaf_nodes, + min_impurity_decrease=self.config.min_impurity_decrease, + bootstrap=self.config.bootstrap, + oob_score=self.config.oob_score, + n_jobs=self.config.n_jobs, + random_state=self.config.random_state, + verbose=self.config.verbose, + warm_start=self.config.warm_start, + class_weight=self.config.class_weight + if self.config.class_weight and self.config.class_weight != "None" + else None, + ccp_alpha=self.config.ccp_alpha, + max_samples=self.config.max_samples, + monotonic_cst=self.config.monotonic_cst, + ) + else: + self.random_forest = RandomForestRegressor( + n_estimators=self.config.n_estimators, + criterion=self.config.criterion + if self.config.criterion + else "squared_error", + max_depth=self.config.max_depth if self.config.max_depth > 0 else None, + min_samples_split=self.config.min_samples_split, + min_samples_leaf=self.config.min_samples_leaf, + min_weight_fraction_leaf=self.config.min_weight_fraction_leaf, + max_features=self.config.max_features, + max_leaf_nodes=self.config.max_leaf_nodes, + min_impurity_decrease=self.config.min_impurity_decrease, + bootstrap=self.config.bootstrap, + oob_score=self.config.oob_score, + n_jobs=self.config.n_jobs, + random_state=self.config.random_state, + verbose=self.config.verbose, + warm_start=self.config.warm_start, + ccp_alpha=self.config.ccp_alpha, + max_samples=self.config.max_samples, + monotonic_cst=self.config.monotonic_cst, + ) + if self.config.use_predict_proba: + self.output_dtype = "float32" + return self + + @property + def feature_importances_(self): + return self.config.estimator.feature_importances_ + + @property + def feature_names_in_(self): + return self.config.feature_names_in_ + + def fit( + self, + X, + y=None, + input_columns: SelectedColumnTypes = None, + target_column: SelectedColumnTypes = None, + raise_if_missing: bool = None, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + num_proc: int = None, + map_kwargs: dict = {"fn_kwargs": {}}, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ) -> "RandomForestModel": + self.config._input_columns = self._set_input_columns_and_arity( + input_columns, target_column + ) + return self._process_fit( + X, + y, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ) + + def predict( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + self._method_prefix = "_predict" + self._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_transform( + X, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def predict_proba( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + self._method_prefix = "_predict_proba" + self._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_transform( + X, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def _fit_sklearn(self, X, y): + self.config.estimator = self.random_forest.fit(X, y) + return self + + def _predict_sklearn(self, X): + return self.config.estimator.predict(X) + + def _predict_proba_sklearn(self, X): + return self.config.estimator.predict_proba(X) + + @staticmethod + def suggest_params(data): + params = { + "n_estimators": ("suggest_int", (100, 2000), {}), + "max_depth": ("suggest_int", (0, 11), {}), + "min_samples_split": ("suggest_int", (2, 20), {}), + "min_samples_leaf": ("suggest_int", (1, 20), {}), + # "bootstrap": ("suggest_categorical", ([True, False],), {}), + # "max_features": ("suggest_categorical", (["sqrt", "log2"],), {}), + "n_jobs": -1, + } + + return params + + @staticmethod + def suggest_first_trial(): + return { + "n_estimators": 100, + "max_depth": 0, + "min_samples_split": 2, + "min_samples_leaf": 1, + "n_jobs": -1, + } + + +class RandomForestForClassification(RandomForestModel): + """Random Forest model for classification tasks""" + + @property + def classes_(self): + return self.config.estimator.classes_ + + @sync_backup_config + def set_params(self, **kwargs): + kwargs.pop("task", None) + self.config.task = "classification" + return super().set_params(**kwargs) + + @staticmethod + def suggest_params(data): + params = RandomForestModel.suggest_params(data) + params["criterion"] = ("suggest_categorical", (["gini", "entropy"],), {}) + params["class_weight"] = ( + "suggest_categorical", + (["balanced", "balanced_subsample", "None"],), + {}, + ) + return params + + @staticmethod + def suggest_first_trial(): + return { + **RandomForestModel.suggest_first_trial(), + "criterion": "gini", + "class_weight": "balanced", + } + + +class RandomForestForRegression(RandomForestModel): + @sync_backup_config + def set_params(self, **kwargs): + kwargs.pop("task", None) + self.config.task = "regression" + return super().set_params(**kwargs) + + @staticmethod + def suggest_params(data): + params = RandomForestModel.suggest_params(data) + params["criterion"] = ( + "suggest_categorical", + (["squared_error", "absolute_error", "poisson"],), + {}, + ) + return params diff --git a/src/biofit/preprocessing/__init__.py b/src/biofit/preprocessing/__init__.py new file mode 100644 index 0000000..fe2129d --- /dev/null +++ b/src/biofit/preprocessing/__init__.py @@ -0,0 +1,10 @@ +# ruff: noqa +# from .auto import * +from .feature_selection import * +from .filtering import * +from .scaling import * +from .feature_extraction import * +from .imputation import * +from .encoding import * +from .resampling import * +from .transformation import * diff --git a/src/biofit/preprocessing/encoding/__init__.py b/src/biofit/preprocessing/encoding/__init__.py new file mode 100644 index 0000000..de1415a --- /dev/null +++ b/src/biofit/preprocessing/encoding/__init__.py @@ -0,0 +1,3 @@ +# ruff: noqa +from .label_binarizing import * +from .label_encoding import * diff --git a/src/biofit/preprocessing/encoding/encoding.py b/src/biofit/preprocessing/encoding/encoding.py new file mode 100644 index 0000000..9f5e44e --- /dev/null +++ b/src/biofit/preprocessing/encoding/encoding.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass, field + +from biofit.processing import BaseProcessor, ProcessorConfig + + +@dataclass +class EncoderConfig(ProcessorConfig): + """Base class for feature extraction processor configurations.""" + + processor_type: str = field(default="encoding", init=False, repr=False) + + +class Encoder(BaseProcessor): + """Base class for feature extraction processors.""" diff --git a/src/biofit/preprocessing/encoding/label_binarizing/__init__.py b/src/biofit/preprocessing/encoding/label_binarizing/__init__.py new file mode 100644 index 0000000..c9ebf6a --- /dev/null +++ b/src/biofit/preprocessing/encoding/label_binarizing/__init__.py @@ -0,0 +1,2 @@ +# ruff: noqa +from .label_binarizing import LabelBinarizer, LabelBinarizerConfig diff --git a/src/biofit/preprocessing/encoding/label_binarizing/label_binarizing.py b/src/biofit/preprocessing/encoding/label_binarizing/label_binarizing.py new file mode 100644 index 0000000..362c926 --- /dev/null +++ b/src/biofit/preprocessing/encoding/label_binarizing/label_binarizing.py @@ -0,0 +1,405 @@ +from dataclasses import dataclass, field +from typing import List, Type, Union + +from biocore.utils.import_util import is_biosets_available +import numpy as np +import pyarrow as pa +from biocore import DataHandler + +from biofit.integration.biosets import get_feature +from biofit.processing import SelectedColumnTypes, sync_backup_config +from biofit.utils import logging + +from ..encoding import Encoder, EncoderConfig + +logger = logging.get_logger(__name__) + + +@dataclass +class LabelBinarizerConfig(EncoderConfig): + """Configuration class for label binarizer.""" + + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("ClassLabel")], init=False, repr=False + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("ClassLabel")], init=False, repr=False + ) + + processor_name: str = field(default="label_binarizing", init=False, repr=False) + _transform_process_desc: str = field(default=None, init=False, repr=False) + positive_labels: list = field(default_factory=list) + negative_labels: list = field(default_factory=list) + as_one_hot: bool = False + names: list = field(default_factory=list, init=False, repr=False) + + # auto attributes + label_mapping: dict = field(default_factory=dict, init=False, repr=False) + + def __post_init__(self): + if self.as_one_hot: + self._transform_process_desc = "One-hot encoding labels" + self._fit_process_desc = "Determining unique labels" + if self.names: + self._transform_process_desc = ( + f"One-hot encoding labels with {self.names}" + ) + self.label_mapping = { + str(label): i for i, label in enumerate(self.names) + } + else: + if not isinstance(self.positive_labels, list): + self.positive_labels = [self.positive_labels] + if not isinstance(self.negative_labels, list): + self.negative_labels = [self.negative_labels] + if len(self.positive_labels) < 1 and len(self.negative_labels) < 1: + raise ValueError( + "At least one of positive_labels or negative_labels must be provided" + ) + + self._transform_process_desc = f"Binarizing labels with positive_labels={self.positive_labels} and negative_labels={self.negative_labels}" + + if self.positive_labels: + self._transform_process_desc = ( + f"Binarizing labels with positive_labels={self.positive_labels}" + ) + + if self.negative_labels: + self._transform_process_desc += ( + f" and negative_labels={self.negative_labels}" + ) + else: + self._transform_process_desc = ( + f"Binarizing labels with negative_labels={self.negative_labels}" + ) + + if not isinstance(self.positive_labels, list): + self.positive_labels = [self.positive_labels] + + if not isinstance(self.negative_labels, list): + self.negative_labels = [self.negative_labels] + + self.label_mapping = {str(label): 1 for label in self.positive_labels} + self.label_mapping.update({str(label): 0 for label in self.negative_labels}) + + +class LabelBinarizer(Encoder): + config_class = LabelBinarizerConfig + config: LabelBinarizerConfig + + def __init__( + self, + positive_labels: list = [], + negative_labels: list = [], + names: list = [], + as_one_hot: bool = False, + config: LabelBinarizerConfig = None, + **kwargs, + ): + """ + Args: + positive_labels (list, *optional*): + The labels to be considered as positive. Default is `[]`. + negative_labels (list, *optional*): + The labels to be considered as negative. Default is `[]`. + as_one_hot (bool, *optional*): + Whether to encode the labels as one-hot vectors. Default is `False`. + names (list, *optional*): + The names of the classes. This is required when `as_one_hot` is `True`. + config (LabelBinarizerConfig, *optional*): + **kwargs: + Arguments that are passed to ProcessorConfig. + """ + super().__init__( + config=config, + positive_labels=positive_labels, + negative_labels=negative_labels, + as_one_hot=as_one_hot, + names=names, + **kwargs, + ) + self.output_feature_type = get_feature("BinClassLabel") + + @sync_backup_config + def set_params(self, **kwargs): + self.config = self.config.replace_defaults(**kwargs) + self.config.__post_init__() + return self + + def fit( + self, + X, + input_columns: SelectedColumnTypes = None, + raise_if_missing: bool = None, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + num_proc: int = None, + map_kwargs: dict = {"fn_kwargs": {}}, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ) -> "LabelBinarizer": + self.config._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_fit( + X, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ) + + def transform( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + self._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_transform( + X, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def fit_transform( + self, + X, + *args, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = True, + raise_if_missing: bool = False, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = {"fn_kwargs": {}}, + num_proc: int = None, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ): + return self.fit( + X, + input_columns=input_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ).transform( + X, + input_columns=input_columns, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def _get_features_out(self, X, **kwargs): + kwargs.pop("one_to_one_features", None) + features = super()._get_features_out(X, one_to_one_features=True, **kwargs) + + input_names = DataHandler.get_column_names(X, generate_cols=True) + input_names = [input_names[i] for i in self._selected_indices] + if self.config.as_one_hot and self.config.names: + # incase that the user provided the pre-encoded labels of get_feature("ClassLabel") + for name in input_names: + label_feature = features.pop(name, None) + if isinstance(label_feature, get_feature("ClassLabel")): + label_names = label_feature.names + if self.config.names and isinstance(self.config.names[0], str): + self.config.label_mapping = { + str(label_names.index(label)): i + for i, label in enumerate(self.config.names) + } + features.update( + { + name: get_feature("ClassLabel")(names=["negative", "positive"]) + for name in self.config.names + } + ) + return features + + if is_biosets_available(): + for name in input_names: + # incase that the user provided the pre-encoded labels of get_feature("ClassLabel") + label_feature = features.get(name) + if isinstance(label_feature, get_feature("ClassLabel")): + label_names = label_feature.names + if self.config.positive_labels and isinstance( + self.config.positive_labels[0], str + ): + self.config.label_mapping = { + str(label_names.index(label)): 1 + for label in self.config.positive_labels + } + if self.config.negative_labels and isinstance( + self.config.negative_labels[0], str + ): + self.config.label_mapping.update( + { + str(label_names.index(label)): 0 + for label in self.config.negative_labels + } + ) + label_feature = get_feature("BinClassLabel")( + names=self.config.names or ["negative", "positive"], + positive_labels=self.config.positive_labels, + negative_labels=self.config.negative_labels, + ) + features[name] = label_feature + + return features + + def _process_fit_input(self, input, **kwargs): + if not self.config.as_one_hot or self.config.names: + kwargs["fn_kwargs"]["fn"] = None + if self.config.names: + self.config.label_mapping = { + str(label): i for i, label in enumerate(self.config.names) + } + return super()._process_fit_input(input, **kwargs) + + def _check_data(self, X): + X_dims = DataHandler.get_shape(X) + if len(X_dims) > 1: + if X_dims[1] == 1: + out = DataHandler.select_column(X, 0) + else: + raise ValueError( + f"Expected input to have 1 column, got {X_dims[1]} columns" + ) + else: + out = X + return DataHandler.to_format(out, "list") + + def _fit_array(self, X): + out = self._check_data(X) + self.config.names = DataHandler.to_format(DataHandler.unique(out), "list") + self.config._transform_process_desc = ( + f"Encoding labels with: {self.config.label_mapping}" + ) + return self + + def _partial_fit_array(self, X): + out = self._check_data(X) + labs = DataHandler.to_format(DataHandler.unique(out), "list") + self.config.names = np.unique(self.config.names + labs).tolist() + return self + + def _pool_fit(self, fitted_processors): + names = [] + for processor in fitted_processors: + if isinstance(processor, LabelBinarizer): + names += processor.config.names + self.config.names = np.unique(names).tolist() + self.config.label_mapping = { + str(label): i for i, label in enumerate(self.config.names) + } + self.config._transform_process_desc = ( + f"Encoding labels with: {self.config.label_mapping}" + ) + return self + + def _process_fit_output(self, input, out): + if self.config.as_one_hot: + self.config._n_features_out = len(self.config.names) + self.output_dtype = "int64" + else: + self.config._n_features_out = None + self.config.label_mapping = { + str(k): v for k, v in self.config.label_mapping.items() + } + return super()._process_fit_output(input, out) + + def _one_hot_transform(self, X): + col = self._check_data(X) + + labs = np.zeros((len(col), len(self.config.names) + 1), dtype=np.int64) + + inds = [self.config.label_mapping.get(str(val), -1) for val in col] + labs[np.arange(len(col)), inds] = 1 + return labs[:, :-1] + + def _binarize_transform(self, X: Union[pa.Table, pa.Array]): + col = self._check_data(X) + + if self.config.positive_labels and self.config.negative_labels: + labs = [-1] * len(col) + elif self.config.positive_labels: + labs = [0] * len(col) + else: + labs = [1] * len(col) + + if self.config.positive_labels and self.config.negative_labels: + labs = [ + self.config.label_mapping.get(str(val), -1) + if labs[i] == -1 + else labs[i] + for i, val in enumerate(col) + ] + elif self.config.positive_labels: + labs = [ + self.config.label_mapping.get(str(val), 0) if labs[i] == 0 else labs[i] + for i, val in enumerate(col) + ] + else: + labs = [ + self.config.label_mapping.get(str(val), 1) if labs[i] == 1 else labs[i] + for i, val in enumerate(col) + ] + + return pa.array(labs) + + def _transform_array(self, X): + if not self.config.as_one_hot: + return self._binarize_transform(X) + else: + return self._one_hot_transform(X) diff --git a/src/biofit/preprocessing/encoding/label_encoding/__init__.py b/src/biofit/preprocessing/encoding/label_encoding/__init__.py new file mode 100644 index 0000000..cbb513d --- /dev/null +++ b/src/biofit/preprocessing/encoding/label_encoding/__init__.py @@ -0,0 +1,2 @@ +# ruff: noqa +from .label_encoding import LabelEncoder, LabelEncoderConfig diff --git a/src/biofit/preprocessing/encoding/label_encoding/label_encoding.py b/src/biofit/preprocessing/encoding/label_encoding/label_encoding.py new file mode 100644 index 0000000..dce470a --- /dev/null +++ b/src/biofit/preprocessing/encoding/label_encoding/label_encoding.py @@ -0,0 +1,243 @@ +from dataclasses import dataclass, field +from typing import List, Type + +import numpy as np +from biocore import DataHandler + +from biofit.integration.biosets import get_feature +from biofit.processing import SelectedColumnTypes, sync_backup_config +from biofit.utils import logging + +from ..encoding import Encoder, EncoderConfig + +logger = logging.get_logger(__name__) + + +@dataclass +class LabelEncoderConfig(EncoderConfig): + """Configuration class for label encoding.""" + + _fit_process_desc: str = field( + default="Determining unique labels", init=False, repr=False + ) + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("ClassLabel")], init=False, repr=False + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("ClassLabel")], init=False, repr=False + ) + + processor_name: str = field(default="label_encoding", init=False, repr=False) + _transform_process_desc: str = field(default=None, init=False, repr=False) + + names: list = field(default_factory=list, init=False, repr=False) + + # auto attributes + label_mapping: dict = field(default_factory=dict, init=False, repr=False) + + def __post_init__(self): + if self.names: + self._transform_process_desc = f"Encoding labels with {self.names}" + self.label_mapping = {label: i for i, label in enumerate(self.names)} + + +class LabelEncoder(Encoder): + config_class = LabelEncoderConfig + config: LabelEncoderConfig + output_feature_type = get_feature("ClassLabel") + + @sync_backup_config + def set_params(self, **kwargs): + self.config = self.config.replace_defaults(**kwargs) + self.config.__post_init__() + return self + + def fit( + self, + X, + input_columns: SelectedColumnTypes = None, + raise_if_missing: bool = None, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + num_proc: int = None, + map_kwargs: dict = {"fn_kwargs": {}}, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ) -> "LabelEncoder": + self.config._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_fit( + X, + input_columns=input_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ) + + def transform( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + self._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_transform( + X, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def fit_transform( + self, + X, + *args, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = True, + raise_if_missing: bool = False, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = {"fn_kwargs": {}}, + num_proc: int = None, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ): + return self.fit( + X, + input_columns=input_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ).transform( + X, + input_columns=input_columns, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def _get_features_out(self, X, **kwargs): + features = super()._get_features_out(X, **kwargs) + if features: + input_names = self.config.feature_names_in_ or self.config.feature_idx_in_ + for name in input_names: + label_feature = features.get(name) + if isinstance(label_feature, get_feature("ClassLabel")): + label_names = label_feature.names + if self.config.names and isinstance(self.config.names[0], str): + self.config.label_mapping = { + label_names.index(label): i + for i, label in enumerate(self.config.names) + } + + label_feature = get_feature("ClassLabel")(names=self.config.names) + features[name] = label_feature + + return features + + def _check_data(self, X): + X_dims = DataHandler.get_shape(X) + if len(X_dims) > 1: + if X_dims[1] == 1: + out = DataHandler.select_column(X, 0) + else: + raise ValueError( + f"Expected input to have 1 column, got {X_dims[1]} columns" + ) + else: + out = X + return out + + def _process_fit_input(self, input, **kwargs): + if self.config.names: + kwargs["fn_kwargs"]["fn"] = None + return super()._process_fit_input(input, **kwargs) + + def _fit_array(self, X): + out = self._check_data(X) + self.config.names = DataHandler.to_format(DataHandler.unique(out), "list") + self.config._transform_process_desc = ( + f"Encoding labels with: {self.config.label_mapping}" + ) + return self + + def _partial_fit_array(self, X): + out = self._check_data(X) + labs = DataHandler.to_format(DataHandler.unique(out), "list") + self.config.names = np.unique(self.config.names + labs).tolist() + return self + + def _pool_fit(self, fitted_processors): + names = [] + for processor in fitted_processors: + if isinstance(processor, LabelEncoder): + names += processor.config.names + self.config.names = np.unique(names).tolist() + self.config.label_mapping = { + label: i for i, label in enumerate(self.config.names) + } + self.config._transform_process_desc = ( + f"Encoding labels with: {self.config.label_mapping}" + ) + return self + + def _transform_array(self, X): + labs = self._check_data(X) + labs = DataHandler.to_format(labs, "list") + return [self.config.label_mapping.get(lab, -1) for lab in labs] diff --git a/src/biofit/preprocessing/feature_extraction/__init__.py b/src/biofit/preprocessing/feature_extraction/__init__.py new file mode 100644 index 0000000..54f0582 --- /dev/null +++ b/src/biofit/preprocessing/feature_extraction/__init__.py @@ -0,0 +1,3 @@ +# ruff: noqa +from .pcoa import * +from .pca import * diff --git a/src/biofit/preprocessing/feature_extraction/feature_extraction.py b/src/biofit/preprocessing/feature_extraction/feature_extraction.py new file mode 100644 index 0000000..556aa27 --- /dev/null +++ b/src/biofit/preprocessing/feature_extraction/feature_extraction.py @@ -0,0 +1,25 @@ +from dataclasses import dataclass, field + +from biofit.integration.biosets import get_feature +from biofit.processing import BaseProcessor, ProcessorConfig, SelectedFeatureTypes + + +@dataclass +class FeatureExtractorConfig(ProcessorConfig): + """Base class for feature extraction processor configurations.""" + + processor_type: str = field(default="feature_extraction", init=False, repr=False) + _fit_unused_feature_types: SelectedFeatureTypes = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + _transform_unused_feature_types: SelectedFeatureTypes = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + + +class FeatureExtractor(BaseProcessor): + """Base class for feature extraction processors.""" diff --git a/src/biofit/preprocessing/feature_extraction/pca/__init__.py b/src/biofit/preprocessing/feature_extraction/pca/__init__.py new file mode 100644 index 0000000..a3c48f4 --- /dev/null +++ b/src/biofit/preprocessing/feature_extraction/pca/__init__.py @@ -0,0 +1,5 @@ +# ruff: noqa +from .pca import ( + PCAFeatureExtractor, + PCAFeatureExtractorConfig, +) diff --git a/src/biofit/preprocessing/feature_extraction/pca/pca.py b/src/biofit/preprocessing/feature_extraction/pca/pca.py new file mode 100644 index 0000000..b2a9b95 --- /dev/null +++ b/src/biofit/preprocessing/feature_extraction/pca/pca.py @@ -0,0 +1,235 @@ +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Union + +import numpy as np +import pandas as pd +from biofit.processing import SelectedColumnTypes, sync_backup_config +from biofit.utils import logging +from sklearn.decomposition import PCA + +from ..feature_extraction import FeatureExtractor, FeatureExtractorConfig + +if TYPE_CHECKING: + import polars as pl + +logger = logging.get_logger(__name__) + + +@dataclass +class PCAFeatureExtractorConfig(FeatureExtractorConfig): + _fit_process_desc: str = field( + default="Determining the principal components", init=False, repr=False + ) + _transform_process_desc: str = field( + default="Transforming the input data to principal components", + init=False, + repr=False, + ) + processor_name: str = field(default="pca", init=False, repr=False) + + input_columns: List[str] = None + n_components: int = None + copy: bool = True + whiten: bool = False + svd_solver: str = "auto" + tol: float = 0.0 + iterated_power: str = "auto" + n_oversamples: int = 10 + power_iteration_normalizer: str = "auto" + random_state: int = None + + +class PCAFeatureExtractor(FeatureExtractor): + output_dtype = "float64" + config_class = PCAFeatureExtractorConfig + config: PCAFeatureExtractorConfig + + def __init__( + self, + input_columns: List[str] = None, + n_components: int = None, + copy: bool = True, + whiten: bool = False, + svd_solver: str = "auto", + tol: float = 0.0, + iterated_power: str = "auto", + n_oversamples: int = 10, + power_iteration_normalizer: str = "auto", + random_state: int = None, + config: PCAFeatureExtractorConfig = None, + **kwargs, + ): + super().__init__( + config=config, + input_columns=input_columns, + n_components=n_components, + copy=copy, + whiten=whiten, + svd_solver=svd_solver, + tol=tol, + iterated_power=iterated_power, + n_oversamples=n_oversamples, + power_iteration_normalizer=power_iteration_normalizer, + random_state=random_state, + **kwargs, + ) + self.pca = PCA( + n_components=self.config.n_components, + copy=self.config.copy, + whiten=self.config.whiten, + svd_solver=self.config.svd_solver, + tol=self.config.tol, + iterated_power=self.config.iterated_power, + n_oversamples=self.config.n_oversamples, + power_iteration_normalizer=self.config.power_iteration_normalizer, + random_state=self.config.random_state, + ) + + @sync_backup_config + def set_params(self, **kwargs): + self.config = self.config.replace_defaults(**kwargs) + self.pca = PCA( + n_components=self.config.n_components, + copy=self.config.copy, + whiten=self.config.whiten, + svd_solver=self.config.svd_solver, + tol=self.config.tol, + iterated_power=self.config.iterated_power, + n_oversamples=self.config.n_oversamples, + power_iteration_normalizer=self.config.power_iteration_normalizer, + random_state=self.config.random_state, + ) + return self + + def fit( + self, + X, + input_columns: SelectedColumnTypes = None, + raise_if_missing: bool = None, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + num_proc: int = None, + map_kwargs: dict = {"fn_kwargs": {}}, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ) -> "PCAFeatureExtractor": + self.config._input_columns = self._set_input_columns_and_arity( + input_columns or self.config.input_columns + ) + return self._process_fit( + X, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ) + + def transform( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + self._input_columns = self._set_input_columns_and_arity( + input_columns or self.config.input_columns + ) + return self._process_transform( + X, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def fit_transform( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = True, + raise_if_missing: bool = False, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = {"fn_kwargs": {}}, + num_proc: int = None, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ): + return self.fit( + X, + input_columns=input_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ).transform( + X, + input_columns=input_columns, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def _fit_sklearn(self, X: Union[pd.DataFrame, "pl.DataFrame", np.ndarray]): + self.config.estimator = self.pca.fit(X) + return self + + def _process_fit_output(self, input, out): + self.config._n_features_out = self.config.estimator.n_components_ + return super()._process_fit_output(input, out) + + def _transform_sklearn(self, X: Union[pd.DataFrame, "pl.DataFrame", np.ndarray]): + return self.config.estimator.transform(X) diff --git a/src/biofit/preprocessing/feature_extraction/pca/plot_pca.py b/src/biofit/preprocessing/feature_extraction/pca/plot_pca.py new file mode 100644 index 0000000..504938e --- /dev/null +++ b/src/biofit/preprocessing/feature_extraction/pca/plot_pca.py @@ -0,0 +1,82 @@ +from dataclasses import dataclass, field +from typing import Optional + +from biofit.utils.types import Unset +from biofit.visualization.plotting_utils import plot_dimension_reduction + +from ..plot_feature_extraction import ( + FeatureExtractorPlotter, + FeatureExtractorPlotterConfig, +) + + +@dataclass +class PCAFeatureExtractorPlotterConfig(FeatureExtractorPlotterConfig): + processor_name: str = field(default="pca", init=False, repr=False) + title: str = "PCA Plot" + n_components: int = 3 + label_column: str = None + group_column: str = None + + +class PCAFeatureExtractorPlotter(FeatureExtractorPlotter): + config_class = PCAFeatureExtractorPlotterConfig + config: PCAFeatureExtractorPlotterConfig + + def __init__( + self, + n_components: int = 3, + label_column: str = None, + group_column: str = None, + config: Optional[PCAFeatureExtractorPlotterConfig] = None, + **kwargs, + ): + super().__init__( + config=config, + n_components=n_components, + label_column=label_column, + group_column=group_column, + **kwargs, + ) + + def plot( + self, + X, + labels=None, + group=None, + input_columns=None, + label_column=None, + group_column=None, + precomputed: bool = Unset("False"), + pca_kwargs: dict = Unset("{}"), + title: str = Unset("PCA Plot"), + n_components: int = Unset(3), + path=None, + **kwargs, + ): + """Plot the PCA plot. + + Args: + X: The input data. + labels: The labels for the data. + group: The group for the data. + input_columns: The input columns. + label_column: The label column. + group_column: The group column. + precomputed: Whether the input data is precomputed. + pca_kwargs: The additional kwargs for computing PCA. Only used if precomputed is False. + **kwargs: The additional kwargs. + + """ + return plot_dimension_reduction( + X, + labels=labels, + group=group, + input_columns=input_columns, + label_column=label_column, + group_column=group_column, + method="pca" if not precomputed else None, + method_kwargs=pca_kwargs, + output_dir=path, + **kwargs, + ) diff --git a/src/biofit/preprocessing/feature_extraction/pcoa/__init__.py b/src/biofit/preprocessing/feature_extraction/pcoa/__init__.py new file mode 100644 index 0000000..d1cbef5 --- /dev/null +++ b/src/biofit/preprocessing/feature_extraction/pcoa/__init__.py @@ -0,0 +1,10 @@ +# ruff: noqa +from .pcoa import ( + PCoAFeatureExtractor, + PCoAFeatureExtractorConfig, + PCoAFeatureExtractorConfigForOTU, +) +from .plot_pcoa import ( + PCoAFeatureExtractorPlotter, + PCoAFeatureExtractorPlotterConfig, +) diff --git a/src/biofit/preprocessing/feature_extraction/pcoa/pcoa.py b/src/biofit/preprocessing/feature_extraction/pcoa/pcoa.py new file mode 100644 index 0000000..aedd511 --- /dev/null +++ b/src/biofit/preprocessing/feature_extraction/pcoa/pcoa.py @@ -0,0 +1,413 @@ +# Portions of this module are derived from the Pyckmeans project available at: +# https://github.com/TankredO/pyckmeans + +# Pyckmeans is licensed under the MIT License. The following is a copy of the original license under which the Pyckmeans software is distributed: + +# MIT License + +# Copyright (c) 2021 Tankred Ott + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +"""Principal Coordinate Analysis (PCoA) feature extraction module.""" + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Optional, Type + +import numpy as np + +from biofit.integration.biosets import get_feature +from biofit.processing import SelectedColumnTypes, sync_backup_config +from biofit.stat import DistanceStat, DistanceStatConfig +from biofit.utils import logging + +from ..feature_extraction import FeatureExtractor, FeatureExtractorConfig + +if TYPE_CHECKING: + pass + +logger = logging.get_logger(__name__) + + +def _center_mat(dmat: np.ndarray) -> np.ndarray: + """_center_mat + + Center n*n matrix. + + Parameters + ---------- + dmat : np.ndarray + n*n matrix. + + Returns + ------- + np.ndarray + Centered matrix. + """ + + n = dmat.shape[0] + mat = np.full((n, n), -1 / n) + mat[np.diag_indices(n)] += 1 + + return mat.dot(dmat).dot(mat) + + +class InvalidCorrectionTypeError(Exception): + """InvalidCorrectionTypeError""" + + +class NegativeEigenvaluesCorrectionError(Exception): + """FailedCorrectionError + + Error, signalling that the correction of negative eigenvalues failed. + """ + + +class NegativeEigenvaluesWarning(Warning): + """NegativeEigenvaluesWarning + + Warning, signalling that negative eigenvalues were encountered. + """ + + +def pcoa( + x: np.ndarray, + correction: Optional[str] = None, + eps: float = 1e-8, +): + """pcoa + + Principle Coordinate Analysis. + + Parameters + ---------- + dist : Union[np.ndarray, pyckmeans.distance.DistanceMatrix] + n*n distance matrix either as np ndarray or as pyckmeans DistanceMatrix. + correction: Optional[str] + Correction for negative eigenvalues, by default None. + Available corrections are: + - None: negative eigenvalues are set to 0 + - lingoes: Lingoes correction + - cailliez: Cailliet correction + eps : float, optional + Eigenvalues smaller than eps will be dropped. By default 0.0001 + + Returns + ------- + Tuple[np.ndarray, np.ndarray] + Eigenvalues and eigenvectors. + + Raises + ------ + InvalidCorrectionTypeError + Raised if an unknown correction type is passed. + NegativeEigenvaluesCorrectionError + Raised if correction parameter is set and correction of negative + eigenvalues is not successful. + """ + + if correction is not None and correction not in ["lingoes", "cailliez"]: + msg = ( + f'Unknown correction type "{correction}". ' + + 'Available correction types are: "lingoes", "cailliez"' + ) + raise InvalidCorrectionTypeError(msg) + + # center matrix + dmat_centered = _center_mat((x * x) / -2) + + # eigen decomposition + eigvals, eigvecs = np.linalg.eigh(dmat_centered, "U") + + # order descending + ord_idcs = np.argsort(eigvals)[::-1] + eigvals = eigvals[ord_idcs] + eigvecs = eigvecs[:, ord_idcs] + + # get min eigenvalue + min_eigval = np.min(eigvals) + + # set small eigenvalues to 0 + zero_eigval_idcs = np.nonzero(np.abs(eigvals) < eps)[0] + eigvals[zero_eigval_idcs] = 0 + + # no negative eigenvalues + if min_eigval > -eps: + fze_idx = len(np.nonzero(eigvals > eps)[0]) # index of first zero in eigvals + vectors = eigvecs[:, :fze_idx] * np.sqrt(eigvals[:fze_idx]) + + return eigvals, vectors + + # negative eigenvalues + else: + fze_idx = len(np.nonzero(eigvals > eps)[0]) # index of first zero in eigvals + vectors = eigvecs[:, :fze_idx] * np.sqrt(eigvals[:fze_idx]) + + # negative eigenvalues, no correction + if not correction: + logger.warn( + "Negative eigenvalues encountered but no correction applied. " + "Negative eigenvalues will be treated as 0." + ) + + return eigvals, vectors + + # negative eigenvalues, correction + + # -- correct distance matrix + # lingoes correction + if correction == "lingoes": + corr_1 = -min_eigval + + # corrected distance matrix + x_ncorr = -0.5 * ((x * x) + 2 * corr_1) + elif correction == "cailliez": + dmat_centered_2 = _center_mat(-0.5 * x) + + # prepare matrix for correction + upper = np.c_[np.zeros((x.shape[0], x.shape[0])), 2 * dmat_centered] + lower = np.c_[np.diag(np.full(x.shape[0], -1)), -4 * dmat_centered_2] + sp_mat = np.r_[upper, lower] + + corr_2 = np.max(np.real(np.linalg.eigvals(sp_mat))) + + # corrected distance matrix + x_ncorr = -0.5 * (x + corr_2) ** 2 + + # -- apply PCoA to corrected distance matrix + x_ncorr[np.diag_indices(x_ncorr.shape[0])] = 0 + x_ncorr = _center_mat(x_ncorr) + + eigvals_ncorr, eigvecs_ncorr = np.linalg.eigh(x_ncorr, "U") + + # order descending + ord_idcs_ncorr = np.argsort(eigvals_ncorr)[::-1] + eigvals_ncorr = eigvals_ncorr[ord_idcs_ncorr] + eigvecs_ncorr = eigvecs_ncorr[:, ord_idcs_ncorr] + + # get min eigenvalue + min_eigval_ncorr = np.min(eigvals_ncorr) + + # set small eigenvalues to 0 + zero_eigval_idcs_ncorr = np.nonzero(np.abs(eigvals_ncorr) < eps)[0] + eigvals_ncorr[zero_eigval_idcs_ncorr] = 0 + + if min_eigval_ncorr < -eps: + msg = ( + "Correction failed. There are still negative eigenvalues after applying " + + f"{correction.capitalize()} correction." + ) + raise NegativeEigenvaluesCorrectionError(msg) + + fze_idx_ncorr = len( + np.nonzero(eigvals_ncorr > eps)[0] + ) # index of first zero in eigvals + vectors_ncorr = eigvecs_ncorr[:, :fze_idx_ncorr] * np.sqrt( + eigvals_ncorr[:fze_idx_ncorr] + ) + + return eigvals_ncorr, vectors_ncorr + + +@dataclass +class PCoAFeatureExtractorConfig(FeatureExtractorConfig): + processor_name: str = field(default="pcoa", init=False, repr=False) + output_template_name: str = field(default="Dim{i+1}", init=False, repr=False) + _fit_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + _transform_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + _fit_process_desc: str = field( + default="", + init=False, + repr=False, + ) + _transform_process_desc: str = field( + default="Applying Principal Coordinate Analysis (PCoA) to the input data", + init=False, + repr=False, + ) + n_components: int = None + correction: str = None + eps: float = 1e-3 + + metric: str = "braycurtis" + p: float = 2 + w: float = None + V: float = None + VI: float = None + squareform: bool = True + + # fitted attributes + vectors: np.ndarray = None + eigvals: np.ndarray = None + + def __post_init__(self): + self._fit_process_desc = f"Calculating PCoA using {self.metric} distance" + + +@dataclass +class PCoAFeatureExtractorConfigForOTU(PCoAFeatureExtractorConfig): + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance")], init=False, repr=False + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance")], init=False, repr=False + ) + dataset_name: str = field(default="otu", init=False, repr=False) + correction: str = "cailliez" + + +class PCoAFeatureExtractor(FeatureExtractor): + output_dtype = "float64" + + config_class = PCoAFeatureExtractorConfig + config: PCoAFeatureExtractorConfig + + def __init__( + self, + n_components: int = None, + correction: str = None, + eps: float = 1e-3, + metric: str = "braycurtis", + p: float = 2, + w: float = None, + V: float = None, + VI: float = None, + squareform: bool = True, + config: Optional[PCoAFeatureExtractorConfig] = None, + **kwargs, + ): + super().__init__( + config=config, + n_components=n_components, + correction=correction, + eps=eps, + metric=metric, + p=p, + w=w, + V=V, + VI=VI, + squareform=squareform, + **kwargs, + ) + distance_config = DistanceStatConfig.from_config(self.config) + self.distance = DistanceStat(distance_config) + self.config._n_features_out = self.config.n_components + + @sync_backup_config + def set_params(self, **kwargs): + self.config = self.config.replace_defaults(**kwargs) + distance_config = DistanceStatConfig.from_config(self.config) + self.distance = DistanceStat(distance_config) + self.config._n_features_out = self.config.n_components + return self + + def fit( + self, + X, + input_columns: SelectedColumnTypes = None, + raise_if_missing: bool = None, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + num_proc: int = None, + map_kwargs: dict = {"fn_kwargs": {}}, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ) -> "PCoAFeatureExtractor": + self.config._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_fit( + X, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ) + + def transform( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + self._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_transform( + X, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def _fit_numpy(self, X: "np.ndarray"): + self.config.eigvals, self.config.vectors = pcoa( + self.distance._transform_numpy(X), + correction=self.config.correction, + eps=self.config.eps, + ) + + return self + + def _process_fit_output(self, input, output): + if self.config.n_components is None: + self.config.n_components = self.config.vectors.shape[1] + self.config._n_features_out = self.config.n_components + return super()._process_fit_output(input, output) + + def _transform_numpy(self, X: "np.ndarray"): + return self.config.vectors[:, : self.config.n_components] diff --git a/src/biofit/preprocessing/feature_extraction/pcoa/plot_pcoa.py b/src/biofit/preprocessing/feature_extraction/pcoa/plot_pcoa.py new file mode 100644 index 0000000..eb2b329 --- /dev/null +++ b/src/biofit/preprocessing/feature_extraction/pcoa/plot_pcoa.py @@ -0,0 +1,91 @@ +from dataclasses import dataclass, field +from typing import Optional + +from biofit.processing import SelectedColumnTypes +from biofit.utils.types import Unset +from biofit.visualization.plotting_utils import plot_dimension_reduction + +from ..plot_feature_extraction import ( + FeatureExtractorPlotter, + FeatureExtractorPlotterConfig, +) + + +@dataclass +class PCoAFeatureExtractorPlotterConfig(FeatureExtractorPlotterConfig): + processor_name: str = field(default="pcoa", init=False, repr=False) + title: str = "PCoA Plot" + n_components: int = 3 + label_column: str = None + group_column: str = None + + +class PCoAFeatureExtractorPlotter(FeatureExtractorPlotter): + config_class = PCoAFeatureExtractorPlotterConfig + config: PCoAFeatureExtractorPlotterConfig + + def __init__( + self, + n_components: int = 3, + label_column: str = None, + group_column: str = None, + config: Optional[PCoAFeatureExtractorPlotterConfig] = None, + **kwargs, + ): + super().__init__( + config=config, + n_components=n_components, + label_column=label_column, + group_column=group_column, + **kwargs, + ) + + def plot( + self, + X, + labels=None, + group=None, + input_columns: SelectedColumnTypes = None, + label_column: SelectedColumnTypes = None, + group_column: SelectedColumnTypes = None, + precomputed: bool = Unset("False"), + pca_kwargs: dict = Unset("{}"), + title: str = Unset("PCA Plot"), + n_components: int = Unset(3), + path=None, + device: str = "pdf", + fingerprint: str = None, + unused_columns: SelectedColumnTypes = None, + raise_if_missing: bool = True, + show: bool = True, + ): + """Plot the PCoA plot. + + Args: + X: The input data. + labels: The labels for the data. + group: The group for the data. + input_columns: The input columns. + label_column: The label column. + group_column: The group column. + precomputed: Whether the input data is precomputed. + pcoa_kwargs: The additional kwargs for computing PCoA. Only used if precomputed is False. + **kwargs: The additional kwargs. + + """ + return plot_dimension_reduction( + X, + labels=labels, + group=group, + input_columns=input_columns, + label_column=label_column, + group_column=group_column, + method="pcoa" if not precomputed else None, + method_kwargs=pca_kwargs, + output_dir=path, + device=device, + fingerprint=fingerprint, + unused_columns=unused_columns, + raise_if_missing=raise_if_missing, + show=show, + ) diff --git a/src/biofit/preprocessing/feature_extraction/plot_feature_extraction.py b/src/biofit/preprocessing/feature_extraction/plot_feature_extraction.py new file mode 100644 index 0000000..ae81b18 --- /dev/null +++ b/src/biofit/preprocessing/feature_extraction/plot_feature_extraction.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from biofit.visualization.plotting import BasePlotter, PlotterConfig + +if TYPE_CHECKING: + pass + + +@dataclass +class FeatureExtractorPlotterConfig(PlotterConfig): + processor_type: str = field(default="feature_extractor", init=False, repr=False) + + +class FeatureExtractorPlotter(BasePlotter): + """Base class for feature extraction processors.""" + + config_class = FeatureExtractorPlotterConfig + config: FeatureExtractorPlotterConfig diff --git a/src/biofit/preprocessing/feature_selection/__init__.py b/src/biofit/preprocessing/feature_selection/__init__.py new file mode 100644 index 0000000..6959d8e --- /dev/null +++ b/src/biofit/preprocessing/feature_selection/__init__.py @@ -0,0 +1,2 @@ +# ruff: noqa +from .min_prevalence_feature_selector import * diff --git a/src/biofit/preprocessing/feature_selection/feature_selection.py b/src/biofit/preprocessing/feature_selection/feature_selection.py new file mode 100644 index 0000000..68bfbd7 --- /dev/null +++ b/src/biofit/preprocessing/feature_selection/feature_selection.py @@ -0,0 +1,105 @@ +from dataclasses import dataclass, field + +from biocore import DataHandler +from biocore.utils import get_kwargs + +from biofit.processing import BaseProcessor, ProcessorConfig +from biofit.utils import logging + +logger = logging.get_logger(__name__) + + +@dataclass +class FeatureSelectorConfig(ProcessorConfig): + processor_type: str = field(default="feature_selection", init=False, repr=False) + + +class FeatureSelector(BaseProcessor): + """Base class for feature selection processors.""" + + config_class = FeatureSelectorConfig + config: FeatureSelectorConfig + + def run(self, X, runner=None, fn_kwargs: dict = {}, **map_kwargs): + fn_kwargs = self._prepare_runner(X, **fn_kwargs) + if "func_type" in fn_kwargs and fn_kwargs["func_type"] != "_fit": + input_format = fn_kwargs["in_format_kwargs"]["target_format"] + fn_kwargs["out_format_kwargs"]["target_format"] = input_format + runner = None + return super().run(X, runner=runner, fn_kwargs=fn_kwargs, **map_kwargs) + + def _transform_any(self, X, selected_indices=None): + if selected_indices is None: + return X + return DataHandler.select_columns(X, selected_indices) + + def _process_transform_batch_input(self, X, *fn_args, **fn_kwargs): + func = fn_kwargs.get("fn", None) + in_format_kwargs = fn_kwargs.get("in_format_kwargs", {}) + in_format_kwargs["input_columns"] = None + input = DataHandler.to_format(X, **in_format_kwargs) + _fn_kwargs = get_kwargs(fn_kwargs, func) + + selected_indices = fn_kwargs.get("selected_indices", None) + unused_indices = fn_kwargs.get("unused_indices", None) + keep_unused_columns = fn_kwargs.get("keep_unused_columns", None) + if self.config._feature_names_out and DataHandler.supports_named_columns(input): + feature_idx_out = DataHandler.get_column_indices( + input, self.config._feature_names_out, raise_if_missing=False + ) + elif self.config._feature_idx_out is not None: + feature_idx_out = [ + selected_indices[i] for i in self.config._feature_idx_out + ] + else: + raise ValueError( + "FeatureSelectorConfig requires either _feature_idx_out, as well as " + "_feature_idx_out when input format supports named columns." + ) + if keep_unused_columns: + feature_idx_out = list(sorted(feature_idx_out + unused_indices)) + _fn_kwargs["selected_indices"] = feature_idx_out + return input, fn_args, _fn_kwargs + + def _process_fit_output(self, input, out): + idx_out = self.config._feature_idx_out + names_in = self.config.feature_names_in_ + names_out = self.config._feature_names_out + if idx_out is not None: + if names_in is not None and ( + names_out is None or len(idx_out) < len(names_out) + ): + names_out = [names_in[i] for i in idx_out] + self.config._feature_names_out = names_out + + return super()._process_fit_output(input, out) + + def _process_transform_batch_output(self, input, out, **fn_kwargs): + # do nothing + return out + + def _process_transform_output(self, output, input, *args, **kwargs): + unused_indices = kwargs.get("unused_indices", None) + new_fingerprint = kwargs.get("fingerprint", None) + selected_indices = self.config._feature_idx_out + if DataHandler.supports_named_columns(input) and self.config._feature_names_out: + input_cols = DataHandler.get_column_names(input) + col_dict = dict(zip(input_cols, range(len(input_cols)))) + selected_indices = { + col_dict[name] + for name in self.config._feature_names_out + if name in col_dict + } + + before_filtering = len(kwargs.get("selected_indices", [])) + logger.info( + f'"{self.__class__.__name__}": Selected {len(selected_indices)} out of ' + f"{before_filtering} input features" + ) + + if unused_indices: + selected_indices = set(unused_indices).union(selected_indices) + + return DataHandler.select_columns( + input, sorted(list(selected_indices)), new_fingerprint=new_fingerprint + ) diff --git a/src/biofit/preprocessing/feature_selection/min_prevalence_feature_selector/__init__.py b/src/biofit/preprocessing/feature_selection/min_prevalence_feature_selector/__init__.py new file mode 100644 index 0000000..762618c --- /dev/null +++ b/src/biofit/preprocessing/feature_selection/min_prevalence_feature_selector/__init__.py @@ -0,0 +1,20 @@ +# ruff: noqa +from .min_prevalence_feature_selector import ( + MinPrevalenceFeatureSelector, + MinPrevalenceFeatureSelectorConfig, + MinPrevalenceFeatureSelectorConfigForMetagenomics, + MinPrevalenceFeatureSelectorConfigForOTU, + MinPrevalenceFeatureSelectorConfigForSNP, +) +from .plot_min_prevalence_feature_selector import ( + MinPrevalenceFeatureSelectorPlotter, + MinPrevalencePlotterConfig, + MinPrevalencePlotterConfigForASV, + MinPrevalencePlotterConfigForGenomics, + MinPrevalencePlotterConfigForMetagenomics, + MinPrevalencePlotterConfigForOTU, + MinPrevalencePlotterConfigForProteomics, + MinPrevalencePlotterConfigForMaldi, + MinPrevalencePlotterConfigForReadCount, + MinPrevalencePlotterConfigForSNP, +) diff --git a/src/biofit/preprocessing/feature_selection/min_prevalence_feature_selector/min_prevalence_feature_selector.py b/src/biofit/preprocessing/feature_selection/min_prevalence_feature_selector/min_prevalence_feature_selector.py new file mode 100644 index 0000000..ab76a9a --- /dev/null +++ b/src/biofit/preprocessing/feature_selection/min_prevalence_feature_selector/min_prevalence_feature_selector.py @@ -0,0 +1,404 @@ +""" +Feature selector that filters features based on the number of samples they are present in. +""" + +from dataclasses import dataclass, field +from typing import List, Optional, Tuple, Type + +import numpy as np +from biocore import DataHandler + +from biofit.integration.biosets import get_feature +from biofit.processing import SelectedColumnTypes, sync_backup_config +from biofit.stat import ColumnMissingnessStat, ColumnMissingnessStatConfig +from biofit.utils import logging + +from ..feature_selection import FeatureSelector, FeatureSelectorConfig + +logger = logging.get_logger(__name__) + +FILTER_FEATURES_DOCSTRING = """ +SampleFilter features based on the number of samples they are present in. + +Args: + X: Input data + min_prevalence: Minimum number of samples a feature must be present in to be kept. + +Returns: + SampleFiltered data. +""" + + +def _filter_features(nrows: int, total_missing, min_prevalence, cols=None): + """ + SampleFilter features in a pandas DataFrame based on their presence in the dataset. + + Args: + X (pd.DataFrame): The input DataFrame containing the features. + min_prevalence (float, optional): The minimum required presence of a feature in the dataset. + depth (float, optional): The minimum value to be considered as present. + + Returns: + List[int]: A list of column indices that pass the filtering condition. + """ + # keep features that are present in at least min_prevalence of the rows + col_to_keep = total_missing <= (nrows * (1 - min_prevalence)) + if cols and len(cols) == len(col_to_keep): + return [c for c, b in zip(cols, col_to_keep) if b] + else: + return [i for i, b in enumerate(col_to_keep) if b] + + +@dataclass +class MinPrevalenceFeatureSelectorConfig(FeatureSelectorConfig): + processor_name: str = field( + default="min_prevalence_feature_selector", init=False, repr=False + ) + _fit_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + _transform_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + + min_prevalence: float = 0.5 + depth: int = None + + def __post_init__(self): + if self.depth is None: + if isinstance(self.min_prevalence, float) and self.min_prevalence < 1: + self._fit_process_desc = ( + "Removing features that are present in less than " + f"{self.min_prevalence * 100:.0f}% of samples" + ) + elif ( + isinstance(self.min_prevalence, (int, float)) + and self.min_prevalence >= 1 + ): + self._fit_process_desc = f"Removing features that are present in less than {self.min_prevalence} samples" + + elif isinstance(self.min_prevalence, float) and self.min_prevalence < 1: + self._fit_process_desc = ( + f"Removing features with <{self.min_prevalence * 100:.0f}% " + f"samples above {self.depth} counts" + ) + elif isinstance(self.min_prevalence, (int, float)) and self.min_prevalence >= 1: + self._fit_process_desc = ( + f"Removing features with <{self.min_prevalence} " + f"samples above {self.depth} counts" + ) + + +@dataclass +class MinPrevalenceFeatureSelectorConfigForMetagenomics( + MinPrevalenceFeatureSelectorConfig +): + dataset_name: str = field(default="metagenomics", init=False, repr=False) + _fit_input_feature_types: List[Tuple[Type, Type]] = field( + default_factory=lambda: [(get_feature("Abundance"), get_feature("ReadCount"))], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Tuple[Type, Type]] = field( + default_factory=lambda: [(get_feature("Abundance"), get_feature("ReadCount"))], + init=False, + repr=False, + ) + depth: int = 100 + min_prevalence: float = 0.01 + + +@dataclass +class MinPrevalenceFeatureSelectorConfigForOTU(MinPrevalenceFeatureSelectorConfig): + dataset_name: str = field(default="otu", init=False, repr=False) + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance")], init=False, repr=False + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance")], init=False, repr=False + ) + depth: int = 0 + min_prevalence: float = 0.01 + + def __post_init__(self): + if self.depth is None: + if isinstance(self.min_prevalence, float) and self.min_prevalence < 1: + self._fit_process_desc = ( + "Removing OTUs that are present in less than " + f"{self.min_prevalence * 100:.0f}% of samples" + ) + elif ( + isinstance(self.min_prevalence, (int, float)) + and self.min_prevalence >= 1 + ): + self._fit_process_desc = f"Removing OTUs that are present in less than {self.min_prevalence} samples" + + elif isinstance(self.min_prevalence, float) and self.min_prevalence < 1: + self._fit_process_desc = ( + f"Removing OTUs with <{self.min_prevalence * 100:.0f}% " + f"samples above {self.depth} counts" + ) + elif isinstance(self.min_prevalence, (int, float)) and self.min_prevalence >= 1: + self._fit_process_desc = ( + f"Removing OTUs with <{self.min_prevalence} " + f"samples above {self.depth} counts" + ) + + +@dataclass +class MinPrevalenceFeatureSelectorConfigForSNP(MinPrevalenceFeatureSelectorConfig): + dataset_name: str = field(default="snp", init=False, repr=False) + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("GenomicVariant")], init=False, repr=False + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("GenomicVariant")], init=False, repr=False + ) + depth: int = 0 + min_prevalence: float = 0.012 + + +class MinPrevalenceFeatureSelector(FeatureSelector): + """ + Feature selector that filters features based on the number of samples they are present in. + + - config: + - min_prevalence (Union[str, float], optional): + The minimum prevalence of features to keep a sample. + The threshold to which we determine the minimum prevalence percentage or count of present values to maintain a sample or feature. + Any value passed >= 1 will be the minimum count while any value < 1 will be a percentage of the total values. + If "auto", the minimum prevalence is calculated as the first quartile minus 1.5 times the interquartile range. Defaults to "auto". + - depth (Union[int, Unset], optional): + The minimum value that we consider something to be present. Defaults to None. + """ + + config_class = MinPrevalenceFeatureSelectorConfig + config: MinPrevalenceFeatureSelectorConfig + + def __init__( + self, + min_prevalence: float = 0.5, + depth: int = None, + config: Optional[MinPrevalenceFeatureSelectorConfig] = None, + **kwargs, + ): + super().__init__( + config=config, depth=depth, min_prevalence=min_prevalence, **kwargs + ) + self = self.set_params() + + @sync_backup_config + def set_params(self, **kwargs): + if len(kwargs) > 0: + self.config = self.config.replace_defaults(**kwargs) + col_missingness_config = ColumnMissingnessStatConfig.from_config(self.config) + col_missingness_config = col_missingness_config.replace_defaults(**kwargs) + self.missingness = ColumnMissingnessStat(config=col_missingness_config) + self.total_missing = None + self._input_columns = None + return self + + def fit( + self, + X, + input_columns: SelectedColumnTypes = None, + raise_if_missing: bool = None, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + num_proc: int = None, + map_kwargs: dict = {"fn_kwargs": {}}, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ) -> "MinPrevalenceFeatureSelector": + self.config._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_fit( + X, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ) + + def transform( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + self._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_transform( + X, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def fit_transform( + self, + X, + *args, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = True, + raise_if_missing: bool = False, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = {"fn_kwargs": {}}, + num_proc: int = None, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ): + return self.fit( + X, + input_columns=input_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ).transform( + X, + input_columns=input_columns, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def _process_fit_input(self, input, **kwargs): + self.height = DataHandler.get_shape(input)[0] + return input, kwargs + + def _fit_numpy(self, X): + self.total_missing = self.missingness._transform_numpy(X).flatten() + return self + + def _partial_fit_numpy(self, X): + if self.total_missing is None: + self.total_missing = self.missingness._transform_numpy(X).flatten() + else: + self.total_missing += self.missingness._transform_numpy(X).flatten() + return self + + def _fit_pandas(self, X): + self.total_missing = self.missingness._transform_pandas(X).values.flatten() + return self + + def _partial_fit_pandas(self, X): + if self.total_missing is None: + self.total_missing = self.missingness._transform_pandas(X).values.flatten() + else: + self.total_missing += self.missingness._transform_pandas(X).values.flatten() + return self + + def _fit_polars(self, X): + self.total_missing = self.missingness._transform_polars(X).to_numpy().flatten() + return self + + def _partial_fit_polars(self, X): + if self.total_missing is None: + self.total_missing = ( + self.missingness._transform_polars(X).to_numpy().flatten() + ) + else: + self.total_missing += ( + self.missingness._transform_polars(X).to_numpy().flatten() + ) + return self + + def _fit_arrow(self, X): + self.total_missing = np.array( + [v[0] for _, v in self.missingness._transform_arrow(X).to_pydict().items()] + ) + return self + + def _partial_fit_arrow(self, X): + if self.total_missing is None: + total_missing = self.missingness._transform_arrow(X) + self.total_missing = np.array( + [v[0] for _, v in total_missing.to_pydict().items()] + ) + else: + other_missing = self.missingness._transform_arrow(X) + other_missing = np.array( + [v[0] for _, v in other_missing.to_pydict().items()] + ) + self.total_missing += other_missing + return self + + def _pool_fit(self, out: List["MinPrevalenceFeatureSelector"]): + new_self = out[0] + if len(out) > 1: + logger.info("Pooling results") + new_self.total_missing = np.sum( + np.vstack([x.total_missing for x in out]), axis=0 + ) + logger.info( + f"Total missing values: {new_self.total_missing.sum()} " + f"({new_self.total_missing.mean():.2f} per feature)" + ) + return new_self + + def _process_fit_output(self, input, out): + self.config._feature_idx_out = _filter_features( + self.height, + self.total_missing, + self.config.min_prevalence, + ) + return super()._process_fit_output(input, out) diff --git a/src/biofit/preprocessing/feature_selection/min_prevalence_feature_selector/plot_min_prevalence_feature_selector.py b/src/biofit/preprocessing/feature_selection/min_prevalence_feature_selector/plot_min_prevalence_feature_selector.py new file mode 100644 index 0000000..df84702 --- /dev/null +++ b/src/biofit/preprocessing/feature_selection/min_prevalence_feature_selector/plot_min_prevalence_feature_selector.py @@ -0,0 +1,367 @@ +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Type + +import pyarrow as pa +from biocore.utils.import_util import is_datasets_available + +from biofit.integration.biosets import get_feature +from biofit.processing import SelectedColumnTypes +from biofit.utils.types import Unset + +from ..plot_feature_selection import ( + FeatureSelectorPlotter, + FeatureSelectorPlotterConfig, +) + +if TYPE_CHECKING: + pass + + +@dataclass +class MinPrevalencePlotterConfig(FeatureSelectorPlotterConfig): + plot_process_desc: str = field( + default="Plotting presence feature selection", init=False, repr=False + ) + _fit_unused_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("METADATA_FEATURE_TYPES"), + get_feature("METADATA_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + _transform_unused_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("METADATA_FEATURE_TYPES"), + get_feature("METADATA_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + _compare: bool = field(default=True, init=False, repr=False) + _add_labels: bool = field(default=False, init=False, repr=False) + processor_name: str = field( + default="min_prevalence_feature_selector", init=False, repr=False + ) + + sample_xlab: str = "Sum of Counts" + sample_main: str = "Sample Distribution" + feature_xlab: str = "Sum of Counts" + feature_main: str = "Feature Distribution" + legend_position: str = "top" + legend_title: str = "Max Missing Feature Selection" + before_name: str = "Before" + after_name: str = "After" + xlog: str = None + ylog: str = None + ncol: int = 2 + include_non_zero_sum: bool = False + non_zero_samp_xlab: str = None + non_zero_samp_main: str = None + non_zero_feat_xlab: str = None + non_zero_feat_main: str = None + + +@dataclass +class MinPrevalencePlotterConfigForMetagenomics(MinPrevalencePlotterConfig): + dataset_name: str = field(default="metagenomics", init=False, repr=False) + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance"), get_feature("Abundance")], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance"), get_feature("Abundance")], + init=False, + repr=False, + ) + + feature_main = "Taxa Total Abundance Distribution" + non_zero_samp_xlab: str = "Species Richness" + non_zero_samp_main: str = "Richness Distribution" + non_zero_feat_xlab: str = "Species Prevalence" + non_zero_feat_main: str = "Prevalence Across Samples" + xlog = "log2_1p" + include_non_zero_sum: bool = True + + +@dataclass +class MinPrevalencePlotterConfigForOTU(MinPrevalencePlotterConfigForMetagenomics): + dataset_name: str = field(default="otu", init=False, repr=False) + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance"), get_feature("Abundance")], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance"), get_feature("Abundance")], + init=False, + repr=False, + ) + feature_main: str = "OTU Total Abundance Distribution" + include_non_zero_sum: bool = True + + +@dataclass +class MinPrevalencePlotterConfigForASV(MinPrevalencePlotterConfigForMetagenomics): + dataset_name: str = field(default="asv", init=False, repr=False) + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance"), get_feature("Abundance")], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance"), get_feature("Abundance")], + init=False, + repr=False, + ) + include_non_zero_sum: bool = True + + +@dataclass +class MinPrevalencePlotterConfigForGenomics(MinPrevalencePlotterConfig): + dataset_name: str = field(default="genomics", init=False, repr=False) + _input_feature_types: List[Type] = field( + default_factory=lambda: [ + (get_feature("ReadCount"), get_feature("GenomicVariant")), + (get_feature("ReadCount"), get_feature("GenomicVariant")), + ], + init=False, + repr=False, + ) + + +@dataclass +class MinPrevalencePlotterConfigForSNP(MinPrevalencePlotterConfig): + dataset_name: str = field(default="snp", init=False, repr=False) + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("GenomicVariant"), + get_feature("GenomicVariant"), + ], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("GenomicVariant"), + get_feature("GenomicVariant"), + ], + init=False, + repr=False, + ) + + +@dataclass +class MinPrevalencePlotterConfigForReadCount(MinPrevalencePlotterConfig): + dataset_name: str = field(default="read_count", init=False, repr=False) + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("ReadCount"), get_feature("ReadCount")], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("ReadCount"), get_feature("ReadCount")], + init=False, + repr=False, + ) + + +@dataclass +class MinPrevalencePlotterConfigForProteomics(MinPrevalencePlotterConfig): + dataset_name: str = field(default="proteomics", init=False, repr=False) + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Expression"), get_feature("Expression")], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Expression"), get_feature("Expression")], + init=False, + repr=False, + ) + + +@dataclass +class MinPrevalencePlotterConfigForMaldi(MinPrevalencePlotterConfig): + dataset_name: str = field(default="maldi", init=False, repr=False) + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("PeakIntensity"), + get_feature("PeakIntensity"), + ], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("PeakIntensity"), + get_feature("PeakIntensity"), + ], + init=False, + repr=False, + ) + + +class MinPrevalenceFeatureSelectorPlotter(FeatureSelectorPlotter): + config_class = MinPrevalencePlotterConfig + config: MinPrevalencePlotterConfig + + def __init__( + self, + sample_xlab: str = "Sum of Counts", + sample_main: str = "Sample Distribution", + feature_xlab: str = "Sum of Counts", + feature_main: str = "Feature Distribution", + legend_position: str = "top", + legend_title: str = "Max Missing Feature Selection", + before_name: str = "Before", + after_name: str = "After", + xlog: str = None, + ylog: str = None, + ncol: int = 2, + include_non_zero_sum: bool = False, + non_zero_samp_xlab: str = None, + non_zero_samp_main: str = None, + non_zero_feat_xlab: str = None, + non_zero_feat_main: str = None, + config: MinPrevalencePlotterConfig = None, + **kwargs, + ): + super().__init__( + config=config, + sample_xlab=sample_xlab, + sample_main=sample_main, + feature_xlab=feature_xlab, + feature_main=feature_main, + legend_position=legend_position, + legend_title=legend_title, + before_name=before_name, + after_name=after_name, + xlog=xlog, + ylog=ylog, + ncol=ncol, + include_non_zero_sum=include_non_zero_sum, + non_zero_samp_xlab=non_zero_samp_xlab, + non_zero_samp_main=non_zero_samp_main, + non_zero_feat_xlab=non_zero_feat_xlab, + non_zero_feat_main=non_zero_feat_main, + **kwargs, + ) + + def plot_pandas(self, x1, x2): + row_sums = [pa.array(x1.sum(axis=1)), pa.array(x2.sum(axis=1))] + col_sums = [pa.array(x1.sum(axis=0)), pa.array(x2.sum(axis=0))] + input = [row_sums, col_sums] + mains = [self.config.sample_main, self.config.feature_main] + xlabs = [self.config.sample_xlab, self.config.feature_xlab] + if self.config.include_non_zero_sum: + non_zero_sample_sums = [ + pa.array(x1[x1 > 0].count(axis=1)), + pa.array(x2[x2 > 0].count(axis=1)), + ] + non_zero_feature_sums = [ + pa.array(x1[x1 > 0].count(axis=0)), + pa.array(x2[x2 > 0].count(axis=0)), + ] + mains.extend( + [ + self.config.non_zero_samp_main, + self.config.non_zero_feat_main, + ] + ) + xlabs.extend( + [ + self.config.non_zero_samp_xlab, + self.config.non_zero_feat_xlab, + ] + ) + input.extend([non_zero_sample_sums, non_zero_feature_sums]) + + self.plotter( + list_of_sums=input, + path=self.config.path, + xlabs=xlabs, + mains=mains, + ncol=self.config.ncol, + legend_position=self.config.legend_position, + legend_title=self.config.legend_title, + before_name=self.config.before_name, + after_name=self.config.after_name, + xlog=self.config.xlog, + ylog=self.config.ylog, + ) + + def plot( + self, + x1, + x2=None, + input_columns1: SelectedColumnTypes = None, + input_columns2: SelectedColumnTypes = None, + sample_xlab: str = Unset('"Sum of Counts"'), + sample_main: str = Unset('"Sample Distribution"'), + feature_xlab: str = Unset('"Sum of Counts"'), + feature_main: str = Unset('"Feature Distribution"'), + legend_position: str = Unset('"top"'), + legend_title: str = Unset('"Max Missing Feature Selection"'), + before_name: str = Unset('"Before"'), + after_name: str = Unset('"After"'), + xlog: str = Unset("None"), + ylog: str = Unset("None"), + ncol: int = Unset("2"), + include_non_zero_sum: bool = Unset("False"), + non_zero_samp_xlab: str = Unset("None"), + non_zero_samp_main: str = Unset("None"), + non_zero_feat_xlab: str = Unset("None"), + non_zero_feat_main: str = Unset("None"), + path: str = None, + device: str = "pdf", + fingerprint: str = None, + unused_columns: SelectedColumnTypes = None, + raise_if_missing: bool = True, + show: bool = True, + ): + if x2 is None: + if is_datasets_available(): + from biosets import Bioset + from datasets import Dataset as HfDataset + + if isinstance(x1, (Bioset, HfDataset)): + _, _, prev_replays = self._get_replays(x1) + if prev_replays: + x2 = Bioset.from_replays(prev_replays) + return self._plot(x2, x1) + else: + raise ValueError( + "Must provide the before and after feature selection datasets." + ) + + self.config._input_columns = self._set_input_columns_and_arity( + input_columns1, input_columns2 + ) + return self._plot( + x1, + x2, + sample_xlab=sample_xlab, + sample_main=sample_main, + feature_xlab=feature_xlab, + feature_main=feature_main, + legend_position=legend_position, + legend_title=legend_title, + before_name=before_name, + after_name=after_name, + xlog=xlog, + ylog=ylog, + ncol=ncol, + include_non_zero_sum=include_non_zero_sum, + non_zero_samp_xlab=non_zero_samp_xlab, + non_zero_samp_main=non_zero_samp_main, + non_zero_feat_xlab=non_zero_feat_xlab, + non_zero_feat_main=non_zero_feat_main, + path=path, + device=device, + fingerprint=fingerprint, + unused_columns=unused_columns, + raise_if_missing=raise_if_missing, + show=show, + ) diff --git a/src/biofit/preprocessing/feature_selection/plot_feature_selection.R b/src/biofit/preprocessing/feature_selection/plot_feature_selection.R new file mode 100644 index 0000000..77220f3 --- /dev/null +++ b/src/biofit/preprocessing/feature_selection/plot_feature_selection.R @@ -0,0 +1,57 @@ +## + +source(file.path(R_SCRIPTS_PATH, "plotting_utils.R")) + + +plot_feature_selector <- function( + list_of_sums, path, + xlabs, + mains, + legend_position = "top", + legend_title = "Feature Selection", + before_name = "Before", + after_name = "After", + xlog = NULL, + ylog = NULL, + ...) { + + suppressPackageStartupMessages(require(ggplot2)) + suppressPackageStartupMessages(require(RColorBrewer)) + suppressPackageStartupMessages(require(circlize)) + suppressPackageStartupMessages(require(patchwork)) + + if (!is.list(list_of_sums) && !is.vector(list_of_sums)) { + list_of_sums <- list(list_of_sums) + } + + if (!is.list(xlabs) && !is.vector(xlabs)) { + xlabs <- list(xlabs) + } + + if (!is.list(mains) && !is.vector(mains)) { + mains <- list(mains) + } + + if (length(xlabs) != length(list_of_sums) && length(xlabs) == 1) { + xlabs <- rep(xlabs, length(list_of_sums)) + } + + if (length(mains) != length(list_of_sums) && length(mains) == 1) { + mains <- rep(mains, length(list_of_sums)) + } + plots <- NULL + for (i in 1:length(list_of_sums)) { + x1 <- as.vector(list_of_sums[[i]][[1]]) + x2 <- as.vector(list_of_sums[[i]][[2]]) + if (is.null(plots)) { + plots <- generate_comparison_histogram(x1, x2, xlab = xlabs[[i]], title = mains[[i]], xlog = xlog, ylog = ylog) + } else { + plots <- plots + generate_comparison_histogram(x1, x2, xlab = xlabs[[i]], title = mains[[i]], xlog = xlog, ylog = ylog) + } + } + + grid <- plots + plot_layout(guides = "collect", ncol=ncol) & + theme(text = element_text(size = 8), legend.position = legend_position) + save_plots(path, plot = grid, width = 6, height = 6.5, dpi = 600) +} + diff --git a/src/biofit/preprocessing/feature_selection/plot_feature_selection.py b/src/biofit/preprocessing/feature_selection/plot_feature_selection.py new file mode 100644 index 0000000..4254ace --- /dev/null +++ b/src/biofit/preprocessing/feature_selection/plot_feature_selection.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass, field +from pathlib import Path + +from biofit.visualization.plotting import BasePlotter, PlotterConfig + + +@dataclass +class FeatureSelectorPlotterConfig(PlotterConfig): + processor_type: str = field(default="feature_selection", init=False, repr=False) + r_source: str = field( + default=(Path(__file__).parent / "plot_feature_selection.R").as_posix(), + init=False, + repr=False, + ) + main_method: str = field(default="plot_feature_selector", init=False, repr=False) + + +class FeatureSelectorPlotter(BasePlotter): + config_class = FeatureSelectorPlotterConfig + config: FeatureSelectorPlotterConfig diff --git a/src/biofit/preprocessing/feature_selection/rfe/plot_rfe.R b/src/biofit/preprocessing/feature_selection/rfe/plot_rfe.R new file mode 100644 index 0000000..774cd2c --- /dev/null +++ b/src/biofit/preprocessing/feature_selection/rfe/plot_rfe.R @@ -0,0 +1,25 @@ +source() + +plot_rfe_feature_selector_for_genomics <- function( + x1, x2, + legend_title = "Feature Selection by Recursive Feature Elimination", + ...) { + args <- list(...) + args$x1 <- x1 + args$x2 <- x2 + args$legend_title <- if ("legend_title" %in% names(args)) args$legend_title else legend_title + do.call(plot_feature_selector_for_genomics, args) +} + + +plot_rfe_feature_selector_snp <- function( + x1, x2, + feature_main = "By GenomicVariantss", + ...) { + args <- list(...) + args$x1 <- x1 + args$x2 <- x2 + args$feature_main <- if ("feature_main" %in% names(args)) args$feature_main else feature_main + do.call(plot_rfe_feature_selector_for_genomics, args) +} + diff --git a/src/biofit/preprocessing/feature_selection/rfe/plot_rfe.py b/src/biofit/preprocessing/feature_selection/rfe/plot_rfe.py new file mode 100644 index 0000000..e69de29 diff --git a/src/biofit/preprocessing/filtering/__init__.py b/src/biofit/preprocessing/filtering/__init__.py new file mode 100644 index 0000000..6949be0 --- /dev/null +++ b/src/biofit/preprocessing/filtering/__init__.py @@ -0,0 +1,4 @@ +# ruff: noqa +from .min_prevalence_sample_filter import * +from .missing_labels import * +from .row_abundance import * diff --git a/src/biofit/preprocessing/filtering/filtering.py b/src/biofit/preprocessing/filtering/filtering.py new file mode 100644 index 0000000..38ca3da --- /dev/null +++ b/src/biofit/preprocessing/filtering/filtering.py @@ -0,0 +1,78 @@ +from dataclasses import dataclass, field + +import pandas as pd +import pyarrow as pa +from biocore import DataHandler +from biocore.utils.import_util import is_datasets_available +from biocore.utils.inspect import get_kwargs +from biocore.utils.py_util import is_bioset, is_dataset, is_iterable_dataset + +from biofit.processing import BaseProcessor, ProcessorConfig +from biofit.utils import logging +from biofit.utils.table_util import string_to_arrow + +logger = logging.get_logger(__name__) + + +@dataclass +class SampleFilterConfig(ProcessorConfig): + processor_type: str = field(default="filtering", init=False, repr=False) + + +class SampleFilter(BaseProcessor): + """Base class for filtering processors. + + NOTE: All transformation functions must return bools. + """ + + _feature_dependent = ( + False # SampleFiltering does not depend on features when transforming + ) + + def run(self, X, runner=None, fn_kwargs: dict = {}, **map_kwargs): + fn_kwargs = self._prepare_runner(X, **fn_kwargs) + if fn_kwargs["func_type"] != "_fit": + if is_datasets_available(): + if ( + is_bioset(X) or is_dataset(X, iterable=False) + ) and "out_type" not in fn_kwargs: + from datasets import Dataset + + runner = Dataset.map + map_kwargs = get_kwargs(map_kwargs, runner) + elif is_iterable_dataset(X): + from datasets import IterableDataset + + runner = IterableDataset.map + map_kwargs = get_kwargs(map_kwargs, runner) + + return super().run(X, runner=runner, fn_kwargs=fn_kwargs, **map_kwargs) + + def _process_transform_batch_output(self, input, out, **fn_kwargs): + bools = DataHandler.to_numpy(out).flatten().tolist() if len(out) > 0 else out[0] + inds = [i for i, x in enumerate(bools) if x] + if len(inds): + return DataHandler.select_rows(input, inds) + else: + cols = DataHandler.get_column_names(input, generate_cols=True) + schema = pa.schema( + [ + pa.field(k, string_to_arrow(v)) + for k, v in DataHandler.get_dtypes(input).items() + ] + ) + return pa.Table.from_pandas( + pd.DataFrame(columns=cols), preserve_index=False, schema=schema + ) + + def _process_transform_output(self, output, *args, **kwargs): + init_num_rows = DataHandler.get_shape(args[0])[0] + final_num_rows = DataHandler.get_shape(output)[0] + if final_num_rows != init_num_rows: + logger.info( + f'"{self.__class__.__name__}": Selected {final_num_rows} out ' + f"of {init_num_rows} samples" + ) + else: + logger.info(f'"{self.__class__.__name__}": no samples were removed') + return output diff --git a/src/biofit/preprocessing/filtering/min_prevalence_sample_filter/__init__.py b/src/biofit/preprocessing/filtering/min_prevalence_sample_filter/__init__.py new file mode 100644 index 0000000..0180563 --- /dev/null +++ b/src/biofit/preprocessing/filtering/min_prevalence_sample_filter/__init__.py @@ -0,0 +1,20 @@ +# ruff: noqa +from .min_prevalence_sample_filter import ( + MinPrevalenceSampleFilter, + MinPrevalenceRowSampleFilterConfig, + MinPrevalenceRowSampleFilterConfigForOTU, + MinPrevalenceRowSampleFilterConfigForSNP, + MinPrevalenceRowSampleFilterConfigForMaldi, +) +from .plot_min_prevalence_sample_filter import ( + MinPrevalenceRowPlotterConfig, + MinPrevalenceRowPlotterConfigForMetagenomics, + MinPrevalenceRowPlotterConfigForOTU, + MinPrevalenceRowPlotterConfigForSNP, + MinPrevalenceRowPlotterConfigForASV, + MinPrevalenceRowPlotterConfigForMaldi, + MinPrevalenceRowPlotterConfigForGenomics, + MinPrevalenceRowPlotterConfigForReadCount, + MinPrevalenceRowPlotterConfigForProteomics, + MinPrevalenceRowPlotter, +) diff --git a/src/biofit/preprocessing/filtering/min_prevalence_sample_filter/min_prevalence_sample_filter.py b/src/biofit/preprocessing/filtering/min_prevalence_sample_filter/min_prevalence_sample_filter.py new file mode 100644 index 0000000..174e31e --- /dev/null +++ b/src/biofit/preprocessing/filtering/min_prevalence_sample_filter/min_prevalence_sample_filter.py @@ -0,0 +1,370 @@ +import sys +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Optional, Type, Union + +import numpy as np +import pandas as pd +from biocore.utils.import_util import is_polars_available + +from biofit.integration.biosets import get_feature +from biofit.processing import SelectedColumnTypes, sync_backup_config +from biofit.stat import RowMissingnessStat, RowMissingnessStatConfig +from biofit.utils import Unset, logging + +from ..filtering import SampleFilter, SampleFilterConfig + +if TYPE_CHECKING: + import polars as pl + +logger = logging.get_logger(__name__) + + +@dataclass +class MinPrevalenceRowSampleFilterConfig(SampleFilterConfig): + # process description + _fit_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + _transform_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + processor_name: str = field( + default="min_prevalence_sample_filter", init=False, repr=False + ) + + # default values + min_prevalence: float = "auto" + depth: int = None + + +@dataclass +class MinPrevalenceRowSampleFilterConfigForOTU(MinPrevalenceRowSampleFilterConfig): + # dataset description + dataset_name: str = field(default="otu", init=False, repr=False) + + # override default values + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance")], init=False, repr=False + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance")], init=False, repr=False + ) + depth: int = 0 + min_prevalence: float = "auto" + + def __post_init__(self): + if self.depth is None: + self._transform_process_desc = ( + f"Removing samples with <{self.min_prevalence * 100:.0f}% present OTUs" + ) + elif isinstance(self.min_prevalence, float) and self.min_prevalence < 1: + self._transform_process_desc = f"Removing samples with <{self.min_prevalence * 100:.0f}% OTUs over {self.depth} counts" + elif isinstance(self.min_prevalence, (int, float)) and self.min_prevalence >= 1: + self._transform_process_desc = f"Removing samples with <{self.min_prevalence} OTUs over {self.depth} counts" + + +@dataclass +class MinPrevalenceRowSampleFilterConfigForSNP(MinPrevalenceRowSampleFilterConfig): + # dataset description + dataset_name: str = field(default="snp", init=False, repr=False) + + # override default values + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("GenomicVariant")], init=False, repr=False + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("GenomicVariant")], init=False, repr=False + ) + depth: int = 0 + min_prevalence: float = 0.012 + + +@dataclass +class MinPrevalenceRowSampleFilterConfigForMaldi(MinPrevalenceRowSampleFilterConfig): + # dataset description + dataset_name: str = field(default="maldi", init=False, repr=False) + + # override default values + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("PeakIntensity")], init=False, repr=False + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("PeakIntensity")], init=False, repr=False + ) + depth: int = 0 + min_prevalence: float = 0.4 + + +class MinPrevalenceSampleFilter(SampleFilter): + # main config class + config_class = MinPrevalenceRowSampleFilterConfig + config: MinPrevalenceRowSampleFilterConfig + + IQRs = None + + def __init__( + self, + config: Optional[MinPrevalenceRowSampleFilterConfig] = None, + *, + min_prevalence: Union[str, float] = "auto", + depth: Union[int, Unset] = None, + **kwargs, + ): + """SampleFilter samples based on the minimum prevalence of features. + + Args: + config (MinPrevalenceRowSampleFilterConfig, optional): + The configuration for the filter. Defaults to None. If a configuration is provided, the + below arguments are ignored. To override the configuration, use the set_params method. + min_prevalence (Union[str, float], optional): + The minimum prevalence of features to keep a sample. + The threshold to which we determine the minimum prevalence percentage or count of present values to maintain a sample or feature. + Any value passed >= 1 will be the minimum count while any value < 1 will be a percentage of the total values. + If "auto", the minimum prevalence is calculated as the first quartile minus 1.5 times the interquartile range. Defaults to "auto". + depth (Union[int, Unset], optional): + The minimum value that we consider something to be present. Defaults to None. + **kwargs: + Additional keyword arguments to pass to the filter. See the ProcessorConfig class for more details. + """ + super().__init__( + config=config, min_prevalence=min_prevalence, depth=depth, **kwargs + ) + row_missingness_config = RowMissingnessStatConfig.from_config(self.config) + self.missingness = RowMissingnessStat(row_missingness_config) + + @sync_backup_config + def set_params(self, **kwargs): + self.config = self.config.replace_defaults(**kwargs) + row_missingness_config = RowMissingnessStatConfig.from_config(self.config) + self.missingness = RowMissingnessStat(row_missingness_config) + return self + + def fit( + self, + X, + input_columns: SelectedColumnTypes = None, + raise_if_missing: bool = None, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + num_proc: int = None, + map_kwargs: dict = {"fn_kwargs": {}}, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ) -> "MinPrevalenceSampleFilter": + self.config._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_fit( + X, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ) + + def transform( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + self._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_transform( + X, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def fit_transform( + self, + X, + *args, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = True, + raise_if_missing: bool = False, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = {"fn_kwargs": {}}, + num_proc: int = None, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ): + return self.fit( + X, + input_columns=input_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ).transform( + X, + input_columns=input_columns, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def _process_fit_input(self, input, **kwargs): + if self.config.min_prevalence != "auto": + kwargs["fn_kwargs"]["fn"] = None + return super()._process_fit_input(input, **kwargs) + + def _fit_polars(self, X: "pl.DataFrame"): + row_missingness = self.missingness._transform_polars(X) + row_missingness = row_missingness.get_column(row_missingness.columns[0]) + row_present = X.shape[1] - row_missingness + iqr = [row_present.quantile(0.25), row_present.quantile(0.75)] + self.config.min_prevalence = iqr[0] - 1.5 * (iqr[1] - iqr[0]) + return self + + def _partial_fit_polars(self, X: "pl.DataFrame"): + row_missingness = self.missingness._transform_polars(X) + row_missingness = row_missingness.get_column(row_missingness.columns[0]) + row_present = X.shape[1] - row_missingness + iqr = [row_present.quantile(0.25), row_present.quantile(0.75)] + if self.IQRs is None: + self.IQRs = [[iqr[0]], [iqr[1]]] + else: + self.IQRs[0].append(iqr[0]) + self.IQRs[1].append(iqr[1]) + return self + + def _fit_pandas(self, X: "pd.DataFrame"): + row_missingness = self.missingness._transform_pandas(X).iloc[:, 0] + row_present = X.shape[1] - row_missingness + iqr = [row_present.quantile(0.25), row_present.quantile(0.75)] + self.config.min_prevalence = iqr[0] - 1.5 * (iqr[1] - iqr[0]) + return self + + def _partial_fit_pandas(self, X: "pd.DataFrame"): + row_missingness = self.missingness._transform_pandas(X).iloc[:, 0] + row_present = X.shape[1] - row_missingness + iqr = [row_present.quantile(0.25), row_present.quantile(0.75)] + if self.IQRs is None: + self.IQRs = [[iqr[0]], [iqr[1]]] + else: + self.IQRs[0].append(iqr[0]) + self.IQRs[1].append(iqr[1]) + return self + + def _fit_numpy(self, X: np.ndarray): + row_missingness = self.missingness._transform_numpy(X)[:, 0] + row_present = X.shape[1] - row_missingness + iqr = [np.quantile(row_present, 0.25), np.quantile(row_present, 0.75)] + self.config.min_prevalence = iqr[0] - 1.5 * (iqr[1] - iqr[0]) + return self + + def _partial_fit_numpy(self, X: np.ndarray): + row_missingness = self.missingness._transform_numpy(X)[:, 0] + row_present = X.shape[1] - row_missingness + iqr = [np.quantile(row_present, 0.25), np.quantile(row_present, 0.75)] + if self.IQRs is None: + self.IQRs = [[iqr[0]], [iqr[1]]] + else: + self.IQRs[0].append(iqr[0]) + self.IQRs[1].append(iqr[1]) + return self + + def _pool_fit_any(self, partial_results: List["MinPrevalenceSampleFilter"]): + IQRs = [[], []] + for result in partial_results: + IQRs[0].extend(result.IQRs[0]) + IQRs[1].extend(result.IQRs[1]) + IQRs[0] = np.mean(IQRs[0]) + IQRs[1] = np.mean(IQRs[1]) + self.config.min_prevalence = IQRs[0] - 1.5 * (IQRs[1] - IQRs[0]) + return self + + def _process_fit_output(self, input, out): + logger.info(f"Minimum prevalence set to {out.config.min_prevalence * 100:.0f}%") + return super()._process_fit_output(input, out) + + def _transform_polars(self, X: "pl.DataFrame"): + if is_polars_available() and "polars" in sys.modules: + import polars as pl + total_present: pl.DataFrame = self.missingness._transform_polars( + X + ).with_columns((X.shape[1] - pl.col("*")).alias("sum")) + + if self.config.min_prevalence < 1: + total_present = total_present.with_columns(pl.col("sum") / X.shape[1]) + tests = total_present.with_columns(pl.col("sum") > self.config.min_prevalence) + if len(tests) == 1: + return tests[0] + return tests.to_numpy()[:, 0].tolist() + + def _transform_pandas(self, X: pd.DataFrame): + total_present = X.shape[1] - self.missingness._transform_pandas(X) + if self.config.min_prevalence < 1: + total_present = total_present / X.shape[1] + tests = total_present > self.config.min_prevalence + if len(tests) == 1: + return tests[0] + return tests.values[:, 0].tolist() + + def _transform_numpy(self, X: np.ndarray): + total_present = X.shape[1] - self.missingness._transform_numpy(X) + if self.config.min_prevalence < 1: + total_present = total_present / X.shape[1] + tests = total_present > self.config.min_prevalence + if len(tests) == 1: + return tests[0] + return tests[:, 0].tolist() diff --git a/src/biofit/preprocessing/filtering/min_prevalence_sample_filter/plot_min_prevalence_sample_filter.py b/src/biofit/preprocessing/filtering/min_prevalence_sample_filter/plot_min_prevalence_sample_filter.py new file mode 100644 index 0000000..9cd8930 --- /dev/null +++ b/src/biofit/preprocessing/filtering/min_prevalence_sample_filter/plot_min_prevalence_sample_filter.py @@ -0,0 +1,304 @@ +from dataclasses import dataclass, field +from typing import List, Type + +import pyarrow as pa + +from biofit.integration.biosets import get_feature +from biofit.processing import SelectedColumnTypes +from biofit.utils.types import Unset + +from ..plot_filtering import ( + SampleFilterPlotter, + SampleFilterPlotterConfig, +) + + +@dataclass +class MinPrevalenceRowPlotterConfig(SampleFilterPlotterConfig): + plot_process_desc: str = field( + default="Plotting presence feature selection", init=False, repr=False + ) + _transform_unused_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("METADATA_FEATURE_TYPES"), + get_feature("METADATA_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + _compare: bool = field(default=True, init=False, repr=False) + processor_name: str = field( + default="min_prevalence_sample_filter", init=False, repr=False + ) + + sample_xlab: str = "Sum of Counts" + sample_main: str = "Sample Distribution" + feature_xlab: str = "Sum of Counts" + feature_main: str = "Feature Distribution" + legend_position: str = "top" + legend_title: str = "Min Prevalence Sample SampleFiltering" + before_name: str = "Before" + after_name: str = "After" + xlog: str = None + ylog: str = None + ncol: int = 2 + include_non_zero_sum: bool = False + non_zero_samp_xlab: str = None + non_zero_samp_main: str = None + non_zero_feat_xlab: str = None + non_zero_feat_main: str = None + + +@dataclass +class MinPrevalenceRowPlotterConfigForMetagenomics(MinPrevalenceRowPlotterConfig): + dataset_name: str = field(default="metagenomics", init=False, repr=False) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [ + [get_feature("Abundance"), get_feature("ReadCount")], + [get_feature("Abundance"), get_feature("ReadCount")], + ], + init=False, + repr=False, + ) + + feature_main = "Taxa Total Abundance Distribution" + non_zero_samp_xlab: str = "Species Richness" + non_zero_samp_main: str = "Richness Distribution" + non_zero_feat_xlab: str = "Species Prevalence" + non_zero_feat_main: str = "Prevalence Across Samples" + xlog = "log2_1p" + include_non_zero_sum: bool = True + + +@dataclass +class MinPrevalenceRowPlotterConfigForOTU(MinPrevalenceRowPlotterConfigForMetagenomics): + dataset_name: str = field(default="otu", init=False, repr=False) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance"), get_feature("Abundance")], + init=False, + repr=False, + ) + feature_main: str = "OTU Total Abundance Distribution" + include_non_zero_sum: bool = True + + +@dataclass +class MinPrevalenceRowPlotterConfigForASV(MinPrevalenceRowPlotterConfigForMetagenomics): + dataset_name: str = field(default="asv", init=False, repr=False) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance"), get_feature("Abundance")], + init=False, + repr=False, + ) + include_non_zero_sum: bool = True + + +@dataclass +class MinPrevalenceRowPlotterConfigForGenomics(MinPrevalenceRowPlotterConfig): + dataset_name: str = field(default="genomics", init=False, repr=False) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [ + (get_feature("ReadCount"), get_feature("GenomicVariant")), + (get_feature("ReadCount"), get_feature("GenomicVariant")), + ], + init=False, + repr=False, + ) + + +@dataclass +class MinPrevalenceRowPlotterConfigForSNP(MinPrevalenceRowPlotterConfig): + dataset_name: str = field(default="snp", init=False, repr=False) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("GenomicVariant"), + get_feature("GenomicVariant"), + ], + init=False, + repr=False, + ) + + +@dataclass +class MinPrevalenceRowPlotterConfigForMaldi(MinPrevalenceRowPlotterConfig): + dataset_name: str = field(default="maldi", init=False, repr=False) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("PeakIntensity"), + get_feature("PeakIntensity"), + ], + init=False, + repr=False, + ) + feature_main: str = "Peak Total Intensity Distribution" + non_zero_samp_xlab: str = "Peak Richness" + non_zero_samp_main: str = "Richness Distribution" + non_zero_feat_xlab: str = "Peak Prevalence" + non_zero_feat_main: str = "Prevalence Across Samples" + xlog = "log2_1p" + include_non_zero_sum: bool = True + + +@dataclass +class MinPrevalenceRowPlotterConfigForReadCount(MinPrevalenceRowPlotterConfig): + dataset_name: str = field(default="read_count", init=False, repr=False) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("ReadCount"), get_feature("ReadCount")], + init=False, + repr=False, + ) + + +@dataclass +class MinPrevalenceRowPlotterConfigForProteomics(MinPrevalenceRowPlotterConfig): + dataset_name: str = field(default="proteomics", init=False, repr=False) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("PeakIntensity"), + get_feature("PeakIntensity"), + ], + init=False, + repr=False, + ) + + +class MinPrevalenceRowPlotter(SampleFilterPlotter): + config_class = MinPrevalenceRowPlotterConfig + config: MinPrevalenceRowPlotterConfig + + def __init__( + self, + sample_xlab: str = Unset('"Sum of Counts"'), + sample_main: str = Unset('"Sample Distribution"'), + feature_xlab: str = Unset('"Sum of Counts"'), + feature_main: str = Unset('"Feature Distribution"'), + legend_position: str = Unset('"top"'), + legend_title: str = Unset('"Min Prevalence Sample SampleFiltering"'), + before_name: str = Unset('"Before"'), + after_name: str = Unset('"After"'), + xlog: str = Unset("None"), + ylog: str = Unset("None"), + ncol: int = Unset("2"), + include_non_zero_sum: bool = Unset("False"), + non_zero_samp_xlab: str = Unset("None"), + non_zero_samp_main: str = Unset("None"), + non_zero_feat_xlab: str = Unset("None"), + non_zero_feat_main: str = Unset("None"), + config: MinPrevalenceRowPlotterConfig = None, + **kwargs, + ): + super().__init__( + config=config, + sample_xlab=sample_xlab, + sample_main=sample_main, + feature_xlab=feature_xlab, + feature_main=feature_main, + legend_position=legend_position, + legend_title=legend_title, + before_name=before_name, + after_name=after_name, + xlog=xlog, + ylog=ylog, + ncol=ncol, + include_non_zero_sum=include_non_zero_sum, + non_zero_samp_xlab=non_zero_samp_xlab, + non_zero_samp_main=non_zero_samp_main, + non_zero_feat_xlab=non_zero_feat_xlab, + non_zero_feat_main=non_zero_feat_main, + **kwargs, + ) + + def plot_pandas(self, x1, x2): + row_sums = [pa.array(x1.sum(axis=1)), pa.array(x2.sum(axis=1))] + col_sums = [pa.array(x1.sum(axis=0)), pa.array(x2.sum(axis=0))] + input = [row_sums, col_sums] + mains = [self.config.sample_main, self.config.feature_main] + xlabs = [self.config.sample_xlab, self.config.feature_xlab] + if self.config.include_non_zero_sum: + non_zero_sample_sums = [ + pa.array(x1[x1 > 0].count(axis=1)), + pa.array(x2[x2 > 0].count(axis=1)), + ] + non_zero_feature_sums = [ + pa.array(x1[x1 > 0].count(axis=0)), + pa.array(x2[x2 > 0].count(axis=0)), + ] + mains.extend( + [ + self.config.non_zero_samp_main, + self.config.non_zero_feat_main, + ] + ) + xlabs.extend( + [ + self.config.non_zero_samp_xlab, + self.config.non_zero_feat_xlab, + ] + ) + input.extend([non_zero_sample_sums, non_zero_feature_sums]) + + self.plotter( + list_of_sums=input, + xlabs=xlabs, + mains=mains, + **self.config.get_params(), + ) + + def plot( + self, + x1, + x2=None, + input_columns1: SelectedColumnTypes = None, + input_columns2: SelectedColumnTypes = None, + sample_xlab: str = Unset('"Sum of Counts"'), + sample_main: str = Unset('"Sample Distribution"'), + feature_xlab: str = Unset('"Sum of Counts"'), + feature_main: str = Unset('"Feature Distribution"'), + legend_position: str = Unset('"top"'), + legend_title: str = Unset('"Min Prevalence Sample SampleFiltering"'), + before_name: str = Unset('"Before"'), + after_name: str = Unset('"After"'), + xlog: str = Unset("None"), + ylog: str = Unset("None"), + ncol: int = Unset("2"), + include_non_zero_sum: bool = Unset("False"), + non_zero_samp_xlab: str = Unset("None"), + non_zero_samp_main: str = Unset("None"), + non_zero_feat_xlab: str = Unset("None"), + non_zero_feat_main: str = Unset("None"), + path: str = None, + device: str = "pdf", + fingerprint: str = None, + unused_columns: SelectedColumnTypes = None, + raise_if_missing: bool = True, + show: bool = True, + ): + self.config._input_columns = self._set_input_columns_and_arity( + input_columns1, input_columns2 + ) + return self._plot( + x1, + x2, + sample_xlab=sample_xlab, + sample_main=sample_main, + feature_xlab=feature_xlab, + feature_main=feature_main, + legend_position=legend_position, + legend_title=legend_title, + before_name=before_name, + after_name=after_name, + xlog=xlog, + ylog=ylog, + ncol=ncol, + include_non_zero_sum=include_non_zero_sum, + non_zero_samp_xlab=non_zero_samp_xlab, + non_zero_samp_main=non_zero_samp_main, + non_zero_feat_xlab=non_zero_feat_xlab, + non_zero_feat_main=non_zero_feat_main, + path=path, + device=device, + fingerprint=fingerprint, + unused_columns=unused_columns, + raise_if_missing=raise_if_missing, + show=show, + ) diff --git a/src/biofit/preprocessing/filtering/missing_labels/__init__.py b/src/biofit/preprocessing/filtering/missing_labels/__init__.py new file mode 100644 index 0000000..2e9f104 --- /dev/null +++ b/src/biofit/preprocessing/filtering/missing_labels/__init__.py @@ -0,0 +1,2 @@ +# ruff: noqa +from .missing_labels import MissingLabelsSampleFilter, MissingLabelsSampleFilterConfig diff --git a/src/biofit/preprocessing/filtering/missing_labels/missing_labels.py b/src/biofit/preprocessing/filtering/missing_labels/missing_labels.py new file mode 100644 index 0000000..d8d3bd6 --- /dev/null +++ b/src/biofit/preprocessing/filtering/missing_labels/missing_labels.py @@ -0,0 +1,221 @@ +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Optional, Type, Union + +import numpy as np +import pandas as pd +import pyarrow as pa +import pyarrow.compute as pc +from biocore import DataHandler + +from biofit.integration.biosets import get_feature +from biofit.processing import SelectedColumnTypes +from biofit.utils import logging + +from ..filtering import SampleFilter, SampleFilterConfig + +if TYPE_CHECKING: + import polars as pl + +logger = logging.get_logger(__name__) + + +@dataclass +class MissingLabelsSampleFilterConfig(SampleFilterConfig): + # process description + _transform_process_desc: str = field( + default="SampleFilter out rows with missing labels", init=False, repr=False + ) + processor_name: str = field(default="missing_labels", init=False, repr=False) + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("TARGET_FEATURE_TYPES")], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("TARGET_FEATURE_TYPES")], + init=False, + repr=False, + ) + + missing_label: Optional[Union[str, int]] = "auto" + + def __post_init__(self): + self._transform_process_desc = ( + f"SampleFiltering out rows with labels equaling to {self.missing_label}" + ) + + +class MissingLabelsSampleFilter(SampleFilter): + """Remove samples that are not labeled + + - config: + - mising_label: The value we deem as missing and want to remove. Default is "auto". + """ + + # main config class + config_class = MissingLabelsSampleFilterConfig + config: MissingLabelsSampleFilterConfig + + def __init__( + self, + config: Optional[MissingLabelsSampleFilterConfig] = None, + missing_label: Optional[Union[str, int]] = "auto", + **kwargs, + ): + super().__init__(config=config, missing_label=missing_label, **kwargs) + + def fit( + self, + X, + input_columns: SelectedColumnTypes = None, + raise_if_missing: bool = None, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + num_proc: int = None, + map_kwargs: dict = {"fn_kwargs": {}}, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ) -> "MissingLabelsSampleFilter": + self.config._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_fit( + X, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ) + + def transform( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + self._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_transform( + X, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def fit_transform( + self, + X, + *args, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = True, + raise_if_missing: bool = False, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = {"fn_kwargs": {}}, + num_proc: int = None, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ): + return self.fit( + X, + input_columns=input_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ).transform( + X, + input_columns=input_columns, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def _process_transform_input(self, X, **kwargs): + selected_indices = kwargs["fn_kwargs"].get("selected_indices", None) or [0] + if self.config.missing_label == "auto": + if DataHandler.is_categorical(X, selected_indices[0], threshold=30): + self.config.missing_label = -1 + kwargs["desc"] = self._transform_process_desc = ( + "SampleFiltering out rows with labels equaling to -1" + ) + else: + self.config.missing_label = None + kwargs["desc"] = self._transform_process_desc = ( + "SampleFiltering out rows with labels equaling to None" + ) + return super()._process_transform_input(X, **kwargs) + + def _transform_arrow(self, X: Union[pa.Table, pa.Array]): + if isinstance(X, pa.Table): + return ( + pc.not_equal(X.column(0), self.config.missing_label).to_numpy().tolist() + ) + return pc.not_equal(X, self.config.missing_label).to_numpy().tolist() + + def _transform_pandas(self, X: pd.DataFrame): + return X.ne(self.config.missing_label).values.tolist() + + def _transform_polars(self, X: "pl.DataFrame"): + import polars as pl + + if isinstance(X, pl.Series): + return X.ne(self.config.missing_label).to_numpy().tolist() + return ( + X.get_column(X.columns[0]).ne(self.config.missing_label).to_numpy().tolist() + ) + + def _transform_numpy(self, X: np.ndarray): + return X != self.config.missing_label diff --git a/src/biofit/preprocessing/filtering/plot_filtering.R b/src/biofit/preprocessing/filtering/plot_filtering.R new file mode 100644 index 0000000..23e8a9b --- /dev/null +++ b/src/biofit/preprocessing/filtering/plot_filtering.R @@ -0,0 +1,69 @@ +source(file.path(R_SCRIPTS_PATH, "plotting_utils.R")) + + +plot_filter <- function( + list_of_sums, path, + xlabs, + mains, + legend_position = "top", + legend_title = "Filtering", + before_name = "Before", + after_name = "After", + xlog = NULL, + ylog = NULL, + ...) { + + + suppressPackageStartupMessages(require(ggplot2)) + suppressPackageStartupMessages(require(RColorBrewer)) + suppressPackageStartupMessages(require(circlize)) + suppressPackageStartupMessages(require(patchwork)) + + if (!is.list(list_of_sums) && !is.vector(list_of_sums)) { + list_of_sums <- list(list_of_sums) + } + + if (!is.list(xlabs) && !is.vector(xlabs)) { + xlabs <- list(xlabs) + } + + if (!is.list(mains) && !is.vector(mains)) { + mains <- list(mains) + } + + if (length(xlabs) != length(list_of_sums) && length(xlabs) == 1) { + xlabs <- rep(xlabs, length(list_of_sums)) + } + + if (length(mains) != length(list_of_sums) && length(mains) == 1) { + mains <- rep(mains, length(list_of_sums)) + } + plots <- NULL + for (i in seq_along(list_of_sums)) { + x1 <- as.vector(list_of_sums[[i]][[1]]) + x2 <- as.vector(list_of_sums[[i]][[2]]) + if (is.null(plots)) { + plots <- generate_comparison_histogram( + x1, x2, + xlab = xlabs[[i]], title = mains[[i]], xlog = xlog, ylog = ylog, + legend_title = legend_title, subplot_title1 = before_name, + subplot_title2 = after_name + ) + } else { + plots <- plots + generate_comparison_histogram( + x1, x2, + xlab = xlabs[[i]], title = mains[[i]], xlog = xlog, ylog = ylog, + legend_title = legend_title, subplot_title1 = before_name, + subplot_title2 = after_name + ) + } + } + + grid <- plots + patchwork::plot_layout(guides = "collect", ncol = ncol) & + ggplot2::theme( + text = ggplot2::element_text(size = 8), + legend.position = legend_position, + ) + + save_plots(path, plot = grid, width = 6, height = 6.5, dpi = 600) +} diff --git a/src/biofit/preprocessing/filtering/plot_filtering.py b/src/biofit/preprocessing/filtering/plot_filtering.py new file mode 100644 index 0000000..2bf2dff --- /dev/null +++ b/src/biofit/preprocessing/filtering/plot_filtering.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass, field +from pathlib import Path + +from biofit.visualization.plotting import BasePlotter, PlotterConfig + + +@dataclass +class SampleFilterPlotterConfig(PlotterConfig): + processor_type: str = field(default="filtering", init=False, repr=False) + r_source: str = field( + default=(Path(__file__).parent / "plot_filtering.R").as_posix(), + init=False, + repr=False, + ) + main_method: str = field(default="plot_filter", init=False, repr=False) + + +class SampleFilterPlotter(BasePlotter): + config_class = SampleFilterPlotterConfig + config: SampleFilterPlotterConfig diff --git a/src/biofit/preprocessing/filtering/row_abundance/__init__.py b/src/biofit/preprocessing/filtering/row_abundance/__init__.py new file mode 100644 index 0000000..5a7c699 --- /dev/null +++ b/src/biofit/preprocessing/filtering/row_abundance/__init__.py @@ -0,0 +1,11 @@ +# ruff: noqa +from .row_abundance import ( + AbundanceSampleFilterConfigForOTU, + AbundanceSampleFilterConfig, + AbundanceSampleFilter, +) +from .plot_row_abundance import ( + AbundanceSampleFilterPlotterConfigForOTU, + AbundanceSampleFilterPlotterConfig, + AbundanceSampleFilterPlotter, +) diff --git a/src/biofit/preprocessing/filtering/row_abundance/plot_row_abundance.py b/src/biofit/preprocessing/filtering/row_abundance/plot_row_abundance.py new file mode 100644 index 0000000..655622b --- /dev/null +++ b/src/biofit/preprocessing/filtering/row_abundance/plot_row_abundance.py @@ -0,0 +1,327 @@ +from dataclasses import dataclass, field +from typing import List, Optional, Type + +import pyarrow as pa + +from biofit.integration.biosets import get_feature +from biofit.processing import SelectedColumnTypes +from biofit.utils.types import Unset + +from ..plot_filtering import ( + SampleFilterPlotter, + SampleFilterPlotterConfig, +) + + +@dataclass +class AbundanceSampleFilterPlotterConfig(SampleFilterPlotterConfig): + plot_process_desc: str = field( + default="Plotting presence feature selection", init=False, repr=False + ) + _fit_unused_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("METADATA_FEATURE_TYPES"), + get_feature("METADATA_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + _transform_unused_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("METADATA_FEATURE_TYPES"), + get_feature("METADATA_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + _compare: bool = field(default=True, init=False, repr=False) + processor_name: str = field(default="row_abundance", init=False, repr=False) + + sample_xlab: str = "Sum of Counts" + sample_main: str = "Sample Distribution" + feature_xlab: str = "Sum of Counts" + feature_main: str = "Feature Distribution" + legend_position: str = "top" + legend_title: str = "Abundance SampleFiltering" + before_name: str = "Before" + after_name: str = "After" + xlog: str = None + ylog: str = None + ncol: int = 2 + include_non_zero_sum: bool = False + non_zero_samp_xlab: str = None + non_zero_samp_main: str = None + non_zero_feat_xlab: str = None + non_zero_feat_main: str = None + + +@dataclass +class AbundanceSampleFilterPlotterConfigForMetagenomics( + AbundanceSampleFilterPlotterConfig +): + dataset_name: str = field(default="metagenomics", init=False, repr=False) + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance"), get_feature("Abundance")], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance"), get_feature("Abundance")], + init=False, + repr=False, + ) + + feature_main = "Taxa Total Abundance Distribution" + non_zero_samp_xlab: str = "Species Richness" + non_zero_samp_main: str = "Richness Distribution" + non_zero_feat_xlab: str = "Species Prevalence" + non_zero_feat_main: str = "Prevalence Across Samples" + xlog = "log2_1p" + include_non_zero_sum: bool = True + + +@dataclass +class AbundanceSampleFilterPlotterConfigForOTU( + AbundanceSampleFilterPlotterConfigForMetagenomics +): + dataset_name: str = field(default="otu", init=False, repr=False) + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance"), get_feature("Abundance")], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance"), get_feature("Abundance")], + init=False, + repr=False, + ) + feature_main: str = "OTU Total Abundance Distribution" + include_non_zero_sum: bool = True + + +@dataclass +class AbundanceSampleFilterPlotterConfigForASV( + AbundanceSampleFilterPlotterConfigForMetagenomics +): + dataset_name: str = field(default="asv", init=False, repr=False) + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance"), get_feature("Abundance")], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance"), get_feature("Abundance")], + init=False, + repr=False, + ) + include_non_zero_sum: bool = True + + +@dataclass +class AbundanceSampleFilterPlotterConfigForGenomics(AbundanceSampleFilterPlotterConfig): + dataset_name: str = field(default="genomics", init=False, repr=False) + _input_feature_types: List[Type] = field( + default_factory=lambda: [ + (get_feature("ReadCount"), get_feature("GenomicVariant")), + (get_feature("ReadCount"), get_feature("GenomicVariant")), + ], + init=False, + repr=False, + ) + + +@dataclass +class AbundanceSampleFilterPlotterConfigForSNP(AbundanceSampleFilterPlotterConfig): + dataset_name: str = field(default="snp", init=False, repr=False) + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("GenomicVariant"), + get_feature("GenomicVariant"), + ], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("GenomicVariant"), + get_feature("GenomicVariant"), + ], + init=False, + repr=False, + ) + + +@dataclass +class AbundanceSampleFilterPlotterConfigForReadCount( + AbundanceSampleFilterPlotterConfig +): + dataset_name: str = field(default="read_count", init=False, repr=False) + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("ReadCount"), get_feature("ReadCount")], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("ReadCount"), get_feature("ReadCount")], + init=False, + repr=False, + ) + + +@dataclass +class AbundanceSampleFilterPlotterConfigForProteomics( + AbundanceSampleFilterPlotterConfig +): + dataset_name: str = field(default="proteomics", init=False, repr=False) + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Expression"), get_feature("Expression")], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Expression"), get_feature("Expression")], + init=False, + repr=False, + ) + + +class AbundanceSampleFilterPlotter(SampleFilterPlotter): + config_class = AbundanceSampleFilterPlotterConfig + config: AbundanceSampleFilterPlotterConfig + + def __init__( + self, + sample_xlab: str = Unset('"Sum of Counts"'), + sample_main: str = Unset('"Sample Distribution"'), + feature_xlab: str = Unset('"Sum of Counts"'), + feature_main: str = Unset('"Feature Distribution"'), + legend_position: str = Unset('"top"'), + legend_title: str = Unset('"Sample Abundance SampleFiltering"'), + before_name: str = Unset('"Before"'), + after_name: str = Unset('"After"'), + xlog: str = Unset("None"), + ylog: str = Unset("None"), + ncol: int = Unset("2"), + include_non_zero_sum: bool = Unset("False"), + non_zero_samp_xlab: str = Unset("None"), + non_zero_samp_main: str = Unset("None"), + non_zero_feat_xlab: str = Unset("None"), + non_zero_feat_main: str = Unset("None"), + config: Optional[AbundanceSampleFilterPlotterConfig] = None, + **kwargs, + ): + super().__init__( + config=config, + sample_xlab=sample_xlab, + sample_main=sample_main, + feature_xlab=feature_xlab, + feature_main=feature_main, + legend_position=legend_position, + legend_title=legend_title, + before_name=before_name, + after_name=after_name, + xlog=xlog, + ylog=ylog, + ncol=ncol, + include_non_zero_sum=include_non_zero_sum, + non_zero_samp_xlab=non_zero_samp_xlab, + non_zero_samp_main=non_zero_samp_main, + non_zero_feat_xlab=non_zero_feat_xlab, + non_zero_feat_main=non_zero_feat_main, + **kwargs, + ) + + def plot_pandas(self, x1, x2): + row_sums = [pa.array(x1.sum(axis=1)), pa.array(x2.sum(axis=1))] + col_sums = [pa.array(x1.sum(axis=0)), pa.array(x2.sum(axis=0))] + input = [row_sums, col_sums] + mains = [self.config.sample_main, self.config.feature_main] + xlabs = [self.config.sample_xlab, self.config.feature_xlab] + if self.config.include_non_zero_sum: + non_zero_sample_sums = [ + pa.array(x1[x1 > 0].count(axis=1)), + pa.array(x2[x2 > 0].count(axis=1)), + ] + non_zero_feature_sums = [ + pa.array(x1[x1 > 0].count(axis=0)), + pa.array(x2[x2 > 0].count(axis=0)), + ] + mains.extend( + [ + self.config.non_zero_samp_main, + self.config.non_zero_feat_main, + ] + ) + xlabs.extend( + [ + self.config.non_zero_samp_xlab, + self.config.non_zero_feat_xlab, + ] + ) + input.extend([non_zero_sample_sums, non_zero_feature_sums]) + + self.plotter( + list_of_sums=input, + xlabs=xlabs, + mains=mains, + **self.config.get_params(), + ) + + def plot( + self, + x1, + x2=None, + input_columns1: SelectedColumnTypes = None, + input_columns2: SelectedColumnTypes = None, + sample_xlab: str = Unset('"Sum of Counts"'), + sample_main: str = Unset('"Sample Distribution"'), + feature_xlab: str = Unset('"Sum of Counts"'), + feature_main: str = Unset('"Feature Distribution"'), + legend_position: str = Unset('"top"'), + legend_title: str = Unset('"Sample Abundance SampleFiltering"'), + before_name: str = Unset('"Before"'), + after_name: str = Unset('"After"'), + xlog: str = Unset("None"), + ylog: str = Unset("None"), + ncol: int = Unset("2"), + include_non_zero_sum: bool = Unset("False"), + non_zero_samp_xlab: str = Unset("None"), + non_zero_samp_main: str = Unset("None"), + non_zero_feat_xlab: str = Unset("None"), + non_zero_feat_main: str = Unset("None"), + path: str = None, + device: str = "pdf", + fingerprint: str = None, + unused_columns: SelectedColumnTypes = None, + raise_if_missing: bool = True, + show: bool = True, + ): + self.config._input_columns = self._set_input_columns_and_arity( + input_columns1, input_columns2 + ) + return self._plot( + x1, + x2, + sample_xlab=sample_xlab, + sample_main=sample_main, + feature_xlab=feature_xlab, + feature_main=feature_main, + legend_position=legend_position, + legend_title=legend_title, + before_name=before_name, + after_name=after_name, + xlog=xlog, + ylog=ylog, + ncol=ncol, + include_non_zero_sum=include_non_zero_sum, + non_zero_samp_xlab=non_zero_samp_xlab, + non_zero_samp_main=non_zero_samp_main, + non_zero_feat_xlab=non_zero_feat_xlab, + non_zero_feat_main=non_zero_feat_main, + path=path, + device=device, + fingerprint=fingerprint, + unused_columns=unused_columns, + raise_if_missing=raise_if_missing, + show=show, + ) diff --git a/src/biofit/preprocessing/filtering/row_abundance/row_abundance.py b/src/biofit/preprocessing/filtering/row_abundance/row_abundance.py new file mode 100644 index 0000000..635f3ae --- /dev/null +++ b/src/biofit/preprocessing/filtering/row_abundance/row_abundance.py @@ -0,0 +1,312 @@ +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Type, Union + +import numpy as np +import pandas as pd + +from biofit.integration.biosets import get_feature +from biofit.processing import SelectedColumnTypes +from biofit.utils import logging + +from ..filtering import SampleFilter, SampleFilterConfig + +if TYPE_CHECKING: + import polars as pl + +logger = logging.get_logger(__name__) + + +@dataclass +class AbundanceSampleFilterConfig(SampleFilterConfig): + lower_threshold: Union[int, str] = "auto" + upper_threshold: Union[int, str] = None + processor_name: str = field(default="row_abundance", init=False, repr=False) + _fit_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + _transform_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + + +@dataclass +class AbundanceSampleFilterConfigForOTU(AbundanceSampleFilterConfig): + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance")], init=False, repr=False + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance")], init=False, repr=False + ) + dataset_name: str = field(default="otu", init=False, repr=False) + lower_threshold: Union[int, str] = "auto" + upper_threshold: Union[int, str] = "auto" + + def __post_init__(self): + self._transform_process_desc = f"SampleFiltering out samples with less than {self.lower_threshold} total otu abundance" + + +class AbundanceSampleFilter(SampleFilter): + config_class = AbundanceSampleFilterConfig + config: AbundanceSampleFilterConfig + + IQRs = None + + def fit( + self, + X, + input_columns: SelectedColumnTypes = None, + raise_if_missing: bool = None, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + num_proc: int = None, + map_kwargs: dict = {"fn_kwargs": {}}, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ) -> "AbundanceSampleFilter": + self.config._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_fit( + X, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ) + + def transform( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + self._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_transform( + X, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def fit_transform( + self, + X, + *args, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = True, + raise_if_missing: bool = False, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = {"fn_kwargs": {}}, + num_proc: int = None, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ): + return self.fit( + X, + input_columns=input_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ).transform( + X, + input_columns=input_columns, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def _process_fit_input(self, input, **kwargs): + if ( + self.config.lower_threshold != "auto" + and self.config.upper_threshold != "auto" + ): + kwargs["fn_kwargs"]["fn"] = None + return super()._process_fit_input(input, **kwargs) + + def _fit_polars(self, X: "pl.DataFrame"): + iqr = [X.sum_horizontal().quantile(0.25), X.sum_horizontal().quantile(0.75)] + if self.config.lower_threshold == "auto": + self.config.lower_threshold = iqr[0] - 1.5 * (iqr[1] - iqr[0]) + if self.config.upper_threshold == "auto": + self.config.upper_threshold = iqr[1] + 1.5 * (iqr[1] - iqr[0]) + return self + + def _partial_fit_polars(self, X: "pl.DataFrame"): + iqr = [X.sum_horizontal().quantile(0.25), X.sum_horizontal().quantile(0.75)] + if self.IQRs is None: + self.IQRs = [[iqr[0]], [iqr[1]]] + else: + self.IQRs[0].append(iqr[0]) + self.IQRs[1].append(iqr[1]) + return self + + def _fit_pandas(self, X: "pd.DataFrame"): + iqr = [X.sum(axis=1).quantile(0.25), X.sum(axis=1).quantile(0.75)] + if self.config.lower_threshold == "auto": + self.config.lower_threshold = iqr[0] - 1.5 * (iqr[1] - iqr[0]) + if self.config.upper_threshold == "auto": + self.config.upper_threshold = iqr[1] + 1.5 * (iqr[1] - iqr[0]) + return self + + def _partial_fit_pandas(self, X: "pd.DataFrame"): + iqr = [X.sum(axis=1).quantile(0.25), X.sum(axis=1).quantile(0.75)] + if self.IQRs is None: + self.IQRs = [[iqr[0]], [iqr[1]]] + else: + self.IQRs[0].append(iqr[0]) + self.IQRs[1].append(iqr[1]) + return self + + def _fit_numpy(self, X: np.ndarray): + iqr = [np.quantile(X.sum(axis=1), 0.25), np.quantile(X.sum(axis=1), 0.75)] + if self.config.lower_threshold == "auto": + self.config.lower_threshold = iqr[0] - 1.5 * (iqr[1] - iqr[0]) + if self.config.upper_threshold == "auto": + self.config.upper_threshold = iqr[1] + 1.5 * (iqr[1] - iqr[0]) + return self + + def _partial_fit_numpy(self, X: np.ndarray): + iqr = [np.quantile(X.sum(axis=1), 0.25), np.quantile(X.sum(axis=1), 0.75)] + if self.IQRs is None: + self.IQRs = [[iqr[0]], [iqr[1]]] + else: + self.IQRs[0].append(iqr[0]) + self.IQRs[1].append(iqr[1]) + return self + + def _pool_fit_any(self, partial_results: List["AbundanceSampleFilter"]): + IQRs = [[], []] + for result in partial_results: + IQRs[0].extend(result.IQRs[0]) + IQRs[1].extend(result.IQRs[1]) + IQRs[0] = np.mean(IQRs[0]) + IQRs[1] = np.mean(IQRs[1]) + if self.config.lower_threshold == "auto": + self.config.lower_threshold = IQRs[0] - 1.5 * (IQRs[1] - IQRs[0]) + if self.config.upper_threshold == "auto": + self.config.upper_threshold = IQRs[1] + 1.5 * (IQRs[1] - IQRs[0]) + return self + + def _process_fit_output(self, input, out): + if out.config.lower_threshold and out.config.upper_threshold: + logger.info( + f"Using {out.config.lower_threshold} as lower threshold and " + f"{out.config.upper_threshold} as upper threshold for filtering samples" + ) + elif out.config.lower_threshold and out.config.upper_threshold: + logger.info( + f"Using {out.config.lower_threshold} as lower threshold for filtering samples" + ) + elif out.config.upper_threshold: + logger.info( + f"Using {out.config.upper_threshold} as upper threshold for filtering samples" + ) + return super()._process_fit_output(input, out) + + def _process_transform_input(self, input, **kwargs): + if self.config.lower_threshold and self.config.upper_threshold: + kwargs["desc"] = ( + f"SampleFiltering out examples with less than {self.config.lower_threshold:.2f} and more than {self.config.upper_threshold:.2f} total abundance" + ) + elif self.config.lower_threshold: + kwargs["desc"] = ( + f"SampleFiltering out examples with less than {self.config.lower_threshold:.2f} total abundance" + ) + elif self.config.upper_threshold: + kwargs["desc"] = ( + f"SampleFiltering out examples with more than {self.config.upper_threshold:.2f} total abundance" + ) + + return super()._process_transform_input(input, **kwargs) + + def _transform_polars(self, X: "pl.DataFrame"): + row_sum = X.sum_horizontal() + if self.config.upper_threshold and self.config.lower_threshold: + return (row_sum > self.config.lower_threshold) & ( + row_sum < self.config.upper_threshold + ) + elif self.config.upper_threshold: + return row_sum < self.config.upper_threshold + elif self.config.lower_threshold: + return row_sum > self.config.lower_threshold + + def _transform_pandas(self, X: "pd.DataFrame"): + row_sum = X.sum(axis=1) + if self.config.upper_threshold and self.config.lower_threshold: + return (row_sum > self.config.lower_threshold) & ( + row_sum < self.config.upper_threshold + ) + elif self.config.upper_threshold: + return row_sum < self.config.upper_threshold + elif self.config.lower_threshold: + return row_sum > self.config.lower_threshold + + def _transform_numpy(self, X: np.ndarray): + row_sum = X.sum(axis=1) + if self.config.upper_threshold and self.config.lower_threshold: + return (row_sum > self.config.lower_threshold) & ( + row_sum < self.config.upper_threshold + ) + elif self.config.upper_threshold: + return row_sum < self.config.upper_threshold + elif self.config.lower_threshold: + return row_sum > self.config.lower_threshold diff --git a/src/biofit/preprocessing/imputation/__init__.py b/src/biofit/preprocessing/imputation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/biofit/preprocessing/imputation/imputation.py b/src/biofit/preprocessing/imputation/imputation.py new file mode 100644 index 0000000..46ad684 --- /dev/null +++ b/src/biofit/preprocessing/imputation/imputation.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass, field + +from biofit.processing import BaseProcessor, ProcessorConfig + + +@dataclass +class ImputationConfig(ProcessorConfig): + """Base class for imputation processor configurations.""" + + processor_name: str = field(default="imputation", init=False, repr=False) + + +class Imputation(BaseProcessor): + """Base class for imputation processors.""" diff --git a/src/biofit/preprocessing/resampling/__init__.py b/src/biofit/preprocessing/resampling/__init__.py new file mode 100644 index 0000000..72aaaf1 --- /dev/null +++ b/src/biofit/preprocessing/resampling/__init__.py @@ -0,0 +1,2 @@ +# ruff: noqa +from .upsampling import * diff --git a/src/biofit/preprocessing/resampling/resampling.py b/src/biofit/preprocessing/resampling/resampling.py new file mode 100644 index 0000000..a54da1a --- /dev/null +++ b/src/biofit/preprocessing/resampling/resampling.py @@ -0,0 +1,52 @@ +from dataclasses import dataclass, field + +import pandas as pd +import pyarrow as pa +from biocore import DataHandler + +from biofit.processing import BaseProcessor, ProcessorConfig +from biofit.utils import logging +from biofit.utils.table_util import string_to_arrow + +logger = logging.get_logger(__name__) + + +@dataclass +class ResamplerConfig(ProcessorConfig): + processor_type: str = field(default="resampling", init=False, repr=False) + + +class Resampler(BaseProcessor): + _feature_dependent = ( + False # SampleFiltering does not depend on features when transforming + ) + + def _process_transform_batch_output(self, input, out, **fn_kwargs): + inds = out["indices"] + if len(inds): + return DataHandler.select_rows(input, inds) + else: + cols = DataHandler.get_column_names(input, generate_cols=True) + schema = pa.schema( + [ + pa.field(k, string_to_arrow(v)) + for k, v in DataHandler.get_dtypes(input).items() + ] + ) + return pa.Table.from_pandas( + pd.DataFrame(columns=cols), preserve_index=False, schema=schema + ) + + def _process_transform_output(self, output, *args, **kwargs): + init_num_rows = DataHandler.get_shape(args[0])[0] + final_num_rows = DataHandler.get_shape(output)[0] + if final_num_rows != init_num_rows: + logger.info( + f'"{self.__class__.__name__}": Resampled to {final_num_rows} from ' + f"{init_num_rows} samples" + ) + else: + logger.info(f'"{self.__class__.__name__}": No resampling performed') + return output + + pass diff --git a/src/biofit/preprocessing/resampling/upsampling/__init__.py b/src/biofit/preprocessing/resampling/upsampling/__init__.py new file mode 100644 index 0000000..f61112e --- /dev/null +++ b/src/biofit/preprocessing/resampling/upsampling/__init__.py @@ -0,0 +1,7 @@ +# ruff: noqa +from .upsampling import ( + UpSamplerConfig, + UpSampler, + UpSamplerConfigForOTU, + UpSamplerConfigForSNP, +) diff --git a/src/biofit/preprocessing/resampling/upsampling/upsampling.py b/src/biofit/preprocessing/resampling/upsampling/upsampling.py new file mode 100644 index 0000000..cf06d0f --- /dev/null +++ b/src/biofit/preprocessing/resampling/upsampling/upsampling.py @@ -0,0 +1,479 @@ +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Type, Union + +import numpy as np +from biocore import DataHandler +from biocore.utils.import_util import is_imblearn_available, requires_backends + +from biofit.integration.biosets import get_feature +from biofit.processing import SelectedColumnTypes, sync_backup_config +from biofit.utils import logging + +from ..resampling import Resampler, ResamplerConfig + +if TYPE_CHECKING: + from sklearn.base import BaseEstimator + + from biofit.processing import BaseProcessor + +logger = logging.get_logger(__name__) + +OVER_SAMPLING_METHODS = {} + +if is_imblearn_available(): + from imblearn.over_sampling import ( + ADASYN, + SMOTE, + SMOTEN, + SMOTENC, + SVMSMOTE, + KMeansSMOTE, + RandomOverSampler, + ) + + OVER_SAMPLING_METHODS = { + "random": RandomOverSampler, + "smote": SMOTE, + "smoten": SMOTEN, + "smotenc": SMOTENC, + "svmsmote": SVMSMOTE, + "kmeanssmote": KMeansSMOTE, + "adasyn": ADASYN, + } + + +@dataclass +class UpSamplerConfig(ResamplerConfig): + # process descriptions + processor_name: str = field(default="upsampling", init=False, repr=False) + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [None, get_feature("TARGET_FEATURE_TYPES")], + init=False, + repr=False, + ) + _fit_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES"), None], + init=False, + repr=False, + ) + _transform_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + + input_columns: List[str] = field(default=None, init=False, repr=False) + target_column: str = field(default=None, init=False, repr=False) + method: str = "random" + + sampling_strategy: str = "auto" + random_state: int = None + + # RandomOverSampler specific attributes + shrinkage: dict = None + + # BaseSMOTE attributes + k_neighbors = 5 + n_jobs = None + + # SMOTEN specific attributes + categorical_encoder: Union["BaseEstimator", "BaseProcessor"] = None + + # SMOTENC specific attributes + categorical_features: list = None + + # SVMSMOTE specific attributes + m_neighbors: int = 10 + svm_estimator: Union["BaseEstimator", "BaseProcessor"] = None + out_step: float = 0.5 + + # KMeansSMOTE specific attributes + kmeans_estimator: Union[int, "BaseEstimator", "BaseProcessor"] = None + density_exponent: Union[float, str] = "auto" + cluster_balance_threshold: Union[float, str] = "auto" + + # ADASYN specific attributes + n_neighbors: int = 5 + + def __post_init__(self): + requires_backends("upsampling", "imblearn") + if self.random_state is None: + self.random_state = np.random.randint(0, np.iinfo(np.int32).max) + if self.method not in OVER_SAMPLING_METHODS: + raise ValueError( + f"Invalid method {self.method}. Valid methods are {list(OVER_SAMPLING_METHODS.keys())}" + ) + + self.sampler_kwargs_ = {} + self.sampler_kwargs_["sampling_strategy"] = self.sampling_strategy + self.sampler_kwargs_["random_state"] = self.random_state + + if self.method == "random": + self.sampler_kwargs_["sampling_strategy"] = self.sampling_strategy + + if self.method in [ + "smote", + "smoten", + "smotenc", + "svmsmote", + "kmeanssmote", + "adasyn", + ]: + self.sampler_kwargs_["n_jobs"] = self.n_jobs + + if self.method in ["smote", "smoten", "smotenc", "svmsmote", "kmeanssmote"]: + self.sampler_kwargs_["k_neighbors"] = self.k_neighbors + + if self.method in ["smoten", "smotenc"]: + self.sampler_kwargs_["categorical_encoder"] = self.categorical_encoder + + if self.method == "smotenc": + self.sampler_kwargs_["categorical_features"] = self.categorical_features + + if self.method == "svmsmote": + self.sampler_kwargs_["m_neighbors"] = self.m_neighbors + self.sampler_kwargs_["svm_estimator"] = self.svm_estimator + self.sampler_kwargs_["out_step"] = self.out_step + + if self.method == "kmeanssmote": + self.sampler_kwargs_["kmeans_estimator"] = self.kmeans_estimator + self.sampler_kwargs_["density_exponent"] = self.density_exponent + self.sampler_kwargs_["cluster_balance_threshold"] = ( + self.cluster_balance_threshold + ) + + if self.method == "adasyn": + self.sampler_kwargs_["n_neighbors"] = self.n_neighbors + + +@dataclass +class UpSamplerConfigForMetagenomics(UpSamplerConfig): + # dataset specific attributes + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [ + (get_feature("Abundance"), get_feature("ReadCount")), + get_feature("TARGET_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [(get_feature("Abundance"), get_feature("ReadCount"))], + init=False, + repr=False, + ) + dataset_name: str = field(default="metagenomics", init=False, repr=False) + + +@dataclass +class UpSamplerConfigForOTU(UpSamplerConfig): + # dataset specific attributes + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("Abundance"), + get_feature("TARGET_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance")], init=False, repr=False + ) + dataset_name: str = field(default="otu", init=False, repr=False) + + +@dataclass +class UpSamplerConfigForSNP(UpSamplerConfig): + # dataset specific attributes + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("GenomicVariant"), + get_feature("TARGET_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("GenomicVariant")], init=False, repr=False + ) + dataset_name: str = field(default="snp", init=False, repr=False) + + +class UpSampler(Resampler): + output_dtype = "float32" + + # config class + config_class = UpSamplerConfig + config: UpSamplerConfig + + def __init__( + self, + method: str = "random", + sampling_strategy: str = "auto", + random_state: int = None, + shrinkage: dict = None, + k_neighbors=5, + n_jobs=None, + categorical_encoder: Union["BaseEstimator", "BaseProcessor"] = None, + categorical_features: list = None, + m_neighbors: int = 10, + svm_estimator: Union["BaseEstimator", "BaseProcessor"] = None, + out_step: float = 0.5, + kmeans_estimator: Union[int, "BaseEstimator", "BaseProcessor"] = None, + density_exponent: Union[float, str] = "auto", + cluster_balance_threshold: Union[float, str] = "auto", + n_neighbors: int = 5, + config: Union[ + UpSamplerConfig, + UpSamplerConfigForOTU, + UpSamplerConfigForSNP, + UpSamplerConfigForMetagenomics, + ] = None, + **kwargs, + ): + """ + Upsample the minority class(es) in the dataset. + + Args: + method (str, 'random'): + The method to use for oversampling. Possible options are: + - 'random': Random Over Sampler + - 'smote': Synthetic Minority Over-sampling Technique + - 'smoten': SMOTE for numerical features only + - 'smotenc': SMOTE for numerical and categorical features + - 'svmsmote': SVM-SMOTE + - 'kmeanssmote': KMeans-SMOTE + - 'adasyn': Adaptive Synthetic Sampling Approach for Imbalanced Learning + + sampling_strategy (str, 'auto'): + The sampling strategy to use for oversampling. Possible options are: + - 'auto': Automatically resample the minority class(es) to the majority class(es) size. + - 'all': Resample all classes to the same size. + + random_state (int, *optional*): + The random state to use for sampling. + + shrinkage (dict, *optional*): + The shrinkage parameter for RandomOverSampler. + + k_neighbors (int, 5): + The number of nearest neighbors to use for SMOTE, SMOTEN, SMOTENC, SVMSMOTE, and KMeansSMOTE. + + n_jobs (int, *optional*): + The number of jobs to use for SMOTE, SMOTEN, SMOTENC, SVMSMOTE, and KMeansSMOTE. + + categorical_encoder (Union[BaseEstimator, BaseProcessor], *optional*): + The encoder to use for SMOTEN. + + categorical_features (list, *optional*): + The list of categorical features to use for SMOTENC. + + m_neighbors (int, 10): + The number of nearest neighbors to use for SVMSMOTE. + + svm_estimator (Union[BaseEstimator, BaseProcessor], *optional*): + The estimator to use for SVMSMOTE. + + out_step (float, 0.5): + The outlier step to use for SVMSMOTE. + + kmeans_estimator (Union[int, BaseEstimator, BaseProcessor], *optional*): + The estimator to use for KMeansSMOTE. + + density_exponent (Union[float, str], 'auto'): + The exponent to use for KMeansSMOTE. + + cluster_balance_threshold (Union[float, str], 'auto'): + The balance threshold to use for KMeansSMOTE. + + n_neighbors (int, 5): + The number of nearest neighbors to use for ADASYN. + + config (Union[UpSamplerConfig, UpSamplerConfigForOTU, UpSamplerConfigForSNP, UpSamplerConfigForMetagenomics], *optional*): + The configuration to use for the upsampling process. If provided, the other arguments are ignored. + Use set_params to update the configuration after initialization. + **kwargs: + Additional keyword arguments to be passed to ProcessorConfig + """ + super().__init__( + config=config, + method=method, + sampling_strategy=sampling_strategy, + random_state=random_state, + shrinkage=shrinkage, + k_neighbors=k_neighbors, + n_jobs=n_jobs, + categorical_encoder=categorical_encoder, + categorical_features=categorical_features, + m_neighbors=m_neighbors, + svm_estimator=svm_estimator, + out_step=out_step, + kmeans_estimator=kmeans_estimator, + density_exponent=density_exponent, + cluster_balance_threshold=cluster_balance_threshold, + n_neighbors=n_neighbors, + **kwargs, + ) + + self.over_sampler = OVER_SAMPLING_METHODS[self.config.method]( + **self.config.sampler_kwargs_ + ) + + @sync_backup_config + def set_params(self, **kwargs): + self.config = self.config.replace_defaults(**kwargs) + self.over_sampler = OVER_SAMPLING_METHODS[self.config.method]( + **self.config.sampler_kwargs_ + ) + return self + + def fit( + self, + X, + y=None, + input_columns: SelectedColumnTypes = None, + target_column: SelectedColumnTypes = None, + raise_if_missing: bool = None, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + num_proc: int = None, + map_kwargs: dict = {"fn_kwargs": {}}, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ) -> "UpSampler": + self.config._input_columns = self._set_input_columns_and_arity( + input_columns, target_column + ) + return self._process_fit( + X, + y, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ) + + def transform( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + self._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_transform( + X, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def fit_transform( + self, + X, + y=None, + input_columns: SelectedColumnTypes = None, + target_column: SelectedColumnTypes = None, + keep_unused_columns: bool = True, + raise_if_missing: bool = True, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = {"fn_kwargs": {}}, + num_proc: int = None, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ): + return self.fit( + X, + y, + input_columns=input_columns, + target_column=target_column, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ).transform( + X, + input_columns=input_columns, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def _fit_pandas(self, X, y): + self.over_sampler.fit_resample(X, DataHandler.select_column(y, 0)) + self.config.sample_indices = self.over_sampler.sample_indices_ + return self + + def _fit_numpy(self, X, y): + self.over_sampler.fit_resample(X, DataHandler.select_column(y, 0)) + self.config.sample_indices = self.over_sampler.sample_indices_ + return self + + def _process_transform_input(self, X, **kwargs): + batch_size = kwargs.get("batch_size", None) + if batch_size is not None and batch_size < DataHandler.get_shape(X)[0]: + logger.warning( + "Upsampling does not support batch processing. Ignoring batched and batch_size parameters." + ) + kwargs["batch_size"] = None + self.config.fingerprint = kwargs.get("new_fingerprint", None) + kwargs["features"] = None + return X, kwargs + + def _transform_any(self, X): + return {"indices": self.config.sample_indices.flatten()} diff --git a/src/biofit/preprocessing/scaling/__init__.py b/src/biofit/preprocessing/scaling/__init__.py new file mode 100644 index 0000000..f5808d2 --- /dev/null +++ b/src/biofit/preprocessing/scaling/__init__.py @@ -0,0 +1,9 @@ +# ruff: noqa +from .tmm import * +from .css import * +from .clr import * +from .relative_abundance import * +from .plot_scaling import ( + ScalerPlotter, + ScalerPlotterConfig, +) diff --git a/src/biofit/preprocessing/scaling/clr/__init__.py b/src/biofit/preprocessing/scaling/clr/__init__.py new file mode 100644 index 0000000..11c1e56 --- /dev/null +++ b/src/biofit/preprocessing/scaling/clr/__init__.py @@ -0,0 +1,2 @@ +# ruff: noqa +from .clr import CLRScaler, CLRScalerConfig diff --git a/src/biofit/preprocessing/scaling/clr/clr.py b/src/biofit/preprocessing/scaling/clr/clr.py new file mode 100644 index 0000000..bd39977 --- /dev/null +++ b/src/biofit/preprocessing/scaling/clr/clr.py @@ -0,0 +1,187 @@ +from dataclasses import dataclass, field +from typing import List, Type + +import numpy as np + +from biofit.integration.biosets import get_feature +from biofit.processing import SelectedColumnTypes +from biofit.utils import logging + +from ..scaling import Scaler, ScalerConfig + +logger = logging.get_logger(__name__) + + +@dataclass +class CLRScalerConfig(ScalerConfig): + """ + Configuration for CumulativeSumScaler. + """ + + processor_name: str = field(default="css", init=False, repr=False) + _fit_process_desc: str = field( + default="Calculating cumulative sum scaling percentile", init=False, repr=False + ) + _transform_process_desc: str = field( + default="Applying cumulative sum scaling", init=False, repr=False + ) + _fit_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + _transform_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + + pseudocount: float = 1e-6 + input_columns: List[str] = None + + +class CLRScaler(Scaler): + """ + BaseCumSumScaler applies cumulative sum scaling to the input data. + """ + + output_dtype = "float64" + config_class = CLRScalerConfig + config: CLRScalerConfig + + def __init__( + self, pseudocount: float = 1e-6, config: CLRScalerConfig = None, **kwargs + ): + super().__init__(config=config, pseudocount=pseudocount, **kwargs) + + def fit( + self, + X, + input_columns: SelectedColumnTypes = None, + raise_if_missing: bool = None, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + num_proc: int = None, + map_kwargs: dict = {"fn_kwargs": {}}, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ) -> "CLRScaler": + self.config._input_columns = self._set_input_columns_and_arity( + input_columns or self.config.input_columns + ) + return self._process_fit( + X, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ) + + def transform( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + self._input_columns = self._set_input_columns_and_arity( + input_columns or self.config.input_columns + ) + return self._process_transform( + X, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def fit_transform( + self, + X, + *args, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = True, + raise_if_missing: bool = False, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = {"fn_kwargs": {}}, + num_proc: int = None, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ): + return self.fit( + X, + input_columns=input_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ).transform( + X, + input_columns=input_columns, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def _transform_numpy(self, X: np.ndarray): + X = X + self.config.pseudocount + # Calculating the geometric mean of each sample + from scipy.stats import gmean + + r_geo_mean = gmean(X, axis=1)[:, np.newaxis] + + return np.log(X / r_geo_mean) diff --git a/src/biofit/preprocessing/scaling/css/__init__.py b/src/biofit/preprocessing/scaling/css/__init__.py new file mode 100644 index 0000000..1a32968 --- /dev/null +++ b/src/biofit/preprocessing/scaling/css/__init__.py @@ -0,0 +1,19 @@ +# ruff: noqa +from .css import ( + CumulativeSumScaler, + CumulativeSumScalerConfig, + CumulativeSumScalerConfigForOTU, + CumulativeSumScalerConfigForMetagenomics, + # CumulativeSumScalerConfigForGenomics, + # CumulativeSumScalerConfigForSNP, +) +from .plot_css import ( + CumulativeSumScalerPlotter, + CumulativeSumScalerPlotterConfig, + CumulativeSumScalerPlotterConfigForMetagenomics, + CumulativeSumScalerPlotterConfigForOTU, + CumulativeSumScalerPlotterConfigForASV, + CumulativeSumScalerPlotterConfigForGenomics, + CumulativeSumScalerPlotterConfigForSNP, + CumulativeSumScalerPlotterConfigForReadCount, +) diff --git a/src/biofit/preprocessing/scaling/css/css.py b/src/biofit/preprocessing/scaling/css/css.py new file mode 100644 index 0000000..19abdea --- /dev/null +++ b/src/biofit/preprocessing/scaling/css/css.py @@ -0,0 +1,348 @@ +from dataclasses import dataclass, field +from typing import List, Type + +import numpy as np +from scipy.interpolate import interp1d + +from biofit.integration.biosets import get_feature +from biofit.processing import SelectedColumnTypes, sync_backup_config +from biofit.utils import logging + +from ..scaling import Scaler, ScalerConfig + +logger = logging.get_logger(__name__) + + +def css_percentile_fast(mat, relative_threshold=0.1): + """Calculates the percentile for which to sum counts up to and scale by. + + Args: + + """ + # Check for sample with one or zero features + if np.any(np.sum(mat > 0, axis=1) <= 1): + raise ValueError("Warning: sample with one or zero features") + + mat = mat.astype(float) + # Sort each row with elements greater than zero + smat = np.sort(mat * (mat > 0), axis=1)[:, ::-1] + smat[smat == 0] = np.nan + + # Pad with NAs to make all rows the same length + max_length = np.max(np.nansum(smat > 0, axis=1)) + padded_smat = np.full((smat.shape[0], max_length), np.nan) + for i in range(smat.shape[0]): + valid_values = smat[i, smat[i] > 0] + padded_smat[i, : len(valid_values)] = valid_values + + # Calculate quantiles for each row + quantiles = np.nanquantile( + padded_smat, np.linspace(0, 1, padded_smat.shape[1]), axis=1 + ) + + padded_smat[np.isnan(padded_smat)] = 0 + + # Compute column means, ignoring NaNs + ref1 = np.nanmean(padded_smat, axis=0)[::-1] + + # Calculate differences + diffr = ref1[:, np.newaxis] - quantiles + + # Calculate median of absolute differences + diffr1 = np.nanmedian(np.abs(diffr), axis=1) + + diffr1[diffr1 == 0] = np.nan + + # Determine threshold + rel_diff = np.abs(np.diff(diffr1)) / diffr1[:-1] + + x = (np.where(rel_diff > relative_threshold)[0][0] + 1) / len(diffr1) + if x <= 0.5: + logger.info( + f"Percentile calculated at {x}, which is less than 0.5. Using default value instead." + ) + x = 0.5 + return x + + +def css_percentile(mat: np.ndarray, approx=False, relative_threshold=0.1): + if np.any(np.sum(mat, axis=1) == 0): + raise ValueError("Warning: empty feature", mat) + + # Sorting each row + smat = np.sort(mat, axis=1) + ref = np.mean(smat, axis=0) + orig_dtype = mat.dtype + mat = mat.astype(float) if not np.issubdtype(orig_dtype, np.floating) else mat + + if not mat.flags["WRITEABLE"]: + # copy the array if it is not writeable + mat = np.array(mat) + mat[mat == 0] = np.nan + + refS = np.sort(ref) + + k = np.where(refS > 0)[0][0] + lo = len(refS) - k + + if not approx: + diffr = np.apply_along_axis( + lambda row: refS[k:] - np.nanquantile(row, np.linspace(0, 1, lo), axis=0), + axis=1, + arr=mat, + ) + else: + + def f(row): + srow = np.sort(row) + rrow = (srow[0], np.nanmax(srow)) + y = np.arange(len(row)) + return refS[k:] - interp1d(row, y)(np.linspace(*rrow, lo)) + + diffr = np.apply_along_axis(f, axis=1, arr=mat) + + mat[np.isnan(mat)] = 0 + mat = mat.astype(orig_dtype) if not np.issubdtype(orig_dtype, np.floating) else mat + + diffr2 = np.nanmedian(np.abs(diffr), axis=0) + + rel_diff = np.abs(np.diff(diffr2)) / diffr2[:-1] + if len(rel_diff) == 0: + return 0.5 + x = (np.where(rel_diff > relative_threshold)[0][0] + 1) / len(diffr2) + if x <= 0.5: + logger.info( + f"Percentile calculated at {x}, which is less than 0.5. Using default value instead." + ) + x = 0.5 + + return x + + +@dataclass +class CumulativeSumScalerConfig(ScalerConfig): + """ + Configuration for CumulativeSumScaler. + """ + + processor_name: str = field(default="css", init=False, repr=False) + _fit_process_desc: str = field( + default="Calculating cumulative sum scaling percentile", init=False, repr=False + ) + _transform_process_desc: str = field( + default="Applying cumulative sum scaling", init=False, repr=False + ) + _fit_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + _transform_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + + scale: int = 1000 + relative_threshold: float = 0.5 + approx: bool = False + percentile: float = None + + +@dataclass +class CumulativeSumScalerConfigForMetagenomics(CumulativeSumScalerConfig): + """ + CumulativeSumScaler specifically designed for metagenomics data. + """ + + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [(get_feature("ReadCount"), get_feature("Abundance"))], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [(get_feature("ReadCount"), get_feature("Abundance"))], + init=False, + repr=False, + ) + dataset_name: str = field(default="metagenomics", init=False, repr=False) + + +@dataclass +class CumulativeSumScalerConfigForOTU(CumulativeSumScalerConfig): + """ + CumulativeSumScaler specifically designed for otu abundance. + + Args: + """ + + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance")], init=False, repr=False + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance")], init=False, repr=False + ) + dataset_name: str = field(default="otu", init=False, repr=False) + + +class CumulativeSumScaler(Scaler): + """ + BaseCumSumScaler applies cumulative sum scaling to the input data. + """ + + output_dtype = "float64" + config_class = CumulativeSumScalerConfig + config: CumulativeSumScalerConfig + + @sync_backup_config + def set_params(self, **kwargs): + self.config = self.config.replace_defaults(**kwargs) + if self.config.percentile is not None: + # there is overhead when passing datasets to the fit method + # (e.g. converting one format to another, preparing mapping jobs, etc.) + # which is unnecessary if the percentile is already known + # Therefore, we remove the fit method to avoid it being called + # TODO: This is a temporary solution. We should find a better way to handle this + delattr(self, "fit_numpy") + return self + + def fit( + self, + X, + input_columns: SelectedColumnTypes = None, + raise_if_missing: bool = None, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + num_proc: int = None, + map_kwargs: dict = {"fn_kwargs": {}}, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ) -> "CumulativeSumScaler": + self.config._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_fit( + X, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ) + + def transform( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + self._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_transform( + X, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def fit_transform( + self, + X, + *args, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = True, + raise_if_missing: bool = False, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = {"fn_kwargs": {}}, + num_proc: int = None, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ): + return self.fit( + X, + input_columns=input_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ).transform( + X, + input_columns=input_columns, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def _fit_numpy(self, X: np.ndarray): + # The fast implementation of CSS requires counts of all samples to at least + # have two non zero features. + if np.any(np.sum(X > 0, axis=1) <= 1): + self.config.percentile = css_percentile(X) + else: + self.config.percentile = css_percentile_fast(X) + return self + + def _transform_numpy(self, X: np.ndarray): + xx = np.where(X == 0, np.nan, X) + + # Compute row quantiles + qs = np.nanquantile(xx, self.config.percentile, axis=1, keepdims=True) + + xx_adj = xx - np.finfo(float).eps + norm_factors = np.nansum(np.where(xx_adj <= qs, xx_adj, 0), axis=1) + norm_factors[norm_factors == 0] = np.nan + norm_factors = norm_factors / self.config.scale + + return X / norm_factors[:, np.newaxis] diff --git a/src/biofit/preprocessing/scaling/css/plot_css.py b/src/biofit/preprocessing/scaling/css/plot_css.py new file mode 100644 index 0000000..49e3013 --- /dev/null +++ b/src/biofit/preprocessing/scaling/css/plot_css.py @@ -0,0 +1,297 @@ +from dataclasses import dataclass, field +from typing import List, Optional, Type + +from biocore import DataHandler + +from biofit.integration.biosets import get_feature +from biofit.processing import SelectedColumnTypes +from biocore.utils.py_util import is_bioset +from biofit.utils.types import Unset + +from ..plot_scaling import ScalerPlotter, ScalerPlotterConfig + + +@dataclass +class CumulativeSumScalerPlotterConfig(ScalerPlotterConfig): + processor_name: str = field(default="css", init=False, repr=False) + _compare: bool = field(default=True, init=False, repr=False) + _transoform_input_feature_types: List[Type] = field( + default_factory=lambda: [ + None, + None, + get_feature("TARGET_FEATURE_TYPES"), + get_feature("TARGET_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + + _transform_unused_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("METADATA_FEATURE_TYPES"), + get_feature("METADATA_FEATURE_TYPES"), + None, + None, + ], + init=False, + repr=False, + ) + + input_columns1: Optional[str] = None + input_columns2: Optional[str] = None + label_name1: Optional[str] = None + label_name2: Optional[str] = None + ylab: Optional[str] = None + xlab: Optional[str] = None + ylim: Optional[list] = None + before_title: str = "Before CSS" + after_title: str = "After CSS" + legend_position: str = "top" + add_box: bool = True + horizontal_plot: bool = False + order: bool = False + col_set: str = "Set1" + cols: Optional[str] = None + log_num: Optional[str] = None + show_outliers: bool = True + + +@dataclass +class CumulativeSumScalerPlotterConfigForMetagenomics(CumulativeSumScalerPlotterConfig): + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [ + (get_feature("Abundance"), get_feature("ReadCount")), + (get_feature("Abundance"), get_feature("ReadCount")), + get_feature("TARGET_FEATURE_TYPES"), + get_feature("TARGET_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + dataset_name: str = field(default="metagenomics", init=False, repr=False) + log_num: Optional[str] = "log2_1p" + + +@dataclass +class CumulativeSumScalerPlotterConfigForOTU(CumulativeSumScalerPlotterConfig): + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("Abundance"), + get_feature("Abundance"), + get_feature("TARGET_FEATURE_TYPES"), + get_feature("TARGET_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + dataset_name: str = field(default="otu", init=False, repr=False) + ylab: Optional[str] = "OTU Abundance" + log_num: Optional[str] = "log10_1p" + + +@dataclass +class CumulativeSumScalerPlotterConfigForASV(CumulativeSumScalerPlotterConfig): + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("Abundance"), + get_feature("Abundance"), + get_feature("TARGET_FEATURE_TYPES"), + get_feature("TARGET_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + dataset_name: str = field(default="asv", init=False, repr=False) + + +@dataclass +class CumulativeSumScalerPlotterConfigForGenomics(CumulativeSumScalerPlotterConfig): + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("GenomicVariant"), + get_feature("GenomicVariant"), + get_feature("TARGET_FEATURE_TYPES"), + get_feature("TARGET_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("GenomicVariant"), + get_feature("GenomicVariant"), + get_feature("TARGET_FEATURE_TYPES"), + get_feature("TARGET_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + dataset_name: str = field(default="genomics", init=False, repr=False) + + +@dataclass +class CumulativeSumScalerPlotterConfigForSNP(CumulativeSumScalerPlotterConfig): + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("GenomicVariant"), + get_feature("GenomicVariant"), + get_feature("TARGET_FEATURE_TYPES"), + get_feature("TARGET_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("GenomicVariant"), + get_feature("GenomicVariant"), + get_feature("TARGET_FEATURE_TYPES"), + get_feature("TARGET_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + dataset_name: str = field(default="snp", init=False, repr=False) + + +@dataclass +class CumulativeSumScalerPlotterConfigForReadCount(CumulativeSumScalerPlotterConfig): + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("ReadCount"), + get_feature("ReadCount"), + get_feature("TARGET_FEATURE_TYPES"), + get_feature("TARGET_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("ReadCount"), + get_feature("ReadCount"), + get_feature("TARGET_FEATURE_TYPES"), + get_feature("TARGET_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + dataset_name: str = field(default="read_count", init=False, repr=False) + + +class CumulativeSumScalerPlotter(ScalerPlotter): + config_class = CumulativeSumScalerPlotterConfig + config: CumulativeSumScalerPlotterConfig + + def __init__( + self, + ylab: Optional[str] = Unset("None"), + xlab: Optional[str] = Unset("None"), + ylim: Optional[list] = Unset("None"), + title: str = Unset('"Violin Plot"'), + legend_position: str = Unset('"top"'), + add_box: bool = Unset("True"), + horizontal_plot: bool = Unset("False"), + order: bool = Unset("False"), + col_set: str = Unset('"Set1"'), + cols: Optional[str] = Unset("None"), + log_num: Optional[int] = Unset("None"), + show_outliers: bool = Unset("True"), + config: CumulativeSumScalerPlotterConfig = None, + **kwargs, + ): + super().__init__( + config=config, + ylab=ylab, + xlab=xlab, + ylim=ylim, + title=title, + legend_position=legend_position, + add_box=add_box, + horizontal_plot=horizontal_plot, + order=order, + col_set=col_set, + cols=cols, + log_num=log_num, + show_outliers=show_outliers, + **kwargs, + ) + + def plot( + self, + x1, + x2=None, + y1=None, + y2=None, + input_columns1: SelectedColumnTypes = None, + input_columns2: SelectedColumnTypes = None, + label_name1: SelectedColumnTypes = None, + label_name2: SelectedColumnTypes = None, + ylab: Optional[str] = Unset("None"), + xlab: Optional[str] = Unset("None"), + ylim: Optional[list] = Unset("None"), + title: str = Unset('"Violin Plot"'), + legend_position: str = Unset('"top"'), + add_box: bool = Unset("True"), + horizontal_plot: bool = Unset("False"), + order: bool = Unset("False"), + col_set: str = Unset('"Set1"'), + cols: Optional[str] = Unset("None"), + log_num: Optional[int] = Unset("None"), + show_outliers: bool = Unset("True"), + path: str = None, + device: str = "pdf", + fingerprint: str = None, + unused_columns: SelectedColumnTypes = None, + raise_if_missing: bool = True, + show: bool = True, + ): + if is_bioset(x2): + from biosets import get_target + + if y2 is None: + y2 = get_target(x2) + if x1 is None or x2 is None: + raise ValueError("Must provide the before and after normalization.") + if y1 is not None and y2 is None: + raise ValueError("Must provide the target for the after normalization.") + if y1 is None: + if label_name1 is None: + raise ValueError( + "Must provide the target for the before normalization." + ) + y1 = DataHandler.select_columns(x1, label_name1) + x1 = DataHandler.drop_columns(x1, label_name1) + if y2 is None: + if label_name2 is None: + raise ValueError("Must provide the target for the after normalization.") + y2 = DataHandler.select_columns(x2, label_name2) + x2 = DataHandler.drop_columns(x2, label_name2) + + self.config._input_columns = self._set_input_columns_and_arity( + input_columns1, input_columns2, label_name1, label_name2 + ) + return self._plot( + x1, + x2, + y1, + y2, + ylab=ylab, + xlab=xlab, + ylim=ylim, + title=title, + legend_position=legend_position, + add_box=add_box, + horizontal_plot=horizontal_plot, + order=order, + col_set=col_set, + cols=cols, + log_num=log_num, + show_outliers=show_outliers, + path=path, + device=device, + fingerprint=fingerprint, + unused_columns=unused_columns, + raise_if_missing=raise_if_missing, + show=show, + ) diff --git a/src/biofit/preprocessing/scaling/plot_scaling.R b/src/biofit/preprocessing/scaling/plot_scaling.R new file mode 100644 index 0000000..568d166 --- /dev/null +++ b/src/biofit/preprocessing/scaling/plot_scaling.R @@ -0,0 +1,56 @@ +source(file.path(R_SCRIPTS_PATH, "plotting_utils.R")) + + +plot_scaler <- function( + x1, + x2 = NULL, + y1 = NULL, + y2 = NULL, + path = NULL, + input_columns1 = NULL, + input_columns2 = NULL, + label_name1 = NULL, + label_name2 = NULL, + ylab = NULL, + xlab = NULL, + ylim = NULL, + before_title = NULL, + after_title = NULL, + legend_position = NULL, + add_box = TRUE, + horizontal_plot = FALSE, + order = FALSE, + col_set = "Set1", + cols = NULL, + log_num = NULL, + show_outliers = TRUE) { + + suppressPackageStartupMessages(require(patchwork)) + + if (is.null(y2)) { + y2 <- y1 + } + p1 <- generate_violin( + x1, y1, + column = input_columns1, label_name = label_name1, ylim = ylim, + xlab = xlab, ylab = ylab, title = before_title, legend_position = legend_position, + add_box = add_box, horizontal_plot = horizontal_plot, order = order, + col_set = col_set, cols = cols, log_num = log_num, show_outliers = show_outliers + ) + if (!is.null(x2)) { + p2 <- generate_violin( + x2, y2, + column = input_columns2, label_name = label_name2, ylim = ylim, + xlab = xlab, ylab = ylab, title = after_title, legend_position = legend_position, + add_box = add_box, horizontal_plot = horizontal_plot, order = order, + col_set = col_set, cols = cols, log_num = log_num, show_outliers = show_outliers + ) + the_plot <- p1 / p2 + } else { + the_plot <- p1 + } + if (!is.null(path)) { + save_plots(path, plot = the_plot, width = 6, height = 6, dpi = 600) + } + return(path) +} diff --git a/src/biofit/preprocessing/scaling/plot_scaling.py b/src/biofit/preprocessing/scaling/plot_scaling.py new file mode 100644 index 0000000..a8c1c66 --- /dev/null +++ b/src/biofit/preprocessing/scaling/plot_scaling.py @@ -0,0 +1,51 @@ +from dataclasses import dataclass, field +from pathlib import Path + +from biocore import DataHandler +from biocore.utils.import_util import is_biosets_available + +from biofit.integration.biosets import get_feature +from biofit.processing import SelectedFeatureTypes +from biofit.visualization.plotting import BasePlotter, PlotterConfig + + +@dataclass +class ScalerPlotterConfig(PlotterConfig): + processor_type: str = field(default="scaling", init=False, repr=False) + r_source: str = field( + default=(Path(__file__).parent / "plot_scaling.R").as_posix(), + init=False, + repr=False, + ) + main_method: str = field(default="plot_scaler", init=False, repr=False) + _fit_unused_feature_types: SelectedFeatureTypes = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + _transform_unused_feature_types: SelectedFeatureTypes = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + + +class ScalerPlotter(BasePlotter): + config_class = ScalerPlotterConfig + config: ScalerPlotterConfig + + def plot_dataset(self, x1, x2, y1, y2): + if is_biosets_available(): + from biosets import decode + + y1 = decode(y1) if y1 is not None else None + y2 = decode(y2) if y2 is not None else None + return self.plot_arrow( + DataHandler.to_arrow(x1), + DataHandler.to_arrow(x2), + DataHandler.to_arrow(y1) if y1 is not None else None, + DataHandler.to_arrow(y2) if y2 is not None else None, + ) + + def plot_arrow(self, x1, x2, y1, y2): + return self.plotter(x1, x2, y1, y2, **self.config.get_params()) diff --git a/src/biofit/preprocessing/scaling/relative_abundance/__init__.py b/src/biofit/preprocessing/scaling/relative_abundance/__init__.py new file mode 100644 index 0000000..136618e --- /dev/null +++ b/src/biofit/preprocessing/scaling/relative_abundance/__init__.py @@ -0,0 +1,11 @@ +# ruff: noqa +from .plot_relative_abundance import ( + RelativeAbundancePlotter, + RelativeAbundancePlotterConfig, + RelativeAbundancePlotterConfigForASV, + RelativeAbundancePlotterConfigForMaldi, + RelativeAbundancePlotterConfigForMetagenomics, + RelativeAbundancePlotterConfigForOTU, + RelativeAbundancePlotterConfigForSNP, +) +from .relative_abundance import RelativeAbundanceScaler, RelativeAbundanceScalerConfig diff --git a/src/biofit/preprocessing/scaling/relative_abundance/plot_relative_abundance.py b/src/biofit/preprocessing/scaling/relative_abundance/plot_relative_abundance.py new file mode 100644 index 0000000..a09dbc1 --- /dev/null +++ b/src/biofit/preprocessing/scaling/relative_abundance/plot_relative_abundance.py @@ -0,0 +1,274 @@ +from dataclasses import dataclass, field +from typing import List, Optional, Type + +from biofit.integration.biosets import get_feature +from biofit.processing import SelectedColumnTypes +from biocore.utils.py_util import is_bioset +from biofit.utils.types import Unset + +from ..plot_scaling import ScalerPlotter, ScalerPlotterConfig + + +@dataclass +class RelativeAbundancePlotterConfig(ScalerPlotterConfig): + processor_name: str = field(default="relative_abundance", init=False, repr=False) + _compare: bool = field(default=True, init=False, repr=False) + _transoform_input_feature_types: List[Type] = field( + default_factory=lambda: [ + None, + None, + get_feature("TARGET_FEATURE_TYPES"), + get_feature("TARGET_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + + _transform_unused_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("METADATA_FEATURE_TYPES"), + get_feature("METADATA_FEATURE_TYPES"), + None, + None, + ], + init=False, + repr=False, + ) + + input_columns1: Optional[str] = None + input_columns2: Optional[str] = None + label_name1: Optional[str] = None + label_name2: Optional[str] = None + ylab: Optional[str] = None + xlab: Optional[str] = None + ylim: Optional[list] = None + before_title: str = "Before Relative Abundance" + after_title: str = "After Relative Abundance" + legend_position: str = "top" + add_box: bool = True + horizontal_plot: bool = False + order: bool = False + col_set: str = "Set1" + cols: Optional[str] = None + log_num: Optional[str] = None + show_outliers: bool = True + + +@dataclass +class RelativeAbundancePlotterConfigForMetagenomics(RelativeAbundancePlotterConfig): + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [ + (get_feature("Abundance"), get_feature("ReadCount")), + (get_feature("Abundance"), get_feature("ReadCount")), + get_feature("TARGET_FEATURE_TYPES"), + get_feature("TARGET_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + dataset_name: str = field(default="metagenomics", init=False, repr=False) + log_num: Optional[str] = "log2_1p" + + +@dataclass +class RelativeAbundancePlotterConfigForOTU(RelativeAbundancePlotterConfig): + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("Abundance"), + get_feature("Abundance"), + get_feature("TARGET_FEATURE_TYPES"), + get_feature("TARGET_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + dataset_name: str = field(default="otu", init=False, repr=False) + ylab: Optional[str] = "OTU Abundance" + log_num: Optional[str] = "log10_1p" + + +@dataclass +class RelativeAbundancePlotterConfigForASV(RelativeAbundancePlotterConfig): + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("Abundance"), + get_feature("Abundance"), + get_feature("TARGET_FEATURE_TYPES"), + get_feature("TARGET_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + dataset_name: str = field(default="asv", init=False, repr=False) + + +@dataclass +class RelativeAbundancePlotterConfigForSNP(RelativeAbundancePlotterConfig): + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("GenomicVariant"), + get_feature("GenomicVariant"), + get_feature("TARGET_FEATURE_TYPES"), + get_feature("TARGET_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("GenomicVariant"), + get_feature("GenomicVariant"), + get_feature("TARGET_FEATURE_TYPES"), + get_feature("TARGET_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + dataset_name: str = field(default="snp", init=False, repr=False) + + +@dataclass +class RelativeAbundancePlotterConfigForMaldi(RelativeAbundancePlotterConfig): + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("PeakIntensity"), + get_feature("PeakIntensity"), + get_feature("TARGET_FEATURE_TYPES"), + get_feature("TARGET_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("PeakIntensity"), + get_feature("PeakIntensity"), + get_feature("TARGET_FEATURE_TYPES"), + get_feature("TARGET_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + dataset_name: str = field(default="maldi", init=False, repr=False) + ylab: Optional[str] = "Peak Intensity" + + +class RelativeAbundancePlotter(ScalerPlotter): + config_class = RelativeAbundancePlotterConfig + config: RelativeAbundancePlotterConfig + + def __init__( + self, + input_columns1: SelectedColumnTypes = None, + input_columns2: SelectedColumnTypes = None, + label_name1: Optional[str] = None, + label_name2: Optional[str] = None, + ylab: Optional[str] = Unset("None"), + xlab: Optional[str] = Unset("None"), + ylim: Optional[list] = Unset("None"), + title: str = Unset('"Violin Plot"'), + legend_position: str = Unset('"top"'), + add_box: bool = Unset("True"), + horizontal_plot: bool = Unset("False"), + order: bool = Unset("False"), + col_set: str = Unset('"Set1"'), + cols: Optional[str] = Unset("None"), + log_num: Optional[int] = Unset("None"), + show_outliers: bool = Unset("True"), + config: RelativeAbundancePlotterConfig = None, + **kwargs, + ): + super().__init__( + config=config, + input_columns1=input_columns1, + input_columns2=input_columns2, + label_name1=label_name1, + label_name2=label_name2, + ylab=ylab, + xlab=xlab, + ylim=ylim, + title=title, + legend_position=legend_position, + add_box=add_box, + horizontal_plot=horizontal_plot, + order=order, + col_set=col_set, + cols=cols, + log_num=log_num, + show_outliers=show_outliers, + **kwargs, + ) + + def plot( + self, + x1, + x2=None, + y1=None, + y2=None, + input_columns1: SelectedColumnTypes = None, + input_columns2: SelectedColumnTypes = None, + label_name1: SelectedColumnTypes = None, + label_name2: SelectedColumnTypes = None, + ylab: Optional[str] = Unset("None"), + xlab: Optional[str] = Unset("None"), + ylim: Optional[list] = Unset("None"), + title: str = Unset('"Violin Plot"'), + legend_position: str = Unset('"top"'), + add_box: bool = Unset("True"), + horizontal_plot: bool = Unset("False"), + order: bool = Unset("False"), + col_set: str = Unset('"Set1"'), + cols: Optional[str] = Unset("None"), + log_num: Optional[int] = Unset("None"), + show_outliers: bool = Unset("True"), + path: str = None, + device: str = "pdf", + fingerprint: str = None, + unused_columns: SelectedColumnTypes = None, + raise_if_missing: bool = True, + show: bool = True, + ): + if is_bioset(x1): + from biosets import get_target + + if y1 is None: + y1 = get_target(x1) + + if is_bioset(x2): + from biosets import get_target + + if y2 is None: + y2 = get_target(x2) + if x1 is None or x2 is None or y1 is None or y2 is None: + raise ValueError("Must provide the before and after normalization.") + self.config._input_columns = self._set_input_columns_and_arity( + input_columns1, input_columns2, label_name1, label_name2 + ) + return self._plot( + x1, + x2, + y1, + y2, + input_columns1=input_columns1, + input_columns2=input_columns2, + label_name1=label_name1, + label_name2=label_name2, + ylab=ylab, + xlab=xlab, + ylim=ylim, + title=title, + legend_position=legend_position, + add_box=add_box, + horizontal_plot=horizontal_plot, + order=order, + col_set=col_set, + cols=cols, + log_num=log_num, + show_outliers=show_outliers, + path=path, + device=device, + fingerprint=fingerprint, + unused_columns=unused_columns, + raise_if_missing=raise_if_missing, + show=show, + ) diff --git a/src/biofit/preprocessing/scaling/relative_abundance/relative_abundance.py b/src/biofit/preprocessing/scaling/relative_abundance/relative_abundance.py new file mode 100644 index 0000000..1f76867 --- /dev/null +++ b/src/biofit/preprocessing/scaling/relative_abundance/relative_abundance.py @@ -0,0 +1,188 @@ +from dataclasses import dataclass, field +from typing import List, Type + +import numpy as np + +from biofit.integration.biosets import get_feature +from biofit.processing import SelectedColumnTypes +from biofit.utils import logging + +from ..scaling import Scaler, ScalerConfig + +logger = logging.get_logger(__name__) + + +@dataclass +class RelativeAbundanceScalerConfig(ScalerConfig): + """ + Configuration for RelativeAbundanceScaler. + """ + + processor_name: str = field(default="relative_abundance", init=False, repr=False) + _fit_process_desc: str = field( + default="Calculating relative abundance scaling", init=False, repr=False + ) + _transform_process_desc: str = field( + default="Applying relative abundance scaling", init=False, repr=False + ) + _fit_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + _transform_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + + pseudocount: float = 1e-6 + input_columns: List[str] = None + + +class RelativeAbundanceScaler(Scaler): + """ + RelativeAbundanceScaler applies relative abundance scaling to the input data. + """ + + output_dtype = "float64" + config_class = RelativeAbundanceScalerConfig + config: RelativeAbundanceScalerConfig + + def __init__( + self, + pseudocount: float = 1e-6, + config: RelativeAbundanceScalerConfig = None, + **kwargs, + ): + super().__init__(config=config, pseudocount=pseudocount, **kwargs) + + def fit( + self, + X, + input_columns: SelectedColumnTypes = None, + raise_if_missing: bool = None, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + num_proc: int = None, + map_kwargs: dict = {"fn_kwargs": {}}, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ) -> "RelativeAbundanceScaler": + self.config._input_columns = self._set_input_columns_and_arity( + input_columns or self.config.input_columns + ) + return self._process_fit( + X, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ) + + def transform( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + self._input_columns = self._set_input_columns_and_arity( + input_columns or self.config.input_columns + ) + return self._process_transform( + X, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def fit_transform( + self, + X, + *args, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = True, + raise_if_missing: bool = False, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = {"fn_kwargs": {}}, + num_proc: int = None, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ): + return self.fit( + X, + input_columns=input_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ).transform( + X, + input_columns=input_columns, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def _transform_numpy(self, X: np.ndarray): + # Calculating the sum of each sample + row_sums = np.sum(X, axis=1)[:, np.newaxis] + if np.any(row_sums == 0): + row_sums += self.config.pseudocount + return X / row_sums diff --git a/src/biofit/preprocessing/scaling/scaling.py b/src/biofit/preprocessing/scaling/scaling.py new file mode 100644 index 0000000..3dda512 --- /dev/null +++ b/src/biofit/preprocessing/scaling/scaling.py @@ -0,0 +1,24 @@ +from dataclasses import dataclass, field +from typing import List, Type + +from biofit.integration.biosets import get_feature +from biofit.processing import BaseProcessor, ProcessorConfig + + +@dataclass +class ScalerConfig(ProcessorConfig): + processor_type: str = field(default="scaling", init=False, repr=False) + _fit_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + _transform_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + + +class Scaler(BaseProcessor): + pass diff --git a/src/biofit/preprocessing/scaling/tmm/__init__.py b/src/biofit/preprocessing/scaling/tmm/__init__.py new file mode 100644 index 0000000..31f9c51 --- /dev/null +++ b/src/biofit/preprocessing/scaling/tmm/__init__.py @@ -0,0 +1,7 @@ +# ruff: noqa +from .tmm import ( + TMMScaler, + TMMScalerConfig, + TMMScalerConfigForOTU, + TMMScalerConfigForMetagenomics, +) diff --git a/src/biofit/preprocessing/scaling/tmm/tmm.R b/src/biofit/preprocessing/scaling/tmm/tmm.R new file mode 100644 index 0000000..41bedb6 --- /dev/null +++ b/src/biofit/preprocessing/scaling/tmm/tmm.R @@ -0,0 +1,89 @@ +# Full path: src/biofit/preprocessing/scaling/tmm/tmm.R +# adapted from ruppinlab/sklearn-extensions and edgeR::calcNormFactors source code + +source(file.path(R_SCRIPTS_PATH, "utils.R")) + +edger_tmm_ref_column <- function(counts, lib_size=colSums(counts), p=0.75) { + y <- t(t(counts) / lib_size) + f <- apply(y, 2, function(x) quantile(x, p=p)) + ref_column <- which.min(abs(f - mean(f))) +} + +edger_tmm_fit <- function(X) { + suppressPackageStartupMessages(require(edgeR)) + suppressPackageStartupMessages(require(arrow)) + X <- as.matrix(as.data.frame(X)) + counts <- t(X) + ref_sample <- counts[, edger_tmm_ref_column(counts)] + return(ref_sample) +} + +edger_tmm_cpm_transform <- function(X, ref_samples, log=TRUE, prior_count=2) { + suppressPackageStartupMessages(require(edgeR)) + suppressPackageStartupMessages(require(arrow)) + X <- as.matrix(convert_to_dataframe(X)) + counts <- t(X) + ref_sample_mask <- apply(counts, 2, function(c) all(c == ref_samples)) + if (any(ref_sample_mask)) { + dge <- edgeR::DGEList(counts=counts) + dge <- edgeR::calcNormFactors( + dge, method="TMM", refColumn=min(which(ref_sample_mask)) + ) + cpms <- edgeR::cpm(dge, log=log, prior.count=prior_count) + } else { + counts <- cbind(counts, ref_samples) + colnames(counts) <- NULL + dge <- edgeR::DGEList(counts=counts) + dge <- edgeR::calcNormFactors(dge, method="TMM", refColumn=ncol(dge)) + cpms <- edgeR::cpm(dge, log=log, prior.count=prior_count) + cpms <- cpms[, -ncol(cpms)] + } + return(arrow::as_arrow_table(as.data.frame(t(cpms)))) +} + +edger_tmm_tpm_transform <- function( + X, feature_meta, ref_samples, log=TRUE, prior_count=2, meta_col="Length" +) { + if (is.null(feature_meta)) stop("feature_meta cannot be NULL") + suppressPackageStartupMessages(require(edgeR)) + suppressPackageStartupMessages(require(arrow)) + X <- as.matrix(convert_to_dataframe(X)) + counts <- t(X) + ref_sample_mask <- apply(counts, 2, function(c) all(c == ref_samples)) + if (any(ref_sample_mask)) { + dge <- edgeR::DGEList(counts=counts, genes=feature_meta) + dge <- edgeR::calcNormFactors( + dge, method="TMM", refColumn=min(which(ref_sample_mask)) + ) + } else { + counts <- cbind(counts, ref_samples) + colnames(counts) <- NULL + dge <- edgeR::DGEList(counts=counts, genes=feature_meta) + dge <- edgeR::calcNormFactors(dge, method="TMM", refColumn=ncol(dge)) + } + if (log) { + # XXX: edgeR doesn't have built-in support for logTPM w/ prior.count + # so do API internal logic manually + # TODO: use effectiveLibSizes() in newer edgeR versions + lib_size <- dge$samples$lib.size * dge$samples$norm.factors + scaled_prior_count <- args$prior_count * lib_size / mean(lib_size) + adj_lib_size <- lib_size + 2 * scaled_prior_count + fpkms <- t( + (t(dge$counts) + scaled_prior_count) / adj_lib_size + ) * 1e6 / dge$genes[[meta_col]] * 1e3 + tpms <- log2(t(t(fpkms) / colSums(fpkms)) * 1e6) + } else { + fpkms <- edgeR::rpkm( + dge, gene.length=meta_col, log=log, prior.count=prior_count + ) + tpms <- t(t(fpkms) / colSums(fpkms)) * 1e6 + } + if (!any(ref_sample_mask)) tpms <- tpms[, -ncol(tpms)] + return(arrow::as_arrow_table(t(tpms))) +} + +edger_cpm_transform <- function(X, log=TRUE, prior_count=2) { + suppressPackageStartupMessages(require(edgeR)) + suppressPackageStartupMessages(require(arrow)) + return(arrow::as_arrow_table(t(edgeR::cpm(t(X), log=log, prior.count=prior_count)))) +} diff --git a/src/biofit/preprocessing/scaling/tmm/tmm.py b/src/biofit/preprocessing/scaling/tmm/tmm.py new file mode 100644 index 0000000..c78511d --- /dev/null +++ b/src/biofit/preprocessing/scaling/tmm/tmm.py @@ -0,0 +1,281 @@ +from dataclasses import dataclass, field +from pathlib import Path +from typing import List, Optional, Type + +import pyarrow as pa + +from biofit.integration.biosets import get_feature +from biofit.integration.R import RCaller +from biofit.integration.R.r_caller import PackageNotInstalledError +from biofit.processing import SelectedColumnTypes, sync_backup_config +from biofit.utils import logging + +from ..scaling import Scaler, ScalerConfig + +logger = logging.get_logger(__name__) + + +@dataclass +class TMMScalerConfig(ScalerConfig): + _fit_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + _transform_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES"), None], + init=False, + repr=False, + ) + # process descriptions + _fit_process_desc: str = field( + default="Determining reference samples for TMM normalization", + init=False, + repr=False, + ) + _transform_process_desc: str = field( + default="Applying TMM normalization", init=False, repr=False + ) + processor_name: str = field(default="tmm", init=False, repr=False) + + # attributes + r_source: str = field( + default=Path(__file__).with_suffix(".R").as_posix(), init=False, repr=False + ) + fit_func_name: str = field(default="edger_tmm_fit", init=False, repr=False) + cpm_transform_func_name: str = field( + default="edger_tmm_cpm_transform", init=False, repr=False + ) + tpm_transform_func_name: str = field( + default="edger_tmm_tpm_transform", init=False, repr=False + ) + + log: bool = True + prior_count: int = 2 + meta_col: str = "Length" + gene_col: str = "genes" + + # estimated attributes + ref_samples: Optional[pa.Table] = field(default=None, init=False, repr=False) + + +class TMMScalerConfigForMetagenomics(TMMScalerConfig): + # dataset specific attributes + _input_feature_types: List[Type] = ( + get_feature("Abundance"), + get_feature("ReadCount"), + ) + dataset_name = "metagenomics" + + +class TMMScalerConfigForOTU(TMMScalerConfig): + # dataset specific attributes + _input_feature_types: List[Type] = get_feature("Abundance") + dataset_name = "otu" + + +class TMMScalerConfigForSNP(TMMScalerConfig): + # dataset specific attributes + _input_feature_types: List[Type] = get_feature("GenomicVariant") + dataset_name = "snp" + + +class TMMScaler(Scaler): + output_dtype = "float64" + + # config class + config_class = TMMScalerConfig + config: TMMScalerConfig + + def __init__( + self, + log: bool = True, + prior_count: int = 2, + meta_col: str = "Length", + gene_col: str = "genes", + config: Optional[TMMScalerConfig] = None, + **kwargs, + ): + super().__init__( + config=config, + log=log, + prior_count=prior_count, + meta_col=meta_col, + gene_col=gene_col, + **kwargs, + ) + r_caller = RCaller.from_script(self.config.r_source) + install_missing = kwargs.get("install_missing") + try: + r_caller.verify_r_dependencies( + cran_dependencies=["BiocManager"], + bioconductor_dependencies=["edgeR"], + install_missing=install_missing, + ) + except PackageNotInstalledError: + raise PackageNotInstalledError( + "TMMScale requires the following R package: edgeR. To " + "install, initialize the TMMScaler with install_missing=True or run " + "R -e 'BiocManager::install(\"edgeR\")' in your terminal." + ) + self.fit_func = r_caller.get_method(self.config.fit_func_name) + self.cpm_transform = r_caller.get_method(self.config.cpm_transform_func_name) + self.tpm_transform = r_caller.get_method(self.config.tpm_transform_func_name) + + def fit( + self, + X, + input_columns: SelectedColumnTypes = None, + raise_if_missing: bool = None, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + num_proc: int = None, + map_kwargs: dict = {"fn_kwargs": {}}, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ): + self.config._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_fit( + X, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ) + + def transform( + self, + X, + genes=None, + input_columns: SelectedColumnTypes = None, + gene_col: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + self._input_columns = self._set_input_columns_and_arity(input_columns, gene_col) + return self._process_transform( + X, + genes, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def fit_transform( + self, + X, + genes=None, + input_columns: SelectedColumnTypes = None, + gene_col: SelectedColumnTypes = None, + keep_unused_columns: bool = True, + raise_if_missing: bool = False, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = {"fn_kwargs": {}}, + num_proc: int = None, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ): + return self.fit( + X, + input_columns=input_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ).transform( + X, + genes=genes, + gene_col=gene_col, + input_columns=input_columns, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + @sync_backup_config + def set_params(self, **kwargs): + self.config = self.config.replace_defaults(**kwargs) + + # load file from src/biofit/preprocessing/scaling/tmm/tmm.R + r_caller = RCaller.from_script(self.config.r_source) + self.fit_func = r_caller.get_method(self.config.fit_func_name) + self.cpm_transform = r_caller.get_method(self.config.cpm_transform_func_name) + self.tpm_transform = r_caller.get_method(self.config.tpm_transform_func_name) + return self + + def _fit_arrow(self, X: pa.Table): + self.config.ref_samples = self.fit_func(X) + return self + + def _transform_arrow(self, X: pa.Table, genes=None): + if genes: + return self.tpm_transform( + X=X, + feature_meta=genes, + ref_samples=self.config.ref_samples, + log=self.config.log, + prior_count=self.config.prior_count, + meta_col=self.config.meta_col, + ) + else: + return self.cpm_transform( + X=X, + ref_samples=self.config.ref_samples, + log=self.config.log, + prior_count=self.config.prior_count, + ) diff --git a/src/biofit/preprocessing/transformation/__init__.py b/src/biofit/preprocessing/transformation/__init__.py new file mode 100644 index 0000000..d5035f9 --- /dev/null +++ b/src/biofit/preprocessing/transformation/__init__.py @@ -0,0 +1,2 @@ +# ruff: noqa +from .log import * diff --git a/src/biofit/preprocessing/transformation/log/__init__.py b/src/biofit/preprocessing/transformation/log/__init__.py new file mode 100644 index 0000000..d81b7a9 --- /dev/null +++ b/src/biofit/preprocessing/transformation/log/__init__.py @@ -0,0 +1,2 @@ +# ruff: noqa +from .log import LogTransformer, LogTransformerConfig, LogTransformerConfigForOTU diff --git a/src/biofit/preprocessing/transformation/log/log.py b/src/biofit/preprocessing/transformation/log/log.py new file mode 100644 index 0000000..d07cdf9 --- /dev/null +++ b/src/biofit/preprocessing/transformation/log/log.py @@ -0,0 +1,201 @@ +from dataclasses import dataclass, field +from typing import List, Type + +import numpy as np + +from biofit.integration.biosets import get_feature +from biofit.processing import ( + SelectedColumnTypes, + SelectedFeatureTypes, + sync_backup_config, +) +from biofit.utils import logging + +from ..transformation import Transformer, TransformerConfig + +logger = logging.get_logger(__name__) + + +@dataclass +class LogTransformerConfig(TransformerConfig): + processor_name: str = field(default="log", init=False, repr=False) + _transform_process_desc: str = field( + default="Applying log transformation", init=False, repr=False + ) + _fit_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + _transform_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + + base = np.e + shift = 0 + estimator = None + + +@dataclass +class LogTransformerConfigForOTU(LogTransformerConfig): + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance")], init=False, repr=False + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance")], init=False, repr=False + ) + dataset_name: str = field(default="otu", init=False, repr=False) + base = 2 + shift = 1 + + +class LogTransformer(Transformer): + output_dtype = "float64" + config_class = LogTransformerConfig + config: LogTransformerConfig + + @sync_backup_config + def set_params(self, **kwargs): + self.config = self.config.replace_defaults(**kwargs) + func = None + if self.config.base == np.e: + if self.config.shift == 0: + func = np.log + elif self.config.shift == 1: + func = np.log1p + elif self.config.base == 10: + func = np.log10 + elif self.config.base == 2: + func = np.log2 + + if func is None: + + def func(x): + return np.log(x + self.config.shift) / np.log(self.config.base) + + self.log_transform = func + return self + + def fit( + self, + X, + input_columns: SelectedColumnTypes = None, + raise_if_missing: bool = None, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + num_proc: int = None, + map_kwargs: dict = {"fn_kwargs": {}}, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ) -> "LogTransformer": + self.config._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_fit( + X, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ) + + def transform( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + self._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_transform( + X, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def fit_transform( + self, + X, + *args, + input_columns: SelectedFeatureTypes = None, + keep_unused_columns: bool = True, + raise_if_missing: bool = False, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = {"fn_kwargs": {}}, + num_proc: int = None, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ): + return self.fit( + X, + input_columns=input_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ).transform( + X, + input_columns=input_columns, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def _transform_sklearn(self, X): + return self.log_transform(X) diff --git a/src/biofit/preprocessing/transformation/transformation.py b/src/biofit/preprocessing/transformation/transformation.py new file mode 100644 index 0000000..61f144b --- /dev/null +++ b/src/biofit/preprocessing/transformation/transformation.py @@ -0,0 +1,17 @@ +from dataclasses import dataclass, field + +from biofit.processing import BaseProcessor, ProcessorConfig + + +@dataclass +class TransformerConfig(ProcessorConfig): + """Base class for Transformer processor configuration.""" + + processor_type: str = field(default="transformation", init=False, repr=False) + + +class Transformer(BaseProcessor): + """Base class for Transformer processors.""" + + config_class = TransformerConfig + config: TransformerConfig diff --git a/src/biofit/processing.py b/src/biofit/processing.py new file mode 100644 index 0000000..7ade0eb --- /dev/null +++ b/src/biofit/processing.py @@ -0,0 +1,3696 @@ +import copy +import importlib +import inspect +import json +import os +import re +import sys +import tempfile +import time +from collections.abc import Callable +from dataclasses import dataclass, field, fields +from functools import wraps +from multiprocessing import Pool +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) + +import joblib +import numpy as np +import pandas as pd +import pandas.api.types as pdt +import pyarrow as pa +from biocore import DataHandler, get_data_format +from biocore.utils.import_util import ( + is_polars_available, + is_ray_available, + is_biosets_available, + is_datasets_available, +) +from biocore.utils.inspect import get_kwargs, get_required_args +from biocore.utils.naming import camelcase_to_snakecase +from sklearn.utils.validation import ( + NotFittedError, +) + +import biofit.config +from biofit.utils import ( + Unset, + determine_upcast, + fingerprint_from_data, + fingerprint_from_kwargs, + generate_cache_dir, + get_cache_file_name, + init_arrow_buffer_and_writer, + logging, + move_temp_file, + update_fingerprint, + version, +) +from biofit.utils.file_utils import expand_path, is_remote_url +from biofit.utils.fingerprint import Hasher, is_caching_enabled +from biofit.utils.py_util import iflatmap_unordered +from biocore.utils.py_util import ( + is_bioset, + is_dataset, + is_dataset_dict, + is_iterable_dataset, +) + +from biofit.utils.table_util import string_to_arrow + +if TYPE_CHECKING: + from datasets.features.features import Features + +logger = logging.get_logger(__name__) + +T_CONFIG = TypeVar("T_CONFIG", bound="BaseConfig") +T_PROC = TypeVar("T_PROC", bound="BaseProcessor") + +SelectedFeatureTypes = Union[Type, Tuple[Type], List[Union[Type, Tuple[Type]]]] + +SelectedColumnTypes = Union[ + str, + int, + List[str], + List[int], + List[Union[List[str], List[int]]], + SelectedFeatureTypes, +] + + +class NonExistentCacheError(Exception): + """Used when we expect the existence of a cache""" + + pass + + +# based on conversion speed, we will use the following order of preference +_ORDERED_FORMATS = [ + "arrow", + "pyarrow", + "torch", + "pt", + "tf", + "tensorflow", + "pandas", + "pd", + "numpy", + "np", + "dicts", + "dict", + "list", +] + +_DATAFRAME_FORMATS = [ + "pandas", + "pd", +] + +_SKLEARN_FORMATS = [ + "numpy", + "np", + "pandas", + "pd", + "series", +] + +_SPECIAL_FORMATS = ["any", "all", "array", "dataframe", "sklearn"] + +_ARROW_WRITEABLE_FORMATS = ["arrow", "pyarrow"] +# when https://github.com/huggingface/datasets/pull/6762 is merged, we ca add +# "polars", "pl" to the list of writeable formats + +if is_polars_available(): + pandas_pos = _ORDERED_FORMATS.index("pandas") + _ORDERED_FORMATS.insert(pandas_pos, "polars") + _ORDERED_FORMATS.insert(pandas_pos + 1, "pl") + + _DATAFRAME_FORMATS.insert(0, "polars") + _DATAFRAME_FORMATS.insert(1, "pl") + + _SKLEARN_FORMATS.insert(0, "polars") + _SKLEARN_FORMATS.insert(1, "pl") + + +def sync_backup_config(func): + @wraps(func) + def wrapper(self, **kwargs): + if hasattr(self, "config_"): + self._fingerprint = None + self.config_.replace_defaults(**kwargs) + return func(self, **kwargs) + + return wrapper + + +def _generate_get_feature_names_out(estimator, n_features_out): + """ + Modified _generate_get_feature_names_out from sklearn to convert estimator name to snakecase + """ + + def name_formatter(template, index): + # Check if there is any placeholder {i...} in the template + if re.search(r"\{i(\+([0-9]+))?\}", template) is None: + # No placeholders, use default formatting {name}{i} + template = f"{template}{index}" + + # Regex to find {i} or {i+n} where n is an integer + pattern = r"\{i(\+([0-9]+))?\}" + matches = re.finditer(pattern, template) + + for match in matches: + full_match = match.group(0) + increment = match.group(2) + if increment: + value = index + int(increment) + else: + value = index + + template = template.replace(full_match, str(value)) + + return template + + estimator_name = None + if hasattr(estimator, "config"): + if getattr(estimator.config, "_feature_names_out", None) is not None: + return np.asarray(estimator.config._feature_names_out) + elif getattr(estimator.config, "output_template_name", None) is not None: + estimator_name = estimator.config.output_template_name + elif getattr(estimator.config, "processor_name", None): + estimator_name = estimator.config.processor_name + if n_features_out > 1: + estimator_name += "_" + if estimator_name is None: + estimator_name = camelcase_to_snakecase(estimator.__class__.__name__) + estimator_name = estimator_name.split("_for_")[0] + estimator_name = "_".join(estimator_name.split("_")[:-1]) + if n_features_out > 1: + estimator_name += "_" + + if n_features_out > 1: + out = [name_formatter(estimator_name, i) for i in range(n_features_out)] + else: + out = [estimator_name] + + return np.asarray(out, dtype=object) + + +REPLAY_ACTIONS = [ + ( + "from_config_transform", + {"path": "path", "fingerprint": "fingerprint", "ext": ".json"}, + ), +] + + +def post_recording(*args, **kwargs): + out, _, func_args, func_kwargs = args + if not hasattr(out, "_fingerprint") or not hasattr(out, "replays"): + return out + + info = kwargs.get("info", None) + + if len(func_args) > 0: + self: "BaseProcessor" = func_args[0] + func_args = func_args[1:] + else: + self = func_kwargs["self"] + del func_kwargs["self"] + + if len(func_args) > 0: + ds = func_args[0] + func_args = func_args[1:] + else: + ds = func_kwargs["X"] + del func_kwargs["X"] + + out_fingerprint = None + if info: + out_fingerprint = info.get("new_fingerprint", None) + + if not out_fingerprint: + out_fingerprint = getattr(out, "_fingerprint", None) + + fingerprint = self.fingerprint + + cache_file_names = [] + if info: + cache_file_name = info.get("cache_file_name", None) + if not cache_file_name: + cache_dir = info.get("cache_dir", None) + if cache_dir and os.path.exists(cache_dir) and fingerprint: + file_list = os.listdir(cache_dir) + cache_file_names = [ + os.path.join(cache_dir, f) + for f in file_list + if fingerprint in f or (out_fingerprint and out_fingerprint in f) + ] + else: + if cache_file_name and not isinstance(cache_file_name, list): + cache_file_names = [cache_file_name] + + out._fingerprint = out_fingerprint + replays = ds.replays or out.replays or [] + out.replays = replays.copy() + entity_path = f"{self.__module__}.{self.__class__.__name__}" + + if cache_file_names: + for func_name, replay_info in REPLAY_ACTIONS: + path_arg = replay_info["path"] + fingerprint_arg = replay_info["fingerprint"] + ext = replay_info["ext"] + paths = [] + for file in cache_file_names: + if file.endswith(ext): + paths.append(file) + + if len(paths) == 1: + new_kwargs = {path_arg: paths[0]} + if fingerprint_arg: + new_kwargs[fingerprint_arg] = fingerprint + out.replays.append( + ( + out._fingerprint, + entity_path, + func_name, + (), + {**new_kwargs}, + ) + ) + break + + return out + + +def keep_dataset_fingerprint(func): + @wraps(func) + def wrapper(*args, **kwargs): + if len(args) > 0: + self = args[0] + args = args[1:] + else: + self = kwargs["self"] + del kwargs["self"] + + if len(args) > 0: + ds = args[0] + args = args[1:] + else: + ds = kwargs["X"] + del kwargs["X"] + + old_fingerprint = getattr(ds, "_fingerprint", None) + out = func(self, ds, *args, **kwargs) + if old_fingerprint: + ds._fingerprint = old_fingerprint + return out + + return wrapper + + +@dataclass +class BaseConfig: + @classmethod + def from_config_file( + cls, path: str, ignore_none=False, add_new_attr=True + ) -> T_CONFIG: + """ + Load configuration from a JSON file. + + Args: + path (str or os.PathLike): Path to the configuration file. + ignore_none (bool): If True, ignore None values during loading. + add_new_attr (bool): + If True, allow adding new attributes that are not explicitly defined. + + Returns: + T_CONFIG: + An instance of the configuration class with attributes loaded from the + file. + """ + if isinstance(path, os.PathLike): + path = str(path) + + with open(path, "r") as f: + states = json.load(f) + return cls.from_dict(states, ignore_none=ignore_none, add_new_attr=add_new_attr) + + @classmethod + def from_config( + cls, + config_or_path: Union[str, os.PathLike, dict, "BaseConfig"], + ignore_none=False, + add_new_attr=False, + ) -> T_CONFIG: + """ + Load configuration from a file, dictionary, or another BaseConfig instance. + + Args: + config_or_path (Union[str, os.PathLike, dict, BaseConfig]): + Configuration data or path. + ignore_none (bool): If True, ignore None values during loading. + add_new_attr (bool): + If True, allow adding new attributes that are not explicitly defined. + + Returns: + T_CONFIG: An instance of the configuration class. + """ + if isinstance(config_or_path, (str, os.PathLike)): + return cls.from_config_file( + config_or_path, ignore_none=ignore_none, add_new_attr=add_new_attr + ) + elif isinstance(config_or_path, dict): + return cls.from_dict( + config_or_path, ignore_none=ignore_none, add_new_attr=add_new_attr + ) + elif isinstance(config_or_path, BaseConfig): + return cls.from_dict( + config_or_path.to_dict(deep=False), + ignore_none=ignore_none, + add_new_attr=add_new_attr, + ) + else: + raise ValueError(f"Unsupported config type {type(config_or_path)}") + + @classmethod + def from_dict(cls, states, ignore_none=False, add_new_attr=False) -> T_CONFIG: + """ + Load configuration from a dictionary. + + Args: + states (dict): Dictionary containing configuration states. + ignore_none (bool): If True, ignore None values during the assignment. + add_new_attr (bool): + If True, allow adding new attributes that are not in the init. + + Returns: + T_CONFIG: + An instance of the configuration class with attributes set according to + `states`. + """ + + def _from_dict(obj): + if isinstance(obj, dict): + if len(obj) == 2 and "path" in obj and "format" in obj: + if obj["format"] == "joblib": + with open(obj["path"], "rb") as f: + return joblib.load(f) + + return DataHandler.to_format( + obj["path"], target_format=obj["format"] + ) + elif "__module__" in obj and "__class__" in obj and "__dict__" in obj: + module = importlib.import_module(obj.get("__module__")) + _cls = getattr(module, obj["__class__"]) + if not hasattr(_cls, "from_dict"): + raise ValueError( + f"from_dict method not found for class {_cls.__name__}" + ) + return _cls.from_dict(obj["__dict__"]) + else: + return {k: _from_dict(v) for k, v in obj.items()} + else: + return obj + + _states = copy.deepcopy(states) + cls_kwargs = get_kwargs(_states, cls.__init__) + self = cls( + **{ + k: _from_dict(v) + for k, v in cls_kwargs.items() + if not isinstance(v, Unset) + } + ) + + for k in cls_kwargs: + _states.pop(k, None) + + attributes = {k: _from_dict(v) for k, v in _states.items()} + + self = self.replace_defaults( + ignore_none=ignore_none, add_new_attr=add_new_attr, **attributes + ) + return self + + def to_dict( + self, + deep=True, + save_nested_complex_obj=False, + path=None, + fingerprint=None, + ): + """ + Convert the configuration to a dictionary. + + Args: + deep (bool): If True, recursively convert all attributes to dictionaries. + save_nested_complex_obj (bool): + If True, save complex nested objects to disk. + path (str): Base path for saving complex nested objects. + fingerprint (str): + Optional fingerprint string to distinguish file paths. + + Returns: + dict: A dictionary representation of the configuration. + """ + base_dir = None + if save_nested_complex_obj: + if path: + base_dir = os.path.dirname(path) + else: + base_dir = tempfile.mkdtemp() + path = tempfile.NamedTemporaryFile("w", dir=base_dir, delete=False).name + + def convert_to_dict(obj, name=None): + """ + Convert an object's attributes to a dictionary, recursively processing + nested objects. Handles complex objects like arrays, datasets, and models + by potentially saving them to disk and replaces them with a reference in + the dictionary. + + Args: + obj (Any): The object to be converted into a dictionary format. + name (str, optional): + Optional name to be used for generating file paths when saving + complex objects. Helps in creating more readable and traceable file + names. + + Returns: + dict: + A dictionary representing the object. For simple types, returns the + object itself. For complex objects, returns a dictionary with + metadata such as path and format or fully serialized object states. + + Raises: + Exception: + If there is a problem with writing the data to disk or any other + operational issue during the conversion process, it raises an + Exception to indicate failure. + + Notes: + This method deals with different types of objects including simple data + types, complex machine learning models, and data structures. Depending + on the 'save_nested_complex_obj' flag in the `to_dict` method, complex + objects like machine learning models may either be saved to a file and + represented by a path, or fully serialized into the dictionary. + + For saving data structures like arrays or datasets, it may use a + temporary directory or a specified path to store intermediate files. + This method handles the creation and cleanup of these files as needed. + """ + if DataHandler.is_array_like(obj): + fp = None + _format = get_data_format(obj) + arrow_writer_kwargs = {} + if base_dir and save_nested_complex_obj: + if is_datasets_available(): + from datasets import Features + + if is_bioset(obj) or is_dataset(obj): + arrow_writer_kwargs["features"] = obj._info.features + obj = obj.data + else: + obj = DataHandler.to_format(obj, target_format="arrow") + arrow_writer_kwargs["features"] = ( + Features.from_arrow_schema(obj.schema) + ) + else: + obj = DataHandler.to_format(obj, target_format="arrow") + arrow_writer_kwargs["features"] = obj.schema + file_suffix = "" + if name: + file_suffix = f"-{name}" + new_fingerprint = update_fingerprint( + fingerprint or "", file_suffix + obj.__class__.__name__ + ) + fp = os.path.join( + base_dir, f"cache-{new_fingerprint}{file_suffix}.arrow" + ) + arrow_writer_kwargs["fingerprint"] = new_fingerprint + + out = { + "path": fp, + "format": _format, + } + if base_dir and save_nested_complex_obj: + _, writer, tmp_file = init_arrow_buffer_and_writer( + cache_file_name=fp, **arrow_writer_kwargs + ) + try: + if writer is not None: + writer.write_table(obj) + writer.finalize() + if tmp_file is not None: + move_temp_file(tmp_file, fp) + + except (Exception, KeyboardInterrupt): + if writer is not None: + writer.finalize() + if tmp_file is not None: + tmp_file.close() + if os.path.exists(tmp_file.name): + os.remove(tmp_file.name) + raise + return out + if hasattr(obj, "to_dict") and callable(obj.to_dict): + states = ( + obj.to_dict( + save_nested_complex_obj=save_nested_complex_obj, + path=path, + fingerprint=fingerprint, + ) + if isinstance(obj, BaseConfig) + else obj.to_dict() + ) + return { + "__module__": obj.__class__.__module__, + "__class__": obj.__class__.__name__, + "__dict__": states, + } + if ( + "sklearn" in obj.__class__.__module__ + or "lightgbm" in obj.__class__.__module__ + or "xgboost" in obj.__class__.__module__ + or "catboost" in obj.__class__.__module__ + ): + fp = None + _format = "joblib" + if base_dir and save_nested_complex_obj: + file_suffix = "" + if name: + file_suffix = f"-{name}" + new_fingerprint = update_fingerprint( + fingerprint or "", file_suffix + obj.__class__.__name__ + ) + fp = os.path.join( + base_dir, f"cache-{new_fingerprint}{file_suffix}.joblib" + ) + + with open(fp, "wb") as f: + joblib.dump(obj, f) + + return { + "path": fp, + "format": _format, + } + + if pdt.is_dict_like(obj): + return {k: convert_to_dict(v, k) for k, v in obj.items()} + if pdt.is_list_like(obj): + return [convert_to_dict(v) for v in obj] + if isinstance(obj, (str, bool, type(None))): + return obj + if pdt.is_integer(obj): + return int(obj) + if pdt.is_number(obj): + return float(obj) + + if deep: + if hasattr(self, "__getstate__"): # python>=3.11 + states = copy.deepcopy( + convert_to_dict( + { + k: v + for k, v in self.__getstate__().items() + if not isinstance(v, type) + } + ) + ) + else: + states = copy.deepcopy( + convert_to_dict( + { + k: v + for k, v in self.__dict__.items() + if not isinstance(v, type) + } + ) + ) + else: + states = copy.deepcopy(self.__dict__) + + states["config_name"] = self.__class__.__name__ + return states + + def save_to_cache(self, path, fingerprint=None): + """ + Save the configuration to a cache file in JSON format. + + Args: + path (str): Path where the configuration will be saved. + fingerprint (str): Optional fingerprint string to distinguish file paths. + """ + base_dir = os.path.dirname(path) + os.makedirs(base_dir, exist_ok=True) + files = os.listdir(base_dir) + try: + states = self.to_dict( + path=path, save_nested_complex_obj=True, fingerprint=fingerprint + ) + except (Exception, KeyboardInterrupt): + new_files = set(os.listdir(base_dir)) - set(files) + for f in new_files: + if os.path.exists(f): + os.remove(f) + raise + + with open(path, "w") as f: + json.dump(states, f, check_circular=False) + + def replace_defaults( + self, + ignore_none=False, + add_new_attr=False, + return_unused_kwargs=False, + **states, + ): + """ + Replace default values of the config instance with provided values. + + Args: + ignore_none (bool): If True, ignore None values during the replacement. + add_new_attr (bool): + If True, allow adding new attributes that are not explicitly defined. + return_unused_kwargs (bool): + If True, return unused keyword arguments. + + Returns: + self or (self, dict): + The configuration instance or a tuple of the instance and unused + kwargs. + """ + unused_keys = [] + for k, v in states.items(): + if isinstance(v, Unset): + continue + if add_new_attr or hasattr(self, k): + if not ignore_none or v is not None: + setattr(self, k, v) + else: + unused_keys.append(k) + if return_unused_kwargs: + return self, {k: states[k] for k in unused_keys} + return self + + def _repr_mimebundle_(self, **kwargs): + """ + Return the MIME bundle for the representation of the estimator. + + This function is utilized by Jupyter environments to display the estimator. + + Args: + **kwargs: Arbitrary keyword arguments. + + Returns: + dict: A MIME type bundle representing the estimator. + """ + from sklearn._config import get_config + from sklearn.utils._estimator_html_repr import estimator_html_repr + + output = {"text/plain": repr(self)} + if get_config()["display"] == "diagram": + output["text/html"] = estimator_html_repr(self) + return output + + def get_params(self, deep=True, show_init_only=True, show_repr_only=True): + """ + Get parameters for the estimator. + + Args: + deep (bool): If True, return parameters of nested objects. + + Returns: + dict: Dictionary of parameters. + """ + params = {} + args = [ + f.name + for f in fields(self) + if (not show_init_only or f.init) and (not show_repr_only or f.repr) + ] + for param in args: + obj = getattr(self, param, "not_found") + if isinstance(obj, str) and obj == "not_found": + # check in dataclass fields for default values + for f in fields(self): + if f.name == param: + if f.default_factory is not None: + obj = f.default_factory() + else: + obj = f.default + break + if hasattr(obj, "get_params") and deep: + params[param] = obj.get_params(deep=deep) + elif not pdt.is_complex(obj) or deep: + params[param] = obj + return params + + +def get_processor_from_config_name( + config_name, processor_type=None, processor_name=None +): + try: + package = "biofit" + module_name = "models" + if processor_type: + package = f"{package}.{module_name}" + module_name = processor_type + if processor_name: + package = f"{package}.{module_name}" + module_name = processor_name + package = package.replace("-", "_") + module_name = module_name.replace("-", "_") + module = importlib.import_module(f".{module_name}", package=package) + config_cls = getattr(module, config_name) + except (ModuleNotFoundError, AttributeError): + return None + + return config_cls + + +@dataclass +class FitTransformConfig(BaseConfig): + # common attributes + + map_kwargs: dict = field(default_factory=lambda: {"fn_kwargs": {}}) + version: str = "0.0.0" + + # only here to transmit the values to the next config + load_from_cache_file = is_caching_enabled() + + def populate_map_kwargs(self): + if "keep_in_memory" not in self.map_kwargs: + self.map_kwargs["keep_in_memory"] = not self.cache_output + + if "cache_file_name" not in self.map_kwargs: + self.map_kwargs["cache_file_name"] = self.cache_file_name + + if "num_proc" not in self.map_kwargs: + self.map_kwargs["num_proc"] = self.num_proc + + return self + + @classmethod + def prepare_config( + cls, + **kwargs, + ): + self = cls() + self = self.replace_defaults(**kwargs) + self = self.populate_map_kwargs() + return self + + +@dataclass +class FitConfig(FitTransformConfig): + # to be used for ray data only + concurrency: Union[Tuple[int, int], int] = None + + +@dataclass +class TransformConfig(FitConfig): + @classmethod + def from_config(cls, config: FitTransformConfig): + if isinstance(config, FitTransformConfig): + states = copy.deepcopy(config.__dict__) + elif isinstance(config, dict): + states = copy.deepcopy(config) + else: + return super().from_config(copy.deepcopy(config)) + return cls.prepare_config(**states) + + +@dataclass +class ProcessorConfig(BaseConfig): + f""" + Stores the parameters for all processors. + + The config should contain all the parameters for transforming the data without the + need to repeat fitting. + + Args: + output_format (str, *optional*): + The output format of the transformed data. The format will be the same as + the input data if not provided. Possible values are {_ORDERED_FORMATS}. + input_columns (_SelectedColumnTypes, *optional*): + The input columns to be used for fitting the processor and transforming the + data. If more than one table or array is provided, a list of lists will + correspond to each input argument. A single list will be applied to only + the first input argument. Only `datasets.Bioset`, + `datasets.IterableDataset`, or `biofit.Bioset` input data support column + name selection via `datasets.Features` object. + unused_columns (_SelectedColumnTypes, *optional*): + The columns that are not used for fitting the processor. This is ignored if + `input_columns` is provided. A single list or item will be applied to all + input arguments, while a list of lists will correspond to each input + argument. + keep_unused_columns (bool): + Whether to keep the unused columns in the final output. Default is `True`. + raise_if_missing (bool): + Whether to raise an error if the input columns provided are missing during + fitting. Default is `True`. If `False`, the processor will use the columns + that are present in the input data and ignore the missing columns. Use this + when the pipeline contains feature selection before the processor. + enable_caching (bool, *optional*): + Whether to disable or enable writing and reading from cache. Default is + `True`. + cache_dir (str, *optional*): + The directory to store the cached data. Default is `None`. + version (str): + The version of the processor. Default is `"0.0.0"`. This is used for + caching purposes. Set this to a new version when the processor is updated + and the parameters are the same as a previous version. + + Attributes: + _fit_process_desc (str): + The description next to the progress bar during fitting. + _transform_process_desc (str): + The description next to the progress bar during transformation. + _input_feature_types (_SelectedFeatureTypes, *optional*): + The input feature types that will be applied by the processor. Only used + when input has a `features` attribute containing `datasets.Features`, such + as a `datasets.Bioset`, `datasets.IterableDataset`, or `biofit.Bioset`. + When `input_columns` is provided, this attribute is ignored. + _unused_feature_types (_SelectedFeatureTypes, *optional*): + The feature types that are not used for fitting the processor. This is + ignored if `input_columns` is provided. + features_out_suffix (str, *optional*): + The suffix to be added to the output feature names. + features_out_prefix (str): The prefix to be added to the output feature names. + processor_type (str, *optional*): + The type of the processor (e.g. feature_selection, scaling, imputation, + etc.). Used for auto class instantiation. Must be the same as the name of + the parent module where the processor is defined. A `None` value implies + that the processor has no parent module. + processor_name (str): + The name of the processor (e.g. select_k_best, min_max_scaler, + simple_imputer, etc.). Used for auto class instantiation. Must be the same + as the module name where the processor is defined. + dataset_name (str, *optional*): + The name of the dataset the processor is applied to. Used for auto class + instantiation based on the type of the dataset. A `None` value implies that + the processor is not dataset-specific. + output_template_name (str, *optional*): + The name of the output template. This is used to generate the output + feature names. + is_fitted (bool): Whether the processor is fitted to the input data. + _batch_method_prefix (str): + The prefix to be added to the batch method name. This is used to call the + batch method during fitting. + _input_columns (List[_SelectedColumnTypes], *optional*): + The parsed input columns, with number of lists equaling to the number of + input arguments. This is generated from + `TransformationMixin._set_input_columns_and_arity`. + n_features_in_ (int): The number of input features. + _n_features_out (int, *optional*): + The number of output features. Anything other than `None` implies that the + processor is not one-to-one. Set this during fit or before transformation + only if the transformation results in new features *and* the number of + output features is not equal to the number of input features. For example, + feature extraction, feature generation, etc. Preprocessing steps like + feature selection does not result in *new* features and should be kept as + `None`. + _features_out (Features, *optional*): + The `datasets.Features` object for the output features. This is inferred + before transformation. Used for caching the output data as an arrow ipc + stream. + feature_idx_in_ (List[int], *optional*): + The column indices of the input that were used for fitting. This is + automatically set. + feature_names_in_ (List[str], *optional*): + The column names of the input features. This is automatically set during + fitting. Only set if the input data during fit supports column names. + Transformations will use this attribute to select the input columns, if + supported. Otherwise, `feature_idx_in_` will be used. + target_idx_in_ (List[int], *optional*): + The column indices of the target that were used for fitting. This is + automatically set. + target_name_in_ (List[str], *optional*): + The column names of the target features. This is automatically set during + fitting. Only set if the target data during fit supports column names. + Transformations will use this attribute to select the target columns, if + supported. Otherwise, `target_idx_in_` will be used. + one_to_one_features (bool): + Whether the transformation results in a one-to-one mapping of input to + output features. This will be `True` if `_n_features_out` is `None` and + `False` otherwise. + _returns_tuple (bool): + Whether the transform method returns a tuple of data. See + `MinPrevalenceRowSampleFilter` for an example. + _data_fingerprint (str): + The fingerprint of the input data. This is used to recognize the input data + during transformation. If the input is the same as the data used for + fitting, information from the fit process is reused to process the input + data for transformation (e.g. selecting the same columns, etc.). + """ + + output_format: str = field(default=None, kw_only=True, init=True, repr=False) + input_columns: SelectedColumnTypes = field( + default=None, kw_only=True, init=True, repr=False + ) + unused_columns: SelectedColumnTypes = field( + default=None, kw_only=True, init=True, repr=False + ) + keep_unused_columns: bool = field(default=True, kw_only=True, init=True, repr=False) + raise_if_missing: bool = field(default=True, kw_only=True, init=True, repr=False) + enable_caching: bool = field(default=True, kw_only=True, init=True, repr=False) + cache_output: bool = field(default=True, kw_only=True, init=True, repr=False) + load_from_cache_file: bool = field( + default=True, kw_only=True, init=True, repr=False + ) + cache_dir: str = field(default=None, kw_only=True, init=True, repr=False) + version: str = field( + default=version.__version__, kw_only=True, init=True, repr=True + ) + + _fit_process_desc: str = field( + default="Fitting the processor to the input data", init=False, repr=False + ) + _transform_process_desc: str = field( + default="Transforming the input data", init=False, repr=False + ) + + _fit_input_feature_types: SelectedFeatureTypes = field( + default=None, init=False, repr=False + ) + _transform_input_feature_types: SelectedFeatureTypes = field( + default=None, init=False, repr=False + ) + _fit_unused_feature_types: SelectedFeatureTypes = field( + default=None, init=False, repr=False + ) + _transform_unused_feature_types: SelectedFeatureTypes = field( + default=None, init=False, repr=False + ) + + _input_columns: SelectedColumnTypes = field( + default=None, kw_only=True, init=False, repr=False + ) + features_out_suffix: str = field(default=None, init=False, repr=False) + features_out_prefix: str = field(default=None, init=False, repr=False) + processor_type: str = field(default="", init=False, repr=False) + processor_name: str = field(default="", init=False, repr=False) + dataset_name: str = field(default="", init=False, repr=False) + + output_template_name: str = field(default=None, init=False, repr=False) + + # automatically generated attributes + _batch_method_prefix: str = field(default="_partial", init=False, repr=False) + is_fitted: bool = field(default=False, init=False, repr=False) + n_features_in_: int = field(default=None, init=False, repr=False) + _n_features_out: int = field(default=None, init=False, repr=False) + _features_out: list = field(default=None, init=False, repr=False) + feature_idx_in_: List[int] = field(default=None, init=False, repr=False) + feature_names_in_: List[str] = field(default=None, init=False, repr=False) + extra_idx_in_: List[List[int]] = field(default=None, init=False, repr=False) + extra_names_in_: List[List[str]] = field(default=None, init=False, repr=False) + _feature_names_out: List[str] = field(default=None, init=False, repr=False) + _feature_idx_out: List[int] = field(default=None, init=False, repr=False) + _returns_tuple: bool = field(default=False, init=False, repr=False) + _data_fingerprint: str = field(default=None, init=False, repr=False) + + @property + def one_to_one_features(self): + return self._n_features_out is None + + def to_dict(self, *args, deep=True, **kwargs): + states = super().to_dict(*args, deep=deep, **kwargs) + states["_fit_process_desc"] = self._fit_process_desc + states["_transform_process_desc"] = self._transform_process_desc + if deep: + + def set_nested_feature_type(ft_name): + ft_obj = getattr(self, ft_name) + if isinstance(ft_obj, type): + states[ft_name] = [ft_obj.__name__] + elif isinstance(ft_obj, tuple): + states[ft_name] = [tuple(ft.__name__ for ft in ft_obj)] + elif isinstance(ft_obj, list): + states[ft_name] = [] + for ft in ft_obj: + if isinstance(ft, type): + states[ft_name].append(ft.__name__) + elif isinstance(ft, tuple): + states[ft_name].append(tuple(f.__name__ for f in ft)) + elif ft is None: + states[ft_name].append(None) + elif ft_obj is None: + states[ft_name] = None + + set_nested_feature_type("_fit_input_feature_types") + set_nested_feature_type("_transform_input_feature_types") + set_nested_feature_type("_fit_unused_feature_types") + set_nested_feature_type("_transform_unused_feature_types") + + states["features_out_suffix"] = self.features_out_suffix + states["features_out_prefix"] = self.features_out_prefix + states["processor_name"] = self.processor_name + states["processor_type"] = self.processor_type + states["dataset_name"] = self.dataset_name + return states + + @classmethod + def from_dict( + cls, states: dict, ignore_none: bool = False, add_new_attr: bool = False + ) -> T_CONFIG: + self = super().from_dict( + states, ignore_none=ignore_none, add_new_attr=add_new_attr + ) + + if is_datasets_available(): + from datasets.features.features import _FEATURE_TYPES + else: + _FEATURE_TYPES = {} + + def get_nested_feature_type(ft): + if isinstance(ft, str): + if not is_datasets_available(): + raise ValueError( + "Trying to load cache using datasets.Feature without datasets " + "installed. Please install datasets to load the cache." + ) + return _FEATURE_TYPES.get(ft) + elif isinstance(ft, (tuple, list)): + return tuple(get_nested_feature_type(f) for f in ft) + return ft + + if self._fit_input_feature_types: + if not isinstance(self._fit_input_feature_types, list): + self._fit_input_feature_types = [self._fit_input_feature_types] + self._fit_input_feature_types = [ + get_nested_feature_type(ft) for ft in self._fit_input_feature_types + ] + if self._fit_unused_feature_types: + if not isinstance(self._fit_unused_feature_types, list): + self._fit_unused_feature_types = [self._fit_unused_feature_types] + self._fit_unused_feature_types = [ + get_nested_feature_type(ft) for ft in self._fit_unused_feature_types + ] + if self._transform_input_feature_types: + if not isinstance(self._transform_input_feature_types, list): + self._transform_input_feature_types = [ + self._transform_input_feature_types + ] + self._transform_input_feature_types = [ + get_nested_feature_type(ft) + for ft in self._transform_input_feature_types + ] + if self._transform_unused_feature_types: + if not isinstance(self._transform_unused_feature_types, list): + self._transform_unused_feature_types = [ + self._transform_unused_feature_types + ] + self._transform_unused_feature_types = [ + get_nested_feature_type(ft) + for ft in self._transform_unused_feature_types + ] + return self + + def replace_defaults(self, ignore_none=False, add_new_attr=False, **states): + for k, v in states.items(): + if isinstance(v, Unset): + continue + if (add_new_attr or hasattr(self, k)) and ( + not ignore_none or v is not None + ): + setattr(self, k, v) + return self + + +class TransformationMixin: + def _get_method(self, formats, func_type, prefix=None): + """ + Retrieves processing methods based on the function type and target format. + + Args: + format (str): The target format. + func_type (str): The type of processing method. + prefix (str, *optional*): + The prefix to be added to the method name (e.g "_partial" for batch + processing methods). Default is None. + + Returns: + A list of processing methods based on the target format. + """ + funcs = [] + + if isinstance(formats, str): + format = [formats] + for format in formats + _SPECIAL_FORMATS: + if prefix: + func = getattr(self, f"{prefix}{func_type}_{format}", None) + if func is not None: + funcs.append(func) + else: + func = getattr(self, f"{func_type}_{format}", None) + if func is not None: + funcs.append(func) + return funcs + + def _has_method(self, formats, func_type, prefix=None): + """ + Checks if atleast one method exists for the given [prefix]_[func_type]_[format] + or [func_type]_[format] if prefix is not given. + + Args: + format (str): The target format. + func_type (str): The type of processing method. + prefix (str, *optional*): + The prefix to be added to the method name (e.g "_partial" for batch + processing methods). Default is None. + check_only (bool, *optional*): + If True, only checks if a method exists. Default is False. + + Returns: + True if a method exists for any of the given formats, False + otherwise. + """ + + if isinstance(formats, str): + format = [formats] + for format in formats + _SPECIAL_FORMATS: + if prefix: + func = getattr(self, f"{prefix}{func_type}_{format}", None) + if func is not None: + return True + else: + func = getattr(self, f"{func_type}_{format}", None) + if func is not None: + return True + return False + + def _get_target_func( + self, + funcs, + source_format, + target_formats=None, + accepted_formats=_ORDERED_FORMATS, + ): + formats = [] + new_funcs = [] + for f in accepted_formats: + for fun in funcs: + if fun.__name__.endswith(f"_{f}"): + formats.append(f) + new_funcs.append(fun) + + if formats: + # this class has a method for the target format + to_format = formats[0] + if source_format in formats: + to_format = source_format + if target_formats is not None: + if not isinstance(target_formats, list): + target_formats = [target_formats] + for target_format in target_formats: + if target_format in formats: + to_format = target_format + break + else: + logger.warning( + f"Using {self.__class__.__name__} using `{target_formats}` is " + f"not supported. Formatting input to `{to_format}` instead" + ) + + return new_funcs[formats.index(to_format)], to_format + + any_funcs = [ + f + for f in funcs + if f.__name__.endswith("_any") or f.__name__.endswith("_all") + ] + if any_funcs: + return any_funcs[0], source_format + + tbl_funcs = [f for f in funcs if f.__name__.endswith("_array")] + if tbl_funcs: + target_formats = _ORDERED_FORMATS + funcs = tbl_funcs + else: + df_funcs = [f for f in funcs if f.__name__.endswith("_dataframe")] + if df_funcs: + target_formats = _DATAFRAME_FORMATS + funcs = df_funcs + else: + arr_funcs = [f for f in funcs if f.__name__.endswith("_sklearn")] + if arr_funcs: + target_formats = _SKLEARN_FORMATS + funcs = arr_funcs + + if funcs and len(funcs): + func = funcs[0] + to_format = None + if target_formats is not None and len(funcs): + # we assume that the first function handles all formats within fn_trans_format + # e.g sklearn functions with [polars, numpy, pandas] formats + if isinstance(target_formats, list): + # prioritize input format over output format if both are supported + if source_format in target_formats: + to_format = source_format + else: + to_format = target_formats[0] + + return func, to_format + + return None, target_formats + + def _parse_column_selection(self, *args, from_config=False): + if from_config: + return self._parse_column_selection_from_config(*args) + return self._parse_column_selection_from_self(*args) + + def _parse_column_selection_from_config(self, *args, from_config=False): + out = {} + if len(args): + if self.config.feature_names_in_: + out[args[0]] = self.config.feature_names_in_ + else: + out[args[0]] = [f"{i}" for i in self.config.feature_idx_in_] + if len(args) > 1: + for i, arg in enumerate(args): + if ( + self.config.extra_names_in_ + and len(self.config.extra_names_in_) > i + and self.config.extra_names_in_[i] + ): + out[arg] = self.config.extra_names_in_[i] + elif ( + self.config.extra_idx_in_ + and len(self.config.extra_idx_in_) > i + and self.config.extra_idx_in_[i] + ): + out[arg] = [f"{j}" for j in self.extra_idx_in_[i]] + return out + + def _parse_column_selection_from_self(self, *args, from_config=False): + out = {} + if len(args): + if self.feature_names_in_: + out[args[0]] = self.feature_names_in_ + else: + out[args[0]] = [f"{i}" for i in self.feature_idx_in_] + if len(args) > 1: + for i, arg in enumerate(args): + if ( + self.extra_names_in_ + and len(self.extra_names_in_) > i + and self.extra_names_in_[i] + ): + out[arg] = self.extra_names_in_[i] + elif ( + self.extra_idx_in_ + and len(self.extra_idx_in_) > i + and self.extra_idx_in_[i] + ): + out[arg] = [f"{j}" for j in self.extra_idx_in_[i]] + return out + + def _set_input_columns_and_arity(self, *args): + input_columns = None + if len(args) > 1: + input_columns = [None] * len(args) + for i, arg in enumerate(args): + if arg is not None: + if isinstance(arg, (str, int)): + input_columns[i] = [arg] + elif isinstance(arg, list): + input_columns[i] = arg + else: + input_columns = args[0] or None + if isinstance(input_columns, (str, int)): + input_columns = [input_columns] + input_columns = [input_columns] + return input_columns + + def _reinsert_columns( + self, input, out, indices, unused_indices, one_to_one_features=False + ): + out_dims = DataHandler.get_shape(out) + x_dims = DataHandler.get_shape(input) + + if unused_indices and x_dims[0] == out_dims[0]: + other_col_names = DataHandler.get_column_names(input, generate_cols=True) + other_col_names = [other_col_names[i] for i in unused_indices] + other_cols = DataHandler.select_columns(input, other_col_names) + other_dims = DataHandler.get_shape(other_cols) + if len(other_dims) == 1: + other_dims = (other_dims[0], 1) + other_cols = DataHandler.to_frame(other_cols, "__input__") + if len(out_dims) == 1: + out_dims = (out_dims[0], 1) + out = DataHandler.to_frame(out, "__output__") + if one_to_one_features: + other_inds = unused_indices + out_inds = indices + else: + other_inds = list(range(other_dims[1])) + out_inds = list(range(other_dims[1], other_dims[1] + out_dims[1])) + + if other_dims[1] > out_dims[1]: + out = DataHandler.concat([other_cols, out], axis=1) + inds = list(np.argsort(other_inds + out_inds)) + else: + out = DataHandler.concat([out, other_cols], axis=1) + inds = list(np.argsort(out_inds + other_inds)) + + out = DataHandler.select_columns(out, inds) + return out + + def _make_columns_exclusive(self, columns): + new_set = columns.copy() + for i in reversed(range(0, len(columns) - 1)): + if columns[i] is not None and columns[i + 1] is not None: + new_set[i] = list(set(columns[i]) - set(columns[i + 1])) + return new_set + + def _get_columns( + self, + X, + *args, + input_columns=None, + input_feature_types=None, + unused_columns=None, + unused_feature_types=None, + raise_if_missing=True, + ): + assert X is not None, "Input data is None" + first_arg_row_num = DataHandler.get_shape(X)[0] + assert first_arg_row_num > 0, "Input data has no rows" + assert input_columns is None or isinstance(input_columns, list), ( + f"input_columns must be a list of column names or indices, " + f"but got {type(input_columns)}" + ) + + def get_columns( + X, + input_columns=None, + unused_columns=None, + generate_cols=False, + raise_if_missing=True, + ): + if X is None: + return None, None, None + col_names = DataHandler.get_column_names(X, generate_cols=True) + col_names_set = set(col_names) + if input_columns: + if isinstance(input_columns, tuple) or isinstance(input_columns, type): + feature_type = input_columns + if isinstance(feature_type, type): + feature_type = (feature_type,) + + if is_datasets_available(): + from datasets.features.features import _FEATURE_TYPES + + if all(f in _FEATURE_TYPES.values() for f in feature_type): + try: + input_columns = ( + DataHandler.get_column_names_by_feature_type( + X, feature_type=feature_type + ) + ) + except ValueError: + if generate_cols: + input_columns = DataHandler.get_column_names( + X, generate_cols=True + ) + else: + input_columns = None + + elif input_columns: + if isinstance(input_columns, (str, int)): + input_columns = [input_columns] + if not isinstance(input_columns[0], int): + missing_columns = set(input_columns) - col_names_set + if missing_columns and raise_if_missing: + raise ValueError( + f"Columns {missing_columns} not found in input dataset" + ) + else: + input_columns = [ + c for c in input_columns if c in col_names_set + ] + elif unused_columns: + if isinstance(unused_columns, (str, int)): + unused_columns = [unused_columns] + if isinstance(unused_columns, tuple) or isinstance( + unused_columns, type + ): + feature_type = unused_columns + if isinstance(feature_type, type): + feature_type = (feature_type,) + if is_datasets_available(): + from datasets.features.features import _FEATURE_TYPES + else: + _FEATURE_TYPES = {} + if all(f in _FEATURE_TYPES.values() for f in feature_type): + try: + unused_columns = ( + DataHandler.get_column_names_by_feature_type( + X, feature_type=feature_type + ) + or [] + ) + except ValueError: + if generate_cols: + unused_columns = [] + else: + unused_columns = None + if unused_columns is not None: + unused_columns = set(unused_columns) + input_columns = [c for c in col_names if c not in unused_columns] + elif generate_cols: + input_columns = DataHandler.get_column_names(X, generate_cols=True) + + if input_columns: + if isinstance(input_columns[0], str): + input_indices = DataHandler.get_column_indices(X, input_columns) + else: + input_indices = input_columns + if DataHandler.supports_named_columns(get_data_format(X)): + cols = DataHandler.get_column_names(X, generate_cols=True) + input_columns = [cols[idx] for idx in input_indices] + else: + input_columns = None + unused_indices = list( + sorted(set(range(len(col_names))) - set(input_indices)) + ) + else: + return None, None, None + + return input_columns, input_indices, unused_indices + + arity = 1 + if input_columns and isinstance(input_columns, list): + arity = len(input_columns) + else: + return None, None, None, None, None, None, None + + def parse_inputs(i): + _input_columns = input_columns[i] + + _input_feature_types = None + if input_feature_types is not None: + _input_feature_types = input_feature_types[i] + + _unused_columns = None + if unused_columns is not None: + _unused_columns = unused_columns[i] + + _unused_feature_types = None + if unused_feature_types is not None: + _unused_feature_types = unused_feature_types[i] + + return ( + _input_columns or None, + _input_feature_types or None, + _unused_columns or None, + _unused_feature_types or None, + ) + + _input_columns, _input_feature_types, _unused_columns, _unused_feature_types = ( + parse_inputs(0) + ) + + feature_names_in, feature_idx_in, unused_idx_in = get_columns( + X, + input_columns=_input_columns or _input_feature_types, + unused_columns=_unused_columns or _unused_feature_types, + generate_cols=True, + raise_if_missing=raise_if_missing, + ) + if _input_columns is not None and feature_idx_in is None: + raise ValueError( + f"Columns {_input_columns} not found in {DataHandler.get_column_names(X)}" + ) + extra_names_in = None + extra_idx_in = None + unused_extra_idx_in = None + offsets = None + assert ( + arity == len(args) + 1 + ), f"Number of column sets ({arity}) must match the arity ({len(args) + 1})" + if arity > 1 or len(args): + extra_names_in = [] + extra_idx_in = [] + offsets = [] + if len(args) and not all(arg is None for arg in args): + unused_extra_idx_in = [] + x_dims = DataHandler.get_shape(X) + if len(x_dims) == 1: + offset = 1 + else: + offset = x_dims[1] + for i, arg in enumerate(args): + ( + _input_columns, + _input_feature_types, + _unused_columns, + _unused_feature_types, + ) = parse_inputs(i + 1) + + if not is_bioset(X) and not is_dataset(X): + _unused_feature_types = None + _input_feature_types = None + if arg is None: + # look into the first input + _input_columns = _input_columns or _input_feature_types + _unused_columns = _unused_columns or _unused_feature_types + if _input_columns is None and _unused_columns is None: + _extra_names_in, _extra_idx_in, _unused_extra_idx_in = ( + None, + None, + None, + ) + else: + _extra_names_in, _extra_idx_in, _unused_extra_idx_in = ( + get_columns( + X, + input_columns=_input_columns + or _input_feature_types, + unused_columns=_unused_columns + or _unused_feature_types, + generate_cols=True, + raise_if_missing=raise_if_missing, + ) + ) + offsets.append(0) + else: + arg_dim = DataHandler.get_shape(arg) + _extra_names_in, _extra_idx_in, _unused_extra_idx_in = ( + get_columns( + arg, + input_columns=_input_columns or _input_feature_types, + unused_columns=_unused_columns or _unused_feature_types, + generate_cols=True, + raise_if_missing=raise_if_missing, + ) + ) + arg_dim = DataHandler.get_shape(arg) + + # only offset when the tables can be combined + if first_arg_row_num == arg_dim[0]: + offsets.append(offset) + if len(arg_dim) == 1: + offset += 1 + else: + offset += arg_dim[1] + extra_names_in.append(_extra_names_in) + extra_idx_in.append(_extra_idx_in) + unused_extra_idx_in.append(_unused_extra_idx_in) + + else: + for i in range(1, arity): + ( + _input_columns, + _input_feature_types, + _unused_columns, + _unused_feature_types, + ) = parse_inputs(i) + _extra_names_in, _extra_idx_in, _unused_extra_idx_in = get_columns( + X, + input_columns=_input_columns or _input_feature_types, + unused_columns=_unused_columns or _unused_feature_types, + generate_cols=False, + raise_if_missing=raise_if_missing, + ) + extra_names_in.append(_extra_names_in) + extra_idx_in.append(_extra_idx_in) + _extra_idx_in = set(_extra_idx_in or []) + unused_idx_in = [ + idx for idx in unused_idx_in if idx not in _extra_idx_in + ] + offsets.append(0) + return ( + feature_names_in, + feature_idx_in, + unused_idx_in, + extra_names_in, + extra_idx_in, + unused_extra_idx_in, + offsets, + ) + + def generate_fingerprint(self, fingerprint, config: BaseConfig): + hash = Hasher() + hash.update(fingerprint) + hash_str = f"{self.__module__}.{self.__class__.__name__}" + if hasattr(config, "version"): + hash_str += f"@{config.version}" + hash.update(hash_str) + fingerprint = hash.hexdigest() + fingerprint = fingerprint_from_kwargs(fingerprint, config.get_params()) + + return fingerprint + + +class BaseProcessor(TransformationMixin): + """ + Configures and manages data processing operations, supporting transformations and + handling of various configurations and states, primarily designed for batch + processing of data with options for multiprocessing and caching. + + Attributes: + update_fingerprint (bool): + Flag to determine if the fingerprint should be updated after processing. + output_dtype (type): Data type for the output features. + config (ProcessorConfig): Configuration object specifying processor settings. + cache_files (list): List of cache file paths used during processing. + + Raises: + ValueError: If an unsupported configuration type is provided." + """ + + # process attributes + output_dtype = None + + # config attributes + config_class = ProcessorConfig + config: ProcessorConfig = None + + # internal attributes for transformation + _feature_dependent = True + _input_columns = None + _method_prefix: str = "_transform" + _fingerprint = None + _is_multiprocessing = False + extra_names_in_ = [] + _selected_indices = None + _unused_indices = None + _extra_indices = None + _unused_extra_indices = None + + # automatically generated attributes + cache_files = None + + def __init__(self, config: Optional[ProcessorConfig] = None, **kwargs): + add_new_attr = kwargs.pop("add_new_attr", False) + ignore_none = kwargs.pop("ignore_none", False) + + if config is None: + if hasattr(self, "config_class"): + self.config = self.config_class.from_dict( + kwargs, ignore_none=ignore_none, add_new_attr=add_new_attr + ) + elif isinstance(config, ProcessorConfig): + self.config = config + elif isinstance(config, dict): + self.config = self.config_class.from_dict( + config, ignore_none=ignore_none, add_new_attr=add_new_attr + ) + else: + raise ValueError(f"Unsupported config type {type(config)}") + if config is None: + self = self.set_params(**kwargs) + if kwargs.get("_function", None): + self._function = kwargs["_function"] + + @classmethod + def _from_config(cls, config: ProcessorConfig, **kwargs): + """Instantiates the processor from a configuration object.""" + return cls(config=config, **kwargs) + + def __call__(self, batch: Union[pd.DataFrame, Dict[str, np.ndarray]], **kwargs): + self = self._process_batches(batch, **kwargs) + return batch + + @sync_backup_config + def set_params(self, **kwargs): + """Sets the parameters of the processor""" + self.config = self.config.replace_defaults(**kwargs) + return self + + @property + def is_fitted(self): + """bool: Whether the processor has been fitted.""" + return self.config.is_fitted + + @property + def has_fit(self): + """bool: Whether a fit function is found for the processor""" + return self._get_method(_ORDERED_FORMATS, func_type="_fit") or self._get_method( + _ORDERED_FORMATS, func_type="_fit", prefix=self._batch_method_prefix + ) + + @property + def fingerprint(self): + """str: The fingerprint of the processor.""" + return self._parse_fingerprint(self._fingerprint) + + def _parse_fingerprint(self, fingerprint): + """Parses the fingerprint and returns the base fingerprint and the processor suffix.""" + base_fingerprint = fingerprint + processor_suffix = f"-{self.config.processor_name}" + if self.config.processor_type: + processor_suffix += f"-{self.config.processor_type}" + if self.config.dataset_name: + processor_suffix += f"-{self.config.dataset_name}" + return f"{base_fingerprint}{processor_suffix}" + + def _reset(self, config: ProcessorConfig): + """Resets the processor to its initial state.""" + # reinstatiate the processor + self._fingerprint = None + self.__init__(config=config) + + @classmethod + def from_config( + cls, path_or_config: Union[str, os.PathLike, dict, "BaseConfig"], **kwargs + ) -> T_PROC: + """Instantiates the processor from a configuration file or object.""" + config = cls.config_class.from_config(path_or_config, add_new_attr=True) + return cls(config=config, **kwargs) + + @classmethod + def from_config_transform(cls, X, path: str, **kwargs): + """Transforms the input data using the processor configuration.""" + self = cls.from_config(path, **kwargs) + self.config.is_fitted = True + return self.transform(X) + + def populate_map_kwargs(self, map_kwargs, cache_output, cache_file_name, num_proc): + """ + Populates the map_kwargs with default values if not provided. + + Args: + map_kwargs (dict): The keyword arguments for the map function. + cache_output (bool): Whether to keep the processed data in memory. + cache_file_name (str): The name of the cache file. + num_proc (int): The number of processes to use for multiprocessing. + + Returns: + dict: The updated map_kwargs. + """ + if "keep_in_memory" not in map_kwargs: + map_kwargs["keep_in_memory"] = not cache_output + + if "cache_file_name" not in map_kwargs: + map_kwargs["cache_file_name"] = cache_file_name + + if "num_proc" not in map_kwargs: + map_kwargs["num_proc"] = num_proc + + return map_kwargs + + def _validate_fit_params(self, arity): + if self.config._input_columns is not None: + if self.config._fit_input_feature_types is not None and len( + self.config._input_columns + ) != len(self.config._fit_input_feature_types): + example_arg = ", ".join(["None" for _ in range(arity)]) + assert False, ( + "`_fit_input_feature_types` is defined in " + f"{self.config.__class__.__name__} but does not match the arity of " + f"the fit function in {self.__class__.__name__} (i.e. len(" + "self.config._fit_input_feature_types) != " + "len(self.config._input_columns) -> " + f"{len(self.config._fit_input_feature_types)} != " + f"{len(self.config._input_columns)}).\n" + "This can be corrected by doing, for example:\n" + f"_fit_input_feature_types = field(\n" + f" default_factory=lambda: [{example_arg}], init=False, " + "repr=False\n" + ")" + ) + if self.config._fit_unused_feature_types is not None and len( + self.config._input_columns + ) != len(self.config._fit_unused_feature_types): + example_arg = ", ".join(["None" for _ in range(arity)]) + assert False, ( + "`_fit_unused_feature_types` is defined in " + f"{self.config.__class__.__name__} but does not match the arity of " + f"the fit function in {self.__class__.__name__} (i.e. len(" + "self.config._fit_unused_feature_types) != " + "len(self.config._input_columns) -> " + f"{len(self.config._fit_unused_feature_types)} != " + f"{len(self.config._input_columns)}).\n" + "This can be corrected by doing, for example:\n" + f"_fit_unused_feature_types = field(\n" + f" default_factory=lambda: [{example_arg}], init=False, " + "repr=False\n" + ")" + ) + + def fit( + self, + X, + raise_if_missing: bool = None, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + num_proc: int = None, + map_kwargs: dict = {"fn_kwargs": {}}, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ): + """Must be implemented by subclasses if the processor is trainable.""" + # only use this fit method if no concrete fit method is found + return self._process_fit( + X, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ) + + def _process_fit( + self, + X, + *args, + raise_if_missing: bool = None, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + num_proc: int = None, + map_kwargs: dict = {"fn_kwargs": {}}, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ) -> T_PROC: + """ + Fits the processor to the input data, preparing it for transformation. This process + may involve learning parameters from the data, validating input formats, and setting + up caching mechanisms. + + Args: + X (Union[np.ndarray, pd.DataFrame, Bioset, IterableDataset, DatasetDict, IterableDatasetDict]): + The input data to fit the processor on. Can be a variety of types including + numpy arrays, pandas DataFrames, or Hugging Face's `datasets` objects. + *args (Any, optional): + Additional input data to fit the processor on, such as target data for + supervised learning tasks. Defaults to None. + batched (bool, optional): + Whether to process the data in batches. This can be beneficial for large datasets. Defaults to None, + which will use the processor's default behavior. + batch_transform (bool, optional): + Specifies if transformation should be applied in batches. Defaults to None. + batch_fit (bool, optional): + Specifies if fitting should be applied in batches. Defaults to None. + batch_size (int, optional): + Size of each batch when `batched` is True. Defaults to 1000. + map_kwargs (dict, optional): + Additional keyword arguments to pass to the map function when processing datasets. Defaults to None. + cache_output (bool, optional): + Whether to keep the processed data in memory. Useful for avoiding disk IO. Defaults to None. + batch_format (str, optional): + The format to convert the input data to before processing. Defaults to None. + batch_format_kwargs (dict, optional): + Additional keyword arguments for the input format conversion. Defaults to None. + fn_output_format_kwargs (dict, optional): + Additional keyword arguments for the output format conversion. Defaults to None. + split (str, optional): + If the input is a DatasetDict or IterableDatasetDict, this specifies the split to process. Defaults to 'train'. + cache_dir (str, Path, optional): + Directory where processed datasets should be cached. Defaults to None, which uses the processor's default cache directory. + fingerprint (str, optional): + A unique identifier for the processing operation, used for caching. Defaults to None. + load_from_cache_file (bool, optional): + Whether to load the fitted processor from a cache file if available. Defaults to None, which follows the processor's default behavior. + update_fingerprint (bool, optional): + Whether to update the fingerprint after processing. Useful for ensuring uniqueness in caching. Defaults to None. + keep_unused_columns (bool, optional): + Whether to retain features in the dataset that are not processed by this processor. Defaults to True. + + Returns: + self: The fitted processor instance. + + Raises: + ValueError: If `input_columns` are specified but do not exist in the dataset. + """ + + funcs = self._get_method(_ORDERED_FORMATS, func_type="_fit") + assert self.config._input_columns is not None or len(funcs) == 0, ( + f"The `fit` method of `{self.__class__.__name__}` must call:\n" + "```\n" + "self.config._input_columns = self._set_input_columns_and_arity(*args)" + "\n```\n" + "Where `*args` are the columns for each input dataset." + ) + + if not hasattr(self, "config_"): + self.config_ = copy.deepcopy(self.config) + else: + self._reset(copy.deepcopy(self.config_)) + + if is_dataset_dict(X): + raise ValueError( + "Please provide the dataset directly instead of the dictionary: processor.fit(dataset['train'])" + ) + + if cache_output is None: + cache_output = self.config.enable_caching and self.config.cache_output + + if cache_output: + self.config._data_fingerprint = getattr( + X, "_fingerprint", None + ) or fingerprint_from_data(X) + else: + self.config._data_fingerprint = None + + if cache_output: + self._fingerprint = fingerprint + + if not fingerprint: + self._fingerprint = fingerprint_from_kwargs( + self.config._data_fingerprint, + self.config.get_params(), + ) + + if cache_dir is not None: + cache_dir = expand_path(str(cache_dir)) + cache_dir = os.path.join(cache_dir, "processors") + + cache_dir = generate_cache_dir( + self, + self.config._data_fingerprint, + root_dir=cache_dir or biofit.config.BIOFIT_PROCESSORS_CACHE, + ) + + if cache_dir: + if cache_file_name: + if is_remote_url(cache_file_name): + raise ValueError( + "`cache_file_name` is a remote URL. Please provide the " + "file name only. You can specify the directory using " + "`cache_dir`." + ) + elif os.path.isabs(cache_file_name): + raise ValueError( + "`cache_file_name` is an absolute path. Please provide the " + "file name only. You can specify the directory using " + "`cache_dir`." + ) + self.cache_files = [ + { + "filename": get_cache_file_name( + cache_dir, self.fingerprint, cache_file_name + ) + } + ] + + self._validate_fit_params(len(args) + 1) + return self._fit( + X, + *args, + funcs=funcs, + cache_file_name=cache_file_name, + input_columns=self.config._input_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + cache_dir=cache_dir, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + ) + + @keep_dataset_fingerprint + def _fit( + self, + X, + *args, + funcs: List[Callable] = None, + cache_file_name: str = None, + input_columns: List[str] = None, + raise_if_missing: bool = None, + cache_output: bool = False, + load_from_cache_file: bool = None, + cache_dir: str = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + ): + """ + Fits the processor to the input data. + + Args: + X (Any): The input data. + y (Any, optional): The target data. Defaults to None. + **kwargs: Additional keyword arguments. + + Returns: + Any: The fitted processor. + """ + + # upadate fit_config with dataset and processor info + + if cache_output is None: + cache_output = self.config.enable_caching and self.config.cache_output + + if load_from_cache_file is None: + load_from_cache_file = ( + self.config.enable_caching and self.config.load_from_cache_file + ) + try: + if load_from_cache_file and self.cache_files: + cache_file_name = self.cache_files[0]["filename"] + self = self.load_processed_estimator_from_cache(cache_file_name) + logger.info(f"Loading cached processed estimator at {cache_file_name}") + else: + raise NonExistentCacheError + except NonExistentCacheError: + ( + selected_columns, + selected_indices, + unused_indices, + extra_columns, + extra_indices, + unused_extra_indices, + offsets, + ) = self._get_columns( + X, + *args, + input_columns=input_columns, + input_feature_types=self.config._fit_input_feature_types, + unused_columns=self.config.unused_columns, + unused_feature_types=self.config._fit_unused_feature_types, + raise_if_missing=raise_if_missing + if raise_if_missing is not None + else True, + ) + + self.config.feature_idx_in_ = selected_indices + self.config.feature_names_in_ = selected_columns + self.config.extra_names_in_ = extra_columns + self.config.extra_idx_in_ = extra_indices + + extra = [] + extra_to_pass = [] + # two options: either all args are the same number of rows as X and we can + # combine them or they are not and we need to pass them separately. + # Advantage of combining is that we can apply multiprocessing or batching + # by using the same indices. Also interops with HF's Bioset.map + if len(args) and not all(arg is None for arg in args): + main_dims = DataHandler.get_shape(X) + combine_all = True + for arg in args: + if arg is not None: + arg_dims = DataHandler.get_shape(arg) + if arg_dims[0] != main_dims[0]: + combine_all = False + break + for i, arg in enumerate(args): + if DataHandler.supports_named_columns(arg) and combine_all: + cols = DataHandler.get_column_names(arg, generate_cols=True) + cols = [f"{c}_{i}" for c in cols] + extra.append(DataHandler.set_column_names(arg, cols)) + else: + extra.append(arg) + if combine_all: + input = DataHandler.concat( + [X] + [ext for ext in extra if ext is not None], axis=1 + ) + extra_to_pass = None + else: + extra_to_pass = extra + extra = None + else: + input = X + + fit_map_kwargs, pooler = self._prepare_fit_kwargs( + funcs, + input, + X, + extra_inputs=extra, + extra_untouched_inputs=extra_to_pass, + selected_indices=selected_indices, + unused_indices=unused_indices, + extra_indices=extra_indices, + unused_extra_indices=unused_extra_indices, + offsets=offsets, + map_kwargs=map_kwargs, + num_proc=num_proc, + batch_format=batch_format, + batched=batched, + batch_size=batch_size, + ) + input, fit_map_kwargs = self._process_fit_input(input, **fit_map_kwargs) + runner = None + out = self + if fit_map_kwargs["fn_kwargs"]["fn"]: + if is_ray_available() and "ray" in sys.modules: + import ray.data + + if isinstance(X, ray.data.Bioset): + fit_ray_kwargs = self._convert_map_kwargs_to_ray_kwargs( + fit_map_kwargs, batch_format=batch_format, is_fit=True + ) + return X, fit_ray_kwargs + + pooler = pooler if pooler is not None else self._pool_fit + + @wraps(BaseProcessor.map) + def runner(*args, **map_kwargs): + out = self.map(*args, **map_kwargs) + if len(out) == 1: + return out[0] + return pooler(out) + + if is_iterable_dataset(input): + from datasets import IterableDataset + + @wraps(IterableDataset.map) + def runner(*args, **map_kwargs): + if len(args) > 0: + ds = args[0] + args = args[1:] + else: + ds = map_kwargs.pop("self") + return ds.map(*args, **map_kwargs) + + out = self.run(input, runner=runner, **fit_map_kwargs) + + self = self._process_fit_output(input, out) + + self.config.n_features_in_ = self.config.n_features_in_ + if ( + self.config.feature_idx_in_ is not None + and self.config.feature_names_in_ is None + ): + self.config.n_features_in_ = len(self.config.feature_idx_in_) + + self.config._n_features_out = ( + self.config._n_features_out + or self.config._n_features_out + or getattr(self, "n_features_out", None) + ) + + temp_file = None + if cache_output and self.cache_files: + cache_file_name = ( + Path(self.cache_files[0]["filename"]).resolve().as_posix() + ) + cache_dir = os.path.dirname(cache_file_name) + temp_file = tempfile.NamedTemporaryFile( + "wb", dir=cache_dir, delete=False + ) + try: + self.config.save_to_cache( + temp_file.name, fingerprint=self.fingerprint + ) + except (Exception, KeyboardInterrupt): + temp_file.close() + if os.path.exists(temp_file.name): + os.remove(temp_file.name) + raise + + if temp_file and self.cache_files: + move_temp_file(temp_file, cache_file_name) + self.config.is_fitted = True + return self + + def _validate_transform_params(self, arity): + """Validates the input arguments for the transform function. + This is used to ensure that the input columns match the expected arity of the + transform function. Arity is defined by the length of self._input_columns. + """ + if self._input_columns is not None: + if self.config._transform_input_feature_types is not None and len( + self._input_columns + ) != len(self.config._transform_input_feature_types): + example_arg = ", ".join(["None" for _ in range(arity)]) + assert False, ( + "`_transform_input_feature_types` is defined in " + f"{self.config.__class__.__name__} but does not match the arity of " + f"the transform function in {self.__class__.__name__} (i.e. len(" + "self.config._transform_input_feature_types) != " + "len(self._input_columns) -> " + f"{len(self.config._transform_input_feature_types)} != " + f"{len(self._input_columns)}).\n" + "This can be corrected by doing, for example:\n" + f"_transform_input_feature_types = field(\n" + f" default_factory=lambda: [{example_arg}], init=False, " + "repr=False\n" + ")" + ) + if self.config._transform_unused_feature_types is not None and len( + self._input_columns + ) != len(self.config._transform_unused_feature_types): + example_arg = ", ".join(["None" for _ in range(arity)]) + assert False, ( + "`_transform_unused_feature_types` is defined in " + f"{self.config.__class__.__name__} but does not match the arity of " + f"the transform function in {self.__class__.__name__} (i.e. len(" + "self.config._transform_unused_feature_types) != " + "len(self._input_columns) -> " + f"{len(self.config._transform_unused_feature_types)} != " + f"{len(self._input_columns)}).\n" + "This can be corrected by doing, for example:\n" + f"_transform_unused_feature_types = field(\n" + f" default_factory=lambda: [{example_arg}], init=False, " + "repr=False\n" + ")" + ) + + def _process_transform( + self, + X, + *args, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + method_type = self._method_prefix[1:] + + assert self._input_columns is not None, ( + f"The `{method_type}` method of `{self.__class__.__name__}` must call:\n" + "```\n" + "self._input_columns = self._set_input_columns_and_arity(*args)" + "\n```\n" + "Where `*args` are the columns for each input dataset." + ) + + if is_dataset_dict(X): + raise ValueError( + "Please provide the dataset directly instead of the dictionary: processor.transform(dataset['train'])" + ) + data_fingerprint = getattr(X, "_fingerprint", None) or fingerprint_from_data(X) + cache_output = ( + cache_output is not None + and cache_output + or (self.config.enable_caching and self.config.cache_output) + ) + + if cache_output: + if not fingerprint: + hash = Hasher() + hash.update(self._fingerprint) + hash.update(data_fingerprint) + fingerprint = fingerprint_from_kwargs( + hash.hexdigest(), + { + "input_columns": self._input_columns, + "unused_columns": self.config.unused_columns, + }, + ) + + if cache_dir is not None: + cache_dir = expand_path(str(cache_dir)) + cache_dir = os.path.join(cache_dir, "datasets") + cache_dir = cache_dir or biofit.config.BIOFIT_DATASETS_CACHE + cache_dir = generate_cache_dir( + X, + data_fingerprint, + root_dir=cache_dir, + ) + + if cache_file_name: + if is_remote_url(cache_file_name): + raise ValueError( + "`cache_file_name` is a remote URL. Please provide the " + "file name only. You can specify the directory using " + "`cache_dir`." + ) + elif os.path.isabs(cache_file_name): + raise ValueError( + "`cache_file_name` is an absolute path. Please provide the " + "file name only. You can specify the directory using " + "`cache_dir`." + ) + + keep_unused_columns = ( + keep_unused_columns + if keep_unused_columns is not None + else self.config.keep_unused_columns + ) + self.keep_unused_columns = keep_unused_columns + if ( + self._input_columns is None + and self._feature_dependent + and data_fingerprint == self.config._data_fingerprint + ): + if DataHandler.supports_named_columns(X): + cols = self.config.feature_names_in_ or self.config.feature_idx_in_ + else: + cols = self.config.feature_idx_in_ + if ( + cols + and self.config.extra_idx_in_ + and len(args) > 0 + and len(args) == len(self.config.extra_idx_in_) + ): + cols = [cols] + for i in range(len(self.config.extra_idx_in_)): + if self.config.extra_names_in_[ + i + ] and DataHandler.supports_named_columns(args[i]): + cols.append(self.config.extra_names_in_[i]) + else: + cols.append(self.config.extra_idx_in_[i]) + + self._input_columns = cols + + self._validate_transform_params(len(args) + 1) + return self._transform( + X, + *args, + input_columns=self._input_columns, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_file_name=cache_file_name, + cache_dir=cache_dir, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + @keep_dataset_fingerprint + def _transform( + self, + X, + *args, + input_columns: List[str] = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_file_name: str = None, + cache_dir: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + """ + Transforms the input data. + + Args: + X (Any): The input data. + y (Any, optional): The target data. Defaults to None. + **kwargs: Additional keyword arguments. + + Returns: + Any: The computed processor. + """ + if not self.is_fitted: + raise NotFittedError + + ( + _, + self._selected_indices, + self._unused_indices, + _, + self._extra_indices, + self._unused_extra_indices, + offsets, + ) = self._get_columns( + X, + *args, + input_columns=input_columns, + input_feature_types=self.config._transform_input_feature_types, + unused_columns=self.config.unused_columns, + unused_feature_types=self.config._transform_unused_feature_types, + raise_if_missing=raise_if_missing if raise_if_missing is not None else True, + ) + + if len(args) and not all(arg is None for arg in args): + extra = [] + for i, arg in enumerate(args): + if arg is not None: + if DataHandler.supports_named_columns(arg): + cols = DataHandler.get_column_names(arg, generate_cols=True) + cols = [f"{c}_{i}" for c in cols] + extra.append(DataHandler.set_column_names(arg, cols)) + else: + extra.append(arg) + input = DataHandler.concat([X] + extra, axis=1) + else: + input = X + + if load_from_cache_file is None: + load_from_cache_file = ( + self.config.enable_caching and self.config.load_from_cache_file + ) + + trans_map_kwargs = self._prepare_transform_kwargs( + input, + X, + *args, + selected_indices=self._selected_indices, + unused_indices=self._unused_indices, + extra_indices=self._extra_indices, + unused_extra_indices=self._unused_extra_indices, + offsets=offsets, + cache_dir=cache_dir, + new_fingerprint=fingerprint, + map_kwargs=map_kwargs or {"fn_kwargs": {}}, + batch_format=batch_format, + batched=batched, + batch_size=batch_size, + cache_output=cache_output, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + num_proc=num_proc, + keep_unused_columns=keep_unused_columns, + ) + + out = None + in_memory_table = None + if trans_map_kwargs["fn_kwargs"]["fn"]: + + @wraps(BaseProcessor.map) + def runner(*args, **map_kwargs): + return DataHandler.concat(self.map(*args, **map_kwargs)) + + if is_bioset(X) or is_dataset(X, iterable=False): + if is_biosets_available(): + from biosets import Bioset + + wrap_map = Bioset.map + else: + from datasets import Dataset + + wrap_map = Dataset.map + + @wraps(wrap_map) + def runner(*args, **map_kwargs): + if len(args) > 0: + ds = args[0] + args = args[1:] + else: + ds = map_kwargs.pop("self") + return ds.map(*args, **map_kwargs) + + elif is_iterable_dataset(X): + from datasets import IterableDataset + + @wraps(IterableDataset.map) + def runner(*args, **map_kwargs): + if len(args) > 0: + ds = args[0] + args = args[1:] + else: + ds = map_kwargs.pop("self") + return ds.map(*args, **map_kwargs) + + out = self.run(input, runner=runner, **trans_map_kwargs) + + out = self._process_transform_output( + out, + X, + *args, + output_format=output_format, + keep_unused_columns=keep_unused_columns, + selected_indices=self._selected_indices, + unused_indices=self._unused_indices, + fingerprint=trans_map_kwargs["new_fingerprint"], + ) + + if in_memory_table: + X._data = in_memory_table + + return out + + def _process_extra_inds( + self, orig_input, extra_inputs, extra_indices, unused_extra_indices + ): + """ + Processes extra indices for additional inputs. + + This method adjusts the indices for extra inputs by adding an offset based on the + dimensions of the original input and the extra inputs. It ensures that the indices + are correctly aligned with the combined input dimensions. + + Args: + orig_input: The original input data. + extra_inputs: A list of additional input data. + extra_indices: A list of indices corresponding to the extra inputs. + unused_extra_indices: A list of unused indices corresponding to the extra inputs. + + Returns: + A tuple containing: + - extra_inds: A list of adjusted indices for the extra inputs. + - unused_extra_inds: A list of adjusted unused indices for the extra inputs. + """ + assert not extra_inputs or extra_indices is not None, ( + "`extra_indices` was returned as `None` from " + f"`{self.__class__.__name__}`. " + f"Was `{self.__class__.__name__}._input_columns` or " + f"`{self.__class__.__name__}.config._input_columns` set correctly?" + ) + extra_inds = copy.deepcopy(extra_indices) + unused_extra_inds = copy.deepcopy(unused_extra_indices) + if extra_inputs and not all(arg is None for arg in extra_inputs): + x_dims = DataHandler.get_shape(orig_input) + if len(x_dims) == 1: + x_dims = (x_dims[0], 1) + offset = x_dims[1] + extra_inds = [] + unused_extra_inds = [] + if unused_extra_indices is None: + unused_extra_indices = [None] * len(extra_indices) + for inds, un_inds, arg in zip( + extra_indices, unused_extra_indices, extra_inputs + ): + if arg is not None: + if inds is not None and len(inds) > 0: + extra_inds.append([i + offset for i in inds]) + else: + extra_inds.append(None) + if un_inds is not None and len(un_inds) > 0: + unused_extra_inds.append([i + offset for i in un_inds]) + else: + unused_extra_inds.append(None) + + arg_dims = DataHandler.get_shape(arg) + if len(arg_dims) == 1: + arg_dims = (arg_dims[0], 1) + offset += arg_dims[1] + else: + extra_inds.append(None) + unused_extra_inds.append(None) + return extra_inds, unused_extra_inds + + def _prepare_fit_kwargs( + self, + funcs, + combined_inputs, + orig_input, + extra_inputs, + extra_untouched_inputs, + selected_indices=None, + unused_indices=None, + extra_indices=None, + unused_extra_indices=None, + offsets=None, + map_kwargs={"fn_kwargs": {}}, + batch_format=None, + batched=None, + batch_size=None, + num_proc=None, + ): + original_format = get_data_format(combined_inputs) + + poolers = None + if ( + batch_size is not None + and batch_size < DataHandler.get_shape(combined_inputs)[0] + and funcs + ): + batchable_funcs = self._get_method( + _ORDERED_FORMATS, + func_type="_fit", + prefix=self.config._batch_method_prefix, + ) + if not batchable_funcs: + batch_size = DataHandler.get_shape(combined_inputs)[0] + if batched is not None: + logger.warning_once( + f"There are no batched fit functions available for {self.__class__.__name__}. " + "Using non-batched fit functions." + ) + batched = True + batch_size = None + batched = True + poolers = None + else: + poolers = self._get_method(_ORDERED_FORMATS, func_type="_pool_fit") + funcs = batchable_funcs + if original_format == "ray" and batch_format is None: + batch_format = ["pandas", "pd", "numpy", "np"] + + func, batch_format = self._get_target_func(funcs, original_format, batch_format) + pooler = None + if poolers: + pooler, _ = self._get_target_func(poolers, original_format, batch_format) + + func_args = inspect.getfullargspec(func).args if func else [] + + with_indices = ( + "indices" in func_args + or "indexes" in func_args + or "index" in func_args + or "ind" in func_args + or "inds" in func_args + or "idx" in func_args + or "i" in func_args + ) + + with_rank = "rank" in func_args or "rnk" in func_args or "r" in func_args + map_kwargs = map_kwargs or {} + map_kwargs = copy.deepcopy(map_kwargs) + map_kwargs["with_indices"] = with_indices + map_kwargs["with_rank"] = with_rank + + if "num_proc" not in map_kwargs: + map_kwargs["num_proc"] = num_proc + + if ( + func + and "num_proc" in map_kwargs + and not func.__name__.startswith(self.config._batch_method_prefix) + ): + # cannot use num_proc with non-batched fit functions + map_kwargs.pop("num_proc") + + map_kwargs["desc"] = self.config._fit_process_desc + if batched is None or batched: + map_kwargs["batched"] = True + map_kwargs["batch_size"] = batch_size + else: + map_kwargs["batched"] = False + map_kwargs["batch_size"] = 1 + + extra_inds, unused_extra_inds = self._process_extra_inds( + orig_input=orig_input, + extra_inputs=extra_inputs, + extra_indices=extra_indices, + unused_extra_indices=unused_extra_indices, + ) + + fn_kwargs = { + "fn": func, + "func_type": "_fit", + "extra_untouched_inputs": extra_untouched_inputs, + "selected_indices": selected_indices, + "unused_indices": unused_indices, + "extra_indices": extra_inds, + "unused_extra_indices": unused_extra_inds, + "with_metadata": "metadata" in func_args, + "in_format_kwargs": { + "target_format": batch_format, + }, + "out_format_kwargs": { + "target_format": None, + }, + } + + if "fn_kwargs" in map_kwargs: + map_kwargs["fn_kwargs"].update(fn_kwargs) + else: + map_kwargs["fn_kwargs"] = fn_kwargs + + map_kwargs["new_fingerprint"] = self.fingerprint + + return map_kwargs, pooler + + def _prepare_transform_kwargs( + self, + combined_inputs, + orig_input, + *extra_inputs, + selected_indices, + unused_indices, + extra_indices, + unused_extra_indices, + offsets=None, + batch_format=None, + map_kwargs={"fn_kwargs": {}}, + batched=None, + batch_size=1000, + cache_output=True, + cache_file_name=None, + cache_dir=None, + load_from_cache_file=True, + new_fingerprint=None, + num_proc=None, + keep_unused_columns=None, + ): + original_format = get_data_format(combined_inputs) + input_format = batch_format + + funcs = self._get_method(_ORDERED_FORMATS, func_type=self._method_prefix) + func, batch_format = self._get_target_func(funcs, original_format, input_format) + + map_kwargs = map_kwargs.copy() + func_args = inspect.getfullargspec(func).args if func else [] + indices_args = [ + "indices", + "indexes", + "index", + "ind", + "inds", + "idx", + "i", + ] + rank_args = ["rank", "rnk", "r"] + with_indices = any(arg in func_args for arg in indices_args) + with_rank = any(arg in func_args for arg in rank_args) + map_kwargs = copy.deepcopy(map_kwargs) + map_kwargs["with_indices"] = with_indices + map_kwargs["with_rank"] = with_rank + + map_kwargs["desc"] = getattr( + self.config, + self._method_prefix + "_process_desc", + "Transforming data", + ) + + if "keep_in_memory" not in map_kwargs: + map_kwargs["keep_in_memory"] = not cache_output + + if "cache_file_name" not in map_kwargs: + map_kwargs["cache_file_name"] = cache_file_name + + if "num_proc" not in map_kwargs: + map_kwargs["num_proc"] = num_proc + + if batched is None or batched: + map_kwargs["batched"] = True + map_kwargs["batch_size"] = batch_size + elif batched: + map_kwargs["batched"] = True + map_kwargs["batch_size"] = batch_size + else: + map_kwargs["batched"] = False + + if "load_from_cache_file" not in map_kwargs: + map_kwargs["load_from_cache_file"] = load_from_cache_file + + output_format = original_format + if batch_format in _ARROW_WRITEABLE_FORMATS: + output_format = batch_format + elif original_format not in _ARROW_WRITEABLE_FORMATS: + output_format = "arrow" + + map_kwargs["new_fingerprint"] = new_fingerprint + + features_out = None + + extra_inds, unused_extra_inds = self._process_extra_inds( + orig_input=orig_input, + extra_inputs=extra_inputs, + extra_indices=extra_indices, + unused_extra_indices=unused_extra_indices, + ) + + fn_kwargs = { + "fn": func, + "func_type": self._method_prefix, + "with_metadata": "metadata" in func_args, + "selected_indices": selected_indices, + "unused_indices": unused_indices, + "extra_indices": extra_inds, + "unused_extra_indices": unused_extra_inds, + "keep_unused_columns": keep_unused_columns, + "in_format_kwargs": { + "target_format": batch_format, + }, + "out_format_kwargs": { + "target_format": output_format, + }, + } + + if "fn_kwargs" in map_kwargs: + map_kwargs["fn_kwargs"].update(fn_kwargs) + else: + map_kwargs["fn_kwargs"] = fn_kwargs + + combined_inputs, map_kwargs = self._process_transform_input( + combined_inputs, **map_kwargs + ) + + if hasattr(orig_input, "cache_files") and orig_input.cache_files: + if cache_file_name is None: + cache_files = orig_input.cache_files[0]["filename"] + cache_dir = os.path.dirname(cache_files) + cache_file_name = f"cache-{new_fingerprint}.arrow" + cache_file_name = os.path.join(cache_dir, cache_file_name) + + if not load_from_cache_file or not ( + cache_file_name and os.path.exists(cache_file_name) + ): + if "features" not in map_kwargs or map_kwargs["features"] is None: + unsel_inds = unused_indices if unused_indices else [] + if extra_inds: + unsel_inds += [ + i + for sub in [ + inds if inds is not None else [] for inds in extra_inds + ] + for i in sub + ] + if unused_extra_inds: + unsel_inds += [ + i + for sub in [ + inds if inds is not None else [] + for inds in unused_extra_inds + ] + for i in sub + ] + unsel_inds = sorted(unsel_inds) + features_out = self._get_features_out( + combined_inputs, + selected_indices=copy.deepcopy(selected_indices), + unselected_indices=unsel_inds, + one_to_one_features=self.config.one_to_one_features, + n_features_out=self.config._n_features_out, + keep_unused_columns=keep_unused_columns, + ) + map_kwargs["features"] = features_out + else: + features_out = map_kwargs["features"] + + map_kwargs["fn_kwargs"]["feature_names"] = list(features_out.keys()) + map_kwargs["features"] = features_out + return map_kwargs + + def fit_transform( + self, + X, + *args, + **kwargs, + ): + output_format = kwargs.pop("output_format", None) + return self.fit(X, *args, **kwargs).transform( + X, output_format=output_format, **kwargs + ) + + def _process_transform_input(self, X, **kwargs): + return X, kwargs + + def _process_transform_output(self, output, input, *args, **kwargs): + output_format = kwargs.get("output_format", None) or get_data_format(input) + if output_format: + if output is None: + raise ValueError( + f"The output format is specified as `{output_format}` but the " + f"output from the {self._method_prefix.removeprefix('_')} method " + "is `None`." + ) + output = DataHandler.to_format(output, output_format) + if DataHandler.get_shape(output)[0] != DataHandler.get_shape(input)[0]: + return output + return output + + def get_params(self, deep=True, show_init_only=True, show_repr_only=True): + """Get the parameters of the processor.""" + return self.config.get_params( + deep=deep, show_init_only=show_init_only, show_repr_only=show_repr_only + ) + + def load_processed_estimator_from_cache( + self, cache_file_name: Optional[Union[Path, str]] = None, **kwargs + ): + """Load a processed estimator from cache if it exists, otherwise throw an error.""" + # Check if we've already cached this computation (indexed by a hash) + if cache_file_name and os.path.exists(cache_file_name): + if isinstance(cache_file_name, Path): + cache_file_name = cache_file_name.resolve().as_posix() + self.config = self.config_class.from_config_file(cache_file_name) + return self + else: + raise NonExistentCacheError + + def _get_features_out( + self, + X, + selected_indices=None, + unselected_indices=None, + one_to_one_features=True, + n_features_out=None, + keep_unused_columns=False, + ) -> "Features": + features_out = None + if unselected_indices is not None: + unsel_inds = set(unselected_indices) + else: + unsel_inds = set() + cols = DataHandler.get_column_names(X, generate_cols=True) + assert cols is not None, "Could not generate column names from input data" + if one_to_one_features: + if selected_indices is not None: + sel_inds = set(selected_indices) + else: + sel_inds = set(range(len(cols))) + else: + sel_inds = set() + + if is_bioset(X) or is_dataset(X): + # get the output features, as well as the features that need to be reinserted + features = X._info.features.copy() + elif is_datasets_available(): + from datasets.features import Value + + features = {k: Value(dtype=v) for k, v in DataHandler.get_dtypes(X).items()} + else: + features = {k: None for k in cols} + + if keep_unused_columns: + sel_inds = sel_inds.union(unsel_inds) + + out_cols = self._get_feature_names_out( + n_features_out=n_features_out, + input_features=cols, + useful_feature_inds=list(sorted(sel_inds)), + one_to_one_features=one_to_one_features, + ) + features_out = {} + if one_to_one_features: + pa_type = None + if self.output_dtype: + pa_type = string_to_arrow(self.output_dtype) + pos = 0 + for i in range(len(cols)): + if i in unsel_inds: + if not keep_unused_columns: + continue + features_out[out_cols[pos]] = features[cols[i]] + pos += 1 + else: + k = out_cols[pos] + pos += 1 + v = features[cols[i]] + if pa_type and hasattr(v, "dtype") and hasattr(v, "pa_type"): + setattr(v, "dtype", self.output_dtype) + setattr(v, "pa_type", pa_type) + features_out[k] = v + else: + if keep_unused_columns: + features_out.update({cols[i]: features[cols[i]] for i in unsel_inds}) + # the transformed features are always appended to the end + out_cols = out_cols[len(unsel_inds) :] + if is_datasets_available(): + from datasets.features import Value + + if self.output_dtype: + features_out.update( + {k: Value(dtype=self.output_dtype) for k in out_cols} + ) + else: + dtypes = list( + set([features[cols[i]].dtype for i in selected_indices]) + ) + dtype = determine_upcast(dtypes) + features_out.update({k: Value(dtype=dtype) for k in out_cols}) + else: + features_out.update({k: None for k in out_cols}) + + if is_datasets_available(): + from datasets.features import Features + + features_out = Features(features_out) + + return features_out + + def _prepare_runner(self, X, **fn_kwargs): + if fn_kwargs.get("with_metadata", False): + if is_bioset(X) or is_dataset(X) and is_biosets_available(): + from biosets import get_feature_metadata + + feat_metadata = get_feature_metadata(X) + feat_arrow_tbl = pa.Table.from_pylist(list(feat_metadata.values())) + try: + feat_arrow_tbl = feat_arrow_tbl.add_column( + 0, "features", pa.array(list(feat_metadata.keys())) + ) + fn_kwargs["metadata"] = feat_arrow_tbl + except Exception: + pass + + return fn_kwargs + + def run( + self, + X, + runner: Optional[Callable] = None, + fn_kwargs: dict = {}, + **map_kwargs, + ): + fn_kwargs = self._prepare_runner(X, **fn_kwargs) + if runner: + if is_bioset(X) or is_dataset(X): + format = "arrow" + if ( + "in_format_kwargs" in fn_kwargs + and "target_format" in fn_kwargs["in_format_kwargs"] + ): + format = fn_kwargs["in_format_kwargs"]["target_format"] + if format not in _ARROW_WRITEABLE_FORMATS: + format = "arrow" + with X.formatted_as(format): + kwargs = get_kwargs({"fn_kwargs": fn_kwargs, **map_kwargs}, runner) + return runner(X, self._process_batches, **kwargs) + else: + kwargs = get_kwargs({"fn_kwargs": fn_kwargs, **map_kwargs}, runner) + return runner(X, self._process_batches, **kwargs) + else: + return self._process_batches(X, **fn_kwargs) + + def _get_feature_names_out( + self, + input_features, + n_features_out=0, + useful_feature_inds=None, + one_to_one_features=True, + ): + """ + Retrieves the feature names based on the input and output data. + + Args: + input (Any): The input data. + output (Any): The output data. + + Returns: + list: The list of feature names. + """ + out_features = [input_features[i] for i in useful_feature_inds] + if one_to_one_features: + if input_features is None: + raise ValueError( + "Input features must be provided to generate output features" + ) + if useful_feature_inds is None: + useful_feature_inds = range(len(input_features)) + + if self.config.features_out_suffix or self.config.features_out_prefix: + if self.config.features_out_prefix: + out_features = [ + f"{self.config.features_out_prefix}{i}" for i in out_features + ] + + if self.config.features_out_suffix: + out_features = [ + f"{i}{self.config.features_out_suffix}" for i in out_features + ] + return out_features + else: + _out_features = _generate_get_feature_names_out(self, n_features_out) + if self.config.features_out_suffix or self.config.features_out_prefix: + if self.config.features_out_prefix: + _out_features = [ + f"{self.config.features_out_prefix}{col}" + for col in _out_features + ] + if self.config.features_out_suffix: + _out_features = [ + f"{col}{self.config.features_out_suffix}" + for col in enumerate(_out_features) + ] + out_features.extend(_out_features) + return out_features + out_features.extend(_out_features) + return out_features + + ## Functions for Datasets and IterableDatasets type ## + + def _convert_map_kwargs_to_ray_kwargs( + self, + map_kwargs: dict, + batch_format=None, + is_fit=True, + ): + ray_kwargs = None + if is_fit: + ray_kwargs = { + "fn": self.__class__, + "num_cpus": map_kwargs.get("num_proc", None), + "num_gpus": map_kwargs.get("num_gpus", None), + "batch_format": batch_format + if batch_format in ["numpy", "pandas"] + else None, + "batch_size": map_kwargs.get("batch_size", None), + "concurrency": 1, + "fn_kwargs": map_kwargs["fn_kwargs"], + "fn_constructor_args": (self.config,), + "zero_copy_batch": True, + } + else: + ray_kwargs = { + "fn": self._process_batches, + "num_cpus": map_kwargs.get("num_proc", None), + "num_gpus": map_kwargs.get("num_gpus", None), + "batch_format": batch_format + if batch_format in ["numpy", "pandas"] + else None, + "batch_size": map_kwargs.get("batch_size", None), + "fn_kwargs": map_kwargs["fn_kwargs"], + "fn_constructor_args": (self.config,), + "zero_copy_batch": False, + } + + return ray_kwargs + + def _dummy_func(self, X, *args, **kwargs): + """for creating fingerprint and returning the input data as is when no function is provided.""" + return X + + def _process_batches(self, X, *fn_args, **fn_kwargs): + func = fn_kwargs.get("fn", None) + if fn_kwargs["func_type"] == "_fit": + input, _fn_args, _fn_kwargs = self._process_fit_batch_input( + X, *fn_args, **fn_kwargs + ) + else: + input, _fn_args, _fn_kwargs = self._process_transform_batch_input( + X, *fn_args, **fn_kwargs + ) + + out = func(input, *_fn_args, **_fn_kwargs) + + if isinstance(out, BaseProcessor): + return self._process_fit_batch_output(out) + return self._process_transform_batch_output(X, out, **fn_kwargs) + + def _validate_fit_func_args(self, func, *fn_args): + required_args = get_required_args(func) + noun_plural = "argument" if len(required_args) == 1 else "arguments" + verb_tense = "was" if len(fn_args) == 0 else "were" + assert len(required_args) == (len(fn_args) + 1), ( + f"`{self.__class__.__name__}.fit` requires {len(required_args)} " + f"{noun_plural} {tuple(required_args)}, " + f"but only {len(fn_args) + 1} {verb_tense} " + "provided. Either provide the missing arguments or provide the input " + f"columns found in '{required_args[0]}', if applicable." + ) + + def _process_fit_batch_input(self, X, *fn_args, **fn_kwargs): + func = fn_kwargs.get("fn", None) + assert X is not None, "No input data provided." + assert func is not None, "No function provided for processing the data." + in_format_kwargs = fn_kwargs.get("in_format_kwargs", {}) + selected_indices = fn_kwargs.get("selected_indices", []) + in_format_kwargs["input_columns"] = selected_indices + input = DataHandler.to_format(X, **in_format_kwargs) + _fn_kwargs = get_kwargs(fn_kwargs, func) + extra_indices = fn_kwargs.get("extra_indices", []) + extra_untouched_inputs = fn_kwargs.get("extra_untouched_inputs", []) + fn_args = list(fn_args) + if extra_untouched_inputs: + for arg in extra_untouched_inputs: + in_format_kwargs["input_columns"] = None + arg = DataHandler.to_format(arg, **in_format_kwargs) + fn_args.append(arg) + elif extra_indices is not None: + for cols in extra_indices: + if cols is not None and len(cols): + in_format_kwargs["input_columns"] = cols + arg = DataHandler.to_format(X, **in_format_kwargs) + fn_args.append(arg) + + self._validate_fit_func_args(func, *fn_args) + return input, tuple(fn_args), _fn_kwargs + + def _process_fit_batch_output(self, out): + return out + + def _process_transform_batch_input(self, X, *fn_args, **fn_kwargs): + assert X is not None, "No input data was provided for processing." + func = fn_kwargs.get("fn", None) + in_format_kwargs = fn_kwargs.get("in_format_kwargs", {}) + selected_indices = fn_kwargs.get("selected_indices", []) + in_format_kwargs["input_columns"] = selected_indices + input = DataHandler.to_format(X, **in_format_kwargs) + _fn_kwargs = get_kwargs(fn_kwargs, func) + extra_indices = fn_kwargs.get("extra_indices", []) + if extra_indices: + fn_args = list(fn_args) + for i, cols in enumerate(extra_indices): + if cols is not None and len(cols): + in_format_kwargs["input_columns"] = cols + extra_arg = DataHandler.to_format(X, **in_format_kwargs) + fn_args.insert(i, extra_arg) + return input, fn_args, _fn_kwargs + + def _process_transform_batch_output(self, input, output, **fn_kwargs): + selected_indices = fn_kwargs.get("selected_indices", None) + unused_indices = fn_kwargs.get("unused_indices", None) + keep_unused_columns = fn_kwargs.get("keep_unused_columns", None) + feature_names = fn_kwargs.get("feature_names", None) + out_dims = DataHandler.get_shape(output) + if len(out_dims) == 1: + out_dims = (out_dims[0], 1) + output = DataHandler.to_frame(output, "__output__") + + out_format_kwargs = fn_kwargs.get("out_format_kwargs", {}) + output = DataHandler.to_format(output, **out_format_kwargs) + if keep_unused_columns: + output = self._reinsert_columns( + input, + output, + selected_indices, + unused_indices, + one_to_one_features=self.config.one_to_one_features, + ) + if feature_names is not None and len(feature_names) > 0: + output = DataHandler.set_column_names(output, feature_names) + return output + + def _process_fit_input(self, input, **kwargs): + return input, kwargs + + def _process_fit_output(self, input, out): + return out + + def __repr__(self): + from sklearn.utils._pprint import _EstimatorPrettyPrinter + + N_MAX_ELEMENTS_TO_SHOW = 30 + pp = _EstimatorPrettyPrinter( + compact=True, + indent=1, + indent_at_name=True, + n_max_elements_to_show=N_MAX_ELEMENTS_TO_SHOW, + ) + repr_ = self.__class__.__name__ + pp.pformat(self.config).removeprefix( + self.config.__class__.__name__ + ) + return repr_ + + def _repr_mimebundle_(self, **kwargs): + """Mime bundle used by jupyter kernels to display estimator""" + from sklearn._config import get_config + from sklearn.utils._estimator_html_repr import estimator_html_repr + + output = {"text/plain": repr(self)} + if get_config()["display"] == "diagram": + output["text/html"] = estimator_html_repr(self) + return output + + def shard(self, X, num_shards, index, contiguous): + if not 0 <= index < num_shards: + raise ValueError("index should be in [0, num_shards-1]") + num_rows = DataHandler.get_shape(X)[0] + if contiguous: + div = num_rows // num_shards + mod = num_rows % num_shards + start = div * index + min(index, mod) + end = start + div + (1 if index < mod else 0) + indices = range(start, end) + else: + indices = np.arange(index, num_rows, num_shards) + + return DataHandler.select_rows(X, indices) + + def _pool_fit(self, fitted_processors): + """Pool the results of the map function.""" + out = fitted_processors[0] + if len(fitted_processors) > 1: + raise NotImplementedError( + f"Pooling results from multiple processes is not supported for {self.__class__.__name__}" + ) + return out + + def map( + self, + X, + function: Optional[Callable] = None, + with_indices: bool = False, + with_rank: bool = False, + batched: bool = False, + batch_size: Optional[int] = 1000, + drop_last_batch: bool = False, + cache_output: bool = True, + input_columns: Optional[list] = None, + fn_kwargs: Optional[dict] = None, + num_proc: Optional[int] = None, + desc: Optional[str] = None, + ): + if batch_size is None: + batch_size = DataHandler.get_shape(X)[0] + num_shards = num_proc if num_proc is not None else 1 + if batched and drop_last_batch: + pbar_total = ( + DataHandler.get_shape(X)[0] + // num_shards + // batch_size + * num_shards + * batch_size + ) + else: + pbar_total = DataHandler.get_shape(X)[0] + + if num_proc: + if is_bioset(X) or is_dataset(X): + shards = [ + X.shard( + num_shards=num_shards, + index=rank, + contiguous=True, + keep_in_memory=not cache_output, + ) + for rank in range(num_proc) + ] + else: + shards = [ + self.shard( + X, + num_shards=num_proc, + index=rank, + contiguous=True, + ) + for rank in range(num_proc) + ] + else: + shards = [X] + + dataset_kwargs = { + "shard": X, + "function": function, + "with_indices": with_indices, + "with_rank": with_rank, + "batched": batched, + "batch_size": batch_size, + "drop_last_batch": drop_last_batch, + "input_columns": input_columns, + "fn_kwargs": fn_kwargs, + } + + kwargs_per_job = [ + { + **dataset_kwargs, + "shard": shards[rank], + "rank": rank, + "offset": sum(len(s) for s in shards[:rank]), + } + for rank in range(num_shards) + ] + + if len(kwargs_per_job) < num_shards: + logger.info( + f"Reprocessing {len(kwargs_per_job)}/{num_shards} shards because some of them were missing from the cache." + ) + + processed_data = [None] * num_shards + shards_done = 0 + if num_proc is not None and num_proc > 1: + with Pool(len(kwargs_per_job)) as pool: + logger.info(f"Spawning {num_proc} processes") + if is_datasets_available(): + from datasets.utils import tqdm + else: + from tqdm.auto import tqdm + with tqdm( + unit=" examples", + total=pbar_total, + desc=(desc or "Map") + f" (num_proc={num_proc})", + ) as pbar: + for rank, done, content in iflatmap_unordered( + pool, + BaseProcessor._map_single, + kwargs_iterable=kwargs_per_job, + ): + if done: + shards_done += 1 + logger.debug( + f"Finished processing shard number {rank} of {num_shards}." + ) + processed_data[rank] = content + else: + pbar.update(content) + else: + processed_data = None + if is_datasets_available(): + from datasets.utils import tqdm + else: + from tqdm.auto import tqdm + + with tqdm( + unit=" examples", + total=pbar_total, + desc=desc or "Map", + ) as pbar: + for rank, done, content in BaseProcessor._map_single(**dataset_kwargs): + if done: + shards_done += 1 + logger.debug( + f"Finished processing shard number {rank} of {num_shards}." + ) + processed_data = content + else: + pbar.update(content) + assert processed_data is not None, "Failed to retrieve the result from map" + + return [processed_data] + for kwargs in kwargs_per_job: + del kwargs["shard"] + return processed_data + + @staticmethod + def _map_single( + shard: Any, + function: Optional[Callable] = None, + with_indices: bool = False, + with_rank: bool = False, + input_columns: Optional[List[str]] = None, + batched: bool = False, + batch_size: Optional[int] = 1000, + drop_last_batch: bool = False, + fn_kwargs: Optional[dict] = None, + rank: Optional[int] = None, + offset: int = 0, + ) -> Iterable[Tuple[int, bool, Union[int, Any]]]: + """Apply a function to all the elements in the table (individually or in batches) + and update the table (if function does update examples). + + Args: + shard (`datasets.Bioset`): Bioset to map the transform on. + function (`Callable`): with one of the following signature: + - `function(example: Dict[str, Any]) -> Dict[str, Any]` if `batched=False` and `with_indices=False` and `with_rank=False` + - `function(example: Dict[str, Any], *extra_args) -> Dict[str, Any]` if `batched=False` and `with_indices=True` and/or `with_rank=True` (one extra arg for each) + - `function(batch: Dict[str, List]) -> Dict[str, List]` if `batched=True` and `with_indices=False` and `with_rank=False` + - `function(batch: Dict[str, List], *extra_args) -> Dict[str, List]` if `batched=True` and `with_indices=True` and/or `with_rank=True` (one extra arg for each) + + For advanced usage, the function can also return a `pyarrow.Table`. + Moreover if your function returns nothing (`None`), then `map` will run your function and return the dataset unchanged. + If no function is provided, default to identity function: lambda x: x + with_indices (`bool`, defaults to `False`): Provide example indices to `function`. Note that in this case the signature of `function` should be `def function(example, idx[, rank]): ...`. + with_rank (`bool`, default `False`): Provide process rank to `function`. Note that in this case the signature of `function` should be `def function(example[, idx], rank): ...`. + input_columns (`Optional[List[str]]`, defaults to `None`): The columns to be passed into `function` as + positional arguments. If `None`, a dict mapping to all formatted columns is passed as one argument. + batched (`bool`, defaults to `False`): Provide batch of examples to `function` + batch_size (`int`, optional, defaults to `1000`): Number of examples per batch provided to `function` if `batched=True` + `batch_size <= 0` or `batch_size == None`: Provide the full dataset as a single batch to `function` + drop_last_batch (`bool`, default: `False`): Whether a last batch smaller than the batch_size should be + dropped instead of being processed by the function. + fn_kwargs (`Dict`, optional, defaults to `None`): Keyword arguments to be passed to `function` + rank: (`int`, optional, defaults to `None`): If specified, this is the process rank when doing multiprocessing + offset: (`int`, defaults to 0): If specified, this is an offset applied to the indices passed to `function` if `with_indices=True`. + """ + if fn_kwargs is None: + fn_kwargs = {} + + # If we do batch computation but no batch size is provided, default to the full dataset + if batched and (batch_size is None or batch_size <= 0): + batch_size = DataHandler.get_shape(shard)[0] + + # We set this variable to True after processing the first example/batch in + # `apply_function_on_filtered_inputs` if the map function returns a dict. + # If set to False, no new arrow table will be created + + update_data = None + input_formatter = None + if is_bioset(shard) or is_dataset(shard): + from datasets.arrow_dataset import get_formatter + + format_kwargs = shard._format_kwargs.copy() + # Lazy formatting is only available for the default format (None/python) + if not input_columns and shard._format_type is None: + format_kwargs["lazy"] = True + input_formatter = get_formatter( + shard._format_type, + features=shard._info.features, + **format_kwargs, + ) + + def apply_function_on_filtered_inputs(data, indices, offset=0): + """Utility to apply the function on a selection of columns.""" + nonlocal update_data + if ( + isinstance(data, pa.Table) + and input_formatter + and is_datasets_available() + ): + from datasets.arrow_dataset import format_table + + inputs = format_table( + data, + 0 if not batched else range(data.num_rows), + format_columns=input_columns, + formatter=input_formatter, + ) + else: + inputs = data + fn_args = ( + [inputs] + if input_columns is None + else [DataHandler.select_column(inputs, col) for col in input_columns] + ) + if offset == 0: + effective_indices = indices + else: + effective_indices = ( + [i + offset for i in indices] + if isinstance(indices, list) + else indices + offset + ) + additional_args = () + if with_indices: + additional_args += (effective_indices,) + if with_rank: + additional_args += (rank,) + processed_inputs = function(*fn_args, *additional_args, **fn_kwargs) + return processed_inputs + + num_examples_progress_update = 0 + # Optionally initialize the writer as a context manager + try: + if is_bioset(shard) or is_dataset(shard, iterable=False): + arrow_formatted_shard = shard.with_format("arrow") + + if not batched: + shard_iterable = enumerate(arrow_formatted_shard) + else: + num_rows = ( + DataHandler.get_shape(shard)[0] + if not drop_last_batch + else DataHandler.get_shape(shard)[0] // batch_size * batch_size + ) + shard_iterable = zip( + range(0, num_rows, batch_size), + arrow_formatted_shard.iter( + batch_size, drop_last_batch=drop_last_batch + ), + ) + else: + if not batched: + shard_iterable = enumerate(shard) + else: + num_rows = ( + DataHandler.get_shape(shard)[0] + if not drop_last_batch + else DataHandler.get_shape(shard)[0] // batch_size * batch_size + ) + shard_iterable = zip( + range(0, num_rows, batch_size), + DataHandler.iter( + shard, batch_size, drop_last_batch=drop_last_batch + ), + ) + processors = None + if not batched: + _time = time.time() + for i, example in shard_iterable: + processor = apply_function_on_filtered_inputs( + example, i, offset=offset + ) + if isinstance(processor, BaseProcessor): + processors = processor + else: + processors = processors or [] + processors.append(processor) + + num_examples_progress_update += 1 + if time.time() > _time + biofit.config.PBAR_REFRESH_TIME_INTERVAL: + _time = time.time() + yield rank, False, num_examples_progress_update + num_examples_progress_update = 0 + else: + _time = time.time() + for i, batch in shard_iterable: + num_examples_in_batch = DataHandler.get_shape(batch)[0] + indices = list( + range( + *( + slice(i, i + batch_size).indices( + DataHandler.get_shape(shard)[0] + ) + ) + ) + ) # Something simpler? + processor = apply_function_on_filtered_inputs( + batch, + indices, + offset=offset, + ) + if isinstance(processor, BaseProcessor): + processors = processor + else: + processors = processors or [] + processors.append(processor) + num_examples_progress_update += num_examples_in_batch + if time.time() > _time + biofit.config.PBAR_REFRESH_TIME_INTERVAL: + _time = time.time() + yield rank, False, num_examples_progress_update + num_examples_progress_update = 0 + + except (Exception, KeyboardInterrupt): + yield rank, False, num_examples_progress_update + raise + + yield rank, False, num_examples_progress_update + + if isinstance(processors, BaseProcessor): + yield rank, True, processors + elif isinstance(processors, list): + if len(processors) > 0: + yield rank, True, DataHandler.concat(processors) + else: + yield rank, True, None + else: + yield rank, True, processors + + def cleanup_cache_files(self, X=None, cache_dir=None, cache_file_name=None) -> int: + """Clean up cache files generated by the processor.""" + count = 0 + cache_files = [] + if not self.cache_files and X is not None: + data_fingerprint = getattr( + X, "_fingerprint", None + ) or fingerprint_from_data(X) + + if cache_dir is not None: + cache_dir = os.path.join(expand_path(str(cache_dir)), "processors") + cache_dir = generate_cache_dir(X, data_fingerprint, cache_dir=cache_dir) + if cache_dir: + cache_files = [ + { + "filename": get_cache_file_name( + cache_dir, self.fingerprint, cache_file_name + ) + } + ] + else: + cache_files = self.cache_files + + for cache_file in cache_files: + if os.path.exists(cache_file["filename"]): + os.remove(cache_file["filename"]) + count += 1 + return count diff --git a/src/biofit/stat/__init__.py b/src/biofit/stat/__init__.py new file mode 100644 index 0000000..e71b198 --- /dev/null +++ b/src/biofit/stat/__init__.py @@ -0,0 +1,9 @@ +# ruff: noqa +from .col_sum import * +from .col_mean import * +from .col_missingness import * +from .row_sum import * +from .row_mean import * +from .row_missingness import * +from .distance import * +from .correlation import * diff --git a/src/biofit/stat/col_mean/__init__.py b/src/biofit/stat/col_mean/__init__.py new file mode 100644 index 0000000..7e79fb9 --- /dev/null +++ b/src/biofit/stat/col_mean/__init__.py @@ -0,0 +1,6 @@ +# ruff: noqa +from .col_mean import ( + ColumnMeanStat, + ColumnMeanStatConfig, + ColumnMeanStatConfigForOTU, +) diff --git a/src/biofit/stat/col_mean/col_mean.py b/src/biofit/stat/col_mean/col_mean.py new file mode 100644 index 0000000..bca2964 --- /dev/null +++ b/src/biofit/stat/col_mean/col_mean.py @@ -0,0 +1,208 @@ +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Type + +import numpy as np + +from biofit.integration.biosets import get_feature +from biofit.processing import SelectedColumnTypes +from biofit.utils import logging + +from ..stat import Stat, StatConfig + +if TYPE_CHECKING: + import polars as pl + +logger = logging.get_logger(__name__) + + +@dataclass +class ColumnMeanStatConfig(StatConfig): + # process description + _transform_process_desc: str = field( + default="Calculating column means", init=False, repr=False + ) + processor_name: str = field(default="col_mean", init=False, repr=False) + + +@dataclass +class ColumnMeanStatConfigForOTU(ColumnMeanStatConfig): + """Computes the sum of each row in the OTUAbundance feature. + This class is the same as TotalOTUAbundanceStat. + It is provided for autostat + """ + + # dataset specific attributes + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance")], init=False, repr=False + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance")], init=False, repr=False + ) + dataset_name: str = field(default="otu", init=False, repr=False) + + +class ColumnMeanStat(Stat): + # config attributes + _config_class = ColumnMeanStatConfig + config: ColumnMeanStatConfig + + sums_ = None + counts_ = 0 + + def fit( + self, + X, + input_columns: SelectedColumnTypes = None, + raise_if_missing: bool = None, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + num_proc: int = None, + map_kwargs: dict = {"fn_kwargs": {}}, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ) -> "ColumnMeanStat": + self.config._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_fit( + X, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ) + + def transform( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + self._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_transform( + X, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def fit_transform( + self, + X, + *args, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = True, + raise_if_missing: bool = False, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = {"fn_kwargs": {}}, + num_proc: int = None, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ): + return self.fit( + X, + input_columns=input_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ).transform( + X, + input_columns=input_columns, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def _fit_numpy(self, X: np.ndarray, y=None): + self.config.means = np.mean(X, axis=0) + return self + + def _partial_fit_numpy(self, X: np.ndarray, y=None): + self.counts_ += X.shape[0] + if self.sums_ is None: + self.sums_ = np.sum(X, axis=0) + else: + self.sums_ += np.sum(X, axis=0) + + # self.config.means = self.config.sums_ / self.config.counts_ + return self + + def _fit_polars(self, X: "pl.DataFrame"): + self.config.means = X.mean() + return self + + def _partial_fit_polars(self, X: "pl.DataFrame"): + self.counts_ += X.shape[0] + if self.sums_ is None: + self.sums_ = X.sum() + else: + self.sums_ += X.sum() + + # self.config.means = self.sums_ / self.counts_ + return self + + def _pool_fit(self, fitted_processors: List["ColumnMeanStat"]): + self.sums_ = sum([p.sums_ for p in fitted_processors]) + self.counts_ = sum([p.counts_ for p in fitted_processors]) + self.config.means = self.sums_ / self.counts_ + return self + + def _process_transform_output(self, output, input, *args, **kwargs): + return super()._process_transform_output( + self.config.means, input, *args, **kwargs + ) diff --git a/src/biofit/stat/col_missingness/__init__.py b/src/biofit/stat/col_missingness/__init__.py new file mode 100644 index 0000000..d6fad16 --- /dev/null +++ b/src/biofit/stat/col_missingness/__init__.py @@ -0,0 +1,8 @@ +# ruff: noqa +from .col_missingness import ( + ColumnMissingnessStat, + ColumnMissingnessStatConfig, + ColumnMissingnessStatConfigForOTU, + ColumnMissingnessStatConfigForSNP, + ColumnMissingnessStatConfigForMetagenomics, +) diff --git a/src/biofit/stat/col_missingness/col_missingness.py b/src/biofit/stat/col_missingness/col_missingness.py new file mode 100644 index 0000000..0af8cbc --- /dev/null +++ b/src/biofit/stat/col_missingness/col_missingness.py @@ -0,0 +1,321 @@ +import os +import sys +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Optional, Type, Union + +import numpy as np +import pandas as pd +import pyarrow as pa +import pyarrow.compute as pc +from biocore import DataHandler + +from biofit.integration.biosets import get_feature +from biofit.processing import SelectedColumnTypes, sync_backup_config +from biofit.utils import logging + +from ..stat import Stat, StatConfig + +if TYPE_CHECKING: + import polars as pl + +logger = logging.get_logger(__name__) + + +def is_multiprocess_mode(): + return os.getppid() > 1 + + +def _filter_features_pandas(X: pd.DataFrame, depth: Optional[float] = None): + """ + SampleFilter features in a pandas DataFrame based on their presence in the dataset. + + Args: + X (pd.DataFrame): The input DataFrame containing the features. + threshold (float, optional): The minimum required presence of a feature in the dataset. Defaults to 0.5. + depth (float, optional): The depth threshold for numeric columns. If specified, counts values less than depth. Defaults to None. + + Returns: + List[int]: A list of column indices that pass the filtering condition. + """ + total_missing = X.isnull().sum(axis=0) + if depth is not None: + total_missing += (X <= depth).sum(axis=0) + return total_missing.to_frame().transpose() + + +def _filter_features_numpy(X: np.ndarray, depth: Optional[float] = None): + total_missing = np.sum(np.isnan(X) | (X is None), axis=0) + if depth is not None: + total_missing += np.sum(X <= depth, axis=0) + return total_missing[np.newaxis, :] + + +def _filter_features_polars(X: "pl.DataFrame", depth: Optional[float] = None): + if "polars" not in sys.modules: + import polars + else: + polars = sys.modules["polars"] + total_missing = X.null_count().sum() + if depth is not None: + less_than_depth = X.with_columns(polars.col("*") <= depth).sum() + total_missing += less_than_depth + + return total_missing + + +@dataclass +class ColumnMissingnessStatConfig(StatConfig): + # process description + transform_process_name: str = field( + default="Calculating the number of missing values", init=False, repr=False + ) + processor_name: str = field(default="col_missingness", init=False, repr=False) + _fit_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + _transform_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + + # config attributes + depth: Optional[Union[float, int]] = None + + +@dataclass +class ColumnMissingnessStatConfigForOTU(ColumnMissingnessStatConfig): + """Computes the sum of each row in the OTUAbundance feature. + This class is the same as RowSumStatForOTU. + """ + + transform_process_name: str = field( + default="Calculating species richness", init=False, repr=False + ) + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance")], init=False, repr=False + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance")], init=False, repr=False + ) + dataset_name: str = field(default="otu", init=False, repr=False) + + depth: Optional[Union[float, int]] = 0 + + +@dataclass +class ColumnMissingnessStatConfigForSNP(ColumnMissingnessStatConfig): + """Computes the sum of each row in the OTUAbundance feature. + This class is the same as RowSumStatForOTU. + """ + + transform_process_name: str = field( + default="Calculating species richness", init=False, repr=False + ) + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance")], init=False, repr=False + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance")], init=False, repr=False + ) + dataset_name: str = field(default="snp", init=False, repr=False) + + depth: Optional[Union[float, int]] = 0 + + +@dataclass +class ColumnMissingnessStatConfigForMetagenomics(ColumnMissingnessStatConfig): + """Computes the sum of each row in the OTUAbundance feature. + This class is the same as RowSumStatForOTU. + """ + + transform_process_name: str = field( + default="Calculating species richness", init=False, repr=False + ) + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [(get_feature("Abundance"), get_feature("ReadCount"))], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [(get_feature("Abundance"), get_feature("ReadCount"))], + init=False, + repr=False, + ) + dataset_name: str = field(default="metagenomics", init=False, repr=False) + + depth: Optional[Union[float, int]] = 0 + + +class ColumnMissingnessStat(Stat): + # config attributes + _config_class = ColumnMissingnessStatConfig + config: ColumnMissingnessStatConfig + output_dtype = "int64" + + def __init__( + self, + depth: Optional[Union[float, int]] = None, + config: Optional[ColumnMissingnessStatConfig] = None, + **kwargs, + ): + super().__init__(config=config, depth=depth, **kwargs) + + @sync_backup_config + def set_params(self, **kwargs): + self.config = self.config.replace_defaults(**kwargs) + return self + + def fit( + self, + X, + input_columns: SelectedColumnTypes = None, + raise_if_missing: bool = None, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + num_proc: int = None, + map_kwargs: dict = {"fn_kwargs": {}}, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ) -> "ColumnMissingnessStat": + return self._process_fit( + X, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ) + + def transform( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + self._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_transform( + X, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def fit_transform( + self, + X, + *args, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = True, + raise_if_missing: bool = False, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = {"fn_kwargs": {}}, + num_proc: int = None, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ): + return self.fit( + X, + input_columns=input_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ).transform( + X, + input_columns=input_columns, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def _process_transform_input(self, X, **kwargs): + kwargs["batch_size"] = DataHandler.get_shape(X)[0] + return X, kwargs + + def _transform_numpy(self, X: np.ndarray) -> np.ndarray: + return _filter_features_numpy(X, self.config.depth) + + def _transform_pandas(self, X: pd.DataFrame) -> pd.DataFrame: + return _filter_features_pandas(X, self.config.depth) + + def _transform_polars(self, X: "pl.DataFrame") -> "pl.DataFrame": + return _filter_features_polars(X, self.config.depth) + + def _transform_arrow(self, X: pa.Table) -> pa.Table: + if self.config.depth is not None: + depth = self.config.depth + return pa.table( + { + k: [ + pc.filter( + v, pc.or_(pc.less_equal(v, depth), pc.is_null(v)) + ).length() + ] + for k, v in zip(X.column_names, X.columns) + } + ) + else: + return pa.table( + { + k: [pc.filter(v, pc.is_null(v)).length()] + for k, v in zip(X.column_names, X.columns) + } + ) diff --git a/src/biofit/stat/col_sum/__init__.py b/src/biofit/stat/col_sum/__init__.py new file mode 100644 index 0000000..75606c4 --- /dev/null +++ b/src/biofit/stat/col_sum/__init__.py @@ -0,0 +1,6 @@ +# ruff: noqa +from .col_sum import ( + ColumnSumStat, + ColumnSumStatConfig, + ColumnSumStatConfigForOTU, +) diff --git a/src/biofit/stat/col_sum/col_sum.py b/src/biofit/stat/col_sum/col_sum.py new file mode 100644 index 0000000..f83e6b9 --- /dev/null +++ b/src/biofit/stat/col_sum/col_sum.py @@ -0,0 +1,277 @@ +""" +This module provides classes for computing the sum of rows and columns in input data. + +Classes: +- RowSumStat: Computes the sum of each row in the input data. +- ColumnSumStat: Computes the sum of each column in the input data. +- ColumnStatForOTU: Computes the sum of each column in the OTUAbundance feature. +- RowSumStatForOTU: Computes the sum of each row in the OTUAbundance feature. +- TotalOTUAbundanceStat: Alias for RowSumStatForOTU. + +These classes provide methods for transforming and fitting the input data using different libraries such as numpy, pandas, and polars. + +Note: The ColumnStatForOTU and RowSumStatForOTU classes are provided for autostat and are equivalent to ColumnSumStat and RowSumStat respectively. +""" + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Callable, List, Optional, Type + +import numpy as np +from biocore import DataHandler + +from biofit.integration.biosets import get_feature +from biofit.processing import SelectedColumnTypes +from biofit.utils import logging + +from ..stat import Stat, StatConfig + +if TYPE_CHECKING: + import polars as pl + +logger = logging.get_logger(__name__) + + +@dataclass +class ColumnSumStatConfig(StatConfig): + _fit_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + _transform_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + _fit_process_desc: str = field( + default="Calculating column sums", init=False, repr=False + ) + _transform_process_desc: str = field( + default="Calculating column sums", init=False, repr=False + ) + processor_name: str = field(default="col_sum", init=False, repr=False) + + +class ColumnSumStat(Stat): + """ + Computes the sum of each column in the input data. + """ + + # config attributes + _config_class = ColumnSumStatConfig + config: ColumnSumStatConfig + output_dtype = "float64" + + def fit( + self, + X, + input_columns: SelectedColumnTypes = None, + raise_if_missing: bool = None, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + num_proc: int = None, + map_kwargs: dict = {"fn_kwargs": {}}, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ) -> "ColumnSumStat": + self.config._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_fit( + X, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ) + + def transform( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + self._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_transform( + X, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def fit_transform( + self, + X, + *args, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = True, + raise_if_missing: bool = False, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = {"fn_kwargs": {}}, + num_proc: int = None, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ): + return self.fit( + X, + input_columns=input_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ).transform( + X, + input_columns=input_columns, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def _fit_numpy(self, X: np.ndarray): + """ + Fits the model to the input data using numpy. + + Args: + X: Input data as a numpy array. + + Returns: + Fitted model. + """ + self.config.sums_ = np.sum(X, axis=0) + return self + + def _partial_fit_numpy(self, X: np.ndarray): + """ + Updates the fitted model with additional input data using numpy. + + Args: + X: Additional input data as a numpy array. + + Returns: + Updated fitted model. + """ + if self.config.sums_ is None: + self.config.sums_ = np.sum(X, axis=0) + else: + self.config.sums_ += np.sum(X, axis=0) + return self + + def _fit_polars(self, X: "pl.DataFrame"): + """ + Fits the model to the input data using polars. + + Args: + X: Input data as a polars DataFrame. + + Returns: + Fitted model. + """ + self.config.sums_ = X.sum() + return self + + def _partial_fit_polars(self, X: "pl.DataFrame"): + """ + Updates the column sums with additional input data using polars. + + Args: + X: Additional input data as a polars DataFrame. + + Returns: + Updated fitted model. + """ + if self.config.sums_ is None: + self.config.sums_ = X.sum() + else: + self.config.sums_ += X.sum() + return self + + def _pool_fit(self, fitted_processors: List["ColumnSumStat"]): + """ + Pools the fitted models from different batches. + + Args: + fitted_processors: List of fitted models. + + Returns: + Pooled fitted model. + """ + self.config.sums_ = sum( + [processor.config.sums_ for processor in fitted_processors] + ) + return self + + def _process_transform_output(self, output, input, *args, **kwargs): + return super()._process_transform_output( + self.config.sums_, input, *args, **kwargs + ) + + +@dataclass +class ColumnSumStatConfigForOTU(ColumnSumStatConfig): + # process description + _transform_process_desc: str = field( + default="Calculating total OTU abundance for each sample", + init=False, + repr=False, + ) + dataset_name: str = field(default="otu", init=False, repr=False) + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance")], init=False, repr=False + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance")], init=False, repr=False + ) diff --git a/src/biofit/stat/correlation/__init__.py b/src/biofit/stat/correlation/__init__.py new file mode 100644 index 0000000..d0cbfdf --- /dev/null +++ b/src/biofit/stat/correlation/__init__.py @@ -0,0 +1,6 @@ +# ruff: noqa + +from .correlation import ( + CorrelationStat, + CorrelationStatConfig, +) diff --git a/src/biofit/stat/correlation/correlation.py b/src/biofit/stat/correlation/correlation.py new file mode 100644 index 0000000..4e15820 --- /dev/null +++ b/src/biofit/stat/correlation/correlation.py @@ -0,0 +1,235 @@ +""" +Correlation calculation +""" + +from dataclasses import dataclass, field + +from biocore import DataHandler + +from biofit.integration.biosets import get_feature +from biofit.processing import ( + SelectedColumnTypes, + SelectedFeatureTypes, + sync_backup_config, +) +from biofit.utils import logging + +from ..stat import Stat, StatConfig + +logger = logging.get_logger(__name__) + +CORRELATION_STAT_DOCSTRING = """ +Stat features based on the correlation with the target variable. + +Args: + X: Input data + y: Target variable + input_columns: Columns to filter + target_column: Target column + **kwargs: Additional keyword arguments + +Returns: + SampleFiltered data. +""" + + +@dataclass +class CorrelationStatConfig(StatConfig): + _transform_input_feature_types: SelectedFeatureTypes = field( + default_factory=lambda: [None, get_feature("TARGET_FEATURE_TYPES")], + init=False, + repr=False, + ) + _transform_unused_feature_types: SelectedFeatureTypes = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES"), None], + init=False, + repr=False, + ) + processor_name: str = field(default="correlation", init=False, repr=False) + + input_columns: str = None + target_column: str = None + method: str = "auto" + + def __post_init__(self): + if self.method not in [ + "auto", + "pearsonr", + "spearmanr", + "kendalltau", + "pointbiserialr", + ]: + raise ValueError( + f"Method {self.method} not supported. Supported methods are: pearsonr, spearmanr, kendalltau, pointbiserialr" + ) + if self.method != "auto": + self._transform_process_desc = ( + f"Calculating {self.method} correlation with target variable" + ) + + +class CorrelationStat(Stat): + """ + Correlation calculation based on the correlation with the target variable. + """ + + config_class = CorrelationStatConfig + config: CorrelationStatConfig + output_dtype = "float64" + func = None + + @sync_backup_config + def set_params(self, **kwargs): + self.config = self.config.replace_defaults(**kwargs) + import scipy.stats + + if self.config.method != "auto": + self.func = getattr(scipy.stats, self.config.method) + return self + + def transform( + self, + X, + y, + input_columns: SelectedColumnTypes = None, + target_column: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + self._input_columns = self._set_input_columns_and_arity( + input_columns, target_column + ) + return self._process_transform( + X, + y, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def fit_transform( + self, + X, + y=None, + input_columns: SelectedColumnTypes = None, + target_column: SelectedColumnTypes = None, + keep_unused_columns: bool = True, + raise_if_missing: bool = None, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = {"fn_kwargs": {}}, + num_proc: int = None, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ): + return self.fit( + X, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ).transform( + X, + y, + input_columns=input_columns, + target_column=target_column, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def _check_data(self, X, y): + if self.config.method == "pointbiserialr": + dtypes = next(iter(DataHandler.get_dtypes(y).values())) + if not any(dtype in dtypes for dtype in ["int", "bool", "str"]): + raise ValueError( + f"Point biserial correlation can only be used with binary target variables. Data type: {dtypes}" + ) + n_unique = DataHandler.nunique( + y, DataHandler.get_column_names(y, generate_cols=True)[0] + ) + if n_unique != 2: + raise ValueError( + f"Point biserial correlation can only be used with binary target variables. Number of unique values: {n_unique}" + ) + return y, X # the first argument must be the binary target variable + return X, y + + def _process_transform_input(self, X, **kwargs): + if self.config.method == "auto": + import scipy.stats + + inds = kwargs["fn_kwargs"]["extra_indices"][0] + is_categorical = DataHandler.is_categorical(X, inds, threshold=0.05) + if is_categorical: + n_classes = DataHandler.nunique(X, inds) + if n_classes == 2: + self.config.method = "pointbiserialr" + kwargs["desc"] = ( + "Calculating point biserial correlation with target variable" + ) + else: + self.config.method = "pearsonr" + kwargs["desc"] = ( + "Calculating pearson correlation with target variable" + ) + else: + self.config.method = "pearsonr" + kwargs["desc"] = "Calculating pearson correlation with target variable" + self.func = getattr(scipy.stats, self.config.method) + return X, kwargs + + def _transform_sklearn(self, X, y): + cols = DataHandler.get_column_names(X, generate_cols=True) + corrs = {} + for col in cols: + corr, _ = self.func( + *self._check_data( + DataHandler.select_column(X, col), DataHandler.select_column(y, 0) + ) + ) + corrs[str(col)] = [corr] + return corrs diff --git a/src/biofit/stat/distance/__init__.py b/src/biofit/stat/distance/__init__.py new file mode 100644 index 0000000..8632b03 --- /dev/null +++ b/src/biofit/stat/distance/__init__.py @@ -0,0 +1,9 @@ +# ruff: noqa +from .distance import ( + DistanceStat, + DistanceStatConfig, + DistanceStatConfigForASV, + DistanceStatConfigForOTU, + DistanceStatConfigForMetagenomics, + DistanceStatConfigForReadCount, +) diff --git a/src/biofit/stat/distance/distance.py b/src/biofit/stat/distance/distance.py new file mode 100644 index 0000000..0c62e3b --- /dev/null +++ b/src/biofit/stat/distance/distance.py @@ -0,0 +1,296 @@ +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Optional, Type + +import numpy as np +from biocore import DataHandler +from sklearn.metrics import DistanceMetric + +from biofit.integration.biosets import get_feature +from biofit.processing import SelectedColumnTypes, sync_backup_config + +from ..stat import Stat, StatConfig + +if TYPE_CHECKING: + pass + + +@dataclass +class DistanceStatConfig(StatConfig): + """ + A base class for distance metrics. + + Inherits all attributes and methods from StatConfig. + + This class is tailored to handle distance metrics, including Minkowski, + weighted and unweighted, seuclidean, and mahalanobis. + """ + + _fit_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + _transform_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + processor_name: str = field(default="distance", init=False, repr=False) + metric: str = "euclidean" + + +@dataclass +class DistanceStatConfigForMetagenomics(DistanceStatConfig): + """ + Calculates distance metrics specifically for metagenomics data. + + Inherits all configs from DistanceStatConfig. + + This class is tailored to handle metagenomics data, including OTU abundance, + ASV abundance, and read counts, using the 'braycurtis' metric by default. + """ + + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [(get_feature("Abundance"), get_feature("ReadCount"))], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [(get_feature("Abundance"), get_feature("ReadCount"))], + init=False, + repr=False, + ) + dataset_name: str = field(default="metagenomics", init=False, repr=False) + metric: str = "braycurtis" + + +@dataclass +class DistanceStatConfigForOTU(DistanceStatConfig): + """ + A subclass for calculating distance metrics on OTU abundance data. + + Inherits all configs from DistanceStatConfig, but is specifically + tailored for OTU abundance data, defaulting to the 'braycurtis' metric. + """ + + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance")], init=False, repr=False + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance")], init=False, repr=False + ) + dataset_name: str = field(default="otu", init=False, repr=False) + metric: str = "braycurtis" + + +@dataclass +class DistanceStatConfigForASV(DistanceStatConfig): + """ + A subclass for calculating distance metrics on ASV abundance data. + + Inherits all attributes and methods from DistanceStat, but is specifically + tailored for ASV abundance data, defaulting to the 'braycurtis' metric. + """ + + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance")], init=False, repr=False + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance")], init=False, repr=False + ) + dataset_name: str = field(default="asv", init=False, repr=False) + metric: str = "braycurtis" + + +class DistanceStatConfigForReadCount(DistanceStatConfig): + """ + A subclass for calculating distance metrics on read count data. + + Inherits all attributes and methods from DistanceStat, but is specifically + tailored for read count data, defaulting to the 'braycurtis' metric. + """ + + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("ReadCount")], init=False, repr=False + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("ReadCount")], init=False, repr=False + ) + dataset_name: str = field(default="read_count", init=False, repr=False) + metric: str = "braycurtis" + + +class DistanceStatConfigForSNP(DistanceStatConfig): + """ + A subclass for calculating distance metrics on read count data. + + Inherits all attributes and methods from DistanceStat, but is specifically + tailored for read count data, defaulting to the 'braycurtis' metric. + """ + + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("GenomicVariant")], init=False, repr=False + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("GenomicVariant")], init=False, repr=False + ) + dataset_name: str = field(default="snp", init=False, repr=False) + + metric: str = "jaccard" + + +class DistanceStat(Stat): + """ + A base class for calculating distance metrics on genomic data. + + Attributes: + metric (str): The name of the distance metric to use. Defaults to 'euclidean'. + p (float): The p-norm to apply for Minkowski, weighted and unweighted. Default is 2. + w (Union[None, np.ndarray], optional): The weight vector for weighted Minkowski. Default is None. + V (Union[None, np.ndarray], optional): The variance vector for seuclidean. Default is None. + VI (Union[None, np.ndarray], optional): The inverse of the covariance matrix for mahalanobis. Default is None. + + Methods: + fit_sklearn(X: Union[np.ndarray, pd.DataFrame, "pl.DataFrame"]): Validates input data and calculates the pairwise distances between samples in X. + """ + + _config_class = DistanceStatConfig + config: DistanceStatConfig + output_dtype = "float64" + + def __init__( + self, + config: Optional[DistanceStatConfig] = None, + metric: Optional[str] = "euclidean", + **kwargs, + ): + super().__init__(config=config, metric=metric, **kwargs) + self.pdist = DistanceMetric.get_metric(self.config.metric).pairwise + + @sync_backup_config + def set_params(self, **kwargs): + self.config = self.config.replace_defaults(**kwargs) + self.pdist = DistanceMetric.get_metric(self.config.metric).pairwise + return self + + def fit( + self, + X, + input_columns: SelectedColumnTypes = None, + raise_if_missing: bool = None, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + num_proc: int = None, + map_kwargs: dict = {"fn_kwargs": {}}, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ) -> "DistanceStat": + self.config._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_fit( + X, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ) + + def transform( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + self.config._n_features_out = DataHandler.get_shape(X)[0] + self._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_transform( + X, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def fit_transform( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = True, + raise_if_missing: bool = False, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = {"fn_kwargs": {}}, + num_proc: int = None, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ): + return self.fit( + X, + input_columns=input_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ).transform( + X, + input_columns=input_columns, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def _transform_numpy(self, X: np.ndarray): + return self.pdist(X, X) diff --git a/src/biofit/stat/row_mean/__init__.py b/src/biofit/stat/row_mean/__init__.py new file mode 100644 index 0000000..a0ec42e --- /dev/null +++ b/src/biofit/stat/row_mean/__init__.py @@ -0,0 +1,6 @@ +# ruff: noqa +from .row_mean import ( + RowMeanStat, + RowMeanStatConfig, + RowMeanStatConfigForOTU, +) diff --git a/src/biofit/stat/row_mean/row_mean.py b/src/biofit/stat/row_mean/row_mean.py new file mode 100644 index 0000000..a6eecc2 --- /dev/null +++ b/src/biofit/stat/row_mean/row_mean.py @@ -0,0 +1,178 @@ +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Type + +import numpy as np + +from biofit.integration.biosets import get_feature +from biofit.processing import SelectedColumnTypes +from biofit.utils import logging + +from ..stat import Stat, StatConfig + +if TYPE_CHECKING: + import polars as pl + +logger = logging.get_logger(__name__) + + +@dataclass +class RowMeanStatConfig(StatConfig): + _transform_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + # process description + _transform_process_desc: str = field( + default="Calculating row means", init=False, repr=False + ) + processor_name: str = field(default="row_mean", init=False, repr=False) + _n_features_out: int = field(default=1, init=False, repr=False) + + +@dataclass +class RowMeanStatConfigForOTU(RowMeanStatConfig): + """Computes the sum of each row in the OTUAbundance feature. + This class is the same as TotalOTUAbundanceStat. + It is provided for autostat + """ + + # dataset specific attributes + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance")], init=False, repr=False + ) + dataset_name: str = field(default="otu", init=False, repr=False) + + +class RowMeanStat(Stat): + # config attributes + _config_class = RowMeanStatConfig + config: RowMeanStatConfig + output_dtype = "float64" + + def fit( + self, + X, + input_columns: SelectedColumnTypes = None, + raise_if_missing: bool = None, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + num_proc: int = None, + map_kwargs: dict = {"fn_kwargs": {}}, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ) -> "RowMeanStat": + self.config._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_fit( + X, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ) + + def transform( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + self._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_transform( + X, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def fit_transform( + self, + X, + *args, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = True, + raise_if_missing: bool = False, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = {"fn_kwargs": {}}, + num_proc: int = None, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ): + return self.fit( + X, + input_columns=input_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ).transform( + X, + input_columns=input_columns, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def _transform_numpy(self, X: np.ndarray, y=None): + if len(X.shape) == 1: + return np.mean(X)[:, None] + return np.mean(X, axis=1)[:, None] + + def _transform_polars(self, X: "pl.DataFrame", y=None): + return X.mean_horizontal().to_frame() diff --git a/src/biofit/stat/row_missingness/__init__.py b/src/biofit/stat/row_missingness/__init__.py new file mode 100644 index 0000000..ce074fd --- /dev/null +++ b/src/biofit/stat/row_missingness/__init__.py @@ -0,0 +1,8 @@ +# ruff: noqa +from .row_missingness import ( + RowMissingnessStat, + RowMissingnessStatConfig, + RowMissingnessStatConfigForOTU, + RowMissingnessStatConfigForSNP, + RowMissingnessStatConfigForMetagenomics, +) diff --git a/src/biofit/stat/row_missingness/row_missingness.py b/src/biofit/stat/row_missingness/row_missingness.py new file mode 100644 index 0000000..8d93818 --- /dev/null +++ b/src/biofit/stat/row_missingness/row_missingness.py @@ -0,0 +1,296 @@ +import os +import sys +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Optional, Type + +import numpy as np +import pandas as pd + +from biofit.integration.biosets import get_feature +from biofit.processing import SelectedColumnTypes +from biofit.utils import logging + +from ..stat import Stat, StatConfig + +if TYPE_CHECKING: + import polars as pl + +logger = logging.get_logger(__name__) + + +def is_multiprocess_mode(): + return os.getppid() > 1 + + +def _filter_samples_pandas(X: pd.DataFrame, depth: Optional[float] = None): + """ + SampleFilter samples in a pandas DataFrame based on their presence in the dataset. + + Args: + X (pd.DataFrame): The input DataFrame containing the samples. + min_sample_presence (float, optional): The minimum required presence of a sample in the dataset. Defaults to 0.5. + depth (float, optional): The depth threshold for numeric columns. If specified, counts values less than depth. Defaults to None. + + Returns: + List[int]: A list of row indices that pass the filtering condition. + """ + total_missing = X.isnull().sum(axis=1) + if depth is not None: + numeric_rows = X.select_dtypes(include=np.number).index + total_missing[numeric_rows] += (X.loc[numeric_rows] <= depth).sum(axis=1) + return total_missing.to_frame() + + +def _filter_samples_numpy(X: np.ndarray, depth: Optional[float] = None): + total_missing = np.sum(np.isnan(X) | (X is None), axis=1) + if depth is not None: + total_missing += np.sum(X <= depth, axis=1) + return total_missing[:, np.newaxis] + + +def _filter_samples_polars(X: "pl.DataFrame", depth: Optional[float] = None): + if "polars" not in sys.modules: + import polars + else: + polars = sys.modules["polars"] + total_missing = X.with_columns( + polars.col("*").is_null() | polars.col("*").is_nan() + ).sum_horizontal() + if depth is not None: + total_missing += X.with_columns(polars.col("*") <= depth).sum_horizontal() + return total_missing.to_frame() + + +@dataclass +class RowMissingnessStatConfig(StatConfig): + # process description + _transform_process_desc: str = field( + default="Calculating the number of missing values for each sample", + init=False, + repr=False, + ) + _fit_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + _transform_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + processor_name: str = field(default="row_missingness", init=False, repr=False) + _n_features_out: int = field(default=1, init=False, repr=False) + + # config attributes + depth: Optional[float] = None + + +@dataclass +class RowMissingnessStatConfigForOTU(RowMissingnessStatConfig): + """Computes the sum of each row in the OTUAbundance feature. + This class is the same as TotalOTUAbundanceStat. + It is provided for autostat + """ + + # dataset specific attributes + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance")], init=False, repr=False + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance")], init=False, repr=False + ) + transform_process_name: str = field( + default="Calculating sample richness", init=False, repr=False + ) + dataset_name: str = field(default="otu", init=False, repr=False) + + # config attributes + depth = 100 + + +@dataclass +class RowMissingnessStatConfigForSNP(RowMissingnessStatConfig): + """Computes the sum of each row in the OTUAbundance feature. + This class is the same as TotalOTUAbundanceStat. + It is provided for autostat + """ + + # dataset specific attributes + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("GenomicVariant")], init=False, repr=False + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("GenomicVariant")], init=False, repr=False + ) + transform_process_name: str = field( + default="Calculating sample richness", init=False, repr=False + ) + dataset_name: str = field(default="snp", init=False, repr=False) + + # config attributes + depth = 100 + + +@dataclass +class RowMissingnessStatConfigForMetagenomics(RowMissingnessStatConfig): + """Computes the sum of each row in the OTUAbundance feature. + This class is the same as TotalOTUAbundanceStat. + It is provided for autostat + """ + + # dataset specific attributes + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [(get_feature("Abundance"), get_feature("ReadCount"))], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [(get_feature("Abundance"), get_feature("ReadCount"))], + init=False, + repr=False, + ) + transform_process_name: str = field( + default="Calculating sample richness", init=False, repr=False + ) + + # config attributes + depth: int = 100 + + +class RowMissingnessStat(Stat): + # feature attributes + one_to_one_features = False + n_features_out = 1 + + # config attributes + _config_class = RowMissingnessStatConfig + config: RowMissingnessStatConfig + + def fit( + self, + X, + input_columns: SelectedColumnTypes = None, + raise_if_missing: bool = None, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + num_proc: int = None, + map_kwargs: dict = {"fn_kwargs": {}}, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ) -> "RowMissingnessStat": + self.config._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_fit( + X, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ) + + def transform( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + self._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_transform( + X, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def fit_transform( + self, + X, + *args, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = True, + raise_if_missing: bool = False, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = {"fn_kwargs": {}}, + num_proc: int = None, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ): + return self.fit( + X, + input_columns=input_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ).transform( + X, + input_columns=input_columns, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def _transform_numpy(self, X: np.ndarray) -> np.ndarray: + return _filter_samples_numpy(X, self.config.depth) + + def _transform_pandas(self, X: pd.DataFrame) -> pd.DataFrame: + return _filter_samples_pandas(X, self.config.depth) + + def _transform_polars(self, X: "pl.DataFrame") -> "pl.DataFrame": + return _filter_samples_polars(X, self.config.depth) diff --git a/src/biofit/stat/row_sum/__init__.py b/src/biofit/stat/row_sum/__init__.py new file mode 100644 index 0000000..d9e93c0 --- /dev/null +++ b/src/biofit/stat/row_sum/__init__.py @@ -0,0 +1,6 @@ +# ruff: noqa +from .row_sum import ( + RowSumStat, + RowSumStatConfig, + RowSumStatConfigForOTU, +) diff --git a/src/biofit/stat/row_sum/row_sum.py b/src/biofit/stat/row_sum/row_sum.py new file mode 100644 index 0000000..2d55eda --- /dev/null +++ b/src/biofit/stat/row_sum/row_sum.py @@ -0,0 +1,150 @@ +""" +This module provides classes for computing the sum of rows and columns in input data. + +Classes: +- RowSumStat: Computes the sum of each row in the input data. +- ColumnSumStat: Computes the sum of each column in the input data. +- ColumnStatForOTU: Computes the sum of each column in the OTUAbundance feature. +- RowSumStatForOTU: Computes the sum of each row in the OTUAbundance feature. +- TotalOTUAbundanceStat: Alias for RowSumStatForOTU. + +These classes provide methods for transforming and fitting the input data using different libraries such as numpy, pandas, and polars. + +Note: The ColumnStatForOTU and RowSumStatForOTU classes are provided for autostat and are equivalent to ColumnSumStat and RowSumStat respectively. +""" + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Type + +import numpy as np +import pandas as pd +from biocore.utils.import_util import requires_backends + +from biofit.integration.biosets import get_feature +from biofit.processing import SelectedColumnTypes +from biofit.utils import logging + +from ..stat import Stat, StatConfig + +if TYPE_CHECKING: + import polars as pl + +logger = logging.get_logger(__name__) + + +@dataclass +class RowSumStatConfig(StatConfig): + # process description + _transform_process_desc: str = field( + default="Calculating row sums", init=False, repr=False + ) + _fit_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + _transform_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + processor_name: str = field(default="row_sum", init=False, repr=False) + _n_features_out: int = field(default=1, init=False, repr=False) + + +@dataclass +class RowSumStatConfigForOTU(RowSumStatConfig): + """Computes the sum of each row in the OTUAbundance feature.""" + + # dataset specific attributes + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance")], init=False, repr=False + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("Abundance")], init=False, repr=False + ) + dataset_name: str = field(default="otu", init=False, repr=False) + + +class RowSumStat(Stat): + """ + Computes the sum of each row in the input data. + """ + + # config attributes + _config_class = RowSumStatConfig + config: RowSumStatConfig + output_dtype = "float64" + + def fit( + self, + X, + input_columns: SelectedColumnTypes = None, + raise_if_missing: bool = None, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + num_proc: int = None, + map_kwargs: dict = {"fn_kwargs": {}}, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ) -> "RowSumStat": + self.config._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_fit( + X, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ) + + def _transform_numpy(self, X: np.ndarray): + """ + Transforms the input data using numpy. + + Args: + X: Input data as a numpy array. + + Returns: + Transformed data as a numpy array. + """ + if len(X.shape) == 1: + return np.sum(X)[:, None] + return np.sum(X, axis=1)[:, None] + + def _transform_pandas(self, X: pd.DataFrame): + """ + Transforms the input data using pandas. + + Args: + X: Input data as a pandas DataFrame. + + Returns: + Transformed data as a pandas DataFrame. + """ + return X.astype("float64").sum(axis=1).to_frame() + + def _transform_polars(self, X: "pl.DataFrame"): + """ + Transforms the input data using polars. + + Args: + X: Input data as a polars DataFrame. + + Returns: + Transformed data as a polars DataFrame. + """ + requires_backends(self._transform_polars, "polars") + import polars as pl + + return X.cast(pl.Float64).sum_horizontal().to_frame() diff --git a/src/biofit/stat/stat.py b/src/biofit/stat/stat.py new file mode 100644 index 0000000..929d388 --- /dev/null +++ b/src/biofit/stat/stat.py @@ -0,0 +1,103 @@ +from dataclasses import dataclass, field + +from biofit.processing import BaseProcessor, ProcessorConfig, SelectedColumnTypes + + +@dataclass +class StatConfig(ProcessorConfig): + processor_type: str = field(default="stat", init=False, repr=False) + + +class Stat(BaseProcessor): + """Base class for statistical processors.""" + + def transform( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + self._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_transform( + X, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def fit_transform( + self, + X, + *args, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = True, + raise_if_missing: bool = False, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = {"fn_kwargs": {}}, + num_proc: int = None, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ): + return self.fit( + X, + input_columns=input_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ).transform( + X, + input_columns=input_columns, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def _process_transform(self, *args, **kwargs): + kwargs["keep_unused_columns"] = False + return super()._process_transform(*args, **kwargs) diff --git a/src/biofit/stat/summary/__init__.py b/src/biofit/stat/summary/__init__.py new file mode 100644 index 0000000..7457204 --- /dev/null +++ b/src/biofit/stat/summary/__init__.py @@ -0,0 +1,2 @@ +# ruff: noqa +from .summary import SummaryStatForOTU diff --git a/src/biofit/stat/summary/summary.py b/src/biofit/stat/summary/summary.py new file mode 100644 index 0000000..e69de29 diff --git a/src/biofit/train.py b/src/biofit/train.py new file mode 100644 index 0000000..fad4ab5 --- /dev/null +++ b/src/biofit/train.py @@ -0,0 +1,564 @@ +import copy +import os +from pathlib import Path +from typing import TYPE_CHECKING, Callable, List, Union + +import numpy as np +import pandas as pd +from biocore import DataHandler +from sklearn.base import BaseEstimator +from sklearn.model_selection import ( + BaseCrossValidator, +) +from sklearn.pipeline import Pipeline + +from biofit.auto.processing_auto import ProcessorPipeline +from biofit.processing import BaseProcessor +from biofit.train_eval_utils import ( + _flatten_pipeline, + _get_data, + split, +) +from biofit.utils import ( + enable_full_determinism, + logging, +) + +if TYPE_CHECKING: + import polars as pl + from datasets import Dataset + +logger = logging.get_logger(__name__) + + +def preprocess( + preprocessor, + x_train, + y_train=None, + x_valid=None, + y_valid=None, + cache_dir=None, + transform_only=False, + raise_error=False, +): + try: + if isinstance(preprocessor, (Pipeline, ProcessorPipeline)): + for proc in preprocessor.steps: + p = proc[-1] if isinstance(proc, tuple) else proc + if isinstance(p, BaseProcessor): + extra_kwargs = { + "cache_dir": cache_dir, + # "load_from_cache_file": False, + } + # p.config.enable_caching = False + else: + extra_kwargs = {} + if transform_only: + x_train = p.transform(x_train, **extra_kwargs) + else: + x_train = p.fit_transform(x_train, **extra_kwargs) + if x_valid is not None: + x_valid = p.transform(x_valid, **extra_kwargs) + else: + if isinstance(preprocessor, BaseProcessor): + extra_kwargs = { + "cache_dir": cache_dir, + # "load_from_cache_file": False, + } + # p.config.enable_caching = False + else: + extra_kwargs = {} + if transform_only: + x_train = preprocessor.transform(x_train, **extra_kwargs) + else: + x_train = preprocessor.fit_transform(x_train, **extra_kwargs) + if x_valid is not None: + x_valid = preprocessor.transform(x_valid, **extra_kwargs) + except ValueError: + if raise_error: + raise + else: + logger.info("Preprocessing failed, using nan_to_num") + _train_cols = x_train.columns.tolist() + x_train = np.nan_to_num(x_train) + # convert back to dataframe + x_train = pd.DataFrame(x_train, columns=_train_cols) + if x_valid is not None: + _valid_cols = x_valid.columns.tolist() + x_valid = np.nan_to_num(x_valid) + x_valid = pd.DataFrame(x_valid, columns=_valid_cols) + return preprocess( + preprocessor, + x_train, + y_train, + x_valid, + y_valid, + cache_dir, + transform_only, + True, + ) + + return x_train, y_train, x_valid, y_valid + + +def train( + model: Union[BaseEstimator, BaseProcessor], + data: Union[pd.DataFrame, "pl.DataFrame", "Dataset"], + target: Union[ + pd.Series, "pl.Series", pd.DataFrame, "pl.DataFrame", "Dataset" + ] = None, + valid_data: Union[pd.DataFrame, "pl.DataFrame", "Dataset"] = None, + valid_target: Union[ + pd.Series, "pl.Series", pd.DataFrame, "pl.DataFrame", "Dataset" + ] = None, + groups: Union[pd.Series, "pl.Series"] = None, + input_columns: Union[List[str], str] = "auto", + target_columns: Union[List[str], str] = "auto", + group_name: str = None, + preprocessor: Union[BaseEstimator, BaseProcessor] = None, + eval_metric: Union[str, Callable] = None, + task: str = None, + cv: BaseCrossValidator = None, + random_state: Union[List[int], int] = 42, + save_indices: bool = False, + output_dir: str = None, + cache_dir: str = None, +): + """ + Train a model or processor on the provided data, with optional preprocessing, + validation data, and cross-validation. + + This function supports training models with various data formats (pandas DataFrame, + polars DataFrame, Dataset), optional preprocessing pipelines, cross-validation + strategies, and random seeds for reproducibility. + + Parameters: + model (Union[BaseEstimator, BaseProcessor]): + The model or processor to train. This can be a scikit-learn estimator, a + pipeline, or a custom estimator that follows scikit-learn's API. + data (Union[pd.DataFrame, pl.DataFrame, "Dataset"]): + The training data. It can be a pandas DataFrame, polars DataFrame, or a + Dataset object. + target (Union[pd.Series, pl.Series, pd.DataFrame, pl.DataFrame, "Dataset"], optional): + The target values for supervised learning. If None, the target columns must + be specified in `data` using `target_columns`. + valid_data (Union[pd.DataFrame, pl.DataFrame, "Dataset"], optional): + Optional validation data. If not provided, validation can be done using + cross-validation. + valid_target (Union[pd.Series, pl.Series, pd.DataFrame, pl.DataFrame, "Dataset"], optional): + Target values for the validation data. + groups (Union[pd.Series, pl.Series], optional): + Group labels for the samples used while splitting the dataset into + train/test set. + input_columns (Union[List[str], str], optional): + Names of the input feature columns. If 'auto', the function will attempt to + infer the input columns from the data. Defaults to "auto". + target_columns (Union[List[str], str], optional): + Names of the target columns in `data`. If 'auto', the function will attempt + to infer the target columns. Defaults to "auto". + group_name (str, optional): + Name of the group column in `data` if groups are specified within the data. + preprocessor (Union[BaseEstimator, BaseProcessor], optional): + Preprocessing pipeline or estimator to apply before training the model. + eval_metric (Union[str, Callable], optional): + Evaluation metric or a callable function to evaluate the model's + performance. + task (str, optional): + Type of task to perform (e.g., 'classification', 'regression', + 'multilabel_classification', 'multi_regression'). + cv (type of BaseCrossValidator, optional): + Cross-validation splitting strategy. If provided, the model will be trained + using cross-validation. + random_state (Union[List[int], int], optional): + Random seed(s) for reproducibility. Can be a single integer or a list of + integers. Defaults to 42. + save_indices (bool, optional): + Whether to save the indices of the train and validation splits. Defaults to + False. + output_dir (str, optional): + Directory where outputs will be saved. Defaults to None. + cache_dir (str, optional): + Directory where cache files will be stored. Defaults to None. + + Returns: + Trained model or pipeline. The return type depends on the inputs: + + - If `random_state` is an integer and `cv` is None, returns a single trained model or pipeline. + - If `random_state` is a list of integers, returns a list of trained models or pipelines, one for each random seed. + - If `cv` is specified, returns a list of trained models or pipelines, one for each cross-validation fold. + + Examples: + 1. Basic usage with pandas DataFrame: + + ```python + from biofit.trainer import train + from biofit.models import LogisticRegressionForClassification + import pandas as pd + + # Create sample data + data = pd.DataFrame({ + 'age': [25, 32, 47, 51, 62], + 'salary': [50000, 60000, 70000, 80000, 90000], + 'purchased': [0, 1, 0, 1, 0] + }) + + # Train the model + model = train( + model=LogisticRegressionForClassification(), + data=data, + target_columns='purchased', + input_columns=['age', 'salary'] + ) + ``` + + 2. Using validation data: + + ```python + from biofit.train import train + from biofit.models import RandomForestForClassification + import pandas as pd + + # Training data + train_data = pd.DataFrame({ + 'feature': [1,2,3,4,5], + 'target': [2,4,6,8,10] + }) + + # Validation data + valid_data = pd.DataFrame({ + 'feature': [6,7], + 'target': [12,14] + }) + + # Train the model with validation data + model = train( + model=RandomForestForClassification(), + data=train_data, + valid_data=valid_data, + target_columns='target', + input_columns='feature' + ) + ``` + + 3. Cross-validation with random seed list: + + ```python + from biofit.train import train + from biofit.models import RandomForestForClassifier + from sklearn.model_selection import StratifiedKFold + import pandas as pd + + # Sample data + data = pd.DataFrame({ + 'feature1': [1,2,3,4,5,6], + 'feature2': [6,5,4,3,2,1], + 'target': [0,1,0,1,0,1] + }) + + cv = StratifiedKFold(n_splits=3) + random_seeds = [42, 7, 21] + + # Train the model with cross-validation and multiple random seeds + models = train( + model=RandomForestForClassifier(), + data=data, + target_columns='target', + input_columns=['feature1', 'feature2'], + cv=cv, + random_state=random_seeds + ) + + # models is a list of models trained with different random seeds + ``` + + 4. Using a preprocessor pipeline: + + ```python + from biofit.train import train + from biofit.auto.processing_auto import ProcessorPipeline + from biofit.models import LogisticRegressionForClassification + from sklearn.compose import ColumnTransformer + import pandas as pd + + # Sample data with categorical feature + data = pd.DataFrame({ + 'age': [25, 32, 47, 51, 62], + 'gender': ['M', 'F', 'M', 'F', 'M'], + 'purchased': [0, 1, 0, 1, 0] + }) + + # Define preprocessing + numeric_features = ['age'] + numeric_transformer = StandardScaler() + + categorical_features = ['gender'] + categorical_transformer = OneHotEncoder() + + preprocessor = ColumnTransformer( + transformers=[ + ('num', numeric_transformer, numeric_features), + ('cat', categorical_transformer, categorical_features) + ]) + + # Create a pipeline that first preprocesses the data, then applies the model + pipeline = ProcessorPipeline(steps=[ + ('preprocessor', preprocessor), + ('classifier', LogisticRegressionForClassification()) + ]) + + # Train the model with preprocessor + model = train( + model=pipeline, + data=data, + target_columns='purchased', + input_columns=['age', 'gender'] + ) + ``` + + 5. Training with a Dataset object: + + ```python + from biofit.train import train + from datasets import load_dataset + + # Load a dataset from the datasets library + dataset = load_dataset('csv', data_files='data.csv') + + # Train the model + model = train( + model=SomeModel(), + data=dataset['train'], + target_columns='label', + input_columns='text' + ) + ``` + + Notes: + - The `train` function can handle various data formats and will attempt to + process the data accordingly. + - If `input_columns` or `target_columns` are set to 'auto', the function will + try to infer the columns automatically based on the data. + - The `preprocessor` can be any transformer that follows the scikit-learn API, + including pipelines. + - When using cross-validation (`cv` is not None), the function will train models + for each fold and return a list of models. + - If multiple random seeds are provided (`random_state` is a list), the function + will train multiple models for feature importance analysis, but only return + the models trained with the first random seed. + """ + x_train, y_train, x_valid, y_valid, groups, _, input_columns, target_columns = ( + _get_data( + data=data, + target=target, + valid_data=valid_data, + valid_target=valid_target, + groups=groups, + group_name=group_name, + input_columns=input_columns, + target_columns=target_columns, + format="pandas", + target_required=True, + ) + ) + if cache_dir is None and output_dir is not None: + cache_dir = (Path(output_dir) / ".cache").resolve().as_posix() + + def fit( + x_train, + y_train, + x_valid, + y_valid, + model, + preprocessor, + random_state=None, + cache_dir=None, + ): + if isinstance(model, Pipeline): + if preprocessor is None and len(model.steps) > 1: + preprocessor = _flatten_pipeline( + [p[-1] if isinstance(p, tuple) else p for p in model.steps[:-1]] + ) + preprocessor = Pipeline(preprocessor) + model = ( + model.steps[-1][1] + if isinstance(model.steps[-1], tuple) + else model.steps[-1] + ) + if preprocessor: + x_train, y_train, x_valid, y_valid = preprocess( + preprocessor, x_train, y_train, x_valid, y_valid, cache_dir + ) + + m_params = ( + model.get_params() if hasattr(model, "get_params") else model.__dict__ + ) + if hasattr(model, "set_params"): + if "random_state" in m_params: + model.set_params(random_state=random_state) + elif "seed" in m_params: + model.set_params(seed=random_state) + else: + if "random_state" in m_params: + model.random_state = random_state + elif "seed" in m_params: + model.seed = random_state + models = [] + if task in ("multilabel_classification", "multi_regression"): + # also multi_column_regression + models = [model] * y_train.shape[1] + for idx, _m in enumerate(models): + if _m.__module__.startswith("xgboost"): + _m.fit( + x_train, + y_train[:, idx], + model__eval_set=[(x_valid, y_valid[:, idx])] + if x_valid is not None + else None, + model__verbose=False, + ) + elif _m.__module__.startswith("biofit"): + old_val = _m.config.enable_caching + _m.config.enable_caching = False + if ( + "early_stopping_rounds" in m_params + and m_params["early_stopping_rounds"] is not None + and m_params["early_stopping_rounds"] > 0 + ): + _m.fit( + x_train, + y_train[:, idx], + eval_set=[(x_valid, y_valid[:, idx])] + if x_valid is not None + else None, + # load_from_cache_file=False, + ) + else: + _m.fit( + x_train, + y_train[:, idx], + load_from_cache_file=False, + ) + _m.config.enable_caching = old_val + else: + _m.fit( + x_train, + y_train[:, idx], + ) + + else: + if model.__module__.startswith("xgboost"): + model.fit( + x_train, + y_train, + model__eval_set=[(x_valid, y_valid)], + model__verbose=False, + ) + elif model.__module__.startswith("biofit"): + old_val = model.config.enable_caching + model.config.enable_caching = False + model.fit( + x_train, + y_train, + load_from_cache_file=False, + ) + model.config.enable_caching = old_val + else: + model.fit(x_train, y_train) + + models = [model] + return models, preprocessor + + count = 0 + count += 1 + + def training_loop(random_state): + enable_full_determinism(random_state) + if cv is None: + _m = copy.deepcopy(model) + _p = copy.deepcopy(preprocessor) + + fitted_model_, preprocessor_ = fit( + x_train=x_train, + y_train=y_train, + x_valid=x_valid, + y_valid=y_valid, + model=_m, + preprocessor=_p, + random_state=random_state, + cache_dir=cache_dir, + ) + + if isinstance(fitted_model_, list) and len(fitted_model_) == 1: + fitted_model_ = fitted_model_[0] + return fitted_model_, preprocessor_ + else: + fitted_models = [] + fitted_preprocessors = [] + if hasattr(cv, "__dict__") and "random_state" in cv.__dict__: + cv.random_state = random_state + for fold, (train_idx, valid_idx) in enumerate( + split(cv, X=x_train, y=y_train, groups=groups) + ): + _cache_dir = os.path.join(cache_dir, f"fold_{fold}") + _m = copy.deepcopy(model) + _p = copy.deepcopy(preprocessor) + + m_params = _m.get_params() + if "random_state" in m_params: + _m.set_params(random_state=random_state) + elif "seed" in m_params: + _m.set_params(seed=random_state) + xtrain_fold, xvalid_fold = ( + DataHandler.select_rows(x_train, train_idx), + DataHandler.select_rows(x_train, valid_idx), + ) + ytrain_fold, yvalid_fold = ( + DataHandler.select_rows(y_train, train_idx), + DataHandler.select_rows(y_train, valid_idx), + ) + fitted_model_fold, fitted_preprocessor_fold = fit( + xtrain_fold, + ytrain_fold, + xvalid_fold, + yvalid_fold, + _m, + _p, + random_state=random_state, + cache_dir=_cache_dir, + ) + if isinstance(fitted_model_fold, list) and len(fitted_model_fold) == 1: + fitted_model_fold = fitted_model_fold[0] + fitted_models.append(fitted_model_fold) + fitted_preprocessors.append(fitted_preprocessor_fold) + + return fitted_models, fitted_preprocessors + + def create_pipeline(model, preprocessor): + pipeline = [] + if isinstance(model, list): + for m, p in zip(model, preprocessor): + if isinstance(p, (Pipeline, ProcessorPipeline)): + p = ProcessorPipeline(_flatten_pipeline(p)) + pipeline.append( + Pipeline([("preprocessor", p), ("model", m)]) if p else m + ) + else: + if isinstance(preprocessor, (Pipeline, ProcessorPipeline)): + preprocessor = ProcessorPipeline(_flatten_pipeline(preprocessor)) + pipeline = ( + Pipeline([("preprocessor", preprocessor), ("model", model)]) + if preprocessor + else model + ) + return pipeline + + if isinstance(random_state, list): + pipeline_list = [] + for rs in random_state: + m, p = training_loop(rs) + pipeline_list.append(create_pipeline(m, p)) + + return pipeline_list + else: + fitted_model, fitted_preprocessor = training_loop(random_state) + return create_pipeline(fitted_model, fitted_preprocessor) diff --git a/src/biofit/train_eval_utils.py b/src/biofit/train_eval_utils.py new file mode 100644 index 0000000..2b7aa24 --- /dev/null +++ b/src/biofit/train_eval_utils.py @@ -0,0 +1,1406 @@ +import copy +import os +import sys +from dataclasses import asdict +from typing import TYPE_CHECKING, Callable, Generator, List, Union + +import joblib +import numpy as np +import pandas as pd +import pyarrow as pa +import yaml +from biocore.data_handling import DataHandler +from biocore.utils.import_util import ( + is_catboost_available, + is_lightgbm_available, + is_plotly_available, + is_xgboost_available, +) +from biocore.utils.inspect import get_kwargs +from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor +from sklearn.model_selection import ( + GroupKFold, + KFold, + LeaveOneGroupOut, + StratifiedGroupKFold, + StratifiedKFold, + train_test_split, +) +from sklearn.pipeline import Pipeline +from sklearn.svm import SVC, SVR + +from biofit.auto.processing_auto import ProcessorPipeline +from biofit.models import ( + LightGBMForClassification, + LightGBMForRegression, + RandomForestForClassification, + RandomForestForRegression, +) +from biofit.models.lasso.lasso import LassoForClassification, LassoForRegression +from biofit.models.models import Model +from biofit.processing import BaseProcessor +from biofit.utils import logging +from biocore.utils.py_util import is_bioset, is_dataset +from biofit.visualization.plotting_utils import plot_feature_importance + +if TYPE_CHECKING: + import optuna + +logger = logging.get_logger(__name__) + + +CLASSIFICATION_TASKS = ["binary_classification", "multiclass_classification"] +REGRESSION_TASKS = ["regression", "multi_regression"] + +_MODELS = { + "binary_classification": { + "random_forest": RandomForestForClassification, + "lightgbm": LightGBMForClassification, + "svm": SVC, + "gradient_boosting": GradientBoostingClassifier, + "lasso": LassoForClassification, + }, + "multiclass_classification": { + "random_forest": RandomForestForClassification, + "lightgbm": LightGBMForClassification, + "svm": SVC, + "gradient_boosting": GradientBoostingClassifier, + "lasso": LassoForClassification, + }, + "regression": { + "random_forest": RandomForestForRegression, + "lightgbm": LightGBMForRegression, + "svm": SVR, + "gradient_boosting": GradientBoostingRegressor, + "lasso": LassoForRegression, + }, + "multi_regression": { + "random_forest": RandomForestForRegression, + "lightgbm": LightGBMForRegression, + "svm": SVR, + "gradient_boosting": GradientBoostingRegressor, + "lasso": LassoForRegression, + }, +} + +if is_xgboost_available(): + from xgboost import XGBClassifier, XGBRegressor + + _MODELS["binary_classification"]["xgboost"] = XGBClassifier + _MODELS["multiclass_classification"]["xgboost"] = XGBClassifier + _MODELS["regression"]["xgboost"] = XGBRegressor + _MODELS["multi_regression"]["xgboost"] = XGBRegressor + +if is_lightgbm_available(): + _MODELS["binary_classification"]["lightgbm"] = LightGBMForClassification + _MODELS["multiclass_classification"]["lightgbm"] = LightGBMForClassification + _MODELS["regression"]["lightgbm"] = LightGBMForRegression + _MODELS["multi_regression"]["lightgbm"] = LightGBMForRegression + +if is_catboost_available(): + from catboost import CatBoostClassifier, CatBoostRegressor + + _MODELS["binary_classification"]["catboost"] = CatBoostClassifier + _MODELS["multiclass_classification"]["catboost"] = CatBoostClassifier + _MODELS["regression"]["catboost"] = CatBoostRegressor + _MODELS["multi_regression"]["catboost"] = CatBoostRegressor + + +_CV = { + "binary_classification": { + "stratified_kfold": StratifiedKFold, + "kfold": KFold, + "group_kfold": GroupKFold, + "stratified_group_kfold": StratifiedGroupKFold, + "logo": LeaveOneGroupOut, + }, + "multiclass_classification": { + "stratified_kfold": StratifiedKFold, + "kfold": KFold, + "group_kfold": GroupKFold, + "stratified_group_kfold": StratifiedGroupKFold, + "logo": LeaveOneGroupOut, + }, + "regression": { + "stratified_kfold": StratifiedKFold, + "kfold": KFold, + }, + "multi_regression": { + "stratified_kfold": StratifiedKFold, + "kfold": KFold, + }, + "multilabel_classification": { + "stratified_kfold": StratifiedKFold, + "kfold": KFold, + }, +} + + +def _flatten_pipeline(p): + pipe = [] + if isinstance(p, list): + for step in p: + pipe.extend(_flatten_pipeline(step)) + elif isinstance(p, (Pipeline, ProcessorPipeline)): + for step in p.steps: + pipe.extend( + _flatten_pipeline(step[-1] if isinstance(step, tuple) else step) + ) + else: + pipe.append(p) + return pipe + + +def _get_name_and_params(obj): + obj_params = None + obj_name = None + if obj: + if isinstance(obj, (ProcessorPipeline, Pipeline, list)): + if isinstance(obj, (Pipeline, ProcessorPipeline)): + obj = [step[-1] if isinstance(step, tuple) else step for step in obj] + obj_params = [ + p.get_params() if isinstance(p, BaseProcessor) else p.__dict__ + for p in obj + ] + obj_name = [ + p.config.processor_name + if isinstance(p, BaseProcessor) + and getattr(p.config, "processor_name", None) + else p.__class__.__name__ + for p in obj + ] + else: + obj_params = obj.get_params() + obj_name = ( + obj.config.processor_name + if getattr(obj.config, "processor_name", None) + else obj.__class__.__name__ + ) + return obj_name, obj_params + + +def _get_processor_info( + models, + preprocessors=None, +): + if isinstance(models, list) and len(models) == 1: + models = models[0] + + preprocessor_name, preprocessor_params = None, None + new_preprocessors = preprocessors + if isinstance(models, Pipeline) and preprocessors is None: + if len(models.steps) > 1: + new_preprocessors = _flatten_pipeline( + [p[-1] if isinstance(p, tuple) else p for p in models.steps[:-1]] + ) + if isinstance(new_preprocessors, list): + if len(new_preprocessors) == 1: + new_preprocessors = new_preprocessors[0] + else: + new_preprocessors = ProcessorPipeline(new_preprocessors) + + new_models = ( + models.steps[-1][-1] + if isinstance(models.steps[-1], tuple) + else models.steps[-1] + ) + model_name, model_params = _get_name_and_params(new_models) + if new_preprocessors is not None: + preprocessor_name, preprocessor_params = _get_name_and_params( + new_preprocessors + ) + elif isinstance(models, list): + new_models, model_name, model_params = [], [], [] + new_preprocessors, preprocessor_name, preprocessor_params = [], [], [] + for i, m in enumerate(models): + _preprocessors = None + if isinstance(preprocessors, list): + _preprocessors = preprocessors[i] + if _preprocessors is None: + if isinstance(m, Pipeline): + if len(m.steps) > 1: + _preprocessors = _flatten_pipeline( + [p[-1] if isinstance(p, tuple) else p for p in m.steps[:-1]] + ) + m = ( + m.steps[-1][-1] + if isinstance(m.steps[-1], tuple) + else m.steps[-1] + ) + if isinstance(_preprocessors, list): + if len(_preprocessors) == 1: + _preprocessors = _preprocessors[0] + else: + _preprocessors = ProcessorPipeline(_preprocessors) + + _name, _params = _get_name_and_params(m) + new_models.append(m) + model_name.append(_name) + model_params.append(_params) + if _preprocessors: + _name, _params = _get_name_and_params(_preprocessors) + preprocessor_name.append(_name) + preprocessor_params.append(_params) + new_preprocessors.append(_preprocessors) + if len(preprocessor_name) == 0 or len(preprocessor_params) == 0: + preprocessor_name, preprocessor_params = None, None + else: + model_name, model_params = _get_name_and_params(models) + new_models = models + if preprocessors is not None: + new_preprocessors = preprocessors + preprocessor_name, preprocessor_params = _get_name_and_params(preprocessors) + + return ( + new_models, + model_name, + model_params, + new_preprocessors, + preprocessor_name, + preprocessor_params, + ) + + +def save_model_and_preprocessor( + output_dir, + models, + preprocessors=None, +): + if models is not None: + joblib.dump(models, os.path.join(output_dir, "model.joblib")) + + if preprocessors is not None: + if isinstance(preprocessors, list): + preprocessors = ProcessorPipeline(preprocessors) + joblib.dump(preprocessors, os.path.join(output_dir, "preprocessor.joblib")) + + +def save_feature_importances( + output_dir, + feat_importances, + preprocessor, + data, + target, + target_columns, + save_plot=True, + label_names=None, + **kwargs, +): + feat_importances = feat_importances.reset_index(names=["features"]) + if save_plot: + transformed_dataset = ( + preprocessor.fit_transform(data, load_from_cache_file=False) + if preprocessor + else data + ) + if kwargs.get("input_columns") is not None: + trans_columns = set( + DataHandler.get_column_names(transformed_dataset, generate_cols=True) + ) + kwargs["input_columns"] = [ + col for col in kwargs["input_columns"] if col in trans_columns + ] + + if is_bioset(data) or is_dataset(data): + y = DataHandler.select_columns(target, columns=target_columns) + else: + transformed_dataset = DataHandler.to_pandas(transformed_dataset) + y = DataHandler.to_pandas(target, columns=target_columns) + + plot_dir = os.path.join(output_dir, "plots") + params = dict( + y=y, + path=plot_dir, + target_columns=target_columns, + ) + params.update(kwargs) + + if feat_importances.iloc[:, 1:].sum().sum() == 0: + logger.warning( + "All feature importances are zero. Please check the model and data. " + "Skipping feature importance plot." + ) + else: + path = params.pop("path", None) + params["output_dir"] = path + plot_feature_importance( + X=transformed_dataset, + feature_importances=feat_importances, + label_names=label_names, + **params, + ) + + feature_metadata = None + if "feature_metadata" in kwargs: + feature_metadata = kwargs["feature_metadata"] + elif is_bioset(data): + from biosets import get_feature_metadata + + feature_metadata = get_feature_metadata(data) + + if feature_metadata is not None: + if isinstance(feature_metadata, dict): + feature_metadata = pd.DataFrame( + list(feature_metadata.values()), index=list(feature_metadata.keys()) + ) + feature_metadata = feature_metadata.reset_index(names=["features"]) + if feat_importances.shape[1] > 2: + feat_importances["median"] = feat_importances.iloc[:, 1:].median(axis=1) + feat_importances = feature_metadata.merge( + feat_importances, how="inner", on="features" + ) + feat_importances.to_csv( + os.path.join(output_dir, "feature_importances.csv"), index=False + ) + + +def save_study( + output_dir, + study, +): + from optuna.visualization import ( + plot_param_importances, + plot_slice, + ) + + joblib.dump(study, os.path.join(output_dir, "study.joblib")) + trials_df = study.trials_dataframe() + best_trial_col = [""] * len(trials_df) + best_trial_col[study.best_trial.number] = "Best Trial" + trials_df["Best Trial"] = best_trial_col + trials_df.to_csv(os.path.join(output_dir, "trials.csv"), index=False) + num_trials = len(study.trials) + + plot_dir = os.path.join(output_dir, "plots") + param_importances_path = os.path.join(plot_dir, "param_importances_plot.pdf") + slice_path = os.path.join(plot_dir, "slice_plot.pdf") + + os.makedirs(plot_dir, exist_ok=True) + + if is_plotly_available(): + try: + if num_trials is not None and num_trials > 1: + plot_param_importances(study).write_image( + param_importances_path, format="pdf", engine="kaleido" + ) + plot_slice(study).write_image(slice_path, format="pdf", engine="kaleido") + except (ValueError, RuntimeError) as e: + logger.warning(f"Failed to save optimization plots: {e}") + else: + logger.warning_once( + "Plotly is not installed. Optimization plots will not be saved." + ) + + +def save_metrics( + output_dir, + metrics, +): + if isinstance(metrics, list) and len(metrics) == 1: + metrics = pd.DataFrame({k: [v] for k, v in metrics[0].items()}) + elif isinstance(metrics, dict): + metrics = pd.DataFrame({k: [v] for k, v in metrics.items()}) + metrics.to_csv(os.path.join(output_dir, "metrics.csv"), index=False) + + +def save_confusion_matrix(output_dir, confusion_matrices, label_names=None): + if not isinstance(confusion_matrices, list) or len(confusion_matrices) == 1: + if isinstance(confusion_matrices, list): + confusion_matrices = confusion_matrices[0] + confusion_matrices.to_csv(os.path.join(output_dir, "confusion_matrix.csv")) + else: + if label_names is None: + label_names = [f"label_{i}" for i in range(len(confusion_matrices))] + for i, cm in enumerate(confusion_matrices): + cm.to_csv( + os.path.join(output_dir, f"confusion_matrices_{label_names[i]}.csv") + ) + + +def save_predictions( + output_dir, + preds, + data, + target=None, + label_names=None, + target_columns=None, + index=None, +): + # save as csv + if index is None: + index = range(len(data)) + + if label_names is None: + if is_bioset(data) or is_dataset(data): + if target_columns is not None: + if isinstance(target_columns, list): + target_columns = target_columns[0] + if target_columns in data._info.features: + label_names = data._info.features[target_columns].names + if label_names is None and (is_bioset(target) or is_dataset(target)): + target_columns = target_columns or DataHandler.get_column_names(target)[0] + if target_columns is not None: + if isinstance(target_columns, list): + target_columns = target_columns[0] + if target_columns in target._info.features: + label_names = target._info.features[target_columns].names + + if is_bioset(data) or is_dataset(data): + from biosets.features import Sample + + names = [ + name + for name, feat in data._info.features.items() + if isinstance(feat, Sample) + ] + if len(names) == 1: + names = names[0] + sample_ids = DataHandler.to_numpy(data, names).flatten() + + if sample_ids is not None: + if index is not None and isinstance(index, (list, np.ndarray)): + sample_ids = np.array(sample_ids)[np.array(index)] + if isinstance(preds.index[0], int): + preds.index = [sample_ids[i] for i in preds.index] + else: + preds.index = sample_ids + preds.index.name = names + + if label_names is not None: + preds["predicted"] = DataHandler.argmax( + preds.iloc[:, : len(label_names)], axis=1 + ) + if target is not None: + preds["actual"] = DataHandler.to_numpy(target).flatten()[index].astype(int) + preds["actual"] = preds["actual"].map( + {i: name for i, name in enumerate(label_names)} + ) + col_names = label_names + ["predicted", "actual"] + + if is_bioset(target) or is_dataset(target): + from biosets.features import BinClassLabel + + feat = target._info.features[target_columns] + if isinstance(feat, BinClassLabel) and ( + feat.positive_labels is not None or feat.negative_labels is not None + ): + if feat.id in DataHandler.get_column_names(data): + preds["actual_original"] = DataHandler.to_pandas( + data, feat.id + ).values.flatten()[index] + col_names.append("actual_original") + else: + col_names = label_names + ["predicted"] + preds["predicted"] = preds["predicted"].map( + {i: name for i, name in enumerate(label_names)} + ) + preds.columns = col_names + else: + preds["predicted"] = DataHandler.argmax(preds, axis=1) + if target is not None: + preds["actual"] = ( + DataHandler.to_pandas(target).iloc[index, 0].values.astype(int) + ) + + if preds.index.name is not None: + preds.to_csv(os.path.join(output_dir, "predictions.csv"), index=True) + else: + preds.to_csv(os.path.join(output_dir, "predictions.csv"), index=False) + + +def _save_train_results( + data, + orig_target, + target, + output_dir, + feature_importances=None, + confusion_matrices=None, + study: "optuna.Study" = None, + num_trials=None, + time_limit=None, + models=None, + metrics=None, + preds=None, + input_columns: List[str] = None, + target_columns: List[str] = None, + label_names: List[str] = None, + cv: Union[KFold, StratifiedKFold, GroupKFold, StratifiedGroupKFold] = None, + group_name: str = None, + outer_cv: Union[KFold, StratifiedKFold, GroupKFold, StratifiedGroupKFold] = None, + outer_group_name: str = None, + task: str = None, + eval_metric: Union[str, Callable] = None, + random_state: Union[int, List[int]] = None, + use_suggested_hyperparameters: bool = True, + unused_columns: List[str] = None, + valid_split: str = None, + index: List[int] = None, + feature_importance_plot_params=None, + **kwargs, +): + orig_data = data + data, target, _, _, _, _, _, target_columns = _get_data( + data=data, + target=target, + valid_data=None, + valid_target=None, + groups=None, + group_name=None, + input_columns=input_columns, + target_columns=target_columns, + format=None, + target_required=True, + ) + + if isinstance(preds, list) and len(preds) == 1: + preds = preds[0] + + os.makedirs(output_dir, exist_ok=True) + if study: + save_study(output_dir, study) + + preprocessors = None + if models is not None: + ( + models, + model_name, + model_params, + preprocessors, + preprocessor_name, + preprocessor_params, + ) = _get_processor_info(models) + save_model_and_preprocessor(output_dir, models, preprocessors) + + if metrics is not None: + save_metrics(output_dir, metrics) + + if preds is not None: + save_predictions( + output_dir, + preds, + orig_data, + target if orig_target is None else orig_target, + label_names, + target_columns, + index, + ) + + if confusion_matrices is not None: + save_confusion_matrix(output_dir, confusion_matrices) + + if feature_importances is not None: + feature_importance_plot_params = feature_importance_plot_params or {} + save_feature_importances( + output_dir, + feature_importances, + preprocessors, + orig_data, + target, + target_columns, + input_columns=input_columns, + label_names=label_names, + **feature_importance_plot_params, + ) + + params = { + "preprocessing": [], + "optimization": {}, + "training": {}, + "cross_validation": {}, + "model": {}, + "misc": {}, + } + + # General Information + if task is not None: + params["training"]["task"] = task + + if label_names is not None: + if not isinstance(target_columns, list): + tc = [target_columns] + else: + tc = target_columns + + if not isinstance(label_names, (np.ndarray, list)) or not isinstance( + label_names[0], (np.ndarray, list) + ): + label_names = [label_names] + + if "classification" in task: + target_type = "labels" + else: + target_type = "targets" + + params["training"][target_type] = {} + for i, lab_name in enumerate(tc): + if lab_name is not None: + class_feat = None + if ( + is_bioset(target) or is_dataset(target) + ) and lab_name in target._info.features: + class_feat = target._info.features[lab_name] + elif is_bioset(data) or is_dataset(data): + class_feat = data._info.features[lab_name] + if class_feat is not None: + if "biosets" in sys.modules and isinstance( + class_feat, sys.modules["biosets"].BinClassLabel + ): + _names = class_feat.names + positive_labels = class_feat.positive_labels + negative_labels = class_feat.negative_labels + if positive_labels is not None or negative_labels is not None: + params["training"][target_type][lab_name] = { + _names[0]: negative_labels, + _names[1]: positive_labels, + } + else: + params["training"][target_type][lab_name] = _names + else: + lab_name = lab_name or f"{i + 1}" + params["training"][target_type][lab_name] = class_feat.names + else: + lab_name = lab_name or f"{i + 1}" + params["training"][target_type][lab_name] = DataHandler.to_list( + label_names[i] + ) + + if len(tc) == 1: + if tc[0] is None or tc[0] not in params["training"][target_type]: + p = params["training"][target_type] + if len(p): + params["training"][target_type] = params["training"][target_type][ + next(iter(p)) + ] + else: + params["training"][target_type] = params["training"][target_type][tc[0]] + + if target_columns is not None: + params["training"]["target_columns"] = target_columns + if eval_metric is not None: + params["training"]["eval_metric"] = ( + eval_metric.__name__ if callable(eval_metric) else eval_metric + ) + if random_state is not None: + params["training"]["random_state"] = random_state + + if valid_split is not None: + params["training"]["valid_split"] = valid_split + + # Hyperparameter Tuning Information + if num_trials is not None: + params["optimization"]["num_trials"] = num_trials + if time_limit is not None: + params["optimization"]["time_limit"] = time_limit + # list tuned hyperparameters and their ranges + if study is not None: + search_space = study.get_trials()[0].distributions + params["optimization"]["search_space"] = {} + for param_name, dist in search_space.items(): + dist_pars = dist.__dict__ + dist_pars.pop("step", None) + params["optimization"]["search_space"][param_name] = { + "type": (dist.__class__.__name__), + **dist_pars, + } + + # Cross-Validation Information + if cv is not None: + params["cross_validation"]["cv"] = { + "type": cv.__class__.__name__, + "params": cv.__dict__, + } + + if outer_cv is not None: + params["cross_validation"]["outer_cv"] = { + "type": outer_cv.__class__.__name__, + "params": outer_cv.__dict__, + } + + # Model Information + if model_name is not None: + params["model"]["name"] = model_name + params["model"]["params"] = model_params if model_params is not None else {} + + # Preprocessor Information + if preprocessor_name is not None and preprocessor_params is not None: + if isinstance(preprocessor_name, list): + params["preprocessing"] = [ + {"name": name, "params": params} + for name, params in zip(preprocessor_name, preprocessor_params) + ] + else: + params["preprocessing"] = { + "name": preprocessor_name, + "params": preprocessor_params, + } + + if unused_columns is not None: + params["misc"]["output_dir"] = output_dir + + if unused_columns is not None: + params["misc"]["unused_columns"] = unused_columns + + if outer_group_name is not None: + params["misc"]["outer_group_name"] = outer_group_name + + if group_name is not None: + params["misc"]["group_name"] = group_name + + with open(os.path.join(output_dir, "training_params.yaml"), "w") as f: + # uncomment for debugging + # print(yaml.safe_dump(params, sort_keys=False)) + yaml.safe_dump(params, f, sort_keys=False) + + +def _init_cv( + cv, groups=None, cv_params=None, sub_task="classification", shuffle=True, seed=42 +): + n_splits = 5 + if isinstance(cv, int): + n_splits = cv + if groups is not None: + cv = LeaveOneGroupOut + else: + cv = next(iter(_CV[sub_task].values())) + + if isinstance(cv, str): + cv = _CV[sub_task][cv] + + if isinstance(cv, type): + if cv_params is None: + if cv.__module__.split(".")[0] in ["biofit", "sklearn"]: + cv_params = {} + cv_params["shuffle"] = shuffle + cv_params["random_state"] = seed if shuffle else None + cv_params["n_splits"] = n_splits + else: + raise ValueError( + "Please provide cv_params for custom cross validation." + ) + cv_params = get_kwargs(cv_params, cv) + cv = cv(**cv_params) + return cv + + +def get_task(models): + if isinstance(models, list) and len(models): + model = models[0] + else: + model = models + if isinstance(model, (ProcessorPipeline, Pipeline)): + for m in model.steps: + val = m + if isinstance(val, tuple): + val = val[-1] + if hasattr(val, "config") and hasattr(val.config, "class_names"): + return val.config.class_names + + elif hasattr(model, "config") and hasattr(model.config, "class_names"): + return model.config.class_names + return None + + +def set_class_names(models, class_names): + if class_names is None: + return models + if isinstance(models, list) and len(models): + return [set_class_names(m, class_names) for m in models] + if isinstance(models, (ProcessorPipeline, Pipeline)): + for val in models.steps: + if isinstance(val, tuple) and hasattr(val[-1], "config"): + val[-1].config.class_names = class_names + if hasattr(val, "config"): + val.config.class_names = class_names + elif hasattr(models, "config"): + models.config.class_names = class_names + return models + + +def infer_task(data=None, target=None, target_columns=None, task=None): + sub_task = None + if target is None and data is not None: + _, target, _, _, _, _, _, target_columns = _get_data( + data=data, + target=target, + valid_data=None, + valid_target=None, + groups=None, + group_name=None, + input_columns=None, + target_columns=target_columns, + format=None, + target_required=False, + ) + if target is None and task is None: + raise ValueError("Target is required to infer task.") + + target_dims = DataHandler.get_shape(target) if target is not None else [] + class_names = None + if task is None and (is_bioset(target) or is_dataset(target)): + from biosets.features import BinClassLabel, ClassLabel, RegressionTarget + + if isinstance(target_columns, list): + target_columns = target_columns[0] + if target_columns and target_columns in target._info.features: + feat = target._info.features[target_columns] + else: + feat = [ + feat + for feat in target._info.features.values() + if isinstance(feat, (BinClassLabel, ClassLabel, RegressionTarget)) + ][0] + if isinstance(feat, (BinClassLabel, ClassLabel)): + class_names = target._info.features[target_columns].names + task = "classification" + elif isinstance(feat, RegressionTarget): + task = "regression" + if task == "classification": + if len(target_dims) == 1 or target_dims[1] == 1: + n_classes = None + if target_columns and (is_bioset(target) or is_dataset(target)): + if isinstance(target_columns, list): + target_columns = target_columns[0] + class_names = target._info.features[target_columns].names + n_classes = len(class_names) + else: + n_classes = DataHandler.nunique(target) + + if n_classes > 2: + sub_task = "multiclass_classification" + else: + sub_task = "binary_classification" + else: + sub_task = "multilabel_classification" + elif task == "regression": + if len(target_dims) > 1 and target_dims[1] > 1: + sub_task = "multiregression" + else: + sub_task = "regression" + elif task in CLASSIFICATION_TASKS or task in REGRESSION_TASKS: + sub_task = task + else: + raise ValueError( + "Invalid task. Please specify either a task, such as 'classification' or " + f"'regression', or a sub task, such as {CLASSIFICATION_TASKS + REGRESSION_TASKS}." + ) + return sub_task, class_names + + +def _iter_preprocessor(preprocessor, condition=None): + if isinstance(preprocessor, (Pipeline, ProcessorPipeline)): + start = False + for n, p in preprocessor.steps: + if isinstance(p, BaseProcessor): + if condition == "sample independent": + if p.has_fit: + yield p.config.processor_name, p + else: + break + elif condition == "sample dependent": + if not p.has_fit: + start = True + yield p + elif start: + yield p + else: + yield p + else: + yield p + elif isinstance(preprocessor, list): + for p in preprocessor: + yield p + else: + yield preprocessor + + +def preprocess( + preprocessor, + x_train, + y_train=None, + x_valid=None, + y_valid=None, + cache_dir=None, + transform_only=False, + raise_error=False, +): + try: + if isinstance(preprocessor, (Pipeline, ProcessorPipeline)): + for proc in preprocessor.steps: + p = proc[-1] if isinstance(proc, tuple) else proc + if isinstance(p, BaseProcessor): + extra_kwargs = { + "cache_dir": cache_dir, + "load_from_cache_file": False, + } + else: + extra_kwargs = {} + if transform_only: + x_train = p.transform(x_train, **extra_kwargs) + else: + x_train = p.fit_transform(x_train, **extra_kwargs) + if x_valid is not None: + x_valid = p.transform(x_valid, **extra_kwargs) + else: + if isinstance(preprocessor, BaseProcessor): + extra_kwargs = {"cache_dir": cache_dir, "load_from_cache_file": False} + else: + extra_kwargs = {} + if transform_only: + x_train = preprocessor.transform(x_train, **extra_kwargs) + else: + x_train = preprocessor.fit_transform(x_train, **extra_kwargs) + if x_valid is not None: + x_valid = preprocessor.transform(x_valid, **extra_kwargs) + except ValueError: + if raise_error: + raise + else: + logger.info("Preprocessing failed, using nan_to_num") + _train_cols = x_train.columns.tolist() + x_train = np.nan_to_num(x_train) + # convert back to dataframe + x_train = pd.DataFrame(x_train, columns=_train_cols) + if x_valid is not None: + _valid_cols = x_valid.columns.tolist() + x_valid = np.nan_to_num(x_valid) + x_valid = pd.DataFrame(x_valid, columns=_valid_cols) + return preprocess( + preprocessor, + x_train, + y_train, + x_valid, + y_valid, + cache_dir, + transform_only, + True, + ) + + return x_train, y_train, x_valid, y_valid + + +def split(cv, X, y=None, groups=None, indices=None): + if cv is None: + return [(None, None)] + # check if cv is a generator + if isinstance(cv, Generator): + return cv + if indices is not None: + return cv.split( + DataHandler.select_rows(X, indices), + y=DataHandler.select_column(DataHandler.select_rows(y, indices), 0) + if y is not None + else None, + groups=DataHandler.select_column( + DataHandler.select_rows(groups, indices), 0 + ) + if groups is not None + else None, + ) + return cv.split( + X, + y=DataHandler.select_column(y, 0) if y is not None else None, + groups=DataHandler.select_column(groups, 0) if groups is not None else None, + ) + + +def get_model(models_or_name, task=None): + if isinstance(models_or_name, list) and len(models_or_name): + mon = models_or_name[0] + else: + mon = models_or_name + + if mon is None: + return None + + model = None + model_cls = None + if isinstance(mon, str): + model_name = mon + model_cls = _MODELS[task][model_name] + elif isinstance(mon, type): + model_cls = mon + model_name = mon.__module__.split(".")[-1] + elif isinstance(mon, Model): + model_name = mon.config.processor_name + model_cls = model.__class__ + model = mon + elif isinstance(mon, (ProcessorPipeline, Pipeline)): + for m in _flatten_pipeline(mon): + out = get_model(m, task) + if out is not None: + return out + return None + elif hasattr(mon, "predict"): + model_cls = mon.__class__ + model_name = mon.__class__.__module__.split(".")[-1] + model = mon + else: + return None + return model, model_cls, model_name + + +def get_model_info(model, task=None): + out = get_model(model, task=None) + if out is not None: + model, _, _ = out + if hasattr(model, "config"): + return asdict(model.config) + elif hasattr(model, "get_params"): + return model.get_params() + else: + return model.__dict__ + + return {} + + +def _split_data( + data, target, valid_split, seed, shuffle, stratify_by_column, keep_in_memory=True +): + def _convert_to_in_memory(data, indices=None): + from datasets import InMemoryTable, MemoryMappedTable + + try: + if data._indices or indices is not None: + if indices is None: + indices = data._indices.column("indices") + + if isinstance(data._data, MemoryMappedTable): + table = MemoryMappedTable._apply_replays( + data._data.table, data._data.replays + ) + else: + table = data._data.table + table = table.take(indices) + data._indices = None + data._data = InMemoryTable(table) + + return data + except pa.ArrowInvalid: + logger.error( + "Table is too large for row selection. Please set keep_in_memory=False or " + "shuffle=False until https://github.com/apache/arrow/issues/25822 has been resolved" + ) + raise + + if is_bioset(data) or is_dataset(data): + if keep_in_memory: + _convert_to_in_memory(data) + + train_data, valid_data = data.train_test_split( + test_size=valid_split, + seed=seed if shuffle else None, + shuffle=shuffle, + stratify_by_column=stratify_by_column if shuffle else None, + ).values() + if not shuffle and not stratify_by_column: + target, valid_target = target.train_test_split( + test_size=valid_split, + seed=seed, + shuffle=shuffle, + stratify_by_column=stratify_by_column, + ) + else: + train_indices = train_data._indices.column("indices") + if valid_data._indices: + valid_indices = valid_data._indices.column("indices") + else: + valid_indices = pa.array( + list( + sorted( + set(range(train_data.num_rows)) + - set(DataHandler.to_numpy(train_indices).tolist()) + ) + ) + ) + + if keep_in_memory: + train_data = _convert_to_in_memory(train_data, train_indices) + valid_data = _convert_to_in_memory(valid_data, valid_indices) + target = _convert_to_in_memory(target, train_indices) + valid_target = _convert_to_in_memory(valid_target, valid_indices) + else: + valid_target = copy.deepcopy(target) + valid_target = valid_target.select(valid_indices) + target = target.select(train_indices) + else: + if stratify_by_column is not None: + stratify = DataHandler.to_numpy(data, stratify_by_column) + else: + stratify = target + train_data, valid_data, target, valid_target = train_test_split( + data, + target, + test_size=valid_split, + random_state=seed, + shuffle=shuffle, + stratify=stratify, + ) + return train_data, valid_data, target, valid_target + + +def _get_label_names(data, target, target_columns, u_labels=None): + if isinstance(target_columns, list): + tc = target_columns[0] + else: + tc = target_columns + if (is_bioset(data) or is_dataset(data)) and tc in data._info.features: + labels = data._info.features[tc].names + elif (is_bioset(target) or is_dataset(target)) and tc in target._info.features: + labels = target._info.features[tc].names + else: + labels = DataHandler.to_list(u_labels) + return labels + + +def _get_data( + data, + target=None, + valid_data=None, + valid_target=None, + valid_split=None, + groups=None, + outer_groups=None, + group_name=None, + outer_group_name=None, + input_columns="auto", + target_columns="auto", + valid_input_columns="auto", + valid_target_columns="auto", + target_required=True, + shuffle=True, + seed=42, + keep_in_memory=True, + format="pandas", +): + if DataHandler.supports_named_columns(data): + use_auto_target_columns = target_columns == "auto" + else: + input_columns = None if input_columns == "auto" else input_columns + target_columns = None if target_columns == "auto" else target_columns + + if valid_data is not None and not DataHandler.supports_named_columns(valid_data): + valid_input_columns = ( + None if valid_input_columns == "auto" else valid_input_columns + ) + valid_target_columns = ( + None if valid_target_columns == "auto" else valid_target_columns + ) + + def get_target_if_none(data, target_columns): + target = None + if target_columns is not None: + sel_kwargs = {} + if is_bioset(data) or is_dataset(data): + sel_kwargs = {"keep_old_fingerprint": False} + if isinstance(target_columns, str): + target = DataHandler.select_columns( + data, [target_columns], **sel_kwargs + ) + else: + target = DataHandler.select_columns(data, target_columns, **sel_kwargs) + return target + + def handle_target_columns(data, target_columns, required=True): + if target_columns == "auto": + if is_bioset(data): + from biosets import get_target_col_names + + target_columns = get_target_col_names(data) + elif target_required: + ValueError( + "`data` must be a `Dataset` to automatically infer target columns. " + "Please provide `target_columns` or `target`." + ) + if isinstance(target_columns, (type, tuple)): + if is_bioset(data) or is_dataset(data): + target_columns = [ + k + for k, v in data._info.features.items() + if isinstance(v, target_columns) + ] + if target_columns == "auto": + target_columns = None + if target_columns is None and target_required: + raise ValueError("Target columns must be provided if target is None.") + + if isinstance(target_columns, list) and len(target_columns) == 1: + target_columns = target_columns[0] + return target_columns + + def handle_input_columns(data, input_columns): + if is_bioset(data) or is_dataset(data): + if input_columns == "auto" and is_bioset(data): + from biosets import get_data_col_names + + input_columns = get_data_col_names(data) + elif isinstance(input_columns, (type, tuple)): + input_columns = [ + k + for k, v in data._info.features.items() + if isinstance(v, input_columns) + ] + if input_columns == "auto": + input_columns = None + + return input_columns + + x_train = None + y_train = None + y_valid = None + x_valid = None + + if isinstance(data, dict): + if isinstance(valid_split, str): + valid_data = data.get(valid_split) + else: + valid_data = data.get( + "valid", data.get("test"), data.get("validation"), None + ) + if len(data) < 3: + train_split = [ + k + for k in data.keys() + if k not in ["valid", "test", "validation", valid_split] + ] + if not train_split: + raise ValueError("Train split not found in dataset.") + data = data.get(train_split) + else: + data = data.get("train", data.get("training")) + + if target is None: + target_columns = handle_target_columns( + data, target_columns, required=target_required + ) + y_train = get_target_if_none(data, target_columns) + + else: + y_train = target + if DataHandler.supports_named_columns(target): + target_columns = DataHandler.get_column_names(target)[0] + + if valid_target is None and valid_data is not None: + if valid_target_columns is None or valid_target_columns == "auto": + valid_target_columns = handle_target_columns( + valid_data, + "auto" if use_auto_target_columns else target_columns, + required=False, + ) + if valid_target_columns is not None: + y_valid = get_target_if_none(valid_data, valid_target_columns) + else: + y_valid = get_target_if_none(valid_data, target_columns) + if valid_target_columns is not None and target_columns is not None: + if valid_target_columns != target_columns: + valid_target = DataHandler.set_column_names( + valid_target, target_columns + ) + elif valid_target_columns is not None: + target_columns = valid_target_columns + + elif valid_target is not None: + y_valid = valid_target + + if group_name is not None and groups is None: + groups = DataHandler.select_column(data, group_name) + if outer_group_name is not None and outer_groups is None: + outer_groups = DataHandler.select_column(data, outer_group_name) + if ( + group_name is None + and groups is not None + and DataHandler.supports_named_columns(groups) + ): + group_name = DataHandler.get_column_names(groups)[0] + if ( + outer_group_name is None + and outer_groups is not None + and DataHandler.supports_named_columns(outer_groups) + ): + outer_group_name = DataHandler.get_column_names(outer_groups)[0] + + if valid_data is None and valid_split is not None: + data, valid_data, y_train, y_valid = _split_data( + data, + y_train, + valid_split, + seed=seed, + shuffle=shuffle, + stratify_by_column=group_name, + keep_in_memory=keep_in_memory, + ) + + input_columns = handle_input_columns(data, input_columns) + if valid_data is not None: + valid_input_columns = handle_input_columns(valid_data, valid_input_columns) + if input_columns is not None: + sel_kwargs = {} + if is_bioset(data) or is_dataset(data): + sel_kwargs = {"keep_old_fingerprint": False} + x_train = DataHandler.select_columns(data, input_columns, **sel_kwargs) + if valid_data is not None: + sel_kwargs = {} + if is_bioset(valid_data) or is_dataset(valid_data): + sel_kwargs = {"keep_old_fingerprint": False} + valid_input_columns = valid_input_columns or input_columns + x_valid = DataHandler.select_columns( + valid_data, valid_input_columns, **sel_kwargs + ) + elif ( + target is not None + and DataHandler.supports_named_columns(data) + and DataHandler.supports_named_columns(target) + ): + _input_columns = DataHandler.get_column_names(data) + _target_columns = DataHandler.get_column_names(target) + if group_name is not None: + _target_columns.append(group_name) + if outer_group_name is not None: + _target_columns.append(outer_group_name) + _intersecting_cols = set(_input_columns) & set(_target_columns) + if _intersecting_cols: + x_train = DataHandler.drop_columns(data, list(_intersecting_cols)) + if ( + valid_data is not None + and DataHandler.supports_named_columns(valid_data) + and DataHandler.supports_named_columns(x_train) + ): + sel_kwargs = {} + if is_bioset(valid_data) or is_dataset(valid_data): + sel_kwargs = {"keep_old_fingerprint": False} + x_valid = DataHandler.select_columns( + valid_data, DataHandler.get_column_names(x_train), **sel_kwargs + ) + else: + x_train = data + if valid_data is not None: + x_valid = valid_data + input_columns = DataHandler.get_column_names(x_train) + else: + x_train = data + if valid_data is not None: + x_valid = valid_data + input_columns = DataHandler.get_column_names(x_train) + + if format is not None: + x_train = DataHandler.to_format(x_train, format) + if y_train is not None: + y_train = DataHandler.to_format(y_train, format) + if groups is not None: + groups = DataHandler.to_format(groups, format) + if outer_groups is not None: + outer_groups = DataHandler.to_format(outer_groups, format) + if x_valid is not None: + x_valid = DataHandler.to_format(x_valid, format) + if y_valid is not None: + y_valid = DataHandler.to_format(y_valid, format) + + return ( + x_train, + y_train, + x_valid, + y_valid, + groups, + outer_groups, + input_columns, + target_columns, + ) diff --git a/src/biofit/utils/__init__.py b/src/biofit/utils/__init__.py new file mode 100644 index 0000000..89b7a6e --- /dev/null +++ b/src/biofit/utils/__init__.py @@ -0,0 +1,60 @@ +# ruff: noqa +from .file_utils import ( + add_end_docstrings, + add_start_docstrings, + estimate_dataset_size, + has_ext, + has_separator, + hash_url_to_filename, + is_file_name, + is_local_path, + is_relative_path, + is_remote_url, + move_temp_file, + url_or_path_join, + url_or_path_parent, +) +from .fingerprint import ( + Hasher, + _build_cache_dir, + disable_caching, + enable_caching, + fingerprint_from_data, + fingerprint_from_kwargs, + generate_cache_dir, + get_cache_file_name, + is_caching_enabled, + update_fingerprint, +) +from .logging import ( + disable_progress_bar, + enable_progress_bar, + set_verbosity, + set_verbosity_debug, + set_verbosity_error, + set_verbosity_info, + set_verbosity_warning, + silence, + unsilence, +) +from .py_util import ( + as_py, + enable_full_determinism, + set_seed, +) +from .recorder import ( + load_module_or_class, + record_step, +) +from .table_util import ( + concat_blocks, + determine_upcast, + init_arrow_buffer_and_writer, + is_binary_like, + is_fixed_width, + is_large_binary_like, + read_arrow_table, + upcast_tables, + write_arrow_table, +) +from .types import Unset diff --git a/src/biofit/utils/_dill.py b/src/biofit/utils/_dill.py new file mode 100644 index 0000000..8ef2698 --- /dev/null +++ b/src/biofit/utils/_dill.py @@ -0,0 +1,469 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modified by: Patrick Smyth +# Date: 2024 +# Summary of changes: +# - removed `config` to track the dill version. Version is parsed directly from the dill +# package. +"""Extends `dill` to support pickling more types and produce more consistent dumps.""" + +import os +import sys +from io import BytesIO +from types import CodeType, FunctionType + +import dill +from packaging import version + + +class Pickler(dill.Pickler): + dispatch = dill._dill.MetaCatchingDict(dill.Pickler.dispatch.copy()) + _legacy_no_dict_keys_sorting = False + + def save(self, obj, save_persistent_id=True): + obj_type = type(obj) + if obj_type not in self.dispatch: + if "regex" in sys.modules: + import regex # type: ignore + + if obj_type is regex.Pattern: + pklregister(obj_type)(_save_regexPattern) + if "spacy" in sys.modules: + import spacy # type: ignore + + if issubclass(obj_type, spacy.Language): + pklregister(obj_type)(_save_spacyLanguage) + if "tiktoken" in sys.modules: + import tiktoken # type: ignore + + if obj_type is tiktoken.Encoding: + pklregister(obj_type)(_save_tiktokenEncoding) + if "torch" in sys.modules: + import torch # type: ignore + + if issubclass(obj_type, torch.Tensor): + pklregister(obj_type)(_save_torchTensor) + + if obj_type is torch.Generator: + pklregister(obj_type)(_save_torchGenerator) + + # Unwrap `torch.compile`-ed modules + if issubclass(obj_type, torch.nn.Module): + obj = getattr(obj, "_orig_mod", obj) + if "transformers" in sys.modules: + import transformers # type: ignore + + if issubclass(obj_type, transformers.PreTrainedTokenizerBase): + pklregister(obj_type)(_save_transformersPreTrainedTokenizerBase) + + # Unwrap `torch.compile`-ed functions + if obj_type is FunctionType: + obj = getattr(obj, "_torchdynamo_orig_callable", obj) + dill.Pickler.save(self, obj, save_persistent_id=save_persistent_id) + + def _batch_setitems(self, items): + if self._legacy_no_dict_keys_sorting: + return super()._batch_setitems(items) + # Ignore the order of keys in a dict + try: + # Faster, but fails for unorderable elements + items = sorted(items) + except Exception: # TypeError, decimal.InvalidOperation, etc. + from datasets.fingerprint import Hasher + + items = sorted(items, key=lambda x: Hasher.hash(x[0])) + dill.Pickler._batch_setitems(self, items) + + def memoize(self, obj): + # Don't memoize strings since two identical strings can have different Python ids + if type(obj) is not str: # noqa: E721 + dill.Pickler.memoize(self, obj) + + +def pklregister(t): + """Register a custom reducer for the type.""" + + def proxy(func): + Pickler.dispatch[t] = func + return func + + return proxy + + +def dump(obj, file): + """Pickle an object to a file.""" + Pickler(file, recurse=True).dump(obj) + + +def dumps(obj): + """Pickle an object to a string.""" + file = BytesIO() + dump(obj, file) + return file.getvalue() + + +if version.parse(dill.__version__) < version.parse("0.3.6"): + + def log(pickler, msg): + dill._dill.log.info(msg) + +elif version.parse(dill.__version__) in [ + version.parse("0.3.6").release, + version.parse("0.3.7").release, + version.parse("0.3.8").release, +]: + + def log(pickler, msg): + dill._dill.logger.trace(pickler, msg) + + +@pklregister(set) +def _save_set(pickler, obj): + log(pickler, f"Se: {obj}") + try: + # Faster, but fails for unorderable elements + args = (sorted(obj),) + except Exception: # TypeError, decimal.InvalidOperation, etc. + from datasets.fingerprint import Hasher + + args = (sorted(obj, key=Hasher.hash),) + + pickler.save_reduce(set, args, obj=obj) + log(pickler, "# Se") + + +def _save_regexPattern(pickler, obj): + import regex # type: ignore + + log(pickler, f"Re: {obj}") + args = (obj.pattern, obj.flags) + pickler.save_reduce(regex.compile, args, obj=obj) + log(pickler, "# Re") + + +def _save_tiktokenEncoding(pickler, obj): + import tiktoken # type: ignore + + log(pickler, f"Enc: {obj}") + args = (obj.name, obj._pat_str, obj._mergeable_ranks, obj._special_tokens) + pickler.save_reduce(tiktoken.Encoding, args, obj=obj) + log(pickler, "# Enc") + + +def _save_torchTensor(pickler, obj): + import torch # type: ignore + + # `torch.from_numpy` is not picklable in `torch>=1.11.0` + def create_torchTensor(np_array, dtype=None): + tensor = torch.from_numpy(np_array) + if dtype: + tensor = tensor.type(dtype) + return tensor + + log(pickler, f"To: {obj}") + if obj.dtype == torch.bfloat16: + args = (obj.detach().to(torch.float).cpu().numpy(), torch.bfloat16) + else: + args = (obj.detach().cpu().numpy(),) + pickler.save_reduce(create_torchTensor, args, obj=obj) + log(pickler, "# To") + + +def _save_torchGenerator(pickler, obj): + import torch # type: ignore + + def create_torchGenerator(state): + generator = torch.Generator() + generator.set_state(state) + return generator + + log(pickler, f"Ge: {obj}") + args = (obj.get_state(),) + pickler.save_reduce(create_torchGenerator, args, obj=obj) + log(pickler, "# Ge") + + +def _save_spacyLanguage(pickler, obj): + import spacy # type: ignore + + def create_spacyLanguage(config, bytes): + lang_cls = spacy.util.get_lang_class(config["nlp"]["lang"]) + lang_inst = lang_cls.from_config(config) + return lang_inst.from_bytes(bytes) + + log(pickler, f"Sp: {obj}") + args = (obj.config, obj.to_bytes()) + pickler.save_reduce(create_spacyLanguage, args, obj=obj) + log(pickler, "# Sp") + + +def _save_transformersPreTrainedTokenizerBase(pickler, obj): + log(pickler, f"Tok: {obj}") + # Ignore the `cache` attribute + state = obj.__dict__ + if "cache" in state and isinstance(state["cache"], dict): + state["cache"] = {} + pickler.save_reduce(type(obj), (), state=state, obj=obj) + log(pickler, "# Tok") + + +if version.parse(dill.__version__) < version.parse("0.3.6"): + + @pklregister(CodeType) + def _save_code(pickler, obj): + """ + From dill._dill.save_code + This is a modified version that removes the origin (filename + line no.) + of functions created in notebooks or shells for example. + """ + dill._dill.log.info(f"Co: {obj}") + # The filename of a function is the .py file where it is defined. + # Filenames of functions created in notebooks or shells start with '<' + # ex: for ipython, and for shell + # Filenames of functions created in ipykernel the filename + # look like f"{tempdir}/ipykernel_{id1}/{id2}.py" + # Moreover lambda functions have a special name: '' + # ex: (lambda x: x).__code__.co_name == "" # True + # + # For the hashing mechanism we ignore where the function has been defined + # More specifically: + # - we ignore the filename of special functions (filename starts with '<') + # - we always ignore the line number + # - we only use the base name of the file instead of the whole path, + # to be robust in case a script is moved for example. + # + # Only those two lines are different from the original implementation: + co_filename = ( + "" + if obj.co_filename.startswith("<") + or ( + len(obj.co_filename.split(os.path.sep)) > 1 + and obj.co_filename.split(os.path.sep)[-2].startswith("ipykernel_") + ) + or obj.co_name == "" + else os.path.basename(obj.co_filename) + ) + co_firstlineno = 1 + # The rest is the same as in the original dill implementation + if dill._dill.PY3: + if hasattr(obj, "co_posonlyargcount"): + args = ( + obj.co_argcount, + obj.co_posonlyargcount, + obj.co_kwonlyargcount, + obj.co_nlocals, + obj.co_stacksize, + obj.co_flags, + obj.co_code, + obj.co_consts, + obj.co_names, + obj.co_varnames, + co_filename, + obj.co_name, + co_firstlineno, + obj.co_lnotab, + obj.co_freevars, + obj.co_cellvars, + ) + else: + args = ( + obj.co_argcount, + obj.co_kwonlyargcount, + obj.co_nlocals, + obj.co_stacksize, + obj.co_flags, + obj.co_code, + obj.co_consts, + obj.co_names, + obj.co_varnames, + co_filename, + obj.co_name, + co_firstlineno, + obj.co_lnotab, + obj.co_freevars, + obj.co_cellvars, + ) + else: + args = ( + obj.co_argcount, + obj.co_nlocals, + obj.co_stacksize, + obj.co_flags, + obj.co_code, + obj.co_consts, + obj.co_names, + obj.co_varnames, + co_filename, + obj.co_name, + co_firstlineno, + obj.co_lnotab, + obj.co_freevars, + obj.co_cellvars, + ) + pickler.save_reduce(CodeType, args, obj=obj) + dill._dill.log.info("# Co") + return + +elif version.parse(dill.__version__).release[:3] in [ + version.parse("0.3.6").release, + version.parse("0.3.7").release, + version.parse("0.3.8").release, +]: + # From: https://github.com/uqfoundation/dill/blob/dill-0.3.6/dill/_dill.py#L1104 + @pklregister(CodeType) + def save_code(pickler, obj): + dill._dill.logger.trace(pickler, "Co: %s", obj) + + ############################################################################################################ + # Modification here for huggingface/datasets + # The filename of a function is the .py file where it is defined. + # Filenames of functions created in notebooks or shells start with '<' + # ex: for ipython, and for shell + # Filenames of functions created in ipykernel the filename + # look like f"{tempdir}/ipykernel_{id1}/{id2}.py" + # Moreover lambda functions have a special name: '' + # ex: (lambda x: x).__code__.co_name == "" # True + # + # For the hashing mechanism we ignore where the function has been defined + # More specifically: + # - we ignore the filename of special functions (filename starts with '<') + # - we always ignore the line number + # - we only use the base name of the file instead of the whole path, + # to be robust in case a script is moved for example. + # + # Only those two lines are different from the original implementation: + co_filename = ( + "" + if obj.co_filename.startswith("<") + or ( + len(obj.co_filename.split(os.path.sep)) > 1 + and obj.co_filename.split(os.path.sep)[-2].startswith("ipykernel_") + ) + or obj.co_name == "" + else os.path.basename(obj.co_filename) + ) + co_firstlineno = 1 + # The rest is the same as in the original dill implementation, except for the replacements: + # - obj.co_filename => co_filename + # - obj.co_firstlineno => co_firstlineno + ############################################################################################################ + + if hasattr(obj, "co_endlinetable"): # python 3.11a (20 args) + args = ( + obj.co_lnotab, # for < python 3.10 [not counted in args] + obj.co_argcount, + obj.co_posonlyargcount, + obj.co_kwonlyargcount, + obj.co_nlocals, + obj.co_stacksize, + obj.co_flags, + obj.co_code, + obj.co_consts, + obj.co_names, + obj.co_varnames, + co_filename, # Modification for huggingface/datasets ############################################ + obj.co_name, + obj.co_qualname, + co_firstlineno, # Modification for huggingface/datasets ######################################### + obj.co_linetable, + obj.co_endlinetable, + obj.co_columntable, + obj.co_exceptiontable, + obj.co_freevars, + obj.co_cellvars, + ) + elif hasattr(obj, "co_exceptiontable"): # python 3.11 (18 args) + args = ( + obj.co_lnotab, # for < python 3.10 [not counted in args] + obj.co_argcount, + obj.co_posonlyargcount, + obj.co_kwonlyargcount, + obj.co_nlocals, + obj.co_stacksize, + obj.co_flags, + obj.co_code, + obj.co_consts, + obj.co_names, + obj.co_varnames, + co_filename, # Modification for huggingface/datasets ############################################ + obj.co_name, + obj.co_qualname, + co_firstlineno, # Modification for huggingface/datasets ######################################### + obj.co_linetable, + obj.co_exceptiontable, + obj.co_freevars, + obj.co_cellvars, + ) + elif hasattr(obj, "co_linetable"): # python 3.10 (16 args) + args = ( + obj.co_lnotab, # for < python 3.10 [not counted in args] + obj.co_argcount, + obj.co_posonlyargcount, + obj.co_kwonlyargcount, + obj.co_nlocals, + obj.co_stacksize, + obj.co_flags, + obj.co_code, + obj.co_consts, + obj.co_names, + obj.co_varnames, + co_filename, # Modification for huggingface/datasets ############################################ + obj.co_name, + co_firstlineno, # Modification for huggingface/datasets ######################################### + obj.co_linetable, + obj.co_freevars, + obj.co_cellvars, + ) + elif hasattr(obj, "co_posonlyargcount"): # python 3.8 (16 args) + args = ( + obj.co_argcount, + obj.co_posonlyargcount, + obj.co_kwonlyargcount, + obj.co_nlocals, + obj.co_stacksize, + obj.co_flags, + obj.co_code, + obj.co_consts, + obj.co_names, + obj.co_varnames, + co_filename, # Modification for huggingface/datasets ############################################ + obj.co_name, + co_firstlineno, # Modification for huggingface/datasets ######################################### + obj.co_lnotab, + obj.co_freevars, + obj.co_cellvars, + ) + else: # python 3.7 (15 args) + args = ( + obj.co_argcount, + obj.co_kwonlyargcount, + obj.co_nlocals, + obj.co_stacksize, + obj.co_flags, + obj.co_code, + obj.co_consts, + obj.co_names, + obj.co_varnames, + co_filename, # Modification for huggingface/datasets ############################################ + obj.co_name, + co_firstlineno, # Modification for huggingface/datasets ######################################### + obj.co_lnotab, + obj.co_freevars, + obj.co_cellvars, + ) + + pickler.save_reduce(dill._dill._create_code, args, obj=obj) + dill._dill.logger.trace(pickler, "# Co") + return diff --git a/src/biofit/utils/doc.py b/src/biofit/utils/doc.py new file mode 100644 index 0000000..8bcd3b7 --- /dev/null +++ b/src/biofit/utils/doc.py @@ -0,0 +1,1213 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Doc utilities: Utilities related to documentation +""" + +import functools +import re +import types + + +def add_start_docstrings(*docstr): + def docstring_decorator(fn): + fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") + return fn + + return docstring_decorator + + +def add_start_docstrings_to_model_forward(*docstr): + def docstring_decorator(fn): + docstring = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") + class_name = f"[`{fn.__qualname__.split('.')[0]}`]" + intro = f" The {class_name} forward method, overrides the `__call__` special method." + note = r""" + + + + Although the recipe for forward pass needs to be defined within this function, one should call the [`Module`] + instance afterwards instead of this since the former takes care of running the pre and post processing steps while + the latter silently ignores them. + + +""" + + fn.__doc__ = intro + note + docstring + return fn + + return docstring_decorator + + +def add_end_docstrings(*docstr): + def docstring_decorator(fn): + fn.__doc__ = (fn.__doc__ if fn.__doc__ is not None else "") + "".join(docstr) + return fn + + return docstring_decorator + + +PT_RETURN_INTRODUCTION = r""" + Returns: + [`{full_output_type}`] or `tuple(torch.FloatTensor)`: A [`{full_output_type}`] or a tuple of + `torch.FloatTensor` (if `return_dict=False` is passed or when `config.return_dict=False`) comprising various + elements depending on the configuration ([`{_config_class}`]) and inputs. + +""" + + +TF_RETURN_INTRODUCTION = r""" + Returns: + [`{full_output_type}`] or `tuple(tf.Tensor)`: A [`{full_output_type}`] or a tuple of `tf.Tensor` (if + `return_dict=False` is passed or when `config.return_dict=False`) comprising various elements depending on the + configuration ([`{_config_class}`]) and inputs. + +""" + + +def _get_indent(t): + """Returns the indentation in the first line of t""" + search = re.search(r"^(\s*)\S", t) + return "" if search is None else search.groups()[0] + + +def _convert_output_args_doc(output_args_doc): + """Convert output_args_doc to display properly.""" + # Split output_arg_doc in blocks argument/description + indent = _get_indent(output_args_doc) + blocks = [] + current_block = "" + for line in output_args_doc.split("\n"): + # If the indent is the same as the beginning, the line is the name of new arg. + if _get_indent(line) == indent: + if len(current_block) > 0: + blocks.append(current_block[:-1]) + current_block = f"{line}\n" + else: + # Otherwise it's part of the description of the current arg. + # We need to remove 2 spaces to the indentation. + current_block += f"{line[2:]}\n" + blocks.append(current_block[:-1]) + + # Format each block for proper rendering + for i in range(len(blocks)): + blocks[i] = re.sub(r"^(\s+)(\S+)(\s+)", r"\1- **\2**\3", blocks[i]) + blocks[i] = re.sub(r":\s*\n\s*(\S)", r" -- \1", blocks[i]) + + return "\n".join(blocks) + + +def _prepare_output_docstrings(output_type, _config_class, min_indent=None): + """ + Prepares the return part of the docstring using `output_type`. + """ + output_docstring = output_type.__doc__ + + # Remove the head of the docstring to keep the list of args only + lines = output_docstring.split("\n") + i = 0 + while i < len(lines) and re.search(r"^\s*(Args|Parameters):\s*$", lines[i]) is None: + i += 1 + if i < len(lines): + params_docstring = "\n".join(lines[(i + 1) :]) + params_docstring = _convert_output_args_doc(params_docstring) + else: + raise ValueError( + f"No `Args` or `Parameters` section is found in the docstring of `{output_type.__name__}`. Make sure it has " + "docstring and contain either `Args` or `Parameters`." + ) + + # Add the return introduction + full_output_type = f"{output_type.__module__}.{output_type.__name__}" + intro = ( + TF_RETURN_INTRODUCTION + if output_type.__name__.startswith("TF") + else PT_RETURN_INTRODUCTION + ) + intro = intro.format(full_output_type=full_output_type, _config_class=_config_class) + result = intro + params_docstring + + # Apply minimum indent if necessary + if min_indent is not None: + lines = result.split("\n") + # Find the indent of the first nonempty line + i = 0 + while len(lines[i]) == 0: + i += 1 + indent = len(_get_indent(lines[i])) + # If too small, add indentation to all nonempty lines + if indent < min_indent: + to_add = " " * (min_indent - indent) + lines = [(f"{to_add}{line}" if len(line) > 0 else line) for line in lines] + result = "\n".join(lines) + + return result + + +FAKE_MODEL_DISCLAIMER = """ + + + This example uses a random model as the real ones are all very big. To get proper results, you should use + {real_checkpoint} instead of {fake_checkpoint}. If you get out-of-memory when loading that checkpoint, you can try + adding `device_map="auto"` in the `from_pretrained` call. + + +""" + + +PT_TOKEN_CLASSIFICATION_SAMPLE = r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, {model_class} + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> inputs = tokenizer( + ... "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="pt" + ... ) + + >>> with torch.no_grad(): + ... logits = model(**inputs).logits + + >>> predicted_token_class_ids = logits.argmax(-1) + + >>> # Note that tokens are classified rather then input words which means that + >>> # there might be more predicted token classes than words. + >>> # Multiple token classes might account for the same word + >>> predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids[0]] + >>> predicted_tokens_classes + {expected_output} + + >>> labels = predicted_token_class_ids + >>> loss = model(**inputs, labels=labels).loss + >>> round(loss.item(), 2) + {expected_loss} + ``` +""" + +PT_QUESTION_ANSWERING_SAMPLE = r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, {model_class} + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + + >>> inputs = tokenizer(question, text, return_tensors="pt") + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> answer_start_index = outputs.start_logits.argmax() + >>> answer_end_index = outputs.end_logits.argmax() + + >>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1] + >>> tokenizer.decode(predict_answer_tokens, skip_special_tokens=True) + {expected_output} + + >>> # target is "nice puppet" + >>> target_start_index = torch.tensor([{qa_target_start_index}]) + >>> target_end_index = torch.tensor([{qa_target_end_index}]) + + >>> outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index) + >>> loss = outputs.loss + >>> round(loss.item(), 2) + {expected_loss} + ``` +""" + +PT_SEQUENCE_CLASSIFICATION_SAMPLE = r""" + Example of single-label classification: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, {model_class} + + >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + + >>> with torch.no_grad(): + ... logits = model(**inputs).logits + + >>> predicted_class_id = logits.argmax().item() + >>> model.config.id2label[predicted_class_id] + {expected_output} + + >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)` + >>> num_labels = len(model.config.id2label) + >>> model = {model_class}.from_pretrained("{checkpoint}", num_labels=num_labels) + + >>> labels = torch.tensor([1]) + >>> loss = model(**inputs, labels=labels).loss + >>> round(loss.item(), 2) + {expected_loss} + ``` + + Example of multi-label classification: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, {model_class} + + >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}", problem_type="multi_label_classification") + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + + >>> with torch.no_grad(): + ... logits = model(**inputs).logits + + >>> predicted_class_ids = torch.arange(0, logits.shape[-1])[torch.sigmoid(logits).squeeze(dim=0) > 0.5] + + >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)` + >>> num_labels = len(model.config.id2label) + >>> model = {model_class}.from_pretrained( + ... "{checkpoint}", num_labels=num_labels, problem_type="multi_label_classification" + ... ) + + >>> labels = torch.sum( + ... torch.nn.functional.one_hot(predicted_class_ids[None, :].clone(), num_classes=num_labels), dim=1 + ... ).to(torch.float) + >>> loss = model(**inputs, labels=labels).loss + ``` +""" + +PT_MASKED_LM_SAMPLE = r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, {model_class} + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="pt") + + >>> with torch.no_grad(): + ... logits = model(**inputs).logits + + >>> # retrieve index of {mask} + >>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0] + + >>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1) + >>> tokenizer.decode(predicted_token_id) + {expected_output} + + >>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"] + >>> # mask labels of non-{mask} tokens + >>> labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100) + + >>> outputs = model(**inputs, labels=labels) + >>> round(outputs.loss.item(), 2) + {expected_loss} + ``` +""" + +PT_BASE_MODEL_SAMPLE = r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, {model_class} + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + ``` +""" + +PT_MULTIPLE_CHOICE_SAMPLE = r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, {model_class} + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + >>> choice0 = "It is eaten with a fork and a knife." + >>> choice1 = "It is eaten while held in the hand." + >>> labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1 + + >>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True) + >>> outputs = model(**{{k: v.unsqueeze(0) for k, v in encoding.items()}}, labels=labels) # batch size is 1 + + >>> # the linear classifier still needs to be trained + >>> loss = outputs.loss + >>> logits = outputs.logits + ``` +""" + +PT_CAUSAL_LM_SAMPLE = r""" + Example: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, {model_class} + + >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs, labels=inputs["input_ids"]) + >>> loss = outputs.loss + >>> logits = outputs.logits + ``` +""" + +PT_SPEECH_BASE_MODEL_SAMPLE = r""" + Example: + + ```python + >>> from transformers import AutoProcessor, {model_class} + >>> import torch + >>> from biofit import load_dataset + + >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") + >>> dataset = dataset.sort("id") + >>> sampling_rate = dataset.features["audio"].sampling_rate + + >>> processor = AutoProcessor.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> # audio file is decoded on the fly + >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt") + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + {expected_output} + ``` +""" + +PT_SPEECH_CTC_SAMPLE = r""" + Example: + + ```python + >>> from transformers import AutoProcessor, {model_class} + >>> from biofit import load_dataset + >>> import torch + + >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") + >>> dataset = dataset.sort("id") + >>> sampling_rate = dataset.features["audio"].sampling_rate + + >>> processor = AutoProcessor.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> # audio file is decoded on the fly + >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt") + >>> with torch.no_grad(): + ... logits = model(**inputs).logits + >>> predicted_ids = torch.argmax(logits, dim=-1) + + >>> # transcribe speech + >>> transcription = processor.batch_decode(predicted_ids) + >>> transcription[0] + {expected_output} + + >>> inputs["labels"] = processor(text=dataset[0]["text"], return_tensors="pt").input_ids + + >>> # compute loss + >>> loss = model(**inputs).loss + >>> round(loss.item(), 2) + {expected_loss} + ``` +""" + +PT_SPEECH_SEQ_CLASS_SAMPLE = r""" + Example: + + ```python + >>> from transformers import AutoFeatureExtractor, {model_class} + >>> from biofit import load_dataset + >>> import torch + + >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") + >>> dataset = dataset.sort("id") + >>> sampling_rate = dataset.features["audio"].sampling_rate + + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> # audio file is decoded on the fly + >>> inputs = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt") + + >>> with torch.no_grad(): + ... logits = model(**inputs).logits + + >>> predicted_class_ids = torch.argmax(logits, dim=-1).item() + >>> predicted_label = model.config.id2label[predicted_class_ids] + >>> predicted_label + {expected_output} + + >>> # compute loss - target_label is e.g. "down" + >>> target_label = model.config.id2label[0] + >>> inputs["labels"] = torch.tensor([model.config.label2id[target_label]]) + >>> loss = model(**inputs).loss + >>> round(loss.item(), 2) + {expected_loss} + ``` +""" + + +PT_SPEECH_FRAME_CLASS_SAMPLE = r""" + Example: + + ```python + >>> from transformers import AutoFeatureExtractor, {model_class} + >>> from biofit import load_dataset + >>> import torch + + >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") + >>> dataset = dataset.sort("id") + >>> sampling_rate = dataset.features["audio"].sampling_rate + + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> # audio file is decoded on the fly + >>> inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt", sampling_rate=sampling_rate) + >>> with torch.no_grad(): + ... logits = model(**inputs).logits + + >>> probabilities = torch.sigmoid(logits[0]) + >>> # labels is a one-hot array of shape (num_frames, num_speakers) + >>> labels = (probabilities > 0.5).long() + >>> labels[0].tolist() + {expected_output} + ``` +""" + + +PT_SPEECH_XVECTOR_SAMPLE = r""" + Example: + + ```python + >>> from transformers import AutoFeatureExtractor, {model_class} + >>> from biofit import load_dataset + >>> import torch + + >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") + >>> dataset = dataset.sort("id") + >>> sampling_rate = dataset.features["audio"].sampling_rate + + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> # audio file is decoded on the fly + >>> inputs = feature_extractor( + ... [d["array"] for d in dataset[:2]["audio"]], sampling_rate=sampling_rate, return_tensors="pt", padding=True + ... ) + >>> with torch.no_grad(): + ... embeddings = model(**inputs).embeddings + + >>> embeddings = torch.nn.functional.normalize(embeddings, dim=-1).cpu() + + >>> # the resulting embeddings can be used for cosine similarity-based retrieval + >>> cosine_sim = torch.nn.CosineSimilarity(dim=-1) + >>> similarity = cosine_sim(embeddings[0], embeddings[1]) + >>> threshold = 0.7 # the optimal threshold is dataset-dependent + >>> if similarity < threshold: + ... print("Speakers are not the same!") + >>> round(similarity.item(), 2) + {expected_output} + ``` +""" + +PT_VISION_BASE_MODEL_SAMPLE = r""" + Example: + + ```python + >>> from transformers import AutoImageProcessor, {model_class} + >>> import torch + >>> from biofit import load_dataset + + >>> dataset = load_dataset("huggingface/cats-image") + >>> image = dataset["test"]["image"][0] + + >>> image_processor = AutoImageProcessor.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> inputs = image_processor(image, return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + {expected_output} + ``` +""" + +PT_VISION_SEQ_CLASS_SAMPLE = r""" + Example: + + ```python + >>> from transformers import AutoImageProcessor, {model_class} + >>> import torch + >>> from biofit import load_dataset + + >>> dataset = load_dataset("huggingface/cats-image") + >>> image = dataset["test"]["image"][0] + + >>> image_processor = AutoImageProcessor.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> inputs = image_processor(image, return_tensors="pt") + + >>> with torch.no_grad(): + ... logits = model(**inputs).logits + + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_label = logits.argmax(-1).item() + >>> print(model.config.id2label[predicted_label]) + {expected_output} + ``` +""" + + +PT_SAMPLE_DOCSTRINGS = { + "SequenceClassification": PT_SEQUENCE_CLASSIFICATION_SAMPLE, + "QuestionAnswering": PT_QUESTION_ANSWERING_SAMPLE, + "TokenClassification": PT_TOKEN_CLASSIFICATION_SAMPLE, + "MultipleChoice": PT_MULTIPLE_CHOICE_SAMPLE, + "MaskedLM": PT_MASKED_LM_SAMPLE, + "LMHead": PT_CAUSAL_LM_SAMPLE, + "BaseModel": PT_BASE_MODEL_SAMPLE, + "SpeechBaseModel": PT_SPEECH_BASE_MODEL_SAMPLE, + "CTC": PT_SPEECH_CTC_SAMPLE, + "AudioClassification": PT_SPEECH_SEQ_CLASS_SAMPLE, + "AudioFrameClassification": PT_SPEECH_FRAME_CLASS_SAMPLE, + "AudioXVector": PT_SPEECH_XVECTOR_SAMPLE, + "VisionBaseModel": PT_VISION_BASE_MODEL_SAMPLE, + "ImageClassification": PT_VISION_SEQ_CLASS_SAMPLE, +} + + +TF_TOKEN_CLASSIFICATION_SAMPLE = r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, {model_class} + >>> import tensorflow as tf + + >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> inputs = tokenizer( + ... "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="tf" + ... ) + + >>> logits = model(**inputs).logits + >>> predicted_token_class_ids = tf.math.argmax(logits, axis=-1) + + >>> # Note that tokens are classified rather then input words which means that + >>> # there might be more predicted token classes than words. + >>> # Multiple token classes might account for the same word + >>> predicted_tokens_classes = [model.config.id2label[t] for t in predicted_token_class_ids[0].numpy().tolist()] + >>> predicted_tokens_classes + {expected_output} + ``` + + ```python + >>> labels = predicted_token_class_ids + >>> loss = tf.math.reduce_mean(model(**inputs, labels=labels).loss) + >>> round(float(loss), 2) + {expected_loss} + ``` +""" + +TF_QUESTION_ANSWERING_SAMPLE = r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, {model_class} + >>> import tensorflow as tf + + >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + + >>> inputs = tokenizer(question, text, return_tensors="tf") + >>> outputs = model(**inputs) + + >>> answer_start_index = int(tf.math.argmax(outputs.start_logits, axis=-1)[0]) + >>> answer_end_index = int(tf.math.argmax(outputs.end_logits, axis=-1)[0]) + + >>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1] + >>> tokenizer.decode(predict_answer_tokens) + {expected_output} + ``` + + ```python + >>> # target is "nice puppet" + >>> target_start_index = tf.constant([{qa_target_start_index}]) + >>> target_end_index = tf.constant([{qa_target_end_index}]) + + >>> outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index) + >>> loss = tf.math.reduce_mean(outputs.loss) + >>> round(float(loss), 2) + {expected_loss} + ``` +""" + +TF_SEQUENCE_CLASSIFICATION_SAMPLE = r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, {model_class} + >>> import tensorflow as tf + + >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf") + + >>> logits = model(**inputs).logits + + >>> predicted_class_id = int(tf.math.argmax(logits, axis=-1)[0]) + >>> model.config.id2label[predicted_class_id] + {expected_output} + ``` + + ```python + >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)` + >>> num_labels = len(model.config.id2label) + >>> model = {model_class}.from_pretrained("{checkpoint}", num_labels=num_labels) + + >>> labels = tf.constant(1) + >>> loss = model(**inputs, labels=labels).loss + >>> round(float(loss), 2) + {expected_loss} + ``` +""" + +TF_MASKED_LM_SAMPLE = r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, {model_class} + >>> import tensorflow as tf + + >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="tf") + >>> logits = model(**inputs).logits + + >>> # retrieve index of {mask} + >>> mask_token_index = tf.where((inputs.input_ids == tokenizer.mask_token_id)[0]) + >>> selected_logits = tf.gather_nd(logits[0], indices=mask_token_index) + + >>> predicted_token_id = tf.math.argmax(selected_logits, axis=-1) + >>> tokenizer.decode(predicted_token_id) + {expected_output} + ``` + + ```python + >>> labels = tokenizer("The capital of France is Paris.", return_tensors="tf")["input_ids"] + >>> # mask labels of non-{mask} tokens + >>> labels = tf.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100) + + >>> outputs = model(**inputs, labels=labels) + >>> round(float(outputs.loss), 2) + {expected_loss} + ``` +""" + +TF_BASE_MODEL_SAMPLE = r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, {model_class} + >>> import tensorflow as tf + + >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf") + >>> outputs = model(inputs) + + >>> last_hidden_states = outputs.last_hidden_state + ``` +""" + +TF_MULTIPLE_CHOICE_SAMPLE = r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, {model_class} + >>> import tensorflow as tf + + >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + >>> choice0 = "It is eaten with a fork and a knife." + >>> choice1 = "It is eaten while held in the hand." + + >>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="tf", padding=True) + >>> inputs = {{k: tf.expand_dims(v, 0) for k, v in encoding.items()}} + >>> outputs = model(inputs) # batch size is 1 + + >>> # the linear classifier still needs to be trained + >>> logits = outputs.logits + ``` +""" + +TF_CAUSAL_LM_SAMPLE = r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, {model_class} + >>> import tensorflow as tf + + >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf") + >>> outputs = model(inputs) + >>> logits = outputs.logits + ``` +""" + +TF_SPEECH_BASE_MODEL_SAMPLE = r""" + Example: + + ```python + >>> from transformers import AutoProcessor, {model_class} + >>> from biofit import load_dataset + + >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") + >>> dataset = dataset.sort("id") + >>> sampling_rate = dataset.features["audio"].sampling_rate + + >>> processor = AutoProcessor.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> # audio file is decoded on the fly + >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="tf") + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + {expected_output} + ``` +""" + +TF_SPEECH_CTC_SAMPLE = r""" + Example: + + ```python + >>> from transformers import AutoProcessor, {model_class} + >>> from biofit import load_dataset + >>> import tensorflow as tf + + >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") + >>> dataset = dataset.sort("id") + >>> sampling_rate = dataset.features["audio"].sampling_rate + + >>> processor = AutoProcessor.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> # audio file is decoded on the fly + >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="tf") + >>> logits = model(**inputs).logits + >>> predicted_ids = tf.math.argmax(logits, axis=-1) + + >>> # transcribe speech + >>> transcription = processor.batch_decode(predicted_ids) + >>> transcription[0] + {expected_output} + ``` + + ```python + >>> inputs["labels"] = processor(text=dataset[0]["text"], return_tensors="tf").input_ids + + >>> # compute loss + >>> loss = model(**inputs).loss + >>> round(float(loss), 2) + {expected_loss} + ``` +""" + +TF_VISION_BASE_MODEL_SAMPLE = r""" + Example: + + ```python + >>> from transformers import AutoImageProcessor, {model_class} + >>> from biofit import load_dataset + + >>> dataset = load_dataset("huggingface/cats-image") + >>> image = dataset["test"]["image"][0] + + >>> image_processor = AutoImageProcessor.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> inputs = image_processor(image, return_tensors="tf") + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + {expected_output} + ``` +""" + +TF_VISION_SEQ_CLASS_SAMPLE = r""" + Example: + + ```python + >>> from transformers import AutoImageProcessor, {model_class} + >>> import tensorflow as tf + >>> from biofit import load_dataset + + >>> dataset = load_dataset("huggingface/cats-image") + >>> image = dataset["test"]["image"][0] + + >>> image_processor = AutoImageProcessor.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> inputs = image_processor(image, return_tensors="tf") + >>> logits = model(**inputs).logits + + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_label = int(tf.math.argmax(logits, axis=-1)) + >>> print(model.config.id2label[predicted_label]) + {expected_output} + ``` +""" + +TF_SAMPLE_DOCSTRINGS = { + "SequenceClassification": TF_SEQUENCE_CLASSIFICATION_SAMPLE, + "QuestionAnswering": TF_QUESTION_ANSWERING_SAMPLE, + "TokenClassification": TF_TOKEN_CLASSIFICATION_SAMPLE, + "MultipleChoice": TF_MULTIPLE_CHOICE_SAMPLE, + "MaskedLM": TF_MASKED_LM_SAMPLE, + "LMHead": TF_CAUSAL_LM_SAMPLE, + "BaseModel": TF_BASE_MODEL_SAMPLE, + "SpeechBaseModel": TF_SPEECH_BASE_MODEL_SAMPLE, + "CTC": TF_SPEECH_CTC_SAMPLE, + "VisionBaseModel": TF_VISION_BASE_MODEL_SAMPLE, + "ImageClassification": TF_VISION_SEQ_CLASS_SAMPLE, +} + + +FLAX_TOKEN_CLASSIFICATION_SAMPLE = r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, {model_class} + + >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax") + + >>> outputs = model(**inputs) + >>> logits = outputs.logits + ``` +""" + +FLAX_QUESTION_ANSWERING_SAMPLE = r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, {model_class} + + >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + >>> inputs = tokenizer(question, text, return_tensors="jax") + + >>> outputs = model(**inputs) + >>> start_scores = outputs.start_logits + >>> end_scores = outputs.end_logits + ``` +""" + +FLAX_SEQUENCE_CLASSIFICATION_SAMPLE = r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, {model_class} + + >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax") + + >>> outputs = model(**inputs) + >>> logits = outputs.logits + ``` +""" + +FLAX_MASKED_LM_SAMPLE = r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, {model_class} + + >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="jax") + + >>> outputs = model(**inputs) + >>> logits = outputs.logits + ``` +""" + +FLAX_BASE_MODEL_SAMPLE = r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, {model_class} + + >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax") + >>> outputs = model(**inputs) + + >>> last_hidden_states = outputs.last_hidden_state + ``` +""" + +FLAX_MULTIPLE_CHOICE_SAMPLE = r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, {model_class} + + >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + >>> choice0 = "It is eaten with a fork and a knife." + >>> choice1 = "It is eaten while held in the hand." + + >>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="jax", padding=True) + >>> outputs = model(**{{k: v[None, :] for k, v in encoding.items()}}) + + >>> logits = outputs.logits + ``` +""" + +FLAX_CAUSAL_LM_SAMPLE = r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, {model_class} + + >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np") + >>> outputs = model(**inputs) + + >>> # retrieve logts for next token + >>> next_token_logits = outputs.logits[:, -1] + ``` +""" + +FLAX_SAMPLE_DOCSTRINGS = { + "SequenceClassification": FLAX_SEQUENCE_CLASSIFICATION_SAMPLE, + "QuestionAnswering": FLAX_QUESTION_ANSWERING_SAMPLE, + "TokenClassification": FLAX_TOKEN_CLASSIFICATION_SAMPLE, + "MultipleChoice": FLAX_MULTIPLE_CHOICE_SAMPLE, + "MaskedLM": FLAX_MASKED_LM_SAMPLE, + "BaseModel": FLAX_BASE_MODEL_SAMPLE, + "LMHead": FLAX_CAUSAL_LM_SAMPLE, +} + + +def filter_outputs_from_example(docstring, **kwargs): + """ + Removes the lines testing an output with the doctest syntax in a code sample when it's set to `None`. + """ + for key, value in kwargs.items(): + if value is not None: + continue + + doc_key = "{" + key + "}" + docstring = re.sub(rf"\n([^\n]+)\n\s+{doc_key}\n", "\n", docstring) + + return docstring + + +def add_code_sample_docstrings( + *docstr, + processor_class=None, + checkpoint=None, + output_type=None, + _config_class=None, + mask="[MASK]", + qa_target_start_index=14, + qa_target_end_index=15, + model_cls=None, + modality=None, + expected_output=None, + expected_loss=None, + real_checkpoint=None, + revision=None, +): + def docstring_decorator(fn): + # model_class defaults to function's class if not specified otherwise + model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls + + if model_class[:2] == "TF": + sample_docstrings = TF_SAMPLE_DOCSTRINGS + elif model_class[:4] == "Flax": + sample_docstrings = FLAX_SAMPLE_DOCSTRINGS + else: + sample_docstrings = PT_SAMPLE_DOCSTRINGS + + # putting all kwargs for docstrings in a dict to be used + # with the `.format(**doc_kwargs)`. Note that string might + # be formatted with non-existing keys, which is fine. + doc_kwargs = { + "model_class": model_class, + "processor_class": processor_class, + "checkpoint": checkpoint, + "mask": mask, + "qa_target_start_index": qa_target_start_index, + "qa_target_end_index": qa_target_end_index, + "expected_output": expected_output, + "expected_loss": expected_loss, + "real_checkpoint": real_checkpoint, + "fake_checkpoint": checkpoint, + "true": "{true}", # For syntax that conflicts with formatting. + } + + if ( + "SequenceClassification" in model_class + or "AudioClassification" in model_class + ) and modality == "audio": + code_sample = sample_docstrings["AudioClassification"] + elif "SequenceClassification" in model_class: + code_sample = sample_docstrings["SequenceClassification"] + elif "QuestionAnswering" in model_class: + code_sample = sample_docstrings["QuestionAnswering"] + elif "TokenClassification" in model_class: + code_sample = sample_docstrings["TokenClassification"] + elif "MultipleChoice" in model_class: + code_sample = sample_docstrings["MultipleChoice"] + elif "MaskedLM" in model_class or model_class in [ + "FlaubertWithLMHeadModel", + "XLMWithLMHeadModel", + ]: + code_sample = sample_docstrings["MaskedLM"] + elif "LMHead" in model_class or "CausalLM" in model_class: + code_sample = sample_docstrings["LMHead"] + elif "CTC" in model_class: + code_sample = sample_docstrings["CTC"] + elif "AudioFrameClassification" in model_class: + code_sample = sample_docstrings["AudioFrameClassification"] + elif "XVector" in model_class and modality == "audio": + code_sample = sample_docstrings["AudioXVector"] + elif "Model" in model_class and modality == "audio": + code_sample = sample_docstrings["SpeechBaseModel"] + elif "Model" in model_class and modality == "vision": + code_sample = sample_docstrings["VisionBaseModel"] + elif "Model" in model_class or "Encoder" in model_class: + code_sample = sample_docstrings["BaseModel"] + elif "ImageClassification" in model_class: + code_sample = sample_docstrings["ImageClassification"] + else: + raise ValueError(f"Docstring can't be built for model {model_class}") + + code_sample = filter_outputs_from_example( + code_sample, expected_output=expected_output, expected_loss=expected_loss + ) + if real_checkpoint is not None: + code_sample = FAKE_MODEL_DISCLAIMER + code_sample + func_doc = (fn.__doc__ or "") + "".join(docstr) + output_doc = ( + "" + if output_type is None + else _prepare_output_docstrings(output_type, _config_class) + ) + built_doc = code_sample.format(**doc_kwargs) + if revision is not None: + if re.match(r"^refs/pr/\\d+", revision): + raise ValueError( + f"The provided revision '{revision}' is incorrect. It should point to" + " a pull request reference on the hub like 'refs/pr/6'" + ) + built_doc = built_doc.replace( + f'from_pretrained("{checkpoint}")', + f'from_pretrained("{checkpoint}", revision="{revision}")', + ) + fn.__doc__ = func_doc + output_doc + built_doc + return fn + + return docstring_decorator + + +def replace_return_docstrings(output_type=None, _config_class=None): + def docstring_decorator(fn): + func_doc = fn.__doc__ + lines = func_doc.split("\n") + i = 0 + while i < len(lines) and re.search(r"^\s*Returns?:\s*$", lines[i]) is None: + i += 1 + if i < len(lines): + indent = len(_get_indent(lines[i])) + lines[i] = _prepare_output_docstrings( + output_type, _config_class, min_indent=indent + ) + func_doc = "\n".join(lines) + else: + raise ValueError( + f"The function {fn} should have an empty 'Return:' or 'Returns:' in its docstring as placeholder, " + f"current docstring is:\n{func_doc}" + ) + fn.__doc__ = func_doc + return fn + + return docstring_decorator + + +def copy_func(f): + """Returns a copy of a function f.""" + # Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard) + g = types.FunctionType( + f.__code__, + f.__globals__, + name=f.__name__, + argdefs=f.__defaults__, + closure=f.__closure__, + ) + g = functools.update_wrapper(g, f) + g.__kwdefaults__ = f.__kwdefaults__ + return g diff --git a/src/biofit/utils/file_utils.py b/src/biofit/utils/file_utils.py new file mode 100644 index 0000000..0d780ac --- /dev/null +++ b/src/biofit/utils/file_utils.py @@ -0,0 +1,234 @@ +""" +This file is adapted from the datasets library, which in turn is adapted from the AllenNLP library. + +datasets +~~~~~~~~ +Utilities for working with the local dataset cache. +This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp +Copyright by the AllenNLP authors. +""" + +import io +import os +import posixpath +import shutil +import sys +import tempfile +from pathlib import Path +from typing import Optional, TypeVar, Union +from urllib.parse import unquote, urlparse + +from huggingface_hub.utils import insecure_hashlib + +from . import logging + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +INCOMPLETE_SUFFIX = ".incomplete" + +PathLike = Union[str, Path] +T = TypeVar("T", str, Path) + + +def is_remote_url(url_or_filename: Union[str, Path]) -> bool: + return urlparse(str(url_or_filename)).scheme != "" and not os.path.ismount( + urlparse(str(url_or_filename)).scheme + ":/" + ) + + +def is_local_path(url_or_filename: Union[str, Path]) -> bool: + # On unix the scheme of a local path is empty (for both absolute and relative), + # while on windows the scheme is the drive name (ex: "c") for absolute paths. + # for details on the windows behavior, see https://bugs.python.org/issue42215 + url_or_filename = Path(url_or_filename).resolve().as_posix() + + return urlparse(url_or_filename).scheme == "" or os.path.ismount( + urlparse(url_or_filename).scheme + ":/" + ) + + +def is_relative_path(url_or_filename: Union[str, Path]) -> bool: + return urlparse(str(url_or_filename)).scheme == "" and not os.path.isabs( + str(url_or_filename) + ) + + +def expand_path(path): + """ + Check if a path is relative and expand it if necessary. + Handles file paths and URLs, including user home directory expansion. + + :param path: str representing the path or URL to check and expand + :return: str with the expanded absolute path or full URL + """ + # Parse the path as a URL + parsed = urlparse(path) + + # If it's a URL (not a local file path) + if parsed.scheme and parsed.netloc: + return path # Return the original URL + + # If it's a file URL, convert it to a local path + if parsed.scheme == "file": + path = unquote(parsed.path) + # On Windows, remove leading slash + if sys.platform == "win32" and path.startswith("/"): + path = path[1:] + else: + # It's a regular path, use the full original string + path = unquote(path) + + # Convert to Path object + path = Path(path) + + # Expand user's home directory if present + path = path.expanduser() + + # Check if the path is absolute + if path.is_absolute(): + return str(path.resolve()) + + # If path is relative, make it absolute + return os.path.normpath(str((Path.cwd() / path).resolve())) + + +def relative_to_absolute_path(path: T) -> T: + """Convert relative path to absolute path.""" + abs_path_str = os.path.abspath(os.path.expanduser(os.path.expandvars(str(path)))) + return Path(abs_path_str) if isinstance(path, Path) else abs_path_str + + +def is_file_name(url_or_path_or_file: T) -> bool: + if is_local_path(url_or_path_or_file): + if is_relative_path(url_or_path_or_file): + if "/" not in Path(url_or_path_or_file).as_posix(): + return True + return False + + +def has_ext(url_or_path: Union[str, Path], ext: Optional[str] = None) -> bool: + if ext is None: + return Path(url_or_path).suffix != "" + return Path(url_or_path).suffix == ext + + +def has_separator(url_or_path: Union[str, Path]) -> bool: + return "/" in str(url_or_path) or "\\" in str(url_or_path) + + +def url_or_path_join(base_name: str, *pathnames: str) -> str: + if is_remote_url(base_name): + return posixpath.join( + base_name, + *(str(pathname).replace(os.sep, "/").lstrip("/") for pathname in pathnames), + ) + else: + return Path(base_name, *pathnames).as_posix() + + +def url_or_path_parent(url_or_path: str) -> str: + if is_remote_url(url_or_path): + return url_or_path[: url_or_path.rindex("/")] + else: + return os.path.dirname(url_or_path) + + +def hash_url_to_filename(url, etag=None): + """ + Convert `url` into a hashed filename in a repeatable way. + If `etag` is specified, append its hash to the url's, delimited + by a period. + If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name + so that TF 2.0 can identify it as a HDF5 file + (see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380) + """ + url_bytes = url.encode("utf-8") + url_hash = insecure_hashlib.sha256(url_bytes) + filename = url_hash.hexdigest() + + if etag: + etag_bytes = etag.encode("utf-8") + etag_hash = insecure_hashlib.sha256(etag_bytes) + filename += "." + etag_hash.hexdigest() + + if url.endswith(".py"): + filename += ".py" + + return filename + + +def add_start_docstrings(*docstr): + def docstring_decorator(fn): + fn.__doc__ = ( + "".join(docstr) + "\n\n" + (fn.__doc__ if fn.__doc__ is not None else "") + ) + return fn + + return docstring_decorator + + +def add_end_docstrings(*docstr): + def docstring_decorator(fn): + fn.__doc__ = ( + (fn.__doc__ if fn.__doc__ is not None else "") + "\n\n" + "".join(docstr) + ) + return fn + + return docstring_decorator + + +def estimate_dataset_size(paths): + return sum(path.stat().st_size for path in paths) + + +def readline(f: io.RawIOBase): + # From: https://github.com/python/cpython/blob/d27e2f4d118e7a9909b6a3e5da06c5ff95806a85/Lib/_pyio.py#L525 + res = bytearray() + while True: + b = f.read(1) + if not b: + break + res += b + if res.endswith(b"\n"): + break + return bytes(res) + + +def move_temp_file( + temp_file: Union[tempfile._TemporaryFileWrapper, str, Path], final_file: str +): + if isinstance(temp_file, tempfile._TemporaryFileWrapper): + temp_file.close() + temp_file_name = Path(temp_file.name) + elif not isinstance(temp_file, Path): + temp_file_name = Path(temp_file) + + if not isinstance(final_file, Path): + final_file = Path(final_file) + + if temp_file_name.exists(): + if final_file.exists(): + final_file.unlink() + # is source a windows path? + if temp_file_name.resolve().drive: + if len(str(temp_file_name)) > 255: + src_file = "\\\\?\\" + str(temp_file_name.resolve()) + else: + src_file = temp_file_name.as_posix() + else: + src_file = temp_file_name.as_posix() + + if final_file.resolve().drive: + if len(str(final_file)) > 255: + dst_file = "\\\\?\\" + str(final_file.resolve()) + else: + dst_file = final_file.as_posix() + else: + dst_file = final_file.as_posix() + shutil.move(src_file, dst_file) + umask = os.umask(0o666) + os.umask(umask) + os.chmod(dst_file, 0o666 & ~umask) + return dst_file + else: + raise FileNotFoundError(f"Temporary file {temp_file_name.name} not found.") diff --git a/src/biofit/utils/fingerprint.py b/src/biofit/utils/fingerprint.py new file mode 100644 index 0000000..38c3976 --- /dev/null +++ b/src/biofit/utils/fingerprint.py @@ -0,0 +1,400 @@ +""" +NOTE: The contents of this file have been inlined from the fingerprint and _dill module in the datasets package's source code +https://github.com/huggingface/datasets/blob/c47cc141c9e6e0edafffdcfde55b171612f1de76/src/datasets/fingerprint.py + +This module has fixes / adaptations for biofit use cases that make it different from the original +datasets library + +The following modifications have been made: + - Added python source code parsing logic to the `Hasher` class. This is used to hash the contents of a python file + without comments since we only care about the code that is actually executed. + - Added the _hash_python_lines function to this file without importing + huggingface_hub + +datasets +~~~~~~~~ +Copyright 2023 The HuggingFace Team. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import functools +import hashlib +import importlib +import inspect +import os +import posixpath +import random +import re +from pathlib import Path +import sys +from types import FunctionType, ModuleType +from typing import Any, Dict, List, Union + +from biocore.utils.import_util import is_datasets_available +import xxhash +from .version import Version +from ._dill import dumps + +import biofit.config as config +import biofit.utils.logging as logging +from biofit.utils.version import __version__ + +from .file_utils import is_file_name, is_remote_url + +_CACHING_ENABLED = True + +logger = logging.get_logger(__name__) + +fingerprint_rng = random.Random() + + +def _hash_python_lines(lines: List[str]) -> str: + filtered_lines = [] + for line in lines: + line = re.sub(r"#.*", "", line) # remove comments + if line: + filtered_lines.append(line) + full_str = "\n".join(filtered_lines) + + # Make a hash from all this code + full_bytes = full_str.encode("utf-8") + _kwargs = {"usedforsecurity": False} if sys.version_info >= (3, 9) else {} + sha256 = functools.partial(hashlib.sha256, **_kwargs) + return sha256(full_bytes).hexdigest() + + +def generate_random_fingerprint(nbits: int = 64) -> str: + return f"{fingerprint_rng.getrandbits(nbits):0{nbits//4}x}" + + +class Hasher: + """Hasher that accepts python objects as inputs.""" + + dispatch: Dict = {} + + def __init__(self): + self.m = xxhash.xxh64() + + @classmethod + def hash_bytes(cls, value: Union[bytes, List[bytes]]) -> str: + value = [value] if isinstance(value, bytes) else value + m = xxhash.xxh64() + for x in value: + m.update(x) + return m.hexdigest() + + @classmethod + def hash(cls, value: Any) -> str: + return cls.hash_bytes(dumps(value)) + + @staticmethod + def _hash_python_lines(self, module: Union[ModuleType, FunctionType, type]) -> str: + filtered_lines = [] + lines = inspect.getsource(module).splitlines() + for line in lines: + line = re.sub(r"#.*", "", line) # remove comments + if line: + filtered_lines.append(line) + return "\n".join(filtered_lines) + + def update(self, value: Any) -> None: + header_for_update = f"=={type(value)}==" + value_for_update = self.hash(value) + self.m.update(header_for_update.encode("utf8")) + self.m.update(value_for_update.encode("utf-8")) + + def hexdigest(self) -> str: + return self.m.hexdigest() + + +def enable_caching(): + """ + When applying transforms on a dataset, the data are stored in cache files. + The caching mechanism allows to reload an existing cache file if it's already been computed. + + Reloading a dataset is possible since the cache files are named using the dataset fingerprint, which is updated + after each transform. + + If disabled, the library will no longer reload cached datasets files when applying transforms to the datasets. + More precisely, if the caching is disabled: + - cache files are always recreated + - cache files are written to a temporary directory that is deleted when session closes + - cache files are named using a random hash instead of the dataset fingerprint + - use [`~datasets.Dataset.save_to_disk`] to save a transformed dataset or it will be deleted when session closes + - caching doesn't affect [`~datasets.load_dataset`]. If you want to regenerate a dataset from scratch you should use + the `download_mode` parameter in [`~datasets.load_dataset`]. + """ + global _CACHING_ENABLED + _CACHING_ENABLED = True + + +def disable_caching(): + """ + When applying transforms on a dataset, the data are stored in cache files. + The caching mechanism allows to reload an existing cache file if it's already been computed. + + Reloading a dataset is possible since the cache files are named using the dataset fingerprint, which is updated + after each transform. + + If disabled, the library will no longer reload cached datasets files when applying transforms to the datasets. + More precisely, if the caching is disabled: + - cache files are always recreated + - cache files are written to a temporary directory that is deleted when session closes + - cache files are named using a random hash instead of the dataset fingerprint + - use [`~datasets.Dataset.save_to_disk`] to save a transformed dataset or it will be deleted when session closes + - caching doesn't affect [`~datasets.load_dataset`]. If you want to regenerate a dataset from scratch you should use + the `download_mode` parameter in [`~datasets.load_dataset`]. + """ + global _CACHING_ENABLED + _CACHING_ENABLED = False + + +def is_caching_enabled() -> bool: + """ + When applying transforms on a dataset, the data are stored in cache files. + The caching mechanism allows to reload an existing cache file if it's already been computed. + + Reloading a dataset is possible since the cache files are named using the dataset fingerprint, which is updated + after each transform. + + If disabled, the library will no longer reload cached datasets files when applying transforms to the datasets. + More precisely, if the caching is disabled: + - cache files are always recreated + - cache files are written to a temporary directory that is deleted when session closes + - cache files are named using a random hash instead of the dataset fingerprint + """ + global _CACHING_ENABLED + return bool(_CACHING_ENABLED) + + +def fingerprint_from_kwargs(fingerprint, kwargs): + hash = Hasher() + if fingerprint: + hash.update(fingerprint) + + for key, value in kwargs.items(): + if isinstance(key, str) and "fingerprint" in key: + continue + hash.update(key) + if isinstance(value, dict): + hash.update(fingerprint_from_kwargs(fingerprint, value)) + else: + hash.update(str(value)) + + return hash.hexdigest() + + +def fingerprint_from_data(data): + from biocore import DataHandler, get_data_format + + if hasattr(data, "_fingerprint"): + return data._fingerprint + if hasattr(data, "fingerprint"): + return data.fingerprint + hasher = Hasher() + original_format = get_data_format(data) + + if original_format is not None: + features = DataHandler.get_column_names(data, generate_cols=True) + n_samples = DataHandler.get_shape(data)[0] + first_row = DataHandler.select_row(data, 0) + hasher.update(original_format) + hasher.update(features) + hasher.update(n_samples) + hasher.update(first_row) + elif hasattr(data, "__dict__"): + state = data.__dict__ + for key in sorted(state): + hasher.update(key) + try: + hasher.update(state[key]) + except Exception: + hasher.update(str(state[key])) + else: + raise ValueError("Data object is not hashable") + return hasher.hexdigest() + + +def generate_cache_dir( + X, fingerprint, cache_dir=None, root_dir=config.BIOFIT_DATASETS_CACHE +): + if cache_dir is None: + if isinstance(root_dir, Path): + root_dir = root_dir.as_posix() + if ( + root_dir == config.BIOFIT_DATASETS_CACHE.as_posix() + and hasattr(X, "cache_files") + and X.cache_files + ): + cache_dir = os.path.dirname(X.cache_files[0]["filename"]) + else: + cache_dir = _build_cache_dir( + X, + fingerprint, + cache_dir_root=root_dir, + ) + + if cache_dir: + Path(cache_dir).mkdir(parents=True, exist_ok=True) + return Path(cache_dir).resolve().as_posix() + return None + + +def get_cache_file_name(cache_dir, fingerprint, cache_file_name=None): + if cache_file_name: + if is_file_name(cache_file_name): + cache_file_name = Path(cache_dir) / Path(cache_file_name).with_suffix( + ".json" + ) + else: + cache_file_name = Path(cache_file_name).with_suffix(".json") + else: + cache_file_name = Path(cache_dir) / f"cache-{fingerprint}.json" + return cache_file_name.resolve().as_posix() + + +def _relative_data_dir( + fingerprint, builder_name, dataset_name, version=None, with_hash=True +) -> str: + """ + Constructs a relative directory path for a dataset based on its properties. + + Args: + dataset (Dataset): The dataset for which to construct the path. + with_version (bool, optional): Include version information in the path. + with_hash (bool, optional): Include hash information in the path. + + Returns: + str: Relative path for the dataset. + """ + builder_data_dir = posixpath.join(builder_name, f"{dataset_name}-{fingerprint}") + if version: + version = str(version) if isinstance(version, str) else __version__ + builder_data_dir = posixpath.join(builder_data_dir, version) + if with_hash and is_datasets_available(): + hash = _hash_python_lines( + inspect.getsource( + getattr(importlib.import_module("datasets.table"), "InMemoryTable") + ) + ) + builder_data_dir = posixpath.join(builder_data_dir, hash) + return builder_data_dir + + +def _build_cache_dir( + obj, + fingerprint: str, + cache_dir_root: str = config.BIOFIT_DATASETS_CACHE, +) -> str: + """ + Builds the cache directory path for storing processed dataset versions. + + Args: + dataset (Union[Dataset, IterableDataset]): The dataset to cache. + cache_dir_root (str, optional): Root directory for caching datasets. + + Returns: + str: The path to the dataset's cache directory. + """ + if ( + hasattr(obj, "version") + and hasattr(obj, "config_name") + and hasattr(obj, "builder_name") + ): + version = str(obj.version) if isinstance(obj.version, str) else __version__ + dataset_name = obj.config_name or "default" + builder_name = obj.builder_name or "in_memory" + elif ( + hasattr(obj, "config") + and hasattr(obj.config, "version") + and hasattr(obj.config, "processor_name") + and hasattr(obj.config, "processor_type") + ): + version = ( + str(obj.config.version) + if obj.config.version and not str(obj.config.version) == "0.0.0" + else __version__ + ) + dataset_name = obj.config.processor_name or "default" + builder_name = obj.config.processor_type or "in" + + else: + version = __version__ + dataset_name = "default" + builder_name = "in_memory" + + builder_data_dir = posixpath.join( + cache_dir_root, + _relative_data_dir( + fingerprint=fingerprint, + builder_name=builder_name, + dataset_name=dataset_name, + ), + ) + version_data_dir = posixpath.join( + cache_dir_root, + _relative_data_dir( + fingerprint=fingerprint, + builder_name=builder_name, + dataset_name=dataset_name, + version=version, + ), + ) + + def _other_versions_on_disk(): + """Returns previous versions on disk.""" + if not os.path.exists(builder_data_dir): + return [] + + version_dirnames = [] + for dir_name in os.listdir(builder_data_dir): + try: + version_dirnames.append((Version(dir_name), dir_name)) + except ValueError: # Invalid version (ex: incomplete data dir) + pass + version_dirnames.sort(reverse=True) + return version_dirnames + # Check and warn if other versions exist + + if not is_remote_url(builder_data_dir): + version_dirs = _other_versions_on_disk() + if version_dirs: + other_version = version_dirs[0][0] + if other_version != version: + warn_msg = ( + f"Found a different version {str(other_version)} of dataset {dataset_name} in " + f"cache_dir {cache_dir_root}. Using currently defined version " + f"{str(version)}." + ) + logger.warning(warn_msg) + if not os.path.exists(version_data_dir): + os.makedirs(version_data_dir, exist_ok=True) + + return version_data_dir + + +def update_fingerprint(fingerprint, value, key=None): + if key is None and value is None: + return fingerprint + hasher = Hasher() + if fingerprint: + hasher.update(fingerprint) + if key: + hasher.update(key) + try: + hasher.update(value) + except Exception: + return generate_random_fingerprint() + else: + return hasher.hexdigest() diff --git a/src/biofit/utils/generic.py b/src/biofit/utils/generic.py new file mode 100644 index 0000000..42365e1 --- /dev/null +++ b/src/biofit/utils/generic.py @@ -0,0 +1,536 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Generic utilities +""" + +import inspect +import tempfile +import contextlib +import shutil +import os +import stat +from pathlib import Path +from collections import OrderedDict, UserDict +from collections.abc import MutableMapping +from contextlib import ExitStack, contextmanager +from dataclasses import fields, is_dataclass +from enum import Enum +from typing import Any, ContextManager, List, Tuple, Union, Optional, Generator + +import numpy as np + +# optional imports +try: + import polars as pl +except ImportError: + pl = None + + +class cached_property(property): + """ + Descriptor that mimics @property but caches output in member variable. + + From tensorflow_datasets + + Built-in in functools from Python 3.8. + """ + + def __get__(self, obj, objtype=None): + # See docs.python.org/3/howto/descriptor.html#properties + if obj is None: + return self + if self.fget is None: + raise AttributeError("unreadable attribute") + attr = "__cached_" + self.fget.__name__ + cached = getattr(obj, attr, None) + if cached is None: + cached = self.fget(obj) + setattr(obj, attr, cached) + return cached + + +def _set_write_permission_and_retry(func, path, excinfo): + os.chmod(path, stat.S_IWRITE) + func(path) + + +@contextlib.contextmanager +def SoftTemporaryDirectory( + suffix: Optional[str] = None, + prefix: Optional[str] = None, + dir: Optional[Union[Path, str]] = None, + **kwargs, +) -> Generator[str, None, None]: + """ + Context manager to create a temporary directory and safely delete it. + + If tmp directory cannot be deleted normally, we set the WRITE permission and retry. + If cleanup still fails, we give up but don't raise an exception. This is equivalent + to `tempfile.TemporaryDirectory(..., ignore_cleanup_errors=True)` introduced in + Python 3.10. + + See https://www.scivision.dev/python-tempfile-permission-error-windows/. + """ + tmpdir = tempfile.TemporaryDirectory( + prefix=prefix, suffix=suffix, dir=dir, **kwargs + ) + yield tmpdir.name + + try: + # First once with normal cleanup + shutil.rmtree(tmpdir.name) + except Exception: + # If failed, try to set write permission and retry + try: + shutil.rmtree(tmpdir.name, onerror=_set_write_permission_and_retry) + except Exception: + pass + + # And finally, cleanup the tmpdir. + # If it fails again, give up but do not throw error + try: + tmpdir.cleanup() + except Exception: + pass + + +# vendored from distutils.util +def strtobool(val): + """Convert a string representation of truth to true (1) or false (0). + + True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'. + Raises ValueError if 'val' is anything else. + """ + val = val.lower() + if val in {"y", "yes", "t", "true", "on", "1"}: + return 1 + if val in {"n", "no", "f", "false", "off", "0"}: + return 0 + raise ValueError(f"invalid truth value {val!r}") + + +def infer_framework_from_repr(x): + """ + Tries to guess the framework of an object `x` from its repr (brittle but will help in `is_tensor` to try the + frameworks in a smart order, without the need to import the frameworks). + """ + representation = str(type(x)) + if representation.startswith(" + + You can't unpack a `ModelOutput` directly. Use the [`~utils.ModelOutput.to_tuple`] method to convert it to a tuple + before. + + + """ + + def __init_subclass__(cls) -> None: + """Register subclasses as pytree nodes. + + This is necessary to synchronize gradients when using `torch.nn.parallel.DistributedDataParallel` with + `static_graph=True` with modules that output `ModelOutput` subclasses. + """ + pass + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Subclasses of ModelOutput must use the @dataclass decorator + # This check is done in __init__ because the @dataclass decorator operates after __init_subclass__ + # issubclass() would return True for issubclass(ModelOutput, ModelOutput) when False is needed + # Just need to check that the current class is not ModelOutput + is_modeloutput_subclass = self.__class__ != ModelOutput + + if is_modeloutput_subclass and not is_dataclass(self): + raise TypeError( + f"{self.__module__}.{self.__class__.__name__} is not a dataclasss." + " This is a subclass of ModelOutput and so must use the @dataclass decorator." + ) + + def __post_init__(self): + """Check the ModelOutput dataclass. + + Only occurs if @dataclass decorator has been used. + """ + class_fields = fields(self) + + # Safety and consistency checks + if not len(class_fields): + raise ValueError(f"{self.__class__.__name__} has no fields.") + if not all(field.default is None for field in class_fields[1:]): + raise ValueError( + f"{self.__class__.__name__} should not have more than one required field." + ) + + first_field = getattr(self, class_fields[0].name) + other_fields_are_none = all( + getattr(self, field.name) is None for field in class_fields[1:] + ) + + if other_fields_are_none and not is_tensor(first_field): + if isinstance(first_field, dict): + iterator = first_field.items() + first_field_iterator = True + else: + try: + iterator = iter(first_field) + first_field_iterator = True + except TypeError: + first_field_iterator = False + + # if we provided an iterator as first field and the iterator is a (key, value) iterator + # set the associated fields + if first_field_iterator: + for idx, element in enumerate(iterator): + if ( + not isinstance(element, (list, tuple)) + or not len(element) == 2 + or not isinstance(element[0], str) + ): + if idx == 0: + # If we do not have an iterator of key/values, set it as attribute + self[class_fields[0].name] = first_field + else: + # If we have a mixed iterator, raise an error + raise ValueError( + f"Cannot set key/value for {element}. It needs to be a tuple (key, value)." + ) + break + setattr(self, element[0], element[1]) + if element[1] is not None: + self[element[0]] = element[1] + elif first_field is not None: + self[class_fields[0].name] = first_field + else: + for field in class_fields: + v = getattr(self, field.name) + if v is not None: + self[field.name] = v + + def __delitem__(self, *args, **kwargs): + raise Exception( + f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance." + ) + + def setdefault(self, *args, **kwargs): + raise Exception( + f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance." + ) + + def pop(self, *args, **kwargs): + raise Exception( + f"You cannot use ``pop`` on a {self.__class__.__name__} instance." + ) + + def update(self, *args, **kwargs): + raise Exception( + f"You cannot use ``update`` on a {self.__class__.__name__} instance." + ) + + def __getitem__(self, k): + if isinstance(k, str): + inner_dict = dict(self.items()) + return inner_dict[k] + else: + return self.to_tuple()[k] + + def __setattr__(self, name, value): + if name in self.keys() and value is not None: + # Don't call self.__setitem__ to avoid recursion errors + super().__setitem__(name, value) + super().__setattr__(name, value) + + def __setitem__(self, key, value): + # Will raise a KeyException if needed + super().__setitem__(key, value) + # Don't call self.__setattr__ to avoid recursion errors + super().__setattr__(key, value) + + def __reduce__(self): + if not is_dataclass(self): + return super().__reduce__() + callable, _args, *remaining = super().__reduce__() + args = tuple(getattr(self, field.name) for field in fields(self)) + return callable, args, *remaining + + def to_tuple(self) -> Tuple[Any]: + """ + Convert self to a tuple containing all the attributes/keys that are not `None`. + """ + return tuple(self[k] for k in self.keys()) + + +class ExplicitEnum(str, Enum): + """ + Enum with more explicit error message for missing values. + """ + + @classmethod + def _missing_(cls, value): + raise ValueError( + f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}" + ) + + +class ContextManagers: + """ + Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers` + in the `fastcore` library. + """ + + def __init__(self, context_managers: List[ContextManager]): + self.context_managers = context_managers + self.stack = ExitStack() + + def __enter__(self): + for context_manager in self.context_managers: + self.stack.enter_context(context_manager) + + def __exit__(self, *args, **kwargs): + self.stack.__exit__(*args, **kwargs) + + +def can_return_loss(model_class): + """ + Check if a given model can return loss. + + Args: + model_class (`type`): The class of the model. + """ + framework = infer_framework(model_class) + if framework == "tf": + signature = inspect.signature(model_class.call) # TensorFlow models + elif framework == "pt": + signature = inspect.signature(model_class.forward) # PyTorch models + else: + signature = inspect.signature(model_class.__call__) # Flax models + + for p in signature.parameters: + if p == "return_loss" and signature.parameters[p].default is True: + return True + + return False + + +def find_labels(model_class): + """ + Find the labels used by a given model. + + Args: + model_class (`type`): The class of the model. + """ + model_name = model_class.__name__ + infer_framework(model_class) + signature = inspect.signature(model_class.__call__) # Flax models + + if "QuestionAnswering" in model_name: + return [ + p + for p in signature.parameters + if "label" in p or p in ("start_positions", "end_positions") + ] + else: + return [p for p in signature.parameters if "label" in p] + + +def flatten_dict(d: MutableMapping, parent_key: str = "", delimiter: str = "."): + """Flatten a nested dict into a single level dict.""" + + def _flatten_dict(d, parent_key="", delimiter="."): + for k, v in d.items(): + key = str(parent_key) + delimiter + str(k) if parent_key else k + if v and isinstance(v, MutableMapping): + yield from flatten_dict(v, key, delimiter=delimiter).items() + else: + yield key, v + + return dict(_flatten_dict(d, parent_key, delimiter)) + + +@contextmanager +def working_or_temp_dir(working_dir, use_temp_dir: bool = False): + if use_temp_dir: + with tempfile.TemporaryDirectory() as tmp_dir: + yield tmp_dir + else: + yield working_dir + + +def add_model_info_to_auto_map(auto_map, repo_id): + """ + Adds the information of the repo_id to a given auto map. + """ + for key, value in auto_map.items(): + if isinstance(value, (tuple, list)): + auto_map[key] = [ + f"{repo_id}--{v}" if (v is not None and "--" not in v) else v + for v in value + ] + elif value is not None and "--" not in value: + auto_map[key] = f"{repo_id}--{value}" + + return auto_map + + +def infer_framework(model_class): + """ + Infers the framework of a given model without using isinstance(), because we cannot guarantee that the relevant + classes are imported or available. + """ + for base_class in inspect.getmro(model_class): + module = base_class.__module__ + name = base_class.__name__ + if ( + module.startswith("tensorflow") + or module.startswith("keras") + or name == "TFPreTrainedModel" + ): + return "tf" + elif module.startswith("torch") or name == "PreTrainedModel": + return "pt" + elif ( + module.startswith("flax") + or module.startswith("jax") + or name == "FlaxPreTrainedModel" + ): + return "flax" + else: + raise TypeError(f"Could not infer framework from class {model_class}.") diff --git a/src/biofit/utils/gorilla.py b/src/biofit/utils/gorilla.py new file mode 100644 index 0000000..7462939 --- /dev/null +++ b/src/biofit/utils/gorilla.py @@ -0,0 +1,1009 @@ +# __ __ __ +# .-----.-----.----|__| | .---.-. +# | _ | _ | _| | | | _ | +# |___ |_____|__| |__|__|__|___._| +# |_____| +# + +""" +NOTE: The contents of this file have been inlined from the gorilla package's source code +https://github.com/christophercrouzet/gorilla/blob/v0.3.0/gorilla.py + +This module has fixes / adaptations for MLflow use cases that make it different from the original +gorilla library + +The following modifications have been made: + - Modify `get_original_attribute` logic, search from children classes to parent classes, + and for each class check "_gorilla_original_{attr_name}" attribute first. + first. This will ensure get the correct original attribute in any cases, e.g., + the case some classes in the hierarchy haven't been patched, but some others are + patched, this case the previous code is risky to get wrong original attribute. + - Make `get_original_attribute` support bypassing descriptor protocol. + - remove `get_attribute` method, use `get_original_attribute` with + `bypass_descriptor_protocol=True` instead of calling it. + - After reverting patch, there will be no side-effect, restore object to be exactly the + original status. + - Remove `create_patches` and `patches` methods. + +gorilla +~~~~~~~ + +Convenient approach to monkey patching. + +:copyright: Copyright 2014-2017 by Christopher Crouzet. +:license: MIT, see LICENSE for details. +""" + +import base64 +import collections +import copy +import importlib +import importlib.util +import inspect +import json +import pkgutil +import sys +from types import ModuleType +from typing import Optional, Union + +import dill +from biofit.utils import logging + +__version__ = "0.3.0" +logger = logging.get_logger(__name__) + + +class PackageNotFoundError(Exception): + """Package not found error.""" + + +class AttributeNotFoundError(Exception): + """Attribute not found error.""" + + +class SameSourceAndDestinationError(Exception): + """Same source and destination error.""" + + +def _iteritems(d, **kwargs): + return iter(d.items(**kwargs)) + + +def _load_module(name): + return importlib.import_module(name) + + +# Pattern for each internal attribute name. +_PATTERN = "_gorilla_%s" + +# Pattern for the name of the overidden attributes to be stored. +_ORIGINAL_NAME = _PATTERN % ("original_%s",) + +# Pattern for the name of the patch attributes to be stored. +_ACTIVE_PATCH = "_gorilla_active_patch_%s" + +# Attribute for the decorator data. +_DECORATOR_DATA = _PATTERN % ("decorator_data",) + + +def default_filter(name, obj): + """Attribute filter. + + It filters out module attributes, and also methods starting with an + underscore `_`. + + This is used as the default filter for the :func:`create_patches` function + and the :func:`patches` decorator. + + Parameters + ---------- + name : str + Attribute name. + obj : object + Attribute value. + + Returns + ------- + bool + Whether the attribute should be returned. + """ + return not (isinstance(obj, ModuleType) or name.startswith("_")) + + +class DecoratorData: + """Decorator data. + + Attributes + ---------- + patches : list of gorilla.Patch + Patches created through the decorators. + override : dict + Any overriding value defined by the :func:`destination`, :func:`name`, + and :func:`settings` decorators. + filter : bool or None + Value defined by the :func:`filter` decorator, if any, or `None` + otherwise. + """ + + def __init__(self): + """Constructor.""" + self.patches = [] + self.override = {} + self.filter = None + + +class Settings: + """Define the patching behaviour. + + Attributes + ---------- + allow_hit : bool + A hit occurs when an attribute at the destination already exists with + the name given by the patch. If `False`, the patch process won't + allow setting a new value for the attribute by raising an exception. + Defaults to `False`. + store_hit : bool + If `True` and :attr:`allow_hit` is also set to `True`, then any + attribute at the destination that is hit is stored under a different + name before being overwritten by the patch. Defaults to `True`. + """ + + def __init__(self, **kwargs): + """Constructor. + + Parameters + ---------- + kwargs + Keyword arguments, see the attributes. + """ + self.allow_hit = False + self.store_hit = True + self._update(**kwargs) + + def __repr__(self): + values = ", ".join( + [f"{key}={value!r}" for key, value in sorted(_iteritems(self.__dict__))] + ) + return f"{type(self).__name__}({values})" + + def __eq__(self, other): + if isinstance(other, type(self)): + return self.__dict__ == other.__dict__ + + return NotImplemented + + def __ne__(self, other): + is_equal = self.__eq__(other) + return is_equal if is_equal is NotImplemented else not is_equal + + def _update(self, **kwargs): + """Update some settings. + + Parameters + ---------- + kwargs + Settings to update. + """ + self.__dict__.update(**kwargs) + + def to_dict(self): + return self.__dict__ + + @classmethod + def from_dict(cls, json): + return cls(**json) + + def to_json(self, filename=None): + if filename: + with open(filename, "w") as f: + return json.dump(self.to_dict(), f) + else: + return json.dumps(self.to_dict()) + + @classmethod + def from_json(cls, json_str): + return cls.from_dict(json.loads(json_str)) + + @classmethod + def from_json_file(cls, filename): + with open(filename, "r") as f: + return cls.from_json(f.read()) + + +class Patch: + """Describe all the information required to apply a patch. + + Attributes + ---------- + destination(`object`): + Patch destination. + name(`str`): + Name of the attribute at the destination. + obj(`object`): + Attribute value. + source_name(`str`, *optional*): + Name of the attribute at the source. If `None`, then the name of the attribute is the same as `name`. + source(`Union[str, ModuleType]`, *optional*): + The fully qualified name of the object. If `None`, then the object will not be serializable. + For example, if object `obj` is defined in module `package1.submodule1`, then `source` + should be `package1.submodule1.obj`. If object is a module, then `source` should be the module name. + settings(`gorilla.Settings`, *optional*): + Settings. If `None`, the default settings are used. + + Warning + ------- + It is highly recommended to use the output of the function + :func:`get_attribute` for setting the attribute :attr:`obj`. This will + ensure that the descriptor protocol is bypassed instead of possibly + retrieving attributes invalid for patching, such as bound methods. + """ + + def __init__( + self, + destination: object, + name: str, + obj: object, + source: Optional[Union[str, ModuleType]] = None, + source_name: Optional[str] = None, + settings: Optional[Settings] = None, + ): + """Constructor. + + Parameters + ---------- + destination : object + See the :attr:`~Patch.destination` attribute. + source : str + See the :attr:`~Patch.source` attribute. + name : str + See the :attr:`~Patch.name` attribute. + obj : object + See the :attr:`~Patch.obj` attribute. + asname : str + See the :attr:`~Patch.asname` attribute. + settings : gorilla.Settings + See the :attr:`~Patch.settings` attribute. + """ + self.destination = destination + self.name = name + self.obj = obj + self.source_name = source_name or name + if source: + if isinstance(source, ModuleType): + self.source = source.__name__ + if not hasattr(source, self.source_name): + raise AttributeNotFoundError( + f"Cannot find attribute {self.source_name} in module {self.source}" + ) + else: + self.source = source + if not importlib.util.find_spec(self.source): + raise PackageNotFoundError(f"Cannot find module {self.source}") + if not hasattr(importlib.import_module(self.source), self.source_name): + raise AttributeNotFoundError( + f"Cannot find attribute {self.source_name} in module {self.source}" + ) + + if self.source == destination.__name__: # don't patch the same module + raise SameSourceAndDestinationError( + f"Cannot patch {self.source_name} in its own module {self.source}" + ) + + self.settings = settings + self.is_inplace_patch = None + + def __repr__(self): + return "{}(destination={!r}, name={!r}, source={!r}, settings={!r})".format( + type(self).__name__, + self.destination.__name__, + self.name, + self.source, + self.settings, + ) + + def __eq__(self, other): + if isinstance(other, type(self)): + return self.__dict__ == other.__dict__ + + return NotImplemented + + def __ne__(self, other): + is_equal = self.__eq__(other) + return is_equal if is_equal is NotImplemented else not is_equal + + def __hash__(self): # pylint: disable=useless-super-delegation + return super().__hash__() + + def _update(self, **kwargs): + """Update some attributes. + + If a 'settings' attribute is passed as a dict, then it will update the + content of the settings, if any, instead of completely overwriting it. + + Parameters + ---------- + kwargs + Attributes to update. + + Raises + ------ + ValueError + The setting doesn't exist. + """ + for key, value in _iteritems(kwargs): + if key == "settings": + if isinstance(value, dict): + if self.settings is None: + self.settings = Settings(**value) + else: + self.settings._update(**value) + else: + self.settings = copy.deepcopy(value) + else: + setattr(self, key, value) + + def to_dict(self, pickle=False): + if self.source is None and not pickle: + raise ValueError( + "Cannot serialize Patch object without `obj` source. Please specify source when creating Patch object or set pickle=True to pickle the object." + ) + destination_name = getattr(self.destination, "__name__", None) + destination_module_name = getattr( + inspect.getmodule(self.destination), "__name__", None + ) + destination_value = None + if not pickle: + if not type(self.destination).__name__ == "module": + destination_value = destination_name + elif pickle and not type(self.destination).__name__ == "module": + destination_value = base64.b64encode(dill.dumps(self.destination)).decode( + "utf-8" + ) + + destination = { + "type": type(self.destination).__name__ if not pickle else "pickle", + "path": destination_module_name, + "value": destination_value, + "version": find_package_version(destination_module_name), + } + + obj = { + "type": type(self.obj).__name__ if not pickle else "pickle", + "path": self.source, + "value": self.source_name + if not pickle + else base64.b64encode(dill.dumps(self.obj)).decode("utf-8"), + "version": find_package_version(self.source), + } + + return { + "destination": destination, + "obj": obj, + "name": self.name, + "source": self.source, + "source_name": self.source_name, + "settings": self.settings.__dict__ if self.settings else None, + } + + @classmethod + def from_dict(cls, json): + def reconstruct_object(obj_details): + if obj_details["type"] == "module": + # The object is a module + obj = importlib.import_module(obj_details["path"]) + elif obj_details["type"] == "pickle": + # The object is a pickled object + obj = dill.loads(base64.b64decode(obj_details["value"].encode("utf-8"))) + return obj + elif obj_details["value"]: + # Object within a module + module = importlib.import_module(obj_details["path"]) + obj = getattr(module, obj_details["value"]) + else: + raise ValueError(f"Path must be specified for `{obj_details['type']}`") + + # Check version + current_version = find_package_version(obj_details["path"]) + if obj_details["version"] and current_version != obj_details["version"]: + logger.debug( + f"Version mismatch for {obj_details['path']}: " + f"Expected {obj_details['version']}, got {current_version}" + ) + raise ValueError( + f"Version mismatch for {obj_details['path']}: " + f"Expected {obj_details['version']}, got {current_version}" + ) + + return obj + + destination = reconstruct_object(json["destination"]) + obj = reconstruct_object(json["obj"]) + settings = Settings(**json["settings"]) if json["settings"] else None + + return cls( + destination=destination, + name=json["name"], + obj=obj, + source=json["source"], + source_name=json["source_name"], + settings=settings, + ) + + def to_json(self, filename=None): + if filename: + with open(filename, "w") as f: + return json.dump(self.to_dict(), f) + else: + return json.dumps(self.to_dict()) + + def to_json_file(self, filename): + with open(filename, "w") as f: + return json.dump(self.to_dict(), f) + + @classmethod + def from_json(cls, json_str): + return cls.from_dict(json.loads(json_str)) + + @classmethod + def from_json_file(cls, filename): + with open(filename, "r") as f: + return cls.from_json(f.read()) + + +def apply(patch): + """Apply a patch. + + The patch's :attr:`~Patch.obj` attribute is injected into the patch's + :attr:`~Patch.destination` under the patch's :attr:`~Patch.name`. + + This is a wrapper around calling + `setattr(patch.destination, patch.name, patch.obj)`. + + Parameters + ---------- + patch : gorilla.Patch + Patch. + + Raises + ------ + RuntimeError + Overwriting an existing attribute is not allowed when the setting + :attr:`Settings.allow_hit` is set to `True`. + + Note + ---- + If both the attributes :attr:`Settings.allow_hit` and + :attr:`Settings.store_hit` are `True` but that the target attribute seems + to have already been stored, then it won't be stored again to avoid losing + the original attribute that was stored the first time around. + """ + # is_inplace_patch = True represents the patch object will overwrite the original + # attribute + patch.is_inplace_patch = patch.name in patch.destination.__dict__ + settings = Settings() if patch.settings is None else patch.settings + + curr_active_patch = _ACTIVE_PATCH % (patch.name,) + if curr_active_patch in patch.destination.__dict__: + logger.debug( + f"Patch {patch.name} on {destination.__name__} already existed. Overwrite old patch." + ) + + # When a hit occurs due to an attribute at the destination already existing + # with the patch's name, the existing attribute is referred to as 'target'. + try: + target = get_original_attribute( + patch.destination, patch.name, bypass_descriptor_protocol=True + ) + except AttributeError: + pass + else: + if not settings.allow_hit: + raise RuntimeError( + "An attribute named '%s' already exists at the destination " # noqa: UP031 + "'%s'. Set a different name through the patch object to avoid " + "a name clash or set the setting 'allow_hit' to True to " + "overwrite the attribute. In the latter case, it is " + "recommended to also set the 'store_hit' setting to True in " + "order to store the original attribute under a different " + "name so it can still be accessed." + % (patch.name, patch.destination.__name__) + ) + + if settings.store_hit: + original_name = _ORIGINAL_NAME % (patch.name,) + setattr(patch.destination, original_name, target) + + setattr(patch.destination, patch.name, patch.obj) + setattr(patch.destination, curr_active_patch, patch) + + +def revert(patch): + """Revert a patch. + Parameters + ---------- + patch : gorilla.Patch + Patch. + Note + ---- + This is only possible if the attribute :attr:`Settings.store_hit` was set + to `True` when applying the patch and overriding an existing attribute. + + Notice: + This method is taken from + https://github.com/christophercrouzet/gorilla/blob/v0.4.0/gorilla.py#L318-L351 + with modifictions for autologging disablement purposes. + """ + # If an curr_active_patch has not been set on destination class for the current patch, + # then the patch has not been applied and we do not need to revert anything. + curr_active_patch = _ACTIVE_PATCH % (patch.name,) + if curr_active_patch not in patch.destination.__dict__: + # already reverted. + return + + original_name = _ORIGINAL_NAME % (patch.name,) + + if patch.is_inplace_patch: + # check whether original_name is in destination. We cannot use hasattr because it will + # try to get attribute from parent classes if attribute not found in destination class. + if original_name not in patch.destination.__dict__: + raise RuntimeError( + "Cannot revert the attribute named '%s' since the setting " # noqa: UP031 + "'store_hit' was not set to True when applying the patch." + % (patch.destination.__name__,) + ) + # restore original method + # during reverting patch, we need restore the raw attribute to the patch point + # so get original attribute bypassing descriptor protocal + original = object.__getattribute__(patch.destination, original_name) + setattr(patch.destination, patch.name, original) + else: + # delete patched method + delattr(patch.destination, patch.name) + + if original_name in patch.destination.__dict__: + delattr(patch.destination, original_name) + delattr(patch.destination, curr_active_patch) + + +def patch(destination, name=None, settings=None): + """Decorator to create a patch. + + The object being decorated becomes the :attr:`~Patch.obj` attribute of the + patch. + + Parameters + ---------- + destination : object + Patch destination. + name : str + Name of the attribute at the destination. + settings : gorilla.Settings + Settings. + + Returns + ------- + object + The decorated object. + + See Also + -------- + :class:`Patch`. + """ + + def decorator(wrapped): + base = _get_base(wrapped) + name_ = base.__name__ if name is None else name + settings_ = copy.deepcopy(settings) + patch = Patch(destination, name_, wrapped, settings=settings_) + data = get_decorator_data(base, set_default=True) + data.patches.append(patch) + return wrapped + + return decorator + + +def destination(value): + """Modifier decorator to update a patch's destination. + + This only modifies the behaviour of the :func:`create_patches` function + and the :func:`patches` decorator, given that their parameter + `use_decorators` is set to `True`. + + Parameters + ---------- + value : object + Patch destination. + + Returns + ------- + object + The decorated object. + """ + + def decorator(wrapped): + data = get_decorator_data(_get_base(wrapped), set_default=True) + data.override["destination"] = value + return wrapped + + return decorator + + +def name(value): + """Modifier decorator to update a patch's name. + + This only modifies the behaviour of the :func:`create_patches` function + and the :func:`patches` decorator, given that their parameter + `use_decorators` is set to `True`. + + Parameters + ---------- + value : object + Patch name. + + Returns + ------- + object + The decorated object. + """ + + def decorator(wrapped): + data = get_decorator_data(_get_base(wrapped), set_default=True) + data.override["name"] = value + return wrapped + + return decorator + + +def settings(**kwargs): + """Modifier decorator to update a patch's settings. + + This only modifies the behaviour of the :func:`create_patches` function + and the :func:`patches` decorator, given that their parameter + `use_decorators` is set to `True`. + + Parameters + ---------- + kwargs + Settings to update. See :class:`Settings` for the list. + + Returns + ------- + object + The decorated object. + """ + + def decorator(wrapped): + data = get_decorator_data(_get_base(wrapped), set_default=True) + data.override.setdefault("settings", {}).update(kwargs) + return wrapped + + return decorator + + +def filter(value): # pylint: disable=redefined-builtin + """Modifier decorator to force the inclusion or exclusion of an attribute. + + This only modifies the behaviour of the :func:`create_patches` function + and the :func:`patches` decorator, given that their parameter + `use_decorators` is set to `True`. + + Parameters + ---------- + value : bool + `True` to force inclusion, `False` to force exclusion, and `None` + to inherit from the behaviour defined by :func:`create_patches` or + :func:`patches`. + + Returns + ------- + object + The decorated object. + """ + + def decorator(wrapped): + data = get_decorator_data(_get_base(wrapped), set_default=True) + data.filter = value + return wrapped + + return decorator + + +def find_patches(modules, recursive=True): + """Find all the patches created through decorators. + + Parameters + ---------- + modules : list of module + Modules and/or packages to search the patches in. + recursive : bool + `True` to search recursively in subpackages. + + Returns + ------- + list of gorilla.Patch + Patches found. + + Raises + ------ + TypeError + The input is not a valid package or module. + + See Also + -------- + :func:`patch`, :func:`patches`. + """ + out = [] + modules = ( + module + for package in modules + for module in _module_iterator(package, recursive=recursive) + ) + for module in modules: + members = _get_members(module, filter=None) + for _, value in members: + base = _get_base(value) + decorator_data = get_decorator_data(base) + if decorator_data is None: + continue + + out.extend(decorator_data.patches) + + return out + + +def get_original_attribute(obj, name, bypass_descriptor_protocol=False): + """Retrieve an overridden attribute that has been stored. + + Parameters + ---------- + obj : object + Object to search the attribute in. + name : str + Name of the attribute. + bypass_descriptor_protocol: boolean + bypassing descriptor protocol if true. When storing original method during patching or + restoring original method during reverting patch, we need set bypass_descriptor_protocol + to be True to ensure get the raw attribute object. + + Returns + ------- + object + The attribute found. + + Raises + ------ + AttributeError + The attribute couldn't be found. + + Note + ---- + if setting store_hit=False, then after patch applied, this methods may return patched + attribute instead of original attribute in specific cases. + + See Also + -------- + :attr:`Settings.allow_hit`. + """ + + original_name = _ORIGINAL_NAME % (name,) + curr_active_patch = _ACTIVE_PATCH % (name,) + + def _get_attr(obj_, name_): + if bypass_descriptor_protocol: + return object.__getattribute__(obj_, name_) + else: + return getattr(obj_, name_) + + no_original_stored_err = ( + "Original attribute %s was not stored when patching, set " + "store_hit=True will address this." + ) + + if inspect.isclass(obj): + # Search from children classes to parent classes, and check "original_name" attribute + # first. This will ensure get the correct original attribute in any cases, e.g., + # the case some classes in the hierarchy haven't been patched, but some others are + # patched, this case the previous code is risky to get wrong original attribute. + for obj_ in inspect.getmro(obj): + if original_name in obj_.__dict__: + return _get_attr(obj_, original_name) + elif name in obj_.__dict__: + if curr_active_patch in obj_.__dict__: + patch = getattr(obj, curr_active_patch) + if patch.is_inplace_patch: + raise RuntimeError( + no_original_stored_err % (f"{obj_.__name__}.{name}",) + ) + else: + # non inplace patch, we can get original methods in parent classes. + # so go on checking parent classes + continue + return _get_attr(obj_, name) + else: + # go on checking parent classes + continue + raise AttributeError(f"'{type(obj)}' object has no attribute '{name}'") + else: + try: + return _get_attr(obj, original_name) + except AttributeError: + if curr_active_patch in obj.__dict__: + raise RuntimeError( + no_original_stored_err % (f"{type(obj).__name__}.{name}",) + ) + return _get_attr(obj, name) + + +def get_decorator_data(obj, set_default=False): + """Retrieve any decorator data from an object. + + Parameters + ---------- + obj : object + Object. + set_default : bool + If no data is found, a default one is set on the object and returned, + otherwise `None` is returned. + + Returns + ------- + gorilla.DecoratorData + The decorator data or `None`. + """ + if inspect.isclass(obj): + datas = getattr(obj, _DECORATOR_DATA, {}) + data = datas.setdefault(obj, None) + if data is None and set_default: + data = DecoratorData() + datas[obj] = data + setattr(obj, _DECORATOR_DATA, datas) + else: + data = getattr(obj, _DECORATOR_DATA, None) + if data is None and set_default: + data = DecoratorData() + setattr(obj, _DECORATOR_DATA, data) + + return data + + +def _get_base(obj): + """Unwrap decorators to retrieve the base object. + + Parameters + ---------- + obj : object + Object. + + Returns + ------- + object + The base object found or the input object otherwise. + """ + if hasattr(obj, "__func__"): + obj = obj.__func__ + elif isinstance(obj, property): + obj = obj.fget + elif isinstance(obj, (classmethod, staticmethod)): + # Fallback for Python < 2.7 back when no `__func__` attribute + # was defined for those descriptors. + obj = obj.__get__(None, object) + else: + return obj + + return _get_base(obj) + + +def _get_members(obj, traverse_bases=True, filter=default_filter, recursive=True): + """Retrieve the member attributes of a module or a class. + + The descriptor protocol is bypassed. + + Parameters + ---------- + obj : module or class + Object. + traverse_bases : bool + If the object is a class, the base classes are also traversed. + filter : function + Attributes for which the function returns `False` are skipped. The + function needs to define two parameters: `name`, the attribute name, + and `obj`, the attribute value. If `None`, no attribute is skipped. + recursive : bool + `True` to search recursively through subclasses. + + Returns + ------ + list of (name, value) + A list of tuples each containing the name and the value of the + attribute. + """ + if filter is None: + filter = _true + + out = [] + stack = collections.deque((obj,)) + while stack: + obj = stack.popleft() + if traverse_bases and inspect.isclass(obj): + roots = [base for base in inspect.getmro(obj) if base not in (type, object)] + else: + roots = [obj] + + members = [] + seen = set() + for root in roots: + for name, value in _iteritems(getattr(root, "__dict__", {})): + if name not in seen and filter(name, value): + members.append((name, value)) + + seen.add(name) + + members = sorted(members) + for _, value in members: + if recursive and inspect.isclass(value): + stack.append(value) + + out.extend(members) + + return out + + +def _module_iterator(root, recursive=True): + """Iterate over modules. + + Parameters + ---------- + root : module + Root module or package to iterate from. + recursive : bool + `True` to iterate within subpackages. + + Yields + ------ + module + The modules found. + """ + yield root + + stack = collections.deque((root,)) + while stack: + package = stack.popleft() + # The '__path__' attribute of a package might return a list of paths if + # the package is referenced as a namespace. + paths = getattr(package, "__path__", []) + for path in paths: + modules = pkgutil.iter_modules([path]) + for _, name, is_package in modules: + module_name = f"{package.__name__}.{name}" + module = sys.modules.get(module_name, None) + if module is None: + # Import the module through the finder to support package + # namespaces. + try: + module = _load_module(module_name) + except ImportError: + # Missing modules means that they are optional. Therefore, skip them. + continue + if is_package: + if recursive: + stack.append(module) + yield module + else: + yield module + + +def find_package_version(module_name: str): + if not module_name: + return None + mods = module_name.split(".") + for i in range(len(mods), 0, -1): + m = ".".join(mods[:i]) + mod = importlib.import_module(m) + if hasattr(mod, "__version__"): + return mod.__version__ + return None + + +def _true(*args, **kwargs): # pylint: disable=unused-argument + """Return `True`.""" + return True diff --git a/src/biofit/utils/logging.py b/src/biofit/utils/logging.py new file mode 100644 index 0000000..2cf8c55 --- /dev/null +++ b/src/biofit/utils/logging.py @@ -0,0 +1,447 @@ +# coding=utf-8 +# Copyright 2020 Optuna, Hugging Face +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Logging utilities.""" + +import functools +import logging +import os +import sys +import threading +import warnings +from logging import ( + CRITICAL, # NOQA + DEBUG, # NOQA + ERROR, # NOQA + FATAL, # NOQA + INFO, # NOQA + NOTSET, # NOQA + WARN, # NOQA + WARNING, # NOQA +) +from logging import captureWarnings as _captureWarnings +from typing import Optional + +from biocore.utils.import_util import is_datasets_available +from tqdm import auto as tqdm_lib + +_lock = threading.Lock() +_default_handler: Optional[logging.Handler] = None + +log_levels = { + "detail": logging.DEBUG, # will also print filename and line number + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + "critical": logging.CRITICAL, +} + +_default_log_level = logging.WARNING + +_tqdm_active = True + + +def _get_default_logging_level(): + """ + If BIOFIT_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is + not - fall back to `_default_log_level` + """ + env_level_str = os.getenv("BIOFIT_VERBOSITY", None) + if env_level_str: + if env_level_str in log_levels: + return log_levels[env_level_str] + else: + logging.getLogger().warning( + f"Unknown option BIOFIT_VERBOSITY={env_level_str}, " + f"has to be one of: {', '.join(log_levels.keys())}" + ) + + return _default_log_level + + +def _get_library_name() -> str: + return __name__.split(".")[0] + + +def _get_library_root_logger() -> logging.Logger: + return logging.getLogger(_get_library_name()) + + +def _configure_library_root_logger() -> None: + global _default_handler + + with _lock: + if _default_handler: + # This library has already configured the library root logger. + return + _default_handler = logging.StreamHandler() # Set sys.stderr as stream. + # set defaults based on https://github.com/pyinstaller/pyinstaller/issues/7334#issuecomment-1357447176 + if sys.stderr is None: + sys.stderr = open(os.devnull, "w") + + _default_handler.flush = sys.stderr.flush + + # Apply our default configuration to the library root logger. + library_root_logger = _get_library_root_logger() + library_root_logger.addHandler(_default_handler) + library_root_logger.setLevel(_get_default_logging_level()) + # if logging level is debug, we add pathname and lineno to formatter for easy debugging + if os.getenv("BIOFIT_VERBOSITY", None) == "detail": + formatter = logging.Formatter( + "[%(levelname)s|%(pathname)s:%(lineno)s] %(asctime)s >> %(message)s" + ) + _default_handler.setFormatter(formatter) + + library_root_logger.propagate = False + + +def _reset_library_root_logger() -> None: + global _default_handler + + with _lock: + if not _default_handler: + return + + library_root_logger = _get_library_root_logger() + library_root_logger.removeHandler(_default_handler) + library_root_logger.setLevel(logging.NOTSET) + _default_handler = None + + +def get_log_levels_dict(): + return log_levels + + +def captureWarnings(capture): + """ + Calls the `captureWarnings` method from the logging library to enable management of the warnings emitted by the + `warnings` library. + + Read more about this method here: + https://docs.python.org/3/library/logging.html#integration-with-the-warnings-module + + All warnings will be logged through the `py.warnings` logger. + + Careful: this method also adds a handler to this logger if it does not already have one, and updates the logging + level of that logger to the library's root logger. + """ + logger = get_logger("py.warnings") + + if not logger.handlers: + logger.addHandler(_default_handler) + + logger.setLevel(_get_library_root_logger().level) + + _captureWarnings(capture) + + +def get_logger(name: Optional[str] = None) -> logging.Logger: + """ + Return a logger with the specified name. + + This function is not supposed to be directly accessed unless you are writing a custom biofit module. + """ + + if name is None: + name = _get_library_name() + + _configure_library_root_logger() + return logging.getLogger(name) + + +def get_verbosity() -> int: + """ + Return the current level for the biofit's root logger as an int. + + Returns: + `int`: The logging level. + + + + biofit has following logging levels: + + - 50: `biofit.logging.CRITICAL` or `biofit.logging.FATAL` + - 40: `biofit.logging.ERROR` + - 30: `biofit.logging.WARNING` or `biofit.logging.WARN` + - 20: `biofit.logging.INFO` + - 10: `biofit.logging.DEBUG` + + """ + + _configure_library_root_logger() + return _get_library_root_logger().getEffectiveLevel() + + +def set_verbosity(verbosity: int) -> None: + """ + Set the verbosity level for the 🤗 biofit's root logger. + + Args: + verbosity (`int`): + Logging level, e.g., one of: + + - `biofit.logging.CRITICAL` or `biofit.logging.FATAL` + - `biofit.logging.ERROR` + - `biofit.logging.WARNING` or `biofit.logging.WARN` + - `biofit.logging.INFO` + - `biofit.logging.DEBUG` + """ + + _configure_library_root_logger() + _get_library_root_logger().setLevel(verbosity) + + +def set_verbosity_info(): + """Set the verbosity to the `INFO` level.""" + return set_verbosity(INFO) + + +def set_verbosity_warning(): + """Set the verbosity to the `WARNING` level.""" + return set_verbosity(WARNING) + + +def set_verbosity_debug(): + """Set the verbosity to the `DEBUG` level.""" + return set_verbosity(DEBUG) + + +def set_verbosity_error(): + """Set the verbosity to the `ERROR` level.""" + return set_verbosity(ERROR) + + +def disable_default_handler() -> None: + """Disable the default handler of the biofit's root logger.""" + + _configure_library_root_logger() + + assert _default_handler is not None + _get_library_root_logger().removeHandler(_default_handler) + + +def enable_default_handler() -> None: + """Enable the default handler of the biofit's root logger.""" + + _configure_library_root_logger() + + assert _default_handler is not None + _get_library_root_logger().addHandler(_default_handler) + + +def set_default_handler(handler: logging.Handler) -> None: + """Set the default handler of the biofit's root logger.""" + + _reset_library_root_logger() + + assert handler is not None + _default_handler = handler + + +def add_handler(handler: logging.Handler) -> None: + """adds a handler to the biofit's root logger.""" + + _configure_library_root_logger() + + assert handler is not None + _get_library_root_logger().addHandler(handler) + + +def set_formatter(formatter: logging.Formatter) -> None: + """adds a formatter to the biofit's default handler""" + global _default_handler + + _configure_library_root_logger() + + assert formatter is not None + _default_handler.setFormatter(formatter) + + +def remove_handler(handler: logging.Handler) -> None: + """removes given handler from the biofit's root logger.""" + + _configure_library_root_logger() + + assert handler is not None and handler not in _get_library_root_logger().handlers + # log optuna study to + _get_library_root_logger().removeHandler(handler) + + +def disable_propagation() -> None: + """ + Disable propagation of the library log outputs. Note that log propagation is disabled by default. + """ + + _configure_library_root_logger() + _get_library_root_logger().propagate = False + + +def enable_propagation() -> None: + """ + Enable propagation of the library log outputs. Please disable the biofit's default handler to + prevent double logging if the root logger has been configured. + """ + + _configure_library_root_logger() + _get_library_root_logger().propagate = True + + +def enable_explicit_format() -> None: + """ + Enable explicit formatting for every biofit's logger. The explicit formatter is as follows: + ``` + [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE + ``` + All handlers currently bound to the root logger are affected by this method. + """ + handlers = _get_library_root_logger().handlers + + for handler in handlers: + formatter = logging.Formatter( + "[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s" + ) + handler.setFormatter(formatter) + + +def reset_format() -> None: + """ + Resets the formatting for biofit's loggers. + + All handlers currently bound to the root logger are affected by this method. + """ + handlers = _get_library_root_logger().handlers + + for handler in handlers: + handler.setFormatter(None) + + +def warning_advice(self, *args, **kwargs): + """ + This method is identical to `logger.warning()`, but if env var BIOFIT_NO_ADVISORY_WARNINGS=1 is set, this + warning will not be printed + """ + no_advisory_warnings = os.getenv("BIOFIT_NO_ADVISORY_WARNINGS", False) + if no_advisory_warnings: + return + self.warning(*args, **kwargs) + + +logging.Logger.warning_advice = warning_advice + + +@functools.lru_cache(None) +def warning_once(self, *args, **kwargs): + """ + This method is identical to `logger.warning()`, but will emit the warning with the same message only once + + Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache. + The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to + another type of cache that includes the caller frame information in the hashing function. + """ + self.warning(*args, **kwargs) + + +logging.Logger.warning_once = warning_once + + +class EmptyTqdm: + """Dummy tqdm which doesn't do anything.""" + + def __init__(self, *args, **kwargs): # pylint: disable=unused-argument + self._iterator = args[0] if args else None + + def __iter__(self): + return iter(self._iterator) + + def __getattr__(self, _): + """Return empty function.""" + + def empty_fn(*args, **kwargs): # pylint: disable=unused-argument + return + + return empty_fn + + def __enter__(self): + return self + + def __exit__(self, type_, value, traceback): + return + + +class _tqdm_cls: + def __call__(self, *args, **kwargs): + if _tqdm_active: + return tqdm_lib.tqdm(*args, **kwargs) + else: + return EmptyTqdm(*args, **kwargs) + + def set_lock(self, *args, **kwargs): + self._lock = None + if _tqdm_active: + return tqdm_lib.tqdm.set_lock(*args, **kwargs) + + def get_lock(self): + if _tqdm_active: + return tqdm_lib.tqdm.get_lock() + + +tqdm = _tqdm_cls() + + +def is_progress_bar_enabled() -> bool: + """Return a boolean indicating whether tqdm progress bars are enabled.""" + global _tqdm_active + return bool(_tqdm_active) + + +def enable_progress_bar(): + """Enable tqdm progress bar.""" + global _tqdm_active + _tqdm_active = True + + +def disable_progress_bar(): + """Disable tqdm progress bar.""" + global _tqdm_active + _tqdm_active = False + + +def silence(): + warnings.filterwarnings("ignore") + verb = get_verbosity() + is_pb_enabled = is_progress_bar_enabled() + set_verbosity(0) + disable_progress_bar() + dataset_verb = None + is_ds_pb_enabled = None + if is_datasets_available(): + import datasets + + dataset_verb = datasets.logging.get_verbosity() + is_ds_pb_enabled = datasets.is_progress_bar_enabled() + datasets.disable_progress_bars() + return verb, dataset_verb, is_pb_enabled, is_ds_pb_enabled + + +def unsilence(verbosity, dataset_verbosity, is_pb_enabled, is_ds_pb_enabled): + warnings.filterwarnings("default") + set_verbosity(verbosity) + if is_pb_enabled: + enable_progress_bar() + if is_ds_pb_enabled is not None and is_datasets_available(): + import datasets + + datasets.logging.set_verbosity(dataset_verbosity) + datasets.enable_progress_bars() diff --git a/src/biofit/utils/py_util.py b/src/biofit/utils/py_util.py new file mode 100644 index 0000000..37c33ce --- /dev/null +++ b/src/biofit/utils/py_util.py @@ -0,0 +1,209 @@ +# Copyright 2024 Patrick Smyth and Hugging Face authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing +import os +import queue +import random +import sys +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from queue import Empty +from typing import Callable, Iterable, Set, TypeVar, Union + +import multiprocess +import multiprocess.pool +import numpy as np +from biocore.utils.import_util import ( + is_datasets_available, + is_polars_available, + is_tf_available, + is_torch_available, +) + +Y = TypeVar("Y") + + +def is_temporal(val): + return isinstance(val, (datetime, date, time, timedelta)) + + +def is_decimal(val): + return isinstance(val, Decimal) + + +def as_py(val): + # return as int64 + if isinstance(val, datetime): + return val.timestamp() + elif isinstance(val, date): + return val.toordinal() + elif isinstance(val, time): + return val.hour * 3600 + val.minute * 60 + val.second + val.microsecond / 1e6 + elif isinstance(val, timedelta): + return val.total_seconds() + elif isinstance(val, Decimal): + return float(val) + elif isinstance(val, np.float16): + return float(val) + elif isinstance(val, bytes): + return f"base64:{val.hex()}" + elif isinstance(val, dict): + return {k: as_py(v) for k, v in val.items()} + return val + + +def enable_full_determinism(seed: int, warn_only: bool = False): + """ + Helper function for reproducible behavior during distributed training. See + - https://pytorch.org/docs/stable/notes/randomness.html for pytorch + - https://www.tensorflow.org/api_docs/python/tf/config/experimental/enable_op_determinism for tensorflow + """ + # set seed first + set_seed(seed) + + if is_torch_available(): + import torch + + # Enable PyTorch deterministic mode. This potentially requires either the environment + # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set, + # depending on the CUDA version, so we set them both here + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" + torch.use_deterministic_algorithms(True, warn_only=warn_only) + + # Enable CUDNN deterministic mode + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + if is_tf_available(): + import tensorflow as tf + + tf.config.experimental.enable_op_determinism() + + +def set_seed(seed: int): + """ + Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch` and/or `tf` (if installed). + + Args: + seed (`int`): The seed to set. + """ + random.seed(seed) + np.random.seed(seed) + + if is_polars_available(): + from polars import set_random_seed + + set_random_seed(seed) + + def is_torch_npu_available(): + try: + import torch + + return torch.npu.is_available() + except ImportError: + return False + + if is_torch_available(): + import torch + + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + if is_tf_available(): + import tensorflow as tf + + tf.random.set_seed(seed) + + +# This function is a copy of the one in datasets.utils.py_util.py, Copyright 2024 +# Hugging Face authors. Licensed under the Apache 2.0 license. See the license file for +# details at https://www.apache.org/licenses/LICENSE-2.0 +def _get_pool_pid( + pool: Union[multiprocessing.pool.Pool, multiprocess.pool.Pool], +) -> Set[int]: + return {f.pid for f in pool._pool} + + +# This function is a copy of the one in datasets.utils.py_util.py, Copyright 2024 +# Hugging Face authors. Licensed under the Apache 2.0 license. See the license file for +# details at https://www.apache.org/licenses/LICENSE-2.0 +def _write_generator_to_queue( + queue: queue.Queue, func: Callable[..., Iterable[Y]], kwargs: dict +) -> int: + for i, result in enumerate(func(**kwargs)): + queue.put(result) + return i + + +# This function is a copy of the one in datasets.utils.py_util.py, Copyright 2024 +# Hugging Face authors. Licensed under the Apache 2.0 license. See the license file for +# details at https://www.apache.org/licenses/LICENSE-2.0 +def iflatmap_unordered( + pool: Union[multiprocessing.pool.Pool, multiprocess.pool.Pool], + func: Callable[..., Iterable[Y]], + *, + kwargs_iterable: Iterable[dict], +) -> Iterable[Y]: + """ + Similar to `itertools.chain.from_iterable(map(func, iterable))` but with a potentially more efficient implementation + that doesn't require storing all the intermediate results in memory. + + Args: + func (Callable): The function to apply to each element of the input iterable. + iterable (Iterable): The input iterable. + + Returns: + Iterator: The flattened iterable. + """ + if is_datasets_available(): + from datasets.utils.py_utils import iflatmap_unordered + + return iflatmap_unordered(pool, func, kwargs_iterable=kwargs_iterable) + else: + initial_pool_pid = _get_pool_pid(pool) + pool_changed = False + manager_cls = ( + multiprocessing.Manager + if isinstance(pool, multiprocessing.pool.Pool) + else multiprocess.Manager + ) + with manager_cls() as manager: + queue = manager.Queue() + async_results = [ + pool.apply_async(_write_generator_to_queue, (queue, func, kwargs)) + for kwargs in kwargs_iterable + ] + try: + while True: + try: + yield queue.get(timeout=0.05) + except Empty: + if ( + all(async_result.ready() for async_result in async_results) + and queue.empty() + ): + break + if _get_pool_pid(pool) != initial_pool_pid: + pool_changed = True + # One of the subprocesses has died. We should not wait forever. + raise RuntimeError( + "One of the subprocesses has abruptly died during map operation." + "To debug the error, disable multiprocessing." + ) + finally: + if not pool_changed: + # we get the result in case there's an error to raise + [async_result.get(timeout=0.05) for async_result in async_results] diff --git a/src/biofit/utils/recorder.py b/src/biofit/utils/recorder.py new file mode 100644 index 0000000..9bb1aff --- /dev/null +++ b/src/biofit/utils/recorder.py @@ -0,0 +1,156 @@ +from typing import List, Optional, Union, Callable +import sys +import inspect +import importlib +from functools import wraps +from pathlib import Path + + +from .. import config +from .file_utils import is_file_name, PathLike +from . import logging + +logger = logging.get_logger(__name__) + + +def load_module_or_class(full_path): + if full_path is None: + return None + try: + module_path, entity_name = full_path.rsplit(".", 1) + module = ( + importlib.import_module(module_path) + if module_path not in sys.modules + else sys.modules[module_path] + ) + except ImportError: + # Handle the case where the module can't be imported + logger.error(f"Failed to import module: {module_path}") + return None + + try: + entity = ( + getattr(module, entity_name) + if entity_name not in sys.modules + else sys.modules[entity_name] + ) + except AttributeError: + # Handle the case where the entity isn't found in the module + logger.error(f"{entity_name} not found in module: {module_path}") + return None + + return entity + + +def load_method(full_path, method_name): + entity = load_module_or_class(full_path) + if entity: + return getattr(entity, method_name) + return None + + +def _get_cache_dir( + cache_files: Optional[List[dict]] = None, cache_file_name: Optional[PathLike] = None +) -> Optional[Path]: + # check if cache_file_name is a path + cache_dir = None + if cache_file_name is not None: + if isinstance(cache_file_name, PathLike): + cache_file_name = Path(cache_file_name) + if "/" in cache_file_name.as_posix(): + cache_dir = cache_file_name.resolve().parent + + if not cache_dir and cache_files: + cache_dir = Path(cache_files[0]["filename"]).resolve().parent + + return cache_dir + + +def _get_cache_info( + cache_files: Optional[List[dict]] = None, + cache_dir: Optional[Path] = None, + cache_file_name: Optional[Union[str, Path]] = None, + file_ext=".arrow", +): + if cache_file_name: + if is_file_name(cache_file_name): + cache_dir = cache_dir or Path.cwd() + cache_file_name = Path(cache_dir) / Path(cache_file_name).with_suffix( + file_ext + ) + else: + cache_file_name = Path(cache_file_name).with_suffix(file_ext) + + cache_dir = _get_cache_dir(cache_files, cache_file_name) + return cache_file_name, cache_dir + + +UNRECORDED_METHODS = ["train_test_split"] +_RECORDER_ACTIVE = False + + +def pre_recording(func, *args, **kwargs): + new_fingerprint = kwargs.get("new_fingerprint", kwargs.get("fingerprint", None)) + + signature = inspect.signature(func) + cache_file_name = kwargs.get("cache_file_name", None) + if not cache_file_name and "cache_file_name" in signature.parameters: + arg_pos = list(signature.parameters).index("cache_file_name") + if len(args) > arg_pos: + cache_file_name = args[arg_pos] + cache_dir = None + self = args[0] if args else kwargs.get("self", None) + if getattr(self, "cache_files", None): + cache_file_name, cache_dir = _get_cache_info( + self.cache_files, cache_dir, cache_file_name + ) + if isinstance(cache_dir, Path): + cache_dir = cache_dir.resolve().as_posix() + if isinstance(cache_file_name, Path): + cache_file_name = cache_file_name.resolve().as_posix() + + return { + "cache_file_name": cache_file_name, + "cache_dir": cache_dir, + "new_fingerprint": new_fingerprint, + } + + +def toggle_recorder(): + global _RECORDER_ACTIVE + _RECORDER_ACTIVE = not _RECORDER_ACTIVE + + +def record_step( + replay_func: Optional[str] = None, + pre_recording: Optional[Callable] = pre_recording, + post_recording: Optional[Callable] = None, +): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + if not config.RECORDER_ENABLED or _RECORDER_ACTIVE: + return func(*args, **kwargs) + + toggle_recorder() + extra_info = {} + if pre_recording: + extra_info = pre_recording(func, *args, **kwargs) + out = func(*args, **kwargs) + if post_recording: + out = post_recording( + out, + func.__name__, + args, + kwargs, + replay_func=replay_func, + info=extra_info, + ) + + toggle_recorder() + + return out + + return wrapper + + return decorator diff --git a/src/biofit/utils/table_util.py b/src/biofit/utils/table_util.py new file mode 100644 index 0000000..2af9312 --- /dev/null +++ b/src/biofit/utils/table_util.py @@ -0,0 +1,863 @@ +import copy +import inspect +import os +import re +import tempfile +from collections import defaultdict +from functools import wraps +from pathlib import Path +from typing import List, Union + +import pyarrow as pa +from biocore.utils.import_util import ( + is_datasets_available, + is_rpy2_arrow_available, + is_rpy2_available, + requires_backends, +) + +from biofit import config + +from . import logging +from .file_utils import move_temp_file + +logger = logging.get_logger(__name__) + + +# Function adapted from the Hugging Face Team, Copyright 2024 The Hugging Face Team +# Licensed under the Apache License, Version 2.0. You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Modified by: Patrick Smyth +# Date: 2024 +# Summary of changes: +# - Changed names from pyarrow to pa for consistency +def _arrow_to_datasets_dtype(arrow_type: pa.DataType) -> str: + """ + _arrow_to_datasets_dtype takes a pa.DataType and converts it to a datasets string dtype. + In effect, `dt == string_to_arrow(_arrow_to_datasets_dtype(dt))` + """ + if pa.types.is_null(arrow_type): + return "null" + elif pa.types.is_boolean(arrow_type): + return "bool" + elif pa.types.is_int8(arrow_type): + return "int8" + elif pa.types.is_int16(arrow_type): + return "int16" + elif pa.types.is_int32(arrow_type): + return "int32" + elif pa.types.is_int64(arrow_type): + return "int64" + elif pa.types.is_uint8(arrow_type): + return "uint8" + elif pa.types.is_uint16(arrow_type): + return "uint16" + elif pa.types.is_uint32(arrow_type): + return "uint32" + elif pa.types.is_uint64(arrow_type): + return "uint64" + elif pa.types.is_float16(arrow_type): + return "float16" # pa dtype is "halffloat" + elif pa.types.is_float32(arrow_type): + return "float32" # pa dtype is "float" + elif pa.types.is_float64(arrow_type): + return "float64" # pa dtype is "double" + elif pa.types.is_time32(arrow_type): + return f"time32[{pa.type_for_alias(str(arrow_type)).unit}]" + elif pa.types.is_time64(arrow_type): + return f"time64[{pa.type_for_alias(str(arrow_type)).unit}]" + elif pa.types.is_timestamp(arrow_type): + if arrow_type.tz is None: + return f"timestamp[{arrow_type.unit}]" + elif arrow_type.tz: + return f"timestamp[{arrow_type.unit}, tz={arrow_type.tz}]" + else: + raise ValueError(f"Unexpected timestamp object {arrow_type}.") + elif pa.types.is_date32(arrow_type): + return "date32" # pa dtype is "date32[day]" + elif pa.types.is_date64(arrow_type): + return "date64" # pa dtype is "date64[ms]" + elif pa.types.is_duration(arrow_type): + return f"duration[{arrow_type.unit}]" + elif pa.types.is_decimal128(arrow_type): + return f"decimal128({arrow_type.precision}, {arrow_type.scale})" + elif pa.types.is_decimal256(arrow_type): + return f"decimal256({arrow_type.precision}, {arrow_type.scale})" + elif pa.types.is_binary(arrow_type): + return "binary" + elif pa.types.is_large_binary(arrow_type): + return "large_binary" + elif pa.types.is_string(arrow_type): + return "string" + elif pa.types.is_large_string(arrow_type): + return "large_string" + elif pa.types.is_dictionary(arrow_type): + return _arrow_to_datasets_dtype(arrow_type.value_type) + else: + raise ValueError( + f"Arrow type {arrow_type} does not have a datasets dtype equivalent." + ) + + +# Function adapted from the Hugging Face Team, Copyright 2024 The Hugging Face Team +# Licensed under the Apache License, Version 2.0. You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +def string_to_arrow(datasets_dtype: str) -> pa.DataType: + """ + string_to_arrow takes a datasets string dtype and converts it to a pyarrow.DataType. + + In effect, `dt == string_to_arrow(_arrow_to_datasets_dtype(dt))` + + This is necessary because the datasets.Value() primitive type is constructed using a string dtype + + Value(dtype=str) + + But Features.type (via `get_nested_type()` expects to resolve Features into a pyarrow Schema, + which means that each Value() must be able to resolve into a corresponding pyarrow.DataType, which is the + purpose of this function. + """ + + def _dtype_error_msg(dtype, pa_dtype, examples=None, urls=None): + msg = f"{dtype} is not a validly formatted string representation of the pyarrow {pa_dtype} type." + if examples: + examples = ( + ", ".join(examples[:-1]) + " or " + examples[-1] + if len(examples) > 1 + else examples[0] + ) + msg += f"\nValid examples include: {examples}." + if urls: + urls = ( + ", ".join(urls[:-1]) + " and " + urls[-1] if len(urls) > 1 else urls[0] + ) + msg += f"\nFor more insformation, see: {urls}." + return msg + + if datasets_dtype in pa.__dict__: + return pa.__dict__[datasets_dtype]() + + if (datasets_dtype + "_") in pa.__dict__: + return pa.__dict__[datasets_dtype + "_"]() + + timestamp_matches = re.search(r"^timestamp\[(.*)\]$", datasets_dtype) + if timestamp_matches: + timestamp_internals = timestamp_matches.group(1) + internals_matches = re.search( + r"^(s|ms|us|ns),\s*tz=([a-zA-Z0-9/_+\-:]*)$", timestamp_internals + ) + if timestamp_internals in ["s", "ms", "us", "ns"]: + return pa.timestamp(timestamp_internals) + elif internals_matches: + return pa.timestamp(internals_matches.group(1), internals_matches.group(2)) + else: + raise ValueError( + _dtype_error_msg( + datasets_dtype, + "timestamp", + examples=["timestamp[us]", "timestamp[us, tz=America/New_York"], + urls=[ + "https://arrow.apache.org/docs/python/generated/pyarrow.timestamp.html" + ], + ) + ) + + duration_matches = re.search(r"^duration\[(.*)\]$", datasets_dtype) + if duration_matches: + duration_internals = duration_matches.group(1) + if duration_internals in ["s", "ms", "us", "ns"]: + return pa.duration(duration_internals) + else: + raise ValueError( + _dtype_error_msg( + datasets_dtype, + "duration", + examples=["duration[s]", "duration[us]"], + urls=[ + "https://arrow.apache.org/docs/python/generated/pyarrow.duration.html" + ], + ) + ) + + time_matches = re.search(r"^time(.*)\[(.*)\]$", datasets_dtype) + if time_matches: + time_internals_bits = time_matches.group(1) + if time_internals_bits == "32": + time_internals_unit = time_matches.group(2) + if time_internals_unit in ["s", "ms"]: + return pa.time32(time_internals_unit) + else: + raise ValueError( + f"{time_internals_unit} is not a valid unit for the pyarrow time32 type. Supported units: s (second) and ms (millisecond)." + ) + elif time_internals_bits == "64": + time_internals_unit = time_matches.group(2) + if time_internals_unit in ["us", "ns"]: + return pa.time64(time_internals_unit) + else: + raise ValueError( + f"{time_internals_unit} is not a valid unit for the pyarrow time64 type. Supported units: us (microsecond) and ns (nanosecond)." + ) + else: + raise ValueError( + _dtype_error_msg( + datasets_dtype, + "time", + examples=["time32[s]", "time64[us]"], + urls=[ + "https://arrow.apache.org/docs/python/generated/pyarrow.time32.html", + "https://arrow.apache.org/docs/python/generated/pyarrow.time64.html", + ], + ) + ) + + decimal_matches = re.search(r"^decimal(.*)\((.*)\)$", datasets_dtype) + if decimal_matches: + decimal_internals_bits = decimal_matches.group(1) + if decimal_internals_bits == "128": + decimal_internals_precision_and_scale = re.search( + r"^(\d+),\s*(-?\d+)$", decimal_matches.group(2) + ) + if decimal_internals_precision_and_scale: + precision = decimal_internals_precision_and_scale.group(1) + scale = decimal_internals_precision_and_scale.group(2) + return pa.decimal128(int(precision), int(scale)) + else: + raise ValueError( + _dtype_error_msg( + datasets_dtype, + "decimal128", + examples=["decimal128(10, 2)", "decimal128(4, -2)"], + urls=[ + "https://arrow.apache.org/docs/python/generated/pyarrow.decimal128.html" + ], + ) + ) + elif decimal_internals_bits == "256": + decimal_internals_precision_and_scale = re.search( + r"^(\d+),\s*(-?\d+)$", decimal_matches.group(2) + ) + if decimal_internals_precision_and_scale: + precision = decimal_internals_precision_and_scale.group(1) + scale = decimal_internals_precision_and_scale.group(2) + return pa.decimal256(int(precision), int(scale)) + else: + raise ValueError( + _dtype_error_msg( + datasets_dtype, + "decimal256", + examples=["decimal256(30, 2)", "decimal256(38, -4)"], + urls=[ + "https://arrow.apache.org/docs/python/generated/pyarrow.decimal256.html" + ], + ) + ) + else: + raise ValueError( + _dtype_error_msg( + datasets_dtype, + "decimal", + examples=["decimal128(12, 3)", "decimal256(40, 6)"], + urls=[ + "https://arrow.apache.org/docs/python/generated/pyarrow.decimal128.html", + "https://arrow.apache.org/docs/python/generated/pyarrow.decimal256.html", + ], + ) + ) + + raise ValueError( + f"Neither {datasets_dtype} nor {datasets_dtype + '_'} seems to be a pyarrow data type. " + f"Please make sure to use a correct data type, see: " + f"https://arrow.apache.org/docs/python/api/datatypes.html#factory-functions" + ) + + +def is_binary_like(data_type: pa.DataType) -> bool: + return pa.types.is_binary(data_type) or pa.types.is_unicode(data_type) + + +def is_large_binary_like(data_type: pa.DataType) -> bool: + return pa.types.is_large_binary(data_type) or pa.types.is_large_unicode(data_type) + + +def is_fixed_width(data_type: pa.DataType) -> bool: + return ( + pa.types.is_primitive(data_type) + or pa.types.is_dictionary(data_type) + or pa.types.is_large_list(data_type) + ) + + +def upcast_tables(tables: List[pa.Table]): + cols = None + cols_diff_dtypes = defaultdict(set) + for table in tables: + if cols is None: + cols = {col.name: col.type for col in table.schema} + else: + cols.update( + {col.name: col.type for col in table.schema if col.name not in cols} + ) + for col in table.schema: + if cols[col.name] != col.type: + cols_diff_dtypes[col.name].update({col.type, cols[col.name]}) + cols_to_cast = {} + for col, types in cols_diff_dtypes.items(): + types = [_arrow_to_datasets_dtype(t) for t in types] + cols_to_cast[col] = string_to_arrow(determine_upcast(types)) + + new_tables = [] + for table in tables: + cols = table.schema.names + casts = pa.schema( + [ + (col.name, cols_to_cast[col.name]) if col.name in cols_to_cast else col + for col in table.schema + ] + ) + new_tables.append(table.cast(casts)) + return new_tables + + +def determine_upcast(dtype_list): + """ + Determines the upcasted data type from a list of data types based on a predefined hierarchy. + + Args: + dtype_list (list): List of data types to be considered for upcasting. + + Returns: + str: The upcasted data type. + """ + + # Define the hierarchy of dtypes with their priority levels + dtype_hierarchy = { + "null": 0, + "bool": 1, + "int8": 2, + "int16": 3, + "int32": 4, + "int": 5, + "int64": 5, + "uint8": 6, + "uint16": 7, + "uint32": 8, + "uint64": 9, + "float16": 10, + "float": 11, + "float32": 11, + "float64": 12, + "time32[s]": 13, + "time32[ms]": 14, + "time64[us]": 15, + "time64[ns]": 16, + "timestamp[s]": 17, + "timestamp[ms]": 18, + "timestamp[us]": 19, + "timestamp[ns]": 20, + "date32": 21, + "date64": 22, + "duration[s]": 23, + "duration[ms]": 24, + "duration[us]": 25, + "duration[ns]": 26, + "decimal128": 27, + "decimal256": 28, + "binary": 29, + "large_binary": 30, + "string": 31, + "large_string": 32, + } + + highest_priority = 0 + upcast_dtype = "null" # the lowest in hierarchy + + for dtype in dtype_list: + priority = dtype_hierarchy.get(dtype, None) + if not priority: + raise ValueError(f"Invalid dtype found {dtype}") + if priority > highest_priority: + highest_priority = priority + upcast_dtype = dtype + + return upcast_dtype + + +def concat_blocks(pa_tables: List[pa.Table], axis: int = 0, append=True) -> pa.Table: + if axis == 0: + # we set promote=True to fill missing columns with null values + if config.PYARROW_VERSION.major < 14: + return pa.concat_tables(pa_tables, promote=True) + else: + return pa.concat_tables(pa_tables, promote_options="permissive") + elif axis == 1: + for i, table in enumerate(pa_tables): + if i == 0: + pa_table = table + else: + for name, col in zip(table.column_names, table.columns): + if append: + pa_table = pa_table.append_column(name, col) + else: + pa_table = pa_table.add_column(0, name, col) + return pa_table + else: + raise ValueError("'axis' must be either 0 or 1") + + +def _arrow_join( + left_table: pa.Table, + right_table: pa.Table, + keys: Union[str, List[str]], + right_keys: Union[str, List[str]] = None, + join_type="left outer", + left_suffix=None, + right_suffix=None, + coalesce_keys=True, + use_threads=True, +) -> pa.Table: + """Extends arrow's join to support joining on struct columns at any nested level. + + Args: + right_table (`Table`): + The table to join to the current one, acting as the right table in the join operation. + + keys (`Union[str, List[str]]`): + The columns from current table that should be used as keys of the join operation left side. + + right_keys (`Union[str, List[str]]`, *optional*): + The columns from the right_table that should be used as keys on the join operation right side. When None use the same key names as the left table. + + join_type (`str`, Defaults to 'left outer'): + The kind of join that should be performed, one of (“left semi”, “right semi”, “left anti”, “right anti”, “inner”, “left outer”, “right outer”, “full outer”) + + left_suffix (`str`, *optional*): + Which suffix to add to left column names. This prevents confusion when the columns in left and right tables have colliding names. + + right_suffix (`str`, *optional*): + Which suffix to add to the right column names. This prevents confusion when the columns in left and right tables have colliding names. + + coalesce_keys (`bool`, Defaults to True): + If the duplicated keys should be omitted from one of the sides in the join result. + + use_threads (`bool`, Defaults to True): + Whether to use multithreading or not. + + + """ + + if isinstance(keys, str): + keys = [keys] + + if right_keys is None: + right_keys = keys + else: + if isinstance(right_keys, str): + right_keys = [right_keys] + + left_cast = [] + right_cast = [] + + def _get_struct_columns_and_prepare_cast( + left_schema: pa.Schema, right_schema: pa.Schema + ) -> dict: + struct_columns = [] + keys = set() + + def process_nested_dict(key, nested_schema): + if key in keys: + return + keys.add(key) + # Check if the value is a list of dictionaries + if pa.types.is_struct(nested_schema): + full_keys = [] + original_keys = [] + for nested_field in nested_schema: + # Recursively process the nested dictionary + original_keys.append(nested_field.name) + full_key = f"{key}.{nested_field.name}" + full_keys.append(full_key) + process_nested_dict(full_key, nested_field.type) + struct_columns.append((key, original_keys, full_keys)) + + for field in left_schema: + if pa.types.is_struct(field.type): + left_cast.append(pa.field(field.name, field.type)) + process_nested_dict(field.name, field.type) + elif ( + pa.types.is_list(field.type) + or pa.types.is_large_list(field.type) + or pa.types.is_fixed_size_list(field.type) + ): + raise pa.ArrowNotImplementedError( + "Joining on lists is not supported. Please load them as a struct before joining. For example, change the column with value [1,2,3] to {'a': 1, 'b': 2, 'c': 3}" + ) + elif pa.types.is_null(field.type): + left_cast.append(pa.field(field.name, pa.string())) + elif ( + not pa.types.is_fixed_size_list + and not is_fixed_width(field.type) + and not is_binary_like(field.type) + and not is_large_binary_like(field.type) + ): + left_cast.append(pa.field(field.name, pa.string())) + else: + left_cast.append(pa.field(field.name, field.type)) + + for field in right_schema: + if pa.types.is_struct(field.type): + right_cast.append(pa.field(field.name, field.type)) + process_nested_dict(field.name, field.type) + elif ( + pa.types.is_list(field.type) + or pa.types.is_large_list(field.type) + or pa.types.is_fixed_size_list(field.type) + ): + raise pa.ArrowNotImplementedError( + "Joining on lists is not supported. Please load them as a struct before joining. For example, change the column with value [1,2,3] to {'a': 1, 'b': 2, 'c': 3}" + ) + elif pa.types.is_null(field.type): + right_cast.append(pa.field(field.name, pa.string())) + elif ( + not pa.types.is_fixed_size_list + and not is_fixed_width(field.type) + and not is_binary_like(field.type) + and not is_large_binary_like(field.type) + ): + right_cast.append(pa.field(field.name, pa.string())) + else: + right_cast.append(pa.field(field.name, field.type)) + + return struct_columns + + def reconstruct_table(joined_table, struct_cols): + """ + Reconstruct struct columns from flattened columns based on original schema. + """ + for column, orig_names, nested_columns in struct_cols: + nested_data = { + sub_col: joined_table[col].combine_chunks() + for col, sub_col in zip(nested_columns, orig_names) + } + reconstructed_nested_col = pa.StructArray.from_arrays( + nested_data.values(), nested_data.keys() + ) + index = joined_table.schema.get_field_index(nested_columns[0]) + joined_table = joined_table.drop(nested_columns).add_column( + index, column, reconstructed_nested_col + ) + return joined_table + + def get_nested_level(schema, level=0): + """ + Get the maximum level of nesting in a struct schema. + """ + max_level = level + for field in schema: + if pa.types.is_struct(field.type): + max_level = max( + max_level, get_nested_level(field.type, level=level + 1) + ) + return max_level + + def flatten_table(table, max_level, current_level=0): + """ + Recursively flatten a table. + """ + if current_level == max_level: + return table + return flatten_table( + table.flatten(), max_level, current_level=current_level + 1 + ) + + left_nested_level = get_nested_level(left_table.schema) + right_nested_level = get_nested_level(right_table.schema) + struct_columns = _get_struct_columns_and_prepare_cast( + left_table.schema, right_table.schema + ) + + colliding_names = list( + (set(right_table.column_names) & set(left_table.column_names)) + - set(keys) + - set(right_keys) + ) + + for left_key, right_key in zip(keys, right_keys): + if left_table[left_key].type != right_table[right_key].type: + index = right_table.schema.get_field_index(right_key) + right_cast[index] = pa.field(right_key, left_table[left_key].type) + + left_table = left_table.cast(pa.schema(left_cast)) if left_cast else left_table + right_table = right_table.cast(pa.schema(right_cast)) if right_cast else right_table + + return reconstruct_table( + flatten_table(left_table, left_nested_level).join( + flatten_table(right_table.drop(colliding_names), right_nested_level), + keys=keys, + right_keys=right_keys, + join_type=join_type, + left_suffix=left_suffix, + right_suffix=right_suffix, + coalesce_keys=coalesce_keys, + use_threads=use_threads, + ), + struct_columns, + ) + + +def init_arrow_buffer_and_writer( + cache_file_name, + fingerprint=None, + features=None, + writer_batch_size=None, + keep_in_memory=False, + disable_nullable=False, +): + if isinstance(cache_file_name, Path): + cache_file_name = cache_file_name.resolve().as_posix() + if is_datasets_available(): + from datasets import arrow_writer + + # Prepare output buffer and batched writer in memory or on file if we update the table + if keep_in_memory or cache_file_name is None: + buf_writer = pa.BufferOutputStream() + tmp_file = None + writer = arrow_writer.ArrowWriter( + features=features, + stream=buf_writer, + writer_batch_size=writer_batch_size, + update_features=False, + fingerprint=fingerprint, + disable_nullable=disable_nullable, + ) + else: + buf_writer = None + tmp_file = tempfile.NamedTemporaryFile( + "wb", dir=os.path.dirname(cache_file_name), delete=False + ) + writer = arrow_writer.ArrowWriter( + features=features, + path=tmp_file.name, + writer_batch_size=writer_batch_size, + update_features=False, + fingerprint=fingerprint, + disable_nullable=disable_nullable, + ) + else: + buf_writer = None + tmp_file = tempfile.NamedTemporaryFile( + "wb", dir=os.path.dirname(cache_file_name), delete=False + ) + if features is not None and isinstance(features, dict): + features = pa.schema([(k, string_to_arrow(v)) for k, v in features.items()]) + if features is not None and not isinstance(features, pa.Schema): + raise ValueError( + "features must be a dict or a pa.Schema when datasets is not installed" + ) + writer = pa.ipc.RecordBatchStreamWriter(tmp_file, schema=features) + + def finalize(): + writer.close() + + writer.finalize = finalize + return buf_writer, writer, tmp_file + + +def write_arrow_table( + table: pa.Table, + cache_file_name: Union[str, Path], + fingerprint=None, + features=None, + writer_batch_size=None, + disable_nullable=False, +): + # requires_backends("write_arrow_table", "datasets") + from datasets import Features + + if features is None: + if isinstance(table, (pa.ChunkedArray, pa.Array)): + data_type = table.type + features = _arrow_to_datasets_dtype(data_type) + features = Features.from_arrow_schema(table.schema) + tmp_file = None + + try: + _, writer, tmp_file = init_arrow_buffer_and_writer( + cache_file_name, + fingerprint=fingerprint, + features=features, + writer_batch_size=writer_batch_size, + disable_nullable=disable_nullable, + ) + writer.write_table(table) + writer.finalize() + except: + if tmp_file and os.path.exists(tmp_file.name): + tmp_file.close() + os.remove(tmp_file.name) + raise + return move_temp_file(tmp_file, cache_file_name) + + +def python_to_r(value): + """Converts a Python primitive type into its R equivalent.""" + + if value is None: + return "NULL" + + elif isinstance(value, bool): + return "TRUE" if value else "FALSE" + + elif isinstance(value, (int, float)): + return str(value) + + elif isinstance(value, complex): + return f"complex(real = {value.real}, imaginary = {value.imag})" + + elif isinstance(value, str): + return f'"{value}"' # Add quotes around strings for R + + elif isinstance(value, list) or isinstance(value, tuple): + # Recursively convert elements of the list/tuple to R format + converted_elements = ", ".join(python_to_r(elem) for elem in value) + return f"c({converted_elements})" + + elif isinstance(value, dict): + # Convert dict to named list in R + converted_items = [ + f"{python_to_r(k)} = {python_to_r(v)}" for k, v in value.items() + ] + return f"list({', '.join(converted_items)})" + + elif isinstance(value, set): + # Convert set to unique vector in R + converted_elements = ", ".join(python_to_r(elem) for elem in sorted(value)) + return f"unique(c({converted_elements}))" + + elif isinstance(value, range): + # Convert range to sequence in R + return f"seq({value.start}, {value.stop - 1}, by = {value.step})" + + elif isinstance(value, bytes): + # Handle bytes conversion to raw in R (if needed, otherwise leave as unsupported) + return f"as.raw(c({', '.join(hex(b) for b in value)}))" + + else: + raise TypeError(f"Type {type(value)} is not supported.") + + +def debug_r_script(path): + path = Path(path).expanduser().resolve() + path.mkdir(parents=True, exist_ok=True) + path = path.as_posix() + + def decorator(func): + @wraps(func) + def wrapper(*args, **func_kwargs): + name = func.__name__ + if is_rpy2_arrow_available(): + from rpy2.robjects import ListVector, default_converter + from rpy2_arrow.arrow import converter + + _converter = default_converter + converter + elif is_rpy2_available(): + from rpy2.robjects import ListVector, conversion + + _converter = ( + conversion.get_conversion() + if getattr(conversion, "get_conversion", None) + else conversion.converter + ) + else: + # suggest installing rpy2_arrow if rpy2 is not available + requires_backends(name, "rpy2_arrow") + + import rpy2.rinterface + import rpy2.robjects as ro + from rpy2.rinterface import Sexp + + self = args[0] + if hasattr(self, "plotter"): + method_args = [ + p.name + for p in inspect.signature(self.plotter).parameters.values() + if p != p.VAR_KEYWORD + ] + else: + method_args = [ + p.name + for p in inspect.signature(func).parameters.values() + if p != p.VAR_KEYWORD + ][1:] + + kwargs = copy.deepcopy(func_kwargs) + for i, arg in enumerate(args[1:]): + kwargs[method_args[i]] = arg + + kwargs.update(self.config.get_params()) + + def convert_to_r(arg): + if isinstance(arg, Sexp): + return arg + elif arg is None: + return ro.NULL + elif isinstance(arg, (list, tuple)): + return _converter.py2rpy([convert_to_r(a) for a in arg]) + elif isinstance(arg, dict): + return ListVector(arg) + else: + return _converter.py2rpy(arg) + + debug_script = "options(error = traceback)\n" + debug_script += ( + f'R_SCRIPTS_PATH <- "{config.R_SCRIPTS.resolve().as_posix()}"\n' + ) + debug_script += 'source(file.path(R_SCRIPTS_PATH, "plotting_utils.R"))\n' + debug_script += 'require("arrow")\n' + tmp_dir = tempfile.gettempdir() + with rpy2.rinterface.local_context() as r_context: + save_vars = [] + for k, v in kwargs.items(): + if isinstance(v, (pa.Array, pa.Table, pa.ChunkedArray)): + cache_file_name = f"{tmp_dir}/{k}.arrow" + write_arrow_table(v, cache_file_name) + debug_script += ( + f"{k} <- as.data.frame(arrow::read_ipc_stream(" + f"'{cache_file_name}', as_data_frame = FALSE))\n" + ) + else: + try: + debug_script += f"{k} <- {python_to_r(v)}\n" + except TypeError: + save_vars.append(k) + r_context[k] = convert_to_r(arg) + if save_vars: + code = ", ".join(r_context.keys()) + code = f"save({code}, file = '{tmp_dir}/data.RData')" + ro.r(code) + debug_script = debug_script + f"load('{tmp_dir}/data.RData')\n" + script = self.r_caller.r_code.split("\n") + first_start = False + second_start = False + for line in script: + if first_start and second_start: + debug_script += f"{line[2:]}\n" + if line.startswith(self.config.main_method): + first_start = True + if first_start and line.endswith("{"): + second_start = True + elif line.startswith("}") and first_start and second_start: + break + with open(f"{path}/debug.R", "w") as f: + f.write(debug_script) + return func(*args, **func_kwargs) + + return wrapper + + return decorator + + +def read_arrow_table( + cache_file_name: Union[str, Path], +): + return pa.ipc.open_stream(cache_file_name).read_all() diff --git a/src/biofit/utils/types.py b/src/biofit/utils/types.py new file mode 100644 index 0000000..8ba7f9b --- /dev/null +++ b/src/biofit/utils/types.py @@ -0,0 +1,27 @@ +class Unset: + """A class to represent an unset value. + + This is used to differentiate between a value that is not set and a value that is set to None. + Optionally, a description can be provided to indicate what the actual default is when Unset is used. + """ + + def __init__(self, description: str = ""): + self.description = description + + def __repr__(self) -> str: + """Return the string representation of the class, including any default description if provided. + + Returns: + The string representation of the class. + """ + if self.description: + return f"Unset (default: {self.description})" + return "Unset" + + def __bool__(self) -> bool: + """Return False when the class is used in a boolean context. + + Returns: + False + """ + return False diff --git a/src/biofit/utils/version.py b/src/biofit/utils/version.py new file mode 100644 index 0000000..fbd2e38 --- /dev/null +++ b/src/biofit/utils/version.py @@ -0,0 +1,115 @@ +# Copyright 2020 The HuggingFace Datasets Authors and the TensorFlow Datasets Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Modified by: Patrick Smyth +# Date: 2024 +# Summary of changes: +# - Added __version__ variable +"""Version utils.""" + +import dataclasses +import re +from dataclasses import dataclass +from functools import total_ordering +from typing import Optional, Union + +_VERSION_REG = re.compile(r"^(?P\d+)" r"\.(?P\d+)" r"\.(?P\d+)$") + + +@total_ordering +@dataclass +class Version: + """Dataset version `MAJOR.MINOR.PATCH`. + + Args: + version_str (`str`): + The dataset version. + description (`str`): + A description of what is new in this version. + major (`str`): + minor (`str`): + patch (`str`): + + Example: + + ```py + >>> VERSION = datasets.Version("1.0.0") + ``` + """ + + version_str: str + description: Optional[str] = None + major: Optional[Union[str, int]] = None + minor: Optional[Union[str, int]] = None + patch: Optional[Union[str, int]] = None + + def __post_init__(self): + self.major, self.minor, self.patch = _str_to_version_tuple(self.version_str) + + def __repr__(self): + return f"{self.tuple[0]}.{self.tuple[1]}.{self.tuple[2]}" + + @property + def tuple(self): + return self.major, self.minor, self.patch + + def _validate_operand(self, other): + if isinstance(other, str): + return Version(other) + elif isinstance(other, Version): + return other + raise TypeError(f"{other} (type {type(other)}) cannot be compared to version.") + + def __eq__(self, other): + try: + other = self._validate_operand(other) + except (TypeError, ValueError): + return False + else: + return self.tuple == other.tuple + + def __lt__(self, other): + other = self._validate_operand(other) + return self.tuple < other.tuple + + def __hash__(self): + return hash(_version_tuple_to_str(self.tuple)) + + @classmethod + def from_dict(cls, dic): + field_names = {f.name for f in dataclasses.fields(cls)} + return cls(**{k: v for k, v in dic.items() if k in field_names}) + + def _to_yaml_string(self) -> str: + return self.version_str + + +def _str_to_version_tuple(version_str): + """Return the tuple (major, minor, patch) version extracted from the str.""" + res = _VERSION_REG.match(version_str) + if not res: + raise ValueError( + f"Invalid version '{version_str}'. Format should be x.y.z with {{x,y,z}} being digits." + ) + return tuple( + int(v) for v in [res.group("major"), res.group("minor"), res.group("patch")] + ) + + +def _version_tuple_to_str(version_tuple): + """Return the str version from the version tuple (major, minor, patch).""" + return ".".join(str(v) for v in version_tuple) + + +__version__ = "0.0.0" diff --git a/src/biofit/visualization/__init__.py b/src/biofit/visualization/__init__.py new file mode 100644 index 0000000..3248c7a --- /dev/null +++ b/src/biofit/visualization/__init__.py @@ -0,0 +1,22 @@ +# ruff: noqa +from .feature_importance import ( + FeatureImportancePlotter, + FeatureImportancePlotterConfig, + FeatureImportancePlotterConfigForOTU, +) +from .sample_metadata import SampleMetadataPlotter, SampleMetadataPlotterConfig +from .plotting_utils import ( + generate_violin, + generate_barplot, + generate_scatterplot, + generate_histogram, + generate_comparison_histogram, + plot_correlation, + plot_feature_distribution, + compare_feature_distributions, + plot_sample_distribution, + compare_sample_distributions, + plot_dimension_reduction, + plot_feature_importance, + plot_sample_metadata, +) diff --git a/src/biofit/visualization/barplot.py b/src/biofit/visualization/barplot.py new file mode 100644 index 0000000..0b651c5 --- /dev/null +++ b/src/biofit/visualization/barplot.py @@ -0,0 +1,198 @@ +import textwrap +from dataclasses import dataclass, field +from typing import List, Optional, Type + +from biocore import DataHandler + +import biofit.config as config +from biofit.integration.biosets import get_feature +from biofit.integration.R import RCaller +from biofit.processing import SelectedColumnTypes, sync_backup_config +from biofit.utils.types import Unset +from biofit.visualization.plotting import BasePlotter, PlotterConfig + + +@dataclass +class BarPlotConfig(PlotterConfig): + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [None, get_feature("TARGET_FEATURE_TYPES"), None], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [None, get_feature("TARGET_FEATURE_TYPES"), None], + init=False, + repr=False, + ) + + _fit_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES"), None, None], + init=False, + repr=False, + ) + _transform_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES"), None, None], + init=False, + repr=False, + ) + + processor_type: str = field(default="scaling", init=False, repr=False) + r_source: str = field( + default=(config.R_SCRIPTS / "plotting_utils.R").as_posix(), + init=False, + repr=False, + ) + main_method: str = field(default="generate_barplot", init=False, repr=False) + + label_name: str = None + value_name: Optional[str] = None + groupby: Optional[str] = None + xlab: Optional[str] = None + ylab: Optional[str] = None + title: str = "Bar Plot" + col_set: str = "Set1" + col_outline: str = "grey30" + col_labels: str = "black" + cols: Optional[str] = None + prop: bool = False + add_count_lab: bool = True + vars_as_entered: bool = False + legend_position: str = "top" + font_size: float = 3.25 + + +class BarPlotter(BasePlotter): + _config_class = BarPlotConfig + config: BarPlotConfig + + def __init__( + self, + xlab: Optional[str] = None, + ylab: Optional[str] = None, + title: str = Unset('"Bar Plot"'), + col_set: str = Unset('"Set1"'), + col_labels: str = Unset('"black"'), + col_outline: str = Unset('"grey30"'), + cols: Optional[List[str]] = Unset("None"), + prop: bool = Unset("False"), + add_count_lab: bool = Unset("True"), + vars_as_entered: bool = Unset("False"), + legend_position: str = Unset('"top"'), + font_size: float = Unset("3.25"), + path=None, + config: Optional[BarPlotConfig] = None, + ): + super().__init__( + config=config, + xlab=xlab, + ylab=ylab, + title=title, + col_set=col_set, + col_labels=col_labels, + cols=cols, + prop=prop, + add_count_lab=add_count_lab, + vars_as_entered=vars_as_entered, + legend_position=legend_position, + font_size=font_size, + path=path, + ) + + self = self.__post_init__() + + def __post_init__(self): + if self.config.r_source and self.config.main_method: + self.r_caller = RCaller.from_script(self.config.r_source) + self.r_caller.verify_r_dependencies() + enter_code = textwrap.dedent( + """ + suppressPackageStartupMessages(require(ggplot2)) + """ + ) + exit_code = textwrap.dedent( + """ + ggplot2::ggsave(path, results) + """ + ) + self.plotter = self.r_caller.get_method( + self.config.main_method, enter_code, exit_code + ) + + return self + + @sync_backup_config + def set_params(self, **kwargs): + self.config = self.config.replace_defaults(**kwargs) + self = self.__post_init__() + return self + + def plot( + self, + x, + y, + group, + label_name: str = None, + value_name: Optional[str] = None, + groupby: Optional[str] = None, + xlab: Optional[str] = Unset("None"), + ylab: Optional[str] = Unset("None"), + title: str = Unset('"Bar Plot"'), + col_set: str = Unset('"Set1"'), + col_labels: str = Unset('"black"'), + col_outline: str = Unset('"grey30"'), + cols: Optional[List[str]] = Unset("None"), + prop: bool = Unset("False"), + add_count_lab: bool = Unset("True"), + vars_as_entered: bool = Unset("False"), + legend_position: str = Unset('"top"'), + font_size: float = Unset("3.25"), + path: str = None, + device: str = "pdf", + fingerprint: str = None, + unused_columns: SelectedColumnTypes = None, + raise_if_missing: bool = True, + show: bool = True, + ): + self.config._input_columns = self._set_input_columns_and_arity( + value_name, label_name, groupby + ) + return self._plot( + x, + y, + group, + xlab=xlab, + ylab=ylab, + title=title, + col_set=col_set, + col_labels=col_labels, + col_outline=col_outline, + cols=cols, + prop=prop, + add_count_lab=add_count_lab, + vars_as_entered=vars_as_entered, + legend_position=legend_position, + font_size=font_size, + path=path, + device=device, + fingerprint=fingerprint, + unused_columns=unused_columns, + raise_if_missing=raise_if_missing, + show=show, + ) + + def plot_dataset(self, x, y, group): + from datasets import Dataset as HfDataset + + if isinstance(x, (Dataset, HfDataset)): + x = decode(x) + if isinstance(group, (Dataset, HfDataset)): + group = decode(group) + return self.plot_arrow( + DataHandler.to_arrow(x), + DataHandler.to_arrow(y) if y is not None else None, + DataHandler.to_arrow(group) if group is not None else None, + ) + + def plot_arrow(self, x, y=None, group=None): + kwargs = self.config.get_params() + self.plotter(x, y, group, **kwargs) diff --git a/src/biofit/visualization/dimension_reduction.py b/src/biofit/visualization/dimension_reduction.py new file mode 100644 index 0000000..4ff8f02 --- /dev/null +++ b/src/biofit/visualization/dimension_reduction.py @@ -0,0 +1,435 @@ +import json +import warnings +from dataclasses import dataclass, field + +import numpy as np +import pandas as pd +import pyarrow as pa +from biocore import DataHandler +from biocore.utils.import_util import requires_backends + +from biofit.integration.biosets import get_feature +from biofit.processing import SelectedColumnTypes, SelectedFeatureTypes +from biofit.utils.types import Unset +from biofit.visualization.plotting import BasePlotter, PlotterConfig +from biofit.visualization.plotting_utils import get_distinct_colors + + +@dataclass +class DimensionReducerPlotterConfig(PlotterConfig): + processor_type: str = field(default="feature_extractor", init=False, repr=False) + _fit_input_feature_types: SelectedFeatureTypes = field( + default_factory=lambda: [ + None, + get_feature("TARGET_FEATURE_TYPES"), + get_feature("METADATA_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + _transform_input_feature_types: SelectedFeatureTypes = field( + default_factory=lambda: [ + None, + get_feature("TARGET_FEATURE_TYPES"), + get_feature("METADATA_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + _fit_unused_feature_types: SelectedFeatureTypes = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES"), None, None], + init=False, + repr=False, + ) + _transform_unused_feature_types: SelectedFeatureTypes = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES"), None, None], + init=False, + repr=False, + ) + + title: str = "Dimension Reduction Plot" + colormap: str = "nipy_spectral" + n_components: int = 3 + label_column: str = None + group_column: str = None + + +class DimensionReductionPlotter(BasePlotter): + """Base class for feature extraction processors.""" + + _config_class = DimensionReducerPlotterConfig + config: DimensionReducerPlotterConfig + + def __init__( + self, + label_column: str = None, + group_column: str = None, + title: str = Unset('"Dimension Reduction Plot"'), + colormap: str = Unset('"nipy_spectral"'), + n_components: int = Unset("3"), + config: DimensionReducerPlotterConfig = None, + **kwargs, + ): + super().__init__( + config=config, + n_components=n_components, + label_column=label_column, + group_column=group_column, + title=title, + colormap=colormap, + **kwargs, + ) + + def plot( + self, + X, + labels=None, + group=None, + input_columns: SelectedColumnTypes = None, + label_column: SelectedColumnTypes = None, + group_column: SelectedColumnTypes = None, + title: str = Unset('"Dimension Reduction Plot"'), + colormap: str = Unset('"nipy_spectral"'), + n_components: int = Unset("3"), + dimension_reducer=None, + path: str = None, + device: str = "pdf", + fingerprint: str = None, + unused_columns: SelectedColumnTypes = None, + raise_if_missing: bool = True, + show=True, + ): + """Plot the PCoA plot. + + Args: + X (array_like): + The input data. Can be a Dataset, polars/pandas DataFrame, numpy array, or Arrow table. + labels (array_like, *optional*): + The labels for the data. Dictates the color of the points in the plot. + Must be an array of values with the same length as the number of rows in X. + If not provided, the points will be colored by group. + group (array_like, *optional*): + The group for the data. Dictates the shape of the points in the plot. + Must be an array of values with the same length as the number of rows in X. + pcoa (array_like, *optional*): + The fitted PCoAFeatureExtractor object. Used to extract the eigvals for the explained variance ratio. + If not provided, the explained variance ratio will not be displayed. + **kwargs: + Additional keyword arguments to pass to the plot: + n_components (int, 2): + The number of components to plot. Defaults to 2. + label_name (str, *optional*): + The name of the labels. Used for the legend and retrieving the values from X if labels is not provided. + group_name (str, *optional*): + The name of the group. Used for the legend and retrieving the values from X if group is not provided. + + """ + self.config._input_columns = self._set_input_columns_and_arity( + input_columns, label_column, group_column + ) + return self._plot( + X, + labels, + group, + n_components=n_components, + dimension_reducer=dimension_reducer, + path=path, + device=device, + fingerprint=fingerprint, + unused_columns=unused_columns, + raise_if_missing=raise_if_missing, + show=show, + ) + + def plot_dataset(self, X, labels=None, group=None, dimension_reducer=None): + from biosets import decode + + if labels is not None: + labels = decode(labels) + + return self.plot_arrow( + DataHandler.to_arrow(X), + DataHandler.to_arrow(labels) if labels is not None else None, + DataHandler.to_arrow(group) if group is not None else None, + dimension_reducer, + ) + + def plot_arrow(self, X, labels=None, group=None, dimension_reducer=None): + def get_int2str_converter(x) -> np.ndarray: + if not isinstance(x, pa.Table): + return lambda x: x + + def get_features(x: pa.Table): + metadata = x.schema.metadata + if metadata: + if b"huggingface" in metadata: + metadata = json.loads(metadata[b"huggingface"].decode("utf-8")) + + if "info" in metadata and "features" in metadata["info"]: + return metadata["info"]["features"] + + return {} + + metadata = get_features(x) + lab_col = x.column_names[0] + if metadata and lab_col in metadata: + label_metadata = metadata[lab_col] + if label_metadata["_type"] == "ClassLabel": + label_metadata.pop("_type") + cls_label = get_feature("ClassLabel")(**label_metadata) + return cls_label.int2str + return lambda x: x + + def pairplot( + tbl: pd.DataFrame, + n_components, + ex_var_ratio=None, + label_name=None, + group_name=None, + marker_dict=None, + color_dict=None, + ): + requires_backends(pairplot, "seaborn") + requires_backends(pairplot, "matplotlib") + import matplotlib.pyplot as plt + import seaborn as sns + from matplotlib.lines import Line2D + + features = tbl.columns[:n_components] + fig, axes = plt.subplots(n_components, n_components, figsize=(15, 15)) + + for i, f1 in enumerate(features): + for j, f2 in enumerate(features): + ax = axes[i, j] + if i != j: + if label_name is not None and group_name is not None: + for (lab, grp), df_group in tbl.groupby( + [ + label_name, + group_name, + ] + ): + sns.scatterplot( + x=f2, + y=f1, + data=df_group, + ax=ax, + color=color_dict[lab], + marker=marker_dict[grp], + ) + elif label_name is not None: + for lab, df_lab in tbl.groupby(label_name): + sns.scatterplot( + x=f2, + y=f1, + data=df_lab, + ax=ax, + color=color_dict[lab], + ) + elif group_name is not None: + for grp, df_group in tbl.groupby(group_name): + sns.scatterplot( + x=f2, + y=f1, + data=df_group, + ax=ax, + marker=marker_dict[grp], + ) + else: + sns.scatterplot(x=f2, y=f1, data=tbl, ax=ax) + + else: + if label_name is not None: + for lab, df_lab in tbl.groupby(label_name): + sns.kdeplot( + df_lab[f1], + ax=ax, + color=color_dict[lab], + fill=True, + ) + else: + sns.kdeplot(tbl[f1], ax=ax, fill=True) + if ex_var_ratio is not None: + plt.text( + 0.1, + 0.9, + f"Dim{i + 1}: {float(ex_var_ratio[i]) * 100:.2f}%", + transform=ax.transAxes, + ) + # Set only left and bottom borders + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + + # Enable ticks on left and bottom only + ax.xaxis.set_tick_params(which="both", bottom=True) + ax.yaxis.set_tick_params(which="both", left=True) + + # Set axis labels + if i == n_components - 1: + ax.set_xlabel(f2) + else: + ax.set_xlabel("") + ax.set_xticklabels([]) + + if j == 0: + ax.set_ylabel(f1) + else: + ax.set_ylabel("") + ax.set_yticklabels([]) + + if label_name: + label_lines = [ + Line2D( + [0], + [0], + color=color, + marker="o", + linestyle="", + markersize=10, + ) + for color in color_dict.values() + ] + u_labels = list(color_dict.keys()) + n_labels = len(u_labels) + label_height = 0.5 + 0.216 / 15 * (n_labels + 1) / 2 + 0.01 + label_legend = fig.legend( + label_lines, + u_labels, + title=label_name, + loc="center left", + bbox_to_anchor=(0.90, label_height), + ) + + if group_name: + group_lines = [ + Line2D( + [0], + [0], + color="gray", + marker=marker, + linestyle="", + markersize=10, + ) + for marker in marker_dict.values() + ] + u_group = list(marker_dict.keys()) + n_group = len(u_group) + + group_height = 0.5 - 0.216 / 15 * (n_group + 1) / 2 - 0.01 + fig.legend( + group_lines, + list(marker_dict.keys()), + title=group_name, + loc="center left", + bbox_to_anchor=(0.90, group_height), + ) + if label_name: + fig.add_artist(label_legend) + + if self.config.title: + fig.suptitle(self.config.title, fontsize=16) + + return fig + + warnings.filterwarnings("ignore") + n_components = self.config.n_components + tbl = DataHandler.to_format(X, "pandas").iloc[:, :n_components] + labs = None + label_name = self.config.label_column or "labels" + u_labels = None + grp = None + grp_name = self.config.group_column or None + u_group = None + if labels is not None: + labs = DataHandler.to_format(labels, "pandas") + if isinstance(labs, pd.DataFrame): + label_name = labs.columns[0] if label_name is None else label_name + labs = labs.iloc[:, 0] + elif isinstance(labs, pd.Series): + label_name = labs.name if label_name is None else label_name + labs = labs.fillna("None") + u_labels = np.unique(labs) + if group is not None: + grp = DataHandler.to_format(group, "pandas") + if isinstance(grp, pd.DataFrame): + grp_name = grp.columns[0] if grp_name is None else grp_name + grp = grp.iloc[:, 0] + elif isinstance(grp, pd.Series): + grp_name = grp.name if grp_name is None else grp_name + u_group = np.unique(grp) + + if labs is not None: + converter = get_int2str_converter(labels) + dtype = DataHandler.get_dtypes(labels) + if "int" in next(iter(dtype.values())): + encodings = labs.to_numpy() + encodings_dims = encodings.shape + if len(encodings_dims) == 1: + labs = converter(encodings) + elif encodings_dims[1] == 1: + labs = converter(encodings[:, 0]) + + if labs is not None: + u_labels = np.unique(labs) + + ex_var_ratio = None + if dimension_reducer is not None: + if hasattr(dimension_reducer.config, "eigvals"): + eigvals = dimension_reducer.config.eigvals + ex_var_ratio = eigvals / eigvals.sum() + elif hasattr(dimension_reducer.config, "estimator"): + ex_var_ratio = ( + dimension_reducer.config.estimator.explained_variance_ratio_ + ) + + tbl.columns = [f"Dim{i + 1}" for i in range(len(tbl.columns))] + + marker_styles = ["o", "v", "^", "<", ">", "s", "p", "*", "H", "X"] + if u_group is not None: + marker_dict = { + label: marker_styles[i % len(marker_styles)] + for i, label in enumerate(u_group) + } + else: + marker_dict = None + + if labs is not None and grp is not None: + tbl = tbl.assign(**{label_name: labs, grp_name: grp}) + elif labs is not None: + tbl = tbl.assign(**{label_name: labs}) + elif grp is not None: + tbl = tbl.assign(**{grp_name: grp}) + + if labs is not None: + color_dict = { + label: color + for label, color in zip( + u_labels, get_distinct_colors(len(u_labels), self.config.colormap) + ) + } + elif grp is not None: + color_dict = { + label: color + for label, color in zip( + u_group, get_distinct_colors(len(u_group), self.config.colormap) + ) + } + else: + color_dict = None + + fig = pairplot( + tbl, + n_components, + ex_var_ratio, + label_name, + grp_name, + marker_dict, + color_dict, + ) + + warnings.filterwarnings("default") + + # save the figure to the path + if self.config.path: + fig.savefig(self.config.path) + return fig diff --git a/src/biofit/visualization/feature_importance.py b/src/biofit/visualization/feature_importance.py new file mode 100644 index 0000000..ef3bf78 --- /dev/null +++ b/src/biofit/visualization/feature_importance.py @@ -0,0 +1,355 @@ +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING, List, Optional, Type + +import numpy as np +import pandas as pd +from biocore import DataHandler + +from biofit.integration.biosets import get_feature +from biofit.integration.R.r_caller import RCaller +from biofit.processing import SelectedColumnTypes +from biofit.utils.types import Unset +from biofit.visualization.plotting import ( + BasePlotter, + PlotterConfig, +) + +if TYPE_CHECKING: + from biosets import Dataset + + +@dataclass +class FeatureImportancePlotterConfig(PlotterConfig): + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [ + None, + None, + get_feature("TARGET_FEATURE_TYPES"), + get_feature("METADATA_FEATURE_TYPES_NOT_TARGET"), + ], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [ + None, + None, + get_feature("TARGET_FEATURE_TYPES"), + get_feature("METADATA_FEATURE_TYPES_NOT_TARGET"), + ], + init=False, + repr=False, + ) + _fit_unused_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("METADATA_FEATURE_TYPES"), + None, + None, + None, + ], + init=False, + repr=False, + ) + _transform_unused_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("METADATA_FEATURE_TYPES"), + None, + None, + None, + ], + init=False, + repr=False, + ) + r_source: str = field( + default=(Path(__file__).parent / "plot_feature_importance.R") + .resolve() + .as_posix(), + init=False, + repr=False, + ) + main_method: str = field(default="plot_feature_importance", init=False, repr=False) + + input_columns: SelectedColumnTypes = None + target_columns: SelectedColumnTypes = None + sample_metadata_columns: SelectedColumnTypes = None + sample_column: str = None + plot_top: int = 15 + feature_meta_name: str = None + feature_column: str = None + cols: List[str] = None + colHeat: List[str] = None + dat_log: str = field(default=None, init=True, repr=False) + show_column_names: bool = False + scale_legend_title: str = "Value" + column_title: str = "Samples" + row_title: str = "Features" + plot_title: str = "Values" + + def __post_init__(self): + if self.dat_log == ["log2_1p", "log2"]: + self.plot_title = "log2\n" + self.plot_title + elif self.dat_log == ["log10_1p", "log10"]: + self.plot_title = "log10\n" + self.plot_title + + +@dataclass +class FeatureImportancePlotterConfigForMetagenomics(FeatureImportancePlotterConfig): + dat_log: str = "log2_1p" + plot_title: str = "Abundance" + row_title: str = "Taxa" + + +@dataclass +class FeatureImportancePlotterConfigForOTU( + FeatureImportancePlotterConfigForMetagenomics +): + row_title: str = "OTUs" + + +class FeatureImportancePlotter(BasePlotter): + _config_class = FeatureImportancePlotterConfig + config: FeatureImportancePlotterConfig + + def __init__( + self, + input_columns: SelectedColumnTypes = None, + target_columns: SelectedColumnTypes = None, + sample_metadata_columns: SelectedColumnTypes = None, + sample_column: str = None, + plot_top: int = Unset("15"), + feature_meta_name: str = Unset("None"), + feature_column: str = Unset("None"), + cols: List[str] = Unset("None"), + colHeat: List[str] = Unset("None"), + dat_log: str = Unset("field(default=None, init=True, repr=False)"), + show_column_names: bool = Unset("False"), + scale_legend_title: str = Unset('"Value"'), + column_title: str = Unset('"Samples"'), + row_title: str = Unset('"Features"'), + plot_title: str = Unset('"Values"'), + path: Optional[str] = None, + install_missing: bool = None, + config: Optional[FeatureImportancePlotterConfig] = None, + ): + super().__init__( + plot_top=plot_top, + feature_meta_name=feature_meta_name, + feature_column=feature_column, + cols=cols, + colHeat=colHeat, + dat_log=dat_log, + show_column_names=show_column_names, + scale_legend_title=scale_legend_title, + column_title=column_title, + row_title=row_title, + plot_title=plot_title, + config=config, + install_missing=install_missing, + ) + self.r_caller = RCaller.from_script(self.config.r_source) + self.r_caller.verify_r_dependencies( + bioconductor_dependencies=["ComplexHeatmap"], + install_missing=install_missing, + ) + self.plotter = self.r_caller.get_method(self.config.main_method) + + def plot_dataset( + self, + X: "Dataset", + feature_importances: "Dataset", + y: "Dataset", + sample_metadata: "Dataset" = None, + feature_metadata: dict = None, + ): + from biosets import decode, get_feature_metadata, get_sample_col_name + + if feature_metadata is None: + feature_metadata = get_feature_metadata(X) + + if self.config.sample_column is None: + self.config.sample_column = ( + get_sample_col_name(sample_metadata) + if sample_metadata is not None + else get_sample_col_name(X) + ) + + self.plot_pandas( + X=DataHandler.to_pandas(X), + feature_importances=DataHandler.to_pandas(feature_importances), + y=DataHandler.to_pandas(decode(y)) if y is not None else None, + sample_metadata=DataHandler.to_pandas(sample_metadata) + if sample_metadata is not None + else None, + feature_metadata=feature_metadata, + ) + + def plot_pandas( + self, + X: pd.DataFrame, + feature_importances: pd.DataFrame, + y: pd.DataFrame, + sample_metadata: pd.DataFrame = None, + feature_metadata: dict = None, + ): + if self.config.dat_log == "log2_1p": + X = np.log2(X + 1) + elif self.config.dat_log == "log2": + X = np.log2(X) + elif self.config.dat_log == "log10_1p": + X = np.log10(X + 1) + elif self.config.dat_log == "log10": + X = np.log10(X) + + if feature_importances is None: + raise ValueError("Please provide feature importances.") + + self.config.feature_column = ( + self.config.feature_column or feature_importances.columns[0] + ) + if self.config.feature_column not in feature_importances.columns: + raise ValueError( + f"Feature column '{self.config.feature_column}' not found in feature " + "importances. Please provide the column name found in both " + "feature importances and feature metadata (if provided)." + ) + + feat_import_cols = [ + c for c in feature_importances.columns if c != self.config.feature_column + ] + + if ( + isinstance(feature_metadata, dict) + and feature_metadata is not None + and len(feature_metadata) > 0 + ): + feature_metadata = pd.DataFrame( + list(feature_metadata.values()), index=list(feature_metadata.keys()) + ) + feature_metadata = feature_metadata.reset_index( + names=[self.config.feature_column] + ) + + if len(feat_import_cols) > 1: + medians = feature_importances.loc[:, feat_import_cols].median(axis=1) + sorted_inds = np.argsort(np.abs(medians))[::-1] + feature_importances = feature_importances.iloc[sorted_inds, :].head( + self.config.plot_top + ) + X = DataHandler.to_arrow(X, preserve_index=False) + feature_importances = DataHandler.to_arrow( + feature_importances, preserve_index=False + ) + if y is not None: + y = DataHandler.to_arrow(y, preserve_index=False) + if feature_metadata is not None and len(feature_metadata) > 0: + feature_metadata = DataHandler.to_arrow( + feature_metadata, preserve_index=False + ) + if sample_metadata is not None: + sample_metadata = DataHandler.to_arrow( + sample_metadata, preserve_index=False + ) + + self._plotter( + X, + y=y, + feature_importances=feature_importances, + sample_metadata=sample_metadata, + feature_metadata=feature_metadata, + ) + + def _plotter( + self, + X, + y, + feature_importances, + sample_metadata, + feature_metadata, + ): + params = self.config.get_params() + if feature_metadata is not None: + feature_metadata_columns = DataHandler.get_column_names(feature_metadata) + feature_meta_name = params.get("feature_meta_name") + if feature_meta_name is not None: + if isinstance(feature_meta_name, (str, int)): + feature_meta_name = [feature_meta_name] + missing_cols = set(feature_meta_name) - set(feature_metadata_columns) + if missing_cols: + raise ValueError( + f"Feature metadata columns {list(missing_cols)} not found in " + "feature metadata." + ) + + self.plotter( + X, + y=y, + feature_importances=feature_importances, + sample_metadata=sample_metadata, + feature_metadata=feature_metadata, + **params, + ) + + def plot( + self, + X, + feature_importances, + y=None, + sample_metadata=None, + feature_metadata: dict = None, + input_columns: SelectedColumnTypes = None, + target_columns: SelectedColumnTypes = None, + sample_metadata_columns: SelectedColumnTypes = None, + sample_column: str = None, + plot_top: int = Unset("15"), + feature_meta_name: str = Unset("None"), + feature_column: str = Unset("None"), + cols: List[str] = Unset("None"), + colHeat: List[str] = Unset("None"), + dat_log: str = Unset("field(default=None, init=True, repr=False)"), + show_column_names: bool = Unset("False"), + scale_legend_title: str = Unset('"Value"'), + column_title: str = Unset('"Samples"'), + row_title: str = Unset('"Features"'), + plot_title: str = Unset('"Values"'), + path: str = None, + device: str = "pdf", + fingerprint: str = None, + unused_columns: SelectedColumnTypes = None, + raise_if_missing: bool = True, + show=True, + ): + # feature_importances are not selected by columns so its None + self.config._input_columns = self._set_input_columns_and_arity( + input_columns, None, target_columns, sample_metadata_columns + ) + if ( + feature_importances is not None + and DataHandler.get_shape(feature_importances)[0] == 0 + ): + raise ValueError("Feature importances is empty.") + self._plot( + X, + feature_importances, + y, + sample_metadata, + feature_metadata=feature_metadata, + plot_top=plot_top, + feature_meta_name=feature_meta_name, + sample_column=sample_column, + feature_column=feature_column, + cols=cols, + colHeat=colHeat, + dat_log=dat_log, + show_column_names=show_column_names, + scale_legend_title=scale_legend_title, + column_title=column_title, + row_title=row_title, + plot_title=plot_title, + show=show, + path=path, + device=device, + fingerprint=fingerprint, + unused_columns=unused_columns, + raise_if_missing=raise_if_missing, + ) diff --git a/src/biofit/visualization/histogram.py b/src/biofit/visualization/histogram.py new file mode 100644 index 0000000..1ee7cad --- /dev/null +++ b/src/biofit/visualization/histogram.py @@ -0,0 +1,430 @@ +import textwrap +from dataclasses import dataclass, field +from typing import List, Optional, Type + +import numpy as np + +import biofit.config as config +from biofit.integration.biosets import get_feature +from biofit.integration.R import RCaller +from biofit.processing import SelectedColumnTypes, sync_backup_config +from biofit.utils.types import Unset +from biofit.visualization.plotting import BasePlotter, PlotterConfig + + +def prepare_data_for_hist(x1, x2=None): + # Calculate row sums and column sums for dataset 1 + data1_sample = x1.sum(axis=1) + data1_feature = x1.sum(axis=0) + + list_sums = [data1_sample, data1_feature] + + # If dataset 2 is provided, calculate its row and column sums + if x2 is not None: + data2_sample = x2.sum(axis=1) + data2_feature = x2.sum(axis=0) + list_sums.extend([data2_sample, data2_feature]) + + return list_sums + + +def non_zero_sums(x1, x2=None): + # Sum all non-zero values for dataset 1 + x1_non_zero_row = (x1 != 0).sum(axis=1) + x1_non_zero_col = (x1 != 0).sum(axis=0) + sums = [x1_non_zero_row, x1_non_zero_col] + + # If dataset 2 is provided, calculate its non-zero sums + if x2 is not None: + x2_non_zero_row = (x2 != 0).sum(axis=1) + x2_non_zero_col = (x2 != 0).sum(axis=0) + sums.extend([x2_non_zero_row, x2_non_zero_col]) + + return sums + + +def prepare_axis_label(label, log_type): + if "1p" in log_type: + if "_1p" in log_type: + label_log = log_type.replace("_1p", "") + label = f"{label} ({label_log}(x+1))" + else: + label = f"{label} (ln(x+1))" + elif log_type == "log": + label = f"{label} (ln)" + else: + label = f"{label} ({log_type})" + return label + + +def log_transformation(x, log_type): + if "1p" in log_type: + if "_1p" in log_type: + label_log = log_type.replace("_1p", "") + if label_log == "log10": + return np.log10(1 + x) + elif label_log == "log2": + return np.log2(1 + x) + else: + return np.log1p(x) + elif log_type == "log": + return np.log(x) + elif log_type == "log2": + return np.log2(x) + elif log_type == "log10": + return np.log10(x) + return x + + +@dataclass +class HistogramConfig(PlotterConfig): + processor_type: str = field(default="scaling", init=False, repr=False) + _unused_feature_types: List[Type] = field( + default=get_feature("METADATA_FEATURE_TYPES"), + init=False, + repr=False, + ) + r_source: str = field( + default=(config.R_SCRIPTS / "plotting_utils.R").as_posix(), + init=False, + repr=False, + ) + main_method: str = field(default="generate_histogram", init=False, repr=False) + + xlab: str = "X" + ylab: str = "Frequency" + title: str = "Histogram" + bins: int = 30 + font_size: int = 8 + col_fill: str = "grey40" + col_outline: str = "white" + col_fill = "grey40" + col_outline = ("black",) + x1_name: str = "Before" + x2_name: str = "After" + xlog: Optional[str] = None + ylog: Optional[str] = None + + def __post_init__(self): + if self.xlog not in [ + None, + "log2", + "log10", + "log", + "log2_1p", + "log10_1p", + "log1p", + ]: + raise ValueError( + f"Invalid value for xlog: {self.xlog}. Must be one of: None, 'log2', 'log10', 'log', 'log2_1p', 'log10_1p', 'log1p'" + ) + if self.ylog not in [ + None, + "log2", + "log10", + "log", + "log2_1p", + "log10_1p", + "log1p", + ]: + raise ValueError( + f"Invalid value for ylog: {self.ylog}. Must be one of: None, 'log2', 'log10', 'log', 'log2_1p', 'log10_1p', 'log1p'" + ) + + +class HistogramPlotter(BasePlotter): + _config_class = HistogramConfig + config: HistogramConfig + + def __init__( + self, + xlab: str = Unset('"X"'), + ylab: str = Unset('"Frequency"'), + title: str = Unset('"Histogram"'), + bins: int = Unset("30"), + font_size: int = Unset("8"), + col_fill: str = Unset('"grey40"'), + col_outline: str = Unset('"white"'), + xlog: Optional[str] = Unset("None"), + ylog: Optional[str] = Unset("None"), + install_missing: bool = None, + config: Optional[HistogramConfig] = None, + ): + super().__init__( + xlab=xlab, + ylab=ylab, + title=title, + bins=bins, + font_size=font_size, + col_fill=col_fill, + col_outline=col_outline, + xlog=xlog, + ylog=ylog, + config=config, + install_missing=install_missing, + ) + if self.config.r_source and self.config.main_method: + self.r_caller = RCaller.from_script(self.config.r_source) + enter_code = textwrap.dedent( + """ + suppressPackageStartupMessages(require(ggplot2)) + """ + ) + exit_code = textwrap.dedent( + """ + ggplot2::ggsave(path, results) + """ + ) + self.plotter = self.r_caller.get_method( + self.config.main_method, enter_code, exit_code + ) + + @sync_backup_config + def set_params(self, **kwargs): + self.config = self.config.replace_defaults(**kwargs) + if self.config.r_source and self.config.main_method: + self.r_caller = RCaller.from_script(self.config.r_source) + enter_code = textwrap.dedent( + """ + suppressPackageStartupMessages(require(ggplot2)) + """ + ) + exit_code = textwrap.dedent( + """ + ggplot2::ggsave(path, results) + """ + ) + self.plotter = self.r_caller.get_method( + self.config.main_method, enter_code, exit_code + ) + + return self + + def plot( + self, + x, + input_columns: SelectedColumnTypes = None, + xlab: str = Unset('"X"'), + ylab: str = Unset('"Frequency"'), + title: str = Unset('"Histogram"'), + bins: int = Unset("30"), + font_size: int = Unset("8"), + col_fill: str = Unset('"grey40"'), + col_outline: str = Unset('"white"'), + xlog: Optional[str] = Unset("None"), + ylog: Optional[str] = Unset("None"), + path: str = None, + device: str = "pdf", + fingerprint: str = None, + unused_columns: SelectedColumnTypes = None, + raise_if_missing: bool = True, + show: bool = True, + ): + self.config._input_columns = self._set_input_columns_and_arity(input_columns) + return self._plot( + x, + xlab=xlab, + ylab=ylab, + title=title, + bins=bins, + font_size=font_size, + col_fill=col_fill, + col_outline=col_outline, + xlog=xlog, + ylog=ylog, + path=path, + device=device, + fingerprint=fingerprint, + unused_columns=unused_columns, + raise_if_missing=raise_if_missing, + show=show, + ) + + def plot_arrow(self, x): + kwargs = self.config.get_params() + context_kwargs = { + "path": kwargs.pop("path", None), + } + self.plotter(x, context_kwargs=context_kwargs, **kwargs) + + +@dataclass +class ComparisonHistogramConfig(PlotterConfig): + processor_type: str = field(default="scaling", init=False, repr=False) + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [None, None], init=False, repr=False + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [None, None], init=False, repr=False + ) + _fit_unused_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("METADATA_FEATURE_TYPES"), + get_feature("METADATA_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + _transform_unused_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("METADATA_FEATURE_TYPES"), + get_feature("METADATA_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + r_source: str = field( + default=(config.R_SCRIPTS / "plotting_utils.R").as_posix(), + init=False, + repr=False, + ) + main_method: str = field( + default="generate_comparison_histogram", init=False, repr=False + ) + + xlab: Optional[str] = None + ylab: str = "Count" + title: str = "Comparison Histogram" + bins: int = 30 + alpha: float = 0.6 + legend_title: str = "Legend" + legend_position: str = "top" + subplot_title1: str = "Before" + subplot_title2: str = "After" + col_set: str = "Set1" + cols: Optional[List[str]] = None + xlog: Optional[bool] = None + ylog: Optional[bool] = None + + +class ComparisonHistogramPlotter(BasePlotter): + _config_class = ComparisonHistogramConfig + config: ComparisonHistogramConfig + + def __init__( + self, + xlab: Optional[str] = Unset("None"), + ylab: str = Unset("None"), + title: str = Unset("None"), + bins: int = Unset("None"), + alpha: float = Unset("None"), + legend_title: str = Unset("None"), + legend_position: str = Unset("None"), + subplot_title1: str = Unset("None"), + subplot_title2: str = Unset("None"), + col_set: str = Unset("None"), + cols: Optional[List[str]] = Unset("None"), + xlog: Optional[bool] = Unset("None"), + ylog: Optional[bool] = Unset("None"), + install_missing: bool = None, + config: Optional[ComparisonHistogramConfig] = None, + ): + super().__init__( + config=config, + xlab=xlab, + ylab=ylab, + title=title, + bins=bins, + alpha=alpha, + legend_title=legend_title, + legend_position=legend_position, + subplot_title1=subplot_title1, + subplot_title2=subplot_title2, + col_set=col_set, + cols=cols, + xlog=xlog, + ylog=ylog, + install_missing=install_missing, + ) + self.plotter = None + if self.config.r_source and self.config.main_method: + self.r_caller = RCaller.from_script(self.config.r_source) + enter_code = textwrap.dedent( + """ + suppressPackageStartupMessages(require(ggplot2)) + """ + ) + exit_code = textwrap.dedent( + """ + ggplot2::ggsave(path, results) + """ + ) + self.plotter = self.r_caller.get_method( + self.config.main_method, enter_code, exit_code + ) + + @sync_backup_config + def set_params(self, **kwargs): + self.config = self.config.replace_defaults(**kwargs) + if self.config.r_source and self.config.main_method: + self.r_caller = RCaller.from_script(self.config.r_source) + enter_code = textwrap.dedent( + """ + suppressPackageStartupMessages(require(ggplot2)) + """ + ) + exit_code = textwrap.dedent( + """ + ggplot2::ggsave(path, results) + """ + ) + self.plotter = self.r_caller.get_method( + self.config.main_method, enter_code, exit_code + ) + + return self + + def plot( + self, + x1, + x2=None, + column1: SelectedColumnTypes = None, + column2: SelectedColumnTypes = None, + xlab: Optional[str] = Unset("None"), + ylab: str = Unset("None"), + title: str = Unset("None"), + bins: int = Unset("None"), + alpha: float = Unset("None"), + legend_title: str = Unset("None"), + legend_position: str = Unset("None"), + subplot_title1: str = Unset("None"), + subplot_title2: str = Unset("None"), + col_set: str = Unset("None"), + cols: Optional[List[str]] = Unset("None"), + xlog: Optional[bool] = Unset("None"), + ylog: Optional[bool] = Unset("None"), + path: str = None, + device: str = "pdf", + fingerprint: str = None, + unused_columns: SelectedColumnTypes = None, + raise_if_missing: bool = True, + show: bool = True, + ): + self.config._input_columns = self._set_input_columns_and_arity(column1, column2) + return self._plot( + x1, + x2, + xlab=xlab, + ylab=ylab, + title=title, + bins=bins, + alpha=alpha, + legend_title=legend_title, + legend_position=legend_position, + col_set=col_set, + cols=cols, + subplot_title1=subplot_title1, + subplot_title2=subplot_title2, + xlog=xlog, + ylog=ylog, + path=path, + device=device, + fingerprint=fingerprint, + unused_columns=unused_columns, + raise_if_missing=raise_if_missing, + show=show, + ) + + def plot_arrow(self, x1, x2): + kwargs = self.config.get_params() + self.plotter(x1, x2, **kwargs) diff --git a/src/biofit/visualization/plot_feature_importance.R b/src/biofit/visualization/plot_feature_importance.R new file mode 100644 index 0000000..d14f0bf --- /dev/null +++ b/src/biofit/visualization/plot_feature_importance.R @@ -0,0 +1,216 @@ +source(file.path(R_SCRIPTS_PATH, "plotting_utils.R")) +source(file.path(R_SCRIPTS_PATH, "utils.R")) + +#' +#' Feature importance plot with feature data in the samples. +#' @param X Arrow data table that is logged or in non-logged form depending on the omics data +#' @param y Arrow labels +#' @param sample_metadata Arrow sample metadata +#' @param feature_importances Arrow feature importance table: first column is the feature ID, "feature", second column (and on wards) is/are the importance per final model seed(s) +#' @param path the path for the pdf file to save +#' @param X, y, sample_metadata, feature_importances arrow objects from the biofit framework +#' @param plot_top the number of top features to plot. Default is 15. +#' @param feature_meta_name can be NULL, prints the feature ID; or character such as "species", prints this column, or vector of more than 1 characters c("genus", "species"), then prints the collapsed text with a space in between. Default is NULL. +#' @param plot_title name of the data type, eg "Presence", or "log2 Abundance". Default is NULL. +#' @param column_title eg. "Sample" or "Isolate". Default is NULL. +#' @param row_title eg. "OTU" or "SNP". Default is NULL. +#' @examples +#' \dontrun{ +#' # usage for OTU data: +#' plot_feature_importance(X, y, sample_metadata, feature_importances, plot_top = 30, feature_meta_name = c("feature", "genus", "species"), plot_title = "log2\nAbundance", column_title = "Sample", row_title = "OTU") +#' # usage for metagenomics data: +#' plot_feature_importance(X, y, sample_metadata, feature_importances, plot_top = 30, feature_meta_name = c("genus", "species"), plot_title = "log2\nAbundance", column_title = "Sample", row_title = "Taxonomy") +#' # usage for transcriptomics/proteomics (non-MALDI) data: eg +#' plot_feature_importance(X, y, sample_metadata, feature_importances, plot_top = 30, feature_meta_name = c("gene"), plot_title = "log2\nExpression", column_title = "Sample", row_title = "Gene") +#' # usage for genomics data: +#' plot_feature_importance(X, y, sample_metadata, feature_importances, plot_top = 30, feature_meta_name = NULL, plot_title = "Presence", column_title = "Isolate", row_title = "SNP") +#' # usage for MALDI-TOF data: +#' plot_feature_importance(X, y, sample_metadata, feature_importances, plot_top = 30, feature_meta_name = "DaRange", plot_title = "log10\nAbundance", column_title = "Isolate", row_title = "Da range") +#' } +#' +plot_feature_importance <- function( + X, + feature_importances, + path, + y = NULL, + sample_metadata = NULL, + feature_metadata = NULL, + input_columns = NULL, + target_columns = NULL, + sample_metadata_columns = NULL, + feature_column = NULL, + feature_meta_name = NULL, + plot_top = 15, + cols = NULL, + col_heat = NULL, + show_column_names = FALSE, + plot_title = "Relative Abundance", + column_title = "Samples", + row_title = "Taxonomic Relative Abundance") { + width <- 13 + height <- 7 + + X <- convert_to_dataframe(X) + input_columns <- get_default_columns(X, input_columns, NULL) + X <- validate_data(X, input_columns) + + if (!is.null(y)) { + y <- convert_to_dataframe(y) + target_columns <- get_default_columns(y, target_columns) + y <- validate_data(y, target_columns) + } + + if (!is.null(sample_metadata)) { + sample_metadata <- convert_to_dataframe(sample_metadata) + sample_metadata_columns <- get_default_columns( + sample_metadata, sample_metadata_columns, NULL + ) + sample_metadata <- validate_data(sample_metadata, sample_metadata_columns) + } + + if (is.null(sample_column)) { + sample_column <- sample_metadata_columns[1] + } + + + if (!is.null(feature_metadata)) { + feature_metadata <- convert_to_dataframe(feature_metadata) + if (is.list(feature_meta_name)) { + feature_meta_name <- as.vector(unlist(feature_meta_name)) + } + feature_meta_name <- get_default_columns( + feature_metadata, feature_meta_name + ) + feature_metadata <- validate_data(feature_metadata, feature_meta_name) + rownames(feature_metadata) <- feature_metadata[, feature_column] + } + + suppressPackageStartupMessages(require(ComplexHeatmap)) + suppressPackageStartupMessages(require(grid)) + suppressPackageStartupMessages(require(circlize)) + suppressPackageStartupMessages(require(RColorBrewer)) + + ## get label and feature information + ## convert arrow data format to data frame + if (!is.null(feature_metadata)) { + feature_metadata <- convert_to_dataframe(feature_metadata) + } + + + feature_importances <- convert_to_dataframe(feature_importances) + top_feat_id <- feature_importances[, 1] + rownames(feature_importances) <- feature_importances[, feature_column] + # drop feature column + not_feature_column <- setdiff(colnames(feature_importances), feature_column) + feature_importances <- feature_importances[, not_feature_column] + + + y <- factor(as.vector(convert_to_dataframe(y)[, 1])) + + rownames(X) <- sample_metadata[, sample_column] + + ### make sure plot_top is numeric + if (!is.numeric(plot_top)) { + stop("plot_top, the number of top features to plot, must be numeric.") + } + + ### color selected from plotting_utils.R to be consistent + if (is.null(cols)) { + cols <- color_select(length(levels(y))) + names(cols) <- levels(y) + } + + ### if col_heat not specified + if (is.null(col_heat)) { + col_heat <- brewer.pal(9, "YlOrRd") + } + + ## Define the feature/text annotation with row names + ## plot feature ID if not specified + if (is.null(feature_meta_name) || is.null(feature_metadata)) { + feat_on_plot <- top_feat_id + } else { + ## for more than one feature meta columns to display + if (length(feature_meta_name) > 1) { + if (!(feature_column %in% feature_meta_name)) { + feature_meta_name <- c(feature_column, feature_meta_name) + } + feat_on_plot <- apply( + feature_metadata[top_feat_id, feature_meta_name], + 1, + paste, + collapse = "\n" + ) + } else { ## for one feature meta columns to display + feat_on_plot <- feature_metadata[top_feat_id, feature_meta_name] + } + } + + ################## + ## generate feature importance plot + ################## + ### the feature boxplot/barplot + # remove UNINTEGRATED from top_feat_id + hdat2plot <- t(data.matrix(X[, top_feat_id])) + + print(hdat2plot) + print(col_heat) + if (ncol(feature_importances) == 1) { + ha2 <- ComplexHeatmap::rowAnnotation( + ` ` = row_anno_points( + feature_importances[top_feat_id, ], + axis = TRUE, outline = FALSE + ), + width = unit(3, "cm") + ) + } else { + ha2 <- ComplexHeatmap::rowAnnotation( + ` ` = row_anno_boxplot( + data.matrix(feature_importances[top_feat_id, ]), + axis = TRUE, + outline = FALSE + ), + width = unit(3, "cm") + ) + } + + ### the sample annotation in the column + ### (Added a name hack here to be Importance + ### as the label is above feature importance box/dot plot) + ha1 <- ComplexHeatmap::HeatmapAnnotation( + df = data.frame(Importance = y), + show_legend = TRUE, + col = list(Importance = cols), + annotation_legend_param = list( + title = target_columns, + legend_direction = "vertical" + ) + ) + + ### add feature names to rows + text_annotation <- rowAnnotation( + text = anno_text( + feat_on_plot, + just = "left", + gp = gpar(fontsize = 10) + ) + ) + + ### plot histogram and add the 3 components from above + h2plot <- ComplexHeatmap::Heatmap(hdat2plot, + cluster_rows = FALSE, col = col_heat, top_annotation = ha1, + heatmap_legend_param = list( + title = plot_title, color_bar = "continuous", + legend_direction = "vertical" + ), + show_column_names = show_column_names, + row_title = row_title, # Title for the rows + column_title = column_title, # Title for the columns + row_title_side = "left", # Position of the row title + column_title_side = "top" + ) + + ha2 + text_annotation # , top_annotation_height = unit(1, "cm") + start_device(path, width = width, height = height, units = "in", res = 300) + draw(h2plot) + dev.off() +} diff --git a/src/biofit/visualization/plot_sample_metadata.R b/src/biofit/visualization/plot_sample_metadata.R new file mode 100644 index 0000000..efe3a3d --- /dev/null +++ b/src/biofit/visualization/plot_sample_metadata.R @@ -0,0 +1,140 @@ +source(file.path(R_SCRIPTS_PATH, "plotting_utils.R")) +source(file.path(R_SCRIPTS_PATH, "utils.R")) + +plot_sample_metadata <- function( + data, outcome = NULL, + sample_metadata_columns = NULL, + outcome_column = NULL, + path, + device = "pdf") { + + suppressPackageStartupMessages(require(ggplot2)) + suppressPackageStartupMessages(require(RColorBrewer)) + suppressPackageStartupMessages(require(circlize)) + suppressPackageStartupMessages(require(patchwork)) + width <- 7 + height <- 7 + if (device[1] != ".") { + device <- paste0(".", device) + } + data <- convert_to_dataframe(data) + sample_metadata_columns <- get_default_columns( + data, sample_metadata_columns, + max_cols = NULL + ) + data <- validate_data(data, sample_metadata_columns) + + if (!is.null(outcome)) { + outcome <- convert_to_dataframe(outcome) + outcome_column <- get_default_columns(outcome, outcome_column, max_cols = 1) + outcome <- validate_data(outcome, outcome_column) + data <- concatenate_datasets( + data[, sample_metadata_columns], outcome, + how = "horizontal" + ) + } else if (!is.null(outcome_column)) { + outcome_column <- get_default_columns(data, outcome_column, max_cols = 1) + } else { + stop("Please provide either an outcome or an outcome column") + } + + # get filename from path + file_name <- gsub("\\.[^.]*$", "", basename(path)) + path <- dirname(path) + # Get Outcome Type + outcome_type <- detect_var_type(data, outcome_column) + + metadata_dir <- path + if (!dir.exists(metadata_dir)) { + dir.create(metadata_dir) + } + + # Comparison Visualizations + if (outcome_type == "categorical") { + suppressPackageStartupMessages(require(forcats)) + # Outcome variable visualization + cat_metadata <- plot_cat_metadata(data, outcome_column) + cat_metadata_fn <- paste0(file_name, "_dist_", outcome_column, device) + save_plots(cat_metadata_fn, + plot = cat_metadata, + path = metadata_dir, width = width, height = height + ) + + + plot_num <- 0 + # Loop through all the columns + for (col in names(data)) { + plot_num <- plot_num + 1 + # If the col is the same as the outcome + if (col == outcome_column) next + + # Get the data type of the current column + col_type <- detect_var_type(data, col) + + # If it is other, we skip for now + if (col_type == "other") { + print(paste0(col, " was not plotted")) + next + } + + # categorical and numerical comparisons + if (col_type == "categorical") { + comp_metadata <- plot_cat_vs_cat_metadata(data, outcome_column, col) + } else if (col_type == "numerical") { + comp_metadata <- plot_cat_vs_num_metadata(data, outcome_column, col) + } + + # Save Comparison + comp_metadata_fn <- paste0( + file_name, plot_num, "_", + outcome_column, "_vs_", col, device + ) + save_plots(comp_metadata_fn, + plot = comp_metadata, + path = metadata_dir, width = width * 2, height = height + ) + } + } else if (outcome_type == "numerical") { + # Outcome variable visualization + num_metadata <- plot_num_metadata(data, outcome_column, font_size = 11) + num_metadata_fn <- paste0(file_name, "_dist_", outcome_column, device) + save_plots(num_metadata_fn, + plot = num_metadata, path = metadata_dir, + width = width, height = height + ) + plot_num <- 0 + # Loop through all of the columns + for (col in names(data)) { + plot_num <- plot_num + 1 + # If the col is the same as the outcome + if (col == outcome_column) next + + # Get the data type of the current column + col_type <- detect_var_type(data, col) + + # If it is other, we skip for now + if (col_type == "other") { + print(paste0(col, " was not plotted")) + next + } + + # categorical and numerical comparisons + if (col_type == "categorical") { + comp_metadata <- plot_cat_vs_num_metadata(data, col, outcome_column) + } else if (col_type == "numerical") { + comp_metadata <- plot_num_vs_num_metadata(data, outcome_column, col) + } + + comp_metadata_fn <- paste0( + file_name, plot_num, "_", outcome_column, "_vs_", col, device + ) + save_plots( + comp_metadata_fn, + plot = comp_metadata, + path = metadata_dir, width = width * 2, height = height + ) + } + } else if (outcome_type == "other") { + print("Outcome variable has too many levels") + } +} diff --git a/src/biofit/visualization/plotting.py b/src/biofit/visualization/plotting.py new file mode 100644 index 0000000..e1a0502 --- /dev/null +++ b/src/biofit/visualization/plotting.py @@ -0,0 +1,429 @@ +import inspect +import os +import sys +import tempfile +import warnings +from dataclasses import dataclass, field +from pathlib import Path +from typing import List, Optional + +from biocore import DataHandler +from biocore.utils.import_util import is_ipywidgets_available, is_matplotlib_available +from biocore.utils.naming import camelcase_to_snakecase + +from biofit.integration.R.r_caller import RCaller +from biofit.processing import ( + _ORDERED_FORMATS, + TransformationMixin, + sync_backup_config, +) +from biofit.utils import ( + _build_cache_dir, + fingerprint_from_data, + has_ext, + has_separator, + logging, + move_temp_file, +) +from biocore.utils.py_util import is_dataset_dict +from biofit.visualization.plotting_utils import ( + display_image_carousel, + is_in_notebook, +) + +from ..processing import BaseConfig, SelectedColumnTypes, SelectedFeatureTypes + +logger = logging.get_logger(__name__) + + +ORDERED_PLOTTER_FORMATS = _ORDERED_FORMATS + ["dataset", "ds"] + + +def _processor_info_from_fingerprint(fingerprint: str): + if fingerprint is None: + return "", "", "" + processor_info = fingerprint.split("-") + ds_name = "" + if len(processor_info) == 5: + _, processor_name, processor_type, ds_name, _ = processor_info + elif len(processor_info) == 4: + _, processor_name, processor_type, ds_name = processor_info + elif len(processor_info) == 3: + _, processor_name, processor_type = processor_info + else: + return "", "", "" + return processor_name, processor_type, ds_name + + +@dataclass +class PlotterConfig(BaseConfig): + path: str = field(default=None, kw_only=True, init=True, repr=True) + device: str = field(default="pdf", kw_only=True, init=True, repr=False) + fingerprint: str = field(default=None, kw_only=True, init=True, repr=False) + unused_columns: SelectedColumnTypes = field( + default=None, kw_only=True, init=True, repr=False + ) + raise_if_missing: bool = field(default=True, kw_only=True, init=True, repr=False) + cache_dir: str = field(default=None, kw_only=True, init=True, repr=False) + version: str = field(default="0.0.0", kw_only=True, init=True, repr=False) + + _input_columns: SelectedColumnTypes = field(default=None, init=False, repr=False) + _compare: bool = field(default=False, init=False, repr=False) + _fit_input_feature_types: SelectedFeatureTypes = field( + default=None, init=False, repr=False + ) + _fit_unused_feature_types: SelectedFeatureTypes = field( + default=None, init=False, repr=False + ) + _transform_input_feature_types: SelectedFeatureTypes = field( + default=None, init=False, repr=False + ) + _transform_unused_feature_types: SelectedFeatureTypes = field( + default=None, init=False, repr=False + ) + r_code: str = field(default=None, init=False, repr=False) + r_source: str = field(default=None, init=False, repr=False) + main_method: str = field(default=None, init=False, repr=False) + + processor_type: str = field(default="", init=False, repr=False) + processor_name: str = field(default="", init=False, repr=False) + dataset_name: str = field(default="", init=False, repr=False) + + # automatically generated attributes + + feature_idx_in_: List[int] = field(default=None, init=False, repr=False) + feature_names_in_: List[str] = field(default=None, init=False, repr=False) + extra_idx_in_: List[List[int]] = field(default=None, init=False, repr=False) + extra_names_in_: List[List[str]] = field(default=None, init=False, repr=False) + + +class BasePlotter(TransformationMixin): + """_summary_ + + Attributes: + r_code (_type_): _description_ + r_source (_type_): _description_ + feature_type (_type_): _description_ + dtype (str): specifies the type of the plotter. Can be 'plotter' or 'plotter_for' + + Raises: + NotImplementedError: _description_ + ValueError: _description_ + ValueError: _description_ + ValueError: _description_ + ValueError: _description_ + ValueError: _description_ + ValueError: _description_ + ValueError: _description_ + + Returns: + _type_: _description_ + """ + + r_caller: RCaller = None + _config_class = PlotterConfig + plotter: Optional[str] = None + config: PlotterConfig + + def __init__(self, config: Optional[PlotterConfig] = None, **kwargs): + add_new_attr = kwargs.pop("add_new_attr", False) + ignore_none = kwargs.pop("ignore_none", False) + + if config is None: + if hasattr(self, "_config_class"): + self.config = self._config_class.from_dict( + kwargs, ignore_none=ignore_none, add_new_attr=add_new_attr + ) + elif isinstance(config, PlotterConfig): + self.config = config + elif isinstance(config, dict): + self.config = self._config_class.from_dict( + config, ignore_none=ignore_none, add_new_attr=add_new_attr + ) + else: + raise ValueError(f"Unsupported config type {type(config)}") + if config is None: + self = self.set_params(**kwargs) + if self.config.r_source and self.config.main_method and self.plotter is None: + self.r_caller = RCaller.from_script(self.config.r_source) + self.r_caller.verify_r_dependencies( + install_missing=kwargs.get("install_missing") + ) + self.plotter = self.r_caller.get_method(self.config.main_method) + if kwargs.get("_function", None): + self._function = kwargs["_function"] + + @sync_backup_config + def set_params(self, **kwargs): + self.config = self.config.replace_defaults(**kwargs) + if self.config.r_source and self.config.main_method: + self.r_caller = RCaller.from_script(self.config.r_source) + self.plotter = self.r_caller.get_method(self.config.main_method) + + return self + + @classmethod + def _from_config(cls, config: PlotterConfig, **kwargs): + return cls(config=config, **kwargs) + + def plot(self, X, *args, **kwargs): + if is_dataset_dict(X): + raise ValueError( + "Cannot plot a DatasetDict or IterableDatasetDict. Please provide a Dataset or IterableDataset." + ) + return self._plot(X, *args, **kwargs) + + def _plot(self, X, *args, **kwargs): + """ + Transforms the input data. + + Args: + X (Any): The input data. + *args: Additional arguments. These are additional table or array-like data to be used in the plotting. + **kwargs: Additional keyword arguments. These are objects or parameters to be used in the plotting. + + Returns: + Any: The computed processor. + """ + show = kwargs.pop("show", True) + output_dir = kwargs.pop("output_dir", None) + self.config, kwargs = self.config.replace_defaults( + ignore_none=True, return_unused_kwargs=True, **kwargs + ) + + args = list(args) + + plot_funcs = self._get_method(ORDERED_PLOTTER_FORMATS, "plot") + plot_func, target_format = self._get_target_func( + plot_funcs, + source_format=DataHandler.get_format(X), + target_formats=kwargs.get("input_format", None), + accepted_formats=ORDERED_PLOTTER_FORMATS, + ) + + fingerprint = kwargs.pop("fingerprint", None) or self.config.fingerprint + + image_paths = None + if plot_func: + ( + self.feature_names_in_, + self.feature_idx_in_, + _, + self.extra_names_in_, + self.extra_idx_in_, + _, + _, + ) = self._get_columns( + X, + *args, + input_columns=self.config._input_columns, + unused_columns=self.config.unused_columns, + input_feature_types=self.config._transform_input_feature_types, + unused_feature_types=self.config._transform_unused_feature_types, + raise_if_missing=self.config.raise_if_missing, + ) + + path_to_save = self.config.path + + # required args are ones without defaults + required_args = [ + p.default == p.empty + for p in inspect.signature(plot_func).parameters.values() + ][1:] + if ( + not self.config._compare + and self.extra_idx_in_ is not None + and len(self.extra_idx_in_) > 0 + and DataHandler.supports_named_columns(X) + ): + new_cols = self._make_columns_exclusive( + [self.feature_names_in_] + self.extra_names_in_ + ) + + self.feature_names_in_ = new_cols[0] + self.feature_idx_in_ = sorted( + DataHandler.get_column_indices(X, self.feature_names_in_) + ) + + self.extra_names_in_ = new_cols[1:] + extra_idx_in_ = [] + for i, names in enumerate(self.extra_names_in_): + if len(args) and args[i] is not None: + extra_idx_in_.append( + sorted( + DataHandler.get_column_indices( + args[i], names, raise_if_missing=False + ) + ) + ) + elif names is not None: + extra_idx_in_.append( + sorted( + DataHandler.get_column_indices( + X, names, raise_if_missing=True + ) + ) + ) + else: + if required_args[i]: + raise ValueError( + f"Missing required argument {i} for plot function {plot_func}" + ) + extra_idx_in_.append(self.extra_idx_in_[i]) + self.extra_idx_in_ = extra_idx_in_ + + data_fingerprint = fingerprint_from_data(X) + if fingerprint is None: + fingerprint = self.generate_fingerprint(data_fingerprint, self.config) + + input, args, kwargs = self._process_plot_input(X, *args, **kwargs) + args = list(args) + if len(args) == 0: + if self.extra_idx_in_: + for i, inds in enumerate(self.extra_idx_in_): + if inds is not None: + args.append( + DataHandler.to_format( + X, target_format, input_columns=inds + ) + ) + else: + args.append(None) + else: + for i, (inds, arg) in enumerate(zip(self.extra_idx_in_, args)): + if arg is not None: + args[i] = DataHandler.to_format( + arg, target_format, input_columns=inds + ) + elif inds is not None: + args[i] = DataHandler.to_format( + X, target_format, input_columns=inds + ) + + input = DataHandler.to_format( + input, target_format, input_columns=self.feature_idx_in_ + ) + temp_dir = tempfile.mkdtemp() + temp_file = os.path.join(temp_dir, fingerprint) + temp_dir = Path(temp_dir) + temp_dir.mkdir(parents=True, exist_ok=True) + temp_dir = temp_dir.resolve().as_posix() + self.config.path = temp_file + f".{self.config.device}" + output_dir = None if output_dir is None else Path(output_dir) + file_name = None + if path_to_save: + if has_separator(path_to_save): + if has_ext(path_to_save): + output_dir = ( + Path(path_to_save).parent + if output_dir is None + else output_dir + ) + file_name = Path(path_to_save).name + else: + output_dir = ( + Path(path_to_save) if output_dir is None else output_dir + ) + file_name = None + elif has_ext(path_to_save): + file_name = str(path_to_save) + + if output_dir is None: + if kwargs.get("processor", None) and kwargs["processor"].cache_files: + output_dir = ( + Path(kwargs["processor"].cache_files[0]["filename"]).parent + / "plots" + ) + elif hasattr(X, "cache_files") and X.cache_files: + output_dir = Path(X.cache_files[0]["filename"]).parent / "plots" + else: + output_dir = Path(_build_cache_dir(X, data_fingerprint)) / "plots" + if file_name is None: + cls_name = camelcase_to_snakecase( + self.__class__.__name__.replace("_plotter", "") + ) + file_name = f"{fingerprint}_{cls_name}.{self.config.device}" + + fig = plot_func(input, *args, **kwargs) + if "matplotlib" in sys.modules: + import matplotlib.pyplot as plt + + if output_dir: + + def get_or_move_images(images, output_dir, move_images=False): + image_paths = [] + if len(images) > 1: + for fn in images: + old_name = fn.resolve().as_posix() + new_name = f"{output_dir}/{fn.name}" + if move_images: + image_paths.append(new_name) + move_temp_file(old_name, new_name) + else: + image_paths.append(old_name) + if move_images: + logger.info(f"Saved {len(images)} plots to {output_dir}") + elif len(images) == 1: + old_name = images[0].resolve().as_posix() + if move_images: + image_paths = f"{output_dir}/{file_name}" + move_temp_file(old_name, image_paths) + logger.info(f"Saved plot to {image_paths}") + else: + image_paths = old_name + + return image_paths + + output_dir.mkdir(parents=True, exist_ok=True) + output_dir = output_dir.resolve().as_posix() + # move all files within temp_dir to output_dir + images = [fn for fn in Path(temp_dir).glob(f"*.{self.config.device}")] + image_paths = get_or_move_images(images, output_dir, move_images=True) + + _image_paths = [] + if self.config.device != "png": + _images = [fn for fn in Path(temp_dir).glob("*.png")] + _image_paths = get_or_move_images( + _images, output_dir, move_images=False + ) + if len(_image_paths) > 0: + image_paths = _image_paths + images = _images + + if len(image_paths) > 1: + if is_in_notebook() and show and is_ipywidgets_available(): + display_image_carousel(image_paths, "png") + elif len(image_paths) == 1: + if is_in_notebook() and show: + if ( + is_matplotlib_available() + and "matplotlib" in sys.modules + and isinstance(fig, plt.Figure) + ): + # ignore warning about non-interactive backend + warnings.filterwarnings("ignore") + fig.show() + warnings.filterwarnings("default") + else: + from IPython.display import Image, display + + display( + Image( + image_paths, + embed=True, + format="png", + width=720, + ) + ) + elif ( + is_matplotlib_available() + and "matplotlib" in sys.modules + and isinstance(fig, plt.Figure) + ): + plt.close(fig) + else: + logger.warning("No plots were generated") + return image_paths + + def _process_plot_input(self, input, *args, **kwargs): + return input, args, kwargs diff --git a/src/biofit/visualization/plotting_utils.py b/src/biofit/visualization/plotting_utils.py new file mode 100644 index 0000000..ff6a63e --- /dev/null +++ b/src/biofit/visualization/plotting_utils.py @@ -0,0 +1,1210 @@ +import sys +from os import PathLike +from types import NoneType +from typing import List, Optional, Union + +import numpy as np +import pandas as pd +from biocore import DataHandler +from biocore.utils import requires_backends +from sklearn.pipeline import Pipeline + +from biofit.processing import SelectedColumnTypes +from biofit.utils.types import Unset + + +def is_in_notebook(): + try: + from IPython import get_ipython + + if "IPKernelApp" in get_ipython().config: + return True + except Exception: + return False + return False + + +def display_image_carousel(image_paths, format="png"): + """ + Displays an image carousel with left and right arrow buttons in a Jupyter notebook. + + Parameters: + image_paths (list): List of paths to images. + output_dir (str): Directory to indicate in the log if limit is exceeded. + max_images (int): Maximum number of images to include in the carousel. + """ + requires_backends(display_image_carousel, "ipywidgets") + import ipywidgets as widgets + from IPython.display import Image, clear_output, display + from ipywidgets import Button, HBox, Layout, VBox + + image_index = 0 + + button_layout = Layout(width="100px") + left_button = Button(description="Previous", layout=button_layout) + right_button = Button(description="Next", layout=button_layout) + image_box = widgets.Output() + + def show_image(): + """Helper function to show the image.""" + with image_box: + clear_output(wait=True) + display( + Image(image_paths[image_index], embed=True, format=format, width=720) + ) + + def on_left_button_clicked(b): + nonlocal image_index + if image_index > 0: + image_index -= 1 + show_image() + + def on_right_button_clicked(b): + nonlocal image_index + if image_index < len(image_paths) - 1: + image_index += 1 + show_image() + + left_button.on_click(on_left_button_clicked) + right_button.on_click(on_right_button_clicked) + + show_image() # Show the first image initially + + navigation_box = HBox([left_button, right_button]) + display(VBox([navigation_box, image_box])) + + +def get_distinct_colors(n, colormap="nipy_spectral"): + """ + Generate a list of n distinct colors using a specified colormap. + Switched to 'nipy_spectral' for higher contrast and brightness differentiation. + + :param n: Number of distinct colors to generate. + :param colormap: Name of the matplotlib colormap to use. + :return: List of RGBA colors. + """ + requires_backends(get_distinct_colors, "matplotlib") + from matplotlib import colormaps + + cmap = colormaps.get_cmap(colormap) + return [cmap(i / n) for i in range(n)] + + +def generate_violin( + x, + y=None, + column: SelectedColumnTypes = None, + label_name: SelectedColumnTypes = None, + xlab: str = Unset('"Labels"'), + ylab: str = Unset('"Value"'), + device: str = "pdf", + fingerprint: str = None, + unused_columns: SelectedColumnTypes = None, + raise_if_missing: bool = True, + output_dir: Union[PathLike, str] = None, +): + """Generates a violin plot of the input data. + + Args: + x: Input data. + other_x: Other input data. + y: Target data. + other_x: Other target data. + **kwargs: Additional arguments. + Returns: + Plotter object. + """ + from biofit.visualization.violin import ViolinPlotter + + return ViolinPlotter().plot( + x, + y, + column=column, + label_name=label_name, + xlab=xlab, + ylab=ylab, + path=output_dir, + device=device, + fingerprint=fingerprint, + unused_columns=unused_columns, + raise_if_missing=raise_if_missing, + ) + + +def generate_scatterplot( + x, + y=None, + group=None, + xdata: SelectedColumnTypes = None, + ydata: SelectedColumnTypes = None, + groupby: SelectedColumnTypes = None, + xlab: str = Unset("None"), + ylab: str = Unset("None"), + title: str = Unset('"Scatterplot"'), + alpha: str = Unset("1"), + col_set: str = Unset('"Set1"'), + cols: List[str] = Unset("None"), + xlog: str = Unset("None"), + ylog: str = Unset("None"), + output_dir: str = None, + device: str = "pdf", + fingerprint: str = None, + unused_columns: SelectedColumnTypes = None, + raise_if_missing: bool = True, +): + from biofit.visualization.scatterplot import ScatterPlotter + + plotter = ScatterPlotter() + return plotter.plot( + x, + y, + group, + xdata=xdata, + ydata=ydata, + groupby=groupby, + xlab=xlab, + ylab=ylab, + title=title, + alpha=alpha, + col_set=col_set, + cols=cols, + xlog=xlog, + ylog=ylog, + path=output_dir, + device=device, + fingerprint=fingerprint, + unused_columns=unused_columns, + raise_if_missing=raise_if_missing, + ) + + +def generate_histogram( + x, + input_columns: SelectedColumnTypes = None, + xlab: str = Unset('"X"'), + ylab: str = Unset('"Frequency"'), + title: str = Unset('"Histogram"'), + bins: int = Unset("30"), + font_size: int = Unset("8"), + col_fill: str = Unset('"grey40"'), + col_outline: str = Unset('"white"'), + xlog: Optional[str] = Unset("None"), + ylog: Optional[str] = Unset("None"), + output_dir: str = None, + device: str = "pdf", + fingerprint: str = None, + unused_columns: SelectedColumnTypes = None, + raise_if_missing: bool = True, +): + """Generates a histogram of the input data. + + Args: + x: Input data. + other_x: Other input data. + y: Target data. + other_x: Other target data. + **kwargs: Additional arguments. + Returns: + Plotter object. + """ + from biofit.visualization.histogram import HistogramPlotter + + plotter = HistogramPlotter() + + return plotter.plot( + x, + input_columns=input_columns, + xlab=xlab, + ylab=ylab, + title=title, + bins=bins, + font_size=font_size, + col_fill=col_fill, + col_outline=col_outline, + xlog=xlog, + ylog=ylog, + path=output_dir, + device=device, + fingerprint=fingerprint, + unused_columns=unused_columns, + raise_if_missing=raise_if_missing, + ) + + +def generate_barplot( + x, + y=None, + group=None, + label_name: SelectedColumnTypes = None, + value_name=None, + groupby: Optional[str] = None, + xlab: Optional[str] = Unset("None"), + ylab: Optional[str] = Unset("None"), + title: str = Unset('"Bar Plot"'), + col_set: str = Unset('"Set1"'), + col_labels: str = Unset('"black"'), + col_outline: str = Unset('"grey30"'), + cols: Optional[List[str]] = Unset("None"), + prop: bool = Unset("False"), + add_count_lab: bool = Unset("True"), + vars_as_entered: bool = Unset("False"), + legend_position: str = Unset('"top"'), + font_size: float = Unset("3.25"), + output_dir: str = None, + device: str = "pdf", + fingerprint: str = None, + unused_columns: SelectedColumnTypes = None, + raise_if_missing: bool = True, +): + """Generates a bar plot of the input data. + + Args: + x: Input data. + other_x: Other input data. + y: Target data. + other_x: Other target data. + **kwargs: Additional arguments. + Returns: + Plotter object. + """ + from biofit.visualization.barplot import BarPlotter + + plotter = BarPlotter() + + return plotter.plot( + x, + y, + group, + label_name=label_name, + value_name=value_name, + groupby=groupby, + xlab=xlab, + ylab=ylab, + title=title, + col_set=col_set, + col_labels=col_labels, + col_outline=col_outline, + cols=cols, + prop=prop, + add_count_lab=add_count_lab, + vars_as_entered=vars_as_entered, + legend_position=legend_position, + font_size=font_size, + path=output_dir, + device=device, + fingerprint=fingerprint, + unused_columns=unused_columns, + raise_if_missing=raise_if_missing, + ) + + +def generate_comparison_histogram( + x1, + x2=None, + column1: str = None, + column2: str = None, + xlab: Optional[str] = Unset("None"), + ylab: str = Unset("None"), + title: str = Unset("None"), + bins: int = Unset("None"), + alpha: float = Unset("None"), + legend_title: str = Unset("None"), + legend_position: str = Unset("None"), + subplot_title1: str = Unset("None"), + subplot_title2: str = Unset("None"), + col_set: str = Unset("None"), + cols: Optional[List[str]] = Unset("None"), + xlog: Optional[bool] = Unset("None"), + ylog: Optional[bool] = Unset("None"), + device: str = "pdf", + fingerprint: str = None, + unused_columns: SelectedColumnTypes = None, + raise_if_missing: bool = True, + output_dir=None, +): + """Generates a comparison histogram of the input data. + + Args: + x: Input data. + other_x: Other input data. + y: Target data. + other_x: Other target data. + **kwargs: Additional arguments. + Returns: + Plotter object. + """ + from biofit.visualization.histogram import ComparisonHistogramPlotter + + plotter = ComparisonHistogramPlotter() + + return plotter.plot( + x1, + x2=x2, + column1=column1, + column2=column2, + xlab=xlab, + ylab=ylab, + title=title, + bins=bins, + alpha=alpha, + legend_title=legend_title, + legend_position=legend_position, + col_set=col_set, + cols=cols, + subplot_title1=subplot_title1, + subplot_title2=subplot_title2, + xlog=xlog, + ylog=ylog, + path=output_dir, + device=device, + fingerprint=fingerprint, + unused_columns=unused_columns, + raise_if_missing=raise_if_missing, + ) + + +def plot_correlation( + X, + y=None, + group=None, + input_columns=None, + target_column=None, + groupby: Optional[str] = None, + precomputed=False, + method="auto", + label_name: str = Unset('"None"'), + value_name: Optional[str] = Unset("None"), + top_k: int = Unset("30"), + xlab: Optional[str] = Unset("None"), + ylab: Optional[str] = Unset("None"), + title: str = Unset('"Bar Plot"'), + col_set: str = Unset('"Set1"'), + cols: Optional[List[str]] = Unset("None"), + prop: bool = Unset("False"), + add_count_lab: bool = Unset("True"), + vars_as_entered: bool = Unset("False"), + legend_position: str = Unset('"top"'), + font_size: float = Unset("3.25"), + output_dir=None, + file_name=None, +): + """Plot the correlation matrix of the input data. + + Args: + X: Input data + y: Target variable + input_columns: Columns to filter + target_column: Target column + top_k: Number of top features to plot + label_name: Name of the labels + value_name: Name of the values + groupby: Groupby column + xlab: X-axis label + ylab: Y-axis label + title: Plot title + col_set: Color set + cols: Columns to plot + prop: Whether to plot proportions + add_count_lab: Whether to add count labels + vars_as_entered: Whether the variables are entered as is + legend_position: Position of the legend + font_size: Font size + Returns: + SampleFiltered data. + """ + if not precomputed: + from biofit.stat import CorrelationStat + + corr_stat = CorrelationStat(method=method) + corrs = corr_stat.fit_transform( + X, y, input_columns, target_column, output_format="pandas" + ) + name_map = { + "pearsonr": "Pearson Correlation", + "spearmanr": "Spearman Correlation", + "kendalltau": "Kendall Correlation", + "pointbiserialr": "Point Biserial Correlation", + } + value_name = name_map.get(corr_stat.config.method, "Correlation") + if isinstance(groupby, (Unset, NoneType)): + groupby = "Features" + corrs = corrs.melt(var_name=groupby, value_name=value_name) + # drop na + corrs = corrs.dropna() + sorted_inds = np.argsort(np.abs(corrs[value_name]))[::-1] + if isinstance(top_k, (Unset, NoneType)): + top_k = 30 + corrs = corrs.iloc[sorted_inds[:top_k]] + return generate_barplot( + corrs, + None, + None, + label_name=label_name, + value_name=value_name, + groupby=groupby, + xlab=xlab, + ylab=ylab, + title=title, + col_set=col_set, + cols=cols, + prop=prop, + add_count_lab=add_count_lab, + vars_as_entered=vars_as_entered, + legend_position=legend_position, + font_size=font_size, + output_dir=output_dir, + ) + + return generate_barplot( + X, + y, + None, + label_name=label_name, + value_name=value_name, + groupby=groupby, + xlab=xlab, + ylab=ylab, + title=title, + col_set=col_set, + cols=cols, + prop=prop, + add_count_lab=add_count_lab, + vars_as_entered=vars_as_entered, + legend_position=legend_position, + font_size=font_size, + output_dir=output_dir, + ) + + +def plot_feature_distribution( + X, + input_columns: SelectedColumnTypes = None, + value_name=None, + aggregate=None, + aggregate_kwargs={}, + xlab: str = Unset('"X"'), + ylab: str = Unset('"Frequency"'), + title: str = Unset('"Histogram"'), + bins: int = Unset("30"), + font_size: int = Unset("8"), + col_fill: str = Unset('"grey40"'), + col_outline: str = Unset('"white"'), + xlog: Optional[str] = Unset("None"), + ylog: Optional[str] = Unset("None"), + output_dir: str = None, + device: str = "pdf", + fingerprint: str = None, + unused_columns: SelectedColumnTypes = None, + raise_if_missing: bool = True, +): + """Plot the feature distribution of the input data. + + Args: + X: Input data + columns: Columns to plot + aggregate: Aggregate each feature by: 'mean', 'median', 'sum', 'std', 'var', 'min', 'max', or 'presence' + **kwargs: Additional keyword arguments + + Returns: + Plot object. + """ + x_dims = DataHandler.get_shape(X) + precomputed = aggregate is None and ( + value_name is not None or len(x_dims) == 1 or x_dims[1] == 1 + ) + value_name = value_name or "Presence" + + if aggregate == "presence": + from biofit.stat import ColumnMissingnessStat + + num_rows = DataHandler.get_shape(X)[0] + missingness = ColumnMissingnessStat( + input_columns=input_columns, **aggregate_kwargs + ).fit_transform(X, output_format="pandas") + data = num_rows - missingness + elif aggregate == "sum": + from biofit.stat import ColumnSumStat + + data = ColumnSumStat( + input_columns=input_columns, **aggregate_kwargs + ).fit_transform(X, output_format="pandas") + elif aggregate == "mean": + from biofit.stat import ColumnMeanStat + + data = ColumnMeanStat( + input_columns=input_columns, **aggregate_kwargs + ).fit_transform(X, output_format="pandas") + + if ( + input_columns is None + and "biosets" in sys.modules + and isinstance(X, getattr(sys.modules["biosets"], "Bioset")) + ): + from biosets import get_data + + data = get_data(X) + else: + data = X + + data = DataHandler.to_pandas(data) + + if input_columns: + data = data[input_columns] + + if not precomputed: + data = data.melt(value_name=value_name) + + return generate_histogram( + data, + input_columns=input_columns, + xlab=xlab, + ylab=ylab, + title=title, + bins=bins, + font_size=font_size, + col_fill=col_fill, + col_outline=col_outline, + xlog=xlog, + ylog=ylog, + output_dir=output_dir, + device=device, + fingerprint=fingerprint, + unused_columns=unused_columns, + raise_if_missing=raise_if_missing, + ) + + +def compare_feature_distributions( + X1, + X2, + columns1: SelectedColumnTypes = None, + columns2: SelectedColumnTypes = None, + value_name=None, + aggregate=None, + aggregate_kwargs={}, + xlab: Optional[str] = Unset("None"), + ylab: str = Unset("None"), + title: str = Unset("None"), + bins: int = Unset("None"), + alpha: float = Unset("None"), + legend_title: str = Unset("None"), + legend_position: str = Unset("None"), + subplot_title1: str = Unset("None"), + subplot_title2: str = Unset("None"), + col_set: str = Unset("None"), + cols: Optional[List[str]] = Unset("None"), + xlog: Optional[bool] = Unset("None"), + ylog: Optional[bool] = Unset("None"), + output_dir: str = None, + device: str = "pdf", + fingerprint: str = None, + unused_columns: SelectedColumnTypes = None, + raise_if_missing: bool = True, +): + """Compare the feature distribution of the input data. + + Args: + X1: Input data 1 + X2: Input data 2 + columns1: Columns to plot for data 1 + columns2: Columns to plot for data 2. If not provided, columns1 will be used for data 2. + aggregate: Aggregate each feature by: 'mean', 'median', 'sum', 'std', 'var', 'min', 'max', or 'presence' + **kwargs: Additional keyword arguments + Returns: + Plot object. + """ + x1_dims = DataHandler.get_shape(X1) + x2_dims = DataHandler.get_shape(X2) + precomputed = ( + aggregate is None + and (value_name is not None or len(x1_dims) == 1 or x1_dims[1] == 1) + and (value_name is not None or len(x2_dims) == 1 or x2_dims[1] == 1) + ) + + if columns1 and not columns2: + columns2 = columns1 + + if aggregate == "presence": + from biofit.stat import ColumnMissingnessStat + + value_name = value_name or "Presence" + num_rows1 = DataHandler.get_shape(X1)[0] + num_rows2 = DataHandler.get_shape(X2)[0] + + missingness1 = ColumnMissingnessStat( + input_columns=columns1, **aggregate_kwargs + ).fit_transform(X1, output_format="pandas") + missingness2 = ColumnMissingnessStat( + input_columns=columns2, **aggregate_kwargs + ).fit_transform(X2, output_format="pandas") + + data1 = num_rows1 - missingness1 + data2 = num_rows2 - missingness2 + elif aggregate == "sum": + from biofit.stat import ColumnSumStat + + value_name = value_name or "Sum" + data1 = ColumnSumStat(input_columns=columns1, **aggregate_kwargs).fit_transform( + X1, output_format="pandas" + ) + data2 = ColumnSumStat(input_columns=columns2, **aggregate_kwargs).fit_transform( + X2, output_format="pandas" + ) + elif aggregate == "mean": + from biofit.stat import ColumnMeanStat + + value_name = value_name or "Mean" + data1 = ColumnMeanStat( + input_columns=columns1, **aggregate_kwargs + ).fit_transform(X1, output_format="pandas") + data2 = ColumnMeanStat( + input_columns=columns2, **aggregate_kwargs + ).fit_transform(X2, output_format="pandas") + else: + data1 = X1 + data2 = X2 + + value_name = value_name or "Value" + + data1 = DataHandler.to_pandas(data1) + data2 = DataHandler.to_pandas(data2) + + if columns1: + data1 = data1[columns1] + + if columns2: + data2 = data2[columns2] + + if not precomputed: + if data1.shape[1] > 2: + data1 = data1.melt(value_name=value_name) + + if data2.shape[1] > 2: + data2 = data2.melt(value_name=value_name) + + return generate_comparison_histogram( + data1, + data2, + column1=value_name, + column2=value_name, + xlab=xlab, + ylab=ylab, + title=title, + bins=bins, + alpha=alpha, + legend_title=legend_title, + legend_position=legend_position, + col_set=col_set, + cols=cols, + subplot_title1=subplot_title1, + subplot_title2=subplot_title2, + xlog=xlog, + ylog=ylog, + output_dir=output_dir, + device=device, + fingerprint=fingerprint, + unused_columns=unused_columns, + raise_if_missing=raise_if_missing, + ) + + +def plot_sample_distribution( + X, + input_columns: SelectedColumnTypes = None, + value_name=None, + aggregate=None, + aggregate_kwargs={}, + xlab: str = Unset('"X"'), + ylab: str = Unset('"Frequency"'), + title: str = Unset('"Histogram"'), + bins: int = Unset("30"), + font_size: int = Unset("8"), + col_fill: str = Unset('"grey40"'), + col_outline: str = Unset('"white"'), + xlog: Optional[str] = Unset("None"), + ylog: Optional[str] = Unset("None"), + output_dir: str = None, + device: str = "pdf", + fingerprint: str = None, + unused_columns: SelectedColumnTypes = None, + raise_if_missing: bool = True, +): + """Plot the feature distribution of the input data. + + Args: + X: Input data + columns: Columns to plot + aggregate: Aggregate each feature by: 'mean', 'median', 'sum', 'std', 'var', 'min', 'max', or 'presence' + **kwargs: Additional keyword arguments + + Returns: + Plot object. + """ + x_dims = DataHandler.get_shape(X) + precomputed = aggregate is None and ( + value_name is not None or len(x_dims) == 1 or x_dims[1] == 1 + ) + + if aggregate == "presence": + from biofit.stat import RowMissingnessStat + + value_name = value_name or "Presence" + num_rows = DataHandler.get_shape(X)[0] + missingness = RowMissingnessStat( + input_columns=input_columns, **aggregate_kwargs + ).fit_transform(X, output_format="pandas") + data = num_rows - missingness + elif aggregate == "sum": + from biofit.stat import RowSumStat + + value_name = value_name or "Sum" + data = RowSumStat( + input_columns=input_columns, **aggregate_kwargs + ).fit_transform(X, output_format="pandas") + elif aggregate == "mean": + from biofit.stat import RowMeanStat + + value_name = value_name or "Mean" + data = RowMeanStat( + input_columns=input_columns, **aggregate_kwargs + ).fit_transform(X, output_format="pandas") + + value_name = value_name or "Value" + + if ( + input_columns is None + and "biosets" in sys.modules + and isinstance(X, getattr(sys.modules["biosets"], "Bioset")) + ): + from biosets import get_data + + data = get_data(X) + else: + data = X + + data = DataHandler.to_pandas(data) + + if input_columns: + data = data[input_columns] + + if not precomputed: + if data.shape[1] > 2: + data = data.melt(value_name=value_name) + + return generate_histogram( + data, + input_columns=value_name, + xlab=xlab, + ylab=ylab, + title=title, + bins=bins, + font_size=font_size, + col_fill=col_fill, + col_outline=col_outline, + xlog=xlog, + ylog=ylog, + output_dir=output_dir, + device=device, + fingerprint=fingerprint, + unused_columns=unused_columns, + raise_if_missing=raise_if_missing, + ) + + +def compare_sample_distributions( + X1, + X2, + columns1: SelectedColumnTypes = None, + columns2: SelectedColumnTypes = None, + value_name=None, + aggregate=None, + aggregate_kwargs={}, + xlab: Optional[str] = Unset("None"), + ylab: str = Unset("None"), + title: str = Unset("None"), + bins: int = Unset("None"), + alpha: float = Unset("None"), + legend_title: str = Unset("None"), + legend_position: str = Unset("None"), + subplot_title1: str = Unset("None"), + subplot_title2: str = Unset("None"), + col_set: str = Unset("None"), + cols: Optional[List[str]] = Unset("None"), + xlog: Optional[bool] = Unset("None"), + ylog: Optional[bool] = Unset("None"), + device: str = "pdf", + fingerprint: str = None, + unused_columns: SelectedColumnTypes = None, + raise_if_missing: bool = True, + output_dir=None, +): + """Compare the feature distribution of the input data. + + Args: + X1: Input data 1 + X2: Input data 2 + columns1: Columns to plot for data 1 + columns2: Columns to plot for data 2. If not provided, columns1 will be used for data 2. + aggregate: Aggregate each feature by: 'mean', 'median', 'sum', 'std', 'var', 'min', 'max', or 'presence' + **kwargs: Additional keyword arguments + + Returns: + Plot object. + """ + x1_dims = DataHandler.get_shape(X1) + x2_dims = DataHandler.get_shape(X2) + precomputed = ( + aggregate is None + and (value_name is not None or len(x1_dims) == 1 or x1_dims[1] == 1) + and (value_name is not None or len(x2_dims) == 1 or x2_dims[1] == 1) + ) + + if columns1 and not columns2: + columns2 = columns1 + + if aggregate == "presence": + from biofit.stat import RowMissingnessStat + + value_name = value_name or "Presence" + num_cols1 = x1_dims[1] if len(x1_dims) == 2 else 1 + num_cols2 = x2_dims[1] if len(x2_dims) == 2 else 1 + + missingness1 = RowMissingnessStat( + input_columns=columns1, keep_unused_columns=False, **aggregate_kwargs + ).fit_transform(X1, output_format="pandas") + missingness2 = RowMissingnessStat( + input_columns=columns2, keep_unused_columns=False, **aggregate_kwargs + ).fit_transform(X2, output_format="pandas") + + data1 = num_cols1 - missingness1 + data2 = num_cols2 - missingness2 + data1.columns = [value_name] + data2.columns = [value_name] + elif aggregate == "sum": + from biofit.stat import RowSumStat + + value_name = value_name or "Sum" + data1 = RowSumStat( + input_columns=columns1, keep_unused_columns=False, **aggregate_kwargs + ).fit_transform(X1, output_format="pandas") + data2 = RowSumStat( + input_columns=columns2, keep_unused_columns=False, **aggregate_kwargs + ).fit_transform(X2, output_format="pandas") + data1.columns = [value_name] + data2.columns = [value_name] + elif aggregate == "mean": + from biofit.stat import RowMeanStat + + value_name = value_name or "Mean" + data1 = RowMeanStat( + input_columns=columns1, keep_unused_columns=False, **aggregate_kwargs + ).fit_transform(X1, output_format="pandas") + data2 = RowMeanStat( + input_columns=columns2, keep_unused_columns=False, **aggregate_kwargs + ).fit_transform(X2, output_format="pandas") + data1.columns = [value_name] + data2.columns = [value_name] + + value_name = value_name or "Value" + + data1 = DataHandler.to_pandas(data1) + data2 = DataHandler.to_pandas(data2) + + if columns1: + data1 = data1[columns1] + + if columns2: + data2 = data2[columns2] + + if not precomputed: + if data1.shape[1] > 1: + data1 = data1.melt(value_name=value_name) + else: + data1.columns = [value_name] + + if data2.shape[1] > 1: + data2 = data2.melt(value_name=value_name) + else: + data1.columns = [value_name] + + if value_name not in data1.columns: + raise ValueError( + f"Column '{value_name}' not found in data1. Found columns: {data1.columns}" + ) + if value_name not in data2.columns: + raise ValueError( + f"Column '{value_name}' not found in data2. Found columns: {data2.columns}" + ) + + return generate_comparison_histogram( + data1, + data2, + column1=value_name, + column2=value_name, + xlab=xlab, + ylab=ylab, + title=title, + bins=bins, + alpha=alpha, + legend_title=legend_title, + legend_position=legend_position, + col_set=col_set, + cols=cols, + subplot_title1=subplot_title1, + subplot_title2=subplot_title2, + xlog=xlog, + ylog=ylog, + output_dir=output_dir, + device=device, + fingerprint=fingerprint, + unused_columns=unused_columns, + raise_if_missing=raise_if_missing, + ) + + +def plot_dimension_reduction( + X, + labels=None, + group=None, + input_columns: SelectedColumnTypes = None, + label_column: SelectedColumnTypes = None, + group_column: SelectedColumnTypes = None, + method=None, + method_kwargs={}, + title: str = Unset('"Dimension Reduction Plot"'), + colormap: str = Unset('"nipy_spectral"'), + n_components: int = Unset("3"), + dimension_reducer=None, + output_dir: str = None, + device: str = "pdf", + fingerprint: str = None, + unused_columns: SelectedColumnTypes = None, + raise_if_missing: bool = True, + show=True, +): + """Plot the dimension reduction plot. + + Args: + X: Input data + labels: Labels for the data + group: Group for the data + dimension_reducer: Dimension reducer object + **kwargs: Additional keyword arguments + + Returns: + Plot object. + """ + requires_backends(plot_sample_metadata, ["matplotlib", "seaborn"]) + if labels is None and label_column is None: + if "biosets" in sys.modules and isinstance( + X, getattr(sys.modules["biosets"], "Bioset") + ): + from biosets import get_target + + labels = get_target(X) + elif label_column: + labels = DataHandler.select_columns(X, label_column) + if "biosets" in sys.modules and isinstance( + X, getattr(sys.modules["biosets"], "Bioset") + ): + from biosets import decode + + labels = decode(labels) + + if input_columns is not None: + X = DataHandler.select_columns(X, input_columns) + input_columns = None + if method is None: + from biofit.preprocessing import PCAFeatureExtractor + + method = PCAFeatureExtractor + if isinstance(method, type): + method = method() + if isinstance(method, str): + from biofit.auto.processing_auto import AutoPreprocessor + + if method_kwargs is None or not isinstance(method_kwargs, dict): + method_kwargs = {} + + method = AutoPreprocessor.for_processor(method, **method_kwargs) + if not method.is_fitted: + data = method.fit_transform(X, load_from_cache_file=False) + else: + data = method.transform(X) + + from biofit.visualization.dimension_reduction import DimensionReductionPlotter + + return DimensionReductionPlotter().plot( + data, + labels, + group, + input_columns=input_columns, + label_column=label_column, + group_column=group_column, + n_components=n_components, + dimension_reducer=method, + show=show, + path=output_dir, + device=device, + fingerprint=fingerprint, + unused_columns=unused_columns, + raise_if_missing=raise_if_missing, + ) + + +def get_feature_importances(models, label_names=None): + if not isinstance(models, list): + models = [models] + tables = [] + for fold, model in enumerate(models): + if isinstance(model, Pipeline): + _m = model[-1] + else: + _m = model + features = _m.config.feature_names_in_ + feature_importances = _m.feature_importances_ + classes = label_names + if classes is None: + if hasattr(_m, "config") and hasattr(_m.config, "class_names"): + classes = _m.config.class_names + if classes is None: + classes = getattr(_m, "classes_", None) + + if classes is not None and ( + (len(feature_importances) // len(classes)) == len(features) + ): + # this means its a one vs all classifier + _tables = [] + for i, c in enumerate(classes): + _tables.append( + pd.Series( + feature_importances[ + i * len(features) : (i + 1) * len(features) + ], + index=features, + name=f"importances_{fold + 1}_{c}_vs_all", + ) + ) + tables.extend(_tables) + else: + tables.append( + pd.DataFrame( + { + f"importances_{fold + 1}": feature_importances, + }, + index=features, + ) + ) + + if len(tables) == 1: + return tables[0] + + # Concatenate tables horizontally + return pd.concat(tables, axis=1, ignore_index=False, copy=False) + + +def plot_feature_importance( + feature_importances, + models=None, + X=None, + y=None, + sample_metadata=None, + feature_metadata: dict = None, + input_columns: SelectedColumnTypes = None, + target_columns: SelectedColumnTypes = None, + sample_metadata_columns: SelectedColumnTypes = None, + sample_column: str = None, + label_names: List[str] = None, + plot_top: int = Unset("15"), + feature_meta_name: str = Unset("None"), + feature_column: str = Unset("None"), + cols: List[str] = Unset("None"), + colHeat: List[str] = Unset("None"), + dat_log: str = Unset("field(default=None, init=True, repr=False)"), + show_column_names: bool = Unset("False"), + scale_legend_title: str = Unset('"Value"'), + column_title: str = Unset('"Samples"'), + row_title: str = Unset('"Features"'), + plot_title: str = Unset('"Values"'), + output_dir: str = None, + device: str = "pdf", + fingerprint: str = None, + unused_columns: SelectedColumnTypes = None, + raise_if_missing: bool = True, + show=True, +): + """Plot the feature importance of the input data. + + Args: + X: Input data + models: List of models + y: Target variable + sample_metadata: Sample metadata + feature_metadata: Feature metadata + Returns: + Plot object. + """ + + from biofit.visualization.feature_importance import FeatureImportancePlotter + + if feature_importances is None and models is not None: + feature_importances = get_feature_importances(models, label_names) + feature_importances = feature_importances.reset_index(names=["features"]) + elif feature_importances is None: + raise ValueError("feature_importances or models must be provided") + + FeatureImportancePlotter().plot( + X, + y=y, + sample_metadata=sample_metadata, + input_columns=input_columns, + target_columns=target_columns, + sample_metadata_columns=sample_metadata_columns, + feature_importances=feature_importances, + feature_metadata=feature_metadata, + plot_top=plot_top, + feature_meta_name=feature_meta_name, + sample_column=sample_column, + feature_column=feature_column, + cols=cols, + colHeat=colHeat, + dat_log=dat_log, + show_column_names=show_column_names, + scale_legend_title=scale_legend_title, + column_title=column_title, + row_title=row_title, + plot_title=plot_title, + show=show, + path=output_dir, + device=device, + fingerprint=fingerprint, + unused_columns=unused_columns, + raise_if_missing=raise_if_missing, + ) + + +def plot_sample_metadata( + sample_metadata, + outcome_data=None, + sample_metadata_columns: Optional[SelectedColumnTypes] = None, + outcome_column: Optional[SelectedColumnTypes] = None, + output_dir: str = None, + device: str = "pdf", + fingerprint: str = None, + unused_columns: SelectedColumnTypes = None, + raise_if_missing: bool = True, +): + requires_backends(plot_sample_metadata, ["rpy2"]) + from biofit.visualization.sample_metadata import SampleMetadataPlotter + + SampleMetadataPlotter().plot( + sample_metadata, + outcome_data, + sample_metadata_columns=sample_metadata_columns, + outcome_column=outcome_column, + path=output_dir, + device=device, + fingerprint=fingerprint, + unused_columns=unused_columns, + raise_if_missing=raise_if_missing, + ) diff --git a/src/biofit/visualization/report_generation.py b/src/biofit/visualization/report_generation.py new file mode 100644 index 0000000..8d9be83 --- /dev/null +++ b/src/biofit/visualization/report_generation.py @@ -0,0 +1,35 @@ +import os +import sys +from typing import TYPE_CHECKING, Optional + +import joblib +from biocore.utils.import_util import is_optuna_available + +if TYPE_CHECKING: + import optuna + + +def report_generation(data, study: Optional["optuna.Study"] = None, cache_dir=None): + if cache_dir is None and ( + ( + "biosets" in sys.modules + and isinstance(data, getattr(sys.modules["biosets"], "Bioset")) + ) + or ( + "datasets" in sys.modules + and isinstance(data, getattr(sys.modules["datasets"], "Dataset")) + ) + ): + cache_dir = os.path.dirname(data.cache_files[0]["filename"]) + + if study is None: + if cache_dir: + with open(os.path.join(cache_dir, "study.joblib"), "rb") as f: + study = joblib.load(f) + + if is_optuna_available() and study: + import optuna.visualization as ov + + if study is None: + raise ValueError("Study object is required for optuna report generation") + ov.plot_optimization_history(study) diff --git a/src/biofit/visualization/sample_metadata.py b/src/biofit/visualization/sample_metadata.py new file mode 100644 index 0000000..1794958 --- /dev/null +++ b/src/biofit/visualization/sample_metadata.py @@ -0,0 +1,128 @@ +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING, List, Optional, Type + +from biocore import DataHandler + +from biofit.integration import RCaller +from biofit.integration.biosets import get_feature +from biofit.processing import SelectedColumnTypes, sync_backup_config +from biofit.visualization.plotting import BasePlotter, PlotterConfig + +if TYPE_CHECKING: + from datasets import Dataset + + +@dataclass +class SampleMetadataPlotterConfig(PlotterConfig): + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("METADATA_FEATURE_TYPES"), + get_feature("TARGET_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [ + get_feature("METADATA_FEATURE_TYPES"), + get_feature("TARGET_FEATURE_TYPES"), + ], + init=False, + repr=False, + ) + r_source: str = field( + default=(Path(__file__).parent / "plot_sample_metadata.R").as_posix(), + init=False, + repr=False, + ) + main_method: str = field(default="plot_sample_metadata", init=False, repr=False) + sample_metadata_columns: Optional[SelectedColumnTypes] = None + outcome_column: Optional[SelectedColumnTypes] = None + + +class SampleMetadataPlotter(BasePlotter): + _config_class = SampleMetadataPlotterConfig + config: SampleMetadataPlotterConfig + + def __init__( + self, + config: Optional[SampleMetadataPlotterConfig] = None, + sample_metadata_columns: Optional[SelectedColumnTypes] = None, + outcome_column: Optional[SelectedColumnTypes] = None, + install_missing: bool = None, + **kwargs, + ): + super().__init__( + config=config, + sample_metadata_columns=sample_metadata_columns, + outcome_column=outcome_column, + install_missing=install_missing, + **kwargs, + ) + r_source = (Path(__file__).parent / "plot_sample_metadata.R").as_posix() + r_caller = RCaller.from_script(r_source) + self.plotter = r_caller.get_method(self.config.main_method) + + @sync_backup_config + def set_params(self, **kwargs): + self.config = self.config.replace_defaults(**kwargs) + r_source = (Path(__file__).parent / "plot_sample_metadata.R").as_posix() + r_caller = RCaller.from_script(r_source) + self.plotter = r_caller.get_method(self.config.main_method) + return self + + def plot_dataset(self, X: "Dataset", y: "Dataset" = None): + from biosets import decode + + if y is not None and isinstance( + next(iter(y._info.features.values())), get_feature("ClassLabel") + ): + y = decode(y) + current_name = next(iter(y._info.features.keys())) + original_name = next(iter(y._info.features.values())).id or current_name + if current_name != original_name: + y = DataHandler.rename_column(y, current_name, original_name) + if original_name in X.column_names: + X = DataHandler.drop_column(X, original_name) + self.plot_arrow( + DataHandler.to_arrow(X), DataHandler.to_arrow(y) if y is not None else None + ) + + def plot_arrow(self, X, y): + if self.config.outcome_column is None and y is None: + raise ValueError( + "No outcome column provided. Please provide the outcome data " + "or specify the outcome column." + ) + + if self.config.outcome_column is None: + self.config.outcome_column = list(y.column_names)[0] + self.plotter(X, outcome=y, **self.config.get_params()) + + def plot( + self, + sample_metadata, + outcome_data: SelectedColumnTypes = None, + sample_metadata_columns: Optional[SelectedColumnTypes] = None, + outcome_column: Optional[SelectedColumnTypes] = None, + path: str = None, + device: str = "pdf", + fingerprint: str = None, + unused_columns: SelectedColumnTypes = None, + raise_if_missing: bool = True, + show: bool = True, + ): + self.config._input_columns = self._set_input_columns_and_arity( + sample_metadata_columns, outcome_column + ) + return self._plot( + sample_metadata, + outcome_data, + path=path, + device=device, + fingerprint=fingerprint, + unused_columns=unused_columns, + raise_if_missing=raise_if_missing, + show=show, + ) diff --git a/src/biofit/visualization/scatterplot.py b/src/biofit/visualization/scatterplot.py new file mode 100644 index 0000000..e874e32 --- /dev/null +++ b/src/biofit/visualization/scatterplot.py @@ -0,0 +1,188 @@ +import textwrap +from dataclasses import dataclass, field +from typing import List, Optional, Type + +import biofit.config as config +from biofit.integration.biosets import get_feature +from biofit.processing import SelectedColumnTypes, sync_backup_config +from biofit.utils.types import Unset +from biofit.visualization.plotting import BasePlotter, PlotterConfig + + +@dataclass +class ScatterPlotConfig(PlotterConfig): + processor_type: str = field(default="scaling", init=False, repr=False) + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [None, get_feature("TARGET_FEATURE_TYPES"), None], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [None, get_feature("TARGET_FEATURE_TYPES"), None], + init=False, + repr=False, + ) + _fit_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES"), None, None], + init=False, + repr=False, + ) + _transform_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES"), None, None], + init=False, + repr=False, + ) + r_source: str = field( + default=(config.R_SCRIPTS / "plotting_utils.R").as_posix(), + init=False, + repr=False, + ) + main_method: str = field(default="generate_scatterplot", init=False, repr=False) + + groupby: str = None + xdata: str = None + ydata: str = None + xlab: str = None + ylab: str = None + title: str = "Scatterplot" + alpha: str = 1 + col_set: str = "Set1" + cols: List[str] = None + xlog: str = None + ylog: str = None + + +class ScatterPlotter(BasePlotter): + _config_class = ScatterPlotConfig + config: ScatterPlotConfig + + def __init__( + self, + groupby: str = None, + xdata: str = None, + ydata: str = None, + xlab: str = Unset("None"), + ylab: str = Unset("None"), + title: str = Unset('"Scatterplot"'), + alpha: str = Unset("1"), + col_set: str = Unset('"Set1"'), + cols: List[str] = Unset("None"), + xlog: str = Unset("None"), + ylog: str = Unset("None"), + install_missing: bool = None, + config: Optional[ScatterPlotConfig] = None, + ): + super().__init__( + config=config, + xlab=xlab, + ylab=ylab, + title=title, + alpha=alpha, + col_set=col_set, + cols=cols, + xlog=xlog, + ylog=ylog, + install_missing=install_missing, + ) + self.plotter = None + if self.config.r_source and self.config.main_method: + self.r_caller = RCaller.from_script(self.config.r_source) + enter_code = textwrap.dedent( + """ + suppressPackageStartupMessages(require(ggplot2)) + """ + ) + exit_code = textwrap.dedent( + """ + ggplot2::ggsave(path, results) + """ + ) + self.plotter = self.r_caller.get_method( + self.config.main_method, enter_code, exit_code + ) + + @sync_backup_config + def set_params(self, **kwargs): + self.config = self.config.replace_defaults(**kwargs) + if self.config.r_source and self.config.main_method: + self.r_caller = RCaller.from_script(self.config.r_source) + enter_code = textwrap.dedent( + """ + suppressPackageStartupMessages(require(ggplot2)) + """ + ) + exit_code = textwrap.dedent( + """ + ggplot2::ggsave(path, results) + """ + ) + self.plotter = self.r_caller.get_method( + self.config.main_method, enter_code, exit_code + ) + + return self + + def plot( + self, + x, + y=None, + group=None, + xdata: SelectedColumnTypes = None, + ydata: SelectedColumnTypes = None, + groupby: SelectedColumnTypes = None, + xlab: str = Unset("None"), + ylab: str = Unset("None"), + title: str = Unset('"Scatterplot"'), + alpha: str = Unset("1"), + col_set: str = Unset('"Set1"'), + cols: List[str] = Unset("None"), + xlog: str = Unset("None"), + ylog: str = Unset("None"), + path: str = None, + device: str = "pdf", + fingerprint: str = None, + unused_columns: SelectedColumnTypes = None, + raise_if_missing: bool = True, + show: bool = True, + ): + self.config._input_columns = self._set_input_columns_and_arity( + xdata, ydata, groupby + ) + return self._plot( + x, + y, + group, + xlab=xlab, + ylab=ylab, + title=title, + alpha=alpha, + col_set=col_set, + cols=cols, + xlog=xlog, + ylog=ylog, + path=path, + device=device, + fingerprint=fingerprint, + unused_columns=unused_columns, + raise_if_missing=raise_if_missing, + show=show, + ) + + def plot_dataset(self, x, y=None, group=None): + from biosets import Bioset, decode + + if isinstance(group, Bioset): + group = decode(group) + + return self.plot_arrow( + x._data.table, + y._data.table if y else None, + group._data.table if group else None, + ) + + def plot_arrow(self, x, y=None, group=None): + kwargs = self.config.get_params() + context_kwargs = { + "path": kwargs.pop("path", None), + } + self.plotter(x, y, group, context_kwargs=context_kwargs, **kwargs) diff --git a/src/biofit/visualization/violin.py b/src/biofit/visualization/violin.py new file mode 100644 index 0000000..786415f --- /dev/null +++ b/src/biofit/visualization/violin.py @@ -0,0 +1,140 @@ +import textwrap +from dataclasses import dataclass, field +from typing import List, Type + +import biofit.config as config +from biofit.integration.biosets import get_feature +from biofit.integration.R import RCaller +from biofit.processing import SelectedColumnTypes, sync_backup_config +from biofit.utils.types import Unset +from biofit.visualization.plotting import BasePlotter, PlotterConfig + + +@dataclass +class ViolinConfig(PlotterConfig): + processor_type: str = field(default="scaling", init=False, repr=False) + _fit_input_feature_types: List[Type] = field( + default_factory=lambda: [None, get_feature("TARGET_FEATURE_TYPES")], + init=False, + repr=False, + ) + _transform_input_feature_types: List[Type] = field( + default_factory=lambda: [None, get_feature("TARGET_FEATURE_TYPES")], + init=False, + repr=False, + ) + _fit_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES"), None], + init=False, + repr=False, + ) + _transform_unused_feature_types: List[Type] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES"), None], + init=False, + repr=False, + ) + + r_source: str = field( + default=(config.R_SCRIPTS / "plotting_utils.R").as_posix(), + init=False, + repr=False, + ) + main_method: str = field(default="generate_violin", init=False, repr=False) + + column: str = None + label_name: str = "labels" + xlab: str = "Labels" + ylab: str = "Value" + + +class ViolinPlotter(BasePlotter): + _config_class = ViolinConfig + config: ViolinConfig + + def __init__( + self, + column: str = None, + label_name: str = None, + xlab: str = Unset('"Labels"'), + ylab: str = Unset('"Value"'), + install_missing: bool = None, + config=None, + ): + super().__init__( + config=config, xlab=xlab, ylab=ylab, install_missing=install_missing + ) + self.plotter = None + if self.config.r_source and self.config.main_method: + self.r_caller = RCaller.from_script(self.config.r_source) + enter_code = textwrap.dedent( + """ + suppressPackageStartupMessages(require(ggplot2)) + """ + ) + exit_code = textwrap.dedent( + """ + ggplot2::ggsave(path, results) + """ + ) + self.plotter = self.r_caller.get_method( + self.config.main_method, enter_code, exit_code + ) + + @sync_backup_config + def set_params(self, **kwargs): + self.config = self.config.replace_defaults(**kwargs) + if self.config.r_source and self.config.main_method: + self.r_caller = RCaller.from_script(self.config.r_source) + enter_code = textwrap.dedent( + """ + suppressPackageStartupMessages(require(ggplot2)) + """ + ) + exit_code = textwrap.dedent( + """ + ggplot2::ggsave(path, results) + """ + ) + self.plotter = self.r_caller.get_method( + self.config.main_method, enter_code, exit_code + ) + + return self + + def plot( + self, + x, + y=None, + column: SelectedColumnTypes = None, + label_name: SelectedColumnTypes = None, + xlab: str = Unset('"Labels"'), + ylab: str = Unset('"Value"'), + path: str = None, + device: str = "pdf", + fingerprint: str = None, + unused_columns: SelectedColumnTypes = None, + raise_if_missing: bool = True, + show: bool = True, + ): + self.config._input_columns = self._set_input_columns_and_arity( + column, label_name + ) + return self._plot( + x, + y, + xlab=xlab, + ylab=ylab, + path=path, + device=device, + fingerprint=fingerprint, + unused_columns=unused_columns, + raise_if_missing=raise_if_missing, + show=show, + ) + + def plot_arrow(self, x, y=None): + self.plotter( + x, + y, + **self.config.get_params(), + ) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..7b503c4 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,108 @@ +from pathlib import Path + +import pytest +from biocore.utils.import_util import is_biosets_available, is_datasets_available +from biofit.utils.logging import silence + +# sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), './mock_packages'))) + +pytest_plugins = ["tests.fixtures.files", "tests.fixtures.fsspec"] + + +def pytest_collection_modifyitems(config, items): + # Mark tests as "unit" by default if not marked as "integration" (or already marked as "unit") + for item in items: + if any(marker in item.keywords for marker in ["integration", "unit"]): + continue + item.add_marker(pytest.mark.unit) + + +def pytest_configure(config): + config.addinivalue_line( + "markers", "torchaudio_latest: mark test to run with torchaudio>=0.12" + ) + + +@pytest.fixture(autouse=True) +def set_test_cache_config(tmp_path_factory, monkeypatch): + test_cache_home = tmp_path_factory.getbasetemp() / "cache" + test_patches_cache = test_cache_home / "patches" + test_datasets_cache = test_cache_home / "datasets" + test_processors_cache = test_cache_home / "processors" + test_modules_cache = test_cache_home / "modules" + monkeypatch.setattr("biofit.config.BIOFIT_CACHE_HOME", Path(test_cache_home)) + monkeypatch.setattr("biofit.config.BIOFIT_PATCHES_CACHE", Path(test_patches_cache)) + test_downloaded_datasets_path = test_datasets_cache / "downloads" + test_extracted_datasets_path = test_datasets_cache / "downloads" / "extracted" + if is_biosets_available(): + monkeypatch.setattr( + "biosets.config.BIOSETS_DATASETS_CACHE", Path(test_datasets_cache) + ) + + monkeypatch.setattr( + "biosets.config.DOWNLOADED_BIOSETS_PATH", + str(test_downloaded_datasets_path), + ) + + monkeypatch.setattr( + "biosets.config.EXTRACTED_BIOSETS_PATH", + str(test_extracted_datasets_path), + ) + + if is_datasets_available(): + monkeypatch.setattr( + "datasets.config.HF_DATASETS_CACHE", str(test_datasets_cache) + ) + monkeypatch.setattr("datasets.config.HF_MODULES_CACHE", str(test_modules_cache)) + monkeypatch.setattr( + "datasets.config.DOWNLOADED_DATASETS_PATH", + str(test_downloaded_datasets_path), + ) + monkeypatch.setattr( + "datasets.config.EXTRACTED_DATASETS_PATH", str(test_extracted_datasets_path) + ) + + monkeypatch.setattr( + "biofit.config.BIOFIT_PROCESSORS_CACHE", Path(test_processors_cache) + ) + monkeypatch.setattr("biofit.config.BIOFIT_MODULES_CACHE", Path(test_modules_cache)) + + +# @pytest.fixture(autouse=True, scope="session") +# def disable_tqdm_output(): +# disable_progress_bar() + + +# @pytest.fixture(autouse=True, scope="session") +# def set_info_verbosity(): +# set_verbosity_info() + + +@pytest.fixture(autouse=True, scope="session") +def silence_ouput(): + silence() + + +@pytest.fixture(autouse=True) +def set_update_download_counts_to_false(monkeypatch): + # don't take tests into account when counting downloads + if is_datasets_available(): + monkeypatch.setattr("datasets.config.HF_UPDATE_DOWNLOAD_COUNTS", False) + + +@pytest.fixture +def set_sqlalchemy_silence_uber_warning(monkeypatch): + # Required to suppress RemovedIn20Warning when feature(s) are not compatible with SQLAlchemy 2.0 + # To be removed once SQLAlchemy 2.0 supported + try: + monkeypatch.setattr("sqlalchemy.util.deprecations.SILENCE_UBER_WARNING", True) + except AttributeError: + pass + + +@pytest.fixture(autouse=True, scope="session") +def zero_time_out_for_remote_code(): + if is_datasets_available(): + import datasets.config + + datasets.config.TIME_OUT_REMOTE_CODE = 0 diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py new file mode 100644 index 0000000..e3022a8 --- /dev/null +++ b/tests/fixtures/files.py @@ -0,0 +1,1426 @@ +import contextlib +import csv +import json +import os +import sqlite3 +import tarfile +import textwrap +import zipfile +from decimal import Decimal + +import numpy as np +import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq +import pytest +from biocore.utils.import_util import ( + is_biosets_available, + requires_backends, +) +from biofit.integration.biosets import get_feature +from biofit.utils import enable_full_determinism +from biofit.utils.py_util import set_seed +from sklearn.datasets import make_classification + +# Constants +ALPHANUMERIC = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" +SEED = 42 +enable_full_determinism(SEED) +PA_DATA = { + "null": lambda num_rows: np.array([None] * num_rows), + "bool": lambda num_rows: np.random.choice([True, False], size=num_rows), + "one_hot": lambda num_rows: np.random.choice([0, 1], size=num_rows), + "multi_bins": lambda num_rows: np.random.choice([0, 1], size=num_rows), + "int8": lambda num_rows: np.random.normal(loc=0, scale=127, size=num_rows).astype( + np.int8 + ), + "int16": lambda num_rows: np.random.normal( + loc=0, scale=32767, size=num_rows + ).astype(np.int16), + "int32": lambda num_rows: np.random.normal( + loc=0, scale=2147483647, size=num_rows + ).astype(np.int32), + "int64": lambda num_rows: np.random.normal( + loc=0, scale=9223372036854775807, size=num_rows + ).astype(np.int64), + "uint8": lambda num_rows: np.abs( + np.random.normal(loc=127.5, scale=127.5, size=num_rows) + ).astype(np.uint8), + "uint16": lambda num_rows: np.abs( + np.random.normal(loc=32767.5, scale=32767.5, size=num_rows) + ).astype(np.uint16), + "uint32": lambda num_rows: np.abs( + np.random.normal(loc=2147483647.5, scale=2147483647.5, size=num_rows) + ).astype(np.uint32), + "uint64": lambda num_rows: np.abs( + np.random.normal( + loc=9223372036854775807.5, scale=9223372036854775807.5, size=num_rows + ) + ).astype(np.uint64), + "float16": lambda num_rows: np.random.normal(size=num_rows).astype(np.float16), + "float32": lambda num_rows: np.random.normal(size=num_rows).astype(np.float32), + "float64": lambda num_rows: np.random.normal(size=num_rows), + "string": lambda num_rows: np.array( + [ + "".join(np.random.choice(list("abcdefghijklmnopqrstuvwxyz"), size=5)) + for _ in range(num_rows) + ] + ), + "binary": lambda num_rows: np.array( + [ + bytes( + "".join(np.random.choice(list("abcdefghijklmnopqrstuvwxyz"), size=5)), + "utf-8", + ) + for _ in range(num_rows) + ] + ), + "large_string": lambda num_rows: np.array( + [ + "".join(np.random.choice(list("abcdefghijklmnopqrstuvwxyz"), size=5)) + for _ in range(num_rows) + ] + ), + "large_binary": lambda num_rows: np.array( + [ + bytes( + "".join(np.random.choice(list("abcdefghijklmnopqrstuvwxyz"), size=5)), + "utf-8", + ) + for _ in range(num_rows) + ] + ), + "date32": lambda num_rows: np.array( + [np.random.randint(1, 2147483647, size=num_rows).tolist()] + ), + "date64": lambda num_rows: np.array( + [np.random.randint(1, 2147483647, size=num_rows).tolist()] + ), + "time32": lambda num_rows: np.random.normal( + loc=0, scale=2147483647, size=num_rows + ).astype(np.int32), + "time64": lambda num_rows: np.random.normal( + loc=0, scale=9223372036854775807, size=num_rows + ).astype(np.int64), + "timestamp": lambda num_rows: np.array( + [ + pd.Timestamp(np.random.choice(pd.date_range("2020-01-01", periods=365))) + for _ in range(num_rows) + ] + ), + "duration": lambda num_rows: np.random.normal( + loc=500, scale=250, size=num_rows + ).astype(int), + "decimal128": lambda num_rows: np.array( + [Decimal(np.random.randint(1, 1000)) for _ in range(num_rows)] + ), + "struct": lambda num_rows: [ + {"a": np.random.randint(1, 1000)} for _ in range(num_rows) + ], +} + + +PA_FIELDS = { + "null": pa.field("null", pa.null()), + "bool": pa.field("bool", pa.bool_()), + "int8": pa.field("int8", pa.int8()), + "int16": pa.field("int16", pa.int16()), + "int32": pa.field("int32", pa.int32()), + "int64": pa.field("int64", pa.int64()), + "uint8": pa.field("uint8", pa.uint8()), + "uint16": pa.field("uint16", pa.uint16()), + "uint32": pa.field("uint32", pa.uint32()), + "uint64": pa.field("uint64", pa.uint64()), + "float16": pa.field("float16", pa.float16()), + "float32": pa.field("float32", pa.float32()), + "float64": pa.field("float64", pa.float64()), + "string": pa.field("string", pa.string()), + "binary": pa.field("binary", pa.binary()), + "large_string": pa.field("large_string", pa.large_string()), + "large_binary": pa.field("large_binary", pa.large_binary()), + "date32": pa.field("date32", pa.date32()), + "date64": pa.field("date64", pa.date64()), + "time32": pa.field("time32", pa.time32("s")), + "time64": pa.field("time64", pa.time64("ns")), + "timestamp": pa.field("timestamp", pa.timestamp("s")), + "duration": pa.field("duration", pa.duration("s")), + "decimal128": pa.field("decimal128", pa.decimal128(38, 9)), + "struct": pa.field("struct", pa.struct([pa.field("a", pa.int32())])), +} + + +def create_directory(path): + """ + Create a directory if it doesn't exist. + """ + try: + os.makedirs(path, exist_ok=True) + except OSError as e: + print(f"Error creating directory {path}: {e}") + raise + + +def _create_all_arrow_types_dataframe(num_rows=100, feature_type=None): + from datasets.features import Value + + if feature_type is None: + feature_type = Value + data = {k: v(num_rows) for k, v in PA_DATA.items()} + + features = { + k: feature_type(k) if k != "struct" else {"a": feature_type("int32")} + for k in PA_FIELDS.keys() + } + + return data, features + + +def create_omic_dataset( + num_rows=100, + num_cols=None, + dtype="all", + sample="sample_id", + batch="batch", + label="label", + multi_class=False, + task="classification", + label_type="int", + metadata=True, + input_feature=None, + sparse=False, + missing_labels=False, +): + """ + Create a sample dataframe with predefined structure. + """ + from datasets import Value + + if input_feature is None: + input_feature = Value + data = {} + features = {} + enable_full_determinism(SEED) + + if sample: + data[sample] = [str(i) for i in range(num_rows)] + features[sample] = get_feature("Sample")("string") + if batch: + data[batch] = [str(i) for i in range(num_rows)] + features[batch] = get_feature("Batch")("string") + metadata_value_options = { + "multi_classification_int": [i % 3 for i in range(num_rows)], + "multi_classification_str": [ + ALPHANUMERIC[i % len(ALPHANUMERIC)] for i in range(num_rows) + ], + "bin_classification_bool": [ + True if i % 2 == 0 else False for i in range(num_rows) + ], + "bin_classification_int": [i % 2 for i in range(num_rows)], + "bin_classification_str": [ + "positive" if i > num_rows // 2 else "negative" for i in range(num_rows) + ], + "regression": np.random.randn(num_rows), + } + metadata_feature_options = { + "multi_classification_int": "int8", + "multi_classification_str": "string", + "bin_classification_bool": "bool", + "bin_classification_int": "int8", + "bin_classification_str": "string", + "regression": "float32", + } + if label: + if task == "classification": + if multi_class: + label_name = "multi" + else: + label_name = "bin" + label_name += f"_classification_{label_type}" + else: + label_name = "regression" + + data[label] = metadata_value_options.pop(label_name) + if missing_labels: + # Randomly set 10% of labels to -1 if classification, else set to None + if task == "classification": + indices_to_replace = np.random.choice( + num_rows, int(num_rows * 0.1), replace=False + ) + data[label] = [ + -1 if i in indices_to_replace else lab + for i, lab in enumerate(data[label]) + ] + else: + indices_to_replace = np.random.choice( + num_rows, int(num_rows * 0.1), replace=False + ) + data[label] = [ + None if i in indices_to_replace else lab + for i, lab in enumerate(data[label]) + ] + label_dtype = metadata_feature_options.pop(label_name) + if label_name == "regression": + features[label] = get_feature("RegressionTarget")(label_dtype) + else: + names = list(set(data[label])) + if not isinstance(names[0], str): + names = [str(n) for n in names] + else: + name_map = {n: i for i, n in enumerate(names)} + data[label] = [name_map[n] for n in data[label]] + + num_classes = len(names) + + features[label] = get_feature("ClassLabel")( + num_classes=num_classes, names=names + ) + if metadata: + if isinstance(metadata, str): + data.update(metadata_value_options) + features.update( + { + k: get_feature("Metadata")(dtype=v) + for k, v in metadata_feature_options.items() + } + ) + else: + for label, v in metadata_value_options.items(): + data[label] = v + features[label] = Value(metadata_feature_options[label]) + + if dtype == "all": + ext_data, ext_features = _create_all_arrow_types_dataframe( + num_rows=num_rows, feature_type=input_feature + ) + else: + ext_data = {} + ext_features = {} + if sparse and isinstance(sparse, bool): + sparse = 0.8 + if num_cols is None: + num_cols = 1 + + dtype_to_pa = { + "multi_bins": "int32", + "one_hot": "int32", + } + + for i in range(num_cols): + arr: np.ndarray = PA_DATA[dtype](num_rows) + ext_data[f"{dtype}_{i}"] = arr.tolist() + + ext_features[f"{dtype}_{i}"] = input_feature( + dtype=dtype_to_pa.get(dtype, dtype), + metadata={ + "my_metadata_str": ALPHANUMERIC[ + np.random.randint(0, len(ALPHANUMERIC)) + ], + "my_metadata_int": np.random.randint(0, 100), + }, + ) + if sparse: + mat = np.array([ext_data[f"{dtype}_{i}"] for i in range(num_cols)]).T + for i in range(num_cols): + arr = np.array(ext_data[f"{dtype}_{i}"]) + total_values = arr.size + if isinstance(sparse, list): + _sparse = sparse[i] + else: + _sparse = sparse + if isinstance(_sparse, bool): + _sparse = np.random.uniform(0.1, 0.9) + + num_to_replace = max( + min(int(total_values * (_sparse)), total_values - 1), 0 + ) + indices_to_replace = np.random.choice( + total_values, num_to_replace, replace=False + ) + # check if replacing with 0 would make a row all 0s + for idx in indices_to_replace: + if dtype in [ + "one_hot", + "multi_bins", + "uint8", + "uint16", + "uint32", + "uint64", + ]: + if np.sum(mat[:idx] > 0) + np.sum(mat[idx + 1 :] > 0) == 0: + indices_to_replace = np.delete( + indices_to_replace, np.where(indices_to_replace == idx) + ) + else: + arr[idx] = 0 + else: + if all(v is None for v in mat[:idx]) and all( + v is None for v in mat[idx + 1 :] + ): + indices_to_replace = np.delete( + indices_to_replace, np.where(indices_to_replace == idx) + ) + else: + arr[idx] = 0 + ext_data[f"{dtype}_{i}"] = arr.tolist() + + data.update(ext_data) + features.update(ext_features) + if is_biosets_available(): + import biosets + import datasets + + return biosets.Bioset.from_dict(data, features=datasets.Features(features)) + return pd.DataFrame(data) + + +def create_feature_dataframe(num_cols=100, feature_id="feature"): + """ + Create a feature dataframe with predefined structure. + """ + enable_full_determinism(SEED) + data = { + feature_id: [str(i) for i in range(num_cols)], + } + + fields = [ + pa.field(feature_id, pa.string()), + ] + + ext_data, ext_fields = _create_all_arrow_types_dataframe(num_rows=num_cols) + data.update(ext_data) + fields.extend(ext_fields) + + return pa.table(data, schema=pa.schema(fields)) + + +def directory_exists_with_files(path, expected_files): + """ + Check if a directory exists with the expected files. + """ + if not os.path.exists(path): + return False + if not all(os.path.exists(os.path.join(path, file)) for file in expected_files): + return False + return True + + +# def save_dataframes(dfs, data_dir, filenames): +# """ +# Save a list of dataframes to CSV in the specified directory. +# """ +# for df, filename in zip(dfs, filenames): +# file_ext = filename.split(".")[-1] +# if file_ext in ["parquet"]: +# tbl = pa.Table.from_pandas(df) if isinstance(df, pd.DataFrame) else df +# if "float16" in tbl.schema.names: +# tbl = tbl.drop(["float16"]) # not supported by parquet +# writer = ParquetWriter( +# path=os.path.join(data_dir, filename), schema=tbl.schema +# ) +# writer.write_table(tbl) +# elif file_ext in ["arrow"]: +# tbl = pa.Table.from_pandas(df) if isinstance(df, pd.DataFrame) else df +# writer = ArrowWriter( +# path=os.path.join(data_dir, filename), schema=tbl.schema +# ) +# writer.write_table(tbl) +# elif file_ext in ["csv"]: +# df.to_csv(os.path.join(data_dir, filename), index=False) +# elif file_ext in ["tsv", "txt"]: +# df.to_csv(os.path.join(data_dir, filename), sep="\t", index=False) + + +# def create_fake_data_dir(data, base_dir, overwrite=False): +# for name, filenames, dfs, _ in data: +# data_dir = f"{base_dir}/{name}" +# os.makedirs(data_dir, exist_ok=True) +# if not directory_exists_with_files(data_dir, filenames) or overwrite: +# save_dataframes(dfs, data_dir, filenames) + + +def create_dataset_with_sklearn( + path, + experiment_type, + n_features=20, + n_samples=50, + multi_class=False, + lab_as_str=True, + dataset_type="snp", + label_column="label", + add_missing_labels=False, +): + if multi_class: + num_classes = 3 + else: + num_classes = 2 + + samples, labs, names = create_data( + n_samples, + n_features, + num_classes, + experiment_type, + lab_as_str=lab_as_str, + _add_missing_labels=add_missing_labels, + ) + data = create_sample_metadata(n_samples) + features = None + if is_biosets_available(): + if experiment_type == "snp": + features = create_features( + n_features, get_feature("GenomicVariant"), "int8", num_classes, names + ) + elif experiment_type in ["otu", "asv"]: + features = create_features( + n_features, get_feature("Abundance"), "int32", num_classes, names + ) + elif experiment_type == "maldi": + features = create_features( + n_features, get_feature("PeakIntensity"), "int32", num_classes, names + ) + + data = pd.concat( + [ + data, + pd.DataFrame(samples, columns=[f"int32_{i}" for i in range(n_features)]), + pd.DataFrame(labs.reshape(-1, 1), columns=[label_column]), + ], + axis=1, + ) + if is_biosets_available(): + import biosets + import datasets + + ds = biosets.Dataset.from_pandas(data, features=datasets.Features(features)) + + ds.info.builder_name = dataset_type + os.makedirs(path, exist_ok=True) + ds.save_to_disk(path) + else: + data.to_csv(path, index=False) + + +def _add_missing_labels(labs, num_classes=2, lab_as_str=True, n_samples=50): + if isinstance(labs, np.ndarray): + new_labs = labs.tolist() + else: + new_labs = labs + # make 10% of labels missing for each class + for i in range(num_classes): + if lab_as_str: + indices_to_replace = np.random.choice( + np.where(labs == ALPHANUMERIC[i % len(ALPHANUMERIC)])[0], + int(n_samples * 0.1), + replace=False, + ) + for idx in indices_to_replace: + new_labs[idx] = None + else: + indices_to_replace = np.random.choice( + np.where(labs == i)[0], int(n_samples * 0.1), replace=False + ) + for idx in indices_to_replace: + new_labs[idx] = -1 + return np.array(new_labs) + + +def create_data( + n_samples, + n_features, + num_classes, + experiment_type, + lab_as_str=False, + add_missing_labels=False, +): + samples, labs = make_classification( + n_samples=n_samples, + n_features=n_features, + n_classes=num_classes, + n_informative=int(n_features * 0.8), + n_redundant=int(n_features * 0.1), + random_state=SEED, + ) + + if experiment_type in ["snp"]: + median = np.median(samples.flatten()) + samples[samples < median] = 0 + samples[samples >= median] = 1 + samples = samples.astype(np.int32) + elif experiment_type in ["otu", "asv", "maldi"]: + samples = np.abs(np.round(np.quantile(samples.flatten(), 0.25))) + samples + samples[samples < 0] = 0 + samples = samples.astype(np.int32) + if lab_as_str: + labs = np.array( + [ALPHANUMERIC[i % len(ALPHANUMERIC)] for i in sorted(labs.tolist())] + ) + names = [ALPHANUMERIC[i % len(ALPHANUMERIC)] for i in range(num_classes)] + else: + names = [str(i) for i in range(num_classes)] + labs = labs.astype(np.int32) + + if add_missing_labels: + labs = _add_missing_labels( + labs, num_classes=num_classes, lab_as_str=lab_as_str, n_samples=n_samples + ) + + samples = pd.DataFrame(samples, columns=[f"feature_{i}" for i in range(n_features)]) + return samples, labs, names + + +def create_features( + n_features, + FeatureType=None, + dtype=None, + num_classes=None, + names=None, + with_metadata=True, + as_dict=False, +): + metadata_feature_options = { + "multi_int": "int32", + "multi_str": "string", + "bin_bool": "bool", + "bin_int": "int32", + "bin_str": "string", + "floating": "float64", + } + if not as_dict: + requires_backends("create_features_for_sample_metadata", "biosets") + + features = {} + if with_metadata: + features.update( + { + "sample_id": get_feature("Sample")("string"), + } + ) + features.update( + { + k: get_feature("Metadata")(dtype=v) + for k, v in metadata_feature_options.items() + } + ) + + for i in range(n_features): + features[f"feature_{i}"] = FeatureType( + dtype="int32", + metadata={ + "my_metadata_str": ALPHANUMERIC[ + np.random.randint(0, len(ALPHANUMERIC)) + ], + "my_metadata_int": np.random.randint(0, 100), + }, + ) + if num_classes is not None: + if num_classes > 2: + features["label"] = get_feature("ClassLabel")( + num_classes=num_classes, names=names + ) + else: + features["label"] = get_feature("BinClassLabel")( + num_classes=num_classes, names=names + ) + + else: + features = {} + if with_metadata: + features = { + "sample_id": {}, + } + features.update({k: {} for k in metadata_feature_options.keys()}) + + for i in range(n_features): + features[f"feature_{i}"] = { + "my_metadata_str": ALPHANUMERIC[ + np.random.randint(0, len(ALPHANUMERIC)) + ], + "my_metadata_int": np.random.randint(0, 100), + } + if num_classes is not None: + features["label"] = { + "num_classes": num_classes, + "names": names, + } + + return features + + +def create_sample_metadata(n_samples=100): + return pd.DataFrame( + { + "sample_id": [f"s{i}" for i in range(n_samples)], + "multi_int": [i % 3 for i in range(n_samples)], + "multi_str": [ + ALPHANUMERIC[i % len(ALPHANUMERIC)] for i in range(n_samples) + ], + "bin_bool": [True if i % 2 == 0 else False for i in range(n_samples)], + "bin_int": [i % 2 for i in range(n_samples)], + "bin_str": [ + "positive" if i > n_samples // 2 else "negative" + for i in range(n_samples) + ], + "floating": np.random.randn(n_samples), + } + ) + + +@pytest.fixture(scope="session") +def count_data(): + data, y, _ = create_data(20, 5, 2, "otu") + y = pd.DataFrame(y[:, None], columns=["label"]) + return data, y + + +@pytest.fixture(scope="session") +def count_data_multi_class(): + data, y, _ = create_data(20, 5, 3, "otu") + y = pd.DataFrame(y[:, None], columns=["label"]) + return data, y + + +@pytest.fixture(scope="session") +def count_data_missing_labels(): + data, y, _ = create_data(20, 5, 2, "otu", add_missing_labels=True) + y = pd.DataFrame(y[:, None], columns=["label"]) + return data, y + + +@pytest.fixture(scope="session") +def binary_data(): + data, y, _ = create_data(20, 5, 2, "snp") + y = pd.DataFrame(y[:, None], columns=["label"]) + return data, y + + +@pytest.fixture(scope="session") +def float_data(): + data, y, _ = create_data(20, 5, 2, None) + y = pd.DataFrame(y[:, None], columns=["label"]) + return data, y + + +@pytest.fixture(scope="session") +def classification_data(): + data, y, _ = create_data(100, 5, 2, None) + y = pd.DataFrame(y[:, None], columns=["label"]) + return data, y + + +@pytest.fixture(scope="session") +def classification_data_multi_class(): + data, y, _ = create_data(100, 5, 3, None) + y = pd.DataFrame(y[:, None], columns=["label"]) + return data, y + + +@pytest.fixture(scope="session") +def feature_metadata(): + return create_features(5, as_dict=True, with_metadata=False) + + +@pytest.fixture(scope="session") +def sample_metadata(): + return create_sample_metadata(20) + + +@pytest.fixture(scope="session") +def biodataset(): + return create_omic_dataset(10, num_cols=3, dtype="float32", metadata="metadata") + + +@pytest.fixture(scope="session") +def snp_dataset_path(tmp_path_factory): + set_seed(SEED) + path = str(tmp_path_factory.mktemp("data") / "SNP") + create_dataset_with_sklearn(path, "snp") + return path + + +@pytest.fixture(scope="session") +def otu_dataset_path(tmp_path_factory): + set_seed(SEED) + path = str(tmp_path_factory.mktemp("data") / "OTU") + create_dataset_with_sklearn(path) + return path + + +@pytest.fixture(scope="session") +def otu_dataset_missing_labels_path(tmp_path_factory): + set_seed(SEED) + path = str(tmp_path_factory.mktemp("data") / "OTU") + create_dataset_with_sklearn(path, "otu", lab_as_str=False, add_missing_labels=True) + return path + + +@pytest.fixture(scope="session") +def otu_dataset_multi_class_path(tmp_path_factory): + set_seed(SEED) + path = str(tmp_path_factory.mktemp("data") / "OTU") + create_dataset_with_sklearn(path, "otu", multi_class=True) + return path + + +@pytest.fixture(scope="session") +def maldi_dataset_path(tmp_path_factory): + set_seed(SEED) + path = str(tmp_path_factory.mktemp("data") / "MALDI") + create_dataset_with_sklearn(path, "maldi") + return path + + +@pytest.fixture(scope="session") +def snp_dataset(): + ds = create_omic_dataset( + num_rows=10, + num_cols=3, + dtype="multi_bins", + metadata="metadata", + input_feature=get_feature("GenomicVariant"), + sparse=0.8, + ) + ds.info.builder_name = "snp" + return ds + + +@pytest.fixture(scope="session") +def maldi_dataset(): + ds = create_omic_dataset( + num_rows=10, + num_cols=3, + dtype="multi_bins", + metadata="metadata", + input_feature=get_feature("PeakIntensity"), + sparse=0.8, + ) + ds.info.builder_name = "maldi" + return ds + + +# @pytest.fixture(scope="session") +# def camda_dataset(): +# camda_dir = "./tests/data/CAMDA" +# camda_metadata_files = os.path.join(camda_dir, "camda.pheno.csv") +# camda_feature_metadata_files = os.path.join(camda_dir, "camda.feature.csv") +# ds = load_dataset( +# "otu", +# data_dir=camda_dir, +# sample_metadata_files=camda_metadata_files, +# feature_metadata_files=camda_feature_metadata_files, +# label_column="City2", +# cache_dir="./.cache", +# ) +# ds.cleanup_cache_files() +# return ds +# + +# @pytest.fixture(scope="session") +# def camda_dataset_files_only(): +# camda_dir = "./tests/data/CAMDA" +# data_files = os.path.join(camda_dir, "*matrix*.csv") +# return load_dataset( +# dataset_type="otu", +# name="camda", +# data_files=data_files, +# label_column="City2", +# ) + + +# @pytest.fixture(scope="session") +# def camda_dataset_no_polars(): +# camda_dir = "./tests/data/CAMDA" +# camda_metadata_files = os.path.join(camda_dir, "camda.pheno.csv") +# camda_feature_metadata_files = os.path.join(camda_dir, "camda.feature.csv") +# return load_dataset( +# "otu", +# data_dir=camda_dir, +# sample_metadata_files=camda_metadata_files, +# feature_metadata_files=camda_feature_metadata_files, +# label_column="City2", +# cache_dir="./.cache", +# use_polars=False, +# ) +# + +# @pytest.fixture(scope="session") +# def tb_dataset(): +# tb_dir = "./tests/data/genomics_TB" +# dataset = load_dataset( +# "snp", +# "TB", +# data_dir=tb_dir, +# label_column="Isoniazid", +# keep_in_memory=False, +# cache_dir="./.cache", +# ) +# dataset.cleanup_cache_files() +# return dataset +# + + +@pytest.fixture(scope="session") +def arrow_file(tmp_path_factory, dataset): + filename = str(tmp_path_factory.mktemp("data") / "file.arrow") + dataset.map(cache_file_name=filename) + return filename + + +# FILE_CONTENT + files + + +FILE_CONTENT = """\ + Text data. + Second line of data.""" + + +@pytest.fixture(scope="session") +def text_file(tmp_path_factory): + filename = tmp_path_factory.mktemp("data") / "file.txt" + data = FILE_CONTENT + with open(filename, "w") as f: + f.write(data) + return filename + + +@pytest.fixture(scope="session") +def bz2_file(tmp_path_factory): + import bz2 + + path = tmp_path_factory.mktemp("data") / "file.txt.bz2" + data = bytes(FILE_CONTENT, "utf-8") + with bz2.open(path, "wb") as f: + f.write(data) + return path + + +@pytest.fixture(scope="session") +def gz_file(tmp_path_factory): + import gzip + + path = str(tmp_path_factory.mktemp("data") / "file.txt.gz") + data = bytes(FILE_CONTENT, "utf-8") + with gzip.open(path, "wb") as f: + f.write(data) + return path + + +@pytest.fixture(scope="session") +def lz4_file(tmp_path_factory): + try: + import lz4.frame + + path = tmp_path_factory.mktemp("data") / "file.txt.lz4" + data = bytes(FILE_CONTENT, "utf-8") + with lz4.frame.open(path, "wb") as f: + f.write(data) + return path + except ImportError: + pytest.skip("lz4 not available") + + +@pytest.fixture(scope="session") +def seven_zip_file(tmp_path_factory, text_file): + try: + import py7zr + + path = tmp_path_factory.mktemp("data") / "file.txt.7z" + with py7zr.SevenZipFile(path, "w") as archive: + archive.write(text_file, arcname=os.path.basename(text_file)) + return path + except ImportError: + pytest.skip("py7zr not available") + + +@pytest.fixture(scope="session") +def tar_file(tmp_path_factory, text_file): + import tarfile + + path = tmp_path_factory.mktemp("data") / "file.txt.tar" + with tarfile.TarFile(path, "w") as f: + f.add(text_file, arcname=os.path.basename(text_file)) + return path + + +@pytest.fixture(scope="session") +def xz_file(tmp_path_factory): + import lzma + + path = tmp_path_factory.mktemp("data") / "file.txt.xz" + data = bytes(FILE_CONTENT, "utf-8") + with lzma.open(path, "wb") as f: + f.write(data) + return path + + +@pytest.fixture(scope="session") +def zip_file(tmp_path_factory, text_file): + import zipfile + + path = tmp_path_factory.mktemp("data") / "file.txt.zip" + with zipfile.ZipFile(path, "w") as f: + f.write(text_file, arcname=os.path.basename(text_file)) + return path + + +@pytest.fixture(scope="session") +def zstd_file(tmp_path_factory): + try: + import zstandard as zstd + + path = tmp_path_factory.mktemp("data") / "file.txt.zst" + data = bytes(FILE_CONTENT, "utf-8") + with zstd.open(path, "wb") as f: + f.write(data) + return path + except ImportError: + pytest.skip("zstandard not available") + + +# xml_file + + +@pytest.fixture(scope="session") +def xml_file(tmp_path_factory): + filename = tmp_path_factory.mktemp("data") / "file.xml" + data = textwrap.dedent( + """\ + + +
+ + + Contingut 1 + Content 1 + + + Contingut 2 + Content 2 + + + Contingut 3 + Content 3 + + + Contingut 4 + Content 4 + + + Contingut 5 + Content 5 + + + """ + ) + with open(filename, "w") as f: + f.write(data) + return filename + + +DATA = [ + {"col_1": "0", "col_2": 0, "col_3": 0.0}, + {"col_1": "1", "col_2": 1, "col_3": 1.0}, + {"col_1": "2", "col_2": 2, "col_3": 2.0}, + {"col_1": "3", "col_2": 3, "col_3": 3.0}, +] +DATA2 = [ + {"col_1": "4", "col_2": 4, "col_3": 4.0}, + {"col_1": "5", "col_2": 5, "col_3": 5.0}, +] +DATA_DICT_OF_LISTS = { + "col_1": ["0", "1", "2", "3"], + "col_2": [0, 1, 2, 3], + "col_3": [0.0, 1.0, 2.0, 3.0], +} + +DATA_312 = [ + {"col_3": 0.0, "col_1": "0", "col_2": 0}, + {"col_3": 1.0, "col_1": "1", "col_2": 1}, +] + +DATA_STR = [ + {"col_1": "s0", "col_2": 0, "col_3": 0.0}, + {"col_1": "s1", "col_2": 1, "col_3": 1.0}, + {"col_1": "s2", "col_2": 2, "col_3": 2.0}, + {"col_1": "s3", "col_2": 3, "col_3": 3.0}, +] + + +@pytest.fixture(scope="session") +def dataset_dict(): + return DATA_DICT_OF_LISTS + + +@pytest.fixture(scope="session") +def arrow_path(tmp_path_factory): + import datasets + + dataset = datasets.Dataset.from_dict(DATA_DICT_OF_LISTS) + path = str(tmp_path_factory.mktemp("data") / "dataset.arrow") + dataset.map(cache_file_name=path) + return path + + +@pytest.fixture(scope="session") +def sqlite_path(tmp_path_factory): + path = str(tmp_path_factory.mktemp("data") / "dataset.sqlite") + with contextlib.closing(sqlite3.connect(path)) as con: + cur = con.cursor() + cur.execute("CREATE TABLE dataset(col_1 text, col_2 int, col_3 real)") + for item in DATA: + cur.execute( + "INSERT INTO dataset(col_1, col_2, col_3) VALUES (?, ?, ?)", + tuple(item.values()), + ) + con.commit() + return path + + +@pytest.fixture(scope="session") +def csv_path(tmp_path_factory): + path = str(tmp_path_factory.mktemp("data") / "dataset.csv") + with open(path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["col_1", "col_2", "col_3"]) + writer.writeheader() + for item in DATA: + writer.writerow(item) + return path + + +@pytest.fixture(scope="session") +def csv2_path(tmp_path_factory): + path = str(tmp_path_factory.mktemp("data") / "dataset2.csv") + with open(path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["col_1", "col_2", "col_3"]) + writer.writeheader() + for item in DATA: + writer.writerow(item) + return path + + +@pytest.fixture(scope="session") +def bz2_csv_path(csv_path, tmp_path_factory): + import bz2 + + path = tmp_path_factory.mktemp("data") / "dataset.csv.bz2" + with open(csv_path, "rb") as f: + data = f.read() + # data = bytes(FILE_CONTENT, "utf-8") + with bz2.open(path, "wb") as f: + f.write(data) + return path + + +@pytest.fixture(scope="session") +def zip_csv_path(csv_path, csv2_path, tmp_path_factory): + path = tmp_path_factory.mktemp("zip_csv_path") / "csv-dataset.zip" + with zipfile.ZipFile(path, "w") as f: + f.write(csv_path, arcname=os.path.basename(csv_path)) + f.write(csv2_path, arcname=os.path.basename(csv2_path)) + return path + + +@pytest.fixture(scope="session") +def zip_uppercase_csv_path(csv_path, csv2_path, tmp_path_factory): + path = tmp_path_factory.mktemp("data") / "dataset.csv.zip" + with zipfile.ZipFile(path, "w") as f: + f.write(csv_path, arcname=os.path.basename(csv_path.replace(".csv", ".CSV"))) + f.write(csv2_path, arcname=os.path.basename(csv2_path.replace(".csv", ".CSV"))) + return path + + +@pytest.fixture(scope="session") +def zip_csv_with_dir_path(csv_path, csv2_path, tmp_path_factory): + path = tmp_path_factory.mktemp("data") / "dataset_with_dir.csv.zip" + with zipfile.ZipFile(path, "w") as f: + f.write(csv_path, arcname=os.path.join("main_dir", os.path.basename(csv_path))) + f.write( + csv2_path, arcname=os.path.join("main_dir", os.path.basename(csv2_path)) + ) + return path + + +@pytest.fixture(scope="session") +def parquet_path(tmp_path_factory): + path = str(tmp_path_factory.mktemp("data") / "dataset.parquet") + schema = pa.schema( + { + "col_1": pa.string(), + "col_2": pa.int64(), + "col_3": pa.float64(), + } + ) + with open(path, "wb") as f: + writer = pq.ParquetWriter(f, schema=schema) + pa_table = pa.Table.from_pydict( + {k: [DATA[i][k] for i in range(len(DATA))] for k in DATA[0]}, schema=schema + ) + writer.write_table(pa_table) + writer.close() + return path + + +@pytest.fixture(scope="session") +def json_list_of_dicts_path(tmp_path_factory): + path = str(tmp_path_factory.mktemp("data") / "dataset.json") + data = {"data": DATA} + with open(path, "w") as f: + json.dump(data, f) + return path + + +@pytest.fixture(scope="session") +def json_dict_of_lists_path(tmp_path_factory): + path = str(tmp_path_factory.mktemp("data") / "dataset.json") + data = {"data": DATA_DICT_OF_LISTS} + with open(path, "w") as f: + json.dump(data, f) + return path + + +@pytest.fixture(scope="session") +def jsonl_path(tmp_path_factory): + path = str(tmp_path_factory.mktemp("data") / "dataset.jsonl") + with open(path, "w") as f: + for item in DATA: + f.write(json.dumps(item) + "\n") + return path + + +@pytest.fixture(scope="session") +def jsonl2_path(tmp_path_factory): + path = str(tmp_path_factory.mktemp("data") / "dataset2.jsonl") + with open(path, "w") as f: + for item in DATA: + f.write(json.dumps(item) + "\n") + return path + + +@pytest.fixture(scope="session") +def jsonl_312_path(tmp_path_factory): + path = str(tmp_path_factory.mktemp("data") / "dataset_312.jsonl") + with open(path, "w") as f: + for item in DATA_312: + f.write(json.dumps(item) + "\n") + return path + + +@pytest.fixture(scope="session") +def jsonl_str_path(tmp_path_factory): + path = str(tmp_path_factory.mktemp("data") / "dataset-str.jsonl") + with open(path, "w") as f: + for item in DATA_STR: + f.write(json.dumps(item) + "\n") + return path + + +@pytest.fixture(scope="session") +def text_gz_path(tmp_path_factory, text_path): + import gzip + + path = str(tmp_path_factory.mktemp("data") / "dataset.txt.gz") + with open(text_path, "rb") as orig_file: + with gzip.open(path, "wb") as zipped_file: + zipped_file.writelines(orig_file) + return path + + +@pytest.fixture(scope="session") +def jsonl_gz_path(tmp_path_factory, jsonl_path): + import gzip + + path = str(tmp_path_factory.mktemp("data") / "dataset.jsonl.gz") + with open(jsonl_path, "rb") as orig_file: + with gzip.open(path, "wb") as zipped_file: + zipped_file.writelines(orig_file) + return path + + +@pytest.fixture(scope="session") +def zip_jsonl_path(jsonl_path, jsonl2_path, tmp_path_factory): + path = tmp_path_factory.mktemp("data") / "dataset.jsonl.zip" + with zipfile.ZipFile(path, "w") as f: + f.write(jsonl_path, arcname=os.path.basename(jsonl_path)) + f.write(jsonl2_path, arcname=os.path.basename(jsonl2_path)) + return path + + +@pytest.fixture(scope="session") +def zip_nested_jsonl_path(zip_jsonl_path, jsonl_path, jsonl2_path, tmp_path_factory): + path = tmp_path_factory.mktemp("data") / "dataset_nested.jsonl.zip" + with zipfile.ZipFile(path, "w") as f: + f.write( + zip_jsonl_path, + arcname=os.path.join("nested", os.path.basename(zip_jsonl_path)), + ) + return path + + +@pytest.fixture(scope="session") +def zip_jsonl_with_dir_path(jsonl_path, jsonl2_path, tmp_path_factory): + path = tmp_path_factory.mktemp("data") / "dataset_with_dir.jsonl.zip" + with zipfile.ZipFile(path, "w") as f: + f.write( + jsonl_path, arcname=os.path.join("main_dir", os.path.basename(jsonl_path)) + ) + f.write( + jsonl2_path, arcname=os.path.join("main_dir", os.path.basename(jsonl2_path)) + ) + return path + + +@pytest.fixture(scope="session") +def tar_jsonl_path(jsonl_path, jsonl2_path, tmp_path_factory): + path = tmp_path_factory.mktemp("data") / "dataset.jsonl.tar" + with tarfile.TarFile(path, "w") as f: + f.add(jsonl_path, arcname=os.path.basename(jsonl_path)) + f.add(jsonl2_path, arcname=os.path.basename(jsonl2_path)) + return path + + +@pytest.fixture(scope="session") +def tar_nested_jsonl_path(tar_jsonl_path, jsonl_path, jsonl2_path, tmp_path_factory): + path = tmp_path_factory.mktemp("data") / "dataset_nested.jsonl.tar" + with tarfile.TarFile(path, "w") as f: + f.add( + tar_jsonl_path, + arcname=os.path.join("nested", os.path.basename(tar_jsonl_path)), + ) + return path + + +@pytest.fixture(scope="session") +def text_path(tmp_path_factory): + data = ["0", "1", "2", "3"] + path = str(tmp_path_factory.mktemp("data") / "dataset.txt") + with open(path, "w") as f: + for item in data: + f.write(item + "\n") + return path + + +@pytest.fixture(scope="session") +def text2_path(tmp_path_factory): + data = ["0", "1", "2", "3"] + path = str(tmp_path_factory.mktemp("data") / "dataset2.txt") + with open(path, "w") as f: + for item in data: + f.write(item + "\n") + return path + + +@pytest.fixture(scope="session") +def text_dir(tmp_path_factory): + data = ["0", "1", "2", "3"] + path = tmp_path_factory.mktemp("data_text_dir") / "dataset.txt" + with open(path, "w") as f: + for item in data: + f.write(item + "\n") + return path.parent + + +@pytest.fixture(scope="session") +def text_dir_with_unsupported_extension(tmp_path_factory): + data = ["0", "1", "2", "3"] + path = tmp_path_factory.mktemp("data") / "dataset.abc" + with open(path, "w") as f: + for item in data: + f.write(item + "\n") + return path + + +@pytest.fixture(scope="session") +def zip_text_path(text_path, text2_path, tmp_path_factory): + path = tmp_path_factory.mktemp("data") / "dataset.text.zip" + with zipfile.ZipFile(path, "w") as f: + f.write(text_path, arcname=os.path.basename(text_path)) + f.write(text2_path, arcname=os.path.basename(text2_path)) + return path + + +@pytest.fixture(scope="session") +def zip_text_with_dir_path(text_path, text2_path, tmp_path_factory): + path = tmp_path_factory.mktemp("data") / "dataset_with_dir.text.zip" + with zipfile.ZipFile(path, "w") as f: + f.write( + text_path, arcname=os.path.join("main_dir", os.path.basename(text_path)) + ) + f.write( + text2_path, arcname=os.path.join("main_dir", os.path.basename(text2_path)) + ) + return path + + +@pytest.fixture(scope="session") +def zip_unsupported_ext_path(text_path, text2_path, tmp_path_factory): + path = tmp_path_factory.mktemp("data") / "dataset.ext.zip" + with zipfile.ZipFile(path, "w") as f: + f.write(text_path, arcname=os.path.basename("unsupported.ext")) + f.write(text2_path, arcname=os.path.basename("unsupported_2.ext")) + return path + + +@pytest.fixture(scope="session") +def text_path_with_unicode_new_lines(tmp_path_factory): + text = "\n".join(["First", "Second\u2029with Unicode new line", "Third"]) + path = str(tmp_path_factory.mktemp("data") / "dataset_with_unicode_new_lines.txt") + with open(path, "w", encoding="utf-8") as f: + f.write(text) + return path + + +@pytest.fixture(scope="session") +def image_file(): + return os.path.join("tests", "features", "data", "test_image_rgb.jpg") + + +@pytest.fixture(scope="session") +def audio_file(): + return os.path.join("tests", "features", "data", "test_audio_44100.wav") + + +@pytest.fixture(scope="session") +def bio_dir(): + return os.path.join("tests", "features", "data", "CAMDA") + + +@pytest.fixture(scope="session") +def metadata_bio_files(): + return os.path.join("tests", "features", "data", "CAMDA", "camda.pheno.csv") + + +@pytest.fixture(scope="session") +def anno_bio_files(): + return os.path.join("tests", "features", "data", "CAMDA", "camda.feature.csv") + + +@pytest.fixture(scope="session") +def zip_image_path(image_file, tmp_path_factory): + path = tmp_path_factory.mktemp("data") / "dataset.img.zip" + with zipfile.ZipFile(path, "w") as f: + f.write(image_file, arcname=os.path.basename(image_file)) + f.write( + image_file, arcname=os.path.basename(image_file).replace(".jpg", "2.jpg") + ) + return path + + +@pytest.fixture(scope="session") +def data_dir_with_hidden_files(tmp_path_factory): + data_dir = tmp_path_factory.mktemp("data_dir") + + (data_dir / "subdir").mkdir() + with open(data_dir / "subdir" / "train.txt", "w") as f: + f.write("foo\n" * 10) + with open(data_dir / "subdir" / "test.txt", "w") as f: + f.write("bar\n" * 10) + # hidden file + with open(data_dir / "subdir" / ".test.txt", "w") as f: + f.write("bar\n" * 10) + + # hidden directory + (data_dir / ".subdir").mkdir() + with open(data_dir / ".subdir" / "train.txt", "w") as f: + f.write("foo\n" * 10) + with open(data_dir / ".subdir" / "test.txt", "w") as f: + f.write("bar\n" * 10) + + return data_dir diff --git a/tests/fixtures/fsspec.py b/tests/fixtures/fsspec.py new file mode 100644 index 0000000..b7bd9a9 --- /dev/null +++ b/tests/fixtures/fsspec.py @@ -0,0 +1,120 @@ +import posixpath +from pathlib import Path +from unittest.mock import patch + +import pytest +from fsspec.implementations.local import ( + AbstractFileSystem, + LocalFileSystem, + stringify_path, +) +from fsspec.registry import _registry as _fsspec_registry + + +class MockFileSystem(AbstractFileSystem): + protocol = "mock" + + def __init__(self, *args, local_root_dir, **kwargs): + super().__init__() + self._fs = LocalFileSystem(*args, **kwargs) + self.local_root_dir = Path(local_root_dir).resolve().as_posix() + "/" + + def mkdir(self, path, *args, **kwargs): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + return self._fs.mkdir(path, *args, **kwargs) + + def makedirs(self, path, *args, **kwargs): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + return self._fs.makedirs(path, *args, **kwargs) + + def rmdir(self, path): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + return self._fs.rmdir(path) + + def ls(self, path, detail=True, *args, **kwargs): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + out = self._fs.ls(path, detail=detail, *args, **kwargs) + if detail: + return [ + {**info, "name": info["name"][len(self.local_root_dir) :]} + for info in out + ] + else: + return [name[len(self.local_root_dir) :] for name in out] + + def info(self, path, *args, **kwargs): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + out = dict(self._fs.info(path, *args, **kwargs)) + out["name"] = out["name"][len(self.local_root_dir) :] + return out + + def cp_file(self, path1, path2, *args, **kwargs): + path1 = posixpath.join(self.local_root_dir, self._strip_protocol(path1)) + path2 = posixpath.join(self.local_root_dir, self._strip_protocol(path2)) + return self._fs.cp_file(path1, path2, *args, **kwargs) + + def rm_file(self, path, *args, **kwargs): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + return self._fs.rm_file(path, *args, **kwargs) + + def rm(self, path, *args, **kwargs): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + return self._fs.rm(path, *args, **kwargs) + + def _open(self, path, *args, **kwargs): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + return self._fs._open(path, *args, **kwargs) + + def created(self, path): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + return self._fs.created(path) + + def modified(self, path): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + return self._fs.modified(path) + + @classmethod + def _strip_protocol(cls, path): + path = stringify_path(path) + if path.startswith("mock://"): + path = path[7:] + return path + + +class TmpDirFileSystem(MockFileSystem): + protocol = "tmp" + tmp_dir = None + + def __init__(self, *args, **kwargs): + assert self.tmp_dir is not None, "TmpDirFileSystem.tmp_dir is not set" + super().__init__(*args, **kwargs, local_root_dir=self.tmp_dir, auto_mkdir=True) + + @classmethod + def _strip_protocol(cls, path): + path = stringify_path(path) + if path.startswith("tmp://"): + path = path[6:] + return path + + +@pytest.fixture +def mock_fsspec(): + _fsspec_registry["mock"] = MockFileSystem + _fsspec_registry["tmp"] = TmpDirFileSystem + yield + del _fsspec_registry["mock"] + del _fsspec_registry["tmp"] + + +@pytest.fixture +def mockfs(tmp_path_factory, mock_fsspec): + local_fs_dir = tmp_path_factory.mktemp("mockfs") + return MockFileSystem(local_root_dir=local_fs_dir, auto_mkdir=True) + + +@pytest.fixture +def tmpfs(tmp_path_factory, mock_fsspec): + tmp_fs_dir = tmp_path_factory.mktemp("tmpfs") + with patch.object(TmpDirFileSystem, "tmp_dir", tmp_fs_dir): + yield TmpDirFileSystem() + TmpDirFileSystem.clear_instance_cache() diff --git a/tests/preprocessing/__init__.py b/tests/preprocessing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/preprocessing/outputs/histogram.pdf b/tests/preprocessing/outputs/histogram.pdf new file mode 100644 index 0000000000000000000000000000000000000000..895cb3780be563eb0998dfd62305451f1678d45e GIT binary patch literal 4407 zcmZ`-XH=6(+olLeR9Yw^Adf6nN(wayf^?;Kq)7+?f+2|{ks5jvke6Zs=^bHNL_|7@ zG?5azA}9*3ARwUhqJD|Hy8G^a@0|1em}~AbbD!s$b6+E2p=TfuQ&0j)M9fCaMNlI~ zo&7*C01DvUeLxx-Ac(FX8cD+7ae7D+8j#jgheP2?P-UnpOc|yIgUWy)rg)G4xn_eQ zVQF4gfT16f;Eh2M0dq2mKqk@LEd1~uWE9$uzN)04r~s#7-*e&VkPPVTSr8iMLEi=Z z=WYOqO&|dcKy;BLBo^-lKrE16Xd<9UGq3{wa!~!31Ko~p?T5yZ{xpZ^;&CL}lnAJO zcZXP^iFmRfiWX1lAMxl0KgwZ(#(9yv0To3!05QN|Nwfn%46wA?=%G<~4>Wz-L=vrJ zzMz2YD!V=$FN~{YUv@7;c%eZw)CJ6FC7uGvUeQKGoBG8+6+6uR$UL}v^3!RxUZ%L3 zho#Nazxhqo9xBr&FweY->)|9Am;%u;+UbU(+Ne&WwK>q)_HL0^l#t+~QeJ~7WNfI$ znLs=~>M>+{S6IVYowIDv2XX(JMhkI$xid2J?F6Rx_zq(j*-JhGrM<#)!j)&t-1w!W zk~3JT&;m@DvgoToCa~Mqe`0MuG`OXR2cKBx851yRvB^MQVxl<8We5l#x(in7(C6f< zsunQxY!Q!-w!@iuHxR{3Do9s6x<`Zhvs(?{wADq6|m!K4d?(mFrNpW>wg( zfLyt=nX?&JQy{uJA>^dw>HAxXSmpKT#i|d&&ducu>uTBXgyrb6t65Pou;QHj`(iiw zbEosUGhe7EauhvepY2%^OXqu%BEn)F80mZez=A(hA@m7Y#bwGMxYKhhBev-UX}3TM zTMH>MWCG?C;T|o<9OtX*A@PnUp%>yZBv1PAj27!^b@Fz> z6DsOazjl<%6cz~iB#_FDOf7|*3R0KOi>!2rwjHIO?WxSMURkn73NFcL1);j!$j3^4 z>%_;gr@i9a6zg_&BOhV=%4MgzvgvbYJ5Vf>s-cL{Dm{$q2d8hNeD182JiEd6enxFu zbXu0Bt1YW7cej`MoBOv#7q0TT&wY*Vrr^4{XMK%zY}P9cO@wX#{`yAIZjqk&t4q|3 zuCKzpMh+oW-*(^3_7%vpajdA`7RcHhe~b9lY=3;&vPO2jqS7qVJ&k#Ms@}&nta;|( z;Q4F8vn#4W_ZT;5toic~{F6U)di?SG{@~9aw1F67Jcxi3K&Oip;6kGno=l?$41k!T zJupZed;s7CrL8GLmFS<#-xx+`$PY%)h>WsElK>|O?JA*VKnIXqKoA2Oy8sOhhyfgc z(Sv;FCqzGhWN1yI-y|9u^l-G4SiGM#0g3vKK!{d`I)L&I^S>BPhv@m7Am?>;b&y1~ z2cSZ?a`|yb&}To+c#h6W96bp{4}BGbLeqmn%xJua{K;OpGD7A5r7wL?Kk1uu%iqR& zl0S~xs@iIi#+7hm=(4vJlhLtOqf~{{Fm|pBk5d@LRP<~JdX}lBhvCw1!i#iMYs3;_ z7%gj-4hiawT_Nxa-8%14B{xo`&Y>Y?O7?`mPZwR%uI{Z5ZOqom}FOsmC;2nY1CJtbC|MaL-d z#b68;8AcP5<$1$xyq?XxB7&3iK*Dwz&KFAR!ZjD*{z6Fby)rt^ti{ZJ0o>%Yvx zUStYh6n^W+;7L8W$Q$S@cwg+y3pIoq)K}iOO?oWMwJ18xw8g!^q|4$762O(CvUowIUYvX#@*&c=*+mt`Oe zt9yAIVY-(>s8(&SR)0qZtf#A`kzl&6syNS}QoYWngR_L)qwn+Lc0}2}l27t;ABeZb zLvHBn^_7|cCb#OGo`U09GIe6Lck=G+aCUqxo@~5z?d1;l&@Bh1l2a$NLr*eDpXLVG zgFR%2Rw+li4ViZb5{EN0a`vf_Un%wtCSRZOi-#qn4E5j*JeGsYcMcqxlmG^OZRSO7 zp#^5}<3kT0b2l(>zF}rMCmnz2iW}@rByS3%9w5LQUYCFTMr4(nR5>d-pIa$Pg>ak{ zfqJ8)#-euFqj(N((XB9^Rbe{^}k8wy5ddJM7%WWZiZ=OBAAhuti?4|9F_qeevwQlBWhm)fP z+I!dFYN85}^^b27m_xW$qRl&Qzu|TOX+@NAhjVeO7bZfv`JWi*qy~tybDyeA;VEXq zXunK3SuEO^V3b<+)QOi>Ums>+a^uJ|QNEbj!dnG)1u8`qmTVR|3E`!LV_kmt+_A8SpkS_Go?wm(G0H^e%e*U-~27nW#~ur%@%SCNSjKjnMsqg0OAE3sA}z1P5l(@*1~$!i^)WZtP6 zsYjOiryWiol3wGrlO2;Lop!biD2=yQve~e4{xzf8p&&r-iqWULJg&V)*Qz=#9ZPJA zNH!m0M;^d!Ck+rb2;IP1zr@chzZ>$(b;Zx>zO_t{-k0{0M%hGHn4_&5$H-#_ix(Hw zj_E3vvJ6>aN*>_7qaMicB_!lrcyJ;2!u4lY_+&&1goc?HK5W4EkbC64Gq&=!=C*B% zlH)4lvK7_i@5=+rKaF1)dpT|}_WG{k7-ej9tp2Y5-I-BsOa>1;;QB8_L-9eT?F+m@Lwl~=K{F~<2XBWq*z$lT?nA1vTMziX6~ERh)9$WgV3(suw*+ z-24W3QP1F}im!ZwCXcGrsv5&8VM6e@0g--|_ss)w@g77&Dpt)={T%DHbzw^e-#JBX?hGD^n+5-F{gF6hPgR|) zANPMoeT|QqkqLVcmhw&In-HTT<8=7@@VnYQk8eg$&K2jgoEtyK895NC->!LzVC436 zE2Vk;*!+=s|Cq?3@S%?d+fR}@ZpBbUli4|sRI`U2M}jT6h98ADx%EcgVyPBu73&jQ z)yvd7Ul>%xR|GE1D!SAY&~u~52vy~7-|d<(8V~Ly_n_vA=CWqX1oq^n8wnciXA9Kc z?upIm2PP`WO*mw??dtU6a4IiYUlt9o!fRFR3#wYrw83tblXb@ND?B z)n|}-bMZ#Wn^Kw=uA08i)ySO9wMzYL^O)b}NN(9)*+fa3*%8yJye%0`$yQC~+o>5> zY>c=|+4tdfitF&W+Wy*v+TyBPQClC@H=z2tkc&e zRn@nRa{Y3lc(=yM&!&A+pDF8)Tq7LYU51_Vo$ft!BbCfQ+p`_2D5(sxU#eDc)NJd< z4@A^O*I_P)Nk34KgM>6Hc)FOU+6~_+tl11*a!TY&(C<=JX~33G#lMt(q1J#|DXlOS zH4cyvFs?F|y)c7+k89hxIbeVD66Iyu1N6QA0xZ?8-|p?i($sXlh_t-|6(Txy@_s9k z4Q1GyeEZDp&bvj)KmUE@+Fox@_u>FS`5SLr3V_&<_#u%#?xT zfdTnD^5`l;Wik1)+mrUbnSij-SKZi3d|CbNlT2~qaVNvFc0O))oQn(o&%NXa6t5i@ zU$(lUV(_B#YsC1?5s8nO5$+Nn zhvv_*?Tf0jim;xE6OLzVcB2j=+|1Ni=HDUkGwos_jml~HvK=bUPruv#&d(;`!bd8i2`rAfm z%fhK$WLNZ{-uf?7hCPxLDTtQqPQbzWiS*2}E-l$_&inoECjB1;9KfM`XOzD>ZybzP zXI(0P^y1sd3VN0FV&jXwV6~^~?H0>0;v4xT-*3$s{$Rz%)tXgVt&Grx-QyddR?3ph z^Fo=oqu2NU)oLStwxOUurK&!TCJQk*npUN^t$)qzF}QO$BIXAGg6Ls9J<&9MN*AC` zfRZxA7lR`c0hhlC;qQ&~k1qEQCH%e3hFFu`N#BJr9i;W|W=LOJ^Gj>JzbE~}#t4nQ ziY8%DND$;lQh>hQzYFI8KeQ(Z3cz3>=+6PDC@U)}1D?Q73{HCocVChp5`#ti z(Ox$Ih&3h%O*=0DIgj(i)1(i*9uO-$o&?Ze9lqCwNJ9FNzE?&G21S4*B=pSr84 CmWeR{ literal 0 HcmV?d00001 diff --git a/tests/preprocessing/outputs/histogram.png b/tests/preprocessing/outputs/histogram.png new file mode 100644 index 0000000000000000000000000000000000000000..13a4674e6ffb1ccd4ac49145493bd872c3351e91 GIT binary patch literal 53930 zcmeHw30PCt)^-#G93mV-}(b@{i z5Gg8F)ToGP83IH=L>4EoGsbJ9%Xe!*2t>ZYU~DlM zFO01%#uj;JU=Cy;|0DmXzfixtY!4s@!xs6T!LVf@@3pS#myfG!jH{22s}J&i+cly?q=m4Up2$ZhJI)UOzynA<)v#Xf0GG03xg zkVE}d4Dy#&9D{uM963OSErxNxmT|y~ae$$2U;W4oMh0@!GU|{6yzS$9JI3dBTFmX| zO}CrDj02dA1GX6quMFfm^@Yf*Hlq$Xzzhg1j&UuH@hMJ2j+o-K+r`aI#m!(b1lBTa zYZ+d(8OXt?MQ*4=j@sHfuoeQF(tMiIVw#>K$L*%(;-+S>2?9GiJHZSu@Ocba#{lax zz`9zn`8L>G3^q4`PzDGUgHSC9A)moWd}|0oWB=ou-L}C0D=*A%%tJ$Uqs@TU2hA8$|lVWRo!wc5XJ z-jsYu>&tQ9JD$FEp+Zl)#Bs~TdQjOJeA4;>7|<2(*H->gl*{9i?nnJ>yG&Gm?gv31 zQ868ZX`Bje@<$rZ({e_j7cl+24hAFq$80QeW|IvL`98%L0cMyppMqS(D5ylqr{14s`k*jp{pHy{W>gxD^d^Mk9uy$R<+IU|=#FCHh=g#TXuDLjHVo7a8%DnxX z*n9V{saSY+;KZKFg|o5qq7^etO;4}bGG`j; z!skC;bDchLqTp9o8=RwI-{N)aw2%Bcc*5{VLFJT(x5!olJ?TGjpJH&Heqk+mcCaUC z7f~vrSnQW0R9sM{gX+b9v2{uGY9R|sgT8i9+rrOH-fV2Y-TS-)&f%kFv#@gRXtQ;c z&~IMRNA|l9C@qKkNGOPgIfDOE>3L~~%bnPGbMkq5cg3lR!=&f8puN4n3` z^ZDiZk)}&-ex~0!jc9if@k6Prws|NRkU2RDhW>;Q3I=sn5(UHmDi}(fInG?N2!aDX zv3o#$PX!5n!)t@MsY6NK!pc8v=prJ{Q%?0I#LZJ@iDzSnjNql-6#lB&*rE{wX;}oj za5&?7`l?}mBvt0UY$(Y*%mCZ*D}};TBdEC(7V5HZvSmwu9!eblX--S3=3tqs?s%G# zU$~&>7h>>3xa_;3MD#EVP4Xv%w9Dhm!+j5_77wijU?dUYr)Em#ORu3^_@Ayq?|x^_ zEWmGX;xwu()l^gW-c$p2Qk=jnn7`1@yR?8j4J z>AL%aYY&FOH1KT7JjfKwQWdQgN_U^%qk-?*&YU`EJ!b>4qC=M)4`vx?;;t{jPBo#R zIvnoeUkk{kofDHor@&u7H`c`2HcG%J?&@4|I-aO#yurSIsF&&%r-3N>mDk|S&q|3) z6-u(QX@3B0ErK&O3C7QOER&|FW<$a>EW4Md3r8D7YhXi_>de;I-J?f{I>D&Vj?RdX z^W* z%eT*HPjq(_l$CD(Cfgi|>JX+HZipOQPr2X9!$8S5p^Ue-^j^^OIAWD<)tYTf<_xZR zyjb{yYgc6OwFU396MlJCH?4cNi@}};$A^|qq79>@M!|)O6RJQ^m5CY&)GVSN$YAja z>T{u9ChAwCGYNF=gw7VxiTU3(6ye=gDcrUDKDd03;V>rUnU3qck}^hi^vc~R&{zLa zM0{-ePXQhU<4hK9<1S;+r#AeQQ8(YT6uP=}qV($xPXcWhM`bQ}ZyUF;U~k9J8k{8u z4(}zt*D~j~wch!n;j(Z0?yGvyH#mdkr4Qpie{^E_!TC!rPadpNeZKecLJ7ydVw#U< z`L|6+#`{mc8Fhc#%AtsFbj{Uw`e?%_sZntK$BEOjtt&1Mtpz}p$&hCLzfzwYO0T|y z@i~k~6G)aWrB3NqpT){|&_f%Fp58-GqC`NEV7wg_hnXl5)GL2bB%nkX*jEQd0*V9_ z34ew(P!ow-hyKOZsA5DFBdQot#fZie(4Yz$SwbUd|E3692@#hI=Zo-icOVTuLxNg( zIWiIiQnUF)T%o#D>G`X#xv$|tmeDSN88Ra6 z=;n}CQZ?M?>!8{xfrA$IZ}GaJ^gg%Fruvuf`5k8t22xh^r^Gs(Inf@C=Ik(`@1hV; z1S?bc?AQGTMEegKO9E%L|NiPryPT!w`ejAxMVe+HMfTw?n|SqJ zz_wSdn#NY7QWYff%0+ZIS?8}uTjR`;5(6MhxXOqa$#*BiCnAjCjt%gk!ZesynmcjQ zV`7&F9@$N|HdrF> zO5yX{SLecbsI7`bly|c1V~HPRAK46JcztB|{yom6lU!zTo`hu|Fljh#u3uQDy}?0A zkyYwb72-?!t!8We2s^irz6M$ZAL@z(+?(F;t_Ac_eguN|Mr7{)^tj>b#c<;%_)WDl zi^g?HD^Jzg5oU?DZd;O=YKK%~ZgMN#hg4bWA=QvYm0=|k<%hf@2wvr1F1cKO$pm|B z22!orOv&4?3_aes0^w=M48slW-!OFKP7_rLI)1aUaX+vbli#-?pkg7p*w@gv91ioC zOra061ww;)^}dAN4?%~OzVQgGb0mefiJ!sT`bZWI)<}#b&OoYt@Ne=Z|`6G|tRrSO}n@E)Jp zTAF|^u){(8G(nDZ8K9~qD&^i>s00qYec`6M*mfj%8bS%#HdAO#SZ7Wr+!dr&_y098 zrk@~$h>V>%_9IbhV1fnUH`2>Y+R?do%3?vze$W7N;h2xQ4*r&7&CGejHPor$AZB|c zQll93NM7uZyI_AayV1Q$%v4>W3vyZy$V!N}Cq7UY*N$vJ^q8n82iRairU^8+1y56M z_2QpU*Ts*d!DHxCRIn!EXy$i$7Tg^LD?d?Tuhnx!YO6O=nh@Qo-lqM(B~FRXCRiR6 zO$;c5n~K--ri^TZVT*%QMm}=WdFC3Z<2)P7_OD`^9?_s|vy~J*sOih|^i@qGO-rIP zhZibOH|77ro})a&^Wf*r9B5nb} zr-@Zz+);M%7PCsYWyZyG`!vwYuG`G9p*I$%K{gVdRrLnE07=&dwpjcZcehcW?j%m{ zWkM@};1dhph zi9$^Zh4!S%nrR;_DXq~AiXY}+0coo@m-gkVpN?w6)!h`n37e?i`(Pv|f52l^38O{V zTbXHQgX{ir!8h_-Zziuafh0rYhy^yQ`WE7O|Tg0~Qd$On)0=7Y)knF474T z4OIFUx=DfJO5t4|T&N4T>NCH(x=BOpiNy#L;Wx3sdX8(iCYZEaAxLLSYq=*C-qo=2 zE9O-Vrh;LEbAZgjs0)Bfd@&LpwLi6~>TWmTDSt9_}6s^TiHR<8rE7l-Ofz`++F~xJq$}@B>?Ki#xnEvXou5!k;a4oZY@Sf-Y zyQ->uALJ6Qil4^IDrER;GWxPiGxe{bi~?{9&5d)I{aQ!|j<3~peieh>r!*idZp0#9 z71V1^^Djsl?Ysm^dG@Uq;rwtaanBNtCJ%=}#QA9Awfq5t4#>EGZ2Hkj>R+!Jg}aXw*;{wE zcokX2<255TX_W8OFAzDgKpdh$>zmtGYZ4lP1MV&NBE%>8DZ1e1$W`^zM|wxCQ%W&*W^%ScOt4coIzAb&52Us25dEK^bd`V&B5ve*H=TT zf>|)XRTCRDPumd16RwgVo8;^oY3)UbV&l-B>YE}RcW6>b3U>=KaS9?H-DgJ{ z6J+#MS@I-(cEwz_!q@&U#DfbxjDmR$^}zPab6OM0^>b53zs8{@qNrtZa*m z>_Zly>Lm++BQ5kDV%E&$F^fyNm0@X5D{2CAapsrZ8r}D?RWVy*CGfmgt_t^a6=A=$ zwo~%n?ELV99XSRGo=w4dsxnCp(UYjKq8a#&Z3r%2N6CB9>qlrjWN~~&5^G)b4TJGm zOPaVZ`lYag82E9EdGcEL>{Q(ShpO4wCs!vX#(bSF}_MWEEI8l5vK9H>(wZ6a1Oc~?umz>@cvV6yq zIga6Jn#MH+x6@n$;UX&<4{7RrgpNDfbqyFosg43>P)y%#p^Y7n-JaaudE(8j!dHJt zb?^()+~j%D6}u4~@P`vU+G*)FEb`s3OnSC? zg;UJ&gud(plx+p8-`q1~Cgma0_dM0?p-dcsN;DQ#JHvc2RDq+W5VhB+=YzUusH59I zYe2nmb?5~(i>O)rTbjkOw>G(qm}?vE#VkL2@bkI$bH_>7)>)01Yj6K?%;T}L({891O(2^{5oduh=9PuDO&%JbaL@Ha^uii0FHWJZc4S!%M+7Uf1P;NHE^D?TJhDO2Tvg|(Y}JJ?t$oy= zVp7MNhwNWlBps&3PiT-=>u1Y3-VyX%z5ZCwklDLpg3CQ&GJ#f{3u-w!?Th$q@GmE0 zQ}w@4y{d+~3!~e4t3sdg?M3)(xB_qS+d$R|aS4Yk+cH~%9Smma<5exC+@w&6WjClx z1bfr{AxgJ0*vs0xO z@P||~O;W>_mEc2Ebbp$X%M|Y74x|y)fw3WC1-?)>(H+I`d)5xU>14hV)^(E(%7O_5 zUtq&fpz+&{Dd!XVR1wL1@hQ5JEDq=q%>|jCu3H#>^@UXW$a&Jes|UJg6#sH<-5lFa zd50s1gjZN$Q<=K*azO%3QYB^gzJ%9JIPu`Zyr=9``@Ozz2!1S^d)muCKHttI8RFc? z;|{XhdO$|_vow}h00e~-t$OLdZ+m{>J+I2d>d7H>4AlKJ?nFzMBy|GT`U+nbznBXI zwVa)N`(%d9alM)ZAs5L5{!!9D2rBVn#sRu6d*iVMqnVUC9-GQK6{3o?*`<2f(Ki}k zNna2`EPpos9&X*L;ch*wHE(`TLo+-Fe~=Pc5jLFvW_nm&&`GKJv`B}R!?@q)eWrW8 zX~=SheA1EMRqqOPL!@i0KS!pEJ~ox>syb^l&W$MSkQ|=lMmB%iOU?da6q6AQ<4wHi z@%V1BLU(0i?qa`IuA6w2oGgr_S4B=c@6hTSq)d;;n}&+!7z)Zl1%9L9gj}khRPT3b zZaGcT94Sc3Ioe8cBXp>fyTeh}Mn@Vo37v_gx@C)5JaKvT&0ImpKq^_CbQ}vI=G~{!9U7ibLVeu9#Q89!{h4GRrmjBA%nfNv zl#+j6%lBD0^6zS%(M}>sHFTOiN_C?uXqd%|(230jI_sB4=%^xr+bbULq-SB~kwfCzZzMJBP!PgYu4rOZmF8$;&+rEbAl^d(mUhYm5Z!in1XnZ9l zS!bCzYTXZ0VN;UtxnId&*%9l8NzT6^VEUZ2M+bs(1ceqVBdFFw)h}vCP=koND5$T6 zI*q6skIqN_Z4;}*bUM4ZB(KY}>?pG2x8hB9)?@Yb@%5Rc7u&fJ^fSxjca>;3gWR@Y z_IG!m=XE($M#;q5=mhVxDR*K3{?*W00Cf6|#u3nf$lo(W)%d$(_z#A+8LJwFYe+M2 znAdU2a}zd`j~jmzb2DQRz2g3Uk9GYgBXq81KIPt;+qX9M;%mx|_`Q1A4UfIVPqZeJ z)=)4(@#t+RNrY=ql%RrxDg{(eqNW2ii>Rl7I!~x$i2BOt^Z}hwp|eGFYJ|oV(0B-C$|Q>v-no_{Zl={PXxxdf2b2e#}j)KSo=a{#N;DMWNklimV}v z>lxI3ILey%+lnK`_T6C|gQ>W%hBk+k)31$ot=pV(@W|OKc;r_UPPXS*(?nH=sxDVH zPqwj}nmp#ik?v(*t`DvB`eOOjw>RTD`fD*9=)!O#E*tD2hDa$ekj&HzxN3Dvw_uk)KQ2biD{t)4=t5v!aXVHXg=|?8< z5q?8JOBw0zUq1H$vbetdr#ul4K(_v+irZ5E=~I>MOuqnL$k?Wf*2IjzCc$i+1u;m- zPxof4!mR^Dmk$3W9gv!oRI&rS?px8;iHvM{Nyfp_4@Y%pC;eD~5wDOM&$yoZ;j5o; zu7|h0Y^e}7`!jPy&gE=-sLca_*Yq9 z-DR_t-C+lc#|am<*)s!HY*Rm{%D((E6OSu*Xcksk)>Zh5+dN2HwJ{i>jd!U54ui>C z;9W{QND25*W#CNpELHDZGs9}8LG;;~=c3lGJ!iP@VOup}iF#|iO;f|I9^CpdKW^K1 z&nQ+42eu7LnXj{N@8Frry)O4U+X>kh)!WmpomBMf*MXBMg>iQIq=*oKgy#_O!OYPl zJJ4Lces!RU{_EC7TRpvpY`eKmUA^#m`@P!i)?P`9ZYpWrJKnKCaoY`biO83Gmo6OY zsCm);lN4D`mOnF0CcGkdEOu_l{&{pC)U`%j+w8MrQ^R&UfWqhT?yr#T^#%vD+hxg@ zdmqibguwLYaSig1`ksiKWa)=~!M|wDx;}dJRZg6a^vRhG!E)zR>eT@7#w7jhz)hn^ zg$Rddh@js6N@F4`M{!A1A&W#_~Eb#NE8Vv0H_g`V$U|quJTpQe*;d4F>ZW;DDlN<*&AN-sv z5d#~Wp&k8G4p54YL_UhqQ5x?*RhV)`SH8tyrp*0j&Fb&|iH$*3)qhd>Yhf^YNZA4= zY>MN*^W%kjCk`wN_eVC-Z_k8;{7i>o9_J}3;w5pR!mjEycj3VI%Zju!yRCc6x11|6 z6--Mz7A#S*q3*1*;{AdO;U$(`aBD37P^GjXwC}Bd?q!-0ktTh(b(yC@Ms8o%eeZMQ zed?(#iK#O$IX5i$zE#r2W(OKAUB8~*EJhQQ?zc`r>;t7s-R+6xLC>yIXhht@<~^1rt6PC zr8@k4RO0WQz3@eGPN3j-E%@z-b_#w53H;gF5I#vTeY!hd)h+enf5SPHyLUv)fb<$ z%0GPbqJPp}7n8u)Z9mtgSEaZ0nY|$C$A9P1x9{Omip`sz9CtWGag7tJvIb_}lIpf? z5~@6;L{ZO8*7zJzQt0mBqPWHxbS z&-x9KJSA0Imgk)=fTgi4U(o_{rV`+K^z}cIUWQC{8l)-?H)Renq)ZiqTF3WHBh z#qEeG$xY@rZ?0t>%JRD+>Yjg1x$COFSl|Z)JoU@2epYavc>iTw;|<+E?puJ|w_;gY z(SBZqu($4eDy{LwDt4cDTUly=q~R00-tawa;abIpmHL*@P5A1zT4iRa;y2K7eUY}M zeZ}*1Ylr+eyIzY2*8{lS5cNh~T-Pj|PZRV>f2S3^%wK--Lp6L7=>sehCvGtUN@%Vx zI{lkNjw??wPV8Hq>LKxCAztYzR``Aw-*&#UXaK*9CyKiyrMqkUMHTJSQYuinVoR#;HDxV}zDviw1IJRUkxRinIu5Osky!396? zBnlTnuSlOD#$76N;-qaR62a?@Qi}#v;}O~>X}SmV$o&WH-o)BBD!YIpZKH;{gh}Vs z^0zqnj&~~SbV$ZUMNCOUGQVl!TT5cOUtm025!R$DWiuo>r;g+mCwe@28cHHBQmT}x z`s?eJfkI>;YaBTVR)H-_C55c1R%48mnm0{gk6eethSepSyZI$ zuZPvM>8dHq+}1GJiLFLlh^$LkscQ89USGnX<_X$+JAZNI&NSE=5YDWspw8NTIVCAn zFn<|g_N7oKjwwv^mv6aA4)yH#;H=P5NE;9DJ;(?-HD=%GiO|6Gq^(~KEnfAZro*XG zknJxOAYi>bWh zD64#9L8WT?zTW3q&wfux3Z*P_;#k$#5Z)jEkjSFFKP>=SPkH(J`7s+*D)C`Q&w8OX z@w;ZRM;RHq&XhU|T6!mTyhNyIP4sT{10+p+=8E#T#?w*??QB05c0txP?_4T;Tux5L zvS{G54J_cAc>B*vVON{{kC+LtIK!P?jc7cM1y25Y-!Wn2 zwpz}vS+zsvNVb)o*MHzMMW75U?WrBGD72x;wVnotO zP^08Lc%(=>@_mKyi+xS-yvvftI6DB0?&&^5_vr^$k4ni_9yOzNh)Z9P!TdtA81>}K zjS77%bUv30v)ue)Dy^@FWS$uJy}fMUp*OwRp~}RK=f<0~MbgcT4W#-H9_bdhG+KZA zNG)fGiH$o`fx%!L^{~i_#)yCM$GqiAGF$Q1g9)YYOI`aq^hFS#t@P>T?yCMl4e0cz7KJBh*3aAeWmGFacT-I*nJS z;xO4`9JPGSp^u$RML-n0Pm)Ue6rTx5RuWXmF9xb)cp6{Y(9-9CYzWc`b`jxA!GOx0 zYyq?_=0T8^xeZ4b)ph;T3v+*1q(h?wj?5m{Adgn5B9n=zC zFUphP<~lY_X`S2C!-h(;jBF-i@a$G^NL&agTRPd|4pCvKvKY^Py_VwZ=BqkHRH#1j z0+dVeNO$~M;^=2O*Crvu)q#a2nHD)`wJe=D*5wVtXcqj!G?{NVHl&F}hTZXCd^!<& z8AuVj`HBx&3&iEbz&ZM~fSw)RuzgtST7*1#i)QUqTp{L!a_ZpXLJtSgN5Sk!I^f^q z??X%w#zYgL*Y)x54QA)m`geoBxKm~P997S<>pIu;kmko)XQ@6}LVX=fwz4liQTdgA zc0()ECbT5apY1DRg32s*Go7Br7Rf~!ASd2BL=m6L`az~D{MDc2D?DULwCRbWCN&6q zMN{G4?w8gum?kYS@))}o;SS{UwyPdyz#rVHMxW9ugo@6wxqj3#{2984u|C21LLIe8 zR=%AgBsH>YO92|V)({9(?oB1;Dr=+JA)UPmn4~GmW8U%fdB3ZQUI^Y##1j2`F1y36 z2z~l41M+Fir=dh|pnb_{nVFFMsZ9%;Nqr(0xIJiYV{j**4C(+1jGV?UKzcGmR!oEn za{go=*DB346m zyd_n+$pBN1gbq9XED0fPU9SG|3LApzy%zOp-~$nNx{Ua~!xITYz42)f8jGx{d&>$_ zO>;$!fNtY>1T~9A47j5$dV;`_QwZN=I>_K&SjlkWc$);2*7NLxALt?{#=mz@?`SOS zG$7Nmp6SasufH|{VU8g_R7wPyWEqba)M`Q4d8aXJdQV$&e#nXZjzSBEOTs+lK0jJP zhCEQeoi(oUAks+$qDdsHfkH{myUC^XLWio&7GG0!Ja%`cs>|I{HKS^e0GYMPnPZI7k#!N*jgY89S z(Y(M;A@T!OvVGHXdjcpzz>SkXUd5oihi$Puc?|%$N>^Ep`3{lux-8iC$M~FJLnQKd z!`E=W1@PgOoaAq}Q{-)VQbhAt~IXZiSvwAd0(#CL1|tZpw@zUg$e5_X2$ z?jl3!`xX8`OjsL^Pb!WQ@wRoa7<=@A*YLJR&P7qvr2SQ?L@tas`NUe zELquAOwx2URLdmbvyN)}P;otrtyowXizfxlz%?ku9ckmBk0O6+0EpH1rs}C)b;Sbh z?IfnTtlOv@J`bFJmj?E^CInvD;|cX}X+2aoL2GBbu9onbT8>N1YbV?R4=?0INkS-X9%}10fF5zxX zCi1Hxg53BXFG_$sW0JK}N-Bq5Ig{DoDmFA%DRLv3qolOjgStYoxdd-4A zMf*~0LKE|gS%?aUTOi9jK{9PdPAvoa)S*X(Rf1jX!U0t?sXn{bQe|>+%q^sd9fO&N jSPje+$G`bUnD`Xyko)q2@jRr 0) + + # Check that the plot files were created and are not empty + for img_path in image_paths: + self.assertTrue(os.path.exists(img_path)) + self.assertGreater(os.path.getsize(img_path), 0) + + def test_plot_with_valid_params(self): + self.plotter.set_params( + plot_top=10, + dat_log="log2_1p", + show_column_names=True, + scale_legend_title="Abundance", + column_title="Samples", + row_title="Features", + plot_title="Feature Importances", + feature_meta_name=["features", "my_metadata_str"], + ) + + output_path = biofit.config.BIOFIT_CACHE_HOME + + self.plotter.plot( + X=self.X, + y=self.y, + sample_metadata=self.sample_metadata, + feature_importances=self.feature_importances, + feature_metadata=self.feature_metadata, + path=output_path, + show=False, + ) + + self._assert_plot_output(output_path) + + def test_plot_minimal_params(self): + output_path = biofit.config.BIOFIT_CACHE_HOME + + self.plotter.plot( + self.data, + feature_importances=self.feature_importances, + path=output_path, + show=False, + ) + + self._assert_plot_output(output_path) + + def test_plot_missing_feature_meta_name(self): + output_path = biofit.config.BIOFIT_CACHE_HOME + + self.plotter.set_params( + feature_meta_name="non_existent_column", + ) + + with self.assertRaises(ValueError) as context: + self.plotter.plot( + X=self.X, + y=self.y, + sample_metadata=self.sample_metadata, + feature_importances=self.feature_importances, + feature_metadata=self.feature_metadata, + path=output_path, + show=False, + ) + self.assertIn( + str(context.exception), + "Feature metadata columns ['non_existent_column'] not found in " + "feature metadata.", + ) + + def test_plot_with_invalid_dat_log(self): + self.plotter.set_params(dat_log="invalid_log") + + output_path = biofit.config.BIOFIT_CACHE_HOME + + # Assuming that invalid dat_log values are handled without raising an error + self.plotter.plot( + X=self.X, + y=self.y, + sample_metadata=self.sample_metadata, + feature_importances=self.feature_importances, + feature_metadata=self.feature_metadata, + path=output_path, + show=False, + ) + + self._assert_plot_output(output_path) + + def test_plot_with_plot_top_equal_to_num_features(self): + num_features = len(self.column_names) + self.plotter.set_params(plot_top=num_features) + + output_path = biofit.config.BIOFIT_CACHE_HOME + + self.plotter.plot( + X=self.X, + y=self.y, + sample_metadata=self.sample_metadata, + feature_importances=self.feature_importances, + feature_metadata=self.feature_metadata, + path=output_path, + show=False, + ) + + self._assert_plot_output(output_path) + + def test_plot_without_dat_log(self): + self.plotter.set_params(dat_log=None) + + output_path = biofit.config.BIOFIT_CACHE_HOME + + self.plotter.plot( + X=self.X, + y=self.y, + sample_metadata=self.sample_metadata, + feature_importances=self.feature_importances, + feature_metadata=self.feature_metadata, + path=output_path, + show=False, + ) + + self._assert_plot_output(output_path) + + def test_plot_with_show_column_names(self): + self.plotter.set_params(show_column_names=True) + + output_path = biofit.config.BIOFIT_CACHE_HOME + + self.plotter.plot( + X=self.X, + y=self.y, + sample_metadata=self.sample_metadata, + feature_importances=self.feature_importances, + feature_metadata=self.feature_metadata, + path=output_path, + show=False, + ) + + self._assert_plot_output(output_path) + + def test_plot_with_custom_scale_legend_title(self): + self.plotter.set_params(scale_legend_title="Custom Legend Title") + + output_path = biofit.config.BIOFIT_CACHE_HOME + + self.plotter.plot( + X=self.X, + y=self.y, + sample_metadata=self.sample_metadata, + feature_importances=self.feature_importances, + feature_metadata=self.feature_metadata, + path=output_path, + show=False, + ) + + self._assert_plot_output(output_path) + + def test_plot_with_input_columns(self): + # Select a subset of columns + selected_columns = self.column_names[:5] + self.plotter.set_params(input_columns=selected_columns) + + output_path = biofit.config.BIOFIT_CACHE_HOME + + self.plotter.plot( + X=self.X, + y=self.y, + sample_metadata=self.sample_metadata, + feature_importances=self.feature_importances, + feature_metadata=self.feature_metadata, + path=output_path, + show=False, + ) + + self._assert_plot_output(output_path) + + def test_plot_with_invalid_input_columns(self): + invalid_columns = ["non_existent_column"] + + output_path = biofit.config.BIOFIT_CACHE_HOME + + with self.assertRaises(ValueError) as context: + self.plotter.plot( + X=self.X, + y=self.y, + input_columns=invalid_columns, + sample_metadata=self.sample_metadata, + feature_importances=self.feature_importances, + feature_metadata=self.feature_metadata, + path=output_path, + show=False, + ) + self.assertIn( + "Columns {'non_existent_column'} not found in input dataset", + str(context.exception), + ) + + def test_plot_with_custom_feature_column(self): + # Rename the 'features' column in feature_importances + self.feature_importances.rename( + columns={"features": "feature_names"}, inplace=True + ) + self.plotter.set_params(feature_column="feature_names") + + output_path = biofit.config.BIOFIT_CACHE_HOME + + self.plotter.plot( + X=self.X, + y=self.y, + sample_metadata=self.sample_metadata, + feature_importances=self.feature_importances, + feature_metadata=self.feature_metadata, + path=output_path, + show=False, + ) + + self._assert_plot_output(output_path) + + def test_plot_with_invalid_feature_column(self): + output_path = biofit.config.BIOFIT_CACHE_HOME + + with self.assertRaises(ValueError) as context: + self.plotter.plot( + X=self.X, + y=self.y, + feature_column="invalid_column", + sample_metadata=self.sample_metadata, + feature_importances=self.feature_importances, + feature_metadata=self.feature_metadata, + path=output_path, + show=False, + ) + self.assertIn( + str(context.exception), + "Feature column 'invalid_column' not found in feature " + "importances. Please provide the column name found in both " + "feature importances and feature metadata (if provided).", + ) + + def test_plot_with_no_sample_metadata(self): + self.plotter.set_params(sample_metadata_columns=None) + + output_path = biofit.config.BIOFIT_CACHE_HOME + + self.plotter.plot( + X=self.X, + y=self.y, + sample_metadata=None, # sample_metadata is None + feature_importances=self.feature_importances, + feature_metadata=self.feature_metadata, + path=output_path, + show=False, + ) + + self._assert_plot_output(output_path) + + def test_plot_with_empty_X(self): + empty_X = pd.DataFrame() + + output_path = biofit.config.BIOFIT_CACHE_HOME + + with self.assertRaises(AssertionError) as context: + self.plotter.plot( + X=empty_X, + y=self.y, + sample_metadata=self.sample_metadata, + feature_importances=self.feature_importances, + feature_metadata=self.feature_metadata, + path=output_path, + show=False, + ) + self.assertIn("Input data has no rows", str(context.exception)) + + def test_plot_output_directory(self): + output_path = biofit.config.BIOFIT_CACHE_HOME + + self.plotter.plot( + X=self.X, + y=self.y, + sample_metadata=self.sample_metadata, + feature_importances=self.feature_importances, + feature_metadata=self.feature_metadata, + path=output_path, + show=False, + ) + + self._assert_plot_output(output_path) + + def test_plot_show_option(self): + output_path = biofit.config.BIOFIT_CACHE_HOME + + # Since we cannot check the display in a unit test, we ensure it runs without error + self.plotter.plot( + X=self.X, + y=self.y, + sample_metadata=self.sample_metadata, + feature_importances=self.feature_importances, + feature_metadata=self.feature_metadata, + path=output_path, + show=True, + ) + + self._assert_plot_output(output_path) diff --git a/tests/test_plot_sample_metadata.py b/tests/test_plot_sample_metadata.py new file mode 100644 index 0000000..3b3069b --- /dev/null +++ b/tests/test_plot_sample_metadata.py @@ -0,0 +1,8 @@ +# from biofit import SampleMetadataPlotter +# +# +# def test_plot_sample_metadata(camda_dataset): +# plotter = SampleMetadataPlotter() +# plotter.plot(camda_dataset["train"]) +# plotter = SampleMetadataPlotter() +# plotter.plot(camda_dataset["train"]) diff --git a/tests/test_plotting_utils.py b/tests/test_plotting_utils.py new file mode 100644 index 0000000..2cf2bbb --- /dev/null +++ b/tests/test_plotting_utils.py @@ -0,0 +1,697 @@ +import shutil +import unittest +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest +from biocore import DataHandler +from biocore.utils.import_util import is_biosets_available, is_polars_available + +import biofit.config +from biofit.utils.py_util import set_seed +from biofit.visualization import ( + generate_comparison_histogram, + plot_correlation, +) +from biofit.visualization.plotting_utils import ( + generate_violin, + plot_dimension_reduction, + plot_feature_importance, + plot_sample_metadata, +) +from tests.utils import create_bioset, require_biosets, require_polars, require_rpy2 + +SUPPORTED_MODELS = ["lightgbm", "lasso", "random_forest"] # , "svm"] + +FORMATS = [ + "pandas", + "polars", + "numpy", + "arrow", + "dataset", +] + + +# Really should be an integration test but we haven't implemented the unit tests +# for each of the plotting classes +pytestmark = pytest.mark.unit + + +class TestPlottingUtils(unittest.TestCase): + @pytest.fixture(autouse=True) + def inject_fixtures(self, count_data, sample_metadata): + self.sample_metadata = sample_metadata + self.metadata_columns = list(sample_metadata.columns) + self.X, self.y = count_data + self.input_columns = list(self.X.columns) + self.target_column = list(self.y.columns)[0] + self.data = pd.concat([self.sample_metadata, self.X, self.y], axis=1) + + self.feature_importances = pd.DataFrame( + { + "features": self.input_columns, + "importances_1": np.random.rand(len(self.input_columns)), + "importances_2": np.random.rand(len(self.input_columns)), + } + ) + # Convert the dataset to various formats + if is_biosets_available(): + from biosets.features import Abundance, BinClassLabel + + self.dataset_all = create_bioset( + X=self.X, + y=self.y, + sample_metadata=self.sample_metadata, + with_feature_metadata=True, + feature_type=Abundance, + target_type=BinClassLabel, + ) + + self.pandas_all = self.data + if is_polars_available(): + self.polars_all = DataHandler.to_polars(self.data) + self.arrow_all = DataHandler.to_arrow(self.data) + + # Extract data and target in various formats + self.numpy_data = DataHandler.to_numpy(self.X) + self.numpy_target = DataHandler.to_numpy(self.y) + + self.pandas_data = self.X + self.pandas_target = self.y + + self.polars_data = DataHandler.to_polars(self.X) + self.polars_target = DataHandler.to_polars(self.y) + + self.arrow_data = DataHandler.to_arrow(self.X) + self.arrow_target = DataHandler.to_arrow(self.y) + + self.pandas_sample_metadata = self.sample_metadata + self.polars_sample_metadata = DataHandler.to_polars(self.sample_metadata) + self.arrow_sample_metadata = DataHandler.to_arrow(self.sample_metadata) + + def setUp(self): + set_seed(42) + self.output_dir = Path(biofit.config.BIOFIT_CACHE_HOME) + self.output_dir.mkdir(exist_ok=True, parents=True) + + def tearDown(self): + if self.output_dir.exists(): + shutil.rmtree(self.output_dir) + + def _assert_plot_outputs(self): + assert self.output_dir.is_dir() + assert len([f for f in self.output_dir.iterdir() if f.is_file()]) > 0 + + # For dataset format + @require_rpy2 + def test_feature_importance_dataset(self): + plot_feature_importance( + self.feature_importances, + X=self.data, + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + @require_rpy2 + def test_feature_importance_pandas(self): + plot_feature_importance( + self.feature_importances, + X=self.pandas_all, + input_columns=self.input_columns, + target_columns=self.target_column, + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + @require_rpy2 + def test_feature_importance_pandas_separate(self): + plot_feature_importance( + self.feature_importances, + X=self.pandas_data, + y=self.pandas_target, + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + @require_rpy2 + def test_feature_importance_polars(self): + plot_feature_importance( + self.feature_importances, + X=self.polars_all, + input_columns=self.input_columns, + target_columns=self.target_column, + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + @require_rpy2 + def test_feature_importance_polars_separate(self): + plot_feature_importance( + self.feature_importances, + X=self.polars_data, + y=self.polars_target, + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + @require_rpy2 + def test_feature_importance_arrow(self): + plot_feature_importance( + self.feature_importances, + X=self.arrow_all, + input_columns=self.input_columns, + target_columns=self.target_column, + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + @require_rpy2 + def test_feature_importance_arrow_separate(self): + plot_feature_importance( + self.feature_importances, + X=self.arrow_data, + y=self.arrow_target, + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + @require_rpy2 + def test_feature_importance_numpy(self): + feature_importances = pd.DataFrame( + { + "features": [f"col_{i}" for i in range(self.numpy_data.shape[1])], + "importances_1": np.random.rand(len(self.input_columns)), + "importances_2": np.random.rand(len(self.input_columns)), + } + ) + + plot_feature_importance( + feature_importances, + X=self.numpy_data, + y=self.numpy_target, + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + @require_rpy2 + def test_sample_metadata_pandas(self): + input_columns = self.metadata_columns + plot_sample_metadata( + self.pandas_all, + sample_metadata_columns=input_columns, + outcome_column=self.target_column, + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + @require_rpy2 + def test_sample_metadata_polars(self): + input_columns = self.metadata_columns + plot_sample_metadata( + self.polars_all, + sample_metadata_columns=input_columns, + outcome_column=self.target_column, + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + @require_rpy2 + def test_sample_metadata_arrow(self): + input_columns = self.metadata_columns + plot_sample_metadata( + self.arrow_all, + sample_metadata_columns=input_columns, + outcome_column=self.target_column, + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + @require_rpy2 + def test_sample_metadata_dataset(self): + plot_sample_metadata( + self.data, + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + # Test methods for violin plot + @require_rpy2 + def test_violin_plot_numpy(self): + generate_violin( + self.numpy_data, + self.numpy_target, + xlab="test", + ylab="test", + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + @require_rpy2 + def test_violin_plot_pandas(self): + generate_violin( + self.pandas_all, + xlab="test", + ylab="test", + column=self.input_columns[0], + label_name=self.target_column, + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + @require_rpy2 + def test_violin_plot_polars(self): + generate_violin( + self.polars_all, + xlab="test", + ylab="test", + column=self.input_columns[0], + label_name=self.target_column, + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + @require_rpy2 + def test_violin_plot_arrow(self): + generate_violin( + self.arrow_all, + xlab="test", + ylab="test", + column=self.input_columns[0], + label_name=self.target_column, + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + @require_rpy2 + def test_violin_plot_dataset(self): + generate_violin( + self.data, + xlab="test", + ylab="test", + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + # Test methods for correlation plot + @require_rpy2 + def test_correlation_plot_numpy(self): + plot_correlation( + self.numpy_data, + self.numpy_target, + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + @require_rpy2 + def test_correlation_plot_pandas(self): + plot_correlation( + self.pandas_all, + input_columns=self.input_columns, + target_column=self.target_column, + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + @require_rpy2 + def test_correlation_plot_polars(self): + plot_correlation( + self.polars_all, + input_columns=self.input_columns, + target_column=self.target_column, + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + @require_rpy2 + def test_correlation_plot_arrow(self): + plot_correlation( + self.arrow_all, + input_columns=self.input_columns, + target_column=self.target_column, + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + @require_rpy2 + def test_correlation_plot_dataset(self): + plot_correlation( + self.data, + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + # Test methods for comparison histogram + @require_rpy2 + def test_comparison_histogram_numpy(self): + generate_comparison_histogram( + self.numpy_data[:, 0], + self.numpy_data[:, 1], + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + @require_rpy2 + def test_comparison_histogram_pandas(self): + input_columns = self.input_columns + generate_comparison_histogram( + self.pandas_all, + column1=input_columns[0], + column2=input_columns[1], + subplot_title1=input_columns[0], + subplot_title2=input_columns[1], + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + @require_rpy2 + def test_comparison_histogram_polars(self): + input_columns = self.input_columns + generate_comparison_histogram( + self.polars_all, + column1=input_columns[0], + column2=input_columns[1], + subplot_title1=input_columns[0], + subplot_title2=input_columns[1], + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + @require_rpy2 + def test_comparison_histogram_arrow(self): + input_columns = self.input_columns + generate_comparison_histogram( + self.arrow_all, + column1=input_columns[0], + column2=input_columns[1], + subplot_title1=input_columns[0], + subplot_title2=input_columns[1], + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + @require_rpy2 + def test_comparison_histogram_dataset(self): + input_columns = self.input_columns + generate_comparison_histogram( + self.data, + column1=input_columns[0], + column2=input_columns[1], + subplot_title1=input_columns[0], + subplot_title2=input_columns[1], + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + # Test methods for PCoA plot + def test_pcoa_plot_numpy(self): + plot_dimension_reduction( + self.numpy_data, + labels=self.numpy_target, + method="pcoa", + method_kwargs={"correction": "cailliez"}, + n_components=3, + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + def test_pcoa_plot_pandas(self): + plot_dimension_reduction( + self.pandas_all, + method="pcoa", + method_kwargs={"correction": "cailliez"}, + n_components=3, + input_columns=self.input_columns, + label_column=self.target_column, + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + @require_polars + def test_pcoa_plot_polars(self): + plot_dimension_reduction( + self.polars_all, + method="pcoa", + method_kwargs={"correction": "cailliez"}, + n_components=3, + input_columns=self.input_columns, + label_column=self.target_column, + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + def test_pcoa_plot_arrow(self): + plot_dimension_reduction( + self.arrow_all, + method="pcoa", + method_kwargs={"correction": "cailliez"}, + n_components=3, + input_columns=self.input_columns, + label_column=self.target_column, + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + @require_rpy2 + def test_pcoa_plot_dataset(self): + plot_dimension_reduction( + self.dataset_all, + method="pcoa", + method_kwargs={"correction": "cailliez"}, + n_components=3, + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + # Test methods for PCA plot + def test_pca_plot_numpy(self): + plot_dimension_reduction( + self.numpy_data, + labels=self.numpy_target, + method="pca", + n_components=3, + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + def test_pca_plot_pandas(self): + plot_dimension_reduction( + self.pandas_all, + method="pca", + n_components=3, + input_columns=self.input_columns, + label_column=self.target_column, + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + def test_pca_plot_polars(self): + plot_dimension_reduction( + self.polars_all, + method="pca", + n_components=3, + input_columns=self.input_columns, + label_column=self.target_column, + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + def test_pca_plot_arrow(self): + plot_dimension_reduction( + self.arrow_all, + method="pca", + n_components=3, + input_columns=self.input_columns, + label_column=self.target_column, + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + @require_biosets + def test_pca_plot_dataset(self): + plot_dimension_reduction( + self.dataset_all, + method="pca", + n_components=3, + output_dir=self.output_dir.as_posix(), + ) + self._assert_plot_outputs() + + # # Test methods for feature distribution plot + # def test_feature_distribution_plot_numpy(self): + # plot_feature_distribution( + # self.numpy_data, + # self.numpy_target, + # output_dir=self.output_dir.as_posix(), + # ) + # self._assert_plot_outputs() + # + # def test_feature_distribution_plot_pandas(self): + # plot_feature_distribution( + # self.pandas_all, + # output_dir=self.output_dir.as_posix(), + # ) + # self._assert_plot_outputs() + # + # def test_feature_distribution_plot_polars(self): + # plot_feature_distribution( + # self.polars_all, + # output_dir=self.output_dir.as_posix(), + # ) + # self._assert_plot_outputs() + # + # def test_feature_distribution_plot_arrow(self): + # plot_feature_distribution( + # self.arrow_all, + # output_dir=self.output_dir.as_posix(), + # ) + # self._assert_plot_outputs() + # + # def test_feature_distribution_plot_dataset(self): + # plot_feature_distribution( + # self.otu_dataset, + # output_dir=self.output_dir.as_posix(), + # ) + # self._assert_plot_outputs() + # + # # Test methods for compare feature distributions plot + # def test_compare_feature_distribution_plot_numpy(self): + # compare_feature_distributions( + # self.numpy_data, + # self.numpy_data, + # output_dir=self.output_dir.as_posix(), + # ) + # self._assert_plot_outputs() + # + # def test_compare_feature_distribution_plot_pandas(self): + # compare_feature_distributions( + # self.pandas_all, + # self.pandas_all, + # output_dir=self.output_dir.as_posix(), + # ) + # self._assert_plot_outputs() + # + # def test_compare_feature_distribution_plot_polars(self): + # compare_feature_distributions( + # self.polars_all, + # self.polars_all, + # columns1=self.input_columns, + # columns2=self.input_columns, + # output_dir=self.output_dir.as_posix(), + # ) + # self._assert_plot_outputs() + # + # def test_compare_feature_distribution_plot_arrow(self): + # compare_feature_distributions( + # self.arrow_all, + # self.arrow_all, + # output_dir=self.output_dir.as_posix(), + # ) + # self._assert_plot_outputs() + # + # def test_compare_feature_distribution_plot_dataset(self): + # compare_feature_distributions( + # self.otu_dataset, + # self.otu_dataset, + # output_dir=self.output_dir.as_posix(), + # ) + # self._assert_plot_outputs() + # + # # Test methods for barplot + # def test_barplot_numpy(self): + # generate_barplot( + # self.numpy_data[:, 0], + # self.numpy_target, + # output_dir=self.output_dir.as_posix(), + # ) + # self._assert_plot_outputs() + # + # def test_barplot_pandas(self): + # generate_barplot( + # self.pandas_all, + # value_name=self.input_columns[0], + # label_name=self.target_column, + # output_dir=self.output_dir.as_posix(), + # ) + # self._assert_plot_outputs() + # + # def test_barplot_polars(self): + # generate_barplot( + # self.polars_all, + # value_name=self.input_columns[0], + # label_name=self.target_column, + # output_dir=self.output_dir.as_posix(), + # ) + # self._assert_plot_outputs() + # + # def test_barplot_arrow(self): + # generate_barplot( + # self.arrow_all, + # value_name=self.input_columns[0], + # label_name=self.target_column, + # output_dir=self.output_dir.as_posix(), + # ) + # self._assert_plot_outputs() + # + # def test_barplot_dataset(self): + # generate_barplot( + # self.otu_dataset, + # value_name=self.input_columns[0], + # label_name=self.target_column, + # output_dir=self.output_dir.as_posix(), + # ) + # self._assert_plot_outputs() + # + # # Test methods for scatterplot + # def test_scatterplot_numpy(self): + # generate_scatterplot( + # x=self.numpy_data[:, 0], + # y=self.numpy_data[:, 1], + # group=self.numpy_target, + # output_dir=self.output_dir.as_posix(), + # ) + # self._assert_plot_outputs() + # + # def test_scatterplot_pandas(self): + # input_columns = self.input_columns + # generate_scatterplot( + # x=self.pandas_all, + # xdata=input_columns[0], + # ydata=input_columns[1], + # groupby=self.target_column, + # output_dir=self.output_dir.as_posix(), + # ) + # self._assert_plot_outputs() + # + # def test_scatterplot_polars(self): + # input_columns = self.input_columns + # generate_scatterplot( + # x=self.polars_all, + # xdata=input_columns[0], + # ydata=input_columns[1], + # groupby=self.target_column, + # output_dir=self.output_dir.as_posix(), + # ) + # self._assert_plot_outputs() + # + # def test_scatterplot_arrow(self): + # input_columns = self.input_columns + # generate_scatterplot( + # x=self.arrow_all, + # xdata=input_columns[0], + # ydata=input_columns[1], + # groupby=self.target_column, + # output_dir=self.output_dir.as_posix(), + # ) + # self._assert_plot_outputs() + # + # def test_scatterplot_dataset(self): + # input_columns = self.input_columns + # generate_scatterplot( + # x=self.otu_dataset, + # xdata=input_columns[0], + # ydata=input_columns[1], + # groupby=self.target_column, + # output_dir=self.output_dir.as_posix(), + # ) + # self._assert_plot_outputs() diff --git a/tests/test_processing.py b/tests/test_processing.py new file mode 100644 index 0000000..1e30cf5 --- /dev/null +++ b/tests/test_processing.py @@ -0,0 +1,2266 @@ +import copy +import inspect +import shutil +import unittest +from dataclasses import dataclass, field +from pathlib import Path +from typing import List, Optional, Type, Union +from unittest.mock import patch + +import numpy as np +import pandas as pd +import pyarrow as pa +import pytest +from biocore import DataHandler + +import biofit.config +from biofit.integration.biosets import get_feature +from biofit.processing import ( + BaseProcessor, + NonExistentCacheError, + ProcessorConfig, + SelectedColumnTypes, + sync_backup_config, +) +from biofit.utils import version +from biofit.utils.fingerprint import generate_cache_dir +from tests.utils import create_bioset, require_biosets, require_datasets + +# Mock feature types for testing purposes +FEATURE_TYPES = Union[Type, tuple] + +pytestmark = pytest.mark.unit + + +@dataclass +class MockProcessorConfig(ProcessorConfig): + """ + Configuration for the MockModel. Inherits from ProcessorConfig and specifies + feature types for automatic column selection. + """ + + # Specifies the input feature types for the fit method (arity of 2: X and y) + _fit_input_feature_types: List[FEATURE_TYPES] = field( + default_factory=lambda: [None, get_feature("TARGET_FEATURE_TYPES")], + init=False, + repr=False, + ) + # Specifies the unused feature types during fitting + _fit_unused_feature_types: List[FEATURE_TYPES] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES"), None], + init=False, + repr=False, + ) + # Specifies the unused feature types during transformation + _transform_unused_feature_types: List[FEATURE_TYPES] = field( + default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")], + init=False, + repr=False, + ) + _fit_process_desc: str = field(default="Fitting test", init=False, repr=False) + _predict_process_desc: str = field(default="Predict test", init=False, repr=False) + _transform_process_desc: str = field( + default="Transforming test", init=False, repr=False + ) + # Name of the processor + processor_type: str = field(default="mock", init=False, repr=False) + processor_name: str = field(default="processor", init=False, repr=False) + # Additional parameters for testing + processor_param_int: int = None + processor_param_float: float = None + processor_param_str: str = None + processor_param_bool: bool = None + processor_param_list: List = None + processor_param_dict: dict = None + processor_param_tuple: tuple = None + + +class MockModel(BaseProcessor): + """ + A mock model class for testing ProcessorConfig, TransformationMixin, and BaseProcessor. + Implements the necessary _fit_* and _predict_* methods. + """ + + config_class = MockProcessorConfig + config: MockProcessorConfig + + def __init__(self, config: Optional[MockProcessorConfig] = None, **kwargs): + # Initialize the BaseProcessor with the given configuration + super().__init__(config=config or self.config_class(**kwargs)) + self.post_fit = lambda x: x + + @sync_backup_config + def set_params(self, **kwargs): + self.config = self.config.replace_defaults(**kwargs) + self.post_fit = lambda x: x + + def fit( + self, + X, + y=None, + input_columns: SelectedColumnTypes = None, + target_column: SelectedColumnTypes = None, + raise_if_missing: bool = None, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + num_proc: int = None, + map_kwargs: dict = {"fn_kwargs": {}}, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ) -> "MockModel": + """ + Mock fit method that processes input data and fits the model. + """ + # Prepare input columns (arity of 2: X and y) + self.config._input_columns = self._set_input_columns_and_arity( + input_columns, target_column + ) + return self._process_fit( + X, + y, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def predict( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + """ + Mock predict method that generates predictions. + """ + self._method_prefix = "_predict" + self.config._n_features_out = 1 + self._input_columns = self._set_input_columns_and_arity(input_columns) + self.output_dtype = "float64" + return self._process_transform( + X, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def predict_proba( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + """ + Mock predict_proba method that generates predictions. + """ + self._method_prefix = "_predict_proba" + self._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_transform( + X, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def _fit_sklearn(self, X, y, some_param=None): + """ + Internal fit method for sklearn-compatible data formats. + """ + # Mock fitting logic (e.g., training a simple model) + self.config.processor_param_numpy = np.array([1, 2, 3]) + self.config.processor_param_arrow = pa.table({"a": [1, 2, 3]}) + self.config.processor_param_polars = pa.table({"a": [1, 2, 3]}) + self.config.processor_param_pandas = pa.table({"a": [1, 2, 3]}) + self.config.estimator = {"mean": np.mean(DataHandler.to_numpy(y).flatten())} + return self + + def _predict_sklearn(self, X): + """ + Internal predict method for sklearn-compatible data formats. + """ + if not self.config.is_fitted: + raise ValueError("Model is not fitted yet.") + # Mock prediction logic (e.g., returning the mean value) + predictions = np.full(len(X), self.config.estimator["mean"]) + return predictions + + def _predict_proba_sklearn(self, X): + """ + Internal predict_proba method for sklearn-compatible data formats. + """ + if not self.config.is_fitted: + raise ValueError("Model is not fitted yet.") + # Mock prediction logic (e.g., returning the mean value) + predictions = np.full( + (len(X), self.config.n_classes), self.config.estimator["mean"] + ) + return predictions + + +class MockPreprocessor(BaseProcessor): + """ + A mock model class for testing ProcessorConfig, TransformationMixin, and BaseProcessor. + Implements the necessary _fit_* and _predict_* methods. + """ + + config_class = MockProcessorConfig + config: MockProcessorConfig + + def __init__(self, config: Optional[MockProcessorConfig] = None, **kwargs): + # Initialize the BaseProcessor with the given configuration + super().__init__(config=config or self.config_class(**kwargs)) + self.post_fit = lambda x: x + + @sync_backup_config + def set_params(self, **kwargs): + self.config = self.config.replace_defaults(**kwargs) + self.post_fit = lambda x: x + + def fit( + self, + X, + y=None, + input_columns: SelectedColumnTypes = None, + target_column: SelectedColumnTypes = None, + raise_if_missing: bool = None, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + batch_format: str = None, + num_proc: int = None, + map_kwargs: dict = {"fn_kwargs": {}}, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ) -> "MockModel": + """ + Mock fit method that processes input data and fits the model. + """ + # Prepare input columns (arity of 2: X and y) + self.config._input_columns = self._set_input_columns_and_arity( + input_columns, target_column + ) + return self._process_fit( + X, + y, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def transform( + self, + X, + input_columns: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + cache_dir: str = None, + cache_file_name: str = None, + load_from_cache_file: bool = None, + batched: bool = None, + batch_size: int = None, + batch_format: str = None, + output_format: str = None, + map_kwargs: dict = None, + num_proc: int = None, + fingerprint: str = None, + ): + """ + Mock predict method that generates predictions. + """ + self._method_prefix = "_transform" + self._input_columns = self._set_input_columns_and_arity(input_columns) + return self._process_transform( + X, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def fit_transform( + self, + X, + y=None, + input_columns: SelectedColumnTypes = None, + target_column: SelectedColumnTypes = None, + keep_unused_columns: bool = None, + raise_if_missing: bool = None, + cache_output: bool = None, + load_from_cache_file: bool = None, + batched: bool = True, + batch_size: int = 1000, + output_format: str = None, + batch_format: str = None, + num_proc: int = None, + map_kwargs: dict = {"fn_kwargs": {}}, + cache_dir: str = None, + cache_file_name: str = None, + fingerprint: str = None, + ): + """ + Mock fit_transform method that processes input data and fits the model. + """ + return self.fit( + X, + y, + input_columns=input_columns, + target_column=target_column, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + num_proc=num_proc, + map_kwargs=map_kwargs, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + fingerprint=fingerprint, + ).transform( + X, + input_columns=input_columns, + keep_unused_columns=keep_unused_columns, + raise_if_missing=raise_if_missing, + cache_output=cache_output, + cache_dir=cache_dir, + cache_file_name=cache_file_name, + load_from_cache_file=load_from_cache_file, + batched=batched, + batch_size=batch_size, + batch_format=batch_format, + output_format=output_format, + map_kwargs=map_kwargs, + num_proc=num_proc, + fingerprint=fingerprint, + ) + + def _fit_pandas(self, X, y): + """ + Internal fit method for sklearn-compatible data formats. + """ + return self + + def _fit_arrow(self, X, y): + """ + Internal fit method for sklearn-compatible data formats. + """ + return self + + def _fit_polars(self, X, y): + """ + Internal fit method for sklearn-compatible data formats. + """ + return self + + def _fit_numpy(self, X, y): + """ + Internal fit method for sklearn-compatible data formats. + """ + return self + + def _transform_pandas(self, X): + """ + Internal predict method for sklearn-compatible data formats. + """ + return X + + def _transform_arrow(self, X): + """ + Internal predict method for sklearn-compatible data formats. + """ + return X + + def _transform_polars(self, X): + """ + Internal predict method for sklearn-compatible data formats. + """ + return X + + def _transform_numpy(self, X): + """ + Internal predict method for sklearn-compatible data formats. + """ + return X + + +class TestMockModel(unittest.TestCase): + @pytest.fixture(autouse=True) + def inject_fixtures(self, count_data, sample_metadata): + self.sample_metadata = sample_metadata + self.metadata_columns = list(sample_metadata.columns) + self.X, self.y = count_data + self.column_names = list(self.X.columns) + self.target_column = self.y.columns[0] + self.data = pd.concat([self.sample_metadata, self.X, self.y], axis=1) + + def setUp(self): + self.params = { + "version": version.__version__, + "processor_param_int": 1, + "processor_param_float": 1.0, + "processor_param_str": "test", + "processor_param_bool": True, + "processor_param_list": [1, 2, 3], + "processor_param_dict": {"a": 1, "b": 2}, + "processor_param_tuple": (1, 2), + } + self.config = MockProcessorConfig(**self.params) + self.model = MockModel(config=self.config) + self.preprocessor = MockPreprocessor(config=self.config) + self.funcs = [ + self.preprocessor._fit_pandas, + self.preprocessor._fit_numpy, + self.preprocessor._fit_arrow, + self.preprocessor._fit_polars, + ] + self.accepted_formats = ["pandas", "numpy", "arrow", "polars"] + + def tearDown(self): + # clean up the cache directory + cache_dir = biofit.config.BIOFIT_CACHE_HOME + if cache_dir.exists(): + shutil.rmtree(cache_dir) + + def test_init(self): + model_params = self.model.config.get_params() + preprocessor_params = self.preprocessor.config.get_params() + self.assertEqual(model_params, preprocessor_params) + self.assertEqual(model_params, self.params) + self.assertEqual(preprocessor_params, self.params) + + def test_get_method(self): + formats = ["numpy", "pandas", "arrow", "polars"] + func_type = "_fit" + methods = self.preprocessor._get_method(formats, func_type) + method_names = [method.__name__ for method in methods] + expected_method_names = [ + "_fit_numpy", + "_fit_pandas", + "_fit_arrow", + "_fit_polars", + ] + self.assertEqual(method_names, expected_method_names) + + def test_has_method(self): + formats = ["numpy", "pandas", "arrow", "polars"] + func_type = "_fit" + has_method = self.preprocessor._has_method(formats, func_type) + self.assertTrue(has_method) + formats = ["nonexistent_format"] + has_method = self.preprocessor._has_method(formats, func_type) + self.assertFalse(has_method) + + def test_get_target_func_match_source_format(self): + # Testing _get_target_func method + # Priority should be source format + func, to_format = self.preprocessor._get_target_func( + funcs=self.funcs, + source_format="numpy", + accepted_formats=self.accepted_formats, + ) + + self.assertEqual(func.__name__, "_fit_numpy") + self.assertEqual(to_format, "numpy") + + def test_get_target_func_priority(self): + # Priority should be the first format found in funcs in the order of + # accepted_formats + func, to_format = self.preprocessor._get_target_func( + funcs=self.funcs[-2:], + source_format="not_a_format", + accepted_formats=self.accepted_formats, + ) + + self.assertEqual(func.__name__, f"_fit_{self.accepted_formats[-2]}") + self.assertEqual(to_format, self.accepted_formats[-2]) + + def test_get_target_func_priority_order(self): + # Priority is the first format found in funcs in the order of + # target_formats + func, to_format = self.preprocessor._get_target_func( + funcs=self.funcs, + source_format="pandas", + target_formats=["polars", "arrow"], + ) + self.assertEqual(func.__name__, "_fit_polars") + self.assertEqual(to_format, "polars") + + def test_set_input_columns_and_arity(self): + result = self.preprocessor._set_input_columns_and_arity( + "col1", ["col2", "col3"] + ) + expected = [["col1"], ["col2", "col3"]] + self.assertEqual(result, expected) + + def test_reinsert_columns_valid(self): + """Test reinsert_columns with valid inputs and unused indices.""" + # Input data (input) + input_data = pd.DataFrame( + { + "col1": np.arange(10), + "col2": np.arange(10, 20), + "col3": np.arange(20, 30), + "col4": np.arange(30, 40), + } + ) + # Output data (out) + output_data = pd.DataFrame({"col1_transformed": np.arange(100, 110)}) + # Indices of unused columns (unused_indices) + indices = [0, 3] # 'col1' and 'col4' transformed into 'col1_transformed' + unused_indices = [1, 2] # 'col2' and 'col3' were unused + + self.preprocessor.config._n_features_out = 1 + result = self.preprocessor._reinsert_columns( + input=input_data, + out=output_data, + indices=indices, + unused_indices=unused_indices, + one_to_one_features=self.preprocessor.config.one_to_one_features, + ) + self.preprocessor.config._n_features_out = None + + # When one_to_one_features=False, the output columns should be + # appended to the end of the input data + # Expected result: 'col2', 'col3', and 'col1_transformed' + other_cols = input_data.iloc[:, unused_indices] + concatenated = pd.concat([output_data, other_cols], axis=1) + expected_output = concatenated[["col2", "col3", "col1_transformed"]] + + # Assert that the result matches the expected output + pd.testing.assert_frame_equal(result, expected_output) + + def test_reinsert_columns_one_to_one_features(self): + """Test reinsert_columns with one_to_one_features=True.""" + # Input data + input_data = pd.DataFrame( + { + "col0": np.arange(10), + "col1": np.arange(10, 20), + "col2": np.arange(20, 30), + "col3": np.arange(30, 40), + } + ) + # Output data + output_data = pd.DataFrame( + { + "col0_transformed": np.arange(100, 110), + "col3_transformed": np.arange(200, 210), + } + ) + # Indices used + indices = [0, 3] # 'col0' and 'col3' transformed + # Unused indices + unused_indices = [1, 2] # 'col1' and 'col2' unused + + self.preprocessor.config._n_features_out = None + + # when one_to_one_features=True, the output columns should be follow + # the same order as the input columns + result = self.preprocessor._reinsert_columns( + input=input_data, + out=output_data, + indices=indices, + unused_indices=unused_indices, + one_to_one_features=self.preprocessor.config.one_to_one_features, + ) + + # Expected result + other_cols = input_data.iloc[:, unused_indices] + expected_output = pd.concat([other_cols, output_data], axis=1) + expected_output = expected_output[ + ["col0_transformed", "col1", "col2", "col3_transformed"] + ] + + # Assert + pd.testing.assert_frame_equal(result, expected_output) + + def test_reinsert_columns_no_unused_columns(self): + """Test when there are no unused columns.""" + input_data = pd.DataFrame({"col1": np.arange(10)}) + output_data = pd.DataFrame({"col1_transformed": np.arange(100, 110)}) + indices = [0] + unused_indices = [] + + result = self.preprocessor._reinsert_columns( + input_data, output_data, indices, unused_indices + ) + + # Expected result is the same as output_data + pd.testing.assert_frame_equal(result, output_data) + + def test_reinsert_columns_mismatched_row_counts(self): + """Test when input and output have different number of rows.""" + input_data = pd.DataFrame({"col1": np.arange(10)}) + output_data = pd.DataFrame( + { + "col1_transformed": np.arange(5) # Different number of rows + } + ) + indices = [0] + unused_indices = [] + + result = self.preprocessor._reinsert_columns( + input_data, output_data, indices, unused_indices + ) + + # Since row counts differ, should return output as is + pd.testing.assert_frame_equal(result, output_data) + + def test_reinsert_columns_invalid_indices(self): + """Test with invalid indices (out of bounds).""" + input_data = pd.DataFrame({"col1": np.arange(10)}) + output_data = pd.DataFrame({"col1_transformed": np.arange(10)}) + indices = [0] + unused_indices = [1] # Invalid index + + with self.assertRaises(IndexError): + self.preprocessor._reinsert_columns( + input_data, output_data, indices, unused_indices + ) + + def test_reinsert_columns_empty_input(self): + """Test with empty input data.""" + input_data = pd.DataFrame() + output_data = pd.DataFrame({"col_transformed": []}) + indices = [] + unused_indices = [] + + result = self.preprocessor._reinsert_columns( + input_data, output_data, indices, unused_indices + ) + + # Expected result is the same as output_data + pd.testing.assert_frame_equal(result, output_data) + + def test_reinsert_columns_empty_output(self): + """Test with empty output data and unused columns.""" + input_data = pd.DataFrame({"col1": [], "col2": []}) + output_data = pd.DataFrame() + indices = [] + unused_indices = [0, 1] + + result = self.preprocessor._reinsert_columns( + input_data, output_data, indices, unused_indices + ) + + # Expected result is input_data with unused columns + pd.testing.assert_frame_equal(result, input_data) + + def test_reinsert_columns_different_column_types(self): + """Test with different data types in columns.""" + input_data = pd.DataFrame( + { + "int_col": np.arange(10), + "float_col": np.random.rand(10), + "str_col": ["text"] * 10, + } + ) + output_data = pd.DataFrame({"int_col_transformed": np.arange(100, 110)}) + indices = [0] + unused_indices = [1, 2] + + result = self.preprocessor._reinsert_columns( + input_data, output_data, indices, unused_indices, one_to_one_features=True + ) + + other_cols = input_data[["float_col", "str_col"]] + expected_output = pd.concat([output_data, other_cols], axis=1) + expected_output = expected_output[ + ["int_col_transformed", "float_col", "str_col"] + ] + + pd.testing.assert_frame_equal(result, expected_output) + + def test_reinsert_columns_null_values(self): + """Test with null values in input data.""" + input_data = pd.DataFrame( + { + "col1": [1, np.nan, 3, np.nan, 5], + "col2": [np.nan, 2, np.nan, 4, np.nan], + } + ) + output_data = pd.DataFrame({"col1_transformed": [10, 20, 30, 40, 50]}) + indices = [0] + unused_indices = [1] + + result = self.preprocessor._reinsert_columns( + input_data, output_data, indices, unused_indices, one_to_one_features=True + ) + + other_cols = input_data[["col2"]] + expected_output = pd.concat([output_data, other_cols], axis=1) + expected_output = expected_output[["col1_transformed", "col2"]] + + pd.testing.assert_frame_equal(result, expected_output) + + def test_reinsert_columns_non_dataframe_input(self): + """Test with input data that is not a DataFrame.""" + input_data = pd.DataFrame( + { + "col1": np.arange(10), + "col2": np.arange(10, 20), + "col3": np.arange(20, 30), + } + ) + output_data = np.arange(10) + indices = [0] + unused_indices = [1, 2] + + output_data = self.preprocessor._reinsert_columns( + input_data, output_data, indices, unused_indices + ) + # format type should match the input data + self.assertIsInstance(output_data, pd.DataFrame) + + def test_make_columns_exclusive(self): + columns = [["col1", "col2"], ["col2", "col3"], ["col3", "col4"]] + result = self.preprocessor._make_columns_exclusive(columns) + expected = [["col1"], ["col2"], ["col3", "col4"]] + self.assertEqual(result, expected) + + def test_get_columns_valid_input_columns(self): + """Test _get_columns with valid input_columns""" + input_columns = self.column_names + result = self.preprocessor._get_columns( + self.X, input_columns=[input_columns], raise_if_missing=True + ) + # Unpack results + ( + feature_names_in, + feature_idx_in, + unused_idx_in, + extra_names_in, + extra_idx_in, + unused_extra_idx_in, + offsets, + ) = result + + # Assertions + self.assertEqual(feature_names_in, input_columns) + expected_indices = [self.column_names.index(col) for col in input_columns] + self.assertEqual(feature_idx_in, expected_indices) + expected_unused_indices = sorted( + set(range(len(self.column_names))) - set(expected_indices) + ) + self.assertEqual(unused_idx_in, expected_unused_indices) + self.assertIsNone(extra_names_in) + self.assertIsNone(extra_idx_in) + self.assertIsNone(unused_extra_idx_in) + self.assertIsNone(offsets) + + def test_get_columns_invalid_input_columns_raise(self): + """Test _get_columns with invalid input_columns and raise_if_missing=True""" + input_columns = self.column_names + ["non_existent_column"] + with self.assertRaises(ValueError) as context: + self.preprocessor._get_columns( + self.X, input_columns=[input_columns], raise_if_missing=True + ) + self.assertIn( + str(context.exception), + "Columns {'non_existent_column'} not found in input dataset", + ) + + def test_get_columns_invalid_input_columns_no_raise(self): + """Test _get_columns with invalid input_columns and raise_if_missing=False""" + input_columns = self.column_names + ["non_existent_column"] + result = self.preprocessor._get_columns( + self.X, input_columns=[input_columns], raise_if_missing=False + ) + ( + feature_names_in, + feature_idx_in, + unused_idx_in, + extra_names_in, + extra_idx_in, + unused_extra_idx_in, + offsets, + ) = result + + self.assertEqual(feature_names_in, self.column_names) + expected_indices = list(range(len(self.column_names))) + self.assertEqual(feature_idx_in, expected_indices) + expected_unused_indices = sorted( + set(range(len(self.column_names))) - set(expected_indices) + ) + self.assertEqual(unused_idx_in, expected_unused_indices) + self.assertIsNone(extra_names_in) + self.assertIsNone(extra_idx_in) + self.assertIsNone(unused_extra_idx_in) + self.assertIsNone(offsets) + + def test_get_columns_unused_columns(self): + """Test _get_columns with unused_columns specified""" + unused_columns = self.column_names[:2] + result = self.preprocessor._get_columns( + self.X, + input_columns=[None], + unused_columns=[unused_columns], + raise_if_missing=True, + ) + # Unpack results + ( + feature_names_in, + feature_idx_in, + unused_idx_in, + extra_names_in, + extra_idx_in, + unused_extra_idx_in, + offsets, + ) = result + + expected_input_columns = [ + col for col in self.column_names if col not in unused_columns + ] + self.assertEqual(feature_names_in, expected_input_columns) + expected_indices = [ + self.column_names.index(col) for col in expected_input_columns + ] + self.assertEqual(feature_idx_in, expected_indices) + expected_unused_indices = sorted( + [self.column_names.index(col) for col in unused_columns] + ) + self.assertEqual(unused_idx_in, expected_unused_indices) + self.assertIsNone(extra_names_in) + self.assertIsNone(extra_idx_in) + self.assertIsNone(unused_extra_idx_in) + self.assertIsNone(offsets) + + def test_get_columns_with_args(self): + """Test _get_columns with additional args (extra inputs)""" + # Extra input data + extra_X = { + "extra_feature_1": np.random.rand(50), + "extra_feature_2": np.random.rand(50), + } + input_columns = [self.column_names, ["extra_feature_1"]] + result = self.preprocessor._get_columns( + self.X, extra_X, input_columns=input_columns, raise_if_missing=True + ) + # Unpack results + ( + feature_names_in, + feature_idx_in, + unused_idx_in, + extra_names_in, + extra_idx_in, + unused_extra_idx_in, + offsets, + ) = result + + # Assertions for main input + self.assertEqual(feature_names_in, self.column_names) + expected_indices = list(range(len(self.column_names))) + self.assertEqual(feature_idx_in, expected_indices) + expected_unused_indices = sorted( + set(range(len(self.column_names))) - set(expected_indices) + ) + self.assertEqual(unused_idx_in, expected_unused_indices) + + # Assertions for extra input + self.assertEqual(extra_names_in, [["extra_feature_1"]]) + self.assertEqual(extra_idx_in, [[0]]) + self.assertEqual(unused_extra_idx_in, [[1]]) # Only 'extra_feature_2' is unused + self.assertEqual(offsets, []) # No offsets since extra input not concatenated + + def test_get_columns_empty_input_columns(self): + """Test _get_columns with empty input_columns""" + result = self.preprocessor._get_columns(self.X, input_columns=[]) + self.assertEqual(result, (None, None, None, None, None, None, None)) + + def test_get_columns_input_feature_types(self): + """Test _get_columns with input_feature_types specified""" + input_feature_types = get_feature("Abundance") + result = self.preprocessor._get_columns( + self.X, + input_columns=[None], + input_feature_types=[input_feature_types], + raise_if_missing=True, + ) + # Unpack results + ( + feature_names_in, + feature_idx_in, + unused_idx_in, + extra_names_in, + extra_idx_in, + unused_extra_idx_in, + offsets, + ) = result + + # Assertions + self.assertEqual(feature_names_in, self.column_names) + expected_indices = list(range(len(self.column_names))) + self.assertEqual(feature_idx_in, expected_indices) + expected_unused_indices = sorted( + set(range(len(self.column_names))) - set(expected_indices) + ) + self.assertEqual(unused_idx_in, expected_unused_indices) + self.assertIsNone(extra_names_in) + self.assertIsNone(extra_idx_in) + self.assertIsNone(unused_extra_idx_in) + self.assertIsNone(offsets) + + @require_biosets + def test_get_columns_invalid_input_feature_types(self): + """Test _get_columns with invalid input_feature_types""" + + X = create_bioset( + self.X, + feature_type=get_feature("Abundance"), + ) + + # Dataset uses Abundance feature type, but GenomicVariant is specified + result = self.preprocessor._get_columns( + X, + input_columns=[None], + input_feature_types=[get_feature("GenomicVariant")], + raise_if_missing=False, + ) + # Should return None values due to invalid feature types + self.assertEqual(result, (None, None, None, None, None, None, None)) + + def test_get_columns_mismatched_arity(self): + """Test _get_columns with mismatched arity between input_columns and args""" + # No args provided, but input_columns has two lists + input_columns = [self.column_names, ["extra_feature_1"]] + with self.assertRaises(AssertionError) as context: + self.preprocessor._get_columns( + self.X, input_columns=input_columns, raise_if_missing=True + ) + self.assertIn( + str(context.exception), + "Number of column sets (2) must match the arity (1)", + ) + + def test_get_columns_with_none_X(self): + """Test _get_columns with X as None""" + with self.assertRaises(AssertionError) as context: + self.preprocessor._get_columns( + None, input_columns=[self.column_names], raise_if_missing=True + ) + self.assertIn("Input data is None", str(context.exception)) + + def test_get_columns_with_non_list_input_columns(self): + """Test _get_columns with input_columns not wrapped in a list""" + input_columns = "feature_1" + with self.assertRaises(AssertionError) as context: + self.preprocessor._get_columns( + self.X, input_columns=input_columns, raise_if_missing=True + ) + # Should return None values since input_columns is not properly wrapped + self.assertIn( + str(context.exception), + f"input_columns must be a list of column names or indices, " + f"but got {type(input_columns)}", + ) + + @require_biosets + def test_get_columns_with_unused_feature_types(self): + """Test _get_columns with unused_feature_types specified""" + X = create_bioset( + self.X, + self.y, + self.sample_metadata, + feature_type=get_feature("Abundance"), + target_type=get_feature("BinClassLabel"), + ) + unused_feature_types = get_feature("METADATA_FEATURE_TYPES") + result = self.preprocessor._get_columns( + X, + input_columns=[None], + unused_feature_types=[unused_feature_types], + raise_if_missing=True, + ) + # Unpack results + ( + feature_names_in, + feature_idx_in, + unused_idx_in, + _, + _, + _, + _, + ) = result + + self.assertEqual(feature_names_in, self.column_names) + expected_indices = [X.column_names.index(col) for col in self.column_names] + self.assertEqual(feature_idx_in, expected_indices) + expected_unused_indices = sorted( + [ + X.column_names.index(col) + for col in self.metadata_columns + [self.target_column] + ] + ) + self.assertEqual(unused_idx_in, expected_unused_indices) + + def test_get_columns_with_all_none_parameters(self): + """Test _get_columns with all parameters as None""" + result = self.preprocessor._get_columns( + self.X, + input_columns=None, + input_feature_types=None, + unused_columns=None, + unused_feature_types=None, + raise_if_missing=False, + ) + # Should return None values + self.assertEqual(result, (None, None, None, None, None, None, None)) + + def test_get_columns_with_args_none(self): + """Test _get_columns with args as None""" + result = self.preprocessor._get_columns( + self.data, + None, + input_columns=[ + self.metadata_columns, + self.column_names, + ], + raise_if_missing=True, + ) + # Unpack results + ( + feature_names_in, + feature_idx_in, + _, + extra_names_in, + extra_idx_in, + unused_extra_idx_in, + offsets, + ) = result + # Assertions for main input + self.assertEqual(feature_names_in, self.metadata_columns) + cols = list(self.data.columns) + expected_indices = [cols.index(col) for col in self.metadata_columns] + self.assertEqual(feature_idx_in, expected_indices) + + # Assertions for extra input (args) + self.assertEqual(extra_names_in, [self.column_names]) + self.assertEqual( + extra_idx_in, + [[cols.index(col) for col in self.column_names]], + ) + self.assertEqual(unused_extra_idx_in, None) + self.assertEqual(offsets, [0]) + + def test_get_columns_with_args_mismatched_rows(self): + """Test _get_columns with args having mismatched number of rows""" + # Extra input data with different number of rows + extra_X = { + "extra_feature_1": np.random.rand(self.data.shape[0] + 1), + } + input_columns = [self.column_names, ["extra_feature_1"]] + result = self.preprocessor._get_columns( + self.X, extra_X, input_columns=input_columns, raise_if_missing=True + ) + # Since row numbers don't match, offsets should not be incremented + (_, _, _, _, _, _, offsets) = result + self.assertEqual(offsets, []) # No offsets should be added + + def test_generate_fingerprint(self): + fingerprint = "initial_fingerprint" + generated_fp = self.preprocessor.generate_fingerprint(fingerprint, self.config) + self.assertIsNotNone(generated_fp) + self.assertIsInstance(generated_fp, str) + + def test_from_config(self): + new_processor = MockPreprocessor._from_config(self.config) + self.assertIsInstance(new_processor, MockPreprocessor) + self.assertEqual(new_processor.config, self.config) + + def test_call(self): + # for ray-tune compatibility + batch = pd.DataFrame({"col1": [1, 2], "col2": [3, 4]}) + kwargs = { + "fn": self.preprocessor._transform_pandas, + "func_type": "_transform", + "selected_indices": [0, 1], + "unused_indices": [2, 3], + "keep_unused_columns": False, + "in_format_kwargs": {"target_format": "pandas"}, + "out_format_kwargs": {"target_format": "pandas"}, + } + result = self.preprocessor(batch, **kwargs) + pd.testing.assert_frame_equal(result, batch) + + def test_set_params(self): + self.preprocessor.set_params( + processor_param_int=99, + processor_param_float=99.0, + processor_param_str="new", + processor_param_bool=False, + processor_param_list=[99, 99, 99], + processor_param_dict={"a": 99, "b": 99}, + processor_param_tuple=(99, 99), + ) + self.assertEqual(self.preprocessor.config.processor_param_int, 99) + self.assertEqual(self.preprocessor.config.processor_param_float, 99.0) + self.assertEqual(self.preprocessor.config.processor_param_str, "new") + self.assertEqual(self.preprocessor.config.processor_param_bool, False) + self.assertEqual(self.preprocessor.config.processor_param_list, [99, 99, 99]) + self.assertEqual( + self.preprocessor.config.processor_param_dict, {"a": 99, "b": 99} + ) + self.assertEqual(self.preprocessor.config.processor_param_tuple, (99, 99)) + + def test_is_fitted(self): + self.assertFalse(self.preprocessor.is_fitted) + self.preprocessor.fit(self.X, self.y) + self.assertTrue(self.preprocessor.is_fitted) + + def test_has_fit(self): + self.assertTrue(self.preprocessor.has_fit) + + def test_process_fit_without_input_columns(self): + # Test _process_fit without input_columns + self.assertFalse(self.model.is_fitted) + with self.assertRaises(AssertionError) as context: + self.preprocessor._process_fit(self.X, self.y) + self.assertIn( + str(context.exception), + "The `fit` method of `MockPreprocessor` must call:\n" + "```\n" + "self.config._input_columns = self._set_input_columns_and_arity(*args)" + "\n```\n" + "Where `*args` are the columns for each input dataset.", + ) + with self.assertRaises(AssertionError) as context: + self.model._process_fit(self.X, self.y) + self.assertIn( + str(context.exception), + "The `fit` method of `MockModel` must call:\n" + "```\n" + "self.config._input_columns = self._set_input_columns_and_arity(*args)" + "\n```\n" + "Where `*args` are the columns for each input dataset.", + ) + + def test_process_transform_without_input_columns(self): + # Test _process_fit without input_columns + with self.assertRaises(AssertionError) as context: + self.preprocessor._process_transform(self.X) + self.assertIn( + str(context.exception), + "The `transform` method of `MockPreprocessor` must call:\n" + "```\n" + "self._input_columns = self._set_input_columns_and_arity(*args)" + "\n```\n" + "Where `*args` are the columns for each input dataset.", + ) + + with self.assertRaises(AssertionError) as context: + self.model._method_prefix = "_predict" + self.model._process_transform(self.X) + self.assertIn( + str(context.exception), + "The `predict` method of `MockModel` must call:\n" + "```\n" + "self._input_columns = self._set_input_columns_and_arity(*args)" + "\n```\n" + "Where `*args` are the columns for each input dataset.", + ) + + with self.assertRaises(AssertionError) as context: + self.model._method_prefix = "_predict_proba" + self.model._process_transform(self.X) + self.assertIn( + str(context.exception), + "The `predict_proba` method of `MockModel` must call:\n" + "```\n" + "self._input_columns = self._set_input_columns_and_arity(*args)" + "\n```\n" + "Where `*args` are the columns for each input dataset.", + ) + + def test_process_fit_with_absolute_path_cache_file_name(self): + # make cache_file_name an absolute path + cache_file_name = ( + biofit.config.BIOFIT_PROCESSORS_CACHE / "cache.json" + ).as_posix() + + self.assertFalse(self.model.is_fitted) + with self.assertRaises(ValueError) as context: + self.model.config._input_columns = self.model._set_input_columns_and_arity( + None, None + ) + self.model._process_fit(self.X, self.y, cache_file_name=cache_file_name) + delattr(self.model.config, "_input_columns") + + self.assertIn( + str(context.exception), + "`cache_file_name` is an absolute path. Please provide the " + "file name only. You can specify the directory using " + "`cache_dir`.", + ) + + def test_process_fit_with_remote_cache_file_name(self): + # make cache_file_name an absolute path + cache_file_name = "s3://bucket/cache.json" + + self.assertFalse(self.model.is_fitted) + with self.assertRaises(ValueError) as context: + self.model.config._input_columns = self.model._set_input_columns_and_arity( + None, None + ) + self.model._process_fit(self.X, self.y, cache_file_name=cache_file_name) + delattr(self.model.config, "_input_columns") + + self.assertIn( + str(context.exception), + "`cache_file_name` is a remote URL. Please provide the " + "file name only. You can specify the directory using " + "`cache_dir`.", + ) + + def test_process_tranform_with_absolute_path_cache_file_name(self): + # make cache_file_name an absolute path + cache_file_name = ( + biofit.config.BIOFIT_PROCESSORS_CACHE / "cache.json" + ).as_posix() + + with self.assertRaises(ValueError) as context: + self.model._input_columns = self.model._set_input_columns_and_arity( + None, None + ) + self.model._process_transform(self.X, cache_file_name=cache_file_name) + delattr(self.model, "_input_columns") + + self.assertIn( + str(context.exception), + "`cache_file_name` is an absolute path. Please provide the " + "file name only. You can specify the directory using " + "`cache_dir`.", + ) + + def test_process_transform_with_remote_cache_file_name(self): + # make cache_file_name an absolute path + cache_file_name = "s3://bucket/cache.json" + + with self.assertRaises(ValueError) as context: + self.model._input_columns = self.model._set_input_columns_and_arity( + None, None + ) + self.model._process_transform(self.X, cache_file_name=cache_file_name) + delattr(self.model, "_input_columns") + + self.assertIn( + str(context.exception), + "`cache_file_name` is a remote URL. Please provide the " + "file name only. You can specify the directory using " + "`cache_dir`.", + ) + + def test_load_processed_estimator_from_cache_invalid_file(self): + # Use a non-existent cache file path + cache_file_name = "/non/existent/cache/file" + + # Expect NonExistentCacheError to be raised + with self.assertRaises(NonExistentCacheError): + self.model.load_processed_estimator_from_cache(cache_file_name) + + def test_process_fit_without_loading_from_cache(self): + # Set up the processor with caching enabled + self.model.config.enable_caching = True + self.model.config.cache_output = True + self.model.config.load_from_cache_file = False + self.assertFalse(self.model.is_fitted) + + # Mock the load_processed_estimator_from_cache method to raise NonExistentCacheError + with patch.object( + self.model, + "load_processed_estimator_from_cache", + side_effect=NonExistentCacheError, + ) as mock_load_cache: + with patch.object( + self.model.config, + "save_to_cache", + wraps=self.model.config.save_to_cache, + ) as mock_save_cache: + # Mock the _fit method to confirm it is called + # Call _process_fit + self.model.config._input_columns = ( + self.model._set_input_columns_and_arity(None, None) + ) + self.model._process_fit( + self.X, + self.y, + cache_file_name="cache.json", + fingerprint="test", + ) + + mock_load_cache.assert_not_called() + mock_save_cache.assert_called_once() + + cache_dir = generate_cache_dir( + self.model, + self.model.config._data_fingerprint, + root_dir=biofit.config.BIOFIT_PROCESSORS_CACHE, + ) + cache_path = Path(cache_dir) / "cache.json" + self.assertTrue(cache_path.exists()) + + # Check that the model is marked as fitted + self.assertTrue(self.model.is_fitted) + + def test_process_fit_without_saving_output_to_cache(self): + self.assertFalse(self.model.is_fitted) + + # Run it once to save the cache + self.model.config._input_columns = self.model._set_input_columns_and_arity( + None, None + ) + self.model._process_fit( + self.X, + self.y, + cache_file_name="cache.json", + fingerprint="test", + ) + + # Set up the processor with only loading from cache enabled + self.model.config.enable_caching = True + self.model.config.cache_output = False + self.model.config.load_from_cache_file = True + + # Also check if fingerprint is consistent + old_data_fingerprint = self.model.config._data_fingerprint + old_processor_fingerprint = self.model.fingerprint + + with patch.object( + self.model, + "load_processed_estimator_from_cache", + side_effect=NonExistentCacheError, + ) as mock_load_cache: + with patch.object( + self.model.config, + "save_to_cache", + wraps=self.model.config.save_to_cache, + ) as mock_save_cache: + # Run it again to load from cache + self.model.config._input_columns = ( + self.model._set_input_columns_and_arity(None, None) + ) + self.model._process_fit( + self.X, + self.y, + cache_file_name="cache.json", + fingerprint="test", + ) + + # Check to see if fingerprint is consistent + self.assertEqual( + old_data_fingerprint, self.model.config._data_fingerprint + ) + self.assertEqual(old_processor_fingerprint, self.model.fingerprint) + + cache_dir = generate_cache_dir( + self.model, + self.model.config._data_fingerprint, + root_dir=biofit.config.BIOFIT_PROCESSORS_CACHE, + ) + cache_path = Path(cache_dir) / "cache.json" + + # Load should be called but save should not + mock_load_cache.assert_called_once_with(cache_path.as_posix()) + mock_save_cache.assert_not_called() + + self.assertTrue(cache_path.exists()) + + # Check that the model is marked as fitted + self.assertTrue(self.model.is_fitted) + + @require_datasets + def test_transform_output_consistency_with_and_without_cache(self): + # Fit the model first without caching + X = DataHandler.to_dataset(self.X) + y = DataHandler.to_dataset(self.y) + original_fingerprint = X._fingerprint + parameters = inspect.signature(self.model._fit_sklearn).parameters + with patch.object( + self.model, "_fit_sklearn", wraps=self.model._fit_sklearn + ) as mock_fit, patch("biofit.processing.BaseProcessor._validate_fit_func_args"): + self.model._fit_sklearn.__name__ = "_fit_sklearn" + # change signature of fit method + self.model._fit_sklearn.__signature__ = inspect.Signature( + parameters=list(parameters.values()) + ) + self.model.fit(X, y) + + mock_fit.assert_called_once() + output_no_cache = self.model.predict(X) + + # Calculate the fingerprint of the output dataset without cache + fingerprint_no_cache = output_no_cache._fingerprint + + with patch.object( + self.model, "_fit_sklearn", wraps=self.model._fit_sklearn + ) as mock_fit: + # Fit the model again (should create cache files) + self.model._fit_sklearn.__name__ = "_fit_sklearn" + self.model.fit(X, y) + + # Since loaded from cache, fit should not be called + mock_fit.assert_not_called() + + # Transform the data with cache enabled + output_with_cache = self.model.predict(X) + + # Calculate the fingerprint of the output dataset with cache + fingerprint_with_cache = output_with_cache._fingerprint + + # Transform a separate data with cache enabled + other_output_with_cache = self.model.predict(y) + + other_fingerprint_with_cache = other_output_with_cache._fingerprint + + # Compare the outputs + self.assertEqual(len(output_no_cache), len(output_with_cache)) + output_no_cache = DataHandler.to_pandas(output_no_cache) + output_with_cache = DataHandler.to_pandas(output_with_cache) + pd.testing.assert_frame_equal(output_no_cache, output_with_cache) + # Compare the fingerprints + self.assertEqual(original_fingerprint, X._fingerprint) + self.assertEqual(fingerprint_no_cache, fingerprint_with_cache) + self.assertEqual(fingerprint_no_cache, fingerprint_with_cache) + self.assertNotEqual(fingerprint_no_cache, other_fingerprint_with_cache) + + def test_parse_fingerprint(self): + fingerprint = "base_fingerprint" + parsed_fp = self.preprocessor._parse_fingerprint(fingerprint) + self.assertIn(fingerprint, parsed_fp) + self.assertIn(self.config.processor_name, parsed_fp) + + def test_reset(self): + # Test if the reset method resets the processor to the state before the first + # fit was called + preprocessor = copy.deepcopy(self.preprocessor.fit(self.X, self.y)) + + # Change the parameters of the processor + preprocessor.config.processor_param_int += 10 + preprocessor.config.processor_param_float = 10.0 + + # Reset the processor + preprocessor._reset(preprocessor.config_) + + # Should be the same as the original processor + self.assertEqual( + preprocessor.config.get_params(), self.preprocessor.config.get_params() + ) + + def test_from_config_classmethod(self): + new_processor = MockPreprocessor.from_config(self.config) + self.assertIsInstance(new_processor, MockPreprocessor) + self.assertEqual(new_processor.config, self.config) + + def test_validate_fit_params_input_raise(self): + self.preprocessor.config._fit_input_feature_types = [get_feature("Abundance")] + self.preprocessor.config._fit_unused_feature_types = None + self.preprocessor.config._transform_input_feature_types = None + self.preprocessor.config._transform_unused_feature_types = None + self.preprocessor.config._input_columns = [None, None] + with self.assertRaises(AssertionError) as context: + self.preprocessor._validate_fit_params(2) + self.assertIn( + str(context.exception), + "`_fit_input_feature_types` is defined in " + "MockProcessorConfig but does not match the arity of " + "the fit function in MockPreprocessor (i.e. len(" + "self.config._fit_input_feature_types) != " + "len(self.config._input_columns) -> " + "1 != 2).\n" + "This can be corrected by doing, for example:\n" + "_fit_input_feature_types = field(\n" + " default_factory=lambda: [None, None], init=False, " + "repr=False\n" + ")", + ) + + def test_validate_fit_params_unused_raise(self): + self.preprocessor.config._fit_input_feature_types = None + self.preprocessor.config._fit_unused_feature_types = [get_feature("Abundance")] + self.preprocessor.config._transform_input_feature_types = None + self.preprocessor.config._transform_unused_feature_types = None + self.preprocessor.config._input_columns = [None, None] + with self.assertRaises(AssertionError) as context: + self.preprocessor._validate_fit_params(2) + self.assertIn( + str(context.exception), + "`_fit_unused_feature_types` is defined in " + "MockProcessorConfig but does not match the arity of " + "the fit function in MockPreprocessor (i.e. len(" + "self.config._fit_unused_feature_types) != " + "len(self.config._input_columns) -> " + "1 != 2).\n" + "This can be corrected by doing, for example:\n" + "_fit_unused_feature_types = field(\n" + " default_factory=lambda: [None, None], init=False, " + "repr=False\n" + ")", + ) + + def test_validate_transform_params_input_raise(self): + self.preprocessor.config._fit_input_feature_types = None + self.preprocessor.config._fit_unused_feature_types = None + self.preprocessor.config._transform_input_feature_types = [ + get_feature("Abundance") + ] + self.preprocessor.config._transform_unused_feature_types = None + self.preprocessor._input_columns = [None, None] + with self.assertRaises(AssertionError) as context: + self.preprocessor._validate_transform_params(2) + self.assertIn( + str(context.exception), + "`_transform_input_feature_types` is defined in " + "MockProcessorConfig but does not match the arity of " + "the transform function in MockPreprocessor (i.e. len(" + "self.config._transform_input_feature_types) != " + "len(self._input_columns) -> " + "1 != 2).\n" + "This can be corrected by doing, for example:\n" + "_transform_input_feature_types = field(\n" + " default_factory=lambda: [None, None], init=False, " + "repr=False\n" + ")", + ) + + def test_validate_transform_params_unused_raise(self): + self.preprocessor.config._fit_input_feature_types = None + self.preprocessor.config._fit_unused_feature_types = None + self.preprocessor.config._transform_input_feature_types = None + self.preprocessor.config._transform_unused_feature_types = [ + get_feature("Abundance") + ] + self.preprocessor._input_columns = [None, None] + with self.assertRaises(AssertionError) as context: + self.preprocessor._validate_transform_params(2) + self.assertIn( + str(context.exception), + "`_transform_unused_feature_types` is defined in " + "MockProcessorConfig but does not match the arity of " + "the transform function in MockPreprocessor (i.e. len(" + "self.config._transform_unused_feature_types) != " + "len(self._input_columns) -> " + "1 != 2).\n" + "This can be corrected by doing, for example:\n" + "_transform_unused_feature_types = field(\n" + " default_factory=lambda: [None, None], init=False, " + "repr=False\n" + ")", + ) + + def test_fit_missing_input_cols(self): + with self.assertRaises(AssertionError) as context: + self.preprocessor._fit(self.X, self.y, funcs=[self.preprocessor._fit_arrow]) + + self.assertIn( + "`extra_indices` was returned as `None` from " + "`MockPreprocessor`. " + "Was `MockPreprocessor._input_columns` or " + "`MockPreprocessor.config._input_columns` set correctly?", + str(context.exception), + ) + + def test_fit_transform(self): + self.preprocessor.fit_transform(self.X, self.y) + + def test_fit_transform_with_columns(self): + self.preprocessor.fit_transform( + self.data, input_columns=self.column_names, target_column=self.target_column + ) + + def test_fit_transform_without_automatic_column_selection(self): + # First, fit the processor + with self.assertRaises(AssertionError) as context: + self.preprocessor.config._fit_input_feature_types = [None, None] + self.preprocessor.fit_transform(self.data) + self.assertIn( + str(context.exception), + "`MockPreprocessor.fit` requires 2 arguments ('X', 'y'), but only 1 was " + "provided. Either provide the missing arguments or provide the input " + "columns found in 'X', if applicable.", + ) + + def test_fit_transform_with_missing_col_in_X(self): + # First, fit the processor + with self.assertRaises(AssertionError) as context: + self.preprocessor.fit_transform( + self.X + ) # does not any contain TARGET_FEATURE_TYPES + self.assertIn( + str(context.exception), + "`MockPreprocessor.fit` requires 2 arguments ('X', 'y'), but only 1 was " + "provided. Either provide the missing arguments or provide the input " + "columns found in 'X', if applicable.", + ) + + def test_fit_transform_output_format(self): + # First, fit the processor + self.preprocessor = self.preprocessor.fit(self.X, self.y) + out1 = self.preprocessor.transform(self.X, output_format="numpy") + out2 = self.preprocessor.fit_transform(self.X, self.y, output_format="numpy") + self.assertIs(type(out1), np.ndarray) + self.assertTrue(np.array_equal(out1, out2)) + + def test_process_extra_inds(self): + extra_indices = [[1, 2]] + unused_extra_indices = [[3]] + extra_inputs = [None] + orig_input = None + result = self.preprocessor._process_extra_inds( + orig_input, extra_inputs, extra_indices, unused_extra_indices + ) + self.assertEqual(result, (extra_indices, unused_extra_indices)) + + def test_prepare_fit_kwargs(self): + combined_inputs = self.data + orig_input = self.sample_metadata + extra_inputs = [self.X, self.y] + selected_indices = list(range(self.data.shape[1])) + extra_indices = [ + list(range(self.X.shape[1])), + list(range(self.y.shape[1])), + ] + map_kwargs, pooler = self.preprocessor._prepare_fit_kwargs( + funcs=[ + self.preprocessor._fit_arrow, + self.preprocessor._fit_pandas, + self.preprocessor._fit_numpy, + ], + combined_inputs=combined_inputs, + orig_input=orig_input, + selected_indices=selected_indices, + extra_inputs=extra_inputs, + extra_indices=extra_indices, + extra_untouched_inputs=None, + map_kwargs={"fn_kwargs": {}}, + num_proc=1, + ) + + # The expected extra indices should be adjusted to match the index + # within the combined_inputs (i.e. offset by the number of columns + # in the orig_input) + expected_extra_indices = [ + list( + range( + self.sample_metadata.shape[1], + self.sample_metadata.shape[1] + self.X.shape[1], + ) + ), + list( + range( + self.sample_metadata.shape[1] + self.X.shape[1], + self.data.shape[1], + ) + ), + ] + expected_map_kwargs = { + "fn_kwargs": { + "fn": self.preprocessor._fit_pandas, # should match data type of combined_inputs + "func_type": "_fit", + "extra_untouched_inputs": None, + "selected_indices": selected_indices, + "unused_indices": None, + "extra_indices": expected_extra_indices, # should be adjusted to match the data type of extra_inputs + "unused_extra_indices": [None, None], + "with_metadata": False, + "in_format_kwargs": {"target_format": "pandas"}, + "out_format_kwargs": {"target_format": None}, + }, + "with_indices": False, + "with_rank": False, + "desc": "Fitting test", + "batched": True, + "batch_size": None, + "new_fingerprint": "None-processor-mock", + } + for key, value in expected_map_kwargs.items(): + if key == "fn_kwargs": + for k, v in value.items(): + self.assertEqual( + map_kwargs[key][k], v, f"{k}: {map_kwargs[key][k]} != {v}" + ) + else: + self.assertEqual(map_kwargs[key], value) + self.assertIsNone(pooler) + + def test_prepare_transform_kwargs(self): + combined_inputs = self.data + orig_input = self.sample_metadata + extra_inputs = [self.X, self.y] + selected_indices = list(range(self.data.shape[1])) + extra_indices = [ + list(range(self.X.shape[1])), + list(range(self.y.shape[1])), + ] + map_kwargs = self.preprocessor._prepare_transform_kwargs( + combined_inputs, + orig_input, + self.X, + self.y, + selected_indices=selected_indices, + extra_indices=extra_indices, + unused_indices=None, + unused_extra_indices=None, + map_kwargs={"fn_kwargs": {}}, + num_proc=1, + ) + + # The expected extra indices should be adjusted to match the index + # within the combined_inputs (i.e. offset by the number of columns + # in the orig_input) + expected_extra_indices = [ + list( + range( + self.sample_metadata.shape[1], + self.sample_metadata.shape[1] + self.X.shape[1], + ) + ), + list( + range( + self.sample_metadata.shape[1] + self.X.shape[1], + self.data.shape[1], + ) + ), + ] + expected_map_kwargs = { + "fn_kwargs": { + "fn": self.preprocessor._transform_pandas, # should match data type of combined_inputs + "func_type": "_transform", + "with_metadata": False, + "selected_indices": selected_indices, + "unused_indices": None, + "extra_indices": expected_extra_indices, + "unused_extra_indices": [None, None], + "keep_unused_columns": None, + "in_format_kwargs": {"target_format": "pandas"}, + "out_format_kwargs": {"target_format": "arrow"}, + "feature_names": list(self.sample_metadata.columns), + }, + "with_indices": False, + "with_rank": False, + "desc": "Transforming test", + "keep_in_memory": False, + "cache_file_name": None, + "num_proc": 1, + "batched": True, + "batch_size": 1000, + "load_from_cache_file": True, + "new_fingerprint": None, + } + + for key, value in expected_map_kwargs.items(): + if key == "fn_kwargs": + for k, v in value.items(): + self.assertEqual( + map_kwargs[key][k], v, f"{k}: {map_kwargs[key][k]} != {v}" + ) + else: + self.assertEqual(map_kwargs[key], value) + + def test_process_transform_input(self): + X = "input_data" + kwargs = {"key": "value"} + result_X, result_kwargs = self.preprocessor._process_transform_input( + X, **kwargs + ) + self.assertEqual(result_X, X) + self.assertEqual(result_kwargs, kwargs) + + def test_process_transform_output(self): + result = self.preprocessor._process_transform_output(self.X, self.X) + pd.testing.assert_frame_equal(result, self.X) + + def test_get_params(self): + params = self.preprocessor.get_params() + self.assertIsInstance(params, dict) + self.assertEqual(params["processor_param_int"], 1) + + def test_load_processed_estimator_from_cache(self): + with self.assertRaises(Exception): + self.preprocessor.load_processed_estimator_from_cache( + "nonexistent_cache_file" + ) + + # TODO: Make a better test for this + # def test_get_features_out(self): + # X = pd.DataFrame({"col1": [1, 2], "col2": [3, 4]}) + # features_out = self.preprocessor._get_features_out(X, selected_indices=[0]) + # self.assertIsInstance(features_out, dict) + + def test_prepare_runner(self): + fn_kwargs = {} + result_kwargs = self.preprocessor._prepare_runner(None, **fn_kwargs) + self.assertEqual(result_kwargs, fn_kwargs) + + # def test_run(self): + # def runner(X, **kwargs): + # return "processed" + # + # result = self.preprocessor.run("input_data", runner=runner) + # self.assertEqual(result, "processed") + + def test_get_feature_names_out(self): + # Testing _get_feature_names_out method + input_features = ["feature1", "feature2"] + n_features_out = 2 + useful_feature_inds = [0, 1] + out_features = self.preprocessor._get_feature_names_out( + input_features=input_features, + n_features_out=n_features_out, + useful_feature_inds=useful_feature_inds, + one_to_one_features=True, + ) + self.assertEqual(out_features, ["feature1", "feature2"]) + + # Test with prefix and suffix + self.preprocessor.config.features_out_prefix = "prefix_" + self.preprocessor.config.features_out_suffix = "_suffix" + out_features = self.preprocessor._get_feature_names_out( + input_features=input_features, + n_features_out=n_features_out, + useful_feature_inds=useful_feature_inds, + one_to_one_features=True, + ) + self.assertEqual( + out_features, ["prefix_feature1_suffix", "prefix_feature2_suffix"] + ) + + # def test_transform_with_unused_columns(self): + # # test transforming data while keeping unused columns + # x = self.x.copy() + # self.preprocessor.fit(x[["feature1", "feature2"]]) + # + # transformed_x = self.preprocessor.transform(x, keep_unused_columns=true) + # + # # verify that unused columns are present + # self.assertin("metadata", transformed_x.columns) + # self.assertin("target", transformed_x.columns) + # # verify that 'feature1' and 'feature2' are centered + # expected_x = x[["feature1", "feature2"]] - x[["feature1", "feature2"]].mean() + # pd.testing.assert_frame_equal( + # transformed_x[["feature1", "feature2"]], expected_x + # ) + # # verify that unused columns are unchanged + # pd.testing.assert_series_equal(transformed_x["metadata"], x["metadata"]) + # pd.testing.assert_series_equal(transformed_x["target"], x["target"]) + + def test_process_fit_batch_input_valid(self): + """ + Test _process_fit_batch_input with valid arguments. + """ + input_processed, fn_args, fn_kwargs = self.model._process_fit_batch_input( + DataHandler.to_arrow(self.data), + fn=self.model._fit_sklearn, + selected_indices=[0, 1, 2], + extra_indices=[[-1]], + unused_indices=[], + extra_untouched_inputs=None, + with_metadata=False, + in_format_kwargs={"target_format": "pandas"}, + ) + + # Assertions + self.assertIsNotNone(input_processed) + self.assertIsInstance(fn_args, tuple) + self.assertEqual(len(fn_args), 1) + format = DataHandler.get_format(input_processed) + self.assertEqual(format, "pandas") + + def test_process_fit_batch_unmatched_arity(self): + """ + Test _process_fit_batch_input with empty input. + """ + # pretty much the same as test_fit_transform_without_automatic_column_selection + # but directly testing the method + with self.assertRaises(AssertionError) as context: + self.model._process_fit_batch_input( + DataHandler.to_arrow(self.data), + fn=self.model._fit_sklearn, + selected_indices=[0, 1, 2], + extra_indices=[], + unused_indices=[], + extra_untouched_inputs=None, + with_metadata=False, + in_format_kwargs={"target_format": "pandas"}, + ) + self.assertIn( + str(context.exception), + "`MockModel.fit` requires 2 arguments ('X', 'y'), but only 1 was " + "provided. Either provide the missing arguments or provide the input " + "columns found in 'X', if applicable.", + ) + + def test_process_fit_batch_output_invalid(self): + """ + Test _process_fit_batch_output with invalid output. + """ + # This test doesn't really do anything, but it's here for completeness + output = None + + processed_output = self.model._process_fit_batch_output(output) + + # Assertions + self.assertIsNone(processed_output) + + def test_process_transform_batch_input_valid(self): + """ + Test _process_transform_batch_input with valid arguments. + """ + input_data = self.X + + input_processed, fn_args, _ = self.preprocessor._process_transform_batch_input( + input_data, + fn=self.preprocessor._transform_pandas, + selected_indices=[0, 1, 2], + unused_indices=[3], + keep_unused_columns=True, + with_metadata=False, + in_format_kwargs={}, + out_format_kwargs={}, + ) + + # Assertions + self.assertIsNotNone(input_processed) + self.assertIsInstance(fn_args, tuple) + self.assertEqual(len(fn_args), 0) + + def test_process_transform_batch_input_empty(self): + """ + Test _process_transform_batch_input with empty input. + """ + with self.assertRaises(AssertionError) as context: + input_data = None + + self.preprocessor._process_transform_batch_input( + input_data, + fn=self.preprocessor._transform_pandas, + selected_indices=[0, 1, 2], + unused_indices=[3], + keep_unused_columns=True, + with_metadata=False, + in_format_kwargs={}, + out_format_kwargs={}, + ) + + self.assertIn( + str(context.exception), + "No input data was provided for processing.", + ) + + def test_process_transform_batch_input_invalid_columns(self): + """ + Test _process_transform_batch_input with invalid/missing columns. + """ + with self.assertRaises(IndexError): + input_data = self.X + + self.preprocessor._process_transform_batch_input( + input_data, + fn=self.preprocessor._transform_pandas, + selected_indices=[self.X.shape[1] + 1], + unused_indices=[3], + keep_unused_columns=True, + with_metadata=False, + in_format_kwargs={}, + out_format_kwargs={}, + ) + + def test_process_transform_batch_output_valid(self): + """ + Test _process_transform_batch_output with valid output. + """ + input = DataHandler.to_polars(self.X, [0, 1, 2, 3]) + output = DataHandler.to_polars(self.X, [0, 1, 3]) + + processed_output = self.preprocessor._process_transform_batch_output( + input, + output, + selected_indices=[0, 1, 3], + unused_indices=[2], + keep_unused_columns=True, + feature_names=["test1", "test2", "int32_3", "test4"], + out_format_kwargs={"target_format": "arrow"}, + one_to_one_features=True, + ) + + # Assertions + self.assertEqual( + processed_output.column_names, ["test1", "test2", "int32_3", "test4"] + ) + self.assertIsInstance(processed_output, pa.Table) + + # def test_process_transform_batch_output_keep_unused_columns(self): + # """ + # Test _process_transform_batch_output with keeping unused columns. + # """ + # transformed_X = self.preprocessor._transform_pandas(self.X) + # + # processed_output = self.preprocessor._process_transform_batch_output( + # self.X, + # transformed_X, + # fn_kwargs={ + # "selected_indices": [0, 1, 2], + # "unused_indices": [3], + # "keep_unused_columns": True, + # "feature_names": ["sample_id", "multi_int", "multi_str", "labels"], + # "out_format_kwargs": {"target_format": "pandas"}, + # }, + # ) + # + # # Assertions + # self.assertIn("labels", processed_output.column_names) + # self.assertEqual(len(processed_output.columns), 4) + + # def test_process_transform_batch_output_discard_unused_columns(self): + # """ + # Test _process_transform_batch_output with discarding unused columns. + # """ + # transformed_X = self.preprocessor._transform_pandas(self.X) + # + # processed_output = self.preprocessor._process_transform_batch_output( + # self.X, + # transformed_X, + # fn_kwargs={ + # "selected_indices": [0, 1, 2], + # "unused_indices": [3], + # "keep_unused_columns": False, + # "feature_names": ["sample_id", "multi_int", "multi_str"], + # "out_format_kwargs": {"target_format": "pandas"}, + # }, + # ) + # + # # Assertions + # self.assertNotIn("labels", processed_output.columns) + # self.assertEqual(len(processed_output.columns), 3) + + # def test_process_transform_batch_output_invalid_output_format(self): + # """ + # Test _process_transform_batch_output with invalid output format. + # """ + # transformed_X = self.preprocessor._transform_pandas(self.X) + # + # with self.assertRaises(ValueError): + # self.preprocessor._process_transform_batch_output( + # self.X, + # transformed_X, + # fn_kwargs={ + # "selected_indices": [0, 1, 2], + # "unused_indices": [3], + # "keep_unused_columns": False, + # "feature_names": ["sample_id", "multi_int", "multi_str"], + # "out_format_kwargs": {"target_format": "invalid_format"}, + # }, + # ) + + # def test_process_fit_batch_input_invalid_fn_kwargs(self): + # """ + # Test _process_fit_batch_input with invalid fn_kwargs. + # """ + # with self.assertRaises(AttributeError): + # input_data = self.X + # y = self.y + # + # # Pass incorrect fn_kwargs (missing 'fn') + # self.model._process_fit_batch_input( + # input_data, + # y, + # fn_kwargs={ + # "selected_indices": [0, 1, 2], + # "extra_indices": [], + # "unused_indices": [], + # "extra_untouched_inputs": [], + # "with_metadata": False, + # "indicate_last_batch": False, + # "in_format_kwargs": {}, + # "out_format_kwargs": {}, + # "with_target": True, + # }, + # ) + + # def test_process_transform_batch_input_with_extra_inputs(self): + # """ + # Test _process_transform_batch_input with extra inputs. + # """ + # with patch.object( + # self.preprocessor, "_transform_pandas", return_value=self.X + # ) as mock_transform: + # input_data = self.X + # extra_input = pa.table({"extra_col": [4, 5, 6]}) + # + # input_processed, fn_args, fn_kwargs = ( + # self.preprocessor._process_transform_batch_input( + # input_data, + # extra_input, + # fn_kwargs={ + # "fn": self.preprocessor._transform_pandas, + # "selected_indices": [0, 1, 2], + # "unused_indices": [3], + # "keep_unused_columns": True, + # "with_metadata": False, + # "in_format_kwargs": {}, + # "out_format_kwargs": {}, + # }, + # ) + # ) + # + # # Assertions + # self.assertIsNotNone(input_processed) + # self.assertIsInstance(fn_args, tuple) + # self.assertEqual(len(fn_args), 1) + # self.assertIsInstance(fn_args[0], pa.Table) + # self.assertEqual(fn_kwargs["fn"], self.preprocessor._transform_pandas) + # mock_transform.assert_called_once() + + def test_process_fit_input(self): + input = "input_data" + kwargs = {"key": "value"} + result_input, result_kwargs = self.preprocessor._process_fit_input( + input, **kwargs + ) + self.assertEqual(result_input, input) + self.assertEqual(result_kwargs, kwargs) + + def test_process_fit_output(self): + input = "input_data" + out = "output_data" + result = self.preprocessor._process_fit_output(input, out) + self.assertEqual(result, out) + + # def test_repr(self): + # repr_str = repr(self.preprocessor) + # self.assertIsInstance(repr_str, str) + # self.assertIn("MockPreprocessor", repr_str) + # + # def test_repr_mimebundle(self): + # mimebundle = self.preprocessor._repr_mimebundle_() + # self.assertIsInstance(mimebundle, dict) + # self.assertIn("text/plain", mimebundle) + # + # def test_shard(self): + # # Testing shard method + # num_shards = 5 + # for index in range(num_shards): + # shard = self.preprocessor.shard( + # self.X, num_shards=num_shards, index=index, contiguous=True + # ) + # # Verify that the shards cover the entire dataset when combined + # if index == 0: + # combined_shard = shard + # else: + # combined_shard = pd.concat([combined_shard, shard], ignore_index=True) + # pd.testing.assert_frame_equal( + # combined_shard.reset_index(drop=True), self.X.reset_index(drop=True) + # ) + + # def test_pool_fit(self): + # # Testing _pool_fit method + # # Since pooling is not implemented for multiple processors, expect NotImplementedError + # + # # Create multiple fitted processors + # fitted_processors = [ + # self.preprocessor.fit(self.X[["feature1", "feature2"]]) for _ in range(2) + # ] + # + # with self.assertRaises(NotImplementedError): + # self.preprocessor._pool_fit(fitted_processors) + # + # # Test with a single processor + # pooled_processor = self.preprocessor._pool_fit([self.preprocessor]) + # self.assertEqual(pooled_processor, self.preprocessor) + # + # def test_map(self): + # X = pd.DataFrame({"col1": [1, 2, 3]}) + # + # def func(batch): + # return batch * 2 + # + # result = self.preprocessor.map(X, function=func, batched=True) + # expected = X * 2 + # pd.testing.assert_frame_equal(result[0], expected) + # + # def test_map_single(self): + # shard = pd.DataFrame({"col1": [1, 2, 3]}) + # + # def func(batch): + # return batch * 2 + # + # gen = BaseProcessor._map_single(shard, function=func, batched=True) + # results = list(gen) + # for rank, done, content in results: + # if done: + # processed_data = content + # expected = shard * 2 + # pd.testing.assert_frame_equal(processed_data[0], expected) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..9eac24a --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,133 @@ +import unittest + +import pandas as pd +from biocore.data_handling import DataHandler +from biocore.utils.import_util import ( + is_biosets_available, + is_datasets_available, + is_polars_available, + is_rpy2_available, +) + + +def require_polars(test_case): + """ + Decorator marking a test that requires Polars. + + These tests are skipped when Polars isn't installed. + + """ + if not is_polars_available(): + test_case = unittest.skip("test requires Polars")(test_case) + return test_case + + +def require_biosets(test_case): + """ + Decorator marking a test that requires biosets. + + These tests are skipped when biosets isn't installed. + + """ + if not is_biosets_available(): + test_case = unittest.skip("test requires biosets")(test_case) + return test_case + + +def require_datasets(test_case): + """ + Decorator marking a test that requires datasets. + + These tests are skipped when datasets isn't installed. + + """ + if not is_datasets_available(): + test_case = unittest.skip("test requires datasets")(test_case) + return test_case + + +def require_rpy2(test_case): + """ + Decorator marking a test that requires rpy2. + + These tests are skipped when rpy2 isn't installed. + + """ + if not is_rpy2_available(): + test_case = unittest.skip("test requires rpy2")(test_case) + return test_case + + +def create_bioset( + X, + y=None, + sample_metadata=None, + with_feature_metadata=True, + feature_type=None, + target_type=None, +): + """ + Create a bioset from data. + + Args: + X (np.ndarray or pd.DataFrame): The data matrix. + y (np.ndarray or pd.Series, optional): The target vector. + metadata (pd.DataFrame, optional): The metadata. + feature_metadata (pd.DataFrame, optional): The feature metadata. + + Returns: + bioset (Bioset): The bioset. + + """ + from biosets.features import Metadata, Sample, Batch + from datasets import ClassLabel, Features + + data = [X] + if y is not None: + data.append(y) + if sample_metadata is not None: + data = [sample_metadata] + data + + data = pd.concat(data, axis=1) + dtypes = DataHandler.get_dtypes(X) + + features = {} + if sample_metadata is not None: + sample_metadata_dtypes = DataHandler.get_dtypes(sample_metadata) + for k, v in sample_metadata_dtypes.items(): + if "sample" in k.lower(): + features[k] = Sample(v) + elif "batch" in k.lower(): + features[k] = Batch(v) + else: + features[k] = Metadata(v) + + if with_feature_metadata: + metadata = { + "my_metadata_str": "str", + "my_metadata_int": 0, + "my_metadata_float": 0.0, + "my_metadata_bool": False, + "my_metadata_list": ["a", "b", "c"], + "my_metadata_dict": {"a": 1, "b": 2, "c": 3}, + "my_metadata_none": None, + } + features.update( + {k: feature_type(dtype=v, metadata=metadata) for k, v in dtypes.items()} + ) + else: + features.update({k: feature_type(dtype=v) for k, v in dtypes.items()}) + if y is not None and target_type is not None: + if issubclass(target_type, ClassLabel): + names = list(set(y.values.flatten().tolist())) + + if isinstance(names[0], int): + names = ["abcdefghiklmnopqrstuvwxyz"[i] for i in names if i >= 0] + + features[y.columns[0]] = target_type(num_classes=len(names), names=names) + else: + features[y.columns[0]] = target_type() + + ds = DataHandler.to_bioset(data) + ds._info.features = Features(features) + return ds diff --git a/tools/generate_auto_maps.py b/tools/generate_auto_maps.py new file mode 100644 index 0000000..27b46ee --- /dev/null +++ b/tools/generate_auto_maps.py @@ -0,0 +1,414 @@ +import ast +import os +import re +from collections import defaultdict +from pathlib import Path + +from biofit import DATASET_NAME_TO_OMIC_TYPE + +OMIC_TYPE_TO_DATASET_NAMES = defaultdict(list) +ALL_DATASET_NAMES = set() +for dataset_name, omic_type in DATASET_NAME_TO_OMIC_TYPE.items(): + OMIC_TYPE_TO_DATASET_NAMES[omic_type].append(dataset_name) + ALL_DATASET_NAMES.add(dataset_name) + + +class ClassFinder(ast.NodeVisitor): + def __init__(self, known_classes, module_prefix, target_names): + self.found_classes = {} + self.known_classes = known_classes + self.module_prefix = module_prefix + self.target_names = target_names + + def visit_ClassDef(self, node): + # Check if any base of the current class is a known ProcessorConfig descendant + for base in node.bases: + if isinstance(base, ast.Name) and base.id in self.known_classes: + target_vals = [None] * len(self.target_names) + # Process each node in the class body + for body_node in node.body: + if isinstance(body_node, (ast.AnnAssign, ast.Assign)): + target_name = self.get_target_var_name(body_node) + if target_name in self.target_names: + index = self.target_names.index(target_name) + target_vals[index] = self.extract_name_value( + body_node.value + ) + self.found_classes[node.name] = ( + self.module_prefix, + base.id, + *target_vals, + ) + + # Assume this class is also a config now + self.known_classes.append(node.name) + self.generic_visit(node) + + def get_target_var_name(self, node): + if isinstance(node, ast.AnnAssign): + return node.target.id + elif isinstance(node, ast.Assign) and isinstance(node.targets[0], ast.Name): + return node.targets[0].id + return None + + def extract_name_value(self, node): + if isinstance(node, ast.Name): + return node.id + elif isinstance(node, ast.Constant): + return node.value + elif isinstance(node, ast.Call): + # Check if the first positional argument or a keyword argument 'default' is set + if node.args: + return self.safe_literal_eval(node.args[0]) + for kw in node.keywords: + if kw.arg == "default": + return self.safe_literal_eval(kw.value) + return None + + def safe_literal_eval(self, node): + if isinstance(node, ast.Constant): + return node.value + return None + + +def finalize_configs(classes): + out = {} + for key, value in classes.items(): + module_prefix, base_id, processor_type, processor_name, dataset_name = value + if base_id in classes: + _base_id = base_id + while not processor_type: + try: + _, _base_id, processor_type, _, _ = classes[_base_id] + except KeyError: + break + + _base_id = base_id + while not processor_name: + try: + _, _base_id, _, processor_name, _ = classes[_base_id] + except KeyError: + break + out[key] = (module_prefix, processor_type, processor_name, dataset_name) + return out + + +def finalize_processors(classes, config_classes): + name2proc = {} + name2config = {} + name2category = {} + name2type = {} + + for key, value in classes.items(): + module_prefix, base_id, config_class_name = value + config_class = ( + config_classes.get(config_class_name, None) if config_class_name else None + ) + if config_class: + processor_type, processor_name, dataset_name = config_class[1:] + if processor_name: + name2proc[processor_name] = key + name2config[processor_name] = config_class_name + processor_category = module_prefix.split(".")[1] + if processor_type == processor_category: + processor_type = None + if processor_type: + name2type[processor_name] = processor_type + name2category[processor_name] = processor_category + + dataset2config = defaultdict(dict) + proc2config = defaultdict(list) + for key, value in config_classes.items(): + module_prefix, processor_type, processor_name, dataset_name = value + if processor_name: + proc2config[processor_name].append( + (module_prefix, key, processor_type, dataset_name) + ) + for key, values in proc2config.items(): + _dataset2config = {value[3]: value for value in values if value[3]} + generic_config = [value for value in values if not value[3]][0] + for dataset_name in ALL_DATASET_NAMES: + if dataset_name in _dataset2config: + _, config_class, _, _ = _dataset2config[dataset_name] + dataset2config[dataset_name][key] = config_class + else: + omic_type = DATASET_NAME_TO_OMIC_TYPE.get(dataset_name, None) + if omic_type in _dataset2config: + _, config_class, _, _ = _dataset2config[omic_type] + dataset2config[dataset_name][key] = config_class + else: + _, config_class, _, _ = generic_config + dataset2config[dataset_name][key] = config_class + + return name2proc, name2config, name2category, name2type, dataset2config + + +def get_all_processor_configs( + source_folder: str, module_name="biofit", known_classes="ProcessorConfig" +): + if not isinstance(known_classes, list): + known_classes = [known_classes] + _processors = {} + package_folder = Path(source_folder) / module_name.replace(".", "/") + for root, dirs, files in os.walk(package_folder.as_posix()): + relative_root = os.path.relpath(root, start=source_folder) + module_prefix = relative_root.replace(os.sep, ".") + for file in files: + if file.endswith(".py"): + # Construct the module path relative to the root + module_path = os.path.join(root, file) + with open(module_path, "r", encoding="utf-8") as file: + try: + # Parse the file content into an AST + tree = ast.parse(file.read(), filename=module_path) + # Initialize the finder and visit the AST + finder = ClassFinder( + known_classes, + module_prefix, + ["processor_type", "processor_name", "dataset_name"], + ) + finder.visit(tree) + # Collect found processors + _processors.update(finder.found_classes) + except SyntaxError: + print(f"Syntax error in {module_path}, skipping.") + + return finalize_configs(_processors) + + +def get_all_processors( + source_folder: str, + module_name="biofit", + known_classes="BaseProcessor", + config_classes=None, +): + if not isinstance(known_classes, list): + known_classes = [known_classes] + _processors = {} + package_folder = Path(source_folder) / module_name.replace(".", "/") + for root, dirs, files in os.walk(package_folder.as_posix()): + relative_root = os.path.relpath(root, start=source_folder) + module_prefix = relative_root.replace(os.sep, ".") + for file in files: + if file.endswith(".py"): + # Construct the module path relative to the root + module_path = os.path.join(root, file) + with open(module_path, "r", encoding="utf-8") as file: + try: + # Parse the file content into an AST + tree = ast.parse(file.read(), filename=module_path) + # Initialize the finder and visit the AST + finder = ClassFinder( + known_classes, module_prefix, ["_config_class"] + ) + finder.visit(tree) + # Collect found processors + _processors.update(finder.found_classes) + except SyntaxError: + print(f"Syntax error in {module_path}, skipping.") + return finalize_processors(_processors, config_classes) + + +def create_config_mapping_constants( + name2proc, + name2config, + name2category, + name2type, + dataset2config, + name2plotter, + name2pltconfig, + dataset2pltconfig, +): + processing_mapping_names_str = "PROCESSOR_MAPPING_NAMES = OrderedDict(\n [\n" + for key, value in name2proc.items(): + processing_mapping_names_str += f' ("{key}", "{value}"),\n' + processing_mapping_names_str += " ]\n)" + + plotter_mapping_names_str = "PLOTTER_MAPPING_NAMES = OrderedDict(\n [\n" + for key, value in name2plotter.items(): + plotter_mapping_names_str += f' ("{key}", "{value}"),\n' + plotter_mapping_names_str += " ]\n)" + + config_mapping_names_str = "CONFIG_MAPPING_NAMES = OrderedDict(\n [\n" + for key, value in name2config.items(): + config_mapping_names_str += f' ("{key}", "{value}"),\n' + config_mapping_names_str += " ]\n)" + + pltconfig_mapping_names_str = "PLOTTER_CONFIG_MAPPING_NAMES = OrderedDict(\n [\n" + for key, value in name2pltconfig.items(): + pltconfig_mapping_names_str += f' ("{key}", "{value}"),\n' + pltconfig_mapping_names_str += " ]\n)" + + processor_category_mapping_names_str = ( + "PROCESSOR_CATEGORY_MAPPING_NAMES = OrderedDict(\n [\n" + ) + for key, value in name2category.items(): + processor_category_mapping_names_str += f' ("{key}", "{value}"),\n' + processor_category_mapping_names_str += " ]\n)" + + processor_type_mapping_names_str = ( + "PROCESSOR_TYPE_MAPPING_NAMES = OrderedDict(\n [\n" + ) + for key, value in name2type.items(): + processor_type_mapping_names_str += f' ("{key}", "{value}"),\n' + processor_type_mapping_names_str += " ]\n)" + + head_str = processing_mapping_names_str + head_str += "\n\n" + head_str += plotter_mapping_names_str + head_str += "\n\n" + head_str += config_mapping_names_str + head_str += "\n\n" + head_str += pltconfig_mapping_names_str + head_str += "\n\n" + head_str += processor_category_mapping_names_str + head_str += "\n\n" + head_str += processor_type_mapping_names_str + head_str += "\n\n" + + regex = re.compile(r"\W") + + for key, value in dataset2config.items(): + mapping_name = f"{regex.sub('_', key).upper()}_MAPPING_NAMES" + head_str += f"{mapping_name} = OrderedDict(\n [\n" + for k, v in value.items(): + head_str += f' ("{k}", "{v}"),\n' + head_str += " ]\n)\n\n" + + for key, value in dataset2pltconfig.items(): + mapping_name = f"{regex.sub('_', key).upper()}_PLOTTER_MAPPING_NAMES" + head_str += f"{mapping_name} = OrderedDict(\n [\n" + for k, v in value.items(): + head_str += f' ("{k}", "{v}"),\n' + head_str += " ]\n)" + head_str += "\n\n" + + mapping_str = "CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES)\n" + mapping_str += ( + "PLOTTER_CONFIG_MAPPING = _LazyConfigMapping(PLOTTER_CONFIG_MAPPING_NAMES)\n" + ) + + for key, value in dataset2config.items(): + mapping_name = f"{regex.sub('_', key).upper()}_MAPPING_NAMES" + mapping_val = f"{regex.sub('_', key).upper()}_MAPPING" + mapping_str += f"{mapping_val} = _LazyConfigMapping({mapping_name})\n" + + for key, value in dataset2pltconfig.items(): + mapping_name = f"{regex.sub('_', key).upper()}_PLOTTER_MAPPING_NAMES" + mapping_val = f"{regex.sub('_', key).upper()}_PLOTTER_MAPPING" + mapping_str += f"{mapping_val} = _LazyConfigMapping({mapping_name})\n" + + ds2mapper = "DATASET_TO_MAPPER = {" + for key, value in dataset2config.items(): + mapping_val = f"{regex.sub('_', key).upper()}_MAPPING" + ds2mapper += f'"{key}": {mapping_val},' + ds2mapper += "}" + + ds2mapper_names = "DATASET_TO_MAPPER_NAMES = {" + for key, value in dataset2config.items(): + mapping_name = f"{regex.sub('_', key).upper()}_MAPPING_NAMES" + ds2mapper_names += f'"{key}": {mapping_name},' + ds2mapper_names += "}" + + dsplt2mapper = "DATASET_PLT_TO_MAPPER = {" + for key, value in dataset2pltconfig.items(): + mapping_val = f"{regex.sub('_', key).upper()}_PLOTTER_MAPPING" + dsplt2mapper += f'"{key}": {mapping_val},' + dsplt2mapper += "}" + + dsplt2mapper_names = "DATASET_PLT_TO_MAPPER_NAMES = {" + for key, value in dataset2pltconfig.items(): + mapping_name = f"{regex.sub('_', key).upper()}_PLOTTER_MAPPING_NAMES" + dsplt2mapper_names += f'"{key}": {mapping_name},' + dsplt2mapper_names += "}" + + mapping_str += "\n\n" + mapping_str += ds2mapper + mapping_str += "\n\n" + mapping_str += ds2mapper_names + mapping_str += "\n\n" + mapping_str += dsplt2mapper + mapping_str += "\n\n" + mapping_str += dsplt2mapper_names + mapping_str += "\n\n" + + return head_str, mapping_str + + +def replace_mapping(head_str, mapping_str, file_path): + with open(file_path, "r") as file: + source_code = file.read() + + tree = ast.parse(source_code) + + head_start_point = None + head_end_point = None + mapping_start_point = None + mapping_end_point = None + + _head_insert_point = 0 + # Find where to end head_str + for node in ast.walk(tree): + if not head_start_point and ( + isinstance(node, ast.ImportFrom) or isinstance(node, ast.Import) + ): + _head_insert_point = node.lineno + 1 + + if ( + not head_start_point + and not isinstance(node, ast.ImportFrom) + and not isinstance(node, ast.Import) + ): + head_start_point = _head_insert_point + + if ( + isinstance(node, ast.FunctionDef) + and node.name == "config_class_to_model_type" + ): + head_end_point = node.lineno - 1 + + if isinstance(node, ast.ClassDef) and node.name == "_LazyConfigMapping": + mapping_start_point = node.end_lineno + 1 + + if isinstance(node, ast.Assign): + if ( + hasattr(node.targets[0], "id") + and node.targets[0].id == "DATASET_PREPROCESSOR_MAPPING_NAMES" + ): + mapping_end_point = node.lineno - 1 + + splitted_source_code = re.split(r"[\r\n]", source_code) + new_source_code = ( + splitted_source_code[:head_start_point] + + [f"\n{head_str}"] + + splitted_source_code[head_end_point:mapping_start_point] + + [f"\n{mapping_str}"] + + splitted_source_code[mapping_end_point:] + ) + + new_source_code = "\n".join(new_source_code) + + with open(file_path, "w") as file: + file.write(new_source_code) + + file_path = Path(file_path).resolve().as_posix() + # use ruff to format the code + os.system(f"ruff format {file_path}") + + +if __name__ == "__main__": + config_classes = get_all_processor_configs("../src") + classes = get_all_processors("../src", config_classes=config_classes) + plotter_config_classes = get_all_processor_configs( + "../src", known_classes="PlotterConfig" + ) + plotter_classes = get_all_processors( + "../src", config_classes=plotter_config_classes, known_classes="BasePlotter" + ) + name2proc, name2config, _, _, dataset2config = plotter_classes + head_str, mapping_str = create_config_mapping_constants( + *classes, name2proc, name2config, dataset2config + ) + + file_path = "../src/biofit/auto/configuration_auto.py" + replace_mapping(head_str, mapping_str, file_path)