diff --git a/sentry_sdk/integrations/spark/spark_driver.py b/sentry_sdk/integrations/spark/spark_driver.py index c6470f2302..a86f16344d 100644 --- a/sentry_sdk/integrations/spark/spark_driver.py +++ b/sentry_sdk/integrations/spark/spark_driver.py @@ -9,6 +9,7 @@ from typing import Optional from sentry_sdk._types import Event, Hint + from pyspark import SparkContext class SparkIntegration(Integration): @@ -17,7 +18,7 @@ class SparkIntegration(Integration): @staticmethod def setup_once(): # type: () -> None - patch_spark_context_init() + _setup_sentry_tracing() def _set_app_properties(): @@ -37,7 +38,7 @@ def _set_app_properties(): def _start_sentry_listener(sc): - # type: (Any) -> None + # type: (SparkContext) -> None """ Start java gateway server to add custom `SparkListener` """ @@ -49,7 +50,51 @@ def _start_sentry_listener(sc): sc._jsc.sc().addSparkListener(listener) -def patch_spark_context_init(): +def _add_event_processor(sc): + # type: (SparkContext) -> None + scope = sentry_sdk.get_isolation_scope() + + @scope.add_event_processor + def process_event(event, hint): + # type: (Event, Hint) -> Optional[Event] + with capture_internal_exceptions(): + if sentry_sdk.get_client().get_integration(SparkIntegration) is None: + return event + + if sc._active_spark_context is None: + return event + + event.setdefault("user", {}).setdefault("id", sc.sparkUser()) + + event.setdefault("tags", {}).setdefault( + "executor.id", sc._conf.get("spark.executor.id") + ) + event["tags"].setdefault( + "spark-submit.deployMode", + sc._conf.get("spark.submit.deployMode"), + ) + event["tags"].setdefault("driver.host", sc._conf.get("spark.driver.host")) + event["tags"].setdefault("driver.port", sc._conf.get("spark.driver.port")) + event["tags"].setdefault("spark_version", sc.version) + event["tags"].setdefault("app_name", sc.appName) + event["tags"].setdefault("application_id", sc.applicationId) + event["tags"].setdefault("master", sc.master) + event["tags"].setdefault("spark_home", sc.sparkHome) + + event.setdefault("extra", {}).setdefault("web_url", sc.uiWebUrl) + + return event + + +def _activate_integration(sc): + # type: (SparkContext) -> None + + _start_sentry_listener(sc) + _set_app_properties() + _add_event_processor(sc) + + +def _patch_spark_context_init(): # type: () -> None from pyspark import SparkContext @@ -59,51 +104,22 @@ def patch_spark_context_init(): def _sentry_patched_spark_context_init(self, *args, **kwargs): # type: (SparkContext, *Any, **Any) -> Optional[Any] rv = spark_context_init(self, *args, **kwargs) - _start_sentry_listener(self) - _set_app_properties() - - scope = sentry_sdk.get_isolation_scope() - - @scope.add_event_processor - def process_event(event, hint): - # type: (Event, Hint) -> Optional[Event] - with capture_internal_exceptions(): - if sentry_sdk.get_client().get_integration(SparkIntegration) is None: - return event - - if self._active_spark_context is None: - return event - - event.setdefault("user", {}).setdefault("id", self.sparkUser()) - - event.setdefault("tags", {}).setdefault( - "executor.id", self._conf.get("spark.executor.id") - ) - event["tags"].setdefault( - "spark-submit.deployMode", - self._conf.get("spark.submit.deployMode"), - ) - event["tags"].setdefault( - "driver.host", self._conf.get("spark.driver.host") - ) - event["tags"].setdefault( - "driver.port", self._conf.get("spark.driver.port") - ) - event["tags"].setdefault("spark_version", self.version) - event["tags"].setdefault("app_name", self.appName) - event["tags"].setdefault("application_id", self.applicationId) - event["tags"].setdefault("master", self.master) - event["tags"].setdefault("spark_home", self.sparkHome) - - event.setdefault("extra", {}).setdefault("web_url", self.uiWebUrl) - - return event - + _activate_integration(self) return rv SparkContext._do_init = _sentry_patched_spark_context_init +def _setup_sentry_tracing(): + # type: () -> None + from pyspark import SparkContext + + if SparkContext._active_spark_context is not None: + _activate_integration(SparkContext._active_spark_context) + return + _patch_spark_context_init() + + class SparkListener: def onApplicationEnd(self, applicationEnd): # noqa: N802,N803 # type: (Any) -> None @@ -208,10 +224,21 @@ class Java: class SentryListener(SparkListener): + def _add_breadcrumb( + self, + level, # type: str + message, # type: str + data=None, # type: Optional[dict[str, Any]] + ): + # type: (...) -> None + sentry_sdk.get_global_scope().add_breadcrumb( + level=level, message=message, data=data + ) + def onJobStart(self, jobStart): # noqa: N802,N803 # type: (Any) -> None message = "Job {} Started".format(jobStart.jobId()) - sentry_sdk.add_breadcrumb(level="info", message=message) + self._add_breadcrumb(level="info", message=message) _set_app_properties() def onJobEnd(self, jobEnd): # noqa: N802,N803 @@ -227,14 +254,14 @@ def onJobEnd(self, jobEnd): # noqa: N802,N803 level = "warning" message = "Job {} Failed".format(jobEnd.jobId()) - sentry_sdk.add_breadcrumb(level=level, message=message, data=data) + self._add_breadcrumb(level=level, message=message, data=data) def onStageSubmitted(self, stageSubmitted): # noqa: N802,N803 # type: (Any) -> None stage_info = stageSubmitted.stageInfo() message = "Stage {} Submitted".format(stage_info.stageId()) data = {"attemptId": stage_info.attemptId(), "name": stage_info.name()} - sentry_sdk.add_breadcrumb(level="info", message=message, data=data) + self._add_breadcrumb(level="info", message=message, data=data) _set_app_properties() def onStageCompleted(self, stageCompleted): # noqa: N802,N803 @@ -255,4 +282,4 @@ def onStageCompleted(self, stageCompleted): # noqa: N802,N803 message = "Stage {} Completed".format(stage_info.stageId()) level = "info" - sentry_sdk.add_breadcrumb(level=level, message=message, data=data) + self._add_breadcrumb(level=level, message=message, data=data) diff --git a/tests/integrations/asgi/test_asgi.py b/tests/integrations/asgi/test_asgi.py index e0a3900a38..f3bc7147bf 100644 --- a/tests/integrations/asgi/test_asgi.py +++ b/tests/integrations/asgi/test_asgi.py @@ -128,7 +128,6 @@ async def app(scope, receive, send): @pytest.fixture def asgi3_custom_transaction_app(): - async def app(scope, receive, send): sentry_sdk.get_current_scope().set_transaction_name("foobar", source="custom") await send( diff --git a/tests/integrations/spark/test_spark.py b/tests/integrations/spark/test_spark.py index 58c8862ee2..44ba9f8728 100644 --- a/tests/integrations/spark/test_spark.py +++ b/tests/integrations/spark/test_spark.py @@ -1,6 +1,7 @@ import pytest import sys from unittest.mock import patch + from sentry_sdk.integrations.spark.spark_driver import ( _set_app_properties, _start_sentry_listener, @@ -18,8 +19,22 @@ ################ -def test_set_app_properties(): - spark_context = SparkContext(appName="Testing123") +@pytest.fixture(scope="function") +def sentry_init_with_reset(sentry_init): + from sentry_sdk.integrations import _processed_integrations + + yield lambda: sentry_init(integrations=[SparkIntegration()]) + _processed_integrations.remove("spark") + + +@pytest.fixture(scope="function") +def create_spark_context(): + yield lambda: SparkContext(appName="Testing123") + SparkContext._active_spark_context.stop() + + +def test_set_app_properties(create_spark_context): + spark_context = create_spark_context() _set_app_properties() assert spark_context.getLocalProperty("sentry_app_name") == "Testing123" @@ -30,9 +45,8 @@ def test_set_app_properties(): ) -def test_start_sentry_listener(): - spark_context = SparkContext.getOrCreate() - +def test_start_sentry_listener(create_spark_context): + spark_context = create_spark_context() gateway = spark_context._gateway assert gateway._callback_server is None @@ -41,9 +55,28 @@ def test_start_sentry_listener(): assert gateway._callback_server is not None -def test_initialize_spark_integration(sentry_init): - sentry_init(integrations=[SparkIntegration()]) - SparkContext.getOrCreate() +@patch("sentry_sdk.integrations.spark.spark_driver._patch_spark_context_init") +def test_initialize_spark_integration_before_spark_context_init( + mock_patch_spark_context_init, + sentry_init_with_reset, + create_spark_context, +): + sentry_init_with_reset() + create_spark_context() + + mock_patch_spark_context_init.assert_called_once() + + +@patch("sentry_sdk.integrations.spark.spark_driver._activate_integration") +def test_initialize_spark_integration_after_spark_context_init( + mock_activate_integration, + create_spark_context, + sentry_init_with_reset, +): + create_spark_context() + sentry_init_with_reset() + + mock_activate_integration.assert_called_once() @pytest.fixture @@ -54,88 +87,83 @@ def sentry_listener(): return listener -@pytest.fixture -def mock_add_breadcrumb(): - with patch("sentry_sdk.add_breadcrumb") as mock: - yield mock - - -def test_sentry_listener_on_job_start(sentry_listener, mock_add_breadcrumb): +def test_sentry_listener_on_job_start(sentry_listener): listener = sentry_listener + with patch.object(listener, "_add_breadcrumb") as mock_add_breadcrumb: - class MockJobStart: - def jobId(self): # noqa: N802 - return "sample-job-id-start" + class MockJobStart: + def jobId(self): # noqa: N802 + return "sample-job-id-start" - mock_job_start = MockJobStart() - listener.onJobStart(mock_job_start) + mock_job_start = MockJobStart() + listener.onJobStart(mock_job_start) - mock_add_breadcrumb.assert_called_once() - mock_hub = mock_add_breadcrumb.call_args + mock_add_breadcrumb.assert_called_once() + mock_hub = mock_add_breadcrumb.call_args - assert mock_hub.kwargs["level"] == "info" - assert "sample-job-id-start" in mock_hub.kwargs["message"] + assert mock_hub.kwargs["level"] == "info" + assert "sample-job-id-start" in mock_hub.kwargs["message"] @pytest.mark.parametrize( "job_result, level", [("JobSucceeded", "info"), ("JobFailed", "warning")] ) -def test_sentry_listener_on_job_end( - sentry_listener, mock_add_breadcrumb, job_result, level -): +def test_sentry_listener_on_job_end(sentry_listener, job_result, level): listener = sentry_listener + with patch.object(listener, "_add_breadcrumb") as mock_add_breadcrumb: - class MockJobResult: - def toString(self): # noqa: N802 - return job_result + class MockJobResult: + def toString(self): # noqa: N802 + return job_result - class MockJobEnd: - def jobId(self): # noqa: N802 - return "sample-job-id-end" + class MockJobEnd: + def jobId(self): # noqa: N802 + return "sample-job-id-end" - def jobResult(self): # noqa: N802 - result = MockJobResult() - return result + def jobResult(self): # noqa: N802 + result = MockJobResult() + return result - mock_job_end = MockJobEnd() - listener.onJobEnd(mock_job_end) + mock_job_end = MockJobEnd() + listener.onJobEnd(mock_job_end) - mock_add_breadcrumb.assert_called_once() - mock_hub = mock_add_breadcrumb.call_args + mock_add_breadcrumb.assert_called_once() + mock_hub = mock_add_breadcrumb.call_args - assert mock_hub.kwargs["level"] == level - assert mock_hub.kwargs["data"]["result"] == job_result - assert "sample-job-id-end" in mock_hub.kwargs["message"] + assert mock_hub.kwargs["level"] == level + assert mock_hub.kwargs["data"]["result"] == job_result + assert "sample-job-id-end" in mock_hub.kwargs["message"] -def test_sentry_listener_on_stage_submitted(sentry_listener, mock_add_breadcrumb): +def test_sentry_listener_on_stage_submitted(sentry_listener): listener = sentry_listener + with patch.object(listener, "_add_breadcrumb") as mock_add_breadcrumb: - class StageInfo: - def stageId(self): # noqa: N802 - return "sample-stage-id-submit" + class StageInfo: + def stageId(self): # noqa: N802 + return "sample-stage-id-submit" - def name(self): - return "run-job" + def name(self): + return "run-job" - def attemptId(self): # noqa: N802 - return 14 + def attemptId(self): # noqa: N802 + return 14 - class MockStageSubmitted: - def stageInfo(self): # noqa: N802 - stageinf = StageInfo() - return stageinf + class MockStageSubmitted: + def stageInfo(self): # noqa: N802 + stageinf = StageInfo() + return stageinf - mock_stage_submitted = MockStageSubmitted() - listener.onStageSubmitted(mock_stage_submitted) + mock_stage_submitted = MockStageSubmitted() + listener.onStageSubmitted(mock_stage_submitted) - mock_add_breadcrumb.assert_called_once() - mock_hub = mock_add_breadcrumb.call_args + mock_add_breadcrumb.assert_called_once() + mock_hub = mock_add_breadcrumb.call_args - assert mock_hub.kwargs["level"] == "info" - assert "sample-stage-id-submit" in mock_hub.kwargs["message"] - assert mock_hub.kwargs["data"]["attemptId"] == 14 - assert mock_hub.kwargs["data"]["name"] == "run-job" + assert mock_hub.kwargs["level"] == "info" + assert "sample-stage-id-submit" in mock_hub.kwargs["message"] + assert mock_hub.kwargs["data"]["attemptId"] == 14 + assert mock_hub.kwargs["data"]["name"] == "run-job" @pytest.fixture @@ -175,39 +203,39 @@ def stageInfo(self): # noqa: N802 def test_sentry_listener_on_stage_completed_success( - sentry_listener, mock_add_breadcrumb, get_mock_stage_completed + sentry_listener, get_mock_stage_completed ): listener = sentry_listener + with patch.object(listener, "_add_breadcrumb") as mock_add_breadcrumb: + mock_stage_completed = get_mock_stage_completed(failure_reason=False) + listener.onStageCompleted(mock_stage_completed) - mock_stage_completed = get_mock_stage_completed(failure_reason=False) - listener.onStageCompleted(mock_stage_completed) - - mock_add_breadcrumb.assert_called_once() - mock_hub = mock_add_breadcrumb.call_args + mock_add_breadcrumb.assert_called_once() + mock_hub = mock_add_breadcrumb.call_args - assert mock_hub.kwargs["level"] == "info" - assert "sample-stage-id-submit" in mock_hub.kwargs["message"] - assert mock_hub.kwargs["data"]["attemptId"] == 14 - assert mock_hub.kwargs["data"]["name"] == "run-job" - assert "reason" not in mock_hub.kwargs["data"] + assert mock_hub.kwargs["level"] == "info" + assert "sample-stage-id-submit" in mock_hub.kwargs["message"] + assert mock_hub.kwargs["data"]["attemptId"] == 14 + assert mock_hub.kwargs["data"]["name"] == "run-job" + assert "reason" not in mock_hub.kwargs["data"] def test_sentry_listener_on_stage_completed_failure( - sentry_listener, mock_add_breadcrumb, get_mock_stage_completed + sentry_listener, get_mock_stage_completed ): listener = sentry_listener - - mock_stage_completed = get_mock_stage_completed(failure_reason=True) - listener.onStageCompleted(mock_stage_completed) - - mock_add_breadcrumb.assert_called_once() - mock_hub = mock_add_breadcrumb.call_args - - assert mock_hub.kwargs["level"] == "warning" - assert "sample-stage-id-submit" in mock_hub.kwargs["message"] - assert mock_hub.kwargs["data"]["attemptId"] == 14 - assert mock_hub.kwargs["data"]["name"] == "run-job" - assert mock_hub.kwargs["data"]["reason"] == "failure-reason" + with patch.object(listener, "_add_breadcrumb") as mock_add_breadcrumb: + mock_stage_completed = get_mock_stage_completed(failure_reason=True) + listener.onStageCompleted(mock_stage_completed) + + mock_add_breadcrumb.assert_called_once() + mock_hub = mock_add_breadcrumb.call_args + + assert mock_hub.kwargs["level"] == "warning" + assert "sample-stage-id-submit" in mock_hub.kwargs["message"] + assert mock_hub.kwargs["data"]["attemptId"] == 14 + assert mock_hub.kwargs["data"]["name"] == "run-job" + assert mock_hub.kwargs["data"]["reason"] == "failure-reason" ################