-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add pre-commit hooks and blacken everything.
- Loading branch information
Showing
22 changed files
with
285 additions
and
236 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
repos: | ||
|
||
- repo: https://github.com/psf/black | ||
rev: 23.7.0 # update with `pre-commit autoupdate` | ||
hooks: | ||
- id: black | ||
language_version: python3 # Should be a command that runs python3.6+ | ||
files: ^(tests|dallinger|dallinger_scripts|demos)/|setup.py | ||
|
||
- repo: https://github.com/PyCQA/flake8 | ||
rev: '6.0.0' | ||
hooks: | ||
- id: flake8 | ||
- repo: https://github.com/pycqa/isort | ||
rev: 5.12.0 | ||
hooks: | ||
- id: isort | ||
args: ["--profile", "black", "--filter-files"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,6 @@ | ||
from dallinger.experiments import Griduniverse | ||
from bams.learners import ActiveLearner | ||
from bams.query_strategies import ( | ||
BALD, | ||
HyperCubePool, | ||
RandomStrategy, | ||
) | ||
from bams.query_strategies import BALD | ||
from dallinger.experiments import Griduniverse | ||
|
||
NDIM = 1 | ||
POOL_SIZE = 500 | ||
|
@@ -14,40 +10,45 @@ | |
|
||
collected_data = {} | ||
|
||
|
||
def num_colors(x): | ||
"""x is the fraction of the total players who are on a single team""" | ||
return int(round(1.0 / x)) | ||
|
||
|
||
def closest_valid_x(x): | ||
x = x[0] | ||
x = max(1.0/6.0, x) # 1/6 is the lowest valid value as we have at most 6 teams | ||
x = max(1.0 / 6.0, x) # 1/6 is the lowest valid value as we have at most 6 teams | ||
num_teams = num_colors(x) | ||
return (1.0 / num_teams, ) | ||
return (1.0 / num_teams,) | ||
|
||
|
||
def scale_up(threshold, dim): | ||
"""Rescale up to actual values""" | ||
out = int(dim * threshold) | ||
return out | ||
|
||
|
||
def scale_down(threshold, dim): | ||
"""Rescale 0 =< output =< 1""" | ||
out = float(dim/threshold) if threshold else 0.0 | ||
out = float(dim / threshold) if threshold else 0.0 | ||
return out | ||
|
||
|
||
def oracle(x): | ||
"""Run a GU game by scaling up the features so they can be input into the game. | ||
Then scale them done so the active learner can understand them. | ||
""" | ||
grid_config = { | ||
"mode": u'live', | ||
"recruiter": u'mturk', | ||
"bot_policy": u"AdvantageSeekingBot", | ||
u'contact_email_on_error': u"[email protected]", | ||
"contact_email_on_error": u"[email protected]", | ||
u'organization_name': u'UC Berkeley', | ||
u'description': u'Play an interactive game', | ||
"dyno_type": u"performance-l", | ||
"redis_size": u"premium-5", | ||
"mode": "live", | ||
"recruiter": "mturk", | ||
"bot_policy": "AdvantageSeekingBot", | ||
"contact_email_on_error": "[email protected]", | ||
"contact_email_on_error": "[email protected]", | ||
"organization_name": "UC Berkeley", | ||
"description": "Play an interactive game", | ||
"dyno_type": "performance-l", | ||
"redis_size": "premium-5", | ||
"num_dynos_worker": 4, | ||
"num_dynos_web": 1, | ||
"max_participants": 12, | ||
|
@@ -67,6 +68,7 @@ def oracle(x): | |
collected_data[data.source] = grid_config | ||
return results | ||
|
||
|
||
def main(): | ||
learner = ActiveLearner( | ||
query_strategy=BALD(dim=NDIM), | ||
|
@@ -80,13 +82,15 @@ def main(): | |
x = learner.next_query() | ||
x = closest_valid_x(x) | ||
y = learner.query(oracle, x) | ||
print x, y | ||
print((x, y)) | ||
learner.update(x, y) | ||
print(learner.posteriors) | ||
print(learner.map_model) | ||
import pdb; pdb.set_trace() | ||
print collected_data | ||
import pdb | ||
|
||
pdb.set_trace() | ||
print(collected_data) | ||
|
||
|
||
if __name__ == '__main__': | ||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.