diff --git a/.circleci/config.yml b/.circleci/config.yml
index 231f2fc9..eadb59b6 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -1,6 +1,6 @@
.common-values:
- docker-image: &docker-image circleci/python:3.7.9
+ docker-image: &docker-image circleci/python:3.10
restore-cache: &restore-cache
keys:
@@ -19,7 +19,7 @@
name: Install package
command: |
. venv/bin/activate
- venv/bin/python3 -m pip install tensorflow==1.15.5
+ venv/bin/python3 -m pip install tensorflow==2.10
venv/bin/python3 -m pip install .
version: 2
@@ -53,6 +53,7 @@ jobs:
test:
docker:
- image: *docker-image
+ resource_class: large
steps:
- checkout
- restore_cache: *restore-cache
@@ -63,8 +64,7 @@ jobs:
name: Unit tests with Pytest
command: |
. venv/bin/activate
- venv/bin/python3 -m pytest tests
- no_output_timeout: 60m
+ venv/bin/python3 setup.py test
lint:
docker:
- image: *docker-image
@@ -78,7 +78,7 @@ jobs:
command: |
. venv/bin/activate
venv/bin/python3 -m pip install flake8
- venv/bin/python3 -m flake8 ampligraph --max-line-length 120 --ignore=W291,W293,W503
+ venv/bin/python3 -m flake8 ampligraph --max-line-length 200 --ignore=W605,W503,E231
docs:
docker:
@@ -105,34 +105,29 @@ workflows:
filters:
branches:
only:
- - master
- - develop
- - /release\/.*/
+ - main
+ - ampligraph2/develop
- pip-check:
filters:
branches:
only:
- - master
- - develop
- - /release\/.*/
+ - main
+ - ampligraph2/develop
- lint:
filters:
branches:
only:
- - master
- - develop
- - /release\/.*/
+ - main
+ - ampligraph2/develop
- docs:
filters:
branches:
only:
- - master
- - develop
- - /release\/.*/
+ - main
+ - ampligraph2/develop
- test:
filters:
branches:
only:
- - master
- - develop
- - /release\/.*/
+ - main
+ - ampligraph2/develop
diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
new file mode 100644
index 00000000..0ba237a9
--- /dev/null
+++ b/.gitlab-ci.yml
@@ -0,0 +1,64 @@
+# To contribute improvements to CI/CD templates, please follow the Development guide at:
+# https://docs.gitlab.com/ee/development/cicd/templates.html
+# This specific template is located at:
+# https://gitlab.com/gitlab-org/gitlab/-/blob/master/lib/gitlab/ci/templates/Python.gitlab-ci.yml
+
+# Official language image. Look for the different tagged releases at:
+# https://hub.docker.com/r/library/python/tags/
+image: python:3.10.6
+
+# Change pip's cache directory to be inside the project directory since we can
+# only cache local items.
+variables:
+ PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip"
+
+# Pip's cache doesn't store the python packages
+# https://pip.pypa.io/en/stable/topics/caching/
+#
+# If you want to also cache the installed packages, you have to install
+# them in a virtualenv and cache it as well.
+cache:
+ paths:
+ - .cache/pip
+ - venv/
+
+before_script:
+ - python --version # For debugging
+ - pip install virtualenv
+ - virtualenv venv
+ - source venv/bin/activate
+ - pip install --upgrade pip
+ - pip install tensorflow==2.10
+
+codestyle:
+ script:
+ - pip install pylint
+ - pylint --fail-under=3 -v ./ampligraph
+
+test:
+ script:
+ - python setup.py test
+
+run:
+ script:
+ - pip install wheel setuptools
+ - pip wheel --wheel-dir dist --no-deps .
+ artifacts:
+ paths:
+ - dist/*.whl
+
+pages:
+ script:
+ - pip install sphinx sphinx-rtd-theme
+ - cd docs
+ - make clean autogen html
+ - mkdir ../public/
+ - mv _build/html/ ../public/
+ artifacts:
+ paths:
+ - public
+
+deploy:
+ stage: deploy
+ script: echo "Define your deployment script!"
+ environment: production
diff --git a/.readthedocs.yml b/.readthedocs.yml
index 28cc3e87..ad5c18a0 100644
--- a/.readthedocs.yml
+++ b/.readthedocs.yml
@@ -13,7 +13,7 @@ sphinx:
formats: all
python:
- version: 3.7
+ version: 3.8
install:
- requirements: docs/requirements_readthedocs.txt
- method: setuptools
diff --git a/README.md b/README.md
index c5f77a48..be0c883a 100644
--- a/README.md
+++ b/README.md
@@ -4,6 +4,9 @@
[![Documentation Status](https://readthedocs.org/projects/ampligraph/badge/?version=latest)](http://ampligraph.readthedocs.io/?badge=latest)
+[![CircleCI](https://dl.circleci.com/status-badge/img/gh/Accenture/AmpliGraph/tree/main.svg?style=svg)](https://dl.circleci.com/status-badge/redirect/gh/Accenture/AmpliGraph/tree/main)
+
+
[Join the conversation on Slack](https://join.slack.com/t/ampligraph/shared_invite/enQtNTc2NTI0MzUxMTM5LTRkODk0MjI2OWRlZjdjYmExY2Q3M2M3NGY0MGYyMmI4NWYyMWVhYTRjZDhkZjA1YTEyMzBkMGE4N2RmNTRiZDg)
![](docs/img/slack_logo.png)
@@ -30,66 +33,68 @@ It then combines embeddings with model-specific scoring functions to predict uns
![](docs/img/kg_lp_step2.png)
-## Key Features
+## AmpliGraph 2.0.0 is now available!
+The new version features TensorFlow 2 back-end and Keras style APIs that makes it faster, easier to use and
+extend the support for multiple features. Further, the data input/output pipeline has changed, and the support for
+some obsolete models was discontinued. See the Changelog for a more thorough list of changes.
-* **Intuitive APIs**: AmpliGraph APIs are designed to reduce the code amount required to learn models that predict links in knowledge graphs.
-* **GPU-Ready**: AmpliGraph is based on TensorFlow, and it is designed to run seamlessly on CPU and GPU devices - to speed-up training.
-* **Extensible**: Roll your own knowledge graph embeddings model by extending AmpliGraph base estimators.
+## Key Features
+* **Intuitive APIs**: AmpliGraph APIs are designed to reduce the code amount required to learn models that predict links in knowledge graphs. The new version AmpliGraph 2 APIs are in Keras style, making the user experience even smoother.
+* **GPU-Ready**: AmpliGraph 2 is based on TensorFlow 2, and it is designed to run seamlessly on CPU and GPU devices - to speed-up training.
+* **Extensible**: Roll your own knowledge graph embeddings model by extending AmpliGraph base estimators.
## Modules
AmpliGraph includes the following submodules:
* **Datasets**: helper functions to load datasets (knowledge graphs).
-* **Models**: knowledge graph embedding models. AmpliGraph contains **TransE**, **DistMult**, **ComplEx**, **HolE**, **ConvE**, **ConvKB**. (More to come!)
+* **Models**: knowledge graph embedding models. AmpliGraph 2 contains **TransE**, **DistMult**, **ComplEx**, **HolE** (More to come!)
* **Evaluation**: metrics and evaluation protocols to assess the predictive power of the models.
* **Discovery**: High-level convenience APIs for knowledge discovery (discover new facts, cluster entities, predict near duplicates).
-
+* **Compat**: submodule that extends the compatibility of AmpliGraph 2 APIs to those of AmpliGraph 1.x for the user already familiar with them.
## Installation
### Prerequisites
* Linux, macOS, Windows
-* Python 3.7
+* Python ≥ 3.8
#### Provision a Virtual Environment
Create and activate a virtual environment (conda)
```
-conda create --name ampligraph python=3.7
+conda create --name ampligraph python=3.8
source activate ampligraph
```
#### Install TensorFlow
-AmpliGraph is built on TensorFlow 1.x.
+AmpliGraph 2 is built on TensorFlow 2.x.
Install from pip or conda:
**CPU-only**
```
-pip install "tensorflow>=1.15.2,<2.0"
+pip install "tensorflow>=2.9"
-or
+or
-conda install tensorflow'>=1.15.2,<2.0.0'
+conda install tensorflow'>=2.9'
```
-**GPU support**
+**Install TensorFlow 2 for Mac OS M1 chip**
```
-pip install "tensorflow-gpu>=1.15.2,<2.0"
-
-or
-
-conda install tensorflow-gpu'>=1.15.2,<2.0.0'
+conda install -c apple tensorflow-deps
+pip install --user tensorflow-macos==2.10
+pip install --user tensorflow-metal==0.6
```
-
+In case of problems with installation refer to [Tensorflow Plugin page on Apple developer site](https://developer.apple.com/metal/tensorflow-plugin/).
### Install AmpliGraph
@@ -114,9 +119,9 @@ pip install -e .
### Sanity Check
```python
->> import ampligraph
->> ampligraph.__version__
-'1.4.0'
+>>> import ampligraph
+>>> ampligraph.__version__
+'2.0.0'
```
@@ -126,16 +131,20 @@ AmpliGraph includes implementations of TransE, DistMult, ComplEx, HolE, ConvE, a
Their predictive power is reported below and compared against the state-of-the-art results in literature.
[More details available here](https://docs.ampligraph.org/en/latest/experiments.html).
-| |FB15K-237 |WN18RR |YAGO3-10 | FB15k |WN18 |
-|------------------------------|----------|---------|-----------|------------|---------------|
-| Literature Best | **0.35***| 0.48* | 0.49* | **0.84**** | **0.95*** |
-| TransE (AmpliGraph) | 0.31 | 0.22 | **0.51** | 0.63 | 0.66 |
-| DistMult (AmpliGraph) | 0.31 | 0.47 | 0.50 | 0.78 | 0.82 |
-| ComplEx (AmpliGraph) | 0.32 | **0.51**| 0.49 | 0.80 | 0.94 |
-| HolE (AmpliGraph) | 0.31 | 0.47 | 0.50 | 0.80 | 0.94 |
-| ConvE (AmpliGraph) | 0.26 | 0.45 | 0.30 | 0.50 | 0.93 |
-| ConvE (1-N, AmpliGraph) | 0.32 | 0.48 | 0.40 | 0.80 | **0.95** |
-| ConvKB (AmpliGraph) | 0.23 | 0.39 | 0.30 | 0.65 | 0.80 |
+| | FB15K-237 | WN18RR | YAGO3-10 | FB15k | WN18 |
+|---------------------------|-----------|----------|----------|------------|-----------|
+| Literature Best | **0.35*** | 0.48* | 0.49* | **0.84**** | **0.95*** |
+| TransE (AmpliGraph 2) | 0.31 | 0.22 | 0.50 | 0.62 | 0.64 |
+| DistMult (AmpliGraph 2) | 0.30 | 0.47 | 0.48 | 0.71 | 0.82 |
+| ComplEx (AmpliGraph 2) | 0.31 | 0.50 | 0.49 | 0.73 | 0.94 |
+| HolE (AmpliGraph 2) | 0.30 | 0.47 | 0.47 | 0.73 | 0.94 |
+| TransE (AmpliGraph 1) | 0.31 | 0.22 | **0.51** | 0.63 | 0.66 |
+| DistMult (AmpliGraph 1) | 0.31 | 0.47 | 0.50 | 0.78 | 0.82 |
+| ComplEx (AmpliGraph 1) | 0.32 | **0.51** | 0.49 | 0.80 | 0.94 |
+| HolE (AmpliGraph 1) | 0.31 | 0.47 | 0.50 | 0.80 | 0.94 |
+| ConvE (AmpliGraph 1) | 0.26 | 0.45 | 0.30 | 0.50 | 0.93 |
+| ConvE (1-N, AmpliGraph 1) | 0.32 | 0.48 | 0.40 | 0.80 | **0.95** |
+| ConvKB (AmpliGraph 1) | 0.23 | 0.39 | 0.30 | 0.65 | 0.80 |
* Timothee Lacroix, Nicolas Usunier, and Guillaume Obozinski. Canonical tensor decomposition for knowledge base
@@ -179,6 +188,8 @@ If you instead use AmpliGraph in an academic publication, cite as:
```
@misc{ampligraph,
author= {Luca Costabello and
+ Alberto Bernardi and
+ Adrianna Janik and
Sumit Pai and
Chan Le Van and
Rory McGrath and
@@ -194,4 +205,4 @@ If you instead use AmpliGraph in an academic publication, cite as:
## License
-AmpliGraph is licensed under the Apache 2.0 License.
\ No newline at end of file
+AmpliGraph is licensed under the Apache 2.0 License.
diff --git a/ampligraph/__init__.py b/ampligraph/__init__.py
index 33b6474d..26a38d20 100644
--- a/ampligraph/__init__.py
+++ b/ampligraph/__init__.py
@@ -1,18 +1,22 @@
-# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
#
-# This file is Licensed under the Apache License, Version 2.0.
+# This file is Licensed under the Apache License, Version 2.0.0.
# A copy of the Licence is available in LICENCE, or at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
"""AmpliGraph is a library for relational learning on knowledge graphs."""
import logging.config
-import pkg_resources
+import pkg_resources
import tensorflow as tf
+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
-__version__ = '1.4.0'
+__version__ = '2.0.0'
__all__ = ['datasets', 'latent_features', 'discovery', 'evaluation', 'utils']
-logging.config.fileConfig(pkg_resources.resource_filename(__name__, 'logger.conf'), disable_existing_loggers=False)
+logging.config.fileConfig(
+ pkg_resources.resource_filename(__name__, "logger.conf"),
+ disable_existing_loggers=False,
+)
diff --git a/ampligraph/compat/__init__.py b/ampligraph/compat/__init__.py
new file mode 100644
index 00000000..32da6c68
--- /dev/null
+++ b/ampligraph/compat/__init__.py
@@ -0,0 +1,12 @@
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
+#
+# This file is Licensed under the Apache License, Version 2.0.
+# A copy of the Licence is available in LICENCE, or at:
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+"""Provides backward compatibility to AmpliGraph 1 APIs."""
+from .evaluate import evaluate_performance
+from .models import ComplEx, DistMult, HolE, TransE
+
+__all__ = ["evaluate_performance", "TransE", "ComplEx", "DistMult", "HolE"]
diff --git a/ampligraph/compat/evaluate.py b/ampligraph/compat/evaluate.py
new file mode 100644
index 00000000..cc92c49c
--- /dev/null
+++ b/ampligraph/compat/evaluate.py
@@ -0,0 +1,179 @@
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
+#
+# This file is Licensed under the Apache License, Version 2.0.
+# A copy of the Licence is available in LICENCE, or at:
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+
+import logging
+
+import numpy as np
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.DEBUG)
+
+
+def evaluate_performance(
+ X,
+ model,
+ filter_triples=None,
+ verbose=False,
+ entities_subset=None,
+ corrupt_side="s,o",
+ batch_size=1,
+):
+ """Evaluate the performance of an embedding model.
+
+ The evaluation protocol follows the procedure defined in :cite:`bordes2013translating` and can be summarised as:
+
+ #. Artificially generate negative triples by corrupting first the subject and then the object.
+
+ #. Remove the positive triples from the set returned by (1) -- positive triples \
+ are usually the concatenation of training, validation and test sets.
+
+ #. Rank each test triple against all remaining triples returned by (2).
+
+
+ With the ranks of both object and subject corruptions, one may compute metrics such as the MRR by
+ calculating them separately and then averaging them out.
+ Note that the metrics implemented in AmpliGraph's ``evaluate.metrics`` module will already work that way
+ when provided with the input returned by ``evaluate_performance``.
+
+ The artificially generated negatives are compliant with the local closed world assumption (LCWA),
+ as described in :cite:`nickel2016review`. In practice, that means only one side of the triple is
+ corrupted at a time (i.e. either the subject or the object).
+
+ .. note::
+ The evaluation protocol assigns the worst rank to a positive test triple in case of a tie with negatives.
+ This is the agreed upon behaviour in the literature.
+
+ .. hint::
+ When ``entities_subset=None``, the method will use all distinct entities in the knowledge graph ``X``
+ to generate negatives to rank against. This might slow down the evaluation process.
+ Some corruptions may not even make sense for the task that one may be interested in.
+
+ For instance, consider the case ``, where we are mainly interested in those movies that
+ an actor has acted in. A sensible way to evaluate this would be to rank against all the movie entities and
+ compute the desired metrics. In such cases, where to focus on some entities, it is recommended to pass the
+ desired entities to use to generate corruptions to ``entities_subset``. Besides, trying to rank a positive
+ against an extremely large number of negatives may be overkilling.
+
+ As a reference, the popular FB15k-237 dataset has ~15k distinct entities. The evaluation protocol ranks each
+ positive against 15k corruptions per side.
+
+ Parameters
+ ----------
+ X : ndarray, shape (n, 3)
+ An array of test triples.
+ model : EmbeddingModel
+ A knowledge graph embedding model.
+ filter_triples : ndarray, shape (n, 3), or None
+ The triples used to filter negatives.
+
+ .. note::
+ When *filtered* mode is enabled (i.e., ``filtered_triples`` is not `None`), to speed up the procedure,
+ we use a database based filtering. This strategy is as described below:
+
+ * Store the filter_triples in the DB.
+ * For each test triple, we generate corruptions for evaluation and score them.
+ * The corruptions may contain some False Negatives. We find such statements by quering the database.
+ * From the computed scores we retrieve the scores of the False Negatives.
+ * We compute the rank of the test triple by comparing against ALL the corruptions.
+ * We then compute the number of False Negatives that are ranked higher than the test triple; and then
+ subtract this value from the above computed rank to yield the final filtered rank.
+
+ **Execution Time:** This method takes ~4 minutes on FB15K using ComplEx
+ (Intel Xeon Gold 6142, 64 GB Ubuntu 16.04 box, Tesla V100 16GB).
+
+ verbose : bool
+ Verbose mode.
+ filter_unseen : bool
+ This can be set to `False` to skip filtering of unseen entities if :meth:`train_test_split_unseen()` was used to
+ split the original dataset.
+
+ entities_subset: array-like
+ List of entities to use for corruptions. If `None`, will generate corruptions
+ using all distinct entities (default: `None`).
+ corrupt_side: str
+ Specifies which side of the triple to corrupt:
+
+ - `'s'`: corrupt only subject.
+ - `'o'`: corrupt only object.
+ - `'s+o'`: corrupt both subject and object.
+ - `'s,o'`: corrupt subject and object sides independently and return 2 ranks. This corresponds to the \
+ evaluation protocol used in literature, where head and tail corruptions are evaluated separately.
+
+ .. note::
+ When ``corrupt_side='s,o'`` the function will return :math:`2*n` ranks as a (n, 2) array.
+ The first column of the array represents the subject corruptions.
+ The second column of the array represents the object corruptions.
+ Otherwise, the function returns :math:`n` ranks as (n) array.
+
+ batch_size: int
+ Batch size to use for evaluation.
+
+ Returns
+ -------
+ ranks : ndarray, shape (n) or (n,2)
+ An array of ranks of test triples.
+ When ``corrupt_side='s,o'`` the function returns (n,2). The first column represents the rank against
+ subject corruptions and the second column represents the rank against object corruptions.
+ In other cases, it returns (n), i.e., rank against the specified corruptions.
+
+ Example
+ -------
+ >>> import numpy as np
+ >>> from ampligraph.datasets import load_wn18
+ >>> from ampligraph.latent_features import ScoringBasedEmbeddingModel
+ >>> from ampligraph.evaluation import mrr_score, hits_at_n_score
+ >>>
+ >>> X = load_wn18()
+ >>> model = ScoringBasedEmbeddingModel(k=150, eta=1, scoring_type='ComplEx')
+ >>> model.compile(optimizer='adam', loss='nll')
+ >>> model.fit(X['train'],
+ >>> batch_size=int(X['train'].shape[0] / 10),
+ >>> epochs=10)
+ >>> filter_triples = {'test': np.concatenate([X['train'], X['valid'], X['test']], axis=0)}
+ >>> ranks = model.evaluate(X['test'][:5],
+ >>> use_filter=filter_triples,
+ >>> corrupt_side='s+o')
+ >>> print(ranks)
+ [[ 1]
+ [116]
+ [ 1]
+ [ 1]
+ [214]]
+ >>> print(mrr_score(ranks))
+ 0.6026587173702869
+ >>> print(hits_at_n_score(ranks, n=10))
+ 0.6
+ """
+ logger.debug("Evaluating the performance of the embedding model.")
+ assert corrupt_side in [
+ "s",
+ "o",
+ "s+o",
+ "s,o",
+ ], "Invalid value for corrupt_side."
+
+ if isinstance(filter_triples, np.ndarray) or isinstance(
+ filter_triples, list
+ ):
+ filter_triples = {"valid": filter_triples}
+ elif filter_triples is None or not filter_triples:
+ filter_triples = False
+ elif isinstance(filter_triples, dict):
+ pass
+ else:
+ raise ValueError("Incorrect type for filter_triples")
+
+ return model.evaluate(
+ x=X,
+ batch_size=batch_size,
+ verbose=verbose,
+ use_filter=filter_triples,
+ corrupt_side=corrupt_side,
+ entities_subset=entities_subset,
+ callbacks=None,
+ )
diff --git a/ampligraph/compat/models.py b/ampligraph/compat/models.py
new file mode 100644
index 00000000..002d6ef7
--- /dev/null
+++ b/ampligraph/compat/models.py
@@ -0,0 +1,842 @@
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
+#
+# This file is Licensed under the Apache License, Version 2.0.
+# A copy of the Licence is available in LICENCE, or at:
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+
+import numpy as np
+import tensorflow as tf
+
+from ampligraph.latent_features.loss_functions import get as get_loss
+from ampligraph.latent_features.models import ScoringBasedEmbeddingModel
+from ampligraph.latent_features.optimizers import get as get_optimizer
+from ampligraph.latent_features.regularizers import get as get_regularizer
+
+BACK_COMPAT_MODELS = {}
+
+
+def register_compatibility(name):
+ def insert_in_registry(class_handle):
+ BACK_COMPAT_MODELS[name] = class_handle
+ class_handle.name = name
+ return class_handle
+
+ return insert_in_registry
+
+
+class ScoringModelBase:
+ def __init__(
+ self,
+ k=100,
+ eta=2,
+ epochs=100,
+ batches_count=100,
+ seed=0,
+ embedding_model_params={
+ "corrupt_sides": ["s,o"],
+ "negative_corruption_entities": "all",
+ "norm": 1,
+ "normalize_ent_emb": False,
+ },
+ optimizer="adam",
+ optimizer_params={"lr": 0.0005},
+ loss="nll",
+ loss_params={},
+ regularizer=None,
+ regularizer_params={},
+ initializer="xavier",
+ initializer_params={"uniform": False},
+ verbose=False,
+ model=None,
+ ):
+ """Initialize the model class.
+
+ Parameters
+ ----------
+ k : int
+ Embedding space dimensionality.
+ eta : int
+ The number of negatives that must be generated at runtime during
+ training for each positive.
+ epochs : int
+ The iterations of the training loop.
+ batches_count : int
+ The number of batches in which the training set must be split
+ during the training loop.
+ seed : int
+ The seed used by the internal random numbers generator.
+ embedding_model_params : dict
+ Model-specific hyperparams, passed to the model as a dictionary.
+ Refer to model-specific documentation for details.
+ optimizer : str
+ The optimizer used to minimize the loss function. Choose between
+ 'sgd', 'adagrad', 'adam', 'momentum'.
+ optimizer_params : dict
+ Arguments specific to the optimizer, passed as a dictionary.
+ Supported keys:
+
+ - **'lr'** (float): learning rate (used by all the optimizers).
+ Default: 0.1.
+ - **'momentum'** (float): learning momentum
+ (only used when ``optimizer=momentum``). Default: 0.9.
+
+ loss : str
+ The type of loss function to use during training.
+
+ - `"pairwise"` the model will use pairwise margin-based
+ loss function.
+ - `"nll"` the model will use negative loss likelihood.
+ - `"absolute_margin"` the model will use absolute
+ margin likelihood.
+ - `"self_adversarial"` the model will use adversarial sampling
+ loss function.
+ - `"multiclass_nll"` the model will use multiclass nll loss.
+ Switch to multiclass loss defined in \
+ :cite:`chen2015` by passing ``"corrupt_side"`` as `["s","o"]` to
+ ``embedding_model_params``. To use loss defined in\
+ :cite:`kadlecBK17` pass ``"corrupt_side"``\
+ as `"o"` to embedding_model_params.
+
+ loss_params : dict
+ Dictionary of loss-specific hyperparameters.
+ regularizer : str
+ The regularization strategy to use with the loss function.
+
+ - `None`: the model will not use any regularizer (default)
+ - `LP`: the model will use :math:`L^1, L^2` or :math:`L^3`
+ regularization based on the value of
+ ``regularizer_params['p']`` in the ``regularizer_params``.
+
+ regularizer_params : dict
+ Dictionary of regularizer-specific hyperparameters.
+ initializer : str
+ The type of initializer to use.
+
+ - `"normal"`: The embeddings will be initialized from a normal
+ distribution
+ - `"uniform"`: The embeddings will be initialized from a uniform
+ distribution
+ - `"xavier"`: The embeddings will be initialized using xavier
+ strategy (default)
+
+ initializer_params : dict
+ Dictionary of initializer-specific hyperparameters.
+ verbose : bool
+ Verbose mode.
+ """
+ if model is not None:
+ self.model_name = model.scoring_type
+ else:
+ self.k = k
+ self.eta = eta
+ self.seed = seed
+
+ self.batches_count = batches_count
+
+ self.epochs = epochs
+ self.embedding_model_params = embedding_model_params
+ self.optimizer = optimizer
+ self.optimizer_params = optimizer_params
+ self.loss = loss
+ self.loss_params = loss_params
+ self.initializer = initializer
+ self.initializer_params = initializer_params
+ self.regularizer = regularizer
+ self.regularizer_params = regularizer_params
+ self.verbose = verbose
+
+ self.model = model
+ self.is_backward = True
+
+ def _get_optimizer(self, optimizer, optim_params):
+ """Get the optimizer from tf.keras.optimizers."""
+ learning_rate = optim_params.get("lr", 0.001)
+ del optim_params["lr"]
+
+ if optimizer == "adam":
+ optim = tf.keras.optimizers.Adam(
+ learning_rate=learning_rate, **optim_params
+ )
+ status = True
+ elif optimizer == "adagrad":
+ optim = tf.keras.optimizers.Adagrad(
+ learning_rate=learning_rate, **optim_params
+ )
+ status = True
+ elif optimizer == "sgd":
+ optim = tf.keras.optimizers.SGD(
+ learning_rate=learning_rate, **optim_params
+ )
+ status = True
+ else:
+ optim = get_optimizer(optimizer)
+ status = False
+
+ optim_params["lr"] = learning_rate
+ return optim, status
+
+ def is_fit(self):
+ """Flag whether the model has been fitted or not."""
+ return self.model.is_fit()
+
+ def _get_initializer(self, initializer, initializer_params):
+ """Get the initializers among tf.keras.initializers."""
+ if initializer == "xavier":
+ if initializer_params["uniform"]:
+ return tf.keras.initializers.GlorotUniform(seed=self.seed)
+ else:
+ return tf.keras.initializers.GlorotNormal(seed=self.seed)
+ elif initializer == "uniform":
+ return tf.keras.initializers.RandomUniform(
+ minval=initializer_params.get("low", -0.05),
+ maxval=initializer_params.get("high", 0.05),
+ seed=self.seed,
+ )
+ elif initializer == "normal":
+ return tf.keras.initializers.RandomNormal(
+ mean=initializer_params.get("mean", 0.0),
+ stddev=initializer_params.get("std", 0.05),
+ seed=self.seed,
+ )
+ elif initializer == "constant":
+ entity_init = initializer_params.get("entity", None)
+ rel_init = initializer_params.get("relation", None)
+ assert (
+ entity_init is not None
+ ), "Please pass the `entity` initializer value"
+ assert (
+ rel_init is not None
+ ), "Please pass the `relation` initializer value"
+ return [
+ tf.constant_initializer(entity_init),
+ tf.constant_initializer(rel_init),
+ ]
+ else:
+ return tf.keras.initializers.get(initializer)
+
+ def fit(
+ self,
+ X,
+ early_stopping=False,
+ early_stopping_params={},
+ focusE_numeric_edge_values=None,
+ tensorboard_logs_path=None,
+ callbacks=None,
+ verbose=False,
+ ):
+ """Train the model (with optional early stopping).
+
+ The model is trained on a training set ``X`` using the training
+ protocol described in :cite:`trouillon2016complex`.
+
+ Parameters
+ ----------
+ X : ndarray, shape (n, 3) or str or GraphDataLoader or
+ AbstractGraphPartitioner Data OR Filename of the data
+ file OR Data Handle to be used for training.
+ early_stopping: bool
+ Flag to enable early stopping (default:`False`)
+ early_stopping_params: dict
+ Dictionary of hyperparameters for the early stopping heuristics.
+ The following string keys are supported:
+
+ - **"x_valid"** (ndarray, shape (n, 3) or str or
+ GraphDataLoader or AbstractGraphPartitioner) - Numpy \
+ array of validation triples OR handle of Dataset adapter
+ which would help retrieve data.
+ - **"criteria"** (str) - Criteria for early stopping
+ `'hits10'`, `'hits3'`, `'hits1'` \
+ or `'mrr'` (default: `'mrr'`).
+ - **"x_filter"** (ndarray, shape (n, 3)) - Positive triples
+ to use as filter if a `'filtered'` early \
+ stopping criteria is desired (i.e., filtered-MRR if
+ ``'criteria':'mrr'``). Note this will affect training time
+ (no filter by default). If the filter has already been set in
+ the adapter, pass `True`.
+ - **"burn_in"** (int) - Number of epochs to pass before
+ kicking in early stopping (default: 100).
+ - **"check_interval"** (int) - Early stopping interval after
+ burn-in (default:10).
+ - **"stop_interval"** (int) - Stop if criteria is performing
+ worse over n consecutive checks (default: 3).
+ - **"corruption_entities"** (list) - List of entities to be
+ used for corruptions. If `'all'`, it uses all entities
+ (default: `'all'`).
+ - **"corrupt_side"** (str) - Specifies which side to corrupt:
+ `'s'`, `'o'`, `'s+o'`, `'s,o'` \
+ (default: `'s,o'`).
+
+ focusE_numeric_edge_values: ndarray, shape (n, 1)
+ Numeric values associated with links in the training set.
+ Semantically, the numeric value can signify importance,
+ uncertainity, significance, confidence, etc. If the numeric value
+ is unknown pass a NaN weight. The model will uniformly randomly
+ assign a numeric value. One can also think about assigning
+ numeric values by looking at the distribution of it per predicate.
+ .. warning:: In the compatible version, this option only supports
+ data passed as np.array.
+ tensorboard_logs_path: str or None
+ Path to store tensorboard logs, e.g., average training loss
+ tracking per epoch (default: `None` indicating no logs will be
+ collected). When provided it will create a folder under provided
+ path and save tensorboard files there. To then view the loss in
+ the terminal run: ``tensorboard --logdir ``.
+ """
+ self.model = ScoringBasedEmbeddingModel(
+ self.eta, self.k, scoring_type=self.model_name, seed=self.seed
+ )
+ if callbacks is None:
+ callbacks = []
+ if tensorboard_logs_path is not None:
+ tensorboard_callback = tf.keras.callbacks.TensorBoard(
+ log_dir=tensorboard_logs_path
+ )
+ callbacks.append(tensorboard_callback)
+
+ regularizer = self.regularizer
+ if regularizer is not None:
+ regularizer = get_regularizer(regularizer,
+ self.regularizer_params)
+
+ initializer = self.initializer
+ if initializer is not None:
+ initializer = self._get_initializer(
+ initializer, self.initializer_params
+ )
+
+ loss = get_loss(self.loss, self.loss_params)
+ optimizer, is_back_compat_optim = self._get_optimizer(
+ self.optimizer, self.optimizer_params
+ )
+
+ self.model.compile(
+ optimizer=optimizer,
+ loss=loss,
+ entity_relation_initializer=initializer,
+ entity_relation_regularizer=regularizer,
+ )
+ if not is_back_compat_optim:
+ tf.keras.backend.set_value(
+ self.model.optimizer.learning_rate,
+ self.optimizer_params.get("lr", 0.001),
+ )
+
+ if len(early_stopping_params) != 0:
+ checkpoint = tf.keras.callbacks.EarlyStopping(
+ monitor="val_{}".format(
+ early_stopping_params.get("criteria", "mrr")
+ ),
+ min_delta=0,
+ patience=early_stopping_params.get("stop_interval", 10),
+ verbose=self.verbose,
+ mode="max",
+ restore_best_weights=True,
+ )
+ callbacks.append(checkpoint)
+
+ x_filter = early_stopping_params.get("x_filter", None)
+
+ if isinstance(x_filter, np.ndarray) or isinstance(x_filter, list):
+ x_filter = {"test": x_filter}
+ elif x_filter is None or not x_filter:
+ x_filter = False
+ elif isinstance(x_filter, dict):
+ pass
+ else:
+ raise ValueError("Incorrect type for x_filter")
+
+ focusE = False
+ params_focusE = {}
+ if focusE_numeric_edge_values is not None:
+ if isinstance(
+ focusE_numeric_edge_values, np.ndarray
+ ) and isinstance(X, np.ndarray):
+ focusE = True
+ X = np.concatenate([X, focusE_numeric_edge_values], axis=1)
+ params_focusE = {
+ "non_linearity": self.embedding_model_params.get(
+ "non_linearity", "linear"
+ ),
+ "stop_epoch": self.embedding_model_params.get(
+ "stop_epoch", 251
+ ),
+ "structural_wt": self.embedding_model_params.get(
+ "structural_wt", 0.001
+ ),
+ }
+ else:
+ msg = (
+ "Either X or focusE_numeric_edge_values are not\
+ np.array, so focusE is not supported. "
+ "Try using Ampligraph 2 or Ampligraph 1.x APIs!"
+ )
+ raise ValueError(msg)
+
+ self.model.fit(
+ X,
+ batch_size=np.ceil(X.shape[0] / self.batches_count),
+ epochs=self.epochs,
+ validation_freq=early_stopping_params.get("check_interval", 10),
+ validation_burn_in=early_stopping_params.get("burn_in", 25),
+ validation_batch_size=early_stopping_params.get("batch_size",
+ 100),
+ validation_data=early_stopping_params.get("x_valid", None),
+ validation_filter=x_filter,
+ validation_entities_subset=early_stopping_params.get(
+ "corruption_entities", None
+ ),
+ callbacks=callbacks,
+ verbose=verbose,
+ focusE=focusE,
+ focusE_params=params_focusE,
+ )
+ self.data_shape = self.model.data_shape
+
+ def get_indexes(self, X, type_of="t", order="raw2ind"):
+ """Converts given data to indexes or to raw data (according to order).
+
+ It works for both triples (``type_of='t'``), entities
+ (``type_of='e'``), and relations (``type_of='r'``).
+
+ Parameters
+ ----------
+ X: np.array
+ Data to be indexed.
+ type_of: str
+ One of `['e', 't', 'r']` to specify which type of data is to be
+ indexed or converted to raw data.
+ order: str
+ One of `['raw2ind', 'ind2raw']` to specify whether to convert raw
+ data to indexes or vice versa.
+
+ Returns
+ -------
+ Y: np.array
+ Indexed data or raw data.
+ """
+ return self.model.get_indexes(X, type_of, order)
+
+ def get_count(self, concept_type="e"):
+ """Returns the count of entities and relations that were present
+ during training.
+
+ Parameters
+ ----------
+ concept_type: str
+ Indicates whether to count entities (``concept_type='e'``) or
+ relations (``concept_type='r'``). Default: ``concept_type='e'``.
+
+ Returns
+ -------
+ count: int
+ Count of the entities or relations.
+ """
+ if concept_type == "entity" or concept_type == "e":
+ return self.model.get_count("e")
+ elif concept_type == "relation" or concept_type == "r":
+ return self.model.get_count("r")
+ else:
+ raise ValueError("Invalid value for concept_type!")
+
+ def get_embeddings(self, entities, embedding_type="entity"):
+ """Get the embeddings of entities or relations.
+
+ .. Note ::
+
+ Use :meth:`ampligraph.utils.create_tensorboard_visualizations` to
+ visualize the embeddings with TensorBoard.
+
+ Parameters
+ ----------
+ entities : array-like, dtype=int, shape=[n]
+ The entities (or relations) of interest. Elements of the vector
+ must be the original string literals, and
+ not internal IDs.
+ embedding_type : str
+ If `'e'` or `'entities'`, ``entities`` argument will be
+ considered as a list of knowledge graph entities (i.e. nodes).
+ If set to `'r'` or `'relation'`, they will be treated as relation
+ types instead (i.e. predicates).
+
+ Returns
+ -------
+ embeddings : ndarray, shape [n, k]
+ An array of k-dimensional embeddings.
+ """
+ if embedding_type == "entity" or embedding_type == "e":
+ return self.model.get_embeddings(entities, "e")
+ elif embedding_type == "relation" or embedding_type == "r":
+ return self.model.get_embeddings(entities, "r")
+ else:
+ raise ValueError("Invalid value for embedding_type!")
+
+ def get_hyperparameter_dict(self):
+ """Returns hyperparameters of the model.
+
+ Returns
+ -------
+ hyperparam_dict : dict
+ Dictionary of hyperparameters that were used for training.
+ """
+ ent_idx = np.arange(self.model.data_indexer.get_entities_count())
+ rel_idx = np.arange(self.model.data_indexer.get_relations_count())
+ ent_values_raw = self.model.data_indexer.get_indexes(
+ ent_idx, "e", "ind2raw"
+ )
+ rel_values_raw = self.model.data_indexer.get_indexes(
+ rel_idx, "r", "ind2raw"
+ )
+ return dict(zip(ent_values_raw, ent_idx)), dict(
+ zip(rel_values_raw, rel_idx)
+ )
+
+ def predict(self, X):
+ """
+ Predict the scores of triples using a trained embedding model.
+
+ The function returns raw scores generated by the model.
+
+ Parameters
+ ----------
+ X : ndarray, shape (n, 3)
+ The triples to score.
+
+ Returns
+ -------
+ scores_predict : ndarray, shape (n)
+ The predicted scores for input triples.
+ """
+ return self.model.predict(X)
+
+ def calibrate(
+ self,
+ X_pos,
+ X_neg=None,
+ positive_base_rate=None,
+ batches_count=100,
+ epochs=50,
+ ):
+ """Calibrate predictions.
+
+ The method implements the heuristics described in :cite:`calibration`,
+ using Platt scaling :cite:`platt1999probabilistic`.
+
+ The calibrated predictions can be obtained with :meth:`predict_proba`
+ after calibration is done.
+
+ Parameters
+ ----------
+ X_pos : ndarray, shape (n, 3)
+ Numpy array of positive triples.
+ X_neg : ndarray, shape (n, 3)
+ Numpy array of negative triples.
+ If `None`, the negative triples are generated via corruptions
+ and the user must provide a positive base rate instead.
+ positive_base_rate: float
+ Base rate of positive statements.
+ For example, if we assume there is a fifty-fifty chance of any
+ query to be true, the base rate would be 50%. If ``X_neg`` is
+ provided and this is `None`, the relative sizes of ``X_pos``
+ and ``X_neg`` will be used to determine the base rate.
+ For example, if we have 50 positive triples and 200 negative
+ triples, the positive base rate will be assumed to be
+ 50/(50+200) = 1/5 = 0.2. This must be a value between 0 and 1.
+ batches_count: int
+ Number of batches to complete one epoch of the Platt
+ scaling training. Only applies when ``X_neg`` is `None`.
+ epochs: int
+ Number of epochs used to train the Platt scaling model.
+ Only applies when ``X_neg`` is `None`.
+
+ """
+ batch_size = int(np.ceil(X_pos.shape[0] / batches_count))
+ return self.model.calibrate(
+ X_pos, X_neg, positive_base_rate, batch_size, epochs
+ )
+
+ def predict_proba(self, X):
+ """
+ Predicts probabilities using the Platt scaling model
+ (after calibration).
+
+ Model must be calibrated beforehand with the ``calibrate`` method.
+
+ Parameters
+ ----------
+ X: ndarray, shape (n, 3)
+ Numpy array of triples to be evaluated.
+
+ Returns
+ -------
+ scores: np.array, shape (n, )
+ Calibrated scores for the input triples.
+ """
+ return self.model.predict_proba(X)
+
+ def evaluate(
+ self,
+ x=None,
+ batch_size=32,
+ verbose=True,
+ use_filter=False,
+ corrupt_side="s,o",
+ entities_subset=None,
+ callbacks=None,
+ ):
+ """
+ Evaluate the inputs against corruptions and return ranks.
+
+ Parameters
+ ----------
+ x: np.array, shape (n,3) or str or GraphDataLoader or
+ AbstractGraphPartitioner Data OR Filename of the data file
+ OR Data Handle to be used for training.
+ batch_size: int
+ Batch size to use during training.
+ May be overridden if ``x`` is `GraphDataLoader` or
+ `AbstractGraphPartitioner` instance
+ verbose: bool
+ Verbosity mode.
+ use_filter: bool or dict
+ Whether to use a filter of not. If a dictionary is specified, the
+ data in the dict is concatenated and used as filter.
+ corrupt_side: str
+ Which side to corrupt of a triple to corrupt. It can be the
+ subject (``corrupt_size="s"``), the object (``corrupt_size="o"``),
+ the subject and the object (``corrupt_size="s+o"`` or
+ ``corrupt_size="s,o"``) (default:`"s,o"`).
+ entities_subset: list or np.array
+ Subset of entities to be used for generating corruptions.
+ callbacks: list of keras.callbacks.Callback instances
+ List of callbacks to apply during evaluation.
+
+ Returns
+ -------
+ rank: np.array, shape (n, number of corrupted sides)
+ Ranking of test triples against subject corruptions and/or
+ object corruptions.
+ """
+
+ return self.model.evaluate(
+ x,
+ batch_size=batch_size,
+ verbose=verbose,
+ use_filter=use_filter,
+ corrupt_side=corrupt_side,
+ entities_subset=entities_subset,
+ callbacks=callbacks,
+ )
+
+
+@register_compatibility("TransE")
+class TransE(ScoringModelBase):
+ """Class wrapping around the ScoringBasedEmbeddingModel with the TransE
+ scoring function."""
+
+ def __init__(
+ self,
+ k=100,
+ eta=2,
+ epochs=100,
+ batches_count=100,
+ seed=0,
+ embedding_model_params={
+ "corrupt_sides": ["s,o"],
+ "negative_corruption_entities": "all",
+ "norm": 1,
+ "normalize_ent_emb": False,
+ },
+ optimizer="adam",
+ optimizer_params={"lr": 0.0005},
+ loss="nll",
+ loss_params={},
+ regularizer=None,
+ regularizer_params={},
+ initializer="xavier",
+ initializer_params={"uniform": False},
+ verbose=False,
+ model=None,
+ ):
+ """Initialize the ScoringBasedEmbeddingModel with the TransE
+ scoring function."""
+ super().__init__(
+ k,
+ eta,
+ epochs,
+ batches_count,
+ seed,
+ embedding_model_params,
+ optimizer,
+ optimizer_params,
+ loss,
+ loss_params,
+ regularizer,
+ regularizer_params,
+ initializer,
+ initializer_params,
+ verbose,
+ model,
+ )
+
+ self.model_name = "TransE"
+
+
+@register_compatibility("DistMult")
+class DistMult(ScoringModelBase):
+ """Class wrapping around the ScoringBasedEmbeddingModel with the DistMult
+ scoring function."""
+
+ def __init__(
+ self,
+ k=100,
+ eta=2,
+ epochs=100,
+ batches_count=100,
+ seed=0,
+ embedding_model_params={
+ "corrupt_sides": ["s,o"],
+ "negative_corruption_entities": "all",
+ "norm": 1,
+ "normalize_ent_emb": False,
+ },
+ optimizer="adam",
+ optimizer_params={"lr": 0.0005},
+ loss="nll",
+ loss_params={},
+ regularizer=None,
+ regularizer_params={},
+ initializer="xavier",
+ initializer_params={"uniform": False},
+ verbose=False,
+ model=None,
+ ):
+ """Initialize the ScoringBasedEmbeddingModel with the DistMult
+ scoring function."""
+ super().__init__(
+ k,
+ eta,
+ epochs,
+ batches_count,
+ seed,
+ embedding_model_params,
+ optimizer,
+ optimizer_params,
+ loss,
+ loss_params,
+ regularizer,
+ regularizer_params,
+ initializer,
+ initializer_params,
+ verbose,
+ model,
+ )
+
+ self.model_name = "DistMult"
+
+
+@register_compatibility("ComplEx")
+class ComplEx(ScoringModelBase):
+ """Class wrapping around the ScoringBasedEmbeddingModel with the ComplEx
+ scoring function."""
+
+ def __init__(
+ self,
+ k=100,
+ eta=2,
+ epochs=100,
+ batches_count=100,
+ seed=0,
+ embedding_model_params={
+ "corrupt_sides": ["s,o"],
+ "negative_corruption_entities": "all",
+ "norm": 1,
+ "normalize_ent_emb": False,
+ },
+ optimizer="adam",
+ optimizer_params={"lr": 0.0005},
+ loss="nll",
+ loss_params={},
+ regularizer=None,
+ regularizer_params={},
+ initializer="xavier",
+ initializer_params={"uniform": False},
+ verbose=False,
+ model=None,
+ ):
+ """Initialize the ScoringBasedEmbeddingModel with the ComplEx
+ scoring function."""
+ super().__init__(
+ k,
+ eta,
+ epochs,
+ batches_count,
+ seed,
+ embedding_model_params,
+ optimizer,
+ optimizer_params,
+ loss,
+ loss_params,
+ regularizer,
+ regularizer_params,
+ initializer,
+ initializer_params,
+ verbose,
+ model,
+ )
+
+ self.model_name = "ComplEx"
+
+
+@register_compatibility("HolE")
+class HolE(ScoringModelBase):
+ """Class wrapping around the ScoringBasedEmbeddingModel with the HolE
+ scoring function."""
+
+ def __init__(
+ self,
+ k=100,
+ eta=2,
+ epochs=100,
+ batches_count=100,
+ seed=0,
+ embedding_model_params={
+ "corrupt_sides": ["s,o"],
+ "negative_corruption_entities": "all",
+ "norm": 1,
+ "normalize_ent_emb": False,
+ },
+ optimizer="adam",
+ optimizer_params={"lr": 0.0005},
+ loss="nll",
+ loss_params={},
+ regularizer=None,
+ regularizer_params={},
+ initializer="xavier",
+ initializer_params={"uniform": False},
+ verbose=False,
+ model=None,
+ ):
+ """Initialize the ScoringBasedEmbeddingModel with the HolE
+ scoring function."""
+ super().__init__(
+ k,
+ eta,
+ epochs,
+ batches_count,
+ seed,
+ embedding_model_params,
+ optimizer,
+ optimizer_params,
+ loss,
+ loss_params,
+ regularizer,
+ regularizer_params,
+ initializer,
+ initializer_params,
+ verbose,
+ model,
+ )
+
+ self.model_name = "HolE"
diff --git a/ampligraph/datasets/__init__.py b/ampligraph/datasets/__init__.py
index 141f9714..fca586ca 100644
--- a/ampligraph/datasets/__init__.py
+++ b/ampligraph/datasets/__init__.py
@@ -1,21 +1,68 @@
-# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
#
# This file is Licensed under the Apache License, Version 2.0.
# A copy of the Licence is available in LICENCE, or at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
-"""Helper functions to load knowledge graphs."""
-
-from .datasets import load_from_csv, load_from_rdf, load_fb15k, load_wn18, load_fb15k_237, load_from_ntriples, \
- load_yago3_10, load_wn18rr, load_wn11, load_fb13, load_onet20k, load_ppi5k, load_nl27k, load_cn15k
-
-from .abstract_dataset_adapter import AmpligraphDatasetAdapter
+"""Support for loading and managing datasets."""
+from .datasets import (
+ _load_xai_fb15k_237_experiment_log,
+ load_cn15k,
+ load_codex,
+ load_fb13,
+ load_fb15k,
+ load_fb15k_237,
+ load_from_csv,
+ load_from_ntriples,
+ load_from_rdf,
+ load_nl27k,
+ load_onet20k,
+ load_ppi5k,
+ load_wn11,
+ load_wn18,
+ load_wn18rr,
+ load_yago3_10,
+)
+from .graph_data_loader import DataIndexer, GraphDataLoader, NoBackend
+from .graph_partitioner import PARTITION_ALGO_REGISTRY, BucketGraphPartitioner
+from .source_identifier import (
+ DataSourceIdentifier,
+ chunks,
+ load_csv,
+ load_gz,
+ load_json,
+ load_tar,
+)
from .sqlite_adapter import SQLiteAdapter
-from .numpy_adapter import NumpyDatasetAdapter
-from .oneton_adapter import OneToNDatasetAdapter
-__all__ = ['load_from_csv', 'load_from_rdf', 'load_from_ntriples', 'load_wn18', 'load_fb15k',
- 'load_fb15k_237', 'load_yago3_10', 'load_wn18rr', 'load_wn11', 'load_fb13',
- 'load_onet20k', 'load_ppi5k', 'load_nl27k', 'load_cn15k',
- 'AmpligraphDatasetAdapter', 'NumpyDatasetAdapter', 'SQLiteAdapter', 'OneToNDatasetAdapter']
+__all__ = [
+ "load_from_csv",
+ "load_from_rdf",
+ "load_wn18",
+ "load_fb15k",
+ "load_fb15k_237",
+ "load_from_ntriples",
+ "load_yago3_10",
+ "load_wn18rr",
+ "load_wn11",
+ "load_fb13",
+ "load_onet20k",
+ "load_ppi5k",
+ "load_nl27k",
+ "load_cn15k",
+ "load_codex",
+ "chunks",
+ "load_json",
+ "load_gz",
+ "load_tar",
+ "load_csv",
+ "DataSourceIdentifier",
+ "DataIndexer",
+ "NoBackend",
+ "_load_xai_fb15k_237_experiment_log",
+ "SQLiteAdapter",
+ "GraphDataLoader",
+ "BucketGraphPartitioner",
+ "PARTITION_ALGO_REGISTRY",
+]
diff --git a/ampligraph/datasets/abstract_dataset_adapter.py b/ampligraph/datasets/abstract_dataset_adapter.py
deleted file mode 100644
index f6d10c14..00000000
--- a/ampligraph/datasets/abstract_dataset_adapter.py
+++ /dev/null
@@ -1,142 +0,0 @@
-# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
-#
-# This file is Licensed under the Apache License, Version 2.0.
-# A copy of the Licence is available in LICENCE, or at:
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-import abc
-
-
-class AmpligraphDatasetAdapter(abc.ABC):
- """Abstract class for dataset adapters
- Developers can design in similar format to adapt data from different sources to feed to ampligraph.
- """
- def __init__(self):
- """Initialize the class variables
- """
- self.dataset = {}
-
- # relation to idx mappings
- self.rel_to_idx = {}
- # entities to idx mappings
- self.ent_to_idx = {}
- # Mapped status of each dataset
- self.mapped_status = {}
- # link weights for focusE
- self.focusE_numeric_edge_values = {}
-
- def use_mappings(self, rel_to_idx, ent_to_idx):
- """Use an existing mapping with the datasource.
- """
- self.rel_to_idx = rel_to_idx
- self.ent_to_idx = ent_to_idx
- # set the mapped status to false, since we are changing the dictionary
- for key in self.dataset.keys():
- self.mapped_status[key] = False
-
- def generate_mappings(self, use_all=False):
- """Generate mappings from either train set or use all dataset to generate mappings
- Parameters
- ----------
- use_all : boolean
- If True, it generates mapping from all the data. If False, it only uses training set to generate mappings
-
- Returns
- -------
- rel_to_idx : dictionary
- Relation to idx mapping dictionary
- ent_to_idx : dictionary
- entity to idx mapping dictionary
- """
- raise NotImplementedError('Abstract Method not implemented!')
-
- def get_size(self, dataset_type="train"):
- """Returns the size of the specified dataset
- Parameters
- ----------
- dataset_type : string
- type of the dataset
-
- Returns
- -------
- size : int
- size of the specified dataset
- """
-
- raise NotImplementedError('Abstract Method not implemented!')
-
- def data_exists(self, dataset_type="train"):
- """Checks if a dataset_type exists in the adapter.
- Parameters
- ----------
- dataset_type : string
- type of the dataset
-
- Returns
- -------
- exists : bool
- Boolean indicating if dataset_type exists in the adapter.
- """
-
- raise NotImplementedError('Abstract Method not implemented!')
-
- def set_data(self, dataset, dataset_type=None, mapped_status=False):
- """set the dataset based on the type
- Parameters
- ----------
- dataset : nd-array or dictionary
- dataset of triples
- dataset_type : string
- if the dataset parameter is an nd- array then this indicates the type of the data being based
- mapped_status : bool
- indicates whether the data has already been mapped to the indices
-
- """
- raise NotImplementedError('Abstract Method not implemented!')
-
- def map_data(self, remap=False):
- """map the data to the mappings of ent_to_idx and rel_to_idx
- Parameters
- ----------
- remap : boolean
- remap the data, if already mapped. One would do this if the dictionary is updated.
- """
- raise NotImplementedError('Abstract Method not implemented!')
-
- def set_filter(self, filter_triples):
- """set's the filter that need to be used while generating evaluation batch
- Parameters
- ----------
- filter_triples : nd-array
- triples that would be used as filter
- """
- raise NotImplementedError('Abstract Method not implemented!')
-
- def get_next_batch(self, batches_count=-1, dataset_type="train", use_filter=False):
- """Generator that returns the next batch of data.
-
- Parameters
- ----------
- dataset_type: string
- indicates which dataset to use
- batches_count: int
- number of batches per epoch (default: -1, i.e. uses batch_size of 1)
- use_filter : bool
- Flag to indicate whether to return the concepts that need to be filtered
-
- Returns
- -------
- batch_output : nd-array
- yields a batch of triples from the dataset type specified
- participating_objects : nd-array [n,1]
- all objects that were involved in the s-p-? relation. This is returned only if use_filter is set to true.
- participating_subjects : nd-array [n,1]
- all subjects that were involved in the ?-p-o relation. This is returned only if use_filter is set to true.
- """
- raise NotImplementedError('Abstract Method not implemented!')
-
- def cleanup(self):
- """Cleans up the internal state
- """
- raise NotImplementedError('Abstract Method not implemented!')
diff --git a/ampligraph/datasets/data_adapter.py b/ampligraph/datasets/data_adapter.py
new file mode 100644
index 00000000..04332b9c
--- /dev/null
+++ b/ampligraph/datasets/data_adapter.py
@@ -0,0 +1,144 @@
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
+#
+# This file is Licensed under the Apache License, Version 2.0.
+# A copy of the Licence is available in LICENCE, or at:
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+import contextlib
+
+import tqdm
+from tensorflow.python.framework import errors
+
+from .graph_data_loader import GraphDataLoader, NoBackend
+from .graph_partitioner import AbstractGraphPartitioner
+from .partitioned_data_manager import get_partition_adapter
+
+
+class DataHandler:
+ def __init__(
+ self,
+ x,
+ model=None,
+ batch_size=1,
+ dataset_type="train",
+ epochs=1,
+ initial_epoch=0,
+ use_indexer=True,
+ use_filter=True,
+ partitioning_k=1,
+ ):
+ """Initializes the DataHandler
+
+ Parameters
+ ----------
+ model: tf.keras.Model
+ Model instance.
+ batch_size: int
+ Batch size to use during training.
+ May be overridden if `x` is ``GraphDataLoader`` or ``AbstractGraphPartitioner`` instance.
+ dataset_type: string
+ Dataset type that is being passed.
+ epochs: int
+ Number of epochs to train (default: 1)
+ initial epoch: int
+ Initial epoch number (default: 1)
+ use_indexer: bool or Mapper
+ Whether the data needs to be indexed or whether we need to use pre-defined indexer to map
+ the data to index.
+ use_filter: bool or dict
+ Whether to use filter of not. If a dictionary is specified, the data in the dict is concatenated
+ and used as filter.
+ partitioning_k: int
+ Number of partitions to create.
+ May be overridden if `x` is an ``AbstractGraphPartitioner`` instance
+ """
+ self._initial_epoch = initial_epoch
+ self._epochs = epochs
+ self._model = model
+ self._inferred_steps = None
+ self.using_partitioning = False
+
+ if partitioning_k <= 0:
+ raise ValueError("Incorrect value specified to partitioning_k")
+
+ if isinstance(x, GraphDataLoader):
+ self._adapter = x
+ self._parent_adapter = self._adapter
+ elif isinstance(x, AbstractGraphPartitioner):
+ self._parent_adapter = x._data
+ self._adapter = x
+ self.using_partitioning = True
+ # override the partitioning_k value using partitioners k
+ partitioning_k = x._k
+ else:
+ # use graph data loader by default
+ self._adapter = GraphDataLoader(
+ x,
+ backend=NoBackend,
+ batch_size=batch_size,
+ dataset_type=dataset_type,
+ use_indexer=use_indexer,
+ use_filter=use_filter,
+ in_memory=True,
+ )
+ self._parent_adapter = self._adapter
+ if partitioning_k > 1:
+ # if use partitioning then pass the graph data loader to partitioner and use
+ # partitioned data manager
+ assert (
+ model is not None
+ ), "Please pass the model to data_handler for partitioning!"
+ self._adapter = get_partition_adapter(
+ self._adapter,
+ self._model,
+ strategy="Bucket",
+ partitioning_k=partitioning_k,
+ )
+
+ self.using_partitioning = True
+
+ @contextlib.contextmanager
+ def catch_stop_iteration(self):
+ """Catches errors when an iterator runs out of data."""
+ try:
+ yield
+ except (StopIteration, errors.OutOfRangeError):
+ if self._inferred_steps is None:
+ self._inferred_steps = self._current_iter
+
+ def steps(self):
+ """Counts the number of steps in an epoch."""
+ self._current_iter = 0
+ while (
+ self._inferred_steps is None
+ or self._current_iter < self._inferred_steps
+ ):
+ self._current_iter += 1
+ yield self._current_iter
+
+ @property
+ def inferred_steps(self):
+ """Returns the number of steps in the batch."""
+ return self._inferred_steps
+
+ def enumerate_epochs(self, use_tqdm=False):
+ """Manages the (reloading) data adapter before epoch starts."""
+ for epoch in tqdm.tqdm(
+ range(self._initial_epoch, self._epochs), disable=not use_tqdm
+ ):
+ self._adapter.reload()
+ yield epoch, iter(self._adapter.get_tf_generator())
+ self._adapter.on_epoch_end()
+
+ self._adapter.on_complete()
+
+ def get_mapper(self):
+ """Returns the mapper of the main data loader class."""
+ return self._parent_adapter.backend.mapper
+
+ def get_update_partitioner_metadata(self, filepath):
+ out_dict = {}
+ if self.using_partitioning:
+ out_dict = self._adapter.get_update_metadata(filepath)
+ return out_dict
diff --git a/ampligraph/datasets/data_indexer.py b/ampligraph/datasets/data_indexer.py
new file mode 100644
index 00000000..7425d2b6
--- /dev/null
+++ b/ampligraph/datasets/data_indexer.py
@@ -0,0 +1,1716 @@
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
+#
+# This file is Licensed under the Apache License, Version 2.0.
+# A copy of the Licence is available in LICENCE, or at:
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+"""Data indexer.
+
+This module provides a class that maps raw data to indexes and the other way around.
+It can be persisted and contains supporting functions.
+
+Example
+-------
+ >>>data = np.array([['/m/01', '/relation1', '/m/02'],
+ >>> ['/m/01', '/relation2', '/m/07']])
+ >>>mapper = DataIndexer(data)
+ >>>mapper.get_indexes(data)
+
+.. It extends functionality of to_idx(...) from AmpliGraph 1:
+ https://docs.ampligraph.org/en/1.3.1/generated/ampligraph.evaluation.to_idx.html?highlight=to_idx
+
+"""
+import logging
+import os
+import shelve
+import shutil
+import sqlite3
+import tempfile
+import uuid
+from datetime import datetime
+
+import numpy as np
+import pandas as pd
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.DEBUG)
+INDEXER_BACKEND_REGISTRY = {}
+
+
+class DataIndexer:
+ """Index graph unique entities and relations.
+
+ Abstract class with unified API for different indexers implementations (`in-memory`, `shelves`, `sqlite`).
+
+ It can support large datasets in two modelities:
+ - using dictionaries for in-memory storage
+ - using persistent dictionary storage (python shelves, sqlite), for dumping huge indexes.
+
+ Methods:
+ - create_mappings - core function that creates mappings.
+ - get_indexes - given an array of triples, returns it in an indexed form,
+ or given indexes, it returns the original triples (subject to parameters).
+ - update_mappings [NotYetImplemented] - update mappings from new data.
+
+ Properties:
+ - data - data to be indexed, either a numpy array or a generator.
+
+ Example
+ -------
+ >>> from ampligraph.datasets.data_indexer import DataIndexer
+ >>> import numpy as np
+ >>> # In-memory mapping
+ >>> data = np.array([['a','b','c'],['c','b','d'],['d','e','f']])
+ >>> mapper = DataIndexer(data, backend='in_memory')
+ >>> print(mapper.get_indexes(data))
+ [[0 0 1]
+ [1 0 2]
+ [2 1 3]]
+ >>> # Persistent mapping
+ >>> data = np.array([['a','b','c'],['c','b','d'],['d','e','f']])
+ >>> mapper = DataIndexer(data, backend='sqlite')
+ >>> print(mapper.get_indexes(data))
+ [[0 0 1]
+ [1 0 2]
+ [2 1 3]]
+ """
+
+ def __init__(self, X, backend="in_memory", **kwargs):
+ """Initialises the DataIndexer."""
+ logger.debug("Initialisation of DataIndexer.")
+ self.backend_type = backend
+ self.data = X
+ # if len(kwargs) == 0:
+ # self.backend = INDEXER_BACKEND_REGISTRY.get(backend)(X)
+ # else:
+ self.backend = INDEXER_BACKEND_REGISTRY.get(backend)(X, **kwargs)
+ if not self.backend.mapped:
+ self.backend.create_mappings()
+ self.metadata = self.backend.metadata
+
+ def update_mappings(self, X):
+ """Update existing mappings with new data."""
+ self.backend.update_mappings(X)
+
+ def get_update_metadata(self, new_file_name=None):
+ metadata = self.backend.get_update_metadata(new_file_name)
+ metadata["backend"] = self.backend_type
+ return metadata
+
+ def get_indexes(self, X, type_of="t", order="raw2ind"):
+ """Converts raw data to an indexed form or vice versa according to previously created mappings.
+
+ Parameters
+ ----------
+ X: array
+ Array with raw or indexed data.
+ type_of: str
+ Type of provided sample to be specified as one of the following values: `{"t", "e", "r"}`.
+ It indicates whether the provided sample is an array of triples (`"t"`), a list of entities (`"e"`)
+ or a list of relations (`"r"`).
+ order: str
+ It specifies whether it converts raw data to indexes (``order="raw2ind"``) or indexes to raw
+ data (``order="ind2raw"``)
+
+ Returns
+ -------
+ Y: array
+ Array of the same size as `sample` but with indexes of elements instead of raw data or raw data instead
+ of indexes.
+ """
+ return self.backend.get_indexes(X, type_of=type_of, order=order)
+
+ def get_relations_count(self):
+ """Get number of unique relations."""
+ return self.backend.get_relations_count()
+
+ def get_entities_count(self):
+ """Get number of unique entities."""
+ return self.backend.get_entities_count()
+
+ def clean(self):
+ """Remove persisted and in-memory objects."""
+ return self.backend.clean()
+
+ def get_entities_in_batches(self, batch_size=-1, random=False, seed=None):
+ """Generator that retrieves entities and return them in batches.
+
+ Parameters
+ ----------
+ batch_size: int
+ Size of array that the batch should have, :math:`-1` when the whole dataset is required.
+ random: bool
+ Whether to return elements of batch in a random order (default: `False`).
+ seed: int
+ Used with ``random=True``, seed for repeatability of experiments.
+
+ Yields
+ ------
+ Batch: numppy array
+ Batch of data of size (batch_size, 3).
+
+ """
+ ents_len = self.get_entities_count()
+ if batch_size == -1:
+ batch_size = ents_len
+ entities = list(range(0, ents_len, batch_size))
+ for start_index in entities:
+ if start_index + batch_size >= ents_len:
+ batch_size = ents_len - start_index
+ ents = list(range(start_index, start_index + batch_size))
+ if random:
+ np.random.seed(seed)
+ np.random.shuffle(ents)
+ yield np.array(ents)
+
+
+def register_indexer_backend(name):
+ """Decorator responsible for registering partition in the partition registry.
+
+ Parameters
+ ----------
+ name: str
+ Name of the new backend.
+
+ Example
+ -------
+ >>>@register_indexer_backend("NewBackendName")
+ >>>class NewBackend():
+ >>>... pass
+ """
+
+ def insert_in_registry(class_handle):
+ """Checks if backend already exists and if not registers it."""
+ if name in INDEXER_BACKEND_REGISTRY.keys():
+ msg = "Indexer backend with name {} already exists!".format(name)
+ logger.error(msg)
+ raise Exception(msg)
+
+ INDEXER_BACKEND_REGISTRY[name] = class_handle
+ class_handle.name = name
+
+ return class_handle
+
+ return insert_in_registry
+
+
+@register_indexer_backend("in_memory")
+class InMemory:
+ def __init__(
+ self,
+ data,
+ entities_dict=None,
+ reversed_entities_dict=None,
+ relations_dict=None,
+ reversed_relations_dict=None,
+ root_directory=tempfile.gettempdir(),
+ name="main_partition",
+ **kwargs
+ ):
+ """Initialise backend by creating mappings.
+
+ Parameters
+ ----------
+ data: array
+ Data to be indexed.
+ entities_dict: dict or shelve path
+ Dictionary or shelve path storing entities mappings; if not provided, it is created from data.
+ reversed_entities_dict: dictionary or shelve path
+ Dictionary or shelve path storing reversed entities mappings; if not provided, it is created from data.
+ relations_dict: dictionary or shelve path
+ Dictionary or shelve path storing relations mappings; if not provided, it is created from data.
+ reversed_relations_dict: dictionary or shelve path
+ Dictionary or shelve path storing reversed relations mappings; if not provided, it is created from data.
+ root_directory: str
+ Path of the directory where to store persistent mappings.
+ """
+ self.data = data
+ self.mapped = False
+ self.metadata = {}
+ # ent to idx dict
+ self.entities_dict = entities_dict
+ self.reversed_entities_dict = reversed_entities_dict
+ # rel to idx dict
+ self.relations_dict = relations_dict
+ self.reversed_relations_dict = reversed_relations_dict
+
+ self.root_directory = root_directory
+ self.name = name
+
+ self.max_ents_index = -1
+ self.max_rels_index = -1
+ self.ents_length = 0
+ self.rev_ents_length = 0
+ self.rels_length = 0
+ self.rev_rels_length = 0
+
+ def get_all_entities(self):
+ """Returns all the (raw) entities in the dataset"""
+ return list(self.entities_dict.values())
+
+ def get_all_relations(self):
+ """Returns all the (raw) relations in the dataset"""
+ return list(self.relations_dict.values())
+
+ def create_mappings(self):
+ """Create mappings of data into indexes.
+
+ It creates four dictionaries with keys as unique entities/relations and values as indexes and reversed
+ version of it. Dispatches to the adequate functions to create persistent or in-memory dictionaries.
+ """
+
+ if (
+ isinstance(self.entities_dict, dict)
+ and isinstance(self.reversed_entities_dict, dict)
+ and isinstance(self.relations_dict, dict)
+ and isinstance(self.reversed_relations_dict, dict)
+ ):
+ self._update_properties()
+ logger.debug(
+ "The mappings initialised from in-memory dictionaries."
+ )
+ elif (
+ self.entities_dict is None
+ and self.reversed_entities_dict is None
+ and self.relations_dict is None
+ and self.reversed_relations_dict is None
+ ):
+ logger.debug(
+ "The mappings will be created for data in {}.".format(
+ self.name
+ )
+ )
+
+ if isinstance(self.data, np.ndarray):
+ self.update_dictionary_mappings()
+ else:
+ self.update_dictionary_mappings_in_chunks()
+ else:
+ logger.debug(
+ "Provided initialization objects are not supported. Can't Initialise mappings."
+ )
+ self.mapped = True
+
+ def _update_properties(self):
+ """Initialise properties from the in-memory dictionary."""
+ self.max_ents_index = self._get_max_ents_index()
+ self.max_rels_index = self._get_max_rels_index()
+
+ self.ents_length = len(self.entities_dict)
+ self.rev_ents_length = len(self.reversed_entities_dict)
+ self.rels_length = len(self.relations_dict)
+ self.rev_rels_length = len(self.reversed_relations_dict)
+
+ def _get_max_ents_index(self):
+ """Get maximum index from entities dictionary."""
+ return max(self.reversed_entities_dict.values())
+
+ def _get_max_rels_index(self):
+ """Get maximum index from relations dictionary."""
+ return max(self.reversed_relations_dict.values())
+
+ def _get_starting_index_ents(self):
+ """Returns next index to continue adding elements to entities dictionary."""
+ if not self.entities_dict:
+ self.entities_dict = {}
+ self.reversed_entities_dict = {}
+ return 0
+ else:
+ return self.max_ents_index + 1
+
+ def _get_starting_index_rels(self):
+ """Returns next index to continue adding elements to relations dictionary."""
+ if not self.relations_dict:
+ self.relations_dict = {}
+ self.reversed_relations_dict = {}
+ return 0
+ else:
+ return self.max_rels_index + 1
+
+ def update_mappings(self, new_data):
+ """Update existing mappings with new data."""
+ self.update_dictionary_mappings(new_data)
+
+ def get_update_metadata(self, new_file_name=None):
+ metadata = {
+ "entities_dict": self.entities_dict,
+ "reversed_entities_dict": self.reversed_entities_dict,
+ "relations_dict": self.relations_dict,
+ "reversed_relations_dict": self.reversed_relations_dict,
+ }
+ return metadata
+
+ def update_dictionary_mappings(self, sample=None):
+ """Index entities and relations.
+
+ Creates shelves for mappings between entities and relations to indexes and reverse mapping.
+ Remember to use mappings for entities with entities and relations with relations!
+ """
+ if sample is None:
+ sample = self.data
+ # logger.debug(sample)
+ i = self._get_starting_index_ents()
+ j = self._get_starting_index_rels()
+
+ for d in sample:
+ if d[0] not in self.reversed_entities_dict:
+ self.reversed_entities_dict[d[0]] = i
+ self.entities_dict[i] = d[0]
+ i += 1
+ if d[2] not in self.reversed_entities_dict:
+ self.reversed_entities_dict[d[2]] = i
+ self.entities_dict[i] = d[2]
+ i += 1
+ if d[1] not in self.reversed_relations_dict:
+ self.reversed_relations_dict[d[1]] = j
+ self.relations_dict[j] = d[1]
+ j += 1
+
+ self.max_ents_index = i - 1
+ self.max_rels_index = j - 1
+
+ self.ents_length = len(self.entities_dict)
+ self.rev_ents_length = len(self.reversed_entities_dict)
+ self.rels_length = len(self.relations_dict)
+ self.rev_rels_length = len(self.reversed_relations_dict)
+
+ if self.rev_ents_length != self.ents_length:
+ msg = "Reversed entities index size not equal to index size ({} and {})".format(
+ self.rev_ents_length, self.ents_length
+ )
+ logger.error(msg)
+ raise Exception(msg)
+
+ if self.rev_rels_length != self.rels_length:
+ msg = "Reversed relations index size not equal to index size ({} and {})".format(
+ self.rev_rels_length, self.rels_length
+ )
+ logger.error(msg)
+ raise Exception(msg)
+
+ logger.debug(
+ "Mappings updated with: {} ents, {} rev_ents, {} rels and {} rev_rels".format(
+ self.ents_length,
+ self.rev_ents_length,
+ self.rels_length,
+ self.rev_rels_length,
+ )
+ )
+
+ def update_dictionary_mappings_in_chunks(self):
+ """Update dictionary mappings chunk by chunk."""
+ for chunk in self.data:
+ if isinstance(chunk, np.ndarray):
+ self.update_dictionary_mappings(chunk)
+ else:
+ self.update_dictionary_mappings(chunk.values)
+
+ def get_indexes(self, sample=None, type_of="t", order="raw2ind"):
+ """Converts raw data to an indexed form or vice versa according to previously created mappings.
+
+ Parameters
+ ----------
+ sample: array
+ Array with raw or indexed data.
+ type_of: str
+ Type of provided sample to be specified as one of the following values: `{"t", "e", "r"}`.
+ It indicates whether the provided sample is an array of triples (`"t"`), a list of entities (`"e"`)
+ or a list of relations (`"r"`).
+ order: str
+ It specifies whether it converts raw data to indexes (``order="raw2ind"``) or indexes to raw
+ data (``order="ind2raw"``)
+
+ Returns
+ -------
+ Array: array
+ Array of the same size as `sample` but with indexes of elements instead of raw data or raw data instead
+ of indexes.
+ """
+ if type_of not in ["t", "e", "r"]:
+ msg = "Type (type_of) should be one of the following: t, e, r, instead got {}".format(
+ type_of
+ )
+ logger.error(msg)
+ raise Exception(msg)
+
+ if type_of == "t":
+ if isinstance(sample, pd.DataFrame):
+ sample = sample.values
+ self.data_shape = sample.shape[1]
+ indexed_data = self.get_indexes_from_a_dictionary(
+ sample[:, :3], order=order
+ )
+ # focusE
+ if sample.shape[1] > 3:
+ # weights = preprocess_focusE_weights(data=sample[:, :3], weights=sample[:, 3:])
+ weights = sample[:, 3:]
+ return np.concatenate([indexed_data, weights], axis=1)
+ else:
+ return indexed_data
+ else:
+ return self.get_indexes_from_a_dictionary_single(
+ sample, type_of=type_of, order=order
+ )
+
+ def get_indexes_from_a_dictionary(self, sample, order="raw2ind"):
+ """Get indexed triples from an in-memory dictionary.
+
+ Parameters
+ ----------
+ sample: array
+ Array with raw or indexed triples.
+ order: str
+ It specifies whether it converts raw data to indexes (``order="raw2ind"``) or indexes to raw
+ data (``order="ind2raw"``)
+
+ Returns
+ -------
+ Array: array
+ Array of the same size as `sample` but with indexes of elements instead of raw data or raw data instead
+ of indexes.
+ """
+ if order == "raw2ind":
+ entities = self.reversed_entities_dict
+ relations = self.reversed_relations_dict
+ dtype = np.int32
+ elif order == "ind2raw":
+ entities = self.entities_dict
+ relations = self.relations_dict
+ dtype = str
+ else:
+ msg = "No such order available options: ind2raw, raw2ind, instead got {}.".format(
+ order
+ )
+ logger.error(msg)
+ raise Exception(msg)
+ if entities is None and relations is None:
+ msg = "Requested entities and relation mappings are empty."
+ logger.error(msg)
+ raise Exception(msg)
+
+ subjects = []
+ objects = []
+ predicates = []
+
+ invalid_keys = 0
+ for row in sample:
+ try:
+ s = entities[row[0]]
+ p = relations[row[1]]
+ o = entities[row[2]]
+ subjects.append(s)
+ predicates.append(p)
+ objects.append(o)
+ except KeyError:
+ invalid_keys += 1
+
+ if invalid_keys > 0:
+ print(
+ "\n{} triples containing invalid keys skipped!".format(
+ invalid_keys
+ )
+ )
+
+ subjects = np.array(subjects, dtype=dtype)
+ objects = np.array(objects, dtype=dtype)
+ predicates = np.array(predicates, dtype=dtype)
+
+ merged = np.stack([subjects, predicates, objects], axis=1)
+ return merged
+
+ def get_indexes_from_a_dictionary_single(
+ self, sample, type_of="e", order="raw2ind"
+ ):
+ """Get indexed elements (entities, relations) or raw data from an in-memory dictionary.
+
+ Parameters
+ ----------
+ sample: array
+ Array with raw or indexed data.
+ type_of: str
+ Type of provided sample to be specified as one of the following values: `{"t", "e", "r"}`.
+ It indicates whether the provided sample is an array of triples (`"t"`), a list of entities (`"e"`)
+ or a list of relations (`"r"`).
+ order: str
+ It specifies whether it converts raw data to indexes (``order="raw2ind"``) or indexes to raw
+ data (``order="ind2raw"``)
+
+ Returns
+ -------
+ tmp: array
+ Array of the same size as `sample` but with indexes of elements instead of raw data or raw data instead
+ of indexes.
+ """
+ if order == "raw2ind":
+ entities = self.reversed_entities_dict
+ relations = self.reversed_relations_dict
+ dtype = np.int32
+ elif order == "ind2raw":
+ entities = self.entities_dict
+ relations = self.relations_dict
+ dtype = str
+ else:
+ msg = "No such order available options: ind2raw, raw2ind, instead got {}.".format(
+ order
+ )
+ logger.error(msg)
+ raise Exception(msg)
+
+ if entities is None and relations is None:
+ msg = "Requested entities and relations mappings are empty."
+ logger.error(msg)
+ raise Exception(msg)
+
+ if type_of == "e":
+ elements = np.array([entities[x] for x in sample], dtype=dtype)
+ return elements
+ elif type_of == "r":
+ elements = np.array([relations[x] for x in sample], dtype=dtype)
+ return elements
+ else:
+ if type_of not in ["r", "e"]:
+ msg = "No such option, should be r (relations) or e (entities), instead got {}".format(
+ type_of
+ )
+ logger.error(msg)
+ raise Exception(msg)
+
+ def get_relations_count(self):
+ """Get number of unique relations."""
+ return len(self.relations_dict)
+
+ def get_entities_count(self):
+ """Get number of unique entities."""
+ return len(self.entities_dict)
+
+ def clean(self):
+ """Remove stored objects."""
+ del self.entities_dict
+ del self.reversed_entities_dict
+ del self.relations_dict
+ del self.reversed_relations_dict
+
+
+@register_indexer_backend("shelves")
+class Shelves:
+ def __init__(
+ self,
+ data,
+ entities_dict=None,
+ reversed_entities_dict=None,
+ relations_dict=None,
+ reversed_relations_dict=None,
+ root_directory=tempfile.gettempdir(),
+ name="main_partition",
+ **kwargs
+ ):
+ """Initialise backend by creating mappings.
+
+ Parameters
+ ----------
+ data: array
+ Data to be indexed.
+ entities_dict: dict or shelve path
+ Dictionary or shelve path storing entities mappings; if not provided, it is created from data.
+ reversed_entities_dict: dictionary or shelve path
+ Dictionary or shelve path storing reversed entities mappings; if not provided, it is created from data.
+ relations_dict: dictionary or shelve path
+ Dictionary or shelve path storing relations mappings; if not provided, it is created from data.
+ reversed_relations_dict: dictionary or shelve path
+ Dictionary or shelve path storing reversed relations mappings; if not provided, it is created from data.
+ root_directory: str
+ Path of the directory where to store persistent mappings.
+ """
+ self.data = data
+ self.mapped = False
+ self.metadata = {}
+ self.entities_dict = entities_dict
+ self.reversed_entities_dict = reversed_entities_dict
+ self.relations_dict = relations_dict
+ self.reversed_relations_dict = reversed_relations_dict
+
+ self.root_directory = root_directory
+ self.name = name
+
+ self.max_ents_index = -1
+ self.max_rels_index = -1
+ self.ents_length = 0
+ self.rev_ents_length = 0
+ self.rels_length = 0
+ self.rev_rels_length = 0
+
+ def get_all_entities(self):
+ """Returns all the (raw) entities in the dataset."""
+ return list(self.entities_dict.values())
+
+ def get_all_relations(self):
+ """Returns all the (raw) relations in the dataset."""
+ return list(self.relations_dict.values())
+
+ def get_update_metadata(self, new_file_name=None):
+ """Update dataset metadata."""
+ metadata = {
+ "entities_dict": self.entities_dict,
+ "reversed_entities_dict": self.reversed_entities_dict,
+ "relations_dict": self.relations_dict,
+ "reversed_relations_dict": self.reversed_relations_dict,
+ }
+ return metadata
+
+ def create_mappings(self):
+ """Creates mappings of data into indexes.
+
+ It creates four dictionaries: two having as keys the unique entities/relations and as values the indexes,
+ while the other two are the reversed version of previous.
+ This method also dispatches to the adequate functions to create persistent or in-memory dictionaries.
+ """
+
+ if (
+ isinstance(self.entities_dict, str)
+ and self.shelve_exists(self.entities_dict)
+ and isinstance(self.reversed_entities_dict, str)
+ and self.shelve_exists(self.reversed_entities_dict)
+ and isinstance(self.relations_dict, str)
+ and self.shelve_exists(self.relations_dict)
+ and isinstance(self.reversed_relations_dict, str)
+ and self.shelve_exists(self.reversed_relations_dict)
+ ):
+ self._update_properties()
+ logger.debug(
+ "The mappings initialised from persistent dictionaries (shelves)."
+ )
+ elif (
+ self.entities_dict is None
+ and self.reversed_entities_dict is None
+ and self.relations_dict is None
+ and self.reversed_relations_dict is None
+ ):
+ logger.debug(
+ "The mappings will be created for data in {}.".format(
+ self.name
+ )
+ )
+
+ if isinstance(self.data, np.ndarray):
+ self.create_persistent_mappings_from_nparray()
+ else:
+ self.create_persistent_mappings_in_chunks()
+ else:
+ logger.debug(
+ "Provided initialization objects are not supported. Can't Initialise mappings."
+ )
+ self.mapped = True
+
+ def create_persistent_mappings_in_chunks(self):
+ """Creates shelves for mappings from entities and relations to indexes and the reverse mappings.
+
+ Four shelves are created in root_directory:
+ - entities__.shf - with map entities -> indexes
+ - reversed_entities__.shf - with map indexes -> entities
+ - relations__.shf - with map relations -> indexes
+ - reversed_relations__.shf - with map indexes -> relations
+
+ Remember to use mappings for entities with entities and relations with relations!
+ """
+ date = datetime.now().strftime("%d-%m-%Y_%I-%M-%S_%f_%p")
+ self.entities_dict = os.path.join(
+ self.root_directory, "entities_{}_{}.shf".format(self.name, date)
+ )
+ self.reversed_entities_dict = os.path.join(
+ self.root_directory,
+ "reversed_entities_{}_{}.shf".format(self.name, date),
+ )
+ self.relations_dict = os.path.join(
+ self.root_directory, "relations_{}_{}.shf".format(self.name, date)
+ )
+ self.reversed_relations_dict = os.path.join(
+ self.root_directory,
+ "reversed_relations_{}_{}.shf".format(self.name, date),
+ )
+
+ for chunk in self.data:
+ if isinstance(chunk, pd.DataFrame):
+ self.update_shelves(chunk.iloc[:, :3].values, rough=True)
+ else:
+ self.update_shelves(chunk[:, :3], rough=True)
+
+ logger.debug(
+ "We need to reindex all the data now so the indexes are continuous among chunks"
+ )
+ self.reindex()
+
+ self.files_id = "_{}_{}.shf".format(self.name, date)
+ files = [
+ "entities",
+ "reversed_entities",
+ "relations",
+ "reversed_relations",
+ ]
+ logger.debug(
+ "Mappings are created in the following files:\n{}\n{}\n{}\n{}".format(
+ *[x + self.files_id for x in files]
+ )
+ )
+ self.metadata.update(
+ {
+ "entities_shelf": self.entities_dict,
+ "reversed_entities_shelf": self.reversed_entities_dict,
+ "relations": self.relations_dict,
+ "reversed_relations_dict": self.reversed_relations_dict,
+ "name": self.name,
+ }
+ )
+
+ def reindex(self):
+ """Reindex the data to continuous values from 0 to .
+
+ This is needed where data is provided in chunks as we do not know the overlap
+ between chunks upfront and indexes are not continuous.
+ This guarantees that entities and relations have continuous indexes.
+ """
+ logger.debug("starting reindexing...")
+ remapped_ents_file = "remapped_ents.shf"
+ remapped_rev_ents_file = "remapped_rev_ents.shf"
+ remapped_rels_file = "remapped_rels.shf"
+ remapped_rev_rels_file = "remapped_rev_rels.shf"
+ with shelve.open(self.reversed_entities_dict) as ents:
+ with shelve.open(
+ remapped_ents_file, writeback=True
+ ) as remapped_ents:
+ with shelve.open(
+ remapped_rev_ents_file, writeback=True
+ ) as remapped_rev_ents:
+ for i, ent in enumerate(ents):
+ remapped_ents[str(i)] = str(ent)
+ remapped_rev_ents[str(ent)] = str(i)
+
+ with shelve.open(self.reversed_relations_dict) as rels:
+ with shelve.open(
+ remapped_rels_file, writeback=True
+ ) as remapped_rels:
+ with shelve.open(
+ remapped_rev_rels_file, writeback=True
+ ) as remapped_rev_rels:
+ for i, rel in enumerate(rels):
+ remapped_rels[str(i)] = str(rel)
+ remapped_rev_rels[str(rel)] = str(i)
+
+ self.move_shelve(remapped_ents_file, self.entities_dict)
+ self.move_shelve(remapped_rev_ents_file, self.reversed_entities_dict)
+ self.move_shelve(remapped_rels_file, self.relations_dict)
+ self.move_shelve(remapped_rev_rels_file, self.reversed_relations_dict)
+ logger.debug("reindexing done!")
+ self._update_properties()
+ logger.debug("properties updated")
+
+ def _update_properties(self, rough=False):
+ """Initialise properties from the persistent dictionary (shelve)."""
+
+ with shelve.open(self.entities_dict) as ents:
+ self.max_ents_index = int(max(ents.keys(), key=lambda x: int(x)))
+ self.ents_length = len(ents)
+ with shelve.open(self.relations_dict) as rels:
+ self.max_rels_index = int(max(rels.keys(), key=lambda x: int(x)))
+ self.rels_length = len(rels)
+ with shelve.open(self.reversed_entities_dict) as ents:
+ self.rev_ents_length = len(ents)
+ with shelve.open(self.reversed_relations_dict) as rels:
+ self.rev_rels_length = len(rels)
+ if not rough:
+ if not self.rev_ents_length == self.ents_length:
+ msg = "Reversed entities index size not equal to index size ({} and {})".format(
+ self.rev_ents_length, self.ents_length
+ )
+ logger.error(msg)
+ raise Exception(msg)
+ if not self.rev_rels_length == self.rels_length:
+ msg = "Reversed relations index size not equal to index size ({} and {})".format(
+ self.rev_rels_length, self.rels_length
+ )
+ logger.error(msg)
+ raise Exception(msg)
+ else:
+ logger.debug(
+ "In a rough mode, the sizes may not be equal due to duplicates, \
+ it will be fixed in reindexing at the later stage."
+ )
+ logger.debug(
+ "Reversed entities index size and index size {} and {}".format(
+ self.rev_ents_length, self.ents_length
+ )
+ )
+ logger.debug(
+ "Reversed relations index size and index size: {} and {}".format(
+ self.rev_rels_length, self.rels_length
+ )
+ )
+
+ def create_persistent_mappings_from_nparray(self):
+ """Creates shelves for mappings from entities and relations to indexes and the reverse mappings.
+
+ Four shelves are created in root_directory:
+ - entities__.shf - with map entities -> indexes
+ - reversed_entities__.shf - with map indexes -> entities
+ - relations__.shf - with map relations -> indexes
+ - reversed_relations__.shf - with map indexes -> relations
+
+ Remember to use mappings for entities with entities and relations with relations!
+ """
+
+ date = datetime.now().strftime("%d-%m-%Y_%I-%M-%S_%f_%p")
+ self.entities_dict = os.path.join(
+ self.root_directory, "entities_{}_{}.shf".format(self.name, date)
+ )
+ self.reversed_entities_dict = os.path.join(
+ self.root_directory,
+ "reversed_entities_{}_{}.shf".format(self.name, date),
+ )
+ self.relations_dict = os.path.join(
+ self.root_directory, "relations_{}_{}.shf".format(self.name, date)
+ )
+ self.reversed_relations_dict = os.path.join(
+ self.root_directory,
+ "reversed_relations_{}_{}.shf".format(self.name, date),
+ )
+ self.files_id = "_{}_{}.shf".format(self.name, date)
+ files = [
+ "entities",
+ "reversed_entities",
+ "relations",
+ "reversed_relations",
+ ]
+ logger.debug(
+ "Mappings are created in the following files:\n{}\n{}\n{}\n{}".format(
+ *[x + self.files_id for x in files]
+ )
+ )
+ self.metadata.update(
+ {
+ "entities_shelf": self.entities_dict,
+ "reversed_entities_shelf": self.reversed_entities_dict,
+ "relations": self.relations_dict,
+ "reversed_relations_dict": self.reversed_relations_dict,
+ "name": self.name,
+ }
+ )
+ self.update_shelves()
+
+ def update_shelves(self, sample=None, rough=False):
+ """Update shelves with sample or full data when sample not provided."""
+ if sample is None:
+ sample = self.data
+ if (
+ self.shelve_exists(self.entities_dict)
+ or self.shelve_exists(self.reversed_entities_dict)
+ or self.shelve_exists(self.relations_dict)
+ or self.shelve_exists(self.reversed_relations_dict)
+ ):
+ msg = "Shelves exists for some reason and are not empty!"
+ logger.error(msg)
+ raise Exception(msg)
+
+ logger.debug("Sample: {}".format(sample))
+ entities = set(sample[:, 0]).union(set(sample[:, 2]))
+ predicates = set(sample[:, 1])
+
+ start_ents = self._get_starting_index_ents()
+ logger.debug("Start index entities: {}".format(start_ents))
+ new_indexes_ents = range(start_ents, start_ents + len(entities))
+ # maximum new index, usually less when multiple chunks provided due to
+ # chunks
+ if not len(new_indexes_ents) == len(entities):
+ msg = "Etimated indexes length for entities not equal to entities length ({} and {})".format(
+ len(new_indexes_ents), len(entities)
+ )
+ logger.error(msg)
+ raise Exception(msg)
+
+ start_rels = self._get_starting_index_rels()
+ new_indexes_rels = range(start_rels, start_rels + len(predicates))
+ logger.debug("Starts index relations: {}".format(start_rels))
+ if not len(new_indexes_rels) == len(predicates):
+ msg = "Estimated indexes length for relations not equal to relations length ({} and {})".format(
+ len(new_indexes_rels), len(predicates)
+ )
+ logger.error(msg)
+ raise Exception(msg)
+ # print("new indexes rels: ", new_indexes_rels)
+ logger.debug(
+ "index rels size: {} and rels size: {}".format(
+ len(new_indexes_rels), len(predicates)
+ )
+ )
+ logger.debug(
+ "index ents size: {} and entss size: {}".format(
+ len(new_indexes_ents), len(entities)
+ )
+ )
+
+ with shelve.open(self.entities_dict, writeback=True) as ents:
+ with shelve.open(
+ self.reversed_entities_dict, writeback=True
+ ) as reverse_ents:
+ with shelve.open(self.relations_dict, writeback=True) as rels:
+ with shelve.open(
+ self.reversed_relations_dict, writeback=True
+ ) as reverse_rels:
+ reverse_ents.update(
+ {
+ str(value): str(key)
+ for key, value in zip(
+ new_indexes_ents, entities
+ )
+ }
+ )
+ ents.update(
+ {
+ str(key): str(value)
+ for key, value in zip(
+ new_indexes_ents, entities
+ )
+ }
+ )
+ reverse_rels.update(
+ {
+ str(value): str(key)
+ for key, value in zip(
+ new_indexes_rels, predicates
+ )
+ }
+ )
+ rels.update(
+ {
+ str(key): str(value)
+ for key, value in zip(
+ new_indexes_rels, predicates
+ )
+ }
+ )
+ self._update_properties(rough=rough)
+
+ def shelve_exists(self, name):
+ """Check if shelve with a given name exists."""
+ if not os.path.isfile(name + ".bak"):
+ return False
+ if not os.path.isfile(name + ".dat"):
+ return False
+ if not os.path.isfile(name + ".dir"):
+ return False
+ return True
+
+ def remove_shelve(self, name):
+ """Remove shelve with a given name."""
+ try:
+ os.remove(name + ".bak")
+ os.remove(name + ".dat")
+ os.remove(name + ".dir")
+ except Exception:
+ if os.path.exists(name + ".db"):
+ os.remove(name + ".db")
+
+ def move_shelve(self, source, destination):
+ """Move shelve to a different file."""
+ try:
+ os.rename(source + ".dir", destination + ".dir")
+ os.rename(source + ".dat", destination + ".dat")
+ os.rename(source + ".bak", destination + ".bak")
+ except Exception:
+ os.rename(source + ".db", destination + ".db")
+
+ def _get_starting_index_ents(self):
+ """Returns next index to continue adding elements to the entities dictionary."""
+ if not self.entities_dict:
+ return 0
+ else:
+ return self.max_ents_index + 1
+
+ def _get_starting_index_rels(self):
+ """Returns next index to continue adding elements to the relations dictionary."""
+ if not self.relations_dict:
+ return 0
+ else:
+ return self.max_rels_index + 1
+
+ def _get_max_ents_index(self):
+ """Get maximum index from entities dictionary."""
+ with shelve.open(self.entities_dict) as ents:
+ return int(max(ents.keys(), key=lambda x: int(x)))
+
+ def _get_max_rels_index(self):
+ """Get maximum index from relations dictionary."""
+ with shelve.open(self.relations_dict) as rels:
+ return int(max(rels.keys(), key=lambda x: int(x)))
+
+ def update_mappings(self, new_data):
+ self.update_shelves(new_data, rough=True)
+ self.reindex()
+
+ def get_indexes(self, sample=None, type_of="t", order="raw2ind"):
+ """Converts raw data to an indexed form or vice versa according to previously created mappings.
+
+ Parameters
+ ----------
+ sample: array
+ Array with raw or indexed data.
+ type_of: str
+ Type of provided sample to be specified as one of the following values: `{"t", "e", "r"}`.
+ It indicates whether the provided sample is an array of triples (`"t"`), a list of entities (`"e"`)
+ or a list of relations (`"r"`).
+ order: str
+ It specifies whether it converts raw data to indexes (``order="raw2ind"``) or indexes to raw
+ data (``order="ind2raw"``)
+
+ Returns
+ -------
+ tmp: array
+ Array of the same size as `sample` but with indexes of elements instead of raw data or raw data instead
+ of indexes.
+ """
+ if type_of not in ["t", "e", "r"]:
+ msg = "Type (type_of) should be one of the following: t, e, r, instead got {}".format(
+ type_of
+ )
+ logger.error(msg)
+ raise Exception(msg)
+
+ if type_of == "t":
+ self.data_shape = sample.shape[1]
+ indexed_data = self.get_indexes_from_shelves(
+ sample[:, :3], order=order
+ )
+ if sample.shape[1] > 3:
+ weights = sample[:, 3:]
+ # weights = preprocess_focusE_weights(data=sample[:, :3], weights=sample[:, 3:])
+ return np.concatenate([indexed_data, weights], axis=1)
+ return indexed_data
+ else:
+ return self.get_indexes_from_shelves_single(
+ sample, type_of=type_of, order=order
+ )
+
+ def get_indexes_from_shelves(self, sample, order="raw2ind"):
+ """Get indexed triples or raw data from shelves.
+
+ Parameters
+ ----------
+ sample: array
+ Array with a fragment of data of size (N,3), where each element is either (subject, predicate, object)
+ or (indexes_subject, indexed_predicate, indexed_object).
+ order: str
+ Specify ``order="raw2ind"`` or ``order="ind2raw"`` whether to convert raw data to indexes or indexes
+ to raw data.
+
+ Returns
+ -------
+ tmp: array
+ Array of size (N,3) where each element is either (indexes_subject, indexed_predicate, indexed_object)
+ or (subject, predicate, object).
+ """
+ if isinstance(sample, pd.DataFrame):
+ sample = sample.values
+ # logger.debug(sample)
+ if order == "raw2ind":
+ entities = self.reversed_entities_dict
+ relations = self.reversed_relations_dict
+ dtype = int
+ elif order == "ind2raw":
+ entities = self.entities_dict
+ relations = self.relations_dict
+ dtype = str
+ else:
+ msg = "No such order available options: ind2raw, raw2ind, instead got {}.".format(
+ order
+ )
+ logger.error(msg)
+ raise Exception(msg)
+
+ with shelve.open(entities) as ents:
+ with shelve.open(relations) as rels:
+ subjects = []
+ objects = []
+ predicates = []
+
+ invalid_keys = 0
+ for row in sample:
+ try:
+ s = ents[str(row[0])]
+ p = rels[str(row[1])]
+ o = ents[str(row[2])]
+ subjects.append(s)
+ predicates.append(p)
+ objects.append(o)
+ except KeyError:
+ invalid_keys += 1
+
+ if invalid_keys > 0:
+ print(
+ "\n{} triples containing invalid keys skipped!".format(
+ invalid_keys
+ )
+ )
+
+ out = np.array((subjects, predicates, objects), dtype=dtype).T
+ return out
+
+ def get_indexes_from_shelves_single(
+ self, sample, type_of="e", order="raw2ind"
+ ):
+ """Get indexed elements or raw data from shelves for entities or relations.
+
+ Parameters
+ ----------
+ sample: list
+ List of entities or relations indexed or in raw format.
+ type_of: str
+ ``type_of="e"`` to get indexes/raw data for entities or ``type_of="r"`` to get indexes/raw data
+ for relations.
+ order: str
+ ``order=raw2ind`` or ``order=ind2raw`` to specify whether to convert raw data to indexes or indexes
+ to raw data.
+
+ Returns
+ -------
+ tmp: array
+ Array of the same size of sample with indexed or raw data.
+ """
+ if order == "raw2ind":
+ entities = self.reversed_entities_dict
+ relations = self.reversed_relations_dict
+ dtype = int
+ elif order == "ind2raw":
+ entities = self.entities_dict
+ relations = self.relations_dict
+ dtype = str
+ else:
+ msg = "No such order available options: ind2raw, raw2ind, instead got {}.".format(
+ order
+ )
+ logger.error(msg)
+ raise Exception(msg)
+
+ if type_of == "e":
+ with shelve.open(entities) as ents:
+ elements = [ents[str(elem)] for elem in sample]
+ return np.array(elements, dtype=dtype)
+ elif type_of == "r":
+ with shelve.open(relations) as rels:
+ elements = [rels[str(elem)] for elem in sample]
+ return np.array(elements, dtype=dtype)
+ else:
+ if type_of not in ["r", "e"]:
+ msg = "No such option, should be r (relations) or e (entities), instead got {}".format(
+ type_of
+ )
+ logger.error(msg)
+ raise Exception(msg)
+
+ def get_relations_count(self):
+ """Get number of unique relations."""
+ return self.rels_length
+
+ def get_entities_count(self):
+ """Get number of unique entities."""
+ return self.ents_length
+
+ def clean(self):
+ """Remove persisted objects."""
+ self.remove_shelve(self.entities_dict)
+ self.remove_shelve(self.reversed_entities_dict)
+ self.remove_shelve(self.relations_dict)
+ self.remove_shelve(self.reversed_relations_dict)
+
+
+@register_indexer_backend("sqlite")
+class SQLite:
+ def __init__(
+ self,
+ data,
+ db_file=None,
+ root_directory=None,
+ name="main_partition",
+ **kwargs
+ ):
+ """Initialise backend by creating mappings.
+
+ Parameters
+ ----------
+ data: data to be indexed.
+ root_directory: directory where to store persistent mappings.
+ """
+ logger.debug("Initialisation of SQLite indexer.")
+ self.data = data
+ self.metadata = {}
+ if root_directory is None:
+ self.root_directory = tempfile.gettempdir()
+ else:
+ self.root_directory = root_directory
+ if db_file is not None:
+ self.db_file = os.path.join(self.root_directory, db_file)
+ self.mapped = True
+ else:
+ date = datetime.now().strftime("%d-%m-%Y_%I-%M-%S_%f_%p")
+ self.db_file = os.path.join(
+ self.root_directory, name + date + str(uuid.uuid4()) + ".db"
+ )
+
+ if os.path.exists(self.db_file):
+ os.remove(self.db_file)
+ self.mapped = False
+ self.name = name
+
+ self.max_ents_index = -1
+ self.max_rels_index = -1
+ self.ents_length = 0
+ self.rels_length = 0
+
+ def get_all_entities(self):
+ """Returns all the (raw) entities in the dataset."""
+
+ query = "select distinct name from entities"
+ with sqlite3.connect(self.db_file) as conn:
+ cursor = conn.cursor()
+ output = None
+ try:
+ cursor.execute(query)
+ output = cursor.fetchall()
+ out_val = []
+ for out in output:
+ out_val.append(out[0])
+ conn.commit()
+ except Exception as e:
+ logger.debug("Query failed. The error '{}' occurred".format(e))
+ logger.debug(query)
+ logger.debug(output)
+ return []
+
+ return out_val
+
+ def get_all_relations(self):
+ """Returns all the (raw) relations in the dataset."""
+ query = "select distinct name from relations"
+ with sqlite3.connect(self.db_file) as conn:
+ cursor = conn.cursor()
+ output = None
+ try:
+ cursor.execute(query)
+ output = cursor.fetchall()
+ out_val = []
+ for out in output:
+ out_val.append(out[0])
+ conn.commit()
+ except Exception as e:
+ logger.debug("Query failed. The error '{}' occurred".format(e))
+ logger.debug(query)
+ logger.debug(output)
+ return []
+
+ return out_val
+
+ def get_update_metadata(self, path):
+ """Get the metadata update for the database."""
+ self.root_directory = path
+ self.root_directory = (
+ "." if self.root_directory == "" else self.root_directory
+ )
+ new_file_name = os.path.join(
+ self.root_directory, os.path.basename(self.db_file)
+ )
+ if not os.path.exists(new_file_name):
+ shutil.copyfile(self.db_file, new_file_name)
+ self.db_file = new_file_name
+ metadata = {
+ "root_directory": self.root_directory,
+ "db_file": os.path.basename(self.db_file),
+ "name": self.name,
+ }
+ return metadata
+
+ def create_mappings(self):
+ """Creates SQLite mappings."""
+ logger.debug("Creating SQLite mappings.")
+ if isinstance(self.data, np.ndarray):
+ self.create_persistent_mappings_from_nparray()
+ else:
+ self.create_persistent_mappings_in_chunks()
+ logger.debug("Database: {}.".format(self.db_file))
+ self.metadata.update({"db": self.db_file, "name": self.name})
+ self.mapped = True
+
+ def update_db(self, sample=None):
+ """Update database with sample or full data when sample not provided."""
+ logger.debug("Update db with data.")
+ if sample is None:
+ sample = self.data
+ logger.debug("sample = {}".format(sample))
+ subjects = sample[:, 0]
+ objects = sample[:, 2]
+ relations = sample[:, 1]
+ entities = np.concatenate((subjects, objects))
+
+ data = {"entities": entities, "relations": relations}
+ for table, elems in data.items():
+ sql_create_table = """ CREATE TABLE IF NOT EXISTS tmp_{} (
+ name text PRIMARY KEY
+ );""".format(
+ table
+ )
+
+ with sqlite3.connect(self.db_file) as conn:
+ c = conn.cursor()
+ c.execute(sql_create_table)
+ conn.commit()
+
+ tab = "tmp_{}"
+ values_placeholder = "({})".format(", ".join(["?"] * 1))
+ query = "INSERT OR IGNORE INTO {} VALUES {};".format(
+ tab.format(table), values_placeholder
+ )
+ with sqlite3.connect(self.db_file) as conn:
+ c = conn.cursor()
+ tmp = [(str(v),) for v in elems]
+ c.executemany(query, tmp)
+ conn.commit()
+
+ def _get_max(self, table):
+ """Get the max value out of a table."""
+ logger.debug("Get max.")
+ query = "SELECT max(id) from {};".format(table)
+ with sqlite3.connect(self.db_file) as conn:
+ cursor = conn.cursor()
+ maxi = None
+ try:
+ cursor.execute(query)
+ maxi = cursor.fetchall()
+ conn.commit()
+ except Exception as e:
+ logger.debug("Query failed. The error '{}' occurred".format(e))
+
+ if maxi is None:
+ logger.debug("Table is empty or not such table exists.")
+ return maxi
+ elif not isinstance(maxi, list) or not isinstance(maxi[0], tuple):
+ raise ValueError(
+ "Cannot get max for the table with provided condition."
+ )
+ logger.debug("Maximal value: {}.".format(maxi[0][0]))
+ return maxi[0][0]
+
+ def _get_max_ents_index(self):
+ """Get the max index for entities."""
+ return self._get_max("entities")
+
+ def _get_max_rels_index(self):
+ """Get the max index for relations."""
+ return self._get_max("relations")
+
+ def _update_properties(self):
+ """Initialise properties from the database."""
+ logger.debug("Update properties")
+ self.max_ents_index = self._get_max_ents_index()
+ self.max_rels_index = self._get_max_rels_index()
+
+ self.ents_length = self.get_entities_count()
+ self.rels_length = self.get_relations_count()
+
+ def create_persistent_mappings_from_nparray(self):
+ """Index entities and relations.
+
+ Creates sqlite db for mappings between entities and relations to indexes.
+ """
+ self.update_db()
+ self.index_data("entities")
+ self.index_data("relations")
+
+ def index_data(self, table):
+ """Create new table with persisted id of elements."""
+ logger.debug("Index data in SQLite.")
+ query = [
+ "CREATE TABLE IF NOT EXISTS {0}(id INTEGER PRIMARY KEY, name TEXT NOT NULL);".format(
+ table
+ ),
+ "INSERT INTO {0}(id, name) SELECT rowid - 1, name FROM tmp_{0};".format(
+ table
+ ),
+ "DROP TABLE tmp_{0};".format(table),
+ ]
+ with sqlite3.connect(self.db_file) as conn:
+ c = conn.cursor()
+ for q in query:
+ c.execute(q)
+ conn.commit()
+ self._update_properties()
+
+ def create_persistent_mappings_in_chunks(self):
+ """Index entities and relations. Creates sqlite db for mappings between
+ entities and relations to indexes in chunks.
+ """
+ for chunk in self.data:
+ if isinstance(chunk, pd.DataFrame):
+ self.update_db(sample=chunk.iloc[:, :3].values)
+ else:
+ self.update_db(chunk[:, :3])
+ logger.debug(
+ "We need to reindex all the data now so the indexes are continuous among chunks"
+ )
+ self.index_data("entities")
+ self.index_data("relations")
+
+ def update_mappings(self, new_data):
+ raise NotImplementedError(
+ "Updating existing mappings not supported, \
+ try creating new mappings in chunks instead."
+ )
+
+ def get_indexes(self, sample=None, type_of="t", order="raw2ind"):
+ """Converts raw data to an indexed form (or vice versa) according to previously created mappings.
+
+ Parameters
+ ----------
+ sample: array
+ Array with raw or indexed data.
+ type_of: str
+ Type of provided sample to be specified as one of the following values: `{"t", "e", "r"}`.
+ It indicates whether the provided sample is an array of triples (`"t"`), a list of entities (`"e"`)
+ or a list of relations (`"r"`).
+ order: str
+ It specifies whether it converts raw data to indexes (``order="raw2ind"``) or indexes to raw
+ data (``order="ind2raw"``)
+
+ Returns
+ -------
+ out: array
+ Array of the same size as `sample` but with indexes of elements instead of raw data or raw data instead
+ of indexes.
+ """
+ if type_of not in ["t", "e", "r"]:
+ msg = "Type (type_of) should be one of the following: t, e, r, instead got {}".format(
+ type_of
+ )
+ logger.error(msg)
+ raise Exception(msg)
+
+ if type_of == "t":
+ self.data_shape = sample.shape[1]
+ indexed_data = self.get_indexes_from_db(sample[:, :3], order=order)
+ # focusE
+ if sample.shape[1] > 3:
+ weights = sample[:, 3:]
+ return np.concatenate([indexed_data, weights], axis=1)
+ return indexed_data
+ else:
+ out, _ = self.get_indexes_from_db_single(
+ sample, type_of=type_of, order=order
+ )
+ return out
+
+ def get_indexes_from_db(self, sample, order="raw2ind"):
+ """Get indexed or raw triples from the database.
+
+ Parameters
+ ----------
+ sample: ndarray
+ Numpy array with a fragment of data of size (N,3), where each element is (subject, predicate, object)
+ or (indexed_subject, indexed_predicate, indexed_object).
+ order: str
+ Specifies whether it should convert raw data to indexes (``order="raw2ind"``) or indexes
+ to raw data (``order="ind2raw"``).
+
+ Returns
+ -------
+ tmp: ndarray
+ Numpy array of size (N,3) with indexed triples, where, depending on ``order``, each element is
+ (indexed_subject, indexed_predicate, indexed_object) or (subject, predicate, object).
+ """
+ if isinstance(sample, pd.DataFrame):
+ sample = sample.values
+
+ subjects, subject_present = self.get_indexes_from_db_single(
+ sample[:, 0], type_of="e", order=order
+ )
+ objects, objects_present = self.get_indexes_from_db_single(
+ sample[:, 2], type_of="e", order=order
+ )
+ predicates, predicates_present = self.get_indexes_from_db_single(
+ sample[:, 1], type_of="r", order=order
+ )
+ if order == "raw2ind":
+ dtype = int
+ elif order == "ind2raw":
+ dtype = str
+ else:
+ msg = "No such order available options: ind2raw, raw2ind, instead got {}.".format(
+ order
+ )
+ logger.error(msg)
+ raise Exception(msg)
+
+ present = (
+ np.array(subject_present)
+ & np.array(objects_present)
+ & np.array(predicates_present)
+ )
+ out = np.array((subjects, predicates, objects), dtype=dtype).T
+ before = out.shape[0]
+ out = out[present]
+ after = out.shape[0]
+
+ if before - after > 0:
+ print(
+ "\n{} triples containing invalid keys skipped!".format(
+ before - after
+ )
+ )
+
+ return out
+
+ def get_indexes_from_db_single(self, sample, type_of="e", order="raw2ind"):
+ """Get indexes or raw data from entities or relations.
+
+ Parameters
+ ----------
+ sample: list
+ List of entities or relations to get indexes for.
+ type_of: str
+ Specifies whether to get indexes/raw data for entities (``type_of="e"``) or relations (``type_of="r"``).
+ order: str
+ Specifies whether to convert raw data to indexes (``order="raw2ind"``) or indexes to raw
+ data (``order="ind2raw"``).
+
+ Returns
+ -------
+ tmp: list
+ List of indexes/raw data.
+ present: list
+ List that specifies whether the mapping for the elements in `sample` were in the database (`True`) or
+ not (:math:`-1`).
+
+ """
+ if type_of == "e":
+ table = "entities"
+ elif type_of == "r":
+ table = "relations"
+ else:
+ msg = "No such option, should be r (relations) or e (entities), instead got {}".format(
+ type_of
+ )
+ logger.error(msg)
+ raise Exception(msg)
+
+ if order == "raw2ind":
+ query = "select name, ifnull(id, '-1') from {0} where name in ({1});".format(
+ table, ",".join('"{}"'.format(v) for v in sample)
+ )
+ with sqlite3.connect(self.db_file) as conn:
+ cursor = conn.cursor()
+ output = None
+ try:
+ cursor.execute(query)
+ output = dict(cursor.fetchall())
+ conn.commit()
+ out_values = []
+ present = []
+
+ for x in sample:
+ try:
+ out_values.append(output[str(x)])
+ present.append(True)
+ except KeyError:
+ out_values.append(str(-1))
+ present.append(False)
+
+ return out_values, present
+
+ except Exception as e:
+ logger.debug(
+ "Query failed. The error '{}' occurred".format(e)
+ )
+ logger.debug(query)
+ logger.debug(output)
+ return []
+ elif order == "ind2raw":
+ query = "select * from {0} where id in ({1});".format(
+ table, ",".join('"{}"'.format(v) for v in sample)
+ )
+ with sqlite3.connect(self.db_file) as conn:
+ cursor = conn.cursor()
+ output = None
+ try:
+ cursor.execute(query)
+ output = dict(cursor.fetchall())
+ conn.commit()
+ out_values = []
+ present = []
+
+ for x in sample:
+ try:
+ out_values.append(output[x])
+ present.append(True)
+ except KeyError:
+ out_values.append(str(-1))
+ present.append(False)
+
+ return out_values, present
+
+ except Exception as e:
+ logger.debug(
+ "Query failed. The error '{}' occurred".format(e)
+ )
+ return []
+ else:
+ msg = "No such order available options: ind2raw, raw2ind, instead got {}.".format(
+ order
+ )
+ logger.error(msg)
+ raise Exception(msg)
+
+ def get_count(self, table, condition):
+ """Return number of unique elements in a table according to condition."""
+ logger.debug("Get count.")
+ query = "SELECT count(*) from {} {};".format(table, condition)
+ with sqlite3.connect(self.db_file) as conn:
+ cursor = conn.cursor()
+ count = None
+ try:
+ cursor.execute(query)
+ count = cursor.fetchall()
+ conn.commit()
+ except Exception as e:
+ logger.debug("Query failed. The error '{}' occurred".format(e))
+
+ if count is None:
+ logger.debug("Table is empty or not such table exists.")
+ return count
+ elif not isinstance(count, list) or not isinstance(count[0], tuple):
+ raise ValueError(
+ "Cannot get count for the table with provided condition."
+ )
+ logger.debug("Count is {}.".format(count[0][0]))
+ return count[0][0]
+
+ def get_relations_count(self, condition=""):
+ """Return number of unique relations."""
+ return self.get_count("relations", condition)
+
+ def get_entities_count(self, condition=""):
+ """Return number of unique entities."""
+ return self.get_count("entities", condition)
+
+ def _get_starting_index_ents(self):
+ """Return next index to continue adding elements to entities dictionary."""
+ if not self.db_file:
+ return 0
+ else:
+ return self.max_ents_index + 1
+
+ def _get_starting_index_rels(self):
+ """Return next index to continue adding elements to relations dictionary."""
+ if not self.db_file:
+ return 0
+ else:
+ return self.max_rels_index + 1
+
+ def clean(self):
+ """Remove the database file."""
+ os.remove(self.db_file)
+ logger.debug("Indexer Database removed.")
diff --git a/ampligraph/datasets/datasets.py b/ampligraph/datasets/datasets.py
index c91543b7..52740ae9 100644
--- a/ampligraph/datasets/datasets.py
+++ b/ampligraph/datasets/datasets.py
@@ -1,24 +1,49 @@
-# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
#
# This file is Licensed under the Apache License, Version 2.0.
# A copy of the Licence is available in LICENCE, or at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
-import pandas as pd
-import os
-import numpy as np
+import hashlib
+import json
import logging
+import os
import urllib
import zipfile
-from pathlib import Path
-import hashlib
from collections import namedtuple
+from pathlib import Path
-AMPLIGRAPH_ENV_NAME = 'AMPLIGRAPH_DATA_HOME'
+import numpy as np
+import pandas as pd
-DatasetMetadata = namedtuple('DatasetMetadata', ['dataset_name', 'filename', 'url', 'train_name', 'valid_name',
- 'test_name', 'train_checksum', 'valid_checksum', 'test_checksum'])
+AMPLIGRAPH_ENV_NAME = "AMPLIGRAPH_DATA_HOME"
+
+DatasetMetadata = namedtuple(
+ "DatasetMetadata",
+ [
+ "dataset_name",
+ "filename",
+ "url",
+ "train_name",
+ "valid_name",
+ "test_name",
+ "train_checksum",
+ "valid_checksum",
+ "test_checksum",
+ "test_human_name",
+ "test_human_checksum",
+ "test_human_ids_name",
+ "test_human_ids_checksum",
+ "mapper_name",
+ "mapper_checksum",
+ "valid_negatives_name",
+ "valid_negatives_checksum",
+ "test_negatives_name",
+ "test_negatives_checksum",
+ ],
+ defaults=(None, None, None, None, None, None, None, None, None, None),
+)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@@ -26,13 +51,14 @@
def _clean_data(X, return_idx=False):
"""
- Clean dataset X by removing unseen entities and relations from valid and test sets.
+ Clean dataset **X** by removing unseen entities and relations from valid and test sets.
Parameters
----------
X: dict
- Dicionary containing the following keys: train, valid, test.
- Each key should contain an ndarray of shape [n, 3].
+ Dictionary containing the following keys: `train`, `valid`, `test`.
+ Each key should contain a ndarray of shape [n, m] with m >= 3
+ (> if weights are included in the array).
return_idx: bool
Whether to return the indices of the remaining rows in valid and test respectively.
@@ -40,36 +66,74 @@ def _clean_data(X, return_idx=False):
Returns
-------
filtered_X: dict
- Dicionary containing the following keys: train, valid, test.
- Each key contains an ndarray of shape [n, 3].
+ Dictionary containing the following keys: `train`, `valid`, `test`.
+ Each key contains a ndarray of shape [n, m].
Valid and test do not contain entities or relations that are not present in train.
valid_idx: ndarray
- Indices of the remaining rows of the valid dataset (with respect to the original valid ndarray).
+ Indices of the remaining rows of the `valid` dataset (with respect to the original `valid` ndarray).
test_idx: ndarray
- Indices of the remaining rows of the test dataset (with respect to the original test ndarray).
+ Indices of the remaining rows of the `test` dataset (with respect to the original `test` ndarray).
"""
- if X["train"].shape[1] == 3:
- columns = ['s', 'p', 'o']
- else:
- columns = ['s', 'p', 'o', 'w']
-
- train = pd.DataFrame(X["train"], columns=columns)
- valid = pd.DataFrame(X["valid"], columns=columns)
- test = pd.DataFrame(X["test"], columns=columns)
+ filtered_X = {}
+ train = pd.DataFrame(X["train"][:,:3], columns=["s", "p", "o"])
+ filtered_X["train"] = X["train"]
+
+ valid = pd.DataFrame(X["valid"][:,:3], columns=["s", "p", "o"])
+ test = pd.DataFrame(X["test"][:,:3], columns=["s", "p", "o"])
train_ent = np.unique(np.concatenate((train.s, train.o)))
train_rel = train.p.unique()
- valid_idx = valid.s.isin(train_ent) & valid.o.isin(train_ent) & valid.p.isin(train_rel)
- test_idx = test.s.isin(train_ent) & test.o.isin(train_ent) & test.p.isin(train_rel)
+ if "valid_negatives" in X:
+ valid_negatives = pd.DataFrame(
+ X["valid_negatives"][:,:3], columns=["s", "p", "o"]
+ )
+ valid_negatives_idx = (
+ valid_negatives.s.isin(train_ent)
+ & valid_negatives.o.isin(train_ent)
+ & valid_negatives.p.isin(train_rel)
+ )
+ filtered_valid_negatives = valid_negatives[valid_negatives_idx].values
+ filtered_X["valid_negatives"] = filtered_valid_negatives
+ if "test_negatives" in X:
+ test_negatives = pd.DataFrame(
+ X["test_negatives"][:,:3], columns=["s", "p", "o"]
+ )
+ test_negatives_idx = (
+ test_negatives.s.isin(train_ent)
+ & test_negatives.o.isin(train_ent)
+ & test_negatives.p.isin(train_rel)
+ )
+ filtered_test_negatives = test[test_negatives_idx].values
+ filtered_X["test_negatives"] = filtered_test_negatives
+
+ valid_idx = (
+ valid.s.isin(train_ent)
+ & valid.o.isin(train_ent)
+ & valid.p.isin(train_rel)
+ )
+ test_idx = (
+ test.s.isin(train_ent)
+ & test.o.isin(train_ent)
+ & test.p.isin(train_rel)
+ )
+
+ # filtered_valid = valid[valid_idx].values
+ # filtered_test = test[test_idx].values
+ filtered_valid = X["valid"][valid_idx]
+ filtered_test = X["test"][test_idx]
- filtered_valid = valid[valid_idx].values
- filtered_test = test[test_idx].values
+ filtered_X["valid"] = filtered_valid
+ filtered_X["test"] = filtered_test
- filtered_X = {'train': train.values, 'valid': filtered_valid, 'test': filtered_test}
+ if "mapper" in X:
+ filtered_X["mapper"] = X["mapper"]
+ if "test-human" in X and "test-human-ids" in X:
+ filtered_X["test-human"] = X["test-human"]
+ filtered_X["test-human-ids"] = X["test-human-ids"]
if return_idx:
return filtered_X, valid_idx, test_idx
@@ -78,18 +142,17 @@ def _clean_data(X, return_idx=False):
def _get_data_home(data_home=None):
- """Get to location of the dataset folder to use.
+ """Get the location of the dataset folder to use.
Automatically determine the dataset folder to use.
- If data_home is provided this location a check is
- performed to see if the path exists and creates one if it does not.
- If data_home is None the AMPLIGRAPH_ENV_NAME dataset is used.
- If AMPLIGRAPH_ENV_NAME is not set the a default environment ``~/ampligraph_datasets`` is used.
+ If ``data_home`` is provided, a check is performed to see if the path exists and creates one if it does not.
+ If ``data_home`` is `None` the ``AMPLIGRAPH_ENV_NAME`` dataset is used.
+ If ``AMPLIGRAPH_ENV_NAME`` is not set, the default environment ``~/ampligraph_datasets`` is used.
Parameters
----------
- data_home : str
+ data_home: str
The path to the folder that contains the datasets.
Returns
@@ -101,19 +164,20 @@ def _get_data_home(data_home=None):
"""
if data_home is None:
- data_home = os.environ.get(AMPLIGRAPH_ENV_NAME, os.path.join('~', 'ampligraph_datasets'))
-
+ data_home = os.environ.get(
+ AMPLIGRAPH_ENV_NAME, os.path.join("~", "ampligraph_datasets")
+ )
data_home = os.path.expanduser(data_home)
if not os.path.exists(data_home):
os.makedirs(data_home)
- logger.debug('data_home is set to {}'.format(data_home))
+ logger.debug("data_home is set to {}".format(data_home))
return data_home
def _md5(file_path):
md5hash = hashlib.md5()
chunk_size = 4096
- with open(file_path, 'rb') as f:
+ with open(file_path, "rb") as f:
content_buffer = f.read(chunk_size)
while content_buffer:
md5hash.update(content_buffer)
@@ -127,51 +191,64 @@ def _unzip_dataset(remote, source, destination, check_md5hash=False):
Parameters
----------
- source : str
+ source: str
The path to the zipped file
- destination : str
+ destination: str
The destination directory to unzip the files to.
"""
# TODO - add error checking
- with zipfile.ZipFile(source, 'r') as zip_ref:
- logger.debug('Unzipping {} to {}'.format(source, destination))
+ with zipfile.ZipFile(source, "r") as zip_ref:
+ logger.debug("Unzipping {} to {}".format(source, destination))
zip_ref.extractall(destination)
if check_md5hash:
- for file_name, remote_checksum in [[remote.train_name, remote.train_checksum],
- [remote.valid_name, remote.valid_checksum],
- [remote.test_name, remote.test_checksum]]:
- file_path = os.path.join(destination, remote.dataset_name, file_name)
+ for file_name, remote_checksum in [
+ [remote.train_name, remote.train_checksum],
+ [remote.valid_name, remote.valid_checksum],
+ [remote.test_name, remote.test_checksum],
+ [remote.test_human_name, remote.test_human_checksum],
+ [remote.test_human_ids_name, remote.test_human_ids_checksum],
+ [remote.mapper_name, remote.mapper_checksum],
+ [remote.valid_negatives_name, remote.valid_negatives_checksum],
+ [remote.test_negatives_name, remote.test_negatives_checksum],
+ ]:
+ file_path = os.path.join(
+ destination, remote.dataset_name, file_name
+ )
checksum = _md5(file_path)
if checksum != remote_checksum:
os.remove(source)
- msg = '{} has an md5 checksum of ({}) which is different from the expected ({}), ' \
- 'the file may be corrupted.'.format(file_path, checksum, remote_checksum)
+ msg = (
+ "{} has an md5 checksum of ({}) which is different from the expected ({}), "
+ "the file may be corrupted.".format(
+ file_path, checksum, remote_checksum
+ )
+ )
logger.error(msg)
raise IOError(msg)
os.remove(source)
def _fetch_remote_data(remote, download_dir, data_home, check_md5hash=False):
- """Download a remote datasets.
+ """Download a remote dataset.
Parameters
----------
- remote : DatasetMetadata
- Named tuple containing remote datasets meta information: dataset name, dataset filename,
+ remote: DatasetMetadata
+ Named tuple containing remote dataset meta information: dataset name, dataset filename,
url, train filename, validation filename, test filename, train checksum, valid checksum, test checksum.
- download_dir : str
+ download_dir: str
The location to download the file to.
- data_home : str
+ data_home: str
The location to save the dataset.
- check_md5hash : bool
+ check_md5hash: bool
Whether to check the MD5 hash of the dataset file.
"""
- file_path = '{}.zip'.format(download_dir)
+ file_path = "{}.zip".format(download_dir)
if not Path(file_path).exists():
urllib.request.urlretrieve(remote.url, file_path)
# TODO - add error checking
@@ -186,12 +263,12 @@ def _fetch_dataset(remote, data_home=None, check_md5hash=False):
Parameters
----------
- remote : DatasetMetadata
+ remote: DatasetMetadata
Named tuple containing remote datasets meta information: dataset name, dataset filename,
url, train filename, validation filename, test filename, train checksum, valid checksum, test checksum.
- data_home : str
+ data_home: str
The location to save the dataset to.
- check_md5hash : bool
+ check_md5hash: bool
Whether to check the MD5 hash of the dataset file.
Returns
@@ -204,7 +281,7 @@ def _fetch_dataset(remote, data_home=None, check_md5hash=False):
dataset_dir = os.path.join(data_home, remote.dataset_name)
if not os.path.exists(dataset_dir):
if remote.url is None:
- msg = 'No dataset at {} and no url provided.'.format(dataset_dir)
+ msg = "No dataset at {} and no url provided.".format(dataset_dir)
logger.error(msg)
raise Exception(msg)
@@ -218,13 +295,13 @@ def _add_reciprocal_relations(triples_df):
Parameters
----------
- triples_df : Dataframe
- Dataframe of triples
+ triples_df: Dataframe
+ Dataframe of triples.
Returns
-------
- triples_df : Dataframe
- Dataframe of triples and their reciprocals
+ triples_df: Dataframe
+ Dataframe of triples and their reciprocals.
"""
# create a copy of the original triples to add reciprocal relations
df_reciprocal = triples_df.copy()
@@ -242,10 +319,14 @@ def _add_reciprocal_relations(triples_df):
return triples_df
-def load_from_csv(directory_path, file_name, sep='\t', header=None, add_reciprocal_rels=False):
- """Load a knowledge graph from a csv file
+def load_from_csv(
+ directory_path, file_name, sep="\t", header=None, add_reciprocal_rels=False
+):
+ """Load a knowledge graph from a .csv file.
- Loads a knowledge graph serialized in a csv file as:
+ Loads a knowledge graph serialized in a .csv file filtering duplicated statements. In the .csv file, each line
+ has to represent a triple, and entities and relations are separated by ``sep``.
+ For instance, if ``sep="\\t"``, the .csv file look like:
.. code-block:: text
@@ -253,15 +334,12 @@ def load_from_csv(directory_path, file_name, sep='\t', header=None, add_reciproc
subj1 relationY obj2
subj3 relationZ obj2
subj4 relationY obj2
- ...
+ ...
- .. note::
- The function filters duplicated statements.
-
- .. note::
- It is recommended to use :meth:`ampligraph.evaluation.train_test_split_no_unseen` to split custom
- knowledge graphs into train, validation, and test sets. Using this function will lead to validation, test sets
- that do not include triples with entities that do not occur in the training set.
+ .. hint::
+ To split a generic knowledge graphs into **training**, **validation**, and **test** sets do not use the above
+ function, but rather :meth:`~ampligraph.evaluation.protocol.train_test_split_no_unseen`: this will return
+ validation and test sets not including triples with entities not present in the training set.
Parameters
@@ -269,28 +347,27 @@ def load_from_csv(directory_path, file_name, sep='\t', header=None, add_reciproc
directory_path: str
Folder where the input file is stored.
- file_name : str
+ file_name: str
File name.
- sep : str
- The subject-predicate-object separator (default \t).
- header : int, None
+ sep: str
+ The subject-predicate-object separator (default: ``"\\t"``).
+ header: int or None
The row of the header of the csv file. Same as pandas.read_csv header param.
- add_reciprocal_rels : bool
+ add_reciprocal_rels: bool
Flag which specifies whether to add reciprocal relations. For every in the dataset
- this creates a corresponding triple with reciprocal relation . (default: False)
+ this creates a corresponding triple with reciprocal relation (default: `False`).
Returns
-------
-
- triples : ndarray , shape [n, 3]
+ triples: ndarray, shape (n, 3)
The actual triples of the file.
- Examples
- --------
-
+ Example
+ -------
+ >>> PATH_TO_FOLDER = 'your/path/to/folder/'
>>> from ampligraph.datasets import load_from_csv
- >>> X = load_from_csv('folder', 'dataset.csv', sep=',')
+ >>> X = load_from_csv(PATH_TO_FOLDER, 'dataset.csv', sep=',')
>>> X[:3]
array([['a', 'y', 'b'],
['b', 'y', 'a'],
@@ -299,13 +376,15 @@ def load_from_csv(directory_path, file_name, sep='\t', header=None, add_reciproc
"""
- logger.debug('Loading data from {}.'.format(file_name))
- df = pd.read_csv(os.path.join(directory_path, file_name),
- sep=sep,
- header=header,
- names=None,
- dtype=str)
- logger.debug('Dropping duplicates.')
+ logger.debug("Loading data from {}.".format(file_name))
+ df = pd.read_csv(
+ os.path.join(directory_path, file_name),
+ sep=sep,
+ header=header,
+ names=None,
+ dtype=str,
+ )
+ logger.debug("Dropping duplicates.")
df = df.drop_duplicates()
if add_reciprocal_rels:
df = _add_reciprocal_relations(df)
@@ -313,7 +392,12 @@ def load_from_csv(directory_path, file_name, sep='\t', header=None, add_reciproc
return df.values
-def _load_dataset(dataset_metadata, data_home=None, check_md5hash=False, add_reciprocal_rels=False):
+def _load_dataset(
+ dataset_metadata,
+ data_home=None,
+ check_md5hash=False,
+ add_reciprocal_rels=False,
+):
"""Load a dataset from the details provided.
DatasetMetadata = namedtuple('DatasetMetadata', ['dataset_name', 'filename', 'url', 'train_name', 'valid_name',
@@ -321,61 +405,146 @@ def _load_dataset(dataset_metadata, data_home=None, check_md5hash=False, add_rec
Parameters
----------
- dataset_metadata : DatasetMetadata
+ dataset_metadata: DatasetMetadata
Named tuple containing remote datasets meta information: dataset name, dataset filename,
url, train filename, validation filename, test filename, train checksum, valid checksum, test checksum.
- data_home : str
- The location to save the dataset to (default: None).
+ data_home: str
+ The location to save the dataset to (default: `None`).
- check_md5hash : boolean
- If True, check the md5hash of the files after they are downloaded (default: False).
+ check_md5hash: bool
+ If True, check the md5hash of the files after they are downloaded (default: `False`).
- add_reciprocal_rels : bool
+ add_reciprocal_rels: bool
Flag which specifies whether to add reciprocal relations. For every in the dataset
- this creates a corresponding triple with reciprocal relation . (default: False).
+ this creates a corresponding triple with reciprocal relation (default: `False`).
"""
-
+ dataset = {}
if dataset_metadata.dataset_name is None:
if dataset_metadata.url is None:
- raise ValueError('The dataset name or url must be provided to load a dataset.')
- dataset_metadata.dataset_name = dataset_metadata.url[dataset_metadata.url.rfind('/') + 1:dataset_metadata
- .url.rfind('.')]
+ raise ValueError(
+ "The dataset name or url must be provided to load a dataset."
+ )
+ dataset_metadata.dataset_name = dataset_metadata.url[
+ dataset_metadata.url.rfind("/")
+ + 1: dataset_metadata.url.rfind(".")
+ ]
dataset_path = _fetch_dataset(dataset_metadata, data_home, check_md5hash)
+ train = load_from_csv(
+ dataset_path,
+ dataset_metadata.train_name,
+ add_reciprocal_rels=add_reciprocal_rels,
+ )
+ dataset["train"] = train
+ valid = load_from_csv(
+ dataset_path,
+ dataset_metadata.valid_name,
+ add_reciprocal_rels=add_reciprocal_rels,
+ )
+ dataset["valid"] = valid
+ test = load_from_csv(
+ dataset_path,
+ dataset_metadata.test_name,
+ add_reciprocal_rels=add_reciprocal_rels,
+ )
+ dataset["test"] = test
+ if dataset_metadata.valid_negatives_name is not None:
+ valid_negatives = load_from_csv(
+ dataset_path,
+ dataset_metadata.valid_negatives_name,
+ add_reciprocal_rels=add_reciprocal_rels,
+ )
+ dataset["valid_negatives"] = valid_negatives
+ if dataset_metadata.test_negatives_name is not None:
+ test_negatives = load_from_csv(
+ dataset_path,
+ dataset_metadata.test_negatives_name,
+ add_reciprocal_rels=add_reciprocal_rels,
+ )
+ dataset["test_negatives"] = test_negatives
+
+ if (
+ dataset_metadata.test_human_checksum is not None
+ and dataset_metadata.test_human_ids_checksum is not None
+ ):
+ test_human = load_from_csv(
+ dataset_path, dataset_metadata.test_human_name
+ )
+ dataset["test-human"] = test_human
+ test_human_ids = load_from_csv(
+ dataset_path, dataset_metadata.test_human_ids_name
+ )
+ dataset["test-human-ids"] = test_human_ids
+ if dataset_metadata.mapper_checksum is not None:
+ mapper = load_mapper_from_json(
+ dataset_path, dataset_metadata.mapper_name
+ )
+ dataset["mapper"] = mapper
+ return dataset
+
+
+def load_mapper_from_json(directory_path, file_name):
+ """Load a mapper from a .json file.
+
+ Loads a mapper for a graph serialized in a .json file as:
- train = load_from_csv(dataset_path,
- dataset_metadata.train_name,
- add_reciprocal_rels=add_reciprocal_rels)
- valid = load_from_csv(dataset_path,
- dataset_metadata.valid_name,
- add_reciprocal_rels=add_reciprocal_rels)
- test = load_from_csv(dataset_path,
- dataset_metadata.test_name,
- add_reciprocal_rels=add_reciprocal_rels)
+ .. code-block:: text
- return {'train': train, 'valid': valid, 'test': test}
+ subj1: human_labeled_subj1
+ relationX: human_labeled_relationX
+ obj1: human_labeled_obj1
+ human_labeled_relationX: description_of_relationX
+
+ ...
+
+ Parameters
+ ----------
+
+ directory_path: str
+ Folder where the input file is stored.
+ file_name: str
+ File name.
+
+ Returns
+ -------
+
+ mapper: dict
+ Dictionary of mappings between graph entities and predicates and human-readable version of them.
+
+ Example
+ -------
+
+ >>> from ampligraph.datasets import load_mapper_from_json
+ >>> mapper = load_mapper_from_json('folder', 'mapper.json')
+ >>> mapper['/m/234fsd/']
+ 'Dog'
+ """
+
+ logger.debug("Loading mapper from {}.".format(file_name))
+ with open(os.path.join(directory_path, file_name)) as f:
+ mapper = json.loads(f.read())
+ return mapper
def load_wn18(check_md5hash=False, add_reciprocal_rels=False):
- """Load the WN18 dataset
+ """Load the WN18 dataset.
+
+ WN18 is a subset of Wordnet. It was first presented by :cite:`bordes2013translating`.
.. warning::
The dataset includes a large number of inverse relations that spilled to the test set, and its use in
- experiments has been deprecated. Use WN18RR instead.
-
- WN18 is a subset of Wordnet. It was first presented by :cite:`bordes2013translating`.
+ experiments has been deprecated. **Use WN18RR instead**.
The WN18 dataset is loaded from file if it exists at the ``AMPLIGRAPH_DATA_HOME`` location.
- If ``AMPLIGRAPH_DATA_HOME`` is not set the the default ``~/ampligraph_datasets`` is checked.
-
- If the dataset is not found at either location it is downloaded and placed in ``AMPLIGRAPH_DATA_HOME``
+ If ``AMPLIGRAPH_DATA_HOME`` is not set, the default ``~/ampligraph_datasets`` is checked.
+ If the dataset is not found at either location, it is downloaded and placed in ``AMPLIGRAPH_DATA_HOME``
or ``~/ampligraph_datasets``.
The dataset is divided in three splits:
- - ``train``: 141,442 triples
- - ``valid`` 5,000 triples
- - ``test`` 5,000 triples
+ - `train`: 141,442 triples
+ - `valid` 5,000 triples
+ - `test` 5,000 triples
========= ========= ======= ======= ============ ===========
Dataset Train Valid Test Entities Relations
@@ -385,21 +554,21 @@ def load_wn18(check_md5hash=False, add_reciprocal_rels=False):
Parameters
----------
- check_md5hash : bool
- If ``True`` check the md5hash of the files. Defaults to ``False``.
+ check_md5hash: bool
+ If `True` check the md5hash of the files (default: `False`).
- add_reciprocal_rels : bool
+ add_reciprocal_rels: bool
Flag which specifies whether to add reciprocal relations. For every in the dataset
- this creates a corresponding triple with reciprocal relation . (default: False).
+ this creates a corresponding triple with reciprocal relation (default: `False`).
Returns
-------
- splits : dict
- The dataset splits {'train': train, 'valid': valid, 'test': test}. Each split is an ndarray of shape [n, 3].
+ splits: dict
+ The dataset splits `{'train': train, 'valid': valid, 'test': test}`. Each split is a ndarray of shape (n, 3).
- Examples
- --------
+ Example
+ -------
>>> from ampligraph.datasets import load_wn18
>>> X = load_wn18()
>>> X['test'][:3]
@@ -410,40 +579,46 @@ def load_wn18(check_md5hash=False, add_reciprocal_rels=False):
"""
wn18 = DatasetMetadata(
- dataset_name='wn18',
- filename='wn18.zip',
- url='https://s3-eu-west-1.amazonaws.com/ampligraph/datasets/wn18.zip',
- train_name='train.txt',
- valid_name='valid.txt',
- test_name='test.txt',
- train_checksum='7d68324d293837ac165c3441a6c8b0eb',
- valid_checksum='f4f66fec0ca83b5ebe7ad7003404e61d',
- test_checksum='b035247a8916c7ec3443fa949e1ff02c'
+ dataset_name="wn18",
+ filename="wn18.zip",
+ url="https://s3-eu-west-1.amazonaws.com/ampligraph/datasets/wn18.zip",
+ train_name="train.txt",
+ valid_name="valid.txt",
+ test_name="test.txt",
+ train_checksum="7d68324d293837ac165c3441a6c8b0eb",
+ valid_checksum="f4f66fec0ca83b5ebe7ad7003404e61d",
+ test_checksum="b035247a8916c7ec3443fa949e1ff02c",
)
- return _load_dataset(wn18,
- data_home=None,
- check_md5hash=check_md5hash,
- add_reciprocal_rels=add_reciprocal_rels)
+ return _load_dataset(
+ wn18,
+ data_home=None,
+ check_md5hash=check_md5hash,
+ add_reciprocal_rels=add_reciprocal_rels,
+ )
-def load_wn18rr(check_md5hash=False, clean_unseen=True, add_reciprocal_rels=False):
- """Load the WN18RR dataset
+def load_wn18rr(
+ check_md5hash=False, clean_unseen=True, add_reciprocal_rels=False
+):
+ """Load the WN18RR dataset.
The dataset is described in :cite:`DettmersMS018`.
- The WN18RR dataset is loaded from file if it exists at the ``AMPLIGRAPH_DATA_HOME`` location.
- If ``AMPLIGRAPH_DATA_HOME`` is not set the the default ``~/ampligraph_datasets`` is checked.
+ .. warning:: *WN18RR*'s validation set contains 198 unseen entities over 210 triples. The test set
+ has 209 unseen entities, distributed over 210 triples.
- If the dataset is not found at either location it is downloaded and placed in ``AMPLIGRAPH_DATA_HOME``
+ The WN18RR dataset is loaded from file if it exists at the ``AMPLIGRAPH_DATA_HOME`` location.
+ If ``AMPLIGRAPH_DATA_HOME`` is not set, the default ``~/ampligraph_datasets`` is checked.
+ If the dataset is not found at either location, it is downloaded and placed in ``AMPLIGRAPH_DATA_HOME``
or ``~/ampligraph_datasets``.
- It is divided in three splits:
+ This dataset is divided in three splits:
- - ``train``
- - ``valid``
- - ``test``
+ - `train`: 86,835 triples
+ - `valid`: 3,034 triples
+ - `test`: 3,134 triples
========= ========= ======= ======= ============ ===========
Dataset Train Valid Test Entities Relations
@@ -451,28 +626,25 @@ def load_wn18rr(check_md5hash=False, clean_unseen=True, add_reciprocal_rels=Fals
WN18RR 86,835 3,034 3,134 40,943 11
========= ========= ======= ======= ============ ===========
- .. warning:: WN18RR's validation set contains 198 unseen entities over 210 triples.
- The test set has 209 unseen entities, distributed over 210 triples.
-
Parameters
----------
- clean_unseen : bool
- If ``True``, filters triples in validation and test sets that include entities not present in the training set.
+ clean_unseen: bool
+ If `True`, filters triples in validation and test sets that include entities not present in the training set.
- check_md5hash : bool
- If ``True`` check the md5hash of the datset files. Defaults to ``False``.
+ check_md5hash: bool
+ If `True` check the md5hash of the datset files (default: `False`).
- add_reciprocal_rels : bool
+ add_reciprocal_rels: bool
Flag which specifies whether to add reciprocal relations. For every in the dataset
- this creates a corresponding triple with reciprocal relation . (default: False).
+ this creates a corresponding triple with reciprocal relation (default: `False`).
Returns
-------
- splits : dict
- The dataset splits: {'train': train, 'valid': valid, 'test': test}. Each split is an ndarray of shape [n, 3].
+ splits: dict
+ The dataset splits: `{'train': train, 'valid': valid, 'test': test}`. Each split is a ndarray of shape (n, 3).
- Examples
+ Example
-------
>>> from ampligraph.datasets import load_wn18rr
@@ -483,49 +655,54 @@ def load_wn18rr(check_md5hash=False, clean_unseen=True, add_reciprocal_rels=Fals
"""
wn18rr = DatasetMetadata(
- dataset_name='wn18RR',
- filename='wn18RR.zip',
- url='https://s3-eu-west-1.amazonaws.com/ampligraph/datasets/wn18RR.zip',
- train_name='train.txt',
- valid_name='valid.txt',
- test_name='test.txt',
- train_checksum='35e81af3ae233327c52a87f23b30ad3c',
- valid_checksum='74a2ee9eca9a8d31f1a7d4d95b5e0887',
- test_checksum='2b45ba1ba436b9d4ff27f1d3511224c9'
+ dataset_name="wn18RR",
+ filename="wn18RR.zip",
+ url="https://s3-eu-west-1.amazonaws.com/ampligraph/datasets/wn18RR.zip",
+ train_name="train.txt",
+ valid_name="valid.txt",
+ test_name="test.txt",
+ train_checksum="35e81af3ae233327c52a87f23b30ad3c",
+ valid_checksum="74a2ee9eca9a8d31f1a7d4d95b5e0887",
+ test_checksum="2b45ba1ba436b9d4ff27f1d3511224c9",
)
if clean_unseen:
- return _clean_data(_load_dataset(wn18rr,
- data_home=None,
- check_md5hash=check_md5hash,
- add_reciprocal_rels=add_reciprocal_rels))
+ return _clean_data(
+ _load_dataset(
+ wn18rr,
+ data_home=None,
+ check_md5hash=check_md5hash,
+ add_reciprocal_rels=add_reciprocal_rels,
+ )
+ )
else:
- return _load_dataset(wn18rr,
- data_home=None,
- check_md5hash=check_md5hash,
- add_reciprocal_rels=add_reciprocal_rels)
+ return _load_dataset(
+ wn18rr,
+ data_home=None,
+ check_md5hash=check_md5hash,
+ add_reciprocal_rels=add_reciprocal_rels,
+ )
def load_fb15k(check_md5hash=False, add_reciprocal_rels=False):
- """Load the FB15k dataset
+ """Load the FB15k dataset.
+
+ FB15k is a split of Freebase, first proposed by :cite:`bordes2013translating`.
.. warning::
The dataset includes a large number of inverse relations that spilled to the test set, and its use in
- experiments has been deprecated. Use FB15k-237 instead.
-
- FB15k is a split of Freebase, first proposed by :cite:`bordes2013translating`.
+ experiments has been deprecated. **Use FB15k-237 instead**.
The FB15k dataset is loaded from file if it exists at the ``AMPLIGRAPH_DATA_HOME`` location.
- If ``AMPLIGRAPH_DATA_HOME`` is not set the the default ``~/ampligraph_datasets`` is checked.
-
- If the dataset is not found at either location it is downloaded and placed in ``AMPLIGRAPH_DATA_HOME``
+ If ``AMPLIGRAPH_DATA_HOME`` is not set, the default ``~/ampligraph_datasets`` is checked.
+ If the dataset is not found at either location, it is downloaded and placed in ``AMPLIGRAPH_DATA_HOME``
or ``~/ampligraph_datasets``.
The dataset is divided in three splits:
- - ``train``
- - ``valid``
- - ``test``
+ - `train`: 483,142 triples
+ - `valid`: 50,000 triples
+ - `test`: 59,071 triples
========= ========= ======= ======= ============ ===========
Dataset Train Valid Test Entities Relations
@@ -535,22 +712,22 @@ def load_fb15k(check_md5hash=False, add_reciprocal_rels=False):
Parameters
----------
- check_md5hash : boolean
- If ``True`` check the md5hash of the files. Defaults to ``False``.
+ check_md5hash: bool
+ If `True` check the md5hash of the files (default: `False`).
- add_reciprocal_rels : bool
+ add_reciprocal_rels: bool
Flag which specifies whether to add reciprocal relations. For every in the dataset
- this creates a corresponding triple with reciprocal relation . (default: False).
+ this creates a corresponding triple with reciprocal relation (default: `False`).
Returns
-------
- splits : dict
- The dataset splits: {'train': train, 'valid': valid, 'test': test}. Each split is an ndarray of shape [n, 3].
+ splits: dict
+ The dataset splits: `{'train': train, 'valid': valid, 'test': test}`. Each split is a ndarray of shape (n, 3).
- Examples
- --------
+ Example
+ -------
>>> from ampligraph.datasets import load_fb15k
>>> X = load_fb15k()
@@ -566,71 +743,85 @@ def load_fb15k(check_md5hash=False, add_reciprocal_rels=False):
"""
FB15K = DatasetMetadata(
- dataset_name='fb15k',
- filename='fb15k.zip',
- url='https://s3-eu-west-1.amazonaws.com/ampligraph/datasets/fb15k.zip',
- train_name='train.txt',
- valid_name='valid.txt',
- test_name='test.txt',
- train_checksum='5a87195e68d7797af00e137a7f6929f2',
- valid_checksum='275835062bb86a86477a3c402d20b814',
- test_checksum='71098693b0efcfb8ac6cd61cf3a3b505'
+ dataset_name="fb15k",
+ filename="fb15k.zip",
+ url="https://s3-eu-west-1.amazonaws.com/ampligraph/datasets/fb15k.zip",
+ train_name="train.txt",
+ valid_name="valid.txt",
+ test_name="test.txt",
+ train_checksum="5a87195e68d7797af00e137a7f6929f2",
+ valid_checksum="275835062bb86a86477a3c402d20b814",
+ test_checksum="71098693b0efcfb8ac6cd61cf3a3b505",
)
- return _load_dataset(FB15K,
- data_home=None,
- check_md5hash=check_md5hash,
- add_reciprocal_rels=add_reciprocal_rels)
+ return _load_dataset(
+ FB15K,
+ data_home=None,
+ check_md5hash=check_md5hash,
+ add_reciprocal_rels=add_reciprocal_rels,
+ )
-def load_fb15k_237(check_md5hash=False, clean_unseen=True, add_reciprocal_rels=False):
- """Load the FB15k-237 dataset
+def load_fb15k_237(
+ check_md5hash=False,
+ clean_unseen=True,
+ add_reciprocal_rels=False,
+ return_mapper=False,
+):
+ """Load the FB15k-237 dataset (with option to load human labeled test subset).
FB15k-237 is a reduced version of FB15K. It was first proposed by :cite:`toutanova2015representing`.
- The FB15k-237 dataset is loaded from file if it exists at the ``AMPLIGRAPH_DATA_HOME`` location.
- If ``AMPLIGRAPH_DATA_HOME`` is not set the the default ``~/ampligraph_datasets`` is checked.
+ .. warning:: *FB15K-237*'s validation set contains 8 unseen entities over 9 triples. The test set has 29 unseen entities,
+ distributed over 28 triples.
- If the dataset is not found at either location it is downloaded and placed in ``AMPLIGRAPH_DATA_HOME``
+ The FB15k-237 dataset is loaded from file if it exists at the ``AMPLIGRAPH_DATA_HOME`` location.
+ If ``AMPLIGRAPH_DATA_HOME`` is not set, the default ``~/ampligraph_datasets`` is checked.
+ If the dataset is not found at either location, it is downloaded and placed in ``AMPLIGRAPH_DATA_HOME``
or ``~/ampligraph_datasets``.
The dataset is divided in three splits:
- - ``train``
- - ``valid``
- - ``test``
+ - `train`: 272,115 triples
+ - `valid`: 17,535 triples
+ - `test`: 20,466 triples
- ========= ========= ======= ======= ============ ===========
- Dataset Train Valid Test Entities Relations
- ========= ========= ======= ======= ============ ===========
- FB15K-237 272,115 17,535 20,466 14,541 237
- ========= ========= ======= ======= ============ ===========
+ It also contains a subset of the test set with human-readable labels, available here:
+ - `test-human`
+ - `test-human-ids`
+
+ ========= ========= ======= ======= ========== ======== =========
+ Dataset Train Valid Test Test-Human Entities Relations
+ ========= ========= ======= ======= ========== ======== =========
+ FB15K-237 272,115 17,535 20,466 273 14,541 237
+ ========= ========= ======= ======= ========== ======== =========
- .. warning::
- FB15K-237's validation set contains 8 unseen entities over 9 triples.
- The test set has 29 unseen entities, distributed over 28 triples.
Parameters
----------
- check_md5hash : boolean
- If ``True`` check the md5hash of the files. Defaults to ``False``.
+ check_md5hash: bool
+ If `True` check the md5hash of the files (default: `False`).
- clean_unseen : bool
- If ``True``, filters triples in validation and test sets that include entities not present in the training set.
+ clean_unseen: bool
+ If `True`, filters triples in validation and test sets that include entities not present in the training set.
- add_reciprocal_rels : bool
- Flag which specifies whether to add reciprocal relations. For every in the dataset
- this creates a corresponding triple with reciprocal relation . (default: False).
+ add_reciprocal_rels: bool
+ Flag which specifies whether to add reciprocal relations. For every in the dataset
+ this creates a corresponding triple with reciprocal relation (default: `False`).
+
+ return_mapper: bool
+ Whether to return human-readable labels in a form of dictionary in ``X["mapper"]`` field (default: `False`).
Returns
-------
- splits : dict
- The dataset splits: {'train': train, 'valid': valid, 'test': test}. Each split is an ndarray of shape [n, 3].
+ splits: dict
+ The dataset splits: `{'train': train, 'valid': valid, 'test': test, 'test-human':test_human, 'test-human-ids': test_human_ids}`.
+ Each split is a ndarray of shape (n, 3).
- Examples
- --------
+ Example
+ -------
>>> from ampligraph.datasets import load_fb15k_237
>>> X = load_fb15k_237()
@@ -640,47 +831,73 @@ def load_fb15k_237(check_md5hash=False, clean_unseen=True, add_reciprocal_rels=F
"""
- fb15k_237 = DatasetMetadata(
- dataset_name='fb15k-237',
- filename='fb15k-237.zip',
- url='https://s3-eu-west-1.amazonaws.com/ampligraph/datasets/fb15k-237.zip',
- train_name='train.txt',
- valid_name='valid.txt',
- test_name='test.txt',
- train_checksum='c05b87b9ac00f41901e016a2092d7837',
- valid_checksum='6a94efd530e5f43fcf84f50bc6d37b69',
- test_checksum='f5bdf63db39f455dec0ed259bb6f8628'
- )
+ if return_mapper:
+ fb15k_237 = DatasetMetadata(
+ dataset_name="fb15k-237",
+ filename="fb15k-237_human_interpretability.zip",
+ url="https://ampgraphenc.s3.eu-west-1.amazonaws.com/datasets/fb15k_237_human_interpretability.zip",
+ train_name="train.txt",
+ valid_name="valid.txt",
+ test_name="test.txt",
+ train_checksum="c05b87b9ac00f41901e016a2092d7837",
+ valid_checksum="6a94efd530e5f43fcf84f50bc6d37b69",
+ test_checksum="f5bdf63db39f455dec0ed259bb6f8628",
+ test_human_name="test_human.txt",
+ test_human_ids_name="test_human_ids.txt",
+ mapper_name="mapper.json",
+ test_human_checksum="5f43e8e2fb07846ffaf80877b0734744",
+ test_human_ids_checksum="e731d027b3bf9d4914393d75dae77dda",
+ mapper_checksum="b4dbdfaf1faf075746d2c32946be0234",
+ )
+ else:
+ fb15k_237 = DatasetMetadata(
+ dataset_name="fb15k-237",
+ filename="fb15k-237.zip",
+ url="https://s3-eu-west-1.amazonaws.com/ampligraph/datasets/fb15k-237.zip",
+ train_name="train.txt",
+ valid_name="valid.txt",
+ test_name="test.txt",
+ train_checksum="c05b87b9ac00f41901e016a2092d7837",
+ valid_checksum="6a94efd530e5f43fcf84f50bc6d37b69",
+ test_checksum="f5bdf63db39f455dec0ed259bb6f8628",
+ )
if clean_unseen:
- return _clean_data(_load_dataset(fb15k_237,
- data_home=None,
- check_md5hash=check_md5hash,
- add_reciprocal_rels=add_reciprocal_rels))
+ return _clean_data(
+ _load_dataset(
+ fb15k_237,
+ data_home=None,
+ check_md5hash=check_md5hash,
+ add_reciprocal_rels=add_reciprocal_rels,
+ )
+ )
else:
- return _load_dataset(fb15k_237,
- data_home=None,
- check_md5hash=check_md5hash,
- add_reciprocal_rels=add_reciprocal_rels)
+ return _load_dataset(
+ fb15k_237,
+ data_home=None,
+ check_md5hash=check_md5hash,
+ add_reciprocal_rels=add_reciprocal_rels,
+ )
-def load_yago3_10(check_md5hash=False, clean_unseen=True, add_reciprocal_rels=False):
- """Load the YAGO3-10 dataset
+def load_yago3_10(
+ check_md5hash=False, clean_unseen=True, add_reciprocal_rels=False
+):
+ """Load the YAGO3-10 dataset.
The dataset is a split of YAGO3 :cite:`mahdisoltani2013yago3`,
and has been first presented in :cite:`DettmersMS018`.
The YAGO3-10 dataset is loaded from file if it exists at the ``AMPLIGRAPH_DATA_HOME`` location.
- If ``AMPLIGRAPH_DATA_HOME`` is not set the the default ``~/ampligraph_datasets`` is checked.
-
+ If ``AMPLIGRAPH_DATA_HOME`` is not set, the default ``~/ampligraph_datasets`` is checked.
If the dataset is not found at either location it is downloaded and placed in ``AMPLIGRAPH_DATA_HOME``
or ``~/ampligraph_datasets``.
- It is divided in three splits:
+ This dataset is divided in three splits:
- - ``train``
- - ``valid``
- - ``test``
+ - `train`: 1,079,040 triples
+ - `valid`: 5,000 triples
+ - `test`: 5,000 triples
========= ========= ======= ======= ============ ===========
Dataset Train Valid Test Entities Relations
@@ -690,23 +907,23 @@ def load_yago3_10(check_md5hash=False, clean_unseen=True, add_reciprocal_rels=Fa
Parameters
----------
- check_md5hash : boolean
- If ``True`` check the md5hash of the files. Defaults to ``False``.
+ check_md5hash: bool
+ If `True` check the md5hash of the files (default: `False`).
- clean_unseen : bool
- If ``True``, filters triples in validation and test sets that include entities not present in the training set.
+ clean_unseen: bool
+ If `True`, filters triples in validation and test sets that include entities not present in the training set.
- add_reciprocal_rels : bool
+ add_reciprocal_rels: bool
Flag which specifies whether to add reciprocal relations. For every in the dataset
- this creates a corresponding triple with reciprocal relation . (default: False).
+ this creates a corresponding triple with reciprocal relation (default:`False`).
Returns
-------
- splits : dict
- The dataset splits: {'train': train, 'valid': valid, 'test': test}. Each split is an ndarray of shape [n, 3].
+ splits: dict
+ The dataset splits: `{'train': train, 'valid': valid, 'test': test}`. Each split is a ndarray of shape (n, 3).
- Examples
+ Example
-------
>>> from ampligraph.datasets import load_yago3_10
@@ -716,51 +933,62 @@ def load_yago3_10(check_md5hash=False, clean_unseen=True, add_reciprocal_rels=Fa
"""
yago3_10 = DatasetMetadata(
- dataset_name='YAGO3-10',
- filename='YAGO3-10.zip',
- url='https://s3-eu-west-1.amazonaws.com/ampligraph/datasets/YAGO3-10.zip',
- train_name='train.txt',
- valid_name='valid.txt',
- test_name='test.txt',
- train_checksum='a9da8f583ec3920570eeccf07199229a',
- valid_checksum='2d679a906f2b1ac29d74d5c948c1ad09',
- test_checksum='14bf97890b2fee774dbce5f326acd189'
+ dataset_name="YAGO3-10",
+ filename="YAGO3-10.zip",
+ url="https://s3-eu-west-1.amazonaws.com/ampligraph/datasets/YAGO3-10.zip",
+ train_name="train.txt",
+ valid_name="valid.txt",
+ test_name="test.txt",
+ train_checksum="a9da8f583ec3920570eeccf07199229a",
+ valid_checksum="2d679a906f2b1ac29d74d5c948c1ad09",
+ test_checksum="14bf97890b2fee774dbce5f326acd189",
)
if clean_unseen:
- return _clean_data(_load_dataset(yago3_10,
- data_home=None,
- check_md5hash=check_md5hash,
- add_reciprocal_rels=add_reciprocal_rels))
+ return _clean_data(
+ _load_dataset(
+ yago3_10,
+ data_home=None,
+ check_md5hash=check_md5hash,
+ add_reciprocal_rels=add_reciprocal_rels,
+ )
+ )
else:
- return _load_dataset(yago3_10,
- data_home=None,
- check_md5hash=check_md5hash,
- add_reciprocal_rels=add_reciprocal_rels)
+ return _load_dataset(
+ yago3_10,
+ data_home=None,
+ check_md5hash=check_md5hash,
+ add_reciprocal_rels=add_reciprocal_rels,
+ )
-def load_wn11(check_md5hash=False, clean_unseen=True, add_reciprocal_rels=False):
- """Load the WordNet11 (WN11) dataset
+def load_wn11(
+ check_md5hash=False, clean_unseen=True, add_reciprocal_rels=False
+):
+ """Load the WordNet11 (WN11) dataset.
WordNet was originally proposed in `WordNet: a lexical database for English` :cite:`miller1995wordnet`.
- WN11 dataset is loaded from file if it exists at the ``AMPLIGRAPH_DATA_HOME`` location.
- If ``AMPLIGRAPH_DATA_HOME`` is not set the the default ``~/ampligraph_datasets`` is checked.
+ .. note::
+ WN11 also provide true and negative labels for the triples in the validation and tests sets.
+ The positive base rate is close to 50%.
+ WN11 dataset is loaded from file if it exists at the ``AMPLIGRAPH_DATA_HOME`` location.
+ If ``AMPLIGRAPH_DATA_HOME`` is not set, the default ``~/ampligraph_datasets`` is checked.
If the dataset is not found at either location, it is downloaded and placed in ``AMPLIGRAPH_DATA_HOME``
or ``~/ampligraph_datasets``.
- It is divided in three splits:
+ This dataset is divided in three splits:
- - ``train``
- - ``valid``
- - ``test``
+ - `train`: 110361 triples
+ - `valid`: 5215 triples
+ - `test`: 21035 triples
Both the validation and test splits are associated with labels (binary ndarrays),
with `True` for positive statements and `False` for negatives:
- - ``valid_labels``
- - ``test_labels``
+ - `valid_labels`
+ - `test_labels`
========= ========= ========== ========== ======== ======== ============ ===========
Dataset Train Valid Pos Valid Neg Test Pos Test Neg Entities Relations
@@ -770,26 +998,26 @@ def load_wn11(check_md5hash=False, clean_unseen=True, add_reciprocal_rels=False)
Parameters
----------
- check_md5hash : boolean
- If ``True`` check the md5hash of the files. Defaults to ``False``.
+ check_md5hash: bool
+ If `True` check the md5hash of the files (default: `False`).
- clean_unseen : bool
- If ``True``, filters triples in validation and test sets that include entities not present in the training set.
+ clean_unseen: bool
+ If `True`, filters triples in validation and test sets that include entities not present in the training set.
- add_reciprocal_rels : bool
+ add_reciprocal_rels: bool
Flag which specifies whether to add reciprocal relations. For every in the dataset
- this creates a corresponding triple with reciprocal relation . (default: False).
+ this creates a corresponding triple with reciprocal relation (default: `False`).
Returns
-------
- splits : dict
- The dataset splits: {'train': train, 'valid': valid, 'valid_labels': valid_labels,
- 'test': test, 'test_labels': test_labels}.
- Each split containing a dataset is an ndarray of shape [n, 3].
- The labels are ndarray of shape [n].
+ splits: dict
+ The dataset splits: `{'train': train, 'valid': valid, 'valid_labels': valid_labels,
+ 'test': test, 'test_labels': test_labels}`.
+ Each split containing a dataset is a ndarray of shape (n, 3).
+ The labels are a ndarray of shape (n).
- Examples
+ Example
-------
>>> from ampligraph.datasets import load_wn11
@@ -801,63 +1029,73 @@ def load_wn11(check_md5hash=False, clean_unseen=True, add_reciprocal_rels=False)
"""
wn11 = DatasetMetadata(
- dataset_name='wordnet11',
- filename='wordnet11.zip',
- url='https://s3-eu-west-1.amazonaws.com/ampligraph/datasets/wordnet11.zip',
- train_name='train.txt',
- valid_name='dev.txt',
- test_name='test.txt',
- train_checksum='2429c672c89e33ad4fa8e1a3ade416e4',
- valid_checksum='87bf86e225e79294a2524089614b96aa',
- test_checksum='24113b464f8042c339e3e6833c1cebdf'
+ dataset_name="wordnet11",
+ filename="wordnet11.zip",
+ url="https://s3-eu-west-1.amazonaws.com/ampligraph/datasets/wordnet11.zip",
+ train_name="train.txt",
+ valid_name="dev.txt",
+ test_name="test.txt",
+ train_checksum="2429c672c89e33ad4fa8e1a3ade416e4",
+ valid_checksum="87bf86e225e79294a2524089614b96aa",
+ test_checksum="24113b464f8042c339e3e6833c1cebdf",
)
- dataset = _load_dataset(wn11, data_home=None,
- check_md5hash=check_md5hash,
- add_reciprocal_rels=add_reciprocal_rels)
+ dataset = _load_dataset(
+ wn11,
+ data_home=None,
+ check_md5hash=check_md5hash,
+ add_reciprocal_rels=add_reciprocal_rels,
+ )
- valid_labels = dataset['valid'][:, 3]
- test_labels = dataset['test'][:, 3]
+ valid_labels = dataset["valid"][:, 3]
+ test_labels = dataset["test"][:, 3]
- dataset['valid'] = dataset['valid'][:, 0:3]
- dataset['test'] = dataset['test'][:, 0:3]
+ dataset["valid"] = dataset["valid"][:, 0:3]
+ dataset["test"] = dataset["test"][:, 0:3]
- dataset['valid_labels'] = valid_labels == '1'
- dataset['test_labels'] = test_labels == '1'
+ dataset["valid_labels"] = valid_labels == "1"
+ dataset["test_labels"] = test_labels == "1"
if clean_unseen:
- clean_dataset, valid_idx, test_idx = _clean_data(dataset, return_idx=True)
- clean_dataset['valid_labels'] = dataset['valid_labels'][valid_idx]
- clean_dataset['test_labels'] = dataset['test_labels'][test_idx]
+ clean_dataset, valid_idx, test_idx = _clean_data(
+ dataset, return_idx=True
+ )
+ clean_dataset["valid_labels"] = dataset["valid_labels"][valid_idx]
+ clean_dataset["test_labels"] = dataset["test_labels"][test_idx]
return clean_dataset
else:
return dataset
-def load_fb13(check_md5hash=False, clean_unseen=True, add_reciprocal_rels=False):
- """Load the Freebase13 (FB13) dataset
+def load_fb13(
+ check_md5hash=False, clean_unseen=True, add_reciprocal_rels=False
+):
+ """Load the Freebase13 (FB13) dataset.
FB13 is a subset of Freebase :cite:`bollacker2008freebase`
and was initially presented in
`Reasoning With Neural Tensor Networks for Knowledge Base Completion` :cite:`socher2013reasoning`.
- FB13 dataset is loaded from file if it exists at the ``AMPLIGRAPH_DATA_HOME`` location.
- If ``AMPLIGRAPH_DATA_HOME`` is not set the the default ``~/ampligraph_datasets`` is checked.
+ .. note::
+ FB13 also provide true and negative labels for the triples in the validation and tests sets.
+ The positive base rate is close to 50%.
+ FB13 dataset is loaded from file if it exists at the ``AMPLIGRAPH_DATA_HOME`` location.
+ If ``AMPLIGRAPH_DATA_HOME`` is not set, the default ``~/ampligraph_datasets`` is checked.
If the dataset is not found at either location, it is downloaded and placed in ``AMPLIGRAPH_DATA_HOME``
or ``~/ampligraph_datasets``.
- It is divided in three splits:
+ This dataset is divided in three splits:
- - ``train``
- - ``valid``
- - ``test``
+ - `train`: 316232 triples
+ - `valid`: 11816 triples
+ - `test`: 47464 triples
Both the validation and test splits are associated with labels (binary ndarrays),
with `True` for positive statements and `False` for negatives:
- - ``valid_labels``
- - ``test_labels``
+ - `valid_labels`
+ - `test_labels`
========= ========= ========== ========== ======== ======== ============ ===========
Dataset Train Valid Pos Valid Neg Test Pos Test Neg Entities Relations
@@ -867,26 +1105,26 @@ def load_fb13(check_md5hash=False, clean_unseen=True, add_reciprocal_rels=False)
Parameters
----------
- check_md5hash : boolean
- If ``True`` check the md5hash of the files. Defaults to ``False``.
+ check_md5hash: bool
+ If `True` check the md5hash of the files (default: `False`).
- clean_unseen : bool
- If ``True``, filters triples in validation and test sets that include entities not present in the training set.
+ clean_unseen: bool
+ If `True`, filters triples in validation and test sets that include entities not present in the training set.
- add_reciprocal_rels : bool
+ add_reciprocal_rels: bool
Flag which specifies whether to add reciprocal relations. For every in the dataset
- this creates a corresponding triple with reciprocal relation . (default: False).
+ this creates a corresponding triple with reciprocal relation (default: False).
Returns
-------
- splits : dict
+ splits: dict
The dataset splits: {'train': train, 'valid': valid, 'valid_labels': valid_labels,
'test': test, 'test_labels': test_labels}.
- Each split containing a dataset is an ndarray of shape [n, 3].
- The labels are ndarray of shape [n].
+ Each split containing a dataset is a ndarray of shape (n, 3).
+ The labels are ndarray of shape (n).
- Examples
+ Example
-------
>>> from ampligraph.datasets import load_fb13
@@ -898,35 +1136,39 @@ def load_fb13(check_md5hash=False, clean_unseen=True, add_reciprocal_rels=False)
"""
fb13 = DatasetMetadata(
- dataset_name='freebase13',
- filename='freebase13.zip',
- url='https://s3-eu-west-1.amazonaws.com/ampligraph/datasets/freebase13.zip',
- train_name='train.txt',
- valid_name='dev.txt',
- test_name='test.txt',
- train_checksum='9099ebcd85ab3ce723cfaaf34f74dceb',
- valid_checksum='c4ef7b244baa436a97c2a5e57d4ba7ed',
- test_checksum='f9af2eac7c5a86996c909bdffd295528'
+ dataset_name="freebase13",
+ filename="freebase13.zip",
+ url="https://s3-eu-west-1.amazonaws.com/ampligraph/datasets/freebase13.zip",
+ train_name="train.txt",
+ valid_name="dev.txt",
+ test_name="test.txt",
+ train_checksum="9099ebcd85ab3ce723cfaaf34f74dceb",
+ valid_checksum="c4ef7b244baa436a97c2a5e57d4ba7ed",
+ test_checksum="f9af2eac7c5a86996c909bdffd295528",
)
- dataset = _load_dataset(fb13,
- data_home=None,
- check_md5hash=check_md5hash,
- add_reciprocal_rels=add_reciprocal_rels)
+ dataset = _load_dataset(
+ fb13,
+ data_home=None,
+ check_md5hash=check_md5hash,
+ add_reciprocal_rels=add_reciprocal_rels,
+ )
- valid_labels = dataset['valid'][:, 3]
- test_labels = dataset['test'][:, 3]
+ valid_labels = dataset["valid"][:, 3]
+ test_labels = dataset["test"][:, 3]
- dataset['valid'] = dataset['valid'][:, 0:3]
- dataset['test'] = dataset['test'][:, 0:3]
+ dataset["valid"] = dataset["valid"][:, 0:3]
+ dataset["test"] = dataset["test"][:, 0:3]
- dataset['valid_labels'] = valid_labels == '1'
- dataset['test_labels'] = test_labels == '1'
+ dataset["valid_labels"] = valid_labels == "1"
+ dataset["test_labels"] = test_labels == "1"
if clean_unseen:
- clean_dataset, valid_idx, test_idx = _clean_data(dataset, return_idx=True)
- clean_dataset['valid_labels'] = dataset['valid_labels'][valid_idx]
- clean_dataset['test_labels'] = dataset['test_labels'][test_idx]
+ clean_dataset, valid_idx, test_idx = _clean_data(
+ dataset, return_idx=True
+ )
+ clean_dataset["valid_labels"] = dataset["valid_labels"][valid_idx]
+ clean_dataset["test_labels"] = dataset["test_labels"][test_idx]
return clean_dataset
else:
return dataset
@@ -942,49 +1184,61 @@ def load_all_datasets(check_md5hash=False):
load_fb13(check_md5hash)
-def load_from_rdf(folder_name, file_name, rdf_format='nt', data_home=None, add_reciprocal_rels=False):
- """Load an RDF file
+def load_from_rdf(
+ folder_name,
+ file_name,
+ rdf_format="nt",
+ data_home=None,
+ add_reciprocal_rels=False,
+):
+ """Load an RDF file.
Loads an RDF knowledge graph using rdflib_ APIs.
- Multiple RDF serialization formats are supported (nt, ttl, rdf/xml, etc).
+ Multiple RDF serialization formats are supported (`nt`, `ttl`, `rdf`/`xml`, etc).
The entire graph will be loaded in memory, and converted into an rdflib `Graph` object.
.. _rdflib: https://rdflib.readthedocs.io/
.. warning::
Large RDF graphs should be serialized to ntriples beforehand and loaded with ``load_from_ntriples()`` instead.
+ This function, indeed, is faster by orders of magnitude.
- .. note::
- It is recommended to use :meth:`ampligraph.evaluation.train_test_split_no_unseen` to split custom
- knowledge graphs into train, validation, and test sets. Using this function will lead to validation, test sets
- that do not include triples with entities that do not occur in the training set.
+ .. hint::
+ To split a generic knowledge graphs into **training**, **validation**, and **test** sets do not use the above
+ function, but rather :meth:`~ampligraph.evaluation.protocol.train_test_split_no_unseen`: this will return
+ validation and test sets not including triples with entities not present in the training set.
Parameters
----------
folder_name: str
Base folder where the file is stored.
- file_name : str
+ file_name: str
File name.
- rdf_format : str
- The RDF serialization format (nt, ttl, rdf/xml - see rdflib documentation).
- data_home : str
+ rdf_format: str
+ The RDF serialization format (`nt`, `ttl`, `rdf`/`xml` - see rdflib documentation).
+ data_home: str
The path to the folder that contains the datasets.
- add_reciprocal_rels : bool
+ add_reciprocal_rels: bool
Flag which specifies whether to add reciprocal relations. For every in the dataset
- this creates a corresponding triple with reciprocal relation . (default: False).
+ this creates a corresponding triple with reciprocal relation (default: `False`).
Returns
-------
- triples : ndarray , shape [n, 3]
- the actual triples of the file.
+ triples: ndarray, shape (n, 3)
+ The actual triples of the file.
"""
- logger.debug('Loading rdf data from {}.'.format(file_name))
+ logger.debug("Loading rdf data from {}.".format(file_name))
data_home = _get_data_home(data_home)
from rdflib import Graph
+
g = Graph()
- g.parse(os.path.join(data_home, folder_name, file_name), format=rdf_format, publicID='http://test#')
+ g.parse(
+ os.path.join(data_home, folder_name, file_name),
+ format=rdf_format,
+ publicID="http://test#",
+ )
triples = pd.DataFrame(np.array(g))
triples = triples.drop_duplicates()
if add_reciprocal_rels:
@@ -993,12 +1247,14 @@ def load_from_rdf(folder_name, file_name, rdf_format='nt', data_home=None, add_r
return triples.values
-def load_from_ntriples(folder_name, file_name, data_home=None, add_reciprocal_rels=False):
- """Load RDF ntriples
+def load_from_ntriples(
+ folder_name, file_name, data_home=None, add_reciprocal_rels=False
+):
+ """Load a dataset of RDF ntriples.
Loads an RDF knowledge graph serialized as ntriples, without building an RDF graph in memory.
- This function should be preferred over ``load_from_rdf()``,
- since it does not load the graph into an rdflib model (and it is therefore faster by order of magnitudes).
+ This function should be preferred over ``load_from_rdf()``, since it does not load the graph into an rdflib
+ model (and it is therefore faster by order of magnitudes).
Nevertheless, it requires a ntriples_ serialization as in the example below:
.. _ntriples: https://www.w3.org/TR/n-triples/.
@@ -1008,38 +1264,40 @@ def load_from_ntriples(folder_name, file_name, data_home=None, add_reciprocal_re
_:alice _:bob .
_:bob _:alice .
- .. note::
- It is recommended to use :meth:`ampligraph.evaluation.train_test_split_no_unseen` to split custom
- knowledge graphs into train, validation, and test sets. Using this function will lead to validation, test sets
- that do not include triples with entities that do not occur in the training set.
+ .. hint::
+ To split a generic knowledge graphs into **training**, **validation**, and **test** sets do not use the above
+ function, but rather :meth:`~ampligraph.evaluation.protocol.train_test_split_no_unseen`: this will return
+ validation and test sets not including triples with entities not present in the training set.
Parameters
----------
folder_name: str
- base folder where the file is stored.
- file_name : str
- file name
- data_home : str
+ Base folder where the file is stored.
+ file_name: str
+ File name.
+ data_home: str
The path to the folder that contains the datasets.
- add_reciprocal_rels : bool
+ add_reciprocal_rels: bool
Flag which specifies whether to add reciprocal relations. For every in the dataset
- this creates a corresponding triple with reciprocal relation . (default: False).
+ this creates a corresponding triple with reciprocal relation (default: `False`).
Returns
-------
- triples : ndarray , shape [n, 3]
- the actual triples of the file.
+ triples: ndarray, shape (n, 3)
+ The actual triples of the file.
"""
- logger.debug('Loading rdf ntriples from {}.'.format(file_name))
+ logger.debug("Loading rdf ntriples from {}.".format(file_name))
data_home = _get_data_home(data_home)
- df = pd.read_csv(os.path.join(data_home, folder_name, file_name),
- sep=r'\s+',
- header=None,
- names=None,
- dtype=str,
- usecols=[0, 1, 2])
+ df = pd.read_csv(
+ os.path.join(data_home, folder_name, file_name),
+ sep=" ",
+ header=None,
+ names=None,
+ dtype=str,
+ usecols=[0, 1, 2],
+ )
# Remove trailing full stop (if present)
df[2] = df[2].apply(lambda x: x.rsplit(".", 1)[0])
@@ -1050,84 +1308,106 @@ def load_from_ntriples(folder_name, file_name, data_home=None, add_reciprocal_re
return df.values
-def generate_focusE_dataset_splits(dataset, split_test_into_top_bottom=True, split_threshold=0.1):
- """ Creates the dataset splits for training models with FocusE layers
-
+# FocusE
+def generate_focusE_dataset_splits(
+ dataset, split_test_into_top_bottom=True, split_threshold=0.1
+):
+ """Creates the dataset splits for training models with FocusE layers
+
Parameters
----------
- dataset : dict
- dictionary of train, test, valid datasets of size (n,4) - where the first 3 cols are s, p, o and
- 4th is the numeric value associated with the triple
-
+ dataset: dict
+ Dictionary of train, test, valid datasets of size (n,m) - where m>3. The first 3 cols are `subject`,
+ `predicate`, and `object`. Afterwards, is the numeric values (potentially multiple) associated with each triple.
+
split_test_into_top_bottom: bool
- Splits the test set by numeric values and returns test_top_split and test_bottom_split by splitting
- based on sorted numeric values and returning top and bottom k% triples, where k is specified by
- `split_threshold` argument
-
+ Splits the test set by numeric values and returns `test_top_split` and `test_bottom_split` by splitting
+ based on sorted numeric values and returning top and bottom k*100% triples, where `k` is specified by
+ `split_threshold` argument.
+
split_threshold: float
- specifies the top and bottom percentage of triples to return
-
+ Specifies the top and bottom percentage of triples to return.
+
Returns
-------
- splits : dict
- The dataset splits: {'train': train,
- 'train_numeric_values': train_numeric_values,
- 'valid': valid,
+ splits: dict
+ The dataset splits: `{'train': train,
+ 'train_numeric_values': train_numeric_values,
+ 'valid': valid,
'valid_numeric_values': valid_numeric_values,
- 'test': test,
+ 'test': test,
'test_numeric_values': test_numeric_values,
- 'test_topk': test_topk,
+ 'test_topk': test_topk,
'test_topk_numeric_values': test_topk_numeric_values,
- 'test_bottomk': test_bottomk,
- 'test_bottomk_numeric_values': test_bottomk_numeric_values}.
- Each numeric value split contains numeric values associated with corresponding dataset split and
- is a ndarray of shape [n, 1].
- Each dataset split is a ndarray of shape [n,3]
- The topk and bottomk splits are only returned when split_test_into_top_bottom is set to True
+ 'test_bottomk': test_bottomk,
+ 'test_bottomk_numeric_values': test_bottomk_numeric_values}`.
+ Each numeric value split contains numeric values associated with the corresponding dataset split and
+ is a ndarray of shape (n, 1).
+ Each dataset split is a ndarray of shape (n,3).
+ The `topk` and `bottomk` splits are only returned when `split_test_into_top_bottom` is set to `True` and contain
+ the triples ordered by highest/lowest numeric edge value associated. These are typically used at evaluation time
+ aiming at observing a model that assigns the highest rank possible to the `_topk` and the lowest possible to
+ the `_bottomk`.
"""
- dataset['train_numeric_values'] = dataset['train'][:, 3].astype(np.float32)
- dataset['valid_numeric_values'] = dataset['valid'][:, 3].astype(np.float32)
- dataset['test_numeric_values'] = dataset['test'][:, 3].astype(np.float32)
-
- dataset['train'] = dataset['train'][:, 0:3]
- dataset['valid'] = dataset['valid'][:, 0:3]
- dataset['test'] = dataset['test'][:, 0:3]
-
- sorted_indices = np.argsort(dataset['test_numeric_values'])
- dataset['test'] = dataset['test'][sorted_indices]
- dataset['test_numeric_values'] = dataset['test_numeric_values'][sorted_indices]
-
+ dataset["train_numeric_values"] = dataset["train"][:, 3:].astype(
+ np.float32
+ )
+ dataset["valid_numeric_values"] = dataset["valid"][:, 3:].astype(
+ np.float32
+ )
+ dataset["test_numeric_values"] = dataset["test"][:, 3:].astype(np.float32)
+
+ dataset["train"] = dataset["train"][:, 0:3]
+ dataset["valid"] = dataset["valid"][:, 0:3]
+ dataset["test"] = dataset["test"][:, 0:3]
+
+ sorted_indices = np.squeeze(
+ np.argsort(dataset["test_numeric_values"], axis=0)
+ )
+ dataset["test"] = dataset["test"][sorted_indices]
+ dataset["test_numeric_values"] = dataset["test_numeric_values"][
+ sorted_indices
+ ]
+
if split_test_into_top_bottom:
- split_threshold = int(split_threshold * dataset['test'].shape[0])
-
- dataset['test_bottomk'] = dataset['test'][:split_threshold]
- dataset['test_bottomk_numeric_values'] = dataset['test_numeric_values'][:split_threshold]
-
- dataset['test_topk'] = dataset['test'][-split_threshold:]
- dataset['test_topk_numeric_values'] = dataset['test_numeric_values'][-split_threshold:]
-
+ split_threshold = int(split_threshold * dataset["test"].shape[0])
+
+ dataset["test_bottomk"] = dataset["test"][:split_threshold]
+ dataset["test_bottomk_numeric_values"] = dataset[
+ "test_numeric_values"
+ ][:split_threshold]
+
+ dataset["test_topk"] = dataset["test"][-split_threshold:]
+ dataset["test_topk_numeric_values"] = dataset["test_numeric_values"][
+ -split_threshold:
+ ]
+
return dataset
-def load_onet20k(check_md5hash=False, clean_unseen=True, split_test_into_top_bottom=True, split_threshold=0.1):
- """Load the O*NET20K dataset
+def load_onet20k(
+ check_md5hash=False,
+ clean_unseen=True,
+ split_test_into_top_bottom=True,
+ split_threshold=0.1,
+):
+ """Load the O*NET20K dataset.
O*NET20K was originally proposed in :cite:`pai2021learning`.
- It a subset of `O*NET `_, a dataset that includes job descriptions, skills
- and labeled, binary relations between such concepts. Each triple is labeled with a numeric value that
- indicates the importance of that link.
-
- ONET*20K dataset is loaded from file if it exists at the ``AMPLIGRAPH_DATA_HOME`` location.
- If ``AMPLIGRAPH_DATA_HOME`` is not set the the default ``~/ampligraph_datasets`` is checked.
+ It is a subset of `O*NET `_, a dataset that includes job descriptions, skills
+ and labeled, binary relations between such concepts. Each triple is labeled with a numeric value that
+ indicates the importance of that link.
+ O*NET20K dataset is loaded from file if it exists at the ``AMPLIGRAPH_DATA_HOME`` location.
+ If ``AMPLIGRAPH_DATA_HOME`` is not set, the default ``~/ampligraph_datasets`` is checked.
If the dataset is not found at either location, it is downloaded and placed in ``AMPLIGRAPH_DATA_HOME``
or ``~/ampligraph_datasets``.
- It is divided in three splits:
+ This dataset is divided in three splits:
- - ``train``
- - ``valid``
- - ``test``
+ - `train`: 461,932 triples
+ - `valid`: 850 triples
+ - `test`: 2,000 triples
Each triple in these splits is associated to a numeric value which represents the importance/relevance of
the link.
@@ -1140,27 +1420,24 @@ def load_onet20k(check_md5hash=False, clean_unseen=True, split_test_into_top_bot
Parameters
----------
- check_md5hash : boolean
- If ``True`` check the md5hash of the files. Defaults to ``False``.
-
- clean_unseen : bool
- If ``True``, filters triples in validation and test sets that include entities not present in the training
+ check_md5hash: bool
+ If `True` check the md5hash of the files (default: `False`).
+ clean_unseen: bool
+ If `True`, filters triples in validation and test sets that include entities not present in the training
set.
-
+
split_test_into_top_bottom: bool
- Splits the test set by numeric values and returns test_top_split and test_bottom_split by splitting based
- on sorted numeric values and returning top and bottom k% triples, where k is specified by `split_threshold`
- argument
-
+ Splits the test set by numeric values and returns `test_top_split` and `test_bottom_split` by splitting based
+ on sorted numeric values and returning top and bottom k% triples, where `k` is specified by `split_threshold`
+ argument.
+
split_threshold: float
- specifies the top and bottom percentage of triples to return
-
+ Specifies the top and bottom percentage of triples to return.
Returns
-------
-
- splits : dict
- The dataset splits: {'train': train,
+ splits: dict
+ The dataset splits: `{'train': train,
'valid': valid,
'test': test,
'test_topk': test_topk,
@@ -1169,60 +1446,70 @@ def load_onet20k(check_md5hash=False, clean_unseen=True, split_test_into_top_bot
'valid_numeric_values':valid_numeric_values,
'test_numeric_values': test_numeric_values,
'test_topk_numeric_values': test_topk_numeric_values,
- 'test_bottomk_numeric_values': test_bottomk_numeric_values}.
-
+ 'test_bottomk_numeric_values': test_bottomk_numeric_values}`.
Each ``*_numeric_values`` split contains numeric values associated to the corresponding dataset split and
- is a ndarray of shape [n].
-
- Each dataset split is a ndarray of shape [n,3].
-
- The ``*_topk`` and ``*_bottomk`` splits are only returned when ``split_test_into_top_bottom=True``.
-
- Examples
+ is a ndarray of shape (n).
+ Each dataset split is a ndarray of shape (n,3).
+ The ``*_topk`` and ``*_bottomk`` splits are only returned when ``split_test_into_top_bottom=True`` and contain
+ the triples ordered by highest/lowest numeric edge value associated. These are typically used at evaluation time
+ aiming at observing a model that assigns the highest rank possible to the `_topk` and the lowest possible to
+ the `_bottomk`.
+
+ Example
-------
-
>>> from ampligraph.datasets import load_onet20k
>>> X = load_onet20k()
+ >>> X["train"][0]
+ ['Job_27-1021.00' 'has_ability_LV' '1.A.1.b.2']
+ >>> X['train_numeric_values'][0]
+ [0.6257143]
"""
onet20k = DatasetMetadata(
- dataset_name='onet20k',
- filename='onet20k.zip',
- url='https://s3-eu-west-1.amazonaws.com/ampligraph/datasets/onet20k.zip',
- train_name='train.tsv',
- valid_name='valid.tsv',
- test_name='test.tsv',
- train_checksum='516220427a9a18516fd7a804a6944d64',
- valid_checksum='d7806951ac3d916c5c5a0304eea064d2',
- test_checksum='e5baec19037cb0bddc5a2fe3c0f4445a'
+ dataset_name="onet20k",
+ filename="onet20k.zip",
+ url="https://s3-eu-west-1.amazonaws.com/ampligraph/datasets/onet20k.zip",
+ train_name="train.tsv",
+ valid_name="valid.tsv",
+ test_name="test.tsv",
+ train_checksum="516220427a9a18516fd7a804a6944d64",
+ valid_checksum="d7806951ac3d916c5c5a0304eea064d2",
+ test_checksum="e5baec19037cb0bddc5a2fe3c0f4445a",
+ )
+
+ dataset = _load_dataset(
+ onet20k, data_home=None, check_md5hash=check_md5hash
)
- dataset = _load_dataset(onet20k, data_home=None,
- check_md5hash=check_md5hash)
-
if clean_unseen:
dataset = _clean_data(dataset)
-
- return generate_focusE_dataset_splits(dataset, split_test_into_top_bottom, split_threshold)
+ return generate_focusE_dataset_splits(
+ dataset, split_test_into_top_bottom, split_threshold
+ )
-def load_ppi5k(check_md5hash=False, clean_unseen=True, split_test_into_top_bottom=True, split_threshold=0.1):
- """Load the PPI5K dataset
+
+def load_ppi5k(
+ check_md5hash=False,
+ clean_unseen=True,
+ split_test_into_top_bottom=True,
+ split_threshold=0.1,
+):
+ """Load the PPI5K dataset.
Originally proposed in :cite:`chen2019embedding`, PPI5K is a subset of the protein-protein
interactions (PPI) knowledge graph :cite:`PPI`. Numeric values represent the confidence of the link
based on existing scientific literature evidence.
PPI5K is loaded from file if it exists at the ``AMPLIGRAPH_DATA_HOME`` location.
- If ``AMPLIGRAPH_DATA_HOME`` is not set the the default ``~/ampligraph_datasets`` is checked.
-
+ If ``AMPLIGRAPH_DATA_HOME`` is not set, the default ``~/ampligraph_datasets`` is checked.
If the dataset is not found at either location, it is downloaded and placed in ``AMPLIGRAPH_DATA_HOME``
or ``~/ampligraph_datasets``.
It is divided into three splits:
- - ``train``
- - ``valid``
- - ``test``
+ - `train`: 230,929 triples
+ - `valid`: 19,017 triples
+ - `test`: 21,720 triples
Each triple in these splits is associated to a numeric value which models additional information on the
fact (importance, relevance of the link).
@@ -1230,35 +1517,31 @@ def load_ppi5k(check_md5hash=False, clean_unseen=True, split_test_into_top_botto
========= ========= ======== =========== ========== ===========
Dataset Train Valid Test Entities Relations
========= ========= ======== =========== ========== ===========
- PPI5K 230929 19017 21720 4999 7
+ PPI5K 230929 19017 21720 4999 7
========= ========= ======== =========== ========== ===========
Parameters
----------
- check_md5hash : boolean
- If ``True`` check the md5hash of the files. Defaults to ``False``.
-
- clean_unseen : bool
- If ``True``, filters triples in validation and test sets that include entities not present in the training
+ check_md5hash: bool
+ If `True` check the md5hash of the files (default: `False`).
+ clean_unseen: bool
+ If `True`, filters triples in validation and test sets that include entities not present in the training
set.
-
+
split_test_into_top_bottom: bool
- When set to ``True``, the function also returns subsets of the test set that includes only the top-k or
- bottom-k numeric-enriched triples. splits ``test_topk``, ``test_bottomk`` and their
+ When set to `True`, the function also returns subsets of the test set that includes only the top-k or
+ bottom-k numeric-enriched triples. Splits `test_topk`, `test_bottomk` and their
numeric values. Such splits are generated by sorting Splits the test set by numeric values and returns
- test_top_split and test_bottom_split by splitting based
- on sorted numeric values and returning top and bottom k% triples, where 'k' is specified by the
- ``split_threshold`` argument.
-
+ `test_top_split` and `test_bottom_split` by splitting based on sorted numeric values and returning top
+ and bottom k% triples, where `k` is specified by the ``split_threshold`` argument.
+
split_threshold: float
- specifies the top and bottom percentage of triples to return
-
+ Specifies the top and bottom percentage of triples to return.
Returns
-------
-
- splits : dict
- The dataset splits: {'train': train,
+ splits: dict
+ The dataset splits: `{'train': train,
'valid': valid,
'test': test,
'test_topk': test_topk,
@@ -1267,60 +1550,68 @@ def load_ppi5k(check_md5hash=False, clean_unseen=True, split_test_into_top_botto
'valid_numeric_values':valid_numeric_values,
'test_numeric_values': test_numeric_values,
'test_topk_numeric_values': test_topk_numeric_values,
- 'test_bottomk_numeric_values': test_bottomk_numeric_values}.
-
+ 'test_bottomk_numeric_values': test_bottomk_numeric_values}`.
Each ``*_numeric_values`` split contains numeric values associated to the corresponding dataset split and
- is a ndarray of shape [n].
-
- Each dataset split is a ndarray of shape [n,3].
-
- The ``*_topk`` and ``*_bottomk`` splits are only returned when ``split_test_into_top_bottom=True``.
-
- Examples
+ is a ndarray of shape (n).
+ Each dataset split is a ndarray of shape (n,3).
+ The ``*_topk`` and ``*_bottomk`` splits are only returned when ``split_test_into_top_bottom=True`` and contain
+ the triples ordered by highest/lowest numeric edge value associated. These are typically used at evaluation time
+ aiming at observing a model that assigns the highest rank possible to the `_topk` and the lowest possible to
+ the `_bottomk`.
+
+ Example
-------
-
>>> from ampligraph.datasets import load_ppi5k
>>> X = load_ppi5k()
+ >>> X["train"][0]
+ ['4001' '5' '4176']
+ >>> X['train_numeric_values'][0]
+ [0.329]
"""
ppi5k = DatasetMetadata(
- dataset_name='ppi5k',
- filename='ppi5k.zip',
- url='https://s3-eu-west-1.amazonaws.com/ampligraph/datasets/ppi5k.zip',
- train_name='train.tsv',
- valid_name='valid.tsv',
- test_name='test.tsv',
- train_checksum='d8b54de3482c0d043118cbd05f2666cf',
- valid_checksum='2bd094118f4be1f4f6d6a1d4707271c1',
- test_checksum='7e6e345f496ed9a0cc58b91d4877ddd6'
+ dataset_name="ppi5k",
+ filename="ppi5k.zip",
+ url="https://s3-eu-west-1.amazonaws.com/ampligraph/datasets/ppi5k.zip",
+ train_name="train.tsv",
+ valid_name="valid.tsv",
+ test_name="test.tsv",
+ train_checksum="d8b54de3482c0d043118cbd05f2666cf",
+ valid_checksum="2bd094118f4be1f4f6d6a1d4707271c1",
+ test_checksum="7e6e345f496ed9a0cc58b91d4877ddd6",
)
- dataset = _load_dataset(ppi5k, data_home=None,
- check_md5hash=check_md5hash)
-
+ dataset = _load_dataset(ppi5k, data_home=None, check_md5hash=check_md5hash)
+
if clean_unseen:
dataset = _clean_data(dataset)
- return generate_focusE_dataset_splits(dataset, split_test_into_top_bottom, split_threshold)
+ return generate_focusE_dataset_splits(
+ dataset, split_test_into_top_bottom, split_threshold
+ )
-def load_nl27k(check_md5hash=False, clean_unseen=True, split_test_into_top_bottom=True, split_threshold=0.1):
- """Load the NL27K dataset
+def load_nl27k(
+ check_md5hash=False,
+ clean_unseen=True,
+ split_test_into_top_bottom=True,
+ split_threshold=0.1,
+):
+ """Load the NL27K dataset.
NL27K was originally proposed in :cite:`chen2019embedding`. It is a subset of the Never Ending Language
Learning (NELL) dataset :cite:`mitchell2018never`, which collects data from web pages.
Numeric values on triples represent link uncertainty.
NL27K is loaded from file if it exists at the ``AMPLIGRAPH_DATA_HOME`` location.
- If ``AMPLIGRAPH_DATA_HOME`` is not set the the default ``~/ampligraph_datasets`` is checked.
-
+ If ``AMPLIGRAPH_DATA_HOME`` is not set, the default ``~/ampligraph_datasets`` is checked.
If the dataset is not found at either location, it is downloaded and placed in ``AMPLIGRAPH_DATA_HOME``
or ``~/ampligraph_datasets``.
It is divided into three splits:
- - ``train``
- - ``valid``
- - ``test``
+ - `train`: 149,100 triples
+ - `valid`: 12,274 triples
+ - `test`: 14,026 triples
Each triple in these splits is associated to a numeric value which represents the importance/relevance of
the link.
@@ -1328,32 +1619,29 @@ def load_nl27k(check_md5hash=False, clean_unseen=True, split_test_into_top_botto
========= ========= ======== =========== ========== ===========
Dataset Train Valid Test Entities Relations
========= ========= ======== =========== ========== ===========
- NL27K 149100 12274 14026 27221 405
+ NL27K 149,100 12,274 14,026 27,221 405
========= ========= ======== =========== ========== ===========
Parameters
----------
- check_md5hash : boolean
- If ``True`` check the md5hash of the files. Defaults to ``False``.
-
- clean_unseen : bool
- If ``True``, filters triples in validation and test sets that include entities not present in the training
+ check_md5hash: bool
+ If `True` check the md5hash of the files (default: `False`).
+ clean_unseen: bool
+ If `True`, filters triples in validation and test sets that include entities not present in the training
set.
-
+
split_test_into_top_bottom: bool
- Splits the test set by numeric values and returns test_top_split and test_bottom_split by splitting based
- on sorted numeric values and returning top and bottom k% triples, where k is specified by `split_threshold`
- argument
-
+ Splits the test set by numeric values and returns `test_top_split` and `test_bottom_split` by splitting based
+ on sorted numeric values and returning top and bottom k% triples, where `k` is specified by `split_threshold`
+ argument.
+
split_threshold: float
- specifies the top and bottom percentage of triples to return
-
+ Specifies the top and bottom percentage of triples to return.
Returns
-------
-
- splits : dict
- The dataset splits: {'train': train,
+ splits: dict
+ The dataset splits: `{'train': train,
'valid': valid,
'test': test,
'test_topk': test_topk,
@@ -1362,60 +1650,68 @@ def load_nl27k(check_md5hash=False, clean_unseen=True, split_test_into_top_botto
'valid_numeric_values':valid_numeric_values,
'test_numeric_values': test_numeric_values,
'test_topk_numeric_values': test_topk_numeric_values,
- 'test_bottomk_numeric_values': test_bottomk_numeric_values}.
-
+ 'test_bottomk_numeric_values': test_bottomk_numeric_values}`.
Each ``*_numeric_values`` split contains numeric values associated to the corresponding dataset split and
- is a ndarray of shape [n].
-
- Each dataset split is a ndarray of shape [n,3].
-
- The ``*_topk`` and ``*_bottomk`` splits are only returned when ``split_test_into_top_bottom=True``.
-
- Examples
+ is a ndarray of shape (n).
+ Each dataset split is a ndarray of shape (n,3).
+ The ``*_topk`` and ``*_bottomk`` splits are only returned when ``split_test_into_top_bottom=True`` and contain
+ the triples ordered by highest/lowest numeric edge value associated. These are typically used at evaluation time
+ aiming at observing a model that assigns the highest rank possible to the `_topk` and the lowest possible to
+ the `_bottomk`.
+
+ Example
-------
-
>>> from ampligraph.datasets import load_nl27k
>>> X = load_nl27k()
+ >>> X["train"][0]
+ ['concept:company:business_review' 'concept:competeswith' 'concept:company:miami_herald001']
+ >>> X['train_numeric_values'][0]
+ [0.859375]
"""
nl27k = DatasetMetadata(
- dataset_name='nl27k',
- filename='nl27k.zip',
- url='https://s3-eu-west-1.amazonaws.com/ampligraph/datasets/nl27k.zip',
- train_name='train.tsv',
- valid_name='valid.tsv',
- test_name='test.tsv',
- train_checksum='d4ce775401d299074d98e046f13e7283',
- valid_checksum='00177fa6b9f5cec18814ee599c02eae3',
- test_checksum='2ba17f29119688d93c9d29ab40f63b3e'
+ dataset_name="nl27k",
+ filename="nl27k.zip",
+ url="https://s3-eu-west-1.amazonaws.com/ampligraph/datasets/nl27k.zip",
+ train_name="train.tsv",
+ valid_name="valid.tsv",
+ test_name="test.tsv",
+ train_checksum="d4ce775401d299074d98e046f13e7283",
+ valid_checksum="00177fa6b9f5cec18814ee599c02eae3",
+ test_checksum="2ba17f29119688d93c9d29ab40f63b3e",
)
- dataset = _load_dataset(nl27k, data_home=None,
- check_md5hash=check_md5hash)
-
+ dataset = _load_dataset(nl27k, data_home=None, check_md5hash=check_md5hash)
+
if clean_unseen:
dataset = _clean_data(dataset)
- return generate_focusE_dataset_splits(dataset, split_test_into_top_bottom, split_threshold)
+ return generate_focusE_dataset_splits(
+ dataset, split_test_into_top_bottom, split_threshold
+ )
-def load_cn15k(check_md5hash=False, clean_unseen=True, split_test_into_top_bottom=True, split_threshold=0.1):
- """Load the CN15K dataset
+def load_cn15k(
+ check_md5hash=False,
+ clean_unseen=True,
+ split_test_into_top_bottom=True,
+ split_threshold=0.1,
+):
+ """Load the CN15K dataset.
CN15K was originally proposed in :cite:`chen2019embedding`, it is a subset of ConceptNet :cite:`CN`,
a common-sense knowledge graph built to represent general human knowledge.
Numeric values on triples represent uncertainty.
CN15k dataset is loaded from file if it exists at the ``AMPLIGRAPH_DATA_HOME`` location.
- If ``AMPLIGRAPH_DATA_HOME`` is not set the the default ``~/ampligraph_datasets`` is checked.
-
+ If ``AMPLIGRAPH_DATA_HOME`` is not set, the default ``~/ampligraph_datasets`` is checked.
If the dataset is not found at either location, it is downloaded and placed in ``AMPLIGRAPH_DATA_HOME``
or ``~/ampligraph_datasets``.
It is divided into three splits:
- - ``train``
- - ``valid``
- - ``test``
+ - `train`: 199,417 triples
+ - `valid`: 16,829 triples
+ - `test`: 19,224 triples
Each triple in these splits is associated to a numeric value which represents the importance/relevance of
the link.
@@ -1423,32 +1719,29 @@ def load_cn15k(check_md5hash=False, clean_unseen=True, split_test_into_top_botto
========= ========= ======== =========== ========== ===========
Dataset Train Valid Test Entities Relations
========= ========= ======== =========== ========== ===========
- CN15K 199417 16829 19224 15000 36
+ CN15K 199,417 16,829 19,224 15,000 36
========= ========= ======== =========== ========== ===========
Parameters
----------
- check_md5hash : boolean
- If ``True`` check the md5hash of the files. Defaults to ``False``.
-
- clean_unseen : bool
- If ``True``, filters triples in validation and test sets that include entities not present in the training
+ check_md5hash: bool
+ If `True`, check the md5hash of the files (default: `False`).
+ clean_unseen: bool
+ If `True`, filters triples in validation and test sets that include entities not present in the training
set.
-
+
split_test_into_top_bottom: bool
- Splits the test set by numeric values and returns test_top_split and test_bottom_split by splitting based
- on sorted numeric values and returning top and bottom k% triples, where k is specified by `split_threshold`
- argument
-
+ Splits the test set by numeric values and returns `test_top_split` and `test_bottom_split` by splitting based
+ on sorted numeric values and returning top and bottom k% triples, where `k` is specified by `split_threshold`
+ argument.
+
split_threshold: float
- specifies the top and bottom percentage of triples to return
-
+ Specifies the top and bottom percentage of triples to return.
Returns
-------
-
- splits : dict
- The dataset splits: {'train': train,
+ splits: dict
+ The dataset splits: `{'train': train,
'valid': valid,
'test': test,
'test_topk': test_topk,
@@ -1457,37 +1750,315 @@ def load_cn15k(check_md5hash=False, clean_unseen=True, split_test_into_top_botto
'valid_numeric_values':valid_numeric_values,
'test_numeric_values': test_numeric_values,
'test_topk_numeric_values': test_topk_numeric_values,
- 'test_bottomk_numeric_values': test_bottomk_numeric_values}.
-
+ 'test_bottomk_numeric_values': test_bottomk_numeric_values}`.
Each ``*_numeric_values`` split contains numeric values associated to the corresponding dataset split and
- is a ndarray of shape [n].
-
- Each dataset split is a ndarray of shape [n,3].
-
- The ``*_topk`` and ``*_bottomk`` splits are only returned when ``split_test_into_top_bottom=True``.
-
- Examples
+ is a ndarray of shape (n).
+ Each dataset split is a ndarray of shape (n,3).
+ The ``*_topk`` and ``*_bottomk`` splits are only returned when ``split_test_into_top_bottom=True`` and contain
+ the triples ordered by highest/lowest numeric edge value associated. These are typically used at evaluation time
+ aiming at observing a model that assigns the highest rank possible to the `_topk` and the lowest possible to
+ the `_bottomk`.
+
+ Example
-------
-
>>> from ampligraph.datasets import load_cn15k
>>> X = load_cn15k()
+ >>> X["train"][0]
+ ['260' '2' '13895']
+ >>> X['train_numeric_values'][0]
+ [0.8927088]
"""
cn15k = DatasetMetadata(
- dataset_name='cn15k',
- filename='cn15k.zip',
- url='https://s3-eu-west-1.amazonaws.com/ampligraph/datasets/cn15k.zip',
- train_name='train.tsv',
- valid_name='valid.tsv',
- test_name='test.tsv',
- train_checksum='8bf2ecc8f34e7b3b544afc30abaac478',
- valid_checksum='15b63ebd7428a262ad5fe869cc944208',
- test_checksum='29df4b8d24a3d89fc7c1032b9c508112'
+ dataset_name="cn15k",
+ filename="cn15k.zip",
+ url="https://s3-eu-west-1.amazonaws.com/ampligraph/datasets/cn15k.zip",
+ train_name="train.tsv",
+ valid_name="valid.tsv",
+ test_name="test.tsv",
+ train_checksum="8bf2ecc8f34e7b3b544afc30abaac478",
+ valid_checksum="15b63ebd7428a262ad5fe869cc944208",
+ test_checksum="29df4b8d24a3d89fc7c1032b9c508112",
)
- dataset = _load_dataset(cn15k, data_home=None,
- check_md5hash=check_md5hash)
-
+ dataset = _load_dataset(cn15k, data_home=None, check_md5hash=check_md5hash)
+
if clean_unseen:
dataset = _clean_data(dataset)
- return generate_focusE_dataset_splits(dataset, split_test_into_top_bottom, split_threshold)
+ return generate_focusE_dataset_splits(
+ dataset, split_test_into_top_bottom, split_threshold
+ )
+
+
+def _load_xai_fb15k_237_experiment_log(full=False, subset="all"):
+ """Load the XAI FB15k-237 experiment log
+
+ XAI-FB15k-237 is a reduced version of FB15K-237 containing human-readable triples.
+
+ The dataset contains several fields, by default the returned data frame contains only triples, when
+ option full is equal to True (``full=True``) the full data is returned (it reflects filtering protocol).
+
+ Fields:
+
+ - predicate,
+ - predicate label,
+ - predicates_description,
+ - subject,
+ - subject_label,
+ - object_label,
+ - object.
+
+ All triples are returned 273 x 7.
+
+ Full Fields:
+
+ *note: some field can have 3 forms, these are marked with X, X = {1,2,3} for 3 triples,
+ that were displayed to the annotators with a given predicate.
+
+ - predicate: evaluated predicate,
+ - predicate label: human label for predicate,
+ - predicates_description: human description of what the predicate means,
+ - question triple X: textual form of triple 1 containing predicate,
+ - subject_tripleX: subject of triple X,
+ - object_tripleX: object of triple X,
+ - subject_label_tripleX: human label of subject of triple X,
+ - object_label_tripleX: human label of object of triple X,
+ - avg rank triple X: avergae rank that the triple obtain among models,
+ - std rank triple X: standard deviation of rank that the triple obtain among models,
+ - avg O rank triple X: average object rank that the triple obtain among models,
+ - std O rank triple X: standard deviation of object rank that the triple obtain among models,
+ - avg S rank triple X: average subject rank that the triple obtain among models,
+ - std S rank triple X: standard deviation of subject rank that the triple obtain among models,
+ - evaluated: summed score of 3 evaluators for a predicate (when each evaluator gave score 0 - not understandable or 1- understandable):
+ 0 - triples with this predicate are not understandable - full agreement between annotators.
+ 1 - triples with this predicate are mostly understandable - partial agreement between annotators.
+ 2 - triples with this predicate are mostly not understandable - partial agreement between annotators.
+ 3 - triples with this predicate are clearly understandable - full agreement between annotators.
+
+ All predicates are returned 91 x 37 records each containing 3 triples.
+
+ ============= ========= ==========
+ Dataset Entities Relations
+ ============= ========= ==========
+ XAI-FB15K-237 446 91
+ ============= ========= ==========
+
+
+ Parameters
+ ----------
+ full [False]: wether to return full dataset or reduced view with triples.
+ subset ["all"]: subset of records to be returned:
+ - "all" - returns all records,
+ - "clear" - returns only triples which all annotators marked as understandable,
+ - "not clear" - not understandable triples,
+ - "confusing+" - mostly understandable triples,
+ - "confusing-" - mostly not understandable.
+
+
+ X: pandas data frame containing triples (full=False), records with predicates (full=True).
+
+ Example
+ -------
+
+ >>> from ampligraph.datasets import _load_xai_fb15k_237_experiment_log
+ >>> X = _load_xai_fb15k_237_experiment_log()
+ >>> X.head(2)
+
+ predicate predicate label predicates_description subject subject_label object_label object
+ 0 /media_common/netflix_genre/titles Titles Titles that have this Genre in Netflix@en /m/07c52 Television Friends /m/030cx
+ 1 /film/film/edited_by Edited by NaN /m/0cc5qkt War Horse Michael Kahn /m/03q8ch
+
+ """
+ import requests
+
+ url = "https://ampgraphenc.s3-eu-west-1.amazonaws.com/datasets/xai_fb15k_237.csv"
+
+ r = requests.get(url, allow_redirects=True)
+ open("xai_fb15k_237.csv", "wb").write(r.content)
+
+ mapper = {
+ "all": "all",
+ "clear": 3,
+ "not clear": 0,
+ "confusing+": 2,
+ "confusing-": 1,
+ }
+ if subset != "all":
+ if subset in mapper:
+ X = pd.read_csv("xai_fb15k_237.csv", sep=",")
+ X = X[X["evaluated"] == mapper[subset]]
+ else:
+ print("No such option!")
+ else:
+ X = pd.read_csv("xai_fb15k_237.csv", sep=",")
+
+ if full:
+ return X
+ else:
+ t1 = X[
+ [
+ "predicate",
+ "predicate label",
+ "predicates_description",
+ "subject_triple1",
+ "subject_label_triple1",
+ "object_label_triple1",
+ "object_triple1",
+ ]
+ ]
+ t2 = X[
+ [
+ "predicate",
+ "predicate label",
+ "predicates_description",
+ "subject_triple2",
+ "subject_label_triple2",
+ "object_label_triple2",
+ "object_triple2",
+ ]
+ ]
+ t3 = X[
+ [
+ "predicate",
+ "predicate label",
+ "predicates_description",
+ "subject_triple3",
+ "subject_label_triple3",
+ "object_label_triple3",
+ "object_triple3",
+ ]
+ ]
+ mapper1 = {
+ "subject_triple1": "subject",
+ "subject_label_triple1": "subject_label",
+ "object_label_triple1": "object_label",
+ "object_triple1": "object",
+ }
+ t1 = t1.rename(columns=mapper1)
+ mapper2 = {
+ "subject_triple2": "subject",
+ "subject_label_triple2": "subject_label",
+ "object_label_triple2": "object_label",
+ "object_triple2": "object",
+ }
+ t2 = t2.rename(columns=mapper2)
+ mapper3 = {
+ "subject_triple3": "subject",
+ "subject_label_triple3": "subject_label",
+ "object_label_triple3": "object_label",
+ "object_triple3": "object",
+ }
+ t3 = t3.rename(columns=mapper3)
+ t1 = t1.append(t2, ignore_index=True)
+ t1 = t1.append(t3, ignore_index=True)
+ return t1
+
+
+def load_codex(
+ check_md5hash=False,
+ clean_unseen=True,
+ add_reciprocal_rels=False,
+ return_mapper=False,
+):
+ """Load the CoDEx-M dataset.
+
+ The dataset is described in :cite:`safavi_codex_2020`.
+
+ .. note::
+ CODEX-M contains also ground truths negative triples for test and validation sets. For more information, see
+ the above reference to the original paper.
+
+ The CodDEx dataset is loaded from file if it exists at the ``AMPLIGRAPH_DATA_HOME`` location.
+ If ``AMPLIGRAPH_DATA_HOME`` is not set, the default ``~/ampligraph_datasets`` is checked.
+ If the dataset is not found at either location, it is downloaded and placed in ``AMPLIGRAPH_DATA_HOME``
+ or ``~/ampligraph_datasets``.
+
+ This dataset is divided in three splits:
+
+ - `train`: 185,584 triples
+ - `valid`: 10,310 triples
+ - `test`: 10,310 triples
+
+ Both the validation and test splits are associated with labels (binary ndarrays),
+ with `True` for positive statements and `False` for negatives:
+
+ - `valid_labels`
+ - `test_labels`
+
+ ========= ========= ======= ================ ======= =============== ============ ===========
+ Dataset Train Valid Valid-negatives Test Test-negatives Entities Relations
+ ========= ========= ======= ================ ======= =============== ============ ===========
+ CoDEx-M 185,584 10,310 10,310 10311 10311 17,050 51
+ ========= ========= ======= ================ ======= =============== ============ ===========
+
+
+ Parameters
+ ----------
+ clean_unseen: bool
+ If `True`, filters triples in validation and test sets that include entities not present in the training set.
+
+ check_md5hash: bool
+ If `True`, check the `md5hash` of the datset files (default: `False`).
+
+ add_reciprocal_rels: bool
+ Flag which specifies whether to add reciprocal relations. For every in the dataset
+ this creates a corresponding triple with reciprocal relation (default: `False`).
+ return_mapper: bool
+ Whether to return human-readable labels in a form of dictionary in ``X["mapper"]`` field (default: `False`).
+
+ Returns
+ -------
+
+ splits: dict
+ The dataset splits: `{'train': train, 'valid': valid, 'valid_negatives': valid_negatives', 'test': test, 'test_negatives': test_negatives}`.
+ Each split is a ndarray of shape (n, 3).
+
+ Example
+ -------
+
+ >>> from ampligraph.datasets import load_codex
+ >>> X = load_codex()
+ >>> X["valid"][0]
+ array(['Q60684', 'P106', 'Q4964182'], dtype=object)
+ >>> X = load_codex(return_mapper=True)
+ >>> [X['mapper'][elem]['label'] for elem in X['valid'][0]]
+ ['Novalis', 'occupation', 'philosopher']
+
+ """
+
+ codex = DatasetMetadata(
+ dataset_name="codex",
+ filename="codex.zip",
+ url="https://s3-eu-west-1.amazonaws.com/ampligraph/datasets/codex.zip",
+ train_name="train.txt",
+ valid_name="valid.txt",
+ test_name="test.txt",
+ valid_negatives_name="valid_negatives.txt",
+ test_negatives_name="test_negatives.txt",
+ mapper_name="mapper.json" if return_mapper else None,
+ train_checksum="d507616dd7b9f6ddbacf83766efaa1dd",
+ valid_checksum="0fd5e85f41e0ba3ef6c10093cbe2a435",
+ test_checksum="7186374c5ca7075d268ccf316927041d",
+ mapper_checksum="9cf7209df69562dff36ae94f95f67e82"
+ if return_mapper
+ else None,
+ test_negatives_checksum="2dc6755e9cc54145e782480c5bb2ef44",
+ valid_negatives_checksum="381300fbd297df9db2fd05bb6cfc1f2d",
+ )
+
+ if clean_unseen:
+ return _clean_data(
+ _load_dataset(
+ codex,
+ data_home=None,
+ check_md5hash=check_md5hash,
+ add_reciprocal_rels=add_reciprocal_rels,
+ )
+ )
+ else:
+ return _load_dataset(
+ codex,
+ data_home=None,
+ check_md5hash=check_md5hash,
+ add_reciprocal_rels=add_reciprocal_rels,
+ )
diff --git a/ampligraph/datasets/graph_data_loader.py b/ampligraph/datasets/graph_data_loader.py
new file mode 100644
index 00000000..bf2deaea
--- /dev/null
+++ b/ampligraph/datasets/graph_data_loader.py
@@ -0,0 +1,907 @@
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
+#
+# This file is Licensed under the Apache License, Version 2.0.
+# A copy of the Licence is available in LICENCE, or at:
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+"""Data loader for graphs (big and small).
+
+This module provides GraphDataLoader class that can be parametrized with an artificial backend that reads data in-memory
+(:class:`~ampligraph.datasets.graph_data_loader.NoBackend`) or with a SQLite backend that stores and reads data
+on-disk (:class:`~ampligraph.datasets.sqlite_adapter.SQLiteAdapter`).
+"""
+import logging
+import tempfile
+import uuid
+from datetime import datetime
+
+import numpy as np
+import tensorflow as tf
+
+from .data_indexer import DataIndexer
+from .source_identifier import DataSourceIdentifier
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.DEBUG)
+
+
+class NoBackend:
+ """Class providing an artificial backend, that reads data into memory."""
+
+ def __init__(
+ self,
+ identifier,
+ use_indexer=True,
+ remap=False,
+ name="main_partition",
+ parent=None,
+ in_memory=True,
+ root_directory=None,
+ use_filter=False,
+ verbose=False,
+ ):
+ """Initialise NoBackend.
+
+ Parameters
+ ----------
+ identifier: initialize data source identifier, provides loader.
+ use_indexer: bool or mapper object
+ Flag or mapper object to tell whether data should be indexed (default: `False`).
+ remap: bool
+ Flag for partitioner to indicate whether to remap previously indexed data to (0, )
+ (default: `False`).
+ name: str
+ Name identifying files for the indexer, partition name/id.
+ parent:
+ Parent data loader that persists data.
+ verbose: bool
+ Verbosity.
+ """
+ # in_memory = False
+ self.verbose = verbose
+ self.identifier = identifier
+ self.use_indexer = use_indexer
+ self.remap = remap
+ self.name = name
+ self.parent = parent
+ self.in_memory = in_memory
+ if root_directory is None:
+ self.root_directory = tempfile.gettempdir()
+ else:
+ self.root_directory = root_directory
+ self.use_filter = use_filter
+ self.sources = {}
+
+ def _add_dataset(self, data_source, dataset_type):
+ msg = "Adding datasets to NoBackend not possible."
+ raise NotImplementedError(msg)
+
+ def __enter__(self):
+ """Context manager enter function. Required by GraphDataLoader."""
+ return self
+
+ def __exit__(self, type, value, tb):
+ """Context manager exit function. Required by GraphDataLoader."""
+ pass
+
+ def get_output_signature(self):
+ """Get the output signature for the tf.data.Dataset object."""
+ triple_tensor = tf.TensorSpec(shape=(None, 3), dtype=tf.int32)
+ if self.data_shape > 3:
+ weights_tensor = tf.TensorSpec(
+ shape=(None, self.data_shape - 3), dtype=tf.float32
+ )
+ if self.use_filter:
+ return (
+ triple_tensor,
+ tf.RaggedTensorSpec(shape=(2, None, None), dtype=tf.int32),
+ weights_tensor,
+ )
+ else:
+ return (triple_tensor, weights_tensor)
+ if self.use_filter:
+ return (
+ triple_tensor,
+ tf.RaggedTensorSpec(shape=(2, None, None), dtype=tf.int32),
+ )
+ return triple_tensor
+
+ def _load(self, data_source, dataset_type):
+ """Loads data into self.data.
+
+ Parameters
+ ----------
+ data_source: np.array or str
+ Array or name of the file containing the data.
+ dataset_type: str
+ Kind of data to be loaded (`"train"` | `"test"` | `"validation"`).
+ """
+ logger.debug(
+ "Simple in-memory data loading of {} dataset.".format(dataset_type)
+ )
+ self.data_source = data_source
+ self.dataset_type = dataset_type
+ if isinstance(self.data_source, np.ndarray):
+ if self.use_indexer is True:
+ self.mapper = DataIndexer(
+ self.data_source,
+ backend="in_memory" if self.in_memory else "sqlite",
+ root_directory=self.root_directory,
+ )
+ self.data = self.mapper.get_indexes(self.data_source)
+ elif self.remap:
+ # create a special mapping for partitions, persistent mapping from main indexes
+ # to partition indexes
+ self.mapper = DataIndexer(
+ self.data_source,
+ backend="sqlite",
+ name=self.name,
+ root_directory=self.root_directory,
+ )
+ self.data = self.mapper.get_indexes(self.data_source)
+ else:
+ self.mapper = self.use_indexer
+ self.data = self.mapper.get_indexes(self.data_source)
+ else:
+ loader = self.identifier.fetch_loader()
+ raw_data = loader(self.data_source)
+ if self.use_indexer is True:
+ self.mapper = DataIndexer(
+ raw_data,
+ backend="in_memory" if self.in_memory else "sqlite",
+ root_directory=self.root_directory,
+ )
+ self.data = self.mapper.get_indexes(raw_data)
+ elif self.use_indexer is False:
+ if self.remap:
+ # create a special mapping for partitions, persistent mapping from
+ # main indexes to partition indexes
+ self.mapper = DataIndexer(
+ raw_data,
+ backend="sqlite",
+ name=self.name,
+ root_directory=self.root_directory,
+ )
+ self.data = self.mapper.get_indexes(raw_data)
+ else:
+ self.data = raw_data
+ logger.debug("Data won't be indexed")
+ elif isinstance(self.use_indexer, DataIndexer):
+ self.mapper = self.use_indexer
+ self.data = self.mapper.get_indexes(raw_data)
+ self.data_shape = self.mapper.backend.data_shape
+
+ def _get_triples(self, subjects=None, objects=None, entities=None):
+ """Get triples whose subjects belongs to ``subjects``, objects to ``objects``,
+ or, if neither object nor subject is provided, triples whose subject or object belong to entities.
+ """
+ if subjects is None and objects is None:
+ if entities is None:
+ msg = "You have to provide either subjects and objects indexes or general entities indexes!"
+ logger.error(msg)
+ raise Exception(msg)
+
+ subjects = entities
+ objects = entities
+ # check_subjects = np.vectorize(lambda t: t in subjects)
+ if subjects is not None and objects is not None:
+ check_triples = np.vectorize(
+ lambda t, r: (t in objects and r in subjects)
+ or (t in subjects and r in objects)
+ )
+ triples = self.data[
+ check_triples(self.data[:, 2], self.data[:, 0])
+ ]
+ elif objects is None:
+ triples = self.data[np.isin(self.data[:, 0], subjects)]
+ elif subjects is None:
+ triples = self.data[np.isin(self.data[:, 2], objects)]
+ triples = np.append(
+ triples,
+ np.array(len(triples) * [self.dataset_type]).reshape(-1, 1),
+ axis=1,
+ )
+ # triples_from_objects = self.data[check_objects(self.data[:,0])]
+ # triples = np.vstack([triples_from_subjects, triples_from_objects])
+ return triples
+
+ def get_data_size(self):
+ """Returns number of triples."""
+ return np.shape(self.data)[0]
+
+ def _get_complementary_entities(self, triples, use_filter=None):
+ """Get subjects and objects complementary to a triple (?,p,?).
+
+ Returns the participating entities in the relation ?-p-o and s-p-?.
+ Function used during evaluation.
+
+ WARNING: If the parent is set the triples returned are coming with parent indexing.
+
+ Parameters
+ ----------
+ x_triple: nd-array, shape (N, 3)
+ Triples `(s, p, o)` that we are querying.
+
+ Returns
+ -------
+ entities: tuple
+ Tuple containing two lists, one with the subjects and one of with the objects participating in the
+ relations ?-p-o and s-p-?.
+ """
+
+ logger.debug("Getting complementary entities")
+
+ if self.parent is not None:
+ logger.debug(
+ "Parent is set, WARNING: The triples returned are coming with parent indexing."
+ )
+
+ logger.debug("Recover original indexes.")
+ triples_original_index = self.mapper.get_indexes(
+ triples, order="ind2raw"
+ )
+ # with shelve.open(self.mapper.entities_dict) as ents:
+ # with shelve.open(self.mapper.relations_dict) as rels:
+ # triples_original_index = np.array([(ents[str(xx[0])], rels[str(xx[1])],
+ # ents[str(xx[2])]) for xx in triples], dtype=np.int32)
+ logger.debug("Query parent for data.")
+ logger.debug("Original index: {}".format(triples_original_index))
+ subjects = self.parent.get_complementary_subjects(
+ triples_original_index, use_filter=use_filter
+ )
+ objects = self.parent.get_complementary_objects(
+ triples_original_index, use_filter=use_filter
+ )
+ logger.debug(
+ "What to do with this new indexes? Evaluation should happen in the original space, \
+ shouldn't it? I'm assuming it does so returning in parent indexing."
+ )
+ return subjects, objects
+ else:
+ subjects = self._get_complementary_subjects(
+ triples, use_filter=use_filter
+ )
+ objects = self._get_complementary_objects(
+ triples, use_filter=use_filter
+ )
+ return subjects, objects
+
+ def _get_complementary_subjects(self, triples, use_filter=False):
+ """Get subjects complementary to triples (?,p,o).
+
+ For a given triple retrieve all subjects coming from triples with same objects and predicates.
+
+ Parameters
+ ----------
+ triples : list or array
+ List or array of arrays with 3 elements (subject, predicate, object).
+
+ Returns
+ -------
+ subjects : list
+ Subjects present in the input triples.
+ """
+
+ logger.debug("Getting complementary subjects")
+
+ if self.parent is not None:
+ logger.debug(
+ "Parent is set, WARNING: The triples returned are coming with parent indexing."
+ )
+
+ logger.debug("Recover original indexes.")
+ triples_original_index = self.mapper.get_indexes(
+ triples, order="ind2raw"
+ )
+
+ # with shelve.open(self.mapper.reversed_entities_dict) as ents:
+ # with shelve.open(self.mapper.reversed_relations_dict) as rels:
+ # triples_original_index = np.array([(ents[str(xx[0])],
+ # rels[str(xx[1])],
+ # ents[str(xx[2])]) for xx in triples],
+ # dtype=np.int32)
+ logger.debug("Query parent for data.")
+ subjects = self.parent.get_complementary_subjects(
+ triples_original_index
+ )
+ logger.debug(
+ "What to do with this new indexes? Evaluation should happen in the \
+ original space, shouldn't it? I'm assuming it does so returning in parent indexing."
+ )
+ return subjects
+ elif self.use_filter is False or self.use_filter is None:
+ self.use_filter = {"train-org": self.data}
+
+ filtered = []
+ for filter_name, filter_source in self.use_filter.items():
+ source = self.get_source(filter_source, filter_name)
+
+ tmp_filter = []
+ for triple in triples:
+ tmp = source[source[:, 2] == triple[2]]
+ tmp_filter.append(list(set(tmp[tmp[:, 1] == triple[1]][:, 0])))
+ filtered.append(tmp_filter)
+ # Unpack data into one list per triple no matter what filter it comes
+ # from
+ unpacked = list(zip(*filtered))
+ subjects = []
+ for k in unpacked:
+ lst = [j for i in k for j in i]
+ subjects.append(np.array(lst, dtype=np.int32))
+
+ return subjects
+
+ def get_source(self, source, name):
+ """Loads the data specified by ``name`` and keep it in the loaded dictionary.
+
+ Used to load filter datasets.
+
+ Parameters
+ ----------
+ source: ndarray or str
+ Data source to load data from.
+ name: str
+ Name of the dataset to be loaded.
+
+ Returns
+ -------
+ Loaded data : ndarray
+ Numpy array containing loaded data indexed according to mapper.
+ """
+ if name not in self.sources:
+ if isinstance(source, np.ndarray):
+ raw_data = source
+ else:
+ identifier = DataSourceIdentifier(source)
+ loader = identifier.fetch_loader()
+ raw_data = loader(source)
+ if name != "train-org":
+ self.sources[name] = self.mapper.get_indexes(raw_data)
+ else:
+ self.sources[name] = raw_data
+ return self.sources[name]
+
+ def _get_complementary_objects(self, triples, use_filter=False):
+ """Get objects complementary to triples (s,p,?).
+
+ For a given triple retrieves all triples with same subjects and predicates.
+ Function used during evaluation.
+
+ Parameters
+ ----------
+ triples : list or array
+ List or array of arrays with 3 elements (subject, predicate, object).
+
+ Returns
+ -------
+ subjects : list
+ Objects present in the input triples.
+ """
+ logger.debug("Getting complementary objects")
+
+ if self.parent is not None:
+ logger.debug(
+ "Parent is set, WARNING: The triples returned are coming with parent indexing."
+ )
+
+ logger.debug("Recover original indexes.")
+ triples_original_index = self.mapper.get_indexes(
+ triples, order="ind2raw"
+ )
+ # with shelve.open(self.mapper.reversed_entities_dict) as ents:
+ # with shelve.open(self.mapper.reversed_relations_dict) as rels:
+ # triples_original_index = np.array([(ents[str(xx[0])], rels[str(xx[1])],
+ # ents[str(xx[2])]) for xx in triples], dtype=np.int32)
+ logger.debug("Query parent for data.")
+ objects = self.parent.get_complementary_objects(
+ triples_original_index
+ )
+ logger.debug(
+ "What to do with this new indexes? Evaluation should happen in \
+ the original space, shouldn't it? I'm assuming it does so returning in parent indexing."
+ )
+ return objects
+ elif self.use_filter is False or self.use_filter is None:
+ self.use_filter = {"train-org": self.data}
+ filtered = []
+ for filter_name, filter_source in self.use_filter.items():
+ source = self.get_source(filter_source, filter_name)
+
+ # load source if not loaded
+ source = self.get_source(filter_source, filter_name)
+ # filter
+
+ tmp_filter = []
+ for triple in triples:
+ tmp = source[source[:, 0] == triple[0]]
+ tmp_filter.append(list(set(tmp[tmp[:, 1] == triple[1]][:, 2])))
+ filtered.append(tmp_filter)
+
+ # Unpack data into one list per triple no matter what filter it comes
+ # from
+ unpacked = list(zip(*filtered))
+ objects = []
+ for k in unpacked:
+ lst = [j for i in k for j in i]
+ objects.append(np.array(lst, dtype=np.int32))
+
+ return objects
+
+ def _intersect(self, dataloader):
+ """Intersection between data and dataloader elements.
+
+ Works only when dataloader is of type `NoBackend`.
+ """
+ if not isinstance(dataloader.backend, NoBackend):
+ msg = "Intersection can only be calculated between same backends (NoBackend), \
+ instead get {}".format(
+ type(dataloader.backend)
+ )
+ logger.error(msg)
+ raise Exception(msg)
+ self.data = np.ascontiguousarray(self.data, dtype="int64")
+ dataloader.backend.data = np.ascontiguousarray(
+ dataloader.backend.data, dtype="int64"
+ )
+ av = self.data.view([("", self.data.dtype)] * self.data.shape[1])
+ bv = dataloader.backend.data.view(
+ [("", dataloader.backend.data.dtype)]
+ * dataloader.backend.data.shape[1]
+ )
+ intersection = (
+ np.intersect1d(av, bv)
+ .view(self.data.dtype)
+ .reshape(
+ -1,
+ self.data.shape[0 if self.data.flags["F_CONTIGUOUS"] else 1],
+ )
+ )
+ return intersection
+
+ def _get_batch_generator(
+ self, batch_size, dataset_type="train", random=False, index_by=""
+ ):
+ """Data batch generator.
+
+ Parameters
+ ----------
+ batch_size: int
+ Size of a batch.
+ dataset_type: str
+ Kind of dataset that is needed (`"train"` | `"test"` | `"validation"`).
+ random: not implemented.
+ index_by: not implemented.
+
+ Returns
+ --------
+ Batch : ndarray
+ Batch of data of size `(batch_size, m)` where :math:`m≥3` and :math:`m>3` if numeric values
+ associated to edges are available.
+ """
+ if not isinstance(batch_size, int):
+ batch_size = int(batch_size)
+ length = len(self.data)
+ triples = range(0, length, batch_size)
+ for start_index in triples:
+ # if the last batch is smaller than the batch_size
+ if start_index + batch_size >= length:
+ batch_size = length - start_index
+ out = self.data[start_index: start_index + batch_size,:3]
+ if self.use_filter:
+ # get the filter values
+ participating_entities = self._get_complementary_entities(
+ out, self.use_filter
+ )
+
+ # focusE
+ if self.data_shape > 3:
+ weights = self.data[start_index: start_index + batch_size, 3:]
+ # weights = preprocess_focusE_weights(data=out,
+ # weights=weights)
+ if self.use_filter:
+ yield out, tf.ragged.constant(
+ participating_entities, dtype=tf.int32
+ ), weights
+ else:
+ yield out, weights
+
+ else:
+ if self.use_filter:
+ yield out, tf.ragged.constant(
+ participating_entities, dtype=tf.int32
+ )
+ else:
+ yield out
+
+ def _clean(self):
+ del self.data
+ self.mapper.clean()
+
+
+class GraphDataLoader:
+ """Data loader for models to ingest graph data.
+
+ This class is internally used by the model to store the data passed by the user and batch over it during
+ training and evaluation, and to obtain filters during evaluation.
+
+ It can be used by advanced users to load custom datasets which are large, for performing partitioned training.
+ The complete dataset will not get loaded in memory. It will load the data in chunks based on which partition
+ is being trained.
+
+ Example
+ -------
+ >>> from ampligraph.datasets import GraphDataLoader, BucketGraphPartitioner
+ >>> from ampligraph.datasets.sqlite_adapter import SQLiteAdapter
+ >>> from ampligraph.latent_features import ScoringBasedEmbeddingModel
+ >>> AMPLIGRAPH_DATA_HOME='/your/path/to/datasets/'
+ >>> # Graph loader - loads the data from the file, numpy array, etc and generates batches for iterating
+ >>> path_to_training = AMPLIGRAPH_DATA_HOME + 'fb15k-237/train.txt'
+ >>> dataset_loader = GraphDataLoader(path_to_training,
+ >>> backend=SQLiteAdapter, # type of backend to use
+ >>> batch_size=1000, # batch size to use while iterating over this dataset
+ >>> dataset_type='train', # dataset type
+ >>> use_filter=False, # Whether to use filter or not
+ >>> use_indexer=True) # indicates that the data needs to be mapped to index
+ >>>
+ >>> # Choose the partitioner - in this case we choose RandomEdges partitioner
+ >>> partitioner = BucketGraphPartitioner(dataset_loader, k=3)
+ >>> partitioned_model = ScoringBasedEmbeddingModel(eta=2,
+ >>> k=50,
+ >>> scoring_type='DistMult')
+ >>> partitioned_model.compile(optimizer='adam', loss='multiclass_nll')
+ >>> partitioned_model.fit(partitioner, # pass the partitioner object as input to the fit function this will generate data for the model during training
+ >>> epochs=10) # number of epochs
+ >>> indexer = partitioned_model.data_handler.get_mapper() # get the mapper from the trained model
+ >>> path_to_test = AMPLIGRAPH_DATA_HOME + 'fb15k-237/test.txt'
+ >>> dataset_loader_test = GraphDataLoader(path_to_test,
+ >>> backend=SQLiteAdapter, # type of backend to use
+ >>> batch_size=400, # batch size to use while iterating over this dataset
+ >>> dataset_type='test', # dataset type
+ >>> use_indexer=indexer # mapper to map test concepts to the same indices used during training
+ >>> )
+ >>> ranks = partitioned_model.evaluate(dataset_loader_test, # pass the dataloader object to generate data for the model during training
+ >>> batch_size=400)
+ >>> print(ranks)
+ [[ 85 7]
+ [ 95 9]
+ [1074 22]
+ ...
+ [ 546 95]
+ [9961 7485]
+ [1494 2]]
+
+
+ """
+
+ def __init__(
+ self,
+ data_source,
+ batch_size=1,
+ dataset_type="train",
+ backend=None,
+ root_directory=None,
+ use_indexer=True,
+ verbose=False,
+ remap=False,
+ name="main_partition",
+ parent=None,
+ in_memory=False,
+ use_filter=False,
+ ):
+ """Initialise persistent/in-memory data storage.
+
+ Parameters
+ ----------
+ data_source: str or np.array or GraphDataLoader or AbstractGraphPartitioner
+ File with data (e.g. CSV). Can be a path pointing to the file location, can be data loaded as numpy, a
+ `GraphDataLoader` or an `AbstractGraphPartitioner` instance.
+ batch_size: int
+ Size of batch.
+ dataset_type: str
+ Kind of data provided (`"train"` | `"test"` | `"valid"`).
+ backend: str
+ Name of backend class (`NoBackend`, `SQLiteAdapter`) or already initialised backend.
+ If `None`, `NoBackend` is used (in-memory processing).
+ root_directory: str
+ Path to a directory where the database will be created, and the data and mappings will be stored.
+ If `None`, the root directory is obtained through the :meth:`tempfile.gettempdir()` method
+ (default: `None`).
+ use_indexer: bool or DataIndexer
+ Flag to tell whether data should be indexed.
+ If the DataIndexer object is passed, the mappings defined in the indexer will be reused
+ to generate mappings for the current data.
+ verbose: bool
+ Verbosity.
+ remap: bool
+ Flag to be used by graph partitioner, indicates whether previously indexed data in partition has to
+ be remapped to new indexes (0, ). It has not to be used with ``use_indexer=True``.
+ The new remappings will be persisted.
+ name: str
+ Name of the partition. This is internally used when the data is partitioned.
+ parent: GraphDataLoader
+ Parent dataloader. This is internally used when the data is partitioned.
+ in_memory: bool
+ Persist indexes or not.
+ use_filter: bool or dict
+ If `True`, current dataset will be used as filter.
+ If `dict`, the datasets specified in the dict will be used for filtering.
+ If `False`, the true positives will not be filtered from corruptions.
+ """
+ self.dataset_type = dataset_type
+ self.data_source = data_source
+ self.batch_size = batch_size
+ if root_directory is None:
+ self.root_directory = tempfile.gettempdir()
+ else:
+ self.root_directory = root_directory
+ self.identifier = DataSourceIdentifier(self.data_source)
+ self.use_indexer = use_indexer
+ self.remap = remap
+ self.in_memory = in_memory
+ self.name = name
+ self.parent = parent
+ if use_filter is None or use_filter is True:
+ self.use_filter = {"train": data_source}
+ else:
+ if isinstance(use_filter, dict) or use_filter is False:
+ self.use_filter = use_filter
+ else:
+ msg = "use_filter should be a dictionary with keys as names of filters and \
+ values as data sources, instead got {}".format(
+ use_filter
+ )
+ logger.error(msg)
+ raise Exception(msg)
+ if bool(use_indexer) != (not remap):
+ msg = (
+ "Either remap or Indexer should be specified at the same time."
+ )
+ logger.error(msg)
+ raise Exception(msg)
+ if isinstance(backend, type) and backend != NoBackend:
+ self.backend = backend(
+ "database_{}_{}.db".format(
+ datetime.now().strftime("%d-%m-%Y_%I-%M-%S_%f_%p"),
+ str(uuid.uuid4()),
+ ),
+ identifier=self.identifier,
+ root_directory=self.root_directory,
+ use_indexer=self.use_indexer,
+ remap=self.remap,
+ name=self.name,
+ parent=self.parent,
+ in_memory=self.in_memory,
+ verbose=verbose,
+ use_filter=self.use_filter,
+ )
+ logger.debug(
+ "Initialized Backend with database at: {}".format(
+ self.backend.db_path
+ )
+ )
+
+ elif backend is None or backend == NoBackend:
+ self.backend = NoBackend(
+ self.identifier,
+ use_indexer=self.use_indexer,
+ remap=self.remap,
+ name=self.name,
+ parent=self.parent,
+ in_memory=self.in_memory,
+ use_filter=self.use_filter,
+ )
+ else:
+ self.backend = backend
+
+ self.backend._load(self.data_source, dataset_type=self.dataset_type)
+ self.data_shape = self.backend.data_shape
+ self.batch_iterator = self.get_batch_generator(
+ use_filter=self.use_filter, dataset_type=self.dataset_type
+ )
+ self.metadata = self.backend.mapper.metadata
+
+ def __iter__(self):
+ """Function needed to be used as an iterator."""
+ return self
+
+ @property
+ def max_entities(self):
+ """Maximum number of entities present in the dataset mapper."""
+ return self.backend.mapper.get_entities_count()
+
+ @property
+ def max_relations(self):
+ """Maximum number of relations present in the dataset mapper."""
+ return self.backend.mapper.get_relations_count()
+
+ def __next__(self):
+ """Function needed to be used as an iterator."""
+ return self.batch_iterator.__next__()
+
+ def reload(self, use_filter=False, dataset_type="train"):
+ """Reinstantiate batch iterator."""
+ self.batch_iterator = self.get_batch_generator(
+ use_filter=use_filter, dataset_type=dataset_type
+ )
+
+ def get_batch_generator(self, dataset_type="train", use_filter=False):
+ """Get batch generator from the backend.
+
+ Parameters
+ ----------
+ dataset_type: str
+ Specifies whether data are generated for `"train"`, `"valid"` or `"test"` set.
+ """
+ return self.backend._get_batch_generator(
+ self.batch_size, dataset_type=dataset_type
+ )
+
+ def get_tf_generator(self):
+ """Generates a tensorflow.data.Dataset object."""
+ return tf.data.Dataset.from_generator(
+ self.backend._get_batch_generator,
+ output_signature=self.backend.get_output_signature(),
+ args=(self.batch_size, self.dataset_type, False, ""),
+ ).prefetch(2)
+
+ def add_dataset(self, data_source, dataset_type):
+ """Adds the dataset to the backend (if possible)."""
+ self.backend._add_dataset(data_source, dataset_type=dataset_type)
+
+ def get_data_size(self):
+ """Returns number of triples."""
+ return self.backend.get_data_size()
+
+ def intersect(self, dataloader):
+ """Returns the intersection between the current data loader and another one specified in ``dataloader``.
+
+ Parameters
+ ----------
+ dataloader: GraphDataLoader
+ Dataloader for which to calculate the intersection for.
+
+ Returns
+ -------
+ intersection: ndarray
+ Array of intersecting elements.
+ """
+
+ return self.backend._intersect(dataloader)
+
+ def get_participating_entities(
+ self, triples, sides="s,o", use_filter=False
+ ):
+ """Get entities from triples with fixed subjects or fixed objects or both fixed.
+
+ Parameters
+ ----------
+ triples: list or array
+ List or array of arrays with 3 elements (subject, predicate, object).
+ sides : str
+ String specifying what entities to retrieve: `"s"` - subjects, `"o"` - objects,
+ `"s,o"` - subjects and objects, `"o,s"` - objects and subjects.
+
+ Returns
+ -------
+ entities : list
+ List of subjects (if ``sides="s"``) or objects (if ``sides="o"``) or two lists with both
+ (if ``sides="s,o"`` or ``sides="o,s"``).
+ """
+ if sides not in ["s", "o", "s,o", "o,s"]:
+ msg = "Sides should be either 's' (subject), 'o' (object), or 's,o'/'o,s' (subject, object/object, subject), \
+ instead got {}".format(
+ sides
+ )
+ logger.error(msg)
+ raise Exception(msg)
+ if "s" in sides:
+ subjects = self.get_complementary_subjects(
+ triples, use_filter=use_filter
+ )
+
+ if "o" in sides:
+ objects = self.get_complementary_objects(
+ triples, use_filter=use_filter
+ )
+
+ if sides == "s,o":
+ return subjects, objects
+ if sides == "o,s":
+ return objects, subjects
+ if sides == "s":
+ return subjects
+ if sides == "o":
+ return objects
+
+ def get_complementary_subjects(self, triples, use_filter=False):
+ """Get subjects complementary to triples (?,p,o).
+
+ For a given triple retrieve all subjects coming from triples with same objects and predicates.
+
+ Parameters
+ ----------
+ triples : list or array
+ List or array of arrays with 3 elements (subject, predicate, object).
+
+ Returns
+ -------
+ subjects : list
+ Subjects present in the input triples.
+ """
+ return self.backend._get_complementary_subjects(
+ triples, use_filter=use_filter
+ )
+
+ def get_complementary_objects(self, triples, use_filter=False):
+ """Get objects complementary to triples (s,p,?).
+
+ For a given triple retrieve all triples with same subjects and predicates.
+ Function used during evaluation.
+
+ Parameters
+ ----------
+ triples : list or array
+ List or array of arrays with 3 elements (subject, predicate, object).
+
+ Returns
+ -------
+ subjects : list
+ Objects present in the input triples.
+ """
+ return self.backend._get_complementary_objects(
+ triples, use_filter=use_filter
+ )
+
+ def get_complementary_entities(self, triples, use_filter=False):
+ """Get subjects and objects complementary to triples (?,p,?).
+
+ Returns the participating entities in the relation ?-p-o and s-p-?.
+
+ Parameters
+ ----------
+ x_triple: nd-array (N,3,)
+ N triples (s-p-o) that we are querying.
+
+ Returns
+ -------
+ entities: tuple
+ Tuple containing two lists, one with the subjects and one of with the objects participating in the
+ relations ?-p-o and s-p-?.
+ """
+ return self.backend._get_complementary_entities(
+ triples, use_filter=use_filter
+ )
+
+ def get_triples(self, subjects=None, objects=None, entities=None):
+ """Get triples that subject is in subjects and object is in objects, or
+ triples that eiter subject or object is in entities.
+
+ Parameters
+ ----------
+ subjects: list
+ List of entities that triples subject should belong to.
+
+ objects: list
+ List of entities that triples object should belong to.
+
+ entities: list
+ List of entities that triples subject and object should belong to.
+
+ Returns
+ -------
+ triples: list
+ List of triples constrained by subjects and objects.
+
+ """
+ return self.backend._get_triples(subjects, objects, entities)
+
+ def clean(self):
+ """Cleans up the temporary files created for training/evaluation."""
+ self.backend._clean()
+
+ def on_epoch_end(self):
+ pass
+
+ def on_complete(self):
+ pass
diff --git a/ampligraph/datasets/graph_partitioner.py b/ampligraph/datasets/graph_partitioner.py
new file mode 100644
index 00000000..b3546c2e
--- /dev/null
+++ b/ampligraph/datasets/graph_partitioner.py
@@ -0,0 +1,794 @@
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
+#
+# This file is Licensed under the Apache License, Version 2.0.
+# A copy of the Licence is available in LICENCE, or at:
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+"""
+Graph partitioning strategies.
+
+This module contains several graph partitioning strategies both based on vertices split and edges split.
+
+Attributes
+----------
+PARTITION_ALGO_REGISTRY : dict
+ Dictionary containing the names of the strategies as key and reference to the strategy class as a value.
+
+"""
+import logging
+import os
+import shelve
+import tempfile
+from abc import ABC
+from datetime import datetime
+
+import numpy as np
+
+from ampligraph.utils.profiling import timing_and_memory
+
+from .graph_data_loader import GraphDataLoader
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.DEBUG)
+PARTITION_ALGO_REGISTRY = {}
+
+
+def register_partitioning_strategy(name, manager):
+ """Decorator responsible for registering partition in the partition registry.
+
+ Parameters
+ ----------
+ name: str
+ Name of the new partition strategy.
+ manager: str
+ Name of the partitioning manager that will handle this partitioning strategy during training.
+
+ Example
+ -------
+ >>>@register_partitioning_strategy("NewStrategyName")
+ >>>class NewPartitionStrategy(AbstractGraphPartitioner):
+ >>>... pass
+ """
+
+ def insert_in_registry(class_handle):
+ """Checks if partition already exists and if not registers it."""
+ if name in PARTITION_ALGO_REGISTRY.keys():
+ msg = "Partitioning Strategy with name {} already exists!".format(
+ name
+ )
+ logger.error(msg)
+ raise Exception(msg)
+
+ PARTITION_ALGO_REGISTRY[name] = class_handle
+ class_handle.name = name
+ class_handle.manager = manager
+
+ return class_handle
+
+ return insert_in_registry
+
+
+def get_number_of_partitions(n):
+ """Calculates number of partitions for Bucket Partitioner.
+
+ Parameters
+ ----------
+ n: int
+ Number of buckets with vertices.
+
+ Returns
+ -------
+ n_partitions: int
+ Number of partitions.
+ """
+ return int(n * (n + 1) / 2)
+
+
+class AbstractGraphPartitioner(ABC):
+ """Meta class defining interface for graph partitioning algorithms."""
+
+ def __init__(self, data, k=2, seed=None, root_dir=None, **kwargs):
+ """Initialise the AbstractGraphPartitioner.
+
+ Parameters
+ ----------
+ data: GraphDataLoader
+ Input data provided as a GraphDataLoader.
+ k: int
+ Number of partitions or buckets to split data into.
+ """
+ self.files = []
+ self.partitions = []
+ self._data = data
+ self._k = k
+ if root_dir is None:
+ self.root_dir = tempfile.gettempdir()
+ else:
+ self.root_dir = root_dir
+ self._split(seed=seed, batch_size=data.batch_size, **kwargs)
+ self.reload()
+
+ def __iter__(self):
+ """Function needed to be used as an iterator."""
+ return self
+
+ def reload(self):
+ """Reload the partition."""
+ self.generator = self.partitions_generator()
+
+ def get_data(self):
+ """Get the underlying data handler."""
+ return self._data
+
+ def partitions_generator(self):
+ """Generates partitions.
+
+ Yields
+ ------
+ next_partition : GraphDataLoader
+ Next partition as a GraphDataLoader object.
+ """
+ for partition in self.partitions:
+ partition.reload()
+ yield partition
+
+ def get_partitions_iterator(self):
+ """Re-instantiate partitions generator.
+
+ Returns
+ -------
+ Partitions generator
+ """
+ return self.partitions_generator()
+
+ def get_partitions_list(self):
+ """Returns handler for partitions list."""
+ for partition in self.partitions:
+ partition.reload()
+ return self.partitions
+
+ def __next__(self):
+ """Function needed to be used as an iterator."""
+ return next(self.generator)
+
+ def _split(self, seed=None, **kwargs):
+ """Split data into `k` equal size partitions.
+
+ Parameters
+ ----------
+ seed: int
+ Seed to be used for repeatability purposes, it is only used when certain randomization is required.
+
+ Returns
+ -------
+ Partitions:
+ Partitions in which the entities are divided.
+ """
+ pass
+
+ def clean(self):
+ """Remove the temporary files created for the partitions."""
+ for partition in self.partitions:
+ partition.clean()
+ for f in self.files:
+ if f.split(".")[-1] != "shf":
+ os.remove(f)
+ else:
+ try:
+ os.remove(f + ".bak")
+ os.remove(f + ".dir")
+ os.remove(f + ".dat")
+ except Exception:
+ if os.path.exists(f + ".db"):
+ os.remove(f + ".db")
+
+
+@register_partitioning_strategy("Bucket", "BucketPartitionDataManager")
+class BucketGraphPartitioner(AbstractGraphPartitioner):
+ """Bucket-based partition strategy.
+
+ This strategy first splits entities into :math:`k` buckets and creates:
+
+ + `k` partitions where the `i`-th includes triples such that subject and object belong to the `i`-th partition.
+ + :math:`\\frac{(k^2-k)}{2}` partitions indexed by :math:`(i,j)` with :math:`i,j=1,...,k`, :math:`i \\neq j` where
+ the :math:`(i,j)`-th partition contains triples such that the subject belongs to the :math:`i`-th partition
+ and the object to the :math:`j`-th partition or viceversa.
+
+ Example
+ -------
+ >>> from ampligraph.datasets import load_fb15k_237, GraphDataLoader, BucketGraphPartitioner
+ >>> from ampligraph.datasets.sqlite_adapter import SQLiteAdapter
+ >>> from ampligraph.latent_features import ScoringBasedEmbeddingModel
+ >>> dataset = load_fb15k_237()
+ >>> dataset_loader = GraphDataLoader(dataset['train'],
+ >>> backend=SQLiteAdapter, # Type of backend to use
+ >>> batch_size=1000, # Batch size to use while iterating over the dataset
+ >>> dataset_type='train', # Dataset type
+ >>> use_filter=False, # Whether to use filter or not
+ >>> use_indexer=True) # indicates that the data needs to be mapped to index
+ >>> partitioner = BucketGraphPartitioner(dataset_loader, k=2)
+ >>> # create and compile a model as usual
+ >>> partitioned_model = ScoringBasedEmbeddingModel(eta=2, k=50, scoring_type='DistMult')
+ >>> partitioned_model.compile(optimizer='adam', loss='multiclass_nll')
+ >>> partitioned_model.fit(partitioner, # The partitioner object generate data for the model during training
+ >>> epochs=10) # Number of epochs
+
+ Example
+ -------
+ >>> import numpy as np
+ >>> from ampligraph.datasets import GraphDataLoader, BucketGraphPartitioner
+ >>> d = np.array([[1,1,2], [1,1,3],[1,1,4],[5,1,3],[5,1,2],[6,1,3],[6,1,2],[6,1,4],[6,1,7]])
+ >>> data = GraphDataLoader(d, batch_size=1, dataset_type="test")
+ >>> partitioner = BucketGraphPartitioner(data, k=2)
+ >>> for i, partition in enumerate(partitioner):
+ >>> print("partition ", i)
+ >>> for batch in partition:
+ >>> print(batch)
+ partition 0
+ [['0,0,1']]
+ [['0,0,2']]
+ [['0,0,3']]
+ partition 1
+ [['4,0,1']]
+ [['4,0,2']]
+ [['5,0,1']]
+ [['5,0,2']]
+ [['5,0,3']]
+ partition 2
+ [['5,0,6']]
+
+ """
+
+ def __init__(self, data, k=2, **kwargs):
+ """Initialise the BucketGraphPartitioner.
+
+ Parameters
+ ----------
+ data: GraphDataLoader
+ Input data as a GraphDataLoader.
+ k: int
+ Number of buckets to split entities (i.e., vertices) into.
+
+ """
+
+ self.partitions = []
+ super().__init__(data, k, **kwargs)
+
+ def create_single_partition(
+ self, ind1, ind2, timestamp, partition_nb, batch_size=1
+ ):
+ """Creates partition based on the two given indices of buckets.
+
+ It appends created partition to the list of partitions (self.partitions).
+
+ Parameters
+ ----------
+ ind1: int
+ Index of the first bucket needed to create partition.
+ ind2: int
+ Index of the second bucket needed to create partition.
+ timestamp: str
+ Date and time string that the files are created with (shelves).
+ partition_nb: int
+ Assigned number of partitions.
+
+ """
+ # logger.debug("------------------------------------------------")
+ # logger.debug("Creating partition nb: {}".format(partition_nb))
+
+ fname = "bucket_{}_{}.shf".format(ind1, timestamp)
+ with shelve.open(
+ os.path.join(self.root_dir, fname), writeback=True
+ ) as bucket_partition_1:
+ indexes_1 = bucket_partition_1["indexes"]
+ fname = "bucket_{}_{}.shf".format(ind2, timestamp)
+ with shelve.open(
+ os.path.join(self.root_dir, fname), writeback=True
+ ) as bucket_partition_2:
+ indexes_2 = bucket_partition_2["indexes"]
+
+ # logger.debug("indexes 1: {}".format(ind1, indexes_1))
+ # logger.debug("indexes 2: {}".format(ind2, indexes_2))
+
+ triples_1_2 = np.array(
+ self._data.get_triples(subjects=indexes_1, objects=indexes_2)
+ )[:, :3]
+ triples_2_1 = np.array(
+ self._data.get_triples(subjects=indexes_2, objects=indexes_1)
+ )[:, :3]
+
+ logger.debug("triples 1-2: {}".format(triples_1_2))
+ logger.debug("triples 2-1: {}".format(triples_2_1))
+ triples = np.vstack([triples_1_2, triples_2_1]).astype(np.int32)
+ # logger.debug(triples)
+ if triples.size != 0:
+ triples = np.unique(triples, axis=0)
+ # logger.debug("unique triples: {}".format(triples))
+ fname = "partition_{}_{}.csv".format(partition_nb, timestamp)
+ fname = os.path.join(self.root_dir, fname)
+ self.files.append(fname)
+ np.savetxt(fname, triples, delimiter="\t", fmt="%d")
+ # special case of GraphDataLoader to create partition datasets:
+ # with remapped indexes (0, size_of_partition),
+ # persisted, with partition number to look up remappings
+ partition_loader = GraphDataLoader(
+ fname,
+ use_indexer=False,
+ batch_size=batch_size,
+ remap=True,
+ parent=self._data,
+ name="partition_{}_buckets_{}-{}".format(
+ partition_nb, ind1, ind2
+ ),
+ )
+ self.partitions.append(partition_loader)
+ return 0 # status everything went ok
+ else:
+ return 1 # status not ok, no partition created
+
+ @timing_and_memory
+ def _split(self, seed=None, verbose=False, batch_size=1, **kwargs):
+ """Split data into `self.k` buckets based on unique entities and assign
+ accordingly triples to `k` partitions and intermediate partitions.
+
+ """
+ timestamp = datetime.now().strftime("%d-%m-%Y_%I-%M-%S_%p")
+ self.ents_size = self._data.backend.mapper.get_entities_count()
+ logger.debug(self.ents_size)
+ self.bucket_size = int(np.ceil(self.ents_size / self._k))
+ self.buckets_generator = (
+ self._data.backend.mapper.get_entities_in_batches(
+ batch_size=self.bucket_size
+ )
+ )
+
+ for i, bucket in enumerate(self.buckets_generator):
+ # dump entities in partition shelve/file
+ fname = "bucket_{}_{}.shf".format(i, timestamp)
+ fname = os.path.join(self.root_dir, fname)
+ self.files.append(fname)
+ with shelve.open(fname, writeback=True) as bucket_partition:
+ bucket_partition["indexes"] = bucket
+ # logger.debug(bucket)
+
+ partition_nb = 0
+ # ensure that the "same" bucket partitions are generated first
+ for i in range(self._k):
+ # condition that excludes duplicated partitions
+ # from k x k possibilities, partition 0-1 and 1-0 is the same - not
+ # needed
+ status_not_ok = self.create_single_partition(
+ i, i, timestamp, partition_nb, batch_size=batch_size
+ )
+ if status_not_ok:
+ continue
+ partition_nb += 1
+
+ # Now generate across bucket partitions
+ for i in range(self._k):
+ for j in range(self._k):
+ if j > i:
+ # condition that excludes duplicated partitions
+ # from k x k possibilities, partition 0-1 and 1-0 are the
+ # same - not needed
+ status_not_ok = self.create_single_partition(
+ i, j, timestamp, partition_nb, batch_size=batch_size
+ )
+ if status_not_ok:
+ continue
+ partition_nb += 1
+
+
+@register_partitioning_strategy(
+ "RandomVertices", "GeneralPartitionDataManager"
+)
+class RandomVerticesGraphPartitioner(AbstractGraphPartitioner):
+ """Partitioning strategy that splits vertices into equal sized buckets of random entities from the graph.
+
+ Example
+ -------
+ >>> from ampligraph.datasets imoprt load_fb15k_237
+ >>> from ampligraph.datasets import GraphDataLoader
+ >>> from ampligraph.datasets.sqlite_adapter import SQLiteAdapter
+ >>> from ampligraph.datasets.graph_partitioner import RandomVerticesGraphPartitioner
+ >>> from ampligraph.latent_features import ScoringBasedEmbeddingModel
+ >>> dataset = load_fb15k_237()
+ >>> dataset_loader = GraphDataLoader(dataset['train'],
+ >>> backend=SQLiteAdapter, # type of backend to use
+ >>> batch_size=2, # batch size to use while iterating over this dataset
+ >>> dataset_type='train', # dataset type
+ >>> use_filter=False, # whether to use filter or not
+ >>> use_indexer=True) # indicates that the data needs to be mapped to index
+ >>> partitioner = RandomVerticesGraphPartitioner(dataset_loader, k=2)
+ >>> # create and compile a model as usual
+ >>> partitioned_model = ScoringBasedEmbeddingModel(eta=2,
+ >>> k=50,
+ >>> scoring_type='DistMult')
+ >>>
+ >>> partitioned_model.compile(optimizer='adam', loss='multiclass_nll')
+ >>>
+ >>> partitioned_model.fit(partitioner, # pass the partitioner object as input to the fit function this will generate data for the model during training
+ >>> epochs=10) # number of epochs
+
+ Example
+ -------
+ >>> import numpy as np
+ >>> from ampligraph.datasets.graph_partitioner import GraphDataLoader, RandomVerticesGraphPartitioner
+ >>> d = np.array([[1,1,2], [1,1,3],[1,1,4],[5,1,3],[5,1,2],[6,1,3],[6,1,2],[6,1,4],[6,1,7]])
+ >>> data = GraphDataLoader(d, batch_size=1, dataset_type="test")
+ >>> partitioner = RandomVerticesGraphPartitioner(data, k=2)
+ >>> for i, partition in enumerate(partitioner):
+ >>> print("partition ", i)
+ >>> for batch in partition:
+ >>> print(batch)
+
+ """
+
+ def __init__(self, data, k=2, seed=None, **kwargs):
+ """Initialise the RandomVerticesGraphPartitioner.
+
+ Parameters
+ ----------
+ data: GraphDataLoader
+ Input data provided as a GraphDataLoader.
+ k: int
+ Number of buckets to split entities (i.e., vertices) into.
+ seed: int
+ Seed to be used during partitioning.
+
+ """
+ self._data = data
+ self._k = k
+ self.partitions = []
+ super().__init__(data, k, **kwargs)
+
+ @timing_and_memory
+ def _split(self, seed=None, batch_size=1, **kwargs):
+ """Split data into `k` equal size partitions by randomly drawing subset of vertices
+ of partition size and retrieving triples associated with these vertices.
+
+ """
+ timestamp = datetime.now().strftime("%d-%m-%Y_%I-%M-%S_%p")
+ self.ents_size = self._data.backend.mapper.get_entities_count()
+ # logger.debug(self.ents_size)
+ # logger.debug(backend.mapper.max_ents_index)
+ self.partition_size = int(np.ceil(self.ents_size / self._k))
+ # logger.debug(self.partition_size)
+ self.buckets_generator = (
+ self._data.backend.mapper.get_entities_in_batches(
+ batch_size=self.partition_size, random=True, seed=seed
+ )
+ )
+
+ for partition_nb, partition in enumerate(self.buckets_generator):
+ # logger.debug(partition)
+ tmp = np.array(self._data.backend._get_triples(entities=partition))
+ # tmp_subj = np.array(self._data.backend._get_triples(subjects=partition))
+ # tmp_obj = np.array(self._data.backend._get_triples(objects=partition))
+ # tmp = np.unique(np.concatenate([tmp_subj, tmp_obj], axis=0), axis=0)
+
+ if tmp.size != 0:
+ triples = tmp[:, :3].astype(np.int32)
+ # logger.debug("unique triples: {}".format(triples))
+ fname = "partition_{}_{}.csv".format(partition_nb, timestamp)
+ fname = os.path.join(self.root_dir, fname)
+ self.files.append(fname)
+ np.savetxt(fname, triples, delimiter="\t", fmt="%d")
+ # special case of GraphDataLoader to create partition datasets:
+ # with remapped indexes (0, size_of_partition),
+ # persisted, with partition number to look up remappings
+ partition_loader = GraphDataLoader(
+ fname,
+ use_indexer=False,
+ batch_size=batch_size,
+ remap=True,
+ parent=self._data,
+ name="partition_{}".format(partition_nb),
+ )
+ self.partitions.append(partition_loader)
+ else:
+ logger.debug("Partition has no triples, skipping!")
+
+
+class EdgeBasedGraphPartitioner(AbstractGraphPartitioner):
+ """Template for edge-based partitioning strategy that splits edges
+ into partitions.
+
+ To be inherited to create different edge-based strategies.
+
+ Example
+ -------
+ >>> from ampligraph.datasets imoprt load_fb15k_237
+ >>> from ampligraph.datasets import GraphDataLoader
+ >>> from ampligraph.datasets.sqlite_adapter import SQLiteAdapter
+ >>> from ampligraph.datasets.graph_partitioner import EdgeBasedGraphPartitioner
+ >>> from ampligraph.latent_features import ScoringBasedEmbeddingModel
+ >>> dataset = load_fb15k_237()
+ >>> dataset_loader = GraphDataLoader(dataset['train'],
+ >>> backend=SQLiteAdapter, # type of backend to use
+ >>> batch_size=2, # batch size to use while iterating over this dataset
+ >>> dataset_type='train', # dataset type
+ >>> use_filter=False, # Whether to use filter or not
+ >>> use_indexer=True) # indicates that the data needs to be mapped to index
+ >>> partitioner = EdgeBasedGraphPartitioner(dataset_loader, k=2)
+ >>> # create and compile a model as usual
+ >>> partitioned_model = ScoringBasedEmbeddingModel(eta=2,
+ >>> k=50,
+ >>> scoring_type='DistMult')
+ >>>
+ >>> partitioned_model.compile(optimizer='adam', loss='multiclass_nll')
+ >>>
+ >>> partitioned_model.fit(partitioner, # pass the partitioner object as input to the fit function this will generate data for the model during training
+ >>> epochs=10) # number of epochs
+
+ """
+
+ def __init__(self, data, k=2, random=False, index_by="", **kwargs):
+ """Initialise the EdgeBasedGraphPartitioner.
+
+ Parameters
+ ----------
+ data: GraphDataLoader
+ Input data as a GraphDataLoader.
+ k: int
+ Number of buckets to split entities (i.e., vertices) into.
+ random: bool
+ Whether to draw edges/triples in random order.
+ index_by: str
+ Which index to use when returning triples (`"s"`, `"o"`, `"so"`, `"os"`).
+
+ """
+
+ self.partitions = []
+ self._data = data
+ self._k = k
+ super().__init__(data, k=k, random=random, index_by=index_by, **kwargs)
+
+ def get_data(self):
+ """Get the underlying data handler."""
+ return self._data
+
+ @timing_and_memory
+ def _split(
+ self, seed=None, batch_size=1, random=False, index_by="", **kwargs
+ ):
+ """Split data into `k` equal size partitions by randomly drawing subset of edges from dataset.
+
+ Returns
+ -------
+ partitions
+ Parts of equal size containing triples
+ """
+ timestamp = datetime.now().strftime("%d-%m-%Y_%I-%M-%S_%p")
+ self.size = self._data.backend.get_data_size()
+
+ self.partition_size = int(np.ceil(self.size / self._k))
+ logger.debug(self.partition_size)
+ generator = self._data.backend._get_batch_generator(
+ random=random,
+ batch_size=self.partition_size,
+ dataset_type=self._data.dataset_type,
+ index_by=index_by,
+ )
+
+ for partition_nb, partition in enumerate(generator):
+ fname = "partition_{}_{}.csv".format(partition_nb, timestamp)
+ fname = os.path.join(self.root_dir, fname)
+ self.files.append(fname)
+ np.savetxt(
+ fname, np.array(partition, dtype=int), delimiter="\t", fmt="%d"
+ )
+ # special case of GraphDataLoader to create partition datasets:
+ # with remapped indexes (0, size_of_partition),
+ # persisted, with partition number to look up remappings
+ partition_loader = GraphDataLoader(
+ fname,
+ use_indexer=False,
+ batch_size=batch_size,
+ remap=True,
+ parent=self._data,
+ name="partition_{}".format(partition_nb),
+ )
+ self.partitions.append(partition_loader)
+
+
+@register_partitioning_strategy("RandomEdges", "GeneralPartitionDataManager")
+class RandomEdgesGraphPartitioner(EdgeBasedGraphPartitioner):
+ """Partitioning strategy that splits edges into equal size
+ partitions randomly drawing triples from the data.
+
+ Example
+ -------
+ >>> from ampligraph.datasets imoprt load_fb15k_237
+ >>> from ampligraph.datasets import GraphDataLoader
+ >>> from ampligraph.datasets.sqlite_adapter import SQLiteAdapter
+ >>> from ampligraph.datasets.graph_partitioner import RandomEdgesGraphPartitioner
+ >>> from ampligraph.latent_features import ScoringBasedEmbeddingModel
+ >>> dataset = load_fb15k_237()
+ >>> dataset_loader = GraphDataLoader(dataset['train'],
+ >>> backend=SQLiteAdapter, # type of backend to use
+ >>> batch_size=2, # batch size to use while iterating over this dataset
+ >>> dataset_type='train', # dataset type
+ >>> use_filter=False, # Whether to use filter or not
+ >>> use_indexer=True) # indicates that the data needs to be mapped to index
+ >>> partitioner = RandomEdgesGraphPartitioner(dataset_loader, k=2)
+ >>> # create and compile a model as usual
+ >>> partitioned_model = ScoringBasedEmbeddingModel(eta=2,
+ >>> k=50,
+ >>> scoring_type='DistMult')
+ >>>
+ >>> partitioned_model.compile(optimizer='adam', loss='multiclass_nll')
+ >>>
+ >>> partitioned_model.fit(partitioner, # pass the partitioner object as input to the fit function this will generate data for the model during training
+ >>> epochs=10) # number of epochs
+
+ """
+
+ def __init__(self, data, k=2, **kwargs):
+ """Initialise the RandomEdgesGraphPartitioner.
+
+ Parameters
+ ----------
+ data: GraphDataLoader
+ Input data as a GraphDataLoader.
+ k: int
+ Number of buckets to split entities (i.e., vertices) into.
+
+ """
+ self.partitions = []
+ self._data = data
+ self._k = k
+ super().__init__(data, k, random=True, index_by="", **kwargs)
+
+
+@register_partitioning_strategy("Naive", "GeneralPartitionDataManager")
+class NaiveGraphPartitioner(EdgeBasedGraphPartitioner):
+ """Partitioning strategy that splits edges into equal size
+ partitions drawing triples from the data sequentially.
+
+ Example
+ -------
+ >>> from ampligraph.datasets imoprt load_fb15k_237
+ >>> from ampligraph.datasets import GraphDataLoader
+ >>> from ampligraph.datasets.sqlite_adapter import SQLiteAdapter
+ >>> from ampligraph.datasets.graph_partitioner import NaiveGraphPartitioner
+ >>> from ampligraph.latent_features import ScoringBasedEmbeddingModel
+ >>> dataset = load_fb15k_237()
+ >>> dataset_loader = GraphDataLoader(dataset['train'],
+ >>> backend=SQLiteAdapter, # type of backend to use
+ >>> batch_size=2, # batch size to use while iterating over this dataset
+ >>> dataset_type='train', # dataset type
+ >>> use_filter=False, # Whether to use filter or not
+ >>> use_indexer=True) # indicates that the data needs to be mapped to index
+ >>> partitioner = NaiveGraphPartitioner(dataset_loader, k=2)
+ >>> # create and compile a model as usual
+ >>> partitioned_model = ScoringBasedEmbeddingModel(eta=2,
+ >>> k=50,
+ >>> scoring_type='DistMult')
+ >>>
+ >>> partitioned_model.compile(optimizer='adam', loss='multiclass_nll')
+ >>>
+ >>> partitioned_model.fit(partitioner, # pass the partitioner object as input to the fit function this will generate data for the model during training
+ >>> epochs=10) # number of epochs
+
+ """
+
+ def __init__(self, data, k=2, **kwargs):
+ """Initialise the NaiveGraphPartitioner.
+
+ Parameters
+ ----------
+ data: GraphDataLoader
+ Input data as a GraphDataLoader.
+ k: int
+ Number of buckets to split entities (i.e., vertices) into.
+
+ """
+ self.partitions = []
+ super().__init__(data, k, random=False, index_by="", **kwargs)
+
+
+@register_partitioning_strategy("SortedEdges", "GeneralPartitionDataManager")
+class SortedEdgesGraphPartitioner(EdgeBasedGraphPartitioner):
+ """Partitioning strategy that splits edges into equal size
+ partitions retrieving triples from the data ordered by subject.
+
+ Example
+ -------
+ >>> from ampligraph.datasets imoprt load_fb15k_237
+ >>> from ampligraph.datasets import GraphDataLoader
+ >>> from ampligraph.datasets.sqlite_adapter import SQLiteAdapter
+ >>> from ampligraph.datasets.graph_partitioner import SortedEdgesGraphPartitioner
+ >>> from ampligraph.latent_features import ScoringBasedEmbeddingModel
+ >>> dataset = load_fb15k_237()
+ >>> dataset_loader = GraphDataLoader(dataset['train'],
+ >>> backend=SQLiteAdapter, # type of backend to use
+ >>> batch_size=2, # batch size to use while iterating over this dataset
+ >>> dataset_type='train', # dataset type
+ >>> use_filter=False, # Whether to use filter or not
+ >>> use_indexer=True) # indicates that the data needs to be mapped to index
+ >>> partitioner = SortedEdgesGraphPartitioner(dataset_loader, k=2)
+ >>> # create and compile a model as usual
+ >>> partitioned_model = ScoringBasedEmbeddingModel(eta=2,
+ >>> k=50,
+ >>> scoring_type='DistMult')
+ >>>
+ >>> partitioned_model.compile(optimizer='adam', loss='multiclass_nll')
+ >>>
+ >>> partitioned_model.fit(partitioner, # pass the partitioner object as input to the fit function this will generate data for the model during training
+ >>> epochs=10) # number of epochs
+
+ """
+
+ def __init__(self, data, k=2, **kwargs):
+ """Initialise the SortedEdgesGraphPartitioner.
+
+ Parameters
+ ----------
+ data: GraphDataLoader
+ Input data as a GraphDataLoader.
+ k: int
+ Number of buckets to split entities (i.e., vertices) into.
+
+ """
+
+ self.partitions = []
+ super().__init__(data, k, random=False, index_by="s", **kwargs)
+
+
+@register_partitioning_strategy(
+ "DoubleSortedEdges", "GeneralPartitionDataManager"
+)
+class DoubleSortedEdgesGraphPartitioner(EdgeBasedGraphPartitioner):
+ """Partitioning strategy that splits edges into equal size
+ partitions retrieving triples from the data ordered by subject and object.
+
+ Example
+ -------
+ >>> from ampligraph.datasets imoprt load_fb15k_237
+ >>> from ampligraph.datasets import GraphDataLoader
+ >>> from ampligraph.datasets.sqlite_adapter import SQLiteAdapter
+ >>> from ampligraph.datasets.graph_partitioner import DoubleSortedEdgesGraphPartitioner
+ >>> from ampligraph.latent_features import ScoringBasedEmbeddingModel
+ >>> dataset = load_fb15k_237()
+ >>> dataset_loader = GraphDataLoader(dataset['train'],
+ >>> backend=SQLiteAdapter, # type of backend to use
+ >>> batch_size=2, # batch size to use while iterating over this dataset
+ >>> dataset_type='train', # dataset type
+ >>> use_filter=False, # Whether to use filter or not
+ >>> use_indexer=True) # indicates that the data needs to be mapped to index
+ >>> partitioner = DoubleSortedEdgesGraphPartitioner(dataset_loader, k=2)
+ >>> # create and compile a model as usual
+ >>> partitioned_model = ScoringBasedEmbeddingModel(eta=2,
+ >>> k=50,
+ >>> scoring_type='DistMult')
+ >>>
+ >>> partitioned_model.compile(optimizer='adam', loss='multiclass_nll')
+ >>>
+ >>> partitioned_model.fit(partitioner, # pass the partitioner object as input to the fit function this will generate data for the model during training
+ >>> epochs=10) # number of epochs
+
+ """
+
+ def __init__(self, data, k=2, **kwargs):
+ """Initialise the DoubleSortedEdgesGraphPartitioner.
+
+ Parameters
+ ----------
+ data: GraphDataLoader
+ Input data as a GraphDataLoader.
+ k: int
+ Number of buckets to split entities (i.e., vertices) into.
+
+ """
+ self.partitions = []
+ super().__init__(data, k, random=False, index_by="so", **kwargs)
+
+
+def main():
+ pass
+
+
+if __name__ == "__main__":
+ main()
diff --git a/ampligraph/datasets/numpy_adapter.py b/ampligraph/datasets/numpy_adapter.py
deleted file mode 100644
index 12228ee3..00000000
--- a/ampligraph/datasets/numpy_adapter.py
+++ /dev/null
@@ -1,214 +0,0 @@
-# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
-#
-# This file is Licensed under the Apache License, Version 2.0.
-# A copy of the Licence is available in LICENCE, or at:
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-import numpy as np
-from ..datasets import AmpligraphDatasetAdapter, SQLiteAdapter
-
-
-class NumpyDatasetAdapter(AmpligraphDatasetAdapter):
-
- def __init__(self):
- """Initialize the class variables
- """
- super(NumpyDatasetAdapter, self).__init__()
- # NumpyDatasetAdapter uses SQLAdapter to filter (if filters are set)
- self.filter_adapter = None
-
- def generate_mappings(self, use_all=False):
- """Generate mappings from either train set or use all dataset to generate mappings
-
- Parameters
- ----------
- use_all : boolean
- If True, it generates mapping from all the data. If False, it only uses training set to generate mappings
-
- Returns
- -------
- rel_to_idx : dictionary
- Relation to idx mapping dictionary
- ent_to_idx : dictionary
- entity to idx mapping dictionary
- """
- from ..evaluation import create_mappings
- if use_all:
- complete_dataset = []
- for key in self.dataset.keys():
- complete_dataset.append(self.dataset[key])
- self.rel_to_idx, self.ent_to_idx = create_mappings(np.concatenate(complete_dataset, axis=0))
-
- else:
- self.rel_to_idx, self.ent_to_idx = create_mappings(self.dataset["train"])
-
- return self.rel_to_idx, self.ent_to_idx
-
- def use_mappings(self, rel_to_idx, ent_to_idx):
- """Use an existing mapping with the datasource.
- """
- super().use_mappings(rel_to_idx, ent_to_idx)
-
- def get_size(self, dataset_type="train"):
- """Returns the size of the specified dataset
- Parameters
- ----------
- dataset_type : string
- type of the dataset
-
- Returns
- -------
- size : int
- size of the specified dataset
- """
- return self.dataset[dataset_type].shape[0]
-
- def data_exists(self, dataset_type="train"):
- """Checks if a dataset_type exists in the adapter.
- Parameters
- ----------
- dataset_type : string
- type of the dataset
-
- Returns
- -------
- exists : bool
- Boolean indicating if dataset_type exists in the adapter.
- """
-
- return dataset_type in self.dataset.keys()
-
- def get_next_batch(self, batches_count=-1, dataset_type="train", use_filter=False):
- """Generator that returns the next batch of data.
-
- Parameters
- ----------
- batches_count: int
- number of batches per epoch (default: -1, i.e. uses batch_size of 1)
- dataset_type: string
- indicates which dataset to use
- use_filter : bool
- Flag to indicate whether to return the concepts that need to be filtered
-
- Returns
- -------
- batch_output : nd-array
- yields a batch of triples from the dataset type specified
- participating_objects : nd-array [n,1]
- all objects that were involved in the s-p-? relation. This is returned only if use_filter is set to true.
- participating_subjects : nd-array [n,1]
- all subjects that were involved in the ?-p-o relation. This is returned only if use_filter is set to true.
- """
- # if data is not already mapped, then map before returning the batch
- if not self.mapped_status[dataset_type]:
- self.map_data()
-
- if batches_count == -1:
- batch_size = 1
- batches_count = self.get_size(dataset_type)
- else:
- batch_size = int(np.ceil(self.get_size(dataset_type) / batches_count))
-
- for i in range(batches_count):
- output = []
- out = np.int32(self.dataset[dataset_type][(i * batch_size):((i + 1) * batch_size), :])
- output.append(out)
-
- try:
- focusE_numeric_edge_values_batch = self.focusE_numeric_edge_values[
- dataset_type][(i * batch_size):((i + 1) * batch_size), :]
- output.append(focusE_numeric_edge_values_batch)
- except KeyError:
- pass
-
- if use_filter:
- # get the filter values by querying the database
- participating_objects, participating_subjects = self.filter_adapter.get_participating_entities(out)
- output.append(participating_objects)
- output.append(participating_subjects)
-
- yield output
-
- def map_data(self, remap=False):
- """map the data to the mappings of ent_to_idx and rel_to_idx
- Parameters
- ----------
- remap : boolean
- remap the data, if already mapped. One would do this if the dictionary is updated.
- """
- from ..evaluation import to_idx
- if len(self.rel_to_idx) == 0 or len(self.ent_to_idx) == 0:
- self.generate_mappings()
-
- for key in self.dataset.keys():
- if (not self.mapped_status[key]) or (remap is True):
- self.dataset[key] = to_idx(self.dataset[key],
- ent_to_idx=self.ent_to_idx,
- rel_to_idx=self.rel_to_idx)
- self.mapped_status[key] = True
-
- def _validate_data(self, data):
- """ Validates the data
- """
- if type(data) != np.ndarray:
- msg = 'Invalid type for input data. Expected ndarray, got {}'.format(type(data))
- raise ValueError(msg)
-
- if (np.shape(data)[1]) != 3:
- msg = 'Invalid size for input data. Expected number of column 3, got {}'.format(np.shape(data)[1])
- raise ValueError(msg)
-
- def set_data(self, dataset, dataset_type=None, mapped_status=False, focusE_numeric_edge_values=None):
- """set the dataset based on the type.
- Note: If you pass the same dataset type (which exists) it will be overwritten
-
- Parameters
- ----------
- dataset : nd-array or dictionary
- dataset of triples
- dataset_type : string
- if the dataset parameter is an nd- array then this indicates the type of the data being based
- mapped_status : bool
- indicates whether the data has already been mapped to the indices
- focusE_numeric_edge_values: nd-array
- list of all the numeric values associated the link
- """
- if isinstance(dataset, dict):
- for key in dataset.keys():
- self._validate_data(dataset[key])
- self.dataset[key] = dataset[key]
- self.mapped_status[key] = mapped_status
- if focusE_numeric_edge_values is not None:
- self.focusE_numeric_edge_values[key] = focusE_numeric_edge_values[key]
- elif dataset_type is not None:
- self._validate_data(dataset)
- self.dataset[dataset_type] = dataset
- self.mapped_status[dataset_type] = mapped_status
- if focusE_numeric_edge_values is not None:
- self.focusE_numeric_edge_values[dataset_type] = focusE_numeric_edge_values
- else:
- raise Exception("Incorrect usage. Expected a dictionary or a combination of dataset and it's type.")
-
- # If the concept-idx mappings are present, then map the passed dataset
- if not (len(self.rel_to_idx) == 0 or len(self.ent_to_idx) == 0):
- self.map_data()
-
- def set_filter(self, filter_triples, mapped_status=False):
- """set's the filter that need to be used while generating evaluation batch
- Note: This adapter uses SQL backend to do filtering
- Parameters
- ----------
- filter_triples : nd-array
- triples that would be used as filter
- """
- self.filter_adapter = SQLiteAdapter()
- self.filter_adapter.use_mappings(self.rel_to_idx, self.ent_to_idx)
- self.filter_adapter.set_data(filter_triples, "filter", mapped_status)
-
- def cleanup(self):
- """Cleans up the internal state.
- """
- if self.filter_adapter is not None:
- self.filter_adapter.cleanup()
- self.filter_adapter = None
diff --git a/ampligraph/datasets/oneton_adapter.py b/ampligraph/datasets/oneton_adapter.py
deleted file mode 100644
index a96f8912..00000000
--- a/ampligraph/datasets/oneton_adapter.py
+++ /dev/null
@@ -1,449 +0,0 @@
-# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
-#
-# This file is Licensed under the Apache License, Version 2.0.
-# A copy of the Licence is available in LICENCE, or at:
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-import numpy as np
-from ..datasets import NumpyDatasetAdapter
-import logging
-logger = logging.getLogger(__name__)
-logger.setLevel(logging.DEBUG)
-
-
-class OneToNDatasetAdapter(NumpyDatasetAdapter):
- r"""1-to-N Dataset Adapter.
-
- Given a triples dataset X comprised of n triples in the form (s, p, o), this dataset adapter will
- generate one-hot outputs for each (s, p) tuple to all entities o that are found in X.
-
- E.g: X = [[a, p, b],
- [a, p, d],
- [c, p, d],
- [c, p, e],
- [c, p, f]]
-
- Gives a one-hot vector mapping of entities to indices:
-
- Entities: [a, b, c, d, e, f]
- Indices: [0, 1, 2, 3, 4, 5]
-
- One-hot outputs are produced for each (s, p) tuple to all valid object indices in the dataset:
-
- # [a, b, c, d, e, f]
- (a, p) : [0, 1, 0, 1, 0, 0]
-
- The ```get_next_batch``` function yields the (s, p, o) triple and one-hot vector corresponding to the (s, p)
- tuple.
-
- If batches are generated with ```unique_pairs=True``` then only one instance of each unique (s, p) tuple
- is returned:
-
- (a, p) : [0, 1, 0, 1, 0, 0]
- (c, p) : [0, 0, 0, 1, 1, 1]
-
- Otherwise batch outputs are generated in dataset order (required for evaluating test set, but gives a higher
- weight to more frequent (s, p) pairs if used during model training):
-
- (a, p) : [0, 1, 0, 1, 0, 0]
- (a, p) : [0, 1, 0, 1, 0, 0]
- (c, p) : [0, 0, 0, 1, 1, 1]
- (c, p) : [0, 0, 0, 1, 1, 1]
- (c, p) : [0, 0, 0, 1, 1, 1]
-
- """
-
- def __init__(self, low_memory=False):
- """Initialize the class variables
-
- Parameters
- ----------
- low_memory : bool
- If low_memory flag set to True the output vectors indices are generated on-the-fly in the batch yield
- function, which lowers memory usage but increases training time.
-
- """
- super(OneToNDatasetAdapter, self).__init__()
-
- self.filter_mapping = None
- self.filtered_status = {}
- self.paired_status = {}
- self.output_mapping = None
- self.output_onehot = {}
- self.low_memory = low_memory
-
- def set_filter(self, filter_triples, mapped_status=False):
- """ Set the filter to be used while generating batch outputs.
-
- Parameters
- ----------
- filter_triples : nd-array
- Triples to be used as a filter.
- mapped_status : bool
- Bool indicating if filter has already been mapped to internal indices.
-
- """
-
- self.set_data(filter_triples, 'filter', mapped_status)
- self.filter_mapping = self.generate_output_mapping('filter')
-
- def generate_outputs(self, dataset_type='train', use_filter=False, unique_pairs=True):
- """Generate one-hot outputs for a dataset.
-
- Parameters
- ----------
- dataset_type : string
- Indicates which dataset to generate outputs for.
- use_filter : bool
- Bool indicating whether to generate outputs using the filter set by `set_filter()`. Default: False
- unique_pairs : bool
- Bool indicating whether to generate outputs according to unique pairs of (subject, predicate), otherwise
- will generate outputs in same row-order as the triples in the specified dataset. Default: True.
-
- """
-
- if dataset_type not in self.dataset.keys():
- msg = 'Unable to generate outputs: dataset `{}` not found. ' \
- 'Use `set_data` to set dataset in adapter first.'.format(dataset_type)
- raise KeyError(msg)
-
- if dataset_type in ['valid', 'test']:
- if unique_pairs:
- # This is just a friendly warning - in most cases the test and valid sets should NOT be unique_pairs.
- msg = 'Generating outputs for dataset `{}` with unique_pairs=True. ' \
- 'Are you sure this is desired behaviour?'.format(dataset_type)
- logger.warning(msg)
-
- if use_filter:
- if self.filter_mapping is None:
- msg = 'Filter not found: cannot generate one-hot outputs with `use_filter=True` ' \
- 'if a filter has not been set.'
- raise ValueError(msg)
- else:
- output_dict = self.filter_mapping
- else:
- if self.output_mapping is None:
- msg = 'Output mapping was not created before generating one-hot vectors. '
- raise ValueError(msg)
- else:
- output_dict = self.output_mapping
-
- if self.low_memory:
- # With low_memory=True the output indices are generated on the fly in the batch yield function
- pass
- else:
- if unique_pairs:
- X = np.unique(self.dataset[dataset_type][:, [0, 1]], axis=0).astype(np.int32)
- else:
- X = self.dataset[dataset_type]
-
- # Initialize np.array of shape [len(X), num_entities]
- self.output_onehot[dataset_type] = np.zeros((len(X), len(self.ent_to_idx)), dtype=np.int8)
-
- # Set one-hot indices using output_dict
- for i, x in enumerate(X):
- indices = output_dict.get((x[0], x[1]), [])
- self.output_onehot[dataset_type][i, indices] = 1
-
- # Set flags indicating filter and unique pair status of outputs for given dataset.
- self.filtered_status[dataset_type] = use_filter
- self.paired_status[dataset_type] = unique_pairs
-
- def generate_output_mapping(self, dataset_type='train'):
- """ Creates dictionary keyed on (subject, predicate) to list of objects
-
- Parameters
- ----------
- dataset_type : string
- Indicates which dataset to generate output mapping from.
-
- Returns
- -------
- dict
- """
-
- # if data is not already mapped, then map before creating output map
- if not self.mapped_status[dataset_type]:
- self.map_data()
-
- output_mapping = dict()
-
- for s, p, o in self.dataset[dataset_type]:
- output_mapping.setdefault((s, p), []).append(o)
-
- return output_mapping
-
- def set_output_mapping(self, output_dict, clear_outputs=True):
- """ Set the mapping used to generate one-hot outputs vectors.
-
- Setting a new output mapping will clear_outputs any previously generated outputs, as otherwise
- can lead to a situation where old outputs are returned from batch function.
-
- Parameters
- ----------
- output_dict : dict
- (subject, predicate) to object indices
- clear_outputs: bool
- Clears any one hot outputs held by the adapter, as otherwise can lead to a situation where onehot
- outputs generated by a different mapping are returned from the batch function. Default: True.
-
- """
-
- self.output_mapping = output_dict
-
- # Clear any onehot outputs previously generated
- if clear_outputs:
- self.clear_outputs()
-
- def clear_outputs(self, dataset_type=None):
- """ Clears generated one-hot outputs currently held by the adapter.
-
- Parameters
- ----------
- dataset_type: string
- indicates which dataset to clear_outputs. Default: None (clears all).
-
- """
-
- if dataset_type is None:
- self.output_onehot = {}
- self.filtered_status = {}
- self.paired_status = {}
- else:
- del self.output_onehot[dataset_type]
- del self.filtered_status[dataset_type]
- del self.paired_status[dataset_type]
-
- def verify_outputs(self, dataset_type, use_filter, unique_pairs):
- """Verifies if one-hot outputs currently held in adapter correspond to the use_filter and unique_pairs
- options.
-
- Parameters
- ----------
- dataset_type: string
- indicates which dataset to use
- use_filter : bool
- Flag to indicate whether the one-hot outputs are generated from filtered or unfiltered datasets
- unique_pairs : bool
- Flag to indicate whether the one-hot outputs are generated by unique (s, p) pairs or in dataset order.
-
- Returns
- -------
- bool
- If False then outputs must be re-generated for the specified dataset and parameters.
-
- """
-
- if dataset_type not in self.output_onehot.keys():
- # One-hot outputs have not been generated for this dataset_type
- return False
-
- if dataset_type not in self.filtered_status.keys():
- # This shouldn't happen.
- logger.debug('Dataset {} is in adapter, but filtered_status is not set.'.format(dataset_type))
- return False
-
- if dataset_type not in self.paired_status.keys():
- logger.debug('Dataset {} is in adapter, but paired_status is not set.'.format(dataset_type))
- return False
-
- if use_filter != self.filtered_status[dataset_type]:
- return False
-
- if unique_pairs != self.paired_status[dataset_type]:
- return False
-
- return True
-
- def get_next_batch(self, batches_count=-1, dataset_type='train', use_filter=False, unique_pairs=True):
- """Generator that returns the next batch of data.
-
- Parameters
- ----------
- batches_count: int
- number of batches per epoch (default: -1, i.e. uses batch_size of 1)
- dataset_type: string
- indicates which dataset to use
- use_filter : bool
- Flag to indicate whether the one-hot outputs are generated from filtered or unfiltered datasets
- unique_pairs : bool
- Flag to indicate whether the one-hot outputs are generated by unique (s, p) pairs or in dataset order.
-
- Returns
- -------
- batch_output : nd-array, shape=[batch_size, 3]
- A batch of triples from the dataset type specified. If unique_pairs=True, then the object column
- will be set to zeros.
- batch_onehot : nd-array
- A batch of onehot arrays corresponding to `batch_output` triples
- """
-
- # if data is not already mapped, then map before returning the batch
- if not self.mapped_status[dataset_type]:
- self.map_data()
-
- if unique_pairs:
- X = np.unique(self.dataset[dataset_type][:, [0, 1]], axis=0).astype(np.int32)
- X = np.c_[X, np.zeros(len(X))] # Append dummy object columns
- else:
- X = self.dataset[dataset_type]
- dataset_size = len(X)
-
- if batches_count == -1:
- batch_size = 1
- batches_count = dataset_size
- else:
- batch_size = int(np.ceil(dataset_size / batches_count))
-
- if use_filter and self.filter_mapping is None:
- msg = 'Cannot set `use_filter=True` if a filter has not been set in the adapter. '
- logger.error(msg)
- raise ValueError(msg)
-
- if not self.low_memory:
-
- if not self.verify_outputs(dataset_type, use_filter=use_filter, unique_pairs=unique_pairs):
- # Verifies that onehot outputs are as expected given filter and unique_pair settings
- msg = 'Generating one-hot outputs for {} [filtered: {}, unique_pairs: {}]'\
- .format(dataset_type, use_filter, unique_pairs)
- logger.info(msg)
- self.generate_outputs(dataset_type, use_filter=use_filter, unique_pairs=unique_pairs)
-
- # Yield batches
- for i in range(batches_count):
-
- out = np.int32(X[(i * batch_size):((i + 1) * batch_size), :])
- out_onehot = self.output_onehot[dataset_type][(i * batch_size):((i + 1) * batch_size), :]
-
- yield out, out_onehot
-
- else:
- # Low-memory, generate one-hot outputs per batch on the fly
- if use_filter:
- output_dict = self.filter_mapping
- else:
- output_dict = self.output_mapping
-
- # Yield batches
- for i in range(batches_count):
-
- out = np.int32(X[(i * batch_size):((i + 1) * batch_size), :])
- out_onehot = np.zeros(shape=[out.shape[0], len(self.ent_to_idx)], dtype=np.int32)
-
- for j, x in enumerate(out):
- indices = output_dict.get((x[0], x[1]), [])
- out_onehot[j, indices] = 1
-
- yield out, out_onehot
-
- def get_next_batch_subject_corruptions(self, batch_size=-1, dataset_type='train', use_filter=True):
- """Batch generator for subject corruptions.
-
- To avoid multiple redundant forward-passes through the network, subject corruptions are performed once for
- each relation, and results accumulated for valid test triples.
-
- If there are no test triples for a relation, then that relation is ignored.
-
- Use batch_size to control memory usage (as a batch_size*N tensor will be allocated, where N is number
- of unique entities.)
-
- Parameters
- ----------
- batch_size: int
- Maximum batch size returned.
- dataset_type: string
- indicates which dataset to use
- use_filter : bool
- Flag to indicate whether to return the one-hot outputs are generated from filtered or unfiltered datasets
-
- Returns
- -------
-
- test_triples : nd-array of shape (?, 3)
- The set of all triples from the dataset type specified that include the predicate currently returned
- in batch_triples.
- batch_triples : nd-array of shape (M, 3), where M is the subject corruption batch size.
- A batch of triples corresponding to subject corruptions of just one predicate.
- batch_onehot : nd-array of shape (M, N), where N is number of unique entities.
- A batch of onehot arrays corresponding to the batch_triples output.
-
- """
-
- if use_filter:
- output_dict = self.filter_mapping
- else:
- output_dict = self.output_mapping
-
- if batch_size == -1:
- batch_size = len(self.ent_to_idx)
-
- ent_list = np.array(list(self.ent_to_idx.values()))
- rel_list = np.array(list(self.rel_to_idx.values()))
-
- for rel in rel_list:
-
- # Select test triples that have this relation
- rel_idx = self.dataset[dataset_type][:, 1] == rel
- test_triples = self.dataset[dataset_type][rel_idx]
-
- ent_idx = 0
-
- while ent_idx < len(ent_list):
-
- ents = ent_list[ent_idx:ent_idx + batch_size]
- ent_idx += batch_size
-
- # Note: the object column is just a dummy value so set to 0
- out = np.stack([ents, np.repeat(rel, len(ents)), np.repeat(0, len(ents))], axis=1)
-
- # Set one-hot filter
- out_filter = np.zeros((out.shape[0], len(ent_list)), dtype=np.int8)
- for j, x in enumerate(out):
- indices = output_dict.get((x[0], x[1]), [])
- out_filter[j, indices] = 1
-
- yield test_triples, out, out_filter
-
- def _validate_data(self, data):
- """Validates the data
- """
- if type(data) != np.ndarray:
- msg = 'Invalid type for input data. Expected ndarray, got {}'.format(type(data))
- raise ValueError(msg)
-
- if (np.shape(data)[1]) != 3:
- msg = 'Invalid size for input data. Expected number of column 3, got {}'.format(np.shape(data)[1])
- raise ValueError(msg)
-
- def set_data(self, dataset, dataset_type=None, mapped_status=False):
- """Set the dataset based on the type.
-
- Note: If you pass the same dataset type (which exists) it will be overwritten
-
- Parameters
- ----------
- dataset : nd-array or dictionary
- dataset of triples
- dataset_type : string
- if the dataset parameter is an nd- array then this indicates the type of the data being based
- mapped_status : bool
- indicates whether the data has already been mapped to the indices
-
- """
- if isinstance(dataset, dict):
- for key in dataset.keys():
- self._validate_data(dataset[key])
- self.dataset[key] = dataset[key]
- self.mapped_status[key] = mapped_status
- elif dataset_type is not None:
- self._validate_data(dataset)
- self.dataset[dataset_type] = dataset
- self.mapped_status[dataset_type] = mapped_status
- else:
- raise Exception("Incorrect usage. Expected a dictionary or a combination of dataset and it's type.")
-
- # If the concept-idx mappings are present, then map the passed dataset
- if not (len(self.rel_to_idx) == 0 or len(self.ent_to_idx) == 0):
- print('Mapping set data: {}'.format(dataset_type))
- self.map_data()
diff --git a/ampligraph/datasets/partitioned_data_manager.py b/ampligraph/datasets/partitioned_data_manager.py
new file mode 100644
index 00000000..90cbedf8
--- /dev/null
+++ b/ampligraph/datasets/partitioned_data_manager.py
@@ -0,0 +1,993 @@
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
+#
+# This file is Licensed under the Apache License, Version 2.0.
+# A copy of the Licence is available in LICENCE, or at:
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+import abc
+import glob
+import logging
+import os
+import shelve
+import shutil
+import tempfile
+from datetime import datetime
+
+import numpy as np
+import tensorflow as tf
+
+from .graph_partitioner import (
+ PARTITION_ALGO_REGISTRY,
+ AbstractGraphPartitioner,
+)
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.DEBUG)
+
+PARTITION_MANAGER_REGISTRY = {}
+
+
+def register_partitioning_manager(name):
+ """Decorator responsible for registering partition manager in the partition manager registry.
+
+ Parameters
+ ----------
+ name: str
+ Name of the new partition manager.
+
+ Example
+ -------
+ >>>@register_partitioning_manager("NewManagerName")
+ >>>class NewManagerName(PartitionDataManager):
+ >>>... pass
+ """
+
+ def insert_in_registry(class_handle):
+ """Checks if partition manager already exists and if not registers it."""
+ if name in PARTITION_MANAGER_REGISTRY.keys():
+ msg = "Partitioning Manager with name {} already exists!".format(
+ name
+ )
+ logger.error(msg)
+ raise Exception(msg)
+
+ PARTITION_MANAGER_REGISTRY[name] = class_handle
+ class_handle.name = name
+ return class_handle
+
+ return insert_in_registry
+
+
+class PartitionDataManager(abc.ABC):
+ def __init__(
+ self,
+ dataset_loader,
+ model,
+ strategy="Bucket",
+ partitioner_k=3,
+ root_directory=None,
+ ent_map_fname=None,
+ ent_meta_fname=None,
+ rel_map_fname=None,
+ rel_meta_fname=None,
+ ):
+ """Initializes the Partitioning Data Manager.
+
+ Uses/Creates partitioner and generates and manages partition related parameters.
+ This is the base class.
+
+ Parameters
+ ----------
+ dataset_loader : AbstractGraphPartitioner or GraphDataLoader
+ Either an instance of AbstractGraphPartitioner or GraphDataLoader.
+ model : tf.keras.Model
+ The model that is being trained.
+ strategy : string
+ Type of partitioning strategy to use.
+ root_directory : str
+ directory where the partition manager files will be stored.
+ """
+ self._model = model
+ self.k = self._model.k
+ self.internal_k = self._model.internal_k
+ self.eta = self._model.eta
+ self.partitioner_k = partitioner_k
+ if (
+ ent_map_fname is not None
+ and ent_meta_fname is not None
+ and rel_map_fname is not None
+ and rel_meta_fname is not None
+ ):
+ self.root_directory = os.path.dirname(ent_map_fname)
+ self.timestamp = os.path.basename(ent_map_fname).split("_")[-1][
+ :-4
+ ]
+ self.ent_map_fname = ent_map_fname
+ self.ent_meta_fname = ent_meta_fname
+ self.rel_map_fname = rel_map_fname
+ self.rel_meta_fname = rel_map_fname
+
+ else:
+ if root_directory is not None:
+ self.root_directory = root_directory
+ else:
+ self.root_directory = tempfile.gettempdir()
+ self.timestamp = datetime.now().strftime("%d-%m-%Y_%I-%M-%S_%f_%p")
+ self.ent_map_fname = os.path.join(
+ self.root_directory, "ent_partition_{}".format(self.timestamp)
+ )
+ self.ent_meta_fname = os.path.join(
+ self.root_directory, "ent_metadata_{}".format(self.timestamp)
+ )
+ self.rel_map_fname = os.path.join(
+ self.root_directory, "rel_partition_{}".format(self.timestamp)
+ )
+ self.rel_meta_fname = os.path.join(
+ self.root_directory, "rel_metadata_{}".format(self.timestamp)
+ )
+
+ if isinstance(dataset_loader, AbstractGraphPartitioner):
+ self.partitioner = dataset_loader
+ self.partitioner_k = self.partitioner._k
+ else:
+ print("Partitioning may take a while...")
+ self.partitioner = PARTITION_ALGO_REGISTRY.get(strategy)(
+ dataset_loader, k=self.partitioner_k
+ )
+
+ self.num_ents = (
+ self.partitioner._data.backend.mapper.backend.ents_length
+ )
+ self.num_rels = (
+ self.partitioner._data.backend.mapper.backend.rels_length
+ )
+ self.max_ent_size = 0
+ for i in range(len(self.partitioner.partitions)):
+ self.max_ent_size = max(
+ self.max_ent_size,
+ self.partitioner.partitions[
+ i
+ ].backend.mapper.backend.ents_length,
+ )
+
+ self._generate_partition_params()
+
+ def _copy_files(self, base):
+ for file in glob.glob(base + "*"):
+ shutil.copy(file, self.root_directory)
+
+ def get_update_metadata(self, filepath):
+ self.root_directory = filepath
+ self.root_directory = (
+ "." if self.root_directory == "" else self.root_directory
+ )
+
+ # new_file_name = os.path.join(self.root_directory, '*{}.bak'.format(self.timestamp))
+ try:
+ self._copy_files(self.ent_map_fname)
+ self._copy_files(self.ent_meta_fname)
+ self._copy_files(self.rel_map_fname)
+ self._copy_files(self.rel_meta_fname)
+ self.ent_map_fname = os.path.join(
+ self.root_directory, "ent_partition_{}".format(self.timestamp)
+ )
+ self.ent_meta_fname = os.path.join(
+ self.root_directory, "ent_metadata_{}".format(self.timestamp)
+ )
+ self.rel_map_fname = os.path.join(
+ self.root_directory, "rel_partition_{}".format(self.timestamp)
+ )
+ self.rel_meta_fname = os.path.join(
+ self.root_directory, "rel_metadata_{}".format(self.timestamp)
+ )
+ except shutil.SameFileError:
+ pass
+
+ metadata = {
+ "root_directory": self.root_directory,
+ "partitioner_k": self.partitioner_k,
+ "ent_map_fname": self.ent_map_fname,
+ "ent_meta_fname": self.ent_meta_fname,
+ "rel_map_fname": self.rel_map_fname,
+ "rel_meta_fname": self.rel_meta_fname,
+ }
+ return metadata
+
+ @property
+ def max_entities(self):
+ """Returns the maximum entity size that can occur in a partition."""
+ return self.max_ent_size
+
+ @property
+ def max_relations(self):
+ """Returns the maximum relation size that can occur in a partition."""
+ return self.num_rels
+
+ def _generate_partition_params(self):
+ """Generates the metadata needed for persisting and loading partition embeddings and other parameters."""
+ raise NotImplementedError("Abstract method not implemented")
+
+ def _update_partion_embeddings(self, graph_data_loader, partition_number):
+ """Persists the embeddings and other parameters after a partition is trained.
+
+ Parameters
+ ----------
+ graph_data_loader : GraphDataLoader
+ Data loader of the current partition that was trained.
+ partition_number: int
+ Partition number of the current partition that was trained.
+ """
+ raise NotImplementedError("Abstract method not implemented")
+
+ def _change_partition(self, graph_data_loader, partition_number):
+ """Gets a new partition to train and loads all the parameters of the partition.
+
+ Parameters
+ ----------
+ graph_data_loader : GraphDataLoader
+ Data loader of the next partition that will be trained.
+ partition_number: int
+ Partition number of the next partition that will be trained.
+ """
+ raise NotImplementedError("Abstract method not implemented")
+
+ def data_generator(self):
+ """Generates the data to be trained from the current partition.
+
+ Once the partition data is exhausted, the current parameters are persisted; the partition is changed
+ and the model is notified.
+
+ Returns
+ -------
+ batch_data_from_current_partition: array of shape (n,3)
+ A batch of triples from the current partition being trained.
+ """
+ for i, partition_data in enumerate(self.partitioner):
+ # partition_data is an object of graph data loader
+ # Perform tasks related to change of partition
+ self._change_partition(partition_data, i)
+ try:
+ while True:
+ # generate data from the current partition
+ batch_data_from_current_partition = next(partition_data)
+ yield batch_data_from_current_partition
+
+ except StopIteration:
+ # No more data in current partition (parsed fully once), so the partition is trained
+ # Hence persist the params related to the current partition.
+ self._update_partion_embeddings(partition_data, i)
+
+ def get_tf_generator(self):
+ """Returns tensorflow data generator."""
+ return tf.data.Dataset.from_generator(
+ self.data_generator,
+ output_types=tf.dtypes.int32,
+ output_shapes=(None, 3),
+ ).prefetch(0)
+
+ def __iter__(self):
+ """Function needed to be used as an iterator."""
+ return self
+
+ def __next__(self):
+ """Function needed to be used as an iterator."""
+ return next(self.batch_iterator)
+
+ def reload(self):
+ """Reload the data for the next epoch."""
+ self.partitioner.reload()
+ self.batch_iterator = iter(self.data_generator())
+
+ def on_epoch_end(self):
+ """Activities to be performed at the end of an epoch."""
+ pass
+
+ def on_complete(self):
+ """Activities to be performed at the end of training.
+
+ The manager persists the data (splits the entity partitions into individual embeddings).
+ """
+ pass
+
+
+@register_partitioning_manager("GeneralPartitionDataManager")
+class GeneralPartitionDataManager(PartitionDataManager):
+ """Manages the partitioning related controls.
+
+ Handles data generation and informs the model about changes in partition.
+ """
+
+ def __init__(
+ self,
+ dataset_loader,
+ model,
+ strategy="RandomEdges",
+ partitioner_k=3,
+ root_directory=None,
+ ):
+ """Initialize the Partitioning Data Manager.
+
+ Uses/Creates partitioner and generates partition related parameters.
+
+ Parameters
+ ----------
+ dataset_loader : AbstractGraphPartitioner or GraphDataLoader
+ Either an instance of AbstractGraphPartitioner or GraphDataLoader.
+ model: tf.keras.Model
+ The model that is being trained.
+ strategy: str
+ Type of partitioning strategy to use.
+ root_directory: str
+ Directory where the partition manager files will be stored.
+ """
+ super(GeneralPartitionDataManager, self).__init__(
+ dataset_loader, model, strategy, partitioner_k, root_directory
+ )
+
+ def _generate_partition_params(self):
+ """Generates the metadata needed for persisting and loading partition embeddings and other parameters."""
+
+ # create entity embeddings and optimizer hyperparams for all entities
+
+ # compute each partition size
+ update_part_size = int(np.ceil(self.num_ents / self.partitioner_k))
+ num_optimizer_hyperparams = (
+ self._model.optimizer.get_hyperparam_count()
+ )
+ # for each partition
+ for part_num in range(self.partitioner_k):
+ with shelve.open(
+ self.ent_map_fname, writeback=True
+ ) as ent_partition:
+ # create the key (entity index) and value (optim params and
+ # embs)
+ for i in range(
+ update_part_size * part_num,
+ min(update_part_size * (part_num + 1), self.num_ents),
+ ):
+ out_dict_key = str(i)
+ opt_param = np.zeros(
+ shape=(1, num_optimizer_hyperparams, self.internal_k),
+ dtype=np.float32,
+ )
+ # ent_emb = xavier(self.num_ents, self.internal_k, num_ents_bucket)
+ ent_emb = self._model.encoding_layer.ent_init(
+ shape=(1, self.internal_k), dtype=tf.float32
+ ).numpy()
+ ent_partition.update({out_dict_key: [opt_param, ent_emb]})
+
+ # create relation embeddings and optimizer hyperparams for all relations
+ # relations are not partitioned
+ with shelve.open(self.rel_map_fname, writeback=True) as rel_partition:
+ for i in range(self.num_rels):
+ out_dict_key = str(i)
+ # TODO change the hardcoding from 3 to actual hyperparam of
+ # optim
+ opt_param = np.zeros(
+ shape=(1, num_optimizer_hyperparams, self.internal_k),
+ dtype=np.float32,
+ )
+ # rel_emb = xavier(self.num_rels, self.internal_k, self.num_rels)
+ rel_emb = self._model.encoding_layer.rel_init(
+ shape=(1, self.internal_k), dtype=tf.float32
+ ).numpy()
+ rel_partition.update({out_dict_key: [opt_param, rel_emb]})
+
+ def _update_partion_embeddings(self, graph_data_loader, partition_number):
+ """Persists the embeddings and other parameters after a partition is trained.
+
+ Parameters
+ ----------
+ graph_data_loader : GraphDataLoader
+ Data loader of the current partition that was trained.
+ partition_number: int
+ Partition number of the current partition that was trained.
+ """
+ # set the trained params back for persisting (exclude paddings)
+ self.all_ent_embs = self._model.encoding_layer.ent_emb.numpy()[
+ : len(self.ent_original_ids), :
+ ]
+ self.all_rel_embs = self._model.encoding_layer.rel_emb.numpy()[
+ : len(self.rel_original_ids), :
+ ]
+
+ # get the optimizer params related to the embeddings
+ (
+ ent_opt_hyperparams,
+ rel_opt_hyperparams,
+ ) = self._model.optimizer.get_entity_relation_hyperparams()
+
+ # get the number of params that are created by the optimizer
+ num_opt_hyperparams = self._model.optimizer.get_hyperparam_count()
+
+ # depending on optimizer, you can have 0 or more params
+ if num_opt_hyperparams > 0:
+ # store the params
+ original_ent_hyperparams = []
+ original_rel_hyperparams = []
+
+ # get all the different params related to entities and relations
+ # eg: beta1, beta2 related to embeddings (when using adam)
+ for i in range(num_opt_hyperparams):
+ original_ent_hyperparams.append(
+ ent_opt_hyperparams[i][: len(self.ent_original_ids)]
+ )
+ original_rel_hyperparams.append(
+ rel_opt_hyperparams[i][: len(self.rel_original_ids)]
+ )
+
+ # store for persistance
+ self.all_rel_opt_params = np.stack(original_rel_hyperparams, 1)
+ self.all_ent_opt_params = np.stack(original_ent_hyperparams, 1)
+
+ # Open the buckets related to the partition and concat
+
+ try:
+ # persist entity related embs and optim params
+ ent_partition = shelve.open(self.ent_map_fname, writeback=True)
+ for i, key in enumerate(self.ent_original_ids):
+ ent_partition[str(key)] = [
+ self.all_ent_opt_params[i: i + 1],
+ self.all_ent_embs[i: i + 1],
+ ]
+
+ finally:
+ ent_partition.close()
+
+ try:
+ # persist relation related embs and optim params
+ rel_partition = shelve.open(self.rel_map_fname, writeback=True)
+ for i, key in enumerate(self.rel_original_ids):
+ rel_partition[str(key)] = [
+ self.all_rel_opt_params[i: i + 1],
+ self.all_rel_embs[i: i + 1],
+ ]
+
+ finally:
+ rel_partition.close()
+
+ def _change_partition(self, graph_data_loader, partition_number):
+ """Gets a new partition to train and loads all the parameters of the partition.
+
+ Parameters
+ ----------
+ graph_data_loader : GraphDataLoader
+ Data loader of the next partition that will be trained.
+ partition_number: int
+ Partition number of the next partition will be trained.
+ """
+ # from the graph data loader of the current partition get the original
+ # entity ids
+ ent_count_in_partition = (
+ graph_data_loader.backend.mapper.get_entities_count()
+ )
+ self.ent_original_ids = graph_data_loader.backend.mapper.get_indexes(
+ np.arange(ent_count_in_partition), type_of="e", order="ind2raw"
+ )
+ """
+ with shelve.open(graph_data_loader.backend.mapper.entities_dict) as partition:
+ # get the partition keys(remapped 0 - partition size)
+ partition_keys = sorted([int(key) for key in partition.keys()])
+ # get the original key's i.e. original entity ids (between 0 and total entities in dataset)
+ self.ent_original_ids = [partition[str(key)] for key in partition_keys]
+ """
+ with shelve.open(self.ent_map_fname) as partition:
+ self.all_ent_embs = []
+ self.all_ent_opt_params = []
+ for key in self.ent_original_ids:
+ self.all_ent_opt_params.append(partition[key][0])
+ self.all_ent_embs.append(partition[key][1])
+ self.all_ent_embs = np.concatenate(self.all_ent_embs, 0)
+ self.all_ent_opt_params = np.concatenate(
+ self.all_ent_opt_params, 0
+ )
+
+ rel_count_in_partition = (
+ graph_data_loader.backend.mapper.get_relations_count()
+ )
+ self.rel_original_ids = graph_data_loader.backend.mapper.get_indexes(
+ np.arange(rel_count_in_partition), type_of="r", order="ind2raw"
+ )
+ """
+ with shelve.open(graph_data_loader.backend.mapper.relations_dict) as partition:
+ partition_keys = sorted([int(key) for key in partition.keys()])
+ self.rel_original_ids = [partition[str(key)] for key in partition_keys]
+ """
+
+ with shelve.open(self.rel_map_fname) as partition:
+ self.all_rel_embs = []
+ self.all_rel_opt_params = []
+ for key in self.rel_original_ids:
+ self.all_rel_opt_params.append(partition[key][0])
+ self.all_rel_embs.append(partition[key][1])
+ self.all_rel_embs = np.concatenate(self.all_rel_embs, 0)
+ self.all_rel_opt_params = np.concatenate(
+ self.all_rel_opt_params, 0
+ )
+
+ # notify the model about the partition change
+ self._model.partition_change_updates(
+ len(self.ent_original_ids), self.all_ent_embs, self.all_rel_embs
+ )
+
+ # Optimizer params will exist only after it has been persisted once
+ if self._model.current_epoch > 1:
+ # TODO: needs to be better handled
+ # get the optimizer params of the embs that will be trained
+ rel_optim_hyperparams = []
+ ent_optim_hyperparams = []
+
+ num_opt_hyperparams = self._model.optimizer.get_hyperparam_count()
+ for i in range(num_opt_hyperparams):
+ rel_hyperparam_i = self.all_rel_opt_params[:, i, :]
+ rel_hyperparam_i = np.pad(
+ rel_hyperparam_i,
+ ((0, self.num_rels - rel_hyperparam_i.shape[0]), (0, 0)),
+ "constant",
+ constant_values=(0),
+ )
+ rel_optim_hyperparams.append(rel_hyperparam_i)
+
+ ent_hyperparam_i = self.all_ent_opt_params[:, i, :]
+ ent_hyperparam_i = np.pad(
+ ent_hyperparam_i,
+ (
+ (0, self.max_ent_size - ent_hyperparam_i.shape[0]),
+ (0, 0),
+ ),
+ "constant",
+ constant_values=(0),
+ )
+ ent_optim_hyperparams.append(ent_hyperparam_i)
+
+ # notify the optimizer and update the optimizer hyperparams
+ self._model.optimizer.set_entity_relation_hyperparams(
+ ent_optim_hyperparams, rel_optim_hyperparams
+ )
+
+ def on_complete(self):
+ """Activities to be performed at the end of training.
+
+ The manager persists the data (splits the entity partitions into individual embeddings).
+ """
+ update_part_size = int(np.ceil(self.num_ents / self.partitioner_k))
+ for part_num in range(self.partitioner_k):
+ with shelve.open(
+ self.ent_map_fname, writeback=True
+ ) as ent_partition:
+ for i in range(
+ update_part_size * part_num,
+ min(update_part_size * (part_num + 1), self.num_ents),
+ ):
+ ent_partition[str(i)] = ent_partition[str(i)][1][0]
+
+ # create relation embeddings and optimizer hyperparams for all relations
+ # relations are not partitioned
+ with shelve.open(self.rel_map_fname, writeback=True) as rel_partition:
+ for i in range(self.num_rels):
+ rel_partition[str(i)] = rel_partition[str(i)][1][0]
+
+
+@register_partitioning_manager("BucketPartitionDataManager")
+class BucketPartitionDataManager(PartitionDataManager):
+ """Manages the partitioning related controls.
+
+ Handles data generation and informs model about changes in partition.
+ """
+
+ def __init__(
+ self,
+ dataset_loader,
+ model,
+ strategy="Bucket",
+ partitioner_k=3,
+ root_directory=None,
+ ):
+ """Initialize the Partitioning Data Manager.
+ Uses/Creates partitioner and generates partition related parameters.
+
+ Parameters
+ ----------
+ dataset_loader : AbstractGraphPartitioner or GraphDataLoader
+ Either an instance of AbstractGraphPartitioner or GraphDataLoader.
+ model: tf.keras.Model
+ The model that is being trained.
+ strategy: str
+ Type of partitioning strategy to use.
+ root_directory: str
+ Directory where the partition manager files will be stored.
+ """
+ super(BucketPartitionDataManager, self).__init__(
+ dataset_loader, model, strategy, partitioner_k, root_directory
+ )
+
+ def _generate_partition_params(self):
+ """Generates the metadata needed for persisting and loading partition embeddings and other parameters."""
+
+ num_optimizer_hyperparams = (
+ self._model.optimizer.get_hyperparam_count()
+ )
+
+ # create entity embeddings and optimizer hyperparams for all entities
+ for i in range(self.partitioner_k):
+ with shelve.open(
+ self.ent_map_fname, writeback=True
+ ) as ent_partition:
+ with shelve.open(self.partitioner.files[i]) as bucket:
+ out_dict_key = str(i)
+ num_ents_bucket = bucket["indexes"].shape[0]
+ # print(num_ents_bucket)
+ # TODO change the hardcoding from 3 to actual hyperparam of
+ # optim
+ opt_param = np.zeros(
+ shape=(
+ num_ents_bucket,
+ num_optimizer_hyperparams,
+ self.internal_k,
+ ),
+ dtype=np.float32,
+ )
+ ent_emb = self._model.encoding_layer.ent_init(
+ shape=(num_ents_bucket, self.internal_k),
+ dtype=tf.float32,
+ ).numpy()
+ ent_partition.update({out_dict_key: [opt_param, ent_emb]})
+
+ # create relation embeddings and optimizer hyperparams for all relations
+ # relations are not partitioned
+ with shelve.open(self.rel_map_fname, writeback=True) as rel_partition:
+ out_dict_key = str(0)
+ # TODO change the hardcoding from 3 to actual hyperparam of optim
+ opt_param = np.zeros(
+ shape=(
+ self.num_rels,
+ num_optimizer_hyperparams,
+ self.internal_k,
+ ),
+ dtype=np.float32,
+ )
+ rel_emb = self._model.encoding_layer.rel_init(
+ shape=(self.num_rels, self.internal_k), dtype=tf.float32
+ ).numpy()
+ rel_partition.update({out_dict_key: [opt_param, rel_emb]})
+
+ # for every partition
+ for i in range(len(self.partitioner.partitions)):
+ # get the source and dest bucket
+ # print(self.partitioner.partitions[i].backend.mapper.metadata['name'])
+ splits = (
+ self.partitioner.partitions[i]
+ .backend.mapper.metadata["name"]
+ .split("-")
+ )
+ source_bucket = splits[0][-1]
+ dest_bucket = splits[1]
+ all_keys_merged_buckets = []
+ # get all the unique entities present in the buckets
+ with shelve.open(
+ self.partitioner.files[int(source_bucket)]
+ ) as bucket:
+ all_keys_merged_buckets.extend(bucket["indexes"])
+ if source_bucket != dest_bucket:
+ with shelve.open(
+ self.partitioner.files[int(dest_bucket)]
+ ) as bucket:
+ all_keys_merged_buckets.extend(bucket["indexes"])
+
+ # since we would be concatenating the bucket embeddings, let's find what 0, 1, 2 etc indices of
+ # embedding matrix means.
+ # bucket entity value to ent_emb matrix index mappings eg: 2001 ->
+ # 0, 2002->1, 2003->2, ...
+ merged_bucket_to_ent_mat_mappings = {}
+ for key, val in zip(
+ all_keys_merged_buckets,
+ np.arange(0, len(all_keys_merged_buckets)),
+ ):
+ merged_bucket_to_ent_mat_mappings[key] = val
+ emb_mat_order = []
+
+ # partitions do not contain all entities of the bucket they belong to.
+ # they will produce data from 0->n idx. So we need to remap the get position of the
+ # entities of the partition in the concatenated emb matrix
+ # data_index -> original_ent_index -> ent_emb_matrix mappings (a->b->c) 0->2002->1, 1->2003->2
+ # (because 2001 may not exist in this partition)
+ # a->b mapping
+ num_ents_bucket = self.partitioner.partitions[
+ i
+ ].backend.mapper.get_entities_count()
+ sorted_partition_keys = np.arange(num_ents_bucket)
+ sorted_partition_values = self.partitioner.partitions[
+ i
+ ].backend.mapper.get_indexes(
+ sorted_partition_keys, type_of="e", order="ind2raw"
+ )
+ for val in sorted_partition_values:
+ # a->b->c mapping
+ emb_mat_order.append(
+ merged_bucket_to_ent_mat_mappings[int(val)]
+ )
+
+ # store it
+ with shelve.open(self.ent_meta_fname, writeback=True) as metadata:
+ metadata[str(i)] = emb_mat_order
+
+ rel_mat_order = []
+ # with
+ # shelve.open(self.partitioner.partitions[i].backend.mapper.metadata['relations'])
+ # as rel_sh:
+ num_rels_bucket = self.partitioner.partitions[
+ i
+ ].backend.mapper.get_relations_count()
+ sorted_partition_keys = np.arange(num_rels_bucket)
+ sorted_partition_values = self.partitioner.partitions[
+ i
+ ].backend.mapper.get_indexes(
+ sorted_partition_keys, type_of="r", order="ind2raw"
+ )
+ # a : 0 to n
+ for val in sorted_partition_values:
+ # a->b mapping
+ rel_mat_order.append(int(val))
+
+ with shelve.open(self.rel_meta_fname, writeback=True) as metadata:
+ metadata[str(i)] = rel_mat_order
+
+ def _update_partion_embeddings(self, graph_data_loader, partition_number):
+ """Persists the embeddings and other parameters after a partition is trained.
+
+ Parameters
+ ----------
+ graph_data_loader : GraphDataLoader
+ Data loader of the current partition that was trained.
+ partition_number: int
+ Partition number of the current partition that was trained.
+ """
+ # set the trained params back for persisting (exclude paddings)
+ self.all_ent_embs[
+ self.ent_original_ids
+ ] = self._model.encoding_layer.ent_emb.numpy()[
+ : len(self.ent_original_ids), :
+ ]
+ self.all_rel_embs[
+ self.rel_original_ids
+ ] = self._model.encoding_layer.rel_emb.numpy()[
+ : len(self.rel_original_ids), :
+ ]
+
+ # get the optimizer params related to the embeddings
+ (
+ ent_opt_hyperparams,
+ rel_opt_hyperparams,
+ ) = self._model.optimizer.get_entity_relation_hyperparams()
+
+ # get the number of params that are created by the optimizer
+ num_opt_hyperparams = self._model.optimizer.get_hyperparam_count()
+
+ # depending on optimizer, you can have 0 or more params
+ if num_opt_hyperparams > 0:
+ # store the params
+ original_ent_hyperparams = []
+ original_rel_hyperparams = []
+
+ # get all the different params related to entities and relations
+ # eg: beta1, beta2 related to embeddings (when using adam)
+ for i in range(num_opt_hyperparams):
+ original_ent_hyperparams.append(
+ ent_opt_hyperparams[i][: len(self.ent_original_ids)]
+ )
+ original_rel_hyperparams.append(
+ rel_opt_hyperparams[i][: len(self.rel_original_ids)]
+ )
+
+ # store for persistance
+ self.all_rel_opt_params[self.rel_original_ids, :, :] = np.stack(
+ original_rel_hyperparams, 1
+ )
+ self.all_ent_opt_params[self.ent_original_ids, :, :] = np.stack(
+ original_ent_hyperparams, 1
+ )
+
+ # Open the buckets related to the partition and concat
+ splits = graph_data_loader.backend.mapper.metadata["name"].split("-")
+ source_bucket = splits[0][-1]
+ dest_bucket = splits[1]
+
+ try:
+ # persist entity related embs and optim params
+ s = shelve.open(self.ent_map_fname, writeback=True)
+
+ # split and save self.all_ent_opt_params and self.all_ent_embs into
+ # respective buckets
+ opt_params = [
+ self.all_ent_opt_params[: self.split_opt_idx],
+ self.all_ent_opt_params[self.split_opt_idx:],
+ ]
+ emb_params = [
+ self.all_ent_embs[: self.split_emb_idx],
+ self.all_ent_embs[self.split_emb_idx:],
+ ]
+
+ s[source_bucket] = [opt_params[0], emb_params[0]]
+ s[dest_bucket] = [opt_params[1], emb_params[1]]
+
+ finally:
+ s.close()
+
+ try:
+ # persist relation related embs and optim params
+ s = shelve.open(self.rel_map_fname, writeback=True)
+ s["0"] = [self.all_rel_opt_params, self.all_rel_embs]
+
+ finally:
+ s.close()
+
+ def _change_partition(self, graph_data_loader, partition_number):
+ """Gets a new partition to train and loads all the parameters of the partition.
+
+ Parameters
+ ----------
+ graph_data_loader : GraphDataLoader
+ Data loader of the next partition that will be trained.
+ partition_number: int
+ Partition number of the next partition will be trained.
+ """
+ try:
+ # open the meta data related to the partition
+ s = shelve.open(self.ent_meta_fname)
+ # entities mapping ids
+ self.ent_original_ids = s[str(partition_number)]
+ finally:
+ s.close()
+
+ try:
+ s = shelve.open(self.rel_meta_fname)
+ # entities mapping ids
+ self.rel_original_ids = s[str(partition_number)]
+
+ finally:
+ s.close()
+
+ # Open the buckets related to the partition and concat
+ splits = graph_data_loader.backend.mapper.metadata["name"].split("-")
+ source_bucket = splits[0][-1]
+ dest_bucket = splits[1]
+
+ try:
+ s = shelve.open(self.ent_map_fname)
+ source_bucket_params = s[source_bucket]
+ dest_source_bucket_params = s[dest_bucket]
+ # full ent embs
+ self.all_ent_embs = np.concatenate(
+ [source_bucket_params[1], dest_source_bucket_params[1]]
+ )
+ self.split_emb_idx = source_bucket_params[1].shape[0]
+
+ self.all_ent_opt_params = np.concatenate(
+ [source_bucket_params[0], dest_source_bucket_params[0]]
+ )
+ self.split_opt_idx = source_bucket_params[0].shape[0]
+
+ # now select only partition embeddings
+ ent_embs = self.all_ent_embs[self.ent_original_ids]
+ ent_opt_params = self.all_ent_opt_params[self.ent_original_ids]
+ finally:
+ s.close()
+
+ try:
+ s = shelve.open(self.rel_map_fname)
+ # full rel embs
+ self.all_rel_embs = s["0"][1]
+ self.all_rel_opt_params = s["0"][0]
+ # now select only partition embeddings
+ rel_embs = self.all_rel_embs[self.rel_original_ids]
+ rel_opt_params = self.all_rel_opt_params[self.rel_original_ids]
+ finally:
+ s.close()
+
+ # notify the model about the partition change
+ self._model.partition_change_updates(
+ len(self.ent_original_ids), ent_embs, rel_embs
+ )
+
+ # Optimizer params will exist only after it has been persisted once
+ if self._model.current_epoch > 1 or (
+ self._model.current_epoch == 1
+ and partition_number > self.partitioner_k
+ ):
+ # TODO: needs to be better handled
+ # get the optimizer params of the embs that will be trained
+ rel_optim_hyperparams = []
+ ent_optim_hyperparams = []
+
+ num_opt_hyperparams = self._model.optimizer.get_hyperparam_count()
+ for i in range(num_opt_hyperparams):
+ rel_hyperparam_i = rel_opt_params[:, i, :]
+ rel_hyperparam_i = np.pad(
+ rel_hyperparam_i,
+ ((0, self.num_rels - rel_hyperparam_i.shape[0]), (0, 0)),
+ "constant",
+ constant_values=(0),
+ )
+ rel_optim_hyperparams.append(rel_hyperparam_i)
+
+ ent_hyperparam_i = ent_opt_params[:, i, :]
+ ent_hyperparam_i = np.pad(
+ ent_hyperparam_i,
+ (
+ (0, self.max_ent_size - ent_hyperparam_i.shape[0]),
+ (0, 0),
+ ),
+ "constant",
+ constant_values=(0),
+ )
+ ent_optim_hyperparams.append(ent_hyperparam_i)
+
+ # notify the optimizer and update the optimizer hyperparams
+ self._model.optimizer.set_entity_relation_hyperparams(
+ ent_optim_hyperparams, rel_optim_hyperparams
+ )
+
+ def on_complete(self):
+ """Activities to be performed on end of training.
+
+ The manager persists the data (splits the entity partitions into individual embeddings).
+ """
+ for i in range(self.partitioner_k - 1, -1, -1):
+ with shelve.open(self.partitioner.files[i]) as bucket:
+ with shelve.open(
+ self.ent_map_fname, writeback=True
+ ) as ent_partition:
+ # get the bucket embeddings
+ # split and store separately
+ for key, val in zip(
+ bucket["indexes"], ent_partition[str(i)][1]
+ ):
+ ent_partition[str(key)] = val
+ if i != 0:
+ del ent_partition[str(i)]
+ with shelve.open(self.rel_map_fname, writeback=True) as rel_partition:
+ # get the bucket embeddings
+ # split and store separately
+ for key in range(rel_partition["0"][1].shape[0] - 1, -1, -1):
+ rel_partition[str(key)] = rel_partition["0"][1][key]
+
+
+def get_partition_adapter(
+ dataset_loader,
+ model,
+ strategy="Bucket",
+ partitioning_k=3,
+ root_directory=None,
+):
+ """Returns partition manager depending on the one registered by the partitioning strategy.
+
+ Parameters
+ ----------
+ dataset_loader: AbstractGraphPartitioner or GraphDataLoader
+ Parent dataset loader that will be used for partitioning.
+ model: tf.keras.Model
+ KGE model that will be managed while using partitioning.
+ strategy: str
+ Graph partitioning strategy.
+ """
+ if isinstance(dataset_loader, AbstractGraphPartitioner):
+ partitioner_manager = PARTITION_MANAGER_REGISTRY.get(
+ dataset_loader.manager
+ )(
+ dataset_loader,
+ model,
+ dataset_loader.name,
+ dataset_loader._k,
+ root_directory,
+ )
+
+ else:
+ partitioner = PARTITION_ALGO_REGISTRY.get(strategy)(
+ dataset_loader, k=partitioning_k
+ )
+ partitioner_manager = PARTITION_MANAGER_REGISTRY.get(
+ partitioner.manager
+ )(partitioner, model, strategy, partitioning_k, root_directory)
+
+ return partitioner_manager
diff --git a/ampligraph/datasets/partitioning_reporter.py b/ampligraph/datasets/partitioning_reporter.py
new file mode 100644
index 00000000..5af20ee8
--- /dev/null
+++ b/ampligraph/datasets/partitioning_reporter.py
@@ -0,0 +1,452 @@
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
+#
+# This file is Licensed under the Apache License, Version 2.0.
+# A copy of the Licence is available in LICENCE, or at:
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+"""Reporting for graph partition strategies.
+
+This module provides reporting capabilities for partitioning strategies.
+"""
+import copy
+import logging
+
+import matplotlib.pyplot as plt
+import numpy as np
+from matplotlib.pyplot import cm
+
+from .datasets import load_fb15k_237
+from .graph_partitioner import RandomVerticesGraphPartitioner
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.DEBUG)
+
+
+class PartitioningReporter:
+ """Assesses the quality of partitioning according to chosen metrics and report it.
+
+ Available metrics: edge cut, edge imbalance, vertex imbalance, time, memory usage.
+
+ Parameters
+ ----------
+ partitionings:
+ Data splits to be compared.
+
+ Example
+ -------
+
+ >>>>quality = PartitioningReporter(partitionings)
+ >>>>report = quality.report(visualize=False)
+ """
+
+ def __init__(self, partitionings):
+ """Initialises PartitioningReporter.
+
+ Parameters
+ ----------
+ partitionings:
+ List of partitioning strategies.
+ """
+ self.partitionings = partitionings
+
+ def get_edge_cut(self, k, partitions, avg_size=None):
+ """Calculates mean edge cut across partitions in a single partitioning.
+
+ Parameters
+ ----------
+ k: int
+ Number of partitions.
+ partitions:
+ Partitions in one partitioning.
+
+ Returns
+ -------
+ edge_cut: ndarray
+ Average edge cut between partitions.
+ """
+
+ intersections = []
+ logger.debug(partitions)
+ for partition1 in partitions:
+ logger.debug("Partition 1: {}".format(partition1))
+ intersect = []
+ for partition2 in partitions:
+ if partition1 == partition2:
+ continue
+ inter = partition1.intersect(partition2)
+ logger.debug("Intersections: {}".format(inter))
+ intersect.append(len(inter))
+ logger.debug("Partition 2: {}".format(partition2))
+ intersections.append(np.mean(intersect))
+ logger.debug("Intersections: {}".format(intersections))
+ edge_cut = np.mean(intersections)
+ edge_cut_proportion = None
+ if avg_size:
+ # edge cut with respect to the average partition size
+ edge_cut_proportion = (edge_cut * 100) / avg_size
+ return edge_cut, edge_cut_proportion
+
+ def get_edge_imbalance(self, avg_size, max_size):
+ """Calculates edge imbalance of partitions.
+
+ Parameters
+ ----------
+ avg_size: int
+ Average size of partition.
+ max_size: int
+ Maximum size of partition.
+
+ Returns
+ -------
+ edge_imb: float
+ Edge imbalance
+ """
+
+ edge_imb = max_size / avg_size - 1
+ return edge_imb
+
+ def get_vertex_imbalance_and_count(self, partitions, vertex_count=False):
+ """Calculates vertex imbalance of partitions, vertex count - counts number
+ of vertices in each partition that estimates the size of partition.
+
+
+ Parameters
+ ----------
+ partitions:
+ Partitions in one partitioning.
+
+ Returns
+ -------
+ vertex_imb: float
+ Vertex imbalance.
+ vertex_count: list
+ List of counts, e.g., for 2 partitions the list will be of size two with vertex count for each
+ partition: (5,6), for 3 partitions: (5,2,4).
+ """
+ lengths = []
+ for partition in partitions:
+ ents_len = partition.backend.mapper.get_entities_count()
+ lengths.append(ents_len)
+
+ vertex_imb = np.max(lengths) / np.mean(lengths) - 1
+ if vertex_count:
+ return vertex_imb, lengths
+ else:
+ return vertex_imb
+
+ def get_average_deviation_from_ideal_size_vertices(self, partitions):
+ """Metric that calculates the average difference between the
+ ideal size of partition (in terms of vertices) and the real size.
+
+ It is expressed as the percentage deviation from the ideal size.
+
+ Parameters
+ ----------
+ partitions:
+ Partitions in one partitioning.
+
+ Returns
+ -------
+ percentage_dev: float
+ Percentage vertex size partition deviation
+ """
+ k = len(partitions)
+ sizes = []
+ for partition in partitions:
+ ents_len = partition.backend.mapper.get_entities_count()
+ sizes.append(ents_len)
+ data_size = ents_len
+ ideal_size = data_size / k
+ percentage_dev = (
+ (np.sum([np.abs(ideal_size - size) for size in sizes]) / k)
+ / ideal_size
+ ) * 100
+ return percentage_dev
+
+ def get_average_deviation_from_ideal_size_edges(self, partitions):
+ """Metric that calculates the average difference between the
+ ideal size of partition (in terms of edges) and the real size.
+
+ It is expressed as percentage deviation from ideal size.
+
+ Parameters
+ ----------
+ partitions:
+ Partitions in one partitioning.
+
+ Returns
+ -------
+ percentage_dev: float
+ Percentage edge size partition deviation.
+ """
+
+ k = len(partitions)
+ sizes = []
+ for partition in partitions:
+ sizes.append(partition.get_data_size())
+ logger.debug("Parent: {}".format(partition.parent.backend.data))
+ data_size = partition.parent.get_data_size()
+ logger.debug("Parent data size: {}".format(data_size))
+ ideal_size = data_size / k
+ logger.debug("Ideal data size: {}".format(ideal_size))
+ percentage_dev = (
+ (np.sum([np.abs(ideal_size - size) for size in sizes]) / k)
+ / ideal_size
+ ) * 100
+ return percentage_dev
+
+ def get_edges_count(self, partitions):
+ """Counts number of edges in each partition that estimates the size of partition.
+
+ Parameters
+ ---------
+ partitions:
+ Partitions in one partitioning.
+
+ Returns
+ -------
+ info: list
+ List of counts, e.g. for 2 partitions the list will be of size two with the edge count: (10,12);
+ for 3 partitions: (7,8,7).
+ """
+ info = []
+ for partition in partitions:
+ edges = partition.get_data_size()
+ info.append(edges)
+
+ return info
+
+ def get_modularity(self):
+ """Calculates modularity of partitions."""
+ raise NotImplementedError
+
+ def report_single_partitioning(
+ self, partitioning, EDGE_CUT=True, EDGE_IMB=True, VERTEX_IMB=True
+ ):
+ """Calculate different metrics for a single partition.
+
+ Parameters
+ ----------
+ partitioning:
+ Single split of data into partitions.
+ EDGE_CUT : bool
+ Flag whether to calculate edge cut or not.
+ EDGE_IMB : bool
+ Flag whether to calculate edge imbalance or not.
+ VERTEX_IMB : bool
+ Flag whether to calculate vertex imbalance or not.
+
+ Returns
+ -------
+ metrics: dict
+ Dictionary with metrics.
+ """
+ logs = partitioning[1]
+ partitioning = partitioning[0]
+ tmp = partitioning.get_data()
+ k = tmp.get_data_size()
+ partitioner = partitioning
+ partitioning = partitioner.get_partitions_list()
+ sizes = [x.get_data_size() for x in partitioning]
+ avg_size = np.mean(sizes)
+ max_size = np.max(sizes)
+ metrics = {"EDGE_IMB": None, "VERTEX_IMB": None, "EDGE_CUT": None}
+
+ if logs:
+ metrics["PARTITIONING TIME"] = logs["_SPLIT"]["time"]
+ metrics["PARTITIONING MEMORY"] = logs["_SPLIT"]["memory-bytes"]
+ if EDGE_CUT:
+ edge_cut, edge_cut_proportion = self.get_edge_cut(
+ k, partitioning, avg_size
+ )
+ metrics["EDGE_CUT"] = edge_cut
+ metrics["EDGE_CUT_PERCENTAGE"] = edge_cut_proportion
+ if EDGE_IMB:
+ edge_imb = self.get_edge_imbalance(avg_size, max_size)
+ metrics["EDGE_IMB"] = edge_imb
+ if VERTEX_IMB:
+ vertex_imb, vertex_count = self.get_vertex_imbalance_and_count(
+ partitioning, vertex_count=True
+ )
+ metrics["VERTEX_IMB"] = vertex_imb
+ metrics["VERTEX_COUNT"] = vertex_count
+ metrics["EDGES_COUNT"] = self.get_edges_count(partitioning)
+ metrics[
+ "PERCENTAGE_DEV_EDGES"
+ ] = self.get_average_deviation_from_ideal_size_edges(partitioning)
+ metrics[
+ "PERCENTAGE_DEV_VERTICES"
+ ] = self.get_average_deviation_from_ideal_size_vertices(partitioning)
+ partitioner.clean()
+ return metrics
+
+ def report(
+ self, visualize=True, barh=True
+ ): # TODO: include plotting parameters
+ """Collect individual reports for every partitioning.
+
+ Parameters
+ ----------
+ visualize : bool
+ Flag indicating whether to visualize output.
+
+ Returns
+ -------
+ reports: dict
+ Calculated metrics for all partitionings stored in a dictionary with keys the numbers of partitions
+ and values the dictionary with metrics.
+ """
+ reports = {}
+ for name, partitioning in self.partitionings.items():
+ reports[name] = self.report_single_partitioning(
+ partitioning, EDGE_IMB=True, VERTEX_IMB=True
+ )
+ k = len(self.partitionings[list(self.partitionings.keys())[0]][1])
+ if visualize:
+ plt.figure(
+ figsize=(15, 15 + 0.3 * k + 0.1 * len(self.partitionings))
+ )
+ ind = 1
+ row_size = 3
+ size = int(len(reports[list(reports.keys())[0]]) / row_size) + 1
+ for metric in reports[list(reports.keys())[0]]:
+ plot = False
+ dat = []
+ color = iter(cm.PiYG(np.linspace(0, 1, len(reports))))
+ colors_aggregate = {r: next(color) for r in reports}
+ for j, report in enumerate(reports):
+ if reports[report][metric] is not None:
+ if isinstance(reports[report][metric], list):
+ n = len(reports[report][metric])
+ color = iter(cm.seismic(np.linspace(0, 1, n)))
+ colors = {
+ "partition {}".format(i): next(color)
+ for i in range(n)
+ }
+ width = 0.8 / n
+ for i, r in enumerate(reports[report][metric]):
+ label = "partition {}".format(i)
+ dat.append(
+ {
+ "y": j + (i * width),
+ "width": r,
+ "height": width,
+ "label": label,
+ "label2": str(report),
+ "color": colors[label],
+ }
+ )
+ else:
+ colors = colors_aggregate
+ label = str(report)
+ dat.append(
+ {
+ "y": j,
+ "width": reports[report][metric],
+ "label2": label,
+ "color": colors[label],
+ }
+ )
+ plot = True
+ if plot:
+ plt.subplots_adjust(wspace=0.1, hspace=0.4)
+ plt.subplot(size, row_size, ind)
+
+ if barh:
+ unpacked = {k: [dic[k] for dic in dat] for k in dat[0]}
+ data = copy.deepcopy(unpacked)
+ del unpacked["label2"]
+ plt.barh(**unpacked, edgecolor="white")
+ else:
+ plt.bar(*list(zip(*dat)), edgecolor="white")
+
+ labels = list(colors.keys())
+ # handles = [plt.Rectangle((0, 0), 1, 1, color=colors[label]) for label in labels]
+ labels = []
+ for elem in data["label2"]:
+ if elem not in labels:
+ labels.append(elem)
+ # if type(labels) == set:
+ if (ind % row_size) == 1:
+ plt.yticks(range(len(list(labels))), list(labels))
+ else:
+ plt.yticks([])
+ plt.title(metric)
+ plt.xticks(rotation=70)
+ ind += 1
+ plt.show()
+ return reports
+
+
+def compare_partitionings(
+ list_of_partitioners, data, num_partitions=2, visualize=True
+):
+ """Wrapper around PartitioningReporter hiding logging settings.
+
+ Parameters
+ ---------
+ list_of_partitioners: list
+ List of uninitialized partitioners.
+ data: ndarray
+ Numpy array with the graph to be split into partitions.
+ num_partitions: int
+ Number of partitions required.
+ visualize : bool
+ Flag whether to visualize results or not.
+
+ Returns
+ -------
+ result: dict
+ Dictionary with metrics evaluating partitionings.
+
+ Example
+ -------
+ >>>partitioners = [NaiveGraphPartitioner,
+ SortedEdgesGraphPartitioner,
+ DoubleSortedEdgesGraphPartitioner]
+ >>>report = compare_partitionings(partitioners)
+ """
+ if isinstance(num_partitions, int):
+ n_partitions = [num_partitions] * len(list_of_partitioners)
+ else:
+ n_partitions = num_partitions
+ partitionings = {}
+ for partitioner, n in zip(list_of_partitioners, n_partitions):
+ logger.debug("Running: {}".format(partitioner.__name__))
+ logs = {}
+ if n != 0:
+ data.reload()
+ partitioner_fitted = partitioner(data, k=n, log=logs)
+ partitionings[partitioner.__name__] = (partitioner_fitted, logs)
+ reporter = PartitioningReporter(partitionings=partitionings)
+ result = reporter.report(visualize=visualize, barh=True)
+ return result
+
+
+def main():
+ """Main function with example usage."""
+ from ampligraph.datasets import GraphDataLoader
+ from ampligraph.datasets.sqlite_adapter import SQLiteAdapter
+
+ sample = load_fb15k_237()["train"]
+ data = GraphDataLoader(sample, backend=SQLiteAdapter, in_memory=False)
+ partitioners = [RandomVerticesGraphPartitioner]
+ report = compare_partitionings(partitioners, data, visualize=False)
+ print(report)
+
+
+# Expected output:
+# {'RandomVerticesGraphPartitioner': {'EDGE_IMB': 0.40953499098494706,
+# 'VERTEX_IMB': 0.03495702005730661,
+# 'EDGE_CUT': 6736.0,
+# 'PARTITIONING TIME': 139.55057835578918,
+# 'PARTITIONING MEMORY': 7473904,
+# 'EDGE_CUT_PERCENTAGE': 9.414790277719542,
+# 'VERTEX_COUNT': [7224, 6736],
+# 'EDGES_COUNT': [100848, 42246],
+# 'PERCENTAGE_DEV_EDGES': 47.41414475497492,
+# 'PERCENTAGE_DEV_VERTICES': 3.757325060324026}}
+
+if __name__ == "__main__":
+ main()
diff --git a/ampligraph/datasets/source_identifier.py b/ampligraph/datasets/source_identifier.py
new file mode 100644
index 00000000..13f22c67
--- /dev/null
+++ b/ampligraph/datasets/source_identifier.py
@@ -0,0 +1,175 @@
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
+#
+# This file is Licensed under the Apache License, Version 2.0.
+# A copy of the Licence is available in LICENCE, or at:
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+"""Data source identifier.
+
+This module provides the main class and the supporting functions for automatic
+identification of data source (whether it is csv, tar.gz or numpy array)
+and provides adequate loader for the data source identified.
+"""
+import logging
+from collections.abc import Iterable
+from itertools import chain, islice
+
+import numpy as np
+import pandas as pd
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.DEBUG)
+
+
+def load_csv(data_source, chunk_size=None, sep="\t", verbose=False, **kwargs):
+ """CSV data loader.
+
+ Parameters
+ ---------
+ data_source: str
+ csv file with data, separated by ``sep``.
+ chunk_size: int
+ The size of chunk to be used while reading the data. If used, the returned type is
+ an iterator and not a numpy array.
+ sep: str
+ Separator in the csv file, e.g. line "1,2,3\n" has ``sep=","``, while "1 2 3\n" has ``sep=" "``.
+
+ Returns
+ -------
+ data: ndarray or iter
+ Either a numpy array with data or a lazy iterator if ``chunk_size`` was provided.
+ """
+ data = pd.read_csv(
+ data_source, sep=sep, chunksize=chunk_size, header=None, **kwargs
+ )
+ logger.debug("data type: {}".format(type(data)))
+ logger.debug("CSV loaded, into iterator data.")
+
+ if isinstance(data, pd.DataFrame):
+ return data.values
+ else:
+ return data
+
+
+def load_json(data_source, orient="records", chunksize=None):
+ """json files data loader.
+
+ Parameters
+ ----------
+ data_source : str
+ Path to a .json file.
+ orient : str
+ Indicates the expected .json file format. The default ``orient="records"`` assumes the knowledge graph is
+ stored as a list like `[{subject_1: value, predicate_1: value, object_1: value}, ...,
+ {subject_n: value, predicate_n: value, object_n: value}]`. If looking for more options check the
+ `Pandas `_ website.
+ chunksize : int
+ The size of chunk to be used while reading the data. If used, the returned type is
+ an iterator and not a numpy array.
+
+
+ Returns
+ -------
+ data : ndarray or iter
+ Either a numpy array with data or a lazy iterator if ``chunk_size`` was provided.
+ """
+ if chunksize is not None:
+ data = pd.read_json(
+ data_source, orient=orient, lines=True, chunksize=chunksize
+ )
+ else:
+ data = pd.read_json(data_source, orient=orient)
+ logger.debug("data type: {}".format(type(data)))
+ logger.debug("JSON loaded into iterator data.")
+
+ return data.values
+
+
+def chunks(iterable, chunk_size=1):
+ """Chunks generator."""
+ iterator = iter(iterable)
+ for first in iterator:
+ yield np.array(list(chain([first], islice(iterator, chunk_size - 1))))
+
+
+def load_gz(data_source, chunk_size=None, verbose=False):
+ """Gz data loader. Reads compressed file."""
+ raise NotImplementedError
+
+
+def load_tar(data_source, chunk_size=None, verbose=False):
+ """Tar data loader. Reads compressed file."""
+ raise NotImplementedError
+
+
+class DataSourceIdentifier:
+ """Class that recognizes the type of given file and provides with an
+ adequate loader.
+
+ Properties
+ ----------
+ supported_types: dict
+ Dictionary of supported types along with their adequate loaders, to support a new data type, this
+ dictionary needs to be updated with the file extension as key and the loading function name as value.
+
+ Example
+ -------
+ >>>identifier = DataSourceIdentifier("data.csv")
+ >>>loader = identifier.fetch_loader()
+ >>>X = loader("data.csv")
+ """
+
+ def __init__(self, data_source, verbose=False):
+ """Initialise DataSourceIdentifier.
+
+ Parameters
+ ----------
+ data_source: str
+ Name of a file to be recognized.
+ """
+ self.verbose = verbose
+ self.data_source = data_source
+ self.supported_types = {
+ "csv": load_csv,
+ "txt": load_csv,
+ "gz": load_csv,
+ "json": load_json,
+ "tar": load_tar,
+ "iter": chunks,
+ }
+ self._identify()
+
+ def fetch_loader(self):
+ """Returns adequate loader required to read identified file."""
+ logger.debug(
+ "Return adequate loader that provides loading of data source."
+ )
+ return self.supported_types[self.src]
+
+ def get_src(self):
+ """Returns identified source type."""
+ return self.src
+
+ def _identify(self):
+ """Identifies the data file type based on the file name."""
+ if isinstance(self.data_source, str):
+ self.src = (
+ self.data_source.split(".")[-1]
+ if "." in self.data_source
+ else None
+ )
+ if self.src is not None and self.src not in self.supported_types:
+ logger.debug(
+ "File type not supported! Supported types: {}".format(
+ ", ".join(self.supported_types)
+ )
+ )
+ self.src = None
+ else:
+ logger.debug("data_source is an object")
+ if isinstance(self.data_source, Iterable):
+ self.src = "iter"
+ logger.debug("data_source is an iterable")
+ else:
+ logger.error("Object type not supported")
diff --git a/ampligraph/datasets/sqlite_adapter.py b/ampligraph/datasets/sqlite_adapter.py
index 6d271970..d35ed649 100644
--- a/ampligraph/datasets/sqlite_adapter.py
+++ b/ampligraph/datasets/sqlite_adapter.py
@@ -1,466 +1,949 @@
-# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
#
# This file is Licensed under the Apache License, Version 2.0.
# A copy of the Licence is available in LICENCE, or at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
-import numpy as np
-from ..datasets import AmpligraphDatasetAdapter
-import tempfile
-import sqlite3
-import time
-import os
+"""SQLite backend for storing graphs.
+
+This module provides SQLite backend for GraphDataLoader.
+
+Attributes
+----------
+DEFAULT_CHUNKSIZE: int
+ Size of data that can be at once loaded to the memory, number of rows,
+ should be set according to available
+ hardware capabilities (default: 30000).
+"""
import logging
+import os
+import sqlite3
+import tempfile
+from sqlite3 import Error
+from urllib.request import pathname2url
+
+import numpy as np
+import pandas as pd
+import tensorflow as tf
+
+from ampligraph.utils.profiling import get_human_readable_size
+
+from .data_indexer import DataIndexer
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
+DEFAULT_CHUNKSIZE = 30000
+
+
+class SQLiteAdapter:
+ """Class implementing database connection.
+
+ Example
+ -------
+ >>> AMPLIGRAPH_DATA_HOME='/your/path/to/datasets/'
+ >>> # Initialize GraphDataLoader from .csv file
+ >>> data = GraphDataLoader("data.csv", backend=SQLiteAdapter)
+ >>> # Initialize GraphDataLoader from .txt file using indexer
+ >>> # to map entities to integers
+ >>> data = GraphDataLoader(AMPLIGRAPH_DATA_HOME + "fb15k/test.txt",
+ >>> backend=SQLiteAdapter("database.db",
+ >>> use_indexer=True))
+ >>> for elem in data:
+ >>> print(elem)
+ >>> break
+ [(1, 1, 2)]
+ >>> # Populate the database with raw triples for training
+ >>> with SQLiteAdapter("database.db") as backend:
+ >>> backend.populate(AMPLIGRAPH_DATA_HOME + "fb15k/train.txt",
+ >>> dataset_type="train")
+ >>> # Populate the database with indexed triples for training
+ >>> with SQLiteAdapter("database.db", use_indexer=True) as backend:
+ >>> backend.populate(AMPLIGRAPH_DATA_HOME + "fb15k/train.txt",
+ >>> dataset_type="train")
+ """
+
+ def __init__(
+ self,
+ db_name,
+ identifier=None,
+ chunk_size=DEFAULT_CHUNKSIZE,
+ root_directory=None,
+ use_indexer=True,
+ verbose=False,
+ remap=False,
+ name="main_partition",
+ parent=None,
+ in_memory=True,
+ use_filter=False,
+ ):
+ """Initialise SQLiteAdapter.
-class SQLiteAdapter(AmpligraphDatasetAdapter):
- '''SQLLite adapter
- '''
-
- def __init__(self, existing_db_name=None, ent_to_idx=None, rel_to_idx=None):
- """Initialize the class variables
Parameters
----------
- existing_db_name : string
- Name of an existing database to use.
- Assumes that the database has schema as required by the adapter and the persisted data is already mapped
- ent_to_idx : dictionary of mappings
- Mappings of entity to idx
- rel_to_idx : dictionary of mappings
- Mappings of relation to idx
+ db_name: str
+ Name of the database.
+ chunk_size: int
+ Size of a chunk to read data from while feeding
+ the database (default: DEFAULT_CHUNKSIZE).
+ root_directory: str
+ Path to a directory where the database will be created,
+ and the data and mappings will be stored. If `None`, the root
+ directory is obtained through the :meth:`tempfile.gettempdir()
+ method (default: `None`).
+ use_indexer: DataIndexer or bool
+ Object of type DataIndexer with pre-defined mapping or bool
+ flag to tell whether data
+ should be indexed.
+ remap: bool
+ Whether to remap or not (shouldn't be used here) -
+ - NotImplemented here.
+ parent:
+ Not Implemented.
+ verbose:
+ Print status messages.
+ """
+ self.db_name = db_name
+ self.verbose = verbose
+ if identifier is None:
+ msg = "You need to provide source identifier object"
+ logger.error(msg)
+ raise Exception(msg)
+ else:
+ self.identifier = identifier
+
+ self.flag_db_open = False
+
+ if root_directory is None:
+ self.root_directory = tempfile.gettempdir()
+ else:
+ self.root_directory = root_directory
+
+ self.db_path = os.path.join(self.root_directory, self.db_name)
+ self.use_indexer = use_indexer
+ self.remap = remap
+ if self.remap:
+ msg = "Remapping is not supported for DataLoaders with SQLite\
+ Adapter as backend"
+
+ logger.error(msg)
+ raise Exception(msg)
+ self.name = name
+ self.parent = parent
+ self.in_memory = in_memory
+ self.use_filter = use_filter
+ self.sources = {}
+
+ if chunk_size is None:
+ chunk_size = DEFAULT_CHUNKSIZE
+ logger.debug(
+ "Currently {} only supports data given in chunks. \
+ Setting chunksize to {}.".format(
+ self.__name__(), DEFAULT_CHUNKSIZE
+ )
+ )
+ else:
+ self.chunk_size = chunk_size
+
+ def get_output_signature(self):
+ """Get the output signature of the tf.data.Dataset object."""
+ triple_tensor = tf.TensorSpec(shape=(None, 3), dtype=tf.int32)
+
+ # focusE
+ if self.data_shape > 3:
+ weights_tensor = tf.TensorSpec(
+ shape=(None, self.data_shape - 3), dtype=tf.float32
+ )
+ if self.use_filter:
+ return (
+ triple_tensor,
+ tf.RaggedTensorSpec(shape=(2, None, None), dtype=tf.int32),
+ weights_tensor,
+ )
+ return (triple_tensor, weights_tensor)
+ if self.use_filter:
+ return (
+ triple_tensor,
+ tf.RaggedTensorSpec(shape=(2, None, None), dtype=tf.int32),
+ )
+ return triple_tensor
+
+ def open_db(self):
+ """Open the database."""
+ db_uri = "file:{}?mode=rw".format(pathname2url(self.db_path))
+ self.connection = sqlite3.connect(db_uri, uri=True)
+ self.flag_db_open = True
+ logger.debug("----------------DB OPENED - normally -----------------")
+
+ def open_connection(self):
+ """Context manager function to open (or create if it does not exist)
+ a database connection.
+
+ """
+ if not self.flag_db_open:
+ try:
+ self.open_db()
+ except sqlite3.OperationalError:
+ logger.debug("Database does not exists. Creating one.")
+ self.connection = sqlite3.connect(self.db_path)
+ self.connection.commit()
+ self.connection.close()
+
+ self._create_database()
+ self.open_db()
+
+ def __enter__(self):
+ self.open_connection()
+ return self
+
+ def __exit__(self, type, value, tb):
+ """Context manager exit function, to be used with "with statement",
+ closes the connection and do the rollback if required.
+
"""
- super(SQLiteAdapter, self).__init__()
- # persistance status of the data
- self.persistance_status = {}
- self.dbname = existing_db_name
- # flag indicating whether we are using existing db
- self.using_existing_db = False
- self.temp_dir = None
- if self.dbname is not None:
- # If we are using existing db then the mappings need to be passed
- assert (self.rel_to_idx is not None)
- assert (self.ent_to_idx is not None)
-
- self.using_existing_db = True
- self.rel_to_idx = rel_to_idx
- self.ent_to_idx = ent_to_idx
-
- def get_db_name(self):
- """Returns the db name
+ if self.flag_db_open:
+ if tb is None:
+ self.connection.commit()
+ self.connection.close()
+ else:
+ # Exception occurred, so rollback.
+ self.connection.rollback()
+ self.flag_db_open = False
+ logger.debug("!!!!!!!!----------------DB CLOSED ----------------")
+
+ def _add_dataset(self, data_source, dataset_type):
+ """Load the data."""
+ self._load(data_source, dataset_type)
+
+ def _get_db_schema(self):
+ """Defines SQL queries to create a table with triples and indexes to
+ navigate easily on pairs subject-predicate, predicate-object.
+
+ Returns
+ -------
+ db_schema: list
+ List of SQL commands to create tables and indexes.
"""
- return self.dbname
+ if self.data_shape < 4:
+ db_schema = [
+ """CREATE TABLE triples_table (subject integer,
+ predicate integer,
+ object integer,
+ dataset_type text(50)
+ );""",
+ "CREATE INDEX triples_table_sp_idx ON triples_table (subject, predicate);",
+ "CREATE INDEX triples_table_po_idx ON triples_table (predicate, object);",
+ "CREATE INDEX triples_table_type_idx ON triples_table (dataset_type);",
+ "CREATE INDEX triples_table_sub_obj_idx ON triples_table (subject, object);",
+ "CREATE INDEX triples_table_subject_idx ON triples_table (subject);",
+ "CREATE INDEX triples_table_object_idx ON triples_table (object);",
+ ]
+ else: # focusE
+ db_schema = [
+ """CREATE TABLE triples_table (subject integer,
+ predicate integer,
+ object integer,
+ weight float,
+ dataset_type text(50)
+ );""",
+ "CREATE INDEX triples_table_sp_idx ON triples_table (subject, predicate);",
+ "CREATE INDEX triples_table_po_idx ON triples_table (predicate, object);",
+ "CREATE INDEX triples_table_type_idx ON triples_table (dataset_type);",
+ "CREATE INDEX triples_table_sub_obj_idx ON triples_table (subject, object);",
+ "CREATE INDEX triples_table_subject_idx ON triples_table (subject);",
+ "CREATE INDEX triples_table_object_idx ON triples_table (object);",
+ ]
+ return db_schema
+
+ def _get_clean_up(self):
+ """Defines SQL commands to clean the database (tables and indexes).
- def _create_schema(self):
- """Creates the database schema
+ Returns
+ -------
+ clean_up: list
+ List of SQL commands to clean tables and indexes.
"""
- if self.using_existing_db:
- return
- if self.dbname is not None:
- self.cleanup()
-
- self.temp_dir = tempfile.TemporaryDirectory(suffix=None, prefix='ampligraph_', dir=None)
- self.dbname = os.path.join(self.temp_dir.name, 'Ampligraph_{}.db'.format(int(time.time())))
-
- conn = sqlite3.connect("{}".format(self.dbname))
- cur = conn.cursor()
- cur.execute("CREATE TABLE entity_table (entity_type integer primary key);")
- cur.execute("CREATE TABLE triples_table (subject integer, \
- predicate integer, \
- object integer, \
- dataset_type text(50), \
- foreign key (object) references entity_table(entity_type), \
- foreign key (subject) references entity_table(entity_type) \
- );")
-
- cur.execute("CREATE INDEX triples_table_sp_idx ON triples_table (subject, predicate);")
- cur.execute("CREATE INDEX triples_table_po_idx ON triples_table (predicate, object);")
- cur.execute("CREATE INDEX triples_table_type_idx ON triples_table (dataset_type);")
-
- cur.execute("CREATE TABLE integrity_check (validity integer primary key);")
-
- cur.execute('INSERT INTO integrity_check VALUES (0)')
- conn.commit()
- cur.close()
- conn.close()
-
- def generate_mappings(self, use_all=False, regenerate=False):
- """Generate mappings from either train set or use all dataset to generate mappings
+ clean_up = [
+ "drop index IF EXISTS triples_table_po_idx",
+ "drop index IF EXISTS triples_table_sp_idx",
+ "drop index IF EXISTS triples_table_type_idx",
+ "drop table IF EXISTS triples_table",
+ ]
+ return clean_up
+
+ def _execute_query(self, query):
+ """Connects to the database and execute given query.
+
Parameters
----------
- use_all : boolean
- If True, it generates mapping from all the data. If False, it only uses training set to generate mappings
- regenerate : boolean
- If true, regenerates the mappings.
- If regenerating, then the database is created again(to conform to new mapping)
+ query: str
+ SQLite query to be executed.
+
Returns
-------
- rel_to_idx : dictionary
- Relation to idx mapping dictionary
- ent_to_idx : dictionary
- entity to idx mapping dictionary
+ output:
+ Result of a query with fetchall().
"""
- if (len(self.rel_to_idx) == 0 or len(self.ent_to_idx) == 0 or (regenerate is True)) \
- and (not self.using_existing_db):
- from ..evaluation import create_mappings
- self._create_schema()
- if use_all:
- complete_dataset = []
- for key in self.dataset.keys():
- complete_dataset.append(self.dataset[key])
- self.rel_to_idx, self.ent_to_idx = create_mappings(np.concatenate(complete_dataset, axis=0))
+ with self:
+ cursor = self.connection.cursor()
+ output = None
+ try:
+ cursor.execute(query)
+ output = cursor.fetchall()
+ self.connection.commit()
+ if self.verbose:
+ logger.debug(f"Query executed successfully, {query}")
+ except Error as e:
+ logger.debug(f"Query failed. The error '{e}' occurred")
+ return output
+
+ def _execute_queries(self, list_of_queries):
+ """Executes given list of queries one by one.
- else:
- self.rel_to_idx, self.ent_to_idx = create_mappings(self.dataset["train"])
+ Parameters
+ ----------
+ query: list
+ List of SQLite queries to be executed.
- self._insert_entities_in_db()
- return self.rel_to_idx, self.ent_to_idx
+ Returns
+ -------
+ output: TODO! result of queries with fetchall().
- def _insert_entities_in_db(self):
- """Inserts entities in the database
"""
- # TODO: can change it to just use the values of the dictionary
- pg_entity_values = np.arange(len(self.ent_to_idx)).reshape(-1, 1).tolist()
- conn = sqlite3.connect("{}".format(self.dbname))
- cur = conn.cursor()
- try:
- cur.executemany('INSERT INTO entity_table VALUES (?)', pg_entity_values)
- conn.commit()
- except sqlite3.Error:
- conn.rollback()
- cur.close()
- conn.close()
-
- def use_mappings(self, rel_to_idx, ent_to_idx):
- """Use an existing mapping with the datasource.
- """
- # cannot change mappings for an existing database.
- if self.using_existing_db:
- raise Exception('Cannot change the mappings for an existing DB')
- super().use_mappings(rel_to_idx, ent_to_idx)
- self._create_schema()
+ for query in list_of_queries:
+ self._execute_query(query)
+
+ def _insert_values_to_a_table(self, table, values):
+ """Insert data into a given table in a database.
- for key in self.dataset.keys():
- self.mapped_status[key] = False
- self.persistance_status[key] = False
+ Parameters
+ ----------
+ table: str
+ Table where to input data.
+ values: ndarray
+ Numpy array of data with shape (N,m) to be written to
+ the database. `N` is a number of entries, :math:`m=3`
+ if we only have triples and :math:`m>3` if we have numerical
+ weights associated with each triple.
+ """
+ with self:
+ if self.verbose:
+ logger.debug("inserting to a table...")
+ if len(np.shape(values)) < 2:
+ size = 1
+ else:
+ size = np.shape(values)[1]
+ cursor = self.connection.cursor()
+ try:
+ values_placeholder = "({})".format(", ".join(["?"] * size))
+ query = "INSERT INTO {} VALUES {}".format(
+ table, values_placeholder
+ )
+ precompute = [
+ (v,) if isinstance(v, int) or isinstance(v, str) else v
+ for v in values
+ ]
+ cursor.executemany(query, precompute)
+ self.connection.commit()
+ if self.verbose:
+ logger.debug("commited to table: {}".format(table))
+ except Error as e:
+ logger.debug("Error: {}".format(e))
+ # self.connection.rollback()
+ logger.debug("Values were inserted!")
+
+ def _create_database(self):
+ """Creates database."""
+ self._execute_queries(self._get_db_schema())
+
+ def _get_triples(self, subjects=None, objects=None, entities=None):
+ """Get triples whose objects belong to objects and subjects to
+ subjects, or, if not provided either object or subject, belong to
+ entities.
- self._insert_entities_in_db()
+ """
+ if subjects is None and objects is None:
+ if entities is None:
+ msg = "You have to provide either subjects and objects\
+ indexes or general entities indexes!"
+
+ logger.error(msg)
+ raise Exception(msg)
+
+ subjects = entities
+ objects = entities
+ if subjects is not None and objects is not None:
+ query = "select * from triples_table where (subject in ({0}) and\
+ object in \
+ ({1})) or (subject in ({1}) and object in ({0}));".format(
+ ",".join(str(v) for v in subjects),
+ ",".join(str(v) for v in objects),
+ )
+ elif objects is None:
+ query = (
+ "select * from triples_table where (subject in ({0}));".format(
+ ",".join(str(v) for v in subjects)
+ )
+ )
+ elif subjects is None:
+ query = (
+ "select * from triples_table where (object in ({0}));".format(
+ ",".join(str(v) for v in objects)
+ )
+ )
+ triples = np.array(self._execute_query(query))
+ triples = np.append(
+ triples[:, :3].astype("int"), triples[:, 3].reshape(-1, 1), axis=1
+ )
+ return triples
+
+ def get_indexed_triples(self, chunk, dataset_type="train"):
+ """Get indexed triples.
- def get_size(self, dataset_type="train"):
- """Returns the size of the specified dataset
Parameters
----------
- dataset_type : string
- type of the dataset
+ chunk: ndarray
+ Numpy array with a fragment of data of size (N,3), where each
+ element is: (subject, predicate, object).
+ dataset_type: str
+ Defines what kind of data we are considering
+ (`"train"`, `"test"`, `"validation"`).
Returns
-------
- size : int
- size of the specified dataset
+ tmp: ndarray
+ Numpy array of size (N, 4) with indexed triples, where each
+ element is of the form
+ (subject index, predicate index, object index, dataset_type).
"""
- select_query = "SELECT count(*) from triples_table where dataset_type ='{}'"
- conn = sqlite3.connect("{}".format(self.dbname))
- cur1 = conn.cursor()
- cur1.execute(select_query.format(dataset_type))
- out = cur1.fetchall()
- cur1.close()
- return out[0][0]
-
- def get_next_batch(self, batches_count=-1, dataset_type="train", use_filter=False):
- """Generator that returns the next batch of data.
+ if self.verbose:
+ logger.debug("getting triples...")
+ if isinstance(chunk, pd.DataFrame):
+ chunk = chunk.values
+ if self.use_indexer:
+ # logger.debug(chunk)
+ triples = self.mapper.get_indexes(chunk[:, :3])
+ if self.data_shape > 3:
+ weights = chunk[:, 3:]
+ # weights = preprocess_focusE_weights(data=triples,
+ # weights=weights)
+ return np.hstack(
+ [
+ triples,
+ weights,
+ np.array(len(triples) * [dataset_type]).reshape(-1, 1),
+ ]
+ )
+ return np.append(
+ triples,
+ np.array(len(triples) * [dataset_type]).reshape(-1, 1),
+ axis=1,
+ )
+ else:
+ return np.append(
+ chunk,
+ np.array(len(chunk) * [dataset_type]).reshape(-1, 1),
+ axis=1,
+ )
+
+ def index_entities(self):
+ """Index the data via the definition of the DataIndexer."""
+ self.reload_data()
+ if self.use_indexer is True:
+ self.mapper = DataIndexer(
+ self.data,
+ backend="in_memory" if self.in_memory else "sqlite",
+ root_directory=self.root_directory,
+ )
+ elif self.use_indexer is False:
+ logger.debug("Data won't be indexed")
+ elif isinstance(self.use_indexer, DataIndexer):
+ self.mapper = self.use_indexer
+
+ def is_indexed(self):
+ """Check if the current data adapter has already been indexed.
+
+ Returns
+ -------
+ Flag : bool
+ Flag indicating whether indexing took place.
+
+ """
+ if not hasattr(self, "mapper"):
+ return False
+ return True
+
+ def reload_data(self, verbose=False):
+ """Reinitialise an iterator with data."""
+ self.data = self.loader(self.data_source, chunk_size=self.chunk_size)
+ if verbose:
+ logger.debug("Data reloaded: {}".format(self.data))
+
+ def populate(
+ self,
+ data_source,
+ dataset_type="train",
+ get_indexed_triples=None,
+ loader=None,
+ ):
+ """Populate the database with data.
+
+ Condition: before you can store triples, you have to index data.
Parameters
----------
- dataset_type: string
- indicates which dataset to use
- batches_count: int
- number of batches per epoch (default: -1, i.e. uses batch_size of 1)
- use_filter : bool
- Flag to indicate whether to return the concepts that need to be filtered
+ data_source: ndarray or str
+ Numpy array or file (e.g., csv file) with data.
+ dataset_type: str
+ What type of data is it?
+ options (`"train"` | `"test"` | `"validation"`).
+ get_indexed_triples: func
+ Function to obtain indexed triples.
+ loader: func
+ Loading function to be used to load data; if `None`, the
+ `DataSourceIdentifier` will try to identify the type of
+ ``data_source`` and return an adequate loader.
+
+ """
+ self.data_source = data_source
+ self.loader = loader
+ if loader is None:
+ self.loader = self.identifier.fetch_loader()
+ if not self.is_indexed() and self.use_indexer is not False:
+ if self.verbose:
+ logger.debug("indexing...")
+ self.index_entities()
+ else:
+ logger.debug(
+ "Data is already indexed or no\
+ indexing is required."
+ )
+ if get_indexed_triples is None:
+ get_indexed_triples = self.get_indexed_triples
+ data = self.loader(data_source, chunk_size=self.chunk_size)
+
+ self.reload_data()
+ for chunk in data: # chunk is a numpy array of size (n,m) with m=3/4
+ if chunk.shape[1] > 3:
+ # weights = preprocess_focusE_weights(data=chunk[:, :3],
+ # weights=chunk[:, 3:]) # weights normalization
+ weights = chunk[:, 3:]
+ chunk = np.concatenate([chunk[:, :3], weights], axis=1)
+ self.data_shape = chunk.shape[1]
+ values_triples = get_indexed_triples(
+ chunk, dataset_type=dataset_type
+ )
+ self._insert_values_to_a_table("triples_table", values_triples)
+ if self.verbose:
+ logger.debug("data is populated")
+
+ query = "SELECT count(*) from triples_table;"
+ _ = self._execute_query(query)
+
+ if isinstance(self.use_filter, dict):
+ for key in self.use_filter:
+ present_filters = [
+ x[0]
+ for x in self._execute_query(
+ "SELECT\
+ DISTINCT dataset_type FROM triples_table"
+ )
+ ]
+ if key not in present_filters:
+ # to allow users not to pass weights in test and validation
+ if (
+ self.data_shape > 3
+ and self.use_filter[key].shape[1] == 3
+ ):
+ nan_weights = np.empty(
+ (self.use_filter[key].shape[0], 1)
+ )
+ nan_weights.fill(np.nan)
+ self.use_filter[key] = np.concatenate(
+ [self.use_filter[key], nan_weights], axis=1
+ )
+ self.populate(self.use_filter[key], key)
+ query = "SELECT count(*) from triples_table;"
+ _ = self._execute_query(query)
+
+ def get_data_size(self, table="triples_table", condition=""):
+ """Gets the size of the given table [with specified condition].
+
+ Parameters
+ ----------
+ table: str
+ Table for which to obtain the size.
+ condition: str
+ Condition to count only a subset of data.
Returns
-------
- batch_output : nd-array
- yields a batch of triples from the dataset type specified
- participating_objects : nd-array [n,1]
- all objects that were involved in the s-p-? relation. This is returned only if use_filter is set to true.
- participating_subjects : nd-array [n,1]
- all subjects that were involved in the ?-p-o relation. This is returned only if use_filter is set to true.
+ count: int
+ Number of records in the table.
+
"""
- if (not self.using_existing_db) and (not self.mapped_status[dataset_type]):
- self.map_data()
+ query = "SELECT count(*) from {} {};".format(table, condition)
+ count = self._execute_query(query)
+ if count is None:
+ logger.debug("Table is empty or not such table exists.")
+ return count
+ elif not isinstance(count, list) or not isinstance(count[0], tuple):
+ raise ValueError(
+ "Cannot get count for the table with\
+ provided condition."
+ )
+ # logger.debug(count)
+ return count[0][0]
+
+ def clean_up(self):
+ """Clean the database."""
+ _ = self._execute_queries(self._get_clean_up())
+
+ def remove_db(self):
+ """Remove the database file."""
+ os.remove(self.db_path)
+ logger.debug("Database removed.")
+
+ def _get_complementary_objects(self, triples, use_filter=None):
+ """For a given triple retrieve all triples with same subjects
+ and predicates.
+
+ Parameters
+ ----------
+ triples: list or array
+ List or array with Nx3 elements (subject, predicate, object).
+
+ Returns
+ -------
+ objects : list
+ Result of a query, list of objects.
- if batches_count == -1:
- batch_size = 1
- batches_count = self.get_size(dataset_type)
- else:
- batch_size = int(np.ceil(self.get_size(dataset_type) / batches_count))
-
- select_query = "SELECT subject, predicate,object FROM triples_table INDEXED BY \
- triples_table_type_idx where dataset_type ='{}' LIMIT {}, {}"
-
- for i in range(batches_count):
- conn = sqlite3.connect("{}".format(self.dbname))
- cur1 = conn.cursor()
- cur1.execute(select_query.format(dataset_type, i * batch_size, batch_size))
- out = np.array(cur1.fetchall(), dtype=np.int32)
- cur1.close()
- if use_filter:
- # get the filter values
- participating_objects, participating_subjects = self.get_participating_entities(out)
- yield out, participating_objects, participating_subjects
- else:
- yield out
+ """
+ results = []
+ if self.use_filter is False or self.use_filter is None:
+ self.use_filter = {"train": self.data}
+ filtered = []
+ valid_filters = [
+ x[0]
+ for x in self._execute_query(
+ "SELECT DISTINCT\
+ dataset_type FROM triples_table"
+ )
+ ]
+ for filter_name, filter_source in self.use_filter.items():
+ if filter_name in valid_filters:
+ tmp_filter = []
+ for triple in triples:
+ query = 'select distinct object from triples_table\
+ INDEXED BY triples_table_sp_idx where subject in \
+ ({}) and predicate in ({}) and dataset_type ="{}"'
+
+ query = query.format(triple[0], triple[1], filter_name)
+ q = self._execute_query(query)
+ tmp = list(set([y for x in q for y in x]))
+ tmp_filter.append(tmp)
+ filtered.append(tmp_filter)
+ # Unpack data into one list per triple no matter what filter
+ # it comes from
+ unpacked = list(zip(*filtered))
+ for k in unpacked:
+ lst = [j for i in k for j in i]
+ results.append(lst)
+
+ return results
+
+ def _get_complementary_subjects(self, triples, use_filter=None):
+ """For a given triple retrieve all triples with same objects
+ and predicates.
+
+ Parameters
+ ----------
+ triple: list or array
+ List or array with elements (subject, predicate, object).
+
+ Returns
+ -------
+ subjects : list
+ Result of a query, list of subjects.
- def _insert_triples(self, triples, key=""):
- """inserts triples in the database for the specified category
"""
- conn = sqlite3.connect("{}".format(self.dbname))
- key = np.array([[key]])
- for j in range(int(np.ceil(triples.shape[0] / 500000.0))):
- pg_triple_values = triples[j * 500000:(j + 1) * 500000]
- pg_triple_values = np.concatenate((pg_triple_values, np.repeat(key,
- pg_triple_values.shape[0], axis=0)), axis=1)
- pg_triple_values = pg_triple_values.tolist()
- cur = conn.cursor()
- cur.executemany('INSERT INTO triples_table VALUES (?,?,?,?)', pg_triple_values)
- conn.commit()
- cur.close()
-
- conn.close()
-
- def map_data(self, remap=False):
- """map the data to the mappings of ent_to_idx and rel_to_idx
+ results = []
+ if self.use_filter is False or self.use_filter is None:
+ self.use_filter = {"train": self.data}
+
+ filtered = []
+ valid_filters = [
+ x[0]
+ for x in self._execute_query(
+ "SELECT DISTINCT\
+ dataset_type FROM triples_table"
+ )
+ ]
+ for filter_name, filter_source in self.use_filter.items():
+ if filter_name in valid_filters:
+ tmp_filter = []
+ for triple in triples:
+ query = 'select distinct subject from triples_table \
+ INDEXED BY triples_table_po_idx where predicate \
+ in ({}) and object in ({})\
+ and dataset_type ="{}"'
+
+ query = query.format(triple[1], triple[2], filter_name)
+ q = self._execute_query(query)
+ tmp = list(set([y for x in q for y in x]))
+ tmp_filter.append(tmp)
+ filtered.append(tmp_filter)
+ # Unpack data into one list per triple no matter what
+ # filter it comes from
+ unpacked = list(zip(*filtered))
+ for k in unpacked:
+ lst = [j for i in k for j in i]
+ results.append(lst)
+ return results
+
+ def _get_complementary_entities(self, triples, use_filter=None):
+ """Returns the participating entities in the relation
+ ?-p-o and s-p-?.
+
Parameters
----------
- remap : boolean
- remap the data, if already mapped. One would do this if the dictionary is updated.
- """
- if self.using_existing_db:
- # since the assumption is that the persisted data is already mapped for an existing db
- return
- from ..evaluation import to_idx
- if len(self.rel_to_idx) == 0 or len(self.ent_to_idx) == 0:
- self.generate_mappings()
-
- for key in self.dataset.keys():
- if isinstance(self.dataset[key], np.ndarray):
- if (not self.mapped_status[key]) or (remap is True):
- self.dataset[key] = to_idx(self.dataset[key],
- ent_to_idx=self.ent_to_idx,
- rel_to_idx=self.rel_to_idx)
- self.mapped_status[key] = True
- if not self.persistance_status[key]:
- self._insert_triples(self.dataset[key], key)
- self.persistance_status[key] = True
-
- conn = sqlite3.connect("{}".format(self.dbname))
- cur = conn.cursor()
- # to maintain integrity of data
- cur.execute('Update integrity_check set validity=1 where validity=0')
- conn.commit()
-
- cur.execute('''CREATE TRIGGER IF NOT EXISTS triples_table_ins_integrity_check_trigger
- AFTER INSERT ON triples_table
- BEGIN
- Update integrity_check set validity=0 where validity=1;
- END
- ;
- ''')
- cur.execute('''CREATE TRIGGER IF NOT EXISTS triples_table_upd_integrity_check_trigger
- AFTER UPDATE ON triples_table
- BEGIN
- Update integrity_check set validity=0 where validity=1;
- END
- ;
- ''')
- cur.execute('''CREATE TRIGGER IF NOT EXISTS triples_table_del_integrity_check_trigger
- AFTER DELETE ON triples_table
- BEGIN
- Update integrity_check set validity=0 where validity=1;
- END
- ;
- ''')
-
- cur.execute('''CREATE TRIGGER IF NOT EXISTS entity_table_upd_integrity_check_trigger
- AFTER UPDATE ON entity_table
- BEGIN
- Update integrity_check set validity=0 where validity=1;
- END
- ;
- ''')
- cur.execute('''CREATE TRIGGER IF NOT EXISTS entity_table_ins_integrity_check_trigger
- AFTER INSERT ON entity_table
- BEGIN
- Update integrity_check set validity=0 where validity=1;
- END
- ;
- ''')
- cur.execute('''CREATE TRIGGER IF NOT EXISTS entity_table_del_integrity_check_trigger
- AFTER DELETE ON entity_table
- BEGIN
- Update integrity_check set validity=0 where validity=1;
- END
- ;
- ''')
- cur.close()
- conn.close()
-
- def _validate_data(self, data):
- """validates the data
- """
- if type(data) != np.ndarray:
- msg = 'Invalid type for input data. Expected ndarray, got {}'.format(type(data))
- raise ValueError(msg)
+ triples: ndarray of shape (N,3)
+ Triples (s-p-o) that we are querying.
- if (np.shape(data)[1]) != 3:
- msg = 'Invalid size for input data. Expected number of column 3, got {}'.format(np.shape(data)[1])
- raise ValueError(msg)
+ Returns
+ -------
+ entities: list, list
+ Two lists of subjects and objects participating in the
+ relations ?-p-o and s-p-?.
- def set_data(self, dataset, dataset_type=None, mapped_status=False, persistance_status=False):
- """set the dataset based on the type.
- Note: If you pass the same dataset type it will be appended
+ """
+ objects = self._get_complementary_objects(
+ triples, use_filter=use_filter
+ )
+ subjects = self._get_complementary_subjects(
+ triples, use_filter=use_filter
+ )
+ return subjects, objects
+
+ def _get_batch_generator(
+ self, batch_size=1, dataset_type="train", random=False, index_by=""
+ ):
+ """Generator that returns the next batch of data.
- #Usage for extremely large datasets:
- from ampligraph.datasets import SQLiteAdapter
- adapt = SQLiteAdapter()
+ Parameters
+ ----------
+ dataset_type: str
+ Indicates which dataset to use
+ (`"train"` | `"test"` | `"validation"`).
+ batch_size: int
+ Number of elements in a batch (default: :math:`1`).
+ index_by: str
+ Possible values: `{"", "so", "os", "s", "o"}`. It indicates
+ whether to use index and which to use:
+ index by subject (`"s"`), object (`"o"`) or both (`"so"`, `"os"`).
+ Indexes were created for the fields so SQLite should use them
+ here to speed up, see example below:
+ sqlite> EXPLAIN QUERY PLAN SELECT * FROM triples_table\
+ ORDER BY subject, object LIMIT 7000, 30;
+ QUERY PLAN
+ `--SCAN TABLE triples_table USING INDEX triples_table_sub_obj_idx
+ random: bool
+ Whether to get records from database in a random order.
+
+ Yields
+ -------
+ batch_output : ndarray
+ Yields a batch of triples from the dataset type specified
+ participating_entities : list
+ List of all entities that were involved in the s-p-? and
+ ?-p-o relations. This is returned only if ``use_filter=True``.
- #compute the mappings from the large dataset.
- #Let's assume that the mappings are already computed in rel_to_idx, ent_to_idx.
- #Set the mappings
- adapt.use_mappings(rel_to_idx, ent_to_idx)
+ """
+ if not isinstance(dataset_type, str):
+ dataset_type = dataset_type.decode("utf-8")
+ cond = f"where dataset_type ='{dataset_type}'"
+ size = self.get_data_size(condition=cond)
+ # focusE: size ppi55k = 230929
+ self.batches_count = int(np.ceil(size / batch_size))
+ logger.debug("batches count: {}".format(self.batches_count))
+ logger.debug("size of data: {}".format(size))
+ index = ""
+ if index_by != "":
+ if (
+ index_by == "s"
+ or index_by == "o"
+ or index_by == "so"
+ or index_by == "os"
+ ) and random:
+ msg = "Field index_by can only be used with random set\
+ to False and can only take values from this\
+ set: {{s,o,so,os,''}},\
+ instead got: {}".format(
+ index_by
+ )
+ logger.error(msg)
+ raise Exception(msg)
+
+ if index_by == "s":
+ index = "ORDER BY subject"
+ if index_by == "o":
+ index = "ORDER BY object"
+ if index_by == "so" or index_by == "os":
+ index = "ORDER BY subject, object"
+ if index == "" and random:
+ index = "ORDER BY random()"
+ query_template = "SELECT * FROM triples_table INDEXED BY \
+ triples_table_type_idx where\
+ dataset_type ='{}' {} LIMIT {}, {};"
+
+ for i in range(self.batches_count):
+ # logger.debug("BATCH NUMBER: {}".format(i))
+ # logger.debug(i * batch_size)
+ query = query_template.format(
+ dataset_type, index, i * batch_size, batch_size
+ )
+ # logger.debug(query)
+ out = self._execute_query(query)
+ # logger.debug(out)
+ if out:
+ triples = np.array(out)[:, :3].astype(np.int32)
+ # focusE
+ if self.data_shape > 3:
+ weights = np.array(out)[:, 3:-1]
- #load and store parts of data in the db as train test or valid
- #if you have already mapped the entity names to index, set mapped_status = True
- adapt.set_data(load_part1, 'train', mapped_status = True)
- adapt.set_data(load_part2, 'train', mapped_status = True)
- adapt.set_data(load_part3, 'train', mapped_status = True)
+ else:
+ weights = np.array([])
- #if mapped_status = False, then the adapter will map the entities to index before persisting
- adapt.set_data(load_part1, 'test', mapped_status = False)
- adapt.set_data(load_part2, 'test', mapped_status = False)
+ if self.use_filter:
+ # get the filter values
+ participating_entities = self._get_complementary_entities(
+ triples
+ )
+ if self.data_shape > 3:
+ yield triples, tf.ragged.constant(participating_entities),
+ weights
+ else:
+ yield triples, tf.ragged.constant(participating_entities)
+ else:
+ if self.data_shape > 3:
+ yield triples, weights
+ else:
+ yield triples
- adapt.set_data(load_part1, 'valid', mapped_status = False)
- adapt.set_data(load_part2, 'valid', mapped_status = False)
+ def summary(self, count=True): # FocusE fix types
+ """Prints summary of the database.
- #create the model
- model = ComplEx(batches_count=10000, seed=0, epochs=10, k=50, eta=10)
- model.fit(adapt)
+ The information that is displayed is: whether it exists, what tables
+ does it have, how many records it contains (if ``count=True``),
+ what are fields held and their types with an example record.
Parameters
----------
- dataset : nd-array or dictionary
- dataset of triples
- dataset_type : string
- if the dataset parameter is an nd- array then this indicates the type of the data being based
- mapped_status : bool
- indicates whether the data has already been mapped to the indices
- persistance_status : bool
- indicates whether the data has already been written to the database
+ count: bool
+ Whether to count number of records per table
+ (can be time consuming).
+
+ Example
+ -------
+ >>> adapter = SQLiteAdapter("database_24-06-2020_03-51-12_PM.db")
+ >>> with adapter as db:
+ >>> db.summary()
+ Summary for Database database_29-06-2020_09-37-20_AM.db
+ File size: 3.9453MB
+ Tables: triples_table
+ +-----------------------------------------------------------------------------------+
+ | TRIPLES_TABLE |
+ +---------------------+------------------+---------------+--------------------------+
+ | | subject (int) | predicate (int) | object (int) | dataset_type (text(50)) |
+ +----+----------------+------------------+---------------+--------------------------+
+ |e.g.| 34321 | 29218 | 38102 | train |
+ +----+----------------+------------------+---------------+--------------------------+
+ Records: 59070
+
"""
- if self.using_existing_db:
- raise Exception('Cannot change the existing DB')
-
- if isinstance(dataset, dict):
- for key in dataset.keys():
- self._validate_data(dataset[key])
- self.dataset[key] = dataset[key]
- self.mapped_status[key] = mapped_status
- self.persistance_status[key] = persistance_status
- elif dataset_type is not None:
- self._validate_data(dataset)
- self.dataset[dataset_type] = dataset
- self.mapped_status[dataset_type] = mapped_status
- self.persistance_status[dataset_type] = persistance_status
+ if os.path.exists(self.db_path):
+ print("Summary for Database {}".format(self.db_name))
+ print("Located in {}".format(self.db_path))
+ file_size = os.path.getsize(self.db_path)
+ summary = """File size: {:.5}{}\nTables: {}"""
+ tables = self._execute_query(
+ "SELECT name FROM sqlite_master\
+ WHERE type='table';"
+ )
+ tables_names = ", ".join(table[0] for table in tables)
+ print(
+ summary.format(
+ *get_human_readable_size(file_size), tables_names
+ )
+ )
+ types = {"integer": "int", "float": "float", "string": "str"}
+ # float aggiunto per focusE
+ for table_name in tables:
+ result = self._execute_query(
+ "PRAGMA table_info('%s')" % table_name
+ )
+ cols_name_type = [
+ "{} ({}):".format(
+ x[1], types[x[2]] if x[2] in types else x[2]
+ )
+ for x in result
+ ] # FocusE
+ length = len(cols_name_type)
+ print(
+ "-------------\n|"
+ + table_name[0].upper()
+ + "|\n-------------\n"
+ )
+ formatted_record = "{:7s}{}\n{:7s}{}".format(
+ " ", "{:25s}" * length, "e.g.", "{:<25s}" * length
+ )
+ msg = ""
+ example = ["-"] * length
+ if count:
+ nb_records = self.get_data_size(table_name[0])
+ msg = "\n\nRecords: {}".format(nb_records)
+ if nb_records != 0:
+ record = self._execute_query(
+ f"SELECT * FROM\
+ {table_name[0]} LIMIT {1};"
+ )[0]
+ example = [str(rec) for rec in record]
+ else:
+ print("Count is set to False hence no data displayed")
+
+ print(formatted_record.format(*cols_name_type, *example), msg)
else:
- raise Exception("Incorrect usage. Expected a dictionary or a combination of dataset and it's type.")
+ logger.debug("Database does not exist.")
- if not (len(self.rel_to_idx) == 0 or len(self.ent_to_idx) == 0):
- self.map_data()
+ def _load(self, data_source, dataset_type="train"):
+ """Loads data from the data source to the database. Wrapper
+ around populate method, required by the GraphDataLoader interface.
- def get_participating_entities(self, x_triple):
- """returns the participating entities in the relation ?-p-o and s-p-?
Parameters
----------
- x_triple : nd-array (3,)
- triple (s-p-o) that we are querying
- Returns
- -------
- ent_participating_as_objects : nd-array (n,1)
- entities participating in the relation s-p-?
- ent_participating_as_subjects : nd-array (n,1)
- entities participating in the relation ?-p-o
- """
- x_triple = np.squeeze(x_triple)
- conn = sqlite3.connect("{}".format(self.dbname))
- cur1 = conn.cursor()
- cur2 = conn.cursor()
- cur_integrity = conn.cursor()
- cur_integrity.execute("SELECT * FROM integrity_check")
-
- if cur_integrity.fetchone()[0] == 0:
- raise Exception('Data integrity is corrupted. The tables have been modified.')
-
- query1 = "select " + str(x_triple[2]) + " union select distinct object from triples_table INDEXED BY \
- triples_table_sp_idx where subject=" + str(x_triple[0]) + " and predicate=" + str(x_triple[1])
- query2 = "select " + str(x_triple[0]) + " union select distinct subject from triples_table INDEXED BY \
- triples_table_po_idx where predicate=" + str(x_triple[1]) + " and object=" + str(x_triple[2])
-
- cur1.execute(query1)
- cur2.execute(query2)
-
- ent_participating_as_objects = np.array(cur1.fetchall())
- ent_participating_as_subjects = np.array(cur2.fetchall())
- '''
- if ent_participating_as_objects.ndim>=1:
- ent_participating_as_objects = np.squeeze(ent_participating_as_objects)
-
- if ent_participating_as_subjects.ndim>=1:
- ent_participating_as_subjects = np.squeeze(ent_participating_as_subjects)
- '''
- cur1.close()
- cur2.close()
- cur_integrity.close()
- conn.close()
-
- return ent_participating_as_objects, ent_participating_as_subjects
-
- def cleanup(self):
- """Clean up the database
+ data_source: str or ndarray
+ Numpy array or path to a file (e.g. csv file) from where
+ to read data.
+ dataset_type: str
+ Kind of dataset that is being loaded
+ (`"train"` | `"test"` | `"validation"`).
+
"""
- if self.using_existing_db:
- # if using an existing db then dont remove
- self.dbname = None
- self.using_existing_db = False
- return
-
- # Drop the created tables
- if self.dbname is not None:
- conn = sqlite3.connect("{}".format(self.dbname))
- cur = conn.cursor()
- cur.execute("drop trigger IF EXISTS entity_table_del_integrity_check_trigger")
- cur.execute("drop trigger IF EXISTS entity_table_ins_integrity_check_trigger")
- cur.execute("drop trigger IF EXISTS entity_table_upd_integrity_check_trigger")
-
- cur.execute("drop trigger IF EXISTS triples_table_del_integrity_check_trigger")
- cur.execute("drop trigger IF EXISTS triples_table_upd_integrity_check_trigger")
- cur.execute("drop trigger IF EXISTS triples_table_ins_integrity_check_trigger")
- cur.execute("drop table IF EXISTS integrity_check")
- cur.execute("drop index IF EXISTS triples_table_po_idx")
- cur.execute("drop index IF EXISTS triples_table_sp_idx")
- cur.execute("drop index IF EXISTS triples_table_type_idx")
- cur.execute("drop table IF EXISTS triples_table")
- cur.execute("drop table IF EXISTS entity_table")
- cur.close()
- conn.close()
- try:
- if self.temp_dir is not None:
- self.temp_dir.cleanup()
- except OSError:
- logger.warning('Unable to remove the created temperory files.')
- logger.warning('Filename:{}'.format(self.dbname))
- self.dbname = None
+ self.data_source = data_source
+ self.populate(self.data_source, dataset_type=dataset_type)
+
+ def _intersect(self, dataloader):
+ if not isinstance(dataloader.backend, SQLiteAdapter):
+ msg = "Provided dataloader should be of type SQLiteAdapter\
+ backend, instead got {}.".format(
+ type(dataloader.backend)
+ )
+ logger.error(msg)
+ raise Exception(msg)
+ raise NotImplementedError
+
+ def _clean(self):
+ os.remove(self.db_path)
+ self.mapper.clean()
diff --git a/ampligraph/discovery/__init__.py b/ampligraph/discovery/__init__.py
index 056a3f14..700dc0c2 100644
--- a/ampligraph/discovery/__init__.py
+++ b/ampligraph/discovery/__init__.py
@@ -1,20 +1,29 @@
-# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
#
# This file is Licensed under the Apache License, Version 2.0.
# A copy of the Licence is available in LICENCE, or at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
-r"""This module includes a number of functions to perform knowledge discovery in graph embeddings.
+r"""This module includes a number of functions to perform knowledge discovery
+in graph embeddings.
-Functions provided include ``discover_facts`` which will generate candidate statements using one of several
-defined strategies and return triples that perform well when evaluated against corruptions, ``find_clusters`` which
-will perform link-based cluster analysis on a knowledge graph, ``find_duplicates`` which will find duplicate entities
-in a graph based on their embeddings, and ``query_topn`` which when given two elements of a triple will return
-the top_n results of all possible completions ordered by predicted score.
+Functions provided include ``discover_facts`` which will generate candidate
+statements using one of several defined strategies and return triples that
+perform well when evaluated against corruptions, ``find_clusters`` which
+will perform link-based cluster analysis on a knowledge graph,
+``find_duplicates`` which will find duplicate entities
+in a graph based on their embeddings, and ``query_topn`` which when given
+two elements of a triple will return the top_n results of all possible
+completions ordered by predicted score.
"""
-from .discovery import discover_facts, find_clusters, find_duplicates, query_topn, find_nearest_neighbours
+from .discovery import (
+ discover_facts,
+ find_clusters,
+ find_duplicates,
+ query_topn,
+)
-__all__ = ['discover_facts', 'find_clusters', 'find_duplicates', 'query_topn', 'find_nearest_neighbours']
+__all__ = ["discover_facts", "find_clusters", "find_duplicates", "query_topn"]
diff --git a/ampligraph/discovery/discovery.py b/ampligraph/discovery/discovery.py
index 7a41d755..f084eac1 100644
--- a/ampligraph/discovery/discovery.py
+++ b/ampligraph/discovery/discovery.py
@@ -1,23 +1,32 @@
-# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
#
# This file is Licensed under the Apache License, Version 2.0.
# A copy of the Licence is available in LICENCE, or at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
+
import logging
+
+import networkx as nx
import numpy as np
+from scipy import optimize, spatial
from sklearn.cluster import DBSCAN
from sklearn.neighbors import NearestNeighbors
-from scipy import optimize, spatial
-import networkx as nx
-from ..evaluation import evaluate_performance, filter_unseen_entities
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
-def discover_facts(X, model, top_n=10, strategy='random_uniform', max_candidates=100, target_rel=None, seed=0):
+def discover_facts(
+ X,
+ model,
+ top_n=10,
+ strategy="random_uniform",
+ max_candidates=100,
+ target_rel=None,
+ seed=0,
+):
"""
Discover new facts from an existing knowledge graph.
@@ -33,18 +42,18 @@ def discover_facts(X, model, top_n=10, strategy='random_uniform', max_candidates
The majority of the strategies are implemented with the same underlying principle of searching for
candidate statements:
- - from among the less frequent entities ('entity_frequency'),
- - less connected entities ('graph_degree', cluster_coefficient'),
- - | less frequent local graph structures ('cluster_triangles', 'cluster_squares'), on the assumption that densely
- connected entities are less likely to have missing true statements.
- - | The remaining strategies ('random_uniform', 'exhaustive') generate candidate statements by a random sampling
- of entity and relations and exhaustively, respectively.
+ - from among the less frequent entities (`'entity_frequency'`),
+ - less connected entities (`'graph_degree'`, `'cluster_coefficient'`),
+ - | less frequent local graph structures (`'cluster_triangles'`, `'cluster_squares'`), on the assumption that
+ densely connected entities are less likely to have missing true statements.
+ - | The remaining strategies (`'random_uniform'`, `'exhaustive'`) generate candidate statements by a random
+ sampling of entities and relations or exhaustively, respectively.
.. warning::
Due to the significant amount of computation required to evaluate all triples using the 'exhaustive' strategy,
we do not recommend its use at this time.
- The function will automatically filter entities that haven't been seen by the model, and operates on
+ The function will automatically filter entities that have not been seen by the model, and operates on
the assumption that the model provided has been fit on the data ``X`` (determined heuristically), although ``X``
may be a subset of the original data, in which case a warning is shown.
@@ -54,106 +63,158 @@ def discover_facts(X, model, top_n=10, strategy='random_uniform', max_candidates
Parameters
----------
- X : ndarray, shape [n, 3]
+ X : ndarray of shape (n, 3)
The input knowledge graph used to train ``model``, or a subset of it.
model : EmbeddingModel
The trained model that will be used to score candidate facts.
top_n : int
The cutoff position in ranking to consider a candidate triple as true positive.
- strategy: string
+ strategy: str
The candidates generation strategy:
- - 'random_uniform' : generates N candidates (N <= max_candidates) based on a uniform sampling of entities.
- - 'entity_frequency' : generates candidates by weighted sampling of entities using entity frequency.
- - 'graph_degree' : generates candidates by weighted sampling of entities with graph degree.
- - 'cluster_coefficient' : generates candidates by weighted sampling entities with clustering coefficient.
- - 'cluster_triangles' : generates candidates by weighted sampling entities with cluster triangles.
- - 'cluster_squares' : generates candidates by weighted sampling entities with cluster squares.
+ - `'random_uniform'` : generates `N` candidates (:math:`N <= max_candidates`) based on a uniform sampling of
+ entities.
+ - `'entity_frequency'` : generates candidates by weighted sampling of entities using entity frequency.
+ - `'graph_degree'` : generates candidates by weighted sampling of entities with graph degree.
+ - `'cluster_coefficient'` : generates candidates by weighted sampling entities with clustering coefficient.
+ - `'cluster_triangles'` : generates candidates by weighted sampling entities with cluster triangles.
+ - `'cluster_squares'` : generates candidates by weighted sampling entities with cluster squares.
max_candidates: int or float
- The maximum numbers of candidates generated by 'strategy'.
- Can be an absolute number or a percentage [0,1] of the size of the ```X``` parameter.
+ The maximum numbers of candidates generated by ``strategy``.
+ Can be an absolute number or a percentage [0,1] of the size of the `X` parameter.
target_rel : str or list(str)
Target relations to focus on. The function will discover facts only for that specific relation types.
- If None, the function attempts to discover new facts for all relation types in the graph.
+ If `None`, the function attempts to discover new facts for all relation types in the graph.
seed : int
Seed to use for reproducible results.
Returns
-------
- X_pred : ndarray, shape [n, 3]
+ X_pred : ndarray, shape (n, 3)
A list of new facts predicted to be true.
- Examples
- --------
+ Example
+ -------
>>> import requests
+ >>> from ampligraph.latent_features import ScoringBasedEmbeddingModel
>>> from ampligraph.datasets import load_from_csv
- >>> from ampligraph.latent_features import ComplEx
>>> from ampligraph.discovery import discover_facts
- >>>
>>> # Game of Thrones relations dataset
>>> url = 'https://ampligraph.s3-eu-west-1.amazonaws.com/datasets/GoT.csv'
>>> open('GoT.csv', 'wb').write(requests.get(url).content)
>>> X = load_from_csv('.', 'GoT.csv', sep=',')
- >>>
- >>> model = ComplEx(batches_count=10, seed=0, epochs=200, k=150, eta=5,
- >>> optimizer='adam', optimizer_params={'lr':1e-3},
- >>> loss='multiclass_nll', regularizer='LP',
- >>> regularizer_params={'p':3, 'lambda':1e-5},
- >>> verbose=True)
- >>> model.fit(X)
- >>>
- >>> discover_facts(X, model, top_n=3, max_candidates=20000, strategy='entity_frequency',
- >>> target_rel='ALLIED_WITH', seed=42)
- array([['House Reed of Greywater Watch', 'ALLIED_WITH', 'Sybelle Glover'],
- ['Hugo Wull', 'ALLIED_WITH', 'House Norrey'],
- ['House Grell', 'ALLIED_WITH', 'Delonne Allyrion'],
- ['Lorent Lorch', 'ALLIED_WITH', 'House Ruttiger']], dtype=object)
-
-
+ >>> model = ScoringBasedEmbeddingModel(eta=5,
+ >>> k=300,
+ >>> scoring_type='ComplEx')
+ >>> model.compile(optimizer='adam', loss='multiclass_nll')
+ >>> model.fit(X,
+ >>> batch_size=100,
+ >>> epochs=10,
+ >>> validation_freq=50,
+ >>> validation_batch_size=100,
+ >>> validation_data = dataset['valid'])
+ >>> discover_facts(X,
+ >>> model,
+ >>> top_n=100,
+ >>> strategy='random_uniform',
+ >>> max_candidates=100,
+ >>> target_rel='ALLIED_WITH',
+ >>> seed=0)
+ Epoch 1/10
+ 33/33 [==============================] - 1s 27ms/step - loss: 177.7778
+ Epoch 2/10
+ 33/33 [==============================] - 0s 6ms/step - loss: 177.4795
+ Epoch 3/10
+ 33/33 [==============================] - 0s 6ms/step - loss: 176.9654
+ Epoch 4/10
+ 33/33 [==============================] - 0s 6ms/step - loss: 175.8453
+ Epoch 5/10
+ 33/33 [==============================] - 0s 6ms/step - loss: 173.4385
+ Epoch 6/10
+ 33/33 [==============================] - 0s 6ms/step - loss: 168.8143
+ Epoch 7/10
+ 33/33 [==============================] - 0s 6ms/step - loss: 161.2919
+ Epoch 8/10
+ 33/33 [==============================] - 0s 6ms/step - loss: 151.3496
+ Epoch 9/10
+ 33/33 [==============================] - 0s 6ms/step - loss: 140.4268
+ Epoch 10/10
+ 33/33 [==============================] - 0s 5ms/step - loss: 129.8206
+ 3175 triples containing invalid keys skipped!
+ (array([['House Nymeros Martell of Sunspear', 'ALLIED_WITH',
+ 'House Mallister of Seagard'],
+ ['Ben', 'ALLIED_WITH', 'House Mallister of Seagard'],
+ ['Selwyn Tarth', 'ALLIED_WITH', 'House Mallister of Seagard'],
+ ['Clarence Charlton', 'ALLIED_WITH', 'House Woods'],
+ ['Selwyn Tarth', 'ALLIED_WITH', 'House Woods'],
+ ['Dacks', 'ALLIED_WITH', 'Titus Peake'],
+ ['Barra', 'ALLIED_WITH', 'Titus Peake'],
+ ['House Chelsted', 'ALLIED_WITH', 'Denys Darklyn'],
+ ['Crow Spike Keep', 'ALLIED_WITH', 'Denys Darklyn'],
+ ['Selwyn Tarth', 'ALLIED_WITH', 'Denys Darklyn'],
+ ['House Chelsted', 'ALLIED_WITH', 'House Belmore of Strongsong'],
+ ['Barra', 'ALLIED_WITH', 'House Belmore of Strongsong'],
+ ['Walder Frey', 'ALLIED_WITH', 'House Belmore of Strongsong']],
+ dtype=object),
+ array([ 2. , 53. , 73. , 42. , 18. , 59.5, 86. , 76.5, 31. , 60.5, 31.5,
+ 32. , 24. ]))
"""
+ if model.is_backward:
+ model = model.model
if not model.is_fitted:
- msg = 'Model is not fitted.'
+ msg = "Model is not fitted."
logger.error(msg)
raise ValueError(msg)
- if not model.is_fitted_on(X):
- msg = 'Model might not be fitted on this data.'
- logger.warning(msg)
- # raise ValueError(msg)
-
- if strategy not in ['random_uniform', 'entity_frequency', 'graph_degree', 'cluster_coefficient',
- 'cluster_triangles', 'cluster_squares']:
- msg = '%s is not a valid strategy.' % strategy
+ # if not model.is_fitted_on(X):
+ # msg = 'Model might not be fitted on this data.'
+ # logger.warning(msg)
+ # raise ValueError(msg)
+
+ if strategy not in [
+ "random_uniform",
+ "entity_frequency",
+ "graph_degree",
+ "cluster_coefficient",
+ "cluster_triangles",
+ "cluster_squares",
+ ]:
+ msg = "%s is not a valid strategy." % strategy
logger.error(msg)
raise ValueError(msg)
- if strategy == 'exhaustive':
- msg = 'Strategy is `exhaustive`, ignoring max_candidates.'
+ if strategy == "exhaustive":
+ msg = "Strategy is `exhaustive`, ignoring max_candidates."
logger.info(msg)
if isinstance(max_candidates, float):
- logger.debug('Converting max_candidates float value {} to int value {}'.format(max_candidates,
- int(max_candidates * len(X))))
+ logger.debug(
+ "Converting max_candidates float value {} to int value {}".format(
+ max_candidates, int(max_candidates * len(X))
+ )
+ )
max_candidates = int(max_candidates * len(X))
if isinstance(target_rel, str):
target_rel = [target_rel]
if target_rel is None:
- msg = 'No target relation specified. Using all relations to generate candidate statements.'
+ msg = "No target relation specified. Using all relations to generate candidate statements."
logger.info(msg)
- rel_list = [x for x in model.rel_to_idx.keys()]
+ rel_list = [x for x in model.data_indexer.backend.get_all_relations()]
else:
missing_rels = []
for rel in target_rel:
- if rel not in model.rel_to_idx.keys():
+ if rel not in model.data_indexer.backend.get_all_relations():
missing_rels.append(rel)
if len(missing_rels) > 0:
- msg = 'Target relation(s) not found in model: {}'.format(missing_rels)
+ msg = "Target relation(s) not found in model: {}".format(
+ missing_rels
+ )
logger.error(msg)
raise ValueError(msg)
@@ -163,23 +224,31 @@ def discover_facts(X, model, top_n=10, strategy='random_uniform', max_candidates
np.random.seed(seed)
# Remove unseen entities
- X_filtered = filter_unseen_entities(X, model)
+ # X_filtered = filter_unseen_entities(X, model)
discoveries = []
discovery_ranks = []
# Iterate through relations
for relation in rel_list:
+ logger.info("Generating candidates for relation: %s" % relation)
- logger.info('Generating candidates for relation: %s' % relation)
-
- candidates = generate_candidates(X_filtered, strategy, relation, max_candidates, seed=seed)
+ candidates = generate_candidates(
+ X, strategy, relation, max_candidates, seed=seed
+ )
- logger.debug('Generated %d candidate statements.' % len(candidates))
+ logger.debug("Generated %d candidate statements." % len(candidates))
# Get ranks of candidate statements
- ranks = evaluate_performance(candidates, model=model, filter_triples=X, use_default_protocol=True,
- verbose=False)
+ # ranks = evaluate_performance(candidates, model=model, filter_triples=X, use_default_protocol=True,
+ # verbose=False)
+
+ ranks = model.evaluate(
+ candidates,
+ use_filter={"test": X},
+ corrupt_side="s,o",
+ verbose=False,
+ )
# Select candidate statements within the top_n predicted ranks standard protocol evaluates against
# corruptions on both sides, we just average the ranks here
@@ -189,93 +258,104 @@ def discover_facts(X, model, top_n=10, strategy='random_uniform', max_candidates
discoveries.append(candidates[preds])
discovery_ranks.append(avg_ranks[preds])
- logger.info('Discovered %d facts' % len(discoveries))
+ logger.info("Discovered %d facts" % len(discoveries))
return np.hstack(discoveries), np.hstack(discovery_ranks)
-def generate_candidates(X, strategy, target_rel, max_candidates, consolidate_sides=False, seed=0):
- """ Generate candidate statements from an existing knowledge graph using a defined strategy.
+def generate_candidates(
+ X, strategy, target_rel, max_candidates, consolidate_sides=False, seed=0
+):
+ """Generate candidate statements from an existing knowledge graph using a defined strategy.
- Parameters
- ----------
+ Parameters
+ ----------
+ X: np.array, shape (n, 3)
+ Triples from which to discover new facts.
+ strategy: str
+ The candidates generation strategy.
+ - `'random_uniform'` : generates `N` candidates (:math:`N <= max_candidates`) based on a uniform random
+ sampling of head and tail entities.
+ - `'entity_frequency'` : generates candidates by sampling entities with low frequency.
+ - `'graph_degree'` : generates candidates by sampling entities with a low graph degree.
+ - `'cluster_coefficient'` : generates candidates by sampling entities with a low clustering coefficient.
+ - `'cluster_triangles'` : generates candidates by sampling entities with a low number of cluster triangles.
+ - `'cluster_squares'` : generates candidates by sampling entities with a low number of cluster squares.
+ max_candidates: int or float
+ The maximum numbers of candidates generated by ``strategy``.
+ Can be an absolute number or a percentage [0,1].
+ This does not guarantee the number of candidates generated.
+ target_rel : str
+ Target relation to focus on. The function will generate candidate
+ statements only with this specific relation type.
+ consolidate_sides: bool
+ If `True` will generate candidate statements as a product of unique head and tail entities, otherwise will
+ consider head and tail entities separately (default: `False`).
+ seed : int
+ Seed to use for reproducible results.
- strategy: string
- The candidates generation strategy.
- - 'random_uniform' : generates N candidates (N <= max_candidates) based on a uniform random sampling of
- head and tail entities.
- - 'entity_frequency' : generates candidates by sampling entities with low frequency.
- - 'graph_degree' : generates candidates by sampling entities with a low graph degree.
- - 'cluster_coefficient' : generates candidates by sampling entities with a low clustering coefficient.
- - 'cluster_triangles' : generates candidates by sampling entities with a low number of cluster triangles.
- - 'cluster_squares' : generates candidates by sampling entities with a low number of cluster squares.
- max_candidates: int or float
- The maximum numbers of candidates generated by 'strategy'.
- Can be an absolute number or a percentage [0,1].
- This does not guarantee the number of candidates generated.
- target_rel : str
- Target relation to focus on. The function will generate candidate
- statements only with this specific relation type.
- consolidate_sides: bool
- If True will generate candidate statements as a product of
- unique head and tail entities, otherwise will
- consider head and tail entities separately. Default: False.
- seed : int
- Seed to use for reproducible results.
+ Returns
+ -------
+ X_candidates : ndarray, shape (n, 3)
+ A list of candidate statements.
- Returns
- -------
- X_candidates : ndarray, shape [n, 3]
- A list of candidate statements.
-
-
- Examples
- --------
- >>> import numpy as np
- >>> from ampligraph.discovery.discovery import generate_candidates
- >>>
- >>> X = np.array([['a', 'y', 'b'],
- >>> ['b', 'y', 'a'],
- >>> ['a', 'y', 'c'],
- >>> ['c', 'y', 'a'],
- >>> ['a', 'y', 'd'],
- >>> ['c', 'y', 'd'],
- >>> ['b', 'y', 'c'],
- >>> ['f', 'y', 'e']])
-
- >>> X_candidates = generate_candidates(X, strategy='graph_degree', target_rel='y', max_candidates=3)
- >>> ([['a', 'y', 'e'],
- >>> ['f', 'y', 'a'],
- >>> ['c', 'y', 'e']])
- """
+ Example
+ -------
+ >>> import numpy as np
+ >>> from ampligraph.discovery.discovery import generate_candidates
+ >>>
+ >>> X = np.array([['a', 'y', 'b'],
+ >>> ['b', 'y', 'a'],
+ >>> ['a', 'y', 'c'],
+ >>> ['c', 'y', 'a'],
+ >>> ['a', 'y', 'd'],
+ >>> ['c', 'y', 'd'],
+ >>> ['b', 'y', 'c'],
+ >>> ['f', 'y', 'e']])
+
+ >>> X_candidates = generate_candidates(X, strategy='graph_degree', target_rel='y', max_candidates=3)
+ >>> ([['a', 'y', 'e'],
+ >>> ['f', 'y', 'a'],
+ >>> ['c', 'y', 'e']])
- if strategy not in ['random_uniform', 'entity_frequency',
- 'graph_degree', 'cluster_coefficient',
- 'cluster_triangles', 'cluster_squares']:
- msg = '%s is not a valid candidate generation strategy.' % strategy
+ """
+ if (
+ X.shape[1] > 3
+ ): # exception needed if weights are given in input together with triples
+ X = X[:, :3]
+ if strategy not in [
+ "random_uniform",
+ "entity_frequency",
+ "graph_degree",
+ "cluster_coefficient",
+ "cluster_triangles",
+ "cluster_squares",
+ ]:
+ msg = "%s is not a valid candidate generation strategy." % strategy
raise ValueError(msg)
if target_rel not in np.unique(X[:, 1]):
# No error as may be case where target_rel is not in X
- msg = 'Target relation is not found in triples.'
+ msg = "Target relation is not found in triples."
logger.warning(msg)
if not isinstance(max_candidates, (float, int)):
- msg = 'Parameter max_candidates must be a float or int.'
+ msg = "Parameter max_candidates must be a float or int."
raise ValueError(msg)
if max_candidates <= 0:
- msg = 'Parameter max_candidates must be a positive integer ' \
- 'or float in range (0,1].'
+ msg = (
+ "Parameter max_candidates must be a positive integer "
+ "or float in range (0,1]."
+ )
raise ValueError(msg)
if isinstance(max_candidates, float):
max_candidates = int(max_candidates * len(X))
def _filter_candidates(X_candidates, X, remove_reflexive=True):
- """ Inner function to filter candidate statements from X_candidates that are in X.
- """
+ """Inner function to filter candidate statements from X_candidates that are in X."""
X_candidates = _setdiff2d(X_candidates, X)
# Filter statements that are ['x', rel, 'x']
if remove_reflexive:
@@ -295,29 +375,41 @@ def _filter_candidates(X_candidates, X, remove_reflexive=True):
e_s = np.unique(X[:, 0])
e_o = np.unique(X[:, 2])
- logger.info('Generating candidates using {} strategy.'.format(strategy))
-
- if strategy == 'random_uniform':
+ logger.info("Generating candidates using {} strategy.".format(strategy))
- # Take close to sqrt of max_candidates so that: len(meshgrid result) == max_candidates
- sample_size = int(np.sqrt(max_candidates) + 10) # +10 to allow for reduction in sampled array due to filtering
+ if strategy == "random_uniform":
+ # Take close to sqrt of max_candidates so that: len(meshgrid result) ==
+ # max_candidates
+ # +10 to allow for reduction in sampled array due to filtering
+ sample_size = int(np.sqrt(max_candidates) + 10)
- X_candidates = np.zeros([max_candidates, 3], dtype=object) # Pre-allocate X_candidates array
- num_retries, max_retries = 0, 5 # Retry up to 5 times to reach max_candidates
+ # Pre-allocate X_candidates array
+ X_candidates = np.zeros([max_candidates, 3], dtype=object)
+ num_retries, max_retries = (
+ 0,
+ 5,
+ ) # Retry up to 5 times to reach max_candidates
start_idx, end_idx = 0, 0 #
while end_idx <= max_candidates - 1:
sample_e_s = np.random.choice(e_s, size=sample_size, replace=False)
sample_e_o = np.random.choice(e_o, size=sample_size, replace=False)
- gen_candidates = np.array(np.meshgrid(sample_e_s, target_rel, sample_e_o)).T.reshape(-1, 3)
+ gen_candidates = np.array(
+ np.meshgrid(sample_e_s, target_rel, sample_e_o)
+ ).T.reshape(-1, 3)
gen_candidates = _filter_candidates(gen_candidates, X)
- # Select either all of gen_candidates or just enough to fill X_candidates
- select_idx = min(len(gen_candidates), len(X_candidates) - start_idx)
+ # Select either all of gen_candidates or just enough to fill
+ # X_candidates
+ select_idx = min(
+ len(gen_candidates), len(X_candidates) - start_idx
+ )
end_idx = start_idx + select_idx
- X_candidates[start_idx:end_idx, :] = gen_candidates[0:select_idx, :]
+ X_candidates[start_idx:end_idx, :] = gen_candidates[
+ 0:select_idx, :
+ ]
start_idx = end_idx
num_retries += 1
@@ -327,21 +419,30 @@ def _filter_candidates(X_candidates, X, remove_reflexive=True):
# end_idx will equal max_candidates in most cases, but could be less
return X_candidates[0:end_idx, :]
- elif strategy == 'entity_frequency':
-
+ elif strategy == "entity_frequency":
# Get entity counts and sort them in ascending order
if consolidate_sides:
- e_s_counts = np.array(np.unique(X[:, [0, 2]], return_counts=True)).T
+ e_s_counts = np.array(
+ np.unique(X[:, [0, 2]], return_counts=True)
+ ).T
e_o_counts = e_s_counts
else:
e_s_counts = np.array(np.unique(X[:, 0], return_counts=True)).T
e_o_counts = np.array(np.unique(X[:, 2], return_counts=True)).T
- e_s_weights = e_s_counts[:, 1].astype(np.float64) / np.sum(e_s_counts[:, 1].astype(np.float64))
- e_o_weights = e_o_counts[:, 1].astype(np.float64) / np.sum(e_o_counts[:, 1].astype(np.float64))
-
- elif strategy in ['graph_degree', 'cluster_coefficient', 'cluster_triangles', 'cluster_squares']:
-
+ e_s_weights = e_s_counts[:, 1].astype(np.float64) / np.sum(
+ e_s_counts[:, 1].astype(np.float64)
+ )
+ e_o_weights = e_o_counts[:, 1].astype(np.float64) / np.sum(
+ e_o_counts[:, 1].astype(np.float64)
+ )
+
+ elif strategy in [
+ "graph_degree",
+ "cluster_coefficient",
+ "cluster_triangles",
+ "cluster_squares",
+ ]:
# Create networkx graph
G = nx.Graph()
for row in X:
@@ -349,13 +450,13 @@ def _filter_candidates(X_candidates, X, remove_reflexive=True):
G.add_edge(row[0], row[2], name=row[1])
# Calculate node metrics
- if strategy == 'graph_degree':
+ if strategy == "graph_degree":
C = {i: j for i, j in G.degree()}
- elif strategy == 'cluster_coefficient':
+ elif strategy == "cluster_coefficient":
C = nx.algorithms.cluster.clustering(G)
- elif strategy == 'cluster_triangles':
+ elif strategy == "cluster_triangles":
C = nx.algorithms.cluster.triangles(G)
- elif strategy == 'cluster_squares':
+ elif strategy == "cluster_squares":
C = nx.algorithms.cluster.square_clustering(G)
e_s_weights = np.array([C[x] for x in e_s], dtype=np.float64)
@@ -364,22 +465,34 @@ def _filter_candidates(X_candidates, X, remove_reflexive=True):
e_s_weights = e_s_weights / np.sum(e_s_weights)
e_o_weights = e_o_weights / np.sum(e_o_weights)
- # Take close to sqrt of max_candidates so that: len(meshgrid result) == max_candidates
- sample_size = int(np.sqrt(max_candidates) + 10) # +10 to allow for reduction in sampled array due to filtering
-
- X_candidates = np.zeros([max_candidates, 3], dtype=object) # Pre-allocate X_candidates array
- num_retries, max_retries = 0, 5 # Retry up to 5 times to reach max_candidates
+ # Take close to sqrt of max_candidates so that: len(meshgrid result) ==
+ # max_candidates
+ # +10 to allow for reduction in sampled array due to filtering
+ sample_size = int(np.sqrt(max_candidates) + 10)
+
+ # Pre-allocate X_candidates array
+ X_candidates = np.zeros([max_candidates, 3], dtype=object)
+ num_retries, max_retries = (
+ 0,
+ 5,
+ ) # Retry up to 5 times to reach max_candidates
start_idx, end_idx = 0, 0
while end_idx <= max_candidates - 1:
-
- sample_e_s = np.random.choice(e_s, size=sample_size, replace=True, p=e_s_weights)
- sample_e_o = np.random.choice(e_o, size=sample_size, replace=True, p=e_o_weights)
-
- gen_candidates = np.array(np.meshgrid(sample_e_s, target_rel, sample_e_o)).T.reshape(-1, 3)
+ sample_e_s = np.random.choice(
+ e_s, size=sample_size, replace=True, p=e_s_weights
+ )
+ sample_e_o = np.random.choice(
+ e_o, size=sample_size, replace=True, p=e_o_weights
+ )
+
+ gen_candidates = np.array(
+ np.meshgrid(sample_e_s, target_rel, sample_e_o)
+ ).T.reshape(-1, 3)
gen_candidates = _filter_candidates(gen_candidates, X)
- # Select either all of gen_candidates or just enough to fill X_candidates
+ # Select either all of gen_candidates or just enough to fill
+ # X_candidates
select_idx = min(len(gen_candidates), len(X_candidates) - start_idx)
end_idx = start_idx + select_idx
@@ -396,7 +509,7 @@ def _filter_candidates(X_candidates, X, remove_reflexive=True):
def _setdiff2d(A, B):
- """ Utility function equivalent to numpy.setdiff1d on 2d arrays.
+ """Utility function equivalent to numpy.setdiff1d on 2d arrays.
Parameters
----------
@@ -407,24 +520,24 @@ def _setdiff2d(A, B):
Returns
-------
- np.array, shape [k, m]
+ subset_A : np.array, shape [k, m]
Rows of A that are not in B.
"""
if len(A.shape) != 2 or len(B.shape) != 2:
- raise RuntimeError('Input arrays must be 2-dimensional.')
+ raise RuntimeError("Input arrays must be 2-dimensional.")
tmp = np.prod(np.swapaxes(A[:, :, None], 1, 2) == B, axis=2)
- return A[~ np.sum(np.cumsum(tmp, axis=0) * tmp == 1, axis=1).astype(bool)]
+ return A[~np.sum(np.cumsum(tmp, axis=0) * tmp == 1, axis=1).astype(bool)]
-def find_clusters(X, model, clustering_algorithm=DBSCAN(), mode="entity"):
+def find_clusters(X, model, clustering_algorithm=DBSCAN(), mode="e"):
"""
Perform link-based cluster analysis on a knowledge graph.
The clustering happens on the embedding space of the entities and relations.
- For example, if we cluster some entities of a model that uses `k=100` (i.e. embedding space of size 100),
+ For example, if we cluster some entities of a model that uses :math:`k=100` (i.e. embedding space of size 100),
we will apply the chosen clustering algorithm on the 100-dimensional space of the provided input samples.
Clustering can be used to evaluate the quality of the knowledge embeddings, by comparing to natural clusters.
@@ -437,15 +550,15 @@ def find_clusters(X, model, clustering_algorithm=DBSCAN(), mode="entity"):
Please see `scikit-learn documentation `_
for a list of algorithms, their parameters, and pros and cons.
- Clustering is exclusive (i.e. a triple is assigned to one and only one cluster).
+ Clustering is exclusive (i.e., a triple is assigned to one and only one cluster).
Parameters
----------
- X : ndarray, shape [n, 3] or [n]
+ X : ndarray, shape (n, 3) or (n)
The input to be clustered.
``X`` can either be the triples of a knowledge graph, its entities, or its relations.
- The argument ``mode`` defines whether ``X`` is supposed an array of triples
+ The argument ``mode`` defines whether ``X`` is supposed to be an array of triples
or an array of either entities or relations.
model : EmbeddingModel
The fitted model that will be used to generate the embeddings.
@@ -453,18 +566,20 @@ def find_clusters(X, model, clustering_algorithm=DBSCAN(), mode="entity"):
``fit()`` or from a helper function such as :meth:`ampligraph.evaluation.select_best_model_ranking`.
clustering_algorithm : object
The initialized object of the clustering algorithm.
- It should be ready to apply the `fit_predict` method.
+ It should be ready to apply the :meth:`fit_predict` method.
Please see: `scikit-learn documentation `_
to understand the clustering API provided by scikit-learn.
The default clustering model is
`sklearn's DBSCAN `_
with its default parameters.
- mode: string
- Clustering mode. Choose from:
+ mode: str
+ Clustering mode.
- - | 'entity' (default): the algorithm will cluster the embeddings of the provided entities.
- - | 'relation': the algorithm will cluster the embeddings of the provided relations.
- - | 'triple' : the algorithm will cluster the concatenation
+ Choose from:
+
+ - | `'e'` (default): the algorithm will cluster the embeddings of the provided entities.
+ - | `'r'`: the algorithm will cluster the embeddings of the provided relations.
+ - | `'t'` : the algorithm will cluster the concatenation
of the embeddings of the subject, predicate and object for each triple.
Returns
@@ -472,8 +587,8 @@ def find_clusters(X, model, clustering_algorithm=DBSCAN(), mode="entity"):
labels : ndarray, shape [n]
Index of the cluster each triple belongs to.
- Examples
- --------
+ Example
+ -------
>>> # Note seaborn, matplotlib, adjustText are not AmpliGraph dependencies.
>>> # and must therefore be installed manually as:
>>> #
@@ -491,7 +606,7 @@ def find_clusters(X, model, clustering_algorithm=DBSCAN(), mode="entity"):
>>> from adjustText import adjust_text
>>>
>>> from ampligraph.datasets import load_from_csv
- >>> from ampligraph.latent_features import ComplEx
+ >>> from ampligraph.latent_features import ScoringBasedEmbeddingModel
>>> from ampligraph.discovery import find_clusters
>>>
>>> # International football matches triples
@@ -502,30 +617,26 @@ def find_clusters(X, model, clustering_algorithm=DBSCAN(), mode="entity"):
>>> open('football.csv', 'wb').write(requests.get(url).content)
>>> X = load_from_csv('.', 'football.csv', sep=',')[:, 1:]
>>>
- >>> model = ComplEx(batches_count=50,
- >>> epochs=300,
- >>> k=100,
- >>> eta=20,
- >>> optimizer='adam',
- >>> optimizer_params={'lr':1e-4},
- >>> loss='multiclass_nll',
- >>> regularizer='LP',
- >>> regularizer_params={'p':3, 'lambda':1e-5},
- >>> seed=0,
- >>> verbose=True)
- >>> model.fit(X)
+ >>> model = ScoringBasedEmbeddingModel(eta=5,
+ >>> k=300,
+ >>> scoring_type='ComplEx')
+ >>> model.compile(optimizer='adam', loss='multiclass_nll')
+ >>> model.fit(X,
+ >>> batch_size=10000,
+ >>> epochs=10)
>>>
>>> df = pd.DataFrame(X, columns=["s", "p", "o"])
>>>
>>> teams = np.unique(np.concatenate((df.s[df.s.str.startswith("Team")],
>>> df.o[df.o.str.startswith("Team")])))
- >>> team_embeddings = model.get_embeddings(teams, embedding_type='entity')
+ >>> team_embeddings = model.get_embeddings(teams, embedding_type='e')
>>>
>>> embeddings_2d = PCA(n_components=2).fit_transform(np.array([i for i in team_embeddings]))
>>>
>>> # Find clusters of embeddings using KMeans
+ >>>
>>> kmeans = KMeans(n_clusters=6, n_init=100, max_iter=500)
- >>> clusters = find_clusters(teams, model, kmeans, mode='entity')
+ >>> clusters = find_clusters(teams, model, kmeans, mode='e')
>>>
>>> # Plot results
>>> df = pd.DataFrame({"teams": teams, "clusters": "cluster" + pd.Series(clusters).astype(str),
@@ -542,9 +653,13 @@ def find_clusters(X, model, clustering_algorithm=DBSCAN(), mode="entity"):
>>> texts.append(plt.text(point['embedding1']+.02, point['embedding2'], str(point['teams'])))
>>> adjust_text(texts)
- .. image:: ../../docs/img/clustering/clustered_embeddings_docstring.png
+ .. image:: ../img/clustering/clustered_embeddings_docstring.png
+ :align: center
"""
+ if model.is_backward:
+ model = model.model
+
if not model.is_fitted:
msg = "Model has not been fitted."
logger.error(msg)
@@ -555,25 +670,27 @@ def find_clusters(X, model, clustering_algorithm=DBSCAN(), mode="entity"):
logger.error(msg)
raise ValueError(msg)
- modes = ("triple", "entity", "relation")
+ modes = ("t", "e", "r")
if mode not in modes:
- msg = "Argument `mode` must be one of the following: {}.".format(", ".join(modes))
+ msg = "Argument `mode` must be one of the following: {}.".format(
+ ", ".join(modes)
+ )
logger.error(msg)
raise ValueError(msg)
- if mode == "triple" and (len(X.shape) != 2 or X.shape[1] != 3):
- msg = "For 'triple' mode the input X must be a matrix with three columns."
+ if mode == "t" and (len(X.shape) != 2 or X.shape[1] != 3):
+ msg = "For 't' mode the input X must be a matrix with three columns."
logger.error(msg)
raise ValueError(msg)
- if mode in ("entity", "relation") and len(X.shape) != 1:
- msg = "For 'entity' or 'relation' mode the input X must be an array."
+ if mode in ("e", "r") and len(X.shape) != 1:
+ msg = "For 'e' or 'r' mode the input X must be an array."
raise ValueError(msg)
- if mode == "triple":
- s = model.get_embeddings(X[:, 0], embedding_type='entity')
- p = model.get_embeddings(X[:, 1], embedding_type='relation')
- o = model.get_embeddings(X[:, 2], embedding_type='entity')
+ if mode == "t":
+ s = model.get_embeddings(X[:, 0], embedding_type="e")
+ p = model.get_embeddings(X[:, 1], embedding_type="r")
+ o = model.get_embeddings(X[:, 2], embedding_type="e")
emb = np.hstack((s, p, o))
else:
emb = model.get_embeddings(X, embedding_type=mode)
@@ -581,14 +698,21 @@ def find_clusters(X, model, clustering_algorithm=DBSCAN(), mode="entity"):
return clustering_algorithm.fit_predict(emb)
-def find_duplicates(X, model, mode="entity", metric='l2', tolerance='auto',
- expected_fraction_duplicates=0.1, verbose=False):
+def find_duplicates(
+ X,
+ model,
+ mode="e",
+ metric="l2",
+ tolerance="auto",
+ expected_fraction_duplicates=0.1,
+ verbose=False,
+):
r"""
Find duplicate entities, relations or triples in a graph based on their embeddings.
For example, say you have a movie dataset that was scraped off the web with possible duplicate movies.
The movies in this case are the entities.
- Therefore, you would use the 'entity' mode to find all the movies that could de duplicates of each other.
+ Therefore, you would use the `"e"` mode to find all the movies that could de duplicates of each other.
Duplicates are defined as points whose distance in the embedding space are smaller than
some given threshold (called the tolerance).
@@ -607,21 +731,23 @@ def find_duplicates(X, model, mode="entity", metric='l2', tolerance='auto',
Parameters
----------
- X : ndarray, shape [n, 3] or [n]
+ X : ndarray, shape (n, 3) or (n)
The input to be clustered.
- X can either be the triples of a knowledge graph, its entities, or its relations.
- The argument `mode` defines whether X is supposed an array of triples
+ `X` can either be the triples of a knowledge graph, its entities, or its relations.
+ The argument ``mode`` defines whether X is supposed to be an array of triples
or an array of either entities or relations.
model : EmbeddingModel
The fitted model that will be used to generate the embeddings.
This model must have been fully trained already, be it directly with ``fit()``
or from a helper function such as :meth:`ampligraph.evaluation.select_best_model_ranking`.
- mode: string
+ mode: str
+ Specifies among which type of entities to look for duplicates.
+
Choose from:
- - | 'entity' (default): the algorithm will find duplicates of the provided entities based on their embeddings.
- - | 'relation': the algorithm will find duplicates of the provided relations based on their embeddings.
- - | 'triple' : the algorithm will find duplicates of the concatenation
+ - | `'e'` (default): the algorithm will find duplicates of the provided entities based on their embeddings.
+ - | `'r'`: the algorithm will find duplicates of the provided relations based on their embeddings.
+ - | `'t'` : the algorithm will find duplicates of the concatenation
of the embeddings of the subject, predicate and object for each provided triple.
metric: str
@@ -629,14 +755,14 @@ def find_duplicates(X, model, mode="entity", metric='l2', tolerance='auto',
`See options here `_.
tolerance: int or str
Minimum distance (depending on the chosen ``metric``) to define one entity as the duplicate of another.
- If 'auto', it will be determined automatically in a way that you get the ``expected_fraction_duplicates``.
- The 'auto' option can be much slower than the regular one, as the finding duplicate internal procedure
+ If `'auto'`, it will be determined automatically in a way that you get the ``expected_fraction_duplicates``.
+ The `'auto'` option can be much slower than the regular one, as the finding duplicate internal procedure
will be repeated multiple times.
expected_fraction_duplicates: float
- Expected fraction of duplicates to be found. It is used only when ``tolerance`` is 'auto'.
+ Expected fraction of duplicates to be found. It is used only when ``tolerance='auto'``.
Should be between 0 and 1 (default: 0.1).
verbose: bool
- Whether to print evaluation messages during optimisation (if ``tolerance`` is 'auto'). Default: False.
+ Whether to print evaluation messages during optimisation when ``tolerance='auto'`` (default: `False`).
Returns
-------
@@ -646,14 +772,14 @@ def find_duplicates(X, model, mode="entity", metric='l2', tolerance='auto',
Each frozenset will contain at least two entities.
tolerance: float
- Tolerance used to find the duplicates (useful in the case of the automatic tolerance option).
+ Tolerance used to find the duplicates (useful if the automatic tolerance option is selected).
- Examples
- --------
+ Example
+ -------
>>> import pandas as pd
>>> import numpy as np
>>> import re
- >>>
+ >>> from ampligraph.latent_features.models import ScoringBasedEmbeddingModel
>>> # The IMDB dataset used here is part of the Movies5 dataset found on:
>>> # The Magellan Data Repository (https://sites.google.com/site/anhaidgroup/projects/data)
>>> import requests
@@ -686,99 +812,124 @@ def find_duplicates(X, model, mode="entity", metric='l2', tolerance='auto',
>>> genres_triples = [(movie_id, "hasGenre", g) for g in genres]
>>> duration_triple = (movie_id, "hasDuration", duration)
>>>
+ >>>
>>> imdb_triples.extend(directors_triples)
>>> imdb_triples.extend(actors_triples)
>>> imdb_triples.extend(genres_triples)
>>> imdb_triples.append(duration_triple)
>>>
>>> # Training knowledge graph embedding with ComplEx model
- >>> from ampligraph.latent_features import ComplEx
- >>>
- >>> model = ComplEx(batches_count=10,
- >>> seed=0,
- >>> epochs=200,
- >>> k=150,
- >>> eta=5,
- >>> optimizer='adam',
- >>> optimizer_params={'lr':1e-3},
- >>> loss='multiclass_nll',
- >>> regularizer='LP',
- >>> regularizer_params={'p':3, 'lambda':1e-5},
- >>> verbose=True)
+ >>> from ampligraph.latent_features import ScoringBasedEmbeddingModel
>>>
>>> imdb_triples = np.array(imdb_triples)
- >>> model.fit(imdb_triples)
+ >>> model = ScoringBasedEmbeddingModel(eta=5,
+ >>> k=300,
+ >>> scoring_type='ComplEx')
+ >>> model.compile(optimizer='adam', loss='multiclass_nll')
+ >>> model.fit(imdb_triples,
+ >>> batch_size=10000,
+ >>> epochs=10)
>>>
>>> # Finding duplicates movies (entities)
>>> from ampligraph.discovery import find_duplicates
>>>
>>> entities = np.unique(imdb_triples[:, 0])
- >>> dups, _ = find_duplicates(entities, model, mode='entity', tolerance=0.4)
- >>> print(list(dups)[:3])
- [frozenset({'ID4048', 'ID4049'}), frozenset({'ID5994', 'ID5993'}), frozenset({'ID6447', 'ID6448'})]
- >>> print(imdb[imdb.id.isin((4048, 4049, 5994, 5993, 6447, 6448))][['movie_name', 'year']])
- movie_name year
- 4048 Ulterior Motives 1993
- 4049 Ulterior Motives 1993
- 5993 Chinese Hercules 1973
- 5994 Chinese Hercules 1973
- 6447 The Stranglers of Bombay 1959
- 6448 The Stranglers of Bombay 1959
-
+ >>> dups, _ = find_duplicates(entities, model, mode='e', tolerance=0.45)
+ >>> id_list = []
+ >>> for data in dups:
+ >>> for i in data:
+ >>> id_list.append(int(i[2:]))
+ >>> print(imdb.iloc[id_list[:6]][['movie_name', 'year']])
+ Epoch 1/10
+ 7/7 [==============================] - 1s 122ms/step - loss: 15612.8799
+ Epoch 2/10
+ 7/7 [==============================] - 0s 20ms/step - loss: 15610.5010
+ Epoch 3/10
+ 7/7 [==============================] - 0s 19ms/step - loss: 15607.7412
+ Epoch 4/10
+ 7/7 [==============================] - 0s 19ms/step - loss: 15604.0674
+ Epoch 5/10
+ 7/7 [==============================] - 0s 20ms/step - loss: 15598.9365
+ Epoch 6/10
+ 7/7 [==============================] - 0s 19ms/step - loss: 15591.7188
+ Epoch 7/10
+ 7/7 [==============================] - 0s 19ms/step - loss: 15581.6055
+ Epoch 8/10
+ 7/7 [==============================] - 0s 20ms/step - loss: 15567.6807
+ Epoch 9/10
+ 7/7 [==============================] - 0s 20ms/step - loss: 15548.8184
+ Epoch 10/10
+ 7/7 [==============================] - 0s 21ms/step - loss: 15523.8721
+ movie_name year
+ 5198 Duel to Death 1983
+ 5199 Duel to Death 1983
+ 2649 The Eliminator 2004
+ 2650 The Eliminator 2004
+ 3967 Lipstick Camera 1994
+ 3968 Lipstick Camera 1994
"""
+ if model.is_backward:
+ model = model.model
+
if not model.is_fitted:
msg = "Model has not been fitted."
logger.error(msg)
raise ValueError(msg)
- modes = ("triple", "entity", "relation")
+ modes = ("t", "e", "r")
if mode not in modes:
- msg = "Argument `mode` must be one of the following: {}.".format(", ".join(modes))
+ msg = "Argument `mode` must be one of the following: {}.".format(
+ ", ".join(modes)
+ )
logger.error(msg)
raise ValueError(msg)
- if mode == "triple" and (len(X.shape) != 2 or X.shape[1] != 3):
- msg = "For 'triple' mode the input X must be a matrix with three columns."
+ if mode == "t" and (len(X.shape) != 2 or X.shape[1] != 3):
+ msg = "For 't' mode the input X must be a matrix with three columns."
logger.error(msg)
raise ValueError(msg)
- if mode in ("entity", "relation") and len(X.shape) != 1:
- msg = "For 'entity' or 'relation' mode the input X must be an array."
+ if mode in ("e", "r") and len(X.shape) != 1:
+ msg = "For 'e' or 'r' mode the input X must be an array."
logger.error(msg)
raise ValueError(msg)
- if mode == "triple":
- s = model.get_embeddings(X[:, 0], embedding_type='entity')
- p = model.get_embeddings(X[:, 1], embedding_type='relation')
- o = model.get_embeddings(X[:, 2], embedding_type='entity')
+ if mode == "t":
+ s = model.get_embeddings(X[:, 0], embedding_type="e")
+ p = model.get_embeddings(X[:, 1], embedding_type="r")
+ o = model.get_embeddings(X[:, 2], embedding_type="e")
emb = np.hstack((s, p, o))
else:
emb = model.get_embeddings(X, embedding_type=mode)
def get_dups(tol):
"""
- Given tolerance, finds duplicate entities in a graph based on their embeddings.
+ Given tolerance, finds duplicate entities in a graph based on their embeddings.
- Parameters
- ----------
- tol: float
- Minimum distance (depending on the chosen metric) to define one entity as the duplicate of another.
+ Parameters
+ ----------
+ tol: float
+ Minimum distance (depending on the chosen metric) to define one entity as the duplicate of another.
- Returns
- -------
- duplicates : set of frozensets
- Each entry in the duplicates set is a frozenset containing all entities that were found to be duplicates
- according to the metric and tolerance.
- Each frozenset will contain at least two entities.
+ Returns
+ -------
+ duplicates : set of frozensets
+ Each entry in the duplicates set is a frozenset containing all entities that were found to be duplicates
+ according to the metric and tolerance.
+ Each frozenset will contain at least two entities.
"""
nn = NearestNeighbors(metric=metric, radius=tol)
nn.fit(emb)
neighbors = nn.radius_neighbors(emb)[1]
- idx_dups = ((i, row) for i, row in enumerate(neighbors) if len(row) > 1)
- if mode == "triple":
- dups = {frozenset(tuple(X[idx]) for idx in row) for i, row in idx_dups}
+ idx_dups = (
+ (i, row) for i, row in enumerate(neighbors) if len(row) > 1
+ )
+ if mode == "t":
+ dups = {
+ frozenset(tuple(X[idx]) for idx in row) for i, row in idx_dups
+ }
else:
dups = {frozenset(X[idx] for idx in row) for i, row in idx_dups}
return dups
@@ -793,24 +944,43 @@ def opt(tol, info):
duplicates = get_dups(tol)
fraction_duplicates = len(set().union(*duplicates)) / len(emb)
if verbose:
- info['Nfeval'] += 1
- logger.info("Eval {}: tol: {}, duplicate fraction: {}".format(info['Nfeval'], tol, fraction_duplicates))
+ info["Nfeval"] += 1
+ logger.info(
+ "Eval {}: tol: {}, duplicate fraction: {}".format(
+ info["Nfeval"], tol, fraction_duplicates
+ )
+ )
return fraction_duplicates - expected_fraction_duplicates
- if tolerance == 'auto':
+ if tolerance == "auto":
max_distance = spatial.distance_matrix(emb, emb).max()
- tolerance = optimize.bisect(opt, 0.0, max_distance, xtol=1e-3, maxiter=50, args=({'Nfeval': 0}, ))
+ tolerance = optimize.bisect(
+ opt,
+ 0.0,
+ max_distance,
+ xtol=1e-3,
+ maxiter=50,
+ args=({"Nfeval": 0},),
+ )
return get_dups(tolerance), tolerance
-def query_topn(model, top_n=10, head=None, relation=None, tail=None, ents_to_consider=None, rels_to_consider=None):
+def query_topn(
+ model,
+ top_n=10,
+ head=None,
+ relation=None,
+ tail=None,
+ ents_to_consider=None,
+ rels_to_consider=None,
+):
"""Queries the model with two elements of a triple and returns the top_n results of
all possible completions ordered by score predicted by the model.
- For example, given a pair in the arguments, the model will score
- all possible triples , filling in the missing element with known
- entities, and return the top_n triples ordered by score. If given a
+ For example, given a `` pair in the arguments, the model will score
+ all possible triples ``, filling in the missing element with known
+ entities, and return the top_n triples ordered by score. If given a ``
pair it will fill in the missing element with known relations.
.. note::
@@ -823,131 +993,145 @@ def query_topn(model, top_n=10, head=None, relation=None, tail=None, ents_to_con
The trained model that will be used to score triple completions.
top_n : int
The number of completed triples to returned.
- head : string
+ head : str
An entity string to query.
- relation : string
+ relation : str
A relation string to query.
- tail :
+ tail : str
An object string to query.
ents_to_consider: array-like
- List of entities to use for triple completions. If None, will generate completions using all distinct entities.
- (Default: None.)
+ List of entities to use for triple completions. If `None`, will generate completions using all distinct entities
+ (Default: `None`).
rels_to_consider: array-like
- List of relations to use for triple completions. If None, will generate completions using all distinct
- relations. (Default: None.)
+ List of relations to use for triple completions. If `None`, will generate completions using all distinct
+ relations (default: `None`).
Returns
-------
- X : ndarray, shape [n, 3]
+ X : ndarray of shape (n, 3)
A list of triples ordered by score.
- S : ndarray, shape [n]
+ S : ndarray, shape (n)
A list of scores.
- Examples
- --------
+ Example
+ -------
>>> import requests
>>> from ampligraph.datasets import load_from_csv
- >>> from ampligraph.latent_features import ComplEx
>>> from ampligraph.discovery import discover_facts
>>> from ampligraph.discovery import query_topn
- >>>
+ >>> from ampligraph.latent_features import ScoringBasedEmbeddingModel
>>> # Game of Thrones relations dataset
>>> url = 'https://ampligraph.s3-eu-west-1.amazonaws.com/datasets/GoT.csv'
>>> open('GoT.csv', 'wb').write(requests.get(url).content)
>>> X = load_from_csv('.', 'GoT.csv', sep=',')
>>>
- >>> model = ComplEx(batches_count=10, seed=0, epochs=200, k=150, eta=5,
- >>> optimizer='adam', optimizer_params={'lr':1e-3}, loss='multiclass_nll',
- >>> regularizer='LP', regularizer_params={'p':3, 'lambda':1e-5},
- >>> verbose=True)
- >>> model.fit(X)
+ >>> model = ScoringBasedEmbeddingModel(eta=5,
+ >>> k=150,
+ >>> scoring_type='TransE')
+ >>> model.compile(optimizer='adagrad', loss='pairwise')
+ >>> model.fit(X,
+ >>> batch_size=100,
+ >>> epochs=20,
+ >>> verbose=False)
>>>
>>> query_topn(model, top_n=5,
- >>> head='Catelyn Stark', relation='ALLIED_WITH', tail=None,
+ >>> head='Eddard Stark', relation='ALLIED_WITH', tail=None,
>>> ents_to_consider=None, rels_to_consider=None)
>>>
- (array([['Catelyn Stark', 'ALLIED_WITH', 'House Tully of Riverrun'],
- ['Catelyn Stark', 'ALLIED_WITH', 'House Stark of Winterfell'],
- ['Catelyn Stark', 'ALLIED_WITH', 'House Wayn'],
- ['Catelyn Stark', 'ALLIED_WITH', 'House Mollen'],
- ['Catelyn Stark', 'ALLIED_WITH', 'Orton Merryweather']],
- dtype='>> import numpy as np
>>> from ampligraph.evaluation.metrics import hits_at_n_score
>>> rankings = np.array([1, 12, 6, 2])
>>> hits_at_n_score(rankings, n=3)
0.5
+
"""
- logger.debug('Calculating Hits@n.')
+ logger.debug("Calculating Hits@n.")
if isinstance(ranks, list):
- logger.debug('Converting ranks to numpy array.')
+ logger.debug("Converting ranks to numpy array.")
ranks = np.asarray(ranks)
ranks = ranks.reshape(-1)
return np.sum(ranks <= n) / len(ranks)
def mrr_score(ranks):
- r"""Mean Reciprocal Rank (MRR)
+ r"""Mean Reciprocal Rank (MRR).
The function computes the mean of the reciprocal of elements of a vector of rankings ``ranks``.
@@ -125,16 +127,16 @@ def mrr_score(ranks):
Parameters
----------
- ranks: ndarray or list, shape [n] or [n,2]
+ ranks: ndarray or list, shape (n) or (n,2)
Input ranks of n test statements.
Returns
-------
mrr_score: float
- The MRR score
+ The MRR score.
- Examples
- --------
+ Example
+ -------
>>> import numpy as np
>>> from ampligraph.evaluation.metrics import mrr_score
>>> rankings = np.array([1, 12, 6, 2])
@@ -142,18 +144,18 @@ def mrr_score(ranks):
0.4375
"""
- logger.debug('Calculating the Mean Reciprocal Rank.')
+ logger.debug("Calculating the Mean Reciprocal Rank.")
if isinstance(ranks, list):
- logger.debug('Converting ranks to numpy array.')
+ logger.debug("Converting ranks to numpy array.")
ranks = np.asarray(ranks)
ranks = ranks.reshape(-1)
return np.sum(1 / ranks) / len(ranks)
def rank_score(y_true, y_pred, pos_lab=1):
- """Rank of a triple
+ """Computes the rank of a triple.
- The rank of a positive element against a list of negatives.
+ The rank of a positive element against a list of negatives.
.. math::
@@ -161,9 +163,9 @@ def rank_score(y_true, y_pred, pos_lab=1):
Parameters
----------
- y_true : ndarray, shape [n]
+ y_true : ndarray, shape (n)
An array of binary labels. The array only contains one positive.
- y_pred : ndarray, shape [n]
+ y_pred : ndarray, shape (n)
An array of scores, for the positive element and the n-1 negatives.
pos_lab : int
The value of the positive label (default = 1).
@@ -173,8 +175,8 @@ def rank_score(y_true, y_pred, pos_lab=1):
rank : int
The rank of the positive element against the negatives.
- Examples
- --------
+ Example
+ -------
>>> import numpy as np
>>> from ampligraph.evaluation.metrics import rank_score
>>> y_pred = np.array([.434, .65, .21, .84])
@@ -184,7 +186,7 @@ def rank_score(y_true, y_pred, pos_lab=1):
"""
- logger.debug('Calculating the Rank Score.')
+ logger.debug("Calculating the Rank Score.")
idx = np.argsort(y_pred)[::-1]
y_ord = y_true[idx]
rank = np.where(y_ord == pos_lab)[0][0] + 1
@@ -192,9 +194,9 @@ def rank_score(y_true, y_pred, pos_lab=1):
def mr_score(ranks):
- r"""Mean Rank (MR)
+ r"""Mean Rank (MR).
- The function computes the mean of of a vector of rankings ``ranks``.
+ The function computes the mean of a vector of rankings ``ranks``.
It can be used in conjunction with the learning to rank evaluation protocol of
:meth:`ampligraph.evaluation.evaluate_performance`.
@@ -232,16 +234,16 @@ def mr_score(ranks):
Parameters
----------
- ranks: ndarray or list, shape [n] or [n,2]
+ ranks: ndarray or list, shape (n) or (n,2)
Input ranks of n test statements.
Returns
-------
mr_score: float
- The MR score
+ The MR score.
- Examples
- --------
+ Example
+ -------
>>> from ampligraph.evaluation import mr_score
>>> ranks= [5, 3, 4, 10, 1]
>>> mr_score(ranks)
@@ -249,9 +251,9 @@ def mr_score(ranks):
"""
- logger.debug('Calculating the Mean Average Rank score.')
+ logger.debug("Calculating the Mean Average Rank score.")
if isinstance(ranks, list):
- logger.debug('Converting ranks to numpy array.')
+ logger.debug("Converting ranks to numpy array.")
ranks = np.asarray(ranks)
ranks = ranks.reshape(-1)
return np.sum(ranks) / len(ranks)
diff --git a/ampligraph/evaluation/protocol.py b/ampligraph/evaluation/protocol.py
index 5ac95212..1a948e00 100644
--- a/ampligraph/evaluation/protocol.py
+++ b/ampligraph/evaluation/protocol.py
@@ -1,4 +1,4 @@
-# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
#
# This file is Licensed under the Apache License, Version 2.0.
# A copy of the Licence is available in LICENCE, or at:
@@ -6,19 +6,15 @@
# http://www.apache.org/licenses/LICENSE-2.0
#
-from collections.abc import Iterable
-from itertools import product, islice
import logging
-import warnings
+from collections.abc import Iterable
+from itertools import islice, product
-import pandas as pd
import numpy as np
+import pandas as pd
from tqdm import tqdm
-import tensorflow as tf
-from ..evaluation import mrr_score, hits_at_n_score, mr_score
-from ..datasets import AmpligraphDatasetAdapter, NumpyDatasetAdapter, OneToNDatasetAdapter
-# from ampligraph.latent_features.models import ConvE
+from ..evaluation import hits_at_n_score, mr_score, mrr_score
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@@ -26,42 +22,47 @@
TOO_MANY_ENTITIES_TH = 50000
-def _train_test_split_no_unseen_fast(X, test_size=100, seed=0, allow_duplication=False, filtered_test_predicates=None):
+def train_test_split_no_unseen(
+ X,
+ test_size=100,
+ seed=0,
+ allow_duplication=False,
+ filtered_test_predicates=None,
+):
"""Split into train and test sets.
- This function carves out a test set that contains only entities
- and relations which also occur in the training set.
-
- This is an improved version which is much faster - since this doesnt sample like earlier approach but rather
- shuffles indices and gets the test set of required size by selecting from the shuffled indices only triples
- which do not disconnect entities/relations.
+ This function carves out a test set that contains only entities
+ and relations which also occur in the training set.
+
+ This is an improved version which is much faster - since this does not sample like in the earlier approach but
+ rather shuffles indices and gets the test set of required size by selecting from the shuffled indices only triples
+ which do not disconnect entities/relations.
Parameters
----------
- X : ndarray, size[n, 3]
+ X : ndarray, shape (n, 3)
The dataset to split.
test_size : int, float
- If int, the number of triples in the test set.
- If float, the percentage of total triples.
+ If `int`, the number of triples in the test set.
+ If `float`, the percentage of total triples.
seed : int
A random seed used to split the dataset.
- allow_duplication: boolean
+ allow_duplication: bool
Flag to indicate if the test set can contain duplicated triples.
filtered_test_predicates: None, list
- If None, all predicate types will be considered for the test set.
- If list, only the predicate types in the list will be considered for
+ If `None`, all predicate types will be considered for the test set.
+ If `list`, only the predicate types in the list will be considered for
the test set.
Returns
-------
- X_train : ndarray, size[n, 3]
+ X_train : ndarray, shape (n, 3)
The training set.
- X_test : ndarray, size[n, 3]
+ X_test : ndarray, shape (n, 3)
The test set.
- Examples
- --------
-
+ Example
+ -------
>>> import numpy as np
>>> from ampligraph.evaluation import train_test_split_no_unseen
>>> # load your dataset to X
@@ -88,8 +89,8 @@ def _train_test_split_no_unseen_fast(X, test_size=100, seed=0, allow_duplication
array([['f', 'y', 'e'],
['c', 'y', 'd']], dtype='>> # if you want to split into train/valid/test datasets, call it 2 times
- >>> X_train_valid, X_test = train_test_split_no_unseen(X, test_size=2)
- >>> X_train, X_valid = train_test_split_no_unseen(X_train_valid, test_size=2)
+ >>> X_train_valid, X_test = train_test_split_no_unseen(X, test_size=2, backward_compatible=True)
+ >>> X_train, X_valid = train_test_split_no_unseen(X_train_valid, test_size=2, backward_compatible=True)
>>> X_train
array([['a', 'y', 'b'],
['a', 'y', 'd'],
@@ -103,9 +104,6 @@ def _train_test_split_no_unseen_fast(X, test_size=100, seed=0, allow_duplication
array([['b', 'y', 'c'],
['b', 'y', 'a']], dtype=' 0
- if dict_entities[test_triple[0]] > 0 and \
- dict_rels[test_triple[1]] > 0 and \
- dict_entities[test_triple[2]] > 0:
-
+ if (
+ dict_entities[test_triple[0]] > 0
+ and dict_rels[test_triple[1]] > 0
+ and dict_entities[test_triple[2]] > 0
+ ):
# Can safetly add the triple to test set
idx_test.append(idx)
if len(idx_test) == test_size:
# Since we found the requested test set of given size
# add all the remaining indices of candidates to training set
idx_train.extend(list(all_indices_shuffled[i + 1:]))
-
+
# break out of the loop
break
-
+
else:
- # since removing this triple results in unseen entities, add it to training
+ # since removing this triple results in unseen entities, add it to
+ # training
dict_entities[test_triple[0]] = dict_entities[test_triple[0]] + 1
dict_rels[test_triple[1]] = dict_rels[test_triple[1]] + 1
dict_entities[test_triple[2]] = dict_entities[test_triple[2]] + 1
idx_train.append(idx)
-
+
if len(idx_test) != test_size:
# if we cannot get the test set of required size that means we cannot get unique triples
# in the test set without creating unseen entities
if allow_duplication:
- # if duplication is allowed, randomly choose from the existing test set and create duplicates
- duplicate_idx = np.random.choice(idx_test, size=(test_size - len(idx_test))).tolist()
+ # if duplication is allowed, randomly choose from the existing test
+ # set and create duplicates
+ duplicate_idx = np.random.choice(
+ idx_test, size=(test_size - len(idx_test))
+ ).tolist()
idx_test.extend(list(duplicate_idx))
else:
- # throw an exception since we cannot get unique triples in the test set without creating
+ # throw an exception since we cannot get unique triples in the test set without creating
# unseen entities
- raise Exception("Cannot create a test split of the desired size. "
- "Some entities will not occur in both training and test set. "
- "Set allow_duplication=True,"
- "remove filter on test predicates or "
- "set test_size to a smaller value.")
-
+ raise Exception(
+ "Cannot create a test split of the desired size. "
+ "Some entities will not occur in both training and test set. "
+ "Set allow_duplication=True,"
+ "remove filter on test predicates or "
+ "set test_size to a smaller value."
+ )
+
if X_train is None:
X_train = X_test_candidates[idx_train]
else:
X_train_subset = X_test_candidates[idx_train]
X_train = np.concatenate([X_train, X_train_subset])
X_test = X_test_candidates[idx_test]
-
+
X_train = np.random.permutation(X_train)
X_test = np.random.permutation(X_test)
- return X_train, X_test
-
-
-def _train_test_split_no_unseen_old(X, test_size=100, seed=0, allow_duplication=False, filtered_test_predicates=None):
- """Split into train and test sets.
-
- This function carves out a test set that contains only entities
- and relations which also occur in the training set.
-
- This is very slow as it runs an infinite loop and samples a triples and appends to test set and checks if it is
- unique or not. This is very time consuming process and highly inefficient.
-
- Parameters
- ----------
- X : ndarray, size[n, 3]
- The dataset to split.
- test_size : int, float
- If int, the number of triples in the test set.
- If float, the percentage of total triples.
- seed : int
- A random seed used to split the dataset.
- allow_duplication: boolean
- Flag to indicate if the test set can contain duplicated triples.
- filtered_test_predicates: None, list
- If None, all predicate types will be considered for the test set.
- If list, only the predicate types in the list will be considered for
- the test set.
-
- Returns
- -------
- X_train : ndarray, size[n, 3]
- The training set.
- X_test : ndarray, size[n, 3]
- The test set.
-
- Examples
- --------
-
- >>> import numpy as np
- >>> from ampligraph.evaluation import train_test_split_no_unseen
- >>> # load your dataset to X
- >>> X = np.array([['a', 'y', 'b'],
- >>> ['f', 'y', 'e'],
- >>> ['b', 'y', 'a'],
- >>> ['a', 'y', 'c'],
- >>> ['c', 'y', 'a'],
- >>> ['a', 'y', 'd'],
- >>> ['c', 'y', 'd'],
- >>> ['b', 'y', 'c'],
- >>> ['f', 'y', 'e']])
- >>> # if you want to split into train/test datasets
- >>> X_train, X_test = train_test_split_no_unseen(X, test_size=2, backward_compatible=True)
- >>> X_train
- array([['a', 'y', 'b'],
- ['f', 'y', 'e'],
- ['b', 'y', 'a'],
- ['c', 'y', 'a'],
- ['c', 'y', 'd'],
- ['b', 'y', 'c'],
- ['f', 'y', 'e']], dtype='>> X_test
- array([['a', 'y', 'c'],
- ['a', 'y', 'd']], dtype='>> # if you want to split into train/valid/test datasets, call it 2 times
- >>> X_train_valid, X_test = train_test_split_no_unseen(X, test_size=2, backward_compatible=True)
- >>> X_train, X_valid = train_test_split_no_unseen(X_train_valid, test_size=2, backward_compatible=True)
- >>> X_train
- array([['a', 'y', 'b'],
- ['b', 'y', 'a'],
- ['c', 'y', 'd'],
- ['b', 'y', 'c'],
- ['f', 'y', 'e']], dtype='>> X_valid
- array([['f', 'y', 'e'],
- ['c', 'y', 'a']], dtype='>> X_test
- array([['a', 'y', 'c'],
- ['a', 'y', 'd']], dtype=' 1 and dict_objs[X[i, 2]] > 1 and dict_rels[X[i, 1]] > 1:
- dict_subs[X[i, 0]] -= 1
- dict_objs[X[i, 2]] -= 1
- dict_rels[X[i, 1]] -= 1
- if allow_duplication:
- idx_test = np.append(idx_test, i)
- else:
- idx_test = np.unique(np.append(idx_test, i))
-
- loop_count += 1
-
- # in case can't find solution
- if loop_count == tolerance:
- if allow_duplication:
- raise Exception("Cannot create a test split of the desired size. "
- "Some entities will not occur in both training and test set. "
- "Change seed values, remove filter on test predicates or set "
- "test_size to a smaller value.")
- else:
- raise Exception("Cannot create a test split of the desired size. "
- "Some entities will not occur in both training and test set. "
- "Set allow_duplication=True,"
- "change seed values, remove filter on test predicates or "
- "set test_size to a smaller value.")
-
- logger.debug('Completed random search.')
-
- idx = np.arange(len(X))
- idx_train = np.setdiff1d(idx, idx_test)
- logger.debug('Train test split completed.')
-
- return X[idx_train, :], X[idx_test, :]
-
-
-def train_test_split_no_unseen(X, test_size=100, seed=0, allow_duplication=False,
- filtered_test_predicates=None, backward_compatible=False):
- """Split into train and test sets.
-
- This function carves out a test set that contains only entities
- and relations which also occur in the training set.
-
- Parameters
- ----------
- X : ndarray, size[n, 3]
- The dataset to split.
- test_size : int, float
- If int, the number of triples in the test set.
- If float, the percentage of total triples.
- seed : int
- A random seed used to split the dataset.
- allow_duplication: boolean
- Flag to indicate if the test set can contain duplicated triples.
- filtered_test_predicates: None, list
- If None, all predicate types will be considered for the test set.
- If list, only the predicate types in the list will be considered for
- the test set.
- backward_compatible: boolean
- Uses the old (slower) version of the API for reproducibility of splits in older pipelines(if any)
- Avoid setting this to True, unless necessary. Set this flag only if you want to use the
- train_test_split_no_unseen of Ampligraph versions 1.3.2 and below. The older version is slow and inefficient
-
- Returns
- -------
- X_train : ndarray, size[n, 3]
- The training set.
- X_test : ndarray, size[n, 3]
- The test set.
-
- Examples
- --------
-
- >>> import numpy as np
- >>> from ampligraph.evaluation import train_test_split_no_unseen
- >>> # load your dataset to X
- >>> X = np.array([['a', 'y', 'b'],
- >>> ['f', 'y', 'e'],
- >>> ['b', 'y', 'a'],
- >>> ['a', 'y', 'c'],
- >>> ['c', 'y', 'a'],
- >>> ['a', 'y', 'd'],
- >>> ['c', 'y', 'd'],
- >>> ['b', 'y', 'c'],
- >>> ['f', 'y', 'e']])
- >>> # if you want to split into train/test datasets
- >>> X_train, X_test = train_test_split_no_unseen(X, test_size=2)
- >>> X_train
- array([['a', 'y', 'd'],
- ['b', 'y', 'a'],
- ['a', 'y', 'c'],
- ['f', 'y', 'e'],
- ['a', 'y', 'b'],
- ['c', 'y', 'a'],
- ['b', 'y', 'c']], dtype='>> X_test
- array([['f', 'y', 'e'],
- ['c', 'y', 'd']], dtype='>> # if you want to split into train/valid/test datasets, call it 2 times
- >>> X_train_valid, X_test = train_test_split_no_unseen(X, test_size=2)
- >>> X_train, X_valid = train_test_split_no_unseen(X_train_valid, test_size=2)
- >>> X_train
- array([['a', 'y', 'b'],
- ['a', 'y', 'd'],
- ['a', 'y', 'c'],
- ['c', 'y', 'a'],
- ['f', 'y', 'e']], dtype='>> X_valid
- array([['c', 'y', 'd'],
- ['f', 'y', 'e']], dtype='>> X_test
- array([['b', 'y', 'c'],
- ['b', 'y', 'a']], dtype=', where we are mainly interested in such movies that an actor
- has acted in. A sensible way to evaluate this would be to rank against all the movie entities and compute
- the desired metrics. In such cases, where focus us on particular task, it is recommended to pass the desired
- entities to use to generate corruptions to ``entities_subset``. Besides, trying to rank a positive against an
- extremely large number of negatives may be overkilling.
-
- As a reference, the popular FB15k-237 dataset has ~15k distinct entities. The evaluation protocol ranks each
- positives against 15k corruptions per side.
-
- Parameters
- ----------
- X : ndarray, shape [n, 3]
- An array of test triples.
- model : EmbeddingModel
- A knowledge graph embedding model
- filter_triples : ndarray of shape [n, 3] or None
- The triples used to filter negatives.
-
- .. note::
- When *filtered* mode is enabled (i.e. `filtered_triples` is not ``None``),
- to speed up the procedure, we use a database based filtering. This strategy is as described below:
-
- * Store the filter_triples in the DB
- * For each test triple, we generate corruptions for evaluation and score them.
- * The corruptions may contain some False Negatives. We find such statements by quering the database.
- * From the computed scores we retrieve the scores of the False Negatives.
- * We compute the rank of the test triple by comparing against ALL the corruptions.
- * We then compute the number of False negatives that are ranked higher than the test triple; and then
- subtract this value from the above computed rank to yield the final filtered rank.
-
- **Execution Time:** This method takes ~4 minutes on FB15K using ComplEx
- (Intel Xeon Gold 6142, 64 GB Ubuntu 16.04 box, Tesla V100 16GB)
-
- verbose : bool
- Verbose mode
- filter_unseen : bool
- This can be set to False to skip filtering of unseen entities if train_test_split_unseen() was used to
- split the original dataset.
-
- entities_subset: array-like
- List of entities to use for corruptions. If None, will generate corruptions
- using all distinct entities. Default is None.
- corrupt_side: string
- Specifies which side of the triple to corrupt:
-
- - 's': corrupt only subject.
- - 'o': corrupt only object.
- - 's+o': corrupt both subject and object.
- - 's,o': corrupt subject and object sides independently and return 2 ranks. This corresponds to the \
- evaluation protocol used in literature, where head and tail corruptions are evaluated separately.
-
- .. note::
- When ``corrupt_side='s,o'`` the function will return 2*n ranks as a [n, 2] array.
- The first column of the array represents the subject corruptions.
- The second column of the array represents the object corruptions.
- Otherwise, the function returns n ranks as [n] array.
-
- ranking_strategy: string
- Specifies the type of score comparison strategy to use while ranking:
-
- - 'worst': assigns the worst rank when scores are equal
- - 'best': assigns the best rank when scores are equal
- - 'middle': assigns the middle rank when scores are equal
-
- Our recommendation is to use ``worst``.
- Think of a model which assigns constant score to any triples. If you use the ``best`` strategy then
- the ranks will always be 1 (which is incorrect because the model has not learnt anything). If you choose
- this model and try to do knowledge discovery, you will not be able to deduce anything as all triples will
- get the same scores. So to be on safer side while choosing the model, we would recommend either ``worst``
- or ``middle`` strategy.
-
- use_default_protocol: bool
- Flag to indicate whether to use the standard protocol used in literature defined in
- :cite:`bordes2013translating` (default: False).
- If set to `True`, ``corrupt_side`` will be set to `'s,o'`.
- This corresponds to the evaluation protocol used in literature, where head and tail corruptions
- are evaluated separately, i.e. in corrupt_side='s,o' mode
-
- Returns
- -------
- ranks : ndarray, shape [n] or [n,2] depending on the value of corrupt_side.
- An array of ranks of test triples.
- When ``corrupt_side='s,o'`` the function returns [n,2]. The first column represents the rank against
- subject corruptions and the second column represents the rank against object corruptions.
- In other cases, it returns [n] i.e. rank against the specified corruptions.
-
- Examples
- --------
- >>> import numpy as np
- >>> from ampligraph.datasets import load_wn18
- >>> from ampligraph.latent_features import ComplEx
- >>> from ampligraph.evaluation import evaluate_performance, mrr_score, hits_at_n_score
- >>>
- >>> X = load_wn18()
- >>> model = ComplEx(batches_count=10, seed=0, epochs=10, k=150, eta=1,
- >>> loss='nll', optimizer='adam')
- >>> model.fit(np.concatenate((X['train'], X['valid'])))
- >>>
- >>> filter_triples = np.concatenate((X['train'], X['valid'], X['test']))
- >>> ranks = evaluate_performance(X['test'][:5], model=model,
- >>> filter_triples=filter_triples,
- >>> corrupt_side='s+o',
- >>> use_default_protocol=False)
- >>> ranks
- array([ 1, 582, 543, 6, 31])
- >>> mrr_score(ranks)
- 0.24049691297347323
- >>> hits_at_n_score(ranks, n=10)
- 0.4
- """
-
- from ampligraph.latent_features import ConvE # avoids circular import hell
-
- dataset_handle = None
-
- # try-except block is mainly to handle clean up in case of exception or manual stop in jupyter notebook
- try:
- if use_default_protocol:
- logger.warning('DeprecationWarning: use_default_protocol will be removed in future. '
- 'Please use corrupt_side argument instead.')
- corrupt_side = 's,o'
-
- logger.debug('Evaluating the performance of the embedding model.')
- assert corrupt_side in ['s', 'o', 's+o', 's,o'], 'Invalid value for corrupt_side.'
- if isinstance(X, np.ndarray):
-
- if filter_unseen:
- X = filter_unseen_entities(X, model, verbose=verbose)
- else:
- logger.warning("If your test set or filter triples contain unseen entities you may get a"
- "runtime error. You can filter them by setting filter_unseen=True")
-
- if isinstance(model, ConvE):
- dataset_handle = OneToNDatasetAdapter()
- else:
- dataset_handle = NumpyDatasetAdapter()
-
- dataset_handle.use_mappings(model.rel_to_idx, model.ent_to_idx)
- dataset_handle.set_data(X, 'test')
-
- elif isinstance(X, AmpligraphDatasetAdapter):
- dataset_handle = X
- else:
- msg = "X must be either a numpy array or an AmpligraphDatasetAdapter."
- logger.error(msg)
- raise ValueError(msg)
-
- if filter_triples is not None:
- if isinstance(filter_triples, np.ndarray):
- logger.debug('Getting filtered triples.')
-
- if filter_unseen:
- filter_triples = filter_unseen_entities(filter_triples, model, verbose=verbose)
- dataset_handle.set_filter(filter_triples)
- model.set_filter_for_eval()
- elif isinstance(X, AmpligraphDatasetAdapter):
- if not isinstance(filter_triples, bool):
- raise Exception('Expected a boolean type')
- if filter_triples is True:
- model.set_filter_for_eval()
- else:
- raise Exception('Invalid datatype for filter. Expected a numpy array or preset data in the adapter.')
-
- eval_dict = {}
-
- # #186: print warning when trying to evaluate with too many entities.
- # Thus will likely result in shooting in your feet, as the protocol will be excessively hard.
- check_filter_size(model, entities_subset)
-
- if entities_subset is not None:
- idx_entities = np.asarray([idx for uri, idx in model.ent_to_idx.items() if uri in entities_subset])
- eval_dict['corruption_entities'] = idx_entities
-
- logger.debug('Evaluating the test set by corrupting side : {}'.format(corrupt_side))
- eval_dict['corrupt_side'] = corrupt_side
-
- assert ranking_strategy in ['worst', 'best', 'middle'], 'Invalid ranking_strategy!'
-
- eval_dict['ranking_strategy'] = ranking_strategy
-
- logger.debug('Configuring evaluation protocol.')
- model.configure_evaluation_protocol(eval_dict)
-
- logger.debug('Making predictions.')
- ranks = model.get_ranks(dataset_handle)
-
- logger.debug('Ending Evaluation')
- model.end_evaluation()
-
- logger.debug('Returning ranks of positive test triples obtained by corrupting {}.'.format(corrupt_side))
- return np.array(ranks)
-
- except BaseException as e:
- model.end_evaluation()
- if dataset_handle is not None:
- dataset_handle.cleanup()
- raise e
-
-
-def check_filter_size(model, corruption_entities):
- """ Raise a warning when trying to evaluate with too many entities.
-
- Doing so will likely result in shooting in your feet, as the protocol will be excessively hard,
- hence the warning message.
-
- Addresses #186.
-
- Parameters
- ----------
- model : the model
- corruption_entities : the corruption_entities used in the protocol
-
- Returns
- -------
- None.
-
- """
-
- warn_msg = """You are attempting to use %d distinct entities to generate synthetic negatives in the evaluation
- protocol. This may be unnecessary and will lead to a 'harder' task. Besides, it will lead to a much slower
- evaluation procedure. We recommended to set the 'corruption_entities' argument to a reasonably sized set
- of entities. The size of corruption_entities depends on your domain-specific task."""
-
- if corruption_entities is None:
- ent_for_corruption_size = len(model.ent_to_idx)
- else:
- ent_for_corruption_size = len(corruption_entities)
-
- if ent_for_corruption_size >= TOO_MANY_ENTITIES_TH:
- warnings.warn(warn_msg % ent_for_corruption_size)
- logger.warning(warn_msg, ent_for_corruption_size)
+ return X_train, X_test
def filter_unseen_entities(X, model, verbose=False):
@@ -958,7 +201,7 @@ def filter_unseen_entities(X, model, verbose=False):
Parameters
----------
- X : ndarray, shape [n, 3]
+ X : ndarray, shape (n, 3)
An array of test triples.
model : ampligraph.latent_features.EmbeddingModel
A knowledge graph embedding model.
@@ -967,16 +210,20 @@ def filter_unseen_entities(X, model, verbose=False):
Returns
-------
- filtered X : ndarray, shape [n, 3]
+ filtered X : ndarray, shape (n, 3)
An array of test triples containing no unseen entities.
"""
- logger.debug('Finding entities in the dataset that are not previously seen by model')
+ logger.debug(
+ "Finding entities in the dataset that are not previously seen by model"
+ )
ent_seen = np.unique(list(model.ent_to_idx.keys()))
- df = pd.DataFrame(X, columns=['s', 'p', 'o'])
+ df = pd.DataFrame(X, columns=["s", "p", "o"])
filtered_df = df[df.s.isin(ent_seen) & df.o.isin(ent_seen)]
n_removed_ents = df.shape[0] - filtered_df.shape[0]
if n_removed_ents > 0:
- msg = 'Removing {} triples containing unseen entities. '.format(n_removed_ents)
+ msg = "Removing {} triples containing unseen entities. ".format(
+ n_removed_ents
+ )
if verbose:
logger.info(msg)
logger.debug(msg)
@@ -984,83 +231,52 @@ def filter_unseen_entities(X, model, verbose=False):
return X
-def _remove_unused_params(params):
- """
- Removed unused parameters considering the registries.
-
- For example, if the regularization is None, there is no need for the regularization parameter lambda.
-
- Parameters
- ----------
- params: dict
- Dictionary with parameters.
-
- Returns
- -------
- params: dict
- Param dict without unused parameters.
- """
- from ..latent_features import LOSS_REGISTRY, REGULARIZER_REGISTRY, MODEL_REGISTRY, \
- OPTIMIZER_REGISTRY, INITIALIZER_REGISTRY
-
- def _param_without_unused(param, registry, category_type, category_type_params):
- """Remove one particular nested param (if unused) given a registry"""
- if category_type_params in param and category_type in registry:
- expected_params = registry[category_type].external_params
- params[category_type_params] = {k: v for k, v in param[category_type_params].items() if
- k in expected_params}
- else:
- params[category_type_params] = {}
-
- params = params.copy()
-
- if "loss" in params and "loss_params" in params:
- _param_without_unused(params, LOSS_REGISTRY, params["loss"], "loss_params")
- if "regularizer" in params and "regularizer_params" in params:
- _param_without_unused(params, REGULARIZER_REGISTRY, params["regularizer"], "regularizer_params")
- if "optimizer" in params and "optimizer_params" in params:
- _param_without_unused(params, OPTIMIZER_REGISTRY, params["optimizer"], "optimizer_params")
- if "initializer" in params and "initializer_params" in params:
- _param_without_unused(params, INITIALIZER_REGISTRY, params["initializer"], "initializer_params")
- if "embedding_model_params" in params and "model_name" in params:
- _param_without_unused(params, MODEL_REGISTRY, params["model_name"], "embedding_model_params")
-
- return params
-
-
def _flatten_nested_keys(dictionary):
"""
- Flatten the nested values of a dictionary into tuple keys
- E.g. {"a": {"b": [1], "c": [2]}} becomes {("a", "b"): [1], ("a", "c"): [2]}
+ Flatten the nested values of a dictionary into tuple keys.
+
+ E.g., {"a": {"b": [1], "c": [2]}} becomes {("a", "b"): [1], ("a", "c"): [2]}
"""
# Find the parameters that are nested dictionaries
- nested_keys = {k for k, v in dictionary.items() if type(v) is dict}
+ nested_keys = {k for k, v in dictionary.items() if isinstance(v, dict)}
# Flatten them into tuples
- flattened_nested_keys = {(nk, k): dictionary[nk][k] for nk in nested_keys for k in dictionary[nk]}
+ flattened_nested_keys = {
+ (nk, k): dictionary[nk][k]
+ for nk in nested_keys
+ for k in dictionary[nk]
+ }
# Get original dictionary without the nested keys
- dictionary_without_nested_keys = {k: v for k, v in dictionary.items() if k not in nested_keys}
+ dictionary_without_nested_keys = {
+ k: v for k, v in dictionary.items() if k not in nested_keys
+ }
# Return merged dicts
return {**dictionary_without_nested_keys, **flattened_nested_keys}
def _unflatten_nested_keys(dictionary):
"""
- Unflatten the nested values of a dictionary based on the keys that are tuples
- E.g. {("a", "b"): [1], ("a", "c"): [2]} becomes {"a": {"b": [1], "c": [2]}}
+ Unflatten the nested values of a dictionary based on the keys that are tuples.
+
+ E.g., {("a", "b"): [1], ("a", "c"): [2]} becomes {"a": {"b": [1], "c": [2]}}
"""
# Find the parameters that are nested dictionaries
- nested_keys = {k[0] for k in dictionary if type(k) is tuple}
+ nested_keys = {k[0] for k in dictionary if isinstance(k, tuple)}
# Select the parameters which were originally nested and unflatten them
- nested_dict = {nk: {k[1]: v for k, v in dictionary.items() if k[0] == nk} for nk in nested_keys}
+ nested_dict = {
+ nk: {k[1]: v for k, v in dictionary.items() if k[0] == nk}
+ for nk in nested_keys
+ }
# Get original dictionary without the nested keys
- dictionary_without_nested_keys = {k: v for k, v in dictionary.items() if type(k) is not tuple}
+ dictionary_without_nested_keys = {
+ k: v for k, v in dictionary.items() if not isinstance(k, tuple)
+ }
# Return merged dicts
return {**dictionary_without_nested_keys, **nested_dict}
def _get_param_hash(param):
- """
- Get the hash of a param dictionary.
+ """Get the hash of a param dictionary.
+
It first unflattens nested dicts, removes unused nested parameters, nests them again and then create a frozenset
based on the resulting items (tuples).
Note that the flattening and unflattening dict functions are idempotent.
@@ -1074,20 +290,24 @@ def _get_param_hash(param):
Returns
-------
- str
+ hash : str
Hash of the param dictionary.
"""
# Remove parameters that are not used by particular configurations
- # For example, if the regularization is None, there is no need for the regularization lambda
- flattened_params = _flatten_nested_keys(_remove_unused_params(_unflatten_nested_keys(param)))
+ # For example, if the regularization is None, there is no need for the
+ # regularization lambda
+ flattened_params = _flatten_nested_keys(_unflatten_nested_keys(param))
+
return hash(frozenset(flattened_params.items()))
class ParamHistory(object):
"""
- Used to evaluates whether a particular parameter configuration has already been previously seen or not.
+ Used to evaluate whether a particular parameter configuration has already been previously seen or not.
+
To achieve that, we hash each parameter configuration, removing unused parameters first.
"""
+
def __init__(self):
"""The param history is a set of hashes."""
self.param_hash_history = set()
@@ -1104,6 +324,7 @@ def __contains__(self, other):
def _next_hyperparam(param_grid):
"""
Iterator that gets the next parameter combination from a dictionary containing lists of parameters.
+
The parameter combinations are deterministic and go over all possible combinations present in the parameter grid.
Parameters
@@ -1121,7 +342,8 @@ def _next_hyperparam(param_grid):
"""
param_history = ParamHistory()
- # Flatten nested dictionaries so we can apply itertools.product to get all possible parameter combinations
+ # Flatten nested dictionaries so we can apply itertools.product to get all
+ # possible parameter combinations
flattened_param_grid = _flatten_nested_keys(param_grid)
for values in product(*flattened_param_grid.values()):
@@ -1133,8 +355,9 @@ def _next_hyperparam(param_grid):
continue
else:
param_history.add(param)
- # Yields nested configuration (unflattened) without useless parameters
- yield _remove_unused_params(_unflatten_nested_keys(param))
+ # Yields nested configuration (unflattened) without useless
+ # parameters
+ yield _unflatten_nested_keys(param)
def _sample_parameters(param_grid):
@@ -1144,7 +367,6 @@ def _sample_parameters(param_grid):
Parameters
----------
-
param_grid: dict
Parameter configurations.
Example::
@@ -1152,8 +374,7 @@ def _sample_parameters(param_grid):
Returns
-------
-
- param: dict
+ params: dict
Return dictionary containing sampled parameters.
"""
@@ -1161,9 +382,9 @@ def _sample_parameters(param_grid):
for k, v in param_grid.items():
if callable(v):
param[k] = v()
- elif type(v) is dict:
+ elif isinstance(v, dict):
param[k] = _sample_parameters(v)
- elif isinstance(v, Iterable) and type(v) is not str:
+ elif isinstance(v, Iterable) and not isinstance(v, str):
param[k] = np.random.choice(v)
else:
param[k] = v
@@ -1173,6 +394,7 @@ def _sample_parameters(param_grid):
def _next_hyperparam_random(param_grid):
"""
Iterator that gets the next parameter combination from a dictionary containing lists of parameters or callables.
+
The parameter combinations are randomly chosen each iteration.
Parameters
@@ -1198,7 +420,7 @@ def _next_hyperparam_random(param_grid):
continue
else:
param_history.add(param)
- yield _remove_unused_params(param)
+ yield param
def _scalars_into_lists(param_grid):
@@ -1213,16 +435,31 @@ def _scalars_into_lists(param_grid):
param_grid = {"k": [50, 100], "eta": lambda: np.random.choice([1, 2, 3]}
"""
for k, v in param_grid.items():
- if not (callable(v) or isinstance(v, Iterable)) or type(v) is str:
+ if not (callable(v) or isinstance(v, Iterable)) or isinstance(v, str):
param_grid[k] = [v]
- elif type(v) is dict:
+ elif isinstance(v, dict):
_scalars_into_lists(v)
-def select_best_model_ranking(model_class, X_train, X_valid, X_test, param_grid, max_combinations=None,
- param_grid_random_seed=0, use_filter=True, early_stopping=False,
- early_stopping_params=None, use_test_for_selection=False, entities_subset=None,
- corrupt_side='s,o', use_default_protocol=False, retrain_best_model=False, verbose=False):
+def select_best_model_ranking(
+ model_class,
+ X_train,
+ X_valid,
+ X_test,
+ param_grid,
+ max_combinations=None,
+ param_grid_random_seed=0,
+ use_filter=True,
+ early_stopping=False,
+ early_stopping_params=None,
+ use_test_for_selection=False,
+ entities_subset=None,
+ corrupt_side="s,o",
+ focusE=False,
+ focusE_params={},
+ retrain_best_model=False,
+ verbose=False,
+):
"""Model selection routine for embedding models via either grid search or random search.
For grid search, pass a fixed ``param_grid`` and leave ``max_combinations`` as `None`
@@ -1245,13 +482,13 @@ def select_best_model_ranking(model_class, X_train, X_valid, X_test, param_grid,
Parameters
----------
- model_class : class
- The class of the EmbeddingModel to evaluate (TransE, DistMult, ComplEx, etc).
- X_train : ndarray, shape [n, 3]
+ model_class : str
+ The class of the EmbeddingModel to evaluate (`'TransE'`, `'DistMult'`, `'ComplEx'`, etc).
+ X_train : ndarray, shape (n, 3)
An array of training triples.
- X_valid : ndarray, shape [n, 3]
+ X_valid : ndarray, shape (n, 3)
An array of validation triples.
- X_test : ndarray, shape [n, 3]
+ X_test : ndarray, shape (n, 3)
An array of test triples.
param_grid : dict
A grid of hyperparameters to use in model selection. The routine will train a model for each combination
@@ -1265,28 +502,28 @@ def select_best_model_ranking(model_class, X_train, X_valid, X_test, param_grid,
or ``"lr": lambda: np.random.uniform(0.01, 0.1)``.
max_combinations: int
Maximum number of combinations to explore.
- By default (None) all combinations will be explored,
+ By default (`None`) all combinations will be explored,
which makes it incompatible with random parameters for random search.
param_grid_random_seed: int
Random seed for the parameters that are callables and random.
use_filter : bool
- If True, will use the entire input dataset X to compute filtered MRR (default: True).
+ If `True`, it will use the entire input dataset `X` to compute filtered MRR (default: `True`).
early_stopping: bool
- Flag to enable early stopping (default:False).
+ Flag to enable early stopping (default: `False`).
- If set to ``True``, the training loop adopts the following early stopping heuristic:
+ If set to `True`, the training loop adopts the following early stopping heuristic:
- The model will be trained regardless of early stopping for ``burn_in`` epochs.
- Every ``check_interval`` epochs the method will compute the metric specified in ``criteria``.
If such metric decreases for ``stop_interval`` checks, we stop training early.
- Note the metric is computed on ``x_valid``. This is usually a validation set that you held out.
+ Note the metric is computed on ``X_valid``. This is usually a validation set that you held out.
- Also, because ``criteria`` is a ranking metric, it requires generating negatives.
- Entities used to generate corruptions can be specified, as long as the side(s) of a triple to corrupt.
+ Also, since ``criteria`` is a ranking metric, it requires generating negatives.
+ Entities used to generate corruptions can be specified as the side(s) of a triple to corrupt.
The method supports filtered metrics, by passing an array of positives to ``x_filter``. This will be used to
- filter the negatives generated on the fly (i.e. the corruptions).
+ filter the negatives generated on the fly (i.e., the corruptions).
.. note::
@@ -1298,44 +535,39 @@ def select_best_model_ranking(model_class, X_train, X_valid, X_test, param_grid,
early_stopping_params={x_valid=X['valid'], 'criteria': 'mrr'}
- Note the size of validation set also contributes to such overhead.
+ Note that the size of validation set also contributes to such overhead.
In most cases a smaller validation set would be enough.
early_stopping_params: dict
Dictionary of parameters for early stopping.
The following keys are supported:
-
- * x_valid: ndarray, shape [n, 3] : Validation set to be used for early stopping. Uses X['valid'] by default.
-
- * criteria: criteria for early stopping ``hits10``, ``hits3``, ``hits1`` or ``mrr``. (default)
-
- * x_filter: ndarray, shape [n, 3] : Filter to be used(no filter by default)
-
- * burn_in: Number of epochs to pass before kicking in early stopping(default: 100)
-
- * check_interval: Early stopping interval after burn-in(default:10)
-
- * stop_interval: Stop if criteria is performing worse over n consecutive checks (default: 3)
-
+ * x_valid: ndarray, shape (n, 3) - Validation set to be used for early stopping (default: `X['valid']`).
+ * criteria: Criteria for early stopping ``hits10``, ``hits3``, ``hits1`` or ``mrr`` (default: `"mrr"`).
+ * x_filter: ndarray, shape (n, 3) - Filter to be used (default: `None`).
+ * burn_in: Number of epochs to pass before kicking in early stopping (default: 100).
+ * check_interval: Early stopping interval after burn-in (default: 10).
+ * stop_interval: Stop if criteria is performing worse over `n` consecutive checks (default: 3).
+
+ focusE: bool
+ Whether to use the focusE layer (default: `False`). If `True`, make sure you pass the weights as an additional
+ column concatenated after the training triples.
+ focusE_params: dict
+ Dictionary of parameters if focusE is activated.
use_test_for_selection:bool
- Use test set for model selection. If False, uses validation set (default: False).
+ Use test set for model selection. If `False`, uses validation set (default: `False`).
entities_subset: array-like
- List of entities to use for corruptions. If None, will generate corruptions
- using all distinct entities (default: None).
- corrupt_side: string
+ List of entities to use for corruptions. If `None`, will generate corruptions
+ using all distinct entities (default: `None`).
+ corrupt_side: str
Specifies which side to corrupt the entities:
- ``s`` is to corrupt only subject.
- ``o`` is to corrupt only object.
- ``s+o`` is to corrupt both subject and object.
- ``s,o`` is to corrupt both subject and object but ranks are computed separately (default).
- use_default_protocol: bool
- Flag to indicate whether to evaluate head and tail corruptions separately(default:False).
- If this is set to true, it will ignore corrupt_side argument and corrupt both head
- and tail separately and rank triples i.e. corrupt_side='s,o' mode.
+ `"s"` to corrupt only subject.
+ `"o"` to corrupt only object.
+ `"s+o"` to corrupt both subject and object.
+ `"s,o"` to corrupt both subject and object but ranks are computed separately (default).
retrain_best_model: bool
- Flag to indicate whether best model should be re-trained at the end with the validation set used in the search.
- Default: False.
+ Flag to indicate whether best model should be re-trained at the end with the validation set used in the search
+ (default: `False`).
verbose : bool
Verbose mode for the model selection procedure (which is independent of the verbose mode in the model fit).
@@ -1355,11 +587,11 @@ def select_best_model_ranking(model_class, X_train, X_valid, X_test, param_grid,
best_mrr_train : float
The MRR (unfiltered) of the best model computed over the validation set in the model selection loop.
- ranks_test : ndarray, shape [n] or [n,2] depending on the value of corrupt_side.
+ ranks_test : ndarray, shape (n) or (n,2)
An array of ranks of test triples.
- When ``corrupt_side='s,o'`` the function returns [n,2]. The first column represents the rank against
- subject corruptions and the second column represents the rank against object corruptions.
- In other cases, it returns [n] i.e. rank against the specified corruptions.
+ When ``corrupt_side='s,o'`` the function returns an array of shape (n,2). The first column represents the
+ rank against subject corruptions and the second column represents the rank against object corruptions.
+ In other cases, it returns an array of size (n), i.e., rank against the specified corruptions.
mrr_test : float
The MRR (filtered) of the best model, retrained on the concatenation of training and validation sets,
@@ -1369,70 +601,95 @@ def select_best_model_ranking(model_class, X_train, X_valid, X_test, param_grid,
A list containing all the intermediate experimental results:
the model parameters and the corresponding validation metrics.
- Examples
- --------
+ Example
+ -------
>>> from ampligraph.datasets import load_wn18
- >>> from ampligraph.latent_features import ComplEx
>>> from ampligraph.evaluation import select_best_model_ranking
>>> import numpy as np
>>>
>>> X = load_wn18()
- >>>
- >>> model_class = ComplEx
+ >>> model_class = 'ComplEx'
>>> param_grid = {
- >>> "batches_count": [50],
- >>> "seed": 0,
- >>> "epochs": [100],
- >>> "k": [100, 200],
- >>> "eta": [5, 10, 15],
- >>> "loss": ["pairwise", "nll"],
- >>> "loss_params": {
- >>> "margin": [2]
- >>> },
- >>> "embedding_model_params": {
- >>> },
- >>> "regularizer": ["LP", None],
- >>> "regularizer_params": {
- >>> "p": [1, 3],
- >>> "lambda": [1e-4, 1e-5]
- >>> },
- >>> "optimizer": ["adagrad", "adam"],
- >>> "optimizer_params": {
- >>> "lr": lambda: np.random.uniform(0.0001, 0.01)
- >>> },
- >>> "verbose": False
- >>> }
- >>> select_best_model_ranking(model_class, X['train'], X['valid'], X['test'],
- >>> param_grid,
- >>> max_combinations=100,
- >>> use_filter=True,
- >>> verbose=True,
+ >>> "batches_count": [50],
+ >>> "seed": 0,
+ >>> "epochs": [4000],
+ >>> "k": [100, 200],
+ >>> "eta": [5,10,15],
+ >>> "loss": ["pairwise", "nll"],
+ >>> "loss_params": {
+ >>> "margin": [2]
+ >>> },
+ >>> "embedding_model_params": {},
+ >>> "regularizer": ["LP", None],
+ >>> "regularizer_params": {
+ >>> "p": [1, 3],
+ >>> "lambda": [1e-4, 1e-5]
+ >>> },
+ >>> "optimizer": ["adagrad", "adam"],
+ >>> "optimizer_params":{
+ >>> "lr": lambda: np.random.uniform(0.0001, 0.01)
+ >>> },
+ >>> "verbose": False
+ >>> }
+ >>> select_best_model_ranking(model_class, X['train'], X['valid'], X['test'], param_grid,
+ >>> max_combinations=100, use_filter=True, verbose=True,
>>> early_stopping=True)
"""
- logger.debug('Starting gridsearch over hyperparameters. {}'.format(param_grid))
- if use_default_protocol:
- logger.warning('DeprecationWarning: use_default_protocol will be removed in future. \
- Please use corrupt_side argument instead.')
- corrupt_side = 's,o'
+ from importlib import import_module
+
+ from ..compat import evaluate_performance
+
+ compat_module = import_module("ampligraph.compat")
+ model_class = getattr(compat_module, model_class)
+
+ logger.debug(
+ "Starting gridsearch over hyperparameters. {}".format(param_grid)
+ )
if early_stopping_params is None:
early_stopping_params = {}
- # Verify missing parameters for the model class (default values will be used)
- undeclared_args = set(model_class.__init__.__code__.co_varnames[1:]) - set(param_grid.keys())
+ # Verify missing parameters for the model class (default values will be
+ # used)
+ undeclared_args = set(model_class.__init__.__code__.co_varnames[1:]) - set(
+ param_grid.keys()
+ )
if len(undeclared_args) != 0:
- logger.debug("The following arguments were not defined in the parameter grid"
- " and thus the default values will be used: {}".format(', '.join(undeclared_args)))
+ logger.debug(
+ "The following arguments were not defined in the parameter grid"
+ " and thus the default values will be used: {}".format(
+ ", ".join(undeclared_args)
+ )
+ )
param_grid["model_name"] = model_class.name
_scalars_into_lists(param_grid)
if max_combinations is not None:
np.random.seed(param_grid_random_seed)
- model_params_combinations = islice(_next_hyperparam_random(param_grid), max_combinations)
+ model_params_combinations = islice(
+ _next_hyperparam_random(param_grid), max_combinations
+ )
else:
model_params_combinations = _next_hyperparam(param_grid)
+ max_combinations = 1
+ for param in param_grid.values():
+ if isinstance(param, list):
+ max_combinations *= len(param)
+ elif isinstance(param, dict):
+ try:
+ max_combinations *= int(
+ np.prod(
+ [
+ len(el)
+ for el in param.values()
+ if isinstance(el, list)
+ ]
+ )
+ )
+ except Exception as e:
+ logger.debug("Exception " + e)
best_mrr_train = 0
best_model = None
@@ -1440,13 +697,29 @@ def select_best_model_ranking(model_class, X_train, X_valid, X_test, param_grid,
if early_stopping:
try:
- early_stopping_params['x_valid']
+ early_stopping_params["x_valid"]
except KeyError:
- logger.debug('Early stopping enable but no x_valid parameter set. Setting x_valid to {}'.format(X_valid))
- early_stopping_params['x_valid'] = X_valid
+ logger.debug(
+ "Early stopping enable but no x_valid parameter set. Setting x_valid to {}".format(
+ X_valid
+ )
+ )
+ early_stopping_params["x_valid"] = X_valid
+
+ focusE_numeric_edge_values = None
+ if focusE:
+ assert isinstance(X_train, np.ndarray) and X_train.shape[1] > 3, (
+ "Weights are missing! Concatenate them to X_train"
+ "in order to use FocusE!"
+ )
+ focusE_numeric_edge_values = X_train[:, 3:]
+ param_grid["embedding_model_params"] = {
+ **param_grid["embedding_model_params"],
+ **focusE_params,
+ }
if use_filter:
- X_filter = np.concatenate((X_train, X_valid, X_test))
+ X_filter = {"train": X_train, "valid": X_valid, "test": X_test}
else:
X_filter = None
@@ -1465,20 +738,35 @@ def evaluation(ranks):
hits_10 = hits_at_n_score(ranks, n=10)
return mrr, mr, hits_1, hits_3, hits_10
- for model_params in tqdm(model_params_combinations, total=max_combinations, disable=(not verbose)):
+ print("Grid search initialized successfully, training starting!")
+ print("Maximum number of combinations: ", max_combinations)
+ for model_params in tqdm(
+ model_params_combinations, total=max_combinations
+ ):
+ print()
current_result = {
"model_name": model_params["model_name"],
- "model_params": model_params
+ "model_params": model_params,
}
del model_params["model_name"]
try:
model = model_class(**model_params)
- model.fit(X_train, early_stopping, early_stopping_params)
- ranks = evaluate_performance(selection_dataset, model=model,
- filter_triples=X_filter, verbose=verbose,
- entities_subset=entities_subset,
- use_default_protocol=use_default_protocol,
- corrupt_side=corrupt_side)
+ model.fit(
+ X_train,
+ early_stopping,
+ early_stopping_params,
+ focusE_numeric_edge_values=focusE_numeric_edge_values,
+ verbose=verbose,
+ )
+
+ ranks = evaluate_performance(
+ selection_dataset,
+ model=model,
+ filter_triples=X_filter,
+ verbose=verbose,
+ entities_subset=entities_subset,
+ corrupt_side=corrupt_side,
+ )
curr_mrr, mr, hits_1, hits_3, hits_10 = evaluation(ranks)
@@ -1487,11 +775,17 @@ def evaluation(ranks):
"mr": mr,
"hits_1": hits_1,
"hits_3": hits_3,
- "hits_10": hits_10
+ "hits_10": hits_10,
}
- info = 'mr: {} mrr: {} hits 1: {} hits 3: {} hits 10: {}, model: {}, params: {}'.format(
- mr, curr_mrr, hits_1, hits_3, hits_10, type(model).__name__, model_params
+ info = "mr: {} mrr: {} hits 1: {} hits 3: {} hits 10: {}, model: {}, params: {}".format(
+ mr,
+ curr_mrr,
+ hits_1,
+ hits_3,
+ hits_10,
+ type(model).__name__,
+ model_params,
)
logger.debug(info)
@@ -1503,34 +797,61 @@ def evaluation(ranks):
best_model = model
best_params = model_params
except Exception as e:
- current_result["results"] = {
- "exception": str(e)
- }
+ current_result["results"] = {"exception": str(e)}
if verbose:
- logger.error('Exception occurred for parameters:{}'.format(model_params))
+ logger.error(
+ "Exception occurred for parameters:{}".format(model_params)
+ )
logger.error(str(e))
else:
pass
experimental_history.append(current_result)
-
+ print("Combination tried, on to the next one!")
if best_model is not None:
if retrain_best_model:
- best_model.fit(np.concatenate((X_train, X_valid)), early_stopping, early_stopping_params)
-
- ranks_test = evaluate_performance(X_test, model=best_model,
- filter_triples=X_filter, verbose=verbose,
- entities_subset=entities_subset,
- use_default_protocol=use_default_protocol,
- corrupt_side=corrupt_side)
-
- test_mrr, test_mr, test_hits_1, test_hits_3, test_hits_10 = evaluation(ranks_test)
-
- info = \
- 'Best model test results: mr: {} mrr: {} hits 1: {} hits 3: {} hits 10: {}, model: {}, params: {}'.format(
- test_mrr, test_mr, test_hits_1, test_hits_3, test_hits_10, type(best_model).__name__, best_params
+ if focusE:
+ assert (
+ isinstance(X_valid, np.ndarray) and X_valid.shape[1] > 3
+ ), (
+ "Validation set is used as training"
+ "data for retraining the best model,"
+ "but weights are missing."
+ "Concatenate them to X_valid!"
+ )
+ focusE_numeric_edge_values = np.concatenate(
+ [focusE_numeric_edge_values, X_valid[:, 3:]], axis=0
+ )
+ best_model.fit(
+ np.concatenate((X_train, X_valid)),
+ early_stopping,
+ early_stopping_params,
+ focusE_numeric_edge_values=focusE_numeric_edge_values,
)
+ ranks_test = evaluate_performance(
+ X_test,
+ model=best_model,
+ filter_triples=X_filter,
+ verbose=verbose,
+ entities_subset=entities_subset,
+ corrupt_side=corrupt_side,
+ )
+
+ test_mrr, test_mr, test_hits_1, test_hits_3, test_hits_10 = evaluation(
+ ranks_test
+ )
+
+ info = "Best model test results: mr: {} mrr: {} hits 1: {} hits 3: {} hits 10: {}, model: {}, params: {}".format(
+ test_mrr,
+ test_mr,
+ test_hits_1,
+ test_hits_3,
+ test_hits_10,
+ type(best_model).__name__,
+ best_params,
+ )
+
logger.debug(info)
if verbose:
logger.info(info)
@@ -1540,7 +861,7 @@ def evaluation(ranks):
"mr": test_mr,
"hits_1": test_hits_1,
"hits_3": test_hits_3,
- "hits_10": test_hits_10
+ "hits_10": test_hits_10,
}
else:
ranks_test = []
@@ -1550,7 +871,14 @@ def evaluation(ranks):
"mr": np.nan,
"hits_1": np.nan,
"hits_3": np.nan,
- "hits_10": np.nan
+ "hits_10": np.nan,
}
- return best_model, best_params, best_mrr_train, ranks_test, test_evaluation, experimental_history
+ return (
+ best_model,
+ best_params,
+ best_mrr_train,
+ ranks_test,
+ test_evaluation,
+ experimental_history,
+ )
diff --git a/ampligraph/latent_features/__init__.py b/ampligraph/latent_features/__init__.py
index 79b4ef2f..5ed41cdd 100644
--- a/ampligraph/latent_features/__init__.py
+++ b/ampligraph/latent_features/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
#
# This file is Licensed under the Apache License, Version 2.0.
# A copy of the Licence is available in LICENCE, or at:
@@ -7,35 +7,32 @@
#
r"""This module includes neural graph embedding models and support functions.
-Knowledge graph embedding models are neural architectures that encode concepts from a knowledge graph
-(i.e. entities :math:`\mathcal{E}` and relation types :math:`\mathcal{R}`) into low-dimensional, continuous vectors
-:math:`\in \mathcal{R}^k`. Such *knowledge graph embeddings* have applications in knowledge graph completion,
-entity resolution, and link-based clustering, just to cite a few :cite:`nickel2016review`.
+Knowledge graph embedding models are neural architectures that encode concepts
+from a knowledge graph (i.e., entities :math:`\mathcal{E}` and relation types
+:math:`\mathcal{R}`) into low-dimensional, continuous vectors :math:`\in
+\mathcal{R}^k`. Such *knowledge graph embeddings* have applications in
+knowledge graph completion, entity resolution, and link-based clustering,
+just to cite a few :cite:`nickel2016review`.
"""
+from .loss_functions import (
+ AbsoluteMarginLoss,
+ NLLLoss,
+ NLLMulticlass,
+ PairwiseLoss,
+ SelfAdversarialLoss,
+)
+from .models import ScoringBasedEmbeddingModel
+from .regularizers import LP_regularizer
-from .models.EmbeddingModel import EmbeddingModel, MODEL_REGISTRY, set_entity_threshold, reset_entity_threshold
-from .models.TransE import TransE
-from .models.DistMult import DistMult
-from .models.ComplEx import ComplEx
-from .models.HolE import HolE
-from .models.RandomBaseline import RandomBaseline
-from .models.ConvKB import ConvKB
-from .models.ConvE import ConvE
-
-from .loss_functions import Loss, AbsoluteMarginLoss, SelfAdversarialLoss, NLLLoss, PairwiseLoss,\
- NLLMulticlass, BCELoss, LOSS_REGISTRY
-from .regularizers import Regularizer, LPRegularizer, REGULARIZER_REGISTRY
-from .optimizers import Optimizer, AdagradOptimizer, AdamOptimizer, MomentumOptimizer, SGDOptimizer, OPTIMIZER_REGISTRY
-from .initializers import Initializer, RandomNormal, RandomUniform, Xavier, Constant, INITIALIZER_REGISTRY
-from .misc import get_entity_triples
-from ..utils import save_model, restore_model
-
-__all__ = ['LOSS_REGISTRY', 'REGULARIZER_REGISTRY', 'MODEL_REGISTRY', 'OPTIMIZER_REGISTRY', 'INITIALIZER_REGISTRY',
- 'set_entity_threshold', 'reset_entity_threshold',
- 'EmbeddingModel', 'TransE', 'DistMult', 'ComplEx', 'HolE', 'ConvKB', 'ConvE', 'RandomBaseline',
- 'Loss', 'AbsoluteMarginLoss', 'SelfAdversarialLoss', 'NLLLoss', 'PairwiseLoss', 'BCELoss', 'NLLMulticlass',
- 'Regularizer', 'LPRegularizer', 'Optimizer', 'AdagradOptimizer', 'AdamOptimizer', 'MomentumOptimizer',
- 'SGDOptimizer', 'Initializer', 'RandomNormal', 'RandomUniform', 'Xavier', 'Constant',
- 'get_entity_triples',
- 'save_model', 'restore_model']
+__all__ = [
+ "layers",
+ "models",
+ "ScoringBasedEmbeddingModel",
+ "PairwiseLoss",
+ "NLLLoss",
+ "AbsoluteMarginLoss",
+ "SelfAdversarialLoss",
+ "NLLMulticlass",
+ "LP_regularizer",
+]
diff --git a/ampligraph/latent_features/constants.py b/ampligraph/latent_features/constants.py
deleted file mode 100644
index fa876453..00000000
--- a/ampligraph/latent_features/constants.py
+++ /dev/null
@@ -1,95 +0,0 @@
-# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
-#
-# This file is Licensed under the Apache License, Version 2.0.
-# A copy of the Licence is available in LICENCE, or at:
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-#######################################################################################################
-# If not specified, following defaults will be used at respective locations
-
-DEFAULT_INITIALIZER = 'xavier'
-
-# Default learning rate for the optimizers
-DEFAULT_LR = 0.0005
-
-# Default momentum for the optimizers
-DEFAULT_MOMENTUM = 0.9
-
-# Default burn in for early stopping
-DEFAULT_BURN_IN_EARLY_STOPPING = 100
-
-# Default check interval for early stopping
-DEFAULT_CHECK_INTERVAL_EARLY_STOPPING = 10
-
-# Default stop interval for early stopping
-DEFAULT_STOP_INTERVAL_EARLY_STOPPING = 3
-
-# default evaluation criteria for early stopping
-DEFAULT_CRITERIA_EARLY_STOPPING = 'mrr'
-
-# default value which indicates whether to normalize the embeddings after each batch update
-DEFAULT_NORMALIZE_EMBEDDINGS = False
-
-# Default side to corrupt for evaluation
-DEFAULT_CORRUPT_SIDE_EVAL = 's,o'
-
-# default hyperparameter for transE
-DEFAULT_NORM_TRANSE = 1
-
-# default value for the way in which the corruptions are to be generated while training/testing.
-# Uses all entities
-DEFAULT_CORRUPTION_ENTITIES = 'all'
-
-# Threshold (on number of unique entities) to categorize the data as Huge Dataset (to warn user)
-ENTITY_WARN_THRESHOLD = 5e5
-
-# Default value for k (embedding size)
-DEFAULT_EMBEDDING_SIZE = 100
-
-# Default value for eta (number of corrputions to be generated for training)
-DEFAULT_ETA = 2
-
-# Default value for number of epochs
-DEFAULT_EPOCH = 100
-
-# Default value for batch count
-DEFAULT_BATCH_COUNT = 100
-
-# Default value for seed
-DEFAULT_SEED = 0
-
-# Default value for optimizer
-DEFAULT_OPTIM = "adam"
-
-# Default value for loss type
-DEFAULT_LOSS = "nll"
-
-# Default value for regularizer type
-DEFAULT_REGULARIZER = None
-
-# Default value for verbose
-DEFAULT_VERBOSE = False
-
-# Specifies how to generate corruptions for training - default does s and o together and applies the loss
-DEFAULT_CORRUPT_SIDE_TRAIN = ['s,o']
-
-# Subject corruption with a OneToNDatasetAdapter requires an N*N matrix (where N is number of unique entities).
-# Specify a batch size to reduce memory overhead.
-DEFAULT_SUBJECT_CORRUPTION_BATCH_SIZE = 10000
-
-# Default hyperparameters for ConvEmodel
-DEFAULT_CONVE_CONV_FILTERS = 32
-DEFAULT_CONVE_KERNEL_SIZE = 3
-DEFAULT_CONVE_DROPOUT_EMBED = 0.2
-DEFAULT_CONVE_DROPOUT_CONV = 0.3
-DEFAULT_CONVE_DROPOUT_DENSE = 0.2
-DEFAULT_CONVE_USE_BIAS = True
-DEFAULT_CONVE_USE_BATCHNORM = True
-
-# Default value for comparision strategy to use while comparing scores of corruptions against positive
-DEFAULT_RANK_COMPARE_STRATEGY = 'worst'
-
-# Score comparision precision
-# (Multiplies the score with this value and truncates the decimal part for comparision)
-SCORE_COMPARISION_PRECISION = 1e5
diff --git a/ampligraph/latent_features/initializers.py b/ampligraph/latent_features/initializers.py
deleted file mode 100644
index 11f21df4..00000000
--- a/ampligraph/latent_features/initializers.py
+++ /dev/null
@@ -1,561 +0,0 @@
-# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
-#
-# This file is Licensed under the Apache License, Version 2.0.
-# A copy of the Licence is available in LICENCE, or at:
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-import tensorflow as tf
-import abc
-import logging
-import numpy as np
-from sklearn.utils import check_random_state
-
-logger = logging.getLogger(__name__)
-logger.setLevel(logging.DEBUG)
-
-INITIALIZER_REGISTRY = {}
-
-# Default value of lower bound for uniform sampling
-DEFAULT_UNIFORM_LOW = -0.05
-
-# Default value of upper bound for uniform sampling
-DEFAULT_UNIFORM_HIGH = 0.05
-
-# Default value of mean for Gaussian sampling
-DEFAULT_NORMAL_MEAN = 0
-
-# Default value of std for Gaussian sampling
-DEFAULT_NORMAL_STD = 0.05
-
-# Default value indicating whether to use xavier uniform or normal
-DEFAULT_XAVIER_IS_UNIFORM = False
-
-
-def register_initializer(name, external_params=[], class_params={}):
- def insert_in_registry(class_handle):
- INITIALIZER_REGISTRY[name] = class_handle
- class_handle.name = name
- INITIALIZER_REGISTRY[name].external_params = external_params
- INITIALIZER_REGISTRY[name].class_params = class_params
- return class_handle
-
- return insert_in_registry
-
-
-class Initializer(abc.ABC):
- """Abstract class for initializer .
- """
-
- name = ""
- external_params = []
- class_params = {}
-
- def __init__(self, initializer_params={}, verbose=True, seed=0):
- """Initialize the Class
-
- Parameters
- ----------
- initializer_params : dict
- dictionary of hyperparams that would be used by the initializer.
- verbose : bool
- set/reset verbose mode
- seed : int/np.random.RandomState
- random state for random number generator
- """
- self.verbose = verbose
- self._initializer_params = {}
- if isinstance(seed, int):
- self.random_generator = check_random_state(seed)
- else:
- self.random_generator = seed
- self._init_hyperparams(initializer_params)
-
- def _display_params(self):
- """Display the parameter values
- """
- logger.info('\n------ Initializer -----')
- logger.info('Name : {}'.format(self.name))
- for key, value in self._initializer_params.items():
- logger.info('{} : {}'.format(key, value))
-
- def _init_hyperparams(self, hyperparam_dict):
- """ Initializes the hyperparameters.
-
- Parameters
- ----------
- hyperparam_dict: dictionary
- Consists of key value pairs. The initializer will check the keys to get the corresponding params
- """
- raise NotImplementedError('Abstract Method not implemented!')
-
- def _get_tf_initializer(self, in_shape=None, out_shape=None, concept='e'):
- """Create a tensorflow node for initializer
-
- Parameters
- ----------
- in_shape: int
- number of inputs to the layer (fan in)
- out_shape: int
- number of outputs of the layer (fan out)
- concept: char
- concept type (e for entity, r for relation)
-
- Returns
- -------
- initializer_instance: An Initializer instance.
- """
- raise NotImplementedError('Abstract Method not implemented!')
-
- def _get_np_initializer(self, in_shape=None, out_shape=None, concept='e'):
- """Create an initialized numpy array
-
- Parameters
- ----------
- in_shape: int
- number of inputs to the layer (fan in)
- out_shape: int
- number of outputs of the layer (fan out)
- concept: char
- concept type (e for entity, r for relation)
-
- Returns
- -------
- initialized_values: n-d array
- Initialized weights
- """
- raise NotImplementedError('Abstract Method not implemented!')
-
- def get_entity_initializer(self, in_shape=None, out_shape=None, init_type='tf'):
- """ Initializer for entity embeddings
-
- Parameters
- ----------
- in_shape: int
- number of inputs to the layer (fan in)
- out_shape: int
- number of outputs of the layer (fan out)
- init_type: string
- Type of initializer ('tf' for tensorflow, 'np' for numpy)
-
- Returns
- -------
- initialized_values: tf.Op or n-d array
- Weights initializer
- """
- assert init_type in ['tf', 'np'], 'Invalid initializer type!'
- if init_type == 'tf':
- return self._get_tf_initializer(in_shape, out_shape, 'e')
- else:
- return self._get_np_initializer(in_shape, out_shape, 'e')
-
- def get_relation_initializer(self, in_shape=None, out_shape=None, init_type='tf'):
- """ Initializer for relation embeddings
-
- Parameters
- ----------
- in_shape: int
- number of inputs to the layer (fan in)
- out_shape: int
- number of outputs of the layer (fan out)
- init_type: string
- Type of initializer ('tf' for tensorflow, 'np' for numpy)
-
- Returns
- -------
- initialized_values: tf.Op or n-d array
- Weights initializer
- """
- assert init_type in ['tf', 'np'], 'Invalid initializer type!'
- if init_type == 'tf':
- return self._get_tf_initializer(in_shape, out_shape, 'r')
- else:
- return self._get_np_initializer(in_shape, out_shape, 'r')
-
-
-@register_initializer("normal", ["mean", "std"])
-class RandomNormal(Initializer):
- r"""Initializes from a normal distribution with specified ``mean`` and ``std``
-
- .. math::
-
- \mathcal{N} (\mu, \sigma)
-
- """
-
- name = ""
- external_params = []
- class_params = {}
-
- def __init__(self, initializer_params={}, verbose=True, seed=0):
- """Initialize the Random Normal initialization strategy
-
- Parameters
- ----------
- initializer_params : dict
- Consists of key-value pairs. The initializer will check the keys to get the corresponding params:
-
- - **mean**: (float). Mean of the weights(default: 0)
- - **std**: (float): std of the weights (default: 0.05)
-
- Example: ``initializer_params={'mean': 0, 'std': 0.01}``
- verbose : bool
- Enable/disable verbose mode
- seed : int/np.random.RandomState
- random state for random number generator
- """
-
- super(RandomNormal, self).__init__(initializer_params, verbose, seed)
-
- def _init_hyperparams(self, hyperparam_dict):
- """ Initializes the hyperparameters.
-
- Parameters
- ----------
- hyperparam_dict : dictionary
- Consists of key value pairs. The initializer will check the keys to get the corresponding params
- """
- self._initializer_params['mean'] = hyperparam_dict.get('mean', DEFAULT_NORMAL_MEAN)
- self._initializer_params['std'] = hyperparam_dict.get('std', DEFAULT_NORMAL_STD)
-
- if self.verbose:
- self._display_params()
-
- def _get_tf_initializer(self, in_shape=None, out_shape=None, concept='e'):
- """Create a tensorflow node for initializer
-
- Parameters
- ----------
- in_shape: int
- number of inputs to the layer (fan in)
- out_shape: int
- number of outputs of the layer (fan out)
- concept: char
- concept type (e for entity, r for relation)
-
- Returns
- -------
- initializer_instance: An Initializer instance.
- """
- return tf.random_normal_initializer(mean=self._initializer_params['mean'],
- stddev=self._initializer_params['std'],
- dtype=tf.float32)
-
- def _get_np_initializer(self, in_shape, out_shape, concept='e'):
- """Create an initialized numpy array
-
- Parameters
- ----------
- in_shape: int
- number of inputs to the layer (fan in)
- out_shape: int
- number of outputs of the layer (fan out)
- concept: char
- concept type (e for entity, r for relation)
-
- Returns
- -------
- out: n-d array
- matrix initialized from a normal distribution of specified mean and std
- """
- return self.random_generator.normal(self._initializer_params['mean'],
- self._initializer_params['std'],
- size=(in_shape, out_shape)).astype(np.float32)
-
-
-@register_initializer("uniform", ["low", "high"])
-class RandomUniform(Initializer):
- r"""Initializes from a uniform distribution with specified ``low`` and ``high``
-
- .. math::
-
- \mathcal{U} (low, high)
-
- """
-
- name = ""
- external_params = []
- class_params = {}
-
- def __init__(self, initializer_params={}, verbose=True, seed=0):
- """Initialize the Uniform initialization strategy
-
- Parameters
- ----------
- initializer_params : dict
- Consists of key-value pairs. The initializer will check the keys to get the corresponding params:
-
- - **low**: (float). lower bound for uniform number (default: -0.05)
- - **high**: (float): upper bound for uniform number (default: 0.05)
-
- Example: ``initializer_params={'low': 0, 'high': 0.01}``
- verbose : bool
- Enable/disable verbose mode
- seed : int/np.random.RandomState
- random state for random number generator
- """
-
- super(RandomUniform, self).__init__(initializer_params, verbose, seed)
-
- def _init_hyperparams(self, hyperparam_dict):
- """ Initializes the hyperparameters.
-
- Parameters
- ----------
- hyperparam_dict : dictionary
- Consists of key value pairs. The initializer will check the keys to get the corresponding params
- """
- self._initializer_params['low'] = hyperparam_dict.get('low', DEFAULT_UNIFORM_LOW)
- self._initializer_params['high'] = hyperparam_dict.get('high', DEFAULT_UNIFORM_HIGH)
-
- if self.verbose:
- self._display_params()
-
- def _get_tf_initializer(self, in_shape=None, out_shape=None, concept='e'):
- """Create a tensorflow node for initializer
-
- Parameters
- ----------
- in_shape: int
- number of inputs to the layer (fan in)
- out_shape: int
- number of outputs of the layer (fan out)
- concept: char
- concept type (e for entity, r for relation)
-
- Returns
- -------
- initializer_instance: An Initializer instance.
- """
- return tf.random_uniform_initializer(minval=self._initializer_params['low'],
- maxval=self._initializer_params['high'],
- dtype=tf.float32)
-
- def _get_np_initializer(self, in_shape, out_shape, concept='e'):
- """Create an initialized numpy array
-
- Parameters
- ----------
- in_shape: int
- number of inputs to the layer (fan in)
- out_shape: int
- number of outputs of the layer (fan out)
- concept: char
- concept type (e for entity, r for relation)
-
- Returns
- -------
- out: n-d array
- matrix initialized from a uniform distribution of specified low and high bounds
- """
- return self.random_generator.uniform(self._initializer_params['low'],
- self._initializer_params['high'],
- size=(in_shape, out_shape)).astype(np.float32)
-
-
-@register_initializer("xavier", ["uniform"])
-class Xavier(Initializer):
- r"""Follows the xavier strategy for initialization of layers :cite:`glorot2010understanding`.
-
- If ``uniform`` is set to True, then it initializes the layer from the following uniform distribution:
-
- .. math::
-
- \mathcal{U} ( - \sqrt{ \frac{6}{ fan_{in} + fan_{out} } }, \sqrt{ \frac{6}{ fan_{in} + fan_{out} } } )
-
- If ``uniform`` is False, then it initializes the layer from the following normal distribution:
-
- .. math::
-
- \mathcal{N} ( 0, \sqrt{ \frac{2}{ fan_{in} + fan_{out} } } )
-
- where :math:`fan_{in}` and :math:`fan_{out}` are number of input units and output units of the layer respectively.
-
- """
-
- name = ""
- external_params = []
- class_params = {}
-
- def __init__(self, initializer_params={}, verbose=True, seed=0):
- """Initialize the Xavier strategy
-
- Parameters
- ----------
- initializer_params : dict
- Consists of key-value pairs. The initializer will check the keys to get the corresponding params:
-
- - **uniform**: (bool). indicates whether to use Xavier Uniform or Xavier Normal initializer.
-
- Example: ``initializer_params={'uniform': False}``
- verbose : bool
- Enable/disable verbose mode
- seed : int/np.random.RandomState
- random state for random number generator
- """
-
- super(Xavier, self).__init__(initializer_params, verbose, seed)
-
- def _init_hyperparams(self, hyperparam_dict):
- """ Initializes the hyperparameters.
-
- Parameters
- ----------
- hyperparam_dict : dictionary
- Consists of key value pairs. The initializer will check the keys to get the corresponding params
- """
- self._initializer_params['uniform'] = hyperparam_dict.get('uniform', DEFAULT_XAVIER_IS_UNIFORM)
-
- if self.verbose:
- self._display_params()
-
- def _get_tf_initializer(self, in_shape=None, out_shape=None, concept='e'):
- """Create a tensorflow node for initializer
-
- Parameters
- ----------
- in_shape: int
- number of inputs to the layer (fan in)
- out_shape: int
- number of outputs of the layer (fan out)
- concept: char
- concept type (e for entity, r for relation)
-
- Returns
- -------
- initializer_instance: An Initializer instance.
- """
- return tf.contrib.layers.xavier_initializer(uniform=self._initializer_params['uniform'],
- dtype=tf.float32)
-
- def _get_np_initializer(self, in_shape, out_shape, concept='e'):
- """Create an initialized numpy array
-
- Parameters
- ----------
- in_shape: int
- number of inputs to the layer (fan in)
- out_shape: int
- number of outputs of the layer (fan out)
- concept: char
- concept type (e for entity, r for relation)
-
- Returns
- -------
- out: n-d array
- matrix initialized using xavier uniform or xavier normal initializer
- """
- if self._initializer_params['uniform']:
- limit = np.sqrt(6 / (in_shape + out_shape))
- return self.random_generator.uniform(-limit, limit, size=(in_shape, out_shape)).astype(np.float32)
- else:
- std = np.sqrt(2 / (in_shape + out_shape))
- return self.random_generator.normal(0, std, size=(in_shape, out_shape)).astype(np.float32)
-
-
-@register_initializer("constant", ["entity", "relation"])
-class Constant(Initializer):
- r"""Initializes with the constant values provided by the user
-
- """
-
- name = ""
- external_params = []
- class_params = {}
-
- def __init__(self, initializer_params={}, verbose=True, seed=0):
- """Initialize the the constant values provided by the user
-
- Parameters
- ----------
- initializer_params : dict
- Consists of key-value pairs. The initializer will check the keys to get the corresponding params:
-
- - **entity**: (np.ndarray.float32). Initial values for entity embeddings
- - **relation**: (np.ndarray.float32). Initial values for relation embeddings
-
- Example: ``initializer_params={'entity': ent_init_value, 'relation': rel_init_value}``
- verbose : bool
- Enable/disable verbose mode
- seed : int/np.random.RandomState
- random state for random number generator
- """
-
- super(Constant, self).__init__(initializer_params, verbose, seed)
-
- def _init_hyperparams(self, hyperparam_dict):
- """ Initializes the hyperparameters.
-
- Parameters
- ----------
- hyperparam_dict : dictionary
- Consists of key value pairs. The initializer will check the keys to get the corresponding params
- """
- try:
- self._initializer_params['entity'] = hyperparam_dict['entity']
- self._initializer_params['relation'] = hyperparam_dict['relation']
- except KeyError:
- raise Exception('Initial values of both entity and relation embeddings need to '
- 'be passed to the initializer!')
- if self.verbose:
- self._display_params()
-
- def _get_tf_initializer(self, in_shape=None, out_shape=None, concept='e'):
- """Create a tensorflow node for initializer
-
- Parameters
- ----------
- in_shape: int
- number of inputs to the layer (fan in)
- out_shape: int
- number of outputs of the layer (fan out)
- concept: char
- concept type (e for entity, r for relation)
-
- Returns
- -------
- initializer_instance: An Initializer instance.
- """
-
- if concept == 'e':
- assert self._initializer_params['entity'].shape[0] == in_shape and \
- self._initializer_params['entity'].shape[1] == out_shape, \
- "Invalid shape for entity initializer!"
-
- return tf.compat.v1.constant_initializer(self._initializer_params['entity'], dtype=tf.float32)
- else:
- assert self._initializer_params['relation'].shape[0] == in_shape and \
- self._initializer_params['relation'].shape[1] == out_shape, \
- "Invalid shape for relation initializer!"
-
- return tf.compat.v1.constant_initializer(self._initializer_params['relation'], dtype=tf.float32)
-
- def _get_np_initializer(self, in_shape, out_shape, concept='e'):
- """Create an initialized numpy array
-
- Parameters
- ----------
- in_shape: int
- number of inputs to the layer (fan in)
- out_shape: int
- number of outputs of the layer (fan out)
- concept: char
- concept type (e for entity, r for relation)
-
- Returns
- -------
- out: n-d array
- matrix initialized using constant values supplied by the user
- """
- if concept == 'e':
- assert self._initializer_params['entity'].shape[0] == in_shape and \
- self._initializer_params['entity'].shape[1] == out_shape, \
- "Invalid shape for entity initializer!"
-
- return self._initializer_params['entity']
- else:
- assert self._initializer_params['relation'].shape[0] == in_shape and \
- self._initializer_params['relation'].shape[1] == out_shape, \
- "Invalid shape for relation initializer!"
-
- return self._initializer_params['relation']
diff --git a/ampligraph/latent_features/layers/__init__.py b/ampligraph/latent_features/layers/__init__.py
new file mode 100644
index 00000000..53bad495
--- /dev/null
+++ b/ampligraph/latent_features/layers/__init__.py
@@ -0,0 +1,8 @@
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
+#
+# This file is Licensed under the Apache License, Version 2.0.
+# A copy of the Licence is available in LICENCE, or at:
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+__all__ = ["scoring", "encoding", "corruption_generation", "calibration"]
diff --git a/ampligraph/latent_features/layers/calibration/__init__.py b/ampligraph/latent_features/layers/calibration/__init__.py
new file mode 100644
index 00000000..efa6bf09
--- /dev/null
+++ b/ampligraph/latent_features/layers/calibration/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
+#
+# This file is Licensed under the Apache License, Version 2.0.
+# A copy of the Licence is available in LICENCE, or at:
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+from .calibrate import CalibrationLayer
+
+__all__ = ["CalibrationLayer"]
diff --git a/ampligraph/latent_features/layers/calibration/calibrate.py b/ampligraph/latent_features/layers/calibration/calibrate.py
new file mode 100644
index 00000000..d7db0bc6
--- /dev/null
+++ b/ampligraph/latent_features/layers/calibration/calibrate.py
@@ -0,0 +1,132 @@
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
+#
+# This file is Licensed under the Apache License, Version 2.0.
+# A copy of the Licence is available in LICENCE, or at:
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+import tensorflow as tf
+import numpy as np
+
+
+class CalibrationLayer(tf.keras.layers.Layer):
+ """Layer to calibrate the model outputs.
+
+ The class implements the heuristics described in :cite:`calibration`,
+ using Platt scaling :cite:`platt1999probabilistic`.
+
+ See the docs of :meth:`~ampligraph.latent_features.models.ScoringBasedEmbeddingModel.calibrate()` for more details.
+ """
+
+ def get_config(self):
+ config = super(CalibrationLayer, self).get_config()
+ config.update(
+ {
+ "pos_size": self.pos_size,
+ "neg_size": self.neg_size,
+ "positive_base_rate": self.positive_base_rate,
+ }
+ )
+ return config
+
+ def __init__(
+ self, pos_size=0, neg_size=0, positive_base_rate=None, **kwargs
+ ):
+ self.pos_size = pos_size
+ self.neg_size = pos_size if neg_size == 0 else neg_size
+
+ if positive_base_rate is not None:
+ if positive_base_rate <= 0 or positive_base_rate >= 1:
+ raise ValueError(
+ "Positive_base_rate must be a value between 0 and 1."
+ )
+ else:
+ assert pos_size > 0 and neg_size > 0, "Positive size must be > 0."
+
+ positive_base_rate = pos_size / (pos_size + neg_size)
+
+ self.positive_base_rate = positive_base_rate
+ self.w_init = tf.constant_initializer(kwargs.pop("calib_w", 0.0))
+ self.b_init = tf.constant_initializer(
+ kwargs.pop(
+ "calib_b",
+ np.log((self.neg_size + 1.0) / (self.pos_size + 1.0)).astype(
+ np.float32
+ ),
+ )
+ )
+ super(CalibrationLayer, self).__init__(**kwargs)
+
+ def build(self, input_shape):
+ """
+ Build method.
+ """
+ self.calib_w = self.add_weight(
+ "calib_w",
+ shape=(),
+ initializer=self.w_init,
+ dtype=tf.float32,
+ trainable=True,
+ )
+
+ self.calib_b = self.add_weight(
+ "calib_b",
+ shape=(),
+ initializer=self.b_init,
+ dtype=tf.float32,
+ trainable=True,
+ )
+ self.built = True
+
+ def call(
+ self, scores_pos, scores_neg=tf.convert_to_tensor(()), training=0
+ ):
+ """
+ Call method.
+ """
+ if training:
+ scores_all = tf.concat([scores_pos, scores_neg], axis=0)
+ else:
+ scores_all = scores_pos
+
+ logits = -(self.calib_w * scores_all + self.calib_b)
+
+ if training:
+ labels = tf.concat(
+ [
+ tf.cast(
+ tf.fill(
+ scores_pos.shape,
+ (self.pos_size + 1.0) / (self.pos_size + 2.0),
+ ),
+ tf.float32,
+ ),
+ tf.cast(
+ tf.fill(scores_neg.shape, 1 / (self.neg_size + 2.0)),
+ tf.float32,
+ ),
+ ],
+ axis=0,
+ )
+ weigths_pos = scores_neg.shape[0] / scores_pos.shape[0]
+ weights_neg = (
+ 1.0 - self.positive_base_rate
+ ) / self.positive_base_rate
+ weights = tf.concat(
+ [
+ tf.cast(
+ tf.fill(scores_pos.shape, weigths_pos), tf.float32
+ ),
+ tf.cast(
+ tf.fill(scores_neg.shape, weights_neg), tf.float32
+ ),
+ ],
+ axis=0,
+ )
+ loss = tf.reduce_mean(
+ weights
+ * tf.nn.sigmoid_cross_entropy_with_logits(labels, logits)
+ )
+ return loss
+ else:
+ return tf.math.sigmoid(logits)
diff --git a/ampligraph/latent_features/layers/corruption_generation/CorruptionGenerationLayerTrain.py b/ampligraph/latent_features/layers/corruption_generation/CorruptionGenerationLayerTrain.py
new file mode 100644
index 00000000..33572a85
--- /dev/null
+++ b/ampligraph/latent_features/layers/corruption_generation/CorruptionGenerationLayerTrain.py
@@ -0,0 +1,97 @@
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
+#
+# This file is Licensed under the Apache License, Version 2.0.
+# A copy of the Licence is available in LICENCE, or at:
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+import tensorflow as tf
+
+
+class CorruptionGenerationLayerTrain(tf.keras.layers.Layer):
+ """Generates corruptions during training.
+
+ The corruption might involve either subject or object using
+ entities sampled uniformly at random from the loaded graph.
+ """
+
+ def get_config(self):
+ config = super(CorruptionGenerationLayerTrain, self).get_config()
+ config.update({"seed": self.seed})
+ return config
+
+ def __init__(self, seed=0, **kwargs):
+ """
+ Initializes the corruption generation layer.
+
+ Parameters
+ ----------
+ eta: int
+ Number of corruptions to generate.
+ """
+ self.seed = seed
+ super(CorruptionGenerationLayerTrain, self).__init__(**kwargs)
+
+ def call(self, pos, ent_size, eta):
+ """
+ Generates corruption for the positives supplied.
+
+ Parameters
+ ----------
+ pos: array-like, shape (n, 3)
+ Batch of input triples (positives).
+ ent_size: int
+ Number of unique entities present in the partition.
+
+ Returns
+ -------
+ corruptions: array-like, shape (n * eta, 3)
+ Corruptions of the triples.
+ """
+ # size and reshape the dataset to sample corruptions
+ dataset = tf.reshape(
+ tf.tile(tf.reshape(pos, [-1]), [eta]),
+ [tf.shape(input=pos)[0] * eta, 3],
+ )
+ # generate a mask which will tell which subject needs to be corrupted
+ # (random uniform sampling)
+ keep_subj_mask = tf.cast(
+ tf.random.uniform(
+ [tf.shape(input=dataset)[0]],
+ 0,
+ 2,
+ dtype=tf.int32,
+ seed=self.seed,
+ ),
+ tf.bool,
+ )
+ # If we are not corrupting the subject then corrupt the object
+ keep_obj_mask = tf.logical_not(keep_subj_mask)
+
+ # cast it to integer (0/1)
+ keep_subj_mask = tf.cast(keep_subj_mask, tf.int32)
+ keep_obj_mask = tf.cast(keep_obj_mask, tf.int32)
+ # generate the n * eta replacements (uniformly randomly)
+ replacements = tf.random.uniform(
+ [tf.shape(dataset)[0]], 0, ent_size, dtype=tf.int32, seed=self.seed
+ )
+ # keep subjects of dataset where keep_subject is 1 and zero it where keep_subject is 0
+ # now add replacements where keep_subject is 0 (i.e. keep_object is 1)
+ subjects = tf.math.add(
+ tf.math.multiply(keep_subj_mask, dataset[:, 0]),
+ tf.math.multiply(keep_obj_mask, replacements),
+ )
+ # keep relations as it is
+ relationships = dataset[:, 1]
+ # keep objects of dataset where keep_object is 1 and zero it where keep_object is 0
+ # now add replacements where keep_object is 0 (i.e. keep_subject is 1)
+ objects = tf.math.add(
+ tf.math.multiply(keep_obj_mask, dataset[:, 2]),
+ tf.math.multiply(keep_subj_mask, replacements),
+ )
+ # stack the generated subject, reln and object entities and create the
+ # corruptions
+ corruptions = tf.transpose(
+ a=tf.stack([subjects, relationships, objects])
+ )
+ return corruptions
diff --git a/ampligraph/latent_features/layers/corruption_generation/__init__.py b/ampligraph/latent_features/layers/corruption_generation/__init__.py
new file mode 100644
index 00000000..e841eccb
--- /dev/null
+++ b/ampligraph/latent_features/layers/corruption_generation/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
+#
+# This file is Licensed under the Apache License, Version 2.0.
+# A copy of the Licence is available in LICENCE, or at:
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+from .CorruptionGenerationLayerTrain import CorruptionGenerationLayerTrain
+
+__all__ = ["CorruptionGenerationLayerTrain"]
diff --git a/ampligraph/latent_features/layers/encoding/EmbeddingLookupLayer.py b/ampligraph/latent_features/layers/encoding/EmbeddingLookupLayer.py
new file mode 100644
index 00000000..db53cec5
--- /dev/null
+++ b/ampligraph/latent_features/layers/encoding/EmbeddingLookupLayer.py
@@ -0,0 +1,346 @@
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
+#
+# This file is Licensed under the Apache License, Version 2.0.
+# A copy of the Licence is available in LICENCE, or at:
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+import tensorflow as tf
+import numpy as np
+
+
+class EmbeddingLookupLayer(tf.keras.layers.Layer):
+ def get_config(self):
+ config = super(EmbeddingLookupLayer, self).get_config()
+
+ config.update(
+ {
+ "k": self.k,
+ "max_ent_size": self._max_ent_size_internal,
+ "max_rel_size": self._max_rel_size_internal,
+ "entity_kernel_initializer": self.ent_init,
+ "entity_kernel_regularizer": self.ent_regularizer,
+ "relation_kernel_initializer": self.rel_init,
+ "relation_kernel_regularizer": self.rel_regularizer,
+ }
+ )
+
+ return config
+
+ def __init__(
+ self,
+ k,
+ max_ent_size=None,
+ max_rel_size=None,
+ entity_kernel_initializer="glorot_uniform",
+ entity_kernel_regularizer=None,
+ relation_kernel_initializer="glorot_uniform",
+ relation_kernel_regularizer=None,
+ **kwargs
+ ):
+ """
+ Initializes the embeddings of the model.
+
+ Parameters
+ ----------
+ k: int
+ Embedding size.
+ max_ent_size: int
+ Max entities that can occur in any partition (default: `None`).
+ max_rel_size: int
+ Max relations that can occur in any partition (default: `None`).
+ entity_kernel_initializer: str (name of objective function), objective function or
+ `tf.keras.initializers.Initializer` instance
+ An objective function is any callable with the signature ``init = fn(shape)``.
+ Initializer of the entity embeddings.
+ entity_kernel_regularizer: str (name of objective function), objective function or
+ `tf.keras.initializers.Initializer` instance
+ An objective function is any callable with the signature ``init = fn(shape)``
+ Initializer of the relation embeddings.
+ relation_kernel_initializer: str or objective function or `tf.keras.regularizers.Regularizer` instance
+ Regularizer of entity embeddings.
+ relation_kernel_regularizer: str or objective function or `tf.keras.regularizers.Regularizer` instance
+ Regularizer of relations embeddings.
+ seed: int
+ Random seed.
+
+ """
+ super(EmbeddingLookupLayer, self).__init__(**kwargs)
+
+ self._max_ent_size_internal = max_ent_size
+ self._max_rel_size_internal = max_rel_size
+ self.k = k
+
+ self.ent_partition = None
+ self.rel_partition = None
+
+ self.max_ent_size = max_ent_size
+ self.max_rel_size = max_rel_size
+
+ self._has_enough_args_to_build_ent_emb = True
+ self._has_enough_args_to_build_rel_emb = True
+
+ if self.max_ent_size is None:
+ self._has_enough_args_to_build_ent_emb = False
+
+ if self.max_rel_size is None:
+ self._has_enough_args_to_build_rel_emb = False
+
+ self.ent_init = entity_kernel_initializer
+ self.rel_init = relation_kernel_initializer
+
+ self.ent_regularizer = entity_kernel_regularizer
+ self.rel_regularizer = relation_kernel_regularizer
+
+ def set_ent_rel_initial_value(self, ent_init, rel_init):
+ """
+ Sets the initial value of entity and relation embedding matrix.
+
+ This function is mainly used during the partitioned training where the full embedding matrix is
+ initialized outside the model.
+ """
+ self.ent_partition = ent_init
+ self.rel_partition = rel_init
+
+ def set_initializer(self, initializer):
+ """
+ Set the initializer of the weights of this layer.
+
+ Parameters
+ ----------
+ initializer: str (name of objective function) or objective function or `tf.keras.initializers.Initializer` or list
+ Initializer of the entity and relation embeddings. This is either a single value or a list of size 2.
+ If it is a single value, then both the entities and relations will be initialized based on
+ the same initializer. If it is a list, the first initializer will be used for entities and the second
+ for relations. Any callable with the signature ``init = fn(shape)`` can be interpreted as an objective
+ function.
+
+ """
+ if isinstance(initializer, list):
+ assert (
+ len(initializer) == 2
+ ), "Incorrect length for initializer. Assumed 2 got {}".format(
+ len(initializer)
+ )
+ self.ent_init = tf.keras.initializers.get(initializer[0])
+ self.rel_init = tf.keras.initializers.get(initializer[1])
+ else:
+ self.ent_init = tf.keras.initializers.get(initializer)
+ self.rel_init = tf.keras.initializers.get(initializer)
+
+ def set_regularizer(self, regularizer):
+ """
+ Set the regularizer of the weights of this layer.
+
+ Parameters
+ ----------
+ regularizer: str (name of objective function) or objective function or `tf.keras.regularizers.Regularizer` instance or list
+ Regularizer of the weights determining entity and relation embeddings.
+ If it is a single value, then both the entities and relations will be regularized based on
+ the same regularizer. If it is a list, the first regularizer will be used for entities and the second
+ for relations.
+
+ """
+
+ if isinstance(regularizer, list):
+ assert (
+ len(regularizer) == 2
+ ), "Incorrect length for regularizer. Expected 2, got {}".format(
+ len(regularizer)
+ )
+ self.ent_regularizer = tf.keras.regularizers.get(regularizer[0])
+ self.rel_regularizer = tf.keras.regularizers.get(regularizer[1])
+ else:
+ self.ent_regularizer = tf.keras.regularizers.get(regularizer)
+ self.rel_regularizer = tf.keras.regularizers.get(regularizer)
+
+ @property
+ def max_ent_size(self):
+ """Returns the size of the entity embedding matrix."""
+ return self._max_ent_size_internal
+
+ @max_ent_size.setter
+ def max_ent_size(self, value):
+ """Setter for the max entity size property.
+
+ The layer is buildable only if this property is set.
+ """
+ if value is not None and value > 0:
+ self._max_ent_size_internal = value
+ self._has_enough_args_to_build_ent_emb = True
+
+ @property
+ def max_rel_size(self):
+ """Returns the size of relation embedding matrix."""
+ return self._max_rel_size_internal
+
+ @max_rel_size.setter
+ def max_rel_size(self, value):
+ """Setter for the max relation size property.
+
+ The layer is buildable only if this property is set.
+ """
+ if value is not None and value > 0:
+ self._max_rel_size_internal = value
+ self._has_enough_args_to_build_rel_emb = True
+
+ def build(self, input_shape):
+ """Builds the embedding lookup error.
+
+ The trainable weights are created based on the hyperparams.
+ """
+ # create the trainable variables for entity embeddings
+ if self._has_enough_args_to_build_ent_emb:
+ self.ent_emb = self.add_weight(
+ "ent_emb",
+ shape=[self._max_ent_size_internal, self.k],
+ initializer=self.ent_init,
+ regularizer=self.ent_regularizer,
+ dtype=tf.float32,
+ trainable=True,
+ )
+
+ if self.ent_partition is not None:
+ paddings_ent = [
+ [
+ 0,
+ self._max_ent_size_internal
+ - self.ent_partition.shape[0],
+ ],
+ [0, 0],
+ ]
+ self.ent_emb.assign(
+ np.pad(
+ self.ent_partition,
+ paddings_ent,
+ "constant",
+ constant_values=0,
+ )
+ )
+ del self.ent_partition
+ self.ent_partition = None
+
+ else:
+ raise TypeError(
+ "Not enough arguments to build Encoding Layer. Please set max_ent_size property."
+ )
+
+ # create the trainable variables for relation embeddings
+ if self._has_enough_args_to_build_rel_emb:
+ self.rel_emb = self.add_weight(
+ "rel_emb",
+ shape=[self._max_rel_size_internal, self.k],
+ initializer=self.rel_init,
+ regularizer=self.rel_regularizer,
+ dtype=tf.float32,
+ trainable=True,
+ )
+
+ if self.rel_partition is not None:
+ paddings_rel = [
+ [
+ 0,
+ self._max_rel_size_internal
+ - self.rel_partition.shape[0],
+ ],
+ [0, 0],
+ ]
+ self.rel_emb.assign(
+ np.pad(
+ self.rel_partition,
+ paddings_rel,
+ "constant",
+ constant_values=0,
+ )
+ )
+ del self.rel_partition
+ self.rel_partition = None
+ else:
+ raise TypeError(
+ "Not enough arguments to build Encoding Layer. Please set max_rel_size property."
+ )
+
+ self.built = True
+
+ def partition_change_updates(self, partition_ent_emb, partition_rel_emb):
+ """Perform the changes that are required when the partition is changed during training.
+
+ Parameters
+ ----------
+ batch_ent_emb:
+ Entity embeddings that need to be trained for the partition
+ (all triples of the partition will have an embedding in this matrix).
+ batch_rel_emb:
+ Relation embeddings that need to be trained for the partition
+ (all triples of the partition will have an embedding in this matrix).
+
+ """
+
+ # if the number of entities in the partition are less than the required size of the embedding matrix
+ # pad it. This is needed because the trainable variable size cant change dynamically.
+ # Once defined, it stays fixed. Hence padding is needed.
+ paddings_ent = tf.constant(
+ [
+ [0, self._max_ent_size_internal - partition_ent_emb.shape[0]],
+ [0, 0],
+ ]
+ )
+ paddings_rel = tf.constant(
+ [
+ [0, self._max_rel_size_internal - partition_rel_emb.shape[0]],
+ [0, 0],
+ ]
+ )
+
+ # once padded, assign it to the trainable variable
+ self.ent_emb.assign(
+ tf.pad(
+ partition_ent_emb, paddings_ent, "CONSTANT", constant_values=0
+ )
+ )
+ self.rel_emb.assign(
+ tf.pad(
+ partition_rel_emb, paddings_rel, "CONSTANT", constant_values=0
+ )
+ )
+
+ def call(self, triples):
+ """
+ Looks up the embeddings of entities and relations of the triples.
+
+ Parameters
+ ----------
+ triples : ndarray, shape (n, 3)
+ Batch of input triples.
+
+ Returns
+ -------
+ emb_triples : list
+ List of embeddings of subjects, predicates, objects.
+ """
+ # look up in the respective embedding matrix
+ e_s = tf.nn.embedding_lookup(self.ent_emb, triples[:, 0])
+ e_p = tf.nn.embedding_lookup(self.rel_emb, triples[:, 1])
+ e_o = tf.nn.embedding_lookup(self.ent_emb, triples[:, 2])
+ return [e_s, e_p, e_o]
+
+ def compute_output_shape(self, input_shape):
+ """Returns the output shape of outputs of call function.
+
+ Parameters
+ ----------
+ input_shape: list
+ Shape of inputs of call function.
+
+ Returns
+ -------
+ output_shape: list
+ Shape of outputs of call function.
+ """
+ assert isinstance(input_shape, list)
+ batch_size, _ = input_shape
+ return [
+ (batch_size, self.k),
+ (batch_size, self.k),
+ (batch_size, self.k),
+ ]
diff --git a/ampligraph/latent_features/layers/encoding/__init__.py b/ampligraph/latent_features/layers/encoding/__init__.py
new file mode 100644
index 00000000..4d64397f
--- /dev/null
+++ b/ampligraph/latent_features/layers/encoding/__init__.py
@@ -0,0 +1,10 @@
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
+#
+# This file is Licensed under the Apache License, Version 2.0.
+# A copy of the Licence is available in LICENCE, or at:
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+from .EmbeddingLookupLayer import EmbeddingLookupLayer
+
+__all__ = ["EmbeddingLookupLayer"]
diff --git a/ampligraph/latent_features/layers/scoring/AbstractScoringLayer.py b/ampligraph/latent_features/layers/scoring/AbstractScoringLayer.py
new file mode 100644
index 00000000..e952741e
--- /dev/null
+++ b/ampligraph/latent_features/layers/scoring/AbstractScoringLayer.py
@@ -0,0 +1,429 @@
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
+#
+# This file is Licensed under the Apache License, Version 2.0.
+# A copy of the Licence is available in LICENCE, or at:
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+import tensorflow as tf
+
+# Precision for floating point comparison
+COMPARISION_PRECISION = 1e3
+
+# Scoring layer registry. Every scoring function must be registered in
+# this registry.
+SCORING_LAYER_REGISTRY = {}
+
+
+def register_layer(name, external_params=None, class_params=None):
+ """Register the scoring function using this decorator.
+
+ Parameters
+ -----------
+ name: str
+ Name of the scoring function to be used to register the class.
+ external_params: list of strings
+ If there are any scoring function hyperparams, register their names.
+ class_params: dict
+ Parameters that may be used internally across various models.
+ """
+ if external_params is None:
+ external_params = []
+ if class_params is None:
+ class_params = {}
+
+ def insert_in_registry(class_handle):
+ assert (
+ name not in SCORING_LAYER_REGISTRY.keys()
+ ), "Scoring Layer with name {} \
+ already exists!".format(
+ name
+ )
+
+ # store the class handle in the registry with name as key
+ SCORING_LAYER_REGISTRY[name] = class_handle
+ # create a class level variable and store the name
+ class_handle.name = name
+
+ # store other params related to the scoring function in the registry
+ # this will be used later during model selection, etc
+ SCORING_LAYER_REGISTRY[name].external_params = external_params
+ SCORING_LAYER_REGISTRY[name].class_params = class_params
+ return class_handle
+
+ return insert_in_registry
+
+
+class AbstractScoringLayer(tf.keras.layers.Layer):
+ """Abstract class for scoring layer."""
+
+ def get_config(self):
+ config = super(AbstractScoringLayer, self).get_config()
+ config.update({"k": self.internal_k})
+ return config
+
+ def __init__(self, k):
+ """Initializes the scoring layer.
+
+ Parameters
+ ----------
+ k: int
+ Embedding size.
+ """
+ super(AbstractScoringLayer, self).__init__()
+ # store the embedding size. (concrete models may overwrite this)
+ self.internal_k = k
+
+ def call(self, triples):
+ """Interface to the external world.
+ Computes the scores of the triples.
+
+ Parameters
+ ----------
+ triples: array-like, shape (n, 3)
+ Batch of input triples.
+
+ Returns
+ -------
+ scores: tf.Tensor, shape (n,1)
+ Tensor of scores of inputs.
+ """
+ return self._compute_scores(triples)
+
+ def _compute_scores(self, triples):
+ """Abstract function to compute scores. Override this method in concrete classes.
+
+ Parameters
+ -----------
+ triples: array-like, shape (n, 3)
+ Batch of input triples.
+
+ Returns
+ --------
+ scores: tf.Tensor, shape (n,1)
+ Tensor of scores of inputs.
+ """
+ raise NotImplementedError("Abstract method not implemented!")
+
+ def _get_object_corruption_scores(self, triples, ent_matrix):
+ """Abstract function to compute object corruption scores.
+
+ Evaluate the inputs against object corruptions and scores of the corruptions.
+
+ Parameters
+ -----------
+ triples: array-like, shape (n, k)
+ Batch of input embeddings.
+ ent_matrix: array-like, shape (m, k)
+ Slice of embedding matrix (corruptions).
+
+ Returns
+ --------
+ scores: tf.Tensor, shape (n,1)
+ Scores of object corruptions (corruptions defined by `ent_embs` matrix).
+ """
+ raise NotImplementedError("Abstract method not implemented!")
+
+ def _get_subject_corruption_scores(self, triples, ent_matrix):
+ """Abstract function to compute subject corruption scores.
+
+ Evaluate the inputs against subject corruptions and scores of the corruptions.
+
+ Parameters
+ -----------
+ triples: array-like, shape (n, k)
+ Batch of input embeddings.
+ ent_matrix: array-like, shape (m, k)
+ Slice of embedding matrix (corruptions).
+
+ Returns
+ --------
+ scores: tf.Tensor, shape (n,1)
+ Scores of subject corruptions (corruptions defined by `ent_embs` matrix).
+ """
+ raise NotImplementedError("Abstract method not implemented!")
+
+ def get_ranks(
+ self,
+ triples,
+ ent_matrix,
+ start_ent_id,
+ end_ent_id,
+ filters,
+ mapping_dict,
+ corrupt_side="s,o",
+ comparison_type="worst",
+ ):
+ """Computes the ranks of triples against their corruptions.
+
+ Ranks are computed by corrupting triple subject and/or object with the embeddings in ent_matrix.
+
+ Parameters
+ -----------
+ triples: array-like, shape (n, k)
+ Batch of input embeddings.
+ ent_matrix: array-like, shape (m, k)
+ Slice of embedding matrix (corruptions).
+ start_ent_id: int
+ Original id of the first row of embedding matrix (used during partitioned approach).
+ end_ent_id: int
+ Original id of the last row of embedding matrix (used during partitioned approach).
+ filters: list of lists
+ Size of list is either 1 or 2 depending on ``corrupt_side``.
+ Size of the internal list is equal to the size of the input triples.
+ Each list contains an array of filters (True Positives) related to the specified side of triples to corrupt.
+ corrupt_side: str
+ Which side to corrupt during evaluation.
+ comparison_type: str
+ Indicates how to break ties (default: `worst`, i.e., assigns the worst rank to the test triple).
+ One of the three types can be passed: `"best"`, `"middle"`, `"worst"`.
+
+ Returns
+ --------
+ ranks: tf.Tensor, shape (n,2)
+ Ranks of triple against subject and object corruptions (corruptions defined by `ent_embs` matrix).
+ """
+ # compute the score of true positives
+ triple_score = self._compute_scores(triples)
+
+ # Handle the floating point comparison by multiplying by reqd precision and casting to int
+ # before comparing
+ triple_score = tf.cast(triple_score * COMPARISION_PRECISION, tf.int32)
+
+ out_ranks = tf.TensorArray(tf.int32, size=0, dynamic_size=True)
+ filter_index = 0
+ if tf.strings.regex_full_match(corrupt_side, ".*s.*"):
+ # compute the score by corrupting the subject side of triples by
+ # ent_matrix
+ sub_corr_score = self._get_subject_corruption_scores(
+ triples, ent_matrix
+ )
+ # Handle the floating point comparison by multiplying by reqd precision and casting to int
+ # before comparing
+ sub_corr_score = tf.cast(
+ sub_corr_score * COMPARISION_PRECISION, tf.int32
+ )
+
+ # if pos score: 0.5, corr_score: 0.5, 0.5, 0.3, 0.6, 0.5, 0.5
+ if comparison_type == "best":
+ # returns: 1 i.e. only. 1 corruption is having score greater
+ # than positive (optimistic)
+ sub_rank = tf.reduce_sum(
+ tf.cast(
+ tf.expand_dims(triple_score, 1) < sub_corr_score,
+ tf.int32,
+ ),
+ 1,
+ )
+ elif comparison_type == "middle":
+ # returns: 3 i.e. 1 + (4/2) i.e. only 1 corruption is having score greater than positive
+ # and 4 corruptions are having same (middle rank is 4/2 = 1),
+ # so 1+2=3
+ sub_rank = tf.reduce_sum(
+ tf.cast(
+ tf.expand_dims(triple_score, 1) < sub_corr_score,
+ tf.int32,
+ ),
+ 1,
+ )
+ part = tf.cast(
+ tf.expand_dims(triple_score, 1) == sub_corr_score, tf.int32
+ )
+ sub_rank += tf.cast(
+ tf.math.ceil(tf.reduce_sum(part, 1) / 2), tf.int32
+ )
+ else:
+ # returns: 5 i.e. 5 corruptions are having score >= positive
+ # as you can see this strategy returns the worst rank
+ # (pessimistic)
+
+ # compare True positive score against their respective
+ # corruptions and get rank.
+ sub_rank = tf.reduce_sum(
+ tf.cast(
+ tf.expand_dims(triple_score, 1) <= sub_corr_score,
+ tf.int32,
+ ),
+ 1,
+ )
+
+ if filters.shape[0] > 0:
+ # tf.print(tf.shape(triple_score)[0])
+ for i in tf.range(tf.shape(triple_score)[0]):
+ # get the ids of True positives that needs to be filtered
+ filter_ids = filters[filter_index][i]
+
+ if mapping_dict.size() > 0:
+ filter_ids = mapping_dict.lookup(filter_ids)
+ filter_ids = tf.reshape(filter_ids, (-1,))
+
+ filter_ids_selector = tf.math.greater_equal(
+ filter_ids, 0
+ )
+ filter_ids = tf.boolean_mask(
+ filter_ids, filter_ids_selector, axis=0
+ )
+
+ # This is done for partitioning (where the full emb matrix is not used)
+ # this gets only the filter ids of the current partition
+ # being used for generating corruption
+ filter_ids_selector = tf.logical_and(
+ filter_ids >= start_ent_id, filter_ids <= end_ent_id
+ )
+
+ filter_ids = tf.boolean_mask(
+ filter_ids, filter_ids_selector
+ )
+ # from entity id convert to index in the current partition
+ filter_ids = filter_ids - start_ent_id
+
+ # get the score of the corruptions which are actually True
+ # positives
+ score_filter = tf.gather(
+ tf.squeeze(tf.gather_nd(sub_corr_score, [[i]])),
+ filter_ids,
+ )
+ # check how many of those were ranked higher than the test
+ # triple
+ num_filters_ranked_higher = tf.reduce_sum(
+ tf.cast(
+ tf.gather(triple_score, [i]) <= score_filter,
+ tf.int32,
+ )
+ )
+ # adjust the rank of the test triple accordingly
+ sub_rank = tf.tensor_scatter_nd_sub(
+ sub_rank, [[i]], [num_filters_ranked_higher]
+ )
+
+ out_ranks = out_ranks.write(out_ranks.size(), sub_rank)
+
+ if tf.strings.regex_full_match(corrupt_side, ".*o.*"):
+ # compute the score by corrupting the object side of triples by
+ # ent_matrix
+ obj_corr_score = self._get_object_corruption_scores(
+ triples, ent_matrix
+ )
+
+ # Handle the floating point comparison by multiplying by reqd precision and casting to int
+ # before comparing
+ obj_corr_score = tf.cast(
+ obj_corr_score * COMPARISION_PRECISION, tf.int32
+ )
+
+ # if pos score: 0.5, corr_score: 0.5, 0.5, 0.3, 0.6, 0.5, 0.5
+ if comparison_type == "best":
+ # returns: 1 i.e. only. 1 corruption is having score greater
+ # than positive (optimistic)
+ obj_rank = tf.reduce_sum(
+ tf.cast(
+ tf.expand_dims(triple_score, 1) < obj_corr_score,
+ tf.int32,
+ ),
+ 1,
+ )
+ elif comparison_type == "middle":
+ print("middle")
+ # returns: 3 i.e. 1 + (4/2) i.e. only 1 corruption is having score greater than positive
+ # and 4 corruptions are having same (middle rank is 4/2 = 1),
+ # so 1+2=3
+ obj_rank = tf.reduce_sum(
+ tf.cast(
+ tf.expand_dims(triple_score, 1) < obj_corr_score,
+ tf.int32,
+ ),
+ 1,
+ )
+ part = tf.cast(
+ tf.expand_dims(triple_score, 1) == obj_corr_score, tf.int32
+ )
+ obj_rank += tf.cast(
+ tf.math.ceil(tf.reduce_sum(part, 1) / 2), tf.int32
+ )
+ else:
+ # returns: 5 i.e. 5 corruptions are having score >= positive
+ # as you can see this strategy returns the worst rank
+ # (pessimistic)
+
+ # compare True positive score against their respective
+ # corruptions and get rank.
+ obj_rank = tf.reduce_sum(
+ tf.cast(
+ tf.expand_dims(triple_score, 1) <= obj_corr_score,
+ tf.int32,
+ ),
+ 1,
+ )
+
+ if filters.shape[0] > 0:
+ for i in tf.range(tf.shape(triple_score)[0]):
+ if corrupt_side in ["s", "o"] and filters.shape[0] == 1:
+ filter_index = 0
+ else:
+ filter_index = 1
+ # get the ids of True positives that needs to be filtered
+ filter_ids = filters[filter_index][i]
+
+ if mapping_dict.size() > 0:
+ filter_ids = mapping_dict.lookup(filter_ids)
+ filter_ids = tf.reshape(filter_ids, (-1,))
+
+ filter_ids_selector = tf.math.greater_equal(
+ filter_ids, 0
+ )
+ filter_ids = tf.boolean_mask(
+ filter_ids, filter_ids_selector, axis=0
+ )
+
+ # This is done for patritioning (where the full emb matrix is not used)
+ # this gets only the filter ids of the current partition
+ # being used for generating corruption
+ filter_ids_selector = tf.logical_and(
+ filter_ids >= start_ent_id, filter_ids <= end_ent_id
+ )
+ filter_ids = tf.boolean_mask(
+ filter_ids, filter_ids_selector
+ )
+ # from entity id convert to index in the current partition
+ filter_ids = filter_ids - start_ent_id
+
+ # get the score of the corruptions which are actually True
+ # positives
+ score_filter = tf.gather(
+ tf.squeeze(tf.gather_nd(obj_corr_score, [[i]])),
+ filter_ids,
+ )
+ # check how many of those were ranked higher than the test
+ # triple
+ num_filters_ranked_higher = tf.reduce_sum(
+ tf.cast(
+ tf.gather(triple_score, [i]) <= score_filter,
+ tf.int32,
+ )
+ )
+ # adjust the rank of the test triple accordingly
+ obj_rank = tf.tensor_scatter_nd_sub(
+ obj_rank, [[i]], [num_filters_ranked_higher]
+ )
+
+ out_ranks = out_ranks.write(out_ranks.size(), obj_rank)
+
+ out_ranks = out_ranks.stack()
+ return out_ranks
+
+ def compute_output_shape(self, input_shape):
+ """Returns the output shape of the outputs of the call function.
+
+ Parameters
+ -----------
+ input_shape: tuple
+ Shape of inputs of call function.
+
+ Returns
+ --------
+ output_shape: tuple
+ Shape of outputs of call function.
+ """
+ assert isinstance(input_shape, list)
+ batch_size, _ = input_shape
+ return [batch_size, 1]
diff --git a/ampligraph/latent_features/layers/scoring/ComplEx.py b/ampligraph/latent_features/layers/scoring/ComplEx.py
new file mode 100644
index 00000000..94532c9b
--- /dev/null
+++ b/ampligraph/latent_features/layers/scoring/ComplEx.py
@@ -0,0 +1,151 @@
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
+#
+# This file is Licensed under the Apache License, Version 2.0.
+# A copy of the Licence is available in LICENCE, or at:
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+import tensorflow as tf
+from .AbstractScoringLayer import register_layer, AbstractScoringLayer
+
+
+@register_layer("ComplEx")
+class ComplEx(AbstractScoringLayer):
+ r"""Complex Embeddings (ComplEx) scoring layer.
+
+ The ComplEx model :cite:`trouillon2016complex` is an extension of
+ the :class:`ampligraph.latent_features.DistMult` bilinear diagonal model.
+
+ ComplEx scoring function is based on the trilinear Hermitian dot product in :math:`\mathbb{C}`:
+
+ .. math::
+ f_{ComplEx}=Re(\langle \mathbf{r}_p, \mathbf{e}_s, \overline{\mathbf{e}_o} \rangle)
+
+ .. note::
+ Since ComplEx embeddings belong to :math:`\mathbb{C}`, this model uses twice as many parameters as
+ :class:`ampligraph.latent_features.DistMult`.
+ """
+
+ def get_config(self):
+ config = super(ComplEx, self).get_config()
+ return config
+
+ def __init__(self, k):
+ super(ComplEx, self).__init__(k)
+ # internally complex uses k embedddings for real part and k embedddings for img part
+ # hence internally it uses 2 * k embeddings
+ self.internal_k = 2 * k
+
+ def _compute_scores(self, triples):
+ """Compute scores using the ComplEx scoring function.
+
+ Parameters
+ ----------
+ triples: array, shape (n, 3)
+ Batch of input triples.
+
+ Returns
+ -------
+ scores: tf.Tensor
+ Tensor with scores of the inputs.
+ """
+ # split the embeddings of s, p, o into 2 parts (real and img part)
+ e_s_real, e_s_img = tf.split(triples[0], 2, axis=1)
+ e_p_real, e_p_img = tf.split(triples[1], 2, axis=1)
+ e_o_real, e_o_img = tf.split(triples[2], 2, axis=1)
+
+ # apply the complex scoring function
+ scores = tf.reduce_sum(
+ (e_s_real * (e_p_real * e_o_real + e_p_img * e_o_img))
+ + (e_s_img * (e_p_real * e_o_img - e_p_img * e_o_real)),
+ axis=1,
+ )
+ return scores
+
+ def _get_subject_corruption_scores(self, triples, ent_matrix):
+ """Compute subject corruption scores.
+
+ Evaluate the inputs against subject corruptions and scores of the corruptions.
+
+ Parameters
+ ----------
+ triples: array-like, shape (n, k)
+ Batch of input embeddings.
+ ent_matrix: array-like, shape (m, k)
+ Slice of embedding matrix (corruptions).
+
+ Returns
+ -------
+ scores: tf.Tensor, shape (n, 1)
+ Scores of subject corruptions (corruptions defined by `ent_embs` matrix).
+ """
+ # split the embeddings of s, p, o into 2 parts (real and img part)
+ e_s_real, e_s_img = tf.split(triples[0], 2, axis=1)
+ e_p_real, e_p_img = tf.split(triples[1], 2, axis=1)
+ e_o_real, e_o_img = tf.split(triples[2], 2, axis=1)
+
+ # split the corruption entity embeddings into 2 parts (real and img
+ # part)
+ ent_real, ent_img = tf.split(ent_matrix, 2, axis=1)
+
+ # compute the subject corruption score using ent_real, ent_img
+ # (corruption embeddings) as subject embeddings
+ sub_corr_score = tf.reduce_sum(
+ ent_real
+ * (
+ tf.expand_dims(e_p_real * e_o_real, 1)
+ + tf.expand_dims(e_p_img * e_o_img, 1)
+ )
+ + (
+ ent_img
+ * (
+ tf.expand_dims(e_p_real * e_o_img, 1)
+ - tf.expand_dims(e_p_img * e_o_real, 1)
+ )
+ ),
+ axis=2,
+ )
+ return sub_corr_score
+
+ def _get_object_corruption_scores(self, triples, ent_matrix):
+ """Compute object corruption scores.
+
+ Evaluate the inputs against object corruptions and scores of the corruptions.
+
+ Parameters
+ ----------
+ triples: array-like, shape (n, k)
+ Batch of input embeddings.
+ ent_matrix: array-like, shape (m, k)
+ Slice of embedding matrix (corruptions).
+
+ Returns
+ -------
+ scores: tf.Tensor, shape (n, 1)
+ Scores of object corruptions (corruptions defined by `ent_embs` matrix).
+ """
+ # split the embeddings of s, p, o into 2 parts (real and img part)
+ e_s_real, e_s_img = tf.split(triples[0], 2, axis=1)
+ e_p_real, e_p_img = tf.split(triples[1], 2, axis=1)
+ e_o_real, e_o_img = tf.split(triples[2], 2, axis=1)
+
+ # split the corruption entity embeddings into 2 parts (real and img
+ # part)
+ ent_real, ent_img = tf.split(ent_matrix, 2, axis=1)
+
+ # compute the object corruption score using ent_real, ent_img
+ # (corruption embeddings) as object embeddings
+ obj_corr_score = tf.reduce_sum(
+ (
+ tf.expand_dims(e_s_real * e_p_real, 1)
+ - tf.expand_dims(e_s_img * e_p_img, 1)
+ )
+ * ent_real
+ + (
+ tf.expand_dims(e_s_img * e_p_real, 1)
+ + tf.expand_dims(e_s_real * e_p_img, 1)
+ )
+ * ent_img,
+ axis=2,
+ )
+ return obj_corr_score
diff --git a/ampligraph/latent_features/layers/scoring/DistMult.py b/ampligraph/latent_features/layers/scoring/DistMult.py
new file mode 100644
index 00000000..de2cdbd8
--- /dev/null
+++ b/ampligraph/latent_features/layers/scoring/DistMult.py
@@ -0,0 +1,99 @@
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
+#
+# This file is Licensed under the Apache License, Version 2.0.
+# A copy of the Licence is available in LICENCE, or at:
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+import tensorflow as tf
+from .AbstractScoringLayer import register_layer, AbstractScoringLayer
+
+
+@register_layer("DistMult")
+class DistMult(AbstractScoringLayer):
+ r"""DistMult scoring layer.
+
+ The model as described in :cite:`yang2014embedding`.
+
+ The bilinear diagonal DistMult model uses the trilinear dot product as scoring function:
+
+ .. math::
+ f_{DistMult}=\langle \mathbf{r}_p, \mathbf{e}_s, \mathbf{e}_o \rangle
+
+ where :math:`\mathbf{e}_{s}` is the embedding of the subject, :math:`\mathbf{r}_{p}` the embedding
+ of the predicate and :math:`\mathbf{e}_{o}` the embedding of the object.
+ """
+
+ def get_config(self):
+ config = super(DistMult, self).get_config()
+ return config
+
+ def __init__(self, k):
+ super(DistMult, self).__init__(k)
+
+ def _compute_scores(self, triples):
+ """Compute scores using the distmult scoring function.
+
+ Parameters
+ ----------
+ triples: array-like, shape (n, 3)
+ Batch of input triples.
+
+ Returns
+ -------
+ scores: tf.Tensor, shape (n,1)
+ Tensor of scores of inputs.
+ """
+ # compute scores as sum(s * p * o)
+ scores = tf.reduce_sum(triples[0] * triples[1] * triples[2], 1)
+ return scores
+
+ def _get_subject_corruption_scores(self, triples, ent_matrix):
+ """Compute subject corruption scores.
+
+ Evaluate the inputs against subject corruptions and scores of the corruptions.
+
+ Parameters
+ ----------
+ triples: array-like, shape (n, k)
+ Batch of input embeddings.
+ ent_matrix: array-like, shape (m, k)
+ Slice of embedding matrix (corruptions).
+
+ Returns
+ -------
+ scores: tf.Tensor, shape (n, 1)
+ Scores of subject corruptions (corruptions defined by `ent_embs` matrix).
+ """
+ rel_emb, obj_emb = triples[1], triples[2]
+ # compute the score by broadcasting the corruption embeddings(ent_matrix) and using the scoring function
+ # compute scores as sum(s_corr * p * o)
+ sub_corr_score = tf.reduce_sum(
+ ent_matrix * tf.expand_dims(rel_emb * obj_emb, 1), 2
+ )
+ return sub_corr_score
+
+ def _get_object_corruption_scores(self, triples, ent_matrix):
+ """Compute object corruption scores.
+
+ Evaluate the inputs against object corruptions and scores of the corruptions.
+
+ Parameters
+ ----------
+ triples: array-like, shape (n, k)
+ Batch of input embeddings.
+ ent_matrix: array-like, shape (m, k)
+ Slice of embedding matrix (corruptions).
+
+ Returns
+ -------
+ scores: tf.Tensor, shape (n, 1)
+ Scores of object corruptions (corruptions defined by `ent_embs` matrix).
+ """
+ sub_emb, rel_emb = triples[0], triples[1]
+ # compute the score by broadcasting the corruption embeddings(ent_matrix) and using the scoring function
+ # compute scores as sum(s * p * o_corr)
+ obj_corr_score = tf.reduce_sum(
+ tf.expand_dims(sub_emb * rel_emb, 1) * ent_matrix, 2
+ )
+ return obj_corr_score
diff --git a/ampligraph/latent_features/layers/scoring/HolE.py b/ampligraph/latent_features/layers/scoring/HolE.py
new file mode 100644
index 00000000..9fd2ba98
--- /dev/null
+++ b/ampligraph/latent_features/layers/scoring/HolE.py
@@ -0,0 +1,89 @@
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
+#
+# This file is Licensed under the Apache License, Version 2.0.
+# A copy of the Licence is available in LICENCE, or at:
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+from .AbstractScoringLayer import register_layer
+from .ComplEx import ComplEx
+
+
+@register_layer("HolE")
+class HolE(ComplEx):
+ r"""Holographic Embeddings (HolE) scoring layer.
+
+ The HolE model :cite:`nickel2016holographic` as re-defined by Hayashi et al. :cite:`HayashiS17`:
+
+ .. math::
+ f_{HolE}= \frac{2}{k} \, f_{ComplEx}
+
+ where :math:`k` is the size of the embeddings.
+ """
+
+ def get_config(self):
+ config = super(HolE, self).get_config()
+ return config
+
+ def __init__(self, k):
+ super(HolE, self).__init__(k)
+
+ def _compute_scores(self, triples):
+ """Compute scores using HolE scoring function.
+
+ Parameters
+ ----------
+ triples: array-like, shape (n, 3)
+ Batch of input triples.
+
+ Returns
+ -------
+ scores: tf.Tensor(n,1)
+ Tensor of scores of inputs.
+ """
+ # HolE scoring is 2/k * complex_score
+ return (2 / (self.internal_k / 2)) * (super()._compute_scores(triples))
+
+ def _get_subject_corruption_scores(self, triples, ent_matrix):
+ """Compute subject corruption scores.
+
+ Evaluate the inputs against subject corruptions and scores of the corruptions.
+
+ Parameters
+ ----------
+ triples: array-like, shape (n, k)
+ Batch of input embeddings.
+ ent_matrix: array-like, shape (m, k)
+ Slice of embedding matrix (corruptions).
+
+ Returns
+ -------
+ scores: tf.Tensor, shape (n,1)
+ Scores of subject corruptions (corruptions defined by `ent_embs` matrix).
+ """
+ # HolE scoring is 2/k * complex_score
+ return (2 / (self.internal_k / 2)) * (
+ super()._get_subject_corruption_scores(triples, ent_matrix)
+ )
+
+ def _get_object_corruption_scores(self, triples, ent_matrix):
+ """Compute object corruption scores.
+
+ Evaluate the inputs against object corruptions and scores of the corruptions.
+
+ Parameters
+ ----------
+ triples: array-like, shape (n, k)
+ Batch of input embeddings.
+ ent_matrix: array-like, shape (m, k)
+ Slice of embedding matrix (corruptions).
+
+ Returns
+ -------
+ scores: tf.Tensor, shape (n,1)
+ Scores of object corruptions (corruptions defined by `ent_embs` matrix).
+ """
+ # HolE scoring is 2/k * complex_score
+ return (2 / (self.internal_k / 2)) * (
+ super()._get_object_corruption_scores(triples, ent_matrix)
+ )
diff --git a/ampligraph/latent_features/layers/scoring/Random.py b/ampligraph/latent_features/layers/scoring/Random.py
new file mode 100644
index 00000000..67052893
--- /dev/null
+++ b/ampligraph/latent_features/layers/scoring/Random.py
@@ -0,0 +1,82 @@
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
+#
+# This file is Licensed under the Apache License, Version 2.0.
+# A copy of the Licence is available in LICENCE, or at:
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+import tensorflow as tf
+from .AbstractScoringLayer import register_layer, AbstractScoringLayer
+
+
+@register_layer("Random")
+class Random(AbstractScoringLayer):
+ r"""Random scoring layer."""
+
+ def get_config(self):
+ config = super(Random, self).get_config()
+ return config
+
+ def __init__(self, k):
+ super(Random, self).__init__(k)
+
+ def _compute_scores(self, triples):
+ """Compute scores using the transE scoring function.
+
+ Parameters
+ ----------
+ triples: array-like, shape (n, 3)
+ Batch of input triples.
+
+ Returns
+ -------
+ scores: tf.Tensor, shape (n,1)
+ Tensor of scores of inputs.
+ """
+
+ scores = tf.random.uniform(shape=[tf.shape(triples[0])[0]], seed=0)
+ return scores
+
+ def _get_subject_corruption_scores(self, triples, ent_matrix):
+ """Compute subject corruption scores.
+
+ Evaluate the inputs against subject corruptions and scores of the corruptions.
+
+ Parameters
+ ----------
+ triples: array-like, shape (n, k)
+ Batch of input embeddings.
+ ent_matrix: array-like, shape (m, k)
+ Slice of embedding matrix (corruptions).
+
+ Returns
+ -------
+ scores: tf.Tensor, shape (n, 1)
+ Scores of subject corruptions (corruptions defined by `ent_embs` matrix).
+ """
+ scores = tf.random.uniform(
+ shape=[tf.shape(triples[0])[0], tf.shape(ent_matrix)[0]], seed=0
+ )
+ return scores
+
+ def _get_object_corruption_scores(self, triples, ent_matrix):
+ """Compute object corruption scores.
+
+ Evaluate the inputs against object corruptions and scores of the corruptions.
+
+ Parameters
+ ----------
+ triples: array-like, shape (n, k)
+ Batch of input embeddings.
+ ent_matrix: array-like, shape (m, k)
+ Slice of embedding matrix (corruptions).
+
+ Returns
+ -------
+ scores: tf.Tensor, shape (n, 1)
+ Scores of object corruptions (corruptions defined by `ent_embs` matrix).
+ """
+ scores = tf.random.uniform(
+ shape=[tf.shape(triples[0])[0], tf.shape(ent_matrix)[0]], seed=0
+ )
+ return scores
diff --git a/ampligraph/latent_features/layers/scoring/TransE.py b/ampligraph/latent_features/layers/scoring/TransE.py
new file mode 100644
index 00000000..5e8ce716
--- /dev/null
+++ b/ampligraph/latent_features/layers/scoring/TransE.py
@@ -0,0 +1,114 @@
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
+#
+# This file is Licensed under the Apache License, Version 2.0.
+# A copy of the Licence is available in LICENCE, or at:
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+import tensorflow as tf
+from .AbstractScoringLayer import register_layer, AbstractScoringLayer
+
+
+@register_layer("TransE")
+class TransE(AbstractScoringLayer):
+ r"""Translating Embeddings (TransE) scoring layer.
+
+ The model as described in :cite:`bordes2013translating`.
+
+ The scoring function of TransE computes a similarity between the embedding of the subject
+ :math:`\mathbf{e}_{sub}` translated by the embedding of the predicate :math:`\mathbf{e}_{pred}`,
+ and the embedding of the object :math:`\mathbf{e}_{obj}`,
+ using the :math:`L_1` or :math:`L_2` norm :math:`||\cdot||` (default: :math:`L_1`):
+
+ .. math::
+ f_{TransE}=-||\mathbf{e}_{sub} + \mathbf{e}_{pred} - \mathbf{e}_{obj}||
+
+ Such scoring function is then used on positive and negative triples :math:`t^+, t^-` in the loss function.
+
+ """
+
+ def get_config(self):
+ config = super(TransE, self).get_config()
+ return config
+
+ def __init__(self, k):
+ super(TransE, self).__init__(k)
+
+ def _compute_scores(self, triples):
+ """Compute scores using transE scoring function.
+
+ Parameters
+ ----------
+ triples: array-like, (n, 3)
+ Batch of input triples.
+
+ Returns
+ -------
+ scores: tf.Tensor, shape (n,1)
+ Tensor of scores of inputs.
+ """
+ # compute scores as -|| s + p - o||
+ scores = tf.negative(
+ tf.norm(triples[0] + triples[1] - triples[2], axis=1, ord=1)
+ )
+ return scores
+
+ def _get_subject_corruption_scores(self, triples, ent_matrix):
+ """Compute subject corruption scores.
+
+ Evaluate the inputs against subject corruptions and scores of the corruptions.
+
+ Parameters
+ ----------
+ triples: array-like, shape (n, k)
+ Batch of input embeddings.
+ ent_matrix: array-like, shape (m, k)
+ Slice of embedding matrix (corruptions).
+
+ Returns
+ -------
+ scores: tf.Tensor, shape (n, 1)
+ Scores of subject corruptions (corruptions defined by `ent_embs` matrix).
+ """
+ # get the subject, predicate and object embeddings of True positives
+ rel_emb, obj_emb = triples[1], triples[2]
+ # compute the score by broadcasting the corruption embeddings(ent_matrix) and using the scoring function
+ # compute scores as -|| s_corr + p - o||
+ sub_corr_score = tf.negative(
+ tf.norm(
+ ent_matrix + tf.expand_dims(rel_emb - obj_emb, 1),
+ axis=2,
+ ord=1,
+ )
+ )
+ return sub_corr_score
+
+ def _get_object_corruption_scores(self, triples, ent_matrix):
+ """Compute object corruption scores.
+
+ Evaluate the inputs against object corruptions and scores of the corruptions.
+
+ Parameters
+ ----------
+ triples: array-like, shape (n, k)
+ Batch of input embeddings.
+ ent_matrix: array-like, shape (m, k)
+ Slice of embedding matrix (corruptions).
+
+ Returns
+ -------
+ scores: tf.Tensor, shape (n, 1)
+ Scores of object corruptions (corruptions defined by `ent_embs` matrix).
+ """
+ # get the subject, predicate and object embeddings of True positives:
+ sub_emb, rel_emb = triples[0], triples[1]
+ # compute the score by broadcasting the corruption embeddings(ent_matrix) and using the scoring function
+ # compute scores as -|| s + p - o_corr||
+ obj_corr_score = tf.negative(
+ tf.norm(
+ tf.expand_dims(sub_emb + rel_emb, 1) - ent_matrix,
+ axis=2,
+ ord=1,
+ )
+ )
+ return obj_corr_score
diff --git a/ampligraph/latent_features/layers/scoring/__init__.py b/ampligraph/latent_features/layers/scoring/__init__.py
new file mode 100644
index 00000000..3a093975
--- /dev/null
+++ b/ampligraph/latent_features/layers/scoring/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2019-20213The AmpliGraph Authors. All Rights Reserved.
+#
+# This file is Licensed under the Apache License, Version 2.0.
+# A copy of the Licence is available in LICENCE, or at:
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+from .TransE import TransE
+from .DistMult import DistMult
+from .HolE import HolE
+from .ComplEx import ComplEx
+from .Random import Random
+
+__all__ = ["TransE", "DistMult", "HolE", "ComplEx", "Random"]
diff --git a/ampligraph/latent_features/loss_functions.py b/ampligraph/latent_features/loss_functions.py
index e1f9f460..943dc8a4 100644
--- a/ampligraph/latent_features/loss_functions.py
+++ b/ampligraph/latent_features/loss_functions.py
@@ -1,14 +1,19 @@
-# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
#
# This file is Licensed under the Apache License, Version 2.0.
# A copy of the Licence is available in LICENCE, or at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
-import tensorflow as tf
import abc
import logging
+import six
+import tensorflow as tf
+from tensorflow.python.keras import metrics as metrics_mod
+from tensorflow.python.keras.utils import losses_utils
+from tensorflow.python.ops import math_ops
+
LOSS_REGISTRY = {}
logger = logging.getLogger(__name__)
@@ -35,110 +40,111 @@
# Default label weighting for ConvE
DEFAULT_LABEL_WEIGHTING = False
+# default reduction of corruption loss per sample
+DEFAULT_REDUCTION = "sum"
-def register_loss(name, external_params=None, class_params=None):
+
+def register_loss(name, external_params=None):
if external_params is None:
external_params = []
- default_class_params = {'require_same_size_pos_neg': True}
-
- if class_params is None:
- class_params = default_class_params
-
- def populate_class_params():
- LOSS_REGISTRY[name].class_params = {
- 'require_same_size_pos_neg': class_params.get('require_same_size_pos_neg',
- default_class_params['require_same_size_pos_neg'])
- }
-
def insert_in_registry(class_handle):
LOSS_REGISTRY[name] = class_handle
class_handle.name = name
LOSS_REGISTRY[name].external_params = external_params
- populate_class_params()
return class_handle
return insert_in_registry
def clip_before_exp(value):
- """Clip the value for stability of exponential
- """
- return tf.clip_by_value(value,
- clip_value_min=DEFAULT_CLIP_EXP_LOWER,
- clip_value_max=DEFAULT_CLIP_EXP_UPPER)
+ """Clip the value for stability of exponential."""
+ return tf.clip_by_value(
+ value,
+ clip_value_min=DEFAULT_CLIP_EXP_LOWER,
+ clip_value_max=DEFAULT_CLIP_EXP_UPPER,
+ )
class Loss(abc.ABC):
- """Abstract class for loss function.
- """
+ """Abstract class for the loss function."""
name = ""
external_params = []
class_params = {}
- def __init__(self, eta, hyperparam_dict, verbose=False):
- """Initialize Loss.
+ def __init__(self, hyperparam_dict={}, verbose=False):
+ """Initialize the loss..
Parameters
----------
- eta: int
- number of negatives
hyperparam_dict : dict
- dictionary of hyperparams.
- (Keys are described in the hyperparameters section)
+ Dictionary of hyperparams.
+
+ - `"reduction"`: (str) - Specifies whether to `"sum"` or take the `"mean"` of loss per sample w.r.t. \
+ corruptions (default: `"sum"`).
+
+ Other Keys are described in the `hyperparameters` section.
"""
self._loss_parameters = {}
+ self._loss_parameters["reduction"] = hyperparam_dict.get(
+ "reduction", DEFAULT_REDUCTION
+ )
+ assert self._loss_parameters["reduction"] in [
+ "sum",
+ "mean",
+ ], "Invalid value for reduction!"
self._dependencies = []
+ self._user_losses = self.name
+ self._user_loss_weights = None
+
+ self._loss_metric = metrics_mod.Mean(name="loss") # Total loss.
- # perform check to see if all the required external hyperparams are passed
+ # perform check to see if all the required external hyperparams are
+ # passed
try:
- self._loss_parameters['eta'] = eta
self._init_hyperparams(hyperparam_dict)
if verbose:
- logger.info('\n--------- Loss ---------')
- logger.info('Name : {}'.format(self.name))
+ logger.info("\n--------- Loss ---------")
+ logger.info("Name : {}".format(self.name))
for key, value in self._loss_parameters.items():
- logger.info('{} : {}'.format(key, value))
+ logger.info("{} : {}".format(key, value))
except KeyError as e:
- msg = 'Some of the hyperparams for loss were not passed to the loss function.\n{}'.format(e)
+ msg = "Some of the hyperparams for loss were not passed to the loss function.\n{}".format(
+ e
+ )
logger.error(msg)
raise Exception(msg)
- def get_state(self, param_name):
- """Get the state value.
+ @property
+ def metrics(self):
+ """Per-output loss metrics."""
+ return [self._loss_metric]
- Parameters
- ----------
- param_name : string
- Name of the state for which one wants to query the value.
- Returns
- -------
- param_value:
- The value of the corresponding state.
- """
- try:
- param_value = LOSS_REGISTRY[self.name].class_params.get(param_name)
- return param_value
- except KeyError as e:
- msg = 'Invalid Keu.\n{}'.format(e)
- logger.error(msg)
- raise Exception(msg)
+ def _reduce_sample_loss(self, loss):
+ """Aggregates the loss of each sample either by adding or taking the mean w.r.t. the number of corruptions."""
+ if self._loss_parameters["reduction"] == "sum":
+ return tf.reduce_sum(loss, 0)
+ else:
+ return tf.reduce_mean(loss, 0)
def _init_hyperparams(self, hyperparam_dict):
"""Initializes the hyperparameters needed by the algorithm.
Parameters
----------
- hyperparam_dict : dictionary
- Consists of key value pairs. The Loss will check the keys to get the corresponding params.
+ hyperparam_dict : dict
+ The Loss will check the keys to get the corresponding parameters.
"""
- msg = 'This function is a placeholder in an abstract class'
+ msg = "This function is a placeholder in an abstract class."
logger.error(msg)
- NotImplementedError(msg)
+ raise NotImplementedError(msg)
- def _inputs_check(self, scores_pos, scores_neg):
- """Creates any dependencies that need to be checked before performing loss computations
+ @tf.function(experimental_relax_shapes=True)
+ def _apply_loss(self, scores_pos, scores_neg):
+ """Interface to the external world.
+
+ This function does the input checks, preprocesses input and finally applies loss function.
Parameters
----------
@@ -146,36 +152,41 @@ def _inputs_check(self, scores_pos, scores_neg):
A tensor of scores assigned to positive statements.
scores_neg : tf.Tensor
A tensor of scores assigned to negative statements.
+
+ Returns
+ -------
+ loss : tf.Tensor
+ The loss value that must be minimized.
"""
- logger.debug('Creating dependencies before loss computations.')
- self._dependencies = []
- if LOSS_REGISTRY[self.name].class_params['require_same_size_pos_neg'] and self._loss_parameters['eta'] != 1:
- logger.debug('Dependencies found: \n\tRequired same size positive and negative. \n\tEta is not 1.')
- self._dependencies.append(tf.Assert(tf.equal(tf.shape(scores_pos)[0], tf.shape(scores_neg)[0]),
- [tf.shape(scores_pos)[0], tf.shape(scores_neg)[0]]))
+ msg = "This function is a placeholder in an abstract class."
+ logger.error(msg)
+ raise NotImplementedError(msg)
- def _apply(self, scores_pos, scores_neg):
- """Apply the loss function. Every inherited class must implement this function.
- (All the TF code must go in this function.)
+ def _broadcast_score_pos(self, scores_pos, eta):
+ """Broadcast the ``score_pos`` to be of size equal to the number of corruptions.
Parameters
----------
scores_pos : tf.Tensor
A tensor of scores assigned to positive statements.
- scores_neg : tf.Tensor
- A tensor of scores assigned to negative statements.
+ eta : tf.Tensor
+ Number of corruptions.
Returns
-------
- loss : tf.Tensor
- The loss value that must be minimized.
+ scores_pos : tf.Tensor
+ Broadcasted `score_pos`.
"""
- msg = 'This function is a placeholder in an abstract class.'
- logger.error(msg)
- NotImplementedError(msg)
-
- def apply(self, scores_pos, scores_neg):
+ scores_pos = tf.reshape(
+ tf.tile(scores_pos, [eta]), [eta, tf.shape(scores_pos)[0]]
+ )
+ return scores_pos
+
+ def __call__(
+ self, scores_pos, scores_neg, eta, regularization_losses=None
+ ):
"""Interface to external world.
+
This function does the input checks, preprocesses input and finally applies loss function.
Parameters
@@ -184,19 +195,37 @@ def apply(self, scores_pos, scores_neg):
A tensor of scores assigned to positive statements.
scores_neg : tf.Tensor
A tensor of scores assigned to negative statements.
+ eta: tf.Tensor
+ Number of synthetic corruptions per positive.
+ regularization_losses: list
+ List of all regularization related losses defined in the layers.
Returns
-------
loss : tf.Tensor
The loss value that must be minimized.
"""
- self._inputs_check(scores_pos, scores_neg)
- with tf.control_dependencies(self._dependencies):
- loss = self._apply(scores_pos, scores_neg)
- return loss
+ loss_values = []
+
+ scores_neg = tf.reshape(scores_neg, [eta, -1])
-@register_loss("pairwise", ['margin'])
+ loss = self._apply_loss(scores_pos, scores_neg)
+ loss_values.append(tf.reduce_sum(loss))
+ if regularization_losses:
+ regularization_losses = losses_utils.cast_losses_to_common_dtype(
+ regularization_losses
+ )
+ reg_loss = math_ops.add_n(regularization_losses)
+ loss_values.append(reg_loss)
+
+ loss_values = losses_utils.cast_losses_to_common_dtype(loss_values)
+ total_loss = math_ops.add_n(loss_values)
+ self._loss_metric.update_state(total_loss)
+ return total_loss
+
+
+@register_loss("pairwise", ["margin"])
class PairwiseLoss(Loss):
r"""Pairwise, max-margin loss.
@@ -210,46 +239,58 @@ class PairwiseLoss(Loss):
where :math:`\gamma` is the margin, :math:`\mathcal{G}` is the set of positives,
:math:`\mathcal{C}` is the set of corruptions, :math:`f_{model}(t;\Theta)` is the model-specific scoring function.
+ Example
+ -------
+ >>> import ampligraph.latent_features.loss_functions as lfs
+ >>> loss = lfs.PairwiseLoss({'margin': 0.005, 'reduction': 'sum'})
+ >>> isinstance(loss, lfs.PairwiseLoss)
+ True
+
+ >>> loss = lfs.get('pairwise')
+ >>> isinstance(loss, lfs.PairwiseLoss)
+ True
+
"""
- def __init__(self, eta, loss_params=None, verbose=False):
- """Initialize Loss.
+ def __init__(self, loss_params={}, verbose=False):
+ """Initialize the loss.
Parameters
----------
- eta: int
- Number of negatives.
loss_params : dict
Dictionary of loss-specific hyperparams:
- - **'margin'**: (float). Margin to be used in pairwise loss computation (default: 1)
+ - `"margin"`: (float) - Margin to be used in pairwise loss computation (default: 1).
+ - `"reduction"`: (str) - Specifies whether to `"sum"` or take the `"mean"` of loss per sample \
+ w.r.t. corruptions (default: `"sum"`).
- Example: ``loss_params={'margin': 1}``
+ Example: `loss_params={'margin': 1}`.
"""
- if loss_params is None:
- loss_params = {'margin': DEFAULT_MARGIN}
- super().__init__(eta, loss_params, verbose)
+ super().__init__(loss_params, verbose)
- def _init_hyperparams(self, hyperparam_dict):
+ def _init_hyperparams(self, hyperparam_dict={}):
"""Verifies and stores the hyperparameters needed by the algorithm.
Parameters
----------
- hyperparam_dict : dictionary
- Consists of key value pairs. The Loss will check the keys to get the corresponding params
+ hyperparam_dict : dict
+ The Loss will check the keys to get the corresponding parameter.
- - **margin** - Margin to be used in pairwise loss computation(default:1)
+ - `"margin"`: (str) - Margin to be used in pairwise loss computation (default: 1).
"""
- self._loss_parameters['margin'] = hyperparam_dict.get('margin', DEFAULT_MARGIN)
+ self._loss_parameters["margin"] = hyperparam_dict.get(
+ "margin", DEFAULT_MARGIN
+ )
- def _apply(self, scores_pos, scores_neg):
+ @tf.function(experimental_relax_shapes=True)
+ def _apply_loss(self, scores_pos, scores_neg):
"""Apply the loss function.
Parameters
----------
- scores_pos : tf.Tensor, shape [n, 1]
+ scores_pos : tf.Tensor, shape (n, 1)
A tensor of scores assigned to positive statements.
- scores_neg : tf.Tensor, shape [n, 1]
+ scores_neg : tf.Tensor, shape (n, 1)
A tensor of scores assigned to negative statements.
Returns
@@ -258,14 +299,18 @@ def _apply(self, scores_pos, scores_neg):
The loss value that must be minimized.
"""
- margin = tf.constant(self._loss_parameters['margin'], dtype=tf.float32, name='margin')
- loss = tf.reduce_sum(tf.maximum(margin - scores_pos + scores_neg, 0))
+ margin = tf.constant(
+ self._loss_parameters["margin"], dtype=tf.float32, name="margin"
+ )
+ loss = self._reduce_sample_loss(
+ tf.maximum(margin - scores_pos + scores_neg, 0)
+ )
return loss
@register_loss("nll")
class NLLLoss(Loss):
- r"""Negative log-likelihood loss.
+ r"""Negative Log-Likelihood loss.
As described in :cite:`trouillon2016complex`.
@@ -273,43 +318,53 @@ class NLLLoss(Loss):
\mathcal{L}(\Theta) = \sum_{t \in \mathcal{G} \cup \mathcal{C}}log(1 + exp(-y \, f_{model}(t;\Theta)))
- where :math:`y \in {-1, 1}` is the label of the statement, :math:`\mathcal{G}` is the set of positives,
- :math:`\mathcal{C}` is the set of corruptions, :math:`f_{model}(t;\Theta)` is the model-specific scoring function.
+ where :math:`y \in \{-1, 1\}` is the label of the statement, :math:`\mathcal{G}` is the set of positives,
+ :math:`\mathcal{C}` is the set of corruptions and :math:`f_{model}(t;\Theta)` is the model-specific scoring function.
+
+ Example
+ -------
+ >>> import ampligraph.latent_features.loss_functions as lfs
+ >>> loss = lfs.NLLLoss({'reduction': 'mean'})
+ >>> isinstance(loss, lfs.NLLLoss)
+ True
+ >>> loss = lfs.get('nll')
+ >>> isinstance(loss, lfs.NLLLoss)
+ True
"""
- def __init__(self, eta, loss_params=None, verbose=False):
- """Initialize Loss.
+ def __init__(self, loss_params={}, verbose=False):
+ """Initialize the loss..
Parameters
----------
- eta: int
- Number of negatives.
loss_params : dict
- Dictionary of hyperparams. No hyperparameters are required for this loss.
+ Dictionary of hyperparams. No hyperparameters are required for this loss except for `"reduction"`.
+
+ - `"reduction"`: (str) - Specifies whether to `"sum"` or take `"mean"` of loss per sample w.r.t. \
+ corruption (default:`"sum"`).
"""
- if loss_params is None:
- loss_params = {}
- super().__init__(eta, loss_params, verbose)
+ super().__init__(loss_params, verbose)
- def _init_hyperparams(self, hyperparam_dict):
+ def _init_hyperparams(self, hyperparam_dict={}):
"""Initializes the hyperparameters needed by the algorithm.
Parameters
----------
hyperparam_dict : dictionary
- Consists of key value pairs. The Loss will check the keys to get the corresponding params.
+ The Loss will check the keys to get the corresponding parameters.
"""
return
- def _apply(self, scores_pos, scores_neg):
+ @tf.function(experimental_relax_shapes=True)
+ def _apply_loss(self, scores_pos, scores_neg):
"""Apply the loss function.
Parameters
----------
- scores_pos : tf.Tensor, shape [n, 1]
+ scores_pos : tf.Tensor, shape (n, 1)
A tensor of scores assigned to positive statements.
- scores_neg : tf.Tensor, shape [n, 1]
+ scores_neg : tf.Tensor, shape (n, 1)
A tensor of scores assigned to negative statements.
Returns
@@ -320,67 +375,78 @@ def _apply(self, scores_pos, scores_neg):
"""
scores_neg = clip_before_exp(scores_neg)
scores_pos = clip_before_exp(scores_pos)
+
+ scores_pos = self._broadcast_score_pos(scores_pos, scores_neg.shape[0])
+
scores = tf.concat([-scores_pos, scores_neg], 0)
- return tf.reduce_sum(tf.log(1 + tf.exp(scores)))
+ return self._reduce_sample_loss(tf.math.log(1 + tf.exp(scores)))
-@register_loss("absolute_margin", ['margin'])
+@register_loss("absolute_margin", ["margin"])
class AbsoluteMarginLoss(Loss):
- r"""Absolute margin , max-margin loss.
+ r"""Absolute margin, max-margin loss.
Introduced in :cite:`Hamaguchi2017`.
.. math::
- \mathcal{L}(\Theta) = \sum_{t^+ \in \mathcal{G}}\sum_{t^- \in \mathcal{C}} f_{model}(t^-;\Theta)
- - max(0, [\gamma - f_{model}(t^+;\Theta)])
+ \mathcal{L}(\Theta) = \sum_{t^+ \in \mathcal{G}}\sum_{t^- \in \mathcal{C}}
+ max(0, [\gamma - f_{model}(t^-;\Theta)]) - f_{model}(t^+;\Theta)
where :math:`\gamma` is the margin, :math:`\mathcal{G}` is the set of positives, :math:`\mathcal{C}` is the
set of corruptions, :math:`f_{model}(t;\Theta)` is the model-specific scoring function.
+ Example
+ -------
+ >>> import ampligraph.latent_features.loss_functions as lfs
+ >>> loss = lfs.AbsoluteMarginLoss({'margin': 1, 'reduction': 'mean'})
+ >>> isinstance(loss, lfs.AbsoluteMarginLoss)
+ True
+
+ >>> loss = lfs.get('absolute_margin')
+ >>> isinstance(loss, lfs.AbsoluteMarginLoss)
+ True
"""
- def __init__(self, eta, loss_params=None, verbose=False):
- """Initialize Loss
+ def __init__(self, loss_params={}, verbose=False):
+ """Initialize the loss.
Parameters
----------
- eta: int
- Number of negatives.
loss_params : dict
Dictionary of loss-specific hyperparams:
- - **'margin'**: float. Margin to be used in pairwise loss computation (default:1)
+ - `"margin"`: (float) - Margin to be used in pairwise loss computation (default: 1).
+ - `"reduction"`: (str) - Specifies whether to `"sum"` or take `"mean"` of loss per sample w.r.t.\
+ corruption (default: `"sum"`).
- Example: ``loss_params={'margin': 1}``
+ Example: ``loss_params={'margin': 1}``.
"""
- if loss_params is None:
- loss_params = {'margin': DEFAULT_MARGIN}
- super().__init__(eta, loss_params, verbose)
+ super().__init__(loss_params, verbose)
- def _init_hyperparams(self, hyperparam_dict):
+ def _init_hyperparams(self, hyperparam_dict={}):
"""Initializes the hyperparameters needed by the algorithm.
Parameters
----------
hyperparam_dict : dict
- Consists of key value pairs. The Loss will check the keys to get the corresponding params.
-
- **margin** - Margin to be used in loss computation(default:1)
+ The Loss will check the keys to get the corresponding params.
- Returns
- -------
+ `"margin"`: (str) - Margin to be used in loss computation (default: 1).
"""
- self._loss_parameters['margin'] = hyperparam_dict.get('margin', DEFAULT_MARGIN)
+ self._loss_parameters["margin"] = hyperparam_dict.get(
+ "margin", DEFAULT_MARGIN
+ )
- def _apply(self, scores_pos, scores_neg):
+ @tf.function(experimental_relax_shapes=True)
+ def _apply_loss(self, scores_pos, scores_neg):
"""Apply the loss function.
Parameters
----------
- scores_pos : tf.Tensor, shape [n, 1]
+ scores_pos : tf.Tensor, shape (n, 1)
A tensor of scores assigned to positive statements.
- scores_neg : tf.Tensor, shape [n, 1]
+ scores_neg : tf.Tensor, shape (n, 1)
A tensor of scores assigned to negative statements.
Returns
@@ -389,318 +455,312 @@ def _apply(self, scores_pos, scores_neg):
The loss value that must be minimized.
"""
- margin = tf.constant(self._loss_parameters['margin'], dtype=tf.float32, name='margin')
- loss = tf.reduce_sum(tf.maximum(margin + scores_neg, 0) - scores_pos)
+ margin = tf.constant(
+ self._loss_parameters["margin"], dtype=tf.float32, name="margin"
+ )
+ loss = self._reduce_sample_loss(
+ tf.maximum(margin + scores_neg, 0) - scores_pos
+ )
return loss
-@register_loss("self_adversarial", ['margin', 'alpha'], {'require_same_size_pos_neg': False})
+@register_loss("self_adversarial", ["margin", "alpha"])
class SelfAdversarialLoss(Loss):
- r"""Self adversarial sampling loss.
+ r"""Self Adversarial Sampling loss.
Introduced in :cite:`sun2018rotate`.
.. math::
- \mathcal{L} = -log\, \sigma(\gamma + f_{model} (\mathbf{s},\mathbf{o}))
- - \sum_{i=1}^{n} p(h_{i}^{'}, r, t_{i}^{'} ) \ log \
- \sigma(-f_{model}(\mathbf{s}_{i}^{'},\mathbf{o}_{i}^{'}) - \gamma)
+ \mathcal{L} = -log \left( \sigma(\gamma + f_{model} (\mathbf{s},\mathbf{o})) \right)
+ - \sum_{i=1}^{n} p(h'_{i}, r, t'_{i} ) \cdot log
+ \left( \sigma(-f_{model}(\mathbf{s}'_{i},\mathbf{o}'_{i}) - \gamma) \right)
where :math:`\mathbf{s}, \mathbf{o} \in \mathcal{R}^k` are the embeddings of the subject
- and object of a triple :math:`t=(s,r,o)`, :math:`\gamma` is the margin, :math:`\sigma` the sigmoid function,
- and :math:`p(s_{i}^{'}, r, o_{i}^{'} )` is the negatives sampling distribution which is defined as:
+ and object of a triple :math:`t=(s,r,o)`, :math:`\gamma \in \mathbb{R}` is the margin, :math:`\sigma` the sigmoid
+ function, and :math:`p(s'_{i}, r, o'_{i})` is the negatives sampling distribution which is defined as:
.. math::
- p(s'_j, r, o'_j | \{(s_i, r_i, o_i)\}) = \frac{\exp \alpha \, f_{model}(\mathbf{s'_j}, \mathbf{o'_j})}
- {\sum_i \exp \alpha \, f_{model}(\mathbf{s'_i}, \mathbf{o'_i})}
+ p(s'_j, r, o'_j | \{(s_i, r_i, o_i)\}) = \frac{\exp \left( \alpha \, f_{model}(\mathbf{s'_j}, \mathbf{o'_j}) \right)}
+ {\sum_i \exp \left( \alpha \, f_{model}(\mathbf{s'_i}, \mathbf{o'_i}) \right)}
- where :math:`\alpha` is the temperature of sampling, :math:`f_{model}` is the scoring function of
- the desired embeddings model.
+ where :math:`\alpha` is the temperature of sampling and :math:`f_{model}` is the scoring function of
+ the desired embedding model.
+ Example
+ -------
+ >>> import ampligraph.latent_features.loss_functions as lfs
+ >>> loss = lfs.SelfAdversarialLoss({'margin': 1, 'alpha': 0.1, 'reduction': 'mean'})
+ >>> isinstance(loss, lfs.SelfAdversarialLoss)
+ True
+ >>> loss = lfs.get('self_adversarial')
+ >>> isinstance(loss, lfs.SelfAdversarialLoss)
+ True
"""
- def __init__(self, eta, loss_params=None, verbose=False):
- """Initialize Loss
+ def __init__(self, loss_params={}, verbose=False):
+ """Initialize the loss.
Parameters
----------
- eta: int
- number of negatives
loss_params : dict
Dictionary of loss-specific hyperparams:
- - **'margin'**: (float). Margin to be used for loss computation (default: 1)
- - **'alpha'** : (float). Temperature of sampling (default:0.5)
+ - `"margin"`: (float) - Margin to be used for loss computation (default: 1).
+ - `"alpha"`: (float) - Temperature of sampling (default: 0.5).
+ - `"reduction"`: (str) - Specifies whether to `"sum"` or take the `"mean"` of the loss per sample w.r.t. \
+ corruption (default: `"sum"`).
- Example: ``loss_params={'margin': 1, 'alpha': 0.5}``
+ Example: `loss_params={'margin': 1, 'alpha': 0.5}`.
"""
- if loss_params is None:
- loss_params = {'margin': DEFAULT_MARGIN_ADVERSARIAL, 'alpha': DEFAULT_ALPHA_ADVERSARIAL}
- super().__init__(eta, loss_params, verbose)
+ super().__init__(loss_params, verbose)
- def _init_hyperparams(self, hyperparam_dict):
+ def _init_hyperparams(self, hyperparam_dict={}):
"""Initializes the hyperparameters needed by the algorithm.
Parameters
----------
- hyperparam_dict : dictionary
- Consists of key value pairs. The Loss will check the keys to get the corresponding params
-
- - **margin** - Margin to be used in adversarial loss computation (default:3)
+ hyperparam_dict : dict
+ The Loss will check the keys to get the corresponding parameters.
- - **alpha** - Temperature of sampling (default:0.5)
+ - `"margin"`` (int) - Margin to be used in adversarial loss computation (default: 3).
+ - `"alpha"`: (float) - Temperature of sampling (default: 0.5).
"""
- self._loss_parameters['margin'] = hyperparam_dict.get('margin', DEFAULT_MARGIN_ADVERSARIAL)
- self._loss_parameters['alpha'] = hyperparam_dict.get('alpha', DEFAULT_ALPHA_ADVERSARIAL)
-
- def _apply(self, scores_pos, scores_neg):
+ self._loss_parameters["margin"] = hyperparam_dict.get(
+ "margin", DEFAULT_MARGIN_ADVERSARIAL
+ )
+ self._loss_parameters["alpha"] = hyperparam_dict.get(
+ "alpha", DEFAULT_ALPHA_ADVERSARIAL
+ )
+
+ @tf.function(experimental_relax_shapes=True)
+ def _apply_loss(self, scores_pos, scores_neg):
"""Apply the loss function.
- Parameters
- ----------
- scores_pos : tf.Tensor, shape [n, 1]
- A tensor of scores assigned to positive statements.
- scores_neg : tf.Tensor, shape [n*negative_count, 1]
- A tensor of scores assigned to negative statements.
+ Parameters
+ ----------
+ scores_pos : tf.Tensor, shape (n, 1)
+ A tensor of scores assigned to positive statements.
+ scores_neg : tf.Tensor, shape (eta, n)
+ A tensor of scores assigned to negative statements.
- Returns
- -------
- loss : tf.Tensor
- The loss value that must be minimized.
+ Returns
+ -------
+ loss : tf.Tensor
+ The loss value that must be minimized.
- """
- margin = tf.constant(self._loss_parameters['margin'], dtype=tf.float32, name='margin')
- alpha = tf.constant(self._loss_parameters['alpha'], dtype=tf.float32, name='alpha')
+ """
+ margin = tf.constant(
+ self._loss_parameters["margin"], dtype=tf.float32, name="margin"
+ )
+ alpha = tf.constant(
+ self._loss_parameters["alpha"], dtype=tf.float32, name="alpha"
+ )
- # Compute p(neg_samples) based on eq 4
- scores_neg_reshaped = tf.reshape(scores_neg, [self._loss_parameters['eta'], tf.shape(scores_pos)[0]])
- p_neg = tf.nn.softmax(alpha * scores_neg_reshaped, axis=0)
+ p_neg = tf.nn.softmax(alpha * scores_neg, axis=0)
# Compute Loss based on eg 5
- loss = tf.reduce_sum(-tf.log_sigmoid(margin - tf.negative(scores_pos))) - tf.reduce_sum(
- tf.multiply(p_neg, tf.log_sigmoid(tf.negative(scores_neg_reshaped) - margin)))
+ loss = -tf.math.log_sigmoid(
+ margin - tf.negative(scores_pos)
+ ) - self._reduce_sample_loss(
+ tf.multiply(
+ p_neg, tf.math.log_sigmoid(tf.negative(scores_neg) - margin)
+ )
+ )
return loss
-@register_loss("multiclass_nll", [], {'require_same_size_pos_neg': False})
+@register_loss("multiclass_nll", [])
class NLLMulticlass(Loss):
- r"""Multiclass NLL Loss.
+ r"""Multiclass Negative Log-Likelihood loss.
- Introduced in :cite:`chen2015` where both the subject and objects are corrupted (to use it in this way pass
- corrupt_sides = ['s', 'o'] to embedding_model_params) .
+ Introduced in :cite:`chen2015`, this loss can be used when both the subject and objects are corrupted
+ (to use it, pass ``corrupt_sides=['s,o']`` in the embedding model parameters).
This loss was re-engineered in :cite:`kadlecBK17` where only the object was corrupted to get improved
- performance (to use it in this way pass corrupt_sides = 'o' to embedding_model_params).
+ performance (to use it in this way pass ``corrupt_sides ='o'`` in the embedding model parameters).
.. math::
\mathcal{L(X)} = -\sum_{x_{e_1,e_2,r_k} \in X} log\,p(e_2|e_1,r_k)
-\sum_{x_{e_1,e_2,r_k} \in X} log\,p(e_1|r_k, e_2)
- Examples
- --------
- >>> from ampligraph.latent_features import TransE
- >>> model = TransE(batches_count=1, seed=555, epochs=20, k=10,
- >>> embedding_model_params={'corrupt_sides':['s', 'o']},
- >>> loss='multiclass_nll', loss_params={})
+ Example
+ -------
+ >>> import ampligraph.latent_features.loss_functions as lfs
+ >>> loss = lfs.NLLMulticlass({'reduction': 'mean'})
+ >>> isinstance(loss, lfs.NLLMulticlass)
+ True
+
+ >>> loss = lfs.get('multiclass_nll')
+ >>> isinstance(loss, lfs.NLLMulticlass)
+ True
"""
- def __init__(self, eta, loss_params=None, verbose=False):
- """Initialize Loss
+
+ def __init__(self, loss_params={}, verbose=False):
+ """Initialize the loss.
Parameters
----------
- eta: int
- number of negatives
loss_params : dict
Dictionary of loss-specific hyperparams:
+ - `"reduction"`: (str) - Specifies whether to `"sum"` or take the `"mean"` of loss per sample w.r.t. \
+ corruption (default: `"sum"`).
+
"""
- if loss_params is None:
- loss_params = {}
- super().__init__(eta, loss_params, verbose)
+ super().__init__(loss_params, verbose)
- def _init_hyperparams(self, hyperparam_dict):
+ def _init_hyperparams(self, hyperparam_dict={}):
"""Verifies and stores the hyperparameters needed by the algorithm.
Parameters
----------
- hyperparam_dict : dictionary
- Consists of key value pairs. The Loss will check the keys to get the corresponding params
+ hyperparam_dict : dict
+ The Loss will check the keys to get the corresponding parameters.
"""
pass
- def _apply(self, scores_pos, scores_neg):
+ @tf.function(experimental_relax_shapes=True)
+ def _apply_loss(self, scores_pos, scores_neg):
"""Apply the loss function.
- Parameters
- ----------
- scores_pos : tf.Tensor, shape [n, 1]
- A tensor of scores assigned to positive statements.
- scores_neg : tf.Tensor, shape [n*negative_count, 1]
- A tensor of scores assigned to negative statements.
+ Parameters
+ ----------
+ scores_pos : tf.Tensor, shape (n, 1)
+ A tensor of scores assigned to positive statements.
+ scores_neg : tf.Tensor, shape (eta, n)
+ A tensor of scores assigned to negative statements.
- Returns
- -------
- loss : float
- The loss value that must be minimized.
+ Returns
+ -------
+ loss : float
+ The loss value that must be minimized.
- """
+ """
# Fix for numerical instability of multiclass loss
scores_pos = clip_before_exp(scores_pos)
scores_neg = clip_before_exp(scores_neg)
- scores_neg_reshaped = tf.reshape(scores_neg, [self._loss_parameters['eta'], tf.shape(scores_pos)[0]])
- neg_exp = tf.exp(scores_neg_reshaped)
+ neg_exp = tf.exp(scores_neg)
pos_exp = tf.exp(scores_pos)
- softmax_score = pos_exp / (tf.reduce_sum(neg_exp, axis=0) + pos_exp)
-
- loss = -tf.reduce_sum(tf.log(softmax_score))
+ softmax_score = pos_exp / (self._reduce_sample_loss(neg_exp) + pos_exp)
+ loss = -tf.math.log(softmax_score)
return loss
-@register_loss('bce', ['label_smoothing', 'label_weighting'], {'require_same_size_pos_neg': False})
-class BCELoss(Loss):
- r""" Binary Cross Entropy Loss.
-
- .. math::
-
- \mathcal{L} = - \frac{1}{N} \sum_{i=1}^{N} y_i \cdot log(p(y_i)) + (1-y_i) \cdot log(1-p(y_i))
-
- Examples
- --------
- >>> from ampligraph.latent_features.models import ConvE
- >>> model = ConvE(batches_count=1, seed=555, epochs=20, k=10, loss='bce', loss_params={})
+class LossFunctionWrapper(Loss):
+ """Wraps a loss function in the `Loss` class.
+
+ Example
+ -------
+ >>> import ampligraph.latent_features.loss_functions as lfs
+ >>> def user_defined_loss(scores_pos, scores_neg):
+ >>> neg_exp = tf.exp(scores_neg)
+ >>> pos_exp = tf.exp(scores_pos)
+ >>> softmax_score = pos_exp / (tf.reduce_sum(neg_exp, axis=0) + pos_exp)
+ >>> loss = -tf.math.log(softmax_score)
+ >>> return loss
+ >>> udf_loss = lfs.get(user_defined_loss)
+ >>> isinstance(udf_loss, Loss)
+ True
+ >>> isinstance(udf_loss, LossFunctionWrapper)
+ True
"""
- def __init__(self, eta, loss_params={}, verbose=False):
- """Initialize Loss
+ def __init__(self, user_defined_loss, name=None):
+ """Initializes the LossFunctionWrapper.
Parameters
----------
- loss_params : dict
- Dictionary of loss-specific hyperparams:
-
+ user_defined_loss : function_handle
+ Handle to loss function (should take 2 parameters as input).
+ name: str
+ Name of the loss function.
"""
- super().__init__(eta, loss_params, verbose)
-
- def _inputs_check(self, y_true, y_pred):
- """ Creates any dependencies that need to be checked before performing loss computations
+ super(LossFunctionWrapper, self).__init__()
+ self._user_losses = user_defined_loss
+ self.name = name
- Parameters
- ----------
- y_true : tf.Tensor
- A tensor of ground truth values.
- y_pred : tf.Tensor
- A tensor of predicted values.
- """
-
- logger.debug('Creating dependencies before loss computations.')
-
- self._dependencies = []
- logger.debug('Dependencies found: \n\tRequired same size y_true and y_pred. ')
- self._dependencies.append(tf.Assert(tf.equal(tf.shape(y_pred)[0], tf.shape(y_true)[0]),
- [tf.shape(y_pred)[0], tf.shape(y_true)[0]]))
-
- if self._loss_parameters['label_smoothing'] is not None:
- if 'num_entities' not in self._loss_parameters.keys():
- msg = "To apply label smoothing the number of entities must be known. " \
- "Set using '_set_hyperparams('num_entities', value)'."
- logger.error(msg)
- raise Exception(msg)
-
- def _init_hyperparams(self, hyperparam_dict):
- """ Verifies and stores the hyperparameters needed by the algorithm.
-
- Parameters
- ----------
- hyperparam_dict : dictionary
- Consists of key value pairs. The Loss will check the keys to get the corresponding params
- - **label_smoothing** (float): Apply label smoothing to vector of true labels. Can improve multi-class
- classification training by using soft targets that are a weighted average of hard targets and the
- uniform distribution over labels. Default: None
- - **label_weighting** (bool): Apply label weighting to vector of true labels. Gives lower weight to
- outputs with more positives in one-hot vector. Default: False
-
- """
-
- self._loss_parameters['label_smoothing'] = hyperparam_dict.get('label_smoothing', DEFAULT_LABEL_SMOOTHING)
- self._loss_parameters['label_weighting'] = hyperparam_dict.get('label_weighting', DEFAULT_LABEL_WEIGHTING)
-
- def _set_hyperparams(self, key, value):
- """ Set a hyperparameter needed by the loss function.
+ def _init_hyperparams(self, hyperparam_dict={}):
+ """Verifies and stores the hyperparameters needed by the algorithm.
Parameters
----------
- key : key for hyperparams dictionary
- value : value for hyperparams dictionary
-
- Returns
- -------
-
+ hyperparam_dict : dict
+ The Loss will check the keys to get the corresponding parameters.
"""
+ pass
- if key in self._loss_parameters.keys():
- msg = '{} already exists in loss hyperparameters dict with value {} \n' \
- 'Overriding with value {}.'.format(key, self._loss_parameters[key], value)
- logger.info(msg)
-
- self._loss_parameters[key] = value
-
- def apply(self, y_true, y_pred):
- """ Interface to external world.
- This function does the input checks, preprocesses input and finally applies loss function.
+ @tf.function(experimental_relax_shapes=True)
+ def _apply_loss(self, scores_pos, scores_neg):
+ """Apply the loss function.
Parameters
----------
- y_true : tf.Tensor
- A tensor of ground truth values.
- y_true : tf.Tensor
- A tensor of predicted values.
+ scores_pos : tf.Tensor, shape (n, 1)
+ A tensor of scores assigned to positive statements.
+ scores_neg : tf.Tensor, shape (eta, n)
+ A tensor of scores assigned to negative statements.
Returns
-------
- loss : tf.Tensor
+ loss : float
The loss value that must be minimized.
- """
- self._inputs_check(y_true, y_pred)
- with tf.control_dependencies(self._dependencies):
- loss = self._apply(y_true, y_pred)
- return loss
-
- def _apply(self, y_true, y_pred):
- """ Apply the loss function.
-
- Parameters
- ----------
- y_true : tf.Tensor
- A tensor of true values.
- y_pred : tf.Tensor
- A tensor of predicted values.
-
- Returns
- -------
- loss : float
- The loss value that must be minimized.
-
- """
-
- if self._loss_parameters['label_smoothing'] is not None:
- y_true = tf.add((1 - self._loss_parameters['label_smoothing']) * y_true,
- (self._loss_parameters['label_smoothing']) / self._loss_parameters['num_entities'])
- if self._loss_parameters['label_weighting']:
+ """
+ return self._user_losses(scores_pos, scores_neg)
- eps = 1e-6
- wt = tf.reduce_mean(y_true)
- loss = -tf.reduce_sum((1 - wt) * y_true * tf.log_sigmoid(y_pred)
- + wt * (1 - y_true) * tf.log(1 - tf.sigmoid(y_pred) + eps))
- else:
- loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=y_true, logits=y_pred))
-
- return loss
+def get(identifier, hyperparams={}):
+ """
+ Get the loss function specified by the identifier.
+
+ Parameters
+ ----------
+ identifier: Loss class instance or str or function handle
+ Instance of Loss class (Pairwise, NLLLoss, etc.), name of the (existing) loss function to be used
+ (with default parameters) or handle to the function which takes in two parameters (signature:
+ def loss_fn(scores_pos, scores_neg)).
+
+ Returns
+ -------
+ loss: Loss class instance
+ Loss function.
+
+ Example
+ -------
+ >>> import ampligraph.latent_features.loss_functions as lfs
+ >>> nll_loss = lfs.get('nll')
+ >>> isinstance(udf_loss, Loss)
+ True
+
+ >>> def user_defined_loss(scores_pos, scores_neg):
+ >>> neg_exp = tf.exp(scores_neg)
+ >>> pos_exp = tf.exp(scores_pos)
+ >>> softmax_score = pos_exp / (tf.reduce_sum(neg_exp, axis=0) + pos_exp)
+ >>> loss = -tf.math.log(softmax_score)
+ >>> return loss
+ >>> udf_loss = lfs.get(user_defined_loss)
+ >>> isinstance(udf_loss, Loss)
+ True
+ """
+ if isinstance(identifier, Loss):
+ return identifier
+ elif isinstance(identifier, six.string_types):
+ if identifier not in LOSS_REGISTRY.keys():
+ raise ValueError(
+ "Could not interpret loss identifier:", identifier
+ )
+ return LOSS_REGISTRY.get(identifier)(hyperparams)
+ elif callable(identifier):
+ loss_name = identifier.__name__
+ wrapped_callable = LossFunctionWrapper(identifier, loss_name)
+ return wrapped_callable
+ else:
+ raise ValueError("Could not interpret loss identifier:", identifier)
diff --git a/ampligraph/latent_features/misc.py b/ampligraph/latent_features/misc.py
deleted file mode 100644
index 124e9a8e..00000000
--- a/ampligraph/latent_features/misc.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
-#
-# This file is Licensed under the Apache License, Version 2.0.
-# A copy of the Licence is available in LICENCE, or at:
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-import numpy as np
-import logging
-
-SUBJECT = 0
-PREDICATE = 1
-OBJECT = 2
-DEBUG = True
-
-logger = logging.getLogger(__name__)
-logger.setLevel(logging.DEBUG)
-
-
-def get_entity_triples(entity, graph):
- """
- Given an entity label e included in the graph G, returns an list of all triples where e appears either
- as subject or object.
-
- Parameters
- ----------
- entity : str, shape [n, 1]
- An entity label.
- graph : np.ndarray, shape [n, 3]
- An ndarray of triples.
-
- Returns
- -------
- neighbours : np.ndarray, shape [n, 3]
- An ndarray of triples where e is either the subject or the object.
- """
- logger.debug('Return a list of all triples where {} appears as subject or object.'.format(entity))
- # NOTE: The current implementation is slightly faster (~15%) than the more readable one-liner:
- # rows, _ = np.where((entity == graph[:,[SUBJECT,OBJECT]]))
-
- # Get rows and cols where entity is found in graph
- rows, cols = np.where((entity == graph))
-
- # In the unlikely event that entity is found in the relation column (index 1)
- rows = rows[np.where(cols != PREDICATE)]
-
- # Subset graph to neighbourhood of entity
- neighbours = graph[rows, :]
-
- return neighbours
diff --git a/ampligraph/latent_features/models/ComplEx.py b/ampligraph/latent_features/models/ComplEx.py
deleted file mode 100644
index 45ee4518..00000000
--- a/ampligraph/latent_features/models/ComplEx.py
+++ /dev/null
@@ -1,385 +0,0 @@
-# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
-#
-# This file is Licensed under the Apache License, Version 2.0.
-# A copy of the Licence is available in LICENCE, or at:
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-from .EmbeddingModel import EmbeddingModel, register_model
-from ampligraph.latent_features import constants as constants
-from ampligraph.latent_features.initializers import DEFAULT_XAVIER_IS_UNIFORM
-import tensorflow as tf
-import time
-
-
-@register_model("ComplEx", ["negative_corruption_entities"])
-class ComplEx(EmbeddingModel):
- r"""Complex embeddings (ComplEx)
-
- The ComplEx model :cite:`trouillon2016complex` is an extension of
- the :class:`ampligraph.latent_features.DistMult` bilinear diagonal model
- . ComplEx scoring function is based on the trilinear Hermitian dot product in :math:`\mathcal{C}`:
-
- .. math::
-
- f_{ComplEx}=Re(\langle \mathbf{r}_p, \mathbf{e}_s, \overline{\mathbf{e}_o} \rangle)
-
- ComplEx can be improved if used alongside the nuclear 3-norm
- (the **ComplEx-N3** model :cite:`lacroix2018canonical`), which can be easily added to the
- loss function via the ``regularizer`` hyperparameter with ``p=3`` and
- a chosen regularisation weight (represented by ``lambda``), as shown in the example below.
- See also :meth:`ampligraph.latent_features.LPRegularizer`.
-
- .. note::
-
- Since ComplEx embeddings belong to :math:`\mathcal{C}`, this model uses twice as many parameters as
- :class:`ampligraph.latent_features.DistMult`.
-
- Examples
- --------
- >>> import numpy as np
- >>> from ampligraph.latent_features import ComplEx
- >>>
- >>> model = ComplEx(batches_count=2, seed=555, epochs=100, k=20, eta=5,
- >>> loss='pairwise', loss_params={'margin':1},
- >>> regularizer='LP', regularizer_params={'p': 2, 'lambda':0.1})
- >>> X = np.array([['a', 'y', 'b'],
- >>> ['b', 'y', 'a'],
- >>> ['a', 'y', 'c'],
- >>> ['c', 'y', 'a'],
- >>> ['a', 'y', 'd'],
- >>> ['c', 'y', 'd'],
- >>> ['b', 'y', 'c'],
- >>> ['f', 'y', 'e']])
- >>> model.fit(X)
- >>> model.predict(np.array([['f', 'y', 'e'], ['b', 'y', 'd']]))
- [[0.019520484], [-0.14998421]]
- >>> model.get_embeddings(['f','e'], embedding_type='entity')
- array([[-0.33021057, 0.26524785, 0.0446662 , -0.07932718, -0.15453218,
- -0.22342539, -0.03382565, 0.17444217, 0.03009969, -0.33569157,
- 0.3200497 , 0.03803705, 0.05536304, -0.00929996, 0.24446663,
- 0.34408194, 0.16192885, -0.15033236, -0.19703785, -0.00783876,
- 0.1495124 , -0.3578853 , -0.04975723, -0.03930473, 0.1663541 ,
- -0.24731971, -0.141296 , 0.03150219, 0.15328223, -0.18549544,
- -0.39240393, -0.10824018, 0.03394471, -0.11075485, 0.1367736 ,
- 0.10059565, -0.32808647, -0.00472086, 0.14231135, -0.13876757],
- [-0.09483694, 0.3531292 , 0.04992269, -0.07774793, 0.1635035 ,
- 0.30610007, 0.3666711 , -0.13785957, -0.3143734 , -0.36909637,
- -0.13792469, -0.07069954, -0.0368113 , -0.16743314, 0.4090072 ,
- -0.03407392, 0.3113114 , -0.08418448, 0.21435146, 0.12006859,
- 0.08447982, -0.02025972, 0.38752195, 0.11451488, -0.0258422 ,
- -0.10990044, -0.22661531, -0.00478273, -0.0238297 , -0.14207476,
- 0.11064807, 0.20135397, 0.22501846, -0.1731076 , -0.2770435 ,
- 0.30784574, -0.15043163, -0.11599299, 0.05718031, -0.1300622 ]],
- dtype=float32)
-
- """
- def __init__(self,
- k=constants.DEFAULT_EMBEDDING_SIZE,
- eta=constants.DEFAULT_ETA,
- epochs=constants.DEFAULT_EPOCH,
- batches_count=constants.DEFAULT_BATCH_COUNT,
- seed=constants.DEFAULT_SEED,
- embedding_model_params={'negative_corruption_entities': constants.DEFAULT_CORRUPTION_ENTITIES,
- 'corrupt_sides': constants.DEFAULT_CORRUPT_SIDE_TRAIN},
- optimizer=constants.DEFAULT_OPTIM,
- optimizer_params={'lr': constants.DEFAULT_LR},
- loss=constants.DEFAULT_LOSS,
- loss_params={},
- regularizer=constants.DEFAULT_REGULARIZER,
- regularizer_params={},
- initializer=constants.DEFAULT_INITIALIZER,
- initializer_params={'uniform': DEFAULT_XAVIER_IS_UNIFORM},
- verbose=constants.DEFAULT_VERBOSE):
- """Initialize an EmbeddingModel
-
- Also creates a new Tensorflow session for training.
-
- Parameters
- ----------
- k : int
- Embedding space dimensionality
- eta : int
- The number of negatives that must be generated at runtime during training for each positive.
- epochs : int
- The iterations of the training loop.
- batches_count : int
- The number of batches in which the training set must be split during the training loop.
- seed : int
- The seed used by the internal random numbers generator.
- embedding_model_params : dict
- ComplEx-specific hyperparams:
-
- - **'negative_corruption_entities'** - Entities to be used for generation of corruptions while training.
- It can take the following values :
- ``all`` (default: all entities),
- ``batch`` (entities present in each batch),
- list of entities
- or an int (which indicates how many entities that should be used for corruption generation).
- - **corrupt_sides** : Specifies how to generate corruptions for training.
- Takes values `s`, `o`, `s+o` or any combination passed as a list
- - **'non_linearity'**: can be one of the following values ``linear``, ``softplus``, ``sigmoid``, ``tanh``
- - **'stop_epoch'**: specifies how long to decay (linearly) the numeric values from 1 to original value
- until it reachs original value.
- - **'structural_wt'**: structural influence hyperparameter [0, 1] that modulates the influence of graph
- topology.
- - **'normalize_numeric_values'**: normalize the numeric values, such that they are scaled between [0, 1]
-
- The last 4 parameters are related to FocusE layers.
-
- optimizer : string
- The optimizer used to minimize the loss function. Choose between 'sgd',
- 'adagrad', 'adam', 'momentum'.
-
- optimizer_params : dict
- Arguments specific to the optimizer, passed as a dictionary.
-
- Supported keys:
-
- - **'lr'** (float): learning rate (used by all the optimizers). Default: 0.1.
- - **'momentum'** (float): learning momentum (only used when ``optimizer=momentum``). Default: 0.9.
-
- Example: ``optimizer_params={'lr': 0.01}``
-
- loss : string
- The type of loss function to use during training.
-
- - ``pairwise`` the model will use pairwise margin-based loss function.
- - ``nll`` the model will use negative loss likelihood.
- - ``absolute_margin`` the model will use absolute margin likelihood.
- - ``self_adversarial`` the model will use adversarial sampling loss function.
- - ``multiclass_nll`` the model will use multiclass nll loss.
- Switch to multiclass loss defined in :cite:`chen2015` by passing 'corrupt_sides'
- as ['s','o'] to embedding_model_params.
- To use loss defined in :cite:`kadlecBK17` pass 'corrupt_sides' as 'o' to embedding_model_params.
-
- loss_params : dict
- Dictionary of loss-specific hyperparameters. See :ref:`loss functions `
- documentation for additional details.
-
- Example: ``optimizer_params={'lr': 0.01}`` if ``loss='pairwise'``.
-
- regularizer : string
- The regularization strategy to use with the loss function.
-
- - ``None``: the model will not use any regularizer (default)
- - 'LP': the model will use L1, L2 or L3 based on the value of ``regularizer_params['p']`` (see below).
-
- regularizer_params : dict
- Dictionary of regularizer-specific hyperparameters. See the :ref:`regularizers `
- documentation for additional details.
-
- Example: ``regularizer_params={'lambda': 1e-5, 'p': 2}`` if ``regularizer='LP'``.
-
- initializer : string
- The type of initializer to use.
-
- - ``normal``: The embeddings will be initialized from a normal distribution
- - ``uniform``: The embeddings will be initialized from a uniform distribution
- - ``xavier``: The embeddings will be initialized using xavier strategy (default)
-
- initializer_params : dict
- Dictionary of initializer-specific hyperparameters. See the
- :ref:`initializer `
- documentation for additional details.
-
- Example: ``initializer_params={'mean': 0, 'std': 0.001}`` if ``initializer='normal'``.
-
- verbose : bool
- Verbose mode.
- """
- super().__init__(k=k, eta=eta, epochs=epochs, batches_count=batches_count, seed=seed,
- embedding_model_params=embedding_model_params,
- optimizer=optimizer, optimizer_params=optimizer_params,
- loss=loss, loss_params=loss_params,
- regularizer=regularizer, regularizer_params=regularizer_params,
- initializer=initializer, initializer_params=initializer_params,
- verbose=verbose)
-
- self.internal_k = self.k * 2
-
- def _initialize_parameters(self):
- """Initialize the complex embeddings.
- """
- timestamp = int(time.time() * 1e6)
- if not self.dealing_with_large_graphs:
- self.ent_emb = tf.get_variable('ent_emb_{}'.format(timestamp),
- shape=[len(self.ent_to_idx), self.internal_k],
- initializer=self.initializer.get_entity_initializer(
- len(self.ent_to_idx), self.internal_k),
- dtype=tf.float32)
- self.rel_emb = tf.get_variable('rel_emb_{}'.format(timestamp),
- shape=[len(self.rel_to_idx), self.internal_k],
- initializer=self.initializer.get_relation_initializer(
- len(self.rel_to_idx), self.internal_k),
- dtype=tf.float32)
- else:
- # initialize entity embeddings to zero (these are reinitialized every batch by batch embeddings)
- self.ent_emb = tf.get_variable('ent_emb_{}'.format(timestamp),
- shape=[self.batch_size * 2, self.internal_k],
- initializer=tf.zeros_initializer(),
- dtype=tf.float32)
- self.rel_emb = tf.get_variable('rel_emb_{}'.format(timestamp),
- shape=[len(self.rel_to_idx), self.internal_k],
- initializer=self.initializer.get_relation_initializer(
- len(self.rel_to_idx), self.internal_k),
- dtype=tf.float32)
-
- def _fn(self, e_s, e_p, e_o):
- r"""ComplEx scoring function.
-
- .. math::
-
- f_{ComplEx}=Re(\langle \mathbf{r}_p, \mathbf{e}_s, \overline{\mathbf{e}_o} \rangle)
-
- Additional details available in :cite:`trouillon2016complex` (Equation 9).
-
- Parameters
- ----------
- e_s : Tensor, shape [n]
- The embeddings of a list of subjects.
- e_p : Tensor, shape [n]
- The embeddings of a list of predicates.
- e_o : Tensor, shape [n]
- The embeddings of a list of objects.
-
- Returns
- -------
- score : TensorFlow operation
- The operation corresponding to the ComplEx scoring function.
-
- """
-
- # Assume each embedding is made of an img and real component.
- # (These components are actually real numbers, see [trouillon2016complex].
- e_s_real, e_s_img = tf.split(e_s, 2, axis=1)
- e_p_real, e_p_img = tf.split(e_p, 2, axis=1)
- e_o_real, e_o_img = tf.split(e_o, 2, axis=1)
-
- # See Eq. 9 [trouillon2016complex):
- return tf.reduce_sum(e_p_real * e_s_real * e_o_real, axis=1) + \
- tf.reduce_sum(e_p_real * e_s_img * e_o_img, axis=1) + \
- tf.reduce_sum(e_p_img * e_s_real * e_o_img, axis=1) - \
- tf.reduce_sum(e_p_img * e_s_img * e_o_real, axis=1)
-
- def fit(self, X, early_stopping=False, early_stopping_params={}, focusE_numeric_edge_values=None,
- tensorboard_logs_path=None):
- """Train a ComplEx model.
-
- The model is trained on a training set X using the training protocol
- described in :cite:`trouillon2016complex`.
-
- Parameters
- ----------
- X : ndarray, shape [n, 3]
- The training triples
- early_stopping: bool
- Flag to enable early stopping (default:False).
-
- If set to ``True``, the training loop adopts the following early stopping heuristic:
-
- - The model will be trained regardless of early stopping for ``burn_in`` epochs.
- - Every ``check_interval`` epochs the method will compute the metric specified in ``criteria``.
-
- If such metric decreases for ``stop_interval`` checks, we stop training early.
-
- Note the metric is computed on ``x_valid``. This is usually a validation set that you held out.
-
- Also, because ``criteria`` is a ranking metric, it requires generating negatives.
- Entities used to generate corruptions can be specified, as long as the side(s) of a triple to corrupt.
- The method supports filtered metrics, by passing an array of positives to ``x_filter``. This will be used to
- filter the negatives generated on the fly (i.e. the corruptions).
-
- .. note::
-
- Keep in mind the early stopping criteria may introduce a certain overhead
- (caused by the metric computation).
- The goal is to strike a good trade-off between such overhead and saving training epochs.
-
- A common approach is to use MRR unfiltered: ::
-
- early_stopping_params={x_valid=X['valid'], 'criteria': 'mrr'}
-
- Note the size of validation set also contributes to such overhead.
- In most cases a smaller validation set would be enough.
-
- early_stopping_params: dictionary
- Dictionary of hyperparameters for the early stopping heuristics.
-
- The following string keys are supported:
-
- - **'x_valid'**: ndarray, shape [n, 3] : Validation set to be used for early stopping.
- - **'criteria'**: string : criteria for early stopping 'hits10', 'hits3', 'hits1' or 'mrr'(default).
- - **'x_filter'**: ndarray, shape [n, 3] : Positive triples to use as filter if a 'filtered'
- early stopping criteria is desired (i.e. filtered-MRR if 'criteria':'mrr').
- Note this will affect training time (no filter by default).
- - **'burn_in'**: int : Number of epochs to pass before kicking in early stopping (default: 100).
- - **check_interval'**: int : Early stopping interval after burn-in (default:10).
- - **'stop_interval'**: int : Stop if criteria is performing worse over n consecutive checks (default: 3)
- - **'corruption_entities'**: List of entities to be used for corruptions.
- If 'all', it uses all entities (default: 'all')
- - **'corrupt_side'**: Specifies which side to corrupt. 's', 'o', 's+o' (default)
-
- Example: ``early_stopping_params={x_valid=X['valid'], 'criteria': 'mrr'}``
-
- focusE_numeric_edge_values: ndarray, shape [n]
- .. _focuse_complex:
-
- If processing a knowledge graph with numeric values associated with links, this is the vector of such
- numbers. Passing this argument will activate the :ref:`FocusE layer `
- :cite:`pai2021learning`.
- Semantically, numeric values can signify importance, uncertainity, significance, confidence, etc.
- Values can be any number, and will be automatically normalised to the [0, 1] range, on a
- predicate-specific basis.
- If the numeric value is unknown pass a ``np.NaN`` value.
- The model will uniformly randomly assign a numeric value.
-
- .. note::
-
- The following toy example shows how to enable the FocusE layer
- to process edges with numeric literals: ::
-
- import numpy as np
- from ampligraph.latent_features import ComplEx
- model = ComplEx(batches_count=1, seed=555, epochs=20,
- k=10, loss='pairwise',
- loss_params={'margin':5})
- X = np.array([['a', 'y', 'b'],
- ['b', 'y', 'a'],
- ['a', 'y', 'c'],
- ['c', 'y', 'a'],
- ['a', 'y', 'd'],
- ['c', 'y', 'd'],
- ['b', 'y', 'c'],
- ['f', 'y', 'e']])
-
- # Numeric values below are associate to each triple in X.
- # They can be any number and will be automatically
- # normalised to the [0, 1] range, on a
- # predicate-specific basis.
- X_edge_values = np.array([5.34, -1.75, 0.33, 5.12,
- np.nan, 3.17, 2.76, 0.41])
-
- model.fit(X, focusE_numeric_edge_values=X_edge_values)
-
-
- tensorboard_logs_path: str or None
- Path to store tensorboard logs, e.g. average training loss tracking per epoch (default: ``None`` indicating
- no logs will be collected). When provided it will create a folder under provided path and save tensorboard
- files there. To then view the loss in the terminal run: ``tensorboard --logdir ``.
-
- """
- super().fit(X, early_stopping, early_stopping_params, focusE_numeric_edge_values,
- tensorboard_logs_path=tensorboard_logs_path)
-
- def predict(self, X, from_idx=False):
- __doc__ = super().predict.__doc__ # NOQA
- return super().predict(X, from_idx=from_idx)
-
- def calibrate(self, X_pos, X_neg=None, positive_base_rate=None, batches_count=100, epochs=50):
- __doc__ = super().calibrate.__doc__ # NOQA
- super().calibrate(X_pos, X_neg, positive_base_rate, batches_count, epochs)
-
- def predict_proba(self, X):
- __doc__ = super().calibrate.__doc__ # NOQA
- return super().predict_proba(X)
diff --git a/ampligraph/latent_features/models/ConvE.py b/ampligraph/latent_features/models/ConvE.py
deleted file mode 100644
index fb489c12..00000000
--- a/ampligraph/latent_features/models/ConvE.py
+++ /dev/null
@@ -1,1154 +0,0 @@
-# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
-#
-# This file is Licensed under the Apache License, Version 2.0.
-# A copy of the Licence is available in LICENCE, or at:
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-import numpy as np
-import tensorflow as tf
-import logging
-from sklearn.utils import check_random_state
-from tqdm import tqdm
-from functools import partial
-import time
-
-from .EmbeddingModel import EmbeddingModel, register_model, ENTITY_THRESHOLD
-from ..initializers import DEFAULT_XAVIER_IS_UNIFORM
-from ampligraph.latent_features import constants as constants
-
-from ...datasets import OneToNDatasetAdapter
-from ..optimizers import SGDOptimizer
-from ...evaluation import to_idx
-
-logger = logging.getLogger(__name__)
-logger.setLevel(logging.DEBUG)
-
-
-@register_model('ConvE', ['conv_filters', 'conv_kernel_size', 'dropout_embed', 'dropout_conv',
- 'dropout_dense', 'use_bias', 'use_batchnorm'], {})
-class ConvE(EmbeddingModel):
- r""" Convolutional 2D KG Embeddings
-
- The ConvE model :cite:`DettmersMS018`.
-
- ConvE uses convolutional layers.
- :math:`g` is a non-linear activation function, :math:`\ast` is the linear convolution operator,
- :math:`vec` indicates 2D reshaping.
-
- .. math::
-
- f_{ConvE} = \langle \sigma \, (vec \, ( g \, ([ \overline{\mathbf{e}_s} ; \overline{\mathbf{r}_p} ]
- \ast \Omega )) \, \mathbf{W} )) \, \mathbf{e}_o\rangle
-
-
- .. note::
-
- ConvE does not handle 's+o' corruptions currently, nor ``large_graph`` mode.
-
-
- Examples
- --------
- >>> import numpy as np
- >>> from ampligraph.latent_features import ConvE
- >>> model = ConvE(batches_count=1, seed=22, epochs=5, k=100)
- >>>
- >>> X = np.array([['a', 'y', 'b'],
- >>> ['b', 'y', 'a'],
- >>> ['a', 'y', 'c'],
- >>> ['c', 'y', 'a'],
- >>> ['a', 'y', 'd'],
- >>> ['c', 'y', 'd'],
- >>> ['b', 'y', 'c'],
- >>> ['f', 'y', 'e']])
- >>> model.fit(X)
- >>> model.predict(np.array([['f', 'y', 'e'], ['b', 'y', 'd']]))
- [0.42921206 0.38998795]
-
- """
-
- def __init__(self,
- k=constants.DEFAULT_EMBEDDING_SIZE,
- eta=constants.DEFAULT_ETA,
- epochs=constants.DEFAULT_EPOCH,
- batches_count=constants.DEFAULT_BATCH_COUNT,
- seed=constants.DEFAULT_SEED,
- embedding_model_params={'conv_filters': constants.DEFAULT_CONVE_CONV_FILTERS,
- 'conv_kernel_size': constants.DEFAULT_CONVE_KERNEL_SIZE,
- 'dropout_embed': constants.DEFAULT_CONVE_DROPOUT_EMBED,
- 'dropout_conv': constants.DEFAULT_CONVE_DROPOUT_CONV,
- 'dropout_dense': constants.DEFAULT_CONVE_DROPOUT_DENSE,
- 'use_bias': constants.DEFAULT_CONVE_USE_BIAS,
- 'use_batchnorm': constants.DEFAULT_CONVE_USE_BATCHNORM},
- optimizer=constants.DEFAULT_OPTIM,
- optimizer_params={'lr': constants.DEFAULT_LR},
- loss='bce',
- loss_params={'label_weighting': False,
- 'label_smoothing': 0.1},
- regularizer=constants.DEFAULT_REGULARIZER,
- regularizer_params={},
- initializer=constants.DEFAULT_INITIALIZER,
- initializer_params={'uniform': DEFAULT_XAVIER_IS_UNIFORM},
- low_memory=False,
- verbose=constants.DEFAULT_VERBOSE):
- """Initialize a ConvE model
-
- Also creates a new Tensorflow session for training.
-
- Parameters
- ----------
- k : int
- Embedding space dimensionality.
-
- eta : int
- The number of negatives that must be generated at runtime during training for each positive.
- Note: This parameter is not used in ConvE.
-
- epochs : int
- The iterations of the training loop.
-
- batches_count : int
- The number of batches in which the training set must be split during the training loop.
-
- seed : int
- The seed used by the internal random numbers generator.
-
- embedding_model_params : dict
- ConvE-specific hyperparams:
-
- - **conv_filters** (int): Number of convolution feature maps. Default: 32
- - **conv_kernel_size** (int): Convolution kernel size. Default: 3
- - **dropout_embed** (float|None): Dropout on the embedding layer. Default: 0.2
- - **dropout_conv** (float|None): Dropout on the convolution maps. Default: 0.3
- - **dropout_dense** (float|None): Dropout on the dense layer. Default: 0.2
- - **use_bias** (bool): Use bias layer. Default: True
- - **use_batchnorm** (bool): Use batch normalization after input, convolution, dense layers. Default: True
-
- optimizer : string
- The optimizer used to minimize the loss function. Choose between 'sgd', 'adagrad', 'adam', 'momentum'.
-
- optimizer_params : dict
- Arguments specific to the optimizer, passed as a dictionary.
-
- Supported keys:
-
- - **'lr'** (float): learning rate (used by all the optimizers). Default: 0.1.
- - **'momentum'** (float): learning momentum (only used when ``optimizer=momentum``). Default: 0.9.
-
- Example: ``optimizer_params={'lr': 0.01}``
-
- loss : string
- The type of loss function to use during training.
-
- - ``bce`` the model will use binary cross entropy loss function.
-
- loss_params : dict
- Dictionary of loss-specific hyperparameters. See :ref:`loss functions ` documentation for
- additional details.
-
- Supported keys:
-
- - **'lr'** (float): learning rate (used by all the optimizers). Default: 0.1.
- - **'momentum'** (float): learning momentum (only used when ``optimizer=momentum``). Default: 0.9.
- - **'label_smoothing'** (float): applies label smoothing to one-hot outputs. Default: 0.1.
- - **'label_weighting'** (bool): applies label weighting to one-hot outputs. Default: True
-
- Example: ``optimizer_params={'lr': 0.01, 'label_smoothing': 0.1}``
-
- regularizer : string
- The regularization strategy to use with the loss function.
-
- - ``None``: the model will not use any regularizer (default)
- - ``LP``: the model will use L1, L2 or L3 based on the value of ``regularizer_params['p']`` (see below).
-
- regularizer_params : dict
- Dictionary of regularizer-specific hyperparameters. See the
- :ref:`regularizers `
- documentation for additional details.
-
- Example: ``regularizer_params={'lambda': 1e-5, 'p': 2}`` if ``regularizer='LP'``.
-
- initializer : string
- The type of initializer to use.
-
- - ``normal``: The embeddings will be initialized from a normal distribution
- - ``uniform``: The embeddings will be initialized from a uniform distribution
- - ``xavier``: The embeddings will be initialized using xavier strategy (default)
-
- initializer_params : dict
- Dictionary of initializer-specific hyperparameters. See the
- :ref:`initializer `
- documentation for additional details.
-
- Example: ``initializer_params={'mean': 0, 'std': 0.001}`` if ``initializer='normal'``.
-
- verbose : bool
- Verbose mode.
-
- low_memory : bool
- Train ConvE with a (slower) low_memory option. If MemoryError is still encountered, try raising the
- batches_count value. Default: False.
-
- """
-
- # Add default values if not provided in embedding_model_params dict
- default_embedding_model_params = {'conv_filters': constants.DEFAULT_CONVE_CONV_FILTERS,
- 'conv_kernel_size': constants.DEFAULT_CONVE_KERNEL_SIZE,
- 'dropout_embed': constants.DEFAULT_CONVE_DROPOUT_EMBED,
- 'dropout_conv': constants.DEFAULT_CONVE_DROPOUT_CONV,
- 'dropout_dense': constants.DEFAULT_CONVE_DROPOUT_DENSE,
- 'use_batchnorm': constants.DEFAULT_CONVE_USE_BATCHNORM,
- 'use_bias': constants.DEFAULT_CONVE_USE_BATCHNORM}
-
- for key, val in default_embedding_model_params.items():
- if key not in embedding_model_params.keys():
- embedding_model_params[key] = val
-
- # Find factor pairs (i,j) of concatenated embedding dimensions, where min(i,j) >= conv_kernel_size
- n = k * 2
- emb_img_depth = 1
-
- ksize = embedding_model_params['conv_kernel_size']
- nfilters = embedding_model_params['conv_filters']
-
- emb_img_width, emb_img_height = None, None
- for i in range(int(np.sqrt(n)) + 1, ksize, -1):
- if n % i == 0:
- emb_img_width, emb_img_height = (i, int(n / i))
- break
-
- if not emb_img_width and not emb_img_height:
- msg = 'Unable to determine factor pairs for embedding reshape. Choose a smaller convolution kernel size, ' \
- 'or a larger embedding dimension.'
- logger.info(msg)
- raise ValueError(msg)
-
- embedding_model_params['embed_image_width'] = emb_img_width
- embedding_model_params['embed_image_height'] = emb_img_height
- embedding_model_params['embed_image_depth'] = emb_img_depth
-
- # Calculate dense dimension
- embedding_model_params['dense_dim'] = (emb_img_width - (ksize - 1)) * (emb_img_height - (ksize - 1)) * nfilters
-
- self.low_memory = low_memory
-
- super().__init__(k=k, eta=eta, epochs=epochs,
- batches_count=batches_count, seed=seed,
- embedding_model_params=embedding_model_params,
- optimizer=optimizer, optimizer_params=optimizer_params,
- loss=loss, loss_params=loss_params,
- regularizer=regularizer, regularizer_params=regularizer_params,
- initializer=initializer, initializer_params=initializer_params,
- verbose=verbose)
-
- def _initialize_parameters(self):
- """Initialize parameters of the model.
-
- This function creates and initializes entity and relation embeddings (with size k).
- If the graph is large, then it loads only the required entity embeddings (max:batch_size*2)
- and all relation embeddings.
- Overload this function if the parameters needs to be initialized differently.
- """
- timestamp = int(time.time() * 1e6)
- if not self.dealing_with_large_graphs:
-
- with tf.variable_scope('meta'):
- self.tf_is_training = tf.Variable(False, trainable=False)
- self.set_training_true = tf.assign(self.tf_is_training, True)
- self.set_training_false = tf.assign(self.tf_is_training, False)
-
- nfilters = self.embedding_model_params['conv_filters']
- ninput = self.embedding_model_params['embed_image_depth']
- ksize = self.embedding_model_params['conv_kernel_size']
- dense_dim = self.embedding_model_params['dense_dim']
-
- self.ent_emb = tf.get_variable('ent_emb_{}'.format(timestamp),
- shape=[len(self.ent_to_idx), self.k],
- initializer=self.initializer.get_entity_initializer(
- len(self.ent_to_idx), self.k),
- dtype=tf.float32)
- self.rel_emb = tf.get_variable('rel_emb_{}'.format(timestamp),
- shape=[len(self.rel_to_idx), self.k],
- initializer=self.initializer.get_relation_initializer(
- len(self.rel_to_idx), self.k),
- dtype=tf.float32)
-
- self.conv2d_W = tf.get_variable('conv2d_weights_{}'.format(timestamp),
- shape=[ksize, ksize, ninput, nfilters],
- initializer=tf.initializers.he_normal(seed=self.seed),
- dtype=tf.float32)
- self.conv2d_B = tf.get_variable('conv2d_bias_{}'.format(timestamp),
- shape=[nfilters],
- initializer=tf.zeros_initializer(), dtype=tf.float32)
-
- self.dense_W = tf.get_variable('dense_weights_{}'.format(timestamp),
- shape=[dense_dim, self.k],
- initializer=tf.initializers.he_normal(seed=self.seed),
- dtype=tf.float32)
- self.dense_B = tf.get_variable('dense_bias_{}'.format(timestamp),
- shape=[self.k],
- initializer=tf.zeros_initializer(), dtype=tf.float32)
-
- if self.embedding_model_params['use_batchnorm']:
-
- emb_img_dim = self.embedding_model_params['embed_image_depth']
-
- self.bn_vars = {'batchnorm_input': {'beta': np.zeros(shape=[emb_img_dim]),
- 'gamma': np.ones(shape=[emb_img_dim]),
- 'moving_mean': np.zeros(shape=[emb_img_dim]),
- 'moving_variance': np.ones(shape=[emb_img_dim])},
- 'batchnorm_conv': {'beta': np.zeros(shape=[nfilters]),
- 'gamma': np.ones(shape=[nfilters]),
- 'moving_mean': np.zeros(shape=[nfilters]),
- 'moving_variance': np.ones(shape=[nfilters])},
- 'batchnorm_dense': {'beta': np.zeros(shape=[1]), # shape = [1] for batch norm
- 'gamma': np.ones(shape=[1]),
- 'moving_mean': np.zeros(shape=[1]),
- 'moving_variance': np.ones(shape=[1])}}
-
- if self.embedding_model_params['use_bias']:
- self.bias = tf.get_variable('activation_bias_{}'.format(timestamp),
- shape=[1, len(self.ent_to_idx)],
- initializer=tf.zeros_initializer(), dtype=tf.float32)
-
- else:
- raise NotImplementedError('ConvE not implemented when dealing with large graphs.')
-
- def _get_model_loss(self, dataset_iterator):
- """Get the current loss including loss due to regularization.
- This function must be overridden if the model uses combination of different losses (eg: VAE).
-
- Parameters
- ----------
- dataset_iterator : tf.data.Iterator
- Dataset iterator.
-
- Returns
- -------
- loss : tf.Tensor
- The loss value that must be minimized.
- """
-
- # training input placeholder
- self.x_pos_tf, self.y_true = dataset_iterator.get_next()
-
- # list of dependent ops that need to be evaluated before computing the loss
- dependencies = []
-
- # run the dependencies
- with tf.control_dependencies(dependencies):
-
- # look up embeddings from input training triples
- e_s_pos, e_p_pos, e_o_pos = self._lookup_embeddings(self.x_pos_tf)
-
- # Get positive predictions
- self.y_pred = self._fn(e_s_pos, e_p_pos, e_o_pos)
-
- # Label smoothing and/or weighting is applied within Loss class
- loss = self.loss.apply(self.y_true, self.y_pred)
-
- if self.regularizer is not None:
- loss += self.regularizer.apply([self.ent_emb, self.rel_emb])
-
- return loss
-
- def _save_trained_params(self):
- """After model fitting, save all the trained parameters in trained_model_params in some order.
- The order would be useful for loading the model.
- This method must be overridden if the model has any other parameters (apart from entity-relation embeddings).
- """
-
- params_dict = {}
- params_dict['ent_emb'] = self.sess_train.run(self.ent_emb)
- params_dict['rel_emb'] = self.sess_train.run(self.rel_emb)
- params_dict['conv2d_W'] = self.sess_train.run(self.conv2d_W)
- params_dict['conv2d_B'] = self.sess_train.run(self.conv2d_B)
- params_dict['dense_W'] = self.sess_train.run(self.dense_W)
- params_dict['dense_B'] = self.sess_train.run(self.dense_B)
-
- if self.embedding_model_params['use_batchnorm']:
-
- bn_dict = {}
-
- for scope in ['batchnorm_input', 'batchnorm_conv', 'batchnorm_dense']:
-
- variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope)
- variables = [x for x in variables if 'Adam' not in x.name] # Filter out any Adam variables
-
- var_dict = {x.name.split('/')[-1].split(':')[0]: x for x in variables}
- bn_dict[scope] = {'beta': self.sess_train.run(var_dict['beta']),
- 'gamma': self.sess_train.run(var_dict['gamma']),
- 'moving_mean': self.sess_train.run(var_dict['moving_mean']),
- 'moving_variance': self.sess_train.run(var_dict['moving_variance'])}
-
- params_dict['bn_vars'] = bn_dict
-
- if self.embedding_model_params['use_bias']:
- params_dict['bias'] = self.sess_train.run(self.bias)
-
- params_dict['output_mapping'] = self.output_mapping
-
- self.trained_model_params = params_dict
-
- def _load_model_from_trained_params(self):
- """Load the model from trained params.
- While restoring make sure that the order of loaded parameters match the saved order.
- It's the duty of the embedding model to load the variables correctly.
- This method must be overridden if the model has any other parameters (apart from entity-relation embeddings)
- This function also set's the evaluation mode to do lazy loading of variables based on the number of
- distinct entities present in the graph.
- """
-
- # Generate the batch size based on entity length and batch_count
- self.batch_size = int(np.ceil(len(self.ent_to_idx) / self.batches_count))
-
- with tf.variable_scope('meta'):
- self.tf_is_training = tf.Variable(False, trainable=False)
- self.set_training_true = tf.assign(self.tf_is_training, True)
- self.set_training_false = tf.assign(self.tf_is_training, False)
-
- self.ent_emb = tf.Variable(self.trained_model_params['ent_emb'], dtype=tf.float32)
- self.rel_emb = tf.Variable(self.trained_model_params['rel_emb'], dtype=tf.float32)
-
- self.conv2d_W = tf.Variable(self.trained_model_params['conv2d_W'], dtype=tf.float32)
- self.conv2d_B = tf.Variable(self.trained_model_params['conv2d_B'], dtype=tf.float32)
- self.dense_W = tf.Variable(self.trained_model_params['dense_W'], dtype=tf.float32)
- self.dense_B = tf.Variable(self.trained_model_params['dense_B'], dtype=tf.float32)
-
- if self.embedding_model_params['use_batchnorm']:
- self.bn_vars = self.trained_model_params['bn_vars']
-
- if self.embedding_model_params['use_bias']:
- self.bias = tf.Variable(self.trained_model_params['bias'], dtype=tf.float32)
-
- self.output_mapping = self.trained_model_params['output_mapping']
-
- def _fn(self, e_s, e_p, e_o):
- r"""The ConvE scoring function.
-
- The function implements the scoring function as defined by
- .. math::
-
- f(vec(f([\overline{e_s};\overline{r_r}] * \Omega)) W ) e_o
-
- Additional details for equivalence of the models available in :cite:`Dettmers2016`.
-
-
- Parameters
- ----------
- e_s : Tensor, shape [n]
- The embeddings of a list of subjects.
- e_p : Tensor, shape [n]
- The embeddings of a list of predicates.
- e_o : Tensor, shape [n]
- The embeddings of a list of objects.
-
- Returns
- -------
- score : TensorFlow operation
- The operation corresponding to the ConvE scoring function.
-
- """
-
- def _dropout(X, rate):
- dropout_rate = tf.cond(self.tf_is_training, true_fn=lambda: tf.constant(rate),
- false_fn=lambda: tf.constant(0, dtype=tf.float32))
- out = tf.nn.dropout(X, rate=dropout_rate)
- return out
-
- def _batchnorm(X, key, axis):
-
- with tf.variable_scope(key, reuse=tf.AUTO_REUSE):
- x = tf.compat.v1.layers.batch_normalization(X, training=self.tf_is_training, axis=axis,
- beta_initializer=tf.constant_initializer(
- self.bn_vars[key]['beta']),
- gamma_initializer=tf.constant_initializer(
- self.bn_vars[key]['gamma']),
- moving_mean_initializer=tf.constant_initializer(
- self.bn_vars[key]['moving_mean']),
- moving_variance_initializer=tf.constant_initializer(
- self.bn_vars[key]['moving_variance']))
- return x
-
- # Inputs
- stacked_emb = tf.stack([e_s, e_p], axis=2)
- self.inputs = tf.reshape(stacked_emb,
- shape=[tf.shape(stacked_emb)[0], self.embedding_model_params['embed_image_height'],
- self.embedding_model_params['embed_image_width'], 1])
-
- x = self.inputs
-
- if self.embedding_model_params['use_batchnorm']:
- x = _batchnorm(x, key='batchnorm_input', axis=3)
-
- if not self.embedding_model_params['dropout_embed'] is None:
- x = _dropout(x, rate=self.embedding_model_params['dropout_embed'])
-
- # Convolution layer
- x = tf.nn.conv2d(x, self.conv2d_W, [1, 1, 1, 1], padding='VALID')
-
- if self.embedding_model_params['use_batchnorm']:
- x = _batchnorm(x, key='batchnorm_conv', axis=3)
- else:
- # Batch normalization will cancel out bias, so only add bias term if not using batchnorm
- x = tf.nn.bias_add(x, self.conv2d_B)
-
- x = tf.nn.relu(x)
-
- if not self.embedding_model_params['dropout_conv'] is None:
- x = _dropout(x, rate=self.embedding_model_params['dropout_conv'])
-
- # Dense layer
- x = tf.reshape(x, shape=[tf.shape(x)[0], self.embedding_model_params['dense_dim']])
- x = tf.matmul(x, self.dense_W)
-
- if self.embedding_model_params['use_batchnorm']:
- # Initializing batchnorm vars for dense layer with shape=[1] will still broadcast over the shape of
- # the specified axis, e.g. dense shape = [?, k], batchnorm on axis 1 will create k batchnorm vars.
- # This is layer normalization rather than batch normalization, so adding a dimension to keep batchnorm,
- # thus dense shape = [?, k, 1], batchnorm on axis 2.
- x = tf.expand_dims(x, -1)
- x = _batchnorm(x, key='batchnorm_dense', axis=2)
- x = tf.squeeze(x, -1)
- else:
- x = tf.nn.bias_add(x, self.dense_B)
-
- # Note: Reference ConvE implementation had dropout on dense layer before applying batch normalization.
- # This can cause variance shift and reduce model performance, so have moved it after as recommended in:
- # https://arxiv.org/abs/1801.05134
- if not self.embedding_model_params['dropout_dense'] is None:
- x = _dropout(x, rate=self.embedding_model_params['dropout_dense'])
-
- x = tf.nn.relu(x)
- x = tf.matmul(x, tf.transpose(self.ent_emb))
-
- if self.embedding_model_params['use_bias']:
- x = tf.add(x, self.bias)
-
- self.scores = x
-
- return self.scores
-
- def get_embeddings(self, entities, embedding_type='entity'):
- """Get the embeddings of entities or relations.
-
- .. Note ::
- Use :meth:`ampligraph.utils.create_tensorboard_visualizations` to visualize the embeddings with TensorBoard.
-
- Parameters
- ----------
- entities : array-like, dtype=int, shape=[n]
- The entities (or relations) of interest. Element of the vector
- must be the original string literals, and not internal IDs.
- embedding_type : string
- If 'entity', ``entities`` argument will be considered as a list of knowledge graph entities (i.e. nodes).
- If set to 'relation', they will be treated as relation types instead (i.e. predicates).
-
- Returns
- -------
- embeddings : ndarray, shape [n, k]
- An array of k-dimensional embeddings.
-
- """
- if not self.is_fitted:
- msg = 'Model has not been fitted.'
- logger.error(msg)
- raise RuntimeError(msg)
-
- if embedding_type == 'entity':
- emb_list = self.trained_model_params['ent_emb']
- lookup_dict = self.ent_to_idx
- elif embedding_type == 'relation':
- emb_list = self.trained_model_params['rel_emb']
- lookup_dict = self.rel_to_idx
- else:
- msg = 'Invalid entity type: {}'.format(embedding_type)
- logger.error(msg)
- raise ValueError(msg)
-
- idxs = np.vectorize(lookup_dict.get)(entities)
- return emb_list[idxs]
-
- def fit(self, X, early_stopping=False, early_stopping_params={}):
- """Train a ConvE (with optional early stopping).
-
- The model is trained on a training set X using the training protocol
- described in :cite:`DettmersMS018`.
-
- Parameters
- ----------
- X : ndarray (shape [n, 3]) or object of ConvEDatasetAdapter
- Numpy array of training triples OR handle of Dataset adapter which would help retrieve data.
- early_stopping: bool
- Flag to enable early stopping (default:``False``)
- early_stopping_params: dictionary
- Dictionary of hyperparameters for the early stopping heuristics.
-
- The following string keys are supported:
-
- - **'x_valid'**: ndarray (shape [n, 3]) or object of AmpligraphDatasetAdapter :
- Numpy array of validation triples OR handle of Dataset adapter which
- would help retrieve data.
- - **'criteria'**: string : criteria for early stopping 'hits10', 'hits3', 'hits1' or 'mrr'(default).
- - **'x_filter'**: ndarray, shape [n, 3] : Positive triples to use as filter if a 'filtered' early
- stopping criteria is desired (i.e. filtered-MRR if 'criteria':'mrr').
- Note this will affect training time (no filter by default).
- If the filter has already been set in the adapter, pass True
- - **'burn_in'**: int : Number of epochs to pass before kicking in early stopping (default: 100).
- - **check_interval'**: int : Early stopping interval after burn-in (default:10).
- - **'stop_interval'**: int : Stop if criteria is performing worse over n consecutive checks (default: 3)
- - **'corruption_entities'**: List of entities to be used for corruptions. If 'all',
- it uses all entities (default: 'all')
- - **'corrupt_side'**: Specifies which side to corrupt. 'o' (default). Note: ConvE does not
- currently support subject corruptions in early stopping.
-
- Example: ``early_stopping_params={x_valid=X['valid'], 'criteria': 'mrr'}``
-
- """
-
- self.train_dataset_handle = None
- # try-except block is mainly to handle clean up in case of exception or manual stop in jupyter notebook
- try:
- if isinstance(X, np.ndarray):
- self.train_dataset_handle = OneToNDatasetAdapter(low_memory=self.low_memory)
- self.train_dataset_handle.set_data(X, 'train')
- elif isinstance(X, OneToNDatasetAdapter):
- self.train_dataset_handle = X
- else:
- msg = 'Invalid type for input X. Expected numpy.array or OneToNDatasetAdapter object, got {}'\
- .format(type(X))
- logger.error(msg)
- raise ValueError(msg)
-
- # create internal IDs mappings
- self.rel_to_idx, self.ent_to_idx = self.train_dataset_handle.generate_mappings()
-
- if len(self.ent_to_idx) > ENTITY_THRESHOLD:
- self.dealing_with_large_graphs = True
- prefetch_batches = 0
-
- logger.warning('Your graph has a large number of distinct entities. '
- 'Found {} distinct entities'.format(len(self.ent_to_idx)))
-
- logger.warning('Changing the variable initialization strategy.')
- logger.warning('Changing the strategy to use lazy loading of variables...')
-
- if early_stopping:
- raise Exception('Early stopping not supported for large graphs')
-
- if not isinstance(self.optimizer, SGDOptimizer):
- raise Exception("This mode works well only with SGD optimizer with decay (read docs for details). "
- "Kindly change the optimizer and restart the experiment")
-
- raise NotImplementedError('ConvE not implemented when dealing with large graphs.')
-
- self.train_dataset_handle.map_data()
-
- # This is useful when we re-fit the same model (e.g. retraining in model selection)
- if self.is_fitted:
- tf.reset_default_graph()
- self.rnd = check_random_state(self.seed)
- tf.random.set_random_seed(self.seed)
-
- self.sess_train = tf.Session(config=self.tf_config)
-
- batch_size = int(np.ceil(self.train_dataset_handle.get_size("train") / self.batches_count))
- self.batch_size = batch_size
-
- if len(self.ent_to_idx) > ENTITY_THRESHOLD:
- logger.warning('Only {} embeddings would be loaded in memory per batch...'.format(batch_size * 2))
-
- self._initialize_parameters()
-
- # Output mapping is dict of (s, p) to list of existing object triple indices
- self.output_mapping = self.train_dataset_handle.generate_output_mapping(dataset_type='train')
- self.train_dataset_handle.set_output_mapping(self.output_mapping)
- self.train_dataset_handle.generate_outputs(dataset_type='train', unique_pairs=True)
- train_iter = partial(self.train_dataset_handle.get_next_batch,
- batches_count=self.batches_count,
- dataset_type='train',
- use_filter=False,
- unique_pairs=True)
-
- dataset = tf.data.Dataset.from_generator(train_iter,
- output_types=(tf.int32, tf.float32),
- output_shapes=((None, 3), (None, len(self.ent_to_idx))))
- prefetch_batches = 5
- dataset = dataset.repeat().prefetch(prefetch_batches)
- dataset_iterator = dataset.make_one_shot_iterator()
-
- # init tf graph/dataflow for training
- # init variables (model parameters to be learned - i.e. the embeddings)
- if self.loss.get_state('require_same_size_pos_neg'):
- batch_size = batch_size * self.eta
-
- # Required for label smoothing
- self.loss._set_hyperparams('num_entities', len(self.ent_to_idx))
-
- loss = self._get_model_loss(dataset_iterator)
-
- # Add update_ops for batch normalization
- update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
- with tf.control_dependencies(update_ops):
- train = self.optimizer.minimize(loss)
-
- self.early_stopping_params = early_stopping_params
-
- # early stopping
- if early_stopping:
- self._initialize_early_stopping()
-
- self.sess_train.run(tf.tables_initializer())
- self.sess_train.run(tf.global_variables_initializer())
- self.sess_train.run(self.set_training_true)
-
- # Entity embeddings normalization
- normalize_ent_emb_op = self.ent_emb.assign(tf.clip_by_norm(self.ent_emb, clip_norm=1, axes=1))
- normalize_rel_emb_op = self.rel_emb.assign(tf.clip_by_norm(self.rel_emb, clip_norm=1, axes=1))
-
- if self.embedding_model_params.get('normalize_ent_emb', constants.DEFAULT_NORMALIZE_EMBEDDINGS):
- self.sess_train.run(normalize_rel_emb_op)
- self.sess_train.run(normalize_ent_emb_op)
-
- epoch_iterator_with_progress = tqdm(range(1, self.epochs + 1), disable=(not self.verbose), unit='epoch')
-
- for epoch in epoch_iterator_with_progress:
- losses = []
- for batch in range(1, self.batches_count + 1):
- feed_dict = {}
- self.optimizer.update_feed_dict(feed_dict, batch, epoch)
-
- loss_batch, _ = self.sess_train.run([loss, train], feed_dict=feed_dict)
-
- if np.isnan(loss_batch) or np.isinf(loss_batch):
- msg = 'Loss is {}. Please change the hyperparameters.'.format(loss_batch)
- logger.error(msg)
- raise ValueError(msg)
-
- losses.append(loss_batch)
- if self.embedding_model_params.get('normalize_ent_emb', constants.DEFAULT_NORMALIZE_EMBEDDINGS):
- self.sess_train.run(normalize_ent_emb_op)
-
- if self.verbose:
- msg = 'Average Loss: {:10f}'.format(sum(losses) / (batch_size * self.batches_count))
- if early_stopping and self.early_stopping_best_value is not None:
- msg += ' — Best validation ({}): {:5f}'.format(self.early_stopping_criteria,
- self.early_stopping_best_value)
-
- logger.debug(msg)
- epoch_iterator_with_progress.set_description(msg)
-
- if early_stopping:
-
- self.sess_train.run(self.set_training_false)
- if self._perform_early_stopping_test(epoch):
- self._end_training()
- return
- self.sess_train.run(self.set_training_true)
-
- self._save_trained_params()
- self._end_training()
- except BaseException as e:
- self._end_training()
- raise e
-
- def _initialize_eval_graph(self, mode='test'):
- """ Initialize the evaluation graph with the set protocol.
-
- Parameters
- ----------
- mode: string
- Indicates which data generator to use.
-
- Returns
- -------
-
- """
-
- logger.debug('Initializing eval graph [mode: {}]'.format(mode))
-
- test_generator = partial(self.eval_dataset_handle.get_next_batch,
- batches_count=-1,
- dataset_type=mode,
- use_filter=self.is_filtered,
- unique_pairs=False)
-
- dataset = tf.data.Dataset.from_generator(test_generator,
- output_types=(tf.int32, tf.float32),
- output_shapes=((None, 3), (None, len(self.ent_to_idx))))
-
- dataset = dataset.repeat()
- dataset = dataset.prefetch(5)
- dataset_iter = dataset.make_one_shot_iterator()
-
- self.X_test_tf, self.X_test_filter_tf = dataset_iter.get_next()
-
- e_s, e_p, e_o = self._lookup_embeddings(self.X_test_tf)
-
- # Scores for all triples
- scores = tf.sigmoid(tf.squeeze(self._fn(e_s, e_p, e_o)))
-
- # Score of positive triple
- self.score_positive = tf.gather(scores, indices=self.X_test_tf[:, 2])
-
- # Scores for positive triples
- self.scores_filtered = tf.boolean_mask(scores, tf.cast(self.X_test_filter_tf, tf.bool))
-
- # Triple rank over all triples
- self.total_rank = self.perform_comparision(scores, self.score_positive)
-
- # Triple rank over positive triples
- self.filter_rank = self.perform_comparision(self.scores_filtered, self.score_positive)
-
- # Rank of triple, with other positives filtered out.
- self.rank = tf.subtract(self.total_rank, self.filter_rank) + 1
-
- # NOTE: if having trouble with the above rank calculation, consider when test triple
- # has the highest score (total_rank=1, filter_rank=1)
-
- def _initialize_early_stopping(self):
- """Initializes and creates evaluation graph for early stopping.
- """
-
- try:
- self.x_valid = self.early_stopping_params['x_valid']
- except KeyError:
- msg = 'x_valid must be passed for early fitting.'
- logger.error(msg)
- raise KeyError(msg)
-
- # Set eval_dataset handler
- if isinstance(self.x_valid, np.ndarray):
-
- if self.x_valid.ndim <= 1 or (np.shape(self.x_valid)[1]) != 3:
- msg = 'Invalid size for input x_valid. Expected (n,3): got {}'.format(np.shape(self.x_valid))
- logger.error(msg)
- raise ValueError(msg)
-
- # store the validation data in the data handler
- self.train_dataset_handle.set_data(self.x_valid, 'valid')
- self.eval_dataset_handle = self.train_dataset_handle
- logger.debug('Initialized eval_dataset from train_dataset using.')
-
- elif isinstance(self.x_valid, OneToNDatasetAdapter):
-
- if not self.eval_dataset_handle.data_exists('valid'):
- msg = 'Dataset `valid` has not been set in the DatasetAdapter.'
- logger.error(msg)
- raise ValueError(msg)
-
- self.eval_dataset_handle = self.x_valid
- logger.debug('Initialized eval_dataset from AmpligraphDatasetAdapter')
-
- else:
- msg = 'Invalid type for input X. Expected np.ndarray or OneToNDatasetAdapter object, \
- got {}'.format(type(self.x_valid))
- logger.error(msg)
- raise ValueError(msg)
-
- self.early_stopping_criteria = self.early_stopping_params.get('criteria',
- constants.DEFAULT_CRITERIA_EARLY_STOPPING)
-
- if self.early_stopping_criteria not in ['hits10', 'hits1', 'hits3', 'mrr']:
- msg = 'Unsupported early stopping criteria.'
- logger.error(msg)
- raise ValueError(msg)
-
- self.eval_config['corrupt_side'] = self.early_stopping_params.get('corrupt_side',
- constants.DEFAULT_CORRUPT_SIDE_EVAL)
-
- if 's' in self.eval_config['corrupt_side']:
- msg = "ConvE does not support subject corruptions in early stopping. Please change to: 'o'"
- logger.error(msg)
- raise ValueError(msg)
-
- self.early_stopping_best_value = None
- self.early_stopping_stop_counter = 0
-
- # Set filter
- if 'x_filter' in self.early_stopping_params.keys():
-
- # If the filter has already been set in the dataset adapter then just pass x_filter = True
- x_filter = self.early_stopping_params['x_filter']
- if isinstance(x_filter, np.ndarray):
-
- if x_filter.ndim <= 1 or (np.shape(x_filter)[1]) != 3:
- msg = 'Invalid size for input x_valid. Expected (n,3): got {}'.format(np.shape(x_filter))
- logger.error(msg)
- raise ValueError(msg)
-
- # set the filter triples in the data handler
- x_filter = to_idx(x_filter, ent_to_idx=self.ent_to_idx, rel_to_idx=self.rel_to_idx)
- self.eval_dataset_handle.set_filter(x_filter, mapped_status=True)
-
- # set the flag to perform filtering
- self.set_filter_for_eval()
- else:
- logger.debug('x_filter not found in early_stopping_params.')
-
- # initialize evaluation graph in validation mode i.e. to use validation set
- self._initialize_eval_graph('valid')
-
- def predict(self, X, from_idx=False):
- """Predict the scores of triples using a trained embedding model.
- The function returns raw scores generated by the model.
-
- .. note::
-
- To obtain probability estimates, calibrate the model with :func:`~ConvE.calibrate`,
- then call :func:`~ConvE.predict_proba`.
-
- Parameters
- ----------
- X : ndarray, shape [n, 3]
- The triples to score.
- from_idx : bool
- If True, will skip conversion to internal IDs. (default: False).
-
- Returns
- -------
- scores_predict : ndarray, shape [n]
- The predicted scores for input triples X.
-
- """
- if not self.is_fitted:
- msg = 'Model has not been fitted.'
- logger.error(msg)
- raise RuntimeError(msg)
-
- tf.reset_default_graph()
- self._load_model_from_trained_params()
-
- dataset_handle = OneToNDatasetAdapter(low_memory=self.low_memory)
- dataset_handle.use_mappings(self.rel_to_idx, self.ent_to_idx)
- dataset_handle.set_data(X, "test", mapped_status=from_idx)
-
- # Note: onehot outputs not required for prediction, but are part of the batch function
- dataset_handle.set_output_mapping(self.output_mapping)
- dataset_handle.generate_outputs(dataset_type='test', unique_pairs=False)
- self.eval_dataset_handle = dataset_handle
-
- self.rnd = check_random_state(self.seed)
- tf.random.set_random_seed(self.seed)
- self._initialize_eval_graph()
-
- with tf.Session(config=self.tf_config) as sess:
-
- sess.run(tf.tables_initializer())
- sess.run(tf.global_variables_initializer())
- sess.run(self.set_training_false)
-
- scores = []
-
- for i in tqdm(range(self.eval_dataset_handle.get_size('test'))):
-
- score = sess.run(self.score_positive)
- scores.append(score[0])
-
- return scores
-
- def get_ranks(self, dataset_handle):
- """ Used by evaluate_predictions to get the ranks for evaluation.
-
- Parameters
- ----------
- dataset_handle : Object of AmpligraphDatasetAdapter
- This contains handles of the generators that would be used to get test triples and filters
-
- Returns
- -------
- ranks : ndarray, shape [n] or [n,2] depending on the value of use_default_protocol.
- An array of ranks of test triples.
- """
-
- if not self.is_fitted:
- msg = 'Model has not been fitted.'
- logger.error(msg)
- raise RuntimeError(msg)
-
- eval_protocol = self.eval_config.get('corrupt_side', constants.DEFAULT_CORRUPT_SIDE_EVAL)
-
- if 'o' in eval_protocol:
- object_ranks = self._get_object_ranks(dataset_handle)
-
- if 's' in eval_protocol:
- subject_ranks = self._get_subject_ranks(dataset_handle)
-
- if eval_protocol == 's,o':
- ranks = [[s, o] for s, o in zip(subject_ranks, object_ranks)]
- elif eval_protocol == 's':
- ranks = subject_ranks
- elif eval_protocol == 'o':
- ranks = object_ranks
-
- return ranks
-
- def _get_object_ranks(self, dataset_handle):
- """ Internal function for obtaining object ranks.
-
- Parameters
- ----------
- dataset_handle : Object of AmpligraphDatasetAdapter
- This contains handles of the generators that would be used to get test triples and filters
-
- Returns
- -------
- ranks : ndarray, shape [n]
- An array of ranks of test triples.
- """
-
- self.eval_dataset_handle = dataset_handle
-
- # Load model parameters, build tf evaluation graph for predictions
- tf.reset_default_graph()
- self.rnd = check_random_state(self.seed)
- tf.random.set_random_seed(self.seed)
- self._load_model_from_trained_params()
-
- # Set the output mapping of the dataset handle - this is superceded if a filter has been set.
- dataset_handle.set_output_mapping(self.output_mapping)
-
- self._initialize_eval_graph()
-
- with tf.Session(config=self.tf_config) as sess:
-
- sess.run(tf.tables_initializer())
- sess.run(tf.global_variables_initializer())
- sess.run(self.set_training_false)
-
- ranks = []
- for _ in tqdm(range(self.eval_dataset_handle.get_size('test')), disable=(not self.verbose)):
- rank = sess.run(self.rank)
- ranks.append(rank)
-
- return np.array(ranks)
-
- def _initialize_eval_graph_subject(self, mode='test'):
- """ Initialize the graph for evaluating subject corruptions.
-
- Parameters
- ----------
- mode: string
- Indicates which data generator to use.
-
- Returns
- -------
-
- """
-
- logger.debug('Initializing eval graph for subject corruptions [mode: {}]'.format(mode))
-
- corruption_batch_size = constants.DEFAULT_SUBJECT_CORRUPTION_BATCH_SIZE
-
- test_generator = partial(self.eval_dataset_handle.get_next_batch_subject_corruptions,
- batch_size=corruption_batch_size,
- dataset_type=mode)
-
- dataset = tf.data.Dataset.from_generator(test_generator,
- output_types=(tf.int32, tf.int32, tf.float32),
- output_shapes=((None, 3), (None, 3), (None, len(self.ent_to_idx))))
-
- dataset = dataset.repeat()
- dataset = dataset.prefetch(5)
- dataset_iter = dataset.make_one_shot_iterator()
-
- self.X_test_tf, self.subject_corr, self.X_filter_tf = dataset_iter.get_next()
-
- e_s, e_p, e_o = self._lookup_embeddings(self.subject_corr)
-
- # Scores for all triples
- self.sigmoid_scores = tf.sigmoid(tf.squeeze(self._fn(e_s, e_p, e_o)))
-
- def _get_subject_ranks(self, dataset_handle, corruption_batch_size=None):
- """ Internal function for obtaining subject ranks.
-
- This function performs subject corruptions. Output layer scores are accumulated in order to rank
- subject corruptions. This can cause high memory consumption, so a default subject corruption batch size
- is set in constants.py.
-
- Parameters
- ----------
- dataset_handle : Object of AmpligraphDatasetAdapter
- This contains handles of the generators that would be used to get test triples and filters
- corruption_batch_size : int / None
- Batch size for accumulating output layer scores for each input. The accumulated batch size
- will be np.array shape=(corruption_batch_size, num_entities), and dtype=np.float32).
- Default: 10000 has been set in constants.DEFAULT_SUBJECT_CORRUPTION_BATCH_SIZE.
-
- Returns
- -------
- ranks : ndarray, shape [n]
- An array of ranks of test triples.
- """
-
- self.eval_dataset_handle = dataset_handle
-
- # Load model parameters, build tf evaluation graph for predictions
- tf.reset_default_graph()
- self.rnd = check_random_state(self.seed)
- tf.random.set_random_seed(self.seed)
- self._load_model_from_trained_params()
-
- # Set the output mapping of the dataset handle - this is superceded if a filter has been set.
- dataset_handle.set_output_mapping(self.output_mapping)
-
- self._initialize_eval_graph_subject()
-
- if not corruption_batch_size:
- corruption_batch_size = constants.DEFAULT_SUBJECT_CORRUPTION_BATCH_SIZE
-
- num_entities = len(self.ent_to_idx)
- num_batch_per_relation = np.ceil(len(self.eval_dataset_handle.ent_to_idx) / corruption_batch_size)
- num_batches = int(num_batch_per_relation * len(self.eval_dataset_handle.rel_to_idx))
-
- with tf.Session(config=self.tf_config) as sess:
-
- sess.run(tf.tables_initializer())
- sess.run(tf.global_variables_initializer())
- sess.run(self.set_training_false)
-
- ranks = []
- # Accumulate scores from each index of the object in the output scores while corrupting subject
- scores_matrix_accum = []
- # Accumulate true/false statements from one-hot outputs while corrupting subject
- scores_filter_accum = []
-
- for _ in tqdm(range(num_batches), disable=(not self.verbose), unit='batch'):
-
- try:
-
- X_test, scores_matrix, scores_filter = sess.run([self.X_test_tf, self.sigmoid_scores,
- self.X_filter_tf])
-
- # Accumulate scores from X_test columns
- scores_matrix_accum.append(scores_matrix[:, X_test[:, 2]])
- scores_filter_accum.append(scores_filter[:, X_test[:, 2]])
-
- num_rows_accum = np.sum([x.shape[0] for x in scores_matrix_accum])
-
- if num_rows_accum == num_entities:
- # When num rows accumulated equals num_entities, batch has finished a single subject corruption
- # loop on a single relation
-
- if len(X_test) == 0:
- # If X_test is empty, reset accumulated scores and continue
- scores_matrix_accum, scores_filter_accum = [], []
- continue
-
- scores_matrix = np.concatenate(scores_matrix_accum)
- scores_filter = np.concatenate(scores_filter_accum)
-
- for i, x in enumerate(X_test):
- score_positive = scores_matrix[x[0], i]
- idx_negatives = np.where(scores_filter[:, i] != 1)
- score_negatives = scores_matrix[idx_negatives[0], i]
- rank = np.sum(score_negatives >= score_positive) + 1
- ranks.append(rank)
-
- # Reset accumulators
- scores_matrix_accum, scores_filter_accum = [], []
-
- except StopIteration:
- break
-
- return np.array(ranks)
diff --git a/ampligraph/latent_features/models/ConvKB.py b/ampligraph/latent_features/models/ConvKB.py
deleted file mode 100644
index de39671e..00000000
--- a/ampligraph/latent_features/models/ConvKB.py
+++ /dev/null
@@ -1,503 +0,0 @@
-# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
-#
-# This file is Licensed under the Apache License, Version 2.0.
-# A copy of the Licence is available in LICENCE, or at:
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-import numpy as np
-import tensorflow as tf
-import logging
-
-from .EmbeddingModel import EmbeddingModel, register_model, ENTITY_THRESHOLD
-from ..initializers import DEFAULT_XAVIER_IS_UNIFORM
-from ampligraph.latent_features import constants as constants
-import time
-
-logger = logging.getLogger(__name__)
-logger.setLevel(logging.DEBUG)
-
-
-@register_model("ConvKB", {'num_filters': 32, 'filter_sizes': [1], 'dropout': 0.1})
-class ConvKB(EmbeddingModel):
- r"""Convolution-based model
-
- The ConvKB model :cite:`Nguyen2018`:
-
- .. math::
-
- f_{ConvKB}= concat \,(g \, ([\mathbf{e}_s, \mathbf{r}_p, \mathbf{e}_o]) * \Omega)) \cdot W
-
- where :math:`g` is a non-linear function, :math:`*` is the convolution operator,
- :math:`\cdot` is the dot product, :math:`concat` is the concatenation operator
- and :math:`\Omega` is a set of filters.
-
- .. note::
- The evaluation protocol implemented in :meth:`ampligraph.evaluation.evaluate_performance` assigns the worst rank
- to a positive test triple in case of a tie with negatives. This is the agreed upon behaviour in literature.
- The original ConvKB implementation :cite:`Nguyen2018` assigns instead the top rank, hence leading to
- `results which are not directly comparable with
- literature `_ .
- We report results obtained with the agreed-upon protocol (tie=worst rank). Note that under these conditions
- the model :ref:`does not reach the state-of-the-art results claimed in the original paper`.
-
- Examples
- --------
- >>> from ampligraph.latent_features import ConvKB
- >>> from ampligraph.datasets import load_wn18
- >>> model = ConvKB(batches_count=2, seed=22, epochs=1, k=10, eta=1,
- >>> embedding_model_params={'num_filters': 32, 'filter_sizes': [1],
- >>> 'dropout': 0.1},
- >>> optimizer='adam', optimizer_params={'lr': 0.001},
- >>> loss='pairwise', loss_params={}, verbose=True)
- >>>
- >>> X = load_wn18()
- >>>
- >>> model.fit(X['train'])
- >>>
- >>> print(model.predict(X['test'][:5]))
- [[0.2803744], [0.0866661], [0.012815937], [-0.004235901], [-0.010947697]]
- """
-
- def __init__(self,
- k=constants.DEFAULT_EMBEDDING_SIZE,
- eta=constants.DEFAULT_ETA,
- epochs=constants.DEFAULT_EPOCH,
- batches_count=constants.DEFAULT_BATCH_COUNT,
- seed=constants.DEFAULT_SEED,
- embedding_model_params={'num_filters': 32,
- 'filter_sizes': [1],
- 'dropout': 0.1},
- optimizer=constants.DEFAULT_OPTIM,
- optimizer_params={'lr': constants.DEFAULT_LR},
- loss=constants.DEFAULT_LOSS,
- loss_params={},
- regularizer=constants.DEFAULT_REGULARIZER,
- regularizer_params={},
- initializer=constants.DEFAULT_INITIALIZER,
- initializer_params={'uniform': DEFAULT_XAVIER_IS_UNIFORM},
- large_graphs=False,
- verbose=constants.DEFAULT_VERBOSE):
- """Initialize an EmbeddingModel
-
- Parameters
- ----------
- k : int
- Embedding space dimensionality.
-
- eta : int
- The number of negatives that must be generated at runtime during training for each positive.
-
- epochs : int
- The iterations of the training loop.
-
- batches_count : int
- The number of batches in which the training set must be split during the training loop.
-
- seed : int
- The seed used by the internal random numbers generator.
-
- embedding_model_params : dict
- ConvKB-specific hyperparams:
- - **num_filters** - Number of feature maps per convolution kernel. Default: 32
- - **filter_sizes** - List of convolution kernel sizes. Default: [1]
- - **dropout** - Dropout on the embedding layer. Default: 0.0
- - **'non_linearity'**: can be one of the following values ``linear``, ``softplus``, ``sigmoid``, ``tanh``
- - **'stop_epoch'**: specifies how long to decay (linearly) the numeric values from 1 to original value
- until it reachs original value.
- - **'structural_wt'**: structural influence hyperparameter [0, 1] that modulates the influence of graph
- topology.
- - **'normalize_numeric_values'**: normalize the numeric values, such that they are scaled between [0, 1]
-
- The last 4 parameters are related to FocusE layers.
-
- optimizer : string
- The optimizer used to minimize the loss function. Choose between
- 'sgd', 'adagrad', 'adam', 'momentum'.
-
- optimizer_params : dict
- Arguments specific to the optimizer, passed as a dictionary.
-
- Supported keys:
-
- - **'lr'** (float): learning rate (used by all the optimizers). Default: 0.1.
- - **'momentum'** (float): learning momentum (only used when ``optimizer=momentum``). Default: 0.9.
-
- Example: ``optimizer_params={'lr': 0.01}``
-
- loss : string
- The type of loss function to use during training.
-
- loss_params : dict
- Dictionary of loss-specific hyperparameters. See :ref:`loss functions `
- documentation for additional details.
-
- Supported keys:
-
- - **'lr'** (float): learning rate (used by all the optimizers). Default: 0.1.
- - **'momentum'** (float): learning momentum (only used when ``optimizer=momentum``). Default: 0.9.
-
- Example: ``optimizer_params={'lr': 0.01, 'label_smoothing': 0.1}``
-
- regularizer : string
- The regularization strategy to use with the loss function.
-
- - ``None``: the model will not use any regularizer (default)
- - ``LP``: the model will use L1, L2 or L3 based on the value of ``regularizer_params['p']`` (see below).
-
- regularizer_params : dict
- Dictionary of regularizer-specific hyperparameters. See the
- :ref:`regularizers `
- documentation for additional details.
-
- Example: ``regularizer_params={'lambda': 1e-5, 'p': 2}`` if ``regularizer='LP'``.
-
- initializer : string
- The type of initializer to use.
-
- - ``normal``: The embeddings will be initialized from a normal distribution
- - ``uniform``: The embeddings will be initialized from a uniform distribution
- - ``xavier``: The embeddings will be initialized using xavier strategy (default)
-
- initializer_params : dict
- Dictionary of initializer-specific hyperparameters. See the
- :ref:`initializer `
- documentation for additional details.
-
- Example: ``initializer_params={'mean': 0, 'std': 0.001}`` if ``initializer='normal'``.
-
- large_graphs : bool
- Avoid loading entire dataset onto GPU when dealing with large graphs.
-
- verbose : bool
- Verbose mode.
- """
-
- num_filters = embedding_model_params['num_filters']
- filter_sizes = embedding_model_params['filter_sizes']
-
- if isinstance(filter_sizes, int):
- filter_sizes = [filter_sizes]
-
- dense_dim = (k * len(filter_sizes) - sum(filter_sizes) + len(filter_sizes)) * num_filters
- embedding_model_params['dense_dim'] = dense_dim
- embedding_model_params['filter_sizes'] = filter_sizes
-
- super().__init__(k=k, eta=eta, epochs=epochs,
- batches_count=batches_count, seed=seed,
- embedding_model_params=embedding_model_params,
- optimizer=optimizer, optimizer_params=optimizer_params,
- loss=loss, loss_params=loss_params,
- regularizer=regularizer, regularizer_params=regularizer_params,
- initializer=initializer, initializer_params=initializer_params,
- large_graphs=large_graphs, verbose=verbose)
-
- def _initialize_parameters(self):
- """Initialize parameters of the model.
-
- This function creates and initializes entity and relation embeddings (with size k).
- If the graph is large, then it loads only the required entity embeddings (max:batch_size*2)
- and all relation embeddings.
- Overload this function if the parameters needs to be initialized differently.
- """
-
- with tf.variable_scope('meta'):
- self.tf_is_training = tf.Variable(False, trainable=False)
- self.set_training_true = tf.assign(self.tf_is_training, True)
- self.set_training_false = tf.assign(self.tf_is_training, False)
-
- timestamp = int(time.time() * 1e6)
- if not self.dealing_with_large_graphs:
-
- self.ent_emb = tf.get_variable('ent_emb_{}'.format(timestamp),
- shape=[len(self.ent_to_idx), self.k],
- initializer=self.initializer.get_entity_initializer(
- len(self.ent_to_idx), self.k), dtype=tf.float32)
- self.rel_emb = tf.get_variable('rel_emb_{}'.format(timestamp),
- shape=[len(self.rel_to_idx), self.k],
- initializer=self.initializer.get_relation_initializer(
- len(self.rel_to_idx), self.k), dtype=tf.float32)
-
- else:
-
- self.ent_emb = tf.get_variable('ent_emb_{}'.format(timestamp),
- shape=[self.batch_size * 2, self.internal_k],
- initializer=tf.zeros_initializer(), dtype=tf.float32)
-
- self.rel_emb = tf.get_variable('rel_emb_{}'.format(timestamp),
- shape=[len(self.rel_to_idx), self.internal_k],
- initializer=self.initializer.get_relation_initializer(
- len(self.rel_to_idx), self.internal_k), dtype=tf.float32)
-
- num_filters = self.embedding_model_params['num_filters']
- filter_sizes = self.embedding_model_params['filter_sizes']
- dense_dim = self.embedding_model_params['dense_dim']
- num_outputs = 1 # i.e. a single score
-
- self.conv_weights = {}
- for i, filter_size in enumerate(filter_sizes):
- conv_shape = [3, filter_size, 1, num_filters]
- conv_name = 'conv-maxpool-{}'.format(filter_size)
- weights_init = tf.initializers.truncated_normal(seed=self.seed)
- self.conv_weights[conv_name] = {'weights': tf.get_variable('{}_W_{}'.format(conv_name, timestamp),
- shape=conv_shape,
- trainable=True, dtype=tf.float32,
- initializer=weights_init),
- 'biases': tf.get_variable('{}_B_{}'.format(conv_name, timestamp),
- shape=[num_filters],
- trainable=True, dtype=tf.float32,
- initializer=tf.zeros_initializer())}
-
- self.dense_W = tf.get_variable('dense_weights_{}'.format(timestamp),
- shape=[dense_dim, num_outputs], trainable=True,
- initializer=tf.keras.initializers.he_normal(seed=self.seed),
- dtype=tf.float32)
- self.dense_B = tf.get_variable('dense_bias_{}'.format(timestamp),
- shape=[num_outputs], trainable=False,
- initializer=tf.zeros_initializer(), dtype=tf.float32)
-
- def get_embeddings(self, entities, embedding_type='entity'):
- """Get the embeddings of entities or relations.
-
- .. Note ::
- Use :meth:`ampligraph.utils.create_tensorboard_visualizations` to visualize the embeddings with TensorBoard.
-
- Parameters
- ----------
- entities : array-like, dtype=int, shape=[n]
- The entities (or relations) of interest. Element of the vector must be the original string literals, and
- not internal IDs.
- embedding_type : string
- If 'entity', ``entities`` argument will be considered as a list of knowledge graph entities (i.e. nodes).
- If set to 'relation', they will be treated as relation types instead (i.e. predicates).
-
- Returns
- -------
- embeddings : ndarray, shape [n, k]
- An array of k-dimensional embeddings.
-
- """
- if not self.is_fitted:
- msg = 'Model has not been fitted.'
- logger.error(msg)
- raise RuntimeError(msg)
-
- if embedding_type == 'entity':
- emb_list = self.trained_model_params['ent_emb']
- lookup_dict = self.ent_to_idx
- elif embedding_type == 'relation':
- emb_list = self.trained_model_params['rel_emb']
- lookup_dict = self.rel_to_idx
- else:
- msg = 'Invalid entity type: {}'.format(embedding_type)
- logger.error(msg)
- raise ValueError(msg)
-
- idxs = np.vectorize(lookup_dict.get)(entities)
- return emb_list[idxs]
-
- def _save_trained_params(self):
- """After model fitting, save all the trained parameters in trained_model_params in some order.
- The order would be useful for loading the model.
- This method must be overridden if the model has any other parameters (apart from entity-relation embeddings).
- """
-
- params_dict = {}
-
- if not self.dealing_with_large_graphs:
- params_dict['ent_emb'] = self.sess_train.run(self.ent_emb)
- else:
- params_dict['ent_emb'] = self.ent_emb_cpu
-
- params_dict['rel_emb'] = self.sess_train.run(self.rel_emb)
-
- params_dict['conv_weights'] = {}
- for name in self.conv_weights.keys():
- params_dict['conv_weights'][name] = {'weights': self.sess_train.run(self.conv_weights[name]['weights']),
- 'biases': self.sess_train.run(self.conv_weights[name]['biases'])}
-
- params_dict['dense_W'] = self.sess_train.run(self.dense_W)
- params_dict['dense_B'] = self.sess_train.run(self.dense_B)
- self.trained_model_params = params_dict
-
- def _load_model_from_trained_params(self):
- """Load the model from trained params.
- While restoring make sure that the order of loaded parameters match the saved order.
- It's the duty of the embedding model to load the variables correctly.
- This method must be overridden if the model has any other parameters (apart from entity-relation embeddings)
- This function also set's the evaluation mode to do lazy loading of variables based on the number of
- distinct entities present in the graph.
- """
-
- # Generate the batch size based on entity length and batch_count
- self.batch_size = int(np.ceil(len(self.ent_to_idx) / self.batches_count))
-
- if len(self.ent_to_idx) > ENTITY_THRESHOLD:
- self.dealing_with_large_graphs = True
-
- logger.warning('Your graph has a large number of distinct entities. '
- 'Found {} distinct entities'.format(len(self.ent_to_idx)))
-
- logger.warning('Changing the variable loading strategy to use lazy loading of variables...')
- logger.warning('Evaluation would take longer than usual.')
-
- if not self.dealing_with_large_graphs:
- self.ent_emb = tf.Variable(self.trained_model_params['ent_emb'], dtype=tf.float32)
- else:
- self.ent_emb_cpu = self.trained_model_params['ent_emb']
- self.ent_emb = tf.Variable(np.zeros((self.batch_size, self.internal_k)), dtype=tf.float32)
-
- self.rel_emb = tf.Variable(self.trained_model_params['rel_emb'], dtype=tf.float32)
-
- with tf.variable_scope('meta'):
- self.tf_is_training = tf.Variable(False, trainable=False)
- self.set_training_true = tf.assign(self.tf_is_training, True)
- self.set_training_false = tf.assign(self.tf_is_training, False)
-
- self.conv_weights = {}
- for name in self.trained_model_params['conv_weights'].keys():
- W = self.trained_model_params['conv_weights'][name]['weights']
- B = self.trained_model_params['conv_weights'][name]['biases']
- self.conv_weights[name] = {'weights': tf.Variable(W, dtype=tf.float32),
- 'biases': tf.Variable(B, dtype=tf.float32)}
-
- self.dense_W = tf.Variable(self.trained_model_params['dense_W'], dtype=tf.float32)
- self.dense_B = tf.Variable(self.trained_model_params['dense_B'], dtype=tf.float32)
-
- def _fn(self, e_s, e_p, e_o):
- r"""The ConvKB scoring function.
-
- The function implements the scoring function as defined by:
- .. math::
-
- \concat(g([\mathbf{e}_s, \mathbf{r}_p, \mathbf{e}_o]) * \Omega)) \cdot W
-
- Additional details for equivalence of the models available in :cite:`Nguyen2018`.
-
-
- Parameters
- ----------
- e_s : Tensor, shape [n]
- The embeddings of a list of subjects.
- e_p : Tensor, shape [n]
- The embeddings of a list of predicates.
- e_o : Tensor, shape [n]
- The embeddings of a list of objects.
-
- Returns
- -------
- score : TensorFlow operation
- The operation corresponding to the ConvKB scoring function.
-
- """
-
- # Inputs
- e_s = tf.expand_dims(e_s, 1)
- e_p = tf.expand_dims(e_p, 1)
- e_o = tf.expand_dims(e_o, 1)
-
- self.inputs = tf.expand_dims(tf.concat([e_s, e_p, e_o], axis=1), -1)
-
- pooled_outputs = []
- for name in self.conv_weights.keys():
- x = tf.nn.conv2d(self.inputs, self.conv_weights[name]['weights'], [1, 1, 1, 1], padding='VALID')
- x = tf.nn.bias_add(x, self.conv_weights[name]['biases'])
- x = tf.nn.relu(x)
- pooled_outputs.append(x)
-
- # Combine all the pooled features
- x = tf.concat(pooled_outputs, 2)
- x = tf.reshape(x, [-1, self.embedding_model_params['dense_dim']])
-
- dropout_rate = tf.cond(self.tf_is_training,
- true_fn=lambda: tf.constant(self.embedding_model_params['dropout']),
- false_fn=lambda: tf.constant(0, dtype=tf.float32))
- x = tf.nn.dropout(x, rate=dropout_rate)
-
- self.scores = tf.nn.xw_plus_b(x, self.dense_W, self.dense_B)
-
- return tf.squeeze(self.scores)
-
- def fit(self, X, early_stopping=False, early_stopping_params={}, focusE_numeric_edge_values=None,
- tensorboard_logs_path=None):
- """Train a ConvKB model (with optional early stopping).
-
- The model is trained on a training set X using the training protocol described in :cite:`trouillon2016complex`.
-
- Parameters
- ----------
- X : ndarray, shape [n, 3]
- The training triples
- early_stopping: bool
- Flag to enable early stopping (default:False).
-
- If set to ``True``, the training loop adopts the following early
- stopping heuristic:
-
- - The model will be trained regardless of early stopping for ``burn_in`` epochs.
- - Every ``check_interval`` epochs the method will compute the metric specified in ``criteria``.
-
- If such metric decreases for ``stop_interval`` checks, we stop
- training early.
-
- Note the metric is computed on ``x_valid``. This is usually a
- validation set that you held out.
-
- Also, because ``criteria`` is a ranking metric, it requires
- generating negatives.
- Entities used to generate corruptions can be specified, as long
- as the side(s) of a triple to corrupt.
- The method supports filtered metrics, by passing an array of
- positives to ``x_filter``. This will be used to
- filter the negatives generated on the fly (i.e. the corruptions).
-
- .. note::
-
- Keep in mind the early stopping criteria may introduce a
- certain overhead
- (caused by the metric computation).
- The goal is to strike a good trade-off between such overhead
- and saving training epochs.
-
- A common approach is to use MRR unfiltered: ::
-
- early_stopping_params={x_valid=X['valid'], 'criteria':
- 'mrr'}
-
- Note the size of validation set also contributes to such
- overhead.
- In most cases a smaller validation set would be enough.
-
- early_stopping_params: dictionary
- Dictionary of hyperparameters for the early stopping heuristics.
-
- The following string keys are supported:
-
- - **'x_valid'**: ndarray, shape [n, 3] : Validation set to be used for early stopping.
- - **'criteria'**: string : criteria for early stopping 'hits10', 'hits3', 'hits1' or 'mrr'(default).
- - **'x_filter'**: ndarray, shape [n, 3] : Positive triples to use as filter if a 'filtered'
- early stopping criteria is desired (i.e. filtered-MRR if 'criteria':'mrr').
- Note this will affect training time (no filter by default).
- - **'burn_in'**: int : Number of epochs to pass before kicking in early stopping (default: 100).
- - **check_interval'**: int : Early stopping interval after burn-in (default:10).
- - **'stop_interval'**: int : Stop if criteria is performing worse over n consecutive checks (default: 3)
- - **'corruption_entities'**: List of entities to be used for corruptions.
- If 'all', it uses all entities (default: 'all')
- - **'corrupt_side'**: Specifies which side to corrupt. 's', 'o', 's+o' (default)
-
- Example: ``early_stopping_params={x_valid=X['valid'], 'criteria': 'mrr'}``
-
- focusE_numeric_edge_values: nd array (n, 1)
- Numeric values associated with links.
- Semantically, the numeric value can signify importance, uncertainity, significance, confidence, etc.
- If the numeric value is unknown pass a NaN weight. The model will uniformly randomly assign a numeric value.
- One can also think about assigning numeric values by looking at the distribution of it per predicate.
-
- tensorboard_logs_path: str or None
- Path to store tensorboard logs, e.g. average training loss tracking per epoch (default: ``None`` indicating
- no logs will be collected). When provided it will create a folder under provided path and save tensorboard
- files there. To then view the loss in the terminal run: ``tensorboard --logdir ``.
-
- """
- super().fit(X, early_stopping, early_stopping_params, focusE_numeric_edge_values,
- tensorboard_logs_path=tensorboard_logs_path)
diff --git a/ampligraph/latent_features/models/DistMult.py b/ampligraph/latent_features/models/DistMult.py
deleted file mode 100644
index 64b47b1f..00000000
--- a/ampligraph/latent_features/models/DistMult.py
+++ /dev/null
@@ -1,331 +0,0 @@
-# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
-#
-# This file is Licensed under the Apache License, Version 2.0.
-# A copy of the Licence is available in LICENCE, or at:
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-from .EmbeddingModel import EmbeddingModel, register_model
-from ampligraph.latent_features import constants as constants
-from ampligraph.latent_features.initializers import DEFAULT_XAVIER_IS_UNIFORM
-import tensorflow as tf
-
-
-@register_model("DistMult",
- ["normalize_ent_emb", "negative_corruption_entities"])
-class DistMult(EmbeddingModel):
- r"""The DistMult model
-
- The model as described in :cite:`yang2014embedding`.
-
- The bilinear diagonal DistMult model uses the trilinear dot product as scoring function:
-
- .. math::
-
- f_{DistMult}=\langle \mathbf{r}_p, \mathbf{e}_s, \mathbf{e}_o \rangle
-
- where :math:`\mathbf{e}_{s}` is the embedding of the subject, :math:`\mathbf{r}_{p}` the embedding
- of the predicate and :math:`\mathbf{e}_{o}` the embedding of the object.
-
- Examples
- --------
- >>> import numpy as np
- >>> from ampligraph.latent_features import DistMult
- >>> model = DistMult(batches_count=1, seed=555, epochs=20, k=10, loss='pairwise',
- >>> loss_params={'margin':5})
- >>> X = np.array([['a', 'y', 'b'],
- >>> ['b', 'y', 'a'],
- >>> ['a', 'y', 'c'],
- >>> ['c', 'y', 'a'],
- >>> ['a', 'y', 'd'],
- >>> ['c', 'y', 'd'],
- >>> ['b', 'y', 'c'],
- >>> ['f', 'y', 'e']])
- >>> model.fit(X)
- >>> model.predict(np.array([['f', 'y', 'e'], ['b', 'y', 'd']]))
- [-0.13863425, -0.09917116]
- >>> model.get_embeddings(['f','e'], embedding_type='entity')
- array([[ 0.10137264, -0.28248304, 0.6153027 , -0.13133956, -0.11675504,
- -0.37876177, 0.06027773, -0.26390398, 0.254603 , 0.1888549 ],
- [-0.6467299 , -0.13729756, 0.3074872 , 0.16966867, -0.04098966,
- 0.25289047, -0.2212451 , -0.6527815 , 0.5657673 , -0.03876532]],
- dtype=float32)
-
- """
-
- def __init__(self,
- k=constants.DEFAULT_EMBEDDING_SIZE,
- eta=constants.DEFAULT_ETA,
- epochs=constants.DEFAULT_EPOCH,
- batches_count=constants.DEFAULT_BATCH_COUNT,
- seed=constants.DEFAULT_SEED,
- embedding_model_params={'normalize_ent_emb': constants.DEFAULT_NORMALIZE_EMBEDDINGS,
- 'negative_corruption_entities': constants.DEFAULT_CORRUPTION_ENTITIES,
- 'corrupt_sides': constants.DEFAULT_CORRUPT_SIDE_TRAIN},
- optimizer=constants.DEFAULT_OPTIM,
- optimizer_params={'lr': constants.DEFAULT_LR},
- loss=constants.DEFAULT_LOSS,
- loss_params={},
- regularizer=constants.DEFAULT_REGULARIZER,
- regularizer_params={},
- initializer=constants.DEFAULT_INITIALIZER,
- initializer_params={'uniform': DEFAULT_XAVIER_IS_UNIFORM},
- verbose=constants.DEFAULT_VERBOSE):
- """Initialize an EmbeddingModel
-
- Also creates a new Tensorflow session for training.
-
- Parameters
- ----------
- k : int
- Embedding space dimensionality
- eta : int
- The number of negatives that must be generated at runtime during training for each positive.
- epochs : int
- The iterations of the training loop.
- batches_count : int
- The number of batches in which the training set must be split during the training loop.
- seed : int
- The seed used by the internal random numbers generator.
- embedding_model_params : dict
- DistMult-specific hyperparams, passed to the model as a dictionary.
-
- Supported keys:
-
- - **'normalize_ent_emb'** (bool): flag to indicate whether to normalize entity embeddings
- after each batch update (default: False).
- - **'negative_corruption_entities'** - Entities to be used for generation of corruptions while training.
- It can take the following values :
- ``all`` (default: all entities),
- ``batch`` (entities present in each batch),
- list of entities
- or an int (which indicates how many entities that should be used for corruption generation).
- - **corrupt_sides** : Specifies how to generate corruptions for training.
- Takes values `s`, `o`, `s+o` or any combination passed as a list
- - **'non_linearity'**: can be one of the following values ``linear``, ``softplus``, ``sigmoid``, ``tanh``
- - **'stop_epoch'**: specifies how long to decay (linearly) the numeric values from 1 to original value
- until it reachs original value.
- - **'structural_wt'**: structural influence hyperparameter [0, 1] that modulates the influence of graph
- topology.
- - **'normalize_numeric_values'**: normalize the numeric values, such that they are scaled between [0, 1]
-
- The last 4 parameters are related to FocusE layers.
-
- Example: ``embedding_model_params={'normalize_ent_emb': False}``
-
- optimizer : string
- The optimizer used to minimize the loss function. Choose between 'sgd',
- 'adagrad', 'adam', 'momentum'.
-
- optimizer_params : dict
- Arguments specific to the optimizer, passed as a dictionary.
-
- Supported keys:
-
- - **'lr'** (float): learning rate (used by all the optimizers). Default: 0.1.
- - **'momentum'** (float): learning momentum (only used when ``optimizer=momentum``). Default: 0.9.
-
- Example: ``optimizer_params={'lr': 0.01}``
-
- loss : string
- The type of loss function to use during training.
-
- - ``pairwise`` the model will use pairwise margin-based loss function.
- - ``nll`` the model will use negative loss likelihood.
- - ``absolute_margin`` the model will use absolute margin likelihood.
- - ``self_adversarial`` the model will use adversarial sampling loss function.
- - ``multiclass_nll`` the model will use multiclass nll loss.
- Switch to multiclass loss defined in :cite:`chen2015` by passing 'corrupt_sides'
- as ['s','o'] to embedding_model_params.
- To use loss defined in :cite:`kadlecBK17` pass 'corrupt_sides' as 'o' to embedding_model_params.
-
- loss_params : dict
- Dictionary of loss-specific hyperparameters. See :ref:`loss functions `
- documentation for additional details.
-
- Example: ``optimizer_params={'lr': 0.01}`` if ``loss='pairwise'``.
-
- regularizer : string
- The regularization strategy to use with the loss function.
-
- - ``None``: the model will not use any regularizer (default)
- - 'LP': the model will use L1, L2 or L3 based on the value of ``regularizer_params['p']`` (see below).
-
- regularizer_params : dict
- Dictionary of regularizer-specific hyperparameters. See the :ref:`regularizers `
- documentation for additional details.
-
- Example: ``regularizer_params={'lambda': 1e-5, 'p': 2}`` if ``regularizer='LP'``.
-
- initializer : string
- The type of initializer to use.
-
- - ``normal``: The embeddings will be initialized from a normal distribution
- - ``uniform``: The embeddings will be initialized from a uniform distribution
- - ``xavier``: The embeddings will be initialized using xavier strategy (default)
-
- initializer_params : dict
- Dictionary of initializer-specific hyperparameters. See the
- :ref:`initializer `
- documentation for additional details.
-
- Example: ``initializer_params={'mean': 0, 'std': 0.001}`` if ``initializer='normal'``.
-
- verbose : bool
- Verbose mode.
- """
- super().__init__(k=k, eta=eta, epochs=epochs, batches_count=batches_count, seed=seed,
- embedding_model_params=embedding_model_params,
- optimizer=optimizer, optimizer_params=optimizer_params,
- loss=loss, loss_params=loss_params,
- regularizer=regularizer, regularizer_params=regularizer_params,
- initializer=initializer, initializer_params=initializer_params,
- verbose=verbose)
-
- def _fn(self, e_s, e_p, e_o):
- r"""DistMult
-
- .. math::
-
- f_{DistMult}=\langle \mathbf{r}_p, \mathbf{e}_s, \mathbf{e}_o \rangle
-
-
- Parameters
- ----------
- e_s : Tensor, shape [n]
- The embeddings of a list of subjects.
- e_p : Tensor, shape [n]
- The embeddings of a list of predicates.
- e_o : Tensor, shape [n]
- The embeddings of a list of objects.
-
- Returns
- -------
- score : TensorFlow operation
- The operation corresponding to the DistMult scoring function.
-
- """
-
- return tf.reduce_sum(e_s * e_p * e_o, axis=1)
-
- def fit(self, X, early_stopping=False, early_stopping_params={},
- focusE_numeric_edge_values=None, tensorboard_logs_path=None):
- """Train an DistMult.
-
- The model is trained on a training set X using the training protocol
- described in :cite:`trouillon2016complex`.
-
- Parameters
- ----------
- X : ndarray, shape [n, 3]
- The training triples
- early_stopping: bool
- Flag to enable early stopping (default:False).
-
- If set to ``True``, the training loop adopts the following early stopping heuristic:
-
- - The model will be trained regardless of early stopping for ``burn_in`` epochs.
- - Every ``check_interval`` epochs the method will compute the metric specified in ``criteria``.
-
- If such metric decreases for ``stop_interval`` checks, we stop training early.
-
- Note the metric is computed on ``x_valid``. This is usually a validation set that you held out.
-
- Also, because ``criteria`` is a ranking metric, it requires generating negatives.
- Entities used to generate corruptions can be specified, as long as the side(s) of a triple to corrupt.
- The method supports filtered metrics, by passing an array of positives to ``x_filter``. This will be used to
- filter the negatives generated on the fly (i.e. the corruptions).
-
- .. note::
-
- Keep in mind the early stopping criteria may introduce a certain overhead
- (caused by the metric computation).
- The goal is to strike a good trade-off between such overhead and saving training epochs.
-
- A common approach is to use MRR unfiltered: ::
-
- early_stopping_params={x_valid=X['valid'], 'criteria': 'mrr'}
-
- Note the size of validation set also contributes to such overhead.
- In most cases a smaller validation set would be enough.
-
- early_stopping_params: dictionary
- Dictionary of hyperparameters for the early stopping heuristics.
-
- The following string keys are supported:
-
- - **'x_valid'**: ndarray, shape [n, 3] : Validation set to be used for early stopping.
- - **'criteria'**: string : criteria for early stopping 'hits10', 'hits3', 'hits1' or 'mrr'(default).
- - **'x_filter'**: ndarray, shape [n, 3] : Positive triples to use as filter if a 'filtered'
- early stopping criteria is desired (i.e. filtered-MRR if 'criteria':'mrr').
- Note this will affect training time (no filter by default).
- - **'burn_in'**: int : Number of epochs to pass before kicking in early stopping (default: 100).
- - **check_interval'**: int : Early stopping interval after burn-in (default:10).
- - **'stop_interval'**: int : Stop if criteria is performing worse over n consecutive checks (default: 3)
- - **'corruption_entities'**: List of entities to be used for corruptions.
- If 'all', it uses all entities (default: 'all')
- - **'corrupt_side'**: Specifies which side to corrupt. 's', 'o', 's+o' (default)
-
- Example: ``early_stopping_params={x_valid=X['valid'], 'criteria': 'mrr'}``
-
- focusE_numeric_edge_values: ndarray, shape [n]
- .. _focuse_distmult:
-
- If processing a knowledge graph with numeric values associated with links, this is the vector of such
- numbers. Passing this argument will activate the :ref:`FocusE layer `
- :cite:`pai2021learning`.
- Semantically, numeric values can signify importance, uncertainity, significance, confidence, etc.
- Values can be any number, and will be automatically normalised to the [0, 1] range, on a
- predicate-specific basis.
- If the numeric value is unknown pass a ``np.NaN`` value.
- The model will uniformly randomly assign a numeric value.
-
- .. note::
-
- The following toy example shows how to enable the FocusE layer
- to process edges with numeric literals: ::
-
- import numpy as np
- from ampligraph.latent_features import DistMult
- model = DistMult(batches_count=1, seed=555, epochs=20,
- k=10, loss='pairwise',
- loss_params={'margin':5})
- X = np.array([['a', 'y', 'b'],
- ['b', 'y', 'a'],
- ['a', 'y', 'c'],
- ['c', 'y', 'a'],
- ['a', 'y', 'd'],
- ['c', 'y', 'd'],
- ['b', 'y', 'c'],
- ['f', 'y', 'e']])
-
- # Numeric values below are associate to each triple in X.
- # They can be any number and will be automatically
- # normalised to the [0, 1] range, on a
- # predicate-specific basis.
- X_edge_values = np.array([5.34, -1.75, 0.33, 5.12,
- np.nan, 3.17, 2.76, 0.41])
-
- model.fit(X, focusE_numeric_edge_values=X_edge_values)
-
-
- tensorboard_logs_path: str or None
- Path to store tensorboard logs, e.g. average training loss tracking per epoch (default: ``None`` indicating
- no logs will be collected). When provided it will create a folder under provided path and save tensorboard
- files there. To then view the loss in the terminal run: ``tensorboard --logdir ``.
-
- """
- super().fit(X, early_stopping, early_stopping_params, focusE_numeric_edge_values,
- tensorboard_logs_path=tensorboard_logs_path)
-
- def predict(self, X, from_idx=False):
- __doc__ = super().predict.__doc__ # NOQA
- return super().predict(X, from_idx=from_idx)
-
- def calibrate(self, X_pos, X_neg=None, positive_base_rate=None, batches_count=100, epochs=50):
- __doc__ = super().calibrate.__doc__ # NOQA
- super().calibrate(X_pos, X_neg, positive_base_rate, batches_count, epochs)
-
- def predict_proba(self, X):
- __doc__ = super().predict_proba.__doc__ # NOQA
- return super().predict_proba(X)
diff --git a/ampligraph/latent_features/models/EmbeddingModel.py b/ampligraph/latent_features/models/EmbeddingModel.py
deleted file mode 100644
index 5fd8206a..00000000
--- a/ampligraph/latent_features/models/EmbeddingModel.py
+++ /dev/null
@@ -1,2128 +0,0 @@
-# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
-#
-# This file is Licensed under the Apache License, Version 2.0.
-# A copy of the Licence is available in LICENCE, or at:
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-import numpy as np
-import tensorflow as tf
-from sklearn.utils import check_random_state
-import abc
-from tqdm import tqdm
-import logging
-from ampligraph.latent_features.loss_functions import LOSS_REGISTRY
-from ampligraph.latent_features.regularizers import REGULARIZER_REGISTRY
-from ampligraph.latent_features.optimizers import OPTIMIZER_REGISTRY, SGDOptimizer
-from ampligraph.latent_features.initializers import INITIALIZER_REGISTRY, DEFAULT_XAVIER_IS_UNIFORM
-from ampligraph.evaluation import generate_corruptions_for_fit, to_idx, generate_corruptions_for_eval, \
- hits_at_n_score, mrr_score
-from ampligraph.datasets import AmpligraphDatasetAdapter, NumpyDatasetAdapter
-from functools import partial
-from ampligraph.latent_features import constants as constants
-import time
-
-logger = logging.getLogger(__name__)
-logger.setLevel(logging.DEBUG)
-
-MODEL_REGISTRY = {}
-
-ENTITY_THRESHOLD = 5e5
-
-
-def set_entity_threshold(threshold):
- """Sets the entity threshold (threshold after which large graph mode is initiated)
- """
- global ENTITY_THRESHOLD
- ENTITY_THRESHOLD = threshold
-
-
-def reset_entity_threshold():
- """Resets the entity threshold
- """
- global ENTITY_THRESHOLD
- ENTITY_THRESHOLD = 5e5
-
-
-def register_model(name, external_params=None, class_params=None):
- if external_params is None:
- external_params = []
- if class_params is None:
- class_params = {}
-
- def insert_in_registry(class_handle):
- MODEL_REGISTRY[name] = class_handle
- class_handle.name = name
- MODEL_REGISTRY[name].external_params = external_params
- MODEL_REGISTRY[name].class_params = class_params
- return class_handle
-
- return insert_in_registry
-
-
-@tf.custom_gradient
-def custom_softplus(x):
- e = 9999 * tf.exp(x)
-
- def grad(dy):
- return dy * (1 - 1 / (1 + e))
- return tf.math.log(1 + e), grad
-
-
-class EmbeddingModel(abc.ABC):
- """Abstract class for embedding models
-
- AmpliGraph neural knowledge graph embeddings models extend this class and
- its core methods.
-
- """
-
- def __init__(self,
- k=constants.DEFAULT_EMBEDDING_SIZE,
- eta=constants.DEFAULT_ETA,
- epochs=constants.DEFAULT_EPOCH,
- batches_count=constants.DEFAULT_BATCH_COUNT,
- seed=constants.DEFAULT_SEED,
- embedding_model_params={},
- optimizer=constants.DEFAULT_OPTIM,
- optimizer_params={'lr': constants.DEFAULT_LR},
- loss=constants.DEFAULT_LOSS,
- loss_params={},
- regularizer=constants.DEFAULT_REGULARIZER,
- regularizer_params={},
- initializer=constants.DEFAULT_INITIALIZER,
- initializer_params={'uniform': DEFAULT_XAVIER_IS_UNIFORM},
- large_graphs=False,
- verbose=constants.DEFAULT_VERBOSE):
- """Initialize an EmbeddingModel
-
- Also creates a new Tensorflow session for training.
-
- Parameters
- ----------
- k : int
- Embedding space dimensionality.
-
- eta : int
- The number of negatives that must be generated at runtime during training for each positive.
-
- epochs : int
- The iterations of the training loop.
-
- batches_count : int
- The number of batches in which the training set must be split during the training loop.
-
- seed : int
- The seed used by the internal random numbers generator.
-
- embedding_model_params : dict
- Model-specific hyperparams, passed to the model as a dictionary.
- Refer to model-specific documentation for details.
-
- For FocusE Layer, following hyper-params can be passed:
-
- - **'non_linearity'**: can be one of the following values ``linear``, ``softplus``, ``sigmoid``, ``tanh``
- - **'stop_epoch'**: specifies how long to decay (linearly) the numeric values from 1 to original value
- until it reachs original value.
- - **'structural_wt'**: structural influence hyperparameter [0, 1] that modulates the influence of graph
- topology.
- - **'normalize_numeric_values'**: normalize the numeric values, such that they are scaled between [0, 1]
-
- optimizer : string
- The optimizer used to minimize the loss function. Choose between
- 'sgd', 'adagrad', 'adam', 'momentum'.
-
- optimizer_params : dict
- Arguments specific to the optimizer, passed as a dictionary.
-
- Supported keys:
-
- - **'lr'** (float): learning rate (used by all the optimizers). Default: 0.1.
- - **'momentum'** (float): learning momentum (only used when ``optimizer=momentum``). Default: 0.9.
-
- Example: ``optimizer_params={'lr': 0.01}``
-
- loss : string
- The type of loss function to use during training.
-
- - ``pairwise`` the model will use pairwise margin-based loss function.
- - ``nll`` the model will use negative loss likelihood.
- - ``absolute_margin`` the model will use absolute margin likelihood.
- - ``self_adversarial`` the model will use adversarial sampling loss function.
- - ``multiclass_nll`` the model will use multiclass nll loss. Switch to multiclass loss defined in
- :cite:`chen2015` by passing 'corrupt_side' as ['s','o'] to embedding_model_params.
- To use loss defined in :cite:`kadlecBK17` pass 'corrupt_side' as 'o' to embedding_model_params.
-
- loss_params : dict
- Dictionary of loss-specific hyperparameters. See :ref:`loss
- functions `
- documentation for additional details.
-
- Example: ``optimizer_params={'lr': 0.01}`` if ``loss='pairwise'``.
-
- regularizer : string
- The regularization strategy to use with the loss function.
-
- - ``None``: the model will not use any regularizer (default)
- - ``LP``: the model will use L1, L2 or L3 based on the value of ``regularizer_params['p']`` (see below).
-
- regularizer_params : dict
- Dictionary of regularizer-specific hyperparameters. See the
- :ref:`regularizers `
- documentation for additional details.
-
- Example: ``regularizer_params={'lambda': 1e-5, 'p': 2}`` if ``regularizer='LP'``.
-
- initializer : string
- The type of initializer to use.
-
- - ``normal``: The embeddings will be initialized from a normal distribution
- - ``uniform``: The embeddings will be initialized from a uniform distribution
- - ``xavier``: The embeddings will be initialized using xavier strategy (default)
-
- initializer_params : dict
- Dictionary of initializer-specific hyperparameters. See the
- :ref:`initializer `
- documentation for additional details.
-
- Example: ``initializer_params={'mean': 0, 'std': 0.001}`` if ``initializer='normal'``.
-
- large_graphs : bool
- Avoid loading entire dataset onto GPU when dealing with large graphs.
-
- verbose : bool
- Verbose mode.
- """
- if (loss == "bce") ^ (self.name == "ConvE"):
- raise ValueError('Invalid Model - Loss combination. '
- 'ConvE model can be used with BCE loss only and vice versa.')
-
- # Store for restoring later.
- self.all_params = \
- {
- 'k': k,
- 'eta': eta,
- 'epochs': epochs,
- 'batches_count': batches_count,
- 'seed': seed,
- 'embedding_model_params': embedding_model_params,
- 'optimizer': optimizer,
- 'optimizer_params': optimizer_params,
- 'loss': loss,
- 'loss_params': loss_params,
- 'regularizer': regularizer,
- 'regularizer_params': regularizer_params,
- 'initializer': initializer,
- 'initializer_params': initializer_params,
- 'verbose': verbose
- }
- tf.reset_default_graph()
- self.seed = seed
- self.rnd = check_random_state(self.seed)
- tf.random.set_random_seed(seed)
-
- self.is_filtered = False
- self.use_focusE = False
- self.loss_params = loss_params
-
- self.embedding_model_params = embedding_model_params
-
- self.k = k
- self.internal_k = k
- self.epochs = epochs
- self.eta = eta
- self.regularizer_params = regularizer_params
- self.batches_count = batches_count
-
- self.dealing_with_large_graphs = large_graphs
-
- if batches_count == 1:
- logger.warning(
- 'All triples will be processed in the same batch (batches_count=1). '
- 'When processing large graphs it is recommended to batch the input knowledge graph instead.')
-
- try:
- self.loss = LOSS_REGISTRY[loss](self.eta, self.loss_params, verbose=verbose)
- except KeyError:
- msg = 'Unsupported loss function: {}'.format(loss)
- logger.error(msg)
- raise ValueError(msg)
-
- try:
- if regularizer is not None:
- self.regularizer = REGULARIZER_REGISTRY[regularizer](self.regularizer_params, verbose=verbose)
- else:
- self.regularizer = regularizer
- except KeyError:
- msg = 'Unsupported regularizer: {}'.format(regularizer)
- logger.error(msg)
- raise ValueError(msg)
-
- self.optimizer_params = optimizer_params
-
- try:
- self.optimizer = OPTIMIZER_REGISTRY[optimizer](self.optimizer_params,
- self.batches_count,
- verbose)
- except KeyError:
- msg = 'Unsupported optimizer: {}'.format(optimizer)
- logger.error(msg)
- raise ValueError(msg)
-
- self.verbose = verbose
-
- self.initializer_params = initializer_params
-
- try:
- self.initializer = INITIALIZER_REGISTRY[initializer](self.initializer_params,
- verbose,
- self.rnd)
- except KeyError:
- msg = 'Unsupported initializer: {}'.format(initializer)
- logger.error(msg)
- raise ValueError(msg)
-
- self.tf_config = tf.ConfigProto(allow_soft_placement=True)
- self.tf_config.gpu_options.allow_growth = True
- self.sess_train = None
- self.trained_model_params = []
- self.is_fitted = False
- self.eval_config = {}
- self.eval_dataset_handle = None
- self.train_dataset_handle = None
- self.is_calibrated = False
- self.calibration_parameters = []
-
- @abc.abstractmethod
- def _fn(self, e_s, e_p, e_o):
- """The scoring function of the model.
-
- Assigns a score to a list of triples, with a model-specific strategy.
- Triples are passed as lists of subject, predicate, object embeddings.
- This function must be overridden by every model to return corresponding score.
-
- Parameters
- ----------
- e_s : Tensor, shape [n]
- The embeddings of a list of subjects.
- e_p : Tensor, shape [n]
- The embeddings of a list of predicates.
- e_o : Tensor, shape [n]
- The embeddings of a list of objects.
-
- Returns
- -------
- score : TensorFlow operation
- The operation corresponding to the scoring function.
-
- """
- logger.error('_fn is a placeholder function in an abstract class')
- NotImplementedError("This function is a placeholder in an abstract class")
-
- def get_hyperparameter_dict(self):
- """Returns hyperparameters of the model.
-
- Returns
- -------
- hyperparam_dict : dict
- Dictionary of hyperparameters that were used for training.
-
- """
- return self.all_params
-
- def get_embedding_model_params(self, output_dict):
- """Save the model parameters in the dictionary.
-
- Parameters
- ----------
- output_dict : dictionary
- Dictionary of saved params.
- It's the duty of the model to save all the variables correctly, so that it can be used for restoring later.
-
- """
- output_dict['model_params'] = self.trained_model_params
- output_dict['large_graph'] = self.dealing_with_large_graphs
- output_dict['calibration_parameters'] = self.calibration_parameters
-
- def restore_model_params(self, in_dict):
- """Load the model parameters from the input dictionary.
-
- Parameters
- ----------
- in_dict : dictionary
- Dictionary of saved params. It's the duty of the model to load the variables correctly.
- """
-
- self.trained_model_params = in_dict['model_params']
-
- # Try catch is for backward compatibility
- try:
- self.calibration_parameters = in_dict['calibration_parameters']
- except KeyError:
- # For backward compatibility
- self.calibration_parameters = []
-
- # Try catch is for backward compatibility
- try:
- self.dealing_with_large_graphs = in_dict['large_graph']
- except KeyError:
- # For backward compatibility
- self.dealing_with_large_graphs = False
-
- def _save_trained_params(self):
- """After model fitting, save all the trained parameters in trained_model_params in some order.
- The order would be useful for loading the model.
- This method must be overridden if the model has any other parameters (apart from entity-relation embeddings).
- """
- params_to_save = []
- if not self.dealing_with_large_graphs:
- params_to_save.append(self.sess_train.run(self.ent_emb))
- else:
- params_to_save.append(self.ent_emb_cpu)
-
- params_to_save.append(self.sess_train.run(self.rel_emb))
-
- self.trained_model_params = params_to_save
-
- def _load_model_from_trained_params(self):
- """Load the model from trained params.
- While restoring make sure that the order of loaded parameters match the saved order.
- It's the duty of the embedding model to load the variables correctly.
- This method must be overridden if the model has any other parameters (apart from entity-relation embeddings).
- This function also set's the evaluation mode to do lazy loading of variables based on the number of
- distinct entities present in the graph.
- """
-
- # Generate the batch size based on entity length and batch_count
- self.batch_size = int(np.ceil(len(self.ent_to_idx) / self.batches_count))
-
- if len(self.ent_to_idx) > ENTITY_THRESHOLD:
- self.dealing_with_large_graphs = True
-
- logger.warning('Your graph has a large number of distinct entities. '
- 'Found {} distinct entities'.format(len(self.ent_to_idx)))
-
- logger.warning('Changing the variable loading strategy to use lazy loading of variables...')
- logger.warning('Evaluation would take longer than usual.')
-
- if not self.dealing_with_large_graphs:
- # (We use tf.variable for future - to load and continue training)
- self.ent_emb = tf.Variable(self.trained_model_params[0], dtype=tf.float32)
- else:
- # Embeddings of all the corruptions entities will not fit on GPU.
- # During training we loaded batch_size*2 embeddings on GPU as only 2* batch_size unique
- # entities can be present in one batch.
- # During corruption generation in eval mode, one side(s/o) is fixed and only the other side varies.
- # Hence we use a batch size of 2 * training_batch_size for corruption generation i.e. those many
- # corruption embeddings would be loaded per batch on the GPU. In other words, those corruptions
- # would be processed as a batch.
-
- self.corr_batch_size = self.batch_size * 2
-
- # Load the entity embeddings on the cpu
- self.ent_emb_cpu = self.trained_model_params[0]
- # (We use tf.variable for future - to load and continue training)
- # create empty variable on GPU.
- # we initialize it with zeros because the actual embeddings will be loaded on the fly.
- self.ent_emb = tf.Variable(np.zeros((self.corr_batch_size, self.internal_k)), dtype=tf.float32)
-
- # (We use tf.variable for future - to load and continue training)
- self.rel_emb = tf.Variable(self.trained_model_params[1], dtype=tf.float32)
-
- def get_embeddings(self, entities, embedding_type='entity'):
- """Get the embeddings of entities or relations.
-
- .. Note ::
- Use :meth:`ampligraph.utils.create_tensorboard_visualizations` to visualize the embeddings with TensorBoard.
-
- Parameters
- ----------
- entities : array-like, dtype=int, shape=[n]
- The entities (or relations) of interest. Element of the vector must be the original string literals, and
- not internal IDs.
- embedding_type : string
- If 'entity', ``entities`` argument will be considered as a list of knowledge graph entities (i.e. nodes).
- If set to 'relation', they will be treated as relation types instead (i.e. predicates).
-
- Returns
- -------
- embeddings : ndarray, shape [n, k]
- An array of k-dimensional embeddings.
-
- """
- if not self.is_fitted:
- msg = 'Model has not been fitted.'
- logger.error(msg)
- raise RuntimeError(msg)
-
- if embedding_type == 'entity':
- emb_list = self.trained_model_params[0]
- lookup_dict = self.ent_to_idx
- elif embedding_type == 'relation':
- emb_list = self.trained_model_params[1]
- lookup_dict = self.rel_to_idx
- else:
- msg = 'Invalid entity type: {}'.format(embedding_type)
- logger.error(msg)
- raise ValueError(msg)
-
- idxs = np.vectorize(lookup_dict.get)(entities)
- return emb_list[idxs]
-
- def _lookup_embeddings(self, x, get_weight=False):
- """Get the embeddings for subjects, predicates, and objects of a list of statements used to train the model.
-
- Parameters
- ----------
- x : tensor, shape [n, k]
- A tensor of k-dimensional embeddings
-
- Returns
- -------
- e_s : Tensor
- A Tensor that includes the embeddings of the subjects.
- e_p : Tensor
- A Tensor that includes the embeddings of the predicates.
- e_o : Tensor
- A Tensor that includes the embeddings of the objects.
- """
- e_s = self._entity_lookup(x[:, 0])
- e_p = tf.nn.embedding_lookup(self.rel_emb, x[:, 1])
- e_o = self._entity_lookup(x[:, 2])
-
- if get_weight:
- wt = self.weight_triple[
- self.batch_number * self.batch_size:(self.batch_number + 1) * self.batch_size]
-
- return e_s, e_p, e_o, wt
- return e_s, e_p, e_o
-
- def _entity_lookup(self, entity):
- """Get the embeddings for entities.
- Remaps the entity indices to corresponding variables in the GPU memory when dealing with large graphs.
-
- Parameters
- ----------
- entity : nd-tensor, shape [n, 1]
- Returns
- -------
- emb : Tensor
- A Tensor that includes the embeddings of the entities.
- """
-
- if self.dealing_with_large_graphs:
- remapping = self.sparse_mappings.lookup(entity)
- else:
- remapping = entity
-
- emb = tf.nn.embedding_lookup(self.ent_emb, remapping)
- return emb
-
- def _initialize_parameters(self):
- """Initialize parameters of the model.
-
- This function creates and initializes entity and relation embeddings (with size k).
- If the graph is large, then it loads only the required entity embeddings (max:batch_size*2)
- and all relation embeddings.
- Overload this function if the parameters needs to be initialized differently.
- """
- timestamp = int(time.time() * 1e6)
- if not self.dealing_with_large_graphs:
- self.ent_emb = tf.get_variable('ent_emb_{}'.format(timestamp),
- shape=[len(self.ent_to_idx), self.internal_k],
- initializer=self.initializer.get_entity_initializer(
- len(self.ent_to_idx), self.internal_k),
- dtype=tf.float32)
- self.rel_emb = tf.get_variable('rel_emb_{}'.format(timestamp),
- shape=[len(self.rel_to_idx), self.internal_k],
- initializer=self.initializer.get_relation_initializer(
- len(self.rel_to_idx), self.internal_k),
- dtype=tf.float32)
- else:
- # initialize entity embeddings to zero (these are reinitialized every batch by batch embeddings)
- self.ent_emb = tf.get_variable('ent_emb_{}'.format(timestamp),
- shape=[self.batch_size * 2, self.internal_k],
- initializer=tf.zeros_initializer(),
- dtype=tf.float32)
- self.rel_emb = tf.get_variable('rel_emb_{}'.format(timestamp),
- shape=[len(self.rel_to_idx), self.internal_k],
- initializer=self.initializer.get_relation_initializer(
- len(self.rel_to_idx), self.internal_k),
- dtype=tf.float32)
-
- def _get_model_loss(self, dataset_iterator):
- """Get the current loss including loss due to regularization.
- This function must be overridden if the model uses combination of different losses(eg: VAE).
-
- Parameters
- ----------
- dataset_iterator : tf.data.Iterator
- Dataset iterator.
-
- Returns
- -------
- loss : tf.Tensor
- The loss value that must be minimized.
- """
- self.epoch = tf.placeholder(tf.float32)
- self.batch_number = tf.placeholder(tf.int32)
-
- if self.use_focusE:
- x_pos_tf, self.unique_entities, ent_emb_batch, weights = dataset_iterator.get_next()
-
- else:
- # get the train triples of the batch, unique entities and the corresponding embeddings
- # the latter 2 variables are passed only for large graphs.
- x_pos_tf, self.unique_entities, ent_emb_batch = dataset_iterator.get_next()
-
- # list of dependent ops that need to be evaluated before computing the loss
- dependencies = []
-
- # if the graph is large
- if self.dealing_with_large_graphs:
- # Create a dependency to load the embeddings of the batch entities dynamically
- init_ent_emb_batch = self.ent_emb.assign(ent_emb_batch, use_locking=True)
- dependencies.append(init_ent_emb_batch)
-
- # create a lookup dependency(to remap the entity indices to the corresponding indices of variables in memory
- self.sparse_mappings = tf.contrib.lookup.MutableDenseHashTable(key_dtype=tf.int32, value_dtype=tf.int32,
- default_value=-1, empty_key=-2,
- deleted_key=-1)
-
- insert_lookup_op = self.sparse_mappings.insert(self.unique_entities,
- tf.reshape(tf.range(tf.shape(self.unique_entities)[0],
- dtype=tf.int32), (-1, 1)))
-
- dependencies.append(insert_lookup_op)
-
- # run the dependencies
- with tf.control_dependencies(dependencies):
- entities_size = 0
- entities_list = None
-
- x_pos = x_pos_tf
-
- e_s_pos, e_p_pos, e_o_pos = self._lookup_embeddings(x_pos)
-
- scores_pos = self._fn(e_s_pos, e_p_pos, e_o_pos)
-
- non_linearity = self.embedding_model_params.get('non_linearity', 'linear')
- if non_linearity == 'linear':
- scores_pos = scores_pos
- elif non_linearity == 'tanh':
- scores_pos = tf.tanh(scores_pos)
- elif non_linearity == 'sigmoid':
- scores_pos = tf.sigmoid(scores_pos)
- elif non_linearity == 'softplus':
- scores_pos = custom_softplus(scores_pos)
- else:
- raise ValueError('Invalid non-linearity')
-
- if self.use_focusE:
-
- epoch_before_stopping_weight = self.embedding_model_params.get('stop_epoch', 251)
- assert epoch_before_stopping_weight >= 0, "Invalid value for stop_epoch"
-
- if epoch_before_stopping_weight == 0:
- # use fixed structural weight
- structure_weight = self.embedding_model_params.get('structural_wt', 0.001)
- assert structure_weight <= 1 and structure_weight >= 0, \
- "Invalid structure_weight passed to model params!"
-
- else:
- # decay of numeric values
- # start with all triples having same numeric values and linearly decay till original value
- structure_weight = tf.maximum(1 - self.epoch / epoch_before_stopping_weight, 0.001)
-
- weights = tf.reduce_mean(weights, 1)
- weights_pos = structure_weight + (1 - structure_weight) * (1 - weights)
- weights_neg = structure_weight + (1 - structure_weight) * (
- tf.reshape(tf.tile(weights, [self.eta]), [tf.shape(weights)[0] * self.eta]))
-
- scores_pos = scores_pos * weights_pos
-
- if self.loss.get_state('require_same_size_pos_neg'):
- logger.debug('Requires the same size of postive and negative')
- scores_pos = tf.reshape(tf.tile(scores_pos, [self.eta]), [tf.shape(scores_pos)[0] * self.eta])
-
- # look up embeddings from input training triples
- negative_corruption_entities = self.embedding_model_params.get('negative_corruption_entities',
- constants.DEFAULT_CORRUPTION_ENTITIES)
-
- if negative_corruption_entities == 'all':
- '''
- if number of entities are large then in this case('all'),
- the corruptions would be generated from batch entities and and additional random entities that
- are selected from all entities (since a total of batch_size*2 entity embeddings are loaded in memory)
- '''
- logger.debug('Using all entities for generation of corruptions during training')
- if self.dealing_with_large_graphs:
- entities_list = tf.squeeze(self.unique_entities)
- else:
- entities_size = tf.shape(self.ent_emb)[0]
- elif negative_corruption_entities == 'batch':
- # default is batch (entities_size=0 and entities_list=None)
- logger.debug('Using batch entities for generation of corruptions during training')
- elif isinstance(negative_corruption_entities, list):
- logger.debug('Using the supplied entities for generation of corruptions during training')
- entities_list = tf.squeeze(tf.constant(np.asarray([idx for uri, idx in self.ent_to_idx.items()
- if uri in negative_corruption_entities]),
- dtype=tf.int32))
- elif isinstance(negative_corruption_entities, int):
- logger.debug('Using first {} entities for generation of corruptions during \
- training'.format(negative_corruption_entities))
- entities_size = negative_corruption_entities
-
- loss = 0
- corruption_sides = self.embedding_model_params.get('corrupt_side', constants.DEFAULT_CORRUPT_SIDE_TRAIN)
- if not isinstance(corruption_sides, list):
- corruption_sides = [corruption_sides]
-
- for side in corruption_sides:
- # Generate the corruptions
- x_neg_tf = generate_corruptions_for_fit(x_pos_tf,
- entities_list=entities_list,
- eta=self.eta,
- corrupt_side=side,
- entities_size=entities_size,
- rnd=self.seed)
-
- # compute corruption scores
- e_s_neg, e_p_neg, e_o_neg = self._lookup_embeddings(x_neg_tf)
- scores_neg = self._fn(e_s_neg, e_p_neg, e_o_neg)
-
- if non_linearity == 'linear':
- scores_neg = scores_neg
- elif non_linearity == 'tanh':
- scores_neg = tf.tanh(scores_neg)
- elif non_linearity == 'sigmoid':
- scores_neg = tf.sigmoid(scores_neg)
- elif non_linearity == 'softplus':
- scores_neg = custom_softplus(scores_neg)
- else:
- raise ValueError('Invalid non-linearity')
-
- if self.use_focusE:
- scores_neg = scores_neg * weights_neg
-
- # Apply the loss function
- loss += self.loss.apply(scores_pos, scores_neg)
-
- if self.regularizer is not None:
- # Apply the regularizer
- loss += self.regularizer.apply([self.ent_emb, self.rel_emb])
-
- return loss
-
- def _initialize_early_stopping(self):
- """Initializes and creates evaluation graph for early stopping.
- """
- try:
- self.x_valid = self.early_stopping_params['x_valid']
-
- if isinstance(self.x_valid, np.ndarray):
- if self.x_valid.ndim <= 1 or (np.shape(self.x_valid)[1]) != 3:
- msg = 'Invalid size for input x_valid. Expected (n,3): got {}'.format(np.shape(self.x_valid))
- logger.error(msg)
- raise ValueError(msg)
-
- # store the validation data in the data handler
- self.x_valid = to_idx(self.x_valid, ent_to_idx=self.ent_to_idx, rel_to_idx=self.rel_to_idx)
- self.train_dataset_handle.set_data(self.x_valid, "valid", mapped_status=True)
- self.eval_dataset_handle = self.train_dataset_handle
-
- elif isinstance(self.x_valid, AmpligraphDatasetAdapter):
- # this assumes that the validation data has already been set in the adapter
- self.eval_dataset_handle = self.x_valid
- else:
- msg = 'Invalid type for input X. Expected ndarray/AmpligraphDataset object, \
- got {}'.format(type(self.x_valid))
- logger.error(msg)
- raise ValueError(msg)
- except KeyError:
- msg = 'x_valid must be passed for early fitting.'
- logger.error(msg)
- raise KeyError(msg)
-
- self.early_stopping_criteria = self.early_stopping_params.get(
- 'criteria', constants.DEFAULT_CRITERIA_EARLY_STOPPING)
- if self.early_stopping_criteria not in ['hits10', 'hits1', 'hits3',
- 'mrr']:
- msg = 'Unsupported early stopping criteria.'
- logger.error(msg)
- raise ValueError(msg)
-
- self.eval_config['corruption_entities'] = self.early_stopping_params.get('corruption_entities',
- constants.DEFAULT_CORRUPTION_ENTITIES)
-
- if isinstance(self.eval_config['corruption_entities'], list):
- # convert from list of raw triples to entity indices
- logger.debug('Using the supplied entities for generation of corruptions for early stopping')
- self.eval_config['corruption_entities'] = np.asarray([idx for uri, idx in self.ent_to_idx.items()
- if uri in self.eval_config['corruption_entities']])
- elif self.eval_config['corruption_entities'] == 'all':
- logger.debug('Using all entities for generation of corruptions for early stopping')
- elif self.eval_config['corruption_entities'] == 'batch':
- logger.debug('Using batch entities for generation of corruptions for early stopping')
-
- self.eval_config['corrupt_side'] = self.early_stopping_params.get('corrupt_side',
- constants.DEFAULT_CORRUPT_SIDE_EVAL)
-
- self.early_stopping_best_value = None
- self.early_stopping_stop_counter = 0
- self.early_stopping_epoch = None
-
- try:
- # If the filter has already been set in the dataset adapter then just pass x_filter = True
- x_filter = self.early_stopping_params['x_filter']
- if isinstance(x_filter, np.ndarray):
- if x_filter.ndim <= 1 or (np.shape(x_filter)[1]) != 3:
- msg = 'Invalid size for input x_valid. Expected (n,3): got {}'.format(np.shape(x_filter))
- logger.error(msg)
- raise ValueError(msg)
- # set the filter triples in the data handler
- x_filter = to_idx(x_filter, ent_to_idx=self.ent_to_idx, rel_to_idx=self.rel_to_idx)
- self.eval_dataset_handle.set_filter(x_filter, mapped_status=True)
- # set the flag to perform filtering
- self.set_filter_for_eval()
- except KeyError:
- logger.debug('x_filter not found in early_stopping_params.')
- pass
-
- # initialize evaluation graph in validation mode i.e. to use validation set
- self._initialize_eval_graph("valid")
-
- def _perform_early_stopping_test(self, epoch):
- """Performs regular validation checks and stop early if the criteria is achieved.
-
- Parameters
- ----------
- epoch : int
- current training epoch.
- Returns
- -------
- stopped: bool
- Flag to indicate if the early stopping criteria is achieved.
- """
-
- if epoch >= self.early_stopping_params.get('burn_in',
- constants.DEFAULT_BURN_IN_EARLY_STOPPING) \
- and epoch % self.early_stopping_params.get('check_interval',
- constants.DEFAULT_CHECK_INTERVAL_EARLY_STOPPING) == 0:
- # compute and store test_loss
- ranks = []
-
- # Get each triple and compute the rank for that triple
- for x_test_triple in range(self.eval_dataset_handle.get_size("valid")):
- rank_triple = self.sess_train.run(self.rank)
- if self.eval_config.get('corrupt_side', constants.DEFAULT_CORRUPT_SIDE_EVAL) == 's,o':
- ranks.append(list(rank_triple))
- else:
- ranks.append(rank_triple)
-
- if self.early_stopping_criteria == 'hits10':
- current_test_value = hits_at_n_score(ranks, 10)
- elif self.early_stopping_criteria == 'hits3':
- current_test_value = hits_at_n_score(ranks, 3)
- elif self.early_stopping_criteria == 'hits1':
- current_test_value = hits_at_n_score(ranks, 1)
- elif self.early_stopping_criteria == 'mrr':
- current_test_value = mrr_score(ranks)
-
- if self.tensorboard_logs_path is not None:
- tag = "Early stopping {} current value".format(self.early_stopping_criteria)
- summary = tf.Summary(value=[tf.Summary.Value(tag=tag,
- simple_value=current_test_value)])
- self.writer.add_summary(summary, epoch)
-
- if self.early_stopping_best_value is None: # First validation iteration
- self.early_stopping_best_value = current_test_value
- self.early_stopping_first_value = current_test_value
- elif self.early_stopping_best_value >= current_test_value:
- self.early_stopping_stop_counter += 1
- if self.early_stopping_stop_counter == self.early_stopping_params.get(
- 'stop_interval', constants.DEFAULT_STOP_INTERVAL_EARLY_STOPPING):
-
- # If the best value for the criteria has not changed from
- # initial value then
- # save the model before early stopping
- if self.early_stopping_best_value == self.early_stopping_first_value:
- self._save_trained_params()
-
- if self.verbose:
- msg = 'Early stopping at epoch:{}'.format(epoch)
- logger.info(msg)
- msg = 'Best {}: {:10f}'.format(
- self.early_stopping_criteria,
- self.early_stopping_best_value)
- logger.info(msg)
-
- self.early_stopping_epoch = epoch
-
- return True
- else:
- self.early_stopping_best_value = current_test_value
- self.early_stopping_stop_counter = 0
- self._save_trained_params()
-
- if self.verbose:
- msg = 'Current best:{}'.format(self.early_stopping_best_value)
- logger.debug(msg)
- msg = 'Current:{}'.format(current_test_value)
- logger.debug(msg)
-
- return False
-
- def _end_training(self):
- """Performs clean up tasks after training.
- """
- # Reset this variable as it is reused during evaluation phase
- if self.is_filtered and self.eval_dataset_handle is not None:
- # cleanup the evaluation data (deletion of tables
- self.eval_dataset_handle.cleanup()
- self.eval_dataset_handle = None
-
- if self.train_dataset_handle is not None:
- self.train_dataset_handle.cleanup()
- self.train_dataset_handle = None
-
- self.is_filtered = False
- self.eval_config = {}
-
- # close the tf session
- if self.sess_train is not None:
- self.sess_train.close()
-
- # set is_fitted to true to indicate that the model fitting is completed
- self.is_fitted = True
-
- def _training_data_generator(self):
- """Generates the training data.
- If we are dealing with large graphs, then along with the training triples (of the batch),
- this method returns the idx of the entities present in the batch (along with filler entities
- sampled randomly from the rest(not in batch) to load batch_size*2 entities on the GPU) and their embeddings.
- """
-
- all_ent = np.int32(np.arange(len(self.ent_to_idx)))
- unique_entities = all_ent.reshape(-1, 1)
- # generate empty embeddings for smaller graphs - as all the entity embeddings will be loaded on GPU
- entity_embeddings = np.empty(shape=(0, self.internal_k), dtype=np.float32)
- # create iterator to iterate over the train batches
- batch_iterator = iter(self.train_dataset_handle.get_next_batch(self.batches_count, "train"))
- for i in range(self.batches_count):
- out = next(batch_iterator)
-
- out_triples = out[0]
- if self.use_focusE:
- out_weights = out[1]
-
- # If large graph, load batch_size*2 entities on GPU memory
- if self.dealing_with_large_graphs:
- # find the unique entities - these HAVE to be loaded
- unique_entities = np.int32(np.unique(np.concatenate([out_triples[:, 0],
- out_triples[:, 2]],
- axis=0)))
- # Load the remaining entities by randomly selecting from the rest of the entities
- self.leftover_entities = self.rnd.permutation(np.setdiff1d(all_ent, unique_entities))
- needed = (self.batch_size * 2 - unique_entities.shape[0])
- '''
- #this is for debugging
- large_number = np.zeros((self.batch_size-unique_entities.shape[0],
- self.ent_emb_cpu.shape[1]), dtype=np.float32) + np.nan
-
- entity_embeddings = np.concatenate((self.ent_emb_cpu[unique_entities,:],
- large_number), axis=0)
- '''
- unique_entities = np.int32(np.concatenate([unique_entities, self.leftover_entities[:needed]], axis=0))
- entity_embeddings = self.ent_emb_cpu[unique_entities, :]
-
- unique_entities = unique_entities.reshape(-1, 1)
-
- if self.use_focusE:
- for col_idx in range(out_weights.shape[1]):
- # random weights are used where weights are unknown
- nan_indices = np.isnan(out_weights[:, col_idx])
- out_weights[nan_indices, col_idx] = np.random.uniform(size=(np.sum(nan_indices)))
-
- out_weights = np.mean(out_weights, 1)
- out_weights = out_weights[:, np.newaxis]
- yield out_triples, unique_entities, entity_embeddings, out_weights
- else:
- yield np.squeeze(out_triples), unique_entities, entity_embeddings
-
- def fit(self, X, early_stopping=False, early_stopping_params={}, focusE_numeric_edge_values=None,
- tensorboard_logs_path=None):
- """Train an EmbeddingModel (with optional early stopping).
-
- The model is trained on a training set X using the training protocol
- described in :cite:`trouillon2016complex`.
-
- Parameters
- ----------
- X : ndarray (shape [n, 3]) or object of AmpligraphDatasetAdapter
- Numpy array of training triples OR handle of Dataset adapter which would help retrieve data.
- early_stopping: bool
- Flag to enable early stopping (default:``False``)
- early_stopping_params: dictionary
- Dictionary of hyperparameters for the early stopping heuristics.
-
- The following string keys are supported:
-
- - **'x_valid'**: ndarray (shape [n, 3]) or object of AmpligraphDatasetAdapter :
- Numpy array of validation triples OR handle of Dataset adapter which
- would help retrieve data.
- - **'criteria'**: string : criteria for early stopping 'hits10', 'hits3', 'hits1' or 'mrr'(default).
- - **'x_filter'**: ndarray, shape [n, 3] : Positive triples to use as filter if a 'filtered' early
- stopping criteria is desired (i.e. filtered-MRR if 'criteria':'mrr').
- Note this will affect training time (no filter by default).
- If the filter has already been set in the adapter, pass True
- - **'burn_in'**: int : Number of epochs to pass before kicking in early stopping (default: 100).
- - **check_interval'**: int : Early stopping interval after burn-in (default:10).
- - **'stop_interval'**: int : Stop if criteria is performing worse over n consecutive checks (default: 3)
- - **'corruption_entities'**: List of entities to be used for corruptions. If 'all',
- it uses all entities (default: 'all')
- - **'corrupt_side'**: Specifies which side to corrupt. 's', 'o', 's+o', 's,o' (default)
-
- Example: ``early_stopping_params={x_valid=X['valid'], 'criteria': 'mrr'}``
- focusE_numeric_edge_values: nd array (n, 1)
- Numeric values associated with links.
- Semantically, the numeric value can signify importance, uncertainity, significance, confidence, etc.
- If the numeric value is unknown pass a NaN weight. The model will uniformly randomly assign a numeric value.
- One can also think about assigning numeric values by looking at the distribution of it per predicate.
-
- tensorboard_logs_path: str or None
- Path to store tensorboard logs, e.g. average training loss tracking per epoch (default: ``None`` indicating
- no logs will be collected). When provided it will create a folder under provided path and save tensorboard
- files there. To then view the loss in the terminal run: ``tensorboard --logdir ``.
-
- """
- self.train_dataset_handle = None
- self.tensorboard_logs_path = tensorboard_logs_path
- # try-except block is mainly to handle clean up in case of exception or manual stop in jupyter notebook
- try:
- if isinstance(X, np.ndarray):
- if focusE_numeric_edge_values is not None:
- logger.debug("Using FocusE")
- self.use_focusE = True
- assert focusE_numeric_edge_values.shape[0] == X.shape[0], \
- "Each triple must have a numeric value (the size of the training set does not match the size" \
- "of the focusE_numeric_edge_values argument."
-
- if focusE_numeric_edge_values.ndim == 1:
- focusE_numeric_edge_values = focusE_numeric_edge_values.reshape(-1, 1)
-
- logger.debug("normalizing numeric values")
- unique_relations = np.unique(X[:, 1])
- for reln in unique_relations:
- for col_idx in range(focusE_numeric_edge_values.shape[1]):
- # here nans signify unknown numeric values
- if np.sum(np.isnan(
- focusE_numeric_edge_values[X[:, 1] == reln,
- col_idx])) != focusE_numeric_edge_values[
- X[:, 1] == reln, col_idx].shape[0]:
- min_val = np.nanmin(focusE_numeric_edge_values[X[:, 1] == reln,
- col_idx])
- max_val = np.nanmax(focusE_numeric_edge_values[X[:, 1] == reln,
- col_idx])
- if min_val == max_val:
- focusE_numeric_edge_values[X[:, 1] == reln, col_idx] = 1.0
- continue
-
- if self.embedding_model_params.get('normalize_numeric_values', True) \
- or min_val < 0 or max_val > 1:
- focusE_numeric_edge_values[X[:, 1] == reln, col_idx] = (
- focusE_numeric_edge_values[X[:, 1] == reln,
- col_idx] - min_val) / (
- max_val - min_val)
- else:
- pass # all the weights are nans
-
- # Adapt the numpy data in the internal format - to generalize
- self.train_dataset_handle = NumpyDatasetAdapter()
- self.train_dataset_handle.set_data(X, "train", focusE_numeric_edge_values=focusE_numeric_edge_values)
- elif isinstance(X, AmpligraphDatasetAdapter):
- self.train_dataset_handle = X
- else:
- msg = 'Invalid type for input X. Expected ndarray/AmpligraphDataset object, got {}'.format(type(X))
- logger.error(msg)
- raise ValueError(msg)
-
- # create internal IDs mappings
- self.rel_to_idx, self.ent_to_idx = self.train_dataset_handle.generate_mappings()
- prefetch_batches = 1
-
- if len(self.ent_to_idx) > ENTITY_THRESHOLD:
- self.dealing_with_large_graphs = True
-
- logger.warning('Your graph has a large number of distinct entities. '
- 'Found {} distinct entities'.format(len(self.ent_to_idx)))
-
- logger.warning('Changing the variable initialization strategy.')
- logger.warning('Changing the strategy to use lazy loading of variables...')
-
- if early_stopping:
- raise Exception('Early stopping not supported for large graphs')
-
- if not isinstance(self.optimizer, SGDOptimizer):
- raise Exception("This mode works well only with SGD optimizer with decay.\
- Kindly change the optimizer and restart the experiment. For details refer the following link: \n \
- https://docs.ampligraph.org/en/latest/dev_notes.html#dealing-with-large-graphs")
-
- if self.dealing_with_large_graphs:
- prefetch_batches = 0
- # CPU matrix of embeddings
- self.ent_emb_cpu = self.initializer.get_entity_initializer(len(self.ent_to_idx),
- self.internal_k,
- 'np')
-
- self.train_dataset_handle.map_data()
-
- # This is useful when we re-fit the same model (e.g. retraining in model selection)
- if self.is_fitted:
- tf.reset_default_graph()
- self.rnd = check_random_state(self.seed)
- tf.random.set_random_seed(self.seed)
-
- self.sess_train = tf.Session(config=self.tf_config)
- if self.tensorboard_logs_path is not None:
- self.writer = tf.summary.FileWriter(self.tensorboard_logs_path, self.sess_train.graph)
- batch_size = int(np.ceil(self.train_dataset_handle.get_size("train") / self.batches_count))
- # dataset = tf.data.Dataset.from_tensor_slices(X).repeat().batch(batch_size).prefetch(2)
-
- if len(self.ent_to_idx) > ENTITY_THRESHOLD:
- logger.warning('Only {} embeddings would be loaded in memory per batch...'.format(batch_size * 2))
-
- self.batch_size = batch_size
- self._initialize_parameters()
-
- if self.use_focusE:
- output_types = (tf.int32, tf.int32, tf.float32, tf.float32)
- output_shapes = ((None, 3), (None, 1), (None, self.internal_k), (None, 1))
- else:
- output_types = (tf.int32, tf.int32, tf.float32)
- output_shapes = ((None, 3), (None, 1), (None, self.internal_k))
-
- dataset = tf.data.Dataset.from_generator(self._training_data_generator,
- output_types=output_types,
- output_shapes=output_shapes)
-
- dataset = dataset.repeat().prefetch(prefetch_batches)
-
- dataset_iterator = tf.data.make_one_shot_iterator(dataset)
- # init tf graph/dataflow for training
- # init variables (model parameters to be learned - i.e. the embeddings)
-
- if self.loss.get_state('require_same_size_pos_neg'):
- batch_size = batch_size * self.eta
-
- loss = self._get_model_loss(dataset_iterator)
-
- train = self.optimizer.minimize(loss)
-
- # Entity embeddings normalization
- normalize_ent_emb_op = self.ent_emb.assign(tf.clip_by_norm(self.ent_emb, clip_norm=1, axes=1))
-
- self.early_stopping_params = early_stopping_params
-
- # early stopping
- if early_stopping:
- self._initialize_early_stopping()
-
- self.sess_train.run(tf.tables_initializer())
- self.sess_train.run(tf.global_variables_initializer())
- try:
- self.sess_train.run(self.set_training_true)
- except AttributeError:
- pass
-
- normalize_rel_emb_op = self.rel_emb.assign(tf.clip_by_norm(self.rel_emb, clip_norm=1, axes=1))
-
- if self.embedding_model_params.get('normalize_ent_emb', constants.DEFAULT_NORMALIZE_EMBEDDINGS):
- self.sess_train.run(normalize_rel_emb_op)
- self.sess_train.run(normalize_ent_emb_op)
-
- epoch_iterator_with_progress = tqdm(range(1, self.epochs + 1), disable=(not self.verbose), unit='epoch')
-
- for epoch in epoch_iterator_with_progress:
- losses = []
- for batch in range(1, self.batches_count + 1):
- feed_dict = {self.epoch: epoch, self.batch_number: batch - 1}
- self.optimizer.update_feed_dict(feed_dict, batch, epoch)
- if self.dealing_with_large_graphs:
- loss_batch, unique_entities, _ = self.sess_train.run([loss, self.unique_entities, train],
- feed_dict=feed_dict)
- self.ent_emb_cpu[np.squeeze(unique_entities), :] = \
- self.sess_train.run(self.ent_emb)[:unique_entities.shape[0], :]
- else:
- loss_batch, _ = self.sess_train.run([loss, train], feed_dict=feed_dict)
-
- if np.isnan(loss_batch) or np.isinf(loss_batch):
- msg = 'Loss is {}. Please change the hyperparameters.'.format(loss_batch)
- logger.error(msg)
- raise ValueError(msg)
-
- losses.append(loss_batch)
- if self.embedding_model_params.get('normalize_ent_emb', constants.DEFAULT_NORMALIZE_EMBEDDINGS):
- self.sess_train.run(normalize_ent_emb_op)
- if self.tensorboard_logs_path is not None:
- avg_loss = sum(losses) / (batch_size * self.batches_count)
- summary = tf.Summary(value=[tf.Summary.Value(tag="Average Loss",
- simple_value=avg_loss)])
- self.writer.add_summary(summary, epoch)
- if self.verbose:
- focusE = ''
- if self.use_focusE:
- focusE = '-FocusE'
- msg = 'Average {}{} Loss: {:10f}'.format(self.name,
- focusE,
- sum(losses) / (batch_size * self.batches_count))
- if early_stopping and self.early_stopping_best_value is not None:
- msg += ' — Best validation ({}): {:5f}'.format(self.early_stopping_criteria,
- self.early_stopping_best_value)
-
- logger.debug(msg)
- epoch_iterator_with_progress.set_description(msg)
-
- if early_stopping:
-
- try:
- self.sess_train.run(self.set_training_false)
- except AttributeError:
- pass
-
- if self._perform_early_stopping_test(epoch):
- if self.tensorboard_logs_path is not None:
- self.writer.flush()
- self.writer.close()
- self._end_training()
- return
-
- try:
- self.sess_train.run(self.set_training_true)
- except AttributeError:
- pass
- if self.tensorboard_logs_path is not None:
- self.writer.flush()
- self.writer.close()
-
- self._save_trained_params()
- self._end_training()
- except BaseException as e:
- self._end_training()
- raise e
-
- def set_filter_for_eval(self):
- """Configures to use filter
- """
- self.is_filtered = True
-
- def configure_evaluation_protocol(self, config=None):
- """Set the configuration for evaluation
-
- Parameters
- ----------
- config : dictionary
- Dictionary of parameters for evaluation configuration. Can contain following keys:
-
- - **corruption_entities**: List of entities to be used for corruptions.
- If ``all``, it uses all entities (default: ``all``)
- - **corrupt_side**: Specifies which side to corrupt. ``s``, ``o``, ``s+o``, ``s,o`` (default)
- In 's,o' mode subject and object corruptions are generated at once but ranked separately
- for speed up (default: False).
-
- """
- if config is None:
- config = {'corruption_entities': constants.DEFAULT_CORRUPTION_ENTITIES,
- 'corrupt_side': constants.DEFAULT_CORRUPT_SIDE_EVAL}
- self.eval_config = config
-
- def _test_generator(self, mode):
- """Generates the test/validation data. If filter_triples are passed, then it returns the False Negatives
- that could be present in the generated corruptions.
-
- If we are dealing with large graphs, then along with the above, this method returns the idx of the
- entities present in the batch and their embeddings.
- """
- test_generator = partial(self.eval_dataset_handle.get_next_batch,
- dataset_type=mode,
- use_filter=self.is_filtered)
-
- batch_iterator = iter(test_generator())
- indices_obj = np.empty(shape=(0, 1), dtype=np.int32)
- indices_sub = np.empty(shape=(0, 1), dtype=np.int32)
- unique_ent = np.empty(shape=(0, 1), dtype=np.int32)
- entity_embeddings = np.empty(shape=(0, self.internal_k), dtype=np.float32)
- for i in range(self.eval_dataset_handle.get_size(mode)):
- if self.is_filtered:
- out, indices_obj, indices_sub = next(batch_iterator)
- else:
- out = next(batch_iterator)
- # since focuse layer is not used in evaluation mode
- out = out[0]
-
- if self.dealing_with_large_graphs:
- # since we are dealing with only one triple (2 entities)
- unique_ent = np.unique(np.array([out[0, 0], out[0, 2]]))
- needed = (self.corr_batch_size - unique_ent.shape[0])
- large_number = np.zeros((needed, self.ent_emb_cpu.shape[1]), dtype=np.float32) + np.nan
- entity_embeddings = np.concatenate((self.ent_emb_cpu[unique_ent, :], large_number), axis=0)
- unique_ent = unique_ent.reshape(-1, 1)
-
- yield out, indices_obj, indices_sub, entity_embeddings, unique_ent
-
- def _generate_corruptions_for_large_graphs(self):
- """Corruption generator for large graph mode only.
- It generates corruptions in batches and also yields the corresponding entity embeddings.
- """
-
- corruption_entities = self.eval_config.get('corruption_entities', constants.DEFAULT_CORRUPTION_ENTITIES)
-
- if corruption_entities == 'all':
- all_entities_np = np.arange(len(self.ent_to_idx))
- corruption_entities = all_entities_np
- elif isinstance(corruption_entities, np.ndarray):
- corruption_entities = corruption_entities
- else:
- msg = 'Invalid type for corruption entities.'
- logger.error(msg)
- raise ValueError(msg)
-
- entity_embeddings = np.empty(shape=(0, self.internal_k), dtype=np.float32)
-
- for i in range(self.corr_batches_count):
- all_ent = corruption_entities[i * self.corr_batch_size:(i + 1) * self.corr_batch_size]
- needed = (self.corr_batch_size - all_ent.shape[0])
- large_number = np.zeros((needed, self.ent_emb_cpu.shape[1]), dtype=np.float32) + np.nan
- entity_embeddings = np.concatenate((self.ent_emb_cpu[all_ent, :], large_number), axis=0)
-
- all_ent = all_ent.reshape(-1, 1)
- yield all_ent, entity_embeddings
-
- def _initialize_eval_graph(self, mode="test"):
- """Initialize the evaluation graph.
-
- Parameters
- ----------
- mode: string
- Indicates which data generator to use.
- """
-
- # Use a data generator which returns a test triple along with the subjects and objects indices for filtering
- # The last two data are used if the graph is large. They are the embeddings of the entities that must be
- # loaded on the GPU before scoring and the indices of those embeddings.
- dataset = tf.data.Dataset.from_generator(partial(self._test_generator, mode=mode),
- output_types=(tf.int32, tf.int32, tf.int32, tf.float32, tf.int32),
- output_shapes=((1, 3), (None, 1), (None, 1),
- (None, self.internal_k), (None, 1)))
- dataset = dataset.repeat()
- dataset = dataset.prefetch(1)
- dataset_iter = tf.data.make_one_shot_iterator(dataset)
- self.X_test_tf, indices_obj, indices_sub, entity_embeddings, unique_ent = dataset_iter.get_next()
-
- corrupt_side = self.eval_config.get('corrupt_side', constants.DEFAULT_CORRUPT_SIDE_EVAL)
-
- # Rather than generating corruptions in batches do it at once on the GPU for small or medium sized graphs
- all_entities_np = np.arange(len(self.ent_to_idx))
-
- corruption_entities = self.eval_config.get('corruption_entities', constants.DEFAULT_CORRUPTION_ENTITIES)
-
- if corruption_entities == 'all':
- corruption_entities = all_entities_np
- elif isinstance(corruption_entities, np.ndarray):
- corruption_entities = corruption_entities
- else:
- msg = 'Invalid type for corruption entities.'
- logger.error(msg)
- raise ValueError(msg)
-
- # Dependencies that need to be run before scoring
- test_dependency = []
- # For large graphs
- if self.dealing_with_large_graphs:
- # Add a dependency to load the embeddings on the GPU
- init_ent_emb_batch = self.ent_emb.assign(entity_embeddings, use_locking=True)
- test_dependency.append(init_ent_emb_batch)
-
- # Add a dependency to create lookup tables(for remapping the entity indices to the order of variables on GPU
- self.sparse_mappings = tf.contrib.lookup.MutableDenseHashTable(key_dtype=tf.int32,
- value_dtype=tf.int32,
- default_value=-1,
- empty_key=-2,
- deleted_key=-1)
- insert_lookup_op = self.sparse_mappings.insert(unique_ent,
- tf.reshape(tf.range(tf.shape(unique_ent)[0],
- dtype=tf.int32), (-1, 1)))
- test_dependency.append(insert_lookup_op)
- if isinstance(corruption_entities, np.ndarray):
- # This is used for mapping the scores of corryption entities to the array which stores the scores
- # Since the number of entities are low when entities_subset is used, the size of the array
- # which stores the scores would be len(entities_subset).
- # Hence while storing, the corruption entity id needs to be mapped to array index
- rankings_mappings = tf.contrib.lookup.MutableDenseHashTable(key_dtype=tf.int32,
- value_dtype=tf.int32,
- default_value=-1,
- empty_key=-2,
- deleted_key=-1)
-
- ranking_lookup_op = rankings_mappings.insert(corruption_entities.reshape(-1, 1),
- tf.reshape(tf.range(len(corruption_entities),
- dtype=tf.int32), (-1, 1)))
- test_dependency.append(ranking_lookup_op)
-
- # Execute the dependency
- with tf.control_dependencies(test_dependency):
- # Compute scores for positive - single triple
- e_s, e_p, e_o = self._lookup_embeddings(self.X_test_tf)
- self.score_positive = tf.squeeze(self._fn(e_s, e_p, e_o))
-
- # Generate corruptions in batches
- self.corr_batches_count = int(np.ceil(len(corruption_entities) / (self.corr_batch_size)))
-
- # Corruption generator -
- # returns corruptions and their corresponding embeddings that need to be loaded on the GPU
- corruption_generator = tf.data.Dataset.from_generator(self._generate_corruptions_for_large_graphs,
- output_types=(tf.int32, tf.float32),
- output_shapes=((None, 1),
- (None, self.internal_k)))
-
- corruption_generator = corruption_generator.repeat()
- corruption_generator = corruption_generator.prefetch(0)
-
- corruption_iter = tf.data.make_one_shot_iterator(corruption_generator)
-
- # Create tensor arrays for storing the scores of subject and object evals
- # size of this array must be equal to size of entities used for corruption.
- scores_predict_s_corruptions = tf.TensorArray(dtype=tf.float32, size=(len(corruption_entities)))
- scores_predict_o_corruptions = tf.TensorArray(dtype=tf.float32, size=(len(corruption_entities)))
-
- def loop_cond(i,
- scores_predict_s_corruptions_in,
- scores_predict_o_corruptions_in):
- return i < self.corr_batches_count
-
- def compute_score_corruptions(i,
- scores_predict_s_corruptions_in,
- scores_predict_o_corruptions_in):
- corr_dependency = []
- corr_batch, entity_embeddings_corrpt = corruption_iter.get_next()
- # if self.dealing_with_large_graphs: #for debugging
- # Add dependency to load the embeddings
- init_ent_emb_corrpt = self.ent_emb.assign(entity_embeddings_corrpt, use_locking=True)
- corr_dependency.append(init_ent_emb_corrpt)
-
- # Add dependency to remap the indices to the corresponding indices on the GPU
- insert_lookup_op2 = self.sparse_mappings.insert(corr_batch,
- tf.reshape(tf.range(tf.shape(corr_batch)[0],
- dtype=tf.int32),
- (-1, 1)))
- corr_dependency.append(insert_lookup_op2)
- # end if
-
- # Execute the dependency
- with tf.control_dependencies(corr_dependency):
- emb_corr = tf.squeeze(self._entity_lookup(corr_batch))
- if isinstance(corruption_entities, np.ndarray):
- remapping = rankings_mappings.lookup(corr_batch)
- else:
- remapping = corr_batch
- if 's' in corrupt_side:
- # compute and store the scores batch wise
- scores_predict_s_c = self._fn(emb_corr, e_p, e_o)
- scores_predict_s_corruptions_in = \
- scores_predict_s_corruptions_in.scatter(tf.squeeze(remapping),
- tf.squeeze(scores_predict_s_c))
-
- if 'o' in corrupt_side:
- scores_predict_o_c = self._fn(e_s, e_p, emb_corr)
- scores_predict_o_corruptions_in = \
- scores_predict_o_corruptions_in.scatter(tf.squeeze(remapping),
- tf.squeeze(scores_predict_o_c))
-
- return i + 1, scores_predict_s_corruptions_in, scores_predict_o_corruptions_in
-
- # compute the scores for all the corruptions
- counter, scores_predict_s_corr_out, scores_predict_o_corr_out = \
- tf.while_loop(loop_cond,
- compute_score_corruptions,
- (0,
- scores_predict_s_corruptions,
- scores_predict_o_corruptions),
- back_prop=False,
- parallel_iterations=1)
-
- if 's' in corrupt_side:
- subj_corruption_scores = scores_predict_s_corr_out.stack()
-
- if 'o' in corrupt_side:
- obj_corruption_scores = scores_predict_o_corr_out.stack()
-
- non_linearity = self.embedding_model_params.get('non_linearity', 'linear')
- if non_linearity == 'linear':
- pass
- elif non_linearity == 'tanh':
- subj_corruption_scores = tf.tanh(subj_corruption_scores)
- obj_corruption_scores = tf.tanh(obj_corruption_scores)
- self.score_positive = tf.tanh(self.score_positive)
- elif non_linearity == 'sigmoid':
- subj_corruption_scores = tf.sigmoid(subj_corruption_scores)
- obj_corruption_scores = tf.sigmoid(obj_corruption_scores)
- self.score_positive = tf.sigmoid(self.score_positive)
- elif non_linearity == 'softplus':
- subj_corruption_scores = custom_softplus(subj_corruption_scores)
- obj_corruption_scores = custom_softplus(obj_corruption_scores)
- self.score_positive = custom_softplus(self.score_positive)
- else:
- raise ValueError('Invalid non-linearity')
-
- if corrupt_side == 's+o' or corrupt_side == 's,o':
- self.scores_predict = tf.concat([obj_corruption_scores, subj_corruption_scores], axis=0)
- elif corrupt_side == 'o':
- self.scores_predict = obj_corruption_scores
- else:
- self.scores_predict = subj_corruption_scores
-
- else:
-
- # Entities that must be used while generating corruptions
- self.corruption_entities_tf = tf.constant(corruption_entities, dtype=tf.int32)
-
- corrupt_side = self.eval_config.get('corrupt_side', constants.DEFAULT_CORRUPT_SIDE_EVAL)
- # Generate corruptions
- self.out_corr = generate_corruptions_for_eval(self.X_test_tf,
- self.corruption_entities_tf,
- corrupt_side)
-
- # Compute scores for negatives
- e_s, e_p, e_o = self._lookup_embeddings(self.out_corr)
- self.scores_predict = self._fn(e_s, e_p, e_o)
-
- # Compute scores for positive
- e_s, e_p, e_o = self._lookup_embeddings(self.X_test_tf)
- self.score_positive = tf.squeeze(self._fn(e_s, e_p, e_o))
-
- non_linearity = self.embedding_model_params.get('non_linearity', 'linear')
- if non_linearity == 'linear':
- pass
- elif non_linearity == 'tanh':
- self.score_positive = tf.tanh(self.score_positive)
- self.scores_predict = tf.tanh(self.scores_predict)
- elif non_linearity == 'sigmoid':
- self.score_positive = tf.sigmoid(self.score_positive)
- self.scores_predict = tf.sigmoid(self.scores_predict)
- elif non_linearity == 'softplus':
- self.score_positive = custom_softplus(self.score_positive)
- self.scores_predict = custom_softplus(self.scores_predict)
- else:
- raise ValueError('Invalid non-linearity')
-
- if corrupt_side == 's,o':
- obj_corruption_scores = tf.slice(self.scores_predict,
- [0],
- [tf.shape(self.scores_predict)[0] // 2])
-
- subj_corruption_scores = tf.slice(self.scores_predict,
- [tf.shape(self.scores_predict)[0] // 2],
- [tf.shape(self.scores_predict)[0] // 2])
-
- # this is to remove the positives from corruptions - while ranking with filter
- positives_among_obj_corruptions_ranked_higher = tf.constant(0, dtype=tf.int32)
- positives_among_sub_corruptions_ranked_higher = tf.constant(0, dtype=tf.int32)
-
- if self.is_filtered:
- # If a list of specified entities were used for corruption generation
- if isinstance(self.eval_config.get('corruption_entities',
- constants.DEFAULT_CORRUPTION_ENTITIES), np.ndarray):
- corruption_entities = self.eval_config.get('corruption_entities',
- constants.DEFAULT_CORRUPTION_ENTITIES).astype(np.int32)
- if corruption_entities.ndim == 1:
- corruption_entities = np.expand_dims(corruption_entities, 1)
- # If the specified key is not present then it would return the length of corruption_entities
- corruption_mapping = tf.contrib.lookup.MutableDenseHashTable(key_dtype=tf.int32,
- value_dtype=tf.int32,
- default_value=len(corruption_entities),
- empty_key=-2,
- deleted_key=-1)
-
- insert_lookup_op = corruption_mapping.insert(corruption_entities,
- tf.reshape(tf.range(tf.shape(corruption_entities)[0],
- dtype=tf.int32), (-1, 1)))
-
- with tf.control_dependencies([insert_lookup_op]):
- # remap the indices of objects to the smaller set of corruptions
- indices_obj = corruption_mapping.lookup(indices_obj)
- # mask out the invalid indices (i.e. the entities that were not in corruption list
- indices_obj = tf.boolean_mask(indices_obj, indices_obj < len(corruption_entities))
- # remap the indices of subject to the smaller set of corruptions
- indices_sub = corruption_mapping.lookup(indices_sub)
- # mask out the invalid indices (i.e. the entities that were not in corruption list
- indices_sub = tf.boolean_mask(indices_sub, indices_sub < len(corruption_entities))
-
- # get the scores of positives present in corruptions
- if corrupt_side == 's,o':
- scores_pos_obj = tf.gather(obj_corruption_scores, indices_obj)
- scores_pos_sub = tf.gather(subj_corruption_scores, indices_sub)
- else:
- scores_pos_obj = tf.gather(self.scores_predict, indices_obj)
- if corrupt_side == 's+o':
- scores_pos_sub = tf.gather(self.scores_predict, indices_sub + len(corruption_entities))
- else:
- scores_pos_sub = tf.gather(self.scores_predict, indices_sub)
- # compute the ranks of the positives present in the corruptions and
- # see how many are ranked higher than the test triple
- if 'o' in corrupt_side:
- positives_among_obj_corruptions_ranked_higher = self.perform_comparision(scores_pos_obj,
- self.score_positive)
- if 's' in corrupt_side:
- positives_among_sub_corruptions_ranked_higher = self.perform_comparision(scores_pos_sub,
- self.score_positive)
-
- # compute the rank of the test triple and subtract the positives(from corruptions) that are ranked higher
- if corrupt_side == 's,o':
- self.rank = tf.stack([
- self.perform_comparision(subj_corruption_scores,
- self.score_positive) + 1 - positives_among_sub_corruptions_ranked_higher,
- self.perform_comparision(obj_corruption_scores,
- self.score_positive) + 1 - positives_among_obj_corruptions_ranked_higher], 0)
- else:
- self.rank = self.perform_comparision(self.scores_predict,
- self.score_positive) + 1 - \
- positives_among_sub_corruptions_ranked_higher - \
- positives_among_obj_corruptions_ranked_higher
-
- def perform_comparision(self, score_corr, score_pos):
- ''' compares the scores of corruptions and positives using the specified strategy.
-
- Parameters:
- -----------
- score_corr:
- Tensor of scores of corruptions
- score_pos:
- Tensor of score of positive triple
-
- Returns:
- --------
- out:
- comparision output based on specified strategy
- '''
- comparision_type = self.eval_config.get('ranking_strategy',
- constants.DEFAULT_RANK_COMPARE_STRATEGY)
-
- assert comparision_type in ['worst', 'best', 'middle'], 'Invalid score comparision type!'
-
- score_corr = tf.cast(score_corr * constants.SCORE_COMPARISION_PRECISION, tf.int32)
-
- score_pos = tf.cast(score_pos * constants.SCORE_COMPARISION_PRECISION, tf.int32)
-
- # if pos score: 0.5, corr_score: 0.5, 0.5, 0.3, 0.6, 0.5, 0.5
- if comparision_type == 'best':
- # returns: 1 i.e. only. 1 corruption is having score greater than positive (optimistic)
- return tf.reduce_sum(tf.cast(score_corr > score_pos, tf.int32))
- elif comparision_type == 'middle':
-
- # returns: 3 i.e. 1 + (4/2) i.e. only 1 corruption is having score greater than positive
- # and 4 corruptions are having same (middle rank is 4/2 = 1), so 1+2=3
- return tf.reduce_sum(tf.cast(score_corr > score_pos, tf.int32)) + \
- tf.cast(tf.math.ceil(tf.reduce_sum(tf.cast(score_corr == score_pos, tf.int32)) / 2),
- tf.int32)
- else:
- # returns: 5 i.e. 5 corruptions are having score >= positive
- # as you can see this strategy returns the worst rank (pessimistic)
- return tf.reduce_sum(tf.cast(score_corr >= score_pos, tf.int32))
-
- def end_evaluation(self):
- """End the evaluation and close the Tensorflow session.
- """
-
- if self.is_filtered and self.eval_dataset_handle is not None:
- self.eval_dataset_handle.cleanup()
- self.eval_dataset_handle = None
-
- self.is_filtered = False
-
- self.eval_config = {}
-
- def get_ranks(self, dataset_handle):
- """ Used by evaluate_predictions to get the ranks for evaluation.
-
- Parameters
- ----------
- dataset_handle : Object of AmpligraphDatasetAdapter
- This contains handles of the generators that would be used to get test triples and filters
-
- Returns
- -------
- ranks : ndarray, shape [n] or [n,2] depending on the value of corrupt_side.
- An array of ranks of test triples.
- """
- if not self.is_fitted:
- msg = 'Model has not been fitted.'
- logger.error(msg)
- raise RuntimeError(msg)
-
- self.eval_dataset_handle = dataset_handle
-
- # build tf graph for predictions
- tf.reset_default_graph()
- self.rnd = check_random_state(self.seed)
- tf.random.set_random_seed(self.seed)
- # load the parameters
- self._load_model_from_trained_params()
- # build the eval graph
- self._initialize_eval_graph()
-
- with tf.Session(config=self.tf_config) as sess:
- sess.run(tf.tables_initializer())
- sess.run(tf.global_variables_initializer())
-
- try:
- sess.run(self.set_training_false)
- except AttributeError:
- pass
-
- ranks = []
-
- for _ in tqdm(range(self.eval_dataset_handle.get_size('test')), disable=(not self.verbose)):
- rank = sess.run(self.rank)
- if self.eval_config.get('corrupt_side', constants.DEFAULT_CORRUPT_SIDE_EVAL) == 's,o':
- ranks.append(list(rank))
- else:
- ranks.append(rank)
-
- return ranks
-
- def predict(self, X, from_idx=False):
- """
- Predict the scores of triples using a trained embedding model.
- The function returns raw scores generated by the model.
-
- .. note::
- To obtain probability estimates, calibrate the model with :func:`~EmbeddingModel.calibrate`,
- then call :func:`~EmbeddingModel.predict_proba`.
-
-
- Parameters
- ----------
- X : ndarray, shape [n, 3]
- The triples to score.
- from_idx : bool
- If True, will skip conversion to internal IDs. (default: False).
-
- Returns
- -------
- scores_predict : ndarray, shape [n]
- The predicted scores for input triples X.
-
- """
- if not self.is_fitted:
- msg = 'Model has not been fitted.'
- logger.error(msg)
- raise RuntimeError(msg)
-
- tf.reset_default_graph()
- self._load_model_from_trained_params()
-
- if type(X) is not np.ndarray:
- X = np.array(X)
-
- if not self.dealing_with_large_graphs:
- if not from_idx:
- X = to_idx(X, ent_to_idx=self.ent_to_idx, rel_to_idx=self.rel_to_idx)
- x_tf = tf.Variable(X, dtype=tf.int32, trainable=False)
-
- e_s, e_p, e_o = self._lookup_embeddings(x_tf)
- scores = self._fn(e_s, e_p, e_o)
-
- non_linearity = self.embedding_model_params.get('non_linearity', 'linear')
- if non_linearity == 'linear':
- pass
- elif non_linearity == 'tanh':
- scores = tf.tanh(self.scores)
- elif non_linearity == 'sigmoid':
- scores = tf.sigmoid(scores)
- elif non_linearity == 'softplus':
- scores = custom_softplus(scores)
- else:
- raise ValueError('Invalid non-linearity')
-
- with tf.Session(config=self.tf_config) as sess:
- sess.run(tf.global_variables_initializer())
- return sess.run(scores)
- else:
- dataset_handle = NumpyDatasetAdapter()
- dataset_handle.use_mappings(self.rel_to_idx, self.ent_to_idx)
- dataset_handle.set_data(X, "test", mapped_status=from_idx)
-
- self.eval_dataset_handle = dataset_handle
-
- # build tf graph for predictions
- self.rnd = check_random_state(self.seed)
- tf.random.set_random_seed(self.seed)
- # load the parameters
- # build the eval graph
- self._initialize_eval_graph()
-
- with tf.Session(config=self.tf_config) as sess:
- sess.run(tf.tables_initializer())
- sess.run(tf.global_variables_initializer())
-
- try:
- sess.run(self.set_training_false)
- except AttributeError:
- pass
-
- scores = []
-
- for _ in tqdm(range(self.eval_dataset_handle.get_size('test')), disable=(not self.verbose)):
- score = sess.run(self.score_positive)
- scores.append(score)
-
- return scores
-
- def is_fitted_on(self, X):
- """ Determine heuristically if a model was fitted on the given triples.
- Parameters
- ----------
- X : ndarray, shape [n, 3]
- The triples to score.
- Returns
- -------
- bool : True if the number of unique entities and relations in X and
- the model match.
- """
-
- if not self.is_fitted:
- msg = 'Model has not been fitted.'
- logger.error(msg)
- raise RuntimeError(msg)
-
- unique_ent = np.unique(np.concatenate((X[:, 0], X[:, 2])))
- unique_rel = np.unique(X[:, 1])
-
- if not len(unique_ent) == len(self.ent_to_idx.keys()):
- return False
- elif not len(unique_rel) == len(self.rel_to_idx.keys()):
- return False
-
- return True
-
- def _calibrate_with_corruptions(self, X_pos, batches_count):
- """
- Calibrates model with corruptions. The corruptions are hard-coded to be subject and object ('s,o')
- with all available entities.
-
- Parameters
- ----------
- X_pos : ndarray (shape [n, 3])
- Numpy array of positive triples.
-
- batches_count: int
- Number of batches to complete one epoch of the Platt scaling training.
-
- Returns
- -------
- scores_pos: tf.Tensor
- Tensor with positive scores.
-
- scores_neg: tf.Tensor
- Tensor with negative scores (generated by the corruptions).
-
- dataset_handle: NumpyDatasetAdapter
- Dataset handle (only used for clean-up).
-
- """
- dataset_handle = NumpyDatasetAdapter()
- dataset_handle.use_mappings(self.rel_to_idx, self.ent_to_idx)
-
- dataset_handle.set_data(X_pos, "pos")
-
- gen_fn = partial(dataset_handle.get_next_batch, batches_count=batches_count, dataset_type="pos")
- dataset = tf.data.Dataset.from_generator(gen_fn,
- output_types=tf.int32,
- output_shapes=(1, None, 3))
- dataset = dataset.repeat().prefetch(1)
- dataset_iter = tf.data.make_one_shot_iterator(dataset)
-
- x_pos_tf = dataset_iter.get_next()[0]
-
- e_s, e_p, e_o = self._lookup_embeddings(x_pos_tf)
- scores_pos = self._fn(e_s, e_p, e_o)
-
- x_neg_tf = generate_corruptions_for_fit(x_pos_tf,
- entities_list=None,
- eta=1,
- corrupt_side='s,o',
- entities_size=len(self.ent_to_idx),
- rnd=self.seed)
-
- e_s_neg, e_p_neg, e_o_neg = self._lookup_embeddings(x_neg_tf)
- scores_neg = self._fn(e_s_neg, e_p_neg, e_o_neg)
-
- return scores_pos, scores_neg, dataset_handle
-
- def _calibrate_with_negatives(self, X_pos, X_neg):
- """
- Calibrates model with two datasets, one with positive triples and another with negative triples.
-
- Parameters
- ----------
- X_pos : ndarray (shape [n, 3])
- Numpy array of positive triples.
-
- X_neg : ndarray (shape [n, 3])
- Numpy array of negative triples.
-
- Returns
- -------
- scores_pos: tf.Tensor
- Tensor with positive scores.
-
- scores_neg: tf.Tensor
- Tensor with negative scores.
-
- """
- x_neg = to_idx(X_neg, ent_to_idx=self.ent_to_idx, rel_to_idx=self.rel_to_idx)
- x_neg_tf = tf.Variable(x_neg, dtype=tf.int32, trainable=False)
-
- x_pos = to_idx(X_pos, ent_to_idx=self.ent_to_idx, rel_to_idx=self.rel_to_idx)
- x_pos_tf = tf.Variable(x_pos, dtype=tf.int32, trainable=False)
-
- e_s, e_p, e_o = self._lookup_embeddings(x_neg_tf)
- scores_neg = self._fn(e_s, e_p, e_o)
-
- e_s, e_p, e_o = self._lookup_embeddings(x_pos_tf)
- scores_pos = self._fn(e_s, e_p, e_o)
-
- return scores_pos, scores_neg
-
- def calibrate(self, X_pos, X_neg=None, positive_base_rate=None, batches_count=100, epochs=50):
- """Calibrate predictions
-
- The method implements the heuristics described in :cite:`calibration`,
- using Platt scaling :cite:`platt1999probabilistic`.
-
- The calibrated predictions can be obtained with :meth:`predict_proba`
- after calibration is done.
-
- Ideally, calibration should be performed on a validation set that was not used to train the embeddings.
-
- There are two modes of operation, depending on the availability of negative triples:
-
- #. Both positive and negative triples are provided via ``X_pos`` and ``X_neg`` respectively. \
- The optimization is done using a second-order method (limited-memory BFGS), \
- therefore no hyperparameter needs to be specified.
-
- #. Only positive triples are provided, and the negative triples are generated by corruptions \
- just like it is done in training or evaluation. The optimization is done using a first-order method (ADAM), \
- therefore ``batches_count`` and ``epochs`` must be specified.
-
-
- Calibration is highly dependent on the base rate of positive triples.
- Therefore, for mode (2) of operation, the user is required to provide the ``positive_base_rate`` argument.
- For mode (1), that can be inferred automatically by the relative sizes of the positive and negative sets,
- but the user can override that by providing a value to ``positive_base_rate``.
-
- Defining the positive base rate is the biggest challenge when calibrating without negatives. That depends on
- the user choice of which triples will be evaluated during test time.
- Let's take WN11 as an example: it has around 50% positives triples on both the validation set and test set,
- so naturally the positive base rate is 50%. However, should the user resample it to have 75% positives
- and 25% negatives, its previous calibration will be degraded. The user must recalibrate the model now with a
- 75% positive base rate. Therefore, this parameter depends on how the user handles the dataset and
- cannot be determined automatically or a priori.
-
- .. Note ::
- Incompatible with large graph mode (i.e. if ``self.dealing_with_large_graphs=True``).
-
- .. Note ::
- `Experiments for the ICLR-21 calibration paper are available here
- `_ :cite:`calibration`.
-
-
- Parameters
- ----------
- X_pos : ndarray (shape [n, 3])
- Numpy array of positive triples.
- X_neg : ndarray (shape [n, 3])
- Numpy array of negative triples.
-
- If `None`, the negative triples are generated via corruptions
- and the user must provide a positive base rate instead.
- positive_base_rate: float
- Base rate of positive statements.
-
- For example, if we assume there is a fifty-fifty chance of any query to be true, the base rate would be 50%.
-
- If ``X_neg`` is provided and this is `None`, the relative sizes of ``X_pos`` and ``X_neg`` will be used to
- determine the base rate. For example, if we have 50 positive triples and 200 negative triples,
- the positive base rate will be assumed to be 50/(50+200) = 1/5 = 0.2.
-
- This must be a value between 0 and 1.
- batches_count: int
- Number of batches to complete one epoch of the Platt scaling training.
- Only applies when ``X_neg`` is `None`.
- epochs: int
- Number of epochs used to train the Platt scaling model.
- Only applies when ``X_neg`` is `None`.
-
- Examples
- -------
-
- >>> import numpy as np
- >>> from sklearn.metrics import brier_score_loss, log_loss
- >>> from scipy.special import expit
- >>>
- >>> from ampligraph.datasets import load_wn11
- >>> from ampligraph.latent_features.models import TransE
- >>>
- >>> X = load_wn11()
- >>> X_valid_pos = X['valid'][X['valid_labels']]
- >>> X_valid_neg = X['valid'][~X['valid_labels']]
- >>>
- >>> model = TransE(batches_count=64, seed=0, epochs=500, k=100, eta=20,
- >>> optimizer='adam', optimizer_params={'lr':0.0001},
- >>> loss='pairwise', verbose=True)
- >>>
- >>> model.fit(X['train'])
- >>>
- >>> # Raw scores
- >>> scores = model.predict(X['test'])
- >>>
- >>> # Calibrate with positives and negatives
- >>> model.calibrate(X_valid_pos, X_valid_neg, positive_base_rate=None)
- >>> probas_pos_neg = model.predict_proba(X['test'])
- >>>
- >>> # Calibrate with just positives and base rate of 50%
- >>> model.calibrate(X_valid_pos, positive_base_rate=0.5)
- >>> probas_pos = model.predict_proba(X['test'])
- >>>
- >>> # Calibration evaluation with the Brier score loss (the smaller, the better)
- >>> print("Brier scores")
- >>> print("Raw scores:", brier_score_loss(X['test_labels'], expit(scores)))
- >>> print("Positive and negative calibration:", brier_score_loss(X['test_labels'], probas_pos_neg))
- >>> print("Positive only calibration:", brier_score_loss(X['test_labels'], probas_pos))
- Brier scores
- Raw scores: 0.4925058891371126
- Positive and negative calibration: 0.20434617882733366
- Positive only calibration: 0.22597599585144656
-
- """
- if not self.is_fitted:
- msg = 'Model has not been fitted.'
- logger.error(msg)
- raise RuntimeError(msg)
-
- if self.dealing_with_large_graphs:
- msg = "Calibration is incompatible with large graph mode."
- logger.error(msg)
- raise ValueError(msg)
-
- if positive_base_rate is not None and (positive_base_rate <= 0 or positive_base_rate >= 1):
- msg = "positive_base_rate must be a value between 0 and 1."
- logger.error(msg)
- raise ValueError(msg)
-
- dataset_handle = None
-
- try:
- tf.reset_default_graph()
- self.rnd = check_random_state(self.seed)
- tf.random.set_random_seed(self.seed)
-
- self._load_model_from_trained_params()
-
- if X_neg is not None:
- if positive_base_rate is None:
- positive_base_rate = len(X_pos) / (len(X_pos) + len(X_neg))
- scores_pos, scores_neg = self._calibrate_with_negatives(X_pos, X_neg)
- else:
- if positive_base_rate is None:
- msg = "When calibrating with randomly generated negative corruptions, " \
- "`positive_base_rate` must be set to a value between 0 and 1."
- logger.error(msg)
- raise ValueError(msg)
- scores_pos, scores_neg, dataset_handle = self._calibrate_with_corruptions(X_pos, batches_count)
-
- n_pos = len(X_pos)
- n_neg = len(X_neg) if X_neg is not None else n_pos
-
- scores_tf = tf.concat([scores_pos, scores_neg], axis=0)
- labels = tf.concat([tf.cast(tf.fill(tf.shape(scores_pos), (n_pos + 1.0) / (n_pos + 2.0)), tf.float32),
- tf.cast(tf.fill(tf.shape(scores_neg), 1 / (n_neg + 2.0)), tf.float32)],
- axis=0)
-
- # Platt scaling model
- w = tf.get_variable('w', initializer=0.0, dtype=tf.float32)
- b = tf.get_variable('b', initializer=np.log((n_neg + 1.0) / (n_pos + 1.0)).astype(np.float32),
- dtype=tf.float32)
- logits = -(w * tf.stop_gradient(scores_tf) + b)
-
- # Sample weights make sure the given positive_base_rate will be achieved irrespective of batch sizes
- weigths_pos = tf.size(scores_neg) / tf.size(scores_pos)
- weights_neg = (1.0 - positive_base_rate) / positive_base_rate
- weights = tf.concat([tf.cast(tf.fill(tf.shape(scores_pos), weigths_pos), tf.float32),
- tf.cast(tf.fill(tf.shape(scores_neg), weights_neg), tf.float32)], axis=0)
-
- loss = tf.losses.sigmoid_cross_entropy(labels, logits, weights=weights)
-
- optimizer = tf.train.AdamOptimizer()
- train = optimizer.minimize(loss)
-
- with tf.Session(config=self.tf_config) as sess:
- sess.run(tf.global_variables_initializer())
-
- epoch_iterator_with_progress = tqdm(range(1, epochs + 1), disable=(not self.verbose), unit='epoch')
- for _ in epoch_iterator_with_progress:
- losses = []
- for batch in range(batches_count):
- loss_batch, _ = sess.run([loss, train])
- losses.append(loss_batch)
- if self.verbose:
- msg = 'Calibration Loss: {:10f}'.format(sum(losses) / batches_count)
- logger.debug(msg)
- epoch_iterator_with_progress.set_description(msg)
-
- self.calibration_parameters = sess.run([w, b])
- self.is_calibrated = True
- finally:
- if dataset_handle is not None:
- dataset_handle.cleanup()
-
- def predict_proba(self, X):
- """
- Predicts probabilities using the Platt scaling model (after calibration).
-
- Model must be calibrated beforehand with the ``calibrate`` method.
-
- Parameters
- ----------
- X: ndarray (shape [n, 3])
- Numpy array of triples to be evaluated.
-
- Returns
- -------
- probas: ndarray (shape [n])
- Probability of each triple to be true according to the Platt scaling calibration.
-
- """
- if not self.is_calibrated:
- msg = "Model has not been calibrated. Please call `model.calibrate(...)` before predicting probabilities."
- logger.error(msg)
- raise RuntimeError(msg)
-
- tf.reset_default_graph()
-
- self._load_model_from_trained_params()
-
- w = tf.Variable(self.calibration_parameters[0], dtype=tf.float32, trainable=False)
- b = tf.Variable(self.calibration_parameters[1], dtype=tf.float32, trainable=False)
-
- x_idx = to_idx(X, ent_to_idx=self.ent_to_idx, rel_to_idx=self.rel_to_idx)
- x_tf = tf.Variable(x_idx, dtype=tf.int32, trainable=False)
-
- e_s, e_p, e_o = self._lookup_embeddings(x_tf)
- scores = self._fn(e_s, e_p, e_o)
- logits = -(w * scores + b)
- probas = tf.sigmoid(logits)
-
- with tf.Session(config=self.tf_config) as sess:
- sess.run(tf.global_variables_initializer())
- return sess.run(probas)
diff --git a/ampligraph/latent_features/models/HolE.py b/ampligraph/latent_features/models/HolE.py
deleted file mode 100644
index 07bcb377..00000000
--- a/ampligraph/latent_features/models/HolE.py
+++ /dev/null
@@ -1,278 +0,0 @@
-# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
-#
-# This file is Licensed under the Apache License, Version 2.0.
-# A copy of the Licence is available in LICENCE, or at:
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-from .ComplEx import ComplEx, register_model
-from ampligraph.latent_features import constants as constants
-from ampligraph.latent_features.initializers import DEFAULT_XAVIER_IS_UNIFORM
-
-
-@register_model("HolE", ["negative_corruption_entities"])
-class HolE(ComplEx):
- r"""Holographic Embeddings
-
- The HolE model :cite:`nickel2016holographic` as re-defined by Hayashi et al. :cite:`HayashiS17`:
-
- .. math::
-
- f_{HolE}= \frac{2}{n} \, f_{ComplEx}
-
- Examples
- --------
- >>> import numpy as np
- >>> from ampligraph.latent_features import HolE
- >>> model = HolE(batches_count=1, seed=555, epochs=100, k=10, eta=5,
- >>> loss='pairwise', loss_params={'margin':1},
- >>> regularizer='LP', regularizer_params={'lambda':0.1})
- >>>
- >>> X = np.array([['a', 'y', 'b'],
- >>> ['b', 'y', 'a'],
- >>> ['a', 'y', 'c'],
- >>> ['c', 'y', 'a'],
- >>> ['a', 'y', 'd'],
- >>> ['c', 'y', 'd'],
- >>> ['b', 'y', 'c'],
- >>> ['f', 'y', 'e']])
- >>> model.fit(X)
- >>> model.predict(np.array([['f', 'y', 'e'], ['b', 'y', 'd']]))
- [[0.009254738], [0.00023370088]]
- """
- def __init__(self,
- k=constants.DEFAULT_EMBEDDING_SIZE,
- eta=constants.DEFAULT_ETA,
- epochs=constants.DEFAULT_EPOCH,
- batches_count=constants.DEFAULT_BATCH_COUNT,
- seed=constants.DEFAULT_SEED,
- embedding_model_params={'negative_corruption_entities': constants.DEFAULT_CORRUPTION_ENTITIES,
- 'corrupt_sides': constants.DEFAULT_CORRUPT_SIDE_TRAIN},
- optimizer=constants.DEFAULT_OPTIM,
- optimizer_params={'lr': constants.DEFAULT_LR},
- loss=constants.DEFAULT_LOSS,
- loss_params={},
- regularizer=constants.DEFAULT_REGULARIZER,
- regularizer_params={},
- initializer=constants.DEFAULT_INITIALIZER,
- initializer_params={'uniform': DEFAULT_XAVIER_IS_UNIFORM},
- verbose=constants.DEFAULT_VERBOSE):
- """Initialize an EmbeddingModel
-
- Also creates a new Tensorflow session for training.
-
- Parameters
- ----------
- k : int
- Embedding space dimensionality
- eta : int
- The number of negatives that must be generated at runtime during training for each positive.
- epochs : int
- The iterations of the training loop.
- batches_count : int
- The number of batches in which the training set must be split during the training loop.
- seed : int
- The seed used by the internal random numbers generator.
- embedding_model_params : dict
- HolE-specific hyperparams:
-
- - **negative_corruption_entities** - Entities to be used for generation of corruptions while training.
- It can take the following values :
- ``all`` (default: all entities),
- ``batch`` (entities present in each batch),
- list of entities
- or an int (which indicates how many entities that should be used for corruption generation).
- - **corrupt_sides** : Specifies how to generate corruptions for training.
- Takes values `s`, `o`, `s+o` or any combination passed as a list.
- - **'non_linearity'**: can be one of the following values ``linear``, ``softplus``, ``sigmoid``, ``tanh``
- - **'stop_epoch'**: specifies how long to decay (linearly) the numeric values from 1 to original value
- until it reachs original value.
- - **'structural_wt'**: structural influence hyperparameter [0, 1] that modulates the influence of graph
- topology.
- - **'normalize_numeric_values'**: normalize the numeric values, such that they are scaled between [0, 1]
-
- The last 4 parameters are related to FocusE layers.
-
- optimizer : string
- The optimizer used to minimize the loss function. Choose between 'sgd',
- 'adagrad', 'adam', 'momentum'.
-
- optimizer_params : dict
- Arguments specific to the optimizer, passed as a dictionary.
-
- Supported keys:
-
- - **'lr'** (float): learning rate (used by all the optimizers). Default: 0.1.
- - **'momentum'** (float): learning momentum (only used when ``optimizer=momentum``). Default: 0.9.
-
- Example: ``optimizer_params={'lr': 0.01}``
-
- loss : string
- The type of loss function to use during training.
-
- - ``pairwise`` the model will use pairwise margin-based loss function.
- - ``nll`` the model will use negative loss likelihood.
- - ``absolute_margin`` the model will use absolute margin likelihood.
- - ``self_adversarial`` the model will use adversarial sampling loss function.
- - ``multiclass_nll`` the model will use multiclass nll loss.
- Switch to multiclass loss defined in :cite:`chen2015` by passing
- 'corrupt_sides' as ['s','o'] to embedding_model_params.
- To use loss defined in :cite:`kadlecBK17` pass 'corrupt_sides' as 'o' to embedding_model_params.
-
- loss_params : dict
- Dictionary of loss-specific hyperparameters. See :ref:`loss functions `
- documentation for additional details.
-
- Example: ``optimizer_params={'lr': 0.01}`` if ``loss='pairwise'``.
-
- regularizer : string
- The regularization strategy to use with the loss function.
-
- - ``None``: the model will not use any regularizer (default)
- - 'LP': the model will use L1, L2 or L3 based on the value of ``regularizer_params['p']`` (see below).
-
- regularizer_params : dict
- Dictionary of regularizer-specific hyperparameters. See the :ref:`regularizers `
- documentation for additional details.
-
- Example: ``regularizer_params={'lambda': 1e-5, 'p': 2}`` if ``regularizer='LP'``.
-
- initializer : string
- The type of initializer to use.
-
- - ``normal``: The embeddings will be initialized from a normal distribution
- - ``uniform``: The embeddings will be initialized from a uniform distribution
- - ``xavier``: The embeddings will be initialized using xavier strategy (default)
-
- initializer_params : dict
- Dictionary of initializer-specific hyperparameters. See the
- :ref:`initializer `
- documentation for additional details.
-
- Example: ``initializer_params={'mean': 0, 'std': 0.001}`` if ``initializer='normal'``.
-
- verbose : bool
- Verbose mode.
- """
- super().__init__(k=k, eta=eta, epochs=epochs, batches_count=batches_count, seed=seed,
- embedding_model_params=embedding_model_params,
- optimizer=optimizer, optimizer_params=optimizer_params,
- loss=loss, loss_params=loss_params,
- regularizer=regularizer, regularizer_params=regularizer_params,
- initializer=initializer, initializer_params=initializer_params,
- verbose=verbose)
- self.internal_k = self.k * 2
-
- def _fn(self, e_s, e_p, e_o):
- """The Hole scoring function.
-
- The function implements the scoring function as defined by
- .. math::
-
- f_{HolE}= 2 / n * f_{ComplEx}
-
- Additional details for equivalence of the models available in :cite:`HayashiS17`.
-
- Parameters
- ----------
- e_s : Tensor, shape [n]
- The embeddings of a list of subjects.
- e_p : Tensor, shape [n]
- The embeddings of a list of predicates.
- e_o : Tensor, shape [n]
- The embeddings of a list of objects.
-
- Returns
- -------
- score : TensorFlow operation
- The operation corresponding to the HolE scoring function.
-
- """
- return (2 / self.k) * (super()._fn(e_s, e_p, e_o))
-
- def fit(self, X, early_stopping=False, early_stopping_params={}, focusE_numeric_edge_values=None,
- tensorboard_logs_path=None):
- """Train a HolE model.
-
- The model is trained on a training set X using the training protocol
- described in :cite:`nickel2016holographic`.
-
- Parameters
- ----------
- X : ndarray, shape [n, 3]
- The training triples
- early_stopping: bool
- Flag to enable early stopping (default:False).
-
- If set to ``True``, the training loop adopts the following early stopping heuristic:
-
- - The model will be trained regardless of early stopping for ``burn_in`` epochs.
- - Every ``check_interval`` epochs the method will compute the metric specified in ``criteria``.
-
- If such metric decreases for ``stop_interval`` checks, we stop training early.
-
- Note the metric is computed on ``x_valid``. This is usually a validation set that you held out.
-
- Also, because ``criteria`` is a ranking metric, it requires generating negatives.
- Entities used to generate corruptions can be specified, as long as the side(s) of a triple to corrupt.
- The method supports filtered metrics, by passing an array of positives to ``x_filter``. This will be used to
- filter the negatives generated on the fly (i.e. the corruptions).
-
- .. note::
-
- Keep in mind the early stopping criteria may introduce a certain overhead
- (caused by the metric computation).
- The goal is to strike a good trade-off between such overhead and saving training epochs.
-
- A common approach is to use MRR unfiltered: ::
-
- early_stopping_params={x_valid=X['valid'], 'criteria': 'mrr'}
-
- Note the size of validation set also contributes to such overhead.
- In most cases a smaller validation set would be enough.
-
- early_stopping_params: dictionary
- Dictionary of hyperparameters for the early stopping heuristics.
-
- The following string keys are supported:
-
- - **'x_valid'**: ndarray, shape [n, 3] : Validation set to be used for early stopping.
- - **'criteria'**: string : criteria for early stopping 'hits10', 'hits3', 'hits1' or 'mrr'(default).
- - **'x_filter'**: ndarray, shape [n, 3] : Positive triples to use as filter if a 'filtered'
- early stopping criteria is desired (i.e. filtered-MRR if 'criteria':'mrr').
- Note this will affect training time (no filter by default).
- - **'burn_in'**: int : Number of epochs to pass before kicking in early stopping (default: 100).
- - **check_interval'**: int : Early stopping interval after burn-in (default:10).
- - **'stop_interval'**: int : Stop if criteria is performing worse over n consecutive checks (default: 3)
- - **'corruption_entities'**: List of entities to be used for corruptions.
- If 'all', it uses all entities (default: 'all')
- - **'corrupt_side'**: Specifies which side to corrupt. 's', 'o', 's+o' (default)
-
- Example: ``early_stopping_params={x_valid=X['valid'], 'criteria': 'mrr'}``
-
- focusE_numeric_edge_values: nd array (n, 1)
- Numeric values associated with links.
- Semantically, the numeric value can signify importance, uncertainity, significance, confidence, etc.
- If the numeric value is unknown pass a NaN weight. The model will uniformly randomly assign a numeric value.
- One can also think about assigning numeric values by looking at the distribution of it per predicate.
-
- tensorboard_logs_path: str or None
- Path to store tensorboard logs, e.g. average training loss tracking per epoch (default: ``None`` indicating
- no logs will be collected). When provided it will create a folder under provided path and save tensorboard
- files there. To then view the loss in the terminal run: ``tensorboard --logdir ``.
-
- """
- super().fit(X, early_stopping, early_stopping_params, focusE_numeric_edge_values,
- tensorboard_logs_path=tensorboard_logs_path)
-
- def predict(self, X, from_idx=False):
- __doc__ = super().predict.__doc__ # NOQA
- return super().predict(X, from_idx=from_idx)
-
- def calibrate(self, X_pos, X_neg=None, positive_base_rate=None, batches_count=100, epochs=50):
- __doc__ = super().calibrate.__doc__ # NOQA
- super().calibrate(X_pos, X_neg, positive_base_rate, batches_count, epochs)
-
- def predict_proba(self, X):
- __doc__ = super().predict_proba.__doc__ # NOQA
- return super().predict_proba(X)
diff --git a/ampligraph/latent_features/models/RandomBaseline.py b/ampligraph/latent_features/models/RandomBaseline.py
deleted file mode 100644
index a06885dc..00000000
--- a/ampligraph/latent_features/models/RandomBaseline.py
+++ /dev/null
@@ -1,173 +0,0 @@
-# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
-#
-# This file is Licensed under the Apache License, Version 2.0.
-# A copy of the Licence is available in LICENCE, or at:
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-import tensorflow as tf
-
-from .EmbeddingModel import EmbeddingModel, register_model
-from ampligraph.latent_features import constants as constants
-
-
-@register_model("RandomBaseline")
-class RandomBaseline(EmbeddingModel):
- """Random baseline
-
- A dummy model that assigns a pseudo-random score included between 0 and 1,
- drawn from a uniform distribution.
-
- The model is useful whenever you need to compare the performance of
- another model on a custom knowledge graph, and no other baseline is available.
-
- .. note:: Although the model still requires invoking the ``fit()`` method,
- no actual training will be carried out.
-
- Examples
- --------
- >>> import numpy as np
- >>> from ampligraph.latent_features import RandomBaseline
- >>> model = RandomBaseline()
- >>> X = np.array([['a', 'y', 'b'],
- >>> ['b', 'y', 'a'],
- >>> ['a', 'y', 'c'],
- >>> ['c', 'y', 'a'],
- >>> ['a', 'y', 'd'],
- >>> ['c', 'y', 'd'],
- >>> ['b', 'y', 'c'],
- >>> ['f', 'y', 'e']])
- >>> model.fit(X)
- >>> model.predict(np.array([['f', 'y', 'e'], ['b', 'y', 'd']]))
- [0.5488135039273248, 0.7151893663724195]
- """
-
- def __init__(self, seed=constants.DEFAULT_SEED, verbose=constants.DEFAULT_VERBOSE):
- """Initialize the model
-
- Parameters
- ----------
- seed : int
- The seed used by the internal random numbers generator.
- verbose : bool
- Verbose mode.
-
- """
- super().__init__(k=1, eta=1, epochs=1, batches_count=1, seed=seed, verbose=verbose)
- self.all_params = \
- {
- 'seed': seed,
- 'verbose': verbose
- }
-
- def _fn(self, e_s, e_p, e_o):
- """Random baseline scoring function: random number between 0 and 1.
-
- Parameters
- ----------
- e_s : Tensor, shape [n]
- The embeddings of a list of subjects.
- e_p : Tensor, shape [n]
- The embeddings of a list of predicates.
- e_o : Tensor, shape [n]
- The embeddings of a list of objects.
-
- Returns
- -------
- score : TensorFlow operation
- Random number between 0 and 1.
-
- """
- # During training TensorFlow requires that gradients with respect to the trainable variables exist
- if self.train_dataset_handle is not None:
- # Sigmoid reaches 1 quite quickly, so the `useless` variable below is 0 for all practical purposes
- useless = tf.sigmoid(tf.reduce_mean(tf.clip_by_value(e_s, 1e10, 1e11))) - 1.0
- return tf.random_uniform((tf.size(e_s),), minval=0, maxval=1) + useless
- else:
- return tf.random_uniform((tf.size(e_s),), minval=0, maxval=1)
-
- def fit(self, X, early_stopping=False, early_stopping_params={}, focusE_numeric_edge_values=None,
- tensorboard_logs_path=None):
- """Train the random model.
-
- There is no actual training involved in practice and the early stopping parameters won't have any effect.
-
- Parameters
- ----------
- X : ndarray, shape [n, 3]
- The training triples
- early_stopping: bool
- Flag to enable early stopping (default:False).
-
- If set to ``True``, the training loop adopts the following early stopping heuristic:
-
- - The model will be trained regardless of early stopping for ``burn_in`` epochs.
- - Every ``check_interval`` epochs the method will compute the metric specified in ``criteria``.
-
- If such metric decreases for ``stop_interval`` checks, we stop training early.
-
- Note the metric is computed on ``x_valid``. This is usually a validation set that you held out.
-
- Also, because ``criteria`` is a ranking metric, it requires generating negatives.
- Entities used to generate corruptions can be specified, as long as the side(s) of a triple to corrupt.
- The method supports filtered metrics, by passing an array of positives to ``x_filter``. This will be used to
- filter the negatives generated on the fly (i.e. the corruptions).
-
- .. note::
-
- Keep in mind the early stopping criteria may introduce a certain overhead
- (caused by the metric computation).
- The goal is to strike a good trade-off between such overhead and saving training epochs.
-
- A common approach is to use MRR unfiltered: ::
-
- early_stopping_params={x_valid=X['valid'], 'criteria': 'mrr'}
-
- Note the size of validation set also contributes to such overhead.
- In most cases a smaller validation set would be enough.
-
- early_stopping_params: dictionary
- Dictionary of hyperparameters for the early stopping heuristics.
-
- The following string keys are supported:
-
- - **'x_valid'**: ndarray, shape [n, 3] : Validation set to be used for early stopping.
- - **'criteria'**: string : criteria for early stopping 'hits10', 'hits3', 'hits1' or 'mrr'(default).
- - **'x_filter'**: ndarray, shape [n, 3] : Positive triples to use as filter if a 'filtered'
- early stopping criteria is desired (i.e. filtered-MRR if 'criteria':'mrr').
- Note this will affect training time (no filter by default).
- - **'burn_in'**: int : Number of epochs to pass before kicking in early stopping (default: 100).
- - **check_interval'**: int : Early stopping interval after burn-in (default:10).
- - **'stop_interval'**: int : Stop if criteria is performing worse over n consecutive checks (default: 3)
- - **'corruption_entities'**: List of entities to be used for corruptions.
- If 'all', it uses all entities (default: 'all')
- - **'corrupt_side'**: Specifies which side to corrupt. 's', 'o', 's+o' (default)
-
- Example: ``early_stopping_params={x_valid=X['valid'], 'criteria': 'mrr'}``
-
- focusE_numeric_edge_values: nd array (n, 1)
- Numeric values associated with links.
- Semantically, the numeric value can signify importance, uncertainity, significance, confidence, etc.
- If the numeric value is unknown pass a NaN weight. The model will uniformly randomly assign a numeric value.
- One can also think about assigning numeric values by looking at the distribution of it per predicate.
-
- tensorboard_logs_path: str or None
- Path to store tensorboard logs, e.g. average training loss tracking per epoch (default: ``None`` indicating
- no logs will be collected). When provided it will create a folder under provided path and save tensorboard
- files there. To then view the loss in the terminal run: ``tensorboard --logdir ``.
-
- """
- super().fit(X, early_stopping, early_stopping_params, focusE_numeric_edge_values,
- tensorboard_logs_path=tensorboard_logs_path)
-
- def predict(self, X, from_idx=False):
- __doc__ = super().predict.__doc__ # NOQA
- return super().predict(X, from_idx=from_idx)
-
- def calibrate(self, X_pos, X_neg=None, positive_base_rate=None, batches_count=100, epochs=50):
- __doc__ = super().calibrate.__doc__ # NOQA
- super().calibrate(X_pos, X_neg, positive_base_rate, batches_count, epochs)
-
- def predict_proba(self, X):
- __doc__ = super().calibrate.__doc__ # NOQA
- return super().predict_proba(X)
diff --git a/ampligraph/latent_features/models/ScoringBasedEmbeddingModel.py b/ampligraph/latent_features/models/ScoringBasedEmbeddingModel.py
new file mode 100644
index 00000000..70208d4b
--- /dev/null
+++ b/ampligraph/latent_features/models/ScoringBasedEmbeddingModel.py
@@ -0,0 +1,2242 @@
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
+#
+# This file is Licensed under the Apache License, Version 2.0.
+# A copy of the Licence is available in LICENCE, or at:
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+import tensorflow as tf
+import copy
+import shelve
+import pickle
+import numpy as np
+import os
+import tempfile
+import logging
+
+from ampligraph.evaluation.metrics import mrr_score, hits_at_n_score, mr_score
+from ampligraph.datasets import data_adapter
+from ampligraph.datasets.partitioned_data_manager import PartitionDataManager
+from ampligraph.latent_features.layers.scoring.AbstractScoringLayer import (
+ SCORING_LAYER_REGISTRY,
+)
+from ampligraph.latent_features.layers.encoding import EmbeddingLookupLayer
+from ampligraph.latent_features.layers.calibration import CalibrationLayer
+from ampligraph.latent_features.layers.corruption_generation import (
+ CorruptionGenerationLayerTrain,
+)
+from ampligraph.datasets.data_indexer import DataIndexer
+from ampligraph.latent_features import optimizers
+from ampligraph.latent_features import loss_functions
+from ampligraph.evaluation import train_test_split_no_unseen
+from tensorflow.python.keras import callbacks as callbacks_module
+from tensorflow.python.keras.engine import training_utils
+from tensorflow.python.eager import def_function
+from tensorflow.python.keras import metrics as metrics_mod
+from tensorflow.python.keras.engine import compile_utils
+
+tf.config.set_soft_device_placement(False)
+tf.debugging.set_log_device_placement(False)
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.DEBUG)
+
+
+class ScoringBasedEmbeddingModel(tf.keras.Model):
+ """Class for handling KGE models which follows the ranking based protocol.
+
+ Example
+ -------
+ >>> # create model and compile using user defined optimizer settings and
+ >>> # user defined settings of an existing loss
+ >>> from ampligraph.datasets import load_fb15k_237
+ >>> from ampligraph.latent_features import ScoringBasedEmbeddingModel
+ >>> from ampligraph.latent_features.loss_functions import SelfAdversarialLoss
+ >>> import tensorflow as tf
+ >>> X = load_fb15k_237()
+ >>> loss = SelfAdversarialLoss({'margin': 0.1, 'alpha': 5, 'reduction': 'sum'})
+ >>> model = ScoringBasedEmbeddingModel(eta=5,
+ >>> k=300,
+ >>> scoring_type='ComplEx',
+ >>> seed=0)
+ >>> model.compile(optimizer='adam', loss=loss)
+ >>> model.fit(X['train'],
+ >>> batch_size=10000,
+ >>> epochs=5)
+ Epoch 1/5
+ 29/29 [==============================] - 3s 87ms/step - loss: 13496.5752
+ Epoch 2/5
+ 29/29 [==============================] - 1s 36ms/step - loss: 13488.8682
+ Epoch 3/5
+ 29/29 [==============================] - 1s 35ms/step - loss: 13436.2725
+ Epoch 4/5
+ 29/29 [==============================] - 1s 35ms/step - loss: 13259.0840
+ Epoch 5/5
+ 29/29 [==============================] - 1s 34ms/step - loss: 12977.0117
+ """
+
+ @classmethod
+ def from_config(cls, config):
+ return cls(**config)
+
+ def get_config(self):
+ """Get the configuration hyper-parameters of the scoring based embedding model."""
+ config = super(ScoringBasedEmbeddingModel, self).get_config()
+ config.update(
+ {
+ "eta": self.eta,
+ "k": self.k,
+ "scoring_type": self.scoring_type,
+ "seed": self.seed,
+ "max_ent_size": self.encoding_layer._max_ent_size_internal,
+ "max_rel_size": self.encoding_layer._max_rel_size_internal,
+ }
+ )
+
+ return config
+
+ def __init__(
+ self,
+ eta,
+ k,
+ scoring_type="DistMult",
+ seed=0,
+ max_ent_size=None,
+ max_rel_size=None,
+ ):
+ """
+ Initializes the scoring based embedding model using the user specified scoring function.
+
+ Parameters
+ ----------
+ eta: int
+ Num of negatives to use during training per triple.
+ k: int
+ Embedding size.
+ scoring_type: str
+ Name of the scoring layer to use.
+
+ - ``TransE`` Translating embedding scoring function will be used
+ - ``DistMult`` DistMult embedding scoring function will be used
+ - ``ComplEx`` ComplEx embedding scoring function will be used
+ - ``HolE`` Holograph embedding scoring function will be used
+
+ seed: int
+ Random seed.
+ max_ent_size: int
+ Maximum number of entities that can occur in any partition (default: `None`).
+ max_rel_size: int
+ Maximum number of relations that can occur in any partition (default: `None`).
+ """
+ super(ScoringBasedEmbeddingModel, self).__init__()
+ # set the random seed
+ tf.random.set_seed(seed)
+ np.random.seed(seed)
+
+ self.max_ent_size = max_ent_size
+ self.max_rel_size = max_rel_size
+
+ self.eta = eta
+ self.scoring_type = scoring_type
+
+ # get the scoring layer
+ self.scoring_layer = SCORING_LAYER_REGISTRY[scoring_type](k)
+ # get the actual k depending on scoring layer
+ # Ex: complex model uses k embeddings for real and k for img side.
+ # so internally it has 2*k, whereas transE uses k.
+ self.k = k
+ self.internal_k = self.scoring_layer.internal_k
+
+ # create the corruption generation layer - generates eta corruptions
+ # during training
+ self.corruption_layer = CorruptionGenerationLayerTrain()
+
+ # If it is single partition, assume that you have max_ent_size unique entities
+ # This would change if we use partitions: based on which partition is in memory
+ # this attribute is used by the corruption_layer to sample eta
+ # corruptions
+ self.num_ents = self.max_ent_size
+
+ # Create the embedding lookup layer.
+ # size of entity emb is max_ent_size * k and relation emb is
+ # max_rel_size * k
+ self.encoding_layer = EmbeddingLookupLayer(
+ self.internal_k, self.max_ent_size, self.max_rel_size
+ )
+
+ # Flag to indicate whether the partitioned training is being done
+ self.is_partitioned_training = False
+ # Variable related to data indexing (entity to idx mapping)
+ self.data_indexer = True
+
+ # Flag to indicate whether to include FocusE layer
+ self.use_focusE = False
+ self.focusE_params = {}
+
+ self.is_calibrated = False
+ self.is_fitted = False
+ self.is_backward = False
+
+ self.seed = seed
+ self.base_dir = tempfile.gettempdir()
+ self.partitioner_metadata = {}
+
+ def is_fit(self):
+ """Check whether the model has been fitted already."""
+ return self.is_fitted
+
+ def compute_output_shape(self, inputShape):
+ """Returns the output shape of the outputs of the call function.
+
+ Parameters
+ ----------
+ input_shape: tuple
+ Shape of inputs of call function.
+
+ Returns
+ -------
+ output_shape: list of tuples
+ List with the shape of outputs of call function for the input triples and the corruption scores.
+ """
+ # input triple score (batch_size, 1) and corruption score (batch_size *
+ # eta, 1)
+ return [(None, 1), (None, 1)]
+
+ def partition_change_updates(self, num_ents, ent_emb, rel_emb):
+ """Perform the changes that are required when the partition is modified during training.
+
+ Parameters
+ ----------
+ num_ents: int
+ Number of unique entities in the partition.
+ ent_emb: array-like
+ Entity embeddings that need to be trained for the partition
+ (all triples of the partition will have embeddings in this matrix).
+ rel_emb: array-like
+ relation embeddings that need to be trained for the partition
+ (all triples of the partition will have embeddings in this matrix).
+
+ """
+ # save the unique entities of the partition : will be used for
+ # corruption generation
+ self.num_ents = num_ents
+ if self.encoding_layer.built:
+ # update the trainable variable in the encoding layer
+ self.encoding_layer.partition_change_updates(ent_emb, rel_emb)
+ else:
+ # if the encoding layer has not been built then store it as an initializer
+ # this would be the case of during partitioned training (first
+ # batch)
+ self.encoding_layer.set_ent_rel_initial_value(ent_emb, rel_emb)
+
+ def call(self, inputs, training=False):
+ """
+ Computes the scores of the triples and returns the corruption scores as well.
+
+ Parameters
+ ----------
+ inputs: ndarray, shape (n, 3)
+ Batch of input triples.
+
+ Returns
+ -------
+ out: list
+ List of input scores along with their corruptions.
+ """
+ # lookup embeddings of the inputs
+ inp_emb = self.encoding_layer(inputs)
+ # score the inputs
+ inp_score = self.scoring_layer(inp_emb)
+ # score the corruptions
+
+ if training:
+ # generate the corruptions for the input triples
+ corruptions = self.corruption_layer(
+ inputs, self.num_ents, self.eta
+ )
+ # lookup embeddings of the inputs
+ corr_emb = self.encoding_layer(corruptions)
+ corr_score = self.scoring_layer(corr_emb)
+
+ return inp_score, corr_score
+
+ else:
+ return inp_score
+
+ @tf.function(experimental_relax_shapes=True)
+ def _get_ranks(
+ self,
+ inputs,
+ ent_embs,
+ start_id,
+ end_id,
+ filters,
+ mapping_dict,
+ corrupt_side="s,o",
+ ranking_strategy="worst",
+ ):
+ """
+ Evaluate the inputs against corruptions and return ranks.
+
+ Parameters
+ ----------
+ inputs: array-like, shape (n, 3)
+ Batch of input triples.
+ ent_embs: array-like, shape (m, k)
+ Slice of embedding matrix (corruptions).
+ start_id: int
+ Original id of the first row of embedding matrix (used during partitioned approach).
+ end_id: int
+ Original id of the last row of embedding matrix (used during partitioned approach).
+ filters: list of lists
+ Size of list is either 1 or 2 depending on ``corrupt_side``.
+ Size of the internal list is equal to the size of the input triples.
+ Each list contains an array of filters (i.e., True Positives) related to the specified side of the
+ corresponding input triples.
+ corrupt_side: str
+ Which side to corrupt during evaluation.
+ ranking_strategy: str
+ Indicates how to break ties (default: `worst`, i.e., assigns the worst rank to the test triple).
+ Can be one of the three types `"best"`, `"middle"`, `"worst"`.
+
+ Returns
+ -------
+ rank: tf.Tensor, shape (n, num of sides being corrupted)
+ Ranking against subject corruptions and object corruptions
+ (corruptions defined by `ent_embs` matrix).
+ """
+ if not self.is_partitioned_training:
+ inputs = [
+ tf.nn.embedding_lookup(
+ self.encoding_layer.ent_emb, inputs[:, 0]
+ ),
+ tf.nn.embedding_lookup(
+ self.encoding_layer.rel_emb, inputs[:, 1]
+ ),
+ tf.nn.embedding_lookup(
+ self.encoding_layer.ent_emb, inputs[:, 2]
+ ),
+ ]
+
+ return self.scoring_layer.get_ranks(
+ inputs,
+ ent_embs,
+ start_id,
+ end_id,
+ filters,
+ mapping_dict,
+ corrupt_side,
+ ranking_strategy,
+ )
+
+ def build(self, input_shape):
+ """Override the build function of the Model class.
+
+ It is called on the first call to ``__call__``.
+ With this function we set some internal parameters of the encoding layers (needed to build that layers
+ themselves) based on the input data supplied by the user while calling the `~ScoringBasedEmbeddingModel.fit` method.
+ """
+ # set the max number of the entities that will be trained per partition
+ # in case of non-partitioned training, it is equal to the total number
+ # of entities of the dataset
+ self.encoding_layer.max_ent_size = self.max_ent_size
+ # set the max number of relations being trained just like above
+ self.encoding_layer.max_rel_size = self.max_rel_size
+ self.num_ents = self.max_ent_size
+ self.built = True
+
+ def compute_focusE_weights(self, weights, structure_weight):
+ """Compute positive and negative weights to scale scores if ``use_focusE=True``.
+
+ Parameters
+ ----------
+ weights: array-like, shape (n, m)
+ Batch of weights associated triples.
+ strucuture_weight: float
+ Structural influence assigned to the weights.
+
+ Returns
+ -------
+ out: tuple of two tf.Tensors, (tf.Tensor(shape=(n, 1)), tf.Tensor(shape=(n * self.eta, 1)))
+ Tuple where the first elements is a tensor containing the positive weights
+ and the second is a tensor containing the negative weights.
+ """
+
+ # Weights computation
+ weights = tf.reduce_mean(weights, 1)
+ weights_pos = structure_weight + (1 - structure_weight) * (1 - weights)
+ weights_neg = structure_weight + (1 - structure_weight) * (
+ tf.reshape(
+ tf.tile(weights, [self.eta]), [tf.shape(weights)[0] * self.eta]
+ )
+ )
+
+ return weights_pos, weights_neg
+
+ def train_step(self, data):
+ """
+ Training step.
+
+ Parameters
+ ----------
+ data: array-like, shape (n, m)
+ Batch of input triples (true positives) with weights associated if m>3.
+
+ Returns
+ -------
+ out: dict
+ Dictionary of metrics computed on the outputs (e.g., loss).
+ """
+ if self.data_shape > 3:
+ triples = data[0]
+ if self.data_handler._adapter.use_filter:
+ weights = data[2]
+ else:
+ weights = data[1]
+ else:
+ triples = data
+ with tf.GradientTape() as tape:
+ # get the model predictions
+ score_pos, score_neg = self(tf.cast(triples, tf.int32), training=1)
+ # focusE layer
+ if self.use_focusE:
+ logger.debug("Using FocusE")
+ non_linearity = self.focusE_params["non_linearity"]
+ structure_weight = self.focusE_params["structural_wt"]
+
+ weights_pos, weights_neg = self.compute_focusE_weights(
+ weights=weights, structure_weight=structure_weight
+ )
+ # Computation of scores
+ score_pos = non_linearity(score_pos) * weights_pos
+ score_neg = non_linearity(score_neg) * weights_neg
+
+ # compute the loss
+ loss = self.compiled_loss(
+ score_pos,
+ score_neg,
+ self.eta,
+ regularization_losses=self.losses,
+ )
+ try:
+ # minimize the loss and update the trainable variables
+ self.optimizer.minimize(
+ loss,
+ self.encoding_layer.ent_emb,
+ self.encoding_layer.rel_emb,
+ tape,
+ )
+ except ValueError as e:
+ if self.scoring_layer.name == "Random":
+ pass
+ else:
+ raise e
+
+ return {m.name: m.result() for m in self.metrics}
+
+ def make_train_function(self):
+ """Similar to keras lib, this function returns the handle to the training step function.
+ It processes one batch of data by iterating over the dataset iterator, it computes the loss and optimizes on it.
+
+ Returns
+ -------
+ out: Function handle
+ Handle to the training step function.
+ """
+ if self.train_function is not None:
+ return self.train_function
+
+ def train_function(iterator):
+ """This is the function whose handle will be returned.
+
+ Parameters
+ ----------
+ iterator: tf.data.Iterator
+ Data iterator.
+
+ Returns
+ -------
+ output: dict
+ Return a dictionary containing values that will be passed to ``tf.keras.Callbacks.on_train_batch_end``.
+ """
+ data = next(iterator)
+ output = self.train_step(data)
+ return output
+
+ if not self.run_eagerly and not self.is_partitioned_training:
+ train_function = def_function.function(
+ train_function, experimental_relax_shapes=True
+ )
+
+ self.train_function = train_function
+ return self.train_function
+
+ def get_focusE_params(self, dict_params={}):
+ """Get parameters for focusE.
+
+ Parameters
+ ----------
+ dict_params: dict
+ The following hyper-params can be passed:
+
+ - "non_linearity": can assume of the following values `"linear"`, `"softplus"`, `"sigmoid"`, `"tanh"`.
+ - "stop_epoch": specifies how long to decay (linearly) the structural influence hyper-parameter \
+ from 1 until it reaches its original value.
+ - "structural_wt": structural influence hyperparameter [0, 1] that modulates the influence of graph \
+ topology.
+
+ If the respective key is missing: ``non_linearity="linear"``, ``stop_epoch=251`` and ``structural_wt=0.001``.
+
+ Returns
+ -------
+ focusE_params : tuple
+ A tuple containing three values: the non-linearity function (`str`), the `stop_epoch` (`int`) and the
+ structure weight (`float`).
+
+ """
+ # Get the non-linearity function
+ non_linearity = dict_params.get("non_linearity", "linear")
+ if non_linearity == "linear":
+ non_linearity = tf.identity
+ elif non_linearity == "tanh":
+ non_linearity = tf.tanh
+ elif non_linearity == "sigmoid":
+ non_linearity = tf.sigmoid
+ elif non_linearity == "softplus":
+
+ def non_linearity(x):
+ return tf.math.log(1 + 9999 * tf.exp(x))
+
+ else:
+ raise ValueError("Invalid focusE non-linearity")
+
+ # Get the stop_epoch for the decay
+ stop_epoch = dict_params.get("stop_epoch", 251)
+ msg = "Invalid value for focusE stop_epoch: expected a value >=0 but got {}".format(
+ stop_epoch
+ )
+ assert stop_epoch >= 0, msg
+
+ # Get structural_wt
+ structure_weight = dict_params.get("structural_wt", 0.001)
+ assert (structure_weight <= 1) and (
+ structure_weight >= 0
+ ), "Invalid focusE 'structural_wt' passed! It has to belong to [0,1]."
+
+ # if stop_epoch == 0, fixed structure weights is used
+ if stop_epoch > 0:
+ # linear decay of numeric values
+ structure_weight = tf.maximum(
+ 1 - self.current_epoch / stop_epoch, 0.001
+ )
+
+ return non_linearity, stop_epoch, structure_weight
+
+ def update_focusE_params(self):
+ """Update the structural weight after decay."""
+ if self.focusE_params["stop_epoch"] > 0:
+ stop_epoch = self.focusE_params["stop_epoch"]
+ self.focusE_params["structural_wt"] = tf.maximum(
+ 1 - self.current_epoch / stop_epoch, 0.001
+ )
+
+ def fit(
+ self,
+ x=None,
+ batch_size=1,
+ epochs=1,
+ verbose=True,
+ callbacks=None,
+ validation_split=0.0,
+ validation_data=None,
+ shuffle=True,
+ initial_epoch=0,
+ validation_batch_size=100,
+ validation_corrupt_side="s,o",
+ validation_freq=50,
+ validation_burn_in=100,
+ validation_filter=False,
+ validation_entities_subset=None,
+ partitioning_k=1,
+ focusE=False,
+ focusE_params={},
+ ):
+ """Fit the model on the provided data.
+
+ Parameters
+ ----------
+ x: np.array, shape (n, 3), or str or GraphDataLoader or AbstractGraphPartitioner
+ Data OR Filename of the data file OR Data Handle to be used for training.
+ batch_size: int
+ Batch size to use during training.
+ May be overridden if **x** is a GraphDataLoader or AbstractGraphPartitioner instance.
+ epochs: int
+ Number of epochs to train (default: 1).
+ verbose: bool
+ Verbosity (default: `True`).
+ callbacks: list of tf.keras.callbacks.Callback
+ List of callbacks to be used during training (default: `None`).
+ validation_split: float
+ Validation split to carve out of **x** (default: 0.0) (currently supported only when **x** is a np.array).
+ validation_data: np.array, shape (n, 3) or str or `GraphDataLoader` or `AbstractGraphPartitioner`
+ Data OR Filename of the data file OR Data Handle to be used for validation.
+ shuffle: bool
+ Indicates whether to shuffle the data after every epoch during training (default: `True`).
+ initial epoch: int
+ Initial epoch number (default: 1).
+ validation_batch_size: int
+ Batch size to use during validation (default: 100).
+ May be overridden if ``validation_data`` is `GraphDataLoader` or `AbstractGraphPartitioner` instance.
+ validation_freq: int
+ Indicates how often to validate (default: 50).
+ validation_burn_in: int
+ The burn-in time after which the validation kicks in.
+ validation_filter: bool or dict
+ Validation filter to be used.
+ validation_entities_subset: list or np.array
+ Subset of entities to be used for generating corruptions.
+
+ .. Note ::
+
+ One can perform early stopping using the tensorflow callback ``tf.keras.callbacks.EarlyStopping``
+ as shown in the accompanying example below.
+
+ focusE: bool
+ Specify whether to include the FocusE layer (default: `False`).
+ The FocusE layer :cite:`pai2021learning` allows to inject numeric edge attributes into the scoring layer
+ of a traditional knowledge graph embedding architecture.
+ Semantically, the numeric value can signify importance, uncertainity, significance, confidence...
+ of a triple.
+
+ .. Note ::
+
+ In order to activate focusE, the training data must have shape (n, 4), where the first three columns
+ store subject, predicate and object of triples, and the 4-th column stores the numerical edge value
+ associated with each triple.
+
+ focusE_params: dict
+ If FocusE layer is included, specify its hyper-parameters.
+ The following hyper-params can be passed:
+
+ + `"non_linearity"`: can be one of the following values `"linear"`, `"softplus"`, `"sigmoid"`, `"tanh"`.
+ + `"stop_epoch"`: specifies how long to decay (linearly) the numeric values from 1 to original value.
+ + `"structural_wt"`: structural influence hyperparameter :math:`\\in [0, 1]` that modulates the influence of graph topology.
+
+ If ``focusE==True`` and ``focusE_params==dict()``, then the default values are passed:
+ ``non_linearity="linear"``, ``stop_epoch=251`` and ``structural_wt=0.001``.
+
+ partitioning_k: int
+ Num of partitions to use while training (default: 1, i.e., the data is not partitioned).
+ May be overridden if ``x`` is an `AbstractGraphPartitioner` instance.
+
+ .. Note ::
+
+ This function is quite useful when the size of your dataset is extremely large and cannot fit in memory.
+ Setting this to a number strictly larger than 1 will automatically partition the data using
+ ``BucketGraphPartitioner``.
+ Kindly checkout the tutorials for usage in Advanced mode.
+
+ Returns
+ -------
+ history: History object
+ Its `History.history` attribute is a record of training loss values, as well as validation loss
+ and validation metrics values.
+
+ Example
+ -------
+ >>> from ampligraph.datasets import load_fb15k_237
+ >>> from ampligraph.latent_features import ScoringBasedEmbeddingModel
+ >>> X = load_fb15k_237()
+ >>> model = ScoringBasedEmbeddingModel(eta=5,
+ >>> k=300,
+ >>> scoring_type='ComplEx',
+ >>> seed=0)
+ >>> model.compile(optimizer='adam', loss='nll')
+ >>> model.fit(X['train'],
+ >>> batch_size=10000,
+ >>> epochs=5)
+ Epoch 1/5
+ 29/29 [==============================] - 2s 71ms/step - loss: 67361.3047
+ Epoch 2/5
+ 29/29 [==============================] - 1s 35ms/step - loss: 67318.6094
+ Epoch 3/5
+ 29/29 [==============================] - 1s 37ms/step - loss: 67020.0703
+ Epoch 4/5
+ 29/29 [==============================] - 1s 35ms/step - loss: 65867.3750
+ Epoch 5/5
+ 29/29 [==============================] - 1s 35ms/step - loss: 63517.9062
+
+ >>> # Early stopping example
+ >>> from ampligraph.latent_features import ScoringBasedEmbeddingModel
+ >>> from ampligraph.datasets import load_fb15k_237
+ >>> dataset = load_fb15k_237()
+ >>> model = ScoringBasedEmbeddingModel(eta=1,
+ >>> k=10,
+ >>> scoring_type='TransE')
+ >>> model.compile(optimizer='adam', loss='multiclass_nll')
+ >>> import tensorflow as tf
+ >>> early_stop = tf.keras.callbacks.EarlyStopping(monitor="val_mrr", # which metrics to monitor
+ >>> patience=3, # If the monitored metric doesnt improve for these many checks the model early stops
+ >>> verbose=1, # verbosity
+ >>> mode="max", # how to compare the monitored metrics; "max" means higher is better
+ >>> restore_best_weights=True) # restore the weights with best value
+ >>> # the early stopping instance needs to be passed as callback to fit function
+ >>> model.fit(dataset['train'],
+ >>> batch_size=10000,
+ >>> epochs=5,
+ >>> validation_freq=1, # validation frequency
+ >>> validation_batch_size=100, # validation batch size
+ >>> validation_burn_in=3, # burn in time
+ >>> validation_corrupt_side='s,o', # which side to corrupt
+ >>> validation_data=dataset['valid'][::100], # Validation data
+ >>> callbacks=[early_stop]) # Pass the early stopping object as a callback
+ Epoch 1/5
+ 29/29 [==============================] - 2s 82ms/step - loss: 6698.2188
+ Epoch 2/5
+ 29/29 [==============================] - 1s 34ms/step - loss: 6648.8862
+ Epoch 3/5
+ 3/3 [==============================] - 1s 446ms/steposs: 6652.895
+ 29/29 [==============================] - 2s 84ms/step - loss: 6590.2842 - val_mrr: 0.0811 -
+ val_mr: 1776.4545 - val_hits@1: 0.0000e+00 - val_hits@10: 0.2301 - val_hits@100: 0.4148
+ Epoch 4/5
+ 3/3 [==============================] - 0s 102ms/steposs: 6564.021
+ 29/29 [==============================] - 1s 47ms/step - loss: 6517.4517 - val_mrr: 0.0918 -
+ val_mr: 1316.6335 - val_hits@1: 0.0000e+00 - val_hits@10: 0.2528 - val_hits@100: 0.4716
+ Epoch 5/5
+ 3/3 [==============================] - 1s 177ms/steposs: 6468.798
+ 29/29 [==============================] - 2s 62ms/step - loss: 6431.8696 - val_mrr: 0.0901 -
+ val_mr: 1074.8920 - val_hits@1: 0.0000e+00 - val_hits@10: 0.2386 - val_hits@100: 0.4773
+
+ """
+ # verifies if compile has been called before calling fit
+ self._assert_compile_was_called()
+ # focusE
+ self.current_epoch = 0
+ self.use_focusE = focusE
+
+ # use train test unseen to split training set
+ if validation_split:
+ assert isinstance(
+ x, np.ndarray
+ ), "Validation split supported for numpy arrays only!"
+ x, validation_data = train_test_split_no_unseen(
+ x,
+ test_size=validation_split,
+ seed=self.seed,
+ allow_duplication=False,
+ )
+
+ with training_utils.RespectCompiledTrainableState(self):
+ # create data handler for the data
+ self.data_handler = data_adapter.DataHandler(
+ x,
+ model=self,
+ batch_size=batch_size,
+ dataset_type="train",
+ epochs=epochs,
+ initial_epoch=initial_epoch,
+ use_filter=False,
+ # if model is already
+ # trained use the old
+ # indexer
+ use_indexer=self.data_indexer,
+ partitioning_k=partitioning_k,
+ )
+
+ self.partitioner_metadata = (
+ self.data_handler.get_update_partitioner_metadata(
+ self.base_dir
+ )
+ )
+ # get the mapping details
+ self.data_indexer = self.data_handler.get_mapper()
+ # get the maximum entities and relations that will be trained
+ # (useful during partitioning)
+ self.max_ent_size = self.data_handler._adapter.max_entities
+ self.max_rel_size = self.data_handler._adapter.max_relations
+ # Number of columns (i.e., only triples or also weights?)
+ if isinstance(self.data_handler._adapter, PartitionDataManager):
+ self.data_shape = (
+ self.data_handler._parent_adapter.backend.data_shape
+ )
+ else:
+ self.data_shape = self.data_handler._adapter.backend.data_shape
+
+ # FocusE
+ if self.data_shape < 4:
+ self.use_focusE = False
+ else:
+ if self.use_focusE:
+ assert isinstance(
+ focusE_params, dict
+ ), "focusE parameters need to be in a dict!"
+ # Define FocusE params
+ (
+ non_linearity,
+ stop_epoch,
+ structure_weight,
+ ) = self.get_focusE_params(focusE_params)
+ self.focusE_params = {
+ "non_linearity": non_linearity,
+ "stop_epoch": stop_epoch,
+ "structural_wt": structure_weight,
+ }
+ else:
+ print(
+ "Data shape is {}: not only triples were given, but focusE is not active!".format(
+ self.data_shape
+ )
+ )
+
+ # Container that configures and calls `tf.keras.Callback`s.
+ if not isinstance(callbacks, callbacks_module.CallbackList):
+ callbacks = callbacks_module.CallbackList(
+ callbacks,
+ add_history=True,
+ add_progbar=verbose != 0,
+ model=self,
+ verbose=verbose,
+ epochs=epochs,
+ )
+
+ # This variable is used by callbacks to stop training in case of
+ # any error
+ self.stop_training = False
+ self.is_partitioned_training = self.data_handler.using_partitioning
+ self.optimizer.set_partitioned_training(
+ self.is_partitioned_training
+ )
+
+ # set some partition related params if it is partitioned training
+ if self.is_partitioned_training:
+ self.partitioner_k = self.data_handler._adapter.partitioner_k
+ self.encoding_layer.max_ent_size = self.max_ent_size
+ self.encoding_layer.max_rel_size = self.max_rel_size
+
+ # make the train function that will be used to process each batch
+ # of data
+ train_function = self.make_train_function()
+ # before training begins call this callback function
+ callbacks.on_train_begin()
+
+ if (
+ isinstance(validation_entities_subset, str)
+ and validation_entities_subset == "all"
+ ):
+ # if the subset is set to none, it will use all entities in the
+ # graph for generating corruptions
+ validation_entities_subset = None
+
+ # enumerate over the data
+ for epoch, iterator in self.data_handler.enumerate_epochs():
+ # current epoch number
+ self.current_epoch = epoch
+ # before epoch begins call this callback function
+ callbacks.on_epoch_begin(epoch)
+ # Update focusE parameter
+ if self.use_focusE:
+ self.update_focusE_params()
+ # handle the stop iteration of data iterator in this scope
+ with self.data_handler.catch_stop_iteration():
+ # iterate over the dataset
+ for step in self.data_handler.steps():
+ # before a batch is processed call this callback
+ # function
+ callbacks.on_train_batch_begin(step)
+
+ # process this batch
+ logs = train_function(iterator)
+ # after a batch is processed call this callback
+ # function
+ callbacks.on_train_batch_end(step, logs)
+
+ # store the logs of the last batch of the epoch
+ epoch_logs = copy.copy(logs)
+ # if validation is enabled
+ if (
+ epoch >= (validation_burn_in - 1)
+ and validation_data is not None
+ and self._should_eval(epoch, validation_freq)
+ ):
+ if self.data_shape > 3 and validation_data.shape[1] == 3:
+ nan_weights = np.empty(validation_data.shape[0])
+ nan_weights.fill(np.nan)
+ validation_data = np.concatenate(
+ [validation_data, nan_weights], axis=1
+ )
+ # evaluate on the validation
+ ranks = self.evaluate(
+ validation_data,
+ batch_size=validation_batch_size or batch_size,
+ use_filter=validation_filter,
+ dataset_type="valid",
+ corrupt_side=validation_corrupt_side,
+ entities_subset=validation_entities_subset,
+ )
+ # compute all the metrics
+ val_logs = {
+ "val_mrr": mrr_score(ranks),
+ "val_mr": mr_score(ranks),
+ "val_hits@1": hits_at_n_score(ranks, 1),
+ "val_hits@10": hits_at_n_score(ranks, 10),
+ "val_hits@100": hits_at_n_score(ranks, 100),
+ }
+ # update the epoch logs with validation details
+ epoch_logs.update(val_logs)
+
+ # after an epoch is completed, call this callback function
+ callbacks.on_epoch_end(epoch, epoch_logs)
+ if self.stop_training:
+ break
+
+ # on training end call this method
+ callbacks.on_train_end()
+ self.is_fitted = True
+ # all the training and validation logs are stored in the history
+ # object by keras.Model
+ return self.history
+
+ def get_indexes(self, X, type_of="t", order="raw2ind"):
+ """Converts given data to indexes or to raw data (according to ``order``).
+
+ It works for ``X`` containing triples, entities, or relations.
+
+ Parameters
+ ----------
+ X: np.array or list
+ Data to be indexed.
+ type_of: str
+ Specifies whether to get indexes/raw data for triples (``type_of='t'``), entities (``type_of='e'``),
+ or relations (``type_of='r'``).
+ order: str
+ Specifies whether to get indexes from raw data (``order='raw2ind'``) or
+ raw data from indexes (``order='ind2raw'``).
+
+ Returns
+ -------
+ Y: np.array
+ Indexed data or raw data.
+
+ Example
+ -------
+ >>> from ampligraph.latent_features import ScoringBasedEmbeddingModel
+ >>> from ampligraph.datasets import load_fb15k_237
+ >>> X = load_fb15k_237()
+ >>> model = ScoringBasedEmbeddingModel(eta=5,
+ >>> k=300,
+ >>> scoring_type='ComplEx',
+ >>> seed=0)
+ >>> model.compile(optimizer='adam', loss='nll')
+ >>> model.fit(X['train'],
+ >>> batch_size=10000,
+ >>> epochs=5,
+ >>> verbose=False)
+ >>> print(model.get_indexes(['/m/027rn', '/m/06v8s0'], 'e', 'raw2ind'))
+ >>> print(model.get_indexes([3877, 0], 'e', 'ind2raw'))
+ [0, 3877]
+ ['/m/06v8s0', '/m/027rn']
+ """
+ return self.data_indexer.get_indexes(X, type_of, order)
+
+ def get_count(self, concept_type="e"):
+ """Returns the count of entities and relations that were present during training.
+
+ Parameters
+ ----------
+ concept_type: str
+ Indicates whether to count entities (``concept_type='e'``) or
+ relations (``concept_type='r'``) (default: `'e'`).
+
+ Returns
+ -------
+ count: int
+ Count of the entities or relations.
+
+ Example
+ -------
+ >>> from ampligraph.datasets import load_fb15k_237
+ >>> from ampligraph.latent_features import ScoringBasedEmbeddingModel
+ >>> X = load_fb15k_237()
+ >>> model = ScoringBasedEmbeddingModel(eta=5,
+ >>> k=300,
+ >>> scoring_type='ComplEx',
+ >>> seed=0)
+ >>> model.compile(optimizer='adam', loss='nll')
+ >>> model.fit(X['train'],
+ >>> batch_size=10000,
+ >>> epochs=5,
+ >>> verbose=False)
+ >>> print('Entities:', model.get_count('e'))
+ >>> print('Relations:', model.get_count('r'))
+ Entities: 14505
+ Relations: 237
+ """
+ assert self.is_fitted, "Model is not fit on the data yet!"
+ if concept_type == "e":
+ return self.data_indexer.get_entities_count()
+ elif concept_type == "r":
+ return self.data_indexer.get_relations_count()
+ else:
+ raise ValueError("Invalid Concept Type (expected 'e' or 'r')")
+
+ def get_train_embedding_matrix_size(self):
+ """Returns the size of the embedding matrix used for training.
+
+ This may not be same as (n, k) during partitioned training (where `n` is the number of triples in the
+ whole training set).
+ """
+ assert self.is_fitted, "Model is not fit on the data yet!"
+ return {
+ "e": self.encoding_layer.ent_emb.shape,
+ "r": self.encoding_layer.rel_emb.shape,
+ }
+
+ def save(
+ self,
+ filepath,
+ overwrite=True,
+ include_optimizer=True,
+ save_format=None,
+ signatures=None,
+ options=None,
+ save_traces=True,
+ ):
+ """Save the model."""
+ super(ScoringBasedEmbeddingModel, self).save(
+ filepath,
+ overwrite,
+ include_optimizer,
+ save_format,
+ signatures,
+ options,
+ save_traces,
+ )
+ self.save_metadata(filedir=filepath)
+
+ def save_metadata(self, filepath=None, filedir=None):
+ """Save metadata."""
+ # store ampligraph specific metadata
+ if filepath is not None:
+ base_dir = os.path.dirname(filedir)
+ base_dir = "." if base_dir == "" else base_dir
+ filepath = os.path.basename(filepath)
+
+ if filedir is not None:
+ base_dir = filedir
+ filepath = os.path.basename(filedir)
+
+ with open(
+ os.path.join(base_dir, filepath + "_metadata.ampkl"), "wb"
+ ) as f:
+ metadata = {
+ "is_partitioned_training": self.is_partitioned_training,
+ "max_ent_size": self.max_ent_size,
+ "max_rel_size": self.max_rel_size,
+ "eta": self.eta,
+ "k": self.k,
+ "is_fitted": self.is_fitted,
+ "is_calibrated": self.is_calibrated,
+ "is_backward": self.is_backward,
+ }
+
+ metadata.update(self.data_indexer.get_update_metadata(base_dir))
+ if self.is_partitioned_training:
+ self.partitioner_metadata = (
+ self.data_handler.get_update_partitioner_metadata(base_dir)
+ )
+ metadata.update(self.partitioner_metadata)
+
+ if self.is_calibrated:
+ metadata["calib_w"] = self.calibration_layer.calib_w.numpy()
+ metadata["calib_b"] = self.calibration_layer.calib_b.numpy()
+ metadata["pos_size"] = self.calibration_layer.pos_size
+ metadata["neg_size"] = self.calibration_layer.neg_size
+ metadata[
+ "positive_base_rate"
+ ] = self.calibration_layer.positive_base_rate
+ pickle.dump(metadata, f)
+
+ def save_weights(self, filepath, overwrite=True):
+ """Save the trainable weights.
+
+ Use this function if the training process is complete and you want to
+ use the model only for inference. Use :meth:`load_weights` to load the model weights back.
+
+ .. Note ::
+ If you want to be able of continuing the training, you can use the :meth:`ampligraph.utils.save_model`
+ and :meth:`ampligraph.utils.restore_model`.These functions save and restore the entire state
+ of the graph, which allows to continue the training from where it was stopped.
+
+ Parameters
+ ----------
+ filepath: str
+ Path to save the model.
+ overwrite: bool
+ Flag which indicates whether the model, if present, needs to be overwritten or not (default: `True`).
+ """
+ # TODO: verify other formats
+
+ # call the base class method to save the weights
+ if not self.is_partitioned_training:
+ super(ScoringBasedEmbeddingModel, self).save_weights(
+ filepath, overwrite
+ )
+ self.save_metadata(filepath)
+
+ def build_full_model(self, batch_size=100):
+ """This method is called while loading the weights to build the model."""
+ self.build((batch_size, 3))
+ for i in range(len(self.layers)):
+ self.layers[i].build((batch_size, 3))
+ self.layers[i].built = True
+
+ def load_metadata(self, filepath=None, filedir=None):
+ if filedir is not None:
+ filepath = os.path.join(filedir, os.path.basename(filedir))
+
+ with open(filepath + "_metadata.ampkl", "rb") as f:
+ metadata = pickle.load(f)
+ metadata["root_directory"] = os.path.dirname(filepath)
+ metadata["root_directory"] = (
+ "."
+ if metadata["root_directory"] == ""
+ else metadata["root_directory"]
+ )
+ self.base_dir = metadata["root_directory"]
+ try:
+ metadata["db_file"] = os.path.basename(metadata["db_file"])
+ except KeyError:
+ print("Saved model does not include a db file. Skipping.")
+
+ self.data_indexer = DataIndexer([], **metadata)
+ self.is_partitioned_training = metadata["is_partitioned_training"]
+ self.max_ent_size = metadata["max_ent_size"]
+ self.max_rel_size = metadata["max_rel_size"]
+ self.is_fitted = metadata["is_fitted"]
+ self.is_backward = metadata.get("is_backward", False)
+ if self.is_partitioned_training:
+ self.partitioner_k = metadata["partitioner_k"]
+ self.partitioner_metadata = {}
+ self.partitioner_metadata["ent_map_fname"] = metadata[
+ "ent_map_fname"
+ ]
+ self.partitioner_metadata["rel_map_fname"] = metadata[
+ "rel_map_fname"
+ ]
+
+ self.is_calibrated = metadata["is_calibrated"]
+ if self.is_calibrated:
+ self.calibration_layer = CalibrationLayer(
+ metadata["pos_size"],
+ metadata["neg_size"],
+ metadata["positive_base_rate"],
+ calib_w=metadata["calib_w"],
+ calib_b=metadata["calib_b"],
+ )
+
+ def load_weights(self, filepath):
+ """Loads the model weights.
+
+ Use this function if ``save_weights`` was used to save the model.
+
+ .. Note ::
+ If you want to continue training, you can use the :meth:`ampligraph.utils.save_model` and
+ :meth:`ampligraph.utils.load_model`. These functions save the entire state of the graph
+ which allows to continue the training from where it stopped.
+
+ Parameters
+ ----------
+ filepath: str
+ Path to save the model.
+ """
+ self.load_metadata(filepath)
+ self.build_full_model()
+ if not self.is_partitioned_training:
+ super(ScoringBasedEmbeddingModel, self).load_weights(filepath)
+
+ def compile(
+ self,
+ optimizer="adam",
+ loss=None,
+ entity_relation_initializer="glorot_uniform",
+ entity_relation_regularizer=None,
+ **kwargs
+ ):
+ """ Compile the model.
+
+ Parameters
+ ----------
+ optimizer: str (name of optimizer) or optimizer instance
+ The optimizer used to minimize the loss function. For pre-defined options, choose between
+ `"sgd"`, `"adagrad"`, `"adam"`, `"rmsprop"`, etc.
+ See `tf.keras.optimizers `_
+ for up-to-date details.
+
+ If a string is passed, then the default parameters of the optimizer will be used.
+
+ If you want to use custom hyperparameters you need to create an instance of the optimizer and
+ pass the instance to the compile function ::
+
+ import tensorflow as tf
+ adam_opt = tf.keras.optimizers.Adam(learning_rate=0.003)
+ model.compile(loss='pairwise', optim=adam_opt)
+
+ loss: str (name of objective function), objective function or `ampligraph.latent_features.loss_functions.Loss`
+
+ If a string is passed, you can use one of the following losses which will be used with their
+ default setting:
+
+ - `"pairwise"`: the model will use the pairwise margin-based loss function.
+ - `"nll"`: the model will use the negative loss likelihood.
+ - `"absolute_margin"`: the model will use the absolute margin likelihood.
+ - `"self_adversarial"`: the model will use the adversarial sampling loss function.
+ - `"multiclass_nll"`: the model will use the multiclass nll loss. ::
+
+ model.compile(loss='absolute_margin', optim='adam')
+
+ If you want to modify the default parameters of the loss function, you need to explictly create an instance
+ of the loss with required hyperparameters and then pass this instance. ::
+
+ from ampligraph.latent_features import AbsoluteMarginLoss
+ ab_loss = AbsoluteMarginLoss(loss_params={'margin': 3})
+ model.compile(loss=ab_loss, optim='adam')
+
+ An objective function is any callable with the signature
+ ``loss = fn(score_true, score_corr, eta)`` ::
+
+ # Create a user defined loss function with the above signature
+ def userLoss(scores_pos, scores_neg):
+ # user defined loss - takes in 2 params and returns loss
+ neg_exp = tf.exp(scores_neg)
+ pos_exp = tf.exp(scores_pos)
+ # Apply softmax to the scores
+ score = pos_exp / (tf.reduce_sum(neg_exp, axis=0) + pos_exp)
+ loss = -tf.math.log(score)
+ return loss
+ # Pass this loss while compiling the model
+ model.compile(loss=userLoss, optim='adam')
+
+ entity_relation_initializer: str (name of initializer function), initializer function or \
+ `tf.keras.initializers.Initializer` or list.
+
+ Initializer of the entity and relation embeddings. This is either a single value or a list of size 2.
+ If a single value is passed, then both the entities and relations will be initialized based on
+ the same initializer; if a list, the first initializer will be used for entities and the second
+ for relations.
+
+ If a string is passed, then the default parameters will be used. Choose between
+ `"random_normal"`, `"random_uniform"`, `"glorot_normal"`, `"he_normal"`, etc.
+
+ See `tf.keras.initializers `_
+ for up-to-date details. ::
+
+ model.compile(loss='pairwise', optim='adam',
+ entity_relation_initializer='random_normal')
+
+ If the user wants to use custom hyperparameters, then an instance of the
+ ``tf.keras.initializers.Initializer`` needs to be passed. ::
+
+ import tensorflow as tf
+ init = tf.keras.initializers.RandomNormal(stddev=0.00003)
+ model.compile(loss='pairwise', optim='adam',
+ entity_relation_initializer=init)
+
+ If the user wants to define custom initializer it can be any callable with the signature `init = fn(shape)` ::
+
+ def my_init(shape):
+ return tf.random.normal(shape)
+ model.compile(loss='pairwise', optim='adam',
+ entity_relation_initializer=my_init)
+
+ entity_relation_regularizer: str (name of regularizer function) or regularizer function or \
+ `tf.keras.regularizers.Regularizer` instance or list
+ Regularizer of entities and relations.
+ If a single value is passed, then both the entities and relations will be regularized based on
+ the same regularizer; if a list, the first regularizer will be used for entities and second
+ for relations.
+
+ If a string is passed, then the default parameters of the regularizers will be used. Choose between
+ `"l1"`, `"l2"`, `"l1_l2"`, etc.
+
+ See `tf.keras.regularizers `_
+ for up-to-date details. ::
+
+ model.compile(loss='pairwise', optim='adam',
+ entity_relation_regularizer='l2')
+
+ If the user wants to use custom hyperparameters, then an instance of the
+ ``tf.keras.regularizers.Regularizer`` needs to be passed. ::
+
+ import tensorflow as tf
+ reg = tf.keras.regularizers.L1L2(l1=0.001, l2=0.1)
+ model.compile(loss='pairwise', optim='adam',
+ entity_relation_regularizer=reg)
+
+ If the user wants to define custom regularizer it can be any callable with signature
+ ``reg = fn(weight_matrix)``. ::
+
+ def my_reg(weight_mx):
+ return 0.01 * tf.math.reduce_sum(tf.math.abs(weight_mx))
+ model.compile(loss='pairwise', optim='adam',
+ entity_relation_regularizer=my_reg)
+
+ Example
+ -------
+ >>> from ampligraph.datasets import load_fb15k_237
+ >>> from ampligraph.latent_features import ScoringBasedEmbeddingModel
+ >>> X = load_fb15k_237()
+ >>> model = ScoringBasedEmbeddingModel(eta=5,
+ >>> k=300,
+ >>> scoring_type='ComplEx',
+ >>> seed=0)
+ >>> model.compile(optimizer='adam', loss='nll')
+ >>> model.fit(X['train'],
+ >>> batch_size=10000,
+ >>> epochs=5)
+ Epoch 1/5
+ 29/29 [==============================] - 2s 61ms/step - loss: 67361.3047
+ Epoch 2/5
+ 29/29 [==============================] - 1s 35ms/step - loss: 67318.6094
+ Epoch 3/5
+ 29/29 [==============================] - 1s 34ms/step - loss: 67020.0703
+ Epoch 4/5
+ 29/29 [==============================] - 1s 34ms/step - loss: 65867.3750
+ Epoch 5/5
+ 29/29 [==============================] - 1s 34ms/step - loss: 63517.9062
+
+ """
+ # get the optimizer
+ self.optimizer = optimizers.get(optimizer)
+ self._run_eagerly = kwargs.pop("run_eagerly", None)
+ # reset the training/evaluate/predict function
+ self._reset_compile_cache()
+
+ # get the loss
+ self.compiled_loss = loss_functions.get(loss)
+ # Only metric supported during the training is mean Loss
+ self.compiled_metrics = compile_utils.MetricsContainer(
+ metrics_mod.Mean(name="loss"), None, None
+ ) # Total loss.
+
+ # set the initializer and regularizer of the embedding matrices in the
+ # encoding layer
+ self.encoding_layer.set_initializer(entity_relation_initializer)
+ self.encoding_layer.set_regularizer(entity_relation_regularizer)
+ self._is_compiled = True
+
+ @property
+ def metrics(self):
+ """Returns all the metrics that will be computed during training."""
+ metrics = []
+ if self._is_compiled:
+ if self.compiled_loss is not None:
+ metrics += self.compiled_loss.metrics
+ return metrics
+
+ def get_emb_matrix_test(self, part_number=1, number_of_parts=1):
+ """Get the embedding matrix during evaluation.
+
+ Parameters
+ ----------
+ part number: int
+ Specifies which part to return from the ``number_of_parts`` in which the entire embedding matrix is split.
+ number_of_parts: int
+ Total number of parts in which to split the embedding matrix.
+
+ Returns
+ -------
+ emb_matrix: np.array, shape (n,k)
+ Part of the embedding matrix corresponding to `part_number`.
+ start_index: int
+ Original entity index (data dict) of the first row of the `emb_matrix`.
+ end_index: int
+ Original entity index (data dict) of the last row of the `emb_matrix`.
+
+ """
+ if number_of_parts == 1:
+ if self.entities_subset.shape[0] != 0:
+ out = tf.nn.embedding_lookup(
+ self.encoding_layer.ent_emb, self.entities_subset
+ )
+ else:
+ out = self.encoding_layer.ent_emb
+ return out, 0, out.shape[0] - 1
+ else:
+ with shelve.open(
+ self.partitioner_metadata["ent_map_fname"]
+ ) as ent_partition:
+ batch_size = int(
+ np.ceil(len(ent_partition.keys()) / number_of_parts)
+ )
+ indices = np.arange(
+ part_number * batch_size, (part_number + 1) * batch_size
+ ).astype(str)
+ emb_matrix = []
+ for idx in indices:
+ try:
+ emb_matrix.append(ent_partition[idx])
+ except KeyError:
+ break
+ return np.array(emb_matrix), int(indices[0]), int(indices[-1])
+
+ def make_test_function(self):
+ """Similar to keras lib, this function returns the handle to test step function.
+
+ It processes one batch of data by iterating over the dataset iterator and computes the test metrics.
+
+ Returns
+ -------
+ out: Function handle
+ Handle to the test step function.
+ """
+
+ # if self.test_function is not None:
+ # return self.test_function
+
+ def test_function(iterator):
+ # total number of parts in which to split the embedding matrix
+ # (default 1, i.e., use full matrix as it is)
+ number_of_parts = 1
+
+ # if it is partitioned training
+ if self.is_partitioned_training:
+ # split the emb matrix based on number of buckets
+ number_of_parts = self.partitioner_k
+
+ data = next(iterator)
+ if self.use_filter:
+ inputs, filters = data[0], data[1]
+ else:
+ if self.data_shape > 3:
+ inputs, filters = data[
+ 0
+ ], tf.RaggedTensor.from_row_lengths([], [])
+ else:
+ inputs, filters = data, tf.RaggedTensor.from_row_lengths(
+ [], []
+ )
+
+ # compute the output shape based on the type of corruptions to be
+ # used
+ output_shape = 0
+ if "s" in self.corrupt_side:
+ output_shape += 1
+
+ if "o" in self.corrupt_side:
+ output_shape += 1
+
+ # create an array to store the ranks based on output shape
+ overall_rank = tf.zeros(
+ (output_shape, tf.shape(inputs)[0]), dtype=np.int32
+ )
+
+ if self.is_partitioned_training:
+ inputs = self.process_model_inputs_for_test(inputs)
+
+ # run the loop based on number of parts in which the original emb
+ # matrix was generated
+ for j in range(number_of_parts):
+ # get the embedding matrix along with entity ids of first and
+ # last row of emb matrix
+ emb_mat, start_ent_id, end_ent_id = self.get_emb_matrix_test(
+ j, number_of_parts
+ )
+ # compute the rank
+ ranks = self._get_ranks(
+ inputs,
+ emb_mat,
+ start_ent_id,
+ end_ent_id,
+ filters,
+ self.mapping_dict,
+ self.corrupt_side,
+ self.ranking_strategy,
+ )
+ # store it in the output
+ for i in tf.range(output_shape):
+ overall_rank = tf.tensor_scatter_nd_add(
+ overall_rank, [[i]], [ranks[i, :]]
+ )
+
+ overall_rank = tf.transpose(
+ tf.reshape(overall_rank, (output_shape, -1))
+ )
+ # if corruption type is s+o then add s and o ranks and return the
+ # added ranks
+ if self.corrupt_side == "s+o":
+ # add the subject and object ranks
+ overall_rank = tf.reduce_sum(overall_rank, 1)
+ # return the added ranks
+ return tf.reshape(overall_rank, (-1, 1))
+
+ return overall_rank
+
+ if not self.run_eagerly and not self.is_partitioned_training:
+ test_function = def_function.function(
+ test_function, experimental_relax_shapes=True
+ )
+
+ self.test_function = test_function
+
+ return self.test_function
+
+ def process_model_inputs_for_test(self, triples):
+ """Return the processed triples.
+
+ Parameters
+ ----------
+ triples: np.array
+ Triples to be processed.
+
+ Returns
+ -------
+ out_triples: np.array or list
+ In regular (non partitioned) mode, the triples are returned as they are given in input.
+ In case of partitioning, it returns the triple embeddings as a list of size 3, where each element
+ is a np.array of subjects, predicates and objects embeddings.
+ """
+ if self.is_partitioned_training:
+ np_triples = triples.numpy()
+ sub_emb_out = []
+ obj_emb_out = []
+ rel_emb_out = []
+ with shelve.open(
+ self.partitioner_metadata["ent_map_fname"]
+ ) as ent_emb:
+ with shelve.open(
+ self.partitioner_metadata["rel_map_fname"]
+ ) as rel_emb:
+ for triple in np_triples:
+ sub_emb_out.append(ent_emb[str(triple[0])])
+ rel_emb_out.append(rel_emb[str(triple[1])])
+ obj_emb_out.append(ent_emb[str(triple[2])])
+
+ emb_out = [
+ np.array(sub_emb_out),
+ np.array(rel_emb_out),
+ np.array(obj_emb_out),
+ ]
+ return emb_out
+ else:
+ return triples
+
+ def evaluate(
+ self,
+ x=None,
+ batch_size=32,
+ verbose=True,
+ use_filter=False,
+ corrupt_side="s,o",
+ entities_subset=None,
+ ranking_strategy="worst",
+ callbacks=None,
+ dataset_type="test",
+ ):
+ """
+ Evaluate the inputs against corruptions and return ranks.
+
+ Parameters
+ ----------
+ x: np.array, shape (n,3) or str or GraphDataLoader or AbstractGraphPartitioner
+ Data OR Filename of the data file OR Data Handle to be used for training.
+ batch_size: int
+ Batch size to use during training.
+ May be overridden if ``x`` is `GraphDataLoader` or `AbstractGraphPartitioner` instance
+ verbose: bool
+ Verbosity mode.
+ use_filter: bool or dict
+ Whether to use a filter of not. If a dictionary is specified, the data in the dict is concatenated
+ and used as filter.
+ corrupt_side: str
+ Which side to corrupt of a triple to corrupt. It can be the subject (``corrupt_size="s"``),
+ the object (``corrupt_size="o"``), the subject and the object (``corrupt_size="s+o"`` or
+ ``corrupt_size="s,o"``) (default:`"s,o"`).
+ ranking_strategy: str
+ Indicates how to break ties when a test triple gets the same rank of a corruption.
+ Can be one of the three types: `"best"`, `"middle"`, `"worst"` (default: `"worst"`, i.e.,
+ the worst rank is assigned to the test triple).
+ entities_subset: list or np.array
+ Subset of entities to be used for generating corruptions.
+ callbacks: list of keras.callbacks.Callback instances
+ List of callbacks to apply during evaluation.
+
+ Returns
+ -------
+ rank: np.array, shape (n, number of corrupted sides)
+ Ranking of test triples against subject corruptions and/or object corruptions.
+
+ Example
+ -------
+ >>> from ampligraph.datasets import load_fb15k_237
+ >>> from ampligraph.latent_features import ScoringBasedEmbeddingModel
+ >>> from ampligraph.evaluation.metrics import mrr_score, hits_at_n_score, mr_score
+ >>> X = load_fb15k_237()
+ >>> model = ScoringBasedEmbeddingModel(eta=5,
+ >>> k=300,
+ >>> scoring_type='ComplEx',
+ >>> seed=0)
+ >>> model.compile(optimizer='adam', loss='nll')
+ >>> model.fit(X['train'],
+ >>> batch_size=10000,
+ >>> epochs=5)
+ Epoch 1/5
+ 29/29 [==============================] - 2s 71ms/step - loss: 67361.3047
+ Epoch 2/5
+ 29/29 [==============================] - 1s 35ms/step - loss: 67318.6094
+ Epoch 3/5
+ 29/29 [==============================] - 1s 35ms/step - loss: 67020.0703
+ Epoch 4/5
+ 29/29 [==============================] - 1s 33ms/step - loss: 65867.3750
+ Epoch 5/5
+ 29/29 [==============================] - 1s 34ms/step - loss: 63517.9062
+ >>> ranks = model.evaluate(X['test'],
+ >>> batch_size=100,
+ >>> corrupt_side='s,o',
+ >>> use_filter={'train': X['train'],
+ >>> 'valid': X['valid'],
+ >>> 'test': X['test'])
+ >>> mr_score(ranks), mrr_score(ranks), hits_at_n_score(ranks, 1), hits_at_n_score(ranks, 10), len(ranks)
+ 28 triples containing invalid keys skipped!
+ 9 triples containing invalid keys skipped!
+ 2045/2045 [==============================] - 149s 73ms/step
+ (428.44671689989235,
+ 0.25761041025282316,
+ 0.1898179861043155,
+ 0.391965945787259,
+ 20438)
+ """
+ # get the test set handler
+ self.data_handler_test = data_adapter.DataHandler(
+ x,
+ batch_size=batch_size,
+ dataset_type=dataset_type,
+ epochs=1,
+ use_filter=use_filter,
+ use_indexer=self.data_indexer,
+ )
+
+ assert corrupt_side in [
+ "s",
+ "o",
+ "s,o",
+ "s+o",
+ ], "Invalid value for corrupt_side"
+ assert ranking_strategy in [
+ "best",
+ "middle",
+ "worst",
+ ], "Invalid value for ranking_strategy"
+
+ self.corrupt_side = corrupt_side
+ self.ranking_strategy = ranking_strategy
+
+ self.entities_subset = tf.constant([])
+ self.mapping_dict = tf.lookup.experimental.DenseHashTable(
+ tf.int32, tf.int32, -1, -1, -2
+ )
+ if entities_subset is not None:
+ entities_subset = self.data_indexer.get_indexes(
+ entities_subset, "e"
+ )
+ self.entities_subset = tf.constant(entities_subset, dtype=tf.int32)
+ self.mapping_dict.insert(
+ self.entities_subset, tf.range(self.entities_subset.shape[0])
+ )
+
+ # flag to indicate if we are using filter or not
+ self.use_filter = (
+ self.data_handler_test._parent_adapter.backend.use_filter
+ or isinstance(
+ self.data_handler_test._parent_adapter.backend.use_filter, dict
+ )
+ )
+
+ # Container that configures and calls `tf.keras.Callback`s.
+ if not isinstance(callbacks, callbacks_module.CallbackList):
+ callbacks = callbacks_module.CallbackList(
+ callbacks,
+ add_history=True,
+ add_progbar=verbose != 0,
+ model=self,
+ verbose=verbose,
+ epochs=1,
+ steps=self.data_handler_test.inferred_steps,
+ )
+
+ test_function = self.make_test_function()
+
+ # before test begins call this callback function
+ callbacks.on_test_begin()
+
+ self.all_ranks = []
+
+ # enumerate over the data
+ for _, iterator in self.data_handler_test.enumerate_epochs():
+ # handle the stop iteration of data iterator in this scope
+ with self.data_handler_test.catch_stop_iteration():
+ # iterate over the dataset
+ for step in self.data_handler_test.steps():
+ # before a batch is processed call this callback function
+ callbacks.on_test_batch_begin(step)
+
+ # process this batch
+ overall_rank = test_function(iterator)
+ # increment the rank by 1 (ranks returned are from (0 -
+ # n-1) so increment by 1
+ overall_rank += 1
+ # save the ranks of the batch triples
+ self.all_ranks.append(overall_rank)
+ # after a batch is processed call this callback function
+ callbacks.on_test_batch_end(step)
+ # on test end call this method
+ callbacks.on_test_end()
+ # return ranks
+ return np.concatenate(self.all_ranks)
+
+ def predict_step(self, inputs):
+ """Returns the output of predict step on a batch of data."""
+ if self.data_shape > 3 and isinstance(inputs, tuple):
+ inputs = inputs[0]
+ score_pos = self(inputs, False)
+ return score_pos
+
+ def predict_step_partitioning(self, inputs):
+ """Returns the output of predict step on a batch of data."""
+ score_pos = self.scoring_layer(inputs)
+ return score_pos
+
+ def make_predict_function(self):
+ """Similar to keras lib, this function returns the handle to the predict step function.
+
+ It processes one batch of data by iterating over the dataset iterator and computes the prediction outputs.
+
+ Returns
+ -------
+ out: Function handle
+ Handle to the predict function.
+ """
+ if self.predict_function is not None:
+ return self.predict_function
+
+ def predict_function(iterator):
+ inputs = next(iterator)
+ if self.is_partitioned_training:
+ inputs = self.process_model_inputs_for_test(inputs)
+ outputs = self.predict_step_partitioning(inputs)
+ else:
+ outputs = self.predict_step(inputs)
+ return outputs
+
+ if not self.run_eagerly and not self.is_partitioned_training:
+ predict_function = def_function.function(
+ predict_function, experimental_relax_shapes=True
+ )
+
+ self.predict_function = predict_function
+ return self.predict_function
+
+ def predict(self, x, batch_size=32, verbose=0, callbacks=None):
+ """
+ Compute scores of the input triples.
+
+ Parameters
+ -----------
+ x: np.array, shape (n, 3) or str or GraphDataLoader or AbstractGraphPartitioner
+ Data OR Filename of the data file OR Data Handle to be used for training.
+ batch_size: int
+ Batch size to use during training.
+ May be overridden if ``x`` is `GraphDataLoader` or `AbstractGraphPartitioner` instance
+ verbose: bool
+ Verbosity mode.
+ callbacks: list of keras.callbacks.Callback instances
+ List of callbacks to apply during evaluation.
+
+ Returns
+ -------
+ scores: np.array, shape (n, )
+ Score of the input triples.
+
+ Example
+ -------
+ >>> from ampligraph.latent_features import ScoringBasedEmbeddingModel
+ >>> import numpy as np
+ >>> from ampligraph.datasets import load_fb15k_237
+ >>> X = load_fb15k_237()
+ >>> model = ScoringBasedEmbeddingModel(eta=5,
+ >>> k=300,
+ >>> scoring_type='ComplEx',
+ >>> seed=0)
+ >>> model.compile(optimizer='adam', loss='nll')
+ >>> model.fit(X['train'],
+ >>> batch_size=10000,
+ >>> epochs=5)
+ Epoch 1/5
+ 29/29 [==============================] - 7s 228ms/step - loss: 67361.2734
+ Epoch 2/5
+ 29/29 [==============================] - 5s 184ms/step - loss: 67318.8203
+ Epoch 3/5
+ 29/29 [==============================] - 5s 187ms/step - loss: 67021.1641
+ Epoch 4/5
+ 29/29 [==============================] - 5s 188ms/step - loss: 65865.5547
+ Epoch 5/5
+ 29/29 [==============================] - 5s 188ms/step - loss: 63510.2773
+
+ >>> pred = model.predict(X['test'],
+ >>> batch_size=100)
+ >>> print(np.sort(pred))
+ [-1.0868168 -0.46582496 -0.44715863 ... 3.2484274 3.3147712 3.326 ]
+
+ """
+
+ self.data_handler_test = data_adapter.DataHandler(
+ x,
+ batch_size=batch_size,
+ dataset_type="test",
+ epochs=1,
+ use_filter=False,
+ use_indexer=self.data_indexer,
+ )
+
+ if not isinstance(callbacks, callbacks_module.CallbackList):
+ callbacks = callbacks_module.CallbackList(
+ callbacks,
+ add_history=True,
+ add_progbar=verbose != 0,
+ model=self,
+ verbose=verbose,
+ epochs=1,
+ steps=self.data_handler_test.inferred_steps,
+ )
+
+ predict_function = self.make_predict_function()
+ callbacks.on_predict_begin()
+ outputs = []
+ for _, iterator in self.data_handler_test.enumerate_epochs():
+ with self.data_handler_test.catch_stop_iteration():
+ for step in self.data_handler_test.steps():
+ callbacks.on_predict_batch_begin(step)
+ batch_outputs = predict_function(iterator)
+ outputs.append(batch_outputs)
+
+ callbacks.on_predict_batch_end(
+ step, {"outputs": batch_outputs}
+ )
+ callbacks.on_predict_end()
+ return np.concatenate(outputs)
+
+ def make_calibrate_function(self):
+ """Similar to keras lib, this function returns the handle to the calibrate step function.
+
+ It processes one batch of data by iterating over the dataset iterator and computes the calibration
+ of predictions.
+
+ Returns
+ -------
+ out: Function handle
+ Handle to the calibration function.
+ """
+
+ def calibrate_with_corruption(iterator):
+ inputs = next(iterator)
+ if self.data_shape > 3 and isinstance(inputs, tuple):
+ inputs = inputs[0]
+ if self.is_partitioned_training:
+ inp_emb = self.process_model_inputs_for_test(inputs)
+ inp_score = self.scoring_layer(inp_emb)
+
+ corruptions = self.corruption_layer(inputs, self.num_ents, 1)
+ corr_emb = self.encoding_layer(corruptions)
+ corr_score = self.scoring_layer(corr_emb)
+ else:
+ inp_emb = self.encoding_layer(inputs)
+ inp_score = self.scoring_layer(inp_emb)
+
+ corruptions = self.corruption_layer(inputs, self.num_ents, 1)
+ corr_emb = self.encoding_layer(corruptions)
+ corr_score = self.scoring_layer(corr_emb)
+ return inp_score, corr_score
+
+ def calibrate_with_negatives(iterator):
+ inputs = next(iterator)
+ if self.data_shape > 3 and isinstance(inputs, tuple):
+ inputs = inputs[0]
+ if self.is_partitioned_training:
+ inp_emb = self.process_model_inputs_for_test(inputs)
+ inp_score = self.scoring_layer(inp_emb)
+ else:
+ inp_emb = self.encoding_layer(inputs)
+ inp_score = self.scoring_layer(inp_emb)
+ return inp_score
+
+ if self.is_calibrate_with_corruption:
+ calibrate_fn = calibrate_with_corruption
+ else:
+ calibrate_fn = calibrate_with_negatives
+
+ if not self.run_eagerly and not self.is_partitioned_training:
+ calibrate_fn = def_function.function(
+ calibrate_fn, experimental_relax_shapes=True
+ )
+
+ return calibrate_fn
+
+ def calibrate(
+ self,
+ X_pos,
+ X_neg=None,
+ positive_base_rate=None,
+ batch_size=32,
+ epochs=50,
+ verbose=0,
+ ):
+ """Calibrate predictions.
+
+ The method implements the heuristics described in :cite:`calibration`,
+ using Platt scaling :cite:`platt1999probabilistic`.
+
+ The calibrated predictions can be obtained with :meth:`predict_proba`
+ after calibration is done.
+
+ Ideally, calibration should be performed on a validation set that was not used to train the embeddings.
+
+ There are two modes of operation, depending on the availability of negative triples:
+
+ #. Both positive and negative triples are provided via ``X_pos`` and ``X_neg`` respectively. \
+ The optimization is done using a second-order method (limited-memory BFGS), \
+ therefore no hyperparameter needs to be specified.
+
+ #. Only positive triples are provided, and the negative triples are generated by corruptions, \
+ just like it is done in training or evaluation. The optimization is done using a first-order method (ADAM), \
+ therefore ``batches_count`` and ``epochs`` must be specified.
+
+
+ Calibration is highly dependent on the base rate of positive triples.
+ Therefore, for mode (2) of operation, the user is required to provide the ``positive_base_rate`` argument.
+ For mode (1), that can be inferred automatically by the relative sizes of the positive and negative sets,
+ but the user can override this behaviour by providing a value to ``positive_base_rate``.
+
+ Defining the positive base rate is the biggest challenge when calibrating without negatives. That depends on
+ the user choice of triples to be evaluated during test time.
+ Let's take the WN11 dataset as an example: it has around 50% positives triples on both the validation set
+ and test set, so the positive base rate follows to be 50%. However, should the user resample it to have
+ 75% positives and 25% negatives, the previous calibration would be degraded. The user must recalibrate
+ the model with a 75% positive base rate. Therefore, this parameter depends on how the user handles the
+ dataset and cannot be determined automatically or a priori.
+
+ .. Note ::
+ :cite:`calibration` `calibration experiments available here
+ `_.
+
+
+ Parameters
+ ----------
+ X_pos : np.array, shape (n,3) or str or GraphDataLoader or AbstractGraphPartitioner
+ Data OR Filename of the data file OR Data Handle to be used as positive triples.
+ X_neg : np.array, shape (n,3) or str or GraphDataLoader or AbstractGraphPartitioner
+ Data OR Filename of the data file OR Data Handle to be used as negative triples.
+
+ If `None`, the negative triples are generated via corruptions
+ and the user must provide a positive base rate instead.
+
+ positive_base_rate: float
+ Base rate of positive statements.
+
+ For example, if we assume there is an even chance for any query to be true, the base rate would be 50%.
+
+ If ``X_neg`` is provided and ``positive_base_rate=None``, the relative sizes of ``X_pos`` and ``X_neg``
+ will be used to determine the base rate. Say we have 50 positive triples and 200 negative
+ triples, the positive base rate will be assumed to be :math:`\\frac{50}{(50+200)} = \\frac{1}{5} = 0.2`.
+
+ This value must be :math:`\\in [0,1]`.
+ batches_size: int
+ Batch size for positives.
+ epochs: int
+ Number of epochs used to train the Platt scaling model.
+ Only applies when ``X_neg=None``.
+ verbose: bool
+ Verbosity (default: `False`).
+
+ Example
+ -------
+ >>> from ampligraph.datasets import load_fb15k_237
+ >>> from ampligraph.latent_features import ScoringBasedEmbeddingModel
+ >>> import numpy as np
+ >>> dataset = load_fb15k_237()
+ >>> model = ScoringBasedEmbeddingModel(eta=5,
+ >>> k=300,
+ >>> scoring_type='ComplEx')
+ >>> model.compile(optimizer='adam', loss='nll')
+ >>> model.fit(dataset['train'],
+ >>> batch_size=10000,
+ >>> epochs=5)
+ >>> print('Raw scores (sorted):', np.sort(model.predict(dataset['test'])))
+ >>> print('Indices obtained by sorting (scores):', np.argsort(model.predict(dataset['test'])))
+ Raw scores (sorted): [-1.0689778 -0.42082012 -0.39887887 ... 3.261838 3.2755773 3.2768354 ]
+ Indices obtained by sorting (scores): [ 3834 18634 4066 ... 6237 13633 10961]
+ >>> model.calibrate(dataset['test'],
+ >>> batch_size=10000,
+ >>> positive_base_rate=0.9,
+ >>> epochs=100)
+ >>> print('Calibrated scores (sorted):', np.sort(model.predict_proba(dataset['test'])))
+ >>> print('Indices obtained by sorting (Calibrated):', np.argsort(model.predict_proba(dataset['test'])))
+ Calibrated scores (sorted): [0.49547982 0.5396996 0.54118955 ... 0.7624245 0.7631044 0.76316655]
+ Indices obtained by sorting (Calibrated): [ 3834 18634 4066 ... 6237 13633 10961]
+
+ """
+ self.is_calibrated = False
+ data_handler_calibrate_pos = data_adapter.DataHandler(
+ X_pos,
+ batch_size=batch_size,
+ dataset_type="test",
+ epochs=epochs,
+ use_filter=False,
+ use_indexer=self.data_indexer,
+ )
+
+ pos_size = data_handler_calibrate_pos._parent_adapter.get_data_size()
+ neg_size = pos_size
+
+ if X_neg is None:
+ assert (
+ positive_base_rate is not None
+ ), "Please provide the negatives or positive base rate!"
+ self.is_calibrate_with_corruption = True
+ else:
+ self.is_calibrate_with_corruption = False
+
+ pos_batch_count = int(np.ceil(pos_size / batch_size))
+
+ data_handler_calibrate_neg = data_adapter.DataHandler(
+ X_neg,
+ batch_size=batch_size,
+ dataset_type="test",
+ epochs=epochs,
+ use_filter=False,
+ use_indexer=self.data_indexer,
+ )
+
+ neg_size = (
+ data_handler_calibrate_neg._parent_adapter.get_data_size()
+ )
+ neg_batch_count = int(np.ceil(neg_size / batch_size))
+
+ if pos_batch_count != neg_batch_count:
+ batch_size_neg = int(np.ceil(neg_size / pos_batch_count))
+ data_handler_calibrate_neg = data_adapter.DataHandler(
+ X_neg,
+ batch_size=batch_size_neg,
+ dataset_type="test",
+ epochs=epochs,
+ use_filter=False,
+ use_indexer=self.data_indexer,
+ )
+
+ if positive_base_rate is None:
+ positive_base_rate = pos_size / (pos_size + neg_size)
+
+ if positive_base_rate is not None and (
+ positive_base_rate <= 0 or positive_base_rate >= 1
+ ):
+ raise ValueError(
+ "positive_base_rate must be a value between 0 and 1."
+ )
+
+ self.calibration_layer = CalibrationLayer(
+ pos_size, neg_size, positive_base_rate
+ )
+ calibrate_function = self.make_calibrate_function()
+
+ optimizer = tf.keras.optimizers.Adam()
+
+ if not self.is_calibrate_with_corruption:
+ negative_iterator = iter(
+ data_handler_calibrate_neg.enumerate_epochs()
+ )
+
+ for _, iterator in data_handler_calibrate_pos.enumerate_epochs(True):
+ if not self.is_calibrate_with_corruption:
+ _, neg_handle = next(negative_iterator)
+
+ with data_handler_calibrate_pos.catch_stop_iteration():
+ for step in data_handler_calibrate_pos.steps():
+ if self.is_calibrate_with_corruption:
+ scores_pos, scores_neg = calibrate_function(iterator)
+
+ else:
+ scores_pos = calibrate_function(iterator)
+ with data_handler_calibrate_neg.catch_stop_iteration():
+ scores_neg = calibrate_function(neg_handle)
+
+ with tf.GradientTape() as tape:
+ out = self.calibration_layer(scores_pos, scores_neg, 1)
+
+ gradients = tape.gradient(
+ out, self.calibration_layer._trainable_weights
+ )
+ # update the trainable params
+ optimizer.apply_gradients(
+ zip(
+ gradients,
+ self.calibration_layer._trainable_weights,
+ )
+ )
+ self.is_calibrated = True
+
+ def predict_proba(self, x, batch_size=32, verbose=0, callbacks=None):
+ """
+ Compute calibrated scores (:math:`0 ≤ score ≤ 1`) for the input triples.
+
+ Parameters
+ ----------
+ x: np.array, shape (n,3) or str or GraphDataLoader or AbstractGraphPartitioner
+ Data OR Filename of the data file OR Data Handle to be used for training.
+ batch_size: int
+ Batch size to use during training.
+ May be overridden if ``x`` is `GraphDataLoader` or `AbstractGraphPartitioner` instance.
+ verbose: bool
+ Verbosity mode (default: `False`).
+ callbacks: list of keras.callbacks.Callback instances
+ List of callbacks to apply during evaluation.
+
+ Returns
+ -------
+ scores: np.array, shape (n, )
+ Calibrated scores for the input triples.
+
+ Example
+ -------
+ >>> from ampligraph.datasets import load_fb15k_237
+ >>> from ampligraph.latent_features import ScoringBasedEmbeddingModel
+ >>> import numpy as np
+ >>> dataset = load_fb15k_237()
+ >>> model = ScoringBasedEmbeddingModel(eta=5,
+ >>> k=300,
+ >>> scoring_type='ComplEx')
+ >>> model.compile(optimizer='adam', loss='nll')
+ >>> model.fit(dataset['train'],
+ >>> batch_size=10000,
+ >>> epochs=5)
+ >>> print('Raw scores (sorted):', np.sort(model.predict(dataset['test'])))
+ >>> print('Indices obtained by sorting (scores):', np.argsort(model.predict(dataset['test'])))
+ Raw scores (sorted): [-1.0384613 -0.46752608 -0.45149875 ... 3.2897844 3.3034315 3.3280635 ]
+ Indices obtained by sorting (scores): [ 3834 18634 4066 ... 1355 13633 10961]
+ >>> model.calibrate(dataset['test'],
+ >>> batch_size=10000,
+ >>> positive_base_rate=0.9,
+ >>> epochs=100)
+ >>> print('Calibrated scores (sorted):', np.sort(model.predict_proba(dataset['test'])))
+ >>> print('Indices obtained by sorting (Calibrated):', np.argsort(model.predict_proba(dataset['test'])))
+ Calibrated scores (sorted): [0.5553725 0.5556108 0.5568415 ... 0.6211011 0.62382233 0.6297585 ]
+ Indices obtained by sorting (Calibrated): [14573 11577 4404 ... 17817 17816 733]
+ """
+ if not self.is_calibrated:
+ msg = "Model has not been calibrated. \
+ Please call `model.calibrate(...)` before predicting probabilities."
+
+ raise RuntimeError(msg)
+
+ self.data_handler_test = data_adapter.DataHandler(
+ x,
+ batch_size=batch_size,
+ dataset_type="test",
+ epochs=1,
+ use_filter=False,
+ use_indexer=self.data_indexer,
+ )
+
+ if not isinstance(callbacks, callbacks_module.CallbackList):
+ callbacks = callbacks_module.CallbackList(
+ callbacks,
+ add_history=True,
+ add_progbar=verbose != 0,
+ model=self,
+ verbose=verbose,
+ epochs=1,
+ steps=self.data_handler_test.inferred_steps,
+ )
+
+ predict_function = self.make_predict_function()
+ callbacks.on_predict_begin()
+ outputs = []
+ for _, iterator in self.data_handler_test.enumerate_epochs():
+ with self.data_handler_test.catch_stop_iteration():
+ for step in self.data_handler_test.steps():
+ callbacks.on_predict_batch_begin(step)
+ batch_outputs = predict_function(iterator)
+ probas = self.calibration_layer(batch_outputs, training=0)
+ outputs.append(probas)
+
+ callbacks.on_predict_batch_end(
+ step, {"outputs": batch_outputs}
+ )
+ callbacks.on_predict_end()
+ return np.concatenate(outputs)
+
+ def get_embeddings(self, entities, embedding_type="e"):
+ """Get the embeddings of entities or relations.
+
+ .. Note ::
+
+ Use :meth:`ampligraph.utils.create_tensorboard_visualizations` to visualize the embeddings with TensorBoard.
+
+ Parameters
+ ----------
+ entities : array-like, shape=(n)
+ The entities (or relations) of interest. Element of the vector must be the original string literals, and
+ not internal IDs.
+ embedding_type : str
+ If `'e'` is passed, ``entities`` argument will be considered as a list of knowledge graph entities
+ (i.e., nodes). If set to `'r'`, ``entities`` will be treated as relations instead.
+ Returns
+ -------
+ embeddings : ndarray, shape (n, k)
+ An array of `k`-dimensional embeddings.
+
+ Example
+ -------
+ >>> from ampligraph.latent_features import ScoringBasedEmbeddingModel
+ >>> from ampligraph.datasets import load_fb15k_237
+ >>> X = load_fb15k_237()
+ >>> model = ScoringBasedEmbeddingModel(eta=5,
+ >>> k=300,
+ >>> scoring_type='ComplEx',
+ >>> seed=0)
+ >>> model.compile(optimizer='adam', loss='nll')
+ >>> model.fit(X['train'],
+ >>> batch_size=10000,
+ >>> epochs=5,
+ >>> verbose=False)
+ >>> model.get_embeddings(['/m/027rn', '/m/06v8s0'], 'e')
+ array([[ 0.04482496 0.11973907 0.01117733 ... -0.13391922 0.11103553 -0.08132861]
+ [-0.10158381 0.08108605 -0.07608676 ... 0.0591407 0.02791426 0.07559016]], dtype=float32)
+ """
+
+ if embedding_type == "e":
+ lookup_concept = self.data_indexer.get_indexes(entities, "e")
+ if self.is_partitioned_training:
+ emb_out = []
+ with shelve.open(
+ self.partitioner_metadata["ent_map_fname"]
+ ) as ent_emb:
+ for ent_id in lookup_concept:
+ emb_out.append(ent_emb[str(ent_id)])
+ else:
+ return tf.nn.embedding_lookup(
+ self.encoding_layer.ent_emb, lookup_concept
+ ).numpy()
+ elif embedding_type == "r":
+ lookup_concept = self.data_indexer.get_indexes(entities, "r")
+ if self.is_partitioned_training:
+ emb_out = []
+ with shelve.open(
+ self.partitioner_metadata["rel_map_fname"]
+ ) as rel_emb:
+ for rel_id in lookup_concept:
+ emb_out.append(rel_emb[str(rel_id)])
+ else:
+ return tf.nn.embedding_lookup(
+ self.encoding_layer.rel_emb, lookup_concept
+ ).numpy()
+ else:
+ msg = "Invalid entity type: {}".format(embedding_type)
+ raise ValueError(msg)
diff --git a/ampligraph/latent_features/models/TransE.py b/ampligraph/latent_features/models/TransE.py
deleted file mode 100644
index 3743dcb8..00000000
--- a/ampligraph/latent_features/models/TransE.py
+++ /dev/null
@@ -1,337 +0,0 @@
-# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
-#
-# This file is Licensed under the Apache License, Version 2.0.
-# A copy of the Licence is available in LICENCE, or at:
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-from .EmbeddingModel import EmbeddingModel, register_model
-from ampligraph.latent_features import constants as constants
-from ampligraph.latent_features.initializers import DEFAULT_XAVIER_IS_UNIFORM
-import tensorflow as tf
-
-
-@register_model("TransE",
- ["norm", "normalize_ent_emb", "negative_corruption_entities"])
-class TransE(EmbeddingModel):
- r"""Translating Embeddings (TransE)
-
- The model as described in :cite:`bordes2013translating`.
-
- The scoring function of TransE computes a similarity between the embedding of the subject
- :math:`\mathbf{e}_{sub}` translated by the embedding of the predicate :math:`\mathbf{e}_{pred}`
- and the embedding of the object :math:`\mathbf{e}_{obj}`,
- using the :math:`L_1` or :math:`L_2` norm :math:`||\cdot||`:
-
- .. math::
-
- f_{TransE}=-||\mathbf{e}_{sub} + \mathbf{e}_{pred} - \mathbf{e}_{obj}||_n
-
-
- Such scoring function is then used on positive and negative triples :math:`t^+, t^-` in the loss function.
-
- Examples
- --------
- >>> import numpy as np
- >>> from ampligraph.latent_features import TransE
- >>> model = TransE(batches_count=1, seed=555, epochs=20, k=10, loss='pairwise',
- >>> loss_params={'margin':5})
- >>> X = np.array([['a', 'y', 'b'],
- >>> ['b', 'y', 'a'],
- >>> ['a', 'y', 'c'],
- >>> ['c', 'y', 'a'],
- >>> ['a', 'y', 'd'],
- >>> ['c', 'y', 'd'],
- >>> ['b', 'y', 'c'],
- >>> ['f', 'y', 'e']])
- >>> model.fit(X)
- >>> model.predict(np.array([['f', 'y', 'e'], ['b', 'y', 'd']]))
- [-4.6903257, -3.9047198]
- >>> model.get_embeddings(['f','e'], embedding_type='entity')
- array([[ 0.10673896, -0.28916815, 0.6278883 , -0.1194713 , -0.10372276,
- -0.37258488, 0.06460134, -0.27879423, 0.25456288, 0.18665907],
- [-0.64494324, -0.12939683, 0.3181001 , 0.16745451, -0.03766293,
- 0.24314676, -0.23038973, -0.658638 , 0.5680542 , -0.05401703]],
- dtype=float32)
-
- """
-
- def __init__(self,
- k=constants.DEFAULT_EMBEDDING_SIZE,
- eta=constants.DEFAULT_ETA,
- epochs=constants.DEFAULT_EPOCH,
- batches_count=constants.DEFAULT_BATCH_COUNT,
- seed=constants.DEFAULT_SEED,
- embedding_model_params={'norm': constants.DEFAULT_NORM_TRANSE,
- 'normalize_ent_emb': constants.DEFAULT_NORMALIZE_EMBEDDINGS,
- 'negative_corruption_entities': constants.DEFAULT_CORRUPTION_ENTITIES,
- 'corrupt_sides': constants.DEFAULT_CORRUPT_SIDE_TRAIN},
- optimizer=constants.DEFAULT_OPTIM,
- optimizer_params={'lr': constants.DEFAULT_LR},
- loss=constants.DEFAULT_LOSS,
- loss_params={},
- regularizer=constants.DEFAULT_REGULARIZER,
- regularizer_params={},
- initializer=constants.DEFAULT_INITIALIZER,
- initializer_params={'uniform': DEFAULT_XAVIER_IS_UNIFORM},
- verbose=constants.DEFAULT_VERBOSE):
- """
- Initialize an EmbeddingModel.
-
- Also creates a new Tensorflow session for training.
-
- Parameters
- ----------
- k : int
- Embedding space dimensionality.
- eta : int
- The number of negatives that must be generated at runtime during training for each positive.
- epochs : int
- The iterations of the training loop.
- batches_count : int
- The number of batches in which the training set must be split during the training loop.
- seed : int
- The seed used by the internal random numbers generator.
- embedding_model_params : dict
- TransE-specific hyperparams, passed to the model as a dictionary.
-
- Supported keys:
-
- - **'norm'** (int): the norm to be used in the scoring function (1 or 2-norm - default: 1).
- - **'normalize_ent_emb'** (bool): flag to indicate whether to normalize entity embeddings
- after each batch update (default: False).
- - **negative_corruption_entities** : entities to be used for generation of corruptions while training.
- It can take the following values :
- ``all`` (default: all entities),
- ``batch`` (entities present in each batch),
- list of entities
- or an int (which indicates how many entities that should be used for corruption generation).
- - **corrupt_sides** : Specifies how to generate corruptions for training.
- Takes values `s`, `o`, `s+o` or any combination passed as a list.
- - **'non_linearity'**: can be one of the following values ``linear``, ``softplus``, ``sigmoid``, ``tanh``
- - **'stop_epoch'**: specifies how long to decay (linearly) the numeric values from 1 to original value
- until it reachs original value.
- - **'structural_wt'**: structural influence hyperparameter [0, 1] that modulates the influence of graph
- topology.
- - **'normalize_numeric_values'**: normalize the numeric values, such that they are scaled between [0, 1]
-
- The last 4 parameters are related to FocusE layers.
-
- Example: ``embedding_model_params={'norm': 1, 'normalize_ent_emb': False}``
-
- optimizer : string
- The optimizer used to minimize the loss function. Choose between 'sgd',
- 'adagrad', 'adam', 'momentum'.
- optimizer_params : dict
- Arguments specific to the optimizer, passed as a dictionary.
-
- Supported keys:
-
- - **'lr'** (float): learning rate (used by all the optimizers). Default: 0.1.
- - **'momentum'** (float): learning momentum (only used when ``optimizer=momentum``). Default: 0.9.
-
- Example: ``optimizer_params={'lr': 0.01}``
-
- loss : string
- The type of loss function to use during training.
-
- - ``pairwise`` the model will use pairwise margin-based loss function.
- - ``nll`` the model will use negative loss likelihood.
- - ``absolute_margin`` the model will use absolute margin likelihood.
- - ``self_adversarial`` the model will use adversarial sampling loss function.
- - ``multiclass_nll`` the model will use multiclass nll loss.
- Switch to multiclass loss defined in :cite:`chen2015`
- by passing 'corrupt_sides' as ['s','o'] to embedding_model_params.
- To use loss defined in :cite:`kadlecBK17` pass 'corrupt_sides' as 'o' to embedding_model_params.
-
- loss_params : dict
- Dictionary of loss-specific hyperparameters. See :ref:`loss functions `
- documentation for additional details.
-
- Example: ``optimizer_params={'lr': 0.01}`` if ``loss='pairwise'``.
-
- regularizer : string
- The regularization strategy to use with the loss function.
-
- - ``None``: the model will not use any regularizer (default)
- - 'LP': the model will use L1, L2 or L3 based on the value of ``regularizer_params['p']`` (see below).
-
- regularizer_params : dict
- Dictionary of regularizer-specific hyperparameters. See the :ref:`regularizers `
- documentation for additional details.
-
- Example: ``regularizer_params={'lambda': 1e-5, 'p': 2}`` if ``regularizer='LP'``.
-
- initializer : string
- The type of initializer to use.
-
- - ``normal``: The embeddings will be initialized from a normal distribution
- - ``uniform``: The embeddings will be initialized from a uniform distribution
- - ``xavier``: The embeddings will be initialized using xavier strategy (default)
-
- initializer_params : dict
- Dictionary of initializer-specific hyperparameters. See the
- :ref:`initializer `
- documentation for additional details.
-
- Example: ``initializer_params={'mean': 0, 'std': 0.001}`` if ``initializer='normal'``.
-
-
- verbose : bool
- Verbose mode
- """
- super().__init__(k=k, eta=eta, epochs=epochs,
- batches_count=batches_count, seed=seed,
- embedding_model_params=embedding_model_params,
- optimizer=optimizer, optimizer_params=optimizer_params,
- loss=loss, loss_params=loss_params,
- regularizer=regularizer, regularizer_params=regularizer_params,
- initializer=initializer, initializer_params=initializer_params,
- verbose=verbose)
-
- def _fn(self, e_s, e_p, e_o):
- r"""The TransE scoring function.
-
- .. math::
-
- f_{TransE}=-||(\mathbf{e}_s + \mathbf{r}_p) - \mathbf{e}_o||_n
-
- Parameters
- ----------
- e_s : Tensor, shape [n]
- The embeddings of a list of subjects.
- e_p : Tensor, shape [n]
- The embeddings of a list of predicates.
- e_o : Tensor, shape [n]
- The embeddings of a list of objects.
-
- Returns
- -------
- score : TensorFlow operation
- The operation corresponding to the TransE scoring function.
-
- """
-
- return tf.negative(
- tf.norm(e_s + e_p - e_o, ord=self.embedding_model_params.get('norm', constants.DEFAULT_NORM_TRANSE),
- axis=1))
-
- def fit(self, X, early_stopping=False, early_stopping_params={}, focusE_numeric_edge_values=None,
- tensorboard_logs_path=None):
- """Train an Translating Embeddings model.
-
- The model is trained on a training set X using the training protocol
- described in :cite:`trouillon2016complex`.
-
- Parameters
- ----------
- X : ndarray, shape [n, 3]
- The training triples
- early_stopping: bool
- Flag to enable early stopping (default:False).
-
- If set to ``True``, the training loop adopts the following early stopping heuristic:
-
- - The model will be trained regardless of early stopping for ``burn_in`` epochs.
- - Every ``check_interval`` epochs the method will compute the metric specified in ``criteria``.
-
- If such metric decreases for ``stop_interval`` checks, we stop training early.
-
- Note the metric is computed on ``x_valid``. This is usually a validation set that you held out.
-
- Also, because ``criteria`` is a ranking metric, it requires generating negatives.
- Entities used to generate corruptions can be specified, as long as the side(s) of a triple to corrupt.
- The method supports filtered metrics, by passing an array of positives to ``x_filter``. This will be used to
- filter the negatives generated on the fly (i.e. the corruptions).
-
- .. note::
-
- Keep in mind the early stopping criteria may introduce a certain overhead
- (caused by the metric computation).
- The goal is to strike a good trade-off between such overhead and saving training epochs.
-
- A common approach is to use MRR unfiltered: ::
-
- early_stopping_params={x_valid=X['valid'], 'criteria': 'mrr'}
-
- Note the size of validation set also contributes to such overhead.
- In most cases a smaller validation set would be enough.
-
- early_stopping_params: dictionary
- Dictionary of hyperparameters for the early stopping heuristics.
-
- The following string keys are supported:
-
- - **'x_valid'**: ndarray, shape [n, 3] : Validation set to be used for early stopping.
- - **'criteria'**: string : criteria for early stopping 'hits10', 'hits3', 'hits1' or 'mrr'(default).
- - **'x_filter'**: ndarray, shape [n, 3] : Positive triples to use as filter if a 'filtered'
- early stopping criteria is desired (i.e. filtered-MRR if 'criteria':'mrr').
- Note this will affect training time (no filter by default).
- - **'burn_in'**: int : Number of epochs to pass before kicking in early stopping (default: 100).
- - **check_interval'**: int : Early stopping interval after burn-in (default:10).
- - **'stop_interval'**: int : Stop if criteria is performing worse over n consecutive checks (default: 3)
- - **'corruption_entities'**: List of entities to be used for corruptions.
- If 'all', it uses all entities (default: 'all')
- - **'corrupt_side'**: Specifies which side to corrupt. 's', 'o', 's+o' (default)
-
- Example: ``early_stopping_params={x_valid=X['valid'], 'criteria': 'mrr'}``
-
- focusE_numeric_edge_values: ndarray, shape [n]
- .. _focuse_transe:
-
- If processing a knowledge graph with numeric values associated with links, this is the vector of such
- numbers. Passing this argument will activate the :ref:`FocusE layer `
- :cite:`pai2021learning`.
- Semantically, numeric values can signify importance, uncertainity, significance, confidence, etc.
- Values can be any number, and will be automatically normalised to the [0, 1] range, on a
- predicate-specific basis.
- If the numeric value is unknown pass a ``np.NaN`` value.
- The model will uniformly randomly assign a numeric value.
-
- .. note::
-
- The following toy example shows how to enable the FocusE layer
- to process edges with numeric literals: ::
-
- import numpy as np
- from ampligraph.latent_features import TransE
- model = TransE(batches_count=1, seed=555, epochs=20,
- k=10, loss='pairwise',
- loss_params={'margin':5})
- X = np.array([['a', 'y', 'b'],
- ['b', 'y', 'a'],
- ['a', 'y', 'c'],
- ['c', 'y', 'a'],
- ['a', 'y', 'd'],
- ['c', 'y', 'd'],
- ['b', 'y', 'c'],
- ['f', 'y', 'e']])
-
- # Numeric values below are associate to each triple in X.
- # They can be any number and will be automatically
- # normalised to the [0, 1] range, on a
- # predicate-specific basis.
- X_edge_values = np.array([5.34, -1.75, 0.33, 5.12,
- np.nan, 3.17, 2.76, 0.41])
-
- model.fit(X, focusE_numeric_edge_values=X_edge_values)
-
- tensorboard_logs_path: str or None
- Path to store tensorboard logs, e.g. average training loss tracking per epoch (default: ``None`` indicating
- no logs will be collected). When provided it will create a folder under provided path and save tensorboard
- files there. To then view the loss in the terminal run: ``tensorboard --logdir ``.
- """
- super().fit(X, early_stopping, early_stopping_params, focusE_numeric_edge_values,
- tensorboard_logs_path=tensorboard_logs_path)
-
- def predict(self, X, from_idx=False):
- __doc__ = super().predict.__doc__ # NOQA
- return super().predict(X, from_idx=from_idx)
-
- def calibrate(self, X_pos, X_neg=None, positive_base_rate=None, batches_count=100, epochs=50):
- __doc__ = super().calibrate.__doc__ # NOQA
- super().calibrate(X_pos, X_neg, positive_base_rate, batches_count, epochs)
-
- def predict_proba(self, X):
- __doc__ = super().predict_proba.__doc__ # NOQA
- return super().predict_proba(X)
diff --git a/ampligraph/latent_features/models/__init__.py b/ampligraph/latent_features/models/__init__.py
index 9857c5e7..55632b7a 100644
--- a/ampligraph/latent_features/models/__init__.py
+++ b/ampligraph/latent_features/models/__init__.py
@@ -1,17 +1,10 @@
-# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
#
# This file is Licensed under the Apache License, Version 2.0.
# A copy of the Licence is available in LICENCE, or at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
-from .EmbeddingModel import EmbeddingModel
-from .TransE import TransE
-from .DistMult import DistMult
-from .ComplEx import ComplEx
-from .HolE import HolE
-from .RandomBaseline import RandomBaseline
-from .ConvKB import ConvKB
-from .ConvE import ConvE
+from .ScoringBasedEmbeddingModel import ScoringBasedEmbeddingModel
-__all__ = ['EmbeddingModel', 'TransE', 'DistMult', 'ComplEx', 'HolE', 'ConvKB', 'ConvE', 'RandomBaseline']
+__all__ = ["ScoringBasedEmbeddingModel"]
diff --git a/ampligraph/latent_features/optimizers.py b/ampligraph/latent_features/optimizers.py
index 4031ea70..619a55b3 100644
--- a/ampligraph/latent_features/optimizers.py
+++ b/ampligraph/latent_features/optimizers.py
@@ -1,451 +1,202 @@
-# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
#
# This file is Licensed under the Apache License, Version 2.0.
# A copy of the Licence is available in LICENCE, or at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
-import tensorflow as tf
import abc
import logging
-import math
+import six
+import tensorflow as tf
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
-OPTIMIZER_REGISTRY = {}
-
-
-def register_optimizer(name, external_params=[], class_params={}):
- def insert_in_registry(class_handle):
- OPTIMIZER_REGISTRY[name] = class_handle
- class_handle.name = name
- OPTIMIZER_REGISTRY[name].external_params = external_params
- OPTIMIZER_REGISTRY[name].class_params = class_params
- return class_handle
-
- return insert_in_registry
-
-
-# Default learning rate for the optimizers
-DEFAULT_LR = 0.0005
-
-# Default momentum for the optimizers
-DEFAULT_MOMENTUM = 0.9
-
-DEFAULT_DECAY_CYCLE = 0
-
-DEFAULT_DECAY_CYCLE_MULTIPLE = 1
-
-DEFAULT_LR_DECAY_FACTOR = 2
-
-DEFAULT_END_LR = 1e-8
-
-DEFAULT_SINE = False
-
-
-class Optimizer(abc.ABC):
- """Abstract class for optimizer .
- """
- name = ""
- external_params = []
- class_params = {}
+class OptimizerWrapper(abc.ABC):
+ """Wrapper around tensorflow optimizer."""
- def __init__(self, optimizer_params, batches_count, verbose):
- """Initialize the Optimizer
+ def __init__(self, optimizer=None):
+ """Initialize the tensorflow Optimizer and wraps it so that it can be used with graph partitioning.
Parameters
----------
- optimizer_params : dict
- Consists of key-value pairs. The initializer will check the keys to get the corresponding params.
- batches_count: int
- number of batches in an epoch
- verbose : bool
- Enable/disable verbose mode
+ optimizer: str (name of optimizer) or optimizer instance.
+ See `tf.keras.optimizers `.
"""
+ self.optimizer = optimizer
+ self.num_optimized_vars = 0
+ # number of optimizer hpyerparams - adam has 2 if amsgrad is false
+ self.number_hyperparams = 1
+ self.is_partitioned_training = False
- self.verbose = verbose
- self._optimizer_params = {}
- self._init_hyperparams(optimizer_params)
- self.batches_count = batches_count
+ # workaround for Adagrad/Adadelta/Ftrl optimizers to work on gpu
+ self.gpu_workaround = False
+ if (
+ isinstance(self.optimizer, tf.keras.optimizers.Adadelta)
+ or isinstance(self.optimizer, tf.keras.optimizers.Adagrad)
+ or isinstance(self.optimizer, tf.keras.optimizers.Ftrl)
+ ):
+ self.gpu_workaround = True
- def _display_params(self):
- """Display the parameter values
- """
- logger.info('\n------ Optimizer -----')
- logger.info('Name : {}'.format(self.name))
- for key, value in self._optimizer_params.items():
- logger.info('{} : {}'.format(key, value))
+ if isinstance(self.optimizer, tf.keras.optimizers.Adam):
+ self.number_hyperparams = 2
- def _init_hyperparams(self, hyperparam_dict):
- """ Initializes the hyperparameters needed by the algorithm.
+ def apply_gradients(self, grads_and_vars):
+ """Wrapper around apply_gradients.
- Parameters
- ----------
- hyperparam_dict : dictionary
- Consists of key value pairs. The optimizer will check the keys to get the corresponding params
+ See `tf.keras.optimizers ` for more details.
"""
+ self.optimizer.apply_gradients(grads_and_vars)
- self._optimizer_params['lr'] = hyperparam_dict.get('lr', DEFAULT_LR)
- if self.verbose:
- self._display_params()
+ def set_partitioned_training(self, value=True):
+ self.is_partitioned_training = value
- def minimize(self, loss):
- """Create an optimizer to minimize the model loss
+ def minimize(self, loss, ent_emb, rel_emb, gradient_tape, other_vars=[]):
+ """Minimizes the loss with respect to entity and relation embeddings and other trainable variables.
Parameters
----------
loss: tf.Tensor
- Node which needs to be evaluated for computing the model loss.
+ Model Loss.
+ ent_emb: tf.Variable
+ Entity embedding.
+ rel_emb: tf.Variable
+ Relation embedding.
+ gradient tape: tf.GradientTape
+ Gradient tape under which the loss computation was tracked.
+ other_vars: list
+ List of all the other trainable variables.
+ """
+ all_trainable_vars = [ent_emb, rel_emb]
+ all_trainable_vars.extend(other_vars)
+ # Total number of trainable variables in the graph
+ self.num_optimized_vars = len(all_trainable_vars)
+
+ if self.gpu_workaround:
+ # workaround - see the issue:
+ # https://github.com/tensorflow/tensorflow/issues/28090
+ with gradient_tape:
+ loss += 0.0000 * (
+ tf.reduce_sum(ent_emb) + tf.reduce_sum(rel_emb)
+ )
+
+ # Compute gradient of loss wrt trainable vars
+ gradients = gradient_tape.gradient(loss, all_trainable_vars)
+ # update the trainable params
+ self.optimizer.apply_gradients(zip(gradients, all_trainable_vars))
+
+ # Compute the number of hyperparameters related to the optimizer
+ # if self.is_partitioned_training and self.number_hyperparams == -1:
+ # optim_weights = self.optimizer.get_weights()
+ # self.number_hyperparams = 0
+ # for i in range(1, len(optim_weights), self.num_optimized_vars):
+ # self.number_hyperparams += 1
+
+ def get_hyperparam_count(self):
+ """Number of hyperparams of the optimizer being used.
+
+ E.g., `adam` has `beta1` and `beta2`; if we use the `amsgrad` argument then it has also a third.
+ """
+ return self.number_hyperparams
+
+ def get_entity_relation_hyperparams(self):
+ """Get optimizer hyperparams related to entity and relation embeddings (for partitioned training).
Returns
-------
- train: tf.Operation
- Node that needs to be evaluated for minimizing the loss during training
+ ent_hyperparams: np.array
+ Entity embedding related optimizer hyperparameters.
+ rel_hyperparams: np.array
+ Relation embedding related optimizer hyperparameters.
"""
- raise NotImplementedError('Abstract Method not implemented!')
-
- def update_feed_dict(self, feed_dict, batch_num, epoch_num):
- """Fills values of placeholders created by the optimizers.
+ optim_weights = self.optimizer.get_weights()
+ ent_hyperparams = []
+ rel_hyperparams = []
+ for i in range(1, len(optim_weights), self.num_optimized_vars):
+ ent_hyperparams.append(optim_weights[i])
+ rel_hyperparams.append(optim_weights[i + 1])
- Parameters
- ----------
- feed_dict : dict
- Dictionary that would be passed while optimizing the model loss to sess.run.
- batch_num: int
- current batch number
- epoch_num: int
- current epoch number
- """
- raise NotImplementedError('Abstract Method not implemented!')
+ return ent_hyperparams, rel_hyperparams
-
-@register_optimizer("adagrad", ['lr'])
-class AdagradOptimizer(Optimizer):
- """Wrapper around adagrad optimizer
- """
-
- def __init__(self, optimizer_params, batches_count, verbose=False):
- """Initialize the Optimizer
+ def set_entity_relation_hyperparams(
+ self, ent_hyperparams, rel_hyperparams
+ ):
+ """Sets optimizer hyperparams related to entity and relation embeddings (for partitioned training).
Parameters
----------
- optimizer_params : dict
- Consists of key-value pairs. The optimizer will check the keys to get the corresponding params:
-
- - **'lr'**: (float). Learning Rate (default: 0.0005)
-
- Example: ``optimizer_params={'lr': 0.001}``
- batches_count: int
- number of batches in an epoch
- verbose : bool
- Enable/disable verbose mode
+ ent_hyperparams: np.array
+ Entity embedding related optimizer hyperparameters.
+ rel_hyperparams: np.array
+ Relation embedding related optimizer hyperparameters.
"""
+ optim_weights = self.optimizer.get_weights()
+ for i, j in zip(
+ range(1, len(optim_weights), self.num_optimized_vars),
+ range(len(ent_hyperparams)),
+ ):
+ optim_weights[i] = ent_hyperparams[j]
+ optim_weights[i + 1] = rel_hyperparams[j]
+ self.optimizer.set_weights(optim_weights)
- super(AdagradOptimizer, self).__init__(optimizer_params, batches_count, verbose)
+ def get_weights(self):
+ """Wrapper around get weights.
- def minimize(self, loss):
- """Create an optimizer to minimize the model loss
-
- Parameters
- ----------
- loss: tf.Tensor
- Node which needs to be evaluated for computing the model loss.
-
- Returns
- -------
- train: tf.Operation
- Node that needs to be evaluated for minimizing the loss during training
+ See `tf.keras.optimizers ` for more details.
"""
- self.optimizer = tf.train.AdagradOptimizer(learning_rate=self._optimizer_params['lr'])
- train = self.optimizer.minimize(loss)
- return train
-
- def update_feed_dict(self, feed_dict, batch_num, epoch_num):
- """Fills values of placeholders created by the optimizers.
-
- Parameters
- ----------
- feed_dict : dict
- Dictionary that would be passed while optimizing the model loss to sess.run.
- batch_num: int
- current batch number
- epoch_num: int
- current epoch number
- """
- return
-
-
-@register_optimizer("adam", ['lr'])
-class AdamOptimizer(Optimizer):
- """Wrapper around Adam Optimizer
- """
-
- def __init__(self, optimizer_params, batches_count, verbose=False):
- """Initialize the Optimizer
-
- Parameters
- ----------
- optimizer_params : dict
- Consists of key-value pairs. The optimizer will check the keys to get the corresponding params:
+ return self.optimizer.get_weights()
- - **'lr'**: (float). Learning Rate (default: 0.0005)
+ def set_weights(self, weights):
+ """Wrapper around set weights.
- Example: ``optimizer_params={'lr': 0.001}``
- batches_count: int
- number of batches in an epoch
- verbose : bool
- Enable/disable verbose mode
+ See `tf.keras.optimizers ` for more details.
"""
+ self.optimizer.set_weights(weights)
- super(AdamOptimizer, self).__init__(optimizer_params, batches_count, verbose)
-
- def minimize(self, loss):
- """Create an optimizer to minimize the model loss
+ def get_iterations(self):
+ return self.optimizer.iterations.numpy()
- Parameters
- ----------
- loss: tf.Tensor
- Node which needs to be evaluated for computing the model loss.
+ def get_config(self):
+ return self.optimizer.get_config()
- Returns
- -------
- train: tf.Operation
- Node that needs to be evaluated for minimizing the loss during training
- """
- self.optimizer = tf.train.AdamOptimizer(learning_rate=self._optimizer_params['lr'])
+ @classmethod
+ def from_config(cls, config):
+ new_config = {}
+ new_config["class_name"] = config["name"]
- train = self.optimizer.minimize(loss)
- return train
+ del config["name"]
+ new_config["config"] = config
+ optimizer = tf.keras.optimizers.get(new_config)
+ return optimizer
- def update_feed_dict(self, feed_dict, batch_num, epoch_num):
- """Fills values of placeholders created by the optimizers.
- Parameters
- ----------
- feed_dict : dict
- Dictionary that would be passed while optimizing the model loss to sess.run.
- batch_num: int
- current batch number
- epoch_num: int
- current epoch number
- """
- return
-
-
-@register_optimizer("momentum", ['lr', 'momentum'])
-class MomentumOptimizer(Optimizer):
- """Wrapper around Momentum Optimizer
+def get(identifier):
"""
+ Get the optimizer specified by the identifier.
- def __init__(self, optimizer_params, batches_count, verbose=False):
- """Initialize the Optimizer
+ Parameters
+ ----------
+ identifier: str or tf.optimizers.Optimizer instance
+ Name of the optimizer to use (with default parameters) or instance of the class `tf.optimizers.Optimizer`.
- Parameters
- ----------
- optimizer_params : dict
- Consists of key-value pairs. The optimizer will check the keys to get the corresponding params:
+ Returns
+ -------
+ optimizer: OptimizerWrapper
+ Instance of `tf.optimizers.Optimizer` wrapped around by `OptimizerWrapper` so that graph partitioning
+ is supported.
- - **'lr'**: (float). Learning Rate (default: 0.0005)
- - **'momentum'**: (float). Momentum (default: 0.9)
-
- Example: ``optimizer_params={'lr': 0.001, 'momentum':0.90}``
- batches_count: int
- number of batches in an epoch
- verbose : bool
- Enable/disable verbose mode
- """
-
- super(MomentumOptimizer, self).__init__(optimizer_params, batches_count, verbose)
-
- def _init_hyperparams(self, hyperparam_dict):
- """ Initializes the hyperparameters needed by the algorithm.
-
- Parameters
- ----------
- hyperparam_dict : dictionary
- Consists of key value pairs. The optimizer will check the keys to get the corresponding params
- """
-
- self._optimizer_params['lr'] = hyperparam_dict.get('lr', DEFAULT_LR)
- self._optimizer_params['momentum'] = hyperparam_dict.get('momentum', DEFAULT_MOMENTUM)
-
- if self.verbose:
- self._display_params()
-
- def minimize(self, loss):
- """Create an optimizer to minimize the model loss
-
- Parameters
- ----------
- loss: tf.Tensor
- Node which needs to be evaluated for computing the model loss.
-
- Returns
- -------
- train: tf.Operation
- Node that needs to be evaluated for minimizing the loss during training
- """
- self.optimizer = tf.train.MomentumOptimizer(learning_rate=self._optimizer_params['lr'],
- momentum=self._optimizer_params['momentum'])
-
- train = self.optimizer.minimize(loss)
- return train
-
- def update_feed_dict(self, feed_dict, batch_num, epoch_num):
- """Fills values of placeholders created by the optimizers.
-
- Parameters
- ----------
- feed_dict : dict
- Dictionary that would be passed while optimizing the model loss to sess.run.
- batch_num: int
- current batch number
- epoch_num: int
- current epoch number
- """
- return
-
-
-@register_optimizer("sgd", ['lr', 'decay_cycle', 'end_lr', 'sine_decay', 'expand_factor', 'decay_lr_rate'])
-class SGDOptimizer(Optimizer):
- '''Wrapper around SGD Optimizer
- '''
- def __init__(self, optimizer_params, batches_count, verbose=False):
- """Initialize the Optimizer
-
- Parameters
- ----------
- optimizer_params : dict
- Consists of key-value pairs. The optimizer will check the keys to get the corresponding params:
-
- - **'lr'**: (float). Learning Rate upper bound (default: 0.0005)
- - **'decay_cycle'**: (int). Cycle of epoch over which to decay (default: 0)
- - **'end_lr'**: (float). Learning Rate lower bound (default: 1e-8)
- - **'cosine_decay'**: (bool). Use cosine decay or to fixed rate decay (default: False)
- - **'expand_factor'**: (float). Expand the decay cycle length by this factor after each cycle \
- (default: 1)
- - **'decay_lr_rate'**: (float). Decay factor to decay the start lr after each cycle \
- (default: 2)
-
- Example: ``optimizer_params={'lr': 0.01, 'decay_cycle':30, 'end_lr':0.0001, 'sine_decay':True}``
- batches_count: int
- number of batches in an epoch
- verbose : bool
- Enable/disable verbose mode
- """
- super(SGDOptimizer, self).__init__(optimizer_params, batches_count, verbose)
-
- def _init_hyperparams(self, hyperparam_dict):
- """ Initializes the hyperparameters needed by the algorithm.
-
- Parameters
- ----------
- hyperparam_dict : dictionary
- Consists of key value pairs. The optimizer will check the keys to get the corresponding params
- """
-
- self._optimizer_params['lr'] = hyperparam_dict.get('lr', DEFAULT_LR)
- self._optimizer_params['decay_cycle'] = hyperparam_dict.get('decay_cycle', DEFAULT_DECAY_CYCLE)
- self._optimizer_params['cosine_decay'] = hyperparam_dict.get('cosine_decay', DEFAULT_SINE)
- self._optimizer_params['expand_factor'] = hyperparam_dict.get('expand_factor', DEFAULT_DECAY_CYCLE_MULTIPLE)
- self._optimizer_params['decay_lr_rate'] = hyperparam_dict.get('decay_lr_rate', DEFAULT_LR_DECAY_FACTOR)
- self._optimizer_params['end_lr'] = hyperparam_dict.get('end_lr', DEFAULT_END_LR)
-
- if self.verbose:
- self._display_params()
-
- def minimize(self, loss):
- """Create an optimizer to minimize the model loss
-
- Parameters
- ----------
- loss: tf.Tensor
- Node which needs to be evaluated for computing the model loss.
-
- Returns
- -------
- train: tf.Operation
- Node that needs to be evaluated for minimizing the loss during training
- """
-
- # create a placeholder for learning rate
- self.lr_placeholder = tf.placeholder(tf.float32)
- # create the optimizer with the placeholder
- self.optimizer = tf.train.GradientDescentOptimizer(learning_rate=self.lr_placeholder)
-
- # load the hyperparameters that would be used while generating the learning rate per batch
- # start learning rate
- self.start_lr = self._optimizer_params['lr']
- self.current_lr = self.start_lr
-
- # cycle rate for learning rate decay
- self.decay_cycle_rate = self._optimizer_params['decay_cycle']
- self.end_lr = self._optimizer_params['end_lr']
-
- # check if it is a sinudoidal decay or constant decay
- self.is_cosine_decay = self._optimizer_params['cosine_decay']
- self.next_cycle_epoch = self.decay_cycle_rate + 1
-
- # Get the cycle expand factor
- self.decay_cycle_expand_factor = self._optimizer_params['expand_factor']
-
- # Get the LR decay factor at the start of each cycle
- self.decay_lr_rate = self._optimizer_params['decay_lr_rate']
- self.curr_cycle_length = self.decay_cycle_rate
- self.curr_start = 0
-
- # create the operation that minimizes the loss
- train = self.optimizer.minimize(loss)
- return train
-
- def update_feed_dict(self, feed_dict, batch_num, epoch_num):
- """Fills values of placeholders created by the optimizers.
-
- Parameters
- ----------
- feed_dict : dict
- Dictionary that would be passed while optimizing the model loss to sess.run.
- batch_num: int
- current batch number
- epoch_num: int
- current epoch number
- """
- # Sinusoidal Decay
- if self.is_cosine_decay:
- # compute the cycle number
- current_cycle_num = \
- ((epoch_num - 1 - self.curr_start) * self.batches_count + (batch_num - 1)) / \
- (self.curr_cycle_length * self.batches_count)
- # compute a learning rate for the current batch/epoch
- self.current_lr = \
- self.end_lr + (self.start_lr - self.end_lr) * 0.5 * (1 + math.cos(math.pi * current_cycle_num))
-
- # Start the next cycle and Expand the cycle/Decay the learning rate
- if epoch_num % (self.next_cycle_epoch - 1) == 0 and batch_num == self.batches_count:
- self.curr_cycle_length = self.curr_cycle_length * self.decay_cycle_expand_factor
- self.next_cycle_epoch = self.next_cycle_epoch + self.curr_cycle_length
- self.curr_start = epoch_num
- self.start_lr = self.start_lr / self.decay_lr_rate
-
- if self.current_lr < self.end_lr:
- self.current_lr = self.end_lr
-
- # fixed rate decay
- elif self.decay_cycle_rate > 0:
- if epoch_num % (self.next_cycle_epoch) == 0 and batch_num == 1:
- if self.current_lr > self.end_lr:
- self.next_cycle_epoch = self.decay_cycle_rate + \
- ((self.next_cycle_epoch - 1) * self.decay_cycle_expand_factor) + 1
- self.current_lr = self.current_lr / self.decay_lr_rate
-
- if self.current_lr < self.end_lr:
- self.current_lr = self.end_lr
-
- # no change to the learning rate
- else:
- pass
-
- feed_dict.update({self.lr_placeholder: self.current_lr})
+ """
+ if isinstance(identifier, tf.optimizers.Optimizer):
+ return OptimizerWrapper(identifier)
+ elif isinstance(identifier, OptimizerWrapper):
+ return identifier
+ elif isinstance(identifier, six.string_types):
+ optimizer = tf.keras.optimizers.get(identifier)
+ return OptimizerWrapper(optimizer)
+ else:
+ raise ValueError(
+ "Could not interpret optimizer identifier:", identifier
+ )
diff --git a/ampligraph/latent_features/pool_functions.py b/ampligraph/latent_features/pool_functions.py
deleted file mode 100755
index f9746992..00000000
--- a/ampligraph/latent_features/pool_functions.py
+++ /dev/null
@@ -1,66 +0,0 @@
-# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
-#
-# This file is Licensed under the Apache License, Version 2.0.
-# A copy of the Licence is available in LICENCE, or at:
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-import tensorflow as tf
-import logging
-
-logger = logging.getLogger(__name__)
-logger.setLevel(logging.DEBUG)
-
-
-def sum_pooling(embeddings):
- """Sum pooling function
- Performs pooling by summation of all embeddings along neighbour axis.
-
- Parameters
- ----------
- embeddings : Tensor, shape [B, max_rel, emb_dim]
- The embeddings of a list of subjects.
-
- Returns
- -------
- v : TensorFlow operation
- Reduced vector v
-
- """
- return tf.reduce_sum(embeddings, axis=1)
-
-
-def avg_pooling(embeddings):
- """Sum pooling function
- Performs pooling by summation of all embeddings along neighbour axis.
-
- Parameters
- ----------
- embeddings : Tensor, shape [B, max_rel, emb_dim]
- The embeddings of a list of subjects.
-
- Returns
- -------
- v : TensorFlow operation
- Reduced vector v
-
- """
- return tf.reduce_mean(embeddings, axis=1)
-
-
-def max_pooling(embeddings):
- """Sum pooling function
- Performs pooling by summation of all embeddings along neighbour axis.
-
- Parameters
- ----------
- embeddings : Tensor, shape [B, max_rel, emb_dim]
- The embeddings of a list of subjects.
-
- Returns
- -------
- v : TensorFlow operation
- Reduced vector v
-
- """
- return tf.reduce_max(embeddings, axis=1)
diff --git a/ampligraph/latent_features/regularizers.py b/ampligraph/latent_features/regularizers.py
index 38637574..0f151914 100644
--- a/ampligraph/latent_features/regularizers.py
+++ b/ampligraph/latent_features/regularizers.py
@@ -1,233 +1,68 @@
-# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
#
# This file is Licensed under the Apache License, Version 2.0.
# A copy of the Licence is available in LICENCE, or at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
-import tensorflow as tf
-import numpy as np
-import abc
-import logging
-
-REGULARIZER_REGISTRY = {}
-
-logger = logging.getLogger(__name__)
-logger.setLevel(logging.DEBUG)
+from functools import partial
-def register_regularizer(name, external_params=None, class_params=None):
- if external_params is None:
- external_params = []
- if class_params is None:
- class_params = {}
-
- def insert_in_registry(class_handle):
- REGULARIZER_REGISTRY[name] = class_handle
- class_handle.name = name
- REGULARIZER_REGISTRY[name].external_params = external_params
- REGULARIZER_REGISTRY[name].class_params = class_params
- return class_handle
+import tensorflow as tf
- return insert_in_registry
+def LP_regularizer(trainable_param, regularizer_parameters={}):
+ """Norm :math:`L^{p}` regularizer.
-# defalut lambda to be used in L1, L2 and L3 regularizer
-DEFAULT_LAMBDA = 1e-5
+ It is passed to the model as the ``entity_relation_regularizer`` argument of the
+ :meth:`~ampligraph.latent_features.models.ScoringBasedEmbeddingModel.compile` method.
-# default regularization - L2
-DEFAULT_NORM = 2
+ Parameters
+ ----------
+ trainable_param: tf.Variable
+ Trainable parameters of the model that need to be regularized.
+ regularizer_parameters: dict
+ Parameters of the regularizer:
+ - **p**: (int) - p for the LP regularizer. For example, when :math:`p=2` (default), it uses the L2 regularizer.
+ - **lambda** : (float) - Regularizer weight (default: 0.00001).
+ Returns
+ -------
+ regularizer: tf.keras.regularizer
+ Regularizer instance from the `tf.keras.regularizer` class.
-class Regularizer(abc.ABC):
- """Abstract class for Regularizer.
"""
-
- name = ""
- external_params = []
- class_params = {}
-
- def __init__(self, hyperparam_dict, verbose=False):
- """Initialize the regularizer.
-
- Parameters
- ----------
- hyperparam_dict : dict
- dictionary of hyperparams
- (Keys are described in the hyperparameters section)
- """
- self._regularizer_parameters = {}
-
- # perform check to see if all the required external hyperparams are passed
- try:
- self._init_hyperparams(hyperparam_dict)
- if verbose:
- logger.info('\n------ Regularizer -----')
- logger.info('Name : {}'.format(self.name))
- for key, value in self._regularizer_parameters.items():
- logger.info('{} : {}'.format(key, value))
-
- except KeyError as e:
- msg = 'Some of the hyperparams for regularizer were not passed.\n{}'.format(e)
- logger.error(msg)
- raise Exception(msg)
-
- def get_state(self, param_name):
- """Get the state value.
-
- Parameters
- ----------
- param_name : string
- name of the state for which one wants to query the value
- Returns
- -------
- param_value:
- the value of the corresponding state
- """
- try:
- param_value = REGULARIZER_REGISTRY[self.name].class_params.get(param_name)
- return param_value
- except KeyError as e:
- msg = 'Invalid Key.\n{}'.format(e)
- logger.error(msg)
- raise Exception(msg)
-
- def _init_hyperparams(self, hyperparam_dict):
- """Initializes the hyperparameters needed by the algorithm.
-
- Parameters
- ----------
- hyperparam_dict : dictionary
- Consists of key value pairs. The regularizer will check the keys to get the corresponding params
- """
- logger.error('This function is a placeholder in an abstract class')
- raise NotImplementedError("This function is a placeholder in an abstract class")
-
- def _apply(self, trainable_params):
- """Apply the regularization function. Every inherited class must implement this function.
-
- (All the TF code must go in this function.)
-
- Parameters
- ----------
- trainable_params : list, shape [n]
- List of trainable params that should be reqularized
-
- Returns
- -------
- loss : tf.Tensor
- Regularization Loss
- """
- logger.error('This function is a placeholder in an abstract class')
- raise NotImplementedError("This function is a placeholder in an abstract class")
-
- def apply(self, trainable_params):
- """Interface to external world. This function performs input checks, input pre-processing, and
- and applies the loss function.
-
- Parameters
- ----------
- trainable_params : list, shape [n]
- List of trainable params that should be reqularized
-
- Returns
- -------
- loss : tf.Tensor
- Regularization Loss
- """
- loss = self._apply(trainable_params)
- return loss
+ return regularizer_parameters.get("lambda", 0.00001) * tf.reduce_sum(
+ tf.pow(tf.abs(trainable_param), regularizer_parameters.get("p", 2))
+ )
-@register_regularizer("LP", ['p', 'lambda'])
-class LPRegularizer(Regularizer):
- r"""Performs LP regularization
+def get(identifier, hyperparams={}):
+ """Get the regularizer specified by the identifier.
- .. math::
-
- \mathcal{L}(Reg) = \sum_{i=1}^{n} \lambda_i * \mid w_i \mid_p
-
- where n is the number of model parameters, :math:`p \in{1,2,3}` is the p-norm and
- :math:`\lambda` is the regularization weight.
-
- For example, if :math:`p=1` the function will perform L1 regularization.
- L2 regularization is obtained with :math:`p=2`.
-
- The nuclear 3-norm proposed in the ComplEx-N3 paper :cite:`lacroix2018canonical` can be obtained with
- ``regularizer_params={'p': 3}``.
+ Parameters
+ ----------
+ identifier: str or tf.keras.regularizer or a callable
+ Name of the regularizer to use (with default parameters) or instance of `tf.keras.regularizer` or a
+ callable function.
+ Returns
+ -------
+ regularizer: tf.keras.regularizer
+ Regularizer instance of the `tf.keras.regularizer` class.
"""
-
- def __init__(self, regularizer_params=None, verbose=False):
- """Initializes the hyperparameters needed by the algorithm.
-
- Parameters
- ----------
- regularizer_params : dictionary
- Consists of key-value pairs. The regularizer will check the keys to get the corresponding params:
-
- - **'lambda'**: (float). Weight of regularization loss for each parameter (default: 1e-5)
- - **'p'**: (int): norm (default: 2)
-
- Example: ``regularizer_params={'lambda': 1e-5, 'p': 1}``
-
- """
- if regularizer_params is None:
- regularizer_params = {'lambda': DEFAULT_LAMBDA, 'p': DEFAULT_NORM}
- super().__init__(regularizer_params, verbose)
-
- def _init_hyperparams(self, hyperparam_dict):
- """Initializes the hyperparameters needed by the algorithm.
-
- Parameters
- ----------
- hyperparam_dict : dictionary
- Consists of key value pairs. The regularizer will check the keys to get the corresponding params:
-
- 'lambda': list or float
- weight for regularizer loss for each parameter(default: 1e-5).
- If list, size must be equal to no. of parameters.
-
- 'p': int
- Norm of the regularizer (``1`` for L1 regularizer, ``2`` for L2 and so on.) (default:2)
-
- """
- self._regularizer_parameters['lambda'] = hyperparam_dict.get('lambda', DEFAULT_LAMBDA)
- self._regularizer_parameters['p'] = hyperparam_dict.get('p', DEFAULT_NORM)
- if not isinstance(self._regularizer_parameters['p'], (int, np.integer)):
- msg = 'Invalid value for regularizer parameter p:{}. Supported type int, np.int32 or np.int64'.format(
- self._regularizer_parameters['p'])
- logger.error(msg)
- raise Exception(msg)
-
- def _apply(self, trainable_params):
- """Apply the regularizer to the params.
-
- Parameters
- ----------
- trainable_params : list, shape [n]
- List of trainable params that should be reqularized.
-
- Returns
- -------
- loss : tf.Tensor
- Regularization Loss
-
- """
- if np.isscalar(self._regularizer_parameters['lambda']):
- self._regularizer_parameters['lambda'] = [self._regularizer_parameters['lambda']] * len(trainable_params)
- elif isinstance(self._regularizer_parameters['lambda'], list) and len(
- self._regularizer_parameters['lambda']) == len(trainable_params):
- pass
- else:
- logger.error('Regularizer weight must be a scalar or a list with length equal to number of params passes')
- raise ValueError(
- "Regularizer weight must be a scalar or a list with length equal to number of params passes")
-
- loss_reg = 0
- for i in range(len(trainable_params)):
- loss_reg += (self._regularizer_parameters['lambda'][i] * tf.reduce_sum(
- tf.pow(tf.abs(trainable_params[i]), self._regularizer_parameters['p'])))
-
- return loss_reg
+ if isinstance(identifier, str) and identifier == "l3":
+ hyperparams["p"] = 3
+ identifier = partial(
+ LP_regularizer, regularizer_parameters=hyperparams
+ )
+ identifier = tf.keras.regularizers.get(identifier)
+ identifier.__name__ = "LP"
+ elif isinstance(identifier, str) and identifier == "LP":
+ identifier = partial(
+ LP_regularizer, regularizer_parameters=hyperparams
+ )
+ identifier = tf.keras.regularizers.get(identifier)
+ identifier.__name__ = "LP"
+ return identifier
diff --git a/ampligraph/utils/__init__.py b/ampligraph/utils/__init__.py
index 626229a3..f87fc960 100644
--- a/ampligraph/utils/__init__.py
+++ b/ampligraph/utils/__init__.py
@@ -1,16 +1,29 @@
-# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
#
# This file is Licensed under the Apache License, Version 2.0.
# A copy of the Licence is available in LICENCE, or at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
-"""This module contains utility functions for neural knowledge graph embedding models.
+"""This module contains utility functions for neural knowledge graph
+ embedding models.
"""
-from .model_utils import save_model, restore_model, create_tensorboard_visualizations, \
- write_metadata_tsv, dataframe_to_triples
+from .model_utils import (
+ create_tensorboard_visualizations,
+ dataframe_to_triples,
+ preprocess_focusE_weights,
+ restore_model,
+ save_model,
+ write_metadata_tsv,
+)
-__all__ = ['save_model', 'restore_model', 'create_tensorboard_visualizations',
- 'write_metadata_tsv', 'dataframe_to_triples']
+__all__ = [
+ "save_model",
+ "restore_model",
+ "create_tensorboard_visualizations",
+ "write_metadata_tsv",
+ "dataframe_to_triples",
+ "preprocess_focusE_weights",
+]
diff --git a/ampligraph/utils/model_utils.py b/ampligraph/utils/model_utils.py
index b39f5b53..42aeec01 100644
--- a/ampligraph/utils/model_utils.py
+++ b/ampligraph/utils/model_utils.py
@@ -1,21 +1,21 @@
-# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
#
# This file is Licensed under the Apache License, Version 2.0.
# A copy of the Licence is available in LICENCE, or at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
+import glob
+import logging
import os
import pickle
-import importlib
+import shutil
from time import gmtime, strftime
-import glob
-import logging
-import tensorflow as tf
-from tensorflow.contrib.tensorboard.plugins import projector
import numpy as np
import pandas as pd
+import tensorflow as tf
+from tensorboard.plugins import projector
"""This module contains utility functions for neural knowledge graph embedding models.
"""
@@ -29,139 +29,117 @@
def save_model(model, model_name_path=None, protocol=pickle.HIGHEST_PROTOCOL):
"""Save a trained model to disk.
- Examples
- --------
- >>> import numpy as np
- >>> from ampligraph.latent_features import ComplEx
- >>> from ampligraph.utils import save_model
- >>> model = ComplEx(batches_count=2, seed=555, epochs=20, k=10)
- >>> X = np.array([['a', 'y', 'b'],
- >>> ['b', 'y', 'a'],
- >>> ['a', 'y', 'c'],
- >>> ['c', 'y', 'a'],
- >>> ['a', 'y', 'd'],
- >>> ['c', 'y', 'd'],
- >>> ['b', 'y', 'c'],
- >>> ['f', 'y', 'e']])
- >>> model.fit(X)
- >>> y_pred_before = model.predict(np.array([['f', 'y', 'e'], ['b', 'y', 'd']]))
- >>> example_name = 'helloworld.pkl'
- >>> save_model(model, model_name_path = example_name)
- >>> print(y_pred_before)
- [-0.29721245, 0.07865551]
-
- Parameters
- ----------
- model: EmbeddingModel
- A trained neural knowledge graph embedding model,
- the model must be an instance of TransE,
- DistMult, ComplEx, or HolE.
- model_name_path: string
- The name of the model to be saved.
- If not specified, a default name model
- with current datetime is named
- and saved to the working directory
-
- """
-
- logger.debug('Saving model {}.'.format(model.__class__.__name__))
-
- obj = {
- 'class_name': model.__class__.__name__,
- 'hyperparams': model.all_params,
- 'is_fitted': model.is_fitted,
- 'ent_to_idx': model.ent_to_idx,
- 'rel_to_idx': model.rel_to_idx,
- 'is_calibrated': model.is_calibrated
- }
-
- model.get_embedding_model_params(obj)
+ Example
+ -------
+ >>> import numpy as np
+ >>> from ampligraph.latent_features import ComplEx
+ >>> from ampligraph.utils import save_model
+ >>> model = ComplEx(batches_count=2, seed=555, epochs=20, k=10)
+ >>> X = np.array([['a', 'y', 'b'],
+ >>> ['b', 'y', 'a'],
+ >>> ['a', 'y', 'c'],
+ >>> ['c', 'y', 'a'],
+ >>> ['a', 'y', 'd'],
+ >>> ['c', 'y', 'd'],
+ >>> ['b', 'y', 'c'],
+ >>> ['f', 'y', 'e']])
+ >>> model.fit(X)
+ >>> y_pred_before = model.predict(np.array([['f', 'y', 'e'], ['b', 'y', 'd']]))
+ >>> example_name = 'helloworld.pkl'
+ >>> save_model(model, model_name_path=example_name)
+ >>> print(y_pred_before)
+ [-0.29721245, 0.07865551]
- logger.debug('Saving hyperparams:{}\n\tis_fitted: \
- {}'.format(model.all_params, model.is_fitted))
+ Parameters
+ ----------
+ model: EmbeddingModel
+ A trained neural knowledge graph embedding model.
+ The model must be an instance of TransE, DistMult, ComplEx, or HolE.
+ model_name_path: str
+ The name of the model to be saved.
+ If not specified, a default name with current datetime is selected and the model is saved
+ to the working directory.
+ """
+ model.data_shape = tf.Variable(
+ model.data_shape, trainable=False
+ ) # Redefine the attribute for saving it
if model_name_path is None:
- model_name_path = DEFAULT_MODEL_NAMES.format(strftime("%Y_%m_%d-%H_%M_%S", gmtime()))
-
- with open(model_name_path, 'wb') as fw:
- pickle.dump(obj, fw, protocol=protocol)
- # dump model tf
+ model_name_path = "{0}".format(strftime("%Y_%m_%d-%H_%M_%S", gmtime()))
+ if os.path.exists(model_name_path):
+ print(
+ "The path {} already exists. This save operation will overwrite the model \
+ at the specified path.".format(
+ model_name_path
+ )
+ )
+ shutil.rmtree(model_name_path)
+ if model.is_backward:
+ model = model.model
+ tf.keras.models.save_model(model, model_name_path)
+ model.save_metadata(filedir=model_name_path)
def restore_model(model_name_path=None):
- """Restore a saved model from disk.
-
- See also :meth:`save_model`.
-
- Examples
- --------
- >>> from ampligraph.utils import restore_model
- >>> import numpy as np
- >>> example_name = 'helloworld.pkl'
- >>> restored_model = restore_model(model_name_path = example_name)
- >>> y_pred_after = restored_model.predict(np.array([['f', 'y', 'e'], ['b', 'y', 'd']]))
- >>> print(y_pred_after)
- [-0.29721245, 0.07865551]
-
- Parameters
- ----------
- model_name_path: string
- The name of saved model to be restored. If not specified,
- the library will try to find the default model in the working directory.
-
- Returns
- -------
- model: EmbeddingModel
- the neural knowledge graph embedding model restored from disk.
+ """Restore a trained model from disk.
+ Parameters
+ ----------
+ model_name_path : str
+ Name of the path to the model.
"""
+ from ampligraph.compat.models import BACK_COMPAT_MODELS
+ from ampligraph.latent_features import ScoringBasedEmbeddingModel
+ from ampligraph.latent_features.layers.encoding import EmbeddingLookupLayer
+ from ampligraph.latent_features.loss_functions import LOSS_REGISTRY
+ from ampligraph.latent_features.optimizers import OptimizerWrapper
+
if model_name_path is None:
- logger.warning("There is no model name specified. \
+ logger.warning(
+ "There is no model name specified. \
We will try to lookup \
- the latest default saved model...")
- default_models = glob.glob("*.model.pkl")
+ the latest default saved model..."
+ )
+ default_models = glob.glob("*.ampkl")
if len(default_models) == 0:
- raise Exception("No default model found. Please specify \
- model_name_path...")
- else:
- model_name_path = default_models[len(default_models) - 1]
- logger.info("Will will load the model: {0} in your \
- current dir...".format(model_name_path))
-
- model = None
- logger.info('Will load model {}.'.format(model_name_path))
+ raise Exception(
+ "No default model found. Please specify \
+ model_name_path..."
+ )
try:
- with open(model_name_path, 'rb') as fr:
- restored_obj = pickle.load(fr)
-
- logger.debug('Restoring model ...')
- module = importlib.import_module("ampligraph.latent_features")
- class_ = getattr(module, restored_obj['class_name'])
- model = class_(**restored_obj['hyperparams'])
- model.is_fitted = restored_obj['is_fitted']
- model.ent_to_idx = restored_obj['ent_to_idx']
- model.rel_to_idx = restored_obj['rel_to_idx']
-
- try:
- model.is_calibrated = restored_obj['is_calibrated']
- except KeyError:
- model.is_calibrated = False
-
- model.restore_model_params(restored_obj)
+ custom_objects = {
+ "ScoringBasedEmbeddingModel": ScoringBasedEmbeddingModel,
+ "OptimizerWrapper": OptimizerWrapper,
+ "embedding_lookup_layer": EmbeddingLookupLayer,
+ }
+ custom_objects.update(LOSS_REGISTRY)
+
+ model = tf.keras.models.load_model(
+ model_name_path, custom_objects=custom_objects
+ )
+ model.load_metadata(filedir=model_name_path)
+ if model.is_backward:
+ model = BACK_COMPAT_MODELS.get(model.scoring_type)(model)
except pickle.UnpicklingError as e:
- msg = 'Error unpickling model {} : {}.'.format(model_name_path, e)
+ msg = "Error loading model {} : {}.".format(model_name_path, e)
logger.debug(msg)
raise Exception(msg)
except (IOError, FileNotFoundError):
- msg = 'No model found: {}.'.format(model_name_path)
+ msg = "No model found: {}.".format(model_name_path)
logger.debug(msg)
raise FileNotFoundError(msg)
-
return model
-def create_tensorboard_visualizations(model, loc, labels=None, write_metadata=True, export_tsv_embeddings=True):
+def create_tensorboard_visualizations(
+ model,
+ loc,
+ entities_subset="all",
+ labels=None,
+ write_metadata=True,
+ export_tsv_embeddings=True,
+):
"""Export embeddings to Tensorboard.
This function exports embeddings to disk in a format used by
@@ -170,12 +148,12 @@ def create_tensorboard_visualizations(model, loc, labels=None, write_metadata=Tr
The function exports:
* A number of checkpoint and graph embedding files in the provided location that will allow
- you to visualize embeddings using Tensorboard. This is generally for use with a
+ the visualization of the embeddings using Tensorboard. This is generally for use with a
`local Tensorboard instance `_.
- * a tab-separated file of embeddings ``embeddings_projector.tsv``. This is generally used to
+ * A tab-separated file of embeddings named `embeddings_projector.tsv`. This is generally used to
visualize embeddings by uploading to `TensorBoard Embedding Projector `_.
- * embeddings metadata (i.e. the embeddings labels from the original knowledge graph), saved to ``metadata.tsv``.
- Such file can be used in TensorBoard or uploaded to TensorBoard Embedding Projector.
+ * Embeddings metadata (i.e., the embedding labels from the original knowledge graph) saved to in a file named
+ `metadata.tsv``. Such file can be used in TensorBoard or uploaded to TensorBoard Embedding Projector.
The content of ``loc`` will look like: ::
@@ -189,35 +167,47 @@ def create_tensorboard_visualizations(model, loc, labels=None, write_metadata=Tr
└── projector_config.pbtxt
.. Note ::
- A TensorBoard guide is available at `this address `_.
+ A TensorBoard guide is available `here `_.
.. Note ::
- Uploading ``embeddings_projector.tsv`` and ``metadata.tsv`` to
+ Uploading `embeddings_projector.tsv` and `metadata.tsv` to
`TensorBoard Embedding Projector `_ will give a result
similar to the picture below:
.. image:: ../img/embeddings_projector.png
- Examples
- --------
- >>> import numpy as np
- >>> from ampligraph.latent_features import TransE
+ Example
+ -------
+ >>> # create model and compile using user defined optimizer settings and user defined settings of an existing loss
+ >>> from ampligraph.latent_features import ScoringBasedEmbeddingModel
+ >>> from ampligraph.latent_features.loss_functions import SelfAdversarialLoss
+ >>> import tensorflow as tf
+ >>> optim = tf.optimizers.Adam(learning_rate=0.01)
+ >>> loss = SelfAdversarialLoss({'margin': 0.1, 'alpha': 5, 'reduction': 'sum'})
+ >>> model = ScoringBasedEmbeddingModel(eta=5,
+ >>> k=300,
+ >>> scoring_type='ComplEx',
+ >>> seed=0)
+ >>> model.compile(optimizer=optim, loss=loss)
+ >>> model.fit('./fb15k-237/train.txt',
+ >>> batch_size=10000,
+ >>> epochs=5)
+ Epoch 1/5
+ 29/29 [==============================] - 2s 67ms/step - loss: 13101.9443
+ Epoch 2/5
+ 29/29 [==============================] - 1s 20ms/step - loss: 11907.5771
+ Epoch 3/5
+ 29/29 [==============================] - 1s 21ms/step - loss: 10890.3447
+ Epoch 4/5
+ 29/29 [==============================] - 1s 20ms/step - loss: 9520.3994
+ Epoch 5/5
+ 29/29 [==============================] - 1s 20ms/step - loss: 8314.7529
>>> from ampligraph.utils import create_tensorboard_visualizations
- >>>
- >>> X = np.array([['a', 'y', 'b'],
- >>> ['b', 'y', 'a'],
- >>> ['a', 'y', 'c'],
- >>> ['c', 'y', 'a'],
- >>> ['a', 'y', 'd'],
- >>> ['c', 'y', 'd'],
- >>> ['b', 'y', 'c'],
- >>> ['f', 'y', 'e']])
- >>>
- >>> model = TransE(batches_count=1, seed=555, epochs=20, k=10, loss='pairwise',
- >>> loss_params={'margin':5})
- >>> model.fit(X)
- >>>
- >>> create_tensorboard_visualizations(model, 'tensorboard_files')
+ >>> create_tensorboard_visualizations(model,
+ entities_subset='all',
+ loc = './full_embeddings_vis')
+ >>> # On terminal run: tensorboard --logdir='./full_embeddings_vis' --port=8891
+ >>> # Open the browser and go to the following URL: http://127.0.0.1:8891/#projector
Parameters
@@ -225,93 +215,109 @@ def create_tensorboard_visualizations(model, loc, labels=None, write_metadata=Tr
model: EmbeddingModel
A trained neural knowledge graph embedding model, the model must be an instance of TransE,
DistMult, ComplEx, or HolE.
- loc: string
+ loc: str
Directory where the files are written.
+ entities_subset: list
+ List of entities whose embeddings have to be visualized.
labels: pd.DataFrame
Label(s) for each embedding point in the Tensorboard visualization.
- Default behaviour is to use the embeddings labels included in the model.
- export_tsv_embeddings: bool (Default: True
- If True, will generate a tab-separated file of embeddings at the given path. This is generally used to
- visualize embeddings by uploading to `TensorBoard Embedding Projector `_.
- write_metadata: bool (Default: True)
- If True will write a file named 'metadata.tsv' in the same directory as path.
+ Default behaviour is to use the embedding labels included in the model.
+ export_tsv_embeddings: bool
+ If `True` (default), will generate a tab-separated file of embeddings at the given path.
+ This is generally used to visualize embeddings by uploading to
+ `TensorBoard Embedding Projector `_.
+ write_metadata: bool
+ If `True` (default), will write a file named `'metadata.tsv'` in the same directory as path.
"""
# Create loc if it doesn't exist
if not os.path.exists(loc):
- logger.debug('Creating Tensorboard visualization directory: %s' % loc)
+ logger.debug("Creating Tensorboard visualization directory: %s" % loc)
os.mkdir(loc)
- if not model.is_fitted:
- raise ValueError('Cannot write embeddings if model is not fitted.')
+ if not model.is_fit():
+ raise ValueError("Cannot write embeddings if model is not fitted.")
- # If no label data supplied, use model ent_to_idx keys as labels
- if labels is None:
+ if entities_subset != "all":
+ assert isinstance(
+ entities_subset, list
+ ), "Please pass a list of entities of entities_subset!"
- logger.info('Using model entity dictionary to create Tensorboard metadata.tsv')
- labels = list(model.ent_to_idx.keys())
+ if entities_subset == "all":
+ entities_index = np.arange(model.get_count("e"))
+
+ entities_label = list(
+ model.get_indexes(entities_index, type_of="e", order="ind2raw")
+ )
+ else:
+ entities_index = model.get_indexes(
+ entities_subset, type_of="e", order="raw2ind"
+ )
+ entities_label = entities_subset
+
+ if labels is not None:
+ # Check if the lengths of the supplied labels is equal to the number of embeddings retrieved
+ if len(labels) != len(entities_label):
+ raise ValueError(
+ "Label data rows must equal number of embeddings."
+ )
else:
- if len(labels) != len(model.ent_to_idx):
- raise ValueError('Label data rows must equal number of embeddings.')
+ # If no label data supplied, use model ent_to_idx keys as labels
+ labels = entities_label
if write_metadata:
- logger.debug('Writing metadata.tsv to: %s' % loc)
+ logger.debug("Writing metadata.tsv to: %s" % loc)
write_metadata_tsv(loc, labels)
+ embeddings = model.get_embeddings(entities_label)
+
if export_tsv_embeddings:
tsv_filename = "embeddings_projector.tsv"
- logger.info('Writing embeddings tsv to: %s' % os.path.join(loc, tsv_filename))
- np.savetxt(os.path.join(loc, tsv_filename), model.trained_model_params[0], delimiter='\t')
-
- checkpoint_path = os.path.join(loc, 'graph_embedding.ckpt')
+ logger.info(
+ "Writing embeddings tsv to: %s" % os.path.join(loc, tsv_filename)
+ )
+ np.savetxt(os.path.join(loc, tsv_filename), embeddings, delimiter="\t")
- # Create embeddings Variable
- embedding_var = tf.Variable(model.trained_model_params[0], name='graph_embedding')
+ # Create a checkpoint with the embeddings only
+ embeddings = tf.Variable(embeddings, name="graph_embeddings")
+ checkpoint = tf.train.Checkpoint(KGE_embeddings=embeddings)
+ checkpoint.save(os.path.join(loc, "graph_embeddings.ckpt"))
- with tf.Session() as sess:
- saver = tf.train.Saver([embedding_var])
-
- sess.run(embedding_var.initializer)
-
- saver.save(sess, checkpoint_path)
-
- config = projector.ProjectorConfig()
-
- # One can add multiple embeddings.
- embedding = config.embeddings.add()
- embedding.tensor_name = embedding_var.name
-
- # Link this tensor to its metadata file (e.g. labels).
- embedding.metadata_path = 'metadata.tsv'
-
- # Saves a config file that TensorBoard will read during startup.
- projector.visualize_embeddings(tf.summary.FileWriter(loc), config)
+ # create a config to display the embeddings in the checkpoint
+ config = projector.ProjectorConfig()
+ embedding = config.embeddings.add()
+ embedding.tensor_name = "KGE_embeddings/.ATTRIBUTES/VARIABLE_VALUE"
+ embedding.metadata_path = "metadata.tsv"
+ projector.visualize_embeddings(loc, config)
def write_metadata_tsv(loc, data):
- """Write Tensorboard metadata.tsv file.
+ """Write Tensorboard `"metadata.tsv"` file.
Parameters
----------
- loc: string
+ loc: str
Directory where the file is written.
- data: list of strings, or pd.DataFrame
+ data: list of strings or pd.DataFrame
Label(s) for each embedding point in the Tensorboard visualization.
- If data is a list of strings then no header will be written. If it is a pandas DataFrame with multiple
- columns then headers will be written.
+ If ``data`` is a list of strings then no header will be written. If it is a `pandas DataFrame` with multiple
+ columns, then the headers will be written.
"""
# Write metadata.tsv
- metadata_path = os.path.join(loc, 'metadata.tsv')
+ metadata_path = os.path.join(loc, "metadata.tsv")
if isinstance(data, list):
- with open(metadata_path, 'w+', encoding='utf8') as metadata_file:
+ with open(metadata_path, "w+", encoding="utf8") as metadata_file:
for row in data:
- metadata_file.write('%s\n' % row)
+ metadata_file.write("%s\n" % row)
elif isinstance(data, pd.DataFrame):
- data.to_csv(metadata_path, sep='\t', index=False)
+ data.to_csv(metadata_path, sep="\t", index=False)
+
+ else:
+ raise ValueError("Labels must be passed as a list or a dataframe")
def dataframe_to_triples(X, schema):
@@ -319,12 +325,12 @@ def dataframe_to_triples(X, schema):
Parameters
----------
- X: pandas DataFrame with headers
- schema: List of (subject, relation_name, object) tuples
- where subject and object are in the headers of the data frame
+ X: pd.DataFrame with headers
+ schema: list of tuples
+ List of (subject, relation_name, object) tuples where subject and object are in the headers of the data frame.
- Examples
- --------
+ Example
+ -------
>>> import pandas as pd
>>> import numpy as np
>>> from ampligraph.utils.model_utils import dataframe_to_triples
@@ -340,7 +346,62 @@ def dataframe_to_triples(X, schema):
request_headers = set(np.delete(np.array(schema), 1, 1).flatten())
diff = request_headers.difference(set(X.columns))
if len(diff) > 0:
- raise Exception("Subject/Object {} are not in data frame headers".format(diff))
+ raise Exception(
+ "Subject/Object {} are not in data frame headers".format(diff)
+ )
for s, p, o in schema:
triples.extend([[si, p, oi] for si, oi in zip(X[s], X[o])])
return np.array(triples)
+
+
+def preprocess_focusE_weights(data, weights, normalize=True):
+ """Preprocessing of focusE weights.
+
+ Extract weights from data, remove `NaNs`, average weights and normalize them
+ if ``self.focusE_params['normalize_numeric_values']==True``.
+
+ Parameters
+ ----------
+ data: array-like, shape (n,m)
+ Array of shape (n,m) with :math:`m=4`. If ``weights=None``, data contains triples
+ and weights (:math:`m>3`). If ``weights`` is passed, ``data`` only contains triples (:math:`m=3`).
+ weights: array-like
+ If not `None`, ``weights`` has shape (n, m-3), with m>0.
+ normalize : bool
+ Specify whether to normalize the weights into the [0,1] range (default: `True`).
+
+ Returns
+ -------
+ processed_weights: np.array, shape (n, 1)
+ An array of weights properly preprocessed and averaged into a unique vector if more than one vector of
+ weights were given.
+ """
+ if weights.ndim == 1:
+ weights = weights.reshape(-1, 1)
+ logger.debug("focusE normalizing weights")
+ unique_relations = np.unique(data[:, 1])
+ for reln in unique_relations:
+ for col_idx in range(weights.shape[1]):
+ # here nans signify unknown numeric values
+ suma = np.sum(pd.isna(weights[data[:, 1] == reln, col_idx]))
+ if suma != weights[data[:, 1] == reln, col_idx].shape[0]:
+ min_val = np.nanmin(
+ weights[data[:, 1] == reln, col_idx].astype(np.float32)
+ )
+ max_val = np.nanmax(
+ weights[data[:, 1] == reln, col_idx].astype(np.float32)
+ )
+ if min_val == max_val:
+ weights[data[:, 1] == reln, col_idx] = 1.0
+ continue
+ # Normalization of the weights
+ if normalize:
+ val = (
+ weights[data[:, 1] == reln, col_idx].astype(float)
+ - min_val
+ ) / (max_val - min_val)
+ weights[data[:, 1] == reln, col_idx] = val
+ else:
+ pass # all the weights are nans
+ weights = np.mean(weights, axis=1).reshape(-1, 1)
+ return weights
diff --git a/ampligraph/utils/profiling.py b/ampligraph/utils/profiling.py
new file mode 100644
index 00000000..79680966
--- /dev/null
+++ b/ampligraph/utils/profiling.py
@@ -0,0 +1,98 @@
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
+#
+# This file is Licensed under the Apache License, Version 2.0.
+# A copy of the Licence is available in LICENCE, or at:
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+import tracemalloc
+from functools import wraps
+from time import time
+
+
+def get_memory_size():
+ """Get memory size.
+
+ Returns
+ -------
+ Total: float
+ Memory size used in total.
+ """
+ snapshot = tracemalloc.take_snapshot()
+ stats = snapshot.statistics("lineno", cumulative=True)
+ total = sum(stat.size for stat in stats)
+ return total
+
+
+def get_human_readable_size(size_in_bytes):
+ """Convert size from bytes to human readable units.
+
+ Parameters
+ ----------
+ size_in_bytes: int
+ Original size given in bytes
+
+ Returns
+ -------
+ readable_size: tuple
+ Tuple of new size and unit, size in units GB/MB/KB/Bytes according
+ to thresholds.
+ """
+ if size_in_bytes >= 1024 * 1024 * 1024:
+ return float(size_in_bytes / (1024 * 1024 * 1024)), "GB"
+ if size_in_bytes >= 1024 * 1024:
+ return float(size_in_bytes / (1024 * 1024)), "MB"
+ if size_in_bytes >= 1024:
+ return float(size_in_bytes / 1024), "KB" # return in KB
+ return float(size_in_bytes), "Bytes"
+
+
+def timing_and_memory(f):
+ """Decorator to register time and memory used by a function f.
+
+ Parameters
+ ----------
+ f: function
+ Function for which the time and memory will be measured.
+
+ It logs the time and the memory in the dictionary passed inside `'log'`
+ parameter if provided. Time is logged in seconds, memory in bytes.
+ Example dictionary entry looks like that:
+ {'SPLIT': {'time': 1.62, 'memory-bytes': 789.097}},
+ where keys are names of functions that were called to get
+ the time measured in uppercase.
+
+ Requires
+ --------
+ passing **kwargs in function parameters
+ """
+
+ @wraps(f)
+ def wrapper(*args, **kwargs):
+ tracemalloc.start()
+ mem_before = get_memory_size()
+ start = time()
+ result = f(*args, **kwargs)
+ end = time()
+ mem_after = get_memory_size()
+ mem_diff = mem_after - mem_before
+ print(
+ "{}: memory before: {:.5}{}, after: {:.5}{},\
+ consumed: {:.5}{}; exec time: {:.5}s".format(
+ f.__name__,
+ *get_human_readable_size(mem_before),
+ *get_human_readable_size(mem_after),
+ *get_human_readable_size(mem_diff),
+ end - start
+ )
+ )
+
+ if "log" in kwargs:
+ name = kwargs.get("log_name", f.__name__.upper())
+ kwargs["log"][name] = {
+ "time": end - start,
+ "memory-bytes": mem_diff,
+ }
+ return result
+
+ return wrapper
diff --git a/ampligraph/utils/tags.py b/ampligraph/utils/tags.py
new file mode 100644
index 00000000..73dbe36c
--- /dev/null
+++ b/ampligraph/utils/tags.py
@@ -0,0 +1,87 @@
+# Copyright 2019-2023 The AmpliGraph Authors. All Rights Reserved.
+#
+# This file is Licensed under the Apache License, Version 2.0.
+# A copy of the Licence is available in LICENCE, or at:
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+import warnings
+
+
+class experimentalWarning(Warning):
+ """Warning that is triggered when the
+ experimental function is run.
+ """
+
+ def __init__(self, message):
+ self.message = message
+
+ def __str__(self):
+ return repr(self.message)
+
+
+def experimental(func):
+ """
+ Decorator - a function that accepts another function
+ and marks it as experimental, meaning it may change in
+ future releases, or its execution is not guaranteed.
+
+ Example:
+
+ >>>@experimental
+ >>>def a_function():
+ >>> "Demonstration function"
+ >>> return "demonstration"
+
+ >>>a_function()
+ experimentalWarning: 'Experimental! Function: a_function is experimental.
+ Use at your own risk.'
+ warnings.warn(experimentalWarning(msg))
+ demonstration
+
+ To disable experimentalWarning set this in the module:
+ >>>warnings.filterwarnings("ignore", category=experimentalWarning)
+
+ """
+
+ def mark_experimental():
+ msg = f"Experimental! Function: {func.__name__} is experimental. Use \
+ at your own risk."
+
+ warnings.warn(experimentalWarning(msg))
+
+ return func()
+
+ return mark_experimental
+
+
+def deprecated(*args, **kwargs):
+ """
+ Decorator - a function that accepts another function
+ and marks it as deprecated, meaning it may be discontinued in
+ future releases, and is provided only for backward compatibility purposes.
+
+ ---------------
+ Example:
+
+ >>>@deprecated(instead="module2.another_function")
+ >>>def a_function():
+ >>> "Demonstration function"
+ >>> return "demonstration"
+
+ >>>a_function()
+ DeprecationWarning: Deprecated! Function: a_function is deprecated.
+ Instead use module2.another_function.
+ warnings.warn(DeprecationWarning(msg))
+ demonstration
+ """
+
+ def mark_deprecated(func):
+ msg = f"Deprecated! Function: {func.__name__} is deprecated. \
+ Instead use {kwargs['instead']}."
+
+ warnings.warn(DeprecationWarning(msg))
+
+ return func
+
+ return mark_deprecated
diff --git a/docs/_static/_sphinx_javascript_frameworks_compat.js b/docs/_static/_sphinx_javascript_frameworks_compat.js
new file mode 100644
index 00000000..8549469d
--- /dev/null
+++ b/docs/_static/_sphinx_javascript_frameworks_compat.js
@@ -0,0 +1,134 @@
+/*
+ * _sphinx_javascript_frameworks_compat.js
+ * ~~~~~~~~~~
+ *
+ * Compatability shim for jQuery and underscores.js.
+ *
+ * WILL BE REMOVED IN Sphinx 6.0
+ * xref RemovedInSphinx60Warning
+ *
+ */
+
+/**
+ * select a different prefix for underscore
+ */
+$u = _.noConflict();
+
+
+/**
+ * small helper function to urldecode strings
+ *
+ * See https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/decodeURIComponent#Decoding_query_parameters_from_a_URL
+ */
+jQuery.urldecode = function(x) {
+ if (!x) {
+ return x
+ }
+ return decodeURIComponent(x.replace(/\+/g, ' '));
+};
+
+/**
+ * small helper function to urlencode strings
+ */
+jQuery.urlencode = encodeURIComponent;
+
+/**
+ * This function returns the parsed url parameters of the
+ * current request. Multiple values per key are supported,
+ * it will always return arrays of strings for the value parts.
+ */
+jQuery.getQueryParameters = function(s) {
+ if (typeof s === 'undefined')
+ s = document.location.search;
+ var parts = s.substr(s.indexOf('?') + 1).split('&');
+ var result = {};
+ for (var i = 0; i < parts.length; i++) {
+ var tmp = parts[i].split('=', 2);
+ var key = jQuery.urldecode(tmp[0]);
+ var value = jQuery.urldecode(tmp[1]);
+ if (key in result)
+ result[key].push(value);
+ else
+ result[key] = [value];
+ }
+ return result;
+};
+
+/**
+ * highlight a given string on a jquery object by wrapping it in
+ * span elements with the given class name.
+ */
+jQuery.fn.highlightText = function(text, className) {
+ function highlight(node, addItems) {
+ if (node.nodeType === 3) {
+ var val = node.nodeValue;
+ var pos = val.toLowerCase().indexOf(text);
+ if (pos >= 0 &&
+ !jQuery(node.parentNode).hasClass(className) &&
+ !jQuery(node.parentNode).hasClass("nohighlight")) {
+ var span;
+ var isInSVG = jQuery(node).closest("body, svg, foreignObject").is("svg");
+ if (isInSVG) {
+ span = document.createElementNS("http://www.w3.org/2000/svg", "tspan");
+ } else {
+ span = document.createElement("span");
+ span.className = className;
+ }
+ span.appendChild(document.createTextNode(val.substr(pos, text.length)));
+ node.parentNode.insertBefore(span, node.parentNode.insertBefore(
+ document.createTextNode(val.substr(pos + text.length)),
+ node.nextSibling));
+ node.nodeValue = val.substr(0, pos);
+ if (isInSVG) {
+ var rect = document.createElementNS("http://www.w3.org/2000/svg", "rect");
+ var bbox = node.parentElement.getBBox();
+ rect.x.baseVal.value = bbox.x;
+ rect.y.baseVal.value = bbox.y;
+ rect.width.baseVal.value = bbox.width;
+ rect.height.baseVal.value = bbox.height;
+ rect.setAttribute('class', className);
+ addItems.push({
+ "parent": node.parentNode,
+ "target": rect});
+ }
+ }
+ }
+ else if (!jQuery(node).is("button, select, textarea")) {
+ jQuery.each(node.childNodes, function() {
+ highlight(this, addItems);
+ });
+ }
+ }
+ var addItems = [];
+ var result = this.each(function() {
+ highlight(this, addItems);
+ });
+ for (var i = 0; i < addItems.length; ++i) {
+ jQuery(addItems[i].parent).before(addItems[i].target);
+ }
+ return result;
+};
+
+/*
+ * backward compatibility for jQuery.browser
+ * This will be supported until firefox bug is fixed.
+ */
+if (!jQuery.browser) {
+ jQuery.uaMatch = function(ua) {
+ ua = ua.toLowerCase();
+
+ var match = /(chrome)[ \/]([\w.]+)/.exec(ua) ||
+ /(webkit)[ \/]([\w.]+)/.exec(ua) ||
+ /(opera)(?:.*version|)[ \/]([\w.]+)/.exec(ua) ||
+ /(msie) ([\w.]+)/.exec(ua) ||
+ ua.indexOf("compatible") < 0 && /(mozilla)(?:.*? rv:([\w.]+)|)/.exec(ua) ||
+ [];
+
+ return {
+ browser: match[ 1 ] || "",
+ version: match[ 2 ] || "0"
+ };
+ };
+ jQuery.browser = {};
+ jQuery.browser[jQuery.uaMatch(navigator.userAgent).browser] = true;
+}
diff --git a/docs/_static/ampligraph_logo_transparent_white.png b/docs/_static/ampligraph_logo_transparent_white.png
new file mode 100644
index 00000000..f9f0daf2
Binary files /dev/null and b/docs/_static/ampligraph_logo_transparent_white.png differ
diff --git a/docs/_static/basic.css b/docs/_static/basic.css
new file mode 100644
index 00000000..7d5974c3
--- /dev/null
+++ b/docs/_static/basic.css
@@ -0,0 +1,928 @@
+/*
+ * basic.css
+ * ~~~~~~~~~
+ *
+ * Sphinx stylesheet -- basic theme.
+ *
+ * :copyright: Copyright 2007-2022 by the Sphinx team, see AUTHORS.
+ * :license: BSD, see LICENSE for details.
+ *
+ */
+
+/* -- main layout ----------------------------------------------------------- */
+
+div.clearer {
+ clear: both;
+}
+
+div.section::after {
+ display: block;
+ content: '';
+ clear: left;
+}
+
+/* -- relbar ---------------------------------------------------------------- */
+
+div.related {
+ width: 100%;
+ font-size: 90%;
+}
+
+div.related h3 {
+ display: none;
+}
+
+div.related ul {
+ margin: 0;
+ padding: 0 0 0 10px;
+ list-style: none;
+}
+
+div.related li {
+ display: inline;
+}
+
+div.related li.right {
+ float: right;
+ margin-right: 5px;
+}
+
+/* -- sidebar --------------------------------------------------------------- */
+
+div.sphinxsidebarwrapper {
+ padding: 10px 5px 0 10px;
+}
+
+div.sphinxsidebar {
+ float: left;
+ width: 230px;
+ margin-left: -100%;
+ font-size: 90%;
+ word-wrap: break-word;
+ overflow-wrap : break-word;
+}
+
+div.sphinxsidebar ul {
+ list-style: none;
+}
+
+div.sphinxsidebar ul ul,
+div.sphinxsidebar ul.want-points {
+ margin-left: 20px;
+ list-style: square;
+}
+
+div.sphinxsidebar ul ul {
+ margin-top: 0;
+ margin-bottom: 0;
+}
+
+div.sphinxsidebar form {
+ margin-top: 10px;
+}
+
+div.sphinxsidebar input {
+ border: 1px solid #98dbcc;
+ font-family: sans-serif;
+ font-size: 1em;
+}
+
+div.sphinxsidebar #searchbox form.search {
+ overflow: hidden;
+}
+
+div.sphinxsidebar #searchbox input[type="text"] {
+ float: left;
+ width: 80%;
+ padding: 0.25em;
+ box-sizing: border-box;
+}
+
+div.sphinxsidebar #searchbox input[type="submit"] {
+ float: left;
+ width: 20%;
+ border-left: none;
+ padding: 0.25em;
+ box-sizing: border-box;
+}
+
+
+img {
+ border: 0;
+ max-width: 100%;
+}
+
+/* -- search page ----------------------------------------------------------- */
+
+ul.search {
+ margin: 10px 0 0 20px;
+ padding: 0;
+}
+
+ul.search li {
+ padding: 5px 0 5px 20px;
+ background-image: url(file.png);
+ background-repeat: no-repeat;
+ background-position: 0 7px;
+}
+
+ul.search li a {
+ font-weight: bold;
+}
+
+ul.search li p.context {
+ color: #888;
+ margin: 2px 0 0 30px;
+ text-align: left;
+}
+
+ul.keywordmatches li.goodmatch a {
+ font-weight: bold;
+}
+
+/* -- index page ------------------------------------------------------------ */
+
+table.contentstable {
+ width: 90%;
+ margin-left: auto;
+ margin-right: auto;
+}
+
+table.contentstable p.biglink {
+ line-height: 150%;
+}
+
+a.biglink {
+ font-size: 1.3em;
+}
+
+span.linkdescr {
+ font-style: italic;
+ padding-top: 5px;
+ font-size: 90%;
+}
+
+/* -- general index --------------------------------------------------------- */
+
+table.indextable {
+ width: 100%;
+}
+
+table.indextable td {
+ text-align: left;
+ vertical-align: top;
+}
+
+table.indextable ul {
+ margin-top: 0;
+ margin-bottom: 0;
+ list-style-type: none;
+}
+
+table.indextable > tbody > tr > td > ul {
+ padding-left: 0em;
+}
+
+table.indextable tr.pcap {
+ height: 10px;
+}
+
+table.indextable tr.cap {
+ margin-top: 10px;
+ background-color: #f2f2f2;
+}
+
+img.toggler {
+ margin-right: 3px;
+ margin-top: 3px;
+ cursor: pointer;
+}
+
+div.modindex-jumpbox {
+ border-top: 1px solid #ddd;
+ border-bottom: 1px solid #ddd;
+ margin: 1em 0 1em 0;
+ padding: 0.4em;
+}
+
+div.genindex-jumpbox {
+ border-top: 1px solid #ddd;
+ border-bottom: 1px solid #ddd;
+ margin: 1em 0 1em 0;
+ padding: 0.4em;
+}
+
+/* -- domain module index --------------------------------------------------- */
+
+table.modindextable td {
+ padding: 2px;
+ border-collapse: collapse;
+}
+
+/* -- general body styles --------------------------------------------------- */
+
+div.body {
+ min-width: 360px;
+ max-width: 800px;
+}
+
+div.body p, div.body dd, div.body li, div.body blockquote {
+ -moz-hyphens: auto;
+ -ms-hyphens: auto;
+ -webkit-hyphens: auto;
+ hyphens: auto;
+}
+
+a.headerlink {
+ visibility: hidden;
+}
+a.brackets:before,
+span.brackets > a:before{
+ content: "[";
+}
+
+a.brackets:after,
+span.brackets > a:after {
+ content: "]";
+}
+
+
+h1:hover > a.headerlink,
+h2:hover > a.headerlink,
+h3:hover > a.headerlink,
+h4:hover > a.headerlink,
+h5:hover > a.headerlink,
+h6:hover > a.headerlink,
+dt:hover > a.headerlink,
+caption:hover > a.headerlink,
+p.caption:hover > a.headerlink,
+div.code-block-caption:hover > a.headerlink {
+ visibility: visible;
+}
+
+div.body p.caption {
+ text-align: inherit;
+}
+
+div.body td {
+ text-align: left;
+}
+
+.first {
+ margin-top: 0 !important;
+}
+
+p.rubric {
+ margin-top: 30px;
+ font-weight: bold;
+}
+
+img.align-left, figure.align-left, .figure.align-left, object.align-left {
+ clear: left;
+ float: left;
+ margin-right: 1em;
+}
+
+img.align-right, figure.align-right, .figure.align-right, object.align-right {
+ clear: right;
+ float: right;
+ margin-left: 1em;
+}
+
+img.align-center, figure.align-center, .figure.align-center, object.align-center {
+ display: block;
+ margin-left: auto;
+ margin-right: auto;
+}
+
+img.align-default, figure.align-default, .figure.align-default {
+ display: block;
+ margin-left: auto;
+ margin-right: auto;
+}
+
+.align-left {
+ text-align: left;
+}
+
+.align-center {
+ text-align: center;
+}
+
+.align-default {
+ text-align: center;
+}
+
+.align-right {
+ text-align: right;
+}
+
+/* -- sidebars -------------------------------------------------------------- */
+
+div.sidebar,
+aside.sidebar {
+ margin: 0 0 0.5em 1em;
+ border: 1px solid #ddb;
+ padding: 7px;
+ background-color: #ffe;
+ width: 40%;
+ float: right;
+ clear: right;
+ overflow-x: auto;
+}
+
+p.sidebar-title {
+ font-weight: bold;
+}
+div.admonition, div.topic, blockquote {
+ clear: left;
+}
+
+/* -- topics ---------------------------------------------------------------- */
+div.topic {
+ border: 1px solid #ccc;
+ padding: 7px;
+ margin: 10px 0 10px 0;
+}
+
+p.topic-title {
+ font-size: 1.1em;
+ font-weight: bold;
+ margin-top: 10px;
+}
+
+/* -- admonitions ----------------------------------------------------------- */
+
+div.admonition {
+ margin-top: 10px;
+ margin-bottom: 10px;
+ padding: 7px;
+}
+
+div.admonition dt {
+ font-weight: bold;
+}
+
+p.admonition-title {
+ margin: 0px 10px 5px 0px;
+ font-weight: bold;
+}
+
+div.body p.centered {
+ text-align: center;
+ margin-top: 25px;
+}
+
+/* -- content of sidebars/topics/admonitions -------------------------------- */
+
+div.sidebar > :last-child,
+aside.sidebar > :last-child,
+div.topic > :last-child,
+div.admonition > :last-child {
+ margin-bottom: 0;
+}
+
+div.sidebar::after,
+aside.sidebar::after,
+div.topic::after,
+div.admonition::after,
+blockquote::after {
+ display: block;
+ content: '';
+ clear: both;
+}
+
+/* -- tables ---------------------------------------------------------------- */
+
+table.docutils {
+ margin-top: 10px;
+ margin-bottom: 10px;
+ border: 0;
+ border-collapse: collapse;
+}
+
+table.align-center {
+ margin-left: auto;
+ margin-right: auto;
+}
+
+table.align-default {
+ margin-left: auto;
+ margin-right: auto;
+}
+
+table caption span.caption-number {
+ font-style: italic;
+}
+
+table caption span.caption-text {
+}
+
+table.docutils td, table.docutils th {
+ padding: 1px 8px 1px 5px;
+ border-top: 0;
+ border-left: 0;
+ border-right: 0;
+ border-bottom: 1px solid #aaa;
+}
+
+th {
+ text-align: left;
+ padding-right: 5px;
+}
+
+table.citation {
+ border-left: solid 1px gray;
+ margin-left: 1px;
+}
+
+table.citation td {
+ border-bottom: none;
+}
+
+th > :first-child,
+td > :first-child {
+ margin-top: 0px;
+}
+
+th > :last-child,
+td > :last-child {
+ margin-bottom: 0px;
+}
+
+/* -- figures --------------------------------------------------------------- */
+
+div.figure, figure {
+ margin: 0.5em;
+ padding: 0.5em;
+}
+
+div.figure p.caption, figcaption {
+ padding: 0.3em;
+}
+
+div.figure p.caption span.caption-number,
+figcaption span.caption-number {
+ font-style: italic;
+}
+
+div.figure p.caption span.caption-text,
+figcaption span.caption-text {
+}
+
+/* -- field list styles ----------------------------------------------------- */
+
+table.field-list td, table.field-list th {
+ border: 0 !important;
+}
+
+.field-list ul {
+ margin: 0;
+ padding-left: 1em;
+}
+
+.field-list p {
+ margin: 0;
+}
+
+.field-name {
+ -moz-hyphens: manual;
+ -ms-hyphens: manual;
+ -webkit-hyphens: manual;
+ hyphens: manual;
+}
+
+/* -- hlist styles ---------------------------------------------------------- */
+
+table.hlist {
+ margin: 1em 0;
+}
+
+table.hlist td {
+ vertical-align: top;
+}
+
+/* -- object description styles --------------------------------------------- */
+
+.sig {
+ font-family: 'Consolas', 'Menlo', 'DejaVu Sans Mono', 'Bitstream Vera Sans Mono', monospace;
+}
+
+.sig-name, code.descname {
+ background-color: transparent;
+ font-weight: bold;
+}
+
+.sig-name {
+ font-size: 1.1em;
+}
+
+code.descname {
+ font-size: 1.2em;
+}
+
+.sig-prename, code.descclassname {
+ background-color: transparent;
+}
+
+.optional {
+ font-size: 1.3em;
+}
+
+.sig-paren {
+ font-size: larger;
+}
+
+.sig-param.n {
+ font-style: italic;
+}
+
+/* C++ specific styling */
+
+.sig-inline.c-texpr,
+.sig-inline.cpp-texpr {
+ font-family: unset;
+}
+
+.sig.c .k, .sig.c .kt,
+.sig.cpp .k, .sig.cpp .kt {
+ color: #0033B3;
+}
+
+.sig.c .m,
+.sig.cpp .m {
+ color: #1750EB;
+}
+
+.sig.c .s, .sig.c .sc,
+.sig.cpp .s, .sig.cpp .sc {
+ color: #067D17;
+}
+
+
+/* -- other body styles ----------------------------------------------------- */
+
+ol.arabic {
+ list-style: decimal;
+}
+
+ol.loweralpha {
+ list-style: lower-alpha;
+}
+
+ol.upperalpha {
+ list-style: upper-alpha;
+}
+
+ol.lowerroman {
+ list-style: lower-roman;
+}
+
+ol.upperroman {
+ list-style: upper-roman;
+}
+
+:not(li) > ol > li:first-child > :first-child,
+:not(li) > ul > li:first-child > :first-child {
+ margin-top: 0px;
+}
+
+:not(li) > ol > li:last-child > :last-child,
+:not(li) > ul > li:last-child > :last-child {
+ margin-bottom: 0px;
+}
+
+ol.simple ol p,
+ol.simple ul p,
+ul.simple ol p,
+ul.simple ul p {
+ margin-top: 0;
+}
+
+ol.simple > li:not(:first-child) > p,
+ul.simple > li:not(:first-child) > p {
+ margin-top: 0;
+}
+
+ol.simple p,
+ul.simple p {
+ margin-bottom: 0;
+}
+
+/* Docutils 0.17 and older (footnotes & citations) */
+dl.footnote > dt,
+dl.citation > dt {
+ float: left;
+ margin-right: 0.5em;
+}
+
+dl.footnote > dd,
+dl.citation > dd {
+ margin-bottom: 0em;
+}
+
+dl.footnote > dd:after,
+dl.citation > dd:after {
+ content: "";
+ clear: both;
+}
+
+/* Docutils 0.18+ (footnotes & citations) */
+aside.footnote > span,
+div.citation > span {
+ float: left;
+}
+aside.footnote > span:last-of-type,
+div.citation > span:last-of-type {
+ padding-right: 0.5em;
+}
+aside.footnote > p {
+ margin-left: 2em;
+}
+div.citation > p {
+ margin-left: 4em;
+}
+aside.footnote > p:last-of-type,
+div.citation > p:last-of-type {
+ margin-bottom: 0em;
+}
+aside.footnote > p:last-of-type:after,
+div.citation > p:last-of-type:after {
+ content: "";
+ clear: both;
+}
+
+/* Footnotes & citations ends */
+
+dl.field-list {
+ display: grid;
+ grid-template-columns: fit-content(30%) auto;
+}
+
+dl.field-list > dt {
+ font-weight: bold;
+ word-break: break-word;
+ padding-left: 0.5em;
+ padding-right: 5px;
+}
+
+dl.field-list > dt:after {
+ content: ":";
+}
+
+dl.field-list > dd {
+ padding-left: 0.5em;
+ margin-top: 0em;
+ margin-left: 0em;
+ margin-bottom: 0em;
+}
+
+dl {
+ margin-bottom: 15px;
+}
+
+dd > :first-child {
+ margin-top: 0px;
+}
+
+dd ul, dd table {
+ margin-bottom: 10px;
+}
+
+dd {
+ margin-top: 3px;
+ margin-bottom: 10px;
+ margin-left: 30px;
+}
+
+dl > dd:last-child,
+dl > dd:last-child > :last-child {
+ margin-bottom: 0;
+}
+
+dt:target, span.highlighted {
+ background-color: #fbe54e;
+}
+
+rect.highlighted {
+ fill: #fbe54e;
+}
+
+dl.glossary dt {
+ font-weight: bold;
+ font-size: 1.1em;
+}
+
+.versionmodified {
+ font-style: italic;
+}
+
+.system-message {
+ background-color: #fda;
+ padding: 5px;
+ border: 3px solid red;
+}
+
+.footnote:target {
+ background-color: #ffa;
+}
+
+.line-block {
+ display: block;
+ margin-top: 1em;
+ margin-bottom: 1em;
+}
+
+.line-block .line-block {
+ margin-top: 0;
+ margin-bottom: 0;
+ margin-left: 1.5em;
+}
+
+.guilabel, .menuselection {
+ font-family: sans-serif;
+}
+
+.accelerator {
+ text-decoration: underline;
+}
+
+.classifier {
+ font-style: oblique;
+}
+
+.classifier:before {
+ font-style: normal;
+ margin: 0 0.5em;
+ content: ":";
+ display: inline-block;
+}
+
+abbr, acronym {
+ border-bottom: dotted 1px;
+ cursor: help;
+}
+
+/* -- code displays --------------------------------------------------------- */
+
+pre {
+ overflow: auto;
+ overflow-y: hidden; /* fixes display issues on Chrome browsers */
+}
+
+pre, div[class*="highlight-"] {
+ clear: both;
+}
+
+span.pre {
+ -moz-hyphens: none;
+ -ms-hyphens: none;
+ -webkit-hyphens: none;
+ hyphens: none;
+ white-space: nowrap;
+}
+
+div[class*="highlight-"] {
+ margin: 1em 0;
+}
+
+td.linenos pre {
+ border: 0;
+ background-color: transparent;
+ color: #aaa;
+}
+
+table.highlighttable {
+ display: block;
+}
+
+table.highlighttable tbody {
+ display: block;
+}
+
+table.highlighttable tr {
+ display: flex;
+}
+
+table.highlighttable td {
+ margin: 0;
+ padding: 0;
+}
+
+table.highlighttable td.linenos {
+ padding-right: 0.5em;
+}
+
+table.highlighttable td.code {
+ flex: 1;
+ overflow: hidden;
+}
+
+.highlight .hll {
+ display: block;
+}
+
+div.highlight pre,
+table.highlighttable pre {
+ margin: 0;
+}
+
+div.code-block-caption + div {
+ margin-top: 0;
+}
+
+div.code-block-caption {
+ margin-top: 1em;
+ padding: 2px 5px;
+ font-size: small;
+}
+
+div.code-block-caption code {
+ background-color: transparent;
+}
+
+table.highlighttable td.linenos,
+span.linenos,
+div.highlight span.gp { /* gp: Generic.Prompt */
+ user-select: none;
+ -webkit-user-select: text; /* Safari fallback only */
+ -webkit-user-select: none; /* Chrome/Safari */
+ -moz-user-select: none; /* Firefox */
+ -ms-user-select: none; /* IE10+ */
+}
+
+div.code-block-caption span.caption-number {
+ padding: 0.1em 0.3em;
+ font-style: italic;
+}
+
+div.code-block-caption span.caption-text {
+}
+
+div.literal-block-wrapper {
+ margin: 1em 0;
+}
+
+code.xref, a code {
+ background-color: transparent;
+ font-weight: bold;
+}
+
+h1 code, h2 code, h3 code, h4 code, h5 code, h6 code {
+ background-color: transparent;
+}
+
+.viewcode-link {
+ float: right;
+}
+
+.viewcode-back {
+ float: right;
+ font-family: sans-serif;
+}
+
+div.viewcode-block:target {
+ margin: -1px -10px;
+ padding: 0 10px;
+}
+
+/* -- math display ---------------------------------------------------------- */
+
+img.math {
+ vertical-align: middle;
+}
+
+div.body div.math p {
+ text-align: center;
+}
+
+span.eqno {
+ float: right;
+}
+
+span.eqno a.headerlink {
+ position: absolute;
+ z-index: 1;
+}
+
+div.math:hover a.headerlink {
+ visibility: visible;
+}
+
+/* -- printout stylesheet --------------------------------------------------- */
+
+@media print {
+ div.document,
+ div.documentwrapper,
+ div.bodywrapper {
+ margin: 0 !important;
+ width: 100%;
+ }
+
+ div.sphinxsidebar,
+ div.related,
+ div.footer,
+ #top-link {
+ display: none;
+ }
+}
\ No newline at end of file
diff --git a/docs/_static/css/badge_only.css b/docs/_static/css/badge_only.css
new file mode 100644
index 00000000..e380325b
--- /dev/null
+++ b/docs/_static/css/badge_only.css
@@ -0,0 +1 @@
+.fa:before{-webkit-font-smoothing:antialiased}.clearfix{*zoom:1}.clearfix:after,.clearfix:before{display:table;content:""}.clearfix:after{clear:both}@font-face{font-family:FontAwesome;font-style:normal;font-weight:400;src:url(fonts/fontawesome-webfont.eot?674f50d287a8c48dc19ba404d20fe713?#iefix) format("embedded-opentype"),url(fonts/fontawesome-webfont.woff2?af7ae505a9eed503f8b8e6982036873e) format("woff2"),url(fonts/fontawesome-webfont.woff?fee66e712a8a08eef5805a46892932ad) format("woff"),url(fonts/fontawesome-webfont.ttf?b06871f281fee6b241d60582ae9369b9) format("truetype"),url(fonts/fontawesome-webfont.svg?912ec66d7572ff821749319396470bde#FontAwesome) format("svg")}.fa:before{font-family:FontAwesome;font-style:normal;font-weight:400;line-height:1}.fa:before,a .fa{text-decoration:inherit}.fa:before,a .fa,li .fa{display:inline-block}li .fa-large:before{width:1.875em}ul.fas{list-style-type:none;margin-left:2em;text-indent:-.8em}ul.fas li .fa{width:.8em}ul.fas li .fa-large:before{vertical-align:baseline}.fa-book:before,.icon-book:before{content:"\f02d"}.fa-caret-down:before,.icon-caret-down:before{content:"\f0d7"}.fa-caret-up:before,.icon-caret-up:before{content:"\f0d8"}.fa-caret-left:before,.icon-caret-left:before{content:"\f0d9"}.fa-caret-right:before,.icon-caret-right:before{content:"\f0da"}.rst-versions{position:fixed;bottom:0;left:0;width:300px;color:#fcfcfc;background:#1f1d1d;font-family:Lato,proxima-nova,Helvetica Neue,Arial,sans-serif;z-index:400}.rst-versions a{color:#2980b9;text-decoration:none}.rst-versions .rst-badge-small{display:none}.rst-versions .rst-current-version{padding:12px;background-color:#272525;display:block;text-align:right;font-size:90%;cursor:pointer;color:#27ae60}.rst-versions .rst-current-version:after{clear:both;content:"";display:block}.rst-versions .rst-current-version .fa{color:#fcfcfc}.rst-versions .rst-current-version .fa-book,.rst-versions .rst-current-version .icon-book{float:left}.rst-versions .rst-current-version.rst-out-of-date{background-color:#e74c3c;color:#fff}.rst-versions .rst-current-version.rst-active-old-version{background-color:#f1c40f;color:#000}.rst-versions.shift-up{height:auto;max-height:100%;overflow-y:scroll}.rst-versions.shift-up .rst-other-versions{display:block}.rst-versions .rst-other-versions{font-size:90%;padding:12px;color:grey;display:none}.rst-versions .rst-other-versions hr{display:block;height:1px;border:0;margin:20px 0;padding:0;border-top:1px solid #413d3d}.rst-versions .rst-other-versions dd{display:inline-block;margin:0}.rst-versions .rst-other-versions dd a{display:inline-block;padding:6px;color:#fcfcfc}.rst-versions.rst-badge{width:auto;bottom:20px;right:20px;left:auto;border:none;max-width:300px;max-height:90%}.rst-versions.rst-badge .fa-book,.rst-versions.rst-badge .icon-book{float:none;line-height:30px}.rst-versions.rst-badge.shift-up .rst-current-version{text-align:right}.rst-versions.rst-badge.shift-up .rst-current-version .fa-book,.rst-versions.rst-badge.shift-up .rst-current-version .icon-book{float:left}.rst-versions.rst-badge>.rst-current-version{width:auto;height:30px;line-height:30px;padding:0 6px;display:block;text-align:center}@media screen and (max-width:768px){.rst-versions{width:85%;display:none}.rst-versions.shift{display:block}}
\ No newline at end of file
diff --git a/docs/_static/css/fonts/Roboto-Slab-Bold.woff b/docs/_static/css/fonts/Roboto-Slab-Bold.woff
new file mode 100644
index 00000000..6cb60000
Binary files /dev/null and b/docs/_static/css/fonts/Roboto-Slab-Bold.woff differ
diff --git a/docs/_static/css/fonts/Roboto-Slab-Bold.woff2 b/docs/_static/css/fonts/Roboto-Slab-Bold.woff2
new file mode 100644
index 00000000..7059e231
Binary files /dev/null and b/docs/_static/css/fonts/Roboto-Slab-Bold.woff2 differ
diff --git a/docs/_static/css/fonts/Roboto-Slab-Regular.woff b/docs/_static/css/fonts/Roboto-Slab-Regular.woff
new file mode 100644
index 00000000..f815f63f
Binary files /dev/null and b/docs/_static/css/fonts/Roboto-Slab-Regular.woff differ
diff --git a/docs/_static/css/fonts/Roboto-Slab-Regular.woff2 b/docs/_static/css/fonts/Roboto-Slab-Regular.woff2
new file mode 100644
index 00000000..f2c76e5b
Binary files /dev/null and b/docs/_static/css/fonts/Roboto-Slab-Regular.woff2 differ
diff --git a/docs/_static/css/fonts/fontawesome-webfont.eot b/docs/_static/css/fonts/fontawesome-webfont.eot
new file mode 100644
index 00000000..e9f60ca9
Binary files /dev/null and b/docs/_static/css/fonts/fontawesome-webfont.eot differ
diff --git a/docs/_static/css/fonts/fontawesome-webfont.svg b/docs/_static/css/fonts/fontawesome-webfont.svg
new file mode 100644
index 00000000..855c845e
--- /dev/null
+++ b/docs/_static/css/fonts/fontawesome-webfont.svg
@@ -0,0 +1,2671 @@
+
+
+
diff --git a/docs/_static/css/fonts/fontawesome-webfont.ttf b/docs/_static/css/fonts/fontawesome-webfont.ttf
new file mode 100644
index 00000000..35acda2f
Binary files /dev/null and b/docs/_static/css/fonts/fontawesome-webfont.ttf differ
diff --git a/docs/_static/css/fonts/fontawesome-webfont.woff b/docs/_static/css/fonts/fontawesome-webfont.woff
new file mode 100644
index 00000000..400014a4
Binary files /dev/null and b/docs/_static/css/fonts/fontawesome-webfont.woff differ
diff --git a/docs/_static/css/fonts/fontawesome-webfont.woff2 b/docs/_static/css/fonts/fontawesome-webfont.woff2
new file mode 100644
index 00000000..4d13fc60
Binary files /dev/null and b/docs/_static/css/fonts/fontawesome-webfont.woff2 differ
diff --git a/docs/_static/css/fonts/lato-bold-italic.woff b/docs/_static/css/fonts/lato-bold-italic.woff
new file mode 100644
index 00000000..88ad05b9
Binary files /dev/null and b/docs/_static/css/fonts/lato-bold-italic.woff differ
diff --git a/docs/_static/css/fonts/lato-bold-italic.woff2 b/docs/_static/css/fonts/lato-bold-italic.woff2
new file mode 100644
index 00000000..c4e3d804
Binary files /dev/null and b/docs/_static/css/fonts/lato-bold-italic.woff2 differ
diff --git a/docs/_static/css/fonts/lato-bold.woff b/docs/_static/css/fonts/lato-bold.woff
new file mode 100644
index 00000000..c6dff51f
Binary files /dev/null and b/docs/_static/css/fonts/lato-bold.woff differ
diff --git a/docs/_static/css/fonts/lato-bold.woff2 b/docs/_static/css/fonts/lato-bold.woff2
new file mode 100644
index 00000000..bb195043
Binary files /dev/null and b/docs/_static/css/fonts/lato-bold.woff2 differ
diff --git a/docs/_static/css/fonts/lato-normal-italic.woff b/docs/_static/css/fonts/lato-normal-italic.woff
new file mode 100644
index 00000000..76114bc0
Binary files /dev/null and b/docs/_static/css/fonts/lato-normal-italic.woff differ
diff --git a/docs/_static/css/fonts/lato-normal-italic.woff2 b/docs/_static/css/fonts/lato-normal-italic.woff2
new file mode 100644
index 00000000..3404f37e
Binary files /dev/null and b/docs/_static/css/fonts/lato-normal-italic.woff2 differ
diff --git a/docs/_static/css/fonts/lato-normal.woff b/docs/_static/css/fonts/lato-normal.woff
new file mode 100644
index 00000000..ae1307ff
Binary files /dev/null and b/docs/_static/css/fonts/lato-normal.woff differ
diff --git a/docs/_static/css/fonts/lato-normal.woff2 b/docs/_static/css/fonts/lato-normal.woff2
new file mode 100644
index 00000000..3bf98433
Binary files /dev/null and b/docs/_static/css/fonts/lato-normal.woff2 differ
diff --git a/docs/_static/css/theme.css b/docs/_static/css/theme.css
new file mode 100644
index 00000000..0d9ae7e1
--- /dev/null
+++ b/docs/_static/css/theme.css
@@ -0,0 +1,4 @@
+html{box-sizing:border-box}*,:after,:before{box-sizing:inherit}article,aside,details,figcaption,figure,footer,header,hgroup,nav,section{display:block}audio,canvas,video{display:inline-block;*display:inline;*zoom:1}[hidden],audio:not([controls]){display:none}*{-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box}html{font-size:100%;-webkit-text-size-adjust:100%;-ms-text-size-adjust:100%}body{margin:0}a:active,a:hover{outline:0}abbr[title]{border-bottom:1px dotted}b,strong{font-weight:700}blockquote{margin:0}dfn{font-style:italic}ins{background:#ff9;text-decoration:none}ins,mark{color:#000}mark{background:#ff0;font-style:italic;font-weight:700}.rst-content code,.rst-content tt,code,kbd,pre,samp{font-family:monospace,serif;_font-family:courier new,monospace;font-size:1em}pre{white-space:pre}q{quotes:none}q:after,q:before{content:"";content:none}small{font-size:85%}sub,sup{font-size:75%;line-height:0;position:relative;vertical-align:baseline}sup{top:-.5em}sub{bottom:-.25em}dl,ol,ul{margin:0;padding:0;list-style:none;list-style-image:none}li{list-style:none}dd{margin:0}img{border:0;-ms-interpolation-mode:bicubic;vertical-align:middle;max-width:100%}svg:not(:root){overflow:hidden}figure,form{margin:0}label{cursor:pointer}button,input,select,textarea{font-size:100%;margin:0;vertical-align:baseline;*vertical-align:middle}button,input{line-height:normal}button,input[type=button],input[type=reset],input[type=submit]{cursor:pointer;-webkit-appearance:button;*overflow:visible}button[disabled],input[disabled]{cursor:default}input[type=search]{-webkit-appearance:textfield;-moz-box-sizing:content-box;-webkit-box-sizing:content-box;box-sizing:content-box}textarea{resize:vertical}table{border-collapse:collapse;border-spacing:0}td{vertical-align:top}.chromeframe{margin:.2em 0;background:#ccc;color:#000;padding:.2em 0}.ir{display:block;border:0;text-indent:-999em;overflow:hidden;background-color:transparent;background-repeat:no-repeat;text-align:left;direction:ltr;*line-height:0}.ir br{display:none}.hidden{display:none!important;visibility:hidden}.visuallyhidden{border:0;clip:rect(0 0 0 0);height:1px;margin:-1px;overflow:hidden;padding:0;position:absolute;width:1px}.visuallyhidden.focusable:active,.visuallyhidden.focusable:focus{clip:auto;height:auto;margin:0;overflow:visible;position:static;width:auto}.invisible{visibility:hidden}.relative{position:relative}big,small{font-size:100%}@media print{body,html,section{background:none!important}*{box-shadow:none!important;text-shadow:none!important;filter:none!important;-ms-filter:none!important}a,a:visited{text-decoration:underline}.ir a:after,a[href^="#"]:after,a[href^="javascript:"]:after{content:""}blockquote,pre{page-break-inside:avoid}thead{display:table-header-group}img,tr{page-break-inside:avoid}img{max-width:100%!important}@page{margin:.5cm}.rst-content .toctree-wrapper>p.caption,h2,h3,p{orphans:3;widows:3}.rst-content .toctree-wrapper>p.caption,h2,h3{page-break-after:avoid}}.btn,.fa:before,.icon:before,.rst-content .admonition,.rst-content .admonition-title:before,.rst-content .admonition-todo,.rst-content .attention,.rst-content .caution,.rst-content .code-block-caption .headerlink:before,.rst-content .danger,.rst-content .eqno .headerlink:before,.rst-content .error,.rst-content .hint,.rst-content .important,.rst-content .note,.rst-content .seealso,.rst-content .tip,.rst-content .warning,.rst-content code.download span:first-child:before,.rst-content dl dt .headerlink:before,.rst-content h1 .headerlink:before,.rst-content h2 .headerlink:before,.rst-content h3 .headerlink:before,.rst-content h4 .headerlink:before,.rst-content h5 .headerlink:before,.rst-content h6 .headerlink:before,.rst-content p.caption .headerlink:before,.rst-content p .headerlink:before,.rst-content table>caption .headerlink:before,.rst-content tt.download span:first-child:before,.wy-alert,.wy-dropdown .caret:before,.wy-inline-validate.wy-inline-validate-danger .wy-input-context:before,.wy-inline-validate.wy-inline-validate-info .wy-input-context:before,.wy-inline-validate.wy-inline-validate-success .wy-input-context:before,.wy-inline-validate.wy-inline-validate-warning .wy-input-context:before,.wy-menu-vertical li.current>a,.wy-menu-vertical li.current>a button.toctree-expand:before,.wy-menu-vertical li.on a,.wy-menu-vertical li.on a button.toctree-expand:before,.wy-menu-vertical li button.toctree-expand:before,.wy-nav-top a,.wy-side-nav-search .wy-dropdown>a,.wy-side-nav-search>a,input[type=color],input[type=date],input[type=datetime-local],input[type=datetime],input[type=email],input[type=month],input[type=number],input[type=password],input[type=search],input[type=tel],input[type=text],input[type=time],input[type=url],input[type=week],select,textarea{-webkit-font-smoothing:antialiased}.clearfix{*zoom:1}.clearfix:after,.clearfix:before{display:table;content:""}.clearfix:after{clear:both}/*!
+ * Font Awesome 4.7.0 by @davegandy - http://fontawesome.io - @fontawesome
+ * License - http://fontawesome.io/license (Font: SIL OFL 1.1, CSS: MIT License)
+ */@font-face{font-family:FontAwesome;src:url(fonts/fontawesome-webfont.eot?674f50d287a8c48dc19ba404d20fe713);src:url(fonts/fontawesome-webfont.eot?674f50d287a8c48dc19ba404d20fe713?#iefix&v=4.7.0) format("embedded-opentype"),url(fonts/fontawesome-webfont.woff2?af7ae505a9eed503f8b8e6982036873e) format("woff2"),url(fonts/fontawesome-webfont.woff?fee66e712a8a08eef5805a46892932ad) format("woff"),url(fonts/fontawesome-webfont.ttf?b06871f281fee6b241d60582ae9369b9) format("truetype"),url(fonts/fontawesome-webfont.svg?912ec66d7572ff821749319396470bde#fontawesomeregular) format("svg");font-weight:400;font-style:normal}.fa,.icon,.rst-content .admonition-title,.rst-content .code-block-caption .headerlink,.rst-content .eqno .headerlink,.rst-content code.download span:first-child,.rst-content dl dt .headerlink,.rst-content h1 .headerlink,.rst-content h2 .headerlink,.rst-content h3 .headerlink,.rst-content h4 .headerlink,.rst-content h5 .headerlink,.rst-content h6 .headerlink,.rst-content p.caption .headerlink,.rst-content p .headerlink,.rst-content table>caption .headerlink,.rst-content tt.download span:first-child,.wy-menu-vertical li.current>a button.toctree-expand,.wy-menu-vertical li.on a button.toctree-expand,.wy-menu-vertical li button.toctree-expand{display:inline-block;font:normal normal normal 14px/1 FontAwesome;font-size:inherit;text-rendering:auto;-webkit-font-smoothing:antialiased;-moz-osx-font-smoothing:grayscale}.fa-lg{font-size:1.33333em;line-height:.75em;vertical-align:-15%}.fa-2x{font-size:2em}.fa-3x{font-size:3em}.fa-4x{font-size:4em}.fa-5x{font-size:5em}.fa-fw{width:1.28571em;text-align:center}.fa-ul{padding-left:0;margin-left:2.14286em;list-style-type:none}.fa-ul>li{position:relative}.fa-li{position:absolute;left:-2.14286em;width:2.14286em;top:.14286em;text-align:center}.fa-li.fa-lg{left:-1.85714em}.fa-border{padding:.2em .25em .15em;border:.08em solid #eee;border-radius:.1em}.fa-pull-left{float:left}.fa-pull-right{float:right}.fa-pull-left.icon,.fa.fa-pull-left,.rst-content .code-block-caption .fa-pull-left.headerlink,.rst-content .eqno .fa-pull-left.headerlink,.rst-content .fa-pull-left.admonition-title,.rst-content code.download span.fa-pull-left:first-child,.rst-content dl dt .fa-pull-left.headerlink,.rst-content h1 .fa-pull-left.headerlink,.rst-content h2 .fa-pull-left.headerlink,.rst-content h3 .fa-pull-left.headerlink,.rst-content h4 .fa-pull-left.headerlink,.rst-content h5 .fa-pull-left.headerlink,.rst-content h6 .fa-pull-left.headerlink,.rst-content p .fa-pull-left.headerlink,.rst-content table>caption .fa-pull-left.headerlink,.rst-content tt.download span.fa-pull-left:first-child,.wy-menu-vertical li.current>a button.fa-pull-left.toctree-expand,.wy-menu-vertical li.on a button.fa-pull-left.toctree-expand,.wy-menu-vertical li button.fa-pull-left.toctree-expand{margin-right:.3em}.fa-pull-right.icon,.fa.fa-pull-right,.rst-content .code-block-caption .fa-pull-right.headerlink,.rst-content .eqno .fa-pull-right.headerlink,.rst-content .fa-pull-right.admonition-title,.rst-content code.download span.fa-pull-right:first-child,.rst-content dl dt .fa-pull-right.headerlink,.rst-content h1 .fa-pull-right.headerlink,.rst-content h2 .fa-pull-right.headerlink,.rst-content h3 .fa-pull-right.headerlink,.rst-content h4 .fa-pull-right.headerlink,.rst-content h5 .fa-pull-right.headerlink,.rst-content h6 .fa-pull-right.headerlink,.rst-content p .fa-pull-right.headerlink,.rst-content table>caption .fa-pull-right.headerlink,.rst-content tt.download span.fa-pull-right:first-child,.wy-menu-vertical li.current>a button.fa-pull-right.toctree-expand,.wy-menu-vertical li.on a button.fa-pull-right.toctree-expand,.wy-menu-vertical li button.fa-pull-right.toctree-expand{margin-left:.3em}.pull-right{float:right}.pull-left{float:left}.fa.pull-left,.pull-left.icon,.rst-content .code-block-caption .pull-left.headerlink,.rst-content .eqno .pull-left.headerlink,.rst-content .pull-left.admonition-title,.rst-content code.download span.pull-left:first-child,.rst-content dl dt .pull-left.headerlink,.rst-content h1 .pull-left.headerlink,.rst-content h2 .pull-left.headerlink,.rst-content h3 .pull-left.headerlink,.rst-content h4 .pull-left.headerlink,.rst-content h5 .pull-left.headerlink,.rst-content h6 .pull-left.headerlink,.rst-content p .pull-left.headerlink,.rst-content table>caption .pull-left.headerlink,.rst-content tt.download span.pull-left:first-child,.wy-menu-vertical li.current>a button.pull-left.toctree-expand,.wy-menu-vertical li.on a button.pull-left.toctree-expand,.wy-menu-vertical li button.pull-left.toctree-expand{margin-right:.3em}.fa.pull-right,.pull-right.icon,.rst-content .code-block-caption .pull-right.headerlink,.rst-content .eqno .pull-right.headerlink,.rst-content .pull-right.admonition-title,.rst-content code.download span.pull-right:first-child,.rst-content dl dt .pull-right.headerlink,.rst-content h1 .pull-right.headerlink,.rst-content h2 .pull-right.headerlink,.rst-content h3 .pull-right.headerlink,.rst-content h4 .pull-right.headerlink,.rst-content h5 .pull-right.headerlink,.rst-content h6 .pull-right.headerlink,.rst-content p .pull-right.headerlink,.rst-content table>caption .pull-right.headerlink,.rst-content tt.download span.pull-right:first-child,.wy-menu-vertical li.current>a button.pull-right.toctree-expand,.wy-menu-vertical li.on a button.pull-right.toctree-expand,.wy-menu-vertical li button.pull-right.toctree-expand{margin-left:.3em}.fa-spin{-webkit-animation:fa-spin 2s linear infinite;animation:fa-spin 2s linear infinite}.fa-pulse{-webkit-animation:fa-spin 1s steps(8) infinite;animation:fa-spin 1s steps(8) infinite}@-webkit-keyframes fa-spin{0%{-webkit-transform:rotate(0deg);transform:rotate(0deg)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes fa-spin{0%{-webkit-transform:rotate(0deg);transform:rotate(0deg)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}.fa-rotate-90{-ms-filter:"progid:DXImageTransform.Microsoft.BasicImage(rotation=1)";-webkit-transform:rotate(90deg);-ms-transform:rotate(90deg);transform:rotate(90deg)}.fa-rotate-180{-ms-filter:"progid:DXImageTransform.Microsoft.BasicImage(rotation=2)";-webkit-transform:rotate(180deg);-ms-transform:rotate(180deg);transform:rotate(180deg)}.fa-rotate-270{-ms-filter:"progid:DXImageTransform.Microsoft.BasicImage(rotation=3)";-webkit-transform:rotate(270deg);-ms-transform:rotate(270deg);transform:rotate(270deg)}.fa-flip-horizontal{-ms-filter:"progid:DXImageTransform.Microsoft.BasicImage(rotation=0, mirror=1)";-webkit-transform:scaleX(-1);-ms-transform:scaleX(-1);transform:scaleX(-1)}.fa-flip-vertical{-ms-filter:"progid:DXImageTransform.Microsoft.BasicImage(rotation=2, mirror=1)";-webkit-transform:scaleY(-1);-ms-transform:scaleY(-1);transform:scaleY(-1)}:root .fa-flip-horizontal,:root .fa-flip-vertical,:root .fa-rotate-90,:root .fa-rotate-180,:root .fa-rotate-270{filter:none}.fa-stack{position:relative;display:inline-block;width:2em;height:2em;line-height:2em;vertical-align:middle}.fa-stack-1x,.fa-stack-2x{position:absolute;left:0;width:100%;text-align:center}.fa-stack-1x{line-height:inherit}.fa-stack-2x{font-size:2em}.fa-inverse{color:#fff}.fa-glass:before{content:""}.fa-music:before{content:""}.fa-search:before,.icon-search:before{content:""}.fa-envelope-o:before{content:""}.fa-heart:before{content:""}.fa-star:before{content:""}.fa-star-o:before{content:""}.fa-user:before{content:""}.fa-film:before{content:""}.fa-th-large:before{content:""}.fa-th:before{content:""}.fa-th-list:before{content:""}.fa-check:before{content:""}.fa-close:before,.fa-remove:before,.fa-times:before{content:""}.fa-search-plus:before{content:""}.fa-search-minus:before{content:""}.fa-power-off:before{content:""}.fa-signal:before{content:""}.fa-cog:before,.fa-gear:before{content:""}.fa-trash-o:before{content:""}.fa-home:before,.icon-home:before{content:""}.fa-file-o:before{content:""}.fa-clock-o:before{content:""}.fa-road:before{content:""}.fa-download:before,.rst-content code.download span:first-child:before,.rst-content tt.download span:first-child:before{content:""}.fa-arrow-circle-o-down:before{content:""}.fa-arrow-circle-o-up:before{content:""}.fa-inbox:before{content:""}.fa-play-circle-o:before{content:""}.fa-repeat:before,.fa-rotate-right:before{content:""}.fa-refresh:before{content:""}.fa-list-alt:before{content:""}.fa-lock:before{content:""}.fa-flag:before{content:""}.fa-headphones:before{content:""}.fa-volume-off:before{content:""}.fa-volume-down:before{content:""}.fa-volume-up:before{content:""}.fa-qrcode:before{content:""}.fa-barcode:before{content:""}.fa-tag:before{content:""}.fa-tags:before{content:""}.fa-book:before,.icon-book:before{content:""}.fa-bookmark:before{content:""}.fa-print:before{content:""}.fa-camera:before{content:""}.fa-font:before{content:""}.fa-bold:before{content:""}.fa-italic:before{content:""}.fa-text-height:before{content:""}.fa-text-width:before{content:""}.fa-align-left:before{content:""}.fa-align-center:before{content:""}.fa-align-right:before{content:""}.fa-align-justify:before{content:""}.fa-list:before{content:""}.fa-dedent:before,.fa-outdent:before{content:""}.fa-indent:before{content:""}.fa-video-camera:before{content:""}.fa-image:before,.fa-photo:before,.fa-picture-o:before{content:""}.fa-pencil:before{content:""}.fa-map-marker:before{content:""}.fa-adjust:before{content:""}.fa-tint:before{content:""}.fa-edit:before,.fa-pencil-square-o:before{content:""}.fa-share-square-o:before{content:""}.fa-check-square-o:before{content:""}.fa-arrows:before{content:""}.fa-step-backward:before{content:""}.fa-fast-backward:before{content:""}.fa-backward:before{content:""}.fa-play:before{content:""}.fa-pause:before{content:""}.fa-stop:before{content:""}.fa-forward:before{content:""}.fa-fast-forward:before{content:""}.fa-step-forward:before{content:""}.fa-eject:before{content:""}.fa-chevron-left:before{content:""}.fa-chevron-right:before{content:""}.fa-plus-circle:before{content:""}.fa-minus-circle:before{content:""}.fa-times-circle:before,.wy-inline-validate.wy-inline-validate-danger .wy-input-context:before{content:""}.fa-check-circle:before,.wy-inline-validate.wy-inline-validate-success .wy-input-context:before{content:""}.fa-question-circle:before{content:""}.fa-info-circle:before{content:""}.fa-crosshairs:before{content:""}.fa-times-circle-o:before{content:""}.fa-check-circle-o:before{content:""}.fa-ban:before{content:""}.fa-arrow-left:before{content:""}.fa-arrow-right:before{content:""}.fa-arrow-up:before{content:""}.fa-arrow-down:before{content:""}.fa-mail-forward:before,.fa-share:before{content:""}.fa-expand:before{content:""}.fa-compress:before{content:""}.fa-plus:before{content:""}.fa-minus:before{content:""}.fa-asterisk:before{content:""}.fa-exclamation-circle:before,.rst-content .admonition-title:before,.wy-inline-validate.wy-inline-validate-info .wy-input-context:before,.wy-inline-validate.wy-inline-validate-warning .wy-input-context:before{content:""}.fa-gift:before{content:""}.fa-leaf:before{content:""}.fa-fire:before,.icon-fire:before{content:""}.fa-eye:before{content:""}.fa-eye-slash:before{content:""}.fa-exclamation-triangle:before,.fa-warning:before{content:""}.fa-plane:before{content:""}.fa-calendar:before{content:""}.fa-random:before{content:""}.fa-comment:before{content:""}.fa-magnet:before{content:""}.fa-chevron-up:before{content:""}.fa-chevron-down:before{content:""}.fa-retweet:before{content:""}.fa-shopping-cart:before{content:""}.fa-folder:before{content:""}.fa-folder-open:before{content:""}.fa-arrows-v:before{content:""}.fa-arrows-h:before{content:""}.fa-bar-chart-o:before,.fa-bar-chart:before{content:""}.fa-twitter-square:before{content:""}.fa-facebook-square:before{content:""}.fa-camera-retro:before{content:""}.fa-key:before{content:""}.fa-cogs:before,.fa-gears:before{content:""}.fa-comments:before{content:""}.fa-thumbs-o-up:before{content:""}.fa-thumbs-o-down:before{content:""}.fa-star-half:before{content:""}.fa-heart-o:before{content:""}.fa-sign-out:before{content:""}.fa-linkedin-square:before{content:""}.fa-thumb-tack:before{content:""}.fa-external-link:before{content:""}.fa-sign-in:before{content:""}.fa-trophy:before{content:""}.fa-github-square:before{content:""}.fa-upload:before{content:""}.fa-lemon-o:before{content:""}.fa-phone:before{content:""}.fa-square-o:before{content:""}.fa-bookmark-o:before{content:""}.fa-phone-square:before{content:""}.fa-twitter:before{content:""}.fa-facebook-f:before,.fa-facebook:before{content:""}.fa-github:before,.icon-github:before{content:""}.fa-unlock:before{content:""}.fa-credit-card:before{content:""}.fa-feed:before,.fa-rss:before{content:""}.fa-hdd-o:before{content:""}.fa-bullhorn:before{content:""}.fa-bell:before{content:""}.fa-certificate:before{content:""}.fa-hand-o-right:before{content:""}.fa-hand-o-left:before{content:""}.fa-hand-o-up:before{content:""}.fa-hand-o-down:before{content:""}.fa-arrow-circle-left:before,.icon-circle-arrow-left:before{content:""}.fa-arrow-circle-right:before,.icon-circle-arrow-right:before{content:""}.fa-arrow-circle-up:before{content:""}.fa-arrow-circle-down:before{content:""}.fa-globe:before{content:""}.fa-wrench:before{content:""}.fa-tasks:before{content:""}.fa-filter:before{content:""}.fa-briefcase:before{content:""}.fa-arrows-alt:before{content:""}.fa-group:before,.fa-users:before{content:""}.fa-chain:before,.fa-link:before,.icon-link:before{content:""}.fa-cloud:before{content:""}.fa-flask:before{content:""}.fa-cut:before,.fa-scissors:before{content:""}.fa-copy:before,.fa-files-o:before{content:""}.fa-paperclip:before{content:""}.fa-floppy-o:before,.fa-save:before{content:""}.fa-square:before{content:""}.fa-bars:before,.fa-navicon:before,.fa-reorder:before{content:""}.fa-list-ul:before{content:""}.fa-list-ol:before{content:""}.fa-strikethrough:before{content:""}.fa-underline:before{content:""}.fa-table:before{content:""}.fa-magic:before{content:""}.fa-truck:before{content:""}.fa-pinterest:before{content:""}.fa-pinterest-square:before{content:""}.fa-google-plus-square:before{content:""}.fa-google-plus:before{content:""}.fa-money:before{content:""}.fa-caret-down:before,.icon-caret-down:before,.wy-dropdown .caret:before{content:""}.fa-caret-up:before{content:""}.fa-caret-left:before{content:""}.fa-caret-right:before{content:""}.fa-columns:before{content:""}.fa-sort:before,.fa-unsorted:before{content:""}.fa-sort-desc:before,.fa-sort-down:before{content:""}.fa-sort-asc:before,.fa-sort-up:before{content:""}.fa-envelope:before{content:""}.fa-linkedin:before{content:""}.fa-rotate-left:before,.fa-undo:before{content:""}.fa-gavel:before,.fa-legal:before{content:""}.fa-dashboard:before,.fa-tachometer:before{content:""}.fa-comment-o:before{content:""}.fa-comments-o:before{content:""}.fa-bolt:before,.fa-flash:before{content:""}.fa-sitemap:before{content:""}.fa-umbrella:before{content:""}.fa-clipboard:before,.fa-paste:before{content:""}.fa-lightbulb-o:before{content:""}.fa-exchange:before{content:""}.fa-cloud-download:before{content:""}.fa-cloud-upload:before{content:""}.fa-user-md:before{content:""}.fa-stethoscope:before{content:""}.fa-suitcase:before{content:""}.fa-bell-o:before{content:""}.fa-coffee:before{content:""}.fa-cutlery:before{content:""}.fa-file-text-o:before{content:""}.fa-building-o:before{content:""}.fa-hospital-o:before{content:""}.fa-ambulance:before{content:""}.fa-medkit:before{content:""}.fa-fighter-jet:before{content:""}.fa-beer:before{content:""}.fa-h-square:before{content:""}.fa-plus-square:before{content:""}.fa-angle-double-left:before{content:""}.fa-angle-double-right:before{content:""}.fa-angle-double-up:before{content:""}.fa-angle-double-down:before{content:""}.fa-angle-left:before{content:""}.fa-angle-right:before{content:""}.fa-angle-up:before{content:""}.fa-angle-down:before{content:""}.fa-desktop:before{content:""}.fa-laptop:before{content:""}.fa-tablet:before{content:""}.fa-mobile-phone:before,.fa-mobile:before{content:""}.fa-circle-o:before{content:""}.fa-quote-left:before{content:""}.fa-quote-right:before{content:""}.fa-spinner:before{content:""}.fa-circle:before{content:""}.fa-mail-reply:before,.fa-reply:before{content:""}.fa-github-alt:before{content:""}.fa-folder-o:before{content:""}.fa-folder-open-o:before{content:""}.fa-smile-o:before{content:""}.fa-frown-o:before{content:""}.fa-meh-o:before{content:""}.fa-gamepad:before{content:""}.fa-keyboard-o:before{content:""}.fa-flag-o:before{content:""}.fa-flag-checkered:before{content:""}.fa-terminal:before{content:""}.fa-code:before{content:""}.fa-mail-reply-all:before,.fa-reply-all:before{content:""}.fa-star-half-empty:before,.fa-star-half-full:before,.fa-star-half-o:before{content:""}.fa-location-arrow:before{content:""}.fa-crop:before{content:""}.fa-code-fork:before{content:""}.fa-chain-broken:before,.fa-unlink:before{content:""}.fa-question:before{content:""}.fa-info:before{content:""}.fa-exclamation:before{content:""}.fa-superscript:before{content:""}.fa-subscript:before{content:""}.fa-eraser:before{content:""}.fa-puzzle-piece:before{content:""}.fa-microphone:before{content:""}.fa-microphone-slash:before{content:""}.fa-shield:before{content:""}.fa-calendar-o:before{content:""}.fa-fire-extinguisher:before{content:""}.fa-rocket:before{content:""}.fa-maxcdn:before{content:""}.fa-chevron-circle-left:before{content:""}.fa-chevron-circle-right:before{content:""}.fa-chevron-circle-up:before{content:""}.fa-chevron-circle-down:before{content:""}.fa-html5:before{content:""}.fa-css3:before{content:""}.fa-anchor:before{content:""}.fa-unlock-alt:before{content:""}.fa-bullseye:before{content:""}.fa-ellipsis-h:before{content:""}.fa-ellipsis-v:before{content:""}.fa-rss-square:before{content:""}.fa-play-circle:before{content:""}.fa-ticket:before{content:""}.fa-minus-square:before{content:""}.fa-minus-square-o:before,.wy-menu-vertical li.current>a button.toctree-expand:before,.wy-menu-vertical li.on a button.toctree-expand:before{content:""}.fa-level-up:before{content:""}.fa-level-down:before{content:""}.fa-check-square:before{content:""}.fa-pencil-square:before{content:""}.fa-external-link-square:before{content:""}.fa-share-square:before{content:""}.fa-compass:before{content:""}.fa-caret-square-o-down:before,.fa-toggle-down:before{content:""}.fa-caret-square-o-up:before,.fa-toggle-up:before{content:""}.fa-caret-square-o-right:before,.fa-toggle-right:before{content:""}.fa-eur:before,.fa-euro:before{content:""}.fa-gbp:before{content:""}.fa-dollar:before,.fa-usd:before{content:""}.fa-inr:before,.fa-rupee:before{content:""}.fa-cny:before,.fa-jpy:before,.fa-rmb:before,.fa-yen:before{content:""}.fa-rouble:before,.fa-rub:before,.fa-ruble:before{content:""}.fa-krw:before,.fa-won:before{content:""}.fa-bitcoin:before,.fa-btc:before{content:""}.fa-file:before{content:""}.fa-file-text:before{content:""}.fa-sort-alpha-asc:before{content:""}.fa-sort-alpha-desc:before{content:""}.fa-sort-amount-asc:before{content:""}.fa-sort-amount-desc:before{content:""}.fa-sort-numeric-asc:before{content:""}.fa-sort-numeric-desc:before{content:""}.fa-thumbs-up:before{content:""}.fa-thumbs-down:before{content:""}.fa-youtube-square:before{content:""}.fa-youtube:before{content:""}.fa-xing:before{content:""}.fa-xing-square:before{content:""}.fa-youtube-play:before{content:""}.fa-dropbox:before{content:""}.fa-stack-overflow:before{content:""}.fa-instagram:before{content:""}.fa-flickr:before{content:""}.fa-adn:before{content:""}.fa-bitbucket:before,.icon-bitbucket:before{content:""}.fa-bitbucket-square:before{content:""}.fa-tumblr:before{content:""}.fa-tumblr-square:before{content:""}.fa-long-arrow-down:before{content:""}.fa-long-arrow-up:before{content:""}.fa-long-arrow-left:before{content:""}.fa-long-arrow-right:before{content:""}.fa-apple:before{content:""}.fa-windows:before{content:""}.fa-android:before{content:""}.fa-linux:before{content:""}.fa-dribbble:before{content:""}.fa-skype:before{content:""}.fa-foursquare:before{content:""}.fa-trello:before{content:""}.fa-female:before{content:""}.fa-male:before{content:""}.fa-gittip:before,.fa-gratipay:before{content:""}.fa-sun-o:before{content:""}.fa-moon-o:before{content:""}.fa-archive:before{content:""}.fa-bug:before{content:""}.fa-vk:before{content:""}.fa-weibo:before{content:""}.fa-renren:before{content:""}.fa-pagelines:before{content:""}.fa-stack-exchange:before{content:""}.fa-arrow-circle-o-right:before{content:""}.fa-arrow-circle-o-left:before{content:""}.fa-caret-square-o-left:before,.fa-toggle-left:before{content:""}.fa-dot-circle-o:before{content:""}.fa-wheelchair:before{content:""}.fa-vimeo-square:before{content:""}.fa-try:before,.fa-turkish-lira:before{content:""}.fa-plus-square-o:before,.wy-menu-vertical li button.toctree-expand:before{content:""}.fa-space-shuttle:before{content:""}.fa-slack:before{content:""}.fa-envelope-square:before{content:""}.fa-wordpress:before{content:""}.fa-openid:before{content:""}.fa-bank:before,.fa-institution:before,.fa-university:before{content:""}.fa-graduation-cap:before,.fa-mortar-board:before{content:""}.fa-yahoo:before{content:""}.fa-google:before{content:""}.fa-reddit:before{content:""}.fa-reddit-square:before{content:""}.fa-stumbleupon-circle:before{content:""}.fa-stumbleupon:before{content:""}.fa-delicious:before{content:""}.fa-digg:before{content:""}.fa-pied-piper-pp:before{content:""}.fa-pied-piper-alt:before{content:""}.fa-drupal:before{content:""}.fa-joomla:before{content:""}.fa-language:before{content:""}.fa-fax:before{content:""}.fa-building:before{content:""}.fa-child:before{content:""}.fa-paw:before{content:""}.fa-spoon:before{content:""}.fa-cube:before{content:""}.fa-cubes:before{content:""}.fa-behance:before{content:""}.fa-behance-square:before{content:""}.fa-steam:before{content:""}.fa-steam-square:before{content:""}.fa-recycle:before{content:""}.fa-automobile:before,.fa-car:before{content:""}.fa-cab:before,.fa-taxi:before{content:""}.fa-tree:before{content:""}.fa-spotify:before{content:""}.fa-deviantart:before{content:""}.fa-soundcloud:before{content:""}.fa-database:before{content:""}.fa-file-pdf-o:before{content:""}.fa-file-word-o:before{content:""}.fa-file-excel-o:before{content:""}.fa-file-powerpoint-o:before{content:""}.fa-file-image-o:before,.fa-file-photo-o:before,.fa-file-picture-o:before{content:""}.fa-file-archive-o:before,.fa-file-zip-o:before{content:""}.fa-file-audio-o:before,.fa-file-sound-o:before{content:""}.fa-file-movie-o:before,.fa-file-video-o:before{content:""}.fa-file-code-o:before{content:""}.fa-vine:before{content:""}.fa-codepen:before{content:""}.fa-jsfiddle:before{content:""}.fa-life-bouy:before,.fa-life-buoy:before,.fa-life-ring:before,.fa-life-saver:before,.fa-support:before{content:""}.fa-circle-o-notch:before{content:""}.fa-ra:before,.fa-rebel:before,.fa-resistance:before{content:""}.fa-empire:before,.fa-ge:before{content:""}.fa-git-square:before{content:""}.fa-git:before{content:""}.fa-hacker-news:before,.fa-y-combinator-square:before,.fa-yc-square:before{content:""}.fa-tencent-weibo:before{content:""}.fa-qq:before{content:""}.fa-wechat:before,.fa-weixin:before{content:""}.fa-paper-plane:before,.fa-send:before{content:""}.fa-paper-plane-o:before,.fa-send-o:before{content:""}.fa-history:before{content:""}.fa-circle-thin:before{content:""}.fa-header:before{content:""}.fa-paragraph:before{content:""}.fa-sliders:before{content:""}.fa-share-alt:before{content:""}.fa-share-alt-square:before{content:""}.fa-bomb:before{content:""}.fa-futbol-o:before,.fa-soccer-ball-o:before{content:""}.fa-tty:before{content:""}.fa-binoculars:before{content:""}.fa-plug:before{content:""}.fa-slideshare:before{content:""}.fa-twitch:before{content:""}.fa-yelp:before{content:""}.fa-newspaper-o:before{content:""}.fa-wifi:before{content:""}.fa-calculator:before{content:""}.fa-paypal:before{content:""}.fa-google-wallet:before{content:""}.fa-cc-visa:before{content:""}.fa-cc-mastercard:before{content:""}.fa-cc-discover:before{content:""}.fa-cc-amex:before{content:""}.fa-cc-paypal:before{content:""}.fa-cc-stripe:before{content:""}.fa-bell-slash:before{content:""}.fa-bell-slash-o:before{content:""}.fa-trash:before{content:""}.fa-copyright:before{content:""}.fa-at:before{content:""}.fa-eyedropper:before{content:""}.fa-paint-brush:before{content:""}.fa-birthday-cake:before{content:""}.fa-area-chart:before{content:""}.fa-pie-chart:before{content:""}.fa-line-chart:before{content:""}.fa-lastfm:before{content:""}.fa-lastfm-square:before{content:""}.fa-toggle-off:before{content:""}.fa-toggle-on:before{content:""}.fa-bicycle:before{content:""}.fa-bus:before{content:""}.fa-ioxhost:before{content:""}.fa-angellist:before{content:""}.fa-cc:before{content:""}.fa-ils:before,.fa-shekel:before,.fa-sheqel:before{content:""}.fa-meanpath:before{content:""}.fa-buysellads:before{content:""}.fa-connectdevelop:before{content:""}.fa-dashcube:before{content:""}.fa-forumbee:before{content:""}.fa-leanpub:before{content:""}.fa-sellsy:before{content:""}.fa-shirtsinbulk:before{content:""}.fa-simplybuilt:before{content:""}.fa-skyatlas:before{content:""}.fa-cart-plus:before{content:""}.fa-cart-arrow-down:before{content:""}.fa-diamond:before{content:""}.fa-ship:before{content:""}.fa-user-secret:before{content:""}.fa-motorcycle:before{content:""}.fa-street-view:before{content:""}.fa-heartbeat:before{content:""}.fa-venus:before{content:""}.fa-mars:before{content:""}.fa-mercury:before{content:""}.fa-intersex:before,.fa-transgender:before{content:""}.fa-transgender-alt:before{content:""}.fa-venus-double:before{content:""}.fa-mars-double:before{content:""}.fa-venus-mars:before{content:""}.fa-mars-stroke:before{content:""}.fa-mars-stroke-v:before{content:""}.fa-mars-stroke-h:before{content:""}.fa-neuter:before{content:""}.fa-genderless:before{content:""}.fa-facebook-official:before{content:""}.fa-pinterest-p:before{content:""}.fa-whatsapp:before{content:""}.fa-server:before{content:""}.fa-user-plus:before{content:""}.fa-user-times:before{content:""}.fa-bed:before,.fa-hotel:before{content:""}.fa-viacoin:before{content:""}.fa-train:before{content:""}.fa-subway:before{content:""}.fa-medium:before{content:""}.fa-y-combinator:before,.fa-yc:before{content:""}.fa-optin-monster:before{content:""}.fa-opencart:before{content:""}.fa-expeditedssl:before{content:""}.fa-battery-4:before,.fa-battery-full:before,.fa-battery:before{content:""}.fa-battery-3:before,.fa-battery-three-quarters:before{content:""}.fa-battery-2:before,.fa-battery-half:before{content:""}.fa-battery-1:before,.fa-battery-quarter:before{content:""}.fa-battery-0:before,.fa-battery-empty:before{content:""}.fa-mouse-pointer:before{content:""}.fa-i-cursor:before{content:""}.fa-object-group:before{content:""}.fa-object-ungroup:before{content:""}.fa-sticky-note:before{content:""}.fa-sticky-note-o:before{content:""}.fa-cc-jcb:before{content:""}.fa-cc-diners-club:before{content:""}.fa-clone:before{content:""}.fa-balance-scale:before{content:""}.fa-hourglass-o:before{content:""}.fa-hourglass-1:before,.fa-hourglass-start:before{content:""}.fa-hourglass-2:before,.fa-hourglass-half:before{content:""}.fa-hourglass-3:before,.fa-hourglass-end:before{content:""}.fa-hourglass:before{content:""}.fa-hand-grab-o:before,.fa-hand-rock-o:before{content:""}.fa-hand-paper-o:before,.fa-hand-stop-o:before{content:""}.fa-hand-scissors-o:before{content:""}.fa-hand-lizard-o:before{content:""}.fa-hand-spock-o:before{content:""}.fa-hand-pointer-o:before{content:""}.fa-hand-peace-o:before{content:""}.fa-trademark:before{content:""}.fa-registered:before{content:""}.fa-creative-commons:before{content:""}.fa-gg:before{content:""}.fa-gg-circle:before{content:""}.fa-tripadvisor:before{content:""}.fa-odnoklassniki:before{content:""}.fa-odnoklassniki-square:before{content:""}.fa-get-pocket:before{content:""}.fa-wikipedia-w:before{content:""}.fa-safari:before{content:""}.fa-chrome:before{content:""}.fa-firefox:before{content:""}.fa-opera:before{content:""}.fa-internet-explorer:before{content:""}.fa-television:before,.fa-tv:before{content:""}.fa-contao:before{content:""}.fa-500px:before{content:""}.fa-amazon:before{content:""}.fa-calendar-plus-o:before{content:""}.fa-calendar-minus-o:before{content:""}.fa-calendar-times-o:before{content:""}.fa-calendar-check-o:before{content:""}.fa-industry:before{content:""}.fa-map-pin:before{content:""}.fa-map-signs:before{content:""}.fa-map-o:before{content:""}.fa-map:before{content:""}.fa-commenting:before{content:""}.fa-commenting-o:before{content:""}.fa-houzz:before{content:""}.fa-vimeo:before{content:""}.fa-black-tie:before{content:""}.fa-fonticons:before{content:""}.fa-reddit-alien:before{content:""}.fa-edge:before{content:""}.fa-credit-card-alt:before{content:""}.fa-codiepie:before{content:""}.fa-modx:before{content:""}.fa-fort-awesome:before{content:""}.fa-usb:before{content:""}.fa-product-hunt:before{content:""}.fa-mixcloud:before{content:""}.fa-scribd:before{content:""}.fa-pause-circle:before{content:""}.fa-pause-circle-o:before{content:""}.fa-stop-circle:before{content:""}.fa-stop-circle-o:before{content:""}.fa-shopping-bag:before{content:""}.fa-shopping-basket:before{content:""}.fa-hashtag:before{content:""}.fa-bluetooth:before{content:""}.fa-bluetooth-b:before{content:""}.fa-percent:before{content:""}.fa-gitlab:before,.icon-gitlab:before{content:""}.fa-wpbeginner:before{content:""}.fa-wpforms:before{content:""}.fa-envira:before{content:""}.fa-universal-access:before{content:""}.fa-wheelchair-alt:before{content:""}.fa-question-circle-o:before{content:""}.fa-blind:before{content:""}.fa-audio-description:before{content:""}.fa-volume-control-phone:before{content:""}.fa-braille:before{content:""}.fa-assistive-listening-systems:before{content:""}.fa-american-sign-language-interpreting:before,.fa-asl-interpreting:before{content:""}.fa-deaf:before,.fa-deafness:before,.fa-hard-of-hearing:before{content:""}.fa-glide:before{content:""}.fa-glide-g:before{content:""}.fa-sign-language:before,.fa-signing:before{content:""}.fa-low-vision:before{content:""}.fa-viadeo:before{content:""}.fa-viadeo-square:before{content:""}.fa-snapchat:before{content:""}.fa-snapchat-ghost:before{content:""}.fa-snapchat-square:before{content:""}.fa-pied-piper:before{content:""}.fa-first-order:before{content:""}.fa-yoast:before{content:""}.fa-themeisle:before{content:""}.fa-google-plus-circle:before,.fa-google-plus-official:before{content:""}.fa-fa:before,.fa-font-awesome:before{content:""}.fa-handshake-o:before{content:""}.fa-envelope-open:before{content:""}.fa-envelope-open-o:before{content:""}.fa-linode:before{content:""}.fa-address-book:before{content:""}.fa-address-book-o:before{content:""}.fa-address-card:before,.fa-vcard:before{content:""}.fa-address-card-o:before,.fa-vcard-o:before{content:""}.fa-user-circle:before{content:""}.fa-user-circle-o:before{content:""}.fa-user-o:before{content:""}.fa-id-badge:before{content:""}.fa-drivers-license:before,.fa-id-card:before{content:""}.fa-drivers-license-o:before,.fa-id-card-o:before{content:""}.fa-quora:before{content:""}.fa-free-code-camp:before{content:""}.fa-telegram:before{content:""}.fa-thermometer-4:before,.fa-thermometer-full:before,.fa-thermometer:before{content:""}.fa-thermometer-3:before,.fa-thermometer-three-quarters:before{content:""}.fa-thermometer-2:before,.fa-thermometer-half:before{content:""}.fa-thermometer-1:before,.fa-thermometer-quarter:before{content:""}.fa-thermometer-0:before,.fa-thermometer-empty:before{content:""}.fa-shower:before{content:""}.fa-bath:before,.fa-bathtub:before,.fa-s15:before{content:""}.fa-podcast:before{content:""}.fa-window-maximize:before{content:""}.fa-window-minimize:before{content:""}.fa-window-restore:before{content:""}.fa-times-rectangle:before,.fa-window-close:before{content:""}.fa-times-rectangle-o:before,.fa-window-close-o:before{content:""}.fa-bandcamp:before{content:""}.fa-grav:before{content:""}.fa-etsy:before{content:""}.fa-imdb:before{content:""}.fa-ravelry:before{content:""}.fa-eercast:before{content:""}.fa-microchip:before{content:""}.fa-snowflake-o:before{content:""}.fa-superpowers:before{content:""}.fa-wpexplorer:before{content:""}.fa-meetup:before{content:""}.sr-only{position:absolute;width:1px;height:1px;padding:0;margin:-1px;overflow:hidden;clip:rect(0,0,0,0);border:0}.sr-only-focusable:active,.sr-only-focusable:focus{position:static;width:auto;height:auto;margin:0;overflow:visible;clip:auto}.fa,.icon,.rst-content .admonition-title,.rst-content .code-block-caption .headerlink,.rst-content .eqno .headerlink,.rst-content code.download span:first-child,.rst-content dl dt .headerlink,.rst-content h1 .headerlink,.rst-content h2 .headerlink,.rst-content h3 .headerlink,.rst-content h4 .headerlink,.rst-content h5 .headerlink,.rst-content h6 .headerlink,.rst-content p.caption .headerlink,.rst-content p .headerlink,.rst-content table>caption .headerlink,.rst-content tt.download span:first-child,.wy-dropdown .caret,.wy-inline-validate.wy-inline-validate-danger .wy-input-context,.wy-inline-validate.wy-inline-validate-info .wy-input-context,.wy-inline-validate.wy-inline-validate-success .wy-input-context,.wy-inline-validate.wy-inline-validate-warning .wy-input-context,.wy-menu-vertical li.current>a button.toctree-expand,.wy-menu-vertical li.on a button.toctree-expand,.wy-menu-vertical li button.toctree-expand{font-family:inherit}.fa:before,.icon:before,.rst-content .admonition-title:before,.rst-content .code-block-caption .headerlink:before,.rst-content .eqno .headerlink:before,.rst-content code.download span:first-child:before,.rst-content dl dt .headerlink:before,.rst-content h1 .headerlink:before,.rst-content h2 .headerlink:before,.rst-content h3 .headerlink:before,.rst-content h4 .headerlink:before,.rst-content h5 .headerlink:before,.rst-content h6 .headerlink:before,.rst-content p.caption .headerlink:before,.rst-content p .headerlink:before,.rst-content table>caption .headerlink:before,.rst-content tt.download span:first-child:before,.wy-dropdown .caret:before,.wy-inline-validate.wy-inline-validate-danger .wy-input-context:before,.wy-inline-validate.wy-inline-validate-info .wy-input-context:before,.wy-inline-validate.wy-inline-validate-success .wy-input-context:before,.wy-inline-validate.wy-inline-validate-warning .wy-input-context:before,.wy-menu-vertical li.current>a button.toctree-expand:before,.wy-menu-vertical li.on a button.toctree-expand:before,.wy-menu-vertical li button.toctree-expand:before{font-family:FontAwesome;display:inline-block;font-style:normal;font-weight:400;line-height:1;text-decoration:inherit}.rst-content .code-block-caption a .headerlink,.rst-content .eqno a .headerlink,.rst-content a .admonition-title,.rst-content code.download a span:first-child,.rst-content dl dt a .headerlink,.rst-content h1 a .headerlink,.rst-content h2 a .headerlink,.rst-content h3 a .headerlink,.rst-content h4 a .headerlink,.rst-content h5 a .headerlink,.rst-content h6 a .headerlink,.rst-content p.caption a .headerlink,.rst-content p a .headerlink,.rst-content table>caption a .headerlink,.rst-content tt.download a span:first-child,.wy-menu-vertical li.current>a button.toctree-expand,.wy-menu-vertical li.on a button.toctree-expand,.wy-menu-vertical li a button.toctree-expand,a .fa,a .icon,a .rst-content .admonition-title,a .rst-content .code-block-caption .headerlink,a .rst-content .eqno .headerlink,a .rst-content code.download span:first-child,a .rst-content dl dt .headerlink,a .rst-content h1 .headerlink,a .rst-content h2 .headerlink,a .rst-content h3 .headerlink,a .rst-content h4 .headerlink,a .rst-content h5 .headerlink,a .rst-content h6 .headerlink,a .rst-content p.caption .headerlink,a .rst-content p .headerlink,a .rst-content table>caption .headerlink,a .rst-content tt.download span:first-child,a .wy-menu-vertical li button.toctree-expand{display:inline-block;text-decoration:inherit}.btn .fa,.btn .icon,.btn .rst-content .admonition-title,.btn .rst-content .code-block-caption .headerlink,.btn .rst-content .eqno .headerlink,.btn .rst-content code.download span:first-child,.btn .rst-content dl dt .headerlink,.btn .rst-content h1 .headerlink,.btn .rst-content h2 .headerlink,.btn .rst-content h3 .headerlink,.btn .rst-content h4 .headerlink,.btn .rst-content h5 .headerlink,.btn .rst-content h6 .headerlink,.btn .rst-content p .headerlink,.btn .rst-content table>caption .headerlink,.btn .rst-content tt.download span:first-child,.btn .wy-menu-vertical li.current>a button.toctree-expand,.btn .wy-menu-vertical li.on a button.toctree-expand,.btn .wy-menu-vertical li button.toctree-expand,.nav .fa,.nav .icon,.nav .rst-content .admonition-title,.nav .rst-content .code-block-caption .headerlink,.nav .rst-content .eqno .headerlink,.nav .rst-content code.download span:first-child,.nav .rst-content dl dt .headerlink,.nav .rst-content h1 .headerlink,.nav .rst-content h2 .headerlink,.nav .rst-content h3 .headerlink,.nav .rst-content h4 .headerlink,.nav .rst-content h5 .headerlink,.nav .rst-content h6 .headerlink,.nav .rst-content p .headerlink,.nav .rst-content table>caption .headerlink,.nav .rst-content tt.download span:first-child,.nav .wy-menu-vertical li.current>a button.toctree-expand,.nav .wy-menu-vertical li.on a button.toctree-expand,.nav .wy-menu-vertical li button.toctree-expand,.rst-content .btn .admonition-title,.rst-content .code-block-caption .btn .headerlink,.rst-content .code-block-caption .nav .headerlink,.rst-content .eqno .btn .headerlink,.rst-content .eqno .nav .headerlink,.rst-content .nav .admonition-title,.rst-content code.download .btn span:first-child,.rst-content code.download .nav span:first-child,.rst-content dl dt .btn .headerlink,.rst-content dl dt .nav .headerlink,.rst-content h1 .btn .headerlink,.rst-content h1 .nav .headerlink,.rst-content h2 .btn .headerlink,.rst-content h2 .nav .headerlink,.rst-content h3 .btn .headerlink,.rst-content h3 .nav .headerlink,.rst-content h4 .btn .headerlink,.rst-content h4 .nav .headerlink,.rst-content h5 .btn .headerlink,.rst-content h5 .nav .headerlink,.rst-content h6 .btn .headerlink,.rst-content h6 .nav .headerlink,.rst-content p .btn .headerlink,.rst-content p .nav .headerlink,.rst-content table>caption .btn .headerlink,.rst-content table>caption .nav .headerlink,.rst-content tt.download .btn span:first-child,.rst-content tt.download .nav span:first-child,.wy-menu-vertical li .btn button.toctree-expand,.wy-menu-vertical li.current>a .btn button.toctree-expand,.wy-menu-vertical li.current>a .nav button.toctree-expand,.wy-menu-vertical li .nav button.toctree-expand,.wy-menu-vertical li.on a .btn button.toctree-expand,.wy-menu-vertical li.on a .nav button.toctree-expand{display:inline}.btn .fa-large.icon,.btn .fa.fa-large,.btn .rst-content .code-block-caption .fa-large.headerlink,.btn .rst-content .eqno .fa-large.headerlink,.btn .rst-content .fa-large.admonition-title,.btn .rst-content code.download span.fa-large:first-child,.btn .rst-content dl dt .fa-large.headerlink,.btn .rst-content h1 .fa-large.headerlink,.btn .rst-content h2 .fa-large.headerlink,.btn .rst-content h3 .fa-large.headerlink,.btn .rst-content h4 .fa-large.headerlink,.btn .rst-content h5 .fa-large.headerlink,.btn .rst-content h6 .fa-large.headerlink,.btn .rst-content p .fa-large.headerlink,.btn .rst-content table>caption .fa-large.headerlink,.btn .rst-content tt.download span.fa-large:first-child,.btn .wy-menu-vertical li button.fa-large.toctree-expand,.nav .fa-large.icon,.nav .fa.fa-large,.nav .rst-content .code-block-caption .fa-large.headerlink,.nav .rst-content .eqno .fa-large.headerlink,.nav .rst-content .fa-large.admonition-title,.nav .rst-content code.download span.fa-large:first-child,.nav .rst-content dl dt .fa-large.headerlink,.nav .rst-content h1 .fa-large.headerlink,.nav .rst-content h2 .fa-large.headerlink,.nav .rst-content h3 .fa-large.headerlink,.nav .rst-content h4 .fa-large.headerlink,.nav .rst-content h5 .fa-large.headerlink,.nav .rst-content h6 .fa-large.headerlink,.nav .rst-content p .fa-large.headerlink,.nav .rst-content table>caption .fa-large.headerlink,.nav .rst-content tt.download span.fa-large:first-child,.nav .wy-menu-vertical li button.fa-large.toctree-expand,.rst-content .btn .fa-large.admonition-title,.rst-content .code-block-caption .btn .fa-large.headerlink,.rst-content .code-block-caption .nav .fa-large.headerlink,.rst-content .eqno .btn .fa-large.headerlink,.rst-content .eqno .nav .fa-large.headerlink,.rst-content .nav .fa-large.admonition-title,.rst-content code.download .btn span.fa-large:first-child,.rst-content code.download .nav span.fa-large:first-child,.rst-content dl dt .btn .fa-large.headerlink,.rst-content dl dt .nav .fa-large.headerlink,.rst-content h1 .btn .fa-large.headerlink,.rst-content h1 .nav .fa-large.headerlink,.rst-content h2 .btn .fa-large.headerlink,.rst-content h2 .nav .fa-large.headerlink,.rst-content h3 .btn .fa-large.headerlink,.rst-content h3 .nav .fa-large.headerlink,.rst-content h4 .btn .fa-large.headerlink,.rst-content h4 .nav .fa-large.headerlink,.rst-content h5 .btn .fa-large.headerlink,.rst-content h5 .nav .fa-large.headerlink,.rst-content h6 .btn .fa-large.headerlink,.rst-content h6 .nav .fa-large.headerlink,.rst-content p .btn .fa-large.headerlink,.rst-content p .nav .fa-large.headerlink,.rst-content table>caption .btn .fa-large.headerlink,.rst-content table>caption .nav .fa-large.headerlink,.rst-content tt.download .btn span.fa-large:first-child,.rst-content tt.download .nav span.fa-large:first-child,.wy-menu-vertical li .btn button.fa-large.toctree-expand,.wy-menu-vertical li .nav button.fa-large.toctree-expand{line-height:.9em}.btn .fa-spin.icon,.btn .fa.fa-spin,.btn .rst-content .code-block-caption .fa-spin.headerlink,.btn .rst-content .eqno .fa-spin.headerlink,.btn .rst-content .fa-spin.admonition-title,.btn .rst-content code.download span.fa-spin:first-child,.btn .rst-content dl dt .fa-spin.headerlink,.btn .rst-content h1 .fa-spin.headerlink,.btn .rst-content h2 .fa-spin.headerlink,.btn .rst-content h3 .fa-spin.headerlink,.btn .rst-content h4 .fa-spin.headerlink,.btn .rst-content h5 .fa-spin.headerlink,.btn .rst-content h6 .fa-spin.headerlink,.btn .rst-content p .fa-spin.headerlink,.btn .rst-content table>caption .fa-spin.headerlink,.btn .rst-content tt.download span.fa-spin:first-child,.btn .wy-menu-vertical li button.fa-spin.toctree-expand,.nav .fa-spin.icon,.nav .fa.fa-spin,.nav .rst-content .code-block-caption .fa-spin.headerlink,.nav .rst-content .eqno .fa-spin.headerlink,.nav .rst-content .fa-spin.admonition-title,.nav .rst-content code.download span.fa-spin:first-child,.nav .rst-content dl dt .fa-spin.headerlink,.nav .rst-content h1 .fa-spin.headerlink,.nav .rst-content h2 .fa-spin.headerlink,.nav .rst-content h3 .fa-spin.headerlink,.nav .rst-content h4 .fa-spin.headerlink,.nav .rst-content h5 .fa-spin.headerlink,.nav .rst-content h6 .fa-spin.headerlink,.nav .rst-content p .fa-spin.headerlink,.nav .rst-content table>caption .fa-spin.headerlink,.nav .rst-content tt.download span.fa-spin:first-child,.nav .wy-menu-vertical li button.fa-spin.toctree-expand,.rst-content .btn .fa-spin.admonition-title,.rst-content .code-block-caption .btn .fa-spin.headerlink,.rst-content .code-block-caption .nav .fa-spin.headerlink,.rst-content .eqno .btn .fa-spin.headerlink,.rst-content .eqno .nav .fa-spin.headerlink,.rst-content .nav .fa-spin.admonition-title,.rst-content code.download .btn span.fa-spin:first-child,.rst-content code.download .nav span.fa-spin:first-child,.rst-content dl dt .btn .fa-spin.headerlink,.rst-content dl dt .nav .fa-spin.headerlink,.rst-content h1 .btn .fa-spin.headerlink,.rst-content h1 .nav .fa-spin.headerlink,.rst-content h2 .btn .fa-spin.headerlink,.rst-content h2 .nav .fa-spin.headerlink,.rst-content h3 .btn .fa-spin.headerlink,.rst-content h3 .nav .fa-spin.headerlink,.rst-content h4 .btn .fa-spin.headerlink,.rst-content h4 .nav .fa-spin.headerlink,.rst-content h5 .btn .fa-spin.headerlink,.rst-content h5 .nav .fa-spin.headerlink,.rst-content h6 .btn .fa-spin.headerlink,.rst-content h6 .nav .fa-spin.headerlink,.rst-content p .btn .fa-spin.headerlink,.rst-content p .nav .fa-spin.headerlink,.rst-content table>caption .btn .fa-spin.headerlink,.rst-content table>caption .nav .fa-spin.headerlink,.rst-content tt.download .btn span.fa-spin:first-child,.rst-content tt.download .nav span.fa-spin:first-child,.wy-menu-vertical li .btn button.fa-spin.toctree-expand,.wy-menu-vertical li .nav button.fa-spin.toctree-expand{display:inline-block}.btn.fa:before,.btn.icon:before,.rst-content .btn.admonition-title:before,.rst-content .code-block-caption .btn.headerlink:before,.rst-content .eqno .btn.headerlink:before,.rst-content code.download span.btn:first-child:before,.rst-content dl dt .btn.headerlink:before,.rst-content h1 .btn.headerlink:before,.rst-content h2 .btn.headerlink:before,.rst-content h3 .btn.headerlink:before,.rst-content h4 .btn.headerlink:before,.rst-content h5 .btn.headerlink:before,.rst-content h6 .btn.headerlink:before,.rst-content p .btn.headerlink:before,.rst-content table>caption .btn.headerlink:before,.rst-content tt.download span.btn:first-child:before,.wy-menu-vertical li button.btn.toctree-expand:before{opacity:.5;-webkit-transition:opacity .05s ease-in;-moz-transition:opacity .05s ease-in;transition:opacity .05s ease-in}.btn.fa:hover:before,.btn.icon:hover:before,.rst-content .btn.admonition-title:hover:before,.rst-content .code-block-caption .btn.headerlink:hover:before,.rst-content .eqno .btn.headerlink:hover:before,.rst-content code.download span.btn:first-child:hover:before,.rst-content dl dt .btn.headerlink:hover:before,.rst-content h1 .btn.headerlink:hover:before,.rst-content h2 .btn.headerlink:hover:before,.rst-content h3 .btn.headerlink:hover:before,.rst-content h4 .btn.headerlink:hover:before,.rst-content h5 .btn.headerlink:hover:before,.rst-content h6 .btn.headerlink:hover:before,.rst-content p .btn.headerlink:hover:before,.rst-content table>caption .btn.headerlink:hover:before,.rst-content tt.download span.btn:first-child:hover:before,.wy-menu-vertical li button.btn.toctree-expand:hover:before{opacity:1}.btn-mini .fa:before,.btn-mini .icon:before,.btn-mini .rst-content .admonition-title:before,.btn-mini .rst-content .code-block-caption .headerlink:before,.btn-mini .rst-content .eqno .headerlink:before,.btn-mini .rst-content code.download span:first-child:before,.btn-mini .rst-content dl dt .headerlink:before,.btn-mini .rst-content h1 .headerlink:before,.btn-mini .rst-content h2 .headerlink:before,.btn-mini .rst-content h3 .headerlink:before,.btn-mini .rst-content h4 .headerlink:before,.btn-mini .rst-content h5 .headerlink:before,.btn-mini .rst-content h6 .headerlink:before,.btn-mini .rst-content p .headerlink:before,.btn-mini .rst-content table>caption .headerlink:before,.btn-mini .rst-content tt.download span:first-child:before,.btn-mini .wy-menu-vertical li button.toctree-expand:before,.rst-content .btn-mini .admonition-title:before,.rst-content .code-block-caption .btn-mini .headerlink:before,.rst-content .eqno .btn-mini .headerlink:before,.rst-content code.download .btn-mini span:first-child:before,.rst-content dl dt .btn-mini .headerlink:before,.rst-content h1 .btn-mini .headerlink:before,.rst-content h2 .btn-mini .headerlink:before,.rst-content h3 .btn-mini .headerlink:before,.rst-content h4 .btn-mini .headerlink:before,.rst-content h5 .btn-mini .headerlink:before,.rst-content h6 .btn-mini .headerlink:before,.rst-content p .btn-mini .headerlink:before,.rst-content table>caption .btn-mini .headerlink:before,.rst-content tt.download .btn-mini span:first-child:before,.wy-menu-vertical li .btn-mini button.toctree-expand:before{font-size:14px;vertical-align:-15%}.rst-content .admonition,.rst-content .admonition-todo,.rst-content .attention,.rst-content .caution,.rst-content .danger,.rst-content .error,.rst-content .hint,.rst-content .important,.rst-content .note,.rst-content .seealso,.rst-content .tip,.rst-content .warning,.wy-alert{padding:12px;line-height:24px;margin-bottom:24px;background:#e7f2fa}.rst-content .admonition-title,.wy-alert-title{font-weight:700;display:block;color:#fff;background:#6ab0de;padding:6px 12px;margin:-12px -12px 12px}.rst-content .danger,.rst-content .error,.rst-content .wy-alert-danger.admonition,.rst-content .wy-alert-danger.admonition-todo,.rst-content .wy-alert-danger.attention,.rst-content .wy-alert-danger.caution,.rst-content .wy-alert-danger.hint,.rst-content .wy-alert-danger.important,.rst-content .wy-alert-danger.note,.rst-content .wy-alert-danger.seealso,.rst-content .wy-alert-danger.tip,.rst-content .wy-alert-danger.warning,.wy-alert.wy-alert-danger{background:#fdf3f2}.rst-content .danger .admonition-title,.rst-content .danger .wy-alert-title,.rst-content .error .admonition-title,.rst-content .error .wy-alert-title,.rst-content .wy-alert-danger.admonition-todo .admonition-title,.rst-content .wy-alert-danger.admonition-todo .wy-alert-title,.rst-content .wy-alert-danger.admonition .admonition-title,.rst-content .wy-alert-danger.admonition .wy-alert-title,.rst-content .wy-alert-danger.attention .admonition-title,.rst-content .wy-alert-danger.attention .wy-alert-title,.rst-content .wy-alert-danger.caution .admonition-title,.rst-content .wy-alert-danger.caution .wy-alert-title,.rst-content .wy-alert-danger.hint .admonition-title,.rst-content .wy-alert-danger.hint .wy-alert-title,.rst-content .wy-alert-danger.important .admonition-title,.rst-content .wy-alert-danger.important .wy-alert-title,.rst-content .wy-alert-danger.note .admonition-title,.rst-content .wy-alert-danger.note .wy-alert-title,.rst-content .wy-alert-danger.seealso .admonition-title,.rst-content .wy-alert-danger.seealso .wy-alert-title,.rst-content .wy-alert-danger.tip .admonition-title,.rst-content .wy-alert-danger.tip .wy-alert-title,.rst-content .wy-alert-danger.warning .admonition-title,.rst-content .wy-alert-danger.warning .wy-alert-title,.rst-content .wy-alert.wy-alert-danger .admonition-title,.wy-alert.wy-alert-danger .rst-content .admonition-title,.wy-alert.wy-alert-danger .wy-alert-title{background:#f29f97}.rst-content .admonition-todo,.rst-content .attention,.rst-content .caution,.rst-content .warning,.rst-content .wy-alert-warning.admonition,.rst-content .wy-alert-warning.danger,.rst-content .wy-alert-warning.error,.rst-content .wy-alert-warning.hint,.rst-content .wy-alert-warning.important,.rst-content .wy-alert-warning.note,.rst-content .wy-alert-warning.seealso,.rst-content .wy-alert-warning.tip,.wy-alert.wy-alert-warning{background:#ffedcc}.rst-content .admonition-todo .admonition-title,.rst-content .admonition-todo .wy-alert-title,.rst-content .attention .admonition-title,.rst-content .attention .wy-alert-title,.rst-content .caution .admonition-title,.rst-content .caution .wy-alert-title,.rst-content .warning .admonition-title,.rst-content .warning .wy-alert-title,.rst-content .wy-alert-warning.admonition .admonition-title,.rst-content .wy-alert-warning.admonition .wy-alert-title,.rst-content .wy-alert-warning.danger .admonition-title,.rst-content .wy-alert-warning.danger .wy-alert-title,.rst-content .wy-alert-warning.error .admonition-title,.rst-content .wy-alert-warning.error .wy-alert-title,.rst-content .wy-alert-warning.hint .admonition-title,.rst-content .wy-alert-warning.hint .wy-alert-title,.rst-content .wy-alert-warning.important .admonition-title,.rst-content .wy-alert-warning.important .wy-alert-title,.rst-content .wy-alert-warning.note .admonition-title,.rst-content .wy-alert-warning.note .wy-alert-title,.rst-content .wy-alert-warning.seealso .admonition-title,.rst-content .wy-alert-warning.seealso .wy-alert-title,.rst-content .wy-alert-warning.tip .admonition-title,.rst-content .wy-alert-warning.tip .wy-alert-title,.rst-content .wy-alert.wy-alert-warning .admonition-title,.wy-alert.wy-alert-warning .rst-content .admonition-title,.wy-alert.wy-alert-warning .wy-alert-title{background:#f0b37e}.rst-content .note,.rst-content .seealso,.rst-content .wy-alert-info.admonition,.rst-content .wy-alert-info.admonition-todo,.rst-content .wy-alert-info.attention,.rst-content .wy-alert-info.caution,.rst-content .wy-alert-info.danger,.rst-content .wy-alert-info.error,.rst-content .wy-alert-info.hint,.rst-content .wy-alert-info.important,.rst-content .wy-alert-info.tip,.rst-content .wy-alert-info.warning,.wy-alert.wy-alert-info{background:#e7f2fa}.rst-content .note .admonition-title,.rst-content .note .wy-alert-title,.rst-content .seealso .admonition-title,.rst-content .seealso .wy-alert-title,.rst-content .wy-alert-info.admonition-todo .admonition-title,.rst-content .wy-alert-info.admonition-todo .wy-alert-title,.rst-content .wy-alert-info.admonition .admonition-title,.rst-content .wy-alert-info.admonition .wy-alert-title,.rst-content .wy-alert-info.attention .admonition-title,.rst-content .wy-alert-info.attention .wy-alert-title,.rst-content .wy-alert-info.caution .admonition-title,.rst-content .wy-alert-info.caution .wy-alert-title,.rst-content .wy-alert-info.danger .admonition-title,.rst-content .wy-alert-info.danger .wy-alert-title,.rst-content .wy-alert-info.error .admonition-title,.rst-content .wy-alert-info.error .wy-alert-title,.rst-content .wy-alert-info.hint .admonition-title,.rst-content .wy-alert-info.hint .wy-alert-title,.rst-content .wy-alert-info.important .admonition-title,.rst-content .wy-alert-info.important .wy-alert-title,.rst-content .wy-alert-info.tip .admonition-title,.rst-content .wy-alert-info.tip .wy-alert-title,.rst-content .wy-alert-info.warning .admonition-title,.rst-content .wy-alert-info.warning .wy-alert-title,.rst-content .wy-alert.wy-alert-info .admonition-title,.wy-alert.wy-alert-info .rst-content .admonition-title,.wy-alert.wy-alert-info .wy-alert-title{background:#6ab0de}.rst-content .hint,.rst-content .important,.rst-content .tip,.rst-content .wy-alert-success.admonition,.rst-content .wy-alert-success.admonition-todo,.rst-content .wy-alert-success.attention,.rst-content .wy-alert-success.caution,.rst-content .wy-alert-success.danger,.rst-content .wy-alert-success.error,.rst-content .wy-alert-success.note,.rst-content .wy-alert-success.seealso,.rst-content .wy-alert-success.warning,.wy-alert.wy-alert-success{background:#dbfaf4}.rst-content .hint .admonition-title,.rst-content .hint .wy-alert-title,.rst-content .important .admonition-title,.rst-content .important .wy-alert-title,.rst-content .tip .admonition-title,.rst-content .tip .wy-alert-title,.rst-content .wy-alert-success.admonition-todo .admonition-title,.rst-content .wy-alert-success.admonition-todo .wy-alert-title,.rst-content .wy-alert-success.admonition .admonition-title,.rst-content .wy-alert-success.admonition .wy-alert-title,.rst-content .wy-alert-success.attention .admonition-title,.rst-content .wy-alert-success.attention .wy-alert-title,.rst-content .wy-alert-success.caution .admonition-title,.rst-content .wy-alert-success.caution .wy-alert-title,.rst-content .wy-alert-success.danger .admonition-title,.rst-content .wy-alert-success.danger .wy-alert-title,.rst-content .wy-alert-success.error .admonition-title,.rst-content .wy-alert-success.error .wy-alert-title,.rst-content .wy-alert-success.note .admonition-title,.rst-content .wy-alert-success.note .wy-alert-title,.rst-content .wy-alert-success.seealso .admonition-title,.rst-content .wy-alert-success.seealso .wy-alert-title,.rst-content .wy-alert-success.warning .admonition-title,.rst-content .wy-alert-success.warning .wy-alert-title,.rst-content .wy-alert.wy-alert-success .admonition-title,.wy-alert.wy-alert-success .rst-content .admonition-title,.wy-alert.wy-alert-success .wy-alert-title{background:#1abc9c}.rst-content .wy-alert-neutral.admonition,.rst-content .wy-alert-neutral.admonition-todo,.rst-content .wy-alert-neutral.attention,.rst-content .wy-alert-neutral.caution,.rst-content .wy-alert-neutral.danger,.rst-content .wy-alert-neutral.error,.rst-content .wy-alert-neutral.hint,.rst-content .wy-alert-neutral.important,.rst-content .wy-alert-neutral.note,.rst-content .wy-alert-neutral.seealso,.rst-content .wy-alert-neutral.tip,.rst-content .wy-alert-neutral.warning,.wy-alert.wy-alert-neutral{background:#f3f6f6}.rst-content .wy-alert-neutral.admonition-todo .admonition-title,.rst-content .wy-alert-neutral.admonition-todo .wy-alert-title,.rst-content .wy-alert-neutral.admonition .admonition-title,.rst-content .wy-alert-neutral.admonition .wy-alert-title,.rst-content .wy-alert-neutral.attention .admonition-title,.rst-content .wy-alert-neutral.attention .wy-alert-title,.rst-content .wy-alert-neutral.caution .admonition-title,.rst-content .wy-alert-neutral.caution .wy-alert-title,.rst-content .wy-alert-neutral.danger .admonition-title,.rst-content .wy-alert-neutral.danger .wy-alert-title,.rst-content .wy-alert-neutral.error .admonition-title,.rst-content .wy-alert-neutral.error .wy-alert-title,.rst-content .wy-alert-neutral.hint .admonition-title,.rst-content .wy-alert-neutral.hint .wy-alert-title,.rst-content .wy-alert-neutral.important .admonition-title,.rst-content .wy-alert-neutral.important .wy-alert-title,.rst-content .wy-alert-neutral.note .admonition-title,.rst-content .wy-alert-neutral.note .wy-alert-title,.rst-content .wy-alert-neutral.seealso .admonition-title,.rst-content .wy-alert-neutral.seealso .wy-alert-title,.rst-content .wy-alert-neutral.tip .admonition-title,.rst-content .wy-alert-neutral.tip .wy-alert-title,.rst-content .wy-alert-neutral.warning .admonition-title,.rst-content .wy-alert-neutral.warning .wy-alert-title,.rst-content .wy-alert.wy-alert-neutral .admonition-title,.wy-alert.wy-alert-neutral .rst-content .admonition-title,.wy-alert.wy-alert-neutral .wy-alert-title{color:#404040;background:#e1e4e5}.rst-content .wy-alert-neutral.admonition-todo a,.rst-content .wy-alert-neutral.admonition a,.rst-content .wy-alert-neutral.attention a,.rst-content .wy-alert-neutral.caution a,.rst-content .wy-alert-neutral.danger a,.rst-content .wy-alert-neutral.error a,.rst-content .wy-alert-neutral.hint a,.rst-content .wy-alert-neutral.important a,.rst-content .wy-alert-neutral.note a,.rst-content .wy-alert-neutral.seealso a,.rst-content .wy-alert-neutral.tip a,.rst-content .wy-alert-neutral.warning a,.wy-alert.wy-alert-neutral a{color:#2980b9}.rst-content .admonition-todo p:last-child,.rst-content .admonition p:last-child,.rst-content .attention p:last-child,.rst-content .caution p:last-child,.rst-content .danger p:last-child,.rst-content .error p:last-child,.rst-content .hint p:last-child,.rst-content .important p:last-child,.rst-content .note p:last-child,.rst-content .seealso p:last-child,.rst-content .tip p:last-child,.rst-content .warning p:last-child,.wy-alert p:last-child{margin-bottom:0}.wy-tray-container{position:fixed;bottom:0;left:0;z-index:600}.wy-tray-container li{display:block;width:300px;background:transparent;color:#fff;text-align:center;box-shadow:0 5px 5px 0 rgba(0,0,0,.1);padding:0 24px;min-width:20%;opacity:0;height:0;line-height:56px;overflow:hidden;-webkit-transition:all .3s ease-in;-moz-transition:all .3s ease-in;transition:all .3s ease-in}.wy-tray-container li.wy-tray-item-success{background:#27ae60}.wy-tray-container li.wy-tray-item-info{background:#2980b9}.wy-tray-container li.wy-tray-item-warning{background:#e67e22}.wy-tray-container li.wy-tray-item-danger{background:#e74c3c}.wy-tray-container li.on{opacity:1;height:56px}@media screen and (max-width:768px){.wy-tray-container{bottom:auto;top:0;width:100%}.wy-tray-container li{width:100%}}button{font-size:100%;margin:0;vertical-align:baseline;*vertical-align:middle;cursor:pointer;line-height:normal;-webkit-appearance:button;*overflow:visible}button::-moz-focus-inner,input::-moz-focus-inner{border:0;padding:0}button[disabled]{cursor:default}.btn{display:inline-block;border-radius:2px;line-height:normal;white-space:nowrap;text-align:center;cursor:pointer;font-size:100%;padding:6px 12px 8px;color:#fff;border:1px solid rgba(0,0,0,.1);background-color:#27ae60;text-decoration:none;font-weight:400;font-family:Lato,proxima-nova,Helvetica Neue,Arial,sans-serif;box-shadow:inset 0 1px 2px -1px hsla(0,0%,100%,.5),inset 0 -2px 0 0 rgba(0,0,0,.1);outline-none:false;vertical-align:middle;*display:inline;zoom:1;-webkit-user-drag:none;-webkit-user-select:none;-moz-user-select:none;-ms-user-select:none;user-select:none;-webkit-transition:all .1s linear;-moz-transition:all .1s linear;transition:all .1s linear}.btn-hover{background:#2e8ece;color:#fff}.btn:hover{background:#2cc36b;color:#fff}.btn:focus{background:#2cc36b;outline:0}.btn:active{box-shadow:inset 0 -1px 0 0 rgba(0,0,0,.05),inset 0 2px 0 0 rgba(0,0,0,.1);padding:8px 12px 6px}.btn:visited{color:#fff}.btn-disabled,.btn-disabled:active,.btn-disabled:focus,.btn-disabled:hover,.btn:disabled{background-image:none;filter:progid:DXImageTransform.Microsoft.gradient(enabled = false);filter:alpha(opacity=40);opacity:.4;cursor:not-allowed;box-shadow:none}.btn::-moz-focus-inner{padding:0;border:0}.btn-small{font-size:80%}.btn-info{background-color:#2980b9!important}.btn-info:hover{background-color:#2e8ece!important}.btn-neutral{background-color:#f3f6f6!important;color:#404040!important}.btn-neutral:hover{background-color:#e5ebeb!important;color:#404040}.btn-neutral:visited{color:#404040!important}.btn-success{background-color:#27ae60!important}.btn-success:hover{background-color:#295!important}.btn-danger{background-color:#e74c3c!important}.btn-danger:hover{background-color:#ea6153!important}.btn-warning{background-color:#e67e22!important}.btn-warning:hover{background-color:#e98b39!important}.btn-invert{background-color:#222}.btn-invert:hover{background-color:#2f2f2f!important}.btn-link{background-color:transparent!important;color:#2980b9;box-shadow:none;border-color:transparent!important}.btn-link:active,.btn-link:hover{background-color:transparent!important;color:#409ad5!important;box-shadow:none}.btn-link:visited{color:#9b59b6}.wy-btn-group .btn,.wy-control .btn{vertical-align:middle}.wy-btn-group{margin-bottom:24px;*zoom:1}.wy-btn-group:after,.wy-btn-group:before{display:table;content:""}.wy-btn-group:after{clear:both}.wy-dropdown{position:relative;display:inline-block}.wy-dropdown-active .wy-dropdown-menu{display:block}.wy-dropdown-menu{position:absolute;left:0;display:none;float:left;top:100%;min-width:100%;background:#fcfcfc;z-index:100;border:1px solid #cfd7dd;box-shadow:0 2px 2px 0 rgba(0,0,0,.1);padding:12px}.wy-dropdown-menu>dd>a{display:block;clear:both;color:#404040;white-space:nowrap;font-size:90%;padding:0 12px;cursor:pointer}.wy-dropdown-menu>dd>a:hover{background:#2980b9;color:#fff}.wy-dropdown-menu>dd.divider{border-top:1px solid #cfd7dd;margin:6px 0}.wy-dropdown-menu>dd.search{padding-bottom:12px}.wy-dropdown-menu>dd.search input[type=search]{width:100%}.wy-dropdown-menu>dd.call-to-action{background:#e3e3e3;text-transform:uppercase;font-weight:500;font-size:80%}.wy-dropdown-menu>dd.call-to-action:hover{background:#e3e3e3}.wy-dropdown-menu>dd.call-to-action .btn{color:#fff}.wy-dropdown.wy-dropdown-up .wy-dropdown-menu{bottom:100%;top:auto;left:auto;right:0}.wy-dropdown.wy-dropdown-bubble .wy-dropdown-menu{background:#fcfcfc;margin-top:2px}.wy-dropdown.wy-dropdown-bubble .wy-dropdown-menu a{padding:6px 12px}.wy-dropdown.wy-dropdown-bubble .wy-dropdown-menu a:hover{background:#2980b9;color:#fff}.wy-dropdown.wy-dropdown-left .wy-dropdown-menu{right:0;left:auto;text-align:right}.wy-dropdown-arrow:before{content:" ";border-bottom:5px solid #f5f5f5;border-left:5px solid transparent;border-right:5px solid transparent;position:absolute;display:block;top:-4px;left:50%;margin-left:-3px}.wy-dropdown-arrow.wy-dropdown-arrow-left:before{left:11px}.wy-form-stacked select{display:block}.wy-form-aligned .wy-help-inline,.wy-form-aligned input,.wy-form-aligned label,.wy-form-aligned select,.wy-form-aligned textarea{display:inline-block;*display:inline;*zoom:1;vertical-align:middle}.wy-form-aligned .wy-control-group>label{display:inline-block;vertical-align:middle;width:10em;margin:6px 12px 0 0;float:left}.wy-form-aligned .wy-control{float:left}.wy-form-aligned .wy-control label{display:block}.wy-form-aligned .wy-control select{margin-top:6px}fieldset{margin:0}fieldset,legend{border:0;padding:0}legend{width:100%;white-space:normal;margin-bottom:24px;font-size:150%;*margin-left:-7px}label,legend{display:block}label{margin:0 0 .3125em;color:#333;font-size:90%}input,select,textarea{font-size:100%;margin:0;vertical-align:baseline;*vertical-align:middle}.wy-control-group{margin-bottom:24px;max-width:1200px;margin-left:auto;margin-right:auto;*zoom:1}.wy-control-group:after,.wy-control-group:before{display:table;content:""}.wy-control-group:after{clear:both}.wy-control-group.wy-control-group-required>label:after{content:" *";color:#e74c3c}.wy-control-group .wy-form-full,.wy-control-group .wy-form-halves,.wy-control-group .wy-form-thirds{padding-bottom:12px}.wy-control-group .wy-form-full input[type=color],.wy-control-group .wy-form-full input[type=date],.wy-control-group .wy-form-full input[type=datetime-local],.wy-control-group .wy-form-full input[type=datetime],.wy-control-group .wy-form-full input[type=email],.wy-control-group .wy-form-full input[type=month],.wy-control-group .wy-form-full input[type=number],.wy-control-group .wy-form-full input[type=password],.wy-control-group .wy-form-full input[type=search],.wy-control-group .wy-form-full input[type=tel],.wy-control-group .wy-form-full input[type=text],.wy-control-group .wy-form-full input[type=time],.wy-control-group .wy-form-full input[type=url],.wy-control-group .wy-form-full input[type=week],.wy-control-group .wy-form-full select,.wy-control-group .wy-form-halves input[type=color],.wy-control-group .wy-form-halves input[type=date],.wy-control-group .wy-form-halves input[type=datetime-local],.wy-control-group .wy-form-halves input[type=datetime],.wy-control-group .wy-form-halves input[type=email],.wy-control-group .wy-form-halves input[type=month],.wy-control-group .wy-form-halves input[type=number],.wy-control-group .wy-form-halves input[type=password],.wy-control-group .wy-form-halves input[type=search],.wy-control-group .wy-form-halves input[type=tel],.wy-control-group .wy-form-halves input[type=text],.wy-control-group .wy-form-halves input[type=time],.wy-control-group .wy-form-halves input[type=url],.wy-control-group .wy-form-halves input[type=week],.wy-control-group .wy-form-halves select,.wy-control-group .wy-form-thirds input[type=color],.wy-control-group .wy-form-thirds input[type=date],.wy-control-group .wy-form-thirds input[type=datetime-local],.wy-control-group .wy-form-thirds input[type=datetime],.wy-control-group .wy-form-thirds input[type=email],.wy-control-group .wy-form-thirds input[type=month],.wy-control-group .wy-form-thirds input[type=number],.wy-control-group .wy-form-thirds input[type=password],.wy-control-group .wy-form-thirds input[type=search],.wy-control-group .wy-form-thirds input[type=tel],.wy-control-group .wy-form-thirds input[type=text],.wy-control-group .wy-form-thirds input[type=time],.wy-control-group .wy-form-thirds input[type=url],.wy-control-group .wy-form-thirds input[type=week],.wy-control-group .wy-form-thirds select{width:100%}.wy-control-group .wy-form-full{float:left;display:block;width:100%;margin-right:0}.wy-control-group .wy-form-full:last-child{margin-right:0}.wy-control-group .wy-form-halves{float:left;display:block;margin-right:2.35765%;width:48.82117%}.wy-control-group .wy-form-halves:last-child,.wy-control-group .wy-form-halves:nth-of-type(2n){margin-right:0}.wy-control-group .wy-form-halves:nth-of-type(odd){clear:left}.wy-control-group .wy-form-thirds{float:left;display:block;margin-right:2.35765%;width:31.76157%}.wy-control-group .wy-form-thirds:last-child,.wy-control-group .wy-form-thirds:nth-of-type(3n){margin-right:0}.wy-control-group .wy-form-thirds:nth-of-type(3n+1){clear:left}.wy-control-group.wy-control-group-no-input .wy-control,.wy-control-no-input{margin:6px 0 0;font-size:90%}.wy-control-no-input{display:inline-block}.wy-control-group.fluid-input input[type=color],.wy-control-group.fluid-input input[type=date],.wy-control-group.fluid-input input[type=datetime-local],.wy-control-group.fluid-input input[type=datetime],.wy-control-group.fluid-input input[type=email],.wy-control-group.fluid-input input[type=month],.wy-control-group.fluid-input input[type=number],.wy-control-group.fluid-input input[type=password],.wy-control-group.fluid-input input[type=search],.wy-control-group.fluid-input input[type=tel],.wy-control-group.fluid-input input[type=text],.wy-control-group.fluid-input input[type=time],.wy-control-group.fluid-input input[type=url],.wy-control-group.fluid-input input[type=week]{width:100%}.wy-form-message-inline{padding-left:.3em;color:#666;font-size:90%}.wy-form-message{display:block;color:#999;font-size:70%;margin-top:.3125em;font-style:italic}.wy-form-message p{font-size:inherit;font-style:italic;margin-bottom:6px}.wy-form-message p:last-child{margin-bottom:0}input{line-height:normal}input[type=button],input[type=reset],input[type=submit]{-webkit-appearance:button;cursor:pointer;font-family:Lato,proxima-nova,Helvetica Neue,Arial,sans-serif;*overflow:visible}input[type=color],input[type=date],input[type=datetime-local],input[type=datetime],input[type=email],input[type=month],input[type=number],input[type=password],input[type=search],input[type=tel],input[type=text],input[type=time],input[type=url],input[type=week]{-webkit-appearance:none;padding:6px;display:inline-block;border:1px solid #ccc;font-size:80%;font-family:Lato,proxima-nova,Helvetica Neue,Arial,sans-serif;box-shadow:inset 0 1px 3px #ddd;border-radius:0;-webkit-transition:border .3s linear;-moz-transition:border .3s linear;transition:border .3s linear}input[type=datetime-local]{padding:.34375em .625em}input[disabled]{cursor:default}input[type=checkbox],input[type=radio]{padding:0;margin-right:.3125em;*height:13px;*width:13px}input[type=checkbox],input[type=radio],input[type=search]{-webkit-box-sizing:border-box;-moz-box-sizing:border-box;box-sizing:border-box}input[type=search]::-webkit-search-cancel-button,input[type=search]::-webkit-search-decoration{-webkit-appearance:none}input[type=color]:focus,input[type=date]:focus,input[type=datetime-local]:focus,input[type=datetime]:focus,input[type=email]:focus,input[type=month]:focus,input[type=number]:focus,input[type=password]:focus,input[type=search]:focus,input[type=tel]:focus,input[type=text]:focus,input[type=time]:focus,input[type=url]:focus,input[type=week]:focus{outline:0;outline:thin dotted\9;border-color:#333}input.no-focus:focus{border-color:#ccc!important}input[type=checkbox]:focus,input[type=file]:focus,input[type=radio]:focus{outline:thin dotted #333;outline:1px auto #129fea}input[type=color][disabled],input[type=date][disabled],input[type=datetime-local][disabled],input[type=datetime][disabled],input[type=email][disabled],input[type=month][disabled],input[type=number][disabled],input[type=password][disabled],input[type=search][disabled],input[type=tel][disabled],input[type=text][disabled],input[type=time][disabled],input[type=url][disabled],input[type=week][disabled]{cursor:not-allowed;background-color:#fafafa}input:focus:invalid,select:focus:invalid,textarea:focus:invalid{color:#e74c3c;border:1px solid #e74c3c}input:focus:invalid:focus,select:focus:invalid:focus,textarea:focus:invalid:focus{border-color:#e74c3c}input[type=checkbox]:focus:invalid:focus,input[type=file]:focus:invalid:focus,input[type=radio]:focus:invalid:focus{outline-color:#e74c3c}input.wy-input-large{padding:12px;font-size:100%}textarea{overflow:auto;vertical-align:top;width:100%;font-family:Lato,proxima-nova,Helvetica Neue,Arial,sans-serif}select,textarea{padding:.5em .625em;display:inline-block;border:1px solid #ccc;font-size:80%;box-shadow:inset 0 1px 3px #ddd;-webkit-transition:border .3s linear;-moz-transition:border .3s linear;transition:border .3s linear}select{border:1px solid #ccc;background-color:#fff}select[multiple]{height:auto}select:focus,textarea:focus{outline:0}input[readonly],select[disabled],select[readonly],textarea[disabled],textarea[readonly]{cursor:not-allowed;background-color:#fafafa}input[type=checkbox][disabled],input[type=radio][disabled]{cursor:not-allowed}.wy-checkbox,.wy-radio{margin:6px 0;color:#404040;display:block}.wy-checkbox input,.wy-radio input{vertical-align:baseline}.wy-form-message-inline{display:inline-block;*display:inline;*zoom:1;vertical-align:middle}.wy-input-prefix,.wy-input-suffix{white-space:nowrap;padding:6px}.wy-input-prefix .wy-input-context,.wy-input-suffix .wy-input-context{line-height:27px;padding:0 8px;display:inline-block;font-size:80%;background-color:#f3f6f6;border:1px solid #ccc;color:#999}.wy-input-suffix .wy-input-context{border-left:0}.wy-input-prefix .wy-input-context{border-right:0}.wy-switch{position:relative;display:block;height:24px;margin-top:12px;cursor:pointer}.wy-switch:before{left:0;top:0;width:36px;height:12px;background:#ccc}.wy-switch:after,.wy-switch:before{position:absolute;content:"";display:block;border-radius:4px;-webkit-transition:all .2s ease-in-out;-moz-transition:all .2s ease-in-out;transition:all .2s ease-in-out}.wy-switch:after{width:18px;height:18px;background:#999;left:-3px;top:-3px}.wy-switch span{position:absolute;left:48px;display:block;font-size:12px;color:#ccc;line-height:1}.wy-switch.active:before{background:#1e8449}.wy-switch.active:after{left:24px;background:#27ae60}.wy-switch.disabled{cursor:not-allowed;opacity:.8}.wy-control-group.wy-control-group-error .wy-form-message,.wy-control-group.wy-control-group-error>label{color:#e74c3c}.wy-control-group.wy-control-group-error input[type=color],.wy-control-group.wy-control-group-error input[type=date],.wy-control-group.wy-control-group-error input[type=datetime-local],.wy-control-group.wy-control-group-error input[type=datetime],.wy-control-group.wy-control-group-error input[type=email],.wy-control-group.wy-control-group-error input[type=month],.wy-control-group.wy-control-group-error input[type=number],.wy-control-group.wy-control-group-error input[type=password],.wy-control-group.wy-control-group-error input[type=search],.wy-control-group.wy-control-group-error input[type=tel],.wy-control-group.wy-control-group-error input[type=text],.wy-control-group.wy-control-group-error input[type=time],.wy-control-group.wy-control-group-error input[type=url],.wy-control-group.wy-control-group-error input[type=week],.wy-control-group.wy-control-group-error textarea{border:1px solid #e74c3c}.wy-inline-validate{white-space:nowrap}.wy-inline-validate .wy-input-context{padding:.5em .625em;display:inline-block;font-size:80%}.wy-inline-validate.wy-inline-validate-success .wy-input-context{color:#27ae60}.wy-inline-validate.wy-inline-validate-danger .wy-input-context{color:#e74c3c}.wy-inline-validate.wy-inline-validate-warning .wy-input-context{color:#e67e22}.wy-inline-validate.wy-inline-validate-info .wy-input-context{color:#2980b9}.rotate-90{-webkit-transform:rotate(90deg);-moz-transform:rotate(90deg);-ms-transform:rotate(90deg);-o-transform:rotate(90deg);transform:rotate(90deg)}.rotate-180{-webkit-transform:rotate(180deg);-moz-transform:rotate(180deg);-ms-transform:rotate(180deg);-o-transform:rotate(180deg);transform:rotate(180deg)}.rotate-270{-webkit-transform:rotate(270deg);-moz-transform:rotate(270deg);-ms-transform:rotate(270deg);-o-transform:rotate(270deg);transform:rotate(270deg)}.mirror{-webkit-transform:scaleX(-1);-moz-transform:scaleX(-1);-ms-transform:scaleX(-1);-o-transform:scaleX(-1);transform:scaleX(-1)}.mirror.rotate-90{-webkit-transform:scaleX(-1) rotate(90deg);-moz-transform:scaleX(-1) rotate(90deg);-ms-transform:scaleX(-1) rotate(90deg);-o-transform:scaleX(-1) rotate(90deg);transform:scaleX(-1) rotate(90deg)}.mirror.rotate-180{-webkit-transform:scaleX(-1) rotate(180deg);-moz-transform:scaleX(-1) rotate(180deg);-ms-transform:scaleX(-1) rotate(180deg);-o-transform:scaleX(-1) rotate(180deg);transform:scaleX(-1) rotate(180deg)}.mirror.rotate-270{-webkit-transform:scaleX(-1) rotate(270deg);-moz-transform:scaleX(-1) rotate(270deg);-ms-transform:scaleX(-1) rotate(270deg);-o-transform:scaleX(-1) rotate(270deg);transform:scaleX(-1) rotate(270deg)}@media only screen and (max-width:480px){.wy-form button[type=submit]{margin:.7em 0 0}.wy-form input[type=color],.wy-form input[type=date],.wy-form input[type=datetime-local],.wy-form input[type=datetime],.wy-form input[type=email],.wy-form input[type=month],.wy-form input[type=number],.wy-form input[type=password],.wy-form input[type=search],.wy-form input[type=tel],.wy-form input[type=text],.wy-form input[type=time],.wy-form input[type=url],.wy-form input[type=week],.wy-form label{margin-bottom:.3em;display:block}.wy-form input[type=color],.wy-form input[type=date],.wy-form input[type=datetime-local],.wy-form input[type=datetime],.wy-form input[type=email],.wy-form input[type=month],.wy-form input[type=number],.wy-form input[type=password],.wy-form input[type=search],.wy-form input[type=tel],.wy-form input[type=time],.wy-form input[type=url],.wy-form input[type=week]{margin-bottom:0}.wy-form-aligned .wy-control-group label{margin-bottom:.3em;text-align:left;display:block;width:100%}.wy-form-aligned .wy-control{margin:1.5em 0 0}.wy-form-message,.wy-form-message-inline,.wy-form .wy-help-inline{display:block;font-size:80%;padding:6px 0}}@media screen and (max-width:768px){.tablet-hide{display:none}}@media screen and (max-width:480px){.mobile-hide{display:none}}.float-left{float:left}.float-right{float:right}.full-width{width:100%}.rst-content table.docutils,.rst-content table.field-list,.wy-table{border-collapse:collapse;border-spacing:0;empty-cells:show;margin-bottom:24px}.rst-content table.docutils caption,.rst-content table.field-list caption,.wy-table caption{color:#000;font:italic 85%/1 arial,sans-serif;padding:1em 0;text-align:center}.rst-content table.docutils td,.rst-content table.docutils th,.rst-content table.field-list td,.rst-content table.field-list th,.wy-table td,.wy-table th{font-size:90%;margin:0;overflow:visible;padding:8px 16px}.rst-content table.docutils td:first-child,.rst-content table.docutils th:first-child,.rst-content table.field-list td:first-child,.rst-content table.field-list th:first-child,.wy-table td:first-child,.wy-table th:first-child{border-left-width:0}.rst-content table.docutils thead,.rst-content table.field-list thead,.wy-table thead{color:#000;text-align:left;vertical-align:bottom;white-space:nowrap}.rst-content table.docutils thead th,.rst-content table.field-list thead th,.wy-table thead th{font-weight:700;border-bottom:2px solid #e1e4e5}.rst-content table.docutils td,.rst-content table.field-list td,.wy-table td{background-color:transparent;vertical-align:middle}.rst-content table.docutils td p,.rst-content table.field-list td p,.wy-table td p{line-height:18px}.rst-content table.docutils td p:last-child,.rst-content table.field-list td p:last-child,.wy-table td p:last-child{margin-bottom:0}.rst-content table.docutils .wy-table-cell-min,.rst-content table.field-list .wy-table-cell-min,.wy-table .wy-table-cell-min{width:1%;padding-right:0}.rst-content table.docutils .wy-table-cell-min input[type=checkbox],.rst-content table.field-list .wy-table-cell-min input[type=checkbox],.wy-table .wy-table-cell-min input[type=checkbox]{margin:0}.wy-table-secondary{color:grey;font-size:90%}.wy-table-tertiary{color:grey;font-size:80%}.rst-content table.docutils:not(.field-list) tr:nth-child(2n-1) td,.wy-table-backed,.wy-table-odd td,.wy-table-striped tr:nth-child(2n-1) td{background-color:#f3f6f6}.rst-content table.docutils,.wy-table-bordered-all{border:1px solid #e1e4e5}.rst-content table.docutils td,.wy-table-bordered-all td{border-bottom:1px solid #e1e4e5;border-left:1px solid #e1e4e5}.rst-content table.docutils tbody>tr:last-child td,.wy-table-bordered-all tbody>tr:last-child td{border-bottom-width:0}.wy-table-bordered{border:1px solid #e1e4e5}.wy-table-bordered-rows td{border-bottom:1px solid #e1e4e5}.wy-table-bordered-rows tbody>tr:last-child td{border-bottom-width:0}.wy-table-horizontal td,.wy-table-horizontal th{border-width:0 0 1px;border-bottom:1px solid #e1e4e5}.wy-table-horizontal tbody>tr:last-child td{border-bottom-width:0}.wy-table-responsive{margin-bottom:24px;max-width:100%;overflow:auto}.wy-table-responsive table{margin-bottom:0!important}.wy-table-responsive table td,.wy-table-responsive table th{white-space:nowrap}a{color:#2980b9;text-decoration:none;cursor:pointer}a:hover{color:#3091d1}a:visited{color:#9b59b6}html{height:100%}body,html{overflow-x:hidden}body{font-family:Lato,proxima-nova,Helvetica Neue,Arial,sans-serif;font-weight:400;color:#404040;min-height:100%;background:#edf0f2}.wy-text-left{text-align:left}.wy-text-center{text-align:center}.wy-text-right{text-align:right}.wy-text-large{font-size:120%}.wy-text-normal{font-size:100%}.wy-text-small,small{font-size:80%}.wy-text-strike{text-decoration:line-through}.wy-text-warning{color:#e67e22!important}a.wy-text-warning:hover{color:#eb9950!important}.wy-text-info{color:#2980b9!important}a.wy-text-info:hover{color:#409ad5!important}.wy-text-success{color:#27ae60!important}a.wy-text-success:hover{color:#36d278!important}.wy-text-danger{color:#e74c3c!important}a.wy-text-danger:hover{color:#ed7669!important}.wy-text-neutral{color:#404040!important}a.wy-text-neutral:hover{color:#595959!important}.rst-content .toctree-wrapper>p.caption,h1,h2,h3,h4,h5,h6,legend{margin-top:0;font-weight:700;font-family:Roboto Slab,ff-tisa-web-pro,Georgia,Arial,sans-serif}p{line-height:24px;font-size:16px;margin:0 0 24px}h1{font-size:175%}.rst-content .toctree-wrapper>p.caption,h2{font-size:150%}h3{font-size:125%}h4{font-size:115%}h5{font-size:110%}h6{font-size:100%}hr{display:block;height:1px;border:0;border-top:1px solid #e1e4e5;margin:24px 0;padding:0}.rst-content code,.rst-content tt,code{white-space:nowrap;max-width:100%;background:#fff;border:1px solid #e1e4e5;font-size:75%;padding:0 5px;font-family:SFMono-Regular,Menlo,Monaco,Consolas,Liberation Mono,Courier New,Courier,monospace;color:#e74c3c;overflow-x:auto}.rst-content tt.code-large,code.code-large{font-size:90%}.rst-content .section ul,.rst-content .toctree-wrapper ul,.rst-content section ul,.wy-plain-list-disc,article ul{list-style:disc;line-height:24px;margin-bottom:24px}.rst-content .section ul li,.rst-content .toctree-wrapper ul li,.rst-content section ul li,.wy-plain-list-disc li,article ul li{list-style:disc;margin-left:24px}.rst-content .section ul li p:last-child,.rst-content .section ul li ul,.rst-content .toctree-wrapper ul li p:last-child,.rst-content .toctree-wrapper ul li ul,.rst-content section ul li p:last-child,.rst-content section ul li ul,.wy-plain-list-disc li p:last-child,.wy-plain-list-disc li ul,article ul li p:last-child,article ul li ul{margin-bottom:0}.rst-content .section ul li li,.rst-content .toctree-wrapper ul li li,.rst-content section ul li li,.wy-plain-list-disc li li,article ul li li{list-style:circle}.rst-content .section ul li li li,.rst-content .toctree-wrapper ul li li li,.rst-content section ul li li li,.wy-plain-list-disc li li li,article ul li li li{list-style:square}.rst-content .section ul li ol li,.rst-content .toctree-wrapper ul li ol li,.rst-content section ul li ol li,.wy-plain-list-disc li ol li,article ul li ol li{list-style:decimal}.rst-content .section ol,.rst-content .section ol.arabic,.rst-content .toctree-wrapper ol,.rst-content .toctree-wrapper ol.arabic,.rst-content section ol,.rst-content section ol.arabic,.wy-plain-list-decimal,article ol{list-style:decimal;line-height:24px;margin-bottom:24px}.rst-content .section ol.arabic li,.rst-content .section ol li,.rst-content .toctree-wrapper ol.arabic li,.rst-content .toctree-wrapper ol li,.rst-content section ol.arabic li,.rst-content section ol li,.wy-plain-list-decimal li,article ol li{list-style:decimal;margin-left:24px}.rst-content .section ol.arabic li ul,.rst-content .section ol li p:last-child,.rst-content .section ol li ul,.rst-content .toctree-wrapper ol.arabic li ul,.rst-content .toctree-wrapper ol li p:last-child,.rst-content .toctree-wrapper ol li ul,.rst-content section ol.arabic li ul,.rst-content section ol li p:last-child,.rst-content section ol li ul,.wy-plain-list-decimal li p:last-child,.wy-plain-list-decimal li ul,article ol li p:last-child,article ol li ul{margin-bottom:0}.rst-content .section ol.arabic li ul li,.rst-content .section ol li ul li,.rst-content .toctree-wrapper ol.arabic li ul li,.rst-content .toctree-wrapper ol li ul li,.rst-content section ol.arabic li ul li,.rst-content section ol li ul li,.wy-plain-list-decimal li ul li,article ol li ul li{list-style:disc}.wy-breadcrumbs{*zoom:1}.wy-breadcrumbs:after,.wy-breadcrumbs:before{display:table;content:""}.wy-breadcrumbs:after{clear:both}.wy-breadcrumbs li{display:inline-block}.wy-breadcrumbs li.wy-breadcrumbs-aside{float:right}.wy-breadcrumbs li a{display:inline-block;padding:5px}.wy-breadcrumbs li a:first-child{padding-left:0}.rst-content .wy-breadcrumbs li tt,.wy-breadcrumbs li .rst-content tt,.wy-breadcrumbs li code{padding:5px;border:none;background:none}.rst-content .wy-breadcrumbs li tt.literal,.wy-breadcrumbs li .rst-content tt.literal,.wy-breadcrumbs li code.literal{color:#404040}.wy-breadcrumbs-extra{margin-bottom:0;color:#b3b3b3;font-size:80%;display:inline-block}@media screen and (max-width:480px){.wy-breadcrumbs-extra,.wy-breadcrumbs li.wy-breadcrumbs-aside{display:none}}@media print{.wy-breadcrumbs li.wy-breadcrumbs-aside{display:none}}html{font-size:16px}.wy-affix{position:fixed;top:1.618em}.wy-menu a:hover{text-decoration:none}.wy-menu-horiz{*zoom:1}.wy-menu-horiz:after,.wy-menu-horiz:before{display:table;content:""}.wy-menu-horiz:after{clear:both}.wy-menu-horiz li,.wy-menu-horiz ul{display:inline-block}.wy-menu-horiz li:hover{background:hsla(0,0%,100%,.1)}.wy-menu-horiz li.divide-left{border-left:1px solid #404040}.wy-menu-horiz li.divide-right{border-right:1px solid #404040}.wy-menu-horiz a{height:32px;display:inline-block;line-height:32px;padding:0 16px}.wy-menu-vertical{width:300px}.wy-menu-vertical header,.wy-menu-vertical p.caption{color:#55a5d9;height:32px;line-height:32px;padding:0 1.618em;margin:12px 0 0;display:block;font-weight:700;text-transform:uppercase;font-size:85%;white-space:nowrap}.wy-menu-vertical ul{margin-bottom:0}.wy-menu-vertical li.divide-top{border-top:1px solid #404040}.wy-menu-vertical li.divide-bottom{border-bottom:1px solid #404040}.wy-menu-vertical li.current{background:#e3e3e3}.wy-menu-vertical li.current a{color:grey;border-right:1px solid #c9c9c9;padding:.4045em 2.427em}.wy-menu-vertical li.current a:hover{background:#d6d6d6}.rst-content .wy-menu-vertical li tt,.wy-menu-vertical li .rst-content tt,.wy-menu-vertical li code{border:none;background:inherit;color:inherit;padding-left:0;padding-right:0}.wy-menu-vertical li button.toctree-expand{display:block;float:left;margin-left:-1.2em;line-height:18px;color:#4d4d4d;border:none;background:none;padding:0}.wy-menu-vertical li.current>a,.wy-menu-vertical li.on a{color:#404040;font-weight:700;position:relative;background:#fcfcfc;border:none;padding:.4045em 1.618em}.wy-menu-vertical li.current>a:hover,.wy-menu-vertical li.on a:hover{background:#fcfcfc}.wy-menu-vertical li.current>a:hover button.toctree-expand,.wy-menu-vertical li.on a:hover button.toctree-expand{color:grey}.wy-menu-vertical li.current>a button.toctree-expand,.wy-menu-vertical li.on a button.toctree-expand{display:block;line-height:18px;color:#333}.wy-menu-vertical li.toctree-l1.current>a{border-bottom:1px solid #c9c9c9;border-top:1px solid #c9c9c9}.wy-menu-vertical .toctree-l1.current .toctree-l2>ul,.wy-menu-vertical .toctree-l2.current .toctree-l3>ul,.wy-menu-vertical .toctree-l3.current .toctree-l4>ul,.wy-menu-vertical .toctree-l4.current .toctree-l5>ul,.wy-menu-vertical .toctree-l5.current .toctree-l6>ul,.wy-menu-vertical .toctree-l6.current .toctree-l7>ul,.wy-menu-vertical .toctree-l7.current .toctree-l8>ul,.wy-menu-vertical .toctree-l8.current .toctree-l9>ul,.wy-menu-vertical .toctree-l9.current .toctree-l10>ul,.wy-menu-vertical .toctree-l10.current .toctree-l11>ul{display:none}.wy-menu-vertical .toctree-l1.current .current.toctree-l2>ul,.wy-menu-vertical .toctree-l2.current .current.toctree-l3>ul,.wy-menu-vertical .toctree-l3.current .current.toctree-l4>ul,.wy-menu-vertical .toctree-l4.current .current.toctree-l5>ul,.wy-menu-vertical .toctree-l5.current .current.toctree-l6>ul,.wy-menu-vertical .toctree-l6.current .current.toctree-l7>ul,.wy-menu-vertical .toctree-l7.current .current.toctree-l8>ul,.wy-menu-vertical .toctree-l8.current .current.toctree-l9>ul,.wy-menu-vertical .toctree-l9.current .current.toctree-l10>ul,.wy-menu-vertical .toctree-l10.current .current.toctree-l11>ul{display:block}.wy-menu-vertical li.toctree-l3,.wy-menu-vertical li.toctree-l4{font-size:.9em}.wy-menu-vertical li.toctree-l2 a,.wy-menu-vertical li.toctree-l3 a,.wy-menu-vertical li.toctree-l4 a,.wy-menu-vertical li.toctree-l5 a,.wy-menu-vertical li.toctree-l6 a,.wy-menu-vertical li.toctree-l7 a,.wy-menu-vertical li.toctree-l8 a,.wy-menu-vertical li.toctree-l9 a,.wy-menu-vertical li.toctree-l10 a{color:#404040}.wy-menu-vertical li.toctree-l2 a:hover button.toctree-expand,.wy-menu-vertical li.toctree-l3 a:hover button.toctree-expand,.wy-menu-vertical li.toctree-l4 a:hover button.toctree-expand,.wy-menu-vertical li.toctree-l5 a:hover button.toctree-expand,.wy-menu-vertical li.toctree-l6 a:hover button.toctree-expand,.wy-menu-vertical li.toctree-l7 a:hover button.toctree-expand,.wy-menu-vertical li.toctree-l8 a:hover button.toctree-expand,.wy-menu-vertical li.toctree-l9 a:hover button.toctree-expand,.wy-menu-vertical li.toctree-l10 a:hover button.toctree-expand{color:grey}.wy-menu-vertical li.toctree-l2.current li.toctree-l3>a,.wy-menu-vertical li.toctree-l3.current li.toctree-l4>a,.wy-menu-vertical li.toctree-l4.current li.toctree-l5>a,.wy-menu-vertical li.toctree-l5.current li.toctree-l6>a,.wy-menu-vertical li.toctree-l6.current li.toctree-l7>a,.wy-menu-vertical li.toctree-l7.current li.toctree-l8>a,.wy-menu-vertical li.toctree-l8.current li.toctree-l9>a,.wy-menu-vertical li.toctree-l9.current li.toctree-l10>a,.wy-menu-vertical li.toctree-l10.current li.toctree-l11>a{display:block}.wy-menu-vertical li.toctree-l2.current>a{padding:.4045em 2.427em}.wy-menu-vertical li.toctree-l2.current li.toctree-l3>a{padding:.4045em 1.618em .4045em 4.045em}.wy-menu-vertical li.toctree-l3.current>a{padding:.4045em 4.045em}.wy-menu-vertical li.toctree-l3.current li.toctree-l4>a{padding:.4045em 1.618em .4045em 5.663em}.wy-menu-vertical li.toctree-l4.current>a{padding:.4045em 5.663em}.wy-menu-vertical li.toctree-l4.current li.toctree-l5>a{padding:.4045em 1.618em .4045em 7.281em}.wy-menu-vertical li.toctree-l5.current>a{padding:.4045em 7.281em}.wy-menu-vertical li.toctree-l5.current li.toctree-l6>a{padding:.4045em 1.618em .4045em 8.899em}.wy-menu-vertical li.toctree-l6.current>a{padding:.4045em 8.899em}.wy-menu-vertical li.toctree-l6.current li.toctree-l7>a{padding:.4045em 1.618em .4045em 10.517em}.wy-menu-vertical li.toctree-l7.current>a{padding:.4045em 10.517em}.wy-menu-vertical li.toctree-l7.current li.toctree-l8>a{padding:.4045em 1.618em .4045em 12.135em}.wy-menu-vertical li.toctree-l8.current>a{padding:.4045em 12.135em}.wy-menu-vertical li.toctree-l8.current li.toctree-l9>a{padding:.4045em 1.618em .4045em 13.753em}.wy-menu-vertical li.toctree-l9.current>a{padding:.4045em 13.753em}.wy-menu-vertical li.toctree-l9.current li.toctree-l10>a{padding:.4045em 1.618em .4045em 15.371em}.wy-menu-vertical li.toctree-l10.current>a{padding:.4045em 15.371em}.wy-menu-vertical li.toctree-l10.current li.toctree-l11>a{padding:.4045em 1.618em .4045em 16.989em}.wy-menu-vertical li.toctree-l2.current>a,.wy-menu-vertical li.toctree-l2.current li.toctree-l3>a{background:#c9c9c9}.wy-menu-vertical li.toctree-l2 button.toctree-expand{color:#a3a3a3}.wy-menu-vertical li.toctree-l3.current>a,.wy-menu-vertical li.toctree-l3.current li.toctree-l4>a{background:#bdbdbd}.wy-menu-vertical li.toctree-l3 button.toctree-expand{color:#969696}.wy-menu-vertical li.current ul{display:block}.wy-menu-vertical li ul{margin-bottom:0;display:none}.wy-menu-vertical li ul li a{margin-bottom:0;color:#d9d9d9;font-weight:400}.wy-menu-vertical a{line-height:18px;padding:.4045em 1.618em;display:block;position:relative;font-size:90%;color:#d9d9d9}.wy-menu-vertical a:hover{background-color:#4e4a4a;cursor:pointer}.wy-menu-vertical a:hover button.toctree-expand{color:#d9d9d9}.wy-menu-vertical a:active{background-color:#2980b9;cursor:pointer;color:#fff}.wy-menu-vertical a:active button.toctree-expand{color:#fff}.wy-side-nav-search{display:block;width:300px;padding:.809em;margin-bottom:.809em;z-index:200;background-color:#2980b9;text-align:center;color:#fcfcfc}.wy-side-nav-search input[type=text]{width:100%;border-radius:50px;padding:6px 12px;border-color:#2472a4}.wy-side-nav-search img{display:block;margin:auto auto .809em;height:45px;width:45px;background-color:#2980b9;padding:5px;border-radius:100%}.wy-side-nav-search .wy-dropdown>a,.wy-side-nav-search>a{color:#fcfcfc;font-size:100%;font-weight:700;display:inline-block;padding:4px 6px;margin-bottom:.809em;max-width:100%}.wy-side-nav-search .wy-dropdown>a:hover,.wy-side-nav-search>a:hover{background:hsla(0,0%,100%,.1)}.wy-side-nav-search .wy-dropdown>a img.logo,.wy-side-nav-search>a img.logo{display:block;margin:0 auto;height:auto;width:auto;border-radius:0;max-width:100%;background:transparent}.wy-side-nav-search .wy-dropdown>a.icon img.logo,.wy-side-nav-search>a.icon img.logo{margin-top:.85em}.wy-side-nav-search>div.version{margin-top:-.4045em;margin-bottom:.809em;font-weight:400;color:hsla(0,0%,100%,.3)}.wy-nav .wy-menu-vertical header{color:#2980b9}.wy-nav .wy-menu-vertical a{color:#b3b3b3}.wy-nav .wy-menu-vertical a:hover{background-color:#2980b9;color:#fff}[data-menu-wrap]{-webkit-transition:all .2s ease-in;-moz-transition:all .2s ease-in;transition:all .2s ease-in;position:absolute;opacity:1;width:100%;opacity:0}[data-menu-wrap].move-center{left:0;right:auto;opacity:1}[data-menu-wrap].move-left{right:auto;left:-100%;opacity:0}[data-menu-wrap].move-right{right:-100%;left:auto;opacity:0}.wy-body-for-nav{background:#fcfcfc}.wy-grid-for-nav{position:absolute;width:100%;height:100%}.wy-nav-side{position:fixed;top:0;bottom:0;left:0;padding-bottom:2em;width:300px;overflow-x:hidden;overflow-y:hidden;min-height:100%;color:#9b9b9b;background:#343131;z-index:200}.wy-side-scroll{width:320px;position:relative;overflow-x:hidden;overflow-y:scroll;height:100%}.wy-nav-top{display:none;background:#2980b9;color:#fff;padding:.4045em .809em;position:relative;line-height:50px;text-align:center;font-size:100%;*zoom:1}.wy-nav-top:after,.wy-nav-top:before{display:table;content:""}.wy-nav-top:after{clear:both}.wy-nav-top a{color:#fff;font-weight:700}.wy-nav-top img{margin-right:12px;height:45px;width:45px;background-color:#2980b9;padding:5px;border-radius:100%}.wy-nav-top i{font-size:30px;float:left;cursor:pointer;padding-top:inherit}.wy-nav-content-wrap{margin-left:300px;background:#fcfcfc;min-height:100%}.wy-nav-content{padding:1.618em 3.236em;height:100%;max-width:800px;margin:auto}.wy-body-mask{position:fixed;width:100%;height:100%;background:rgba(0,0,0,.2);display:none;z-index:499}.wy-body-mask.on{display:block}footer{color:grey}footer p{margin-bottom:12px}.rst-content footer span.commit tt,footer span.commit .rst-content tt,footer span.commit code{padding:0;font-family:SFMono-Regular,Menlo,Monaco,Consolas,Liberation Mono,Courier New,Courier,monospace;font-size:1em;background:none;border:none;color:grey}.rst-footer-buttons{*zoom:1}.rst-footer-buttons:after,.rst-footer-buttons:before{width:100%;display:table;content:""}.rst-footer-buttons:after{clear:both}.rst-breadcrumbs-buttons{margin-top:12px;*zoom:1}.rst-breadcrumbs-buttons:after,.rst-breadcrumbs-buttons:before{display:table;content:""}.rst-breadcrumbs-buttons:after{clear:both}#search-results .search li{margin-bottom:24px;border-bottom:1px solid #e1e4e5;padding-bottom:24px}#search-results .search li:first-child{border-top:1px solid #e1e4e5;padding-top:24px}#search-results .search li a{font-size:120%;margin-bottom:12px;display:inline-block}#search-results .context{color:grey;font-size:90%}.genindextable li>ul{margin-left:24px}@media screen and (max-width:768px){.wy-body-for-nav{background:#fcfcfc}.wy-nav-top{display:block}.wy-nav-side{left:-300px}.wy-nav-side.shift{width:85%;left:0}.wy-menu.wy-menu-vertical,.wy-side-nav-search,.wy-side-scroll{width:auto}.wy-nav-content-wrap{margin-left:0}.wy-nav-content-wrap .wy-nav-content{padding:1.618em}.wy-nav-content-wrap.shift{position:fixed;min-width:100%;left:85%;top:0;height:100%;overflow:hidden}}@media screen and (min-width:1100px){.wy-nav-content-wrap{background:rgba(0,0,0,.05)}.wy-nav-content{margin:0;background:#fcfcfc}}@media print{.rst-versions,.wy-nav-side,footer{display:none}.wy-nav-content-wrap{margin-left:0}}.rst-versions{position:fixed;bottom:0;left:0;width:300px;color:#fcfcfc;background:#1f1d1d;font-family:Lato,proxima-nova,Helvetica Neue,Arial,sans-serif;z-index:400}.rst-versions a{color:#2980b9;text-decoration:none}.rst-versions .rst-badge-small{display:none}.rst-versions .rst-current-version{padding:12px;background-color:#272525;display:block;text-align:right;font-size:90%;cursor:pointer;color:#27ae60;*zoom:1}.rst-versions .rst-current-version:after,.rst-versions .rst-current-version:before{display:table;content:""}.rst-versions .rst-current-version:after{clear:both}.rst-content .code-block-caption .rst-versions .rst-current-version .headerlink,.rst-content .eqno .rst-versions .rst-current-version .headerlink,.rst-content .rst-versions .rst-current-version .admonition-title,.rst-content code.download .rst-versions .rst-current-version span:first-child,.rst-content dl dt .rst-versions .rst-current-version .headerlink,.rst-content h1 .rst-versions .rst-current-version .headerlink,.rst-content h2 .rst-versions .rst-current-version .headerlink,.rst-content h3 .rst-versions .rst-current-version .headerlink,.rst-content h4 .rst-versions .rst-current-version .headerlink,.rst-content h5 .rst-versions .rst-current-version .headerlink,.rst-content h6 .rst-versions .rst-current-version .headerlink,.rst-content p .rst-versions .rst-current-version .headerlink,.rst-content table>caption .rst-versions .rst-current-version .headerlink,.rst-content tt.download .rst-versions .rst-current-version span:first-child,.rst-versions .rst-current-version .fa,.rst-versions .rst-current-version .icon,.rst-versions .rst-current-version .rst-content .admonition-title,.rst-versions .rst-current-version .rst-content .code-block-caption .headerlink,.rst-versions .rst-current-version .rst-content .eqno .headerlink,.rst-versions .rst-current-version .rst-content code.download span:first-child,.rst-versions .rst-current-version .rst-content dl dt .headerlink,.rst-versions .rst-current-version .rst-content h1 .headerlink,.rst-versions .rst-current-version .rst-content h2 .headerlink,.rst-versions .rst-current-version .rst-content h3 .headerlink,.rst-versions .rst-current-version .rst-content h4 .headerlink,.rst-versions .rst-current-version .rst-content h5 .headerlink,.rst-versions .rst-current-version .rst-content h6 .headerlink,.rst-versions .rst-current-version .rst-content p .headerlink,.rst-versions .rst-current-version .rst-content table>caption .headerlink,.rst-versions .rst-current-version .rst-content tt.download span:first-child,.rst-versions .rst-current-version .wy-menu-vertical li button.toctree-expand,.wy-menu-vertical li .rst-versions .rst-current-version button.toctree-expand{color:#fcfcfc}.rst-versions .rst-current-version .fa-book,.rst-versions .rst-current-version .icon-book{float:left}.rst-versions .rst-current-version.rst-out-of-date{background-color:#e74c3c;color:#fff}.rst-versions .rst-current-version.rst-active-old-version{background-color:#f1c40f;color:#000}.rst-versions.shift-up{height:auto;max-height:100%;overflow-y:scroll}.rst-versions.shift-up .rst-other-versions{display:block}.rst-versions .rst-other-versions{font-size:90%;padding:12px;color:grey;display:none}.rst-versions .rst-other-versions hr{display:block;height:1px;border:0;margin:20px 0;padding:0;border-top:1px solid #413d3d}.rst-versions .rst-other-versions dd{display:inline-block;margin:0}.rst-versions .rst-other-versions dd a{display:inline-block;padding:6px;color:#fcfcfc}.rst-versions.rst-badge{width:auto;bottom:20px;right:20px;left:auto;border:none;max-width:300px;max-height:90%}.rst-versions.rst-badge .fa-book,.rst-versions.rst-badge .icon-book{float:none;line-height:30px}.rst-versions.rst-badge.shift-up .rst-current-version{text-align:right}.rst-versions.rst-badge.shift-up .rst-current-version .fa-book,.rst-versions.rst-badge.shift-up .rst-current-version .icon-book{float:left}.rst-versions.rst-badge>.rst-current-version{width:auto;height:30px;line-height:30px;padding:0 6px;display:block;text-align:center}@media screen and (max-width:768px){.rst-versions{width:85%;display:none}.rst-versions.shift{display:block}}.rst-content .toctree-wrapper>p.caption,.rst-content h1,.rst-content h2,.rst-content h3,.rst-content h4,.rst-content h5,.rst-content h6{margin-bottom:24px}.rst-content img{max-width:100%;height:auto}.rst-content div.figure,.rst-content figure{margin-bottom:24px}.rst-content div.figure .caption-text,.rst-content figure .caption-text{font-style:italic}.rst-content div.figure p:last-child.caption,.rst-content figure p:last-child.caption{margin-bottom:0}.rst-content div.figure.align-center,.rst-content figure.align-center{text-align:center}.rst-content .section>a>img,.rst-content .section>img,.rst-content section>a>img,.rst-content section>img{margin-bottom:24px}.rst-content abbr[title]{text-decoration:none}.rst-content.style-external-links a.reference.external:after{font-family:FontAwesome;content:"\f08e";color:#b3b3b3;vertical-align:super;font-size:60%;margin:0 .2em}.rst-content blockquote{margin-left:24px;line-height:24px;margin-bottom:24px}.rst-content pre.literal-block{white-space:pre;margin:0;padding:12px;font-family:SFMono-Regular,Menlo,Monaco,Consolas,Liberation Mono,Courier New,Courier,monospace;display:block;overflow:auto}.rst-content div[class^=highlight],.rst-content pre.literal-block{border:1px solid #e1e4e5;overflow-x:auto;margin:1px 0 24px}.rst-content div[class^=highlight] div[class^=highlight],.rst-content pre.literal-block div[class^=highlight]{padding:0;border:none;margin:0}.rst-content div[class^=highlight] td.code{width:100%}.rst-content .linenodiv pre{border-right:1px solid #e6e9ea;margin:0;padding:12px;font-family:SFMono-Regular,Menlo,Monaco,Consolas,Liberation Mono,Courier New,Courier,monospace;user-select:none;pointer-events:none}.rst-content div[class^=highlight] pre{white-space:pre;margin:0;padding:12px;display:block;overflow:auto}.rst-content div[class^=highlight] pre .hll{display:block;margin:0 -12px;padding:0 12px}.rst-content .linenodiv pre,.rst-content div[class^=highlight] pre,.rst-content pre.literal-block{font-family:SFMono-Regular,Menlo,Monaco,Consolas,Liberation Mono,Courier New,Courier,monospace;font-size:12px;line-height:1.4}.rst-content div.highlight .gp,.rst-content div.highlight span.linenos{user-select:none;pointer-events:none}.rst-content div.highlight span.linenos{display:inline-block;padding-left:0;padding-right:12px;margin-right:12px;border-right:1px solid #e6e9ea}.rst-content .code-block-caption{font-style:italic;font-size:85%;line-height:1;padding:1em 0;text-align:center}@media print{.rst-content .codeblock,.rst-content div[class^=highlight],.rst-content div[class^=highlight] pre{white-space:pre-wrap}}.rst-content .admonition,.rst-content .admonition-todo,.rst-content .attention,.rst-content .caution,.rst-content .danger,.rst-content .error,.rst-content .hint,.rst-content .important,.rst-content .note,.rst-content .seealso,.rst-content .tip,.rst-content .warning{clear:both}.rst-content .admonition-todo .last,.rst-content .admonition-todo>:last-child,.rst-content .admonition .last,.rst-content .admonition>:last-child,.rst-content .attention .last,.rst-content .attention>:last-child,.rst-content .caution .last,.rst-content .caution>:last-child,.rst-content .danger .last,.rst-content .danger>:last-child,.rst-content .error .last,.rst-content .error>:last-child,.rst-content .hint .last,.rst-content .hint>:last-child,.rst-content .important .last,.rst-content .important>:last-child,.rst-content .note .last,.rst-content .note>:last-child,.rst-content .seealso .last,.rst-content .seealso>:last-child,.rst-content .tip .last,.rst-content .tip>:last-child,.rst-content .warning .last,.rst-content .warning>:last-child{margin-bottom:0}.rst-content .admonition-title:before{margin-right:4px}.rst-content .admonition table{border-color:rgba(0,0,0,.1)}.rst-content .admonition table td,.rst-content .admonition table th{background:transparent!important;border-color:rgba(0,0,0,.1)!important}.rst-content .section ol.loweralpha,.rst-content .section ol.loweralpha>li,.rst-content .toctree-wrapper ol.loweralpha,.rst-content .toctree-wrapper ol.loweralpha>li,.rst-content section ol.loweralpha,.rst-content section ol.loweralpha>li{list-style:lower-alpha}.rst-content .section ol.upperalpha,.rst-content .section ol.upperalpha>li,.rst-content .toctree-wrapper ol.upperalpha,.rst-content .toctree-wrapper ol.upperalpha>li,.rst-content section ol.upperalpha,.rst-content section ol.upperalpha>li{list-style:upper-alpha}.rst-content .section ol li>*,.rst-content .section ul li>*,.rst-content .toctree-wrapper ol li>*,.rst-content .toctree-wrapper ul li>*,.rst-content section ol li>*,.rst-content section ul li>*{margin-top:12px;margin-bottom:12px}.rst-content .section ol li>:first-child,.rst-content .section ul li>:first-child,.rst-content .toctree-wrapper ol li>:first-child,.rst-content .toctree-wrapper ul li>:first-child,.rst-content section ol li>:first-child,.rst-content section ul li>:first-child{margin-top:0}.rst-content .section ol li>p,.rst-content .section ol li>p:last-child,.rst-content .section ul li>p,.rst-content .section ul li>p:last-child,.rst-content .toctree-wrapper ol li>p,.rst-content .toctree-wrapper ol li>p:last-child,.rst-content .toctree-wrapper ul li>p,.rst-content .toctree-wrapper ul li>p:last-child,.rst-content section ol li>p,.rst-content section ol li>p:last-child,.rst-content section ul li>p,.rst-content section ul li>p:last-child{margin-bottom:12px}.rst-content .section ol li>p:only-child,.rst-content .section ol li>p:only-child:last-child,.rst-content .section ul li>p:only-child,.rst-content .section ul li>p:only-child:last-child,.rst-content .toctree-wrapper ol li>p:only-child,.rst-content .toctree-wrapper ol li>p:only-child:last-child,.rst-content .toctree-wrapper ul li>p:only-child,.rst-content .toctree-wrapper ul li>p:only-child:last-child,.rst-content section ol li>p:only-child,.rst-content section ol li>p:only-child:last-child,.rst-content section ul li>p:only-child,.rst-content section ul li>p:only-child:last-child{margin-bottom:0}.rst-content .section ol li>ol,.rst-content .section ol li>ul,.rst-content .section ul li>ol,.rst-content .section ul li>ul,.rst-content .toctree-wrapper ol li>ol,.rst-content .toctree-wrapper ol li>ul,.rst-content .toctree-wrapper ul li>ol,.rst-content .toctree-wrapper ul li>ul,.rst-content section ol li>ol,.rst-content section ol li>ul,.rst-content section ul li>ol,.rst-content section ul li>ul{margin-bottom:12px}.rst-content .section ol.simple li>*,.rst-content .section ol.simple li ol,.rst-content .section ol.simple li ul,.rst-content .section ul.simple li>*,.rst-content .section ul.simple li ol,.rst-content .section ul.simple li ul,.rst-content .toctree-wrapper ol.simple li>*,.rst-content .toctree-wrapper ol.simple li ol,.rst-content .toctree-wrapper ol.simple li ul,.rst-content .toctree-wrapper ul.simple li>*,.rst-content .toctree-wrapper ul.simple li ol,.rst-content .toctree-wrapper ul.simple li ul,.rst-content section ol.simple li>*,.rst-content section ol.simple li ol,.rst-content section ol.simple li ul,.rst-content section ul.simple li>*,.rst-content section ul.simple li ol,.rst-content section ul.simple li ul{margin-top:0;margin-bottom:0}.rst-content .line-block{margin-left:0;margin-bottom:24px;line-height:24px}.rst-content .line-block .line-block{margin-left:24px;margin-bottom:0}.rst-content .topic-title{font-weight:700;margin-bottom:12px}.rst-content .toc-backref{color:#404040}.rst-content .align-right{float:right;margin:0 0 24px 24px}.rst-content .align-left{float:left;margin:0 24px 24px 0}.rst-content .align-center{margin:auto}.rst-content .align-center:not(table){display:block}.rst-content .code-block-caption .headerlink,.rst-content .eqno .headerlink,.rst-content .toctree-wrapper>p.caption .headerlink,.rst-content dl dt .headerlink,.rst-content h1 .headerlink,.rst-content h2 .headerlink,.rst-content h3 .headerlink,.rst-content h4 .headerlink,.rst-content h5 .headerlink,.rst-content h6 .headerlink,.rst-content p.caption .headerlink,.rst-content p .headerlink,.rst-content table>caption .headerlink{opacity:0;font-size:14px;font-family:FontAwesome;margin-left:.5em}.rst-content .code-block-caption .headerlink:focus,.rst-content .code-block-caption:hover .headerlink,.rst-content .eqno .headerlink:focus,.rst-content .eqno:hover .headerlink,.rst-content .toctree-wrapper>p.caption .headerlink:focus,.rst-content .toctree-wrapper>p.caption:hover .headerlink,.rst-content dl dt .headerlink:focus,.rst-content dl dt:hover .headerlink,.rst-content h1 .headerlink:focus,.rst-content h1:hover .headerlink,.rst-content h2 .headerlink:focus,.rst-content h2:hover .headerlink,.rst-content h3 .headerlink:focus,.rst-content h3:hover .headerlink,.rst-content h4 .headerlink:focus,.rst-content h4:hover .headerlink,.rst-content h5 .headerlink:focus,.rst-content h5:hover .headerlink,.rst-content h6 .headerlink:focus,.rst-content h6:hover .headerlink,.rst-content p.caption .headerlink:focus,.rst-content p.caption:hover .headerlink,.rst-content p .headerlink:focus,.rst-content p:hover .headerlink,.rst-content table>caption .headerlink:focus,.rst-content table>caption:hover .headerlink{opacity:1}.rst-content .btn:focus{outline:2px solid}.rst-content table>caption .headerlink:after{font-size:12px}.rst-content .centered{text-align:center}.rst-content .sidebar{float:right;width:40%;display:block;margin:0 0 24px 24px;padding:24px;background:#f3f6f6;border:1px solid #e1e4e5}.rst-content .sidebar dl,.rst-content .sidebar p,.rst-content .sidebar ul{font-size:90%}.rst-content .sidebar .last,.rst-content .sidebar>:last-child{margin-bottom:0}.rst-content .sidebar .sidebar-title{display:block;font-family:Roboto Slab,ff-tisa-web-pro,Georgia,Arial,sans-serif;font-weight:700;background:#e1e4e5;padding:6px 12px;margin:-24px -24px 24px;font-size:100%}.rst-content .highlighted{background:#f1c40f;box-shadow:0 0 0 2px #f1c40f;display:inline;font-weight:700}.rst-content .citation-reference,.rst-content .footnote-reference{vertical-align:baseline;position:relative;top:-.4em;line-height:0;font-size:90%}.rst-content .hlist{width:100%}.rst-content dl dt span.classifier:before{content:" : "}.rst-content dl dt span.classifier-delimiter{display:none!important}html.writer-html4 .rst-content table.docutils.citation,html.writer-html4 .rst-content table.docutils.footnote{background:none;border:none}html.writer-html4 .rst-content table.docutils.citation td,html.writer-html4 .rst-content table.docutils.citation tr,html.writer-html4 .rst-content table.docutils.footnote td,html.writer-html4 .rst-content table.docutils.footnote tr{border:none;background-color:transparent!important;white-space:normal}html.writer-html4 .rst-content table.docutils.citation td.label,html.writer-html4 .rst-content table.docutils.footnote td.label{padding-left:0;padding-right:0;vertical-align:top}html.writer-html5 .rst-content dl.field-list,html.writer-html5 .rst-content dl.footnote{display:grid;grid-template-columns:max-content auto}html.writer-html5 .rst-content dl.field-list>dt,html.writer-html5 .rst-content dl.footnote>dt{padding-left:1rem}html.writer-html5 .rst-content dl.field-list>dt:after,html.writer-html5 .rst-content dl.footnote>dt:after{content:":"}html.writer-html5 .rst-content dl.field-list>dd,html.writer-html5 .rst-content dl.field-list>dt,html.writer-html5 .rst-content dl.footnote>dd,html.writer-html5 .rst-content dl.footnote>dt{margin-bottom:0}html.writer-html5 .rst-content dl.footnote{font-size:.9rem}html.writer-html5 .rst-content dl.footnote>dt{margin:0 .5rem .5rem 0;line-height:1.2rem;word-break:break-all;font-weight:400}html.writer-html5 .rst-content dl.footnote>dt>span.brackets{margin-right:.5rem}html.writer-html5 .rst-content dl.footnote>dt>span.brackets:before{content:"["}html.writer-html5 .rst-content dl.footnote>dt>span.brackets:after{content:"]"}html.writer-html5 .rst-content dl.footnote>dt>span.fn-backref{font-style:italic}html.writer-html5 .rst-content dl.footnote>dd{margin:0 0 .5rem;line-height:1.2rem}html.writer-html5 .rst-content dl.footnote>dd p,html.writer-html5 .rst-content dl.option-list kbd{font-size:.9rem}.rst-content table.docutils.footnote,html.writer-html4 .rst-content table.docutils.citation,html.writer-html5 .rst-content dl.footnote{color:grey}.rst-content table.docutils.footnote code,.rst-content table.docutils.footnote tt,html.writer-html4 .rst-content table.docutils.citation code,html.writer-html4 .rst-content table.docutils.citation tt,html.writer-html5 .rst-content dl.footnote code,html.writer-html5 .rst-content dl.footnote tt{color:#555}.rst-content .wy-table-responsive.citation,.rst-content .wy-table-responsive.footnote{margin-bottom:0}.rst-content .wy-table-responsive.citation+:not(.citation),.rst-content .wy-table-responsive.footnote+:not(.footnote){margin-top:24px}.rst-content .wy-table-responsive.citation:last-child,.rst-content .wy-table-responsive.footnote:last-child{margin-bottom:24px}.rst-content table.docutils th{border-color:#e1e4e5}html.writer-html5 .rst-content table.docutils th{border:1px solid #e1e4e5}html.writer-html5 .rst-content table.docutils td>p,html.writer-html5 .rst-content table.docutils th>p{line-height:1rem;margin-bottom:0;font-size:.9rem}.rst-content table.docutils td .last,.rst-content table.docutils td .last>:last-child{margin-bottom:0}.rst-content table.field-list,.rst-content table.field-list td{border:none}.rst-content table.field-list td p{font-size:inherit;line-height:inherit}.rst-content table.field-list td>strong{display:inline-block}.rst-content table.field-list .field-name{padding-right:10px;text-align:left;white-space:nowrap}.rst-content table.field-list .field-body{text-align:left}.rst-content code,.rst-content tt{color:#000;font-family:SFMono-Regular,Menlo,Monaco,Consolas,Liberation Mono,Courier New,Courier,monospace;padding:2px 5px}.rst-content code big,.rst-content code em,.rst-content tt big,.rst-content tt em{font-size:100%!important;line-height:normal}.rst-content code.literal,.rst-content tt.literal{color:#e74c3c;white-space:normal}.rst-content code.xref,.rst-content tt.xref,a .rst-content code,a .rst-content tt{font-weight:700;color:#404040}.rst-content kbd,.rst-content pre,.rst-content samp{font-family:SFMono-Regular,Menlo,Monaco,Consolas,Liberation Mono,Courier New,Courier,monospace}.rst-content a code,.rst-content a tt{color:#2980b9}.rst-content dl{margin-bottom:24px}.rst-content dl dt{font-weight:700;margin-bottom:12px}.rst-content dl ol,.rst-content dl p,.rst-content dl table,.rst-content dl ul{margin-bottom:12px}.rst-content dl dd{margin:0 0 12px 24px;line-height:24px}html.writer-html4 .rst-content dl:not(.docutils),html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple){margin-bottom:24px}html.writer-html4 .rst-content dl:not(.docutils)>dt,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple)>dt{display:table;margin:6px 0;font-size:90%;line-height:normal;background:#e7f2fa;color:#2980b9;border-top:3px solid #6ab0de;padding:6px;position:relative}html.writer-html4 .rst-content dl:not(.docutils)>dt:before,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple)>dt:before{color:#6ab0de}html.writer-html4 .rst-content dl:not(.docutils)>dt .headerlink,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple)>dt .headerlink{color:#404040;font-size:100%!important}html.writer-html4 .rst-content dl:not(.docutils) dl:not(.field-list)>dt,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) dl:not(.field-list)>dt{margin-bottom:6px;border:none;border-left:3px solid #ccc;background:#f0f0f0;color:#555}html.writer-html4 .rst-content dl:not(.docutils) dl:not(.field-list)>dt .headerlink,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) dl:not(.field-list)>dt .headerlink{color:#404040;font-size:100%!important}html.writer-html4 .rst-content dl:not(.docutils)>dt:first-child,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple)>dt:first-child{margin-top:0}html.writer-html4 .rst-content dl:not(.docutils) code.descclassname,html.writer-html4 .rst-content dl:not(.docutils) code.descname,html.writer-html4 .rst-content dl:not(.docutils) tt.descclassname,html.writer-html4 .rst-content dl:not(.docutils) tt.descname,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) code.descclassname,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) code.descname,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) tt.descclassname,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) tt.descname{background-color:transparent;border:none;padding:0;font-size:100%!important}html.writer-html4 .rst-content dl:not(.docutils) code.descname,html.writer-html4 .rst-content dl:not(.docutils) tt.descname,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) code.descname,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) tt.descname{font-weight:700}html.writer-html4 .rst-content dl:not(.docutils) .optional,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) .optional{display:inline-block;padding:0 4px;color:#000;font-weight:700}html.writer-html4 .rst-content dl:not(.docutils) .property,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) .property{display:inline-block;padding-right:8px;max-width:100%}html.writer-html4 .rst-content dl:not(.docutils) .k,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) .k{font-style:italic}html.writer-html4 .rst-content dl:not(.docutils) .descclassname,html.writer-html4 .rst-content dl:not(.docutils) .descname,html.writer-html4 .rst-content dl:not(.docutils) .sig-name,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) .descclassname,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) .descname,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) .sig-name{font-family:SFMono-Regular,Menlo,Monaco,Consolas,Liberation Mono,Courier New,Courier,monospace;color:#000}.rst-content .viewcode-back,.rst-content .viewcode-link{display:inline-block;color:#27ae60;font-size:80%;padding-left:24px}.rst-content .viewcode-back{display:block;float:right}.rst-content p.rubric{margin-bottom:12px;font-weight:700}.rst-content code.download,.rst-content tt.download{background:inherit;padding:inherit;font-weight:400;font-family:inherit;font-size:inherit;color:inherit;border:inherit;white-space:inherit}.rst-content code.download span:first-child,.rst-content tt.download span:first-child{-webkit-font-smoothing:subpixel-antialiased}.rst-content code.download span:first-child:before,.rst-content tt.download span:first-child:before{margin-right:4px}.rst-content .guilabel{border:1px solid #7fbbe3;background:#e7f2fa;font-size:80%;font-weight:700;border-radius:4px;padding:2.4px 6px;margin:auto 2px}.rst-content .versionmodified{font-style:italic}@media screen and (max-width:480px){.rst-content .sidebar{width:100%}}span[id*=MathJax-Span]{color:#404040}.math{text-align:center}@font-face{font-family:Lato;src:url(fonts/lato-normal.woff2?bd03a2cc277bbbc338d464e679fe9942) format("woff2"),url(fonts/lato-normal.woff?27bd77b9162d388cb8d4c4217c7c5e2a) format("woff");font-weight:400;font-style:normal;font-display:block}@font-face{font-family:Lato;src:url(fonts/lato-bold.woff2?cccb897485813c7c256901dbca54ecf2) format("woff2"),url(fonts/lato-bold.woff?d878b6c29b10beca227e9eef4246111b) format("woff");font-weight:700;font-style:normal;font-display:block}@font-face{font-family:Lato;src:url(fonts/lato-bold-italic.woff2?0b6bb6725576b072c5d0b02ecdd1900d) format("woff2"),url(fonts/lato-bold-italic.woff?9c7e4e9eb485b4a121c760e61bc3707c) format("woff");font-weight:700;font-style:italic;font-display:block}@font-face{font-family:Lato;src:url(fonts/lato-normal-italic.woff2?4eb103b4d12be57cb1d040ed5e162e9d) format("woff2"),url(fonts/lato-normal-italic.woff?f28f2d6482446544ef1ea1ccc6dd5892) format("woff");font-weight:400;font-style:italic;font-display:block}@font-face{font-family:Roboto Slab;font-style:normal;font-weight:400;src:url(fonts/Roboto-Slab-Regular.woff2?7abf5b8d04d26a2cafea937019bca958) format("woff2"),url(fonts/Roboto-Slab-Regular.woff?c1be9284088d487c5e3ff0a10a92e58c) format("woff");font-display:block}@font-face{font-family:Roboto Slab;font-style:normal;font-weight:700;src:url(fonts/Roboto-Slab-Bold.woff2?9984f4a9bda09be08e83f2506954adbe) format("woff2"),url(fonts/Roboto-Slab-Bold.woff?bed5564a116b05148e3b3bea6fb1162a) format("woff");font-display:block}
\ No newline at end of file
diff --git a/docs/_static/doctools.js b/docs/_static/doctools.js
new file mode 100644
index 00000000..c3db08d1
--- /dev/null
+++ b/docs/_static/doctools.js
@@ -0,0 +1,264 @@
+/*
+ * doctools.js
+ * ~~~~~~~~~~~
+ *
+ * Base JavaScript utilities for all Sphinx HTML documentation.
+ *
+ * :copyright: Copyright 2007-2022 by the Sphinx team, see AUTHORS.
+ * :license: BSD, see LICENSE for details.
+ *
+ */
+"use strict";
+
+const _ready = (callback) => {
+ if (document.readyState !== "loading") {
+ callback();
+ } else {
+ document.addEventListener("DOMContentLoaded", callback);
+ }
+};
+
+/**
+ * highlight a given string on a node by wrapping it in
+ * span elements with the given class name.
+ */
+const _highlight = (node, addItems, text, className) => {
+ if (node.nodeType === Node.TEXT_NODE) {
+ const val = node.nodeValue;
+ const parent = node.parentNode;
+ const pos = val.toLowerCase().indexOf(text);
+ if (
+ pos >= 0 &&
+ !parent.classList.contains(className) &&
+ !parent.classList.contains("nohighlight")
+ ) {
+ let span;
+
+ const closestNode = parent.closest("body, svg, foreignObject");
+ const isInSVG = closestNode && closestNode.matches("svg");
+ if (isInSVG) {
+ span = document.createElementNS("http://www.w3.org/2000/svg", "tspan");
+ } else {
+ span = document.createElement("span");
+ span.classList.add(className);
+ }
+
+ span.appendChild(document.createTextNode(val.substr(pos, text.length)));
+ parent.insertBefore(
+ span,
+ parent.insertBefore(
+ document.createTextNode(val.substr(pos + text.length)),
+ node.nextSibling
+ )
+ );
+ node.nodeValue = val.substr(0, pos);
+
+ if (isInSVG) {
+ const rect = document.createElementNS(
+ "http://www.w3.org/2000/svg",
+ "rect"
+ );
+ const bbox = parent.getBBox();
+ rect.x.baseVal.value = bbox.x;
+ rect.y.baseVal.value = bbox.y;
+ rect.width.baseVal.value = bbox.width;
+ rect.height.baseVal.value = bbox.height;
+ rect.setAttribute("class", className);
+ addItems.push({ parent: parent, target: rect });
+ }
+ }
+ } else if (node.matches && !node.matches("button, select, textarea")) {
+ node.childNodes.forEach((el) => _highlight(el, addItems, text, className));
+ }
+};
+const _highlightText = (thisNode, text, className) => {
+ let addItems = [];
+ _highlight(thisNode, addItems, text, className);
+ addItems.forEach((obj) =>
+ obj.parent.insertAdjacentElement("beforebegin", obj.target)
+ );
+};
+
+/**
+ * Small JavaScript module for the documentation.
+ */
+const Documentation = {
+ init: () => {
+ Documentation.highlightSearchWords();
+ Documentation.initDomainIndexTable();
+ Documentation.initOnKeyListeners();
+ },
+
+ /**
+ * i18n support
+ */
+ TRANSLATIONS: {},
+ PLURAL_EXPR: (n) => (n === 1 ? 0 : 1),
+ LOCALE: "unknown",
+
+ // gettext and ngettext don't access this so that the functions
+ // can safely bound to a different name (_ = Documentation.gettext)
+ gettext: (string) => {
+ const translated = Documentation.TRANSLATIONS[string];
+ switch (typeof translated) {
+ case "undefined":
+ return string; // no translation
+ case "string":
+ return translated; // translation exists
+ default:
+ return translated[0]; // (singular, plural) translation tuple exists
+ }
+ },
+
+ ngettext: (singular, plural, n) => {
+ const translated = Documentation.TRANSLATIONS[singular];
+ if (typeof translated !== "undefined")
+ return translated[Documentation.PLURAL_EXPR(n)];
+ return n === 1 ? singular : plural;
+ },
+
+ addTranslations: (catalog) => {
+ Object.assign(Documentation.TRANSLATIONS, catalog.messages);
+ Documentation.PLURAL_EXPR = new Function(
+ "n",
+ `return (${catalog.plural_expr})`
+ );
+ Documentation.LOCALE = catalog.locale;
+ },
+
+ /**
+ * highlight the search words provided in the url in the text
+ */
+ highlightSearchWords: () => {
+ const highlight =
+ new URLSearchParams(window.location.search).get("highlight") || "";
+ const terms = highlight.toLowerCase().split(/\s+/).filter(x => x);
+ if (terms.length === 0) return; // nothing to do
+
+ // There should never be more than one element matching "div.body"
+ const divBody = document.querySelectorAll("div.body");
+ const body = divBody.length ? divBody[0] : document.querySelector("body");
+ window.setTimeout(() => {
+ terms.forEach((term) => _highlightText(body, term, "highlighted"));
+ }, 10);
+
+ const searchBox = document.getElementById("searchbox");
+ if (searchBox === null) return;
+ searchBox.appendChild(
+ document
+ .createRange()
+ .createContextualFragment(
+ '