Skip to content

Commit

Permalink
Merge pull request #570 from roboflow/add-authorizer-for-dedicated-de…
Browse files Browse the repository at this point in the history
…ployment

authorizer for dedicated deployment
  • Loading branch information
PawelPeczek-Roboflow authored Aug 9, 2024
2 parents 53e5f08 + 160569e commit a264772
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 1 deletion.
4 changes: 4 additions & 0 deletions inference/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,3 +410,7 @@

HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
DEVICE = os.getenv("DEVICE")

DEDICATED_DEPLOYMENT_WORKSPACE_URL = os.environ.get(
"DEDICATED_DEPLOYMENT_WORKSPACE_URL", None
)
71 changes: 70 additions & 1 deletion inference/core/interfaces/http/http_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
CORE_MODEL_SAM_ENABLED,
CORE_MODEL_YOLO_WORLD_ENABLED,
CORE_MODELS_ENABLED,
DEDICATED_DEPLOYMENT_WORKSPACE_URL,
DISABLE_WORKFLOW_ENDPOINTS,
LAMBDA,
LEGACY_ROUTE_ENABLED,
Expand Down Expand Up @@ -143,7 +144,11 @@
serialise_workflow_result,
)
from inference.core.managers.base import ModelManager
from inference.core.roboflow_api import get_workflow_specification
from inference.core.roboflow_api import (
get_roboflow_dataset_type,
get_roboflow_workspace,
get_workflow_specification,
)
from inference.core.utils.notebooks import start_notebook
from inference.core.workflows.core_steps.common.entities import StepExecutionMode
from inference.core.workflows.core_steps.common.query_language.errors import (
Expand Down Expand Up @@ -446,6 +451,70 @@ async def count_errors(request: Request, call_next):
self.model_manager.num_errors += 1
return response

if DEDICATED_DEPLOYMENT_WORKSPACE_URL:
cached_api_keys = set()
cached_projects = set()

@app.middleware("http")
async def check_authorization(request: Request, call_next):
# exclude / and /info (health check)
if request.url.path in ["/", "/info"]:
return await call_next(request)

def _unauthorized_response(msg):
return JSONResponse(
status_code=401,
content={
"status": 401,
"message": msg,
},
)

# check api_key
req_params = request.query_params
json_params = (
await request.json()
if request.headers.get("content-type", None) == "application/json"
else dict()
)
api_key = req_params.get("api_key", None) or json_params.get(
"api_key", None
)

if api_key not in cached_api_keys:
try:
workspace_url = (
get_roboflow_workspace(api_key)
if api_key is not None
else None
)

if workspace_url != DEDICATED_DEPLOYMENT_WORKSPACE_URL:
return _unauthorized_response("Unauthorized api_key")

cached_api_keys.add(api_key)
except RoboflowAPINotAuthorizedError as e:
return _unauthorized_response("Unauthorized api_key")

# check project_url
model_id = json_params.get("model_id", "")
project_url = (
req_params.get("project", None)
or json_params.get("project", None)
or model_id.split("/")[0]
)
if project_url is not None and project_url not in cached_projects:
try:
_ = get_roboflow_dataset_type(
api_key, DEDICATED_DEPLOYMENT_WORKSPACE_URL, project_url
)

cached_projects.add(project_url)
except RoboflowAPINotNotFoundError as e:
return _unauthorized_response("Unauthorized project")

return await call_next(request)

self.app = app
self.model_manager = model_manager

Expand Down

0 comments on commit a264772

Please sign in to comment.