Skip to content

Commit

Permalink
Merge pull request #1 from hicsail/restructure-test-data
Browse files Browse the repository at this point in the history
Restructure test data
  • Loading branch information
vpsx authored Sep 16, 2024
2 parents 8b4797e + 027a657 commit dbf5d2e
Show file tree
Hide file tree
Showing 3 changed files with 440 additions and 102 deletions.
285 changes: 184 additions & 101 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from bson.objectid import ObjectId
from contextlib import asynccontextmanager
from typing import Annotated, Union
from zipfile import ZipFile

from fastapi import FastAPI, File, Response, status
from fastapi.responses import FileResponse, JSONResponse
Expand Down Expand Up @@ -66,90 +67,117 @@ class Study(BaseModel):

class VisualPairedAssociatesResult(BaseModel):
"""
- vpa_split_times: Time-to-answer per picture question. Should be length <=20. In... milliseconds? Will put Int for now.
- vpa_split_scores: Correct-or-wrong per picture question. Should be length <=20.
- vpa_total_score: Number correct out of 20.
Represents the participant's response to one question in a Visual Paired Associates test.
(The split lists will be length <20 when the participant times out before finishing all 20 questions.)
- vpa_rt: Time-to-answer, in milliseconds.
- vpa_correct: True if participant answered correctly, false otherwise.
- vpa_response: Participant's response (image filename).
A Test of type VISUAL_PAIRED_ASSOCIATES should have a list of VisualPairedAssociatesResults of length <=20; there are 20 questions to a test, but the participant may time out before finishing.
"""

vpa_split_times: list[int]
vpa_split_scores: list[bool]
vpa_total_score: int
vpa_rt: int
vpa_correct: bool
vpa_response: str


class ChoiceReactionTimeResult(BaseModel):
"""
- crt_split_times: Reaction-time per question. Variable length. In... milliseconds? Will put Int for now.
- crt_split_scores: Correct-or-wrong per question. Variable length.
- crt_lefthand_correct: Number of times <left> was the correct answer and participant hit <left>.
- crt_lefthand_attempted: Number of times <left> was the correct answer and participant answered something.
- crt_righthand_correct: Number of times <right> was the correct answer and participant hit <right>.
- crt_righthand_attempted: Number of times <right> was the correct answer and participant answered something.
- crt_total_correct: Number of times participant answered correctly.
- crt_total_attempted: Number of times participant answered something.
(The participant will time out after 90 seconds, so the "attempted" counts will vary, and the split lists will vary in length.)
Represents the participant's response to one question in a Choice Reaction Time test.
- crt_rt: Reaction time, in milliseconds.
- crt_correct: True if participant answered correctly, false otherwise.
- crt_response: Participant's response ("right" or "left").
- crt_dwell: Length of time the response key was pressed/held, in milliseconds.
A Test of type CHOICE_REACTION_TIME should have a list of ChoiceReactionTimeResults, the length of which will vary according to how many questions the participant attempts in the 90 seconds allotted.
"""

crt_split_times: list[int]
crt_split_scores: list[bool]
crt_lefthand_correct: int
crt_lefthand_attempted: int
crt_righthand_correct: int
crt_righthand_attempted: int
crt_total_correct: int
crt_total_attempted: int
class RightOrLeft(enum.StrEnum):
RIGHT = enum.auto()
LEFT = enum.auto()

crt_rt: int
crt_correct: bool
crt_response: RightOrLeft
crt_dwell: int


class DigitSymbolMatchingResult(BaseModel):
"""
- dsm_split_times: Reaction-time per question. Variable length. In... milliseconds? Will put Int for now.
- dsm_split_scores: Correct-or-wrong per question. Variable length.
- dsm_total_correct: Number of times participant answered correctly.
- dsm_total_attempted: Number of times participant answered something.
Represents the participant's response to one question in a Digit Symbol Matching test.
- dsm_rt: Reaction time, in milliseconds.
- dsm_correct: True if participant answered correctly, false otherwise.
- dsm_response: Participant's response (1, 2, or 3).
(The participant will time out after 90 seconds, so the "attempted" count will vary, and the split lists will vary in length.)
A Test of type DIGIT_SYMBOL_MATCHING should have a list of DigitSymbolMatchingResults, the length of which will vary according to how many questions the participant attempts in the 90 seconds allotted.
"""

dsm_split_times: list[int]
dsm_split_scores: list[bool]
dsm_total_correct: int
dsm_total_attempted: int
class OneTwoOrThree(enum.IntEnum):
ONE = 1
TWO = 2
THREE = 3

dsm_rt: int
dsm_correct: bool
dsm_response: OneTwoOrThree


class ImmediateRecallResult(BaseModel):
"""
- ir_split_times: Time-to-answer per attempt. Variable length of 1 or 2. In... milliseconds? Will put Int for now.
- ir_score: 2 pts if correct on first attempt, 1 pt if on second attempt, 0 pts if failed both attempts.
Represents the participant's response to the sole question in an Immediate Recall test, at which they get two attempts.
- ir_rt_first: Time-to-answer for first attempt, in milliseconds.
- ir_rt_second: Time-to-answer for second attempt, in milliseconds. Optional (use only when second attempt was made).
- ir_score: 2 pts if correct on first attempt, 1 pt if correct on second attempt, 0 pts if failed both attempts.
A Test of type IMMEDIATE_RECALL should have one single ImmediateRecallResult.
"""

