Skip to content

Commit

Permalink
STS: When persisting STS keys, use the actual port instead of the one…
Browse files Browse the repository at this point in the history
… from the policy

'Servers MAY send this key to securely connected clients, but it will be ignored.'
-- https://ircv3.net/specs/extensions/sts#the-port-key
  • Loading branch information
progval committed Sep 3, 2021
1 parent 74073b2 commit 015ac4a
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 43 deletions.
8 changes: 6 additions & 2 deletions plugins/Channel/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,8 +950,12 @@ def alertOps(self, irc, msg, channel, s, frm=None):
if frm is not None:
s += format(_(' (from %s)'), frm)
for nick in irc.state.channels[channel].users:
if ircdb.checkCapability(msg.prefix, capability):
irc.reply(s, to=nick, private=True)
prefix = irc.state.nicksToHostmasks.get(nick)
if not prefix:
continue
if not ircdb.checkCapability(prefix, capability):
continue
irc.reply(s, to=nick, private=True)
irc.replySuccess()

@internationalizeDocstring
Expand Down
11 changes: 6 additions & 5 deletions src/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def _getNextServer(self):

def _applyStsPolicy(self, server):
network = ircdb.networks.getNetwork(self.networkName)
policy = network.stsPolicies.get(server.hostname)
(policy_port, policy) = network.stsPolicies.get(
server.hostname, (None, None))
lastDisconnect = network.lastDisconnectTimes.get(server.hostname)

if policy is None or lastDisconnect is None:
Expand All @@ -107,22 +108,22 @@ def _applyStsPolicy(self, server):

# The policy was stored, which means it was received on a secure
# connection.
policy = ircutils.parseStsPolicy(log, policy, parseDuration=True)
policy = ircutils.parseStsPolicy(log, policy, secure_connection=True)

if lastDisconnect + policy['duration'] < time.time():
log.info('STS policy expired, removing.')
network.expireStsPolicy(server.hostname)
return server

if server.port == policy['port']:
if server.port == policy_port:
log.info('Using STS policy, port %s', server.port)
else:
log.info('Using STS policy: changing port from %s to %s.',
server.port, policy['port'])
server.port, policy_port)

# Change the port, and force TLS verification, as required by the STS
# specification.
return Server(server.hostname, policy['port'], server.attempt,
return Server(server.hostname, policy_port, server.attempt,
force_tls_verification=True)

def die(self):
Expand Down
21 changes: 14 additions & 7 deletions src/ircdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,9 +509,10 @@ def __repr__(self):
(self.__class__.__name__, self.stsPolicies,
self.lastDisconnectTimes)

def addStsPolicy(self, server, stsPolicy):
assert isinstance(stsPolicy, str)
self.stsPolicies[server] = stsPolicy
def addStsPolicy(self, server, port, stsPolicy):
assert isinstance(port, int), repr(port)
assert isinstance(stsPolicy, str), repr(stsPolicy)
self.stsPolicies[server] = (port, stsPolicy)

def expireStsPolicy(self, server):
if server in self.stsPolicies:
Expand All @@ -526,8 +527,10 @@ def write(s):
fd.write(s)
fd.write(os.linesep)

for (server, stsPolicy) in sorted(self.stsPolicies.items()):
write('stsPolicy %s %s' % (server, stsPolicy))
for (server, (port, stsPolicy)) in sorted(self.stsPolicies.items()):
assert isinstance(port, int), repr(port)
assert isinstance(stsPolicy, str), repr(stsPolicy)
write('stsPolicy %s %s %s' % (server, port, stsPolicy))

for (server, disconnectTime) in \
sorted(self.lastDisconnectTimes.items()):
Expand Down Expand Up @@ -667,8 +670,12 @@ def network(self, rest, lineno):
IrcNetworkCreator.name = rest

def stspolicy(self, rest, lineno):
(server, stsPolicy) = rest.split()
self.net.addStsPolicy(server, stsPolicy)
L = rest.split()
if len(L) == 2:
# Old policy missing a port. Discard it
return
(server, policyPort, stsPolicy) = L
self.net.addStsPolicy(server, int(policyPort), stsPolicy)

def lastdisconnecttime(self, rest, lineno):
(server, when) = rest.split()
Expand Down
11 changes: 8 additions & 3 deletions src/irclib.py
Original file line number Diff line number Diff line change
Expand Up @@ -2050,7 +2050,7 @@ def _onCapSts(self, policy, msg):
or (self.driver.ssl and self.driver.anyCertValidationEnabled())

parsed_policy = ircutils.parseStsPolicy(
log, policy, parseDuration=secure_connection)
log, policy, secure_connection=secure_connection)
if parsed_policy is None:
# There was an error (and it was logged). Ignore it and proceed
# with the connection.
Expand All @@ -2065,9 +2065,14 @@ def _onCapSts(self, policy, msg):
# For future-proofing (because we don't want to write an invalid
# value), we write the raw policy received from the server instead
# of the parsed one.
log.debug('Storing STS policy: %s', policy)
log.debug('Storing STS policy for %s (TLS port %s): %s',
self.driver.currentServer.hostname,
self.driver.currentServer.port,
policy)
ircdb.networks.getNetwork(self.network).addStsPolicy(
self.driver.currentServer.hostname, policy)
self.driver.currentServer.hostname,
self.driver.currentServer.port,
policy)
else:
hostname = self.driver.currentServer.hostname
attempt = self.driver.currentServer.attempt
Expand Down
8 changes: 6 additions & 2 deletions src/ircutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,11 +1073,15 @@ def parseCapabilityKeyValue(s):
return d


