diff --git a/cli/medperf/commands/benchmark/benchmark.py b/cli/medperf/commands/benchmark/benchmark.py index 35d719b0d..f3b6b669a 100644 --- a/cli/medperf/commands/benchmark/benchmark.py +++ b/cli/medperf/commands/benchmark/benchmark.py @@ -10,6 +10,8 @@ from medperf.commands.benchmark.associate import AssociateBenchmark from medperf.commands.result.create import BenchmarkExecution +from medperf.utils import parse_institutions_file + app = typer.Typer() @@ -56,6 +58,9 @@ def submit( evaluator_mlcube: int = typer.Option( ..., "--evaluator-mlcube", "-e", help="Evaluator MLCube UID" ), + institutions_file: str = typer.Option( + None, "--institutions", "-i", help="CSV file containing the institution and email per expected participant." + ), skip_data_preparation_step: bool = typer.Option( False, "--skip-demo-data-preparation", @@ -68,6 +73,7 @@ def submit( ), ): """Submits a new benchmark to the platform""" + institutions = parse_institutions_file(institutions_file) benchmark_info = { "name": name, "description": description, @@ -77,6 +83,7 @@ def submit( "data_preparation_mlcube": data_preparation_mlcube, "reference_model_mlcube": reference_model_mlcube, "data_evaluator_mlcube": evaluator_mlcube, + "institutions": institutions, "state": "OPERATION" if operational else "DEVELOPMENT", } SubmitBenchmark.run( diff --git a/cli/medperf/entities/benchmark.py b/cli/medperf/entities/benchmark.py index e03fcdb4f..3e11cf97f 100644 --- a/cli/medperf/entities/benchmark.py +++ b/cli/medperf/entities/benchmark.py @@ -28,6 +28,7 @@ class Benchmark(Entity, ApprovableSchema, DeployableSchema): data_evaluator_mlcube: int metadata: dict = {} user_metadata: dict = {} + institutions: dict = {} is_active: bool = True @staticmethod diff --git a/cli/medperf/utils.py b/cli/medperf/utils.py index 35aa697d6..971dd4836 100644 --- a/cli/medperf/utils.py +++ b/cli/medperf/utils.py @@ -2,6 +2,7 @@ import re import os +import csv import signal import yaml import random @@ -21,7 +22,7 @@ from pexpect.exceptions import TIMEOUT from git import Repo, GitCommandError import medperf.config as config -from medperf.exceptions import ExecutionError, MedperfException +from medperf.exceptions import ExecutionError, MedperfException, InvalidArgumentError def get_file_hash(path: str) -> str: @@ -469,6 +470,27 @@ def check_for_updates() -> None: logging.debug(e) +def parse_institutions_file(institutions_file): + institutions = {} + if institutions_file is None: + return institutions + + with open(institutions_file, 'r') as f: + reader = csv.DictReader(f) + fieldnames = set(reader.fieldnames) + exp_fieldnames = {'institution', 'email'} + if len(exp_fieldnames - fieldnames): + raise InvalidArgumentError( + 'Institutions file must contain an "institution" and "email" columns' + ) + for row in reader: + email = row['email'] + institution = row['institution'] + institutions[email] = institution + + return institutions + + class spawn_and_kill: def __init__(self, cmd, timeout=None, *args, **kwargs): self.cmd = cmd diff --git a/server/benchmark/migrations/0003_benchmark_institutions.py b/server/benchmark/migrations/0003_benchmark_institutions.py new file mode 100644 index 000000000..b35b53ceb --- /dev/null +++ b/server/benchmark/migrations/0003_benchmark_institutions.py @@ -0,0 +1,18 @@ +# Generated by Django 3.2.20 on 2024-10-29 19:26 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('benchmark', '0002_alter_benchmark_demo_dataset_tarball_url'), + ] + + operations = [ + migrations.AddField( + model_name='benchmark', + name='institutions', + field=models.JSONField(blank=True, default=dict, null=True), + ), + ] diff --git a/server/benchmark/models.py b/server/benchmark/models.py index 2e0d51e3d..30a3764f4 100644 --- a/server/benchmark/models.py +++ b/server/benchmark/models.py @@ -39,6 +39,7 @@ class Benchmark(models.Model): related_name="data_evaluator_mlcube", ) metadata = models.JSONField(default=dict, blank=True, null=True) + institutions = models.JSONField(default=dict, blank=True, null=True) state = models.CharField( choices=BENCHMARK_STATE, max_length=100, default="DEVELOPMENT" ) diff --git a/server/benchmark/serializers.py b/server/benchmark/serializers.py index 11f007d37..8ffaa204b 100644 --- a/server/benchmark/serializers.py +++ b/server/benchmark/serializers.py @@ -9,6 +9,16 @@ class Meta: fields = "__all__" read_only_fields = ["owner", "approved_at", "approval_status"] + def to_representation(self, instance): + representation = super().to_representation(instance) + request = self.context.get('request') + + # Remove institutions field if user is not the owner + if request and request.user != instance.owner: + representation.pop("institutions", None) + + return representation + def validate(self, data): owner = self.context["request"].user pending_benchmarks = Benchmark.objects.filter( diff --git a/server/benchmark/views.py b/server/benchmark/views.py index 2ca3fd619..d18b1a47c 100644 --- a/server/benchmark/views.py +++ b/server/benchmark/views.py @@ -132,7 +132,7 @@ def get(self, request, pk, format=None): Retrieve a benchmark instance. """ benchmark = self.get_object(pk) - serializer = BenchmarkSerializer(benchmark) + serializer = BenchmarkSerializer(benchmark, context={'request': request}) return Response(serializer.data) def put(self, request, pk, format=None): diff --git a/server/benchmarkdataset/serializers.py b/server/benchmarkdataset/serializers.py index 9cc120079..2be5057a0 100644 --- a/server/benchmarkdataset/serializers.py +++ b/server/benchmarkdataset/serializers.py @@ -1,5 +1,6 @@ from rest_framework import serializers from django.utils import timezone +from django.contrib.auth import get_user_model from benchmark.models import Benchmark from dataset.models import Dataset @@ -75,10 +76,15 @@ def create(self, validated_data): if approval_status != "PENDING": validated_data["approved_at"] = timezone.now() else: - if ( - validated_data["dataset"].owner.id - == validated_data["benchmark"].owner.id - ): + dset_owner = validated_data["dataset"].owner.id + bmk_owner = validated_data["benchmark"].owner.id + expected_emails = validated_data["benchmark"].institutions.keys() + User = get_user_model() + dset_user = User.objects.get(id=dset_owner) + + is_same_owner = dset_owner == bmk_owner + is_expected_participant = dset_user.email in expected_emails + if is_same_owner or is_expected_participant: validated_data["approval_status"] = "APPROVED" validated_data["approved_at"] = timezone.now() return BenchmarkDataset.objects.create(**validated_data)