Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Issue #1578] Setup adding task information to all task logs #2196

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions api/src/data_migration/command/load_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import src.adapters.db.flask_db as flask_db
import src.db.foreign
import src.db.models.staging
from src.task.ecs_background_task import ecs_background_task
from src.task.opportunities.set_current_opportunities_task import SetCurrentOpportunitiesTask

from ..data_migration_blueprint import data_migration_blueprint
Expand All @@ -31,6 +32,7 @@
"--insert-chunk-size", default=4000, help="chunk size for load inserts", show_default=True
)
@flask_db.with_db_session()
@ecs_background_task(task_name="load-transform")
def load_transform(
db_session: db.Session, load: bool, transform: bool, set_current: bool, insert_chunk_size: int
) -> None:
Expand Down
37 changes: 20 additions & 17 deletions api/src/logging/flask_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""

import logging
import os
import time
import uuid

Expand All @@ -26,6 +27,8 @@
logger = logging.getLogger(__name__)
EXTRA_LOG_DATA_ATTR = "extra_log_data"

_GLOBAL_LOG_CONTEXT: dict = {}


def init_app(app_logger: logging.Logger, app: flask.Flask) -> None:
"""Initialize the Flask app logger.
Expand All @@ -50,7 +53,7 @@ def init_app(app_logger: logging.Logger, app: flask.Flask) -> None:
# set on the ancestors.
# See https://docs.python.org/3/library/logging.html#logging.Logger.propagate
for handler in app_logger.handlers:
handler.addFilter(_add_app_context_info_to_log_record)
handler.addFilter(_add_global_context_info_to_log_record)
handler.addFilter(_add_request_context_info_to_log_record)

# Add request context data to every log record for the current request
Expand All @@ -63,6 +66,11 @@ def init_app(app_logger: logging.Logger, app: flask.Flask) -> None:
app.before_request(_log_start_request)
app.after_request(_log_end_request)

# Add some metadata to all log messages globally
add_extra_data_to_global_logs(
{"app.name": app.name, "environment": os.environ.get("ENVIRONMENT")}
)

app_logger.info("initialized flask logger")


Expand All @@ -77,6 +85,12 @@ def add_extra_data_to_current_request_logs(
setattr(flask.g, EXTRA_LOG_DATA_ATTR, extra_log_data)


def add_extra_data_to_global_logs(data: dict[str, str | int | float | bool | None]) -> None:
"""Add metadata to all logs for the rest of the lifecycle of this app process"""
global _GLOBAL_LOG_CONTEXT
_GLOBAL_LOG_CONTEXT.update(data)


def _track_request_start_time() -> None:
"""Store the request start time in flask.g"""
flask.g.request_start_time = time.perf_counter()
Expand Down Expand Up @@ -117,20 +131,6 @@ def _log_end_request(response: flask.Response) -> flask.Response:
return response


def _add_app_context_info_to_log_record(record: logging.LogRecord) -> bool:
"""Add app context data to the log record.

If there is no app context, then do not add any data.
"""
if not flask.has_app_context():
return True

assert flask.current_app is not None
record.__dict__ |= _get_app_context_info(flask.current_app)

return True


def _add_request_context_info_to_log_record(record: logging.LogRecord) -> bool:
"""Add request context data to the log record.

Expand All @@ -146,8 +146,11 @@ def _add_request_context_info_to_log_record(record: logging.LogRecord) -> bool:
return True


def _get_app_context_info(app: flask.Flask) -> dict:
return {"app.name": app.name}
def _add_global_context_info_to_log_record(record: logging.LogRecord) -> bool:
global _GLOBAL_LOG_CONTEXT
record.__dict__ |= _GLOBAL_LOG_CONTEXT

return True


def _get_request_context_info(request: flask.Request) -> dict:
Expand Down
2 changes: 2 additions & 0 deletions api/src/search/backend/load_search_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from src.adapters.db import flask_db
from src.search.backend.load_opportunities_to_index import LoadOpportunitiesToIndex
from src.search.backend.load_search_data_blueprint import load_search_data_blueprint
from src.task.ecs_background_task import ecs_background_task


