Skip to content

Commit

Permalink
Merge pull request #727 from input-output-hk/plt-3513-improve-failures
Browse files Browse the repository at this point in the history
PLT-3515 Close communication gracefully
  • Loading branch information
jhbertra authored Oct 12, 2023
2 parents 6e165db + 604e0db commit 41af813
Show file tree
Hide file tree
Showing 23 changed files with 393 additions and 211 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
### Added

- BREAKING added framing to transport layer allowing for errors to be
propagated when a peer crashes.
90 changes: 60 additions & 30 deletions marlowe-protocols/src/Network/Channel.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,20 @@
module Network.Channel where

import Control.Concurrent.STM (STM, newTChan, readTChan, writeTChan)
import Control.Monad (mfilter, (>=>))
import qualified Data.ByteString as BS
import Control.Exception (Exception (..))
import Control.Monad ((>=>))
import Control.Monad.Trans.Maybe (MaybeT (..))
import Data.Binary.Get (getInt64be, runGet)
import Data.Binary.Put (putInt64be, runPut)
import qualified Data.ByteString.Lazy as LBS
import Data.Functor (($>))
import Data.Text.Internal.Lazy (smallChunkSize)
import qualified Data.Text.Lazy as TL
import qualified Data.Text.Lazy.Encoding as TLE
import GHC.Generics (Generic)
import GHC.IO (mkUserError)
import Network.Socket (Socket)
import qualified Network.Socket.ByteString.Lazy as Socket
import qualified System.IO as IO
import UnliftIO (MonadIO, liftIO)
import UnliftIO (MonadIO, MonadUnliftIO, SomeException (..), catch, liftIO, mask, throwIO, try)

data Channel m a = Channel
{ send :: a -> m ()
Expand Down Expand Up @@ -42,36 +47,61 @@ hoistChannel nat Channel{..} =
, recv = nat recv
}

handlesAsChannel
:: forall m
. (MonadIO m)
=> IO.Handle
-- ^ Read handle
-> IO.Handle
-- ^ Write handle
-> Channel m LBS.ByteString
handlesAsChannel hread hwrite = Channel{..}
where
send :: LBS.ByteString -> m ()
send chunk = liftIO do
LBS.hPut hwrite chunk
IO.hFlush hwrite
data FrameStatus
= OkStatus
| ErrorStatus
deriving stock (Show, Read, Eq, Ord, Bounded, Enum, Generic)

recv :: m (Maybe LBS.ByteString)
recv = liftIO do
eof <- IO.hIsEOF hread
if eof
then pure Nothing
else Just . LBS.fromStrict <$> BS.hGetSome hread smallChunkSize
data Frame = Frame
{ frameStatus :: FrameStatus
, frameContents :: LBS.ByteString
}

socketAsChannel :: forall m. (MonadIO m) => Socket -> Channel m LBS.ByteString
socketAsChannel :: forall m. (MonadIO m) => Socket -> Channel m Frame
socketAsChannel sock = Channel{..}
where
send :: LBS.ByteString -> m ()
send = liftIO . Socket.sendAll sock
send :: Frame -> m ()
send Frame{..} = liftIO do
let headerBytes =
LBS.cons
( case frameStatus of
OkStatus -> 0
ErrorStatus -> 1
)
(runPut $ putInt64be $ LBS.length frameContents)
Socket.sendAll sock $ headerBytes <> frameContents

recv :: m (Maybe Frame)
recv = runMaybeT do
headerBytes <- liftIO $ Socket.recv sock 9
(statusByte, sizeBytes) <- MaybeT $ pure $ LBS.uncons headerBytes
frameStatus <- case statusByte of
0 -> pure OkStatus
1 -> pure ErrorStatus
_ -> throwIO $ mkUserError $ "Invalid status byte: " <> show statusByte
let contentLength = runGet getInt64be sizeBytes
frameContents <- liftIO $ Socket.recv sock contentLength
pure Frame{..}

