Skip to content

Commit

Permalink
[core] Fixed the PacketFilter configuration not counting the AEAD AUT…
Browse files Browse the repository at this point in the history
…H tag (Haivision#2880).

* Fixed build break due to const use of UniquePtr.
  • Loading branch information
ethouris authored and maxsharabayko committed Apr 26, 2024
1 parent 6142612 commit 1802702
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 3 deletions.
17 changes: 16 additions & 1 deletion srtcore/core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5609,6 +5609,14 @@ bool srt::CUDT::prepareConnectionObjects(const CHandShake &hs, HandshakeSide hsd
return true;
}

int srt::CUDT::getAuthTagSize() const
{
if (m_pCryptoControl && m_pCryptoControl->getCryptoMode() == CSrtConfig::CIPHER_MODE_AES_GCM)
return HAICRYPT_AUTHTAG_MAX;

return 0;
}

bool srt::CUDT::prepareBuffers(CUDTException* eout)
{
if (m_pSndBuffer)
Expand All @@ -5620,7 +5628,14 @@ bool srt::CUDT::prepareBuffers(CUDTException* eout)
try
{
// CryptoControl has to be initialized and in case of RESPONDER the KM REQ must be processed (interpretSrtHandshake(..)) for the crypto mode to be deduced.
const int authtag = (m_pCryptoControl && m_pCryptoControl->getCryptoMode() == CSrtConfig::CIPHER_MODE_AES_GCM) ? HAICRYPT_AUTHTAG_MAX : 0;
const int authtag = getAuthTagSize();

SRT_ASSERT(m_iMaxSRTPayloadSize != 0);

HLOGC(rslog.Debug, log << CONID() << "Creating buffers: snd-plsize=" << m_iMaxSRTPayloadSize
<< " snd-bufsize=" << 32
<< " authtag=" << authtag);

m_pSndBuffer = new CSndBuffer(32, m_iMaxSRTPayloadSize, authtag);
SRT_ASSERT(m_iPeerISN != -1);
m_pRcvBuffer = new srt::CRcvBuffer(m_iPeerISN, m_config.iRcvBufSize, m_pRcvQueue->m_pUnitQueue, m_config.bMessageAPI);
Expand Down
1 change: 1 addition & 0 deletions srtcore/core.h
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ class CUDT
/// Allocates sender and receiver buffers and loss lists.
SRT_ATR_NODISCARD SRT_ATTR_REQUIRES(m_ConnectionLock)
bool prepareBuffers(CUDTException* eout);
int getAuthTagSize() const;

SRT_ATR_NODISCARD SRT_ATTR_REQUIRES(m_ConnectionLock)
EConnectStatus postConnect(const CPacket* response, bool rendezvous, CUDTException* eout) ATR_NOEXCEPT;
Expand Down
3 changes: 3 additions & 0 deletions srtcore/fec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,9 @@ void FECFilterBuiltin::ClipData(Group& g, uint16_t length_net, uint8_t kflg,
g.flag_clip = g.flag_clip ^ kflg;
g.timestamp_clip = g.timestamp_clip ^ timestamp_hw;

HLOGC(pflog.Debug, log << "FEC CLIP: data pkt.size=" << payload_size
<< " to a clip buffer size=" << payloadSize());

// Payload goes "as is".
for (size_t i = 0; i < payload_size; ++i)
{
Expand Down
8 changes: 7 additions & 1 deletion srtcore/packetfilter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,9 +314,15 @@ bool srt::PacketFilter::configure(CUDT* parent, CUnitQueue* uq, const std::strin
init.socket_id = parent->socketID();
init.snd_isn = parent->sndSeqNo();
init.rcv_isn = parent->rcvSeqNo();
init.payload_size = parent->OPT_PayloadSize();

// XXX This is a formula for a full "SRT payload" part that undergoes transmission,
// might be nice to have this formula as something more general.
init.payload_size = parent->OPT_PayloadSize() + parent->getAuthTagSize();
init.rcvbuf_size = parent->m_config.iRcvBufSize;

HLOGC(pflog.Debug, log << "PFILTER: @" << init.socket_id << " payload size="
<< init.payload_size << " rcvbuf size=" << init.rcvbuf_size);

// Found a filter, so call the creation function
m_filter = selector->second->Create(init, m_provided, confstr);
if (!m_filter)
Expand Down
2 changes: 1 addition & 1 deletion srtcore/utilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ class UniquePtr: public std::auto_ptr<T>
bool operator==(const element_type* two) const { return get() == two; }
bool operator!=(const element_type* two) const { return get() != two; }

operator bool () { return 0!= get(); }
operator bool () const { return 0!= get(); }
};

// A primitive one-argument versions of Sprint and Printable
Expand Down

0 comments on commit 1802702

Please sign in to comment.