Skip to content

Commit

Permalink
[omm] CRUD API functionality for storing credentials (#1543)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dcallies authored Feb 9, 2024
1 parent b6c2f15 commit ff80691
Show file tree
Hide file tree
Showing 13 changed files with 399 additions and 50 deletions.
9 changes: 8 additions & 1 deletion open-media-match/.devcontainer/omm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from threatexchange.content_type.photo import PhotoContent
from threatexchange.content_type.video import VideoContent
from threatexchange.exchanges.impl.static_sample import StaticSampleSignalExchangeAPI
from threatexchange.exchanges.impl.fb_threatexchange_api import (
FBThreatExchangeSignalExchangeAPI,
)

# Database configuration
DBUSER = "media_match"
Expand All @@ -34,7 +37,11 @@
STORAGE_IFACE_INSTANCE = DefaultOMMStore(
signal_types=[PdqSignal, VideoMD5Signal],
content_types=[PhotoContent, VideoContent],
exchange_types=[StaticSampleSignalExchangeAPI, InfiniteRandomExchange],
exchange_types=[
StaticSampleSignalExchangeAPI,
InfiniteRandomExchange,
FBThreatExchangeSignalExchangeAPI,
],
)

# Debugging stuff
Expand Down
66 changes: 59 additions & 7 deletions open-media-match/src/OpenMediaMatch/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,17 @@
import os
import datetime
import sys
import random
import typing as t

import click
import flask
from flask.logging import default_handler
from flask_apscheduler import APScheduler

from threatexchange.signal_type.signal_base import SignalType, CanGenerateRandomSignal
from threatexchange.signal_type.pdq.signal import PdqSignal
from threatexchange.signal_type.md5 import VideoMD5Signal
from threatexchange.exchanges import auth
from threatexchange.exchanges.signal_exchange_api import TSignalExchangeAPICls

from OpenMediaMatch.storage.interface import IUnifiedStore
from OpenMediaMatch.storage import interface
from OpenMediaMatch.storage.postgres.impl import DefaultOMMStore
from OpenMediaMatch.background_tasks import (
build_index,
Expand All @@ -34,7 +32,6 @@
)
from OpenMediaMatch.persistence import get_storage
from OpenMediaMatch.blueprints import development, hashing, matching, curation, ui
from OpenMediaMatch.storage.interface import BankConfig
from OpenMediaMatch.utils import dev_utils


Expand Down Expand Up @@ -90,7 +87,7 @@ def create_app() -> flask.Flask:
app.config["STORAGE_IFACE_INSTANCE"] = DefaultOMMStore()
storage = app.config["STORAGE_IFACE_INSTANCE"]
assert isinstance(
storage, IUnifiedStore
storage, interface.IUnifiedStore
), "STORAGE_IFACE_INSTANCE is not an instance of IUnifiedStore"

_setup_task_logging(app.logger)
Expand Down Expand Up @@ -209,4 +206,59 @@ def build_indices():
storage = get_storage()
build_index.build_all_indices(storage, storage, storage)

@app.cli.command("auth")
@click.argument("api_name", callback=_get_api_cfg)
@click.option(
"--from-str",
help="attempt to use the private _from_str method to auth",
)
@click.option("--unset", is_flag=True, help="clear credentials")
def set_credentials(
api_name: interface.SignalExchangeAPIConfig, from_str: str | None, unset: bool
) -> None:
"""
Persist credentials for apis.
Using the lookup mechanisms built into threatexchange.exchange.auth
attempt to find credentials in the local environment.
The easiest way is usually via an environment variable.
Example, for fb_threatexchange:
TX_ACCESS_TOKEN='12345678|facefaceface' flask auth
"""
api_cfg = api_name # Can't rename arguments, so we rename variable :/
storage = get_storage()
api_cls = api_cfg.api_cls
cred_cls: auth.CredentialHelper = api_cls.get_credential_cls() # type: ignore

if unset:
api_cfg.credentials = None
else:
if from_str is not None:
creds = cred_cls._from_str(from_str)
if creds is None or not creds._are_valid():
raise click.UsageError("Invalid 'from-str'")
else:
try:
creds = cred_cls.get(api_cls)
except auth.SignalExchangeAPIMissingAuthException as e:
raise click.UsageError(e.pretty_str())
except auth.SignalExchangeAPIInvalidAuthException as e:
raise click.UsageError(e.message)
api_cfg.credentials = creds
storage.exchange_api_config_update(api_cfg)

return app


def _get_api_cfg(ctx: click.Context, param: click.Parameter, value: str):
storage = get_storage()
config = storage.exchange_apis_get_configs().get(value)
if config is None:
raise click.BadParameter("No such api")
api_cls = config.api_cls
if not issubclass(api_cls, auth.SignalExchangeWithAuth):
raise click.BadParameter("api doesn't take authentification")
return config
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _fetch(
return
log("Fetching signals for %s from %s", collab.name, collab.api)

api_cls = collab_store.exchange_get_type_configs().get(collab.api)
api_cls = collab_store.exchange_apis_get_installed().get(collab.api)
assert (
api_cls is not None
), f"No such SignalExchangeAPI '{collab.api}' - maybe it was deleted?"
Expand Down
36 changes: 33 additions & 3 deletions open-media-match/src/OpenMediaMatch/blueprints/curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

from flask import Blueprint, Response, request, jsonify, abort
from sqlalchemy.exc import IntegrityError
from werkzeug.exceptions import HTTPException

from threatexchange.utils import dataclass_json
from threatexchange.signal_type.signal_base import SignalType
from werkzeug.exceptions import HTTPException
from threatexchange.exchanges import auth

from OpenMediaMatch import persistence
from OpenMediaMatch.utils import flask_utils
Expand Down Expand Up @@ -180,10 +182,38 @@ def _get_collab(name: str):
# Fetching/Exchanges (aka collaborations)
@bp.route("/exchanges/apis", methods=["GET"])
def exchange_api_list() -> list[str]:
exchange_apis = persistence.get_storage().exchange_get_type_configs()
exchange_apis = persistence.get_storage().exchange_apis_get_configs()
return list(exchange_apis)


@bp.route("/exchanges/api/<string:api_name>", methods=["GET", "POST"])
def exchange_api_config_get_or_update(api_name: str) -> dict[str, t.Any]:
storage = persistence.get_storage()
api_cfg = storage.exchange_apis_get_configs().get(api_name)
if api_cfg is None:
abort(400, f"no such Exchange API '{api_name}'")

if request.method == "POST":
raw_json = request.json
if not isinstance(raw_json, dict):
abort(400, "this endpoint expects a json object payload")
cred_json = raw_json.get("credential_json")
if cred_json is not None:
if not cred_json:
api_cfg.credentials = None
else:
api_cfg.set_credentials_from_json_dict(cred_json)

storage.exchange_api_config_update(api_cfg)

return {
"supports_authentification": issubclass(
api_cfg.api_cls, auth.SignalExchangeWithAuth
),
"has_set_authentification": api_cfg.credentials is not None,
}


@bp.route("/exchanges", methods=["POST"])
def exchange_create():
"""
Expand All @@ -205,7 +235,7 @@ def exchange_create():
abort(400, "Field `api_json` must be object")

storage = persistence.get_storage()
api_types = storage.exchange_get_type_configs()
api_types = storage.exchange_apis_get_installed()

if api_type_name is None:
abort(400, "Field `api_type` is required")
Expand Down
19 changes: 18 additions & 1 deletion open-media-match/src/OpenMediaMatch/blueprints/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,23 @@ def _index_info() -> dict[str, dict[str, t.Any]]:
return index


def _api_cls_info() -> dict[str, dict[str, t.Any]]:
return {
name: {
"auth_note": (
""
if not cfg.supports_auth
else (
"(may need auth)"
if cfg.credentials is None
else "(has credentials)"
)
)
}
for name, cfg in get_storage().exchange_apis_get_configs().items()
}


def _collab_info() -> dict[str, dict[str, t.Any]]:
storage = get_storage()
collabs = storage.exchanges_get()
Expand Down Expand Up @@ -91,7 +108,7 @@ def home():
template_vars = {
"signal": curation.get_all_signal_types(),
"content": curation.get_all_content_types(),
"exchangeApiList": curation.exchange_api_list(),
"exchange_apis": _api_cls_info(),
"bankList": curation.banks_index(),
"production": current_app.config.get("PRODUCTION", True),
"index": _index_info(),
Expand Down
43 changes: 40 additions & 3 deletions open-media-match/src/OpenMediaMatch/storage/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@

import flask

from threatexchange.utils import dataclass_json
from threatexchange.content_type.content_base import ContentType
from threatexchange.signal_type.signal_base import SignalType
from threatexchange.signal_type.index import SignalTypeIndex
from threatexchange.exchanges import auth
from threatexchange.exchanges.fetch_state import (
FetchCheckpointBase,
CollaborationConfigBase,
Expand Down Expand Up @@ -195,6 +197,29 @@ def get_last_index_build_checkpoint(
"""


@dataclass
class SignalExchangeAPIConfig:
"""
Holder for SignalExchangeAPIConfig configuration.
"""

api_cls: TSignalExchangeAPICls
credentials: t.Optional[auth.CredentialHelper] = None

@property
def supports_auth(self):
"""Whether this API takes credentials for authentification"""
return issubclass(self.api_cls, auth.SignalExchangeWithAuth)

def set_credentials_from_json_dict(self, d: dict[str, t.Any]) -> None:
if not self.supports_auth:
raise ValueError(f"{self.api_cls.get_name()} does not support credentials")
cred_cls = t.cast(
auth.SignalExchangeWithAuth, self.api_cls
).get_credential_cls()
self.credentials = dataclass_json.dataclass_load_dict(d, cred_cls)


@dataclass(kw_only=True)
class FetchStatus:
checkpoint_ts: t.Optional[int]
Expand Down Expand Up @@ -226,11 +251,23 @@ def get_default(cls) -> t.Self:
class ISignalExchangeStore(metaclass=abc.ABCMeta):
"""Interface for accessing SignalExchange configuration"""

@abc.abstractmethod
def exchange_get_type_configs(self) -> t.Mapping[str, TSignalExchangeAPICls]:
def exchange_apis_get_installed(self) -> t.Mapping[str, TSignalExchangeAPICls]:
"""
Return all installed SignalExchange types.
"""
return {k: v.api_cls for k, v in self.exchange_apis_get_configs().items()}

@abc.abstractmethod
def exchange_apis_get_configs(self) -> t.Mapping[str, SignalExchangeAPIConfig]:
"""
Returns the configuration for all installed exchange types
"""

@abc.abstractmethod
def exchange_api_config_update(self, cfg: SignalExchangeAPIConfig) -> None:
"""
Update the config for an installed exchange API.
"""

@abc.abstractmethod
def exchange_update(
Expand All @@ -239,7 +276,7 @@ def exchange_update(
"""
Create or update a collaboration/exchange.
If create is false, if the name doesn't .
If create is false, if the name doesn't exist it will throw
If create is true, if the name already exists it will throw
"""

Expand Down
14 changes: 11 additions & 3 deletions open-media-match/src/OpenMediaMatch/storage/mocked.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,19 @@ def get_last_index_build_checkpoint(
return None

# Exchanges
def exchange_get_type_configs(self) -> t.Mapping[str, TSignalExchangeAPICls]:
return {e.get_name(): e for e in (StaticSampleSignalExchangeAPI,)}
def exchange_type_get_configs(
self,
) -> t.Mapping[str, interface.SignalExchangeAPIConfig]:
return {
e.get_name(): interface.SignalExchangeAPIConfig(e)
for e in (StaticSampleSignalExchangeAPI,)
}

def exchange_type_update(self, cfg: interface.SignalExchangeAPIConfig) -> None:
raise Exception("Not implemented")

def exchange_get_api_instance(self, api_cls_name: str) -> TSignalExchangeAPI:
return self.exchange_get_type_configs()[api_cls_name]()
return self.exchange_type_get_configs()[api_cls_name].api_cls()

def exchange_update(
self, cfg: CollaborationConfigBase, *, create: bool = False
Expand Down
37 changes: 37 additions & 0 deletions open-media-match/src/OpenMediaMatch/storage/postgres/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from sqlalchemy.sql import func

from threatexchange.exchanges.collab_config import CollaborationConfigBase
from threatexchange.exchanges import auth
from threatexchange.exchanges.signal_exchange_api import (
TSignalExchangeAPICls,
SignalExchangeAPI,
Expand All @@ -60,6 +61,7 @@
FetchStatus,
SignalTypeIndexBuildCheckpoint,
BankContentIterationItem,
SignalExchangeAPIConfig,
)


Expand Down Expand Up @@ -489,3 +491,38 @@ class SignalTypeOverride(db.Model): # type: ignore[name-defined]
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(255), unique=True, index=True)
enabled_ratio: Mapped[float] = mapped_column(default=1.0)


class ExchangeAPIConfig(db.Model): # type: ignore[name-defined]
"""
Store any per-API config we might need.
"""

id: Mapped[int] = mapped_column(primary_key=True)
api: Mapped[str] = mapped_column(unique=True, index=True)
# If the credentials can't be produced at docker build time, here's a
# backup location to store them. You'll have to modify the OMM code to
# use them how your API expects if it's not one of the natively supported
# Exchange types.
# This should correspond to threatexchange.exchanges.authCredentialHelper
# object
default_credentials_json: Mapped[t.Dict[str, t.Any]] = mapped_column(
JSON, default=None
)

def serialize_credentials(self, creds: auth.CredentialHelper | None) -> None:
if creds is None:
self.default_credentials_json = {}
else:
self.default_credentials_json = dataclass_json.dataclass_dump_dict(creds)

def as_storage_iface_cls(
self, api_cls: TSignalExchangeAPICls
) -> SignalExchangeAPIConfig:
creds = None
if issubclass(api_cls, auth.SignalExchangeWithAuth):
if self.default_credentials_json:
creds = dataclass_json.dataclass_load_dict(
self.default_credentials_json, api_cls.get_credential_cls()
)
return SignalExchangeAPIConfig(api_cls, creds)
Loading

0 comments on commit ff80691

Please sign in to comment.