Skip to content

Commit

Permalink
Fix: fix readMessageBegin name type error
Browse files Browse the repository at this point in the history
Client: ["python"]
  • Loading branch information
bwangelme committed Nov 6, 2023
1 parent 0eab6e0 commit f139196
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
3 changes: 2 additions & 1 deletion lib/py/src/protocol/TBinaryProtocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#

from .TProtocol import TType, TProtocolBase, TProtocolException, TProtocolFactory
from ..compat import binary_to_str
from struct import pack, unpack


Expand Down Expand Up @@ -145,7 +146,7 @@ def readMessageBegin(self):
if self.strictRead:
raise TProtocolException(type=TProtocolException.BAD_VERSION,
message='No protocol version header')
name = self.trans.readAll(sz)
name = binary_to_str(self.trans.readAll(sz))
type = self.readByte()
seqid = self.readI32()
return (name, type, seqid)
Expand Down
28 changes: 25 additions & 3 deletions lib/py/test/thrift_TBinaryProtocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,19 @@ def testField(type, data):
protocol.readStructEnd()


def testMessage(data):
def testMessage(data, strict=True):
message = {}
message['name'] = data[0]
message['type'] = data[1]
message['seqid'] = data[2]

strictRead, strictWrite = True, True
if not strict:
strictRead, strictWrite = False, False

buf = TTransport.TMemoryBuffer()
transport = TTransport.TBufferedTransportFactory().getTransport(buf)
protocol = TBinaryProtocol(transport)
protocol = TBinaryProtocol(transport, strictRead=strictRead, strictWrite=strictWrite)
protocol.writeMessageBegin(message['name'], message['type'], message['seqid'])
protocol.writeMessageEnd()

Expand All @@ -169,7 +173,7 @@ def testMessage(data):

buf = TTransport.TMemoryBuffer(data_r)
transport = TTransport.TBufferedTransportFactory().getTransport(buf)
protocol = TBinaryProtocol(transport)
protocol = TBinaryProtocol(transport, strictRead=strictRead, strictWrite=strictWrite)
result = protocol.readMessageBegin()
protocol.readMessageEnd()
return result
Expand Down Expand Up @@ -259,6 +263,24 @@ def test_TBinaryProtocol_write_read(self):
print("Assertion fail")
raise e

def test_TBinaryProtocol_no_strict_write_read(self):
TMessageType = {"T_CALL": 1, "T_REPLY": 2, "T_EXCEPTION": 3, "T_ONEWAY": 4}
test_data = [("short message name", TMessageType['T_CALL'], 0),
("1", TMessageType['T_REPLY'], 12345),
("loooooooooooooooooooooooooooooooooong", TMessageType['T_EXCEPTION'], 1 << 16),
("one way push", TMessageType['T_ONEWAY'], 12),
("Janky", TMessageType['T_CALL'], 0)]

try:
for dt in test_data:
result = testMessage(dt, strict=False)
self.assertEqual(result[0], dt[0])
self.assertEqual(result[1], dt[1])
self.assertEqual(result[2], dt[2])
except Exception as e:
print("Assertion fail")
raise e


if __name__ == '__main__':
unittest.main()

0 comments on commit f139196

Please sign in to comment.