recv :: m (Maybe LBS.ByteString)
recv = liftIO $ mfilter (not . LBS.null) . pure <$> Socket.recv sock (fromIntegral smallChunkSize)
withSocketChannel
:: (MonadUnliftIO m, MonadFail m)
=> Socket
-> (Channel m Frame -> m a)
-> m a
withSocketChannel socket f = do
let channel = socketAsChannel socket
mask \restore -> do
result <- try $ restore $ f channel
case result of
Left (SomeException ex) -> do
let errorFrame = Frame ErrorStatus $ TLE.encodeUtf8 $ TL.pack $ displayException ex
-- Ignore any errors sending the error to the peer. For example, if we're throwing
-- an exception because the peer failed, we would end up writing to a broken pipe.
-- Additionally, we don't really care if this succeeds or not from the point of view
-- of this peer's next steps - which should always be to re-throw the exception.
send channel errorFrame `catch` \SomeException{} -> pure ()
throwIO ex
Right a -> pure a

effectChannel
:: (Monad m)
Expand Down
5 changes: 3 additions & 2 deletions marlowe-protocols/src/Network/Channel/Typed.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ import Control.Monad.Trans.Resource (ResourceT)
import Data.Binary (put)
import Data.Binary.Put (runPut)
import qualified Data.ByteString.Lazy as LBS
import Data.Data (Typeable)
import Network.Channel (socketAsChannel)
import Network.Protocol.Codec (BinaryMessage)
import Network.Protocol.Codec (BinaryMessage, ShowProtocol)
import Network.Protocol.Driver.Trace
import Network.Protocol.Peer.Trace
import Network.Socket (
Expand Down Expand Up @@ -304,7 +305,7 @@ driverToChannel inj driver = go $ startDStateTraced driver
}

tcpClientChannel
:: (MonadUnliftIO m, MonadEvent r s m, HasSpanContext r, BinaryMessage ps)
:: (MonadUnliftIO m, MonadEvent r s m, HasSpanContext r, BinaryMessage ps, ShowProtocol ps, Typeable ps)
=> InjectSelector (TcpClientSelector ps) s
-> HostName
-> PortNumber
Expand Down
3 changes: 1 addition & 2 deletions marlowe-protocols/src/Network/Protocol/ChainSeek/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@ import Data.Proxy (Proxy (..))
import Data.Text (Text)
import qualified Data.Text as T
import GHC.Show (showList__, showSpace)
import Network.Protocol.Codec (BinaryMessage (..))
import Network.Protocol.Codec (BinaryMessage (..), ShowProtocol (..))
import Network.Protocol.Codec.Spec (
ArbitraryMessage (..),
MessageEq (..),
MessageVariations (..),
ShowProtocol (..),
SomePeerHasAgency (..),
Variations (..),
)
Expand Down
59 changes: 50 additions & 9 deletions marlowe-protocols/src/Network/Protocol/Codec.hs
Original file line number Diff line number Diff line change
@@ -1,40 +1,81 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ViewPatterns #-}

module Network.Protocol.Codec where

import Control.Exception (Exception)
import Control.Exception (Exception (..))
import Control.Monad (mfilter)
import Data.Binary
import Data.Binary.Get
import Data.Binary.Put (runPut)
import qualified Data.ByteString as BS
import Data.ByteString.Base16 (encodeBase16)
import qualified Data.ByteString.Lazy as LBS
import Network.TypedProtocol (Message, Protocol)
import Data.Data (Typeable)
import qualified Data.Text as T
import Network.TypedProtocol (Message, Protocol (..))
import Network.TypedProtocol.Codec

