Skip to content

Commit

Permalink
Fix client tests
Browse files Browse the repository at this point in the history
  • Loading branch information
linuxdaemon committed Jun 1, 2020
1 parent 262e5c3 commit 7e2c2c5
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 36 deletions.
74 changes: 47 additions & 27 deletions cloudbot/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
from cloudbot import clients
from cloudbot.client import Client
from cloudbot.config import Config
from cloudbot.event import Event, CommandEvent, RegexEvent, EventType
from cloudbot.event import CommandEvent, Event, EventType, RegexEvent
from cloudbot.hook import Action
from cloudbot.plugin import PluginManager
from cloudbot.reloader import PluginReloader, ConfigReloader
from cloudbot.util import database, formatting, async_util
from cloudbot.reloader import ConfigReloader, PluginReloader
from cloudbot.util import async_util, database, formatting
from cloudbot.util.mapping import KeyFoldDict

logger = logging.getLogger("cloudbot")
Expand Down Expand Up @@ -58,28 +58,34 @@ def clean_name(n):
:type n: str
:rtype: str
"""
return re.sub('[^A-Za-z0-9_]+', '', n.replace(" ", "_"))
return re.sub("[^A-Za-z0-9_]+", "", n.replace(" ", "_"))


def get_cmd_regex(event):
conn = event.conn
is_pm = event.chan.lower() == event.nick.lower()
command_prefix = re.escape(conn.config.get('command_prefix', '.'))
command_prefix = re.escape(conn.config.get("command_prefix", "."))
conn_nick = re.escape(event.conn.nick)
cmd_re = re.compile(
r"""
^
# Prefix or nick
(?:
(?P<prefix>[""" + command_prefix + r"""])""" + ('?' if is_pm else '') + r"""
(?P<prefix>["""
+ command_prefix
+ r"""])"""
+ ("?" if is_pm else "")
+ r"""
|
""" + conn_nick + r"""[,;:]+\s+
"""
+ conn_nick
+ r"""[,;:]+\s+
)
(?P<command>\w+) # Command
(?:$|\s+)
(?P<text>.*) # Text
""",
re.IGNORECASE | re.VERBOSE
re.IGNORECASE | re.VERBOSE,
)
return cmd_re

Expand Down Expand Up @@ -126,7 +132,7 @@ def __init__(self, loop=asyncio.get_event_loop()):
self.memory = collections.defaultdict()

# declare and create data folder
self.data_dir = os.path.abspath('data')
self.data_dir = os.path.abspath("data")
if not os.path.exists(self.data_dir):
logger.debug("Data folder not found, creating.")
os.mkdir(self.data_dir)
Expand All @@ -141,11 +147,14 @@ def __init__(self, loop=asyncio.get_event_loop()):
self.config_reloading_enabled = reloading_conf.get("config_reloading", True)

# this doesn't REALLY need to be here but it's nice
self.user_agent = self.config.get('user_agent', 'CloudBot/3.0 - CloudBot Refresh '
'<https://github.com/CloudBotIRC/CloudBot/>')
self.user_agent = self.config.get(
"user_agent",
"CloudBot/3.0 - CloudBot Refresh "
"<https://github.com/CloudBotIRC/CloudBot/>",
)

# setup db
db_path = self.config.get('database', 'sqlite:///cloudbot.db')
db_path = self.config.get("database", "sqlite:///cloudbot.db")
self.db_engine = create_engine(db_path)
self.db_factory = sessionmaker(bind=self.db_engine)
self.db_session = scoped_session(self.db_factory)
Expand Down Expand Up @@ -201,15 +210,14 @@ def register_client(self, name, cls):

def create_connections(self):
""" Create a BotConnection for all the networks defined in the config """
for config in self.config['connections']:
for config in self.config["connections"]:
# strip all spaces and capitalization from the connection name
name = clean_name(config['name'])
nick = config['nick']
name = clean_name(config["name"])
nick = config["nick"]
_type = config.get("type", "irc")

self.connections[name] = self.get_client(_type)(
self, _type, name, nick, config=config,
channels=config['channels']
self, _type, name, nick, config=config, channels=config["channels"]
)
logger.debug("[%s] Created connection.", name)

Expand Down Expand Up @@ -283,7 +291,9 @@ async def _init_routine(self):
conn.active = True

# Connect to servers
await asyncio.gather(*[conn.try_connect() for conn in self.connections.values()], loop=self.loop)
await asyncio.gather(
*[conn.try_connect() for conn in self.connections.values()], loop=self.loop
)
logger.debug("Connections created.")

# Run a manual garbage collection cycle, to clean up any unused objects created during initialization
Expand All @@ -294,7 +304,7 @@ def load_clients(self):
Load all clients from the "clients" directory
"""
scanner = Scanner(bot=self)
scanner.scan(clients, categories=['cloudbot.client'])
scanner.scan(clients, categories=["cloudbot.client"])

