diff --git a/WORKSPACE b/WORKSPACE index 5eb4e72dc..6a7331046 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -290,6 +290,17 @@ http_archive( ], ) +http_archive( + name = "daos", + build_file = "//third_party:daos.BUILD", + sha256 = "9789a5a0065cfa4249105f1676b9eba89f68b54bc03083140549b7a8a8f615d3", + strip_prefix = "daos-2.0.2", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/daos-stack/daos/archive/refs/tags/v2.0.2.tar.gz", + "https://github.com/daos-stack/daos/archive/refs/tags/v2.0.2.tar.gz", + ], +) + http_archive( name = "dav1d", build_file = "//third_party:dav1d.BUILD", diff --git a/docs/daos_tf_docs.md b/docs/daos_tf_docs.md new file mode 100644 index 000000000..0fe232a4e --- /dev/null +++ b/docs/daos_tf_docs.md @@ -0,0 +1,107 @@ +# DAOS-TensorFlow IO GUIDE + +## Table Of Contents + +- [Features](#features) +- [Prerequisites](#prerequisites) +- [Environment Setup](#environment-setup) +- [Building](#building) +- [Testing](#testing) +- [Example](#example) + +## Features + +* Providing a plugin utilizing the DAOS DFS layer to provide efficient utilization for Intel's filesystem. The Read-Only Memory Region remains unsupported. + +## Prerequisites + +* A valid DAOS installation, currently based on [version v2.0.2](https://github.com/daos-stack/daos/releases/tag/v2.0.2) + * An installation guide and steps can be accessed from [here](https://docs.daos.io/admin/installation/) + +## Environment Setup + +Assuming you are in a terminal in the repository root directory: + +* Install latest versions of the following dependencies by running + * Centos 8 + ``` + $ yum install -y python3 python3-devel gcc gcc-c++ git unzip which make + ``` + * Ubuntu 20.04 + ``` + $ sudo apt-get -y -qq update + $ sudo apt-get -y -qq install gcc g++ git unzip curl python3-pip + ``` +* Download the Bazel installer + ``` + $ curl -sSOL https://github.com/bazelbuild/bazel/releases/download/\$(cat .bazelversion)/bazel-\$(cat .bazelversion)-installer-linux-x86_64.sh + ``` +* Install Bazel + ``` + $ bash -x -e bazel-$(cat .bazelversion)-installer-linux-x86_64.sh + ``` +* Update Pip and install pytest + ``` + $ python3 -m pip install -U pip + $ python3 -m pip install pytest + ``` + +## Building + +Assuming you are in a terminal in the repository root directory: + +* Configure and install tensorflow (the current version should be tensorflow2.6.2) + ``` + $ ./configure.sh + ## Set python3 as default. + $ ln -s /usr/bin/python3 /usr/bin/python + ``` + +* At this point, all libraries and dependencies should be installed. + * Make sure the environment variable **LIBRARY_PATH** includes the paths to all daos libraries + * Make sure the environment variable **LD_LIBRARY_PATH** includes the paths to: + * All daos libraries + * The tensorflow framework (libtensorflow and libtensorflow_framework) + * If not, find the required libraries and add their paths to the environment variable + ``` + export LD_LIBRARY_PATH=":$LD_LIBARY_PATH" + ``` + * Make sure the environment variable **CPLUS_INCLUDE_PATH** and **C_INCLUDE_PATH** includes the paths to: + * The tensorflow headers (usually in /usr/local/lib64/python3.6/site-packages/tensorflow/include) + * If not, find the required headers and add their paths to the environment variable + ``` + export CPLUS_INCLUDE_PATH=":$CPLUS_INCLUDE_PATH" + export C_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:$C_INCLUDE_PATH + ``` + +* Build the project using bazel + ``` + bazel build --action_env=LIBRARY_PATH=$LIBRARY_PATH -s --verbose_failures --spawn_strategy=standalone //tensorflow_io/... //tensorflow_io_gcs_filesystem/... + ``` + This should take a few minutes. Note that sandboxing may result in build failures when using Docker Containers for DAOS due to mounting issues, if that’s the case, add **--spawn_strategy=standalone** to the above build command to bypass sandboxing. (When disabling sandbox, an error may be thrown for an undefined type z_crc_t due to a conflict in header files. Please find the crypt.h file in the bazel cache in subdirectory /external/zlib/contrib/minizip/crypt.h and add the following line to the file **typedef unsigned long z_crc_t;** then re-build) + + + +## Testing +Assuming you are in a terminal in the repository root directory: + +* Run the following command for the simple serial test to validate building. Note that any tests need to be run with the TFIO_DATAPATH flag to specify the location of the binaries. + ``` + $ TFIO_DATAPATH=bazel-bin python3 -m pytest -s -v tests/test_serialization.py + + ``` + +* Run the following commands to run the dfs plugin test: + ``` + # To create the required pool and container and export required env variables for the dfs tests. + $ source tests/test_dfs/dfs_init.sh + # To run dfs tests + $ TFIO_DATAPATH=bazel-bin python3 -m pytest -s -v tests/test_dfs.py + # For Cleanup, deletes pools and containers created for test. + $ bash ./tests/test_dfs/dfs_cleanup.sh + ``` + +## Example + +Please refer to [the DAOS notebook example in the tutorials folder in docs folder.](tutorials/daos.ipynb) + diff --git a/docs/tutorials/daos.ipynb b/docs/tutorials/daos.ipynb new file mode 100644 index 000000000..102e26326 --- /dev/null +++ b/docs/tutorials/daos.ipynb @@ -0,0 +1,583 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "c774b6fdf4b1" + }, + "source": [ + "##### Copyright 2022 The TensorFlow IO Authors." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "906e07f6e562" + }, + "outputs": [], + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "id": "1f9e30da", + "metadata": { + "id": "7857033a12ad" + }, + "source": [ + "# DAOS Filesystem with Tensorflow (Using MNIST)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "41881aed035d" + }, + "source": [ + "\n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " View on TensorFlow.org\n", + " \n", + " Run in Google Colab\n", + " \n", + " View source on GitHub\n", + " \n", + " Download notebook\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "22b37505", + "metadata": { + "id": "e5708f933d28" + }, + "source": [ + "## Overview\n", + "\n", + "This tutorial shows how to use read and write files on [DAOS Filesystem](https://docs.daos.io/) with TensorFlow, through TensorFlow IO's DAOS file system integration.\n", + "\n", + "A machine running DAOS natively or through a [docker emulator](https://github.com/daos-stack/daos/tree/master/utils/docker) is needed to run this tutorial and/or use the Tensorflow IO DAOS integration. The DAOS Pool and Container used for this tutorial will be created and deleted within this tutorial, where you will be training and testing a simple Neural Network on the MNIST Dataset loaded from the DAOS File System Plugin.\n", + "\n", + "The pool and container id or label are part of the filename uri:\n", + "```\n", + "daos:////\n", + "daos:///cont-label/\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "8a3ef6a9", + "metadata": { + "id": "b27be7087d0e" + }, + "source": [ + "## Setup and usage" + ] + }, + { + "cell_type": "markdown", + "id": "1ad41a1a", + "metadata": { + "id": "e20c1d316af6" + }, + "source": [ + "### Install required packages, and restart runtime" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "5e35916b", + "metadata": { + "id": "5de1951509cb" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Defaulting to user installation because normal site-packages is not writeable\n", + "Requirement already satisfied: tensorflow-io in /home/omar/.local/lib/python3.8/site-packages (0.20.0)\n", + "Requirement already satisfied: tensorflow<2.7.0,>=2.6.0 in /home/omar/.local/lib/python3.8/site-packages (from tensorflow-io) (2.6.0)\n", + "Requirement already satisfied: tensorflow-io-gcs-filesystem==0.20.0 in /home/omar/.local/lib/python3.8/site-packages (from tensorflow-io) (0.20.0)\n", + "Requirement already satisfied: gast==0.4.0 in /home/omar/.local/lib/python3.8/site-packages (from tensorflow<2.7.0,>=2.6.0->tensorflow-io) (0.4.0)\n", + "Requirement already satisfied: grpcio<2.0,>=1.37.0 in /home/omar/.local/lib/python3.8/site-packages (from tensorflow<2.7.0,>=2.6.0->tensorflow-io) (1.39.0)\n", + "Requirement already satisfied: protobuf>=3.9.2 in /home/omar/.local/lib/python3.8/site-packages (from tensorflow<2.7.0,>=2.6.0->tensorflow-io) (3.17.3)\n", + "Requirement already satisfied: tensorboard~=2.6 in /home/omar/.local/lib/python3.8/site-packages (from tensorflow<2.7.0,>=2.6.0->tensorflow-io) (2.6.0)\n", + "Requirement already satisfied: tensorflow-estimator~=2.6 in /home/omar/.local/lib/python3.8/site-packages (from tensorflow<2.7.0,>=2.6.0->tensorflow-io) (2.6.0)\n", + "Requirement already satisfied: typing-extensions~=3.7.4 in /home/omar/.local/lib/python3.8/site-packages (from tensorflow<2.7.0,>=2.6.0->tensorflow-io) (3.7.4.3)\n", + "Requirement already satisfied: termcolor~=1.1.0 in /home/omar/.local/lib/python3.8/site-packages (from tensorflow<2.7.0,>=2.6.0->tensorflow-io) (1.1.0)\n", + "Requirement already satisfied: wrapt~=1.12.1 in /home/omar/.local/lib/python3.8/site-packages (from tensorflow<2.7.0,>=2.6.0->tensorflow-io) (1.12.1)\n", + "Requirement already satisfied: google-pasta~=0.2 in /home/omar/.local/lib/python3.8/site-packages (from tensorflow<2.7.0,>=2.6.0->tensorflow-io) (0.2.0)\n", + "Requirement already satisfied: keras~=2.6 in /home/omar/.local/lib/python3.8/site-packages (from tensorflow<2.7.0,>=2.6.0->tensorflow-io) (2.6.0)\n", + "Requirement already satisfied: six~=1.15.0 in /home/omar/.local/lib/python3.8/site-packages (from tensorflow<2.7.0,>=2.6.0->tensorflow-io) (1.15.0)\n", + "Requirement already satisfied: numpy~=1.19.2 in /home/omar/.local/lib/python3.8/site-packages (from tensorflow<2.7.0,>=2.6.0->tensorflow-io) (1.19.5)\n", + "Requirement already satisfied: opt-einsum~=3.3.0 in /home/omar/.local/lib/python3.8/site-packages (from tensorflow<2.7.0,>=2.6.0->tensorflow-io) (3.3.0)\n", + "Requirement already satisfied: wheel~=0.35 in /home/omar/.local/lib/python3.8/site-packages (from tensorflow<2.7.0,>=2.6.0->tensorflow-io) (0.37.0)\n", + "Requirement already satisfied: astunparse~=1.6.3 in /home/omar/.local/lib/python3.8/site-packages (from tensorflow<2.7.0,>=2.6.0->tensorflow-io) (1.6.3)\n", + "Requirement already satisfied: absl-py~=0.10 in /home/omar/.local/lib/python3.8/site-packages (from tensorflow<2.7.0,>=2.6.0->tensorflow-io) (0.13.0)\n", + "Requirement already satisfied: keras-preprocessing~=1.1.2 in /home/omar/.local/lib/python3.8/site-packages (from tensorflow<2.7.0,>=2.6.0->tensorflow-io) (1.1.2)\n", + "Requirement already satisfied: h5py~=3.1.0 in /home/omar/.local/lib/python3.8/site-packages (from tensorflow<2.7.0,>=2.6.0->tensorflow-io) (3.1.0)\n", + "Requirement already satisfied: clang~=5.0 in /home/omar/.local/lib/python3.8/site-packages (from tensorflow<2.7.0,>=2.6.0->tensorflow-io) (5.0)\n", + "Requirement already satisfied: flatbuffers~=1.12.0 in /home/omar/.local/lib/python3.8/site-packages (from tensorflow<2.7.0,>=2.6.0->tensorflow-io) (1.12)\n", + "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /home/omar/.local/lib/python3.8/site-packages (from tensorboard~=2.6->tensorflow<2.7.0,>=2.6.0->tensorflow-io) (0.4.6)\n", + "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /home/omar/.local/lib/python3.8/site-packages (from tensorboard~=2.6->tensorflow<2.7.0,>=2.6.0->tensorflow-io) (1.8.0)\n", + "Requirement already satisfied: markdown>=2.6.8 in /home/omar/.local/lib/python3.8/site-packages (from tensorboard~=2.6->tensorflow<2.7.0,>=2.6.0->tensorflow-io) (3.3.4)\n", + "Requirement already satisfied: google-auth<2,>=1.6.3 in /home/omar/.local/lib/python3.8/site-packages (from tensorboard~=2.6->tensorflow<2.7.0,>=2.6.0->tensorflow-io) (1.35.0)\n", + "Requirement already satisfied: setuptools>=41.0.0 in /usr/lib/python3/dist-packages (from tensorboard~=2.6->tensorflow<2.7.0,>=2.6.0->tensorflow-io) (45.2.0)\n", + "Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /home/omar/.local/lib/python3.8/site-packages (from tensorboard~=2.6->tensorflow<2.7.0,>=2.6.0->tensorflow-io) (0.6.1)\n", + "Requirement already satisfied: requests<3,>=2.21.0 in /usr/lib/python3/dist-packages (from tensorboard~=2.6->tensorflow<2.7.0,>=2.6.0->tensorflow-io) (2.22.0)\n", + "Requirement already satisfied: werkzeug>=0.11.15 in /home/omar/.local/lib/python3.8/site-packages (from tensorboard~=2.6->tensorflow<2.7.0,>=2.6.0->tensorflow-io) (2.0.1)\n", + "Requirement already satisfied: pyasn1-modules>=0.2.1 in /home/omar/.local/lib/python3.8/site-packages (from google-auth<2,>=1.6.3->tensorboard~=2.6->tensorflow<2.7.0,>=2.6.0->tensorflow-io) (0.2.8)\n", + "Requirement already satisfied: rsa<5,>=3.1.4 in /home/omar/.local/lib/python3.8/site-packages (from google-auth<2,>=1.6.3->tensorboard~=2.6->tensorflow<2.7.0,>=2.6.0->tensorflow-io) (4.7.2)\n", + "Requirement already satisfied: cachetools<5.0,>=2.0.0 in /home/omar/.local/lib/python3.8/site-packages (from google-auth<2,>=1.6.3->tensorboard~=2.6->tensorflow<2.7.0,>=2.6.0->tensorflow-io) (4.2.2)\n", + "Requirement already satisfied: requests-oauthlib>=0.7.0 in /home/omar/.local/lib/python3.8/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.6->tensorflow<2.7.0,>=2.6.0->tensorflow-io) (1.3.0)\n", + "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /home/omar/.local/lib/python3.8/site-packages (from pyasn1-modules>=0.2.1->google-auth<2,>=1.6.3->tensorboard~=2.6->tensorflow<2.7.0,>=2.6.0->tensorflow-io) (0.4.8)\n", + "Requirement already satisfied: oauthlib>=3.0.0 in /usr/lib/python3/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.6->tensorflow<2.7.0,>=2.6.0->tensorflow-io) (3.1.0)\n", + "\u001b[33mWARNING: You are using pip version 21.2.4; however, version 22.0.4 is available.\n", + "You should consider upgrading via the '/usr/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\n" + ] + } + ], + "source": [ + "try:\n", + " %tensorflow_version 2.x \n", + "except Exception:\n", + " pass\n", + "\n", + "!pip install tensorflow-io" + ] + }, + { + "cell_type": "markdown", + "id": "bf7de300", + "metadata": { + "id": "d5e736c41c99" + }, + "source": [ + "### Create Pool and Container" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "79528fed", + "metadata": { + "id": "fb83b02da201" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/usr/bin/sh: 1: dmg: not found\n", + "/usr/bin/sh: 1: daos: not found\n" + ] + } + ], + "source": [ + "!dmg -i pool create -s 500M TEST_POOL\n", + "!daos cont create --pool=TEST_POOL --type=POSIX TEST_CONT" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "f9e03445ca2b" + }, + "source": [ + "Importing the needed libraries" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "c9d707f548ed" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "import tensorflow_io as tfio" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "f5707958e9b2" + }, + "source": [ + "Initializing our dfs path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8b3d02b4bdce" + }, + "outputs": [], + "source": [ + "dfs_url = \"daos://TEST_POOL/TEST_CONT/\" # This the path you'll be using to load and access the dataset\n", + "pwd = !pwd\n", + "posix_url = pwd[0] + \"/tests/test_dfs/\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cb041488c5cc" + }, + "source": [ + "Install Datasets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3d67936c99b0" + }, + "outputs": [], + "source": [ + "!wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz -O $(pwd)/tests/test_dfs/train.gz\n", + "!wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz -O $(pwd)/tests/test_dfs/train_labels.gz\n", + "!wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz -O $(pwd)/tests/test_dfs/test.gz\n", + "!wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz -O $(pwd)/tests/test_dfs/test_labels.gz" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "b40e9a30808c" + }, + "source": [ + "Copying the Data from the POSIX Filesystem to the DAOS Filesystem under the pool and container you just created" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "c7a9cb50149f" + }, + "outputs": [], + "source": [ + "file_names = [\"train.gz\", \"test.gz\", \"train_labels.gz\", \"test_labels.gz\"]\n", + "for file in file_names:\n", + " tf.io.gfile.copy(posix_url + file, dfs_url + file, True)\n", + " \n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1879a8438874" + }, + "source": [ + "Checking Our Training Images and Training Labels Exist under the specified pool and container" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4831a44b46c4" + }, + "outputs": [], + "source": [ + "images = dfs_url + \"train.gz\"\n", + "labels = dfs_url + \"train_labels.gz\"\n", + "if tf.io.gfile.exists(images) and tf.io.gfile.exists(labels):\n", + " print(\"True\")\n", + "else:\n", + " print(\"False\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "d7453b98fc59" + }, + "source": [ + "Loading MNIST Data from the DFS using tensorflow-io's built in MNIST loading functionality" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4b904e495cf5" + }, + "outputs": [], + "source": [ + "d_train = tfio.IODataset.from_mnist(\n", + " images,\n", + " labels\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "43e0588f94e4" + }, + "source": [ + "Pre-processing and Building a simple Keras Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "250109055744" + }, + "outputs": [], + "source": [ + "# Shuffle the elements of the dataset.\n", + "d_train = d_train.shuffle(buffer_size=1024)\n", + "\n", + "# By default image data is uint8, so convert to float32 using map().\n", + "d_train = d_train.map(lambda x, y: (tf.image.convert_image_dtype(x, tf.float32), y))\n", + "\n", + "# prepare batches the data just like any other tf.data.Dataset\n", + "d_train = d_train.batch(32)\n", + "\n", + "# Build the model.\n", + "model = tf.keras.models.Sequential(\n", + " [\n", + " tf.keras.layers.Flatten(input_shape=(28, 28)),\n", + " tf.keras.layers.Dense(512, activation=tf.nn.relu),\n", + " tf.keras.layers.Dropout(0.2),\n", + " tf.keras.layers.Dense(10, activation=tf.nn.softmax),\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4129e8e2c1b4" + }, + "source": [ + "Compiling the model you just built" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7c9302dea1da" + }, + "outputs": [], + "source": [ + "model.compile(\n", + " optimizer=\"adam\", loss=\"sparse_categorical_crossentropy\", metrics=[\"accuracy\"]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6193df954a7c" + }, + "source": [ + "And finally, training on the dataset for 5 epochs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "629b3388da02" + }, + "outputs": [], + "source": [ + "history = model.fit(d_train, epochs=15, steps_per_epoch=100)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bfb6caec04ac" + }, + "source": [ + "Plot of Loss vs Epoch during Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "14d751d25850" + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "plt.plot(history.history['loss'])\n", + "plt.title('model loss')\n", + "plt.ylabel('loss')\n", + "plt.xlabel('epoch')\n", + "plt.legend(['train'], loc='upper left')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fdcea14afec8" + }, + "source": [ + "Check Test Data is available" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9c235960fd75" + }, + "outputs": [], + "source": [ + "test_images = dfs_url + \"test.gz\"\n", + "test_labels = dfs_url + \"test_labels.gz\"\n", + "if tf.io.gfile.exists(test_images) and tf.io.gfile.exists(test_labels):\n", + " print(\"True\")\n", + "else:\n", + " print(\"False\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7bcaff8a9330" + }, + "source": [ + "Apply same pre-processing and batching on test data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ecd03ced1c6a" + }, + "outputs": [], + "source": [ + "d_test = tfio.IODataset.from_mnist(\n", + " test_images,\n", + " test_labels,\n", + ")\n", + "\n", + "# Shuffle the elements of the dataset.\n", + "d_test = d_test.shuffle(buffer_size=1024)\n", + "\n", + "# By default image data is uint8, so convert to float32 using map().\n", + "d_test = d_test.map(lambda x, y: (tf.image.convert_image_dtype(x, tf.float32), y))\n", + "\n", + "# prepare batches the data just like any other tf.data.Dataset\n", + "d_test = d_test.batch(32)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "c4684eb989dd" + }, + "source": [ + "Evaluate our model on both test and train data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "de7d13fa71ef" + }, + "outputs": [], + "source": [ + "_, train_acc = model.evaluate(d_train, verbose=0)\n", + "_, test_acc = model.evaluate(d_test, verbose=0)\n", + "print('Train: %.3f, Test: %.3f' % (train_acc, test_acc))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6a006d7d2fcb" + }, + "source": [ + "Prediction Example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "701b75aea6c5" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "iterator = iter(d_test)\n", + "elem = iterator.get_next()[0][0]\n", + "plt.imshow(elem)\n", + "prediction = model.predict(np.array([elem]))\n", + "result = np.where(prediction[0] == np.amax(prediction[0]))\n", + "print(\"Predicted Value is\" ,result[0][0])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "21b12f92e0a2" + }, + "source": [ + "### Cleanup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "d847aaa57005" + }, + "outputs": [], + "source": [ + "!dmg -i pool destroy -f TEST_POOL" + ] + } + ], + "metadata": { + "colab": { + "name": "daos.ipynb", + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow_io/core/filesystems/BUILD b/tensorflow_io/core/filesystems/BUILD index 780b6a2e4..e1a4bc1c3 100644 --- a/tensorflow_io/core/filesystems/BUILD +++ b/tensorflow_io/core/filesystems/BUILD @@ -39,6 +39,12 @@ cc_library( "//tensorflow_io/core/filesystems/hdfs", "//tensorflow_io/core/filesystems/http", "//tensorflow_io/core/filesystems/s3", - ], + ] + select({ + "@bazel_tools//src/conditions:windows": [], + "@bazel_tools//src/conditions:darwin": [], + "//conditions:default": [ + "//tensorflow_io/core/filesystems/dfs", + ], + }), alwayslink = 1, ) diff --git a/tensorflow_io/core/filesystems/dfs/BUILD b/tensorflow_io/core/filesystems/dfs/BUILD new file mode 100644 index 000000000..8fbdc704d --- /dev/null +++ b/tensorflow_io/core/filesystems/dfs/BUILD @@ -0,0 +1,35 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) + +load( + "//:tools/build/tensorflow_io.bzl", + "tf_io_copts", +) + +cc_library( + name = "dfs", + srcs = [] + select({ + "@bazel_tools//src/conditions:windows": [], + "@bazel_tools//src/conditions:darwin": [], + "//conditions:default": [ + "dfs_filesystem.cc", + "dfs_utils.cc", + "dfs_utils.h", + ], + }), + copts = tf_io_copts(), + linkstatic = True, + deps = [ + "//tensorflow_io/core/filesystems:filesystem_plugins_header", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + ] + select({ + "@bazel_tools//src/conditions:windows": [], + "@bazel_tools//src/conditions:darwin": [], + "//conditions:default": [ + "@daos", + ], + }), + alwayslink = 1, +) diff --git a/tensorflow_io/core/filesystems/dfs/dfs_filesystem.cc b/tensorflow_io/core/filesystems/dfs/dfs_filesystem.cc new file mode 100644 index 000000000..d3bbb1de5 --- /dev/null +++ b/tensorflow_io/core/filesystems/dfs/dfs_filesystem.cc @@ -0,0 +1,785 @@ +#include + +#include "tensorflow_io/core/filesystems/dfs/dfs_utils.h" +#undef NDEBUG +#include + +namespace tensorflow { +namespace io { +namespace dfs { + +// SECTION 1. Implementation for `TF_RandomAccessFile` +// ---------------------------------------------------------------------------- +namespace tf_random_access_file { +typedef struct DFSRandomAccessFile { + dfs_path_t dpath; + DFS* daos; + dfs_t* daos_fs; + dfs_obj_t* daos_file; + std::vector buffers; + daos_size_t file_size; + bool caching; + size_t buff_size; + size_t num_of_buffers; + DFSRandomAccessFile(dfs_path_t* path, dfs_obj_t* obj, daos_handle_t eq_handle) + : dpath(*path) { + daos = dpath.getDAOS(); + daos_fs = dpath.getFsys(); + daos_file = obj; + + if (dpath.getCachedSize(file_size) != 0) { + daos->libdfs->dfs_get_size(daos_fs, obj, &file_size); + dpath.setCachedSize(file_size); + } + if (char* env_caching = std::getenv("TF_IO_DAOS_CACHING")) { + caching = atoi(env_caching) > 0; + } else { + caching = false; + } + + if (caching) { + if (char* env_num_of_buffers = std::getenv("TF_IO_DAOS_NUM_OF_BUFFERS")) { + num_of_buffers = atoi(env_num_of_buffers); + } else { + num_of_buffers = NUM_OF_BUFFERS; + } + + if (char* env_buff_size = std::getenv("TF_IO_DAOS_BUFFER_SIZE")) { + buff_size = GetStorageSize(env_buff_size); + } else { + buff_size = BUFF_SIZE; + } + for (size_t i = 0; i < num_of_buffers; i++) { + buffers.push_back(ReadBuffer(i, daos, eq_handle, buff_size)); + } + } + } + + int64_t ReadNoCache(uint64_t offset, size_t n, char* buffer, + TF_Status* status) { + int rc; + d_sg_list_t rsgl; + d_iov_t iov; + d_iov_set(&iov, (void*)buffer, n); + rsgl.sg_nr = 1; + rsgl.sg_iovs = &iov; + + daos_size_t read_size; + + rc = daos->libdfs->dfs_read(daos_fs, daos_file, &rsgl, offset, &read_size, + NULL); + if (rc) { + TF_SetStatus(status, TF_INTERNAL, ""); + return read_size; + } + + if (read_size != n) { + TF_SetStatus(status, TF_OUT_OF_RANGE, ""); + return read_size; + } + + TF_SetStatus(status, TF_OK, ""); + return read_size; + } +} DFSRandomAccessFile; + +void Cleanup(TF_RandomAccessFile* file) { + int rc = 0; + auto dfs_file = static_cast(file->plugin_file); + dfs_file->buffers.clear(); + + rc = dfs_file->daos->libdfs->dfs_release(dfs_file->daos_file); + assert(rc == 0); + dfs_file->daos_fs = nullptr; + delete dfs_file; +} + +int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n, + char* ret, TF_Status* status) { + auto dfs_file = static_cast(file->plugin_file); + if (offset >= dfs_file->file_size) { + TF_SetStatus(status, TF_OUT_OF_RANGE, ""); + return -1; + } + + if (!dfs_file->caching) { + return dfs_file->ReadNoCache(offset, n, ret, status); + } + + size_t ret_offset = 0; + size_t curr_offset = offset; + int64_t total_bytes = 0; + size_t ret_size = offset + n; + while (curr_offset < ret_size && curr_offset < dfs_file->file_size) { + int64_t read_bytes = 0; + for (auto& read_buf : dfs_file->buffers) { + if (read_buf.CacheHit(curr_offset)) { + read_bytes = read_buf.CopyFromCache(ret, ret_offset, curr_offset, n, + dfs_file->file_size, status); + break; + } + } + + if (read_bytes < 0) { + return -1; + } + + if (read_bytes > 0) { + curr_offset += read_bytes; + ret_offset += read_bytes; + total_bytes += read_bytes; + n -= read_bytes; + continue; + } + + size_t async_offset = curr_offset; + for (size_t i = 0; i < dfs_file->buffers.size(); i++) { + if (async_offset > dfs_file->file_size) break; + dfs_file->buffers[i].ReadAsync(dfs_file->daos_fs, dfs_file->daos_file, + async_offset, dfs_file->file_size); + async_offset += dfs_file->buff_size; + } + } + + return total_bytes; +} + +} // namespace tf_random_access_file + +// SECTION 2. Implementation for `TF_WritableFile` +// ---------------------------------------------------------------------------- +namespace tf_writable_file { +typedef struct DFSWritableFile { + dfs_path_t dpath; + DFS* daos; + dfs_t* daos_fs; + dfs_obj_t* daos_file; + daos_size_t file_size; + bool size_known; + + DFSWritableFile(dfs_path_t* path, dfs_obj_t* obj) : dpath(*path) { + daos = dpath.getDAOS(); + daos_fs = dpath.getFsys(); + daos_file = obj; + size_known = false; + daos_size_t dummy; // initialize file_size + get_file_size(dummy); + } + + int get_file_size(daos_size_t& size) { + if (!size_known) { + int rc = daos->libdfs->dfs_get_size(daos_fs, daos_file, &file_size); + if (rc != 0) { + return rc; + } + dpath.setCachedSize(file_size); + size_known = true; + } + size = file_size; + return 0; + } + + void set_file_size(daos_size_t size) { + dpath.setCachedSize(size); + file_size = size; + size_known = true; + } + + void unset_file_size(void) { + dpath.clearCachedSize(); + size_known = false; + } +} DFSWritableFile; + +void Cleanup(TF_WritableFile* file) { + auto dfs_file = static_cast(file->plugin_file); + dfs_file->daos->libdfs->dfs_release(dfs_file->daos_file); + dfs_file->daos_fs = nullptr; + delete dfs_file; +} + +void Append(const TF_WritableFile* file, const char* buffer, size_t n, + TF_Status* status) { + d_sg_list_t wsgl; + d_iov_t iov; + int rc; + auto dfs_file = static_cast(file->plugin_file); + + d_iov_set(&iov, (void*)buffer, n); + wsgl.sg_nr = 1; + wsgl.sg_iovs = &iov; + + daos_size_t cur_file_size; + rc = dfs_file->get_file_size(cur_file_size); + if (rc != 0) { + TF_SetStatus(status, TF_INTERNAL, "Cannot determine file size"); + return; + } + + rc = dfs_file->daos->libdfs->dfs_write(dfs_file->daos_fs, dfs_file->daos_file, + &wsgl, cur_file_size, NULL); + if (rc) { + TF_SetStatus(status, TF_RESOURCE_EXHAUSTED, ""); + dfs_file->unset_file_size(); + return; + } + + dfs_file->set_file_size(cur_file_size + n); + TF_SetStatus(status, TF_OK, ""); +} + +int64_t Tell(const TF_WritableFile* file, TF_Status* status) { + auto dfs_file = static_cast(file->plugin_file); + + daos_size_t cur_file_size; + int rc = dfs_file->get_file_size(cur_file_size); + if (rc != 0) { + TF_SetStatus(status, TF_INTERNAL, "Cannot determine file size"); + return -1; + } + + TF_SetStatus(status, TF_OK, ""); + return cur_file_size; +} + +void Close(const TF_WritableFile* file, TF_Status* status) { + auto dfs_file = static_cast(file->plugin_file); + dfs_file->daos->libdfs->dfs_release(dfs_file->daos_file); + dfs_file->daos_fs = nullptr; + dfs_file->daos_file = nullptr; + TF_SetStatus(status, TF_OK, ""); +} + +} // namespace tf_writable_file + +// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion` +// ---------------------------------------------------------------------------- +namespace tf_read_only_memory_region { +void Cleanup(TF_ReadOnlyMemoryRegion* region) {} + +const void* Data(const TF_ReadOnlyMemoryRegion* region) { return nullptr; } + +uint64_t Length(const TF_ReadOnlyMemoryRegion* region) { return 0; } + +} // namespace tf_read_only_memory_region + +// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem +// ---------------------------------------------------------------------------- +namespace tf_dfs_filesystem { + +void atexit_handler(void); // forward declaration + +static TF_Filesystem* dfs_filesystem; + +void Init(TF_Filesystem* filesystem, TF_Status* status) { + filesystem->plugin_filesystem = new (std::nothrow) DFS(status); + + // tensorflow never calls Cleanup(), see + // https://github.com/tensorflow/tensorflow/issues/27535 + // The workaround is to implement its code via atexit() which in turn + // requires that a static pointer to the plugin be kept for use at exit time. + if (TF_GetCode(status) == TF_OK) { + dfs_filesystem = filesystem; + std::atexit(atexit_handler); + } +} + +void Cleanup(TF_Filesystem* filesystem) { + auto daos = static_cast(filesystem->plugin_filesystem); + delete daos; +} + +void atexit_handler(void) { + // delete dfs_filesystem; + Cleanup(dfs_filesystem); +} + +void NewFile(const TF_Filesystem* filesystem, const char* path, File_Mode mode, + int flags, dfs_path_t& dpath, dfs_obj_t** obj, TF_Status* status) { + int rc; + auto daos = static_cast(filesystem->plugin_filesystem); + + rc = daos->Setup(daos, path, dpath, status); + if (rc) return; + + daos->dfsNewFile(&dpath, mode, flags, obj, status); +} + +void NewWritableFile(const TF_Filesystem* filesystem, const char* path, + TF_WritableFile* file, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + + dfs_path_t dpath; + dfs_obj_t* obj = NULL; + NewFile(filesystem, path, WRITE, S_IRUSR | S_IWUSR | S_IFREG, dpath, &obj, + status); + if (TF_GetCode(status) != TF_OK) return; + + file->plugin_file = new tf_writable_file::DFSWritableFile(&dpath, obj); +} + +void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path, + TF_RandomAccessFile* file, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + auto daos = static_cast(filesystem->plugin_filesystem); + + dfs_path_t dpath; + dfs_obj_t* obj = NULL; + NewFile(filesystem, path, READ, S_IRUSR | S_IFREG, dpath, &obj, status); + if (TF_GetCode(status) != TF_OK) return; + + auto random_access_file = new tf_random_access_file::DFSRandomAccessFile( + &dpath, obj, daos->mEventQueueHandle); + if (random_access_file->caching) { + size_t async_offset = 0; + for (size_t i = 0; i < random_access_file->num_of_buffers; i++) { + if (async_offset > random_access_file->file_size) break; + random_access_file->buffers[i].ReadAsync( + random_access_file->daos_fs, random_access_file->daos_file, + async_offset, random_access_file->file_size); + async_offset += random_access_file->buff_size; + } + } + file->plugin_file = random_access_file; +} + +void NewAppendableFile(const TF_Filesystem* filesystem, const char* path, + TF_WritableFile* file, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + + dfs_path_t dpath; + dfs_obj_t* obj = NULL; + NewFile(filesystem, path, APPEND, S_IRUSR | S_IWUSR | S_IFREG, dpath, &obj, + status); + if (TF_GetCode(status) != TF_OK) return; + + file->plugin_file = new tf_writable_file::DFSWritableFile(&dpath, obj); +} + +static void PathExists(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + int rc; + auto daos = static_cast(filesystem->plugin_filesystem); + + dfs_path_t dpath; + rc = daos->Setup(daos, path, dpath, status); + if (rc) return; + + dfs_obj_t* obj; + rc = daos->dfsLookUp(&dpath, &obj, status); + if (rc) return; + + rc = daos->libdfs->dfs_release(obj); + if (rc) { + TF_SetStatus(status, TF_INTERNAL, ""); + } +} + +static void CreateDir(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + int rc; + auto daos = static_cast(filesystem->plugin_filesystem); + + dfs_path_t dpath; + rc = daos->Setup(daos, path, dpath, status); + if (rc) return; + + daos->dfsCreateDir(&dpath, status); +} + +static void RecursivelyCreateDir(const TF_Filesystem* filesystem, + const char* path, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + int rc; + auto daos = static_cast(filesystem->plugin_filesystem); + + dfs_path_t dpath; + rc = daos->Setup(daos, path, dpath, status); + if (rc) return; + + size_t next_dir = 0; + std::string dir_string; + std::string path_string = dpath.getRelPath(); + do { + next_dir = path_string.find("/", next_dir); + if (next_dir == 0) { + dpath.setRelPath("/"); + } else { + dpath.setRelPath(path_string.substr(0, next_dir)); + } + if (next_dir != std::string::npos) next_dir++; + TF_SetStatus(status, TF_OK, ""); + daos->dfsCreateDir(&dpath, status); + if ((TF_GetCode(status) != TF_OK) && + (TF_GetCode(status) != TF_ALREADY_EXISTS)) { + return; + } + } while (next_dir != std::string::npos); + + if (TF_GetCode(status) == TF_ALREADY_EXISTS) { + TF_SetStatus(status, TF_OK, ""); // per modular_filesystem_test suite + } +} + +void DeleteFileSystemEntry(const TF_Filesystem* filesystem, const char* path, + bool recursive, bool is_dir, TF_Status* status) { + int rc; + auto daos = static_cast(filesystem->plugin_filesystem); + + dfs_path_t dpath; + rc = daos->Setup(daos, path, dpath, status); + if (rc) { + return; + } + daos->dfsDeleteObject(&dpath, is_dir, recursive, status); +} + +static void DeleteSingleDir(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + bool recursive = false; + bool is_dir = true; + DeleteFileSystemEntry(filesystem, path, recursive, is_dir, status); +} + +static void RecursivelyDeleteDir(const TF_Filesystem* filesystem, + const char* path, uint64_t* undeleted_files, + uint64_t* undeleted_dirs, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + bool recursive = true; + bool is_dir = true; + DeleteFileSystemEntry(filesystem, path, recursive, is_dir, status); + if (TF_GetCode(status) == TF_NOT_FOUND || + TF_GetCode(status) == TF_FAILED_PRECONDITION) { + *undeleted_dirs = 1; + *undeleted_files = 0; + } else { + *undeleted_dirs = 0; + *undeleted_files = 0; + } +} + +// Note: the signature for is_directory() has a bool for the return value, but +// tensorflow does not use this, instead it interprets the status field to get +// the result. A value of TF_OK indicates that the object is a directory, and +// a value of TF_FAILED_PRECONDITION indicates that the object is a file. All +// other status values throw an exception. + +static bool IsDir(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + int rc; + auto daos = static_cast(filesystem->plugin_filesystem); + + dfs_path_t dpath; + rc = daos->Setup(daos, path, dpath, status); + if (rc) return false; + + dfs_obj_t* obj; + rc = daos->dfsLookUp(&dpath, &obj, status); + if (rc) return false; + + bool is_dir = daos->dfsIsDirectory(obj); + + rc = daos->libdfs->dfs_release(obj); + if (rc) { + TF_SetStatus(status, TF_INTERNAL, ""); + return false; + } + + if (!is_dir) { + TF_SetStatus(status, TF_FAILED_PRECONDITION, ""); + return false; + } + return true; +} + +static int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + int rc; + auto daos = static_cast(filesystem->plugin_filesystem); + + dfs_path_t dpath; + rc = daos->Setup(daos, path, dpath, status); + if (rc) return -1; + + daos_size_t size; + if (dpath.getCachedSize(size) == 0) { + return size; + } + dfs_obj_t* obj; + rc = daos->dfsLookUp(&dpath, &obj, status); + if (rc) { + return -1; + } + + if (daos->dfsIsDirectory(obj)) { + daos->libdfs->dfs_release(obj); + TF_SetStatus(status, TF_FAILED_PRECONDITION, ""); + return -1; + } + + daos->libdfs->dfs_get_size(dpath.getFsys(), obj, &size); + dpath.setCachedSize(size); + + rc = daos->libdfs->dfs_release(obj); + if (rc) { + TF_SetStatus(status, TF_INTERNAL, ""); + return -1; + } + return size; +} + +static void DeleteFile(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + bool recursive = false; + bool is_dir = false; + DeleteFileSystemEntry(filesystem, path, recursive, is_dir, status); +} + +static void RenameFile(const TF_Filesystem* filesystem, const char* src, + const char* dst, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + int rc; + auto daos = static_cast(filesystem->plugin_filesystem); + + dfs_path_t src_dpath; + rc = daos->Setup(daos, src, src_dpath, status); + if (rc) return; + + dfs_path_t dst_dpath; + rc = daos->Setup(daos, dst, dst_dpath, status); + if (rc) return; + + if (src_dpath.getFsys() != dst_dpath.getFsys()) { + TF_SetStatus(status, TF_FAILED_PRECONDITION, "Non-Matching Pool/Container"); + return; + } + + // Source object must exist + dfs_obj_t* temp_obj; + rc = daos->dfsLookUp(&src_dpath, &temp_obj, status); + if (rc) { + return; + } + + // Source object cannot be a directory + bool is_dir = daos->dfsIsDirectory(temp_obj); + daos_size_t src_size; + + if (is_dir) { + TF_SetStatus(status, TF_FAILED_PRECONDITION, ""); + } else { + if (src_dpath.getCachedSize(src_size) != 0) { + daos->libdfs->dfs_get_size(src_dpath.getFsys(), temp_obj, &src_size); + } + } + + rc = daos->libdfs->dfs_release(temp_obj); + if (rc) { + TF_SetStatus(status, TF_INTERNAL, ""); + } + if (TF_GetCode(status) != TF_OK) { + return; + } + + // Destination object may or may not exist, but must not be a directory. + rc = daos->dfsLookUp(&dst_dpath, &temp_obj, status); + if (rc) { + if (TF_GetCode(status) != TF_NOT_FOUND) { + return; + } + TF_SetStatus(status, TF_OK, ""); + } else { + bool is_dir = daos->dfsIsDirectory(temp_obj); + rc = daos->libdfs->dfs_release(temp_obj); + if (rc) { + TF_SetStatus(status, TF_INTERNAL, ""); + return; + } + if (is_dir) { + TF_SetStatus(status, TF_FAILED_PRECONDITION, ""); + return; + } + } + + // Open the parent objects. Note that these are cached directory entries, + // not to be closed by this function. + + dfs_obj_t* parent_src = NULL; + rc = daos->dfsFindParent(&src_dpath, &parent_src, status); + if (rc) { + TF_SetStatus(status, TF_NOT_FOUND, ""); + return; + } + + dfs_obj_t* parent_dst = NULL; + rc = daos->dfsFindParent(&dst_dpath, &parent_dst, status); + if (rc) { + TF_SetStatus(status, TF_NOT_FOUND, ""); + return; + } + + if (!daos->dfsIsDirectory(parent_dst)) { + TF_SetStatus(status, TF_FAILED_PRECONDITION, ""); + return; + } + + std::string src_name = src_dpath.getBaseName(); + std::string dst_name = dst_dpath.getBaseName(); + + rc = daos->libdfs->dfs_move(src_dpath.getFsys(), parent_src, src_name.c_str(), + parent_dst, dst_name.c_str(), NULL); + if (rc) { + TF_SetStatus(status, TF_INTERNAL, ""); + return; + } + + dst_dpath.setCachedSize(src_size); + src_dpath.clearCachedSize(); +} + +static void Stat(const TF_Filesystem* filesystem, const char* path, + TF_FileStatistics* stats, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + int rc; + auto daos = static_cast(filesystem->plugin_filesystem); + + dfs_path_t dpath; + rc = daos->Setup(daos, path, dpath, status); + if (rc) return; + + dfs_obj_t* obj; + rc = daos->dfsLookUp(&dpath, &obj, status); + if (rc) return; + + struct stat stbuf; + rc = daos->libdfs->dfs_ostat(dpath.getFsys(), obj, &stbuf); + if (rc) { + daos->libdfs->dfs_release(obj); + TF_SetStatus(status, TF_INTERNAL, ""); + return; + } + + stats->length = stbuf.st_size; + stats->mtime_nsec = static_cast(stbuf.st_mtime) * 1e9; + if (daos->dfsIsDirectory(obj)) { + stats->is_directory = true; + } else { + stats->is_directory = false; + } + + rc = daos->libdfs->dfs_release(obj); + if (rc) { + TF_SetStatus(status, TF_INTERNAL, ""); + } +} + +static int GetChildren(const TF_Filesystem* filesystem, const char* path, + char*** entries, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + int rc; + auto daos = static_cast(filesystem->plugin_filesystem); + + dfs_path_t dpath; + rc = daos->Setup(daos, path, dpath, status); + if (rc) return -1; + + dfs_obj_t* obj; + rc = daos->dfsLookUp(&dpath, &obj, status); + if (rc) { + return -1; + } + + if (!daos->dfsIsDirectory(obj)) { + TF_SetStatus(status, TF_FAILED_PRECONDITION, ""); + daos->libdfs->dfs_release(obj); + return -1; + } + + std::vector children; + rc = daos->dfsReadDir(dpath.getFsys(), obj, children); + daos->libdfs->dfs_release(obj); + if (rc) { + TF_SetStatus(status, TF_INTERNAL, ""); + return -1; + } + + uint32_t nr = children.size(); + + CopyEntries(entries, children); + + return nr; +} + +static char* TranslateName(const TF_Filesystem* filesystem, const char* uri) { + // Note: this function should be doing the equivalent of the + // lexically_normalize() function available in newer compilers. + return strdup(uri); +} + +static void FlushCaches(const TF_Filesystem* filesystem) { + auto daos = static_cast(filesystem->plugin_filesystem); + + daos->clearAllDirCaches(); + daos->clearAllSizeCaches(); +} + +} // namespace tf_dfs_filesystem + +void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops, const char* uri) { + TF_SetFilesystemVersionMetadata(ops); + ops->scheme = strdup(uri); + + ops->random_access_file_ops = static_cast( + plugin_memory_allocate(TF_RANDOM_ACCESS_FILE_OPS_SIZE)); + ops->random_access_file_ops->cleanup = tf_random_access_file::Cleanup; + ops->random_access_file_ops->read = tf_random_access_file::Read; + + ops->writable_file_ops = static_cast( + plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE)); + ops->writable_file_ops->cleanup = tf_writable_file::Cleanup; + ops->writable_file_ops->append = tf_writable_file::Append; + ops->writable_file_ops->tell = tf_writable_file::Tell; + ops->writable_file_ops->close = tf_writable_file::Close; + + ops->read_only_memory_region_ops = static_cast( + plugin_memory_allocate(TF_READ_ONLY_MEMORY_REGION_OPS_SIZE)); + ops->read_only_memory_region_ops->cleanup = + tf_read_only_memory_region::Cleanup; + ops->read_only_memory_region_ops->data = tf_read_only_memory_region::Data; + ops->read_only_memory_region_ops->length = tf_read_only_memory_region::Length; + + ops->filesystem_ops = static_cast( + plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE)); + ops->filesystem_ops->init = tf_dfs_filesystem::Init; + ops->filesystem_ops->cleanup = tf_dfs_filesystem::Cleanup; + ops->filesystem_ops->new_random_access_file = + tf_dfs_filesystem::NewRandomAccessFile; + ops->filesystem_ops->new_writable_file = tf_dfs_filesystem::NewWritableFile; + ops->filesystem_ops->new_appendable_file = + tf_dfs_filesystem::NewAppendableFile; + ops->filesystem_ops->path_exists = tf_dfs_filesystem::PathExists; + ops->filesystem_ops->create_dir = tf_dfs_filesystem::CreateDir; + ops->filesystem_ops->delete_dir = tf_dfs_filesystem::DeleteSingleDir; + ops->filesystem_ops->recursively_create_dir = + tf_dfs_filesystem::RecursivelyCreateDir; + ops->filesystem_ops->is_directory = tf_dfs_filesystem::IsDir; + ops->filesystem_ops->delete_recursively = + tf_dfs_filesystem::RecursivelyDeleteDir; + ops->filesystem_ops->get_file_size = tf_dfs_filesystem::GetFileSize; + ops->filesystem_ops->delete_file = tf_dfs_filesystem::DeleteFile; + ops->filesystem_ops->rename_file = tf_dfs_filesystem::RenameFile; + ops->filesystem_ops->stat = tf_dfs_filesystem::Stat; + ops->filesystem_ops->get_children = tf_dfs_filesystem::GetChildren; + ops->filesystem_ops->translate_name = tf_dfs_filesystem::TranslateName; + ops->filesystem_ops->flush_caches = tf_dfs_filesystem::FlushCaches; +} + +} // namespace dfs +} // namespace io +} // namespace tensorflow diff --git a/tensorflow_io/core/filesystems/dfs/dfs_utils.cc b/tensorflow_io/core/filesystems/dfs/dfs_utils.cc new file mode 100644 index 000000000..568fbdcd6 --- /dev/null +++ b/tensorflow_io/core/filesystems/dfs/dfs_utils.cc @@ -0,0 +1,971 @@ + +#include "tensorflow_io/core/filesystems/dfs/dfs_utils.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" +#undef NDEBUG +#include + +std::string GetStorageString(uint64_t size) { + if (size < KILO) { + return std::to_string(size); + } else if (size < MEGA) { + return std::to_string(size / KILO) + "K"; + } else if (size < GEGA) { + return std::to_string(size / MEGA) + "M"; + } else if (size < TERA) { + return std::to_string(size / GEGA) + "G"; + } else { + return std::to_string(size / TERA) + "T"; + } +} + +size_t GetStorageSize(std::string size) { + char size_char = size.back(); + size_t curr_scale = 1; + switch (size_char) { + case 'K': + size.pop_back(); + curr_scale *= 1024; + return (size_t)atoi(size.c_str()) * curr_scale; + case 'M': + size.pop_back(); + curr_scale *= 1024 * 1024; + return (size_t)atoi(size.c_str()) * curr_scale; + case 'G': + size.pop_back(); + curr_scale *= 1024 * 1024 * 1024; + return (size_t)atoi(size.c_str()) * curr_scale; + case 'T': + size.pop_back(); + curr_scale *= 1024 * 1024 * 1024; + return (size_t)atoi(size.c_str()) * curr_scale * 1024; + default: + return atoi(size.c_str()); + } +} + +mode_t GetFlags(File_Mode mode) { + switch (mode) { + case READ: + return O_RDONLY; + case WRITE: + return O_WRONLY | O_CREAT; + case APPEND: + return O_WRONLY | O_APPEND | O_CREAT; + case READWRITE: + return O_RDWR | O_CREAT; + default: + return -1; + } +} + +int ParseUUID(const std::string& str, uuid_t uuid) { + return uuid_parse(str.c_str(), uuid); +} + +void CopyEntries(char*** entries, std::vector& results) { + *entries = static_cast( + tensorflow::io::plugin_memory_allocate(results.size() * sizeof(char*))); + + for (uint32_t i = 0; i < results.size(); i++) { + (*entries)[i] = static_cast(tensorflow::io::plugin_memory_allocate( + results[i].size() * sizeof(char))); + if (results[i][0] == '/') results[i].erase(0, 1); + strcpy((*entries)[i], results[i].c_str()); + } +} + +bool Match(const std::string& filename, const std::string& pattern) { + return fnmatch(pattern.c_str(), filename.c_str(), FNM_PATHNAME) == 0; +} + +DFS::DFS(TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + + // Try to load the necessary daos libraries. + libdfs.reset(new libDFS(status)); + if (TF_GetCode(status) != TF_OK) { + libdfs.reset(nullptr); + return; + } + + int rc = libdfs->daos_init(); + if (rc) { + TF_SetStatus(status, TF_INTERNAL, "Error initializing DAOS library"); + return; + } + + rc = libdfs->daos_eq_create(&mEventQueueHandle); + if (rc) { + TF_SetStatus(status, TF_INTERNAL, + "Error initializing DAOS event queue handle"); + return; + } +} + +DFS::~DFS() { + int rc; + clearAllDirCaches(); + clearAllSizeCaches(); + + rc = libdfs->daos_eq_destroy(mEventQueueHandle, 0); + assert(rc == 0); + ClearConnections(); + + libdfs->daos_fini(); + libdfs.reset(nullptr); +} + +int DFS::ParseDFSPath(const std::string& path, std::string& pool_string, + std::string& cont_string, std::string& filename) { + struct duns_attr_t attr = {0}; + attr.da_flags = DUNS_NO_CHECK_PATH; + + int rc = libdfs->duns_resolve_path(path.c_str(), &attr); + if (rc == 0) { + pool_string = attr.da_pool; + cont_string = attr.da_cont; + filename = attr.da_rel_path == NULL ? "/" : attr.da_rel_path; + if (filename.back() == '/' && filename.size() > 1) filename.pop_back(); + libdfs->duns_destroy_attr(&attr); + } + return rc; +} + +int DFS::Setup(DFS* daos, const std::string path, dfs_path_t& dpath, + TF_Status* status) { + int allow_cont_creation = 1; + int rc; + + std::string pool, cont, rel_path; + rc = ParseDFSPath(path, pool, cont, rel_path); + if (rc) { + TF_SetStatus(status, TF_INVALID_ARGUMENT, ""); + return rc; + } + Connect(daos, pool, cont, allow_cont_creation, status); + if (TF_GetCode(status) != TF_OK) { + return -1; + } + + pool_info_t* po_inf = pools[pool]; + cont_info_t* cont_info = po_inf->containers[cont]; + dfs_path_t res(cont_info, rel_path); + dpath = res; + return 0; +} + +void DFS::Connect(DFS* daos, std::string& pool_string, std::string& cont_string, + int allow_cont_creation, TF_Status* status) { + int rc; + + rc = ConnectPool(pool_string, status); + if (rc) { + TF_SetStatus(status, TF_INTERNAL, "Error Connecting to Pool"); + return; + } + + rc = ConnectContainer(daos, pool_string, cont_string, allow_cont_creation, + status); + if (rc) { + TF_SetStatus(status, TF_INTERNAL, "Error Connecting to Container"); + return; + } + + TF_SetStatus(status, TF_OK, ""); +} + +int DFS::Query(id_handle_t pool, id_handle_t container, dfs_t* daos_fs) { + int rc; + daos_pool_info_t pool_info; + daos_cont_info_t cont_info; + if (daos_fs) { + memset(&pool_info, 'D', sizeof(daos_pool_info_t)); + pool_info.pi_bits = DPI_ALL; + rc = libdfs->daos_pool_query(pool.second, NULL, &pool_info, NULL, NULL); + if (rc) return rc; + rc = libdfs->daos_cont_query(container.second, &cont_info, NULL, NULL); + if (rc) return rc; + std::cout << "Pool " << pool.first << " ntarget=" << pool_info.pi_ntargets + << std::endl; + std::cout << "Pool space info:" << std::endl; + std::cout << "- Target(VOS) count:" << pool_info.pi_space.ps_ntargets + << std::endl; + std::cout << "- SCM:" << std::endl; + std::cout << " Total size: " + << GetStorageString(pool_info.pi_space.ps_space.s_total[0]); + std::cout << " Free: " + << GetStorageString(pool_info.pi_space.ps_space.s_free[0]) + << std::endl; + std::cout << "- NVMe:" << std::endl; + std::cout << " Total size: " + << GetStorageString(pool_info.pi_space.ps_space.s_total[1]); + std::cout << " Free: " + << GetStorageString(pool_info.pi_space.ps_space.s_free[1]) + << std::endl; + std::cout << std::endl + << "Connected Container: " << container.first << std::endl; + + return 0; + } + + return -1; +} + +int DFS::ClearConnections() { + int rc; + + for (;;) { + auto pool_it = pools.cbegin(); + if (pool_it == pools.cend()) { + break; + } + rc = DisconnectPool((*pool_it++).first); + if (rc) return rc; + } + return 0; +} + +void DFS::clearDirCache(dir_cache_t& dir_cache) { + for (auto kv = dir_cache.begin(); kv != dir_cache.end();) { + dfs_obj_t* dir = kv->second; + libdfs->dfs_release(dir); + kv = dir_cache.erase(kv); + } +} + +void DFS::clearAllDirCaches(void) { + for (auto pool_it = pools.cbegin(); pool_it != pools.cend();) { + for (auto cont_it = ((*pool_it).second->containers).cbegin(); + cont_it != ((*pool_it).second->containers).cend();) { + cont_info_t* cont = (*cont_it).second; + clearDirCache(cont->dir_map); + cont_it++; + } + pool_it++; + } +} + +void DFS::clearSizeCache(size_cache_t& size_cache) { size_cache.clear(); } + +void DFS::clearAllSizeCaches(void) { + for (auto pool_it = pools.cbegin(); pool_it != pools.cend();) { + for (auto cont_it = ((*pool_it).second->containers).cbegin(); + cont_it != ((*pool_it).second->containers).cend();) { + cont_info_t* cont = (*cont_it).second; + clearSizeCache(cont->size_map); + cont_it++; + } + pool_it++; + } +} + +int DFS::dfsDeleteObject(dfs_path_t* dpath, bool is_dir, bool recursive, + TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + dfs_obj_t* temp_obj; + + int rc = dfsLookUp(dpath, &temp_obj, status); + if (rc) return -1; + + if (dfsIsDirectory(temp_obj)) { + if (!is_dir) { + TF_SetStatus(status, TF_FAILED_PRECONDITION, "Object is a directory"); + libdfs->dfs_release(temp_obj); + return -1; + } + } else { + if (is_dir && !recursive) { + TF_SetStatus(status, TF_FAILED_PRECONDITION, "Object is not a directory"); + libdfs->dfs_release(temp_obj); + return -1; + } + } + + dfs_obj_t* parent; + rc = dfsFindParent(dpath, &parent, status); + if (rc) { + libdfs->dfs_release(temp_obj); + return -1; + } + + rc = libdfs->dfs_remove(dpath->getFsys(), parent, + dpath->getBaseName().c_str(), recursive, NULL); + libdfs->dfs_release(temp_obj); + if (rc) { + TF_SetStatus(status, TF_FAILED_PRECONDITION, + "Error Deleting Existing Object"); + return -1; + } + + if (is_dir) { + if (recursive) { + dpath->clearFsysCachedDirs(); + dpath->clearFsysCachedSizes(); + } else { + dpath->clearCachedDir(); + } + } else { + dpath->clearCachedSize(); + } + + TF_SetStatus(status, TF_OK, ""); + return 0; +} + +void DFS::dfsNewFile(dfs_path_t* dpath, File_Mode file_mode, int flags, + dfs_obj_t** obj, TF_Status* status) { + int rc; + dfs_obj_t* temp_obj; + mode_t open_flags; + + rc = dfsLookUp(dpath, &temp_obj, status); + if (rc) { + if (TF_GetCode(status) != TF_NOT_FOUND) { + return; + } + if (file_mode == READ) { + return; + } + TF_SetStatus(status, TF_OK, ""); + } + + if (temp_obj != NULL && dfsIsDirectory(temp_obj)) { + TF_SetStatus(status, TF_FAILED_PRECONDITION, ""); + libdfs->dfs_release(temp_obj); + return; + } + + if (temp_obj != NULL && file_mode == READ) { + *obj = temp_obj; + return; + } + + if (!rc && file_mode == WRITE) { + rc = dfsDeleteObject(dpath, false, false, status); + if (rc) { + libdfs->dfs_release(temp_obj); + return; + } + } + + open_flags = GetFlags(file_mode); + + dfs_obj_t* parent; + mode_t parent_mode; + rc = dfsFindParent(dpath, &parent, status); + if (rc) { + libdfs->dfs_release(temp_obj); + return; + } + rc = libdfs->dfs_get_mode(parent, &parent_mode); + if (rc) { + TF_SetStatus(status, TF_INTERNAL, "Cannot retrieve object mode"); + libdfs->dfs_release(temp_obj); + return; + } + if (parent != NULL && !S_ISDIR(parent_mode)) { + TF_SetStatus(status, TF_FAILED_PRECONDITION, ""); + libdfs->dfs_release(temp_obj); + return; + } + + std::string base_name = dpath->getBaseName(); + rc = libdfs->dfs_open(dpath->getFsys(), parent, base_name.c_str(), flags, + open_flags, 0, 0, NULL, obj); + if (rc) { + TF_SetStatus(status, TF_INTERNAL, "Error Creating Writable File"); + libdfs->dfs_release(temp_obj); + return; + } +} + +// Look up an object path and return its dfs_obj_t* in *obj. If this routine +// returns zero then *obj is guaranteed to contain a valid dfs_obj_t* which +// must be released by the caller. If non-zero, then *obj will be nullptr. If +// the object happens to be a directory, a dup of the dfs_obj_t* will also be +// separately cached in the filesystem's directory cache. + +int DFS::dfsLookUp(dfs_path_t* dpath, dfs_obj_t** obj, TF_Status* status) { + *obj = NULL; + int rc; + + // Check if the object path is for a directory we have seen before. + + dfs_obj_t* _obj = dpath->getCachedDir(); + if (_obj) { + rc = libdfs->dfs_dup(dpath->getFsys(), _obj, O_RDWR, obj); + if (rc) { + TF_SetStatus(status, TF_INTERNAL, "dfs_dup() of open directory failed"); + return -1; + } + return 0; + } + + if (dpath->isRoot()) { + rc = libdfs->dfs_lookup(dpath->getFsys(), dpath->getRelPath().c_str(), + O_RDWR, &_obj, NULL, NULL); + } else { + dfs_obj_t* parent = NULL; + rc = dfsFindParent(dpath, &parent, status); + if (rc) return -1; + + dfs_t* fsys = dpath->getFsys(); + std::string basename = dpath->getBaseName(); + rc = libdfs->dfs_lookup_rel(fsys, parent, basename.c_str(), O_RDWR, &_obj, + NULL, NULL); + } + if (rc != 0) { + if (rc == ENOENT) { + TF_SetStatus(status, TF_NOT_FOUND, ""); + } else { + TF_SetStatus(status, TF_FAILED_PRECONDITION, ""); + } + return -1; + } + + if (!dfsIsDirectory(_obj)) { + *obj = _obj; + return 0; + } + + // The object is a directory, so return the original dfs_obj_t* and store a + // dup of the original in the filesystem's directory cache. + + rc = libdfs->dfs_dup(dpath->getFsys(), _obj, O_RDWR, obj); + if (rc) { + TF_SetStatus(status, TF_INTERNAL, "dfs_dup() of open directory failed"); + return -1; + } + dpath->setCachedDir(_obj); + return 0; +} + +// Given an object pathname, return the dfs_obj_t* for its parent directory. +// Cache the parent directory if not already cached. The caller should not +// release the parent dfs_obj_t*. + +int DFS::dfsFindParent(dfs_path_t* dpath, dfs_obj_t** obj, TF_Status* status) { + *obj = NULL; + int rc; + + dfs_path_t parent_dpath = *dpath; + parent_dpath.setRelPath(dpath->getParentPath()); + + dfs_obj_t* _obj = parent_dpath.getCachedDir(); + if (_obj) { + *obj = _obj; + return 0; + } + + rc = libdfs->dfs_lookup(parent_dpath.getFsys(), + parent_dpath.getRelPath().c_str(), O_RDWR, &_obj, + NULL, NULL); + if (rc) { + if (rc == ENOENT) { + TF_SetStatus(status, TF_NOT_FOUND, ""); + } else { + TF_SetStatus(status, TF_FAILED_PRECONDITION, ""); + } + return -1; + } + + if (!dfsIsDirectory(_obj)) { + TF_SetStatus(status, TF_FAILED_PRECONDITION, + "parent object is not a directory"); + libdfs->dfs_release(_obj); + return -1; + } + + parent_dpath.setCachedDir(_obj); + *obj = _obj; + + return 0; +} + +int DFS::dfsCreateDir(dfs_path_t* dpath, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + dfs_obj_t* temp_obj; + int rc; + + rc = dfsLookUp(dpath, &temp_obj, status); + if (!rc) { + if (dfsIsDirectory(temp_obj)) { + libdfs->dfs_release(temp_obj); + TF_SetStatus(status, TF_ALREADY_EXISTS, ""); + return 0; + } else { + TF_SetStatus(status, TF_FAILED_PRECONDITION, ""); + libdfs->dfs_release(temp_obj); + return -1; + } + } else if (TF_GetCode(status) != TF_NOT_FOUND) { + return -1; + } + + TF_SetStatus(status, TF_OK, ""); + dfs_obj_t* parent; + rc = dfsFindParent(dpath, &parent, status); + if (rc) { + return rc; + } + + rc = libdfs->dfs_mkdir(dpath->getFsys(), parent, dpath->getBaseName().c_str(), + S_IWUSR | S_IRUSR | S_IXUSR, 0); + if (rc) { + TF_SetStatus(status, TF_INTERNAL, "Error Creating Directory"); + } + + return rc; +} + +bool DFS::dfsIsDirectory(dfs_obj_t* obj) { + if (obj == NULL) { + return true; + } + mode_t mode; + libdfs->dfs_get_mode(obj, &mode); + if (S_ISDIR(mode)) { + return true; + } + return false; +} + +int DFS::dfsReadDir(dfs_t* daos_fs, dfs_obj_t* obj, + std::vector& children) { + int rc = 0; + daos_anchor_t anchor = {0}; + uint32_t nr = STACK; + struct dirent* dirs = (struct dirent*)malloc(nr * sizeof(struct dirent)); + while (!daos_anchor_is_eof(&anchor)) { + rc = libdfs->dfs_readdir(daos_fs, obj, &anchor, &nr, dirs); + if (rc) { + return rc; + } + + for (uint32_t i = 0; i < nr; i++) { + children.emplace_back(dirs[i].d_name); + } + } + + free(dirs); + return rc; +} + +int DFS::ConnectPool(std::string pool_string, TF_Status* status) { + int rc = 0; + + if (pools.find(pool_string) != pools.end()) { + return 0; + } + + daos_handle_t poh; + daos_pool_info_t info; + rc = libdfs->daos_pool_connect2(pool_string.c_str(), NULL, DAOS_PC_RW, &poh, + &info, NULL); + if (rc == 0) { + pool_info_t* po_inf = new pool_info_t(); + po_inf->poh = poh; + pools[pool_string] = po_inf; + } + return rc; +} + +int DFS::ConnectContainer(DFS* daos, std::string pool_string, + std::string cont_string, int allow_creation, + TF_Status* status) { + int rc = 0; + + pool_info_t* po_inf = pools[pool_string]; + auto search = po_inf->containers.find(cont_string); + if (search != po_inf->containers.end()) { + return 0; + } + + daos_handle_t coh; + daos_cont_info_t info; + rc = libdfs->daos_cont_open2(po_inf->poh, cont_string.c_str(), DAOS_COO_RW, + &coh, &info, NULL); + if (rc == -DER_NONEXIST) { + if (allow_creation) { + rc = libdfs->dfs_cont_create_with_label(po_inf->poh, cont_string.c_str(), + NULL, NULL, &coh, NULL); + } + } + if (rc != 0) return rc; + + dfs_t* daos_fs; + rc = libdfs->dfs_mount(po_inf->poh, coh, O_RDWR, &daos_fs); + if (rc != 0) return rc; + + cont_info_t* co_inf = new cont_info_t(); + co_inf->coh = coh; + co_inf->daos = daos; + co_inf->daos_fs = daos_fs; + co_inf->pool = pool_string; + co_inf->cont = cont_string; + + po_inf->containers[cont_string] = co_inf; + return 0; +} + +int DFS::DisconnectPool(std::string pool_string) { + int rc = 0; + pool_info_t* po_inf = pools[pool_string]; + + for (;;) { + auto cont_it = po_inf->containers.cbegin(); + if (cont_it == po_inf->containers.cend()) { + break; + } + rc = DisconnectContainer(pool_string, (*cont_it++).first); + if (rc) return rc; + } + + rc = libdfs->daos_pool_disconnect(po_inf->poh, NULL); + if (rc == 0) { + delete po_inf; + pools.erase(pool_string); + } + return rc; +} + +int DFS::DisconnectContainer(std::string pool_string, std::string cont_string) { + int rc = 0; + cont_info_t* co_inf = pools[pool_string]->containers[cont_string]; + + if (co_inf->daos_fs) { + rc = libdfs->dfs_umount(co_inf->daos_fs); + if (rc) return rc; + co_inf->daos_fs = nullptr; + } + + rc = libdfs->daos_cont_close(co_inf->coh, nullptr); + if (rc == 0) { + delete co_inf; + pools[pool_string]->containers.erase(cont_string); + } + return rc; +} + +ReadBuffer::ReadBuffer(size_t aId, DFS* daos, daos_handle_t aEqh, size_t size) + : id(aId), daos(daos), buffer_size(size), eqh(aEqh) { + buffer = new char[size]; + buffer_offset = ULONG_MAX; + event = new daos_event_t; + int rc = daos->libdfs->daos_event_init(event, eqh, nullptr); + assert(rc == 0); +} + +ReadBuffer::~ReadBuffer() { + if (event != nullptr) { + WaitEvent(); + int rc = daos->libdfs->daos_event_fini(event); + assert(rc == 0); + delete event; + } + if (buffer != nullptr) { + delete[] buffer; + } +} + +ReadBuffer::ReadBuffer(ReadBuffer&& read_buffer) { + eqh = read_buffer.eqh; + buffer_size = read_buffer.buffer_size; + buffer = std::move(read_buffer.buffer); + event = std::move(read_buffer.event); + buffer_offset = ULONG_MAX; + id = read_buffer.id; + daos = read_buffer.daos; + read_buffer.buffer = nullptr; + read_buffer.event = nullptr; +} + +bool ReadBuffer::CacheHit(const size_t pos) { + return pos >= buffer_offset && (pos < buffer_offset + buffer_size); +} + +void ReadBuffer::WaitEvent() { + bool event_status; + int rc = daos->libdfs->daos_event_test(event, DAOS_EQ_WAIT, &event_status); + assert(rc == 0 && event_status == true); +} + +int ReadBuffer::ReadAsync(dfs_t* daos_fs, dfs_obj_t* file, const size_t off, + const size_t file_size) { + if (off >= file_size) { + return 0; + } + size_t buffer_actual_size = + buffer_size > (file_size - off) ? (file_size - off) : buffer_size; + WaitEvent(); + d_iov_set(&iov, (void*)buffer, buffer_actual_size); + rsgl.sg_nr = 1; + rsgl.sg_iovs = &iov; + buffer_offset = off; + int rc = daos->libdfs->daos_event_fini(event); + assert(rc == 0); + rc = daos->libdfs->daos_event_init(event, eqh, nullptr); + assert(rc == 0); + event->ev_error = daos->libdfs->dfs_read(daos_fs, file, &rsgl, buffer_offset, + &read_size, event); + return 0; +} + +int ReadBuffer::CopyData(char* ret, const size_t ret_offset, const size_t off, + const size_t n) { + WaitEvent(); + if (event->ev_error != DER_SUCCESS) { + return event->ev_error; + } + memcpy(ret + ret_offset, buffer + (off - buffer_offset), n); + return 0; +} + +int64_t ReadBuffer::CopyFromCache(char* ret, const size_t ret_offset, + const size_t off, const size_t n, + const daos_size_t file_size, + TF_Status* status) { + size_t aRead_size; + aRead_size = off + n > file_size ? file_size - off : n; + aRead_size = off + aRead_size > buffer_offset + buffer_size + ? buffer_offset + buffer_size - off + : aRead_size; + int rc = CopyData(ret, ret_offset, off, aRead_size); + if (rc) { + TF_SetStatusFromIOError(status, rc, "I/O error on dfs_read() call"); + return -1; + } + + if (off + n > file_size) { + TF_SetStatus(status, TF_OUT_OF_RANGE, ""); + } else { + TF_SetStatus(status, TF_OK, ""); + } + + return static_cast(aRead_size); +} + +dfs_path_t::dfs_path_t(cont_info_t* cont_info, std::string rel_path) + : cont_info(cont_info), rel_path(rel_path) {} + +dfs_path_t& dfs_path_t::operator=(dfs_path_t other) { + cont_info = other.cont_info; + rel_path = other.rel_path; + return *this; +} + +std::string dfs_path_t::getFullPath(void) { + std::string full_path = + "/" + cont_info->pool + "/" + cont_info->cont + rel_path; + return full_path; +} + +std::string dfs_path_t::getRelPath(void) { return rel_path; } + +std::string dfs_path_t::getParentPath(void) { + if (rel_path == "/") { + return rel_path; // root is its own parent + } + + std::string parent_path; + size_t slash_pos = rel_path.rfind("/"); + if (slash_pos == 0) { + parent_path = "/"; + } else { + parent_path = rel_path.substr(0, slash_pos); + } + return parent_path; +} + +std::string dfs_path_t::getBaseName(void) { + size_t base_start = rel_path.rfind("/") + 1; + std::string base_name = rel_path.substr(base_start); + return base_name; +} + +DFS* dfs_path_t::getDAOS(void) { return cont_info->daos; } + +dfs_t* dfs_path_t::getFsys(void) { return cont_info->daos_fs; } + +void dfs_path_t::setRelPath(std::string new_path) { rel_path = new_path; } + +bool dfs_path_t::isRoot(void) { return (rel_path == "/"); } + +dfs_obj_t* dfs_path_t::getCachedDir(void) { + auto search = cont_info->dir_map.find(rel_path); + if (search != cont_info->dir_map.end()) { + return search->second; + } else { + return nullptr; + } +} + +void dfs_path_t::setCachedDir(dfs_obj_t* dir_obj) { + cont_info->dir_map[rel_path] = dir_obj; +} + +void dfs_path_t::clearCachedDir(void) { cont_info->dir_map.erase(rel_path); } + +void dfs_path_t::clearFsysCachedDirs(void) { + cont_info->daos->clearDirCache(cont_info->dir_map); +} + +int dfs_path_t::getCachedSize(daos_size_t& size) { + auto search = cont_info->size_map.find(rel_path); + if (search != cont_info->size_map.end()) { + size = search->second; + return 0; + } else { + return -1; + } +} + +void dfs_path_t::setCachedSize(daos_size_t size) { + cont_info->size_map[rel_path] = size; +} + +void dfs_path_t::clearCachedSize(void) { cont_info->size_map.erase(rel_path); } + +void dfs_path_t::clearFsysCachedSizes(void) { + cont_info->daos->clearSizeCache(cont_info->size_map); +} + +static void* LoadSharedLibrary(const char* library_filename, + TF_Status* status) { + std::string full_path; + char* libdir; + void* handle; + + if ((libdir = std::getenv("TF_IO_DAOS_LIBRARY_DIR")) != nullptr) { + full_path = libdir; + if (full_path.back() != '/') full_path.push_back('/'); + full_path.append(library_filename); + handle = dlopen(full_path.c_str(), RTLD_NOW | RTLD_LOCAL); + if (handle != nullptr) { + TF_SetStatus(status, TF_OK, ""); + return handle; + } + } + + // Check for the library in the installation location used by rpms. + full_path = "/usr/lib64/"; + full_path += library_filename; + handle = dlopen(full_path.c_str(), RTLD_NOW | RTLD_LOCAL); + if (handle != nullptr) { + TF_SetStatus(status, TF_OK, ""); + return handle; + } + + // Check for the library in the location used when building DAOS fom source. + full_path = "/opt/daos/lib64/"; + full_path += library_filename; + handle = dlopen(full_path.c_str(), RTLD_NOW | RTLD_LOCAL); + if (handle != nullptr) { + TF_SetStatus(status, TF_OK, ""); + return handle; + } + + std::string error_message = + absl::StrCat("Library (", library_filename, ") not found: ", dlerror()); + TF_SetStatus(status, TF_NOT_FOUND, error_message.c_str()); + return nullptr; +} + +static void* GetSymbolFromLibrary(void* handle, const char* symbol_name, + TF_Status* status) { + if (handle == nullptr) { + TF_SetStatus(status, TF_INVALID_ARGUMENT, "library handle cannot be null"); + return nullptr; + } + void* symbol = dlsym(handle, symbol_name); + if (symbol == nullptr) { + std::string error_message = + absl::StrCat("Symbol (", symbol_name, ") not found: ", dlerror()); + TF_SetStatus(status, TF_NOT_FOUND, error_message.c_str()); + return nullptr; + } + + TF_SetStatus(status, TF_OK, ""); + return symbol; +} + +template +void BindFunc(void* handle, const char* name, std::function* func, + TF_Status* status) { + *func = reinterpret_cast( + GetSymbolFromLibrary(handle, name, status)); +} + +libDFS::~libDFS() { + if (libdaos_handle_ != nullptr) { + dlclose(libdaos_handle_); + } + if (libdfs_handle_ != nullptr) { + dlclose(libdfs_handle_); + } + if (libduns_handle_ != nullptr) { + dlclose(libduns_handle_); + } +} + +void libDFS::LoadAndBindDaosLibs(TF_Status* status) { +#define LOAD_DFS_LIBRARY(handle, library_filename, status) \ + do { \ + handle = LoadSharedLibrary(library_filename, status); \ + if (TF_GetCode(status) != TF_OK) return; \ + } while (0); + + LOAD_DFS_LIBRARY(libdaos_handle_, "libdaos.so", status); + LOAD_DFS_LIBRARY(libdfs_handle_, "libdfs.so", status); + LOAD_DFS_LIBRARY(libduns_handle_, "libduns.so", status); + +#undef LOAD_DFS_LIBRARY + +#define BIND_DFS_FUNC(handle, function) \ + do { \ + BindFunc(handle, #function, &function, status); \ + if (TF_GetCode(status) != TF_OK) return; \ + } while (0); + + BIND_DFS_FUNC(libdaos_handle_, daos_cont_close); + BIND_DFS_FUNC(libdaos_handle_, daos_cont_open2); + BIND_DFS_FUNC(libdaos_handle_, daos_cont_query); + BIND_DFS_FUNC(libdaos_handle_, daos_event_init); + BIND_DFS_FUNC(libdaos_handle_, daos_event_fini); + BIND_DFS_FUNC(libdaos_handle_, daos_event_test); + BIND_DFS_FUNC(libdaos_handle_, daos_eq_create); + BIND_DFS_FUNC(libdaos_handle_, daos_eq_destroy); + BIND_DFS_FUNC(libdaos_handle_, daos_fini); + BIND_DFS_FUNC(libdaos_handle_, daos_init); + BIND_DFS_FUNC(libdaos_handle_, daos_pool_connect2); + BIND_DFS_FUNC(libdaos_handle_, daos_pool_disconnect); + BIND_DFS_FUNC(libdaos_handle_, daos_pool_query); + + BIND_DFS_FUNC(libdfs_handle_, dfs_cont_create_with_label); + BIND_DFS_FUNC(libdfs_handle_, dfs_dup); + BIND_DFS_FUNC(libdfs_handle_, dfs_get_mode); + BIND_DFS_FUNC(libdfs_handle_, dfs_get_size); + BIND_DFS_FUNC(libdfs_handle_, dfs_lookup); + BIND_DFS_FUNC(libdfs_handle_, dfs_lookup_rel); + BIND_DFS_FUNC(libdfs_handle_, dfs_mkdir); + BIND_DFS_FUNC(libdfs_handle_, dfs_mount); + BIND_DFS_FUNC(libdfs_handle_, dfs_move); + BIND_DFS_FUNC(libdfs_handle_, dfs_open); + BIND_DFS_FUNC(libdfs_handle_, dfs_ostat); + BIND_DFS_FUNC(libdfs_handle_, dfs_read); + BIND_DFS_FUNC(libdfs_handle_, dfs_readdir); + BIND_DFS_FUNC(libdfs_handle_, dfs_release); + BIND_DFS_FUNC(libdfs_handle_, dfs_remove); + BIND_DFS_FUNC(libdfs_handle_, dfs_umount); + BIND_DFS_FUNC(libdfs_handle_, dfs_write); + + BIND_DFS_FUNC(libduns_handle_, duns_destroy_attr); + BIND_DFS_FUNC(libduns_handle_, duns_resolve_path); + +#undef BIND_DFS_FUNC +} diff --git a/tensorflow_io/core/filesystems/dfs/dfs_utils.h b/tensorflow_io/core/filesystems/dfs/dfs_utils.h new file mode 100644 index 000000000..1522c8115 --- /dev/null +++ b/tensorflow_io/core/filesystems/dfs/dfs_utils.h @@ -0,0 +1,304 @@ +#ifndef TENSORFLOW_IO_CORE_FILESYSTEMS_DFS_DFS_FILESYSTEM_H_ +#define TENSORFLOW_IO_CORE_FILESYSTEMS_DFS_DFS_FILESYSTEM_H_ + +#define KILO 1e3 +#define MEGA 1e6 +#define GEGA 1e9 +#define TERA 1e12 +#define POOL_START 6 +#define CONT_START 43 +#define PATH_START 80 +#define STACK 24 +#define NUM_OF_BUFFERS 256 +#define BUFF_SIZE 4 * 1024 * 1024 + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" +#include "tensorflow/c/logging.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow_io/core/filesystems/filesystem_plugins.h" + +typedef std::unordered_map dir_cache_t; +typedef std::unordered_map size_cache_t; + +class DFS; + +// Class for per-DFS-filesystem state variables, one per container in the +// 'containers' map. +class cont_info_t { + public: + daos_handle_t coh; + DFS* daos; + std::string pool; + std::string cont; + dfs_t* daos_fs; + dir_cache_t dir_map; + size_cache_t size_map; +}; + +typedef struct pool_info { + daos_handle_t poh; + std::unordered_map containers; +} pool_info_t; + +// Class for per-DFS-file state variables and common path operations. State +// includes the filesystem in which the file resides. +class dfs_path_t { + public: + dfs_path_t() { cont_info = nullptr; }; + dfs_path_t(cont_info_t* cont_info, std::string rel_path); + dfs_path_t& operator=(dfs_path_t other); + DFS* getDAOS(void); + dfs_t* getFsys(void); + std::string getFullPath(void); + std::string getRelPath(void); + std::string getParentPath(void); + std::string getBaseName(void); + void setRelPath(std::string); + bool isRoot(void); + + dfs_obj_t* getCachedDir(void); + void setCachedDir(dfs_obj_t* dir_obj); + void clearCachedDir(void); + void clearFsysCachedDirs(void); + + int getCachedSize(daos_size_t& size); + void setCachedSize(daos_size_t size); + void clearCachedSize(void); + void clearFsysCachedSizes(void); + + private: + cont_info_t* cont_info; + std::string rel_path; +}; + +typedef std::pair id_handle_t; + +enum File_Mode { READ, WRITE, APPEND, READWRITE }; + +std::string GetStorageString(uint64_t size); + +size_t GetStorageSize(std::string size); + +int ParseUUID(const std::string& str, uuid_t uuid); + +class libDFS { + public: + explicit libDFS(TF_Status* status) { LoadAndBindDaosLibs(status); } + + ~libDFS(); + + std::function daos_cont_close; + + std::function + daos_cont_open2; + + std::function + daos_cont_query; + + std::function + daos_event_init; + + std::function daos_event_fini; + + std::function daos_event_test; + + std::function daos_eq_create; + + std::function daos_eq_destroy; + + std::function daos_fini; + + std::function daos_init; + + std::function + daos_pool_connect2; + + std::function daos_pool_disconnect; + + std::function + daos_pool_query; + + std::function + dfs_cont_create_with_label; + + std::function dfs_dup; + + std::function dfs_get_mode; + + std::function dfs_get_size; + + std::function + dfs_lookup; + + std::function + dfs_lookup_rel; + + std::function + dfs_mkdir; + + std::function dfs_mount; + + std::function + dfs_move; + + std::function + dfs_open; + + std::function dfs_ostat; + + std::function + dfs_read; + + std::function + dfs_readdir; + + std::function dfs_release; + + std::function + dfs_remove; + + std::function dfs_umount; + + std::function + dfs_write; + + std::function duns_destroy_attr; + + std::function duns_resolve_path; + + private: + void LoadAndBindDaosLibs(TF_Status* status); + + void* libdaos_handle_; + void* libdfs_handle_; + void* libduns_handle_; +}; + +// Singlton class for the DFS plugin, containing all its global state. +class DFS { + public: + daos_handle_t mEventQueueHandle; + std::unique_ptr libdfs; + std::unordered_map pools; + + explicit DFS(TF_Status* status); + + int ParseDFSPath(const std::string& path, std::string& pool_string, + std::string& cont_string, std::string& filename); + + int Setup(DFS* daos, const std::string path, dfs_path_t& dpath, + TF_Status* status); + + void Connect(DFS* daos, std::string& pool_string, std::string& cont_string, + int allow_cont_creation, TF_Status* status); + + int Query(id_handle_t pool, id_handle_t container, dfs_t* daos_fs); + + int ClearConnections(); + + void clearDirCache(dir_cache_t& dir_cache); + + void clearAllDirCaches(void); + + void clearSizeCache(size_cache_t& size_cache); + + void clearAllSizeCaches(void); + + void dfsNewFile(dfs_path_t* dpath, File_Mode mode, int flags, dfs_obj_t** obj, + TF_Status* status); + + int dfsFindParent(dfs_path_t* dpath, dfs_obj_t** obj, TF_Status* status); + + int dfsCreateDir(dfs_path_t* dpath, TF_Status* status); + + int dfsDeleteObject(dfs_path_t* dpath, bool is_dir, bool recursive, + TF_Status* status); + + bool dfsIsDirectory(dfs_obj_t* obj); + + int dfsReadDir(dfs_t* daos_fs, dfs_obj_t* obj, + std::vector& children); + + int dfsLookUp(dfs_path_t* dpath, dfs_obj_t** obj, TF_Status* status); + + dfs_obj_t* lookup_insert_dir(const char* name, mode_t* mode); + + ~DFS(); + + private: + int ConnectPool(std::string pool_string, TF_Status* status); + + int ConnectContainer(DFS* daos, std::string pool_string, + std::string cont_string, int allow_creation, + TF_Status* status); + + int DisconnectPool(std::string pool_string); + + int DisconnectContainer(std::string pool_string, std::string cont_string); +}; + +void CopyEntries(char*** entries, std::vector& results); + +class ReadBuffer { + public: + ReadBuffer(size_t id, DFS* daos, daos_handle_t eqh, size_t size); + + ReadBuffer(ReadBuffer&&); + + ~ReadBuffer(); + + bool CacheHit(const size_t pos); + + void WaitEvent(); + + int ReadAsync(dfs_t* dfs, dfs_obj_t* file, const size_t off, + const size_t file_size); + + int CopyData(char* ret, const size_t ret_offset, const size_t offset, + const size_t n); + + int64_t CopyFromCache(char* ret, const size_t ret_offset, const size_t off, + const size_t n, const daos_size_t file_size, + TF_Status* status); + + private: + size_t id; + DFS* daos; + char* buffer; + size_t buffer_offset; + size_t buffer_size; + daos_handle_t eqh; + daos_event_t* event; + d_sg_list_t rsgl; + d_iov_t iov; + daos_size_t read_size; +}; + +#endif // TENSORFLOW_IO_CORE_FILESYSTEMS_DFS_DFS_FILESYSTEM_H_ diff --git a/tensorflow_io/core/filesystems/filesystem_plugins.cc b/tensorflow_io/core/filesystems/filesystem_plugins.cc index 5db773bf3..b9166213c 100644 --- a/tensorflow_io/core/filesystems/filesystem_plugins.cc +++ b/tensorflow_io/core/filesystems/filesystem_plugins.cc @@ -30,6 +30,9 @@ TFIO_PLUGIN_EXPORT void TF_InitPlugin(TF_FilesystemPluginInfo* info) { info->plugin_memory_allocate = tensorflow::io::plugin_memory_allocate; info->plugin_memory_free = tensorflow::io::plugin_memory_free; info->num_schemes = 7; +#if !defined(__APPLE__) && !defined(_MSC_VER) + info->num_schemes = info->num_schemes + 1; +#endif info->ops = static_cast( tensorflow::io::plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0]))); @@ -40,4 +43,7 @@ TFIO_PLUGIN_EXPORT void TF_InitPlugin(TF_FilesystemPluginInfo* info) { tensorflow::io::hdfs::ProvideFilesystemSupportFor(&info->ops[4], "hdfs"); tensorflow::io::hdfs::ProvideFilesystemSupportFor(&info->ops[5], "viewfs"); tensorflow::io::hdfs::ProvideFilesystemSupportFor(&info->ops[6], "har"); +#if !defined(__APPLE__) && !defined(_MSC_VER) + tensorflow::io::dfs::ProvideFilesystemSupportFor(&info->ops[7], "daos"); +#endif } diff --git a/tensorflow_io/core/filesystems/filesystem_plugins.h b/tensorflow_io/core/filesystems/filesystem_plugins.h index b3e708a37..3ec23cebd 100644 --- a/tensorflow_io/core/filesystems/filesystem_plugins.h +++ b/tensorflow_io/core/filesystems/filesystem_plugins.h @@ -50,6 +50,12 @@ void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops, const char* uri); } // namespace s3 +namespace dfs { + +void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops, const char* uri); + +} // namespace dfs + } // namespace io } // namespace tensorflow diff --git a/tests/test_dfs.py b/tests/test_dfs.py new file mode 100644 index 000000000..40bfaadd6 --- /dev/null +++ b/tests/test_dfs.py @@ -0,0 +1,184 @@ +""" +Tests for Tensorflow-IO DFS Plugin +""" +import os +import sys +import pytest + +import tensorflow as tf +import tensorflow_io as tfio + +if sys.platform in ["darwin", "win32"]: + pytest.skip("Incompatible", allow_module_level=True) + + +class DFSTest(tf.test.TestCase): + """Test Class for DFS""" + + def __init__(self, methodName="runTest"): # pylint: disable=invalid-name + + self.pool = os.environ["POOL_LABEL"] + self.pool_uuid = os.environ["POOL_UUID"] + self.container = os.environ["CONT_LABEL"] + self.container_uuid = os.environ["CONT_UUID"] + self.path_root = "daos://" + os.path.join(self.pool, self.container) + self.path_root_with_uuid = "daos://" + os.path.join( + self.pool_uuid, self.container_uuid + ) + super().__init__(methodName) + + def _path_to(self, path): + return os.path.join(self.path_root, path) + + def _uuid_path_to(self, path): + return os.path.join(self.path_root_with_uuid, path) + + def test_exists(self): + """Test Root exists""" + self.assertTrue(tf.io.gfile.isdir(self.path_root)) + + def test_create_file(self): + """Test create file.""" + # Setup and check preconditions. + file_name = self._path_to("testfile") + if tf.io.gfile.exists(file_name): + tf.io.gfile.remove(file_name) + # Create file. + with tf.io.gfile.GFile(file_name, "w") as write_file: + write_file.write("") + # Check that file was created. + self.assertTrue(tf.io.gfile.exists(file_name)) + + tf.io.gfile.remove(file_name) + + def test_write_read_file(self): + """Test write/read file.""" + # Setup and check preconditions. + file_name = self._path_to("writereadfile") + if tf.io.gfile.exists(file_name): + tf.io.gfile.remove(file_name) + + # Write data. + with tf.io.gfile.GFile(file_name, "w") as write_file: + write_file.write("Hello\n, world!") + + # Read data. + with tf.io.gfile.GFile(file_name, "r") as read_file: + file_read = read_file.read() + self.assertEqual(file_read, "Hello\n, world!") + + def test_write_read_file_uuid(self): + """Test write/read file.""" + # Setup and check preconditions. + file_name = self._uuid_path_to("writereadfile") + if tf.io.gfile.exists(file_name): + tf.io.gfile.remove(file_name) + + # Write data. + with tf.io.gfile.GFile(file_name, "w") as write_file: + write_file.write("Hello\n, world!") + + # Read data. + with tf.io.gfile.GFile(file_name, "r") as read_file: + file_read = read_file.read() + self.assertEqual(file_read, "Hello\n, world!") + + def test_wildcard_matching(self): + """Test glob patterns""" + dir_name = self._path_to("wildcard") + tf.io.gfile.mkdir(dir_name) + for ext in [".txt", ".md"]: + for i in range(3): + file_path = self._path_to(f"wildcard/{i}{ext}") + with tf.io.gfile.GFile(file_path, "w") as write_file: + write_file.write("") + + txt_files = tf.io.gfile.glob(self._path_to("wildcard/*.txt")) + self.assertEqual(3, len(txt_files)) + for i, name in enumerate(txt_files): + self.assertEqual(self._path_to(f"wildcard/{i}.txt"), name) + tf.io.gfile.rmtree(self._path_to("wildcard")) + + def test_delete_recursively(self): + """Test delete recursively.""" + # Setup and check preconditions. + dir_name = self._path_to("recursive") + file_name = self._path_to("recursive/1") + + tf.io.gfile.mkdir(dir_name) + with tf.io.gfile.GFile(file_name, "w") as write_file: + write_file.write("") + + self.assertTrue(tf.io.gfile.isdir(dir_name)) + self.assertTrue(tf.io.gfile.exists(file_name)) + + # Delete directory recursively. + tf.io.gfile.rmtree(dir_name) + + # Check that directory was deleted. + self.assertFalse(tf.io.gfile.exists(dir_name)) + self.assertFalse(tf.io.gfile.exists(file_name)) + + def test_is_directory(self): + """Test is directory.""" + # Setup and check preconditions. + parent = self._path_to("isdir") + dir_name = self._path_to("isdir/1") + file_name = self._path_to("7.txt") + tf.io.gfile.mkdir(parent) + with tf.io.gfile.GFile(file_name, "w") as write_file: + write_file.write("") + tf.io.gfile.mkdir(dir_name) + # Check that directory is a directory. + self.assertTrue(tf.io.gfile.isdir(dir_name)) + # Check that file is not a directory. + self.assertFalse(tf.io.gfile.isdir(file_name)) + + def test_list_directory(self): + """Test list directory.""" + # Setup and check preconditions. + dir_name = self._path_to("listdir") + tf.io.gfile.mkdir(dir_name) + file_names = [self._path_to(f"listdir/{i}") for i in range(1, 33)] + + for file_name in file_names: + with tf.io.gfile.GFile(file_name, "w") as write_file: + write_file.write("") + # Get list of files in directory. + ls_result = tf.io.gfile.listdir(dir_name) + # Check that list of files is correct. + self.assertEqual(len(file_names), len(ls_result)) + for element in ["1", "2", "3"]: + self.assertTrue(element in ls_result) + + def test_make_dirs(self): + """Test make dirs.""" + # Setup and check preconditions. + dir_name = self.path_root + # Make directory. + tf.io.gfile.mkdir(dir_name) + # Check that directory was created. + self.assertTrue(tf.io.gfile.isdir(dir_name)) + + parent = self._path_to("test") + dir_name = self._path_to("test/directory") + tf.io.gfile.mkdir(parent) + tf.io.gfile.makedirs(dir_name) + self.assertTrue(tf.io.gfile.isdir(dir_name)) + + def test_remove(self): + """Test remove.""" + # Setup and check preconditions. + file_name = self._path_to("file_to_be_removed") + self.assertFalse(tf.io.gfile.exists(file_name)) + with tf.io.gfile.GFile(file_name, "w") as write_file: + write_file.write("") + self.assertTrue(tf.io.gfile.exists(file_name)) + # Remove file. + tf.io.gfile.remove(file_name) + # Check that file was removed. + self.assertFalse(tf.io.gfile.exists(file_name)) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tests/test_dfs/dfs_cleanup.sh b/tests/test_dfs/dfs_cleanup.sh new file mode 100755 index 000000000..b025bb499 --- /dev/null +++ b/tests/test_dfs/dfs_cleanup.sh @@ -0,0 +1,2 @@ +#!/usr/bin/env bash +dmg -i pool destroy -f TEST_POOL \ No newline at end of file diff --git a/tests/test_dfs/dfs_init.sh b/tests/test_dfs/dfs_init.sh new file mode 100755 index 000000000..f440804ed --- /dev/null +++ b/tests/test_dfs/dfs_init.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash +OUTPUT=$(dmg -i pool create -s 2G TEST_POOL) +export POOL_UUID=`echo -e $OUTPUT | cut -d':' -f 3 | cut -d ' ' -f 2 | xargs` +echo "$POOL_UUID" +export POOL_LABEL='TEST_POOL' +echo "$POOL_LABEL" +OUTPUT=$(daos cont create --pool=TEST_POOL --type=POSIX TEST_CONT) +export CONT_UUID=`echo -e $OUTPUT | cut -d ' ' -f 4 | xargs` +echo "$CONT_UUID" +export CONT_LABEL='TEST_CONT' +echo "$CONT_LABEL" + diff --git a/third_party/daos.BUILD b/third_party/daos.BUILD new file mode 100644 index 000000000..5b56bd05b --- /dev/null +++ b/third_party/daos.BUILD @@ -0,0 +1,32 @@ +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "daos", + hdrs = glob( + [ + "src/include/**/*.h", + ], + ) + [ + "src/include/daos_version.h", + ], + copts = [], + includes = ["src/include"], + deps = [ + "@util_linux//:uuid", + ], +) + +genrule( + name = "daos_version_h", + srcs = [ + "src/include/daos_version.h.in", + ], + outs = [ + "src/include/daos_version.h", + ], + cmd = ("sed " + + "-e 's/@TMPL_MAJOR@/2/g' " + + "-e 's/@TMPL_MINOR@/0/g' " + + "-e 's/@TMPL_FIX@/2/g' " + + "$< >$@"), +)