def parseStsPolicy(logger, policy, parseDuration):
def parseStsPolicy(logger, policy, secure_connection):
parsed_policy = parseCapabilityKeyValue(policy)

for key in ('port', 'duration'):
if key == 'duration' and not parseDuration:
if key == 'duration' and not secure_connection:
if key in parsed_policy:
del parsed_policy[key]
continue
elif key == 'port' and secure_connection:
if key in parsed_policy:
del parsed_policy[key]
continue
Expand Down
6 changes: 3 additions & 3 deletions test/test_drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def tearDown(self):
def testValidStsPolicy(self):
irc = irclib.Irc('test')
net = ircdb.networks.getNetwork('test')
net.addStsPolicy('example.com', 'duration=10,port=6697')
net.addStsPolicy('example.com', 6697, 'duration=10,port=12345')
net.addDisconnection('example.com')

with conf.supybot.networks.test.servers.context(
Expand All @@ -64,7 +64,7 @@ def testValidStsPolicy(self):
def testExpiredStsPolicy(self):
irc = irclib.Irc('test')
net = ircdb.networks.getNetwork('test')
net.addStsPolicy('example.com', 'duration=10,port=6697')
net.addStsPolicy('example.com', 6697, 'duration=10')
net.addDisconnection('example.com')

timeFastForward(16)
Expand All @@ -81,7 +81,7 @@ def testExpiredStsPolicy(self):
def testRescheduledStsPolicy(self):
irc = irclib.Irc('test')
net = ircdb.networks.getNetwork('test')
net.addStsPolicy('example.com', 'duration=10,port=6697')
net.addStsPolicy('example.com', 6697, 'duration=10')
net.addDisconnection('example.com')

with conf.supybot.networks.test.servers.context(
Expand Down
36 changes: 18 additions & 18 deletions test/test_ircdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,11 +358,11 @@ def testDefaults(self):

def testStsPolicy(self):
n = ircdb.IrcNetwork()
n.addStsPolicy('foo', 'bar')
n.addStsPolicy('baz', 'qux')
n.addStsPolicy('foo', 123, 'bar')
n.addStsPolicy('baz', 456, 'qux')
self.assertEqual(n.stsPolicies, {
'foo': 'bar',
'baz': 'qux',
'foo': (123, 'bar'),
'baz': (456, 'qux'),
})

def testAddDisconnection(self):
Expand All @@ -374,8 +374,8 @@ def testAddDisconnection(self):

def testPreserve(self):
n = ircdb.IrcNetwork()
n.addStsPolicy('foo', 'sts1')
n.addStsPolicy('bar', 'sts2')
n.addStsPolicy('foo', 123, 'sts1')
n.addStsPolicy('bar', 456,'sts2')
n.addDisconnection('foo')
n.addDisconnection('baz')
disconnect_time_foo = n.lastDisconnectTimes['foo']
Expand All @@ -384,8 +384,8 @@ def testPreserve(self):
n.preserve(fd, indent=' ')
fd.seek(0)
self.assertCountEqual(fd.read().split('\n'), [
' stsPolicy foo sts1',
' stsPolicy bar sts2',
' stsPolicy foo 123 sts1',
' stsPolicy bar 456 sts2',
' lastDisconnectTime foo %d' % disconnect_time_foo,
' lastDisconnectTime baz %d' % disconnect_time_baz,
'',
Expand Down Expand Up @@ -467,8 +467,8 @@ def testGetSetNetwork(self):

def testPreserveOne(self):
n = ircdb.IrcNetwork()
n.addStsPolicy('foo', 'sts1')
n.addStsPolicy('bar', 'sts2')
n.addStsPolicy('foo', 123, 'sts1')
n.addStsPolicy('bar', 456, 'sts2')
n.addDisconnection('foo')
n.addDisconnection('baz')
disconnect_time_foo = n.lastDisconnectTimes['foo']
Expand All @@ -486,8 +486,8 @@ def testPreserveOne(self):
lines = fd.getvalue().split('\n')
self.assertEqual(lines.pop(0), 'network foonet')
self.assertCountEqual(lines, [
' stsPolicy foo sts1',
' stsPolicy bar sts2',
' stsPolicy foo 123 sts1',
' stsPolicy bar 456 sts2',
' lastDisconnectTime foo %d' % disconnect_time_foo,
' lastDisconnectTime baz %d' % disconnect_time_baz,
'',
Expand All @@ -496,15 +496,15 @@ def testPreserveOne(self):

def testPreserveThree(self):
n = ircdb.IrcNetwork()
n.addStsPolicy('foo', 'sts1')
n.addStsPolicy('foo', 123, 'sts1')
self.networks.setNetwork('foonet', n)

n = ircdb.IrcNetwork()
n.addStsPolicy('bar', 'sts2')
n.addStsPolicy('bar', 456, 'sts2')
self.networks.setNetwork('barnet', n)

n = ircdb.IrcNetwork()
n.addStsPolicy('baz', 'sts3')
n.addStsPolicy('baz', 789, 'sts3')
self.networks.setNetwork('baznet', n)

fd = io.StringIO()
Expand All @@ -518,13 +518,13 @@ def testPreserveThree(self):
fd.seek(0)
self.assertEqual(fd.getvalue(),
'network barnet\n'
' stsPolicy bar sts2\n'
' stsPolicy bar 456 sts2\n'
'\n'
'network baznet\n'
' stsPolicy baz sts3\n'
' stsPolicy baz 789 sts3\n'
'\n'
'network foonet\n'
' stsPolicy foo sts1\n'
' stsPolicy foo 123 sts1\n'
'\n'
)

Expand Down
17 changes: 14 additions & 3 deletions test/test_irclib.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,16 +759,27 @@ def testStsInSecureConnection(self):
self.irc.driver.ssl = True
self.irc.driver.currentServer = drivers.Server('irc.test', 6697, None, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=duration=42,port=6697')))
args=('*', 'LS', 'sts=duration=42,port=12345')))

self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {
'irc.test': 'duration=42,port=6697'})
'irc.test': (6697, 'duration=42,port=12345')})
self.irc.driver.reconnect.assert_not_called()

def testStsInSecureConnectionNoPort(self):
self.irc.driver.anyCertValidationEnabled.return_value = True
self.irc.driver.ssl = True
self.irc.driver.currentServer = drivers.Server('irc.test', 6697, None, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=duration=42')))

self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {
'irc.test': (6697, 'duration=42')})
self.irc.driver.reconnect.assert_not_called()

def testStsInInsecureTlsConnection(self):
self.irc.driver.anyCertValidationEnabled.return_value = False
self.irc.driver.ssl = True
self.irc.driver.currentServer = drivers.Server('irc.test', 6697, None, False)
self.irc.driver.currentServer = drivers.Server('irc.test', 6667, None, False)
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
args=('*', 'LS', 'sts=duration=42,port=6697')))

Expand Down

0 comments on commit 015ac4a

Please sign in to comment.