@load_search_data_blueprint.cli.command(
Expand All @@ -16,6 +17,7 @@
help="Whether to run a full refresh, or only incrementally update oppportunities",
)
@flask_db.with_db_session()
@ecs_background_task(task_name="load-opportunity-data-opensearch")
def load_opportunity_data(db_session: db.Session, full_refresh: bool) -> None:
search_client = search.SearchClient()

Expand Down
137 changes: 137 additions & 0 deletions api/src/task/ecs_background_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import contextlib
import logging
import os
import time
import uuid
from functools import wraps
from typing import Callable, Generator, ParamSpec, TypeVar

import requests

from src.logging.flask_logger import add_extra_data_to_global_logs

logger = logging.getLogger(__name__)

P = ParamSpec("P")
T = TypeVar("T")


def ecs_background_task(task_name: str) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""
Decorator for any ECS Task entrypoint function.

This encapsulates the setup required by all ECS tasks, making it easy to:
- add new shared initialization steps for logging
- write new ECS task code without thinking about the boilerplate

Usage:

TASK_NAME = "my-cool-task"

@task_blueprint.cli.command(TASK_NAME, help="For running my cool task")
@ecs_background_task(TASK_NAME)
@flask_db.with_db_session()
def entrypoint(db_session: db.Session):
do_cool_stuff()

Parameters:
task_name (str): Name of the ECS task

