Skip to content

Commit

Permalink
fix imports for pymongo<4.9
Browse files Browse the repository at this point in the history
  • Loading branch information
mabdinur committed Sep 19, 2024
1 parent d0a2565 commit baab551
Showing 1 changed file with 32 additions and 23 deletions.
55 changes: 32 additions & 23 deletions ddtrace/contrib/internal/pymongo/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@
from .client import set_address_tags


_VERSION = pymongo.version_tuple

if _VERSION >= (4, 9):
from pymongo.synchronous.pool import Connection
from pymongo.synchronous.server import Server
from pymongo.synchronous.topology import Topology
elif _VERSION >= (4, 5):
from pymongo.pool import Connection
from pymongo.server import Server
from pymongo.topology import Topology
else:
from pymongo.pool import SocketInfo as Connection
from pymongo.server import Server
from pymongo.topology import Topology


_CHECKOUT_FN_NAME = "get_socket" if pymongo.version_tuple < (4, 5) else "checkout"


Expand All @@ -41,9 +57,6 @@ def get_version():
return getattr(pymongo, "__version__", "")


_VERSION = pymongo.version_tuple


def patch():
if getattr(pymongo, "_datadog_patch", False):
return
Expand All @@ -60,43 +73,39 @@ def unpatch():

def patch_pymongo_module():
_w(pymongo.MongoClient.__init__, _trace_mongo_client_init)
_w(pymongo.synchronous.topology.Topology.select_server, _trace_topology_select_server)
_w(Topology.select_server, _trace_topology_select_server)
if _VERSION >= (3, 12):
_w(pymongo.synchronous.server.Server.run_operation, _trace_server_run_operation_and_with_response)
_w(Server.run_operation, _trace_server_run_operation_and_with_response)
elif _VERSION >= (3, 9):
_w(pymongo.synchronous.server.Server.run_operation_with_response, _trace_server_run_operation_and_with_response)
_w(Server.run_operation_with_response, _trace_server_run_operation_and_with_response)
else:
_w(pymongo.synchronous.server.Server.send_message_with_response, _trace_server_send_message_with_response)
_w(Server.send_message_with_response, _trace_server_send_message_with_response)

if _VERSION >= (4, 5):
_w(pymongo.synchronous.server.Server.checkout, traced_get_socket)
_w(pymongo.synchronous.pool.Connection.command, _trace_socket_command)
_w(pymongo.synchronous.pool.Connection.write_command, _trace_socket_write_command)
_w(Server.checkout, traced_get_socket)
else:
_w(pymongo.synchronous.server.Server.get_socket, traced_get_socket)
_w(pymongo.synchronous.pool.SocketInfo.command, _trace_socket_command)
_w(pymongo.synchronous.pool.SocketInfo.write_command, _trace_socket_write_command)
_w(Server.get_socket, traced_get_socket)
_w(Connection.command, _trace_socket_command)
_w(Connection.write_command, _trace_socket_write_command)


def unpatch_pymongo_module():
_u(pymongo.MongoClient.__init__, _trace_mongo_client_init)
_u(pymongo.synchronous.topology.Topology.select_server, _trace_topology_select_server)
_u(Topology.select_server, _trace_topology_select_server)

if _VERSION >= (3, 12):
_u(pymongo.synchronous.server.Server.run_operation, _trace_server_run_operation_and_with_response)
_u(Server.run_operation, _trace_server_run_operation_and_with_response)
elif _VERSION >= (3, 9):
_u(pymongo.synchronous.server.Server.run_operation_with_response, _trace_server_run_operation_and_with_response)
_u(Server.run_operation_with_response, _trace_server_run_operation_and_with_response)
else:
_u(pymongo.synchronous.server.Server.send_message_with_response, _trace_server_send_message_with_response)
_u(Server.send_message_with_response, _trace_server_send_message_with_response)

if _VERSION >= (4, 5):
_u(pymongo.synchronous.server.Server.checkout, traced_get_socket)
_u(pymongo.synchronous.pool.Connection.command, _trace_socket_command)
_u(pymongo.synchronous.pool.Connection.write_command, _trace_socket_write_command)
_u(Server.checkout, traced_get_socket)
else:
_u(pymongo.synchronous.server.Server.get_socket, traced_get_socket)
_u(pymongo.synchronous.pool.SocketInfo.command, _trace_socket_command)
_u(pymongo.synchronous.pool.SocketInfo.write_command, _trace_socket_write_command)
_u(Server.get_socket, traced_get_socket)
_u(Connection.command, _trace_socket_command)
_u(Connection.write_command, _trace_socket_write_command)


@contextlib.contextmanager
Expand Down

0 comments on commit baab551

Please sign in to comment.