diff --git a/pdns/dnsdistdist/dnsdist-nghttp2-in.cc b/pdns/dnsdistdist/dnsdist-nghttp2-in.cc index 4def059246cf..9907938dae43 100644 --- a/pdns/dnsdistdist/dnsdist-nghttp2-in.cc +++ b/pdns/dnsdistdist/dnsdist-nghttp2-in.cc @@ -401,9 +401,6 @@ void IncomingHTTP2Connection::handleIO() } } else { - d_currentPos = 0; - d_proxyProtocolNeed = 0; - d_buffer.clear(); d_state = State::waitingForQuery; handleConnectionReady(); } diff --git a/pdns/dnsdistdist/dnsdist-tcp.cc b/pdns/dnsdistdist/dnsdist-tcp.cc index 1e02000fe4ad..d2516352146c 100644 --- a/pdns/dnsdistdist/dnsdist-tcp.cc +++ b/pdns/dnsdistdist/dnsdist-tcp.cc @@ -902,6 +902,9 @@ IncomingTCPConnectionState::ProxyProtocolResult IncomingTCPConnectionState::hand d_proxyProtocolValues = make_unique>(std::move(proxyProtocolValues)); } + d_currentPos = 0; + d_proxyProtocolNeed = 0; + d_buffer.clear(); return ProxyProtocolResult::Done; } } @@ -1084,15 +1087,14 @@ void IncomingTCPConnectionState::handleIO() if (!d_lastIOBlocked && d_state == State::readingProxyProtocolHeader) { auto status = handleProxyProtocolPayload(); if (status == ProxyProtocolResult::Done) { + d_buffer.resize(sizeof(uint16_t)); + if (isProxyPayloadOutsideTLS()) { d_state = State::doingHandshake; iostate = handleHandshake(now); } else { d_state = State::readingQuerySize; - d_buffer.resize(sizeof(uint16_t)); - d_currentPos = 0; - d_proxyProtocolNeed = 0; } } else if (status == ProxyProtocolResult::Error) { diff --git a/regression-tests.dnsdist/test_ProxyProtocol.py b/regression-tests.dnsdist/test_ProxyProtocol.py index 2ed60e08bc9f..78677b3a7149 100644 --- a/regression-tests.dnsdist/test_ProxyProtocol.py +++ b/regression-tests.dnsdist/test_ProxyProtocol.py @@ -142,7 +142,6 @@ class TestProxyProtocol(ProxyProtocolTest): addAction("values-action.proxy.tests.powerdns.com.", SetProxyProtocolValuesAction({ ["1"]="dnsdist", ["255"]="proxy-protocol"})) """ _config_params = ['_proxyResponderPort'] - _verboseMode = True def testProxyUDP(self): """ @@ -379,6 +378,8 @@ class TestProxyProtocolIncoming(ProxyProtocolTest): _config_template = """ addDOHLocal("127.0.0.1:%d", "%s", "%s", {"/"}, {library='nghttp2', proxyProtocolOutsideTLS=true}) addDOHLocal("127.0.0.1:%d", "%s", "%s", {"/"}, {library='nghttp2', proxyProtocolOutsideTLS=false}) + addTLSLocal("127.0.0.1:%d", "%s", "%s", {proxyProtocolOutsideTLS=true}) + addTLSLocal("127.0.0.1:%d", "%s", "%s", {proxyProtocolOutsideTLS=false}) setProxyProtocolACL( { "127.0.0.1/32" } ) newServer{address="127.0.0.1:%d", useProxyProtocol=true, proxyProtocolAdvertiseTLS=true} @@ -421,7 +422,9 @@ class TestProxyProtocolIncoming(ProxyProtocolTest): _caCert = 'ca.pem' _dohServerPPOutsidePort = pickAvailablePort() _dohServerPPInsidePort = pickAvailablePort() - _config_params = ['_dohServerPPOutsidePort', '_serverCert', '_serverKey', '_dohServerPPInsidePort', '_serverCert', '_serverKey', '_proxyResponderPort'] + _dotServerPPOutsidePort = pickAvailablePort() + _dotServerPPInsidePort = pickAvailablePort() + _config_params = ['_dohServerPPOutsidePort', '_serverCert', '_serverKey', '_dohServerPPInsidePort', '_serverCert', '_serverKey', '_dotServerPPOutsidePort', '_serverCert', '_serverKey', '_dotServerPPInsidePort', '_serverCert', '_serverKey', '_proxyResponderPort'] def testNoHeader(self): """ @@ -666,7 +669,7 @@ def testProxyDoHSeveralQueriesOverConnectionPPOutside(self): conn = self.openDOHConnection(reverseProxyPort, self._caCert, timeout=2.0) reverseProxyBaseURL = ("https://%s:%d/" % (self._serverName, reverseProxyPort)) - (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, useQueue=True, conn=conn) + (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, conn=conn) (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0) self.assertTrue(receivedProxyPayload) self.assertTrue(receivedDNSData) @@ -682,7 +685,7 @@ def testProxyDoHSeveralQueriesOverConnectionPPOutside(self): for idx in range(5): receivedResponse = None toProxyQueue.put(response, True, 2.0) - (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, useQueue=True, conn=conn) + (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, conn=conn) (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0) self.assertTrue(receivedProxyPayload) self.assertTrue(receivedDNSData) @@ -719,7 +722,7 @@ def testProxyDoHSeveralQueriesOverConnectionPPInside(self): conn = self.openDOHConnection(reverseProxyPort, self._caCert, timeout=2.0) reverseProxyBaseURL = ("https://%s:%d/" % (self._serverName, reverseProxyPort)) - (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, useQueue=True, conn=conn) + (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, conn=conn) (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0) self.assertTrue(receivedProxyPayload) self.assertTrue(receivedDNSData) @@ -735,7 +738,7 @@ def testProxyDoHSeveralQueriesOverConnectionPPInside(self): for idx in range(5): receivedResponse = None toProxyQueue.put(response, True, 2.0) - (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, useQueue=True, conn=conn) + (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, conn=conn) (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0) self.assertTrue(receivedProxyPayload) self.assertTrue(receivedDNSData) @@ -748,6 +751,108 @@ def testProxyDoHSeveralQueriesOverConnectionPPInside(self): self.assertEqual(receivedResponse, response) self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [32, ''], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort) + def testProxyDoTSeveralQueriesOverConnectionPPOutside(self): + """ + Incoming Proxy Protocol: Several queries over the same connection (DoT, PP outside TLS) + """ + name = 'several-queries.dot-outside.proxy-protocol-incoming.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + response = dns.message.make_response(query) + + toProxyQueue.put(response, True, 2.0) + + wire = query.to_wire() + + reverseProxyPort = pickAvailablePort() + reverseProxy = threading.Thread(name='Mock Proxy Protocol Reverse Proxy', target=MockTCPReverseProxyAddingProxyProtocol, args=[reverseProxyPort, self._dotServerPPOutsidePort]) + reverseProxy.start() + time.sleep(1) + + receivedResponse = None + conn = self.openTLSConnection(reverseProxyPort, self._serverName, self._caCert, timeout=2.0) + self.sendTCPQueryOverConnection(conn, query, response=response) + receivedResponse = self.recvTCPResponseOverConnection(conn) + (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0) + self.assertTrue(receivedProxyPayload) + self.assertTrue(receivedDNSData) + self.assertTrue(receivedResponse) + + receivedQuery = dns.message.from_wire(receivedDNSData) + receivedQuery.id = query.id + receivedResponse.id = response.id + self.assertEqual(receivedQuery, query) + self.assertEqual(receivedResponse, response) + self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [32, ''], [42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort) + + for idx in range(5): + receivedResponse = None + toProxyQueue.put(response, True, 2.0) + self.sendTCPQueryOverConnection(conn, query, response=response) + receivedResponse = self.recvTCPResponseOverConnection(conn) + (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0) + self.assertTrue(receivedProxyPayload) + self.assertTrue(receivedDNSData) + self.assertTrue(receivedResponse) + + receivedQuery = dns.message.from_wire(receivedDNSData) + receivedQuery.id = query.id + receivedResponse.id = response.id + self.assertEqual(receivedQuery, query) + self.assertEqual(receivedResponse, response) + self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [32, ''], [42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort) + + def testProxyDoTSeveralQueriesOverConnectionPPInside(self): + """ + Incoming Proxy Protocol: Several queries over the same connection (DoT, PP inside TLS) + """ + name = 'several-queries.dot-inside.proxy-protocol-incoming.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + response = dns.message.make_response(query) + + toProxyQueue.put(response, True, 2.0) + + wire = query.to_wire() + + reverseProxyPort = pickAvailablePort() + tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + tlsContext.load_cert_chain(self._serverCert, self._serverKey) + tlsContext.set_alpn_protocols(['dot']) + reverseProxy = threading.Thread(name='Mock Proxy Protocol Reverse Proxy', target=MockTCPReverseProxyAddingProxyProtocol, args=[reverseProxyPort, self._dotServerPPInsidePort, tlsContext, self._caCert, self._serverName]) + reverseProxy.start() + + receivedResponse = None + time.sleep(1) + conn = self.openTLSConnection(reverseProxyPort, self._serverName, self._caCert, timeout=2.0) + + self.sendTCPQueryOverConnection(conn, query, response=response) + receivedResponse = self.recvTCPResponseOverConnection(conn) + (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0) + self.assertTrue(receivedProxyPayload) + self.assertTrue(receivedDNSData) + self.assertTrue(receivedResponse) + + receivedQuery = dns.message.from_wire(receivedDNSData) + receivedQuery.id = query.id + receivedResponse.id = response.id + self.assertEqual(receivedQuery, query) + self.assertEqual(receivedResponse, response) + + for idx in range(5): + receivedResponse = None + toProxyQueue.put(response, True, 2.0) + self.sendTCPQueryOverConnection(conn, query, response=response) + receivedResponse = self.recvTCPResponseOverConnection(conn) + (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0) + self.assertTrue(receivedProxyPayload) + self.assertTrue(receivedDNSData) + self.assertTrue(receivedResponse) + + receivedQuery = dns.message.from_wire(receivedDNSData) + receivedQuery.id = query.id + receivedResponse.id = response.id + self.assertEqual(receivedQuery, query) + self.assertEqual(receivedResponse, response) + @classmethod def tearDownClass(cls): cls._sock.close() @@ -768,7 +873,6 @@ class TestProxyProtocolNotExpected(DNSDistTest): """ # NORMAL responder, does not expect a proxy protocol payload! _config_params = ['_testServerPort'] - _verboseMode = True def testNoHeader(self): """ @@ -910,7 +1014,6 @@ class TestDOHWithOutgoingProxyProtocol(DNSDistDOHTest): setACL( { "::1/128", "127.0.0.0/8" } ) """ _config_params = ['_proxyResponderPort', '_dohWithNGHTTP2ServerPort', '_serverCert', '_serverKey', '_dohWithH2OServerPort', '_serverCert', '_serverKey'] - _verboseMode = True def testTruncation(self): """