ir_split_times: list[int]
ir_score: int
class ZeroOneOrTwo(enum.IntEnum):
ZERO = 0
ONE = 1
TWO = 2

ir_rt_first: int
ir_rt_second: Union[int, None] = None
ir_score: ZeroOneOrTwo


class DelayedRecallResult(BaseModel):
"""
- dr_time: Time-to-answer for the one attempt. In... milliseconds? Will put Int for now.
Represents the participant's response to the sole question in a Delayed Recall test.
- dr_rt: Time-to-answer, in milliseconds.
- dr_score: 0-5 pts corresponding to number of animals correctly recalled.
A Test of type DELAYED_RECALL should have one single DelayedRecallResult.
"""

dr_time: int
dr_score: int
class OneToFive(enum.IntEnum):
ONE = 1
TWO = 2
THREE = 3
FOUR = 4
FIVE = 5

dr_rt: int
dr_score: OneToFive


class SpatialMemoryResult(BaseModel):
"""
- sm_split_times: Time-to-answer per puzzle. Variable length <=5. In... milliseconds? Will put Int for now.
- sm_split_scores: Correct-or-wrong per puzzle. Variable length <=5.
- sm_total_correct: 0-5 pts.
Represents the participant's response to one question in a Spatial Memory test.
(No total-attempted field because this may be fewer than 5 if participant timed out, but never more than 5.)
- sm_rt: Time-to-answer, in milliseconds.
- sm_correct: True if participant answered correctly, false otherwise.
A Test of type SPATIAL_MEMORY should have a list of SpatialMemoryResults of length 0 to 5; there are 5 questions to a test, but the participant may time out before finishing.
"""

sm_split_times: list[int]
sm_split_scores: list[bool]
sm_total_correct: int
sm_rt: int
sm_correct: bool


class TestIn(BaseModel):
Expand All @@ -159,23 +187,27 @@ class TestIn(BaseModel):

study_id: str
time_started: datetime.datetime
time_elapsed_milliseconds: int # or timedelta, if desired...
device_info: str
# Optional field for potential msgs like "participant timed out" (added by frontend/client) or
# "could not find study" (added by backend/server) or any other such.
# (So the server - this codebase - should append to this field, not replace it.)
notes: Union[str, None] = None
result: Union[
VisualPairedAssociatesResult,
ChoiceReactionTimeResult,
DigitSymbolMatchingResult,
list[VisualPairedAssociatesResult],
list[ChoiceReactionTimeResult],
list[DigitSymbolMatchingResult],
ImmediateRecallResult,
DelayedRecallResult,
SpatialMemoryResult,
None,
] = None
list[SpatialMemoryResult],
]


class Test(TestIn):
"""
The subset of Test fields that this server provides.
"""

# Let DB maintain _id field; manage test_id separately.
# No strong reason except it slightly simplifies CSV export by removing the need to rename the field.
test_id: str
Expand All @@ -186,6 +218,25 @@ class ErrorMessage(BaseModel):
message: str


test_type_to_result_type = {
TestType.IMMEDIATE_RECALL: ImmediateRecallResult,
TestType.DELAYED_RECALL: DelayedRecallResult,
TestType.CHOICE_REACTION_TIME: ChoiceReactionTimeResult,
TestType.VISUAL_PAIRED_ASSOCIATES: VisualPairedAssociatesResult,
TestType.DIGIT_SYMBOL_MATCHING: DigitSymbolMatchingResult,
TestType.SPATIAL_MEMORY: SpatialMemoryResult,
}

result_type_to_test_type = {
ImmediateRecallResult: TestType.IMMEDIATE_RECALL,
DelayedRecallResult: TestType.DELAYED_RECALL,
ChoiceReactionTimeResult: TestType.CHOICE_REACTION_TIME,
VisualPairedAssociatesResult: TestType.VISUAL_PAIRED_ASSOCIATES,
DigitSymbolMatchingResult: TestType.DIGIT_SYMBOL_MATCHING,
SpatialMemoryResult: TestType.SPATIAL_MEMORY,
}


@app.get("/")
def read_root():
return {"Hello": "World"}
Expand Down Expand Up @@ -334,24 +385,12 @@ def insert_test(test: TestIn, response: Response):
new_test_dict = test.dict()

