Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ml: create ML module; add scikit spam model #1063

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions invenio.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ from zenodo_rdm.github.schemas import CitationMetadataSchema
from zenodo_rdm.legacy.resources import record_serializers
from zenodo_rdm.metrics.config import METRICS_CACHE_UPDATE_INTERVAL
from zenodo_rdm.moderation.errors import UserBlockedException
from zenodo_rdm.moderation.handlers import CommunityScoreHandler, RecordScoreHandler
from zenodo_rdm.moderation.handlers import CommunityModerationHandler, RecordModerationHandler
from zenodo_rdm.openaire.records.components import OpenAIREComponent
from zenodo_rdm.permissions import (
ZenodoCommunityPermissionPolicy,
Expand Down Expand Up @@ -817,11 +817,11 @@ RDM_RECORDS_SERVICE_COMPONENTS = DefaultRecordsComponents + [
"""Addd OpenAIRE component to records service."""

RDM_CONTENT_MODERATION_HANDLERS = [
RecordScoreHandler(),
RecordModerationHandler(),
]
"""Records content moderation handlers."""
RDM_COMMUNITY_CONTENT_MODERATION_HANDLERS = [
CommunityScoreHandler(),
CommunityModerationHandler(),
]
"""Community content moderation handlers."""

Expand Down Expand Up @@ -1062,3 +1062,5 @@ COMMUNITIES_SHOW_BROWSE_MENU_ENTRY = True

JOBS_ADMINISTRATION_ENABLED = True
"""Enable Jobs administration view."""

SPAM_DETECTOR_MODEL="spam-scikit:1.0.0"
2 changes: 2 additions & 0 deletions site/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,14 @@ invenio_base.apps =
zenodo_rdm_moderation = zenodo_rdm.moderation.ext:ZenodoModeration
invenio_openaire = zenodo_rdm.openaire.ext:OpenAIRE
zenodo_rdm_stats = zenodo_rdm.stats.ext:ZenodoStats
zenodo_rdm_ml = zenodo_rdm.ml.ext:ZenodoML
invenio_base.api_apps =
zenodo_rdm_legacy = zenodo_rdm.legacy.ext:ZenodoLegacy
profiler = zenodo_rdm.profiler:Profiler
zenodo_rdm_metrics = zenodo_rdm.metrics.ext:ZenodoMetrics
zenodo_rdm_moderation = zenodo_rdm.moderation.ext:ZenodoModeration
invenio_openaire = zenodo_rdm.openaire.ext:OpenAIRE
zenodo_rdm_ml = zenodo_rdm.ml.ext:ZenodoML
invenio_base.api_blueprints =
zenodo_rdm_legacy = zenodo_rdm.legacy.views:blueprint
zenodo_rdm_legacy_records = zenodo_rdm.legacy.views:create_legacy_records_bp
Expand Down
7 changes: 7 additions & 0 deletions site/zenodo_rdm/ml/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# -*- coding: utf-8 -*-
#
# Copyright (C) 2024 CERN.
#
# Zenodo-RDM is free software; you can redistribute it and/or modify
# it under the terms of the MIT License; see LICENSE file for more details.
"""Machine learning module."""
41 changes: 41 additions & 0 deletions site/zenodo_rdm/ml/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# -*- coding: utf-8 -*-
#
# Copyright (C) 2024 CERN.
#
# Zenodo-RDM is free software; you can redistribute it and/or modify
# it under the terms of the MIT License; see LICENSE file for more details.
"""Base class for ML models."""


class MLModel:
"""Base class for ML models."""

def __init__(self, version=None, **kwargs):
"""Constructor."""
self.version = version

def process(self, data, preprocess=None, postprocess=None, raise_exc=True):
"""Pipeline function to call pre/post process with predict."""
try:
preprocessor = preprocess or self.preprocess
postprocessor = postprocess or self.postprocess

preprocessed = preprocessor(data)
prediction = self.predict(preprocessed)
return postprocessor(prediction)
except Exception as e:
if raise_exc:
raise e
return None

def predict(self, data):
"""Predict method to be implemented by subclass."""
raise NotImplementedError()

def preprocess(self, data):
"""Preprocess data."""
return data

def postprocess(self, data):
"""Postprocess data."""
return data
21 changes: 21 additions & 0 deletions site/zenodo_rdm/ml/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
#
# Copyright (C) 2024 CERN.
#
# ZenodoRDM is free software; you can redistribute it and/or modify
# it under the terms of the MIT License; see LICENSE file for more details.

"""Machine learning config."""

from .models import SpamDetectorScikit

ML_MODELS = {
"spam_scikit": SpamDetectorScikit,
}
"""Machine learning models."""

# NOTE Model URL and model host need to be formattable strings for the model name.
ML_KUBEFLOW_MODEL_URL = "CHANGE-{0}-ME"
ML_KUBEFLOW_MODEL_HOST = "{0}-CHANGE"
ML_KUBEFLOW_TOKEN = "CHANGE SECRET"
"""Kubeflow connection config."""
49 changes: 49 additions & 0 deletions site/zenodo_rdm/ml/ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
#
# Copyright (C) 2024 CERN.
#
# ZenodoRDM is free software; you can redistribute it and/or modify
# it under the terms of the MIT License; see LICENSE file for more details.

"""ZenodoRDM machine learning module."""

from flask import current_app

from . import config


class ZenodoML:
"""Zenodo machine learning extension."""

def __init__(self, app=None):
"""Extension initialization."""
if app:
self.init_app(app)

@staticmethod
def init_config(app):
"""Initialize configuration."""
for k in dir(config):
if k.startswith("ML_"):
app.config.setdefault(k, getattr(config, k))

def init_app(self, app):
"""Flask application initialization."""
self.init_config(app)
app.extensions["zenodo-ml"] = self

def _parse_model_name_version(self, model):
"""Parse model name and version."""
vals = model.rsplit(":")
version = vals[1] if len(vals) > 1 else None
return vals[0], version

def models(self, model, **kwargs):
"""Return model based on model name."""
models = current_app.config.get("ML_MODELS", {})
model_name, version = self._parse_model_name_version(model)

if model_name not in models:
raise ValueError("Model not found/registered.")

return models[model_name](version=version, **kwargs)
85 changes: 85 additions & 0 deletions site/zenodo_rdm/ml/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# -*- coding: utf-8 -*-
#
# Copyright (C) 2024 CERN.
#
# ZenodoRDM is free software; you can redistribute it and/or modify
# it under the terms of the MIT License; see LICENSE file for more details.
"""Model definitions."""


import json
import string

import requests
from bs4 import BeautifulSoup
from flask import current_app

from .base import MLModel


class SpamDetectorScikit(MLModel):
"""Spam detection model based on Sklearn."""

MODEL_NAME = "sklearn-spam"
MAX_WORDS = 4000

def __init__(self, version, **kwargs):
"""Constructor. Makes version required."""
super().__init__(version, **kwargs)

def preprocess(self, data):
"""Preprocess data.

Parse HTML, remove punctuation and truncate to max chars.
"""
text = BeautifulSoup(data, "html.parser").get_text()
trans_table = str.maketrans(string.punctuation, " " * len(string.punctuation))
parts = text.translate(trans_table).lower().strip().split(" ")
if len(parts) >= self.MAX_WORDS:
parts = parts[: self.MAX_WORDS]
return " ".join(parts)

def postprocess(self, data):
"""Postprocess data.

Gives spam and ham probability.
"""
result = {
"spam": data["outputs"][0]["data"][0],
"ham": data["outputs"][0]["data"][1],
}
return result

def _send_request_kubeflow(self, data):
"""Send predict request to Kubeflow."""
payload = {
"inputs": [
{
"name": "input-0",
"shape": [1],
"datatype": "BYTES",
"data": [f"{data}"],
}
]
}
model_ref = self.MODEL_NAME + "-" + self.version
url = current_app.config.get("ML_KUBEFLOW_MODEL_URL").format(model_ref)
host = current_app.config.get("ML_KUBEFLOW_MODEL_HOST").format(model_ref)
access_token = current_app.config.get("ML_KUBEFLOW_TOKEN")
r = requests.post(
url,
headers={
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
"Host": host,
},
json=payload,
)
if r.status_code != 200:
raise requests.RequestException("Prediction was not successful.", request=r)
return json.loads(r.text)

def predict(self, data):
"""Get prediction from model."""
prediction = self._send_request_kubeflow(data)
return prediction
12 changes: 12 additions & 0 deletions site/zenodo_rdm/ml/proxies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-
#
# Copyright (C) 2024 CERN.
#
# ZenodoRDM is free software; you can redistribute it and/or modify
# it under the terms of the MIT License; see LICENSE file for more details.
"""Proxy objects for easier access to application objects."""

from flask import current_app
from werkzeug.local import LocalProxy

current_ml_models = LocalProxy(lambda: current_app.extensions["zenodo-ml"])
26 changes: 13 additions & 13 deletions site/zenodo_rdm/moderation/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,21 @@
MODERATION_EXEMPT_USERS = []
"""List of users exempt from moderation."""

MODERATION_RECORD_SCORE_RULES = [
verified_user_rule,
links_rule,
files_rule,
text_sanitization_rule,
match_query_rule,
]
MODERATION_RECORD_SCORE_RULES = {
"verified_user_rule": verified_user_rule,
"links_rule": links_rule,
"files_rule": files_rule,
"text_sanitization_rule": text_sanitization_rule,
"match_query_rule": match_query_rule,
}
"""Scoring rules for record moderation."""

MODERATION_COMMUNITY_SCORE_RULES = [
links_rule,
text_sanitization_rule,
verified_user_rule,
match_query_rule,
]
MODERATION_COMMUNITY_SCORE_RULES = {
"links_rule": links_rule,
"text_sanitization_rule": text_sanitization_rule,
"verified_user_rule": verified_user_rule,
"match_query_rule": match_query_rule,
}
"""Scoring rules for communtiy moderation."""

MODERATION_PERCOLATOR_INDEX_PREFIX = "moderation-queries"
Expand Down
30 changes: 20 additions & 10 deletions site/zenodo_rdm/moderation/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from .uow import ExceptionOp


class BaseScoreHandler:
class BaseModerationHandler:
"""Base handler to calculate moderation scores based on rules."""

def __init__(self, rules=None):
Expand All @@ -56,7 +56,11 @@ def rules(self):
"""Get scoring rules."""
if isinstance(self._rules, str):
return current_app.config[self._rules]
return self._rules or []
return self._rules or {}

def evaluate_result(self, params):
"""Evaluate aggregate result based on params."""
return sum(params.values())

@property
def should_apply_actions(self):
Expand All @@ -77,17 +81,19 @@ def run(self, identity, draft=None, record=None, user=None, uow=None):
)
return

score = 0
for rule in self.rules:
score += rule(identity, draft=draft, record=record)
results = {}
for name, rule in self.rules.items():
results[name] = rule(identity, draft=draft, record=record)

action_ctx = {
"user_id": user.id,
"record_pid": record.pid.pid_value,
"score": score,
"results": results,
}
current_app.logger.debug("Moderation score calculated", extra=action_ctx)
if score > current_scores.spam_threshold:

evaluation = self.evaluate_result(results)
if evaluation > current_scores.spam_threshold:
action_ctx["action"] = "block"
if self.should_apply_actions:
# If user is verified, we need to (re)open the moderation
Expand All @@ -102,9 +108,11 @@ def run(self, identity, draft=None, record=None, user=None, uow=None):
"Block moderation action triggered",
extra=action_ctx,
)
elif score < current_scores.ham_threshold:

elif evaluation < current_scores.ham_threshold:
action_ctx["action"] = "approve"

# If the user is already verified, we don't need to verify again
if user.verified:
current_app.logger.debug(
"User is verified, skipping moderation actions",
Expand Down Expand Up @@ -187,7 +195,7 @@ def _block(self, user, uow, action_ctx):
raise UserBlockedException()


class RecordScoreHandler(BaseHandler, BaseScoreHandler):
class RecordModerationHandler(BaseHandler, BaseModerationHandler):
"""Handler for calculating scores for records."""

def __init__(self):
Expand Down Expand Up @@ -222,7 +230,9 @@ def publish(self, identity, draft=None, record=None, uow=None, **kwargs):
self.run(identity, record=record, user=user, uow=uow)


class CommunityScoreHandler(community_moderation.BaseHandler, BaseScoreHandler):
class CommunityModerationHandler(
community_moderation.BaseHandler, BaseModerationHandler
):
"""Handler for calculating scores for communities."""

def __init__(self):
Expand Down