Skip to content

Commit

Permalink
Warm start a stratey based on config provided seed conditions (facebo…
Browse files Browse the repository at this point in the history
…okresearch#487)

Summary:
Pull Request resolved: facebookresearch#487

There is a desire to warm start a strategy with data filtered on master table values recorded from previous experiments.

Data can be filtered on any data other than master.unique_id. This means that the below are all valid filter criteria:
- experiment_name
- experiment_description
- experiment_id
- participant_id
- anything stored in extra_metadata

Criteria will follow AND logic between fields, but inclusive OR logic within the same field.

Data will be further filtered out if it does not meet these criteria:
- Each parameter matches on name and type
- Each outcome matches on name and type
- Stimuli per trial are the same

allow-large-files

Differential Revision: D66731597
  • Loading branch information
unrealdev12 authored and facebook-github-bot committed Dec 20, 2024
1 parent 63f22b1 commit a78ee13
Show file tree
Hide file tree
Showing 10 changed files with 840 additions and 22 deletions.
24 changes: 18 additions & 6 deletions aepsych/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,14 @@ def to_dict(self, deduplicate: bool = True) -> Dict[str, Any]:
_dict[section][setting] = self[section][setting]
return _dict

# Turn the metadata section into JSON.
def jsonifyMetadata(self, only_extra: bool = False) -> str:
"""Return a json string of the metadata section.
def get_metadata(self, only_extra: bool = False) -> Dict[Any, Any]:
"""Return a dictionary of the metadata section.
Args:
only_extra (bool): Only jsonify the extra meta data.
only_extra (bool, optional): Only gather the extra metadata. Defaults to False.
Returns:
str: A json string representing the metadata dictionary or an empty string
if there is no metadata to return.
Dict[Any, Any]: a collection of the metadata stored in this conig.
"""
configdict = self.to_dict()
metadata = configdict["metadata"].copy()
Expand All @@ -172,7 +170,21 @@ def jsonifyMetadata(self, only_extra: bool = False) -> str:
]
for name in default_metadata:
metadata.pop(name, None)

return metadata

# Turn the metadata section into JSON.
def jsonifyMetadata(self, only_extra: bool = False) -> str:
"""Return a json string of the metadata section.
Args:
only_extra (bool): Only jsonify the extra meta data.
Returns:
str: A json string representing the metadata dictionary or an empty string
if there is no metadata to return.
"""
metadata = self.get_metadata(only_extra)
if len(metadata.keys()) == 0:
return ""
else:
Expand Down
475 changes: 475 additions & 0 deletions aepsych/database/data_fetcher.py

Large diffs are not rendered by default.

18 changes: 9 additions & 9 deletions aepsych/database/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def execute_sql_query(self, query: str, vals: Dict[str, str]) -> List[Any]:
List[Any]: The results of the query.
"""
with self.session_scope() as session:
return session.execute(query, vals).fetchall()
return session.execute(query, vals).all()

def get_master_records(self) -> List[tables.DBMasterTable]:
"""Grab the list of master records.
Expand Down Expand Up @@ -239,14 +239,14 @@ def get_raw_for(self, master_id: int) -> Optional[List[tables.DbRawTable]]:

return None

def get_all_params_for(self, master_id: int) -> Optional[List[tables.DbRawTable]]:
def get_all_params_for(self, master_id: int) -> Optional[List[tables.DbParamTable]]:
"""Get the parameters for all the iterations of a specific experiment.
Args:
master_id (int): The master id.
Returns:
List[tables.DbRawTable] or None: The parameters or None if they don't exist.
List[tables.DbParamTable] or None: The parameters or None if they don't exist.
"""
warnings.warn(
"get_all_params_for is the same as get_param_for since there can only be one instance of any master_id",
Expand All @@ -266,14 +266,14 @@ def get_all_params_for(self, master_id: int) -> Optional[List[tables.DbRawTable]

return None

def get_param_for(self, master_id: int) -> Optional[List[tables.DbRawTable]]:
def get_param_for(self, master_id: int) -> Optional[List[tables.DbParamTable]]:
"""Get the parameters for a specific iteration of a specific experiment.
Args:
master_id (int): The master id.
Returns:
List[tables.DbRawTable] or None: The parameters or None if they don't exist.
List[tables.DbParamTable] or None: The parameters or None if they don't exist.
"""
raw_record = self.get_raw_for(master_id)

Expand All @@ -284,14 +284,14 @@ def get_param_for(self, master_id: int) -> Optional[List[tables.DbRawTable]]:

return None

def get_all_outcomes_for(self, master_id: int) -> Optional[List[tables.DbRawTable]]:
def get_all_outcomes_for(self, master_id: int) -> Optional[List[tables.DbOutcomeTable]]:
"""Get the outcomes for all the iterations of a specific experiment.
Args:
master_id (int): The master id.
Returns:
List[tables.DbRawTable] or None: The outcomes or None if they don't exist.
List[tables.DbOutcomeTable] or None: The outcomes or None if they don't exist.
"""
warnings.warn(
"get_all_outcomes_for is the same as get_outcome_for since there can only be one instance of any master_id",
Expand All @@ -311,14 +311,14 @@ def get_all_outcomes_for(self, master_id: int) -> Optional[List[tables.DbRawTabl

return None

def get_outcome_for(self, master_id: int) -> Optional[List[tables.DbRawTable]]:
def get_outcome_for(self, master_id: int) -> Optional[List[tables.DbOutcomeTable]]:
"""Get the outcomes for a specific iteration of a specific experiment.
Args:
master_id (int): The master id.
Returns:
List[tables.DbRawTable] or None: The outcomes or None if they don't exist.
List[tables.DbOutcomeTable] or None: The outcomes or None if they don't exist.
"""
raw_record = self.get_raw_for(master_id)

Expand Down
11 changes: 7 additions & 4 deletions aepsych/server/message_handlers/handle_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from aepsych.config import Config
from aepsych.strategy import SequentialStrategy
from aepsych.version import __version__
from aepsych.database.data_fetcher import DataFetcher

logger = utils_logging.getLogger(logging.INFO)

Expand All @@ -33,6 +34,10 @@ def _configure(server, config):
server.strat = SequentialStrategy.from_config(config)
server.strat_id = server.n_strats - 1 # 0-index strats

for strat in server.strat.strat_list:
fetcher = DataFetcher.from_config(config, strat.name)
fetcher.warm_start_strat(server, strat)

return server.strat_id


Expand Down Expand Up @@ -88,13 +93,11 @@ def handle_setup(server, request):
par_id = tempconfig["metadata"].get("participant_id", fallback=None)

# This may be populated when replaying
if server._db_master_record is not None:
exp_id = tempconfig["metadata"].get("experiment_id", fallback=None)
if exp_id is None and server._db_master_record is not None:
exp_id = server._db_master_record.experiment_id
else:
exp_id = tempconfig["metadata"].get("experiment_id", fallback=None)

extra_metadata = tempconfig.jsonifyMetadata(only_extra=True)

server._db_master_record = server.db.record_setup(
description=exp_desc,
name=exp_name,
Expand Down
2 changes: 1 addition & 1 deletion aepsych/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self, socket=None, database_path=None):
self.exit_server_loop = False
self._db_master_record = None
self._db_raw_record = None
self.db = db.Database(database_path)
self.db: db.Database = db.Database(database_path)
self.skip_computations = False
self.strat_names = None

Expand Down
15 changes: 15 additions & 0 deletions aepsych/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,21 @@ def n_trials(self) -> int:
)
return self.min_asks

def pre_warm_model(self, x: torch.Tensor, y: torch.Tensor) -> None:
"""
Adds new data points to the strategy, and normalizes the inputs.
We speceifically disregard the n return value of normalize_inputs here in order
to stop warm start data from affecting the trials run length.
Args:
x torch.Tensor: The input data points.
y torch.Tensor: The output data points.
"""
# warming the model shouldn't affect strategy.n
self.x, self.y, n = self.normalize_inputs(x, y)
self._model_is_fresh = False

def add_data(
self, x: Union[np.ndarray, torch.Tensor], y: Union[np.ndarray, torch.Tensor]
) -> None:
Expand Down
6 changes: 6 additions & 0 deletions aepsych/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,3 +431,9 @@ def get_dims(config: Config) -> int:
except KeyError:
# Likely old style of parameter definition, fallback to looking at a bound
return len(config.getlist("common", "lb", element_type=float))

def generate_default_outcome_names(count: int) -> List[str]:
if count == 1:
return ["outcome"]

return ["outcome_" + i for i in range(count)]
9 changes: 7 additions & 2 deletions tests/server/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,18 @@


class BaseServerTestCase(unittest.TestCase):
#so that this can be overridden for tests that require specific databases.
@property
def database_path(self):
return "./{}_test_server.db".format(str(uuid.uuid4().hex))

def setUp(self):
# setup logger
server.logger = utils_logging.getLogger(logging.DEBUG, "logs")
# random port
socket = server.sockets.PySocket(port=0)
# random datebase path name without dashes
database_path = "./{}_test_server.db".format(str(uuid.uuid4().hex))
database_path = self.database_path
self.s = server.AEPsychServer(socket=socket, database_path=database_path)
self.db_name = database_path.split("/")[1]
self.db_path = database_path
Expand All @@ -88,7 +93,7 @@ def dummy_create_setup(self, server, request=None):
)


class ServerTestCase(BaseServerTestCase):
class ServerTestCase(BaseServerTestCase):
def test_final_strat_serialization(self):
setup_request = {
"type": "setup",
Expand Down
Binary file added tests/test_databases/1000_outcome.db
Binary file not shown.
Loading

0 comments on commit a78ee13

Please sign in to comment.