class ShowProtocol ps where
showsPrecMessage :: Int -> PeerHasAgency pr st -> Message ps st st' -> ShowS
default showsPrecMessage :: (Show (Message ps st st')) => Int -> PeerHasAgency pr st -> Message ps st st' -> ShowS
showsPrecMessage p _ = showsPrec p

showsPrecServerHasAgency :: forall (st :: ps). Int -> ServerHasAgency st -> ShowS
default showsPrecServerHasAgency :: forall (st :: ps). (Show (ServerHasAgency st)) => Int -> ServerHasAgency st -> ShowS
showsPrecServerHasAgency = showsPrec

showsPrecClientHasAgency :: forall (st :: ps). Int -> ClientHasAgency st -> ShowS
default showsPrecClientHasAgency :: forall (st :: ps). (Show (ClientHasAgency st)) => Int -> ClientHasAgency st -> ShowS
showsPrecClientHasAgency = showsPrec

class (Protocol ps) => BinaryMessage ps where
putMessage :: PeerHasAgency pr (st :: ps) -> Message ps st st' -> Put
getMessage :: PeerHasAgency pr (st :: ps) -> Get (SomeMessage st)

data DeserializeError = DeserializeError
newtype ShowPeerHasAgencyViaShowProtocol pr st = ShowPeerHasAgencyViaShowProtocol (PeerHasAgency pr st)

instance (ShowProtocol ps) => Show (ShowPeerHasAgencyViaShowProtocol pr (st :: ps)) where
showsPrec p (ShowPeerHasAgencyViaShowProtocol pa) = showParen (p > 10) case pa of
ClientAgency tok -> showString "ClientAgency" . showsPrecClientHasAgency p tok
ServerAgency tok -> showString "ServerAgency" . showsPrecServerHasAgency p tok

data DeserializeError ps = forall pr (st :: ps).
DeserializeError
{ message :: !String
, offset :: !ByteOffset
, unconsumedInput :: !BS.ByteString
, state :: ShowPeerHasAgencyViaShowProtocol pr st
}
deriving (Show)

instance Exception DeserializeError
deriving instance (ShowProtocol ps) => Show (DeserializeError ps)

instance (Typeable ps, ShowProtocol ps) => Exception (DeserializeError ps) where
displayException (DeserializeError{..}) =
unlines
[ "Offset: " <> show offset
, "Protocol State: " <> show state
, "Message: " <> message
, "Unconsumed Input: " <> T.unpack (encodeBase16 unconsumedInput)
]

binaryCodec :: (Applicative m, BinaryMessage ps) => Codec ps DeserializeError m LBS.ByteString
binaryCodec = Codec (encodePut . putMessage) (decodeGet . getMessage)
binaryCodec :: (Applicative m, BinaryMessage ps) => Codec ps (DeserializeError ps) m LBS.ByteString
binaryCodec = Codec (encodePut . putMessage) (decodeGet <*> getMessage)

encodePut :: (a -> Put) -> a -> LBS.ByteString
encodePut = fmap runPut

decodeGet :: (Applicative m) => Get a -> m (DecodeStep LBS.ByteString DeserializeError m a)
decodeGet = go . runGetIncremental
decodeGet
:: (Applicative m)
=> PeerHasAgency pr (st :: ps)
-> Get a
-> m (DecodeStep LBS.ByteString (DeserializeError ps) m a)
decodeGet (ShowPeerHasAgencyViaShowProtocol -> state) = go . runGetIncremental
where
go =
pure . \case
Expand Down
15 changes: 1 addition & 14 deletions marlowe-protocols/src/Network/Protocol/Codec/Spec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import Data.Word
import GHC.Generics
import GHC.Real (Ratio ((:%)))
import GHC.Show (showSpace)
import Network.Protocol.Codec (BinaryMessage, binaryCodec)
import Network.Protocol.Codec (BinaryMessage, ShowProtocol (..), binaryCodec)
import Network.TypedProtocol (PeerHasAgency, Protocol (..), SomeMessage (..))
import Network.TypedProtocol.Codec (AnyMessageAndAgency (..), Codec (..), PeerHasAgency (..), runDecoder)
import Numeric.Natural (Natural)
Expand All @@ -42,19 +42,6 @@ import Test.QuickCheck (Property, Testable (property), counterexample, forAllShr
import Test.QuickCheck.Gen (Gen)
import Test.QuickCheck.Property (failed, succeeded)

class ShowProtocol ps where
showsPrecMessage :: Int -> PeerHasAgency pr st -> Message ps st st' -> ShowS
default showsPrecMessage :: (Show (Message ps st st')) => Int -> PeerHasAgency pr st -> Message ps st st' -> ShowS
showsPrecMessage p _ = showsPrec p

showsPrecServerHasAgency :: forall (st :: ps). Int -> ServerHasAgency st -> ShowS
default showsPrecServerHasAgency :: forall (st :: ps). (Show (ServerHasAgency st)) => Int -> ServerHasAgency st -> ShowS
showsPrecServerHasAgency = showsPrec

showsPrecClientHasAgency :: forall (st :: ps). Int -> ClientHasAgency st -> ShowS
default showsPrecClientHasAgency :: forall (st :: ps). (Show (ClientHasAgency st)) => Int -> ClientHasAgency st -> ShowS
showsPrecClientHasAgency = showsPrec

class MessageEq ps where
messageEq :: AnyMessageAndAgency ps -> AnyMessageAndAgency ps -> Bool

Expand Down
Loading

0 comments on commit 41af813

Please sign in to comment.