diff --git a/code_submitter/extract_archives.py b/code_submitter/extract_archives.py index 60ea3f9..4f77826 100644 --- a/code_submitter/extract_archives.py +++ b/code_submitter/extract_archives.py @@ -3,31 +3,43 @@ import asyncio import zipfile import argparse +from typing import cast from pathlib import Path import databases +from sqlalchemy.sql import select from . import utils, config +from .tables import Session -async def async_main(output_archive: Path) -> None: +async def async_main(output_archive: Path, session_name: str) -> None: output_archive.parent.mkdir(parents=True, exist_ok=True) database = databases.Database(config.DATABASE_URL) + session_id = cast(int, await database.fetch_one(select([ + Session.c.id, + ]).where( + Session.c.name == session_name, + ))) + with zipfile.ZipFile(output_archive) as zf: async with database.transaction(): - utils.collect_submissions(database, zf) + utils.collect_submissions(database, zf, session_id) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() + parser.add_argument('session_name', type=str) parser.add_argument('output_archive', type=Path) return parser.parse_args() def main(args: argparse.Namespace) -> None: - asyncio.get_event_loop().run_until_complete(async_main(args.output_archive)) + asyncio.get_event_loop().run_until_complete( + async_main(args.output_archive, args.session_name), + ) if __name__ == '__main__': diff --git a/code_submitter/server.py b/code_submitter/server.py index c382620..361d311 100644 --- a/code_submitter/server.py +++ b/code_submitter/server.py @@ -1,6 +1,7 @@ import io import zipfile -import datetime +import itertools +from typing import cast import databases from sqlalchemy.sql import select @@ -16,7 +17,7 @@ from . import auth, utils, config from .auth import User, BLUESHIRT_SCOPE -from .tables import Archive, ChoiceHistory +from .tables import Archive, Session, ChoiceHistory, ChoiceForSession database = databases.Database(config.DATABASE_URL, force_rollback=config.TESTING) templates = Jinja2Templates(directory='templates') @@ -49,10 +50,34 @@ async def homepage(request: Request) -> Response: Archive.c.created.desc(), ), ) + sessions = await database.fetch_all( + Session.select().order_by(Session.c.created.desc()), + ) + sessions_and_archives = await database.fetch_all( + select([ + Archive.c.id, + Session.c.name, + ]).select_from( + Archive.join(ChoiceHistory).join(ChoiceForSession).join(Session), + ).where( + Archive.c.id.in_(x['id'] for x in uploads), + ).order_by( + Archive.c.id, + ), + ) + sessions_by_upload = { + grouper: [x['name'] for x in items] + for grouper, items in itertools.groupby( + sessions_and_archives, + key=lambda y: cast(int, y['id']), + ) + } return templates.TemplateResponse('index.html', { 'request': request, 'chosen': chosen, 'uploads': uploads, + 'sessions': sessions, + 'sessions_by_upload': sessions_by_upload, 'BLUESHIRT_SCOPE': BLUESHIRT_SCOPE, }) @@ -123,14 +148,39 @@ async def upload(request: Request) -> Response: @requires(['authenticated', BLUESHIRT_SCOPE]) +async def create_session(request: Request) -> Response: + user: User = request.user + form = await request.form() + + await utils.create_session(database, form['name'], by_username=user.username) + + return RedirectResponse( + request.url_for('homepage'), + # 302 so that the browser switches to GET + status_code=302, + ) + + +@requires(['authenticated', BLUESHIRT_SCOPE]) +@database.transaction() async def download_submissions(request: Request) -> Response: + session_id = cast(int, request.path_params['session_id']) + + session = await database.fetch_one( + Session.select().where(Session.c.id == session_id), + ) + + if session is None: + return Response( + f"{session_id!r} is not a valid session id", + status_code=404, + ) + buffer = io.BytesIO() with zipfile.ZipFile(buffer, mode='w') as zf: - await utils.collect_submissions(database, zf) + await utils.collect_submissions(database, zf, session_id) - filename = 'submissions-{now}.zip'.format( - now=datetime.datetime.now(datetime.timezone.utc), - ) + filename = f"submissions-{session['name']}.zip" return Response( buffer.getvalue(), @@ -142,7 +192,12 @@ async def download_submissions(request: Request) -> Response: routes = [ Route('/', endpoint=homepage, methods=['GET']), Route('/upload', endpoint=upload, methods=['POST']), - Route('/download-submissions', endpoint=download_submissions, methods=['GET']), + Route('/create-session', endpoint=create_session, methods=['POST']), + Route( + '/download-submissions/{session_id:int}', + endpoint=download_submissions, + methods=['GET'], + ), ] middleware = [ diff --git a/code_submitter/tables.py b/code_submitter/tables.py index 4c8e5e0..783f27c 100644 --- a/code_submitter/tables.py +++ b/code_submitter/tables.py @@ -38,3 +38,39 @@ server_default=sqlalchemy.func.now(), ), ) + + +# At the point of downloading the archives in order to run matches, you create a +# Session. The act of doing that will also create the required ChoiceForSession +# rows to record which items will be contained in the download. +Session = sqlalchemy.Table( + 'session', + metadata, + sqlalchemy.Column('id', sqlalchemy.Integer, primary_key=True), + sqlalchemy.Column('name', sqlalchemy.String, unique=True, nullable=False), + + sqlalchemy.Column('username', sqlalchemy.String, nullable=False), + + sqlalchemy.Column( + 'created', + sqlalchemy.DateTime(timezone=True), + nullable=False, + server_default=sqlalchemy.func.now(), + ), +) + +# TODO: constrain such that each team can only have one choice per session? +ChoiceForSession = sqlalchemy.Table( + 'choice_for_session', + metadata, + sqlalchemy.Column( + 'choice_id', + sqlalchemy.ForeignKey('choice_history.id'), + primary_key=True, + ), + sqlalchemy.Column( + 'session_id', + sqlalchemy.ForeignKey('session.id'), + primary_key=True, + ), +) diff --git a/code_submitter/utils.py b/code_submitter/utils.py index 8d79bd7..27e1cbc 100644 --- a/code_submitter/utils.py +++ b/code_submitter/utils.py @@ -1,40 +1,78 @@ -from typing import Dict, Tuple +from typing import cast, Dict, Tuple from zipfile import ZipFile import databases from sqlalchemy.sql import select -from .tables import Archive, ChoiceHistory +from .tables import Archive, Session, ChoiceHistory, ChoiceForSession async def get_chosen_submissions( database: databases.Database, + session_id: int, ) -> Dict[str, Tuple[int, bytes]]: """ Return a mapping of teams to their the chosen archive. """ - # Note: Ideally we'd group by team in SQL, however that doesn't seem to work - # properly -- we don't get the ordering applied before the grouping. - rows = await database.fetch_all( select([ Archive.c.id, Archive.c.team, Archive.c.content, - ChoiceHistory.c.created, ]).select_from( - Archive.join(ChoiceHistory), - ).order_by( - Archive.c.team, - ChoiceHistory.c.created.asc(), + Archive.join(ChoiceHistory).join(ChoiceForSession), + ).where( + Session.c.id == session_id, ), ) - # Rely on later keys replacing earlier occurrences of the same key. return {x['team']: (x['id'], x['content']) for x in rows} +async def create_session( + database: databases.Database, + name: str, + *, + by_username: str, +) -> int: + """ + Return a mapping of teams to their the chosen archive. + """ + + # Note: Ideally we'd group by team in SQL, however that doesn't seem to work + # properly -- we don't get the ordering applied before the grouping. + + async with database.transaction(): + rows = await database.fetch_all( + select([ + ChoiceHistory.c.id, + Archive.c.team, + ]).select_from( + Archive.join(ChoiceHistory), + ).order_by( + Archive.c.team, + ChoiceHistory.c.created.asc(), + ), + ) + + session_id = cast(int, await database.execute( + Session.insert().values(name=name, username=by_username), + )) + + # Rely on later keys replacing earlier occurrences of the same key. + choice_by_team = {x['team']: x['id'] for x in rows} + await database.execute_many( + ChoiceForSession.insert(), + [ + {'choice_id': x, 'session_id': session_id} + for x in choice_by_team.values() + ], + ) + + return session_id + + def summarise(submissions: Dict[str, Tuple[int, bytes]]) -> str: return "".join( "{}: {}\n".format(team, id_) @@ -45,8 +83,9 @@ def summarise(submissions: Dict[str, Tuple[int, bytes]]) -> str: async def collect_submissions( database: databases.Database, zipfile: ZipFile, + session_id: int, ) -> None: - submissions = await get_chosen_submissions(database) + submissions = await get_chosen_submissions(database, session_id) for team, (_, content) in submissions.items(): zipfile.writestr(f'{team.upper()}.zip', content) diff --git a/migrations/versions/27f63e48c6c4_create_sessions.py b/migrations/versions/27f63e48c6c4_create_sessions.py new file mode 100644 index 0000000..ed85ec2 --- /dev/null +++ b/migrations/versions/27f63e48c6c4_create_sessions.py @@ -0,0 +1,49 @@ +"""Create Sessions + +Revision ID: 27f63e48c6c4 +Revises: d4e3b890e3d7 +Create Date: 2021-01-09 11:57:18.916146 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = '27f63e48c6c4' +down_revision = 'd4e3b890e3d7' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + 'session', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('name', sa.String(), nullable=False), + sa.Column('username', sa.String(), nullable=False), + sa.Column( + 'created', + sa.DateTime(timezone=True), + server_default=sa.text('(CURRENT_TIMESTAMP)'), + nullable=False, + ), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('name'), + ) + op.create_table( + 'choice_for_session', + sa.Column('choice_id', sa.Integer(), nullable=False), + sa.Column('session_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['choice_id'], ['choice_history.id']), + sa.ForeignKeyConstraint(['session_id'], ['session.id']), + sa.PrimaryKeyConstraint('choice_id', 'session_id'), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('choice_for_session') + op.drop_table('session') + # ### end Alembic commands ### diff --git a/templates/index.html b/templates/index.html index fb4ec22..16bf280 100644 --- a/templates/index.html +++ b/templates/index.html @@ -31,12 +31,54 @@

Virtual Competition Code Submission

+
+
+

Sessions

+ + + + + + {% if BLUESHIRT_SCOPE in request.auth.scopes %} + + {% endif %} + + {% for session in sessions %} + + + + + + {% if BLUESHIRT_SCOPE in request.auth.scopes %} + + {% endif %} + + {% endfor %} +
NameCreatedByDownload
{{ session.name }}{{ session.created }}{{ session.username }} + + ▼ + +
+
+
{% if BLUESHIRT_SCOPE in request.auth.scopes %}
- - Download current chosen submissions - +
+

Create a new session

+
+ + +
+ +
{% endif %} @@ -83,7 +125,7 @@

Upload a new submission for team {{ request.user.team }}

{% endif %}
-
+

Your team's uploads

@@ -91,6 +133,7 @@

Your team's uploads

+ {% for upload in uploads %} Your team's uploads {% endif %} + {% endfor %}
Uploaded By SelectedSessions
{{ ', '.join(sessions_by_upload.get(upload.id, ())) }}
diff --git a/tests/tests_app.py b/tests/tests_app.py index ad91d12..18f1e62 100644 --- a/tests/tests_app.py +++ b/tests/tests_app.py @@ -4,9 +4,15 @@ from unittest import mock import test_utils +from sqlalchemy.sql import select from starlette.testclient import TestClient -from code_submitter.tables import Archive, ChoiceHistory +from code_submitter.tables import ( + Archive, + Session, + ChoiceHistory, + ChoiceForSession, +) class AppTests(test_utils.DatabaseTestCase): @@ -16,11 +22,11 @@ def setUp(self) -> None: # App import must happen after TESTING environment setup from code_submitter.server import app - def url_for(name: str) -> str: + def url_for(name: str, **path_params: str) -> str: # While it makes for uglier tests, we do need to use more absolute # paths here so that the urls emitted contain the root_path from the # ASGI server and in turn work correctly under proxy. - return 'http://testserver{}'.format(app.url_path_for(name)) + return 'http://testserver{}'.format(app.url_path_for(name, **path_params)) test_client = TestClient(app) self.session = test_client.__enter__() @@ -277,8 +283,50 @@ def test_upload_archive_without_robot_py(self) -> None: ) self.assertEqual([], choices, "Should not have created a choice") + def test_create_session_requires_blueshirt(self) -> None: + response = self.session.post( + self.url_for('create_session'), + data={'name': "Test session"}, + ) + self.assertEqual(403, response.status_code) + + def test_create_session(self) -> None: + self.session.auth = ('blueshirt', 'blueshirt') + + response = self.session.post( + self.url_for('create_session'), + data={'name': "Test session"}, + ) + self.assertEqual(302, response.status_code) + self.assertEqual( + self.url_for('homepage'), + response.headers['location'], + ) + + session, = self.await_( + self.database.fetch_all(select([ + Session.c.name, + Session.c.username, + ])), + ) + + self.assertEqual( + { + 'name': 'Test session', + 'username': 'blueshirt', + }, + dict(session), + "Should have created a session", + ) + def test_no_download_link_for_non_blueshirt(self) -> None: - download_url = self.url_for('download_submissions') + session_id = self.await_(self.database.execute( + Session.insert().values( + name="Test session", + username='blueshirt', + ), + )) + download_url = self.url_for('download_submissions', session_id=session_id) response = self.session.get(self.url_for('homepage')) @@ -288,20 +336,50 @@ def test_no_download_link_for_non_blueshirt(self) -> None: def test_shows_download_link_for_blueshirt(self) -> None: self.session.auth = ('blueshirt', 'blueshirt') - download_url = self.url_for('download_submissions') + session_id = self.await_(self.database.execute( + Session.insert().values( + name="Test session", + username='blueshirt', + ), + )) + download_url = self.url_for('download_submissions', session_id=session_id) response = self.session.get(self.url_for('homepage')) html = response.text self.assertIn(download_url, html) def test_download_submissions_requires_blueshirt(self) -> None: - response = self.session.get(self.url_for('download_submissions')) + session_id = self.await_(self.database.execute( + Session.insert().values( + name="Test session", + username='blueshirt', + ), + )) + response = self.session.get( + self.url_for('download_submissions', session_id=session_id), + ) self.assertEqual(403, response.status_code) + def test_download_submissions_when_invalid_session(self) -> None: + self.session.auth = ('blueshirt', 'blueshirt') + response = self.session.get( + self.url_for('download_submissions', session_id='4'), + ) + self.assertEqual(404, response.status_code) + def test_download_submissions_when_none(self) -> None: self.session.auth = ('blueshirt', 'blueshirt') - response = self.session.get(self.url_for('download_submissions')) + session_id = self.await_(self.database.execute( + Session.insert().values( + name="Test session", + username='blueshirt', + ), + )) + + response = self.session.get( + self.url_for('download_submissions', session_id=session_id), + ) self.assertEqual(200, response.status_code) with zipfile.ZipFile(io.BytesIO(response.content)) as zf: @@ -322,7 +400,7 @@ def test_download_submissions(self) -> None: created=datetime.datetime(2020, 8, 8, 12, 0), ), )) - self.await_(self.database.execute( + choice_id = self.await_(self.database.execute( ChoiceHistory.insert().values( archive_id=8888888888, username='test_user', @@ -330,7 +408,22 @@ def test_download_submissions(self) -> None: ), )) - response = self.session.get(self.url_for('download_submissions')) + session_id = self.await_(self.database.execute( + Session.insert().values( + name="Test session", + username='blueshirt', + ), + )) + self.await_(self.database.execute( + ChoiceForSession.insert().values( + choice_id=choice_id, + session_id=session_id, + ), + )) + + response = self.session.get( + self.url_for('download_submissions', session_id=session_id), + ) self.assertEqual(200, response.status_code) with zipfile.ZipFile(io.BytesIO(response.content)) as zf: diff --git a/tests/tests_utils.py b/tests/tests_utils.py index 2a4bd48..ffac35a 100644 --- a/tests/tests_utils.py +++ b/tests/tests_utils.py @@ -5,7 +5,12 @@ import test_utils from code_submitter import utils -from code_submitter.tables import Archive, ChoiceHistory +from code_submitter.tables import ( + Archive, + Session, + ChoiceHistory, + ChoiceForSession, +) class UtilsTests(test_utils.InTransactionTestCase): @@ -41,18 +46,20 @@ def setUp(self) -> None: )) def test_get_chosen_submissions_nothing_chosen(self) -> None: - result = self.await_(utils.get_chosen_submissions(self.database)) + result = self.await_( + utils.get_chosen_submissions(self.database, session_id=0), + ) self.assertEqual({}, result) def test_get_chosen_submissions_multiple_chosen(self) -> None: - self.await_(self.database.execute( + choice_id_1 = self.await_(self.database.execute( ChoiceHistory.insert().values( archive_id=8888888888, username='someone_else', created=datetime.datetime(2020, 8, 8, 12, 0), ), )) - self.await_(self.database.execute( + choice_id_2 = self.await_(self.database.execute( ChoiceHistory.insert().values( archive_id=1111111111, username='test_user', @@ -66,8 +73,28 @@ def test_get_chosen_submissions_multiple_chosen(self) -> None: created=datetime.datetime(2020, 2, 2, 12, 0), ), )) + session_id = self.await_(self.database.execute( + Session.insert().values( + name="Test session", + username='blueshirt', + ), + )) + self.await_(self.database.execute( + ChoiceForSession.insert().values( + choice_id=choice_id_1, + session_id=session_id, + ), + )) + self.await_(self.database.execute( + ChoiceForSession.insert().values( + choice_id=choice_id_2, + session_id=session_id, + ), + )) - result = self.await_(utils.get_chosen_submissions(self.database)) + result = self.await_( + utils.get_chosen_submissions(self.database, session_id), + ) self.assertEqual( { 'SRZ2': (1111111111, b'1111111111'), @@ -76,15 +103,56 @@ def test_get_chosen_submissions_multiple_chosen(self) -> None: result, ) - def test_collect_submissions(self) -> None: - self.await_(self.database.execute( + def test_create_session(self) -> None: + choice_id_1 = self.await_(self.database.execute( ChoiceHistory.insert().values( archive_id=8888888888, username='someone_else', created=datetime.datetime(2020, 8, 8, 12, 0), ), )) + choice_id_2 = self.await_(self.database.execute( + ChoiceHistory.insert().values( + archive_id=1111111111, + username='test_user', + created=datetime.datetime(2020, 3, 3, 12, 0), + ), + )) self.await_(self.database.execute( + ChoiceHistory.insert().values( + archive_id=2222222222, + username='test_user', + created=datetime.datetime(2020, 2, 2, 12, 0), + ), + )) + + session_id = self.await_(utils.create_session( + self.database, + "Test Session", + by_username='the-user', + )) + + choices = self.await_(self.database.fetch_all( + ChoiceForSession.select(), + )) + + self.assertEqual( + [ + {'choice_id': choice_id_1, 'session_id': session_id}, + {'choice_id': choice_id_2, 'session_id': session_id}, + ], + [dict(x) for x in choices], + ) + + def test_collect_submissions(self) -> None: + choice_id_1 = self.await_(self.database.execute( + ChoiceHistory.insert().values( + archive_id=8888888888, + username='someone_else', + created=datetime.datetime(2020, 8, 8, 12, 0), + ), + )) + choice_id_2 = self.await_(self.database.execute( ChoiceHistory.insert().values( archive_id=1111111111, username='test_user', @@ -98,9 +166,27 @@ def test_collect_submissions(self) -> None: created=datetime.datetime(2020, 2, 2, 12, 0), ), )) + session_id = self.await_(self.database.execute( + Session.insert().values( + name="Test session", + username='blueshirt', + ), + )) + self.await_(self.database.execute( + ChoiceForSession.insert().values( + choice_id=choice_id_1, + session_id=session_id, + ), + )) + self.await_(self.database.execute( + ChoiceForSession.insert().values( + choice_id=choice_id_2, + session_id=session_id, + ), + )) with zipfile.ZipFile(io.BytesIO(), mode='w') as zf: - self.await_(utils.collect_submissions(self.database, zf)) + self.await_(utils.collect_submissions(self.database, zf, session_id)) self.assertEqual( {