IMPORTANT: Do not specify this decorator before the task command.
Click effectively rewrites your function to be a main function
and any decorators from before the "task_blueprint.cli.command(...)"
line are discarded.
See: https://click.palletsprojects.com/en/8.1.x/quickstart/#basic-concepts-creating-a-command
"""

def decorator(f: Callable[P, T]) -> Callable[P, T]:
@wraps(f)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
with _ecs_background_task_impl(task_name):
return f(*args, **kwargs)

return wrapper

return decorator


@contextlib.contextmanager
def _ecs_background_task_impl(task_name: str) -> Generator[None, None, None]:
# The actual implementation, see the docs on the
# decorator method above for details on usage

start = time.perf_counter()
_add_log_metadata(task_name)

# initialize new relic here when we add that

logger.info("Starting ECS task %s", task_name)

try:
yield
except Exception:
# We want to make certain that any exception will always
# be logged as an error
# logger.exception is just an alias for logger.error(<msg>, exc_info=True)
logger.exception("ECS task failed", extra={"status": "error"})
raise

end = time.perf_counter()
duration = round((end - start), 3)
logger.info(
"Completed ECS task %s",
task_name,
extra={"ecs_task_duration_sec": duration, "status": "success"},
)


def _add_log_metadata(task_name: str) -> None:
# Note we set an "aws.ecs.task_name" as well pulled from ECS
# which may be different as that value is set based on our infra setup
# while this one is just based on whatever we passed the @ecs_background_task decorator
add_extra_data_to_global_logs({"task_name": task_name, "task_uuid": str(uuid.uuid4())})
add_extra_data_to_global_logs(_get_ecs_metadata())


def _get_ecs_metadata() -> dict:
"""
Retrieves ECS metadata from an AWS-provided metadata URI. This URI is injected to all ECS tasks by AWS as an envar.
See https://docs.aws.amazon.com/AmazonECS/latest/userguide/task-metadata-endpoint-v4-fargate.html for more.
"""
ecs_metadata_uri = os.environ.get("ECS_CONTAINER_METADATA_URI_V4")

if os.environ.get("ENVIRONMENT", "local") == "local" or ecs_metadata_uri is None:
logger.info(
"ECS metadata not available for local environments. Run this task on ECS to see metadata."
)
return {}

task_metadata = requests.get(ecs_metadata_uri, timeout=1) # 1sec timeout
logger.info("Retrieved task metadata from ECS")
metadata_json = task_metadata.json()

ecs_task_name = metadata_json["Name"]
ecs_task_id = metadata_json["Labels"]["com.amazonaws.ecs.task-arn"].split("/")[-1]
ecs_taskdef = ":".join(
[
metadata_json["Labels"]["com.amazonaws.ecs.task-definition-family"],
metadata_json["Labels"]["com.amazonaws.ecs.task-definition-version"],
]
)
cloudwatch_log_group = metadata_json["LogOptions"]["awslogs-group"]
cloudwatch_log_stream = metadata_json["LogOptions"]["awslogs-stream"]

# Step function only
sfn_execution_id = os.environ.get("SFN_EXECUTION_ID")
sfn_id = sfn_execution_id.split(":")[-2] if sfn_execution_id is not None else None

return {
"aws.ecs.task_name": ecs_task_name,
"aws.ecs.task_id": ecs_task_id,
"aws.ecs.task_definition": ecs_taskdef,
# these will be added automatically by New Relic log ingester, but
# just to be sure and for non-log usages, explicitly declare them
"aws.cloudwatch.log_group": cloudwatch_log_group,
"aws.cloudwatch.log_stream": cloudwatch_log_stream,
"aws.step_function.id": sfn_id,
}
2 changes: 2 additions & 0 deletions api/src/task/opportunities/export_opportunity_data_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from src.api.opportunities_v1.opportunity_schemas import OpportunityV1Schema
from src.db.models.opportunity_models import CurrentOpportunitySummary, Opportunity
from src.services.opportunities_v1.opportunity_to_csv import opportunities_to_csv
from src.task.ecs_background_task import ecs_background_task
from src.task.task import Task
from src.task.task_blueprint import task_blueprint
from src.util.datetime_util import get_now_us_eastern_datetime
Expand All @@ -27,6 +28,7 @@
help="Generate JSON and CSV files containing an export of all opportunity data",
)
@flask_db.with_db_session()
@ecs_background_task(task_name="export-opportunity-data")
def export_opportunity_data(db_session: db.Session) -> None:
ExportOpportunityDataTask(db_session).run()

Expand Down
59 changes: 59 additions & 0 deletions api/tests/src/task/test_ecs_background_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import logging
import time

import pytest

from src.logging.flask_logger import add_extra_data_to_global_logs
from src.task.ecs_background_task import ecs_background_task


def test_ecs_background_task(app, caplog):
# We pull in the app so its initialized
# Global logging params like the task name are stored on the app
caplog.set_level(logging.INFO)

@ecs_background_task(task_name="my_test_task_name")
def my_test_func(param1, param2):
# Add a brief sleep so that we can test the duration logic
time.sleep(0.2) # 0.2s
add_extra_data_to_global_logs({"example_param": 12345})

return param1 + param2

# Verify the function works uneventfully
assert my_test_func(1, 2) == 3

for record in caplog.records:
extra = record.__dict__
assert extra["task_name"] == "my_test_task_name"

last_record = caplog.records[-1].__dict__
# Make sure the ECS task duration was tracked
allowed_error = 0.1
assert last_record["ecs_task_duration_sec"] == pytest.approx(0.2, abs=allowed_error)
# Make sure the extra we added was put in this automatically
assert last_record["example_param"] == 12345
assert last_record["message"] == "Completed ECS task my_test_task_name"


def test_ecs_background_task_when_erroring(app, caplog):
caplog.set_level(logging.INFO)

@ecs_background_task(task_name="my_error_test_task_name")
def my_test_error_func():
add_extra_data_to_global_logs({"another_param": "hello"})

raise ValueError("I am an error")

with pytest.raises(ValueError, match="I am an error"):
my_test_error_func()

for record in caplog.records:
extra = record.__dict__
assert extra["task_name"] == "my_error_test_task_name"

last_record = caplog.records[-1].__dict__

assert last_record["another_param"] == "hello"
assert last_record["levelname"] == "ERROR"
assert last_record["message"] == "ECS task failed"
Loading