Skip to content

Commit

Permalink
Replace tensorflow ST5 embedder with torch based one
Browse files Browse the repository at this point in the history
This will ensure the embedder works with M1-based macs (fixes #37).

PiperOrigin-RevId: 625611229
Change-Id: I6591565e82774e23aa40bd7a33571f8fe6a9a0d5
  • Loading branch information
vezhnick authored and copybara-github committed Apr 17, 2024
1 parent 449cd8f commit d0df6e3
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 85 deletions.
46 changes: 0 additions & 46 deletions concordia/associative_memory/embedder_st5.py

This file was deleted.

5 changes: 3 additions & 2 deletions examples/cyberball/cyberball.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
"\n",
"from google.colab import widgets # pytype: disable=import-error\n",
"from IPython import display\n",
"import sentence_transformers\n",
"\n",
"from concordia.agents import basic_agent\n",
"from concordia import components as generic_components\n",
Expand All @@ -79,7 +80,6 @@
"from concordia.document import interactive_document\n",
"from concordia.associative_memory import associative_memory\n",
"from concordia.associative_memory import blank_memories\n",
"from concordia.associative_memory import embedder_st5\n",
"from concordia.associative_memory import formative_memories\n",
"from concordia.associative_memory import importance_function\n",
"from concordia.clocks import game_clock\n",
Expand All @@ -106,7 +106,8 @@
"outputs": [],
"source": [
"# Setup sentence encoder\n",
"embedder = embedder_st5.EmbedderST5()"
"st_model = sentence_transformers.SentenceTransformer('sentence-transformers/all-mpnet-base-v2')\n",
"embedder = lambda x: st_model.encode(x, show_progress_bar=False)\n"
]
},
{
Expand Down
5 changes: 3 additions & 2 deletions examples/magic_beans_for_sale.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
"import datetime\n",
"\n",
"import numpy as np\n",
"import sentence_transformers\n",
"\n",
"from google.colab import widgets # pytype: disable=import-error\n",
"from IPython import display\n",
Expand All @@ -77,7 +78,6 @@
"from concordia import components as generic_components\n",
"from concordia.associative_memory import associative_memory\n",
"from concordia.associative_memory import blank_memories\n",
"from concordia.associative_memory import embedder_st5\n",
"from concordia.associative_memory import formative_memories\n",
"from concordia.associative_memory import importance_function\n",
"from concordia.clocks import game_clock\n",
Expand All @@ -102,7 +102,8 @@
"outputs": [],
"source": [
"# Setup sentence encoder\n",
"embedder = embedder_st5.EmbedderST5()"
"st_model = sentence_transformers.SentenceTransformer('sentence-transformers/all-mpnet-base-v2')\n",
"embedder = lambda x: st_model.encode(x, show_progress_bar=False)\n"
]
},
{
Expand Down
5 changes: 3 additions & 2 deletions examples/phone/calendar.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,14 @@
"import random\n",
"\n",
"from IPython import display\n",
"import sentence_transformers\n",
"\n",
"from concordia import components as generic_components\n",
"from concordia.components import agent as components\n",
"from concordia.components import game_master as gm_components\n",
"from concordia.agents import basic_agent\n",
"from concordia.associative_memory import associative_memory\n",
"from concordia.associative_memory import blank_memories\n",
"from concordia.associative_memory import embedder_st5\n",
"from concordia.associative_memory import formative_memories\n",
"from concordia.associative_memory import importance_function\n",
"from concordia.clocks import game_clock\n",
Expand All @@ -99,7 +99,8 @@
"outputs": [],
"source": [
"#@title Setup sentence encoder\n",
"embedder = embedder_st5.EmbedderST5()"
"st_model = sentence_transformers.SentenceTransformer('sentence-transformers/all-mpnet-base-v2')\n",
"embedder = lambda x: st_model.encode(x, show_progress_bar=False)\n"
]
},
{
Expand Down
1 change: 1 addition & 0 deletions examples/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
absl-py
docstring-parser
gdm-concordia
sentence_transformers
termcolor
5 changes: 3 additions & 2 deletions examples/three_key_questions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
"import datetime\n",
"\n",
"import numpy as np\n",
"import sentence_transformers\n",
"\n",
"from google.colab import widgets # pytype: disable=import-error\n",
"from IPython import display\n",
Expand All @@ -84,7 +85,6 @@
"from concordia import components as generic_components\n",
"from concordia.associative_memory import associative_memory\n",
"from concordia.associative_memory import blank_memories\n",
"from concordia.associative_memory import embedder_st5\n",
"from concordia.associative_memory import formative_memories\n",
"from concordia.associative_memory import importance_function\n",
"from concordia.clocks import game_clock\n",
Expand All @@ -109,7 +109,8 @@
"outputs": [],
"source": [
"# Setup sentence encoder\n",
"embedder = embedder_st5.EmbedderST5()"
"st_model = sentence_transformers.SentenceTransformer('sentence-transformers/all-mpnet-base-v2')\n",
"embedder = lambda x: st_model.encode(x, show_progress_bar=False)\n"
]
},
{
Expand Down
5 changes: 3 additions & 2 deletions examples/village/day_in_riverbend.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,14 @@
"\n",
"from google.colab import widgets # pytype: disable=import-error\n",
"from IPython import display\n",
"import sentence_transformers\n",
"\n",
"from concordia import components as generic_components\n",
"from concordia.agents import basic_agent\n",
"from concordia.components import agent as components\n",
"from concordia.agents import basic_agent\n",
"from concordia.associative_memory import associative_memory\n",
"from concordia.associative_memory import blank_memories\n",
"from concordia.associative_memory import embedder_st5\n",
"from concordia.associative_memory import formative_memories\n",
"from concordia.associative_memory import importance_function\n",
"from concordia.clocks import game_clock\n",
Expand All @@ -105,7 +105,8 @@
"outputs": [],
"source": [
"# @title Setup sentence encoder\n",
"embedder = embedder_st5.EmbedderST5()"
"st_model = sentence_transformers.SentenceTransformer('sentence-transformers/all-mpnet-base-v2')\n",
"embedder = lambda x: st_model.encode(x, show_progress_bar=False)\n"
]
},
{
Expand Down
5 changes: 3 additions & 2 deletions examples/village/riverbend_elections.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,14 @@
"\n",
"from google.colab import widgets # pytype: disable=import-error\n",
"from IPython import display\n",
"import sentence_transformers\n",
"\n",
"from concordia import components as generic_components\n",
"from concordia.components import agent as components\n",
"from concordia.components import game_master as gm_components\n",
"from concordia.agents import basic_agent\n",
"from concordia.associative_memory import associative_memory\n",
"from concordia.associative_memory import blank_memories\n",
"from concordia.associative_memory import embedder_st5\n",
"from concordia.associative_memory import formative_memories\n",
"from concordia.associative_memory import importance_function\n",
"from concordia.clocks import game_clock\n",
Expand All @@ -104,7 +104,8 @@
"outputs": [],
"source": [
"# @title Setup sentence encoder\n",
"embedder = embedder_st5.EmbedderST5()"
"st_model = sentence_transformers.SentenceTransformer('sentence-transformers/all-mpnet-base-v2')\n",
"embedder = lambda x: st_model.encode(x, show_progress_bar=False)\n"
]
},
{
Expand Down
44 changes: 17 additions & 27 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,8 @@

"""Install script for setuptools."""

import platform
import setuptools

IS_M1_OSX = platform.system() == 'Darwin' and platform.machine() == 'arm64'

REQUIREMENTS = (
# TODO: b/312199199 - remove some requirements.
'absl-py',
'google-cloud-aiplatform',
'ipython',
'langchain',
'matplotlib',
'numpy',
'openai>=1.3.0',
'pandas<=2.0.3',
'python-dateutil',
'reactivex',
'retry',
'scipy',
'tensorflow',
'tensorflow-hub',
'tensorflow-text',
'termcolor',
'typing-extensions',
)
M1_OSX_REQUIREMENTS = tuple(set(REQUIREMENTS) - {'tensorflow-text'})


setuptools.setup(
name='gdm-concordia',
version='1.2.0',
Expand Down Expand Up @@ -75,7 +49,23 @@
packages=setuptools.find_packages(include=['concordia', 'concordia.*']),
package_data={},
python_requires='>=3.11',
install_requires=M1_OSX_REQUIREMENTS if IS_M1_OSX else REQUIREMENTS,
install_requires=(
# TODO: b/312199199 - remove some requirements.
'absl-py',
'google-cloud-aiplatform',
'ipython',
'langchain',
'matplotlib',
'numpy',
'openai>=1.3.0',
'pandas<=2.0.3',
'python-dateutil',
'reactivex',
'retry',
'scipy',
'termcolor',
'typing-extensions',
),
extras_require={
# Used in development.
'dev': [
Expand Down

0 comments on commit d0df6e3

Please sign in to comment.