diff --git a/pykafka/producer.py b/pykafka/producer.py index ff849f15b..e34b0d9e4 100644 --- a/pykafka/producer.py +++ b/pykafka/producer.py @@ -267,7 +267,7 @@ def stop_owned_brokers(): for queue_reader in queue_readers: queue_reader.join() - def produce(self, message, partition_key=None): + def produce(self, message, partition_key=None, callback=None): """Produce a message. :param message: The message to produce (use None to send null) @@ -275,6 +275,8 @@ def produce(self, message, partition_key=None): :param partition_key: The key to use when deciding which partition to send this message to :type partition_key: bytes + :param callback: function to call upon delivery receipt + :type callback: callable """ if not (isinstance(partition_key, bytes) or partition_key is None): raise TypeError("Producer.produce accepts a bytes object as partition_key, " @@ -292,7 +294,8 @@ def produce(self, message, partition_key=None): partition_id=partition_id, # We must pass our thread-local Queue instance directly, # as results will be written to it in a worker thread - delivery_report_q=self._delivery_reports.queue) + delivery_report_q=self._delivery_reports.queue, + callback=callback) self._produce(msg) if self._synchronous: @@ -373,6 +376,8 @@ def mark_as_delivered(message_batch): owned_broker.increment_messages_pending(-1 * len(message_batch)) req.delivered += len(message_batch) for msg in message_batch: + if msg.callback: + msg.callback() self._delivery_reports.put(msg) try: diff --git a/pykafka/protocol.py b/pykafka/protocol.py index 54f9e058e..597f45e46 100644 --- a/pykafka/protocol.py +++ b/pykafka/protocol.py @@ -152,6 +152,7 @@ class Message(Message, Serializable): :ivar offset: The offset of the message :ivar partition_id: The id of the partition to which this message belongs :ivar delivery_report_q: For use by :class:`pykafka.producer.Producer` + :ivar callback: For use by :class:`pykafka.producer.Producer` """ MAGIC = 0 @@ -163,7 +164,8 @@ class Message(Message, Serializable): "partition_id", "partition", "produce_attempt", - "delivery_report_q" + "delivery_report_q", + "callback" ] def __init__(self, @@ -173,7 +175,8 @@ def __init__(self, offset=-1, partition_id=-1, produce_attempt=0, - delivery_report_q=None): + delivery_report_q=None, + callback=None): self.compression_type = compression_type self.partition_key = partition_key self.value = value @@ -186,6 +189,7 @@ def __init__(self, self.produce_attempt = produce_attempt # delivery_report_q is used by the producer self.delivery_report_q = delivery_report_q + self.callback = callback def __len__(self): size = 4 + 1 + 1 + 4 + 4 diff --git a/tests/pykafka/test_protocol.py b/tests/pykafka/test_protocol.py index 7fc120761..3800df29a 100644 --- a/tests/pykafka/test_protocol.py +++ b/tests/pykafka/test_protocol.py @@ -105,7 +105,8 @@ class TestFetchAPI(unittest2.TestCase): 'partition_id': 0, 'produce_attempt': 0, 'delivery_report_q': None, - 'partition': None + 'partition': None, + 'callback': None }, { 'partition_key': b'test_key', 'compression_type': 0, @@ -114,7 +115,8 @@ class TestFetchAPI(unittest2.TestCase): 'partition_id': 0, 'produce_attempt': 0, 'delivery_report_q': None, - 'partition': None + 'partition': None, + 'callback': None }, { 'partition_key': None, 'compression_type': 0, @@ -123,7 +125,8 @@ class TestFetchAPI(unittest2.TestCase): 'partition_id': 0, 'produce_attempt': 0, 'delivery_report_q': None, - 'partition': None + 'partition': None, + 'callback': None }] def msg_to_dict(self, msg):