def process(self, event):
"""
Expand Down Expand Up @@ -348,12 +358,16 @@ def add_hook(hook, _event):
cmd_match = get_cmd_regex(event).match(event.content)

if cmd_match:
command_prefix = event.conn.config.get('command_prefix', '.')
prefix = cmd_match.group('prefix') or command_prefix[0]
command = cmd_match.group('command').lower()
text = cmd_match.group('text').strip()
command_prefix = event.conn.config.get("command_prefix", ".")
prefix = cmd_match.group("prefix") or command_prefix[0]
command = cmd_match.group("command").lower()
text = cmd_match.group("text").strip()
cmd_event = partial(
CommandEvent, text=text, triggered_command=command, base_event=event, cmd_prefix=prefix
CommandEvent,
text=text,
triggered_command=command,
base_event=event,
cmd_prefix=prefix,
)
if command in self.plugin_manager.commands:
command_hook = self.plugin_manager.commands[command]
Expand All @@ -373,7 +387,9 @@ def add_hook(hook, _event):
command_event = cmd_event(hook=command_hook)
add_hook(command_hook, command_event)
else:
commands = sorted(command for command, plugin in potential_matches)
commands = sorted(
command for command, plugin in potential_matches
)
txt_list = formatting.get_text_list(commands)
event.notice("Possible matches: {}".format(txt_list))

Expand All @@ -390,12 +406,16 @@ def add_hook(hook, _event):
regex_match = regex.search(event.content)
if regex_match:
regex_matched = True
regex_event = RegexEvent(hook=regex_hook, match=regex_match, base_event=event)
regex_event = RegexEvent(
hook=regex_hook, match=regex_match, base_event=event
)
if not add_hook(regex_hook, regex_event):
# The hook has an action of Action.HALT* so stop adding new tasks
break

tasks.sort(key=lambda t: t[0].priority)

for _hook, _event in tasks:
async_util.wrap_future(self.plugin_manager.launch(_hook, _event))
async_util.wrap_future(
self.plugin_manager.launch(_hook, _event), loop=self.loop
)
13 changes: 4 additions & 9 deletions tests/core_tests/irc_client_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
from asyncio import Task
from unittest.mock import MagicMock

from cloudbot.clients.irc import _IrcProtocol
Expand All @@ -11,13 +10,11 @@ def _filter_event(self, event):
return {k: v for k, v in dict(event).items() if not callable(v)}

def test_data_received(self):
conn, out, proto = self.make_proto()
_, out, proto = self.make_proto()
proto.data_received(
b":server.host COMMAND this is :a command\r\n:server.host PRIVMSG me :hi\r\n"
)

conn.loop.run_until_complete(asyncio.gather(*Task.all_tasks(conn.loop)))

assert out == [
{
"chan": None,
Expand Down Expand Up @@ -62,24 +59,22 @@ def test_data_received(self):
def make_proto(self):
conn = MagicMock()
conn.nick = "me"
conn.loop = asyncio.get_event_loop_policy().new_event_loop()
conn.loop = conn.bot.loop = asyncio.get_event_loop_policy().new_event_loop()
out = []

async def func(e):
def func(e):
out.append(self._filter_event(e))

conn.bot.process = func
proto = _IrcProtocol(conn)
return conn, out, proto

def test_broken_line_doesnt_interrupt(self):
conn, out, proto = self.make_proto()
_, out, proto = self.make_proto()
proto.data_received(
b":server.host COMMAND this is :a command\r\nPRIVMSG\r\n:server.host PRIVMSG me :hi\r\n"
)

conn.loop.run_until_complete(asyncio.gather(*Task.all_tasks(conn.loop)))

assert out == [
{
"chan": None,
Expand Down

0 comments on commit 7e2c2c5

Please sign in to comment.