diff --git a/.github/conda/build.sh b/.github/conda/build.sh
new file mode 100644
index 0000000..a660906
--- /dev/null
+++ b/.github/conda/build.sh
@@ -0,0 +1 @@
+$PYTHON setup.py install --single-version-externally-managed --record=record.txt
diff --git a/.github/conda/meta.yaml b/.github/conda/meta.yaml
new file mode 100644
index 0000000..5a9d230
--- /dev/null
+++ b/.github/conda/meta.yaml
@@ -0,0 +1,46 @@
+{% set name = "biofit" %}
+
+package:
+ name: "{{ name|lower }}"
+ version: "{{ BIOFIT_VERSION }}"
+
+source:
+ path: ../../
+
+build:
+ noarch: python
+requirements:
+ host:
+ - python
+ - pip
+ run:
+ - python
+ - pip
+ - biocore
+ - filelock
+ - numpy >=1.17
+ - pyarrow >=8.0.0
+ - pyarrow-hotfix
+ - dill >=0.3.0,<0.3.8
+ - pandas
+ - requests >=2.19.0
+ - tqdm >=4.62.1
+ - xxhash
+ - multiprocess
+ - fsspec[http] >=2023.1.0,<=2023.10.0
+ - packaging
+ - pyyaml >=5.1
+ - scikit-learn >=1.4.0
+test:
+ imports:
+ - biofit
+about:
+ home: https://github.com/psmyth94/biofit
+ license: Apache-2.0
+ license_file: LICENSE
+ summary: A Python package for machine learning on omics datasets.
+ keywords:
+ - omics
+ - machine learning
+ - bioinformatics
+ - datasets
diff --git a/.github/workflows/ci_cd_pipeline.yml b/.github/workflows/ci_cd_pipeline.yml
new file mode 100644
index 0000000..570ddd8
--- /dev/null
+++ b/.github/workflows/ci_cd_pipeline.yml
@@ -0,0 +1,60 @@
+name: CI/CD Pipeline
+on:
+ push:
+ branches:
+ - main
+ - "ci-*"
+ pull_request:
+ branches:
+ - main
+jobs:
+ check_code_quality:
+ name: Check Code Quality
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+ - name: Set up Python 3.8
+ uses: actions/setup-python@v5
+ with:
+ python-version: "3.8"
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install .[quality]
+ - name: Run Ruff code quality checks
+ run: |
+ ruff check tests src setup.py --output-format=github
+ ruff format --check tests src setup.py
+ tests:
+ needs: check_code_quality
+ strategy:
+ matrix:
+ test: [unit]
+ os: [ubuntu-latest, windows-latest]
+ python-version: ["3.8", "3.9", "3.10", "3.11"]
+ runs-on: ${{ matrix.os }}
+ continue-on-error: ${{ matrix.test == 'integration' }}
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install uv --upgrade
+ uv pip install --system "biofit[test] @ ."
+ - name: Run Tests
+ run: |
+ mkdir -p reports/${{ matrix.test }}_tests
+ python3 -m pytest -m ${{ matrix.test }} -n 2 --dist loadfile -sv ./tests/ --junitxml=reports/${{ matrix.test }}_tests/results.xml
+ - name: Upload Unit Test Results
+ uses: actions/upload-artifact@v3
+ with:
+ name: ${{ matrix.test }}-test-reports
+ path: reports/${{ matrix.test }}_tests/results.xml
diff --git a/.github/workflows/conda-release.yml b/.github/workflows/conda-release.yml
new file mode 100644
index 0000000..818bceb
--- /dev/null
+++ b/.github/workflows/conda-release.yml
@@ -0,0 +1,36 @@
+name: Conda - Build
+on:
+ push:
+ tags:
+ - "[0-9]+.[0-9]+.[0-9]+*"
+env:
+ ANACONDA_API_TOKEN: ${{ secrets.ANACONDA_API_TOKEN }}
+jobs:
+ build_and_package:
+ runs-on: ubuntu-22.04
+ defaults:
+ run:
+ shell: bash -l {0}
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+ - name: Install miniconda
+ uses: conda-incubator/setup-miniconda@v2
+ with:
+ auto-update-conda: true
+ auto-activate-base: false
+ activate-environment: "build-biofit"
+ python-version: 3.8
+ channels: patrico49
+ - name: Setup conda env
+ run: |
+ conda install -c defaults anaconda-client conda-build
+ - name: Extract version
+ run: echo "BIOFIT_VERSION=`python setup.py --version`" >> $GITHUB_ENV
+ - name: Build conda packages
+ run: |
+ conda info
+ conda build .github/conda
+ - name: Upload to Anaconda
+ run: |
+ anaconda upload `conda build .github/conda --output -c conda-forge` --force
diff --git a/.github/workflows/pypi-release.yml b/.github/workflows/pypi-release.yml
new file mode 100644
index 0000000..7155722
--- /dev/null
+++ b/.github/workflows/pypi-release.yml
@@ -0,0 +1,35 @@
+name: PyPI - Release
+on:
+ push:
+ tags:
+ - "[0-9]+.[0-9]+.[0-9]+*"
+env:
+ PYPI_API_TOKEN: ${{ secrets.PYPI_API_TOKEN }}
+jobs:
+ build_and_publish:
+ runs-on: ubuntu-22.04
+ defaults:
+ run:
+ shell: bash -l {0}
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: 3.8
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install setuptools wheel twine
+ - name: Extract version
+ run: echo "BIOFIT_VERSION=$(python setup.py --version)" >> $GITHUB_ENV
+ - name: Build package
+ run: |
+ python setup.py sdist bdist_wheel
+ - name: Publish to PyPI
+ run: |
+ python -m twine upload dist/* --non-interactive
+ env:
+ TWINE_USERNAME: __token__
+ TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..0c55a78
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,132 @@
+results/
+*.lock
+_dev
+# Ignore Nextflow-related files and directories
+.nextflow*
+work
+_dev/in_progress*
+.dev
+
+# Ignore data directory
+tests/data
+
+# anything not ipynb related
+tutorials/*
+!tutorials/*.ipynb
+
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+.conda
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+cache*.arrow
+cache*.yaml
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+-checkpoint.ipynb
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+# According to https://pipenv.pypa.io/en/latest/advanced/#pipfile-vs-pipfilelock
+# PIPENV_VENV_IN_PROJECT=1
+Pipfile.lock
+
+# Poetry
+# https://python-poetry.org/docs/basic-usage/#configuring-poetry
+.poetry/
+
+# virtualenv
+venv/
+env/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# Idea project settings
+.idea/
+
+# PyCharm project settings
+.idea/
+
+# Visual Studio Code settings
+.vscode/
+.vs/
+
+# MyPy
+.mypy_cache/
+
+# pycodestyle
+.pytest_cache/
+
+#Docker
+*Dockerfile
+*dockerignore
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..c3751a6
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,9 @@
+repos:
+ - repo: https://github.com/charliermarsh/ruff-pre-commit # https://github.com/charliermarsh/ruff#usage
+ rev: 'v0.3.0'
+ hooks:
+ # Run the linter.
+ - id: ruff
+ args: [ --fix ]
+ # Run the formatter.
+ - id: ruff-format
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..d645695
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 0000000..3032302
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1 @@
+recursive-include src/biofit *.R
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000..695cf22
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,20 @@
+.PHONY: quality style test
+
+check_dirs := tests src
+
+# Check that source code meets quality standards
+
+quality:
+ ruff check $(check_dirs) setup.py # linter
+ ruff format --check $(check_dirs) setup.py # formatter
+
+# Format source code automatically
+
+style:
+ ruff check --fix $(check_dirs) setup.py # linter
+ ruff format $(check_dirs) setup.py # formatter
+
+# Run tests for the library
+
+test:
+ python -m pytest -n auto --dist=loadfile -s -v ./tests/
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..3873972
--- /dev/null
+++ b/README.md
@@ -0,0 +1,149 @@
+
+ $${\Huge{\textbf{\textsf{\color{#2E8B57}Bio\color{#4682B4}fit}}}}$$
+
+
+
+
+
+
+
+
+
+
+
+
+**Biofit** is a machine learning library designed for bioinformatics datasets. It
+provides tools for transforming, extracting, training, and evaluating machine learning
+models on biomedical data. It also provides automatic data preprocessing, visualization,
+and configurable processing pipelines. Here are some of the main features of Biofit:
+
+- **Automatic Data Preprocessing:** Automatically preprocess biomedical datasets using
+ built-in preprocessing steps.
+- **Automatic Visualization:** Automatically visualize data using built-in visualization
+ methods geared towards biomedical data.
+- **Configurable Processing Pipelines:** Define and customize data processing pipelines.
+- **Data Handling Flexibility:** Support for a wide range of data formats, including:
+ - [Pandas](https://github.com/pandas-dev/pandas)
+ - [Polars](https://github.com/pola-rs/polars)
+ - [NumPy](https://github.com/numpy/numpy)
+ - [CSR (SciPy)](https://github.com/scipy/scipy)
+ - [Arrow](https://github.com/apache/arrow)
+ - 🤗 [Datasets](https://github.com/huggingface/datasets)
+ - [Biosets](https://github.com/psmyth94/biosets)
+- **Machine Learning Models:** Supports a wide range of machine learning models, including:
+ - [Scikit-learn](https://github.com/scikit-learn/scikit-learn)
+ - [Random Forest](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html)
+ - [Support Vector Machine](https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html)
+ - [Logistic Regression](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html)
+ - [Lasso Regression](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Lasso.html)
+ - [LightGBM](https://github.com/microsoft/LightGBM)
+ - More to come!
+- **Caching and Reuse:** Caches intermediate results using Apache Arrow for efficient reuse.
+- **Batch Processing and Multiprocessing:** Utilize batch processing and multiprocessing for efficient handling of large-scale data.
+
+## Installation
+
+You can install Biofit via pip:
+
+```bash
+pip install biofit
+```
+
+## Quick Start
+
+### Preprocessing Data
+
+Biofit provides preprocessing capabilities tailored for omics data. You can use
+built-in classes to load preprocessing steps based on the experiment type or create
+custom preprocessing pipelines. The preprocessing pipeline in Biofit uses a syntax
+similar to sklearn and supports distributed processing.
+
+#### Using a Preprocessor
+
+Biofit allows you to fit and transform your data in a few lines, similar to sklearn.
+For example, you can use the LogTransformer to apply a log transformation to your data:
+
+```python
+from biofit.preprocessing import LogTransformer
+import pandas as pd
+
+dataset = pd.DataFrame({"feature1": [1, 2, 3, 4, 5]})
+log_transformer = LogTransformer()
+preprocessed_data = log_transformer.fit_transform(dataset)
+# Applying log transformation: 100%|█████████████████████████████| 5/5 [00:00<00:00, 7656.63 examples/s]
+print(preprocessed_data)
+# feature1
+# 0 0.000000
+# 1 0.693147
+# 2 1.098612
+# 3 1.386294
+# 4 1.609438
+```
+
+#### Auto Preprocessing
+
+You can automatically apply standard preprocessing steps by specifying the experiment
+type. This allows you to load tailored preprocessing steps for the type of data you are
+working with, such as "otu", "asv", "snp", or "maldi":
+
+```python
+from biofit.preprocessing import AutoPreprocessor
+
+preprocessor = AutoPreprocessor.for_experiment("snp", [{"min_prevalence": 0.1}, None])
+print(preprocessor)
+# [('min_prevalence_row', MinPrevalencFilter(min_prevalence=0.1)),
+# ('min_prevalence', MinPrevalenceFeatureSelector(min_prevalence=0.01))]
+
+# Fit and transform the dataset using the preprocessor
+preprocessed_data = preprocessor.fit_transform(dataset)
+```
+
+Biofit is made with Biosets in mind. You can pass the loaded dataset instead of a string
+to load the preprocessors:
+
+```python
+from biosets import load_dataset
+
+dataset = load_dataset("csv", data_files="my_file.csv", experiment_type="snp")
+
+preprocessor = AutoPreprocessor.for_experiment(dataset)
+print(preprocessor)
+# [('min_prevalence_row', MinPrevalencFilter(min_prevalence=0.01)),
+# ('min_prevalence', MinPrevalenceFeatureSelector(min_prevalence=0.01))]
+preprocessed_data = preprocessor.fit_transform(dataset)
+```
+
+#### Custom Preprocessing Pipeline
+
+Biofit allows you to create custom preprocessing pipelines using the
+`PreprocessorPipeline` class. This allows chaining multiple preprocessing steps from
+`sklearn` and Biofit in a single operation:
+
+```python
+from biosets import load_dataset
+from biofit.preprocessing import LogTransformer, PreprocessorPipeline
+from sklearn.preprocessing import StandardScaler
+
+# Load the dataset
+dataset = load_dataset("csv", data_files="my_file.csv")
+
+# Define a custom preprocessing pipeline
+pipeline = PreprocessorPipeline(
+ [("scaler", StandardScaler()), ("log_transformer", LogTransformer())]
+)
+
+# Fit and transform the dataset using the pipeline
+preprocessed_data = pipeline.fit_transform(dataset.to_pandas())
+```
+
+For further details, check the [advance usage documentation](./docs/PREPROCESSING.md).
+
+# License
+
+Biofit is licensed under the Apache 2.0 License. See the [LICENSE](./LICENSE) file for
+more information.
+
+# Contributing
+
+If you would like to contribute to Biofit, please read the
+[CONTRIBUTING](./CONTRIBUTING.md) guidelines.
diff --git a/docs/.gitkeep b/docs/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md
new file mode 100644
index 0000000..992fbfa
--- /dev/null
+++ b/docs/CONTRIBUTING.md
@@ -0,0 +1,108 @@
+# Contributing to Biofit
+
+Biofit is an open source project, so all contributions and suggestions are welcome.
+
+You can contribute in many different ways: giving ideas, answering questions, reporting
+bugs, proposing enhancements, improving the documentation, fixing bugs, and more.
+
+Many thanks in advance to every contributor.
+
+## How to Work on an Open Issue?
+
+You have the list of open Issues at:
+[Biofit Issues](https://github.com/psmyth94/biofit/issues)
+
+## How to Create a Pull Request?
+
+If you want to contribute to the codebase, follow these steps:
+
+1. **Clone the Repository:**
+
+ Clone the `dev` branch of the repository to your local disk:
+
+ ```bash
+ git clone git@github.com:psmyth94/biofit
+ cd biofit
+ ```
+
+2. **Create a New Branch:**
+
+ Create a new branch to hold your development changes:
+
+ ```bash
+ git checkout -b a-descriptive-name-for-my-changes
+ ```
+
+ Do not work on the `main` branch directly.
+
+3. **Set Up a Development Environment:**
+
+ Set up a development environment by running the following command:
+
+ ```bash
+ mamba env create -n biofit-local python=3.10
+ mamba activate biofit-local
+ pip install -e ".[test]"
+ ```
+
+ (If Biofit was already installed in the virtual environment, remove it with
+ `pip uninstall biofit` before reinstalling it in editable mode with the `-e` flag.)
+
+4. **Develop the Features on Your Branch:**
+
+ Make your changes to the code.
+
+5. **Format Your Code:**
+
+ Format your code. Run `ruff` so that your newly added files look nice with the
+ following command:
+
+ ```bash
+ ruff check . --fix
+ ```
+
+6. **(Optional) Use Pre-commit Hooks:**
+
+ You can also use pre-commit to format your code automatically each time you run
+ `git commit`, instead of running `ruff` manually. To do this, install pre-commit via
+ `pip install pre-commit` and then run `pre-commit install` in the project's root
+ directory to set up the hooks. Note that if any files were formatted by pre-commit
+ hooks during committing, you have to run `git commit` again.
+
+7. **Commit Your Changes:**
+
+ Once you're happy with your contribution, add your changed files and make a commit
+ to record your changes locally:
+
+ ```bash
+ git add -u
+ git commit -m "Your commit message"
+ ```
+
+8. **Sync with the Original Repository:**
+
+ It is a good idea to sync your copy of the code with the original repository
+ regularly. This way you can quickly account for changes:
+
+ ```bash
+ git fetch upstream
+ git rebase upstream main
+ ```
+
+9. **Push the Changes:**
+
+ Once you are satisfied, push the changes to the remote repository using:
+
+ ```bash
+ git push origin a-descriptive-name-for-my-changes
+ ```
+
+10. **Create a Pull Request:**
+
+ Go to the webpage of the repository on GitHub. Click on "Pull request" to send
+ your changes to the project maintainers for review.
+
+## Code of Conduct
+
+This project adheres to the Contributor Covenant code of conduct. By participating, you
+are expected to abide by this code.
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000..cde56b0
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,16 @@
+[metadata]
+license_files = LICENSE
+
+[tool:ruff]
+line-length = 88
+select = ["E", "F", "W", "C90"]
+ignore = ["E501"]
+exclude = ["*.ipynb"]
+
+[tool:pytest]
+# Test fails if a FutureWarning is thrown by `huggingface_hub`
+filterwarnings =
+ error::FutureWarning:huggingface_hub*
+markers =
+ unit: unit test
+ integration: integration test
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..3dbe6be
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,129 @@
+from setuptools import find_packages, setup
+
+REQUIRED_PKGS = [
+ # the library that biofit is built upon
+ "biocore>=1.1.0",
+ # For file locking
+ "filelock",
+ # We use numpy>=1.17 to have np.random.Generator (Dataset shuffling)
+ "numpy>=1.17",
+ # Backend and serialization.
+ # Minimum 8.0.0 to be able to use .to_reader()
+ "pyarrow>=8.0.0",
+ # As long as we allow pyarrow < 14.0.1, to fix vulnerability CVE-2023-47248
+ "pyarrow-hotfix",
+ # For smart caching dataset processing
+ "dill>=0.3.0,<0.3.8", # tmp pin until dill has official support for determinism see https://github.com/uqfoundation/dill/issues/19
+ # For performance gains with apache arrow
+ "pandas",
+ # for downloading datasets over HTTPS
+ "requests>=2.19.0",
+ # progress bars in download and scripts
+ "tqdm>=4.62.1",
+ # for fast hashing
+ "xxhash",
+ # for better multiprocessing
+ "multiprocess",
+ # to save datasets locally or on any filesystem
+ # minimum 2023.1.0 to support protocol=kwargs in fsspec's `open`, `get_fs_token_paths`, etc.: see https://github.com/fsspec/filesystem_spec/pull/1143
+ "fsspec[http]>=2023.1.0,<=2023.10.0",
+ # Utilities from PyPA to e.g., compare versions
+ "packaging",
+ # To parse YAML metadata from dataset cards
+ "pyyaml>=5.1",
+ # for processing and transforming datasets
+ "scikit-learn>=1.4.0",
+]
+
+QUALITY_REQUIRE = ["ruff>=0.1.5"]
+
+DOCS_REQUIRE = [
+ # Might need to add doc-builder and some specific deps in the future
+ "s3fs",
+]
+
+VISUALIZATION_REQUIRE = [
+ "matplotlib",
+ "seaborn",
+ "psutil", # for the start time of the biofit run
+]
+
+
+ML_REQUIRE = [
+ "polars>=0.20.5",
+ "timezones>=0.10.2",
+ "optuna",
+ "lightgbm",
+ # "xgboost",
+ # "catboost",
+ "imbalanced-learn",
+]
+
+TESTS_REQUIRE = ["pytest", "pytest-timeout", "pytest-xdist"]
+
+
+EXTRAS_REQUIRE = {
+ "polars": ["polars>=0.20.5", "timezones>=0.10.2"],
+ "rpy2": ["rpy2>=3.5.15", "rpy2-arrow>=0.0.8"],
+ "ml": ML_REQUIRE,
+ "apache-beam": ["apache-beam>=2.26.0,<2.44.0"],
+ "vcf": ["cyvcf2>=0.30.0", "sgkit>=0.0.1"],
+ "tensorflow": [
+ "tensorflow>=2.2.0,!=2.6.0,!=2.6.1; sys_platform != 'darwin' or platform_machine != 'arm64'",
+ "tensorflow-macos; sys_platform == 'darwin' and platform_machine == 'arm64'",
+ ],
+ "tensorflow_gpu": ["tensorflow-gpu>=2.2.0,!=2.6.0,!=2.6.1"],
+ "torch": ["torch"],
+ "jax": ["jax>=0.3.14", "jaxlib>=0.3.14"],
+ "s3": ["s3fs"],
+ "viz": VISUALIZATION_REQUIRE,
+ "test": QUALITY_REQUIRE
+ + TESTS_REQUIRE
+ + DOCS_REQUIRE
+ + VISUALIZATION_REQUIRE
+ + ML_REQUIRE,
+ "all": VISUALIZATION_REQUIRE + ML_REQUIRE + QUALITY_REQUIRE + DOCS_REQUIRE,
+ "quality": QUALITY_REQUIRE,
+ "docs": DOCS_REQUIRE,
+}
+
+setup(
+ name="biofit",
+ version="0.0.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
+ description="BioFit: Bioinformatics Machine Learning Framework",
+ long_description=open("README.md", encoding="utf-8").read(),
+ long_description_content_type="text/markdown",
+ author="Patrick Smyth",
+ author_email="psmyth1994@gmail.com",
+ url="https://github.com/psmyth94/biofit",
+ download_url="https://github.com/psmyth94/biofit/tags",
+ license="MIT",
+ package_dir={"": "src"},
+ packages=find_packages("src"),
+ include_package_data=True,
+ python_requires=">=3.8.0,<3.12.0",
+ install_requires=REQUIRED_PKGS,
+ extras_require=EXTRAS_REQUIRE,
+ entry_points={
+ "console_scripts": [
+ "biofit = biofit.cli.main:main",
+ ],
+ },
+ classifiers=[
+ "Development Status :: 5 - Production/Stable",
+ "Intended Audience :: Developers",
+ "Intended Audience :: Education",
+ "Intended Audience :: Science/Research",
+ "License :: OSI Approved :: MIT License",
+ "Operating System :: OS Independent",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ "Topic :: Scientific/Engineering :: Bio-Informatics",
+ ],
+ keywords="omics machine learning bioinformatics metrics",
+ zip_safe=False, # Required for mypy to find the py.typed file
+)
diff --git a/src/biofit/__init__.py b/src/biofit/__init__.py
new file mode 100644
index 0000000..078865d
--- /dev/null
+++ b/src/biofit/__init__.py
@@ -0,0 +1,15 @@
+# ruff: noqa
+from .auto import *
+from .models import *
+from .train import train
+from .utils import (
+ disable_progress_bar,
+ enable_progress_bar,
+ logging,
+ set_verbosity,
+ set_verbosity_debug,
+ set_verbosity_error,
+ set_verbosity_info,
+)
+from .utils.version import __version__
+from .visualization import *
diff --git a/src/biofit/__main__.py b/src/biofit/__main__.py
new file mode 100644
index 0000000..67a8cb8
--- /dev/null
+++ b/src/biofit/__main__.py
@@ -0,0 +1,15 @@
+import argparse
+
+from .__version__ import __version__
+
+
+def main():
+ print("use biofit.cli instead. available commands: --version")
+ # print(f"{__file__}")
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--version", action="version", version=f"{__version__}")
+ parser.parse_args()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/biofit/auto/__init__.py b/src/biofit/auto/__init__.py
new file mode 100644
index 0000000..cd528a4
--- /dev/null
+++ b/src/biofit/auto/__init__.py
@@ -0,0 +1,11 @@
+# ruff: noqa
+from .configuration_auto import (
+ AutoConfig,
+ AutoPreprocessorConfig,
+)
+from .modeling_auto import (
+ AutoModel,
+ AutoModelForClassification,
+)
+from .processing_auto import AutoProcessor, AutoPreprocessor
+from .plotting_auto import AutoPlotter, PlotterPipeline
diff --git a/src/biofit/auto/auto_factory.py b/src/biofit/auto/auto_factory.py
new file mode 100644
index 0000000..60917e8
--- /dev/null
+++ b/src/biofit/auto/auto_factory.py
@@ -0,0 +1,338 @@
+import importlib
+from collections import OrderedDict
+from typing import TYPE_CHECKING
+
+from biocore.utils.import_util import is_transformers_available
+
+from biofit.utils import logging
+
+from ..auto.configuration_auto import (
+ PROCESSOR_CATEGORY_MAPPING_NAMES,
+ PROCESSOR_TYPE_MAPPING_NAMES,
+ AutoConfig,
+)
+
+if TYPE_CHECKING:
+ from biofit.processing import BaseProcessor
+
+logger = logging.get_logger(__name__)
+
+
+def _get_model_class(config, model_mapping: "_LazyAutoMapping"):
+ supported_models = model_mapping[type(config)]
+ if not isinstance(supported_models, (list, tuple)):
+ return supported_models
+
+ name_to_model = {model.__name__: model for model in supported_models}
+ architectures = getattr(config, "architectures", [])
+ for arch in architectures:
+ if arch in name_to_model:
+ return name_to_model[arch]
+ elif f"TF{arch}" in name_to_model:
+ return name_to_model[f"TF{arch}"]
+ elif f"Flax{arch}" in name_to_model:
+ return name_to_model[f"Flax{arch}"]
+
+ # If not architecture is set in the config or match the supported models, the first element of the tuple is the
+ # defaults.
+ return supported_models[0]
+
+
+def _get_class(config, model_mapping):
+ supported_models = model_mapping[type(config)]
+ return supported_models
+
+
+class _BaseAutoProcessorClass:
+ # Base class for auto preprocessors.
+ _processor_mapping = None
+
+ def __init__(self, *args, **kwargs):
+ raise EnvironmentError(
+ f"{self.__class__.__name__} is designed to be instantiated "
+ f"`{self.__class__.__name__}.from_config(config)` methods."
+ )
+
+ @classmethod
+ def for_processor(self, processor_name, **kwargs):
+ config = AutoConfig.for_processor(processor_name, **kwargs)
+ return self.from_config(config, **kwargs)
+
+ @classmethod
+ def from_config(cls, config, **kwargs):
+ if type(config) in cls._processor_mapping.keys():
+ preprocessor_class: "BaseProcessor" = _get_class(
+ config, cls._processor_mapping
+ )
+ return preprocessor_class._from_config(config, **kwargs)
+
+ raise ValueError(
+ f"Unrecognized configuration class {config.__class__} for this kind of AutoProcessor: {cls.__name__}.\n"
+ f"Processor type should be one of {', '.join(c.__name__ for c in cls._processor_mapping.keys())}."
+ )
+
+ @classmethod
+ def register(cls, _config_class, processor_class: "BaseProcessor", exist_ok=False):
+ """
+ Register a new model for this class.
+
+ Args:
+ _config_class ([`PretrainedConfig`]):
+ The configuration corresponding to the model to register.
+ model_class ([`PreTrainedModel`]):
+ The model to register.
+ """
+ if hasattr(processor_class, "_config_class") and issubclass(
+ _config_class, processor_class._config_class
+ ):
+ raise ValueError(
+ "The model class you are passing has a `_config_class` attribute that is not consistent with the "
+ f"config class you passed (model has {processor_class._config_class} and you passed {_config_class}. Fix "
+ "one of those so they match!"
+ )
+ cls._processor_mapping.register(
+ _config_class, processor_class, exist_ok=exist_ok
+ )
+
+
+class _BaseAutoModelClass:
+ # Base class for auto models.
+ _processor_mapping = None
+
+ def __init__(self, *args, **kwargs):
+ raise EnvironmentError(
+ f"{self.__class__.__name__} is designed to be instantiated "
+ f"`{self.__class__.__name__}.from_config(config)` methods."
+ )
+
+ @classmethod
+ def for_model(cls, model_name, *model_args, **kwargs):
+ config = AutoConfig.for_processor(model_name, **kwargs)
+ return cls.from_config(config, *model_args, **kwargs)
+
+ @classmethod
+ def from_config(cls, config, **kwargs):
+ if type(config) in cls._processor_mapping.keys():
+ model_class: "BaseProcessor" = _get_model_class(
+ config, cls._processor_mapping
+ )
+ return model_class._from_config(config, **kwargs)
+
+ if is_transformers_available():
+ logger.warning(
+ "Could not find a matching model for this configuration. "
+ "Searching for a model in the Transformers library instead."
+ )
+
+ from transformers.models.auto.auto_factory import (
+ _BaseAutoModelClass as _HfBaseAutoModelClass,
+ )
+
+ return _HfBaseAutoModelClass.from_config(config, **kwargs)
+
+ raise ValueError(
+ f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
+ f"Model type should be one of {', '.join(c.__name__ for c in cls._processor_mapping.keys())}."
+ )
+
+ @classmethod
+ def from_pretrained(cls, pretrained_estimator_name_or_path, *model_args, **kwargs):
+ if is_transformers_available():
+ from transformers.models.auto.auto_factory import (
+ _BaseAutoModelClass as _HfBaseAutoModelClass,
+ )
+
+ return _HfBaseAutoModelClass.from_pretrained(
+ pretrained_estimator_name_or_path, *model_args, **kwargs
+ )
+ else:
+ raise EnvironmentError(
+ f"Using {cls.__name__}.from_pretrained requires the transformers library to be installed. "
+ "You can install it with `pip install transformers`"
+ )
+
+ @classmethod
+ def register(cls, _config_class, processor_class: "BaseProcessor", exist_ok=False):
+ """
+ Register a new model for this class.
+
+ Args:
+ _config_class ([`PretrainedConfig`]):
+ The configuration corresponding to the model to register.
+ model_class ([`PreTrainedModel`]):
+ The model to register.
+ """
+ if hasattr(processor_class, "_config_class") and issubclass(
+ _config_class, processor_class._config_class
+ ):
+ raise ValueError(
+ "The model class you are passing has a `_config_class` attribute that is not consistent with the "
+ f"config class you passed (model has {processor_class._config_class} and you passed {_config_class}. Fix "
+ "one of those so they match!"
+ )
+ cls._processor_mapping.register(
+ _config_class, processor_class, exist_ok=exist_ok
+ )
+
+
+def insert_head_doc(docstring, head_doc=""):
+ if len(head_doc) > 0:
+ return docstring.replace(
+ "one of the model classes of the library ",
+ f"one of the model classes of the library (with a {head_doc} head) ",
+ )
+ return docstring.replace(
+ "one of the model classes of the library ",
+ "one of the base model classes of the library ",
+ )
+
+
+def get_values(model_mapping):
+ result = []
+ for model in model_mapping.values():
+ if isinstance(model, (list, tuple)):
+ result += list(model)
+ else:
+ result.append(model)
+
+ return result
+
+
+def getattribute_from_module(module, attr):
+ if attr is None:
+ return None
+ if isinstance(attr, tuple):
+ return tuple(getattribute_from_module(module, a) for a in attr)
+ if hasattr(module, attr):
+ return getattr(module, attr)
+ # Some of the mappings have entries estimator_type -> object of another model type. In that case we try to grab the
+ # object at the top level.
+ biofit_module = importlib.import_module("biofit")
+
+ if module != biofit_module:
+ try:
+ return getattribute_from_module(biofit_module, attr)
+ except ValueError:
+ raise ValueError(
+ f"Could not find {attr} neither in {module} nor in {biofit_module}!"
+ )
+ else:
+ raise ValueError(f"Could not find {attr} in {biofit_module}!")
+
+
+class _LazyAutoMapping(OrderedDict):
+ """
+ " A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.
+
+ Args:
+ - config_mapping: The map model type to config class
+ - model_mapping: The map model type to model (or tokenizer) class
+ """
+
+ def __init__(self, config_mapping, processor_mapping):
+ self._config_mapping = config_mapping
+ self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
+ self._processor_mapping = processor_mapping
+ self._processor_mapping._processor_mapping = self
+ self._extra_content = {}
+ self._modules = {}
+
+ def __len__(self):
+ common_keys = set(self._config_mapping.keys()).intersection(
+ self._processor_mapping.keys()
+ )
+ return len(common_keys) + len(self._extra_content)
+
+ def __getitem__(self, key):
+ if key in self._extra_content:
+ return self._extra_content[key]
+ processor_type = self._reverse_config_mapping[key.__name__]
+ if processor_type in self._processor_mapping:
+ processor_name = self._processor_mapping[processor_type]
+ return self._load_attr_from_module(processor_type, processor_name)
+
+ # Maybe there was several model types associated with this config.
+ estimator_types = [
+ k for k, v in self._config_mapping.items() if v == key.__name__
+ ]
+ for ptype in estimator_types:
+ if ptype in self._processor_mapping:
+ processor_name = self._processor_mapping[ptype]
+ return self._load_attr_from_module(ptype, processor_name)
+ raise KeyError(key)
+
+ def _load_attr_from_module(self, module_name, attr):
+ if module_name not in self._modules:
+ processor_category = PROCESSOR_CATEGORY_MAPPING_NAMES.get(
+ module_name, "models"
+ )
+ processor_type = PROCESSOR_TYPE_MAPPING_NAMES.get(module_name, None)
+ if processor_type is not None:
+ package_name = f"biofit.{processor_category}.{processor_type}"
+ else:
+ package_name = f"biofit.{processor_category}"
+ self._modules[module_name] = importlib.import_module(
+ f".{module_name}", package_name
+ )
+ return getattribute_from_module(self._modules[module_name], attr)
+
+ def keys(self):
+ mapping_keys = [
+ self._load_attr_from_module(key, name)
+ for key, name in self._config_mapping.items()
+ if key in self._processor_mapping.keys()
+ ]
+ return mapping_keys + list(self._extra_content.keys())
+
+ def get(self, key, default):
+ try:
+ return self.__getitem__(key)
+ except KeyError:
+ return default
+
+ def __bool__(self):
+ return bool(self.keys())
+
+ def values(self):
+ mapping_values = [
+ self._load_attr_from_module(key, name)
+ for key, name in self._processor_mapping.items()
+ if key in self._config_mapping.keys()
+ ]
+ return mapping_values + list(self._extra_content.values())
+
+ def items(self):
+ mapping_items = [
+ (
+ self._load_attr_from_module(key, self._config_mapping[key]),
+ self._load_attr_from_module(key, self._processor_mapping[key]),
+ )
+ for key in self._processor_mapping.keys()
+ if key in self._config_mapping.keys()
+ ]
+ return mapping_items + list(self._extra_content.items())
+
+ def __iter__(self):
+ return iter(self.keys())
+
+ def __contains__(self, item):
+ if item in self._extra_content:
+ return True
+ if (
+ not hasattr(item, "__name__")
+ or item.__name__ not in self._reverse_config_mapping
+ ):
+ return False
+ estimator_type = self._reverse_config_mapping[item.__name__]
+ return estimator_type in self._processor_mapping
+
+ def register(self, key, value, exist_ok=False):
+ """
+ Register a new processor in this mapping.
+ """
+ if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping:
+ estimator_type = self._reverse_config_mapping[key.__name__]
+ if estimator_type in self._processor_mapping.keys() and not exist_ok:
+ raise ValueError(f"'{key}' is already used by a Transformers model.")
+
+ self._extra_content[key] = value
diff --git a/src/biofit/auto/configuration_auto.py b/src/biofit/auto/configuration_auto.py
new file mode 100644
index 0000000..fb324ba
--- /dev/null
+++ b/src/biofit/auto/configuration_auto.py
@@ -0,0 +1,1269 @@
+import importlib
+import re
+import warnings
+from collections import OrderedDict
+from typing import List, Union
+
+from biocore.utils.import_util import (
+ is_biosets_available,
+ is_transformers_available,
+)
+from biocore.utils.inspect import get_kwargs
+
+from biofit.processing import ProcessorConfig
+from biofit.visualization.plotting import PlotterConfig
+
+PROCESSOR_MAPPING_NAMES = OrderedDict(
+ [
+ ("col_sum", "ColumnSumStat"),
+ ("row_sum", "RowSumStat"),
+ ("correlation", "CorrelationStat"),
+ ("distance", "DistanceStat"),
+ ("row_missingness", "RowMissingnessStat"),
+ ("row_mean", "RowMeanStat"),
+ ("col_missingness", "ColumnMissingnessStat"),
+ ("col_mean", "ColumnMeanStat"),
+ ("lightgbm", "LightGBMModel"),
+ ("lasso", "LassoModel"),
+ ("random_forest", "RandomForestModel"),
+ ("logistic_regression", "LogisticRegressionModel"),
+ ("pcoa", "PCoAFeatureExtractor"),
+ ("pca", "PCAFeatureExtractor"),
+ ("label_binarizing", "LabelBinarizer"),
+ ("label_encoding", "LabelEncoder"),
+ ("upsampling", "UpSampler"),
+ ("min_prevalence_feature_selector", "MinPrevalenceFeatureSelector"),
+ ("log", "LogTransformer"),
+ ("relative_abundance", "RelativeAbundanceScaler"),
+ ("tmm", "TMMScaler"),
+ ("css", "CumulativeSumScaler"),
+ ("missing_labels", "MissingLabelsSampleFilter"),
+ ("min_prevalence_sample_filter", "MinPrevalenceSampleFilter"),
+ ("row_abundance", "AbundanceSampleFilter"),
+ ]
+)
+
+PLOTTER_MAPPING_NAMES = OrderedDict(
+ [
+ ("pcoa", "PCoAFeatureExtractorPlotter"),
+ ("min_prevalence_feature_selector", "MinPrevalenceFeatureSelectorPlotter"),
+ ("relative_abundance", "RelativeAbundancePlotter"),
+ ("css", "CumulativeSumScalerPlotter"),
+ ("min_prevalence_sample_filter", "MinPrevalenceRowPlotter"),
+ ("row_abundance", "AbundanceSampleFilterPlotter"),
+ ]
+)
+
+CONFIG_MAPPING_NAMES = OrderedDict(
+ [
+ ("col_sum", "ColumnSumStatConfig"),
+ ("row_sum", "RowSumStatConfig"),
+ ("correlation", "CorrelationStatConfig"),
+ ("distance", "DistanceStatConfig"),
+ ("row_missingness", "RowMissingnessStatConfig"),
+ ("row_mean", "RowMeanStatConfig"),
+ ("col_missingness", "ColumnMissingnessStatConfig"),
+ ("col_mean", "ColumnMeanStatConfig"),
+ ("lightgbm", "LightGBMConfig"),
+ ("lasso", "LassoConfig"),
+ ("random_forest", "RandomForestConfig"),
+ ("logistic_regression", "LogisticRegressionConfig"),
+ ("pcoa", "PCoAFeatureExtractorConfig"),
+ ("pca", "PCAFeatureExtractorConfig"),
+ ("label_binarizing", "LabelBinarizerConfig"),
+ ("label_encoding", "LabelEncoderConfig"),
+ ("upsampling", "UpSamplerConfig"),
+ ("min_prevalence_feature_selector", "MinPrevalenceFeatureSelectorConfig"),
+ ("log", "LogTransformerConfig"),
+ ("relative_abundance", "RelativeAbundanceScalerConfig"),
+ ("tmm", "TMMScalerConfig"),
+ ("css", "CumulativeSumScalerConfig"),
+ ("missing_labels", "MissingLabelsSampleFilterConfig"),
+ ("min_prevalence_sample_filter", "MinPrevalenceRowSampleFilterConfig"),
+ ("row_abundance", "AbundanceSampleFilterConfig"),
+ ]
+)
+
+PLOTTER_CONFIG_MAPPING_NAMES = OrderedDict(
+ [
+ ("pcoa", "PCoAFeatureExtractorPlotterConfig"),
+ ("min_prevalence_feature_selector", "MinPrevalencePlotterConfig"),
+ ("relative_abundance", "RelativeAbundancePlotterConfig"),
+ ("css", "CumulativeSumScalerPlotterConfig"),
+ ("min_prevalence_sample_filter", "MinPrevalenceRowPlotterConfig"),
+ ("row_abundance", "AbundanceSampleFilterPlotterConfig"),
+ ]
+)
+
+PROCESSOR_CATEGORY_MAPPING_NAMES = OrderedDict(
+ [
+ ("col_sum", "stat"),
+ ("row_sum", "stat"),
+ ("correlation", "stat"),
+ ("distance", "stat"),
+ ("row_missingness", "stat"),
+ ("row_mean", "stat"),
+ ("col_missingness", "stat"),
+ ("col_mean", "stat"),
+ ("lightgbm", "models"),
+ ("lasso", "models"),
+ ("random_forest", "models"),
+ ("logistic_regression", "models"),
+ ("pcoa", "preprocessing"),
+ ("pca", "preprocessing"),
+ ("label_binarizing", "preprocessing"),
+ ("label_encoding", "preprocessing"),
+ ("upsampling", "preprocessing"),
+ ("min_prevalence_feature_selector", "preprocessing"),
+ ("log", "preprocessing"),
+ ("relative_abundance", "preprocessing"),
+ ("tmm", "preprocessing"),
+ ("css", "preprocessing"),
+ ("missing_labels", "preprocessing"),
+ ("min_prevalence_sample_filter", "preprocessing"),
+ ("row_abundance", "preprocessing"),
+ ]
+)
+
+PROCESSOR_TYPE_MAPPING_NAMES = OrderedDict(
+ [
+ ("pcoa", "feature_extraction"),
+ ("pca", "feature_extraction"),
+ ("label_binarizing", "encoding"),
+ ("label_encoding", "encoding"),
+ ("upsampling", "resampling"),
+ ("min_prevalence_feature_selector", "feature_selection"),
+ ("log", "transformation"),
+ ("relative_abundance", "scaling"),
+ ("tmm", "scaling"),
+ ("css", "scaling"),
+ ("missing_labels", "filtering"),
+ ("min_prevalence_sample_filter", "filtering"),
+ ("row_abundance", "filtering"),
+ ]
+)
+
+METABOLOMICS_MAPPING_NAMES = OrderedDict(
+ [
+ ("col_sum", "ColumnSumStatConfig"),
+ ("row_sum", "RowSumStatConfig"),
+ ("correlation", "CorrelationStatConfig"),
+ ("distance", "DistanceStatConfig"),
+ ("row_missingness", "RowMissingnessStatConfig"),
+ ("row_mean", "RowMeanStatConfig"),
+ ("col_missingness", "ColumnMissingnessStatConfig"),
+ ("col_mean", "ColumnMeanStatConfig"),
+ ("lightgbm", "LightGBMConfig"),
+ ("lasso", "LassoConfig"),
+ ("random_forest", "RandomForestConfig"),
+ ("logistic_regression", "LogisticRegressionConfig"),
+ ("pcoa", "PCoAFeatureExtractorConfig"),
+ ("pca", "PCAFeatureExtractorConfig"),
+ ("label_binarizing", "LabelBinarizerConfig"),
+ ("label_encoding", "LabelEncoderConfig"),
+ ("upsampling", "UpSamplerConfig"),
+ ("min_prevalence_feature_selector", "MinPrevalenceFeatureSelectorConfig"),
+ ("log", "LogTransformerConfig"),
+ ("relative_abundance", "RelativeAbundanceScalerConfig"),
+ ("tmm", "TMMScalerConfig"),
+ ("css", "CLRScalerConfig"),
+ ("imputation", "ImputationConfig"),
+ ("missing_labels", "MissingLabelsSampleFilterConfig"),
+ ("min_prevalence_sample_filter", "MinPrevalenceRowSampleFilterConfig"),
+ ("row_abundance", "AbundanceSampleFilterConfig"),
+ ]
+)
+
+KMER_MAPPING_NAMES = OrderedDict(
+ [
+ ("col_sum", "ColumnSumStatConfig"),
+ ("row_sum", "RowSumStatConfig"),
+ ("correlation", "CorrelationStatConfig"),
+ ("distance", "DistanceStatConfig"),
+ ("row_missingness", "RowMissingnessStatConfig"),
+ ("row_mean", "RowMeanStatConfig"),
+ ("col_missingness", "ColumnMissingnessStatConfig"),
+ ("col_mean", "ColumnMeanStatConfig"),
+ ("lightgbm", "LightGBMConfig"),
+ ("lasso", "LassoConfig"),
+ ("random_forest", "RandomForestConfig"),
+ ("logistic_regression", "LogisticRegressionConfig"),
+ ("pcoa", "PCoAFeatureExtractorConfig"),
+ ("pca", "PCAFeatureExtractorConfig"),
+ ("label_binarizing", "LabelBinarizerConfig"),
+ ("label_encoding", "LabelEncoderConfig"),
+ ("upsampling", "UpSamplerConfig"),
+ ("min_prevalence_feature_selector", "MinPrevalenceFeatureSelectorConfig"),
+ ("log", "LogTransformerConfig"),
+ ("relative_abundance", "RelativeAbundanceScalerConfig"),
+ ("tmm", "TMMScalerConfig"),
+ ("css", "CLRScalerConfig"),
+ ("imputation", "ImputationConfig"),
+ ("missing_labels", "MissingLabelsSampleFilterConfig"),
+ ("min_prevalence_sample_filter", "MinPrevalenceRowSampleFilterConfig"),
+ ("row_abundance", "AbundanceSampleFilterConfig"),
+ ]
+)
+
+MS1_MAPPING_NAMES = OrderedDict(
+ [
+ ("col_sum", "ColumnSumStatConfig"),
+ ("row_sum", "RowSumStatConfig"),
+ ("correlation", "CorrelationStatConfig"),
+ ("distance", "DistanceStatConfig"),
+ ("row_missingness", "RowMissingnessStatConfig"),
+ ("row_mean", "RowMeanStatConfig"),
+ ("col_missingness", "ColumnMissingnessStatConfig"),
+ ("col_mean", "ColumnMeanStatConfig"),
+ ("lightgbm", "LightGBMConfig"),
+ ("lasso", "LassoConfig"),
+ ("random_forest", "RandomForestConfig"),
+ ("logistic_regression", "LogisticRegressionConfig"),
+ ("pcoa", "PCoAFeatureExtractorConfig"),
+ ("pca", "PCAFeatureExtractorConfig"),
+ ("label_binarizing", "LabelBinarizerConfig"),
+ ("label_encoding", "LabelEncoderConfig"),
+ ("upsampling", "UpSamplerConfig"),
+ ("min_prevalence_feature_selector", "MinPrevalenceFeatureSelectorConfig"),
+ ("log", "LogTransformerConfig"),
+ ("relative_abundance", "RelativeAbundanceScalerConfig"),
+ ("tmm", "TMMScalerConfig"),
+ ("css", "CLRScalerConfig"),
+ ("imputation", "ImputationConfig"),
+ ("missing_labels", "MissingLabelsSampleFilterConfig"),
+ ("min_prevalence_sample_filter", "MinPrevalenceRowSampleFilterConfig"),
+ ("row_abundance", "AbundanceSampleFilterConfig"),
+ ]
+)
+
+MALDI_MAPPING_NAMES = OrderedDict(
+ [
+ ("col_sum", "ColumnSumStatConfig"),
+ ("row_sum", "RowSumStatConfig"),
+ ("correlation", "CorrelationStatConfig"),
+ ("distance", "DistanceStatConfig"),
+ ("row_missingness", "RowMissingnessStatConfig"),
+ ("row_mean", "RowMeanStatConfig"),
+ ("col_missingness", "ColumnMissingnessStatConfig"),
+ ("col_mean", "ColumnMeanStatConfig"),
+ ("lightgbm", "LightGBMConfig"),
+ ("lasso", "LassoConfig"),
+ ("random_forest", "RandomForestConfig"),
+ ("logistic_regression", "LogisticRegressionConfig"),
+ ("pcoa", "PCoAFeatureExtractorConfig"),
+ ("pca", "PCAFeatureExtractorConfig"),
+ ("label_binarizing", "LabelBinarizerConfig"),
+ ("label_encoding", "LabelEncoderConfig"),
+ ("upsampling", "UpSamplerConfig"),
+ ("min_prevalence_feature_selector", "MinPrevalenceFeatureSelectorConfig"),
+ ("log", "LogTransformerConfig"),
+ ("relative_abundance", "RelativeAbundanceScalerConfig"),
+ ("tmm", "TMMScalerConfig"),
+ ("css", "CLRScalerConfig"),
+ ("imputation", "ImputationConfig"),
+ ("missing_labels", "MissingLabelsSampleFilterConfig"),
+ ("min_prevalence_sample_filter", "MinPrevalenceRowSampleFilterConfigForMaldi"),
+ ("row_abundance", "AbundanceSampleFilterConfig"),
+ ]
+)
+
+BIODATA_MAPPING_NAMES = OrderedDict(
+ [
+ ("col_sum", "ColumnSumStatConfig"),
+ ("row_sum", "RowSumStatConfig"),
+ ("correlation", "CorrelationStatConfig"),
+ ("distance", "DistanceStatConfig"),
+ ("row_missingness", "RowMissingnessStatConfig"),
+ ("row_mean", "RowMeanStatConfig"),
+ ("col_missingness", "ColumnMissingnessStatConfig"),
+ ("col_mean", "ColumnMeanStatConfig"),
+ ("lightgbm", "LightGBMConfig"),
+ ("lasso", "LassoConfig"),
+ ("random_forest", "RandomForestConfig"),
+ ("logistic_regression", "LogisticRegressionConfig"),
+ ("pcoa", "PCoAFeatureExtractorConfig"),
+ ("pca", "PCAFeatureExtractorConfig"),
+ ("label_binarizing", "LabelBinarizerConfig"),
+ ("label_encoding", "LabelEncoderConfig"),
+ ("upsampling", "UpSamplerConfig"),
+ ("min_prevalence_feature_selector", "MinPrevalenceFeatureSelectorConfig"),
+ ("log", "LogTransformerConfig"),
+ ("relative_abundance", "RelativeAbundanceScalerConfig"),
+ ("tmm", "TMMScalerConfig"),
+ ("css", "CLRScalerConfig"),
+ ("imputation", "ImputationConfig"),
+ ("missing_labels", "MissingLabelsSampleFilterConfig"),
+ ("min_prevalence_sample_filter", "MinPrevalenceRowSampleFilterConfig"),
+ ("row_abundance", "AbundanceSampleFilterConfig"),
+ ]
+)
+
+RNA_SEQ_MAPPING_NAMES = OrderedDict(
+ [
+ ("col_sum", "ColumnSumStatConfig"),
+ ("row_sum", "RowSumStatConfig"),
+ ("correlation", "CorrelationStatConfig"),
+ ("distance", "DistanceStatConfig"),
+ ("row_missingness", "RowMissingnessStatConfig"),
+ ("row_mean", "RowMeanStatConfig"),
+ ("col_missingness", "ColumnMissingnessStatConfig"),
+ ("col_mean", "ColumnMeanStatConfig"),
+ ("lightgbm", "LightGBMConfig"),
+ ("lasso", "LassoConfig"),
+ ("random_forest", "RandomForestConfig"),
+ ("logistic_regression", "LogisticRegressionConfig"),
+ ("pcoa", "PCoAFeatureExtractorConfig"),
+ ("pca", "PCAFeatureExtractorConfig"),
+ ("label_binarizing", "LabelBinarizerConfig"),
+ ("label_encoding", "LabelEncoderConfig"),
+ ("upsampling", "UpSamplerConfig"),
+ ("min_prevalence_feature_selector", "MinPrevalenceFeatureSelectorConfig"),
+ ("log", "LogTransformerConfig"),
+ ("relative_abundance", "RelativeAbundanceScalerConfig"),
+ ("tmm", "TMMScalerConfig"),
+ ("css", "CLRScalerConfig"),
+ ("imputation", "ImputationConfig"),
+ ("missing_labels", "MissingLabelsSampleFilterConfig"),
+ ("min_prevalence_sample_filter", "MinPrevalenceRowSampleFilterConfig"),
+ ("row_abundance", "AbundanceSampleFilterConfig"),
+ ]
+)
+
+PROTEOMICS_MAPPING_NAMES = OrderedDict(
+ [
+ ("col_sum", "ColumnSumStatConfig"),
+ ("row_sum", "RowSumStatConfig"),
+ ("correlation", "CorrelationStatConfig"),
+ ("distance", "DistanceStatConfig"),
+ ("row_missingness", "RowMissingnessStatConfig"),
+ ("row_mean", "RowMeanStatConfig"),
+ ("col_missingness", "ColumnMissingnessStatConfig"),
+ ("col_mean", "ColumnMeanStatConfig"),
+ ("lightgbm", "LightGBMConfig"),
+ ("lasso", "LassoConfig"),
+ ("random_forest", "RandomForestConfig"),
+ ("logistic_regression", "LogisticRegressionConfig"),
+ ("pcoa", "PCoAFeatureExtractorConfig"),
+ ("pca", "PCAFeatureExtractorConfig"),
+ ("label_binarizing", "LabelBinarizerConfig"),
+ ("label_encoding", "LabelEncoderConfig"),
+ ("upsampling", "UpSamplerConfig"),
+ ("min_prevalence_feature_selector", "MinPrevalenceFeatureSelectorConfig"),
+ ("log", "LogTransformerConfig"),
+ ("relative_abundance", "RelativeAbundanceScalerConfig"),
+ ("tmm", "TMMScalerConfig"),
+ ("css", "CLRScalerConfig"),
+ ("imputation", "ImputationConfig"),
+ ("missing_labels", "MissingLabelsSampleFilterConfig"),
+ ("min_prevalence_sample_filter", "MinPrevalenceRowSampleFilterConfig"),
+ ("row_abundance", "AbundanceSampleFilterConfig"),
+ ]
+)
+
+METAGENOMICS_MAPPING_NAMES = OrderedDict(
+ [
+ ("col_sum", "ColumnSumStatConfig"),
+ ("row_sum", "RowSumStatConfig"),
+ ("correlation", "CorrelationStatConfig"),
+ ("distance", "DistanceStatConfigForMetagenomics"),
+ ("row_missingness", "RowMissingnessStatConfig"),
+ ("row_mean", "RowMeanStatConfig"),
+ ("col_missingness", "ColumnMissingnessStatConfigForMetagenomics"),
+ ("col_mean", "ColumnMeanStatConfig"),
+ ("lightgbm", "LightGBMConfig"),
+ ("lasso", "LassoConfig"),
+ ("random_forest", "RandomForestConfig"),
+ ("logistic_regression", "LogisticRegressionConfig"),
+ ("pcoa", "PCoAFeatureExtractorConfig"),
+ ("pca", "PCAFeatureExtractorConfig"),
+ ("label_binarizing", "LabelBinarizerConfig"),
+ ("label_encoding", "LabelEncoderConfig"),
+ ("upsampling", "UpSamplerConfigForMetagenomics"),
+ (
+ "min_prevalence_feature_selector",
+ "MinPrevalenceFeatureSelectorConfigForMetagenomics",
+ ),
+ ("log", "LogTransformerConfig"),
+ ("relative_abundance", "RelativeAbundanceScalerConfig"),
+ ("tmm", "TMMScalerConfigForMetagenomics"),
+ ("css", "CumulativeSumScalerConfigForMetagenomics"),
+ ("imputation", "ImputationConfig"),
+ ("missing_labels", "MissingLabelsSampleFilterConfig"),
+ ("min_prevalence_sample_filter", "MinPrevalenceRowSampleFilterConfig"),
+ ("row_abundance", "AbundanceSampleFilterConfig"),
+ ]
+)
+
+MS2_MAPPING_NAMES = OrderedDict(
+ [
+ ("col_sum", "ColumnSumStatConfig"),
+ ("row_sum", "RowSumStatConfig"),
+ ("correlation", "CorrelationStatConfig"),
+ ("distance", "DistanceStatConfig"),
+ ("row_missingness", "RowMissingnessStatConfig"),
+ ("row_mean", "RowMeanStatConfig"),
+ ("col_missingness", "ColumnMissingnessStatConfig"),
+ ("col_mean", "ColumnMeanStatConfig"),
+ ("lightgbm", "LightGBMConfig"),
+ ("lasso", "LassoConfig"),
+ ("random_forest", "RandomForestConfig"),
+ ("logistic_regression", "LogisticRegressionConfig"),
+ ("pcoa", "PCoAFeatureExtractorConfig"),
+ ("pca", "PCAFeatureExtractorConfig"),
+ ("label_binarizing", "LabelBinarizerConfig"),
+ ("label_encoding", "LabelEncoderConfig"),
+ ("upsampling", "UpSamplerConfig"),
+ ("min_prevalence_feature_selector", "MinPrevalenceFeatureSelectorConfig"),
+ ("log", "LogTransformerConfig"),
+ ("relative_abundance", "RelativeAbundanceScalerConfig"),
+ ("tmm", "TMMScalerConfig"),
+ ("css", "CLRScalerConfig"),
+ ("imputation", "ImputationConfig"),
+ ("missing_labels", "MissingLabelsSampleFilterConfig"),
+ ("min_prevalence_sample_filter", "MinPrevalenceRowSampleFilterConfig"),
+ ("row_abundance", "AbundanceSampleFilterConfig"),
+ ]
+)
+
+OTU_MAPPING_NAMES = OrderedDict(
+ [
+ ("col_sum", "ColumnSumStatConfigForOTU"),
+ ("row_sum", "RowSumStatConfigForOTU"),
+ ("correlation", "CorrelationStatConfig"),
+ ("distance", "DistanceStatConfigForOTU"),
+ ("row_missingness", "RowMissingnessStatConfigForOTU"),
+ ("row_mean", "RowMeanStatConfigForOTU"),
+ ("col_missingness", "ColumnMissingnessStatConfigForOTU"),
+ ("col_mean", "ColumnMeanStatConfigForOTU"),
+ ("lightgbm", "LightGBMConfig"),
+ ("lasso", "LassoConfigForOTU"),
+ ("random_forest", "RandomForestConfig"),
+ ("logistic_regression", "LogisticRegressionConfig"),
+ ("pcoa", "PCoAFeatureExtractorConfigForOTU"),
+ ("pca", "PCAFeatureExtractorConfig"),
+ ("label_binarizing", "LabelBinarizerConfig"),
+ ("label_encoding", "LabelEncoderConfig"),
+ ("upsampling", "UpSamplerConfigForOTU"),
+ ("min_prevalence_feature_selector", "MinPrevalenceFeatureSelectorConfigForOTU"),
+ ("log", "LogTransformerConfigForOTU"),
+ ("relative_abundance", "RelativeAbundanceScalerConfig"),
+ ("tmm", "TMMScalerConfigForOTU"),
+ ("css", "CumulativeSumScalerConfigForOTU"),
+ ("imputation", "ImputationConfig"),
+ ("missing_labels", "MissingLabelsSampleFilterConfig"),
+ ("min_prevalence_sample_filter", "MinPrevalenceRowSampleFilterConfigForOTU"),
+ ("row_abundance", "AbundanceSampleFilterConfigForOTU"),
+ ]
+)
+
+GENOMICS_MAPPING_NAMES = OrderedDict(
+ [
+ ("col_sum", "ColumnSumStatConfig"),
+ ("row_sum", "RowSumStatConfig"),
+ ("correlation", "CorrelationStatConfig"),
+ ("distance", "DistanceStatConfig"),
+ ("row_missingness", "RowMissingnessStatConfig"),
+ ("row_mean", "RowMeanStatConfig"),
+ ("col_missingness", "ColumnMissingnessStatConfig"),
+ ("col_mean", "ColumnMeanStatConfig"),
+ ("lightgbm", "LightGBMConfig"),
+ ("lasso", "LassoConfig"),
+ ("random_forest", "RandomForestConfig"),
+ ("logistic_regression", "LogisticRegressionConfig"),
+ ("pcoa", "PCoAFeatureExtractorConfig"),
+ ("pca", "PCAFeatureExtractorConfig"),
+ ("label_binarizing", "LabelBinarizerConfig"),
+ ("label_encoding", "LabelEncoderConfig"),
+ ("upsampling", "UpSamplerConfig"),
+ ("min_prevalence_feature_selector", "MinPrevalenceFeatureSelectorConfig"),
+ ("log", "LogTransformerConfig"),
+ ("relative_abundance", "RelativeAbundanceScalerConfig"),
+ ("tmm", "TMMScalerConfig"),
+ ("css", "CLRScalerConfig"),
+ ("imputation", "ImputationConfig"),
+ ("missing_labels", "MissingLabelsSampleFilterConfig"),
+ ("min_prevalence_sample_filter", "MinPrevalenceRowSampleFilterConfig"),
+ ("row_abundance", "AbundanceSampleFilterConfig"),
+ ]
+)
+
+SNP_MAPPING_NAMES = OrderedDict(
+ [
+ ("col_sum", "ColumnSumStatConfig"),
+ ("row_sum", "RowSumStatConfig"),
+ ("correlation", "CorrelationStatConfig"),
+ ("distance", "DistanceStatConfigForSNP"),
+ ("row_missingness", "RowMissingnessStatConfigForSNP"),
+ ("row_mean", "RowMeanStatConfig"),
+ ("col_missingness", "ColumnMissingnessStatConfigForSNP"),
+ ("col_mean", "ColumnMeanStatConfig"),
+ ("lightgbm", "LightGBMConfig"),
+ ("lasso", "LassoConfig"),
+ ("random_forest", "RandomForestConfig"),
+ ("logistic_regression", "LogisticRegressionConfig"),
+ ("pcoa", "PCoAFeatureExtractorConfig"),
+ ("pca", "PCAFeatureExtractorConfig"),
+ ("label_binarizing", "LabelBinarizerConfig"),
+ ("label_encoding", "LabelEncoderConfig"),
+ ("upsampling", "UpSamplerConfigForSNP"),
+ ("min_prevalence_feature_selector", "MinPrevalenceFeatureSelectorConfigForSNP"),
+ ("log", "LogTransformerConfig"),
+ ("relative_abundance", "RelativeAbundanceScalerConfig"),
+ ("tmm", "TMMScalerConfigForSNP"),
+ ("css", "CLRScalerConfig"),
+ ("imputation", "ImputationConfig"),
+ ("missing_labels", "MissingLabelsSampleFilterConfig"),
+ ("min_prevalence_sample_filter", "MinPrevalenceRowSampleFilterConfigForSNP"),
+ ("row_abundance", "AbundanceSampleFilterConfig"),
+ ]
+)
+
+METABOLOMICS_PLOTTER_MAPPING_NAMES = OrderedDict(
+ [
+ ("pcoa", "PCoAFeatureExtractorPlotterConfig"),
+ ("min_prevalence_feature_selector", "MinPrevalencePlotterConfig"),
+ ("relative_abundance", "RelativeAbundancePlotterConfig"),
+ ("css", "CumulativeSumScalerPlotterConfig"),
+ ("min_prevalence_sample_filter", "MinPrevalenceRowPlotterConfig"),
+ ("row_abundance", "AbundanceSampleFilterPlotterConfig"),
+ ]
+)
+
+KMER_PLOTTER_MAPPING_NAMES = OrderedDict(
+ [
+ ("pcoa", "PCoAFeatureExtractorPlotterConfig"),
+ ("min_prevalence_feature_selector", "MinPrevalencePlotterConfigForGenomics"),
+ ("relative_abundance", "RelativeAbundancePlotterConfig"),
+ ("css", "CumulativeSumScalerPlotterConfigForGenomics"),
+ ("min_prevalence_sample_filter", "MinPrevalenceRowPlotterConfigForGenomics"),
+ ("row_abundance", "AbundanceSampleFilterPlotterConfigForGenomics"),
+ ]
+)
+
+MS1_PLOTTER_MAPPING_NAMES = OrderedDict(
+ [
+ ("pcoa", "PCoAFeatureExtractorPlotterConfig"),
+ ("min_prevalence_feature_selector", "MinPrevalencePlotterConfigForProteomics"),
+ ("relative_abundance", "RelativeAbundancePlotterConfig"),
+ ("css", "CumulativeSumScalerPlotterConfig"),
+ ("min_prevalence_sample_filter", "MinPrevalenceRowPlotterConfigForProteomics"),
+ ("row_abundance", "AbundanceSampleFilterPlotterConfigForProteomics"),
+ ]
+)
+
+MALDI_PLOTTER_MAPPING_NAMES = OrderedDict(
+ [
+ ("pcoa", "PCoAFeatureExtractorPlotterConfig"),
+ ("min_prevalence_feature_selector", "MinPrevalencePlotterConfigForMaldi"),
+ ("relative_abundance", "RelativeAbundancePlotterConfigForMaldi"),
+ ("css", "CumulativeSumScalerPlotterConfig"),
+ ("min_prevalence_sample_filter", "MinPrevalenceRowPlotterConfigForMaldi"),
+ ("row_abundance", "AbundanceSampleFilterPlotterConfigForProteomics"),
+ ]
+)
+
+BIODATA_PLOTTER_MAPPING_NAMES = OrderedDict(
+ [
+ ("pcoa", "PCoAFeatureExtractorPlotterConfig"),
+ ("min_prevalence_feature_selector", "MinPrevalencePlotterConfig"),
+ ("relative_abundance", "RelativeAbundancePlotterConfig"),
+ ("css", "CumulativeSumScalerPlotterConfig"),
+ ("min_prevalence_sample_filter", "MinPrevalenceRowPlotterConfig"),
+ ("row_abundance", "AbundanceSampleFilterPlotterConfig"),
+ ]
+)
+
+RNA_SEQ_PLOTTER_MAPPING_NAMES = OrderedDict(
+ [
+ ("pcoa", "PCoAFeatureExtractorPlotterConfig"),
+ ("min_prevalence_feature_selector", "MinPrevalencePlotterConfig"),
+ ("relative_abundance", "RelativeAbundancePlotterConfig"),
+ ("css", "CumulativeSumScalerPlotterConfig"),
+ ("min_prevalence_sample_filter", "MinPrevalenceRowPlotterConfig"),
+ ("row_abundance", "AbundanceSampleFilterPlotterConfig"),
+ ]
+)
+
+PROTEOMICS_PLOTTER_MAPPING_NAMES = OrderedDict(
+ [
+ ("pcoa", "PCoAFeatureExtractorPlotterConfig"),
+ ("min_prevalence_feature_selector", "MinPrevalencePlotterConfigForProteomics"),
+ ("relative_abundance", "RelativeAbundancePlotterConfig"),
+ ("css", "CumulativeSumScalerPlotterConfig"),
+ ("min_prevalence_sample_filter", "MinPrevalenceRowPlotterConfigForProteomics"),
+ ("row_abundance", "AbundanceSampleFilterPlotterConfigForProteomics"),
+ ]
+)
+
+METAGENOMICS_PLOTTER_MAPPING_NAMES = OrderedDict(
+ [
+ ("pcoa", "PCoAFeatureExtractorPlotterConfig"),
+ (
+ "min_prevalence_feature_selector",
+ "MinPrevalencePlotterConfigForMetagenomics",
+ ),
+ ("relative_abundance", "RelativeAbundancePlotterConfigForMetagenomics"),
+ ("css", "CumulativeSumScalerPlotterConfigForMetagenomics"),
+ (
+ "min_prevalence_sample_filter",
+ "MinPrevalenceRowPlotterConfigForMetagenomics",
+ ),
+ ("row_abundance", "AbundanceSampleFilterPlotterConfigForMetagenomics"),
+ ]
+)
+
+MS2_PLOTTER_MAPPING_NAMES = OrderedDict(
+ [
+ ("pcoa", "PCoAFeatureExtractorPlotterConfig"),
+ ("min_prevalence_feature_selector", "MinPrevalencePlotterConfigForProteomics"),
+ ("relative_abundance", "RelativeAbundancePlotterConfig"),
+ ("css", "CumulativeSumScalerPlotterConfig"),
+ ("min_prevalence_sample_filter", "MinPrevalenceRowPlotterConfigForProteomics"),
+ ("row_abundance", "AbundanceSampleFilterPlotterConfigForProteomics"),
+ ]
+)
+
+OTU_PLOTTER_MAPPING_NAMES = OrderedDict(
+ [
+ ("pcoa", "PCoAFeatureExtractorPlotterConfig"),
+ ("min_prevalence_feature_selector", "MinPrevalencePlotterConfigForOTU"),
+ ("relative_abundance", "RelativeAbundancePlotterConfigForOTU"),
+ ("css", "CumulativeSumScalerPlotterConfigForOTU"),
+ ("min_prevalence_sample_filter", "MinPrevalenceRowPlotterConfigForOTU"),
+ ("row_abundance", "AbundanceSampleFilterPlotterConfigForOTU"),
+ ]
+)
+
+GENOMICS_PLOTTER_MAPPING_NAMES = OrderedDict(
+ [
+ ("pcoa", "PCoAFeatureExtractorPlotterConfig"),
+ ("min_prevalence_feature_selector", "MinPrevalencePlotterConfigForGenomics"),
+ ("relative_abundance", "RelativeAbundancePlotterConfig"),
+ ("css", "CumulativeSumScalerPlotterConfigForGenomics"),
+ ("min_prevalence_sample_filter", "MinPrevalenceRowPlotterConfigForGenomics"),
+ ("row_abundance", "AbundanceSampleFilterPlotterConfigForGenomics"),
+ ]
+)
+
+SNP_PLOTTER_MAPPING_NAMES = OrderedDict(
+ [
+ ("pcoa", "PCoAFeatureExtractorPlotterConfig"),
+ ("min_prevalence_feature_selector", "MinPrevalencePlotterConfigForSNP"),
+ ("relative_abundance", "RelativeAbundancePlotterConfigForSNP"),
+ ("css", "CumulativeSumScalerPlotterConfigForSNP"),
+ ("min_prevalence_sample_filter", "MinPrevalenceRowPlotterConfigForSNP"),
+ ("row_abundance", "AbundanceSampleFilterPlotterConfigForSNP"),
+ ]
+)
+
+
+def config_class_to_model_type(config):
+ """Converts a config class name to the corresponding model type"""
+ for key, cls in CONFIG_MAPPING_NAMES.items():
+ if cls == config:
+ return key
+ for key, cls in CONFIG_MAPPING._extra_content.items():
+ if cls.__name__ == config:
+ return key
+ return None
+
+
+class _LazyConfigMapping(OrderedDict):
+ """
+ A dictionary that lazily load its values when they are requested.
+ """
+
+ def __init__(self, mapping):
+ self._mapping = mapping
+ self._extra_content = {}
+ self._modules = {}
+
+ def __getitem__(self, key: str):
+ if key in self._extra_content:
+ return self._extra_content[key]
+ if key not in self._mapping:
+ raise KeyError(key)
+ value = self._mapping[key]
+ module_name = key.replace("-", "_")
+ processor_category = PROCESSOR_CATEGORY_MAPPING_NAMES.get(key, "models")
+ processor_type = PROCESSOR_TYPE_MAPPING_NAMES.get(key, None)
+ try:
+ if module_name not in self._modules:
+ package = (
+ f"biofit.{processor_category}"
+ if not processor_type
+ else f"biofit.{processor_category}.{processor_type}"
+ )
+ self._modules[module_name] = importlib.import_module(
+ f".{module_name}", package
+ )
+ except ImportError:
+ if is_transformers_available() and processor_category == "models":
+ from transformers.models.auto.configuration_auto import (
+ model_type_to_module_name,
+ )
+
+ module_name = model_type_to_module_name(key)
+ if module_name not in self._modules:
+ self._modules[module_name] = importlib.import_module(
+ f".{module_name}", "transformers.models"
+ )
+ else:
+ raise
+ if hasattr(self._modules[module_name], value):
+ return getattr(self._modules[module_name], value)
+ try:
+ biofit_module = importlib.import_module("biofit")
+ return getattr(biofit_module, value)
+ except AttributeError:
+ if is_transformers_available():
+ transformers_module = importlib.import_module("transformers")
+ return getattr(transformers_module, value)
+ raise
+
+ def keys(self):
+ return list(self._mapping.keys()) + list(self._extra_content.keys())
+
+ def values(self):
+ return [self[k] for k in self._mapping.keys()] + list(
+ self._extra_content.values()
+ )
+
+ def items(self):
+ return [(k, self[k]) for k in self._mapping.keys()] + list(
+ self._extra_content.items()
+ )
+
+ def __iter__(self):
+ return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
+
+ def __contains__(self, item):
+ return item in self._mapping or item in self._extra_content
+
+ def register(self, key, value, exist_ok=False):
+ """
+ Register a new configuration in this mapping.
+ """
+ if key in self._mapping.keys() and (not exist_ok):
+ raise ValueError(
+ f"'{key}' is already used by a Transformers config, pick another name."
+ )
+ self._extra_content[key] = value
+
+
+CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES)
+PLOTTER_CONFIG_MAPPING = _LazyConfigMapping(PLOTTER_CONFIG_MAPPING_NAMES)
+METABOLOMICS_MAPPING = _LazyConfigMapping(METABOLOMICS_MAPPING_NAMES)
+KMER_MAPPING = _LazyConfigMapping(KMER_MAPPING_NAMES)
+MS1_MAPPING = _LazyConfigMapping(MS1_MAPPING_NAMES)
+MALDI_MAPPING = _LazyConfigMapping(MALDI_MAPPING_NAMES)
+BIODATA_MAPPING = _LazyConfigMapping(BIODATA_MAPPING_NAMES)
+RNA_SEQ_MAPPING = _LazyConfigMapping(RNA_SEQ_MAPPING_NAMES)
+PROTEOMICS_MAPPING = _LazyConfigMapping(PROTEOMICS_MAPPING_NAMES)
+METAGENOMICS_MAPPING = _LazyConfigMapping(METAGENOMICS_MAPPING_NAMES)
+MS2_MAPPING = _LazyConfigMapping(MS2_MAPPING_NAMES)
+OTU_MAPPING = _LazyConfigMapping(OTU_MAPPING_NAMES)
+GENOMICS_MAPPING = _LazyConfigMapping(GENOMICS_MAPPING_NAMES)
+SNP_MAPPING = _LazyConfigMapping(SNP_MAPPING_NAMES)
+METABOLOMICS_PLOTTER_MAPPING = _LazyConfigMapping(METABOLOMICS_PLOTTER_MAPPING_NAMES)
+KMER_PLOTTER_MAPPING = _LazyConfigMapping(KMER_PLOTTER_MAPPING_NAMES)
+MS1_PLOTTER_MAPPING = _LazyConfigMapping(MS1_PLOTTER_MAPPING_NAMES)
+MALDI_PLOTTER_MAPPING = _LazyConfigMapping(MALDI_PLOTTER_MAPPING_NAMES)
+BIODATA_PLOTTER_MAPPING = _LazyConfigMapping(BIODATA_PLOTTER_MAPPING_NAMES)
+RNA_SEQ_PLOTTER_MAPPING = _LazyConfigMapping(RNA_SEQ_PLOTTER_MAPPING_NAMES)
+PROTEOMICS_PLOTTER_MAPPING = _LazyConfigMapping(PROTEOMICS_PLOTTER_MAPPING_NAMES)
+METAGENOMICS_PLOTTER_MAPPING = _LazyConfigMapping(METAGENOMICS_PLOTTER_MAPPING_NAMES)
+MS2_PLOTTER_MAPPING = _LazyConfigMapping(MS2_PLOTTER_MAPPING_NAMES)
+OTU_PLOTTER_MAPPING = _LazyConfigMapping(OTU_PLOTTER_MAPPING_NAMES)
+GENOMICS_PLOTTER_MAPPING = _LazyConfigMapping(GENOMICS_PLOTTER_MAPPING_NAMES)
+SNP_PLOTTER_MAPPING = _LazyConfigMapping(SNP_PLOTTER_MAPPING_NAMES)
+
+
+DATASET_TO_MAPPER = {
+ "metabolomics": METABOLOMICS_MAPPING,
+ "kmer": KMER_MAPPING,
+ "ms1": MS1_MAPPING,
+ "maldi": MALDI_MAPPING,
+ "biodata": BIODATA_MAPPING,
+ "rna-seq": RNA_SEQ_MAPPING,
+ "proteomics": PROTEOMICS_MAPPING,
+ "metagenomics": METAGENOMICS_MAPPING,
+ "ms2": MS2_MAPPING,
+ "otu": OTU_MAPPING,
+ "genomics": GENOMICS_MAPPING,
+ "snp": SNP_MAPPING,
+}
+
+DATASET_TO_MAPPER_NAMES = {
+ "metabolomics": METABOLOMICS_MAPPING_NAMES,
+ "kmer": KMER_MAPPING_NAMES,
+ "ms1": MS1_MAPPING_NAMES,
+ "maldi": MALDI_MAPPING_NAMES,
+ "biodata": BIODATA_MAPPING_NAMES,
+ "rna-seq": RNA_SEQ_MAPPING_NAMES,
+ "proteomics": PROTEOMICS_MAPPING_NAMES,
+ "metagenomics": METAGENOMICS_MAPPING_NAMES,
+ "ms2": MS2_MAPPING_NAMES,
+ "otu": OTU_MAPPING_NAMES,
+ "genomics": GENOMICS_MAPPING_NAMES,
+ "snp": SNP_MAPPING_NAMES,
+}
+
+DATASET_PLT_TO_MAPPER = {
+ "metabolomics": METABOLOMICS_PLOTTER_MAPPING,
+ "kmer": KMER_PLOTTER_MAPPING,
+ "ms1": MS1_PLOTTER_MAPPING,
+ "maldi": MALDI_PLOTTER_MAPPING,
+ "biodata": BIODATA_PLOTTER_MAPPING,
+ "rna-seq": RNA_SEQ_PLOTTER_MAPPING,
+ "proteomics": PROTEOMICS_PLOTTER_MAPPING,
+ "metagenomics": METAGENOMICS_PLOTTER_MAPPING,
+ "ms2": MS2_PLOTTER_MAPPING,
+ "otu": OTU_PLOTTER_MAPPING,
+ "genomics": GENOMICS_PLOTTER_MAPPING,
+ "snp": SNP_PLOTTER_MAPPING,
+}
+
+DATASET_PLT_TO_MAPPER_NAMES = {
+ "metabolomics": METABOLOMICS_PLOTTER_MAPPING_NAMES,
+ "kmer": KMER_PLOTTER_MAPPING_NAMES,
+ "ms1": MS1_PLOTTER_MAPPING_NAMES,
+ "maldi": MALDI_PLOTTER_MAPPING_NAMES,
+ "biodata": BIODATA_PLOTTER_MAPPING_NAMES,
+ "rna-seq": RNA_SEQ_PLOTTER_MAPPING_NAMES,
+ "proteomics": PROTEOMICS_PLOTTER_MAPPING_NAMES,
+ "metagenomics": METAGENOMICS_PLOTTER_MAPPING_NAMES,
+ "ms2": MS2_PLOTTER_MAPPING_NAMES,
+ "otu": OTU_PLOTTER_MAPPING_NAMES,
+ "genomics": GENOMICS_PLOTTER_MAPPING_NAMES,
+ "snp": SNP_PLOTTER_MAPPING_NAMES,
+}
+
+
+DATASET_PREPROCESSOR_MAPPING_NAMES = OrderedDict(
+ [
+ (
+ "otu",
+ [
+ "row_abundance",
+ "min_prevalence_feature_selector",
+ "css",
+ ],
+ ),
+ (
+ "asv",
+ [
+ "row_abundance",
+ "min_prevalence_feature_selector",
+ "css",
+ ],
+ ),
+ (
+ "abundance",
+ [
+ "row_abundance",
+ "min_prevalence_feature_selector",
+ "css",
+ ],
+ ),
+ (
+ "metagenomics",
+ [
+ "row_abundance",
+ "min_prevalence_feature_selector",
+ "css",
+ ],
+ ),
+ (
+ "snp",
+ [
+ "min_prevalence_sample_filter",
+ "min_prevalence_feature_selector",
+ ],
+ ),
+ (
+ "genomics",
+ [
+ "min_prevalence_sample_filter",
+ "min_prevalence_feature_selector",
+ ],
+ ),
+ (
+ "maldi",
+ [
+ "min_prevalence_sample_filter",
+ "relative_abundance",
+ "min_prevalence_feature_selector",
+ ],
+ ),
+ (
+ "metabolomics",
+ [
+ "min_prevalence_sample_filter",
+ "min_prevalence_feature_selector",
+ ],
+ ),
+ (
+ "proteomics",
+ [
+ "min_prevalence_sample_filter",
+ "min_prevalence_feature_selector",
+ ],
+ ),
+ (
+ "transcriptomics",
+ [
+ "min_prevalence_sample_filter",
+ "min_prevalence_feature_selector",
+ ],
+ ),
+ ]
+)
+
+
+class _LazyLoadAllMappings(OrderedDict):
+ """
+ A mapping that will load all pairs of key values at the first access (either by indexing, requestions keys, values,
+ etc.)
+
+ Args:
+ mapping: The mapping to load.
+ """
+
+ def __init__(self, mapping):
+ self._mapping = mapping
+ self._initialized = False
+ self._data = {}
+
+ def _initialize(self):
+ if self._initialized:
+ return
+ warnings.warn(
+ "ALL_PRETRAINED_CONFIG_ARCHIVE_MAP is deprecated and will be removed in v5 of Transformers. It does not contain all available model checkpoints, far from it. Checkout hf.co/models for that.",
+ FutureWarning,
+ )
+ for processor_name, map_name in self._mapping.items():
+ processor_category = PROCESSOR_CATEGORY_MAPPING_NAMES.get(
+ processor_name, "models"
+ )
+ processor_type = PROCESSOR_TYPE_MAPPING_NAMES.get(processor_name, None)
+ try:
+ module_name = processor_name.replace("-", "_")
+ package = (
+ f"biofit.{processor_category}"
+ if not processor_type
+ else f"biofit.{processor_category}.{processor_type}"
+ )
+ module = importlib.import_module(f".{module_name}", package)
+ mapping = getattr(module, map_name)
+ except ImportError:
+ if is_transformers_available() and processor_category == "models":
+ from transformers.models.auto.configuration_auto import (
+ model_type_to_module_name,
+ )
+
+ module_name = model_type_to_module_name(processor_name)
+ module = importlib.import_module(
+ f".{module_name}", "transformers.models"
+ )
+ mapping = getattr(module, map_name)
+ self._data.update(mapping)
+ self._initialized = True
+
+ def __getitem__(self, key):
+ self._initialize()
+ return self._data[key]
+
+ def keys(self):
+ self._initialize()
+ return self._data.keys()
+
+ def values(self):
+ self._initialize()
+ return self._data.values()
+
+ def items(self):
+ self._initialize()
+ return self._data.keys()
+
+ def __iter__(self):
+ self._initialize()
+ return iter(self._data)
+
+ def __contains__(self, item):
+ self._initialize()
+ return item in self._data
+
+
+def _get_class_name(model_class: Union[str, List[str]]):
+ if isinstance(model_class, (list, tuple)):
+ return " or ".join([f"[`{c}`]" for c in model_class if c is not None])
+ return f"[`{model_class}`]"
+
+
+def _list_model_options(indent, config_to_class=None, use_model_types=True):
+ if config_to_class is None and (not use_model_types):
+ raise ValueError(
+ "Using `use_model_types=False` requires a `config_to_class` dictionary."
+ )
+ if use_model_types:
+ if config_to_class is None:
+ model_type_to_name = {
+ model_type: f"[`{config}`]"
+ for model_type, config in CONFIG_MAPPING_NAMES.items()
+ }
+ else:
+ model_type_to_name = {
+ model_type: _get_class_name(model_class)
+ for model_type, model_class in config_to_class.items()
+ if model_type in PROCESSOR_MAPPING_NAMES
+ }
+ lines = [
+ f"{indent}- **{model_type}** -- {model_type_to_name[model_type]} ({PROCESSOR_MAPPING_NAMES[model_type]} model)"
+ for model_type in sorted(model_type_to_name.keys())
+ ]
+ else:
+ config_to_name = {
+ CONFIG_MAPPING_NAMES[config]: _get_class_name(clas)
+ for config, clas in config_to_class.items()
+ if config in CONFIG_MAPPING_NAMES
+ }
+ config_to_model_name = {
+ config: PROCESSOR_MAPPING_NAMES[model_type]
+ for model_type, config in CONFIG_MAPPING_NAMES.items()
+ }
+ lines = [
+ f"{indent}- [`{config_name}`] configuration class: {config_to_name[config_name]} ({config_to_model_name[config_name]} model)"
+ for config_name in sorted(config_to_name.keys())
+ ]
+ return "\n".join(lines)
+
+
+def replace_list_option_in_docstrings(config_to_class=None, use_model_types=True):
+ def docstring_decorator(fn):
+ docstrings = fn.__doc__
+ if docstrings is None:
+ return fn
+ lines = docstrings.split("\n")
+ i = 0
+ while (
+ i < len(lines) and re.search("^(\\s*)List options\\s*$", lines[i]) is None
+ ):
+ i += 1
+ if i < len(lines):
+ indent = re.search("^(\\s*)List options\\s*$", lines[i]).groups()[0]
+ if use_model_types:
+ indent = f"{indent} "
+ lines[i] = _list_model_options(
+ indent, config_to_class=config_to_class, use_model_types=use_model_types
+ )
+ docstrings = "\n".join(lines)
+ else:
+ raise ValueError(
+ f"The function {fn} should have an empty 'List options' in its docstring as placeholder, current docstring is:\n{docstrings}"
+ )
+ fn.__doc__ = docstrings
+ return fn
+
+ return docstring_decorator
+
+
+class AutoConfig:
+ """
+ This is a generic configuration class that will be instantiated as one of the configuration classes of the library
+ when created with the [`~AutoConfig.from_pretrained`] class method.
+
+ This class cannot be instantiated directly using `__init__()` (throws an error).
+ """
+
+ def __init__(self):
+ raise EnvironmentError(
+ "AutoConfig is designed to be instantiated using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
+ )
+
+ @classmethod
+ def for_processor(
+ cls, model_type: str, *args, experiment_type: str = None, **kwargs
+ ):
+ if is_biosets_available():
+ from biosets.packaged_modules import EXPERIMENT_TYPE_ALIAS
+ else:
+ EXPERIMENT_TYPE_ALIAS = {}
+
+ experiment_type = EXPERIMENT_TYPE_ALIAS.get(experiment_type, experiment_type)
+ if experiment_type is not None:
+ if experiment_type in DATASET_TO_MAPPER:
+ mapper = DATASET_TO_MAPPER[experiment_type]
+ if model_type in mapper:
+ _config_class = mapper[model_type]
+ return _config_class(*args, **kwargs)
+ raise ValueError(
+ f"Unrecognized model identifier: {model_type}. Should contain one of {', '.join(mapper.keys())}"
+ )
+ if model_type in CONFIG_MAPPING:
+ _config_class = CONFIG_MAPPING[model_type]
+ cls_kwargs = get_kwargs(kwargs, _config_class.__init__)
+ return _config_class(*args, **cls_kwargs)
+ elif (
+ is_transformers_available()
+ and PROCESSOR_CATEGORY_MAPPING_NAMES.get(model_type, "models") == "models"
+ ):
+ from transformers.models.auto.configuration_auto import (
+ AutoConfig as HfAutoConfig,
+ )
+
+ return HfAutoConfig.for_model(model_type, *args, **kwargs)
+ raise ValueError(
+ f"Unrecognized model identifier: {model_type}. Should contain one of {', '.join(CONFIG_MAPPING.keys())}"
+ )
+
+ @classmethod
+ @replace_list_option_in_docstrings()
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
+ if is_transformers_available():
+ from transformers.models.auto.configuration_auto import (
+ AutoConfig as HfAutoConfig,
+ )
+
+ return HfAutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
+ raise EnvironmentError(
+ "Using `AutoConfig.from_pretrained` requires the transformers library to be installed. You can install it with `pip install transformers`."
+ )
+
+ @staticmethod
+ def register(processor_type, config, exist_ok=False):
+ """
+ Register a new configuration for this class.
+
+ Args:
+ model_type (`str`): The model type like "bert" or "gpt".
+ config ([`PretrainedConfig`]): The config to register.
+ """
+ types = ProcessorConfig
+ if is_transformers_available():
+ from transformers import PretrainedConfig
+
+ types = (PretrainedConfig, ProcessorConfig)
+ config_processor_type = getattr(
+ config, "processor_type", getattr(config, "model_type", None)
+ )
+ if issubclass(config, types) and config_processor_type != processor_type:
+ raise ValueError(
+ f"The config you are passing has a `model_type` attribute that is not consistent with the model type you passed (config has {config_processor_type} and you passed {processor_type}. Fix one of those so they match!"
+ )
+ CONFIG_MAPPING.register(processor_type, config, exist_ok=exist_ok)
+
+
+class AutoPlotterConfig:
+ """
+ This is a generic configuration class that will be instantiated as one of the configuration classes of the library
+ when created with the [`~AutoPlotterConfig.from_pretrained`] class method.
+
+ This class cannot be instantiated directly using `__init__()` (throws an error).
+ """
+
+ @classmethod
+ def for_dataset(cls, dataset_name: str, *args, **kwargs):
+ if is_biosets_available():
+ from biosets.packaged_modules import EXPERIMENT_TYPE_ALIAS
+ else:
+ EXPERIMENT_TYPE_ALIAS = {}
+
+ dataset_name = EXPERIMENT_TYPE_ALIAS.get(dataset_name, dataset_name)
+ if dataset_name in DATASET_PREPROCESSOR_MAPPING_NAMES:
+ preprocessors = DATASET_PREPROCESSOR_MAPPING_NAMES[dataset_name]
+ mapper = DATASET_PLT_TO_MAPPER[dataset_name]
+ classes = []
+ for preprocessor in preprocessors:
+ _config_class = mapper[preprocessor]
+ classes.append(_config_class(*args, **kwargs))
+ return classes
+
+ raise ValueError(
+ f"Unrecognized dataset identifier: {dataset_name}. Should contain one of {', '.join(DATASET_PREPROCESSOR_MAPPING_NAMES.keys())}"
+ )
+
+ @classmethod
+ def for_processor(
+ cls, processor_type: str, *args, dataset_name: str = None, **kwargs
+ ):
+ if dataset_name is not None:
+ if is_biosets_available():
+ from biosets.packaged_modules import EXPERIMENT_TYPE_ALIAS
+ else:
+ EXPERIMENT_TYPE_ALIAS = {}
+
+ dataset_name = EXPERIMENT_TYPE_ALIAS.get(dataset_name, dataset_name)
+ if dataset_name in DATASET_PLT_TO_MAPPER:
+ mapper = DATASET_PLT_TO_MAPPER[dataset_name]
+ if processor_type in mapper:
+ _config_class = mapper[processor_type]
+ return _config_class(*args, **kwargs)
+ raise ValueError(
+ f"Unrecognized model identifier: {processor_type}. Should contain one of {', '.join(mapper.keys())}"
+ )
+ if processor_type in PLOTTER_CONFIG_MAPPING:
+ _config_class = PLOTTER_CONFIG_MAPPING[processor_type]
+ return _config_class(*args, **kwargs)
+ elif (
+ is_transformers_available()
+ and PROCESSOR_CATEGORY_MAPPING_NAMES.get(processor_type, "models")
+ == "models"
+ ):
+ from transformers.models.auto.configuration_auto import (
+ AutoConfig as HfAutoConfig,
+ )
+
+ return HfAutoConfig.for_model(processor_type, *args, **kwargs)
+ raise ValueError(
+ f"Unrecognized model identifier: {processor_type}. Should contain one of {', '.join(PLOTTER_CONFIG_MAPPING.keys())}"
+ )
+
+ @staticmethod
+ def register(processor_type, config, exist_ok=False):
+ """
+ Register a new configuration for this class.
+
+ Args:
+ model_type (`str`): The model type like "bert" or "gpt".
+ config ([`PretrainedConfig`]): The config to register.
+ """
+ types = PlotterConfig
+ config_processor_type = getattr(
+ config, "processor_type", getattr(config, "model_type", None)
+ )
+ if issubclass(config, types) and config_processor_type != processor_type:
+ raise ValueError(
+ f"The config you are passing has a `model_type` attribute that is not consistent with the model type you passed (config has {config_processor_type} and you passed {processor_type}. Fix one of those so they match!"
+ )
+ PLOTTER_CONFIG_MAPPING.register(processor_type, config, exist_ok=exist_ok)
+
+
+class AutoPreprocessorConfig(AutoConfig):
+ """
+ This is a generic configuration class that will be instantiated as one of the configuration classes of the library
+ when created with the [`~AutoPreprocessorConfig.from_pretrained`] class method.
+
+ This class cannot be instantiated directly using `__init__()` (throws an error).
+ """
+
+ @classmethod
+ def for_dataset(cls, dataset_name: str, *args, **kwargs):
+ if is_biosets_available():
+ from biosets.packaged_modules import EXPERIMENT_TYPE_ALIAS
+ else:
+ EXPERIMENT_TYPE_ALIAS = {}
+
+ dataset_name = EXPERIMENT_TYPE_ALIAS.get(dataset_name, dataset_name)
+ if dataset_name in DATASET_PREPROCESSOR_MAPPING_NAMES:
+ preprocessors = DATASET_PREPROCESSOR_MAPPING_NAMES[dataset_name]
+ mapper = DATASET_TO_MAPPER[dataset_name]
+ classes = []
+ for preprocessor in preprocessors:
+ _config_class = mapper[preprocessor]
+ classes.append(_config_class(*args, **kwargs.get(preprocessor, {})))
+ return classes
+
+ raise ValueError(
+ f"Unrecognized dataset identifier: {dataset_name}. Should contain one of {', '.join(DATASET_PREPROCESSOR_MAPPING_NAMES.keys())}"
+ )
diff --git a/src/biofit/auto/modeling_auto.py b/src/biofit/auto/modeling_auto.py
new file mode 100644
index 0000000..aeaa77f
--- /dev/null
+++ b/src/biofit/auto/modeling_auto.py
@@ -0,0 +1,61 @@
+from collections import OrderedDict
+
+from biofit.auto.auto_factory import _BaseAutoModelClass, _LazyAutoMapping
+
+from .configuration_auto import CONFIG_MAPPING_NAMES
+
+MODEL_MAPPING_NAMES = OrderedDict(
+ [
+ ("xgboost", "XGBoostModel"),
+ ("lightgbm", "LightGBMModel"),
+ ("catboost", "CatBoostModel"),
+ ("random_forest", "RandomForestModel"),
+ ("logistic_regression", "LogisticRegressionModel"),
+ ("lasso", "LassoModel"),
+ ("svm", "SVMModel"),
+ ("knn", "KNNModel"),
+ ]
+)
+
+MODEL_FOR_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
+ [
+ ("xgboost", "XGBoostForClassification"),
+ ("lightgbm", "LightGBMForClassification"),
+ ("catboost", "CatBoostForClassification"),
+ ("random_forest", "RandomForestForClassification"),
+ ("logistic_regression", "LogisticRegressionForClassification"),
+ ("lasso", "LassoForClassification"),
+ ("svm", "SVMForClassification"),
+ ("knn", "KNNForClassification"),
+ ]
+)
+
+MODEL_FOR_REGRESSION_MAPPING_NAMES = OrderedDict(
+ [
+ ("xgboost", "XGBoostForRegression"),
+ ("lightgbm", "LightGBMForRegression"),
+ ("catboost", "CatBoostForRegression"),
+ ("random_forest", "RandomForestForRegression"),
+ ("logistic_regression", "LogisticRegressionForRegression"),
+ ("lasso", "LassoForRegression"),
+ ("svm", "SVMForRegression"),
+ ("knn", "KNNForRegression"),
+ ]
+)
+
+MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
+MODEL_FOR_CLASSIFICATION_MAPPING_NAMES = _LazyAutoMapping(
+ CONFIG_MAPPING_NAMES, MODEL_FOR_CLASSIFICATION_MAPPING_NAMES
+)
+
+
+class AutoModel(_BaseAutoModelClass):
+ _processor_mapping = MODEL_MAPPING
+
+
+class AutoModelForClassification(_BaseAutoModelClass):
+ _processor_mapping = MODEL_FOR_CLASSIFICATION_MAPPING_NAMES
+
+
+class AutoModelForRegression(_BaseAutoModelClass):
+ _processor_mapping = MODEL_FOR_REGRESSION_MAPPING_NAMES
diff --git a/src/biofit/auto/plotting_auto.py b/src/biofit/auto/plotting_auto.py
new file mode 100644
index 0000000..491e1ac
--- /dev/null
+++ b/src/biofit/auto/plotting_auto.py
@@ -0,0 +1,143 @@
+from typing import List, Union
+
+from biocore.utils.import_util import is_biosets_available
+from biocore.utils.inspect import get_kwargs
+
+from biofit.auto.auto_factory import (
+ _BaseAutoProcessorClass,
+ _LazyAutoMapping,
+)
+from biofit.processing import BaseProcessor
+from biofit.visualization.plotting import BasePlotter
+from biofit.visualization.plotting_utils import (
+ display_image_carousel,
+ is_in_notebook,
+)
+
+from .configuration_auto import (
+ DATASET_PLT_TO_MAPPER_NAMES,
+ PLOTTER_CONFIG_MAPPING_NAMES,
+ PLOTTER_MAPPING_NAMES,
+ AutoPlotterConfig,
+)
+from .processing_auto import AutoPreprocessor, ProcessorPipeline
+
+PLOTTER_MAPPING = _LazyAutoMapping(PLOTTER_CONFIG_MAPPING_NAMES, PLOTTER_MAPPING_NAMES)
+
+
+class PlotterPipeline:
+ def __init__(self, plotters: List[BasePlotter], processors: List[BaseProcessor]):
+ self.plotters = plotters
+ self.processors = processors
+
+ def plot(self, X, *args, fit=True, **kwargs):
+ from datasets import Dataset, IterableDataset
+
+ from biofit import Bioset
+
+ if not isinstance(X, (Bioset, Dataset, IterableDataset)):
+ raise ValueError("X must be a Bioset or huggingface Dataset.")
+ pre_X = X
+ show = kwargs.pop("show", True)
+ images = []
+ for plotter, processor in zip(self.plotters, self.processors):
+ fit_trans_kwargs = get_kwargs(kwargs, processor.fit_transform)
+ if fit:
+ after_X = processor.fit_transform(pre_X, *args, **fit_trans_kwargs)
+ else:
+ after_X = processor.transform(pre_X, *args, **fit_trans_kwargs)
+ if plotter.config._compare:
+ path = plotter.plot(pre_X, after_X, *args, show=False, **kwargs)
+ else:
+ path = plotter.plot(after_X, *args, show=False, **kwargs)
+ if isinstance(path, list):
+ images.extend(path)
+ else:
+ images.append(path)
+ pre_X = after_X
+
+ if show and is_in_notebook():
+ display_image_carousel(images)
+
+
+class AutoPlotter(_BaseAutoProcessorClass):
+ _processor_mapping = PLOTTER_MAPPING
+
+ @classmethod
+ def for_dataset(cls, dataset_name, **kwargs):
+ """Create a processor for a dataset.
+
+ Args:
+ dataset (Bioset): The dataset to create a processor for.
+
+ Returns:
+ Processor: The processor for the dataset.
+ """
+
+ if is_biosets_available():
+ from biosets.packaged_modules import EXPERIMENT_TYPE_ALIAS
+ else:
+ EXPERIMENT_TYPE_ALIAS = {}
+
+ dataset_name = EXPERIMENT_TYPE_ALIAS.get(dataset_name, dataset_name)
+ _plotter_mapping = _LazyAutoMapping(
+ DATASET_PLT_TO_MAPPER_NAMES.get(dataset_name), PLOTTER_MAPPING_NAMES
+ )
+ configs = AutoPlotterConfig.for_dataset(dataset_name)
+ procs = []
+ for config in configs:
+ config_kwargs = get_kwargs(kwargs, config.__class__.__init__)
+ procs.append(
+ _plotter_mapping[type(config)]._from_config(config, **config_kwargs)
+ )
+
+ processors = AutoPreprocessor.for_dataset(dataset_name)
+
+ return PlotterPipeline(procs, processors)
+
+ @classmethod
+ def from_processor(
+ cls, processor, dataset_name=None, **kwargs
+ ) -> Union[PlotterPipeline, BasePlotter]:
+ """Create a processor from another processor.
+
+ Args:
+ processor (Processor): The processor to create a processor for.
+
+ Returns:
+ Processor: The processor for the dataset.
+ """
+
+ def get_proc(proc, dataset_name=None, **kwargs):
+ dataset_name = dataset_name or proc.config.dataset_name
+ if dataset_name:
+ if is_biosets_available():
+ from biosets.packaged_modules import EXPERIMENT_TYPE_ALIAS
+ else:
+ EXPERIMENT_TYPE_ALIAS = {}
+ dataset_name = EXPERIMENT_TYPE_ALIAS.get(dataset_name, dataset_name)
+ _plotter_mapping = _LazyAutoMapping(
+ DATASET_PLT_TO_MAPPER_NAMES.get(dataset_name),
+ PLOTTER_MAPPING_NAMES,
+ )
+ config = AutoPlotterConfig.for_processor(
+ proc.config.processor_name,
+ dataset_name=dataset_name,
+ )
+ else:
+ _plotter_mapping = PLOTTER_MAPPING
+ config = AutoPlotterConfig.for_processor(proc.config.processor_name)
+ config_kwargs = get_kwargs(kwargs, config.__class__.__init__)
+ return _plotter_mapping[type(config)]._from_config(config, **config_kwargs)
+
+ if isinstance(processor, ProcessorPipeline):
+ plotters = []
+ for proc in processor.steps:
+ plotters.append(get_proc(proc[1], dataset_name=dataset_name, **kwargs))
+ return PlotterPipeline(plotters, processor.processors)
+ elif isinstance(processor, BaseProcessor):
+ return get_proc(processor, **kwargs)
+ else:
+ raise ValueError(
+ "processor must be a `biofit.ProcessorPipeline` or `biofit.BaseProcessor`."
+ )
diff --git a/src/biofit/auto/processing_auto.py b/src/biofit/auto/processing_auto.py
new file mode 100644
index 0000000..8258b88
--- /dev/null
+++ b/src/biofit/auto/processing_auto.py
@@ -0,0 +1,98 @@
+from typing import List
+
+from biocore.utils.import_util import is_biosets_available
+from biocore.utils.inspect import get_kwargs
+from sklearn.pipeline import Pipeline
+
+from biofit.auto.auto_factory import (
+ _BaseAutoProcessorClass,
+ _get_class,
+ _LazyAutoMapping,
+)
+from biofit.processing import BaseProcessor
+
+from .configuration_auto import (
+ CONFIG_MAPPING_NAMES,
+ DATASET_TO_MAPPER_NAMES,
+ PROCESSOR_MAPPING_NAMES,
+ AutoPreprocessorConfig,
+)
+
+PROCESSOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, PROCESSOR_MAPPING_NAMES)
+
+
+class AutoProcessor(_BaseAutoProcessorClass):
+ _processor_mapping = PROCESSOR_MAPPING
+
+
+class ProcessorPipeline(Pipeline):
+ """A pipeline of processors."""
+
+ def __init__(self, processors: List[BaseProcessor] = None, **kwargs):
+ if "steps" in kwargs:
+ return super().__init__(**kwargs)
+
+ self.processors = processors
+ steps = []
+ for i, processor in enumerate(processors):
+ if not isinstance(processor, tuple):
+ if hasattr(processor, "config"):
+ steps.append(
+ (
+ f"{processor.config.processor_name}_{i + 1}",
+ processor,
+ )
+ )
+ else:
+ steps.append((f"{processor.__class__.__name__}_{i + 1}", processor))
+ else:
+ steps.append(processor)
+ super().__init__(steps, **kwargs)
+
+ def fit_transform(self, X, y=None, **kwargs):
+ for processor in self.processors:
+ if isinstance(processor, tuple):
+ p = processor[-1]
+ else:
+ p = processor
+ fit_transform_kwargs = get_kwargs(kwargs, p.fit_transform)
+ X = p.fit_transform(X, y, **fit_transform_kwargs)
+ return X
+
+ def pop(self, index):
+ self.steps.pop(index)
+ return self.processors.pop(index)
+
+
+class AutoPreprocessor(_BaseAutoProcessorClass):
+ _processor_mapping = PROCESSOR_MAPPING
+
+ @classmethod
+ def for_dataset(cls, dataset_name, **kwargs):
+ """
+ Create a preprocessor pipeline for a given dataset.
+
+ Args:
+ dataset: str
+ The dataset name.
+ Returns:
+ ProcessorPipeline
+ The preprocessor pipeline for the dataset.
+ """
+ if is_biosets_available():
+ from biosets.packaged_modules import EXPERIMENT_TYPE_ALIAS
+ else:
+ EXPERIMENT_TYPE_ALIAS = {}
+
+ dataset_name = EXPERIMENT_TYPE_ALIAS.get(dataset_name, dataset_name)
+ _processor_mapping = _LazyAutoMapping(
+ DATASET_TO_MAPPER_NAMES.get(dataset_name), PROCESSOR_MAPPING_NAMES
+ )
+ configs = AutoPreprocessorConfig.for_dataset(dataset_name, **kwargs)
+ procs = [
+ _get_class(config, _processor_mapping)._from_config(config)
+ for config in configs
+ ]
+
+ processors = ProcessorPipeline(procs)
+ return processors
diff --git a/src/biofit/cli/__init__.py b/src/biofit/cli/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/biofit/cli/install_r.py b/src/biofit/cli/install_r.py
new file mode 100644
index 0000000..fd012fc
--- /dev/null
+++ b/src/biofit/cli/install_r.py
@@ -0,0 +1,95 @@
+import os
+
+from biofit.integration.R.r_caller import (
+ R_PLOTTING_DEPENDENCIES,
+ R_PREPROCESSING_DEPENDENCIES,
+ RCaller,
+)
+from biofit.utils import logging
+
+logger = logging.get_logger(__name__)
+
+
+def add_to_parser(parser):
+ # Register the 'install' subcommand
+ install_r_parser = parser.add_parser(
+ "install",
+ help="Install R dependencies",
+ description="Install various dependencies including R and its packages.",
+ )
+
+ install_r_parser.add_argument(
+ "--all",
+ action="store_true",
+ help="Install all R dependencies for plotting and preprocessing",
+ )
+ install_r_parser.add_argument(
+ "--plotting",
+ action="store_true",
+ help="Install R plotting dependencies",
+ )
+ install_r_parser.add_argument(
+ "--preprocessing",
+ action="store_true",
+ help="Install R preprocessing dependencies",
+ )
+ install_r_parser.add_argument(
+ "--cran",
+ nargs="+",
+ help="Install specific CRAN packages",
+ )
+ install_r_parser.add_argument(
+ "--bioconductor",
+ nargs="+",
+ help="Install specific Bioconductor packages",
+ )
+ install_r_parser.add_argument(
+ "--binary",
+ action="store_true",
+ help="Install with binaries only",
+ )
+
+ install_r_parser.add_argument(
+ "--r-home",
+ "-r",
+ type=str,
+ help="Path to the R installation directory which contains the library folder",
+ )
+ return parser
+
+
+def run(args):
+ if args.r_home:
+ os.environ["R_HOME"] = args.r_home
+ cran_deps = args.cran or []
+ bioconductor_deps = args.bioconductor or []
+ params = {}
+ if args.binary:
+ params["pkgType"] = "binary"
+ if args.all or args.plotting:
+ cran_deps += R_PLOTTING_DEPENDENCIES["cran"]
+ bioconductor_deps += R_PLOTTING_DEPENDENCIES["bioconductor"]
+ if args.all or args.preprocessing:
+ cran_deps += R_PREPROCESSING_DEPENDENCIES["cran"]
+ bioconductor_deps += R_PREPROCESSING_DEPENDENCIES["bioconductor"]
+
+ if cran_deps:
+ logger.info(f"Checking for CRAN packages: {', '.join(cran_deps)}")
+ RCaller.verify_r_dependencies(
+ cran_dependencies=cran_deps,
+ bioconductor_dependencies=[],
+ install_missing=True,
+ )
+ logger.info("CRAN dependencies installed successfully.")
+
+ if bioconductor_deps:
+ logger.info(f"Checking Bioconductor packages: {', '.join(bioconductor_deps)}")
+ RCaller.verify_r_dependencies(
+ cran_dependencies=[],
+ bioconductor_dependencies=bioconductor_deps,
+ install_missing=True,
+ )
+ logger.info("Bioconductor dependencies installed successfully.")
+
+ if not cran_deps and not bioconductor_deps:
+ logger.info("No R dependencies to install.")
diff --git a/src/biofit/cli/main.py b/src/biofit/cli/main.py
new file mode 100644
index 0000000..ded0aa2
--- /dev/null
+++ b/src/biofit/cli/main.py
@@ -0,0 +1,50 @@
+import argparse
+import sys
+
+import biofit.cli.install_r as install_r
+from biofit.utils import logging
+
+logger = logging.get_logger(__name__)
+
+
+def create_parser():
+ parser = argparse.ArgumentParser(
+ description="BIOFIT: General-purpose Omics Machine Learning framework"
+ )
+
+ # add a quiet option
+ parser.add_argument(
+ "--quiet",
+ "-q",
+ action="store_true",
+ help="Suppress all logging output except for errors",
+ )
+ subparsers = parser.add_subparsers(dest="subcommand")
+ subparsers.required = True
+
+ return parser, subparsers
+
+
+def main():
+ parser, subparsers = create_parser()
+ subparsers = install_r.add_to_parser(subparsers)
+ args = parser.parse_args()
+
+ if args.quiet:
+ logging.set_verbosity(logging.ERROR)
+ else:
+ logging.set_verbosity(logging.INFO)
+
+ try:
+ if args.subcommand == "install":
+ install_r.run(args)
+ else:
+ parser.print_help()
+ sys.exit(1)
+ except Exception as e:
+ logger.error(str(e))
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/biofit/config.py b/src/biofit/config.py
new file mode 100644
index 0000000..63c2c89
--- /dev/null
+++ b/src/biofit/config.py
@@ -0,0 +1,231 @@
+import importlib
+import importlib.metadata
+import importlib.util
+import logging
+import os
+from pathlib import Path
+
+from packaging import version
+
+logger = logging.getLogger(__name__)
+
+ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
+ENV_VARS_FALSE_VALUES = {"0", "OFF", "NO", "FALSE"}
+ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
+ENV_VARS_FALSE_AND_AUTO_VALUES = ENV_VARS_FALSE_VALUES.union({"AUTO"})
+
+
+DEFAULT_XDG_CACHE_HOME = "~/.cache"
+XDG_CACHE_HOME = os.getenv("XDG_CACHE_HOME", DEFAULT_XDG_CACHE_HOME)
+DEFAULT_BIOFIT_CACHE_HOME = os.path.join(XDG_CACHE_HOME, "biofit")
+BIOFIT_CACHE_HOME = os.path.expanduser(
+ os.getenv("BIOFIT_HOME", DEFAULT_BIOFIT_CACHE_HOME)
+)
+
+DEFAULT_BIOFIT_DATASETS_CACHE = os.path.join(BIOFIT_CACHE_HOME, "datasets")
+BIOFIT_DATASETS_CACHE = Path(
+ os.getenv("BIOFIT_DATASETS_CACHE", DEFAULT_BIOFIT_DATASETS_CACHE)
+)
+DEFAULT_BIOFIT_PREPROCESSORS_CACHE = os.path.join(BIOFIT_CACHE_HOME, "processors")
+BIOFIT_PROCESSORS_CACHE = Path(
+ os.getenv("BIOFIT_PROCESSORS_CACHE", DEFAULT_BIOFIT_PREPROCESSORS_CACHE)
+)
+
+DEFAULT_BIOFIT_METRICS_CACHE = os.path.join(BIOFIT_CACHE_HOME, "metrics")
+BIOFIT_METRICS_CACHE = Path(
+ os.getenv("BIOFIT_METRICS_CACHE", DEFAULT_BIOFIT_METRICS_CACHE)
+)
+
+DEFAULT_BIOFIT_MODULES_CACHE = os.path.join(BIOFIT_CACHE_HOME, "modules")
+BIOFIT_MODULES_CACHE = Path(
+ os.getenv("BIOFIT_MODULES_CACHE", DEFAULT_BIOFIT_MODULES_CACHE)
+)
+
+BIOFIT_DYNAMIC_MODULE_NAME = Path(
+ os.getenv("BIOFIT_DYNAMIC_MODULE_NAME", "datasets_modules")
+)
+
+BIOFIT_CACHE_HOME = os.getenv(
+ "BIOFIT_CACHE_HOME",
+ os.path.join(os.getenv("XDG_CACHE_HOME", os.path.expanduser("~/.cache")), "biofit"),
+)
+DEFAULT_BIOFIT_PATCHES_CACHE = os.path.join(BIOFIT_CACHE_HOME, "patches")
+BIOFIT_PATCHES_CACHE = Path(
+ os.getenv("BIOFIT_PATCHES_CACHE", DEFAULT_BIOFIT_PATCHES_CACHE)
+)
+
+DOWNLOADED_PATCHES_DIR = "downloads"
+DEFAULT_DOWNLOADED_PATCHES_PATH = os.path.join(
+ BIOFIT_PATCHES_CACHE, DOWNLOADED_PATCHES_DIR
+)
+DOWNLOADED_PATCHES_PATH = Path(
+ os.getenv("BIOFIT_PATCHES_DOWNLOADED_PATCHES_PATH", DEFAULT_DOWNLOADED_PATCHES_PATH)
+)
+
+EXTRACTED_PATCHES_DIR = "extracted"
+DEFAULT_EXTRACTED_PATCHES_PATH = os.path.join(
+ DEFAULT_DOWNLOADED_PATCHES_PATH, EXTRACTED_PATCHES_DIR
+)
+EXTRACTED_PATCHES_PATH = Path(
+ os.getenv("BIOFIT_PATCHES_EXTRACTED_PATCHES_PATH", DEFAULT_EXTRACTED_PATCHES_PATH)
+)
+
+PATCHES_FILENAME = "patches.json"
+NO_PATCHES_FILENAME = "no_patches.json"
+IS_CONDA = os.getenv("CONDA_PREFIX") is not None
+
+
+PYARROW_AVAILABLE = False
+PYARROW_VERSION = "N/A"
+
+PYARROW_AVAILABLE = importlib.util.find_spec("pyarrow") is not None
+if PYARROW_AVAILABLE:
+ try:
+ PYARROW_VERSION = version.parse(importlib.metadata.version("pyarrow"))
+ logger.info(f"pyarrow version {PYARROW_VERSION} available.")
+ except importlib.metadata.PackageNotFoundError:
+ pass
+
+RPY2_AVAILABLE = False
+RPY2_VERSION = "N/A"
+
+RPY2_AVAILABLE = importlib.util.find_spec("rpy2") is not None
+if RPY2_AVAILABLE:
+ try:
+ RPY2_VERSION = version.parse(importlib.metadata.version("rpy2"))
+ logger.info(f"rpy2 version {RPY2_VERSION} available.")
+ except importlib.metadata.PackageNotFoundError:
+ pass
+
+RPY2_ARROW_AVAILABLE = False
+RPY2_ARROW_VERSION = "N/A"
+
+RPY2_ARROW_AVAILABLE = importlib.util.find_spec("rpy2_arrow") is not None
+if RPY2_ARROW_AVAILABLE:
+ try:
+ RPY2_ARROW_VERSION = version.parse(importlib.metadata.version("rpy2-arrow"))
+ logger.info(f"rpy2 version {RPY2_ARROW_VERSION} available.")
+ except importlib.metadata.PackageNotFoundError:
+ pass
+
+POLARS_VERSION = "N/A"
+POLARS_AVAILABLE = False
+
+POLARS_AVAILABLE = importlib.util.find_spec("polars") is not None
+if POLARS_AVAILABLE:
+ try:
+ POLARS_VERSION = version.parse(importlib.metadata.version("polars"))
+ logger.info(f"Polars version {POLARS_VERSION} available.")
+ except importlib.metadata.PackageNotFoundError:
+ pass
+
+DASK_VERSION = "N/A"
+DASK_AVAILABLE = False
+
+DASK_AVAILABLE = importlib.util.find_spec("dask") is not None
+if DASK_AVAILABLE:
+ try:
+ DASK_VERSION = version.parse(importlib.metadata.version("dask"))
+ logger.info(f"Dask version {DASK_VERSION} available.")
+ except importlib.metadata.PackageNotFoundError:
+ pass
+
+
+USE_TF = os.environ.get("USE_TF", "AUTO").upper()
+USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
+USE_JAX = os.environ.get("USE_JAX", "AUTO").upper()
+
+
+TORCH_VERSION = "N/A"
+TORCH_AVAILABLE = False
+
+if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
+ TORCH_AVAILABLE = importlib.util.find_spec("torch") is not None
+ if TORCH_AVAILABLE:
+ try:
+ TORCH_VERSION = version.parse(importlib.metadata.version("torch"))
+ logger.info(f"PyTorch version {TORCH_VERSION} available.")
+ except importlib.metadata.PackageNotFoundError:
+ pass
+else:
+ logger.info("Disabling PyTorch because USE_TF is set")
+
+TF_VERSION = "N/A"
+TF_AVAILABLE = False
+
+if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
+ TF_AVAILABLE = importlib.util.find_spec("tensorflow") is not None
+ if TF_AVAILABLE:
+ # For the metadata, we have to look for both tensorflow and tensorflow-cpu
+ for package in [
+ "tensorflow",
+ "tensorflow-cpu",
+ "tensorflow-gpu",
+ "tf-nightly",
+ "tf-nightly-cpu",
+ "tf-nightly-gpu",
+ "intel-tensorflow",
+ "tensorflow-rocm",
+ "tensorflow-macos",
+ ]:
+ try:
+ TF_VERSION = version.parse(importlib.metadata.version(package))
+ except importlib.metadata.PackageNotFoundError:
+ continue
+ else:
+ break
+ else:
+ TF_AVAILABLE = False
+ if TF_AVAILABLE:
+ if TF_VERSION.major < 2:
+ logger.info(
+ f"TensorFlow found but with version {TF_VERSION}. `datasets` requires version 2 minimum."
+ )
+ TF_AVAILABLE = False
+ else:
+ logger.info(f"TensorFlow version {TF_VERSION} available.")
+else:
+ logger.info("Disabling Tensorflow because USE_TORCH is set")
+
+
+JAX_VERSION = "N/A"
+JAX_AVAILABLE = False
+
+if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
+ JAX_AVAILABLE = (
+ importlib.util.find_spec("jax") is not None
+ and importlib.util.find_spec("jaxlib") is not None
+ )
+ if JAX_AVAILABLE:
+ try:
+ JAX_VERSION = version.parse(importlib.metadata.version("jax"))
+ logger.info(f"JAX version {JAX_VERSION} available.")
+ except importlib.metadata.PackageNotFoundError:
+ pass
+else:
+ logger.info("Disabling JAX because USE_JAX is set to False")
+
+
+USE_BEAM = os.environ.get("USE_BEAM", "AUTO").upper()
+BEAM_VERSION = "N/A"
+BEAM_AVAILABLE = False
+if USE_BEAM in ENV_VARS_TRUE_AND_AUTO_VALUES:
+ try:
+ BEAM_VERSION = version.parse(importlib.metadata.version("apache_beam"))
+ BEAM_AVAILABLE = True
+ logger.info(f"Apache Beam version {BEAM_VERSION} available.")
+ except importlib.metadata.PackageNotFoundError:
+ pass
+else:
+ logger.info("Disabling Apache Beam because USE_BEAM is set to False")
+
+
+R_SCRIPTS = Path(__file__).parent / "integration/R/scripts"
+BIOFIT_SKIP_R_DEPENDENCIES = (
+ os.getenv("BIOFIT_SKIP_R_DEPENDENCIES", "true").upper() not in ENV_VARS_FALSE_VALUES
+)
+RECORDER_ENABLED = os.getenv("BIOFIT_RECORDER_ENABLED", "true").lower() == "true"
+
+NO_IMBALANCE_ADJUSTMENT = os.getenv("BIOFIT_NO_IMBALANCE_ADJUSTMENT", "false") == "true"
+
+PBAR_REFRESH_TIME_INTERVAL = 0.05
diff --git a/src/biofit/eval.py b/src/biofit/eval.py
new file mode 100644
index 0000000..f541662
--- /dev/null
+++ b/src/biofit/eval.py
@@ -0,0 +1,296 @@
+from pathlib import Path
+from typing import List, Union
+
+import numpy as np
+from biocore import DataHandler
+from sklearn.pipeline import Pipeline
+
+from biofit.metrics import calculate_metrics, confusion_matrix, get_metrics
+from biofit.processing import BaseProcessor
+from biofit.train_eval_utils import (
+ _get_data,
+ get_model_info,
+ preprocess,
+ save_confusion_matrix,
+ save_metrics,
+ save_predictions,
+)
+
+
+def _predict(
+ models,
+ x_eval,
+ y_eval=None,
+ preprocessors=None,
+ use_proba=True,
+ task=None,
+ label_names=None,
+ cache_dir=None,
+):
+ # model can be a list of models for each y_eval, if multi-label
+ # classification/regression is used
+ if not isinstance(models, list):
+ models = [models]
+
+ y_preds = []
+ for i, model in enumerate(models):
+ preprocessor = preprocessors[i] if preprocessors is not None else None
+ if isinstance(model, Pipeline):
+ if len(model.steps) > 1 and preprocessor is None:
+ preprocessor = Pipeline([p for p in model.steps[:-1]])
+ model = model.steps[-1][1]
+ if preprocessor:
+ x_eval, _, _, _ = preprocess(
+ preprocessor, x_eval, cache_dir=cache_dir, transform_only=True
+ )
+
+ def fn(model, use_proba, x_test):
+ extra_kwargs = {}
+ if isinstance(model, BaseProcessor):
+ extra_kwargs = {"cache_dir": cache_dir, "load_from_cache_file": False}
+
+ if use_proba and hasattr(model, "predict_proba"):
+ return model.predict_proba(x_test, **extra_kwargs)
+ elif hasattr(model, "predict"):
+ return model.predict(x_test, **extra_kwargs)
+ else:
+ return model.transform(x_test, **extra_kwargs)
+
+ if not hasattr(model, "predict") and not hasattr(model, "predict_proba"):
+ for m in model:
+ y_pred = fn(m, use_proba, x_eval)
+ elif hasattr(model, "steps"):
+ for step in model.steps:
+ y_pred = fn(
+ step[-1] if isinstance(step, tuple) else step, use_proba, x_eval
+ )
+ else:
+ y_pred = fn(model, use_proba, x_eval)
+
+ y_pred_dims = DataHandler.get_shape(y_pred)
+ x_test_dims = DataHandler.get_shape(x_eval)
+ if y_eval is not None:
+ if task == "multiclass_classification":
+ if label_names is None:
+ label_names = np.unique(y_eval).tolist()
+
+ if len(y_pred_dims) == 1 and (
+ (y_pred_dims[0] % len(label_names)) != 0
+ or (y_pred_dims[0] // len(label_names)) != x_test_dims[0]
+ ):
+ y_pred_ohe = np.zeros((x_test_dims[0], len(label_names)))
+ y_pred_ohe[np.arange(x_test_dims[0]), y_pred] = 1
+ y_pred = y_pred_ohe
+ elif y_pred_dims[1] != len(label_names):
+ y_pred_ohe = np.zeros((y_pred_dims[0], len(label_names)))
+ y_pred_ohe[np.arange(y_pred_dims[0]), y_pred] = 1
+ y_pred = y_pred_ohe
+
+ if task == "binary_classification":
+ if len(y_pred_dims) == 1 and (
+ (y_pred_dims[0] % 2) != 0 or (y_pred_dims[0] // 2) != x_test_dims[0]
+ ):
+ y_pred = DataHandler.concat([1 - y_pred, y_pred], axis=1)
+ elif len(y_pred_dims) > 1 and y_pred_dims[1] != 2:
+ y_pred = DataHandler.concat([1 - y_pred, y_pred], axis=1)
+
+ y_preds.append(y_pred)
+
+ if len(y_preds) == 1:
+ y_preds = y_preds[0]
+ else:
+ y_preds = DataHandler.concat(y_preds, axis=1)
+ return y_preds
+
+
+def _update_metrics(y_true, y_pred, labels, metrics, task, results):
+ calc_results = calculate_metrics(metrics, y_true, y_pred, task, labels)
+ results.update(calc_results)
+
+
+def evaluate(
+ model,
+ data,
+ target=None,
+ label_names=None,
+ input_columns: Union[List[str], str] = "auto",
+ target_columns: Union[List[str], str] = "auto",
+ preprocessors=None,
+ task=None,
+ metrics=None,
+ use_proba=True,
+ output_dir=None,
+ cache_dir=None,
+ config=None,
+):
+ """
+ Evaluate the model or models on the test data with improved flexibility and error
+ handling.
+
+ Args:
+ data (np.ndarray): Data to be used for testing the model.
+ target (np.ndarray, *optional*): Labels corresponding to the test data.
+ models (list): A list of models to evaluate.
+ preprocessor (object, *optional*):
+ Preprocessor to apply to the data before evaluation.
+ metrics (dict):
+ Metrics to calculate for evaluation, with each metric being a callable. if
+ None, all metrics for the `task` will be calculated.
+ use_proba (bool):
+ Whether to use the predict_proba method of the model for predictions.
+ save_indices (bool): Whether to save the indices of the test set to a file.
+ output_dir (str): Directory to save the predictions table
+ cache_dir (str):
+ Directory to save the cache files. Defaults to f"{output_dir}/cache", if
+ output_dir is provided.
+
+ Returns:
+ list: A list of dictionaries with metric results for each model.
+ """
+
+ x_eval, y_eval, _, _, _, _, input_columns, target_columns = _get_data(
+ data=data,
+ target=target,
+ valid_data=None,
+ valid_target=None,
+ input_columns=input_columns,
+ target_columns=target_columns,
+ format=None,
+ target_required=False,
+ )
+
+ model_info = get_model_info(model, task)
+ if task is None:
+ task = model_info.get("task", None)
+
+ if (
+ preprocessors is not None
+ and not isinstance(preprocessors, list)
+ and isinstance(model, list)
+ ):
+ preprocessors = [preprocessors] * len(model)
+
+ if "classification" in task and label_names is None:
+ label_names = DataHandler.to_list(DataHandler.unique(y_eval))
+
+ if task and metrics is None and y_eval is not None:
+ if task == "classification":
+ if len(label_names) > 2:
+ task = "multiclass_classification"
+ else:
+ task = "binary_classification"
+ metrics = get_metrics(task)
+
+ if (
+ preprocessors is not None
+ and not isinstance(preprocessors, list)
+ and isinstance(model, list)
+ ):
+ preprocessors = [preprocessors] * len(model)
+
+ y_preds = _predict(
+ x_eval=x_eval,
+ models=model,
+ preprocessors=preprocessors,
+ use_proba=use_proba,
+ )
+ y_preds = DataHandler.to_pandas(y_preds)
+
+ results = None
+ if y_eval is not None:
+ results = calculate_metrics(
+ metrics, DataHandler.to_pandas(y_eval), y_preds, task, label_names
+ )
+
+ y_preds = y_preds.sort_index()
+ if output_dir is not None:
+ output_dir = Path(output_dir)
+ output_dir.mkdir(exist_ok=True, parents=True)
+ y_preds = DataHandler.to_pandas(y_preds)
+ labs = model_info.get("class_names", None) or label_names
+ if results is not None:
+ save_metrics(output_dir, results)
+
+ cm = confusion_matrix(y_eval, y_preds, labels=labs)
+ save_confusion_matrix(output_dir, cm, label_names)
+ save_predictions(
+ output_dir,
+ y_preds,
+ data,
+ y_eval,
+ label_names=label_names,
+ target_columns=target_columns,
+ )
+
+ return y_preds, results
+
+
+def predict(
+ model,
+ data,
+ input_columns: Union[List[str], str] = "auto",
+ preprocessors=None,
+ use_proba=True,
+ output_dir=None,
+ cache_dir=None,
+ config=None,
+):
+ """
+ Evaluate the model or models on the test data with improved flexibility and error
+ handling.
+
+ Args:
+ data (np.ndarray): Data to be used for testing the model.
+ target (np.ndarray, *optional*): Labels corresponding to the test data.
+ models (list): A list of models to evaluate.
+ preprocessor (object, *optional*):
+ Preprocessor to apply to the data before evaluation.
+ use_proba (bool):
+ Whether to use the predict_proba method of the model for predictions.
+ output_dir (str): Directory to save the predictions table
+ cache_dir (str):
+ Directory to save the cache files. Defaults to f"{output_dir}/cache", if
+ output_dir is provided.
+ config (dict, *optional*):
+ Placeholder for additional configuration options. Currently not used.
+
+ Returns:
+ list: A list of dictionaries with metric results for each model.
+ """
+
+ model_info = get_model_info(model, task=None)
+ x_eval, _, _, _, _, _, input_columns, _ = _get_data(
+ data=data,
+ input_columns=input_columns,
+ target_columns=None,
+ format=None,
+ target_required=False,
+ )
+
+ if (
+ preprocessors is not None
+ and not isinstance(preprocessors, list)
+ and isinstance(model, list)
+ ):
+ preprocessors = [preprocessors] * len(model)
+
+ y_preds = _predict(
+ x_eval=x_eval,
+ models=model,
+ preprocessors=preprocessors,
+ use_proba=use_proba,
+ )
+ y_preds = DataHandler.to_pandas(y_preds)
+
+ if output_dir is not None:
+ output_dir = Path(output_dir)
+ output_dir.mkdir(exist_ok=True, parents=True)
+ y_pred = DataHandler.to_pandas(y_preds)
+ save_predictions(
+ output_dir,
+ y_pred,
+ data,
+ label_names=model_info.get("class_names", None),
+ )
+
+ return y_preds
diff --git a/src/biofit/exceptions.py b/src/biofit/exceptions.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/biofit/integration/R/__init__.py b/src/biofit/integration/R/__init__.py
new file mode 100644
index 0000000..261c8f4
--- /dev/null
+++ b/src/biofit/integration/R/__init__.py
@@ -0,0 +1,2 @@
+# ruff: noqa
+from .r_caller import RCaller
diff --git a/src/biofit/integration/R/r_caller.py b/src/biofit/integration/R/r_caller.py
new file mode 100644
index 0000000..f3b2b17
--- /dev/null
+++ b/src/biofit/integration/R/r_caller.py
@@ -0,0 +1,736 @@
+import gc
+import inspect
+import io
+import os
+import re
+import shutil
+import subprocess
+import sys
+from glob import glob
+from pathlib import Path
+from typing import TYPE_CHECKING, List, Optional, TextIO
+
+import numpy as np
+from biocore.utils.import_util import (
+ is_rpy2_arrow_available,
+ is_rpy2_available,
+ requires_backends,
+)
+
+import biofit.config
+from biofit import config
+from biofit.utils import logging
+
+if TYPE_CHECKING:
+ from rpy2.robjects.conversion import Converter
+
+
+class PackageNotInstalledError(ImportError):
+ """Error occuring because the R package to import is not installed."""
+
+ pass
+
+
+logger = logging.get_logger(__name__)
+
+_is_auto_install = None
+
+R_PLOTTING_DEPENDENCIES = {
+ "cran": [
+ "ggplot2",
+ "arrow",
+ "circlize",
+ "RColorBrewer",
+ "scales",
+ "forcats",
+ "patchwork",
+ "reshape2",
+ "dplyr",
+ "tools",
+ ],
+ "bioconductor": ["ComplexHeatmap"],
+}
+R_PREPROCESSING_DEPENDENCIES = {
+ "cran": [],
+ "bioconductor": ["edgeR"],
+}
+
+
+def is_auto_install():
+ global _is_auto_install
+ if _is_auto_install is None:
+ _is_auto_install = not biofit.config.BIOFIT_SKIP_R_DEPENDENCIES
+ return _is_auto_install
+
+
+def enable_auto_install():
+ global _is_auto_install
+ _is_auto_install = True
+
+
+def disable_auto_install():
+ global _is_auto_install
+ _is_auto_install = False
+
+
+class ROutputCapture:
+ def __init__(self, stdout: TextIO = io.StringIO(), stderr: TextIO = io.StringIO()):
+ requires_backends("ROutputCapture", "rpy2")
+ from rpy2.rinterface_lib import callbacks
+
+ # Create StringIO buffers to capture output
+ self.stdout = stdout
+ self.stderr = stderr
+ # Save original R console write functions
+ self._original_consolewrite_print = callbacks.consolewrite_print
+ self._original_consolewrite_warnerror = callbacks.consolewrite_warnerror
+
+ def __enter__(self):
+ from rpy2.rinterface_lib import callbacks
+
+ # Define new console write functions
+ def custom_consolewrite_print(output):
+ self.stdout.write(output)
+
+ def custom_consolewrite_warnerror(output):
+ self.stderr.write(output)
+
+ # Replace R console write functions with custom functions
+ callbacks.consolewrite_print = custom_consolewrite_print
+ callbacks.consolewrite_warnerror = custom_consolewrite_warnerror
+
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ from rpy2.rinterface_lib import callbacks
+
+ # Restore original R console write functions
+ callbacks.consolewrite_print = self._original_consolewrite_print
+ callbacks.consolewrite_warnerror = self._original_consolewrite_warnerror
+
+ # Reset the position of StringIO buffers to the beginning
+ self.stdout.seek(0)
+ self.stderr.seek(0)
+
+ def get_stdout(self):
+ return self.stdout.getvalue()
+
+ def get_stderr(self):
+ return self.stderr.getvalue()
+
+
+def get_linux_distro_codename():
+ """
+ Retrieves the codename of the Linux distribution by parsing /etc/os-release.
+
+ Returns:
+ str: The codename of the distribution (e.g., 'focal', 'bullseye', 'centos7').
+ Returns 'Unknown' if the codename cannot be determined.
+ """
+ os_release_path = "/etc/os-release"
+
+ if not os.path.isfile(os_release_path):
+ return "Unknown"
+
+ os_info = {}
+
+ try:
+ with open(os_release_path, "r") as file:
+ for line in file:
+ # Remove any leading/trailing whitespace and skip empty lines
+ line = line.strip()
+ if not line or "=" not in line:
+ continue
+
+ key, value = line.split("=", 1)
+ # Remove surrounding quotes if present
+ value = value.strip('"').strip("'")
+ os_info[key] = value
+ except Exception as e:
+ print(f"Error reading {os_release_path}: {e}")
+ return None
+
+ # Handle specific distributions
+ distro_id = os_info.get("ID", "").lower()
+
+ if distro_id in ["ubuntu", "debian"]:
+ # For Ubuntu and Debian, extract codename from VERSION or VERSION_ID
+ if "VERSION_CODENAME" in os_info:
+ return os_info["VERSION_CODENAME"].lower()
+ elif "VERSION" in os_info:
+ # Example VERSION: "20.04.5 LTS (Focal Fossa)"
+ match = re.search(r"\(([^)]+)\)", os_info["VERSION"])
+ if match:
+ return match.group(1).lower()
+ elif "VERSION_ID" in os_info:
+ return os_info["VERSION_ID"].lower()
+
+ elif distro_id in ["centos", "rhel", "fedora"]:
+ # For CentOS, RHEL, Fedora, use VERSION_ID or similar
+ if "VERSION_ID" in os_info:
+ return f"{distro_id}{os_info['VERSION_ID'].lower()}"
+ elif "VERSION" in os_info:
+ # Example VERSION: "7 (Core)"
+ match = re.search(r"(\d+)", os_info["VERSION"])
+ if match:
+ return f"{distro_id}{match.group(1)}"
+
+ # Attempt to get VERSION_CODENAME directly
+ if "VERSION_CODENAME" in os_info:
+ return os_info["VERSION_CODENAME"].lower()
+
+ # Fallback: Try to extract codename from PRETTY_NAME or other fields
+ for key in ["PRETTY_NAME", "NAME", "DESCRIPTION"]:
+ if key in os_info:
+ match = re.search(r"\b([a-z]+)\b", os_info[key], re.IGNORECASE)
+ if match:
+ return match.group(1).lower()
+
+ return None
+
+
+def get_cran_info():
+ """
+ Detects the R version using rpy2 and constructs the appropriate CRAN URL based on the operating system
+ and, for Linux, the distribution codename.
+
+ Returns:
+ str: The constructed CRAN URL, or a default 'latest' URL if R version cannot be detected.
+ """
+ cran_url = None
+
+ from rpy2.rinterface_lib.embedded import RRuntimeError
+
+ try:
+ if sys.platform in ["win32", "darwin"]:
+ # For Windows, include R version in the CRAN URL
+ cran_url = "https://packagemanager.posit.co/cran/latest"
+ elif sys.platform.startswith("linux"):
+ # For Linux and macOS
+ dist = get_linux_distro_codename()
+ if dist:
+ cran_url = (
+ f"https://packagemanager.posit.co/cran/__linux__/{dist}/latest"
+ )
+ else:
+ logger.warning(
+ "Linux distribution codename could not be determined. Using default CRAN URL."
+ )
+ return None
+ else:
+ logger.warning("Unknown operating system. Using default CRAN URL.")
+ return None
+
+ except RRuntimeError as e:
+ logger.error(f"Error detecting R version: {e}. Using default CRAN URL.")
+
+ return cran_url
+
+
+def get_bioconductor_info():
+ bioc_mirror = "https://packagemanager.posit.co/bioconductor/latest"
+ bioconductor_config_file = f"{bioc_mirror}/config.yaml"
+ logger.info(f"Using Bioconductor mirror: {bioc_mirror}")
+ logger.info(f"Using Bioconductor config file: {bioconductor_config_file}")
+
+ return bioconductor_config_file, bioc_mirror
+
+
+class RCaller:
+ r_code = None
+ r_source = None
+ _r_context = {}
+ _global_vars = {}
+
+ def _get_converter(self, obj, use_arrow=True) -> "Converter":
+ """Check if rpy2 is available and retrieves the right converter.
+
+ Args:
+ obj (obj): The object (or name) calling this method.
+
+ Raises:
+ RuntimeError: If rpy2 is not available.
+
+ Returns:
+ Converter: The converter to convert R objects to Python objects.
+ """
+ if is_rpy2_arrow_available() and use_arrow:
+ from rpy2.robjects import default_converter, numpy2ri, pandas2ri
+ from rpy2_arrow.arrow import converter
+
+ _converter = (
+ default_converter + numpy2ri.converter + pandas2ri.converter + converter
+ )
+ elif is_rpy2_available():
+ from rpy2.robjects import conversion, numpy2ri, pandas2ri
+
+ _converter = (
+ (
+ conversion.get_conversion()
+ if getattr(conversion, "get_conversion", None)
+ else conversion.converter
+ )
+ + numpy2ri.converter
+ + pandas2ri.converter
+ )
+ else:
+ # suggest installing rpy2_arrow if rpy2 is not available
+ requires_backends(obj, "rpy2_arrow")
+ _converter = None
+
+ return _converter
+
+ @staticmethod
+ def get_r_home():
+ """
+ Retrieve the R_HOME directory across multiple systems, prioritizing local R installations,
+ particularly those within an active Conda environment.
+
+ Returns:
+ str or None: The path to R_HOME if found, otherwise None.
+ """
+ # Step 1: Check R_HOME environment variable
+ r_home_env = os.environ.get("R_HOME")
+ if r_home_env and os.path.exists(r_home_env):
+ return r_home_env
+
+ # Step 2: Check for R in active Conda environment
+ conda_prefix = os.environ.get("CONDA_PREFIX")
+ if conda_prefix:
+ if sys.platform == "win32":
+ r_executable = os.path.join(conda_prefix, "Scripts", "R.exe")
+ else:
+ r_executable = os.path.join(conda_prefix, "bin", "R")
+
+ if os.path.exists(r_executable):
+ try:
+ r_home = subprocess.check_output(
+ [r_executable, "RHOME"], universal_newlines=True
+ ).strip()
+ if os.path.exists(r_home):
+ return r_home
+ except subprocess.CalledProcessError:
+ pass # R executable found but failed to get R_HOME
+
+ # Step 3: Check system PATH for R executable
+ r_executable = shutil.which("R")
+ if r_executable:
+ try:
+ r_home = subprocess.check_output(
+ [r_executable, "RHOME"], universal_newlines=True
+ ).strip()
+ if os.path.exists(r_home):
+ return r_home
+ except subprocess.CalledProcessError:
+ pass # R executable found but failed to get R_HOME
+
+ # Step 4: Check common installation paths (additional step for thoroughness)
+ potential_paths = []
+ if sys.platform == "win32":
+ # Windows common installation paths
+ potential_paths.extend(glob("C:/Program Files/R/R-*/"))
+
+ # Attempt to read R_HOME from Windows Registry
+ try:
+ import winreg
+
+ reg_path = r"SOFTWARE\R-core\R"
+ for root in [winreg.HKEY_LOCAL_MACHINE, winreg.HKEY_CURRENT_USER]:
+ try:
+ key = winreg.OpenKey(root, reg_path)
+ r_home, _ = winreg.QueryValueEx(key, "InstallPath")
+ if os.path.exists(r_home):
+ return r_home
+ except FileNotFoundError:
+ continue
+ except ImportError:
+ pass # winreg not available
+ else:
+ # macOS and Linux common installation paths
+ potential_paths.extend(
+ [
+ "/usr/local/lib/R",
+ "/usr/lib/R",
+ "/Library/Frameworks/R.framework/Resources", # macOS
+ ]
+ )
+
+ for path in potential_paths:
+ if os.path.exists(path):
+ return path
+
+ # R_HOME could not be determined
+ return None
+
+ @staticmethod
+ def verify_r_dependencies(
+ cran_dependencies: List[str] = R_PLOTTING_DEPENDENCIES["cran"]
+ + R_PREPROCESSING_DEPENDENCIES["cran"],
+ bioconductor_dependencies: List[str] = R_PREPROCESSING_DEPENDENCIES[
+ "bioconductor"
+ ]
+ + R_PLOTTING_DEPENDENCIES["bioconductor"],
+ install_missing: Optional[bool] = None,
+ **kwargs,
+ ):
+ if install_missing is None:
+ install_missing = is_auto_install()
+
+ r_home = None
+ try:
+ import rpy2.situation
+
+ r_home = rpy2.situation.get_r_home()
+ except Exception:
+ pass
+ r_home = r_home or RCaller.get_r_home()
+ if r_home is None:
+ raise RuntimeError("R_HOME could not be determined.")
+
+ import rpy2.robjects.packages as rpackages
+ from rpy2.robjects import ListVector
+ from rpy2.robjects.vectors import StrVector
+
+ utils = rpackages.importr("utils")
+ base = rpackages.importr("base")
+ bioc_missing, cran_missing = [], []
+ cran_url = get_cran_info()
+ if cran_url is not None and "repos" not in kwargs:
+ kwargs["repos"] = ListVector([("CRAN", cran_url)])
+ if cran_dependencies:
+ names_to_install = [
+ x for x in cran_dependencies if not rpackages.isinstalled(x)
+ ]
+ if len(names_to_install) > 0 and install_missing:
+ logger.info(f"Using R_HOME: {r_home}")
+ logger.info(f"Installing missing CRAN dependencies: {names_to_install}")
+ if "BiocManager" in names_to_install:
+ names_to_install.remove("BiocManager")
+ utils.chooseCRANmirror(ind=1)
+ utils.install_packages("BiocManager")
+
+ if cran_url is not None:
+ logger.info("Using CRAN mirror: %s", cran_url)
+ base.options(**kwargs)
+ utils.install_packages(StrVector(names_to_install))
+ # verify again to check if all dependencies are installed
+ names_to_install = [
+ x for x in cran_dependencies if not rpackages.isinstalled(x)
+ ]
+ if len(names_to_install) > 0:
+ # conda package names
+ conda_packages = ["r-" + name.lower() for name in names_to_install]
+
+ conda_install_cmd = (
+ f"conda install -y -c conda-forge {' '.join(conda_packages)}"
+ )
+ logger.warning(
+ f"Failed to install the following CRAN dependencies: "
+ f"{names_to_install}. If you are using conda, you can try "
+ "installing the dependencies via:\n"
+ f"{conda_install_cmd}"
+ )
+
+ else:
+ cran_missing = names_to_install
+ if bioconductor_dependencies:
+ names_to_install = [
+ x for x in bioconductor_dependencies if not rpackages.isinstalled(x)
+ ]
+ if len(names_to_install) > 0 and install_missing:
+ logger.info(f"Using R_HOME: {r_home}")
+ logger.info(
+ "Installing missing Bioconductor dependencies: "
+ f"{bioconductor_dependencies}"
+ )
+ if not rpackages.isinstalled("BiocManager"):
+ logger.info("Installing BiocManager")
+ utils.chooseCRANmirror(ind=1)
+ utils.install_packages("BiocManager")
+
+ biocmanager = rpackages.importr("BiocManager")
+ bioconductor_config_file, bioc_mirror = get_bioconductor_info()
+ if cran_url is not None:
+ logger.info("Using CRAN mirror: %s", cran_url)
+ if bioc_mirror is not None:
+ logger.info("Using Bioconductor mirror: %s", bioc_mirror)
+ if bioc_mirror is not None and "BioC_mirror" not in kwargs:
+ kwargs["BioC_mirror"] = bioc_mirror
+ if (
+ bioconductor_config_file is not None
+ and "BIOCONDUCTOR_CONFIG_FILE" not in kwargs
+ ):
+ kwargs["BIOCONDUCTOR_CONFIG_FILE"] = bioconductor_config_file
+ base.options(**kwargs)
+ biocmanager.install(
+ StrVector(names_to_install), ask=False, update=False
+ )
+ # verify again to check if all dependencies are installed
+ names_to_install = [
+ x for x in bioconductor_dependencies if not rpackages.isinstalled(x)
+ ]
+ if len(names_to_install) > 0:
+ conda_packages = [
+ "bioconductor-" + name.lower() for name in names_to_install
+ ]
+ conda_install_cmd = (
+ "conda install -y -c conda-forge -c bioconda "
+ f"{' '.join(conda_packages)}"
+ )
+ msg = (
+ f"Failed to install the following Bioconductor dependencies: "
+ f"{names_to_install}. If you are using conda, you can try "
+ "installing the dependencies via:\n"
+ f"conda install -y -c conda-forge {' '.join(conda_packages)}"
+ )
+ if sys.platform == "win32":
+ msg += (
+ "\nNote: Many of the Bioconductor packages are not "
+ "available on Windows. Please submit an issue on the BIOFIT "
+ "GitLab repository if you need help with installation."
+ )
+ logger.warning(msg)
+ else:
+ bioc_missing = names_to_install
+ if (cran_missing or bioc_missing) and not install_missing:
+ instructions = "Run `biofit.integration.R.RCaller.verify_r_dependencies("
+ if cran_missing:
+ instructions += f"cran_deps={cran_missing}, "
+ if bioc_missing:
+ instructions += f"bioc_deps={bioc_missing}, "
+ instructions += "install_missing=True)` to install missing dependencies."
+ raise PackageNotInstalledError(
+ f"Missing R dependencies: {cran_missing + bioc_missing}. "
+ f"{instructions}"
+ )
+
+ @classmethod
+ def from_script(cls, r_code_or_path=None):
+ self = cls()
+ self._global_vars["R_SCRIPTS_PATH"] = (
+ f"{config.R_SCRIPTS.resolve().as_posix()}/"
+ )
+ self.r_code = ""
+ if r_code_or_path:
+ if os.path.exists(r_code_or_path):
+ self.r_source = r_code_or_path
+ self._global_vars["CURRENT_DIR"] = (
+ f"{Path(self.r_source).resolve().parent.as_posix()}/"
+ )
+ with open(r_code_or_path, "r") as f:
+ self.r_code += f.read()
+ else:
+ self.r_code += r_code_or_path
+ return self
+
+ def _run_r(
+ self,
+ code_runner,
+ r_code: str,
+ env=None,
+ add_globals=False,
+ quiet=False,
+ ) -> object:
+ """
+ Run R code using the provided code_runner function.
+
+ Args:
+ code_runner (Callable): The function to run the R code.
+ r_code (str): The R code to run.
+ env (SexpEnvironment, optional): The R environment to add global variables
+ to. Defaults to globalenv.
+ add_globals (bool, optional): Whether to add global variables to the R
+ context. Defaults to False.
+ quiet (bool, optional): Whether to suppress output from R. Defaults to
+ False.
+
+ Returns:
+ The output from the R code or the converted output if convert is True.
+ """
+ from rpy2.rinterface_lib.embedded import RRuntimeError
+
+ def run_code(env):
+ if env is None:
+ from rpy2.rinterface import globalenv
+
+ env = globalenv
+ if add_globals:
+ for name, value in self._global_vars.items():
+ env[name] = value.removeprefix('"').removesuffix('"')
+ results = code_runner(r_code)
+ last_line = r_code.strip().split("\n")[-1].strip()
+ # for some reason, arrow tables are not returned properly when letting
+ # rpy2 handle the conversion. So we grab it directly from the R context.
+ if last_line in env:
+ results = env[last_line]
+ return results
+
+ if quiet:
+ with ROutputCapture() as output_capture:
+ try:
+ return run_code(env)
+ except RRuntimeError as e:
+ # Attempt to capture the traceback from R.
+ r_stdout = output_capture.get_stdout()
+ r_stderr = output_capture.get_stderr()
+ logger.error(f"captured stdout from R: {r_stdout}")
+ logger.error(f"captured stderr from R: {r_stderr}")
+ try:
+ r_traceback = "\n".join(code_runner("unlist(traceback())"))
+ except Exception as traceback_exc:
+ r_traceback = f"Failed to capture R traceback: {traceback_exc}"
+ raise RuntimeError(f"Error in R code: {e}\n{r_traceback}") from e
+ else:
+ try:
+ return run_code(env)
+ except RRuntimeError as e:
+ try:
+ r_traceback = "\n".join(code_runner("unlist(traceback())"))
+ except Exception as traceback_exc:
+ r_traceback = f"Failed to capture R traceback: {traceback_exc}"
+ raise RuntimeError(f"Error in R code: {e}\n{r_traceback}") from e
+
+ def _convert_output(self, obj, code_runner, converter):
+ # check if the result is a list
+ is_list = converter.rpy2py(
+ self._run_r(code_runner, "is.list", quiet=True)(obj)
+ )[0]
+ if not is_list:
+ obj = [obj]
+
+ names = self._run_r(code_runner, "names")(obj)
+ has_no_names = converter.rpy2py(self._run_r(code_runner, "is.null")(names))[0]
+ if has_no_names:
+ names = [f"{i}" for i in range(len(obj))]
+ out = {}
+
+ for n, result in zip(names, obj):
+ if converter.rpy2py(self._run_r(code_runner, "is.vector")(result))[0]:
+ out[n] = np.array([r for r in result])
+ elif converter.rpy2py(self._run_r(code_runner, "is.list")(result))[0]:
+ out[n] = [r for r in result]
+ elif result.__class__ in converter.rpy2py_nc_map and hasattr(
+ result, "rclass"
+ ):
+ nc_map = converter.rpy2py_nc_map[result.__class__]
+ for rclass in result.rclass:
+ if rclass in nc_map:
+ out[n] = nc_map[rclass](result)
+ break
+ else:
+ out[n] = converter.rpy2py(result)
+ else:
+ out[n] = converter.rpy2py(result)
+
+ if len(out) == 1:
+ out = out[n]
+ return out
+
+ def _cleanup(self, code_runner):
+ code_runner("rm(list=ls())")
+ # garbage collection in Python doesn't always clean up after R automatically
+ code_runner("gc()")
+ gc.collect()
+
+ def _init_vars(self, code_runner, env, converter, **kwargs):
+ import rpy2.robjects as ro
+ from rpy2.rinterface import Sexp
+
+ def convert_to_r(arg):
+ if isinstance(arg, Sexp):
+ return arg
+ elif arg is None:
+ return ro.NULL
+ elif isinstance(arg, (list, tuple)):
+ return converter.py2rpy([convert_to_r(a) for a in arg])
+ elif isinstance(arg, dict):
+ return ro.ListVector(arg)
+ else:
+ return converter.py2rpy(arg)
+
+ for n, arg in kwargs.items():
+ env[n] = convert_to_r(arg)
+
+ def get_method(self, name, enter_code="", exit_code="", convert=True, quiet=True):
+ _converter = self._get_converter(name)
+ import rpy2.rinterface
+ import rpy2.robjects as ro
+
+ with rpy2.rinterface.local_context() as r_context:
+ self._run_r(ro.r, self.r_code, add_globals=True, quiet=quiet)
+ func_args = list(_converter.rpy2py(r_context[name]).formals().names)
+ self._cleanup(ro.r)
+
+ def func(*args, context_kwargs: dict = None, **kwargs):
+ if args:
+ for n, arg in zip(func_args[: len(args)], args):
+ kwargs[n] = arg
+
+ run_code = self.r_code + "\n" + enter_code + "\n"
+
+ arg_str = ", ".join([f"{k} = {k}" for k in kwargs.keys() if k in func_args])
+
+ # additional variables not in the function signature
+ if context_kwargs:
+ kwargs.update(context_kwargs)
+ kwargs.update(self._global_vars)
+
+ run_code += f"results <- {name}({arg_str})"
+
+ run_code += "\n" + exit_code + "\nresults"
+
+ out = self.run(run_code, convert=convert, quiet=quiet, **kwargs)
+
+ return out
+
+ parameters = [
+ inspect.Parameter(name=arg, kind=inspect.Parameter.POSITIONAL_OR_KEYWORD)
+ for arg in func_args
+ if arg != "..."
+ ]
+ parameters += (
+ [inspect.Parameter(name="kwargs", kind=inspect.Parameter.VAR_KEYWORD)]
+ if "..." in func_args
+ else []
+ )
+ func.__signature__ = inspect.Signature(parameters=parameters)
+ return func
+
+ def run(self, r_code=None, convert=True, quiet=False, **kwargs):
+ """A method to run the R code with the provided arguments.
+
+ Args:
+ r_code (str, optional): The R code to run. Defaults to the provided R code
+ given in the `from_script` method.
+ convert (bool, optional): Whether to convert the R output to Python objects.
+ Defaults to True.
+ **kwargs: The variables to pass to the R code
+ """
+ r_code = r_code or self.r_code
+ if r_code:
+ r_code = r_code.strip()
+ _converter = self._get_converter(self.run)
+
+ import rpy2.rinterface
+ import rpy2.robjects as ro
+
+ with rpy2.rinterface.local_context() as r_context:
+ self._init_vars(ro.r, r_context, _converter, **kwargs)
+ out = self._run_r(ro.r, r_code, env=r_context, quiet=quiet)
+ if out is None:
+ logger.warning("No output from R code")
+ if convert:
+ if out is None:
+ raise RuntimeError(
+ "Conversion requested but no output from R code. Either "
+ "provide an output by adding the variable name on the last "
+ "line of the R code or set `convert` to `False`."
+ )
+ out = self._convert_output(out, ro.r, _converter)
+ self._cleanup(ro.r)
+ return out
+ else:
+ raise RuntimeError(
+ "R code not provided, please provide R code or the path to the R script"
+ " via `from_script` method."
+ )
diff --git a/src/biofit/integration/R/scripts/analysis_utils.R b/src/biofit/integration/R/scripts/analysis_utils.R
new file mode 100644
index 0000000..4a2e837
--- /dev/null
+++ b/src/biofit/integration/R/scripts/analysis_utils.R
@@ -0,0 +1,245 @@
+# cumNorm <- function(x, p = cumNormStatFast(x)) {
+# normFactors <- calcNormFactors(x, p = p)
+# x$normFactors <- normFactors
+# return(x)
+# }
+#
+# calcNormFactors <- function(x, p = cumNormStatFast(x)) {
+# xx <- x
+# xx[x == 0] <- NA
+# qs <- rowQuantiles(xx, probs = p, na.rm = TRUE)
+# norm_factors <- apply(xx, 1, function(row, qs) {
+# row <- (row - .Machine$double.eps)
+# sum(row[row <= qs])
+# }, qs)
+# names(norm_factors) <- rownames(x)
+# as.data.frame(norm_factors)
+# }
+#
+# cumNormMat <- function(x, p = cumNormStatFast(x), sl = 1000) {
+# xx <- x
+# xx[x == 0] <- NA
+#
+# qs <- rowQuantiles(xx, probs = p, na.rm = TRUE)
+#
+# newMat <- apply(xx, 1, function(row, qs) {
+# row <- (row - .Machine$double.eps)
+# sum(row[row <= qs])
+# }, qs)
+# nmat <- sweep(x, 1, newMat / sl, "/")
+# return(nmat)
+# }
+#
+# cumNormStat <- function(obj, qFlag = TRUE, pFlag = FALSE, rel = .1, ...) {
+# mat <- returnAppropriateObj(obj, FALSE, FALSE)
+# if (any(rowSums(mat) == 0)) stop("Warning: empty feature")
+#
+# smat <- apply(mat, 1, function(row) sort(row, decreasing = FALSE))
+# ref <- colMeans(smat)
+#
+# yy <- mat
+# yy[yy == 0] = NA
+#
+# refS <- sort(ref)
+#
+# k <- which(refS > 0)[1]
+# lo <- (length(refS) - k + 1)
+#
+# if (qFlag) {
+# diffr <- apply(yy, 1, function(row) {
+# refS[k:length(refS)] - quantile(row, probs = seq(0, 1, length.out = lo), na.rm = TRUE)
+# })
+# } else {
+# diffr <- apply(yy, 1, function(row) {
+# refS[k:length(refS)] - approx(sort(row, decreasing = FALSE), n = lo)$y
+# })
+# }
+#
+# diffr2 <- matrixStats::colMedians(abs(diffr), na.rm = TRUE)
+#
+# x <- which(abs(diff(diffr2)) / diffr2[-1] > rel)[1] / length(diffr2)
+# if (x <= 0.50) {
+# message("Default value being used.")
+# x <- 0.50
+# }
+#
+# return(x)
+# }
+#
+# cumNormStatFast <- function(mat, pFlag = FALSE, rel = .1, ...) {
+#
+# smat <- apply(mat, 1, function(row) {
+# sort(row[which(row > 0)], decreasing = TRUE)
+# })
+#
+# leng <- max(sapply(smat, length))
+# if (any(sapply(smat, length) == 1)) stop("Warning: sample with one or zero features")
+#
+# smat2 <- array(NA, dim = c(leng, nrow(mat)))
+# for (i in 1:nrow(mat)) {
+# smat2[leng:(leng - length(smat[i]) + 1), i] = smat[i]
+# }
+#
+# rmat2 <- apply(smat2, 1, function(row) {
+# quantile(row, probs = seq(0, 1, length.out = ncol(smat2)), na.rm = TRUE)
+# })
+#
+# smat2[is.na(smat2)] = 0
+# ref1 <- colMeans(smat2)
+#
+# ncols <- ncol(rmat2)
+# diffr <- apply(rmat2, 1, function(row) {
+# ref1 - row
+# })
+# diffr1 <- matrixStats::colMedians(abs(diffr))
+#
+# x <- which(abs(diff(diffr1)) / diffr1[-1] > rel)[1] / length(diffr1)
+# if (x <= 0.50) {
+# message("Default value being used.")
+# x <- 0.50
+# }
+#
+# return(x)
+# }
+
+
+###############
+## setup the MRexperiment object
+## phenodat: sample data.frame with rownames being the
+## sampleID (crucial), first column also the sampleID here
+## OTUdata: feaure data.frame with rownames being the featID
+## (crucial). first column also the featID
+## cntdat: matrix data.frame with rownames being the featID and
+## colnames as sampleID
+###############
+
+setMRobject <- function(cntdat, phenodat, featdat) {
+ suppressPackageStartupMessages(require("metagenomeSeq"))
+ phenodatDF <- as.data.frame(phenodat)
+ phenotypeData <- AnnotatedDataFrame(phenodatDF)
+
+ featdatDF <- as.data.frame(featdat)
+ OTUdata <- AnnotatedDataFrame(featdatDF)
+
+ cntdatDF <- as.data.frame(cntdat)
+
+ obj <- newMRexperiment(cntdatDF, phenoData = phenotypeData, featureData = OTUdata)
+ return(obj)
+}
+
+###############
+ ### Filter OTU/Taxa presence with a chosen read count
+ ### 'present' for the number of samples the feature is at least present in
+ ### 'fdepth' for feature count cutoff for presence, a taxa/OTU has to have
+ ### at least this number of reads to be deemed present within a sample.
+ ### 'depth' for the total number reads a sample has to at least have
+###############
+filterMRobject <- function(obj, present = 1, fdepth = 1, depth = 1, norm = FALSE) {
+ mat <- returnAppropriateObj(obj, norm = norm, log = FALSE) > (fdepth - 1)
+ cols <- which(colSums(MRcounts(obj, norm = norm)) >= depth)
+ rows <- which(rowSums(mat[, cols]) >= present)
+
+ # Apply filter
+ obj <- obj[rows, cols]
+ return(obj)
+}
+
+
+###############
+# Normalize the data with percentile param
+# When percentile is NULL, normalization factor will be calculated
+###############
+normMRobject <- function(obj, percentile = NULL) {
+ # Calculating normalization factor
+ if (is.null(percentile)) {
+ percentile <- cumNormStat(obj)
+
+ }
+ # Apply normalization
+ obj <- cumNorm(obj, p = signif(percentile, 4))
+ return(obj)
+}
+
+###############
+# plot histograms to get a sense of the counts and presence
+###############
+plotHistMeta <- function(x, xlab="log2(Sum of counts)", main="Histogram", breaks=50) {
+ hist(x, xlab = xlab, main = main, breaks = breaks)
+}
+
+##################
+# Boxplot of distributions: Before and after normalization in log2 form
+##################
+boxplotMeta <- function(obj, keyAnnot, cols = NULL) {
+ #color setup
+ if (is.null(cols)) {
+ suppressPackageStartupMessages(require("RColorBrewer"))
+ cols <- c(brewer.pal(8, "Accent"), rev(brewer.pal(8, "Dark2")[-8]), brewer.pal(8, "Paired"))
+ }
+
+ cl <- factor(pData(obj)[, keyAnnot])
+ clcol <- cols[as.integer(cl)]
+
+ #plot
+ par(mfrow = c(2, 1))
+ boxplot(log2(1 + MRcounts(obj, norm = F, log = F)), col = clcol, outcol = clcol, ylab = "log2(1+Abundance)")
+ boxplot(log2(1 + MRcounts(obj, norm = T, log = F)), col = clcol, outcol = clcol, ylab = "log2(1+Abundance)")
+}
+
+
+
+##################
+# PCoA - Bray-Curtis
+##################
+fitPCoA <- function(obj, method = "bray", norm = T) {
+ suppressPackageStartupMessages(library("vegan"))
+ suppressPackageStartupMessages(library("ape"))
+
+ #### distance computation and dimension reduction
+ d <- vegdist(t(MRcounts(obj, norm = norm, log = F)), method = method)
+ pcodecomp <- pcoa(d)
+
+ # if any eigenvalue is negative, leverage Cailliez correction
+ if (sum(pcodecomp$values$Relative_eig < 0) > 0) {
+ pcodecomp <- pcoa(d, correction = "cailliez")
+ }
+ return(pcodecomp)
+}
+
+##################
+# Plot PCoA - Bray-Curtis
+##################
+plotPCoA <- function(obj, pcodecomp, keyAnnot, keyAnnot2=NULL, dimn=2, fileNameAdd="", cols=NULL){
+ suppressPackageStartupMessages(require(ape))
+
+ #color setup
+ if(is.null(cols)){
+ suppressPackageStartupMessages(require("RColorBrewer"))
+ cols = c(brewer.pal(8, "Accent"),rev(brewer.pal(8, "Dark2")[-8]), brewer.pal(8,"Paired"))
+ }
+ cl=cl2=factor(pData(obj)[,keyAnnot])
+ clcol=cols[as.integer(cl)]
+
+ # if a secondary annotation were specified, pch is used
+ if(!is.null(keyAnnot2)){
+ cl2 <- factor(pData(obj)[,keyAnnot2])
+ }
+ pch2use=c(1: length(levels(cl2)))
+ pchInput=pch2use[cl2]
+
+ # Compute the percentage of explained variance
+ PCOAaxes <- pcodecomp$vectors[,c(1:dimn)]
+ eignPERC<- pcodecomp$values$Rel_corr_eig[c(1:dimn)]
+
+ # plot PCoA
+ pairs(PCOAaxes, main=paste("PCoA", fileNameAdd), col=clcol, pch=pchInput,
+ cex=1.1, cex.labels = 2, cex.axis=1.5, upper.panel=NULL,
+ labels=paste("Dim",1:dimn,"\n",round(eignPERC,3)*100,"%"))
+ if(is.null(keyAnnot2)){
+ legend("topright", legend = levels(cl), col=cols, pch=pch2use, ncol=3, cex=1)#, inset = c(0.1, 0.1))
+ }else{
+ legend("top", legend = levels(cl), col=cols, pch = 1, ncol=3, cex=1, inset = c(0.1, 0.1))
+ legend("topright", legend = levels(cl2), pch=pch2use, cex=1, inset = c(0.1, -0.1))
+ }
+
+}
diff --git a/src/biofit/integration/R/scripts/dependencies.R b/src/biofit/integration/R/scripts/dependencies.R
new file mode 100644
index 0000000..12735c5
--- /dev/null
+++ b/src/biofit/integration/R/scripts/dependencies.R
@@ -0,0 +1,10 @@
+source(paste0(R_SCRIPTS_PATH, "/utils.R"))
+
+install_dependencies <- function() {
+ dependencies <- c(
+ "ggplot2", "arrow", "circlize", "RColorBrewer", "scales",
+ "forcats", "patchwork", "reshape2", "ComplexHeatmap",
+ "edgeR", "dplyr", "tools"
+ )
+ ensure_packages(dependencies)
+}
diff --git a/src/biofit/integration/R/scripts/plotting_utils.R b/src/biofit/integration/R/scripts/plotting_utils.R
new file mode 100644
index 0000000..4349f94
--- /dev/null
+++ b/src/biofit/integration/R/scripts/plotting_utils.R
@@ -0,0 +1,1447 @@
+source(file.path(R_SCRIPTS_PATH, "utils.R"))
+
+sci_format <- function(x) {
+ suppressPackageStartupMessages(require(scales))
+ ifelse((abs(x) < 1 & abs(x) > 0 & floor(log10(abs(x))) <= -2) | abs(x) >= 10000,
+ formatC(x, format = "e", digits = 2), formatC(x, format = "f", digits = 2)
+ )
+}
+
+#' `prepare_data_for_hist()` calculates row sums and columns sums for 1 or 2 data frames.
+#'
+#' Note: If you leave x2 blank, it will be NULL and not calculate.
+#'
+#' @param x1 data.frame: 1st or only data frame to have sums be calculated
+#' @param x2 data.frame: 2nd data frame to have sums be calculated. Default is NULL.
+#'
+#' @returns list: a list of all the sums; 1st two values in the list are for x1, the next two for x2.
+#'
+prepare_data_for_hist <- function(x1, x2 = NULL) {
+ # Calculate row sum and col sums of dataset 1
+ data1_sample <- rowSums(x1)
+ data1_feature <- colSums(x1)
+
+ list_sums <- list(data1_sample, data1_feature)
+
+ # Calculates row and col sums for second dataset if it is given Allows for
+ # situations with only a single dataset
+ if (!is.null(x2)) {
+ data2_sample <- rowSums(x2)
+ data2_feature <- colSums(x2)
+
+ list_sums <- append(list_sums, list(data2_sample, data2_feature))
+ }
+
+ return(list_sums)
+}
+
+#' `non_zero_sums()` calculates row sums and columns sums for non-zero values in 1 or 2 data frames.
+#'
+#' Note: If you leave x2 blank, it will be NULL and not calculate.
+#'
+#' @param x1 (data.frame) 1st or only data frame to have non-zero sums be calculated.
+#' @param x2 (data.frame) 2nd data frame to have non-zero sums be calculated. Default is NULL.
+#'
+#' @returns (list) of all the sums; 1st value in the list are for x1, the 2nd is for x2 if used.
+#'
+non_zero_sums <- function(x1, x2 = NULL) {
+ # Sum all non-zero values for a dataset
+ x1_non_zero_row <- rowSums(x1 != 0)
+ x1_non_zero_col <- colSums(x1 != 0)
+ sums <- list(x1_non_zero_row, x1_non_zero_col)
+
+ # Allows for situations with only a single dataset
+ if (!is.null(x2)) {
+ x2_non_zero_row <- rowSums(x2 != 0)
+ x2_non_zero_col <- colSums(x2 != 0)
+ sums <- append(sums, list(x2_non_zero_row, x2_non_zero_col))
+ }
+
+ # Returns list of non_zero rowSums
+ return(sums)
+}
+
+#' `prepare_axis_label()` adjusts axis label names to include log transformation.
+#'
+#' @param label chr: the axis label to be adjusted.
+#' @param log_type chr: the log transformation being applied.
+#'
+#' @returns chr: the new label with the transformation added.
+#'
+prepare_axis_label <- function(label, log_type) {
+ if (grepl("1p", log_type)) {
+ if (grepl("_1p", log_type)) {
+ label_log <- gsub("_1p", "", log_type)
+ label <- paste0(label, " (", label_log, "(x+1))")
+ } else {
+ label <- paste0(label, " (ln(x+1))")
+ }
+ } else if (log_type == "log") {
+ label <- paste0(label, " (ln)")
+ } else {
+ label <- paste0(label, " (", log_type, ")")
+ }
+ return(label)
+}
+
+log_transformation <- function(x, log_type) {
+ if (grepl("1p", log_type)) {
+ if (grepl("_1p", log_type)) {
+ label_log <- gsub("_1p", "", log_type)
+ if (label_log == "log10") {
+ return(log10(1 + x))
+ } else if (label_log == "log2") {
+ return(log2(1 + x))
+ }
+ } else {
+ return(log(1 + x))
+ }
+ } else if (log_type == "log") {
+ return(log(x))
+ } else if (log_type == "log2") {
+ return(log2(x))
+ } else if (log_type == "log10") {
+ return(log10(x))
+ }
+ return(x)
+}
+
+#'
+#' Transformations functions for different transformation functions
+#' They are called by scale_*_continuous when we enter the name (first argument in trans_new())
+#' They just need to be here, don't need to be called any other time.
+#'
+
+#' `log1p()` is a transformation function for log with x + 1 as an input
+log2_1p <- function(x) {
+ log2(1 + x)
+}
+
+#' `log2_1p_trans()` is a transformation function for log base 2 with x + 1 as an input
+log2_1p_trans <- function() {
+ suppressPackageStartupMessages(require(scales))
+
+ trans_new("log2_1p",
+ transform = log2_1p, inverse = function(x) {
+ 2^x - 1
+ }, breaks = trans_breaks(log2_1p, function(x) 2^x - 1),
+ domain = c(0, Inf)
+ )
+}
+
+#' `log10_1p()` is a transformation function for log base 10 with x + 1 as an input
+log10_1p <- function(x) {
+ log10(1 + x)
+}
+
+#' `log10_1p_trans()` is a transformation function for log base 10 with x + 1 as an input
+log10_1p_trans <- function() {
+ suppressPackageStartupMessages(require(scales))
+
+ trans_new("log10_1p",
+ transform = log10_1p, inverse = function(x) {
+ 10^x - 1
+ }, breaks = trans_breaks(log10_1p, function(x) 10^x - 1),
+ domain = c(0, Inf)
+ )
+}
+
+#'
+#' `color_select()` prepares a vector of colours to use in a ggplot
+#'
+#' Note: Requires RColorBrewer and circulize packages
+#'
+#' @details It will use RColorBrewer's sets of colours as the base for the vector
+#' If the number of colours needed are <= the number of colors in the set,
+#' then the colours are used directly from the set
+#' Else, the function will generate colors in between using a colour ramp
+#'
+#' @param levels (int) number of colours that are to be returned
+#' @param col_set (chr) the RColorBrewer set that is to be used, and only those sets
+#' ('Set1', 'Set2', 'Set3', 'Pastel2', 'Pastel1', 'Paired', 'Dark2', 'Accent')
+#'
+#' @returns (vec) a vector of colours
+#'
+#' @examples
+#' \dontrun{
+#' colors <- color_select(5, "Set3")
+#' }
+color_select <- function(levels, col_set = "Set1") {
+ col_set <- match.arg(col_set, c(
+ "Set1", "Set2", "Set3", "Pastel2", "Pastel1",
+ "Paired", "Dark2", "Accent"
+ ))
+
+ if (col_set %in% c("Set1", "Pastel1")) {
+ num_col <- 9
+ } else if (col_set %in% c("Set3", "Paired")) {
+ num_col <- 12
+ } else {
+ num_col <- 8
+ }
+
+ if (levels > num_col) {
+ at <- seq(0, levels, length.out = num_col)
+ color_fn <- circlize::colorRamp2(at, RColorBrewer::brewer.pal(num_col, col_set))
+ colors <- color_fn(1:levels)
+ } else {
+ color_fn <- RColorBrewer::brewer.pal(num_col, col_set)
+ colors <- color_fn[1:levels]
+ }
+ return(colors)
+}
+
+
+#' Generate a simple histogram for one variable.
+#'
+#' `generate_histogram()` creates a simple histogram for a single numerical variable.
+#'
+#' Note: requires packages ggplot2 and rlang.
+#'
+#' @param data (data.frame or vector) a data frame that the function is obtaining the data from or a vector of data
+#' If you pass a vector, it will be made into a data.frame with the label as the name of the column.
+#' @param column (chr) the column with the data that is to be used in the histogram.
+#' Default is value so label doesn't need to be provided and the function works.
+#' @param xlab (chr) name of the x-axis. Default is 'X'.
+#' @param ylab (chr) name of the y-axis. Default is 'Frequency'.
+#' @param title (chr) title of the figure. Default is 'Histogram'.
+#' @param bins (int) number of bins for the histogram. Default is 30.
+#' @param font_size (dbl) size of the font for the plot. Default is 8.
+#' @param alpha (dbl) Opacity of bars. Default is 0.6.
+#' @param col_fill (chr) primary colour of the bars. Default is 'grey40'.
+#' @param col_outline (chr) colour of the outline of the bars. Default is 'black'.
+#' @param xlog (chr) logarithmic transformation of the x-axis.
+#' Options include: (default) NULL (none), 'log', 'log2', 'log10', 'log1p', 'log2_1p', 'log10_1p'
+#' @param ylog (chr) logarithmic transformation of the y-axis. Default is NULL (none).
+#' Options include: (default) NULL (none), 'log', 'log2', 'log10', 'log1p', 'log2_1p', 'log10_1p'
+#'
+#' @returns the ggplot build of the final histogram.
+#'
+#' @examples
+#' \dontrun{
+#' generate_histogram(mtcars, wt)
+#'
+#' generate_histogram(mtcars, wt,
+#' xlab = "Weights", ylab = "Frequency",
+#' title = "Weights of Cars", bins = 30,
+#' col_fill = "red", col_outline = "black", xlog = "log2",
+#' ylog = NULL
+#' )
+#' }
+generate_histogram <- function(
+ data, column = NULL, xlab = "X", ylab = "Frequency",
+ title = "Histogram", bins = 30, font_size = 8, alpha = 0.6,
+ col_fill = "grey40", col_outline = "black",
+ xlog = NULL, ylog = NULL) {
+ data <- convert_to_dataframe(data)
+
+ column <- get_default_columns(data, column)
+ if (is.null(column)) {
+ data <- reshape2::melt(data)
+ column <- "value"
+ }
+ data <- validate_data(data, column)
+
+ if (!is.null(xlog)) {
+ xlog <- match.arg(xlog, c(
+ "log2", "log10", "log", "log2_1p", "log10_1p",
+ "log1p"
+ ))
+ xlab <- prepare_axis_label(xlab, xlog)
+ }
+ if (!is.null(ylog)) {
+ ylog <- match.arg(ylog, c(
+ "log2", "log10", "log", "log2_1p", "log10_1p",
+ "log1p"
+ ))
+ ylab <- prepare_axis_label(ylab, ylog)
+ }
+
+ suppressPackageStartupMessages(require(ggplot2))
+
+ ggplot(data, aes(x = .data[[column]])) +
+ geom_histogram(
+ bins = bins, fill = col_fill,
+ color = col_outline, alpha = alpha
+ ) +
+ labs(x = xlab, y = ylab, title = title) +
+ theme_bw() +
+ theme(text = element_text(size = font_size), plot.title = element_text(hjust = 0.5)) +
+ {
+ if (!is.null(xlog)) {
+ scale_x_continuous(trans = xlog)
+ }
+ } +
+ {
+ if (!is.null(ylog)) {
+ scale_y_continuous(trans = ylog)
+ }
+ }
+}
+
+#' Generate a histogram to compare two sets of values
+#'
+#' `generate_comparison_histogram()` creates a histogram that compares values from two groups. Typically a before and after some transformation or filtering, but could also just be two categories.
+#'
+#' Note: requires packages ggplot2, RColorBrewer.
+#'
+#' @param data1 (vector) data of the 1st group.
+#' @param data2 (vector) data of the 2nd group.
+#' @param column1 (chr) name of column containing first set of data
+#' @param column2 (chr) name of column containing second set of data
+#' @param xlab (chr) name of the x-axis.
+#' @param ylab (chr) name of the y-axis. Default is 'Count'.
+#' @param title (chr) title of the figure. Default is 'Comparison Histogram'
+#' @param bins (int) number of bins for the histogram. Default is 30.
+#' @param alpha (double) opacity of the histograms. Value between 0 and 1. Default is 0.6.
+#' @param legend_title (chr) name of the legend. Default is 'Feature Selection'.
+#' @param legend_position (chr) placement of the legend in the figure.
+#' Options indude: (default) 'top', 'bottom', 'left', 'right'.
+#' @param subplot_title1 (chr) name of the values from the 1st dataset. Default is 'Before'.
+#' @param subplot_title2 (chr) name of the values from the 2nd dataset. Default is 'After'.
+#' @param col_set (chr) name of RColorBrewer set to be used for colouring.
+#' Options include: (default) 'Set1', 'Set2', 'Set3', 'Pastel2', 'Pastel1', 'Paired', 'Dark2', 'Accent'
+#' @param cols (chr (vec)) vector of colors that you wish to use. Default to NULL and other colours are produced using the set in col_set.
+#' @param col_outline (chr) colour of the border for the bars. Default is 'black'.
+#' @param xlog (chr) logarithmic transformation of the x-axis.
+#' Options include: (default) NULL (none), 'log', 'log2', 'log10'.
+#' @param ylog (chr) logarithmic transformation of the y-axis. Default is NULL (none).
+#' Options include: (default) NULL (none), 'log', 'log2', 'log10'.
+#'
+#' @returns the ggplot build of the final histogram.
+#'
+#' @examples
+#' \dontrun{
+#' control <- PlantGrowth$weight[PlantGrowth$group == "ctrl"]
+#' treat1 <- PlantGrowth$weight[PlantGrowth$group == "trt1"]
+#' generate_comparison_histogram(control, treat1,
+#' xlab = "Weights",
+#' title = "Control VS. Treatment: Plant Weights"
+#' )
+#' }
+generate_comparison_histogram <- function(
+ data1, data2, column1 = NULL, column2 = NULL,
+ xlab = NULL, ylab = "Count", title = "Comparison Histogram", bins = 30, alpha = 0.6,
+ legend_title = "Feature Selection", legend_position = "top", subplot_title1 = "Before",
+ subplot_title2 = "After", col_set = "Set1", cols = NULL, col_outline = "black", xlog = NULL,
+ ylog = NULL, ...) {
+ suppressPackageStartupMessages(require(ggplot2))
+
+ data1 <- convert_to_dataframe(data1)
+
+ data2 <- convert_to_dataframe(data2)
+
+ column1 <- get_default_columns(data1, column1)
+ if (is.null(column1)) {
+ data1 <- reshape2::melt(data1)
+ column1 <- "value"
+ }
+ data1 <- validate_data(data1, column1)
+
+ column2 <- get_default_columns(data2, column2)
+ if (is.null(column2)) {
+ data2 <- reshape2::melt(data2)
+ column2 <- "value"
+ }
+ data2 <- validate_data(data2, column1)
+
+ if (is.null(xlab)) {
+ if (is.null(column1)) {
+ xlab <- "Values"
+ } else {
+ xlab <- column1
+ }
+ }
+
+ if (!is.null(xlog)) {
+ xlog <- match.arg(xlog, c(
+ "log2", "log10", "log", "log2_1p", "log10_1p",
+ "log1p"
+ ))
+ xlab <- prepare_axis_label(xlab, xlog)
+ }
+ if (!is.null(ylog)) {
+ ylog <- match.arg(ylog, c(
+ "log2", "log10", "log", "log2_1p", "log10_1p",
+ "log1p"
+ ))
+ ylab <- prepare_axis_label(ylab, ylog)
+ }
+
+ data_frame <- rbind(data.frame(value = data1, category = subplot_title1), data.frame(
+ value = data2,
+ category = subplot_title2
+ ))
+ data_frame$category <- factor(data_frame$category, levels = c(subplot_title1, subplot_title2))
+
+ # Default Colours
+ if (is.null(cols)) {
+ cols <- color_select(2, col_set = col_set)
+ names(cols) <- c(subplot_title1, subplot_title2)
+ }
+
+ ggplot(data_frame, aes(x = .data[[column1]], fill = category)) +
+ geom_histogram(
+ bins = bins,
+ alpha = alpha, position = "identity", color = col_outline
+ ) +
+ scale_fill_manual(values = c(
+ cols[1],
+ cols[2]
+ ), name = legend_title) +
+ labs(x = xlab, title = title) +
+ theme_bw() +
+ theme(
+ legend.position = legend_position, text = element_text(size = 8),
+ axis.title = element_text(size = 6), legend.text = element_text(size = 6),
+ legend.key.size = unit(1, "line"), legend.title = element_text(size = 7)
+ ) +
+ {
+ if (!is.null(xlog)) {
+ scale_x_continuous(trans = xlog)
+ }
+ } +
+ {
+ if (!is.null(ylog)) {
+ scale_y_continuous(trans = ylog)
+ }
+ }
+}
+
+#' Generate a simple density plot for one variable.
+#'
+#' `generate_density()` creates a simple density plot for a single numerical variable.
+#'
+#' Note: requires packages ggplot2.
+#'
+#' @param data (data.frame) a data frame that the function is obtaining the data from.
+#' @param column (chr) the column with the data that is to be used in the density plot.
+#' @param xlab (chr) name of the x-axis. Default is 'X'.
+#' @param ylab (chr) name of the y-axis. Default is 'Density'.
+#' @param title (chr) title of the figure. Default is 'Density Plot'.
+#' @param col_fill (chr) primary colour of the bars. Default is 'grey40'.
+#' @param col_outline (chr) colour of the outline of the bars. Default is 'black'.
+#' @param adjust (dbl) adjust bandwidth, which determines how precise (lower values) or smooth (higher values) the density plot is. Default is 1.
+#' @param alpha (dbl) opacity of the density plot. Default is 0.8.
+#' @param xlog (chr) logarithmic transformation of the x-axis.
+#' Options include: (default) NULL (none), 'log', 'log2', 'log10', 'log1p', 'log2_1p', 'log10_1p'
+#'
+#' @returns the ggplot build of the final density plot.
+#'
+#' @examples
+#' \dontrun{
+#' generate_density(mtcars, wt)
+#'
+#' generate_density(mtcars, wt,
+#' xlab = "Weights", ylab = "Density",
+#' title = "Weights of Cars",
+#' col_fill = "red", col_outline = "black",
+#' adjust = 0.5, alpha = 0.8, xlog = "log2"
+#' )
+#' }
+generate_density <- function(
+ data, column = NULL, xlab = "X", ylab = "Density",
+ title = "Density Plot", col_fill = "grey40", col_outline = "black", adjust = 1,
+ alpha = 0.6, xlog = NULL) {
+ suppressPackageStartupMessages(require(ggplot2))
+
+ data <- convert_to_dataframe(data)
+
+ column <- get_default_columns(data, column)
+ if (is.null(column)) {
+ data <- reshape2::melt(data)
+ column <- "value"
+ }
+ data <- validate_data(data, column)
+
+ if (!is.null(xlog)) {
+ xlog <- match.arg(xlog, c(
+ "log2", "log10", "log", "log2_1p", "log10_1p",
+ "log1p"
+ ))
+ xlab <- prepare_axis_label(xlab, xlog)
+ }
+
+ if (!is.data.frame(data)) {
+ data <- convert_to_dataframe(data)
+ colnames(data) <- column
+ }
+
+ ggplot(data, aes(x = .data[[column]])) +
+ geom_density(
+ fill = col_fill, color = col_outline,
+ adjust = adjust, alpha = alpha
+ ) +
+ labs(x = xlab, y = ylab, title = title) +
+ theme_bw() +
+ theme(plot.title = element_text(hjust = 0.5)) +
+ {
+ if (!is.null(xlog)) {
+ scale_x_continuous(trans = xlog)
+ }
+ }
+}
+
+#' Generate a density plot to compare two sets of values.
+#'
+#' `generate_comparison_density()` creates a density plot that compares values from two groups. Typically a before and after some transformation or filtering, but could also just be two categories.
+#'
+#' Note: requires packages ggplot2, RColorBrewer.
+#'
+#' @param data1 (vector) data of the 1st group.
+#' @param data2 (vector) data of the 2nd group.
+#' @param column1 (chr) name of column containing first set of data
+#' @param column2 (chr) name of column containing second set of data
+#' @param xlab (chr) name of the x-axis.
+#' @param ylab (chr) name of the y-axis. Default is 'Count'.
+#' @param title (chr) title of the figure. Default is 'Comparison Density Plot'
+#' @param legend_title (chr) name of the legend. Default is 'Feature Selection'.
+#' @param legend_position (chr) placement of the legend in the figure.
+#' Options include: (default) 'top', 'bottom', 'left', 'right'.
+#' @param subplot_title1 (chr) name of the values from the 1st dataset. Default is 'Before'.
+#' @param subplot_title2 (chr) name of the values from the 2nd dataset. Default is 'After'.
+#' @param col_set (chr) name of RColorBrewer set to be used for colouring.
+#' Options include: (default) 'Set1', 'Set2', 'Set3', 'Pastel2', 'Pastel1', 'Paired', 'Dark2', 'Accent'
+#' @param cols (chr (vec)) vector of colors that you wish to use. Default to NULL and other colours are produced using the set in col_set.
+#' @param col_outline (chr) colour of the border of the density shape. Default is 'black'.
+#' @param adjust (dbl) adjust bandwidth, which determines how precise (lower values) or smooth (higher values) the density plot is. Default is 1.
+#' @param alpha (dbl) opacity of the histograms. Value between 0 and 1. Default is 0.6.
+#' @param xlog (chr) logarithmic transformation of the x-axis.
+#' Options include: (default) NULL (none), 'log', 'log2', 'log10', 'log1p', 'log2_1p', 'log10_1p'.
+#'
+#' @returns the ggplot build of the final density plot.
+#'
+#' @examples
+#' \dontrun{
+#' control <- PlantGrowth$weight[PlantGrowth$group == "ctrl"]
+#' treat1 <- PlantGrowth$weight[PlantGrowth$group == "trt1"]
+#' generate_comparison_density(control, treat1,
+#' xlab = "Weights",
+#' title = "Control VS. Treatment: Plant Weights"
+#' )
+#' }
+generate_comparison_density <- function(
+ data1, data2, column1 = NULL, column2 = NULL,
+ xlab = NULL, ylab = "Count", title = "Comparison Density Plot", legend_title = "Feature Selection",
+ legend_position = "top", subplot_title1 = "Before", subplot_title2 = "After", col_set = "Set1",
+ cols = NULL, col_outline = "black", adjust = 1, alpha = 0.6, xlog = NULL) {
+ suppressPackageStartupMessages(require(ggplot2))
+ data1 <- convert_to_dataframe(data1)
+
+ data2 <- convert_to_dataframe(data2)
+
+ column1 <- get_default_columns(data1, column1)
+ if (is.null(column1)) {
+ data1 <- reshape2::melt(data1)
+ column1 <- "value"
+ }
+ data1 <- validate_data(data1, column1)
+
+ column2 <- get_default_columns(data2, column2)
+ if (is.null(column2)) {
+ data2 <- reshape2::melt(data2)
+ column2 <- "value"
+ }
+ data2 <- validate_data(data2, column1)
+
+ if (is.null(xlab)) {
+ if (is.null(column1)) {
+ xlab <- "Values"
+ } else {
+ xlab <- column1
+ }
+ }
+
+ if (!is.null(xlog)) {
+ xlog <- match.arg(xlog, c(
+ "log2", "log10", "log", "log2_1p", "log10_1p",
+ "log1p"
+ ))
+ xlab <- prepare_axis_label(xlab, xlog)
+ }
+
+ data_frame <- rbind(data.frame(value = data1, category = subplot_title1), data.frame(
+ value = data2,
+ category = subplot_title2
+ ))
+ data_frame$category <- factor(data_frame$category, levels = c(subplot_title1, subplot_title2))
+
+ if (is.null(cols)) {
+ cols <- color_select(2, col_set = col_set)
+ names(cols) <- c(subplot_title1, subplot_title2)
+ }
+
+ ggplot(data_frame, aes(x = value, fill = category)) +
+ geom_density(
+ adjust = adjust,
+ alpha = alpha, color = col_outline
+ ) +
+ scale_fill_manual(values = c(
+ cols[1],
+ cols[2]
+ ), name = legend_title) +
+ labs(x = xlab, title = title) +
+ theme_bw() +
+ theme(
+ legend.position = legend_position, text = element_text(size = 8),
+ axis.title = element_text(size = 6), legend.text = element_text(size = 6),
+ legend.key.size = unit(1, "line"), legend.title = element_text(size = 7)
+ ) +
+ {
+ if (!is.null(xlog)) {
+ scale_x_continuous(trans = xlog)
+ }
+ }
+}
+
+#'
+#' `generate_barplot()` creates a bar plot using ggplot2
+#'
+#' Note: Requires ggplot2 package
+#'
+#' @details Settings available for both normal or stacked (& proportional stacked) bar plots.
+#' adding a groupby variable will make a stacked bar plot and setting prop = T will make proportional bar plot.
+#'
+#' @param data (data.frame/vector) data frame contained the data of interest or vector with levels (categorical variable)
+#' @param y (data.frame or vector) contains the identity values/heights of the bars (ex. counts of the bars pre-calculated)
+#' @param group
+#' @param label_name (chr) the name of the categorical variable to be used for the bars. Default is 'Labels'.
+#' @param value_name (chr) the name of the column containing values for y (pre-calculated counts)
+#' @param groupby (chr) the name of the secondary categorical variable to group the bars. Default is NULL.
+#' @param xlab (chr) x-axis label. Default is 'X'.
+#' @param ylab (chr) y-axis label. Default is 'Count'
+#' @param title (chr) plot name/title. Default is 'Bar Plot'
+#' @param col_set (chr) name of RColorBrewer set to be used for colouring.
+#' Options include: (default) 'Set1', 'Set2', 'Set3', 'Pastel2', 'Pastel1', 'Paired', 'Dark2', 'Accent'
+#' @param cols (chr (vec)) vector of colors that you wish to use. Default to NULL and other colours are produced using the set in col_set.
+#' @param col_outline (chr) colour of the outline of the bars. Default is 'grey'.
+#' @param col_labels (chr) colour of the labels on the bars. Default is 'black'.
+#' @param alpha (dbl) Opacity. Default is 0.6.
+#' @param prop (logical) make a stacked barplot a proportional barplot. Default is FALSE.
+#' @param add_count_lab (logical) add count labels to bars. Default is TRUE.
+#' @param vars_as_entered (logical) leave the variable order as entered, not flipping the values so the variable with more levels makes the bars. Default is FALSE.
+#' @param legend_position (chr) position of the legend.
+#' Options Include: (default) 'top', 'bottom', 'left', 'right'
+#' @param font_size (dbl) size of the font. Default is 3.25.
+#'
+#' @returns The ggplot object of the completed plot
+#'
+#' @examples
+#' \dontrun{
+#' generate_barplot(mtcars, "gear")
+#'
+#' generate_barplot(mtcars, "gear",
+#' groupby = "cyl", xlab = "gear", ylab = "Count",
+#' title = "Gear Vs. Carb", col_set = "Set2",
+#' prop = T, add_count_lab = T, vars_as_entered = F,
+#' legend_position = "top", font_size = 4
+#' )
+#' }
+generate_barplot <- function(
+ data, y = NULL, group = NULL, label_name = "labels",
+ value_name = "values", groupby = "group", xlab = NULL, ylab = NULL, title = "Bar Plot",
+ col_set = "Set1", cols = NULL, col_outline = "grey30", col_labels = "black", alpha = 0.6,
+ prop = F, add_count_lab = T, vars_as_entered = F, legend_position = "top", font_size = 3.25) {
+ suppressPackageStartupMessages(require(ggplot2))
+
+ # Angled text threshold
+ VERT_LEVELS_MIN <- 9 # minimum number of levels
+ VERT_WORD_MIN <- 7 # minimum length of largest label name
+
+ data <- convert_to_dataframe(data)
+ label_name <- get_default_columns(data, label_name)
+ data <- validate_data(data, label_name)
+
+ if (!is.null(y)) {
+ y <- convert_to_dataframe(y)
+ value_name <- get_default_columns(y, value_name)
+ # check if value_name is in y or add it if possible, otherwise error
+ y <- validate_data(y, value_name)
+ data <- concatenate_datasets(data, y, how = "horizontal")
+ }
+
+ if (!is.null(group)) {
+ group <- convert_to_dataframe(group)
+ groupby <- get_default_columns(group, groupby)
+ group <- validate_data(group, groupby)
+ data <- concatenate_datasets(data, group, how = "horizontal")
+ }
+
+ if (is.null(xlab)) {
+ xlab <- label_name
+ }
+
+ # Check that the variable is a character
+ if (!is.character(data[[label_name]])) {
+ data[[label_name]] <- as.character(data[[label_name]])
+ }
+ cat_lvls <- length(unique(data[[label_name]]))
+
+ if (!is.null(value_name) && value_name %in% colnames(data)) {
+ data <- data[, c(label_name, value_name)]
+ if (is.null(ylab)) {
+ ylab <- value_name
+ }
+ sorted_inds <- order(data[[value_name]], decreasing = TRUE)
+ data <- data[sorted_inds, ]
+ max_length <- max(nchar(unique(na.omit(data[[label_name]]))))
+ data[[label_name]] <- factor(data[[label_name]], levels = data[[label_name]])
+ levels <- 1
+ if (is.null(cols)) {
+ cols <- color_select(levels, col_set)
+ }
+ the_plot <- ggplot(data, aes(x = .data[[label_name]], y = .data[[value_name]])) +
+ geom_bar(stat = "identity", color = col_outline, fill = cols, alpha = alpha) +
+ theme_bw() +
+ labs(x = xlab, y = ylab, title = title) +
+ theme(
+ legend.position = legend_position,
+ plot.title = element_text(hjust = 0.5)
+ ) +
+ {
+ if (cat_lvls >= VERT_LEVELS_MIN || max_length >= VERT_WORD_MIN) {
+ theme(axis.text.x = element_text(angle = 60, hjust = 1))
+ }
+ } +
+ {
+ if (add_count_lab) {
+ geom_text(aes(label = sci_format(.data[[value_name]])),
+ stat = "identity",
+ position = position_stack(vjust = 0.5), color = col_labels, size = font_size
+ )
+ }
+ }
+ } else {
+ if (is.null(ylab)) {
+ ylab <- "Count"
+ }
+ # If a groupby value has been passed, the function sets up for a stacked
+ # barplot
+ if (!is.null(groupby) && groupby %in% colnames(data)) {
+ stacked <- TRUE
+
+ # Check that the variable is a character
+ if (!is.character(data[[groupby]])) {
+ data[[groupby]] <- as.character(data[[groupby]])
+ }
+
+ # settings for plot type and labels Proportional or Stack
+ if (prop) {
+ position_set <- "fill"
+ position_func <- position_fill(vjust = 0.5)
+
+ # Won't change the name if a custom label is given
+ if (ylab == "Count") {
+ ylab <- "Proportion"
+ }
+ } else {
+ position_set <- "stack"
+ position_func <- position_stack(vjust = 0.5)
+ }
+
+ # Calculate Levels
+ groupby_lvls <- length(unique(data[[groupby]]))
+ max_length <- max(nchar(unique(na.omit(data[[label_name]]))))
+
+ # Function will make the variable with more levels the bars of the plot
+ # Unless it is specified to leave it as is
+ if (!vars_as_entered) {
+ # if grouby is larger, swap the values
+ if (groupby_lvls > cat_lvls) {
+ temp <- label_name
+ label_name <- groupby
+ groupby <- temp
+ xlab <- label_name
+
+ levels <- cat_lvls
+ cat_lvls <- groupby_lvls
+ max_length <- max(nchar(unique(na.omit(data[[label_name]]))))
+ } else {
+ levels <- groupby_lvls
+ }
+ } else {
+ levels <- groupby_lvls
+ }
+ } else {
+ stacked <- FALSE
+ cat_lvls <- length(unique(data[[label_name]]))
+ max_length <- max(nchar(unique(na.omit(data[[label_name]]))))
+ levels <- 1
+ }
+
+ if (is.null(cols)) {
+ cols <- color_select(levels, col_set)
+ }
+
+ if (stacked) {
+ the_plot <- ggplot(data, aes(
+ x = forcats::fct_infreq(.data[[label_name]]),
+ fill = forcats::fct_rev(forcats::fct_infreq(.data[[groupby]]))
+ )) +
+ geom_bar(position = position_set, stat = "count", color = col_outline, alpha = alpha) +
+ theme_bw() +
+ labs(
+ x = xlab, y = ylab, color = groupby, fill = groupby,
+ title = title
+ ) +
+ theme(legend.position = legend_position, plot.title = element_text(hjust = 0.5)) +
+ scale_fill_manual(values = cols) +
+ {
+ if (cat_lvls >= VERT_LEVELS_MIN || max_length >= VERT_WORD_MIN) {
+ theme(axis.text.x = element_text(angle = 60, hjust = 1))
+ }
+ } +
+ {
+ if (add_count_lab) {
+ geom_text(aes(label = after_stat(count)),
+ stat = "count", position = position_func,
+ color = col_labels, size = font_size
+ )
+ }
+ }
+ } else {
+ the_plot <- ggplot(data, aes(x = forcats::fct_infreq(.data[[label_name]]))) +
+ geom_bar(stat = "count", color = col_outline, alpha = alpha) +
+ theme_bw() +
+ labs(x = xlab, y = ylab, title = title) +
+ theme(
+ legend.position = legend_position,
+ plot.title = element_text(hjust = 0.5)
+ ) +
+ {
+ if (cat_lvls >= VERT_LEVELS_MIN || max_length >= VERT_WORD_MIN) {
+ theme(axis.text.x = element_text(angle = 60, hjust = 1))
+ }
+ } +
+ {
+ if (add_count_lab) {
+ geom_text(aes(label = after_stat(count)),
+ stat = "count", position = position_stack(vjust = 0.5),
+ color = col_labels, size = font_size
+ )
+ }
+ }
+ }
+ }
+ return(the_plot)
+}
+
+#'
+#' `generate_boxplot()` creates a violin plot using ggplot2
+#'
+#' Note: Requires ggplot2
+#'
+#' @param data (data.frame or vector) data frame containing the variable or numerical data
+#' @param labels (data.frame or vector) vector (or data frame) containing the categorical data
+#' @param column (chr) name of numerical variable column.
+#' @param label_name (chr) name of categorical variable.
+#' @param xlab (chr) label for categorical axis.
+#' @param ylab (chr) label for numerical axis.
+#' Note: You don't need to change the x and y labels around if horizontal_plot is TRUE, the function will change them automatically.
+#' @param title (chr) title of the plot. Default is 'Violin Plot'
+#' @param legend_position (chr) position of the legend.
+#' Options Include: (default) 'top', 'bottom', 'left', 'right'
+#' @param add_box (logical) Add boxplots on top of violin plots. Default is TRUE
+#' @param horizontal_plot (logical) if you want the boxplots to be running horizontally, set to TRUE. Default is FALSE.
+#' @param order (logical) if you want the violin plots ordered by median. Default is FALSE.
+#' @param col_set (chr) name of RColorBrewer set to be used for colouring.
+#' Options Include: (default) 'Set1', 'Set2', 'Set3', 'Pastel2', 'Pastel1', 'Paired', 'Dark2', 'Accent'
+#' @param cols (chr (vec)) vector of colors that you wish to use. Default to NULL and other colours are produced using the set in col_set.
+#' @param alpha (dbl) Opacity. Default is 0.6.
+#' @param log_num (chr) logarithmic transformation of the numerical axis (y).
+#' Options include: (default) NULL (none), 'log', 'log2', 'log10', 'log1p', 'log2_1p', 'log10_1p'.
+#'
+#' @returns The ggplot object of the completed plot
+#'
+#' @examples
+#' \dontrun{
+#' generate_boxplot(iris, column = "Petal.Length", label_name = "Species")
+#'
+#' generate_boxplot(iris,
+#' column = "Petal.Length", label_name = "Species",
+#' ylab = "Petal Length", xlab = "Species", horizontal_plot = T
+#' )
+#' }
+generate_boxplot <- function(
+ data, labels = NULL, column = NULL, label_name = "labels",
+ xlab = NULL, ylab = NULL, title = "Boxplot", legend_position = "top", horizontal_plot = F,
+ order = F, col_set = "Set1", cols = NULL, alpha = 0.6, log_num = NULL) {
+ suppressPackageStartupMessages(require(ggplot2))
+
+ # Angled text threshold
+ VERT_LEVELS_MIN <- 9 # minimum number of levels
+ VERT_WORD_MIN <- 7 # minimum length of largest label name
+
+ data <- convert_to_dataframe(data)
+
+ if (!is.null(labels)) {
+ labels <- convert_to_dataframe(labels)
+ label_name <- get_default_columns(labels, label_name)
+ labels <- validate_data(labels, label_name)
+ }
+
+ if (is.null(column) && ncol(data) > 2 && !is.null(label_name)) {
+ suppressPackageStartupMessages(require(reshape2))
+
+ if (!is.null(labels)) {
+ data <- concatenate_datasets(data, labels, how = "horizontal")
+ }
+ data <- reshape2::melt(data, id.vars = label_name)
+ column <- names(data)[3]
+ } else {
+ column <- get_default_columns(data, column)
+ data <- validate_data(data, column)
+ if (!is.null(labels)) {
+ data <- concatenate_datasets(data, labels, how = "horizontal")
+ }
+ }
+
+ if (is.null(xlab)) {
+ xlab <- label_name
+ }
+
+ if (is.null(ylab)) {
+ ylab <- column
+ }
+
+ if (!is.null(log_num)) {
+ log_num <- match.arg(log_num, c(
+ "log2", "log10", "log", "log2_1p", "log10_1p",
+ "log1p"
+ ))
+ ylab <- prepare_axis_label(ylab, log_num)
+ }
+
+ if (!is.null(label_name) && label_name %in% colnames(data)) {
+ # Make sure label_name is categorical
+ if (!is.character(data[[label_name]])) {
+ data[[label_name]] <- as.character(data[[label_name]])
+ }
+
+ levels <- length(unique(data[[label_name]]))
+ max_length <- max(nchar(unique(na.omit(data[[label_name]]))))
+
+ if (is.null(cols)) {
+ cols <- color_select(levels, col_set)
+ }
+
+ if (order) {
+ the_plot <- ggplot(data, aes(x = forcats::fct_reorder(.data[[label_name]],
+ .data[[column]],
+ .fun = median
+ ), y = .data[[column]], fill = .data[[label_name]])) +
+ geom_boxplot(alpha = alpha) +
+ theme_bw() +
+ scale_fill_manual(values = cols) +
+ labs(x = xlab, y = ylab, title = title) +
+ theme(
+ legend.position = legend_position,
+ plot.title = element_text(hjust = 0.5)
+ ) +
+ {
+ if (!is.null(log_num)) {
+ scale_y_continuous(trans = log_num)
+ }
+ } +
+ {
+ if (horizontal_plot) {
+ coord_flip()
+ }
+ } +
+ {
+ if ((levels >= VERT_LEVELS_MIN || max_length >= VERT_WORD_MIN)) {
+ theme(axis.text.x = element_text(angle = 60, hjust = 1))
+ }
+ }
+ } else {
+ the_plot <- ggplot(data, aes(
+ x = .data[[label_name]], y = .data[[column]],
+ fill = .data[[label_name]]
+ )) +
+ geom_boxplot(alpha = alpha) +
+ theme_bw() +
+ scale_fill_manual(values = cols) +
+ labs(x = xlab, y = ylab, title = title) +
+ theme(
+ legend.position = legend_position,
+ plot.title = element_text(hjust = 0.5)
+ ) +
+ {
+ if (!is.null(log_num)) {
+ scale_y_continuous(trans = log_num)
+ }
+ } +
+ {
+ if (horizontal_plot) {
+ coord_flip()
+ }
+ } +
+ {
+ if ((levels >= VERT_LEVELS_MIN || max_length >= VERT_WORD_MIN)) {
+ theme(axis.text.x = element_text(angle = 60, hjust = 1))
+ }
+ }
+ }
+ } else {
+ the_plot <- ggplot(data, aes(x = 1, y = .data[[column]])) +
+ geom_boxplot(alpha = alpha) +
+ theme_bw() +
+ labs(x = xlab, y = ylab, title = title) +
+ theme(
+ legend.position = legend_position,
+ plot.title = element_text(hjust = 0.5)
+ ) +
+ {
+ if (!is.null(log_num)) {
+ scale_y_continuous(trans = log_num)
+ }
+ } +
+ {
+ if (horizontal_plot) {
+ coord_flip()
+ }
+ }
+ }
+ return(the_plot)
+}
+
+#'
+#' `generate_violin()` creates a violin plot using ggplot2
+#'
+#' Note: Requires ggplot2
+#'
+#' @param data (data.frame or vector) data frame containing the variable or numerical data
+#' @param labels (data.frame or vector) vector (or data frame) containing the categorical data
+#' @param column (chr) name of numerical variable column.
+#' @param label_name (chr) name of categorical variable.
+#' @param xlab (chr) label for categorical axis.
+#' @param ylab (chr) label for numerical axis.
+#' Note: You don't need to change the x and y labels around if horizontal_plot is TRUE, the function will change them automatically.
+#' @param title (chr) title of the plot. Default is 'Violin Plot'
+#' @param legend_position (chr) position of the legend.
+#' Options Include: (default) 'top', 'bottom', 'left', 'right'
+#' @param add_box (logical) Add boxplots on top of violin plots. Default is TRUE
+#' @param horizontal_plot (logical) if you want the boxplots to be running horizontally, set to TRUE. Default is FALSE.
+#' @param order (logical) if you want the violin plots ordered by median. Default is FALSE.
+#' @param col_set (chr) name of RColorBrewer set to be used for colouring.
+#' Options Include: (default) 'Set1', 'Set2', 'Set3', 'Pastel2', 'Pastel1', 'Paired', 'Dark2', 'Accent'
+#' @param cols (chr (vec)) vector of colors that you wish to use. Default to NULL and other colours are produced using the set in col_set.
+#' @param alpha (dbl) Opacity. Default is 0.6.
+#' @param log_num (chr) logarithmic transformation of the numerical axis (y).
+#' Options include: (default) NULL (none), 'log', 'log2', 'log10', 'log1p', 'log2_1p', 'log10_1p'.
+#'
+#' @returns The ggplot object of the completed plot
+#'
+#' @examples
+#' \dontrun{
+#' generate_violin(iris, column = "Petal.Length", label_name = "Species")
+#'
+#' generate_violin(iris,
+#' column = "Petal.Length", label_name = "Species",
+#' ylab = "Petal Length", xlab = "Species", horizontal_plot = T
+#' )
+#' }
+generate_violin <- function(
+ data, labels = NULL, column = NULL, label_name = "labels",
+ ylab = NULL, xlab = NULL, ylim = NULL, title = "Violin Plot",
+ legend_position = "top", add_box = TRUE, horizontal_plot = FALSE,
+ order = FALSE, col_set = "Set1", cols = NULL, alpha = 0.6,
+ log_num = NULL, show_outliers = TRUE) {
+ suppressPackageStartupMessages(require(ggplot2))
+
+ # Angled text threshold
+ VERT_LEVELS_MIN <- 9 # minimum number of levels
+ VERT_WORD_MIN <- 7 # minimum length of largest label name
+
+ data <- convert_to_dataframe(data)
+
+ if (!is.null(labels)) {
+ labels <- convert_to_dataframe(labels)
+ label_name <- get_default_columns(labels, label_name)
+ labels <- validate_data(labels, label_name)
+ }
+
+ if (is.null(column) && ncol(data) > 2 && !is.null(label_name)) {
+ suppressPackageStartupMessages(require(reshape2))
+ # melt the data
+ if (!is.null(labels)) {
+ data <- concatenate_datasets(data, labels, how = "horizontal")
+ }
+ data <- reshape2::melt(data, id.vars = label_name)
+ column <- names(data)[3]
+ } else {
+ column <- get_default_columns(data, column)
+ data <- validate_data(data, column)
+ if (!is.null(labels)) {
+ data <- concatenate_datasets(data, labels, how = "horizontal")
+ }
+ }
+
+
+ if (is.null(xlab) && !is.null(label_name) && label_name %in% colnames(data)) {
+ xlab <- label_name
+ }
+
+ if (is.null(ylab)) {
+ ylab <- column
+ }
+
+ if (!is.null(log_num)) {
+ log_num <- match.arg(log_num, c(
+ "log2", "log10", "log", "log2_1p", "log10_1p",
+ "log1p"
+ ))
+ # ylab <- prepare_axis_label(ylab, log_num)
+ data[[column]] <- log_transformation(data[[column]], log_num)
+ }
+
+ if (!is.null(label_name) && label_name %in% colnames(data)) {
+ # Make sure label_column is categorical
+ if (!is.character(data[[label_name]])) {
+ data[[label_name]] <- as.character(data[[label_name]])
+ }
+
+ levels <- length(unique(data[[label_name]]))
+ max_length <- max(nchar(unique(na.omit(data[[label_name]]))))
+
+ if (is.null(cols)) {
+ cols <- color_select(levels, col_set)
+ }
+
+ if (order) {
+ the_plot <- ggplot(data, aes(x = forcats::fct_reorder(.data[[label_name]],
+ .data[[column]],
+ .fun = median
+ ), y = .data[[column]], fill = .data[[label_name]])) +
+ geom_violin(alpha = alpha) +
+ {
+ if (add_box) {
+ geom_boxplot(
+ width = 0.2, color = "black", outlier.shape = 1,
+ fill = NA, outlier.fill = "black", alpha = alpha,
+ outliers = show_outliers
+ )
+ }
+ } +
+ theme_bw() +
+ scale_fill_manual(values = cols) +
+ labs(
+ x = xlab, y = ylab,
+ title = title
+ ) +
+ theme(legend.position = legend_position, plot.title = element_text(hjust = 0.5)) +
+ # {
+ # if (!is.null(log_num)) {
+ # scale_y_continuous(trans = log_num)
+ # }
+ # } +
+ {
+ if (!is.null(ylim)) {
+ ylim(ylim[[1]], ylim[[2]])
+ }
+ } +
+ {
+ if (horizontal_plot) {
+ coord_flip()
+ }
+ } +
+ {
+ if ((levels >= VERT_LEVELS_MIN || max_length >= VERT_WORD_MIN)) {
+ theme(axis.text.x = element_text(angle = 60, hjust = 1))
+ }
+ }
+ } else {
+ the_plot <- ggplot(data, aes(
+ x = .data[[label_name]], y = .data[[column]],
+ fill = .data[[label_name]]
+ )) +
+ geom_violin(alpha = alpha) +
+ {
+ if (add_box) {
+ geom_boxplot(
+ width = 0.7, color = "black", outlier.shape = 1, outlier.size = 0.5,
+ fill = NA, outlier.fill = "black", alpha = alpha, outliers = show_outliers
+ )
+ }
+ } +
+ theme_bw() +
+ scale_fill_manual(values = cols) +
+ labs(
+ x = xlab, y = ylab,
+ title = title
+ ) +
+ theme(legend.position = legend_position, plot.title = element_text(hjust = 0.5)) +
+ # {
+ # if (!is.null(log_num)) {
+ # scale_y_continuous(trans = log_num)
+ # }
+ # } +
+ {
+ if (!is.null(ylim)) {
+ ylim(ylim[[1]], ylim[[2]])
+ }
+ } +
+ {
+ if (horizontal_plot) {
+ coord_flip()
+ }
+ } +
+ {
+ if ((levels >= VERT_LEVELS_MIN || max_length >= VERT_WORD_MIN)) {
+ theme(axis.text.x = element_text(angle = 60, hjust = 1))
+ }
+ }
+ }
+ } else {
+ # generate a single violin plot
+ the_plot <- ggplot(data, aes(x = 1, y = .data[[column]])) +
+ geom_violin(fill = "grey40", alpha = alpha) +
+ {
+ if (add_box) {
+ geom_boxplot(
+ width = 0.2, color = "black", outlier.shape = 1,
+ fill = NA, outlier.fill = "black", alpha = alpha
+ )
+ }
+ } +
+ theme_bw() +
+ labs(x = xlab, y = ylab, title = title) +
+ theme(
+ legend.position = legend_position,
+ plot.title = element_text(hjust = 0.5)
+ ) +
+ {
+ if (!is.null(log_num)) {
+ scale_y_continuous(trans = log_num)
+ }
+ } +
+ {
+ if (horizontal_plot) {
+ coord_flip()
+ }
+ }
+ }
+ return(the_plot)
+}
+
+#'
+#' `generate_scatterplot()` creates a scatterplot using ggplot2
+#'
+#' Note: Requires ggplot2 package
+#'
+#' @param data (data.frame or vector) data frame containing the variables or data for the x axis.
+#' @param y (data.frame or vector) data for the y axis.
+#' @param group (data.frame or vector) label data for grouping the x and y by colour.
+#' @param xdata (chr) name of column with data for the x-axis.
+#' @param ydata (chr) name of column with data for the y-axis.
+#' @param groupby (chr) name of categorical variable to group the points. Default is NULL.
+#' @param xlab (chr) x-axis label. Default is 'Var 1'.
+#' @param ylab (chr) y-axis label. Default is 'Var 2'.
+#' @param title (chr) title for the plot. Default is 'Scatterplot'.
+#' @param alpha (dbl) opacity of points. Range: 0 to 1. Default is 1.
+#' @param col_set (chr) name of RColorBrewer set to be used for colouring.
+#' Options Include: (default) 'Set1', 'Set2', 'Set3', 'Pastel2', 'Pastel1', 'Paired', 'Dark2', 'Accent'
+#' @param xlog (chr) log transformation of the x-axis.
+#' Options Include: (default) NULL (None), 'log', 'log2', 'log10', 'log1p', 'log2_1p', 'log10_1p'.
+#' @param ylog (chr) log transformation of the y-axis.
+#' Options Include: (default) NULL (None), 'log', 'log2', 'log10', 'log1p', 'log2_1p', 'log10_1p'.
+#'
+#' @returns The ggplot object of the completed plot
+#'
+#' @examples
+#' \dontrun{
+#' generate_scatterplot(mtcars, xdata = "wt", ydata = "qsec")
+#'
+#' generate_scatterplot(iris,
+#' xdata = "Petal.Length", ydata = "Petal.Width",
+#' groupby = "Species", xlab = "Petal Length",
+#' ylab = "Petal Width", title = "Petals (Length VS. Width)",
+#' alpha = 0.75, col_set = "Set2"
+#' )
+#' }
+generate_scatterplot <- function(
+ data, y = NULL, group = NULL, xdata = "x", ydata = "y",
+ groupby = "group", xlab = NULL, ylab = NULL, title = "Scatterplot", alpha = 1,
+ col_set = "Set1", cols = NULL, xlog = NULL, ylog = NULL) {
+ suppressPackageStartupMessages(require(ggplot2))
+ data <- convert_to_dataframe(data)
+ xdata <- get_default_columns(data, xdata)
+ data <- validate_data(data, xdata)
+
+ if (!is.null(y)) {
+ y <- convert_to_dataframe(y)
+ value_name <- get_default_columns(y, ydata)
+ y <- validate_data(y, ydata)
+ data <- concatenate_datasets(data, y, how = "horizontal")
+ } else {
+ data <- validate_data(data, ydata)
+ }
+
+ if (!is.null(group)) {
+ group <- convert_to_dataframe(group)
+ groupby <- get_default_columns(group, groupby)
+ group <- validate_data(group, groupby)
+ data <- concatenate_datasets(data, group, how = "horizontal")
+ }
+
+ if (is.null(xlab)) {
+ xlab <- xdata
+ }
+
+ if (is.null(ylab)) {
+ ylab <- ydata
+ }
+
+ if (!is.null(xlog)) {
+ xlog <- match.arg(xlog, c(
+ "log2", "log10", "log", "log2_1p", "log10_1p",
+ "log1p"
+ ))
+ xlab <- prepare_axis_label(xlab, xlog)
+ }
+ if (!is.null(ylog)) {
+ ylog <- match.arg(ylog, c(
+ "log2", "log10", "log", "log2_1p", "log10_1p",
+ "log1p"
+ ))
+ ylab <- prepare_axis_label(ylab, ylog)
+ }
+
+ if (!is.null(groupby) && groupby %in% colnames(data)) {
+ grouped <- TRUE
+
+ if (!is.character(data[[groupby]])) {
+ data[[groupby]] <- as.character(data[[groupby]])
+ }
+
+ if (is.null(cols)) {
+ levels <- length(unique(data[[groupby]]))
+ cols <- color_select(levels, col_set)
+ }
+ } else {
+ grouped <- FALSE
+ }
+
+ if (grouped) {
+ ggplot(data, aes(x = .data[[xdata]], y = .data[[ydata]], col = .data[[groupby]])) +
+ geom_point(alpha = alpha) +
+ theme_bw() +
+ theme(plot.title = element_text(hjust = 0.5)) +
+ scale_color_manual(values = cols) +
+ labs(x = xlab, y = ylab, title = title) +
+ {
+ if (!is.null(xlog)) {
+ scale_x_continuous(trans = xlog)
+ }
+ } +
+ {
+ if (!is.null(ylog)) {
+ scale_y_continuous(trans = ylog)
+ }
+ }
+ } else {
+ ggplot(data, aes(x = .data[[xdata]], y = .data[[ydata]])) +
+ geom_point(alpha = alpha) +
+ theme_bw() +
+ theme(plot.title = element_text(hjust = 0.5)) +
+ labs(
+ x = xlab,
+ y = ylab, title = title
+ ) +
+ {
+ if (!is.null(xlog)) {
+ scale_x_continuous(trans = xlog)
+ }
+ } +
+ {
+ if (!is.null(ylog)) {
+ scale_y_continuous(trans = ylog)
+ }
+ }
+ }
+}
+
+#' calls all the plots for a single categorical variable
+#' @param data (data.frame) data frame containing the data.
+#' @param label_name (chr) name of categorical variable.
+plot_cat_metadata <- function(data, label_name, title = NULL, ...) {
+ if (is.null(title)) {
+ title <- paste0("Distribution of ", label_name)
+ }
+ p <- generate_barplot(data,
+ label_name = label_name, xlab = label_name, title = title,
+ ...
+ )
+ return(p)
+}
+
+#' calls all the plots for a single numerical variable
+#' @param data (data.frame) data frame containing the data.
+#' @param num_var (chr) name of numerical variable.
+plot_num_metadata <- function(data, num_var, title = NULL, ...) {
+ if (is.null(title)) {
+ title <- paste0("Distribution of ", num_var)
+ }
+ p <- generate_histogram(data, num_var, xlab = num_var, title = title, ...)
+ return(p)
+}
+
+#' Categorical Variable vs Categorical Variable
+#' @param data (data.frame) data frame containing the data.
+#' @param var1 (chr) name of 1st categorical variable.
+#' @param var2 (chr) name of 2nd categorical variable.
+plot_cat_vs_cat_metadata <- function(data, var1, var2, title = NULL, ...) {
+ suppressPackageStartupMessages(require(patchwork))
+ # Categorical vs Categorical
+ if (is.null(title)) {
+ title <- paste0(var1, " VS. ", var2)
+ }
+ # Stacked Barplot
+ p1 <- generate_barplot(data,
+ label_name = var1, groupby = var2, xlab = var1,
+ title = NULL, add_count_lab = F, ...
+ )
+ # Proportional Barplot
+ p2 <- generate_barplot(data,
+ label_name = var1, xlab = var1, groupby = var2,
+ prop = T, title = NULL, ...
+ )
+ # Grid
+ p <- (p1 + p2 + plot_layout(guides = "collect") + plot_annotation(
+ title = title,
+ theme = theme(plot.title = element_text(hjust = 0.5))
+ ) & theme(legend.position = "top"))
+ return(p)
+}
+
+#' Categorical Variable vs Numerical Variable
+#' @param data (data.frame) data frame containing the data.
+#' @param cat (chr) name of categorical variable.
+#' @param num (chr) name of numerical variable.
+plot_cat_vs_num_metadata <- function(data, cat, num, title = NULL, ...) {
+ suppressPackageStartupMessages(require(patchwork))
+ # Categorical vs Numeric
+ if (is.null(title)) {
+ title <- paste0(cat, " VS. ", num)
+ }
+ # Violin Plot
+ p <- generate_violin(data, column = num, label_name = cat, title = title, ...)
+ return(p)
+}
+
+#' Numerical Variable vs Numerical Variable
+#' @param data (data.frame) data frame containing the data.
+#' @param var1 (chr) name of 1st numerical variable.
+#' @param var2 (chr) name of 2nd numerical variable.
+plot_num_vs_num_metadata <- function(data, var1, var2, title = NULL, ...) {
+ # Numeric vs Numeric
+ if (is.null(title)) {
+ title <- paste0(var1, " VS. ", var2)
+ }
+ # Scatterplot
+ p <- generate_scatterplot(data,
+ xdata = var1, ydata = var2, xlab = var1, ylab = var2,
+ title = title, ...
+ )
+ return(p)
+}
diff --git a/src/biofit/integration/R/scripts/utils.R b/src/biofit/integration/R/scripts/utils.R
new file mode 100644
index 0000000..98225e1
--- /dev/null
+++ b/src/biofit/integration/R/scripts/utils.R
@@ -0,0 +1,246 @@
+BIOCONDUCTOR <- c(
+ "limma", "metagenomeSeq", "phyloseq",
+ "DESeq2", "biomaRt", "GenomicRanges",
+ "Biostrings", "BSgenome", "GenomicFeatures",
+ "GenomicAlignments", "VariantAnnotation",
+ "S4Vectors", "IRanges", "AnnotationDbi",
+ "ComplexHeatmap"
+)
+
+save_plots <- function(
+ filename, plot = last_plot(), device = NULL, path = NULL,
+ scale = 1, width = NA, height = NA, units = c(
+ "in", "cm",
+ "mm", "px"
+ ), dpi = 300, limitsize = TRUE, bg = NULL,
+ create.dir = TRUE, ...) {
+ suppressPackageStartupMessages(c(
+ require(ggplot2)
+ ))
+
+
+ ext <- NULL
+ if (!is.null(filename)) {
+ ext <- tolower(tools::file_ext(filename))
+ } else if (!is.null(path)) {
+ ext <- tolower(tools::file_ext(path))
+ }
+
+ # always save a png if ext is not png
+ if (ext != "png") {
+ fn <- paste0(tools::file_path_sans_ext(filename), ".png")
+
+ ggplot2::ggsave(
+ fn,
+ plot = plot, device = "png",
+ path = path, scale = scale, width = width,
+ height = height, units = units, dpi = dpi,
+ limitsize = limitsize, bg = bg, create.dir = create.dir, ...
+ )
+ }
+ ggplot2::ggsave(
+ filename,
+ plot = plot, device = device,
+ path = path, scale = scale, width = width,
+ height = height, units = units, dpi = dpi,
+ limitsize = limitsize, bg = bg, create.dir = create.dir, ...
+ )
+}
+
+get_info_from_arrow_table <- function(arrow_table) {
+ # Get the schema
+ info <- arrow_table$schema$metadata$huggingface
+ if (is.null(info)) {
+ return(NULL)
+ }
+ jsonlite::fromJSON(info)$info$features
+}
+
+#' @title arrow_to_factor
+#' @description Convert label encodings to factors.
+#' Uses the arrow metadata to get the label information.
+#' @param y The arrow table
+#' @return if y is not an arrow table or y is not a class label,
+#' it returns y as is.
+encodings_to_factor <- function(y) {
+ # verify that y is an arrow table
+ if (!inherits(y, "Table")) {
+ return(y)
+ }
+ label_info <- get_info_from_arrow_table(y)$features$labels
+ if (is_class(label_info)) {
+ return(y)
+ }
+ # Add a label for missing values
+ label_names <- c("", label_info$names)
+ y <- as.data.frame(y)[, 1]
+ # labels ranges from -1 to n-1,
+ # where n is the number of classes and a -1 is a missing value
+ factor(y + 2,
+ levels = seq_along(label_names),
+ labels = label_names
+ )
+}
+
+is_class <- function(label_info) {
+ if (label_info$`_type` != "ClassLabel") {
+ return(TRUE)
+ }
+ FALSE
+}
+
+#' detect using class what each variable type is
+#' @param data data frame containing the data
+#' @param col name of the column that is being checked
+detect_var_type <- function(data, col) {
+ cat_types <- c("character", "logical")
+ num_types <- c("numeric", "integer", "double")
+
+ if (class(data[[col]]) %in% cat_types) {
+ col_type <- "categorical"
+
+ if (is_other_var(data, col)) {
+ col_type <- "other"
+ }
+ } else if (class(data[[col]]) %in% num_types) {
+ col_type <- "numerical"
+
+ if (!is_other_var(data, col)) {
+ col_type <- "categorical"
+ }
+ if (length(unique(data[[col]])) <= 1) {
+ col_type <- "other"
+ }
+ }
+ return(col_type)
+}
+
+#' For detecting our classification of other variable
+#' @param data data frame containing the data
+#' @param col name of variable being checked
+#' @param threshold max number of levels
+is_other_var <- function(data, col, threshold = 30) {
+ other_var <- FALSE
+
+ num_levels <- length(unique(data[[col]]))
+ # if there are more than 30 levels we consider this an other variable
+ if (num_levels > threshold || num_levels <= 1) {
+ other_var <- TRUE
+ }
+
+ return(other_var)
+}
+
+get_feature_metadata <- function(data) {
+ # Create a metadata table for the features
+ feature_metadata <- get_info_from_arrow_table(data)
+
+ extract_metadata <- function(item) {
+ return(item$metadata)
+ }
+ # Extract and combine the metadata from each item into a data frame
+ metadata_list <- lapply(feature_metadata, extract_metadata)
+ metadata_df <- dplyr::bind_rows(metadata_list)
+ # add the feature names to the metadata
+ metadata_df$feature <- colnames(data)
+ return(metadata_df)
+}
+
+start_device <- function(path, ...) {
+ # Start the device
+ if (is.null(path)) {
+ path <- tempfile()
+ }
+ ext <- tolower(tools::file_ext(path))
+ if (ext == "pdf") {
+ args <- list(...)
+ pdf(path, width = args$width, height = args$height)
+ } else if (ext == "png") {
+ png(path, ...)
+ } else if (ext == "jpeg" || ext == "jpg") {
+ jpeg(path, ...)
+ } else if (ext == "bmp") {
+ bmp(path, ...)
+ } else if (ext == "tiff") {
+ tiff(path, ...)
+ } else {
+ stop("Unsupported file format")
+ }
+}
+
+
+is_dataframe <- function(data) {
+ is(data, "Table") ||
+ is(data, "RecordBatch") ||
+ is.data.frame(data)
+}
+
+
+is_vector <- function(data) {
+ is(data, "Array") ||
+ is(data, "ChunkedArray") ||
+ is.vector(data)
+}
+
+
+convert_to_dataframe <- function(value) {
+ if (is_dataframe(value)) {
+ as.data.frame(value)
+ } else if (is_vector(value)) {
+ `colnames<-`(as.data.frame(value), NULL)
+ } else {
+ stop(paste0("Unsupported object type: ", class(data)))
+ }
+}
+
+get_default_columns <- function(data, default = NULL, max_cols = 1) {
+ if (
+ is_dataframe(data) &&
+ (
+ is.null(max_cols) ||
+ ncol(data) == max_cols
+ ) &&
+ !is.null(colnames(data))
+ ) {
+ colnames(data)
+ } else {
+ default
+ }
+}
+
+validate_data <- function(data, names, required = FALSE, replace = TRUE) {
+ if (is_dataframe(data)) {
+ if (!is.null(names)) {
+ if (!all(names %in% colnames(data)) && !replace) {
+ stop(paste0("Column name ", names, " is not in the data"))
+ }
+ } else if (required) {
+ stop("A column name was not provided")
+ }
+ if (ncol(data) == length(names) && replace) {
+ colnames(data) <- names
+ }
+ data
+ } else {
+ stop("Data must be a data frame")
+ }
+}
+
+
+concatenate_datasets <- function(data1, data2, how = "horizontal") {
+ if (is_dataframe(data1) || is_dataframe(data2)) {
+ if (how == "vertical") {
+ rbind(data1, data2)
+ } else {
+ cbind(data1, data2)
+ }
+ } else if (is_vector(data1) && is_vector(data2)) {
+ if (how == "vertical") {
+ c(data1, data2)
+ } else {
+ cbind(data1, data2)
+ }
+ } else {
+ stop("Data types must match")
+ }
+}
diff --git a/src/biofit/integration/__init__.py b/src/biofit/integration/__init__.py
new file mode 100644
index 0000000..2737d0b
--- /dev/null
+++ b/src/biofit/integration/__init__.py
@@ -0,0 +1,2 @@
+# ruff: noqa
+from .R import RCaller
diff --git a/src/biofit/integration/biosets.py b/src/biofit/integration/biosets.py
new file mode 100644
index 0000000..9c56126
--- /dev/null
+++ b/src/biofit/integration/biosets.py
@@ -0,0 +1,14 @@
+import importlib
+
+from biocore.utils.import_util import is_biosets_available
+
+
+class NotAvailable:
+ def __init__(self, *args, **kwargs):
+ pass
+
+
+def get_feature(val):
+ if is_biosets_available():
+ return getattr(importlib.import_module("biosets.features"), val)
+ return None
diff --git a/src/biofit/integration/patcher.py b/src/biofit/integration/patcher.py
new file mode 100644
index 0000000..151f844
--- /dev/null
+++ b/src/biofit/integration/patcher.py
@@ -0,0 +1,880 @@
+import ast
+import importlib
+import importlib.util
+import inspect
+import json
+import os
+import posixpath
+import shutil
+import threading
+from collections import defaultdict
+from contextlib import contextmanager
+from pathlib import Path
+from types import ModuleType
+from typing import Any, Callable, Dict, List, Optional, Tuple, TypeAlias, Union
+
+import filelock
+
+import biofit.config
+
+from ..utils import gorilla, logging
+from ..utils.file_utils import is_remote_url
+from ..utils.fingerprint import Hasher, is_caching_enabled
+from ..utils.gorilla import (
+ SameSourceAndDestinationError,
+ _get_members,
+ _module_iterator,
+)
+
+logger = logging.get_logger(__name__)
+
+_EXCLUDED_MEMBERS = [
+ "__builtins__",
+ "__cached__",
+ "__doc__",
+ "__file__",
+ "__loader__",
+ "__name__",
+ "__package__",
+ "__spec__",
+ "__annotations__",
+]
+
+
+def dynamic_import(attr_path: str):
+ """
+ Dynamically imports an attribute (class, function, variable) from a given path.
+
+ :param attr_path: The full path to the attribute
+ :return: The attribute object.
+ """
+ module_path, attr_name = attr_path.rsplit(".", 1)
+ module = importlib.import_module(module_path)
+ return getattr(module, attr_name)
+
+
+def parse_module_for_patching(
+ module: Union[ModuleType, str],
+ hash_module: bool = True,
+ extra=[],
+ filter: Optional[Callable] = None,
+) -> Dict[str, Tuple[Any, str]]:
+ if isinstance(module, str):
+ module = importlib.import_module(module)
+
+ members = _get_members(
+ module, filter=create_patch_filter(module) if filter is None else filter
+ )
+ _patches = {}
+ for member_name, member in members:
+ if hash_module:
+ _patches[member_name] = tuple(
+ [member, module.__name__, Hasher.hash(member)] + extra
+ )
+ else:
+ _patches[member_name] = tuple([member, module.__name__] + extra)
+ return _patches
+
+
+def create_patch_filter(module: ModuleType):
+ excluded_imports = set()
+ for node in ast.parse(inspect.getsource(module)).body:
+ if isinstance(node, ast.ImportFrom) or isinstance(node, ast.Import):
+ excluded_imports.update(
+ [
+ name.name if name.asname is None else name.asname
+ for name in node.names
+ ]
+ )
+
+ def filter(name: str, obj: Any):
+ if (
+ name in _EXCLUDED_MEMBERS
+ or name in excluded_imports
+ or inspect.ismodule(obj)
+ ):
+ return False
+ return True
+
+ return filter
+
+
+def get_hashed_patches(entity_paths=[], module_paths=[]):
+ """
+ Defines patches with their module paths and dynamically loads and hashes their values.
+
+ :return: A list of patches with their values hashed.
+ """
+
+ patches = []
+ for path in entity_paths:
+ value = dynamic_import(path)
+ module_name = path.rsplit(".", 1)[0]
+ hashed_value = Hasher.hash(value)
+ patches.append(
+ (
+ path.rsplit(".", maxsplit=1)[-1],
+ (value, module_name, hashed_value),
+ )
+ )
+
+ for path in module_paths:
+ patches.extend(parse_module_for_patching(path).items())
+
+ return patches
+
+
+def create_lock_path(root, rel_path):
+ lock_path = posixpath.join(
+ root, Path(rel_path).as_posix().replace("/", "_") + ".lock"
+ )
+ return lock_path
+
+
+class PatcherConfig:
+ """Base class for the wrapper config. Handles patching and caching.
+
+ Args:
+ patches (Dict[str, Any]): Dictionary of patches to apply. The keys are the names of the members to patch and the values are the patches.
+ package (ModuleType): The target package that will be within the `patch_targets`. All members within patches will
+ patch_targets (Union[List[ModuleType], ModuleType]): The packages to patch.
+ cache_dir (Optional[Union[Path, str]], optional): The cache directory. Defaults to None.
+ """
+
+ patches: Dict[str, Any] = None
+ root: ModuleType = None
+ patch_targets: Union[List[ModuleType], ModuleType] = None
+ settings = gorilla.Settings(allow_hit=True, store_hit=True)
+ cache_dir: Optional[Union[Path, str]] = None
+
+ def __init__(
+ self,
+ patches: Dict[str, Any] = None,
+ root: ModuleType = None,
+ patch_targets: Union[List[ModuleType], ModuleType] = None,
+ cache_dir: Optional[Union[Path, str]] = None,
+ **kwargs,
+ ):
+ self._cache_enabled = is_caching_enabled()
+
+ self.patches = self.patches if self.patches is not None else patches
+ self.root = self.root if self.root is not None else root
+ self.patch_targets = (
+ self.patch_targets if self.patch_targets is not None else patch_targets
+ )
+ self.cache_dir = self.cache_dir if self.cache_dir is not None else cache_dir
+
+ if isinstance(self.patches, dict):
+ self.patches = list(self.patches.items())
+
+ if self.patch_targets is None:
+ self.patch_targets = [self.root]
+ elif not isinstance(self.patch_targets, list):
+ self.patch_targets = [self.patch_targets]
+
+ self.cache_files = None
+
+ if self.cache_dir is None:
+ self._cache_dir_root = biofit.config.BIOFIT_PATCHES_CACHE.as_posix()
+ else:
+ if isinstance(self.cache_dir, str) or isinstance(self.cache_dir, Path):
+ self._cache_dir_root = self.cache_dir
+ elif isinstance(self.cache_dir, dict):
+ self._cache_dir_root = self.cache_dir.get(
+ "root", biofit.config.BIOFIT_PATCHES_CACHE
+ )
+ # specifies the cache files for each patch target.
+ for patch_target in self.patch_targets:
+ if patch_target.__name__ in self.cache_dir:
+ for name, path in self.cache_dir[patch_target.__name__].items():
+ if isinstance(path, Path):
+ path = path.as_posix()
+ if os.path.exists(path):
+ if self.cache_files is None:
+ self.cache_files = defaultdict(dict)
+ self.cache_files[patch_target.__name__][name] = path
+
+ self._cache_dir_root = (
+ self._cache_dir_root
+ if is_remote_url(self._cache_dir_root)
+ else os.path.expanduser(self._cache_dir_root)
+ )
+
+ self._relative_ref_pkg_cache_dir = self._build_relative_cache_dir(self.root)
+ self._relative_patch_targets_cache_dirs = {
+ p.__name__: self._build_relative_cache_dir(p)
+ for p in self.patch_targets
+ if hasattr(p, "__name__")
+ }
+
+ self.patches = self._sort_patches(self._prepare_patches(self.patches))
+ self._output_dir = self._build_output_dir()
+ # NOTE: this is causing issues with the cache, so we will disable it for now
+ # self.cache_files = self._get_cache_files()
+ self.cache_files = None
+
+ self._attemps = kwargs.get("attemps", 1)
+
+ def _generate_modules(self):
+ return {
+ package.__name__: (
+ module for module in _module_iterator(package, recursive=True)
+ )
+ for package in self.patch_targets
+ if hasattr(package, "__name__")
+ }
+
+ def _build_output_dir(self):
+ _output_dir = Hasher().hash(self.patches)
+ os.makedirs(os.path.join(self._cache_dir_root, _output_dir), exist_ok=True)
+ return _output_dir
+
+ def _build_relative_cache_dir(self, package):
+ """Return the data directory for the current version."""
+ version = package.__version__ if hasattr(package, "__version__") else "0.0.0"
+ package_name = package.__name__ if hasattr(package, "__name__") else "unknown"
+ relative_patches_dir = posixpath.join(package_name, version)
+ return relative_patches_dir
+
+ def _prepare_patches(self, patches: Dict[str, Any]):
+ if patches is None:
+ raise ValueError("patches cannot be None.")
+ _patches = []
+ if not is_remote_url(self._cache_dir_root):
+ source_cache_dir = posixpath.join(
+ self._cache_dir_root, self._relative_ref_pkg_cache_dir
+ )
+ os.makedirs(source_cache_dir, exist_ok=True)
+ for name, patch in patches:
+ if isinstance(patch, Tuple):
+ if len(patch) == 2:
+ patch_, source_ = patch
+ hash_ = Hasher.hash(patch_)
+ if isinstance(source_, ModuleType):
+ source_ = source_.__name__
+ # check if source exists
+ if not isinstance(source_, str) or not importlib.util.find_spec(
+ source_
+ ):
+ raise ValueError(
+ f"Invalid source: {source_}, source must be a module."
+ )
+ _patches.append((name, (patch_, source_, hash_, None)))
+ elif len(patch) == 3:
+ patch_, source_, hash_ = patch
+ if isinstance(source_, ModuleType):
+ source_ = source_.__name__
+ if not isinstance(source_, str) or not importlib.util.find_spec(
+ source_
+ ):
+ raise ValueError(
+ f"Invalid source: {source_}, source must be a module."
+ )
+ if hash_ is None:
+ hash_ = Hasher.hash(patch)
+ _patches.append((name, (patch_, source_, hash_, None)))
+ elif len(patch) == 4:
+ patch_, source_, hash_, destination_ = patch
+ if isinstance(source_, ModuleType):
+ source_ = source_.__name__
+ if not isinstance(source_, str) or not importlib.util.find_spec(
+ source_
+ ):
+ raise ValueError(
+ f"Invalid source: {source_}, source must be a module."
+ )
+ if hash_ is None:
+ hash_ = Hasher.hash(patch)
+ if isinstance(destination_, str):
+ destination_ = importlib.import_module(destination_)
+ if not isinstance(destination_, ModuleType):
+ raise ValueError(
+ f"Invalid destination: {destination_}, destination must be a module."
+ )
+ _patches.append((name, (patch_, source_, hash_, destination_)))
+ else:
+ hash_ = Hasher.hash(patch)
+ origin = inspect.getmodule(patch)
+ source_ = origin.__name__ if hasattr(origin, "__name__") else None
+ _patches.append((name, (patch_, source_, hash_, None)))
+ return _patches
+
+ def _get_cache_files(self):
+ cache_files = None
+ if self._cache_enabled and not os.path.exists(
+ posixpath.join(
+ self._cache_dir_root, self._output_dir, biofit.config.PATCHES_FILENAME
+ )
+ ):
+ for name, (_, _, hash, _) in self.patches:
+ for (
+ patch_target_name,
+ relative_target_cache_dir,
+ ) in self._relative_patch_targets_cache_dirs.items():
+ cache_file = posixpath.join(
+ self._cache_dir_root,
+ self._relative_ref_pkg_cache_dir,
+ relative_target_cache_dir,
+ hash,
+ )
+ os.makedirs(cache_file, exist_ok=True)
+ patches_file = f"{cache_file}/{biofit.config.PATCHES_FILENAME}"
+ if os.path.exists(patches_file):
+ if cache_files is None:
+ cache_files = defaultdict(dict)
+ cache_files[patch_target_name][name] = patches_file
+ no_patches_file = (
+ f"{cache_file}/{biofit.config.NO_PATCHES_FILENAME}"
+ )
+ if os.path.exists(no_patches_file):
+ if cache_files is None:
+ cache_files = defaultdict(dict)
+ cache_files[patch_target_name][name] = no_patches_file
+ return cache_files
+
+ @classmethod
+ def clear_cache(cls, cache_dir: Optional[Union[Path, str]] = None):
+ """Clear a directory in the root cache directory for patches.
+
+ Args:
+ cache_dir (Optional[Union[Path, str]], optional): The cache directory to delete.
+ If None, the entire cache directory will be deleted. Defaults to None.
+ """
+ if cache_dir is None:
+ cache_dir = (
+ cls._cache_dir_root
+ if hasattr(cls, "_cache_dir_root")
+ else biofit.config.BIOFIT_PATCHES_CACHE.as_posix()
+ )
+ if hasattr(cls, "_output_dir"):
+ cache_dir = posixpath.join(cache_dir, cls._output_dir)
+
+ if os.path.exists(cache_dir):
+ shutil.rmtree(cache_dir)
+
+ def _clear_cache(self):
+ self.cache_files = None
+ for _, (_, _, hash, _) in self.patches:
+ for (
+ relative_target_cache_dir
+ ) in self._relative_patch_targets_cache_dirs.values():
+ rel_member_dir = posixpath.join(
+ self._relative_ref_pkg_cache_dir, relative_target_cache_dir, hash
+ )
+ member_dir = posixpath.join(self._cache_dir_root, rel_member_dir)
+ # lock the directory
+ if os.path.exists(member_dir):
+ for file in os.listdir(member_dir):
+ rel_fp = posixpath.join(rel_member_dir, file)
+ fp = os.path.join(self._cache_dir_root, rel_fp)
+ lock_path = create_lock_path(self._cache_dir_root, rel_fp)
+ with filelock.FileLock(lock_path):
+ os.remove(fp)
+ with filelock.FileLock(
+ create_lock_path(self._cache_dir_root, rel_member_dir)
+ ):
+ os.rmdir(member_dir)
+ output_dir = posixpath.join(self._cache_dir_root, self._output_dir)
+ fp = posixpath.join(output_dir, biofit.config.PATCHES_FILENAME)
+ if os.path.exists(fp):
+ with filelock.FileLock(
+ create_lock_path(
+ self._cache_dir_root,
+ posixpath.join(
+ self._output_dir, biofit.config.PATCHES_FILENAME
+ ),
+ )
+ ):
+ os.remove(fp)
+ if os.path.exists(output_dir):
+ with filelock.FileLock(
+ create_lock_path(self._cache_dir_root, self._output_dir)
+ ):
+ shutil.rmtree(output_dir)
+
+ self._cleanup_empty_dir()
+
+ def _cleanup_empty_dir(self):
+ # use os.walk to find empty directories in self._cache_dir_root
+ for root, dirs, files in os.walk(self._cache_dir_root, topdown=False):
+ if len(dirs) == 0 and len(files) == 0:
+ os.rmdir(root)
+
+ def _sort_patches(self, patches: list):
+ """Sorts the patches for consistent hashing"""
+ unique_keys = [str(p) for p in patches]
+ inds = sorted(range(len(unique_keys)), key=lambda index: unique_keys[index])
+ return [patches[i] for i in inds]
+
+
+class Patcher:
+ """
+ Patcher class for applying patches to target packages/modules.
+ """
+
+ config: PatcherConfig = None
+ _lock: Path = None
+
+ def __init__(self, config: PatcherConfig = None, **kwargs):
+ """
+ Initialize the Patcher object.
+
+ Args:
+ configs (Union[PatcherConfig, List[PatcherConfig]]): Configuration object or list of configuration objects.
+ filter (Optional[Union[str, List[str]]], optional):
+ *args: Additional positional arguments.
+ **kwargs: Additional keyword arguments.
+ """
+ self.config = self.config if self.config is not None else config
+ self._cache_enabled = is_caching_enabled()
+ self.patches = self._sort_patches(self.load_and_prepare_patches())
+ self._fingerprint = Hasher.hash(str(self.patches))
+ self._validate_cache()
+
+ def _sort_patches(self, patches: list):
+ """Sorts the patches for consistent hashing"""
+ return self.config._sort_patches(patches)
+
+ @contextmanager
+ def patch(self, patches: List[gorilla.Patch]):
+ """Context manager for applying additional patches to the target packages/modules."""
+ for patch in patches:
+ gorilla.apply(patch)
+
+ self._apply_patches()
+
+ yield
+
+ self._revert_patches()
+
+ for patch in patches:
+ gorilla.revert(patch)
+
+ def _apply_patches(self):
+ if self._lock:
+ # means that it is trying to run in nested context of the same instance
+ # specify this by setting _cleanup to None
+ self._cleanup = None
+ return
+
+ thread_id = threading.get_ident()
+ file_path = os.path.normpath(
+ posixpath.join(self.config._output_dir, f"{self._fingerprint}_{thread_id}")
+ )
+ self._lock = Path(create_lock_path(self.config._cache_dir_root, file_path))
+ self._cleanup = False
+
+ if not self._lock.exists():
+ self._cleanup = True
+ for patch in self.patches:
+ gorilla.apply(patch)
+ self._lock.parent.mkdir(parents=True, exist_ok=True)
+ self._lock.touch()
+ else:
+ self._lock = None
+
+ def _revert_patches(self):
+ if self._cleanup is None:
+ # same instance is already running, we put it back to true
+ self._cleanup = True
+ return
+ if self._cleanup:
+ for patch in self.patches:
+ gorilla.revert(patch)
+ # remove the lock file
+ if self._lock and self._lock.exists():
+ self._lock.unlink()
+ self._lock = None
+
+ def __enter__(self):
+ self._apply_patches()
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self._revert_patches()
+
+ def load_and_prepare_patches(self):
+ output_file = posixpath.join(
+ self.config._output_dir, biofit.config.PATCHES_FILENAME
+ )
+ fp = posixpath.join(self.config._cache_dir_root, output_file)
+ if self._cache_enabled and os.path.exists(fp):
+ try:
+ return self._load_patches_from_cache(fp)
+ except Exception as e:
+ # something went wrong with loading the patches from cache, so we will clear the cache and recompile
+ self.config._clear_cache()
+ logger.warning(f"Failed to load patches from cache: {e}")
+
+ patches: List[gorilla.Patch] = []
+ _uncached_patches = defaultdict(dict)
+ for name, (patch, source, hash, destination) in self.config.patches:
+ if destination is not None:
+ try:
+ patches.append(
+ gorilla.Patch(
+ destination,
+ name=name,
+ obj=patch,
+ source=source,
+ settings=self.config.settings,
+ )
+ )
+ except SameSourceAndDestinationError:
+ # a patch with the same source and destination was given, raise an error
+ raise SameSourceAndDestinationError(
+ f'Patch with same source and destination: ("{name}", ({name}, "{source}", "{hash}", "{destination}"))'
+ )
+ for patch_target in self.config.patch_targets:
+ _uncached_patches[patch_target.__name__][name] = (patch, source, hash)
+
+ patches += self.create_patches(_uncached_patches)
+ try:
+ self._save_patches_to_cache(patches, fp)
+ except Exception as e:
+ logger.warning(f"Failed to save patches to cache: {e}")
+
+ if os.path.exists(fp):
+ with filelock.FileLock(
+ create_lock_path(self.config._cache_dir_root, output_file)
+ ):
+ os.remove(fp)
+ with filelock.FileLock(
+ create_lock_path(
+ self.config._cache_dir_root, self.config._output_dir
+ )
+ ):
+ os.rmdir(
+ os.path.join(
+ self.config._cache_dir_root, self.config._output_dir
+ )
+ )
+ return patches
+
+ def create_patches(self, patches: Dict[str, Any] = None):
+ results = []
+ if len(patches) > 0:
+ for patch_target in self.config.patch_targets:
+ _patches: List[Tuple[gorilla.Patch, str]] = self._create_patches(
+ patches,
+ patch_target,
+ return_hash=True,
+ )
+ results += [p for p, _ in _patches]
+
+ # NOTE: this is causing issues with the cache, so we will disable it for now
+ # if self._cache_enabled:
+ # self._save_patches_to_patch_cache(_patches, patch_target.__name__)
+
+ return results
+
+ def _create_patches(
+ self,
+ patches: Dict[str, Any],
+ patch_target: ModuleType,
+ return_hash=False,
+ ) -> Union[List[gorilla.Patch], List[Tuple[gorilla.Patch, str]]]:
+ """Create patches for within `~self.config.package`.
+
+ Args:
+ patches (`Dict[str, Any]`):
+ Dictionary of patches to apply. The keys are the names of the members.
+ Values can be either the patch (i.e. the object that will replace the member) or a tuple of
+ (patch, source module), (patch, hash), or (patch, source module, hash). If tuple is length 2,
+ the second element is inferred based on if it is an instance of `ModuleType` or if str is importable.
+ Otherwise, it will be taken as the hash.
+ patch_targets (`Union[ModuleType, List[ModuleType]]`):
+ The target packages to patch. If recursive is True, all submodules will be patched as well.
+ recursive (`bool`, Defaults to True):
+ Whether to recursively search for modules to patch.
+ return_hash (`bool`, Defaults to False):
+ Whether to return the hash of the patch in final output. If True, the output will be a tuple of (patch, hash).
+
+ Returns:
+ `List[gorilla.Patch]` or `List[Tuple[gorilla.Patch, str]]`:
+ List of patches to apply. If return_hash is True, the output will be a tuple of (patch, hash).
+ """
+
+ _patches = []
+
+ try:
+ for module in self.config._generate_modules()[patch_target.__name__]:
+ for asname, name, value in self._find_patches(
+ patches[patch_target.__name__], module
+ ):
+ if isinstance(value, tuple):
+ if len(value) == 3:
+ patch, source, hash = value
+ elif len(value) == 2:
+ patch, source = value
+ if importlib.util.find_spec(source):
+ hash = Hasher.hash(patch) if return_hash else None
+ else:
+ hash = source
+ origin = inspect.getmodule(value)
+ source = (
+ origin.__name__
+ if hasattr(origin, "__name__")
+ else None
+ )
+ else:
+ patch = value
+ origin = inspect.getmodule(value)
+ source = (
+ origin.__name__ if hasattr(origin, "__name__") else None
+ )
+ hash = Hasher.hash(patch) if return_hash else None
+ try:
+ if return_hash:
+ _patches.append(
+ (
+ gorilla.Patch(
+ module,
+ name=asname,
+ obj=patch,
+ source=source,
+ source_name=name,
+ settings=self.config.settings,
+ ),
+ hash,
+ )
+ )
+ else:
+ _patches.append(
+ gorilla.Patch(
+ module,
+ name=asname,
+ obj=patch,
+ source=source,
+ source_name=name,
+ settings=self.config.settings,
+ )
+ )
+ except SameSourceAndDestinationError:
+ pass
+
+ except Exception as e:
+ logger.warning(
+ f"Failed to create patches: {e} for patch_target: {patch_target.__name__}"
+ )
+ raise e
+ return _patches
+
+ def _get_absolute_module_name(
+ self, source_module: Union[str, ModuleType], node: ast.ImportFrom
+ ) -> str:
+ """
+ Reconstructs the absolute module name from a relative import.
+
+ Args:
+ - source_module (str): The name of the module that contains the relative import.
+ - node (ast.ImportFrom): The AST node representing the relative import.
+
+ Returns:
+ - str: The absolute module name.
+ """
+ if not isinstance(node, ast.ImportFrom):
+ raise ValueError("Node must be of type ast.ImportFrom")
+
+ relative_level = node.level
+ source_module_parts = source_module.split(".")
+ if relative_level > len(source_module_parts):
+ raise ValueError("Relative level is too high for the current module")
+
+ source_module_parts = source_module_parts[:-relative_level]
+ imported_module_parts = node.module.split(".") if node.module else []
+
+ absolute_module_parts = source_module_parts + imported_module_parts
+ return ".".join(absolute_module_parts)
+
+ def _find_patches(
+ self, patches: Dict[str, Any], module: ModuleType
+ ) -> List[Tuple[str, Any]]:
+ """
+ Find the patches to apply in the module.
+
+ Args:
+ patches (Dict[str, Any]): Dictionary of patches to apply.
+ module (ModuleType): Module to search for patches.
+
+ Returns:
+ OrderedDict[str, Any]: Ordered dictionary of patches to apply.
+ """
+ try:
+ module_source = inspect.getsource(module)
+ except OSError:
+ with open(module.__file__, "r") as file:
+ module_source = file.read()
+
+ exclude_imports = set()
+ asname_map = {}
+
+ package_name = self.config.root.__name__
+
+ for node in ast.parse(module_source).body:
+ if isinstance(node, ast.ImportFrom):
+ module_name = (
+ node.module
+ if node.level == 0
+ else self._get_absolute_module_name(module.__name__, node)
+ )
+ if package_name not in module_name:
+ exclude_imports.update([name.name for name in node.names])
+ else:
+ for name in node.names:
+ if name.asname is not None:
+ asname_map[name.asname] = name.name
+
+ def is_from_package(member: TypeAlias, module: ModuleType, package_name):
+ """Check if the member is from the same package as the module."""
+ origin = inspect.getmodule(member)
+ return origin is None or origin.__name__.startswith(package_name)
+
+ module_globals = inspect.getmembers(
+ module, lambda a: is_from_package(a, module, package_name)
+ )
+
+ return [
+ (
+ (name, name, patches[name])
+ if name not in asname_map
+ else (name, asname_map[name], patches[asname_map[name]])
+ )
+ for (name, value) in module_globals
+ if (
+ name not in exclude_imports
+ and name not in _EXCLUDED_MEMBERS
+ and not inspect.ismodule(value)
+ and (
+ name in patches
+ or (name in asname_map and asname_map[name] in patches)
+ )
+ )
+ ]
+
+ def _save_patches_to_cache(
+ self, patches: List[Tuple[gorilla.Patch, str]], file_path: str
+ ):
+ if not is_remote_url(self.config._cache_dir_root):
+ _relative_cache_dir = (
+ Path(file_path).relative_to(self.config._cache_dir_root).as_posix()
+ )
+ with filelock.FileLock(
+ create_lock_path(self.config._cache_dir_root, _relative_cache_dir)
+ ):
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
+ with open(file_path, "w") as f:
+ json.dump([p.to_dict() for p in patches], f)
+
+ def _save_patches_to_patch_cache(
+ self, patches: List[Tuple[gorilla.Patch, str]], patch_target: str
+ ):
+ """
+ Save the patches to the cache.
+
+ Args:
+ config (PatcherConfig): Patcher configuration object.
+ patches (List[Tuple[gorilla.Patch, str]]): List of patches to save.
+ patch_targets_name_or_file_path (str, optional): Name of the patch target or the file path to save the patches to.
+ """
+
+ if not is_remote_url(self.config._cache_dir_root):
+ gorilla_patches: Dict[str, List[gorilla.Patch]] = defaultdict(list)
+ for gorilla_patch, hash in patches:
+ gorilla_patches[hash].append(gorilla_patch)
+
+ for _, (_, _, hash, _) in self.config.patches:
+ if hash in gorilla_patches:
+ continue
+ gorilla_patches[hash] = []
+
+ for hash, gorilla_patch in gorilla_patches.items():
+ _relative_cache_dir = posixpath.join(
+ self.config._relative_ref_pkg_cache_dir,
+ self.config._relative_patch_targets_cache_dirs[patch_target],
+ hash,
+ )
+ if gorilla_patch:
+ file_path = os.path.join(
+ self.config._cache_dir_root,
+ _relative_cache_dir,
+ biofit.config.PATCHES_FILENAME,
+ )
+ else:
+ file_path = os.path.join(
+ self.config._cache_dir_root,
+ _relative_cache_dir,
+ biofit.config.NO_PATCHES_FILENAME,
+ )
+ # lock the directory
+ with filelock.FileLock(
+ create_lock_path(
+ self.config._cache_dir_root, _relative_cache_dir
+ )
+ ):
+ try:
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
+ with open(file_path, "w") as f:
+ json.dump([p.to_dict() for p in gorilla_patch], f)
+
+ except Exception as e:
+ logger.warning(f"Failed to save patch to cache: {e}")
+ if os.path.exists(file_path):
+ os.remove(file_path)
+ os.rmdir(os.path.dirname(file_path))
+
+ def _load_patches_from_cache(self, file_path: str) -> List[gorilla.Patch]:
+ """
+ Load the patches from the cache.
+
+ Args:
+ config (PatcherConfig): Patcher configuration object.
+ file_path (str): Path to the cache file.
+
+ Returns:
+ List[gorilla.Patch]: List of patches loaded from the cache.
+ """
+ relative_file_path = (
+ Path(file_path).relative_to(self.config._cache_dir_root).as_posix()
+ )
+
+ # lock the directory
+ with filelock.FileLock(
+ create_lock_path(self.config._cache_dir_root, relative_file_path)
+ ):
+ with open(file_path, "r") as f:
+ content = json.load(f)
+ if content is None:
+ raise ValueError(f"Failed to load patch from cache: {file_path}")
+ if isinstance(content, dict):
+ return [gorilla.Patch.from_dict(content)]
+ elif isinstance(content, list):
+ return [gorilla.Patch.from_dict(c) for c in content]
+
+ def _validate_cache(self):
+ if not is_remote_url(self.config._cache_dir_root):
+ output_dir = posixpath.join(
+ self.config._cache_dir_root, self.config._output_dir
+ )
+ probe_file = os.path.normpath(
+ posixpath.join(output_dir, f"{self._fingerprint}.txt")
+ )
+ # look for a .cache file in the output directory
+ for root, dirs, files in os.walk(output_dir, topdown=False):
+ for file in files:
+ if file.endswith(".txt"):
+ fp = os.path.normpath(os.path.join(root, file))
+ if probe_file != fp:
+ logger.debug(
+ "Cache file is outdated or corrupted, clearing the cache and re-building the patches..."
+ )
+ self.config.clear_cache(self.config._output_dir)
+ self.__init__(self.config)
+ return
+ if not os.path.exists(probe_file):
+ with filelock.FileLock(
+ create_lock_path(
+ self.config._cache_dir_root, self.config._output_dir
+ )
+ ):
+ with open(probe_file, "w") as f:
+ f.write("")
diff --git a/src/biofit/metrics/__init__.py b/src/biofit/metrics/__init__.py
new file mode 100644
index 0000000..7015280
--- /dev/null
+++ b/src/biofit/metrics/__init__.py
@@ -0,0 +1,168 @@
+from functools import partial
+
+from sklearn.metrics import (
+ accuracy_score,
+ balanced_accuracy_score,
+ f1_score,
+ log_loss,
+ mean_absolute_error,
+ mean_squared_error,
+ mean_squared_log_error,
+ precision_score,
+ r2_score,
+ recall_score,
+ roc_auc_score,
+)
+
+from .metrics import (
+ calculate_metrics, # noqa: F401
+ confusion_matrix, # noqa: F401
+ log_loss_weighted,
+ specificity,
+)
+
+_OPTUNA_METRICS = {
+ "binary_classification": "f1",
+ "multiclass_classification": "f1_weighted",
+ "regression": "rmse",
+ "multi_regression": "rmse",
+}
+
+_METRICS = {
+ "binary_classification": {
+ "logloss": (log_loss, "minimize"),
+ "logloss_weighted": (
+ partial(log_loss_weighted, class_weights="balanced"),
+ "minimize",
+ ),
+ "auc": (roc_auc_score, "maximize"),
+ "f1": (f1_score, "maximize"),
+ "accuracy": (accuracy_score, "maximize"),
+ "balanced_accuracy": (balanced_accuracy_score, "maximize"),
+ "precision": (precision_score, "maximize"),
+ "recall": (recall_score, "maximize"),
+ "specificity": (specificity, "maximize"),
+ },
+ "multiclass_classification": {
+ "mlogloss": (
+ lambda y_true, y_pred, labels: log_loss(y_true, y_pred, labels=labels),
+ "minimize",
+ ),
+ "mlogloss_weighted": (
+ lambda y_true, y_pred, labels: log_loss_weighted(
+ y_true, y_pred, labels=labels, class_weights="balanced"
+ ),
+ "minimize",
+ ),
+ "accuracy": (
+ lambda y_true, y_pred, labels: accuracy_score(y_true, y_pred),
+ "maximize",
+ ),
+ "balanced_accuracy": (
+ lambda y_true, y_pred, labels: balanced_accuracy_score(y_true, y_pred),
+ "maximize",
+ ),
+ "f1_macro": (
+ lambda y_true, y_pred, labels: f1_score(
+ y_true, y_pred, average="macro", labels=labels
+ ),
+ "maximize",
+ ),
+ "f1_micro": (
+ lambda y_true, y_pred, labels: f1_score(
+ y_true, y_pred, average="micro", labels=labels
+ ),
+ "maximize",
+ ),
+ "f1_weighted": (
+ lambda y_true, y_pred, labels: f1_score(
+ y_true, y_pred, average="weighted", labels=labels
+ ),
+ "maximize",
+ ),
+ "precision_macro": (
+ lambda y_true, y_pred, labels: precision_score(
+ y_true, y_pred, average="macro", labels=labels
+ ),
+ "maximize",
+ ),
+ "precision_micro": (
+ lambda y_true, y_pred, labels: precision_score(
+ y_true, y_pred, average="micro", labels=labels
+ ),
+ "maximize",
+ ),
+ "precision_weighted": (
+ lambda y_true, y_pred, labels: precision_score(
+ y_true, y_pred, average="weighted", labels=labels
+ ),
+ "maximize",
+ ),
+ "recall_macro": (
+ lambda y_true, y_pred, labels: recall_score(
+ y_true, y_pred, average="macro", labels=labels
+ ),
+ "maximize",
+ ),
+ "recall_micro": (
+ lambda y_true, y_pred, labels: recall_score(
+ y_true, y_pred, average="micro", labels=labels
+ ),
+ "maximize",
+ ),
+ "recall_weighted": (
+ lambda y_true, y_pred, labels: recall_score(
+ y_true, y_pred, average="weighted", labels=labels
+ ),
+ "maximize",
+ ),
+ "specificity_macro": (
+ lambda y_true, y_pred, labels: specificity(
+ y_true, y_pred, average="macro", labels=labels
+ ),
+ "maximize",
+ ),
+ # "specificity_micro": (
+ # lambda y_true, y_pred, labels: specificity(
+ # y_true, y_pred, average="micro", labels=labels
+ # ),
+ # "maximize",
+ # ),
+ "specificity_weighted": (
+ lambda y_true, y_pred, labels: specificity(
+ y_true, y_pred, average="weighted", labels=labels
+ ),
+ "maximize",
+ ),
+ },
+ "regression": {
+ "rmse": (partial(mean_squared_error, squared=False), "minimize"),
+ "rmsle": (partial(mean_squared_log_error, squared=False), "minimize"),
+ "r2": (r2_score, "maximize"),
+ "mse": (mean_squared_error, "minimize"),
+ "mae": (mean_absolute_error, "minimize"),
+ },
+ "multi_regression": {
+ "rmse": (partial(mean_squared_error, squared=False), "minimize"),
+ "rmsle": (partial(mean_squared_log_error, squared=False), "minimize"),
+ "r2": (r2_score, "maximize"),
+ "mse": (mean_squared_error, "minimize"),
+ "mae": (mean_absolute_error, "minimize"),
+ },
+ "multilabel_classification": {
+ "logloss": (log_loss, "minimize"),
+ },
+}
+
+
+def get_metrics(task=None, metric=None):
+ if task is None:
+ return _METRICS
+ if metric is None:
+ return _METRICS[task]
+ else:
+ return _METRICS[task][metric]
+
+
+def get_optuna_metric(task):
+ return _OPTUNA_METRICS[task]
diff --git a/src/biofit/metrics/metrics.py b/src/biofit/metrics/metrics.py
new file mode 100644
index 0000000..00ddeb9
--- /dev/null
+++ b/src/biofit/metrics/metrics.py
@@ -0,0 +1,201 @@
+import copy
+import inspect
+from typing import Union
+
+import numpy as np
+import pandas as pd
+from biocore import DataHandler
+from sklearn.metrics import (
+ confusion_matrix as sk_confusion_matrix,
+)
+from sklearn.metrics import (
+ log_loss,
+ multilabel_confusion_matrix,
+)
+from sklearn.metrics._classification import (
+ _check_set_wise_labels,
+ _check_zero_division,
+ _nanaverage,
+ _prf_divide,
+)
+from sklearn.utils.class_weight import compute_class_weight
+
+
+def confusion_matrix(
+ y_true, y_pred, labels=None, sample_weight=None, samplewise=False, normalize=None
+):
+ y_true = DataHandler.to_numpy(y_true)
+ y_pred = DataHandler.to_numpy(y_pred)
+
+ if y_true.ndim == 1 or y_true.shape[1] == 1:
+ y_true = y_true.flatten()
+ if y_pred.ndim == 1 or y_pred.shape[1] == 1:
+ y_pred = (y_pred.flatten() > 0.5).astype(int)
+ else:
+ y_pred = np.argmax(y_pred, axis=1)
+
+ if len(labels) < 3:
+ mat = sk_confusion_matrix(
+ y_true,
+ y_pred,
+ labels=labels
+ if labels is not None and not isinstance(labels[0], str)
+ else None,
+ sample_weight=sample_weight,
+ normalize=normalize,
+ )
+
+ df = pd.DataFrame(
+ mat,
+ columns=[f"Predicted {label}" for label in labels],
+ index=[f"Actual {label}" for label in labels],
+ )
+ return df
+ else:
+ # create a len(labels) x len(labels) matrix
+ mat = np.zeros((len(labels), len(labels)), dtype=int)
+ for i, label in enumerate(labels):
+ for j, pred in enumerate(labels):
+ mat[i, j] = np.sum((y_true == i) & (y_pred == j))
+
+ df = pd.DataFrame(
+ mat,
+ columns=[f"Predicted {label}" for label in labels],
+ index=[f"Actual {label}" for label in labels],
+ )
+ return df
+
+
+def specificity(
+ y_true,
+ y_pred,
+ *,
+ labels=None,
+ pos_label=1,
+ average=None,
+ sample_weight=None,
+ zero_division="warn",
+):
+ _check_zero_division(zero_division)
+ labels = _check_set_wise_labels(y_true, y_pred, average, labels, pos_label)
+
+ # Calculate tp_sum, pred_sum, true_sum ###
+ samplewise = average == "samples"
+ MCM = multilabel_confusion_matrix(
+ y_true,
+ y_pred,
+ sample_weight=sample_weight,
+ labels=labels,
+ samplewise=samplewise,
+ )
+ tp_sum = MCM[:, 1, 1]
+ tn_sum = MCM[:, 0, 0]
+ fp_sum = MCM[:, 0, 1]
+ fn_sum = MCM[:, 1, 0]
+ true_sum = tp_sum + fn_sum
+ false_sum = tn_sum + fp_sum
+
+ if average == "micro":
+ false_sum = np.array([false_sum.sum()])
+
+ # Divide, and on zero-division, set scores and/or warn according to
+ # zero_division:
+ specificity = _prf_divide(
+ tn_sum, false_sum, "specificity", "false", average, "specificity", zero_division
+ )
+
+ # Average the results
+ if average == "weighted":
+ weights = true_sum
+ elif average == "samples":
+ weights = sample_weight
+ else:
+ weights = None
+
+ if average is not None:
+ assert average != "binary" or len(specificity) == 1
+ specificity = _nanaverage(specificity, weights=weights)
+ if isinstance(specificity, (list, tuple, np.ndarray)) and len(specificity) > 1:
+ specificity = specificity[-1]
+ return specificity
+
+
+def log_loss_weighted(y_true, y_pred, labels=None, class_weights=None):
+ ytrue = DataHandler.to_numpy(y_true).flatten()
+ if labels is None:
+ labels = DataHandler.unique(DataHandler.to_numpy(ytrue).flatten())
+ else:
+ labels = DataHandler.to_numpy(labels).flatten()
+ if set(labels) - set(ytrue):
+ labels = DataHandler.unique(DataHandler.to_numpy(ytrue).flatten())
+
+ if class_weights is None:
+ class_weights = np.ones(len(labels))
+ else:
+ class_weights = compute_class_weight(class_weights, classes=labels, y=ytrue)
+
+ ypred = DataHandler.to_numpy(y_pred)
+ if ypred.ndim == 2 and ypred.shape[1] == 1:
+ ypred = ypred.flatten()
+ class_weights = dict(zip(labels, class_weights))
+ sample_weights = np.array([class_weights[label] for label in ytrue])
+ return log_loss(ytrue, ypred, sample_weight=sample_weights)
+
+
+def calculate_metrics(
+ metrics: Union[dict, callable], y_true, y_pred, sub_task, labels=None
+):
+ results = {}
+
+ if labels is not None:
+ labels = list(range(len(labels)))
+ if callable(metrics):
+ params = inspect.signature(metrics).parameters
+ eval_kwargs = {}
+ if "labels" in params:
+ eval_kwargs["labels"] = labels
+
+ results["custom_metric"].append(metrics(y_true, y_pred, **eval_kwargs))
+
+ else:
+ for metric_name, (metric_func, _) in metrics.items():
+ if sub_task == "binary_classification":
+ if metric_name in ["logloss", "logloss_weighted", "auc"]:
+ results[metric_name] = metric_func(
+ y_true, DataHandler.select_column(y_pred, 1)
+ )
+ else:
+ results[metric_name] = metric_func(
+ y_true,
+ DataHandler.ge(DataHandler.select_column(y_pred, 1), 0.5),
+ )
+ elif sub_task == "multiclass_classification":
+ if metric_name in (
+ "accuracy",
+ "balanced_accuracy",
+ "f1_macro",
+ "f1_micro",
+ "f1_weighted",
+ "precision_macro",
+ "precision_micro",
+ "precision_weighted",
+ "recall_macro",
+ "recall_micro",
+ "recall_weighted",
+ "specificity_macro",
+ "specificity_micro",
+ "specificity_weighted",
+ ):
+ results[metric_name] = metric_func(
+ y_true, DataHandler.argmax(y_pred, axis=1), labels
+ )
+ else:
+ results[metric_name] = metric_func(y_true, y_pred, labels)
+ else:
+ if metric_name == "rmsle":
+ temp_pred = copy.deepcopy(DataHandler.to_numpy(y_pred))
+ temp_pred = np.clip(temp_pred, 0, None)
+ results[metric_name] = metric_func(y_true, temp_pred)
+ else:
+ results[metric_name] = metric_func(y_true, y_pred)
+ return results
diff --git a/src/biofit/models/__init__.py b/src/biofit/models/__init__.py
new file mode 100644
index 0000000..6b809ad
--- /dev/null
+++ b/src/biofit/models/__init__.py
@@ -0,0 +1,3 @@
+# ruff: noqa
+from .random_forest import *
+from .lightgbm import *
diff --git a/src/biofit/models/ensemble/__init__.py b/src/biofit/models/ensemble/__init__.py
new file mode 100644
index 0000000..6e3f30a
--- /dev/null
+++ b/src/biofit/models/ensemble/__init__.py
@@ -0,0 +1,2 @@
+# ruff: noqa
+from .random_forest import *
diff --git a/src/biofit/models/ensemble/ensemble.py b/src/biofit/models/ensemble/ensemble.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/biofit/models/lasso/__init__.py b/src/biofit/models/lasso/__init__.py
new file mode 100644
index 0000000..bf46f72
--- /dev/null
+++ b/src/biofit/models/lasso/__init__.py
@@ -0,0 +1,7 @@
+# ruff: noqa
+from .lasso import (
+ LassoModel,
+ LassoConfig,
+ LassoForClassification,
+ LassoForRegression,
+)
diff --git a/src/biofit/models/lasso/lasso.py b/src/biofit/models/lasso/lasso.py
new file mode 100644
index 0000000..2d2ba88
--- /dev/null
+++ b/src/biofit/models/lasso/lasso.py
@@ -0,0 +1,355 @@
+from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, List, Optional, Union
+
+import numpy as np
+from sklearn.linear_model import Lasso
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import SelectedColumnTypes, sync_backup_config
+from biofit.utils import logging
+
+from ..logistic_regression import LogisticRegressionModel
+from ..models import Model, ModelConfig
+
+logger = logging.get_logger(__name__)
+
+if TYPE_CHECKING:
+ pass
+
+
+@dataclass
+class LassoConfig(ModelConfig):
+ _fit_input_feature_types: List[Union[None, type]] = field(
+ default_factory=lambda: [None, get_feature("TARGET_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ _fit_unused_feature_types: List[Union[None, type]] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES"), None],
+ init=False,
+ repr=False,
+ )
+ _transform_unused_feature_types: List[type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ _fit_process_desc: str = field(
+ default="Fitting the lasso model", init=False, repr=False
+ )
+ predict_process_desc: str = field(
+ default=None,
+ init=False,
+ repr=False,
+ )
+ predict_proba_process_desc: str = field(
+ default=None,
+ init=False,
+ repr=False,
+ )
+ processor_name: str = field(default="lasso", init=False, repr=False)
+
+ input_columns: SelectedColumnTypes = None
+ target_column: SelectedColumnTypes = None
+ alpha = 1.0
+ fit_intercept: bool = True
+ precompute: Union[bool, np.ndarray] = False
+ copy_X: bool = True
+ max_iter: int = 1000
+ tol: float = 1e-4
+ warm_start: bool = False
+ positive: bool = False
+ random_state: Optional[int] = 42
+ class_weight: Union[dict, str] = None
+ selection: str = "cyclic" # "cyclic" or "random"
+ solver: str = "saga"
+
+ use_predict_proba: bool = False
+ task: str = None
+ estimator: Lasso = field(default=None, init=False, repr=False)
+ n_classes: int = None
+
+ def __post_init__(self):
+ if self.task is None:
+ self.task = "classification" if self.use_predict_proba else "regression"
+ self._fit_process_desc = f"Fitting the lasso {self.task} model"
+ self.predict_process_desc = f"Predicting with the lasso {self.task} model"
+ self.predict_proba_process_desc = (
+ f"Predicting probabilities with the lasso {self.task} model"
+ )
+
+
+class LassoConfigForOTU(LassoConfig):
+ dataset_name: str = field(default="otu", init=False, repr=False)
+ log_transform: Union[str, bool] = field(default="log2_1p", init=False, repr=False)
+
+
+class LassoModel(Model):
+ config_class = LassoConfig
+ config: LassoConfig
+ lasso: Lasso
+
+ @sync_backup_config
+ def set_params(self, **kwargs):
+ self.config = self.config.replace_defaults(**kwargs)
+ self.lasso = Lasso(
+ alpha=self.config.alpha,
+ fit_intercept=self.config.fit_intercept,
+ precompute=self.config.precompute,
+ copy_X=self.config.copy_X,
+ max_iter=self.config.max_iter,
+ tol=self.config.tol,
+ warm_start=self.config.warm_start,
+ positive=self.config.positive,
+ random_state=self.config.random_state,
+ selection=self.config.selection,
+ )
+ return self
+
+ @property
+ def feature_importances_(self):
+ return self.config.estimator.coef_.flatten()
+
+ @property
+ def feature_names_in_(self):
+ return self.config.feature_names_in_
+
+ def fit(
+ self,
+ X,
+ y=None,
+ input_columns: SelectedColumnTypes = None,
+ target_column: SelectedColumnTypes = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ num_proc: int = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ) -> "LassoModel":
+ self.config._input_columns = self._set_input_columns_and_arity(
+ input_columns, target_column
+ )
+ return self._process_fit(
+ X,
+ y,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ )
+
+ def predict(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ self._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_transform(
+ X,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def predict_proba(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ self._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_transform(
+ X,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def _fit_sklearn(self, X, y):
+ if self.config.transform_log == "log2_1p":
+ X = np.log2(X + 1)
+ self.config.estimator = self.lasso.fit(X, y)
+ return self
+
+ def _predict_sklearn(self, X):
+ return self.config.estimator.predict(X)
+
+ def _predict_proba_sklearn(self, X):
+ logger.warning_once(
+ "predict_proba is not supported for Lasso. Returning using predict instead."
+ )
+ return self._predict_sklearn(X)
+
+ @staticmethod
+ def suggest_params(data):
+ return {
+ "alpha": (
+ "suggest_float",
+ (
+ 1e-8,
+ 1e3,
+ ),
+ {"log": True},
+ ),
+ "fit_intercept": ("suggest_categorical", ([True, False],), {}),
+ "max_iter": (
+ "suggest_int",
+ (
+ 1000,
+ 10000,
+ ),
+ {},
+ ),
+ }
+
+ @staticmethod
+ def suggest_first_trial():
+ return {
+ "alpha": 1.0,
+ "fit_intercept": True,
+ "max_iter": 1000,
+ }
+
+
+# Lasso for Classification is a logistic regression model with L1 penalty, except
+# that with the additional layer of converting the scores for classes to the "winning"
+# class output label.
+class LassoForClassification(LogisticRegressionModel):
+ class_config = LassoConfig
+
+ def __init__(
+ self,
+ alpha=1.0,
+ fit_intercept: bool = True,
+ copy_X: bool = True,
+ max_iter: int = 1000,
+ tol: float = 1e-4,
+ warm_start: bool = False,
+ class_weight: str = None,
+ random_state: int = 42,
+ solver: str = "saga",
+ config: LassoConfig = None,
+ **kwargs,
+ ):
+ if "C" in kwargs:
+ alpha = 1.0 / kwargs.pop("C")
+ super().__init__(
+ C=1.0 / alpha, # C is the inverse of alpha
+ fit_intercept=fit_intercept,
+ max_iter=max_iter,
+ tol=tol,
+ penalty="l1",
+ solver=solver,
+ warm_start=warm_start,
+ class_weight=class_weight,
+ random_state=random_state,
+ config=config,
+ )
+
+ @sync_backup_config
+ def set_params(self, **kwargs):
+ if "alpha" in kwargs:
+ kwargs["C"] = 1.0 / kwargs.pop("alpha")
+ super().set_params(**kwargs)
+ return self
+
+ @property
+ def classes_(self):
+ return self.config.estimator.classes_
+
+ @staticmethod
+ def suggest_params(data):
+ params = {
+ "C": (
+ "suggest_float",
+ (
+ 1e-4,
+ 1e2,
+ ),
+ {"log": True},
+ ),
+ "tol": (
+ "suggest_float",
+ (
+ 1e-5,
+ 1e-2,
+ ),
+ {"log": True},
+ ),
+ "fit_intercept": ("suggest_categorical", ([True, False],), {}),
+ "max_iter": ("suggest_categorical", ([100, 200, 500, 1000],), {}),
+ "class_weight": ("suggest_categorical", (["balanced", "None"],), {}),
+ }
+ return params
+
+ @staticmethod
+ def suggest_first_trial():
+ return {
+ "alpha": 1.0,
+ "tol": 1e-4,
+ "fit_intercept": True,
+ "max_iter": 1000,
+ "class_weight": "None",
+ }
+
+
+class LassoForRegression(LassoModel):
+ pass
diff --git a/src/biofit/models/lightgbm/__init__.py b/src/biofit/models/lightgbm/__init__.py
new file mode 100644
index 0000000..5c69cd2
--- /dev/null
+++ b/src/biofit/models/lightgbm/__init__.py
@@ -0,0 +1,7 @@
+# ruff: noqa
+from .lightgbm import (
+ LightGBMModel,
+ LightGBMConfig,
+ LightGBMForClassification,
+ LightGBMForRegression,
+)
diff --git a/src/biofit/models/lightgbm/lightgbm.py b/src/biofit/models/lightgbm/lightgbm.py
new file mode 100644
index 0000000..754b120
--- /dev/null
+++ b/src/biofit/models/lightgbm/lightgbm.py
@@ -0,0 +1,578 @@
+from dataclasses import dataclass, field, fields
+from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Type, Union
+
+import numpy as np
+from biocore import DataHandler
+from biocore.utils.import_util import is_lightgbm_available, requires_backends
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import SelectedColumnTypes, sync_backup_config
+from biofit.utils import logging
+
+from ..models import Model, ModelConfig
+
+logger = logging.get_logger(__name__)
+
+if TYPE_CHECKING:
+ from lightgbm import LGBMClassifier, LGBMRegressor
+ from lightgbm.sklearn import _LGBM_ScikitCustomObjectiveFunction
+
+if is_lightgbm_available():
+ from lightgbm import LGBMClassifier, LGBMRegressor
+ from lightgbm.sklearn import _LGBM_ScikitCustomObjectiveFunction
+
+
+@dataclass
+class LightGBMConfig(ModelConfig):
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ None,
+ get_feature("TARGET_FEATURE_TYPES"),
+ None,
+ None,
+ ],
+ init=False,
+ repr=False,
+ )
+ _fit_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("METADATA_FEATURE_TYPES"),
+ None,
+ None,
+ None,
+ ],
+ init=False,
+ repr=False,
+ )
+ _transform_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ processor_name: str = field(default="lightgbm", init=False, repr=False)
+ boosting_type: str = "gbdt"
+ num_leaves: int = 31
+ max_depth: int = -1
+ learning_rate: float = 0.1
+ n_estimators: int = 100
+ subsample_for_bin: int = 200000
+ objective: Optional[Union[str, "_LGBM_ScikitCustomObjectiveFunction"]] = None
+ eval_metric: str = None
+ class_weight: Optional[Union[Dict, str]] = None
+ min_split_gain: float = 0.0
+ min_child_weight: float = 1e-3
+ min_child_samples: int = 20
+ subsample: float = 1.0
+ subsample_freq: int = 0
+ colsample_bytree: float = 1.0
+ reg_alpha: float = 0.0
+ reg_lambda: float = 0.0
+ random_state: Optional[Union[int, np.random.RandomState]] = 42
+ n_jobs: Optional[int] = None
+ importance_type: str = "split"
+ early_stopping_rounds: int = None
+ callbacks: Optional[List[Callable]] = None
+ verbosity: int = -1
+ additional_kwargs: dict = field(default_factory=dict)
+
+ n_classes: int = None
+ task: str = None
+ estimator: Union["LGBMClassifier", "LGBMRegressor"] = field(
+ default=None, init=False, repr=False
+ )
+
+ def __post_init__(self):
+ requires_backends(self.__class__, "lightgbm")
+ self._fit_process_desc = f"Fitting the LGBM {self.task} model"
+ self.predict_process_desc = f"Predicting with the LGBM {self.task} model"
+ self.predict_proba_process_desc = (
+ f"Predicting probabilities with the LGBM {self.task} model"
+ )
+
+
+class LightGBMModel(Model):
+ config_class = LightGBMConfig
+ config: LightGBMConfig
+ lightgbm: Union["LGBMClassifier", "LGBMRegressor"]
+
+ def __init__(
+ self,
+ boosting_type: str = "gbdt",
+ num_leaves: int = 31,
+ max_depth: int = -1,
+ learning_rate: float = 0.1,
+ n_estimators: int = 100,
+ subsample_for_bin: int = 200000,
+ objective: Optional[Union[str, "_LGBM_ScikitCustomObjectiveFunction"]] = None,
+ eval_metric: str = None,
+ class_weight: Optional[Union[Dict, str]] = None,
+ min_split_gain: float = 0.0,
+ min_child_weight: float = 1e-3,
+ min_child_samples: int = 20,
+ subsample: float = 1.0,
+ subsample_freq: int = 0,
+ colsample_bytree: float = 1.0,
+ reg_alpha: float = 0.0,
+ reg_lambda: float = 0.0,
+ random_state: Optional[Union[int, np.random.RandomState]] = None,
+ n_jobs: Optional[int] = None,
+ importance_type: str = "split",
+ callbacks: Optional[List[Callable]] = None,
+ verbosity: int = -1,
+ n_classes: int = None,
+ early_stopping_rounds: int = None,
+ task: str = None,
+ version: str = None,
+ config: Optional[LightGBMConfig] = None,
+ **kwargs,
+ ):
+ m_params = [f.name for f in fields(LightGBMConfig) if f.init]
+ biofit_params = {k: v for k, v in kwargs.items() if k in m_params}
+ lightgbm_params = {k: v for k, v in kwargs.items() if k not in m_params}
+ if "additional_kwargs" in biofit_params:
+ lightgbm_params.update(biofit_params.pop("additional_kwargs"))
+ super().__init__(
+ config=config,
+ boosting_type=boosting_type,
+ num_leaves=num_leaves,
+ max_depth=max_depth,
+ learning_rate=learning_rate,
+ n_estimators=n_estimators,
+ subsample_for_bin=subsample_for_bin,
+ objective=objective,
+ eval_metric=eval_metric,
+ class_weight=class_weight,
+ min_split_gain=min_split_gain,
+ min_child_weight=min_child_weight,
+ min_child_samples=min_child_samples,
+ subsample=subsample,
+ subsample_freq=subsample_freq,
+ colsample_bytree=colsample_bytree,
+ reg_alpha=reg_alpha,
+ reg_lambda=reg_lambda,
+ random_state=random_state,
+ n_jobs=n_jobs,
+ importance_type=importance_type,
+ early_stopping_rounds=early_stopping_rounds,
+ callbacks=callbacks,
+ verbosity=verbosity,
+ n_classes=n_classes,
+ task=task,
+ version=version,
+ additional_kwargs=lightgbm_params,
+ **biofit_params,
+ )
+ if self.config.task is None:
+ raise ValueError(
+ "Task is not set. Please set the task before setting the parameters."
+ )
+
+ if "classification" in self.config.task:
+ self.lightgbm = LGBMClassifier(
+ boosting_type=self.config.boosting_type,
+ num_leaves=self.config.num_leaves,
+ max_depth=self.config.max_depth if self.config.max_depth > 0 else -1,
+ learning_rate=self.config.learning_rate,
+ n_estimators=self.config.n_estimators,
+ subsample_for_bin=self.config.subsample_for_bin,
+ objective=self.config.objective,
+ class_weight=self.config.class_weight
+ if self.config.class_weight != "None"
+ else None,
+ min_split_gain=self.config.min_split_gain,
+ min_child_weight=self.config.min_child_weight,
+ min_child_samples=self.config.min_child_samples,
+ subsample=self.config.subsample,
+ subsample_freq=self.config.subsample_freq,
+ colsample_bytree=self.config.colsample_bytree,
+ reg_alpha=self.config.reg_alpha,
+ reg_lambda=self.config.reg_lambda,
+ random_state=self.config.random_state,
+ n_jobs=self.config.n_jobs,
+ importance_type=self.config.importance_type,
+ verbosity=self.config.verbosity,
+ **self.config.additional_kwargs,
+ )
+ elif "regression" in self.config.task:
+ self.lightgbm = LGBMRegressor(
+ boosting_type=self.config.boosting_type,
+ num_leaves=self.config.num_leaves,
+ max_depth=self.config.max_depth if self.config.max_depth > 0 else -1,
+ learning_rate=self.config.learning_rate,
+ n_estimators=self.config.n_estimators,
+ subsample_for_bin=self.config.subsample_for_bin,
+ objective=self.config.objective,
+ min_split_gain=self.config.min_split_gain,
+ min_child_weight=self.config.min_child_weight,
+ min_child_samples=self.config.min_child_samples,
+ subsample=self.config.subsample,
+ subsample_freq=self.config.subsample_freq,
+ colsample_bytree=self.config.colsample_bytree,
+ reg_alpha=self.config.reg_alpha,
+ reg_lambda=self.config.reg_lambda,
+ random_state=self.config.random_state,
+ n_jobs=self.config.n_jobs,
+ importance_type=self.config.importance_type,
+ verbosity=self.config.verbosity,
+ **self.config.additional_kwargs,
+ )
+ else:
+ raise ValueError(f"Invalid task: {self.config.task}")
+
+ @sync_backup_config
+ def set_params(self, **params):
+ self.config = self.config.replace_defaults(**params)
+ if self.config.task is None:
+ raise ValueError(
+ "Task is not set. Please set the task before setting the parameters."
+ )
+
+ if "classification" in self.config.task:
+ self.lightgbm = LGBMClassifier(
+ boosting_type=self.config.boosting_type,
+ num_leaves=self.config.num_leaves,
+ max_depth=self.config.max_depth if self.config.max_depth > 0 else -1,
+ learning_rate=self.config.learning_rate,
+ n_estimators=self.config.n_estimators,
+ subsample_for_bin=self.config.subsample_for_bin,
+ objective=self.config.objective,
+ class_weight=self.config.class_weight
+ if self.config.class_weight != "None"
+ else None,
+ min_split_gain=self.config.min_split_gain,
+ min_child_weight=self.config.min_child_weight,
+ min_child_samples=self.config.min_child_samples,
+ subsample=self.config.subsample,
+ subsample_freq=self.config.subsample_freq,
+ colsample_bytree=self.config.colsample_bytree,
+ reg_alpha=self.config.reg_alpha,
+ reg_lambda=self.config.reg_lambda,
+ random_state=self.config.random_state,
+ n_jobs=self.config.n_jobs,
+ importance_type=self.config.importance_type,
+ verbosity=self.config.verbosity,
+ )
+ elif "regression" in self.config.task:
+ self.lightgbm = LGBMRegressor(
+ boosting_type=self.config.boosting_type,
+ num_leaves=self.config.num_leaves,
+ max_depth=self.config.max_depth if self.config.max_depth > 0 else -1,
+ learning_rate=self.config.learning_rate,
+ n_estimators=self.config.n_estimators,
+ subsample_for_bin=self.config.subsample_for_bin,
+ objective=self.config.objective,
+ min_split_gain=self.config.min_split_gain,
+ min_child_weight=self.config.min_child_weight,
+ min_child_samples=self.config.min_child_samples,
+ subsample=self.config.subsample,
+ subsample_freq=self.config.subsample_freq,
+ colsample_bytree=self.config.colsample_bytree,
+ reg_alpha=self.config.reg_alpha,
+ reg_lambda=self.config.reg_lambda,
+ random_state=self.config.random_state,
+ n_jobs=self.config.n_jobs,
+ importance_type=self.config.importance_type,
+ verbosity=self.config.verbosity,
+ )
+ else:
+ raise ValueError(f"Invalid task: {self.config.task}")
+ return self
+
+ def set_objective(self, task: str):
+ if task == "regression":
+ if self.config.eval_metric == "rmse":
+ self.config.objective = "regression"
+ elif self.config.eval_metric == "mae":
+ self.config.objective = "regression_l1"
+ elif task == "binary_classification":
+ self.config.objective = "binary"
+ if self.config.eval_metric == "accuracy":
+ self.config.eval_metric = "binary_error"
+ elif task == "multiclass_classification":
+ self.config.objective = "multiclass"
+ self.lightgbm = self.lightgbm.set_params(objective=self.config.objective)
+ return self
+
+ @property
+ def feature_importances_(self):
+ return self.config.estimator.feature_importances_
+
+ @property
+ def feature_names_in_(self):
+ return self.config.feature_names_in_
+
+ def fit(
+ self,
+ X,
+ y=None,
+ eval_set=None,
+ input_columns: SelectedColumnTypes = None,
+ target_column: SelectedColumnTypes = None,
+ eval_input_columns: SelectedColumnTypes = None,
+ eval_target_column: SelectedColumnTypes = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ num_proc: int = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ early_stopping_rounds: int = None,
+ ) -> "LightGBMModel":
+ self.config._input_columns = self._set_input_columns_and_arity(
+ input_columns, target_column, eval_input_columns, eval_target_column
+ )
+ if eval_set is not None:
+ if isinstance(eval_set, tuple):
+ extras = eval_set
+ elif (
+ isinstance(eval_set, list)
+ and len(eval_set) == 1
+ and isinstance(eval_set[0], tuple)
+ ):
+ extras = eval_set[0]
+ else:
+ raise ValueError(
+ "eval_set must be a tuple or a list containing a single tuple"
+ )
+ else:
+ extras = (None, None)
+
+ return self._process_fit(
+ X,
+ y,
+ *extras,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ )
+
+ def predict(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ self._method_prefix = "_predict"
+ self._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_transform(
+ X,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def predict_proba(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ self._method_prefix = "_predict_proba"
+ self._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_transform(
+ X,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def _fit_sklearn(self, X, y, eval_x=None, eval_y=None):
+ callbacks = self.config.callbacks
+ if self.config.early_stopping_rounds:
+ callbacks = callbacks or []
+ from lightgbm.callback import early_stopping
+
+ callbacks.append(early_stopping(self.config.early_stopping_rounds))
+
+ if eval_x is not None and eval_y is not None:
+ self.config.estimator = self.lightgbm.fit(
+ X,
+ y,
+ eval_set=[(eval_x, eval_y)],
+ eval_metric=self.config.eval_metric,
+ callbacks=callbacks,
+ )
+ else:
+ self.config.estimator = self.lightgbm.fit(
+ X, y, eval_metric=self.config.eval_metric, callbacks=callbacks
+ )
+ if self.config.early_stopping_rounds:
+ self.config.n_estimators = self.config.estimator.best_iteration_
+ self.config.estimator.set_params(n_estimators=self.config.n_estimators)
+ self.config.early_stopping_rounds = None
+ return self
+
+ def _predict_sklearn(self, X):
+ return self.config.estimator.predict(X)
+
+ def _predict_proba_sklearn(self, X):
+ return self.config.estimator.predict_proba(X)
+
+ def _get_features_out(
+ self, X, selected_indices=None, unselected_indices=None, **kwargs
+ ):
+ if self._method_prefix == "_predict_proba":
+ self.config._n_features_out = self.config.n_classes
+ self.output_dtype = "float64"
+ else:
+ self.config._n_features_out = 1
+ return super()._get_features_out(
+ X,
+ selected_indices=selected_indices,
+ unselected_indices=unselected_indices,
+ one_to_one_features=False,
+ n_features_out=self.config._n_features_out,
+ keep_unused_columns=False,
+ )
+
+ @staticmethod
+ def suggest_params(data):
+ params = {
+ "verbosity": -1,
+ "n_estimators": ("suggest_int", [50, 1000], {"step": 50}),
+ "max_depth": ("suggest_int", [0, 11], {}),
+ "learning_rate": ("suggest_float", [1e-3, 0.1], {"log": True}),
+ "num_leaves": ("suggest_int", [31, 255], {}),
+ "subsample": ("suggest_float", [0.5, 1.0], {"step": 0.1}),
+ "reg_alpha": ("suggest_float", [1e-9, 1.0], {"log": True}),
+ "reg_lambda": ("suggest_float", [1e-9, 1.0], {"log": True}),
+ "min_child_samples": ("suggest_int", [10, 100], {}),
+ "min_split_gain": ("suggest_float", [0.0, 1.0], {}),
+ "num_threads": -1,
+ }
+
+ if data is not None:
+ data_dim = DataHandler.get_shape(data)
+
+ if data_dim[1] > data_dim[0]:
+ msg = "The number of features is greater than the number of samples"
+ max_feat_frac = max(0.5, data_dim[0] / data_dim[1])
+ max_feat_frac = round(max_feat_frac, 1)
+ params["colsample_bytree"] = (
+ "suggest_float",
+ [0.1, max_feat_frac],
+ {"step": 0.1},
+ )
+ if data_dim[1] > 1000:
+ msg += " and exceeds 1000"
+ params["n_estimators"] = ("suggest_int", [50, 500], {"step": 50})
+ params["max_depth"] = ("suggest_int", [3, 8], {})
+ params["num_leaves"] = ("suggest_int", [31, 63], {})
+ params["min_child_samples"] = ("suggest_int", [10, 50], {})
+ params["max_bin"] = ("suggest_int", [63, 127], {})
+ msg += ". Adjusting the hyperparameters accordingly for faster training. "
+ logger.warning(msg)
+
+ return params
+
+ @staticmethod
+ def suggest_first_trial():
+ return {
+ "n_estimators": 100,
+ "max_depth": 8,
+ "learning_rate": 0.1,
+ "num_leaves": 31,
+ "subsample": 1.0,
+ "reg_alpha": 1e-9,
+ "reg_lambda": 1e-9,
+ "min_child_samples": 20,
+ "colsample_bytree": 0.9,
+ "min_split_gain": 0.0,
+ }
+
+
+class LightGBMForClassification(LightGBMModel):
+ def __init__(self, **kwargs):
+ kwargs["task"] = kwargs.get("task", "classification")
+ super().__init__(**kwargs)
+
+ @property
+ def classes_(self):
+ return self.config.estimator.classes_
+
+ @sync_backup_config
+ def set_params(self, **params):
+ self.config.task = self.config.task or "classification"
+ return super().set_params(**params)
+
+ def _process_fit_input(self, X, **kwargs):
+ X, kwargs = super()._process_fit_input(X, **kwargs)
+ self.set_objective(self.config.task)
+ return X, kwargs
+
+ @staticmethod
+ def suggest_params(data):
+ params = LightGBMModel.suggest_params(data)
+ params["class_weight"] = ("suggest_categorical", [["balanced", "None"]], {})
+ return params
+
+ @staticmethod
+ def suggest_first_trial():
+ return {
+ **LightGBMModel.suggest_first_trial(),
+ "class_weight": "balanced",
+ }
+
+
+class LightGBMForRegression(LightGBMModel):
+ @sync_backup_config
+ def set_params(self, **params):
+ self.config.task = self.config.task or "regression"
+ return super().set_params(**params)
diff --git a/src/biofit/models/logistic_regression/__init__.py b/src/biofit/models/logistic_regression/__init__.py
new file mode 100644
index 0000000..d909e73
--- /dev/null
+++ b/src/biofit/models/logistic_regression/__init__.py
@@ -0,0 +1,5 @@
+# ruff: noqa
+from .logistic_regression import (
+ LogisticRegressionModel,
+ LogisticRegressionConfig,
+)
diff --git a/src/biofit/models/logistic_regression/logistic_regression.py b/src/biofit/models/logistic_regression/logistic_regression.py
new file mode 100644
index 0000000..2cb15b2
--- /dev/null
+++ b/src/biofit/models/logistic_regression/logistic_regression.py
@@ -0,0 +1,246 @@
+from dataclasses import dataclass, field
+from typing import List, Optional, Union
+
+from sklearn.linear_model import LogisticRegression
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import SelectedColumnTypes, sync_backup_config
+from biofit.utils import logging
+
+from ..models import Model, ModelConfig
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class LogisticRegressionConfig(ModelConfig):
+ _fit_input_feature_types: List[Union[None, type]] = field(
+ default_factory=lambda: [None, get_feature("TARGET_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ _fit_unused_feature_types: List[Union[None, type]] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES"), None],
+ init=False,
+ repr=False,
+ )
+ _transform_unused_feature_types: List[type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ _fit_process_desc: str = field(
+ default="Fitting the lasso model", init=False, repr=False
+ )
+ predict_process_desc: str = field(
+ default=None,
+ init=False,
+ repr=False,
+ )
+ predict_proba_process_desc: str = field(
+ default=None,
+ init=False,
+ repr=False,
+ )
+ processor_name: str = field(default="logistic_regression", init=False, repr=False)
+
+ input_columns: SelectedColumnTypes = None
+ target_column: SelectedColumnTypes = None
+
+ penalty: str = "l2"
+ dual: bool = False
+ tol: float = 1e-4
+ C: float = 1.0
+ fit_intercept: bool = True
+ intercept_scaling: float = 1
+ class_weight: Union[dict, str] = None
+ random_state: Optional[int] = 42
+ solver: str = "saga"
+ max_iter: int = 100
+ multi_class: str = "auto"
+ verbose: int = 0
+ warm_start: bool = False
+ n_jobs: Optional[int] = None
+ l1_ratio: Optional[float] = None
+
+ use_predict_proba: bool = False
+ task: str = field(default="classification", init=False, repr=False)
+ estimator: LogisticRegression = field(default=None, init=False, repr=False)
+ n_classes: int = None
+
+ def __post_init__(self):
+ self._fit_process_desc = "Fitting the logistic regression model"
+ self.predict_process_desc = "Predicting with the logistic regression model"
+ self.predict_proba_process_desc = (
+ "Predicting probabilities with the logistic regression model"
+ )
+
+
+class LogisticRegressionModel(Model):
+ config_class = LogisticRegressionConfig
+ config: LogisticRegressionConfig
+ logistic_regression: LogisticRegression
+
+ def __init__(
+ self,
+ penalty: str = "l2",
+ dual: bool = False,
+ tol: float = 1e-4,
+ C: float = 1.0,
+ fit_intercept: bool = True,
+ intercept_scaling: float = 1,
+ class_weight: Union[dict, str] = None,
+ random_state: int = 42,
+ solver: str = "lbfgs",
+ max_iter: int = 100,
+ multi_class: str = "auto",
+ verbose: int = 0,
+ warm_start: bool = False,
+ n_jobs: Optional[int] = None,
+ l1_ratio: Optional[float] = None,
+ config: LogisticRegressionConfig = None,
+ **kwargs,
+ ):
+ super().__init__(
+ penalty=penalty,
+ dual=dual,
+ tol=tol,
+ C=C,
+ fit_intercept=fit_intercept,
+ intercept_scaling=intercept_scaling,
+ class_weight=class_weight,
+ random_state=random_state,
+ solver=solver,
+ max_iter=max_iter,
+ multi_class=multi_class if multi_class != "None" else None,
+ verbose=verbose,
+ warm_start=warm_start,
+ n_jobs=n_jobs,
+ l1_ratio=l1_ratio,
+ config=config,
+ **kwargs,
+ )
+ self.logistic_regression = LogisticRegression(
+ penalty=self.config.penalty,
+ dual=self.config.dual,
+ tol=self.config.tol,
+ C=self.config.C,
+ fit_intercept=self.config.fit_intercept,
+ intercept_scaling=self.config.intercept_scaling,
+ class_weight=self.config.class_weight
+ if self.config.class_weight != "None"
+ else None,
+ random_state=self.config.random_state,
+ solver=self.config.solver,
+ max_iter=self.config.max_iter,
+ multi_class=self.config.multi_class,
+ verbose=self.config.verbose,
+ warm_start=self.config.warm_start,
+ n_jobs=self.config.n_jobs,
+ l1_ratio=self.config.l1_ratio,
+ )
+
+ @sync_backup_config
+ def set_params(self, **kwargs):
+ self.config = self.config.replace_defaults(**kwargs)
+ self.logistic_regression = LogisticRegression(
+ penalty=self.config.penalty,
+ dual=self.config.dual,
+ tol=self.config.tol,
+ C=self.config.C,
+ fit_intercept=self.config.fit_intercept,
+ intercept_scaling=self.config.intercept_scaling,
+ class_weight=self.config.class_weight
+ if self.config.class_weight != "None"
+ else None,
+ random_state=self.config.random_state,
+ solver=self.config.solver,
+ max_iter=self.config.max_iter,
+ multi_class=self.config.multi_class,
+ verbose=self.config.verbose,
+ warm_start=self.config.warm_start,
+ n_jobs=self.config.n_jobs,
+ l1_ratio=self.config.l1_ratio,
+ )
+ return self
+
+ @property
+ def feature_importances_(self):
+ return self.config.estimator.coef_.flatten()
+
+ @property
+ def feature_names_in_(self):
+ return self.config.feature_names_in_
+
+ def fit(
+ self,
+ X,
+ y=None,
+ input_columns: SelectedColumnTypes = None,
+ target_column: SelectedColumnTypes = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ num_proc: int = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ) -> "LogisticRegressionModel":
+ self.config._input_columns = self._set_input_columns_and_arity(
+ input_columns, target_column
+ )
+ return self._process_fit(
+ X,
+ y,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ )
+
+ def _fit_sklearn(self, X, y):
+ self.config.estimator = self.logistic_regression.fit(X, y)
+ return self
+
+ def _predict_sklearn(self, X):
+ return self.config.estimator.predict(X)
+
+ def _predict_proba_sklearn(self, X):
+ return self.config.estimator.predict_proba(X)
+
+ @staticmethod
+ def suggest_params(data):
+ params = {
+ "C": (
+ "suggest_float",
+ (
+ 1e-4,
+ 1e2,
+ ),
+ {"log": True},
+ ),
+ "tol": (
+ "suggest_float",
+ (
+ 1e-5,
+ 1e-2,
+ ),
+ {"log": True},
+ ),
+ "max_iter": ("suggest_categorical", ([100, 200, 500, 1000],), {}),
+ "penalty": "l1",
+ "n_jobs": -1,
+ "class_weight": ("suggest_categorical", (["balanced", "None"],), {}),
+ }
+ return params
diff --git a/src/biofit/models/models.py b/src/biofit/models/models.py
new file mode 100644
index 0000000..52a58d8
--- /dev/null
+++ b/src/biofit/models/models.py
@@ -0,0 +1,223 @@
+from dataclasses import dataclass, field
+from functools import wraps
+
+import pyarrow as pa
+from biocore import DataHandler
+from biocore.utils.import_util import is_datasets_available
+
+from biofit.processing import BaseProcessor, ProcessorConfig, SelectedColumnTypes
+from biofit.utils import logging
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class ModelConfig(ProcessorConfig):
+ """Base class for feature extraction processor configurations."""
+
+ _fit_process_desc: str = field(default="Fitting the model", init=False, repr=False)
+ predict_process_desc: str = field(
+ default="Predicting target output", init=False, repr=False
+ )
+ predict_proba_process_desc: str = field(
+ default="Predicting target output probabilities", init=False, repr=False
+ )
+ processor_type: str = field(default="models", init=False, repr=False)
+ _missing_val: float = field(default=0, init=False, repr=False)
+ _missing_val_pa_type: pa.DataType = field(
+ default=pa.float64(), init=False, repr=False
+ )
+ class_names: list = field(default=None, init=False, repr=False)
+
+ task: str = None
+
+
+class Model(BaseProcessor):
+ """Base class for models."""
+
+ @wraps(BaseProcessor._process_transform)
+ def predict(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ """Predict the model."""
+ self._method_prefix = "_predict"
+ self.output_dtype = "int64"
+ self._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_transform(
+ X,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def predict_proba(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ self._method_prefix = "_predict_proba"
+ self.output_dtype = "float64"
+ self._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_transform(
+ X,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ @staticmethod
+ def suggest_params() -> dict:
+ """Get hyperparameters for an optuna trial."""
+ raise NotImplementedError
+
+ def _process_fit_input(self, X, **kwargs):
+ target_col = kwargs["fn_kwargs"]["extra_indices"][0]
+ if target_col and "class" in self.config.task.lower():
+ if isinstance(target_col, list):
+ target_col = target_col[0]
+
+ target_col = DataHandler.get_column_names(X, generate_cols=True)[target_col]
+
+ if self.config.n_classes is None:
+ self.config.n_classes = DataHandler.nunique(X, target_col)
+
+ if is_datasets_available():
+ from datasets import Dataset
+
+ if isinstance(X, Dataset):
+ feat = X._info.features[target_col]
+ self.config.class_names = feat.names
+
+ if self.config.n_classes > 2:
+ self.config.task = "multiclass_classification"
+ else:
+ self.config.task = "binary_classification"
+ return super()._process_fit_input(X, **kwargs)
+
+ def _process_fit_output(self, input, out):
+ if hasattr(self, "classes_"):
+ if self.config.class_names is not None and isinstance(
+ self.classes_[0], int
+ ):
+ if sorted(self.classes_.tolist()) == list(range(self.config.n_classes)):
+ self.config.class_names = [
+ self.config.class_names[i] for i in self.classes_
+ ]
+ else:
+ self.config.class_names = [str(i) for i in self.classes_]
+ else:
+ self.config.class_names = [str(i) for i in self.classes_]
+ return super()._process_fit_output(input, out)
+
+ def _process_transform_input(self, X, **kwargs):
+ if self._method_prefix == "_predict_proba":
+ self.config._n_features_out = self.config.n_classes
+ self.config._feature_names_out = self.config.class_names
+ self.output_dtype = "float64"
+ else:
+ if self.config.extra_names_in_ is not None:
+ self.config._feature_names_out = self.config.extra_names_in_[0]
+ self.config._n_features_out = 1
+ return super()._process_transform_input(X, **kwargs)
+
+ def _process_transform_batch_input(self, X, *fn_args, **fn_kwargs):
+ X, args, kwargs = super()._process_transform_batch_input(
+ X, *fn_args, **fn_kwargs
+ )
+ if DataHandler.supports_named_columns(X):
+ input_cols = DataHandler.get_column_names(X)
+ missing_cols = set(self.config.feature_names_in_) - set(input_cols)
+ intersecting_cols = set(input_cols) & set(self.config.feature_names_in_)
+
+ X_dims = DataHandler.get_shape(X)
+ num_cols = None
+ if len(X_dims) > 1:
+ num_cols = X_dims[1]
+
+ num_non_existing = num_cols - len(intersecting_cols)
+ if num_non_existing:
+ logger.warning_once(
+ f"Dataset has {num_non_existing} out of {num_cols} columns that were "
+ "not in the training data. Dropping these columns."
+ )
+ X = DataHandler.select_columns(X, list(intersecting_cols))
+
+ if missing_cols:
+ if self.config._missing_val is None:
+ self.config._missing_val_str = "`None`"
+ elif self.config._missing_val == 0:
+ self.config._missing_val_str = "zeroes"
+ else:
+ self.config._missing_val_str = f"{self.config._missing_val}"
+ logger.warning_once(
+ f"Dataset is missing {len(missing_cols)} out of "
+ f"{len(self.config.feature_names_in_)} columns that were in the "
+ "training data. Adding these columns as "
+ f"{self.config._missing_val_str}."
+ )
+ num_rows = DataHandler.get_shape(X)[0]
+ zeros_mat = pa.table(
+ {
+ col: pa.array(
+ [self.config._missing_val] * num_rows,
+ type=self.config._missing_val_pa_type,
+ )
+ for col in missing_cols
+ }
+ )
+ X = DataHandler.concat(
+ [X, DataHandler.to_format(zeros_mat, DataHandler.get_format(X))],
+ axis=1,
+ )
+ # Reorder columns to match the training data
+ X = DataHandler.select_columns(X, self.config.feature_names_in_)
+
+ return X, args, kwargs
diff --git a/src/biofit/models/random_forest/__init__.py b/src/biofit/models/random_forest/__init__.py
new file mode 100644
index 0000000..5aaf26a
--- /dev/null
+++ b/src/biofit/models/random_forest/__init__.py
@@ -0,0 +1,7 @@
+# ruff: noqa
+from .random_forest import (
+ RandomForestModel,
+ RandomForestConfig,
+ RandomForestForClassification,
+ RandomForestForRegression,
+)
diff --git a/src/biofit/models/random_forest/random_forest.py b/src/biofit/models/random_forest/random_forest.py
new file mode 100644
index 0000000..f98269e
--- /dev/null
+++ b/src/biofit/models/random_forest/random_forest.py
@@ -0,0 +1,428 @@
+from dataclasses import dataclass, field
+from typing import List, Union
+
+from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import SelectedColumnTypes, sync_backup_config
+
+from ..models import Model, ModelConfig
+
+
+@dataclass
+class RandomForestConfig(ModelConfig):
+ _fit_input_feature_types: List[Union[None, type]] = field(
+ default_factory=lambda: [None, get_feature("TARGET_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ _fit_unused_feature_types: List[Union[None, type]] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES"), None],
+ init=False,
+ repr=False,
+ )
+ _transform_unused_feature_types: List[type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ _fit_process_desc: str = field(
+ default="Fitting the random forest model", init=False, repr=False
+ )
+ predict_proba_process_desc: str = field(
+ default=None,
+ init=False,
+ repr=False,
+ )
+ processor_name: str = field(default="random_forest", init=False, repr=False)
+
+ n_estimators: int = 100
+ max_depth: int = None
+ min_samples_split: int = 2
+ min_samples_leaf: int = 1
+ min_weight_fraction_leaf: float = 0.0
+ max_leaf_nodes: int = None
+ min_impurity_decrease: float = 0.0
+ bootstrap: bool = False
+ oob_score: bool = False
+ n_jobs: int = None
+ random_state: int = 42
+ verbose: int = 0
+ warm_start: bool = False
+ ccp_alpha: float = 0.0
+ max_samples: int = None
+ monotonic_cst: dict = None
+ criterion: str = None # gini for classification, squared_error for regression
+ max_features: Union[str, int, float] = "sqrt"
+
+ class_weight: Union[str, dict, List[dict]] = None
+ n_classes: int = None
+
+ use_predict_proba: bool = False
+ task: str = None
+
+ estimator: Union[RandomForestClassifier, RandomForestRegressor] = field(
+ default=None, init=False, repr=False
+ )
+
+
+class RandomForestModel(Model):
+ config_class = RandomForestConfig
+ config: RandomForestConfig
+ random_forest: Union[RandomForestClassifier, RandomForestRegressor]
+
+ def __init__(
+ self,
+ n_estimators: int = 100,
+ max_depth: int = None,
+ min_samples_split: int = 2,
+ min_samples_leaf: int = 1,
+ min_weight_fraction_leaf: float = 0.0,
+ max_leaf_nodes: int = None,
+ min_impurity_decrease: float = 0.0,
+ bootstrap: bool = True,
+ oob_score: bool = False,
+ n_jobs: int = None,
+ random_state: int = None,
+ verbose: int = 0,
+ warm_start: bool = False,
+ ccp_alpha: float = 0.0,
+ max_samples: int = None,
+ monotonic_cst: dict = None,
+ criterion: str = None,
+ max_features: Union[str, int, float] = "sqrt",
+ class_weight: Union[str, dict, List[dict]] = None,
+ task: str = None,
+ config: RandomForestConfig = None,
+ **kwargs,
+ ):
+ super().__init__(
+ config=config,
+ n_estimators=n_estimators,
+ max_depth=max_depth,
+ min_samples_split=min_samples_split,
+ min_samples_leaf=min_samples_leaf,
+ min_weight_fraction_leaf=min_weight_fraction_leaf,
+ max_leaf_nodes=max_leaf_nodes,
+ min_impurity_decrease=min_impurity_decrease,
+ bootstrap=bootstrap,
+ oob_score=oob_score,
+ n_jobs=n_jobs,
+ random_state=random_state,
+ verbose=verbose,
+ warm_start=warm_start,
+ ccp_alpha=ccp_alpha,
+ max_samples=max_samples,
+ monotonic_cst=monotonic_cst,
+ criterion=criterion,
+ max_features=max_features,
+ task=task,
+ class_weight=class_weight,
+ **kwargs,
+ )
+
+ if self.config.task == "classification":
+ self.random_forest = RandomForestClassifier(
+ n_estimators=self.config.n_estimators,
+ criterion=self.config.criterion if self.config.criterion else "gini",
+ max_depth=self.config.max_depth if self.config.max_depth else None,
+ min_samples_split=self.config.min_samples_split,
+ min_samples_leaf=self.config.min_samples_leaf,
+ min_weight_fraction_leaf=self.config.min_weight_fraction_leaf,
+ max_features=self.config.max_features,
+ max_leaf_nodes=self.config.max_leaf_nodes,
+ min_impurity_decrease=self.config.min_impurity_decrease,
+ bootstrap=self.config.bootstrap,
+ oob_score=self.config.oob_score,
+ n_jobs=self.config.n_jobs,
+ random_state=self.config.random_state,
+ verbose=self.config.verbose,
+ warm_start=self.config.warm_start,
+ class_weight=self.config.class_weight
+ if self.config.class_weight and self.config.class_weight != "None"
+ else None,
+ ccp_alpha=self.config.ccp_alpha,
+ max_samples=self.config.max_samples,
+ monotonic_cst=self.config.monotonic_cst,
+ )
+ else:
+ self.random_forest = RandomForestRegressor(
+ n_estimators=self.config.n_estimators,
+ criterion=self.config.criterion
+ if self.config.criterion
+ else "squared_error",
+ max_depth=self.config.max_depth if self.config.max_depth > 0 else None,
+ min_samples_split=self.config.min_samples_split,
+ min_samples_leaf=self.config.min_samples_leaf,
+ min_weight_fraction_leaf=self.config.min_weight_fraction_leaf,
+ max_features=self.config.max_features,
+ max_leaf_nodes=self.config.max_leaf_nodes,
+ min_impurity_decrease=self.config.min_impurity_decrease,
+ bootstrap=self.config.bootstrap,
+ oob_score=self.config.oob_score,
+ n_jobs=self.config.n_jobs,
+ random_state=self.config.random_state,
+ verbose=self.config.verbose,
+ warm_start=self.config.warm_start,
+ ccp_alpha=self.config.ccp_alpha,
+ max_samples=self.config.max_samples,
+ monotonic_cst=self.config.monotonic_cst,
+ )
+
+ @sync_backup_config
+ def set_params(self, **kwargs):
+ self.config = self.config.replace_defaults(**kwargs)
+ if self.config.task == "classification":
+ self.random_forest = RandomForestClassifier(
+ n_estimators=self.config.n_estimators,
+ criterion=self.config.criterion if self.config.criterion else "gini",
+ max_depth=self.config.max_depth if self.config.max_depth else None,
+ min_samples_split=self.config.min_samples_split,
+ min_samples_leaf=self.config.min_samples_leaf,
+ min_weight_fraction_leaf=self.config.min_weight_fraction_leaf,
+ max_features=self.config.max_features,
+ max_leaf_nodes=self.config.max_leaf_nodes,
+ min_impurity_decrease=self.config.min_impurity_decrease,
+ bootstrap=self.config.bootstrap,
+ oob_score=self.config.oob_score,
+ n_jobs=self.config.n_jobs,
+ random_state=self.config.random_state,
+ verbose=self.config.verbose,
+ warm_start=self.config.warm_start,
+ class_weight=self.config.class_weight
+ if self.config.class_weight and self.config.class_weight != "None"
+ else None,
+ ccp_alpha=self.config.ccp_alpha,
+ max_samples=self.config.max_samples,
+ monotonic_cst=self.config.monotonic_cst,
+ )
+ else:
+ self.random_forest = RandomForestRegressor(
+ n_estimators=self.config.n_estimators,
+ criterion=self.config.criterion
+ if self.config.criterion
+ else "squared_error",
+ max_depth=self.config.max_depth if self.config.max_depth > 0 else None,
+ min_samples_split=self.config.min_samples_split,
+ min_samples_leaf=self.config.min_samples_leaf,
+ min_weight_fraction_leaf=self.config.min_weight_fraction_leaf,
+ max_features=self.config.max_features,
+ max_leaf_nodes=self.config.max_leaf_nodes,
+ min_impurity_decrease=self.config.min_impurity_decrease,
+ bootstrap=self.config.bootstrap,
+ oob_score=self.config.oob_score,
+ n_jobs=self.config.n_jobs,
+ random_state=self.config.random_state,
+ verbose=self.config.verbose,
+ warm_start=self.config.warm_start,
+ ccp_alpha=self.config.ccp_alpha,
+ max_samples=self.config.max_samples,
+ monotonic_cst=self.config.monotonic_cst,
+ )
+ if self.config.use_predict_proba:
+ self.output_dtype = "float32"
+ return self
+
+ @property
+ def feature_importances_(self):
+ return self.config.estimator.feature_importances_
+
+ @property
+ def feature_names_in_(self):
+ return self.config.feature_names_in_
+
+ def fit(
+ self,
+ X,
+ y=None,
+ input_columns: SelectedColumnTypes = None,
+ target_column: SelectedColumnTypes = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ num_proc: int = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ) -> "RandomForestModel":
+ self.config._input_columns = self._set_input_columns_and_arity(
+ input_columns, target_column
+ )
+ return self._process_fit(
+ X,
+ y,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ )
+
+ def predict(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ self._method_prefix = "_predict"
+ self._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_transform(
+ X,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def predict_proba(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ self._method_prefix = "_predict_proba"
+ self._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_transform(
+ X,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def _fit_sklearn(self, X, y):
+ self.config.estimator = self.random_forest.fit(X, y)
+ return self
+
+ def _predict_sklearn(self, X):
+ return self.config.estimator.predict(X)
+
+ def _predict_proba_sklearn(self, X):
+ return self.config.estimator.predict_proba(X)
+
+ @staticmethod
+ def suggest_params(data):
+ params = {
+ "n_estimators": ("suggest_int", (100, 2000), {}),
+ "max_depth": ("suggest_int", (0, 11), {}),
+ "min_samples_split": ("suggest_int", (2, 20), {}),
+ "min_samples_leaf": ("suggest_int", (1, 20), {}),
+ # "bootstrap": ("suggest_categorical", ([True, False],), {}),
+ # "max_features": ("suggest_categorical", (["sqrt", "log2"],), {}),
+ "n_jobs": -1,
+ }
+
+ return params
+
+ @staticmethod
+ def suggest_first_trial():
+ return {
+ "n_estimators": 100,
+ "max_depth": 0,
+ "min_samples_split": 2,
+ "min_samples_leaf": 1,
+ "n_jobs": -1,
+ }
+
+
+class RandomForestForClassification(RandomForestModel):
+ """Random Forest model for classification tasks"""
+
+ @property
+ def classes_(self):
+ return self.config.estimator.classes_
+
+ @sync_backup_config
+ def set_params(self, **kwargs):
+ kwargs.pop("task", None)
+ self.config.task = "classification"
+ return super().set_params(**kwargs)
+
+ @staticmethod
+ def suggest_params(data):
+ params = RandomForestModel.suggest_params(data)
+ params["criterion"] = ("suggest_categorical", (["gini", "entropy"],), {})
+ params["class_weight"] = (
+ "suggest_categorical",
+ (["balanced", "balanced_subsample", "None"],),
+ {},
+ )
+ return params
+
+ @staticmethod
+ def suggest_first_trial():
+ return {
+ **RandomForestModel.suggest_first_trial(),
+ "criterion": "gini",
+ "class_weight": "balanced",
+ }
+
+
+class RandomForestForRegression(RandomForestModel):
+ @sync_backup_config
+ def set_params(self, **kwargs):
+ kwargs.pop("task", None)
+ self.config.task = "regression"
+ return super().set_params(**kwargs)
+
+ @staticmethod
+ def suggest_params(data):
+ params = RandomForestModel.suggest_params(data)
+ params["criterion"] = (
+ "suggest_categorical",
+ (["squared_error", "absolute_error", "poisson"],),
+ {},
+ )
+ return params
diff --git a/src/biofit/preprocessing/__init__.py b/src/biofit/preprocessing/__init__.py
new file mode 100644
index 0000000..fe2129d
--- /dev/null
+++ b/src/biofit/preprocessing/__init__.py
@@ -0,0 +1,10 @@
+# ruff: noqa
+# from .auto import *
+from .feature_selection import *
+from .filtering import *
+from .scaling import *
+from .feature_extraction import *
+from .imputation import *
+from .encoding import *
+from .resampling import *
+from .transformation import *
diff --git a/src/biofit/preprocessing/encoding/__init__.py b/src/biofit/preprocessing/encoding/__init__.py
new file mode 100644
index 0000000..de1415a
--- /dev/null
+++ b/src/biofit/preprocessing/encoding/__init__.py
@@ -0,0 +1,3 @@
+# ruff: noqa
+from .label_binarizing import *
+from .label_encoding import *
diff --git a/src/biofit/preprocessing/encoding/encoding.py b/src/biofit/preprocessing/encoding/encoding.py
new file mode 100644
index 0000000..9f5e44e
--- /dev/null
+++ b/src/biofit/preprocessing/encoding/encoding.py
@@ -0,0 +1,14 @@
+from dataclasses import dataclass, field
+
+from biofit.processing import BaseProcessor, ProcessorConfig
+
+
+@dataclass
+class EncoderConfig(ProcessorConfig):
+ """Base class for feature extraction processor configurations."""
+
+ processor_type: str = field(default="encoding", init=False, repr=False)
+
+
+class Encoder(BaseProcessor):
+ """Base class for feature extraction processors."""
diff --git a/src/biofit/preprocessing/encoding/label_binarizing/__init__.py b/src/biofit/preprocessing/encoding/label_binarizing/__init__.py
new file mode 100644
index 0000000..c9ebf6a
--- /dev/null
+++ b/src/biofit/preprocessing/encoding/label_binarizing/__init__.py
@@ -0,0 +1,2 @@
+# ruff: noqa
+from .label_binarizing import LabelBinarizer, LabelBinarizerConfig
diff --git a/src/biofit/preprocessing/encoding/label_binarizing/label_binarizing.py b/src/biofit/preprocessing/encoding/label_binarizing/label_binarizing.py
new file mode 100644
index 0000000..362c926
--- /dev/null
+++ b/src/biofit/preprocessing/encoding/label_binarizing/label_binarizing.py
@@ -0,0 +1,405 @@
+from dataclasses import dataclass, field
+from typing import List, Type, Union
+
+from biocore.utils.import_util import is_biosets_available
+import numpy as np
+import pyarrow as pa
+from biocore import DataHandler
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import SelectedColumnTypes, sync_backup_config
+from biofit.utils import logging
+
+from ..encoding import Encoder, EncoderConfig
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class LabelBinarizerConfig(EncoderConfig):
+ """Configuration class for label binarizer."""
+
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("ClassLabel")], init=False, repr=False
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("ClassLabel")], init=False, repr=False
+ )
+
+ processor_name: str = field(default="label_binarizing", init=False, repr=False)
+ _transform_process_desc: str = field(default=None, init=False, repr=False)
+ positive_labels: list = field(default_factory=list)
+ negative_labels: list = field(default_factory=list)
+ as_one_hot: bool = False
+ names: list = field(default_factory=list, init=False, repr=False)
+
+ # auto attributes
+ label_mapping: dict = field(default_factory=dict, init=False, repr=False)
+
+ def __post_init__(self):
+ if self.as_one_hot:
+ self._transform_process_desc = "One-hot encoding labels"
+ self._fit_process_desc = "Determining unique labels"
+ if self.names:
+ self._transform_process_desc = (
+ f"One-hot encoding labels with {self.names}"
+ )
+ self.label_mapping = {
+ str(label): i for i, label in enumerate(self.names)
+ }
+ else:
+ if not isinstance(self.positive_labels, list):
+ self.positive_labels = [self.positive_labels]
+ if not isinstance(self.negative_labels, list):
+ self.negative_labels = [self.negative_labels]
+ if len(self.positive_labels) < 1 and len(self.negative_labels) < 1:
+ raise ValueError(
+ "At least one of positive_labels or negative_labels must be provided"
+ )
+
+ self._transform_process_desc = f"Binarizing labels with positive_labels={self.positive_labels} and negative_labels={self.negative_labels}"
+
+ if self.positive_labels:
+ self._transform_process_desc = (
+ f"Binarizing labels with positive_labels={self.positive_labels}"
+ )
+
+ if self.negative_labels:
+ self._transform_process_desc += (
+ f" and negative_labels={self.negative_labels}"
+ )
+ else:
+ self._transform_process_desc = (
+ f"Binarizing labels with negative_labels={self.negative_labels}"
+ )
+
+ if not isinstance(self.positive_labels, list):
+ self.positive_labels = [self.positive_labels]
+
+ if not isinstance(self.negative_labels, list):
+ self.negative_labels = [self.negative_labels]
+
+ self.label_mapping = {str(label): 1 for label in self.positive_labels}
+ self.label_mapping.update({str(label): 0 for label in self.negative_labels})
+
+
+class LabelBinarizer(Encoder):
+ config_class = LabelBinarizerConfig
+ config: LabelBinarizerConfig
+
+ def __init__(
+ self,
+ positive_labels: list = [],
+ negative_labels: list = [],
+ names: list = [],
+ as_one_hot: bool = False,
+ config: LabelBinarizerConfig = None,
+ **kwargs,
+ ):
+ """
+ Args:
+ positive_labels (list, *optional*):
+ The labels to be considered as positive. Default is `[]`.
+ negative_labels (list, *optional*):
+ The labels to be considered as negative. Default is `[]`.
+ as_one_hot (bool, *optional*):
+ Whether to encode the labels as one-hot vectors. Default is `False`.
+ names (list, *optional*):
+ The names of the classes. This is required when `as_one_hot` is `True`.
+ config (LabelBinarizerConfig, *optional*):
+ **kwargs:
+ Arguments that are passed to ProcessorConfig.
+ """
+ super().__init__(
+ config=config,
+ positive_labels=positive_labels,
+ negative_labels=negative_labels,
+ as_one_hot=as_one_hot,
+ names=names,
+ **kwargs,
+ )
+ self.output_feature_type = get_feature("BinClassLabel")
+
+ @sync_backup_config
+ def set_params(self, **kwargs):
+ self.config = self.config.replace_defaults(**kwargs)
+ self.config.__post_init__()
+ return self
+
+ def fit(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ num_proc: int = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ) -> "LabelBinarizer":
+ self.config._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_fit(
+ X,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ )
+
+ def transform(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ self._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_transform(
+ X,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def fit_transform(
+ self,
+ X,
+ *args,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = True,
+ raise_if_missing: bool = False,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ num_proc: int = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ):
+ return self.fit(
+ X,
+ input_columns=input_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ ).transform(
+ X,
+ input_columns=input_columns,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def _get_features_out(self, X, **kwargs):
+ kwargs.pop("one_to_one_features", None)
+ features = super()._get_features_out(X, one_to_one_features=True, **kwargs)
+
+ input_names = DataHandler.get_column_names(X, generate_cols=True)
+ input_names = [input_names[i] for i in self._selected_indices]
+ if self.config.as_one_hot and self.config.names:
+ # incase that the user provided the pre-encoded labels of get_feature("ClassLabel")
+ for name in input_names:
+ label_feature = features.pop(name, None)
+ if isinstance(label_feature, get_feature("ClassLabel")):
+ label_names = label_feature.names
+ if self.config.names and isinstance(self.config.names[0], str):
+ self.config.label_mapping = {
+ str(label_names.index(label)): i
+ for i, label in enumerate(self.config.names)
+ }
+ features.update(
+ {
+ name: get_feature("ClassLabel")(names=["negative", "positive"])
+ for name in self.config.names
+ }
+ )
+ return features
+
+ if is_biosets_available():
+ for name in input_names:
+ # incase that the user provided the pre-encoded labels of get_feature("ClassLabel")
+ label_feature = features.get(name)
+ if isinstance(label_feature, get_feature("ClassLabel")):
+ label_names = label_feature.names
+ if self.config.positive_labels and isinstance(
+ self.config.positive_labels[0], str
+ ):
+ self.config.label_mapping = {
+ str(label_names.index(label)): 1
+ for label in self.config.positive_labels
+ }
+ if self.config.negative_labels and isinstance(
+ self.config.negative_labels[0], str
+ ):
+ self.config.label_mapping.update(
+ {
+ str(label_names.index(label)): 0
+ for label in self.config.negative_labels
+ }
+ )
+ label_feature = get_feature("BinClassLabel")(
+ names=self.config.names or ["negative", "positive"],
+ positive_labels=self.config.positive_labels,
+ negative_labels=self.config.negative_labels,
+ )
+ features[name] = label_feature
+
+ return features
+
+ def _process_fit_input(self, input, **kwargs):
+ if not self.config.as_one_hot or self.config.names:
+ kwargs["fn_kwargs"]["fn"] = None
+ if self.config.names:
+ self.config.label_mapping = {
+ str(label): i for i, label in enumerate(self.config.names)
+ }
+ return super()._process_fit_input(input, **kwargs)
+
+ def _check_data(self, X):
+ X_dims = DataHandler.get_shape(X)
+ if len(X_dims) > 1:
+ if X_dims[1] == 1:
+ out = DataHandler.select_column(X, 0)
+ else:
+ raise ValueError(
+ f"Expected input to have 1 column, got {X_dims[1]} columns"
+ )
+ else:
+ out = X
+ return DataHandler.to_format(out, "list")
+
+ def _fit_array(self, X):
+ out = self._check_data(X)
+ self.config.names = DataHandler.to_format(DataHandler.unique(out), "list")
+ self.config._transform_process_desc = (
+ f"Encoding labels with: {self.config.label_mapping}"
+ )
+ return self
+
+ def _partial_fit_array(self, X):
+ out = self._check_data(X)
+ labs = DataHandler.to_format(DataHandler.unique(out), "list")
+ self.config.names = np.unique(self.config.names + labs).tolist()
+ return self
+
+ def _pool_fit(self, fitted_processors):
+ names = []
+ for processor in fitted_processors:
+ if isinstance(processor, LabelBinarizer):
+ names += processor.config.names
+ self.config.names = np.unique(names).tolist()
+ self.config.label_mapping = {
+ str(label): i for i, label in enumerate(self.config.names)
+ }
+ self.config._transform_process_desc = (
+ f"Encoding labels with: {self.config.label_mapping}"
+ )
+ return self
+
+ def _process_fit_output(self, input, out):
+ if self.config.as_one_hot:
+ self.config._n_features_out = len(self.config.names)
+ self.output_dtype = "int64"
+ else:
+ self.config._n_features_out = None
+ self.config.label_mapping = {
+ str(k): v for k, v in self.config.label_mapping.items()
+ }
+ return super()._process_fit_output(input, out)
+
+ def _one_hot_transform(self, X):
+ col = self._check_data(X)
+
+ labs = np.zeros((len(col), len(self.config.names) + 1), dtype=np.int64)
+
+ inds = [self.config.label_mapping.get(str(val), -1) for val in col]
+ labs[np.arange(len(col)), inds] = 1
+ return labs[:, :-1]
+
+ def _binarize_transform(self, X: Union[pa.Table, pa.Array]):
+ col = self._check_data(X)
+
+ if self.config.positive_labels and self.config.negative_labels:
+ labs = [-1] * len(col)
+ elif self.config.positive_labels:
+ labs = [0] * len(col)
+ else:
+ labs = [1] * len(col)
+
+ if self.config.positive_labels and self.config.negative_labels:
+ labs = [
+ self.config.label_mapping.get(str(val), -1)
+ if labs[i] == -1
+ else labs[i]
+ for i, val in enumerate(col)
+ ]
+ elif self.config.positive_labels:
+ labs = [
+ self.config.label_mapping.get(str(val), 0) if labs[i] == 0 else labs[i]
+ for i, val in enumerate(col)
+ ]
+ else:
+ labs = [
+ self.config.label_mapping.get(str(val), 1) if labs[i] == 1 else labs[i]
+ for i, val in enumerate(col)
+ ]
+
+ return pa.array(labs)
+
+ def _transform_array(self, X):
+ if not self.config.as_one_hot:
+ return self._binarize_transform(X)
+ else:
+ return self._one_hot_transform(X)
diff --git a/src/biofit/preprocessing/encoding/label_encoding/__init__.py b/src/biofit/preprocessing/encoding/label_encoding/__init__.py
new file mode 100644
index 0000000..cbb513d
--- /dev/null
+++ b/src/biofit/preprocessing/encoding/label_encoding/__init__.py
@@ -0,0 +1,2 @@
+# ruff: noqa
+from .label_encoding import LabelEncoder, LabelEncoderConfig
diff --git a/src/biofit/preprocessing/encoding/label_encoding/label_encoding.py b/src/biofit/preprocessing/encoding/label_encoding/label_encoding.py
new file mode 100644
index 0000000..dce470a
--- /dev/null
+++ b/src/biofit/preprocessing/encoding/label_encoding/label_encoding.py
@@ -0,0 +1,243 @@
+from dataclasses import dataclass, field
+from typing import List, Type
+
+import numpy as np
+from biocore import DataHandler
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import SelectedColumnTypes, sync_backup_config
+from biofit.utils import logging
+
+from ..encoding import Encoder, EncoderConfig
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class LabelEncoderConfig(EncoderConfig):
+ """Configuration class for label encoding."""
+
+ _fit_process_desc: str = field(
+ default="Determining unique labels", init=False, repr=False
+ )
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("ClassLabel")], init=False, repr=False
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("ClassLabel")], init=False, repr=False
+ )
+
+ processor_name: str = field(default="label_encoding", init=False, repr=False)
+ _transform_process_desc: str = field(default=None, init=False, repr=False)
+
+ names: list = field(default_factory=list, init=False, repr=False)
+
+ # auto attributes
+ label_mapping: dict = field(default_factory=dict, init=False, repr=False)
+
+ def __post_init__(self):
+ if self.names:
+ self._transform_process_desc = f"Encoding labels with {self.names}"
+ self.label_mapping = {label: i for i, label in enumerate(self.names)}
+
+
+class LabelEncoder(Encoder):
+ config_class = LabelEncoderConfig
+ config: LabelEncoderConfig
+ output_feature_type = get_feature("ClassLabel")
+
+ @sync_backup_config
+ def set_params(self, **kwargs):
+ self.config = self.config.replace_defaults(**kwargs)
+ self.config.__post_init__()
+ return self
+
+ def fit(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ num_proc: int = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ) -> "LabelEncoder":
+ self.config._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_fit(
+ X,
+ input_columns=input_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ )
+
+ def transform(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ self._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_transform(
+ X,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def fit_transform(
+ self,
+ X,
+ *args,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = True,
+ raise_if_missing: bool = False,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ num_proc: int = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ):
+ return self.fit(
+ X,
+ input_columns=input_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ ).transform(
+ X,
+ input_columns=input_columns,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def _get_features_out(self, X, **kwargs):
+ features = super()._get_features_out(X, **kwargs)
+ if features:
+ input_names = self.config.feature_names_in_ or self.config.feature_idx_in_
+ for name in input_names:
+ label_feature = features.get(name)
+ if isinstance(label_feature, get_feature("ClassLabel")):
+ label_names = label_feature.names
+ if self.config.names and isinstance(self.config.names[0], str):
+ self.config.label_mapping = {
+ label_names.index(label): i
+ for i, label in enumerate(self.config.names)
+ }
+
+ label_feature = get_feature("ClassLabel")(names=self.config.names)
+ features[name] = label_feature
+
+ return features
+
+ def _check_data(self, X):
+ X_dims = DataHandler.get_shape(X)
+ if len(X_dims) > 1:
+ if X_dims[1] == 1:
+ out = DataHandler.select_column(X, 0)
+ else:
+ raise ValueError(
+ f"Expected input to have 1 column, got {X_dims[1]} columns"
+ )
+ else:
+ out = X
+ return out
+
+ def _process_fit_input(self, input, **kwargs):
+ if self.config.names:
+ kwargs["fn_kwargs"]["fn"] = None
+ return super()._process_fit_input(input, **kwargs)
+
+ def _fit_array(self, X):
+ out = self._check_data(X)
+ self.config.names = DataHandler.to_format(DataHandler.unique(out), "list")
+ self.config._transform_process_desc = (
+ f"Encoding labels with: {self.config.label_mapping}"
+ )
+ return self
+
+ def _partial_fit_array(self, X):
+ out = self._check_data(X)
+ labs = DataHandler.to_format(DataHandler.unique(out), "list")
+ self.config.names = np.unique(self.config.names + labs).tolist()
+ return self
+
+ def _pool_fit(self, fitted_processors):
+ names = []
+ for processor in fitted_processors:
+ if isinstance(processor, LabelEncoder):
+ names += processor.config.names
+ self.config.names = np.unique(names).tolist()
+ self.config.label_mapping = {
+ label: i for i, label in enumerate(self.config.names)
+ }
+ self.config._transform_process_desc = (
+ f"Encoding labels with: {self.config.label_mapping}"
+ )
+ return self
+
+ def _transform_array(self, X):
+ labs = self._check_data(X)
+ labs = DataHandler.to_format(labs, "list")
+ return [self.config.label_mapping.get(lab, -1) for lab in labs]
diff --git a/src/biofit/preprocessing/feature_extraction/__init__.py b/src/biofit/preprocessing/feature_extraction/__init__.py
new file mode 100644
index 0000000..54f0582
--- /dev/null
+++ b/src/biofit/preprocessing/feature_extraction/__init__.py
@@ -0,0 +1,3 @@
+# ruff: noqa
+from .pcoa import *
+from .pca import *
diff --git a/src/biofit/preprocessing/feature_extraction/feature_extraction.py b/src/biofit/preprocessing/feature_extraction/feature_extraction.py
new file mode 100644
index 0000000..556aa27
--- /dev/null
+++ b/src/biofit/preprocessing/feature_extraction/feature_extraction.py
@@ -0,0 +1,25 @@
+from dataclasses import dataclass, field
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import BaseProcessor, ProcessorConfig, SelectedFeatureTypes
+
+
+@dataclass
+class FeatureExtractorConfig(ProcessorConfig):
+ """Base class for feature extraction processor configurations."""
+
+ processor_type: str = field(default="feature_extraction", init=False, repr=False)
+ _fit_unused_feature_types: SelectedFeatureTypes = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ _transform_unused_feature_types: SelectedFeatureTypes = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+
+
+class FeatureExtractor(BaseProcessor):
+ """Base class for feature extraction processors."""
diff --git a/src/biofit/preprocessing/feature_extraction/pca/__init__.py b/src/biofit/preprocessing/feature_extraction/pca/__init__.py
new file mode 100644
index 0000000..a3c48f4
--- /dev/null
+++ b/src/biofit/preprocessing/feature_extraction/pca/__init__.py
@@ -0,0 +1,5 @@
+# ruff: noqa
+from .pca import (
+ PCAFeatureExtractor,
+ PCAFeatureExtractorConfig,
+)
diff --git a/src/biofit/preprocessing/feature_extraction/pca/pca.py b/src/biofit/preprocessing/feature_extraction/pca/pca.py
new file mode 100644
index 0000000..b2a9b95
--- /dev/null
+++ b/src/biofit/preprocessing/feature_extraction/pca/pca.py
@@ -0,0 +1,235 @@
+from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, List, Union
+
+import numpy as np
+import pandas as pd
+from biofit.processing import SelectedColumnTypes, sync_backup_config
+from biofit.utils import logging
+from sklearn.decomposition import PCA
+
+from ..feature_extraction import FeatureExtractor, FeatureExtractorConfig
+
+if TYPE_CHECKING:
+ import polars as pl
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class PCAFeatureExtractorConfig(FeatureExtractorConfig):
+ _fit_process_desc: str = field(
+ default="Determining the principal components", init=False, repr=False
+ )
+ _transform_process_desc: str = field(
+ default="Transforming the input data to principal components",
+ init=False,
+ repr=False,
+ )
+ processor_name: str = field(default="pca", init=False, repr=False)
+
+ input_columns: List[str] = None
+ n_components: int = None
+ copy: bool = True
+ whiten: bool = False
+ svd_solver: str = "auto"
+ tol: float = 0.0
+ iterated_power: str = "auto"
+ n_oversamples: int = 10
+ power_iteration_normalizer: str = "auto"
+ random_state: int = None
+
+
+class PCAFeatureExtractor(FeatureExtractor):
+ output_dtype = "float64"
+ config_class = PCAFeatureExtractorConfig
+ config: PCAFeatureExtractorConfig
+
+ def __init__(
+ self,
+ input_columns: List[str] = None,
+ n_components: int = None,
+ copy: bool = True,
+ whiten: bool = False,
+ svd_solver: str = "auto",
+ tol: float = 0.0,
+ iterated_power: str = "auto",
+ n_oversamples: int = 10,
+ power_iteration_normalizer: str = "auto",
+ random_state: int = None,
+ config: PCAFeatureExtractorConfig = None,
+ **kwargs,
+ ):
+ super().__init__(
+ config=config,
+ input_columns=input_columns,
+ n_components=n_components,
+ copy=copy,
+ whiten=whiten,
+ svd_solver=svd_solver,
+ tol=tol,
+ iterated_power=iterated_power,
+ n_oversamples=n_oversamples,
+ power_iteration_normalizer=power_iteration_normalizer,
+ random_state=random_state,
+ **kwargs,
+ )
+ self.pca = PCA(
+ n_components=self.config.n_components,
+ copy=self.config.copy,
+ whiten=self.config.whiten,
+ svd_solver=self.config.svd_solver,
+ tol=self.config.tol,
+ iterated_power=self.config.iterated_power,
+ n_oversamples=self.config.n_oversamples,
+ power_iteration_normalizer=self.config.power_iteration_normalizer,
+ random_state=self.config.random_state,
+ )
+
+ @sync_backup_config
+ def set_params(self, **kwargs):
+ self.config = self.config.replace_defaults(**kwargs)
+ self.pca = PCA(
+ n_components=self.config.n_components,
+ copy=self.config.copy,
+ whiten=self.config.whiten,
+ svd_solver=self.config.svd_solver,
+ tol=self.config.tol,
+ iterated_power=self.config.iterated_power,
+ n_oversamples=self.config.n_oversamples,
+ power_iteration_normalizer=self.config.power_iteration_normalizer,
+ random_state=self.config.random_state,
+ )
+ return self
+
+ def fit(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ num_proc: int = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ) -> "PCAFeatureExtractor":
+ self.config._input_columns = self._set_input_columns_and_arity(
+ input_columns or self.config.input_columns
+ )
+ return self._process_fit(
+ X,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ )
+
+ def transform(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ self._input_columns = self._set_input_columns_and_arity(
+ input_columns or self.config.input_columns
+ )
+ return self._process_transform(
+ X,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def fit_transform(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = True,
+ raise_if_missing: bool = False,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ num_proc: int = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ):
+ return self.fit(
+ X,
+ input_columns=input_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ ).transform(
+ X,
+ input_columns=input_columns,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def _fit_sklearn(self, X: Union[pd.DataFrame, "pl.DataFrame", np.ndarray]):
+ self.config.estimator = self.pca.fit(X)
+ return self
+
+ def _process_fit_output(self, input, out):
+ self.config._n_features_out = self.config.estimator.n_components_
+ return super()._process_fit_output(input, out)
+
+ def _transform_sklearn(self, X: Union[pd.DataFrame, "pl.DataFrame", np.ndarray]):
+ return self.config.estimator.transform(X)
diff --git a/src/biofit/preprocessing/feature_extraction/pca/plot_pca.py b/src/biofit/preprocessing/feature_extraction/pca/plot_pca.py
new file mode 100644
index 0000000..504938e
--- /dev/null
+++ b/src/biofit/preprocessing/feature_extraction/pca/plot_pca.py
@@ -0,0 +1,82 @@
+from dataclasses import dataclass, field
+from typing import Optional
+
+from biofit.utils.types import Unset
+from biofit.visualization.plotting_utils import plot_dimension_reduction
+
+from ..plot_feature_extraction import (
+ FeatureExtractorPlotter,
+ FeatureExtractorPlotterConfig,
+)
+
+
+@dataclass
+class PCAFeatureExtractorPlotterConfig(FeatureExtractorPlotterConfig):
+ processor_name: str = field(default="pca", init=False, repr=False)
+ title: str = "PCA Plot"
+ n_components: int = 3
+ label_column: str = None
+ group_column: str = None
+
+
+class PCAFeatureExtractorPlotter(FeatureExtractorPlotter):
+ config_class = PCAFeatureExtractorPlotterConfig
+ config: PCAFeatureExtractorPlotterConfig
+
+ def __init__(
+ self,
+ n_components: int = 3,
+ label_column: str = None,
+ group_column: str = None,
+ config: Optional[PCAFeatureExtractorPlotterConfig] = None,
+ **kwargs,
+ ):
+ super().__init__(
+ config=config,
+ n_components=n_components,
+ label_column=label_column,
+ group_column=group_column,
+ **kwargs,
+ )
+
+ def plot(
+ self,
+ X,
+ labels=None,
+ group=None,
+ input_columns=None,
+ label_column=None,
+ group_column=None,
+ precomputed: bool = Unset("False"),
+ pca_kwargs: dict = Unset("{}"),
+ title: str = Unset("PCA Plot"),
+ n_components: int = Unset(3),
+ path=None,
+ **kwargs,
+ ):
+ """Plot the PCA plot.
+
+ Args:
+ X: The input data.
+ labels: The labels for the data.
+ group: The group for the data.
+ input_columns: The input columns.
+ label_column: The label column.
+ group_column: The group column.
+ precomputed: Whether the input data is precomputed.
+ pca_kwargs: The additional kwargs for computing PCA. Only used if precomputed is False.
+ **kwargs: The additional kwargs.
+
+ """
+ return plot_dimension_reduction(
+ X,
+ labels=labels,
+ group=group,
+ input_columns=input_columns,
+ label_column=label_column,
+ group_column=group_column,
+ method="pca" if not precomputed else None,
+ method_kwargs=pca_kwargs,
+ output_dir=path,
+ **kwargs,
+ )
diff --git a/src/biofit/preprocessing/feature_extraction/pcoa/__init__.py b/src/biofit/preprocessing/feature_extraction/pcoa/__init__.py
new file mode 100644
index 0000000..d1cbef5
--- /dev/null
+++ b/src/biofit/preprocessing/feature_extraction/pcoa/__init__.py
@@ -0,0 +1,10 @@
+# ruff: noqa
+from .pcoa import (
+ PCoAFeatureExtractor,
+ PCoAFeatureExtractorConfig,
+ PCoAFeatureExtractorConfigForOTU,
+)
+from .plot_pcoa import (
+ PCoAFeatureExtractorPlotter,
+ PCoAFeatureExtractorPlotterConfig,
+)
diff --git a/src/biofit/preprocessing/feature_extraction/pcoa/pcoa.py b/src/biofit/preprocessing/feature_extraction/pcoa/pcoa.py
new file mode 100644
index 0000000..aedd511
--- /dev/null
+++ b/src/biofit/preprocessing/feature_extraction/pcoa/pcoa.py
@@ -0,0 +1,413 @@
+# Portions of this module are derived from the Pyckmeans project available at:
+# https://github.com/TankredO/pyckmeans
+
+# Pyckmeans is licensed under the MIT License. The following is a copy of the original license under which the Pyckmeans software is distributed:
+
+# MIT License
+
+# Copyright (c) 2021 Tankred Ott
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+
+"""Principal Coordinate Analysis (PCoA) feature extraction module."""
+
+from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, List, Optional, Type
+
+import numpy as np
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import SelectedColumnTypes, sync_backup_config
+from biofit.stat import DistanceStat, DistanceStatConfig
+from biofit.utils import logging
+
+from ..feature_extraction import FeatureExtractor, FeatureExtractorConfig
+
+if TYPE_CHECKING:
+ pass
+
+logger = logging.get_logger(__name__)
+
+
+def _center_mat(dmat: np.ndarray) -> np.ndarray:
+ """_center_mat
+
+ Center n*n matrix.
+
+ Parameters
+ ----------
+ dmat : np.ndarray
+ n*n matrix.
+
+ Returns
+ -------
+ np.ndarray
+ Centered matrix.
+ """
+
+ n = dmat.shape[0]
+ mat = np.full((n, n), -1 / n)
+ mat[np.diag_indices(n)] += 1
+
+ return mat.dot(dmat).dot(mat)
+
+
+class InvalidCorrectionTypeError(Exception):
+ """InvalidCorrectionTypeError"""
+
+
+class NegativeEigenvaluesCorrectionError(Exception):
+ """FailedCorrectionError
+
+ Error, signalling that the correction of negative eigenvalues failed.
+ """
+
+
+class NegativeEigenvaluesWarning(Warning):
+ """NegativeEigenvaluesWarning
+
+ Warning, signalling that negative eigenvalues were encountered.
+ """
+
+
+def pcoa(
+ x: np.ndarray,
+ correction: Optional[str] = None,
+ eps: float = 1e-8,
+):
+ """pcoa
+
+ Principle Coordinate Analysis.
+
+ Parameters
+ ----------
+ dist : Union[np.ndarray, pyckmeans.distance.DistanceMatrix]
+ n*n distance matrix either as np ndarray or as pyckmeans DistanceMatrix.
+ correction: Optional[str]
+ Correction for negative eigenvalues, by default None.
+ Available corrections are:
+ - None: negative eigenvalues are set to 0
+ - lingoes: Lingoes correction
+ - cailliez: Cailliet correction
+ eps : float, optional
+ Eigenvalues smaller than eps will be dropped. By default 0.0001
+
+ Returns
+ -------
+ Tuple[np.ndarray, np.ndarray]
+ Eigenvalues and eigenvectors.
+
+ Raises
+ ------
+ InvalidCorrectionTypeError
+ Raised if an unknown correction type is passed.
+ NegativeEigenvaluesCorrectionError
+ Raised if correction parameter is set and correction of negative
+ eigenvalues is not successful.
+ """
+
+ if correction is not None and correction not in ["lingoes", "cailliez"]:
+ msg = (
+ f'Unknown correction type "{correction}". '
+ + 'Available correction types are: "lingoes", "cailliez"'
+ )
+ raise InvalidCorrectionTypeError(msg)
+
+ # center matrix
+ dmat_centered = _center_mat((x * x) / -2)
+
+ # eigen decomposition
+ eigvals, eigvecs = np.linalg.eigh(dmat_centered, "U")
+
+ # order descending
+ ord_idcs = np.argsort(eigvals)[::-1]
+ eigvals = eigvals[ord_idcs]
+ eigvecs = eigvecs[:, ord_idcs]
+
+ # get min eigenvalue
+ min_eigval = np.min(eigvals)
+
+ # set small eigenvalues to 0
+ zero_eigval_idcs = np.nonzero(np.abs(eigvals) < eps)[0]
+ eigvals[zero_eigval_idcs] = 0
+
+ # no negative eigenvalues
+ if min_eigval > -eps:
+ fze_idx = len(np.nonzero(eigvals > eps)[0]) # index of first zero in eigvals
+ vectors = eigvecs[:, :fze_idx] * np.sqrt(eigvals[:fze_idx])
+
+ return eigvals, vectors
+
+ # negative eigenvalues
+ else:
+ fze_idx = len(np.nonzero(eigvals > eps)[0]) # index of first zero in eigvals
+ vectors = eigvecs[:, :fze_idx] * np.sqrt(eigvals[:fze_idx])
+
+ # negative eigenvalues, no correction
+ if not correction:
+ logger.warn(
+ "Negative eigenvalues encountered but no correction applied. "
+ "Negative eigenvalues will be treated as 0."
+ )
+
+ return eigvals, vectors
+
+ # negative eigenvalues, correction
+
+ # -- correct distance matrix
+ # lingoes correction
+ if correction == "lingoes":
+ corr_1 = -min_eigval
+
+ # corrected distance matrix
+ x_ncorr = -0.5 * ((x * x) + 2 * corr_1)
+ elif correction == "cailliez":
+ dmat_centered_2 = _center_mat(-0.5 * x)
+
+ # prepare matrix for correction
+ upper = np.c_[np.zeros((x.shape[0], x.shape[0])), 2 * dmat_centered]
+ lower = np.c_[np.diag(np.full(x.shape[0], -1)), -4 * dmat_centered_2]
+ sp_mat = np.r_[upper, lower]
+
+ corr_2 = np.max(np.real(np.linalg.eigvals(sp_mat)))
+
+ # corrected distance matrix
+ x_ncorr = -0.5 * (x + corr_2) ** 2
+
+ # -- apply PCoA to corrected distance matrix
+ x_ncorr[np.diag_indices(x_ncorr.shape[0])] = 0
+ x_ncorr = _center_mat(x_ncorr)
+
+ eigvals_ncorr, eigvecs_ncorr = np.linalg.eigh(x_ncorr, "U")
+
+ # order descending
+ ord_idcs_ncorr = np.argsort(eigvals_ncorr)[::-1]
+ eigvals_ncorr = eigvals_ncorr[ord_idcs_ncorr]
+ eigvecs_ncorr = eigvecs_ncorr[:, ord_idcs_ncorr]
+
+ # get min eigenvalue
+ min_eigval_ncorr = np.min(eigvals_ncorr)
+
+ # set small eigenvalues to 0
+ zero_eigval_idcs_ncorr = np.nonzero(np.abs(eigvals_ncorr) < eps)[0]
+ eigvals_ncorr[zero_eigval_idcs_ncorr] = 0
+
+ if min_eigval_ncorr < -eps:
+ msg = (
+ "Correction failed. There are still negative eigenvalues after applying "
+ + f"{correction.capitalize()} correction."
+ )
+ raise NegativeEigenvaluesCorrectionError(msg)
+
+ fze_idx_ncorr = len(
+ np.nonzero(eigvals_ncorr > eps)[0]
+ ) # index of first zero in eigvals
+ vectors_ncorr = eigvecs_ncorr[:, :fze_idx_ncorr] * np.sqrt(
+ eigvals_ncorr[:fze_idx_ncorr]
+ )
+
+ return eigvals_ncorr, vectors_ncorr
+
+
+@dataclass
+class PCoAFeatureExtractorConfig(FeatureExtractorConfig):
+ processor_name: str = field(default="pcoa", init=False, repr=False)
+ output_template_name: str = field(default="Dim{i+1}", init=False, repr=False)
+ _fit_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ _transform_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ _fit_process_desc: str = field(
+ default="",
+ init=False,
+ repr=False,
+ )
+ _transform_process_desc: str = field(
+ default="Applying Principal Coordinate Analysis (PCoA) to the input data",
+ init=False,
+ repr=False,
+ )
+ n_components: int = None
+ correction: str = None
+ eps: float = 1e-3
+
+ metric: str = "braycurtis"
+ p: float = 2
+ w: float = None
+ V: float = None
+ VI: float = None
+ squareform: bool = True
+
+ # fitted attributes
+ vectors: np.ndarray = None
+ eigvals: np.ndarray = None
+
+ def __post_init__(self):
+ self._fit_process_desc = f"Calculating PCoA using {self.metric} distance"
+
+
+@dataclass
+class PCoAFeatureExtractorConfigForOTU(PCoAFeatureExtractorConfig):
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance")], init=False, repr=False
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance")], init=False, repr=False
+ )
+ dataset_name: str = field(default="otu", init=False, repr=False)
+ correction: str = "cailliez"
+
+
+class PCoAFeatureExtractor(FeatureExtractor):
+ output_dtype = "float64"
+
+ config_class = PCoAFeatureExtractorConfig
+ config: PCoAFeatureExtractorConfig
+
+ def __init__(
+ self,
+ n_components: int = None,
+ correction: str = None,
+ eps: float = 1e-3,
+ metric: str = "braycurtis",
+ p: float = 2,
+ w: float = None,
+ V: float = None,
+ VI: float = None,
+ squareform: bool = True,
+ config: Optional[PCoAFeatureExtractorConfig] = None,
+ **kwargs,
+ ):
+ super().__init__(
+ config=config,
+ n_components=n_components,
+ correction=correction,
+ eps=eps,
+ metric=metric,
+ p=p,
+ w=w,
+ V=V,
+ VI=VI,
+ squareform=squareform,
+ **kwargs,
+ )
+ distance_config = DistanceStatConfig.from_config(self.config)
+ self.distance = DistanceStat(distance_config)
+ self.config._n_features_out = self.config.n_components
+
+ @sync_backup_config
+ def set_params(self, **kwargs):
+ self.config = self.config.replace_defaults(**kwargs)
+ distance_config = DistanceStatConfig.from_config(self.config)
+ self.distance = DistanceStat(distance_config)
+ self.config._n_features_out = self.config.n_components
+ return self
+
+ def fit(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ num_proc: int = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ) -> "PCoAFeatureExtractor":
+ self.config._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_fit(
+ X,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ )
+
+ def transform(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ self._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_transform(
+ X,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def _fit_numpy(self, X: "np.ndarray"):
+ self.config.eigvals, self.config.vectors = pcoa(
+ self.distance._transform_numpy(X),
+ correction=self.config.correction,
+ eps=self.config.eps,
+ )
+
+ return self
+
+ def _process_fit_output(self, input, output):
+ if self.config.n_components is None:
+ self.config.n_components = self.config.vectors.shape[1]
+ self.config._n_features_out = self.config.n_components
+ return super()._process_fit_output(input, output)
+
+ def _transform_numpy(self, X: "np.ndarray"):
+ return self.config.vectors[:, : self.config.n_components]
diff --git a/src/biofit/preprocessing/feature_extraction/pcoa/plot_pcoa.py b/src/biofit/preprocessing/feature_extraction/pcoa/plot_pcoa.py
new file mode 100644
index 0000000..eb2b329
--- /dev/null
+++ b/src/biofit/preprocessing/feature_extraction/pcoa/plot_pcoa.py
@@ -0,0 +1,91 @@
+from dataclasses import dataclass, field
+from typing import Optional
+
+from biofit.processing import SelectedColumnTypes
+from biofit.utils.types import Unset
+from biofit.visualization.plotting_utils import plot_dimension_reduction
+
+from ..plot_feature_extraction import (
+ FeatureExtractorPlotter,
+ FeatureExtractorPlotterConfig,
+)
+
+
+@dataclass
+class PCoAFeatureExtractorPlotterConfig(FeatureExtractorPlotterConfig):
+ processor_name: str = field(default="pcoa", init=False, repr=False)
+ title: str = "PCoA Plot"
+ n_components: int = 3
+ label_column: str = None
+ group_column: str = None
+
+
+class PCoAFeatureExtractorPlotter(FeatureExtractorPlotter):
+ config_class = PCoAFeatureExtractorPlotterConfig
+ config: PCoAFeatureExtractorPlotterConfig
+
+ def __init__(
+ self,
+ n_components: int = 3,
+ label_column: str = None,
+ group_column: str = None,
+ config: Optional[PCoAFeatureExtractorPlotterConfig] = None,
+ **kwargs,
+ ):
+ super().__init__(
+ config=config,
+ n_components=n_components,
+ label_column=label_column,
+ group_column=group_column,
+ **kwargs,
+ )
+
+ def plot(
+ self,
+ X,
+ labels=None,
+ group=None,
+ input_columns: SelectedColumnTypes = None,
+ label_column: SelectedColumnTypes = None,
+ group_column: SelectedColumnTypes = None,
+ precomputed: bool = Unset("False"),
+ pca_kwargs: dict = Unset("{}"),
+ title: str = Unset("PCA Plot"),
+ n_components: int = Unset(3),
+ path=None,
+ device: str = "pdf",
+ fingerprint: str = None,
+ unused_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = True,
+ show: bool = True,
+ ):
+ """Plot the PCoA plot.
+
+ Args:
+ X: The input data.
+ labels: The labels for the data.
+ group: The group for the data.
+ input_columns: The input columns.
+ label_column: The label column.
+ group_column: The group column.
+ precomputed: Whether the input data is precomputed.
+ pcoa_kwargs: The additional kwargs for computing PCoA. Only used if precomputed is False.
+ **kwargs: The additional kwargs.
+
+ """
+ return plot_dimension_reduction(
+ X,
+ labels=labels,
+ group=group,
+ input_columns=input_columns,
+ label_column=label_column,
+ group_column=group_column,
+ method="pcoa" if not precomputed else None,
+ method_kwargs=pca_kwargs,
+ output_dir=path,
+ device=device,
+ fingerprint=fingerprint,
+ unused_columns=unused_columns,
+ raise_if_missing=raise_if_missing,
+ show=show,
+ )
diff --git a/src/biofit/preprocessing/feature_extraction/plot_feature_extraction.py b/src/biofit/preprocessing/feature_extraction/plot_feature_extraction.py
new file mode 100644
index 0000000..ae81b18
--- /dev/null
+++ b/src/biofit/preprocessing/feature_extraction/plot_feature_extraction.py
@@ -0,0 +1,19 @@
+from dataclasses import dataclass, field
+from typing import TYPE_CHECKING
+
+from biofit.visualization.plotting import BasePlotter, PlotterConfig
+
+if TYPE_CHECKING:
+ pass
+
+
+@dataclass
+class FeatureExtractorPlotterConfig(PlotterConfig):
+ processor_type: str = field(default="feature_extractor", init=False, repr=False)
+
+
+class FeatureExtractorPlotter(BasePlotter):
+ """Base class for feature extraction processors."""
+
+ config_class = FeatureExtractorPlotterConfig
+ config: FeatureExtractorPlotterConfig
diff --git a/src/biofit/preprocessing/feature_selection/__init__.py b/src/biofit/preprocessing/feature_selection/__init__.py
new file mode 100644
index 0000000..6959d8e
--- /dev/null
+++ b/src/biofit/preprocessing/feature_selection/__init__.py
@@ -0,0 +1,2 @@
+# ruff: noqa
+from .min_prevalence_feature_selector import *
diff --git a/src/biofit/preprocessing/feature_selection/feature_selection.py b/src/biofit/preprocessing/feature_selection/feature_selection.py
new file mode 100644
index 0000000..68bfbd7
--- /dev/null
+++ b/src/biofit/preprocessing/feature_selection/feature_selection.py
@@ -0,0 +1,105 @@
+from dataclasses import dataclass, field
+
+from biocore import DataHandler
+from biocore.utils import get_kwargs
+
+from biofit.processing import BaseProcessor, ProcessorConfig
+from biofit.utils import logging
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class FeatureSelectorConfig(ProcessorConfig):
+ processor_type: str = field(default="feature_selection", init=False, repr=False)
+
+
+class FeatureSelector(BaseProcessor):
+ """Base class for feature selection processors."""
+
+ config_class = FeatureSelectorConfig
+ config: FeatureSelectorConfig
+
+ def run(self, X, runner=None, fn_kwargs: dict = {}, **map_kwargs):
+ fn_kwargs = self._prepare_runner(X, **fn_kwargs)
+ if "func_type" in fn_kwargs and fn_kwargs["func_type"] != "_fit":
+ input_format = fn_kwargs["in_format_kwargs"]["target_format"]
+ fn_kwargs["out_format_kwargs"]["target_format"] = input_format
+ runner = None
+ return super().run(X, runner=runner, fn_kwargs=fn_kwargs, **map_kwargs)
+
+ def _transform_any(self, X, selected_indices=None):
+ if selected_indices is None:
+ return X
+ return DataHandler.select_columns(X, selected_indices)
+
+ def _process_transform_batch_input(self, X, *fn_args, **fn_kwargs):
+ func = fn_kwargs.get("fn", None)
+ in_format_kwargs = fn_kwargs.get("in_format_kwargs", {})
+ in_format_kwargs["input_columns"] = None
+ input = DataHandler.to_format(X, **in_format_kwargs)
+ _fn_kwargs = get_kwargs(fn_kwargs, func)
+
+ selected_indices = fn_kwargs.get("selected_indices", None)
+ unused_indices = fn_kwargs.get("unused_indices", None)
+ keep_unused_columns = fn_kwargs.get("keep_unused_columns", None)
+ if self.config._feature_names_out and DataHandler.supports_named_columns(input):
+ feature_idx_out = DataHandler.get_column_indices(
+ input, self.config._feature_names_out, raise_if_missing=False
+ )
+ elif self.config._feature_idx_out is not None:
+ feature_idx_out = [
+ selected_indices[i] for i in self.config._feature_idx_out
+ ]
+ else:
+ raise ValueError(
+ "FeatureSelectorConfig requires either _feature_idx_out, as well as "
+ "_feature_idx_out when input format supports named columns."
+ )
+ if keep_unused_columns:
+ feature_idx_out = list(sorted(feature_idx_out + unused_indices))
+ _fn_kwargs["selected_indices"] = feature_idx_out
+ return input, fn_args, _fn_kwargs
+
+ def _process_fit_output(self, input, out):
+ idx_out = self.config._feature_idx_out
+ names_in = self.config.feature_names_in_
+ names_out = self.config._feature_names_out
+ if idx_out is not None:
+ if names_in is not None and (
+ names_out is None or len(idx_out) < len(names_out)
+ ):
+ names_out = [names_in[i] for i in idx_out]
+ self.config._feature_names_out = names_out
+
+ return super()._process_fit_output(input, out)
+
+ def _process_transform_batch_output(self, input, out, **fn_kwargs):
+ # do nothing
+ return out
+
+ def _process_transform_output(self, output, input, *args, **kwargs):
+ unused_indices = kwargs.get("unused_indices", None)
+ new_fingerprint = kwargs.get("fingerprint", None)
+ selected_indices = self.config._feature_idx_out
+ if DataHandler.supports_named_columns(input) and self.config._feature_names_out:
+ input_cols = DataHandler.get_column_names(input)
+ col_dict = dict(zip(input_cols, range(len(input_cols))))
+ selected_indices = {
+ col_dict[name]
+ for name in self.config._feature_names_out
+ if name in col_dict
+ }
+
+ before_filtering = len(kwargs.get("selected_indices", []))
+ logger.info(
+ f'"{self.__class__.__name__}": Selected {len(selected_indices)} out of '
+ f"{before_filtering} input features"
+ )
+
+ if unused_indices:
+ selected_indices = set(unused_indices).union(selected_indices)
+
+ return DataHandler.select_columns(
+ input, sorted(list(selected_indices)), new_fingerprint=new_fingerprint
+ )
diff --git a/src/biofit/preprocessing/feature_selection/min_prevalence_feature_selector/__init__.py b/src/biofit/preprocessing/feature_selection/min_prevalence_feature_selector/__init__.py
new file mode 100644
index 0000000..762618c
--- /dev/null
+++ b/src/biofit/preprocessing/feature_selection/min_prevalence_feature_selector/__init__.py
@@ -0,0 +1,20 @@
+# ruff: noqa
+from .min_prevalence_feature_selector import (
+ MinPrevalenceFeatureSelector,
+ MinPrevalenceFeatureSelectorConfig,
+ MinPrevalenceFeatureSelectorConfigForMetagenomics,
+ MinPrevalenceFeatureSelectorConfigForOTU,
+ MinPrevalenceFeatureSelectorConfigForSNP,
+)
+from .plot_min_prevalence_feature_selector import (
+ MinPrevalenceFeatureSelectorPlotter,
+ MinPrevalencePlotterConfig,
+ MinPrevalencePlotterConfigForASV,
+ MinPrevalencePlotterConfigForGenomics,
+ MinPrevalencePlotterConfigForMetagenomics,
+ MinPrevalencePlotterConfigForOTU,
+ MinPrevalencePlotterConfigForProteomics,
+ MinPrevalencePlotterConfigForMaldi,
+ MinPrevalencePlotterConfigForReadCount,
+ MinPrevalencePlotterConfigForSNP,
+)
diff --git a/src/biofit/preprocessing/feature_selection/min_prevalence_feature_selector/min_prevalence_feature_selector.py b/src/biofit/preprocessing/feature_selection/min_prevalence_feature_selector/min_prevalence_feature_selector.py
new file mode 100644
index 0000000..ab76a9a
--- /dev/null
+++ b/src/biofit/preprocessing/feature_selection/min_prevalence_feature_selector/min_prevalence_feature_selector.py
@@ -0,0 +1,404 @@
+"""
+Feature selector that filters features based on the number of samples they are present in.
+"""
+
+from dataclasses import dataclass, field
+from typing import List, Optional, Tuple, Type
+
+import numpy as np
+from biocore import DataHandler
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import SelectedColumnTypes, sync_backup_config
+from biofit.stat import ColumnMissingnessStat, ColumnMissingnessStatConfig
+from biofit.utils import logging
+
+from ..feature_selection import FeatureSelector, FeatureSelectorConfig
+
+logger = logging.get_logger(__name__)
+
+FILTER_FEATURES_DOCSTRING = """
+SampleFilter features based on the number of samples they are present in.
+
+Args:
+ X: Input data
+ min_prevalence: Minimum number of samples a feature must be present in to be kept.
+
+Returns:
+ SampleFiltered data.
+"""
+
+
+def _filter_features(nrows: int, total_missing, min_prevalence, cols=None):
+ """
+ SampleFilter features in a pandas DataFrame based on their presence in the dataset.
+
+ Args:
+ X (pd.DataFrame): The input DataFrame containing the features.
+ min_prevalence (float, optional): The minimum required presence of a feature in the dataset.
+ depth (float, optional): The minimum value to be considered as present.
+
+ Returns:
+ List[int]: A list of column indices that pass the filtering condition.
+ """
+ # keep features that are present in at least min_prevalence of the rows
+ col_to_keep = total_missing <= (nrows * (1 - min_prevalence))
+ if cols and len(cols) == len(col_to_keep):
+ return [c for c, b in zip(cols, col_to_keep) if b]
+ else:
+ return [i for i, b in enumerate(col_to_keep) if b]
+
+
+@dataclass
+class MinPrevalenceFeatureSelectorConfig(FeatureSelectorConfig):
+ processor_name: str = field(
+ default="min_prevalence_feature_selector", init=False, repr=False
+ )
+ _fit_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ _transform_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+
+ min_prevalence: float = 0.5
+ depth: int = None
+
+ def __post_init__(self):
+ if self.depth is None:
+ if isinstance(self.min_prevalence, float) and self.min_prevalence < 1:
+ self._fit_process_desc = (
+ "Removing features that are present in less than "
+ f"{self.min_prevalence * 100:.0f}% of samples"
+ )
+ elif (
+ isinstance(self.min_prevalence, (int, float))
+ and self.min_prevalence >= 1
+ ):
+ self._fit_process_desc = f"Removing features that are present in less than {self.min_prevalence} samples"
+
+ elif isinstance(self.min_prevalence, float) and self.min_prevalence < 1:
+ self._fit_process_desc = (
+ f"Removing features with <{self.min_prevalence * 100:.0f}% "
+ f"samples above {self.depth} counts"
+ )
+ elif isinstance(self.min_prevalence, (int, float)) and self.min_prevalence >= 1:
+ self._fit_process_desc = (
+ f"Removing features with <{self.min_prevalence} "
+ f"samples above {self.depth} counts"
+ )
+
+
+@dataclass
+class MinPrevalenceFeatureSelectorConfigForMetagenomics(
+ MinPrevalenceFeatureSelectorConfig
+):
+ dataset_name: str = field(default="metagenomics", init=False, repr=False)
+ _fit_input_feature_types: List[Tuple[Type, Type]] = field(
+ default_factory=lambda: [(get_feature("Abundance"), get_feature("ReadCount"))],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Tuple[Type, Type]] = field(
+ default_factory=lambda: [(get_feature("Abundance"), get_feature("ReadCount"))],
+ init=False,
+ repr=False,
+ )
+ depth: int = 100
+ min_prevalence: float = 0.01
+
+
+@dataclass
+class MinPrevalenceFeatureSelectorConfigForOTU(MinPrevalenceFeatureSelectorConfig):
+ dataset_name: str = field(default="otu", init=False, repr=False)
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance")], init=False, repr=False
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance")], init=False, repr=False
+ )
+ depth: int = 0
+ min_prevalence: float = 0.01
+
+ def __post_init__(self):
+ if self.depth is None:
+ if isinstance(self.min_prevalence, float) and self.min_prevalence < 1:
+ self._fit_process_desc = (
+ "Removing OTUs that are present in less than "
+ f"{self.min_prevalence * 100:.0f}% of samples"
+ )
+ elif (
+ isinstance(self.min_prevalence, (int, float))
+ and self.min_prevalence >= 1
+ ):
+ self._fit_process_desc = f"Removing OTUs that are present in less than {self.min_prevalence} samples"
+
+ elif isinstance(self.min_prevalence, float) and self.min_prevalence < 1:
+ self._fit_process_desc = (
+ f"Removing OTUs with <{self.min_prevalence * 100:.0f}% "
+ f"samples above {self.depth} counts"
+ )
+ elif isinstance(self.min_prevalence, (int, float)) and self.min_prevalence >= 1:
+ self._fit_process_desc = (
+ f"Removing OTUs with <{self.min_prevalence} "
+ f"samples above {self.depth} counts"
+ )
+
+
+@dataclass
+class MinPrevalenceFeatureSelectorConfigForSNP(MinPrevalenceFeatureSelectorConfig):
+ dataset_name: str = field(default="snp", init=False, repr=False)
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("GenomicVariant")], init=False, repr=False
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("GenomicVariant")], init=False, repr=False
+ )
+ depth: int = 0
+ min_prevalence: float = 0.012
+
+
+class MinPrevalenceFeatureSelector(FeatureSelector):
+ """
+ Feature selector that filters features based on the number of samples they are present in.
+
+ - config:
+ - min_prevalence (Union[str, float], optional):
+ The minimum prevalence of features to keep a sample.
+ The threshold to which we determine the minimum prevalence percentage or count of present values to maintain a sample or feature.
+ Any value passed >= 1 will be the minimum count while any value < 1 will be a percentage of the total values.
+ If "auto", the minimum prevalence is calculated as the first quartile minus 1.5 times the interquartile range. Defaults to "auto".
+ - depth (Union[int, Unset], optional):
+ The minimum value that we consider something to be present. Defaults to None.
+ """
+
+ config_class = MinPrevalenceFeatureSelectorConfig
+ config: MinPrevalenceFeatureSelectorConfig
+
+ def __init__(
+ self,
+ min_prevalence: float = 0.5,
+ depth: int = None,
+ config: Optional[MinPrevalenceFeatureSelectorConfig] = None,
+ **kwargs,
+ ):
+ super().__init__(
+ config=config, depth=depth, min_prevalence=min_prevalence, **kwargs
+ )
+ self = self.set_params()
+
+ @sync_backup_config
+ def set_params(self, **kwargs):
+ if len(kwargs) > 0:
+ self.config = self.config.replace_defaults(**kwargs)
+ col_missingness_config = ColumnMissingnessStatConfig.from_config(self.config)
+ col_missingness_config = col_missingness_config.replace_defaults(**kwargs)
+ self.missingness = ColumnMissingnessStat(config=col_missingness_config)
+ self.total_missing = None
+ self._input_columns = None
+ return self
+
+ def fit(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ num_proc: int = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ) -> "MinPrevalenceFeatureSelector":
+ self.config._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_fit(
+ X,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ )
+
+ def transform(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ self._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_transform(
+ X,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def fit_transform(
+ self,
+ X,
+ *args,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = True,
+ raise_if_missing: bool = False,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ num_proc: int = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ):
+ return self.fit(
+ X,
+ input_columns=input_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ ).transform(
+ X,
+ input_columns=input_columns,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def _process_fit_input(self, input, **kwargs):
+ self.height = DataHandler.get_shape(input)[0]
+ return input, kwargs
+
+ def _fit_numpy(self, X):
+ self.total_missing = self.missingness._transform_numpy(X).flatten()
+ return self
+
+ def _partial_fit_numpy(self, X):
+ if self.total_missing is None:
+ self.total_missing = self.missingness._transform_numpy(X).flatten()
+ else:
+ self.total_missing += self.missingness._transform_numpy(X).flatten()
+ return self
+
+ def _fit_pandas(self, X):
+ self.total_missing = self.missingness._transform_pandas(X).values.flatten()
+ return self
+
+ def _partial_fit_pandas(self, X):
+ if self.total_missing is None:
+ self.total_missing = self.missingness._transform_pandas(X).values.flatten()
+ else:
+ self.total_missing += self.missingness._transform_pandas(X).values.flatten()
+ return self
+
+ def _fit_polars(self, X):
+ self.total_missing = self.missingness._transform_polars(X).to_numpy().flatten()
+ return self
+
+ def _partial_fit_polars(self, X):
+ if self.total_missing is None:
+ self.total_missing = (
+ self.missingness._transform_polars(X).to_numpy().flatten()
+ )
+ else:
+ self.total_missing += (
+ self.missingness._transform_polars(X).to_numpy().flatten()
+ )
+ return self
+
+ def _fit_arrow(self, X):
+ self.total_missing = np.array(
+ [v[0] for _, v in self.missingness._transform_arrow(X).to_pydict().items()]
+ )
+ return self
+
+ def _partial_fit_arrow(self, X):
+ if self.total_missing is None:
+ total_missing = self.missingness._transform_arrow(X)
+ self.total_missing = np.array(
+ [v[0] for _, v in total_missing.to_pydict().items()]
+ )
+ else:
+ other_missing = self.missingness._transform_arrow(X)
+ other_missing = np.array(
+ [v[0] for _, v in other_missing.to_pydict().items()]
+ )
+ self.total_missing += other_missing
+ return self
+
+ def _pool_fit(self, out: List["MinPrevalenceFeatureSelector"]):
+ new_self = out[0]
+ if len(out) > 1:
+ logger.info("Pooling results")
+ new_self.total_missing = np.sum(
+ np.vstack([x.total_missing for x in out]), axis=0
+ )
+ logger.info(
+ f"Total missing values: {new_self.total_missing.sum()} "
+ f"({new_self.total_missing.mean():.2f} per feature)"
+ )
+ return new_self
+
+ def _process_fit_output(self, input, out):
+ self.config._feature_idx_out = _filter_features(
+ self.height,
+ self.total_missing,
+ self.config.min_prevalence,
+ )
+ return super()._process_fit_output(input, out)
diff --git a/src/biofit/preprocessing/feature_selection/min_prevalence_feature_selector/plot_min_prevalence_feature_selector.py b/src/biofit/preprocessing/feature_selection/min_prevalence_feature_selector/plot_min_prevalence_feature_selector.py
new file mode 100644
index 0000000..df84702
--- /dev/null
+++ b/src/biofit/preprocessing/feature_selection/min_prevalence_feature_selector/plot_min_prevalence_feature_selector.py
@@ -0,0 +1,367 @@
+from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, List, Type
+
+import pyarrow as pa
+from biocore.utils.import_util import is_datasets_available
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import SelectedColumnTypes
+from biofit.utils.types import Unset
+
+from ..plot_feature_selection import (
+ FeatureSelectorPlotter,
+ FeatureSelectorPlotterConfig,
+)
+
+if TYPE_CHECKING:
+ pass
+
+
+@dataclass
+class MinPrevalencePlotterConfig(FeatureSelectorPlotterConfig):
+ plot_process_desc: str = field(
+ default="Plotting presence feature selection", init=False, repr=False
+ )
+ _fit_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("METADATA_FEATURE_TYPES"),
+ get_feature("METADATA_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+ _transform_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("METADATA_FEATURE_TYPES"),
+ get_feature("METADATA_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+ _compare: bool = field(default=True, init=False, repr=False)
+ _add_labels: bool = field(default=False, init=False, repr=False)
+ processor_name: str = field(
+ default="min_prevalence_feature_selector", init=False, repr=False
+ )
+
+ sample_xlab: str = "Sum of Counts"
+ sample_main: str = "Sample Distribution"
+ feature_xlab: str = "Sum of Counts"
+ feature_main: str = "Feature Distribution"
+ legend_position: str = "top"
+ legend_title: str = "Max Missing Feature Selection"
+ before_name: str = "Before"
+ after_name: str = "After"
+ xlog: str = None
+ ylog: str = None
+ ncol: int = 2
+ include_non_zero_sum: bool = False
+ non_zero_samp_xlab: str = None
+ non_zero_samp_main: str = None
+ non_zero_feat_xlab: str = None
+ non_zero_feat_main: str = None
+
+
+@dataclass
+class MinPrevalencePlotterConfigForMetagenomics(MinPrevalencePlotterConfig):
+ dataset_name: str = field(default="metagenomics", init=False, repr=False)
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance"), get_feature("Abundance")],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance"), get_feature("Abundance")],
+ init=False,
+ repr=False,
+ )
+
+ feature_main = "Taxa Total Abundance Distribution"
+ non_zero_samp_xlab: str = "Species Richness"
+ non_zero_samp_main: str = "Richness Distribution"
+ non_zero_feat_xlab: str = "Species Prevalence"
+ non_zero_feat_main: str = "Prevalence Across Samples"
+ xlog = "log2_1p"
+ include_non_zero_sum: bool = True
+
+
+@dataclass
+class MinPrevalencePlotterConfigForOTU(MinPrevalencePlotterConfigForMetagenomics):
+ dataset_name: str = field(default="otu", init=False, repr=False)
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance"), get_feature("Abundance")],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance"), get_feature("Abundance")],
+ init=False,
+ repr=False,
+ )
+ feature_main: str = "OTU Total Abundance Distribution"
+ include_non_zero_sum: bool = True
+
+
+@dataclass
+class MinPrevalencePlotterConfigForASV(MinPrevalencePlotterConfigForMetagenomics):
+ dataset_name: str = field(default="asv", init=False, repr=False)
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance"), get_feature("Abundance")],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance"), get_feature("Abundance")],
+ init=False,
+ repr=False,
+ )
+ include_non_zero_sum: bool = True
+
+
+@dataclass
+class MinPrevalencePlotterConfigForGenomics(MinPrevalencePlotterConfig):
+ dataset_name: str = field(default="genomics", init=False, repr=False)
+ _input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ (get_feature("ReadCount"), get_feature("GenomicVariant")),
+ (get_feature("ReadCount"), get_feature("GenomicVariant")),
+ ],
+ init=False,
+ repr=False,
+ )
+
+
+@dataclass
+class MinPrevalencePlotterConfigForSNP(MinPrevalencePlotterConfig):
+ dataset_name: str = field(default="snp", init=False, repr=False)
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("GenomicVariant"),
+ get_feature("GenomicVariant"),
+ ],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("GenomicVariant"),
+ get_feature("GenomicVariant"),
+ ],
+ init=False,
+ repr=False,
+ )
+
+
+@dataclass
+class MinPrevalencePlotterConfigForReadCount(MinPrevalencePlotterConfig):
+ dataset_name: str = field(default="read_count", init=False, repr=False)
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("ReadCount"), get_feature("ReadCount")],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("ReadCount"), get_feature("ReadCount")],
+ init=False,
+ repr=False,
+ )
+
+
+@dataclass
+class MinPrevalencePlotterConfigForProteomics(MinPrevalencePlotterConfig):
+ dataset_name: str = field(default="proteomics", init=False, repr=False)
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Expression"), get_feature("Expression")],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Expression"), get_feature("Expression")],
+ init=False,
+ repr=False,
+ )
+
+
+@dataclass
+class MinPrevalencePlotterConfigForMaldi(MinPrevalencePlotterConfig):
+ dataset_name: str = field(default="maldi", init=False, repr=False)
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("PeakIntensity"),
+ get_feature("PeakIntensity"),
+ ],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("PeakIntensity"),
+ get_feature("PeakIntensity"),
+ ],
+ init=False,
+ repr=False,
+ )
+
+
+class MinPrevalenceFeatureSelectorPlotter(FeatureSelectorPlotter):
+ config_class = MinPrevalencePlotterConfig
+ config: MinPrevalencePlotterConfig
+
+ def __init__(
+ self,
+ sample_xlab: str = "Sum of Counts",
+ sample_main: str = "Sample Distribution",
+ feature_xlab: str = "Sum of Counts",
+ feature_main: str = "Feature Distribution",
+ legend_position: str = "top",
+ legend_title: str = "Max Missing Feature Selection",
+ before_name: str = "Before",
+ after_name: str = "After",
+ xlog: str = None,
+ ylog: str = None,
+ ncol: int = 2,
+ include_non_zero_sum: bool = False,
+ non_zero_samp_xlab: str = None,
+ non_zero_samp_main: str = None,
+ non_zero_feat_xlab: str = None,
+ non_zero_feat_main: str = None,
+ config: MinPrevalencePlotterConfig = None,
+ **kwargs,
+ ):
+ super().__init__(
+ config=config,
+ sample_xlab=sample_xlab,
+ sample_main=sample_main,
+ feature_xlab=feature_xlab,
+ feature_main=feature_main,
+ legend_position=legend_position,
+ legend_title=legend_title,
+ before_name=before_name,
+ after_name=after_name,
+ xlog=xlog,
+ ylog=ylog,
+ ncol=ncol,
+ include_non_zero_sum=include_non_zero_sum,
+ non_zero_samp_xlab=non_zero_samp_xlab,
+ non_zero_samp_main=non_zero_samp_main,
+ non_zero_feat_xlab=non_zero_feat_xlab,
+ non_zero_feat_main=non_zero_feat_main,
+ **kwargs,
+ )
+
+ def plot_pandas(self, x1, x2):
+ row_sums = [pa.array(x1.sum(axis=1)), pa.array(x2.sum(axis=1))]
+ col_sums = [pa.array(x1.sum(axis=0)), pa.array(x2.sum(axis=0))]
+ input = [row_sums, col_sums]
+ mains = [self.config.sample_main, self.config.feature_main]
+ xlabs = [self.config.sample_xlab, self.config.feature_xlab]
+ if self.config.include_non_zero_sum:
+ non_zero_sample_sums = [
+ pa.array(x1[x1 > 0].count(axis=1)),
+ pa.array(x2[x2 > 0].count(axis=1)),
+ ]
+ non_zero_feature_sums = [
+ pa.array(x1[x1 > 0].count(axis=0)),
+ pa.array(x2[x2 > 0].count(axis=0)),
+ ]
+ mains.extend(
+ [
+ self.config.non_zero_samp_main,
+ self.config.non_zero_feat_main,
+ ]
+ )
+ xlabs.extend(
+ [
+ self.config.non_zero_samp_xlab,
+ self.config.non_zero_feat_xlab,
+ ]
+ )
+ input.extend([non_zero_sample_sums, non_zero_feature_sums])
+
+ self.plotter(
+ list_of_sums=input,
+ path=self.config.path,
+ xlabs=xlabs,
+ mains=mains,
+ ncol=self.config.ncol,
+ legend_position=self.config.legend_position,
+ legend_title=self.config.legend_title,
+ before_name=self.config.before_name,
+ after_name=self.config.after_name,
+ xlog=self.config.xlog,
+ ylog=self.config.ylog,
+ )
+
+ def plot(
+ self,
+ x1,
+ x2=None,
+ input_columns1: SelectedColumnTypes = None,
+ input_columns2: SelectedColumnTypes = None,
+ sample_xlab: str = Unset('"Sum of Counts"'),
+ sample_main: str = Unset('"Sample Distribution"'),
+ feature_xlab: str = Unset('"Sum of Counts"'),
+ feature_main: str = Unset('"Feature Distribution"'),
+ legend_position: str = Unset('"top"'),
+ legend_title: str = Unset('"Max Missing Feature Selection"'),
+ before_name: str = Unset('"Before"'),
+ after_name: str = Unset('"After"'),
+ xlog: str = Unset("None"),
+ ylog: str = Unset("None"),
+ ncol: int = Unset("2"),
+ include_non_zero_sum: bool = Unset("False"),
+ non_zero_samp_xlab: str = Unset("None"),
+ non_zero_samp_main: str = Unset("None"),
+ non_zero_feat_xlab: str = Unset("None"),
+ non_zero_feat_main: str = Unset("None"),
+ path: str = None,
+ device: str = "pdf",
+ fingerprint: str = None,
+ unused_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = True,
+ show: bool = True,
+ ):
+ if x2 is None:
+ if is_datasets_available():
+ from biosets import Bioset
+ from datasets import Dataset as HfDataset
+
+ if isinstance(x1, (Bioset, HfDataset)):
+ _, _, prev_replays = self._get_replays(x1)
+ if prev_replays:
+ x2 = Bioset.from_replays(prev_replays)
+ return self._plot(x2, x1)
+ else:
+ raise ValueError(
+ "Must provide the before and after feature selection datasets."
+ )
+
+ self.config._input_columns = self._set_input_columns_and_arity(
+ input_columns1, input_columns2
+ )
+ return self._plot(
+ x1,
+ x2,
+ sample_xlab=sample_xlab,
+ sample_main=sample_main,
+ feature_xlab=feature_xlab,
+ feature_main=feature_main,
+ legend_position=legend_position,
+ legend_title=legend_title,
+ before_name=before_name,
+ after_name=after_name,
+ xlog=xlog,
+ ylog=ylog,
+ ncol=ncol,
+ include_non_zero_sum=include_non_zero_sum,
+ non_zero_samp_xlab=non_zero_samp_xlab,
+ non_zero_samp_main=non_zero_samp_main,
+ non_zero_feat_xlab=non_zero_feat_xlab,
+ non_zero_feat_main=non_zero_feat_main,
+ path=path,
+ device=device,
+ fingerprint=fingerprint,
+ unused_columns=unused_columns,
+ raise_if_missing=raise_if_missing,
+ show=show,
+ )
diff --git a/src/biofit/preprocessing/feature_selection/plot_feature_selection.R b/src/biofit/preprocessing/feature_selection/plot_feature_selection.R
new file mode 100644
index 0000000..77220f3
--- /dev/null
+++ b/src/biofit/preprocessing/feature_selection/plot_feature_selection.R
@@ -0,0 +1,57 @@
+##
+
+source(file.path(R_SCRIPTS_PATH, "plotting_utils.R"))
+
+
+plot_feature_selector <- function(
+ list_of_sums, path,
+ xlabs,
+ mains,
+ legend_position = "top",
+ legend_title = "Feature Selection",
+ before_name = "Before",
+ after_name = "After",
+ xlog = NULL,
+ ylog = NULL,
+ ...) {
+
+ suppressPackageStartupMessages(require(ggplot2))
+ suppressPackageStartupMessages(require(RColorBrewer))
+ suppressPackageStartupMessages(require(circlize))
+ suppressPackageStartupMessages(require(patchwork))
+
+ if (!is.list(list_of_sums) && !is.vector(list_of_sums)) {
+ list_of_sums <- list(list_of_sums)
+ }
+
+ if (!is.list(xlabs) && !is.vector(xlabs)) {
+ xlabs <- list(xlabs)
+ }
+
+ if (!is.list(mains) && !is.vector(mains)) {
+ mains <- list(mains)
+ }
+
+ if (length(xlabs) != length(list_of_sums) && length(xlabs) == 1) {
+ xlabs <- rep(xlabs, length(list_of_sums))
+ }
+
+ if (length(mains) != length(list_of_sums) && length(mains) == 1) {
+ mains <- rep(mains, length(list_of_sums))
+ }
+ plots <- NULL
+ for (i in 1:length(list_of_sums)) {
+ x1 <- as.vector(list_of_sums[[i]][[1]])
+ x2 <- as.vector(list_of_sums[[i]][[2]])
+ if (is.null(plots)) {
+ plots <- generate_comparison_histogram(x1, x2, xlab = xlabs[[i]], title = mains[[i]], xlog = xlog, ylog = ylog)
+ } else {
+ plots <- plots + generate_comparison_histogram(x1, x2, xlab = xlabs[[i]], title = mains[[i]], xlog = xlog, ylog = ylog)
+ }
+ }
+
+ grid <- plots + plot_layout(guides = "collect", ncol=ncol) &
+ theme(text = element_text(size = 8), legend.position = legend_position)
+ save_plots(path, plot = grid, width = 6, height = 6.5, dpi = 600)
+}
+
diff --git a/src/biofit/preprocessing/feature_selection/plot_feature_selection.py b/src/biofit/preprocessing/feature_selection/plot_feature_selection.py
new file mode 100644
index 0000000..4254ace
--- /dev/null
+++ b/src/biofit/preprocessing/feature_selection/plot_feature_selection.py
@@ -0,0 +1,20 @@
+from dataclasses import dataclass, field
+from pathlib import Path
+
+from biofit.visualization.plotting import BasePlotter, PlotterConfig
+
+
+@dataclass
+class FeatureSelectorPlotterConfig(PlotterConfig):
+ processor_type: str = field(default="feature_selection", init=False, repr=False)
+ r_source: str = field(
+ default=(Path(__file__).parent / "plot_feature_selection.R").as_posix(),
+ init=False,
+ repr=False,
+ )
+ main_method: str = field(default="plot_feature_selector", init=False, repr=False)
+
+
+class FeatureSelectorPlotter(BasePlotter):
+ config_class = FeatureSelectorPlotterConfig
+ config: FeatureSelectorPlotterConfig
diff --git a/src/biofit/preprocessing/feature_selection/rfe/plot_rfe.R b/src/biofit/preprocessing/feature_selection/rfe/plot_rfe.R
new file mode 100644
index 0000000..774cd2c
--- /dev/null
+++ b/src/biofit/preprocessing/feature_selection/rfe/plot_rfe.R
@@ -0,0 +1,25 @@
+source()
+
+plot_rfe_feature_selector_for_genomics <- function(
+ x1, x2,
+ legend_title = "Feature Selection by Recursive Feature Elimination",
+ ...) {
+ args <- list(...)
+ args$x1 <- x1
+ args$x2 <- x2
+ args$legend_title <- if ("legend_title" %in% names(args)) args$legend_title else legend_title
+ do.call(plot_feature_selector_for_genomics, args)
+}
+
+
+plot_rfe_feature_selector_snp <- function(
+ x1, x2,
+ feature_main = "By GenomicVariantss",
+ ...) {
+ args <- list(...)
+ args$x1 <- x1
+ args$x2 <- x2
+ args$feature_main <- if ("feature_main" %in% names(args)) args$feature_main else feature_main
+ do.call(plot_rfe_feature_selector_for_genomics, args)
+}
+
diff --git a/src/biofit/preprocessing/feature_selection/rfe/plot_rfe.py b/src/biofit/preprocessing/feature_selection/rfe/plot_rfe.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/biofit/preprocessing/filtering/__init__.py b/src/biofit/preprocessing/filtering/__init__.py
new file mode 100644
index 0000000..6949be0
--- /dev/null
+++ b/src/biofit/preprocessing/filtering/__init__.py
@@ -0,0 +1,4 @@
+# ruff: noqa
+from .min_prevalence_sample_filter import *
+from .missing_labels import *
+from .row_abundance import *
diff --git a/src/biofit/preprocessing/filtering/filtering.py b/src/biofit/preprocessing/filtering/filtering.py
new file mode 100644
index 0000000..38ca3da
--- /dev/null
+++ b/src/biofit/preprocessing/filtering/filtering.py
@@ -0,0 +1,78 @@
+from dataclasses import dataclass, field
+
+import pandas as pd
+import pyarrow as pa
+from biocore import DataHandler
+from biocore.utils.import_util import is_datasets_available
+from biocore.utils.inspect import get_kwargs
+from biocore.utils.py_util import is_bioset, is_dataset, is_iterable_dataset
+
+from biofit.processing import BaseProcessor, ProcessorConfig
+from biofit.utils import logging
+from biofit.utils.table_util import string_to_arrow
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class SampleFilterConfig(ProcessorConfig):
+ processor_type: str = field(default="filtering", init=False, repr=False)
+
+
+class SampleFilter(BaseProcessor):
+ """Base class for filtering processors.
+
+ NOTE: All transformation functions must return bools.
+ """
+
+ _feature_dependent = (
+ False # SampleFiltering does not depend on features when transforming
+ )
+
+ def run(self, X, runner=None, fn_kwargs: dict = {}, **map_kwargs):
+ fn_kwargs = self._prepare_runner(X, **fn_kwargs)
+ if fn_kwargs["func_type"] != "_fit":
+ if is_datasets_available():
+ if (
+ is_bioset(X) or is_dataset(X, iterable=False)
+ ) and "out_type" not in fn_kwargs:
+ from datasets import Dataset
+
+ runner = Dataset.map
+ map_kwargs = get_kwargs(map_kwargs, runner)
+ elif is_iterable_dataset(X):
+ from datasets import IterableDataset
+
+ runner = IterableDataset.map
+ map_kwargs = get_kwargs(map_kwargs, runner)
+
+ return super().run(X, runner=runner, fn_kwargs=fn_kwargs, **map_kwargs)
+
+ def _process_transform_batch_output(self, input, out, **fn_kwargs):
+ bools = DataHandler.to_numpy(out).flatten().tolist() if len(out) > 0 else out[0]
+ inds = [i for i, x in enumerate(bools) if x]
+ if len(inds):
+ return DataHandler.select_rows(input, inds)
+ else:
+ cols = DataHandler.get_column_names(input, generate_cols=True)
+ schema = pa.schema(
+ [
+ pa.field(k, string_to_arrow(v))
+ for k, v in DataHandler.get_dtypes(input).items()
+ ]
+ )
+ return pa.Table.from_pandas(
+ pd.DataFrame(columns=cols), preserve_index=False, schema=schema
+ )
+
+ def _process_transform_output(self, output, *args, **kwargs):
+ init_num_rows = DataHandler.get_shape(args[0])[0]
+ final_num_rows = DataHandler.get_shape(output)[0]
+ if final_num_rows != init_num_rows:
+ logger.info(
+ f'"{self.__class__.__name__}": Selected {final_num_rows} out '
+ f"of {init_num_rows} samples"
+ )
+ else:
+ logger.info(f'"{self.__class__.__name__}": no samples were removed')
+ return output
diff --git a/src/biofit/preprocessing/filtering/min_prevalence_sample_filter/__init__.py b/src/biofit/preprocessing/filtering/min_prevalence_sample_filter/__init__.py
new file mode 100644
index 0000000..0180563
--- /dev/null
+++ b/src/biofit/preprocessing/filtering/min_prevalence_sample_filter/__init__.py
@@ -0,0 +1,20 @@
+# ruff: noqa
+from .min_prevalence_sample_filter import (
+ MinPrevalenceSampleFilter,
+ MinPrevalenceRowSampleFilterConfig,
+ MinPrevalenceRowSampleFilterConfigForOTU,
+ MinPrevalenceRowSampleFilterConfigForSNP,
+ MinPrevalenceRowSampleFilterConfigForMaldi,
+)
+from .plot_min_prevalence_sample_filter import (
+ MinPrevalenceRowPlotterConfig,
+ MinPrevalenceRowPlotterConfigForMetagenomics,
+ MinPrevalenceRowPlotterConfigForOTU,
+ MinPrevalenceRowPlotterConfigForSNP,
+ MinPrevalenceRowPlotterConfigForASV,
+ MinPrevalenceRowPlotterConfigForMaldi,
+ MinPrevalenceRowPlotterConfigForGenomics,
+ MinPrevalenceRowPlotterConfigForReadCount,
+ MinPrevalenceRowPlotterConfigForProteomics,
+ MinPrevalenceRowPlotter,
+)
diff --git a/src/biofit/preprocessing/filtering/min_prevalence_sample_filter/min_prevalence_sample_filter.py b/src/biofit/preprocessing/filtering/min_prevalence_sample_filter/min_prevalence_sample_filter.py
new file mode 100644
index 0000000..174e31e
--- /dev/null
+++ b/src/biofit/preprocessing/filtering/min_prevalence_sample_filter/min_prevalence_sample_filter.py
@@ -0,0 +1,370 @@
+import sys
+from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, List, Optional, Type, Union
+
+import numpy as np
+import pandas as pd
+from biocore.utils.import_util import is_polars_available
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import SelectedColumnTypes, sync_backup_config
+from biofit.stat import RowMissingnessStat, RowMissingnessStatConfig
+from biofit.utils import Unset, logging
+
+from ..filtering import SampleFilter, SampleFilterConfig
+
+if TYPE_CHECKING:
+ import polars as pl
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class MinPrevalenceRowSampleFilterConfig(SampleFilterConfig):
+ # process description
+ _fit_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ _transform_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ processor_name: str = field(
+ default="min_prevalence_sample_filter", init=False, repr=False
+ )
+
+ # default values
+ min_prevalence: float = "auto"
+ depth: int = None
+
+
+@dataclass
+class MinPrevalenceRowSampleFilterConfigForOTU(MinPrevalenceRowSampleFilterConfig):
+ # dataset description
+ dataset_name: str = field(default="otu", init=False, repr=False)
+
+ # override default values
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance")], init=False, repr=False
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance")], init=False, repr=False
+ )
+ depth: int = 0
+ min_prevalence: float = "auto"
+
+ def __post_init__(self):
+ if self.depth is None:
+ self._transform_process_desc = (
+ f"Removing samples with <{self.min_prevalence * 100:.0f}% present OTUs"
+ )
+ elif isinstance(self.min_prevalence, float) and self.min_prevalence < 1:
+ self._transform_process_desc = f"Removing samples with <{self.min_prevalence * 100:.0f}% OTUs over {self.depth} counts"
+ elif isinstance(self.min_prevalence, (int, float)) and self.min_prevalence >= 1:
+ self._transform_process_desc = f"Removing samples with <{self.min_prevalence} OTUs over {self.depth} counts"
+
+
+@dataclass
+class MinPrevalenceRowSampleFilterConfigForSNP(MinPrevalenceRowSampleFilterConfig):
+ # dataset description
+ dataset_name: str = field(default="snp", init=False, repr=False)
+
+ # override default values
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("GenomicVariant")], init=False, repr=False
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("GenomicVariant")], init=False, repr=False
+ )
+ depth: int = 0
+ min_prevalence: float = 0.012
+
+
+@dataclass
+class MinPrevalenceRowSampleFilterConfigForMaldi(MinPrevalenceRowSampleFilterConfig):
+ # dataset description
+ dataset_name: str = field(default="maldi", init=False, repr=False)
+
+ # override default values
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("PeakIntensity")], init=False, repr=False
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("PeakIntensity")], init=False, repr=False
+ )
+ depth: int = 0
+ min_prevalence: float = 0.4
+
+
+class MinPrevalenceSampleFilter(SampleFilter):
+ # main config class
+ config_class = MinPrevalenceRowSampleFilterConfig
+ config: MinPrevalenceRowSampleFilterConfig
+
+ IQRs = None
+
+ def __init__(
+ self,
+ config: Optional[MinPrevalenceRowSampleFilterConfig] = None,
+ *,
+ min_prevalence: Union[str, float] = "auto",
+ depth: Union[int, Unset] = None,
+ **kwargs,
+ ):
+ """SampleFilter samples based on the minimum prevalence of features.
+
+ Args:
+ config (MinPrevalenceRowSampleFilterConfig, optional):
+ The configuration for the filter. Defaults to None. If a configuration is provided, the
+ below arguments are ignored. To override the configuration, use the set_params method.
+ min_prevalence (Union[str, float], optional):
+ The minimum prevalence of features to keep a sample.
+ The threshold to which we determine the minimum prevalence percentage or count of present values to maintain a sample or feature.
+ Any value passed >= 1 will be the minimum count while any value < 1 will be a percentage of the total values.
+ If "auto", the minimum prevalence is calculated as the first quartile minus 1.5 times the interquartile range. Defaults to "auto".
+ depth (Union[int, Unset], optional):
+ The minimum value that we consider something to be present. Defaults to None.
+ **kwargs:
+ Additional keyword arguments to pass to the filter. See the ProcessorConfig class for more details.
+ """
+ super().__init__(
+ config=config, min_prevalence=min_prevalence, depth=depth, **kwargs
+ )
+ row_missingness_config = RowMissingnessStatConfig.from_config(self.config)
+ self.missingness = RowMissingnessStat(row_missingness_config)
+
+ @sync_backup_config
+ def set_params(self, **kwargs):
+ self.config = self.config.replace_defaults(**kwargs)
+ row_missingness_config = RowMissingnessStatConfig.from_config(self.config)
+ self.missingness = RowMissingnessStat(row_missingness_config)
+ return self
+
+ def fit(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ num_proc: int = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ) -> "MinPrevalenceSampleFilter":
+ self.config._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_fit(
+ X,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ )
+
+ def transform(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ self._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_transform(
+ X,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def fit_transform(
+ self,
+ X,
+ *args,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = True,
+ raise_if_missing: bool = False,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ num_proc: int = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ):
+ return self.fit(
+ X,
+ input_columns=input_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ ).transform(
+ X,
+ input_columns=input_columns,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def _process_fit_input(self, input, **kwargs):
+ if self.config.min_prevalence != "auto":
+ kwargs["fn_kwargs"]["fn"] = None
+ return super()._process_fit_input(input, **kwargs)
+
+ def _fit_polars(self, X: "pl.DataFrame"):
+ row_missingness = self.missingness._transform_polars(X)
+ row_missingness = row_missingness.get_column(row_missingness.columns[0])
+ row_present = X.shape[1] - row_missingness
+ iqr = [row_present.quantile(0.25), row_present.quantile(0.75)]
+ self.config.min_prevalence = iqr[0] - 1.5 * (iqr[1] - iqr[0])
+ return self
+
+ def _partial_fit_polars(self, X: "pl.DataFrame"):
+ row_missingness = self.missingness._transform_polars(X)
+ row_missingness = row_missingness.get_column(row_missingness.columns[0])
+ row_present = X.shape[1] - row_missingness
+ iqr = [row_present.quantile(0.25), row_present.quantile(0.75)]
+ if self.IQRs is None:
+ self.IQRs = [[iqr[0]], [iqr[1]]]
+ else:
+ self.IQRs[0].append(iqr[0])
+ self.IQRs[1].append(iqr[1])
+ return self
+
+ def _fit_pandas(self, X: "pd.DataFrame"):
+ row_missingness = self.missingness._transform_pandas(X).iloc[:, 0]
+ row_present = X.shape[1] - row_missingness
+ iqr = [row_present.quantile(0.25), row_present.quantile(0.75)]
+ self.config.min_prevalence = iqr[0] - 1.5 * (iqr[1] - iqr[0])
+ return self
+
+ def _partial_fit_pandas(self, X: "pd.DataFrame"):
+ row_missingness = self.missingness._transform_pandas(X).iloc[:, 0]
+ row_present = X.shape[1] - row_missingness
+ iqr = [row_present.quantile(0.25), row_present.quantile(0.75)]
+ if self.IQRs is None:
+ self.IQRs = [[iqr[0]], [iqr[1]]]
+ else:
+ self.IQRs[0].append(iqr[0])
+ self.IQRs[1].append(iqr[1])
+ return self
+
+ def _fit_numpy(self, X: np.ndarray):
+ row_missingness = self.missingness._transform_numpy(X)[:, 0]
+ row_present = X.shape[1] - row_missingness
+ iqr = [np.quantile(row_present, 0.25), np.quantile(row_present, 0.75)]
+ self.config.min_prevalence = iqr[0] - 1.5 * (iqr[1] - iqr[0])
+ return self
+
+ def _partial_fit_numpy(self, X: np.ndarray):
+ row_missingness = self.missingness._transform_numpy(X)[:, 0]
+ row_present = X.shape[1] - row_missingness
+ iqr = [np.quantile(row_present, 0.25), np.quantile(row_present, 0.75)]
+ if self.IQRs is None:
+ self.IQRs = [[iqr[0]], [iqr[1]]]
+ else:
+ self.IQRs[0].append(iqr[0])
+ self.IQRs[1].append(iqr[1])
+ return self
+
+ def _pool_fit_any(self, partial_results: List["MinPrevalenceSampleFilter"]):
+ IQRs = [[], []]
+ for result in partial_results:
+ IQRs[0].extend(result.IQRs[0])
+ IQRs[1].extend(result.IQRs[1])
+ IQRs[0] = np.mean(IQRs[0])
+ IQRs[1] = np.mean(IQRs[1])
+ self.config.min_prevalence = IQRs[0] - 1.5 * (IQRs[1] - IQRs[0])
+ return self
+
+ def _process_fit_output(self, input, out):
+ logger.info(f"Minimum prevalence set to {out.config.min_prevalence * 100:.0f}%")
+ return super()._process_fit_output(input, out)
+
+ def _transform_polars(self, X: "pl.DataFrame"):
+ if is_polars_available() and "polars" in sys.modules:
+ import polars as pl
+ total_present: pl.DataFrame = self.missingness._transform_polars(
+ X
+ ).with_columns((X.shape[1] - pl.col("*")).alias("sum"))
+
+ if self.config.min_prevalence < 1:
+ total_present = total_present.with_columns(pl.col("sum") / X.shape[1])
+ tests = total_present.with_columns(pl.col("sum") > self.config.min_prevalence)
+ if len(tests) == 1:
+ return tests[0]
+ return tests.to_numpy()[:, 0].tolist()
+
+ def _transform_pandas(self, X: pd.DataFrame):
+ total_present = X.shape[1] - self.missingness._transform_pandas(X)
+ if self.config.min_prevalence < 1:
+ total_present = total_present / X.shape[1]
+ tests = total_present > self.config.min_prevalence
+ if len(tests) == 1:
+ return tests[0]
+ return tests.values[:, 0].tolist()
+
+ def _transform_numpy(self, X: np.ndarray):
+ total_present = X.shape[1] - self.missingness._transform_numpy(X)
+ if self.config.min_prevalence < 1:
+ total_present = total_present / X.shape[1]
+ tests = total_present > self.config.min_prevalence
+ if len(tests) == 1:
+ return tests[0]
+ return tests[:, 0].tolist()
diff --git a/src/biofit/preprocessing/filtering/min_prevalence_sample_filter/plot_min_prevalence_sample_filter.py b/src/biofit/preprocessing/filtering/min_prevalence_sample_filter/plot_min_prevalence_sample_filter.py
new file mode 100644
index 0000000..9cd8930
--- /dev/null
+++ b/src/biofit/preprocessing/filtering/min_prevalence_sample_filter/plot_min_prevalence_sample_filter.py
@@ -0,0 +1,304 @@
+from dataclasses import dataclass, field
+from typing import List, Type
+
+import pyarrow as pa
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import SelectedColumnTypes
+from biofit.utils.types import Unset
+
+from ..plot_filtering import (
+ SampleFilterPlotter,
+ SampleFilterPlotterConfig,
+)
+
+
+@dataclass
+class MinPrevalenceRowPlotterConfig(SampleFilterPlotterConfig):
+ plot_process_desc: str = field(
+ default="Plotting presence feature selection", init=False, repr=False
+ )
+ _transform_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("METADATA_FEATURE_TYPES"),
+ get_feature("METADATA_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+ _compare: bool = field(default=True, init=False, repr=False)
+ processor_name: str = field(
+ default="min_prevalence_sample_filter", init=False, repr=False
+ )
+
+ sample_xlab: str = "Sum of Counts"
+ sample_main: str = "Sample Distribution"
+ feature_xlab: str = "Sum of Counts"
+ feature_main: str = "Feature Distribution"
+ legend_position: str = "top"
+ legend_title: str = "Min Prevalence Sample SampleFiltering"
+ before_name: str = "Before"
+ after_name: str = "After"
+ xlog: str = None
+ ylog: str = None
+ ncol: int = 2
+ include_non_zero_sum: bool = False
+ non_zero_samp_xlab: str = None
+ non_zero_samp_main: str = None
+ non_zero_feat_xlab: str = None
+ non_zero_feat_main: str = None
+
+
+@dataclass
+class MinPrevalenceRowPlotterConfigForMetagenomics(MinPrevalenceRowPlotterConfig):
+ dataset_name: str = field(default="metagenomics", init=False, repr=False)
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ [get_feature("Abundance"), get_feature("ReadCount")],
+ [get_feature("Abundance"), get_feature("ReadCount")],
+ ],
+ init=False,
+ repr=False,
+ )
+
+ feature_main = "Taxa Total Abundance Distribution"
+ non_zero_samp_xlab: str = "Species Richness"
+ non_zero_samp_main: str = "Richness Distribution"
+ non_zero_feat_xlab: str = "Species Prevalence"
+ non_zero_feat_main: str = "Prevalence Across Samples"
+ xlog = "log2_1p"
+ include_non_zero_sum: bool = True
+
+
+@dataclass
+class MinPrevalenceRowPlotterConfigForOTU(MinPrevalenceRowPlotterConfigForMetagenomics):
+ dataset_name: str = field(default="otu", init=False, repr=False)
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance"), get_feature("Abundance")],
+ init=False,
+ repr=False,
+ )
+ feature_main: str = "OTU Total Abundance Distribution"
+ include_non_zero_sum: bool = True
+
+
+@dataclass
+class MinPrevalenceRowPlotterConfigForASV(MinPrevalenceRowPlotterConfigForMetagenomics):
+ dataset_name: str = field(default="asv", init=False, repr=False)
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance"), get_feature("Abundance")],
+ init=False,
+ repr=False,
+ )
+ include_non_zero_sum: bool = True
+
+
+@dataclass
+class MinPrevalenceRowPlotterConfigForGenomics(MinPrevalenceRowPlotterConfig):
+ dataset_name: str = field(default="genomics", init=False, repr=False)
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ (get_feature("ReadCount"), get_feature("GenomicVariant")),
+ (get_feature("ReadCount"), get_feature("GenomicVariant")),
+ ],
+ init=False,
+ repr=False,
+ )
+
+
+@dataclass
+class MinPrevalenceRowPlotterConfigForSNP(MinPrevalenceRowPlotterConfig):
+ dataset_name: str = field(default="snp", init=False, repr=False)
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("GenomicVariant"),
+ get_feature("GenomicVariant"),
+ ],
+ init=False,
+ repr=False,
+ )
+
+
+@dataclass
+class MinPrevalenceRowPlotterConfigForMaldi(MinPrevalenceRowPlotterConfig):
+ dataset_name: str = field(default="maldi", init=False, repr=False)
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("PeakIntensity"),
+ get_feature("PeakIntensity"),
+ ],
+ init=False,
+ repr=False,
+ )
+ feature_main: str = "Peak Total Intensity Distribution"
+ non_zero_samp_xlab: str = "Peak Richness"
+ non_zero_samp_main: str = "Richness Distribution"
+ non_zero_feat_xlab: str = "Peak Prevalence"
+ non_zero_feat_main: str = "Prevalence Across Samples"
+ xlog = "log2_1p"
+ include_non_zero_sum: bool = True
+
+
+@dataclass
+class MinPrevalenceRowPlotterConfigForReadCount(MinPrevalenceRowPlotterConfig):
+ dataset_name: str = field(default="read_count", init=False, repr=False)
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("ReadCount"), get_feature("ReadCount")],
+ init=False,
+ repr=False,
+ )
+
+
+@dataclass
+class MinPrevalenceRowPlotterConfigForProteomics(MinPrevalenceRowPlotterConfig):
+ dataset_name: str = field(default="proteomics", init=False, repr=False)
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("PeakIntensity"),
+ get_feature("PeakIntensity"),
+ ],
+ init=False,
+ repr=False,
+ )
+
+
+class MinPrevalenceRowPlotter(SampleFilterPlotter):
+ config_class = MinPrevalenceRowPlotterConfig
+ config: MinPrevalenceRowPlotterConfig
+
+ def __init__(
+ self,
+ sample_xlab: str = Unset('"Sum of Counts"'),
+ sample_main: str = Unset('"Sample Distribution"'),
+ feature_xlab: str = Unset('"Sum of Counts"'),
+ feature_main: str = Unset('"Feature Distribution"'),
+ legend_position: str = Unset('"top"'),
+ legend_title: str = Unset('"Min Prevalence Sample SampleFiltering"'),
+ before_name: str = Unset('"Before"'),
+ after_name: str = Unset('"After"'),
+ xlog: str = Unset("None"),
+ ylog: str = Unset("None"),
+ ncol: int = Unset("2"),
+ include_non_zero_sum: bool = Unset("False"),
+ non_zero_samp_xlab: str = Unset("None"),
+ non_zero_samp_main: str = Unset("None"),
+ non_zero_feat_xlab: str = Unset("None"),
+ non_zero_feat_main: str = Unset("None"),
+ config: MinPrevalenceRowPlotterConfig = None,
+ **kwargs,
+ ):
+ super().__init__(
+ config=config,
+ sample_xlab=sample_xlab,
+ sample_main=sample_main,
+ feature_xlab=feature_xlab,
+ feature_main=feature_main,
+ legend_position=legend_position,
+ legend_title=legend_title,
+ before_name=before_name,
+ after_name=after_name,
+ xlog=xlog,
+ ylog=ylog,
+ ncol=ncol,
+ include_non_zero_sum=include_non_zero_sum,
+ non_zero_samp_xlab=non_zero_samp_xlab,
+ non_zero_samp_main=non_zero_samp_main,
+ non_zero_feat_xlab=non_zero_feat_xlab,
+ non_zero_feat_main=non_zero_feat_main,
+ **kwargs,
+ )
+
+ def plot_pandas(self, x1, x2):
+ row_sums = [pa.array(x1.sum(axis=1)), pa.array(x2.sum(axis=1))]
+ col_sums = [pa.array(x1.sum(axis=0)), pa.array(x2.sum(axis=0))]
+ input = [row_sums, col_sums]
+ mains = [self.config.sample_main, self.config.feature_main]
+ xlabs = [self.config.sample_xlab, self.config.feature_xlab]
+ if self.config.include_non_zero_sum:
+ non_zero_sample_sums = [
+ pa.array(x1[x1 > 0].count(axis=1)),
+ pa.array(x2[x2 > 0].count(axis=1)),
+ ]
+ non_zero_feature_sums = [
+ pa.array(x1[x1 > 0].count(axis=0)),
+ pa.array(x2[x2 > 0].count(axis=0)),
+ ]
+ mains.extend(
+ [
+ self.config.non_zero_samp_main,
+ self.config.non_zero_feat_main,
+ ]
+ )
+ xlabs.extend(
+ [
+ self.config.non_zero_samp_xlab,
+ self.config.non_zero_feat_xlab,
+ ]
+ )
+ input.extend([non_zero_sample_sums, non_zero_feature_sums])
+
+ self.plotter(
+ list_of_sums=input,
+ xlabs=xlabs,
+ mains=mains,
+ **self.config.get_params(),
+ )
+
+ def plot(
+ self,
+ x1,
+ x2=None,
+ input_columns1: SelectedColumnTypes = None,
+ input_columns2: SelectedColumnTypes = None,
+ sample_xlab: str = Unset('"Sum of Counts"'),
+ sample_main: str = Unset('"Sample Distribution"'),
+ feature_xlab: str = Unset('"Sum of Counts"'),
+ feature_main: str = Unset('"Feature Distribution"'),
+ legend_position: str = Unset('"top"'),
+ legend_title: str = Unset('"Min Prevalence Sample SampleFiltering"'),
+ before_name: str = Unset('"Before"'),
+ after_name: str = Unset('"After"'),
+ xlog: str = Unset("None"),
+ ylog: str = Unset("None"),
+ ncol: int = Unset("2"),
+ include_non_zero_sum: bool = Unset("False"),
+ non_zero_samp_xlab: str = Unset("None"),
+ non_zero_samp_main: str = Unset("None"),
+ non_zero_feat_xlab: str = Unset("None"),
+ non_zero_feat_main: str = Unset("None"),
+ path: str = None,
+ device: str = "pdf",
+ fingerprint: str = None,
+ unused_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = True,
+ show: bool = True,
+ ):
+ self.config._input_columns = self._set_input_columns_and_arity(
+ input_columns1, input_columns2
+ )
+ return self._plot(
+ x1,
+ x2,
+ sample_xlab=sample_xlab,
+ sample_main=sample_main,
+ feature_xlab=feature_xlab,
+ feature_main=feature_main,
+ legend_position=legend_position,
+ legend_title=legend_title,
+ before_name=before_name,
+ after_name=after_name,
+ xlog=xlog,
+ ylog=ylog,
+ ncol=ncol,
+ include_non_zero_sum=include_non_zero_sum,
+ non_zero_samp_xlab=non_zero_samp_xlab,
+ non_zero_samp_main=non_zero_samp_main,
+ non_zero_feat_xlab=non_zero_feat_xlab,
+ non_zero_feat_main=non_zero_feat_main,
+ path=path,
+ device=device,
+ fingerprint=fingerprint,
+ unused_columns=unused_columns,
+ raise_if_missing=raise_if_missing,
+ show=show,
+ )
diff --git a/src/biofit/preprocessing/filtering/missing_labels/__init__.py b/src/biofit/preprocessing/filtering/missing_labels/__init__.py
new file mode 100644
index 0000000..2e9f104
--- /dev/null
+++ b/src/biofit/preprocessing/filtering/missing_labels/__init__.py
@@ -0,0 +1,2 @@
+# ruff: noqa
+from .missing_labels import MissingLabelsSampleFilter, MissingLabelsSampleFilterConfig
diff --git a/src/biofit/preprocessing/filtering/missing_labels/missing_labels.py b/src/biofit/preprocessing/filtering/missing_labels/missing_labels.py
new file mode 100644
index 0000000..d8d3bd6
--- /dev/null
+++ b/src/biofit/preprocessing/filtering/missing_labels/missing_labels.py
@@ -0,0 +1,221 @@
+from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, List, Optional, Type, Union
+
+import numpy as np
+import pandas as pd
+import pyarrow as pa
+import pyarrow.compute as pc
+from biocore import DataHandler
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import SelectedColumnTypes
+from biofit.utils import logging
+
+from ..filtering import SampleFilter, SampleFilterConfig
+
+if TYPE_CHECKING:
+ import polars as pl
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class MissingLabelsSampleFilterConfig(SampleFilterConfig):
+ # process description
+ _transform_process_desc: str = field(
+ default="SampleFilter out rows with missing labels", init=False, repr=False
+ )
+ processor_name: str = field(default="missing_labels", init=False, repr=False)
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("TARGET_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("TARGET_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+
+ missing_label: Optional[Union[str, int]] = "auto"
+
+ def __post_init__(self):
+ self._transform_process_desc = (
+ f"SampleFiltering out rows with labels equaling to {self.missing_label}"
+ )
+
+
+class MissingLabelsSampleFilter(SampleFilter):
+ """Remove samples that are not labeled
+
+ - config:
+ - mising_label: The value we deem as missing and want to remove. Default is "auto".
+ """
+
+ # main config class
+ config_class = MissingLabelsSampleFilterConfig
+ config: MissingLabelsSampleFilterConfig
+
+ def __init__(
+ self,
+ config: Optional[MissingLabelsSampleFilterConfig] = None,
+ missing_label: Optional[Union[str, int]] = "auto",
+ **kwargs,
+ ):
+ super().__init__(config=config, missing_label=missing_label, **kwargs)
+
+ def fit(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ num_proc: int = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ) -> "MissingLabelsSampleFilter":
+ self.config._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_fit(
+ X,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ )
+
+ def transform(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ self._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_transform(
+ X,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def fit_transform(
+ self,
+ X,
+ *args,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = True,
+ raise_if_missing: bool = False,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ num_proc: int = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ):
+ return self.fit(
+ X,
+ input_columns=input_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ ).transform(
+ X,
+ input_columns=input_columns,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def _process_transform_input(self, X, **kwargs):
+ selected_indices = kwargs["fn_kwargs"].get("selected_indices", None) or [0]
+ if self.config.missing_label == "auto":
+ if DataHandler.is_categorical(X, selected_indices[0], threshold=30):
+ self.config.missing_label = -1
+ kwargs["desc"] = self._transform_process_desc = (
+ "SampleFiltering out rows with labels equaling to -1"
+ )
+ else:
+ self.config.missing_label = None
+ kwargs["desc"] = self._transform_process_desc = (
+ "SampleFiltering out rows with labels equaling to None"
+ )
+ return super()._process_transform_input(X, **kwargs)
+
+ def _transform_arrow(self, X: Union[pa.Table, pa.Array]):
+ if isinstance(X, pa.Table):
+ return (
+ pc.not_equal(X.column(0), self.config.missing_label).to_numpy().tolist()
+ )
+ return pc.not_equal(X, self.config.missing_label).to_numpy().tolist()
+
+ def _transform_pandas(self, X: pd.DataFrame):
+ return X.ne(self.config.missing_label).values.tolist()
+
+ def _transform_polars(self, X: "pl.DataFrame"):
+ import polars as pl
+
+ if isinstance(X, pl.Series):
+ return X.ne(self.config.missing_label).to_numpy().tolist()
+ return (
+ X.get_column(X.columns[0]).ne(self.config.missing_label).to_numpy().tolist()
+ )
+
+ def _transform_numpy(self, X: np.ndarray):
+ return X != self.config.missing_label
diff --git a/src/biofit/preprocessing/filtering/plot_filtering.R b/src/biofit/preprocessing/filtering/plot_filtering.R
new file mode 100644
index 0000000..23e8a9b
--- /dev/null
+++ b/src/biofit/preprocessing/filtering/plot_filtering.R
@@ -0,0 +1,69 @@
+source(file.path(R_SCRIPTS_PATH, "plotting_utils.R"))
+
+
+plot_filter <- function(
+ list_of_sums, path,
+ xlabs,
+ mains,
+ legend_position = "top",
+ legend_title = "Filtering",
+ before_name = "Before",
+ after_name = "After",
+ xlog = NULL,
+ ylog = NULL,
+ ...) {
+
+
+ suppressPackageStartupMessages(require(ggplot2))
+ suppressPackageStartupMessages(require(RColorBrewer))
+ suppressPackageStartupMessages(require(circlize))
+ suppressPackageStartupMessages(require(patchwork))
+
+ if (!is.list(list_of_sums) && !is.vector(list_of_sums)) {
+ list_of_sums <- list(list_of_sums)
+ }
+
+ if (!is.list(xlabs) && !is.vector(xlabs)) {
+ xlabs <- list(xlabs)
+ }
+
+ if (!is.list(mains) && !is.vector(mains)) {
+ mains <- list(mains)
+ }
+
+ if (length(xlabs) != length(list_of_sums) && length(xlabs) == 1) {
+ xlabs <- rep(xlabs, length(list_of_sums))
+ }
+
+ if (length(mains) != length(list_of_sums) && length(mains) == 1) {
+ mains <- rep(mains, length(list_of_sums))
+ }
+ plots <- NULL
+ for (i in seq_along(list_of_sums)) {
+ x1 <- as.vector(list_of_sums[[i]][[1]])
+ x2 <- as.vector(list_of_sums[[i]][[2]])
+ if (is.null(plots)) {
+ plots <- generate_comparison_histogram(
+ x1, x2,
+ xlab = xlabs[[i]], title = mains[[i]], xlog = xlog, ylog = ylog,
+ legend_title = legend_title, subplot_title1 = before_name,
+ subplot_title2 = after_name
+ )
+ } else {
+ plots <- plots + generate_comparison_histogram(
+ x1, x2,
+ xlab = xlabs[[i]], title = mains[[i]], xlog = xlog, ylog = ylog,
+ legend_title = legend_title, subplot_title1 = before_name,
+ subplot_title2 = after_name
+ )
+ }
+ }
+
+ grid <- plots + patchwork::plot_layout(guides = "collect", ncol = ncol) &
+ ggplot2::theme(
+ text = ggplot2::element_text(size = 8),
+ legend.position = legend_position,
+ )
+
+ save_plots(path, plot = grid, width = 6, height = 6.5, dpi = 600)
+}
diff --git a/src/biofit/preprocessing/filtering/plot_filtering.py b/src/biofit/preprocessing/filtering/plot_filtering.py
new file mode 100644
index 0000000..2bf2dff
--- /dev/null
+++ b/src/biofit/preprocessing/filtering/plot_filtering.py
@@ -0,0 +1,20 @@
+from dataclasses import dataclass, field
+from pathlib import Path
+
+from biofit.visualization.plotting import BasePlotter, PlotterConfig
+
+
+@dataclass
+class SampleFilterPlotterConfig(PlotterConfig):
+ processor_type: str = field(default="filtering", init=False, repr=False)
+ r_source: str = field(
+ default=(Path(__file__).parent / "plot_filtering.R").as_posix(),
+ init=False,
+ repr=False,
+ )
+ main_method: str = field(default="plot_filter", init=False, repr=False)
+
+
+class SampleFilterPlotter(BasePlotter):
+ config_class = SampleFilterPlotterConfig
+ config: SampleFilterPlotterConfig
diff --git a/src/biofit/preprocessing/filtering/row_abundance/__init__.py b/src/biofit/preprocessing/filtering/row_abundance/__init__.py
new file mode 100644
index 0000000..5a7c699
--- /dev/null
+++ b/src/biofit/preprocessing/filtering/row_abundance/__init__.py
@@ -0,0 +1,11 @@
+# ruff: noqa
+from .row_abundance import (
+ AbundanceSampleFilterConfigForOTU,
+ AbundanceSampleFilterConfig,
+ AbundanceSampleFilter,
+)
+from .plot_row_abundance import (
+ AbundanceSampleFilterPlotterConfigForOTU,
+ AbundanceSampleFilterPlotterConfig,
+ AbundanceSampleFilterPlotter,
+)
diff --git a/src/biofit/preprocessing/filtering/row_abundance/plot_row_abundance.py b/src/biofit/preprocessing/filtering/row_abundance/plot_row_abundance.py
new file mode 100644
index 0000000..655622b
--- /dev/null
+++ b/src/biofit/preprocessing/filtering/row_abundance/plot_row_abundance.py
@@ -0,0 +1,327 @@
+from dataclasses import dataclass, field
+from typing import List, Optional, Type
+
+import pyarrow as pa
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import SelectedColumnTypes
+from biofit.utils.types import Unset
+
+from ..plot_filtering import (
+ SampleFilterPlotter,
+ SampleFilterPlotterConfig,
+)
+
+
+@dataclass
+class AbundanceSampleFilterPlotterConfig(SampleFilterPlotterConfig):
+ plot_process_desc: str = field(
+ default="Plotting presence feature selection", init=False, repr=False
+ )
+ _fit_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("METADATA_FEATURE_TYPES"),
+ get_feature("METADATA_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+ _transform_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("METADATA_FEATURE_TYPES"),
+ get_feature("METADATA_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+ _compare: bool = field(default=True, init=False, repr=False)
+ processor_name: str = field(default="row_abundance", init=False, repr=False)
+
+ sample_xlab: str = "Sum of Counts"
+ sample_main: str = "Sample Distribution"
+ feature_xlab: str = "Sum of Counts"
+ feature_main: str = "Feature Distribution"
+ legend_position: str = "top"
+ legend_title: str = "Abundance SampleFiltering"
+ before_name: str = "Before"
+ after_name: str = "After"
+ xlog: str = None
+ ylog: str = None
+ ncol: int = 2
+ include_non_zero_sum: bool = False
+ non_zero_samp_xlab: str = None
+ non_zero_samp_main: str = None
+ non_zero_feat_xlab: str = None
+ non_zero_feat_main: str = None
+
+
+@dataclass
+class AbundanceSampleFilterPlotterConfigForMetagenomics(
+ AbundanceSampleFilterPlotterConfig
+):
+ dataset_name: str = field(default="metagenomics", init=False, repr=False)
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance"), get_feature("Abundance")],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance"), get_feature("Abundance")],
+ init=False,
+ repr=False,
+ )
+
+ feature_main = "Taxa Total Abundance Distribution"
+ non_zero_samp_xlab: str = "Species Richness"
+ non_zero_samp_main: str = "Richness Distribution"
+ non_zero_feat_xlab: str = "Species Prevalence"
+ non_zero_feat_main: str = "Prevalence Across Samples"
+ xlog = "log2_1p"
+ include_non_zero_sum: bool = True
+
+
+@dataclass
+class AbundanceSampleFilterPlotterConfigForOTU(
+ AbundanceSampleFilterPlotterConfigForMetagenomics
+):
+ dataset_name: str = field(default="otu", init=False, repr=False)
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance"), get_feature("Abundance")],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance"), get_feature("Abundance")],
+ init=False,
+ repr=False,
+ )
+ feature_main: str = "OTU Total Abundance Distribution"
+ include_non_zero_sum: bool = True
+
+
+@dataclass
+class AbundanceSampleFilterPlotterConfigForASV(
+ AbundanceSampleFilterPlotterConfigForMetagenomics
+):
+ dataset_name: str = field(default="asv", init=False, repr=False)
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance"), get_feature("Abundance")],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance"), get_feature("Abundance")],
+ init=False,
+ repr=False,
+ )
+ include_non_zero_sum: bool = True
+
+
+@dataclass
+class AbundanceSampleFilterPlotterConfigForGenomics(AbundanceSampleFilterPlotterConfig):
+ dataset_name: str = field(default="genomics", init=False, repr=False)
+ _input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ (get_feature("ReadCount"), get_feature("GenomicVariant")),
+ (get_feature("ReadCount"), get_feature("GenomicVariant")),
+ ],
+ init=False,
+ repr=False,
+ )
+
+
+@dataclass
+class AbundanceSampleFilterPlotterConfigForSNP(AbundanceSampleFilterPlotterConfig):
+ dataset_name: str = field(default="snp", init=False, repr=False)
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("GenomicVariant"),
+ get_feature("GenomicVariant"),
+ ],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("GenomicVariant"),
+ get_feature("GenomicVariant"),
+ ],
+ init=False,
+ repr=False,
+ )
+
+
+@dataclass
+class AbundanceSampleFilterPlotterConfigForReadCount(
+ AbundanceSampleFilterPlotterConfig
+):
+ dataset_name: str = field(default="read_count", init=False, repr=False)
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("ReadCount"), get_feature("ReadCount")],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("ReadCount"), get_feature("ReadCount")],
+ init=False,
+ repr=False,
+ )
+
+
+@dataclass
+class AbundanceSampleFilterPlotterConfigForProteomics(
+ AbundanceSampleFilterPlotterConfig
+):
+ dataset_name: str = field(default="proteomics", init=False, repr=False)
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Expression"), get_feature("Expression")],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Expression"), get_feature("Expression")],
+ init=False,
+ repr=False,
+ )
+
+
+class AbundanceSampleFilterPlotter(SampleFilterPlotter):
+ config_class = AbundanceSampleFilterPlotterConfig
+ config: AbundanceSampleFilterPlotterConfig
+
+ def __init__(
+ self,
+ sample_xlab: str = Unset('"Sum of Counts"'),
+ sample_main: str = Unset('"Sample Distribution"'),
+ feature_xlab: str = Unset('"Sum of Counts"'),
+ feature_main: str = Unset('"Feature Distribution"'),
+ legend_position: str = Unset('"top"'),
+ legend_title: str = Unset('"Sample Abundance SampleFiltering"'),
+ before_name: str = Unset('"Before"'),
+ after_name: str = Unset('"After"'),
+ xlog: str = Unset("None"),
+ ylog: str = Unset("None"),
+ ncol: int = Unset("2"),
+ include_non_zero_sum: bool = Unset("False"),
+ non_zero_samp_xlab: str = Unset("None"),
+ non_zero_samp_main: str = Unset("None"),
+ non_zero_feat_xlab: str = Unset("None"),
+ non_zero_feat_main: str = Unset("None"),
+ config: Optional[AbundanceSampleFilterPlotterConfig] = None,
+ **kwargs,
+ ):
+ super().__init__(
+ config=config,
+ sample_xlab=sample_xlab,
+ sample_main=sample_main,
+ feature_xlab=feature_xlab,
+ feature_main=feature_main,
+ legend_position=legend_position,
+ legend_title=legend_title,
+ before_name=before_name,
+ after_name=after_name,
+ xlog=xlog,
+ ylog=ylog,
+ ncol=ncol,
+ include_non_zero_sum=include_non_zero_sum,
+ non_zero_samp_xlab=non_zero_samp_xlab,
+ non_zero_samp_main=non_zero_samp_main,
+ non_zero_feat_xlab=non_zero_feat_xlab,
+ non_zero_feat_main=non_zero_feat_main,
+ **kwargs,
+ )
+
+ def plot_pandas(self, x1, x2):
+ row_sums = [pa.array(x1.sum(axis=1)), pa.array(x2.sum(axis=1))]
+ col_sums = [pa.array(x1.sum(axis=0)), pa.array(x2.sum(axis=0))]
+ input = [row_sums, col_sums]
+ mains = [self.config.sample_main, self.config.feature_main]
+ xlabs = [self.config.sample_xlab, self.config.feature_xlab]
+ if self.config.include_non_zero_sum:
+ non_zero_sample_sums = [
+ pa.array(x1[x1 > 0].count(axis=1)),
+ pa.array(x2[x2 > 0].count(axis=1)),
+ ]
+ non_zero_feature_sums = [
+ pa.array(x1[x1 > 0].count(axis=0)),
+ pa.array(x2[x2 > 0].count(axis=0)),
+ ]
+ mains.extend(
+ [
+ self.config.non_zero_samp_main,
+ self.config.non_zero_feat_main,
+ ]
+ )
+ xlabs.extend(
+ [
+ self.config.non_zero_samp_xlab,
+ self.config.non_zero_feat_xlab,
+ ]
+ )
+ input.extend([non_zero_sample_sums, non_zero_feature_sums])
+
+ self.plotter(
+ list_of_sums=input,
+ xlabs=xlabs,
+ mains=mains,
+ **self.config.get_params(),
+ )
+
+ def plot(
+ self,
+ x1,
+ x2=None,
+ input_columns1: SelectedColumnTypes = None,
+ input_columns2: SelectedColumnTypes = None,
+ sample_xlab: str = Unset('"Sum of Counts"'),
+ sample_main: str = Unset('"Sample Distribution"'),
+ feature_xlab: str = Unset('"Sum of Counts"'),
+ feature_main: str = Unset('"Feature Distribution"'),
+ legend_position: str = Unset('"top"'),
+ legend_title: str = Unset('"Sample Abundance SampleFiltering"'),
+ before_name: str = Unset('"Before"'),
+ after_name: str = Unset('"After"'),
+ xlog: str = Unset("None"),
+ ylog: str = Unset("None"),
+ ncol: int = Unset("2"),
+ include_non_zero_sum: bool = Unset("False"),
+ non_zero_samp_xlab: str = Unset("None"),
+ non_zero_samp_main: str = Unset("None"),
+ non_zero_feat_xlab: str = Unset("None"),
+ non_zero_feat_main: str = Unset("None"),
+ path: str = None,
+ device: str = "pdf",
+ fingerprint: str = None,
+ unused_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = True,
+ show: bool = True,
+ ):
+ self.config._input_columns = self._set_input_columns_and_arity(
+ input_columns1, input_columns2
+ )
+ return self._plot(
+ x1,
+ x2,
+ sample_xlab=sample_xlab,
+ sample_main=sample_main,
+ feature_xlab=feature_xlab,
+ feature_main=feature_main,
+ legend_position=legend_position,
+ legend_title=legend_title,
+ before_name=before_name,
+ after_name=after_name,
+ xlog=xlog,
+ ylog=ylog,
+ ncol=ncol,
+ include_non_zero_sum=include_non_zero_sum,
+ non_zero_samp_xlab=non_zero_samp_xlab,
+ non_zero_samp_main=non_zero_samp_main,
+ non_zero_feat_xlab=non_zero_feat_xlab,
+ non_zero_feat_main=non_zero_feat_main,
+ path=path,
+ device=device,
+ fingerprint=fingerprint,
+ unused_columns=unused_columns,
+ raise_if_missing=raise_if_missing,
+ show=show,
+ )
diff --git a/src/biofit/preprocessing/filtering/row_abundance/row_abundance.py b/src/biofit/preprocessing/filtering/row_abundance/row_abundance.py
new file mode 100644
index 0000000..635f3ae
--- /dev/null
+++ b/src/biofit/preprocessing/filtering/row_abundance/row_abundance.py
@@ -0,0 +1,312 @@
+from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, List, Type, Union
+
+import numpy as np
+import pandas as pd
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import SelectedColumnTypes
+from biofit.utils import logging
+
+from ..filtering import SampleFilter, SampleFilterConfig
+
+if TYPE_CHECKING:
+ import polars as pl
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class AbundanceSampleFilterConfig(SampleFilterConfig):
+ lower_threshold: Union[int, str] = "auto"
+ upper_threshold: Union[int, str] = None
+ processor_name: str = field(default="row_abundance", init=False, repr=False)
+ _fit_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ _transform_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+
+
+@dataclass
+class AbundanceSampleFilterConfigForOTU(AbundanceSampleFilterConfig):
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance")], init=False, repr=False
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance")], init=False, repr=False
+ )
+ dataset_name: str = field(default="otu", init=False, repr=False)
+ lower_threshold: Union[int, str] = "auto"
+ upper_threshold: Union[int, str] = "auto"
+
+ def __post_init__(self):
+ self._transform_process_desc = f"SampleFiltering out samples with less than {self.lower_threshold} total otu abundance"
+
+
+class AbundanceSampleFilter(SampleFilter):
+ config_class = AbundanceSampleFilterConfig
+ config: AbundanceSampleFilterConfig
+
+ IQRs = None
+
+ def fit(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ num_proc: int = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ) -> "AbundanceSampleFilter":
+ self.config._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_fit(
+ X,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ )
+
+ def transform(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ self._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_transform(
+ X,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def fit_transform(
+ self,
+ X,
+ *args,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = True,
+ raise_if_missing: bool = False,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ num_proc: int = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ):
+ return self.fit(
+ X,
+ input_columns=input_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ ).transform(
+ X,
+ input_columns=input_columns,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def _process_fit_input(self, input, **kwargs):
+ if (
+ self.config.lower_threshold != "auto"
+ and self.config.upper_threshold != "auto"
+ ):
+ kwargs["fn_kwargs"]["fn"] = None
+ return super()._process_fit_input(input, **kwargs)
+
+ def _fit_polars(self, X: "pl.DataFrame"):
+ iqr = [X.sum_horizontal().quantile(0.25), X.sum_horizontal().quantile(0.75)]
+ if self.config.lower_threshold == "auto":
+ self.config.lower_threshold = iqr[0] - 1.5 * (iqr[1] - iqr[0])
+ if self.config.upper_threshold == "auto":
+ self.config.upper_threshold = iqr[1] + 1.5 * (iqr[1] - iqr[0])
+ return self
+
+ def _partial_fit_polars(self, X: "pl.DataFrame"):
+ iqr = [X.sum_horizontal().quantile(0.25), X.sum_horizontal().quantile(0.75)]
+ if self.IQRs is None:
+ self.IQRs = [[iqr[0]], [iqr[1]]]
+ else:
+ self.IQRs[0].append(iqr[0])
+ self.IQRs[1].append(iqr[1])
+ return self
+
+ def _fit_pandas(self, X: "pd.DataFrame"):
+ iqr = [X.sum(axis=1).quantile(0.25), X.sum(axis=1).quantile(0.75)]
+ if self.config.lower_threshold == "auto":
+ self.config.lower_threshold = iqr[0] - 1.5 * (iqr[1] - iqr[0])
+ if self.config.upper_threshold == "auto":
+ self.config.upper_threshold = iqr[1] + 1.5 * (iqr[1] - iqr[0])
+ return self
+
+ def _partial_fit_pandas(self, X: "pd.DataFrame"):
+ iqr = [X.sum(axis=1).quantile(0.25), X.sum(axis=1).quantile(0.75)]
+ if self.IQRs is None:
+ self.IQRs = [[iqr[0]], [iqr[1]]]
+ else:
+ self.IQRs[0].append(iqr[0])
+ self.IQRs[1].append(iqr[1])
+ return self
+
+ def _fit_numpy(self, X: np.ndarray):
+ iqr = [np.quantile(X.sum(axis=1), 0.25), np.quantile(X.sum(axis=1), 0.75)]
+ if self.config.lower_threshold == "auto":
+ self.config.lower_threshold = iqr[0] - 1.5 * (iqr[1] - iqr[0])
+ if self.config.upper_threshold == "auto":
+ self.config.upper_threshold = iqr[1] + 1.5 * (iqr[1] - iqr[0])
+ return self
+
+ def _partial_fit_numpy(self, X: np.ndarray):
+ iqr = [np.quantile(X.sum(axis=1), 0.25), np.quantile(X.sum(axis=1), 0.75)]
+ if self.IQRs is None:
+ self.IQRs = [[iqr[0]], [iqr[1]]]
+ else:
+ self.IQRs[0].append(iqr[0])
+ self.IQRs[1].append(iqr[1])
+ return self
+
+ def _pool_fit_any(self, partial_results: List["AbundanceSampleFilter"]):
+ IQRs = [[], []]
+ for result in partial_results:
+ IQRs[0].extend(result.IQRs[0])
+ IQRs[1].extend(result.IQRs[1])
+ IQRs[0] = np.mean(IQRs[0])
+ IQRs[1] = np.mean(IQRs[1])
+ if self.config.lower_threshold == "auto":
+ self.config.lower_threshold = IQRs[0] - 1.5 * (IQRs[1] - IQRs[0])
+ if self.config.upper_threshold == "auto":
+ self.config.upper_threshold = IQRs[1] + 1.5 * (IQRs[1] - IQRs[0])
+ return self
+
+ def _process_fit_output(self, input, out):
+ if out.config.lower_threshold and out.config.upper_threshold:
+ logger.info(
+ f"Using {out.config.lower_threshold} as lower threshold and "
+ f"{out.config.upper_threshold} as upper threshold for filtering samples"
+ )
+ elif out.config.lower_threshold and out.config.upper_threshold:
+ logger.info(
+ f"Using {out.config.lower_threshold} as lower threshold for filtering samples"
+ )
+ elif out.config.upper_threshold:
+ logger.info(
+ f"Using {out.config.upper_threshold} as upper threshold for filtering samples"
+ )
+ return super()._process_fit_output(input, out)
+
+ def _process_transform_input(self, input, **kwargs):
+ if self.config.lower_threshold and self.config.upper_threshold:
+ kwargs["desc"] = (
+ f"SampleFiltering out examples with less than {self.config.lower_threshold:.2f} and more than {self.config.upper_threshold:.2f} total abundance"
+ )
+ elif self.config.lower_threshold:
+ kwargs["desc"] = (
+ f"SampleFiltering out examples with less than {self.config.lower_threshold:.2f} total abundance"
+ )
+ elif self.config.upper_threshold:
+ kwargs["desc"] = (
+ f"SampleFiltering out examples with more than {self.config.upper_threshold:.2f} total abundance"
+ )
+
+ return super()._process_transform_input(input, **kwargs)
+
+ def _transform_polars(self, X: "pl.DataFrame"):
+ row_sum = X.sum_horizontal()
+ if self.config.upper_threshold and self.config.lower_threshold:
+ return (row_sum > self.config.lower_threshold) & (
+ row_sum < self.config.upper_threshold
+ )
+ elif self.config.upper_threshold:
+ return row_sum < self.config.upper_threshold
+ elif self.config.lower_threshold:
+ return row_sum > self.config.lower_threshold
+
+ def _transform_pandas(self, X: "pd.DataFrame"):
+ row_sum = X.sum(axis=1)
+ if self.config.upper_threshold and self.config.lower_threshold:
+ return (row_sum > self.config.lower_threshold) & (
+ row_sum < self.config.upper_threshold
+ )
+ elif self.config.upper_threshold:
+ return row_sum < self.config.upper_threshold
+ elif self.config.lower_threshold:
+ return row_sum > self.config.lower_threshold
+
+ def _transform_numpy(self, X: np.ndarray):
+ row_sum = X.sum(axis=1)
+ if self.config.upper_threshold and self.config.lower_threshold:
+ return (row_sum > self.config.lower_threshold) & (
+ row_sum < self.config.upper_threshold
+ )
+ elif self.config.upper_threshold:
+ return row_sum < self.config.upper_threshold
+ elif self.config.lower_threshold:
+ return row_sum > self.config.lower_threshold
diff --git a/src/biofit/preprocessing/imputation/__init__.py b/src/biofit/preprocessing/imputation/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/biofit/preprocessing/imputation/imputation.py b/src/biofit/preprocessing/imputation/imputation.py
new file mode 100644
index 0000000..46ad684
--- /dev/null
+++ b/src/biofit/preprocessing/imputation/imputation.py
@@ -0,0 +1,14 @@
+from dataclasses import dataclass, field
+
+from biofit.processing import BaseProcessor, ProcessorConfig
+
+
+@dataclass
+class ImputationConfig(ProcessorConfig):
+ """Base class for imputation processor configurations."""
+
+ processor_name: str = field(default="imputation", init=False, repr=False)
+
+
+class Imputation(BaseProcessor):
+ """Base class for imputation processors."""
diff --git a/src/biofit/preprocessing/resampling/__init__.py b/src/biofit/preprocessing/resampling/__init__.py
new file mode 100644
index 0000000..72aaaf1
--- /dev/null
+++ b/src/biofit/preprocessing/resampling/__init__.py
@@ -0,0 +1,2 @@
+# ruff: noqa
+from .upsampling import *
diff --git a/src/biofit/preprocessing/resampling/resampling.py b/src/biofit/preprocessing/resampling/resampling.py
new file mode 100644
index 0000000..a54da1a
--- /dev/null
+++ b/src/biofit/preprocessing/resampling/resampling.py
@@ -0,0 +1,52 @@
+from dataclasses import dataclass, field
+
+import pandas as pd
+import pyarrow as pa
+from biocore import DataHandler
+
+from biofit.processing import BaseProcessor, ProcessorConfig
+from biofit.utils import logging
+from biofit.utils.table_util import string_to_arrow
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class ResamplerConfig(ProcessorConfig):
+ processor_type: str = field(default="resampling", init=False, repr=False)
+
+
+class Resampler(BaseProcessor):
+ _feature_dependent = (
+ False # SampleFiltering does not depend on features when transforming
+ )
+
+ def _process_transform_batch_output(self, input, out, **fn_kwargs):
+ inds = out["indices"]
+ if len(inds):
+ return DataHandler.select_rows(input, inds)
+ else:
+ cols = DataHandler.get_column_names(input, generate_cols=True)
+ schema = pa.schema(
+ [
+ pa.field(k, string_to_arrow(v))
+ for k, v in DataHandler.get_dtypes(input).items()
+ ]
+ )
+ return pa.Table.from_pandas(
+ pd.DataFrame(columns=cols), preserve_index=False, schema=schema
+ )
+
+ def _process_transform_output(self, output, *args, **kwargs):
+ init_num_rows = DataHandler.get_shape(args[0])[0]
+ final_num_rows = DataHandler.get_shape(output)[0]
+ if final_num_rows != init_num_rows:
+ logger.info(
+ f'"{self.__class__.__name__}": Resampled to {final_num_rows} from '
+ f"{init_num_rows} samples"
+ )
+ else:
+ logger.info(f'"{self.__class__.__name__}": No resampling performed')
+ return output
+
+ pass
diff --git a/src/biofit/preprocessing/resampling/upsampling/__init__.py b/src/biofit/preprocessing/resampling/upsampling/__init__.py
new file mode 100644
index 0000000..f61112e
--- /dev/null
+++ b/src/biofit/preprocessing/resampling/upsampling/__init__.py
@@ -0,0 +1,7 @@
+# ruff: noqa
+from .upsampling import (
+ UpSamplerConfig,
+ UpSampler,
+ UpSamplerConfigForOTU,
+ UpSamplerConfigForSNP,
+)
diff --git a/src/biofit/preprocessing/resampling/upsampling/upsampling.py b/src/biofit/preprocessing/resampling/upsampling/upsampling.py
new file mode 100644
index 0000000..cf06d0f
--- /dev/null
+++ b/src/biofit/preprocessing/resampling/upsampling/upsampling.py
@@ -0,0 +1,479 @@
+from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, List, Type, Union
+
+import numpy as np
+from biocore import DataHandler
+from biocore.utils.import_util import is_imblearn_available, requires_backends
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import SelectedColumnTypes, sync_backup_config
+from biofit.utils import logging
+
+from ..resampling import Resampler, ResamplerConfig
+
+if TYPE_CHECKING:
+ from sklearn.base import BaseEstimator
+
+ from biofit.processing import BaseProcessor
+
+logger = logging.get_logger(__name__)
+
+OVER_SAMPLING_METHODS = {}
+
+if is_imblearn_available():
+ from imblearn.over_sampling import (
+ ADASYN,
+ SMOTE,
+ SMOTEN,
+ SMOTENC,
+ SVMSMOTE,
+ KMeansSMOTE,
+ RandomOverSampler,
+ )
+
+ OVER_SAMPLING_METHODS = {
+ "random": RandomOverSampler,
+ "smote": SMOTE,
+ "smoten": SMOTEN,
+ "smotenc": SMOTENC,
+ "svmsmote": SVMSMOTE,
+ "kmeanssmote": KMeansSMOTE,
+ "adasyn": ADASYN,
+ }
+
+
+@dataclass
+class UpSamplerConfig(ResamplerConfig):
+ # process descriptions
+ processor_name: str = field(default="upsampling", init=False, repr=False)
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [None, get_feature("TARGET_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ _fit_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES"), None],
+ init=False,
+ repr=False,
+ )
+ _transform_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+
+ input_columns: List[str] = field(default=None, init=False, repr=False)
+ target_column: str = field(default=None, init=False, repr=False)
+ method: str = "random"
+
+ sampling_strategy: str = "auto"
+ random_state: int = None
+
+ # RandomOverSampler specific attributes
+ shrinkage: dict = None
+
+ # BaseSMOTE attributes
+ k_neighbors = 5
+ n_jobs = None
+
+ # SMOTEN specific attributes
+ categorical_encoder: Union["BaseEstimator", "BaseProcessor"] = None
+
+ # SMOTENC specific attributes
+ categorical_features: list = None
+
+ # SVMSMOTE specific attributes
+ m_neighbors: int = 10
+ svm_estimator: Union["BaseEstimator", "BaseProcessor"] = None
+ out_step: float = 0.5
+
+ # KMeansSMOTE specific attributes
+ kmeans_estimator: Union[int, "BaseEstimator", "BaseProcessor"] = None
+ density_exponent: Union[float, str] = "auto"
+ cluster_balance_threshold: Union[float, str] = "auto"
+
+ # ADASYN specific attributes
+ n_neighbors: int = 5
+
+ def __post_init__(self):
+ requires_backends("upsampling", "imblearn")
+ if self.random_state is None:
+ self.random_state = np.random.randint(0, np.iinfo(np.int32).max)
+ if self.method not in OVER_SAMPLING_METHODS:
+ raise ValueError(
+ f"Invalid method {self.method}. Valid methods are {list(OVER_SAMPLING_METHODS.keys())}"
+ )
+
+ self.sampler_kwargs_ = {}
+ self.sampler_kwargs_["sampling_strategy"] = self.sampling_strategy
+ self.sampler_kwargs_["random_state"] = self.random_state
+
+ if self.method == "random":
+ self.sampler_kwargs_["sampling_strategy"] = self.sampling_strategy
+
+ if self.method in [
+ "smote",
+ "smoten",
+ "smotenc",
+ "svmsmote",
+ "kmeanssmote",
+ "adasyn",
+ ]:
+ self.sampler_kwargs_["n_jobs"] = self.n_jobs
+
+ if self.method in ["smote", "smoten", "smotenc", "svmsmote", "kmeanssmote"]:
+ self.sampler_kwargs_["k_neighbors"] = self.k_neighbors
+
+ if self.method in ["smoten", "smotenc"]:
+ self.sampler_kwargs_["categorical_encoder"] = self.categorical_encoder
+
+ if self.method == "smotenc":
+ self.sampler_kwargs_["categorical_features"] = self.categorical_features
+
+ if self.method == "svmsmote":
+ self.sampler_kwargs_["m_neighbors"] = self.m_neighbors
+ self.sampler_kwargs_["svm_estimator"] = self.svm_estimator
+ self.sampler_kwargs_["out_step"] = self.out_step
+
+ if self.method == "kmeanssmote":
+ self.sampler_kwargs_["kmeans_estimator"] = self.kmeans_estimator
+ self.sampler_kwargs_["density_exponent"] = self.density_exponent
+ self.sampler_kwargs_["cluster_balance_threshold"] = (
+ self.cluster_balance_threshold
+ )
+
+ if self.method == "adasyn":
+ self.sampler_kwargs_["n_neighbors"] = self.n_neighbors
+
+
+@dataclass
+class UpSamplerConfigForMetagenomics(UpSamplerConfig):
+ # dataset specific attributes
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ (get_feature("Abundance"), get_feature("ReadCount")),
+ get_feature("TARGET_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [(get_feature("Abundance"), get_feature("ReadCount"))],
+ init=False,
+ repr=False,
+ )
+ dataset_name: str = field(default="metagenomics", init=False, repr=False)
+
+
+@dataclass
+class UpSamplerConfigForOTU(UpSamplerConfig):
+ # dataset specific attributes
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("Abundance"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance")], init=False, repr=False
+ )
+ dataset_name: str = field(default="otu", init=False, repr=False)
+
+
+@dataclass
+class UpSamplerConfigForSNP(UpSamplerConfig):
+ # dataset specific attributes
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("GenomicVariant"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("GenomicVariant")], init=False, repr=False
+ )
+ dataset_name: str = field(default="snp", init=False, repr=False)
+
+
+class UpSampler(Resampler):
+ output_dtype = "float32"
+
+ # config class
+ config_class = UpSamplerConfig
+ config: UpSamplerConfig
+
+ def __init__(
+ self,
+ method: str = "random",
+ sampling_strategy: str = "auto",
+ random_state: int = None,
+ shrinkage: dict = None,
+ k_neighbors=5,
+ n_jobs=None,
+ categorical_encoder: Union["BaseEstimator", "BaseProcessor"] = None,
+ categorical_features: list = None,
+ m_neighbors: int = 10,
+ svm_estimator: Union["BaseEstimator", "BaseProcessor"] = None,
+ out_step: float = 0.5,
+ kmeans_estimator: Union[int, "BaseEstimator", "BaseProcessor"] = None,
+ density_exponent: Union[float, str] = "auto",
+ cluster_balance_threshold: Union[float, str] = "auto",
+ n_neighbors: int = 5,
+ config: Union[
+ UpSamplerConfig,
+ UpSamplerConfigForOTU,
+ UpSamplerConfigForSNP,
+ UpSamplerConfigForMetagenomics,
+ ] = None,
+ **kwargs,
+ ):
+ """
+ Upsample the minority class(es) in the dataset.
+
+ Args:
+ method (str, 'random'):
+ The method to use for oversampling. Possible options are:
+ - 'random': Random Over Sampler
+ - 'smote': Synthetic Minority Over-sampling Technique
+ - 'smoten': SMOTE for numerical features only
+ - 'smotenc': SMOTE for numerical and categorical features
+ - 'svmsmote': SVM-SMOTE
+ - 'kmeanssmote': KMeans-SMOTE
+ - 'adasyn': Adaptive Synthetic Sampling Approach for Imbalanced Learning
+
+ sampling_strategy (str, 'auto'):
+ The sampling strategy to use for oversampling. Possible options are:
+ - 'auto': Automatically resample the minority class(es) to the majority class(es) size.
+ - 'all': Resample all classes to the same size.
+
+ random_state (int, *optional*):
+ The random state to use for sampling.
+
+ shrinkage (dict, *optional*):
+ The shrinkage parameter for RandomOverSampler.
+
+ k_neighbors (int, 5):
+ The number of nearest neighbors to use for SMOTE, SMOTEN, SMOTENC, SVMSMOTE, and KMeansSMOTE.
+
+ n_jobs (int, *optional*):
+ The number of jobs to use for SMOTE, SMOTEN, SMOTENC, SVMSMOTE, and KMeansSMOTE.
+
+ categorical_encoder (Union[BaseEstimator, BaseProcessor], *optional*):
+ The encoder to use for SMOTEN.
+
+ categorical_features (list, *optional*):
+ The list of categorical features to use for SMOTENC.
+
+ m_neighbors (int, 10):
+ The number of nearest neighbors to use for SVMSMOTE.
+
+ svm_estimator (Union[BaseEstimator, BaseProcessor], *optional*):
+ The estimator to use for SVMSMOTE.
+
+ out_step (float, 0.5):
+ The outlier step to use for SVMSMOTE.
+
+ kmeans_estimator (Union[int, BaseEstimator, BaseProcessor], *optional*):
+ The estimator to use for KMeansSMOTE.
+
+ density_exponent (Union[float, str], 'auto'):
+ The exponent to use for KMeansSMOTE.
+
+ cluster_balance_threshold (Union[float, str], 'auto'):
+ The balance threshold to use for KMeansSMOTE.
+
+ n_neighbors (int, 5):
+ The number of nearest neighbors to use for ADASYN.
+
+ config (Union[UpSamplerConfig, UpSamplerConfigForOTU, UpSamplerConfigForSNP, UpSamplerConfigForMetagenomics], *optional*):
+ The configuration to use for the upsampling process. If provided, the other arguments are ignored.
+ Use set_params to update the configuration after initialization.
+ **kwargs:
+ Additional keyword arguments to be passed to ProcessorConfig
+ """
+ super().__init__(
+ config=config,
+ method=method,
+ sampling_strategy=sampling_strategy,
+ random_state=random_state,
+ shrinkage=shrinkage,
+ k_neighbors=k_neighbors,
+ n_jobs=n_jobs,
+ categorical_encoder=categorical_encoder,
+ categorical_features=categorical_features,
+ m_neighbors=m_neighbors,
+ svm_estimator=svm_estimator,
+ out_step=out_step,
+ kmeans_estimator=kmeans_estimator,
+ density_exponent=density_exponent,
+ cluster_balance_threshold=cluster_balance_threshold,
+ n_neighbors=n_neighbors,
+ **kwargs,
+ )
+
+ self.over_sampler = OVER_SAMPLING_METHODS[self.config.method](
+ **self.config.sampler_kwargs_
+ )
+
+ @sync_backup_config
+ def set_params(self, **kwargs):
+ self.config = self.config.replace_defaults(**kwargs)
+ self.over_sampler = OVER_SAMPLING_METHODS[self.config.method](
+ **self.config.sampler_kwargs_
+ )
+ return self
+
+ def fit(
+ self,
+ X,
+ y=None,
+ input_columns: SelectedColumnTypes = None,
+ target_column: SelectedColumnTypes = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ num_proc: int = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ) -> "UpSampler":
+ self.config._input_columns = self._set_input_columns_and_arity(
+ input_columns, target_column
+ )
+ return self._process_fit(
+ X,
+ y,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ )
+
+ def transform(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ self._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_transform(
+ X,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def fit_transform(
+ self,
+ X,
+ y=None,
+ input_columns: SelectedColumnTypes = None,
+ target_column: SelectedColumnTypes = None,
+ keep_unused_columns: bool = True,
+ raise_if_missing: bool = True,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ num_proc: int = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ):
+ return self.fit(
+ X,
+ y,
+ input_columns=input_columns,
+ target_column=target_column,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ ).transform(
+ X,
+ input_columns=input_columns,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def _fit_pandas(self, X, y):
+ self.over_sampler.fit_resample(X, DataHandler.select_column(y, 0))
+ self.config.sample_indices = self.over_sampler.sample_indices_
+ return self
+
+ def _fit_numpy(self, X, y):
+ self.over_sampler.fit_resample(X, DataHandler.select_column(y, 0))
+ self.config.sample_indices = self.over_sampler.sample_indices_
+ return self
+
+ def _process_transform_input(self, X, **kwargs):
+ batch_size = kwargs.get("batch_size", None)
+ if batch_size is not None and batch_size < DataHandler.get_shape(X)[0]:
+ logger.warning(
+ "Upsampling does not support batch processing. Ignoring batched and batch_size parameters."
+ )
+ kwargs["batch_size"] = None
+ self.config.fingerprint = kwargs.get("new_fingerprint", None)
+ kwargs["features"] = None
+ return X, kwargs
+
+ def _transform_any(self, X):
+ return {"indices": self.config.sample_indices.flatten()}
diff --git a/src/biofit/preprocessing/scaling/__init__.py b/src/biofit/preprocessing/scaling/__init__.py
new file mode 100644
index 0000000..f5808d2
--- /dev/null
+++ b/src/biofit/preprocessing/scaling/__init__.py
@@ -0,0 +1,9 @@
+# ruff: noqa
+from .tmm import *
+from .css import *
+from .clr import *
+from .relative_abundance import *
+from .plot_scaling import (
+ ScalerPlotter,
+ ScalerPlotterConfig,
+)
diff --git a/src/biofit/preprocessing/scaling/clr/__init__.py b/src/biofit/preprocessing/scaling/clr/__init__.py
new file mode 100644
index 0000000..11c1e56
--- /dev/null
+++ b/src/biofit/preprocessing/scaling/clr/__init__.py
@@ -0,0 +1,2 @@
+# ruff: noqa
+from .clr import CLRScaler, CLRScalerConfig
diff --git a/src/biofit/preprocessing/scaling/clr/clr.py b/src/biofit/preprocessing/scaling/clr/clr.py
new file mode 100644
index 0000000..bd39977
--- /dev/null
+++ b/src/biofit/preprocessing/scaling/clr/clr.py
@@ -0,0 +1,187 @@
+from dataclasses import dataclass, field
+from typing import List, Type
+
+import numpy as np
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import SelectedColumnTypes
+from biofit.utils import logging
+
+from ..scaling import Scaler, ScalerConfig
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class CLRScalerConfig(ScalerConfig):
+ """
+ Configuration for CumulativeSumScaler.
+ """
+
+ processor_name: str = field(default="css", init=False, repr=False)
+ _fit_process_desc: str = field(
+ default="Calculating cumulative sum scaling percentile", init=False, repr=False
+ )
+ _transform_process_desc: str = field(
+ default="Applying cumulative sum scaling", init=False, repr=False
+ )
+ _fit_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ _transform_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+
+ pseudocount: float = 1e-6
+ input_columns: List[str] = None
+
+
+class CLRScaler(Scaler):
+ """
+ BaseCumSumScaler applies cumulative sum scaling to the input data.
+ """
+
+ output_dtype = "float64"
+ config_class = CLRScalerConfig
+ config: CLRScalerConfig
+
+ def __init__(
+ self, pseudocount: float = 1e-6, config: CLRScalerConfig = None, **kwargs
+ ):
+ super().__init__(config=config, pseudocount=pseudocount, **kwargs)
+
+ def fit(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ num_proc: int = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ) -> "CLRScaler":
+ self.config._input_columns = self._set_input_columns_and_arity(
+ input_columns or self.config.input_columns
+ )
+ return self._process_fit(
+ X,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ )
+
+ def transform(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ self._input_columns = self._set_input_columns_and_arity(
+ input_columns or self.config.input_columns
+ )
+ return self._process_transform(
+ X,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def fit_transform(
+ self,
+ X,
+ *args,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = True,
+ raise_if_missing: bool = False,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ num_proc: int = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ):
+ return self.fit(
+ X,
+ input_columns=input_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ ).transform(
+ X,
+ input_columns=input_columns,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def _transform_numpy(self, X: np.ndarray):
+ X = X + self.config.pseudocount
+ # Calculating the geometric mean of each sample
+ from scipy.stats import gmean
+
+ r_geo_mean = gmean(X, axis=1)[:, np.newaxis]
+
+ return np.log(X / r_geo_mean)
diff --git a/src/biofit/preprocessing/scaling/css/__init__.py b/src/biofit/preprocessing/scaling/css/__init__.py
new file mode 100644
index 0000000..1a32968
--- /dev/null
+++ b/src/biofit/preprocessing/scaling/css/__init__.py
@@ -0,0 +1,19 @@
+# ruff: noqa
+from .css import (
+ CumulativeSumScaler,
+ CumulativeSumScalerConfig,
+ CumulativeSumScalerConfigForOTU,
+ CumulativeSumScalerConfigForMetagenomics,
+ # CumulativeSumScalerConfigForGenomics,
+ # CumulativeSumScalerConfigForSNP,
+)
+from .plot_css import (
+ CumulativeSumScalerPlotter,
+ CumulativeSumScalerPlotterConfig,
+ CumulativeSumScalerPlotterConfigForMetagenomics,
+ CumulativeSumScalerPlotterConfigForOTU,
+ CumulativeSumScalerPlotterConfigForASV,
+ CumulativeSumScalerPlotterConfigForGenomics,
+ CumulativeSumScalerPlotterConfigForSNP,
+ CumulativeSumScalerPlotterConfigForReadCount,
+)
diff --git a/src/biofit/preprocessing/scaling/css/css.py b/src/biofit/preprocessing/scaling/css/css.py
new file mode 100644
index 0000000..19abdea
--- /dev/null
+++ b/src/biofit/preprocessing/scaling/css/css.py
@@ -0,0 +1,348 @@
+from dataclasses import dataclass, field
+from typing import List, Type
+
+import numpy as np
+from scipy.interpolate import interp1d
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import SelectedColumnTypes, sync_backup_config
+from biofit.utils import logging
+
+from ..scaling import Scaler, ScalerConfig
+
+logger = logging.get_logger(__name__)
+
+
+def css_percentile_fast(mat, relative_threshold=0.1):
+ """Calculates the percentile for which to sum counts up to and scale by.
+
+ Args:
+
+ """
+ # Check for sample with one or zero features
+ if np.any(np.sum(mat > 0, axis=1) <= 1):
+ raise ValueError("Warning: sample with one or zero features")
+
+ mat = mat.astype(float)
+ # Sort each row with elements greater than zero
+ smat = np.sort(mat * (mat > 0), axis=1)[:, ::-1]
+ smat[smat == 0] = np.nan
+
+ # Pad with NAs to make all rows the same length
+ max_length = np.max(np.nansum(smat > 0, axis=1))
+ padded_smat = np.full((smat.shape[0], max_length), np.nan)
+ for i in range(smat.shape[0]):
+ valid_values = smat[i, smat[i] > 0]
+ padded_smat[i, : len(valid_values)] = valid_values
+
+ # Calculate quantiles for each row
+ quantiles = np.nanquantile(
+ padded_smat, np.linspace(0, 1, padded_smat.shape[1]), axis=1
+ )
+
+ padded_smat[np.isnan(padded_smat)] = 0
+
+ # Compute column means, ignoring NaNs
+ ref1 = np.nanmean(padded_smat, axis=0)[::-1]
+
+ # Calculate differences
+ diffr = ref1[:, np.newaxis] - quantiles
+
+ # Calculate median of absolute differences
+ diffr1 = np.nanmedian(np.abs(diffr), axis=1)
+
+ diffr1[diffr1 == 0] = np.nan
+
+ # Determine threshold
+ rel_diff = np.abs(np.diff(diffr1)) / diffr1[:-1]
+
+ x = (np.where(rel_diff > relative_threshold)[0][0] + 1) / len(diffr1)
+ if x <= 0.5:
+ logger.info(
+ f"Percentile calculated at {x}, which is less than 0.5. Using default value instead."
+ )
+ x = 0.5
+ return x
+
+
+def css_percentile(mat: np.ndarray, approx=False, relative_threshold=0.1):
+ if np.any(np.sum(mat, axis=1) == 0):
+ raise ValueError("Warning: empty feature", mat)
+
+ # Sorting each row
+ smat = np.sort(mat, axis=1)
+ ref = np.mean(smat, axis=0)
+ orig_dtype = mat.dtype
+ mat = mat.astype(float) if not np.issubdtype(orig_dtype, np.floating) else mat
+
+ if not mat.flags["WRITEABLE"]:
+ # copy the array if it is not writeable
+ mat = np.array(mat)
+ mat[mat == 0] = np.nan
+
+ refS = np.sort(ref)
+
+ k = np.where(refS > 0)[0][0]
+ lo = len(refS) - k
+
+ if not approx:
+ diffr = np.apply_along_axis(
+ lambda row: refS[k:] - np.nanquantile(row, np.linspace(0, 1, lo), axis=0),
+ axis=1,
+ arr=mat,
+ )
+ else:
+
+ def f(row):
+ srow = np.sort(row)
+ rrow = (srow[0], np.nanmax(srow))
+ y = np.arange(len(row))
+ return refS[k:] - interp1d(row, y)(np.linspace(*rrow, lo))
+
+ diffr = np.apply_along_axis(f, axis=1, arr=mat)
+
+ mat[np.isnan(mat)] = 0
+ mat = mat.astype(orig_dtype) if not np.issubdtype(orig_dtype, np.floating) else mat
+
+ diffr2 = np.nanmedian(np.abs(diffr), axis=0)
+
+ rel_diff = np.abs(np.diff(diffr2)) / diffr2[:-1]
+ if len(rel_diff) == 0:
+ return 0.5
+ x = (np.where(rel_diff > relative_threshold)[0][0] + 1) / len(diffr2)
+ if x <= 0.5:
+ logger.info(
+ f"Percentile calculated at {x}, which is less than 0.5. Using default value instead."
+ )
+ x = 0.5
+
+ return x
+
+
+@dataclass
+class CumulativeSumScalerConfig(ScalerConfig):
+ """
+ Configuration for CumulativeSumScaler.
+ """
+
+ processor_name: str = field(default="css", init=False, repr=False)
+ _fit_process_desc: str = field(
+ default="Calculating cumulative sum scaling percentile", init=False, repr=False
+ )
+ _transform_process_desc: str = field(
+ default="Applying cumulative sum scaling", init=False, repr=False
+ )
+ _fit_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ _transform_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+
+ scale: int = 1000
+ relative_threshold: float = 0.5
+ approx: bool = False
+ percentile: float = None
+
+
+@dataclass
+class CumulativeSumScalerConfigForMetagenomics(CumulativeSumScalerConfig):
+ """
+ CumulativeSumScaler specifically designed for metagenomics data.
+ """
+
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [(get_feature("ReadCount"), get_feature("Abundance"))],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [(get_feature("ReadCount"), get_feature("Abundance"))],
+ init=False,
+ repr=False,
+ )
+ dataset_name: str = field(default="metagenomics", init=False, repr=False)
+
+
+@dataclass
+class CumulativeSumScalerConfigForOTU(CumulativeSumScalerConfig):
+ """
+ CumulativeSumScaler specifically designed for otu abundance.
+
+ Args:
+ """
+
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance")], init=False, repr=False
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance")], init=False, repr=False
+ )
+ dataset_name: str = field(default="otu", init=False, repr=False)
+
+
+class CumulativeSumScaler(Scaler):
+ """
+ BaseCumSumScaler applies cumulative sum scaling to the input data.
+ """
+
+ output_dtype = "float64"
+ config_class = CumulativeSumScalerConfig
+ config: CumulativeSumScalerConfig
+
+ @sync_backup_config
+ def set_params(self, **kwargs):
+ self.config = self.config.replace_defaults(**kwargs)
+ if self.config.percentile is not None:
+ # there is overhead when passing datasets to the fit method
+ # (e.g. converting one format to another, preparing mapping jobs, etc.)
+ # which is unnecessary if the percentile is already known
+ # Therefore, we remove the fit method to avoid it being called
+ # TODO: This is a temporary solution. We should find a better way to handle this
+ delattr(self, "fit_numpy")
+ return self
+
+ def fit(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ num_proc: int = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ) -> "CumulativeSumScaler":
+ self.config._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_fit(
+ X,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ )
+
+ def transform(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ self._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_transform(
+ X,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def fit_transform(
+ self,
+ X,
+ *args,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = True,
+ raise_if_missing: bool = False,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ num_proc: int = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ):
+ return self.fit(
+ X,
+ input_columns=input_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ ).transform(
+ X,
+ input_columns=input_columns,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def _fit_numpy(self, X: np.ndarray):
+ # The fast implementation of CSS requires counts of all samples to at least
+ # have two non zero features.
+ if np.any(np.sum(X > 0, axis=1) <= 1):
+ self.config.percentile = css_percentile(X)
+ else:
+ self.config.percentile = css_percentile_fast(X)
+ return self
+
+ def _transform_numpy(self, X: np.ndarray):
+ xx = np.where(X == 0, np.nan, X)
+
+ # Compute row quantiles
+ qs = np.nanquantile(xx, self.config.percentile, axis=1, keepdims=True)
+
+ xx_adj = xx - np.finfo(float).eps
+ norm_factors = np.nansum(np.where(xx_adj <= qs, xx_adj, 0), axis=1)
+ norm_factors[norm_factors == 0] = np.nan
+ norm_factors = norm_factors / self.config.scale
+
+ return X / norm_factors[:, np.newaxis]
diff --git a/src/biofit/preprocessing/scaling/css/plot_css.py b/src/biofit/preprocessing/scaling/css/plot_css.py
new file mode 100644
index 0000000..49e3013
--- /dev/null
+++ b/src/biofit/preprocessing/scaling/css/plot_css.py
@@ -0,0 +1,297 @@
+from dataclasses import dataclass, field
+from typing import List, Optional, Type
+
+from biocore import DataHandler
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import SelectedColumnTypes
+from biocore.utils.py_util import is_bioset
+from biofit.utils.types import Unset
+
+from ..plot_scaling import ScalerPlotter, ScalerPlotterConfig
+
+
+@dataclass
+class CumulativeSumScalerPlotterConfig(ScalerPlotterConfig):
+ processor_name: str = field(default="css", init=False, repr=False)
+ _compare: bool = field(default=True, init=False, repr=False)
+ _transoform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ None,
+ None,
+ get_feature("TARGET_FEATURE_TYPES"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+
+ _transform_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("METADATA_FEATURE_TYPES"),
+ get_feature("METADATA_FEATURE_TYPES"),
+ None,
+ None,
+ ],
+ init=False,
+ repr=False,
+ )
+
+ input_columns1: Optional[str] = None
+ input_columns2: Optional[str] = None
+ label_name1: Optional[str] = None
+ label_name2: Optional[str] = None
+ ylab: Optional[str] = None
+ xlab: Optional[str] = None
+ ylim: Optional[list] = None
+ before_title: str = "Before CSS"
+ after_title: str = "After CSS"
+ legend_position: str = "top"
+ add_box: bool = True
+ horizontal_plot: bool = False
+ order: bool = False
+ col_set: str = "Set1"
+ cols: Optional[str] = None
+ log_num: Optional[str] = None
+ show_outliers: bool = True
+
+
+@dataclass
+class CumulativeSumScalerPlotterConfigForMetagenomics(CumulativeSumScalerPlotterConfig):
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ (get_feature("Abundance"), get_feature("ReadCount")),
+ (get_feature("Abundance"), get_feature("ReadCount")),
+ get_feature("TARGET_FEATURE_TYPES"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+ dataset_name: str = field(default="metagenomics", init=False, repr=False)
+ log_num: Optional[str] = "log2_1p"
+
+
+@dataclass
+class CumulativeSumScalerPlotterConfigForOTU(CumulativeSumScalerPlotterConfig):
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("Abundance"),
+ get_feature("Abundance"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+ dataset_name: str = field(default="otu", init=False, repr=False)
+ ylab: Optional[str] = "OTU Abundance"
+ log_num: Optional[str] = "log10_1p"
+
+
+@dataclass
+class CumulativeSumScalerPlotterConfigForASV(CumulativeSumScalerPlotterConfig):
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("Abundance"),
+ get_feature("Abundance"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+ dataset_name: str = field(default="asv", init=False, repr=False)
+
+
+@dataclass
+class CumulativeSumScalerPlotterConfigForGenomics(CumulativeSumScalerPlotterConfig):
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("GenomicVariant"),
+ get_feature("GenomicVariant"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("GenomicVariant"),
+ get_feature("GenomicVariant"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+ dataset_name: str = field(default="genomics", init=False, repr=False)
+
+
+@dataclass
+class CumulativeSumScalerPlotterConfigForSNP(CumulativeSumScalerPlotterConfig):
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("GenomicVariant"),
+ get_feature("GenomicVariant"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("GenomicVariant"),
+ get_feature("GenomicVariant"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+ dataset_name: str = field(default="snp", init=False, repr=False)
+
+
+@dataclass
+class CumulativeSumScalerPlotterConfigForReadCount(CumulativeSumScalerPlotterConfig):
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("ReadCount"),
+ get_feature("ReadCount"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("ReadCount"),
+ get_feature("ReadCount"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+ dataset_name: str = field(default="read_count", init=False, repr=False)
+
+
+class CumulativeSumScalerPlotter(ScalerPlotter):
+ config_class = CumulativeSumScalerPlotterConfig
+ config: CumulativeSumScalerPlotterConfig
+
+ def __init__(
+ self,
+ ylab: Optional[str] = Unset("None"),
+ xlab: Optional[str] = Unset("None"),
+ ylim: Optional[list] = Unset("None"),
+ title: str = Unset('"Violin Plot"'),
+ legend_position: str = Unset('"top"'),
+ add_box: bool = Unset("True"),
+ horizontal_plot: bool = Unset("False"),
+ order: bool = Unset("False"),
+ col_set: str = Unset('"Set1"'),
+ cols: Optional[str] = Unset("None"),
+ log_num: Optional[int] = Unset("None"),
+ show_outliers: bool = Unset("True"),
+ config: CumulativeSumScalerPlotterConfig = None,
+ **kwargs,
+ ):
+ super().__init__(
+ config=config,
+ ylab=ylab,
+ xlab=xlab,
+ ylim=ylim,
+ title=title,
+ legend_position=legend_position,
+ add_box=add_box,
+ horizontal_plot=horizontal_plot,
+ order=order,
+ col_set=col_set,
+ cols=cols,
+ log_num=log_num,
+ show_outliers=show_outliers,
+ **kwargs,
+ )
+
+ def plot(
+ self,
+ x1,
+ x2=None,
+ y1=None,
+ y2=None,
+ input_columns1: SelectedColumnTypes = None,
+ input_columns2: SelectedColumnTypes = None,
+ label_name1: SelectedColumnTypes = None,
+ label_name2: SelectedColumnTypes = None,
+ ylab: Optional[str] = Unset("None"),
+ xlab: Optional[str] = Unset("None"),
+ ylim: Optional[list] = Unset("None"),
+ title: str = Unset('"Violin Plot"'),
+ legend_position: str = Unset('"top"'),
+ add_box: bool = Unset("True"),
+ horizontal_plot: bool = Unset("False"),
+ order: bool = Unset("False"),
+ col_set: str = Unset('"Set1"'),
+ cols: Optional[str] = Unset("None"),
+ log_num: Optional[int] = Unset("None"),
+ show_outliers: bool = Unset("True"),
+ path: str = None,
+ device: str = "pdf",
+ fingerprint: str = None,
+ unused_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = True,
+ show: bool = True,
+ ):
+ if is_bioset(x2):
+ from biosets import get_target
+
+ if y2 is None:
+ y2 = get_target(x2)
+ if x1 is None or x2 is None:
+ raise ValueError("Must provide the before and after normalization.")
+ if y1 is not None and y2 is None:
+ raise ValueError("Must provide the target for the after normalization.")
+ if y1 is None:
+ if label_name1 is None:
+ raise ValueError(
+ "Must provide the target for the before normalization."
+ )
+ y1 = DataHandler.select_columns(x1, label_name1)
+ x1 = DataHandler.drop_columns(x1, label_name1)
+ if y2 is None:
+ if label_name2 is None:
+ raise ValueError("Must provide the target for the after normalization.")
+ y2 = DataHandler.select_columns(x2, label_name2)
+ x2 = DataHandler.drop_columns(x2, label_name2)
+
+ self.config._input_columns = self._set_input_columns_and_arity(
+ input_columns1, input_columns2, label_name1, label_name2
+ )
+ return self._plot(
+ x1,
+ x2,
+ y1,
+ y2,
+ ylab=ylab,
+ xlab=xlab,
+ ylim=ylim,
+ title=title,
+ legend_position=legend_position,
+ add_box=add_box,
+ horizontal_plot=horizontal_plot,
+ order=order,
+ col_set=col_set,
+ cols=cols,
+ log_num=log_num,
+ show_outliers=show_outliers,
+ path=path,
+ device=device,
+ fingerprint=fingerprint,
+ unused_columns=unused_columns,
+ raise_if_missing=raise_if_missing,
+ show=show,
+ )
diff --git a/src/biofit/preprocessing/scaling/plot_scaling.R b/src/biofit/preprocessing/scaling/plot_scaling.R
new file mode 100644
index 0000000..568d166
--- /dev/null
+++ b/src/biofit/preprocessing/scaling/plot_scaling.R
@@ -0,0 +1,56 @@
+source(file.path(R_SCRIPTS_PATH, "plotting_utils.R"))
+
+
+plot_scaler <- function(
+ x1,
+ x2 = NULL,
+ y1 = NULL,
+ y2 = NULL,
+ path = NULL,
+ input_columns1 = NULL,
+ input_columns2 = NULL,
+ label_name1 = NULL,
+ label_name2 = NULL,
+ ylab = NULL,
+ xlab = NULL,
+ ylim = NULL,
+ before_title = NULL,
+ after_title = NULL,
+ legend_position = NULL,
+ add_box = TRUE,
+ horizontal_plot = FALSE,
+ order = FALSE,
+ col_set = "Set1",
+ cols = NULL,
+ log_num = NULL,
+ show_outliers = TRUE) {
+
+ suppressPackageStartupMessages(require(patchwork))
+
+ if (is.null(y2)) {
+ y2 <- y1
+ }
+ p1 <- generate_violin(
+ x1, y1,
+ column = input_columns1, label_name = label_name1, ylim = ylim,
+ xlab = xlab, ylab = ylab, title = before_title, legend_position = legend_position,
+ add_box = add_box, horizontal_plot = horizontal_plot, order = order,
+ col_set = col_set, cols = cols, log_num = log_num, show_outliers = show_outliers
+ )
+ if (!is.null(x2)) {
+ p2 <- generate_violin(
+ x2, y2,
+ column = input_columns2, label_name = label_name2, ylim = ylim,
+ xlab = xlab, ylab = ylab, title = after_title, legend_position = legend_position,
+ add_box = add_box, horizontal_plot = horizontal_plot, order = order,
+ col_set = col_set, cols = cols, log_num = log_num, show_outliers = show_outliers
+ )
+ the_plot <- p1 / p2
+ } else {
+ the_plot <- p1
+ }
+ if (!is.null(path)) {
+ save_plots(path, plot = the_plot, width = 6, height = 6, dpi = 600)
+ }
+ return(path)
+}
diff --git a/src/biofit/preprocessing/scaling/plot_scaling.py b/src/biofit/preprocessing/scaling/plot_scaling.py
new file mode 100644
index 0000000..a8c1c66
--- /dev/null
+++ b/src/biofit/preprocessing/scaling/plot_scaling.py
@@ -0,0 +1,51 @@
+from dataclasses import dataclass, field
+from pathlib import Path
+
+from biocore import DataHandler
+from biocore.utils.import_util import is_biosets_available
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import SelectedFeatureTypes
+from biofit.visualization.plotting import BasePlotter, PlotterConfig
+
+
+@dataclass
+class ScalerPlotterConfig(PlotterConfig):
+ processor_type: str = field(default="scaling", init=False, repr=False)
+ r_source: str = field(
+ default=(Path(__file__).parent / "plot_scaling.R").as_posix(),
+ init=False,
+ repr=False,
+ )
+ main_method: str = field(default="plot_scaler", init=False, repr=False)
+ _fit_unused_feature_types: SelectedFeatureTypes = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ _transform_unused_feature_types: SelectedFeatureTypes = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+
+
+class ScalerPlotter(BasePlotter):
+ config_class = ScalerPlotterConfig
+ config: ScalerPlotterConfig
+
+ def plot_dataset(self, x1, x2, y1, y2):
+ if is_biosets_available():
+ from biosets import decode
+
+ y1 = decode(y1) if y1 is not None else None
+ y2 = decode(y2) if y2 is not None else None
+ return self.plot_arrow(
+ DataHandler.to_arrow(x1),
+ DataHandler.to_arrow(x2),
+ DataHandler.to_arrow(y1) if y1 is not None else None,
+ DataHandler.to_arrow(y2) if y2 is not None else None,
+ )
+
+ def plot_arrow(self, x1, x2, y1, y2):
+ return self.plotter(x1, x2, y1, y2, **self.config.get_params())
diff --git a/src/biofit/preprocessing/scaling/relative_abundance/__init__.py b/src/biofit/preprocessing/scaling/relative_abundance/__init__.py
new file mode 100644
index 0000000..136618e
--- /dev/null
+++ b/src/biofit/preprocessing/scaling/relative_abundance/__init__.py
@@ -0,0 +1,11 @@
+# ruff: noqa
+from .plot_relative_abundance import (
+ RelativeAbundancePlotter,
+ RelativeAbundancePlotterConfig,
+ RelativeAbundancePlotterConfigForASV,
+ RelativeAbundancePlotterConfigForMaldi,
+ RelativeAbundancePlotterConfigForMetagenomics,
+ RelativeAbundancePlotterConfigForOTU,
+ RelativeAbundancePlotterConfigForSNP,
+)
+from .relative_abundance import RelativeAbundanceScaler, RelativeAbundanceScalerConfig
diff --git a/src/biofit/preprocessing/scaling/relative_abundance/plot_relative_abundance.py b/src/biofit/preprocessing/scaling/relative_abundance/plot_relative_abundance.py
new file mode 100644
index 0000000..a09dbc1
--- /dev/null
+++ b/src/biofit/preprocessing/scaling/relative_abundance/plot_relative_abundance.py
@@ -0,0 +1,274 @@
+from dataclasses import dataclass, field
+from typing import List, Optional, Type
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import SelectedColumnTypes
+from biocore.utils.py_util import is_bioset
+from biofit.utils.types import Unset
+
+from ..plot_scaling import ScalerPlotter, ScalerPlotterConfig
+
+
+@dataclass
+class RelativeAbundancePlotterConfig(ScalerPlotterConfig):
+ processor_name: str = field(default="relative_abundance", init=False, repr=False)
+ _compare: bool = field(default=True, init=False, repr=False)
+ _transoform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ None,
+ None,
+ get_feature("TARGET_FEATURE_TYPES"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+
+ _transform_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("METADATA_FEATURE_TYPES"),
+ get_feature("METADATA_FEATURE_TYPES"),
+ None,
+ None,
+ ],
+ init=False,
+ repr=False,
+ )
+
+ input_columns1: Optional[str] = None
+ input_columns2: Optional[str] = None
+ label_name1: Optional[str] = None
+ label_name2: Optional[str] = None
+ ylab: Optional[str] = None
+ xlab: Optional[str] = None
+ ylim: Optional[list] = None
+ before_title: str = "Before Relative Abundance"
+ after_title: str = "After Relative Abundance"
+ legend_position: str = "top"
+ add_box: bool = True
+ horizontal_plot: bool = False
+ order: bool = False
+ col_set: str = "Set1"
+ cols: Optional[str] = None
+ log_num: Optional[str] = None
+ show_outliers: bool = True
+
+
+@dataclass
+class RelativeAbundancePlotterConfigForMetagenomics(RelativeAbundancePlotterConfig):
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ (get_feature("Abundance"), get_feature("ReadCount")),
+ (get_feature("Abundance"), get_feature("ReadCount")),
+ get_feature("TARGET_FEATURE_TYPES"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+ dataset_name: str = field(default="metagenomics", init=False, repr=False)
+ log_num: Optional[str] = "log2_1p"
+
+
+@dataclass
+class RelativeAbundancePlotterConfigForOTU(RelativeAbundancePlotterConfig):
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("Abundance"),
+ get_feature("Abundance"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+ dataset_name: str = field(default="otu", init=False, repr=False)
+ ylab: Optional[str] = "OTU Abundance"
+ log_num: Optional[str] = "log10_1p"
+
+
+@dataclass
+class RelativeAbundancePlotterConfigForASV(RelativeAbundancePlotterConfig):
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("Abundance"),
+ get_feature("Abundance"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+ dataset_name: str = field(default="asv", init=False, repr=False)
+
+
+@dataclass
+class RelativeAbundancePlotterConfigForSNP(RelativeAbundancePlotterConfig):
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("GenomicVariant"),
+ get_feature("GenomicVariant"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("GenomicVariant"),
+ get_feature("GenomicVariant"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+ dataset_name: str = field(default="snp", init=False, repr=False)
+
+
+@dataclass
+class RelativeAbundancePlotterConfigForMaldi(RelativeAbundancePlotterConfig):
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("PeakIntensity"),
+ get_feature("PeakIntensity"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("PeakIntensity"),
+ get_feature("PeakIntensity"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+ dataset_name: str = field(default="maldi", init=False, repr=False)
+ ylab: Optional[str] = "Peak Intensity"
+
+
+class RelativeAbundancePlotter(ScalerPlotter):
+ config_class = RelativeAbundancePlotterConfig
+ config: RelativeAbundancePlotterConfig
+
+ def __init__(
+ self,
+ input_columns1: SelectedColumnTypes = None,
+ input_columns2: SelectedColumnTypes = None,
+ label_name1: Optional[str] = None,
+ label_name2: Optional[str] = None,
+ ylab: Optional[str] = Unset("None"),
+ xlab: Optional[str] = Unset("None"),
+ ylim: Optional[list] = Unset("None"),
+ title: str = Unset('"Violin Plot"'),
+ legend_position: str = Unset('"top"'),
+ add_box: bool = Unset("True"),
+ horizontal_plot: bool = Unset("False"),
+ order: bool = Unset("False"),
+ col_set: str = Unset('"Set1"'),
+ cols: Optional[str] = Unset("None"),
+ log_num: Optional[int] = Unset("None"),
+ show_outliers: bool = Unset("True"),
+ config: RelativeAbundancePlotterConfig = None,
+ **kwargs,
+ ):
+ super().__init__(
+ config=config,
+ input_columns1=input_columns1,
+ input_columns2=input_columns2,
+ label_name1=label_name1,
+ label_name2=label_name2,
+ ylab=ylab,
+ xlab=xlab,
+ ylim=ylim,
+ title=title,
+ legend_position=legend_position,
+ add_box=add_box,
+ horizontal_plot=horizontal_plot,
+ order=order,
+ col_set=col_set,
+ cols=cols,
+ log_num=log_num,
+ show_outliers=show_outliers,
+ **kwargs,
+ )
+
+ def plot(
+ self,
+ x1,
+ x2=None,
+ y1=None,
+ y2=None,
+ input_columns1: SelectedColumnTypes = None,
+ input_columns2: SelectedColumnTypes = None,
+ label_name1: SelectedColumnTypes = None,
+ label_name2: SelectedColumnTypes = None,
+ ylab: Optional[str] = Unset("None"),
+ xlab: Optional[str] = Unset("None"),
+ ylim: Optional[list] = Unset("None"),
+ title: str = Unset('"Violin Plot"'),
+ legend_position: str = Unset('"top"'),
+ add_box: bool = Unset("True"),
+ horizontal_plot: bool = Unset("False"),
+ order: bool = Unset("False"),
+ col_set: str = Unset('"Set1"'),
+ cols: Optional[str] = Unset("None"),
+ log_num: Optional[int] = Unset("None"),
+ show_outliers: bool = Unset("True"),
+ path: str = None,
+ device: str = "pdf",
+ fingerprint: str = None,
+ unused_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = True,
+ show: bool = True,
+ ):
+ if is_bioset(x1):
+ from biosets import get_target
+
+ if y1 is None:
+ y1 = get_target(x1)
+
+ if is_bioset(x2):
+ from biosets import get_target
+
+ if y2 is None:
+ y2 = get_target(x2)
+ if x1 is None or x2 is None or y1 is None or y2 is None:
+ raise ValueError("Must provide the before and after normalization.")
+ self.config._input_columns = self._set_input_columns_and_arity(
+ input_columns1, input_columns2, label_name1, label_name2
+ )
+ return self._plot(
+ x1,
+ x2,
+ y1,
+ y2,
+ input_columns1=input_columns1,
+ input_columns2=input_columns2,
+ label_name1=label_name1,
+ label_name2=label_name2,
+ ylab=ylab,
+ xlab=xlab,
+ ylim=ylim,
+ title=title,
+ legend_position=legend_position,
+ add_box=add_box,
+ horizontal_plot=horizontal_plot,
+ order=order,
+ col_set=col_set,
+ cols=cols,
+ log_num=log_num,
+ show_outliers=show_outliers,
+ path=path,
+ device=device,
+ fingerprint=fingerprint,
+ unused_columns=unused_columns,
+ raise_if_missing=raise_if_missing,
+ show=show,
+ )
diff --git a/src/biofit/preprocessing/scaling/relative_abundance/relative_abundance.py b/src/biofit/preprocessing/scaling/relative_abundance/relative_abundance.py
new file mode 100644
index 0000000..1f76867
--- /dev/null
+++ b/src/biofit/preprocessing/scaling/relative_abundance/relative_abundance.py
@@ -0,0 +1,188 @@
+from dataclasses import dataclass, field
+from typing import List, Type
+
+import numpy as np
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import SelectedColumnTypes
+from biofit.utils import logging
+
+from ..scaling import Scaler, ScalerConfig
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class RelativeAbundanceScalerConfig(ScalerConfig):
+ """
+ Configuration for RelativeAbundanceScaler.
+ """
+
+ processor_name: str = field(default="relative_abundance", init=False, repr=False)
+ _fit_process_desc: str = field(
+ default="Calculating relative abundance scaling", init=False, repr=False
+ )
+ _transform_process_desc: str = field(
+ default="Applying relative abundance scaling", init=False, repr=False
+ )
+ _fit_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ _transform_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+
+ pseudocount: float = 1e-6
+ input_columns: List[str] = None
+
+
+class RelativeAbundanceScaler(Scaler):
+ """
+ RelativeAbundanceScaler applies relative abundance scaling to the input data.
+ """
+
+ output_dtype = "float64"
+ config_class = RelativeAbundanceScalerConfig
+ config: RelativeAbundanceScalerConfig
+
+ def __init__(
+ self,
+ pseudocount: float = 1e-6,
+ config: RelativeAbundanceScalerConfig = None,
+ **kwargs,
+ ):
+ super().__init__(config=config, pseudocount=pseudocount, **kwargs)
+
+ def fit(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ num_proc: int = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ) -> "RelativeAbundanceScaler":
+ self.config._input_columns = self._set_input_columns_and_arity(
+ input_columns or self.config.input_columns
+ )
+ return self._process_fit(
+ X,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ )
+
+ def transform(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ self._input_columns = self._set_input_columns_and_arity(
+ input_columns or self.config.input_columns
+ )
+ return self._process_transform(
+ X,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def fit_transform(
+ self,
+ X,
+ *args,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = True,
+ raise_if_missing: bool = False,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ num_proc: int = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ):
+ return self.fit(
+ X,
+ input_columns=input_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ ).transform(
+ X,
+ input_columns=input_columns,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def _transform_numpy(self, X: np.ndarray):
+ # Calculating the sum of each sample
+ row_sums = np.sum(X, axis=1)[:, np.newaxis]
+ if np.any(row_sums == 0):
+ row_sums += self.config.pseudocount
+ return X / row_sums
diff --git a/src/biofit/preprocessing/scaling/scaling.py b/src/biofit/preprocessing/scaling/scaling.py
new file mode 100644
index 0000000..3dda512
--- /dev/null
+++ b/src/biofit/preprocessing/scaling/scaling.py
@@ -0,0 +1,24 @@
+from dataclasses import dataclass, field
+from typing import List, Type
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import BaseProcessor, ProcessorConfig
+
+
+@dataclass
+class ScalerConfig(ProcessorConfig):
+ processor_type: str = field(default="scaling", init=False, repr=False)
+ _fit_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ _transform_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+
+
+class Scaler(BaseProcessor):
+ pass
diff --git a/src/biofit/preprocessing/scaling/tmm/__init__.py b/src/biofit/preprocessing/scaling/tmm/__init__.py
new file mode 100644
index 0000000..31f9c51
--- /dev/null
+++ b/src/biofit/preprocessing/scaling/tmm/__init__.py
@@ -0,0 +1,7 @@
+# ruff: noqa
+from .tmm import (
+ TMMScaler,
+ TMMScalerConfig,
+ TMMScalerConfigForOTU,
+ TMMScalerConfigForMetagenomics,
+)
diff --git a/src/biofit/preprocessing/scaling/tmm/tmm.R b/src/biofit/preprocessing/scaling/tmm/tmm.R
new file mode 100644
index 0000000..41bedb6
--- /dev/null
+++ b/src/biofit/preprocessing/scaling/tmm/tmm.R
@@ -0,0 +1,89 @@
+# Full path: src/biofit/preprocessing/scaling/tmm/tmm.R
+# adapted from ruppinlab/sklearn-extensions and edgeR::calcNormFactors source code
+
+source(file.path(R_SCRIPTS_PATH, "utils.R"))
+
+edger_tmm_ref_column <- function(counts, lib_size=colSums(counts), p=0.75) {
+ y <- t(t(counts) / lib_size)
+ f <- apply(y, 2, function(x) quantile(x, p=p))
+ ref_column <- which.min(abs(f - mean(f)))
+}
+
+edger_tmm_fit <- function(X) {
+ suppressPackageStartupMessages(require(edgeR))
+ suppressPackageStartupMessages(require(arrow))
+ X <- as.matrix(as.data.frame(X))
+ counts <- t(X)
+ ref_sample <- counts[, edger_tmm_ref_column(counts)]
+ return(ref_sample)
+}
+
+edger_tmm_cpm_transform <- function(X, ref_samples, log=TRUE, prior_count=2) {
+ suppressPackageStartupMessages(require(edgeR))
+ suppressPackageStartupMessages(require(arrow))
+ X <- as.matrix(convert_to_dataframe(X))
+ counts <- t(X)
+ ref_sample_mask <- apply(counts, 2, function(c) all(c == ref_samples))
+ if (any(ref_sample_mask)) {
+ dge <- edgeR::DGEList(counts=counts)
+ dge <- edgeR::calcNormFactors(
+ dge, method="TMM", refColumn=min(which(ref_sample_mask))
+ )
+ cpms <- edgeR::cpm(dge, log=log, prior.count=prior_count)
+ } else {
+ counts <- cbind(counts, ref_samples)
+ colnames(counts) <- NULL
+ dge <- edgeR::DGEList(counts=counts)
+ dge <- edgeR::calcNormFactors(dge, method="TMM", refColumn=ncol(dge))
+ cpms <- edgeR::cpm(dge, log=log, prior.count=prior_count)
+ cpms <- cpms[, -ncol(cpms)]
+ }
+ return(arrow::as_arrow_table(as.data.frame(t(cpms))))
+}
+
+edger_tmm_tpm_transform <- function(
+ X, feature_meta, ref_samples, log=TRUE, prior_count=2, meta_col="Length"
+) {
+ if (is.null(feature_meta)) stop("feature_meta cannot be NULL")
+ suppressPackageStartupMessages(require(edgeR))
+ suppressPackageStartupMessages(require(arrow))
+ X <- as.matrix(convert_to_dataframe(X))
+ counts <- t(X)
+ ref_sample_mask <- apply(counts, 2, function(c) all(c == ref_samples))
+ if (any(ref_sample_mask)) {
+ dge <- edgeR::DGEList(counts=counts, genes=feature_meta)
+ dge <- edgeR::calcNormFactors(
+ dge, method="TMM", refColumn=min(which(ref_sample_mask))
+ )
+ } else {
+ counts <- cbind(counts, ref_samples)
+ colnames(counts) <- NULL
+ dge <- edgeR::DGEList(counts=counts, genes=feature_meta)
+ dge <- edgeR::calcNormFactors(dge, method="TMM", refColumn=ncol(dge))
+ }
+ if (log) {
+ # XXX: edgeR doesn't have built-in support for logTPM w/ prior.count
+ # so do API internal logic manually
+ # TODO: use effectiveLibSizes() in newer edgeR versions
+ lib_size <- dge$samples$lib.size * dge$samples$norm.factors
+ scaled_prior_count <- args$prior_count * lib_size / mean(lib_size)
+ adj_lib_size <- lib_size + 2 * scaled_prior_count
+ fpkms <- t(
+ (t(dge$counts) + scaled_prior_count) / adj_lib_size
+ ) * 1e6 / dge$genes[[meta_col]] * 1e3
+ tpms <- log2(t(t(fpkms) / colSums(fpkms)) * 1e6)
+ } else {
+ fpkms <- edgeR::rpkm(
+ dge, gene.length=meta_col, log=log, prior.count=prior_count
+ )
+ tpms <- t(t(fpkms) / colSums(fpkms)) * 1e6
+ }
+ if (!any(ref_sample_mask)) tpms <- tpms[, -ncol(tpms)]
+ return(arrow::as_arrow_table(t(tpms)))
+}
+
+edger_cpm_transform <- function(X, log=TRUE, prior_count=2) {
+ suppressPackageStartupMessages(require(edgeR))
+ suppressPackageStartupMessages(require(arrow))
+ return(arrow::as_arrow_table(t(edgeR::cpm(t(X), log=log, prior.count=prior_count))))
+}
diff --git a/src/biofit/preprocessing/scaling/tmm/tmm.py b/src/biofit/preprocessing/scaling/tmm/tmm.py
new file mode 100644
index 0000000..c78511d
--- /dev/null
+++ b/src/biofit/preprocessing/scaling/tmm/tmm.py
@@ -0,0 +1,281 @@
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import List, Optional, Type
+
+import pyarrow as pa
+
+from biofit.integration.biosets import get_feature
+from biofit.integration.R import RCaller
+from biofit.integration.R.r_caller import PackageNotInstalledError
+from biofit.processing import SelectedColumnTypes, sync_backup_config
+from biofit.utils import logging
+
+from ..scaling import Scaler, ScalerConfig
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class TMMScalerConfig(ScalerConfig):
+ _fit_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ _transform_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES"), None],
+ init=False,
+ repr=False,
+ )
+ # process descriptions
+ _fit_process_desc: str = field(
+ default="Determining reference samples for TMM normalization",
+ init=False,
+ repr=False,
+ )
+ _transform_process_desc: str = field(
+ default="Applying TMM normalization", init=False, repr=False
+ )
+ processor_name: str = field(default="tmm", init=False, repr=False)
+
+ # attributes
+ r_source: str = field(
+ default=Path(__file__).with_suffix(".R").as_posix(), init=False, repr=False
+ )
+ fit_func_name: str = field(default="edger_tmm_fit", init=False, repr=False)
+ cpm_transform_func_name: str = field(
+ default="edger_tmm_cpm_transform", init=False, repr=False
+ )
+ tpm_transform_func_name: str = field(
+ default="edger_tmm_tpm_transform", init=False, repr=False
+ )
+
+ log: bool = True
+ prior_count: int = 2
+ meta_col: str = "Length"
+ gene_col: str = "genes"
+
+ # estimated attributes
+ ref_samples: Optional[pa.Table] = field(default=None, init=False, repr=False)
+
+
+class TMMScalerConfigForMetagenomics(TMMScalerConfig):
+ # dataset specific attributes
+ _input_feature_types: List[Type] = (
+ get_feature("Abundance"),
+ get_feature("ReadCount"),
+ )
+ dataset_name = "metagenomics"
+
+
+class TMMScalerConfigForOTU(TMMScalerConfig):
+ # dataset specific attributes
+ _input_feature_types: List[Type] = get_feature("Abundance")
+ dataset_name = "otu"
+
+
+class TMMScalerConfigForSNP(TMMScalerConfig):
+ # dataset specific attributes
+ _input_feature_types: List[Type] = get_feature("GenomicVariant")
+ dataset_name = "snp"
+
+
+class TMMScaler(Scaler):
+ output_dtype = "float64"
+
+ # config class
+ config_class = TMMScalerConfig
+ config: TMMScalerConfig
+
+ def __init__(
+ self,
+ log: bool = True,
+ prior_count: int = 2,
+ meta_col: str = "Length",
+ gene_col: str = "genes",
+ config: Optional[TMMScalerConfig] = None,
+ **kwargs,
+ ):
+ super().__init__(
+ config=config,
+ log=log,
+ prior_count=prior_count,
+ meta_col=meta_col,
+ gene_col=gene_col,
+ **kwargs,
+ )
+ r_caller = RCaller.from_script(self.config.r_source)
+ install_missing = kwargs.get("install_missing")
+ try:
+ r_caller.verify_r_dependencies(
+ cran_dependencies=["BiocManager"],
+ bioconductor_dependencies=["edgeR"],
+ install_missing=install_missing,
+ )
+ except PackageNotInstalledError:
+ raise PackageNotInstalledError(
+ "TMMScale requires the following R package: edgeR. To "
+ "install, initialize the TMMScaler with install_missing=True or run "
+ "R -e 'BiocManager::install(\"edgeR\")' in your terminal."
+ )
+ self.fit_func = r_caller.get_method(self.config.fit_func_name)
+ self.cpm_transform = r_caller.get_method(self.config.cpm_transform_func_name)
+ self.tpm_transform = r_caller.get_method(self.config.tpm_transform_func_name)
+
+ def fit(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ num_proc: int = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ):
+ self.config._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_fit(
+ X,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ )
+
+ def transform(
+ self,
+ X,
+ genes=None,
+ input_columns: SelectedColumnTypes = None,
+ gene_col: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ self._input_columns = self._set_input_columns_and_arity(input_columns, gene_col)
+ return self._process_transform(
+ X,
+ genes,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def fit_transform(
+ self,
+ X,
+ genes=None,
+ input_columns: SelectedColumnTypes = None,
+ gene_col: SelectedColumnTypes = None,
+ keep_unused_columns: bool = True,
+ raise_if_missing: bool = False,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ num_proc: int = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ):
+ return self.fit(
+ X,
+ input_columns=input_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ ).transform(
+ X,
+ genes=genes,
+ gene_col=gene_col,
+ input_columns=input_columns,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ @sync_backup_config
+ def set_params(self, **kwargs):
+ self.config = self.config.replace_defaults(**kwargs)
+
+ # load file from src/biofit/preprocessing/scaling/tmm/tmm.R
+ r_caller = RCaller.from_script(self.config.r_source)
+ self.fit_func = r_caller.get_method(self.config.fit_func_name)
+ self.cpm_transform = r_caller.get_method(self.config.cpm_transform_func_name)
+ self.tpm_transform = r_caller.get_method(self.config.tpm_transform_func_name)
+ return self
+
+ def _fit_arrow(self, X: pa.Table):
+ self.config.ref_samples = self.fit_func(X)
+ return self
+
+ def _transform_arrow(self, X: pa.Table, genes=None):
+ if genes:
+ return self.tpm_transform(
+ X=X,
+ feature_meta=genes,
+ ref_samples=self.config.ref_samples,
+ log=self.config.log,
+ prior_count=self.config.prior_count,
+ meta_col=self.config.meta_col,
+ )
+ else:
+ return self.cpm_transform(
+ X=X,
+ ref_samples=self.config.ref_samples,
+ log=self.config.log,
+ prior_count=self.config.prior_count,
+ )
diff --git a/src/biofit/preprocessing/transformation/__init__.py b/src/biofit/preprocessing/transformation/__init__.py
new file mode 100644
index 0000000..d5035f9
--- /dev/null
+++ b/src/biofit/preprocessing/transformation/__init__.py
@@ -0,0 +1,2 @@
+# ruff: noqa
+from .log import *
diff --git a/src/biofit/preprocessing/transformation/log/__init__.py b/src/biofit/preprocessing/transformation/log/__init__.py
new file mode 100644
index 0000000..d81b7a9
--- /dev/null
+++ b/src/biofit/preprocessing/transformation/log/__init__.py
@@ -0,0 +1,2 @@
+# ruff: noqa
+from .log import LogTransformer, LogTransformerConfig, LogTransformerConfigForOTU
diff --git a/src/biofit/preprocessing/transformation/log/log.py b/src/biofit/preprocessing/transformation/log/log.py
new file mode 100644
index 0000000..d07cdf9
--- /dev/null
+++ b/src/biofit/preprocessing/transformation/log/log.py
@@ -0,0 +1,201 @@
+from dataclasses import dataclass, field
+from typing import List, Type
+
+import numpy as np
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import (
+ SelectedColumnTypes,
+ SelectedFeatureTypes,
+ sync_backup_config,
+)
+from biofit.utils import logging
+
+from ..transformation import Transformer, TransformerConfig
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class LogTransformerConfig(TransformerConfig):
+ processor_name: str = field(default="log", init=False, repr=False)
+ _transform_process_desc: str = field(
+ default="Applying log transformation", init=False, repr=False
+ )
+ _fit_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ _transform_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+
+ base = np.e
+ shift = 0
+ estimator = None
+
+
+@dataclass
+class LogTransformerConfigForOTU(LogTransformerConfig):
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance")], init=False, repr=False
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance")], init=False, repr=False
+ )
+ dataset_name: str = field(default="otu", init=False, repr=False)
+ base = 2
+ shift = 1
+
+
+class LogTransformer(Transformer):
+ output_dtype = "float64"
+ config_class = LogTransformerConfig
+ config: LogTransformerConfig
+
+ @sync_backup_config
+ def set_params(self, **kwargs):
+ self.config = self.config.replace_defaults(**kwargs)
+ func = None
+ if self.config.base == np.e:
+ if self.config.shift == 0:
+ func = np.log
+ elif self.config.shift == 1:
+ func = np.log1p
+ elif self.config.base == 10:
+ func = np.log10
+ elif self.config.base == 2:
+ func = np.log2
+
+ if func is None:
+
+ def func(x):
+ return np.log(x + self.config.shift) / np.log(self.config.base)
+
+ self.log_transform = func
+ return self
+
+ def fit(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ num_proc: int = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ) -> "LogTransformer":
+ self.config._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_fit(
+ X,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ )
+
+ def transform(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ self._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_transform(
+ X,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def fit_transform(
+ self,
+ X,
+ *args,
+ input_columns: SelectedFeatureTypes = None,
+ keep_unused_columns: bool = True,
+ raise_if_missing: bool = False,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ num_proc: int = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ):
+ return self.fit(
+ X,
+ input_columns=input_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ ).transform(
+ X,
+ input_columns=input_columns,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def _transform_sklearn(self, X):
+ return self.log_transform(X)
diff --git a/src/biofit/preprocessing/transformation/transformation.py b/src/biofit/preprocessing/transformation/transformation.py
new file mode 100644
index 0000000..61f144b
--- /dev/null
+++ b/src/biofit/preprocessing/transformation/transformation.py
@@ -0,0 +1,17 @@
+from dataclasses import dataclass, field
+
+from biofit.processing import BaseProcessor, ProcessorConfig
+
+
+@dataclass
+class TransformerConfig(ProcessorConfig):
+ """Base class for Transformer processor configuration."""
+
+ processor_type: str = field(default="transformation", init=False, repr=False)
+
+
+class Transformer(BaseProcessor):
+ """Base class for Transformer processors."""
+
+ config_class = TransformerConfig
+ config: TransformerConfig
diff --git a/src/biofit/processing.py b/src/biofit/processing.py
new file mode 100644
index 0000000..7ade0eb
--- /dev/null
+++ b/src/biofit/processing.py
@@ -0,0 +1,3696 @@
+import copy
+import importlib
+import inspect
+import json
+import os
+import re
+import sys
+import tempfile
+import time
+from collections.abc import Callable
+from dataclasses import dataclass, field, fields
+from functools import wraps
+from multiprocessing import Pool
+from pathlib import Path
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+)
+
+import joblib
+import numpy as np
+import pandas as pd
+import pandas.api.types as pdt
+import pyarrow as pa
+from biocore import DataHandler, get_data_format
+from biocore.utils.import_util import (
+ is_polars_available,
+ is_ray_available,
+ is_biosets_available,
+ is_datasets_available,
+)
+from biocore.utils.inspect import get_kwargs, get_required_args
+from biocore.utils.naming import camelcase_to_snakecase
+from sklearn.utils.validation import (
+ NotFittedError,
+)
+
+import biofit.config
+from biofit.utils import (
+ Unset,
+ determine_upcast,
+ fingerprint_from_data,
+ fingerprint_from_kwargs,
+ generate_cache_dir,
+ get_cache_file_name,
+ init_arrow_buffer_and_writer,
+ logging,
+ move_temp_file,
+ update_fingerprint,
+ version,
+)
+from biofit.utils.file_utils import expand_path, is_remote_url
+from biofit.utils.fingerprint import Hasher, is_caching_enabled
+from biofit.utils.py_util import iflatmap_unordered
+from biocore.utils.py_util import (
+ is_bioset,
+ is_dataset,
+ is_dataset_dict,
+ is_iterable_dataset,
+)
+
+from biofit.utils.table_util import string_to_arrow
+
+if TYPE_CHECKING:
+ from datasets.features.features import Features
+
+logger = logging.get_logger(__name__)
+
+T_CONFIG = TypeVar("T_CONFIG", bound="BaseConfig")
+T_PROC = TypeVar("T_PROC", bound="BaseProcessor")
+
+SelectedFeatureTypes = Union[Type, Tuple[Type], List[Union[Type, Tuple[Type]]]]
+
+SelectedColumnTypes = Union[
+ str,
+ int,
+ List[str],
+ List[int],
+ List[Union[List[str], List[int]]],
+ SelectedFeatureTypes,
+]
+
+
+class NonExistentCacheError(Exception):
+ """Used when we expect the existence of a cache"""
+
+ pass
+
+
+# based on conversion speed, we will use the following order of preference
+_ORDERED_FORMATS = [
+ "arrow",
+ "pyarrow",
+ "torch",
+ "pt",
+ "tf",
+ "tensorflow",
+ "pandas",
+ "pd",
+ "numpy",
+ "np",
+ "dicts",
+ "dict",
+ "list",
+]
+
+_DATAFRAME_FORMATS = [
+ "pandas",
+ "pd",
+]
+
+_SKLEARN_FORMATS = [
+ "numpy",
+ "np",
+ "pandas",
+ "pd",
+ "series",
+]
+
+_SPECIAL_FORMATS = ["any", "all", "array", "dataframe", "sklearn"]
+
+_ARROW_WRITEABLE_FORMATS = ["arrow", "pyarrow"]
+# when https://github.com/huggingface/datasets/pull/6762 is merged, we ca add
+# "polars", "pl" to the list of writeable formats
+
+if is_polars_available():
+ pandas_pos = _ORDERED_FORMATS.index("pandas")
+ _ORDERED_FORMATS.insert(pandas_pos, "polars")
+ _ORDERED_FORMATS.insert(pandas_pos + 1, "pl")
+
+ _DATAFRAME_FORMATS.insert(0, "polars")
+ _DATAFRAME_FORMATS.insert(1, "pl")
+
+ _SKLEARN_FORMATS.insert(0, "polars")
+ _SKLEARN_FORMATS.insert(1, "pl")
+
+
+def sync_backup_config(func):
+ @wraps(func)
+ def wrapper(self, **kwargs):
+ if hasattr(self, "config_"):
+ self._fingerprint = None
+ self.config_.replace_defaults(**kwargs)
+ return func(self, **kwargs)
+
+ return wrapper
+
+
+def _generate_get_feature_names_out(estimator, n_features_out):
+ """
+ Modified _generate_get_feature_names_out from sklearn to convert estimator name to snakecase
+ """
+
+ def name_formatter(template, index):
+ # Check if there is any placeholder {i...} in the template
+ if re.search(r"\{i(\+([0-9]+))?\}", template) is None:
+ # No placeholders, use default formatting {name}{i}
+ template = f"{template}{index}"
+
+ # Regex to find {i} or {i+n} where n is an integer
+ pattern = r"\{i(\+([0-9]+))?\}"
+ matches = re.finditer(pattern, template)
+
+ for match in matches:
+ full_match = match.group(0)
+ increment = match.group(2)
+ if increment:
+ value = index + int(increment)
+ else:
+ value = index
+
+ template = template.replace(full_match, str(value))
+
+ return template
+
+ estimator_name = None
+ if hasattr(estimator, "config"):
+ if getattr(estimator.config, "_feature_names_out", None) is not None:
+ return np.asarray(estimator.config._feature_names_out)
+ elif getattr(estimator.config, "output_template_name", None) is not None:
+ estimator_name = estimator.config.output_template_name
+ elif getattr(estimator.config, "processor_name", None):
+ estimator_name = estimator.config.processor_name
+ if n_features_out > 1:
+ estimator_name += "_"
+ if estimator_name is None:
+ estimator_name = camelcase_to_snakecase(estimator.__class__.__name__)
+ estimator_name = estimator_name.split("_for_")[0]
+ estimator_name = "_".join(estimator_name.split("_")[:-1])
+ if n_features_out > 1:
+ estimator_name += "_"
+
+ if n_features_out > 1:
+ out = [name_formatter(estimator_name, i) for i in range(n_features_out)]
+ else:
+ out = [estimator_name]
+
+ return np.asarray(out, dtype=object)
+
+
+REPLAY_ACTIONS = [
+ (
+ "from_config_transform",
+ {"path": "path", "fingerprint": "fingerprint", "ext": ".json"},
+ ),
+]
+
+
+def post_recording(*args, **kwargs):
+ out, _, func_args, func_kwargs = args
+ if not hasattr(out, "_fingerprint") or not hasattr(out, "replays"):
+ return out
+
+ info = kwargs.get("info", None)
+
+ if len(func_args) > 0:
+ self: "BaseProcessor" = func_args[0]
+ func_args = func_args[1:]
+ else:
+ self = func_kwargs["self"]
+ del func_kwargs["self"]
+
+ if len(func_args) > 0:
+ ds = func_args[0]
+ func_args = func_args[1:]
+ else:
+ ds = func_kwargs["X"]
+ del func_kwargs["X"]
+
+ out_fingerprint = None
+ if info:
+ out_fingerprint = info.get("new_fingerprint", None)
+
+ if not out_fingerprint:
+ out_fingerprint = getattr(out, "_fingerprint", None)
+
+ fingerprint = self.fingerprint
+
+ cache_file_names = []
+ if info:
+ cache_file_name = info.get("cache_file_name", None)
+ if not cache_file_name:
+ cache_dir = info.get("cache_dir", None)
+ if cache_dir and os.path.exists(cache_dir) and fingerprint:
+ file_list = os.listdir(cache_dir)
+ cache_file_names = [
+ os.path.join(cache_dir, f)
+ for f in file_list
+ if fingerprint in f or (out_fingerprint and out_fingerprint in f)
+ ]
+ else:
+ if cache_file_name and not isinstance(cache_file_name, list):
+ cache_file_names = [cache_file_name]
+
+ out._fingerprint = out_fingerprint
+ replays = ds.replays or out.replays or []
+ out.replays = replays.copy()
+ entity_path = f"{self.__module__}.{self.__class__.__name__}"
+
+ if cache_file_names:
+ for func_name, replay_info in REPLAY_ACTIONS:
+ path_arg = replay_info["path"]
+ fingerprint_arg = replay_info["fingerprint"]
+ ext = replay_info["ext"]
+ paths = []
+ for file in cache_file_names:
+ if file.endswith(ext):
+ paths.append(file)
+
+ if len(paths) == 1:
+ new_kwargs = {path_arg: paths[0]}
+ if fingerprint_arg:
+ new_kwargs[fingerprint_arg] = fingerprint
+ out.replays.append(
+ (
+ out._fingerprint,
+ entity_path,
+ func_name,
+ (),
+ {**new_kwargs},
+ )
+ )
+ break
+
+ return out
+
+
+def keep_dataset_fingerprint(func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ if len(args) > 0:
+ self = args[0]
+ args = args[1:]
+ else:
+ self = kwargs["self"]
+ del kwargs["self"]
+
+ if len(args) > 0:
+ ds = args[0]
+ args = args[1:]
+ else:
+ ds = kwargs["X"]
+ del kwargs["X"]
+
+ old_fingerprint = getattr(ds, "_fingerprint", None)
+ out = func(self, ds, *args, **kwargs)
+ if old_fingerprint:
+ ds._fingerprint = old_fingerprint
+ return out
+
+ return wrapper
+
+
+@dataclass
+class BaseConfig:
+ @classmethod
+ def from_config_file(
+ cls, path: str, ignore_none=False, add_new_attr=True
+ ) -> T_CONFIG:
+ """
+ Load configuration from a JSON file.
+
+ Args:
+ path (str or os.PathLike): Path to the configuration file.
+ ignore_none (bool): If True, ignore None values during loading.
+ add_new_attr (bool):
+ If True, allow adding new attributes that are not explicitly defined.
+
+ Returns:
+ T_CONFIG:
+ An instance of the configuration class with attributes loaded from the
+ file.
+ """
+ if isinstance(path, os.PathLike):
+ path = str(path)
+
+ with open(path, "r") as f:
+ states = json.load(f)
+ return cls.from_dict(states, ignore_none=ignore_none, add_new_attr=add_new_attr)
+
+ @classmethod
+ def from_config(
+ cls,
+ config_or_path: Union[str, os.PathLike, dict, "BaseConfig"],
+ ignore_none=False,
+ add_new_attr=False,
+ ) -> T_CONFIG:
+ """
+ Load configuration from a file, dictionary, or another BaseConfig instance.
+
+ Args:
+ config_or_path (Union[str, os.PathLike, dict, BaseConfig]):
+ Configuration data or path.
+ ignore_none (bool): If True, ignore None values during loading.
+ add_new_attr (bool):
+ If True, allow adding new attributes that are not explicitly defined.
+
+ Returns:
+ T_CONFIG: An instance of the configuration class.
+ """
+ if isinstance(config_or_path, (str, os.PathLike)):
+ return cls.from_config_file(
+ config_or_path, ignore_none=ignore_none, add_new_attr=add_new_attr
+ )
+ elif isinstance(config_or_path, dict):
+ return cls.from_dict(
+ config_or_path, ignore_none=ignore_none, add_new_attr=add_new_attr
+ )
+ elif isinstance(config_or_path, BaseConfig):
+ return cls.from_dict(
+ config_or_path.to_dict(deep=False),
+ ignore_none=ignore_none,
+ add_new_attr=add_new_attr,
+ )
+ else:
+ raise ValueError(f"Unsupported config type {type(config_or_path)}")
+
+ @classmethod
+ def from_dict(cls, states, ignore_none=False, add_new_attr=False) -> T_CONFIG:
+ """
+ Load configuration from a dictionary.
+
+ Args:
+ states (dict): Dictionary containing configuration states.
+ ignore_none (bool): If True, ignore None values during the assignment.
+ add_new_attr (bool):
+ If True, allow adding new attributes that are not in the init.
+
+ Returns:
+ T_CONFIG:
+ An instance of the configuration class with attributes set according to
+ `states`.
+ """
+
+ def _from_dict(obj):
+ if isinstance(obj, dict):
+ if len(obj) == 2 and "path" in obj and "format" in obj:
+ if obj["format"] == "joblib":
+ with open(obj["path"], "rb") as f:
+ return joblib.load(f)
+
+ return DataHandler.to_format(
+ obj["path"], target_format=obj["format"]
+ )
+ elif "__module__" in obj and "__class__" in obj and "__dict__" in obj:
+ module = importlib.import_module(obj.get("__module__"))
+ _cls = getattr(module, obj["__class__"])
+ if not hasattr(_cls, "from_dict"):
+ raise ValueError(
+ f"from_dict method not found for class {_cls.__name__}"
+ )
+ return _cls.from_dict(obj["__dict__"])
+ else:
+ return {k: _from_dict(v) for k, v in obj.items()}
+ else:
+ return obj
+
+ _states = copy.deepcopy(states)
+ cls_kwargs = get_kwargs(_states, cls.__init__)
+ self = cls(
+ **{
+ k: _from_dict(v)
+ for k, v in cls_kwargs.items()
+ if not isinstance(v, Unset)
+ }
+ )
+
+ for k in cls_kwargs:
+ _states.pop(k, None)
+
+ attributes = {k: _from_dict(v) for k, v in _states.items()}
+
+ self = self.replace_defaults(
+ ignore_none=ignore_none, add_new_attr=add_new_attr, **attributes
+ )
+ return self
+
+ def to_dict(
+ self,
+ deep=True,
+ save_nested_complex_obj=False,
+ path=None,
+ fingerprint=None,
+ ):
+ """
+ Convert the configuration to a dictionary.
+
+ Args:
+ deep (bool): If True, recursively convert all attributes to dictionaries.
+ save_nested_complex_obj (bool):
+ If True, save complex nested objects to disk.
+ path (str): Base path for saving complex nested objects.
+ fingerprint (str):
+ Optional fingerprint string to distinguish file paths.
+
+ Returns:
+ dict: A dictionary representation of the configuration.
+ """
+ base_dir = None
+ if save_nested_complex_obj:
+ if path:
+ base_dir = os.path.dirname(path)
+ else:
+ base_dir = tempfile.mkdtemp()
+ path = tempfile.NamedTemporaryFile("w", dir=base_dir, delete=False).name
+
+ def convert_to_dict(obj, name=None):
+ """
+ Convert an object's attributes to a dictionary, recursively processing
+ nested objects. Handles complex objects like arrays, datasets, and models
+ by potentially saving them to disk and replaces them with a reference in
+ the dictionary.
+
+ Args:
+ obj (Any): The object to be converted into a dictionary format.
+ name (str, optional):
+ Optional name to be used for generating file paths when saving
+ complex objects. Helps in creating more readable and traceable file
+ names.
+
+ Returns:
+ dict:
+ A dictionary representing the object. For simple types, returns the
+ object itself. For complex objects, returns a dictionary with
+ metadata such as path and format or fully serialized object states.
+
+ Raises:
+ Exception:
+ If there is a problem with writing the data to disk or any other
+ operational issue during the conversion process, it raises an
+ Exception to indicate failure.
+
+ Notes:
+ This method deals with different types of objects including simple data
+ types, complex machine learning models, and data structures. Depending
+ on the 'save_nested_complex_obj' flag in the `to_dict` method, complex
+ objects like machine learning models may either be saved to a file and
+ represented by a path, or fully serialized into the dictionary.
+
+ For saving data structures like arrays or datasets, it may use a
+ temporary directory or a specified path to store intermediate files.
+ This method handles the creation and cleanup of these files as needed.
+ """
+ if DataHandler.is_array_like(obj):
+ fp = None
+ _format = get_data_format(obj)
+ arrow_writer_kwargs = {}
+ if base_dir and save_nested_complex_obj:
+ if is_datasets_available():
+ from datasets import Features
+
+ if is_bioset(obj) or is_dataset(obj):
+ arrow_writer_kwargs["features"] = obj._info.features
+ obj = obj.data
+ else:
+ obj = DataHandler.to_format(obj, target_format="arrow")
+ arrow_writer_kwargs["features"] = (
+ Features.from_arrow_schema(obj.schema)
+ )
+ else:
+ obj = DataHandler.to_format(obj, target_format="arrow")
+ arrow_writer_kwargs["features"] = obj.schema
+ file_suffix = ""
+ if name:
+ file_suffix = f"-{name}"
+ new_fingerprint = update_fingerprint(
+ fingerprint or "", file_suffix + obj.__class__.__name__
+ )
+ fp = os.path.join(
+ base_dir, f"cache-{new_fingerprint}{file_suffix}.arrow"
+ )
+ arrow_writer_kwargs["fingerprint"] = new_fingerprint
+
+ out = {
+ "path": fp,
+ "format": _format,
+ }
+ if base_dir and save_nested_complex_obj:
+ _, writer, tmp_file = init_arrow_buffer_and_writer(
+ cache_file_name=fp, **arrow_writer_kwargs
+ )
+ try:
+ if writer is not None:
+ writer.write_table(obj)
+ writer.finalize()
+ if tmp_file is not None:
+ move_temp_file(tmp_file, fp)
+
+ except (Exception, KeyboardInterrupt):
+ if writer is not None:
+ writer.finalize()
+ if tmp_file is not None:
+ tmp_file.close()
+ if os.path.exists(tmp_file.name):
+ os.remove(tmp_file.name)
+ raise
+ return out
+ if hasattr(obj, "to_dict") and callable(obj.to_dict):
+ states = (
+ obj.to_dict(
+ save_nested_complex_obj=save_nested_complex_obj,
+ path=path,
+ fingerprint=fingerprint,
+ )
+ if isinstance(obj, BaseConfig)
+ else obj.to_dict()
+ )
+ return {
+ "__module__": obj.__class__.__module__,
+ "__class__": obj.__class__.__name__,
+ "__dict__": states,
+ }
+ if (
+ "sklearn" in obj.__class__.__module__
+ or "lightgbm" in obj.__class__.__module__
+ or "xgboost" in obj.__class__.__module__
+ or "catboost" in obj.__class__.__module__
+ ):
+ fp = None
+ _format = "joblib"
+ if base_dir and save_nested_complex_obj:
+ file_suffix = ""
+ if name:
+ file_suffix = f"-{name}"
+ new_fingerprint = update_fingerprint(
+ fingerprint or "", file_suffix + obj.__class__.__name__
+ )
+ fp = os.path.join(
+ base_dir, f"cache-{new_fingerprint}{file_suffix}.joblib"
+ )
+
+ with open(fp, "wb") as f:
+ joblib.dump(obj, f)
+
+ return {
+ "path": fp,
+ "format": _format,
+ }
+
+ if pdt.is_dict_like(obj):
+ return {k: convert_to_dict(v, k) for k, v in obj.items()}
+ if pdt.is_list_like(obj):
+ return [convert_to_dict(v) for v in obj]
+ if isinstance(obj, (str, bool, type(None))):
+ return obj
+ if pdt.is_integer(obj):
+ return int(obj)
+ if pdt.is_number(obj):
+ return float(obj)
+
+ if deep:
+ if hasattr(self, "__getstate__"): # python>=3.11
+ states = copy.deepcopy(
+ convert_to_dict(
+ {
+ k: v
+ for k, v in self.__getstate__().items()
+ if not isinstance(v, type)
+ }
+ )
+ )
+ else:
+ states = copy.deepcopy(
+ convert_to_dict(
+ {
+ k: v
+ for k, v in self.__dict__.items()
+ if not isinstance(v, type)
+ }
+ )
+ )
+ else:
+ states = copy.deepcopy(self.__dict__)
+
+ states["config_name"] = self.__class__.__name__
+ return states
+
+ def save_to_cache(self, path, fingerprint=None):
+ """
+ Save the configuration to a cache file in JSON format.
+
+ Args:
+ path (str): Path where the configuration will be saved.
+ fingerprint (str): Optional fingerprint string to distinguish file paths.
+ """
+ base_dir = os.path.dirname(path)
+ os.makedirs(base_dir, exist_ok=True)
+ files = os.listdir(base_dir)
+ try:
+ states = self.to_dict(
+ path=path, save_nested_complex_obj=True, fingerprint=fingerprint
+ )
+ except (Exception, KeyboardInterrupt):
+ new_files = set(os.listdir(base_dir)) - set(files)
+ for f in new_files:
+ if os.path.exists(f):
+ os.remove(f)
+ raise
+
+ with open(path, "w") as f:
+ json.dump(states, f, check_circular=False)
+
+ def replace_defaults(
+ self,
+ ignore_none=False,
+ add_new_attr=False,
+ return_unused_kwargs=False,
+ **states,
+ ):
+ """
+ Replace default values of the config instance with provided values.
+
+ Args:
+ ignore_none (bool): If True, ignore None values during the replacement.
+ add_new_attr (bool):
+ If True, allow adding new attributes that are not explicitly defined.
+ return_unused_kwargs (bool):
+ If True, return unused keyword arguments.
+
+ Returns:
+ self or (self, dict):
+ The configuration instance or a tuple of the instance and unused
+ kwargs.
+ """
+ unused_keys = []
+ for k, v in states.items():
+ if isinstance(v, Unset):
+ continue
+ if add_new_attr or hasattr(self, k):
+ if not ignore_none or v is not None:
+ setattr(self, k, v)
+ else:
+ unused_keys.append(k)
+ if return_unused_kwargs:
+ return self, {k: states[k] for k in unused_keys}
+ return self
+
+ def _repr_mimebundle_(self, **kwargs):
+ """
+ Return the MIME bundle for the representation of the estimator.
+
+ This function is utilized by Jupyter environments to display the estimator.
+
+ Args:
+ **kwargs: Arbitrary keyword arguments.
+
+ Returns:
+ dict: A MIME type bundle representing the estimator.
+ """
+ from sklearn._config import get_config
+ from sklearn.utils._estimator_html_repr import estimator_html_repr
+
+ output = {"text/plain": repr(self)}
+ if get_config()["display"] == "diagram":
+ output["text/html"] = estimator_html_repr(self)
+ return output
+
+ def get_params(self, deep=True, show_init_only=True, show_repr_only=True):
+ """
+ Get parameters for the estimator.
+
+ Args:
+ deep (bool): If True, return parameters of nested objects.
+
+ Returns:
+ dict: Dictionary of parameters.
+ """
+ params = {}
+ args = [
+ f.name
+ for f in fields(self)
+ if (not show_init_only or f.init) and (not show_repr_only or f.repr)
+ ]
+ for param in args:
+ obj = getattr(self, param, "not_found")
+ if isinstance(obj, str) and obj == "not_found":
+ # check in dataclass fields for default values
+ for f in fields(self):
+ if f.name == param:
+ if f.default_factory is not None:
+ obj = f.default_factory()
+ else:
+ obj = f.default
+ break
+ if hasattr(obj, "get_params") and deep:
+ params[param] = obj.get_params(deep=deep)
+ elif not pdt.is_complex(obj) or deep:
+ params[param] = obj
+ return params
+
+
+def get_processor_from_config_name(
+ config_name, processor_type=None, processor_name=None
+):
+ try:
+ package = "biofit"
+ module_name = "models"
+ if processor_type:
+ package = f"{package}.{module_name}"
+ module_name = processor_type
+ if processor_name:
+ package = f"{package}.{module_name}"
+ module_name = processor_name
+ package = package.replace("-", "_")
+ module_name = module_name.replace("-", "_")
+ module = importlib.import_module(f".{module_name}", package=package)
+ config_cls = getattr(module, config_name)
+ except (ModuleNotFoundError, AttributeError):
+ return None
+
+ return config_cls
+
+
+@dataclass
+class FitTransformConfig(BaseConfig):
+ # common attributes
+
+ map_kwargs: dict = field(default_factory=lambda: {"fn_kwargs": {}})
+ version: str = "0.0.0"
+
+ # only here to transmit the values to the next config
+ load_from_cache_file = is_caching_enabled()
+
+ def populate_map_kwargs(self):
+ if "keep_in_memory" not in self.map_kwargs:
+ self.map_kwargs["keep_in_memory"] = not self.cache_output
+
+ if "cache_file_name" not in self.map_kwargs:
+ self.map_kwargs["cache_file_name"] = self.cache_file_name
+
+ if "num_proc" not in self.map_kwargs:
+ self.map_kwargs["num_proc"] = self.num_proc
+
+ return self
+
+ @classmethod
+ def prepare_config(
+ cls,
+ **kwargs,
+ ):
+ self = cls()
+ self = self.replace_defaults(**kwargs)
+ self = self.populate_map_kwargs()
+ return self
+
+
+@dataclass
+class FitConfig(FitTransformConfig):
+ # to be used for ray data only
+ concurrency: Union[Tuple[int, int], int] = None
+
+
+@dataclass
+class TransformConfig(FitConfig):
+ @classmethod
+ def from_config(cls, config: FitTransformConfig):
+ if isinstance(config, FitTransformConfig):
+ states = copy.deepcopy(config.__dict__)
+ elif isinstance(config, dict):
+ states = copy.deepcopy(config)
+ else:
+ return super().from_config(copy.deepcopy(config))
+ return cls.prepare_config(**states)
+
+
+@dataclass
+class ProcessorConfig(BaseConfig):
+ f"""
+ Stores the parameters for all processors.
+
+ The config should contain all the parameters for transforming the data without the
+ need to repeat fitting.
+
+ Args:
+ output_format (str, *optional*):
+ The output format of the transformed data. The format will be the same as
+ the input data if not provided. Possible values are {_ORDERED_FORMATS}.
+ input_columns (_SelectedColumnTypes, *optional*):
+ The input columns to be used for fitting the processor and transforming the
+ data. If more than one table or array is provided, a list of lists will
+ correspond to each input argument. A single list will be applied to only
+ the first input argument. Only `datasets.Bioset`,
+ `datasets.IterableDataset`, or `biofit.Bioset` input data support column
+ name selection via `datasets.Features` object.
+ unused_columns (_SelectedColumnTypes, *optional*):
+ The columns that are not used for fitting the processor. This is ignored if
+ `input_columns` is provided. A single list or item will be applied to all
+ input arguments, while a list of lists will correspond to each input
+ argument.
+ keep_unused_columns (bool):
+ Whether to keep the unused columns in the final output. Default is `True`.
+ raise_if_missing (bool):
+ Whether to raise an error if the input columns provided are missing during
+ fitting. Default is `True`. If `False`, the processor will use the columns
+ that are present in the input data and ignore the missing columns. Use this
+ when the pipeline contains feature selection before the processor.
+ enable_caching (bool, *optional*):
+ Whether to disable or enable writing and reading from cache. Default is
+ `True`.
+ cache_dir (str, *optional*):
+ The directory to store the cached data. Default is `None`.
+ version (str):
+ The version of the processor. Default is `"0.0.0"`. This is used for
+ caching purposes. Set this to a new version when the processor is updated
+ and the parameters are the same as a previous version.
+
+ Attributes:
+ _fit_process_desc (str):
+ The description next to the progress bar during fitting.
+ _transform_process_desc (str):
+ The description next to the progress bar during transformation.
+ _input_feature_types (_SelectedFeatureTypes, *optional*):
+ The input feature types that will be applied by the processor. Only used
+ when input has a `features` attribute containing `datasets.Features`, such
+ as a `datasets.Bioset`, `datasets.IterableDataset`, or `biofit.Bioset`.
+ When `input_columns` is provided, this attribute is ignored.
+ _unused_feature_types (_SelectedFeatureTypes, *optional*):
+ The feature types that are not used for fitting the processor. This is
+ ignored if `input_columns` is provided.
+ features_out_suffix (str, *optional*):
+ The suffix to be added to the output feature names.
+ features_out_prefix (str): The prefix to be added to the output feature names.
+ processor_type (str, *optional*):
+ The type of the processor (e.g. feature_selection, scaling, imputation,
+ etc.). Used for auto class instantiation. Must be the same as the name of
+ the parent module where the processor is defined. A `None` value implies
+ that the processor has no parent module.
+ processor_name (str):
+ The name of the processor (e.g. select_k_best, min_max_scaler,
+ simple_imputer, etc.). Used for auto class instantiation. Must be the same
+ as the module name where the processor is defined.
+ dataset_name (str, *optional*):
+ The name of the dataset the processor is applied to. Used for auto class
+ instantiation based on the type of the dataset. A `None` value implies that
+ the processor is not dataset-specific.
+ output_template_name (str, *optional*):
+ The name of the output template. This is used to generate the output
+ feature names.
+ is_fitted (bool): Whether the processor is fitted to the input data.
+ _batch_method_prefix (str):
+ The prefix to be added to the batch method name. This is used to call the
+ batch method during fitting.
+ _input_columns (List[_SelectedColumnTypes], *optional*):
+ The parsed input columns, with number of lists equaling to the number of
+ input arguments. This is generated from
+ `TransformationMixin._set_input_columns_and_arity`.
+ n_features_in_ (int): The number of input features.
+ _n_features_out (int, *optional*):
+ The number of output features. Anything other than `None` implies that the
+ processor is not one-to-one. Set this during fit or before transformation
+ only if the transformation results in new features *and* the number of
+ output features is not equal to the number of input features. For example,
+ feature extraction, feature generation, etc. Preprocessing steps like
+ feature selection does not result in *new* features and should be kept as
+ `None`.
+ _features_out (Features, *optional*):
+ The `datasets.Features` object for the output features. This is inferred
+ before transformation. Used for caching the output data as an arrow ipc
+ stream.
+ feature_idx_in_ (List[int], *optional*):
+ The column indices of the input that were used for fitting. This is
+ automatically set.
+ feature_names_in_ (List[str], *optional*):
+ The column names of the input features. This is automatically set during
+ fitting. Only set if the input data during fit supports column names.
+ Transformations will use this attribute to select the input columns, if
+ supported. Otherwise, `feature_idx_in_` will be used.
+ target_idx_in_ (List[int], *optional*):
+ The column indices of the target that were used for fitting. This is
+ automatically set.
+ target_name_in_ (List[str], *optional*):
+ The column names of the target features. This is automatically set during
+ fitting. Only set if the target data during fit supports column names.
+ Transformations will use this attribute to select the target columns, if
+ supported. Otherwise, `target_idx_in_` will be used.
+ one_to_one_features (bool):
+ Whether the transformation results in a one-to-one mapping of input to
+ output features. This will be `True` if `_n_features_out` is `None` and
+ `False` otherwise.
+ _returns_tuple (bool):
+ Whether the transform method returns a tuple of data. See
+ `MinPrevalenceRowSampleFilter` for an example.
+ _data_fingerprint (str):
+ The fingerprint of the input data. This is used to recognize the input data
+ during transformation. If the input is the same as the data used for
+ fitting, information from the fit process is reused to process the input
+ data for transformation (e.g. selecting the same columns, etc.).
+ """
+
+ output_format: str = field(default=None, kw_only=True, init=True, repr=False)
+ input_columns: SelectedColumnTypes = field(
+ default=None, kw_only=True, init=True, repr=False
+ )
+ unused_columns: SelectedColumnTypes = field(
+ default=None, kw_only=True, init=True, repr=False
+ )
+ keep_unused_columns: bool = field(default=True, kw_only=True, init=True, repr=False)
+ raise_if_missing: bool = field(default=True, kw_only=True, init=True, repr=False)
+ enable_caching: bool = field(default=True, kw_only=True, init=True, repr=False)
+ cache_output: bool = field(default=True, kw_only=True, init=True, repr=False)
+ load_from_cache_file: bool = field(
+ default=True, kw_only=True, init=True, repr=False
+ )
+ cache_dir: str = field(default=None, kw_only=True, init=True, repr=False)
+ version: str = field(
+ default=version.__version__, kw_only=True, init=True, repr=True
+ )
+
+ _fit_process_desc: str = field(
+ default="Fitting the processor to the input data", init=False, repr=False
+ )
+ _transform_process_desc: str = field(
+ default="Transforming the input data", init=False, repr=False
+ )
+
+ _fit_input_feature_types: SelectedFeatureTypes = field(
+ default=None, init=False, repr=False
+ )
+ _transform_input_feature_types: SelectedFeatureTypes = field(
+ default=None, init=False, repr=False
+ )
+ _fit_unused_feature_types: SelectedFeatureTypes = field(
+ default=None, init=False, repr=False
+ )
+ _transform_unused_feature_types: SelectedFeatureTypes = field(
+ default=None, init=False, repr=False
+ )
+
+ _input_columns: SelectedColumnTypes = field(
+ default=None, kw_only=True, init=False, repr=False
+ )
+ features_out_suffix: str = field(default=None, init=False, repr=False)
+ features_out_prefix: str = field(default=None, init=False, repr=False)
+ processor_type: str = field(default="", init=False, repr=False)
+ processor_name: str = field(default="", init=False, repr=False)
+ dataset_name: str = field(default="", init=False, repr=False)
+
+ output_template_name: str = field(default=None, init=False, repr=False)
+
+ # automatically generated attributes
+ _batch_method_prefix: str = field(default="_partial", init=False, repr=False)
+ is_fitted: bool = field(default=False, init=False, repr=False)
+ n_features_in_: int = field(default=None, init=False, repr=False)
+ _n_features_out: int = field(default=None, init=False, repr=False)
+ _features_out: list = field(default=None, init=False, repr=False)
+ feature_idx_in_: List[int] = field(default=None, init=False, repr=False)
+ feature_names_in_: List[str] = field(default=None, init=False, repr=False)
+ extra_idx_in_: List[List[int]] = field(default=None, init=False, repr=False)
+ extra_names_in_: List[List[str]] = field(default=None, init=False, repr=False)
+ _feature_names_out: List[str] = field(default=None, init=False, repr=False)
+ _feature_idx_out: List[int] = field(default=None, init=False, repr=False)
+ _returns_tuple: bool = field(default=False, init=False, repr=False)
+ _data_fingerprint: str = field(default=None, init=False, repr=False)
+
+ @property
+ def one_to_one_features(self):
+ return self._n_features_out is None
+
+ def to_dict(self, *args, deep=True, **kwargs):
+ states = super().to_dict(*args, deep=deep, **kwargs)
+ states["_fit_process_desc"] = self._fit_process_desc
+ states["_transform_process_desc"] = self._transform_process_desc
+ if deep:
+
+ def set_nested_feature_type(ft_name):
+ ft_obj = getattr(self, ft_name)
+ if isinstance(ft_obj, type):
+ states[ft_name] = [ft_obj.__name__]
+ elif isinstance(ft_obj, tuple):
+ states[ft_name] = [tuple(ft.__name__ for ft in ft_obj)]
+ elif isinstance(ft_obj, list):
+ states[ft_name] = []
+ for ft in ft_obj:
+ if isinstance(ft, type):
+ states[ft_name].append(ft.__name__)
+ elif isinstance(ft, tuple):
+ states[ft_name].append(tuple(f.__name__ for f in ft))
+ elif ft is None:
+ states[ft_name].append(None)
+ elif ft_obj is None:
+ states[ft_name] = None
+
+ set_nested_feature_type("_fit_input_feature_types")
+ set_nested_feature_type("_transform_input_feature_types")
+ set_nested_feature_type("_fit_unused_feature_types")
+ set_nested_feature_type("_transform_unused_feature_types")
+
+ states["features_out_suffix"] = self.features_out_suffix
+ states["features_out_prefix"] = self.features_out_prefix
+ states["processor_name"] = self.processor_name
+ states["processor_type"] = self.processor_type
+ states["dataset_name"] = self.dataset_name
+ return states
+
+ @classmethod
+ def from_dict(
+ cls, states: dict, ignore_none: bool = False, add_new_attr: bool = False
+ ) -> T_CONFIG:
+ self = super().from_dict(
+ states, ignore_none=ignore_none, add_new_attr=add_new_attr
+ )
+
+ if is_datasets_available():
+ from datasets.features.features import _FEATURE_TYPES
+ else:
+ _FEATURE_TYPES = {}
+
+ def get_nested_feature_type(ft):
+ if isinstance(ft, str):
+ if not is_datasets_available():
+ raise ValueError(
+ "Trying to load cache using datasets.Feature without datasets "
+ "installed. Please install datasets to load the cache."
+ )
+ return _FEATURE_TYPES.get(ft)
+ elif isinstance(ft, (tuple, list)):
+ return tuple(get_nested_feature_type(f) for f in ft)
+ return ft
+
+ if self._fit_input_feature_types:
+ if not isinstance(self._fit_input_feature_types, list):
+ self._fit_input_feature_types = [self._fit_input_feature_types]
+ self._fit_input_feature_types = [
+ get_nested_feature_type(ft) for ft in self._fit_input_feature_types
+ ]
+ if self._fit_unused_feature_types:
+ if not isinstance(self._fit_unused_feature_types, list):
+ self._fit_unused_feature_types = [self._fit_unused_feature_types]
+ self._fit_unused_feature_types = [
+ get_nested_feature_type(ft) for ft in self._fit_unused_feature_types
+ ]
+ if self._transform_input_feature_types:
+ if not isinstance(self._transform_input_feature_types, list):
+ self._transform_input_feature_types = [
+ self._transform_input_feature_types
+ ]
+ self._transform_input_feature_types = [
+ get_nested_feature_type(ft)
+ for ft in self._transform_input_feature_types
+ ]
+ if self._transform_unused_feature_types:
+ if not isinstance(self._transform_unused_feature_types, list):
+ self._transform_unused_feature_types = [
+ self._transform_unused_feature_types
+ ]
+ self._transform_unused_feature_types = [
+ get_nested_feature_type(ft)
+ for ft in self._transform_unused_feature_types
+ ]
+ return self
+
+ def replace_defaults(self, ignore_none=False, add_new_attr=False, **states):
+ for k, v in states.items():
+ if isinstance(v, Unset):
+ continue
+ if (add_new_attr or hasattr(self, k)) and (
+ not ignore_none or v is not None
+ ):
+ setattr(self, k, v)
+ return self
+
+
+class TransformationMixin:
+ def _get_method(self, formats, func_type, prefix=None):
+ """
+ Retrieves processing methods based on the function type and target format.
+
+ Args:
+ format (str): The target format.
+ func_type (str): The type of processing method.
+ prefix (str, *optional*):
+ The prefix to be added to the method name (e.g "_partial" for batch
+ processing methods). Default is None.
+
+ Returns:
+ A list of processing methods based on the target format.
+ """
+ funcs = []
+
+ if isinstance(formats, str):
+ format = [formats]
+ for format in formats + _SPECIAL_FORMATS:
+ if prefix:
+ func = getattr(self, f"{prefix}{func_type}_{format}", None)
+ if func is not None:
+ funcs.append(func)
+ else:
+ func = getattr(self, f"{func_type}_{format}", None)
+ if func is not None:
+ funcs.append(func)
+ return funcs
+
+ def _has_method(self, formats, func_type, prefix=None):
+ """
+ Checks if atleast one method exists for the given [prefix]_[func_type]_[format]
+ or [func_type]_[format] if prefix is not given.
+
+ Args:
+ format (str): The target format.
+ func_type (str): The type of processing method.
+ prefix (str, *optional*):
+ The prefix to be added to the method name (e.g "_partial" for batch
+ processing methods). Default is None.
+ check_only (bool, *optional*):
+ If True, only checks if a method exists. Default is False.
+
+ Returns:
+ True if a method exists for any of the given formats, False
+ otherwise.
+ """
+
+ if isinstance(formats, str):
+ format = [formats]
+ for format in formats + _SPECIAL_FORMATS:
+ if prefix:
+ func = getattr(self, f"{prefix}{func_type}_{format}", None)
+ if func is not None:
+ return True
+ else:
+ func = getattr(self, f"{func_type}_{format}", None)
+ if func is not None:
+ return True
+ return False
+
+ def _get_target_func(
+ self,
+ funcs,
+ source_format,
+ target_formats=None,
+ accepted_formats=_ORDERED_FORMATS,
+ ):
+ formats = []
+ new_funcs = []
+ for f in accepted_formats:
+ for fun in funcs:
+ if fun.__name__.endswith(f"_{f}"):
+ formats.append(f)
+ new_funcs.append(fun)
+
+ if formats:
+ # this class has a method for the target format
+ to_format = formats[0]
+ if source_format in formats:
+ to_format = source_format
+ if target_formats is not None:
+ if not isinstance(target_formats, list):
+ target_formats = [target_formats]
+ for target_format in target_formats:
+ if target_format in formats:
+ to_format = target_format
+ break
+ else:
+ logger.warning(
+ f"Using {self.__class__.__name__} using `{target_formats}` is "
+ f"not supported. Formatting input to `{to_format}` instead"
+ )
+
+ return new_funcs[formats.index(to_format)], to_format
+
+ any_funcs = [
+ f
+ for f in funcs
+ if f.__name__.endswith("_any") or f.__name__.endswith("_all")
+ ]
+ if any_funcs:
+ return any_funcs[0], source_format
+
+ tbl_funcs = [f for f in funcs if f.__name__.endswith("_array")]
+ if tbl_funcs:
+ target_formats = _ORDERED_FORMATS
+ funcs = tbl_funcs
+ else:
+ df_funcs = [f for f in funcs if f.__name__.endswith("_dataframe")]
+ if df_funcs:
+ target_formats = _DATAFRAME_FORMATS
+ funcs = df_funcs
+ else:
+ arr_funcs = [f for f in funcs if f.__name__.endswith("_sklearn")]
+ if arr_funcs:
+ target_formats = _SKLEARN_FORMATS
+ funcs = arr_funcs
+
+ if funcs and len(funcs):
+ func = funcs[0]
+ to_format = None
+ if target_formats is not None and len(funcs):
+ # we assume that the first function handles all formats within fn_trans_format
+ # e.g sklearn functions with [polars, numpy, pandas] formats
+ if isinstance(target_formats, list):
+ # prioritize input format over output format if both are supported
+ if source_format in target_formats:
+ to_format = source_format
+ else:
+ to_format = target_formats[0]
+
+ return func, to_format
+
+ return None, target_formats
+
+ def _parse_column_selection(self, *args, from_config=False):
+ if from_config:
+ return self._parse_column_selection_from_config(*args)
+ return self._parse_column_selection_from_self(*args)
+
+ def _parse_column_selection_from_config(self, *args, from_config=False):
+ out = {}
+ if len(args):
+ if self.config.feature_names_in_:
+ out[args[0]] = self.config.feature_names_in_
+ else:
+ out[args[0]] = [f"{i}" for i in self.config.feature_idx_in_]
+ if len(args) > 1:
+ for i, arg in enumerate(args):
+ if (
+ self.config.extra_names_in_
+ and len(self.config.extra_names_in_) > i
+ and self.config.extra_names_in_[i]
+ ):
+ out[arg] = self.config.extra_names_in_[i]
+ elif (
+ self.config.extra_idx_in_
+ and len(self.config.extra_idx_in_) > i
+ and self.config.extra_idx_in_[i]
+ ):
+ out[arg] = [f"{j}" for j in self.extra_idx_in_[i]]
+ return out
+
+ def _parse_column_selection_from_self(self, *args, from_config=False):
+ out = {}
+ if len(args):
+ if self.feature_names_in_:
+ out[args[0]] = self.feature_names_in_
+ else:
+ out[args[0]] = [f"{i}" for i in self.feature_idx_in_]
+ if len(args) > 1:
+ for i, arg in enumerate(args):
+ if (
+ self.extra_names_in_
+ and len(self.extra_names_in_) > i
+ and self.extra_names_in_[i]
+ ):
+ out[arg] = self.extra_names_in_[i]
+ elif (
+ self.extra_idx_in_
+ and len(self.extra_idx_in_) > i
+ and self.extra_idx_in_[i]
+ ):
+ out[arg] = [f"{j}" for j in self.extra_idx_in_[i]]
+ return out
+
+ def _set_input_columns_and_arity(self, *args):
+ input_columns = None
+ if len(args) > 1:
+ input_columns = [None] * len(args)
+ for i, arg in enumerate(args):
+ if arg is not None:
+ if isinstance(arg, (str, int)):
+ input_columns[i] = [arg]
+ elif isinstance(arg, list):
+ input_columns[i] = arg
+ else:
+ input_columns = args[0] or None
+ if isinstance(input_columns, (str, int)):
+ input_columns = [input_columns]
+ input_columns = [input_columns]
+ return input_columns
+
+ def _reinsert_columns(
+ self, input, out, indices, unused_indices, one_to_one_features=False
+ ):
+ out_dims = DataHandler.get_shape(out)
+ x_dims = DataHandler.get_shape(input)
+
+ if unused_indices and x_dims[0] == out_dims[0]:
+ other_col_names = DataHandler.get_column_names(input, generate_cols=True)
+ other_col_names = [other_col_names[i] for i in unused_indices]
+ other_cols = DataHandler.select_columns(input, other_col_names)
+ other_dims = DataHandler.get_shape(other_cols)
+ if len(other_dims) == 1:
+ other_dims = (other_dims[0], 1)
+ other_cols = DataHandler.to_frame(other_cols, "__input__")
+ if len(out_dims) == 1:
+ out_dims = (out_dims[0], 1)
+ out = DataHandler.to_frame(out, "__output__")
+ if one_to_one_features:
+ other_inds = unused_indices
+ out_inds = indices
+ else:
+ other_inds = list(range(other_dims[1]))
+ out_inds = list(range(other_dims[1], other_dims[1] + out_dims[1]))
+
+ if other_dims[1] > out_dims[1]:
+ out = DataHandler.concat([other_cols, out], axis=1)
+ inds = list(np.argsort(other_inds + out_inds))
+ else:
+ out = DataHandler.concat([out, other_cols], axis=1)
+ inds = list(np.argsort(out_inds + other_inds))
+
+ out = DataHandler.select_columns(out, inds)
+ return out
+
+ def _make_columns_exclusive(self, columns):
+ new_set = columns.copy()
+ for i in reversed(range(0, len(columns) - 1)):
+ if columns[i] is not None and columns[i + 1] is not None:
+ new_set[i] = list(set(columns[i]) - set(columns[i + 1]))
+ return new_set
+
+ def _get_columns(
+ self,
+ X,
+ *args,
+ input_columns=None,
+ input_feature_types=None,
+ unused_columns=None,
+ unused_feature_types=None,
+ raise_if_missing=True,
+ ):
+ assert X is not None, "Input data is None"
+ first_arg_row_num = DataHandler.get_shape(X)[0]
+ assert first_arg_row_num > 0, "Input data has no rows"
+ assert input_columns is None or isinstance(input_columns, list), (
+ f"input_columns must be a list of column names or indices, "
+ f"but got {type(input_columns)}"
+ )
+
+ def get_columns(
+ X,
+ input_columns=None,
+ unused_columns=None,
+ generate_cols=False,
+ raise_if_missing=True,
+ ):
+ if X is None:
+ return None, None, None
+ col_names = DataHandler.get_column_names(X, generate_cols=True)
+ col_names_set = set(col_names)
+ if input_columns:
+ if isinstance(input_columns, tuple) or isinstance(input_columns, type):
+ feature_type = input_columns
+ if isinstance(feature_type, type):
+ feature_type = (feature_type,)
+
+ if is_datasets_available():
+ from datasets.features.features import _FEATURE_TYPES
+
+ if all(f in _FEATURE_TYPES.values() for f in feature_type):
+ try:
+ input_columns = (
+ DataHandler.get_column_names_by_feature_type(
+ X, feature_type=feature_type
+ )
+ )
+ except ValueError:
+ if generate_cols:
+ input_columns = DataHandler.get_column_names(
+ X, generate_cols=True
+ )
+ else:
+ input_columns = None
+
+ elif input_columns:
+ if isinstance(input_columns, (str, int)):
+ input_columns = [input_columns]
+ if not isinstance(input_columns[0], int):
+ missing_columns = set(input_columns) - col_names_set
+ if missing_columns and raise_if_missing:
+ raise ValueError(
+ f"Columns {missing_columns} not found in input dataset"
+ )
+ else:
+ input_columns = [
+ c for c in input_columns if c in col_names_set
+ ]
+ elif unused_columns:
+ if isinstance(unused_columns, (str, int)):
+ unused_columns = [unused_columns]
+ if isinstance(unused_columns, tuple) or isinstance(
+ unused_columns, type
+ ):
+ feature_type = unused_columns
+ if isinstance(feature_type, type):
+ feature_type = (feature_type,)
+ if is_datasets_available():
+ from datasets.features.features import _FEATURE_TYPES
+ else:
+ _FEATURE_TYPES = {}
+ if all(f in _FEATURE_TYPES.values() for f in feature_type):
+ try:
+ unused_columns = (
+ DataHandler.get_column_names_by_feature_type(
+ X, feature_type=feature_type
+ )
+ or []
+ )
+ except ValueError:
+ if generate_cols:
+ unused_columns = []
+ else:
+ unused_columns = None
+ if unused_columns is not None:
+ unused_columns = set(unused_columns)
+ input_columns = [c for c in col_names if c not in unused_columns]
+ elif generate_cols:
+ input_columns = DataHandler.get_column_names(X, generate_cols=True)
+
+ if input_columns:
+ if isinstance(input_columns[0], str):
+ input_indices = DataHandler.get_column_indices(X, input_columns)
+ else:
+ input_indices = input_columns
+ if DataHandler.supports_named_columns(get_data_format(X)):
+ cols = DataHandler.get_column_names(X, generate_cols=True)
+ input_columns = [cols[idx] for idx in input_indices]
+ else:
+ input_columns = None
+ unused_indices = list(
+ sorted(set(range(len(col_names))) - set(input_indices))
+ )
+ else:
+ return None, None, None
+
+ return input_columns, input_indices, unused_indices
+
+ arity = 1
+ if input_columns and isinstance(input_columns, list):
+ arity = len(input_columns)
+ else:
+ return None, None, None, None, None, None, None
+
+ def parse_inputs(i):
+ _input_columns = input_columns[i]
+
+ _input_feature_types = None
+ if input_feature_types is not None:
+ _input_feature_types = input_feature_types[i]
+
+ _unused_columns = None
+ if unused_columns is not None:
+ _unused_columns = unused_columns[i]
+
+ _unused_feature_types = None
+ if unused_feature_types is not None:
+ _unused_feature_types = unused_feature_types[i]
+
+ return (
+ _input_columns or None,
+ _input_feature_types or None,
+ _unused_columns or None,
+ _unused_feature_types or None,
+ )
+
+ _input_columns, _input_feature_types, _unused_columns, _unused_feature_types = (
+ parse_inputs(0)
+ )
+
+ feature_names_in, feature_idx_in, unused_idx_in = get_columns(
+ X,
+ input_columns=_input_columns or _input_feature_types,
+ unused_columns=_unused_columns or _unused_feature_types,
+ generate_cols=True,
+ raise_if_missing=raise_if_missing,
+ )
+ if _input_columns is not None and feature_idx_in is None:
+ raise ValueError(
+ f"Columns {_input_columns} not found in {DataHandler.get_column_names(X)}"
+ )
+ extra_names_in = None
+ extra_idx_in = None
+ unused_extra_idx_in = None
+ offsets = None
+ assert (
+ arity == len(args) + 1
+ ), f"Number of column sets ({arity}) must match the arity ({len(args) + 1})"
+ if arity > 1 or len(args):
+ extra_names_in = []
+ extra_idx_in = []
+ offsets = []
+ if len(args) and not all(arg is None for arg in args):
+ unused_extra_idx_in = []
+ x_dims = DataHandler.get_shape(X)
+ if len(x_dims) == 1:
+ offset = 1
+ else:
+ offset = x_dims[1]
+ for i, arg in enumerate(args):
+ (
+ _input_columns,
+ _input_feature_types,
+ _unused_columns,
+ _unused_feature_types,
+ ) = parse_inputs(i + 1)
+
+ if not is_bioset(X) and not is_dataset(X):
+ _unused_feature_types = None
+ _input_feature_types = None
+ if arg is None:
+ # look into the first input
+ _input_columns = _input_columns or _input_feature_types
+ _unused_columns = _unused_columns or _unused_feature_types
+ if _input_columns is None and _unused_columns is None:
+ _extra_names_in, _extra_idx_in, _unused_extra_idx_in = (
+ None,
+ None,
+ None,
+ )
+ else:
+ _extra_names_in, _extra_idx_in, _unused_extra_idx_in = (
+ get_columns(
+ X,
+ input_columns=_input_columns
+ or _input_feature_types,
+ unused_columns=_unused_columns
+ or _unused_feature_types,
+ generate_cols=True,
+ raise_if_missing=raise_if_missing,
+ )
+ )
+ offsets.append(0)
+ else:
+ arg_dim = DataHandler.get_shape(arg)
+ _extra_names_in, _extra_idx_in, _unused_extra_idx_in = (
+ get_columns(
+ arg,
+ input_columns=_input_columns or _input_feature_types,
+ unused_columns=_unused_columns or _unused_feature_types,
+ generate_cols=True,
+ raise_if_missing=raise_if_missing,
+ )
+ )
+ arg_dim = DataHandler.get_shape(arg)
+
+ # only offset when the tables can be combined
+ if first_arg_row_num == arg_dim[0]:
+ offsets.append(offset)
+ if len(arg_dim) == 1:
+ offset += 1
+ else:
+ offset += arg_dim[1]
+ extra_names_in.append(_extra_names_in)
+ extra_idx_in.append(_extra_idx_in)
+ unused_extra_idx_in.append(_unused_extra_idx_in)
+
+ else:
+ for i in range(1, arity):
+ (
+ _input_columns,
+ _input_feature_types,
+ _unused_columns,
+ _unused_feature_types,
+ ) = parse_inputs(i)
+ _extra_names_in, _extra_idx_in, _unused_extra_idx_in = get_columns(
+ X,
+ input_columns=_input_columns or _input_feature_types,
+ unused_columns=_unused_columns or _unused_feature_types,
+ generate_cols=False,
+ raise_if_missing=raise_if_missing,
+ )
+ extra_names_in.append(_extra_names_in)
+ extra_idx_in.append(_extra_idx_in)
+ _extra_idx_in = set(_extra_idx_in or [])
+ unused_idx_in = [
+ idx for idx in unused_idx_in if idx not in _extra_idx_in
+ ]
+ offsets.append(0)
+ return (
+ feature_names_in,
+ feature_idx_in,
+ unused_idx_in,
+ extra_names_in,
+ extra_idx_in,
+ unused_extra_idx_in,
+ offsets,
+ )
+
+ def generate_fingerprint(self, fingerprint, config: BaseConfig):
+ hash = Hasher()
+ hash.update(fingerprint)
+ hash_str = f"{self.__module__}.{self.__class__.__name__}"
+ if hasattr(config, "version"):
+ hash_str += f"@{config.version}"
+ hash.update(hash_str)
+ fingerprint = hash.hexdigest()
+ fingerprint = fingerprint_from_kwargs(fingerprint, config.get_params())
+
+ return fingerprint
+
+
+class BaseProcessor(TransformationMixin):
+ """
+ Configures and manages data processing operations, supporting transformations and
+ handling of various configurations and states, primarily designed for batch
+ processing of data with options for multiprocessing and caching.
+
+ Attributes:
+ update_fingerprint (bool):
+ Flag to determine if the fingerprint should be updated after processing.
+ output_dtype (type): Data type for the output features.
+ config (ProcessorConfig): Configuration object specifying processor settings.
+ cache_files (list): List of cache file paths used during processing.
+
+ Raises:
+ ValueError: If an unsupported configuration type is provided."
+ """
+
+ # process attributes
+ output_dtype = None
+
+ # config attributes
+ config_class = ProcessorConfig
+ config: ProcessorConfig = None
+
+ # internal attributes for transformation
+ _feature_dependent = True
+ _input_columns = None
+ _method_prefix: str = "_transform"
+ _fingerprint = None
+ _is_multiprocessing = False
+ extra_names_in_ = []
+ _selected_indices = None
+ _unused_indices = None
+ _extra_indices = None
+ _unused_extra_indices = None
+
+ # automatically generated attributes
+ cache_files = None
+
+ def __init__(self, config: Optional[ProcessorConfig] = None, **kwargs):
+ add_new_attr = kwargs.pop("add_new_attr", False)
+ ignore_none = kwargs.pop("ignore_none", False)
+
+ if config is None:
+ if hasattr(self, "config_class"):
+ self.config = self.config_class.from_dict(
+ kwargs, ignore_none=ignore_none, add_new_attr=add_new_attr
+ )
+ elif isinstance(config, ProcessorConfig):
+ self.config = config
+ elif isinstance(config, dict):
+ self.config = self.config_class.from_dict(
+ config, ignore_none=ignore_none, add_new_attr=add_new_attr
+ )
+ else:
+ raise ValueError(f"Unsupported config type {type(config)}")
+ if config is None:
+ self = self.set_params(**kwargs)
+ if kwargs.get("_function", None):
+ self._function = kwargs["_function"]
+
+ @classmethod
+ def _from_config(cls, config: ProcessorConfig, **kwargs):
+ """Instantiates the processor from a configuration object."""
+ return cls(config=config, **kwargs)
+
+ def __call__(self, batch: Union[pd.DataFrame, Dict[str, np.ndarray]], **kwargs):
+ self = self._process_batches(batch, **kwargs)
+ return batch
+
+ @sync_backup_config
+ def set_params(self, **kwargs):
+ """Sets the parameters of the processor"""
+ self.config = self.config.replace_defaults(**kwargs)
+ return self
+
+ @property
+ def is_fitted(self):
+ """bool: Whether the processor has been fitted."""
+ return self.config.is_fitted
+
+ @property
+ def has_fit(self):
+ """bool: Whether a fit function is found for the processor"""
+ return self._get_method(_ORDERED_FORMATS, func_type="_fit") or self._get_method(
+ _ORDERED_FORMATS, func_type="_fit", prefix=self._batch_method_prefix
+ )
+
+ @property
+ def fingerprint(self):
+ """str: The fingerprint of the processor."""
+ return self._parse_fingerprint(self._fingerprint)
+
+ def _parse_fingerprint(self, fingerprint):
+ """Parses the fingerprint and returns the base fingerprint and the processor suffix."""
+ base_fingerprint = fingerprint
+ processor_suffix = f"-{self.config.processor_name}"
+ if self.config.processor_type:
+ processor_suffix += f"-{self.config.processor_type}"
+ if self.config.dataset_name:
+ processor_suffix += f"-{self.config.dataset_name}"
+ return f"{base_fingerprint}{processor_suffix}"
+
+ def _reset(self, config: ProcessorConfig):
+ """Resets the processor to its initial state."""
+ # reinstatiate the processor
+ self._fingerprint = None
+ self.__init__(config=config)
+
+ @classmethod
+ def from_config(
+ cls, path_or_config: Union[str, os.PathLike, dict, "BaseConfig"], **kwargs
+ ) -> T_PROC:
+ """Instantiates the processor from a configuration file or object."""
+ config = cls.config_class.from_config(path_or_config, add_new_attr=True)
+ return cls(config=config, **kwargs)
+
+ @classmethod
+ def from_config_transform(cls, X, path: str, **kwargs):
+ """Transforms the input data using the processor configuration."""
+ self = cls.from_config(path, **kwargs)
+ self.config.is_fitted = True
+ return self.transform(X)
+
+ def populate_map_kwargs(self, map_kwargs, cache_output, cache_file_name, num_proc):
+ """
+ Populates the map_kwargs with default values if not provided.
+
+ Args:
+ map_kwargs (dict): The keyword arguments for the map function.
+ cache_output (bool): Whether to keep the processed data in memory.
+ cache_file_name (str): The name of the cache file.
+ num_proc (int): The number of processes to use for multiprocessing.
+
+ Returns:
+ dict: The updated map_kwargs.
+ """
+ if "keep_in_memory" not in map_kwargs:
+ map_kwargs["keep_in_memory"] = not cache_output
+
+ if "cache_file_name" not in map_kwargs:
+ map_kwargs["cache_file_name"] = cache_file_name
+
+ if "num_proc" not in map_kwargs:
+ map_kwargs["num_proc"] = num_proc
+
+ return map_kwargs
+
+ def _validate_fit_params(self, arity):
+ if self.config._input_columns is not None:
+ if self.config._fit_input_feature_types is not None and len(
+ self.config._input_columns
+ ) != len(self.config._fit_input_feature_types):
+ example_arg = ", ".join(["None" for _ in range(arity)])
+ assert False, (
+ "`_fit_input_feature_types` is defined in "
+ f"{self.config.__class__.__name__} but does not match the arity of "
+ f"the fit function in {self.__class__.__name__} (i.e. len("
+ "self.config._fit_input_feature_types) != "
+ "len(self.config._input_columns) -> "
+ f"{len(self.config._fit_input_feature_types)} != "
+ f"{len(self.config._input_columns)}).\n"
+ "This can be corrected by doing, for example:\n"
+ f"_fit_input_feature_types = field(\n"
+ f" default_factory=lambda: [{example_arg}], init=False, "
+ "repr=False\n"
+ ")"
+ )
+ if self.config._fit_unused_feature_types is not None and len(
+ self.config._input_columns
+ ) != len(self.config._fit_unused_feature_types):
+ example_arg = ", ".join(["None" for _ in range(arity)])
+ assert False, (
+ "`_fit_unused_feature_types` is defined in "
+ f"{self.config.__class__.__name__} but does not match the arity of "
+ f"the fit function in {self.__class__.__name__} (i.e. len("
+ "self.config._fit_unused_feature_types) != "
+ "len(self.config._input_columns) -> "
+ f"{len(self.config._fit_unused_feature_types)} != "
+ f"{len(self.config._input_columns)}).\n"
+ "This can be corrected by doing, for example:\n"
+ f"_fit_unused_feature_types = field(\n"
+ f" default_factory=lambda: [{example_arg}], init=False, "
+ "repr=False\n"
+ ")"
+ )
+
+ def fit(
+ self,
+ X,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ num_proc: int = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ):
+ """Must be implemented by subclasses if the processor is trainable."""
+ # only use this fit method if no concrete fit method is found
+ return self._process_fit(
+ X,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ )
+
+ def _process_fit(
+ self,
+ X,
+ *args,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ num_proc: int = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ) -> T_PROC:
+ """
+ Fits the processor to the input data, preparing it for transformation. This process
+ may involve learning parameters from the data, validating input formats, and setting
+ up caching mechanisms.
+
+ Args:
+ X (Union[np.ndarray, pd.DataFrame, Bioset, IterableDataset, DatasetDict, IterableDatasetDict]):
+ The input data to fit the processor on. Can be a variety of types including
+ numpy arrays, pandas DataFrames, or Hugging Face's `datasets` objects.
+ *args (Any, optional):
+ Additional input data to fit the processor on, such as target data for
+ supervised learning tasks. Defaults to None.
+ batched (bool, optional):
+ Whether to process the data in batches. This can be beneficial for large datasets. Defaults to None,
+ which will use the processor's default behavior.
+ batch_transform (bool, optional):
+ Specifies if transformation should be applied in batches. Defaults to None.
+ batch_fit (bool, optional):
+ Specifies if fitting should be applied in batches. Defaults to None.
+ batch_size (int, optional):
+ Size of each batch when `batched` is True. Defaults to 1000.
+ map_kwargs (dict, optional):
+ Additional keyword arguments to pass to the map function when processing datasets. Defaults to None.
+ cache_output (bool, optional):
+ Whether to keep the processed data in memory. Useful for avoiding disk IO. Defaults to None.
+ batch_format (str, optional):
+ The format to convert the input data to before processing. Defaults to None.
+ batch_format_kwargs (dict, optional):
+ Additional keyword arguments for the input format conversion. Defaults to None.
+ fn_output_format_kwargs (dict, optional):
+ Additional keyword arguments for the output format conversion. Defaults to None.
+ split (str, optional):
+ If the input is a DatasetDict or IterableDatasetDict, this specifies the split to process. Defaults to 'train'.
+ cache_dir (str, Path, optional):
+ Directory where processed datasets should be cached. Defaults to None, which uses the processor's default cache directory.
+ fingerprint (str, optional):
+ A unique identifier for the processing operation, used for caching. Defaults to None.
+ load_from_cache_file (bool, optional):
+ Whether to load the fitted processor from a cache file if available. Defaults to None, which follows the processor's default behavior.
+ update_fingerprint (bool, optional):
+ Whether to update the fingerprint after processing. Useful for ensuring uniqueness in caching. Defaults to None.
+ keep_unused_columns (bool, optional):
+ Whether to retain features in the dataset that are not processed by this processor. Defaults to True.
+
+ Returns:
+ self: The fitted processor instance.
+
+ Raises:
+ ValueError: If `input_columns` are specified but do not exist in the dataset.
+ """
+
+ funcs = self._get_method(_ORDERED_FORMATS, func_type="_fit")
+ assert self.config._input_columns is not None or len(funcs) == 0, (
+ f"The `fit` method of `{self.__class__.__name__}` must call:\n"
+ "```\n"
+ "self.config._input_columns = self._set_input_columns_and_arity(*args)"
+ "\n```\n"
+ "Where `*args` are the columns for each input dataset."
+ )
+
+ if not hasattr(self, "config_"):
+ self.config_ = copy.deepcopy(self.config)
+ else:
+ self._reset(copy.deepcopy(self.config_))
+
+ if is_dataset_dict(X):
+ raise ValueError(
+ "Please provide the dataset directly instead of the dictionary: processor.fit(dataset['train'])"
+ )
+
+ if cache_output is None:
+ cache_output = self.config.enable_caching and self.config.cache_output
+
+ if cache_output:
+ self.config._data_fingerprint = getattr(
+ X, "_fingerprint", None
+ ) or fingerprint_from_data(X)
+ else:
+ self.config._data_fingerprint = None
+
+ if cache_output:
+ self._fingerprint = fingerprint
+
+ if not fingerprint:
+ self._fingerprint = fingerprint_from_kwargs(
+ self.config._data_fingerprint,
+ self.config.get_params(),
+ )
+
+ if cache_dir is not None:
+ cache_dir = expand_path(str(cache_dir))
+ cache_dir = os.path.join(cache_dir, "processors")
+
+ cache_dir = generate_cache_dir(
+ self,
+ self.config._data_fingerprint,
+ root_dir=cache_dir or biofit.config.BIOFIT_PROCESSORS_CACHE,
+ )
+
+ if cache_dir:
+ if cache_file_name:
+ if is_remote_url(cache_file_name):
+ raise ValueError(
+ "`cache_file_name` is a remote URL. Please provide the "
+ "file name only. You can specify the directory using "
+ "`cache_dir`."
+ )
+ elif os.path.isabs(cache_file_name):
+ raise ValueError(
+ "`cache_file_name` is an absolute path. Please provide the "
+ "file name only. You can specify the directory using "
+ "`cache_dir`."
+ )
+ self.cache_files = [
+ {
+ "filename": get_cache_file_name(
+ cache_dir, self.fingerprint, cache_file_name
+ )
+ }
+ ]
+
+ self._validate_fit_params(len(args) + 1)
+ return self._fit(
+ X,
+ *args,
+ funcs=funcs,
+ cache_file_name=cache_file_name,
+ input_columns=self.config._input_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ cache_dir=cache_dir,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ )
+
+ @keep_dataset_fingerprint
+ def _fit(
+ self,
+ X,
+ *args,
+ funcs: List[Callable] = None,
+ cache_file_name: str = None,
+ input_columns: List[str] = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = False,
+ load_from_cache_file: bool = None,
+ cache_dir: str = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ ):
+ """
+ Fits the processor to the input data.
+
+ Args:
+ X (Any): The input data.
+ y (Any, optional): The target data. Defaults to None.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ Any: The fitted processor.
+ """
+
+ # upadate fit_config with dataset and processor info
+
+ if cache_output is None:
+ cache_output = self.config.enable_caching and self.config.cache_output
+
+ if load_from_cache_file is None:
+ load_from_cache_file = (
+ self.config.enable_caching and self.config.load_from_cache_file
+ )
+ try:
+ if load_from_cache_file and self.cache_files:
+ cache_file_name = self.cache_files[0]["filename"]
+ self = self.load_processed_estimator_from_cache(cache_file_name)
+ logger.info(f"Loading cached processed estimator at {cache_file_name}")
+ else:
+ raise NonExistentCacheError
+ except NonExistentCacheError:
+ (
+ selected_columns,
+ selected_indices,
+ unused_indices,
+ extra_columns,
+ extra_indices,
+ unused_extra_indices,
+ offsets,
+ ) = self._get_columns(
+ X,
+ *args,
+ input_columns=input_columns,
+ input_feature_types=self.config._fit_input_feature_types,
+ unused_columns=self.config.unused_columns,
+ unused_feature_types=self.config._fit_unused_feature_types,
+ raise_if_missing=raise_if_missing
+ if raise_if_missing is not None
+ else True,
+ )
+
+ self.config.feature_idx_in_ = selected_indices
+ self.config.feature_names_in_ = selected_columns
+ self.config.extra_names_in_ = extra_columns
+ self.config.extra_idx_in_ = extra_indices
+
+ extra = []
+ extra_to_pass = []
+ # two options: either all args are the same number of rows as X and we can
+ # combine them or they are not and we need to pass them separately.
+ # Advantage of combining is that we can apply multiprocessing or batching
+ # by using the same indices. Also interops with HF's Bioset.map
+ if len(args) and not all(arg is None for arg in args):
+ main_dims = DataHandler.get_shape(X)
+ combine_all = True
+ for arg in args:
+ if arg is not None:
+ arg_dims = DataHandler.get_shape(arg)
+ if arg_dims[0] != main_dims[0]:
+ combine_all = False
+ break
+ for i, arg in enumerate(args):
+ if DataHandler.supports_named_columns(arg) and combine_all:
+ cols = DataHandler.get_column_names(arg, generate_cols=True)
+ cols = [f"{c}_{i}" for c in cols]
+ extra.append(DataHandler.set_column_names(arg, cols))
+ else:
+ extra.append(arg)
+ if combine_all:
+ input = DataHandler.concat(
+ [X] + [ext for ext in extra if ext is not None], axis=1
+ )
+ extra_to_pass = None
+ else:
+ extra_to_pass = extra
+ extra = None
+ else:
+ input = X
+
+ fit_map_kwargs, pooler = self._prepare_fit_kwargs(
+ funcs,
+ input,
+ X,
+ extra_inputs=extra,
+ extra_untouched_inputs=extra_to_pass,
+ selected_indices=selected_indices,
+ unused_indices=unused_indices,
+ extra_indices=extra_indices,
+ unused_extra_indices=unused_extra_indices,
+ offsets=offsets,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ batch_format=batch_format,
+ batched=batched,
+ batch_size=batch_size,
+ )
+ input, fit_map_kwargs = self._process_fit_input(input, **fit_map_kwargs)
+ runner = None
+ out = self
+ if fit_map_kwargs["fn_kwargs"]["fn"]:
+ if is_ray_available() and "ray" in sys.modules:
+ import ray.data
+
+ if isinstance(X, ray.data.Bioset):
+ fit_ray_kwargs = self._convert_map_kwargs_to_ray_kwargs(
+ fit_map_kwargs, batch_format=batch_format, is_fit=True
+ )
+ return X, fit_ray_kwargs
+
+ pooler = pooler if pooler is not None else self._pool_fit
+
+ @wraps(BaseProcessor.map)
+ def runner(*args, **map_kwargs):
+ out = self.map(*args, **map_kwargs)
+ if len(out) == 1:
+ return out[0]
+ return pooler(out)
+
+ if is_iterable_dataset(input):
+ from datasets import IterableDataset
+
+ @wraps(IterableDataset.map)
+ def runner(*args, **map_kwargs):
+ if len(args) > 0:
+ ds = args[0]
+ args = args[1:]
+ else:
+ ds = map_kwargs.pop("self")
+ return ds.map(*args, **map_kwargs)
+
+ out = self.run(input, runner=runner, **fit_map_kwargs)
+
+ self = self._process_fit_output(input, out)
+
+ self.config.n_features_in_ = self.config.n_features_in_
+ if (
+ self.config.feature_idx_in_ is not None
+ and self.config.feature_names_in_ is None
+ ):
+ self.config.n_features_in_ = len(self.config.feature_idx_in_)
+
+ self.config._n_features_out = (
+ self.config._n_features_out
+ or self.config._n_features_out
+ or getattr(self, "n_features_out", None)
+ )
+
+ temp_file = None
+ if cache_output and self.cache_files:
+ cache_file_name = (
+ Path(self.cache_files[0]["filename"]).resolve().as_posix()
+ )
+ cache_dir = os.path.dirname(cache_file_name)
+ temp_file = tempfile.NamedTemporaryFile(
+ "wb", dir=cache_dir, delete=False
+ )
+ try:
+ self.config.save_to_cache(
+ temp_file.name, fingerprint=self.fingerprint
+ )
+ except (Exception, KeyboardInterrupt):
+ temp_file.close()
+ if os.path.exists(temp_file.name):
+ os.remove(temp_file.name)
+ raise
+
+ if temp_file and self.cache_files:
+ move_temp_file(temp_file, cache_file_name)
+ self.config.is_fitted = True
+ return self
+
+ def _validate_transform_params(self, arity):
+ """Validates the input arguments for the transform function.
+ This is used to ensure that the input columns match the expected arity of the
+ transform function. Arity is defined by the length of self._input_columns.
+ """
+ if self._input_columns is not None:
+ if self.config._transform_input_feature_types is not None and len(
+ self._input_columns
+ ) != len(self.config._transform_input_feature_types):
+ example_arg = ", ".join(["None" for _ in range(arity)])
+ assert False, (
+ "`_transform_input_feature_types` is defined in "
+ f"{self.config.__class__.__name__} but does not match the arity of "
+ f"the transform function in {self.__class__.__name__} (i.e. len("
+ "self.config._transform_input_feature_types) != "
+ "len(self._input_columns) -> "
+ f"{len(self.config._transform_input_feature_types)} != "
+ f"{len(self._input_columns)}).\n"
+ "This can be corrected by doing, for example:\n"
+ f"_transform_input_feature_types = field(\n"
+ f" default_factory=lambda: [{example_arg}], init=False, "
+ "repr=False\n"
+ ")"
+ )
+ if self.config._transform_unused_feature_types is not None and len(
+ self._input_columns
+ ) != len(self.config._transform_unused_feature_types):
+ example_arg = ", ".join(["None" for _ in range(arity)])
+ assert False, (
+ "`_transform_unused_feature_types` is defined in "
+ f"{self.config.__class__.__name__} but does not match the arity of "
+ f"the transform function in {self.__class__.__name__} (i.e. len("
+ "self.config._transform_unused_feature_types) != "
+ "len(self._input_columns) -> "
+ f"{len(self.config._transform_unused_feature_types)} != "
+ f"{len(self._input_columns)}).\n"
+ "This can be corrected by doing, for example:\n"
+ f"_transform_unused_feature_types = field(\n"
+ f" default_factory=lambda: [{example_arg}], init=False, "
+ "repr=False\n"
+ ")"
+ )
+
+ def _process_transform(
+ self,
+ X,
+ *args,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ method_type = self._method_prefix[1:]
+
+ assert self._input_columns is not None, (
+ f"The `{method_type}` method of `{self.__class__.__name__}` must call:\n"
+ "```\n"
+ "self._input_columns = self._set_input_columns_and_arity(*args)"
+ "\n```\n"
+ "Where `*args` are the columns for each input dataset."
+ )
+
+ if is_dataset_dict(X):
+ raise ValueError(
+ "Please provide the dataset directly instead of the dictionary: processor.transform(dataset['train'])"
+ )
+ data_fingerprint = getattr(X, "_fingerprint", None) or fingerprint_from_data(X)
+ cache_output = (
+ cache_output is not None
+ and cache_output
+ or (self.config.enable_caching and self.config.cache_output)
+ )
+
+ if cache_output:
+ if not fingerprint:
+ hash = Hasher()
+ hash.update(self._fingerprint)
+ hash.update(data_fingerprint)
+ fingerprint = fingerprint_from_kwargs(
+ hash.hexdigest(),
+ {
+ "input_columns": self._input_columns,
+ "unused_columns": self.config.unused_columns,
+ },
+ )
+
+ if cache_dir is not None:
+ cache_dir = expand_path(str(cache_dir))
+ cache_dir = os.path.join(cache_dir, "datasets")
+ cache_dir = cache_dir or biofit.config.BIOFIT_DATASETS_CACHE
+ cache_dir = generate_cache_dir(
+ X,
+ data_fingerprint,
+ root_dir=cache_dir,
+ )
+
+ if cache_file_name:
+ if is_remote_url(cache_file_name):
+ raise ValueError(
+ "`cache_file_name` is a remote URL. Please provide the "
+ "file name only. You can specify the directory using "
+ "`cache_dir`."
+ )
+ elif os.path.isabs(cache_file_name):
+ raise ValueError(
+ "`cache_file_name` is an absolute path. Please provide the "
+ "file name only. You can specify the directory using "
+ "`cache_dir`."
+ )
+
+ keep_unused_columns = (
+ keep_unused_columns
+ if keep_unused_columns is not None
+ else self.config.keep_unused_columns
+ )
+ self.keep_unused_columns = keep_unused_columns
+ if (
+ self._input_columns is None
+ and self._feature_dependent
+ and data_fingerprint == self.config._data_fingerprint
+ ):
+ if DataHandler.supports_named_columns(X):
+ cols = self.config.feature_names_in_ or self.config.feature_idx_in_
+ else:
+ cols = self.config.feature_idx_in_
+ if (
+ cols
+ and self.config.extra_idx_in_
+ and len(args) > 0
+ and len(args) == len(self.config.extra_idx_in_)
+ ):
+ cols = [cols]
+ for i in range(len(self.config.extra_idx_in_)):
+ if self.config.extra_names_in_[
+ i
+ ] and DataHandler.supports_named_columns(args[i]):
+ cols.append(self.config.extra_names_in_[i])
+ else:
+ cols.append(self.config.extra_idx_in_[i])
+
+ self._input_columns = cols
+
+ self._validate_transform_params(len(args) + 1)
+ return self._transform(
+ X,
+ *args,
+ input_columns=self._input_columns,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_file_name=cache_file_name,
+ cache_dir=cache_dir,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ @keep_dataset_fingerprint
+ def _transform(
+ self,
+ X,
+ *args,
+ input_columns: List[str] = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_file_name: str = None,
+ cache_dir: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ """
+ Transforms the input data.
+
+ Args:
+ X (Any): The input data.
+ y (Any, optional): The target data. Defaults to None.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ Any: The computed processor.
+ """
+ if not self.is_fitted:
+ raise NotFittedError
+
+ (
+ _,
+ self._selected_indices,
+ self._unused_indices,
+ _,
+ self._extra_indices,
+ self._unused_extra_indices,
+ offsets,
+ ) = self._get_columns(
+ X,
+ *args,
+ input_columns=input_columns,
+ input_feature_types=self.config._transform_input_feature_types,
+ unused_columns=self.config.unused_columns,
+ unused_feature_types=self.config._transform_unused_feature_types,
+ raise_if_missing=raise_if_missing if raise_if_missing is not None else True,
+ )
+
+ if len(args) and not all(arg is None for arg in args):
+ extra = []
+ for i, arg in enumerate(args):
+ if arg is not None:
+ if DataHandler.supports_named_columns(arg):
+ cols = DataHandler.get_column_names(arg, generate_cols=True)
+ cols = [f"{c}_{i}" for c in cols]
+ extra.append(DataHandler.set_column_names(arg, cols))
+ else:
+ extra.append(arg)
+ input = DataHandler.concat([X] + extra, axis=1)
+ else:
+ input = X
+
+ if load_from_cache_file is None:
+ load_from_cache_file = (
+ self.config.enable_caching and self.config.load_from_cache_file
+ )
+
+ trans_map_kwargs = self._prepare_transform_kwargs(
+ input,
+ X,
+ *args,
+ selected_indices=self._selected_indices,
+ unused_indices=self._unused_indices,
+ extra_indices=self._extra_indices,
+ unused_extra_indices=self._unused_extra_indices,
+ offsets=offsets,
+ cache_dir=cache_dir,
+ new_fingerprint=fingerprint,
+ map_kwargs=map_kwargs or {"fn_kwargs": {}},
+ batch_format=batch_format,
+ batched=batched,
+ batch_size=batch_size,
+ cache_output=cache_output,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ num_proc=num_proc,
+ keep_unused_columns=keep_unused_columns,
+ )
+
+ out = None
+ in_memory_table = None
+ if trans_map_kwargs["fn_kwargs"]["fn"]:
+
+ @wraps(BaseProcessor.map)
+ def runner(*args, **map_kwargs):
+ return DataHandler.concat(self.map(*args, **map_kwargs))
+
+ if is_bioset(X) or is_dataset(X, iterable=False):
+ if is_biosets_available():
+ from biosets import Bioset
+
+ wrap_map = Bioset.map
+ else:
+ from datasets import Dataset
+
+ wrap_map = Dataset.map
+
+ @wraps(wrap_map)
+ def runner(*args, **map_kwargs):
+ if len(args) > 0:
+ ds = args[0]
+ args = args[1:]
+ else:
+ ds = map_kwargs.pop("self")
+ return ds.map(*args, **map_kwargs)
+
+ elif is_iterable_dataset(X):
+ from datasets import IterableDataset
+
+ @wraps(IterableDataset.map)
+ def runner(*args, **map_kwargs):
+ if len(args) > 0:
+ ds = args[0]
+ args = args[1:]
+ else:
+ ds = map_kwargs.pop("self")
+ return ds.map(*args, **map_kwargs)
+
+ out = self.run(input, runner=runner, **trans_map_kwargs)
+
+ out = self._process_transform_output(
+ out,
+ X,
+ *args,
+ output_format=output_format,
+ keep_unused_columns=keep_unused_columns,
+ selected_indices=self._selected_indices,
+ unused_indices=self._unused_indices,
+ fingerprint=trans_map_kwargs["new_fingerprint"],
+ )
+
+ if in_memory_table:
+ X._data = in_memory_table
+
+ return out
+
+ def _process_extra_inds(
+ self, orig_input, extra_inputs, extra_indices, unused_extra_indices
+ ):
+ """
+ Processes extra indices for additional inputs.
+
+ This method adjusts the indices for extra inputs by adding an offset based on the
+ dimensions of the original input and the extra inputs. It ensures that the indices
+ are correctly aligned with the combined input dimensions.
+
+ Args:
+ orig_input: The original input data.
+ extra_inputs: A list of additional input data.
+ extra_indices: A list of indices corresponding to the extra inputs.
+ unused_extra_indices: A list of unused indices corresponding to the extra inputs.
+
+ Returns:
+ A tuple containing:
+ - extra_inds: A list of adjusted indices for the extra inputs.
+ - unused_extra_inds: A list of adjusted unused indices for the extra inputs.
+ """
+ assert not extra_inputs or extra_indices is not None, (
+ "`extra_indices` was returned as `None` from "
+ f"`{self.__class__.__name__}`. "
+ f"Was `{self.__class__.__name__}._input_columns` or "
+ f"`{self.__class__.__name__}.config._input_columns` set correctly?"
+ )
+ extra_inds = copy.deepcopy(extra_indices)
+ unused_extra_inds = copy.deepcopy(unused_extra_indices)
+ if extra_inputs and not all(arg is None for arg in extra_inputs):
+ x_dims = DataHandler.get_shape(orig_input)
+ if len(x_dims) == 1:
+ x_dims = (x_dims[0], 1)
+ offset = x_dims[1]
+ extra_inds = []
+ unused_extra_inds = []
+ if unused_extra_indices is None:
+ unused_extra_indices = [None] * len(extra_indices)
+ for inds, un_inds, arg in zip(
+ extra_indices, unused_extra_indices, extra_inputs
+ ):
+ if arg is not None:
+ if inds is not None and len(inds) > 0:
+ extra_inds.append([i + offset for i in inds])
+ else:
+ extra_inds.append(None)
+ if un_inds is not None and len(un_inds) > 0:
+ unused_extra_inds.append([i + offset for i in un_inds])
+ else:
+ unused_extra_inds.append(None)
+
+ arg_dims = DataHandler.get_shape(arg)
+ if len(arg_dims) == 1:
+ arg_dims = (arg_dims[0], 1)
+ offset += arg_dims[1]
+ else:
+ extra_inds.append(None)
+ unused_extra_inds.append(None)
+ return extra_inds, unused_extra_inds
+
+ def _prepare_fit_kwargs(
+ self,
+ funcs,
+ combined_inputs,
+ orig_input,
+ extra_inputs,
+ extra_untouched_inputs,
+ selected_indices=None,
+ unused_indices=None,
+ extra_indices=None,
+ unused_extra_indices=None,
+ offsets=None,
+ map_kwargs={"fn_kwargs": {}},
+ batch_format=None,
+ batched=None,
+ batch_size=None,
+ num_proc=None,
+ ):
+ original_format = get_data_format(combined_inputs)
+
+ poolers = None
+ if (
+ batch_size is not None
+ and batch_size < DataHandler.get_shape(combined_inputs)[0]
+ and funcs
+ ):
+ batchable_funcs = self._get_method(
+ _ORDERED_FORMATS,
+ func_type="_fit",
+ prefix=self.config._batch_method_prefix,
+ )
+ if not batchable_funcs:
+ batch_size = DataHandler.get_shape(combined_inputs)[0]
+ if batched is not None:
+ logger.warning_once(
+ f"There are no batched fit functions available for {self.__class__.__name__}. "
+ "Using non-batched fit functions."
+ )
+ batched = True
+ batch_size = None
+ batched = True
+ poolers = None
+ else:
+ poolers = self._get_method(_ORDERED_FORMATS, func_type="_pool_fit")
+ funcs = batchable_funcs
+ if original_format == "ray" and batch_format is None:
+ batch_format = ["pandas", "pd", "numpy", "np"]
+
+ func, batch_format = self._get_target_func(funcs, original_format, batch_format)
+ pooler = None
+ if poolers:
+ pooler, _ = self._get_target_func(poolers, original_format, batch_format)
+
+ func_args = inspect.getfullargspec(func).args if func else []
+
+ with_indices = (
+ "indices" in func_args
+ or "indexes" in func_args
+ or "index" in func_args
+ or "ind" in func_args
+ or "inds" in func_args
+ or "idx" in func_args
+ or "i" in func_args
+ )
+
+ with_rank = "rank" in func_args or "rnk" in func_args or "r" in func_args
+ map_kwargs = map_kwargs or {}
+ map_kwargs = copy.deepcopy(map_kwargs)
+ map_kwargs["with_indices"] = with_indices
+ map_kwargs["with_rank"] = with_rank
+
+ if "num_proc" not in map_kwargs:
+ map_kwargs["num_proc"] = num_proc
+
+ if (
+ func
+ and "num_proc" in map_kwargs
+ and not func.__name__.startswith(self.config._batch_method_prefix)
+ ):
+ # cannot use num_proc with non-batched fit functions
+ map_kwargs.pop("num_proc")
+
+ map_kwargs["desc"] = self.config._fit_process_desc
+ if batched is None or batched:
+ map_kwargs["batched"] = True
+ map_kwargs["batch_size"] = batch_size
+ else:
+ map_kwargs["batched"] = False
+ map_kwargs["batch_size"] = 1
+
+ extra_inds, unused_extra_inds = self._process_extra_inds(
+ orig_input=orig_input,
+ extra_inputs=extra_inputs,
+ extra_indices=extra_indices,
+ unused_extra_indices=unused_extra_indices,
+ )
+
+ fn_kwargs = {
+ "fn": func,
+ "func_type": "_fit",
+ "extra_untouched_inputs": extra_untouched_inputs,
+ "selected_indices": selected_indices,
+ "unused_indices": unused_indices,
+ "extra_indices": extra_inds,
+ "unused_extra_indices": unused_extra_inds,
+ "with_metadata": "metadata" in func_args,
+ "in_format_kwargs": {
+ "target_format": batch_format,
+ },
+ "out_format_kwargs": {
+ "target_format": None,
+ },
+ }
+
+ if "fn_kwargs" in map_kwargs:
+ map_kwargs["fn_kwargs"].update(fn_kwargs)
+ else:
+ map_kwargs["fn_kwargs"] = fn_kwargs
+
+ map_kwargs["new_fingerprint"] = self.fingerprint
+
+ return map_kwargs, pooler
+
+ def _prepare_transform_kwargs(
+ self,
+ combined_inputs,
+ orig_input,
+ *extra_inputs,
+ selected_indices,
+ unused_indices,
+ extra_indices,
+ unused_extra_indices,
+ offsets=None,
+ batch_format=None,
+ map_kwargs={"fn_kwargs": {}},
+ batched=None,
+ batch_size=1000,
+ cache_output=True,
+ cache_file_name=None,
+ cache_dir=None,
+ load_from_cache_file=True,
+ new_fingerprint=None,
+ num_proc=None,
+ keep_unused_columns=None,
+ ):
+ original_format = get_data_format(combined_inputs)
+ input_format = batch_format
+
+ funcs = self._get_method(_ORDERED_FORMATS, func_type=self._method_prefix)
+ func, batch_format = self._get_target_func(funcs, original_format, input_format)
+
+ map_kwargs = map_kwargs.copy()
+ func_args = inspect.getfullargspec(func).args if func else []
+ indices_args = [
+ "indices",
+ "indexes",
+ "index",
+ "ind",
+ "inds",
+ "idx",
+ "i",
+ ]
+ rank_args = ["rank", "rnk", "r"]
+ with_indices = any(arg in func_args for arg in indices_args)
+ with_rank = any(arg in func_args for arg in rank_args)
+ map_kwargs = copy.deepcopy(map_kwargs)
+ map_kwargs["with_indices"] = with_indices
+ map_kwargs["with_rank"] = with_rank
+
+ map_kwargs["desc"] = getattr(
+ self.config,
+ self._method_prefix + "_process_desc",
+ "Transforming data",
+ )
+
+ if "keep_in_memory" not in map_kwargs:
+ map_kwargs["keep_in_memory"] = not cache_output
+
+ if "cache_file_name" not in map_kwargs:
+ map_kwargs["cache_file_name"] = cache_file_name
+
+ if "num_proc" not in map_kwargs:
+ map_kwargs["num_proc"] = num_proc
+
+ if batched is None or batched:
+ map_kwargs["batched"] = True
+ map_kwargs["batch_size"] = batch_size
+ elif batched:
+ map_kwargs["batched"] = True
+ map_kwargs["batch_size"] = batch_size
+ else:
+ map_kwargs["batched"] = False
+
+ if "load_from_cache_file" not in map_kwargs:
+ map_kwargs["load_from_cache_file"] = load_from_cache_file
+
+ output_format = original_format
+ if batch_format in _ARROW_WRITEABLE_FORMATS:
+ output_format = batch_format
+ elif original_format not in _ARROW_WRITEABLE_FORMATS:
+ output_format = "arrow"
+
+ map_kwargs["new_fingerprint"] = new_fingerprint
+
+ features_out = None
+
+ extra_inds, unused_extra_inds = self._process_extra_inds(
+ orig_input=orig_input,
+ extra_inputs=extra_inputs,
+ extra_indices=extra_indices,
+ unused_extra_indices=unused_extra_indices,
+ )
+
+ fn_kwargs = {
+ "fn": func,
+ "func_type": self._method_prefix,
+ "with_metadata": "metadata" in func_args,
+ "selected_indices": selected_indices,
+ "unused_indices": unused_indices,
+ "extra_indices": extra_inds,
+ "unused_extra_indices": unused_extra_inds,
+ "keep_unused_columns": keep_unused_columns,
+ "in_format_kwargs": {
+ "target_format": batch_format,
+ },
+ "out_format_kwargs": {
+ "target_format": output_format,
+ },
+ }
+
+ if "fn_kwargs" in map_kwargs:
+ map_kwargs["fn_kwargs"].update(fn_kwargs)
+ else:
+ map_kwargs["fn_kwargs"] = fn_kwargs
+
+ combined_inputs, map_kwargs = self._process_transform_input(
+ combined_inputs, **map_kwargs
+ )
+
+ if hasattr(orig_input, "cache_files") and orig_input.cache_files:
+ if cache_file_name is None:
+ cache_files = orig_input.cache_files[0]["filename"]
+ cache_dir = os.path.dirname(cache_files)
+ cache_file_name = f"cache-{new_fingerprint}.arrow"
+ cache_file_name = os.path.join(cache_dir, cache_file_name)
+
+ if not load_from_cache_file or not (
+ cache_file_name and os.path.exists(cache_file_name)
+ ):
+ if "features" not in map_kwargs or map_kwargs["features"] is None:
+ unsel_inds = unused_indices if unused_indices else []
+ if extra_inds:
+ unsel_inds += [
+ i
+ for sub in [
+ inds if inds is not None else [] for inds in extra_inds
+ ]
+ for i in sub
+ ]
+ if unused_extra_inds:
+ unsel_inds += [
+ i
+ for sub in [
+ inds if inds is not None else []
+ for inds in unused_extra_inds
+ ]
+ for i in sub
+ ]
+ unsel_inds = sorted(unsel_inds)
+ features_out = self._get_features_out(
+ combined_inputs,
+ selected_indices=copy.deepcopy(selected_indices),
+ unselected_indices=unsel_inds,
+ one_to_one_features=self.config.one_to_one_features,
+ n_features_out=self.config._n_features_out,
+ keep_unused_columns=keep_unused_columns,
+ )
+ map_kwargs["features"] = features_out
+ else:
+ features_out = map_kwargs["features"]
+
+ map_kwargs["fn_kwargs"]["feature_names"] = list(features_out.keys())
+ map_kwargs["features"] = features_out
+ return map_kwargs
+
+ def fit_transform(
+ self,
+ X,
+ *args,
+ **kwargs,
+ ):
+ output_format = kwargs.pop("output_format", None)
+ return self.fit(X, *args, **kwargs).transform(
+ X, output_format=output_format, **kwargs
+ )
+
+ def _process_transform_input(self, X, **kwargs):
+ return X, kwargs
+
+ def _process_transform_output(self, output, input, *args, **kwargs):
+ output_format = kwargs.get("output_format", None) or get_data_format(input)
+ if output_format:
+ if output is None:
+ raise ValueError(
+ f"The output format is specified as `{output_format}` but the "
+ f"output from the {self._method_prefix.removeprefix('_')} method "
+ "is `None`."
+ )
+ output = DataHandler.to_format(output, output_format)
+ if DataHandler.get_shape(output)[0] != DataHandler.get_shape(input)[0]:
+ return output
+ return output
+
+ def get_params(self, deep=True, show_init_only=True, show_repr_only=True):
+ """Get the parameters of the processor."""
+ return self.config.get_params(
+ deep=deep, show_init_only=show_init_only, show_repr_only=show_repr_only
+ )
+
+ def load_processed_estimator_from_cache(
+ self, cache_file_name: Optional[Union[Path, str]] = None, **kwargs
+ ):
+ """Load a processed estimator from cache if it exists, otherwise throw an error."""
+ # Check if we've already cached this computation (indexed by a hash)
+ if cache_file_name and os.path.exists(cache_file_name):
+ if isinstance(cache_file_name, Path):
+ cache_file_name = cache_file_name.resolve().as_posix()
+ self.config = self.config_class.from_config_file(cache_file_name)
+ return self
+ else:
+ raise NonExistentCacheError
+
+ def _get_features_out(
+ self,
+ X,
+ selected_indices=None,
+ unselected_indices=None,
+ one_to_one_features=True,
+ n_features_out=None,
+ keep_unused_columns=False,
+ ) -> "Features":
+ features_out = None
+ if unselected_indices is not None:
+ unsel_inds = set(unselected_indices)
+ else:
+ unsel_inds = set()
+ cols = DataHandler.get_column_names(X, generate_cols=True)
+ assert cols is not None, "Could not generate column names from input data"
+ if one_to_one_features:
+ if selected_indices is not None:
+ sel_inds = set(selected_indices)
+ else:
+ sel_inds = set(range(len(cols)))
+ else:
+ sel_inds = set()
+
+ if is_bioset(X) or is_dataset(X):
+ # get the output features, as well as the features that need to be reinserted
+ features = X._info.features.copy()
+ elif is_datasets_available():
+ from datasets.features import Value
+
+ features = {k: Value(dtype=v) for k, v in DataHandler.get_dtypes(X).items()}
+ else:
+ features = {k: None for k in cols}
+
+ if keep_unused_columns:
+ sel_inds = sel_inds.union(unsel_inds)
+
+ out_cols = self._get_feature_names_out(
+ n_features_out=n_features_out,
+ input_features=cols,
+ useful_feature_inds=list(sorted(sel_inds)),
+ one_to_one_features=one_to_one_features,
+ )
+ features_out = {}
+ if one_to_one_features:
+ pa_type = None
+ if self.output_dtype:
+ pa_type = string_to_arrow(self.output_dtype)
+ pos = 0
+ for i in range(len(cols)):
+ if i in unsel_inds:
+ if not keep_unused_columns:
+ continue
+ features_out[out_cols[pos]] = features[cols[i]]
+ pos += 1
+ else:
+ k = out_cols[pos]
+ pos += 1
+ v = features[cols[i]]
+ if pa_type and hasattr(v, "dtype") and hasattr(v, "pa_type"):
+ setattr(v, "dtype", self.output_dtype)
+ setattr(v, "pa_type", pa_type)
+ features_out[k] = v
+ else:
+ if keep_unused_columns:
+ features_out.update({cols[i]: features[cols[i]] for i in unsel_inds})
+ # the transformed features are always appended to the end
+ out_cols = out_cols[len(unsel_inds) :]
+ if is_datasets_available():
+ from datasets.features import Value
+
+ if self.output_dtype:
+ features_out.update(
+ {k: Value(dtype=self.output_dtype) for k in out_cols}
+ )
+ else:
+ dtypes = list(
+ set([features[cols[i]].dtype for i in selected_indices])
+ )
+ dtype = determine_upcast(dtypes)
+ features_out.update({k: Value(dtype=dtype) for k in out_cols})
+ else:
+ features_out.update({k: None for k in out_cols})
+
+ if is_datasets_available():
+ from datasets.features import Features
+
+ features_out = Features(features_out)
+
+ return features_out
+
+ def _prepare_runner(self, X, **fn_kwargs):
+ if fn_kwargs.get("with_metadata", False):
+ if is_bioset(X) or is_dataset(X) and is_biosets_available():
+ from biosets import get_feature_metadata
+
+ feat_metadata = get_feature_metadata(X)
+ feat_arrow_tbl = pa.Table.from_pylist(list(feat_metadata.values()))
+ try:
+ feat_arrow_tbl = feat_arrow_tbl.add_column(
+ 0, "features", pa.array(list(feat_metadata.keys()))
+ )
+ fn_kwargs["metadata"] = feat_arrow_tbl
+ except Exception:
+ pass
+
+ return fn_kwargs
+
+ def run(
+ self,
+ X,
+ runner: Optional[Callable] = None,
+ fn_kwargs: dict = {},
+ **map_kwargs,
+ ):
+ fn_kwargs = self._prepare_runner(X, **fn_kwargs)
+ if runner:
+ if is_bioset(X) or is_dataset(X):
+ format = "arrow"
+ if (
+ "in_format_kwargs" in fn_kwargs
+ and "target_format" in fn_kwargs["in_format_kwargs"]
+ ):
+ format = fn_kwargs["in_format_kwargs"]["target_format"]
+ if format not in _ARROW_WRITEABLE_FORMATS:
+ format = "arrow"
+ with X.formatted_as(format):
+ kwargs = get_kwargs({"fn_kwargs": fn_kwargs, **map_kwargs}, runner)
+ return runner(X, self._process_batches, **kwargs)
+ else:
+ kwargs = get_kwargs({"fn_kwargs": fn_kwargs, **map_kwargs}, runner)
+ return runner(X, self._process_batches, **kwargs)
+ else:
+ return self._process_batches(X, **fn_kwargs)
+
+ def _get_feature_names_out(
+ self,
+ input_features,
+ n_features_out=0,
+ useful_feature_inds=None,
+ one_to_one_features=True,
+ ):
+ """
+ Retrieves the feature names based on the input and output data.
+
+ Args:
+ input (Any): The input data.
+ output (Any): The output data.
+
+ Returns:
+ list: The list of feature names.
+ """
+ out_features = [input_features[i] for i in useful_feature_inds]
+ if one_to_one_features:
+ if input_features is None:
+ raise ValueError(
+ "Input features must be provided to generate output features"
+ )
+ if useful_feature_inds is None:
+ useful_feature_inds = range(len(input_features))
+
+ if self.config.features_out_suffix or self.config.features_out_prefix:
+ if self.config.features_out_prefix:
+ out_features = [
+ f"{self.config.features_out_prefix}{i}" for i in out_features
+ ]
+
+ if self.config.features_out_suffix:
+ out_features = [
+ f"{i}{self.config.features_out_suffix}" for i in out_features
+ ]
+ return out_features
+ else:
+ _out_features = _generate_get_feature_names_out(self, n_features_out)
+ if self.config.features_out_suffix or self.config.features_out_prefix:
+ if self.config.features_out_prefix:
+ _out_features = [
+ f"{self.config.features_out_prefix}{col}"
+ for col in _out_features
+ ]
+ if self.config.features_out_suffix:
+ _out_features = [
+ f"{col}{self.config.features_out_suffix}"
+ for col in enumerate(_out_features)
+ ]
+ out_features.extend(_out_features)
+ return out_features
+ out_features.extend(_out_features)
+ return out_features
+
+ ## Functions for Datasets and IterableDatasets type ##
+
+ def _convert_map_kwargs_to_ray_kwargs(
+ self,
+ map_kwargs: dict,
+ batch_format=None,
+ is_fit=True,
+ ):
+ ray_kwargs = None
+ if is_fit:
+ ray_kwargs = {
+ "fn": self.__class__,
+ "num_cpus": map_kwargs.get("num_proc", None),
+ "num_gpus": map_kwargs.get("num_gpus", None),
+ "batch_format": batch_format
+ if batch_format in ["numpy", "pandas"]
+ else None,
+ "batch_size": map_kwargs.get("batch_size", None),
+ "concurrency": 1,
+ "fn_kwargs": map_kwargs["fn_kwargs"],
+ "fn_constructor_args": (self.config,),
+ "zero_copy_batch": True,
+ }
+ else:
+ ray_kwargs = {
+ "fn": self._process_batches,
+ "num_cpus": map_kwargs.get("num_proc", None),
+ "num_gpus": map_kwargs.get("num_gpus", None),
+ "batch_format": batch_format
+ if batch_format in ["numpy", "pandas"]
+ else None,
+ "batch_size": map_kwargs.get("batch_size", None),
+ "fn_kwargs": map_kwargs["fn_kwargs"],
+ "fn_constructor_args": (self.config,),
+ "zero_copy_batch": False,
+ }
+
+ return ray_kwargs
+
+ def _dummy_func(self, X, *args, **kwargs):
+ """for creating fingerprint and returning the input data as is when no function is provided."""
+ return X
+
+ def _process_batches(self, X, *fn_args, **fn_kwargs):
+ func = fn_kwargs.get("fn", None)
+ if fn_kwargs["func_type"] == "_fit":
+ input, _fn_args, _fn_kwargs = self._process_fit_batch_input(
+ X, *fn_args, **fn_kwargs
+ )
+ else:
+ input, _fn_args, _fn_kwargs = self._process_transform_batch_input(
+ X, *fn_args, **fn_kwargs
+ )
+
+ out = func(input, *_fn_args, **_fn_kwargs)
+
+ if isinstance(out, BaseProcessor):
+ return self._process_fit_batch_output(out)
+ return self._process_transform_batch_output(X, out, **fn_kwargs)
+
+ def _validate_fit_func_args(self, func, *fn_args):
+ required_args = get_required_args(func)
+ noun_plural = "argument" if len(required_args) == 1 else "arguments"
+ verb_tense = "was" if len(fn_args) == 0 else "were"
+ assert len(required_args) == (len(fn_args) + 1), (
+ f"`{self.__class__.__name__}.fit` requires {len(required_args)} "
+ f"{noun_plural} {tuple(required_args)}, "
+ f"but only {len(fn_args) + 1} {verb_tense} "
+ "provided. Either provide the missing arguments or provide the input "
+ f"columns found in '{required_args[0]}', if applicable."
+ )
+
+ def _process_fit_batch_input(self, X, *fn_args, **fn_kwargs):
+ func = fn_kwargs.get("fn", None)
+ assert X is not None, "No input data provided."
+ assert func is not None, "No function provided for processing the data."
+ in_format_kwargs = fn_kwargs.get("in_format_kwargs", {})
+ selected_indices = fn_kwargs.get("selected_indices", [])
+ in_format_kwargs["input_columns"] = selected_indices
+ input = DataHandler.to_format(X, **in_format_kwargs)
+ _fn_kwargs = get_kwargs(fn_kwargs, func)
+ extra_indices = fn_kwargs.get("extra_indices", [])
+ extra_untouched_inputs = fn_kwargs.get("extra_untouched_inputs", [])
+ fn_args = list(fn_args)
+ if extra_untouched_inputs:
+ for arg in extra_untouched_inputs:
+ in_format_kwargs["input_columns"] = None
+ arg = DataHandler.to_format(arg, **in_format_kwargs)
+ fn_args.append(arg)
+ elif extra_indices is not None:
+ for cols in extra_indices:
+ if cols is not None and len(cols):
+ in_format_kwargs["input_columns"] = cols
+ arg = DataHandler.to_format(X, **in_format_kwargs)
+ fn_args.append(arg)
+
+ self._validate_fit_func_args(func, *fn_args)
+ return input, tuple(fn_args), _fn_kwargs
+
+ def _process_fit_batch_output(self, out):
+ return out
+
+ def _process_transform_batch_input(self, X, *fn_args, **fn_kwargs):
+ assert X is not None, "No input data was provided for processing."
+ func = fn_kwargs.get("fn", None)
+ in_format_kwargs = fn_kwargs.get("in_format_kwargs", {})
+ selected_indices = fn_kwargs.get("selected_indices", [])
+ in_format_kwargs["input_columns"] = selected_indices
+ input = DataHandler.to_format(X, **in_format_kwargs)
+ _fn_kwargs = get_kwargs(fn_kwargs, func)
+ extra_indices = fn_kwargs.get("extra_indices", [])
+ if extra_indices:
+ fn_args = list(fn_args)
+ for i, cols in enumerate(extra_indices):
+ if cols is not None and len(cols):
+ in_format_kwargs["input_columns"] = cols
+ extra_arg = DataHandler.to_format(X, **in_format_kwargs)
+ fn_args.insert(i, extra_arg)
+ return input, fn_args, _fn_kwargs
+
+ def _process_transform_batch_output(self, input, output, **fn_kwargs):
+ selected_indices = fn_kwargs.get("selected_indices", None)
+ unused_indices = fn_kwargs.get("unused_indices", None)
+ keep_unused_columns = fn_kwargs.get("keep_unused_columns", None)
+ feature_names = fn_kwargs.get("feature_names", None)
+ out_dims = DataHandler.get_shape(output)
+ if len(out_dims) == 1:
+ out_dims = (out_dims[0], 1)
+ output = DataHandler.to_frame(output, "__output__")
+
+ out_format_kwargs = fn_kwargs.get("out_format_kwargs", {})
+ output = DataHandler.to_format(output, **out_format_kwargs)
+ if keep_unused_columns:
+ output = self._reinsert_columns(
+ input,
+ output,
+ selected_indices,
+ unused_indices,
+ one_to_one_features=self.config.one_to_one_features,
+ )
+ if feature_names is not None and len(feature_names) > 0:
+ output = DataHandler.set_column_names(output, feature_names)
+ return output
+
+ def _process_fit_input(self, input, **kwargs):
+ return input, kwargs
+
+ def _process_fit_output(self, input, out):
+ return out
+
+ def __repr__(self):
+ from sklearn.utils._pprint import _EstimatorPrettyPrinter
+
+ N_MAX_ELEMENTS_TO_SHOW = 30
+ pp = _EstimatorPrettyPrinter(
+ compact=True,
+ indent=1,
+ indent_at_name=True,
+ n_max_elements_to_show=N_MAX_ELEMENTS_TO_SHOW,
+ )
+ repr_ = self.__class__.__name__ + pp.pformat(self.config).removeprefix(
+ self.config.__class__.__name__
+ )
+ return repr_
+
+ def _repr_mimebundle_(self, **kwargs):
+ """Mime bundle used by jupyter kernels to display estimator"""
+ from sklearn._config import get_config
+ from sklearn.utils._estimator_html_repr import estimator_html_repr
+
+ output = {"text/plain": repr(self)}
+ if get_config()["display"] == "diagram":
+ output["text/html"] = estimator_html_repr(self)
+ return output
+
+ def shard(self, X, num_shards, index, contiguous):
+ if not 0 <= index < num_shards:
+ raise ValueError("index should be in [0, num_shards-1]")
+ num_rows = DataHandler.get_shape(X)[0]
+ if contiguous:
+ div = num_rows // num_shards
+ mod = num_rows % num_shards
+ start = div * index + min(index, mod)
+ end = start + div + (1 if index < mod else 0)
+ indices = range(start, end)
+ else:
+ indices = np.arange(index, num_rows, num_shards)
+
+ return DataHandler.select_rows(X, indices)
+
+ def _pool_fit(self, fitted_processors):
+ """Pool the results of the map function."""
+ out = fitted_processors[0]
+ if len(fitted_processors) > 1:
+ raise NotImplementedError(
+ f"Pooling results from multiple processes is not supported for {self.__class__.__name__}"
+ )
+ return out
+
+ def map(
+ self,
+ X,
+ function: Optional[Callable] = None,
+ with_indices: bool = False,
+ with_rank: bool = False,
+ batched: bool = False,
+ batch_size: Optional[int] = 1000,
+ drop_last_batch: bool = False,
+ cache_output: bool = True,
+ input_columns: Optional[list] = None,
+ fn_kwargs: Optional[dict] = None,
+ num_proc: Optional[int] = None,
+ desc: Optional[str] = None,
+ ):
+ if batch_size is None:
+ batch_size = DataHandler.get_shape(X)[0]
+ num_shards = num_proc if num_proc is not None else 1
+ if batched and drop_last_batch:
+ pbar_total = (
+ DataHandler.get_shape(X)[0]
+ // num_shards
+ // batch_size
+ * num_shards
+ * batch_size
+ )
+ else:
+ pbar_total = DataHandler.get_shape(X)[0]
+
+ if num_proc:
+ if is_bioset(X) or is_dataset(X):
+ shards = [
+ X.shard(
+ num_shards=num_shards,
+ index=rank,
+ contiguous=True,
+ keep_in_memory=not cache_output,
+ )
+ for rank in range(num_proc)
+ ]
+ else:
+ shards = [
+ self.shard(
+ X,
+ num_shards=num_proc,
+ index=rank,
+ contiguous=True,
+ )
+ for rank in range(num_proc)
+ ]
+ else:
+ shards = [X]
+
+ dataset_kwargs = {
+ "shard": X,
+ "function": function,
+ "with_indices": with_indices,
+ "with_rank": with_rank,
+ "batched": batched,
+ "batch_size": batch_size,
+ "drop_last_batch": drop_last_batch,
+ "input_columns": input_columns,
+ "fn_kwargs": fn_kwargs,
+ }
+
+ kwargs_per_job = [
+ {
+ **dataset_kwargs,
+ "shard": shards[rank],
+ "rank": rank,
+ "offset": sum(len(s) for s in shards[:rank]),
+ }
+ for rank in range(num_shards)
+ ]
+
+ if len(kwargs_per_job) < num_shards:
+ logger.info(
+ f"Reprocessing {len(kwargs_per_job)}/{num_shards} shards because some of them were missing from the cache."
+ )
+
+ processed_data = [None] * num_shards
+ shards_done = 0
+ if num_proc is not None and num_proc > 1:
+ with Pool(len(kwargs_per_job)) as pool:
+ logger.info(f"Spawning {num_proc} processes")
+ if is_datasets_available():
+ from datasets.utils import tqdm
+ else:
+ from tqdm.auto import tqdm
+ with tqdm(
+ unit=" examples",
+ total=pbar_total,
+ desc=(desc or "Map") + f" (num_proc={num_proc})",
+ ) as pbar:
+ for rank, done, content in iflatmap_unordered(
+ pool,
+ BaseProcessor._map_single,
+ kwargs_iterable=kwargs_per_job,
+ ):
+ if done:
+ shards_done += 1
+ logger.debug(
+ f"Finished processing shard number {rank} of {num_shards}."
+ )
+ processed_data[rank] = content
+ else:
+ pbar.update(content)
+ else:
+ processed_data = None
+ if is_datasets_available():
+ from datasets.utils import tqdm
+ else:
+ from tqdm.auto import tqdm
+
+ with tqdm(
+ unit=" examples",
+ total=pbar_total,
+ desc=desc or "Map",
+ ) as pbar:
+ for rank, done, content in BaseProcessor._map_single(**dataset_kwargs):
+ if done:
+ shards_done += 1
+ logger.debug(
+ f"Finished processing shard number {rank} of {num_shards}."
+ )
+ processed_data = content
+ else:
+ pbar.update(content)
+ assert processed_data is not None, "Failed to retrieve the result from map"
+
+ return [processed_data]
+ for kwargs in kwargs_per_job:
+ del kwargs["shard"]
+ return processed_data
+
+ @staticmethod
+ def _map_single(
+ shard: Any,
+ function: Optional[Callable] = None,
+ with_indices: bool = False,
+ with_rank: bool = False,
+ input_columns: Optional[List[str]] = None,
+ batched: bool = False,
+ batch_size: Optional[int] = 1000,
+ drop_last_batch: bool = False,
+ fn_kwargs: Optional[dict] = None,
+ rank: Optional[int] = None,
+ offset: int = 0,
+ ) -> Iterable[Tuple[int, bool, Union[int, Any]]]:
+ """Apply a function to all the elements in the table (individually or in batches)
+ and update the table (if function does update examples).
+
+ Args:
+ shard (`datasets.Bioset`): Bioset to map the transform on.
+ function (`Callable`): with one of the following signature:
+ - `function(example: Dict[str, Any]) -> Dict[str, Any]` if `batched=False` and `with_indices=False` and `with_rank=False`
+ - `function(example: Dict[str, Any], *extra_args) -> Dict[str, Any]` if `batched=False` and `with_indices=True` and/or `with_rank=True` (one extra arg for each)
+ - `function(batch: Dict[str, List]) -> Dict[str, List]` if `batched=True` and `with_indices=False` and `with_rank=False`
+ - `function(batch: Dict[str, List], *extra_args) -> Dict[str, List]` if `batched=True` and `with_indices=True` and/or `with_rank=True` (one extra arg for each)
+
+ For advanced usage, the function can also return a `pyarrow.Table`.
+ Moreover if your function returns nothing (`None`), then `map` will run your function and return the dataset unchanged.
+ If no function is provided, default to identity function: lambda x: x
+ with_indices (`bool`, defaults to `False`): Provide example indices to `function`. Note that in this case the signature of `function` should be `def function(example, idx[, rank]): ...`.
+ with_rank (`bool`, default `False`): Provide process rank to `function`. Note that in this case the signature of `function` should be `def function(example[, idx], rank): ...`.
+ input_columns (`Optional[List[str]]`, defaults to `None`): The columns to be passed into `function` as
+ positional arguments. If `None`, a dict mapping to all formatted columns is passed as one argument.
+ batched (`bool`, defaults to `False`): Provide batch of examples to `function`
+ batch_size (`int`, optional, defaults to `1000`): Number of examples per batch provided to `function` if `batched=True`
+ `batch_size <= 0` or `batch_size == None`: Provide the full dataset as a single batch to `function`
+ drop_last_batch (`bool`, default: `False`): Whether a last batch smaller than the batch_size should be
+ dropped instead of being processed by the function.
+ fn_kwargs (`Dict`, optional, defaults to `None`): Keyword arguments to be passed to `function`
+ rank: (`int`, optional, defaults to `None`): If specified, this is the process rank when doing multiprocessing
+ offset: (`int`, defaults to 0): If specified, this is an offset applied to the indices passed to `function` if `with_indices=True`.
+ """
+ if fn_kwargs is None:
+ fn_kwargs = {}
+
+ # If we do batch computation but no batch size is provided, default to the full dataset
+ if batched and (batch_size is None or batch_size <= 0):
+ batch_size = DataHandler.get_shape(shard)[0]
+
+ # We set this variable to True after processing the first example/batch in
+ # `apply_function_on_filtered_inputs` if the map function returns a dict.
+ # If set to False, no new arrow table will be created
+
+ update_data = None
+ input_formatter = None
+ if is_bioset(shard) or is_dataset(shard):
+ from datasets.arrow_dataset import get_formatter
+
+ format_kwargs = shard._format_kwargs.copy()
+ # Lazy formatting is only available for the default format (None/python)
+ if not input_columns and shard._format_type is None:
+ format_kwargs["lazy"] = True
+ input_formatter = get_formatter(
+ shard._format_type,
+ features=shard._info.features,
+ **format_kwargs,
+ )
+
+ def apply_function_on_filtered_inputs(data, indices, offset=0):
+ """Utility to apply the function on a selection of columns."""
+ nonlocal update_data
+ if (
+ isinstance(data, pa.Table)
+ and input_formatter
+ and is_datasets_available()
+ ):
+ from datasets.arrow_dataset import format_table
+
+ inputs = format_table(
+ data,
+ 0 if not batched else range(data.num_rows),
+ format_columns=input_columns,
+ formatter=input_formatter,
+ )
+ else:
+ inputs = data
+ fn_args = (
+ [inputs]
+ if input_columns is None
+ else [DataHandler.select_column(inputs, col) for col in input_columns]
+ )
+ if offset == 0:
+ effective_indices = indices
+ else:
+ effective_indices = (
+ [i + offset for i in indices]
+ if isinstance(indices, list)
+ else indices + offset
+ )
+ additional_args = ()
+ if with_indices:
+ additional_args += (effective_indices,)
+ if with_rank:
+ additional_args += (rank,)
+ processed_inputs = function(*fn_args, *additional_args, **fn_kwargs)
+ return processed_inputs
+
+ num_examples_progress_update = 0
+ # Optionally initialize the writer as a context manager
+ try:
+ if is_bioset(shard) or is_dataset(shard, iterable=False):
+ arrow_formatted_shard = shard.with_format("arrow")
+
+ if not batched:
+ shard_iterable = enumerate(arrow_formatted_shard)
+ else:
+ num_rows = (
+ DataHandler.get_shape(shard)[0]
+ if not drop_last_batch
+ else DataHandler.get_shape(shard)[0] // batch_size * batch_size
+ )
+ shard_iterable = zip(
+ range(0, num_rows, batch_size),
+ arrow_formatted_shard.iter(
+ batch_size, drop_last_batch=drop_last_batch
+ ),
+ )
+ else:
+ if not batched:
+ shard_iterable = enumerate(shard)
+ else:
+ num_rows = (
+ DataHandler.get_shape(shard)[0]
+ if not drop_last_batch
+ else DataHandler.get_shape(shard)[0] // batch_size * batch_size
+ )
+ shard_iterable = zip(
+ range(0, num_rows, batch_size),
+ DataHandler.iter(
+ shard, batch_size, drop_last_batch=drop_last_batch
+ ),
+ )
+ processors = None
+ if not batched:
+ _time = time.time()
+ for i, example in shard_iterable:
+ processor = apply_function_on_filtered_inputs(
+ example, i, offset=offset
+ )
+ if isinstance(processor, BaseProcessor):
+ processors = processor
+ else:
+ processors = processors or []
+ processors.append(processor)
+
+ num_examples_progress_update += 1
+ if time.time() > _time + biofit.config.PBAR_REFRESH_TIME_INTERVAL:
+ _time = time.time()
+ yield rank, False, num_examples_progress_update
+ num_examples_progress_update = 0
+ else:
+ _time = time.time()
+ for i, batch in shard_iterable:
+ num_examples_in_batch = DataHandler.get_shape(batch)[0]
+ indices = list(
+ range(
+ *(
+ slice(i, i + batch_size).indices(
+ DataHandler.get_shape(shard)[0]
+ )
+ )
+ )
+ ) # Something simpler?
+ processor = apply_function_on_filtered_inputs(
+ batch,
+ indices,
+ offset=offset,
+ )
+ if isinstance(processor, BaseProcessor):
+ processors = processor
+ else:
+ processors = processors or []
+ processors.append(processor)
+ num_examples_progress_update += num_examples_in_batch
+ if time.time() > _time + biofit.config.PBAR_REFRESH_TIME_INTERVAL:
+ _time = time.time()
+ yield rank, False, num_examples_progress_update
+ num_examples_progress_update = 0
+
+ except (Exception, KeyboardInterrupt):
+ yield rank, False, num_examples_progress_update
+ raise
+
+ yield rank, False, num_examples_progress_update
+
+ if isinstance(processors, BaseProcessor):
+ yield rank, True, processors
+ elif isinstance(processors, list):
+ if len(processors) > 0:
+ yield rank, True, DataHandler.concat(processors)
+ else:
+ yield rank, True, None
+ else:
+ yield rank, True, processors
+
+ def cleanup_cache_files(self, X=None, cache_dir=None, cache_file_name=None) -> int:
+ """Clean up cache files generated by the processor."""
+ count = 0
+ cache_files = []
+ if not self.cache_files and X is not None:
+ data_fingerprint = getattr(
+ X, "_fingerprint", None
+ ) or fingerprint_from_data(X)
+
+ if cache_dir is not None:
+ cache_dir = os.path.join(expand_path(str(cache_dir)), "processors")
+ cache_dir = generate_cache_dir(X, data_fingerprint, cache_dir=cache_dir)
+ if cache_dir:
+ cache_files = [
+ {
+ "filename": get_cache_file_name(
+ cache_dir, self.fingerprint, cache_file_name
+ )
+ }
+ ]
+ else:
+ cache_files = self.cache_files
+
+ for cache_file in cache_files:
+ if os.path.exists(cache_file["filename"]):
+ os.remove(cache_file["filename"])
+ count += 1
+ return count
diff --git a/src/biofit/stat/__init__.py b/src/biofit/stat/__init__.py
new file mode 100644
index 0000000..e71b198
--- /dev/null
+++ b/src/biofit/stat/__init__.py
@@ -0,0 +1,9 @@
+# ruff: noqa
+from .col_sum import *
+from .col_mean import *
+from .col_missingness import *
+from .row_sum import *
+from .row_mean import *
+from .row_missingness import *
+from .distance import *
+from .correlation import *
diff --git a/src/biofit/stat/col_mean/__init__.py b/src/biofit/stat/col_mean/__init__.py
new file mode 100644
index 0000000..7e79fb9
--- /dev/null
+++ b/src/biofit/stat/col_mean/__init__.py
@@ -0,0 +1,6 @@
+# ruff: noqa
+from .col_mean import (
+ ColumnMeanStat,
+ ColumnMeanStatConfig,
+ ColumnMeanStatConfigForOTU,
+)
diff --git a/src/biofit/stat/col_mean/col_mean.py b/src/biofit/stat/col_mean/col_mean.py
new file mode 100644
index 0000000..bca2964
--- /dev/null
+++ b/src/biofit/stat/col_mean/col_mean.py
@@ -0,0 +1,208 @@
+from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, List, Type
+
+import numpy as np
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import SelectedColumnTypes
+from biofit.utils import logging
+
+from ..stat import Stat, StatConfig
+
+if TYPE_CHECKING:
+ import polars as pl
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class ColumnMeanStatConfig(StatConfig):
+ # process description
+ _transform_process_desc: str = field(
+ default="Calculating column means", init=False, repr=False
+ )
+ processor_name: str = field(default="col_mean", init=False, repr=False)
+
+
+@dataclass
+class ColumnMeanStatConfigForOTU(ColumnMeanStatConfig):
+ """Computes the sum of each row in the OTUAbundance feature.
+ This class is the same as TotalOTUAbundanceStat.
+ It is provided for autostat
+ """
+
+ # dataset specific attributes
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance")], init=False, repr=False
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance")], init=False, repr=False
+ )
+ dataset_name: str = field(default="otu", init=False, repr=False)
+
+
+class ColumnMeanStat(Stat):
+ # config attributes
+ _config_class = ColumnMeanStatConfig
+ config: ColumnMeanStatConfig
+
+ sums_ = None
+ counts_ = 0
+
+ def fit(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ num_proc: int = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ) -> "ColumnMeanStat":
+ self.config._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_fit(
+ X,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ )
+
+ def transform(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ self._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_transform(
+ X,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def fit_transform(
+ self,
+ X,
+ *args,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = True,
+ raise_if_missing: bool = False,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ num_proc: int = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ):
+ return self.fit(
+ X,
+ input_columns=input_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ ).transform(
+ X,
+ input_columns=input_columns,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def _fit_numpy(self, X: np.ndarray, y=None):
+ self.config.means = np.mean(X, axis=0)
+ return self
+
+ def _partial_fit_numpy(self, X: np.ndarray, y=None):
+ self.counts_ += X.shape[0]
+ if self.sums_ is None:
+ self.sums_ = np.sum(X, axis=0)
+ else:
+ self.sums_ += np.sum(X, axis=0)
+
+ # self.config.means = self.config.sums_ / self.config.counts_
+ return self
+
+ def _fit_polars(self, X: "pl.DataFrame"):
+ self.config.means = X.mean()
+ return self
+
+ def _partial_fit_polars(self, X: "pl.DataFrame"):
+ self.counts_ += X.shape[0]
+ if self.sums_ is None:
+ self.sums_ = X.sum()
+ else:
+ self.sums_ += X.sum()
+
+ # self.config.means = self.sums_ / self.counts_
+ return self
+
+ def _pool_fit(self, fitted_processors: List["ColumnMeanStat"]):
+ self.sums_ = sum([p.sums_ for p in fitted_processors])
+ self.counts_ = sum([p.counts_ for p in fitted_processors])
+ self.config.means = self.sums_ / self.counts_
+ return self
+
+ def _process_transform_output(self, output, input, *args, **kwargs):
+ return super()._process_transform_output(
+ self.config.means, input, *args, **kwargs
+ )
diff --git a/src/biofit/stat/col_missingness/__init__.py b/src/biofit/stat/col_missingness/__init__.py
new file mode 100644
index 0000000..d6fad16
--- /dev/null
+++ b/src/biofit/stat/col_missingness/__init__.py
@@ -0,0 +1,8 @@
+# ruff: noqa
+from .col_missingness import (
+ ColumnMissingnessStat,
+ ColumnMissingnessStatConfig,
+ ColumnMissingnessStatConfigForOTU,
+ ColumnMissingnessStatConfigForSNP,
+ ColumnMissingnessStatConfigForMetagenomics,
+)
diff --git a/src/biofit/stat/col_missingness/col_missingness.py b/src/biofit/stat/col_missingness/col_missingness.py
new file mode 100644
index 0000000..0af8cbc
--- /dev/null
+++ b/src/biofit/stat/col_missingness/col_missingness.py
@@ -0,0 +1,321 @@
+import os
+import sys
+from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, List, Optional, Type, Union
+
+import numpy as np
+import pandas as pd
+import pyarrow as pa
+import pyarrow.compute as pc
+from biocore import DataHandler
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import SelectedColumnTypes, sync_backup_config
+from biofit.utils import logging
+
+from ..stat import Stat, StatConfig
+
+if TYPE_CHECKING:
+ import polars as pl
+
+logger = logging.get_logger(__name__)
+
+
+def is_multiprocess_mode():
+ return os.getppid() > 1
+
+
+def _filter_features_pandas(X: pd.DataFrame, depth: Optional[float] = None):
+ """
+ SampleFilter features in a pandas DataFrame based on their presence in the dataset.
+
+ Args:
+ X (pd.DataFrame): The input DataFrame containing the features.
+ threshold (float, optional): The minimum required presence of a feature in the dataset. Defaults to 0.5.
+ depth (float, optional): The depth threshold for numeric columns. If specified, counts values less than depth. Defaults to None.
+
+ Returns:
+ List[int]: A list of column indices that pass the filtering condition.
+ """
+ total_missing = X.isnull().sum(axis=0)
+ if depth is not None:
+ total_missing += (X <= depth).sum(axis=0)
+ return total_missing.to_frame().transpose()
+
+
+def _filter_features_numpy(X: np.ndarray, depth: Optional[float] = None):
+ total_missing = np.sum(np.isnan(X) | (X is None), axis=0)
+ if depth is not None:
+ total_missing += np.sum(X <= depth, axis=0)
+ return total_missing[np.newaxis, :]
+
+
+def _filter_features_polars(X: "pl.DataFrame", depth: Optional[float] = None):
+ if "polars" not in sys.modules:
+ import polars
+ else:
+ polars = sys.modules["polars"]
+ total_missing = X.null_count().sum()
+ if depth is not None:
+ less_than_depth = X.with_columns(polars.col("*") <= depth).sum()
+ total_missing += less_than_depth
+
+ return total_missing
+
+
+@dataclass
+class ColumnMissingnessStatConfig(StatConfig):
+ # process description
+ transform_process_name: str = field(
+ default="Calculating the number of missing values", init=False, repr=False
+ )
+ processor_name: str = field(default="col_missingness", init=False, repr=False)
+ _fit_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ _transform_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+
+ # config attributes
+ depth: Optional[Union[float, int]] = None
+
+
+@dataclass
+class ColumnMissingnessStatConfigForOTU(ColumnMissingnessStatConfig):
+ """Computes the sum of each row in the OTUAbundance feature.
+ This class is the same as RowSumStatForOTU.
+ """
+
+ transform_process_name: str = field(
+ default="Calculating species richness", init=False, repr=False
+ )
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance")], init=False, repr=False
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance")], init=False, repr=False
+ )
+ dataset_name: str = field(default="otu", init=False, repr=False)
+
+ depth: Optional[Union[float, int]] = 0
+
+
+@dataclass
+class ColumnMissingnessStatConfigForSNP(ColumnMissingnessStatConfig):
+ """Computes the sum of each row in the OTUAbundance feature.
+ This class is the same as RowSumStatForOTU.
+ """
+
+ transform_process_name: str = field(
+ default="Calculating species richness", init=False, repr=False
+ )
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance")], init=False, repr=False
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance")], init=False, repr=False
+ )
+ dataset_name: str = field(default="snp", init=False, repr=False)
+
+ depth: Optional[Union[float, int]] = 0
+
+
+@dataclass
+class ColumnMissingnessStatConfigForMetagenomics(ColumnMissingnessStatConfig):
+ """Computes the sum of each row in the OTUAbundance feature.
+ This class is the same as RowSumStatForOTU.
+ """
+
+ transform_process_name: str = field(
+ default="Calculating species richness", init=False, repr=False
+ )
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [(get_feature("Abundance"), get_feature("ReadCount"))],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [(get_feature("Abundance"), get_feature("ReadCount"))],
+ init=False,
+ repr=False,
+ )
+ dataset_name: str = field(default="metagenomics", init=False, repr=False)
+
+ depth: Optional[Union[float, int]] = 0
+
+
+class ColumnMissingnessStat(Stat):
+ # config attributes
+ _config_class = ColumnMissingnessStatConfig
+ config: ColumnMissingnessStatConfig
+ output_dtype = "int64"
+
+ def __init__(
+ self,
+ depth: Optional[Union[float, int]] = None,
+ config: Optional[ColumnMissingnessStatConfig] = None,
+ **kwargs,
+ ):
+ super().__init__(config=config, depth=depth, **kwargs)
+
+ @sync_backup_config
+ def set_params(self, **kwargs):
+ self.config = self.config.replace_defaults(**kwargs)
+ return self
+
+ def fit(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ num_proc: int = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ) -> "ColumnMissingnessStat":
+ return self._process_fit(
+ X,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ )
+
+ def transform(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ self._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_transform(
+ X,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def fit_transform(
+ self,
+ X,
+ *args,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = True,
+ raise_if_missing: bool = False,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ num_proc: int = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ):
+ return self.fit(
+ X,
+ input_columns=input_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ ).transform(
+ X,
+ input_columns=input_columns,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def _process_transform_input(self, X, **kwargs):
+ kwargs["batch_size"] = DataHandler.get_shape(X)[0]
+ return X, kwargs
+
+ def _transform_numpy(self, X: np.ndarray) -> np.ndarray:
+ return _filter_features_numpy(X, self.config.depth)
+
+ def _transform_pandas(self, X: pd.DataFrame) -> pd.DataFrame:
+ return _filter_features_pandas(X, self.config.depth)
+
+ def _transform_polars(self, X: "pl.DataFrame") -> "pl.DataFrame":
+ return _filter_features_polars(X, self.config.depth)
+
+ def _transform_arrow(self, X: pa.Table) -> pa.Table:
+ if self.config.depth is not None:
+ depth = self.config.depth
+ return pa.table(
+ {
+ k: [
+ pc.filter(
+ v, pc.or_(pc.less_equal(v, depth), pc.is_null(v))
+ ).length()
+ ]
+ for k, v in zip(X.column_names, X.columns)
+ }
+ )
+ else:
+ return pa.table(
+ {
+ k: [pc.filter(v, pc.is_null(v)).length()]
+ for k, v in zip(X.column_names, X.columns)
+ }
+ )
diff --git a/src/biofit/stat/col_sum/__init__.py b/src/biofit/stat/col_sum/__init__.py
new file mode 100644
index 0000000..75606c4
--- /dev/null
+++ b/src/biofit/stat/col_sum/__init__.py
@@ -0,0 +1,6 @@
+# ruff: noqa
+from .col_sum import (
+ ColumnSumStat,
+ ColumnSumStatConfig,
+ ColumnSumStatConfigForOTU,
+)
diff --git a/src/biofit/stat/col_sum/col_sum.py b/src/biofit/stat/col_sum/col_sum.py
new file mode 100644
index 0000000..f83e6b9
--- /dev/null
+++ b/src/biofit/stat/col_sum/col_sum.py
@@ -0,0 +1,277 @@
+"""
+This module provides classes for computing the sum of rows and columns in input data.
+
+Classes:
+- RowSumStat: Computes the sum of each row in the input data.
+- ColumnSumStat: Computes the sum of each column in the input data.
+- ColumnStatForOTU: Computes the sum of each column in the OTUAbundance feature.
+- RowSumStatForOTU: Computes the sum of each row in the OTUAbundance feature.
+- TotalOTUAbundanceStat: Alias for RowSumStatForOTU.
+
+These classes provide methods for transforming and fitting the input data using different libraries such as numpy, pandas, and polars.
+
+Note: The ColumnStatForOTU and RowSumStatForOTU classes are provided for autostat and are equivalent to ColumnSumStat and RowSumStat respectively.
+"""
+
+from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, Callable, List, Optional, Type
+
+import numpy as np
+from biocore import DataHandler
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import SelectedColumnTypes
+from biofit.utils import logging
+
+from ..stat import Stat, StatConfig
+
+if TYPE_CHECKING:
+ import polars as pl
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class ColumnSumStatConfig(StatConfig):
+ _fit_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ _transform_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ _fit_process_desc: str = field(
+ default="Calculating column sums", init=False, repr=False
+ )
+ _transform_process_desc: str = field(
+ default="Calculating column sums", init=False, repr=False
+ )
+ processor_name: str = field(default="col_sum", init=False, repr=False)
+
+
+class ColumnSumStat(Stat):
+ """
+ Computes the sum of each column in the input data.
+ """
+
+ # config attributes
+ _config_class = ColumnSumStatConfig
+ config: ColumnSumStatConfig
+ output_dtype = "float64"
+
+ def fit(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ num_proc: int = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ) -> "ColumnSumStat":
+ self.config._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_fit(
+ X,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ )
+
+ def transform(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ self._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_transform(
+ X,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def fit_transform(
+ self,
+ X,
+ *args,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = True,
+ raise_if_missing: bool = False,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ num_proc: int = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ):
+ return self.fit(
+ X,
+ input_columns=input_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ ).transform(
+ X,
+ input_columns=input_columns,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def _fit_numpy(self, X: np.ndarray):
+ """
+ Fits the model to the input data using numpy.
+
+ Args:
+ X: Input data as a numpy array.
+
+ Returns:
+ Fitted model.
+ """
+ self.config.sums_ = np.sum(X, axis=0)
+ return self
+
+ def _partial_fit_numpy(self, X: np.ndarray):
+ """
+ Updates the fitted model with additional input data using numpy.
+
+ Args:
+ X: Additional input data as a numpy array.
+
+ Returns:
+ Updated fitted model.
+ """
+ if self.config.sums_ is None:
+ self.config.sums_ = np.sum(X, axis=0)
+ else:
+ self.config.sums_ += np.sum(X, axis=0)
+ return self
+
+ def _fit_polars(self, X: "pl.DataFrame"):
+ """
+ Fits the model to the input data using polars.
+
+ Args:
+ X: Input data as a polars DataFrame.
+
+ Returns:
+ Fitted model.
+ """
+ self.config.sums_ = X.sum()
+ return self
+
+ def _partial_fit_polars(self, X: "pl.DataFrame"):
+ """
+ Updates the column sums with additional input data using polars.
+
+ Args:
+ X: Additional input data as a polars DataFrame.
+
+ Returns:
+ Updated fitted model.
+ """
+ if self.config.sums_ is None:
+ self.config.sums_ = X.sum()
+ else:
+ self.config.sums_ += X.sum()
+ return self
+
+ def _pool_fit(self, fitted_processors: List["ColumnSumStat"]):
+ """
+ Pools the fitted models from different batches.
+
+ Args:
+ fitted_processors: List of fitted models.
+
+ Returns:
+ Pooled fitted model.
+ """
+ self.config.sums_ = sum(
+ [processor.config.sums_ for processor in fitted_processors]
+ )
+ return self
+
+ def _process_transform_output(self, output, input, *args, **kwargs):
+ return super()._process_transform_output(
+ self.config.sums_, input, *args, **kwargs
+ )
+
+
+@dataclass
+class ColumnSumStatConfigForOTU(ColumnSumStatConfig):
+ # process description
+ _transform_process_desc: str = field(
+ default="Calculating total OTU abundance for each sample",
+ init=False,
+ repr=False,
+ )
+ dataset_name: str = field(default="otu", init=False, repr=False)
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance")], init=False, repr=False
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance")], init=False, repr=False
+ )
diff --git a/src/biofit/stat/correlation/__init__.py b/src/biofit/stat/correlation/__init__.py
new file mode 100644
index 0000000..d0cbfdf
--- /dev/null
+++ b/src/biofit/stat/correlation/__init__.py
@@ -0,0 +1,6 @@
+# ruff: noqa
+
+from .correlation import (
+ CorrelationStat,
+ CorrelationStatConfig,
+)
diff --git a/src/biofit/stat/correlation/correlation.py b/src/biofit/stat/correlation/correlation.py
new file mode 100644
index 0000000..4e15820
--- /dev/null
+++ b/src/biofit/stat/correlation/correlation.py
@@ -0,0 +1,235 @@
+"""
+Correlation calculation
+"""
+
+from dataclasses import dataclass, field
+
+from biocore import DataHandler
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import (
+ SelectedColumnTypes,
+ SelectedFeatureTypes,
+ sync_backup_config,
+)
+from biofit.utils import logging
+
+from ..stat import Stat, StatConfig
+
+logger = logging.get_logger(__name__)
+
+CORRELATION_STAT_DOCSTRING = """
+Stat features based on the correlation with the target variable.
+
+Args:
+ X: Input data
+ y: Target variable
+ input_columns: Columns to filter
+ target_column: Target column
+ **kwargs: Additional keyword arguments
+
+Returns:
+ SampleFiltered data.
+"""
+
+
+@dataclass
+class CorrelationStatConfig(StatConfig):
+ _transform_input_feature_types: SelectedFeatureTypes = field(
+ default_factory=lambda: [None, get_feature("TARGET_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ _transform_unused_feature_types: SelectedFeatureTypes = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES"), None],
+ init=False,
+ repr=False,
+ )
+ processor_name: str = field(default="correlation", init=False, repr=False)
+
+ input_columns: str = None
+ target_column: str = None
+ method: str = "auto"
+
+ def __post_init__(self):
+ if self.method not in [
+ "auto",
+ "pearsonr",
+ "spearmanr",
+ "kendalltau",
+ "pointbiserialr",
+ ]:
+ raise ValueError(
+ f"Method {self.method} not supported. Supported methods are: pearsonr, spearmanr, kendalltau, pointbiserialr"
+ )
+ if self.method != "auto":
+ self._transform_process_desc = (
+ f"Calculating {self.method} correlation with target variable"
+ )
+
+
+class CorrelationStat(Stat):
+ """
+ Correlation calculation based on the correlation with the target variable.
+ """
+
+ config_class = CorrelationStatConfig
+ config: CorrelationStatConfig
+ output_dtype = "float64"
+ func = None
+
+ @sync_backup_config
+ def set_params(self, **kwargs):
+ self.config = self.config.replace_defaults(**kwargs)
+ import scipy.stats
+
+ if self.config.method != "auto":
+ self.func = getattr(scipy.stats, self.config.method)
+ return self
+
+ def transform(
+ self,
+ X,
+ y,
+ input_columns: SelectedColumnTypes = None,
+ target_column: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ self._input_columns = self._set_input_columns_and_arity(
+ input_columns, target_column
+ )
+ return self._process_transform(
+ X,
+ y,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def fit_transform(
+ self,
+ X,
+ y=None,
+ input_columns: SelectedColumnTypes = None,
+ target_column: SelectedColumnTypes = None,
+ keep_unused_columns: bool = True,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ num_proc: int = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ):
+ return self.fit(
+ X,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ ).transform(
+ X,
+ y,
+ input_columns=input_columns,
+ target_column=target_column,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def _check_data(self, X, y):
+ if self.config.method == "pointbiserialr":
+ dtypes = next(iter(DataHandler.get_dtypes(y).values()))
+ if not any(dtype in dtypes for dtype in ["int", "bool", "str"]):
+ raise ValueError(
+ f"Point biserial correlation can only be used with binary target variables. Data type: {dtypes}"
+ )
+ n_unique = DataHandler.nunique(
+ y, DataHandler.get_column_names(y, generate_cols=True)[0]
+ )
+ if n_unique != 2:
+ raise ValueError(
+ f"Point biserial correlation can only be used with binary target variables. Number of unique values: {n_unique}"
+ )
+ return y, X # the first argument must be the binary target variable
+ return X, y
+
+ def _process_transform_input(self, X, **kwargs):
+ if self.config.method == "auto":
+ import scipy.stats
+
+ inds = kwargs["fn_kwargs"]["extra_indices"][0]
+ is_categorical = DataHandler.is_categorical(X, inds, threshold=0.05)
+ if is_categorical:
+ n_classes = DataHandler.nunique(X, inds)
+ if n_classes == 2:
+ self.config.method = "pointbiserialr"
+ kwargs["desc"] = (
+ "Calculating point biserial correlation with target variable"
+ )
+ else:
+ self.config.method = "pearsonr"
+ kwargs["desc"] = (
+ "Calculating pearson correlation with target variable"
+ )
+ else:
+ self.config.method = "pearsonr"
+ kwargs["desc"] = "Calculating pearson correlation with target variable"
+ self.func = getattr(scipy.stats, self.config.method)
+ return X, kwargs
+
+ def _transform_sklearn(self, X, y):
+ cols = DataHandler.get_column_names(X, generate_cols=True)
+ corrs = {}
+ for col in cols:
+ corr, _ = self.func(
+ *self._check_data(
+ DataHandler.select_column(X, col), DataHandler.select_column(y, 0)
+ )
+ )
+ corrs[str(col)] = [corr]
+ return corrs
diff --git a/src/biofit/stat/distance/__init__.py b/src/biofit/stat/distance/__init__.py
new file mode 100644
index 0000000..8632b03
--- /dev/null
+++ b/src/biofit/stat/distance/__init__.py
@@ -0,0 +1,9 @@
+# ruff: noqa
+from .distance import (
+ DistanceStat,
+ DistanceStatConfig,
+ DistanceStatConfigForASV,
+ DistanceStatConfigForOTU,
+ DistanceStatConfigForMetagenomics,
+ DistanceStatConfigForReadCount,
+)
diff --git a/src/biofit/stat/distance/distance.py b/src/biofit/stat/distance/distance.py
new file mode 100644
index 0000000..0c62e3b
--- /dev/null
+++ b/src/biofit/stat/distance/distance.py
@@ -0,0 +1,296 @@
+from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, List, Optional, Type
+
+import numpy as np
+from biocore import DataHandler
+from sklearn.metrics import DistanceMetric
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import SelectedColumnTypes, sync_backup_config
+
+from ..stat import Stat, StatConfig
+
+if TYPE_CHECKING:
+ pass
+
+
+@dataclass
+class DistanceStatConfig(StatConfig):
+ """
+ A base class for distance metrics.
+
+ Inherits all attributes and methods from StatConfig.
+
+ This class is tailored to handle distance metrics, including Minkowski,
+ weighted and unweighted, seuclidean, and mahalanobis.
+ """
+
+ _fit_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ _transform_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ processor_name: str = field(default="distance", init=False, repr=False)
+ metric: str = "euclidean"
+
+
+@dataclass
+class DistanceStatConfigForMetagenomics(DistanceStatConfig):
+ """
+ Calculates distance metrics specifically for metagenomics data.
+
+ Inherits all configs from DistanceStatConfig.
+
+ This class is tailored to handle metagenomics data, including OTU abundance,
+ ASV abundance, and read counts, using the 'braycurtis' metric by default.
+ """
+
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [(get_feature("Abundance"), get_feature("ReadCount"))],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [(get_feature("Abundance"), get_feature("ReadCount"))],
+ init=False,
+ repr=False,
+ )
+ dataset_name: str = field(default="metagenomics", init=False, repr=False)
+ metric: str = "braycurtis"
+
+
+@dataclass
+class DistanceStatConfigForOTU(DistanceStatConfig):
+ """
+ A subclass for calculating distance metrics on OTU abundance data.
+
+ Inherits all configs from DistanceStatConfig, but is specifically
+ tailored for OTU abundance data, defaulting to the 'braycurtis' metric.
+ """
+
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance")], init=False, repr=False
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance")], init=False, repr=False
+ )
+ dataset_name: str = field(default="otu", init=False, repr=False)
+ metric: str = "braycurtis"
+
+
+@dataclass
+class DistanceStatConfigForASV(DistanceStatConfig):
+ """
+ A subclass for calculating distance metrics on ASV abundance data.
+
+ Inherits all attributes and methods from DistanceStat, but is specifically
+ tailored for ASV abundance data, defaulting to the 'braycurtis' metric.
+ """
+
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance")], init=False, repr=False
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance")], init=False, repr=False
+ )
+ dataset_name: str = field(default="asv", init=False, repr=False)
+ metric: str = "braycurtis"
+
+
+class DistanceStatConfigForReadCount(DistanceStatConfig):
+ """
+ A subclass for calculating distance metrics on read count data.
+
+ Inherits all attributes and methods from DistanceStat, but is specifically
+ tailored for read count data, defaulting to the 'braycurtis' metric.
+ """
+
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("ReadCount")], init=False, repr=False
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("ReadCount")], init=False, repr=False
+ )
+ dataset_name: str = field(default="read_count", init=False, repr=False)
+ metric: str = "braycurtis"
+
+
+class DistanceStatConfigForSNP(DistanceStatConfig):
+ """
+ A subclass for calculating distance metrics on read count data.
+
+ Inherits all attributes and methods from DistanceStat, but is specifically
+ tailored for read count data, defaulting to the 'braycurtis' metric.
+ """
+
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("GenomicVariant")], init=False, repr=False
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("GenomicVariant")], init=False, repr=False
+ )
+ dataset_name: str = field(default="snp", init=False, repr=False)
+
+ metric: str = "jaccard"
+
+
+class DistanceStat(Stat):
+ """
+ A base class for calculating distance metrics on genomic data.
+
+ Attributes:
+ metric (str): The name of the distance metric to use. Defaults to 'euclidean'.
+ p (float): The p-norm to apply for Minkowski, weighted and unweighted. Default is 2.
+ w (Union[None, np.ndarray], optional): The weight vector for weighted Minkowski. Default is None.
+ V (Union[None, np.ndarray], optional): The variance vector for seuclidean. Default is None.
+ VI (Union[None, np.ndarray], optional): The inverse of the covariance matrix for mahalanobis. Default is None.
+
+ Methods:
+ fit_sklearn(X: Union[np.ndarray, pd.DataFrame, "pl.DataFrame"]): Validates input data and calculates the pairwise distances between samples in X.
+ """
+
+ _config_class = DistanceStatConfig
+ config: DistanceStatConfig
+ output_dtype = "float64"
+
+ def __init__(
+ self,
+ config: Optional[DistanceStatConfig] = None,
+ metric: Optional[str] = "euclidean",
+ **kwargs,
+ ):
+ super().__init__(config=config, metric=metric, **kwargs)
+ self.pdist = DistanceMetric.get_metric(self.config.metric).pairwise
+
+ @sync_backup_config
+ def set_params(self, **kwargs):
+ self.config = self.config.replace_defaults(**kwargs)
+ self.pdist = DistanceMetric.get_metric(self.config.metric).pairwise
+ return self
+
+ def fit(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ num_proc: int = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ) -> "DistanceStat":
+ self.config._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_fit(
+ X,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ )
+
+ def transform(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ self.config._n_features_out = DataHandler.get_shape(X)[0]
+ self._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_transform(
+ X,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def fit_transform(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = True,
+ raise_if_missing: bool = False,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ num_proc: int = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ):
+ return self.fit(
+ X,
+ input_columns=input_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ ).transform(
+ X,
+ input_columns=input_columns,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def _transform_numpy(self, X: np.ndarray):
+ return self.pdist(X, X)
diff --git a/src/biofit/stat/row_mean/__init__.py b/src/biofit/stat/row_mean/__init__.py
new file mode 100644
index 0000000..a0ec42e
--- /dev/null
+++ b/src/biofit/stat/row_mean/__init__.py
@@ -0,0 +1,6 @@
+# ruff: noqa
+from .row_mean import (
+ RowMeanStat,
+ RowMeanStatConfig,
+ RowMeanStatConfigForOTU,
+)
diff --git a/src/biofit/stat/row_mean/row_mean.py b/src/biofit/stat/row_mean/row_mean.py
new file mode 100644
index 0000000..a6eecc2
--- /dev/null
+++ b/src/biofit/stat/row_mean/row_mean.py
@@ -0,0 +1,178 @@
+from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, List, Type
+
+import numpy as np
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import SelectedColumnTypes
+from biofit.utils import logging
+
+from ..stat import Stat, StatConfig
+
+if TYPE_CHECKING:
+ import polars as pl
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class RowMeanStatConfig(StatConfig):
+ _transform_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ # process description
+ _transform_process_desc: str = field(
+ default="Calculating row means", init=False, repr=False
+ )
+ processor_name: str = field(default="row_mean", init=False, repr=False)
+ _n_features_out: int = field(default=1, init=False, repr=False)
+
+
+@dataclass
+class RowMeanStatConfigForOTU(RowMeanStatConfig):
+ """Computes the sum of each row in the OTUAbundance feature.
+ This class is the same as TotalOTUAbundanceStat.
+ It is provided for autostat
+ """
+
+ # dataset specific attributes
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance")], init=False, repr=False
+ )
+ dataset_name: str = field(default="otu", init=False, repr=False)
+
+
+class RowMeanStat(Stat):
+ # config attributes
+ _config_class = RowMeanStatConfig
+ config: RowMeanStatConfig
+ output_dtype = "float64"
+
+ def fit(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ num_proc: int = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ) -> "RowMeanStat":
+ self.config._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_fit(
+ X,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ )
+
+ def transform(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ self._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_transform(
+ X,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def fit_transform(
+ self,
+ X,
+ *args,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = True,
+ raise_if_missing: bool = False,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ num_proc: int = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ):
+ return self.fit(
+ X,
+ input_columns=input_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ ).transform(
+ X,
+ input_columns=input_columns,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def _transform_numpy(self, X: np.ndarray, y=None):
+ if len(X.shape) == 1:
+ return np.mean(X)[:, None]
+ return np.mean(X, axis=1)[:, None]
+
+ def _transform_polars(self, X: "pl.DataFrame", y=None):
+ return X.mean_horizontal().to_frame()
diff --git a/src/biofit/stat/row_missingness/__init__.py b/src/biofit/stat/row_missingness/__init__.py
new file mode 100644
index 0000000..ce074fd
--- /dev/null
+++ b/src/biofit/stat/row_missingness/__init__.py
@@ -0,0 +1,8 @@
+# ruff: noqa
+from .row_missingness import (
+ RowMissingnessStat,
+ RowMissingnessStatConfig,
+ RowMissingnessStatConfigForOTU,
+ RowMissingnessStatConfigForSNP,
+ RowMissingnessStatConfigForMetagenomics,
+)
diff --git a/src/biofit/stat/row_missingness/row_missingness.py b/src/biofit/stat/row_missingness/row_missingness.py
new file mode 100644
index 0000000..8d93818
--- /dev/null
+++ b/src/biofit/stat/row_missingness/row_missingness.py
@@ -0,0 +1,296 @@
+import os
+import sys
+from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, List, Optional, Type
+
+import numpy as np
+import pandas as pd
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import SelectedColumnTypes
+from biofit.utils import logging
+
+from ..stat import Stat, StatConfig
+
+if TYPE_CHECKING:
+ import polars as pl
+
+logger = logging.get_logger(__name__)
+
+
+def is_multiprocess_mode():
+ return os.getppid() > 1
+
+
+def _filter_samples_pandas(X: pd.DataFrame, depth: Optional[float] = None):
+ """
+ SampleFilter samples in a pandas DataFrame based on their presence in the dataset.
+
+ Args:
+ X (pd.DataFrame): The input DataFrame containing the samples.
+ min_sample_presence (float, optional): The minimum required presence of a sample in the dataset. Defaults to 0.5.
+ depth (float, optional): The depth threshold for numeric columns. If specified, counts values less than depth. Defaults to None.
+
+ Returns:
+ List[int]: A list of row indices that pass the filtering condition.
+ """
+ total_missing = X.isnull().sum(axis=1)
+ if depth is not None:
+ numeric_rows = X.select_dtypes(include=np.number).index
+ total_missing[numeric_rows] += (X.loc[numeric_rows] <= depth).sum(axis=1)
+ return total_missing.to_frame()
+
+
+def _filter_samples_numpy(X: np.ndarray, depth: Optional[float] = None):
+ total_missing = np.sum(np.isnan(X) | (X is None), axis=1)
+ if depth is not None:
+ total_missing += np.sum(X <= depth, axis=1)
+ return total_missing[:, np.newaxis]
+
+
+def _filter_samples_polars(X: "pl.DataFrame", depth: Optional[float] = None):
+ if "polars" not in sys.modules:
+ import polars
+ else:
+ polars = sys.modules["polars"]
+ total_missing = X.with_columns(
+ polars.col("*").is_null() | polars.col("*").is_nan()
+ ).sum_horizontal()
+ if depth is not None:
+ total_missing += X.with_columns(polars.col("*") <= depth).sum_horizontal()
+ return total_missing.to_frame()
+
+
+@dataclass
+class RowMissingnessStatConfig(StatConfig):
+ # process description
+ _transform_process_desc: str = field(
+ default="Calculating the number of missing values for each sample",
+ init=False,
+ repr=False,
+ )
+ _fit_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ _transform_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ processor_name: str = field(default="row_missingness", init=False, repr=False)
+ _n_features_out: int = field(default=1, init=False, repr=False)
+
+ # config attributes
+ depth: Optional[float] = None
+
+
+@dataclass
+class RowMissingnessStatConfigForOTU(RowMissingnessStatConfig):
+ """Computes the sum of each row in the OTUAbundance feature.
+ This class is the same as TotalOTUAbundanceStat.
+ It is provided for autostat
+ """
+
+ # dataset specific attributes
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance")], init=False, repr=False
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance")], init=False, repr=False
+ )
+ transform_process_name: str = field(
+ default="Calculating sample richness", init=False, repr=False
+ )
+ dataset_name: str = field(default="otu", init=False, repr=False)
+
+ # config attributes
+ depth = 100
+
+
+@dataclass
+class RowMissingnessStatConfigForSNP(RowMissingnessStatConfig):
+ """Computes the sum of each row in the OTUAbundance feature.
+ This class is the same as TotalOTUAbundanceStat.
+ It is provided for autostat
+ """
+
+ # dataset specific attributes
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("GenomicVariant")], init=False, repr=False
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("GenomicVariant")], init=False, repr=False
+ )
+ transform_process_name: str = field(
+ default="Calculating sample richness", init=False, repr=False
+ )
+ dataset_name: str = field(default="snp", init=False, repr=False)
+
+ # config attributes
+ depth = 100
+
+
+@dataclass
+class RowMissingnessStatConfigForMetagenomics(RowMissingnessStatConfig):
+ """Computes the sum of each row in the OTUAbundance feature.
+ This class is the same as TotalOTUAbundanceStat.
+ It is provided for autostat
+ """
+
+ # dataset specific attributes
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [(get_feature("Abundance"), get_feature("ReadCount"))],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [(get_feature("Abundance"), get_feature("ReadCount"))],
+ init=False,
+ repr=False,
+ )
+ transform_process_name: str = field(
+ default="Calculating sample richness", init=False, repr=False
+ )
+
+ # config attributes
+ depth: int = 100
+
+
+class RowMissingnessStat(Stat):
+ # feature attributes
+ one_to_one_features = False
+ n_features_out = 1
+
+ # config attributes
+ _config_class = RowMissingnessStatConfig
+ config: RowMissingnessStatConfig
+
+ def fit(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ num_proc: int = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ) -> "RowMissingnessStat":
+ self.config._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_fit(
+ X,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ )
+
+ def transform(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ self._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_transform(
+ X,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def fit_transform(
+ self,
+ X,
+ *args,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = True,
+ raise_if_missing: bool = False,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ num_proc: int = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ):
+ return self.fit(
+ X,
+ input_columns=input_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ ).transform(
+ X,
+ input_columns=input_columns,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def _transform_numpy(self, X: np.ndarray) -> np.ndarray:
+ return _filter_samples_numpy(X, self.config.depth)
+
+ def _transform_pandas(self, X: pd.DataFrame) -> pd.DataFrame:
+ return _filter_samples_pandas(X, self.config.depth)
+
+ def _transform_polars(self, X: "pl.DataFrame") -> "pl.DataFrame":
+ return _filter_samples_polars(X, self.config.depth)
diff --git a/src/biofit/stat/row_sum/__init__.py b/src/biofit/stat/row_sum/__init__.py
new file mode 100644
index 0000000..d9e93c0
--- /dev/null
+++ b/src/biofit/stat/row_sum/__init__.py
@@ -0,0 +1,6 @@
+# ruff: noqa
+from .row_sum import (
+ RowSumStat,
+ RowSumStatConfig,
+ RowSumStatConfigForOTU,
+)
diff --git a/src/biofit/stat/row_sum/row_sum.py b/src/biofit/stat/row_sum/row_sum.py
new file mode 100644
index 0000000..2d55eda
--- /dev/null
+++ b/src/biofit/stat/row_sum/row_sum.py
@@ -0,0 +1,150 @@
+"""
+This module provides classes for computing the sum of rows and columns in input data.
+
+Classes:
+- RowSumStat: Computes the sum of each row in the input data.
+- ColumnSumStat: Computes the sum of each column in the input data.
+- ColumnStatForOTU: Computes the sum of each column in the OTUAbundance feature.
+- RowSumStatForOTU: Computes the sum of each row in the OTUAbundance feature.
+- TotalOTUAbundanceStat: Alias for RowSumStatForOTU.
+
+These classes provide methods for transforming and fitting the input data using different libraries such as numpy, pandas, and polars.
+
+Note: The ColumnStatForOTU and RowSumStatForOTU classes are provided for autostat and are equivalent to ColumnSumStat and RowSumStat respectively.
+"""
+
+from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, List, Type
+
+import numpy as np
+import pandas as pd
+from biocore.utils.import_util import requires_backends
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import SelectedColumnTypes
+from biofit.utils import logging
+
+from ..stat import Stat, StatConfig
+
+if TYPE_CHECKING:
+ import polars as pl
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class RowSumStatConfig(StatConfig):
+ # process description
+ _transform_process_desc: str = field(
+ default="Calculating row sums", init=False, repr=False
+ )
+ _fit_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ _transform_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ processor_name: str = field(default="row_sum", init=False, repr=False)
+ _n_features_out: int = field(default=1, init=False, repr=False)
+
+
+@dataclass
+class RowSumStatConfigForOTU(RowSumStatConfig):
+ """Computes the sum of each row in the OTUAbundance feature."""
+
+ # dataset specific attributes
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance")], init=False, repr=False
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("Abundance")], init=False, repr=False
+ )
+ dataset_name: str = field(default="otu", init=False, repr=False)
+
+
+class RowSumStat(Stat):
+ """
+ Computes the sum of each row in the input data.
+ """
+
+ # config attributes
+ _config_class = RowSumStatConfig
+ config: RowSumStatConfig
+ output_dtype = "float64"
+
+ def fit(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ num_proc: int = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ) -> "RowSumStat":
+ self.config._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_fit(
+ X,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ )
+
+ def _transform_numpy(self, X: np.ndarray):
+ """
+ Transforms the input data using numpy.
+
+ Args:
+ X: Input data as a numpy array.
+
+ Returns:
+ Transformed data as a numpy array.
+ """
+ if len(X.shape) == 1:
+ return np.sum(X)[:, None]
+ return np.sum(X, axis=1)[:, None]
+
+ def _transform_pandas(self, X: pd.DataFrame):
+ """
+ Transforms the input data using pandas.
+
+ Args:
+ X: Input data as a pandas DataFrame.
+
+ Returns:
+ Transformed data as a pandas DataFrame.
+ """
+ return X.astype("float64").sum(axis=1).to_frame()
+
+ def _transform_polars(self, X: "pl.DataFrame"):
+ """
+ Transforms the input data using polars.
+
+ Args:
+ X: Input data as a polars DataFrame.
+
+ Returns:
+ Transformed data as a polars DataFrame.
+ """
+ requires_backends(self._transform_polars, "polars")
+ import polars as pl
+
+ return X.cast(pl.Float64).sum_horizontal().to_frame()
diff --git a/src/biofit/stat/stat.py b/src/biofit/stat/stat.py
new file mode 100644
index 0000000..929d388
--- /dev/null
+++ b/src/biofit/stat/stat.py
@@ -0,0 +1,103 @@
+from dataclasses import dataclass, field
+
+from biofit.processing import BaseProcessor, ProcessorConfig, SelectedColumnTypes
+
+
+@dataclass
+class StatConfig(ProcessorConfig):
+ processor_type: str = field(default="stat", init=False, repr=False)
+
+
+class Stat(BaseProcessor):
+ """Base class for statistical processors."""
+
+ def transform(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ self._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_transform(
+ X,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def fit_transform(
+ self,
+ X,
+ *args,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = True,
+ raise_if_missing: bool = False,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ num_proc: int = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ):
+ return self.fit(
+ X,
+ input_columns=input_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ ).transform(
+ X,
+ input_columns=input_columns,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def _process_transform(self, *args, **kwargs):
+ kwargs["keep_unused_columns"] = False
+ return super()._process_transform(*args, **kwargs)
diff --git a/src/biofit/stat/summary/__init__.py b/src/biofit/stat/summary/__init__.py
new file mode 100644
index 0000000..7457204
--- /dev/null
+++ b/src/biofit/stat/summary/__init__.py
@@ -0,0 +1,2 @@
+# ruff: noqa
+from .summary import SummaryStatForOTU
diff --git a/src/biofit/stat/summary/summary.py b/src/biofit/stat/summary/summary.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/biofit/train.py b/src/biofit/train.py
new file mode 100644
index 0000000..fad4ab5
--- /dev/null
+++ b/src/biofit/train.py
@@ -0,0 +1,564 @@
+import copy
+import os
+from pathlib import Path
+from typing import TYPE_CHECKING, Callable, List, Union
+
+import numpy as np
+import pandas as pd
+from biocore import DataHandler
+from sklearn.base import BaseEstimator
+from sklearn.model_selection import (
+ BaseCrossValidator,
+)
+from sklearn.pipeline import Pipeline
+
+from biofit.auto.processing_auto import ProcessorPipeline
+from biofit.processing import BaseProcessor
+from biofit.train_eval_utils import (
+ _flatten_pipeline,
+ _get_data,
+ split,
+)
+from biofit.utils import (
+ enable_full_determinism,
+ logging,
+)
+
+if TYPE_CHECKING:
+ import polars as pl
+ from datasets import Dataset
+
+logger = logging.get_logger(__name__)
+
+
+def preprocess(
+ preprocessor,
+ x_train,
+ y_train=None,
+ x_valid=None,
+ y_valid=None,
+ cache_dir=None,
+ transform_only=False,
+ raise_error=False,
+):
+ try:
+ if isinstance(preprocessor, (Pipeline, ProcessorPipeline)):
+ for proc in preprocessor.steps:
+ p = proc[-1] if isinstance(proc, tuple) else proc
+ if isinstance(p, BaseProcessor):
+ extra_kwargs = {
+ "cache_dir": cache_dir,
+ # "load_from_cache_file": False,
+ }
+ # p.config.enable_caching = False
+ else:
+ extra_kwargs = {}
+ if transform_only:
+ x_train = p.transform(x_train, **extra_kwargs)
+ else:
+ x_train = p.fit_transform(x_train, **extra_kwargs)
+ if x_valid is not None:
+ x_valid = p.transform(x_valid, **extra_kwargs)
+ else:
+ if isinstance(preprocessor, BaseProcessor):
+ extra_kwargs = {
+ "cache_dir": cache_dir,
+ # "load_from_cache_file": False,
+ }
+ # p.config.enable_caching = False
+ else:
+ extra_kwargs = {}
+ if transform_only:
+ x_train = preprocessor.transform(x_train, **extra_kwargs)
+ else:
+ x_train = preprocessor.fit_transform(x_train, **extra_kwargs)
+ if x_valid is not None:
+ x_valid = preprocessor.transform(x_valid, **extra_kwargs)
+ except ValueError:
+ if raise_error:
+ raise
+ else:
+ logger.info("Preprocessing failed, using nan_to_num")
+ _train_cols = x_train.columns.tolist()
+ x_train = np.nan_to_num(x_train)
+ # convert back to dataframe
+ x_train = pd.DataFrame(x_train, columns=_train_cols)
+ if x_valid is not None:
+ _valid_cols = x_valid.columns.tolist()
+ x_valid = np.nan_to_num(x_valid)
+ x_valid = pd.DataFrame(x_valid, columns=_valid_cols)
+ return preprocess(
+ preprocessor,
+ x_train,
+ y_train,
+ x_valid,
+ y_valid,
+ cache_dir,
+ transform_only,
+ True,
+ )
+
+ return x_train, y_train, x_valid, y_valid
+
+
+def train(
+ model: Union[BaseEstimator, BaseProcessor],
+ data: Union[pd.DataFrame, "pl.DataFrame", "Dataset"],
+ target: Union[
+ pd.Series, "pl.Series", pd.DataFrame, "pl.DataFrame", "Dataset"
+ ] = None,
+ valid_data: Union[pd.DataFrame, "pl.DataFrame", "Dataset"] = None,
+ valid_target: Union[
+ pd.Series, "pl.Series", pd.DataFrame, "pl.DataFrame", "Dataset"
+ ] = None,
+ groups: Union[pd.Series, "pl.Series"] = None,
+ input_columns: Union[List[str], str] = "auto",
+ target_columns: Union[List[str], str] = "auto",
+ group_name: str = None,
+ preprocessor: Union[BaseEstimator, BaseProcessor] = None,
+ eval_metric: Union[str, Callable] = None,
+ task: str = None,
+ cv: BaseCrossValidator = None,
+ random_state: Union[List[int], int] = 42,
+ save_indices: bool = False,
+ output_dir: str = None,
+ cache_dir: str = None,
+):
+ """
+ Train a model or processor on the provided data, with optional preprocessing,
+ validation data, and cross-validation.
+
+ This function supports training models with various data formats (pandas DataFrame,
+ polars DataFrame, Dataset), optional preprocessing pipelines, cross-validation
+ strategies, and random seeds for reproducibility.
+
+ Parameters:
+ model (Union[BaseEstimator, BaseProcessor]):
+ The model or processor to train. This can be a scikit-learn estimator, a
+ pipeline, or a custom estimator that follows scikit-learn's API.
+ data (Union[pd.DataFrame, pl.DataFrame, "Dataset"]):
+ The training data. It can be a pandas DataFrame, polars DataFrame, or a
+ Dataset object.
+ target (Union[pd.Series, pl.Series, pd.DataFrame, pl.DataFrame, "Dataset"], optional):
+ The target values for supervised learning. If None, the target columns must
+ be specified in `data` using `target_columns`.
+ valid_data (Union[pd.DataFrame, pl.DataFrame, "Dataset"], optional):
+ Optional validation data. If not provided, validation can be done using
+ cross-validation.
+ valid_target (Union[pd.Series, pl.Series, pd.DataFrame, pl.DataFrame, "Dataset"], optional):
+ Target values for the validation data.
+ groups (Union[pd.Series, pl.Series], optional):
+ Group labels for the samples used while splitting the dataset into
+ train/test set.
+ input_columns (Union[List[str], str], optional):
+ Names of the input feature columns. If 'auto', the function will attempt to
+ infer the input columns from the data. Defaults to "auto".
+ target_columns (Union[List[str], str], optional):
+ Names of the target columns in `data`. If 'auto', the function will attempt
+ to infer the target columns. Defaults to "auto".
+ group_name (str, optional):
+ Name of the group column in `data` if groups are specified within the data.
+ preprocessor (Union[BaseEstimator, BaseProcessor], optional):
+ Preprocessing pipeline or estimator to apply before training the model.
+ eval_metric (Union[str, Callable], optional):
+ Evaluation metric or a callable function to evaluate the model's
+ performance.
+ task (str, optional):
+ Type of task to perform (e.g., 'classification', 'regression',
+ 'multilabel_classification', 'multi_regression').
+ cv (type of BaseCrossValidator, optional):
+ Cross-validation splitting strategy. If provided, the model will be trained
+ using cross-validation.
+ random_state (Union[List[int], int], optional):
+ Random seed(s) for reproducibility. Can be a single integer or a list of
+ integers. Defaults to 42.
+ save_indices (bool, optional):
+ Whether to save the indices of the train and validation splits. Defaults to
+ False.
+ output_dir (str, optional):
+ Directory where outputs will be saved. Defaults to None.
+ cache_dir (str, optional):
+ Directory where cache files will be stored. Defaults to None.
+
+ Returns:
+ Trained model or pipeline. The return type depends on the inputs:
+
+ - If `random_state` is an integer and `cv` is None, returns a single trained model or pipeline.
+ - If `random_state` is a list of integers, returns a list of trained models or pipelines, one for each random seed.
+ - If `cv` is specified, returns a list of trained models or pipelines, one for each cross-validation fold.
+
+ Examples:
+ 1. Basic usage with pandas DataFrame:
+
+ ```python
+ from biofit.trainer import train
+ from biofit.models import LogisticRegressionForClassification
+ import pandas as pd
+
+ # Create sample data
+ data = pd.DataFrame({
+ 'age': [25, 32, 47, 51, 62],
+ 'salary': [50000, 60000, 70000, 80000, 90000],
+ 'purchased': [0, 1, 0, 1, 0]
+ })
+
+ # Train the model
+ model = train(
+ model=LogisticRegressionForClassification(),
+ data=data,
+ target_columns='purchased',
+ input_columns=['age', 'salary']
+ )
+ ```
+
+ 2. Using validation data:
+
+ ```python
+ from biofit.train import train
+ from biofit.models import RandomForestForClassification
+ import pandas as pd
+
+ # Training data
+ train_data = pd.DataFrame({
+ 'feature': [1,2,3,4,5],
+ 'target': [2,4,6,8,10]
+ })
+
+ # Validation data
+ valid_data = pd.DataFrame({
+ 'feature': [6,7],
+ 'target': [12,14]
+ })
+
+ # Train the model with validation data
+ model = train(
+ model=RandomForestForClassification(),
+ data=train_data,
+ valid_data=valid_data,
+ target_columns='target',
+ input_columns='feature'
+ )
+ ```
+
+ 3. Cross-validation with random seed list:
+
+ ```python
+ from biofit.train import train
+ from biofit.models import RandomForestForClassifier
+ from sklearn.model_selection import StratifiedKFold
+ import pandas as pd
+
+ # Sample data
+ data = pd.DataFrame({
+ 'feature1': [1,2,3,4,5,6],
+ 'feature2': [6,5,4,3,2,1],
+ 'target': [0,1,0,1,0,1]
+ })
+
+ cv = StratifiedKFold(n_splits=3)
+ random_seeds = [42, 7, 21]
+
+ # Train the model with cross-validation and multiple random seeds
+ models = train(
+ model=RandomForestForClassifier(),
+ data=data,
+ target_columns='target',
+ input_columns=['feature1', 'feature2'],
+ cv=cv,
+ random_state=random_seeds
+ )
+
+ # models is a list of models trained with different random seeds
+ ```
+
+ 4. Using a preprocessor pipeline:
+
+ ```python
+ from biofit.train import train
+ from biofit.auto.processing_auto import ProcessorPipeline
+ from biofit.models import LogisticRegressionForClassification
+ from sklearn.compose import ColumnTransformer
+ import pandas as pd
+
+ # Sample data with categorical feature
+ data = pd.DataFrame({
+ 'age': [25, 32, 47, 51, 62],
+ 'gender': ['M', 'F', 'M', 'F', 'M'],
+ 'purchased': [0, 1, 0, 1, 0]
+ })
+
+ # Define preprocessing
+ numeric_features = ['age']
+ numeric_transformer = StandardScaler()
+
+ categorical_features = ['gender']
+ categorical_transformer = OneHotEncoder()
+
+ preprocessor = ColumnTransformer(
+ transformers=[
+ ('num', numeric_transformer, numeric_features),
+ ('cat', categorical_transformer, categorical_features)
+ ])
+
+ # Create a pipeline that first preprocesses the data, then applies the model
+ pipeline = ProcessorPipeline(steps=[
+ ('preprocessor', preprocessor),
+ ('classifier', LogisticRegressionForClassification())
+ ])
+
+ # Train the model with preprocessor
+ model = train(
+ model=pipeline,
+ data=data,
+ target_columns='purchased',
+ input_columns=['age', 'gender']
+ )
+ ```
+
+ 5. Training with a Dataset object:
+
+ ```python
+ from biofit.train import train
+ from datasets import load_dataset
+
+ # Load a dataset from the datasets library
+ dataset = load_dataset('csv', data_files='data.csv')
+
+ # Train the model
+ model = train(
+ model=SomeModel(),
+ data=dataset['train'],
+ target_columns='label',
+ input_columns='text'
+ )
+ ```
+
+ Notes:
+ - The `train` function can handle various data formats and will attempt to
+ process the data accordingly.
+ - If `input_columns` or `target_columns` are set to 'auto', the function will
+ try to infer the columns automatically based on the data.
+ - The `preprocessor` can be any transformer that follows the scikit-learn API,
+ including pipelines.
+ - When using cross-validation (`cv` is not None), the function will train models
+ for each fold and return a list of models.
+ - If multiple random seeds are provided (`random_state` is a list), the function
+ will train multiple models for feature importance analysis, but only return
+ the models trained with the first random seed.
+ """
+ x_train, y_train, x_valid, y_valid, groups, _, input_columns, target_columns = (
+ _get_data(
+ data=data,
+ target=target,
+ valid_data=valid_data,
+ valid_target=valid_target,
+ groups=groups,
+ group_name=group_name,
+ input_columns=input_columns,
+ target_columns=target_columns,
+ format="pandas",
+ target_required=True,
+ )
+ )
+ if cache_dir is None and output_dir is not None:
+ cache_dir = (Path(output_dir) / ".cache").resolve().as_posix()
+
+ def fit(
+ x_train,
+ y_train,
+ x_valid,
+ y_valid,
+ model,
+ preprocessor,
+ random_state=None,
+ cache_dir=None,
+ ):
+ if isinstance(model, Pipeline):
+ if preprocessor is None and len(model.steps) > 1:
+ preprocessor = _flatten_pipeline(
+ [p[-1] if isinstance(p, tuple) else p for p in model.steps[:-1]]
+ )
+ preprocessor = Pipeline(preprocessor)
+ model = (
+ model.steps[-1][1]
+ if isinstance(model.steps[-1], tuple)
+ else model.steps[-1]
+ )
+ if preprocessor:
+ x_train, y_train, x_valid, y_valid = preprocess(
+ preprocessor, x_train, y_train, x_valid, y_valid, cache_dir
+ )
+
+ m_params = (
+ model.get_params() if hasattr(model, "get_params") else model.__dict__
+ )
+ if hasattr(model, "set_params"):
+ if "random_state" in m_params:
+ model.set_params(random_state=random_state)
+ elif "seed" in m_params:
+ model.set_params(seed=random_state)
+ else:
+ if "random_state" in m_params:
+ model.random_state = random_state
+ elif "seed" in m_params:
+ model.seed = random_state
+ models = []
+ if task in ("multilabel_classification", "multi_regression"):
+ # also multi_column_regression
+ models = [model] * y_train.shape[1]
+ for idx, _m in enumerate(models):
+ if _m.__module__.startswith("xgboost"):
+ _m.fit(
+ x_train,
+ y_train[:, idx],
+ model__eval_set=[(x_valid, y_valid[:, idx])]
+ if x_valid is not None
+ else None,
+ model__verbose=False,
+ )
+ elif _m.__module__.startswith("biofit"):
+ old_val = _m.config.enable_caching
+ _m.config.enable_caching = False
+ if (
+ "early_stopping_rounds" in m_params
+ and m_params["early_stopping_rounds"] is not None
+ and m_params["early_stopping_rounds"] > 0
+ ):
+ _m.fit(
+ x_train,
+ y_train[:, idx],
+ eval_set=[(x_valid, y_valid[:, idx])]
+ if x_valid is not None
+ else None,
+ # load_from_cache_file=False,
+ )
+ else:
+ _m.fit(
+ x_train,
+ y_train[:, idx],
+ load_from_cache_file=False,
+ )
+ _m.config.enable_caching = old_val
+ else:
+ _m.fit(
+ x_train,
+ y_train[:, idx],
+ )
+
+ else:
+ if model.__module__.startswith("xgboost"):
+ model.fit(
+ x_train,
+ y_train,
+ model__eval_set=[(x_valid, y_valid)],
+ model__verbose=False,
+ )
+ elif model.__module__.startswith("biofit"):
+ old_val = model.config.enable_caching
+ model.config.enable_caching = False
+ model.fit(
+ x_train,
+ y_train,
+ load_from_cache_file=False,
+ )
+ model.config.enable_caching = old_val
+ else:
+ model.fit(x_train, y_train)
+
+ models = [model]
+ return models, preprocessor
+
+ count = 0
+ count += 1
+
+ def training_loop(random_state):
+ enable_full_determinism(random_state)
+ if cv is None:
+ _m = copy.deepcopy(model)
+ _p = copy.deepcopy(preprocessor)
+
+ fitted_model_, preprocessor_ = fit(
+ x_train=x_train,
+ y_train=y_train,
+ x_valid=x_valid,
+ y_valid=y_valid,
+ model=_m,
+ preprocessor=_p,
+ random_state=random_state,
+ cache_dir=cache_dir,
+ )
+
+ if isinstance(fitted_model_, list) and len(fitted_model_) == 1:
+ fitted_model_ = fitted_model_[0]
+ return fitted_model_, preprocessor_
+ else:
+ fitted_models = []
+ fitted_preprocessors = []
+ if hasattr(cv, "__dict__") and "random_state" in cv.__dict__:
+ cv.random_state = random_state
+ for fold, (train_idx, valid_idx) in enumerate(
+ split(cv, X=x_train, y=y_train, groups=groups)
+ ):
+ _cache_dir = os.path.join(cache_dir, f"fold_{fold}")
+ _m = copy.deepcopy(model)
+ _p = copy.deepcopy(preprocessor)
+
+ m_params = _m.get_params()
+ if "random_state" in m_params:
+ _m.set_params(random_state=random_state)
+ elif "seed" in m_params:
+ _m.set_params(seed=random_state)
+ xtrain_fold, xvalid_fold = (
+ DataHandler.select_rows(x_train, train_idx),
+ DataHandler.select_rows(x_train, valid_idx),
+ )
+ ytrain_fold, yvalid_fold = (
+ DataHandler.select_rows(y_train, train_idx),
+ DataHandler.select_rows(y_train, valid_idx),
+ )
+ fitted_model_fold, fitted_preprocessor_fold = fit(
+ xtrain_fold,
+ ytrain_fold,
+ xvalid_fold,
+ yvalid_fold,
+ _m,
+ _p,
+ random_state=random_state,
+ cache_dir=_cache_dir,
+ )
+ if isinstance(fitted_model_fold, list) and len(fitted_model_fold) == 1:
+ fitted_model_fold = fitted_model_fold[0]
+ fitted_models.append(fitted_model_fold)
+ fitted_preprocessors.append(fitted_preprocessor_fold)
+
+ return fitted_models, fitted_preprocessors
+
+ def create_pipeline(model, preprocessor):
+ pipeline = []
+ if isinstance(model, list):
+ for m, p in zip(model, preprocessor):
+ if isinstance(p, (Pipeline, ProcessorPipeline)):
+ p = ProcessorPipeline(_flatten_pipeline(p))
+ pipeline.append(
+ Pipeline([("preprocessor", p), ("model", m)]) if p else m
+ )
+ else:
+ if isinstance(preprocessor, (Pipeline, ProcessorPipeline)):
+ preprocessor = ProcessorPipeline(_flatten_pipeline(preprocessor))
+ pipeline = (
+ Pipeline([("preprocessor", preprocessor), ("model", model)])
+ if preprocessor
+ else model
+ )
+ return pipeline
+
+ if isinstance(random_state, list):
+ pipeline_list = []
+ for rs in random_state:
+ m, p = training_loop(rs)
+ pipeline_list.append(create_pipeline(m, p))
+
+ return pipeline_list
+ else:
+ fitted_model, fitted_preprocessor = training_loop(random_state)
+ return create_pipeline(fitted_model, fitted_preprocessor)
diff --git a/src/biofit/train_eval_utils.py b/src/biofit/train_eval_utils.py
new file mode 100644
index 0000000..2b7aa24
--- /dev/null
+++ b/src/biofit/train_eval_utils.py
@@ -0,0 +1,1406 @@
+import copy
+import os
+import sys
+from dataclasses import asdict
+from typing import TYPE_CHECKING, Callable, Generator, List, Union
+
+import joblib
+import numpy as np
+import pandas as pd
+import pyarrow as pa
+import yaml
+from biocore.data_handling import DataHandler
+from biocore.utils.import_util import (
+ is_catboost_available,
+ is_lightgbm_available,
+ is_plotly_available,
+ is_xgboost_available,
+)
+from biocore.utils.inspect import get_kwargs
+from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor
+from sklearn.model_selection import (
+ GroupKFold,
+ KFold,
+ LeaveOneGroupOut,
+ StratifiedGroupKFold,
+ StratifiedKFold,
+ train_test_split,
+)
+from sklearn.pipeline import Pipeline
+from sklearn.svm import SVC, SVR
+
+from biofit.auto.processing_auto import ProcessorPipeline
+from biofit.models import (
+ LightGBMForClassification,
+ LightGBMForRegression,
+ RandomForestForClassification,
+ RandomForestForRegression,
+)
+from biofit.models.lasso.lasso import LassoForClassification, LassoForRegression
+from biofit.models.models import Model
+from biofit.processing import BaseProcessor
+from biofit.utils import logging
+from biocore.utils.py_util import is_bioset, is_dataset
+from biofit.visualization.plotting_utils import plot_feature_importance
+
+if TYPE_CHECKING:
+ import optuna
+
+logger = logging.get_logger(__name__)
+
+
+CLASSIFICATION_TASKS = ["binary_classification", "multiclass_classification"]
+REGRESSION_TASKS = ["regression", "multi_regression"]
+
+_MODELS = {
+ "binary_classification": {
+ "random_forest": RandomForestForClassification,
+ "lightgbm": LightGBMForClassification,
+ "svm": SVC,
+ "gradient_boosting": GradientBoostingClassifier,
+ "lasso": LassoForClassification,
+ },
+ "multiclass_classification": {
+ "random_forest": RandomForestForClassification,
+ "lightgbm": LightGBMForClassification,
+ "svm": SVC,
+ "gradient_boosting": GradientBoostingClassifier,
+ "lasso": LassoForClassification,
+ },
+ "regression": {
+ "random_forest": RandomForestForRegression,
+ "lightgbm": LightGBMForRegression,
+ "svm": SVR,
+ "gradient_boosting": GradientBoostingRegressor,
+ "lasso": LassoForRegression,
+ },
+ "multi_regression": {
+ "random_forest": RandomForestForRegression,
+ "lightgbm": LightGBMForRegression,
+ "svm": SVR,
+ "gradient_boosting": GradientBoostingRegressor,
+ "lasso": LassoForRegression,
+ },
+}
+
+if is_xgboost_available():
+ from xgboost import XGBClassifier, XGBRegressor
+
+ _MODELS["binary_classification"]["xgboost"] = XGBClassifier
+ _MODELS["multiclass_classification"]["xgboost"] = XGBClassifier
+ _MODELS["regression"]["xgboost"] = XGBRegressor
+ _MODELS["multi_regression"]["xgboost"] = XGBRegressor
+
+if is_lightgbm_available():
+ _MODELS["binary_classification"]["lightgbm"] = LightGBMForClassification
+ _MODELS["multiclass_classification"]["lightgbm"] = LightGBMForClassification
+ _MODELS["regression"]["lightgbm"] = LightGBMForRegression
+ _MODELS["multi_regression"]["lightgbm"] = LightGBMForRegression
+
+if is_catboost_available():
+ from catboost import CatBoostClassifier, CatBoostRegressor
+
+ _MODELS["binary_classification"]["catboost"] = CatBoostClassifier
+ _MODELS["multiclass_classification"]["catboost"] = CatBoostClassifier
+ _MODELS["regression"]["catboost"] = CatBoostRegressor
+ _MODELS["multi_regression"]["catboost"] = CatBoostRegressor
+
+
+_CV = {
+ "binary_classification": {
+ "stratified_kfold": StratifiedKFold,
+ "kfold": KFold,
+ "group_kfold": GroupKFold,
+ "stratified_group_kfold": StratifiedGroupKFold,
+ "logo": LeaveOneGroupOut,
+ },
+ "multiclass_classification": {
+ "stratified_kfold": StratifiedKFold,
+ "kfold": KFold,
+ "group_kfold": GroupKFold,
+ "stratified_group_kfold": StratifiedGroupKFold,
+ "logo": LeaveOneGroupOut,
+ },
+ "regression": {
+ "stratified_kfold": StratifiedKFold,
+ "kfold": KFold,
+ },
+ "multi_regression": {
+ "stratified_kfold": StratifiedKFold,
+ "kfold": KFold,
+ },
+ "multilabel_classification": {
+ "stratified_kfold": StratifiedKFold,
+ "kfold": KFold,
+ },
+}
+
+
+def _flatten_pipeline(p):
+ pipe = []
+ if isinstance(p, list):
+ for step in p:
+ pipe.extend(_flatten_pipeline(step))
+ elif isinstance(p, (Pipeline, ProcessorPipeline)):
+ for step in p.steps:
+ pipe.extend(
+ _flatten_pipeline(step[-1] if isinstance(step, tuple) else step)
+ )
+ else:
+ pipe.append(p)
+ return pipe
+
+
+def _get_name_and_params(obj):
+ obj_params = None
+ obj_name = None
+ if obj:
+ if isinstance(obj, (ProcessorPipeline, Pipeline, list)):
+ if isinstance(obj, (Pipeline, ProcessorPipeline)):
+ obj = [step[-1] if isinstance(step, tuple) else step for step in obj]
+ obj_params = [
+ p.get_params() if isinstance(p, BaseProcessor) else p.__dict__
+ for p in obj
+ ]
+ obj_name = [
+ p.config.processor_name
+ if isinstance(p, BaseProcessor)
+ and getattr(p.config, "processor_name", None)
+ else p.__class__.__name__
+ for p in obj
+ ]
+ else:
+ obj_params = obj.get_params()
+ obj_name = (
+ obj.config.processor_name
+ if getattr(obj.config, "processor_name", None)
+ else obj.__class__.__name__
+ )
+ return obj_name, obj_params
+
+
+def _get_processor_info(
+ models,
+ preprocessors=None,
+):
+ if isinstance(models, list) and len(models) == 1:
+ models = models[0]
+
+ preprocessor_name, preprocessor_params = None, None
+ new_preprocessors = preprocessors
+ if isinstance(models, Pipeline) and preprocessors is None:
+ if len(models.steps) > 1:
+ new_preprocessors = _flatten_pipeline(
+ [p[-1] if isinstance(p, tuple) else p for p in models.steps[:-1]]
+ )
+ if isinstance(new_preprocessors, list):
+ if len(new_preprocessors) == 1:
+ new_preprocessors = new_preprocessors[0]
+ else:
+ new_preprocessors = ProcessorPipeline(new_preprocessors)
+
+ new_models = (
+ models.steps[-1][-1]
+ if isinstance(models.steps[-1], tuple)
+ else models.steps[-1]
+ )
+ model_name, model_params = _get_name_and_params(new_models)
+ if new_preprocessors is not None:
+ preprocessor_name, preprocessor_params = _get_name_and_params(
+ new_preprocessors
+ )
+ elif isinstance(models, list):
+ new_models, model_name, model_params = [], [], []
+ new_preprocessors, preprocessor_name, preprocessor_params = [], [], []
+ for i, m in enumerate(models):
+ _preprocessors = None
+ if isinstance(preprocessors, list):
+ _preprocessors = preprocessors[i]
+ if _preprocessors is None:
+ if isinstance(m, Pipeline):
+ if len(m.steps) > 1:
+ _preprocessors = _flatten_pipeline(
+ [p[-1] if isinstance(p, tuple) else p for p in m.steps[:-1]]
+ )
+ m = (
+ m.steps[-1][-1]
+ if isinstance(m.steps[-1], tuple)
+ else m.steps[-1]
+ )
+ if isinstance(_preprocessors, list):
+ if len(_preprocessors) == 1:
+ _preprocessors = _preprocessors[0]
+ else:
+ _preprocessors = ProcessorPipeline(_preprocessors)
+
+ _name, _params = _get_name_and_params(m)
+ new_models.append(m)
+ model_name.append(_name)
+ model_params.append(_params)
+ if _preprocessors:
+ _name, _params = _get_name_and_params(_preprocessors)
+ preprocessor_name.append(_name)
+ preprocessor_params.append(_params)
+ new_preprocessors.append(_preprocessors)
+ if len(preprocessor_name) == 0 or len(preprocessor_params) == 0:
+ preprocessor_name, preprocessor_params = None, None
+ else:
+ model_name, model_params = _get_name_and_params(models)
+ new_models = models
+ if preprocessors is not None:
+ new_preprocessors = preprocessors
+ preprocessor_name, preprocessor_params = _get_name_and_params(preprocessors)
+
+ return (
+ new_models,
+ model_name,
+ model_params,
+ new_preprocessors,
+ preprocessor_name,
+ preprocessor_params,
+ )
+
+
+def save_model_and_preprocessor(
+ output_dir,
+ models,
+ preprocessors=None,
+):
+ if models is not None:
+ joblib.dump(models, os.path.join(output_dir, "model.joblib"))
+
+ if preprocessors is not None:
+ if isinstance(preprocessors, list):
+ preprocessors = ProcessorPipeline(preprocessors)
+ joblib.dump(preprocessors, os.path.join(output_dir, "preprocessor.joblib"))
+
+
+def save_feature_importances(
+ output_dir,
+ feat_importances,
+ preprocessor,
+ data,
+ target,
+ target_columns,
+ save_plot=True,
+ label_names=None,
+ **kwargs,
+):
+ feat_importances = feat_importances.reset_index(names=["features"])
+ if save_plot:
+ transformed_dataset = (
+ preprocessor.fit_transform(data, load_from_cache_file=False)
+ if preprocessor
+ else data
+ )
+ if kwargs.get("input_columns") is not None:
+ trans_columns = set(
+ DataHandler.get_column_names(transformed_dataset, generate_cols=True)
+ )
+ kwargs["input_columns"] = [
+ col for col in kwargs["input_columns"] if col in trans_columns
+ ]
+
+ if is_bioset(data) or is_dataset(data):
+ y = DataHandler.select_columns(target, columns=target_columns)
+ else:
+ transformed_dataset = DataHandler.to_pandas(transformed_dataset)
+ y = DataHandler.to_pandas(target, columns=target_columns)
+
+ plot_dir = os.path.join(output_dir, "plots")
+ params = dict(
+ y=y,
+ path=plot_dir,
+ target_columns=target_columns,
+ )
+ params.update(kwargs)
+
+ if feat_importances.iloc[:, 1:].sum().sum() == 0:
+ logger.warning(
+ "All feature importances are zero. Please check the model and data. "
+ "Skipping feature importance plot."
+ )
+ else:
+ path = params.pop("path", None)
+ params["output_dir"] = path
+ plot_feature_importance(
+ X=transformed_dataset,
+ feature_importances=feat_importances,
+ label_names=label_names,
+ **params,
+ )
+
+ feature_metadata = None
+ if "feature_metadata" in kwargs:
+ feature_metadata = kwargs["feature_metadata"]
+ elif is_bioset(data):
+ from biosets import get_feature_metadata
+
+ feature_metadata = get_feature_metadata(data)
+
+ if feature_metadata is not None:
+ if isinstance(feature_metadata, dict):
+ feature_metadata = pd.DataFrame(
+ list(feature_metadata.values()), index=list(feature_metadata.keys())
+ )
+ feature_metadata = feature_metadata.reset_index(names=["features"])
+ if feat_importances.shape[1] > 2:
+ feat_importances["median"] = feat_importances.iloc[:, 1:].median(axis=1)
+ feat_importances = feature_metadata.merge(
+ feat_importances, how="inner", on="features"
+ )
+ feat_importances.to_csv(
+ os.path.join(output_dir, "feature_importances.csv"), index=False
+ )
+
+
+def save_study(
+ output_dir,
+ study,
+):
+ from optuna.visualization import (
+ plot_param_importances,
+ plot_slice,
+ )
+
+ joblib.dump(study, os.path.join(output_dir, "study.joblib"))
+ trials_df = study.trials_dataframe()
+ best_trial_col = [""] * len(trials_df)
+ best_trial_col[study.best_trial.number] = "Best Trial"
+ trials_df["Best Trial"] = best_trial_col
+ trials_df.to_csv(os.path.join(output_dir, "trials.csv"), index=False)
+ num_trials = len(study.trials)
+
+ plot_dir = os.path.join(output_dir, "plots")
+ param_importances_path = os.path.join(plot_dir, "param_importances_plot.pdf")
+ slice_path = os.path.join(plot_dir, "slice_plot.pdf")
+
+ os.makedirs(plot_dir, exist_ok=True)
+
+ if is_plotly_available():
+ try:
+ if num_trials is not None and num_trials > 1:
+ plot_param_importances(study).write_image(
+ param_importances_path, format="pdf", engine="kaleido"
+ )
+ plot_slice(study).write_image(slice_path, format="pdf", engine="kaleido")
+ except (ValueError, RuntimeError) as e:
+ logger.warning(f"Failed to save optimization plots: {e}")
+ else:
+ logger.warning_once(
+ "Plotly is not installed. Optimization plots will not be saved."
+ )
+
+
+def save_metrics(
+ output_dir,
+ metrics,
+):
+ if isinstance(metrics, list) and len(metrics) == 1:
+ metrics = pd.DataFrame({k: [v] for k, v in metrics[0].items()})
+ elif isinstance(metrics, dict):
+ metrics = pd.DataFrame({k: [v] for k, v in metrics.items()})
+ metrics.to_csv(os.path.join(output_dir, "metrics.csv"), index=False)
+
+
+def save_confusion_matrix(output_dir, confusion_matrices, label_names=None):
+ if not isinstance(confusion_matrices, list) or len(confusion_matrices) == 1:
+ if isinstance(confusion_matrices, list):
+ confusion_matrices = confusion_matrices[0]
+ confusion_matrices.to_csv(os.path.join(output_dir, "confusion_matrix.csv"))
+ else:
+ if label_names is None:
+ label_names = [f"label_{i}" for i in range(len(confusion_matrices))]
+ for i, cm in enumerate(confusion_matrices):
+ cm.to_csv(
+ os.path.join(output_dir, f"confusion_matrices_{label_names[i]}.csv")
+ )
+
+
+def save_predictions(
+ output_dir,
+ preds,
+ data,
+ target=None,
+ label_names=None,
+ target_columns=None,
+ index=None,
+):
+ # save as csv
+ if index is None:
+ index = range(len(data))
+
+ if label_names is None:
+ if is_bioset(data) or is_dataset(data):
+ if target_columns is not None:
+ if isinstance(target_columns, list):
+ target_columns = target_columns[0]
+ if target_columns in data._info.features:
+ label_names = data._info.features[target_columns].names
+ if label_names is None and (is_bioset(target) or is_dataset(target)):
+ target_columns = target_columns or DataHandler.get_column_names(target)[0]
+ if target_columns is not None:
+ if isinstance(target_columns, list):
+ target_columns = target_columns[0]
+ if target_columns in target._info.features:
+ label_names = target._info.features[target_columns].names
+
+ if is_bioset(data) or is_dataset(data):
+ from biosets.features import Sample
+
+ names = [
+ name
+ for name, feat in data._info.features.items()
+ if isinstance(feat, Sample)
+ ]
+ if len(names) == 1:
+ names = names[0]
+ sample_ids = DataHandler.to_numpy(data, names).flatten()
+
+ if sample_ids is not None:
+ if index is not None and isinstance(index, (list, np.ndarray)):
+ sample_ids = np.array(sample_ids)[np.array(index)]
+ if isinstance(preds.index[0], int):
+ preds.index = [sample_ids[i] for i in preds.index]
+ else:
+ preds.index = sample_ids
+ preds.index.name = names
+
+ if label_names is not None:
+ preds["predicted"] = DataHandler.argmax(
+ preds.iloc[:, : len(label_names)], axis=1
+ )
+ if target is not None:
+ preds["actual"] = DataHandler.to_numpy(target).flatten()[index].astype(int)
+ preds["actual"] = preds["actual"].map(
+ {i: name for i, name in enumerate(label_names)}
+ )
+ col_names = label_names + ["predicted", "actual"]
+
+ if is_bioset(target) or is_dataset(target):
+ from biosets.features import BinClassLabel
+
+ feat = target._info.features[target_columns]
+ if isinstance(feat, BinClassLabel) and (
+ feat.positive_labels is not None or feat.negative_labels is not None
+ ):
+ if feat.id in DataHandler.get_column_names(data):
+ preds["actual_original"] = DataHandler.to_pandas(
+ data, feat.id
+ ).values.flatten()[index]
+ col_names.append("actual_original")
+ else:
+ col_names = label_names + ["predicted"]
+ preds["predicted"] = preds["predicted"].map(
+ {i: name for i, name in enumerate(label_names)}
+ )
+ preds.columns = col_names
+ else:
+ preds["predicted"] = DataHandler.argmax(preds, axis=1)
+ if target is not None:
+ preds["actual"] = (
+ DataHandler.to_pandas(target).iloc[index, 0].values.astype(int)
+ )
+
+ if preds.index.name is not None:
+ preds.to_csv(os.path.join(output_dir, "predictions.csv"), index=True)
+ else:
+ preds.to_csv(os.path.join(output_dir, "predictions.csv"), index=False)
+
+
+def _save_train_results(
+ data,
+ orig_target,
+ target,
+ output_dir,
+ feature_importances=None,
+ confusion_matrices=None,
+ study: "optuna.Study" = None,
+ num_trials=None,
+ time_limit=None,
+ models=None,
+ metrics=None,
+ preds=None,
+ input_columns: List[str] = None,
+ target_columns: List[str] = None,
+ label_names: List[str] = None,
+ cv: Union[KFold, StratifiedKFold, GroupKFold, StratifiedGroupKFold] = None,
+ group_name: str = None,
+ outer_cv: Union[KFold, StratifiedKFold, GroupKFold, StratifiedGroupKFold] = None,
+ outer_group_name: str = None,
+ task: str = None,
+ eval_metric: Union[str, Callable] = None,
+ random_state: Union[int, List[int]] = None,
+ use_suggested_hyperparameters: bool = True,
+ unused_columns: List[str] = None,
+ valid_split: str = None,
+ index: List[int] = None,
+ feature_importance_plot_params=None,
+ **kwargs,
+):
+ orig_data = data
+ data, target, _, _, _, _, _, target_columns = _get_data(
+ data=data,
+ target=target,
+ valid_data=None,
+ valid_target=None,
+ groups=None,
+ group_name=None,
+ input_columns=input_columns,
+ target_columns=target_columns,
+ format=None,
+ target_required=True,
+ )
+
+ if isinstance(preds, list) and len(preds) == 1:
+ preds = preds[0]
+
+ os.makedirs(output_dir, exist_ok=True)
+ if study:
+ save_study(output_dir, study)
+
+ preprocessors = None
+ if models is not None:
+ (
+ models,
+ model_name,
+ model_params,
+ preprocessors,
+ preprocessor_name,
+ preprocessor_params,
+ ) = _get_processor_info(models)
+ save_model_and_preprocessor(output_dir, models, preprocessors)
+
+ if metrics is not None:
+ save_metrics(output_dir, metrics)
+
+ if preds is not None:
+ save_predictions(
+ output_dir,
+ preds,
+ orig_data,
+ target if orig_target is None else orig_target,
+ label_names,
+ target_columns,
+ index,
+ )
+
+ if confusion_matrices is not None:
+ save_confusion_matrix(output_dir, confusion_matrices)
+
+ if feature_importances is not None:
+ feature_importance_plot_params = feature_importance_plot_params or {}
+ save_feature_importances(
+ output_dir,
+ feature_importances,
+ preprocessors,
+ orig_data,
+ target,
+ target_columns,
+ input_columns=input_columns,
+ label_names=label_names,
+ **feature_importance_plot_params,
+ )
+
+ params = {
+ "preprocessing": [],
+ "optimization": {},
+ "training": {},
+ "cross_validation": {},
+ "model": {},
+ "misc": {},
+ }
+
+ # General Information
+ if task is not None:
+ params["training"]["task"] = task
+
+ if label_names is not None:
+ if not isinstance(target_columns, list):
+ tc = [target_columns]
+ else:
+ tc = target_columns
+
+ if not isinstance(label_names, (np.ndarray, list)) or not isinstance(
+ label_names[0], (np.ndarray, list)
+ ):
+ label_names = [label_names]
+
+ if "classification" in task:
+ target_type = "labels"
+ else:
+ target_type = "targets"
+
+ params["training"][target_type] = {}
+ for i, lab_name in enumerate(tc):
+ if lab_name is not None:
+ class_feat = None
+ if (
+ is_bioset(target) or is_dataset(target)
+ ) and lab_name in target._info.features:
+ class_feat = target._info.features[lab_name]
+ elif is_bioset(data) or is_dataset(data):
+ class_feat = data._info.features[lab_name]
+ if class_feat is not None:
+ if "biosets" in sys.modules and isinstance(
+ class_feat, sys.modules["biosets"].BinClassLabel
+ ):
+ _names = class_feat.names
+ positive_labels = class_feat.positive_labels
+ negative_labels = class_feat.negative_labels
+ if positive_labels is not None or negative_labels is not None:
+ params["training"][target_type][lab_name] = {
+ _names[0]: negative_labels,
+ _names[1]: positive_labels,
+ }
+ else:
+ params["training"][target_type][lab_name] = _names
+ else:
+ lab_name = lab_name or f"{i + 1}"
+ params["training"][target_type][lab_name] = class_feat.names
+ else:
+ lab_name = lab_name or f"{i + 1}"
+ params["training"][target_type][lab_name] = DataHandler.to_list(
+ label_names[i]
+ )
+
+ if len(tc) == 1:
+ if tc[0] is None or tc[0] not in params["training"][target_type]:
+ p = params["training"][target_type]
+ if len(p):
+ params["training"][target_type] = params["training"][target_type][
+ next(iter(p))
+ ]
+ else:
+ params["training"][target_type] = params["training"][target_type][tc[0]]
+
+ if target_columns is not None:
+ params["training"]["target_columns"] = target_columns
+ if eval_metric is not None:
+ params["training"]["eval_metric"] = (
+ eval_metric.__name__ if callable(eval_metric) else eval_metric
+ )
+ if random_state is not None:
+ params["training"]["random_state"] = random_state
+
+ if valid_split is not None:
+ params["training"]["valid_split"] = valid_split
+
+ # Hyperparameter Tuning Information
+ if num_trials is not None:
+ params["optimization"]["num_trials"] = num_trials
+ if time_limit is not None:
+ params["optimization"]["time_limit"] = time_limit
+ # list tuned hyperparameters and their ranges
+ if study is not None:
+ search_space = study.get_trials()[0].distributions
+ params["optimization"]["search_space"] = {}
+ for param_name, dist in search_space.items():
+ dist_pars = dist.__dict__
+ dist_pars.pop("step", None)
+ params["optimization"]["search_space"][param_name] = {
+ "type": (dist.__class__.__name__),
+ **dist_pars,
+ }
+
+ # Cross-Validation Information
+ if cv is not None:
+ params["cross_validation"]["cv"] = {
+ "type": cv.__class__.__name__,
+ "params": cv.__dict__,
+ }
+
+ if outer_cv is not None:
+ params["cross_validation"]["outer_cv"] = {
+ "type": outer_cv.__class__.__name__,
+ "params": outer_cv.__dict__,
+ }
+
+ # Model Information
+ if model_name is not None:
+ params["model"]["name"] = model_name
+ params["model"]["params"] = model_params if model_params is not None else {}
+
+ # Preprocessor Information
+ if preprocessor_name is not None and preprocessor_params is not None:
+ if isinstance(preprocessor_name, list):
+ params["preprocessing"] = [
+ {"name": name, "params": params}
+ for name, params in zip(preprocessor_name, preprocessor_params)
+ ]
+ else:
+ params["preprocessing"] = {
+ "name": preprocessor_name,
+ "params": preprocessor_params,
+ }
+
+ if unused_columns is not None:
+ params["misc"]["output_dir"] = output_dir
+
+ if unused_columns is not None:
+ params["misc"]["unused_columns"] = unused_columns
+
+ if outer_group_name is not None:
+ params["misc"]["outer_group_name"] = outer_group_name
+
+ if group_name is not None:
+ params["misc"]["group_name"] = group_name
+
+ with open(os.path.join(output_dir, "training_params.yaml"), "w") as f:
+ # uncomment for debugging
+ # print(yaml.safe_dump(params, sort_keys=False))
+ yaml.safe_dump(params, f, sort_keys=False)
+
+
+def _init_cv(
+ cv, groups=None, cv_params=None, sub_task="classification", shuffle=True, seed=42
+):
+ n_splits = 5
+ if isinstance(cv, int):
+ n_splits = cv
+ if groups is not None:
+ cv = LeaveOneGroupOut
+ else:
+ cv = next(iter(_CV[sub_task].values()))
+
+ if isinstance(cv, str):
+ cv = _CV[sub_task][cv]
+
+ if isinstance(cv, type):
+ if cv_params is None:
+ if cv.__module__.split(".")[0] in ["biofit", "sklearn"]:
+ cv_params = {}
+ cv_params["shuffle"] = shuffle
+ cv_params["random_state"] = seed if shuffle else None
+ cv_params["n_splits"] = n_splits
+ else:
+ raise ValueError(
+ "Please provide cv_params for custom cross validation."
+ )
+ cv_params = get_kwargs(cv_params, cv)
+ cv = cv(**cv_params)
+ return cv
+
+
+def get_task(models):
+ if isinstance(models, list) and len(models):
+ model = models[0]
+ else:
+ model = models
+ if isinstance(model, (ProcessorPipeline, Pipeline)):
+ for m in model.steps:
+ val = m
+ if isinstance(val, tuple):
+ val = val[-1]
+ if hasattr(val, "config") and hasattr(val.config, "class_names"):
+ return val.config.class_names
+
+ elif hasattr(model, "config") and hasattr(model.config, "class_names"):
+ return model.config.class_names
+ return None
+
+
+def set_class_names(models, class_names):
+ if class_names is None:
+ return models
+ if isinstance(models, list) and len(models):
+ return [set_class_names(m, class_names) for m in models]
+ if isinstance(models, (ProcessorPipeline, Pipeline)):
+ for val in models.steps:
+ if isinstance(val, tuple) and hasattr(val[-1], "config"):
+ val[-1].config.class_names = class_names
+ if hasattr(val, "config"):
+ val.config.class_names = class_names
+ elif hasattr(models, "config"):
+ models.config.class_names = class_names
+ return models
+
+
+def infer_task(data=None, target=None, target_columns=None, task=None):
+ sub_task = None
+ if target is None and data is not None:
+ _, target, _, _, _, _, _, target_columns = _get_data(
+ data=data,
+ target=target,
+ valid_data=None,
+ valid_target=None,
+ groups=None,
+ group_name=None,
+ input_columns=None,
+ target_columns=target_columns,
+ format=None,
+ target_required=False,
+ )
+ if target is None and task is None:
+ raise ValueError("Target is required to infer task.")
+
+ target_dims = DataHandler.get_shape(target) if target is not None else []
+ class_names = None
+ if task is None and (is_bioset(target) or is_dataset(target)):
+ from biosets.features import BinClassLabel, ClassLabel, RegressionTarget
+
+ if isinstance(target_columns, list):
+ target_columns = target_columns[0]
+ if target_columns and target_columns in target._info.features:
+ feat = target._info.features[target_columns]
+ else:
+ feat = [
+ feat
+ for feat in target._info.features.values()
+ if isinstance(feat, (BinClassLabel, ClassLabel, RegressionTarget))
+ ][0]
+ if isinstance(feat, (BinClassLabel, ClassLabel)):
+ class_names = target._info.features[target_columns].names
+ task = "classification"
+ elif isinstance(feat, RegressionTarget):
+ task = "regression"
+ if task == "classification":
+ if len(target_dims) == 1 or target_dims[1] == 1:
+ n_classes = None
+ if target_columns and (is_bioset(target) or is_dataset(target)):
+ if isinstance(target_columns, list):
+ target_columns = target_columns[0]
+ class_names = target._info.features[target_columns].names
+ n_classes = len(class_names)
+ else:
+ n_classes = DataHandler.nunique(target)
+
+ if n_classes > 2:
+ sub_task = "multiclass_classification"
+ else:
+ sub_task = "binary_classification"
+ else:
+ sub_task = "multilabel_classification"
+ elif task == "regression":
+ if len(target_dims) > 1 and target_dims[1] > 1:
+ sub_task = "multiregression"
+ else:
+ sub_task = "regression"
+ elif task in CLASSIFICATION_TASKS or task in REGRESSION_TASKS:
+ sub_task = task
+ else:
+ raise ValueError(
+ "Invalid task. Please specify either a task, such as 'classification' or "
+ f"'regression', or a sub task, such as {CLASSIFICATION_TASKS + REGRESSION_TASKS}."
+ )
+ return sub_task, class_names
+
+
+def _iter_preprocessor(preprocessor, condition=None):
+ if isinstance(preprocessor, (Pipeline, ProcessorPipeline)):
+ start = False
+ for n, p in preprocessor.steps:
+ if isinstance(p, BaseProcessor):
+ if condition == "sample independent":
+ if p.has_fit:
+ yield p.config.processor_name, p
+ else:
+ break
+ elif condition == "sample dependent":
+ if not p.has_fit:
+ start = True
+ yield p
+ elif start:
+ yield p
+ else:
+ yield p
+ else:
+ yield p
+ elif isinstance(preprocessor, list):
+ for p in preprocessor:
+ yield p
+ else:
+ yield preprocessor
+
+
+def preprocess(
+ preprocessor,
+ x_train,
+ y_train=None,
+ x_valid=None,
+ y_valid=None,
+ cache_dir=None,
+ transform_only=False,
+ raise_error=False,
+):
+ try:
+ if isinstance(preprocessor, (Pipeline, ProcessorPipeline)):
+ for proc in preprocessor.steps:
+ p = proc[-1] if isinstance(proc, tuple) else proc
+ if isinstance(p, BaseProcessor):
+ extra_kwargs = {
+ "cache_dir": cache_dir,
+ "load_from_cache_file": False,
+ }
+ else:
+ extra_kwargs = {}
+ if transform_only:
+ x_train = p.transform(x_train, **extra_kwargs)
+ else:
+ x_train = p.fit_transform(x_train, **extra_kwargs)
+ if x_valid is not None:
+ x_valid = p.transform(x_valid, **extra_kwargs)
+ else:
+ if isinstance(preprocessor, BaseProcessor):
+ extra_kwargs = {"cache_dir": cache_dir, "load_from_cache_file": False}
+ else:
+ extra_kwargs = {}
+ if transform_only:
+ x_train = preprocessor.transform(x_train, **extra_kwargs)
+ else:
+ x_train = preprocessor.fit_transform(x_train, **extra_kwargs)
+ if x_valid is not None:
+ x_valid = preprocessor.transform(x_valid, **extra_kwargs)
+ except ValueError:
+ if raise_error:
+ raise
+ else:
+ logger.info("Preprocessing failed, using nan_to_num")
+ _train_cols = x_train.columns.tolist()
+ x_train = np.nan_to_num(x_train)
+ # convert back to dataframe
+ x_train = pd.DataFrame(x_train, columns=_train_cols)
+ if x_valid is not None:
+ _valid_cols = x_valid.columns.tolist()
+ x_valid = np.nan_to_num(x_valid)
+ x_valid = pd.DataFrame(x_valid, columns=_valid_cols)
+ return preprocess(
+ preprocessor,
+ x_train,
+ y_train,
+ x_valid,
+ y_valid,
+ cache_dir,
+ transform_only,
+ True,
+ )
+
+ return x_train, y_train, x_valid, y_valid
+
+
+def split(cv, X, y=None, groups=None, indices=None):
+ if cv is None:
+ return [(None, None)]
+ # check if cv is a generator
+ if isinstance(cv, Generator):
+ return cv
+ if indices is not None:
+ return cv.split(
+ DataHandler.select_rows(X, indices),
+ y=DataHandler.select_column(DataHandler.select_rows(y, indices), 0)
+ if y is not None
+ else None,
+ groups=DataHandler.select_column(
+ DataHandler.select_rows(groups, indices), 0
+ )
+ if groups is not None
+ else None,
+ )
+ return cv.split(
+ X,
+ y=DataHandler.select_column(y, 0) if y is not None else None,
+ groups=DataHandler.select_column(groups, 0) if groups is not None else None,
+ )
+
+
+def get_model(models_or_name, task=None):
+ if isinstance(models_or_name, list) and len(models_or_name):
+ mon = models_or_name[0]
+ else:
+ mon = models_or_name
+
+ if mon is None:
+ return None
+
+ model = None
+ model_cls = None
+ if isinstance(mon, str):
+ model_name = mon
+ model_cls = _MODELS[task][model_name]
+ elif isinstance(mon, type):
+ model_cls = mon
+ model_name = mon.__module__.split(".")[-1]
+ elif isinstance(mon, Model):
+ model_name = mon.config.processor_name
+ model_cls = model.__class__
+ model = mon
+ elif isinstance(mon, (ProcessorPipeline, Pipeline)):
+ for m in _flatten_pipeline(mon):
+ out = get_model(m, task)
+ if out is not None:
+ return out
+ return None
+ elif hasattr(mon, "predict"):
+ model_cls = mon.__class__
+ model_name = mon.__class__.__module__.split(".")[-1]
+ model = mon
+ else:
+ return None
+ return model, model_cls, model_name
+
+
+def get_model_info(model, task=None):
+ out = get_model(model, task=None)
+ if out is not None:
+ model, _, _ = out
+ if hasattr(model, "config"):
+ return asdict(model.config)
+ elif hasattr(model, "get_params"):
+ return model.get_params()
+ else:
+ return model.__dict__
+
+ return {}
+
+
+def _split_data(
+ data, target, valid_split, seed, shuffle, stratify_by_column, keep_in_memory=True
+):
+ def _convert_to_in_memory(data, indices=None):
+ from datasets import InMemoryTable, MemoryMappedTable
+
+ try:
+ if data._indices or indices is not None:
+ if indices is None:
+ indices = data._indices.column("indices")
+
+ if isinstance(data._data, MemoryMappedTable):
+ table = MemoryMappedTable._apply_replays(
+ data._data.table, data._data.replays
+ )
+ else:
+ table = data._data.table
+ table = table.take(indices)
+ data._indices = None
+ data._data = InMemoryTable(table)
+
+ return data
+ except pa.ArrowInvalid:
+ logger.error(
+ "Table is too large for row selection. Please set keep_in_memory=False or "
+ "shuffle=False until https://github.com/apache/arrow/issues/25822 has been resolved"
+ )
+ raise
+
+ if is_bioset(data) or is_dataset(data):
+ if keep_in_memory:
+ _convert_to_in_memory(data)
+
+ train_data, valid_data = data.train_test_split(
+ test_size=valid_split,
+ seed=seed if shuffle else None,
+ shuffle=shuffle,
+ stratify_by_column=stratify_by_column if shuffle else None,
+ ).values()
+ if not shuffle and not stratify_by_column:
+ target, valid_target = target.train_test_split(
+ test_size=valid_split,
+ seed=seed,
+ shuffle=shuffle,
+ stratify_by_column=stratify_by_column,
+ )
+ else:
+ train_indices = train_data._indices.column("indices")
+ if valid_data._indices:
+ valid_indices = valid_data._indices.column("indices")
+ else:
+ valid_indices = pa.array(
+ list(
+ sorted(
+ set(range(train_data.num_rows))
+ - set(DataHandler.to_numpy(train_indices).tolist())
+ )
+ )
+ )
+
+ if keep_in_memory:
+ train_data = _convert_to_in_memory(train_data, train_indices)
+ valid_data = _convert_to_in_memory(valid_data, valid_indices)
+ target = _convert_to_in_memory(target, train_indices)
+ valid_target = _convert_to_in_memory(valid_target, valid_indices)
+ else:
+ valid_target = copy.deepcopy(target)
+ valid_target = valid_target.select(valid_indices)
+ target = target.select(train_indices)
+ else:
+ if stratify_by_column is not None:
+ stratify = DataHandler.to_numpy(data, stratify_by_column)
+ else:
+ stratify = target
+ train_data, valid_data, target, valid_target = train_test_split(
+ data,
+ target,
+ test_size=valid_split,
+ random_state=seed,
+ shuffle=shuffle,
+ stratify=stratify,
+ )
+ return train_data, valid_data, target, valid_target
+
+
+def _get_label_names(data, target, target_columns, u_labels=None):
+ if isinstance(target_columns, list):
+ tc = target_columns[0]
+ else:
+ tc = target_columns
+ if (is_bioset(data) or is_dataset(data)) and tc in data._info.features:
+ labels = data._info.features[tc].names
+ elif (is_bioset(target) or is_dataset(target)) and tc in target._info.features:
+ labels = target._info.features[tc].names
+ else:
+ labels = DataHandler.to_list(u_labels)
+ return labels
+
+
+def _get_data(
+ data,
+ target=None,
+ valid_data=None,
+ valid_target=None,
+ valid_split=None,
+ groups=None,
+ outer_groups=None,
+ group_name=None,
+ outer_group_name=None,
+ input_columns="auto",
+ target_columns="auto",
+ valid_input_columns="auto",
+ valid_target_columns="auto",
+ target_required=True,
+ shuffle=True,
+ seed=42,
+ keep_in_memory=True,
+ format="pandas",
+):
+ if DataHandler.supports_named_columns(data):
+ use_auto_target_columns = target_columns == "auto"
+ else:
+ input_columns = None if input_columns == "auto" else input_columns
+ target_columns = None if target_columns == "auto" else target_columns
+
+ if valid_data is not None and not DataHandler.supports_named_columns(valid_data):
+ valid_input_columns = (
+ None if valid_input_columns == "auto" else valid_input_columns
+ )
+ valid_target_columns = (
+ None if valid_target_columns == "auto" else valid_target_columns
+ )
+
+ def get_target_if_none(data, target_columns):
+ target = None
+ if target_columns is not None:
+ sel_kwargs = {}
+ if is_bioset(data) or is_dataset(data):
+ sel_kwargs = {"keep_old_fingerprint": False}
+ if isinstance(target_columns, str):
+ target = DataHandler.select_columns(
+ data, [target_columns], **sel_kwargs
+ )
+ else:
+ target = DataHandler.select_columns(data, target_columns, **sel_kwargs)
+ return target
+
+ def handle_target_columns(data, target_columns, required=True):
+ if target_columns == "auto":
+ if is_bioset(data):
+ from biosets import get_target_col_names
+
+ target_columns = get_target_col_names(data)
+ elif target_required:
+ ValueError(
+ "`data` must be a `Dataset` to automatically infer target columns. "
+ "Please provide `target_columns` or `target`."
+ )
+ if isinstance(target_columns, (type, tuple)):
+ if is_bioset(data) or is_dataset(data):
+ target_columns = [
+ k
+ for k, v in data._info.features.items()
+ if isinstance(v, target_columns)
+ ]
+ if target_columns == "auto":
+ target_columns = None
+ if target_columns is None and target_required:
+ raise ValueError("Target columns must be provided if target is None.")
+
+ if isinstance(target_columns, list) and len(target_columns) == 1:
+ target_columns = target_columns[0]
+ return target_columns
+
+ def handle_input_columns(data, input_columns):
+ if is_bioset(data) or is_dataset(data):
+ if input_columns == "auto" and is_bioset(data):
+ from biosets import get_data_col_names
+
+ input_columns = get_data_col_names(data)
+ elif isinstance(input_columns, (type, tuple)):
+ input_columns = [
+ k
+ for k, v in data._info.features.items()
+ if isinstance(v, input_columns)
+ ]
+ if input_columns == "auto":
+ input_columns = None
+
+ return input_columns
+
+ x_train = None
+ y_train = None
+ y_valid = None
+ x_valid = None
+
+ if isinstance(data, dict):
+ if isinstance(valid_split, str):
+ valid_data = data.get(valid_split)
+ else:
+ valid_data = data.get(
+ "valid", data.get("test"), data.get("validation"), None
+ )
+ if len(data) < 3:
+ train_split = [
+ k
+ for k in data.keys()
+ if k not in ["valid", "test", "validation", valid_split]
+ ]
+ if not train_split:
+ raise ValueError("Train split not found in dataset.")
+ data = data.get(train_split)
+ else:
+ data = data.get("train", data.get("training"))
+
+ if target is None:
+ target_columns = handle_target_columns(
+ data, target_columns, required=target_required
+ )
+ y_train = get_target_if_none(data, target_columns)
+
+ else:
+ y_train = target
+ if DataHandler.supports_named_columns(target):
+ target_columns = DataHandler.get_column_names(target)[0]
+
+ if valid_target is None and valid_data is not None:
+ if valid_target_columns is None or valid_target_columns == "auto":
+ valid_target_columns = handle_target_columns(
+ valid_data,
+ "auto" if use_auto_target_columns else target_columns,
+ required=False,
+ )
+ if valid_target_columns is not None:
+ y_valid = get_target_if_none(valid_data, valid_target_columns)
+ else:
+ y_valid = get_target_if_none(valid_data, target_columns)
+ if valid_target_columns is not None and target_columns is not None:
+ if valid_target_columns != target_columns:
+ valid_target = DataHandler.set_column_names(
+ valid_target, target_columns
+ )
+ elif valid_target_columns is not None:
+ target_columns = valid_target_columns
+
+ elif valid_target is not None:
+ y_valid = valid_target
+
+ if group_name is not None and groups is None:
+ groups = DataHandler.select_column(data, group_name)
+ if outer_group_name is not None and outer_groups is None:
+ outer_groups = DataHandler.select_column(data, outer_group_name)
+ if (
+ group_name is None
+ and groups is not None
+ and DataHandler.supports_named_columns(groups)
+ ):
+ group_name = DataHandler.get_column_names(groups)[0]
+ if (
+ outer_group_name is None
+ and outer_groups is not None
+ and DataHandler.supports_named_columns(outer_groups)
+ ):
+ outer_group_name = DataHandler.get_column_names(outer_groups)[0]
+
+ if valid_data is None and valid_split is not None:
+ data, valid_data, y_train, y_valid = _split_data(
+ data,
+ y_train,
+ valid_split,
+ seed=seed,
+ shuffle=shuffle,
+ stratify_by_column=group_name,
+ keep_in_memory=keep_in_memory,
+ )
+
+ input_columns = handle_input_columns(data, input_columns)
+ if valid_data is not None:
+ valid_input_columns = handle_input_columns(valid_data, valid_input_columns)
+ if input_columns is not None:
+ sel_kwargs = {}
+ if is_bioset(data) or is_dataset(data):
+ sel_kwargs = {"keep_old_fingerprint": False}
+ x_train = DataHandler.select_columns(data, input_columns, **sel_kwargs)
+ if valid_data is not None:
+ sel_kwargs = {}
+ if is_bioset(valid_data) or is_dataset(valid_data):
+ sel_kwargs = {"keep_old_fingerprint": False}
+ valid_input_columns = valid_input_columns or input_columns
+ x_valid = DataHandler.select_columns(
+ valid_data, valid_input_columns, **sel_kwargs
+ )
+ elif (
+ target is not None
+ and DataHandler.supports_named_columns(data)
+ and DataHandler.supports_named_columns(target)
+ ):
+ _input_columns = DataHandler.get_column_names(data)
+ _target_columns = DataHandler.get_column_names(target)
+ if group_name is not None:
+ _target_columns.append(group_name)
+ if outer_group_name is not None:
+ _target_columns.append(outer_group_name)
+ _intersecting_cols = set(_input_columns) & set(_target_columns)
+ if _intersecting_cols:
+ x_train = DataHandler.drop_columns(data, list(_intersecting_cols))
+ if (
+ valid_data is not None
+ and DataHandler.supports_named_columns(valid_data)
+ and DataHandler.supports_named_columns(x_train)
+ ):
+ sel_kwargs = {}
+ if is_bioset(valid_data) or is_dataset(valid_data):
+ sel_kwargs = {"keep_old_fingerprint": False}
+ x_valid = DataHandler.select_columns(
+ valid_data, DataHandler.get_column_names(x_train), **sel_kwargs
+ )
+ else:
+ x_train = data
+ if valid_data is not None:
+ x_valid = valid_data
+ input_columns = DataHandler.get_column_names(x_train)
+ else:
+ x_train = data
+ if valid_data is not None:
+ x_valid = valid_data
+ input_columns = DataHandler.get_column_names(x_train)
+
+ if format is not None:
+ x_train = DataHandler.to_format(x_train, format)
+ if y_train is not None:
+ y_train = DataHandler.to_format(y_train, format)
+ if groups is not None:
+ groups = DataHandler.to_format(groups, format)
+ if outer_groups is not None:
+ outer_groups = DataHandler.to_format(outer_groups, format)
+ if x_valid is not None:
+ x_valid = DataHandler.to_format(x_valid, format)
+ if y_valid is not None:
+ y_valid = DataHandler.to_format(y_valid, format)
+
+ return (
+ x_train,
+ y_train,
+ x_valid,
+ y_valid,
+ groups,
+ outer_groups,
+ input_columns,
+ target_columns,
+ )
diff --git a/src/biofit/utils/__init__.py b/src/biofit/utils/__init__.py
new file mode 100644
index 0000000..89b7a6e
--- /dev/null
+++ b/src/biofit/utils/__init__.py
@@ -0,0 +1,60 @@
+# ruff: noqa
+from .file_utils import (
+ add_end_docstrings,
+ add_start_docstrings,
+ estimate_dataset_size,
+ has_ext,
+ has_separator,
+ hash_url_to_filename,
+ is_file_name,
+ is_local_path,
+ is_relative_path,
+ is_remote_url,
+ move_temp_file,
+ url_or_path_join,
+ url_or_path_parent,
+)
+from .fingerprint import (
+ Hasher,
+ _build_cache_dir,
+ disable_caching,
+ enable_caching,
+ fingerprint_from_data,
+ fingerprint_from_kwargs,
+ generate_cache_dir,
+ get_cache_file_name,
+ is_caching_enabled,
+ update_fingerprint,
+)
+from .logging import (
+ disable_progress_bar,
+ enable_progress_bar,
+ set_verbosity,
+ set_verbosity_debug,
+ set_verbosity_error,
+ set_verbosity_info,
+ set_verbosity_warning,
+ silence,
+ unsilence,
+)
+from .py_util import (
+ as_py,
+ enable_full_determinism,
+ set_seed,
+)
+from .recorder import (
+ load_module_or_class,
+ record_step,
+)
+from .table_util import (
+ concat_blocks,
+ determine_upcast,
+ init_arrow_buffer_and_writer,
+ is_binary_like,
+ is_fixed_width,
+ is_large_binary_like,
+ read_arrow_table,
+ upcast_tables,
+ write_arrow_table,
+)
+from .types import Unset
diff --git a/src/biofit/utils/_dill.py b/src/biofit/utils/_dill.py
new file mode 100644
index 0000000..8ef2698
--- /dev/null
+++ b/src/biofit/utils/_dill.py
@@ -0,0 +1,469 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Modified by: Patrick Smyth
+# Date: 2024
+# Summary of changes:
+# - removed `config` to track the dill version. Version is parsed directly from the dill
+# package.
+"""Extends `dill` to support pickling more types and produce more consistent dumps."""
+
+import os
+import sys
+from io import BytesIO
+from types import CodeType, FunctionType
+
+import dill
+from packaging import version
+
+
+class Pickler(dill.Pickler):
+ dispatch = dill._dill.MetaCatchingDict(dill.Pickler.dispatch.copy())
+ _legacy_no_dict_keys_sorting = False
+
+ def save(self, obj, save_persistent_id=True):
+ obj_type = type(obj)
+ if obj_type not in self.dispatch:
+ if "regex" in sys.modules:
+ import regex # type: ignore
+
+ if obj_type is regex.Pattern:
+ pklregister(obj_type)(_save_regexPattern)
+ if "spacy" in sys.modules:
+ import spacy # type: ignore
+
+ if issubclass(obj_type, spacy.Language):
+ pklregister(obj_type)(_save_spacyLanguage)
+ if "tiktoken" in sys.modules:
+ import tiktoken # type: ignore
+
+ if obj_type is tiktoken.Encoding:
+ pklregister(obj_type)(_save_tiktokenEncoding)
+ if "torch" in sys.modules:
+ import torch # type: ignore
+
+ if issubclass(obj_type, torch.Tensor):
+ pklregister(obj_type)(_save_torchTensor)
+
+ if obj_type is torch.Generator:
+ pklregister(obj_type)(_save_torchGenerator)
+
+ # Unwrap `torch.compile`-ed modules
+ if issubclass(obj_type, torch.nn.Module):
+ obj = getattr(obj, "_orig_mod", obj)
+ if "transformers" in sys.modules:
+ import transformers # type: ignore
+
+ if issubclass(obj_type, transformers.PreTrainedTokenizerBase):
+ pklregister(obj_type)(_save_transformersPreTrainedTokenizerBase)
+
+ # Unwrap `torch.compile`-ed functions
+ if obj_type is FunctionType:
+ obj = getattr(obj, "_torchdynamo_orig_callable", obj)
+ dill.Pickler.save(self, obj, save_persistent_id=save_persistent_id)
+
+ def _batch_setitems(self, items):
+ if self._legacy_no_dict_keys_sorting:
+ return super()._batch_setitems(items)
+ # Ignore the order of keys in a dict
+ try:
+ # Faster, but fails for unorderable elements
+ items = sorted(items)
+ except Exception: # TypeError, decimal.InvalidOperation, etc.
+ from datasets.fingerprint import Hasher
+
+ items = sorted(items, key=lambda x: Hasher.hash(x[0]))
+ dill.Pickler._batch_setitems(self, items)
+
+ def memoize(self, obj):
+ # Don't memoize strings since two identical strings can have different Python ids
+ if type(obj) is not str: # noqa: E721
+ dill.Pickler.memoize(self, obj)
+
+
+def pklregister(t):
+ """Register a custom reducer for the type."""
+
+ def proxy(func):
+ Pickler.dispatch[t] = func
+ return func
+
+ return proxy
+
+
+def dump(obj, file):
+ """Pickle an object to a file."""
+ Pickler(file, recurse=True).dump(obj)
+
+
+def dumps(obj):
+ """Pickle an object to a string."""
+ file = BytesIO()
+ dump(obj, file)
+ return file.getvalue()
+
+
+if version.parse(dill.__version__) < version.parse("0.3.6"):
+
+ def log(pickler, msg):
+ dill._dill.log.info(msg)
+
+elif version.parse(dill.__version__) in [
+ version.parse("0.3.6").release,
+ version.parse("0.3.7").release,
+ version.parse("0.3.8").release,
+]:
+
+ def log(pickler, msg):
+ dill._dill.logger.trace(pickler, msg)
+
+
+@pklregister(set)
+def _save_set(pickler, obj):
+ log(pickler, f"Se: {obj}")
+ try:
+ # Faster, but fails for unorderable elements
+ args = (sorted(obj),)
+ except Exception: # TypeError, decimal.InvalidOperation, etc.
+ from datasets.fingerprint import Hasher
+
+ args = (sorted(obj, key=Hasher.hash),)
+
+ pickler.save_reduce(set, args, obj=obj)
+ log(pickler, "# Se")
+
+
+def _save_regexPattern(pickler, obj):
+ import regex # type: ignore
+
+ log(pickler, f"Re: {obj}")
+ args = (obj.pattern, obj.flags)
+ pickler.save_reduce(regex.compile, args, obj=obj)
+ log(pickler, "# Re")
+
+
+def _save_tiktokenEncoding(pickler, obj):
+ import tiktoken # type: ignore
+
+ log(pickler, f"Enc: {obj}")
+ args = (obj.name, obj._pat_str, obj._mergeable_ranks, obj._special_tokens)
+ pickler.save_reduce(tiktoken.Encoding, args, obj=obj)
+ log(pickler, "# Enc")
+
+
+def _save_torchTensor(pickler, obj):
+ import torch # type: ignore
+
+ # `torch.from_numpy` is not picklable in `torch>=1.11.0`
+ def create_torchTensor(np_array, dtype=None):
+ tensor = torch.from_numpy(np_array)
+ if dtype:
+ tensor = tensor.type(dtype)
+ return tensor
+
+ log(pickler, f"To: {obj}")
+ if obj.dtype == torch.bfloat16:
+ args = (obj.detach().to(torch.float).cpu().numpy(), torch.bfloat16)
+ else:
+ args = (obj.detach().cpu().numpy(),)
+ pickler.save_reduce(create_torchTensor, args, obj=obj)
+ log(pickler, "# To")
+
+
+def _save_torchGenerator(pickler, obj):
+ import torch # type: ignore
+
+ def create_torchGenerator(state):
+ generator = torch.Generator()
+ generator.set_state(state)
+ return generator
+
+ log(pickler, f"Ge: {obj}")
+ args = (obj.get_state(),)
+ pickler.save_reduce(create_torchGenerator, args, obj=obj)
+ log(pickler, "# Ge")
+
+
+def _save_spacyLanguage(pickler, obj):
+ import spacy # type: ignore
+
+ def create_spacyLanguage(config, bytes):
+ lang_cls = spacy.util.get_lang_class(config["nlp"]["lang"])
+ lang_inst = lang_cls.from_config(config)
+ return lang_inst.from_bytes(bytes)
+
+ log(pickler, f"Sp: {obj}")
+ args = (obj.config, obj.to_bytes())
+ pickler.save_reduce(create_spacyLanguage, args, obj=obj)
+ log(pickler, "# Sp")
+
+
+def _save_transformersPreTrainedTokenizerBase(pickler, obj):
+ log(pickler, f"Tok: {obj}")
+ # Ignore the `cache` attribute
+ state = obj.__dict__
+ if "cache" in state and isinstance(state["cache"], dict):
+ state["cache"] = {}
+ pickler.save_reduce(type(obj), (), state=state, obj=obj)
+ log(pickler, "# Tok")
+
+
+if version.parse(dill.__version__) < version.parse("0.3.6"):
+
+ @pklregister(CodeType)
+ def _save_code(pickler, obj):
+ """
+ From dill._dill.save_code
+ This is a modified version that removes the origin (filename + line no.)
+ of functions created in notebooks or shells for example.
+ """
+ dill._dill.log.info(f"Co: {obj}")
+ # The filename of a function is the .py file where it is defined.
+ # Filenames of functions created in notebooks or shells start with '<'
+ # ex: for ipython, and for shell
+ # Filenames of functions created in ipykernel the filename
+ # look like f"{tempdir}/ipykernel_{id1}/{id2}.py"
+ # Moreover lambda functions have a special name: ''
+ # ex: (lambda x: x).__code__.co_name == "" # True
+ #
+ # For the hashing mechanism we ignore where the function has been defined
+ # More specifically:
+ # - we ignore the filename of special functions (filename starts with '<')
+ # - we always ignore the line number
+ # - we only use the base name of the file instead of the whole path,
+ # to be robust in case a script is moved for example.
+ #
+ # Only those two lines are different from the original implementation:
+ co_filename = (
+ ""
+ if obj.co_filename.startswith("<")
+ or (
+ len(obj.co_filename.split(os.path.sep)) > 1
+ and obj.co_filename.split(os.path.sep)[-2].startswith("ipykernel_")
+ )
+ or obj.co_name == ""
+ else os.path.basename(obj.co_filename)
+ )
+ co_firstlineno = 1
+ # The rest is the same as in the original dill implementation
+ if dill._dill.PY3:
+ if hasattr(obj, "co_posonlyargcount"):
+ args = (
+ obj.co_argcount,
+ obj.co_posonlyargcount,
+ obj.co_kwonlyargcount,
+ obj.co_nlocals,
+ obj.co_stacksize,
+ obj.co_flags,
+ obj.co_code,
+ obj.co_consts,
+ obj.co_names,
+ obj.co_varnames,
+ co_filename,
+ obj.co_name,
+ co_firstlineno,
+ obj.co_lnotab,
+ obj.co_freevars,
+ obj.co_cellvars,
+ )
+ else:
+ args = (
+ obj.co_argcount,
+ obj.co_kwonlyargcount,
+ obj.co_nlocals,
+ obj.co_stacksize,
+ obj.co_flags,
+ obj.co_code,
+ obj.co_consts,
+ obj.co_names,
+ obj.co_varnames,
+ co_filename,
+ obj.co_name,
+ co_firstlineno,
+ obj.co_lnotab,
+ obj.co_freevars,
+ obj.co_cellvars,
+ )
+ else:
+ args = (
+ obj.co_argcount,
+ obj.co_nlocals,
+ obj.co_stacksize,
+ obj.co_flags,
+ obj.co_code,
+ obj.co_consts,
+ obj.co_names,
+ obj.co_varnames,
+ co_filename,
+ obj.co_name,
+ co_firstlineno,
+ obj.co_lnotab,
+ obj.co_freevars,
+ obj.co_cellvars,
+ )
+ pickler.save_reduce(CodeType, args, obj=obj)
+ dill._dill.log.info("# Co")
+ return
+
+elif version.parse(dill.__version__).release[:3] in [
+ version.parse("0.3.6").release,
+ version.parse("0.3.7").release,
+ version.parse("0.3.8").release,
+]:
+ # From: https://github.com/uqfoundation/dill/blob/dill-0.3.6/dill/_dill.py#L1104
+ @pklregister(CodeType)
+ def save_code(pickler, obj):
+ dill._dill.logger.trace(pickler, "Co: %s", obj)
+
+ ############################################################################################################
+ # Modification here for huggingface/datasets
+ # The filename of a function is the .py file where it is defined.
+ # Filenames of functions created in notebooks or shells start with '<'
+ # ex: for ipython, and for shell
+ # Filenames of functions created in ipykernel the filename
+ # look like f"{tempdir}/ipykernel_{id1}/{id2}.py"
+ # Moreover lambda functions have a special name: ''
+ # ex: (lambda x: x).__code__.co_name == "" # True
+ #
+ # For the hashing mechanism we ignore where the function has been defined
+ # More specifically:
+ # - we ignore the filename of special functions (filename starts with '<')
+ # - we always ignore the line number
+ # - we only use the base name of the file instead of the whole path,
+ # to be robust in case a script is moved for example.
+ #
+ # Only those two lines are different from the original implementation:
+ co_filename = (
+ ""
+ if obj.co_filename.startswith("<")
+ or (
+ len(obj.co_filename.split(os.path.sep)) > 1
+ and obj.co_filename.split(os.path.sep)[-2].startswith("ipykernel_")
+ )
+ or obj.co_name == ""
+ else os.path.basename(obj.co_filename)
+ )
+ co_firstlineno = 1
+ # The rest is the same as in the original dill implementation, except for the replacements:
+ # - obj.co_filename => co_filename
+ # - obj.co_firstlineno => co_firstlineno
+ ############################################################################################################
+
+ if hasattr(obj, "co_endlinetable"): # python 3.11a (20 args)
+ args = (
+ obj.co_lnotab, # for < python 3.10 [not counted in args]
+ obj.co_argcount,
+ obj.co_posonlyargcount,
+ obj.co_kwonlyargcount,
+ obj.co_nlocals,
+ obj.co_stacksize,
+ obj.co_flags,
+ obj.co_code,
+ obj.co_consts,
+ obj.co_names,
+ obj.co_varnames,
+ co_filename, # Modification for huggingface/datasets ############################################
+ obj.co_name,
+ obj.co_qualname,
+ co_firstlineno, # Modification for huggingface/datasets #########################################
+ obj.co_linetable,
+ obj.co_endlinetable,
+ obj.co_columntable,
+ obj.co_exceptiontable,
+ obj.co_freevars,
+ obj.co_cellvars,
+ )
+ elif hasattr(obj, "co_exceptiontable"): # python 3.11 (18 args)
+ args = (
+ obj.co_lnotab, # for < python 3.10 [not counted in args]
+ obj.co_argcount,
+ obj.co_posonlyargcount,
+ obj.co_kwonlyargcount,
+ obj.co_nlocals,
+ obj.co_stacksize,
+ obj.co_flags,
+ obj.co_code,
+ obj.co_consts,
+ obj.co_names,
+ obj.co_varnames,
+ co_filename, # Modification for huggingface/datasets ############################################
+ obj.co_name,
+ obj.co_qualname,
+ co_firstlineno, # Modification for huggingface/datasets #########################################
+ obj.co_linetable,
+ obj.co_exceptiontable,
+ obj.co_freevars,
+ obj.co_cellvars,
+ )
+ elif hasattr(obj, "co_linetable"): # python 3.10 (16 args)
+ args = (
+ obj.co_lnotab, # for < python 3.10 [not counted in args]
+ obj.co_argcount,
+ obj.co_posonlyargcount,
+ obj.co_kwonlyargcount,
+ obj.co_nlocals,
+ obj.co_stacksize,
+ obj.co_flags,
+ obj.co_code,
+ obj.co_consts,
+ obj.co_names,
+ obj.co_varnames,
+ co_filename, # Modification for huggingface/datasets ############################################
+ obj.co_name,
+ co_firstlineno, # Modification for huggingface/datasets #########################################
+ obj.co_linetable,
+ obj.co_freevars,
+ obj.co_cellvars,
+ )
+ elif hasattr(obj, "co_posonlyargcount"): # python 3.8 (16 args)
+ args = (
+ obj.co_argcount,
+ obj.co_posonlyargcount,
+ obj.co_kwonlyargcount,
+ obj.co_nlocals,
+ obj.co_stacksize,
+ obj.co_flags,
+ obj.co_code,
+ obj.co_consts,
+ obj.co_names,
+ obj.co_varnames,
+ co_filename, # Modification for huggingface/datasets ############################################
+ obj.co_name,
+ co_firstlineno, # Modification for huggingface/datasets #########################################
+ obj.co_lnotab,
+ obj.co_freevars,
+ obj.co_cellvars,
+ )
+ else: # python 3.7 (15 args)
+ args = (
+ obj.co_argcount,
+ obj.co_kwonlyargcount,
+ obj.co_nlocals,
+ obj.co_stacksize,
+ obj.co_flags,
+ obj.co_code,
+ obj.co_consts,
+ obj.co_names,
+ obj.co_varnames,
+ co_filename, # Modification for huggingface/datasets ############################################
+ obj.co_name,
+ co_firstlineno, # Modification for huggingface/datasets #########################################
+ obj.co_lnotab,
+ obj.co_freevars,
+ obj.co_cellvars,
+ )
+
+ pickler.save_reduce(dill._dill._create_code, args, obj=obj)
+ dill._dill.logger.trace(pickler, "# Co")
+ return
diff --git a/src/biofit/utils/doc.py b/src/biofit/utils/doc.py
new file mode 100644
index 0000000..8bcd3b7
--- /dev/null
+++ b/src/biofit/utils/doc.py
@@ -0,0 +1,1213 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Doc utilities: Utilities related to documentation
+"""
+
+import functools
+import re
+import types
+
+
+def add_start_docstrings(*docstr):
+ def docstring_decorator(fn):
+ fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
+ return fn
+
+ return docstring_decorator
+
+
+def add_start_docstrings_to_model_forward(*docstr):
+ def docstring_decorator(fn):
+ docstring = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
+ class_name = f"[`{fn.__qualname__.split('.')[0]}`]"
+ intro = f" The {class_name} forward method, overrides the `__call__` special method."
+ note = r"""
+
+
+
+ Although the recipe for forward pass needs to be defined within this function, one should call the [`Module`]
+ instance afterwards instead of this since the former takes care of running the pre and post processing steps while
+ the latter silently ignores them.
+
+
+"""
+
+ fn.__doc__ = intro + note + docstring
+ return fn
+
+ return docstring_decorator
+
+
+def add_end_docstrings(*docstr):
+ def docstring_decorator(fn):
+ fn.__doc__ = (fn.__doc__ if fn.__doc__ is not None else "") + "".join(docstr)
+ return fn
+
+ return docstring_decorator
+
+
+PT_RETURN_INTRODUCTION = r"""
+ Returns:
+ [`{full_output_type}`] or `tuple(torch.FloatTensor)`: A [`{full_output_type}`] or a tuple of
+ `torch.FloatTensor` (if `return_dict=False` is passed or when `config.return_dict=False`) comprising various
+ elements depending on the configuration ([`{_config_class}`]) and inputs.
+
+"""
+
+
+TF_RETURN_INTRODUCTION = r"""
+ Returns:
+ [`{full_output_type}`] or `tuple(tf.Tensor)`: A [`{full_output_type}`] or a tuple of `tf.Tensor` (if
+ `return_dict=False` is passed or when `config.return_dict=False`) comprising various elements depending on the
+ configuration ([`{_config_class}`]) and inputs.
+
+"""
+
+
+def _get_indent(t):
+ """Returns the indentation in the first line of t"""
+ search = re.search(r"^(\s*)\S", t)
+ return "" if search is None else search.groups()[0]
+
+
+def _convert_output_args_doc(output_args_doc):
+ """Convert output_args_doc to display properly."""
+ # Split output_arg_doc in blocks argument/description
+ indent = _get_indent(output_args_doc)
+ blocks = []
+ current_block = ""
+ for line in output_args_doc.split("\n"):
+ # If the indent is the same as the beginning, the line is the name of new arg.
+ if _get_indent(line) == indent:
+ if len(current_block) > 0:
+ blocks.append(current_block[:-1])
+ current_block = f"{line}\n"
+ else:
+ # Otherwise it's part of the description of the current arg.
+ # We need to remove 2 spaces to the indentation.
+ current_block += f"{line[2:]}\n"
+ blocks.append(current_block[:-1])
+
+ # Format each block for proper rendering
+ for i in range(len(blocks)):
+ blocks[i] = re.sub(r"^(\s+)(\S+)(\s+)", r"\1- **\2**\3", blocks[i])
+ blocks[i] = re.sub(r":\s*\n\s*(\S)", r" -- \1", blocks[i])
+
+ return "\n".join(blocks)
+
+
+def _prepare_output_docstrings(output_type, _config_class, min_indent=None):
+ """
+ Prepares the return part of the docstring using `output_type`.
+ """
+ output_docstring = output_type.__doc__
+
+ # Remove the head of the docstring to keep the list of args only
+ lines = output_docstring.split("\n")
+ i = 0
+ while i < len(lines) and re.search(r"^\s*(Args|Parameters):\s*$", lines[i]) is None:
+ i += 1
+ if i < len(lines):
+ params_docstring = "\n".join(lines[(i + 1) :])
+ params_docstring = _convert_output_args_doc(params_docstring)
+ else:
+ raise ValueError(
+ f"No `Args` or `Parameters` section is found in the docstring of `{output_type.__name__}`. Make sure it has "
+ "docstring and contain either `Args` or `Parameters`."
+ )
+
+ # Add the return introduction
+ full_output_type = f"{output_type.__module__}.{output_type.__name__}"
+ intro = (
+ TF_RETURN_INTRODUCTION
+ if output_type.__name__.startswith("TF")
+ else PT_RETURN_INTRODUCTION
+ )
+ intro = intro.format(full_output_type=full_output_type, _config_class=_config_class)
+ result = intro + params_docstring
+
+ # Apply minimum indent if necessary
+ if min_indent is not None:
+ lines = result.split("\n")
+ # Find the indent of the first nonempty line
+ i = 0
+ while len(lines[i]) == 0:
+ i += 1
+ indent = len(_get_indent(lines[i]))
+ # If too small, add indentation to all nonempty lines
+ if indent < min_indent:
+ to_add = " " * (min_indent - indent)
+ lines = [(f"{to_add}{line}" if len(line) > 0 else line) for line in lines]
+ result = "\n".join(lines)
+
+ return result
+
+
+FAKE_MODEL_DISCLAIMER = """
+
+
+ This example uses a random model as the real ones are all very big. To get proper results, you should use
+ {real_checkpoint} instead of {fake_checkpoint}. If you get out-of-memory when loading that checkpoint, you can try
+ adding `device_map="auto"` in the `from_pretrained` call.
+
+
+"""
+
+
+PT_TOKEN_CLASSIFICATION_SAMPLE = r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, {model_class}
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> inputs = tokenizer(
+ ... "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="pt"
+ ... )
+
+ >>> with torch.no_grad():
+ ... logits = model(**inputs).logits
+
+ >>> predicted_token_class_ids = logits.argmax(-1)
+
+ >>> # Note that tokens are classified rather then input words which means that
+ >>> # there might be more predicted token classes than words.
+ >>> # Multiple token classes might account for the same word
+ >>> predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids[0]]
+ >>> predicted_tokens_classes
+ {expected_output}
+
+ >>> labels = predicted_token_class_ids
+ >>> loss = model(**inputs, labels=labels).loss
+ >>> round(loss.item(), 2)
+ {expected_loss}
+ ```
+"""
+
+PT_QUESTION_ANSWERING_SAMPLE = r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, {model_class}
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
+
+ >>> inputs = tokenizer(question, text, return_tensors="pt")
+ >>> with torch.no_grad():
+ ... outputs = model(**inputs)
+
+ >>> answer_start_index = outputs.start_logits.argmax()
+ >>> answer_end_index = outputs.end_logits.argmax()
+
+ >>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
+ >>> tokenizer.decode(predict_answer_tokens, skip_special_tokens=True)
+ {expected_output}
+
+ >>> # target is "nice puppet"
+ >>> target_start_index = torch.tensor([{qa_target_start_index}])
+ >>> target_end_index = torch.tensor([{qa_target_end_index}])
+
+ >>> outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index)
+ >>> loss = outputs.loss
+ >>> round(loss.item(), 2)
+ {expected_loss}
+ ```
+"""
+
+PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
+ Example of single-label classification:
+
+ ```python
+ >>> import torch
+ >>> from transformers import AutoTokenizer, {model_class}
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+
+ >>> with torch.no_grad():
+ ... logits = model(**inputs).logits
+
+ >>> predicted_class_id = logits.argmax().item()
+ >>> model.config.id2label[predicted_class_id]
+ {expected_output}
+
+ >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
+ >>> num_labels = len(model.config.id2label)
+ >>> model = {model_class}.from_pretrained("{checkpoint}", num_labels=num_labels)
+
+ >>> labels = torch.tensor([1])
+ >>> loss = model(**inputs, labels=labels).loss
+ >>> round(loss.item(), 2)
+ {expected_loss}
+ ```
+
+ Example of multi-label classification:
+
+ ```python
+ >>> import torch
+ >>> from transformers import AutoTokenizer, {model_class}
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}", problem_type="multi_label_classification")
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+
+ >>> with torch.no_grad():
+ ... logits = model(**inputs).logits
+
+ >>> predicted_class_ids = torch.arange(0, logits.shape[-1])[torch.sigmoid(logits).squeeze(dim=0) > 0.5]
+
+ >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
+ >>> num_labels = len(model.config.id2label)
+ >>> model = {model_class}.from_pretrained(
+ ... "{checkpoint}", num_labels=num_labels, problem_type="multi_label_classification"
+ ... )
+
+ >>> labels = torch.sum(
+ ... torch.nn.functional.one_hot(predicted_class_ids[None, :].clone(), num_classes=num_labels), dim=1
+ ... ).to(torch.float)
+ >>> loss = model(**inputs, labels=labels).loss
+ ```
+"""
+
+PT_MASKED_LM_SAMPLE = r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, {model_class}
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="pt")
+
+ >>> with torch.no_grad():
+ ... logits = model(**inputs).logits
+
+ >>> # retrieve index of {mask}
+ >>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
+
+ >>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
+ >>> tokenizer.decode(predicted_token_id)
+ {expected_output}
+
+ >>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"]
+ >>> # mask labels of non-{mask} tokens
+ >>> labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
+
+ >>> outputs = model(**inputs, labels=labels)
+ >>> round(outputs.loss.item(), 2)
+ {expected_loss}
+ ```
+"""
+
+PT_BASE_MODEL_SAMPLE = r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, {model_class}
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> last_hidden_states = outputs.last_hidden_state
+ ```
+"""
+
+PT_MULTIPLE_CHOICE_SAMPLE = r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, {model_class}
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
+ >>> choice0 = "It is eaten with a fork and a knife."
+ >>> choice1 = "It is eaten while held in the hand."
+ >>> labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1
+
+ >>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
+ >>> outputs = model(**{{k: v.unsqueeze(0) for k, v in encoding.items()}}, labels=labels) # batch size is 1
+
+ >>> # the linear classifier still needs to be trained
+ >>> loss = outputs.loss
+ >>> logits = outputs.logits
+ ```
+"""
+
+PT_CAUSAL_LM_SAMPLE = r"""
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers import AutoTokenizer, {model_class}
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs, labels=inputs["input_ids"])
+ >>> loss = outputs.loss
+ >>> logits = outputs.logits
+ ```
+"""
+
+PT_SPEECH_BASE_MODEL_SAMPLE = r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoProcessor, {model_class}
+ >>> import torch
+ >>> from biofit import load_dataset
+
+ >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
+ >>> dataset = dataset.sort("id")
+ >>> sampling_rate = dataset.features["audio"].sampling_rate
+
+ >>> processor = AutoProcessor.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> # audio file is decoded on the fly
+ >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
+ >>> with torch.no_grad():
+ ... outputs = model(**inputs)
+
+ >>> last_hidden_states = outputs.last_hidden_state
+ >>> list(last_hidden_states.shape)
+ {expected_output}
+ ```
+"""
+
+PT_SPEECH_CTC_SAMPLE = r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoProcessor, {model_class}
+ >>> from biofit import load_dataset
+ >>> import torch
+
+ >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
+ >>> dataset = dataset.sort("id")
+ >>> sampling_rate = dataset.features["audio"].sampling_rate
+
+ >>> processor = AutoProcessor.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> # audio file is decoded on the fly
+ >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
+ >>> with torch.no_grad():
+ ... logits = model(**inputs).logits
+ >>> predicted_ids = torch.argmax(logits, dim=-1)
+
+ >>> # transcribe speech
+ >>> transcription = processor.batch_decode(predicted_ids)
+ >>> transcription[0]
+ {expected_output}
+
+ >>> inputs["labels"] = processor(text=dataset[0]["text"], return_tensors="pt").input_ids
+
+ >>> # compute loss
+ >>> loss = model(**inputs).loss
+ >>> round(loss.item(), 2)
+ {expected_loss}
+ ```
+"""
+
+PT_SPEECH_SEQ_CLASS_SAMPLE = r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoFeatureExtractor, {model_class}
+ >>> from biofit import load_dataset
+ >>> import torch
+
+ >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
+ >>> dataset = dataset.sort("id")
+ >>> sampling_rate = dataset.features["audio"].sampling_rate
+
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> # audio file is decoded on the fly
+ >>> inputs = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
+
+ >>> with torch.no_grad():
+ ... logits = model(**inputs).logits
+
+ >>> predicted_class_ids = torch.argmax(logits, dim=-1).item()
+ >>> predicted_label = model.config.id2label[predicted_class_ids]
+ >>> predicted_label
+ {expected_output}
+
+ >>> # compute loss - target_label is e.g. "down"
+ >>> target_label = model.config.id2label[0]
+ >>> inputs["labels"] = torch.tensor([model.config.label2id[target_label]])
+ >>> loss = model(**inputs).loss
+ >>> round(loss.item(), 2)
+ {expected_loss}
+ ```
+"""
+
+
+PT_SPEECH_FRAME_CLASS_SAMPLE = r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoFeatureExtractor, {model_class}
+ >>> from biofit import load_dataset
+ >>> import torch
+
+ >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
+ >>> dataset = dataset.sort("id")
+ >>> sampling_rate = dataset.features["audio"].sampling_rate
+
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> # audio file is decoded on the fly
+ >>> inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt", sampling_rate=sampling_rate)
+ >>> with torch.no_grad():
+ ... logits = model(**inputs).logits
+
+ >>> probabilities = torch.sigmoid(logits[0])
+ >>> # labels is a one-hot array of shape (num_frames, num_speakers)
+ >>> labels = (probabilities > 0.5).long()
+ >>> labels[0].tolist()
+ {expected_output}
+ ```
+"""
+
+
+PT_SPEECH_XVECTOR_SAMPLE = r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoFeatureExtractor, {model_class}
+ >>> from biofit import load_dataset
+ >>> import torch
+
+ >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
+ >>> dataset = dataset.sort("id")
+ >>> sampling_rate = dataset.features["audio"].sampling_rate
+
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> # audio file is decoded on the fly
+ >>> inputs = feature_extractor(
+ ... [d["array"] for d in dataset[:2]["audio"]], sampling_rate=sampling_rate, return_tensors="pt", padding=True
+ ... )
+ >>> with torch.no_grad():
+ ... embeddings = model(**inputs).embeddings
+
+ >>> embeddings = torch.nn.functional.normalize(embeddings, dim=-1).cpu()
+
+ >>> # the resulting embeddings can be used for cosine similarity-based retrieval
+ >>> cosine_sim = torch.nn.CosineSimilarity(dim=-1)
+ >>> similarity = cosine_sim(embeddings[0], embeddings[1])
+ >>> threshold = 0.7 # the optimal threshold is dataset-dependent
+ >>> if similarity < threshold:
+ ... print("Speakers are not the same!")
+ >>> round(similarity.item(), 2)
+ {expected_output}
+ ```
+"""
+
+PT_VISION_BASE_MODEL_SAMPLE = r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, {model_class}
+ >>> import torch
+ >>> from biofit import load_dataset
+
+ >>> dataset = load_dataset("huggingface/cats-image")
+ >>> image = dataset["test"]["image"][0]
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> inputs = image_processor(image, return_tensors="pt")
+
+ >>> with torch.no_grad():
+ ... outputs = model(**inputs)
+
+ >>> last_hidden_states = outputs.last_hidden_state
+ >>> list(last_hidden_states.shape)
+ {expected_output}
+ ```
+"""
+
+PT_VISION_SEQ_CLASS_SAMPLE = r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, {model_class}
+ >>> import torch
+ >>> from biofit import load_dataset
+
+ >>> dataset = load_dataset("huggingface/cats-image")
+ >>> image = dataset["test"]["image"][0]
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> inputs = image_processor(image, return_tensors="pt")
+
+ >>> with torch.no_grad():
+ ... logits = model(**inputs).logits
+
+ >>> # model predicts one of the 1000 ImageNet classes
+ >>> predicted_label = logits.argmax(-1).item()
+ >>> print(model.config.id2label[predicted_label])
+ {expected_output}
+ ```
+"""
+
+
+PT_SAMPLE_DOCSTRINGS = {
+ "SequenceClassification": PT_SEQUENCE_CLASSIFICATION_SAMPLE,
+ "QuestionAnswering": PT_QUESTION_ANSWERING_SAMPLE,
+ "TokenClassification": PT_TOKEN_CLASSIFICATION_SAMPLE,
+ "MultipleChoice": PT_MULTIPLE_CHOICE_SAMPLE,
+ "MaskedLM": PT_MASKED_LM_SAMPLE,
+ "LMHead": PT_CAUSAL_LM_SAMPLE,
+ "BaseModel": PT_BASE_MODEL_SAMPLE,
+ "SpeechBaseModel": PT_SPEECH_BASE_MODEL_SAMPLE,
+ "CTC": PT_SPEECH_CTC_SAMPLE,
+ "AudioClassification": PT_SPEECH_SEQ_CLASS_SAMPLE,
+ "AudioFrameClassification": PT_SPEECH_FRAME_CLASS_SAMPLE,
+ "AudioXVector": PT_SPEECH_XVECTOR_SAMPLE,
+ "VisionBaseModel": PT_VISION_BASE_MODEL_SAMPLE,
+ "ImageClassification": PT_VISION_SEQ_CLASS_SAMPLE,
+}
+
+
+TF_TOKEN_CLASSIFICATION_SAMPLE = r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, {model_class}
+ >>> import tensorflow as tf
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> inputs = tokenizer(
+ ... "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="tf"
+ ... )
+
+ >>> logits = model(**inputs).logits
+ >>> predicted_token_class_ids = tf.math.argmax(logits, axis=-1)
+
+ >>> # Note that tokens are classified rather then input words which means that
+ >>> # there might be more predicted token classes than words.
+ >>> # Multiple token classes might account for the same word
+ >>> predicted_tokens_classes = [model.config.id2label[t] for t in predicted_token_class_ids[0].numpy().tolist()]
+ >>> predicted_tokens_classes
+ {expected_output}
+ ```
+
+ ```python
+ >>> labels = predicted_token_class_ids
+ >>> loss = tf.math.reduce_mean(model(**inputs, labels=labels).loss)
+ >>> round(float(loss), 2)
+ {expected_loss}
+ ```
+"""
+
+TF_QUESTION_ANSWERING_SAMPLE = r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, {model_class}
+ >>> import tensorflow as tf
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
+
+ >>> inputs = tokenizer(question, text, return_tensors="tf")
+ >>> outputs = model(**inputs)
+
+ >>> answer_start_index = int(tf.math.argmax(outputs.start_logits, axis=-1)[0])
+ >>> answer_end_index = int(tf.math.argmax(outputs.end_logits, axis=-1)[0])
+
+ >>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
+ >>> tokenizer.decode(predict_answer_tokens)
+ {expected_output}
+ ```
+
+ ```python
+ >>> # target is "nice puppet"
+ >>> target_start_index = tf.constant([{qa_target_start_index}])
+ >>> target_end_index = tf.constant([{qa_target_end_index}])
+
+ >>> outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index)
+ >>> loss = tf.math.reduce_mean(outputs.loss)
+ >>> round(float(loss), 2)
+ {expected_loss}
+ ```
+"""
+
+TF_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, {model_class}
+ >>> import tensorflow as tf
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
+
+ >>> logits = model(**inputs).logits
+
+ >>> predicted_class_id = int(tf.math.argmax(logits, axis=-1)[0])
+ >>> model.config.id2label[predicted_class_id]
+ {expected_output}
+ ```
+
+ ```python
+ >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
+ >>> num_labels = len(model.config.id2label)
+ >>> model = {model_class}.from_pretrained("{checkpoint}", num_labels=num_labels)
+
+ >>> labels = tf.constant(1)
+ >>> loss = model(**inputs, labels=labels).loss
+ >>> round(float(loss), 2)
+ {expected_loss}
+ ```
+"""
+
+TF_MASKED_LM_SAMPLE = r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, {model_class}
+ >>> import tensorflow as tf
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="tf")
+ >>> logits = model(**inputs).logits
+
+ >>> # retrieve index of {mask}
+ >>> mask_token_index = tf.where((inputs.input_ids == tokenizer.mask_token_id)[0])
+ >>> selected_logits = tf.gather_nd(logits[0], indices=mask_token_index)
+
+ >>> predicted_token_id = tf.math.argmax(selected_logits, axis=-1)
+ >>> tokenizer.decode(predicted_token_id)
+ {expected_output}
+ ```
+
+ ```python
+ >>> labels = tokenizer("The capital of France is Paris.", return_tensors="tf")["input_ids"]
+ >>> # mask labels of non-{mask} tokens
+ >>> labels = tf.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
+
+ >>> outputs = model(**inputs, labels=labels)
+ >>> round(float(outputs.loss), 2)
+ {expected_loss}
+ ```
+"""
+
+TF_BASE_MODEL_SAMPLE = r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, {model_class}
+ >>> import tensorflow as tf
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
+ >>> outputs = model(inputs)
+
+ >>> last_hidden_states = outputs.last_hidden_state
+ ```
+"""
+
+TF_MULTIPLE_CHOICE_SAMPLE = r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, {model_class}
+ >>> import tensorflow as tf
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
+ >>> choice0 = "It is eaten with a fork and a knife."
+ >>> choice1 = "It is eaten while held in the hand."
+
+ >>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="tf", padding=True)
+ >>> inputs = {{k: tf.expand_dims(v, 0) for k, v in encoding.items()}}
+ >>> outputs = model(inputs) # batch size is 1
+
+ >>> # the linear classifier still needs to be trained
+ >>> logits = outputs.logits
+ ```
+"""
+
+TF_CAUSAL_LM_SAMPLE = r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, {model_class}
+ >>> import tensorflow as tf
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
+ >>> outputs = model(inputs)
+ >>> logits = outputs.logits
+ ```
+"""
+
+TF_SPEECH_BASE_MODEL_SAMPLE = r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoProcessor, {model_class}
+ >>> from biofit import load_dataset
+
+ >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
+ >>> dataset = dataset.sort("id")
+ >>> sampling_rate = dataset.features["audio"].sampling_rate
+
+ >>> processor = AutoProcessor.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> # audio file is decoded on the fly
+ >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="tf")
+ >>> outputs = model(**inputs)
+
+ >>> last_hidden_states = outputs.last_hidden_state
+ >>> list(last_hidden_states.shape)
+ {expected_output}
+ ```
+"""
+
+TF_SPEECH_CTC_SAMPLE = r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoProcessor, {model_class}
+ >>> from biofit import load_dataset
+ >>> import tensorflow as tf
+
+ >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
+ >>> dataset = dataset.sort("id")
+ >>> sampling_rate = dataset.features["audio"].sampling_rate
+
+ >>> processor = AutoProcessor.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> # audio file is decoded on the fly
+ >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="tf")
+ >>> logits = model(**inputs).logits
+ >>> predicted_ids = tf.math.argmax(logits, axis=-1)
+
+ >>> # transcribe speech
+ >>> transcription = processor.batch_decode(predicted_ids)
+ >>> transcription[0]
+ {expected_output}
+ ```
+
+ ```python
+ >>> inputs["labels"] = processor(text=dataset[0]["text"], return_tensors="tf").input_ids
+
+ >>> # compute loss
+ >>> loss = model(**inputs).loss
+ >>> round(float(loss), 2)
+ {expected_loss}
+ ```
+"""
+
+TF_VISION_BASE_MODEL_SAMPLE = r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, {model_class}
+ >>> from biofit import load_dataset
+
+ >>> dataset = load_dataset("huggingface/cats-image")
+ >>> image = dataset["test"]["image"][0]
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> inputs = image_processor(image, return_tensors="tf")
+ >>> outputs = model(**inputs)
+
+ >>> last_hidden_states = outputs.last_hidden_state
+ >>> list(last_hidden_states.shape)
+ {expected_output}
+ ```
+"""
+
+TF_VISION_SEQ_CLASS_SAMPLE = r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoImageProcessor, {model_class}
+ >>> import tensorflow as tf
+ >>> from biofit import load_dataset
+
+ >>> dataset = load_dataset("huggingface/cats-image")
+ >>> image = dataset["test"]["image"][0]
+
+ >>> image_processor = AutoImageProcessor.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> inputs = image_processor(image, return_tensors="tf")
+ >>> logits = model(**inputs).logits
+
+ >>> # model predicts one of the 1000 ImageNet classes
+ >>> predicted_label = int(tf.math.argmax(logits, axis=-1))
+ >>> print(model.config.id2label[predicted_label])
+ {expected_output}
+ ```
+"""
+
+TF_SAMPLE_DOCSTRINGS = {
+ "SequenceClassification": TF_SEQUENCE_CLASSIFICATION_SAMPLE,
+ "QuestionAnswering": TF_QUESTION_ANSWERING_SAMPLE,
+ "TokenClassification": TF_TOKEN_CLASSIFICATION_SAMPLE,
+ "MultipleChoice": TF_MULTIPLE_CHOICE_SAMPLE,
+ "MaskedLM": TF_MASKED_LM_SAMPLE,
+ "LMHead": TF_CAUSAL_LM_SAMPLE,
+ "BaseModel": TF_BASE_MODEL_SAMPLE,
+ "SpeechBaseModel": TF_SPEECH_BASE_MODEL_SAMPLE,
+ "CTC": TF_SPEECH_CTC_SAMPLE,
+ "VisionBaseModel": TF_VISION_BASE_MODEL_SAMPLE,
+ "ImageClassification": TF_VISION_SEQ_CLASS_SAMPLE,
+}
+
+
+FLAX_TOKEN_CLASSIFICATION_SAMPLE = r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, {model_class}
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
+
+ >>> outputs = model(**inputs)
+ >>> logits = outputs.logits
+ ```
+"""
+
+FLAX_QUESTION_ANSWERING_SAMPLE = r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, {model_class}
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
+ >>> inputs = tokenizer(question, text, return_tensors="jax")
+
+ >>> outputs = model(**inputs)
+ >>> start_scores = outputs.start_logits
+ >>> end_scores = outputs.end_logits
+ ```
+"""
+
+FLAX_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, {model_class}
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
+
+ >>> outputs = model(**inputs)
+ >>> logits = outputs.logits
+ ```
+"""
+
+FLAX_MASKED_LM_SAMPLE = r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, {model_class}
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="jax")
+
+ >>> outputs = model(**inputs)
+ >>> logits = outputs.logits
+ ```
+"""
+
+FLAX_BASE_MODEL_SAMPLE = r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, {model_class}
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
+ >>> outputs = model(**inputs)
+
+ >>> last_hidden_states = outputs.last_hidden_state
+ ```
+"""
+
+FLAX_MULTIPLE_CHOICE_SAMPLE = r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, {model_class}
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
+ >>> choice0 = "It is eaten with a fork and a knife."
+ >>> choice1 = "It is eaten while held in the hand."
+
+ >>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="jax", padding=True)
+ >>> outputs = model(**{{k: v[None, :] for k, v in encoding.items()}})
+
+ >>> logits = outputs.logits
+ ```
+"""
+
+FLAX_CAUSAL_LM_SAMPLE = r"""
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, {model_class}
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
+ >>> outputs = model(**inputs)
+
+ >>> # retrieve logts for next token
+ >>> next_token_logits = outputs.logits[:, -1]
+ ```
+"""
+
+FLAX_SAMPLE_DOCSTRINGS = {
+ "SequenceClassification": FLAX_SEQUENCE_CLASSIFICATION_SAMPLE,
+ "QuestionAnswering": FLAX_QUESTION_ANSWERING_SAMPLE,
+ "TokenClassification": FLAX_TOKEN_CLASSIFICATION_SAMPLE,
+ "MultipleChoice": FLAX_MULTIPLE_CHOICE_SAMPLE,
+ "MaskedLM": FLAX_MASKED_LM_SAMPLE,
+ "BaseModel": FLAX_BASE_MODEL_SAMPLE,
+ "LMHead": FLAX_CAUSAL_LM_SAMPLE,
+}
+
+
+def filter_outputs_from_example(docstring, **kwargs):
+ """
+ Removes the lines testing an output with the doctest syntax in a code sample when it's set to `None`.
+ """
+ for key, value in kwargs.items():
+ if value is not None:
+ continue
+
+ doc_key = "{" + key + "}"
+ docstring = re.sub(rf"\n([^\n]+)\n\s+{doc_key}\n", "\n", docstring)
+
+ return docstring
+
+
+def add_code_sample_docstrings(
+ *docstr,
+ processor_class=None,
+ checkpoint=None,
+ output_type=None,
+ _config_class=None,
+ mask="[MASK]",
+ qa_target_start_index=14,
+ qa_target_end_index=15,
+ model_cls=None,
+ modality=None,
+ expected_output=None,
+ expected_loss=None,
+ real_checkpoint=None,
+ revision=None,
+):
+ def docstring_decorator(fn):
+ # model_class defaults to function's class if not specified otherwise
+ model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls
+
+ if model_class[:2] == "TF":
+ sample_docstrings = TF_SAMPLE_DOCSTRINGS
+ elif model_class[:4] == "Flax":
+ sample_docstrings = FLAX_SAMPLE_DOCSTRINGS
+ else:
+ sample_docstrings = PT_SAMPLE_DOCSTRINGS
+
+ # putting all kwargs for docstrings in a dict to be used
+ # with the `.format(**doc_kwargs)`. Note that string might
+ # be formatted with non-existing keys, which is fine.
+ doc_kwargs = {
+ "model_class": model_class,
+ "processor_class": processor_class,
+ "checkpoint": checkpoint,
+ "mask": mask,
+ "qa_target_start_index": qa_target_start_index,
+ "qa_target_end_index": qa_target_end_index,
+ "expected_output": expected_output,
+ "expected_loss": expected_loss,
+ "real_checkpoint": real_checkpoint,
+ "fake_checkpoint": checkpoint,
+ "true": "{true}", # For syntax that conflicts with formatting.
+ }
+
+ if (
+ "SequenceClassification" in model_class
+ or "AudioClassification" in model_class
+ ) and modality == "audio":
+ code_sample = sample_docstrings["AudioClassification"]
+ elif "SequenceClassification" in model_class:
+ code_sample = sample_docstrings["SequenceClassification"]
+ elif "QuestionAnswering" in model_class:
+ code_sample = sample_docstrings["QuestionAnswering"]
+ elif "TokenClassification" in model_class:
+ code_sample = sample_docstrings["TokenClassification"]
+ elif "MultipleChoice" in model_class:
+ code_sample = sample_docstrings["MultipleChoice"]
+ elif "MaskedLM" in model_class or model_class in [
+ "FlaubertWithLMHeadModel",
+ "XLMWithLMHeadModel",
+ ]:
+ code_sample = sample_docstrings["MaskedLM"]
+ elif "LMHead" in model_class or "CausalLM" in model_class:
+ code_sample = sample_docstrings["LMHead"]
+ elif "CTC" in model_class:
+ code_sample = sample_docstrings["CTC"]
+ elif "AudioFrameClassification" in model_class:
+ code_sample = sample_docstrings["AudioFrameClassification"]
+ elif "XVector" in model_class and modality == "audio":
+ code_sample = sample_docstrings["AudioXVector"]
+ elif "Model" in model_class and modality == "audio":
+ code_sample = sample_docstrings["SpeechBaseModel"]
+ elif "Model" in model_class and modality == "vision":
+ code_sample = sample_docstrings["VisionBaseModel"]
+ elif "Model" in model_class or "Encoder" in model_class:
+ code_sample = sample_docstrings["BaseModel"]
+ elif "ImageClassification" in model_class:
+ code_sample = sample_docstrings["ImageClassification"]
+ else:
+ raise ValueError(f"Docstring can't be built for model {model_class}")
+
+ code_sample = filter_outputs_from_example(
+ code_sample, expected_output=expected_output, expected_loss=expected_loss
+ )
+ if real_checkpoint is not None:
+ code_sample = FAKE_MODEL_DISCLAIMER + code_sample
+ func_doc = (fn.__doc__ or "") + "".join(docstr)
+ output_doc = (
+ ""
+ if output_type is None
+ else _prepare_output_docstrings(output_type, _config_class)
+ )
+ built_doc = code_sample.format(**doc_kwargs)
+ if revision is not None:
+ if re.match(r"^refs/pr/\\d+", revision):
+ raise ValueError(
+ f"The provided revision '{revision}' is incorrect. It should point to"
+ " a pull request reference on the hub like 'refs/pr/6'"
+ )
+ built_doc = built_doc.replace(
+ f'from_pretrained("{checkpoint}")',
+ f'from_pretrained("{checkpoint}", revision="{revision}")',
+ )
+ fn.__doc__ = func_doc + output_doc + built_doc
+ return fn
+
+ return docstring_decorator
+
+
+def replace_return_docstrings(output_type=None, _config_class=None):
+ def docstring_decorator(fn):
+ func_doc = fn.__doc__
+ lines = func_doc.split("\n")
+ i = 0
+ while i < len(lines) and re.search(r"^\s*Returns?:\s*$", lines[i]) is None:
+ i += 1
+ if i < len(lines):
+ indent = len(_get_indent(lines[i]))
+ lines[i] = _prepare_output_docstrings(
+ output_type, _config_class, min_indent=indent
+ )
+ func_doc = "\n".join(lines)
+ else:
+ raise ValueError(
+ f"The function {fn} should have an empty 'Return:' or 'Returns:' in its docstring as placeholder, "
+ f"current docstring is:\n{func_doc}"
+ )
+ fn.__doc__ = func_doc
+ return fn
+
+ return docstring_decorator
+
+
+def copy_func(f):
+ """Returns a copy of a function f."""
+ # Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)
+ g = types.FunctionType(
+ f.__code__,
+ f.__globals__,
+ name=f.__name__,
+ argdefs=f.__defaults__,
+ closure=f.__closure__,
+ )
+ g = functools.update_wrapper(g, f)
+ g.__kwdefaults__ = f.__kwdefaults__
+ return g
diff --git a/src/biofit/utils/file_utils.py b/src/biofit/utils/file_utils.py
new file mode 100644
index 0000000..0d780ac
--- /dev/null
+++ b/src/biofit/utils/file_utils.py
@@ -0,0 +1,234 @@
+"""
+This file is adapted from the datasets library, which in turn is adapted from the AllenNLP library.
+
+datasets
+~~~~~~~~
+Utilities for working with the local dataset cache.
+This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
+Copyright by the AllenNLP authors.
+"""
+
+import io
+import os
+import posixpath
+import shutil
+import sys
+import tempfile
+from pathlib import Path
+from typing import Optional, TypeVar, Union
+from urllib.parse import unquote, urlparse
+
+from huggingface_hub.utils import insecure_hashlib
+
+from . import logging
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+INCOMPLETE_SUFFIX = ".incomplete"
+
+PathLike = Union[str, Path]
+T = TypeVar("T", str, Path)
+
+
+def is_remote_url(url_or_filename: Union[str, Path]) -> bool:
+ return urlparse(str(url_or_filename)).scheme != "" and not os.path.ismount(
+ urlparse(str(url_or_filename)).scheme + ":/"
+ )
+
+
+def is_local_path(url_or_filename: Union[str, Path]) -> bool:
+ # On unix the scheme of a local path is empty (for both absolute and relative),
+ # while on windows the scheme is the drive name (ex: "c") for absolute paths.
+ # for details on the windows behavior, see https://bugs.python.org/issue42215
+ url_or_filename = Path(url_or_filename).resolve().as_posix()
+
+ return urlparse(url_or_filename).scheme == "" or os.path.ismount(
+ urlparse(url_or_filename).scheme + ":/"
+ )
+
+
+def is_relative_path(url_or_filename: Union[str, Path]) -> bool:
+ return urlparse(str(url_or_filename)).scheme == "" and not os.path.isabs(
+ str(url_or_filename)
+ )
+
+
+def expand_path(path):
+ """
+ Check if a path is relative and expand it if necessary.
+ Handles file paths and URLs, including user home directory expansion.
+
+ :param path: str representing the path or URL to check and expand
+ :return: str with the expanded absolute path or full URL
+ """
+ # Parse the path as a URL
+ parsed = urlparse(path)
+
+ # If it's a URL (not a local file path)
+ if parsed.scheme and parsed.netloc:
+ return path # Return the original URL
+
+ # If it's a file URL, convert it to a local path
+ if parsed.scheme == "file":
+ path = unquote(parsed.path)
+ # On Windows, remove leading slash
+ if sys.platform == "win32" and path.startswith("/"):
+ path = path[1:]
+ else:
+ # It's a regular path, use the full original string
+ path = unquote(path)
+
+ # Convert to Path object
+ path = Path(path)
+
+ # Expand user's home directory if present
+ path = path.expanduser()
+
+ # Check if the path is absolute
+ if path.is_absolute():
+ return str(path.resolve())
+
+ # If path is relative, make it absolute
+ return os.path.normpath(str((Path.cwd() / path).resolve()))
+
+
+def relative_to_absolute_path(path: T) -> T:
+ """Convert relative path to absolute path."""
+ abs_path_str = os.path.abspath(os.path.expanduser(os.path.expandvars(str(path))))
+ return Path(abs_path_str) if isinstance(path, Path) else abs_path_str
+
+
+def is_file_name(url_or_path_or_file: T) -> bool:
+ if is_local_path(url_or_path_or_file):
+ if is_relative_path(url_or_path_or_file):
+ if "/" not in Path(url_or_path_or_file).as_posix():
+ return True
+ return False
+
+
+def has_ext(url_or_path: Union[str, Path], ext: Optional[str] = None) -> bool:
+ if ext is None:
+ return Path(url_or_path).suffix != ""
+ return Path(url_or_path).suffix == ext
+
+
+def has_separator(url_or_path: Union[str, Path]) -> bool:
+ return "/" in str(url_or_path) or "\\" in str(url_or_path)
+
+
+def url_or_path_join(base_name: str, *pathnames: str) -> str:
+ if is_remote_url(base_name):
+ return posixpath.join(
+ base_name,
+ *(str(pathname).replace(os.sep, "/").lstrip("/") for pathname in pathnames),
+ )
+ else:
+ return Path(base_name, *pathnames).as_posix()
+
+
+def url_or_path_parent(url_or_path: str) -> str:
+ if is_remote_url(url_or_path):
+ return url_or_path[: url_or_path.rindex("/")]
+ else:
+ return os.path.dirname(url_or_path)
+
+
+def hash_url_to_filename(url, etag=None):
+ """
+ Convert `url` into a hashed filename in a repeatable way.
+ If `etag` is specified, append its hash to the url's, delimited
+ by a period.
+ If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name
+ so that TF 2.0 can identify it as a HDF5 file
+ (see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
+ """
+ url_bytes = url.encode("utf-8")
+ url_hash = insecure_hashlib.sha256(url_bytes)
+ filename = url_hash.hexdigest()
+
+ if etag:
+ etag_bytes = etag.encode("utf-8")
+ etag_hash = insecure_hashlib.sha256(etag_bytes)
+ filename += "." + etag_hash.hexdigest()
+
+ if url.endswith(".py"):
+ filename += ".py"
+
+ return filename
+
+
+def add_start_docstrings(*docstr):
+ def docstring_decorator(fn):
+ fn.__doc__ = (
+ "".join(docstr) + "\n\n" + (fn.__doc__ if fn.__doc__ is not None else "")
+ )
+ return fn
+
+ return docstring_decorator
+
+
+def add_end_docstrings(*docstr):
+ def docstring_decorator(fn):
+ fn.__doc__ = (
+ (fn.__doc__ if fn.__doc__ is not None else "") + "\n\n" + "".join(docstr)
+ )
+ return fn
+
+ return docstring_decorator
+
+
+def estimate_dataset_size(paths):
+ return sum(path.stat().st_size for path in paths)
+
+
+def readline(f: io.RawIOBase):
+ # From: https://github.com/python/cpython/blob/d27e2f4d118e7a9909b6a3e5da06c5ff95806a85/Lib/_pyio.py#L525
+ res = bytearray()
+ while True:
+ b = f.read(1)
+ if not b:
+ break
+ res += b
+ if res.endswith(b"\n"):
+ break
+ return bytes(res)
+
+
+def move_temp_file(
+ temp_file: Union[tempfile._TemporaryFileWrapper, str, Path], final_file: str
+):
+ if isinstance(temp_file, tempfile._TemporaryFileWrapper):
+ temp_file.close()
+ temp_file_name = Path(temp_file.name)
+ elif not isinstance(temp_file, Path):
+ temp_file_name = Path(temp_file)
+
+ if not isinstance(final_file, Path):
+ final_file = Path(final_file)
+
+ if temp_file_name.exists():
+ if final_file.exists():
+ final_file.unlink()
+ # is source a windows path?
+ if temp_file_name.resolve().drive:
+ if len(str(temp_file_name)) > 255:
+ src_file = "\\\\?\\" + str(temp_file_name.resolve())
+ else:
+ src_file = temp_file_name.as_posix()
+ else:
+ src_file = temp_file_name.as_posix()
+
+ if final_file.resolve().drive:
+ if len(str(final_file)) > 255:
+ dst_file = "\\\\?\\" + str(final_file.resolve())
+ else:
+ dst_file = final_file.as_posix()
+ else:
+ dst_file = final_file.as_posix()
+ shutil.move(src_file, dst_file)
+ umask = os.umask(0o666)
+ os.umask(umask)
+ os.chmod(dst_file, 0o666 & ~umask)
+ return dst_file
+ else:
+ raise FileNotFoundError(f"Temporary file {temp_file_name.name} not found.")
diff --git a/src/biofit/utils/fingerprint.py b/src/biofit/utils/fingerprint.py
new file mode 100644
index 0000000..38c3976
--- /dev/null
+++ b/src/biofit/utils/fingerprint.py
@@ -0,0 +1,400 @@
+"""
+NOTE: The contents of this file have been inlined from the fingerprint and _dill module in the datasets package's source code
+https://github.com/huggingface/datasets/blob/c47cc141c9e6e0edafffdcfde55b171612f1de76/src/datasets/fingerprint.py
+
+This module has fixes / adaptations for biofit use cases that make it different from the original
+datasets library
+
+The following modifications have been made:
+ - Added python source code parsing logic to the `Hasher` class. This is used to hash the contents of a python file
+ without comments since we only care about the code that is actually executed.
+ - Added the _hash_python_lines function to this file without importing
+ huggingface_hub
+
+datasets
+~~~~~~~~
+Copyright 2023 The HuggingFace Team. All rights reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+
+import functools
+import hashlib
+import importlib
+import inspect
+import os
+import posixpath
+import random
+import re
+from pathlib import Path
+import sys
+from types import FunctionType, ModuleType
+from typing import Any, Dict, List, Union
+
+from biocore.utils.import_util import is_datasets_available
+import xxhash
+from .version import Version
+from ._dill import dumps
+
+import biofit.config as config
+import biofit.utils.logging as logging
+from biofit.utils.version import __version__
+
+from .file_utils import is_file_name, is_remote_url
+
+_CACHING_ENABLED = True
+
+logger = logging.get_logger(__name__)
+
+fingerprint_rng = random.Random()
+
+
+def _hash_python_lines(lines: List[str]) -> str:
+ filtered_lines = []
+ for line in lines:
+ line = re.sub(r"#.*", "", line) # remove comments
+ if line:
+ filtered_lines.append(line)
+ full_str = "\n".join(filtered_lines)
+
+ # Make a hash from all this code
+ full_bytes = full_str.encode("utf-8")
+ _kwargs = {"usedforsecurity": False} if sys.version_info >= (3, 9) else {}
+ sha256 = functools.partial(hashlib.sha256, **_kwargs)
+ return sha256(full_bytes).hexdigest()
+
+
+def generate_random_fingerprint(nbits: int = 64) -> str:
+ return f"{fingerprint_rng.getrandbits(nbits):0{nbits//4}x}"
+
+
+class Hasher:
+ """Hasher that accepts python objects as inputs."""
+
+ dispatch: Dict = {}
+
+ def __init__(self):
+ self.m = xxhash.xxh64()
+
+ @classmethod
+ def hash_bytes(cls, value: Union[bytes, List[bytes]]) -> str:
+ value = [value] if isinstance(value, bytes) else value
+ m = xxhash.xxh64()
+ for x in value:
+ m.update(x)
+ return m.hexdigest()
+
+ @classmethod
+ def hash(cls, value: Any) -> str:
+ return cls.hash_bytes(dumps(value))
+
+ @staticmethod
+ def _hash_python_lines(self, module: Union[ModuleType, FunctionType, type]) -> str:
+ filtered_lines = []
+ lines = inspect.getsource(module).splitlines()
+ for line in lines:
+ line = re.sub(r"#.*", "", line) # remove comments
+ if line:
+ filtered_lines.append(line)
+ return "\n".join(filtered_lines)
+
+ def update(self, value: Any) -> None:
+ header_for_update = f"=={type(value)}=="
+ value_for_update = self.hash(value)
+ self.m.update(header_for_update.encode("utf8"))
+ self.m.update(value_for_update.encode("utf-8"))
+
+ def hexdigest(self) -> str:
+ return self.m.hexdigest()
+
+
+def enable_caching():
+ """
+ When applying transforms on a dataset, the data are stored in cache files.
+ The caching mechanism allows to reload an existing cache file if it's already been computed.
+
+ Reloading a dataset is possible since the cache files are named using the dataset fingerprint, which is updated
+ after each transform.
+
+ If disabled, the library will no longer reload cached datasets files when applying transforms to the datasets.
+ More precisely, if the caching is disabled:
+ - cache files are always recreated
+ - cache files are written to a temporary directory that is deleted when session closes
+ - cache files are named using a random hash instead of the dataset fingerprint
+ - use [`~datasets.Dataset.save_to_disk`] to save a transformed dataset or it will be deleted when session closes
+ - caching doesn't affect [`~datasets.load_dataset`]. If you want to regenerate a dataset from scratch you should use
+ the `download_mode` parameter in [`~datasets.load_dataset`].
+ """
+ global _CACHING_ENABLED
+ _CACHING_ENABLED = True
+
+
+def disable_caching():
+ """
+ When applying transforms on a dataset, the data are stored in cache files.
+ The caching mechanism allows to reload an existing cache file if it's already been computed.
+
+ Reloading a dataset is possible since the cache files are named using the dataset fingerprint, which is updated
+ after each transform.
+
+ If disabled, the library will no longer reload cached datasets files when applying transforms to the datasets.
+ More precisely, if the caching is disabled:
+ - cache files are always recreated
+ - cache files are written to a temporary directory that is deleted when session closes
+ - cache files are named using a random hash instead of the dataset fingerprint
+ - use [`~datasets.Dataset.save_to_disk`] to save a transformed dataset or it will be deleted when session closes
+ - caching doesn't affect [`~datasets.load_dataset`]. If you want to regenerate a dataset from scratch you should use
+ the `download_mode` parameter in [`~datasets.load_dataset`].
+ """
+ global _CACHING_ENABLED
+ _CACHING_ENABLED = False
+
+
+def is_caching_enabled() -> bool:
+ """
+ When applying transforms on a dataset, the data are stored in cache files.
+ The caching mechanism allows to reload an existing cache file if it's already been computed.
+
+ Reloading a dataset is possible since the cache files are named using the dataset fingerprint, which is updated
+ after each transform.
+
+ If disabled, the library will no longer reload cached datasets files when applying transforms to the datasets.
+ More precisely, if the caching is disabled:
+ - cache files are always recreated
+ - cache files are written to a temporary directory that is deleted when session closes
+ - cache files are named using a random hash instead of the dataset fingerprint
+ """
+ global _CACHING_ENABLED
+ return bool(_CACHING_ENABLED)
+
+
+def fingerprint_from_kwargs(fingerprint, kwargs):
+ hash = Hasher()
+ if fingerprint:
+ hash.update(fingerprint)
+
+ for key, value in kwargs.items():
+ if isinstance(key, str) and "fingerprint" in key:
+ continue
+ hash.update(key)
+ if isinstance(value, dict):
+ hash.update(fingerprint_from_kwargs(fingerprint, value))
+ else:
+ hash.update(str(value))
+
+ return hash.hexdigest()
+
+
+def fingerprint_from_data(data):
+ from biocore import DataHandler, get_data_format
+
+ if hasattr(data, "_fingerprint"):
+ return data._fingerprint
+ if hasattr(data, "fingerprint"):
+ return data.fingerprint
+ hasher = Hasher()
+ original_format = get_data_format(data)
+
+ if original_format is not None:
+ features = DataHandler.get_column_names(data, generate_cols=True)
+ n_samples = DataHandler.get_shape(data)[0]
+ first_row = DataHandler.select_row(data, 0)
+ hasher.update(original_format)
+ hasher.update(features)
+ hasher.update(n_samples)
+ hasher.update(first_row)
+ elif hasattr(data, "__dict__"):
+ state = data.__dict__
+ for key in sorted(state):
+ hasher.update(key)
+ try:
+ hasher.update(state[key])
+ except Exception:
+ hasher.update(str(state[key]))
+ else:
+ raise ValueError("Data object is not hashable")
+ return hasher.hexdigest()
+
+
+def generate_cache_dir(
+ X, fingerprint, cache_dir=None, root_dir=config.BIOFIT_DATASETS_CACHE
+):
+ if cache_dir is None:
+ if isinstance(root_dir, Path):
+ root_dir = root_dir.as_posix()
+ if (
+ root_dir == config.BIOFIT_DATASETS_CACHE.as_posix()
+ and hasattr(X, "cache_files")
+ and X.cache_files
+ ):
+ cache_dir = os.path.dirname(X.cache_files[0]["filename"])
+ else:
+ cache_dir = _build_cache_dir(
+ X,
+ fingerprint,
+ cache_dir_root=root_dir,
+ )
+
+ if cache_dir:
+ Path(cache_dir).mkdir(parents=True, exist_ok=True)
+ return Path(cache_dir).resolve().as_posix()
+ return None
+
+
+def get_cache_file_name(cache_dir, fingerprint, cache_file_name=None):
+ if cache_file_name:
+ if is_file_name(cache_file_name):
+ cache_file_name = Path(cache_dir) / Path(cache_file_name).with_suffix(
+ ".json"
+ )
+ else:
+ cache_file_name = Path(cache_file_name).with_suffix(".json")
+ else:
+ cache_file_name = Path(cache_dir) / f"cache-{fingerprint}.json"
+ return cache_file_name.resolve().as_posix()
+
+
+def _relative_data_dir(
+ fingerprint, builder_name, dataset_name, version=None, with_hash=True
+) -> str:
+ """
+ Constructs a relative directory path for a dataset based on its properties.
+
+ Args:
+ dataset (Dataset): The dataset for which to construct the path.
+ with_version (bool, optional): Include version information in the path.
+ with_hash (bool, optional): Include hash information in the path.
+
+ Returns:
+ str: Relative path for the dataset.
+ """
+ builder_data_dir = posixpath.join(builder_name, f"{dataset_name}-{fingerprint}")
+ if version:
+ version = str(version) if isinstance(version, str) else __version__
+ builder_data_dir = posixpath.join(builder_data_dir, version)
+ if with_hash and is_datasets_available():
+ hash = _hash_python_lines(
+ inspect.getsource(
+ getattr(importlib.import_module("datasets.table"), "InMemoryTable")
+ )
+ )
+ builder_data_dir = posixpath.join(builder_data_dir, hash)
+ return builder_data_dir
+
+
+def _build_cache_dir(
+ obj,
+ fingerprint: str,
+ cache_dir_root: str = config.BIOFIT_DATASETS_CACHE,
+) -> str:
+ """
+ Builds the cache directory path for storing processed dataset versions.
+
+ Args:
+ dataset (Union[Dataset, IterableDataset]): The dataset to cache.
+ cache_dir_root (str, optional): Root directory for caching datasets.
+
+ Returns:
+ str: The path to the dataset's cache directory.
+ """
+ if (
+ hasattr(obj, "version")
+ and hasattr(obj, "config_name")
+ and hasattr(obj, "builder_name")
+ ):
+ version = str(obj.version) if isinstance(obj.version, str) else __version__
+ dataset_name = obj.config_name or "default"
+ builder_name = obj.builder_name or "in_memory"
+ elif (
+ hasattr(obj, "config")
+ and hasattr(obj.config, "version")
+ and hasattr(obj.config, "processor_name")
+ and hasattr(obj.config, "processor_type")
+ ):
+ version = (
+ str(obj.config.version)
+ if obj.config.version and not str(obj.config.version) == "0.0.0"
+ else __version__
+ )
+ dataset_name = obj.config.processor_name or "default"
+ builder_name = obj.config.processor_type or "in"
+
+ else:
+ version = __version__
+ dataset_name = "default"
+ builder_name = "in_memory"
+
+ builder_data_dir = posixpath.join(
+ cache_dir_root,
+ _relative_data_dir(
+ fingerprint=fingerprint,
+ builder_name=builder_name,
+ dataset_name=dataset_name,
+ ),
+ )
+ version_data_dir = posixpath.join(
+ cache_dir_root,
+ _relative_data_dir(
+ fingerprint=fingerprint,
+ builder_name=builder_name,
+ dataset_name=dataset_name,
+ version=version,
+ ),
+ )
+
+ def _other_versions_on_disk():
+ """Returns previous versions on disk."""
+ if not os.path.exists(builder_data_dir):
+ return []
+
+ version_dirnames = []
+ for dir_name in os.listdir(builder_data_dir):
+ try:
+ version_dirnames.append((Version(dir_name), dir_name))
+ except ValueError: # Invalid version (ex: incomplete data dir)
+ pass
+ version_dirnames.sort(reverse=True)
+ return version_dirnames
+ # Check and warn if other versions exist
+
+ if not is_remote_url(builder_data_dir):
+ version_dirs = _other_versions_on_disk()
+ if version_dirs:
+ other_version = version_dirs[0][0]
+ if other_version != version:
+ warn_msg = (
+ f"Found a different version {str(other_version)} of dataset {dataset_name} in "
+ f"cache_dir {cache_dir_root}. Using currently defined version "
+ f"{str(version)}."
+ )
+ logger.warning(warn_msg)
+ if not os.path.exists(version_data_dir):
+ os.makedirs(version_data_dir, exist_ok=True)
+
+ return version_data_dir
+
+
+def update_fingerprint(fingerprint, value, key=None):
+ if key is None and value is None:
+ return fingerprint
+ hasher = Hasher()
+ if fingerprint:
+ hasher.update(fingerprint)
+ if key:
+ hasher.update(key)
+ try:
+ hasher.update(value)
+ except Exception:
+ return generate_random_fingerprint()
+ else:
+ return hasher.hexdigest()
diff --git a/src/biofit/utils/generic.py b/src/biofit/utils/generic.py
new file mode 100644
index 0000000..42365e1
--- /dev/null
+++ b/src/biofit/utils/generic.py
@@ -0,0 +1,536 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Generic utilities
+"""
+
+import inspect
+import tempfile
+import contextlib
+import shutil
+import os
+import stat
+from pathlib import Path
+from collections import OrderedDict, UserDict
+from collections.abc import MutableMapping
+from contextlib import ExitStack, contextmanager
+from dataclasses import fields, is_dataclass
+from enum import Enum
+from typing import Any, ContextManager, List, Tuple, Union, Optional, Generator
+
+import numpy as np
+
+# optional imports
+try:
+ import polars as pl
+except ImportError:
+ pl = None
+
+
+class cached_property(property):
+ """
+ Descriptor that mimics @property but caches output in member variable.
+
+ From tensorflow_datasets
+
+ Built-in in functools from Python 3.8.
+ """
+
+ def __get__(self, obj, objtype=None):
+ # See docs.python.org/3/howto/descriptor.html#properties
+ if obj is None:
+ return self
+ if self.fget is None:
+ raise AttributeError("unreadable attribute")
+ attr = "__cached_" + self.fget.__name__
+ cached = getattr(obj, attr, None)
+ if cached is None:
+ cached = self.fget(obj)
+ setattr(obj, attr, cached)
+ return cached
+
+
+def _set_write_permission_and_retry(func, path, excinfo):
+ os.chmod(path, stat.S_IWRITE)
+ func(path)
+
+
+@contextlib.contextmanager
+def SoftTemporaryDirectory(
+ suffix: Optional[str] = None,
+ prefix: Optional[str] = None,
+ dir: Optional[Union[Path, str]] = None,
+ **kwargs,
+) -> Generator[str, None, None]:
+ """
+ Context manager to create a temporary directory and safely delete it.
+
+ If tmp directory cannot be deleted normally, we set the WRITE permission and retry.
+ If cleanup still fails, we give up but don't raise an exception. This is equivalent
+ to `tempfile.TemporaryDirectory(..., ignore_cleanup_errors=True)` introduced in
+ Python 3.10.
+
+ See https://www.scivision.dev/python-tempfile-permission-error-windows/.
+ """
+ tmpdir = tempfile.TemporaryDirectory(
+ prefix=prefix, suffix=suffix, dir=dir, **kwargs
+ )
+ yield tmpdir.name
+
+ try:
+ # First once with normal cleanup
+ shutil.rmtree(tmpdir.name)
+ except Exception:
+ # If failed, try to set write permission and retry
+ try:
+ shutil.rmtree(tmpdir.name, onerror=_set_write_permission_and_retry)
+ except Exception:
+ pass
+
+ # And finally, cleanup the tmpdir.
+ # If it fails again, give up but do not throw error
+ try:
+ tmpdir.cleanup()
+ except Exception:
+ pass
+
+
+# vendored from distutils.util
+def strtobool(val):
+ """Convert a string representation of truth to true (1) or false (0).
+
+ True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'.
+ Raises ValueError if 'val' is anything else.
+ """
+ val = val.lower()
+ if val in {"y", "yes", "t", "true", "on", "1"}:
+ return 1
+ if val in {"n", "no", "f", "false", "off", "0"}:
+ return 0
+ raise ValueError(f"invalid truth value {val!r}")
+
+
+def infer_framework_from_repr(x):
+ """
+ Tries to guess the framework of an object `x` from its repr (brittle but will help in `is_tensor` to try the
+ frameworks in a smart order, without the need to import the frameworks).
+ """
+ representation = str(type(x))
+ if representation.startswith("
+
+ You can't unpack a `ModelOutput` directly. Use the [`~utils.ModelOutput.to_tuple`] method to convert it to a tuple
+ before.
+
+
+ """
+
+ def __init_subclass__(cls) -> None:
+ """Register subclasses as pytree nodes.
+
+ This is necessary to synchronize gradients when using `torch.nn.parallel.DistributedDataParallel` with
+ `static_graph=True` with modules that output `ModelOutput` subclasses.
+ """
+ pass
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # Subclasses of ModelOutput must use the @dataclass decorator
+ # This check is done in __init__ because the @dataclass decorator operates after __init_subclass__
+ # issubclass() would return True for issubclass(ModelOutput, ModelOutput) when False is needed
+ # Just need to check that the current class is not ModelOutput
+ is_modeloutput_subclass = self.__class__ != ModelOutput
+
+ if is_modeloutput_subclass and not is_dataclass(self):
+ raise TypeError(
+ f"{self.__module__}.{self.__class__.__name__} is not a dataclasss."
+ " This is a subclass of ModelOutput and so must use the @dataclass decorator."
+ )
+
+ def __post_init__(self):
+ """Check the ModelOutput dataclass.
+
+ Only occurs if @dataclass decorator has been used.
+ """
+ class_fields = fields(self)
+
+ # Safety and consistency checks
+ if not len(class_fields):
+ raise ValueError(f"{self.__class__.__name__} has no fields.")
+ if not all(field.default is None for field in class_fields[1:]):
+ raise ValueError(
+ f"{self.__class__.__name__} should not have more than one required field."
+ )
+
+ first_field = getattr(self, class_fields[0].name)
+ other_fields_are_none = all(
+ getattr(self, field.name) is None for field in class_fields[1:]
+ )
+
+ if other_fields_are_none and not is_tensor(first_field):
+ if isinstance(first_field, dict):
+ iterator = first_field.items()
+ first_field_iterator = True
+ else:
+ try:
+ iterator = iter(first_field)
+ first_field_iterator = True
+ except TypeError:
+ first_field_iterator = False
+
+ # if we provided an iterator as first field and the iterator is a (key, value) iterator
+ # set the associated fields
+ if first_field_iterator:
+ for idx, element in enumerate(iterator):
+ if (
+ not isinstance(element, (list, tuple))
+ or not len(element) == 2
+ or not isinstance(element[0], str)
+ ):
+ if idx == 0:
+ # If we do not have an iterator of key/values, set it as attribute
+ self[class_fields[0].name] = first_field
+ else:
+ # If we have a mixed iterator, raise an error
+ raise ValueError(
+ f"Cannot set key/value for {element}. It needs to be a tuple (key, value)."
+ )
+ break
+ setattr(self, element[0], element[1])
+ if element[1] is not None:
+ self[element[0]] = element[1]
+ elif first_field is not None:
+ self[class_fields[0].name] = first_field
+ else:
+ for field in class_fields:
+ v = getattr(self, field.name)
+ if v is not None:
+ self[field.name] = v
+
+ def __delitem__(self, *args, **kwargs):
+ raise Exception(
+ f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance."
+ )
+
+ def setdefault(self, *args, **kwargs):
+ raise Exception(
+ f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance."
+ )
+
+ def pop(self, *args, **kwargs):
+ raise Exception(
+ f"You cannot use ``pop`` on a {self.__class__.__name__} instance."
+ )
+
+ def update(self, *args, **kwargs):
+ raise Exception(
+ f"You cannot use ``update`` on a {self.__class__.__name__} instance."
+ )
+
+ def __getitem__(self, k):
+ if isinstance(k, str):
+ inner_dict = dict(self.items())
+ return inner_dict[k]
+ else:
+ return self.to_tuple()[k]
+
+ def __setattr__(self, name, value):
+ if name in self.keys() and value is not None:
+ # Don't call self.__setitem__ to avoid recursion errors
+ super().__setitem__(name, value)
+ super().__setattr__(name, value)
+
+ def __setitem__(self, key, value):
+ # Will raise a KeyException if needed
+ super().__setitem__(key, value)
+ # Don't call self.__setattr__ to avoid recursion errors
+ super().__setattr__(key, value)
+
+ def __reduce__(self):
+ if not is_dataclass(self):
+ return super().__reduce__()
+ callable, _args, *remaining = super().__reduce__()
+ args = tuple(getattr(self, field.name) for field in fields(self))
+ return callable, args, *remaining
+
+ def to_tuple(self) -> Tuple[Any]:
+ """
+ Convert self to a tuple containing all the attributes/keys that are not `None`.
+ """
+ return tuple(self[k] for k in self.keys())
+
+
+class ExplicitEnum(str, Enum):
+ """
+ Enum with more explicit error message for missing values.
+ """
+
+ @classmethod
+ def _missing_(cls, value):
+ raise ValueError(
+ f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
+ )
+
+
+class ContextManagers:
+ """
+ Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers`
+ in the `fastcore` library.
+ """
+
+ def __init__(self, context_managers: List[ContextManager]):
+ self.context_managers = context_managers
+ self.stack = ExitStack()
+
+ def __enter__(self):
+ for context_manager in self.context_managers:
+ self.stack.enter_context(context_manager)
+
+ def __exit__(self, *args, **kwargs):
+ self.stack.__exit__(*args, **kwargs)
+
+
+def can_return_loss(model_class):
+ """
+ Check if a given model can return loss.
+
+ Args:
+ model_class (`type`): The class of the model.
+ """
+ framework = infer_framework(model_class)
+ if framework == "tf":
+ signature = inspect.signature(model_class.call) # TensorFlow models
+ elif framework == "pt":
+ signature = inspect.signature(model_class.forward) # PyTorch models
+ else:
+ signature = inspect.signature(model_class.__call__) # Flax models
+
+ for p in signature.parameters:
+ if p == "return_loss" and signature.parameters[p].default is True:
+ return True
+
+ return False
+
+
+def find_labels(model_class):
+ """
+ Find the labels used by a given model.
+
+ Args:
+ model_class (`type`): The class of the model.
+ """
+ model_name = model_class.__name__
+ infer_framework(model_class)
+ signature = inspect.signature(model_class.__call__) # Flax models
+
+ if "QuestionAnswering" in model_name:
+ return [
+ p
+ for p in signature.parameters
+ if "label" in p or p in ("start_positions", "end_positions")
+ ]
+ else:
+ return [p for p in signature.parameters if "label" in p]
+
+
+def flatten_dict(d: MutableMapping, parent_key: str = "", delimiter: str = "."):
+ """Flatten a nested dict into a single level dict."""
+
+ def _flatten_dict(d, parent_key="", delimiter="."):
+ for k, v in d.items():
+ key = str(parent_key) + delimiter + str(k) if parent_key else k
+ if v and isinstance(v, MutableMapping):
+ yield from flatten_dict(v, key, delimiter=delimiter).items()
+ else:
+ yield key, v
+
+ return dict(_flatten_dict(d, parent_key, delimiter))
+
+
+@contextmanager
+def working_or_temp_dir(working_dir, use_temp_dir: bool = False):
+ if use_temp_dir:
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ yield tmp_dir
+ else:
+ yield working_dir
+
+
+def add_model_info_to_auto_map(auto_map, repo_id):
+ """
+ Adds the information of the repo_id to a given auto map.
+ """
+ for key, value in auto_map.items():
+ if isinstance(value, (tuple, list)):
+ auto_map[key] = [
+ f"{repo_id}--{v}" if (v is not None and "--" not in v) else v
+ for v in value
+ ]
+ elif value is not None and "--" not in value:
+ auto_map[key] = f"{repo_id}--{value}"
+
+ return auto_map
+
+
+def infer_framework(model_class):
+ """
+ Infers the framework of a given model without using isinstance(), because we cannot guarantee that the relevant
+ classes are imported or available.
+ """
+ for base_class in inspect.getmro(model_class):
+ module = base_class.__module__
+ name = base_class.__name__
+ if (
+ module.startswith("tensorflow")
+ or module.startswith("keras")
+ or name == "TFPreTrainedModel"
+ ):
+ return "tf"
+ elif module.startswith("torch") or name == "PreTrainedModel":
+ return "pt"
+ elif (
+ module.startswith("flax")
+ or module.startswith("jax")
+ or name == "FlaxPreTrainedModel"
+ ):
+ return "flax"
+ else:
+ raise TypeError(f"Could not infer framework from class {model_class}.")
diff --git a/src/biofit/utils/gorilla.py b/src/biofit/utils/gorilla.py
new file mode 100644
index 0000000..7462939
--- /dev/null
+++ b/src/biofit/utils/gorilla.py
@@ -0,0 +1,1009 @@
+# __ __ __
+# .-----.-----.----|__| | .---.-.
+# | _ | _ | _| | | | _ |
+# |___ |_____|__| |__|__|__|___._|
+# |_____|
+#
+
+"""
+NOTE: The contents of this file have been inlined from the gorilla package's source code
+https://github.com/christophercrouzet/gorilla/blob/v0.3.0/gorilla.py
+
+This module has fixes / adaptations for MLflow use cases that make it different from the original
+gorilla library
+
+The following modifications have been made:
+ - Modify `get_original_attribute` logic, search from children classes to parent classes,
+ and for each class check "_gorilla_original_{attr_name}" attribute first.
+ first. This will ensure get the correct original attribute in any cases, e.g.,
+ the case some classes in the hierarchy haven't been patched, but some others are
+ patched, this case the previous code is risky to get wrong original attribute.
+ - Make `get_original_attribute` support bypassing descriptor protocol.
+ - remove `get_attribute` method, use `get_original_attribute` with
+ `bypass_descriptor_protocol=True` instead of calling it.
+ - After reverting patch, there will be no side-effect, restore object to be exactly the
+ original status.
+ - Remove `create_patches` and `patches` methods.
+
+gorilla
+~~~~~~~
+
+Convenient approach to monkey patching.
+
+:copyright: Copyright 2014-2017 by Christopher Crouzet.
+:license: MIT, see LICENSE for details.
+"""
+
+import base64
+import collections
+import copy
+import importlib
+import importlib.util
+import inspect
+import json
+import pkgutil
+import sys
+from types import ModuleType
+from typing import Optional, Union
+
+import dill
+from biofit.utils import logging
+
+__version__ = "0.3.0"
+logger = logging.get_logger(__name__)
+
+
+class PackageNotFoundError(Exception):
+ """Package not found error."""
+
+
+class AttributeNotFoundError(Exception):
+ """Attribute not found error."""
+
+
+class SameSourceAndDestinationError(Exception):
+ """Same source and destination error."""
+
+
+def _iteritems(d, **kwargs):
+ return iter(d.items(**kwargs))
+
+
+def _load_module(name):
+ return importlib.import_module(name)
+
+
+# Pattern for each internal attribute name.
+_PATTERN = "_gorilla_%s"
+
+# Pattern for the name of the overidden attributes to be stored.
+_ORIGINAL_NAME = _PATTERN % ("original_%s",)
+
+# Pattern for the name of the patch attributes to be stored.
+_ACTIVE_PATCH = "_gorilla_active_patch_%s"
+
+# Attribute for the decorator data.
+_DECORATOR_DATA = _PATTERN % ("decorator_data",)
+
+
+def default_filter(name, obj):
+ """Attribute filter.
+
+ It filters out module attributes, and also methods starting with an
+ underscore `_`.
+
+ This is used as the default filter for the :func:`create_patches` function
+ and the :func:`patches` decorator.
+
+ Parameters
+ ----------
+ name : str
+ Attribute name.
+ obj : object
+ Attribute value.
+
+ Returns
+ -------
+ bool
+ Whether the attribute should be returned.
+ """
+ return not (isinstance(obj, ModuleType) or name.startswith("_"))
+
+
+class DecoratorData:
+ """Decorator data.
+
+ Attributes
+ ----------
+ patches : list of gorilla.Patch
+ Patches created through the decorators.
+ override : dict
+ Any overriding value defined by the :func:`destination`, :func:`name`,
+ and :func:`settings` decorators.
+ filter : bool or None
+ Value defined by the :func:`filter` decorator, if any, or `None`
+ otherwise.
+ """
+
+ def __init__(self):
+ """Constructor."""
+ self.patches = []
+ self.override = {}
+ self.filter = None
+
+
+class Settings:
+ """Define the patching behaviour.
+
+ Attributes
+ ----------
+ allow_hit : bool
+ A hit occurs when an attribute at the destination already exists with
+ the name given by the patch. If `False`, the patch process won't
+ allow setting a new value for the attribute by raising an exception.
+ Defaults to `False`.
+ store_hit : bool
+ If `True` and :attr:`allow_hit` is also set to `True`, then any
+ attribute at the destination that is hit is stored under a different
+ name before being overwritten by the patch. Defaults to `True`.
+ """
+
+ def __init__(self, **kwargs):
+ """Constructor.
+
+ Parameters
+ ----------
+ kwargs
+ Keyword arguments, see the attributes.
+ """
+ self.allow_hit = False
+ self.store_hit = True
+ self._update(**kwargs)
+
+ def __repr__(self):
+ values = ", ".join(
+ [f"{key}={value!r}" for key, value in sorted(_iteritems(self.__dict__))]
+ )
+ return f"{type(self).__name__}({values})"
+
+ def __eq__(self, other):
+ if isinstance(other, type(self)):
+ return self.__dict__ == other.__dict__
+
+ return NotImplemented
+
+ def __ne__(self, other):
+ is_equal = self.__eq__(other)
+ return is_equal if is_equal is NotImplemented else not is_equal
+
+ def _update(self, **kwargs):
+ """Update some settings.
+
+ Parameters
+ ----------
+ kwargs
+ Settings to update.
+ """
+ self.__dict__.update(**kwargs)
+
+ def to_dict(self):
+ return self.__dict__
+
+ @classmethod
+ def from_dict(cls, json):
+ return cls(**json)
+
+ def to_json(self, filename=None):
+ if filename:
+ with open(filename, "w") as f:
+ return json.dump(self.to_dict(), f)
+ else:
+ return json.dumps(self.to_dict())
+
+ @classmethod
+ def from_json(cls, json_str):
+ return cls.from_dict(json.loads(json_str))
+
+ @classmethod
+ def from_json_file(cls, filename):
+ with open(filename, "r") as f:
+ return cls.from_json(f.read())
+
+
+class Patch:
+ """Describe all the information required to apply a patch.
+
+ Attributes
+ ----------
+ destination(`object`):
+ Patch destination.
+ name(`str`):
+ Name of the attribute at the destination.
+ obj(`object`):
+ Attribute value.
+ source_name(`str`, *optional*):
+ Name of the attribute at the source. If `None`, then the name of the attribute is the same as `name`.
+ source(`Union[str, ModuleType]`, *optional*):
+ The fully qualified name of the object. If `None`, then the object will not be serializable.
+ For example, if object `obj` is defined in module `package1.submodule1`, then `source`
+ should be `package1.submodule1.obj`. If object is a module, then `source` should be the module name.
+ settings(`gorilla.Settings`, *optional*):
+ Settings. If `None`, the default settings are used.
+
+ Warning
+ -------
+ It is highly recommended to use the output of the function
+ :func:`get_attribute` for setting the attribute :attr:`obj`. This will
+ ensure that the descriptor protocol is bypassed instead of possibly
+ retrieving attributes invalid for patching, such as bound methods.
+ """
+
+ def __init__(
+ self,
+ destination: object,
+ name: str,
+ obj: object,
+ source: Optional[Union[str, ModuleType]] = None,
+ source_name: Optional[str] = None,
+ settings: Optional[Settings] = None,
+ ):
+ """Constructor.
+
+ Parameters
+ ----------
+ destination : object
+ See the :attr:`~Patch.destination` attribute.
+ source : str
+ See the :attr:`~Patch.source` attribute.
+ name : str
+ See the :attr:`~Patch.name` attribute.
+ obj : object
+ See the :attr:`~Patch.obj` attribute.
+ asname : str
+ See the :attr:`~Patch.asname` attribute.
+ settings : gorilla.Settings
+ See the :attr:`~Patch.settings` attribute.
+ """
+ self.destination = destination
+ self.name = name
+ self.obj = obj
+ self.source_name = source_name or name
+ if source:
+ if isinstance(source, ModuleType):
+ self.source = source.__name__
+ if not hasattr(source, self.source_name):
+ raise AttributeNotFoundError(
+ f"Cannot find attribute {self.source_name} in module {self.source}"
+ )
+ else:
+ self.source = source
+ if not importlib.util.find_spec(self.source):
+ raise PackageNotFoundError(f"Cannot find module {self.source}")
+ if not hasattr(importlib.import_module(self.source), self.source_name):
+ raise AttributeNotFoundError(
+ f"Cannot find attribute {self.source_name} in module {self.source}"
+ )
+
+ if self.source == destination.__name__: # don't patch the same module
+ raise SameSourceAndDestinationError(
+ f"Cannot patch {self.source_name} in its own module {self.source}"
+ )
+
+ self.settings = settings
+ self.is_inplace_patch = None
+
+ def __repr__(self):
+ return "{}(destination={!r}, name={!r}, source={!r}, settings={!r})".format(
+ type(self).__name__,
+ self.destination.__name__,
+ self.name,
+ self.source,
+ self.settings,
+ )
+
+ def __eq__(self, other):
+ if isinstance(other, type(self)):
+ return self.__dict__ == other.__dict__
+
+ return NotImplemented
+
+ def __ne__(self, other):
+ is_equal = self.__eq__(other)
+ return is_equal if is_equal is NotImplemented else not is_equal
+
+ def __hash__(self): # pylint: disable=useless-super-delegation
+ return super().__hash__()
+
+ def _update(self, **kwargs):
+ """Update some attributes.
+
+ If a 'settings' attribute is passed as a dict, then it will update the
+ content of the settings, if any, instead of completely overwriting it.
+
+ Parameters
+ ----------
+ kwargs
+ Attributes to update.
+
+ Raises
+ ------
+ ValueError
+ The setting doesn't exist.
+ """
+ for key, value in _iteritems(kwargs):
+ if key == "settings":
+ if isinstance(value, dict):
+ if self.settings is None:
+ self.settings = Settings(**value)
+ else:
+ self.settings._update(**value)
+ else:
+ self.settings = copy.deepcopy(value)
+ else:
+ setattr(self, key, value)
+
+ def to_dict(self, pickle=False):
+ if self.source is None and not pickle:
+ raise ValueError(
+ "Cannot serialize Patch object without `obj` source. Please specify source when creating Patch object or set pickle=True to pickle the object."
+ )
+ destination_name = getattr(self.destination, "__name__", None)
+ destination_module_name = getattr(
+ inspect.getmodule(self.destination), "__name__", None
+ )
+ destination_value = None
+ if not pickle:
+ if not type(self.destination).__name__ == "module":
+ destination_value = destination_name
+ elif pickle and not type(self.destination).__name__ == "module":
+ destination_value = base64.b64encode(dill.dumps(self.destination)).decode(
+ "utf-8"
+ )
+
+ destination = {
+ "type": type(self.destination).__name__ if not pickle else "pickle",
+ "path": destination_module_name,
+ "value": destination_value,
+ "version": find_package_version(destination_module_name),
+ }
+
+ obj = {
+ "type": type(self.obj).__name__ if not pickle else "pickle",
+ "path": self.source,
+ "value": self.source_name
+ if not pickle
+ else base64.b64encode(dill.dumps(self.obj)).decode("utf-8"),
+ "version": find_package_version(self.source),
+ }
+
+ return {
+ "destination": destination,
+ "obj": obj,
+ "name": self.name,
+ "source": self.source,
+ "source_name": self.source_name,
+ "settings": self.settings.__dict__ if self.settings else None,
+ }
+
+ @classmethod
+ def from_dict(cls, json):
+ def reconstruct_object(obj_details):
+ if obj_details["type"] == "module":
+ # The object is a module
+ obj = importlib.import_module(obj_details["path"])
+ elif obj_details["type"] == "pickle":
+ # The object is a pickled object
+ obj = dill.loads(base64.b64decode(obj_details["value"].encode("utf-8")))
+ return obj
+ elif obj_details["value"]:
+ # Object within a module
+ module = importlib.import_module(obj_details["path"])
+ obj = getattr(module, obj_details["value"])
+ else:
+ raise ValueError(f"Path must be specified for `{obj_details['type']}`")
+
+ # Check version
+ current_version = find_package_version(obj_details["path"])
+ if obj_details["version"] and current_version != obj_details["version"]:
+ logger.debug(
+ f"Version mismatch for {obj_details['path']}: "
+ f"Expected {obj_details['version']}, got {current_version}"
+ )
+ raise ValueError(
+ f"Version mismatch for {obj_details['path']}: "
+ f"Expected {obj_details['version']}, got {current_version}"
+ )
+
+ return obj
+
+ destination = reconstruct_object(json["destination"])
+ obj = reconstruct_object(json["obj"])
+ settings = Settings(**json["settings"]) if json["settings"] else None
+
+ return cls(
+ destination=destination,
+ name=json["name"],
+ obj=obj,
+ source=json["source"],
+ source_name=json["source_name"],
+ settings=settings,
+ )
+
+ def to_json(self, filename=None):
+ if filename:
+ with open(filename, "w") as f:
+ return json.dump(self.to_dict(), f)
+ else:
+ return json.dumps(self.to_dict())
+
+ def to_json_file(self, filename):
+ with open(filename, "w") as f:
+ return json.dump(self.to_dict(), f)
+
+ @classmethod
+ def from_json(cls, json_str):
+ return cls.from_dict(json.loads(json_str))
+
+ @classmethod
+ def from_json_file(cls, filename):
+ with open(filename, "r") as f:
+ return cls.from_json(f.read())
+
+
+def apply(patch):
+ """Apply a patch.
+
+ The patch's :attr:`~Patch.obj` attribute is injected into the patch's
+ :attr:`~Patch.destination` under the patch's :attr:`~Patch.name`.
+
+ This is a wrapper around calling
+ `setattr(patch.destination, patch.name, patch.obj)`.
+
+ Parameters
+ ----------
+ patch : gorilla.Patch
+ Patch.
+
+ Raises
+ ------
+ RuntimeError
+ Overwriting an existing attribute is not allowed when the setting
+ :attr:`Settings.allow_hit` is set to `True`.
+
+ Note
+ ----
+ If both the attributes :attr:`Settings.allow_hit` and
+ :attr:`Settings.store_hit` are `True` but that the target attribute seems
+ to have already been stored, then it won't be stored again to avoid losing
+ the original attribute that was stored the first time around.
+ """
+ # is_inplace_patch = True represents the patch object will overwrite the original
+ # attribute
+ patch.is_inplace_patch = patch.name in patch.destination.__dict__
+ settings = Settings() if patch.settings is None else patch.settings
+
+ curr_active_patch = _ACTIVE_PATCH % (patch.name,)
+ if curr_active_patch in patch.destination.__dict__:
+ logger.debug(
+ f"Patch {patch.name} on {destination.__name__} already existed. Overwrite old patch."
+ )
+
+ # When a hit occurs due to an attribute at the destination already existing
+ # with the patch's name, the existing attribute is referred to as 'target'.
+ try:
+ target = get_original_attribute(
+ patch.destination, patch.name, bypass_descriptor_protocol=True
+ )
+ except AttributeError:
+ pass
+ else:
+ if not settings.allow_hit:
+ raise RuntimeError(
+ "An attribute named '%s' already exists at the destination " # noqa: UP031
+ "'%s'. Set a different name through the patch object to avoid "
+ "a name clash or set the setting 'allow_hit' to True to "
+ "overwrite the attribute. In the latter case, it is "
+ "recommended to also set the 'store_hit' setting to True in "
+ "order to store the original attribute under a different "
+ "name so it can still be accessed."
+ % (patch.name, patch.destination.__name__)
+ )
+
+ if settings.store_hit:
+ original_name = _ORIGINAL_NAME % (patch.name,)
+ setattr(patch.destination, original_name, target)
+
+ setattr(patch.destination, patch.name, patch.obj)
+ setattr(patch.destination, curr_active_patch, patch)
+
+
+def revert(patch):
+ """Revert a patch.
+ Parameters
+ ----------
+ patch : gorilla.Patch
+ Patch.
+ Note
+ ----
+ This is only possible if the attribute :attr:`Settings.store_hit` was set
+ to `True` when applying the patch and overriding an existing attribute.
+
+ Notice:
+ This method is taken from
+ https://github.com/christophercrouzet/gorilla/blob/v0.4.0/gorilla.py#L318-L351
+ with modifictions for autologging disablement purposes.
+ """
+ # If an curr_active_patch has not been set on destination class for the current patch,
+ # then the patch has not been applied and we do not need to revert anything.
+ curr_active_patch = _ACTIVE_PATCH % (patch.name,)
+ if curr_active_patch not in patch.destination.__dict__:
+ # already reverted.
+ return
+
+ original_name = _ORIGINAL_NAME % (patch.name,)
+
+ if patch.is_inplace_patch:
+ # check whether original_name is in destination. We cannot use hasattr because it will
+ # try to get attribute from parent classes if attribute not found in destination class.
+ if original_name not in patch.destination.__dict__:
+ raise RuntimeError(
+ "Cannot revert the attribute named '%s' since the setting " # noqa: UP031
+ "'store_hit' was not set to True when applying the patch."
+ % (patch.destination.__name__,)
+ )
+ # restore original method
+ # during reverting patch, we need restore the raw attribute to the patch point
+ # so get original attribute bypassing descriptor protocal
+ original = object.__getattribute__(patch.destination, original_name)
+ setattr(patch.destination, patch.name, original)
+ else:
+ # delete patched method
+ delattr(patch.destination, patch.name)
+
+ if original_name in patch.destination.__dict__:
+ delattr(patch.destination, original_name)
+ delattr(patch.destination, curr_active_patch)
+
+
+def patch(destination, name=None, settings=None):
+ """Decorator to create a patch.
+
+ The object being decorated becomes the :attr:`~Patch.obj` attribute of the
+ patch.
+
+ Parameters
+ ----------
+ destination : object
+ Patch destination.
+ name : str
+ Name of the attribute at the destination.
+ settings : gorilla.Settings
+ Settings.
+
+ Returns
+ -------
+ object
+ The decorated object.
+
+ See Also
+ --------
+ :class:`Patch`.
+ """
+
+ def decorator(wrapped):
+ base = _get_base(wrapped)
+ name_ = base.__name__ if name is None else name
+ settings_ = copy.deepcopy(settings)
+ patch = Patch(destination, name_, wrapped, settings=settings_)
+ data = get_decorator_data(base, set_default=True)
+ data.patches.append(patch)
+ return wrapped
+
+ return decorator
+
+
+def destination(value):
+ """Modifier decorator to update a patch's destination.
+
+ This only modifies the behaviour of the :func:`create_patches` function
+ and the :func:`patches` decorator, given that their parameter
+ `use_decorators` is set to `True`.
+
+ Parameters
+ ----------
+ value : object
+ Patch destination.
+
+ Returns
+ -------
+ object
+ The decorated object.
+ """
+
+ def decorator(wrapped):
+ data = get_decorator_data(_get_base(wrapped), set_default=True)
+ data.override["destination"] = value
+ return wrapped
+
+ return decorator
+
+
+def name(value):
+ """Modifier decorator to update a patch's name.
+
+ This only modifies the behaviour of the :func:`create_patches` function
+ and the :func:`patches` decorator, given that their parameter
+ `use_decorators` is set to `True`.
+
+ Parameters
+ ----------
+ value : object
+ Patch name.
+
+ Returns
+ -------
+ object
+ The decorated object.
+ """
+
+ def decorator(wrapped):
+ data = get_decorator_data(_get_base(wrapped), set_default=True)
+ data.override["name"] = value
+ return wrapped
+
+ return decorator
+
+
+def settings(**kwargs):
+ """Modifier decorator to update a patch's settings.
+
+ This only modifies the behaviour of the :func:`create_patches` function
+ and the :func:`patches` decorator, given that their parameter
+ `use_decorators` is set to `True`.
+
+ Parameters
+ ----------
+ kwargs
+ Settings to update. See :class:`Settings` for the list.
+
+ Returns
+ -------
+ object
+ The decorated object.
+ """
+
+ def decorator(wrapped):
+ data = get_decorator_data(_get_base(wrapped), set_default=True)
+ data.override.setdefault("settings", {}).update(kwargs)
+ return wrapped
+
+ return decorator
+
+
+def filter(value): # pylint: disable=redefined-builtin
+ """Modifier decorator to force the inclusion or exclusion of an attribute.
+
+ This only modifies the behaviour of the :func:`create_patches` function
+ and the :func:`patches` decorator, given that their parameter
+ `use_decorators` is set to `True`.
+
+ Parameters
+ ----------
+ value : bool
+ `True` to force inclusion, `False` to force exclusion, and `None`
+ to inherit from the behaviour defined by :func:`create_patches` or
+ :func:`patches`.
+
+ Returns
+ -------
+ object
+ The decorated object.
+ """
+
+ def decorator(wrapped):
+ data = get_decorator_data(_get_base(wrapped), set_default=True)
+ data.filter = value
+ return wrapped
+
+ return decorator
+
+
+def find_patches(modules, recursive=True):
+ """Find all the patches created through decorators.
+
+ Parameters
+ ----------
+ modules : list of module
+ Modules and/or packages to search the patches in.
+ recursive : bool
+ `True` to search recursively in subpackages.
+
+ Returns
+ -------
+ list of gorilla.Patch
+ Patches found.
+
+ Raises
+ ------
+ TypeError
+ The input is not a valid package or module.
+
+ See Also
+ --------
+ :func:`patch`, :func:`patches`.
+ """
+ out = []
+ modules = (
+ module
+ for package in modules
+ for module in _module_iterator(package, recursive=recursive)
+ )
+ for module in modules:
+ members = _get_members(module, filter=None)
+ for _, value in members:
+ base = _get_base(value)
+ decorator_data = get_decorator_data(base)
+ if decorator_data is None:
+ continue
+
+ out.extend(decorator_data.patches)
+
+ return out
+
+
+def get_original_attribute(obj, name, bypass_descriptor_protocol=False):
+ """Retrieve an overridden attribute that has been stored.
+
+ Parameters
+ ----------
+ obj : object
+ Object to search the attribute in.
+ name : str
+ Name of the attribute.
+ bypass_descriptor_protocol: boolean
+ bypassing descriptor protocol if true. When storing original method during patching or
+ restoring original method during reverting patch, we need set bypass_descriptor_protocol
+ to be True to ensure get the raw attribute object.
+
+ Returns
+ -------
+ object
+ The attribute found.
+
+ Raises
+ ------
+ AttributeError
+ The attribute couldn't be found.
+
+ Note
+ ----
+ if setting store_hit=False, then after patch applied, this methods may return patched
+ attribute instead of original attribute in specific cases.
+
+ See Also
+ --------
+ :attr:`Settings.allow_hit`.
+ """
+
+ original_name = _ORIGINAL_NAME % (name,)
+ curr_active_patch = _ACTIVE_PATCH % (name,)
+
+ def _get_attr(obj_, name_):
+ if bypass_descriptor_protocol:
+ return object.__getattribute__(obj_, name_)
+ else:
+ return getattr(obj_, name_)
+
+ no_original_stored_err = (
+ "Original attribute %s was not stored when patching, set "
+ "store_hit=True will address this."
+ )
+
+ if inspect.isclass(obj):
+ # Search from children classes to parent classes, and check "original_name" attribute
+ # first. This will ensure get the correct original attribute in any cases, e.g.,
+ # the case some classes in the hierarchy haven't been patched, but some others are
+ # patched, this case the previous code is risky to get wrong original attribute.
+ for obj_ in inspect.getmro(obj):
+ if original_name in obj_.__dict__:
+ return _get_attr(obj_, original_name)
+ elif name in obj_.__dict__:
+ if curr_active_patch in obj_.__dict__:
+ patch = getattr(obj, curr_active_patch)
+ if patch.is_inplace_patch:
+ raise RuntimeError(
+ no_original_stored_err % (f"{obj_.__name__}.{name}",)
+ )
+ else:
+ # non inplace patch, we can get original methods in parent classes.
+ # so go on checking parent classes
+ continue
+ return _get_attr(obj_, name)
+ else:
+ # go on checking parent classes
+ continue
+ raise AttributeError(f"'{type(obj)}' object has no attribute '{name}'")
+ else:
+ try:
+ return _get_attr(obj, original_name)
+ except AttributeError:
+ if curr_active_patch in obj.__dict__:
+ raise RuntimeError(
+ no_original_stored_err % (f"{type(obj).__name__}.{name}",)
+ )
+ return _get_attr(obj, name)
+
+
+def get_decorator_data(obj, set_default=False):
+ """Retrieve any decorator data from an object.
+
+ Parameters
+ ----------
+ obj : object
+ Object.
+ set_default : bool
+ If no data is found, a default one is set on the object and returned,
+ otherwise `None` is returned.
+
+ Returns
+ -------
+ gorilla.DecoratorData
+ The decorator data or `None`.
+ """
+ if inspect.isclass(obj):
+ datas = getattr(obj, _DECORATOR_DATA, {})
+ data = datas.setdefault(obj, None)
+ if data is None and set_default:
+ data = DecoratorData()
+ datas[obj] = data
+ setattr(obj, _DECORATOR_DATA, datas)
+ else:
+ data = getattr(obj, _DECORATOR_DATA, None)
+ if data is None and set_default:
+ data = DecoratorData()
+ setattr(obj, _DECORATOR_DATA, data)
+
+ return data
+
+
+def _get_base(obj):
+ """Unwrap decorators to retrieve the base object.
+
+ Parameters
+ ----------
+ obj : object
+ Object.
+
+ Returns
+ -------
+ object
+ The base object found or the input object otherwise.
+ """
+ if hasattr(obj, "__func__"):
+ obj = obj.__func__
+ elif isinstance(obj, property):
+ obj = obj.fget
+ elif isinstance(obj, (classmethod, staticmethod)):
+ # Fallback for Python < 2.7 back when no `__func__` attribute
+ # was defined for those descriptors.
+ obj = obj.__get__(None, object)
+ else:
+ return obj
+
+ return _get_base(obj)
+
+
+def _get_members(obj, traverse_bases=True, filter=default_filter, recursive=True):
+ """Retrieve the member attributes of a module or a class.
+
+ The descriptor protocol is bypassed.
+
+ Parameters
+ ----------
+ obj : module or class
+ Object.
+ traverse_bases : bool
+ If the object is a class, the base classes are also traversed.
+ filter : function
+ Attributes for which the function returns `False` are skipped. The
+ function needs to define two parameters: `name`, the attribute name,
+ and `obj`, the attribute value. If `None`, no attribute is skipped.
+ recursive : bool
+ `True` to search recursively through subclasses.
+
+ Returns
+ ------
+ list of (name, value)
+ A list of tuples each containing the name and the value of the
+ attribute.
+ """
+ if filter is None:
+ filter = _true
+
+ out = []
+ stack = collections.deque((obj,))
+ while stack:
+ obj = stack.popleft()
+ if traverse_bases and inspect.isclass(obj):
+ roots = [base for base in inspect.getmro(obj) if base not in (type, object)]
+ else:
+ roots = [obj]
+
+ members = []
+ seen = set()
+ for root in roots:
+ for name, value in _iteritems(getattr(root, "__dict__", {})):
+ if name not in seen and filter(name, value):
+ members.append((name, value))
+
+ seen.add(name)
+
+ members = sorted(members)
+ for _, value in members:
+ if recursive and inspect.isclass(value):
+ stack.append(value)
+
+ out.extend(members)
+
+ return out
+
+
+def _module_iterator(root, recursive=True):
+ """Iterate over modules.
+
+ Parameters
+ ----------
+ root : module
+ Root module or package to iterate from.
+ recursive : bool
+ `True` to iterate within subpackages.
+
+ Yields
+ ------
+ module
+ The modules found.
+ """
+ yield root
+
+ stack = collections.deque((root,))
+ while stack:
+ package = stack.popleft()
+ # The '__path__' attribute of a package might return a list of paths if
+ # the package is referenced as a namespace.
+ paths = getattr(package, "__path__", [])
+ for path in paths:
+ modules = pkgutil.iter_modules([path])
+ for _, name, is_package in modules:
+ module_name = f"{package.__name__}.{name}"
+ module = sys.modules.get(module_name, None)
+ if module is None:
+ # Import the module through the finder to support package
+ # namespaces.
+ try:
+ module = _load_module(module_name)
+ except ImportError:
+ # Missing modules means that they are optional. Therefore, skip them.
+ continue
+ if is_package:
+ if recursive:
+ stack.append(module)
+ yield module
+ else:
+ yield module
+
+
+def find_package_version(module_name: str):
+ if not module_name:
+ return None
+ mods = module_name.split(".")
+ for i in range(len(mods), 0, -1):
+ m = ".".join(mods[:i])
+ mod = importlib.import_module(m)
+ if hasattr(mod, "__version__"):
+ return mod.__version__
+ return None
+
+
+def _true(*args, **kwargs): # pylint: disable=unused-argument
+ """Return `True`."""
+ return True
diff --git a/src/biofit/utils/logging.py b/src/biofit/utils/logging.py
new file mode 100644
index 0000000..2cf8c55
--- /dev/null
+++ b/src/biofit/utils/logging.py
@@ -0,0 +1,447 @@
+# coding=utf-8
+# Copyright 2020 Optuna, Hugging Face
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Logging utilities."""
+
+import functools
+import logging
+import os
+import sys
+import threading
+import warnings
+from logging import (
+ CRITICAL, # NOQA
+ DEBUG, # NOQA
+ ERROR, # NOQA
+ FATAL, # NOQA
+ INFO, # NOQA
+ NOTSET, # NOQA
+ WARN, # NOQA
+ WARNING, # NOQA
+)
+from logging import captureWarnings as _captureWarnings
+from typing import Optional
+
+from biocore.utils.import_util import is_datasets_available
+from tqdm import auto as tqdm_lib
+
+_lock = threading.Lock()
+_default_handler: Optional[logging.Handler] = None
+
+log_levels = {
+ "detail": logging.DEBUG, # will also print filename and line number
+ "debug": logging.DEBUG,
+ "info": logging.INFO,
+ "warning": logging.WARNING,
+ "error": logging.ERROR,
+ "critical": logging.CRITICAL,
+}
+
+_default_log_level = logging.WARNING
+
+_tqdm_active = True
+
+
+def _get_default_logging_level():
+ """
+ If BIOFIT_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
+ not - fall back to `_default_log_level`
+ """
+ env_level_str = os.getenv("BIOFIT_VERBOSITY", None)
+ if env_level_str:
+ if env_level_str in log_levels:
+ return log_levels[env_level_str]
+ else:
+ logging.getLogger().warning(
+ f"Unknown option BIOFIT_VERBOSITY={env_level_str}, "
+ f"has to be one of: {', '.join(log_levels.keys())}"
+ )
+
+ return _default_log_level
+
+
+def _get_library_name() -> str:
+ return __name__.split(".")[0]
+
+
+def _get_library_root_logger() -> logging.Logger:
+ return logging.getLogger(_get_library_name())
+
+
+def _configure_library_root_logger() -> None:
+ global _default_handler
+
+ with _lock:
+ if _default_handler:
+ # This library has already configured the library root logger.
+ return
+ _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
+ # set defaults based on https://github.com/pyinstaller/pyinstaller/issues/7334#issuecomment-1357447176
+ if sys.stderr is None:
+ sys.stderr = open(os.devnull, "w")
+
+ _default_handler.flush = sys.stderr.flush
+
+ # Apply our default configuration to the library root logger.
+ library_root_logger = _get_library_root_logger()
+ library_root_logger.addHandler(_default_handler)
+ library_root_logger.setLevel(_get_default_logging_level())
+ # if logging level is debug, we add pathname and lineno to formatter for easy debugging
+ if os.getenv("BIOFIT_VERBOSITY", None) == "detail":
+ formatter = logging.Formatter(
+ "[%(levelname)s|%(pathname)s:%(lineno)s] %(asctime)s >> %(message)s"
+ )
+ _default_handler.setFormatter(formatter)
+
+ library_root_logger.propagate = False
+
+
+def _reset_library_root_logger() -> None:
+ global _default_handler
+
+ with _lock:
+ if not _default_handler:
+ return
+
+ library_root_logger = _get_library_root_logger()
+ library_root_logger.removeHandler(_default_handler)
+ library_root_logger.setLevel(logging.NOTSET)
+ _default_handler = None
+
+
+def get_log_levels_dict():
+ return log_levels
+
+
+def captureWarnings(capture):
+ """
+ Calls the `captureWarnings` method from the logging library to enable management of the warnings emitted by the
+ `warnings` library.
+
+ Read more about this method here:
+ https://docs.python.org/3/library/logging.html#integration-with-the-warnings-module
+
+ All warnings will be logged through the `py.warnings` logger.
+
+ Careful: this method also adds a handler to this logger if it does not already have one, and updates the logging
+ level of that logger to the library's root logger.
+ """
+ logger = get_logger("py.warnings")
+
+ if not logger.handlers:
+ logger.addHandler(_default_handler)
+
+ logger.setLevel(_get_library_root_logger().level)
+
+ _captureWarnings(capture)
+
+
+def get_logger(name: Optional[str] = None) -> logging.Logger:
+ """
+ Return a logger with the specified name.
+
+ This function is not supposed to be directly accessed unless you are writing a custom biofit module.
+ """
+
+ if name is None:
+ name = _get_library_name()
+
+ _configure_library_root_logger()
+ return logging.getLogger(name)
+
+
+def get_verbosity() -> int:
+ """
+ Return the current level for the biofit's root logger as an int.
+
+ Returns:
+ `int`: The logging level.
+
+
+
+ biofit has following logging levels:
+
+ - 50: `biofit.logging.CRITICAL` or `biofit.logging.FATAL`
+ - 40: `biofit.logging.ERROR`
+ - 30: `biofit.logging.WARNING` or `biofit.logging.WARN`
+ - 20: `biofit.logging.INFO`
+ - 10: `biofit.logging.DEBUG`
+
+ """
+
+ _configure_library_root_logger()
+ return _get_library_root_logger().getEffectiveLevel()
+
+
+def set_verbosity(verbosity: int) -> None:
+ """
+ Set the verbosity level for the 🤗 biofit's root logger.
+
+ Args:
+ verbosity (`int`):
+ Logging level, e.g., one of:
+
+ - `biofit.logging.CRITICAL` or `biofit.logging.FATAL`
+ - `biofit.logging.ERROR`
+ - `biofit.logging.WARNING` or `biofit.logging.WARN`
+ - `biofit.logging.INFO`
+ - `biofit.logging.DEBUG`
+ """
+
+ _configure_library_root_logger()
+ _get_library_root_logger().setLevel(verbosity)
+
+
+def set_verbosity_info():
+ """Set the verbosity to the `INFO` level."""
+ return set_verbosity(INFO)
+
+
+def set_verbosity_warning():
+ """Set the verbosity to the `WARNING` level."""
+ return set_verbosity(WARNING)
+
+
+def set_verbosity_debug():
+ """Set the verbosity to the `DEBUG` level."""
+ return set_verbosity(DEBUG)
+
+
+def set_verbosity_error():
+ """Set the verbosity to the `ERROR` level."""
+ return set_verbosity(ERROR)
+
+
+def disable_default_handler() -> None:
+ """Disable the default handler of the biofit's root logger."""
+
+ _configure_library_root_logger()
+
+ assert _default_handler is not None
+ _get_library_root_logger().removeHandler(_default_handler)
+
+
+def enable_default_handler() -> None:
+ """Enable the default handler of the biofit's root logger."""
+
+ _configure_library_root_logger()
+
+ assert _default_handler is not None
+ _get_library_root_logger().addHandler(_default_handler)
+
+
+def set_default_handler(handler: logging.Handler) -> None:
+ """Set the default handler of the biofit's root logger."""
+
+ _reset_library_root_logger()
+
+ assert handler is not None
+ _default_handler = handler
+
+
+def add_handler(handler: logging.Handler) -> None:
+ """adds a handler to the biofit's root logger."""
+
+ _configure_library_root_logger()
+
+ assert handler is not None
+ _get_library_root_logger().addHandler(handler)
+
+
+def set_formatter(formatter: logging.Formatter) -> None:
+ """adds a formatter to the biofit's default handler"""
+ global _default_handler
+
+ _configure_library_root_logger()
+
+ assert formatter is not None
+ _default_handler.setFormatter(formatter)
+
+
+def remove_handler(handler: logging.Handler) -> None:
+ """removes given handler from the biofit's root logger."""
+
+ _configure_library_root_logger()
+
+ assert handler is not None and handler not in _get_library_root_logger().handlers
+ # log optuna study to
+ _get_library_root_logger().removeHandler(handler)
+
+
+def disable_propagation() -> None:
+ """
+ Disable propagation of the library log outputs. Note that log propagation is disabled by default.
+ """
+
+ _configure_library_root_logger()
+ _get_library_root_logger().propagate = False
+
+
+def enable_propagation() -> None:
+ """
+ Enable propagation of the library log outputs. Please disable the biofit's default handler to
+ prevent double logging if the root logger has been configured.
+ """
+
+ _configure_library_root_logger()
+ _get_library_root_logger().propagate = True
+
+
+def enable_explicit_format() -> None:
+ """
+ Enable explicit formatting for every biofit's logger. The explicit formatter is as follows:
+ ```
+ [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
+ ```
+ All handlers currently bound to the root logger are affected by this method.
+ """
+ handlers = _get_library_root_logger().handlers
+
+ for handler in handlers:
+ formatter = logging.Formatter(
+ "[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s"
+ )
+ handler.setFormatter(formatter)
+
+
+def reset_format() -> None:
+ """
+ Resets the formatting for biofit's loggers.
+
+ All handlers currently bound to the root logger are affected by this method.
+ """
+ handlers = _get_library_root_logger().handlers
+
+ for handler in handlers:
+ handler.setFormatter(None)
+
+
+def warning_advice(self, *args, **kwargs):
+ """
+ This method is identical to `logger.warning()`, but if env var BIOFIT_NO_ADVISORY_WARNINGS=1 is set, this
+ warning will not be printed
+ """
+ no_advisory_warnings = os.getenv("BIOFIT_NO_ADVISORY_WARNINGS", False)
+ if no_advisory_warnings:
+ return
+ self.warning(*args, **kwargs)
+
+
+logging.Logger.warning_advice = warning_advice
+
+
+@functools.lru_cache(None)
+def warning_once(self, *args, **kwargs):
+ """
+ This method is identical to `logger.warning()`, but will emit the warning with the same message only once
+
+ Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache.
+ The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to
+ another type of cache that includes the caller frame information in the hashing function.
+ """
+ self.warning(*args, **kwargs)
+
+
+logging.Logger.warning_once = warning_once
+
+
+class EmptyTqdm:
+ """Dummy tqdm which doesn't do anything."""
+
+ def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
+ self._iterator = args[0] if args else None
+
+ def __iter__(self):
+ return iter(self._iterator)
+
+ def __getattr__(self, _):
+ """Return empty function."""
+
+ def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
+ return
+
+ return empty_fn
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type_, value, traceback):
+ return
+
+
+class _tqdm_cls:
+ def __call__(self, *args, **kwargs):
+ if _tqdm_active:
+ return tqdm_lib.tqdm(*args, **kwargs)
+ else:
+ return EmptyTqdm(*args, **kwargs)
+
+ def set_lock(self, *args, **kwargs):
+ self._lock = None
+ if _tqdm_active:
+ return tqdm_lib.tqdm.set_lock(*args, **kwargs)
+
+ def get_lock(self):
+ if _tqdm_active:
+ return tqdm_lib.tqdm.get_lock()
+
+
+tqdm = _tqdm_cls()
+
+
+def is_progress_bar_enabled() -> bool:
+ """Return a boolean indicating whether tqdm progress bars are enabled."""
+ global _tqdm_active
+ return bool(_tqdm_active)
+
+
+def enable_progress_bar():
+ """Enable tqdm progress bar."""
+ global _tqdm_active
+ _tqdm_active = True
+
+
+def disable_progress_bar():
+ """Disable tqdm progress bar."""
+ global _tqdm_active
+ _tqdm_active = False
+
+
+def silence():
+ warnings.filterwarnings("ignore")
+ verb = get_verbosity()
+ is_pb_enabled = is_progress_bar_enabled()
+ set_verbosity(0)
+ disable_progress_bar()
+ dataset_verb = None
+ is_ds_pb_enabled = None
+ if is_datasets_available():
+ import datasets
+
+ dataset_verb = datasets.logging.get_verbosity()
+ is_ds_pb_enabled = datasets.is_progress_bar_enabled()
+ datasets.disable_progress_bars()
+ return verb, dataset_verb, is_pb_enabled, is_ds_pb_enabled
+
+
+def unsilence(verbosity, dataset_verbosity, is_pb_enabled, is_ds_pb_enabled):
+ warnings.filterwarnings("default")
+ set_verbosity(verbosity)
+ if is_pb_enabled:
+ enable_progress_bar()
+ if is_ds_pb_enabled is not None and is_datasets_available():
+ import datasets
+
+ datasets.logging.set_verbosity(dataset_verbosity)
+ datasets.enable_progress_bars()
diff --git a/src/biofit/utils/py_util.py b/src/biofit/utils/py_util.py
new file mode 100644
index 0000000..37c33ce
--- /dev/null
+++ b/src/biofit/utils/py_util.py
@@ -0,0 +1,209 @@
+# Copyright 2024 Patrick Smyth and Hugging Face authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import multiprocessing
+import os
+import queue
+import random
+import sys
+from datetime import date, datetime, time, timedelta
+from decimal import Decimal
+from queue import Empty
+from typing import Callable, Iterable, Set, TypeVar, Union
+
+import multiprocess
+import multiprocess.pool
+import numpy as np
+from biocore.utils.import_util import (
+ is_datasets_available,
+ is_polars_available,
+ is_tf_available,
+ is_torch_available,
+)
+
+Y = TypeVar("Y")
+
+
+def is_temporal(val):
+ return isinstance(val, (datetime, date, time, timedelta))
+
+
+def is_decimal(val):
+ return isinstance(val, Decimal)
+
+
+def as_py(val):
+ # return as int64
+ if isinstance(val, datetime):
+ return val.timestamp()
+ elif isinstance(val, date):
+ return val.toordinal()
+ elif isinstance(val, time):
+ return val.hour * 3600 + val.minute * 60 + val.second + val.microsecond / 1e6
+ elif isinstance(val, timedelta):
+ return val.total_seconds()
+ elif isinstance(val, Decimal):
+ return float(val)
+ elif isinstance(val, np.float16):
+ return float(val)
+ elif isinstance(val, bytes):
+ return f"base64:{val.hex()}"
+ elif isinstance(val, dict):
+ return {k: as_py(v) for k, v in val.items()}
+ return val
+
+
+def enable_full_determinism(seed: int, warn_only: bool = False):
+ """
+ Helper function for reproducible behavior during distributed training. See
+ - https://pytorch.org/docs/stable/notes/randomness.html for pytorch
+ - https://www.tensorflow.org/api_docs/python/tf/config/experimental/enable_op_determinism for tensorflow
+ """
+ # set seed first
+ set_seed(seed)
+
+ if is_torch_available():
+ import torch
+
+ # Enable PyTorch deterministic mode. This potentially requires either the environment
+ # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
+ # depending on the CUDA version, so we set them both here
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
+ torch.use_deterministic_algorithms(True, warn_only=warn_only)
+
+ # Enable CUDNN deterministic mode
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+ if is_tf_available():
+ import tensorflow as tf
+
+ tf.config.experimental.enable_op_determinism()
+
+
+def set_seed(seed: int):
+ """
+ Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch` and/or `tf` (if installed).
+
+ Args:
+ seed (`int`): The seed to set.
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+
+ if is_polars_available():
+ from polars import set_random_seed
+
+ set_random_seed(seed)
+
+ def is_torch_npu_available():
+ try:
+ import torch
+
+ return torch.npu.is_available()
+ except ImportError:
+ return False
+
+ if is_torch_available():
+ import torch
+
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+ if is_tf_available():
+ import tensorflow as tf
+
+ tf.random.set_seed(seed)
+
+
+# This function is a copy of the one in datasets.utils.py_util.py, Copyright 2024
+# Hugging Face authors. Licensed under the Apache 2.0 license. See the license file for
+# details at https://www.apache.org/licenses/LICENSE-2.0
+def _get_pool_pid(
+ pool: Union[multiprocessing.pool.Pool, multiprocess.pool.Pool],
+) -> Set[int]:
+ return {f.pid for f in pool._pool}
+
+
+# This function is a copy of the one in datasets.utils.py_util.py, Copyright 2024
+# Hugging Face authors. Licensed under the Apache 2.0 license. See the license file for
+# details at https://www.apache.org/licenses/LICENSE-2.0
+def _write_generator_to_queue(
+ queue: queue.Queue, func: Callable[..., Iterable[Y]], kwargs: dict
+) -> int:
+ for i, result in enumerate(func(**kwargs)):
+ queue.put(result)
+ return i
+
+
+# This function is a copy of the one in datasets.utils.py_util.py, Copyright 2024
+# Hugging Face authors. Licensed under the Apache 2.0 license. See the license file for
+# details at https://www.apache.org/licenses/LICENSE-2.0
+def iflatmap_unordered(
+ pool: Union[multiprocessing.pool.Pool, multiprocess.pool.Pool],
+ func: Callable[..., Iterable[Y]],
+ *,
+ kwargs_iterable: Iterable[dict],
+) -> Iterable[Y]:
+ """
+ Similar to `itertools.chain.from_iterable(map(func, iterable))` but with a potentially more efficient implementation
+ that doesn't require storing all the intermediate results in memory.
+
+ Args:
+ func (Callable): The function to apply to each element of the input iterable.
+ iterable (Iterable): The input iterable.
+
+ Returns:
+ Iterator: The flattened iterable.
+ """
+ if is_datasets_available():
+ from datasets.utils.py_utils import iflatmap_unordered
+
+ return iflatmap_unordered(pool, func, kwargs_iterable=kwargs_iterable)
+ else:
+ initial_pool_pid = _get_pool_pid(pool)
+ pool_changed = False
+ manager_cls = (
+ multiprocessing.Manager
+ if isinstance(pool, multiprocessing.pool.Pool)
+ else multiprocess.Manager
+ )
+ with manager_cls() as manager:
+ queue = manager.Queue()
+ async_results = [
+ pool.apply_async(_write_generator_to_queue, (queue, func, kwargs))
+ for kwargs in kwargs_iterable
+ ]
+ try:
+ while True:
+ try:
+ yield queue.get(timeout=0.05)
+ except Empty:
+ if (
+ all(async_result.ready() for async_result in async_results)
+ and queue.empty()
+ ):
+ break
+ if _get_pool_pid(pool) != initial_pool_pid:
+ pool_changed = True
+ # One of the subprocesses has died. We should not wait forever.
+ raise RuntimeError(
+ "One of the subprocesses has abruptly died during map operation."
+ "To debug the error, disable multiprocessing."
+ )
+ finally:
+ if not pool_changed:
+ # we get the result in case there's an error to raise
+ [async_result.get(timeout=0.05) for async_result in async_results]
diff --git a/src/biofit/utils/recorder.py b/src/biofit/utils/recorder.py
new file mode 100644
index 0000000..9bb1aff
--- /dev/null
+++ b/src/biofit/utils/recorder.py
@@ -0,0 +1,156 @@
+from typing import List, Optional, Union, Callable
+import sys
+import inspect
+import importlib
+from functools import wraps
+from pathlib import Path
+
+
+from .. import config
+from .file_utils import is_file_name, PathLike
+from . import logging
+
+logger = logging.get_logger(__name__)
+
+
+def load_module_or_class(full_path):
+ if full_path is None:
+ return None
+ try:
+ module_path, entity_name = full_path.rsplit(".", 1)
+ module = (
+ importlib.import_module(module_path)
+ if module_path not in sys.modules
+ else sys.modules[module_path]
+ )
+ except ImportError:
+ # Handle the case where the module can't be imported
+ logger.error(f"Failed to import module: {module_path}")
+ return None
+
+ try:
+ entity = (
+ getattr(module, entity_name)
+ if entity_name not in sys.modules
+ else sys.modules[entity_name]
+ )
+ except AttributeError:
+ # Handle the case where the entity isn't found in the module
+ logger.error(f"{entity_name} not found in module: {module_path}")
+ return None
+
+ return entity
+
+
+def load_method(full_path, method_name):
+ entity = load_module_or_class(full_path)
+ if entity:
+ return getattr(entity, method_name)
+ return None
+
+
+def _get_cache_dir(
+ cache_files: Optional[List[dict]] = None, cache_file_name: Optional[PathLike] = None
+) -> Optional[Path]:
+ # check if cache_file_name is a path
+ cache_dir = None
+ if cache_file_name is not None:
+ if isinstance(cache_file_name, PathLike):
+ cache_file_name = Path(cache_file_name)
+ if "/" in cache_file_name.as_posix():
+ cache_dir = cache_file_name.resolve().parent
+
+ if not cache_dir and cache_files:
+ cache_dir = Path(cache_files[0]["filename"]).resolve().parent
+
+ return cache_dir
+
+
+def _get_cache_info(
+ cache_files: Optional[List[dict]] = None,
+ cache_dir: Optional[Path] = None,
+ cache_file_name: Optional[Union[str, Path]] = None,
+ file_ext=".arrow",
+):
+ if cache_file_name:
+ if is_file_name(cache_file_name):
+ cache_dir = cache_dir or Path.cwd()
+ cache_file_name = Path(cache_dir) / Path(cache_file_name).with_suffix(
+ file_ext
+ )
+ else:
+ cache_file_name = Path(cache_file_name).with_suffix(file_ext)
+
+ cache_dir = _get_cache_dir(cache_files, cache_file_name)
+ return cache_file_name, cache_dir
+
+
+UNRECORDED_METHODS = ["train_test_split"]
+_RECORDER_ACTIVE = False
+
+
+def pre_recording(func, *args, **kwargs):
+ new_fingerprint = kwargs.get("new_fingerprint", kwargs.get("fingerprint", None))
+
+ signature = inspect.signature(func)
+ cache_file_name = kwargs.get("cache_file_name", None)
+ if not cache_file_name and "cache_file_name" in signature.parameters:
+ arg_pos = list(signature.parameters).index("cache_file_name")
+ if len(args) > arg_pos:
+ cache_file_name = args[arg_pos]
+ cache_dir = None
+ self = args[0] if args else kwargs.get("self", None)
+ if getattr(self, "cache_files", None):
+ cache_file_name, cache_dir = _get_cache_info(
+ self.cache_files, cache_dir, cache_file_name
+ )
+ if isinstance(cache_dir, Path):
+ cache_dir = cache_dir.resolve().as_posix()
+ if isinstance(cache_file_name, Path):
+ cache_file_name = cache_file_name.resolve().as_posix()
+
+ return {
+ "cache_file_name": cache_file_name,
+ "cache_dir": cache_dir,
+ "new_fingerprint": new_fingerprint,
+ }
+
+
+def toggle_recorder():
+ global _RECORDER_ACTIVE
+ _RECORDER_ACTIVE = not _RECORDER_ACTIVE
+
+
+def record_step(
+ replay_func: Optional[str] = None,
+ pre_recording: Optional[Callable] = pre_recording,
+ post_recording: Optional[Callable] = None,
+):
+ def decorator(func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ if not config.RECORDER_ENABLED or _RECORDER_ACTIVE:
+ return func(*args, **kwargs)
+
+ toggle_recorder()
+ extra_info = {}
+ if pre_recording:
+ extra_info = pre_recording(func, *args, **kwargs)
+ out = func(*args, **kwargs)
+ if post_recording:
+ out = post_recording(
+ out,
+ func.__name__,
+ args,
+ kwargs,
+ replay_func=replay_func,
+ info=extra_info,
+ )
+
+ toggle_recorder()
+
+ return out
+
+ return wrapper
+
+ return decorator
diff --git a/src/biofit/utils/table_util.py b/src/biofit/utils/table_util.py
new file mode 100644
index 0000000..2af9312
--- /dev/null
+++ b/src/biofit/utils/table_util.py
@@ -0,0 +1,863 @@
+import copy
+import inspect
+import os
+import re
+import tempfile
+from collections import defaultdict
+from functools import wraps
+from pathlib import Path
+from typing import List, Union
+
+import pyarrow as pa
+from biocore.utils.import_util import (
+ is_datasets_available,
+ is_rpy2_arrow_available,
+ is_rpy2_available,
+ requires_backends,
+)
+
+from biofit import config
+
+from . import logging
+from .file_utils import move_temp_file
+
+logger = logging.get_logger(__name__)
+
+
+# Function adapted from the Hugging Face Team, Copyright 2024 The Hugging Face Team
+# Licensed under the Apache License, Version 2.0. You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Modified by: Patrick Smyth
+# Date: 2024
+# Summary of changes:
+# - Changed names from pyarrow to pa for consistency
+def _arrow_to_datasets_dtype(arrow_type: pa.DataType) -> str:
+ """
+ _arrow_to_datasets_dtype takes a pa.DataType and converts it to a datasets string dtype.
+ In effect, `dt == string_to_arrow(_arrow_to_datasets_dtype(dt))`
+ """
+ if pa.types.is_null(arrow_type):
+ return "null"
+ elif pa.types.is_boolean(arrow_type):
+ return "bool"
+ elif pa.types.is_int8(arrow_type):
+ return "int8"
+ elif pa.types.is_int16(arrow_type):
+ return "int16"
+ elif pa.types.is_int32(arrow_type):
+ return "int32"
+ elif pa.types.is_int64(arrow_type):
+ return "int64"
+ elif pa.types.is_uint8(arrow_type):
+ return "uint8"
+ elif pa.types.is_uint16(arrow_type):
+ return "uint16"
+ elif pa.types.is_uint32(arrow_type):
+ return "uint32"
+ elif pa.types.is_uint64(arrow_type):
+ return "uint64"
+ elif pa.types.is_float16(arrow_type):
+ return "float16" # pa dtype is "halffloat"
+ elif pa.types.is_float32(arrow_type):
+ return "float32" # pa dtype is "float"
+ elif pa.types.is_float64(arrow_type):
+ return "float64" # pa dtype is "double"
+ elif pa.types.is_time32(arrow_type):
+ return f"time32[{pa.type_for_alias(str(arrow_type)).unit}]"
+ elif pa.types.is_time64(arrow_type):
+ return f"time64[{pa.type_for_alias(str(arrow_type)).unit}]"
+ elif pa.types.is_timestamp(arrow_type):
+ if arrow_type.tz is None:
+ return f"timestamp[{arrow_type.unit}]"
+ elif arrow_type.tz:
+ return f"timestamp[{arrow_type.unit}, tz={arrow_type.tz}]"
+ else:
+ raise ValueError(f"Unexpected timestamp object {arrow_type}.")
+ elif pa.types.is_date32(arrow_type):
+ return "date32" # pa dtype is "date32[day]"
+ elif pa.types.is_date64(arrow_type):
+ return "date64" # pa dtype is "date64[ms]"
+ elif pa.types.is_duration(arrow_type):
+ return f"duration[{arrow_type.unit}]"
+ elif pa.types.is_decimal128(arrow_type):
+ return f"decimal128({arrow_type.precision}, {arrow_type.scale})"
+ elif pa.types.is_decimal256(arrow_type):
+ return f"decimal256({arrow_type.precision}, {arrow_type.scale})"
+ elif pa.types.is_binary(arrow_type):
+ return "binary"
+ elif pa.types.is_large_binary(arrow_type):
+ return "large_binary"
+ elif pa.types.is_string(arrow_type):
+ return "string"
+ elif pa.types.is_large_string(arrow_type):
+ return "large_string"
+ elif pa.types.is_dictionary(arrow_type):
+ return _arrow_to_datasets_dtype(arrow_type.value_type)
+ else:
+ raise ValueError(
+ f"Arrow type {arrow_type} does not have a datasets dtype equivalent."
+ )
+
+
+# Function adapted from the Hugging Face Team, Copyright 2024 The Hugging Face Team
+# Licensed under the Apache License, Version 2.0. You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+def string_to_arrow(datasets_dtype: str) -> pa.DataType:
+ """
+ string_to_arrow takes a datasets string dtype and converts it to a pyarrow.DataType.
+
+ In effect, `dt == string_to_arrow(_arrow_to_datasets_dtype(dt))`
+
+ This is necessary because the datasets.Value() primitive type is constructed using a string dtype
+
+ Value(dtype=str)
+
+ But Features.type (via `get_nested_type()` expects to resolve Features into a pyarrow Schema,
+ which means that each Value() must be able to resolve into a corresponding pyarrow.DataType, which is the
+ purpose of this function.
+ """
+
+ def _dtype_error_msg(dtype, pa_dtype, examples=None, urls=None):
+ msg = f"{dtype} is not a validly formatted string representation of the pyarrow {pa_dtype} type."
+ if examples:
+ examples = (
+ ", ".join(examples[:-1]) + " or " + examples[-1]
+ if len(examples) > 1
+ else examples[0]
+ )
+ msg += f"\nValid examples include: {examples}."
+ if urls:
+ urls = (
+ ", ".join(urls[:-1]) + " and " + urls[-1] if len(urls) > 1 else urls[0]
+ )
+ msg += f"\nFor more insformation, see: {urls}."
+ return msg
+
+ if datasets_dtype in pa.__dict__:
+ return pa.__dict__[datasets_dtype]()
+
+ if (datasets_dtype + "_") in pa.__dict__:
+ return pa.__dict__[datasets_dtype + "_"]()
+
+ timestamp_matches = re.search(r"^timestamp\[(.*)\]$", datasets_dtype)
+ if timestamp_matches:
+ timestamp_internals = timestamp_matches.group(1)
+ internals_matches = re.search(
+ r"^(s|ms|us|ns),\s*tz=([a-zA-Z0-9/_+\-:]*)$", timestamp_internals
+ )
+ if timestamp_internals in ["s", "ms", "us", "ns"]:
+ return pa.timestamp(timestamp_internals)
+ elif internals_matches:
+ return pa.timestamp(internals_matches.group(1), internals_matches.group(2))
+ else:
+ raise ValueError(
+ _dtype_error_msg(
+ datasets_dtype,
+ "timestamp",
+ examples=["timestamp[us]", "timestamp[us, tz=America/New_York"],
+ urls=[
+ "https://arrow.apache.org/docs/python/generated/pyarrow.timestamp.html"
+ ],
+ )
+ )
+
+ duration_matches = re.search(r"^duration\[(.*)\]$", datasets_dtype)
+ if duration_matches:
+ duration_internals = duration_matches.group(1)
+ if duration_internals in ["s", "ms", "us", "ns"]:
+ return pa.duration(duration_internals)
+ else:
+ raise ValueError(
+ _dtype_error_msg(
+ datasets_dtype,
+ "duration",
+ examples=["duration[s]", "duration[us]"],
+ urls=[
+ "https://arrow.apache.org/docs/python/generated/pyarrow.duration.html"
+ ],
+ )
+ )
+
+ time_matches = re.search(r"^time(.*)\[(.*)\]$", datasets_dtype)
+ if time_matches:
+ time_internals_bits = time_matches.group(1)
+ if time_internals_bits == "32":
+ time_internals_unit = time_matches.group(2)
+ if time_internals_unit in ["s", "ms"]:
+ return pa.time32(time_internals_unit)
+ else:
+ raise ValueError(
+ f"{time_internals_unit} is not a valid unit for the pyarrow time32 type. Supported units: s (second) and ms (millisecond)."
+ )
+ elif time_internals_bits == "64":
+ time_internals_unit = time_matches.group(2)
+ if time_internals_unit in ["us", "ns"]:
+ return pa.time64(time_internals_unit)
+ else:
+ raise ValueError(
+ f"{time_internals_unit} is not a valid unit for the pyarrow time64 type. Supported units: us (microsecond) and ns (nanosecond)."
+ )
+ else:
+ raise ValueError(
+ _dtype_error_msg(
+ datasets_dtype,
+ "time",
+ examples=["time32[s]", "time64[us]"],
+ urls=[
+ "https://arrow.apache.org/docs/python/generated/pyarrow.time32.html",
+ "https://arrow.apache.org/docs/python/generated/pyarrow.time64.html",
+ ],
+ )
+ )
+
+ decimal_matches = re.search(r"^decimal(.*)\((.*)\)$", datasets_dtype)
+ if decimal_matches:
+ decimal_internals_bits = decimal_matches.group(1)
+ if decimal_internals_bits == "128":
+ decimal_internals_precision_and_scale = re.search(
+ r"^(\d+),\s*(-?\d+)$", decimal_matches.group(2)
+ )
+ if decimal_internals_precision_and_scale:
+ precision = decimal_internals_precision_and_scale.group(1)
+ scale = decimal_internals_precision_and_scale.group(2)
+ return pa.decimal128(int(precision), int(scale))
+ else:
+ raise ValueError(
+ _dtype_error_msg(
+ datasets_dtype,
+ "decimal128",
+ examples=["decimal128(10, 2)", "decimal128(4, -2)"],
+ urls=[
+ "https://arrow.apache.org/docs/python/generated/pyarrow.decimal128.html"
+ ],
+ )
+ )
+ elif decimal_internals_bits == "256":
+ decimal_internals_precision_and_scale = re.search(
+ r"^(\d+),\s*(-?\d+)$", decimal_matches.group(2)
+ )
+ if decimal_internals_precision_and_scale:
+ precision = decimal_internals_precision_and_scale.group(1)
+ scale = decimal_internals_precision_and_scale.group(2)
+ return pa.decimal256(int(precision), int(scale))
+ else:
+ raise ValueError(
+ _dtype_error_msg(
+ datasets_dtype,
+ "decimal256",
+ examples=["decimal256(30, 2)", "decimal256(38, -4)"],
+ urls=[
+ "https://arrow.apache.org/docs/python/generated/pyarrow.decimal256.html"
+ ],
+ )
+ )
+ else:
+ raise ValueError(
+ _dtype_error_msg(
+ datasets_dtype,
+ "decimal",
+ examples=["decimal128(12, 3)", "decimal256(40, 6)"],
+ urls=[
+ "https://arrow.apache.org/docs/python/generated/pyarrow.decimal128.html",
+ "https://arrow.apache.org/docs/python/generated/pyarrow.decimal256.html",
+ ],
+ )
+ )
+
+ raise ValueError(
+ f"Neither {datasets_dtype} nor {datasets_dtype + '_'} seems to be a pyarrow data type. "
+ f"Please make sure to use a correct data type, see: "
+ f"https://arrow.apache.org/docs/python/api/datatypes.html#factory-functions"
+ )
+
+
+def is_binary_like(data_type: pa.DataType) -> bool:
+ return pa.types.is_binary(data_type) or pa.types.is_unicode(data_type)
+
+
+def is_large_binary_like(data_type: pa.DataType) -> bool:
+ return pa.types.is_large_binary(data_type) or pa.types.is_large_unicode(data_type)
+
+
+def is_fixed_width(data_type: pa.DataType) -> bool:
+ return (
+ pa.types.is_primitive(data_type)
+ or pa.types.is_dictionary(data_type)
+ or pa.types.is_large_list(data_type)
+ )
+
+
+def upcast_tables(tables: List[pa.Table]):
+ cols = None
+ cols_diff_dtypes = defaultdict(set)
+ for table in tables:
+ if cols is None:
+ cols = {col.name: col.type for col in table.schema}
+ else:
+ cols.update(
+ {col.name: col.type for col in table.schema if col.name not in cols}
+ )
+ for col in table.schema:
+ if cols[col.name] != col.type:
+ cols_diff_dtypes[col.name].update({col.type, cols[col.name]})
+ cols_to_cast = {}
+ for col, types in cols_diff_dtypes.items():
+ types = [_arrow_to_datasets_dtype(t) for t in types]
+ cols_to_cast[col] = string_to_arrow(determine_upcast(types))
+
+ new_tables = []
+ for table in tables:
+ cols = table.schema.names
+ casts = pa.schema(
+ [
+ (col.name, cols_to_cast[col.name]) if col.name in cols_to_cast else col
+ for col in table.schema
+ ]
+ )
+ new_tables.append(table.cast(casts))
+ return new_tables
+
+
+def determine_upcast(dtype_list):
+ """
+ Determines the upcasted data type from a list of data types based on a predefined hierarchy.
+
+ Args:
+ dtype_list (list): List of data types to be considered for upcasting.
+
+ Returns:
+ str: The upcasted data type.
+ """
+
+ # Define the hierarchy of dtypes with their priority levels
+ dtype_hierarchy = {
+ "null": 0,
+ "bool": 1,
+ "int8": 2,
+ "int16": 3,
+ "int32": 4,
+ "int": 5,
+ "int64": 5,
+ "uint8": 6,
+ "uint16": 7,
+ "uint32": 8,
+ "uint64": 9,
+ "float16": 10,
+ "float": 11,
+ "float32": 11,
+ "float64": 12,
+ "time32[s]": 13,
+ "time32[ms]": 14,
+ "time64[us]": 15,
+ "time64[ns]": 16,
+ "timestamp[s]": 17,
+ "timestamp[ms]": 18,
+ "timestamp[us]": 19,
+ "timestamp[ns]": 20,
+ "date32": 21,
+ "date64": 22,
+ "duration[s]": 23,
+ "duration[ms]": 24,
+ "duration[us]": 25,
+ "duration[ns]": 26,
+ "decimal128": 27,
+ "decimal256": 28,
+ "binary": 29,
+ "large_binary": 30,
+ "string": 31,
+ "large_string": 32,
+ }
+
+ highest_priority = 0
+ upcast_dtype = "null" # the lowest in hierarchy
+
+ for dtype in dtype_list:
+ priority = dtype_hierarchy.get(dtype, None)
+ if not priority:
+ raise ValueError(f"Invalid dtype found {dtype}")
+ if priority > highest_priority:
+ highest_priority = priority
+ upcast_dtype = dtype
+
+ return upcast_dtype
+
+
+def concat_blocks(pa_tables: List[pa.Table], axis: int = 0, append=True) -> pa.Table:
+ if axis == 0:
+ # we set promote=True to fill missing columns with null values
+ if config.PYARROW_VERSION.major < 14:
+ return pa.concat_tables(pa_tables, promote=True)
+ else:
+ return pa.concat_tables(pa_tables, promote_options="permissive")
+ elif axis == 1:
+ for i, table in enumerate(pa_tables):
+ if i == 0:
+ pa_table = table
+ else:
+ for name, col in zip(table.column_names, table.columns):
+ if append:
+ pa_table = pa_table.append_column(name, col)
+ else:
+ pa_table = pa_table.add_column(0, name, col)
+ return pa_table
+ else:
+ raise ValueError("'axis' must be either 0 or 1")
+
+
+def _arrow_join(
+ left_table: pa.Table,
+ right_table: pa.Table,
+ keys: Union[str, List[str]],
+ right_keys: Union[str, List[str]] = None,
+ join_type="left outer",
+ left_suffix=None,
+ right_suffix=None,
+ coalesce_keys=True,
+ use_threads=True,
+) -> pa.Table:
+ """Extends arrow's join to support joining on struct columns at any nested level.
+
+ Args:
+ right_table (`Table`):
+ The table to join to the current one, acting as the right table in the join operation.
+
+ keys (`Union[str, List[str]]`):
+ The columns from current table that should be used as keys of the join operation left side.
+
+ right_keys (`Union[str, List[str]]`, *optional*):
+ The columns from the right_table that should be used as keys on the join operation right side. When None use the same key names as the left table.
+
+ join_type (`str`, Defaults to 'left outer'):
+ The kind of join that should be performed, one of (“left semi”, “right semi”, “left anti”, “right anti”, “inner”, “left outer”, “right outer”, “full outer”)
+
+ left_suffix (`str`, *optional*):
+ Which suffix to add to left column names. This prevents confusion when the columns in left and right tables have colliding names.
+
+ right_suffix (`str`, *optional*):
+ Which suffix to add to the right column names. This prevents confusion when the columns in left and right tables have colliding names.
+
+ coalesce_keys (`bool`, Defaults to True):
+ If the duplicated keys should be omitted from one of the sides in the join result.
+
+ use_threads (`bool`, Defaults to True):
+ Whether to use multithreading or not.
+
+
+ """
+
+ if isinstance(keys, str):
+ keys = [keys]
+
+ if right_keys is None:
+ right_keys = keys
+ else:
+ if isinstance(right_keys, str):
+ right_keys = [right_keys]
+
+ left_cast = []
+ right_cast = []
+
+ def _get_struct_columns_and_prepare_cast(
+ left_schema: pa.Schema, right_schema: pa.Schema
+ ) -> dict:
+ struct_columns = []
+ keys = set()
+
+ def process_nested_dict(key, nested_schema):
+ if key in keys:
+ return
+ keys.add(key)
+ # Check if the value is a list of dictionaries
+ if pa.types.is_struct(nested_schema):
+ full_keys = []
+ original_keys = []
+ for nested_field in nested_schema:
+ # Recursively process the nested dictionary
+ original_keys.append(nested_field.name)
+ full_key = f"{key}.{nested_field.name}"
+ full_keys.append(full_key)
+ process_nested_dict(full_key, nested_field.type)
+ struct_columns.append((key, original_keys, full_keys))
+
+ for field in left_schema:
+ if pa.types.is_struct(field.type):
+ left_cast.append(pa.field(field.name, field.type))
+ process_nested_dict(field.name, field.type)
+ elif (
+ pa.types.is_list(field.type)
+ or pa.types.is_large_list(field.type)
+ or pa.types.is_fixed_size_list(field.type)
+ ):
+ raise pa.ArrowNotImplementedError(
+ "Joining on lists is not supported. Please load them as a struct before joining. For example, change the column with value [1,2,3] to {'a': 1, 'b': 2, 'c': 3}"
+ )
+ elif pa.types.is_null(field.type):
+ left_cast.append(pa.field(field.name, pa.string()))
+ elif (
+ not pa.types.is_fixed_size_list
+ and not is_fixed_width(field.type)
+ and not is_binary_like(field.type)
+ and not is_large_binary_like(field.type)
+ ):
+ left_cast.append(pa.field(field.name, pa.string()))
+ else:
+ left_cast.append(pa.field(field.name, field.type))
+
+ for field in right_schema:
+ if pa.types.is_struct(field.type):
+ right_cast.append(pa.field(field.name, field.type))
+ process_nested_dict(field.name, field.type)
+ elif (
+ pa.types.is_list(field.type)
+ or pa.types.is_large_list(field.type)
+ or pa.types.is_fixed_size_list(field.type)
+ ):
+ raise pa.ArrowNotImplementedError(
+ "Joining on lists is not supported. Please load them as a struct before joining. For example, change the column with value [1,2,3] to {'a': 1, 'b': 2, 'c': 3}"
+ )
+ elif pa.types.is_null(field.type):
+ right_cast.append(pa.field(field.name, pa.string()))
+ elif (
+ not pa.types.is_fixed_size_list
+ and not is_fixed_width(field.type)
+ and not is_binary_like(field.type)
+ and not is_large_binary_like(field.type)
+ ):
+ right_cast.append(pa.field(field.name, pa.string()))
+ else:
+ right_cast.append(pa.field(field.name, field.type))
+
+ return struct_columns
+
+ def reconstruct_table(joined_table, struct_cols):
+ """
+ Reconstruct struct columns from flattened columns based on original schema.
+ """
+ for column, orig_names, nested_columns in struct_cols:
+ nested_data = {
+ sub_col: joined_table[col].combine_chunks()
+ for col, sub_col in zip(nested_columns, orig_names)
+ }
+ reconstructed_nested_col = pa.StructArray.from_arrays(
+ nested_data.values(), nested_data.keys()
+ )
+ index = joined_table.schema.get_field_index(nested_columns[0])
+ joined_table = joined_table.drop(nested_columns).add_column(
+ index, column, reconstructed_nested_col
+ )
+ return joined_table
+
+ def get_nested_level(schema, level=0):
+ """
+ Get the maximum level of nesting in a struct schema.
+ """
+ max_level = level
+ for field in schema:
+ if pa.types.is_struct(field.type):
+ max_level = max(
+ max_level, get_nested_level(field.type, level=level + 1)
+ )
+ return max_level
+
+ def flatten_table(table, max_level, current_level=0):
+ """
+ Recursively flatten a table.
+ """
+ if current_level == max_level:
+ return table
+ return flatten_table(
+ table.flatten(), max_level, current_level=current_level + 1
+ )
+
+ left_nested_level = get_nested_level(left_table.schema)
+ right_nested_level = get_nested_level(right_table.schema)
+ struct_columns = _get_struct_columns_and_prepare_cast(
+ left_table.schema, right_table.schema
+ )
+
+ colliding_names = list(
+ (set(right_table.column_names) & set(left_table.column_names))
+ - set(keys)
+ - set(right_keys)
+ )
+
+ for left_key, right_key in zip(keys, right_keys):
+ if left_table[left_key].type != right_table[right_key].type:
+ index = right_table.schema.get_field_index(right_key)
+ right_cast[index] = pa.field(right_key, left_table[left_key].type)
+
+ left_table = left_table.cast(pa.schema(left_cast)) if left_cast else left_table
+ right_table = right_table.cast(pa.schema(right_cast)) if right_cast else right_table
+
+ return reconstruct_table(
+ flatten_table(left_table, left_nested_level).join(
+ flatten_table(right_table.drop(colliding_names), right_nested_level),
+ keys=keys,
+ right_keys=right_keys,
+ join_type=join_type,
+ left_suffix=left_suffix,
+ right_suffix=right_suffix,
+ coalesce_keys=coalesce_keys,
+ use_threads=use_threads,
+ ),
+ struct_columns,
+ )
+
+
+def init_arrow_buffer_and_writer(
+ cache_file_name,
+ fingerprint=None,
+ features=None,
+ writer_batch_size=None,
+ keep_in_memory=False,
+ disable_nullable=False,
+):
+ if isinstance(cache_file_name, Path):
+ cache_file_name = cache_file_name.resolve().as_posix()
+ if is_datasets_available():
+ from datasets import arrow_writer
+
+ # Prepare output buffer and batched writer in memory or on file if we update the table
+ if keep_in_memory or cache_file_name is None:
+ buf_writer = pa.BufferOutputStream()
+ tmp_file = None
+ writer = arrow_writer.ArrowWriter(
+ features=features,
+ stream=buf_writer,
+ writer_batch_size=writer_batch_size,
+ update_features=False,
+ fingerprint=fingerprint,
+ disable_nullable=disable_nullable,
+ )
+ else:
+ buf_writer = None
+ tmp_file = tempfile.NamedTemporaryFile(
+ "wb", dir=os.path.dirname(cache_file_name), delete=False
+ )
+ writer = arrow_writer.ArrowWriter(
+ features=features,
+ path=tmp_file.name,
+ writer_batch_size=writer_batch_size,
+ update_features=False,
+ fingerprint=fingerprint,
+ disable_nullable=disable_nullable,
+ )
+ else:
+ buf_writer = None
+ tmp_file = tempfile.NamedTemporaryFile(
+ "wb", dir=os.path.dirname(cache_file_name), delete=False
+ )
+ if features is not None and isinstance(features, dict):
+ features = pa.schema([(k, string_to_arrow(v)) for k, v in features.items()])
+ if features is not None and not isinstance(features, pa.Schema):
+ raise ValueError(
+ "features must be a dict or a pa.Schema when datasets is not installed"
+ )
+ writer = pa.ipc.RecordBatchStreamWriter(tmp_file, schema=features)
+
+ def finalize():
+ writer.close()
+
+ writer.finalize = finalize
+ return buf_writer, writer, tmp_file
+
+
+def write_arrow_table(
+ table: pa.Table,
+ cache_file_name: Union[str, Path],
+ fingerprint=None,
+ features=None,
+ writer_batch_size=None,
+ disable_nullable=False,
+):
+ # requires_backends("write_arrow_table", "datasets")
+ from datasets import Features
+
+ if features is None:
+ if isinstance(table, (pa.ChunkedArray, pa.Array)):
+ data_type = table.type
+ features = _arrow_to_datasets_dtype(data_type)
+ features = Features.from_arrow_schema(table.schema)
+ tmp_file = None
+
+ try:
+ _, writer, tmp_file = init_arrow_buffer_and_writer(
+ cache_file_name,
+ fingerprint=fingerprint,
+ features=features,
+ writer_batch_size=writer_batch_size,
+ disable_nullable=disable_nullable,
+ )
+ writer.write_table(table)
+ writer.finalize()
+ except:
+ if tmp_file and os.path.exists(tmp_file.name):
+ tmp_file.close()
+ os.remove(tmp_file.name)
+ raise
+ return move_temp_file(tmp_file, cache_file_name)
+
+
+def python_to_r(value):
+ """Converts a Python primitive type into its R equivalent."""
+
+ if value is None:
+ return "NULL"
+
+ elif isinstance(value, bool):
+ return "TRUE" if value else "FALSE"
+
+ elif isinstance(value, (int, float)):
+ return str(value)
+
+ elif isinstance(value, complex):
+ return f"complex(real = {value.real}, imaginary = {value.imag})"
+
+ elif isinstance(value, str):
+ return f'"{value}"' # Add quotes around strings for R
+
+ elif isinstance(value, list) or isinstance(value, tuple):
+ # Recursively convert elements of the list/tuple to R format
+ converted_elements = ", ".join(python_to_r(elem) for elem in value)
+ return f"c({converted_elements})"
+
+ elif isinstance(value, dict):
+ # Convert dict to named list in R
+ converted_items = [
+ f"{python_to_r(k)} = {python_to_r(v)}" for k, v in value.items()
+ ]
+ return f"list({', '.join(converted_items)})"
+
+ elif isinstance(value, set):
+ # Convert set to unique vector in R
+ converted_elements = ", ".join(python_to_r(elem) for elem in sorted(value))
+ return f"unique(c({converted_elements}))"
+
+ elif isinstance(value, range):
+ # Convert range to sequence in R
+ return f"seq({value.start}, {value.stop - 1}, by = {value.step})"
+
+ elif isinstance(value, bytes):
+ # Handle bytes conversion to raw in R (if needed, otherwise leave as unsupported)
+ return f"as.raw(c({', '.join(hex(b) for b in value)}))"
+
+ else:
+ raise TypeError(f"Type {type(value)} is not supported.")
+
+
+def debug_r_script(path):
+ path = Path(path).expanduser().resolve()
+ path.mkdir(parents=True, exist_ok=True)
+ path = path.as_posix()
+
+ def decorator(func):
+ @wraps(func)
+ def wrapper(*args, **func_kwargs):
+ name = func.__name__
+ if is_rpy2_arrow_available():
+ from rpy2.robjects import ListVector, default_converter
+ from rpy2_arrow.arrow import converter
+
+ _converter = default_converter + converter
+ elif is_rpy2_available():
+ from rpy2.robjects import ListVector, conversion
+
+ _converter = (
+ conversion.get_conversion()
+ if getattr(conversion, "get_conversion", None)
+ else conversion.converter
+ )
+ else:
+ # suggest installing rpy2_arrow if rpy2 is not available
+ requires_backends(name, "rpy2_arrow")
+
+ import rpy2.rinterface
+ import rpy2.robjects as ro
+ from rpy2.rinterface import Sexp
+
+ self = args[0]
+ if hasattr(self, "plotter"):
+ method_args = [
+ p.name
+ for p in inspect.signature(self.plotter).parameters.values()
+ if p != p.VAR_KEYWORD
+ ]
+ else:
+ method_args = [
+ p.name
+ for p in inspect.signature(func).parameters.values()
+ if p != p.VAR_KEYWORD
+ ][1:]
+
+ kwargs = copy.deepcopy(func_kwargs)
+ for i, arg in enumerate(args[1:]):
+ kwargs[method_args[i]] = arg
+
+ kwargs.update(self.config.get_params())
+
+ def convert_to_r(arg):
+ if isinstance(arg, Sexp):
+ return arg
+ elif arg is None:
+ return ro.NULL
+ elif isinstance(arg, (list, tuple)):
+ return _converter.py2rpy([convert_to_r(a) for a in arg])
+ elif isinstance(arg, dict):
+ return ListVector(arg)
+ else:
+ return _converter.py2rpy(arg)
+
+ debug_script = "options(error = traceback)\n"
+ debug_script += (
+ f'R_SCRIPTS_PATH <- "{config.R_SCRIPTS.resolve().as_posix()}"\n'
+ )
+ debug_script += 'source(file.path(R_SCRIPTS_PATH, "plotting_utils.R"))\n'
+ debug_script += 'require("arrow")\n'
+ tmp_dir = tempfile.gettempdir()
+ with rpy2.rinterface.local_context() as r_context:
+ save_vars = []
+ for k, v in kwargs.items():
+ if isinstance(v, (pa.Array, pa.Table, pa.ChunkedArray)):
+ cache_file_name = f"{tmp_dir}/{k}.arrow"
+ write_arrow_table(v, cache_file_name)
+ debug_script += (
+ f"{k} <- as.data.frame(arrow::read_ipc_stream("
+ f"'{cache_file_name}', as_data_frame = FALSE))\n"
+ )
+ else:
+ try:
+ debug_script += f"{k} <- {python_to_r(v)}\n"
+ except TypeError:
+ save_vars.append(k)
+ r_context[k] = convert_to_r(arg)
+ if save_vars:
+ code = ", ".join(r_context.keys())
+ code = f"save({code}, file = '{tmp_dir}/data.RData')"
+ ro.r(code)
+ debug_script = debug_script + f"load('{tmp_dir}/data.RData')\n"
+ script = self.r_caller.r_code.split("\n")
+ first_start = False
+ second_start = False
+ for line in script:
+ if first_start and second_start:
+ debug_script += f"{line[2:]}\n"
+ if line.startswith(self.config.main_method):
+ first_start = True
+ if first_start and line.endswith("{"):
+ second_start = True
+ elif line.startswith("}") and first_start and second_start:
+ break
+ with open(f"{path}/debug.R", "w") as f:
+ f.write(debug_script)
+ return func(*args, **func_kwargs)
+
+ return wrapper
+
+ return decorator
+
+
+def read_arrow_table(
+ cache_file_name: Union[str, Path],
+):
+ return pa.ipc.open_stream(cache_file_name).read_all()
diff --git a/src/biofit/utils/types.py b/src/biofit/utils/types.py
new file mode 100644
index 0000000..8ba7f9b
--- /dev/null
+++ b/src/biofit/utils/types.py
@@ -0,0 +1,27 @@
+class Unset:
+ """A class to represent an unset value.
+
+ This is used to differentiate between a value that is not set and a value that is set to None.
+ Optionally, a description can be provided to indicate what the actual default is when Unset is used.
+ """
+
+ def __init__(self, description: str = ""):
+ self.description = description
+
+ def __repr__(self) -> str:
+ """Return the string representation of the class, including any default description if provided.
+
+ Returns:
+ The string representation of the class.
+ """
+ if self.description:
+ return f"Unset (default: {self.description})"
+ return "Unset"
+
+ def __bool__(self) -> bool:
+ """Return False when the class is used in a boolean context.
+
+ Returns:
+ False
+ """
+ return False
diff --git a/src/biofit/utils/version.py b/src/biofit/utils/version.py
new file mode 100644
index 0000000..fbd2e38
--- /dev/null
+++ b/src/biofit/utils/version.py
@@ -0,0 +1,115 @@
+# Copyright 2020 The HuggingFace Datasets Authors and the TensorFlow Datasets Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Modified by: Patrick Smyth
+# Date: 2024
+# Summary of changes:
+# - Added __version__ variable
+"""Version utils."""
+
+import dataclasses
+import re
+from dataclasses import dataclass
+from functools import total_ordering
+from typing import Optional, Union
+
+_VERSION_REG = re.compile(r"^(?P\d+)" r"\.(?P\d+)" r"\.(?P\d+)$")
+
+
+@total_ordering
+@dataclass
+class Version:
+ """Dataset version `MAJOR.MINOR.PATCH`.
+
+ Args:
+ version_str (`str`):
+ The dataset version.
+ description (`str`):
+ A description of what is new in this version.
+ major (`str`):
+ minor (`str`):
+ patch (`str`):
+
+ Example:
+
+ ```py
+ >>> VERSION = datasets.Version("1.0.0")
+ ```
+ """
+
+ version_str: str
+ description: Optional[str] = None
+ major: Optional[Union[str, int]] = None
+ minor: Optional[Union[str, int]] = None
+ patch: Optional[Union[str, int]] = None
+
+ def __post_init__(self):
+ self.major, self.minor, self.patch = _str_to_version_tuple(self.version_str)
+
+ def __repr__(self):
+ return f"{self.tuple[0]}.{self.tuple[1]}.{self.tuple[2]}"
+
+ @property
+ def tuple(self):
+ return self.major, self.minor, self.patch
+
+ def _validate_operand(self, other):
+ if isinstance(other, str):
+ return Version(other)
+ elif isinstance(other, Version):
+ return other
+ raise TypeError(f"{other} (type {type(other)}) cannot be compared to version.")
+
+ def __eq__(self, other):
+ try:
+ other = self._validate_operand(other)
+ except (TypeError, ValueError):
+ return False
+ else:
+ return self.tuple == other.tuple
+
+ def __lt__(self, other):
+ other = self._validate_operand(other)
+ return self.tuple < other.tuple
+
+ def __hash__(self):
+ return hash(_version_tuple_to_str(self.tuple))
+
+ @classmethod
+ def from_dict(cls, dic):
+ field_names = {f.name for f in dataclasses.fields(cls)}
+ return cls(**{k: v for k, v in dic.items() if k in field_names})
+
+ def _to_yaml_string(self) -> str:
+ return self.version_str
+
+
+def _str_to_version_tuple(version_str):
+ """Return the tuple (major, minor, patch) version extracted from the str."""
+ res = _VERSION_REG.match(version_str)
+ if not res:
+ raise ValueError(
+ f"Invalid version '{version_str}'. Format should be x.y.z with {{x,y,z}} being digits."
+ )
+ return tuple(
+ int(v) for v in [res.group("major"), res.group("minor"), res.group("patch")]
+ )
+
+
+def _version_tuple_to_str(version_tuple):
+ """Return the str version from the version tuple (major, minor, patch)."""
+ return ".".join(str(v) for v in version_tuple)
+
+
+__version__ = "0.0.0"
diff --git a/src/biofit/visualization/__init__.py b/src/biofit/visualization/__init__.py
new file mode 100644
index 0000000..3248c7a
--- /dev/null
+++ b/src/biofit/visualization/__init__.py
@@ -0,0 +1,22 @@
+# ruff: noqa
+from .feature_importance import (
+ FeatureImportancePlotter,
+ FeatureImportancePlotterConfig,
+ FeatureImportancePlotterConfigForOTU,
+)
+from .sample_metadata import SampleMetadataPlotter, SampleMetadataPlotterConfig
+from .plotting_utils import (
+ generate_violin,
+ generate_barplot,
+ generate_scatterplot,
+ generate_histogram,
+ generate_comparison_histogram,
+ plot_correlation,
+ plot_feature_distribution,
+ compare_feature_distributions,
+ plot_sample_distribution,
+ compare_sample_distributions,
+ plot_dimension_reduction,
+ plot_feature_importance,
+ plot_sample_metadata,
+)
diff --git a/src/biofit/visualization/barplot.py b/src/biofit/visualization/barplot.py
new file mode 100644
index 0000000..0b651c5
--- /dev/null
+++ b/src/biofit/visualization/barplot.py
@@ -0,0 +1,198 @@
+import textwrap
+from dataclasses import dataclass, field
+from typing import List, Optional, Type
+
+from biocore import DataHandler
+
+import biofit.config as config
+from biofit.integration.biosets import get_feature
+from biofit.integration.R import RCaller
+from biofit.processing import SelectedColumnTypes, sync_backup_config
+from biofit.utils.types import Unset
+from biofit.visualization.plotting import BasePlotter, PlotterConfig
+
+
+@dataclass
+class BarPlotConfig(PlotterConfig):
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [None, get_feature("TARGET_FEATURE_TYPES"), None],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [None, get_feature("TARGET_FEATURE_TYPES"), None],
+ init=False,
+ repr=False,
+ )
+
+ _fit_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES"), None, None],
+ init=False,
+ repr=False,
+ )
+ _transform_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES"), None, None],
+ init=False,
+ repr=False,
+ )
+
+ processor_type: str = field(default="scaling", init=False, repr=False)
+ r_source: str = field(
+ default=(config.R_SCRIPTS / "plotting_utils.R").as_posix(),
+ init=False,
+ repr=False,
+ )
+ main_method: str = field(default="generate_barplot", init=False, repr=False)
+
+ label_name: str = None
+ value_name: Optional[str] = None
+ groupby: Optional[str] = None
+ xlab: Optional[str] = None
+ ylab: Optional[str] = None
+ title: str = "Bar Plot"
+ col_set: str = "Set1"
+ col_outline: str = "grey30"
+ col_labels: str = "black"
+ cols: Optional[str] = None
+ prop: bool = False
+ add_count_lab: bool = True
+ vars_as_entered: bool = False
+ legend_position: str = "top"
+ font_size: float = 3.25
+
+
+class BarPlotter(BasePlotter):
+ _config_class = BarPlotConfig
+ config: BarPlotConfig
+
+ def __init__(
+ self,
+ xlab: Optional[str] = None,
+ ylab: Optional[str] = None,
+ title: str = Unset('"Bar Plot"'),
+ col_set: str = Unset('"Set1"'),
+ col_labels: str = Unset('"black"'),
+ col_outline: str = Unset('"grey30"'),
+ cols: Optional[List[str]] = Unset("None"),
+ prop: bool = Unset("False"),
+ add_count_lab: bool = Unset("True"),
+ vars_as_entered: bool = Unset("False"),
+ legend_position: str = Unset('"top"'),
+ font_size: float = Unset("3.25"),
+ path=None,
+ config: Optional[BarPlotConfig] = None,
+ ):
+ super().__init__(
+ config=config,
+ xlab=xlab,
+ ylab=ylab,
+ title=title,
+ col_set=col_set,
+ col_labels=col_labels,
+ cols=cols,
+ prop=prop,
+ add_count_lab=add_count_lab,
+ vars_as_entered=vars_as_entered,
+ legend_position=legend_position,
+ font_size=font_size,
+ path=path,
+ )
+
+ self = self.__post_init__()
+
+ def __post_init__(self):
+ if self.config.r_source and self.config.main_method:
+ self.r_caller = RCaller.from_script(self.config.r_source)
+ self.r_caller.verify_r_dependencies()
+ enter_code = textwrap.dedent(
+ """
+ suppressPackageStartupMessages(require(ggplot2))
+ """
+ )
+ exit_code = textwrap.dedent(
+ """
+ ggplot2::ggsave(path, results)
+ """
+ )
+ self.plotter = self.r_caller.get_method(
+ self.config.main_method, enter_code, exit_code
+ )
+
+ return self
+
+ @sync_backup_config
+ def set_params(self, **kwargs):
+ self.config = self.config.replace_defaults(**kwargs)
+ self = self.__post_init__()
+ return self
+
+ def plot(
+ self,
+ x,
+ y,
+ group,
+ label_name: str = None,
+ value_name: Optional[str] = None,
+ groupby: Optional[str] = None,
+ xlab: Optional[str] = Unset("None"),
+ ylab: Optional[str] = Unset("None"),
+ title: str = Unset('"Bar Plot"'),
+ col_set: str = Unset('"Set1"'),
+ col_labels: str = Unset('"black"'),
+ col_outline: str = Unset('"grey30"'),
+ cols: Optional[List[str]] = Unset("None"),
+ prop: bool = Unset("False"),
+ add_count_lab: bool = Unset("True"),
+ vars_as_entered: bool = Unset("False"),
+ legend_position: str = Unset('"top"'),
+ font_size: float = Unset("3.25"),
+ path: str = None,
+ device: str = "pdf",
+ fingerprint: str = None,
+ unused_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = True,
+ show: bool = True,
+ ):
+ self.config._input_columns = self._set_input_columns_and_arity(
+ value_name, label_name, groupby
+ )
+ return self._plot(
+ x,
+ y,
+ group,
+ xlab=xlab,
+ ylab=ylab,
+ title=title,
+ col_set=col_set,
+ col_labels=col_labels,
+ col_outline=col_outline,
+ cols=cols,
+ prop=prop,
+ add_count_lab=add_count_lab,
+ vars_as_entered=vars_as_entered,
+ legend_position=legend_position,
+ font_size=font_size,
+ path=path,
+ device=device,
+ fingerprint=fingerprint,
+ unused_columns=unused_columns,
+ raise_if_missing=raise_if_missing,
+ show=show,
+ )
+
+ def plot_dataset(self, x, y, group):
+ from datasets import Dataset as HfDataset
+
+ if isinstance(x, (Dataset, HfDataset)):
+ x = decode(x)
+ if isinstance(group, (Dataset, HfDataset)):
+ group = decode(group)
+ return self.plot_arrow(
+ DataHandler.to_arrow(x),
+ DataHandler.to_arrow(y) if y is not None else None,
+ DataHandler.to_arrow(group) if group is not None else None,
+ )
+
+ def plot_arrow(self, x, y=None, group=None):
+ kwargs = self.config.get_params()
+ self.plotter(x, y, group, **kwargs)
diff --git a/src/biofit/visualization/dimension_reduction.py b/src/biofit/visualization/dimension_reduction.py
new file mode 100644
index 0000000..4ff8f02
--- /dev/null
+++ b/src/biofit/visualization/dimension_reduction.py
@@ -0,0 +1,435 @@
+import json
+import warnings
+from dataclasses import dataclass, field
+
+import numpy as np
+import pandas as pd
+import pyarrow as pa
+from biocore import DataHandler
+from biocore.utils.import_util import requires_backends
+
+from biofit.integration.biosets import get_feature
+from biofit.processing import SelectedColumnTypes, SelectedFeatureTypes
+from biofit.utils.types import Unset
+from biofit.visualization.plotting import BasePlotter, PlotterConfig
+from biofit.visualization.plotting_utils import get_distinct_colors
+
+
+@dataclass
+class DimensionReducerPlotterConfig(PlotterConfig):
+ processor_type: str = field(default="feature_extractor", init=False, repr=False)
+ _fit_input_feature_types: SelectedFeatureTypes = field(
+ default_factory=lambda: [
+ None,
+ get_feature("TARGET_FEATURE_TYPES"),
+ get_feature("METADATA_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: SelectedFeatureTypes = field(
+ default_factory=lambda: [
+ None,
+ get_feature("TARGET_FEATURE_TYPES"),
+ get_feature("METADATA_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+ _fit_unused_feature_types: SelectedFeatureTypes = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES"), None, None],
+ init=False,
+ repr=False,
+ )
+ _transform_unused_feature_types: SelectedFeatureTypes = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES"), None, None],
+ init=False,
+ repr=False,
+ )
+
+ title: str = "Dimension Reduction Plot"
+ colormap: str = "nipy_spectral"
+ n_components: int = 3
+ label_column: str = None
+ group_column: str = None
+
+
+class DimensionReductionPlotter(BasePlotter):
+ """Base class for feature extraction processors."""
+
+ _config_class = DimensionReducerPlotterConfig
+ config: DimensionReducerPlotterConfig
+
+ def __init__(
+ self,
+ label_column: str = None,
+ group_column: str = None,
+ title: str = Unset('"Dimension Reduction Plot"'),
+ colormap: str = Unset('"nipy_spectral"'),
+ n_components: int = Unset("3"),
+ config: DimensionReducerPlotterConfig = None,
+ **kwargs,
+ ):
+ super().__init__(
+ config=config,
+ n_components=n_components,
+ label_column=label_column,
+ group_column=group_column,
+ title=title,
+ colormap=colormap,
+ **kwargs,
+ )
+
+ def plot(
+ self,
+ X,
+ labels=None,
+ group=None,
+ input_columns: SelectedColumnTypes = None,
+ label_column: SelectedColumnTypes = None,
+ group_column: SelectedColumnTypes = None,
+ title: str = Unset('"Dimension Reduction Plot"'),
+ colormap: str = Unset('"nipy_spectral"'),
+ n_components: int = Unset("3"),
+ dimension_reducer=None,
+ path: str = None,
+ device: str = "pdf",
+ fingerprint: str = None,
+ unused_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = True,
+ show=True,
+ ):
+ """Plot the PCoA plot.
+
+ Args:
+ X (array_like):
+ The input data. Can be a Dataset, polars/pandas DataFrame, numpy array, or Arrow table.
+ labels (array_like, *optional*):
+ The labels for the data. Dictates the color of the points in the plot.
+ Must be an array of values with the same length as the number of rows in X.
+ If not provided, the points will be colored by group.
+ group (array_like, *optional*):
+ The group for the data. Dictates the shape of the points in the plot.
+ Must be an array of values with the same length as the number of rows in X.
+ pcoa (array_like, *optional*):
+ The fitted PCoAFeatureExtractor object. Used to extract the eigvals for the explained variance ratio.
+ If not provided, the explained variance ratio will not be displayed.
+ **kwargs:
+ Additional keyword arguments to pass to the plot:
+ n_components (int, 2):
+ The number of components to plot. Defaults to 2.
+ label_name (str, *optional*):
+ The name of the labels. Used for the legend and retrieving the values from X if labels is not provided.
+ group_name (str, *optional*):
+ The name of the group. Used for the legend and retrieving the values from X if group is not provided.
+
+ """
+ self.config._input_columns = self._set_input_columns_and_arity(
+ input_columns, label_column, group_column
+ )
+ return self._plot(
+ X,
+ labels,
+ group,
+ n_components=n_components,
+ dimension_reducer=dimension_reducer,
+ path=path,
+ device=device,
+ fingerprint=fingerprint,
+ unused_columns=unused_columns,
+ raise_if_missing=raise_if_missing,
+ show=show,
+ )
+
+ def plot_dataset(self, X, labels=None, group=None, dimension_reducer=None):
+ from biosets import decode
+
+ if labels is not None:
+ labels = decode(labels)
+
+ return self.plot_arrow(
+ DataHandler.to_arrow(X),
+ DataHandler.to_arrow(labels) if labels is not None else None,
+ DataHandler.to_arrow(group) if group is not None else None,
+ dimension_reducer,
+ )
+
+ def plot_arrow(self, X, labels=None, group=None, dimension_reducer=None):
+ def get_int2str_converter(x) -> np.ndarray:
+ if not isinstance(x, pa.Table):
+ return lambda x: x
+
+ def get_features(x: pa.Table):
+ metadata = x.schema.metadata
+ if metadata:
+ if b"huggingface" in metadata:
+ metadata = json.loads(metadata[b"huggingface"].decode("utf-8"))
+
+ if "info" in metadata and "features" in metadata["info"]:
+ return metadata["info"]["features"]
+
+ return {}
+
+ metadata = get_features(x)
+ lab_col = x.column_names[0]
+ if metadata and lab_col in metadata:
+ label_metadata = metadata[lab_col]
+ if label_metadata["_type"] == "ClassLabel":
+ label_metadata.pop("_type")
+ cls_label = get_feature("ClassLabel")(**label_metadata)
+ return cls_label.int2str
+ return lambda x: x
+
+ def pairplot(
+ tbl: pd.DataFrame,
+ n_components,
+ ex_var_ratio=None,
+ label_name=None,
+ group_name=None,
+ marker_dict=None,
+ color_dict=None,
+ ):
+ requires_backends(pairplot, "seaborn")
+ requires_backends(pairplot, "matplotlib")
+ import matplotlib.pyplot as plt
+ import seaborn as sns
+ from matplotlib.lines import Line2D
+
+ features = tbl.columns[:n_components]
+ fig, axes = plt.subplots(n_components, n_components, figsize=(15, 15))
+
+ for i, f1 in enumerate(features):
+ for j, f2 in enumerate(features):
+ ax = axes[i, j]
+ if i != j:
+ if label_name is not None and group_name is not None:
+ for (lab, grp), df_group in tbl.groupby(
+ [
+ label_name,
+ group_name,
+ ]
+ ):
+ sns.scatterplot(
+ x=f2,
+ y=f1,
+ data=df_group,
+ ax=ax,
+ color=color_dict[lab],
+ marker=marker_dict[grp],
+ )
+ elif label_name is not None:
+ for lab, df_lab in tbl.groupby(label_name):
+ sns.scatterplot(
+ x=f2,
+ y=f1,
+ data=df_lab,
+ ax=ax,
+ color=color_dict[lab],
+ )
+ elif group_name is not None:
+ for grp, df_group in tbl.groupby(group_name):
+ sns.scatterplot(
+ x=f2,
+ y=f1,
+ data=df_group,
+ ax=ax,
+ marker=marker_dict[grp],
+ )
+ else:
+ sns.scatterplot(x=f2, y=f1, data=tbl, ax=ax)
+
+ else:
+ if label_name is not None:
+ for lab, df_lab in tbl.groupby(label_name):
+ sns.kdeplot(
+ df_lab[f1],
+ ax=ax,
+ color=color_dict[lab],
+ fill=True,
+ )
+ else:
+ sns.kdeplot(tbl[f1], ax=ax, fill=True)
+ if ex_var_ratio is not None:
+ plt.text(
+ 0.1,
+ 0.9,
+ f"Dim{i + 1}: {float(ex_var_ratio[i]) * 100:.2f}%",
+ transform=ax.transAxes,
+ )
+ # Set only left and bottom borders
+ ax.spines["top"].set_visible(False)
+ ax.spines["right"].set_visible(False)
+
+ # Enable ticks on left and bottom only
+ ax.xaxis.set_tick_params(which="both", bottom=True)
+ ax.yaxis.set_tick_params(which="both", left=True)
+
+ # Set axis labels
+ if i == n_components - 1:
+ ax.set_xlabel(f2)
+ else:
+ ax.set_xlabel("")
+ ax.set_xticklabels([])
+
+ if j == 0:
+ ax.set_ylabel(f1)
+ else:
+ ax.set_ylabel("")
+ ax.set_yticklabels([])
+
+ if label_name:
+ label_lines = [
+ Line2D(
+ [0],
+ [0],
+ color=color,
+ marker="o",
+ linestyle="",
+ markersize=10,
+ )
+ for color in color_dict.values()
+ ]
+ u_labels = list(color_dict.keys())
+ n_labels = len(u_labels)
+ label_height = 0.5 + 0.216 / 15 * (n_labels + 1) / 2 + 0.01
+ label_legend = fig.legend(
+ label_lines,
+ u_labels,
+ title=label_name,
+ loc="center left",
+ bbox_to_anchor=(0.90, label_height),
+ )
+
+ if group_name:
+ group_lines = [
+ Line2D(
+ [0],
+ [0],
+ color="gray",
+ marker=marker,
+ linestyle="",
+ markersize=10,
+ )
+ for marker in marker_dict.values()
+ ]
+ u_group = list(marker_dict.keys())
+ n_group = len(u_group)
+
+ group_height = 0.5 - 0.216 / 15 * (n_group + 1) / 2 - 0.01
+ fig.legend(
+ group_lines,
+ list(marker_dict.keys()),
+ title=group_name,
+ loc="center left",
+ bbox_to_anchor=(0.90, group_height),
+ )
+ if label_name:
+ fig.add_artist(label_legend)
+
+ if self.config.title:
+ fig.suptitle(self.config.title, fontsize=16)
+
+ return fig
+
+ warnings.filterwarnings("ignore")
+ n_components = self.config.n_components
+ tbl = DataHandler.to_format(X, "pandas").iloc[:, :n_components]
+ labs = None
+ label_name = self.config.label_column or "labels"
+ u_labels = None
+ grp = None
+ grp_name = self.config.group_column or None
+ u_group = None
+ if labels is not None:
+ labs = DataHandler.to_format(labels, "pandas")
+ if isinstance(labs, pd.DataFrame):
+ label_name = labs.columns[0] if label_name is None else label_name
+ labs = labs.iloc[:, 0]
+ elif isinstance(labs, pd.Series):
+ label_name = labs.name if label_name is None else label_name
+ labs = labs.fillna("None")
+ u_labels = np.unique(labs)
+ if group is not None:
+ grp = DataHandler.to_format(group, "pandas")
+ if isinstance(grp, pd.DataFrame):
+ grp_name = grp.columns[0] if grp_name is None else grp_name
+ grp = grp.iloc[:, 0]
+ elif isinstance(grp, pd.Series):
+ grp_name = grp.name if grp_name is None else grp_name
+ u_group = np.unique(grp)
+
+ if labs is not None:
+ converter = get_int2str_converter(labels)
+ dtype = DataHandler.get_dtypes(labels)
+ if "int" in next(iter(dtype.values())):
+ encodings = labs.to_numpy()
+ encodings_dims = encodings.shape
+ if len(encodings_dims) == 1:
+ labs = converter(encodings)
+ elif encodings_dims[1] == 1:
+ labs = converter(encodings[:, 0])
+
+ if labs is not None:
+ u_labels = np.unique(labs)
+
+ ex_var_ratio = None
+ if dimension_reducer is not None:
+ if hasattr(dimension_reducer.config, "eigvals"):
+ eigvals = dimension_reducer.config.eigvals
+ ex_var_ratio = eigvals / eigvals.sum()
+ elif hasattr(dimension_reducer.config, "estimator"):
+ ex_var_ratio = (
+ dimension_reducer.config.estimator.explained_variance_ratio_
+ )
+
+ tbl.columns = [f"Dim{i + 1}" for i in range(len(tbl.columns))]
+
+ marker_styles = ["o", "v", "^", "<", ">", "s", "p", "*", "H", "X"]
+ if u_group is not None:
+ marker_dict = {
+ label: marker_styles[i % len(marker_styles)]
+ for i, label in enumerate(u_group)
+ }
+ else:
+ marker_dict = None
+
+ if labs is not None and grp is not None:
+ tbl = tbl.assign(**{label_name: labs, grp_name: grp})
+ elif labs is not None:
+ tbl = tbl.assign(**{label_name: labs})
+ elif grp is not None:
+ tbl = tbl.assign(**{grp_name: grp})
+
+ if labs is not None:
+ color_dict = {
+ label: color
+ for label, color in zip(
+ u_labels, get_distinct_colors(len(u_labels), self.config.colormap)
+ )
+ }
+ elif grp is not None:
+ color_dict = {
+ label: color
+ for label, color in zip(
+ u_group, get_distinct_colors(len(u_group), self.config.colormap)
+ )
+ }
+ else:
+ color_dict = None
+
+ fig = pairplot(
+ tbl,
+ n_components,
+ ex_var_ratio,
+ label_name,
+ grp_name,
+ marker_dict,
+ color_dict,
+ )
+
+ warnings.filterwarnings("default")
+
+ # save the figure to the path
+ if self.config.path:
+ fig.savefig(self.config.path)
+ return fig
diff --git a/src/biofit/visualization/feature_importance.py b/src/biofit/visualization/feature_importance.py
new file mode 100644
index 0000000..ef3bf78
--- /dev/null
+++ b/src/biofit/visualization/feature_importance.py
@@ -0,0 +1,355 @@
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import TYPE_CHECKING, List, Optional, Type
+
+import numpy as np
+import pandas as pd
+from biocore import DataHandler
+
+from biofit.integration.biosets import get_feature
+from biofit.integration.R.r_caller import RCaller
+from biofit.processing import SelectedColumnTypes
+from biofit.utils.types import Unset
+from biofit.visualization.plotting import (
+ BasePlotter,
+ PlotterConfig,
+)
+
+if TYPE_CHECKING:
+ from biosets import Dataset
+
+
+@dataclass
+class FeatureImportancePlotterConfig(PlotterConfig):
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ None,
+ None,
+ get_feature("TARGET_FEATURE_TYPES"),
+ get_feature("METADATA_FEATURE_TYPES_NOT_TARGET"),
+ ],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ None,
+ None,
+ get_feature("TARGET_FEATURE_TYPES"),
+ get_feature("METADATA_FEATURE_TYPES_NOT_TARGET"),
+ ],
+ init=False,
+ repr=False,
+ )
+ _fit_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("METADATA_FEATURE_TYPES"),
+ None,
+ None,
+ None,
+ ],
+ init=False,
+ repr=False,
+ )
+ _transform_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("METADATA_FEATURE_TYPES"),
+ None,
+ None,
+ None,
+ ],
+ init=False,
+ repr=False,
+ )
+ r_source: str = field(
+ default=(Path(__file__).parent / "plot_feature_importance.R")
+ .resolve()
+ .as_posix(),
+ init=False,
+ repr=False,
+ )
+ main_method: str = field(default="plot_feature_importance", init=False, repr=False)
+
+ input_columns: SelectedColumnTypes = None
+ target_columns: SelectedColumnTypes = None
+ sample_metadata_columns: SelectedColumnTypes = None
+ sample_column: str = None
+ plot_top: int = 15
+ feature_meta_name: str = None
+ feature_column: str = None
+ cols: List[str] = None
+ colHeat: List[str] = None
+ dat_log: str = field(default=None, init=True, repr=False)
+ show_column_names: bool = False
+ scale_legend_title: str = "Value"
+ column_title: str = "Samples"
+ row_title: str = "Features"
+ plot_title: str = "Values"
+
+ def __post_init__(self):
+ if self.dat_log == ["log2_1p", "log2"]:
+ self.plot_title = "log2\n" + self.plot_title
+ elif self.dat_log == ["log10_1p", "log10"]:
+ self.plot_title = "log10\n" + self.plot_title
+
+
+@dataclass
+class FeatureImportancePlotterConfigForMetagenomics(FeatureImportancePlotterConfig):
+ dat_log: str = "log2_1p"
+ plot_title: str = "Abundance"
+ row_title: str = "Taxa"
+
+
+@dataclass
+class FeatureImportancePlotterConfigForOTU(
+ FeatureImportancePlotterConfigForMetagenomics
+):
+ row_title: str = "OTUs"
+
+
+class FeatureImportancePlotter(BasePlotter):
+ _config_class = FeatureImportancePlotterConfig
+ config: FeatureImportancePlotterConfig
+
+ def __init__(
+ self,
+ input_columns: SelectedColumnTypes = None,
+ target_columns: SelectedColumnTypes = None,
+ sample_metadata_columns: SelectedColumnTypes = None,
+ sample_column: str = None,
+ plot_top: int = Unset("15"),
+ feature_meta_name: str = Unset("None"),
+ feature_column: str = Unset("None"),
+ cols: List[str] = Unset("None"),
+ colHeat: List[str] = Unset("None"),
+ dat_log: str = Unset("field(default=None, init=True, repr=False)"),
+ show_column_names: bool = Unset("False"),
+ scale_legend_title: str = Unset('"Value"'),
+ column_title: str = Unset('"Samples"'),
+ row_title: str = Unset('"Features"'),
+ plot_title: str = Unset('"Values"'),
+ path: Optional[str] = None,
+ install_missing: bool = None,
+ config: Optional[FeatureImportancePlotterConfig] = None,
+ ):
+ super().__init__(
+ plot_top=plot_top,
+ feature_meta_name=feature_meta_name,
+ feature_column=feature_column,
+ cols=cols,
+ colHeat=colHeat,
+ dat_log=dat_log,
+ show_column_names=show_column_names,
+ scale_legend_title=scale_legend_title,
+ column_title=column_title,
+ row_title=row_title,
+ plot_title=plot_title,
+ config=config,
+ install_missing=install_missing,
+ )
+ self.r_caller = RCaller.from_script(self.config.r_source)
+ self.r_caller.verify_r_dependencies(
+ bioconductor_dependencies=["ComplexHeatmap"],
+ install_missing=install_missing,
+ )
+ self.plotter = self.r_caller.get_method(self.config.main_method)
+
+ def plot_dataset(
+ self,
+ X: "Dataset",
+ feature_importances: "Dataset",
+ y: "Dataset",
+ sample_metadata: "Dataset" = None,
+ feature_metadata: dict = None,
+ ):
+ from biosets import decode, get_feature_metadata, get_sample_col_name
+
+ if feature_metadata is None:
+ feature_metadata = get_feature_metadata(X)
+
+ if self.config.sample_column is None:
+ self.config.sample_column = (
+ get_sample_col_name(sample_metadata)
+ if sample_metadata is not None
+ else get_sample_col_name(X)
+ )
+
+ self.plot_pandas(
+ X=DataHandler.to_pandas(X),
+ feature_importances=DataHandler.to_pandas(feature_importances),
+ y=DataHandler.to_pandas(decode(y)) if y is not None else None,
+ sample_metadata=DataHandler.to_pandas(sample_metadata)
+ if sample_metadata is not None
+ else None,
+ feature_metadata=feature_metadata,
+ )
+
+ def plot_pandas(
+ self,
+ X: pd.DataFrame,
+ feature_importances: pd.DataFrame,
+ y: pd.DataFrame,
+ sample_metadata: pd.DataFrame = None,
+ feature_metadata: dict = None,
+ ):
+ if self.config.dat_log == "log2_1p":
+ X = np.log2(X + 1)
+ elif self.config.dat_log == "log2":
+ X = np.log2(X)
+ elif self.config.dat_log == "log10_1p":
+ X = np.log10(X + 1)
+ elif self.config.dat_log == "log10":
+ X = np.log10(X)
+
+ if feature_importances is None:
+ raise ValueError("Please provide feature importances.")
+
+ self.config.feature_column = (
+ self.config.feature_column or feature_importances.columns[0]
+ )
+ if self.config.feature_column not in feature_importances.columns:
+ raise ValueError(
+ f"Feature column '{self.config.feature_column}' not found in feature "
+ "importances. Please provide the column name found in both "
+ "feature importances and feature metadata (if provided)."
+ )
+
+ feat_import_cols = [
+ c for c in feature_importances.columns if c != self.config.feature_column
+ ]
+
+ if (
+ isinstance(feature_metadata, dict)
+ and feature_metadata is not None
+ and len(feature_metadata) > 0
+ ):
+ feature_metadata = pd.DataFrame(
+ list(feature_metadata.values()), index=list(feature_metadata.keys())
+ )
+ feature_metadata = feature_metadata.reset_index(
+ names=[self.config.feature_column]
+ )
+
+ if len(feat_import_cols) > 1:
+ medians = feature_importances.loc[:, feat_import_cols].median(axis=1)
+ sorted_inds = np.argsort(np.abs(medians))[::-1]
+ feature_importances = feature_importances.iloc[sorted_inds, :].head(
+ self.config.plot_top
+ )
+ X = DataHandler.to_arrow(X, preserve_index=False)
+ feature_importances = DataHandler.to_arrow(
+ feature_importances, preserve_index=False
+ )
+ if y is not None:
+ y = DataHandler.to_arrow(y, preserve_index=False)
+ if feature_metadata is not None and len(feature_metadata) > 0:
+ feature_metadata = DataHandler.to_arrow(
+ feature_metadata, preserve_index=False
+ )
+ if sample_metadata is not None:
+ sample_metadata = DataHandler.to_arrow(
+ sample_metadata, preserve_index=False
+ )
+
+ self._plotter(
+ X,
+ y=y,
+ feature_importances=feature_importances,
+ sample_metadata=sample_metadata,
+ feature_metadata=feature_metadata,
+ )
+
+ def _plotter(
+ self,
+ X,
+ y,
+ feature_importances,
+ sample_metadata,
+ feature_metadata,
+ ):
+ params = self.config.get_params()
+ if feature_metadata is not None:
+ feature_metadata_columns = DataHandler.get_column_names(feature_metadata)
+ feature_meta_name = params.get("feature_meta_name")
+ if feature_meta_name is not None:
+ if isinstance(feature_meta_name, (str, int)):
+ feature_meta_name = [feature_meta_name]
+ missing_cols = set(feature_meta_name) - set(feature_metadata_columns)
+ if missing_cols:
+ raise ValueError(
+ f"Feature metadata columns {list(missing_cols)} not found in "
+ "feature metadata."
+ )
+
+ self.plotter(
+ X,
+ y=y,
+ feature_importances=feature_importances,
+ sample_metadata=sample_metadata,
+ feature_metadata=feature_metadata,
+ **params,
+ )
+
+ def plot(
+ self,
+ X,
+ feature_importances,
+ y=None,
+ sample_metadata=None,
+ feature_metadata: dict = None,
+ input_columns: SelectedColumnTypes = None,
+ target_columns: SelectedColumnTypes = None,
+ sample_metadata_columns: SelectedColumnTypes = None,
+ sample_column: str = None,
+ plot_top: int = Unset("15"),
+ feature_meta_name: str = Unset("None"),
+ feature_column: str = Unset("None"),
+ cols: List[str] = Unset("None"),
+ colHeat: List[str] = Unset("None"),
+ dat_log: str = Unset("field(default=None, init=True, repr=False)"),
+ show_column_names: bool = Unset("False"),
+ scale_legend_title: str = Unset('"Value"'),
+ column_title: str = Unset('"Samples"'),
+ row_title: str = Unset('"Features"'),
+ plot_title: str = Unset('"Values"'),
+ path: str = None,
+ device: str = "pdf",
+ fingerprint: str = None,
+ unused_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = True,
+ show=True,
+ ):
+ # feature_importances are not selected by columns so its None
+ self.config._input_columns = self._set_input_columns_and_arity(
+ input_columns, None, target_columns, sample_metadata_columns
+ )
+ if (
+ feature_importances is not None
+ and DataHandler.get_shape(feature_importances)[0] == 0
+ ):
+ raise ValueError("Feature importances is empty.")
+ self._plot(
+ X,
+ feature_importances,
+ y,
+ sample_metadata,
+ feature_metadata=feature_metadata,
+ plot_top=plot_top,
+ feature_meta_name=feature_meta_name,
+ sample_column=sample_column,
+ feature_column=feature_column,
+ cols=cols,
+ colHeat=colHeat,
+ dat_log=dat_log,
+ show_column_names=show_column_names,
+ scale_legend_title=scale_legend_title,
+ column_title=column_title,
+ row_title=row_title,
+ plot_title=plot_title,
+ show=show,
+ path=path,
+ device=device,
+ fingerprint=fingerprint,
+ unused_columns=unused_columns,
+ raise_if_missing=raise_if_missing,
+ )
diff --git a/src/biofit/visualization/histogram.py b/src/biofit/visualization/histogram.py
new file mode 100644
index 0000000..1ee7cad
--- /dev/null
+++ b/src/biofit/visualization/histogram.py
@@ -0,0 +1,430 @@
+import textwrap
+from dataclasses import dataclass, field
+from typing import List, Optional, Type
+
+import numpy as np
+
+import biofit.config as config
+from biofit.integration.biosets import get_feature
+from biofit.integration.R import RCaller
+from biofit.processing import SelectedColumnTypes, sync_backup_config
+from biofit.utils.types import Unset
+from biofit.visualization.plotting import BasePlotter, PlotterConfig
+
+
+def prepare_data_for_hist(x1, x2=None):
+ # Calculate row sums and column sums for dataset 1
+ data1_sample = x1.sum(axis=1)
+ data1_feature = x1.sum(axis=0)
+
+ list_sums = [data1_sample, data1_feature]
+
+ # If dataset 2 is provided, calculate its row and column sums
+ if x2 is not None:
+ data2_sample = x2.sum(axis=1)
+ data2_feature = x2.sum(axis=0)
+ list_sums.extend([data2_sample, data2_feature])
+
+ return list_sums
+
+
+def non_zero_sums(x1, x2=None):
+ # Sum all non-zero values for dataset 1
+ x1_non_zero_row = (x1 != 0).sum(axis=1)
+ x1_non_zero_col = (x1 != 0).sum(axis=0)
+ sums = [x1_non_zero_row, x1_non_zero_col]
+
+ # If dataset 2 is provided, calculate its non-zero sums
+ if x2 is not None:
+ x2_non_zero_row = (x2 != 0).sum(axis=1)
+ x2_non_zero_col = (x2 != 0).sum(axis=0)
+ sums.extend([x2_non_zero_row, x2_non_zero_col])
+
+ return sums
+
+
+def prepare_axis_label(label, log_type):
+ if "1p" in log_type:
+ if "_1p" in log_type:
+ label_log = log_type.replace("_1p", "")
+ label = f"{label} ({label_log}(x+1))"
+ else:
+ label = f"{label} (ln(x+1))"
+ elif log_type == "log":
+ label = f"{label} (ln)"
+ else:
+ label = f"{label} ({log_type})"
+ return label
+
+
+def log_transformation(x, log_type):
+ if "1p" in log_type:
+ if "_1p" in log_type:
+ label_log = log_type.replace("_1p", "")
+ if label_log == "log10":
+ return np.log10(1 + x)
+ elif label_log == "log2":
+ return np.log2(1 + x)
+ else:
+ return np.log1p(x)
+ elif log_type == "log":
+ return np.log(x)
+ elif log_type == "log2":
+ return np.log2(x)
+ elif log_type == "log10":
+ return np.log10(x)
+ return x
+
+
+@dataclass
+class HistogramConfig(PlotterConfig):
+ processor_type: str = field(default="scaling", init=False, repr=False)
+ _unused_feature_types: List[Type] = field(
+ default=get_feature("METADATA_FEATURE_TYPES"),
+ init=False,
+ repr=False,
+ )
+ r_source: str = field(
+ default=(config.R_SCRIPTS / "plotting_utils.R").as_posix(),
+ init=False,
+ repr=False,
+ )
+ main_method: str = field(default="generate_histogram", init=False, repr=False)
+
+ xlab: str = "X"
+ ylab: str = "Frequency"
+ title: str = "Histogram"
+ bins: int = 30
+ font_size: int = 8
+ col_fill: str = "grey40"
+ col_outline: str = "white"
+ col_fill = "grey40"
+ col_outline = ("black",)
+ x1_name: str = "Before"
+ x2_name: str = "After"
+ xlog: Optional[str] = None
+ ylog: Optional[str] = None
+
+ def __post_init__(self):
+ if self.xlog not in [
+ None,
+ "log2",
+ "log10",
+ "log",
+ "log2_1p",
+ "log10_1p",
+ "log1p",
+ ]:
+ raise ValueError(
+ f"Invalid value for xlog: {self.xlog}. Must be one of: None, 'log2', 'log10', 'log', 'log2_1p', 'log10_1p', 'log1p'"
+ )
+ if self.ylog not in [
+ None,
+ "log2",
+ "log10",
+ "log",
+ "log2_1p",
+ "log10_1p",
+ "log1p",
+ ]:
+ raise ValueError(
+ f"Invalid value for ylog: {self.ylog}. Must be one of: None, 'log2', 'log10', 'log', 'log2_1p', 'log10_1p', 'log1p'"
+ )
+
+
+class HistogramPlotter(BasePlotter):
+ _config_class = HistogramConfig
+ config: HistogramConfig
+
+ def __init__(
+ self,
+ xlab: str = Unset('"X"'),
+ ylab: str = Unset('"Frequency"'),
+ title: str = Unset('"Histogram"'),
+ bins: int = Unset("30"),
+ font_size: int = Unset("8"),
+ col_fill: str = Unset('"grey40"'),
+ col_outline: str = Unset('"white"'),
+ xlog: Optional[str] = Unset("None"),
+ ylog: Optional[str] = Unset("None"),
+ install_missing: bool = None,
+ config: Optional[HistogramConfig] = None,
+ ):
+ super().__init__(
+ xlab=xlab,
+ ylab=ylab,
+ title=title,
+ bins=bins,
+ font_size=font_size,
+ col_fill=col_fill,
+ col_outline=col_outline,
+ xlog=xlog,
+ ylog=ylog,
+ config=config,
+ install_missing=install_missing,
+ )
+ if self.config.r_source and self.config.main_method:
+ self.r_caller = RCaller.from_script(self.config.r_source)
+ enter_code = textwrap.dedent(
+ """
+ suppressPackageStartupMessages(require(ggplot2))
+ """
+ )
+ exit_code = textwrap.dedent(
+ """
+ ggplot2::ggsave(path, results)
+ """
+ )
+ self.plotter = self.r_caller.get_method(
+ self.config.main_method, enter_code, exit_code
+ )
+
+ @sync_backup_config
+ def set_params(self, **kwargs):
+ self.config = self.config.replace_defaults(**kwargs)
+ if self.config.r_source and self.config.main_method:
+ self.r_caller = RCaller.from_script(self.config.r_source)
+ enter_code = textwrap.dedent(
+ """
+ suppressPackageStartupMessages(require(ggplot2))
+ """
+ )
+ exit_code = textwrap.dedent(
+ """
+ ggplot2::ggsave(path, results)
+ """
+ )
+ self.plotter = self.r_caller.get_method(
+ self.config.main_method, enter_code, exit_code
+ )
+
+ return self
+
+ def plot(
+ self,
+ x,
+ input_columns: SelectedColumnTypes = None,
+ xlab: str = Unset('"X"'),
+ ylab: str = Unset('"Frequency"'),
+ title: str = Unset('"Histogram"'),
+ bins: int = Unset("30"),
+ font_size: int = Unset("8"),
+ col_fill: str = Unset('"grey40"'),
+ col_outline: str = Unset('"white"'),
+ xlog: Optional[str] = Unset("None"),
+ ylog: Optional[str] = Unset("None"),
+ path: str = None,
+ device: str = "pdf",
+ fingerprint: str = None,
+ unused_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = True,
+ show: bool = True,
+ ):
+ self.config._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._plot(
+ x,
+ xlab=xlab,
+ ylab=ylab,
+ title=title,
+ bins=bins,
+ font_size=font_size,
+ col_fill=col_fill,
+ col_outline=col_outline,
+ xlog=xlog,
+ ylog=ylog,
+ path=path,
+ device=device,
+ fingerprint=fingerprint,
+ unused_columns=unused_columns,
+ raise_if_missing=raise_if_missing,
+ show=show,
+ )
+
+ def plot_arrow(self, x):
+ kwargs = self.config.get_params()
+ context_kwargs = {
+ "path": kwargs.pop("path", None),
+ }
+ self.plotter(x, context_kwargs=context_kwargs, **kwargs)
+
+
+@dataclass
+class ComparisonHistogramConfig(PlotterConfig):
+ processor_type: str = field(default="scaling", init=False, repr=False)
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [None, None], init=False, repr=False
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [None, None], init=False, repr=False
+ )
+ _fit_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("METADATA_FEATURE_TYPES"),
+ get_feature("METADATA_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+ _transform_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("METADATA_FEATURE_TYPES"),
+ get_feature("METADATA_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+ r_source: str = field(
+ default=(config.R_SCRIPTS / "plotting_utils.R").as_posix(),
+ init=False,
+ repr=False,
+ )
+ main_method: str = field(
+ default="generate_comparison_histogram", init=False, repr=False
+ )
+
+ xlab: Optional[str] = None
+ ylab: str = "Count"
+ title: str = "Comparison Histogram"
+ bins: int = 30
+ alpha: float = 0.6
+ legend_title: str = "Legend"
+ legend_position: str = "top"
+ subplot_title1: str = "Before"
+ subplot_title2: str = "After"
+ col_set: str = "Set1"
+ cols: Optional[List[str]] = None
+ xlog: Optional[bool] = None
+ ylog: Optional[bool] = None
+
+
+class ComparisonHistogramPlotter(BasePlotter):
+ _config_class = ComparisonHistogramConfig
+ config: ComparisonHistogramConfig
+
+ def __init__(
+ self,
+ xlab: Optional[str] = Unset("None"),
+ ylab: str = Unset("None"),
+ title: str = Unset("None"),
+ bins: int = Unset("None"),
+ alpha: float = Unset("None"),
+ legend_title: str = Unset("None"),
+ legend_position: str = Unset("None"),
+ subplot_title1: str = Unset("None"),
+ subplot_title2: str = Unset("None"),
+ col_set: str = Unset("None"),
+ cols: Optional[List[str]] = Unset("None"),
+ xlog: Optional[bool] = Unset("None"),
+ ylog: Optional[bool] = Unset("None"),
+ install_missing: bool = None,
+ config: Optional[ComparisonHistogramConfig] = None,
+ ):
+ super().__init__(
+ config=config,
+ xlab=xlab,
+ ylab=ylab,
+ title=title,
+ bins=bins,
+ alpha=alpha,
+ legend_title=legend_title,
+ legend_position=legend_position,
+ subplot_title1=subplot_title1,
+ subplot_title2=subplot_title2,
+ col_set=col_set,
+ cols=cols,
+ xlog=xlog,
+ ylog=ylog,
+ install_missing=install_missing,
+ )
+ self.plotter = None
+ if self.config.r_source and self.config.main_method:
+ self.r_caller = RCaller.from_script(self.config.r_source)
+ enter_code = textwrap.dedent(
+ """
+ suppressPackageStartupMessages(require(ggplot2))
+ """
+ )
+ exit_code = textwrap.dedent(
+ """
+ ggplot2::ggsave(path, results)
+ """
+ )
+ self.plotter = self.r_caller.get_method(
+ self.config.main_method, enter_code, exit_code
+ )
+
+ @sync_backup_config
+ def set_params(self, **kwargs):
+ self.config = self.config.replace_defaults(**kwargs)
+ if self.config.r_source and self.config.main_method:
+ self.r_caller = RCaller.from_script(self.config.r_source)
+ enter_code = textwrap.dedent(
+ """
+ suppressPackageStartupMessages(require(ggplot2))
+ """
+ )
+ exit_code = textwrap.dedent(
+ """
+ ggplot2::ggsave(path, results)
+ """
+ )
+ self.plotter = self.r_caller.get_method(
+ self.config.main_method, enter_code, exit_code
+ )
+
+ return self
+
+ def plot(
+ self,
+ x1,
+ x2=None,
+ column1: SelectedColumnTypes = None,
+ column2: SelectedColumnTypes = None,
+ xlab: Optional[str] = Unset("None"),
+ ylab: str = Unset("None"),
+ title: str = Unset("None"),
+ bins: int = Unset("None"),
+ alpha: float = Unset("None"),
+ legend_title: str = Unset("None"),
+ legend_position: str = Unset("None"),
+ subplot_title1: str = Unset("None"),
+ subplot_title2: str = Unset("None"),
+ col_set: str = Unset("None"),
+ cols: Optional[List[str]] = Unset("None"),
+ xlog: Optional[bool] = Unset("None"),
+ ylog: Optional[bool] = Unset("None"),
+ path: str = None,
+ device: str = "pdf",
+ fingerprint: str = None,
+ unused_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = True,
+ show: bool = True,
+ ):
+ self.config._input_columns = self._set_input_columns_and_arity(column1, column2)
+ return self._plot(
+ x1,
+ x2,
+ xlab=xlab,
+ ylab=ylab,
+ title=title,
+ bins=bins,
+ alpha=alpha,
+ legend_title=legend_title,
+ legend_position=legend_position,
+ col_set=col_set,
+ cols=cols,
+ subplot_title1=subplot_title1,
+ subplot_title2=subplot_title2,
+ xlog=xlog,
+ ylog=ylog,
+ path=path,
+ device=device,
+ fingerprint=fingerprint,
+ unused_columns=unused_columns,
+ raise_if_missing=raise_if_missing,
+ show=show,
+ )
+
+ def plot_arrow(self, x1, x2):
+ kwargs = self.config.get_params()
+ self.plotter(x1, x2, **kwargs)
diff --git a/src/biofit/visualization/plot_feature_importance.R b/src/biofit/visualization/plot_feature_importance.R
new file mode 100644
index 0000000..d14f0bf
--- /dev/null
+++ b/src/biofit/visualization/plot_feature_importance.R
@@ -0,0 +1,216 @@
+source(file.path(R_SCRIPTS_PATH, "plotting_utils.R"))
+source(file.path(R_SCRIPTS_PATH, "utils.R"))
+
+#'
+#' Feature importance plot with feature data in the samples.
+#' @param X Arrow data table that is logged or in non-logged form depending on the omics data
+#' @param y Arrow labels
+#' @param sample_metadata Arrow sample metadata
+#' @param feature_importances Arrow feature importance table: first column is the feature ID, "feature", second column (and on wards) is/are the importance per final model seed(s)
+#' @param path the path for the pdf file to save
+#' @param X, y, sample_metadata, feature_importances arrow objects from the biofit framework
+#' @param plot_top the number of top features to plot. Default is 15.
+#' @param feature_meta_name can be NULL, prints the feature ID; or character such as "species", prints this column, or vector of more than 1 characters c("genus", "species"), then prints the collapsed text with a space in between. Default is NULL.
+#' @param plot_title name of the data type, eg "Presence", or "log2 Abundance". Default is NULL.
+#' @param column_title eg. "Sample" or "Isolate". Default is NULL.
+#' @param row_title eg. "OTU" or "SNP". Default is NULL.
+#' @examples
+#' \dontrun{
+#' # usage for OTU data:
+#' plot_feature_importance(X, y, sample_metadata, feature_importances, plot_top = 30, feature_meta_name = c("feature", "genus", "species"), plot_title = "log2\nAbundance", column_title = "Sample", row_title = "OTU")
+#' # usage for metagenomics data:
+#' plot_feature_importance(X, y, sample_metadata, feature_importances, plot_top = 30, feature_meta_name = c("genus", "species"), plot_title = "log2\nAbundance", column_title = "Sample", row_title = "Taxonomy")
+#' # usage for transcriptomics/proteomics (non-MALDI) data: eg
+#' plot_feature_importance(X, y, sample_metadata, feature_importances, plot_top = 30, feature_meta_name = c("gene"), plot_title = "log2\nExpression", column_title = "Sample", row_title = "Gene")
+#' # usage for genomics data:
+#' plot_feature_importance(X, y, sample_metadata, feature_importances, plot_top = 30, feature_meta_name = NULL, plot_title = "Presence", column_title = "Isolate", row_title = "SNP")
+#' # usage for MALDI-TOF data:
+#' plot_feature_importance(X, y, sample_metadata, feature_importances, plot_top = 30, feature_meta_name = "DaRange", plot_title = "log10\nAbundance", column_title = "Isolate", row_title = "Da range")
+#' }
+#'
+plot_feature_importance <- function(
+ X,
+ feature_importances,
+ path,
+ y = NULL,
+ sample_metadata = NULL,
+ feature_metadata = NULL,
+ input_columns = NULL,
+ target_columns = NULL,
+ sample_metadata_columns = NULL,
+ feature_column = NULL,
+ feature_meta_name = NULL,
+ plot_top = 15,
+ cols = NULL,
+ col_heat = NULL,
+ show_column_names = FALSE,
+ plot_title = "Relative Abundance",
+ column_title = "Samples",
+ row_title = "Taxonomic Relative Abundance") {
+ width <- 13
+ height <- 7
+
+ X <- convert_to_dataframe(X)
+ input_columns <- get_default_columns(X, input_columns, NULL)
+ X <- validate_data(X, input_columns)
+
+ if (!is.null(y)) {
+ y <- convert_to_dataframe(y)
+ target_columns <- get_default_columns(y, target_columns)
+ y <- validate_data(y, target_columns)
+ }
+
+ if (!is.null(sample_metadata)) {
+ sample_metadata <- convert_to_dataframe(sample_metadata)
+ sample_metadata_columns <- get_default_columns(
+ sample_metadata, sample_metadata_columns, NULL
+ )
+ sample_metadata <- validate_data(sample_metadata, sample_metadata_columns)
+ }
+
+ if (is.null(sample_column)) {
+ sample_column <- sample_metadata_columns[1]
+ }
+
+
+ if (!is.null(feature_metadata)) {
+ feature_metadata <- convert_to_dataframe(feature_metadata)
+ if (is.list(feature_meta_name)) {
+ feature_meta_name <- as.vector(unlist(feature_meta_name))
+ }
+ feature_meta_name <- get_default_columns(
+ feature_metadata, feature_meta_name
+ )
+ feature_metadata <- validate_data(feature_metadata, feature_meta_name)
+ rownames(feature_metadata) <- feature_metadata[, feature_column]
+ }
+
+ suppressPackageStartupMessages(require(ComplexHeatmap))
+ suppressPackageStartupMessages(require(grid))
+ suppressPackageStartupMessages(require(circlize))
+ suppressPackageStartupMessages(require(RColorBrewer))
+
+ ## get label and feature information
+ ## convert arrow data format to data frame
+ if (!is.null(feature_metadata)) {
+ feature_metadata <- convert_to_dataframe(feature_metadata)
+ }
+
+
+ feature_importances <- convert_to_dataframe(feature_importances)
+ top_feat_id <- feature_importances[, 1]
+ rownames(feature_importances) <- feature_importances[, feature_column]
+ # drop feature column
+ not_feature_column <- setdiff(colnames(feature_importances), feature_column)
+ feature_importances <- feature_importances[, not_feature_column]
+
+
+ y <- factor(as.vector(convert_to_dataframe(y)[, 1]))
+
+ rownames(X) <- sample_metadata[, sample_column]
+
+ ### make sure plot_top is numeric
+ if (!is.numeric(plot_top)) {
+ stop("plot_top, the number of top features to plot, must be numeric.")
+ }
+
+ ### color selected from plotting_utils.R to be consistent
+ if (is.null(cols)) {
+ cols <- color_select(length(levels(y)))
+ names(cols) <- levels(y)
+ }
+
+ ### if col_heat not specified
+ if (is.null(col_heat)) {
+ col_heat <- brewer.pal(9, "YlOrRd")
+ }
+
+ ## Define the feature/text annotation with row names
+ ## plot feature ID if not specified
+ if (is.null(feature_meta_name) || is.null(feature_metadata)) {
+ feat_on_plot <- top_feat_id
+ } else {
+ ## for more than one feature meta columns to display
+ if (length(feature_meta_name) > 1) {
+ if (!(feature_column %in% feature_meta_name)) {
+ feature_meta_name <- c(feature_column, feature_meta_name)
+ }
+ feat_on_plot <- apply(
+ feature_metadata[top_feat_id, feature_meta_name],
+ 1,
+ paste,
+ collapse = "\n"
+ )
+ } else { ## for one feature meta columns to display
+ feat_on_plot <- feature_metadata[top_feat_id, feature_meta_name]
+ }
+ }
+
+ ##################
+ ## generate feature importance plot
+ ##################
+ ### the feature boxplot/barplot
+ # remove UNINTEGRATED from top_feat_id
+ hdat2plot <- t(data.matrix(X[, top_feat_id]))
+
+ print(hdat2plot)
+ print(col_heat)
+ if (ncol(feature_importances) == 1) {
+ ha2 <- ComplexHeatmap::rowAnnotation(
+ ` ` = row_anno_points(
+ feature_importances[top_feat_id, ],
+ axis = TRUE, outline = FALSE
+ ),
+ width = unit(3, "cm")
+ )
+ } else {
+ ha2 <- ComplexHeatmap::rowAnnotation(
+ ` ` = row_anno_boxplot(
+ data.matrix(feature_importances[top_feat_id, ]),
+ axis = TRUE,
+ outline = FALSE
+ ),
+ width = unit(3, "cm")
+ )
+ }
+
+ ### the sample annotation in the column
+ ### (Added a name hack here to be Importance
+ ### as the label is above feature importance box/dot plot)
+ ha1 <- ComplexHeatmap::HeatmapAnnotation(
+ df = data.frame(Importance = y),
+ show_legend = TRUE,
+ col = list(Importance = cols),
+ annotation_legend_param = list(
+ title = target_columns,
+ legend_direction = "vertical"
+ )
+ )
+
+ ### add feature names to rows
+ text_annotation <- rowAnnotation(
+ text = anno_text(
+ feat_on_plot,
+ just = "left",
+ gp = gpar(fontsize = 10)
+ )
+ )
+
+ ### plot histogram and add the 3 components from above
+ h2plot <- ComplexHeatmap::Heatmap(hdat2plot,
+ cluster_rows = FALSE, col = col_heat, top_annotation = ha1,
+ heatmap_legend_param = list(
+ title = plot_title, color_bar = "continuous",
+ legend_direction = "vertical"
+ ),
+ show_column_names = show_column_names,
+ row_title = row_title, # Title for the rows
+ column_title = column_title, # Title for the columns
+ row_title_side = "left", # Position of the row title
+ column_title_side = "top"
+ ) +
+ ha2 + text_annotation # , top_annotation_height = unit(1, "cm")
+ start_device(path, width = width, height = height, units = "in", res = 300)
+ draw(h2plot)
+ dev.off()
+}
diff --git a/src/biofit/visualization/plot_sample_metadata.R b/src/biofit/visualization/plot_sample_metadata.R
new file mode 100644
index 0000000..efe3a3d
--- /dev/null
+++ b/src/biofit/visualization/plot_sample_metadata.R
@@ -0,0 +1,140 @@
+source(file.path(R_SCRIPTS_PATH, "plotting_utils.R"))
+source(file.path(R_SCRIPTS_PATH, "utils.R"))
+
+plot_sample_metadata <- function(
+ data, outcome = NULL,
+ sample_metadata_columns = NULL,
+ outcome_column = NULL,
+ path,
+ device = "pdf") {
+
+ suppressPackageStartupMessages(require(ggplot2))
+ suppressPackageStartupMessages(require(RColorBrewer))
+ suppressPackageStartupMessages(require(circlize))
+ suppressPackageStartupMessages(require(patchwork))
+ width <- 7
+ height <- 7
+ if (device[1] != ".") {
+ device <- paste0(".", device)
+ }
+ data <- convert_to_dataframe(data)
+ sample_metadata_columns <- get_default_columns(
+ data, sample_metadata_columns,
+ max_cols = NULL
+ )
+ data <- validate_data(data, sample_metadata_columns)
+
+ if (!is.null(outcome)) {
+ outcome <- convert_to_dataframe(outcome)
+ outcome_column <- get_default_columns(outcome, outcome_column, max_cols = 1)
+ outcome <- validate_data(outcome, outcome_column)
+ data <- concatenate_datasets(
+ data[, sample_metadata_columns], outcome,
+ how = "horizontal"
+ )
+ } else if (!is.null(outcome_column)) {
+ outcome_column <- get_default_columns(data, outcome_column, max_cols = 1)
+ } else {
+ stop("Please provide either an outcome or an outcome column")
+ }
+
+ # get filename from path
+ file_name <- gsub("\\.[^.]*$", "", basename(path))
+ path <- dirname(path)
+ # Get Outcome Type
+ outcome_type <- detect_var_type(data, outcome_column)
+
+ metadata_dir <- path
+ if (!dir.exists(metadata_dir)) {
+ dir.create(metadata_dir)
+ }
+
+ # Comparison Visualizations
+ if (outcome_type == "categorical") {
+ suppressPackageStartupMessages(require(forcats))
+ # Outcome variable visualization
+ cat_metadata <- plot_cat_metadata(data, outcome_column)
+ cat_metadata_fn <- paste0(file_name, "_dist_", outcome_column, device)
+ save_plots(cat_metadata_fn,
+ plot = cat_metadata,
+ path = metadata_dir, width = width, height = height
+ )
+
+
+ plot_num <- 0
+ # Loop through all the columns
+ for (col in names(data)) {
+ plot_num <- plot_num + 1
+ # If the col is the same as the outcome
+ if (col == outcome_column) next
+
+ # Get the data type of the current column
+ col_type <- detect_var_type(data, col)
+
+ # If it is other, we skip for now
+ if (col_type == "other") {
+ print(paste0(col, " was not plotted"))
+ next
+ }
+
+ # categorical and numerical comparisons
+ if (col_type == "categorical") {
+ comp_metadata <- plot_cat_vs_cat_metadata(data, outcome_column, col)
+ } else if (col_type == "numerical") {
+ comp_metadata <- plot_cat_vs_num_metadata(data, outcome_column, col)
+ }
+
+ # Save Comparison
+ comp_metadata_fn <- paste0(
+ file_name, plot_num, "_",
+ outcome_column, "_vs_", col, device
+ )
+ save_plots(comp_metadata_fn,
+ plot = comp_metadata,
+ path = metadata_dir, width = width * 2, height = height
+ )
+ }
+ } else if (outcome_type == "numerical") {
+ # Outcome variable visualization
+ num_metadata <- plot_num_metadata(data, outcome_column, font_size = 11)
+ num_metadata_fn <- paste0(file_name, "_dist_", outcome_column, device)
+ save_plots(num_metadata_fn,
+ plot = num_metadata, path = metadata_dir,
+ width = width, height = height
+ )
+ plot_num <- 0
+ # Loop through all of the columns
+ for (col in names(data)) {
+ plot_num <- plot_num + 1
+ # If the col is the same as the outcome
+ if (col == outcome_column) next
+
+ # Get the data type of the current column
+ col_type <- detect_var_type(data, col)
+
+ # If it is other, we skip for now
+ if (col_type == "other") {
+ print(paste0(col, " was not plotted"))
+ next
+ }
+
+ # categorical and numerical comparisons
+ if (col_type == "categorical") {
+ comp_metadata <- plot_cat_vs_num_metadata(data, col, outcome_column)
+ } else if (col_type == "numerical") {
+ comp_metadata <- plot_num_vs_num_metadata(data, outcome_column, col)
+ }
+
+ comp_metadata_fn <- paste0(
+ file_name, plot_num, "_", outcome_column, "_vs_", col, device
+ )
+ save_plots(
+ comp_metadata_fn,
+ plot = comp_metadata,
+ path = metadata_dir, width = width * 2, height = height
+ )
+ }
+ } else if (outcome_type == "other") {
+ print("Outcome variable has too many levels")
+ }
+}
diff --git a/src/biofit/visualization/plotting.py b/src/biofit/visualization/plotting.py
new file mode 100644
index 0000000..e1a0502
--- /dev/null
+++ b/src/biofit/visualization/plotting.py
@@ -0,0 +1,429 @@
+import inspect
+import os
+import sys
+import tempfile
+import warnings
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import List, Optional
+
+from biocore import DataHandler
+from biocore.utils.import_util import is_ipywidgets_available, is_matplotlib_available
+from biocore.utils.naming import camelcase_to_snakecase
+
+from biofit.integration.R.r_caller import RCaller
+from biofit.processing import (
+ _ORDERED_FORMATS,
+ TransformationMixin,
+ sync_backup_config,
+)
+from biofit.utils import (
+ _build_cache_dir,
+ fingerprint_from_data,
+ has_ext,
+ has_separator,
+ logging,
+ move_temp_file,
+)
+from biocore.utils.py_util import is_dataset_dict
+from biofit.visualization.plotting_utils import (
+ display_image_carousel,
+ is_in_notebook,
+)
+
+from ..processing import BaseConfig, SelectedColumnTypes, SelectedFeatureTypes
+
+logger = logging.get_logger(__name__)
+
+
+ORDERED_PLOTTER_FORMATS = _ORDERED_FORMATS + ["dataset", "ds"]
+
+
+def _processor_info_from_fingerprint(fingerprint: str):
+ if fingerprint is None:
+ return "", "", ""
+ processor_info = fingerprint.split("-")
+ ds_name = ""
+ if len(processor_info) == 5:
+ _, processor_name, processor_type, ds_name, _ = processor_info
+ elif len(processor_info) == 4:
+ _, processor_name, processor_type, ds_name = processor_info
+ elif len(processor_info) == 3:
+ _, processor_name, processor_type = processor_info
+ else:
+ return "", "", ""
+ return processor_name, processor_type, ds_name
+
+
+@dataclass
+class PlotterConfig(BaseConfig):
+ path: str = field(default=None, kw_only=True, init=True, repr=True)
+ device: str = field(default="pdf", kw_only=True, init=True, repr=False)
+ fingerprint: str = field(default=None, kw_only=True, init=True, repr=False)
+ unused_columns: SelectedColumnTypes = field(
+ default=None, kw_only=True, init=True, repr=False
+ )
+ raise_if_missing: bool = field(default=True, kw_only=True, init=True, repr=False)
+ cache_dir: str = field(default=None, kw_only=True, init=True, repr=False)
+ version: str = field(default="0.0.0", kw_only=True, init=True, repr=False)
+
+ _input_columns: SelectedColumnTypes = field(default=None, init=False, repr=False)
+ _compare: bool = field(default=False, init=False, repr=False)
+ _fit_input_feature_types: SelectedFeatureTypes = field(
+ default=None, init=False, repr=False
+ )
+ _fit_unused_feature_types: SelectedFeatureTypes = field(
+ default=None, init=False, repr=False
+ )
+ _transform_input_feature_types: SelectedFeatureTypes = field(
+ default=None, init=False, repr=False
+ )
+ _transform_unused_feature_types: SelectedFeatureTypes = field(
+ default=None, init=False, repr=False
+ )
+ r_code: str = field(default=None, init=False, repr=False)
+ r_source: str = field(default=None, init=False, repr=False)
+ main_method: str = field(default=None, init=False, repr=False)
+
+ processor_type: str = field(default="", init=False, repr=False)
+ processor_name: str = field(default="", init=False, repr=False)
+ dataset_name: str = field(default="", init=False, repr=False)
+
+ # automatically generated attributes
+
+ feature_idx_in_: List[int] = field(default=None, init=False, repr=False)
+ feature_names_in_: List[str] = field(default=None, init=False, repr=False)
+ extra_idx_in_: List[List[int]] = field(default=None, init=False, repr=False)
+ extra_names_in_: List[List[str]] = field(default=None, init=False, repr=False)
+
+
+class BasePlotter(TransformationMixin):
+ """_summary_
+
+ Attributes:
+ r_code (_type_): _description_
+ r_source (_type_): _description_
+ feature_type (_type_): _description_
+ dtype (str): specifies the type of the plotter. Can be 'plotter' or 'plotter_for'
+
+ Raises:
+ NotImplementedError: _description_
+ ValueError: _description_
+ ValueError: _description_
+ ValueError: _description_
+ ValueError: _description_
+ ValueError: _description_
+ ValueError: _description_
+ ValueError: _description_
+
+ Returns:
+ _type_: _description_
+ """
+
+ r_caller: RCaller = None
+ _config_class = PlotterConfig
+ plotter: Optional[str] = None
+ config: PlotterConfig
+
+ def __init__(self, config: Optional[PlotterConfig] = None, **kwargs):
+ add_new_attr = kwargs.pop("add_new_attr", False)
+ ignore_none = kwargs.pop("ignore_none", False)
+
+ if config is None:
+ if hasattr(self, "_config_class"):
+ self.config = self._config_class.from_dict(
+ kwargs, ignore_none=ignore_none, add_new_attr=add_new_attr
+ )
+ elif isinstance(config, PlotterConfig):
+ self.config = config
+ elif isinstance(config, dict):
+ self.config = self._config_class.from_dict(
+ config, ignore_none=ignore_none, add_new_attr=add_new_attr
+ )
+ else:
+ raise ValueError(f"Unsupported config type {type(config)}")
+ if config is None:
+ self = self.set_params(**kwargs)
+ if self.config.r_source and self.config.main_method and self.plotter is None:
+ self.r_caller = RCaller.from_script(self.config.r_source)
+ self.r_caller.verify_r_dependencies(
+ install_missing=kwargs.get("install_missing")
+ )
+ self.plotter = self.r_caller.get_method(self.config.main_method)
+ if kwargs.get("_function", None):
+ self._function = kwargs["_function"]
+
+ @sync_backup_config
+ def set_params(self, **kwargs):
+ self.config = self.config.replace_defaults(**kwargs)
+ if self.config.r_source and self.config.main_method:
+ self.r_caller = RCaller.from_script(self.config.r_source)
+ self.plotter = self.r_caller.get_method(self.config.main_method)
+
+ return self
+
+ @classmethod
+ def _from_config(cls, config: PlotterConfig, **kwargs):
+ return cls(config=config, **kwargs)
+
+ def plot(self, X, *args, **kwargs):
+ if is_dataset_dict(X):
+ raise ValueError(
+ "Cannot plot a DatasetDict or IterableDatasetDict. Please provide a Dataset or IterableDataset."
+ )
+ return self._plot(X, *args, **kwargs)
+
+ def _plot(self, X, *args, **kwargs):
+ """
+ Transforms the input data.
+
+ Args:
+ X (Any): The input data.
+ *args: Additional arguments. These are additional table or array-like data to be used in the plotting.
+ **kwargs: Additional keyword arguments. These are objects or parameters to be used in the plotting.
+
+ Returns:
+ Any: The computed processor.
+ """
+ show = kwargs.pop("show", True)
+ output_dir = kwargs.pop("output_dir", None)
+ self.config, kwargs = self.config.replace_defaults(
+ ignore_none=True, return_unused_kwargs=True, **kwargs
+ )
+
+ args = list(args)
+
+ plot_funcs = self._get_method(ORDERED_PLOTTER_FORMATS, "plot")
+ plot_func, target_format = self._get_target_func(
+ plot_funcs,
+ source_format=DataHandler.get_format(X),
+ target_formats=kwargs.get("input_format", None),
+ accepted_formats=ORDERED_PLOTTER_FORMATS,
+ )
+
+ fingerprint = kwargs.pop("fingerprint", None) or self.config.fingerprint
+
+ image_paths = None
+ if plot_func:
+ (
+ self.feature_names_in_,
+ self.feature_idx_in_,
+ _,
+ self.extra_names_in_,
+ self.extra_idx_in_,
+ _,
+ _,
+ ) = self._get_columns(
+ X,
+ *args,
+ input_columns=self.config._input_columns,
+ unused_columns=self.config.unused_columns,
+ input_feature_types=self.config._transform_input_feature_types,
+ unused_feature_types=self.config._transform_unused_feature_types,
+ raise_if_missing=self.config.raise_if_missing,
+ )
+
+ path_to_save = self.config.path
+
+ # required args are ones without defaults
+ required_args = [
+ p.default == p.empty
+ for p in inspect.signature(plot_func).parameters.values()
+ ][1:]
+ if (
+ not self.config._compare
+ and self.extra_idx_in_ is not None
+ and len(self.extra_idx_in_) > 0
+ and DataHandler.supports_named_columns(X)
+ ):
+ new_cols = self._make_columns_exclusive(
+ [self.feature_names_in_] + self.extra_names_in_
+ )
+
+ self.feature_names_in_ = new_cols[0]
+ self.feature_idx_in_ = sorted(
+ DataHandler.get_column_indices(X, self.feature_names_in_)
+ )
+
+ self.extra_names_in_ = new_cols[1:]
+ extra_idx_in_ = []
+ for i, names in enumerate(self.extra_names_in_):
+ if len(args) and args[i] is not None:
+ extra_idx_in_.append(
+ sorted(
+ DataHandler.get_column_indices(
+ args[i], names, raise_if_missing=False
+ )
+ )
+ )
+ elif names is not None:
+ extra_idx_in_.append(
+ sorted(
+ DataHandler.get_column_indices(
+ X, names, raise_if_missing=True
+ )
+ )
+ )
+ else:
+ if required_args[i]:
+ raise ValueError(
+ f"Missing required argument {i} for plot function {plot_func}"
+ )
+ extra_idx_in_.append(self.extra_idx_in_[i])
+ self.extra_idx_in_ = extra_idx_in_
+
+ data_fingerprint = fingerprint_from_data(X)
+ if fingerprint is None:
+ fingerprint = self.generate_fingerprint(data_fingerprint, self.config)
+
+ input, args, kwargs = self._process_plot_input(X, *args, **kwargs)
+ args = list(args)
+ if len(args) == 0:
+ if self.extra_idx_in_:
+ for i, inds in enumerate(self.extra_idx_in_):
+ if inds is not None:
+ args.append(
+ DataHandler.to_format(
+ X, target_format, input_columns=inds
+ )
+ )
+ else:
+ args.append(None)
+ else:
+ for i, (inds, arg) in enumerate(zip(self.extra_idx_in_, args)):
+ if arg is not None:
+ args[i] = DataHandler.to_format(
+ arg, target_format, input_columns=inds
+ )
+ elif inds is not None:
+ args[i] = DataHandler.to_format(
+ X, target_format, input_columns=inds
+ )
+
+ input = DataHandler.to_format(
+ input, target_format, input_columns=self.feature_idx_in_
+ )
+ temp_dir = tempfile.mkdtemp()
+ temp_file = os.path.join(temp_dir, fingerprint)
+ temp_dir = Path(temp_dir)
+ temp_dir.mkdir(parents=True, exist_ok=True)
+ temp_dir = temp_dir.resolve().as_posix()
+ self.config.path = temp_file + f".{self.config.device}"
+ output_dir = None if output_dir is None else Path(output_dir)
+ file_name = None
+ if path_to_save:
+ if has_separator(path_to_save):
+ if has_ext(path_to_save):
+ output_dir = (
+ Path(path_to_save).parent
+ if output_dir is None
+ else output_dir
+ )
+ file_name = Path(path_to_save).name
+ else:
+ output_dir = (
+ Path(path_to_save) if output_dir is None else output_dir
+ )
+ file_name = None
+ elif has_ext(path_to_save):
+ file_name = str(path_to_save)
+
+ if output_dir is None:
+ if kwargs.get("processor", None) and kwargs["processor"].cache_files:
+ output_dir = (
+ Path(kwargs["processor"].cache_files[0]["filename"]).parent
+ / "plots"
+ )
+ elif hasattr(X, "cache_files") and X.cache_files:
+ output_dir = Path(X.cache_files[0]["filename"]).parent / "plots"
+ else:
+ output_dir = Path(_build_cache_dir(X, data_fingerprint)) / "plots"
+ if file_name is None:
+ cls_name = camelcase_to_snakecase(
+ self.__class__.__name__.replace("_plotter", "")
+ )
+ file_name = f"{fingerprint}_{cls_name}.{self.config.device}"
+
+ fig = plot_func(input, *args, **kwargs)
+ if "matplotlib" in sys.modules:
+ import matplotlib.pyplot as plt
+
+ if output_dir:
+
+ def get_or_move_images(images, output_dir, move_images=False):
+ image_paths = []
+ if len(images) > 1:
+ for fn in images:
+ old_name = fn.resolve().as_posix()
+ new_name = f"{output_dir}/{fn.name}"
+ if move_images:
+ image_paths.append(new_name)
+ move_temp_file(old_name, new_name)
+ else:
+ image_paths.append(old_name)
+ if move_images:
+ logger.info(f"Saved {len(images)} plots to {output_dir}")
+ elif len(images) == 1:
+ old_name = images[0].resolve().as_posix()
+ if move_images:
+ image_paths = f"{output_dir}/{file_name}"
+ move_temp_file(old_name, image_paths)
+ logger.info(f"Saved plot to {image_paths}")
+ else:
+ image_paths = old_name
+
+ return image_paths
+
+ output_dir.mkdir(parents=True, exist_ok=True)
+ output_dir = output_dir.resolve().as_posix()
+ # move all files within temp_dir to output_dir
+ images = [fn for fn in Path(temp_dir).glob(f"*.{self.config.device}")]
+ image_paths = get_or_move_images(images, output_dir, move_images=True)
+
+ _image_paths = []
+ if self.config.device != "png":
+ _images = [fn for fn in Path(temp_dir).glob("*.png")]
+ _image_paths = get_or_move_images(
+ _images, output_dir, move_images=False
+ )
+ if len(_image_paths) > 0:
+ image_paths = _image_paths
+ images = _images
+
+ if len(image_paths) > 1:
+ if is_in_notebook() and show and is_ipywidgets_available():
+ display_image_carousel(image_paths, "png")
+ elif len(image_paths) == 1:
+ if is_in_notebook() and show:
+ if (
+ is_matplotlib_available()
+ and "matplotlib" in sys.modules
+ and isinstance(fig, plt.Figure)
+ ):
+ # ignore warning about non-interactive backend
+ warnings.filterwarnings("ignore")
+ fig.show()
+ warnings.filterwarnings("default")
+ else:
+ from IPython.display import Image, display
+
+ display(
+ Image(
+ image_paths,
+ embed=True,
+ format="png",
+ width=720,
+ )
+ )
+ elif (
+ is_matplotlib_available()
+ and "matplotlib" in sys.modules
+ and isinstance(fig, plt.Figure)
+ ):
+ plt.close(fig)
+ else:
+ logger.warning("No plots were generated")
+ return image_paths
+
+ def _process_plot_input(self, input, *args, **kwargs):
+ return input, args, kwargs
diff --git a/src/biofit/visualization/plotting_utils.py b/src/biofit/visualization/plotting_utils.py
new file mode 100644
index 0000000..ff6a63e
--- /dev/null
+++ b/src/biofit/visualization/plotting_utils.py
@@ -0,0 +1,1210 @@
+import sys
+from os import PathLike
+from types import NoneType
+from typing import List, Optional, Union
+
+import numpy as np
+import pandas as pd
+from biocore import DataHandler
+from biocore.utils import requires_backends
+from sklearn.pipeline import Pipeline
+
+from biofit.processing import SelectedColumnTypes
+from biofit.utils.types import Unset
+
+
+def is_in_notebook():
+ try:
+ from IPython import get_ipython
+
+ if "IPKernelApp" in get_ipython().config:
+ return True
+ except Exception:
+ return False
+ return False
+
+
+def display_image_carousel(image_paths, format="png"):
+ """
+ Displays an image carousel with left and right arrow buttons in a Jupyter notebook.
+
+ Parameters:
+ image_paths (list): List of paths to images.
+ output_dir (str): Directory to indicate in the log if limit is exceeded.
+ max_images (int): Maximum number of images to include in the carousel.
+ """
+ requires_backends(display_image_carousel, "ipywidgets")
+ import ipywidgets as widgets
+ from IPython.display import Image, clear_output, display
+ from ipywidgets import Button, HBox, Layout, VBox
+
+ image_index = 0
+
+ button_layout = Layout(width="100px")
+ left_button = Button(description="Previous", layout=button_layout)
+ right_button = Button(description="Next", layout=button_layout)
+ image_box = widgets.Output()
+
+ def show_image():
+ """Helper function to show the image."""
+ with image_box:
+ clear_output(wait=True)
+ display(
+ Image(image_paths[image_index], embed=True, format=format, width=720)
+ )
+
+ def on_left_button_clicked(b):
+ nonlocal image_index
+ if image_index > 0:
+ image_index -= 1
+ show_image()
+
+ def on_right_button_clicked(b):
+ nonlocal image_index
+ if image_index < len(image_paths) - 1:
+ image_index += 1
+ show_image()
+
+ left_button.on_click(on_left_button_clicked)
+ right_button.on_click(on_right_button_clicked)
+
+ show_image() # Show the first image initially
+
+ navigation_box = HBox([left_button, right_button])
+ display(VBox([navigation_box, image_box]))
+
+
+def get_distinct_colors(n, colormap="nipy_spectral"):
+ """
+ Generate a list of n distinct colors using a specified colormap.
+ Switched to 'nipy_spectral' for higher contrast and brightness differentiation.
+
+ :param n: Number of distinct colors to generate.
+ :param colormap: Name of the matplotlib colormap to use.
+ :return: List of RGBA colors.
+ """
+ requires_backends(get_distinct_colors, "matplotlib")
+ from matplotlib import colormaps
+
+ cmap = colormaps.get_cmap(colormap)
+ return [cmap(i / n) for i in range(n)]
+
+
+def generate_violin(
+ x,
+ y=None,
+ column: SelectedColumnTypes = None,
+ label_name: SelectedColumnTypes = None,
+ xlab: str = Unset('"Labels"'),
+ ylab: str = Unset('"Value"'),
+ device: str = "pdf",
+ fingerprint: str = None,
+ unused_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = True,
+ output_dir: Union[PathLike, str] = None,
+):
+ """Generates a violin plot of the input data.
+
+ Args:
+ x: Input data.
+ other_x: Other input data.
+ y: Target data.
+ other_x: Other target data.
+ **kwargs: Additional arguments.
+ Returns:
+ Plotter object.
+ """
+ from biofit.visualization.violin import ViolinPlotter
+
+ return ViolinPlotter().plot(
+ x,
+ y,
+ column=column,
+ label_name=label_name,
+ xlab=xlab,
+ ylab=ylab,
+ path=output_dir,
+ device=device,
+ fingerprint=fingerprint,
+ unused_columns=unused_columns,
+ raise_if_missing=raise_if_missing,
+ )
+
+
+def generate_scatterplot(
+ x,
+ y=None,
+ group=None,
+ xdata: SelectedColumnTypes = None,
+ ydata: SelectedColumnTypes = None,
+ groupby: SelectedColumnTypes = None,
+ xlab: str = Unset("None"),
+ ylab: str = Unset("None"),
+ title: str = Unset('"Scatterplot"'),
+ alpha: str = Unset("1"),
+ col_set: str = Unset('"Set1"'),
+ cols: List[str] = Unset("None"),
+ xlog: str = Unset("None"),
+ ylog: str = Unset("None"),
+ output_dir: str = None,
+ device: str = "pdf",
+ fingerprint: str = None,
+ unused_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = True,
+):
+ from biofit.visualization.scatterplot import ScatterPlotter
+
+ plotter = ScatterPlotter()
+ return plotter.plot(
+ x,
+ y,
+ group,
+ xdata=xdata,
+ ydata=ydata,
+ groupby=groupby,
+ xlab=xlab,
+ ylab=ylab,
+ title=title,
+ alpha=alpha,
+ col_set=col_set,
+ cols=cols,
+ xlog=xlog,
+ ylog=ylog,
+ path=output_dir,
+ device=device,
+ fingerprint=fingerprint,
+ unused_columns=unused_columns,
+ raise_if_missing=raise_if_missing,
+ )
+
+
+def generate_histogram(
+ x,
+ input_columns: SelectedColumnTypes = None,
+ xlab: str = Unset('"X"'),
+ ylab: str = Unset('"Frequency"'),
+ title: str = Unset('"Histogram"'),
+ bins: int = Unset("30"),
+ font_size: int = Unset("8"),
+ col_fill: str = Unset('"grey40"'),
+ col_outline: str = Unset('"white"'),
+ xlog: Optional[str] = Unset("None"),
+ ylog: Optional[str] = Unset("None"),
+ output_dir: str = None,
+ device: str = "pdf",
+ fingerprint: str = None,
+ unused_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = True,
+):
+ """Generates a histogram of the input data.
+
+ Args:
+ x: Input data.
+ other_x: Other input data.
+ y: Target data.
+ other_x: Other target data.
+ **kwargs: Additional arguments.
+ Returns:
+ Plotter object.
+ """
+ from biofit.visualization.histogram import HistogramPlotter
+
+ plotter = HistogramPlotter()
+
+ return plotter.plot(
+ x,
+ input_columns=input_columns,
+ xlab=xlab,
+ ylab=ylab,
+ title=title,
+ bins=bins,
+ font_size=font_size,
+ col_fill=col_fill,
+ col_outline=col_outline,
+ xlog=xlog,
+ ylog=ylog,
+ path=output_dir,
+ device=device,
+ fingerprint=fingerprint,
+ unused_columns=unused_columns,
+ raise_if_missing=raise_if_missing,
+ )
+
+
+def generate_barplot(
+ x,
+ y=None,
+ group=None,
+ label_name: SelectedColumnTypes = None,
+ value_name=None,
+ groupby: Optional[str] = None,
+ xlab: Optional[str] = Unset("None"),
+ ylab: Optional[str] = Unset("None"),
+ title: str = Unset('"Bar Plot"'),
+ col_set: str = Unset('"Set1"'),
+ col_labels: str = Unset('"black"'),
+ col_outline: str = Unset('"grey30"'),
+ cols: Optional[List[str]] = Unset("None"),
+ prop: bool = Unset("False"),
+ add_count_lab: bool = Unset("True"),
+ vars_as_entered: bool = Unset("False"),
+ legend_position: str = Unset('"top"'),
+ font_size: float = Unset("3.25"),
+ output_dir: str = None,
+ device: str = "pdf",
+ fingerprint: str = None,
+ unused_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = True,
+):
+ """Generates a bar plot of the input data.
+
+ Args:
+ x: Input data.
+ other_x: Other input data.
+ y: Target data.
+ other_x: Other target data.
+ **kwargs: Additional arguments.
+ Returns:
+ Plotter object.
+ """
+ from biofit.visualization.barplot import BarPlotter
+
+ plotter = BarPlotter()
+
+ return plotter.plot(
+ x,
+ y,
+ group,
+ label_name=label_name,
+ value_name=value_name,
+ groupby=groupby,
+ xlab=xlab,
+ ylab=ylab,
+ title=title,
+ col_set=col_set,
+ col_labels=col_labels,
+ col_outline=col_outline,
+ cols=cols,
+ prop=prop,
+ add_count_lab=add_count_lab,
+ vars_as_entered=vars_as_entered,
+ legend_position=legend_position,
+ font_size=font_size,
+ path=output_dir,
+ device=device,
+ fingerprint=fingerprint,
+ unused_columns=unused_columns,
+ raise_if_missing=raise_if_missing,
+ )
+
+
+def generate_comparison_histogram(
+ x1,
+ x2=None,
+ column1: str = None,
+ column2: str = None,
+ xlab: Optional[str] = Unset("None"),
+ ylab: str = Unset("None"),
+ title: str = Unset("None"),
+ bins: int = Unset("None"),
+ alpha: float = Unset("None"),
+ legend_title: str = Unset("None"),
+ legend_position: str = Unset("None"),
+ subplot_title1: str = Unset("None"),
+ subplot_title2: str = Unset("None"),
+ col_set: str = Unset("None"),
+ cols: Optional[List[str]] = Unset("None"),
+ xlog: Optional[bool] = Unset("None"),
+ ylog: Optional[bool] = Unset("None"),
+ device: str = "pdf",
+ fingerprint: str = None,
+ unused_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = True,
+ output_dir=None,
+):
+ """Generates a comparison histogram of the input data.
+
+ Args:
+ x: Input data.
+ other_x: Other input data.
+ y: Target data.
+ other_x: Other target data.
+ **kwargs: Additional arguments.
+ Returns:
+ Plotter object.
+ """
+ from biofit.visualization.histogram import ComparisonHistogramPlotter
+
+ plotter = ComparisonHistogramPlotter()
+
+ return plotter.plot(
+ x1,
+ x2=x2,
+ column1=column1,
+ column2=column2,
+ xlab=xlab,
+ ylab=ylab,
+ title=title,
+ bins=bins,
+ alpha=alpha,
+ legend_title=legend_title,
+ legend_position=legend_position,
+ col_set=col_set,
+ cols=cols,
+ subplot_title1=subplot_title1,
+ subplot_title2=subplot_title2,
+ xlog=xlog,
+ ylog=ylog,
+ path=output_dir,
+ device=device,
+ fingerprint=fingerprint,
+ unused_columns=unused_columns,
+ raise_if_missing=raise_if_missing,
+ )
+
+
+def plot_correlation(
+ X,
+ y=None,
+ group=None,
+ input_columns=None,
+ target_column=None,
+ groupby: Optional[str] = None,
+ precomputed=False,
+ method="auto",
+ label_name: str = Unset('"None"'),
+ value_name: Optional[str] = Unset("None"),
+ top_k: int = Unset("30"),
+ xlab: Optional[str] = Unset("None"),
+ ylab: Optional[str] = Unset("None"),
+ title: str = Unset('"Bar Plot"'),
+ col_set: str = Unset('"Set1"'),
+ cols: Optional[List[str]] = Unset("None"),
+ prop: bool = Unset("False"),
+ add_count_lab: bool = Unset("True"),
+ vars_as_entered: bool = Unset("False"),
+ legend_position: str = Unset('"top"'),
+ font_size: float = Unset("3.25"),
+ output_dir=None,
+ file_name=None,
+):
+ """Plot the correlation matrix of the input data.
+
+ Args:
+ X: Input data
+ y: Target variable
+ input_columns: Columns to filter
+ target_column: Target column
+ top_k: Number of top features to plot
+ label_name: Name of the labels
+ value_name: Name of the values
+ groupby: Groupby column
+ xlab: X-axis label
+ ylab: Y-axis label
+ title: Plot title
+ col_set: Color set
+ cols: Columns to plot
+ prop: Whether to plot proportions
+ add_count_lab: Whether to add count labels
+ vars_as_entered: Whether the variables are entered as is
+ legend_position: Position of the legend
+ font_size: Font size
+ Returns:
+ SampleFiltered data.
+ """
+ if not precomputed:
+ from biofit.stat import CorrelationStat
+
+ corr_stat = CorrelationStat(method=method)
+ corrs = corr_stat.fit_transform(
+ X, y, input_columns, target_column, output_format="pandas"
+ )
+ name_map = {
+ "pearsonr": "Pearson Correlation",
+ "spearmanr": "Spearman Correlation",
+ "kendalltau": "Kendall Correlation",
+ "pointbiserialr": "Point Biserial Correlation",
+ }
+ value_name = name_map.get(corr_stat.config.method, "Correlation")
+ if isinstance(groupby, (Unset, NoneType)):
+ groupby = "Features"
+ corrs = corrs.melt(var_name=groupby, value_name=value_name)
+ # drop na
+ corrs = corrs.dropna()
+ sorted_inds = np.argsort(np.abs(corrs[value_name]))[::-1]
+ if isinstance(top_k, (Unset, NoneType)):
+ top_k = 30
+ corrs = corrs.iloc[sorted_inds[:top_k]]
+ return generate_barplot(
+ corrs,
+ None,
+ None,
+ label_name=label_name,
+ value_name=value_name,
+ groupby=groupby,
+ xlab=xlab,
+ ylab=ylab,
+ title=title,
+ col_set=col_set,
+ cols=cols,
+ prop=prop,
+ add_count_lab=add_count_lab,
+ vars_as_entered=vars_as_entered,
+ legend_position=legend_position,
+ font_size=font_size,
+ output_dir=output_dir,
+ )
+
+ return generate_barplot(
+ X,
+ y,
+ None,
+ label_name=label_name,
+ value_name=value_name,
+ groupby=groupby,
+ xlab=xlab,
+ ylab=ylab,
+ title=title,
+ col_set=col_set,
+ cols=cols,
+ prop=prop,
+ add_count_lab=add_count_lab,
+ vars_as_entered=vars_as_entered,
+ legend_position=legend_position,
+ font_size=font_size,
+ output_dir=output_dir,
+ )
+
+
+def plot_feature_distribution(
+ X,
+ input_columns: SelectedColumnTypes = None,
+ value_name=None,
+ aggregate=None,
+ aggregate_kwargs={},
+ xlab: str = Unset('"X"'),
+ ylab: str = Unset('"Frequency"'),
+ title: str = Unset('"Histogram"'),
+ bins: int = Unset("30"),
+ font_size: int = Unset("8"),
+ col_fill: str = Unset('"grey40"'),
+ col_outline: str = Unset('"white"'),
+ xlog: Optional[str] = Unset("None"),
+ ylog: Optional[str] = Unset("None"),
+ output_dir: str = None,
+ device: str = "pdf",
+ fingerprint: str = None,
+ unused_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = True,
+):
+ """Plot the feature distribution of the input data.
+
+ Args:
+ X: Input data
+ columns: Columns to plot
+ aggregate: Aggregate each feature by: 'mean', 'median', 'sum', 'std', 'var', 'min', 'max', or 'presence'
+ **kwargs: Additional keyword arguments
+
+ Returns:
+ Plot object.
+ """
+ x_dims = DataHandler.get_shape(X)
+ precomputed = aggregate is None and (
+ value_name is not None or len(x_dims) == 1 or x_dims[1] == 1
+ )
+ value_name = value_name or "Presence"
+
+ if aggregate == "presence":
+ from biofit.stat import ColumnMissingnessStat
+
+ num_rows = DataHandler.get_shape(X)[0]
+ missingness = ColumnMissingnessStat(
+ input_columns=input_columns, **aggregate_kwargs
+ ).fit_transform(X, output_format="pandas")
+ data = num_rows - missingness
+ elif aggregate == "sum":
+ from biofit.stat import ColumnSumStat
+
+ data = ColumnSumStat(
+ input_columns=input_columns, **aggregate_kwargs
+ ).fit_transform(X, output_format="pandas")
+ elif aggregate == "mean":
+ from biofit.stat import ColumnMeanStat
+
+ data = ColumnMeanStat(
+ input_columns=input_columns, **aggregate_kwargs
+ ).fit_transform(X, output_format="pandas")
+
+ if (
+ input_columns is None
+ and "biosets" in sys.modules
+ and isinstance(X, getattr(sys.modules["biosets"], "Bioset"))
+ ):
+ from biosets import get_data
+
+ data = get_data(X)
+ else:
+ data = X
+
+ data = DataHandler.to_pandas(data)
+
+ if input_columns:
+ data = data[input_columns]
+
+ if not precomputed:
+ data = data.melt(value_name=value_name)
+
+ return generate_histogram(
+ data,
+ input_columns=input_columns,
+ xlab=xlab,
+ ylab=ylab,
+ title=title,
+ bins=bins,
+ font_size=font_size,
+ col_fill=col_fill,
+ col_outline=col_outline,
+ xlog=xlog,
+ ylog=ylog,
+ output_dir=output_dir,
+ device=device,
+ fingerprint=fingerprint,
+ unused_columns=unused_columns,
+ raise_if_missing=raise_if_missing,
+ )
+
+
+def compare_feature_distributions(
+ X1,
+ X2,
+ columns1: SelectedColumnTypes = None,
+ columns2: SelectedColumnTypes = None,
+ value_name=None,
+ aggregate=None,
+ aggregate_kwargs={},
+ xlab: Optional[str] = Unset("None"),
+ ylab: str = Unset("None"),
+ title: str = Unset("None"),
+ bins: int = Unset("None"),
+ alpha: float = Unset("None"),
+ legend_title: str = Unset("None"),
+ legend_position: str = Unset("None"),
+ subplot_title1: str = Unset("None"),
+ subplot_title2: str = Unset("None"),
+ col_set: str = Unset("None"),
+ cols: Optional[List[str]] = Unset("None"),
+ xlog: Optional[bool] = Unset("None"),
+ ylog: Optional[bool] = Unset("None"),
+ output_dir: str = None,
+ device: str = "pdf",
+ fingerprint: str = None,
+ unused_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = True,
+):
+ """Compare the feature distribution of the input data.
+
+ Args:
+ X1: Input data 1
+ X2: Input data 2
+ columns1: Columns to plot for data 1
+ columns2: Columns to plot for data 2. If not provided, columns1 will be used for data 2.
+ aggregate: Aggregate each feature by: 'mean', 'median', 'sum', 'std', 'var', 'min', 'max', or 'presence'
+ **kwargs: Additional keyword arguments
+ Returns:
+ Plot object.
+ """
+ x1_dims = DataHandler.get_shape(X1)
+ x2_dims = DataHandler.get_shape(X2)
+ precomputed = (
+ aggregate is None
+ and (value_name is not None or len(x1_dims) == 1 or x1_dims[1] == 1)
+ and (value_name is not None or len(x2_dims) == 1 or x2_dims[1] == 1)
+ )
+
+ if columns1 and not columns2:
+ columns2 = columns1
+
+ if aggregate == "presence":
+ from biofit.stat import ColumnMissingnessStat
+
+ value_name = value_name or "Presence"
+ num_rows1 = DataHandler.get_shape(X1)[0]
+ num_rows2 = DataHandler.get_shape(X2)[0]
+
+ missingness1 = ColumnMissingnessStat(
+ input_columns=columns1, **aggregate_kwargs
+ ).fit_transform(X1, output_format="pandas")
+ missingness2 = ColumnMissingnessStat(
+ input_columns=columns2, **aggregate_kwargs
+ ).fit_transform(X2, output_format="pandas")
+
+ data1 = num_rows1 - missingness1
+ data2 = num_rows2 - missingness2
+ elif aggregate == "sum":
+ from biofit.stat import ColumnSumStat
+
+ value_name = value_name or "Sum"
+ data1 = ColumnSumStat(input_columns=columns1, **aggregate_kwargs).fit_transform(
+ X1, output_format="pandas"
+ )
+ data2 = ColumnSumStat(input_columns=columns2, **aggregate_kwargs).fit_transform(
+ X2, output_format="pandas"
+ )
+ elif aggregate == "mean":
+ from biofit.stat import ColumnMeanStat
+
+ value_name = value_name or "Mean"
+ data1 = ColumnMeanStat(
+ input_columns=columns1, **aggregate_kwargs
+ ).fit_transform(X1, output_format="pandas")
+ data2 = ColumnMeanStat(
+ input_columns=columns2, **aggregate_kwargs
+ ).fit_transform(X2, output_format="pandas")
+ else:
+ data1 = X1
+ data2 = X2
+
+ value_name = value_name or "Value"
+
+ data1 = DataHandler.to_pandas(data1)
+ data2 = DataHandler.to_pandas(data2)
+
+ if columns1:
+ data1 = data1[columns1]
+
+ if columns2:
+ data2 = data2[columns2]
+
+ if not precomputed:
+ if data1.shape[1] > 2:
+ data1 = data1.melt(value_name=value_name)
+
+ if data2.shape[1] > 2:
+ data2 = data2.melt(value_name=value_name)
+
+ return generate_comparison_histogram(
+ data1,
+ data2,
+ column1=value_name,
+ column2=value_name,
+ xlab=xlab,
+ ylab=ylab,
+ title=title,
+ bins=bins,
+ alpha=alpha,
+ legend_title=legend_title,
+ legend_position=legend_position,
+ col_set=col_set,
+ cols=cols,
+ subplot_title1=subplot_title1,
+ subplot_title2=subplot_title2,
+ xlog=xlog,
+ ylog=ylog,
+ output_dir=output_dir,
+ device=device,
+ fingerprint=fingerprint,
+ unused_columns=unused_columns,
+ raise_if_missing=raise_if_missing,
+ )
+
+
+def plot_sample_distribution(
+ X,
+ input_columns: SelectedColumnTypes = None,
+ value_name=None,
+ aggregate=None,
+ aggregate_kwargs={},
+ xlab: str = Unset('"X"'),
+ ylab: str = Unset('"Frequency"'),
+ title: str = Unset('"Histogram"'),
+ bins: int = Unset("30"),
+ font_size: int = Unset("8"),
+ col_fill: str = Unset('"grey40"'),
+ col_outline: str = Unset('"white"'),
+ xlog: Optional[str] = Unset("None"),
+ ylog: Optional[str] = Unset("None"),
+ output_dir: str = None,
+ device: str = "pdf",
+ fingerprint: str = None,
+ unused_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = True,
+):
+ """Plot the feature distribution of the input data.
+
+ Args:
+ X: Input data
+ columns: Columns to plot
+ aggregate: Aggregate each feature by: 'mean', 'median', 'sum', 'std', 'var', 'min', 'max', or 'presence'
+ **kwargs: Additional keyword arguments
+
+ Returns:
+ Plot object.
+ """
+ x_dims = DataHandler.get_shape(X)
+ precomputed = aggregate is None and (
+ value_name is not None or len(x_dims) == 1 or x_dims[1] == 1
+ )
+
+ if aggregate == "presence":
+ from biofit.stat import RowMissingnessStat
+
+ value_name = value_name or "Presence"
+ num_rows = DataHandler.get_shape(X)[0]
+ missingness = RowMissingnessStat(
+ input_columns=input_columns, **aggregate_kwargs
+ ).fit_transform(X, output_format="pandas")
+ data = num_rows - missingness
+ elif aggregate == "sum":
+ from biofit.stat import RowSumStat
+
+ value_name = value_name or "Sum"
+ data = RowSumStat(
+ input_columns=input_columns, **aggregate_kwargs
+ ).fit_transform(X, output_format="pandas")
+ elif aggregate == "mean":
+ from biofit.stat import RowMeanStat
+
+ value_name = value_name or "Mean"
+ data = RowMeanStat(
+ input_columns=input_columns, **aggregate_kwargs
+ ).fit_transform(X, output_format="pandas")
+
+ value_name = value_name or "Value"
+
+ if (
+ input_columns is None
+ and "biosets" in sys.modules
+ and isinstance(X, getattr(sys.modules["biosets"], "Bioset"))
+ ):
+ from biosets import get_data
+
+ data = get_data(X)
+ else:
+ data = X
+
+ data = DataHandler.to_pandas(data)
+
+ if input_columns:
+ data = data[input_columns]
+
+ if not precomputed:
+ if data.shape[1] > 2:
+ data = data.melt(value_name=value_name)
+
+ return generate_histogram(
+ data,
+ input_columns=value_name,
+ xlab=xlab,
+ ylab=ylab,
+ title=title,
+ bins=bins,
+ font_size=font_size,
+ col_fill=col_fill,
+ col_outline=col_outline,
+ xlog=xlog,
+ ylog=ylog,
+ output_dir=output_dir,
+ device=device,
+ fingerprint=fingerprint,
+ unused_columns=unused_columns,
+ raise_if_missing=raise_if_missing,
+ )
+
+
+def compare_sample_distributions(
+ X1,
+ X2,
+ columns1: SelectedColumnTypes = None,
+ columns2: SelectedColumnTypes = None,
+ value_name=None,
+ aggregate=None,
+ aggregate_kwargs={},
+ xlab: Optional[str] = Unset("None"),
+ ylab: str = Unset("None"),
+ title: str = Unset("None"),
+ bins: int = Unset("None"),
+ alpha: float = Unset("None"),
+ legend_title: str = Unset("None"),
+ legend_position: str = Unset("None"),
+ subplot_title1: str = Unset("None"),
+ subplot_title2: str = Unset("None"),
+ col_set: str = Unset("None"),
+ cols: Optional[List[str]] = Unset("None"),
+ xlog: Optional[bool] = Unset("None"),
+ ylog: Optional[bool] = Unset("None"),
+ device: str = "pdf",
+ fingerprint: str = None,
+ unused_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = True,
+ output_dir=None,
+):
+ """Compare the feature distribution of the input data.
+
+ Args:
+ X1: Input data 1
+ X2: Input data 2
+ columns1: Columns to plot for data 1
+ columns2: Columns to plot for data 2. If not provided, columns1 will be used for data 2.
+ aggregate: Aggregate each feature by: 'mean', 'median', 'sum', 'std', 'var', 'min', 'max', or 'presence'
+ **kwargs: Additional keyword arguments
+
+ Returns:
+ Plot object.
+ """
+ x1_dims = DataHandler.get_shape(X1)
+ x2_dims = DataHandler.get_shape(X2)
+ precomputed = (
+ aggregate is None
+ and (value_name is not None or len(x1_dims) == 1 or x1_dims[1] == 1)
+ and (value_name is not None or len(x2_dims) == 1 or x2_dims[1] == 1)
+ )
+
+ if columns1 and not columns2:
+ columns2 = columns1
+
+ if aggregate == "presence":
+ from biofit.stat import RowMissingnessStat
+
+ value_name = value_name or "Presence"
+ num_cols1 = x1_dims[1] if len(x1_dims) == 2 else 1
+ num_cols2 = x2_dims[1] if len(x2_dims) == 2 else 1
+
+ missingness1 = RowMissingnessStat(
+ input_columns=columns1, keep_unused_columns=False, **aggregate_kwargs
+ ).fit_transform(X1, output_format="pandas")
+ missingness2 = RowMissingnessStat(
+ input_columns=columns2, keep_unused_columns=False, **aggregate_kwargs
+ ).fit_transform(X2, output_format="pandas")
+
+ data1 = num_cols1 - missingness1
+ data2 = num_cols2 - missingness2
+ data1.columns = [value_name]
+ data2.columns = [value_name]
+ elif aggregate == "sum":
+ from biofit.stat import RowSumStat
+
+ value_name = value_name or "Sum"
+ data1 = RowSumStat(
+ input_columns=columns1, keep_unused_columns=False, **aggregate_kwargs
+ ).fit_transform(X1, output_format="pandas")
+ data2 = RowSumStat(
+ input_columns=columns2, keep_unused_columns=False, **aggregate_kwargs
+ ).fit_transform(X2, output_format="pandas")
+ data1.columns = [value_name]
+ data2.columns = [value_name]
+ elif aggregate == "mean":
+ from biofit.stat import RowMeanStat
+
+ value_name = value_name or "Mean"
+ data1 = RowMeanStat(
+ input_columns=columns1, keep_unused_columns=False, **aggregate_kwargs
+ ).fit_transform(X1, output_format="pandas")
+ data2 = RowMeanStat(
+ input_columns=columns2, keep_unused_columns=False, **aggregate_kwargs
+ ).fit_transform(X2, output_format="pandas")
+ data1.columns = [value_name]
+ data2.columns = [value_name]
+
+ value_name = value_name or "Value"
+
+ data1 = DataHandler.to_pandas(data1)
+ data2 = DataHandler.to_pandas(data2)
+
+ if columns1:
+ data1 = data1[columns1]
+
+ if columns2:
+ data2 = data2[columns2]
+
+ if not precomputed:
+ if data1.shape[1] > 1:
+ data1 = data1.melt(value_name=value_name)
+ else:
+ data1.columns = [value_name]
+
+ if data2.shape[1] > 1:
+ data2 = data2.melt(value_name=value_name)
+ else:
+ data1.columns = [value_name]
+
+ if value_name not in data1.columns:
+ raise ValueError(
+ f"Column '{value_name}' not found in data1. Found columns: {data1.columns}"
+ )
+ if value_name not in data2.columns:
+ raise ValueError(
+ f"Column '{value_name}' not found in data2. Found columns: {data2.columns}"
+ )
+
+ return generate_comparison_histogram(
+ data1,
+ data2,
+ column1=value_name,
+ column2=value_name,
+ xlab=xlab,
+ ylab=ylab,
+ title=title,
+ bins=bins,
+ alpha=alpha,
+ legend_title=legend_title,
+ legend_position=legend_position,
+ col_set=col_set,
+ cols=cols,
+ subplot_title1=subplot_title1,
+ subplot_title2=subplot_title2,
+ xlog=xlog,
+ ylog=ylog,
+ output_dir=output_dir,
+ device=device,
+ fingerprint=fingerprint,
+ unused_columns=unused_columns,
+ raise_if_missing=raise_if_missing,
+ )
+
+
+def plot_dimension_reduction(
+ X,
+ labels=None,
+ group=None,
+ input_columns: SelectedColumnTypes = None,
+ label_column: SelectedColumnTypes = None,
+ group_column: SelectedColumnTypes = None,
+ method=None,
+ method_kwargs={},
+ title: str = Unset('"Dimension Reduction Plot"'),
+ colormap: str = Unset('"nipy_spectral"'),
+ n_components: int = Unset("3"),
+ dimension_reducer=None,
+ output_dir: str = None,
+ device: str = "pdf",
+ fingerprint: str = None,
+ unused_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = True,
+ show=True,
+):
+ """Plot the dimension reduction plot.
+
+ Args:
+ X: Input data
+ labels: Labels for the data
+ group: Group for the data
+ dimension_reducer: Dimension reducer object
+ **kwargs: Additional keyword arguments
+
+ Returns:
+ Plot object.
+ """
+ requires_backends(plot_sample_metadata, ["matplotlib", "seaborn"])
+ if labels is None and label_column is None:
+ if "biosets" in sys.modules and isinstance(
+ X, getattr(sys.modules["biosets"], "Bioset")
+ ):
+ from biosets import get_target
+
+ labels = get_target(X)
+ elif label_column:
+ labels = DataHandler.select_columns(X, label_column)
+ if "biosets" in sys.modules and isinstance(
+ X, getattr(sys.modules["biosets"], "Bioset")
+ ):
+ from biosets import decode
+
+ labels = decode(labels)
+
+ if input_columns is not None:
+ X = DataHandler.select_columns(X, input_columns)
+ input_columns = None
+ if method is None:
+ from biofit.preprocessing import PCAFeatureExtractor
+
+ method = PCAFeatureExtractor
+ if isinstance(method, type):
+ method = method()
+ if isinstance(method, str):
+ from biofit.auto.processing_auto import AutoPreprocessor
+
+ if method_kwargs is None or not isinstance(method_kwargs, dict):
+ method_kwargs = {}
+
+ method = AutoPreprocessor.for_processor(method, **method_kwargs)
+ if not method.is_fitted:
+ data = method.fit_transform(X, load_from_cache_file=False)
+ else:
+ data = method.transform(X)
+
+ from biofit.visualization.dimension_reduction import DimensionReductionPlotter
+
+ return DimensionReductionPlotter().plot(
+ data,
+ labels,
+ group,
+ input_columns=input_columns,
+ label_column=label_column,
+ group_column=group_column,
+ n_components=n_components,
+ dimension_reducer=method,
+ show=show,
+ path=output_dir,
+ device=device,
+ fingerprint=fingerprint,
+ unused_columns=unused_columns,
+ raise_if_missing=raise_if_missing,
+ )
+
+
+def get_feature_importances(models, label_names=None):
+ if not isinstance(models, list):
+ models = [models]
+ tables = []
+ for fold, model in enumerate(models):
+ if isinstance(model, Pipeline):
+ _m = model[-1]
+ else:
+ _m = model
+ features = _m.config.feature_names_in_
+ feature_importances = _m.feature_importances_
+ classes = label_names
+ if classes is None:
+ if hasattr(_m, "config") and hasattr(_m.config, "class_names"):
+ classes = _m.config.class_names
+ if classes is None:
+ classes = getattr(_m, "classes_", None)
+
+ if classes is not None and (
+ (len(feature_importances) // len(classes)) == len(features)
+ ):
+ # this means its a one vs all classifier
+ _tables = []
+ for i, c in enumerate(classes):
+ _tables.append(
+ pd.Series(
+ feature_importances[
+ i * len(features) : (i + 1) * len(features)
+ ],
+ index=features,
+ name=f"importances_{fold + 1}_{c}_vs_all",
+ )
+ )
+ tables.extend(_tables)
+ else:
+ tables.append(
+ pd.DataFrame(
+ {
+ f"importances_{fold + 1}": feature_importances,
+ },
+ index=features,
+ )
+ )
+
+ if len(tables) == 1:
+ return tables[0]
+
+ # Concatenate tables horizontally
+ return pd.concat(tables, axis=1, ignore_index=False, copy=False)
+
+
+def plot_feature_importance(
+ feature_importances,
+ models=None,
+ X=None,
+ y=None,
+ sample_metadata=None,
+ feature_metadata: dict = None,
+ input_columns: SelectedColumnTypes = None,
+ target_columns: SelectedColumnTypes = None,
+ sample_metadata_columns: SelectedColumnTypes = None,
+ sample_column: str = None,
+ label_names: List[str] = None,
+ plot_top: int = Unset("15"),
+ feature_meta_name: str = Unset("None"),
+ feature_column: str = Unset("None"),
+ cols: List[str] = Unset("None"),
+ colHeat: List[str] = Unset("None"),
+ dat_log: str = Unset("field(default=None, init=True, repr=False)"),
+ show_column_names: bool = Unset("False"),
+ scale_legend_title: str = Unset('"Value"'),
+ column_title: str = Unset('"Samples"'),
+ row_title: str = Unset('"Features"'),
+ plot_title: str = Unset('"Values"'),
+ output_dir: str = None,
+ device: str = "pdf",
+ fingerprint: str = None,
+ unused_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = True,
+ show=True,
+):
+ """Plot the feature importance of the input data.
+
+ Args:
+ X: Input data
+ models: List of models
+ y: Target variable
+ sample_metadata: Sample metadata
+ feature_metadata: Feature metadata
+ Returns:
+ Plot object.
+ """
+
+ from biofit.visualization.feature_importance import FeatureImportancePlotter
+
+ if feature_importances is None and models is not None:
+ feature_importances = get_feature_importances(models, label_names)
+ feature_importances = feature_importances.reset_index(names=["features"])
+ elif feature_importances is None:
+ raise ValueError("feature_importances or models must be provided")
+
+ FeatureImportancePlotter().plot(
+ X,
+ y=y,
+ sample_metadata=sample_metadata,
+ input_columns=input_columns,
+ target_columns=target_columns,
+ sample_metadata_columns=sample_metadata_columns,
+ feature_importances=feature_importances,
+ feature_metadata=feature_metadata,
+ plot_top=plot_top,
+ feature_meta_name=feature_meta_name,
+ sample_column=sample_column,
+ feature_column=feature_column,
+ cols=cols,
+ colHeat=colHeat,
+ dat_log=dat_log,
+ show_column_names=show_column_names,
+ scale_legend_title=scale_legend_title,
+ column_title=column_title,
+ row_title=row_title,
+ plot_title=plot_title,
+ show=show,
+ path=output_dir,
+ device=device,
+ fingerprint=fingerprint,
+ unused_columns=unused_columns,
+ raise_if_missing=raise_if_missing,
+ )
+
+
+def plot_sample_metadata(
+ sample_metadata,
+ outcome_data=None,
+ sample_metadata_columns: Optional[SelectedColumnTypes] = None,
+ outcome_column: Optional[SelectedColumnTypes] = None,
+ output_dir: str = None,
+ device: str = "pdf",
+ fingerprint: str = None,
+ unused_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = True,
+):
+ requires_backends(plot_sample_metadata, ["rpy2"])
+ from biofit.visualization.sample_metadata import SampleMetadataPlotter
+
+ SampleMetadataPlotter().plot(
+ sample_metadata,
+ outcome_data,
+ sample_metadata_columns=sample_metadata_columns,
+ outcome_column=outcome_column,
+ path=output_dir,
+ device=device,
+ fingerprint=fingerprint,
+ unused_columns=unused_columns,
+ raise_if_missing=raise_if_missing,
+ )
diff --git a/src/biofit/visualization/report_generation.py b/src/biofit/visualization/report_generation.py
new file mode 100644
index 0000000..8d9be83
--- /dev/null
+++ b/src/biofit/visualization/report_generation.py
@@ -0,0 +1,35 @@
+import os
+import sys
+from typing import TYPE_CHECKING, Optional
+
+import joblib
+from biocore.utils.import_util import is_optuna_available
+
+if TYPE_CHECKING:
+ import optuna
+
+
+def report_generation(data, study: Optional["optuna.Study"] = None, cache_dir=None):
+ if cache_dir is None and (
+ (
+ "biosets" in sys.modules
+ and isinstance(data, getattr(sys.modules["biosets"], "Bioset"))
+ )
+ or (
+ "datasets" in sys.modules
+ and isinstance(data, getattr(sys.modules["datasets"], "Dataset"))
+ )
+ ):
+ cache_dir = os.path.dirname(data.cache_files[0]["filename"])
+
+ if study is None:
+ if cache_dir:
+ with open(os.path.join(cache_dir, "study.joblib"), "rb") as f:
+ study = joblib.load(f)
+
+ if is_optuna_available() and study:
+ import optuna.visualization as ov
+
+ if study is None:
+ raise ValueError("Study object is required for optuna report generation")
+ ov.plot_optimization_history(study)
diff --git a/src/biofit/visualization/sample_metadata.py b/src/biofit/visualization/sample_metadata.py
new file mode 100644
index 0000000..1794958
--- /dev/null
+++ b/src/biofit/visualization/sample_metadata.py
@@ -0,0 +1,128 @@
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import TYPE_CHECKING, List, Optional, Type
+
+from biocore import DataHandler
+
+from biofit.integration import RCaller
+from biofit.integration.biosets import get_feature
+from biofit.processing import SelectedColumnTypes, sync_backup_config
+from biofit.visualization.plotting import BasePlotter, PlotterConfig
+
+if TYPE_CHECKING:
+ from datasets import Dataset
+
+
+@dataclass
+class SampleMetadataPlotterConfig(PlotterConfig):
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("METADATA_FEATURE_TYPES"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [
+ get_feature("METADATA_FEATURE_TYPES"),
+ get_feature("TARGET_FEATURE_TYPES"),
+ ],
+ init=False,
+ repr=False,
+ )
+ r_source: str = field(
+ default=(Path(__file__).parent / "plot_sample_metadata.R").as_posix(),
+ init=False,
+ repr=False,
+ )
+ main_method: str = field(default="plot_sample_metadata", init=False, repr=False)
+ sample_metadata_columns: Optional[SelectedColumnTypes] = None
+ outcome_column: Optional[SelectedColumnTypes] = None
+
+
+class SampleMetadataPlotter(BasePlotter):
+ _config_class = SampleMetadataPlotterConfig
+ config: SampleMetadataPlotterConfig
+
+ def __init__(
+ self,
+ config: Optional[SampleMetadataPlotterConfig] = None,
+ sample_metadata_columns: Optional[SelectedColumnTypes] = None,
+ outcome_column: Optional[SelectedColumnTypes] = None,
+ install_missing: bool = None,
+ **kwargs,
+ ):
+ super().__init__(
+ config=config,
+ sample_metadata_columns=sample_metadata_columns,
+ outcome_column=outcome_column,
+ install_missing=install_missing,
+ **kwargs,
+ )
+ r_source = (Path(__file__).parent / "plot_sample_metadata.R").as_posix()
+ r_caller = RCaller.from_script(r_source)
+ self.plotter = r_caller.get_method(self.config.main_method)
+
+ @sync_backup_config
+ def set_params(self, **kwargs):
+ self.config = self.config.replace_defaults(**kwargs)
+ r_source = (Path(__file__).parent / "plot_sample_metadata.R").as_posix()
+ r_caller = RCaller.from_script(r_source)
+ self.plotter = r_caller.get_method(self.config.main_method)
+ return self
+
+ def plot_dataset(self, X: "Dataset", y: "Dataset" = None):
+ from biosets import decode
+
+ if y is not None and isinstance(
+ next(iter(y._info.features.values())), get_feature("ClassLabel")
+ ):
+ y = decode(y)
+ current_name = next(iter(y._info.features.keys()))
+ original_name = next(iter(y._info.features.values())).id or current_name
+ if current_name != original_name:
+ y = DataHandler.rename_column(y, current_name, original_name)
+ if original_name in X.column_names:
+ X = DataHandler.drop_column(X, original_name)
+ self.plot_arrow(
+ DataHandler.to_arrow(X), DataHandler.to_arrow(y) if y is not None else None
+ )
+
+ def plot_arrow(self, X, y):
+ if self.config.outcome_column is None and y is None:
+ raise ValueError(
+ "No outcome column provided. Please provide the outcome data "
+ "or specify the outcome column."
+ )
+
+ if self.config.outcome_column is None:
+ self.config.outcome_column = list(y.column_names)[0]
+ self.plotter(X, outcome=y, **self.config.get_params())
+
+ def plot(
+ self,
+ sample_metadata,
+ outcome_data: SelectedColumnTypes = None,
+ sample_metadata_columns: Optional[SelectedColumnTypes] = None,
+ outcome_column: Optional[SelectedColumnTypes] = None,
+ path: str = None,
+ device: str = "pdf",
+ fingerprint: str = None,
+ unused_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = True,
+ show: bool = True,
+ ):
+ self.config._input_columns = self._set_input_columns_and_arity(
+ sample_metadata_columns, outcome_column
+ )
+ return self._plot(
+ sample_metadata,
+ outcome_data,
+ path=path,
+ device=device,
+ fingerprint=fingerprint,
+ unused_columns=unused_columns,
+ raise_if_missing=raise_if_missing,
+ show=show,
+ )
diff --git a/src/biofit/visualization/scatterplot.py b/src/biofit/visualization/scatterplot.py
new file mode 100644
index 0000000..e874e32
--- /dev/null
+++ b/src/biofit/visualization/scatterplot.py
@@ -0,0 +1,188 @@
+import textwrap
+from dataclasses import dataclass, field
+from typing import List, Optional, Type
+
+import biofit.config as config
+from biofit.integration.biosets import get_feature
+from biofit.processing import SelectedColumnTypes, sync_backup_config
+from biofit.utils.types import Unset
+from biofit.visualization.plotting import BasePlotter, PlotterConfig
+
+
+@dataclass
+class ScatterPlotConfig(PlotterConfig):
+ processor_type: str = field(default="scaling", init=False, repr=False)
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [None, get_feature("TARGET_FEATURE_TYPES"), None],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [None, get_feature("TARGET_FEATURE_TYPES"), None],
+ init=False,
+ repr=False,
+ )
+ _fit_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES"), None, None],
+ init=False,
+ repr=False,
+ )
+ _transform_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES"), None, None],
+ init=False,
+ repr=False,
+ )
+ r_source: str = field(
+ default=(config.R_SCRIPTS / "plotting_utils.R").as_posix(),
+ init=False,
+ repr=False,
+ )
+ main_method: str = field(default="generate_scatterplot", init=False, repr=False)
+
+ groupby: str = None
+ xdata: str = None
+ ydata: str = None
+ xlab: str = None
+ ylab: str = None
+ title: str = "Scatterplot"
+ alpha: str = 1
+ col_set: str = "Set1"
+ cols: List[str] = None
+ xlog: str = None
+ ylog: str = None
+
+
+class ScatterPlotter(BasePlotter):
+ _config_class = ScatterPlotConfig
+ config: ScatterPlotConfig
+
+ def __init__(
+ self,
+ groupby: str = None,
+ xdata: str = None,
+ ydata: str = None,
+ xlab: str = Unset("None"),
+ ylab: str = Unset("None"),
+ title: str = Unset('"Scatterplot"'),
+ alpha: str = Unset("1"),
+ col_set: str = Unset('"Set1"'),
+ cols: List[str] = Unset("None"),
+ xlog: str = Unset("None"),
+ ylog: str = Unset("None"),
+ install_missing: bool = None,
+ config: Optional[ScatterPlotConfig] = None,
+ ):
+ super().__init__(
+ config=config,
+ xlab=xlab,
+ ylab=ylab,
+ title=title,
+ alpha=alpha,
+ col_set=col_set,
+ cols=cols,
+ xlog=xlog,
+ ylog=ylog,
+ install_missing=install_missing,
+ )
+ self.plotter = None
+ if self.config.r_source and self.config.main_method:
+ self.r_caller = RCaller.from_script(self.config.r_source)
+ enter_code = textwrap.dedent(
+ """
+ suppressPackageStartupMessages(require(ggplot2))
+ """
+ )
+ exit_code = textwrap.dedent(
+ """
+ ggplot2::ggsave(path, results)
+ """
+ )
+ self.plotter = self.r_caller.get_method(
+ self.config.main_method, enter_code, exit_code
+ )
+
+ @sync_backup_config
+ def set_params(self, **kwargs):
+ self.config = self.config.replace_defaults(**kwargs)
+ if self.config.r_source and self.config.main_method:
+ self.r_caller = RCaller.from_script(self.config.r_source)
+ enter_code = textwrap.dedent(
+ """
+ suppressPackageStartupMessages(require(ggplot2))
+ """
+ )
+ exit_code = textwrap.dedent(
+ """
+ ggplot2::ggsave(path, results)
+ """
+ )
+ self.plotter = self.r_caller.get_method(
+ self.config.main_method, enter_code, exit_code
+ )
+
+ return self
+
+ def plot(
+ self,
+ x,
+ y=None,
+ group=None,
+ xdata: SelectedColumnTypes = None,
+ ydata: SelectedColumnTypes = None,
+ groupby: SelectedColumnTypes = None,
+ xlab: str = Unset("None"),
+ ylab: str = Unset("None"),
+ title: str = Unset('"Scatterplot"'),
+ alpha: str = Unset("1"),
+ col_set: str = Unset('"Set1"'),
+ cols: List[str] = Unset("None"),
+ xlog: str = Unset("None"),
+ ylog: str = Unset("None"),
+ path: str = None,
+ device: str = "pdf",
+ fingerprint: str = None,
+ unused_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = True,
+ show: bool = True,
+ ):
+ self.config._input_columns = self._set_input_columns_and_arity(
+ xdata, ydata, groupby
+ )
+ return self._plot(
+ x,
+ y,
+ group,
+ xlab=xlab,
+ ylab=ylab,
+ title=title,
+ alpha=alpha,
+ col_set=col_set,
+ cols=cols,
+ xlog=xlog,
+ ylog=ylog,
+ path=path,
+ device=device,
+ fingerprint=fingerprint,
+ unused_columns=unused_columns,
+ raise_if_missing=raise_if_missing,
+ show=show,
+ )
+
+ def plot_dataset(self, x, y=None, group=None):
+ from biosets import Bioset, decode
+
+ if isinstance(group, Bioset):
+ group = decode(group)
+
+ return self.plot_arrow(
+ x._data.table,
+ y._data.table if y else None,
+ group._data.table if group else None,
+ )
+
+ def plot_arrow(self, x, y=None, group=None):
+ kwargs = self.config.get_params()
+ context_kwargs = {
+ "path": kwargs.pop("path", None),
+ }
+ self.plotter(x, y, group, context_kwargs=context_kwargs, **kwargs)
diff --git a/src/biofit/visualization/violin.py b/src/biofit/visualization/violin.py
new file mode 100644
index 0000000..786415f
--- /dev/null
+++ b/src/biofit/visualization/violin.py
@@ -0,0 +1,140 @@
+import textwrap
+from dataclasses import dataclass, field
+from typing import List, Type
+
+import biofit.config as config
+from biofit.integration.biosets import get_feature
+from biofit.integration.R import RCaller
+from biofit.processing import SelectedColumnTypes, sync_backup_config
+from biofit.utils.types import Unset
+from biofit.visualization.plotting import BasePlotter, PlotterConfig
+
+
+@dataclass
+class ViolinConfig(PlotterConfig):
+ processor_type: str = field(default="scaling", init=False, repr=False)
+ _fit_input_feature_types: List[Type] = field(
+ default_factory=lambda: [None, get_feature("TARGET_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ _transform_input_feature_types: List[Type] = field(
+ default_factory=lambda: [None, get_feature("TARGET_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ _fit_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES"), None],
+ init=False,
+ repr=False,
+ )
+ _transform_unused_feature_types: List[Type] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES"), None],
+ init=False,
+ repr=False,
+ )
+
+ r_source: str = field(
+ default=(config.R_SCRIPTS / "plotting_utils.R").as_posix(),
+ init=False,
+ repr=False,
+ )
+ main_method: str = field(default="generate_violin", init=False, repr=False)
+
+ column: str = None
+ label_name: str = "labels"
+ xlab: str = "Labels"
+ ylab: str = "Value"
+
+
+class ViolinPlotter(BasePlotter):
+ _config_class = ViolinConfig
+ config: ViolinConfig
+
+ def __init__(
+ self,
+ column: str = None,
+ label_name: str = None,
+ xlab: str = Unset('"Labels"'),
+ ylab: str = Unset('"Value"'),
+ install_missing: bool = None,
+ config=None,
+ ):
+ super().__init__(
+ config=config, xlab=xlab, ylab=ylab, install_missing=install_missing
+ )
+ self.plotter = None
+ if self.config.r_source and self.config.main_method:
+ self.r_caller = RCaller.from_script(self.config.r_source)
+ enter_code = textwrap.dedent(
+ """
+ suppressPackageStartupMessages(require(ggplot2))
+ """
+ )
+ exit_code = textwrap.dedent(
+ """
+ ggplot2::ggsave(path, results)
+ """
+ )
+ self.plotter = self.r_caller.get_method(
+ self.config.main_method, enter_code, exit_code
+ )
+
+ @sync_backup_config
+ def set_params(self, **kwargs):
+ self.config = self.config.replace_defaults(**kwargs)
+ if self.config.r_source and self.config.main_method:
+ self.r_caller = RCaller.from_script(self.config.r_source)
+ enter_code = textwrap.dedent(
+ """
+ suppressPackageStartupMessages(require(ggplot2))
+ """
+ )
+ exit_code = textwrap.dedent(
+ """
+ ggplot2::ggsave(path, results)
+ """
+ )
+ self.plotter = self.r_caller.get_method(
+ self.config.main_method, enter_code, exit_code
+ )
+
+ return self
+
+ def plot(
+ self,
+ x,
+ y=None,
+ column: SelectedColumnTypes = None,
+ label_name: SelectedColumnTypes = None,
+ xlab: str = Unset('"Labels"'),
+ ylab: str = Unset('"Value"'),
+ path: str = None,
+ device: str = "pdf",
+ fingerprint: str = None,
+ unused_columns: SelectedColumnTypes = None,
+ raise_if_missing: bool = True,
+ show: bool = True,
+ ):
+ self.config._input_columns = self._set_input_columns_and_arity(
+ column, label_name
+ )
+ return self._plot(
+ x,
+ y,
+ xlab=xlab,
+ ylab=ylab,
+ path=path,
+ device=device,
+ fingerprint=fingerprint,
+ unused_columns=unused_columns,
+ raise_if_missing=raise_if_missing,
+ show=show,
+ )
+
+ def plot_arrow(self, x, y=None):
+ self.plotter(
+ x,
+ y,
+ **self.config.get_params(),
+ )
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000..7b503c4
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,108 @@
+from pathlib import Path
+
+import pytest
+from biocore.utils.import_util import is_biosets_available, is_datasets_available
+from biofit.utils.logging import silence
+
+# sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), './mock_packages')))
+
+pytest_plugins = ["tests.fixtures.files", "tests.fixtures.fsspec"]
+
+
+def pytest_collection_modifyitems(config, items):
+ # Mark tests as "unit" by default if not marked as "integration" (or already marked as "unit")
+ for item in items:
+ if any(marker in item.keywords for marker in ["integration", "unit"]):
+ continue
+ item.add_marker(pytest.mark.unit)
+
+
+def pytest_configure(config):
+ config.addinivalue_line(
+ "markers", "torchaudio_latest: mark test to run with torchaudio>=0.12"
+ )
+
+
+@pytest.fixture(autouse=True)
+def set_test_cache_config(tmp_path_factory, monkeypatch):
+ test_cache_home = tmp_path_factory.getbasetemp() / "cache"
+ test_patches_cache = test_cache_home / "patches"
+ test_datasets_cache = test_cache_home / "datasets"
+ test_processors_cache = test_cache_home / "processors"
+ test_modules_cache = test_cache_home / "modules"
+ monkeypatch.setattr("biofit.config.BIOFIT_CACHE_HOME", Path(test_cache_home))
+ monkeypatch.setattr("biofit.config.BIOFIT_PATCHES_CACHE", Path(test_patches_cache))
+ test_downloaded_datasets_path = test_datasets_cache / "downloads"
+ test_extracted_datasets_path = test_datasets_cache / "downloads" / "extracted"
+ if is_biosets_available():
+ monkeypatch.setattr(
+ "biosets.config.BIOSETS_DATASETS_CACHE", Path(test_datasets_cache)
+ )
+
+ monkeypatch.setattr(
+ "biosets.config.DOWNLOADED_BIOSETS_PATH",
+ str(test_downloaded_datasets_path),
+ )
+
+ monkeypatch.setattr(
+ "biosets.config.EXTRACTED_BIOSETS_PATH",
+ str(test_extracted_datasets_path),
+ )
+
+ if is_datasets_available():
+ monkeypatch.setattr(
+ "datasets.config.HF_DATASETS_CACHE", str(test_datasets_cache)
+ )
+ monkeypatch.setattr("datasets.config.HF_MODULES_CACHE", str(test_modules_cache))
+ monkeypatch.setattr(
+ "datasets.config.DOWNLOADED_DATASETS_PATH",
+ str(test_downloaded_datasets_path),
+ )
+ monkeypatch.setattr(
+ "datasets.config.EXTRACTED_DATASETS_PATH", str(test_extracted_datasets_path)
+ )
+
+ monkeypatch.setattr(
+ "biofit.config.BIOFIT_PROCESSORS_CACHE", Path(test_processors_cache)
+ )
+ monkeypatch.setattr("biofit.config.BIOFIT_MODULES_CACHE", Path(test_modules_cache))
+
+
+# @pytest.fixture(autouse=True, scope="session")
+# def disable_tqdm_output():
+# disable_progress_bar()
+
+
+# @pytest.fixture(autouse=True, scope="session")
+# def set_info_verbosity():
+# set_verbosity_info()
+
+
+@pytest.fixture(autouse=True, scope="session")
+def silence_ouput():
+ silence()
+
+
+@pytest.fixture(autouse=True)
+def set_update_download_counts_to_false(monkeypatch):
+ # don't take tests into account when counting downloads
+ if is_datasets_available():
+ monkeypatch.setattr("datasets.config.HF_UPDATE_DOWNLOAD_COUNTS", False)
+
+
+@pytest.fixture
+def set_sqlalchemy_silence_uber_warning(monkeypatch):
+ # Required to suppress RemovedIn20Warning when feature(s) are not compatible with SQLAlchemy 2.0
+ # To be removed once SQLAlchemy 2.0 supported
+ try:
+ monkeypatch.setattr("sqlalchemy.util.deprecations.SILENCE_UBER_WARNING", True)
+ except AttributeError:
+ pass
+
+
+@pytest.fixture(autouse=True, scope="session")
+def zero_time_out_for_remote_code():
+ if is_datasets_available():
+ import datasets.config
+
+ datasets.config.TIME_OUT_REMOTE_CODE = 0
diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py
new file mode 100644
index 0000000..e3022a8
--- /dev/null
+++ b/tests/fixtures/files.py
@@ -0,0 +1,1426 @@
+import contextlib
+import csv
+import json
+import os
+import sqlite3
+import tarfile
+import textwrap
+import zipfile
+from decimal import Decimal
+
+import numpy as np
+import pandas as pd
+import pyarrow as pa
+import pyarrow.parquet as pq
+import pytest
+from biocore.utils.import_util import (
+ is_biosets_available,
+ requires_backends,
+)
+from biofit.integration.biosets import get_feature
+from biofit.utils import enable_full_determinism
+from biofit.utils.py_util import set_seed
+from sklearn.datasets import make_classification
+
+# Constants
+ALPHANUMERIC = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
+SEED = 42
+enable_full_determinism(SEED)
+PA_DATA = {
+ "null": lambda num_rows: np.array([None] * num_rows),
+ "bool": lambda num_rows: np.random.choice([True, False], size=num_rows),
+ "one_hot": lambda num_rows: np.random.choice([0, 1], size=num_rows),
+ "multi_bins": lambda num_rows: np.random.choice([0, 1], size=num_rows),
+ "int8": lambda num_rows: np.random.normal(loc=0, scale=127, size=num_rows).astype(
+ np.int8
+ ),
+ "int16": lambda num_rows: np.random.normal(
+ loc=0, scale=32767, size=num_rows
+ ).astype(np.int16),
+ "int32": lambda num_rows: np.random.normal(
+ loc=0, scale=2147483647, size=num_rows
+ ).astype(np.int32),
+ "int64": lambda num_rows: np.random.normal(
+ loc=0, scale=9223372036854775807, size=num_rows
+ ).astype(np.int64),
+ "uint8": lambda num_rows: np.abs(
+ np.random.normal(loc=127.5, scale=127.5, size=num_rows)
+ ).astype(np.uint8),
+ "uint16": lambda num_rows: np.abs(
+ np.random.normal(loc=32767.5, scale=32767.5, size=num_rows)
+ ).astype(np.uint16),
+ "uint32": lambda num_rows: np.abs(
+ np.random.normal(loc=2147483647.5, scale=2147483647.5, size=num_rows)
+ ).astype(np.uint32),
+ "uint64": lambda num_rows: np.abs(
+ np.random.normal(
+ loc=9223372036854775807.5, scale=9223372036854775807.5, size=num_rows
+ )
+ ).astype(np.uint64),
+ "float16": lambda num_rows: np.random.normal(size=num_rows).astype(np.float16),
+ "float32": lambda num_rows: np.random.normal(size=num_rows).astype(np.float32),
+ "float64": lambda num_rows: np.random.normal(size=num_rows),
+ "string": lambda num_rows: np.array(
+ [
+ "".join(np.random.choice(list("abcdefghijklmnopqrstuvwxyz"), size=5))
+ for _ in range(num_rows)
+ ]
+ ),
+ "binary": lambda num_rows: np.array(
+ [
+ bytes(
+ "".join(np.random.choice(list("abcdefghijklmnopqrstuvwxyz"), size=5)),
+ "utf-8",
+ )
+ for _ in range(num_rows)
+ ]
+ ),
+ "large_string": lambda num_rows: np.array(
+ [
+ "".join(np.random.choice(list("abcdefghijklmnopqrstuvwxyz"), size=5))
+ for _ in range(num_rows)
+ ]
+ ),
+ "large_binary": lambda num_rows: np.array(
+ [
+ bytes(
+ "".join(np.random.choice(list("abcdefghijklmnopqrstuvwxyz"), size=5)),
+ "utf-8",
+ )
+ for _ in range(num_rows)
+ ]
+ ),
+ "date32": lambda num_rows: np.array(
+ [np.random.randint(1, 2147483647, size=num_rows).tolist()]
+ ),
+ "date64": lambda num_rows: np.array(
+ [np.random.randint(1, 2147483647, size=num_rows).tolist()]
+ ),
+ "time32": lambda num_rows: np.random.normal(
+ loc=0, scale=2147483647, size=num_rows
+ ).astype(np.int32),
+ "time64": lambda num_rows: np.random.normal(
+ loc=0, scale=9223372036854775807, size=num_rows
+ ).astype(np.int64),
+ "timestamp": lambda num_rows: np.array(
+ [
+ pd.Timestamp(np.random.choice(pd.date_range("2020-01-01", periods=365)))
+ for _ in range(num_rows)
+ ]
+ ),
+ "duration": lambda num_rows: np.random.normal(
+ loc=500, scale=250, size=num_rows
+ ).astype(int),
+ "decimal128": lambda num_rows: np.array(
+ [Decimal(np.random.randint(1, 1000)) for _ in range(num_rows)]
+ ),
+ "struct": lambda num_rows: [
+ {"a": np.random.randint(1, 1000)} for _ in range(num_rows)
+ ],
+}
+
+
+PA_FIELDS = {
+ "null": pa.field("null", pa.null()),
+ "bool": pa.field("bool", pa.bool_()),
+ "int8": pa.field("int8", pa.int8()),
+ "int16": pa.field("int16", pa.int16()),
+ "int32": pa.field("int32", pa.int32()),
+ "int64": pa.field("int64", pa.int64()),
+ "uint8": pa.field("uint8", pa.uint8()),
+ "uint16": pa.field("uint16", pa.uint16()),
+ "uint32": pa.field("uint32", pa.uint32()),
+ "uint64": pa.field("uint64", pa.uint64()),
+ "float16": pa.field("float16", pa.float16()),
+ "float32": pa.field("float32", pa.float32()),
+ "float64": pa.field("float64", pa.float64()),
+ "string": pa.field("string", pa.string()),
+ "binary": pa.field("binary", pa.binary()),
+ "large_string": pa.field("large_string", pa.large_string()),
+ "large_binary": pa.field("large_binary", pa.large_binary()),
+ "date32": pa.field("date32", pa.date32()),
+ "date64": pa.field("date64", pa.date64()),
+ "time32": pa.field("time32", pa.time32("s")),
+ "time64": pa.field("time64", pa.time64("ns")),
+ "timestamp": pa.field("timestamp", pa.timestamp("s")),
+ "duration": pa.field("duration", pa.duration("s")),
+ "decimal128": pa.field("decimal128", pa.decimal128(38, 9)),
+ "struct": pa.field("struct", pa.struct([pa.field("a", pa.int32())])),
+}
+
+
+def create_directory(path):
+ """
+ Create a directory if it doesn't exist.
+ """
+ try:
+ os.makedirs(path, exist_ok=True)
+ except OSError as e:
+ print(f"Error creating directory {path}: {e}")
+ raise
+
+
+def _create_all_arrow_types_dataframe(num_rows=100, feature_type=None):
+ from datasets.features import Value
+
+ if feature_type is None:
+ feature_type = Value
+ data = {k: v(num_rows) for k, v in PA_DATA.items()}
+
+ features = {
+ k: feature_type(k) if k != "struct" else {"a": feature_type("int32")}
+ for k in PA_FIELDS.keys()
+ }
+
+ return data, features
+
+
+def create_omic_dataset(
+ num_rows=100,
+ num_cols=None,
+ dtype="all",
+ sample="sample_id",
+ batch="batch",
+ label="label",
+ multi_class=False,
+ task="classification",
+ label_type="int",
+ metadata=True,
+ input_feature=None,
+ sparse=False,
+ missing_labels=False,
+):
+ """
+ Create a sample dataframe with predefined structure.
+ """
+ from datasets import Value
+
+ if input_feature is None:
+ input_feature = Value
+ data = {}
+ features = {}
+ enable_full_determinism(SEED)
+
+ if sample:
+ data[sample] = [str(i) for i in range(num_rows)]
+ features[sample] = get_feature("Sample")("string")
+ if batch:
+ data[batch] = [str(i) for i in range(num_rows)]
+ features[batch] = get_feature("Batch")("string")
+ metadata_value_options = {
+ "multi_classification_int": [i % 3 for i in range(num_rows)],
+ "multi_classification_str": [
+ ALPHANUMERIC[i % len(ALPHANUMERIC)] for i in range(num_rows)
+ ],
+ "bin_classification_bool": [
+ True if i % 2 == 0 else False for i in range(num_rows)
+ ],
+ "bin_classification_int": [i % 2 for i in range(num_rows)],
+ "bin_classification_str": [
+ "positive" if i > num_rows // 2 else "negative" for i in range(num_rows)
+ ],
+ "regression": np.random.randn(num_rows),
+ }
+ metadata_feature_options = {
+ "multi_classification_int": "int8",
+ "multi_classification_str": "string",
+ "bin_classification_bool": "bool",
+ "bin_classification_int": "int8",
+ "bin_classification_str": "string",
+ "regression": "float32",
+ }
+ if label:
+ if task == "classification":
+ if multi_class:
+ label_name = "multi"
+ else:
+ label_name = "bin"
+ label_name += f"_classification_{label_type}"
+ else:
+ label_name = "regression"
+
+ data[label] = metadata_value_options.pop(label_name)
+ if missing_labels:
+ # Randomly set 10% of labels to -1 if classification, else set to None
+ if task == "classification":
+ indices_to_replace = np.random.choice(
+ num_rows, int(num_rows * 0.1), replace=False
+ )
+ data[label] = [
+ -1 if i in indices_to_replace else lab
+ for i, lab in enumerate(data[label])
+ ]
+ else:
+ indices_to_replace = np.random.choice(
+ num_rows, int(num_rows * 0.1), replace=False
+ )
+ data[label] = [
+ None if i in indices_to_replace else lab
+ for i, lab in enumerate(data[label])
+ ]
+ label_dtype = metadata_feature_options.pop(label_name)
+ if label_name == "regression":
+ features[label] = get_feature("RegressionTarget")(label_dtype)
+ else:
+ names = list(set(data[label]))
+ if not isinstance(names[0], str):
+ names = [str(n) for n in names]
+ else:
+ name_map = {n: i for i, n in enumerate(names)}
+ data[label] = [name_map[n] for n in data[label]]
+
+ num_classes = len(names)
+
+ features[label] = get_feature("ClassLabel")(
+ num_classes=num_classes, names=names
+ )
+ if metadata:
+ if isinstance(metadata, str):
+ data.update(metadata_value_options)
+ features.update(
+ {
+ k: get_feature("Metadata")(dtype=v)
+ for k, v in metadata_feature_options.items()
+ }
+ )
+ else:
+ for label, v in metadata_value_options.items():
+ data[label] = v
+ features[label] = Value(metadata_feature_options[label])
+
+ if dtype == "all":
+ ext_data, ext_features = _create_all_arrow_types_dataframe(
+ num_rows=num_rows, feature_type=input_feature
+ )
+ else:
+ ext_data = {}
+ ext_features = {}
+ if sparse and isinstance(sparse, bool):
+ sparse = 0.8
+ if num_cols is None:
+ num_cols = 1
+
+ dtype_to_pa = {
+ "multi_bins": "int32",
+ "one_hot": "int32",
+ }
+
+ for i in range(num_cols):
+ arr: np.ndarray = PA_DATA[dtype](num_rows)
+ ext_data[f"{dtype}_{i}"] = arr.tolist()
+
+ ext_features[f"{dtype}_{i}"] = input_feature(
+ dtype=dtype_to_pa.get(dtype, dtype),
+ metadata={
+ "my_metadata_str": ALPHANUMERIC[
+ np.random.randint(0, len(ALPHANUMERIC))
+ ],
+ "my_metadata_int": np.random.randint(0, 100),
+ },
+ )
+ if sparse:
+ mat = np.array([ext_data[f"{dtype}_{i}"] for i in range(num_cols)]).T
+ for i in range(num_cols):
+ arr = np.array(ext_data[f"{dtype}_{i}"])
+ total_values = arr.size
+ if isinstance(sparse, list):
+ _sparse = sparse[i]
+ else:
+ _sparse = sparse
+ if isinstance(_sparse, bool):
+ _sparse = np.random.uniform(0.1, 0.9)
+
+ num_to_replace = max(
+ min(int(total_values * (_sparse)), total_values - 1), 0
+ )
+ indices_to_replace = np.random.choice(
+ total_values, num_to_replace, replace=False
+ )
+ # check if replacing with 0 would make a row all 0s
+ for idx in indices_to_replace:
+ if dtype in [
+ "one_hot",
+ "multi_bins",
+ "uint8",
+ "uint16",
+ "uint32",
+ "uint64",
+ ]:
+ if np.sum(mat[:idx] > 0) + np.sum(mat[idx + 1 :] > 0) == 0:
+ indices_to_replace = np.delete(
+ indices_to_replace, np.where(indices_to_replace == idx)
+ )
+ else:
+ arr[idx] = 0
+ else:
+ if all(v is None for v in mat[:idx]) and all(
+ v is None for v in mat[idx + 1 :]
+ ):
+ indices_to_replace = np.delete(
+ indices_to_replace, np.where(indices_to_replace == idx)
+ )
+ else:
+ arr[idx] = 0
+ ext_data[f"{dtype}_{i}"] = arr.tolist()
+
+ data.update(ext_data)
+ features.update(ext_features)
+ if is_biosets_available():
+ import biosets
+ import datasets
+
+ return biosets.Bioset.from_dict(data, features=datasets.Features(features))
+ return pd.DataFrame(data)
+
+
+def create_feature_dataframe(num_cols=100, feature_id="feature"):
+ """
+ Create a feature dataframe with predefined structure.
+ """
+ enable_full_determinism(SEED)
+ data = {
+ feature_id: [str(i) for i in range(num_cols)],
+ }
+
+ fields = [
+ pa.field(feature_id, pa.string()),
+ ]
+
+ ext_data, ext_fields = _create_all_arrow_types_dataframe(num_rows=num_cols)
+ data.update(ext_data)
+ fields.extend(ext_fields)
+
+ return pa.table(data, schema=pa.schema(fields))
+
+
+def directory_exists_with_files(path, expected_files):
+ """
+ Check if a directory exists with the expected files.
+ """
+ if not os.path.exists(path):
+ return False
+ if not all(os.path.exists(os.path.join(path, file)) for file in expected_files):
+ return False
+ return True
+
+
+# def save_dataframes(dfs, data_dir, filenames):
+# """
+# Save a list of dataframes to CSV in the specified directory.
+# """
+# for df, filename in zip(dfs, filenames):
+# file_ext = filename.split(".")[-1]
+# if file_ext in ["parquet"]:
+# tbl = pa.Table.from_pandas(df) if isinstance(df, pd.DataFrame) else df
+# if "float16" in tbl.schema.names:
+# tbl = tbl.drop(["float16"]) # not supported by parquet
+# writer = ParquetWriter(
+# path=os.path.join(data_dir, filename), schema=tbl.schema
+# )
+# writer.write_table(tbl)
+# elif file_ext in ["arrow"]:
+# tbl = pa.Table.from_pandas(df) if isinstance(df, pd.DataFrame) else df
+# writer = ArrowWriter(
+# path=os.path.join(data_dir, filename), schema=tbl.schema
+# )
+# writer.write_table(tbl)
+# elif file_ext in ["csv"]:
+# df.to_csv(os.path.join(data_dir, filename), index=False)
+# elif file_ext in ["tsv", "txt"]:
+# df.to_csv(os.path.join(data_dir, filename), sep="\t", index=False)
+
+
+# def create_fake_data_dir(data, base_dir, overwrite=False):
+# for name, filenames, dfs, _ in data:
+# data_dir = f"{base_dir}/{name}"
+# os.makedirs(data_dir, exist_ok=True)
+# if not directory_exists_with_files(data_dir, filenames) or overwrite:
+# save_dataframes(dfs, data_dir, filenames)
+
+
+def create_dataset_with_sklearn(
+ path,
+ experiment_type,
+ n_features=20,
+ n_samples=50,
+ multi_class=False,
+ lab_as_str=True,
+ dataset_type="snp",
+ label_column="label",
+ add_missing_labels=False,
+):
+ if multi_class:
+ num_classes = 3
+ else:
+ num_classes = 2
+
+ samples, labs, names = create_data(
+ n_samples,
+ n_features,
+ num_classes,
+ experiment_type,
+ lab_as_str=lab_as_str,
+ _add_missing_labels=add_missing_labels,
+ )
+ data = create_sample_metadata(n_samples)
+ features = None
+ if is_biosets_available():
+ if experiment_type == "snp":
+ features = create_features(
+ n_features, get_feature("GenomicVariant"), "int8", num_classes, names
+ )
+ elif experiment_type in ["otu", "asv"]:
+ features = create_features(
+ n_features, get_feature("Abundance"), "int32", num_classes, names
+ )
+ elif experiment_type == "maldi":
+ features = create_features(
+ n_features, get_feature("PeakIntensity"), "int32", num_classes, names
+ )
+
+ data = pd.concat(
+ [
+ data,
+ pd.DataFrame(samples, columns=[f"int32_{i}" for i in range(n_features)]),
+ pd.DataFrame(labs.reshape(-1, 1), columns=[label_column]),
+ ],
+ axis=1,
+ )
+ if is_biosets_available():
+ import biosets
+ import datasets
+
+ ds = biosets.Dataset.from_pandas(data, features=datasets.Features(features))
+
+ ds.info.builder_name = dataset_type
+ os.makedirs(path, exist_ok=True)
+ ds.save_to_disk(path)
+ else:
+ data.to_csv(path, index=False)
+
+
+def _add_missing_labels(labs, num_classes=2, lab_as_str=True, n_samples=50):
+ if isinstance(labs, np.ndarray):
+ new_labs = labs.tolist()
+ else:
+ new_labs = labs
+ # make 10% of labels missing for each class
+ for i in range(num_classes):
+ if lab_as_str:
+ indices_to_replace = np.random.choice(
+ np.where(labs == ALPHANUMERIC[i % len(ALPHANUMERIC)])[0],
+ int(n_samples * 0.1),
+ replace=False,
+ )
+ for idx in indices_to_replace:
+ new_labs[idx] = None
+ else:
+ indices_to_replace = np.random.choice(
+ np.where(labs == i)[0], int(n_samples * 0.1), replace=False
+ )
+ for idx in indices_to_replace:
+ new_labs[idx] = -1
+ return np.array(new_labs)
+
+
+def create_data(
+ n_samples,
+ n_features,
+ num_classes,
+ experiment_type,
+ lab_as_str=False,
+ add_missing_labels=False,
+):
+ samples, labs = make_classification(
+ n_samples=n_samples,
+ n_features=n_features,
+ n_classes=num_classes,
+ n_informative=int(n_features * 0.8),
+ n_redundant=int(n_features * 0.1),
+ random_state=SEED,
+ )
+
+ if experiment_type in ["snp"]:
+ median = np.median(samples.flatten())
+ samples[samples < median] = 0
+ samples[samples >= median] = 1
+ samples = samples.astype(np.int32)
+ elif experiment_type in ["otu", "asv", "maldi"]:
+ samples = np.abs(np.round(np.quantile(samples.flatten(), 0.25))) + samples
+ samples[samples < 0] = 0
+ samples = samples.astype(np.int32)
+ if lab_as_str:
+ labs = np.array(
+ [ALPHANUMERIC[i % len(ALPHANUMERIC)] for i in sorted(labs.tolist())]
+ )
+ names = [ALPHANUMERIC[i % len(ALPHANUMERIC)] for i in range(num_classes)]
+ else:
+ names = [str(i) for i in range(num_classes)]
+ labs = labs.astype(np.int32)
+
+ if add_missing_labels:
+ labs = _add_missing_labels(
+ labs, num_classes=num_classes, lab_as_str=lab_as_str, n_samples=n_samples
+ )
+
+ samples = pd.DataFrame(samples, columns=[f"feature_{i}" for i in range(n_features)])
+ return samples, labs, names
+
+
+def create_features(
+ n_features,
+ FeatureType=None,
+ dtype=None,
+ num_classes=None,
+ names=None,
+ with_metadata=True,
+ as_dict=False,
+):
+ metadata_feature_options = {
+ "multi_int": "int32",
+ "multi_str": "string",
+ "bin_bool": "bool",
+ "bin_int": "int32",
+ "bin_str": "string",
+ "floating": "float64",
+ }
+ if not as_dict:
+ requires_backends("create_features_for_sample_metadata", "biosets")
+
+ features = {}
+ if with_metadata:
+ features.update(
+ {
+ "sample_id": get_feature("Sample")("string"),
+ }
+ )
+ features.update(
+ {
+ k: get_feature("Metadata")(dtype=v)
+ for k, v in metadata_feature_options.items()
+ }
+ )
+
+ for i in range(n_features):
+ features[f"feature_{i}"] = FeatureType(
+ dtype="int32",
+ metadata={
+ "my_metadata_str": ALPHANUMERIC[
+ np.random.randint(0, len(ALPHANUMERIC))
+ ],
+ "my_metadata_int": np.random.randint(0, 100),
+ },
+ )
+ if num_classes is not None:
+ if num_classes > 2:
+ features["label"] = get_feature("ClassLabel")(
+ num_classes=num_classes, names=names
+ )
+ else:
+ features["label"] = get_feature("BinClassLabel")(
+ num_classes=num_classes, names=names
+ )
+
+ else:
+ features = {}
+ if with_metadata:
+ features = {
+ "sample_id": {},
+ }
+ features.update({k: {} for k in metadata_feature_options.keys()})
+
+ for i in range(n_features):
+ features[f"feature_{i}"] = {
+ "my_metadata_str": ALPHANUMERIC[
+ np.random.randint(0, len(ALPHANUMERIC))
+ ],
+ "my_metadata_int": np.random.randint(0, 100),
+ }
+ if num_classes is not None:
+ features["label"] = {
+ "num_classes": num_classes,
+ "names": names,
+ }
+
+ return features
+
+
+def create_sample_metadata(n_samples=100):
+ return pd.DataFrame(
+ {
+ "sample_id": [f"s{i}" for i in range(n_samples)],
+ "multi_int": [i % 3 for i in range(n_samples)],
+ "multi_str": [
+ ALPHANUMERIC[i % len(ALPHANUMERIC)] for i in range(n_samples)
+ ],
+ "bin_bool": [True if i % 2 == 0 else False for i in range(n_samples)],
+ "bin_int": [i % 2 for i in range(n_samples)],
+ "bin_str": [
+ "positive" if i > n_samples // 2 else "negative"
+ for i in range(n_samples)
+ ],
+ "floating": np.random.randn(n_samples),
+ }
+ )
+
+
+@pytest.fixture(scope="session")
+def count_data():
+ data, y, _ = create_data(20, 5, 2, "otu")
+ y = pd.DataFrame(y[:, None], columns=["label"])
+ return data, y
+
+
+@pytest.fixture(scope="session")
+def count_data_multi_class():
+ data, y, _ = create_data(20, 5, 3, "otu")
+ y = pd.DataFrame(y[:, None], columns=["label"])
+ return data, y
+
+
+@pytest.fixture(scope="session")
+def count_data_missing_labels():
+ data, y, _ = create_data(20, 5, 2, "otu", add_missing_labels=True)
+ y = pd.DataFrame(y[:, None], columns=["label"])
+ return data, y
+
+
+@pytest.fixture(scope="session")
+def binary_data():
+ data, y, _ = create_data(20, 5, 2, "snp")
+ y = pd.DataFrame(y[:, None], columns=["label"])
+ return data, y
+
+
+@pytest.fixture(scope="session")
+def float_data():
+ data, y, _ = create_data(20, 5, 2, None)
+ y = pd.DataFrame(y[:, None], columns=["label"])
+ return data, y
+
+
+@pytest.fixture(scope="session")
+def classification_data():
+ data, y, _ = create_data(100, 5, 2, None)
+ y = pd.DataFrame(y[:, None], columns=["label"])
+ return data, y
+
+
+@pytest.fixture(scope="session")
+def classification_data_multi_class():
+ data, y, _ = create_data(100, 5, 3, None)
+ y = pd.DataFrame(y[:, None], columns=["label"])
+ return data, y
+
+
+@pytest.fixture(scope="session")
+def feature_metadata():
+ return create_features(5, as_dict=True, with_metadata=False)
+
+
+@pytest.fixture(scope="session")
+def sample_metadata():
+ return create_sample_metadata(20)
+
+
+@pytest.fixture(scope="session")
+def biodataset():
+ return create_omic_dataset(10, num_cols=3, dtype="float32", metadata="metadata")
+
+
+@pytest.fixture(scope="session")
+def snp_dataset_path(tmp_path_factory):
+ set_seed(SEED)
+ path = str(tmp_path_factory.mktemp("data") / "SNP")
+ create_dataset_with_sklearn(path, "snp")
+ return path
+
+
+@pytest.fixture(scope="session")
+def otu_dataset_path(tmp_path_factory):
+ set_seed(SEED)
+ path = str(tmp_path_factory.mktemp("data") / "OTU")
+ create_dataset_with_sklearn(path)
+ return path
+
+
+@pytest.fixture(scope="session")
+def otu_dataset_missing_labels_path(tmp_path_factory):
+ set_seed(SEED)
+ path = str(tmp_path_factory.mktemp("data") / "OTU")
+ create_dataset_with_sklearn(path, "otu", lab_as_str=False, add_missing_labels=True)
+ return path
+
+
+@pytest.fixture(scope="session")
+def otu_dataset_multi_class_path(tmp_path_factory):
+ set_seed(SEED)
+ path = str(tmp_path_factory.mktemp("data") / "OTU")
+ create_dataset_with_sklearn(path, "otu", multi_class=True)
+ return path
+
+
+@pytest.fixture(scope="session")
+def maldi_dataset_path(tmp_path_factory):
+ set_seed(SEED)
+ path = str(tmp_path_factory.mktemp("data") / "MALDI")
+ create_dataset_with_sklearn(path, "maldi")
+ return path
+
+
+@pytest.fixture(scope="session")
+def snp_dataset():
+ ds = create_omic_dataset(
+ num_rows=10,
+ num_cols=3,
+ dtype="multi_bins",
+ metadata="metadata",
+ input_feature=get_feature("GenomicVariant"),
+ sparse=0.8,
+ )
+ ds.info.builder_name = "snp"
+ return ds
+
+
+@pytest.fixture(scope="session")
+def maldi_dataset():
+ ds = create_omic_dataset(
+ num_rows=10,
+ num_cols=3,
+ dtype="multi_bins",
+ metadata="metadata",
+ input_feature=get_feature("PeakIntensity"),
+ sparse=0.8,
+ )
+ ds.info.builder_name = "maldi"
+ return ds
+
+
+# @pytest.fixture(scope="session")
+# def camda_dataset():
+# camda_dir = "./tests/data/CAMDA"
+# camda_metadata_files = os.path.join(camda_dir, "camda.pheno.csv")
+# camda_feature_metadata_files = os.path.join(camda_dir, "camda.feature.csv")
+# ds = load_dataset(
+# "otu",
+# data_dir=camda_dir,
+# sample_metadata_files=camda_metadata_files,
+# feature_metadata_files=camda_feature_metadata_files,
+# label_column="City2",
+# cache_dir="./.cache",
+# )
+# ds.cleanup_cache_files()
+# return ds
+#
+
+# @pytest.fixture(scope="session")
+# def camda_dataset_files_only():
+# camda_dir = "./tests/data/CAMDA"
+# data_files = os.path.join(camda_dir, "*matrix*.csv")
+# return load_dataset(
+# dataset_type="otu",
+# name="camda",
+# data_files=data_files,
+# label_column="City2",
+# )
+
+
+# @pytest.fixture(scope="session")
+# def camda_dataset_no_polars():
+# camda_dir = "./tests/data/CAMDA"
+# camda_metadata_files = os.path.join(camda_dir, "camda.pheno.csv")
+# camda_feature_metadata_files = os.path.join(camda_dir, "camda.feature.csv")
+# return load_dataset(
+# "otu",
+# data_dir=camda_dir,
+# sample_metadata_files=camda_metadata_files,
+# feature_metadata_files=camda_feature_metadata_files,
+# label_column="City2",
+# cache_dir="./.cache",
+# use_polars=False,
+# )
+#
+
+# @pytest.fixture(scope="session")
+# def tb_dataset():
+# tb_dir = "./tests/data/genomics_TB"
+# dataset = load_dataset(
+# "snp",
+# "TB",
+# data_dir=tb_dir,
+# label_column="Isoniazid",
+# keep_in_memory=False,
+# cache_dir="./.cache",
+# )
+# dataset.cleanup_cache_files()
+# return dataset
+#
+
+
+@pytest.fixture(scope="session")
+def arrow_file(tmp_path_factory, dataset):
+ filename = str(tmp_path_factory.mktemp("data") / "file.arrow")
+ dataset.map(cache_file_name=filename)
+ return filename
+
+
+# FILE_CONTENT + files
+
+
+FILE_CONTENT = """\
+ Text data.
+ Second line of data."""
+
+
+@pytest.fixture(scope="session")
+def text_file(tmp_path_factory):
+ filename = tmp_path_factory.mktemp("data") / "file.txt"
+ data = FILE_CONTENT
+ with open(filename, "w") as f:
+ f.write(data)
+ return filename
+
+
+@pytest.fixture(scope="session")
+def bz2_file(tmp_path_factory):
+ import bz2
+
+ path = tmp_path_factory.mktemp("data") / "file.txt.bz2"
+ data = bytes(FILE_CONTENT, "utf-8")
+ with bz2.open(path, "wb") as f:
+ f.write(data)
+ return path
+
+
+@pytest.fixture(scope="session")
+def gz_file(tmp_path_factory):
+ import gzip
+
+ path = str(tmp_path_factory.mktemp("data") / "file.txt.gz")
+ data = bytes(FILE_CONTENT, "utf-8")
+ with gzip.open(path, "wb") as f:
+ f.write(data)
+ return path
+
+
+@pytest.fixture(scope="session")
+def lz4_file(tmp_path_factory):
+ try:
+ import lz4.frame
+
+ path = tmp_path_factory.mktemp("data") / "file.txt.lz4"
+ data = bytes(FILE_CONTENT, "utf-8")
+ with lz4.frame.open(path, "wb") as f:
+ f.write(data)
+ return path
+ except ImportError:
+ pytest.skip("lz4 not available")
+
+
+@pytest.fixture(scope="session")
+def seven_zip_file(tmp_path_factory, text_file):
+ try:
+ import py7zr
+
+ path = tmp_path_factory.mktemp("data") / "file.txt.7z"
+ with py7zr.SevenZipFile(path, "w") as archive:
+ archive.write(text_file, arcname=os.path.basename(text_file))
+ return path
+ except ImportError:
+ pytest.skip("py7zr not available")
+
+
+@pytest.fixture(scope="session")
+def tar_file(tmp_path_factory, text_file):
+ import tarfile
+
+ path = tmp_path_factory.mktemp("data") / "file.txt.tar"
+ with tarfile.TarFile(path, "w") as f:
+ f.add(text_file, arcname=os.path.basename(text_file))
+ return path
+
+
+@pytest.fixture(scope="session")
+def xz_file(tmp_path_factory):
+ import lzma
+
+ path = tmp_path_factory.mktemp("data") / "file.txt.xz"
+ data = bytes(FILE_CONTENT, "utf-8")
+ with lzma.open(path, "wb") as f:
+ f.write(data)
+ return path
+
+
+@pytest.fixture(scope="session")
+def zip_file(tmp_path_factory, text_file):
+ import zipfile
+
+ path = tmp_path_factory.mktemp("data") / "file.txt.zip"
+ with zipfile.ZipFile(path, "w") as f:
+ f.write(text_file, arcname=os.path.basename(text_file))
+ return path
+
+
+@pytest.fixture(scope="session")
+def zstd_file(tmp_path_factory):
+ try:
+ import zstandard as zstd
+
+ path = tmp_path_factory.mktemp("data") / "file.txt.zst"
+ data = bytes(FILE_CONTENT, "utf-8")
+ with zstd.open(path, "wb") as f:
+ f.write(data)
+ return path
+ except ImportError:
+ pytest.skip("zstandard not available")
+
+
+# xml_file
+
+
+@pytest.fixture(scope="session")
+def xml_file(tmp_path_factory):
+ filename = tmp_path_factory.mktemp("data") / "file.xml"
+ data = textwrap.dedent(
+ """\
+
+
+
+
+
+ Contingut 1
+ Content 1
+
+
+ Contingut 2
+ Content 2
+
+
+ Contingut 3
+ Content 3
+
+
+ Contingut 4
+ Content 4
+
+
+ Contingut 5
+ Content 5
+
+
+ """
+ )
+ with open(filename, "w") as f:
+ f.write(data)
+ return filename
+
+
+DATA = [
+ {"col_1": "0", "col_2": 0, "col_3": 0.0},
+ {"col_1": "1", "col_2": 1, "col_3": 1.0},
+ {"col_1": "2", "col_2": 2, "col_3": 2.0},
+ {"col_1": "3", "col_2": 3, "col_3": 3.0},
+]
+DATA2 = [
+ {"col_1": "4", "col_2": 4, "col_3": 4.0},
+ {"col_1": "5", "col_2": 5, "col_3": 5.0},
+]
+DATA_DICT_OF_LISTS = {
+ "col_1": ["0", "1", "2", "3"],
+ "col_2": [0, 1, 2, 3],
+ "col_3": [0.0, 1.0, 2.0, 3.0],
+}
+
+DATA_312 = [
+ {"col_3": 0.0, "col_1": "0", "col_2": 0},
+ {"col_3": 1.0, "col_1": "1", "col_2": 1},
+]
+
+DATA_STR = [
+ {"col_1": "s0", "col_2": 0, "col_3": 0.0},
+ {"col_1": "s1", "col_2": 1, "col_3": 1.0},
+ {"col_1": "s2", "col_2": 2, "col_3": 2.0},
+ {"col_1": "s3", "col_2": 3, "col_3": 3.0},
+]
+
+
+@pytest.fixture(scope="session")
+def dataset_dict():
+ return DATA_DICT_OF_LISTS
+
+
+@pytest.fixture(scope="session")
+def arrow_path(tmp_path_factory):
+ import datasets
+
+ dataset = datasets.Dataset.from_dict(DATA_DICT_OF_LISTS)
+ path = str(tmp_path_factory.mktemp("data") / "dataset.arrow")
+ dataset.map(cache_file_name=path)
+ return path
+
+
+@pytest.fixture(scope="session")
+def sqlite_path(tmp_path_factory):
+ path = str(tmp_path_factory.mktemp("data") / "dataset.sqlite")
+ with contextlib.closing(sqlite3.connect(path)) as con:
+ cur = con.cursor()
+ cur.execute("CREATE TABLE dataset(col_1 text, col_2 int, col_3 real)")
+ for item in DATA:
+ cur.execute(
+ "INSERT INTO dataset(col_1, col_2, col_3) VALUES (?, ?, ?)",
+ tuple(item.values()),
+ )
+ con.commit()
+ return path
+
+
+@pytest.fixture(scope="session")
+def csv_path(tmp_path_factory):
+ path = str(tmp_path_factory.mktemp("data") / "dataset.csv")
+ with open(path, "w", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=["col_1", "col_2", "col_3"])
+ writer.writeheader()
+ for item in DATA:
+ writer.writerow(item)
+ return path
+
+
+@pytest.fixture(scope="session")
+def csv2_path(tmp_path_factory):
+ path = str(tmp_path_factory.mktemp("data") / "dataset2.csv")
+ with open(path, "w", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=["col_1", "col_2", "col_3"])
+ writer.writeheader()
+ for item in DATA:
+ writer.writerow(item)
+ return path
+
+
+@pytest.fixture(scope="session")
+def bz2_csv_path(csv_path, tmp_path_factory):
+ import bz2
+
+ path = tmp_path_factory.mktemp("data") / "dataset.csv.bz2"
+ with open(csv_path, "rb") as f:
+ data = f.read()
+ # data = bytes(FILE_CONTENT, "utf-8")
+ with bz2.open(path, "wb") as f:
+ f.write(data)
+ return path
+
+
+@pytest.fixture(scope="session")
+def zip_csv_path(csv_path, csv2_path, tmp_path_factory):
+ path = tmp_path_factory.mktemp("zip_csv_path") / "csv-dataset.zip"
+ with zipfile.ZipFile(path, "w") as f:
+ f.write(csv_path, arcname=os.path.basename(csv_path))
+ f.write(csv2_path, arcname=os.path.basename(csv2_path))
+ return path
+
+
+@pytest.fixture(scope="session")
+def zip_uppercase_csv_path(csv_path, csv2_path, tmp_path_factory):
+ path = tmp_path_factory.mktemp("data") / "dataset.csv.zip"
+ with zipfile.ZipFile(path, "w") as f:
+ f.write(csv_path, arcname=os.path.basename(csv_path.replace(".csv", ".CSV")))
+ f.write(csv2_path, arcname=os.path.basename(csv2_path.replace(".csv", ".CSV")))
+ return path
+
+
+@pytest.fixture(scope="session")
+def zip_csv_with_dir_path(csv_path, csv2_path, tmp_path_factory):
+ path = tmp_path_factory.mktemp("data") / "dataset_with_dir.csv.zip"
+ with zipfile.ZipFile(path, "w") as f:
+ f.write(csv_path, arcname=os.path.join("main_dir", os.path.basename(csv_path)))
+ f.write(
+ csv2_path, arcname=os.path.join("main_dir", os.path.basename(csv2_path))
+ )
+ return path
+
+
+@pytest.fixture(scope="session")
+def parquet_path(tmp_path_factory):
+ path = str(tmp_path_factory.mktemp("data") / "dataset.parquet")
+ schema = pa.schema(
+ {
+ "col_1": pa.string(),
+ "col_2": pa.int64(),
+ "col_3": pa.float64(),
+ }
+ )
+ with open(path, "wb") as f:
+ writer = pq.ParquetWriter(f, schema=schema)
+ pa_table = pa.Table.from_pydict(
+ {k: [DATA[i][k] for i in range(len(DATA))] for k in DATA[0]}, schema=schema
+ )
+ writer.write_table(pa_table)
+ writer.close()
+ return path
+
+
+@pytest.fixture(scope="session")
+def json_list_of_dicts_path(tmp_path_factory):
+ path = str(tmp_path_factory.mktemp("data") / "dataset.json")
+ data = {"data": DATA}
+ with open(path, "w") as f:
+ json.dump(data, f)
+ return path
+
+
+@pytest.fixture(scope="session")
+def json_dict_of_lists_path(tmp_path_factory):
+ path = str(tmp_path_factory.mktemp("data") / "dataset.json")
+ data = {"data": DATA_DICT_OF_LISTS}
+ with open(path, "w") as f:
+ json.dump(data, f)
+ return path
+
+
+@pytest.fixture(scope="session")
+def jsonl_path(tmp_path_factory):
+ path = str(tmp_path_factory.mktemp("data") / "dataset.jsonl")
+ with open(path, "w") as f:
+ for item in DATA:
+ f.write(json.dumps(item) + "\n")
+ return path
+
+
+@pytest.fixture(scope="session")
+def jsonl2_path(tmp_path_factory):
+ path = str(tmp_path_factory.mktemp("data") / "dataset2.jsonl")
+ with open(path, "w") as f:
+ for item in DATA:
+ f.write(json.dumps(item) + "\n")
+ return path
+
+
+@pytest.fixture(scope="session")
+def jsonl_312_path(tmp_path_factory):
+ path = str(tmp_path_factory.mktemp("data") / "dataset_312.jsonl")
+ with open(path, "w") as f:
+ for item in DATA_312:
+ f.write(json.dumps(item) + "\n")
+ return path
+
+
+@pytest.fixture(scope="session")
+def jsonl_str_path(tmp_path_factory):
+ path = str(tmp_path_factory.mktemp("data") / "dataset-str.jsonl")
+ with open(path, "w") as f:
+ for item in DATA_STR:
+ f.write(json.dumps(item) + "\n")
+ return path
+
+
+@pytest.fixture(scope="session")
+def text_gz_path(tmp_path_factory, text_path):
+ import gzip
+
+ path = str(tmp_path_factory.mktemp("data") / "dataset.txt.gz")
+ with open(text_path, "rb") as orig_file:
+ with gzip.open(path, "wb") as zipped_file:
+ zipped_file.writelines(orig_file)
+ return path
+
+
+@pytest.fixture(scope="session")
+def jsonl_gz_path(tmp_path_factory, jsonl_path):
+ import gzip
+
+ path = str(tmp_path_factory.mktemp("data") / "dataset.jsonl.gz")
+ with open(jsonl_path, "rb") as orig_file:
+ with gzip.open(path, "wb") as zipped_file:
+ zipped_file.writelines(orig_file)
+ return path
+
+
+@pytest.fixture(scope="session")
+def zip_jsonl_path(jsonl_path, jsonl2_path, tmp_path_factory):
+ path = tmp_path_factory.mktemp("data") / "dataset.jsonl.zip"
+ with zipfile.ZipFile(path, "w") as f:
+ f.write(jsonl_path, arcname=os.path.basename(jsonl_path))
+ f.write(jsonl2_path, arcname=os.path.basename(jsonl2_path))
+ return path
+
+
+@pytest.fixture(scope="session")
+def zip_nested_jsonl_path(zip_jsonl_path, jsonl_path, jsonl2_path, tmp_path_factory):
+ path = tmp_path_factory.mktemp("data") / "dataset_nested.jsonl.zip"
+ with zipfile.ZipFile(path, "w") as f:
+ f.write(
+ zip_jsonl_path,
+ arcname=os.path.join("nested", os.path.basename(zip_jsonl_path)),
+ )
+ return path
+
+
+@pytest.fixture(scope="session")
+def zip_jsonl_with_dir_path(jsonl_path, jsonl2_path, tmp_path_factory):
+ path = tmp_path_factory.mktemp("data") / "dataset_with_dir.jsonl.zip"
+ with zipfile.ZipFile(path, "w") as f:
+ f.write(
+ jsonl_path, arcname=os.path.join("main_dir", os.path.basename(jsonl_path))
+ )
+ f.write(
+ jsonl2_path, arcname=os.path.join("main_dir", os.path.basename(jsonl2_path))
+ )
+ return path
+
+
+@pytest.fixture(scope="session")
+def tar_jsonl_path(jsonl_path, jsonl2_path, tmp_path_factory):
+ path = tmp_path_factory.mktemp("data") / "dataset.jsonl.tar"
+ with tarfile.TarFile(path, "w") as f:
+ f.add(jsonl_path, arcname=os.path.basename(jsonl_path))
+ f.add(jsonl2_path, arcname=os.path.basename(jsonl2_path))
+ return path
+
+
+@pytest.fixture(scope="session")
+def tar_nested_jsonl_path(tar_jsonl_path, jsonl_path, jsonl2_path, tmp_path_factory):
+ path = tmp_path_factory.mktemp("data") / "dataset_nested.jsonl.tar"
+ with tarfile.TarFile(path, "w") as f:
+ f.add(
+ tar_jsonl_path,
+ arcname=os.path.join("nested", os.path.basename(tar_jsonl_path)),
+ )
+ return path
+
+
+@pytest.fixture(scope="session")
+def text_path(tmp_path_factory):
+ data = ["0", "1", "2", "3"]
+ path = str(tmp_path_factory.mktemp("data") / "dataset.txt")
+ with open(path, "w") as f:
+ for item in data:
+ f.write(item + "\n")
+ return path
+
+
+@pytest.fixture(scope="session")
+def text2_path(tmp_path_factory):
+ data = ["0", "1", "2", "3"]
+ path = str(tmp_path_factory.mktemp("data") / "dataset2.txt")
+ with open(path, "w") as f:
+ for item in data:
+ f.write(item + "\n")
+ return path
+
+
+@pytest.fixture(scope="session")
+def text_dir(tmp_path_factory):
+ data = ["0", "1", "2", "3"]
+ path = tmp_path_factory.mktemp("data_text_dir") / "dataset.txt"
+ with open(path, "w") as f:
+ for item in data:
+ f.write(item + "\n")
+ return path.parent
+
+
+@pytest.fixture(scope="session")
+def text_dir_with_unsupported_extension(tmp_path_factory):
+ data = ["0", "1", "2", "3"]
+ path = tmp_path_factory.mktemp("data") / "dataset.abc"
+ with open(path, "w") as f:
+ for item in data:
+ f.write(item + "\n")
+ return path
+
+
+@pytest.fixture(scope="session")
+def zip_text_path(text_path, text2_path, tmp_path_factory):
+ path = tmp_path_factory.mktemp("data") / "dataset.text.zip"
+ with zipfile.ZipFile(path, "w") as f:
+ f.write(text_path, arcname=os.path.basename(text_path))
+ f.write(text2_path, arcname=os.path.basename(text2_path))
+ return path
+
+
+@pytest.fixture(scope="session")
+def zip_text_with_dir_path(text_path, text2_path, tmp_path_factory):
+ path = tmp_path_factory.mktemp("data") / "dataset_with_dir.text.zip"
+ with zipfile.ZipFile(path, "w") as f:
+ f.write(
+ text_path, arcname=os.path.join("main_dir", os.path.basename(text_path))
+ )
+ f.write(
+ text2_path, arcname=os.path.join("main_dir", os.path.basename(text2_path))
+ )
+ return path
+
+
+@pytest.fixture(scope="session")
+def zip_unsupported_ext_path(text_path, text2_path, tmp_path_factory):
+ path = tmp_path_factory.mktemp("data") / "dataset.ext.zip"
+ with zipfile.ZipFile(path, "w") as f:
+ f.write(text_path, arcname=os.path.basename("unsupported.ext"))
+ f.write(text2_path, arcname=os.path.basename("unsupported_2.ext"))
+ return path
+
+
+@pytest.fixture(scope="session")
+def text_path_with_unicode_new_lines(tmp_path_factory):
+ text = "\n".join(["First", "Second\u2029with Unicode new line", "Third"])
+ path = str(tmp_path_factory.mktemp("data") / "dataset_with_unicode_new_lines.txt")
+ with open(path, "w", encoding="utf-8") as f:
+ f.write(text)
+ return path
+
+
+@pytest.fixture(scope="session")
+def image_file():
+ return os.path.join("tests", "features", "data", "test_image_rgb.jpg")
+
+
+@pytest.fixture(scope="session")
+def audio_file():
+ return os.path.join("tests", "features", "data", "test_audio_44100.wav")
+
+
+@pytest.fixture(scope="session")
+def bio_dir():
+ return os.path.join("tests", "features", "data", "CAMDA")
+
+
+@pytest.fixture(scope="session")
+def metadata_bio_files():
+ return os.path.join("tests", "features", "data", "CAMDA", "camda.pheno.csv")
+
+
+@pytest.fixture(scope="session")
+def anno_bio_files():
+ return os.path.join("tests", "features", "data", "CAMDA", "camda.feature.csv")
+
+
+@pytest.fixture(scope="session")
+def zip_image_path(image_file, tmp_path_factory):
+ path = tmp_path_factory.mktemp("data") / "dataset.img.zip"
+ with zipfile.ZipFile(path, "w") as f:
+ f.write(image_file, arcname=os.path.basename(image_file))
+ f.write(
+ image_file, arcname=os.path.basename(image_file).replace(".jpg", "2.jpg")
+ )
+ return path
+
+
+@pytest.fixture(scope="session")
+def data_dir_with_hidden_files(tmp_path_factory):
+ data_dir = tmp_path_factory.mktemp("data_dir")
+
+ (data_dir / "subdir").mkdir()
+ with open(data_dir / "subdir" / "train.txt", "w") as f:
+ f.write("foo\n" * 10)
+ with open(data_dir / "subdir" / "test.txt", "w") as f:
+ f.write("bar\n" * 10)
+ # hidden file
+ with open(data_dir / "subdir" / ".test.txt", "w") as f:
+ f.write("bar\n" * 10)
+
+ # hidden directory
+ (data_dir / ".subdir").mkdir()
+ with open(data_dir / ".subdir" / "train.txt", "w") as f:
+ f.write("foo\n" * 10)
+ with open(data_dir / ".subdir" / "test.txt", "w") as f:
+ f.write("bar\n" * 10)
+
+ return data_dir
diff --git a/tests/fixtures/fsspec.py b/tests/fixtures/fsspec.py
new file mode 100644
index 0000000..b7bd9a9
--- /dev/null
+++ b/tests/fixtures/fsspec.py
@@ -0,0 +1,120 @@
+import posixpath
+from pathlib import Path
+from unittest.mock import patch
+
+import pytest
+from fsspec.implementations.local import (
+ AbstractFileSystem,
+ LocalFileSystem,
+ stringify_path,
+)
+from fsspec.registry import _registry as _fsspec_registry
+
+
+class MockFileSystem(AbstractFileSystem):
+ protocol = "mock"
+
+ def __init__(self, *args, local_root_dir, **kwargs):
+ super().__init__()
+ self._fs = LocalFileSystem(*args, **kwargs)
+ self.local_root_dir = Path(local_root_dir).resolve().as_posix() + "/"
+
+ def mkdir(self, path, *args, **kwargs):
+ path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
+ return self._fs.mkdir(path, *args, **kwargs)
+
+ def makedirs(self, path, *args, **kwargs):
+ path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
+ return self._fs.makedirs(path, *args, **kwargs)
+
+ def rmdir(self, path):
+ path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
+ return self._fs.rmdir(path)
+
+ def ls(self, path, detail=True, *args, **kwargs):
+ path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
+ out = self._fs.ls(path, detail=detail, *args, **kwargs)
+ if detail:
+ return [
+ {**info, "name": info["name"][len(self.local_root_dir) :]}
+ for info in out
+ ]
+ else:
+ return [name[len(self.local_root_dir) :] for name in out]
+
+ def info(self, path, *args, **kwargs):
+ path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
+ out = dict(self._fs.info(path, *args, **kwargs))
+ out["name"] = out["name"][len(self.local_root_dir) :]
+ return out
+
+ def cp_file(self, path1, path2, *args, **kwargs):
+ path1 = posixpath.join(self.local_root_dir, self._strip_protocol(path1))
+ path2 = posixpath.join(self.local_root_dir, self._strip_protocol(path2))
+ return self._fs.cp_file(path1, path2, *args, **kwargs)
+
+ def rm_file(self, path, *args, **kwargs):
+ path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
+ return self._fs.rm_file(path, *args, **kwargs)
+
+ def rm(self, path, *args, **kwargs):
+ path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
+ return self._fs.rm(path, *args, **kwargs)
+
+ def _open(self, path, *args, **kwargs):
+ path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
+ return self._fs._open(path, *args, **kwargs)
+
+ def created(self, path):
+ path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
+ return self._fs.created(path)
+
+ def modified(self, path):
+ path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
+ return self._fs.modified(path)
+
+ @classmethod
+ def _strip_protocol(cls, path):
+ path = stringify_path(path)
+ if path.startswith("mock://"):
+ path = path[7:]
+ return path
+
+
+class TmpDirFileSystem(MockFileSystem):
+ protocol = "tmp"
+ tmp_dir = None
+
+ def __init__(self, *args, **kwargs):
+ assert self.tmp_dir is not None, "TmpDirFileSystem.tmp_dir is not set"
+ super().__init__(*args, **kwargs, local_root_dir=self.tmp_dir, auto_mkdir=True)
+
+ @classmethod
+ def _strip_protocol(cls, path):
+ path = stringify_path(path)
+ if path.startswith("tmp://"):
+ path = path[6:]
+ return path
+
+
+@pytest.fixture
+def mock_fsspec():
+ _fsspec_registry["mock"] = MockFileSystem
+ _fsspec_registry["tmp"] = TmpDirFileSystem
+ yield
+ del _fsspec_registry["mock"]
+ del _fsspec_registry["tmp"]
+
+
+@pytest.fixture
+def mockfs(tmp_path_factory, mock_fsspec):
+ local_fs_dir = tmp_path_factory.mktemp("mockfs")
+ return MockFileSystem(local_root_dir=local_fs_dir, auto_mkdir=True)
+
+
+@pytest.fixture
+def tmpfs(tmp_path_factory, mock_fsspec):
+ tmp_fs_dir = tmp_path_factory.mktemp("tmpfs")
+ with patch.object(TmpDirFileSystem, "tmp_dir", tmp_fs_dir):
+ yield TmpDirFileSystem()
+ TmpDirFileSystem.clear_instance_cache()
diff --git a/tests/preprocessing/__init__.py b/tests/preprocessing/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/preprocessing/outputs/histogram.pdf b/tests/preprocessing/outputs/histogram.pdf
new file mode 100644
index 0000000..895cb37
Binary files /dev/null and b/tests/preprocessing/outputs/histogram.pdf differ
diff --git a/tests/preprocessing/outputs/histogram.png b/tests/preprocessing/outputs/histogram.png
new file mode 100644
index 0000000..13a4674
Binary files /dev/null and b/tests/preprocessing/outputs/histogram.png differ
diff --git a/tests/preprocessing/test_abundance_filtering.py b/tests/preprocessing/test_abundance_filtering.py
new file mode 100644
index 0000000..1e05d65
--- /dev/null
+++ b/tests/preprocessing/test_abundance_filtering.py
@@ -0,0 +1,66 @@
+import biofit.config
+import pandas as pd
+import pytest
+from biocore.data_handling import DataHandler
+from biocore.utils.import_util import is_biosets_available, is_polars_available
+from biofit.preprocessing.filtering.row_abundance import AbundanceSampleFilter
+
+from tests.utils import create_bioset
+
+handler = DataHandler()
+
+pytestmark = pytest.mark.unit
+
+
+@pytest.mark.parametrize(
+ "format",
+ [
+ "pandas",
+ "polars",
+ "numpy",
+ "arrow",
+ "dataset",
+ "pandas_cached",
+ "polars_cached",
+ "numpy_cached",
+ "arrow_cached",
+ "dataset_cached",
+ ],
+)
+def test_abundance_filter_otu(count_data, sample_metadata, format):
+ load_from_cache_file = "_cached" in format
+ format = format.replace("_cached", "")
+ X, y = count_data
+ otu_dataset = pd.concat([sample_metadata, X, y], axis=1)
+ proc = AbundanceSampleFilter(load_from_cache_file=load_from_cache_file)
+ cache_dir = biofit.config.BIOFIT_CACHE_HOME
+ input_columns = list(X.columns)
+ if format == "numpy":
+ data = handler.to_format(X, format)
+ proc.fit_transform(data, cache_dir=cache_dir)
+ else:
+ if format == "dataset":
+ if not is_biosets_available():
+ pytest.skip("test requires biosets")
+ from biosets.features import Abundance, BinClassLabel
+
+ otu_dataset = create_bioset(
+ X=X,
+ y=y,
+ sample_metadata=sample_metadata,
+ with_feature_metadata=True,
+ feature_type=Abundance,
+ target_type=BinClassLabel,
+ )
+
+ # column does not need to be specified for dataset format
+ proc.fit_transform(otu_dataset, cache_dir=cache_dir)
+ else:
+ if format == "polars" and not is_polars_available():
+ pytest.skip("test requires polars")
+ data = handler.to_format(otu_dataset, format)
+ proc.fit_transform(
+ data,
+ input_columns=input_columns,
+ cache_dir=cache_dir,
+ )
diff --git a/tests/preprocessing/test_auto_plotting.py b/tests/preprocessing/test_auto_plotting.py
new file mode 100644
index 0000000..80d8903
--- /dev/null
+++ b/tests/preprocessing/test_auto_plotting.py
@@ -0,0 +1,58 @@
+import pandas as pd
+import pytest
+from biocore.data_handling import DataHandler
+
+import biofit.config
+from biofit.auto import AutoPlotter
+from tests.utils import require_biosets, require_rpy2
+
+handler = DataHandler()
+
+
+pytestmark = pytest.mark.integration
+
+
+@require_biosets
+@require_rpy2
+@pytest.mark.parametrize("format", ["dataset", "dataset_cached"])
+def test_auto_plotting_otu(count_data, sample_metadata, format):
+ from biosets.features import Abundance
+ from datasets.features import Features
+
+ format = format.replace("_cached", "")
+ X, y = count_data
+ otu_dataset = pd.concat([sample_metadata, X, y], axis=1)
+ cache_dir = biofit.config.BIOFIT_CACHE_HOME
+
+ otu_dataset = DataHandler.to_bioset(otu_dataset)
+ otu_dataset._info.features = Features(
+ {
+ k: Abundance(dtype="int64") if k in X.columns else v
+ for k, v in otu_dataset._info.features.items()
+ }
+ )
+ proc = AutoPlotter.for_dataset("otu", path=cache_dir)
+ proc.plot(otu_dataset, path=cache_dir)
+
+
+@require_biosets
+@require_rpy2
+@pytest.mark.parametrize("format", ["dataset", "dataset_cached"])
+def test_auto_plotting_snp(binary_data, sample_metadata, format):
+ from biosets.features import Abundance
+ from datasets.features import Features
+
+ format = format.replace("_cached", "")
+ X, y = binary_data
+ snp_dataset = pd.concat([sample_metadata, X, y], axis=1)
+ cache_dir = biofit.config.BIOFIT_CACHE_HOME
+
+ snp_dataset = DataHandler.to_bioset(snp_dataset)
+ snp_dataset._info.features = Features(
+ {
+ k: Abundance(dtype="int64") if k in X.columns else v
+ for k, v in snp_dataset._info.features.items()
+ }
+ )
+ proc = AutoPlotter.for_dataset("snp", path=cache_dir)
+ proc.plot(snp_dataset, path=cache_dir)
diff --git a/tests/preprocessing/test_auto_preprocessing.py b/tests/preprocessing/test_auto_preprocessing.py
new file mode 100644
index 0000000..d5ecd85
--- /dev/null
+++ b/tests/preprocessing/test_auto_preprocessing.py
@@ -0,0 +1,67 @@
+import biofit.config
+import pandas as pd
+import pytest
+from biocore.data_handling import DataHandler
+from biocore.utils.import_util import is_biosets_available, is_polars_available
+from biofit.auto import AutoPreprocessor
+
+from tests.utils import create_bioset
+
+pytestmark = pytest.mark.unit
+
+
+@pytest.mark.parametrize(
+ "format",
+ [
+ "dataset",
+ "pandas",
+ "polars",
+ "numpy",
+ "arrow",
+ "pandas_cached",
+ "polars_cached",
+ "numpy_cached",
+ "arrow_cached",
+ "dataset_cached",
+ ],
+)
+def test_auto_preprocessor(float_data, sample_metadata, format):
+ load_from_cache_file = "_cached" in format
+ format = format.replace("_cached", "")
+ X, y = float_data
+ otu_dataset = pd.concat([sample_metadata, X, y], axis=1)
+ proc = AutoPreprocessor.for_dataset(
+ "snp", load_from_cache_file=load_from_cache_file
+ )
+ cache_dir = biofit.config.BIOFIT_CACHE_HOME
+ input_columns = list(X.columns)
+ if format == "numpy":
+ data = DataHandler.to_format(X, format)
+ proc.fit_transform(data, cache_dir=cache_dir)
+ else:
+ if format == "dataset":
+ if not is_biosets_available():
+ pytest.skip("test requires biosets")
+ from biosets.features import GenomicVariant, BinClassLabel
+
+ otu_dataset = create_bioset(
+ X=X,
+ y=y,
+ sample_metadata=sample_metadata,
+ with_feature_metadata=True,
+ feature_type=GenomicVariant,
+ target_type=BinClassLabel,
+ )
+
+ # column does not need to be specified for dataset format
+ proc.fit_transform(otu_dataset, cache_dir=cache_dir)
+ else:
+ if format == "polars" and not is_polars_available():
+ pytest.skip("test requires polars")
+ data = DataHandler.to_format(otu_dataset, format)
+ proc.fit_transform(
+ data,
+ input_columns=input_columns,
+ cache_dir=cache_dir,
+ )
+ # TODO: add assertions for the output
diff --git a/tests/preprocessing/test_css.py b/tests/preprocessing/test_css.py
new file mode 100644
index 0000000..a41818c
--- /dev/null
+++ b/tests/preprocessing/test_css.py
@@ -0,0 +1,95 @@
+import biofit.config
+import pandas as pd
+import pytest
+from biocore.data_handling import DataHandler
+from biocore.utils.import_util import (
+ is_biosets_available,
+ is_polars_available,
+ is_rpy2_available,
+)
+from biofit.preprocessing import CumulativeSumScaler
+from biofit.preprocessing.scaling.css.plot_css import (
+ CumulativeSumScalerPlotter,
+ CumulativeSumScalerPlotterConfigForOTU,
+)
+
+from tests.utils import create_bioset
+
+handler = DataHandler()
+
+pytestmark = pytest.mark.unit
+
+
+@pytest.mark.parametrize(
+ "format",
+ [
+ "pandas",
+ "polars",
+ "numpy",
+ "arrow",
+ "dataset",
+ "pandas_cached",
+ "polars_cached",
+ "numpy_cached",
+ "arrow_cached",
+ "dataset_cached",
+ ],
+)
+def test_css(float_data, sample_metadata, format):
+ format = format.replace("_cached", "")
+ X, y = float_data
+ otu_dataset = pd.concat([sample_metadata, X, y], axis=1)
+ proc = CumulativeSumScaler()
+ cache_dir = biofit.config.BIOFIT_CACHE_HOME
+ if is_rpy2_available():
+ plotter = CumulativeSumScalerPlotter(
+ config=CumulativeSumScalerPlotterConfigForOTU()
+ )
+ else:
+ plotter = None
+ input_columns = list(X.columns)
+ if format == "numpy":
+ data = handler.to_format(X, format)
+ trans_data = proc.fit_transform(data, cache_dir=cache_dir)
+ target = handler.to_format(y, format)
+ if plotter is not None:
+ plotter.plot(x1=data, x2=trans_data, y1=target, y2=target)
+
+ else:
+ if format == "dataset":
+ if not is_biosets_available():
+ pytest.skip("test requires biosets")
+ from biosets.features import Abundance, BinClassLabel
+
+ otu_dataset = create_bioset(
+ X=X,
+ y=y,
+ sample_metadata=sample_metadata,
+ with_feature_metadata=True,
+ feature_type=Abundance,
+ target_type=BinClassLabel,
+ )
+
+ # column does not need to be specified for dataset format
+ trans_data = proc.fit_transform(otu_dataset, cache_dir=cache_dir)
+ if plotter is not None:
+ plotter.plot(otu_dataset, trans_data)
+ else:
+ if format == "polars" and not is_polars_available():
+ pytest.skip("test requires polars")
+ data = handler.to_format(otu_dataset, format)
+ trans_data = proc.fit_transform(
+ data,
+ input_columns=input_columns,
+ cache_dir=cache_dir,
+ )
+ if plotter is not None:
+ plotter.plot(
+ x1=data,
+ x2=trans_data,
+ input_columns1=input_columns,
+ input_columns2=input_columns,
+ label_name1="labels",
+ label_name2="labels",
+ )
+ # TODO: add assertions for the output
diff --git a/tests/preprocessing/test_label_binarizer.py b/tests/preprocessing/test_label_binarizer.py
new file mode 100644
index 0000000..c9169b5
--- /dev/null
+++ b/tests/preprocessing/test_label_binarizer.py
@@ -0,0 +1,86 @@
+from unittest.mock import patch
+
+import pandas as pd
+import pytest
+from biocore.data_handling import DataHandler
+from biocore.utils.import_util import is_biosets_available, is_polars_available
+
+import biofit.config
+from biofit.preprocessing import LabelBinarizer
+from biofit.processing import NonExistentCacheError
+from tests.utils import create_bioset
+
+EXPECTED_LABELS = [1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0]
+
+pytestmark = pytest.mark.unit
+
+
+@pytest.mark.parametrize(
+ "format",
+ [
+ "pandas",
+ "polars",
+ "numpy",
+ "arrow",
+ "dataset",
+ "pandas_cached",
+ "polars_cached",
+ "numpy_cached",
+ "arrow_cached",
+ "dataset_cached",
+ ],
+)
+def test_label_binarizer(count_data_multi_class, sample_metadata, format):
+ should_load_from_cache = "_cached" in format
+ format = format.replace("_cached", "")
+
+ X, y = count_data_multi_class
+ y = pd.DataFrame({y.columns[0]: [["a", "b", "c"][i] for i in y.values.flatten()]})
+ otu_dataset_multi_class = pd.concat([sample_metadata, X, y], axis=1)
+ cache_dir = biofit.config.BIOFIT_CACHE_HOME
+ with patch.object(
+ NonExistentCacheError, "__init__", return_value=None
+ ) as mock_error:
+ if format == "numpy":
+ proc = LabelBinarizer(negative_labels="a")
+ data = DataHandler.to_format(y, format)
+ out = proc.fit_transform(data, cache_dir=cache_dir)
+ assert out.shape[1] == 1
+ assert out.flatten().tolist() == EXPECTED_LABELS
+
+ else:
+ if format == "dataset":
+ if not is_biosets_available():
+ pytest.skip("test requires biosets")
+ from biosets.features import Abundance, ClassLabel
+
+ otu_dataset_multi_class = create_bioset(
+ X=X,
+ y=y,
+ sample_metadata=sample_metadata,
+ with_feature_metadata=True,
+ feature_type=Abundance,
+ target_type=ClassLabel,
+ )
+
+ # column does not need to be specified for dataset format
+ proc = LabelBinarizer(negative_labels="a")
+ out = proc.fit_transform(otu_dataset_multi_class, cache_dir=cache_dir)
+ else:
+ if format == "polars" and not is_polars_available():
+ pytest.skip("test requires polars")
+ proc = LabelBinarizer(negative_labels="a")
+ data = DataHandler.to_format(otu_dataset_multi_class, format)
+ out = proc.fit_transform(
+ data, input_columns=y.columns[0], cache_dir=cache_dir
+ )
+ assert DataHandler.get_shape(out)[1] == 13
+ assert (
+ DataHandler.to_numpy(out, y.columns[0]).flatten().tolist()
+ == EXPECTED_LABELS
+ )
+ if should_load_from_cache:
+ # ensure that NonExistentCacheError was not raised
+ mock_error.assert_not_called()
+ else:
+ mock_error.assert_called_once()
diff --git a/tests/preprocessing/test_log.py b/tests/preprocessing/test_log.py
new file mode 100644
index 0000000..3bb5bba
--- /dev/null
+++ b/tests/preprocessing/test_log.py
@@ -0,0 +1,66 @@
+import biofit.config
+import pandas as pd
+import pytest
+from biocore.data_handling import DataHandler
+from biocore.utils.import_util import is_biosets_available, is_polars_available
+from biofit.preprocessing import LogTransformer
+
+from tests.utils import create_bioset
+
+handler = DataHandler()
+
+pytestmark = pytest.mark.unit
+
+
+@pytest.mark.parametrize(
+ "format",
+ [
+ "pandas",
+ "polars",
+ "numpy",
+ "arrow",
+ "dataset",
+ "pandas_cached",
+ "polars_cached",
+ "numpy_cached",
+ "arrow_cached",
+ "dataset_cached",
+ ],
+)
+def test_log_transformer(count_data, sample_metadata, format):
+ load_from_cache_file = "_cached" in format
+ format = format.replace("_cached", "")
+ X, y = count_data
+ otu_dataset = pd.concat([sample_metadata, X, y], axis=1)
+ proc = LogTransformer(shift=1, load_from_cache_file=load_from_cache_file)
+ cache_dir = biofit.config.BIOFIT_CACHE_HOME
+ input_columns = list(X.columns)
+ if format == "numpy":
+ data = handler.to_format(X, format)
+ proc.fit_transform(data, cache_dir=cache_dir)
+ else:
+ if format == "dataset":
+ if not is_biosets_available():
+ pytest.skip("test requires biosets")
+ from biosets.features import Abundance, BinClassLabel
+
+ otu_dataset = create_bioset(
+ X=X,
+ y=y,
+ sample_metadata=sample_metadata,
+ with_feature_metadata=True,
+ feature_type=Abundance,
+ target_type=BinClassLabel,
+ )
+
+ # column does not need to be specified for dataset format
+ proc.fit_transform(otu_dataset, cache_dir=cache_dir)
+ else:
+ if format == "polars" and not is_polars_available():
+ pytest.skip("test requires polars")
+ data = handler.to_format(otu_dataset, format)
+ proc.fit_transform(
+ data,
+ input_columns=input_columns,
+ cache_dir=cache_dir,
+ )
diff --git a/tests/preprocessing/test_min_prevalence_features.py b/tests/preprocessing/test_min_prevalence_features.py
new file mode 100644
index 0000000..5d06c53
--- /dev/null
+++ b/tests/preprocessing/test_min_prevalence_features.py
@@ -0,0 +1,91 @@
+import biofit.config
+import pandas as pd
+import pytest
+from biocore.data_handling import DataHandler
+from biocore.utils.import_util import is_biosets_available, is_rpy2_available
+from biofit.preprocessing.feature_selection import (
+ MinPrevalenceFeatureSelector,
+ MinPrevalenceFeatureSelectorPlotter,
+ MinPrevalencePlotterConfigForOTU,
+)
+from tests.utils import create_bioset
+
+pytestmark = pytest.mark.unit
+
+
+@pytest.mark.parametrize(
+ "format",
+ [
+ "pandas",
+ "polars",
+ "numpy",
+ "arrow",
+ "dataset",
+ "pandas_cached",
+ "polars_cached",
+ "numpy_cached",
+ "arrow_cached",
+ "dataset_cached",
+ ],
+)
+def test_min_missing(count_data, sample_metadata, format):
+ new_format = format.replace("_cached", "")
+ X, y = count_data
+ otu_dataset = pd.concat([sample_metadata, X, y], axis=1)
+ # sparsity is 0.3, 0.5, 0.7 for the three features
+ proc = MinPrevalenceFeatureSelector(min_prevalence=0.4, depth=0)
+ if is_rpy2_available():
+ plotter = MinPrevalenceFeatureSelectorPlotter(
+ config=MinPrevalencePlotterConfigForOTU()
+ )
+ else:
+ plotter = None
+ cache_dir = biofit.config.BIOFIT_CACHE_HOME
+ input_columns = list(X.columns)
+ if new_format == "numpy":
+ data = DataHandler.to_format(X, new_format)
+ out = proc.fit_transform(data, cache_dir=cache_dir)
+ if plotter is not None:
+ plotter.plot(x1=data, x2=out)
+ # has no metadata
+ assert DataHandler.get_shape(out)[1] == 3
+ else:
+ if new_format == "dataset":
+ if not is_biosets_available():
+ pytest.skip("test requires biosets")
+ from biosets.features import Abundance, BinClassLabel
+
+ otu_dataset = create_bioset(
+ X=X,
+ y=y,
+ sample_metadata=sample_metadata,
+ with_feature_metadata=True,
+ feature_type=Abundance,
+ target_type=BinClassLabel,
+ )
+ # column does not need to be specified for dataset format
+ out = proc.fit_transform(otu_dataset, cache_dir=cache_dir)
+ if plotter is not None:
+ plotter.plot(x1=otu_dataset, x2=out)
+ else:
+ data = DataHandler.to_format(otu_dataset, new_format)
+ out = proc.fit_transform(
+ data,
+ input_columns=input_columns,
+ cache_dir=cache_dir,
+ )
+ input_columns2 = [
+ col for col in DataHandler.get_column_names(out) if col in input_columns
+ ]
+ if plotter is not None:
+ plotter.plot(
+ x1=data,
+ x2=out,
+ input_columns1=input_columns,
+ input_columns2=input_columns2,
+ )
+ # has metadata
+ assert DataHandler.get_shape(out)[1] == 11
+ # only calculated when not cached
+ if "cached" not in format:
+ assert proc.total_missing.tolist() == [9, 13, 10, 6, 14]
diff --git a/tests/preprocessing/test_min_prevalence_samples.py b/tests/preprocessing/test_min_prevalence_samples.py
new file mode 100644
index 0000000..ce5cb70
--- /dev/null
+++ b/tests/preprocessing/test_min_prevalence_samples.py
@@ -0,0 +1,68 @@
+import pandas as pd
+import pytest
+from biocore.data_handling import DataHandler
+from biocore.utils.import_util import is_datasets_available
+
+import biofit.config
+from biofit.preprocessing.filtering.min_prevalence_sample_filter import (
+ MinPrevalenceSampleFilter,
+)
+from tests.utils import create_bioset
+
+pytestmark = pytest.mark.unit
+
+
+@pytest.mark.parametrize(
+ "format",
+ [
+ "pandas",
+ "polars",
+ "numpy",
+ "arrow",
+ "dataset",
+ "pandas_cached",
+ "polars_cached",
+ "numpy_cached",
+ "arrow_cached",
+ "dataset_cached",
+ ],
+)
+def test_max_missing_row(count_data, sample_metadata, format):
+ load_from_cache_file = "_cached" in format
+ format = format.replace("_cached", "")
+ X, y = count_data
+ otu_dataset = pd.concat([sample_metadata, X, y], axis=1)
+
+ proc = MinPrevalenceSampleFilter(
+ load_from_cache_file=load_from_cache_file, min_prevalence=0.5, depth=0
+ )
+ cache_dir = biofit.config.BIOFIT_CACHE_HOME
+ input_columns = list(X.columns)
+ if format == "numpy":
+ data = DataHandler.to_format(X, format)
+ out = proc.fit_transform(data, cache_dir=cache_dir)
+ else:
+ if format == "dataset":
+ if not is_datasets_available():
+ pytest.skip("test requires datasets")
+ from biosets.features import Abundance, BinClassLabel
+
+ otu_dataset = create_bioset(
+ X=X,
+ y=y,
+ sample_metadata=sample_metadata,
+ with_feature_metadata=True,
+ feature_type=Abundance,
+ target_type=BinClassLabel,
+ )
+
+ # column does not need to be specified for dataset format
+ out = proc.fit_transform(otu_dataset, cache_dir=cache_dir)
+ else:
+ data = DataHandler.to_format(otu_dataset, format)
+ out = proc.fit_transform(
+ data,
+ input_columns=input_columns,
+ cache_dir=cache_dir,
+ )
+ assert DataHandler.get_shape(out)[0] == 10
diff --git a/tests/preprocessing/test_missing_labels.py b/tests/preprocessing/test_missing_labels.py
new file mode 100644
index 0000000..7130ac9
--- /dev/null
+++ b/tests/preprocessing/test_missing_labels.py
@@ -0,0 +1,63 @@
+import biofit.config
+import pandas as pd
+import pytest
+from biocore.utils.import_util import is_biosets_available
+from biofit.preprocessing.filtering.missing_labels import MissingLabelsSampleFilter
+
+from tests.utils import create_bioset
+
+pytestmark = pytest.mark.unit
+
+
+@pytest.mark.parametrize(
+ "format",
+ [
+ "pandas",
+ "polars",
+ "numpy",
+ "arrow",
+ "dataset",
+ "pandas_cached",
+ "polars_cached",
+ "numpy_cached",
+ "arrow_cached",
+ "dataset_cached",
+ ],
+)
+def test_missing_labels(count_data_missing_labels, sample_metadata, format):
+ load_from_cache_file = "_cached" in format
+ format = format.replace("_cached", "")
+ from biocore import DataHandler
+
+ handler = DataHandler()
+ X, y = count_data_missing_labels
+ otu_dataset_missing_labels = pd.concat([sample_metadata, X, y], axis=1)
+ proc = MissingLabelsSampleFilter(
+ load_from_cache_file=load_from_cache_file, missing_label="auto"
+ )
+ cache_dir = biofit.config.BIOFIT_CACHE_HOME
+ if format == "numpy":
+ data = handler.to_format(y, format)
+ out = proc.fit_transform(data, cache_dir=cache_dir)
+ else:
+ if format == "dataset":
+ if not is_biosets_available():
+ pytest.skip("test requires biosets")
+ from biosets.features import Abundance, BinClassLabel
+
+ otu_dataset_missing_labels = create_bioset(
+ X=X,
+ y=y,
+ sample_metadata=sample_metadata,
+ with_feature_metadata=True,
+ feature_type=Abundance,
+ target_type=BinClassLabel,
+ )
+ # column does not need to be specified for dataset format
+ out = proc.fit_transform(otu_dataset_missing_labels, cache_dir=cache_dir)
+ else:
+ data = handler.to_format(otu_dataset_missing_labels, format)
+ out = proc.fit_transform(
+ data, input_columns=y.columns[0], cache_dir=cache_dir
+ )
+ assert handler.get_shape(out)[0] == 16
diff --git a/tests/preprocessing/test_pca.py b/tests/preprocessing/test_pca.py
new file mode 100644
index 0000000..01438b8
--- /dev/null
+++ b/tests/preprocessing/test_pca.py
@@ -0,0 +1,77 @@
+import biofit.config
+import pandas as pd
+import pytest
+from biocore.data_handling import DataHandler
+from biocore.utils.import_util import is_biosets_available
+from biofit.preprocessing import PCAFeatureExtractor
+
+from tests.utils import create_bioset
+
+pytestmark = pytest.mark.unit
+
+
+@pytest.mark.parametrize(
+ "format",
+ [
+ "pandas",
+ "polars",
+ "numpy",
+ "arrow",
+ "dataset",
+ "pandas_cached",
+ "polars_cached",
+ "numpy_cached",
+ "arrow_cached",
+ "dataset_cached",
+ ],
+)
+def test_pca_feature_extractor(count_data, sample_metadata, format):
+ load_from_cache_file = "_cached" in format
+ format = format.replace("_cached", "")
+ X, y = count_data
+ otu_dataset = pd.concat([sample_metadata, X, y], axis=1)
+
+ proc = PCAFeatureExtractor(load_from_cache_file=load_from_cache_file)
+ expected = [-0.37224401, 1.02218086, -1.25448252, -0.22033803, -0.2220483]
+
+ input_columns = list(X.columns)
+ cache_dir = biofit.config.BIOFIT_CACHE_HOME
+ if format == "numpy":
+ data = DataHandler.to_format(X, format)
+ out = proc.fit_transform(data)
+ assert DataHandler.get_shape(out)[1] == 5
+ assert DataHandler.to_list(DataHandler.select_row(out, 0)) == pytest.approx(
+ expected, abs=1e-3
+ )
+ else:
+ if format == "dataset":
+ if not is_biosets_available():
+ pytest.skip("test requires biosets")
+ from biosets.features import Abundance, BinClassLabel
+
+ otu_dataset = create_bioset(
+ X=X,
+ y=y,
+ sample_metadata=sample_metadata,
+ with_feature_metadata=True,
+ feature_type=Abundance,
+ target_type=BinClassLabel,
+ )
+
+ # column does not need to be specified for dataset format
+ out = proc.fit_transform(otu_dataset)
+ else:
+ data = DataHandler.to_format(otu_dataset, format)
+ out = proc.fit_transform(
+ data, input_columns=input_columns, cache_dir=cache_dir
+ )
+ # 20 + 8 columns
+ assert DataHandler.get_shape(out)[1] == 13
+ assert DataHandler.select_row(
+ DataHandler.to_numpy(
+ DataHandler.select_columns(
+ out, [f"pca_{i}" for i in range(len(input_columns))]
+ )
+ ),
+ 0,
+ ) == pytest.approx(expected, abs=1e-3)
diff --git a/tests/preprocessing/test_pcoa.py b/tests/preprocessing/test_pcoa.py
new file mode 100644
index 0000000..1609fdd
--- /dev/null
+++ b/tests/preprocessing/test_pcoa.py
@@ -0,0 +1,99 @@
+import biofit.config
+import pandas as pd
+import pytest
+from biocore.data_handling import DataHandler
+from biocore.utils.import_util import is_biosets_available
+from biofit.preprocessing import (
+ PCoAFeatureExtractor,
+)
+
+from tests.utils import create_bioset
+
+pytestmark = pytest.mark.unit
+
+
+@pytest.mark.parametrize(
+ "format",
+ [
+ "pandas",
+ "polars",
+ "numpy",
+ "arrow",
+ "dataset",
+ "pandas_cached",
+ "polars_cached",
+ "numpy_cached",
+ "arrow_cached",
+ "dataset_cached",
+ ],
+)
+def test_pcoa_feature_extractor(count_data, sample_metadata, format):
+ load_from_cache_file = "_cached" in format
+ format = format.replace("_cached", "")
+ X, y = count_data
+ otu_dataset = pd.concat([sample_metadata, X, y], axis=1)
+
+ proc = PCoAFeatureExtractor(
+ correction="cailliez", load_from_cache_file=load_from_cache_file
+ )
+ input_columns = list(X.columns)
+ cache_dir = biofit.config.BIOFIT_CACHE_HOME
+ if format == "numpy":
+ data = DataHandler.to_format(X, format)
+ out = proc.fit_transform(data, cache_dir=cache_dir)[0]
+ else:
+ if format == "dataset":
+ if not is_biosets_available():
+ pytest.skip("test requires biosets")
+ from biosets.features import Abundance, BinClassLabel
+
+ otu_dataset = create_bioset(
+ X=X,
+ y=y,
+ sample_metadata=sample_metadata,
+ with_feature_metadata=True,
+ feature_type=Abundance,
+ target_type=BinClassLabel,
+ )
+
+ # column does not need to be specified for dataset format
+ out = proc.fit_transform(otu_dataset, cache_dir=cache_dir)
+ out = DataHandler.to_numpy(
+ DataHandler.drop_columns(
+ out, list(sample_metadata.columns) + list(y.columns)
+ )
+ )[0]
+ else:
+ data = DataHandler.to_format(otu_dataset, format)
+ out = proc.fit_transform(
+ data, input_columns=input_columns, cache_dir=cache_dir
+ )
+ out = DataHandler.to_numpy(
+ DataHandler.drop_columns(
+ out, list(sample_metadata.columns) + list(y.columns)
+ )
+ )[0]
+
+ assert out == pytest.approx(
+ [
+ 1.94211301e-01,
+ 6.14681757e-01,
+ -1.57956243e-01,
+ -5.19891966e-01,
+ 1.95193082e-01,
+ 1.75318366e-01,
+ 4.31596285e-02,
+ -2.40616503e-01,
+ 6.36366314e-02,
+ -1.08292981e-01,
+ 1.35341980e-02,
+ 6.58959194e-16,
+ 8.61517981e-16,
+ 3.13058241e-16,
+ 1.03875001e-02,
+ -1.46641936e-01,
+ 1.47741248e-01,
+ 7.09769578e-02,
+ ],
+ abs=1e-2,
+ )
diff --git a/tests/preprocessing/test_r_caller.py b/tests/preprocessing/test_r_caller.py
new file mode 100644
index 0000000..056c2a9
--- /dev/null
+++ b/tests/preprocessing/test_r_caller.py
@@ -0,0 +1,70 @@
+import os
+
+import biofit.config
+from biofit.integration.R.r_caller import RCaller
+from tests.utils import require_rpy2
+
+
+@require_rpy2
+def test_r_caller(count_data):
+ X, y = count_data
+ otu_dataset = X
+ X = otu_dataset.data.table
+
+ os.makedirs(os.path.join(biofit.config.BIOFIT_CACHE_HOME, "outputs"), exist_ok=True)
+ output_path = os.path.join(biofit.config.BIOFIT_CACHE_HOME, "outputs/histogram.png")
+
+ r = RCaller.from_script(
+ """
+ plot_histogram <- function(X, output_path, breaks=30) {
+ X <- as.data.frame(X)
+ library(ggplot2)
+ df <- data.frame(abundance=rowSums(X))
+ p <- ggplot(df, aes(x=abundance)) +
+ geom_histogram(bins=breaks) +
+ theme_minimal()
+ }
+ """
+ )
+ r.verify_r_dependencies(bioconductor_dependencies=["edgeR"], install_missing=True)
+ func = r.get_method("plot_histogram", exit_code="ggsave(output_path, plot=results)")
+
+ func(X, output_path)
+
+
+@require_rpy2
+def test_r_caller_create_dataframe(count_data):
+ X, y = count_data
+ otu_dataset = X
+ X = otu_dataset.data.table
+
+ r = RCaller.from_script(
+ """
+ create_dataframe <- function(X) {
+ X <- as.data.frame(X)
+ return(X)
+ }
+ """
+ )
+ func = r.get_method("create_dataframe")
+
+ func(X)
+
+
+@require_rpy2
+def test_r_caller_create_otu_dataframe(count_data):
+ X, y = count_data
+ otu_dataset = X
+ X = otu_dataset.data.table
+
+ r = RCaller.from_script(
+ """
+ create_otu_dataframe <- function(X) {
+ X <- as.data.frame(X)
+ df <- data.frame(abundance=rowSums(X))
+ }
+ """
+ )
+ func = r.get_method("create_otu_dataframe")
+
+ func(X)
diff --git a/tests/preprocessing/test_tmm.py b/tests/preprocessing/test_tmm.py
new file mode 100644
index 0000000..b89cf00
--- /dev/null
+++ b/tests/preprocessing/test_tmm.py
@@ -0,0 +1,90 @@
+import biofit.config
+import pandas as pd
+import pytest
+from biocore.data_handling import DataHandler
+from biocore.utils.import_util import is_biosets_available, is_polars_available
+from biofit.preprocessing import TMMScaler
+
+from tests.utils import create_bioset, require_rpy2
+
+FORMATS = [
+ "pandas",
+ "polars",
+ "numpy",
+ "arrow",
+ "dataset",
+ "pandas_cached",
+ "polars_cached",
+ "numpy_cached",
+ "arrow_cached",
+ "dataset_cached",
+]
+
+pytestmark = pytest.mark.unit
+
+
+@require_rpy2
+@pytest.mark.parametrize("format", FORMATS)
+def test_otu_tmm(count_data, sample_metadata, format):
+ format = format.replace("_cached", "")
+ X, y = count_data
+ otu_dataset = pd.concat([sample_metadata, X, y], axis=1)
+ proc = TMMScaler(install_missing=True)
+ cache_dir = biofit.config.BIOFIT_CACHE_HOME
+ input_columns = list(X.columns)
+ if format == "numpy":
+ data = DataHandler.to_format(X, format)
+ trans_data = proc.fit_transform(data, cache_dir=cache_dir)
+ trans_data = pd.DataFrame(trans_data, columns=input_columns)
+ else:
+ if format == "dataset":
+ if not is_biosets_available():
+ pytest.skip("test requires biosets")
+ from biosets.features import Abundance, BinClassLabel
+
+ otu_dataset = create_bioset(
+ X=X,
+ y=y,
+ sample_metadata=sample_metadata,
+ with_feature_metadata=True,
+ feature_type=Abundance,
+ target_type=BinClassLabel,
+ )
+
+ # column does not need to be specified for dataset format
+ trans_data = proc.fit_transform(otu_dataset, cache_dir=cache_dir)
+ else:
+ if format == "polars" and not is_polars_available():
+ pytest.skip("test requires polars")
+ data = DataHandler.to_format(otu_dataset, format)
+ trans_data = proc.fit_transform(
+ data,
+ input_columns=input_columns,
+ cache_dir=cache_dir,
+ )
+ trans_data = DataHandler.to_pandas(trans_data)[input_columns]
+
+ expected_means = [
+ 16.29848284590347,
+ 16.109477689114623,
+ 16.1056032673168,
+ 16.26133423278604,
+ 16.365374835244474,
+ 16.52574914648284,
+ 16.03496933879711,
+ 16.333477800711826,
+ 16.329265689665576,
+ 16.414420051652822,
+ 16.327363474257353,
+ 16.3420904314305,
+ 16.51850994922633,
+ 16.372262272154607,
+ 16.643263063183195,
+ 16.429082863879305,
+ 16.62725573408411,
+ 16.268096219909324,
+ 16.187416726134824,
+ 16.238862083800324,
+ ]
+
+ assert trans_data.mean().tolist() == pytest.approx(expected_means, rel=1e-6)
diff --git a/tests/preprocessing/test_upsampling.py b/tests/preprocessing/test_upsampling.py
new file mode 100644
index 0000000..4c04ca8
--- /dev/null
+++ b/tests/preprocessing/test_upsampling.py
@@ -0,0 +1,70 @@
+import biofit.config
+import pandas as pd
+import pytest
+from biocore.utils.import_util import is_biosets_available
+from biofit.preprocessing.resampling.upsampling import (
+ UpSampler,
+ UpSamplerConfigForOTU,
+)
+
+from tests.utils import create_bioset
+
+pytestmark = pytest.mark.unit
+
+
+@pytest.mark.parametrize(
+ "format",
+ [
+ "pandas",
+ "polars",
+ "numpy",
+ "arrow",
+ "dataset",
+ "pandas_cached",
+ "polars_cached",
+ "numpy_cached",
+ "arrow_cached",
+ "dataset_cached",
+ ],
+)
+def test_upsampling(count_data, sample_metadata, format):
+ from biocore import DataHandler
+
+ format = format.replace("_cached", "")
+ X, y = count_data
+ otu_dataset = pd.concat([sample_metadata, X, y], axis=1)
+ handler = DataHandler()
+
+ proc = UpSampler(config=UpSamplerConfigForOTU())
+ cache_dir = biofit.config.BIOFIT_CACHE_HOME
+ if format == "numpy":
+ data = handler.to_format(X, format)
+ labs = handler.to_format(y, format)
+ out = proc.fit_transform(data, labs, cache_dir=cache_dir)
+ else:
+ if format == "dataset":
+ if not is_biosets_available():
+ pytest.skip("test requires biosets")
+ from biosets.features import Abundance, BinClassLabel
+
+ otu_dataset = create_bioset(
+ X=X,
+ y=y,
+ sample_metadata=sample_metadata,
+ with_feature_metadata=True,
+ feature_type=Abundance,
+ target_type=BinClassLabel,
+ )
+
+ # column does not need to be specified for dataset format
+ out = proc.fit_transform(otu_dataset)
+ else:
+ data = handler.to_format(otu_dataset, format)
+ out = proc.fit_transform(
+ data,
+ input_columns=list(X.columns),
+ target_column=y.columns[0],
+ cache_dir=cache_dir,
+ )
+ labs = handler.to_numpy(out, y.columns[0])
+ assert labs.sum() / len(labs) == 0.5
diff --git a/tests/stat/__init__.py b/tests/stat/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/stat/test_correlation.py b/tests/stat/test_correlation.py
new file mode 100644
index 0000000..dfe6976
--- /dev/null
+++ b/tests/stat/test_correlation.py
@@ -0,0 +1,89 @@
+import biofit.config
+import pandas as pd
+import pytest
+from biocore.data_handling import DataHandler
+from biocore.utils.import_util import is_biosets_available, is_polars_available
+from biofit.stat import CorrelationStat
+from biofit.stat.correlation.correlation import CorrelationStatConfig
+
+from tests.utils import create_bioset
+
+handler = DataHandler()
+
+pytestmark = pytest.mark.unit
+
+
+@pytest.mark.parametrize(
+ "format",
+ [
+ "pandas",
+ "polars",
+ "numpy",
+ "arrow",
+ "dataset",
+ "pandas_cached",
+ "polars_cached",
+ "numpy_cached",
+ "arrow_cached",
+ "dataset_cached",
+ ],
+)
+def test_correlation(count_data, sample_metadata, format):
+ from biocore import DataHandler
+
+ load_from_cache_file = "_cached" in format
+ format = format.replace("_cached", "")
+ X, y = count_data
+ otu_dataset = pd.concat([sample_metadata, X, y], axis=1)
+ handler = DataHandler()
+
+ proc = CorrelationStat(
+ config=CorrelationStatConfig(load_from_cache_file=load_from_cache_file)
+ )
+ cache_dir = biofit.config.BIOFIT_CACHE_HOME
+ if format == "numpy":
+ data = handler.to_format(X, format)
+ labs = handler.to_format(y, format)
+ out = proc.fit_transform(data, labs, cache_dir=cache_dir)
+ else:
+ if format == "dataset":
+ if not is_biosets_available():
+ pytest.skip("test requires biosets")
+ from biosets.features import Abundance, BinClassLabel
+
+ otu_dataset = create_bioset(
+ X=X,
+ y=y,
+ sample_metadata=sample_metadata,
+ with_feature_metadata=True,
+ feature_type=Abundance,
+ target_type=BinClassLabel,
+ )
+
+ # column does not need to be specified for dataset format
+ out = proc.fit_transform(otu_dataset)
+ else:
+ if format == "polars" and not is_polars_available():
+ pytest.skip("test requires polars")
+ data = handler.to_format(otu_dataset, format)
+ out = proc.fit_transform(
+ data,
+ input_columns=list(X.columns),
+ target_column=y.columns[0],
+ cache_dir=cache_dir,
+ )
+ out = handler.to_list(out)
+ if format in ["arrow", "dataset"]:
+ out = [out[0] for out in out]
+
+ if isinstance(out[0], list):
+ out = out[0]
+
+ expected_vals = [
+ 0.12309149097933272,
+ 0.34874291623145787,
+ -1.3877787807814457e-17,
+ 0.09325048082403137,
+ 0.6081636405595372,
+ ]
+ assert out == pytest.approx(expected_vals, abs=1e-6)
diff --git a/tests/stat/test_distance.py b/tests/stat/test_distance.py
new file mode 100644
index 0000000..918d2d0
--- /dev/null
+++ b/tests/stat/test_distance.py
@@ -0,0 +1,100 @@
+import biofit.config
+import pandas as pd
+import pytest
+from biocore.data_handling import DataHandler
+from biocore.utils.import_util import is_biosets_available
+from biofit.stat.distance import DistanceStat
+from biofit.stat.distance.distance import DistanceStatConfigForOTU
+
+from tests.utils import create_bioset
+
+handler = DataHandler()
+
+pytestmark = pytest.mark.unit
+
+
+@pytest.mark.parametrize(
+ "format",
+ [
+ "pandas",
+ "polars",
+ "numpy",
+ "arrow",
+ "dataset",
+ "pandas_cached",
+ "polars_cached",
+ "numpy_cached",
+ "arrow_cached",
+ "dataset_cached",
+ ],
+)
+def test_distance(count_data, sample_metadata, format):
+ from biocore import DataHandler
+
+ load_from_cache_file = "_cached" in format
+ format = format.replace("_cached", "")
+ X, y = count_data
+ otu_dataset = pd.concat([sample_metadata, X, y], axis=1)
+ handler = DataHandler()
+
+ proc = DistanceStat(
+ config=DistanceStatConfigForOTU(load_from_cache_file=load_from_cache_file)
+ )
+ cache_dir = biofit.config.BIOFIT_CACHE_HOME
+ if format == "numpy":
+ data = handler.to_format(X, format)
+ out = proc.fit_transform(data, cache_dir=cache_dir)
+ else:
+ if format == "dataset":
+ if not is_biosets_available():
+ pytest.skip("test requires biosets")
+ from biosets.features import Abundance, BinClassLabel
+
+ otu_dataset = create_bioset(
+ X=X,
+ y=y,
+ sample_metadata=sample_metadata,
+ with_feature_metadata=True,
+ feature_type=Abundance,
+ target_type=BinClassLabel,
+ )
+
+ # column does not need to be specified for dataset format
+ out = proc.fit_transform(otu_dataset)
+ else:
+ data = handler.to_format(otu_dataset, format)
+ out = proc.fit_transform(
+ data,
+ input_columns=list(X.columns),
+ cache_dir=cache_dir,
+ )
+ out = handler.to_list(out)
+ if format in ["arrow", "dataset"]:
+ out = [out[0] for out in out]
+
+ if isinstance(out[0], list):
+ out = out[0]
+
+ expected_vals = [
+ 0.0,
+ 1.0,
+ 1.0,
+ 0.3333333333333333,
+ 1.0,
+ 0.5,
+ 0.7777777777777778,
+ 0.6,
+ 0.42857142857142855,
+ 0.4,
+ 0.7142857142857143,
+ 0.75,
+ 0.25,
+ 0.42857142857142855,
+ 1.0,
+ 0.75,
+ 0.7777777777777778,
+ 1.0,
+ 0.7142857142857143,
+ 0.6,
+ ]
+ assert out == pytest.approx(expected_vals, abs=1e-6)
diff --git a/tests/stat/test_mean.py b/tests/stat/test_mean.py
new file mode 100644
index 0000000..9c8ef54
--- /dev/null
+++ b/tests/stat/test_mean.py
@@ -0,0 +1,157 @@
+import biofit.config
+import pandas as pd
+import pytest
+from biocore.data_handling import DataHandler
+from biocore.utils.import_util import is_biosets_available
+from biofit.stat.col_mean.col_mean import ColumnMeanStat, ColumnMeanStatConfigForOTU
+from biofit.stat.row_mean.row_mean import RowMeanStat, RowMeanStatConfigForOTU
+
+from tests.utils import create_bioset
+
+pytestmark = pytest.mark.unit
+
+
+@pytest.mark.parametrize(
+ "format",
+ [
+ "pandas",
+ "polars",
+ "numpy",
+ "arrow",
+ "dataset",
+ "pandas_cached",
+ "polars_cached",
+ "numpy_cached",
+ "arrow_cached",
+ "dataset_cached",
+ ],
+)
+def test_col_mean(count_data, sample_metadata, format):
+ load_from_cache_file = "_cached" in format
+ format = format.replace("_cached", "")
+ X, y = count_data
+ otu_dataset = pd.concat([sample_metadata, X, y], axis=1)
+
+ proc = ColumnMeanStat(
+ config=ColumnMeanStatConfigForOTU(load_from_cache_file=load_from_cache_file)
+ )
+ cache_dir = biofit.config.BIOFIT_CACHE_HOME
+ if format == "numpy":
+ data = DataHandler.to_format(X, format)
+ out = proc.fit_transform(data, cache_dir=cache_dir)
+ else:
+ if format == "dataset":
+ if not is_biosets_available():
+ pytest.skip("test requires biosets")
+ from biosets.features import Abundance, BinClassLabel
+
+ otu_dataset = create_bioset(
+ X=X,
+ y=y,
+ sample_metadata=sample_metadata,
+ with_feature_metadata=True,
+ feature_type=Abundance,
+ target_type=BinClassLabel,
+ )
+
+ # column does not need to be specified for dataset format
+ out = proc.fit_transform(otu_dataset)
+ else:
+ data = DataHandler.to_format(otu_dataset, format)
+ out = proc.fit_transform(
+ data,
+ input_columns=list(X.columns),
+ cache_dir=cache_dir,
+ )
+ out = DataHandler.to_list(out)
+ out = [out[0] if isinstance(out, list) and len(out) == 1 else out for out in out]
+
+ if isinstance(out[0], list):
+ out = out[0]
+
+ expected_values = [0.8, 0.6, 0.7, 0.75, 0.45]
+
+ assert out == pytest.approx(expected_values, abs=1e-6)
+
+
+@pytest.mark.parametrize(
+ "format",
+ [
+ "pandas",
+ "polars",
+ "numpy",
+ "arrow",
+ "dataset",
+ "pandas_cached",
+ "polars_cached",
+ "numpy_cached",
+ "arrow_cached",
+ "dataset_cached",
+ ],
+)
+def test_row_mean(count_data, sample_metadata, format):
+ load_from_cache_file = "_cached" in format
+ format = format.replace("_cached", "")
+ X, y = count_data
+ otu_dataset = pd.concat([sample_metadata, X, y], axis=1)
+
+ proc = RowMeanStat(
+ config=RowMeanStatConfigForOTU(load_from_cache_file=load_from_cache_file)
+ )
+ cache_dir = biofit.config.BIOFIT_CACHE_HOME
+ if format == "numpy":
+ data = DataHandler.to_format(X, format)
+ out = proc.fit_transform(data, cache_dir=cache_dir)
+ else:
+ if format == "dataset":
+ if not is_biosets_available():
+ pytest.skip("test requires biosets")
+ from biosets.features import Abundance, BinClassLabel
+
+ otu_dataset = create_bioset(
+ X=X,
+ y=y,
+ sample_metadata=sample_metadata,
+ with_feature_metadata=True,
+ feature_type=Abundance,
+ target_type=BinClassLabel,
+ )
+
+ # column does not need to be specified for dataset format
+ out = proc.fit_transform(otu_dataset)
+ else:
+ data = DataHandler.to_format(otu_dataset, format)
+ out = proc.fit_transform(
+ data,
+ input_columns=list(X.columns),
+ cache_dir=cache_dir,
+ )
+ out = DataHandler.to_list(out)
+ out = [out[0] if isinstance(out, list) and len(out) == 1 else out for out in out]
+
+ if isinstance(out[0], list):
+ out = out[0]
+
+ expected_values = [
+ 0.6,
+ 0.2,
+ 0.0,
+ 0.6,
+ 0.4,
+ 0.2,
+ 1.2,
+ 0.4,
+ 0.8,
+ 1.4,
+ 0.8,
+ 1.0,
+ 1.0,
+ 0.8,
+ 0.2,
+ 1.0,
+ 1.2,
+ 0.2,
+ 0.8,
+ 0.4,
+ ]
+ assert out == pytest.approx(expected_values, abs=1e-6)
diff --git a/tests/stat/test_missingness.py b/tests/stat/test_missingness.py
new file mode 100644
index 0000000..ea62504
--- /dev/null
+++ b/tests/stat/test_missingness.py
@@ -0,0 +1,150 @@
+import pandas as pd
+import pytest
+from biocore.data_handling import DataHandler
+from biocore.utils.import_util import is_biosets_available, is_datasets_available
+
+import biofit.config
+from biofit.stat import (
+ ColumnMissingnessStat,
+ RowMissingnessStat,
+)
+from biofit.stat.col_missingness.col_missingness import (
+ ColumnMissingnessStatConfigForOTU,
+)
+from biofit.stat.row_missingness.row_missingness import RowMissingnessStatConfigForOTU
+from tests.utils import create_bioset
+
+handler = DataHandler()
+
+pytestmark = pytest.mark.unit
+
+
+@pytest.mark.parametrize(
+ "format",
+ [
+ "pandas",
+ "polars",
+ "numpy",
+ "arrow",
+ "dataset",
+ "pandas_cached",
+ "polars_cached",
+ "numpy_cached",
+ "arrow_cached",
+ "dataset_cached",
+ ],
+)
+def test_col_missingness(count_data, sample_metadata, format):
+ load_from_cache_file = "_cached" in format
+ format = format.replace("_cached", "")
+
+ X, y = count_data
+ otu_dataset = pd.concat([sample_metadata, X, y], axis=1)
+ proc = ColumnMissingnessStat(
+ config=ColumnMissingnessStatConfigForOTU(
+ depth=0, load_from_cache_file=load_from_cache_file
+ )
+ )
+ cache_dir = biofit.config.BIOFIT_CACHE_HOME
+ input_columns = list(X.columns)
+ if format == "numpy":
+ data = handler.to_format(X, format)
+ out = proc.fit_transform(data, cache_dir=cache_dir)
+ else:
+ if format == "dataset":
+ if not is_datasets_available():
+ pytest.skip("test requires datasets")
+ from biosets.features import Abundance, BinClassLabel
+
+ otu_dataset = create_bioset(
+ X=X,
+ y=y,
+ sample_metadata=sample_metadata,
+ with_feature_metadata=True,
+ feature_type=Abundance,
+ target_type=BinClassLabel,
+ )
+
+ # column does not need to be specified for dataset format
+ out = proc.fit_transform(otu_dataset)
+ else:
+ data = handler.to_format(otu_dataset, format)
+ out = proc.fit_transform(
+ data,
+ input_columns=input_columns,
+ cache_dir=cache_dir,
+ )
+ out = handler.to_list(out)
+ out = handler.to_list(out)
+ out = [out[0] if isinstance(out, list) and len(out) == 1 else out for out in out]
+
+ if isinstance(out[0], list):
+ out = out[0]
+
+ expected_values = [9, 13, 10, 6, 14]
+
+ assert out == pytest.approx(expected_values, abs=1e-6)
+
+
+@pytest.mark.parametrize(
+ "format",
+ [
+ "pandas",
+ "polars",
+ "numpy",
+ "arrow",
+ "dataset",
+ "pandas_cached",
+ "polars_cached",
+ "numpy_cached",
+ "arrow_cached",
+ "dataset_cached",
+ ],
+)
+def test_row_missingness(count_data, sample_metadata, format):
+ load_from_cache_file = "_cached" in format
+ format = format.replace("_cached", "")
+
+ X, y = count_data
+ otu_dataset = pd.concat([sample_metadata, X, y], axis=1)
+ proc = RowMissingnessStat(
+ config=RowMissingnessStatConfigForOTU(
+ depth=0, load_from_cache_file=load_from_cache_file
+ )
+ )
+ cache_dir = biofit.config.BIOFIT_CACHE_HOME
+ if format == "numpy":
+ data = handler.to_format(X, format)
+ out = proc.fit_transform(data, cache_dir=cache_dir)
+ else:
+ if format == "dataset":
+ if not is_biosets_available():
+ pytest.skip("test requires biosets")
+ from biosets.features import Abundance, BinClassLabel
+
+ otu_dataset = create_bioset(
+ X=X,
+ y=y,
+ sample_metadata=sample_metadata,
+ with_feature_metadata=True,
+ feature_type=Abundance,
+ target_type=BinClassLabel,
+ )
+
+ # column does not need to be specified for dataset format
+ out = proc.fit_transform(otu_dataset)
+ else:
+ data = handler.to_format(otu_dataset, format)
+ out = proc.fit_transform(
+ data,
+ input_columns=list(X.columns),
+ cache_dir=cache_dir,
+ )
+ out = handler.to_list(out)
+ out = [out[0] if isinstance(out, list) and len(out) == 1 else out for out in out]
+
+ if isinstance(out[0], list):
+ out = out[0]
+
+ expected_values = [3, 4, 5, 2, 4, 4, 1, 4, 1, 0, 2, 2, 2, 2, 4, 1, 1, 4, 3, 3]
+ assert out == pytest.approx(expected_values, abs=1e-6)
diff --git a/tests/stat/test_sum.py b/tests/stat/test_sum.py
new file mode 100644
index 0000000..663bf11
--- /dev/null
+++ b/tests/stat/test_sum.py
@@ -0,0 +1,157 @@
+import biofit.config
+import pandas as pd
+import pytest
+from biocore.data_handling import DataHandler
+from biocore.utils.import_util import is_biosets_available
+from biofit.stat.col_sum.col_sum import ColumnSumStat, ColumnSumStatConfigForOTU
+from biofit.stat.row_sum.row_sum import RowSumStat, RowSumStatConfigForOTU
+
+from tests.utils import create_bioset
+
+pytestmark = pytest.mark.unit
+
+
+@pytest.mark.parametrize(
+ "format",
+ [
+ "pandas",
+ "polars",
+ "numpy",
+ "arrow",
+ "dataset",
+ "pandas_cached",
+ "polars_cached",
+ "numpy_cached",
+ "arrow_cached",
+ "dataset_cached",
+ ],
+)
+def test_col_sum(count_data, sample_metadata, format):
+ load_from_cache_file = "_cached" in format
+ format = format.replace("_cached", "")
+ X, y = count_data
+ otu_dataset = pd.concat([sample_metadata, X, y], axis=1)
+
+ proc = ColumnSumStat(
+ config=ColumnSumStatConfigForOTU(load_from_cache_file=load_from_cache_file)
+ )
+ cache_dir = biofit.config.BIOFIT_CACHE_HOME
+ if format == "numpy":
+ data = DataHandler.to_format(X, format)
+ out = proc.fit_transform(data, cache_dir=cache_dir)
+ else:
+ if format == "dataset":
+ if not is_biosets_available():
+ pytest.skip("test requires biosets")
+ from biosets.features import Abundance, BinClassLabel
+
+ otu_dataset = create_bioset(
+ X=X,
+ y=y,
+ sample_metadata=sample_metadata,
+ with_feature_metadata=True,
+ feature_type=Abundance,
+ target_type=BinClassLabel,
+ )
+
+ # column does not need to be specified for dataset format
+ out = proc.fit_transform(otu_dataset)
+ else:
+ data = DataHandler.to_format(otu_dataset, format)
+ out = proc.fit_transform(
+ data,
+ input_columns=list(X.columns),
+ cache_dir=cache_dir,
+ )
+ out = DataHandler.to_list(out)
+ out = [out[0] if isinstance(out, list) and len(out) == 1 else out for out in out]
+
+ if isinstance(out[0], list):
+ out = out[0]
+
+ expected_values = [16, 12, 14, 15, 9]
+
+ assert out == expected_values
+
+
+@pytest.mark.parametrize(
+ "format",
+ [
+ "pandas",
+ "polars",
+ "numpy",
+ "arrow",
+ "dataset",
+ "pandas_cached",
+ "polars_cached",
+ "numpy_cached",
+ "arrow_cached",
+ "dataset_cached",
+ ],
+)
+def test_row_sum(count_data, sample_metadata, format):
+ load_from_cache_file = "_cached" in format
+ format = format.replace("_cached", "")
+ X, y = count_data
+ otu_dataset = pd.concat([sample_metadata, X, y], axis=1)
+
+ proc = RowSumStat(
+ config=RowSumStatConfigForOTU(load_from_cache_file=load_from_cache_file)
+ )
+ cache_dir = biofit.config.BIOFIT_CACHE_HOME
+ if format == "numpy":
+ data = DataHandler.to_format(X, format)
+ out = proc.fit_transform(data, cache_dir=cache_dir)
+ else:
+ if format == "dataset":
+ if not is_biosets_available():
+ pytest.skip("test requires biosets")
+ from biosets.features import Abundance, BinClassLabel
+
+ otu_dataset = create_bioset(
+ X=X,
+ y=y,
+ sample_metadata=sample_metadata,
+ with_feature_metadata=True,
+ feature_type=Abundance,
+ target_type=BinClassLabel,
+ )
+
+ # column does not need to be specified for dataset format
+ out = proc.fit_transform(otu_dataset)
+ else:
+ data = DataHandler.to_format(otu_dataset, format)
+ out = proc.fit_transform(
+ data,
+ input_columns=list(X.columns),
+ cache_dir=cache_dir,
+ )
+ out = DataHandler.to_list(out)
+ out = [out[0] if isinstance(out, list) and len(out) == 1 else out for out in out]
+
+ if isinstance(out[0], list):
+ out = out[0]
+
+ expected_values = [
+ 3.0,
+ 1.0,
+ 0.0,
+ 3.0,
+ 2.0,
+ 1.0,
+ 6.0,
+ 2.0,
+ 4.0,
+ 7.0,
+ 4.0,
+ 5.0,
+ 5.0,
+ 4.0,
+ 1.0,
+ 5.0,
+ 6.0,
+ 1.0,
+ 4.0,
+ 2.0,
+ ]
+ assert out == pytest.approx(expected_values, abs=1e-6)
diff --git a/tests/test_eval.py b/tests/test_eval.py
new file mode 100644
index 0000000..c3b4dff
--- /dev/null
+++ b/tests/test_eval.py
@@ -0,0 +1,692 @@
+import os
+
+import numpy as np
+import pandas as pd
+import pytest
+from biocore.data_handling import DataHandler
+from biocore.utils.import_util import is_biosets_available, is_lightgbm_available
+from biofit.eval import evaluate
+from biofit.models.lasso.lasso import LassoForClassification
+from biofit.models.lightgbm.lightgbm import LightGBMForClassification
+from biofit.models.random_forest.random_forest import RandomForestForClassification
+from biofit.utils.py_util import set_seed
+
+from tests.utils import create_bioset
+
+SUPPORTED_MODELS = [
+ "lightgbm",
+ "lasso",
+ "random_forest",
+ # "svm",
+]
+
+FORMATS = [
+ "pandas",
+ "polars",
+ "numpy",
+ "arrow",
+ "dataset",
+ "pandas_shuffle",
+ "polars_shuffle",
+ "arrow_shuffle",
+ "dataset_shuffle",
+ "pandas_shuffle",
+ "polars_seperate",
+ "arrow_seperate",
+ "dataset_seperate",
+ "pandas_cached",
+ "polars_cached",
+ "numpy_cached",
+ "arrow_cached",
+ "dataset_cached",
+]
+
+EXPECTED_VALS = {
+ "binary_classification": {
+ "lightgbm": {
+ "expected_preds": [
+ [0.9535689791828371, 0.0464310208171629],
+ [0.9669298162770052, 0.03307018372299475],
+ [0.12790019069249647, 0.8720998093075035],
+ [0.03757079636968397, 0.962429203630316],
+ [0.1846289135978405, 0.8153710864021595],
+ [0.05110607323534777, 0.9488939267646522],
+ [0.8978708835102494, 0.1021291164897507],
+ [0.11111038569821852, 0.8888896143017815],
+ [0.44921187388480155, 0.5507881261151985],
+ [0.855423097498373, 0.144576902501627],
+ [0.8978708835102494, 0.1021291164897507],
+ [0.14105769325058248, 0.8589423067494175],
+ [0.08034206860142967, 0.9196579313985703],
+ [0.0517397023865368, 0.9482602976134632],
+ [0.7298271986950035, 0.27017280130499655],
+ [0.9540495562336279, 0.045950443766372154],
+ [0.9165930028884138, 0.08340699711158621],
+ [0.10931200167088684, 0.8906879983291132],
+ [0.9385507945257092, 0.061449205474290745],
+ [0.8948719161714075, 0.10512808382859253],
+ [0.8973383378571904, 0.10266166214280961],
+ [0.11639013253708874, 0.8836098674629113],
+ [0.956052390852163, 0.04394760914783706],
+ [0.9612289867007701, 0.038771013299229905],
+ [0.2755781540774306, 0.7244218459225694],
+ [0.947484475608107, 0.05251552439189295],
+ [0.9013976735797605, 0.09860232642023944],
+ [0.037727020474712436, 0.9622729795252876],
+ [0.9753215964143493, 0.02467840358565063],
+ [0.7162318325330577, 0.2837681674669423],
+ [0.10419336890743847, 0.8958066310925615],
+ [0.4044412343918701, 0.5955587656081299],
+ [0.9730656228707297, 0.026934377129270236],
+ [0.11318279020889954, 0.8868172097911005],
+ [0.036295649741235, 0.963704350258765],
+ [0.48345509058787506, 0.5165449094121249],
+ [0.9670021170214806, 0.03299788297851934],
+ [0.05377327588513203, 0.946226724114868],
+ [0.07216314186313177, 0.9278368581368682],
+ [0.8387195459495832, 0.1612804540504168],
+ [0.9685902354188991, 0.03140976458110088],
+ [0.10931200167088684, 0.8906879983291132],
+ [0.03757079636968397, 0.962429203630316],
+ [0.22858955836206873, 0.7714104416379313],
+ [0.972921875983505, 0.027078124016494988],
+ [0.7921953615847763, 0.20780463841522362],
+ [0.9675038597561006, 0.032496140243899345],
+ [0.11327252137195587, 0.8867274786280441],
+ [0.07512499995791921, 0.9248750000420808],
+ [0.035676949760042875, 0.9643230502399571],
+ [0.8834766324397875, 0.11652336756021252],
+ [0.9660311161320242, 0.033968883867975835],
+ [0.5879269539460168, 0.41207304605398315],
+ [0.7434865393955636, 0.2565134606044363],
+ [0.11382126481839294, 0.8861787351816071],
+ [0.17946079428899442, 0.8205392057110056],
+ [0.05801699845027686, 0.9419830015497231],
+ [0.26289780273304264, 0.7371021972669574],
+ [0.8444282511876444, 0.15557174881235555],
+ ],
+ "expected_metrics": {
+ "logloss": 0.16449683853287383,
+ "logloss_weighted": 0.16449683853287383,
+ "auc": np.float64(0.9956),
+ "f1": np.float64(0.98),
+ "accuracy": 0.98,
+ "balanced_accuracy": np.float64(0.98),
+ "precision": np.float64(0.98),
+ "recall": np.float64(0.98),
+ "specificity": np.float64(0.98),
+ },
+ },
+ "lasso": {
+ "expected_preds": [
+ [0.7816936873086878, 0.21830631269131215],
+ [0.31624928527411866, 0.6837507147258813],
+ [0.31175710811128376, 0.6882428918887162],
+ [0.3327395758214362, 0.6672604241785638],
+ [0.30591660281237865, 0.6940833971876214],
+ [0.27231695202635364, 0.7276830479736464],
+ [0.7881033635355769, 0.2118966364644231],
+ [0.23754846287095566, 0.7624515371290443],
+ [0.25642166051074533, 0.7435783394892547],
+ [0.32707066935266993, 0.6729293306473301],
+ [0.6891296880681221, 0.3108703119318779],
+ [0.5981991819110841, 0.40180081808891593],
+ [0.25115314289900026, 0.7488468571009997],
+ [0.5734621177488353, 0.42653788225116473],
+ [0.32864980374690667, 0.6713501962530933],
+ [0.6462823930480115, 0.3537176069519885],
+ [0.6384700780622443, 0.3615299219377557],
+ [0.3350062784077973, 0.6649937215922027],
+ [0.48056137630666895, 0.519438623693331],
+ [0.5153027484675532, 0.48469725153244675],
+ [0.3115503620089727, 0.6884496379910273],
+ [0.20750955945016547, 0.7924904405498345],
+ [0.4022895107130673, 0.5977104892869327],
+ [0.5570701111034967, 0.4429298888965033],
+ [0.46231437748106563, 0.5376856225189344],
+ [0.5336473794374168, 0.46635262056258314],
+ [0.8974911435749764, 0.10250885642502364],
+ [0.5698967574554465, 0.4301032425445534],
+ [0.8699974284228695, 0.13000257157713047],
+ [0.18044140781683704, 0.819558592183163],
+ [0.3065589974178936, 0.6934410025821064],
+ [0.7661916377320677, 0.23380836226793236],
+ [0.46470999173188965, 0.5352900082681104],
+ [0.30624315426509185, 0.6937568457349081],
+ [0.29512360607920796, 0.704876393920792],
+ [0.6455959800560582, 0.35440401994394183],
+ [0.6765435056409852, 0.3234564943590148],
+ [0.29782956675631833, 0.7021704332436817],
+ [0.21753403619181588, 0.7824659638081841],
+ [0.7948865980678086, 0.20511340193219144],
+ [0.824249495776864, 0.17575050422313598],
+ [0.39307615244097227, 0.6069238475590277],
+ [0.4115183697470993, 0.5884816302529007],
+ [0.5313818223941422, 0.4686181776058578],
+ [0.5972677846657815, 0.4027322153342186],
+ [0.5477712695165392, 0.4522287304834608],
+ [0.74534004641736, 0.25465995358264004],
+ [0.35323735627572894, 0.6467626437242711],
+ [0.4471767406356947, 0.5528232593643053],
+ [0.23015436437358872, 0.7698456356264113],
+ [0.6962745670430985, 0.3037254329569014],
+ [0.8336952862635421, 0.16630471373645794],
+ [0.42658011594906, 0.57341988405094],
+ [0.30674813687003766, 0.6932518631299623],
+ [0.24919868347662733, 0.7508013165233727],
+ [0.36639539420598566, 0.6336046057940143],
+ [0.6101778533199422, 0.3898221466800577],
+ [0.7757372692814383, 0.22426273071856173],
+ [0.5574921398591706, 0.44250786014082943],
+ ],
+ "expected_metrics": {
+ "logloss": 0.5865464808741376,
+ "logloss_weighted": 0.5865464808741376,
+ "auc": np.float64(0.7508),
+ "f1": np.float64(0.6990291262135923),
+ "accuracy": 0.69,
+ "balanced_accuracy": np.float64(0.69),
+ "precision": np.float64(0.6792452830188679),
+ "recall": np.float64(0.72),
+ "specificity": np.float64(0.66),
+ },
+ },
+ "random_forest": {
+ "expected_preds": [
+ [0.98, 0.02],
+ [0.95, 0.05],
+ [0.16, 0.84],
+ [0.03, 0.97],
+ [0.08, 0.92],
+ [0.0, 1.0],
+ [0.97, 0.03],
+ [0.02, 0.98],
+ [0.68, 0.32],
+ [0.8, 0.2],
+ [0.84, 0.16],
+ [0.11, 0.89],
+ [0.16, 0.84],
+ [0.06, 0.94],
+ [0.75, 0.25],
+ [0.89, 0.11],
+ [0.87, 0.13],
+ [0.08, 0.92],
+ [0.79, 0.21],
+ [0.91, 0.09],
+ [0.75, 0.25],
+ [0.07, 0.93],
+ [0.96, 0.04],
+ [0.96, 0.04],
+ [0.19, 0.81],
+ [0.91, 0.09],
+ [0.98, 0.02],
+ [0.05, 0.95],
+ [0.99, 0.01],
+ [0.64, 0.36],
+ [0.2, 0.8],
+ [0.28, 0.72],
+ [0.9, 0.1],
+ [0.17, 0.83],
+ [0.0, 1.0],
+ [0.35, 0.65],
+ [0.95, 0.05],
+ [0.01, 0.99],
+ [0.11, 0.89],
+ [0.86, 0.14],
+ [0.95, 0.05],
+ [0.05, 0.95],
+ [0.04, 0.96],
+ [0.23, 0.77],
+ [0.93, 0.07],
+ [0.82, 0.18],
+ [0.91, 0.09],
+ [0.18, 0.82],
+ [0.11, 0.89],
+ [0.01, 0.99],
+ [0.96, 0.04],
+ [0.99, 0.01],
+ [0.83, 0.17],
+ [0.6, 0.4],
+ [0.06, 0.94],
+ [0.3, 0.7],
+ [0.06, 0.94],
+ [0.3, 0.7],
+ [0.87, 0.13],
+ ],
+ "expected_metrics": {
+ "logloss": 0.13027149699327187,
+ "logloss_weighted": 0.13027149699327187,
+ "auc": np.float64(1.0),
+ "f1": np.float64(1.0),
+ "accuracy": 1.0,
+ "balanced_accuracy": np.float64(1.0),
+ "precision": np.float64(1.0),
+ "recall": np.float64(1.0),
+ "specificity": np.float64(1.0),
+ },
+ },
+ },
+ "multi_class_classification": {
+ "lightgbm": {
+ "expected_preds": [
+ [0.593146054103032, 0.3544863375913485, 0.05236760830561952],
+ [0.0162947366442909, 0.08966806138760489, 0.8940372019681042],
+ [0.7314477097580847, 0.08987852215806323, 0.1786737680838522],
+ [0.9513440600973485, 0.02558212227442155, 0.02307381762822989],
+ [0.08076421066646375, 0.8471446489996138, 0.07209114033392237],
+ [0.024941657594269584, 0.8349486385255367, 0.14010970388019373],
+ [0.8080818936953171, 0.11152838686995886, 0.08038971943472402],
+ [0.3745713232222257, 0.1005204106433833, 0.524908266134391],
+ [0.03084559346731874, 0.7175825421822467, 0.2515718643504347],
+ [0.9264354653778086, 0.040485517926432894, 0.03307901669575851],
+ [0.26169193369944177, 0.6432778275261304, 0.09503023877442779],
+ [0.05390047046164553, 0.12321047791027544, 0.822889051628079],
+ [0.9893123044084093, 0.0033716628402068512, 0.007316032751383729],
+ [0.04989245219405661, 0.5164105654452859, 0.43369698236065757],
+ [0.8965695794118023, 0.04655120246031727, 0.05687921812788048],
+ [0.09683252885215721, 0.8737826902595416, 0.02938478088830118],
+ [0.018979028711999975, 0.3854551955203541, 0.595565775767646],
+ [0.015064114607673476, 0.06491993148499259, 0.9200159539073338],
+ [0.2501257382333796, 0.07557633061268863, 0.6742979311539318],
+ [0.01260103830566784, 0.950366380134126, 0.03703258156020627],
+ [0.057080075231720706, 0.8878873343593096, 0.05503259040896966],
+ [0.14887379039370122, 0.7141974476107982, 0.13692876199550066],
+ [0.022279213512249003, 0.9685747767923893, 0.009146009695361583],
+ [0.36701992326651967, 0.4247852501339898, 0.20819482659949057],
+ [0.9432880138214005, 0.02093608544659706, 0.03577590073200245],
+ [0.21445840357507803, 0.12505753739912728, 0.6604840590257947],
+ [0.9850150858045772, 0.005457083883926245, 0.009527830311496462],
+ [0.008149987927395205, 0.9720361230965696, 0.019813888976035123],
+ [0.5787640351932142, 0.3587723056494069, 0.06246365915737892],
+ [0.058630546812436284, 0.11549220701190936, 0.8258772461756545],
+ [0.9415956567242569, 0.0017096868247902064, 0.056694656450952965],
+ [0.8401430809052793, 0.005159964181618216, 0.1546969549131024],
+ [0.3373342066378342, 0.05918842349172757, 0.6034773698704382],
+ [0.011080447400212695, 0.9116047621274613, 0.07731479047232594],
+ [0.3693202163184197, 0.26089826875819366, 0.3697815149233866],
+ [0.09434666971449106, 0.2436016225971083, 0.6620517076884006],
+ [0.012044889962176637, 0.15740806448518946, 0.8305470455526338],
+ [0.9625788280022749, 0.006700416704897087, 0.030720755292827962],
+ [0.0314960379961621, 0.018228459472588898, 0.9502755025312489],
+ [0.07180604640327914, 0.21617724025134022, 0.7120167133453806],
+ [0.9909239929523648, 0.002762349978571398, 0.006313657069063688],
+ [0.8467504500331269, 0.014471453132731507, 0.13877809683414177],
+ [0.17742037837331487, 0.7239156681031946, 0.0986639535234905],
+ [0.963211960496519, 0.006581222131420159, 0.030206817372060775],
+ [0.9867858584643275, 0.002130638765929315, 0.01108350276974304],
+ [0.011223976533424277, 0.9234131157628203, 0.06536290770375551],
+ [0.013413129940471473, 0.9751884803477013, 0.011398389711827031],
+ [0.8890369114051427, 0.057398540656221825, 0.053564547938635396],
+ [0.7393617652660448, 0.1762448283804414, 0.08439340635351382],
+ [0.0026998424032764717, 0.8961457621237923, 0.10115439547293113],
+ [0.04922197852764869, 0.9004333884777317, 0.050344632994619644],
+ [0.009004793995118152, 0.8620266528963505, 0.12896855310853148],
+ [0.004300474674760467, 0.9410398929508006, 0.05465963237443907],
+ [0.016551281611928693, 0.07217063275491635, 0.911278085633155],
+ [0.22490788573213896, 0.7139179765022333, 0.06117413776562761],
+ [0.9519439180445933, 0.010465037470341117, 0.037591044485065436],
+ [0.9358052195418675, 0.011508047597263087, 0.05268673286086945],
+ [0.07397151864688836, 0.08594553167721526, 0.8400829496758964],
+ [0.9486043978269892, 0.019309808345722143, 0.03208579382728866],
+ ],
+ "expected_metrics": {
+ "mlogloss": 0.1996362230208467,
+ "mlogloss_weighted": 0.19985753361762523,
+ "accuracy": 0.97,
+ "balanced_accuracy": np.float64(0.9702911467617351),
+ "f1_macro": np.float64(0.969998492386552),
+ "f1_micro": np.float64(0.97),
+ "f1_weighted": np.float64(0.9699954771596563),
+ "precision_macro": np.float64(0.9705882352941178),
+ "precision_micro": np.float64(0.97),
+ "precision_weighted": np.float64(0.9708823529411765),
+ "recall_macro": np.float64(0.9702911467617351),
+ "recall_micro": np.float64(0.97),
+ "recall_weighted": np.float64(0.97),
+ "specificity_macro": np.float64(0.9850746268656717),
+ "specificity_weighted": np.float64(0.9852238805970149),
+ },
+ },
+ "lasso": {
+ "expected_preds": [
+ [0.727717303959388, 0.23690670804022768, 0.03537598800038429],
+ [0.18692727398823294, 0.11157108480807601, 0.7015016412036912],
+ [0.34309525593719364, 0.1765917153709202, 0.48031302869188613],
+ [0.5538856681121752, 0.16189352573200325, 0.2842208061558217],
+ [0.6211387488472769, 0.14887956386778411, 0.22998168728493895],
+ [0.08136652003632926, 0.7405643083523318, 0.178069171611339],
+ [0.759360225300843, 0.13038683194736123, 0.11025294275179577],
+ [0.7439536052085204, 0.11809572286257745, 0.13795067192890215],
+ [0.23857065761662077, 0.41094065879844277, 0.35048868358493646],
+ [0.7701695624738958, 0.16052746174977564, 0.06930297577632857],
+ [0.2933931613744791, 0.60926996361502, 0.09733687501050099],
+ [0.147242819833966, 0.4228473626130698, 0.42990981755296426],
+ [0.9621888073691544, 0.0008644135412697095, 0.036946779089575964],
+ [0.01705558890825475, 0.11664143509420513, 0.8663029759975402],
+ [0.3893194843740159, 0.3280460247431107, 0.28263449088287346],
+ [0.39448406752573534, 0.25421966960835124, 0.3512962628659135],
+ [0.07706673345013908, 0.792292432198681, 0.13064083435118],
+ [0.2569720803873218, 0.022213793505098937, 0.7208141261075793],
+ [0.4051741803656034, 0.30944209380006654, 0.2853837258343301],
+ [0.03527785507203288, 0.8355004491958612, 0.12922169573210607],
+ [0.13010121295474616, 0.5270237025408331, 0.3428750845044207],
+ [0.48615577104471547, 0.28824697580116326, 0.22559725315412132],
+ [0.06686616038986118, 0.9040761159727345, 0.029057723637404262],
+ [0.22016199312436038, 0.18946190384464945, 0.5903761030309902],
+ [0.887645228005912, 0.05455414993455491, 0.057800622059533145],
+ [0.10492818962033762, 0.23688245995737642, 0.658189350422286],
+ [0.6561865190807977, 0.06952320441570069, 0.27429027650350163],
+ [0.056170413275787576, 0.7340325149666672, 0.20979707175754517],
+ [0.3467095278403483, 0.49654341645234773, 0.156747055707304],
+ [0.14269187448939388, 0.29671434149381504, 0.5605937840167912],
+ [0.8597096595076208, 0.003057035249782796, 0.13723330524259655],
+ [0.5412733542712894, 0.06458744307929185, 0.39413920264941876],
+ [0.768266090194226, 0.04860121636266081, 0.1831326934431133],
+ [0.2116624126426145, 0.5105721747127452, 0.27776541264464033],
+ [0.13321208490679914, 0.6243477238858032, 0.24244019120739774],
+ [0.049943991269237, 0.42310719519145185, 0.5269488135393112],
+ [0.15853416173853788, 0.3785164681414171, 0.4629493701200449],
+ [0.9085156781663416, 0.03340196488561009, 0.05808235694804848],
+ [0.16236725942071278, 0.14571741821409478, 0.6919153223651924],
+ [0.11524093905630521, 0.20831478654951238, 0.6764442743941824],
+ [0.847819055621279, 0.009838453305903732, 0.1423424910728173],
+ [0.09772981091930481, 0.012648051184822403, 0.8896221378958727],
+ [0.4151595507653043, 0.45189425384898724, 0.1329461953857084],
+ [0.8300230582936061, 0.003552826486441063, 0.16642411521995287],
+ [0.9509389107752908, 0.0020869188052771133, 0.04697417041943202],
+ [0.11616015578071352, 0.7491230372424216, 0.13471680697686486],
+ [0.42501082628592196, 0.34458353378516193, 0.230405639928916],
+ [0.3747387868476141, 0.4898130370047601, 0.1354481761476259],
+ [0.25984451234985434, 0.39972345466155934, 0.3404320329885864],
+ [0.0670850052509371, 0.7474444842869399, 0.185470510462123],
+ [0.3192185535429023, 0.4248999983362068, 0.25588144812089086],
+ [0.08374375640315046, 0.6594847821394243, 0.25677146145742524],
+ [0.035162467740786844, 0.8535509674940139, 0.11128656476519908],
+ [0.2098548230221486, 0.3014144284213309, 0.4887307485565204],
+ [0.027996101265055064, 0.8653440168301139, 0.10665988190483106],
+ [0.636044154536507, 0.050639795480214446, 0.31331604998327856],
+ [0.744216538349922, 0.08092876869163469, 0.17485469295844333],
+ [0.1723251482898169, 0.014796167284032756, 0.8128786844261504],
+ [0.8461973185412733, 0.10200854616835917, 0.05179413529036747],
+ ],
+ "expected_metrics": {
+ "mlogloss": 0.7093916124597448,
+ "mlogloss_weighted": 0.7094112118294117,
+ "accuracy": 0.73,
+ "balanced_accuracy": np.float64(0.7302436125965537),
+ "f1_macro": np.float64(0.7299465240641712),
+ "f1_micro": np.float64(0.73),
+ "f1_weighted": np.float64(0.7299197860962567),
+ "precision_macro": np.float64(0.7305194805194807),
+ "precision_micro": np.float64(0.73),
+ "precision_weighted": np.float64(0.7307142857142856),
+ "recall_macro": np.float64(0.7302436125965537),
+ "recall_micro": np.float64(0.73),
+ "recall_weighted": np.float64(0.73),
+ "specificity_macro": np.float64(0.8650685964118799),
+ "specificity_weighted": np.float64(0.8652057892356401),
+ },
+ },
+ "random_forest": {
+ "expected_preds": [
+ [0.87, 0.13, 0.0],
+ [0.0, 0.06, 0.94],
+ [0.75, 0.09, 0.16],
+ [0.79, 0.16, 0.05],
+ [0.22, 0.77, 0.01],
+ [0.07, 0.74, 0.19],
+ [0.87, 0.11, 0.02],
+ [0.27, 0.01, 0.72],
+ [0.02, 0.83, 0.15],
+ [0.92, 0.06, 0.02],
+ [0.28, 0.64, 0.08],
+ [0.09, 0.12, 0.79],
+ [0.98, 0.01, 0.01],
+ [0.05, 0.7, 0.25],
+ [0.69, 0.23, 0.08],
+ [0.1, 0.86, 0.04],
+ [0.06, 0.24, 0.7],
+ [0.07, 0.04, 0.89],
+ [0.24, 0.09, 0.67],
+ [0.06, 0.93, 0.01],
+ [0.15, 0.81, 0.04],
+ [0.15, 0.65, 0.2],
+ [0.03, 0.94, 0.03],
+ [0.24, 0.62, 0.14],
+ [0.96, 0.01, 0.03],
+ [0.1, 0.04, 0.86],
+ [0.94, 0.05, 0.01],
+ [0.04, 0.92, 0.04],
+ [0.76, 0.18, 0.06],
+ [0.01, 0.07, 0.92],
+ [0.79, 0.07, 0.14],
+ [0.88, 0.05, 0.07],
+ [0.27, 0.09, 0.64],
+ [0.02, 0.95, 0.03],
+ [0.69, 0.15, 0.16],
+ [0.01, 0.2, 0.79],
+ [0.04, 0.14, 0.82],
+ [0.94, 0.03, 0.03],
+ [0.03, 0.06, 0.91],
+ [0.14, 0.14, 0.72],
+ [0.98, 0.01, 0.01],
+ [0.79, 0.02, 0.19],
+ [0.2, 0.76, 0.04],
+ [0.9, 0.03, 0.07],
+ [0.99, 0.0, 0.01],
+ [0.01, 0.94, 0.05],
+ [0.02, 0.89, 0.09],
+ [0.83, 0.04, 0.13],
+ [0.7, 0.16, 0.14],
+ [0.01, 0.91, 0.08],
+ [0.06, 0.84, 0.1],
+ [0.05, 0.86, 0.09],
+ [0.02, 0.83, 0.15],
+ [0.0, 0.14, 0.86],
+ [0.13, 0.72, 0.15],
+ [0.81, 0.06, 0.13],
+ [0.85, 0.02, 0.13],
+ [0.06, 0.01, 0.93],
+ [0.92, 0.04, 0.04],
+ ],
+ "expected_metrics": {
+ "mlogloss": 0.18626589694153936,
+ "mlogloss_weighted": 0.1865393454742004,
+ "accuracy": 1.0,
+ "balanced_accuracy": np.float64(1.0),
+ "f1_macro": np.float64(1.0),
+ "f1_micro": np.float64(1.0),
+ "f1_weighted": np.float64(1.0),
+ "precision_macro": np.float64(1.0),
+ "precision_micro": np.float64(1.0),
+ "precision_weighted": np.float64(1.0),
+ "recall_macro": np.float64(1.0),
+ "recall_micro": np.float64(1.0),
+ "recall_weighted": np.float64(1.0),
+ "specificity_macro": np.float64(1.0),
+ "specificity_weighted": np.float64(1.0),
+ },
+ },
+ },
+}
+
+
+pytestmark = pytest.mark.integration
+
+
+@pytest.mark.parametrize("format", FORMATS)
+@pytest.mark.parametrize("model_name", SUPPORTED_MODELS)
+def test_eval_binary_classification(
+ classification_data, sample_metadata, model_name, format
+):
+ run_test_eval(
+ classification_data,
+ sample_metadata,
+ model_name,
+ format,
+ "binary_classification",
+ )
+
+
+@pytest.mark.parametrize("format", FORMATS)
+@pytest.mark.parametrize("model_name", SUPPORTED_MODELS)
+def test_eval_multi_class_classification(
+ classification_data_multi_class, sample_metadata, model_name, format
+):
+ run_test_eval(
+ classification_data_multi_class,
+ sample_metadata,
+ model_name,
+ format,
+ "multi_class_classification",
+ )
+
+
+def run_test_eval(classification_data, sample_metadata, model_name, format, task):
+ format = format.replace("_cached", "")
+ X, y = classification_data
+ _run_eval(model_name, X, y, sample_metadata, format, task)
+
+
+def _run_eval(model_name, X, y, sample_metadata, format, task):
+ set_seed(42)
+ if model_name == "lightgbm":
+ if not is_lightgbm_available():
+ pytest.skip("test requires lightgbm")
+ model = LightGBMForClassification()
+ elif model_name == "lasso":
+ model = LassoForClassification()
+ elif model_name == "random_forest":
+ model = RandomForestForClassification()
+
+ otu_dataset = pd.concat([sample_metadata, X, y], axis=1)
+
+ input_columns = list(X.columns)
+ target_column = list(y.columns)[0]
+ if format == "numpy":
+ data = DataHandler.to_numpy(X)
+ target = DataHandler.to_numpy(y)
+ model.fit(data, target)
+ preds, metrics = evaluate(model, data, target)
+ elif "seperate" in format:
+ new_format = format.replace("_seperate", "")
+ data = X
+ target = y
+ if new_format == "dataset":
+ if not is_biosets_available():
+ pytest.skip("test requires biosets")
+ from biosets.features import Abundance, BinClassLabel
+
+ otu_dataset = create_bioset(
+ X=X,
+ y=y,
+ sample_metadata=sample_metadata,
+ with_feature_metadata=True,
+ feature_type=Abundance,
+ target_type=BinClassLabel,
+ )
+
+ model.fit(data, target)
+ preds, metrics = evaluate(model, otu_dataset, target)
+ else:
+ data = DataHandler.to_format(data, new_format)
+ target = DataHandler.to_format(target, new_format)
+ model.fit(data, target)
+ preds, metrics = evaluate(model, data, target)
+ elif "shuffle" in format:
+ # here we test to see if output is the same when columns are shuffled
+ new_format = format.replace("_shuffle", "")
+ columns = DataHandler.get_column_names(otu_dataset)
+ np.random.seed(42)
+ columns = np.random.permutation(columns).tolist()
+ if new_format == "dataset":
+ if not is_biosets_available():
+ pytest.skip("test requires biosets")
+ from biosets.features import Abundance, BinClassLabel
+
+ otu_dataset = create_bioset(
+ X=X,
+ y=y,
+ sample_metadata=sample_metadata,
+ with_feature_metadata=True,
+ feature_type=Abundance,
+ target_type=BinClassLabel,
+ )
+
+ data = otu_dataset
+ model.fit(data)
+ data = DataHandler.select_columns(data, columns)
+ preds, metrics = evaluate(
+ model,
+ data,
+ output_dir=os.path.dirname(otu_dataset.cache_files[0]["filename"]),
+ )
+ else:
+ data = DataHandler.to_format(otu_dataset, new_format)
+ model.fit(
+ data,
+ target_column=target_column,
+ input_columns=input_columns,
+ )
+ data = DataHandler.select_columns(data, columns)
+ columns = [col for col in columns if col in input_columns]
+ # add extra random columns
+ data_dim = DataHandler.get_shape(data)
+ for i in range(5):
+ data = DataHandler.append_column(
+ data, f"extra_col_{i}", np.array([0] * data_dim[0], dtype=np.int32)
+ )
+ columns += [f"extra_col_{i}" for i in range(5)]
+ preds, metrics = evaluate(
+ model,
+ data,
+ target_columns=target_column,
+ input_columns=input_columns,
+ )
+ else:
+ if format == "dataset":
+ if not is_biosets_available():
+ pytest.skip("test requires biosets")
+ from biosets.features import Abundance, BinClassLabel
+
+ otu_dataset = create_bioset(
+ X=X,
+ y=y,
+ sample_metadata=sample_metadata,
+ with_feature_metadata=True,
+ feature_type=Abundance,
+ target_type=BinClassLabel,
+ )
+
+ data = otu_dataset
+ model.fit(data)
+ preds, metrics = evaluate(
+ model,
+ data,
+ output_dir=os.path.dirname(otu_dataset.cache_files[0]["filename"]),
+ )
+ else:
+ data = DataHandler.to_format(otu_dataset, format)
+ model.fit(
+ data,
+ target_column=target_column,
+ input_columns=input_columns,
+ )
+ preds, metrics = evaluate(
+ model,
+ data,
+ target_columns=target_column,
+ input_columns=input_columns,
+ )
+
+ ev = EXPECTED_VALS[task][model_name]
+ for i, pred in enumerate(preds.values.tolist()[:59]):
+ val = ev["expected_preds"][i]
+ assert pred == pytest.approx(
+ val, abs=0.1
+ ), f"Prediction {pred} does not match the expected value {val}."
+
+ for metric, expected_value in ev["expected_metrics"].items():
+ assert metrics[metric] == pytest.approx(
+ expected_value, abs=1e-2
+ ), f"Metric {metric} does not match the expected value."
diff --git a/tests/test_patcher.py b/tests/test_patcher.py
new file mode 100644
index 0000000..4bc16fc
--- /dev/null
+++ b/tests/test_patcher.py
@@ -0,0 +1,309 @@
+import importlib
+import os
+import shutil
+import sys
+import tempfile
+import threading
+import unittest
+from unittest.mock import patch
+
+import biofit.config
+import pytest
+from biofit.integration.patcher import (
+ Patcher,
+ PatcherConfig,
+ get_hashed_patches,
+)
+
+SOURCE_MODULE_CODE = """
+SOURCE_CONSTANT = 42
+CONSTANT_WITH_SAME_NAME = 42
+
+def source_function():
+ return "new function"
+
+class SourceClass:
+ def method(self):
+ return "new method"
+
+class ClassWithSameName:
+ def method(self):
+ return "new method"
+
+def function_with_same_name():
+ return "new function"
+"""
+
+TARGET_MODULE_CODE = """
+TARGET_CONSTANT = 24
+CONSTANT_WITH_SAME_NAME = 24
+def target_function():
+ return "original function"
+
+class TargetClass:
+ def method(self):
+ return "original method"
+
+class ClassWithSameName:
+ def method(self):
+ return "original method"
+
+def function_with_same_name():
+ return "original function"
+"""
+
+
+pytestmark = pytest.mark.unit
+
+
+class SingletonMeta(type):
+ _instances = {}
+
+ def __call__(cls, *args, **kwargs):
+ if cls not in cls._instances:
+ cls._instances[cls] = super(SingletonMeta, cls).__call__(*args, **kwargs)
+ return cls._instances[cls]
+
+
+# Define MockPatcherConfig and MockPatcher outside of setUp
+class MockPatcherConfig(PatcherConfig, metaclass=SingletonMeta):
+ def __init__(self, target_module, source_module):
+ self.root = importlib.import_module(target_module)
+ self.patch_targets = [self.root]
+
+ self.module_paths = [source_module]
+
+ self.patches = get_hashed_patches(module_paths=self.module_paths)
+ super().__init__(
+ patches=self.patches,
+ root=self.root,
+ patch_targets=self.patch_targets,
+ )
+
+ def get_mock_patches(self, entity_paths):
+ patches = {}
+ for path in entity_paths:
+ module_name, attr_name = path.rsplit(".", 1)
+ module = importlib.import_module(module_name)
+ obj = getattr(module, attr_name)
+ patches[attr_name] = (obj, module_name)
+ return patches
+
+
+class MockPatcher(Patcher, metaclass=SingletonMeta):
+ def __init__(self, target_module, source_module):
+ config = MockPatcherConfig(target_module, source_module)
+ super().__init__(config=config)
+
+
+class TestPatcher(unittest.TestCase):
+ def setUp(self):
+ # Create temporary directories for source and target modules
+ self.source_module_dir = tempfile.mkdtemp()
+ self.target_module_dir = tempfile.mkdtemp()
+ self.cache_dir = biofit.config.BIOFIT_PATCHES_CACHE
+
+ # Names of the modules
+ self.source_module_name = "source_module"
+
+ # Write the source module code to a file
+ self.source_module_path = os.path.join(
+ self.source_module_dir, f"{self.source_module_name}.py"
+ )
+ with open(self.source_module_path, "w") as f:
+ f.write(SOURCE_MODULE_CODE)
+
+ self.target_module_name = "target_module"
+ # Write the target module code to a file
+ self.target_module_path = os.path.join(
+ self.target_module_dir, f"{self.target_module_name}.py"
+ )
+ with open(self.target_module_path, "w") as f:
+ f.write(TARGET_MODULE_CODE)
+
+ # Add both module directories to sys.path
+ sys.path.insert(0, self.target_module_dir)
+ sys.path.insert(0, self.source_module_dir)
+
+ # Import the source and target modules
+ self.source_module = importlib.import_module(self.source_module_name)
+ self.target_module = importlib.import_module(self.target_module_name)
+
+ # Initialize MockPatcher
+ self.mock_patcher = MockPatcher(
+ target_module=self.target_module_name,
+ source_module=self.source_module_name,
+ )
+
+ def tearDown(self):
+ # Clean up the temporary directories and sys.path
+ if os.path.exists(self.source_module_dir):
+ shutil.rmtree(self.source_module_dir)
+ if os.path.exists(self.target_module_dir):
+ shutil.rmtree(self.target_module_dir)
+ if self.cache_dir.exists():
+ shutil.rmtree(self.cache_dir)
+ sys.path.remove(self.source_module_dir)
+ sys.path.remove(self.target_module_dir)
+
+ # Remove the source and target modules from sys.modules
+ if self.source_module_name in sys.modules:
+ del sys.modules[self.source_module_name]
+ if self.target_module_name in sys.modules:
+ del sys.modules[self.target_module_name]
+
+ # Clear singleton instances
+ SingletonMeta._instances.clear()
+
+ def test_mock_patcher_apply_patches(self):
+ # Apply patches using MockPatcher and test if the functions are patched
+ with self.mock_patcher:
+ self.assertEqual(
+ self.target_module.function_with_same_name(), "new function"
+ )
+ class_with_same_name = self.target_module.ClassWithSameName()
+ self.assertEqual(class_with_same_name.method(), "new method")
+ constant_with_same_name = self.target_module.CONSTANT_WITH_SAME_NAME
+ self.assertEqual(constant_with_same_name, 42)
+
+ # After exiting the context manager, the original functions should be restored
+ self.assertEqual(
+ self.target_module.function_with_same_name(), "original function"
+ )
+ class_with_same_name = self.target_module.ClassWithSameName()
+ self.assertEqual(class_with_same_name.method(), "original method")
+ constant_with_same_name = self.target_module.CONSTANT_WITH_SAME_NAME
+ self.assertEqual(constant_with_same_name, 24)
+
+ def test_mock_patcher_singleton(self):
+ # Ensure that MockPatcher is a singleton
+ another_patcher = MockPatcher(
+ target_module=self.target_module,
+ source_module=self.source_module,
+ )
+ self.assertIs(self.mock_patcher, another_patcher)
+
+ def test_mock_patcher_config_singleton(self):
+ # Ensure that MockPatcherConfig is a singleton
+ config1 = self.mock_patcher.config
+ config2 = self.mock_patcher.config
+ self.assertIs(config1, config2)
+
+ def test_mock_patcher_revert_patches(self):
+ # Apply patches and revert them manually
+ self.mock_patcher._apply_patches()
+ self.assertEqual(self.target_module.function_with_same_name(), "new function")
+
+ self.mock_patcher._revert_patches()
+ self.assertEqual(
+ self.target_module.function_with_same_name(), "original function"
+ )
+
+ def test_mock_patcher_clear_cache(self):
+ # Test clearing the cache
+ self.mock_patcher.config.clear_cache()
+ # ensure that the cache directory is empty
+ self.assertFalse(self.cache_dir.exists())
+
+ def test_mock_patcher_exception_handling(self):
+ # Test that exceptions within the context manager do not leave patches applied
+ try:
+ with self.mock_patcher:
+ self.assertEqual(
+ self.target_module.function_with_same_name(), "new function"
+ )
+ raise ValueError("Test exception")
+ except ValueError:
+ pass
+
+ # Ensure that after the exception, the original functions are restored
+ self.assertEqual(
+ self.target_module.function_with_same_name(), "original function"
+ )
+
+ def test_mock_patcher_thread_safety(self):
+ # Test that the MockPatcher is thread-safe
+ def thread_target():
+ with self.mock_patcher:
+ self.assertEqual(
+ self.target_module.function_with_same_name(), "new function"
+ )
+ self.assertEqual(
+ self.target_module.ClassWithSameName().method(), "new method"
+ )
+
+ thread = threading.Thread(target=thread_target)
+ thread.start()
+ thread.join()
+
+ # Ensure that after the thread has finished, the original functions are restored
+ self.assertEqual(
+ self.target_module.function_with_same_name(), "original function"
+ )
+ self.assertEqual(
+ self.target_module.ClassWithSameName().method(), "original method"
+ )
+
+ def test_mock_patcher_context_manager_nesting(self):
+ # Test nesting the MockPatcher context manager
+ with self.mock_patcher:
+ self.assertEqual(
+ self.target_module.function_with_same_name(), "new function"
+ )
+ with self.mock_patcher:
+ self.assertEqual(
+ self.target_module.function_with_same_name(), "new function"
+ )
+ self.assertEqual(
+ self.target_module.function_with_same_name(), "new function"
+ )
+ self.assertEqual(
+ self.target_module.function_with_same_name(), "original function"
+ )
+
+ def test_mock_patcher_invalid_patch(self):
+ # Test handling of an invalid patch (e.g., invalid source module)
+ invalid_entity_paths = ["nonexistent_module.nonexistent_function"]
+
+ class InvalidMockPatcherConfig(PatcherConfig, metaclass=SingletonMeta):
+ def __init__(self):
+ self.patches = self.get_mock_patches(invalid_entity_paths)
+ self.root = self.target_module
+ self.patch_targets = [self.target_module]
+ super().__init__(
+ patches=self.patches,
+ root=self.root,
+ patch_targets=self.patch_targets,
+ )
+
+ def get_mock_patches(self, entity_paths):
+ patches = {}
+ for path in entity_paths:
+ try:
+ module_name, attr_name = path.rsplit(".", 1)
+ module = importlib.import_module(module_name)
+ obj = getattr(module, attr_name)
+ patches[attr_name] = (obj, module_name)
+ except (ImportError, AttributeError):
+ raise ValueError(f"Invalid entity path: {path}")
+ return patches
+
+ with self.assertRaises(ValueError):
+ InvalidMockPatcherConfig()
+
+ def test_mock_patcher_with_no_cache(self):
+ # Test the MockPatcher when caching is disabled
+ with patch("biofit.integration.patcher.is_caching_enabled", return_value=False):
+ mock_patcher_no_cache = MockPatcher(
+ target_module=self.target_module,
+ source_module=self.source_module,
+ )
+
+ with mock_patcher_no_cache:
+ self.assertEqual(
+ self.target_module.function_with_same_name(), "new function"
+ )
+
+ self.assertEqual(
+ self.target_module.function_with_same_name(), "original function"
+ )
diff --git a/tests/test_plot_feature_importances.py b/tests/test_plot_feature_importances.py
new file mode 100644
index 0000000..2d27201
--- /dev/null
+++ b/tests/test_plot_feature_importances.py
@@ -0,0 +1,350 @@
+import os
+import shutil
+import unittest
+
+import biofit.config
+import numpy as np
+import pandas as pd
+import pytest
+from biofit.visualization.feature_importance import FeatureImportancePlotter
+
+from tests.utils import require_rpy2
+
+pytestmark = pytest.mark.unit
+
+
+@require_rpy2
+class TestFeatureImportancePlotter(unittest.TestCase):
+ @pytest.fixture(autouse=True)
+ def inject_fixtures(self, count_data, sample_metadata, feature_metadata):
+ self.X, self.y = count_data
+ self.feature_metadata = feature_metadata
+ self.sample_metadata = sample_metadata
+ self.column_names = self.X.columns
+ self.metadata_columns = self.sample_metadata.columns
+ self.data = pd.concat([sample_metadata, *count_data], axis=1)
+
+ self.feature_importances = pd.DataFrame(
+ {
+ "features": self.column_names,
+ "importances_1": np.random.randint(0, 255, len(self.column_names)),
+ "importances_2": np.random.randint(0, 255, len(self.column_names)),
+ }
+ )
+
+ def setUp(self):
+ self.plotter = FeatureImportancePlotter()
+
+ def tearDown(self):
+ # Clean up the cache directory
+ cache_dir = biofit.config.BIOFIT_CACHE_HOME
+ if os.path.exists(cache_dir):
+ shutil.rmtree(cache_dir)
+
+ def _assert_plot_output(self, output_path):
+ # Retrieve the image paths from output_path
+ image_paths = [os.path.join(output_path, f) for f in os.listdir(output_path)]
+
+ # Check that image_paths is not empty
+ self.assertIsNotNone(image_paths)
+ self.assertTrue(len(image_paths) > 0)
+
+ # Check that the plot files were created and are not empty
+ for img_path in image_paths:
+ self.assertTrue(os.path.exists(img_path))
+ self.assertGreater(os.path.getsize(img_path), 0)
+
+ def test_plot_with_valid_params(self):
+ self.plotter.set_params(
+ plot_top=10,
+ dat_log="log2_1p",
+ show_column_names=True,
+ scale_legend_title="Abundance",
+ column_title="Samples",
+ row_title="Features",
+ plot_title="Feature Importances",
+ feature_meta_name=["features", "my_metadata_str"],
+ )
+
+ output_path = biofit.config.BIOFIT_CACHE_HOME
+
+ self.plotter.plot(
+ X=self.X,
+ y=self.y,
+ sample_metadata=self.sample_metadata,
+ feature_importances=self.feature_importances,
+ feature_metadata=self.feature_metadata,
+ path=output_path,
+ show=False,
+ )
+
+ self._assert_plot_output(output_path)
+
+ def test_plot_minimal_params(self):
+ output_path = biofit.config.BIOFIT_CACHE_HOME
+
+ self.plotter.plot(
+ self.data,
+ feature_importances=self.feature_importances,
+ path=output_path,
+ show=False,
+ )
+
+ self._assert_plot_output(output_path)
+
+ def test_plot_missing_feature_meta_name(self):
+ output_path = biofit.config.BIOFIT_CACHE_HOME
+
+ self.plotter.set_params(
+ feature_meta_name="non_existent_column",
+ )
+
+ with self.assertRaises(ValueError) as context:
+ self.plotter.plot(
+ X=self.X,
+ y=self.y,
+ sample_metadata=self.sample_metadata,
+ feature_importances=self.feature_importances,
+ feature_metadata=self.feature_metadata,
+ path=output_path,
+ show=False,
+ )
+ self.assertIn(
+ str(context.exception),
+ "Feature metadata columns ['non_existent_column'] not found in "
+ "feature metadata.",
+ )
+
+ def test_plot_with_invalid_dat_log(self):
+ self.plotter.set_params(dat_log="invalid_log")
+
+ output_path = biofit.config.BIOFIT_CACHE_HOME
+
+ # Assuming that invalid dat_log values are handled without raising an error
+ self.plotter.plot(
+ X=self.X,
+ y=self.y,
+ sample_metadata=self.sample_metadata,
+ feature_importances=self.feature_importances,
+ feature_metadata=self.feature_metadata,
+ path=output_path,
+ show=False,
+ )
+
+ self._assert_plot_output(output_path)
+
+ def test_plot_with_plot_top_equal_to_num_features(self):
+ num_features = len(self.column_names)
+ self.plotter.set_params(plot_top=num_features)
+
+ output_path = biofit.config.BIOFIT_CACHE_HOME
+
+ self.plotter.plot(
+ X=self.X,
+ y=self.y,
+ sample_metadata=self.sample_metadata,
+ feature_importances=self.feature_importances,
+ feature_metadata=self.feature_metadata,
+ path=output_path,
+ show=False,
+ )
+
+ self._assert_plot_output(output_path)
+
+ def test_plot_without_dat_log(self):
+ self.plotter.set_params(dat_log=None)
+
+ output_path = biofit.config.BIOFIT_CACHE_HOME
+
+ self.plotter.plot(
+ X=self.X,
+ y=self.y,
+ sample_metadata=self.sample_metadata,
+ feature_importances=self.feature_importances,
+ feature_metadata=self.feature_metadata,
+ path=output_path,
+ show=False,
+ )
+
+ self._assert_plot_output(output_path)
+
+ def test_plot_with_show_column_names(self):
+ self.plotter.set_params(show_column_names=True)
+
+ output_path = biofit.config.BIOFIT_CACHE_HOME
+
+ self.plotter.plot(
+ X=self.X,
+ y=self.y,
+ sample_metadata=self.sample_metadata,
+ feature_importances=self.feature_importances,
+ feature_metadata=self.feature_metadata,
+ path=output_path,
+ show=False,
+ )
+
+ self._assert_plot_output(output_path)
+
+ def test_plot_with_custom_scale_legend_title(self):
+ self.plotter.set_params(scale_legend_title="Custom Legend Title")
+
+ output_path = biofit.config.BIOFIT_CACHE_HOME
+
+ self.plotter.plot(
+ X=self.X,
+ y=self.y,
+ sample_metadata=self.sample_metadata,
+ feature_importances=self.feature_importances,
+ feature_metadata=self.feature_metadata,
+ path=output_path,
+ show=False,
+ )
+
+ self._assert_plot_output(output_path)
+
+ def test_plot_with_input_columns(self):
+ # Select a subset of columns
+ selected_columns = self.column_names[:5]
+ self.plotter.set_params(input_columns=selected_columns)
+
+ output_path = biofit.config.BIOFIT_CACHE_HOME
+
+ self.plotter.plot(
+ X=self.X,
+ y=self.y,
+ sample_metadata=self.sample_metadata,
+ feature_importances=self.feature_importances,
+ feature_metadata=self.feature_metadata,
+ path=output_path,
+ show=False,
+ )
+
+ self._assert_plot_output(output_path)
+
+ def test_plot_with_invalid_input_columns(self):
+ invalid_columns = ["non_existent_column"]
+
+ output_path = biofit.config.BIOFIT_CACHE_HOME
+
+ with self.assertRaises(ValueError) as context:
+ self.plotter.plot(
+ X=self.X,
+ y=self.y,
+ input_columns=invalid_columns,
+ sample_metadata=self.sample_metadata,
+ feature_importances=self.feature_importances,
+ feature_metadata=self.feature_metadata,
+ path=output_path,
+ show=False,
+ )
+ self.assertIn(
+ "Columns {'non_existent_column'} not found in input dataset",
+ str(context.exception),
+ )
+
+ def test_plot_with_custom_feature_column(self):
+ # Rename the 'features' column in feature_importances
+ self.feature_importances.rename(
+ columns={"features": "feature_names"}, inplace=True
+ )
+ self.plotter.set_params(feature_column="feature_names")
+
+ output_path = biofit.config.BIOFIT_CACHE_HOME
+
+ self.plotter.plot(
+ X=self.X,
+ y=self.y,
+ sample_metadata=self.sample_metadata,
+ feature_importances=self.feature_importances,
+ feature_metadata=self.feature_metadata,
+ path=output_path,
+ show=False,
+ )
+
+ self._assert_plot_output(output_path)
+
+ def test_plot_with_invalid_feature_column(self):
+ output_path = biofit.config.BIOFIT_CACHE_HOME
+
+ with self.assertRaises(ValueError) as context:
+ self.plotter.plot(
+ X=self.X,
+ y=self.y,
+ feature_column="invalid_column",
+ sample_metadata=self.sample_metadata,
+ feature_importances=self.feature_importances,
+ feature_metadata=self.feature_metadata,
+ path=output_path,
+ show=False,
+ )
+ self.assertIn(
+ str(context.exception),
+ "Feature column 'invalid_column' not found in feature "
+ "importances. Please provide the column name found in both "
+ "feature importances and feature metadata (if provided).",
+ )
+
+ def test_plot_with_no_sample_metadata(self):
+ self.plotter.set_params(sample_metadata_columns=None)
+
+ output_path = biofit.config.BIOFIT_CACHE_HOME
+
+ self.plotter.plot(
+ X=self.X,
+ y=self.y,
+ sample_metadata=None, # sample_metadata is None
+ feature_importances=self.feature_importances,
+ feature_metadata=self.feature_metadata,
+ path=output_path,
+ show=False,
+ )
+
+ self._assert_plot_output(output_path)
+
+ def test_plot_with_empty_X(self):
+ empty_X = pd.DataFrame()
+
+ output_path = biofit.config.BIOFIT_CACHE_HOME
+
+ with self.assertRaises(AssertionError) as context:
+ self.plotter.plot(
+ X=empty_X,
+ y=self.y,
+ sample_metadata=self.sample_metadata,
+ feature_importances=self.feature_importances,
+ feature_metadata=self.feature_metadata,
+ path=output_path,
+ show=False,
+ )
+ self.assertIn("Input data has no rows", str(context.exception))
+
+ def test_plot_output_directory(self):
+ output_path = biofit.config.BIOFIT_CACHE_HOME
+
+ self.plotter.plot(
+ X=self.X,
+ y=self.y,
+ sample_metadata=self.sample_metadata,
+ feature_importances=self.feature_importances,
+ feature_metadata=self.feature_metadata,
+ path=output_path,
+ show=False,
+ )
+
+ self._assert_plot_output(output_path)
+
+ def test_plot_show_option(self):
+ output_path = biofit.config.BIOFIT_CACHE_HOME
+
+ # Since we cannot check the display in a unit test, we ensure it runs without error
+ self.plotter.plot(
+ X=self.X,
+ y=self.y,
+ sample_metadata=self.sample_metadata,
+ feature_importances=self.feature_importances,
+ feature_metadata=self.feature_metadata,
+ path=output_path,
+ show=True,
+ )
+
+ self._assert_plot_output(output_path)
diff --git a/tests/test_plot_sample_metadata.py b/tests/test_plot_sample_metadata.py
new file mode 100644
index 0000000..3b3069b
--- /dev/null
+++ b/tests/test_plot_sample_metadata.py
@@ -0,0 +1,8 @@
+# from biofit import SampleMetadataPlotter
+#
+#
+# def test_plot_sample_metadata(camda_dataset):
+# plotter = SampleMetadataPlotter()
+# plotter.plot(camda_dataset["train"])
+# plotter = SampleMetadataPlotter()
+# plotter.plot(camda_dataset["train"])
diff --git a/tests/test_plotting_utils.py b/tests/test_plotting_utils.py
new file mode 100644
index 0000000..2cf2bbb
--- /dev/null
+++ b/tests/test_plotting_utils.py
@@ -0,0 +1,697 @@
+import shutil
+import unittest
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+import pytest
+from biocore import DataHandler
+from biocore.utils.import_util import is_biosets_available, is_polars_available
+
+import biofit.config
+from biofit.utils.py_util import set_seed
+from biofit.visualization import (
+ generate_comparison_histogram,
+ plot_correlation,
+)
+from biofit.visualization.plotting_utils import (
+ generate_violin,
+ plot_dimension_reduction,
+ plot_feature_importance,
+ plot_sample_metadata,
+)
+from tests.utils import create_bioset, require_biosets, require_polars, require_rpy2
+
+SUPPORTED_MODELS = ["lightgbm", "lasso", "random_forest"] # , "svm"]
+
+FORMATS = [
+ "pandas",
+ "polars",
+ "numpy",
+ "arrow",
+ "dataset",
+]
+
+
+# Really should be an integration test but we haven't implemented the unit tests
+# for each of the plotting classes
+pytestmark = pytest.mark.unit
+
+
+class TestPlottingUtils(unittest.TestCase):
+ @pytest.fixture(autouse=True)
+ def inject_fixtures(self, count_data, sample_metadata):
+ self.sample_metadata = sample_metadata
+ self.metadata_columns = list(sample_metadata.columns)
+ self.X, self.y = count_data
+ self.input_columns = list(self.X.columns)
+ self.target_column = list(self.y.columns)[0]
+ self.data = pd.concat([self.sample_metadata, self.X, self.y], axis=1)
+
+ self.feature_importances = pd.DataFrame(
+ {
+ "features": self.input_columns,
+ "importances_1": np.random.rand(len(self.input_columns)),
+ "importances_2": np.random.rand(len(self.input_columns)),
+ }
+ )
+ # Convert the dataset to various formats
+ if is_biosets_available():
+ from biosets.features import Abundance, BinClassLabel
+
+ self.dataset_all = create_bioset(
+ X=self.X,
+ y=self.y,
+ sample_metadata=self.sample_metadata,
+ with_feature_metadata=True,
+ feature_type=Abundance,
+ target_type=BinClassLabel,
+ )
+
+ self.pandas_all = self.data
+ if is_polars_available():
+ self.polars_all = DataHandler.to_polars(self.data)
+ self.arrow_all = DataHandler.to_arrow(self.data)
+
+ # Extract data and target in various formats
+ self.numpy_data = DataHandler.to_numpy(self.X)
+ self.numpy_target = DataHandler.to_numpy(self.y)
+
+ self.pandas_data = self.X
+ self.pandas_target = self.y
+
+ self.polars_data = DataHandler.to_polars(self.X)
+ self.polars_target = DataHandler.to_polars(self.y)
+
+ self.arrow_data = DataHandler.to_arrow(self.X)
+ self.arrow_target = DataHandler.to_arrow(self.y)
+
+ self.pandas_sample_metadata = self.sample_metadata
+ self.polars_sample_metadata = DataHandler.to_polars(self.sample_metadata)
+ self.arrow_sample_metadata = DataHandler.to_arrow(self.sample_metadata)
+
+ def setUp(self):
+ set_seed(42)
+ self.output_dir = Path(biofit.config.BIOFIT_CACHE_HOME)
+ self.output_dir.mkdir(exist_ok=True, parents=True)
+
+ def tearDown(self):
+ if self.output_dir.exists():
+ shutil.rmtree(self.output_dir)
+
+ def _assert_plot_outputs(self):
+ assert self.output_dir.is_dir()
+ assert len([f for f in self.output_dir.iterdir() if f.is_file()]) > 0
+
+ # For dataset format
+ @require_rpy2
+ def test_feature_importance_dataset(self):
+ plot_feature_importance(
+ self.feature_importances,
+ X=self.data,
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ @require_rpy2
+ def test_feature_importance_pandas(self):
+ plot_feature_importance(
+ self.feature_importances,
+ X=self.pandas_all,
+ input_columns=self.input_columns,
+ target_columns=self.target_column,
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ @require_rpy2
+ def test_feature_importance_pandas_separate(self):
+ plot_feature_importance(
+ self.feature_importances,
+ X=self.pandas_data,
+ y=self.pandas_target,
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ @require_rpy2
+ def test_feature_importance_polars(self):
+ plot_feature_importance(
+ self.feature_importances,
+ X=self.polars_all,
+ input_columns=self.input_columns,
+ target_columns=self.target_column,
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ @require_rpy2
+ def test_feature_importance_polars_separate(self):
+ plot_feature_importance(
+ self.feature_importances,
+ X=self.polars_data,
+ y=self.polars_target,
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ @require_rpy2
+ def test_feature_importance_arrow(self):
+ plot_feature_importance(
+ self.feature_importances,
+ X=self.arrow_all,
+ input_columns=self.input_columns,
+ target_columns=self.target_column,
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ @require_rpy2
+ def test_feature_importance_arrow_separate(self):
+ plot_feature_importance(
+ self.feature_importances,
+ X=self.arrow_data,
+ y=self.arrow_target,
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ @require_rpy2
+ def test_feature_importance_numpy(self):
+ feature_importances = pd.DataFrame(
+ {
+ "features": [f"col_{i}" for i in range(self.numpy_data.shape[1])],
+ "importances_1": np.random.rand(len(self.input_columns)),
+ "importances_2": np.random.rand(len(self.input_columns)),
+ }
+ )
+
+ plot_feature_importance(
+ feature_importances,
+ X=self.numpy_data,
+ y=self.numpy_target,
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ @require_rpy2
+ def test_sample_metadata_pandas(self):
+ input_columns = self.metadata_columns
+ plot_sample_metadata(
+ self.pandas_all,
+ sample_metadata_columns=input_columns,
+ outcome_column=self.target_column,
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ @require_rpy2
+ def test_sample_metadata_polars(self):
+ input_columns = self.metadata_columns
+ plot_sample_metadata(
+ self.polars_all,
+ sample_metadata_columns=input_columns,
+ outcome_column=self.target_column,
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ @require_rpy2
+ def test_sample_metadata_arrow(self):
+ input_columns = self.metadata_columns
+ plot_sample_metadata(
+ self.arrow_all,
+ sample_metadata_columns=input_columns,
+ outcome_column=self.target_column,
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ @require_rpy2
+ def test_sample_metadata_dataset(self):
+ plot_sample_metadata(
+ self.data,
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ # Test methods for violin plot
+ @require_rpy2
+ def test_violin_plot_numpy(self):
+ generate_violin(
+ self.numpy_data,
+ self.numpy_target,
+ xlab="test",
+ ylab="test",
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ @require_rpy2
+ def test_violin_plot_pandas(self):
+ generate_violin(
+ self.pandas_all,
+ xlab="test",
+ ylab="test",
+ column=self.input_columns[0],
+ label_name=self.target_column,
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ @require_rpy2
+ def test_violin_plot_polars(self):
+ generate_violin(
+ self.polars_all,
+ xlab="test",
+ ylab="test",
+ column=self.input_columns[0],
+ label_name=self.target_column,
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ @require_rpy2
+ def test_violin_plot_arrow(self):
+ generate_violin(
+ self.arrow_all,
+ xlab="test",
+ ylab="test",
+ column=self.input_columns[0],
+ label_name=self.target_column,
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ @require_rpy2
+ def test_violin_plot_dataset(self):
+ generate_violin(
+ self.data,
+ xlab="test",
+ ylab="test",
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ # Test methods for correlation plot
+ @require_rpy2
+ def test_correlation_plot_numpy(self):
+ plot_correlation(
+ self.numpy_data,
+ self.numpy_target,
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ @require_rpy2
+ def test_correlation_plot_pandas(self):
+ plot_correlation(
+ self.pandas_all,
+ input_columns=self.input_columns,
+ target_column=self.target_column,
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ @require_rpy2
+ def test_correlation_plot_polars(self):
+ plot_correlation(
+ self.polars_all,
+ input_columns=self.input_columns,
+ target_column=self.target_column,
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ @require_rpy2
+ def test_correlation_plot_arrow(self):
+ plot_correlation(
+ self.arrow_all,
+ input_columns=self.input_columns,
+ target_column=self.target_column,
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ @require_rpy2
+ def test_correlation_plot_dataset(self):
+ plot_correlation(
+ self.data,
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ # Test methods for comparison histogram
+ @require_rpy2
+ def test_comparison_histogram_numpy(self):
+ generate_comparison_histogram(
+ self.numpy_data[:, 0],
+ self.numpy_data[:, 1],
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ @require_rpy2
+ def test_comparison_histogram_pandas(self):
+ input_columns = self.input_columns
+ generate_comparison_histogram(
+ self.pandas_all,
+ column1=input_columns[0],
+ column2=input_columns[1],
+ subplot_title1=input_columns[0],
+ subplot_title2=input_columns[1],
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ @require_rpy2
+ def test_comparison_histogram_polars(self):
+ input_columns = self.input_columns
+ generate_comparison_histogram(
+ self.polars_all,
+ column1=input_columns[0],
+ column2=input_columns[1],
+ subplot_title1=input_columns[0],
+ subplot_title2=input_columns[1],
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ @require_rpy2
+ def test_comparison_histogram_arrow(self):
+ input_columns = self.input_columns
+ generate_comparison_histogram(
+ self.arrow_all,
+ column1=input_columns[0],
+ column2=input_columns[1],
+ subplot_title1=input_columns[0],
+ subplot_title2=input_columns[1],
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ @require_rpy2
+ def test_comparison_histogram_dataset(self):
+ input_columns = self.input_columns
+ generate_comparison_histogram(
+ self.data,
+ column1=input_columns[0],
+ column2=input_columns[1],
+ subplot_title1=input_columns[0],
+ subplot_title2=input_columns[1],
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ # Test methods for PCoA plot
+ def test_pcoa_plot_numpy(self):
+ plot_dimension_reduction(
+ self.numpy_data,
+ labels=self.numpy_target,
+ method="pcoa",
+ method_kwargs={"correction": "cailliez"},
+ n_components=3,
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ def test_pcoa_plot_pandas(self):
+ plot_dimension_reduction(
+ self.pandas_all,
+ method="pcoa",
+ method_kwargs={"correction": "cailliez"},
+ n_components=3,
+ input_columns=self.input_columns,
+ label_column=self.target_column,
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ @require_polars
+ def test_pcoa_plot_polars(self):
+ plot_dimension_reduction(
+ self.polars_all,
+ method="pcoa",
+ method_kwargs={"correction": "cailliez"},
+ n_components=3,
+ input_columns=self.input_columns,
+ label_column=self.target_column,
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ def test_pcoa_plot_arrow(self):
+ plot_dimension_reduction(
+ self.arrow_all,
+ method="pcoa",
+ method_kwargs={"correction": "cailliez"},
+ n_components=3,
+ input_columns=self.input_columns,
+ label_column=self.target_column,
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ @require_rpy2
+ def test_pcoa_plot_dataset(self):
+ plot_dimension_reduction(
+ self.dataset_all,
+ method="pcoa",
+ method_kwargs={"correction": "cailliez"},
+ n_components=3,
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ # Test methods for PCA plot
+ def test_pca_plot_numpy(self):
+ plot_dimension_reduction(
+ self.numpy_data,
+ labels=self.numpy_target,
+ method="pca",
+ n_components=3,
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ def test_pca_plot_pandas(self):
+ plot_dimension_reduction(
+ self.pandas_all,
+ method="pca",
+ n_components=3,
+ input_columns=self.input_columns,
+ label_column=self.target_column,
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ def test_pca_plot_polars(self):
+ plot_dimension_reduction(
+ self.polars_all,
+ method="pca",
+ n_components=3,
+ input_columns=self.input_columns,
+ label_column=self.target_column,
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ def test_pca_plot_arrow(self):
+ plot_dimension_reduction(
+ self.arrow_all,
+ method="pca",
+ n_components=3,
+ input_columns=self.input_columns,
+ label_column=self.target_column,
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ @require_biosets
+ def test_pca_plot_dataset(self):
+ plot_dimension_reduction(
+ self.dataset_all,
+ method="pca",
+ n_components=3,
+ output_dir=self.output_dir.as_posix(),
+ )
+ self._assert_plot_outputs()
+
+ # # Test methods for feature distribution plot
+ # def test_feature_distribution_plot_numpy(self):
+ # plot_feature_distribution(
+ # self.numpy_data,
+ # self.numpy_target,
+ # output_dir=self.output_dir.as_posix(),
+ # )
+ # self._assert_plot_outputs()
+ #
+ # def test_feature_distribution_plot_pandas(self):
+ # plot_feature_distribution(
+ # self.pandas_all,
+ # output_dir=self.output_dir.as_posix(),
+ # )
+ # self._assert_plot_outputs()
+ #
+ # def test_feature_distribution_plot_polars(self):
+ # plot_feature_distribution(
+ # self.polars_all,
+ # output_dir=self.output_dir.as_posix(),
+ # )
+ # self._assert_plot_outputs()
+ #
+ # def test_feature_distribution_plot_arrow(self):
+ # plot_feature_distribution(
+ # self.arrow_all,
+ # output_dir=self.output_dir.as_posix(),
+ # )
+ # self._assert_plot_outputs()
+ #
+ # def test_feature_distribution_plot_dataset(self):
+ # plot_feature_distribution(
+ # self.otu_dataset,
+ # output_dir=self.output_dir.as_posix(),
+ # )
+ # self._assert_plot_outputs()
+ #
+ # # Test methods for compare feature distributions plot
+ # def test_compare_feature_distribution_plot_numpy(self):
+ # compare_feature_distributions(
+ # self.numpy_data,
+ # self.numpy_data,
+ # output_dir=self.output_dir.as_posix(),
+ # )
+ # self._assert_plot_outputs()
+ #
+ # def test_compare_feature_distribution_plot_pandas(self):
+ # compare_feature_distributions(
+ # self.pandas_all,
+ # self.pandas_all,
+ # output_dir=self.output_dir.as_posix(),
+ # )
+ # self._assert_plot_outputs()
+ #
+ # def test_compare_feature_distribution_plot_polars(self):
+ # compare_feature_distributions(
+ # self.polars_all,
+ # self.polars_all,
+ # columns1=self.input_columns,
+ # columns2=self.input_columns,
+ # output_dir=self.output_dir.as_posix(),
+ # )
+ # self._assert_plot_outputs()
+ #
+ # def test_compare_feature_distribution_plot_arrow(self):
+ # compare_feature_distributions(
+ # self.arrow_all,
+ # self.arrow_all,
+ # output_dir=self.output_dir.as_posix(),
+ # )
+ # self._assert_plot_outputs()
+ #
+ # def test_compare_feature_distribution_plot_dataset(self):
+ # compare_feature_distributions(
+ # self.otu_dataset,
+ # self.otu_dataset,
+ # output_dir=self.output_dir.as_posix(),
+ # )
+ # self._assert_plot_outputs()
+ #
+ # # Test methods for barplot
+ # def test_barplot_numpy(self):
+ # generate_barplot(
+ # self.numpy_data[:, 0],
+ # self.numpy_target,
+ # output_dir=self.output_dir.as_posix(),
+ # )
+ # self._assert_plot_outputs()
+ #
+ # def test_barplot_pandas(self):
+ # generate_barplot(
+ # self.pandas_all,
+ # value_name=self.input_columns[0],
+ # label_name=self.target_column,
+ # output_dir=self.output_dir.as_posix(),
+ # )
+ # self._assert_plot_outputs()
+ #
+ # def test_barplot_polars(self):
+ # generate_barplot(
+ # self.polars_all,
+ # value_name=self.input_columns[0],
+ # label_name=self.target_column,
+ # output_dir=self.output_dir.as_posix(),
+ # )
+ # self._assert_plot_outputs()
+ #
+ # def test_barplot_arrow(self):
+ # generate_barplot(
+ # self.arrow_all,
+ # value_name=self.input_columns[0],
+ # label_name=self.target_column,
+ # output_dir=self.output_dir.as_posix(),
+ # )
+ # self._assert_plot_outputs()
+ #
+ # def test_barplot_dataset(self):
+ # generate_barplot(
+ # self.otu_dataset,
+ # value_name=self.input_columns[0],
+ # label_name=self.target_column,
+ # output_dir=self.output_dir.as_posix(),
+ # )
+ # self._assert_plot_outputs()
+ #
+ # # Test methods for scatterplot
+ # def test_scatterplot_numpy(self):
+ # generate_scatterplot(
+ # x=self.numpy_data[:, 0],
+ # y=self.numpy_data[:, 1],
+ # group=self.numpy_target,
+ # output_dir=self.output_dir.as_posix(),
+ # )
+ # self._assert_plot_outputs()
+ #
+ # def test_scatterplot_pandas(self):
+ # input_columns = self.input_columns
+ # generate_scatterplot(
+ # x=self.pandas_all,
+ # xdata=input_columns[0],
+ # ydata=input_columns[1],
+ # groupby=self.target_column,
+ # output_dir=self.output_dir.as_posix(),
+ # )
+ # self._assert_plot_outputs()
+ #
+ # def test_scatterplot_polars(self):
+ # input_columns = self.input_columns
+ # generate_scatterplot(
+ # x=self.polars_all,
+ # xdata=input_columns[0],
+ # ydata=input_columns[1],
+ # groupby=self.target_column,
+ # output_dir=self.output_dir.as_posix(),
+ # )
+ # self._assert_plot_outputs()
+ #
+ # def test_scatterplot_arrow(self):
+ # input_columns = self.input_columns
+ # generate_scatterplot(
+ # x=self.arrow_all,
+ # xdata=input_columns[0],
+ # ydata=input_columns[1],
+ # groupby=self.target_column,
+ # output_dir=self.output_dir.as_posix(),
+ # )
+ # self._assert_plot_outputs()
+ #
+ # def test_scatterplot_dataset(self):
+ # input_columns = self.input_columns
+ # generate_scatterplot(
+ # x=self.otu_dataset,
+ # xdata=input_columns[0],
+ # ydata=input_columns[1],
+ # groupby=self.target_column,
+ # output_dir=self.output_dir.as_posix(),
+ # )
+ # self._assert_plot_outputs()
diff --git a/tests/test_processing.py b/tests/test_processing.py
new file mode 100644
index 0000000..1e30cf5
--- /dev/null
+++ b/tests/test_processing.py
@@ -0,0 +1,2266 @@
+import copy
+import inspect
+import shutil
+import unittest
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import List, Optional, Type, Union
+from unittest.mock import patch
+
+import numpy as np
+import pandas as pd
+import pyarrow as pa
+import pytest
+from biocore import DataHandler
+
+import biofit.config
+from biofit.integration.biosets import get_feature
+from biofit.processing import (
+ BaseProcessor,
+ NonExistentCacheError,
+ ProcessorConfig,
+ SelectedColumnTypes,
+ sync_backup_config,
+)
+from biofit.utils import version
+from biofit.utils.fingerprint import generate_cache_dir
+from tests.utils import create_bioset, require_biosets, require_datasets
+
+# Mock feature types for testing purposes
+FEATURE_TYPES = Union[Type, tuple]
+
+pytestmark = pytest.mark.unit
+
+
+@dataclass
+class MockProcessorConfig(ProcessorConfig):
+ """
+ Configuration for the MockModel. Inherits from ProcessorConfig and specifies
+ feature types for automatic column selection.
+ """
+
+ # Specifies the input feature types for the fit method (arity of 2: X and y)
+ _fit_input_feature_types: List[FEATURE_TYPES] = field(
+ default_factory=lambda: [None, get_feature("TARGET_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ # Specifies the unused feature types during fitting
+ _fit_unused_feature_types: List[FEATURE_TYPES] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES"), None],
+ init=False,
+ repr=False,
+ )
+ # Specifies the unused feature types during transformation
+ _transform_unused_feature_types: List[FEATURE_TYPES] = field(
+ default_factory=lambda: [get_feature("METADATA_FEATURE_TYPES")],
+ init=False,
+ repr=False,
+ )
+ _fit_process_desc: str = field(default="Fitting test", init=False, repr=False)
+ _predict_process_desc: str = field(default="Predict test", init=False, repr=False)
+ _transform_process_desc: str = field(
+ default="Transforming test", init=False, repr=False
+ )
+ # Name of the processor
+ processor_type: str = field(default="mock", init=False, repr=False)
+ processor_name: str = field(default="processor", init=False, repr=False)
+ # Additional parameters for testing
+ processor_param_int: int = None
+ processor_param_float: float = None
+ processor_param_str: str = None
+ processor_param_bool: bool = None
+ processor_param_list: List = None
+ processor_param_dict: dict = None
+ processor_param_tuple: tuple = None
+
+
+class MockModel(BaseProcessor):
+ """
+ A mock model class for testing ProcessorConfig, TransformationMixin, and BaseProcessor.
+ Implements the necessary _fit_* and _predict_* methods.
+ """
+
+ config_class = MockProcessorConfig
+ config: MockProcessorConfig
+
+ def __init__(self, config: Optional[MockProcessorConfig] = None, **kwargs):
+ # Initialize the BaseProcessor with the given configuration
+ super().__init__(config=config or self.config_class(**kwargs))
+ self.post_fit = lambda x: x
+
+ @sync_backup_config
+ def set_params(self, **kwargs):
+ self.config = self.config.replace_defaults(**kwargs)
+ self.post_fit = lambda x: x
+
+ def fit(
+ self,
+ X,
+ y=None,
+ input_columns: SelectedColumnTypes = None,
+ target_column: SelectedColumnTypes = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ num_proc: int = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ) -> "MockModel":
+ """
+ Mock fit method that processes input data and fits the model.
+ """
+ # Prepare input columns (arity of 2: X and y)
+ self.config._input_columns = self._set_input_columns_and_arity(
+ input_columns, target_column
+ )
+ return self._process_fit(
+ X,
+ y,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def predict(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ """
+ Mock predict method that generates predictions.
+ """
+ self._method_prefix = "_predict"
+ self.config._n_features_out = 1
+ self._input_columns = self._set_input_columns_and_arity(input_columns)
+ self.output_dtype = "float64"
+ return self._process_transform(
+ X,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def predict_proba(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ """
+ Mock predict_proba method that generates predictions.
+ """
+ self._method_prefix = "_predict_proba"
+ self._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_transform(
+ X,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def _fit_sklearn(self, X, y, some_param=None):
+ """
+ Internal fit method for sklearn-compatible data formats.
+ """
+ # Mock fitting logic (e.g., training a simple model)
+ self.config.processor_param_numpy = np.array([1, 2, 3])
+ self.config.processor_param_arrow = pa.table({"a": [1, 2, 3]})
+ self.config.processor_param_polars = pa.table({"a": [1, 2, 3]})
+ self.config.processor_param_pandas = pa.table({"a": [1, 2, 3]})
+ self.config.estimator = {"mean": np.mean(DataHandler.to_numpy(y).flatten())}
+ return self
+
+ def _predict_sklearn(self, X):
+ """
+ Internal predict method for sklearn-compatible data formats.
+ """
+ if not self.config.is_fitted:
+ raise ValueError("Model is not fitted yet.")
+ # Mock prediction logic (e.g., returning the mean value)
+ predictions = np.full(len(X), self.config.estimator["mean"])
+ return predictions
+
+ def _predict_proba_sklearn(self, X):
+ """
+ Internal predict_proba method for sklearn-compatible data formats.
+ """
+ if not self.config.is_fitted:
+ raise ValueError("Model is not fitted yet.")
+ # Mock prediction logic (e.g., returning the mean value)
+ predictions = np.full(
+ (len(X), self.config.n_classes), self.config.estimator["mean"]
+ )
+ return predictions
+
+
+class MockPreprocessor(BaseProcessor):
+ """
+ A mock model class for testing ProcessorConfig, TransformationMixin, and BaseProcessor.
+ Implements the necessary _fit_* and _predict_* methods.
+ """
+
+ config_class = MockProcessorConfig
+ config: MockProcessorConfig
+
+ def __init__(self, config: Optional[MockProcessorConfig] = None, **kwargs):
+ # Initialize the BaseProcessor with the given configuration
+ super().__init__(config=config or self.config_class(**kwargs))
+ self.post_fit = lambda x: x
+
+ @sync_backup_config
+ def set_params(self, **kwargs):
+ self.config = self.config.replace_defaults(**kwargs)
+ self.post_fit = lambda x: x
+
+ def fit(
+ self,
+ X,
+ y=None,
+ input_columns: SelectedColumnTypes = None,
+ target_column: SelectedColumnTypes = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ batch_format: str = None,
+ num_proc: int = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ) -> "MockModel":
+ """
+ Mock fit method that processes input data and fits the model.
+ """
+ # Prepare input columns (arity of 2: X and y)
+ self.config._input_columns = self._set_input_columns_and_arity(
+ input_columns, target_column
+ )
+ return self._process_fit(
+ X,
+ y,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def transform(
+ self,
+ X,
+ input_columns: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ load_from_cache_file: bool = None,
+ batched: bool = None,
+ batch_size: int = None,
+ batch_format: str = None,
+ output_format: str = None,
+ map_kwargs: dict = None,
+ num_proc: int = None,
+ fingerprint: str = None,
+ ):
+ """
+ Mock predict method that generates predictions.
+ """
+ self._method_prefix = "_transform"
+ self._input_columns = self._set_input_columns_and_arity(input_columns)
+ return self._process_transform(
+ X,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def fit_transform(
+ self,
+ X,
+ y=None,
+ input_columns: SelectedColumnTypes = None,
+ target_column: SelectedColumnTypes = None,
+ keep_unused_columns: bool = None,
+ raise_if_missing: bool = None,
+ cache_output: bool = None,
+ load_from_cache_file: bool = None,
+ batched: bool = True,
+ batch_size: int = 1000,
+ output_format: str = None,
+ batch_format: str = None,
+ num_proc: int = None,
+ map_kwargs: dict = {"fn_kwargs": {}},
+ cache_dir: str = None,
+ cache_file_name: str = None,
+ fingerprint: str = None,
+ ):
+ """
+ Mock fit_transform method that processes input data and fits the model.
+ """
+ return self.fit(
+ X,
+ y,
+ input_columns=input_columns,
+ target_column=target_column,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ num_proc=num_proc,
+ map_kwargs=map_kwargs,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ fingerprint=fingerprint,
+ ).transform(
+ X,
+ input_columns=input_columns,
+ keep_unused_columns=keep_unused_columns,
+ raise_if_missing=raise_if_missing,
+ cache_output=cache_output,
+ cache_dir=cache_dir,
+ cache_file_name=cache_file_name,
+ load_from_cache_file=load_from_cache_file,
+ batched=batched,
+ batch_size=batch_size,
+ batch_format=batch_format,
+ output_format=output_format,
+ map_kwargs=map_kwargs,
+ num_proc=num_proc,
+ fingerprint=fingerprint,
+ )
+
+ def _fit_pandas(self, X, y):
+ """
+ Internal fit method for sklearn-compatible data formats.
+ """
+ return self
+
+ def _fit_arrow(self, X, y):
+ """
+ Internal fit method for sklearn-compatible data formats.
+ """
+ return self
+
+ def _fit_polars(self, X, y):
+ """
+ Internal fit method for sklearn-compatible data formats.
+ """
+ return self
+
+ def _fit_numpy(self, X, y):
+ """
+ Internal fit method for sklearn-compatible data formats.
+ """
+ return self
+
+ def _transform_pandas(self, X):
+ """
+ Internal predict method for sklearn-compatible data formats.
+ """
+ return X
+
+ def _transform_arrow(self, X):
+ """
+ Internal predict method for sklearn-compatible data formats.
+ """
+ return X
+
+ def _transform_polars(self, X):
+ """
+ Internal predict method for sklearn-compatible data formats.
+ """
+ return X
+
+ def _transform_numpy(self, X):
+ """
+ Internal predict method for sklearn-compatible data formats.
+ """
+ return X
+
+
+class TestMockModel(unittest.TestCase):
+ @pytest.fixture(autouse=True)
+ def inject_fixtures(self, count_data, sample_metadata):
+ self.sample_metadata = sample_metadata
+ self.metadata_columns = list(sample_metadata.columns)
+ self.X, self.y = count_data
+ self.column_names = list(self.X.columns)
+ self.target_column = self.y.columns[0]
+ self.data = pd.concat([self.sample_metadata, self.X, self.y], axis=1)
+
+ def setUp(self):
+ self.params = {
+ "version": version.__version__,
+ "processor_param_int": 1,
+ "processor_param_float": 1.0,
+ "processor_param_str": "test",
+ "processor_param_bool": True,
+ "processor_param_list": [1, 2, 3],
+ "processor_param_dict": {"a": 1, "b": 2},
+ "processor_param_tuple": (1, 2),
+ }
+ self.config = MockProcessorConfig(**self.params)
+ self.model = MockModel(config=self.config)
+ self.preprocessor = MockPreprocessor(config=self.config)
+ self.funcs = [
+ self.preprocessor._fit_pandas,
+ self.preprocessor._fit_numpy,
+ self.preprocessor._fit_arrow,
+ self.preprocessor._fit_polars,
+ ]
+ self.accepted_formats = ["pandas", "numpy", "arrow", "polars"]
+
+ def tearDown(self):
+ # clean up the cache directory
+ cache_dir = biofit.config.BIOFIT_CACHE_HOME
+ if cache_dir.exists():
+ shutil.rmtree(cache_dir)
+
+ def test_init(self):
+ model_params = self.model.config.get_params()
+ preprocessor_params = self.preprocessor.config.get_params()
+ self.assertEqual(model_params, preprocessor_params)
+ self.assertEqual(model_params, self.params)
+ self.assertEqual(preprocessor_params, self.params)
+
+ def test_get_method(self):
+ formats = ["numpy", "pandas", "arrow", "polars"]
+ func_type = "_fit"
+ methods = self.preprocessor._get_method(formats, func_type)
+ method_names = [method.__name__ for method in methods]
+ expected_method_names = [
+ "_fit_numpy",
+ "_fit_pandas",
+ "_fit_arrow",
+ "_fit_polars",
+ ]
+ self.assertEqual(method_names, expected_method_names)
+
+ def test_has_method(self):
+ formats = ["numpy", "pandas", "arrow", "polars"]
+ func_type = "_fit"
+ has_method = self.preprocessor._has_method(formats, func_type)
+ self.assertTrue(has_method)
+ formats = ["nonexistent_format"]
+ has_method = self.preprocessor._has_method(formats, func_type)
+ self.assertFalse(has_method)
+
+ def test_get_target_func_match_source_format(self):
+ # Testing _get_target_func method
+ # Priority should be source format
+ func, to_format = self.preprocessor._get_target_func(
+ funcs=self.funcs,
+ source_format="numpy",
+ accepted_formats=self.accepted_formats,
+ )
+
+ self.assertEqual(func.__name__, "_fit_numpy")
+ self.assertEqual(to_format, "numpy")
+
+ def test_get_target_func_priority(self):
+ # Priority should be the first format found in funcs in the order of
+ # accepted_formats
+ func, to_format = self.preprocessor._get_target_func(
+ funcs=self.funcs[-2:],
+ source_format="not_a_format",
+ accepted_formats=self.accepted_formats,
+ )
+
+ self.assertEqual(func.__name__, f"_fit_{self.accepted_formats[-2]}")
+ self.assertEqual(to_format, self.accepted_formats[-2])
+
+ def test_get_target_func_priority_order(self):
+ # Priority is the first format found in funcs in the order of
+ # target_formats
+ func, to_format = self.preprocessor._get_target_func(
+ funcs=self.funcs,
+ source_format="pandas",
+ target_formats=["polars", "arrow"],
+ )
+ self.assertEqual(func.__name__, "_fit_polars")
+ self.assertEqual(to_format, "polars")
+
+ def test_set_input_columns_and_arity(self):
+ result = self.preprocessor._set_input_columns_and_arity(
+ "col1", ["col2", "col3"]
+ )
+ expected = [["col1"], ["col2", "col3"]]
+ self.assertEqual(result, expected)
+
+ def test_reinsert_columns_valid(self):
+ """Test reinsert_columns with valid inputs and unused indices."""
+ # Input data (input)
+ input_data = pd.DataFrame(
+ {
+ "col1": np.arange(10),
+ "col2": np.arange(10, 20),
+ "col3": np.arange(20, 30),
+ "col4": np.arange(30, 40),
+ }
+ )
+ # Output data (out)
+ output_data = pd.DataFrame({"col1_transformed": np.arange(100, 110)})
+ # Indices of unused columns (unused_indices)
+ indices = [0, 3] # 'col1' and 'col4' transformed into 'col1_transformed'
+ unused_indices = [1, 2] # 'col2' and 'col3' were unused
+
+ self.preprocessor.config._n_features_out = 1
+ result = self.preprocessor._reinsert_columns(
+ input=input_data,
+ out=output_data,
+ indices=indices,
+ unused_indices=unused_indices,
+ one_to_one_features=self.preprocessor.config.one_to_one_features,
+ )
+ self.preprocessor.config._n_features_out = None
+
+ # When one_to_one_features=False, the output columns should be
+ # appended to the end of the input data
+ # Expected result: 'col2', 'col3', and 'col1_transformed'
+ other_cols = input_data.iloc[:, unused_indices]
+ concatenated = pd.concat([output_data, other_cols], axis=1)
+ expected_output = concatenated[["col2", "col3", "col1_transformed"]]
+
+ # Assert that the result matches the expected output
+ pd.testing.assert_frame_equal(result, expected_output)
+
+ def test_reinsert_columns_one_to_one_features(self):
+ """Test reinsert_columns with one_to_one_features=True."""
+ # Input data
+ input_data = pd.DataFrame(
+ {
+ "col0": np.arange(10),
+ "col1": np.arange(10, 20),
+ "col2": np.arange(20, 30),
+ "col3": np.arange(30, 40),
+ }
+ )
+ # Output data
+ output_data = pd.DataFrame(
+ {
+ "col0_transformed": np.arange(100, 110),
+ "col3_transformed": np.arange(200, 210),
+ }
+ )
+ # Indices used
+ indices = [0, 3] # 'col0' and 'col3' transformed
+ # Unused indices
+ unused_indices = [1, 2] # 'col1' and 'col2' unused
+
+ self.preprocessor.config._n_features_out = None
+
+ # when one_to_one_features=True, the output columns should be follow
+ # the same order as the input columns
+ result = self.preprocessor._reinsert_columns(
+ input=input_data,
+ out=output_data,
+ indices=indices,
+ unused_indices=unused_indices,
+ one_to_one_features=self.preprocessor.config.one_to_one_features,
+ )
+
+ # Expected result
+ other_cols = input_data.iloc[:, unused_indices]
+ expected_output = pd.concat([other_cols, output_data], axis=1)
+ expected_output = expected_output[
+ ["col0_transformed", "col1", "col2", "col3_transformed"]
+ ]
+
+ # Assert
+ pd.testing.assert_frame_equal(result, expected_output)
+
+ def test_reinsert_columns_no_unused_columns(self):
+ """Test when there are no unused columns."""
+ input_data = pd.DataFrame({"col1": np.arange(10)})
+ output_data = pd.DataFrame({"col1_transformed": np.arange(100, 110)})
+ indices = [0]
+ unused_indices = []
+
+ result = self.preprocessor._reinsert_columns(
+ input_data, output_data, indices, unused_indices
+ )
+
+ # Expected result is the same as output_data
+ pd.testing.assert_frame_equal(result, output_data)
+
+ def test_reinsert_columns_mismatched_row_counts(self):
+ """Test when input and output have different number of rows."""
+ input_data = pd.DataFrame({"col1": np.arange(10)})
+ output_data = pd.DataFrame(
+ {
+ "col1_transformed": np.arange(5) # Different number of rows
+ }
+ )
+ indices = [0]
+ unused_indices = []
+
+ result = self.preprocessor._reinsert_columns(
+ input_data, output_data, indices, unused_indices
+ )
+
+ # Since row counts differ, should return output as is
+ pd.testing.assert_frame_equal(result, output_data)
+
+ def test_reinsert_columns_invalid_indices(self):
+ """Test with invalid indices (out of bounds)."""
+ input_data = pd.DataFrame({"col1": np.arange(10)})
+ output_data = pd.DataFrame({"col1_transformed": np.arange(10)})
+ indices = [0]
+ unused_indices = [1] # Invalid index
+
+ with self.assertRaises(IndexError):
+ self.preprocessor._reinsert_columns(
+ input_data, output_data, indices, unused_indices
+ )
+
+ def test_reinsert_columns_empty_input(self):
+ """Test with empty input data."""
+ input_data = pd.DataFrame()
+ output_data = pd.DataFrame({"col_transformed": []})
+ indices = []
+ unused_indices = []
+
+ result = self.preprocessor._reinsert_columns(
+ input_data, output_data, indices, unused_indices
+ )
+
+ # Expected result is the same as output_data
+ pd.testing.assert_frame_equal(result, output_data)
+
+ def test_reinsert_columns_empty_output(self):
+ """Test with empty output data and unused columns."""
+ input_data = pd.DataFrame({"col1": [], "col2": []})
+ output_data = pd.DataFrame()
+ indices = []
+ unused_indices = [0, 1]
+
+ result = self.preprocessor._reinsert_columns(
+ input_data, output_data, indices, unused_indices
+ )
+
+ # Expected result is input_data with unused columns
+ pd.testing.assert_frame_equal(result, input_data)
+
+ def test_reinsert_columns_different_column_types(self):
+ """Test with different data types in columns."""
+ input_data = pd.DataFrame(
+ {
+ "int_col": np.arange(10),
+ "float_col": np.random.rand(10),
+ "str_col": ["text"] * 10,
+ }
+ )
+ output_data = pd.DataFrame({"int_col_transformed": np.arange(100, 110)})
+ indices = [0]
+ unused_indices = [1, 2]
+
+ result = self.preprocessor._reinsert_columns(
+ input_data, output_data, indices, unused_indices, one_to_one_features=True
+ )
+
+ other_cols = input_data[["float_col", "str_col"]]
+ expected_output = pd.concat([output_data, other_cols], axis=1)
+ expected_output = expected_output[
+ ["int_col_transformed", "float_col", "str_col"]
+ ]
+
+ pd.testing.assert_frame_equal(result, expected_output)
+
+ def test_reinsert_columns_null_values(self):
+ """Test with null values in input data."""
+ input_data = pd.DataFrame(
+ {
+ "col1": [1, np.nan, 3, np.nan, 5],
+ "col2": [np.nan, 2, np.nan, 4, np.nan],
+ }
+ )
+ output_data = pd.DataFrame({"col1_transformed": [10, 20, 30, 40, 50]})
+ indices = [0]
+ unused_indices = [1]
+
+ result = self.preprocessor._reinsert_columns(
+ input_data, output_data, indices, unused_indices, one_to_one_features=True
+ )
+
+ other_cols = input_data[["col2"]]
+ expected_output = pd.concat([output_data, other_cols], axis=1)
+ expected_output = expected_output[["col1_transformed", "col2"]]
+
+ pd.testing.assert_frame_equal(result, expected_output)
+
+ def test_reinsert_columns_non_dataframe_input(self):
+ """Test with input data that is not a DataFrame."""
+ input_data = pd.DataFrame(
+ {
+ "col1": np.arange(10),
+ "col2": np.arange(10, 20),
+ "col3": np.arange(20, 30),
+ }
+ )
+ output_data = np.arange(10)
+ indices = [0]
+ unused_indices = [1, 2]
+
+ output_data = self.preprocessor._reinsert_columns(
+ input_data, output_data, indices, unused_indices
+ )
+ # format type should match the input data
+ self.assertIsInstance(output_data, pd.DataFrame)
+
+ def test_make_columns_exclusive(self):
+ columns = [["col1", "col2"], ["col2", "col3"], ["col3", "col4"]]
+ result = self.preprocessor._make_columns_exclusive(columns)
+ expected = [["col1"], ["col2"], ["col3", "col4"]]
+ self.assertEqual(result, expected)
+
+ def test_get_columns_valid_input_columns(self):
+ """Test _get_columns with valid input_columns"""
+ input_columns = self.column_names
+ result = self.preprocessor._get_columns(
+ self.X, input_columns=[input_columns], raise_if_missing=True
+ )
+ # Unpack results
+ (
+ feature_names_in,
+ feature_idx_in,
+ unused_idx_in,
+ extra_names_in,
+ extra_idx_in,
+ unused_extra_idx_in,
+ offsets,
+ ) = result
+
+ # Assertions
+ self.assertEqual(feature_names_in, input_columns)
+ expected_indices = [self.column_names.index(col) for col in input_columns]
+ self.assertEqual(feature_idx_in, expected_indices)
+ expected_unused_indices = sorted(
+ set(range(len(self.column_names))) - set(expected_indices)
+ )
+ self.assertEqual(unused_idx_in, expected_unused_indices)
+ self.assertIsNone(extra_names_in)
+ self.assertIsNone(extra_idx_in)
+ self.assertIsNone(unused_extra_idx_in)
+ self.assertIsNone(offsets)
+
+ def test_get_columns_invalid_input_columns_raise(self):
+ """Test _get_columns with invalid input_columns and raise_if_missing=True"""
+ input_columns = self.column_names + ["non_existent_column"]
+ with self.assertRaises(ValueError) as context:
+ self.preprocessor._get_columns(
+ self.X, input_columns=[input_columns], raise_if_missing=True
+ )
+ self.assertIn(
+ str(context.exception),
+ "Columns {'non_existent_column'} not found in input dataset",
+ )
+
+ def test_get_columns_invalid_input_columns_no_raise(self):
+ """Test _get_columns with invalid input_columns and raise_if_missing=False"""
+ input_columns = self.column_names + ["non_existent_column"]
+ result = self.preprocessor._get_columns(
+ self.X, input_columns=[input_columns], raise_if_missing=False
+ )
+ (
+ feature_names_in,
+ feature_idx_in,
+ unused_idx_in,
+ extra_names_in,
+ extra_idx_in,
+ unused_extra_idx_in,
+ offsets,
+ ) = result
+
+ self.assertEqual(feature_names_in, self.column_names)
+ expected_indices = list(range(len(self.column_names)))
+ self.assertEqual(feature_idx_in, expected_indices)
+ expected_unused_indices = sorted(
+ set(range(len(self.column_names))) - set(expected_indices)
+ )
+ self.assertEqual(unused_idx_in, expected_unused_indices)
+ self.assertIsNone(extra_names_in)
+ self.assertIsNone(extra_idx_in)
+ self.assertIsNone(unused_extra_idx_in)
+ self.assertIsNone(offsets)
+
+ def test_get_columns_unused_columns(self):
+ """Test _get_columns with unused_columns specified"""
+ unused_columns = self.column_names[:2]
+ result = self.preprocessor._get_columns(
+ self.X,
+ input_columns=[None],
+ unused_columns=[unused_columns],
+ raise_if_missing=True,
+ )
+ # Unpack results
+ (
+ feature_names_in,
+ feature_idx_in,
+ unused_idx_in,
+ extra_names_in,
+ extra_idx_in,
+ unused_extra_idx_in,
+ offsets,
+ ) = result
+
+ expected_input_columns = [
+ col for col in self.column_names if col not in unused_columns
+ ]
+ self.assertEqual(feature_names_in, expected_input_columns)
+ expected_indices = [
+ self.column_names.index(col) for col in expected_input_columns
+ ]
+ self.assertEqual(feature_idx_in, expected_indices)
+ expected_unused_indices = sorted(
+ [self.column_names.index(col) for col in unused_columns]
+ )
+ self.assertEqual(unused_idx_in, expected_unused_indices)
+ self.assertIsNone(extra_names_in)
+ self.assertIsNone(extra_idx_in)
+ self.assertIsNone(unused_extra_idx_in)
+ self.assertIsNone(offsets)
+
+ def test_get_columns_with_args(self):
+ """Test _get_columns with additional args (extra inputs)"""
+ # Extra input data
+ extra_X = {
+ "extra_feature_1": np.random.rand(50),
+ "extra_feature_2": np.random.rand(50),
+ }
+ input_columns = [self.column_names, ["extra_feature_1"]]
+ result = self.preprocessor._get_columns(
+ self.X, extra_X, input_columns=input_columns, raise_if_missing=True
+ )
+ # Unpack results
+ (
+ feature_names_in,
+ feature_idx_in,
+ unused_idx_in,
+ extra_names_in,
+ extra_idx_in,
+ unused_extra_idx_in,
+ offsets,
+ ) = result
+
+ # Assertions for main input
+ self.assertEqual(feature_names_in, self.column_names)
+ expected_indices = list(range(len(self.column_names)))
+ self.assertEqual(feature_idx_in, expected_indices)
+ expected_unused_indices = sorted(
+ set(range(len(self.column_names))) - set(expected_indices)
+ )
+ self.assertEqual(unused_idx_in, expected_unused_indices)
+
+ # Assertions for extra input
+ self.assertEqual(extra_names_in, [["extra_feature_1"]])
+ self.assertEqual(extra_idx_in, [[0]])
+ self.assertEqual(unused_extra_idx_in, [[1]]) # Only 'extra_feature_2' is unused
+ self.assertEqual(offsets, []) # No offsets since extra input not concatenated
+
+ def test_get_columns_empty_input_columns(self):
+ """Test _get_columns with empty input_columns"""
+ result = self.preprocessor._get_columns(self.X, input_columns=[])
+ self.assertEqual(result, (None, None, None, None, None, None, None))
+
+ def test_get_columns_input_feature_types(self):
+ """Test _get_columns with input_feature_types specified"""
+ input_feature_types = get_feature("Abundance")
+ result = self.preprocessor._get_columns(
+ self.X,
+ input_columns=[None],
+ input_feature_types=[input_feature_types],
+ raise_if_missing=True,
+ )
+ # Unpack results
+ (
+ feature_names_in,
+ feature_idx_in,
+ unused_idx_in,
+ extra_names_in,
+ extra_idx_in,
+ unused_extra_idx_in,
+ offsets,
+ ) = result
+
+ # Assertions
+ self.assertEqual(feature_names_in, self.column_names)
+ expected_indices = list(range(len(self.column_names)))
+ self.assertEqual(feature_idx_in, expected_indices)
+ expected_unused_indices = sorted(
+ set(range(len(self.column_names))) - set(expected_indices)
+ )
+ self.assertEqual(unused_idx_in, expected_unused_indices)
+ self.assertIsNone(extra_names_in)
+ self.assertIsNone(extra_idx_in)
+ self.assertIsNone(unused_extra_idx_in)
+ self.assertIsNone(offsets)
+
+ @require_biosets
+ def test_get_columns_invalid_input_feature_types(self):
+ """Test _get_columns with invalid input_feature_types"""
+
+ X = create_bioset(
+ self.X,
+ feature_type=get_feature("Abundance"),
+ )
+
+ # Dataset uses Abundance feature type, but GenomicVariant is specified
+ result = self.preprocessor._get_columns(
+ X,
+ input_columns=[None],
+ input_feature_types=[get_feature("GenomicVariant")],
+ raise_if_missing=False,
+ )
+ # Should return None values due to invalid feature types
+ self.assertEqual(result, (None, None, None, None, None, None, None))
+
+ def test_get_columns_mismatched_arity(self):
+ """Test _get_columns with mismatched arity between input_columns and args"""
+ # No args provided, but input_columns has two lists
+ input_columns = [self.column_names, ["extra_feature_1"]]
+ with self.assertRaises(AssertionError) as context:
+ self.preprocessor._get_columns(
+ self.X, input_columns=input_columns, raise_if_missing=True
+ )
+ self.assertIn(
+ str(context.exception),
+ "Number of column sets (2) must match the arity (1)",
+ )
+
+ def test_get_columns_with_none_X(self):
+ """Test _get_columns with X as None"""
+ with self.assertRaises(AssertionError) as context:
+ self.preprocessor._get_columns(
+ None, input_columns=[self.column_names], raise_if_missing=True
+ )
+ self.assertIn("Input data is None", str(context.exception))
+
+ def test_get_columns_with_non_list_input_columns(self):
+ """Test _get_columns with input_columns not wrapped in a list"""
+ input_columns = "feature_1"
+ with self.assertRaises(AssertionError) as context:
+ self.preprocessor._get_columns(
+ self.X, input_columns=input_columns, raise_if_missing=True
+ )
+ # Should return None values since input_columns is not properly wrapped
+ self.assertIn(
+ str(context.exception),
+ f"input_columns must be a list of column names or indices, "
+ f"but got {type(input_columns)}",
+ )
+
+ @require_biosets
+ def test_get_columns_with_unused_feature_types(self):
+ """Test _get_columns with unused_feature_types specified"""
+ X = create_bioset(
+ self.X,
+ self.y,
+ self.sample_metadata,
+ feature_type=get_feature("Abundance"),
+ target_type=get_feature("BinClassLabel"),
+ )
+ unused_feature_types = get_feature("METADATA_FEATURE_TYPES")
+ result = self.preprocessor._get_columns(
+ X,
+ input_columns=[None],
+ unused_feature_types=[unused_feature_types],
+ raise_if_missing=True,
+ )
+ # Unpack results
+ (
+ feature_names_in,
+ feature_idx_in,
+ unused_idx_in,
+ _,
+ _,
+ _,
+ _,
+ ) = result
+
+ self.assertEqual(feature_names_in, self.column_names)
+ expected_indices = [X.column_names.index(col) for col in self.column_names]
+ self.assertEqual(feature_idx_in, expected_indices)
+ expected_unused_indices = sorted(
+ [
+ X.column_names.index(col)
+ for col in self.metadata_columns + [self.target_column]
+ ]
+ )
+ self.assertEqual(unused_idx_in, expected_unused_indices)
+
+ def test_get_columns_with_all_none_parameters(self):
+ """Test _get_columns with all parameters as None"""
+ result = self.preprocessor._get_columns(
+ self.X,
+ input_columns=None,
+ input_feature_types=None,
+ unused_columns=None,
+ unused_feature_types=None,
+ raise_if_missing=False,
+ )
+ # Should return None values
+ self.assertEqual(result, (None, None, None, None, None, None, None))
+
+ def test_get_columns_with_args_none(self):
+ """Test _get_columns with args as None"""
+ result = self.preprocessor._get_columns(
+ self.data,
+ None,
+ input_columns=[
+ self.metadata_columns,
+ self.column_names,
+ ],
+ raise_if_missing=True,
+ )
+ # Unpack results
+ (
+ feature_names_in,
+ feature_idx_in,
+ _,
+ extra_names_in,
+ extra_idx_in,
+ unused_extra_idx_in,
+ offsets,
+ ) = result
+ # Assertions for main input
+ self.assertEqual(feature_names_in, self.metadata_columns)
+ cols = list(self.data.columns)
+ expected_indices = [cols.index(col) for col in self.metadata_columns]
+ self.assertEqual(feature_idx_in, expected_indices)
+
+ # Assertions for extra input (args)
+ self.assertEqual(extra_names_in, [self.column_names])
+ self.assertEqual(
+ extra_idx_in,
+ [[cols.index(col) for col in self.column_names]],
+ )
+ self.assertEqual(unused_extra_idx_in, None)
+ self.assertEqual(offsets, [0])
+
+ def test_get_columns_with_args_mismatched_rows(self):
+ """Test _get_columns with args having mismatched number of rows"""
+ # Extra input data with different number of rows
+ extra_X = {
+ "extra_feature_1": np.random.rand(self.data.shape[0] + 1),
+ }
+ input_columns = [self.column_names, ["extra_feature_1"]]
+ result = self.preprocessor._get_columns(
+ self.X, extra_X, input_columns=input_columns, raise_if_missing=True
+ )
+ # Since row numbers don't match, offsets should not be incremented
+ (_, _, _, _, _, _, offsets) = result
+ self.assertEqual(offsets, []) # No offsets should be added
+
+ def test_generate_fingerprint(self):
+ fingerprint = "initial_fingerprint"
+ generated_fp = self.preprocessor.generate_fingerprint(fingerprint, self.config)
+ self.assertIsNotNone(generated_fp)
+ self.assertIsInstance(generated_fp, str)
+
+ def test_from_config(self):
+ new_processor = MockPreprocessor._from_config(self.config)
+ self.assertIsInstance(new_processor, MockPreprocessor)
+ self.assertEqual(new_processor.config, self.config)
+
+ def test_call(self):
+ # for ray-tune compatibility
+ batch = pd.DataFrame({"col1": [1, 2], "col2": [3, 4]})
+ kwargs = {
+ "fn": self.preprocessor._transform_pandas,
+ "func_type": "_transform",
+ "selected_indices": [0, 1],
+ "unused_indices": [2, 3],
+ "keep_unused_columns": False,
+ "in_format_kwargs": {"target_format": "pandas"},
+ "out_format_kwargs": {"target_format": "pandas"},
+ }
+ result = self.preprocessor(batch, **kwargs)
+ pd.testing.assert_frame_equal(result, batch)
+
+ def test_set_params(self):
+ self.preprocessor.set_params(
+ processor_param_int=99,
+ processor_param_float=99.0,
+ processor_param_str="new",
+ processor_param_bool=False,
+ processor_param_list=[99, 99, 99],
+ processor_param_dict={"a": 99, "b": 99},
+ processor_param_tuple=(99, 99),
+ )
+ self.assertEqual(self.preprocessor.config.processor_param_int, 99)
+ self.assertEqual(self.preprocessor.config.processor_param_float, 99.0)
+ self.assertEqual(self.preprocessor.config.processor_param_str, "new")
+ self.assertEqual(self.preprocessor.config.processor_param_bool, False)
+ self.assertEqual(self.preprocessor.config.processor_param_list, [99, 99, 99])
+ self.assertEqual(
+ self.preprocessor.config.processor_param_dict, {"a": 99, "b": 99}
+ )
+ self.assertEqual(self.preprocessor.config.processor_param_tuple, (99, 99))
+
+ def test_is_fitted(self):
+ self.assertFalse(self.preprocessor.is_fitted)
+ self.preprocessor.fit(self.X, self.y)
+ self.assertTrue(self.preprocessor.is_fitted)
+
+ def test_has_fit(self):
+ self.assertTrue(self.preprocessor.has_fit)
+
+ def test_process_fit_without_input_columns(self):
+ # Test _process_fit without input_columns
+ self.assertFalse(self.model.is_fitted)
+ with self.assertRaises(AssertionError) as context:
+ self.preprocessor._process_fit(self.X, self.y)
+ self.assertIn(
+ str(context.exception),
+ "The `fit` method of `MockPreprocessor` must call:\n"
+ "```\n"
+ "self.config._input_columns = self._set_input_columns_and_arity(*args)"
+ "\n```\n"
+ "Where `*args` are the columns for each input dataset.",
+ )
+ with self.assertRaises(AssertionError) as context:
+ self.model._process_fit(self.X, self.y)
+ self.assertIn(
+ str(context.exception),
+ "The `fit` method of `MockModel` must call:\n"
+ "```\n"
+ "self.config._input_columns = self._set_input_columns_and_arity(*args)"
+ "\n```\n"
+ "Where `*args` are the columns for each input dataset.",
+ )
+
+ def test_process_transform_without_input_columns(self):
+ # Test _process_fit without input_columns
+ with self.assertRaises(AssertionError) as context:
+ self.preprocessor._process_transform(self.X)
+ self.assertIn(
+ str(context.exception),
+ "The `transform` method of `MockPreprocessor` must call:\n"
+ "```\n"
+ "self._input_columns = self._set_input_columns_and_arity(*args)"
+ "\n```\n"
+ "Where `*args` are the columns for each input dataset.",
+ )
+
+ with self.assertRaises(AssertionError) as context:
+ self.model._method_prefix = "_predict"
+ self.model._process_transform(self.X)
+ self.assertIn(
+ str(context.exception),
+ "The `predict` method of `MockModel` must call:\n"
+ "```\n"
+ "self._input_columns = self._set_input_columns_and_arity(*args)"
+ "\n```\n"
+ "Where `*args` are the columns for each input dataset.",
+ )
+
+ with self.assertRaises(AssertionError) as context:
+ self.model._method_prefix = "_predict_proba"
+ self.model._process_transform(self.X)
+ self.assertIn(
+ str(context.exception),
+ "The `predict_proba` method of `MockModel` must call:\n"
+ "```\n"
+ "self._input_columns = self._set_input_columns_and_arity(*args)"
+ "\n```\n"
+ "Where `*args` are the columns for each input dataset.",
+ )
+
+ def test_process_fit_with_absolute_path_cache_file_name(self):
+ # make cache_file_name an absolute path
+ cache_file_name = (
+ biofit.config.BIOFIT_PROCESSORS_CACHE / "cache.json"
+ ).as_posix()
+
+ self.assertFalse(self.model.is_fitted)
+ with self.assertRaises(ValueError) as context:
+ self.model.config._input_columns = self.model._set_input_columns_and_arity(
+ None, None
+ )
+ self.model._process_fit(self.X, self.y, cache_file_name=cache_file_name)
+ delattr(self.model.config, "_input_columns")
+
+ self.assertIn(
+ str(context.exception),
+ "`cache_file_name` is an absolute path. Please provide the "
+ "file name only. You can specify the directory using "
+ "`cache_dir`.",
+ )
+
+ def test_process_fit_with_remote_cache_file_name(self):
+ # make cache_file_name an absolute path
+ cache_file_name = "s3://bucket/cache.json"
+
+ self.assertFalse(self.model.is_fitted)
+ with self.assertRaises(ValueError) as context:
+ self.model.config._input_columns = self.model._set_input_columns_and_arity(
+ None, None
+ )
+ self.model._process_fit(self.X, self.y, cache_file_name=cache_file_name)
+ delattr(self.model.config, "_input_columns")
+
+ self.assertIn(
+ str(context.exception),
+ "`cache_file_name` is a remote URL. Please provide the "
+ "file name only. You can specify the directory using "
+ "`cache_dir`.",
+ )
+
+ def test_process_tranform_with_absolute_path_cache_file_name(self):
+ # make cache_file_name an absolute path
+ cache_file_name = (
+ biofit.config.BIOFIT_PROCESSORS_CACHE / "cache.json"
+ ).as_posix()
+
+ with self.assertRaises(ValueError) as context:
+ self.model._input_columns = self.model._set_input_columns_and_arity(
+ None, None
+ )
+ self.model._process_transform(self.X, cache_file_name=cache_file_name)
+ delattr(self.model, "_input_columns")
+
+ self.assertIn(
+ str(context.exception),
+ "`cache_file_name` is an absolute path. Please provide the "
+ "file name only. You can specify the directory using "
+ "`cache_dir`.",
+ )
+
+ def test_process_transform_with_remote_cache_file_name(self):
+ # make cache_file_name an absolute path
+ cache_file_name = "s3://bucket/cache.json"
+
+ with self.assertRaises(ValueError) as context:
+ self.model._input_columns = self.model._set_input_columns_and_arity(
+ None, None
+ )
+ self.model._process_transform(self.X, cache_file_name=cache_file_name)
+ delattr(self.model, "_input_columns")
+
+ self.assertIn(
+ str(context.exception),
+ "`cache_file_name` is a remote URL. Please provide the "
+ "file name only. You can specify the directory using "
+ "`cache_dir`.",
+ )
+
+ def test_load_processed_estimator_from_cache_invalid_file(self):
+ # Use a non-existent cache file path
+ cache_file_name = "/non/existent/cache/file"
+
+ # Expect NonExistentCacheError to be raised
+ with self.assertRaises(NonExistentCacheError):
+ self.model.load_processed_estimator_from_cache(cache_file_name)
+
+ def test_process_fit_without_loading_from_cache(self):
+ # Set up the processor with caching enabled
+ self.model.config.enable_caching = True
+ self.model.config.cache_output = True
+ self.model.config.load_from_cache_file = False
+ self.assertFalse(self.model.is_fitted)
+
+ # Mock the load_processed_estimator_from_cache method to raise NonExistentCacheError
+ with patch.object(
+ self.model,
+ "load_processed_estimator_from_cache",
+ side_effect=NonExistentCacheError,
+ ) as mock_load_cache:
+ with patch.object(
+ self.model.config,
+ "save_to_cache",
+ wraps=self.model.config.save_to_cache,
+ ) as mock_save_cache:
+ # Mock the _fit method to confirm it is called
+ # Call _process_fit
+ self.model.config._input_columns = (
+ self.model._set_input_columns_and_arity(None, None)
+ )
+ self.model._process_fit(
+ self.X,
+ self.y,
+ cache_file_name="cache.json",
+ fingerprint="test",
+ )
+
+ mock_load_cache.assert_not_called()
+ mock_save_cache.assert_called_once()
+
+ cache_dir = generate_cache_dir(
+ self.model,
+ self.model.config._data_fingerprint,
+ root_dir=biofit.config.BIOFIT_PROCESSORS_CACHE,
+ )
+ cache_path = Path(cache_dir) / "cache.json"
+ self.assertTrue(cache_path.exists())
+
+ # Check that the model is marked as fitted
+ self.assertTrue(self.model.is_fitted)
+
+ def test_process_fit_without_saving_output_to_cache(self):
+ self.assertFalse(self.model.is_fitted)
+
+ # Run it once to save the cache
+ self.model.config._input_columns = self.model._set_input_columns_and_arity(
+ None, None
+ )
+ self.model._process_fit(
+ self.X,
+ self.y,
+ cache_file_name="cache.json",
+ fingerprint="test",
+ )
+
+ # Set up the processor with only loading from cache enabled
+ self.model.config.enable_caching = True
+ self.model.config.cache_output = False
+ self.model.config.load_from_cache_file = True
+
+ # Also check if fingerprint is consistent
+ old_data_fingerprint = self.model.config._data_fingerprint
+ old_processor_fingerprint = self.model.fingerprint
+
+ with patch.object(
+ self.model,
+ "load_processed_estimator_from_cache",
+ side_effect=NonExistentCacheError,
+ ) as mock_load_cache:
+ with patch.object(
+ self.model.config,
+ "save_to_cache",
+ wraps=self.model.config.save_to_cache,
+ ) as mock_save_cache:
+ # Run it again to load from cache
+ self.model.config._input_columns = (
+ self.model._set_input_columns_and_arity(None, None)
+ )
+ self.model._process_fit(
+ self.X,
+ self.y,
+ cache_file_name="cache.json",
+ fingerprint="test",
+ )
+
+ # Check to see if fingerprint is consistent
+ self.assertEqual(
+ old_data_fingerprint, self.model.config._data_fingerprint
+ )
+ self.assertEqual(old_processor_fingerprint, self.model.fingerprint)
+
+ cache_dir = generate_cache_dir(
+ self.model,
+ self.model.config._data_fingerprint,
+ root_dir=biofit.config.BIOFIT_PROCESSORS_CACHE,
+ )
+ cache_path = Path(cache_dir) / "cache.json"
+
+ # Load should be called but save should not
+ mock_load_cache.assert_called_once_with(cache_path.as_posix())
+ mock_save_cache.assert_not_called()
+
+ self.assertTrue(cache_path.exists())
+
+ # Check that the model is marked as fitted
+ self.assertTrue(self.model.is_fitted)
+
+ @require_datasets
+ def test_transform_output_consistency_with_and_without_cache(self):
+ # Fit the model first without caching
+ X = DataHandler.to_dataset(self.X)
+ y = DataHandler.to_dataset(self.y)
+ original_fingerprint = X._fingerprint
+ parameters = inspect.signature(self.model._fit_sklearn).parameters
+ with patch.object(
+ self.model, "_fit_sklearn", wraps=self.model._fit_sklearn
+ ) as mock_fit, patch("biofit.processing.BaseProcessor._validate_fit_func_args"):
+ self.model._fit_sklearn.__name__ = "_fit_sklearn"
+ # change signature of fit method
+ self.model._fit_sklearn.__signature__ = inspect.Signature(
+ parameters=list(parameters.values())
+ )
+ self.model.fit(X, y)
+
+ mock_fit.assert_called_once()
+ output_no_cache = self.model.predict(X)
+
+ # Calculate the fingerprint of the output dataset without cache
+ fingerprint_no_cache = output_no_cache._fingerprint
+
+ with patch.object(
+ self.model, "_fit_sklearn", wraps=self.model._fit_sklearn
+ ) as mock_fit:
+ # Fit the model again (should create cache files)
+ self.model._fit_sklearn.__name__ = "_fit_sklearn"
+ self.model.fit(X, y)
+
+ # Since loaded from cache, fit should not be called
+ mock_fit.assert_not_called()
+
+ # Transform the data with cache enabled
+ output_with_cache = self.model.predict(X)
+
+ # Calculate the fingerprint of the output dataset with cache
+ fingerprint_with_cache = output_with_cache._fingerprint
+
+ # Transform a separate data with cache enabled
+ other_output_with_cache = self.model.predict(y)
+
+ other_fingerprint_with_cache = other_output_with_cache._fingerprint
+
+ # Compare the outputs
+ self.assertEqual(len(output_no_cache), len(output_with_cache))
+ output_no_cache = DataHandler.to_pandas(output_no_cache)
+ output_with_cache = DataHandler.to_pandas(output_with_cache)
+ pd.testing.assert_frame_equal(output_no_cache, output_with_cache)
+ # Compare the fingerprints
+ self.assertEqual(original_fingerprint, X._fingerprint)
+ self.assertEqual(fingerprint_no_cache, fingerprint_with_cache)
+ self.assertEqual(fingerprint_no_cache, fingerprint_with_cache)
+ self.assertNotEqual(fingerprint_no_cache, other_fingerprint_with_cache)
+
+ def test_parse_fingerprint(self):
+ fingerprint = "base_fingerprint"
+ parsed_fp = self.preprocessor._parse_fingerprint(fingerprint)
+ self.assertIn(fingerprint, parsed_fp)
+ self.assertIn(self.config.processor_name, parsed_fp)
+
+ def test_reset(self):
+ # Test if the reset method resets the processor to the state before the first
+ # fit was called
+ preprocessor = copy.deepcopy(self.preprocessor.fit(self.X, self.y))
+
+ # Change the parameters of the processor
+ preprocessor.config.processor_param_int += 10
+ preprocessor.config.processor_param_float = 10.0
+
+ # Reset the processor
+ preprocessor._reset(preprocessor.config_)
+
+ # Should be the same as the original processor
+ self.assertEqual(
+ preprocessor.config.get_params(), self.preprocessor.config.get_params()
+ )
+
+ def test_from_config_classmethod(self):
+ new_processor = MockPreprocessor.from_config(self.config)
+ self.assertIsInstance(new_processor, MockPreprocessor)
+ self.assertEqual(new_processor.config, self.config)
+
+ def test_validate_fit_params_input_raise(self):
+ self.preprocessor.config._fit_input_feature_types = [get_feature("Abundance")]
+ self.preprocessor.config._fit_unused_feature_types = None
+ self.preprocessor.config._transform_input_feature_types = None
+ self.preprocessor.config._transform_unused_feature_types = None
+ self.preprocessor.config._input_columns = [None, None]
+ with self.assertRaises(AssertionError) as context:
+ self.preprocessor._validate_fit_params(2)
+ self.assertIn(
+ str(context.exception),
+ "`_fit_input_feature_types` is defined in "
+ "MockProcessorConfig but does not match the arity of "
+ "the fit function in MockPreprocessor (i.e. len("
+ "self.config._fit_input_feature_types) != "
+ "len(self.config._input_columns) -> "
+ "1 != 2).\n"
+ "This can be corrected by doing, for example:\n"
+ "_fit_input_feature_types = field(\n"
+ " default_factory=lambda: [None, None], init=False, "
+ "repr=False\n"
+ ")",
+ )
+
+ def test_validate_fit_params_unused_raise(self):
+ self.preprocessor.config._fit_input_feature_types = None
+ self.preprocessor.config._fit_unused_feature_types = [get_feature("Abundance")]
+ self.preprocessor.config._transform_input_feature_types = None
+ self.preprocessor.config._transform_unused_feature_types = None
+ self.preprocessor.config._input_columns = [None, None]
+ with self.assertRaises(AssertionError) as context:
+ self.preprocessor._validate_fit_params(2)
+ self.assertIn(
+ str(context.exception),
+ "`_fit_unused_feature_types` is defined in "
+ "MockProcessorConfig but does not match the arity of "
+ "the fit function in MockPreprocessor (i.e. len("
+ "self.config._fit_unused_feature_types) != "
+ "len(self.config._input_columns) -> "
+ "1 != 2).\n"
+ "This can be corrected by doing, for example:\n"
+ "_fit_unused_feature_types = field(\n"
+ " default_factory=lambda: [None, None], init=False, "
+ "repr=False\n"
+ ")",
+ )
+
+ def test_validate_transform_params_input_raise(self):
+ self.preprocessor.config._fit_input_feature_types = None
+ self.preprocessor.config._fit_unused_feature_types = None
+ self.preprocessor.config._transform_input_feature_types = [
+ get_feature("Abundance")
+ ]
+ self.preprocessor.config._transform_unused_feature_types = None
+ self.preprocessor._input_columns = [None, None]
+ with self.assertRaises(AssertionError) as context:
+ self.preprocessor._validate_transform_params(2)
+ self.assertIn(
+ str(context.exception),
+ "`_transform_input_feature_types` is defined in "
+ "MockProcessorConfig but does not match the arity of "
+ "the transform function in MockPreprocessor (i.e. len("
+ "self.config._transform_input_feature_types) != "
+ "len(self._input_columns) -> "
+ "1 != 2).\n"
+ "This can be corrected by doing, for example:\n"
+ "_transform_input_feature_types = field(\n"
+ " default_factory=lambda: [None, None], init=False, "
+ "repr=False\n"
+ ")",
+ )
+
+ def test_validate_transform_params_unused_raise(self):
+ self.preprocessor.config._fit_input_feature_types = None
+ self.preprocessor.config._fit_unused_feature_types = None
+ self.preprocessor.config._transform_input_feature_types = None
+ self.preprocessor.config._transform_unused_feature_types = [
+ get_feature("Abundance")
+ ]
+ self.preprocessor._input_columns = [None, None]
+ with self.assertRaises(AssertionError) as context:
+ self.preprocessor._validate_transform_params(2)
+ self.assertIn(
+ str(context.exception),
+ "`_transform_unused_feature_types` is defined in "
+ "MockProcessorConfig but does not match the arity of "
+ "the transform function in MockPreprocessor (i.e. len("
+ "self.config._transform_unused_feature_types) != "
+ "len(self._input_columns) -> "
+ "1 != 2).\n"
+ "This can be corrected by doing, for example:\n"
+ "_transform_unused_feature_types = field(\n"
+ " default_factory=lambda: [None, None], init=False, "
+ "repr=False\n"
+ ")",
+ )
+
+ def test_fit_missing_input_cols(self):
+ with self.assertRaises(AssertionError) as context:
+ self.preprocessor._fit(self.X, self.y, funcs=[self.preprocessor._fit_arrow])
+
+ self.assertIn(
+ "`extra_indices` was returned as `None` from "
+ "`MockPreprocessor`. "
+ "Was `MockPreprocessor._input_columns` or "
+ "`MockPreprocessor.config._input_columns` set correctly?",
+ str(context.exception),
+ )
+
+ def test_fit_transform(self):
+ self.preprocessor.fit_transform(self.X, self.y)
+
+ def test_fit_transform_with_columns(self):
+ self.preprocessor.fit_transform(
+ self.data, input_columns=self.column_names, target_column=self.target_column
+ )
+
+ def test_fit_transform_without_automatic_column_selection(self):
+ # First, fit the processor
+ with self.assertRaises(AssertionError) as context:
+ self.preprocessor.config._fit_input_feature_types = [None, None]
+ self.preprocessor.fit_transform(self.data)
+ self.assertIn(
+ str(context.exception),
+ "`MockPreprocessor.fit` requires 2 arguments ('X', 'y'), but only 1 was "
+ "provided. Either provide the missing arguments or provide the input "
+ "columns found in 'X', if applicable.",
+ )
+
+ def test_fit_transform_with_missing_col_in_X(self):
+ # First, fit the processor
+ with self.assertRaises(AssertionError) as context:
+ self.preprocessor.fit_transform(
+ self.X
+ ) # does not any contain TARGET_FEATURE_TYPES
+ self.assertIn(
+ str(context.exception),
+ "`MockPreprocessor.fit` requires 2 arguments ('X', 'y'), but only 1 was "
+ "provided. Either provide the missing arguments or provide the input "
+ "columns found in 'X', if applicable.",
+ )
+
+ def test_fit_transform_output_format(self):
+ # First, fit the processor
+ self.preprocessor = self.preprocessor.fit(self.X, self.y)
+ out1 = self.preprocessor.transform(self.X, output_format="numpy")
+ out2 = self.preprocessor.fit_transform(self.X, self.y, output_format="numpy")
+ self.assertIs(type(out1), np.ndarray)
+ self.assertTrue(np.array_equal(out1, out2))
+
+ def test_process_extra_inds(self):
+ extra_indices = [[1, 2]]
+ unused_extra_indices = [[3]]
+ extra_inputs = [None]
+ orig_input = None
+ result = self.preprocessor._process_extra_inds(
+ orig_input, extra_inputs, extra_indices, unused_extra_indices
+ )
+ self.assertEqual(result, (extra_indices, unused_extra_indices))
+
+ def test_prepare_fit_kwargs(self):
+ combined_inputs = self.data
+ orig_input = self.sample_metadata
+ extra_inputs = [self.X, self.y]
+ selected_indices = list(range(self.data.shape[1]))
+ extra_indices = [
+ list(range(self.X.shape[1])),
+ list(range(self.y.shape[1])),
+ ]
+ map_kwargs, pooler = self.preprocessor._prepare_fit_kwargs(
+ funcs=[
+ self.preprocessor._fit_arrow,
+ self.preprocessor._fit_pandas,
+ self.preprocessor._fit_numpy,
+ ],
+ combined_inputs=combined_inputs,
+ orig_input=orig_input,
+ selected_indices=selected_indices,
+ extra_inputs=extra_inputs,
+ extra_indices=extra_indices,
+ extra_untouched_inputs=None,
+ map_kwargs={"fn_kwargs": {}},
+ num_proc=1,
+ )
+
+ # The expected extra indices should be adjusted to match the index
+ # within the combined_inputs (i.e. offset by the number of columns
+ # in the orig_input)
+ expected_extra_indices = [
+ list(
+ range(
+ self.sample_metadata.shape[1],
+ self.sample_metadata.shape[1] + self.X.shape[1],
+ )
+ ),
+ list(
+ range(
+ self.sample_metadata.shape[1] + self.X.shape[1],
+ self.data.shape[1],
+ )
+ ),
+ ]
+ expected_map_kwargs = {
+ "fn_kwargs": {
+ "fn": self.preprocessor._fit_pandas, # should match data type of combined_inputs
+ "func_type": "_fit",
+ "extra_untouched_inputs": None,
+ "selected_indices": selected_indices,
+ "unused_indices": None,
+ "extra_indices": expected_extra_indices, # should be adjusted to match the data type of extra_inputs
+ "unused_extra_indices": [None, None],
+ "with_metadata": False,
+ "in_format_kwargs": {"target_format": "pandas"},
+ "out_format_kwargs": {"target_format": None},
+ },
+ "with_indices": False,
+ "with_rank": False,
+ "desc": "Fitting test",
+ "batched": True,
+ "batch_size": None,
+ "new_fingerprint": "None-processor-mock",
+ }
+ for key, value in expected_map_kwargs.items():
+ if key == "fn_kwargs":
+ for k, v in value.items():
+ self.assertEqual(
+ map_kwargs[key][k], v, f"{k}: {map_kwargs[key][k]} != {v}"
+ )
+ else:
+ self.assertEqual(map_kwargs[key], value)
+ self.assertIsNone(pooler)
+
+ def test_prepare_transform_kwargs(self):
+ combined_inputs = self.data
+ orig_input = self.sample_metadata
+ extra_inputs = [self.X, self.y]
+ selected_indices = list(range(self.data.shape[1]))
+ extra_indices = [
+ list(range(self.X.shape[1])),
+ list(range(self.y.shape[1])),
+ ]
+ map_kwargs = self.preprocessor._prepare_transform_kwargs(
+ combined_inputs,
+ orig_input,
+ self.X,
+ self.y,
+ selected_indices=selected_indices,
+ extra_indices=extra_indices,
+ unused_indices=None,
+ unused_extra_indices=None,
+ map_kwargs={"fn_kwargs": {}},
+ num_proc=1,
+ )
+
+ # The expected extra indices should be adjusted to match the index
+ # within the combined_inputs (i.e. offset by the number of columns
+ # in the orig_input)
+ expected_extra_indices = [
+ list(
+ range(
+ self.sample_metadata.shape[1],
+ self.sample_metadata.shape[1] + self.X.shape[1],
+ )
+ ),
+ list(
+ range(
+ self.sample_metadata.shape[1] + self.X.shape[1],
+ self.data.shape[1],
+ )
+ ),
+ ]
+ expected_map_kwargs = {
+ "fn_kwargs": {
+ "fn": self.preprocessor._transform_pandas, # should match data type of combined_inputs
+ "func_type": "_transform",
+ "with_metadata": False,
+ "selected_indices": selected_indices,
+ "unused_indices": None,
+ "extra_indices": expected_extra_indices,
+ "unused_extra_indices": [None, None],
+ "keep_unused_columns": None,
+ "in_format_kwargs": {"target_format": "pandas"},
+ "out_format_kwargs": {"target_format": "arrow"},
+ "feature_names": list(self.sample_metadata.columns),
+ },
+ "with_indices": False,
+ "with_rank": False,
+ "desc": "Transforming test",
+ "keep_in_memory": False,
+ "cache_file_name": None,
+ "num_proc": 1,
+ "batched": True,
+ "batch_size": 1000,
+ "load_from_cache_file": True,
+ "new_fingerprint": None,
+ }
+
+ for key, value in expected_map_kwargs.items():
+ if key == "fn_kwargs":
+ for k, v in value.items():
+ self.assertEqual(
+ map_kwargs[key][k], v, f"{k}: {map_kwargs[key][k]} != {v}"
+ )
+ else:
+ self.assertEqual(map_kwargs[key], value)
+
+ def test_process_transform_input(self):
+ X = "input_data"
+ kwargs = {"key": "value"}
+ result_X, result_kwargs = self.preprocessor._process_transform_input(
+ X, **kwargs
+ )
+ self.assertEqual(result_X, X)
+ self.assertEqual(result_kwargs, kwargs)
+
+ def test_process_transform_output(self):
+ result = self.preprocessor._process_transform_output(self.X, self.X)
+ pd.testing.assert_frame_equal(result, self.X)
+
+ def test_get_params(self):
+ params = self.preprocessor.get_params()
+ self.assertIsInstance(params, dict)
+ self.assertEqual(params["processor_param_int"], 1)
+
+ def test_load_processed_estimator_from_cache(self):
+ with self.assertRaises(Exception):
+ self.preprocessor.load_processed_estimator_from_cache(
+ "nonexistent_cache_file"
+ )
+
+ # TODO: Make a better test for this
+ # def test_get_features_out(self):
+ # X = pd.DataFrame({"col1": [1, 2], "col2": [3, 4]})
+ # features_out = self.preprocessor._get_features_out(X, selected_indices=[0])
+ # self.assertIsInstance(features_out, dict)
+
+ def test_prepare_runner(self):
+ fn_kwargs = {}
+ result_kwargs = self.preprocessor._prepare_runner(None, **fn_kwargs)
+ self.assertEqual(result_kwargs, fn_kwargs)
+
+ # def test_run(self):
+ # def runner(X, **kwargs):
+ # return "processed"
+ #
+ # result = self.preprocessor.run("input_data", runner=runner)
+ # self.assertEqual(result, "processed")
+
+ def test_get_feature_names_out(self):
+ # Testing _get_feature_names_out method
+ input_features = ["feature1", "feature2"]
+ n_features_out = 2
+ useful_feature_inds = [0, 1]
+ out_features = self.preprocessor._get_feature_names_out(
+ input_features=input_features,
+ n_features_out=n_features_out,
+ useful_feature_inds=useful_feature_inds,
+ one_to_one_features=True,
+ )
+ self.assertEqual(out_features, ["feature1", "feature2"])
+
+ # Test with prefix and suffix
+ self.preprocessor.config.features_out_prefix = "prefix_"
+ self.preprocessor.config.features_out_suffix = "_suffix"
+ out_features = self.preprocessor._get_feature_names_out(
+ input_features=input_features,
+ n_features_out=n_features_out,
+ useful_feature_inds=useful_feature_inds,
+ one_to_one_features=True,
+ )
+ self.assertEqual(
+ out_features, ["prefix_feature1_suffix", "prefix_feature2_suffix"]
+ )
+
+ # def test_transform_with_unused_columns(self):
+ # # test transforming data while keeping unused columns
+ # x = self.x.copy()
+ # self.preprocessor.fit(x[["feature1", "feature2"]])
+ #
+ # transformed_x = self.preprocessor.transform(x, keep_unused_columns=true)
+ #
+ # # verify that unused columns are present
+ # self.assertin("metadata", transformed_x.columns)
+ # self.assertin("target", transformed_x.columns)
+ # # verify that 'feature1' and 'feature2' are centered
+ # expected_x = x[["feature1", "feature2"]] - x[["feature1", "feature2"]].mean()
+ # pd.testing.assert_frame_equal(
+ # transformed_x[["feature1", "feature2"]], expected_x
+ # )
+ # # verify that unused columns are unchanged
+ # pd.testing.assert_series_equal(transformed_x["metadata"], x["metadata"])
+ # pd.testing.assert_series_equal(transformed_x["target"], x["target"])
+
+ def test_process_fit_batch_input_valid(self):
+ """
+ Test _process_fit_batch_input with valid arguments.
+ """
+ input_processed, fn_args, fn_kwargs = self.model._process_fit_batch_input(
+ DataHandler.to_arrow(self.data),
+ fn=self.model._fit_sklearn,
+ selected_indices=[0, 1, 2],
+ extra_indices=[[-1]],
+ unused_indices=[],
+ extra_untouched_inputs=None,
+ with_metadata=False,
+ in_format_kwargs={"target_format": "pandas"},
+ )
+
+ # Assertions
+ self.assertIsNotNone(input_processed)
+ self.assertIsInstance(fn_args, tuple)
+ self.assertEqual(len(fn_args), 1)
+ format = DataHandler.get_format(input_processed)
+ self.assertEqual(format, "pandas")
+
+ def test_process_fit_batch_unmatched_arity(self):
+ """
+ Test _process_fit_batch_input with empty input.
+ """
+ # pretty much the same as test_fit_transform_without_automatic_column_selection
+ # but directly testing the method
+ with self.assertRaises(AssertionError) as context:
+ self.model._process_fit_batch_input(
+ DataHandler.to_arrow(self.data),
+ fn=self.model._fit_sklearn,
+ selected_indices=[0, 1, 2],
+ extra_indices=[],
+ unused_indices=[],
+ extra_untouched_inputs=None,
+ with_metadata=False,
+ in_format_kwargs={"target_format": "pandas"},
+ )
+ self.assertIn(
+ str(context.exception),
+ "`MockModel.fit` requires 2 arguments ('X', 'y'), but only 1 was "
+ "provided. Either provide the missing arguments or provide the input "
+ "columns found in 'X', if applicable.",
+ )
+
+ def test_process_fit_batch_output_invalid(self):
+ """
+ Test _process_fit_batch_output with invalid output.
+ """
+ # This test doesn't really do anything, but it's here for completeness
+ output = None
+
+ processed_output = self.model._process_fit_batch_output(output)
+
+ # Assertions
+ self.assertIsNone(processed_output)
+
+ def test_process_transform_batch_input_valid(self):
+ """
+ Test _process_transform_batch_input with valid arguments.
+ """
+ input_data = self.X
+
+ input_processed, fn_args, _ = self.preprocessor._process_transform_batch_input(
+ input_data,
+ fn=self.preprocessor._transform_pandas,
+ selected_indices=[0, 1, 2],
+ unused_indices=[3],
+ keep_unused_columns=True,
+ with_metadata=False,
+ in_format_kwargs={},
+ out_format_kwargs={},
+ )
+
+ # Assertions
+ self.assertIsNotNone(input_processed)
+ self.assertIsInstance(fn_args, tuple)
+ self.assertEqual(len(fn_args), 0)
+
+ def test_process_transform_batch_input_empty(self):
+ """
+ Test _process_transform_batch_input with empty input.
+ """
+ with self.assertRaises(AssertionError) as context:
+ input_data = None
+
+ self.preprocessor._process_transform_batch_input(
+ input_data,
+ fn=self.preprocessor._transform_pandas,
+ selected_indices=[0, 1, 2],
+ unused_indices=[3],
+ keep_unused_columns=True,
+ with_metadata=False,
+ in_format_kwargs={},
+ out_format_kwargs={},
+ )
+
+ self.assertIn(
+ str(context.exception),
+ "No input data was provided for processing.",
+ )
+
+ def test_process_transform_batch_input_invalid_columns(self):
+ """
+ Test _process_transform_batch_input with invalid/missing columns.
+ """
+ with self.assertRaises(IndexError):
+ input_data = self.X
+
+ self.preprocessor._process_transform_batch_input(
+ input_data,
+ fn=self.preprocessor._transform_pandas,
+ selected_indices=[self.X.shape[1] + 1],
+ unused_indices=[3],
+ keep_unused_columns=True,
+ with_metadata=False,
+ in_format_kwargs={},
+ out_format_kwargs={},
+ )
+
+ def test_process_transform_batch_output_valid(self):
+ """
+ Test _process_transform_batch_output with valid output.
+ """
+ input = DataHandler.to_polars(self.X, [0, 1, 2, 3])
+ output = DataHandler.to_polars(self.X, [0, 1, 3])
+
+ processed_output = self.preprocessor._process_transform_batch_output(
+ input,
+ output,
+ selected_indices=[0, 1, 3],
+ unused_indices=[2],
+ keep_unused_columns=True,
+ feature_names=["test1", "test2", "int32_3", "test4"],
+ out_format_kwargs={"target_format": "arrow"},
+ one_to_one_features=True,
+ )
+
+ # Assertions
+ self.assertEqual(
+ processed_output.column_names, ["test1", "test2", "int32_3", "test4"]
+ )
+ self.assertIsInstance(processed_output, pa.Table)
+
+ # def test_process_transform_batch_output_keep_unused_columns(self):
+ # """
+ # Test _process_transform_batch_output with keeping unused columns.
+ # """
+ # transformed_X = self.preprocessor._transform_pandas(self.X)
+ #
+ # processed_output = self.preprocessor._process_transform_batch_output(
+ # self.X,
+ # transformed_X,
+ # fn_kwargs={
+ # "selected_indices": [0, 1, 2],
+ # "unused_indices": [3],
+ # "keep_unused_columns": True,
+ # "feature_names": ["sample_id", "multi_int", "multi_str", "labels"],
+ # "out_format_kwargs": {"target_format": "pandas"},
+ # },
+ # )
+ #
+ # # Assertions
+ # self.assertIn("labels", processed_output.column_names)
+ # self.assertEqual(len(processed_output.columns), 4)
+
+ # def test_process_transform_batch_output_discard_unused_columns(self):
+ # """
+ # Test _process_transform_batch_output with discarding unused columns.
+ # """
+ # transformed_X = self.preprocessor._transform_pandas(self.X)
+ #
+ # processed_output = self.preprocessor._process_transform_batch_output(
+ # self.X,
+ # transformed_X,
+ # fn_kwargs={
+ # "selected_indices": [0, 1, 2],
+ # "unused_indices": [3],
+ # "keep_unused_columns": False,
+ # "feature_names": ["sample_id", "multi_int", "multi_str"],
+ # "out_format_kwargs": {"target_format": "pandas"},
+ # },
+ # )
+ #
+ # # Assertions
+ # self.assertNotIn("labels", processed_output.columns)
+ # self.assertEqual(len(processed_output.columns), 3)
+
+ # def test_process_transform_batch_output_invalid_output_format(self):
+ # """
+ # Test _process_transform_batch_output with invalid output format.
+ # """
+ # transformed_X = self.preprocessor._transform_pandas(self.X)
+ #
+ # with self.assertRaises(ValueError):
+ # self.preprocessor._process_transform_batch_output(
+ # self.X,
+ # transformed_X,
+ # fn_kwargs={
+ # "selected_indices": [0, 1, 2],
+ # "unused_indices": [3],
+ # "keep_unused_columns": False,
+ # "feature_names": ["sample_id", "multi_int", "multi_str"],
+ # "out_format_kwargs": {"target_format": "invalid_format"},
+ # },
+ # )
+
+ # def test_process_fit_batch_input_invalid_fn_kwargs(self):
+ # """
+ # Test _process_fit_batch_input with invalid fn_kwargs.
+ # """
+ # with self.assertRaises(AttributeError):
+ # input_data = self.X
+ # y = self.y
+ #
+ # # Pass incorrect fn_kwargs (missing 'fn')
+ # self.model._process_fit_batch_input(
+ # input_data,
+ # y,
+ # fn_kwargs={
+ # "selected_indices": [0, 1, 2],
+ # "extra_indices": [],
+ # "unused_indices": [],
+ # "extra_untouched_inputs": [],
+ # "with_metadata": False,
+ # "indicate_last_batch": False,
+ # "in_format_kwargs": {},
+ # "out_format_kwargs": {},
+ # "with_target": True,
+ # },
+ # )
+
+ # def test_process_transform_batch_input_with_extra_inputs(self):
+ # """
+ # Test _process_transform_batch_input with extra inputs.
+ # """
+ # with patch.object(
+ # self.preprocessor, "_transform_pandas", return_value=self.X
+ # ) as mock_transform:
+ # input_data = self.X
+ # extra_input = pa.table({"extra_col": [4, 5, 6]})
+ #
+ # input_processed, fn_args, fn_kwargs = (
+ # self.preprocessor._process_transform_batch_input(
+ # input_data,
+ # extra_input,
+ # fn_kwargs={
+ # "fn": self.preprocessor._transform_pandas,
+ # "selected_indices": [0, 1, 2],
+ # "unused_indices": [3],
+ # "keep_unused_columns": True,
+ # "with_metadata": False,
+ # "in_format_kwargs": {},
+ # "out_format_kwargs": {},
+ # },
+ # )
+ # )
+ #
+ # # Assertions
+ # self.assertIsNotNone(input_processed)
+ # self.assertIsInstance(fn_args, tuple)
+ # self.assertEqual(len(fn_args), 1)
+ # self.assertIsInstance(fn_args[0], pa.Table)
+ # self.assertEqual(fn_kwargs["fn"], self.preprocessor._transform_pandas)
+ # mock_transform.assert_called_once()
+
+ def test_process_fit_input(self):
+ input = "input_data"
+ kwargs = {"key": "value"}
+ result_input, result_kwargs = self.preprocessor._process_fit_input(
+ input, **kwargs
+ )
+ self.assertEqual(result_input, input)
+ self.assertEqual(result_kwargs, kwargs)
+
+ def test_process_fit_output(self):
+ input = "input_data"
+ out = "output_data"
+ result = self.preprocessor._process_fit_output(input, out)
+ self.assertEqual(result, out)
+
+ # def test_repr(self):
+ # repr_str = repr(self.preprocessor)
+ # self.assertIsInstance(repr_str, str)
+ # self.assertIn("MockPreprocessor", repr_str)
+ #
+ # def test_repr_mimebundle(self):
+ # mimebundle = self.preprocessor._repr_mimebundle_()
+ # self.assertIsInstance(mimebundle, dict)
+ # self.assertIn("text/plain", mimebundle)
+ #
+ # def test_shard(self):
+ # # Testing shard method
+ # num_shards = 5
+ # for index in range(num_shards):
+ # shard = self.preprocessor.shard(
+ # self.X, num_shards=num_shards, index=index, contiguous=True
+ # )
+ # # Verify that the shards cover the entire dataset when combined
+ # if index == 0:
+ # combined_shard = shard
+ # else:
+ # combined_shard = pd.concat([combined_shard, shard], ignore_index=True)
+ # pd.testing.assert_frame_equal(
+ # combined_shard.reset_index(drop=True), self.X.reset_index(drop=True)
+ # )
+
+ # def test_pool_fit(self):
+ # # Testing _pool_fit method
+ # # Since pooling is not implemented for multiple processors, expect NotImplementedError
+ #
+ # # Create multiple fitted processors
+ # fitted_processors = [
+ # self.preprocessor.fit(self.X[["feature1", "feature2"]]) for _ in range(2)
+ # ]
+ #
+ # with self.assertRaises(NotImplementedError):
+ # self.preprocessor._pool_fit(fitted_processors)
+ #
+ # # Test with a single processor
+ # pooled_processor = self.preprocessor._pool_fit([self.preprocessor])
+ # self.assertEqual(pooled_processor, self.preprocessor)
+ #
+ # def test_map(self):
+ # X = pd.DataFrame({"col1": [1, 2, 3]})
+ #
+ # def func(batch):
+ # return batch * 2
+ #
+ # result = self.preprocessor.map(X, function=func, batched=True)
+ # expected = X * 2
+ # pd.testing.assert_frame_equal(result[0], expected)
+ #
+ # def test_map_single(self):
+ # shard = pd.DataFrame({"col1": [1, 2, 3]})
+ #
+ # def func(batch):
+ # return batch * 2
+ #
+ # gen = BaseProcessor._map_single(shard, function=func, batched=True)
+ # results = list(gen)
+ # for rank, done, content in results:
+ # if done:
+ # processed_data = content
+ # expected = shard * 2
+ # pd.testing.assert_frame_equal(processed_data[0], expected)
diff --git a/tests/utils.py b/tests/utils.py
new file mode 100644
index 0000000..9eac24a
--- /dev/null
+++ b/tests/utils.py
@@ -0,0 +1,133 @@
+import unittest
+
+import pandas as pd
+from biocore.data_handling import DataHandler
+from biocore.utils.import_util import (
+ is_biosets_available,
+ is_datasets_available,
+ is_polars_available,
+ is_rpy2_available,
+)
+
+
+def require_polars(test_case):
+ """
+ Decorator marking a test that requires Polars.
+
+ These tests are skipped when Polars isn't installed.
+
+ """
+ if not is_polars_available():
+ test_case = unittest.skip("test requires Polars")(test_case)
+ return test_case
+
+
+def require_biosets(test_case):
+ """
+ Decorator marking a test that requires biosets.
+
+ These tests are skipped when biosets isn't installed.
+
+ """
+ if not is_biosets_available():
+ test_case = unittest.skip("test requires biosets")(test_case)
+ return test_case
+
+
+def require_datasets(test_case):
+ """
+ Decorator marking a test that requires datasets.
+
+ These tests are skipped when datasets isn't installed.
+
+ """
+ if not is_datasets_available():
+ test_case = unittest.skip("test requires datasets")(test_case)
+ return test_case
+
+
+def require_rpy2(test_case):
+ """
+ Decorator marking a test that requires rpy2.
+
+ These tests are skipped when rpy2 isn't installed.
+
+ """
+ if not is_rpy2_available():
+ test_case = unittest.skip("test requires rpy2")(test_case)
+ return test_case
+
+
+def create_bioset(
+ X,
+ y=None,
+ sample_metadata=None,
+ with_feature_metadata=True,
+ feature_type=None,
+ target_type=None,
+):
+ """
+ Create a bioset from data.
+
+ Args:
+ X (np.ndarray or pd.DataFrame): The data matrix.
+ y (np.ndarray or pd.Series, optional): The target vector.
+ metadata (pd.DataFrame, optional): The metadata.
+ feature_metadata (pd.DataFrame, optional): The feature metadata.
+
+ Returns:
+ bioset (Bioset): The bioset.
+
+ """
+ from biosets.features import Metadata, Sample, Batch
+ from datasets import ClassLabel, Features
+
+ data = [X]
+ if y is not None:
+ data.append(y)
+ if sample_metadata is not None:
+ data = [sample_metadata] + data
+
+ data = pd.concat(data, axis=1)
+ dtypes = DataHandler.get_dtypes(X)
+
+ features = {}
+ if sample_metadata is not None:
+ sample_metadata_dtypes = DataHandler.get_dtypes(sample_metadata)
+ for k, v in sample_metadata_dtypes.items():
+ if "sample" in k.lower():
+ features[k] = Sample(v)
+ elif "batch" in k.lower():
+ features[k] = Batch(v)
+ else:
+ features[k] = Metadata(v)
+
+ if with_feature_metadata:
+ metadata = {
+ "my_metadata_str": "str",
+ "my_metadata_int": 0,
+ "my_metadata_float": 0.0,
+ "my_metadata_bool": False,
+ "my_metadata_list": ["a", "b", "c"],
+ "my_metadata_dict": {"a": 1, "b": 2, "c": 3},
+ "my_metadata_none": None,
+ }
+ features.update(
+ {k: feature_type(dtype=v, metadata=metadata) for k, v in dtypes.items()}
+ )
+ else:
+ features.update({k: feature_type(dtype=v) for k, v in dtypes.items()})
+ if y is not None and target_type is not None:
+ if issubclass(target_type, ClassLabel):
+ names = list(set(y.values.flatten().tolist()))
+
+ if isinstance(names[0], int):
+ names = ["abcdefghiklmnopqrstuvwxyz"[i] for i in names if i >= 0]
+
+ features[y.columns[0]] = target_type(num_classes=len(names), names=names)
+ else:
+ features[y.columns[0]] = target_type()
+
+ ds = DataHandler.to_bioset(data)
+ ds._info.features = Features(features)
+ return ds
diff --git a/tools/generate_auto_maps.py b/tools/generate_auto_maps.py
new file mode 100644
index 0000000..27b46ee
--- /dev/null
+++ b/tools/generate_auto_maps.py
@@ -0,0 +1,414 @@
+import ast
+import os
+import re
+from collections import defaultdict
+from pathlib import Path
+
+from biofit import DATASET_NAME_TO_OMIC_TYPE
+
+OMIC_TYPE_TO_DATASET_NAMES = defaultdict(list)
+ALL_DATASET_NAMES = set()
+for dataset_name, omic_type in DATASET_NAME_TO_OMIC_TYPE.items():
+ OMIC_TYPE_TO_DATASET_NAMES[omic_type].append(dataset_name)
+ ALL_DATASET_NAMES.add(dataset_name)
+
+
+class ClassFinder(ast.NodeVisitor):
+ def __init__(self, known_classes, module_prefix, target_names):
+ self.found_classes = {}
+ self.known_classes = known_classes
+ self.module_prefix = module_prefix
+ self.target_names = target_names
+
+ def visit_ClassDef(self, node):
+ # Check if any base of the current class is a known ProcessorConfig descendant
+ for base in node.bases:
+ if isinstance(base, ast.Name) and base.id in self.known_classes:
+ target_vals = [None] * len(self.target_names)
+ # Process each node in the class body
+ for body_node in node.body:
+ if isinstance(body_node, (ast.AnnAssign, ast.Assign)):
+ target_name = self.get_target_var_name(body_node)
+ if target_name in self.target_names:
+ index = self.target_names.index(target_name)
+ target_vals[index] = self.extract_name_value(
+ body_node.value
+ )
+ self.found_classes[node.name] = (
+ self.module_prefix,
+ base.id,
+ *target_vals,
+ )
+
+ # Assume this class is also a config now
+ self.known_classes.append(node.name)
+ self.generic_visit(node)
+
+ def get_target_var_name(self, node):
+ if isinstance(node, ast.AnnAssign):
+ return node.target.id
+ elif isinstance(node, ast.Assign) and isinstance(node.targets[0], ast.Name):
+ return node.targets[0].id
+ return None
+
+ def extract_name_value(self, node):
+ if isinstance(node, ast.Name):
+ return node.id
+ elif isinstance(node, ast.Constant):
+ return node.value
+ elif isinstance(node, ast.Call):
+ # Check if the first positional argument or a keyword argument 'default' is set
+ if node.args:
+ return self.safe_literal_eval(node.args[0])
+ for kw in node.keywords:
+ if kw.arg == "default":
+ return self.safe_literal_eval(kw.value)
+ return None
+
+ def safe_literal_eval(self, node):
+ if isinstance(node, ast.Constant):
+ return node.value
+ return None
+
+
+def finalize_configs(classes):
+ out = {}
+ for key, value in classes.items():
+ module_prefix, base_id, processor_type, processor_name, dataset_name = value
+ if base_id in classes:
+ _base_id = base_id
+ while not processor_type:
+ try:
+ _, _base_id, processor_type, _, _ = classes[_base_id]
+ except KeyError:
+ break
+
+ _base_id = base_id
+ while not processor_name:
+ try:
+ _, _base_id, _, processor_name, _ = classes[_base_id]
+ except KeyError:
+ break
+ out[key] = (module_prefix, processor_type, processor_name, dataset_name)
+ return out
+
+
+def finalize_processors(classes, config_classes):
+ name2proc = {}
+ name2config = {}
+ name2category = {}
+ name2type = {}
+
+ for key, value in classes.items():
+ module_prefix, base_id, config_class_name = value
+ config_class = (
+ config_classes.get(config_class_name, None) if config_class_name else None
+ )
+ if config_class:
+ processor_type, processor_name, dataset_name = config_class[1:]
+ if processor_name:
+ name2proc[processor_name] = key
+ name2config[processor_name] = config_class_name
+ processor_category = module_prefix.split(".")[1]
+ if processor_type == processor_category:
+ processor_type = None
+ if processor_type:
+ name2type[processor_name] = processor_type
+ name2category[processor_name] = processor_category
+
+ dataset2config = defaultdict(dict)
+ proc2config = defaultdict(list)
+ for key, value in config_classes.items():
+ module_prefix, processor_type, processor_name, dataset_name = value
+ if processor_name:
+ proc2config[processor_name].append(
+ (module_prefix, key, processor_type, dataset_name)
+ )
+ for key, values in proc2config.items():
+ _dataset2config = {value[3]: value for value in values if value[3]}
+ generic_config = [value for value in values if not value[3]][0]
+ for dataset_name in ALL_DATASET_NAMES:
+ if dataset_name in _dataset2config:
+ _, config_class, _, _ = _dataset2config[dataset_name]
+ dataset2config[dataset_name][key] = config_class
+ else:
+ omic_type = DATASET_NAME_TO_OMIC_TYPE.get(dataset_name, None)
+ if omic_type in _dataset2config:
+ _, config_class, _, _ = _dataset2config[omic_type]
+ dataset2config[dataset_name][key] = config_class
+ else:
+ _, config_class, _, _ = generic_config
+ dataset2config[dataset_name][key] = config_class
+
+ return name2proc, name2config, name2category, name2type, dataset2config
+
+
+def get_all_processor_configs(
+ source_folder: str, module_name="biofit", known_classes="ProcessorConfig"
+):
+ if not isinstance(known_classes, list):
+ known_classes = [known_classes]
+ _processors = {}
+ package_folder = Path(source_folder) / module_name.replace(".", "/")
+ for root, dirs, files in os.walk(package_folder.as_posix()):
+ relative_root = os.path.relpath(root, start=source_folder)
+ module_prefix = relative_root.replace(os.sep, ".")
+ for file in files:
+ if file.endswith(".py"):
+ # Construct the module path relative to the root
+ module_path = os.path.join(root, file)
+ with open(module_path, "r", encoding="utf-8") as file:
+ try:
+ # Parse the file content into an AST
+ tree = ast.parse(file.read(), filename=module_path)
+ # Initialize the finder and visit the AST
+ finder = ClassFinder(
+ known_classes,
+ module_prefix,
+ ["processor_type", "processor_name", "dataset_name"],
+ )
+ finder.visit(tree)
+ # Collect found processors
+ _processors.update(finder.found_classes)
+ except SyntaxError:
+ print(f"Syntax error in {module_path}, skipping.")
+
+ return finalize_configs(_processors)
+
+
+def get_all_processors(
+ source_folder: str,
+ module_name="biofit",
+ known_classes="BaseProcessor",
+ config_classes=None,
+):
+ if not isinstance(known_classes, list):
+ known_classes = [known_classes]
+ _processors = {}
+ package_folder = Path(source_folder) / module_name.replace(".", "/")
+ for root, dirs, files in os.walk(package_folder.as_posix()):
+ relative_root = os.path.relpath(root, start=source_folder)
+ module_prefix = relative_root.replace(os.sep, ".")
+ for file in files:
+ if file.endswith(".py"):
+ # Construct the module path relative to the root
+ module_path = os.path.join(root, file)
+ with open(module_path, "r", encoding="utf-8") as file:
+ try:
+ # Parse the file content into an AST
+ tree = ast.parse(file.read(), filename=module_path)
+ # Initialize the finder and visit the AST
+ finder = ClassFinder(
+ known_classes, module_prefix, ["_config_class"]
+ )
+ finder.visit(tree)
+ # Collect found processors
+ _processors.update(finder.found_classes)
+ except SyntaxError:
+ print(f"Syntax error in {module_path}, skipping.")
+ return finalize_processors(_processors, config_classes)
+
+
+def create_config_mapping_constants(
+ name2proc,
+ name2config,
+ name2category,
+ name2type,
+ dataset2config,
+ name2plotter,
+ name2pltconfig,
+ dataset2pltconfig,
+):
+ processing_mapping_names_str = "PROCESSOR_MAPPING_NAMES = OrderedDict(\n [\n"
+ for key, value in name2proc.items():
+ processing_mapping_names_str += f' ("{key}", "{value}"),\n'
+ processing_mapping_names_str += " ]\n)"
+
+ plotter_mapping_names_str = "PLOTTER_MAPPING_NAMES = OrderedDict(\n [\n"
+ for key, value in name2plotter.items():
+ plotter_mapping_names_str += f' ("{key}", "{value}"),\n'
+ plotter_mapping_names_str += " ]\n)"
+
+ config_mapping_names_str = "CONFIG_MAPPING_NAMES = OrderedDict(\n [\n"
+ for key, value in name2config.items():
+ config_mapping_names_str += f' ("{key}", "{value}"),\n'
+ config_mapping_names_str += " ]\n)"
+
+ pltconfig_mapping_names_str = "PLOTTER_CONFIG_MAPPING_NAMES = OrderedDict(\n [\n"
+ for key, value in name2pltconfig.items():
+ pltconfig_mapping_names_str += f' ("{key}", "{value}"),\n'
+ pltconfig_mapping_names_str += " ]\n)"
+
+ processor_category_mapping_names_str = (
+ "PROCESSOR_CATEGORY_MAPPING_NAMES = OrderedDict(\n [\n"
+ )
+ for key, value in name2category.items():
+ processor_category_mapping_names_str += f' ("{key}", "{value}"),\n'
+ processor_category_mapping_names_str += " ]\n)"
+
+ processor_type_mapping_names_str = (
+ "PROCESSOR_TYPE_MAPPING_NAMES = OrderedDict(\n [\n"
+ )
+ for key, value in name2type.items():
+ processor_type_mapping_names_str += f' ("{key}", "{value}"),\n'
+ processor_type_mapping_names_str += " ]\n)"
+
+ head_str = processing_mapping_names_str
+ head_str += "\n\n"
+ head_str += plotter_mapping_names_str
+ head_str += "\n\n"
+ head_str += config_mapping_names_str
+ head_str += "\n\n"
+ head_str += pltconfig_mapping_names_str
+ head_str += "\n\n"
+ head_str += processor_category_mapping_names_str
+ head_str += "\n\n"
+ head_str += processor_type_mapping_names_str
+ head_str += "\n\n"
+
+ regex = re.compile(r"\W")
+
+ for key, value in dataset2config.items():
+ mapping_name = f"{regex.sub('_', key).upper()}_MAPPING_NAMES"
+ head_str += f"{mapping_name} = OrderedDict(\n [\n"
+ for k, v in value.items():
+ head_str += f' ("{k}", "{v}"),\n'
+ head_str += " ]\n)\n\n"
+
+ for key, value in dataset2pltconfig.items():
+ mapping_name = f"{regex.sub('_', key).upper()}_PLOTTER_MAPPING_NAMES"
+ head_str += f"{mapping_name} = OrderedDict(\n [\n"
+ for k, v in value.items():
+ head_str += f' ("{k}", "{v}"),\n'
+ head_str += " ]\n)"
+ head_str += "\n\n"
+
+ mapping_str = "CONFIG_MAPPING = _LazyConfigMapping(CONFIG_MAPPING_NAMES)\n"
+ mapping_str += (
+ "PLOTTER_CONFIG_MAPPING = _LazyConfigMapping(PLOTTER_CONFIG_MAPPING_NAMES)\n"
+ )
+
+ for key, value in dataset2config.items():
+ mapping_name = f"{regex.sub('_', key).upper()}_MAPPING_NAMES"
+ mapping_val = f"{regex.sub('_', key).upper()}_MAPPING"
+ mapping_str += f"{mapping_val} = _LazyConfigMapping({mapping_name})\n"
+
+ for key, value in dataset2pltconfig.items():
+ mapping_name = f"{regex.sub('_', key).upper()}_PLOTTER_MAPPING_NAMES"
+ mapping_val = f"{regex.sub('_', key).upper()}_PLOTTER_MAPPING"
+ mapping_str += f"{mapping_val} = _LazyConfigMapping({mapping_name})\n"
+
+ ds2mapper = "DATASET_TO_MAPPER = {"
+ for key, value in dataset2config.items():
+ mapping_val = f"{regex.sub('_', key).upper()}_MAPPING"
+ ds2mapper += f'"{key}": {mapping_val},'
+ ds2mapper += "}"
+
+ ds2mapper_names = "DATASET_TO_MAPPER_NAMES = {"
+ for key, value in dataset2config.items():
+ mapping_name = f"{regex.sub('_', key).upper()}_MAPPING_NAMES"
+ ds2mapper_names += f'"{key}": {mapping_name},'
+ ds2mapper_names += "}"
+
+ dsplt2mapper = "DATASET_PLT_TO_MAPPER = {"
+ for key, value in dataset2pltconfig.items():
+ mapping_val = f"{regex.sub('_', key).upper()}_PLOTTER_MAPPING"
+ dsplt2mapper += f'"{key}": {mapping_val},'
+ dsplt2mapper += "}"
+
+ dsplt2mapper_names = "DATASET_PLT_TO_MAPPER_NAMES = {"
+ for key, value in dataset2pltconfig.items():
+ mapping_name = f"{regex.sub('_', key).upper()}_PLOTTER_MAPPING_NAMES"
+ dsplt2mapper_names += f'"{key}": {mapping_name},'
+ dsplt2mapper_names += "}"
+
+ mapping_str += "\n\n"
+ mapping_str += ds2mapper
+ mapping_str += "\n\n"
+ mapping_str += ds2mapper_names
+ mapping_str += "\n\n"
+ mapping_str += dsplt2mapper
+ mapping_str += "\n\n"
+ mapping_str += dsplt2mapper_names
+ mapping_str += "\n\n"
+
+ return head_str, mapping_str
+
+
+def replace_mapping(head_str, mapping_str, file_path):
+ with open(file_path, "r") as file:
+ source_code = file.read()
+
+ tree = ast.parse(source_code)
+
+ head_start_point = None
+ head_end_point = None
+ mapping_start_point = None
+ mapping_end_point = None
+
+ _head_insert_point = 0
+ # Find where to end head_str
+ for node in ast.walk(tree):
+ if not head_start_point and (
+ isinstance(node, ast.ImportFrom) or isinstance(node, ast.Import)
+ ):
+ _head_insert_point = node.lineno + 1
+
+ if (
+ not head_start_point
+ and not isinstance(node, ast.ImportFrom)
+ and not isinstance(node, ast.Import)
+ ):
+ head_start_point = _head_insert_point
+
+ if (
+ isinstance(node, ast.FunctionDef)
+ and node.name == "config_class_to_model_type"
+ ):
+ head_end_point = node.lineno - 1
+
+ if isinstance(node, ast.ClassDef) and node.name == "_LazyConfigMapping":
+ mapping_start_point = node.end_lineno + 1
+
+ if isinstance(node, ast.Assign):
+ if (
+ hasattr(node.targets[0], "id")
+ and node.targets[0].id == "DATASET_PREPROCESSOR_MAPPING_NAMES"
+ ):
+ mapping_end_point = node.lineno - 1
+
+ splitted_source_code = re.split(r"[\r\n]", source_code)
+ new_source_code = (
+ splitted_source_code[:head_start_point]
+ + [f"\n{head_str}"]
+ + splitted_source_code[head_end_point:mapping_start_point]
+ + [f"\n{mapping_str}"]
+ + splitted_source_code[mapping_end_point:]
+ )
+
+ new_source_code = "\n".join(new_source_code)
+
+ with open(file_path, "w") as file:
+ file.write(new_source_code)
+
+ file_path = Path(file_path).resolve().as_posix()
+ # use ruff to format the code
+ os.system(f"ruff format {file_path}")
+
+
+if __name__ == "__main__":
+ config_classes = get_all_processor_configs("../src")
+ classes = get_all_processors("../src", config_classes=config_classes)
+ plotter_config_classes = get_all_processor_configs(
+ "../src", known_classes="PlotterConfig"
+ )
+ plotter_classes = get_all_processors(
+ "../src", config_classes=plotter_config_classes, known_classes="BasePlotter"
+ )
+ name2proc, name2config, _, _, dataset2config = plotter_classes
+ head_str, mapping_str = create_config_mapping_constants(
+ *classes, name2proc, name2config, dataset2config
+ )
+
+ file_path = "../src/biofit/auto/configuration_auto.py"
+ replace_mapping(head_str, mapping_str, file_path)