From 3571e84f78454bf9c48b1e427c1acf6a5fae2bdd Mon Sep 17 00:00:00 2001 From: Joshua Koike Date: Fri, 3 Mar 2017 00:18:15 -0700 Subject: [PATCH] Change backing REST lib to Req (From Wreq) (#11) * Keep only REST * Qualified Wreq * Ported to Req, rest sample stub * Fixes/cleanup * Removed rest-only example, added missing ext to Framework * Separate HTTP details * Fighting type hell * Fixed type issues, thx @Lazersmoke * Documentation * Fixed GetChannelMessages args, removed redundant return * Changes for Req * Add changelog --- CHANGELOG.md | 4 + discord-hs.cabal | 11 +- examples/putstr.hs | 2 +- src/Network/Discord/Framework.hs | 10 +- src/Network/Discord/Gateway.hs | 3 +- src/Network/Discord/Rest.hs | 22 ++- src/Network/Discord/Rest/Channel.hs | 185 ++++++++-------------- src/Network/Discord/Rest/Guild.hs | 236 ++++++++++------------------ src/Network/Discord/Rest/HTTP.hs | 76 +++++++++ src/Network/Discord/Rest/Prelude.hs | 51 +++--- src/Network/Discord/Rest/User.hs | 90 ++++------- src/Network/Discord/Types.hs | 27 +++- src/Network/Discord/Types/Events.hs | 1 + stack.yaml | 9 +- 14 files changed, 345 insertions(+), 382 deletions(-) create mode 100644 CHANGELOG.md create mode 100644 src/Network/Discord/Rest/HTTP.hs diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..0255964 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,4 @@ +Changes: +- Switch from Req to Wreq (See issue #9) +- Breaking api change to UploadFile, now UploadFile fileName file + (file arg remains a LBS) diff --git a/discord-hs.cabal b/discord-hs.cabal index fcf4f8a..0c49931 100755 --- a/discord-hs.cabal +++ b/discord-hs.cabal @@ -23,20 +23,21 @@ Flag disable-docs library exposed-modules: Network.Discord - , Network.Discord.Framework - , Network.Discord.Gateway , Network.Discord.Rest , Network.Discord.Rest.Channel , Network.Discord.Rest.Guild , Network.Discord.Rest.User + , Network.Discord.Framework + , Network.Discord.Gateway , Network.Discord.Types , Network.Discord.Types.Channel + , Network.Discord.Types.Guild , Network.Discord.Types.Events , Network.Discord.Types.Gateway - , Network.Discord.Types.Guild other-modules: Paths_discord_hs , Network.Discord.Rest.Prelude , Network.Discord.Types.Prelude + , Network.Discord.Rest.HTTP -- other-extensions: build-depends: base==4.* , aeson==1.0.* @@ -46,7 +47,7 @@ library , data-default==0.7.* , hashable==1.2.* , hslogger==1.2.* - , lens==4.15.* + , http-client==0.5.* , mmorph==1.0.* , mtl==2.2.* , pipes==4.3.* @@ -59,7 +60,7 @@ library , url==2.1.* , vector==0.11.* , websockets==0.10.* - , wreq==0.5.* + , req==0.2.* , wuss==1.1.* ghc-options: -Wall hs-source-dirs: src diff --git a/examples/putstr.hs b/examples/putstr.hs index eb24d81..23119af 100644 --- a/examples/putstr.hs +++ b/examples/putstr.hs @@ -11,7 +11,7 @@ import Network.Discord.Gateway data PutStrClient = PsClient instance Client PutStrClient where - getAuth _ = Bot "TOKEN HERE" + getAuth _ = Bot "TOKEN" main :: IO () main = runWebsocket (fromJust $ importURL "wss://gateway.discord.gg") PsClient $ do diff --git a/src/Network/Discord/Framework.hs b/src/Network/Discord/Framework.hs index 40ecd2e..f2d86e0 100644 --- a/src/Network/Discord/Framework.hs +++ b/src/Network/Discord/Framework.hs @@ -7,7 +7,7 @@ module Network.Discord.Framework where import Data.Proxy import Control.Concurrent.STM - import Control.Monad.State (execStateT, get) + import Control.Monad.State (get) import Data.Aeson (Object) import Pipes ((~>)) import Pipes.Core hiding (Proxy) @@ -27,7 +27,7 @@ module Network.Discord.Framework where undefined undefined limits - + -- | Basic client implementation. Most likely suitable for most bots. data BotClient = BotClient Auth instance D.Client BotClient where @@ -51,8 +51,8 @@ module Network.Discord.Framework where runAsync c effect = do client <- liftIO . atomically $ getSTMClient c st <- asyncState client - liftIO . void $ forkFinally - (execStateT (runEffect effect) st) + liftIO . void $ forkFinally + (execDiscordM (runEffect effect) st) finish where finish (Right DiscordState{getClient = st}) = atomically $ mergeClient st @@ -63,7 +63,7 @@ module Network.Discord.Framework where -- | Event handlers for 'Gateway' events. These correspond to events listed in -- 'Event' - data D.Client c => Handle c = Null + data D.Client c => Handle c = Null | Misc (Event -> Effect DiscordM ()) | ReadyEvent (Init -> Effect DiscordM ()) | ResumedEvent (Object -> Effect DiscordM ()) diff --git a/src/Network/Discord/Gateway.hs b/src/Network/Discord/Gateway.hs index 8f7cb13..2284dcd 100755 --- a/src/Network/Discord/Gateway.hs +++ b/src/Network/Discord/Gateway.hs @@ -37,7 +37,7 @@ module Network.Discord.Gateway where runWebsocket (URL (Absolute h) path _) client inner = do rl <- newTVarIO [] runSecureClient (host h) 443 (path++"/?v=6") - $ \conn -> evalStateT (runEffect inner) + $ \conn -> evalDiscordM (runEffect inner) (DiscordState Create client conn undefined rl) runWebsocket _ _ _ = mzero @@ -94,3 +94,4 @@ module Network.Discord.Gateway where -- 'Connection' to a stream of gateway 'Event's eventCore :: Connection -> Producer Event DiscordM () eventCore conn = makeWebsocketSource conn >-> makeEvents + diff --git a/src/Network/Discord/Rest.hs b/src/Network/Discord/Rest.hs index d97acb5..ac69d6a 100644 --- a/src/Network/Discord/Rest.hs +++ b/src/Network/Discord/Rest.hs @@ -1,6 +1,7 @@ {-# LANGUAGE ExistentialQuantification, MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings, FlexibleInstances #-} {-# OPTIONS_HADDOCK prune, not-home #-} +{-# OPTIONS_GHC -fno-warn-orphans #-} -- | Provides framework to interact with REST api gateways. Implementations specific to the -- Discord API are provided in Network.Discord.Rest.Channel, Network.Discord.Rest.Guild, -- and Network.Discord.Rest.User. @@ -13,13 +14,13 @@ module Network.Discord.Rest ) where import Control.Monad (void) import Data.Maybe (fromJust) + import Control.Exception (throwIO) - import Control.Lens + import qualified Network.HTTP.Req as R import Control.Monad.Morph (lift) import Data.Aeson.Types import Data.Hashable import Network.URL - import Network.Wreq import Pipes.Core import Network.Discord.Types as Dc @@ -27,32 +28,37 @@ module Network.Discord.Rest import Network.Discord.Rest.Guild import Network.Discord.Rest.Prelude import Network.Discord.Rest.User - + import Network.Discord.Rest.HTTP (baseUrl) + -- | Perform an API request. fetch :: (DoFetch a, Hashable a) => a -> Pipes.Core.Proxy X () c' c DiscordM Fetched fetch req = restServer +>> (request $ Fetch req) - + -- | Perform an API request, ignoring the response fetch' :: (DoFetch a, Hashable a) => a -> Pipes.Core.Proxy X () c' c DiscordM () fetch' = void . fetch - + -- | Alternative method of interacting with the REST api withApi :: Pipes.Core.Client Fetchable Fetched DiscordM Fetched -> Effect DiscordM () withApi inner = void $ restServer +>> inner - + -- | Provides a pipe to perform REST actions restServer :: Fetchable -> Server Fetchable Fetched DiscordM Fetched restServer req = lift (doFetch req) >>= respond >>= restServer + instance R.MonadHttp IO where + handleHttpException = throwIO + -- | Obtains a new gateway to connect to. getGateway :: IO URL getGateway = do - resp <- asValue =<< get (baseURL++"/gateway") - return . fromJust $ importURL =<< parseMaybe getURL (resp ^. responseBody) + r <- R.req R.GET (baseUrl R./: "gateway") R.NoReqBody R.jsonResponse mempty + return . fromJust $ importURL =<< parseMaybe getURL (R.responseBody r) + where getURL :: Value -> Parser String getURL = withObject "url" (.: "url") diff --git a/src/Network/Discord/Rest/Channel.hs b/src/Network/Discord/Rest/Channel.hs index 35315d9..b21a78f 100644 --- a/src/Network/Discord/Rest/Channel.hs +++ b/src/Network/Discord/Rest/Channel.hs @@ -5,22 +5,19 @@ module Network.Discord.Rest.Channel ( ChannelRequest(..) ) where - import Control.Monad (when) - - import Control.Concurrent.STM - import Control.Lens - import Control.Monad.Morph (lift) + import Data.Aeson import Data.ByteString.Lazy import Data.Hashable - import Data.Monoid ((<>)) - import Data.Text - import Data.Time.Clock.POSIX - import Network.Wreq - import qualified Control.Monad.State as ST (get, liftIO) - + import Data.Monoid (mempty, (<>)) + import Data.Text as T + import Network.HTTP.Client (RequestBody (..)) + import Network.HTTP.Client.MultipartFormData (partFileRequestBody) + import Network.HTTP.Req (reqBodyMultipart) + import Network.Discord.Rest.Prelude - import Network.Discord.Types as Dc + import Network.Discord.Types + import Network.Discord.Rest.HTTP -- | Data constructor for Channel requests. See data ChannelRequest a where @@ -31,13 +28,13 @@ module Network.Discord.Rest.Channel -- | Deletes a channel if its id doesn't equal to the id of guild. DeleteChannel :: Snowflake -> ChannelRequest Channel -- | Gets a messages from a channel with limit of 100 per request. - GetChannelMessages :: Snowflake -> [(Text, Text)] -> ChannelRequest [Message] + GetChannelMessages :: Snowflake -> Range -> ChannelRequest [Message] -- | Gets a message in a channel by its id. GetChannelMessage :: Snowflake -> Snowflake -> ChannelRequest Message -- | Sends a message to a channel. CreateMessage :: Snowflake -> Text -> Maybe Embed -> ChannelRequest Message -- | Sends a message with a file to a channel. - UploadFile :: Snowflake -> Text -> ByteString -> ChannelRequest Message + UploadFile :: Snowflake -> FilePath -> ByteString -> ChannelRequest Message -- | Edits a message content. EditMessage :: Message -> Text -> Maybe Embed -> ChannelRequest Message -- | Deletes a message. @@ -83,109 +80,61 @@ module Network.Discord.Rest.Channel hashWithSalt s (AddPinnedMessage chan _) = hashWithSalt s ("pin"::Text, chan) hashWithSalt s (DeletePinnedMessage chan _) = hashWithSalt s ("pin"::Text, chan) - instance Eq (ChannelRequest a) where - a == b = hash a == hash b - - instance RateLimit (ChannelRequest a) where - getRateLimit req = do - DiscordState {getRateLimits=rl} <- ST.get - now <- ST.liftIO (fmap round getPOSIXTime :: IO Int) - ST.liftIO . atomically $ do - rateLimits <- readTVar rl - case lookup (hash req) rateLimits of - Nothing -> return Nothing - Just a - | a >= now -> return $ Just a - | otherwise -> modifyTVar' rl (Dc.delete $ hash req) >> return Nothing - - setRateLimit req reset = do - DiscordState {getRateLimits=rl} <- ST.get - ST.liftIO . atomically . modifyTVar rl $ Dc.insert (hash req) reset + instance RateLimit (ChannelRequest a) instance (FromJSON a) => DoFetch (ChannelRequest a) where - doFetch req = do - waitRateLimit req - SyncFetched <$> fetch req - - -- |Sends a request, used by doFetch. - fetch :: FromJSON a => ChannelRequest a -> DiscordM a - fetch request = do - req <- baseRequest - (resp, rlRem, rlNext) <- lift $ do - resp <- case request of - GetChannel chan -> getWith req - (baseURL ++ "/channels/" ++ show chan) - - ModifyChannel chan patch -> customPayloadMethodWith "PATCH" req - (baseURL ++ "/channels/" ++ show chan) - (toJSON patch) - - DeleteChannel chan -> deleteWith req - (baseURL ++ "/channels/" ++ show chan) - - GetChannelMessages chan patch -> getWith - (Prelude.foldr (\(k, v) -> param k .~ [v]) req patch) - (baseURL ++ "/channels/" ++ show chan ++ "/messages") - - GetChannelMessage chan msg -> getWith req - (baseURL ++ "/channels/" ++ show chan ++ "/messages/" ++ show msg) - - CreateMessage chan msg embed -> postWith req - (baseURL ++ "/channels/" ++ show chan ++ "/messages") - (object $ [("content", toJSON msg)] <> maybeEmbed embed) - - UploadFile chan msg file -> postWith - (req & header "Content-Type" .~ ["multipart/form-data"]) - (baseURL ++ "/channels/" ++ show chan ++ "/messages") - ["content" := msg, "file" := file] - - EditMessage (Message msg chan _ _ _ _ _ _ _ _ _ _ _ _) new embed -> - customPayloadMethodWith "PATCH" req - (baseURL ++ "/channels/" ++ show chan ++ "/messages/" ++ show msg) - (object $ [("content", toJSON new)] <> maybeEmbed embed) - - DeleteMessage (Message msg chan _ _ _ _ _ _ _ _ _ _ _ _) -> - deleteWith req - (baseURL ++ "/channels/" ++ show chan ++ "/messages/" ++ show msg) - - BulkDeleteMessage chan msgs -> postWith req - (baseURL ++ "/channels/" ++ show chan ++ "/messages/bulk-delete") - (object - [("messages", toJSON - $ Prelude.map (\(Message msg _ _ _ _ _ _ _ _ _ _ _ _ _) -> msg) msgs)]) - - EditChannelPermissions chan perm patch -> putWith req - (baseURL ++ "/channels/" ++ show chan ++ "/permissions/" ++ show perm) - (toJSON patch) - - GetChannelInvites chan -> getWith req - (baseURL ++ "/channels/" ++ show chan ++ "/invites") - - CreateChannelInvite chan patch -> postWith req - (baseURL ++ "/channels/" ++ show chan ++ "/invites") - (toJSON patch) - - DeleteChannelPermission chan perm -> deleteWith req - (baseURL ++ "/channels/" ++ show chan ++ "/permissions/" ++ show perm) - - TriggerTypingIndicator chan -> postWith req - (baseURL ++ "/channels/" ++ show chan ++ "/typing") - (toJSON ([]::[Int])) - - GetPinnedMessages chan -> getWith req - (baseURL ++ "/channels/" ++ show chan ++ "/pins") - - AddPinnedMessage chan msg -> putWith req - (baseURL ++ "/channels/" ++ show chan ++ "/pins/" ++ show msg) - (toJSON ([]::[Int])) - - DeletePinnedMessage chan msg -> deleteWith req - (baseURL ++ "/channels/" ++ show chan ++ "/pins/" ++ show msg) - return (justRight . eitherDecode $ resp ^. responseBody - , justRight . eitherDecodeStrict $ resp ^. responseHeader "X-RateLimit-Remaining"::Int - , justRight . eitherDecodeStrict $ resp ^. responseHeader "X-RateLimit-Reset"::Int) - when (rlRem == 0) $ setRateLimit request rlNext - return resp - where - maybeEmbed :: Maybe Embed -> [(Text, Value)] - maybeEmbed = maybe [] $ \embed -> [("embed", toJSON embed)] + doFetch req = SyncFetched <$> go req + where + maybeEmbed :: Maybe Embed -> [(Text, Value)] + maybeEmbed = maybe [] $ \embed -> ["embed" .= embed] + url = baseUrl /: "channels" + go :: ChannelRequest a -> DiscordM a + go r@(GetChannel chan) = makeRequest r + $ Get (url // chan) mempty + go r@(ModifyChannel chan patch) = makeRequest r + $ Patch (url // chan) + (ReqBodyJson patch) mempty + go r@(DeleteChannel chan) = makeRequest r + $ Delete (url // chan) mempty + go r@(GetChannelMessages chan range) = makeRequest r + $ Get (url // chan /: "messages") (toQueryString range) + go r@(GetChannelMessage chan msg) = makeRequest r + $ Get (url // chan /: "messages" // msg) mempty + go r@(CreateMessage chan msg embed) = makeRequest r + $ Post (url // chan /: "messages") + (ReqBodyJson . object $ ["content" .= msg] <> maybeEmbed embed) + mempty + go r@(UploadFile chan fileName file) = do + body <- reqBodyMultipart [partFileRequestBody "file" fileName $ RequestBodyLBS file] + makeRequest r $ Post (url // chan /: "messages") + body mempty + go r@(EditMessage (Message msg chan _ _ _ _ _ _ _ _ _ _ _ _) new embed) = makeRequest r + $ Patch (url // chan /: "messages" // msg) + (ReqBodyJson . object $ ["content" .= new] <> maybeEmbed embed) + mempty + go r@(DeleteMessage (Message msg chan _ _ _ _ _ _ _ _ _ _ _ _)) = makeRequest r + $ Delete (url // chan /: "messages" // msg) mempty + go r@(BulkDeleteMessage chan msgs) = makeRequest r + $ Post (url // chan /: "messages" /: "bulk-delete") + (ReqBodyJson $ object ["messages" .= Prelude.map messageId msgs]) + mempty + go r@(EditChannelPermissions chan perm patch) = makeRequest r + $ Put (url // chan /: "permissions" // perm) + (ReqBodyJson patch) mempty + go r@(GetChannelInvites chan) = makeRequest r + $ Get (url // chan /: "invites") mempty + go r@(CreateChannelInvite chan patch) = makeRequest r + $ Post (url // chan /: "invites") + (ReqBodyJson patch) mempty + go r@(DeleteChannelPermission chan perm) = makeRequest r + $ Delete (url // chan /: "permissions" // perm) mempty + go r@(TriggerTypingIndicator chan) = makeRequest r + $ Post (url // chan /: "typing") + NoReqBody mempty + go r@(GetPinnedMessages chan) = makeRequest r + $ Get (url // chan /: "pins") mempty + go r@(AddPinnedMessage chan msg) = makeRequest r + $ Put (url // chan /: "pins" // msg) + NoReqBody mempty + go r@(DeletePinnedMessage chan msg) = makeRequest r + $ Delete (url // chan /: "pins" // msg) mempty diff --git a/src/Network/Discord/Rest/Guild.hs b/src/Network/Discord/Rest/Guild.hs index 7d6c35e..b810305 100644 --- a/src/Network/Discord/Rest/Guild.hs +++ b/src/Network/Discord/Rest/Guild.hs @@ -5,21 +5,16 @@ module Network.Discord.Rest.Guild ( GuildRequest(..) ) where - import Control.Monad (when) - - import Control.Concurrent.STM - import Control.Lens - import Control.Monad.Morph (lift) + import Data.Aeson import Data.Hashable - import Data.Text - import Data.Time.Clock.POSIX - import Network.Wreq - import qualified Control.Monad.State as ST (get, liftIO) - + import Data.Monoid (mempty) + import Data.Text as T + import Network.HTTP.Req ((=:)) + import Network.Discord.Rest.Prelude - import Network.Discord.Types as Dc - + import Network.Discord.Types + import Network.Discord.Rest.HTTP -- | Data constructor for Guild requests. See -- @@ -60,7 +55,7 @@ module Network.Discord.Rest.Guild -- | Returns a list of 'User' objects that are banned from this guild. Requires the -- 'BAN_MEMBERS' permission GetGuildBans :: Snowflake -> GuildRequest [User] - -- | Create a guild ban, and optionally delete previous messages sent by the banned + -- | Create a guild ban, and optionally Delete previous messages sent by the banned -- user. Requires the 'BAN_MEMBERS' permission. Fires a Guild Ban Add 'Event'. CreateGuildBan :: Snowflake -> Snowflake -> Integer -> GuildRequest () -- | Remove the ban for a user. Requires the 'BAN_MEMBERS' permissions. @@ -126,7 +121,7 @@ module Network.Discord.Rest.Guild hashWithSalt s (CreateGuildChannel g _) = hashWithSalt s ("guild_chan"::Text, g) hashWithSalt s (ModifyChanPosition g _) = hashWithSalt s ("guild_chan"::Text, g) hashWithSalt s (GetGuildMember g _) = hashWithSalt s ("guild_memb"::Text, g) - hashWithSalt s (ListGuildMembers g _) = hashWithSalt s ("guild_membs"::Text, g) + hashWithSalt s (ListGuildMembers g _) = hashWithSalt s ("guild_membs"::Text, g) hashWithSalt s (AddGuildMember g _ _) = hashWithSalt s ("guild_memb"::Text, g) hashWithSalt s (ModifyGuildMember g _ _) = hashWithSalt s ("guild_memb"::Text, g) hashWithSalt s (RemoveGuildMember g _) = hashWithSalt s ("guild_memb"::Text, g) @@ -154,142 +149,81 @@ module Network.Discord.Rest.Guild hashWithSalt s (GetGuildEmbed g) = hashWithSalt s ("guild_embed"::Text, g) hashWithSalt s (ModifyGuildEmbed g _) = hashWithSalt s ("guild_embed"::Text, g) - instance Eq (GuildRequest a) where - a == b = hash a == hash b - - instance RateLimit (GuildRequest a) where - getRateLimit req = do - DiscordState {getRateLimits=rl} <- ST.get - now <- ST.liftIO (fmap round getPOSIXTime :: IO Int) - ST.liftIO . atomically $ do - rateLimits <- readTVar rl - case lookup (hash req) rateLimits of - Nothing -> return Nothing - Just a - | a >= now -> return $ Just a - | otherwise -> modifyTVar' rl (Dc.delete $ hash req) >> return Nothing - - setRateLimit req reset = do - DiscordState {getRateLimits=rl} <- ST.get - ST.liftIO . atomically . modifyTVar rl $ Dc.insert (hash req) reset - + instance RateLimit (GuildRequest a) instance (FromJSON a) => DoFetch (GuildRequest a) where - doFetch req = do - waitRateLimit req - SyncFetched <$> fetch req - - fetch :: FromJSON a => GuildRequest a -> DiscordM a - fetch request = do - req <- baseRequest - (resp, rlRem, rlNext) <- lift $ do - resp <- case request of - GetGuild chan -> getWith req - (baseURL ++ "/guilds/" ++ show chan) - - ModifyGuild chan patch -> customPayloadMethodWith "PATCH" req - (baseURL ++ "/guilds/" ++ show chan) - (toJSON patch) - - DeleteGuild chan -> deleteWith req - (baseURL ++ "/guilds/" ++ show chan) - - GetGuildChannels chan -> getWith req - (baseURL ++ "/guilds/" ++ show chan ++ "/channels") - - CreateGuildChannel chan patch -> postWith req - (baseURL ++ "/guilds/" ++ show chan ++ "/channels") - (toJSON patch) - - ModifyChanPosition chan patch -> customPayloadMethodWith "PATCH" req - (baseURL ++ "/guilds/" ++ show chan ++ "/channels") - (toJSON patch) - - GetGuildMember chan user -> getWith req - (baseURL ++ "/guilds/" ++ show chan ++ "/members/" ++ show user) - - ListGuildMembers chan range -> getWith req - (baseURL ++ "/guilds/" ++ show chan ++"/members?" ++ toQueryString range) - - AddGuildMember chan user patch -> customPayloadMethodWith "PUT" req - (baseURL ++ "/guilds/" ++ show chan ++ "/members/" ++ show user) - (toJSON patch) - - ModifyGuildMember chan user patch -> customPayloadMethodWith "PATCH" req - (baseURL ++ "/guilds/" ++ show chan ++ "/members/" ++ show user) - (toJSON patch) - - RemoveGuildMember chan user -> deleteWith req - (baseURL ++ "/guilds/" ++ show chan ++ "/members/"++ show user) - - GetGuildBans chan -> getWith req - (baseURL ++ "/guilds/" ++ show chan ++ "/bans") - - CreateGuildBan chan user msg -> customPayloadMethodWith "PUT" req - (baseURL ++ "/guilds/" ++ show chan ++ "/bans/" ++ show user) - ["delete-message-days" := msg] - - RemoveGuildBan chan user -> deleteWith req - (baseURL ++ "/guilds/" ++ show chan ++ "/bans/" ++ show user) - - GetGuildRoles chan -> getWith req - (baseURL ++ "/guilds/" ++ show chan ++ "/roles") - - CreateGuildRole chan -> postWith req - (baseURL ++ "/guilds/" ++ show chan ++ "/roles") - (toJSON (""::Text)) - - ModifyGuildRolePositions chan pos -> postWith req - (baseURL ++ "/guilds/" ++ show chan ++ "/roles") - (toJSON pos) - - ModifyGuildRole chan role patch -> postWith req - (baseURL ++ "/guilds/" ++ show chan ++ "/roles/" ++ show role) - (toJSON patch) - - DeleteGuildRole chan role -> deleteWith req - (baseURL ++ "/guilds/" ++ show chan ++ "/roles/" ++ show role) - - GetGuildPruneCount chan days -> getWith req - (baseURL++"/guilds/"++ show chan ++"/prune?days="++show days) - - BeginGuildPrune chan days -> postWith req - (baseURL ++ "/guilds/" ++ show chan ++ "/prune?days=" ++ show days) - (toJSON (""::Text)) - - GetGuildVoiceRegions chan -> getWith req - (baseURL ++ "/guilds/" ++ show chan ++ "/regions") - - GetGuildInvites chan -> getWith req - (baseURL ++ "/guilds/" ++ show chan ++ "/invites") - - GetGuildIntegrations chan -> getWith req - (baseURL ++ "/guilds/" ++ show chan ++ "/integrations") - - CreateGuildIntegration chan patch -> postWith req - (baseURL ++ "/guilds/"++ show chan ++ "/integrations") - (toJSON patch) - - ModifyGuildIntegration chan integ patch -> customPayloadMethodWith "PATCH" req - (baseURL ++ "/guilds/" ++ show chan ++ "/integrations/" ++ show integ) - (toJSON patch) - - DeleteGuildIntegration chan integ -> deleteWith req - (baseURL ++ "/guilds/" ++ show chan ++ "/integrations/" ++ show integ) - - SyncGuildIntegration chan integ -> postWith req - (baseURL ++ "/guilds/" ++ show chan ++ "/integrations/" ++ show integ) - (toJSON (""::Text)) - - GetGuildEmbed chan -> getWith req - (baseURL ++ "/guilds/" ++ show chan ++ "/embed") - - ModifyGuildEmbed chan embed -> customPayloadMethodWith "PATCH" req - (baseURL ++ "/guilds/" ++ show chan ++ "/embed") - (toJSON embed) - - return (justRight . eitherDecode $ resp ^. responseBody - , justRight . eitherDecodeStrict $ resp ^. responseHeader "X-RateLimit-Remaining"::Int - , justRight . eitherDecodeStrict $ resp ^. responseHeader "X-RateLimit-Reset"::Int) - when (rlRem == 0) $ setRateLimit request rlNext - return resp + doFetch req = SyncFetched <$> go req + where + url = baseUrl /: "guilds" + go :: GuildRequest a -> DiscordM a + go r@(GetGuild guild) = makeRequest r + $ Get (url // guild) mempty + go r@(ModifyGuild guild patch) = makeRequest r + $ Patch (url // guild) (ReqBodyJson patch) mempty + go r@(DeleteGuild guild) = makeRequest r + $ Delete (url // guild) mempty + go r@(GetGuildChannels guild) = makeRequest r + $ Get (url // guild /: "channels") mempty + go r@(CreateGuildChannel guild patch) = makeRequest r + $ Post (url // guild /: "channels") (ReqBodyJson patch) mempty + go r@(ModifyChanPosition guild patch) = makeRequest r + $ Post (url // guild /: "channels") (ReqBodyJson patch) mempty + go r@(GetGuildMember guild member) = makeRequest r + $ Get (url // guild /: "members" // member) mempty + go r@(ListGuildMembers guild range) = makeRequest r + $ Get (url // guild /: "members") (toQueryString range) + go r@(AddGuildMember guild user patch) = makeRequest r + $ Put (url // guild /: "members" // user) (ReqBodyJson patch) mempty + go r@(ModifyGuildMember guild member patch) = makeRequest r + $ Patch (url // guild /: "members" // member) (ReqBodyJson patch) mempty + go r@(RemoveGuildMember guild user) = makeRequest r + $ Delete (url // guild /: "members" // user) mempty + go r@(GetGuildBans guild) = makeRequest r + $ Get (url // guild /: "bans") mempty + go r@(CreateGuildBan guild user msgs) = makeRequest r + $ Put (url // guild /: "bans" // user) + (ReqBodyJson $ object [ "delete-message-days" .= msgs]) + mempty + go r@(RemoveGuildBan guild ban) = makeRequest r + $ Delete (url // guild /: "bans" // ban) mempty + go r@(GetGuildRoles guild) = makeRequest r + $ Get (url // guild /: "roles") mempty + go r@(CreateGuildRole guild) = makeRequest r + $ Post (url // guild /: "roles") + NoReqBody mempty + go r@(ModifyGuildRolePositions guild patch) = makeRequest r + $ Post (url // guild /: "roles") + (ReqBodyJson patch) mempty + go r@(ModifyGuildRole guild role patch) = makeRequest r + $ Post (url // guild /: "roles" // role) + (ReqBodyJson patch) mempty + go r@(DeleteGuildRole guild role) = makeRequest r + $ Delete (url // guild /: "roles" // role) mempty + go r@(GetGuildPruneCount guild days) = makeRequest r + $ Get (url // guild /: "prune") ("days" =: days) + go r@(BeginGuildPrune guild days) = makeRequest r + $ Post (url // guild /: "prune") + NoReqBody ("days" =: days) + go r@(GetGuildVoiceRegions guild) = makeRequest r + $ Get (url // guild /: "regions") mempty + go r@(GetGuildInvites guild) = makeRequest r + $ Get (url // guild /: "invites") mempty + go r@(GetGuildIntegrations guild) = makeRequest r + $ Get (url // guild /: "integrations") mempty + go r@(CreateGuildIntegration guild patch) = makeRequest r + $ Post (url // guild /: "integrations") + (ReqBodyJson patch) mempty + go r@(ModifyGuildIntegration guild integ patch) = makeRequest r + $ Patch (url // guild /: "integrations" // integ) + (ReqBodyJson patch) mempty + go r@(DeleteGuildIntegration guild integ) = makeRequest r + $ Delete (url // guild /: "integrations" // integ) + mempty + go r@(SyncGuildIntegration guild integ) = makeRequest r + $ Post (url // guild /: "integrations" // integ) + NoReqBody mempty + go r@(GetGuildEmbed guild) = makeRequest r + $ Get (url // guild /: "integrations") mempty + go r@(ModifyGuildEmbed guild patch) = makeRequest r + $ Patch (url // guild /: "embed") + (ReqBodyJson patch) mempty diff --git a/src/Network/Discord/Rest/HTTP.hs b/src/Network/Discord/Rest/HTTP.hs new file mode 100644 index 0000000..d9a6ea3 --- /dev/null +++ b/src/Network/Discord/Rest/HTTP.hs @@ -0,0 +1,76 @@ +{-# LANGUAGE GADTs, OverloadedStrings, InstanceSigs, TypeSynonymInstances #-} +{-# LANGUAGE FlexibleInstances, MultiParamTypeClasses #-} +{-# LANGUAGE DataKinds, ScopedTypeVariables, Rank2Types #-} +-- | Provide HTTP primitives +module Network.Discord.Rest.HTTP + ( JsonRequest(..) + , R.ReqBodyJson(..) + , R.NoReqBody(..) + , baseUrl + , fetch + , makeRequest + , (//) + , (R./:) + ) where + + import Data.Semigroup ((<>)) + + import Control.Monad.State (get, when) + import Data.Aeson + import Data.ByteString.Char8 (pack, ByteString) + import Data.Maybe (fromMaybe) + import qualified Data.Text as T (pack) + import qualified Network.HTTP.Req as R + + import Data.Version (showVersion) + import Network.Discord.Rest.Prelude + import Network.Discord.Types (DiscordM, getClient, DiscordState(..), getAuth) + import Paths_discord_hs (version) + + -- | The base url (Req) for API requests + baseUrl :: R.Url 'R.Https + baseUrl = R.https "discordapp.com" R./: "api" R./: apiVersion + where apiVersion = "v6" + + -- | Construct base options with auth from Discord state + baseRequestOptions :: DiscordM Option + baseRequestOptions = do + DiscordState {getClient=client} <- get + return $ R.header "Authorization" (pack . show $ getAuth client) + <> R.header "User-Agent" (pack $ "DiscordBot (https://github.com/jano017/Discord.hs," + ++ showVersion version ++ ")") + infixl 5 // + (//) :: Show a => R.Url scheme -> a -> R.Url scheme + url // part = url R./: (T.pack $ show part) + + type Option = R.Option 'R.Https + + -- | Represtents a HTTP request made to an API that supplies a Json response + data JsonRequest r where + Delete :: FromJSON r => R.Url 'R.Https -> Option -> JsonRequest r + Get :: FromJSON r => R.Url 'R.Https -> Option -> JsonRequest r + Patch :: (FromJSON r, R.HttpBody a) => R.Url 'R.Https -> a -> Option -> JsonRequest r + Post :: (FromJSON r, R.HttpBody a) => R.Url 'R.Https -> a -> Option -> JsonRequest r + Put :: (FromJSON r, R.HttpBody a) => R.Url 'R.Https -> a -> Option -> JsonRequest r + + fetch :: FromJSON r => JsonRequest r -> DiscordM (R.JsonResponse r) + fetch (Delete url opts) = R.req R.DELETE url R.NoReqBody R.jsonResponse =<< (<> opts) <$> baseRequestOptions + fetch (Get url opts) = R.req R.GET url R.NoReqBody R.jsonResponse =<< (<> opts) <$> baseRequestOptions + fetch (Patch url body opts) = R.req R.PATCH url body R.jsonResponse =<< (<> opts) <$> baseRequestOptions + fetch (Post url body opts) = R.req R.POST url body R.jsonResponse =<< (<> opts) <$> baseRequestOptions + fetch (Put url body opts) = R.req R.PUT url body R.jsonResponse =<< (<> opts) <$> baseRequestOptions + + makeRequest :: (RateLimit a, FromJSON r) => a -> JsonRequest r -> DiscordM r + makeRequest req action = do + waitRateLimit req + resp <- fetch action + when (parseHeader resp "X-RateLimit-Remaining" 1 < 1) $ + setRateLimit req $ parseHeader resp "X-RateLimit-Reset" 0 + return $ R.responseBody resp + where + parseHeader :: R.HttpResponse resp => resp -> ByteString -> Int -> Int + parseHeader resp header def = fromMaybe def $ decodeStrict =<< R.responseHeader resp header + + -- | Base implementation of DoFetch, allows arbitrary HTTP requests to be performed + instance (FromJSON r) => DoFetch (JsonRequest r) where + doFetch req = SyncFetched . R.responseBody <$> fetch req diff --git a/src/Network/Discord/Rest/Prelude.hs b/src/Network/Discord/Rest/Prelude.hs index e9fc85d..a005619 100644 --- a/src/Network/Discord/Rest/Prelude.hs +++ b/src/Network/Discord/Rest/Prelude.hs @@ -1,52 +1,47 @@ {-# LANGUAGE ExistentialQuantification, MultiParamTypeClasses #-} -{-# LANGUAGE OverloadedStrings, FlexibleInstances #-} +{-# LANGUAGE OverloadedStrings, DataKinds #-} -- | Utility and base types and functions for the Discord Rest API module Network.Discord.Rest.Prelude where import Control.Concurrent (threadDelay) - import Data.Version (showVersion) - - import Control.Lens + + import Control.Concurrent.STM import Data.Aeson - import Data.ByteString.Char8 (pack) import Data.Default import Data.Hashable + import Data.Monoid ((<>)) import Data.Time.Clock.POSIX - import Network.Wreq + import Network.HTTP.Req (Option, Scheme(..), (=:)) import System.Log.Logger import qualified Control.Monad.State as St import Network.Discord.Types - import Paths_discord_hs (version) - - -- | Read function specialized for Integers - readInteger :: String -> Integer - readInteger = read -- | The base url for API requests baseURL :: String baseURL = "https://discordapp.com/api/v6" - -- | Construct base request with auth from Discord state - baseRequest :: DiscordM Options - baseRequest = do - DiscordState {getClient=client} <- St.get - return $ defaults - & header "Authorization" .~ [pack . show $ getAuth client] - & header "User-Agent" .~ - [pack $ "DiscordBot (https://github.com/jano017/Discord.hs," - ++ showVersion version - ++ ")"] - & header "Content-Type" .~ ["application/json"] - -- | Class for rate-limitable actions - class RateLimit a where + class Hashable a => RateLimit a where -- | Return seconds to expiration if we're waiting -- for a rate limit to reset getRateLimit :: a -> DiscordM (Maybe Int) + getRateLimit req = do + DiscordState {getRateLimits=rl} <- St.get + now <- St.liftIO (fmap round getPOSIXTime :: IO Int) + St.liftIO . atomically $ do + rateLimits <- readTVar rl + case lookup (hash req) rateLimits of + Nothing -> return Nothing + Just a + | a >= now -> return $ Just a + | otherwise -> modifyTVar' rl (delete $ hash req) >> return Nothing -- | Set seconds to the next rate limit reset when -- we hit a rate limit setRateLimit :: a -> Int -> DiscordM () + setRateLimit req reset = do + DiscordState {getRateLimits=rl} <- St.get + St.liftIO . atomically . modifyTVar rl $ insert (hash req) reset -- | If we hit a rate limit, wait for it to reset waitRateLimit :: a -> DiscordM () waitRateLimit endpoint = do @@ -87,6 +82,8 @@ module Network.Discord.Rest.Prelude where def = Range 0 18446744073709551615 100 -- | Convert a Range to a query string - toQueryString :: Range -> String - toQueryString (Range a b l) = - "after=" ++ show a ++ "&before=" ++ show b ++ "&limit=" ++ show l + toQueryString :: Range -> Option 'Https + toQueryString (Range a b l) + = "after" =: show a + <> "before" =: show b + <> "limit" =: show l diff --git a/src/Network/Discord/Rest/User.hs b/src/Network/Discord/Rest/User.hs index af269c0..c93c6f4 100644 --- a/src/Network/Discord/Rest/User.hs +++ b/src/Network/Discord/Rest/User.hs @@ -5,20 +5,15 @@ module Network.Discord.Rest.User ( UserRequest(..) ) where - import Control.Monad (when) - - import Control.Concurrent.STM - import Control.Monad.Morph (lift) - import Control.Lens + import Data.Aeson import Data.Hashable - import Data.Text - import Data.Time.Clock.POSIX - import Network.Wreq - import qualified Control.Monad.State as ST (get, liftIO) - + import Data.Monoid (mempty) + import Data.Text as T + import Network.Discord.Rest.Prelude - import Network.Discord.Types as Dc + import Network.Discord.Types + import Network.Discord.Rest.HTTP -- | Data constructor for User requests. See -- @@ -46,57 +41,36 @@ module Network.Discord.Rest.User hashWithSalt s (GetUser _) = hashWithSalt s ("user"::Text) hashWithSalt s (ModifyCurrentUser _) = hashWithSalt s ("modify_user"::Text) hashWithSalt s (GetCurrentUserGuilds _) = hashWithSalt s ("get_user_guilds"::Text) - hashWithSalt s (LeaveGuild g) = hashWithSalt s ("leaveGuild"::Text, g) + hashWithSalt s (LeaveGuild g) = hashWithSalt s ("leave_guild"::Text, g) hashWithSalt s (GetUserDMs) = hashWithSalt s ("get_dms"::Text) hashWithSalt s (CreateDM _) = hashWithSalt s ("make_dm"::Text) - instance Eq (UserRequest a) where - a == b = hash a == hash b + instance RateLimit (UserRequest a) - instance RateLimit (UserRequest a) where - getRateLimit req = do - DiscordState {getRateLimits=rl} <- ST.get - now <- ST.liftIO (fmap round getPOSIXTime :: IO Int) - ST.liftIO . atomically $ do - rateLimits <- readTVar rl - case lookup (hash req) rateLimits of - Nothing -> return Nothing - Just a - | a >= now -> return $ Just a - | otherwise -> modifyTVar' rl (Dc.delete $ hash req) >> return Nothing + instance (FromJSON a) => DoFetch (UserRequest a) where + doFetch req = SyncFetched <$> go req + where + url = baseUrl /: "users" + go :: UserRequest a -> DiscordM a + go r@(GetCurrentUser) = makeRequest r + $ Get (url /: "@me") mempty - setRateLimit req reset = do - DiscordState {getRateLimits=rl} <- ST.get - ST.liftIO . atomically . modifyTVar rl $ Dc.insert (hash req) reset + go r@(GetUser user) = makeRequest r + $ Get (url // user ) mempty - instance (FromJSON a) => DoFetch (UserRequest a) where - doFetch req = do - waitRateLimit req - SyncFetched <$> fetch req + go r@(ModifyCurrentUser patch) = makeRequest r + $ Patch (url /: "@me") (ReqBodyJson patch) mempty + + go r@(GetCurrentUserGuilds range) = makeRequest r + $ Get url $ toQueryString range + + go r@(LeaveGuild guild) = makeRequest r + $ Delete (url /: "@me" /: "guilds" // guild) mempty + + go r@(GetUserDMs) = makeRequest r + $ Get (url /: "@me" /: "channels") mempty - fetch :: FromJSON a => UserRequest a -> DiscordM a - fetch request = do - req <- baseRequest - (resp, rlRem, rlNext) <- lift $ do - resp <- case request of - GetCurrentUser -> getWith req - "/users/@me" - GetUser user -> getWith req - ("/users/" ++ show user) - ModifyCurrentUser patch -> customPayloadMethodWith "PATCH" req - "/users/@me" - (toJSON patch) - GetCurrentUserGuilds range -> getWith req - ("/users/@me/guilds?" ++ toQueryString range) - LeaveGuild guild -> deleteWith req - ("/users/@me/guilds/" ++ show guild) - GetUserDMs -> getWith req - "/users/@me/channels" - CreateDM (Snowflake user) -> postWith req - "/users/@me/channels" - ["recipient_id" := user] - return (justRight . eitherDecode $ resp ^. responseBody - , justRight . eitherDecodeStrict $ resp ^. responseHeader "X-RateLimit-Remaining"::Int - , justRight . eitherDecodeStrict $ resp ^. responseHeader "X-RateLimit-Reset"::Int) - when (rlRem == 0) $ setRateLimit request rlNext - return resp + go r@(CreateDM user) = makeRequest r + $ Post (url /: "@me" /: "channels") + (ReqBodyJson $ object ["recipient_id" .= user]) + mempty diff --git a/src/Network/Discord/Types.hs b/src/Network/Discord/Types.hs index 268a6d8..8fc4d38 100755 --- a/src/Network/Discord/Types.hs +++ b/src/Network/Discord/Types.hs @@ -1,4 +1,4 @@ -{-# LANGUAGE RankNTypes, ExistentialQuantification #-} +{-# LANGUAGE RankNTypes, ExistentialQuantification, GeneralizedNewtypeDeriving #-} {-# OPTIONS_HADDOCK prune, not-home #-} -- | Provides types and encoding/decoding code. Types should be identical to those provided -- in the Discord API documentation. @@ -11,12 +11,14 @@ module Network.Discord.Types , module Network.Discord.Types.Guild ) where - import Data.Proxy - import Control.Monad.State (StateT) + import Data.Proxy (Proxy) + import Control.Monad.State (StateT, MonadState, evalStateT, execStateT) import Control.Concurrent.STM + import Control.Monad.IO.Class (MonadIO) import Network.WebSockets (Connection) import System.IO.Unsafe (unsafePerformIO) + import qualified Network.HTTP.Req as R (MonadHttp(..)) import Network.Discord.Types.Channel import Network.Discord.Types.Events @@ -28,8 +30,8 @@ module Network.Discord.Types data StateEnum = Create | Start | Running | InvalidReconnect | InvalidDead -- | Stores details needed to manage the gateway and bot - data DiscordState = forall a . (Client a) => DiscordState { - getState :: StateEnum -- ^ Current state of the gateway + data DiscordState = forall a . (Client a) => DiscordState + { getState :: StateEnum -- ^ Current state of the gateway , getClient :: a -- ^ Currently running bot client , getWebSocket :: Connection -- ^ Stored WebSocket gateway , getSequenceNum :: TMVar Integer -- ^ Heartbeat sequence number @@ -37,7 +39,20 @@ module Network.Discord.Types } -- | Convenience type alias for the monad most used throughout most Discord.hs operations - type DiscordM = StateT DiscordState IO + newtype DiscordM a = DiscordM (StateT DiscordState IO a) + deriving (MonadIO, MonadState DiscordState, Monad, Applicative, Functor) + + -- | Allow HTTP requests to be made from the DiscordM monad + instance R.MonadHttp DiscordM where + handleHttpException e = error $ show e + + -- | Unwrap and eval a 'DiscordM' + evalDiscordM :: DiscordM a -> DiscordState -> IO a + evalDiscordM (DiscordM inner) = evalStateT inner + + -- | Unwrap and exec a 'DiscordM' + execDiscordM :: DiscordM a -> DiscordState -> IO DiscordState + execDiscordM (DiscordM inner) = execStateT inner -- | The Client typeclass holds the majority of the user-customizable state, -- including merging states resulting from async operations. diff --git a/src/Network/Discord/Types/Events.hs b/src/Network/Discord/Types/Events.hs index bc1e5b9..50acf5a 100644 --- a/src/Network/Discord/Types/Events.hs +++ b/src/Network/Discord/Types/Events.hs @@ -90,3 +90,4 @@ module Network.Discord.Types.Events where _ -> UnknownEvent ev <$> reparse o where o = Object ob parseDispatch _ = error "Tried to parse non-Dispatch payload" + diff --git a/stack.yaml b/stack.yaml index cc991c0..09ffbf2 100644 --- a/stack.yaml +++ b/stack.yaml @@ -1,9 +1,14 @@ -flags: {} +resolver: lts-8.0 + +flags: + time-locale-compat: + old-locale: false extra-package-dbs: [] packages: - '.' extra-deps: -- wreq-0.5.0.0 + - time-locale-compat-0.1.1.3 resolver: lts-8.0 nix: shell-file: shell.nix +