# infer the test_type...
# ("match type(test.result): case xxxResult: blah" complained of name capture. Wow!)
if type(test.result) is VisualPairedAssociatesResult:
new_test_dict.update({"test_type": TestType.VISUAL_PAIRED_ASSOCIATES})
elif type(test.result) is ChoiceReactionTimeResult:
new_test_dict.update({"test_type": TestType.CHOICE_REACTION_TIME})
elif type(test.result) is DigitSymbolMatchingResult:
new_test_dict.update({"test_type": TestType.DIGIT_SYMBOL_MATCHING})
elif type(test.result) is ImmediateRecallResult:
new_test_dict.update({"test_type": TestType.IMMEDIATE_RECALL})
elif type(test.result) is DelayedRecallResult:
new_test_dict.update({"test_type": TestType.DELAYED_RECALL})
elif type(test.result) is SpatialMemoryResult:
new_test_dict.update({"test_type": TestType.SPATIAL_MEMORY})
else:
# should not get here
raise Exception(
"Could not infer test type; should have been caught by type validation"
if isinstance(test.result, list):
new_test_dict.update(
{"test_type": result_type_to_test_type[type(test.result[0])]}
)
else:
new_test_dict.update({"test_type": result_type_to_test_type[type(test.result)]})

# check that the study_id corresponds to an existing study...
# (note: does _not_ check for an existing submitted test result of this type for this study)
Expand Down Expand Up @@ -385,41 +424,85 @@ def get_tests_as_list():
return [t for t in all_tests]


@app.get("/tests/download-file")
def get_tests_as_csv_file() -> FileResponse:
def write_single_test_type_to_csv_file(
csvfile,
test_type: TestType,
participant_id: str = None,
):
studies = db.get_collection("studies")
tests = db.get_collection("tests")
all_tests = tests.find({}, {"_id": 0}) # Exclude _id field

with open("tests.csv", "w", newline="") as csvfile:

# Get the fieldnames, but throw out the "result" and "study_id" fields.
# (The study_id field will get added back in below, along with all the other study fields.)
fields = Test.model_fields
fields.pop("result")
fields.pop("study_id")
fieldnames = fields.keys()
# Instead, normalize the result fields into the fieldnames.
# (This is why, in the xyzResults class fields, all the field names are prefixed.)
fieldnames ^= VisualPairedAssociatesResult.model_fields.keys()
fieldnames ^= ChoiceReactionTimeResult.model_fields.keys()
fieldnames ^= DigitSymbolMatchingResult.model_fields.keys()
fieldnames ^= ImmediateRecallResult.model_fields.keys()
fieldnames ^= DelayedRecallResult.model_fields.keys()
fieldnames ^= SpatialMemoryResult.model_fields.keys()
# Also normalize the study fields into the fieldnames.
fieldnames ^= Study.model_fields.keys()
all_tests = tests.find({"test_type": test_type}, {"_id": 0}) # Exclude _id field

writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
# Get the fieldnames, but throw out the "result" and "study_id" fields.
# (The study_id field will get added back in below, along with all the other study fields.)
fields = Test.model_fields.copy()
fields.pop("result")
fields.pop("study_id")
fieldnames = fields.keys()

writer.writeheader()
for test in all_tests:
test_result = test.pop("result")
# Add the result (sub)fields for the requested test type into the fieldnames.
fieldnames ^= test_type_to_result_type[test_type].model_fields.keys()

# Also normalize the study fields into the fieldnames.
fieldnames ^= Study.model_fields.keys()

writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()

for test in all_tests:
study_id = test.pop("study_id")
study = studies.find_one({"study_id": study_id})
if participant_id and study["participant_id"] != participant_id:
continue
study.pop("_id")

test_result = test.pop("result")

if isinstance(test_result, list):
for question in test_result:
test.update(question)
test.update(study)
writer.writerow(test)
else:
test.update(test_result)
study_id = test.pop("study_id")
study = studies.find_one({"study_id": study_id})
study.pop("_id")
test.update(study)
writer.writerow(test)

return FileResponse("tests.csv")

@app.get("/tests/zip-archive/download-file")
def get_tests_as_csv_zip_archive(participant_id: str = None) -> FileResponse:
"""
Download results data on all test types, one CSV file per test type, combined into a ZIP archive.
If `participant_id` is given, restrict the results to those concerning that `participant_id`.
If no tests are found for a particular test_type or participant_id, an empty CSV file is generated, containing just the header (field names).
"""

with ZipFile("all-tests.zip", "w") as zipfile:
for test_type in TestType:
csv_filename = test_type.value + ".csv"
with open(csv_filename, "w", newline="") as csvfile:
write_single_test_type_to_csv_file(csvfile, test_type, participant_id)
zipfile.write(csv_filename)

return FileResponse("all-tests.zip")


@app.get("/tests/single-test-type/download-file")
def get_single_test_type_as_csv_file(
test_type: TestType,
participant_id: str = None,
) -> FileResponse:
"""
Download results data on one test type, returned in a CSV file.
If `participant_id` is given, restrict the results to those concerning that `participant_id`.
If no tests are found for a particular test_type or participant_id, an empty CSV file is generated, containing just the header (field names).
"""

with open("test.csv", "w", newline="") as csvfile:
write_single_test_type_to_csv_file(csvfile, test_type, participant_id)

return FileResponse("test.csv")
Loading

0 comments on commit dbf5d2e

Please sign in to comment.