diff --git a/pincer/client.py b/pincer/client.py index ec4f1476..85a6102a 100644 --- a/pincer/client.py +++ b/pincer/client.py @@ -17,15 +17,27 @@ from .core import HTTPClient from .core.gateway import Dispatcher from .exceptions import ( - InvalidEventName, TooManySetupArguments, NoValidSetupMethod, - NoCogManagerReturnFound, CogAlreadyExists, CogNotFound + InvalidEventName, + TooManySetupArguments, + NoValidSetupMethod, + NoCogManagerReturnFound, + CogAlreadyExists, + CogNotFound, ) from .middleware import middleware from .objects import ( - Role, Channel, DefaultThrottleHandler, User, Guild, Intents, - GuildTemplate, StickerPack + Role, + Channel, + DefaultThrottleHandler, + User, + Guild, + Intents, + GuildTemplate, + Connection, + StickerPack, File ) -from .utils.conversion import construct_client_dict +from .objects.guild.channel import GroupDMChannel +from .utils.conversion import construct_client_dict, remove_none from .utils.event_mgr import EventMgr from .utils.extraction import get_index from .utils.insertion import should_pass_cls @@ -40,6 +52,8 @@ from .objects.app.throttling import ThrottleInterface from .objects.guild import Webhook + from collections.abc import AsyncIterator + _log = logging.getLogger(__package__) MiddlewareType = Optional[Union[Coro, Tuple[str, List[Any], Dict[str, Any]]]] @@ -104,7 +118,8 @@ def decorator(func: Coro): if override: _log.warning( "Middleware overriding has been enabled for `%s`." - " This might cause unexpected behavior.", call + " This might cause unexpected behavior.", + call, ) if not override and callable(_events.get(call)): @@ -117,9 +132,7 @@ async def wrapper(cls, payload: GatewayDispatch): _log.debug("`%s` middleware has been invoked", call) return await ( - func(cls, payload) - if should_pass_cls(func) - else func(payload) + func(cls, payload) if should_pass_cls(func) else func(payload) ) _events[call] = wrapper @@ -166,12 +179,13 @@ class Client(Dispatcher): """ # noqa: E501 def __init__( - self, - token: str, *, - received: str = None, - intents: Union[Iterable, Intents] = None, - throttler: ThrottleInterface = DefaultThrottleHandler, - reconnect: bool = True, + self, + token: str, + *, + received: str = None, + intents: Intents = None, + throttler: ThrottleInterface = DefaultThrottleHandler, + reconnect: bool = True, ): if isinstance(intents, Iterable): @@ -183,7 +197,7 @@ def __init__( # Gets triggered on all events -1: self.payload_event_handler, # Use this event handler for opcode 0. - 0: self.event_handler + 0: self.event_handler, }, intents=intents or Intents.all(), reconnect=reconnect, @@ -209,10 +223,9 @@ def chat_commands(self) -> List[str]: Get a list of chat command calls which have been registered in the :class:`~pincer.commands.ChatCommandHandler`\\. """ - return list(map( - lambda cmd: cmd.app.name, - ChatCommandHandler.register.values() - )) + return list( + map(lambda cmd: cmd.app.name, ChatCommandHandler.register.values()) + ) @property def guild_ids(self) -> List[Snowflake]: @@ -274,8 +287,10 @@ async def on_ready(self): InvalidEventName If the function name is not a valid event (on_x) """ - if not iscoroutinefunction(coroutine) \ - and not isasyncgenfunction(coroutine): + if ( + not iscoroutinefunction(coroutine) + and not isasyncgenfunction(coroutine) + ): raise TypeError( "Any event which is registered must be a coroutine function" ) @@ -307,10 +322,17 @@ def get_event_coro(name: str) -> List[Optional[Coro]]: """ calls = _events.get(name.strip().lower()) - return [] if not calls else list(filter( - lambda call: iscoroutinefunction(call) or isasyncgenfunction(call), - calls - )) + return ( + [] + if not calls + else list( + filter( + lambda call: iscoroutinefunction(call) + or isasyncgenfunction(call), + calls, + ) + ) + ) def load_cog(self, path: str, package: Optional[str] = None): """Load a cog from a string path, setup method in COG may @@ -451,7 +473,7 @@ def execute_event(calls: List[Coro], *args, **kwargs): if should_pass_cls(call): call_args = ( ChatCommandHandler.managers[call.__module__], - *(arg for arg in args if arg is not None) + *(arg for arg in args if arg is not None), ) ensure_future(call(*call_args, **kwargs)) @@ -462,15 +484,11 @@ def run(self): def __del__(self): """Ensure close of the http client.""" - if hasattr(self, 'http'): + if hasattr(self, "http"): run(self.http.close()) async def handle_middleware( - self, - payload: GatewayDispatch, - key: str, - *args, - **kwargs + self, payload: GatewayDispatch, key: str, *args, **kwargs ) -> Tuple[Optional[Coro], List[Any], Dict[str, Any]]: """|coro| @@ -522,11 +540,7 @@ async def handle_middleware( ) async def execute_error( - self, - error: Exception, - name: str = "on_error", - *args, - **kwargs + self, error: Exception, name: str = "on_error", *args, **kwargs ): """|coro| @@ -623,7 +637,7 @@ async def create_guild( afk_channel_id: Optional[Snowflake] = None, afk_timeout: Optional[int] = None, system_channel_id: Optional[Snowflake] = None, - system_channel_flags: Optional[int] = None + system_channel_flags: Optional[int] = None, ) -> Guild: """Creates a guild. @@ -664,7 +678,7 @@ async def create_guild( async def create_guild(self, name: str, **kwargs) -> Guild: g = await self.http.post("guilds", data={"name": name, **kwargs}) - return await self.get_guild(g['id']) + return await self.get_guild(g["id"]) async def get_guild_template(self, code: str) -> GuildTemplate: """|coro| @@ -682,16 +696,12 @@ async def get_guild_template(self, code: str) -> GuildTemplate: """ return GuildTemplate.from_dict( construct_client_dict( - self, - await self.http.get(f"guilds/templates/{code}") + self, await self.http.get(f"guilds/templates/{code}") ) ) async def create_guild_from_template( - self, - template: GuildTemplate, - name: str, - icon: Optional[str] = None + self, template: GuildTemplate, name: str, icon: Optional[str] = None ) -> Guild: """|coro| Creates a guild from a template. @@ -715,16 +725,16 @@ async def create_guild_from_template( self, await self.http.post( f"guilds/templates/{template.code}", - data={"name": name, "icon": icon} - ) + data={"name": name, "icon": icon}, + ), ) ) async def wait_for( - self, - event_name: str, - check: CheckFunction = None, - timeout: Optional[float] = None + self, + event_name: str, + check: CheckFunction = None, + timeout: Optional[float] = None, ): """ Parameters @@ -745,11 +755,11 @@ async def wait_for( return await self.event_mgr.wait_for(event_name, check, timeout) def loop_for( - self, - event_name: str, - check: CheckFunction = None, - iteration_timeout: Optional[float] = None, - loop_timeout: Optional[float] = None + self, + event_name: str, + check: CheckFunction = None, + iteration_timeout: Optional[float] = None, + loop_timeout: Optional[float] = None, ): """ Parameters @@ -771,10 +781,7 @@ def loop_for( What the Discord API returns for this event. """ return self.event_mgr.loop_for( - event_name, - check, - iteration_timeout, - loop_timeout + event_name, check, iteration_timeout, loop_timeout ) async def get_guild(self, guild_id: int) -> Guild: @@ -853,9 +860,7 @@ async def get_channel(self, _id: int) -> Channel: return await Channel.from_id(self, _id) async def get_webhook( - self, - id: Snowflake, - token: Optional[str] = None + self, id: Snowflake, token: Optional[str] = None ) -> Webhook: """|coro| Fetch a Webhook from its identifier. @@ -875,6 +880,145 @@ async def get_webhook( """ return await Webhook.from_id(self, id, token) + + async def get_current_user(self) -> User: + """|coro| + The user object of the requester's account. + + For OAuth2, this requires the ``identify`` scope, + which will return the object *without* an email, + and optionally the ``email`` scope, + which returns the object *with* an email. + + Returns + ------- + :class:`~pincer.objects.user.user.User` + """ + return User.from_dict( + construct_client_dict( + self, + await self.http.get("users/@me") + ) + ) + + async def modify_current_user( + self, username: Optional[str] = None, avatar: Optional[File] = None + ) -> User: + """|coro| + Modify the requester's user account settings + + Parameters + ---------- + username : Optional[:class:`str`] + user's username, + if changed may cause the user's discriminator to be randomized. + |default| :data:`None` + avatar : Optional[:class:`File`] + if passed, modifies the user's avatar + a data URI scheme of JPG, GIF or PNG + |default| :data:`None` + + Returns + ------- + :class:`~pincer.objects.user.user.User` + Current modified user + """ + + avatar = avatar.uri if avatar else avatar + + user = await self.http.patch( + "users/@me", remove_none({"username": username, "avatar": avatar}) + ) + return User.from_dict(construct_client_dict(self, user)) + + async def get_current_user_guilds( + self, + before: Optional[Snowflake] = None, + after: Optional[Snowflake] = None, + limit: Optional[int] = None, + ) -> AsyncIterator[Guild]: + """|coro| + Returns a list of partial guild objects the current user is a member of. + Requires the ``guilds`` OAuth2 scope. + + Parameters + ---------- + before : Optional[:class:`~pincer.utils.snowflake.Snowflake`] + get guilds before this guild ID + after : Optional[:class:`~pincer.utils.snowflake.Snowflake`] + get guilds after this guild ID + limit : Optional[:class:`int`] + max number of guilds to return (1-200) |default| :data:`200` + + Yields + ------ + :class:`~pincer.objects.guild.guild.Guild` + A Partial Guild that the user is in + """ + guilds = await self.http.get( + "users/@me/guilds?" + + (f"{before=}&" if before else "") + + (f"{after=}&" if after else "") + + (f"{limit=}&" if limit else "") + ) + + for guild in guilds: + yield Guild.from_dict(construct_client_dict(self, guild)) + + async def leave_guild(self, _id: Snowflake): + """|coro| + Leave a guild. + + Parameters + ---------- + _id : :class:`~pincer.utils.snowflake.Snowflake` + the id of the guild that the bot will leave + """ + await self.http.delete(f"users/@me/guilds/{_id}") + self._client.guilds.pop(_id, None) + + async def create_group_dm( + self, access_tokens: List[str], nicks: Dict[Snowflake, str] + ) -> GroupDMChannel: + """|coro| + Create a new group DM channel with multiple users. + DMs created with this endpoint will not be shown in the Discord client + + Parameters + ---------- + access_tokens : List[:class:`str`] + access tokens of users that have + granted your app the ``gdm.join`` scope + + nicks : Dict[:class:`~pincer.utils.snowflake.Snowflake`, :class:`str`] + a dictionary of user ids to their respective nicknames + + Returns + ------- + :class:`~pincer.objects.guild.channel.GroupDMChannel` + group DM channel created + """ + channel = await self.http.post( + "users/@me/channels", + {"access_tokens": access_tokens, "nicks": nicks}, + ) + + return GroupDMChannel.from_dict(construct_client_dict(self, channel)) + + async def get_connections(self) -> AsyncIterator[Connection]: + """|coro| + Returns a list of connection objects. + Requires the ``connections`` OAuth2 scope. + + Yields + ------- + :class:`~pincer.objects.user.connection.Connections` + the connection objects + """ + connections = await self.http.get("users/@me/connections") + for conn in connections: + yield Connection.from_dict(conn) + async def sticker_packs(self) -> AsyncIterator[StickerPack]: """|coro| Yields sticker packs available to Nitro subscribers. diff --git a/pincer/objects/guild/channel.py b/pincer/objects/guild/channel.py index c67117ed..f14fc0e0 100644 --- a/pincer/objects/guild/channel.py +++ b/pincer/objects/guild/channel.py @@ -518,6 +518,8 @@ async def edit(self, **kwargs): """ return await super().edit(**kwargs) +class GroupDMChannel(Channel): + """A subclass of ``Channel`` for Group DMs""" class CategoryChannel(Channel): """A subclass of ``Channel`` for categories channels @@ -592,6 +594,7 @@ class ChannelMention(APIObject): _channel_type_map: Dict[ChannelType, Channel] = { ChannelType.GUILD_TEXT: TextChannel, ChannelType.GUILD_VOICE: VoiceChannel, + ChannelType.GROUP_DM: GroupDMChannel, ChannelType.GUILD_CATEGORY: CategoryChannel, ChannelType.GUILD_NEWS: NewsChannel, ChannelType.GUILD_PUBLIC_THREAD: PublicThread,