Skip to content

Commit

Permalink
Update TabTransformer example for Keras 3. (keras-team#1581)
Browse files Browse the repository at this point in the history
  • Loading branch information
hertschuh authored Nov 8, 2023
1 parent bd69633 commit f54c41f
Show file tree
Hide file tree
Showing 6 changed files with 219 additions and 192 deletions.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
128 changes: 69 additions & 59 deletions examples/structured_data/ipynb/tabtransformer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,7 @@
"The Transformer layers transform the embeddings of categorical features\n",
"into robust contextual embeddings to achieve higher predictive accuracy.\n",
"\n",
"This example should be run with TensorFlow 2.7 or higher,\n",
"as well as [TensorFlow Addons](https://www.tensorflow.org/addons/overview),\n",
"which can be installed using the following command:\n",
"\n",
"```python\n",
"pip install -U tensorflow-addons\n",
"```\n",
"\n",
"## Setup"
]
Expand All @@ -48,14 +42,16 @@
},
"outputs": [],
"source": [
"import keras\n",
"from keras import layers\n",
"from keras import ops\n",
"\n",
"import math\n",
"import numpy as np\n",
"import pandas as pd\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n",
"import tensorflow_addons as tfa\n",
"import matplotlib.pyplot as plt"
"from tensorflow import data as tf_data\n",
"import matplotlib.pyplot as plt\n",
"from functools import partial"
]
},
{
Expand Down Expand Up @@ -289,19 +285,43 @@
" return features, target_index, weights\n",
"\n",
"\n",
"lookup_dict = {}\n",
"for feature_name in CATEGORICAL_FEATURE_NAMES:\n",
" vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]\n",
" # Create a lookup to convert a string values to an integer indices.\n",
" # Since we are not using a mask token, nor expecting any out of vocabulary\n",
" # (oov) token, we set mask_token to None and num_oov_indices to 0.\n",
" lookup = layers.StringLookup(\n",
" vocabulary=vocabulary, mask_token=None, num_oov_indices=0\n",
" )\n",
" lookup_dict[feature_name] = lookup\n",
"\n",
"\n",
"def encode_categorical(batch_x, batch_y, weights):\n",
" for feature_name in CATEGORICAL_FEATURE_NAMES:\n",
" batch_x[feature_name] = lookup_dict[feature_name](batch_x[feature_name])\n",
"\n",
" return batch_x, batch_y, weights\n",
"\n",
"\n",
"def get_dataset_from_csv(csv_file_path, batch_size=128, shuffle=False):\n",
" dataset = tf.data.experimental.make_csv_dataset(\n",
" csv_file_path,\n",
" batch_size=batch_size,\n",
" column_names=CSV_HEADER,\n",
" column_defaults=COLUMN_DEFAULTS,\n",
" label_name=TARGET_FEATURE_NAME,\n",
" num_epochs=1,\n",
" header=False,\n",
" na_value=\"?\",\n",
" shuffle=shuffle,\n",
" ).map(prepare_example, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)\n",
" return dataset.cache()\n"
" dataset = (\n",
" tf_data.experimental.make_csv_dataset(\n",
" csv_file_path,\n",
" batch_size=batch_size,\n",
" column_names=CSV_HEADER,\n",
" column_defaults=COLUMN_DEFAULTS,\n",
" label_name=TARGET_FEATURE_NAME,\n",
" num_epochs=1,\n",
" header=False,\n",
" na_value=\"?\",\n",
" shuffle=shuffle,\n",
" )\n",
" .map(prepare_example, num_parallel_calls=tf_data.AUTOTUNE, deterministic=False)\n",
" .map(encode_categorical)\n",
" )\n",
" return dataset.cache()\n",
""
]
},
{
Expand Down Expand Up @@ -331,8 +351,7 @@
" weight_decay,\n",
" batch_size,\n",
"):\n",
"\n",
" optimizer = tfa.optimizers.AdamW(\n",
" optimizer = keras.optimizers.AdamW(\n",
" learning_rate=learning_rate, weight_decay=weight_decay\n",
" )\n",
"\n",
Expand All @@ -355,7 +374,8 @@
"\n",
" print(f\"Validation accuracy: {round(accuracy * 100, 2)}%\")\n",
"\n",
" return history\n"
" return history\n",
""
]
},
{
Expand Down Expand Up @@ -385,13 +405,14 @@
" for feature_name in FEATURE_NAMES:\n",
" if feature_name in NUMERIC_FEATURE_NAMES:\n",
" inputs[feature_name] = layers.Input(\n",
" name=feature_name, shape=(), dtype=tf.float32\n",
" name=feature_name, shape=(), dtype=\"float32\"\n",
" )\n",
" else:\n",
" inputs[feature_name] = layers.Input(\n",
" name=feature_name, shape=(), dtype=tf.string\n",
" name=feature_name, shape=(), dtype=\"float32\"\n",
" )\n",
" return inputs\n"
" return inputs\n",
""
]
},
{
Expand All @@ -417,45 +438,34 @@
"source": [
"\n",
"def encode_inputs(inputs, embedding_dims):\n",
"\n",
" encoded_categorical_feature_list = []\n",
" numerical_feature_list = []\n",
"\n",
" for feature_name in inputs:\n",
" if feature_name in CATEGORICAL_FEATURE_NAMES:\n",
"\n",
" # Get the vocabulary of the categorical feature.\n",
" vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]\n",
"\n",
" # Create a lookup to convert string values to an integer indices.\n",
" # Since we are not using a mask token nor expecting any out of vocabulary\n",
" # (oov) token, we set mask_token to None and num_oov_indices to 0.\n",
" lookup = layers.StringLookup(\n",
" vocabulary=vocabulary,\n",
" mask_token=None,\n",
" num_oov_indices=0,\n",
" output_mode=\"int\",\n",
" )\n",
" # Create a lookup to convert a string values to an integer indices.\n",
" # Since we are not using a mask token, nor expecting any out of vocabulary\n",
" # (oov) token, we set mask_token to None and num_oov_indices to 0.\n",
"\n",
" # Convert the string input values into integer indices.\n",
" encoded_feature = lookup(inputs[feature_name])\n",
"\n",
" # Create an embedding layer with the specified dimensions.\n",
" embedding = layers.Embedding(\n",
" input_dim=len(vocabulary), output_dim=embedding_dims\n",
" )\n",
"\n",
" # Convert the index values to embedding representations.\n",
" encoded_categorical_feature = embedding(encoded_feature)\n",
" encoded_categorical_feature = embedding(inputs[feature_name])\n",
" encoded_categorical_feature_list.append(encoded_categorical_feature)\n",
"\n",
" else:\n",
"\n",
" # Use the numerical features as-is.\n",
" numerical_feature = tf.expand_dims(inputs[feature_name], -1)\n",
" numerical_feature = ops.expand_dims(inputs[feature_name], -1)\n",
" numerical_feature_list.append(numerical_feature)\n",
"\n",
" return encoded_categorical_feature_list, numerical_feature_list\n"
" return encoded_categorical_feature_list, numerical_feature_list\n",
""
]
},
{
Expand All @@ -477,14 +487,14 @@
"source": [
"\n",
"def create_mlp(hidden_units, dropout_rate, activation, normalization_layer, name=None):\n",
"\n",
" mlp_layers = []\n",
" for units in hidden_units:\n",
" mlp_layers.append(normalization_layer),\n",
" mlp_layers.append(normalization_layer()),\n",
" mlp_layers.append(layers.Dense(units, activation=activation))\n",
" mlp_layers.append(layers.Dropout(dropout_rate))\n",
"\n",
" return keras.Sequential(mlp_layers, name=name)\n"
" return keras.Sequential(mlp_layers, name=name)\n",
""
]
},
{
Expand All @@ -510,7 +520,6 @@
"def create_baseline_model(\n",
" embedding_dims, num_mlp_blocks, mlp_hidden_units_factors, dropout_rate\n",
"):\n",
"\n",
" # Create model inputs.\n",
" inputs = create_model_inputs()\n",
" # encode features.\n",
Expand All @@ -530,7 +539,7 @@
" hidden_units=feedforward_units,\n",
" dropout_rate=dropout_rate,\n",
" activation=keras.activations.gelu,\n",
" normalization_layer=layers.LayerNormalization(epsilon=1e-6),\n",
" normalization_layer=layers.LayerNormalization,\n",
" name=f\"feedforward_{layer_idx}\",\n",
" )(features)\n",
"\n",
Expand All @@ -543,7 +552,7 @@
" hidden_units=mlp_hidden_units,\n",
" dropout_rate=dropout_rate,\n",
" activation=keras.activations.selu,\n",
" normalization_layer=layers.BatchNormalization(),\n",
" normalization_layer=layers.BatchNormalization,\n",
" name=\"MLP\",\n",
" )(features)\n",
"\n",
Expand Down Expand Up @@ -644,15 +653,14 @@
" dropout_rate,\n",
" use_column_embedding=False,\n",
"):\n",
"\n",
" # Create model inputs.\n",
" inputs = create_model_inputs()\n",
" # encode features.\n",
" encoded_categorical_feature_list, numerical_feature_list = encode_inputs(\n",
" inputs, embedding_dims\n",
" )\n",
" # Stack categorical feature embeddings for the Tansformer.\n",
" encoded_categorical_features = tf.stack(encoded_categorical_feature_list, axis=1)\n",
" encoded_categorical_features = ops.stack(encoded_categorical_feature_list, axis=1)\n",
" # Concatenate numerical features.\n",
" numerical_features = layers.concatenate(numerical_feature_list)\n",
"\n",
Expand All @@ -662,7 +670,7 @@
" column_embedding = layers.Embedding(\n",
" input_dim=num_columns, output_dim=embedding_dims\n",
" )\n",
" column_indices = tf.range(start=0, limit=num_columns, delta=1)\n",
" column_indices = ops.arange(start=0, stop=num_columns, step=1)\n",
" encoded_categorical_features = encoded_categorical_features + column_embedding(\n",
" column_indices\n",
" )\n",
Expand All @@ -687,7 +695,9 @@
" hidden_units=[embedding_dims],\n",
" dropout_rate=dropout_rate,\n",
" activation=keras.activations.gelu,\n",
" normalization_layer=layers.LayerNormalization(epsilon=1e-6),\n",
" normalization_layer=partial(\n",
" layers.LayerNormalization, epsilon=1e-6\n",
" ), # using partial to provide keyword arguments before initialization\n",
" name=f\"feedforward_{block_idx}\",\n",
" )(x)\n",
" # Skip connection 2.\n",
Expand All @@ -713,7 +723,7 @@
" hidden_units=mlp_hidden_units,\n",
" dropout_rate=dropout_rate,\n",
" activation=keras.activations.selu,\n",
" normalization_layer=layers.BatchNormalization(),\n",
" normalization_layer=layers.BatchNormalization,\n",
" name=\"MLP\",\n",
" )(features)\n",
"\n",
Expand Down Expand Up @@ -794,7 +804,7 @@
"\n",
"| Trained Model | Demo |\n",
"| :--: | :--: |\n",
"| [![Generic badge](https://img.shields.io/badge/🤗%20Model-TabTransformer-black.svg)](https://huggingface.co/keras-io/tab_transformer) | [![Generic badge](https://img.shields.io/badge/🤗%20Spaces-TabTransformer-black.svg)](https://huggingface.co/spaces/keras-io/TabTransformer_Classification) |"
"| [![Generic badge](https://img.shields.io/badge/\ud83e\udd17%20Model-TabTransformer-black.svg)](https://huggingface.co/keras-io/tab_transformer) | [![Generic badge](https://img.shields.io/badge/\ud83e\udd17%20Spaces-TabTransformer-black.svg)](https://huggingface.co/spaces/keras-io/TabTransformer_Classification) |"
]
}
],
Expand Down
Loading

0 comments on commit f54c41f

Please sign in to comment.