diff --git a/be1-go/channel/authentication/authentication_test.go b/be1-go/channel/authentication/authentication_test.go index 3a1cf90ebf..7f8f711996 100644 --- a/be1-go/channel/authentication/authentication_test.go +++ b/be1-go/channel/authentication/authentication_test.go @@ -213,7 +213,7 @@ type fakeHub struct { // newFakeHub returns a fake Hub. func newFakeHub(publicOrg kyber.Point, log zerolog.Logger, laoFac channel.LaoFactory) (*fakeHub, error) { - schemaValidator, err := validation.NewSchemaValidator(log) + schemaValidator, err := validation.NewSchemaValidator() if err != nil { return nil, xerrors.Errorf("failed to create the schema validator: %v", err) } @@ -285,7 +285,7 @@ func (h *fakeHub) Sign(data []byte) ([]byte, error) { func (h *fakeHub) NotifyWitnessMessage(_ string, _ string, _ string) {} // GetPeersInfo implements channel.HubFunctionalities -func (h *fakeHub) GetPeersInfo() []method.ServerInfo { +func (h *fakeHub) GetPeersInfo() []method.GreetServerParams { return nil } diff --git a/be1-go/channel/channel.go b/be1-go/channel/channel.go index 6e5e479201..73ca9423a8 100644 --- a/be1-go/channel/channel.go +++ b/be1-go/channel/channel.go @@ -101,7 +101,7 @@ type HubFunctionalities interface { SendAndHandleMessage(method.Broadcast) error NotifyWitnessMessage(messageId string, publicKey string, signature string) GetClientServerAddress() string - GetPeersInfo() []method.ServerInfo + GetPeersInfo() []method.GreetServerParams } // Broadcastable defines a channel that can broadcast diff --git a/be1-go/channel/chirp/chirp_test.go b/be1-go/channel/chirp/chirp_test.go index 830d7193d9..667fc17c7a 100644 --- a/be1-go/channel/chirp/chirp_test.go +++ b/be1-go/channel/chirp/chirp_test.go @@ -559,7 +559,7 @@ type fakeHub struct { // NewFakeHub returns a fake Hub. func NewFakeHub(publicOrg kyber.Point, log zerolog.Logger, laoFac channel.LaoFactory) (*fakeHub, error) { - schemaValidator, err := validation.NewSchemaValidator(log) + schemaValidator, err := validation.NewSchemaValidator() if err != nil { return nil, xerrors.Errorf("failed to create the schema validator: %v", err) } @@ -626,7 +626,7 @@ func (h *fakeHub) Sign(data []byte) ([]byte, error) { func (h *fakeHub) NotifyWitnessMessage(messageId string, publicKey string, signature string) {} // GetPeersInfo implements channel.HubFunctionalities -func (h *fakeHub) GetPeersInfo() []method.ServerInfo { +func (h *fakeHub) GetPeersInfo() []method.GreetServerParams { return nil } diff --git a/be1-go/channel/coin/coin_test.go b/be1-go/channel/coin/coin_test.go index cd8f5e4523..8f7ff8d5ae 100644 --- a/be1-go/channel/coin/coin_test.go +++ b/be1-go/channel/coin/coin_test.go @@ -744,7 +744,7 @@ type fakeHub struct { // NewFakeHub returns a fake Hub. func NewFakeHub(publicOrg kyber.Point, log zerolog.Logger, laoFac channel.LaoFactory) (*fakeHub, error) { - schemaValidator, err := validation.NewSchemaValidator(log) + schemaValidator, err := validation.NewSchemaValidator() if err != nil { return nil, xerrors.Errorf("failed to create the schema validator: %v", err) } @@ -811,7 +811,7 @@ func (h *fakeHub) Sign(data []byte) ([]byte, error) { func (h *fakeHub) NotifyWitnessMessage(messageId string, publicKey string, signature string) {} // GetPeersInfo implements channel.HubFunctionalities -func (h *fakeHub) GetPeersInfo() []method.ServerInfo { +func (h *fakeHub) GetPeersInfo() []method.GreetServerParams { return nil } diff --git a/be1-go/channel/consensus/consensus_test.go b/be1-go/channel/consensus/consensus_test.go index 5f9ebd8719..75cc1e8a43 100644 --- a/be1-go/channel/consensus/consensus_test.go +++ b/be1-go/channel/consensus/consensus_test.go @@ -1949,7 +1949,7 @@ type fakeHub struct { // NewFakeHub returns a fake Hub. func NewFakeHub(pubKeyOwner kyber.Point, log zerolog.Logger, laoFac channel.LaoFactory) (*fakeHub, error) { - schemaValidator, err := validation.NewSchemaValidator(log) + schemaValidator, err := validation.NewSchemaValidator() if err != nil { return nil, xerrors.Errorf("failed to create the schema validator: %v", err) } @@ -2017,7 +2017,7 @@ func (h *fakeHub) Sign(data []byte) ([]byte, error) { func (h *fakeHub) NotifyWitnessMessage(messageId string, publicKey string, signature string) {} // GetPeersInfo implements channel.HubFunctionalities -func (h *fakeHub) GetPeersInfo() []method.ServerInfo { +func (h *fakeHub) GetPeersInfo() []method.GreetServerParams { return nil } diff --git a/be1-go/channel/election/election_test.go b/be1-go/channel/election/election_test.go index 8ede712942..32263115b3 100644 --- a/be1-go/channel/election/election_test.go +++ b/be1-go/channel/election/election_test.go @@ -693,7 +693,7 @@ type fakeHub struct { // NewFakeHub returns a fake Hub. func NewFakeHub(publicOrg kyber.Point, log zerolog.Logger, laoFac channel.LaoFactory) (*fakeHub, error) { - schemaValidator, err := validation.NewSchemaValidator(log) + schemaValidator, err := validation.NewSchemaValidator() if err != nil { return nil, xerrors.Errorf("failed to create the schema validator: %v", err) } @@ -754,7 +754,7 @@ func (h *fakeHub) Sign(data []byte) ([]byte, error) { func (h *fakeHub) NotifyWitnessMessage(messageId string, publicKey string, signature string) {} // GetPeersInfo implements channel.HubFunctionalities -func (h *fakeHub) GetPeersInfo() []method.ServerInfo { +func (h *fakeHub) GetPeersInfo() []method.GreetServerParams { return nil } diff --git a/be1-go/channel/generalChirping/generalChirping_test.go b/be1-go/channel/generalChirping/generalChirping_test.go index e70ba74ac9..dfaacd7daf 100644 --- a/be1-go/channel/generalChirping/generalChirping_test.go +++ b/be1-go/channel/generalChirping/generalChirping_test.go @@ -211,7 +211,7 @@ type fakeHub struct { // NewFakeHub returns a fake Hub. func NewFakeHub(publicOrg kyber.Point, log zerolog.Logger, laoFac channel.LaoFactory) (*fakeHub, error) { - schemaValidator, err := validation.NewSchemaValidator(log) + schemaValidator, err := validation.NewSchemaValidator() if err != nil { return nil, xerrors.Errorf("failed to create the schema validator: %v", err) } @@ -278,7 +278,7 @@ func (h *fakeHub) Sign(data []byte) ([]byte, error) { func (h *fakeHub) NotifyWitnessMessage(messageId string, publicKey string, signature string) {} // GetPeersInfo implements channel.HubFunctionalities -func (h *fakeHub) GetPeersInfo() []method.ServerInfo { +func (h *fakeHub) GetPeersInfo() []method.GreetServerParams { return nil } diff --git a/be1-go/channel/lao/lao_test.go b/be1-go/channel/lao/lao_test.go index 2d7b76ff3c..6d5d37efff 100644 --- a/be1-go/channel/lao/lao_test.go +++ b/be1-go/channel/lao/lao_test.go @@ -745,7 +745,7 @@ type fakeHub struct { // NewFakeHub returns a fake Hub. func NewFakeHub(clientAddress string, publicOrg kyber.Point, log zerolog.Logger, laoFac channel.LaoFactory) (*fakeHub, error) { - schemaValidator, err := validation.NewSchemaValidator(log) + schemaValidator, err := validation.NewSchemaValidator() if err != nil { return nil, xerrors.Errorf("failed to create the schema validator: %v", err) } @@ -809,19 +809,19 @@ func (h *fakeHub) Sign(data []byte) ([]byte, error) { func (h *fakeHub) NotifyWitnessMessage(messageId string, publicKey string, signature string) {} // GetPeersInfo implements channel.HubFunctionalities -func (h *fakeHub) GetPeersInfo() []method.ServerInfo { - peer1 := method.ServerInfo{ +func (h *fakeHub) GetPeersInfo() []method.GreetServerParams { + peer1 := method.GreetServerParams{ PublicKey: "", ClientAddress: "wss://localhost:9002/client", ServerAddress: "", } - peer2 := method.ServerInfo{ + peer2 := method.GreetServerParams{ PublicKey: "", ClientAddress: "wss://localhost:9004/client", ServerAddress: "", } - return []method.ServerInfo{peer1, peer2} + return []method.GreetServerParams{peer1, peer2} } func (h *fakeHub) GetSchemaValidator() validation.SchemaValidator { diff --git a/be1-go/channel/reaction/reaction_test.go b/be1-go/channel/reaction/reaction_test.go index f895673b6b..ba953c0e8f 100644 --- a/be1-go/channel/reaction/reaction_test.go +++ b/be1-go/channel/reaction/reaction_test.go @@ -603,7 +603,7 @@ type fakeHub struct { // NewFakeHub returns a fake Hub. func NewFakeHub(publicOrg kyber.Point, log zerolog.Logger, laoFac channel.LaoFactory) (*fakeHub, error) { - schemaValidator, err := validation.NewSchemaValidator(log) + schemaValidator, err := validation.NewSchemaValidator() if err != nil { return nil, xerrors.Errorf("failed to create the schema validator: %v", err) } @@ -670,7 +670,7 @@ func (h *fakeHub) Sign(data []byte) ([]byte, error) { func (h *fakeHub) NotifyWitnessMessage(messageId string, publicKey string, signature string) {} // GetPeersInfo implements channel.HubFunctionalities -func (h *fakeHub) GetPeersInfo() []method.ServerInfo { +func (h *fakeHub) GetPeersInfo() []method.GreetServerParams { return nil } diff --git a/be1-go/cli/cli.go b/be1-go/cli/cli.go index c906b010e5..34b173615d 100644 --- a/be1-go/cli/cli.go +++ b/be1-go/cli/cli.go @@ -5,17 +5,22 @@ import ( "encoding/base64" "encoding/json" "fmt" + "github.com/rs/zerolog" "golang.org/x/exp/slices" "net/url" "os" popstellar "popstellar" - "popstellar/channel/lao" "popstellar/crypto" "popstellar/hub" - "popstellar/hub/standard_hub" + "popstellar/internal/popserver" + "popstellar/internal/popserver/config" + "popstellar/internal/popserver/database" + "popstellar/internal/popserver/database/sqlite" + "popstellar/internal/popserver/state" + "popstellar/internal/popserver/utils" "popstellar/network" "popstellar/network/socket" - "popstellar/popcha" + "popstellar/validation" "sync" "time" @@ -54,10 +59,84 @@ type ServerConfig struct { OtherServers []string `json:"other-servers"` } +func (s *ServerConfig) newHub(l *zerolog.Logger) (hub.Hub, error) { + // compute the client server address if it wasn't provided + if s.ClientAddress == "" { + s.ClientAddress = fmt.Sprintf("ws://%s:%d/client", s.PublicAddress, s.ClientPort) + } + // compute the server server address if it wasn't provided + if s.ServerAddress == "" { + s.ServerAddress = fmt.Sprintf("ws://%s:%d/server", s.PublicAddress, s.ServerPort) + } + + path := "./database-a/" + sqlite.DefaultPath + + if s.ClientPort == 9002 { + path = "./database-b/" + sqlite.DefaultPath + } + + var point kyber.Point = nil + err := ownerKey(s.PublicKey, &point) + if err != nil { + return nil, err + } + + schemaValidator, err := validation.NewSchemaValidator() + if err != nil { + return nil, err + } + + db, err := sqlite.NewSQLite(path, true) + if err != nil { + return nil, err + } + + database.InitDatabase(&db) + + serverPublicKey, serverSecretKey, err := db.GetServerKeys() + if err != nil { + serverSecretKey = crypto.Suite.Scalar().Pick(crypto.Suite.RandomStream()) + serverPublicKey = crypto.Suite.Point().Mul(serverSecretKey, nil) + + err := db.StoreServerKeys(serverPublicKey, serverSecretKey) + if err != nil { + return nil, err + } + } + + utils.InitUtils(l, schemaValidator) + + state.InitState(l) + + config.InitConfig(point, serverPublicKey, serverSecretKey, s.ClientAddress, s.ServerAddress) + + channels, err := db.GetAllChannels() + if err != nil { + return nil, err + } + + for _, channel := range channels { + alreadyExist, errAnswer := state.HasChannel(channel) + if errAnswer != nil { + return nil, errAnswer + } + if alreadyExist { + continue + } + + errAnswer = state.AddChannel(channel) + if errAnswer != nil { + return nil, errAnswer + } + } + + return popserver.NewHub(), nil +} + // Serve parses the CLI arguments and spawns a hub and a websocket server for // the server func Serve(cliCtx *cli.Context) error { - log := popstellar.Logger + poplog := popstellar.Logger configFilePath := cliCtx.String("config-file") var serverConfig ServerConfig @@ -80,39 +159,22 @@ func Serve(cliCtx *cli.Context) error { } } - computeAddresses(&serverConfig) - - var point kyber.Point = nil - ownerKey(serverConfig.PublicKey, &point) - - // create user hub - h, err := standard_hub.NewHub(point, serverConfig.ClientAddress, serverConfig.ServerAddress, log.With().Str("role", "server").Logger(), - lao.NewChannel) + h, err := serverConfig.newHub(&poplog) if err != nil { - return xerrors.Errorf("failed create the hub: %v", err) + return err } // start the processing loop h.Start() - // Start the PoPCHA Authorization Server. It will run internally on localhost, the address of the server given in - // the config file will be the one used externally. - authorizationSrv, err := popcha.NewAuthServer(h, "localhost", serverConfig.AuthPort, - log.With().Str("role", "authorization server").Logger()) - if err != nil { - return xerrors.Errorf("Error while starting the PoPCHA server: %v", err) - } - authorizationSrv.Start() - <-authorizationSrv.Started - // Start websocket server for clients clientSrv := network.NewServer(h, serverConfig.PrivateAddress, serverConfig.ClientPort, socket.ClientSocketType, - log.With().Str("role", "client websocket").Logger()) + poplog.With().Str("role", "client websocket").Logger()) clientSrv.Start() // Start a websocket server for remote servers serverSrv := network.NewServer(h, serverConfig.PrivateAddress, serverConfig.ServerPort, socket.ServerSocketType, - log.With().Str("role", "server websocket").Logger()) + poplog.With().Str("role", "server websocket").Logger()) serverSrv.Start() // create wait group which waits for goroutines to finish @@ -150,7 +212,7 @@ func Serve(cliCtx *cli.Context) error { go serverConnectionLoop(h, wg, done, serverConfig.OtherServers, updatedServersChan, &connectedServers) // Wait for a Ctrl-C - err = network.WaitAndShutdownServers(cliCtx.Context, authorizationSrv, clientSrv, serverSrv) + err = network.WaitAndShutdownServers(cliCtx.Context, nil, clientSrv, serverSrv) if err != nil { return err } @@ -158,7 +220,6 @@ func Serve(cliCtx *cli.Context) error { h.Stop() <-clientSrv.Stopped <-serverSrv.Stopped - <-authorizationSrv.Stopped // notify channs to stop close(done) @@ -173,7 +234,7 @@ func Serve(cliCtx *cli.Context) error { select { case <-channsClosed: case <-time.After(time.Second * 10): - log.Error().Msg("channs didn't close after timeout, exiting") + poplog.Error().Msg("channs didn't close after timeout, exiting") } return nil @@ -241,7 +302,7 @@ func connectToServers(h hub.Hub, wg *sync.WaitGroup, done chan struct{}, servers func connectToSocket(address string, h hub.Hub, wg *sync.WaitGroup, done chan struct{}) error { - log := popstellar.Logger + poplog := popstellar.Logger urlString := fmt.Sprintf("ws://%s/server", address) u, err := url.Parse(urlString) @@ -254,10 +315,10 @@ func connectToSocket(address string, h hub.Hub, return xerrors.Errorf("failed to dial to %s: %v", u.String(), err) } - log.Info().Msgf("connected to server at %s", urlString) + poplog.Info().Msgf("connected to server at %s", urlString) remoteSocket := socket.NewServerSocket(h.Receiver(), - h.OnSocketClose(), ws, wg, done, log) + h.OnSocketClose(), ws, wg, done, poplog) wg.Add(2) go remoteSocket.WritePump() @@ -307,20 +368,20 @@ func startWithConfigFile(configFilename string) (ServerConfig, error) { func loadConfig(configFilename string) (ServerConfig, error) { bytes, err := os.ReadFile(configFilename) if err != nil { - return ServerConfig{}, xerrors.Errorf("could not read config file: %w", err) + return ServerConfig{}, xerrors.Errorf("could not read serverConfig file: %w", err) } - var config ServerConfig - err = json.Unmarshal(bytes, &config) + var serverConfig ServerConfig + err = json.Unmarshal(bytes, &serverConfig) if err != nil { - return ServerConfig{}, xerrors.Errorf("could not unmarshal config file: %w", err) + return ServerConfig{}, xerrors.Errorf("could not unmarshal serverConfig file: %w", err) } - if config.ServerPort == config.ClientPort { + if serverConfig.ServerPort == serverConfig.ClientPort { return ServerConfig{}, xerrors.Errorf("client and server ports must be different") - } else if config.ServerPort == config.AuthPort || config.ClientPort == config.AuthPort { + } else if serverConfig.ServerPort == serverConfig.AuthPort || serverConfig.ClientPort == serverConfig.AuthPort { return ServerConfig{}, xerrors.Errorf("PoPCHA Authentication port must be unique\"") } - return config, nil + return serverConfig, nil } // startWithFlags returns the ServerConfig using the command line flags @@ -399,15 +460,3 @@ func updateServersState(servers []string, connectedServers *map[string]bool) { } } } - -// computeAddresses computes the client and server addresses if they were not provided -func computeAddresses(serverConfig *ServerConfig) { - // compute the client server address if it wasn't provided - if serverConfig.ClientAddress == "" { - serverConfig.ClientAddress = fmt.Sprintf("ws://%s:%d/client", serverConfig.PublicAddress, serverConfig.ClientPort) - } - // compute the server server address if it wasn't provided - if serverConfig.ServerAddress == "" { - serverConfig.ServerAddress = fmt.Sprintf("ws://%s:%d/server", serverConfig.PublicAddress, serverConfig.ServerPort) - } -} diff --git a/be1-go/cli/database-a/dummy b/be1-go/cli/database-a/dummy new file mode 100644 index 0000000000..e69de29bb2 diff --git a/be1-go/cli/database-b/dummy b/be1-go/cli/database-b/dummy new file mode 100644 index 0000000000..e69de29bb2 diff --git a/be1-go/cli/pop_test.go b/be1-go/cli/pop_test.go index 530b433922..ea429c6c4d 100644 --- a/be1-go/cli/pop_test.go +++ b/be1-go/cli/pop_test.go @@ -11,7 +11,6 @@ import ( const waitUp = time.Second * 2 func TestConnectMultipleServers(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) wait := sync.WaitGroup{} @@ -63,7 +62,6 @@ func TestConnectMultipleServers(t *testing.T) { } func TestConnectMultipleServersWithoutPK(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) wait := sync.WaitGroup{} diff --git a/be1-go/docs/README.md b/be1-go/docs/README.md index c295a05df4..293733194c 100644 --- a/be1-go/docs/README.md +++ b/be1-go/docs/README.md @@ -36,37 +36,36 @@ The project is organized into different modules as follows ``` . -├── channel # contains the abstract definition of a channel -│   ├── authentication # channel implementation for an authentication channel -│   ├── chirp # channel implementation for a chirp channel -│ ├── coin # channel implementation for a coin channel -│   ├── consensus # channel implementation for a consensus channel -│   ├── election # channel implementation for an election channel -│   ├── generalChirping # channel implementation for a universal post channel -│   ├── lao # channel implementation for a LAO channel -│   ├── reaction # channel implementation for a reaction channel -│   └── registry # helper for registry -├── cli # command line interface -├── crypto # defines the cryptographic suite +├── channel # contains the abstract definition of a channel NEED TO BE DELETED +├── cli # command line interface +├── crypto # defines the cryptographic suite ├── docs -├── hub # contains the abstract definition of a hub -│   ├── standard_hub # hub implementation -├── inbox # helper to store messages used by channels -├── message # message types and marshaling/unmarshaling logic -├── network # module to set up Websocket connections -│   └── socket # module to send/receive data over the wire -├── popcha # HTTP server and back-end logic for PoPCHA -└── validation # module to validate incoming/outgoing messages +├── hub # contains the abstract definition of a hub NEED TO BE DELETED +├── inbox # helper to store messages used by channels NEED TO BE DELETED +├── internal +│   ├── depgraph # tool to generate the dependencies graph +│   └── popserver # entry point of the messages received by the sockets +│   ├── config # singleton with the server config informations and server keys +│   ├── database # singleton with the database + implementations of the database +│   ├── generatortest # query and message generators only use for the tests +│   ├── handler # handlers for each query, answer and channel type (entry point is func HandleIncomingMessage) +│   ├── state # singleton with the temporary states of the server (peers, queries, and subscriptions) +│   ├── type # every types use in the implementation +│   └── utils # singleton with the log instance and the schema validator +├── message # message types and marshaling/unmarshaling logic +├── network # module to set up Websocket connections +│   └── socket # module to send/receive data over the wire +├── popcha # HTTP server and back-end logic for PoPCHA NEED TO BE REFACTOR +└── validation # module to validate incoming/outgoing messages ``` -The entry point is the cli with bulk of the implementation logic in the hub -module. +The entry point is the `cli` with bulk of the implementation logic in the `popserver` package. The following diagram represents the relations between the packages in the application.
- Global architecture + Global architecture
#### Architecture @@ -125,42 +124,41 @@ websocket connections. ##### Processing messages in the application layer -The incoming messages received by the `ReadPump` are propagated up the stack to -the `Hub` which is responsible for processing it and sending, depending on the message's nature, a: -- `Result` to the request. -- `Error` -- `Broadcast` -- `GreetServer` back to a server that has not been greeted yet. -- `GetMessagesById` in response to a heartbeat if it is missing some messages. +The incoming messages received by the `ReadPump` are propagated up the stack to the `Hub`. The `Hub`, on receiving a message, +processes it by invoking the `HandleIncomingMessage` method from the package `handler` and in case of `Error`, while processing the message, it will log it. +In parallel, the `Hub` will send a `Heartbeat` every 30 seconds to all the connected servers. -A hub, on receiving a message, processes it by invoking the -`handleIncomingMessage` method where its handled depending on which `Socket` the -message originates from. - -The flowchart below describes the flow of data and how messages are processed. +The flowchart below describes all the possible way for a message inside the handlers from package `handler`.
- Flowchart + Flowchart

- Flowchart last updated at the end of Spring 2023 + Flowchart last updated on 11.05.2024 and everything in red is still missing in the refactoring

-The hubs themselves contain multiple `Channels` with the `Root` channel being -the default one, where messages for creation of new LAOs may be published for -instance. Another example of a channel would be one for an `Election` which -would be a sub-channel within the LAO channel. +We use `Socket.SendError` to send an `Error` back to the client. We use this function only in two places, inside `HandleIncomingMessage` +in case the format of message is wrong or inside `handleQuery` because we should never answer an error to an answer to avoid loops. + +We use `Socket.SendResult` to send a `Result` back to the client when there is no error after processing its query. We use it only inside `query.go` at the end of each method. + +We check the Mid-level communication inside `channel.go`. + +##### Database +
+ Flowchart +
+ +

+ SQL database schema last updated at 11.05.2024 +

-The hubs use `Socket.SendError` to send an `Error` back to the client. We -suggest using `message.NewError` and `message.NewErrorf` to create these error -messages and wrap them using `xerrors.Errorf` with the `%w` format specifier if -required. The rule of thumb is the leaf/last method called from the hub should -create/return a `message.Error` and intermediate methods should propagate it up -by wrapping it until it reaches a point where `Socket.SendError` is invoked. +The database is used to store the state of the server. It is implemented in the `database` package. +We use the Repository pattern to interact with the database. +The current implementation uses a SQLite database. +For testing we use [github.com/vektra/mockery](https://github.com/vektra/mockery) to mock the database. -The hubs have a separate goroutine that is not shown in the flowchart and that -sends a `Heartbeat` message to the servers every 30 seconds. ##### Message definitions diff --git a/be1-go/docs/images/database.png b/be1-go/docs/images/database.png new file mode 100644 index 0000000000..55bab8f37f Binary files /dev/null and b/be1-go/docs/images/database.png differ diff --git a/be1-go/docs/images/dependencies/dependencies.dot b/be1-go/docs/images/dependencies/dependencies.dot new file mode 100644 index 0000000000..b5e89a811c --- /dev/null +++ b/be1-go/docs/images/dependencies/dependencies.dot @@ -0,0 +1,27 @@ +strict digraph { +labelloc="t"; +label =
(generated 10 May 24 - 08:56:20)>; +graph [fontname = "helvetica"]; +graph [fontname = "helvetica"]; +node [fontname = "helvetica"]; +edge [fontname = "helvetica"]; +node [shape=box,style=rounded]; +start=0; +ratio = fill; +rankdir="LR"; +"internal/popserver" -> "internal/popserver/config" [minlen=1]; +"internal/popserver" -> "internal/popserver/handler" [minlen=1]; +"internal/popserver" -> "internal/popserver/state" [minlen=1]; +"internal/popserver" -> "internal/popserver/utils" [minlen=1]; +"internal/popserver" -> "network/socket" [minlen=1]; +"internal/popserver/database" -> "internal/popserver/types" [minlen=1]; +"internal/popserver/handler" -> "internal/popserver/config" [minlen=1]; +"internal/popserver/handler" -> "internal/popserver/database" [minlen=1]; +"internal/popserver/handler" -> "internal/popserver/state" [minlen=1]; +"internal/popserver/handler" -> "internal/popserver/types" [minlen=1]; +"internal/popserver/handler" -> "internal/popserver/utils" [minlen=1]; +"internal/popserver/handler" -> "network/socket" [minlen=1]; +"internal/popserver/state" -> "internal/popserver/types" [minlen=1]; +"internal/popserver/state" -> "network/socket" [minlen=1]; +"internal/popserver/types" -> "network/socket" [minlen=1]; +} diff --git a/be1-go/docs/images/dependencies/dependencies.png b/be1-go/docs/images/dependencies/dependencies.png new file mode 100644 index 0000000000..d79cd34237 Binary files /dev/null and b/be1-go/docs/images/dependencies/dependencies.png differ diff --git a/be1-go/docs/images/global architecture.png b/be1-go/docs/images/global architecture.png deleted file mode 100644 index e34780ab56..0000000000 Binary files a/be1-go/docs/images/global architecture.png and /dev/null differ diff --git a/be1-go/docs/images/handler/handler.dot b/be1-go/docs/images/handler/handler.dot new file mode 100644 index 0000000000..5ecf2315a8 --- /dev/null +++ b/be1-go/docs/images/handler/handler.dot @@ -0,0 +1,199 @@ +strict digraph { +node [shape=box,style=rounded]; +start=0; +ratio = fill; +rankdir="LR"; + +subgraph cluster_incoming_message { +node [shape=box,style=rounded]; +label = "incoming_message.go"; +"HandleIncomingMessage"; + +} + +subgraph cluster_query { +node [shape=box,style=rounded]; +label = "query.go"; +"handleQuery"; +"handleGetMessagesByID"; +"handleHeartbeat"; +"handleCatchup"; +"handlePublish"; +"handleUnsubscribe"; +"handleSubscribe"; +"handleGreetserver"; +} + +subgraph cluster_answer{ +node [shape=box,style=rounded]; +label = "answer.go"; +"handleGetMessagesByIDAnswer"; +"handleAnswer"; +} + +subgraph cluster_channel{ +node [shape=box,style=rounded]; +label = "channel.go"; +"handleChannel"; +} + +subgraph cluster_root{ +node [shape=box,style=rounded]; +label = "root.go"; +"handleChannelRoot"; +"handleLaoCreate"; +} + +subgraph cluster_lao{ +node [shape=box,style=rounded]; +label = "lao.go"; +"handleElectionSetup"; +"handleRollCallReOpen"; +"handleRollCallClose"; +"handleRollCallOpen"; +"handleRollCallCreate"; +"handleMeetingState" [ fillcolor="1 0.2 1" style=filled]; +"handleMeetingCreate" [ fillcolor="1 0.2 1" style=filled]; +"handleMessageWitness" [ fillcolor="1 0.2 1" style=filled]; +"handleLaoState" [ fillcolor="1 0.2 1" style=filled]; +"handleLaoUpdate" [ fillcolor="1 0.2 1" style=filled]; +"handleChannelLao"; +} + +subgraph cluster_election{ +node [shape=box,style=rounded]; +label = "election.go"; +"handleElectionEnd"; +"handleVoteCastVote" +"handleElectionOpen"; +"handleChannelElection"; +} + +subgraph cluster_chirp{ +node [shape=box,style=rounded]; +label = "chirp.go"; +"handleChirpDelete"; +"handleChirpAdd"; +"handleChannelChirp"; +} + +subgraph cluster_reaction{ +node [shape=box,style=rounded]; +label = "reaction.go"; +"handleReactionDelete"; +"handleReactionAdd"; +"handleChannelReaction"; +} + +subgraph cluster_coin{ +node [shape=box,style=rounded]; +label = "coin.go"; +"handleCoinPostTransaction"; +"handleChannelCoin"; +} + +subgraph cluster_consensus{ +node [shape=box,fillcolor="1 0.2 1" style=filled]; +label = "consensus.go"; +"handleConsensusFailure"; +"handleConsensusLearn"; +"handleConsensusAccept"; +"handleConsensusPropose"; +"handleConsensusPromise"; +"handleConsensusPrepare"; +"handleConsensusElectAccept"; +"handleConsensusElect"; +"handleChannelConsensus"; +} + +subgraph cluster_authentification{ +node [shape=box,fillcolor="1 0.2 1" style=filled]; +label = "authentification.go"; +"handleAuthenticateUser"; +"handleChannelAuthentication"; +} + +"HandleIncomingMessage" -> { +"handleQuery" +"handleAnswer" +} + +"handleQuery" -> { +"handleGreetserver" +"handleSubscribe" +"handleUnsubscribe" +"handlePublish" +"handleCatchup" +"handleHeartbeat" +"handleGetMessagesByID" +} + +"handleAnswer" -> { +"handleGetMessagesByIDAnswer"; +} + +"handlePublish" -> "handleChannel"; +"handleGetMessagesByIDAnswer" -> "handleChannel"; + +"handleChannel" -> { +"handleChannelRoot"; +"handleChannelLao"; +"handleChannelElection"; +"handleChannelChirp"; +"handleChannelReaction"; +"handleChannelCoin"; +"handleChannelConsensus"; +"handleChannelAuthentication"; +} + +"handleChannelRoot" -> "handleLaoCreate"; + +"handleChannelLao" -> { +"handleElectionSetup" +"handleRollCallReOpen" +"handleRollCallClose" +"handleRollCallOpen" +"handleRollCallCreate" +"handleMeetingState" +"handleMeetingCreate" +"handleMessageWitness" +"handleLaoState" +"handleLaoUpdate" +} + +"handleChannelElection" -> { +"handleElectionEnd"; +"handleVoteCastVote" +"handleElectionOpen"; +} + +"handleChannelChirp" -> { +"handleChirpDelete"; +"handleChirpAdd"; +} + +"handleChannelReaction" -> { +"handleReactionDelete"; +"handleReactionAdd"; +} + +"handleChannelCoin" -> { +"handleCoinPostTransaction"; +} + +"handleChannelConsensus" -> { +"handleConsensusFailure"; +"handleConsensusLearn"; +"handleConsensusAccept"; +"handleConsensusPropose"; +"handleConsensusPromise"; +"handleConsensusPrepare"; +"handleConsensusElectAccept"; +"handleConsensusElect"; +} + +"handleChannelAuthentication" -> { +"handleAuthenticateUser"; +} + +} diff --git a/be1-go/docs/images/handler/handler.png b/be1-go/docs/images/handler/handler.png new file mode 100644 index 0000000000..7ac9089fec Binary files /dev/null and b/be1-go/docs/images/handler/handler.png differ diff --git a/be1-go/go.mod b/be1-go/go.mod index 1d52d3f62e..cefd55e543 100644 --- a/be1-go/go.mod +++ b/be1-go/go.mod @@ -11,6 +11,7 @@ require ( github.com/golang-jwt/jwt/v4 v4.5.0 github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.4.2 + github.com/pkg/errors v0.9.1 github.com/rs/xid v1.3.0 github.com/rs/zerolog v1.25.0 github.com/rzajac/zltest v0.12.0 @@ -19,32 +20,46 @@ require ( github.com/urfave/cli/v2 v2.3.0 github.com/zitadel/oidc/v2 v2.1.2 go.dedis.ch/kyber/v3 v3.0.13 - golang.org/x/exp v0.0.0-20230321023759-10a507213a29 + golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 golang.org/x/sync v0.4.0 golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 gopkg.in/yaml.v2 v2.2.3 + modernc.org/sqlite v1.29.5 ) require ( github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect github.com/golang/protobuf v1.5.3 // indirect + github.com/google/uuid v1.3.0 // indirect github.com/gorilla/schema v1.2.0 // indirect github.com/gorilla/securecookie v1.1.1 // indirect + github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/kr/pretty v0.2.1 // indirect + github.com/mattn/go-isatty v0.0.16 // indirect github.com/muhlemmer/gu v0.3.1 // indirect + github.com/ncruces/go-strftime v0.1.9 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rs/cors v1.8.3 // indirect github.com/russross/blackfriday/v2 v2.0.1 // indirect github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect + github.com/stretchr/objx v0.5.0 // indirect go.dedis.ch/fixbuf v1.0.3 // indirect golang.org/x/crypto v0.14.0 // indirect golang.org/x/net v0.16.0 // indirect golang.org/x/oauth2 v0.6.0 // indirect - golang.org/x/sys v0.13.0 // indirect + golang.org/x/sys v0.16.0 // indirect golang.org/x/text v0.13.0 // indirect google.golang.org/appengine v1.6.7 // indirect google.golang.org/protobuf v1.29.1 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect + modernc.org/libc v1.41.0 // indirect + modernc.org/mathutil v1.6.0 // indirect + modernc.org/memory v1.7.2 // indirect + modernc.org/strutil v1.2.0 // indirect + modernc.org/token v1.1.0 // indirect ) diff --git a/be1-go/go.sum b/be1-go/go.sum index 30c5019d86..0ddfd5658b 100644 --- a/be1-go/go.sum +++ b/be1-go/go.sum @@ -15,6 +15,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:ma github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= @@ -29,6 +31,8 @@ github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiu github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= +github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= @@ -39,17 +43,28 @@ github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyC github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/muhlemmer/gu v0.3.1 h1:7EAqmFrW7n3hETvuAdmFmn4hS8W+z3LgKtrnow+YzNM= github.com/muhlemmer/gu v0.3.1/go.mod h1:YHtHR+gxM+bKEIIs7Hmi9sPT3ZDUvTN/i88wQpZkrdM= +github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= +github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rs/cors v1.8.3 h1:O+qNyWn7Z+F9M0ILBHgMVPuB1xTOucVd5gtaYyXBpRo= github.com/rs/cors v1.8.3/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/rs/xid v1.3.0 h1:6NjYksEUlhurdVehpc7S7dk6DAmcKv8V9gG0FsVN2U4= @@ -96,10 +111,12 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= -golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug= -golang.org/x/exp v0.0.0-20230321023759-10a507213a29/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= +golang.org/x/exp v0.0.0-20231108232855-2478ac86f678 h1:mchzmB1XO2pMaKFRqk/+MV3mgGG96aqaPXaMifQU47w= +golang.org/x/exp v0.0.0-20231108232855-2478ac86f678/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0= +golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -122,9 +139,10 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= -golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= +golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= @@ -135,6 +153,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.17.0 h1:FvmRgNOcs3kOa+T20R1uhfP9F6HgG2mfxDv1vrx1Htc= +golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -158,3 +178,19 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.1.3/go.mod h1:NgwopIslSNH47DimFoV78dnkksY2EFtX0ajyb3K/las= +modernc.org/fileutil v1.3.0 h1:gQ5SIzK3H9kdfai/5x41oQiKValumqNTDXMvKo62HvE= +modernc.org/fileutil v1.3.0/go.mod h1:XatxS8fZi3pS8/hKG2GH/ArUogfxjpEKs3Ku3aK4JyQ= +modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 h1:5D53IMaUuA5InSeMu9eJtlQXS2NxAhyWQvkKEgXZhHI= +modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6/go.mod h1:Qz0X07sNOR1jWYCrJMEnbW/X55x206Q7Vt4mz6/wHp4= +modernc.org/libc v1.41.0 h1:g9YAc6BkKlgORsUWj+JwqoB1wU3o4DE3bM3yvA3k+Gk= +modernc.org/libc v1.41.0/go.mod h1:w0eszPsiXoOnoMJgrXjglgLuDy/bt5RR4y3QzUUeodY= +modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4= +modernc.org/mathutil v1.6.0/go.mod h1:Ui5Q9q1TR2gFm0AQRqQUaBWFLAhQpCwNcuhBOSedWPo= +modernc.org/memory v1.7.2 h1:Klh90S215mmH8c9gO98QxQFsY+W451E8AnzjoE2ee1E= +modernc.org/memory v1.7.2/go.mod h1:NO4NVCQy0N7ln+T9ngWqOQfi7ley4vpwvARR+Hjw95E= +modernc.org/sqlite v1.29.5 h1:8l/SQKAjDtZFo9lkJLdk8g9JEOeYRG4/ghStDCCTiTE= +modernc.org/sqlite v1.29.5/go.mod h1:S02dvcmm7TnTRvGhv8IGYyLnIt7AS2KPaB1F/71p75U= +modernc.org/strutil v1.2.0 h1:agBi9dp1I+eOnxXeiZawM8F4LawKv4NzGWSaLfyeNZA= +modernc.org/strutil v1.2.0/go.mod h1:/mdcBmfOibveCTBxUl5B5l6W+TTH1FXPLHZE6bTosX0= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/be1-go/hub/hub.go b/be1-go/hub/hub.go index 60d4cc8f14..9045062c4a 100644 --- a/be1-go/hub/hub.go +++ b/be1-go/hub/hub.go @@ -10,7 +10,7 @@ import ( // Hub defines the methods a PoP server must implement to receive messages // and handle clients. type Hub interface { - // NotifyNewServer add a socket for the hub to send message to other servers + // NotifyNewServer adds a Socket for the hub to send message to other servers NotifyNewServer(socket.Socket) // Start invokes the processing loop for the hub. @@ -22,11 +22,11 @@ type Hub interface { // Receiver returns a channel that may be used to process incoming messages Receiver() chan<- socket.IncomingMessage - // OnSocketClose returns a channel which accepts socket ids on connection + // OnSocketClose returns a channel which accepts Socket ids on connection // close events. This allows the hub to cleanup clients which close without - // sending an unsubscribe message + // sending an Unsubscribe message OnSocketClose() chan<- string - // SendGreetServer sends a greet server message in the socket + // SendGreetServer sends a greet server message in the Socket SendGreetServer(socket.Socket) error } diff --git a/be1-go/hub/standard_hub/hub_state/peers.go b/be1-go/hub/standard_hub/hub_state/peers.go index ea314b65cb..bec2e1c201 100644 --- a/be1-go/hub/standard_hub/hub_state/peers.go +++ b/be1-go/hub/standard_hub/hub_state/peers.go @@ -1,7 +1,7 @@ package hub_state import ( - "golang.org/x/xerrors" + "popstellar/message/answer" "popstellar/message/query/method" "sync" @@ -13,7 +13,7 @@ import ( type Peers struct { sync.RWMutex // peersInfo stores the info of the peers: public key, client and server endpoints associated with the socket ID - peersInfo map[string]method.ServerInfo + peersInfo map[string]method.GreetServerParams // peersGreeted stores the peers that were greeted by the socket ID peersGreeted map[string]struct{} } @@ -21,19 +21,19 @@ type Peers struct { // NewPeers creates a new Peers structure func NewPeers() Peers { return Peers{ - peersInfo: make(map[string]method.ServerInfo), + peersInfo: make(map[string]method.GreetServerParams), peersGreeted: make(map[string]struct{}), } } // AddPeerInfo adds a peer's info to the table -func (p *Peers) AddPeerInfo(socketId string, info method.ServerInfo) error { +func (p *Peers) AddPeerInfo(socketId string, info method.GreetServerParams) error { p.Lock() defer p.Unlock() currentInfo, ok := p.peersInfo[socketId] if ok { - return xerrors.Errorf( + return answer.NewInvalidActionError( "cannot add %s because peersInfo[%s] already contains %s", info, socketId, currentInfo) } @@ -50,10 +50,10 @@ func (p *Peers) AddPeerGreeted(socketId string) { } // GetAllPeersInfo returns a copy of the peers' info slice -func (p *Peers) GetAllPeersInfo() []method.ServerInfo { +func (p *Peers) GetAllPeersInfo() []method.GreetServerParams { p.RLock() defer p.RUnlock() - peersInfo := make([]method.ServerInfo, 0, len(p.peersInfo)) + peersInfo := make([]method.GreetServerParams, 0, len(p.peersInfo)) for _, info := range p.peersInfo { if !slices.Contains(peersInfo, info) { peersInfo = append(peersInfo, info) diff --git a/be1-go/hub/standard_hub/standard_hub.go b/be1-go/hub/standard_hub/standard_hub.go index bb08024c7a..70ac150110 100644 --- a/be1-go/hub/standard_hub/standard_hub.go +++ b/be1-go/hub/standard_hub/standard_hub.go @@ -99,7 +99,7 @@ type Hub struct { func NewHub(pubKeyOwner kyber.Point, clientServerAddress string, serverServerAddress string, log zerolog.Logger, laoFac channel.LaoFactory, ) (*Hub, error) { - schemaValidator, err := validation.NewSchemaValidator(log) + schemaValidator, err := validation.NewSchemaValidator() if err != nil { return nil, xerrors.Errorf("failed to create the schema validator: %v", err) } @@ -236,7 +236,7 @@ func (h *Hub) SendGreetServer(socket socket.Socket) error { return xerrors.Errorf("failed to marshal server public key: %v", err) } - serverInfo := method.ServerInfo{ + serverInfo := method.GreetServerParams{ PublicKey: base64.URLEncoding.EncodeToString(pk), ServerAddress: h.serverServerAddress, ClientAddress: h.clientServerAddress, @@ -594,7 +594,7 @@ func (h *Hub) NotifyWitnessMessage(messageId string, publicKey string, signature h.hubInbox.AddWitnessSignature(messageId, publicKey, signature) } -func (h *Hub) GetPeersInfo() []method.ServerInfo { +func (h *Hub) GetPeersInfo() []method.GreetServerParams { return h.peers.GetAllPeersInfo() } diff --git a/be1-go/hub/standard_hub/standard_hub_test.go b/be1-go/hub/standard_hub/standard_hub_test.go index 9920f62640..d491d8a44b 100644 --- a/be1-go/hub/standard_hub/standard_hub_test.go +++ b/be1-go/hub/standard_hub/standard_hub_test.go @@ -1801,7 +1801,7 @@ func Test_Handle_GreetServer_First_Time(t *testing.T) { sock := &fakeSocket{} - serverInfo := method.ServerInfo{ + serverInfo := method.GreetServerParams{ PublicKey: "", ServerAddress: "ws://localhost:9003/server", ClientAddress: "ws://localhost:9002/client", @@ -1856,7 +1856,7 @@ func Test_Handle_GreetServer_Already_Greeted(t *testing.T) { // reset socket message sock.msg = nil - serverInfo := method.ServerInfo{ + serverInfo := method.GreetServerParams{ PublicKey: "", ServerAddress: "ws://localhost:9003/server", ClientAddress: "ws://localhost:9002/client", @@ -1894,13 +1894,13 @@ func Test_Handle_GreetServer_Already_Received(t *testing.T) { hub, err := NewHub(keypair.public, "", "", nolog, nil) require.NoError(t, err) - serverInfo1 := method.ServerInfo{ + serverInfo1 := method.GreetServerParams{ PublicKey: "", ServerAddress: "ws://localhost:9003/server", ClientAddress: "ws://localhost:9002/client", } - serverInfo2 := method.ServerInfo{ + serverInfo2 := method.GreetServerParams{ PublicKey: "", ServerAddress: "ws://localhost:9005/server", ClientAddress: "ws://localhost:9004/client", diff --git a/be1-go/internal/depgraph/dep.yml b/be1-go/internal/depgraph/dep.yml index 48f4061ca1..200f58f5ff 100644 --- a/be1-go/internal/depgraph/dep.yml +++ b/be1-go/internal/depgraph/dep.yml @@ -2,14 +2,7 @@ modname: popstellar overwrite: true outfile: graph.dot includes: - - popstellar/* + - popstellar/internal/popserver/* + - popstellar/network/socket excludes: - - popstellar/internal - - popstellar/crypto - - popstellar/cli - - popstellar/docs - - popstellar/message - - popstellar$ interfaces: - - hub - - channel \ No newline at end of file diff --git a/be1-go/internal/popserver/config/config.go b/be1-go/internal/popserver/config/config.go new file mode 100644 index 0000000000..ed37e64ab6 --- /dev/null +++ b/be1-go/internal/popserver/config/config.go @@ -0,0 +1,96 @@ +package config + +import ( + "encoding/base64" + "go.dedis.ch/kyber/v3" + "popstellar/message/answer" + "sync" +) + +var once sync.Once +var instance *config + +type config struct { + ownerPubKey kyber.Point + serverPubKey kyber.Point + serverSecretKey kyber.Scalar + clientServerAddress string + serverServerAddress string +} + +func InitConfig(ownerPubKey, serverPubKey kyber.Point, serverSecretKey kyber.Scalar, + clientServerAddress, serverServerAddress string) { + once.Do(func() { + instance = &config{ + ownerPubKey: ownerPubKey, + serverPubKey: serverPubKey, + serverSecretKey: serverSecretKey, + clientServerAddress: clientServerAddress, + serverServerAddress: serverServerAddress, + } + }) +} + +// ONLY FOR TEST PURPOSE +// SetConfig is only here to be used to reset the config before each test +func SetConfig(ownerPubKey, serverPubKey kyber.Point, serverSecretKey kyber.Scalar, + clientServerAddress, serverServerAddress string) { + instance = &config{ + ownerPubKey: ownerPubKey, + serverPubKey: serverPubKey, + serverSecretKey: serverSecretKey, + clientServerAddress: clientServerAddress, + serverServerAddress: serverServerAddress, + } +} + +func getConfig() (*config, *answer.Error) { + if instance == nil { + errAnswer := answer.NewInternalServerError("config was not instantiated") + return nil, errAnswer + } + + return instance, nil +} + +func GetOwnerPublicKeyInstance() (kyber.Point, *answer.Error) { + config, errAnswer := getConfig() + if errAnswer != nil { + return nil, errAnswer + } + + return config.ownerPubKey, nil +} + +func GetServerPublicKeyInstance() (kyber.Point, *answer.Error) { + config, errAnswer := getConfig() + if errAnswer != nil { + return nil, errAnswer + } + + return config.serverPubKey, nil +} + +func GetServerSecretKeyInstance() (kyber.Scalar, *answer.Error) { + config, errAnswer := getConfig() + if errAnswer != nil { + return nil, errAnswer + } + + return config.serverSecretKey, nil +} + +func GetServerInfo() (string, string, string, *answer.Error) { + config, errAnswer := getConfig() + if errAnswer != nil { + return "", "", "", errAnswer + } + + pkBuf, err := config.serverPubKey.MarshalBinary() + if err != nil { + errAnswer := answer.NewInternalServerError("failed to unmarshall server public key", err) + return "", "", "", errAnswer + } + + return base64.URLEncoding.EncodeToString(pkBuf), instance.clientServerAddress, instance.serverServerAddress, nil +} diff --git a/be1-go/internal/popserver/database/database.go b/be1-go/internal/popserver/database/database.go new file mode 100644 index 0000000000..6bc4eef2cb --- /dev/null +++ b/be1-go/internal/popserver/database/database.go @@ -0,0 +1,63 @@ +package database + +import ( + "popstellar/internal/popserver/database/repository" + "popstellar/message/answer" + "sync" +) + +var once sync.Once +var instance repository.Repository + +func InitDatabase(db repository.Repository) { + once.Do(func() { + instance = db + }) +} + +// ONLY FOR TEST PURPOSE +// SetDatabase is only here to be used to reset the database before each test +func SetDatabase(mockRepo *repository.MockRepository) { + instance = mockRepo +} + +func getInstance() (repository.Repository, *answer.Error) { + if instance == nil { + errAnswer := answer.NewInternalServerError("database was not instantiated") + return nil, errAnswer + } + + return instance, nil +} + +func GetQueryRepositoryInstance() (repository.QueryRepository, *answer.Error) { + return getInstance() +} + +func GetChannelRepositoryInstance() (repository.ChannelRepository, *answer.Error) { + return getInstance() +} + +func GetRootRepositoryInstance() (repository.RootRepository, *answer.Error) { + return getInstance() +} + +func GetLAORepositoryInstance() (repository.LAORepository, *answer.Error) { + return getInstance() +} + +func GetChirpRepositoryInstance() (repository.ChirpRepository, *answer.Error) { + return getInstance() +} + +func GetCoinRepositoryInstance() (repository.CoinRepository, *answer.Error) { + return getInstance() +} + +func GetElectionRepositoryInstance() (repository.ElectionRepository, *answer.Error) { + return getInstance() +} + +func GetReactionRepositoryInstance() (repository.ReactionRepository, *answer.Error) { + return getInstance() +} diff --git a/be1-go/internal/popserver/database/repository/mock_repository.go b/be1-go/internal/popserver/database/repository/mock_repository.go new file mode 100644 index 0000000000..a1a5e3f155 --- /dev/null +++ b/be1-go/internal/popserver/database/repository/mock_repository.go @@ -0,0 +1,966 @@ +// Code generated by mockery v2.42.1. DO NOT EDIT. + +package repository + +import ( + message "popstellar/message/query/method/message" + + mock "github.com/stretchr/testify/mock" + kyber "go.dedis.ch/kyber/v3" + + types "popstellar/internal/popserver/types" +) + +// MockRepository is an autogenerated mock type for the Repository type +type MockRepository struct { + mock.Mock +} + +// CheckPrevCreateOrCloseID provides a mock function with given fields: channel, nextID +func (_m *MockRepository) CheckPrevCreateOrCloseID(channel string, nextID string) (bool, error) { + ret := _m.Called(channel, nextID) + + if len(ret) == 0 { + panic("no return value specified for CheckPrevCreateOrCloseID") + } + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(string, string) (bool, error)); ok { + return rf(channel, nextID) + } + if rf, ok := ret.Get(0).(func(string, string) bool); ok { + r0 = rf(channel, nextID) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(string, string) error); ok { + r1 = rf(channel, nextID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CheckPrevOpenOrReopenID provides a mock function with given fields: channel, nextID +func (_m *MockRepository) CheckPrevOpenOrReopenID(channel string, nextID string) (bool, error) { + ret := _m.Called(channel, nextID) + + if len(ret) == 0 { + panic("no return value specified for CheckPrevOpenOrReopenID") + } + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(string, string) (bool, error)); ok { + return rf(channel, nextID) + } + if rf, ok := ret.Get(0).(func(string, string) bool); ok { + r0 = rf(channel, nextID) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(string, string) error); ok { + r1 = rf(channel, nextID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetAllMessagesFromChannel provides a mock function with given fields: channelID +func (_m *MockRepository) GetAllMessagesFromChannel(channelID string) ([]message.Message, error) { + ret := _m.Called(channelID) + + if len(ret) == 0 { + panic("no return value specified for GetAllMessagesFromChannel") + } + + var r0 []message.Message + var r1 error + if rf, ok := ret.Get(0).(func(string) ([]message.Message, error)); ok { + return rf(channelID) + } + if rf, ok := ret.Get(0).(func(string) []message.Message); ok { + r0 = rf(channelID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]message.Message) + } + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(channelID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetChannelType provides a mock function with given fields: channel +func (_m *MockRepository) GetChannelType(channel string) (string, error) { + ret := _m.Called(channel) + + if len(ret) == 0 { + panic("no return value specified for GetChannelType") + } + + var r0 string + var r1 error + if rf, ok := ret.Get(0).(func(string) (string, error)); ok { + return rf(channel) + } + if rf, ok := ret.Get(0).(func(string) string); ok { + r0 = rf(channel) + } else { + r0 = ret.Get(0).(string) + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(channel) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetElectionAttendees provides a mock function with given fields: electionID +func (_m *MockRepository) GetElectionAttendees(electionID string) (map[string]struct{}, error) { + ret := _m.Called(electionID) + + if len(ret) == 0 { + panic("no return value specified for GetElectionAttendees") + } + + var r0 map[string]struct{} + var r1 error + if rf, ok := ret.Get(0).(func(string) (map[string]struct{}, error)); ok { + return rf(electionID) + } + if rf, ok := ret.Get(0).(func(string) map[string]struct{}); ok { + r0 = rf(electionID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]struct{}) + } + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(electionID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetElectionCreationTime provides a mock function with given fields: electionID +func (_m *MockRepository) GetElectionCreationTime(electionID string) (int64, error) { + ret := _m.Called(electionID) + + if len(ret) == 0 { + panic("no return value specified for GetElectionCreationTime") + } + + var r0 int64 + var r1 error + if rf, ok := ret.Get(0).(func(string) (int64, error)); ok { + return rf(electionID) + } + if rf, ok := ret.Get(0).(func(string) int64); ok { + r0 = rf(electionID) + } else { + r0 = ret.Get(0).(int64) + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(electionID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetElectionQuestions provides a mock function with given fields: electionID +func (_m *MockRepository) GetElectionQuestions(electionID string) (map[string]types.Question, error) { + ret := _m.Called(electionID) + + if len(ret) == 0 { + panic("no return value specified for GetElectionQuestions") + } + + var r0 map[string]types.Question + var r1 error + if rf, ok := ret.Get(0).(func(string) (map[string]types.Question, error)); ok { + return rf(electionID) + } + if rf, ok := ret.Get(0).(func(string) map[string]types.Question); ok { + r0 = rf(electionID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]types.Question) + } + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(electionID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetElectionQuestionsWithValidVotes provides a mock function with given fields: electionID +func (_m *MockRepository) GetElectionQuestionsWithValidVotes(electionID string) (map[string]types.Question, error) { + ret := _m.Called(electionID) + + if len(ret) == 0 { + panic("no return value specified for GetElectionQuestionsWithValidVotes") + } + + var r0 map[string]types.Question + var r1 error + if rf, ok := ret.Get(0).(func(string) (map[string]types.Question, error)); ok { + return rf(electionID) + } + if rf, ok := ret.Get(0).(func(string) map[string]types.Question); ok { + r0 = rf(electionID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]types.Question) + } + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(electionID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetElectionSecretKey provides a mock function with given fields: electionID +func (_m *MockRepository) GetElectionSecretKey(electionID string) (kyber.Scalar, error) { + ret := _m.Called(electionID) + + if len(ret) == 0 { + panic("no return value specified for GetElectionSecretKey") + } + + var r0 kyber.Scalar + var r1 error + if rf, ok := ret.Get(0).(func(string) (kyber.Scalar, error)); ok { + return rf(electionID) + } + if rf, ok := ret.Get(0).(func(string) kyber.Scalar); ok { + r0 = rf(electionID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(kyber.Scalar) + } + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(electionID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetElectionType provides a mock function with given fields: electionID +func (_m *MockRepository) GetElectionType(electionID string) (string, error) { + ret := _m.Called(electionID) + + if len(ret) == 0 { + panic("no return value specified for GetElectionType") + } + + var r0 string + var r1 error + if rf, ok := ret.Get(0).(func(string) (string, error)); ok { + return rf(electionID) + } + if rf, ok := ret.Get(0).(func(string) string); ok { + r0 = rf(electionID) + } else { + r0 = ret.Get(0).(string) + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(electionID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetLAOOrganizerPubKey provides a mock function with given fields: electionID +func (_m *MockRepository) GetLAOOrganizerPubKey(electionID string) (kyber.Point, error) { + ret := _m.Called(electionID) + + if len(ret) == 0 { + panic("no return value specified for GetLAOOrganizerPubKey") + } + + var r0 kyber.Point + var r1 error + if rf, ok := ret.Get(0).(func(string) (kyber.Point, error)); ok { + return rf(electionID) + } + if rf, ok := ret.Get(0).(func(string) kyber.Point); ok { + r0 = rf(electionID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(kyber.Point) + } + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(electionID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetLaoWitnesses provides a mock function with given fields: laoID +func (_m *MockRepository) GetLaoWitnesses(laoID string) (map[string]struct{}, error) { + ret := _m.Called(laoID) + + if len(ret) == 0 { + panic("no return value specified for GetLaoWitnesses") + } + + var r0 map[string]struct{} + var r1 error + if rf, ok := ret.Get(0).(func(string) (map[string]struct{}, error)); ok { + return rf(laoID) + } + if rf, ok := ret.Get(0).(func(string) map[string]struct{}); ok { + r0 = rf(laoID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]struct{}) + } + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(laoID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetMessageByID provides a mock function with given fields: ID +func (_m *MockRepository) GetMessageByID(ID string) (message.Message, error) { + ret := _m.Called(ID) + + if len(ret) == 0 { + panic("no return value specified for GetMessageByID") + } + + var r0 message.Message + var r1 error + if rf, ok := ret.Get(0).(func(string) (message.Message, error)); ok { + return rf(ID) + } + if rf, ok := ret.Get(0).(func(string) message.Message); ok { + r0 = rf(ID) + } else { + r0 = ret.Get(0).(message.Message) + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(ID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetMessagesByID provides a mock function with given fields: IDs +func (_m *MockRepository) GetMessagesByID(IDs []string) (map[string]message.Message, error) { + ret := _m.Called(IDs) + + if len(ret) == 0 { + panic("no return value specified for GetMessagesByID") + } + + var r0 map[string]message.Message + var r1 error + if rf, ok := ret.Get(0).(func([]string) (map[string]message.Message, error)); ok { + return rf(IDs) + } + if rf, ok := ret.Get(0).(func([]string) map[string]message.Message); ok { + r0 = rf(IDs) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]message.Message) + } + } + + if rf, ok := ret.Get(1).(func([]string) error); ok { + r1 = rf(IDs) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetOrganizerPubKey provides a mock function with given fields: laoID +func (_m *MockRepository) GetOrganizerPubKey(laoID string) (kyber.Point, error) { + ret := _m.Called(laoID) + + if len(ret) == 0 { + panic("no return value specified for GetOrganizerPubKey") + } + + var r0 kyber.Point + var r1 error + if rf, ok := ret.Get(0).(func(string) (kyber.Point, error)); ok { + return rf(laoID) + } + if rf, ok := ret.Get(0).(func(string) kyber.Point); ok { + r0 = rf(laoID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(kyber.Point) + } + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(laoID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetParamsForGetMessageByID provides a mock function with given fields: params +func (_m *MockRepository) GetParamsForGetMessageByID(params map[string][]string) (map[string][]string, error) { + ret := _m.Called(params) + + if len(ret) == 0 { + panic("no return value specified for GetParamsForGetMessageByID") + } + + var r0 map[string][]string + var r1 error + if rf, ok := ret.Get(0).(func(map[string][]string) (map[string][]string, error)); ok { + return rf(params) + } + if rf, ok := ret.Get(0).(func(map[string][]string) map[string][]string); ok { + r0 = rf(params) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string][]string) + } + } + + if rf, ok := ret.Get(1).(func(map[string][]string) error); ok { + r1 = rf(params) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetParamsHeartbeat provides a mock function with given fields: +func (_m *MockRepository) GetParamsHeartbeat() (map[string][]string, error) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetParamsHeartbeat") + } + + var r0 map[string][]string + var r1 error + if rf, ok := ret.Get(0).(func() (map[string][]string, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() map[string][]string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string][]string) + } + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetReactionSender provides a mock function with given fields: messageID +func (_m *MockRepository) GetReactionSender(messageID string) (string, error) { + ret := _m.Called(messageID) + + if len(ret) == 0 { + panic("no return value specified for GetReactionSender") + } + + var r0 string + var r1 error + if rf, ok := ret.Get(0).(func(string) (string, error)); ok { + return rf(messageID) + } + if rf, ok := ret.Get(0).(func(string) string); ok { + r0 = rf(messageID) + } else { + r0 = ret.Get(0).(string) + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(messageID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetResultForGetMessagesByID provides a mock function with given fields: params +func (_m *MockRepository) GetResultForGetMessagesByID(params map[string][]string) (map[string][]message.Message, error) { + ret := _m.Called(params) + + if len(ret) == 0 { + panic("no return value specified for GetResultForGetMessagesByID") + } + + var r0 map[string][]message.Message + var r1 error + if rf, ok := ret.Get(0).(func(map[string][]string) (map[string][]message.Message, error)); ok { + return rf(params) + } + if rf, ok := ret.Get(0).(func(map[string][]string) map[string][]message.Message); ok { + r0 = rf(params) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string][]message.Message) + } + } + + if rf, ok := ret.Get(1).(func(map[string][]string) error); ok { + r1 = rf(params) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetRollCallState provides a mock function with given fields: channel +func (_m *MockRepository) GetRollCallState(channel string) (string, error) { + ret := _m.Called(channel) + + if len(ret) == 0 { + panic("no return value specified for GetRollCallState") + } + + var r0 string + var r1 error + if rf, ok := ret.Get(0).(func(string) (string, error)); ok { + return rf(channel) + } + if rf, ok := ret.Get(0).(func(string) string); ok { + r0 = rf(channel) + } else { + r0 = ret.Get(0).(string) + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(channel) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetServerKeys provides a mock function with given fields: +func (_m *MockRepository) GetServerKeys() (kyber.Point, kyber.Scalar, error) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetServerKeys") + } + + var r0 kyber.Point + var r1 kyber.Scalar + var r2 error + if rf, ok := ret.Get(0).(func() (kyber.Point, kyber.Scalar, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() kyber.Point); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(kyber.Point) + } + } + + if rf, ok := ret.Get(1).(func() kyber.Scalar); ok { + r1 = rf() + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(kyber.Scalar) + } + } + + if rf, ok := ret.Get(2).(func() error); ok { + r2 = rf() + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// HasChannel provides a mock function with given fields: channel +func (_m *MockRepository) HasChannel(channel string) (bool, error) { + ret := _m.Called(channel) + + if len(ret) == 0 { + panic("no return value specified for HasChannel") + } + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(string) (bool, error)); ok { + return rf(channel) + } + if rf, ok := ret.Get(0).(func(string) bool); ok { + r0 = rf(channel) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(channel) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// HasMessage provides a mock function with given fields: messageID +func (_m *MockRepository) HasMessage(messageID string) (bool, error) { + ret := _m.Called(messageID) + + if len(ret) == 0 { + panic("no return value specified for HasMessage") + } + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(string) (bool, error)); ok { + return rf(messageID) + } + if rf, ok := ret.Get(0).(func(string) bool); ok { + r0 = rf(messageID) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(messageID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// IsAttendee provides a mock function with given fields: laoPath, poptoken +func (_m *MockRepository) IsAttendee(laoPath string, poptoken string) (bool, error) { + ret := _m.Called(laoPath, poptoken) + + if len(ret) == 0 { + panic("no return value specified for IsAttendee") + } + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(string, string) (bool, error)); ok { + return rf(laoPath, poptoken) + } + if rf, ok := ret.Get(0).(func(string, string) bool); ok { + r0 = rf(laoPath, poptoken) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(string, string) error); ok { + r1 = rf(laoPath, poptoken) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// IsElectionEnded provides a mock function with given fields: electionID +func (_m *MockRepository) IsElectionEnded(electionID string) (bool, error) { + ret := _m.Called(electionID) + + if len(ret) == 0 { + panic("no return value specified for IsElectionEnded") + } + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(string) (bool, error)); ok { + return rf(electionID) + } + if rf, ok := ret.Get(0).(func(string) bool); ok { + r0 = rf(electionID) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(electionID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// IsElectionStarted provides a mock function with given fields: electionID +func (_m *MockRepository) IsElectionStarted(electionID string) (bool, error) { + ret := _m.Called(electionID) + + if len(ret) == 0 { + panic("no return value specified for IsElectionStarted") + } + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(string) (bool, error)); ok { + return rf(electionID) + } + if rf, ok := ret.Get(0).(func(string) bool); ok { + r0 = rf(electionID) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(electionID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// IsElectionStartedOrEnded provides a mock function with given fields: electionID +func (_m *MockRepository) IsElectionStartedOrEnded(electionID string) (bool, error) { + ret := _m.Called(electionID) + + if len(ret) == 0 { + panic("no return value specified for IsElectionStartedOrEnded") + } + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(string) (bool, error)); ok { + return rf(electionID) + } + if rf, ok := ret.Get(0).(func(string) bool); ok { + r0 = rf(electionID) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(electionID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// StoreChirpMessages provides a mock function with given fields: channel, generalChannel, msg, generalMsg +func (_m *MockRepository) StoreChirpMessages(channel string, generalChannel string, msg message.Message, generalMsg message.Message) error { + ret := _m.Called(channel, generalChannel, msg, generalMsg) + + if len(ret) == 0 { + panic("no return value specified for StoreChirpMessages") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string, string, message.Message, message.Message) error); ok { + r0 = rf(channel, generalChannel, msg, generalMsg) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// StoreElection provides a mock function with given fields: laoPath, electionPath, electionPubKey, electionSecretKey, msg +func (_m *MockRepository) StoreElection(laoPath string, electionPath string, electionPubKey kyber.Point, electionSecretKey kyber.Scalar, msg message.Message) error { + ret := _m.Called(laoPath, electionPath, electionPubKey, electionSecretKey, msg) + + if len(ret) == 0 { + panic("no return value specified for StoreElection") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string, string, kyber.Point, kyber.Scalar, message.Message) error); ok { + r0 = rf(laoPath, electionPath, electionPubKey, electionSecretKey, msg) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// StoreElectionEndWithResult provides a mock function with given fields: channelID, msg, electionResultMsg +func (_m *MockRepository) StoreElectionEndWithResult(channelID string, msg message.Message, electionResultMsg message.Message) error { + ret := _m.Called(channelID, msg, electionResultMsg) + + if len(ret) == 0 { + panic("no return value specified for StoreElectionEndWithResult") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string, message.Message, message.Message) error); ok { + r0 = rf(channelID, msg, electionResultMsg) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// StoreElectionWithElectionKey provides a mock function with given fields: laoPath, electionPath, electionPubKey, electionSecretKey, msg, electionKeyMsg +func (_m *MockRepository) StoreElectionWithElectionKey(laoPath string, electionPath string, electionPubKey kyber.Point, electionSecretKey kyber.Scalar, msg message.Message, electionKeyMsg message.Message) error { + ret := _m.Called(laoPath, electionPath, electionPubKey, electionSecretKey, msg, electionKeyMsg) + + if len(ret) == 0 { + panic("no return value specified for StoreElectionWithElectionKey") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string, string, kyber.Point, kyber.Scalar, message.Message, message.Message) error); ok { + r0 = rf(laoPath, electionPath, electionPubKey, electionSecretKey, msg, electionKeyMsg) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// StoreLaoWithLaoGreet provides a mock function with given fields: channels, laoID, organizerPubBuf, msg, laoGreetMsg +func (_m *MockRepository) StoreLaoWithLaoGreet(channels map[string]string, laoID string, organizerPubBuf []byte, msg message.Message, laoGreetMsg message.Message) error { + ret := _m.Called(channels, laoID, organizerPubBuf, msg, laoGreetMsg) + + if len(ret) == 0 { + panic("no return value specified for StoreLaoWithLaoGreet") + } + + var r0 error + if rf, ok := ret.Get(0).(func(map[string]string, string, []byte, message.Message, message.Message) error); ok { + r0 = rf(channels, laoID, organizerPubBuf, msg, laoGreetMsg) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// StoreMessageAndData provides a mock function with given fields: channelID, msg +func (_m *MockRepository) StoreMessageAndData(channelID string, msg message.Message) error { + ret := _m.Called(channelID, msg) + + if len(ret) == 0 { + panic("no return value specified for StoreMessageAndData") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string, message.Message) error); ok { + r0 = rf(channelID, msg) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// StoreRollCallClose provides a mock function with given fields: channels, laoID, msg +func (_m *MockRepository) StoreRollCallClose(channels []string, laoID string, msg message.Message) error { + ret := _m.Called(channels, laoID, msg) + + if len(ret) == 0 { + panic("no return value specified for StoreRollCallClose") + } + + var r0 error + if rf, ok := ret.Get(0).(func([]string, string, message.Message) error); ok { + r0 = rf(channels, laoID, msg) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// StoreServerKeys provides a mock function with given fields: electionPubKey, electionSecretKey +func (_m *MockRepository) StoreServerKeys(electionPubKey kyber.Point, electionSecretKey kyber.Scalar) error { + ret := _m.Called(electionPubKey, electionSecretKey) + + if len(ret) == 0 { + panic("no return value specified for StoreServerKeys") + } + + var r0 error + if rf, ok := ret.Get(0).(func(kyber.Point, kyber.Scalar) error); ok { + r0 = rf(electionPubKey, electionSecretKey) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewMockRepository creates a new instance of MockRepository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockRepository(t interface { + mock.TestingT + Cleanup(func()) +}) *MockRepository { + mock := &MockRepository{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/be1-go/internal/popserver/database/repository/repository.go b/be1-go/internal/popserver/database/repository/repository.go new file mode 100644 index 0000000000..539cff3023 --- /dev/null +++ b/be1-go/internal/popserver/database/repository/repository.go @@ -0,0 +1,185 @@ +package repository + +import ( + "go.dedis.ch/kyber/v3" + "popstellar/internal/popserver/types" + "popstellar/message/query/method/message" +) + +type Repository interface { + QueryRepository + AnswerRepository + ChannelRepository + RootRepository + ElectionRepository + LAORepository + ChirpRepository + CoinRepository + ReactionRepository + + // StoreServerKeys stores the keys of the server + StoreServerKeys(electionPubKey kyber.Point, electionSecretKey kyber.Scalar) error + + // GetServerKeys get the keys of the server + GetServerKeys() (kyber.Point, kyber.Scalar, error) + + // StoreMessageAndData stores a message with an object and an action inside the database. + StoreMessageAndData(channelID string, msg message.Message) error + + // GetMessagesByID returns a set of messages by their IDs. + GetMessagesByID(IDs []string) (map[string]message.Message, error) + + // GetMessageByID returns a message by its ID. + GetMessageByID(ID string) (message.Message, error) +} + +// ======================= Query ========================== + +type QueryRepository interface { + GetResultForGetMessagesByID(params map[string][]string) (map[string][]message.Message, error) + + // GetParamsForGetMessageByID returns the params to do the getMessageByID msg in reponse of heartbeat + GetParamsForGetMessageByID(params map[string][]string) (map[string][]string, error) + + // GetAllMessagesFromChannel return all the messages received + sent on a channel + GetAllMessagesFromChannel(channelID string) ([]message.Message, error) + + GetParamsHeartbeat() (map[string][]string, error) +} + +// ======================= Answer ========================== + +type AnswerRepository interface { +} + +// ======================= Channel ========================== + +type ChannelRepository interface { + // HasChannel returns true if the channel already exists. + HasChannel(channel string) (bool, error) + + // HasMessage returns true if the message already exists. + HasMessage(messageID string) (bool, error) + + // GetChannelType returns the type of the channel. + GetChannelType(channel string) (string, error) +} + +type RootRepository interface { + + // StoreLaoWithLaoGreet stores a list of "sub" channels, a message and a lao greet message inside the database. + StoreLaoWithLaoGreet( + channels map[string]string, + laoID string, + organizerPubBuf []byte, + msg, laoGreetMsg message.Message) error + + // StoreMessageAndData stores a message inside the database. + StoreMessageAndData(channelID string, msg message.Message) error + + // HasChannel returns true if the channel already exists. + HasChannel(channel string) (bool, error) +} + +type LAORepository interface { + // GetLaoWitnesses returns the list of witnesses of a LAO. + GetLaoWitnesses(laoID string) (map[string]struct{}, error) + + // GetOrganizerPubKey returns the organizer public key of a LAO. + GetOrganizerPubKey(laoID string) (kyber.Point, error) + + // GetRollCallState returns the state of th lao roll call. + GetRollCallState(channel string) (string, error) + + // CheckPrevOpenOrReopenID returns true if the previous roll call open or reopen has the same ID + CheckPrevOpenOrReopenID(channel, nextID string) (bool, error) + + // CheckPrevCreateOrCloseID returns true if the previous roll call create or close has the same ID + CheckPrevCreateOrCloseID(channel, nextID string) (bool, error) + + // StoreRollCallClose stores a list of chirp channels and a rollCallClose message inside the database. + StoreRollCallClose(channels []string, laoID string, msg message.Message) error + + // StoreElectionWithElectionKey stores an electionSetup message and an election key message inside the database. + StoreElectionWithElectionKey( + laoPath, electionPath string, + electionPubKey kyber.Point, + electionSecretKey kyber.Scalar, + msg, electionKeyMsg message.Message) error + + //StoreElection stores an electionSetup message inside the database. + StoreElection( + laoPath, electionPath string, + electionPubKey kyber.Point, + electionSecretKey kyber.Scalar, + msg message.Message) error + + // StoreMessageAndData stores a message with an object and an action inside the database. + StoreMessageAndData(channelID string, msg message.Message) error + + // HasMessage returns true if the message already exists. + HasMessage(messageID string) (bool, error) +} + +type ElectionRepository interface { + + // GetLAOOrganizerPubKey returns the organizer public key of an election. + GetLAOOrganizerPubKey(electionID string) (kyber.Point, error) + + // GetElectionSecretKey returns the secret key of an election. + GetElectionSecretKey(electionID string) (kyber.Scalar, error) + + // IsElectionStartedOrEnded returns true if the election is started or ended. + IsElectionStartedOrEnded(electionID string) (bool, error) + + // IsElectionEnded returns true if the election is ended. + IsElectionEnded(electionID string) (bool, error) + + //IsElectionStarted returns true if the election is started. + IsElectionStarted(electionID string) (bool, error) + + // GetElectionType returns the type of an election. + GetElectionType(electionID string) (string, error) + + // GetElectionCreationTime returns the creation time of an election. + GetElectionCreationTime(electionID string) (int64, error) + + // GetElectionAttendees returns the attendees of an election. + GetElectionAttendees(electionID string) (map[string]struct{}, error) + + // GetElectionQuestions returns the questions of an election. + GetElectionQuestions(electionID string) (map[string]types.Question, error) + + // GetElectionQuestionsWithValidVotes returns the questions of an election with valid votes. + GetElectionQuestionsWithValidVotes(electionID string) (map[string]types.Question, error) + + // StoreElectionEndWithResult stores a message and an election result message inside the database. + StoreElectionEndWithResult(channelID string, msg, electionResultMsg message.Message) error + + // StoreMessageAndData stores a message with an object and an action inside the database. + StoreMessageAndData(channelID string, msg message.Message) error +} + +type ChirpRepository interface { + // HasMessage returns true if the message already exists. + HasMessage(messageID string) (bool, error) + + // StoreChirpMessages stores a chirp message and a generalChirp broadcast inside the database. + StoreChirpMessages(channel, generalChannel string, msg, generalMsg message.Message) error +} + +type CoinRepository interface { + // StoreMessageAndData stores a message with an object and an action inside the database. + StoreMessageAndData(channelID string, msg message.Message) error +} + +type ReactionRepository interface { + // IsAttendee returns if the user has participated in the last roll-call from the LAO + IsAttendee(laoPath string, poptoken string) (bool, error) + + // GetReactionSender returns a reaction sender + GetReactionSender(messageID string) (string, error) + + // StoreMessageAndData stores a message with an object and an action inside the database. + StoreMessageAndData(channelID string, msg message.Message) error +} diff --git a/be1-go/internal/popserver/database/sqlite/sqlite.go b/be1-go/internal/popserver/database/sqlite/sqlite.go new file mode 100644 index 0000000000..abcc72eca4 --- /dev/null +++ b/be1-go/internal/popserver/database/sqlite/sqlite.go @@ -0,0 +1,1265 @@ +package sqlite + +import ( + "database/sql" + "encoding/base64" + "encoding/json" + "errors" + "go.dedis.ch/kyber/v3" + "golang.org/x/xerrors" + _ "modernc.org/sqlite" + "popstellar/crypto" + "popstellar/internal/popserver/types" + "popstellar/message/messagedata" + "popstellar/message/query/method/message" + "strings" + "time" +) + +func (s *SQLite) StoreServerKeys(electionPubKey kyber.Point, electionSecretKey kyber.Scalar) error { + dbLock.Lock() + defer dbLock.Unlock() + + tx, err := s.database.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + electionPubBuf, err := electionPubKey.MarshalBinary() + if err != nil { + return err + } + electionSecBuf, err := electionSecretKey.MarshalBinary() + if err != nil { + return err + } + + _, err = tx.Exec(insertKeys, serverKeysPath, electionPubBuf, electionSecBuf) + if err != nil { + return err + } + + return tx.Commit() +} + +func (s *SQLite) GetServerKeys() (kyber.Point, kyber.Scalar, error) { + dbLock.RLock() + defer dbLock.RUnlock() + + var serverPubBuf []byte + var serverSecBuf []byte + err := s.database.QueryRow(selectKeys, serverKeysPath).Scan(&serverPubBuf, &serverSecBuf) + if err != nil { + return nil, nil, err + } + serverPubKey := crypto.Suite.Point() + err = serverPubKey.UnmarshalBinary(serverPubBuf) + if err != nil { + return nil, nil, err + } + serverSecKey := crypto.Suite.Scalar() + err = serverSecKey.UnmarshalBinary(serverSecBuf) + if err != nil { + return nil, nil, err + } + + return serverPubKey, serverSecKey, nil +} + +func (s *SQLite) StoreMessageAndData(channelPath string, msg message.Message) error { + dbLock.Lock() + defer dbLock.Unlock() + + tx, err := s.database.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + if err = addPendingSignatures(tx, &msg); err != nil { + return err + } + + messageData, err := base64.URLEncoding.DecodeString(msg.Data) + if err != nil { + return err + } + + msgByte, err := json.Marshal(msg) + if err != nil { + return err + } + _, err = tx.Exec(insertMessage, msg.MessageID, msgByte, messageData, time.Now().UnixNano()) + if err != nil { + return err + + } + _, err = tx.Exec(insertChannelMessage, channelPath, msg.MessageID, true) + if err != nil { + return err + + } + return tx.Commit() +} + +func addPendingSignatures(tx *sql.Tx, msg *message.Message) error { + rows, err := tx.Query(selectPendingSignatures, msg.MessageID) + if err != nil { + return err + } + + for rows.Next() { + var witness string + var signature string + if err = rows.Scan(&witness, &signature); err != nil { + return err + } + msg.WitnessSignatures = append(msg.WitnessSignatures, message.WitnessSignature{ + Witness: witness, + Signature: signature, + }) + } + + if err = rows.Err(); err != nil { + return err + } + + _, err = tx.Exec(deletePendingSignatures, msg.MessageID) + return err +} + +// GetMessagesByID returns a set of messages by their IDs. +func (s *SQLite) GetMessagesByID(IDs []string) (map[string]message.Message, error) { + dbLock.RLock() + defer dbLock.RUnlock() + + if len(IDs) == 0 { + return make(map[string]message.Message), nil + } + + IDsInterface := make([]interface{}, len(IDs)) + for i, v := range IDs { + IDsInterface[i] = v + } + rows, err := s.database.Query("SELECT messageID, message "+ + "FROM message "+ + "WHERE messageID IN ("+strings.Repeat("?,", len(IDs)-1)+"?"+")", IDsInterface...) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return nil, err + } else if errors.Is(err, sql.ErrNoRows) { + return make(map[string]message.Message), nil + } + + messagesByID := make(map[string]message.Message, len(IDs)) + for rows.Next() { + var messageID string + var messageByte []byte + if err = rows.Scan(&messageID, &messageByte); err != nil { + return nil, err + } + + var msg message.Message + if err = json.Unmarshal(messageByte, &msg); err != nil { + return nil, err + } + messagesByID[messageID] = msg + } + + if err = rows.Err(); err != nil { + return nil, err + } + return messagesByID, nil +} + +// GetMessageByID returns a message by its ID. +func (s *SQLite) GetMessageByID(ID string) (message.Message, error) { + dbLock.RLock() + defer dbLock.RUnlock() + + var messageByte []byte + err := s.database.QueryRow(selectMessage, ID).Scan(&messageByte) + if err != nil { + return message.Message{}, err + } + + var msg message.Message + if err = json.Unmarshal(messageByte, &msg); err != nil { + return message.Message{}, err + } + return msg, nil +} + +// AddWitnessSignature stores a pending signature inside the SQLite database. +func (s *SQLite) AddWitnessSignature(messageID string, witness string, signature string) error { + dbLock.Lock() + defer dbLock.Unlock() + + tx, err := s.database.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + witnessSignature, err := json.Marshal(message.WitnessSignature{ + Witness: witness, + Signature: signature, + }) + if err != nil { + return err + } + + res, err := tx.Exec(updateMsg, witnessSignature, messageID) + if err != nil { + return err + } + changes, err := res.RowsAffected() + if err != nil { + return err + } + if changes == 0 { + _, err := tx.Exec(insertPendingSignatures, messageID, witness, signature) + if err != nil { + return err + } + } + return tx.Commit() +} + +// StoreChannel mainly used for testing and storing the root channel +func (s *SQLite) StoreChannel(channelPath, channelType, laoPath string) error { + dbLock.Lock() + defer dbLock.Unlock() + + _, err := s.database.Exec(insertChannel, channelPath, channelTypeToID[channelType], laoPath) + return err +} + +func (s *SQLite) GetAllChannels() ([]string, error) { + dbLock.RLock() + defer dbLock.RUnlock() + + rows, err := s.database.Query(selectAllChannels) + if err != nil { + return nil, err + } + + var channels []string + for rows.Next() { + var channelPath string + if err = rows.Scan(&channelPath); err != nil { + return nil, err + } + channels = append(channels, channelPath) + } + + if err = rows.Err(); err != nil { + return nil, err + } + + return channels, nil +} + +//====================================================================================================================== +// QueryRepository interface implementation +//====================================================================================================================== + +// GetChannelType returns the type of the channelPath. +func (s *SQLite) GetChannelType(channelPath string) (string, error) { + dbLock.RLock() + defer dbLock.RUnlock() + + var channelType string + err := s.database.QueryRow(selectChannelType, channelPath).Scan(&channelType) + return channelType, err +} + +// GetAllMessagesFromChannel returns all the messages received + sent on a channel sorted by stored time. +func (s *SQLite) GetAllMessagesFromChannel(channelPath string) ([]message.Message, error) { + dbLock.RLock() + defer dbLock.RUnlock() + + rows, err := s.database.Query(selectAllMessagesFromChannel, channelPath) + if err != nil { + return nil, err + } + + messages := make([]message.Message, 0) + for rows.Next() { + var messageByte []byte + if err = rows.Scan(&messageByte); err != nil { + return nil, err + } + var msg message.Message + if err = json.Unmarshal(messageByte, &msg); err != nil { + return nil, err + } + messages = append(messages, msg) + } + + if rows.Err() != nil { + return nil, err + } + + return messages, nil +} + +func (s *SQLite) GetResultForGetMessagesByID(params map[string][]string) (map[string][]message.Message, error) { + dbLock.RLock() + defer dbLock.RUnlock() + + var interfaces []interface{} + // isBaseChannel must be true + interfaces = append(interfaces, true) + for _, value := range params { + for _, v := range value { + interfaces = append(interfaces, v) + } + } + + if len(interfaces) == 1 { + return make(map[string][]message.Message), nil + } + + rows, err := s.database.Query("SELECT message, channelPath "+ + "FROM message JOIN channelMessage on message.messageID = channelMessage.messageID "+ + "WHERE isBaseChannel = ? "+ + "AND message.messageID IN ("+strings.Repeat("?,", len(interfaces)-2)+"?"+") ", interfaces...) + if err != nil { + return nil, err + } + + result := make(map[string][]message.Message) + for rows.Next() { + var messageByte []byte + var channelPath string + if err = rows.Scan(&messageByte, &channelPath); err != nil { + return nil, err + } + var msg message.Message + if err = json.Unmarshal(messageByte, &msg); err != nil { + return nil, err + } + result[channelPath] = append(result[channelPath], msg) + } + + if err = rows.Err(); err != nil { + return nil, err + } + + return result, nil +} + +func (s *SQLite) GetParamsHeartbeat() (map[string][]string, error) { + dbLock.RLock() + defer dbLock.RUnlock() + + rows, err := s.database.Query(selectBaseChannelMessages, true) + if err != nil { + return nil, err + } + + result := make(map[string][]string) + for rows.Next() { + var channelPath string + var messageID string + if err = rows.Scan(&messageID, &channelPath); err != nil { + return nil, err + } + if _, ok := result[channelPath]; !ok { + result[channelPath] = make([]string, 0) + } + result[channelPath] = append(result[channelPath], messageID) + } + + if err = rows.Err(); err != nil { + return nil, err + } + + return result, nil +} + +func (s *SQLite) GetParamsForGetMessageByID(params map[string][]string) (map[string][]string, error) { + dbLock.RLock() + defer dbLock.RUnlock() + + var interfaces []interface{} + // isBaseChannel must be true + interfaces = append(interfaces, true) + for _, value := range params { + for _, v := range value { + interfaces = append(interfaces, v) + } + } + + if len(interfaces) == 1 { + return make(map[string][]string), nil + } + + rows, err := s.database.Query("SELECT message.messageID, channelPath "+ + "FROM message JOIN channelMessage on message.messageID = channelMessage.messageID "+ + "WHERE isBaseChannel = ? "+ + "AND message.messageID IN ("+strings.Repeat("?,", len(interfaces)-2)+"?"+") ", interfaces...) + if err != nil { + return nil, err + } + + result := make(map[string]struct{}) + for rows.Next() { + var messageID string + var channelPath string + if err = rows.Scan(&messageID, &channelPath); err != nil { + return nil, err + } + result[messageID] = struct{}{} + } + + if err = rows.Err(); err != nil { + return nil, err + } + + missingIDs := make(map[string][]string) + for channel, messageIDs := range params { + for _, messageID := range messageIDs { + if _, ok := result[messageID]; !ok { + missingIDs[channel] = append(missingIDs[channel], messageID) + } + } + } + return missingIDs, nil +} + +//====================================================================================================================== +// ChannelRepository interface implementation +//====================================================================================================================== + +func (s *SQLite) HasChannel(channelPath string) (bool, error) { + dbLock.RLock() + defer dbLock.RUnlock() + + var c string + err := s.database.QueryRow(selectChannelPath, channelPath).Scan(&c) + if err != nil && errors.Is(err, sql.ErrNoRows) { + return false, nil + } else if err != nil && !errors.Is(err, sql.ErrNoRows) { + return false, err + } else { + return true, nil + } +} + +func (s *SQLite) HasMessage(messageID string) (bool, error) { + dbLock.RLock() + defer dbLock.RUnlock() + + var msgID string + err := s.database.QueryRow(selectMessageID, messageID).Scan(&msgID) + if err != nil && errors.Is(err, sql.ErrNoRows) { + return false, nil + } else if err != nil && !errors.Is(err, sql.ErrNoRows) { + return false, err + } else { + return true, nil + } +} + +//====================================================================================================================== +// RootRepository interface implementation +//====================================================================================================================== + +func (s *SQLite) StoreLaoWithLaoGreet( + channels map[string]string, + laoPath string, + organizerPubBuf []byte, + msg, laoGreetMsg message.Message) error { + + dbLock.Lock() + defer dbLock.Unlock() + + tx, err := s.database.Begin() + if err != nil { + return err + } + + msgByte, err := json.Marshal(msg) + if err != nil { + return err + } + laoGreetMsgByte, err := json.Marshal(laoGreetMsg) + if err != nil { + return err + } + + messageData, err := base64.URLEncoding.DecodeString(msg.Data) + if err != nil { + return err + } + laoGreetData, err := base64.URLEncoding.DecodeString(laoGreetMsg.Data) + if err != nil { + return err + } + + storedTime := time.Now().UnixNano() + + for channel, channelType := range channels { + _, err = tx.Exec(insertChannel, channel, channelTypeToID[channelType], laoPath) + if err != nil { + return err + } + } + + _, err = tx.Exec(insertMessage, msg.MessageID, msgByte, messageData, storedTime) + if err != nil { + return err + } + _, err = tx.Exec(insertChannelMessage, "/root", msg.MessageID, true) + if err != nil { + return err + } + + _, err = tx.Exec(insertChannelMessage, laoPath, msg.MessageID, false) + if err != nil { + return err + } + + _, err = tx.Exec(insertPublicKey, laoPath, organizerPubBuf) + if err != nil { + return err + } + _, err = tx.Exec(insertMessage, laoGreetMsg.MessageID, laoGreetMsgByte, laoGreetData, storedTime) + if err != nil { + return err + } + _, err = tx.Exec(insertChannelMessage, laoPath, laoGreetMsg.MessageID, false) + if err != nil { + return err + } + + err = tx.Commit() + if err != nil { + return err + } + + defer tx.Rollback() + return nil +} + +//====================================================================================================================== +// LaoRepository interface implementation +//====================================================================================================================== + +func (s *SQLite) GetOrganizerPubKey(laoPath string) (kyber.Point, error) { + dbLock.RLock() + defer dbLock.RUnlock() + + var organizerPubBuf []byte + err := s.database.QueryRow(selectPublicKey, laoPath).Scan(&organizerPubBuf) + if err != nil { + return nil, err + } + organizerPubKey := crypto.Suite.Point() + err = organizerPubKey.UnmarshalBinary(organizerPubBuf) + if err != nil { + return nil, err + } + return organizerPubKey, nil +} + +func (s *SQLite) GetRollCallState(channelPath string) (string, error) { + dbLock.RLock() + defer dbLock.RUnlock() + + var state string + err := s.database.QueryRow(selectLastRollCallMessage, messagedata.RollCallObject, channelPath).Scan(&state) + if err != nil { + return "", err + } + return state, nil +} + +func (s *SQLite) CheckPrevOpenOrReopenID(channel, nextID string) (bool, error) { + dbLock.RLock() + defer dbLock.RUnlock() + + var lastMsg []byte + var lastAction string + + err := s.database.QueryRow(selectLastRollCallMessageInList, channel, messagedata.RollCallObject, + messagedata.RollCallActionOpen, messagedata.RollCallActionReOpen).Scan(&lastMsg, &lastAction) + + if err != nil && errors.Is(err, sql.ErrNoRows) { + return false, nil + } else if err != nil { + return false, err + } + + switch lastAction { + case messagedata.RollCallActionOpen: + var rollCallOpen messagedata.RollCallOpen + err = json.Unmarshal(lastMsg, &rollCallOpen) + if err != nil { + return false, err + } + return rollCallOpen.UpdateID == nextID, nil + case messagedata.RollCallActionReOpen: + var rollCallReOpen messagedata.RollCallReOpen + err = json.Unmarshal(lastMsg, &rollCallReOpen) + if err != nil { + return false, err + } + return rollCallReOpen.UpdateID == nextID, nil + } + + return false, nil +} + +func (s *SQLite) CheckPrevCreateOrCloseID(channel, nextID string) (bool, error) { + dbLock.RLock() + defer dbLock.RUnlock() + + var lastMsg []byte + var lastAction string + + err := s.database.QueryRow(selectLastRollCallMessageInList, channel, messagedata.RollCallObject, + messagedata.RollCallActionCreate, messagedata.RollCallActionClose).Scan(&lastMsg, &lastAction) + + if err != nil && errors.Is(err, sql.ErrNoRows) { + return false, nil + } else if err != nil { + return false, err + } + + switch lastAction { + case messagedata.RollCallActionCreate: + var rollCallCreate messagedata.RollCallCreate + err = json.Unmarshal(lastMsg, &rollCallCreate) + if err != nil { + return false, err + } + return rollCallCreate.ID == nextID, nil + case messagedata.RollCallActionClose: + var rollCallClose messagedata.RollCallClose + err = json.Unmarshal(lastMsg, &rollCallClose) + if err != nil { + return false, err + } + return rollCallClose.UpdateID == nextID, nil + } + + return false, nil +} + +func (s *SQLite) GetLaoWitnesses(laoPath string) (map[string]struct{}, error) { + dbLock.RLock() + defer dbLock.RUnlock() + + var witnesses []string + err := s.database.QueryRow(selectLaoWitnesses, laoPath, messagedata.LAOObject, messagedata.LAOActionCreate).Scan(&witnesses) + if err != nil { + return nil, err + } + + var witnessesMap = make(map[string]struct{}) + for _, witness := range witnesses { + witnessesMap[witness] = struct{}{} + } + + return witnessesMap, nil +} + +func (s *SQLite) StoreRollCallClose(channels []string, laoPath string, msg message.Message) error { + dbLock.Lock() + defer dbLock.Unlock() + + tx, err := s.database.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + messageData, err := base64.URLEncoding.DecodeString(msg.Data) + if err != nil { + return err + } + + msgBytes, err := json.Marshal(msg) + if err != nil { + return err + } + + _, err = tx.Exec(insertMessage, msg.MessageID, msgBytes, messageData, time.Now().UnixNano()) + if err != nil { + return err + } + _, err = tx.Exec(insertChannelMessage, laoPath, msg.MessageID, true) + if err != nil { + return err + } + for _, channel := range channels { + _, err = tx.Exec(insertChannel, channel, channelTypeToID[ChirpType], laoPath) + if err != nil { + return err + } + } + err = tx.Commit() + if err != nil { + return err + } + return nil + +} + +func (s *SQLite) storeElectionHelper( + tx *sql.Tx, + storedTime int64, + laoPath, electionPath string, + electionPubKey kyber.Point, + electionSecretKey kyber.Scalar, + msg message.Message) error { + + msgBytes, err := json.Marshal(msg) + if err != nil { + return err + } + messageData, err := base64.URLEncoding.DecodeString(msg.Data) + if err != nil { + return err + } + + electionPubBuf, err := electionPubKey.MarshalBinary() + if err != nil { + return err + } + electionSecretBuf, err := electionSecretKey.MarshalBinary() + if err != nil { + return err + } + + _, err = tx.Exec(insertMessage, msg.MessageID, msgBytes, messageData, storedTime) + if err != nil { + return err + } + _, err = tx.Exec(insertChannelMessage, laoPath, msg.MessageID, true) + if err != nil { + return err + } + _, err = tx.Exec(insertChannel, electionPath, channelTypeToID[ElectionType], laoPath) + if err != nil { + return err + } + _, err = tx.Exec(insertChannelMessage, electionPath, msg.MessageID, false) + if err != nil { + return err + } + _, err = tx.Exec(insertKeys, electionPath, electionPubBuf, electionSecretBuf) + if err != nil { + return err + } + + return nil +} + +func (s *SQLite) StoreElection( + laoPath, electionPath string, + electionPubKey kyber.Point, + electionSecretKey kyber.Scalar, + msg message.Message) error { + + dbLock.Lock() + defer dbLock.Unlock() + + tx, err := s.database.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + storedTime := time.Now().UnixNano() + + err = s.storeElectionHelper(tx, storedTime, laoPath, electionPath, electionPubKey, electionSecretKey, msg) + if err != nil { + return err + } + + return tx.Commit() +} + +func (s *SQLite) StoreElectionWithElectionKey( + laoPath, electionPath string, + electionPubKey kyber.Point, + electionSecretKey kyber.Scalar, + msg, electionKeyMsg message.Message) error { + + dbLock.Lock() + defer dbLock.Unlock() + + tx, err := s.database.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + storedTime := time.Now().UnixNano() + + err = s.storeElectionHelper(tx, storedTime, laoPath, electionPath, electionPubKey, electionSecretKey, msg) + if err != nil { + return err + } + + electionKey, err := base64.URLEncoding.DecodeString(electionKeyMsg.Data) + if err != nil { + return err + } + electionKeyMsgBytes, err := json.Marshal(electionKeyMsg) + if err != nil { + return err + } + + _, err = tx.Exec(insertMessage, electionKeyMsg.MessageID, electionKeyMsgBytes, electionKey, storedTime) + if err != nil { + return err + } + _, err = tx.Exec(insertChannelMessage, electionPath, electionKeyMsg.MessageID, false) + if err != nil { + return err + } + + return tx.Commit() +} + +//====================================================================================================================== +// ElectionRepository interface implementation +//====================================================================================================================== + +func (s *SQLite) GetLAOOrganizerPubKey(electionPath string) (kyber.Point, error) { + dbLock.RLock() + defer dbLock.RUnlock() + + tx, err := s.database.Begin() + if err != nil { + return nil, err + } + defer tx.Rollback() + + var electionPubBuf []byte + err = tx.QueryRow(selectLaoOrganizerKey, electionPath).Scan(&electionPubBuf) + if err != nil { + return nil, err + } + electionPubKey := crypto.Suite.Point() + err = electionPubKey.UnmarshalBinary(electionPubBuf) + if err != nil { + return nil, err + } + + err = tx.Commit() + if err != nil { + return nil, err + } + + return electionPubKey, nil +} + +func (s *SQLite) GetElectionSecretKey(electionPath string) (kyber.Scalar, error) { + dbLock.RLock() + defer dbLock.RUnlock() + + var electionSecretBuf []byte + err := s.database.QueryRow(selectSecretKey, electionPath).Scan(&electionSecretBuf) + if err != nil { + return nil, err + } + + electionSecretKey := crypto.Suite.Scalar() + err = electionSecretKey.UnmarshalBinary(electionSecretBuf) + if err != nil { + return nil, err + } + return electionSecretKey, nil +} + +func (s *SQLite) getElectionState(electionPath string) (string, error) { + dbLock.RLock() + defer dbLock.RUnlock() + + var state string + err := s.database.QueryRow(selectLastElectionMessage, electionPath, messagedata.ElectionObject, messagedata.VoteActionCastVote).Scan(&state) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return "", err + } + return state, nil +} + +func (s *SQLite) IsElectionStartedOrEnded(electionPath string) (bool, error) { + dbLock.RLock() + defer dbLock.RUnlock() + + state, err := s.getElectionState(electionPath) + if err != nil { + return false, err + } + + return state == messagedata.ElectionActionOpen || state == messagedata.ElectionActionEnd, nil +} + +func (s *SQLite) IsElectionStarted(electionPath string) (bool, error) { + dbLock.RLock() + defer dbLock.RUnlock() + + state, err := s.getElectionState(electionPath) + if err != nil { + return false, err + } + return state == messagedata.ElectionActionOpen, nil +} + +func (s *SQLite) IsElectionEnded(electionPath string) (bool, error) { + dbLock.RLock() + defer dbLock.RUnlock() + + state, err := s.getElectionState(electionPath) + if err != nil { + return false, err + } + return state == messagedata.ElectionActionEnd, nil +} + +func (s *SQLite) GetElectionCreationTime(electionPath string) (int64, error) { + dbLock.RLock() + defer dbLock.RUnlock() + + var creationTime int64 + err := s.database.QueryRow(selectElectionCreationTime, electionPath, messagedata.ElectionObject, messagedata.ElectionActionSetup).Scan(&creationTime) + if err != nil { + return 0, err + } + return creationTime, nil +} + +func (s *SQLite) GetElectionType(electionPath string) (string, error) { + dbLock.RLock() + defer dbLock.RUnlock() + + var electionType string + err := s.database.QueryRow(selectElectionType, electionPath, messagedata.ElectionObject, messagedata.ElectionActionSetup).Scan(&electionType) + if err != nil { + return "", err + } + return electionType, nil +} + +func (s *SQLite) GetElectionAttendees(electionPath string) (map[string]struct{}, error) { + dbLock.RLock() + defer dbLock.RUnlock() + + var rollCallCloseBytes []byte + err := s.database.QueryRow(selectElectionAttendees, + electionPath, + messagedata.RollCallObject, + messagedata.RollCallActionClose, + messagedata.RollCallObject, + messagedata.RollCallActionClose, + ).Scan(&rollCallCloseBytes) + if err != nil { + return nil, err + } + + var rollCallClose messagedata.RollCallClose + err = json.Unmarshal(rollCallCloseBytes, &rollCallClose) + if err != nil { + return nil, err + } + + attendeesMap := make(map[string]struct{}) + for _, attendee := range rollCallClose.Attendees { + attendeesMap[attendee] = struct{}{} + } + return attendeesMap, nil +} + +func (s *SQLite) getElectionSetup(electionPath string, tx *sql.Tx) (messagedata.ElectionSetup, error) { + dbLock.RLock() + defer dbLock.RUnlock() + + var electionSetupBytes []byte + err := tx.QueryRow(selectElectionSetup, electionPath, messagedata.ElectionObject, messagedata.ElectionActionSetup).Scan(&electionSetupBytes) + if err != nil { + return messagedata.ElectionSetup{}, err + } + + var electionSetup messagedata.ElectionSetup + err = json.Unmarshal(electionSetupBytes, &electionSetup) + if err != nil { + return messagedata.ElectionSetup{}, err + } + return electionSetup, nil + +} + +func (s *SQLite) GetElectionQuestions(electionPath string) (map[string]types.Question, error) { + dbLock.RLock() + defer dbLock.RUnlock() + + tx, err := s.database.Begin() + if err != nil { + return nil, err + + } + defer tx.Rollback() + + electionSetup, err := s.getElectionSetup(electionPath, tx) + if err != nil { + return nil, err + + } + questions, err := getQuestionsFromMessage(electionSetup) + if err != nil { + return nil, err + } + + err = tx.Commit() + if err != nil { + return nil, err + + } + return questions, nil +} + +func (s *SQLite) GetElectionQuestionsWithValidVotes(electionPath string) (map[string]types.Question, error) { + dbLock.RLock() + defer dbLock.RUnlock() + + tx, err := s.database.Begin() + if err != nil { + return nil, err + } + defer tx.Rollback() + + electionSetup, err := s.getElectionSetup(electionPath, tx) + if err != nil { + return nil, err + } + questions, err := getQuestionsFromMessage(electionSetup) + if err != nil { + return nil, err + } + + rows, err := tx.Query(selectCastVotes, electionPath, messagedata.ElectionObject, messagedata.VoteActionCastVote) + if err != nil { + return nil, err + } + + for rows.Next() { + var voteBytes []byte + var msgID string + var sender string + if err = rows.Scan(&voteBytes, &msgID, &sender); err != nil { + return nil, err + } + var vote messagedata.VoteCastVote + err = json.Unmarshal(voteBytes, &vote) + if err != nil { + return nil, err + } + err = updateVote(msgID, sender, vote, questions) + if err != nil { + return nil, err + } + } + if err = rows.Err(); err != nil { + return nil, err + } + err = tx.Commit() + if err != nil { + return nil, err + } + return questions, nil +} + +func getQuestionsFromMessage(electionSetup messagedata.ElectionSetup) (map[string]types.Question, error) { + questions := make(map[string]types.Question) + for _, question := range electionSetup.Questions { + ballotOptions := make([]string, len(question.BallotOptions)) + copy(ballotOptions, question.BallotOptions) + _, ok := questions[question.ID] + if ok { + return nil, xerrors.Errorf("duplicate question ID") + } + questions[question.ID] = types.Question{ + ID: []byte(question.ID), + BallotOptions: ballotOptions, + ValidVotes: make(map[string]types.ValidVote), + Method: question.VotingMethod, + } + } + return questions, nil +} + +func updateVote(msgID, sender string, castVote messagedata.VoteCastVote, questions map[string]types.Question) error { + for idx, vote := range castVote.Votes { + question, ok := questions[vote.Question] + if !ok { + return xerrors.Errorf("question not found for vote number %d sent by %s", idx, sender) + } + earlierVote, ok := question.ValidVotes[sender] + if !ok || earlierVote.VoteTime < castVote.CreatedAt { + question.ValidVotes[sender] = types.ValidVote{ + MsgID: msgID, + ID: vote.ID, + VoteTime: castVote.CreatedAt, + Index: vote.Vote, + } + } + } + return nil +} + +func (s *SQLite) StoreElectionEndWithResult(channelPath string, msg, electionResultMsg message.Message) error { + dbLock.Lock() + defer dbLock.Unlock() + + tx, err := s.database.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + msgBytes, err := json.Marshal(msg) + if err != nil { + return err + } + messageData, err := base64.URLEncoding.DecodeString(msg.Data) + if err != nil { + return err + } + electionResult, err := base64.URLEncoding.DecodeString(electionResultMsg.Data) + if err != nil { + return err + } + electionResultMsgBytes, err := json.Marshal(electionResultMsg) + if err != nil { + return err + } + storedTime := time.Now().UnixNano() + + _, err = tx.Exec(insertMessage, msg.MessageID, msgBytes, messageData, storedTime) + if err != nil { + return err + } + _, err = tx.Exec(insertChannelMessage, channelPath, msg.MessageID, true) + if err != nil { + return err + } + _, err = tx.Exec(insertMessage, electionResultMsg.MessageID, electionResultMsgBytes, electionResult, storedTime) + if err != nil { + return err + } + _, err = tx.Exec(insertChannelMessage, channelPath, electionResultMsg.MessageID, false) + if err != nil { + return err + } + err = tx.Commit() + return err +} + +//====================================================================================================================== +// ChirpRepository interface implementation +//====================================================================================================================== + +func (s *SQLite) StoreChirpMessages(channel, generalChannel string, msg, generalMsg message.Message) error { + dbLock.Lock() + defer dbLock.Unlock() + + tx, err := s.database.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + msgBytes, err := json.Marshal(msg) + if err != nil { + return err + } + messageData, err := base64.URLEncoding.DecodeString(msg.Data) + if err != nil { + return err + } + generalMsgBytes, err := json.Marshal(generalMsg) + if err != nil { + return err + } + generalMessageData, err := base64.URLEncoding.DecodeString(generalMsg.Data) + if err != nil { + return err + } + storedTime := time.Now().UnixNano() + + _, err = tx.Exec(insertMessage, msg.MessageID, msgBytes, messageData, storedTime) + if err != nil { + return err + } + _, err = tx.Exec(insertChannelMessage, channel, msg.MessageID, true) + if err != nil { + return err + } + _, err = tx.Exec(insertMessage, generalMsg.MessageID, generalMsgBytes, generalMessageData, storedTime) + if err != nil { + return err + } + _, err = tx.Exec(insertChannelMessage, generalChannel, generalMsg.MessageID, false) + if err != nil { + return err + } + err = tx.Commit() + return err +} + +//====================================================================================================================== +// ReactionRepository interface implementation +//====================================================================================================================== + +func (s *SQLite) IsAttendee(laoPath, poptoken string) (bool, error) { + dbLock.RLock() + defer dbLock.RUnlock() + + var rollCallCloseBytes []byte + err := s.database.QueryRow(selectLastRollCallClose, laoPath, messagedata.RollCallObject, messagedata.RollCallActionClose).Scan(&rollCallCloseBytes) + if err != nil { + return false, err + } + + var rollCallClose messagedata.RollCallClose + err = json.Unmarshal(rollCallCloseBytes, &rollCallClose) + if err != nil { + return false, err + } + + for _, attendee := range rollCallClose.Attendees { + if attendee == poptoken { + return true, nil + } + } + + return false, nil +} + +func (s *SQLite) GetReactionSender(messageID string) (string, error) { + dbLock.RLock() + defer dbLock.RUnlock() + + var sender string + var object string + var action string + err := s.database.QueryRow(selectSender, messageID).Scan(&sender, &object, &action) + if err != nil && errors.Is(err, sql.ErrNoRows) { + return "", nil + } else if err != nil { + return "", err + + } + + if object != messagedata.ReactionObject || action != messagedata.ReactionActionAdd { + return "", xerrors.New("unexpected object or action") + } + return sender, nil +} diff --git a/be1-go/internal/popserver/database/sqlite/sqlite_const.go b/be1-go/internal/popserver/database/sqlite/sqlite_const.go new file mode 100644 index 0000000000..33d3092378 --- /dev/null +++ b/be1-go/internal/popserver/database/sqlite/sqlite_const.go @@ -0,0 +1,295 @@ +package sqlite + +const ( + DefaultPath = "sqlite.db" + serverKeysPath = "server_keys" +) + +const ( + RootType = "root" + LaoType = "lao" + ElectionType = "election" + ChirpType = "chirp" + ReactionType = "reaction" + ConsensusType = "consensus" + CoinType = "coin" + AuthType = "auth" + PopChaType = "popcha" + GeneralChirpType = "generalChirp" +) + +var channelTypeToID = map[string]string{ + RootType: "1", + LaoType: "2", + ElectionType: "3", + ChirpType: "4", + ReactionType: "5", + ConsensusType: "6", + PopChaType: "7", + CoinType: "8", + AuthType: "9", + GeneralChirpType: "10", +} + +var channelTypes = []string{ + RootType, + LaoType, + ElectionType, + ChirpType, + ReactionType, + ConsensusType, + PopChaType, + CoinType, + AuthType, + GeneralChirpType, +} + +const ( + createMessage = ` + CREATE TABLE IF NOT EXISTS message ( + messageID TEXT, + message TEXT, + messageData TEXT NULL, + storedTime BIGINT, + PRIMARY KEY (messageID) + )` + + createChannelType = ` + CREATE TABLE IF NOT EXISTS channelType ( + ID INTEGER, + type TEXT, + PRIMARY KEY (ID) + )` + + createChannel = ` + CREATE TABLE IF NOT EXISTS channel ( + channelPath TEXT, + typeID TEXT, + laoPath TEXT NULL, + FOREIGN KEY (laoPath) REFERENCES channel(channelPath), + FOREIGN KEY (typeID) REFERENCES channelType(ID), + PRIMARY KEY (channelPath) + )` + + createKey = ` + CREATE TABLE IF NOT EXISTS key ( + channelPath TEXT, + publicKey TEXT, + secretKey TEXT NULL, + FOREIGN KEY (channelPath) REFERENCES channel(channelPath), + PRIMARY KEY (channelPath) + )` + + createChannelMessage = ` + CREATE TABLE IF NOT EXISTS channelMessage ( + channelPath TEXT, + messageID TEXT, + isBaseChannel BOOLEAN, + FOREIGN KEY (messageID) REFERENCES message(messageID), + FOREIGN KEY (channelPath) REFERENCES channel(channelPath), + PRIMARY KEY (channelPath, messageID) + )` + + createPendingSignatures = ` + CREATE TABLE IF NOT EXISTS pendingSignatures ( + messageID TEXT, + witness TEXT, + signature TEXT UNIQUE, + PRIMARY KEY (messageID, witness) + )` +) + +const ( + insertChannelMessage = `INSERT INTO channelMessage (channelPath, messageID, isBaseChannel) VALUES (?, ?, ?)` + insertMessage = `INSERT INTO message (messageID, message, messageData, storedTime) VALUES (?, ?, ?, ?)` + insertChannel = `INSERT INTO channel (channelPath, typeID, laoPath) VALUES (?, ?, ?)` + insertOrIgnoreChannel = `INSERT OR IGNORE INTO channel (channelPath, typeID, laoPath) VALUES (?, ?, ?)` + insertKeys = `INSERT INTO key (channelPath, publicKey, secretKey) VALUES (?, ?, ?)` + insertPublicKey = `INSERT INTO key (channelPath, publicKey) VALUES (?, ?)` + insertPendingSignatures = `INSERT INTO pendingSignatures (messageID, witness, signature) VALUES (?, ?, ?)` +) + +const ( + selectKeys = `SELECT publicKey, secretKey FROM key WHERE channelPath = ?` + + selectPublicKey = `SELECT publicKey FROM key WHERE channelPath = ?` + + selectSecretKey = `SELECT secretKey FROM key WHERE channelPath = ?` + + selectPendingSignatures = `SELECT witness, signature FROM pendingSignatures WHERE messageID = ?` + + selectMessage = `SELECT message FROM message WHERE messageID = ?` + + selectAllChannels = `SELECT channelPath FROM channel` + + selectChannelType = `SELECT type FROM channelType JOIN channel on channel.typeID = channelType.ID WHERE channelPath = ?` + + selectAllMessagesFromChannel = ` + SELECT message.message + FROM message + JOIN channelMessage ON message.messageID = channelMessage.messageID + WHERE channelMessage.channelPath = ? + ORDER BY message.storedTime DESC` + + selectBaseChannelMessages = `SELECT messageID, channelPath FROM channelMessage WHERE isBaseChannel = ?` + + selectChannelPath = `SELECT channelPath FROM channel WHERE channelPath = ?` + + selectMessageID = `SELECT messageID FROM message WHERE messageID = ?` + + selectLastRollCallMessage = ` + SELECT json_extract(messageData, '$.action') + FROM message + WHERE storedTime = ( + SELECT MAX(storedTime) + FROM ( + SELECT * + FROM message + JOIN channelMessage ON message.messageID = channelMessage.messageID + ) + WHERE json_extract(messageData, '$.object') = ? AND channelPath = ? + )` + + selectLastRollCallMessageInList = ` + SELECT message.messageData, json_extract(message.messageData, '$.action') + FROM message + JOIN channelMessage ON message.messageID = channelMessage.messageID + WHERE channelMessage.channelPath = ? + AND json_extract(message.messageData, '$.object') = ? + AND json_extract(message.messageData, '$.action') IN (?, ?) + ORDER BY message.storedTime DESC + LIMIT 1` + + selectLaoWitnesses = ` + SELECT json_extract(messageData, '$.witnesses') + FROM ( + SELECT * + FROM message + JOIN channelMessage ON message.messageID = channelMessage.messageID + ) + WHERE channelPath = ? + AND json_extract(messageData, '$.object') = ? + AND json_extract(messageData, '$.action') = ?` + + selectLaoOrganizerKey = ` + SELECT publicKey + FROM key + WHERE channelPath = ( + SELECT laoPath + FROM channel + WHERE channelPath = ? + ) +` + + selectLastElectionMessage = ` + SELECT json_extract(messageData, '$.action') + FROM message + WHERE storedTime = ( + SELECT MAX(storedTime) + FROM ( + SELECT * + FROM message + JOIN channelMessage ON message.messageID = channelMessage.messageID + ) + WHERE channelPath = ? + AND json_extract(messageData, '$.object') = ? + AND json_extract(messageData, '$.action') != ? + )` + + selectElectionCreationTime = ` + SELECT json_extract(messageData, '$.created_at') + FROM ( + SELECT * + FROM message + JOIN channelMessage ON message.messageID = channelMessage.messageID + ) + WHERE channelPath = ? + AND json_extract(messageData, '$.object') = ? + AND json_extract(messageData, '$.action') = ?` + + selectElectionType = ` + SELECT json_extract(messageData, '$.version') + FROM ( + SELECT * + FROM message + JOIN channelMessage ON message.messageID = channelMessage.messageID + ) + WHERE channelPath = ? + AND json_extract(messageData, '$.object') = ? + AND json_extract(messageData, '$.action') = ?` + + selectElectionAttendees = ` + SELECT joined.messageData + FROM ( + SELECT * + FROM message + JOIN channelMessage ON message.messageID = channelMessage.messageID + ) joined + JOIN channel c ON joined.channelPath = c.laoPath + WHERE c.channelPath = ? + AND json_extract(joined.messageData, '$.object') = ? + AND json_extract(joined.messageData, '$.action') = ? + AND joined.storedTime = ( + SELECT MAX(storedTime) + FROM ( + SELECT * + FROM message + JOIN channelMessage ON message.messageID = channelMessage.messageID + ) + WHERE channelPath = c.laoPath + AND json_extract(messageData, '$.object') = ? + AND json_extract(messageData, '$.action') = ? + )` + + selectElectionSetup = ` + SELECT messageData + FROM ( + SELECT * + FROM message + JOIN channelMessage ON message.messageID = channelMessage.messageID + ) + WHERE channelPath = ? + AND json_extract(messageData, '$.object') = ? + AND json_extract(messageData, '$.action') = ?` + + selectCastVotes = ` + SELECT messageData, messageID, json_extract(message, '$.sender') + FROM ( + SELECT * + FROM message + JOIN channelMessage ON message.messageID = channelMessage.messageID + ) + WHERE channelPath = ? + AND json_extract(messageData, '$.object') = ? + AND json_extract(messageData, '$.action') = ?` + + selectLastRollCallClose = ` + SELECT messageData + FROM message + WHERE storedTime = ( + SELECT MAX(storedTime) + FROM ( + SELECT * + FROM message + JOIN channelMessage ON message.messageID = channelMessage.messageID + ) + WHERE channelPath = ? + AND json_extract(messageData, '$.object') = ? + AND json_extract(messageData, '$.action') = ? + )` + + selectSender = ` + SELECT json_extract(message, '$.sender'), + json_extract(messageData, '$.object'), + json_extract(messageData, '$.action') + FROM message + WHERE messageID = ?` +) + +const ( + deletePendingSignatures = `DELETE FROM pendingSignatures WHERE messageID = ?` +) + +const ( + updateMsg = `UPDATE OR IGNORE message SET message = json_insert(message,'$.witness_signatures[#]', json(?)) WHERE messageID = ?` +) diff --git a/be1-go/internal/popserver/database/sqlite/sqlite_init.go b/be1-go/internal/popserver/database/sqlite/sqlite_init.go new file mode 100644 index 0000000000..cbbf0a0ed8 --- /dev/null +++ b/be1-go/internal/popserver/database/sqlite/sqlite_init.go @@ -0,0 +1,119 @@ +package sqlite + +import ( + "database/sql" + database2 "popstellar/internal/popserver/database/repository" + "sync" +) + +var dbLock sync.RWMutex + +// SQLite is a wrapper around the SQLite database. +type SQLite struct { + database2.Repository + database *sql.DB +} + +//====================================================================================================================== +// Database initialization +//====================================================================================================================== + +// NewSQLite returns a new SQLite instance. +func NewSQLite(path string, foreignKeyOn bool) (SQLite, error) { + dbLock.Lock() + defer dbLock.Unlock() + + db, err := sql.Open("sqlite", path) + if err != nil { + return SQLite{}, err + } + + if !foreignKeyOn { + _, err = db.Exec("PRAGMA foreign_keys = OFF;") + if err != nil { + db.Close() + return SQLite{}, err + } + } + + tx, err := db.Begin() + if err != nil { + db.Close() + return SQLite{}, err + } + defer tx.Rollback() + + _, err = tx.Exec(createMessage) + if err != nil { + db.Close() + return SQLite{}, err + } + + _, err = tx.Exec(createChannelType) + if err != nil { + db.Close() + return SQLite{}, err + } + + err = fillChannelTypes(tx) + if err != nil { + db.Close() + return SQLite{}, err + } + + _, err = tx.Exec(createKey) + if err != nil { + db.Close() + return SQLite{}, err + } + + _, err = tx.Exec(createChannel) + if err != nil { + db.Close() + return SQLite{}, err + } + + _, err = tx.Exec(insertOrIgnoreChannel, "/root", channelTypeToID[RootType], "") + if err != nil { + db.Close() + return SQLite{}, err + } + + _, err = tx.Exec(createChannelMessage) + if err != nil { + db.Close() + return SQLite{}, err + } + + _, err = tx.Exec(createPendingSignatures) + if err != nil { + db.Close() + return SQLite{}, err + } + + err = tx.Commit() + if err != nil { + db.Close() + return SQLite{}, err + } + + return SQLite{database: db}, nil +} + +// Close closes the SQLite database. +func (s *SQLite) Close() error { + dbLock.Lock() + defer dbLock.Unlock() + + return s.database.Close() +} + +func fillChannelTypes(tx *sql.Tx) error { + for _, channelType := range channelTypes { + _, err := tx.Exec("INSERT INTO channelType (type) VALUES (?)", channelType) + if err != nil { + return err + } + } + return nil +} diff --git a/be1-go/internal/popserver/database/sqlite/sqlite_test.go b/be1-go/internal/popserver/database/sqlite/sqlite_test.go new file mode 100644 index 0000000000..c570d274ef --- /dev/null +++ b/be1-go/internal/popserver/database/sqlite/sqlite_test.go @@ -0,0 +1,828 @@ +package sqlite + +import ( + "encoding/base64" + "encoding/json" + "github.com/stretchr/testify/require" + "os" + "path/filepath" + "popstellar/crypto" + "popstellar/internal/popserver/generatortest" + "popstellar/internal/popserver/types" + "popstellar/message/messagedata" + "popstellar/message/query/method/message" + "sort" + "testing" +) + +//====================================================================================================================== +// Repository interface implementation tests +//====================================================================================================================== + +func Test_SQLite_GetMessageByID(t *testing.T) { + lite, dir, err := newFakeSQLite(t) + require.NoError(t, err) + + defer lite.Close() + defer os.RemoveAll(dir) + + testMessages := newTestMessages() + for _, m := range testMessages { + err = lite.StoreMessageAndData(m.channel, m.msg) + require.NoError(t, err) + } + + expected := []message.Message{testMessages[0].msg, + testMessages[1].msg, + testMessages[2].msg, + testMessages[3].msg} + IDs := []string{"ID1", "ID2", "ID3", "ID4"} + for i, elem := range IDs { + msg, err := lite.GetMessageByID(elem) + require.NoError(t, err) + require.Equal(t, expected[i], msg) + } +} + +func Test_SQLite_GetMessagesByID(t *testing.T) { + lite, dir, err := newFakeSQLite(t) + require.NoError(t, err) + defer lite.Close() + defer os.RemoveAll(dir) + + testMessages := newTestMessages() + for _, m := range testMessages { + err = lite.StoreMessageAndData(m.channel, m.msg) + require.NoError(t, err) + } + + IDs := []string{"ID1", "ID2", "ID3", "ID4"} + expected := map[string]message.Message{"ID1": testMessages[0].msg, + "ID2": testMessages[1].msg, + "ID3": testMessages[2].msg, + "ID4": testMessages[3].msg} + + messages, err := lite.GetMessagesByID(IDs) + require.NoError(t, err) + require.Equal(t, expected, messages) +} + +func Test_SQLite_AddWitnessSignature(t *testing.T) { + lite, dir, err := newFakeSQLite(t) + require.NoError(t, err) + defer lite.Close() + defer os.RemoveAll(dir) + + testMessages := newTestMessages() + for _, m := range testMessages { + err = lite.StoreMessageAndData(m.channel, m.msg) + require.NoError(t, err) + } + // Add signatures to message1 + expected := []message.WitnessSignature{{Witness: "witness1", Signature: "sig1"}, {Witness: "witness2", Signature: "sig2"}} + err = lite.AddWitnessSignature("ID1", "witness1", "sig1") + require.NoError(t, err) + err = lite.AddWitnessSignature("ID1", "witness2", "sig2") + require.NoError(t, err) + + //Verify that the signature have been added to the message + msg1, err := lite.GetMessageByID("ID1") + require.NoError(t, err) + require.Equal(t, expected, msg1.WitnessSignatures) + + message4 := message.Message{Data: base64.URLEncoding.EncodeToString([]byte("data4")), + Sender: "sender4", + Signature: "sig4", + MessageID: "ID5", + WitnessSignatures: []message.WitnessSignature{}, + } + + // Add signatures to message4 who is not currently stored + err = lite.AddWitnessSignature("ID5", "witness2", "sig3") + require.NoError(t, err) + + //Verify that the signature has been added to the message + err = lite.StoreMessageAndData("channel1", message4) + require.NoError(t, err) + expected = []message.WitnessSignature{{Witness: "witness2", Signature: "sig3"}} + msg4, err := lite.GetMessageByID("ID5") + require.NoError(t, err) + require.Equal(t, expected, msg4.WitnessSignatures) +} + +//====================================================================================================================== +// Helper functions +//====================================================================================================================== + +func newFakeSQLite(t *testing.T) (SQLite, string, error) { + dir, err := os.MkdirTemp("", "test-") + require.NoError(t, err) + + fn := filepath.Join(dir, "test.DB") + lite, err := NewSQLite(fn, false) + require.NoError(t, err) + + return lite, dir, nil +} + +type testMessage struct { + msg message.Message + channel string +} + +func newTestMessages() []testMessage { + message1 := message.Message{Data: base64.URLEncoding.EncodeToString([]byte("data1")), + Sender: "sender1", + Signature: "sig1", + MessageID: "ID1", + WitnessSignatures: []message.WitnessSignature{}, + } + + message2 := message.Message{Data: base64.URLEncoding.EncodeToString([]byte("data2")), + Sender: "sender2", + Signature: "sig2", + MessageID: "ID2", + WitnessSignatures: []message.WitnessSignature{}, + } + + message3 := message.Message{Data: base64.URLEncoding.EncodeToString([]byte("data3")), + Sender: "sender3", + Signature: "sig3", + MessageID: "ID3", + WitnessSignatures: []message.WitnessSignature{}, + } + message4 := message3 + message4.MessageID = "ID4" + + return []testMessage{{msg: message1, channel: "channel1"}, + {msg: message2, channel: "channel2"}, + {msg: message3, channel: "channel1/subChannel1"}, + {msg: message4, channel: "channel1"}, + } +} + +func Test_SQLite_GetAllMessagesFromChannel(t *testing.T) { + lite, dir, err := newFakeSQLite(t) + require.NoError(t, err) + defer lite.Close() + defer os.RemoveAll(dir) + + testMessages := newTestMessages() + for _, m := range testMessages { + err = lite.StoreMessageAndData(m.channel, m.msg) + require.NoError(t, err) + } + + expected := []message.Message{testMessages[3].msg, testMessages[0].msg} + messages, err := lite.GetAllMessagesFromChannel("channel1") + require.NoError(t, err) + require.Equal(t, expected, messages) +} + +func Test_SQLite_GetChannelType(t *testing.T) { + lite, dir, err := newFakeSQLite(t) + require.NoError(t, err) + defer lite.Close() + defer os.RemoveAll(dir) + + err = lite.StoreChannel("channel1", "root", "") + require.NoError(t, err) + + channelType, err := lite.GetChannelType("channel1") + require.NoError(t, err) + require.Equal(t, "root", channelType) +} + +func Test_SQLite_GetResultForGetMessagesByID(t *testing.T) { + lite, dir, err := newFakeSQLite(t) + require.NoError(t, err) + defer lite.Close() + defer os.RemoveAll(dir) + + testMessages := newTestMessages() + for _, m := range testMessages { + err = lite.StoreMessageAndData(m.channel, m.msg) + require.NoError(t, err) + } + + expected := map[string][]message.Message{ + "channel1": {testMessages[0].msg, testMessages[3].msg}, + "channel2": {testMessages[1].msg}, + "channel1/subChannel1": {testMessages[2].msg}} + params := map[string][]string{ + "channel1": {"ID1", "ID4"}, + "channel2": {"ID2"}, + "channel1/subChannel1": {"ID3"}} + result, err := lite.GetResultForGetMessagesByID(params) + require.NoError(t, err) + require.Equal(t, expected, result) +} + +func Test_SQLite_GetParamsForGetMessageByID(t *testing.T) { + lite, dir, err := newFakeSQLite(t) + require.NoError(t, err) + defer lite.Close() + defer os.RemoveAll(dir) + + testMessages := newTestMessages() + for _, m := range testMessages { + err = lite.StoreMessageAndData(m.channel, m.msg) + require.NoError(t, err) + } + params := map[string][]string{ + "channel1": {"other_ID1", "other_ID4", "ID1", "ID4"}, + "channel2": {"other_ID2", "ID2"}, + "channel1/subChannel1": {"other_ID3", "ID3"}, + "other_channel": {"other_ID5", "other_ID6"}} + + expected := map[string][]string{ + "channel1": {"other_ID1", "other_ID4"}, + "channel2": {"other_ID2"}, + "channel1/subChannel1": {"other_ID3"}, + "other_channel": {"other_ID5", "other_ID6"}} + result, err := lite.GetParamsForGetMessageByID(params) + require.NoError(t, err) + require.Equal(t, expected, result) +} + +func Test_SQLite_HasChannel(t *testing.T) { + lite, dir, err := newFakeSQLite(t) + require.NoError(t, err) + defer lite.Close() + defer os.RemoveAll(dir) + + err = lite.StoreChannel( + "channel1", + "root", + "") + require.NoError(t, err) + + ok, err := lite.HasChannel("channel1") + require.NoError(t, err) + require.True(t, ok) + + ok, err = lite.HasChannel("channel2") + require.NoError(t, err) + require.False(t, ok) +} + +func TestSQLite_HasMessage(t *testing.T) { + lite, dir, err := newFakeSQLite(t) + require.NoError(t, err) + defer lite.Close() + defer os.RemoveAll(dir) + + message5 := message.Message{Data: base64.URLEncoding.EncodeToString([]byte("data5")), + Sender: "sender5", + Signature: "sig5", + MessageID: "ID5", + WitnessSignatures: []message.WitnessSignature{}, + } + + err = lite.StoreMessageAndData("channel1", message5) + require.NoError(t, err) + + ok, err := lite.HasMessage("ID5") + require.NoError(t, err) + require.True(t, ok) + + ok, err = lite.HasMessage("ID1") + require.NoError(t, err) + require.False(t, ok) +} + +func Test_SQLite_StoreLaoWithLaoGreet(t *testing.T) { + lite, dir, err := newFakeSQLite(t) + require.NoError(t, err) + defer lite.Close() + defer os.RemoveAll(dir) + + channels := map[string]string{ + "laoPath": "lao", + "channel1": "chirp", + "channel2": "coin", + "channel3": "auth", + "channel4": "consensus", + "channel5": "reaction"} + + secret := crypto.Suite.Scalar().Pick(crypto.Suite.RandomStream()) + point := crypto.Suite.Point().Mul(secret, nil) + organizerPubKey := point + organizerPubBuf, err := organizerPubKey.MarshalBinary() + require.NoError(t, err) + + organizerPubBuf64 := base64.URLEncoding.EncodeToString(organizerPubBuf) + + laoID := "laoID" + + laoCreateMsg := generatortest.NewLaoCreateMsg(t, "sender1", laoID, "laoName", 123456789, + organizerPubBuf64, nil) + + laoGreet := messagedata.LaoGreet{ + Object: "lao", + Action: "greet", + LaoID: laoID, + Frontend: "frontend", + Address: "address", + Peers: []messagedata.Peer{{Address: "peer1"}, {Address: "peer2"}}, + } + laoGreetBytes, err := json.Marshal(laoGreet) + require.NoError(t, err) + + laoGreetMsg := message.Message{Data: base64.URLEncoding.EncodeToString(laoGreetBytes), + Sender: "sender2", + Signature: "sig2", + MessageID: "ID2", + WitnessSignatures: []message.WitnessSignature{}} + + err = lite.StoreLaoWithLaoGreet(channels, laoID, organizerPubBuf, laoCreateMsg, laoGreetMsg) + require.NoError(t, err) + + expected := []message.Message{laoGreetMsg, laoCreateMsg} + + sort.Slice(expected, func(i, j int) bool { + return expected[i].MessageID < expected[j].MessageID + }) + messages, err := lite.GetAllMessagesFromChannel(laoID) + require.NoError(t, err) + + sort.Slice(expected, func(i, j int) bool { + return messages[i].MessageID < messages[j].MessageID + }) + require.Equal(t, expected, messages) + + expected = []message.Message{laoCreateMsg} + messages, err = lite.GetAllMessagesFromChannel("/root") + require.NoError(t, err) + require.Equal(t, expected, messages) + + for channel, expectedType := range channels { + ok, err := lite.HasChannel(channel) + require.NoError(t, err) + require.True(t, ok) + channelType, err := lite.GetChannelType(channel) + require.NoError(t, err) + require.Equal(t, expectedType, channelType) + } + + returnedKey, err := lite.GetOrganizerPubKey(laoID) + require.NoError(t, err) + organizerPubKey.Equal(returnedKey) + require.True(t, organizerPubKey.Equal(returnedKey)) + + // Test that we can retrieve the organizer public key from the election channel + electionPath := "electionID" + err = lite.StoreChannel(electionPath, "election", laoID) + require.NoError(t, err) + returnedKey, err = lite.GetLAOOrganizerPubKey(electionPath) + require.NoError(t, err) + require.True(t, organizerPubKey.Equal(returnedKey)) + +} + +func Test_SQLite_GetRollCallState(t *testing.T) { + lite, dir, err := newFakeSQLite(t) + require.NoError(t, err) + defer lite.Close() + defer os.RemoveAll(dir) + + rollCallCreate := generatortest.NewRollCallCreateMsg(t, "sender1", "name", "createID", 1, 2, 10, nil) + rollCallOpen := generatortest.NewRollCallOpenMsg(t, "sender1", "openID", "createID", 4, nil) + rollCallClose := generatortest.NewRollCallCloseMsg(t, "sender1", "closeID", "openID", 8, nil, nil) + states := []string{"create", "open", "close"} + messages := []message.Message{rollCallCreate, rollCallOpen, rollCallClose} + + for i, msg := range messages { + err = lite.StoreMessageAndData("channel1", msg) + require.NoError(t, err) + state, err := lite.GetRollCallState("channel1") + require.NoError(t, err) + require.Equal(t, states[i], state) + } +} + +func Test_SQLite_CheckPrevOpenOrReopenID(t *testing.T) { + lite, dir, err := newFakeSQLite(t) + require.NoError(t, err) + defer lite.Close() + defer os.RemoveAll(dir) + + rollCallOpen := generatortest.NewRollCallOpenMsg(t, "sender1", "openID", "createID", 4, nil) + rollCallReopen := generatortest.NewRollCallReOpenMsg(t, "sender1", "reopenID", "closeID", 12, nil) + + err = lite.StoreMessageAndData("channel1", rollCallOpen) + require.NoError(t, err) + + ok, err := lite.CheckPrevOpenOrReopenID("channel1", "openID") + require.NoError(t, err) + require.True(t, ok) + + err = lite.StoreMessageAndData("channel1", rollCallReopen) + require.NoError(t, err) + + ok, err = lite.CheckPrevOpenOrReopenID("channel1", "reopenID") + require.NoError(t, err) + require.True(t, ok) +} + +func Test_SQLite_CheckPrevCreateOrCloseID(t *testing.T) { + lite, dir, err := newFakeSQLite(t) + require.NoError(t, err) + defer lite.Close() + defer os.RemoveAll(dir) + + rollCallCreate := generatortest.NewRollCallCreateMsg(t, "sender1", "name", "createID", 1, 2, 10, nil) + rollCallClose := generatortest.NewRollCallCloseMsg(t, "sender1", "closeID", "openID", 8, nil, nil) + + err = lite.StoreMessageAndData("channel1", rollCallCreate) + require.NoError(t, err) + + ok, err := lite.CheckPrevCreateOrCloseID("channel1", "createID") + require.NoError(t, err) + require.True(t, ok) + + err = lite.StoreMessageAndData("channel1", rollCallClose) + require.NoError(t, err) + + ok, err = lite.CheckPrevCreateOrCloseID("channel1", "closeID") + require.NoError(t, err) + require.True(t, ok) +} + +func Test_SQLite_StoreRollCallClose(t *testing.T) { + lite, dir, err := newFakeSQLite(t) + require.NoError(t, err) + defer lite.Close() + defer os.RemoveAll(dir) + + channels := []string{"channel1", "channel2", "channel3"} + laoID := "laoID" + + rollCallClose := generatortest.NewRollCallCloseMsg(t, "sender1", "closeID", "openID", 8, nil, nil) + + err = lite.StoreRollCallClose(channels, laoID, rollCallClose) + require.NoError(t, err) + + expected := []message.Message{rollCallClose} + messages, err := lite.GetAllMessagesFromChannel(laoID) + require.NoError(t, err) + require.Equal(t, expected, messages) + + for _, channel := range channels { + ok, err := lite.HasChannel(channel) + require.NoError(t, err) + require.True(t, ok) + } +} + +func Test_SQLite_StoreElectionWithElectionKey(t *testing.T) { + lite, dir, err := newFakeSQLite(t) + require.NoError(t, err) + defer lite.Close() + defer os.RemoveAll(dir) + + laoID := "laoID" + electionID := "electionID" + secret := crypto.Suite.Scalar().Pick(crypto.Suite.RandomStream()) + point := crypto.Suite.Point().Mul(secret, nil) + + electionPubBuf, err := point.MarshalBinary() + require.NoError(t, err) + + electionSetupMsg := generatortest.NewElectionSetupMsg(t, "sender1", "ID1", laoID, "electionName", + "version", 1, 2, 3, nil, nil) + + electionKey := messagedata.ElectionKey{ + Object: "election", + Action: "key", + Election: electionID, + Key: base64.URLEncoding.EncodeToString(electionPubBuf), + } + + electionKeyBytes, err := json.Marshal(electionKey) + require.NoError(t, err) + + electionKeyMsg := message.Message{ + Data: base64.URLEncoding.EncodeToString(electionKeyBytes), + Sender: "sender1", + Signature: "sig1", + MessageID: "ID2", + WitnessSignatures: []message.WitnessSignature{}, + } + + err = lite.StoreElectionWithElectionKey(laoID, electionID, point, secret, electionSetupMsg, electionKeyMsg) + require.NoError(t, err) + + expected := []message.Message{electionSetupMsg} + messages, err := lite.GetAllMessagesFromChannel(laoID) + require.NoError(t, err) + require.Equal(t, expected, messages) + + expected = []message.Message{electionKeyMsg, electionSetupMsg} + messages, err = lite.GetAllMessagesFromChannel(electionID) + require.NoError(t, err) + require.Equal(t, expected, messages) + + returnedSecretKey, err := lite.GetElectionSecretKey(electionID) + require.NoError(t, err) + require.True(t, secret.Equal(returnedSecretKey)) +} + +func Test_SQLite_StoreElection(t *testing.T) { + lite, dir, err := newFakeSQLite(t) + require.NoError(t, err) + defer lite.Close() + defer os.RemoveAll(dir) + + laoID := "laoID" + electionID := "electionID" + secret := crypto.Suite.Scalar().Pick(crypto.Suite.RandomStream()) + point := crypto.Suite.Point().Mul(secret, nil) + + electionSetupMsg := generatortest.NewElectionSetupMsg(t, "sender1", "ID1", laoID, "electionName", + "version", 1, 2, 3, nil, nil) + + err = lite.StoreElection(laoID, electionID, point, secret, electionSetupMsg) + require.NoError(t, err) + + expected := []message.Message{electionSetupMsg} + messages, err := lite.GetAllMessagesFromChannel(laoID) + require.NoError(t, err) + require.Equal(t, expected, messages) + + expected = []message.Message{electionSetupMsg} + messages, err = lite.GetAllMessagesFromChannel(electionID) + require.NoError(t, err) + require.Equal(t, expected, messages) + + returnedSecretKey, err := lite.GetElectionSecretKey(electionID) + require.NoError(t, err) + require.True(t, secret.Equal(returnedSecretKey)) +} + +func Test_SQLite_IsElectionStartedOrTerminated(t *testing.T) { + lite, dir, err := newFakeSQLite(t) + require.NoError(t, err) + defer lite.Close() + defer os.RemoveAll(dir) + + electionPath := "electionPath" + electionID := "electionID" + laoID := "laoID" + ok, err := lite.IsElectionStartedOrEnded(electionPath) + require.NoError(t, err) + require.False(t, ok) + + electionOpenMsg := generatortest.NewElectionOpenMsg(t, "sender1", laoID, electionID, 1, nil) + + err = lite.StoreMessageAndData(electionID, electionOpenMsg) + require.NoError(t, err) + ok, err = lite.IsElectionStartedOrEnded(electionID) + require.NoError(t, err) + require.True(t, ok) + + ok, err = lite.IsElectionStarted(electionID) + require.NoError(t, err) + require.True(t, ok) + + ok, err = lite.IsElectionEnded(electionID) + require.NoError(t, err) + require.False(t, ok) + + electionCloseMsg := generatortest.NewElectionCloseMsg(t, "sender1", laoID, electionID, "", 1, nil) + + err = lite.StoreMessageAndData(electionID, electionCloseMsg) + require.NoError(t, err) + ok, err = lite.IsElectionStartedOrEnded(electionID) + require.NoError(t, err) + require.True(t, ok) + + ok, err = lite.IsElectionEnded(electionID) + require.NoError(t, err) + require.True(t, ok) + + ok, err = lite.IsElectionStarted(electionID) + require.NoError(t, err) + require.False(t, ok) +} + +func Test_SQLite_GetElectionCreationTimeAndType(t *testing.T) { + lite, dir, err := newFakeSQLite(t) + require.NoError(t, err) + defer lite.Close() + defer os.RemoveAll(dir) + + laoPath := "laoPath" + electionPath := "electionPath" + creationTime := int64(123456789) + + electionSetupMsg := generatortest.NewElectionSetupMsg(t, "sender1", "ID1", laoPath, "electionName", + messagedata.OpenBallot, creationTime, 2, 3, nil, nil) + + err = lite.StoreMessageAndData(electionPath, electionSetupMsg) + require.NoError(t, err) + + returnedTime, err := lite.GetElectionCreationTime(electionPath) + require.NoError(t, err) + require.Equal(t, creationTime, returnedTime) + + electionType, err := lite.GetElectionType(electionPath) + require.NoError(t, err) + require.Equal(t, messagedata.OpenBallot, electionType) +} + +func Test_SQLite_GetElectionAttendees(t *testing.T) { + lite, dir, err := newFakeSQLite(t) + require.NoError(t, err) + defer lite.Close() + defer os.RemoveAll(dir) + + electionID := "electionID" + laoID := "laoID" + attendees := []string{"attendee1", "attendee2", "attendee3"} + expected := map[string]struct{}{"attendee1": {}, "attendee2": {}, "attendee3": {}} + + rollCallCloseMsg := generatortest.NewRollCallCloseMsg(t, "sender1", "closeID", "openID", 8, attendees, nil) + + err = lite.StoreMessageAndData(laoID, rollCallCloseMsg) + require.NoError(t, err) + + err = lite.StoreChannel(electionID, "election", laoID) + require.NoError(t, err) + + returnedAttendees, err := lite.GetElectionAttendees(electionID) + require.NoError(t, err) + require.Equal(t, expected, returnedAttendees) +} + +func Test_SQLite_GetElectionQuestionsWithVotes(t *testing.T) { + lite, dir, err := newFakeSQLite(t) + require.NoError(t, err) + defer lite.Close() + defer os.RemoveAll(dir) + + electionPath := "electionPath" + laoPath := "laoPath" + laoID := "laoID" + electionID := "electionID" + questions := []messagedata.ElectionSetupQuestion{ + { + ID: "questionID1", + Question: "question1", + VotingMethod: "Plurality", + BallotOptions: []string{"Option1", "Option2"}, + }, + } + + electionSetupMsg := generatortest.NewElectionSetupMsg(t, "sender1", "ID1", laoPath, "electionName", + messagedata.OpenBallot, 1, 2, 3, questions, nil) + + err = lite.StoreMessageAndData(electionPath, electionSetupMsg) + require.NoError(t, err) + + data64, err := base64.URLEncoding.DecodeString(electionSetupMsg.Data) + require.NoError(t, err) + + var electionSetup messagedata.ElectionSetup + err = json.Unmarshal(data64, &electionSetup) + require.NoError(t, err) + + expected, err := getQuestionsFromMessage(electionSetup) + require.NoError(t, err) + + // Add votes to the election + vote1 := generatortest.VoteString{ID: "voteID1", Question: "questionID1", Vote: "Option1"} + votes := []generatortest.VoteString{vote1} + castVoteMsg := generatortest.NewVoteCastVoteStringMsg(t, "sender1", laoID, electionID, + 1, votes, nil) + + err = lite.StoreMessageAndData(electionPath, castVoteMsg) + require.NoError(t, err) + + question1 := expected["questionID1"] + question1.ValidVotes = map[string]types.ValidVote{ + "sender1": {MsgID: castVoteMsg.MessageID, ID: "voteID1", VoteTime: 1, Index: "Option1"}, + } + expected["questionID1"] = question1 + + result, err := lite.GetElectionQuestionsWithValidVotes(electionPath) + require.NoError(t, err) + require.Equal(t, expected, result) + + // Add more votes to the election + vote2 := generatortest.VoteString{ID: "voteID2", Question: "questionID1", Vote: "Option2"} + votes = []generatortest.VoteString{vote2} + castVoteMsg = generatortest.NewVoteCastVoteStringMsg(t, "sender1", laoID, electionID, + 2, votes, nil) + + err = lite.StoreMessageAndData(electionPath, castVoteMsg) + require.NoError(t, err) + + question1 = expected["questionID1"] + question1.ValidVotes = map[string]types.ValidVote{ + "sender1": {MsgID: castVoteMsg.MessageID, ID: "voteID2", VoteTime: 2, Index: "Option2"}, + } + expected["questionID1"] = question1 + + result, err = lite.GetElectionQuestionsWithValidVotes(electionPath) + require.NoError(t, err) + require.Equal(t, expected, result) +} + +func Test_SQLite_StoreElectionEndWithResult(t *testing.T) { + lite, dir, err := newFakeSQLite(t) + require.NoError(t, err) + defer lite.Close() + defer os.RemoveAll(dir) + + electionPath := "electionPath" + laoID := "laoID" + electionID := "electionID" + + electionEndMsg := generatortest.NewElectionCloseMsg(t, "sender1", laoID, electionID, "", 1, nil) + electionResultMsg := generatortest.NewElectionResultMsg(t, "sender2", nil, nil) + + err = lite.StoreElectionEndWithResult(electionPath, electionEndMsg, electionResultMsg) + require.NoError(t, err) + + expected := []message.Message{electionEndMsg, electionResultMsg} + messages, err := lite.GetAllMessagesFromChannel(electionPath) + require.NoError(t, err) + require.Equal(t, expected, messages) +} + +func Test_SQLite_StoreChirpMessages(t *testing.T) { + lite, dir, err := newFakeSQLite(t) + require.NoError(t, err) + defer lite.Close() + defer os.RemoveAll(dir) + + chirpPath := "chirpID" + generalChirpPath := "generalChirpID" + + chirpMsg := generatortest.NewChirpAddMsg(t, "sender1", nil, 1) + generalChirpMsg := message.Message{ + Data: base64.URLEncoding.EncodeToString([]byte("data")), + Sender: "sender1", + Signature: "sig2", + MessageID: "ID2", + } + + err = lite.StoreChirpMessages(chirpPath, generalChirpPath, chirpMsg, generalChirpMsg) + require.NoError(t, err) + + expected := []message.Message{chirpMsg} + messages, err := lite.GetAllMessagesFromChannel(chirpPath) + require.NoError(t, err) + require.Equal(t, expected, messages) + + expected = []message.Message{generalChirpMsg} + messages, err = lite.GetAllMessagesFromChannel(generalChirpPath) + require.NoError(t, err) + require.Equal(t, expected, messages) + +} + +func Test_SQLite_IsAttendee(t *testing.T) { + lite, dir, err := newFakeSQLite(t) + require.NoError(t, err) + defer lite.Close() + defer os.RemoveAll(dir) + + attendees := []string{"attendee1", "attendee2", "attendee3"} + laoID := "laoID" + + rollCallCloseMsg := generatortest.NewRollCallCloseMsg(t, "sender1", "closeID", "openID", + 8, attendees, nil) + + err = lite.StoreMessageAndData(laoID, rollCallCloseMsg) + require.NoError(t, err) + + ok, err := lite.IsAttendee(laoID, "attendee1") + require.NoError(t, err) + require.True(t, ok) + + ok, err = lite.IsAttendee(laoID, "attendee4") + require.NoError(t, err) + require.False(t, ok) +} + +func Test_SQLite_GetReactionSender(t *testing.T) { + + lite, dir, err := newFakeSQLite(t) + require.NoError(t, err) + defer lite.Close() + defer os.RemoveAll(dir) + + reactionAddMsg := generatortest.NewReactionAddMsg(t, "sender1", nil, "", "chirpID", 1) + + sender, err := lite.GetReactionSender(reactionAddMsg.MessageID) + require.NoError(t, err) + require.Equal(t, "", sender) + + err = lite.StoreMessageAndData("channel1", reactionAddMsg) + require.NoError(t, err) + sender, err = lite.GetReactionSender(reactionAddMsg.MessageID) + require.NoError(t, err) + require.Equal(t, "sender1", sender) +} diff --git a/be1-go/internal/popserver/generatortest/chirp.go b/be1-go/internal/popserver/generatortest/chirp.go new file mode 100644 index 0000000000..e918aa9aec --- /dev/null +++ b/be1-go/internal/popserver/generatortest/chirp.go @@ -0,0 +1,45 @@ +package generatortest + +import ( + "encoding/json" + "github.com/stretchr/testify/require" + "go.dedis.ch/kyber/v3" + "popstellar/message/messagedata" + "popstellar/message/query/method/message" + "testing" +) + +func NewChirpAddMsg(t *testing.T, sender string, senderSK kyber.Scalar, timestamp int64) message.Message { + + chirpAdd := messagedata.ChirpAdd{ + Object: messagedata.ChirpObject, + Action: messagedata.ChirpActionAdd, + Text: "just a chirp", + Timestamp: timestamp, + } + + dataBuf, err := json.Marshal(chirpAdd) + require.NoError(t, err) + + msg := newMessage(t, sender, senderSK, dataBuf) + + return msg +} + +func NewChirpDeleteMsg(t *testing.T, sender string, senderSK kyber.Scalar, chirpID string, + timestamp int64) message.Message { + + chirpAdd := messagedata.ChirpDelete{ + Object: messagedata.ChirpObject, + Action: messagedata.ChirpActionDelete, + ChirpID: chirpID, + Timestamp: timestamp, + } + + dataBuf, err := json.Marshal(chirpAdd) + require.NoError(t, err) + + msg := newMessage(t, sender, senderSK, dataBuf) + + return msg +} diff --git a/be1-go/internal/popserver/generatortest/election.go b/be1-go/internal/popserver/generatortest/election.go new file mode 100644 index 0000000000..6c20e9ef28 --- /dev/null +++ b/be1-go/internal/popserver/generatortest/election.go @@ -0,0 +1,133 @@ +package generatortest + +import ( + "encoding/json" + "github.com/stretchr/testify/require" + "go.dedis.ch/kyber/v3" + "popstellar/message/messagedata" + "popstellar/message/query/method/message" + "testing" +) + +func NewElectionOpenMsg(t *testing.T, sender, lao, election string, openedAt int64, + senderSK kyber.Scalar) message.Message { + electionOpen := messagedata.ElectionOpen{ + Object: messagedata.ElectionObject, + Action: messagedata.ElectionActionOpen, + Lao: lao, + Election: election, + OpenedAt: openedAt, + } + + electionOpenBuf, err := json.Marshal(electionOpen) + require.NoError(t, err) + + msg := newMessage(t, sender, senderSK, electionOpenBuf) + + return msg +} + +func NewElectionCloseMsg(t *testing.T, sender, lao, election, registeredVotes string, openedAt int64, + senderSK kyber.Scalar) message.Message { + electionEnd := messagedata.ElectionEnd{ + Object: messagedata.ElectionObject, + Action: messagedata.ElectionActionEnd, + Lao: lao, + Election: election, + CreatedAt: openedAt, + RegisteredVotes: registeredVotes, + } + + electionEndBuf, err := json.Marshal(electionEnd) + require.NoError(t, err) + + msg := newMessage(t, sender, senderSK, electionEndBuf) + + return msg +} + +func NewElectionResultMsg(t *testing.T, sender string, questions []messagedata.ElectionResultQuestion, + senderSK kyber.Scalar) message.Message { + electionResult := messagedata.ElectionResult{ + Object: messagedata.ElectionObject, + Action: messagedata.ElectionActionResult, + Questions: questions, + } + + electionResultBuf, err := json.Marshal(electionResult) + require.NoError(t, err) + + msg := newMessage(t, sender, senderSK, electionResultBuf) + + return msg +} + +func NewVoteCastVoteIntMsg(t *testing.T, sender, lao, election string, createdAt int64, votes []VoteInt, + senderSK kyber.Scalar) message.Message { + castVote := VoteCastVoteInt{ + Object: messagedata.ElectionObject, + Action: messagedata.VoteActionCastVote, + Lao: lao, + Election: election, + CreatedAt: createdAt, + Votes: votes, + } + + castVoteBuf, err := json.Marshal(castVote) + require.NoError(t, err) + + return newMessage(t, sender, senderSK, castVoteBuf) +} + +func NewVoteCastVoteStringMsg(t *testing.T, sender, lao, election string, createdAt int64, votes []VoteString, + senderSK kyber.Scalar) message.Message { + castVote := VoteCastVoteString{ + Object: messagedata.ElectionObject, + Action: messagedata.VoteActionCastVote, + Lao: lao, + Election: election, + CreatedAt: createdAt, + Votes: votes, + } + + castVoteBuf, err := json.Marshal(castVote) + require.NoError(t, err) + + return newMessage(t, sender, senderSK, castVoteBuf) +} + +type VoteCastVoteInt struct { + Object string `json:"object"` + Action string `json:"action"` + Lao string `json:"lao"` + Election string `json:"election"` + + // CreatedAt is a Unix timestamp + CreatedAt int64 `json:"created_at"` + + Votes []VoteInt `json:"votes"` +} + +type VoteInt struct { + ID string `json:"id"` + Question string `json:"question"` + Vote int `json:"vote"` +} + +type VoteCastVoteString struct { + Object string `json:"object"` + Action string `json:"action"` + Lao string `json:"lao"` + Election string `json:"election"` + + // CreatedAt is a Unix timestamp + CreatedAt int64 `json:"created_at"` + + Votes []VoteString `json:"votes"` +} + +type VoteString struct { + ID string `json:"id"` + Question string `json:"question"` + Vote string `json:"vote"` +} diff --git a/be1-go/internal/popserver/generatortest/generatortest.go b/be1-go/internal/popserver/generatortest/generatortest.go new file mode 100644 index 0000000000..39d8838af2 --- /dev/null +++ b/be1-go/internal/popserver/generatortest/generatortest.go @@ -0,0 +1,67 @@ +package generatortest + +import ( + "encoding/base64" + "encoding/json" + "github.com/stretchr/testify/require" + "go.dedis.ch/kyber/v3" + "go.dedis.ch/kyber/v3/sign/schnorr" + "popstellar/crypto" + "popstellar/message/messagedata" + "popstellar/message/query/method/message" + "testing" +) + +func newMessage(t *testing.T, sender string, senderSK kyber.Scalar, data []byte) message.Message { + data64 := base64.URLEncoding.EncodeToString(data) + + signature64 := base64.URLEncoding.EncodeToString([]byte(sender)) + + if senderSK != nil { + signatureBuf, err := schnorr.Sign(crypto.Suite, senderSK, data) + require.NoError(t, err) + + signature64 = base64.URLEncoding.EncodeToString(signatureBuf) + } + + messageID64 := messagedata.Hash(data64, signature64) + + return message.Message{ + Data: data64, + Sender: sender, + Signature: signature64, + MessageID: messageID64, + WitnessSignatures: []message.WitnessSignature{}, + } +} + +func NewNothingMsg(t *testing.T, sender string, senderSK kyber.Scalar) message.Message { + data := struct { + Object string `json:"object"` + Action string `json:"action"` + Not string `json:"not"` + }{ + Object: "lao", + Action: "nothing", + Not: "no", + } + buf, err := json.Marshal(data) + require.NoError(t, err) + + return newMessage(t, sender, senderSK, buf) +} + +func NewNothingQuery(t *testing.T, id int) []byte { + wrongQuery := struct { + Jsonrpc string `json:"Jsonrpc"` + ID int `json:"ID"` + }{ + Jsonrpc: "2.0", + ID: id, + } + + wrongQueryBuf, err := json.Marshal(&wrongQuery) + require.NoError(t, err) + + return wrongQueryBuf +} diff --git a/be1-go/internal/popserver/generatortest/lao.go b/be1-go/internal/popserver/generatortest/lao.go new file mode 100644 index 0000000000..80e035460c --- /dev/null +++ b/be1-go/internal/popserver/generatortest/lao.go @@ -0,0 +1,137 @@ +package generatortest + +import ( + "encoding/json" + "github.com/stretchr/testify/require" + "go.dedis.ch/kyber/v3" + "popstellar/message/messagedata" + "popstellar/message/query/method/message" + "testing" +) + +func NewLaoStateMsg(t *testing.T, organizer, laoID, name, modificationID string, creation, lastModified int64, + organizerSK kyber.Scalar) message.Message { + laoState := messagedata.LaoState{ + Object: messagedata.LAOObject, + Action: messagedata.LAOActionState, + ID: laoID, + Name: name, + Creation: creation, + LastModified: lastModified, + Organizer: organizer, + Witnesses: []string{}, + ModificationID: modificationID, + ModificationSignatures: []messagedata.ModificationSignature{}, + } + + buf, err := json.Marshal(laoState) + require.NoError(t, err) + + msg := newMessage(t, organizer, organizerSK, buf) + + return msg +} + +func NewRollCallCreateMsg(t *testing.T, sender, laoName, createID string, creation, start, end int64, + senderSK kyber.Scalar) message.Message { + rollCallCreate := messagedata.RollCallCreate{ + Object: messagedata.RollCallObject, + Action: messagedata.RollCallActionCreate, + ID: createID, + Name: laoName, + Creation: creation, + ProposedStart: start, + ProposedEnd: end, + Location: "Location", + Description: "Description", + } + + buf, err := json.Marshal(rollCallCreate) + require.NoError(t, err) + + msg := newMessage(t, sender, senderSK, buf) + + return msg +} + +func NewRollCallOpenMsg(t *testing.T, sender, updateID, opens string, openedAt int64, + senderSK kyber.Scalar) message.Message { + + rollCallOpen := messagedata.RollCallOpen{ + Object: messagedata.RollCallObject, + Action: messagedata.RollCallActionOpen, + UpdateID: updateID, + Opens: opens, + OpenedAt: openedAt, + } + + buf, err := json.Marshal(rollCallOpen) + require.NoError(t, err) + + msg := newMessage(t, sender, senderSK, buf) + + return msg +} + +func NewRollCallReOpenMsg(t *testing.T, sender, updateID, opens string, openedAt int64, + senderSK kyber.Scalar) message.Message { + + rollCallReOpen := messagedata.RollCallOpen{ + Object: messagedata.RollCallObject, + Action: messagedata.RollCallActionReOpen, + UpdateID: updateID, + Opens: opens, + OpenedAt: openedAt, + } + + buf, err := json.Marshal(rollCallReOpen) + require.NoError(t, err) + + msg := newMessage(t, sender, senderSK, buf) + + return msg +} + +func NewRollCallCloseMsg(t *testing.T, sender, updateID, closes string, closedAt int64, attendees []string, + senderSK kyber.Scalar) message.Message { + + rollCallClose := messagedata.RollCallClose{ + Object: messagedata.RollCallObject, + Action: messagedata.RollCallActionClose, + UpdateID: updateID, + Closes: closes, + ClosedAt: closedAt, + Attendees: attendees, + } + + buf, err := json.Marshal(rollCallClose) + require.NoError(t, err) + + msg := newMessage(t, sender, senderSK, buf) + + return msg +} + +func NewElectionSetupMsg(t *testing.T, sender, ID, setupLao, electionName, version string, + createdAt, start, end int64, questions []messagedata.ElectionSetupQuestion, senderSK kyber.Scalar) message.Message { + + electionSetup := messagedata.ElectionSetup{ + Object: messagedata.ElectionObject, + Action: messagedata.ElectionActionSetup, + ID: ID, + Lao: setupLao, + Name: electionName, + Version: version, + CreatedAt: createdAt, + StartTime: start, + EndTime: end, + Questions: questions, + } + + buf, err := json.Marshal(electionSetup) + require.NoError(t, err) + + msg := newMessage(t, sender, senderSK, buf) + + return msg +} diff --git a/be1-go/internal/popserver/generatortest/query.go b/be1-go/internal/popserver/generatortest/query.go new file mode 100644 index 0000000000..43803a9308 --- /dev/null +++ b/be1-go/internal/popserver/generatortest/query.go @@ -0,0 +1,151 @@ +package generatortest + +import ( + "encoding/json" + "github.com/stretchr/testify/require" + jsonrpc "popstellar/message" + "popstellar/message/query" + "popstellar/message/query/method" + "popstellar/message/query/method/message" + "testing" +) + +func NewGreetServerQuery(t *testing.T, publicKey, clientAddress, serverAddress string) []byte { + serverInfo := method.GreetServerParams{ + PublicKey: publicKey, + ServerAddress: clientAddress, + ClientAddress: serverAddress, + } + + greetServer := method.GreetServer{ + Base: query.Base{ + JSONRPCBase: jsonrpc.JSONRPCBase{ + JSONRPC: "2.0", + }, + + Method: query.MethodGreetServer, + }, + Params: serverInfo, + } + + greetServerBuf, err := json.Marshal(&greetServer) + require.NoError(t, err) + + return greetServerBuf +} + +func NewSubscribeQuery(t *testing.T, queryID int, channel string) []byte { + subscribe := method.Subscribe{ + Base: query.Base{ + JSONRPCBase: jsonrpc.JSONRPCBase{ + JSONRPC: "2.0", + }, + + Method: query.MethodSubscribe, + }, + ID: queryID, + Params: method.SubscribeParams{Channel: channel}, + } + + subscribeBuf, err := json.Marshal(&subscribe) + require.NoError(t, err) + + return subscribeBuf +} + +func NewUnsubscribeQuery(t *testing.T, queryID int, channel string) []byte { + unsubscribe := method.Unsubscribe{ + Base: query.Base{ + JSONRPCBase: jsonrpc.JSONRPCBase{ + JSONRPC: "2.0", + }, + + Method: query.MethodUnsubscribe, + }, + ID: queryID, + Params: method.UnsubscribeParams{Channel: channel}, + } + + unsubscribeBuf, err := json.Marshal(&unsubscribe) + require.NoError(t, err) + + return unsubscribeBuf +} + +func NewPublishQuery(t *testing.T, queryID int, channel string, msg message.Message) []byte { + publish := method.Publish{ + Base: query.Base{ + JSONRPCBase: jsonrpc.JSONRPCBase{ + JSONRPC: "2.0", + }, + + Method: query.MethodPublish, + }, + ID: queryID, + Params: method.PublishParams{ + Channel: channel, + Message: msg, + }, + } + + publishBuf, err := json.Marshal(&publish) + require.NoError(t, err) + + return publishBuf +} + +func NewCatchupQuery(t *testing.T, queryID int, channel string) []byte { + catchup := method.Catchup{ + Base: query.Base{ + JSONRPCBase: jsonrpc.JSONRPCBase{ + JSONRPC: "2.0", + }, + + Method: query.MethodCatchUp, + }, + ID: queryID, + Params: method.CatchupParams{Channel: channel}, + } + + catchupBuf, err := json.Marshal(&catchup) + require.NoError(t, err) + + return catchupBuf +} + +func NewHeartbeatQuery(t *testing.T, msgIDsByChannel map[string][]string) []byte { + heartbeat := method.Heartbeat{ + Base: query.Base{ + JSONRPCBase: jsonrpc.JSONRPCBase{ + JSONRPC: "2.0", + }, + + Method: query.MethodHeartbeat, + }, + Params: msgIDsByChannel, + } + + heartbeatBuf, err := json.Marshal(&heartbeat) + require.NoError(t, err) + + return heartbeatBuf +} + +func NewGetMessagesByIDQuery(t *testing.T, queryID int, msgIDsByChannel map[string][]string) []byte { + getMessagesByID := method.GetMessagesById{ + Base: query.Base{ + JSONRPCBase: jsonrpc.JSONRPCBase{ + JSONRPC: "2.0", + }, + + Method: query.MethodGetMessagesById, + }, + ID: queryID, + Params: msgIDsByChannel, + } + + getMessagesByIDBuf, err := json.Marshal(&getMessagesByID) + require.NoError(t, err) + + return getMessagesByIDBuf +} diff --git a/be1-go/internal/popserver/generatortest/reaction.go b/be1-go/internal/popserver/generatortest/reaction.go new file mode 100644 index 0000000000..abe494c0bf --- /dev/null +++ b/be1-go/internal/popserver/generatortest/reaction.go @@ -0,0 +1,47 @@ +package generatortest + +import ( + "encoding/json" + "github.com/stretchr/testify/require" + "go.dedis.ch/kyber/v3" + "popstellar/message/messagedata" + "popstellar/message/query/method/message" + "testing" +) + +func NewReactionAddMsg(t *testing.T, sender string, senderSK kyber.Scalar, reactionCodePoint, ChirpID string, + timestamp int64) message.Message { + + reactionAdd := messagedata.ReactionAdd{ + Object: messagedata.ReactionObject, + Action: messagedata.ReactionActionAdd, + ReactionCodepoint: reactionCodePoint, + ChirpID: ChirpID, + Timestamp: timestamp, + } + + dataBuf, err := json.Marshal(reactionAdd) + require.NoError(t, err) + + msg := newMessage(t, sender, senderSK, dataBuf) + + return msg +} + +func NewReactionDeleteMsg(t *testing.T, sender string, senderSK kyber.Scalar, reactionID string, + timestamp int64) message.Message { + + reactionDelete := messagedata.ReactionDelete{ + Object: messagedata.ReactionObject, + Action: messagedata.ReactionActionDelete, + ReactionID: reactionID, + Timestamp: timestamp, + } + + dataBuf, err := json.Marshal(reactionDelete) + require.NoError(t, err) + + msg := newMessage(t, sender, senderSK, dataBuf) + + return msg +} diff --git a/be1-go/internal/popserver/generatortest/root.go b/be1-go/internal/popserver/generatortest/root.go new file mode 100644 index 0000000000..8a60e2cfda --- /dev/null +++ b/be1-go/internal/popserver/generatortest/root.go @@ -0,0 +1,30 @@ +package generatortest + +import ( + "encoding/json" + "github.com/stretchr/testify/require" + "go.dedis.ch/kyber/v3" + "popstellar/message/messagedata" + "popstellar/message/query/method/message" + "testing" +) + +func NewLaoCreateMsg(t *testing.T, sender string, ID string, laoName string, creation int64, organizer string, + senderSK kyber.Scalar) message.Message { + laoCreate := messagedata.LaoCreate{ + Object: messagedata.LAOObject, + Action: messagedata.LAOActionCreate, + ID: ID, + Name: laoName, + Creation: creation, + Organizer: organizer, + Witnesses: []string{}, + } + + buf, err := json.Marshal(laoCreate) + require.NoError(t, err) + + msg := newMessage(t, sender, senderSK, buf) + + return msg +} diff --git a/be1-go/internal/popserver/handler/answer.go b/be1-go/internal/popserver/handler/answer.go new file mode 100644 index 0000000000..73e37de183 --- /dev/null +++ b/be1-go/internal/popserver/handler/answer.go @@ -0,0 +1,124 @@ +package handler + +import ( + "encoding/json" + "popstellar/internal/popserver/state" + "popstellar/internal/popserver/utils" + "popstellar/message/answer" + "popstellar/message/query/method/message" + "sort" +) + +const maxRetry = 10 + +func handleAnswer(msg []byte) *answer.Error { + var answerMsg answer.Answer + + err := json.Unmarshal(msg, &answerMsg) + if err != nil { + errAnswer := answer.NewJsonUnmarshalError(err.Error()) + return errAnswer.Wrap("handleAnswer") + } + + if answerMsg.Result == nil { + utils.LogInfo("received an error, nothing to handle") + // don't send any error to avoid infinite error loop as a server will + // send an error to another server that will create another error + return nil + } + if answerMsg.Result.IsEmpty() { + utils.LogInfo("expected isn't an answer to a popquery, nothing to handle") + return nil + } + + errAnswer := state.SetQueryReceived(*answerMsg.ID) + if errAnswer != nil { + return errAnswer.Wrap("handleAnswer") + } + + errAnswer = handleGetMessagesByIDAnswer(answerMsg) + if errAnswer != nil { + return errAnswer.Wrap("handleAnswer") + } + + return nil +} + +func handleGetMessagesByIDAnswer(msg answer.Answer) *answer.Error { + result := msg.Result.GetMessagesByChannel() + msgsByChan := make(map[string]map[string]message.Message) + + // Unmarshal each message + for channelID, rawMsgs := range result { + msgsByChan[channelID] = make(map[string]message.Message) + for _, rawMsg := range rawMsgs { + var msg message.Message + err := json.Unmarshal(rawMsg, &msg) + if err == nil { + msgsByChan[channelID][msg.MessageID] = msg + continue + } + + errAnswer := answer.NewInvalidMessageFieldError("failed to unmarshal: %v", err) + utils.LogError(errAnswer.Wrap("handleGetMessagesByIDAnswer")) + } + + if len(msgsByChan[channelID]) == 0 { + delete(msgsByChan, channelID) + } + } + + // Handle every message and discard them if handled without error + handleMessagesByChannel(msgsByChan) + + return nil +} + +func handleMessagesByChannel(msgsByChannel map[string]map[string]message.Message) { + // Handle every messages + for i := 0; i < maxRetry; i++ { + // Sort by channelID length + sortedChannelIDs := getSortedChannels(msgsByChannel) + + tryToHandleMessages(msgsByChannel, sortedChannelIDs) + + if len(msgsByChannel) == 0 { + return + } + } +} + +func tryToHandleMessages(msgsByChannel map[string]map[string]message.Message, sortedChannelIDs []string) { + for _, channelID := range sortedChannelIDs { + msgs := msgsByChannel[channelID] + for msgID, msg := range msgs { + errAnswer := handleChannel(channelID, msg) + if errAnswer == nil { + delete(msgsByChannel[channelID], msgID) + continue + } + + if errAnswer.Code == answer.InvalidMessageFieldErrorCode { + delete(msgsByChannel[channelID], msgID) + } + + errAnswer = errAnswer.Wrap(msgID).Wrap("tryToHandleMessages") + utils.LogError(errAnswer) + } + + if len(msgsByChannel[channelID]) == 0 { + delete(msgsByChannel, channelID) + } + } +} + +func getSortedChannels(msgsByChannel map[string]map[string]message.Message) []string { + sortedChannelIDs := make([]string, 0) + for channelID := range msgsByChannel { + sortedChannelIDs = append(sortedChannelIDs, channelID) + } + sort.Slice(sortedChannelIDs, func(i, j int) bool { + return len(sortedChannelIDs[i]) < len(sortedChannelIDs[j]) + }) + return sortedChannelIDs +} diff --git a/be1-go/internal/popserver/handler/answer_test.go b/be1-go/internal/popserver/handler/answer_test.go new file mode 100644 index 0000000000..b6468d2302 --- /dev/null +++ b/be1-go/internal/popserver/handler/answer_test.go @@ -0,0 +1,108 @@ +package handler + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "github.com/stretchr/testify/require" + "go.dedis.ch/kyber/v3/sign/schnorr" + "popstellar/crypto" + "popstellar/internal/popserver/database" + "popstellar/internal/popserver/database/repository" + "popstellar/message/messagedata" + "popstellar/message/query/method/message" + "testing" + "time" +) + +func Test_handleMessagesByChannel(t *testing.T) { + mockRepository := repository.NewMockRepository(t) + database.SetDatabase(mockRepository) + + type input struct { + name string + messages map[string]map[string]message.Message + expected map[string]map[string]message.Message + } + + keypair := GenerateKeyPair(t) + now := time.Now().Unix() + name := "LAO X" + + laoID := messagedata.Hash(base64.URLEncoding.EncodeToString(keypair.PublicBuf), fmt.Sprintf("%d", now), name) + + data := messagedata.LaoCreate{ + Object: messagedata.LAOObject, + Action: messagedata.LAOActionCreate, + ID: laoID, + Name: name, + Creation: now, + Organizer: base64.URLEncoding.EncodeToString(keypair.PublicBuf), + Witnesses: []string{}, + } + + dataBuf, err := json.Marshal(data) + require.NoError(t, err) + signature, err := schnorr.Sign(crypto.Suite, keypair.Private, dataBuf) + require.NoError(t, err) + + dataBase64 := base64.URLEncoding.EncodeToString(dataBuf) + signatureBase64 := base64.URLEncoding.EncodeToString(signature) + + msgValid := message.Message{ + Data: dataBase64, + Sender: base64.URLEncoding.EncodeToString(keypair.PublicBuf), + Signature: signatureBase64, + MessageID: messagedata.Hash(dataBase64, signatureBase64), + WitnessSignatures: []message.WitnessSignature{}, + } + + msgWithInvalidField := message.Message{ + Data: "wrong data", + Sender: "wrong sender", + Signature: "wrong signature", + MessageID: "wrong messageID", + WitnessSignatures: []message.WitnessSignature{}, + } + + inputs := make([]input, 0) + + // blacklist without invalid field error + + messages := make(map[string]map[string]message.Message) + messages["/root"] = make(map[string]message.Message) + messages["/root"][msgValid.MessageID] = msgValid + messages["/root"][msgWithInvalidField.MessageID] = msgWithInvalidField + messages["/root/lao1"] = make(map[string]message.Message) + messages["/root/lao1"][msgValid.MessageID] = msgValid + messages["/root/lao1"][msgWithInvalidField.MessageID] = msgWithInvalidField + + expected := make(map[string]map[string]message.Message) + expected["/root"] = make(map[string]message.Message) + expected["/root"][msgValid.MessageID] = msgValid + expected["/root/lao1"] = make(map[string]message.Message) + expected["/root/lao1"][msgValid.MessageID] = msgValid + + mockRepository.On("HasMessage", msgValid.MessageID).Return(false, nil) + mockRepository.On("GetChannelType", "/root").Return("", nil) + mockRepository.On("GetChannelType", "/root/lao1").Return("", nil) + + inputs = append(inputs, input{ + name: "blacklist without invalid field error", + messages: messages, + expected: expected, + }) + + for _, i := range inputs { + t.Run(i.name, func(t *testing.T) { + handleMessagesByChannel(i.messages) + + for k0, v0 := range i.expected { + for k1 := range v0 { + require.Equal(t, i.expected[k0][k1], i.messages[k0][k1]) + } + } + }) + } + +} diff --git a/be1-go/internal/popserver/handler/channel.go b/be1-go/internal/popserver/handler/channel.go new file mode 100644 index 0000000000..43c08acfc3 --- /dev/null +++ b/be1-go/internal/popserver/handler/channel.go @@ -0,0 +1,184 @@ +package handler + +import ( + "encoding/base64" + "encoding/json" + "go.dedis.ch/kyber/v3" + "go.dedis.ch/kyber/v3/sign/schnorr" + "popstellar/crypto" + "popstellar/internal/popserver/config" + "popstellar/internal/popserver/database" + "popstellar/internal/popserver/database/sqlite" + "popstellar/internal/popserver/state" + "popstellar/internal/popserver/utils" + jsonrpc "popstellar/message" + "popstellar/message/answer" + "popstellar/message/messagedata" + "popstellar/message/query" + "popstellar/message/query/method" + "popstellar/message/query/method/message" + "popstellar/validation" +) + +func handleChannel(channelPath string, msg message.Message) *answer.Error { + errAnswer := verifyMessage(msg) + if errAnswer != nil { + return errAnswer.Wrap("handleChannel") + } + + db, errAnswer := database.GetChannelRepositoryInstance() + if errAnswer != nil { + return errAnswer.Wrap("handleChannel") + } + + msgAlreadyExists, err := db.HasMessage(msg.MessageID) + if err != nil { + errAnswer := answer.NewQueryDatabaseError("if message exists: %v", err) + return errAnswer.Wrap("handleChannel") + } + if msgAlreadyExists { + errAnswer := answer.NewInvalidActionError("message %s was already received", msg.MessageID) + return errAnswer.Wrap("handleChannel") + } + + channelType, err := db.GetChannelType(channelPath) + if err != nil { + errAnswer := answer.NewQueryDatabaseError("channel type: %v", err) + return errAnswer.Wrap("handleChannel") + } + + switch channelType { + case sqlite.RootType: + errAnswer = handleChannelRoot(msg) + case sqlite.LaoType: + errAnswer = handleChannelLao(channelPath, msg) + case sqlite.ElectionType: + errAnswer = handleChannelElection(channelPath, msg) + case sqlite.ChirpType: + errAnswer = handleChannelChirp(channelPath, msg) + case sqlite.ReactionType: + errAnswer = handleChannelReaction(channelPath, msg) + case sqlite.CoinType: + errAnswer = handleChannelCoin(channelPath, msg) + default: + errAnswer = answer.NewInvalidResourceError("unknown channel type for %s", channelPath) + } + + if errAnswer != nil { + return errAnswer.Wrap("handleChannel") + } + + return nil +} + +// util for the channels + +func verifyMessage(msg message.Message) *answer.Error { + dataBytes, err := base64.URLEncoding.DecodeString(msg.Data) + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("failed to decode data: %v", err) + return errAnswer.Wrap("verifyMessage") + } + + publicKeySender, err := base64.URLEncoding.DecodeString(msg.Sender) + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("failed to decode public key: %v", err) + return errAnswer.Wrap("verifyMessage") + } + + signatureBytes, err := base64.URLEncoding.DecodeString(msg.Signature) + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("failed to decode signature: %v", err) + return errAnswer.Wrap("verifyMessage") + } + + err = schnorr.VerifyWithChecks(crypto.Suite, publicKeySender, dataBytes, signatureBytes) + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("failed to verify signature : %v", err) + return errAnswer.Wrap("verifyMessage") + } + + expectedMessageID := messagedata.Hash(msg.Data, msg.Signature) + if expectedMessageID != msg.MessageID { + errAnswer := answer.NewInvalidActionError("messageID is wrong: expected %s found %s", + expectedMessageID, msg.MessageID) + return errAnswer.Wrap("verifyMessage") + } + return nil +} + +func verifyDataAndGetObjectAction(msg message.Message) (string, string, *answer.Error) { + jsonData, err := base64.URLEncoding.DecodeString(msg.Data) + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("failed to decode message data: %v", err) + return "", "", errAnswer.Wrap("verifyDataAndGetObjectAction") + } + + // validate message data against the json schema + errAnswer := utils.VerifyJSON(jsonData, validation.Data) + if errAnswer != nil { + return "", "", errAnswer.Wrap("verifyDataAndGetObjectAction") + } + + // get object#action + object, action, err := messagedata.GetObjectAndAction(jsonData) + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("failed to get object#action: %v", err) + return "", "", errAnswer.Wrap("verifyDataAndGetObjectAction") + } + return object, action, nil +} + +func Sign(data []byte) ([]byte, *answer.Error) { + var errAnswer *answer.Error + + serverSecretKey, errAnswer := config.GetServerSecretKeyInstance() + if errAnswer != nil { + return nil, errAnswer.Wrap("Sign") + } + + signatureBuf, err := schnorr.Sign(crypto.Suite, serverSecretKey, data) + if err != nil { + errAnswer = answer.NewInternalServerError("failed to sign the data: %v", err) + return nil, errAnswer.Wrap("Sign") + } + return signatureBuf, nil +} + +// generateKeys generates and returns a key pair +func generateKeys() (kyber.Point, kyber.Scalar) { + secret := crypto.Suite.Scalar().Pick(crypto.Suite.RandomStream()) + point := crypto.Suite.Point().Mul(secret, nil) + return point, secret +} + +func broadcastToAllClients(msg message.Message, channel string) *answer.Error { + rpcMessage := method.Broadcast{ + Base: query.Base{ + JSONRPCBase: jsonrpc.JSONRPCBase{ + JSONRPC: "2.0", + }, + Method: "broadcast", + }, + Params: struct { + Channel string `json:"channel"` + Message message.Message `json:"message"` + }{ + channel, + msg, + }, + } + + buf, err := json.Marshal(&rpcMessage) + if err != nil { + errAnswer := answer.NewInternalServerError("failed to marshal broadcast query: %v", err) + return errAnswer.Wrap("broadcastToAllClients") + } + + errAnswer := state.SendToAll(buf, channel) + if errAnswer != nil { + return errAnswer.Wrap("broadcastToAllClients") + } + + return nil +} diff --git a/be1-go/internal/popserver/handler/channel_test.go b/be1-go/internal/popserver/handler/channel_test.go new file mode 100644 index 0000000000..b76b95d064 --- /dev/null +++ b/be1-go/internal/popserver/handler/channel_test.go @@ -0,0 +1,197 @@ +package handler + +import ( + "encoding/base64" + "github.com/stretchr/testify/require" + "go.dedis.ch/kyber/v3" + "golang.org/x/xerrors" + "popstellar/crypto" + "popstellar/internal/popserver/database" + "popstellar/internal/popserver/database/repository" + "popstellar/internal/popserver/generatortest" + "popstellar/message/query/method/message" + "testing" + "time" +) + +// the public key used in every lao_create json files in the test_data/root folder +const ownerPubBuf64 = "3yPmdBu8DM7jT30IKqkPjuFFIHnubO0z4E0dV7dR4sY=" + +func Test_handleChannel(t *testing.T) { + mockRepository := repository.NewMockRepository(t) + database.SetDatabase(mockRepository) + + keypair := GenerateKeyPair(t) + sender := base64.URLEncoding.EncodeToString(keypair.PublicBuf) + + type input struct { + name string + channel string + message message.Message + contains string + } + + args := make([]input, 0) + + // Test 1: failed to handled message because unknown channel type + + channel := "unknown" + msg := generatortest.NewChirpAddMsg(t, sender, keypair.Private, time.Now().Unix()) + + mockRepository.On("HasMessage", msg.MessageID).Return(false, nil) + mockRepository.On("GetChannelType", channel).Return("", nil) + + args = append(args, input{ + name: "Test 1", + channel: channel, + message: msg, + contains: "unknown channel type for " + channel, + }) + + // Test 2: failed to handled message because db is disconnected when querying the channel type + + channel = "disconnectedDB" + msg = generatortest.NewChirpAddMsg(t, sender, keypair.Private, time.Now().Unix()) + + mockRepository.On("HasMessage", msg.MessageID).Return(false, nil) + mockRepository.On("GetChannelType", channel). + Return("", xerrors.Errorf("DB is disconnected")) + + args = append(args, input{ + name: "Test 2", + channel: channel, + message: msg, + contains: "DB is disconnected", + }) + + // Test 3: failed to handled message because message already exists + + msg = generatortest.NewChirpAddMsg(t, sender, keypair.Private, time.Now().Unix()) + + mockRepository.On("HasMessage", msg.MessageID).Return(true, nil) + + args = append(args, input{ + name: "Test 3", + message: msg, + contains: "message " + msg.MessageID + " was already received", + }) + + // Test 4: failed to handled message because db is disconnected when querying if the message already exists + + msg = generatortest.NewChirpAddMsg(t, sender, keypair.Private, time.Now().Unix()) + + mockRepository.On("HasMessage", msg.MessageID). + Return(false, xerrors.Errorf("DB is disconnected")) + + args = append(args, input{ + name: "Test 4", + message: msg, + contains: "DB is disconnected", + }) + + // Test 5: failed to handled message because the format of messageID + + msg = generatortest.NewChirpAddMsg(t, sender, keypair.Private, time.Now().Unix()) + expectedMsgID := msg.MessageID + msg.MessageID = base64.URLEncoding.EncodeToString([]byte("wrong messageID")) + + args = append(args, input{ + name: "Test 5", + message: msg, + contains: "messageID is wrong: expected " + expectedMsgID + " found " + msg.MessageID, + }) + + // Test 6: failed to handled message because wrong sender + + msg = generatortest.NewChirpAddMsg(t, sender, keypair.Private, time.Now().Unix()) + msg.Sender = base64.URLEncoding.EncodeToString([]byte("wrong sender")) + + args = append(args, input{ + name: "Test 6", + message: msg, + contains: "failed to verify signature", + }) + + // Test 7: failed to handled message because wrong data + + msg = generatortest.NewChirpAddMsg(t, sender, keypair.Private, time.Now().Unix()) + msg.Data = base64.URLEncoding.EncodeToString([]byte("wrong data")) + + args = append(args, input{ + name: "Test 7", + message: msg, + contains: "failed to verify signature", + }) + + // Test 8: failed to handled message because wrong signature + + msg = generatortest.NewChirpAddMsg(t, sender, keypair.Private, time.Now().Unix()) + msg.Data = base64.URLEncoding.EncodeToString([]byte("wrong signature")) + + args = append(args, input{ + name: "Test 8", + message: msg, + contains: "failed to verify signature", + }) + + // Test 9: failed to handled message because wrong signature encoding + + msg = generatortest.NewChirpAddMsg(t, sender, keypair.Private, time.Now().Unix()) + msg.Signature = "wrong signature" + + args = append(args, input{ + name: "Test 9", + message: msg, + contains: "failed to decode signature", + }) + + // Test 10: failed to handled message because wrong signature encoding + + msg = generatortest.NewChirpAddMsg(t, sender, keypair.Private, time.Now().Unix()) + msg.Sender = "wrong sender" + + args = append(args, input{ + name: "Test 10", + message: msg, + contains: "failed to decode public key", + }) + + // Test 11: failed to handled message because wrong signature encoding + + msg = generatortest.NewChirpAddMsg(t, sender, keypair.Private, time.Now().Unix()) + msg.Data = "wrong data" + + args = append(args, input{ + name: "Test 11", + message: msg, + contains: "failed to decode data", + }) + + for _, arg := range args { + t.Run(arg.name, func(t *testing.T) { + errAnswer := handleChannel(arg.channel, arg.message) + require.NotNil(t, errAnswer) + require.Contains(t, errAnswer.Error(), arg.contains) + }) + } + +} + +type Keypair struct { + Public kyber.Point + PublicBuf []byte + Private kyber.Scalar + PrivateBuf []byte +} + +func GenerateKeyPair(t *testing.T) Keypair { + secret := crypto.Suite.Scalar().Pick(crypto.Suite.RandomStream()) + point := crypto.Suite.Point().Mul(secret, nil) + + publicBuf, err := point.MarshalBinary() + require.NoError(t, err) + privateBuf, err := secret.MarshalBinary() + require.NoError(t, err) + + return Keypair{point, publicBuf, secret, privateBuf} +} diff --git a/be1-go/internal/popserver/handler/chirp.go b/be1-go/internal/popserver/handler/chirp.go new file mode 100644 index 0000000000..ecf4fc7c44 --- /dev/null +++ b/be1-go/internal/popserver/handler/chirp.go @@ -0,0 +1,192 @@ +package handler + +import ( + "encoding/base64" + "encoding/json" + "popstellar/internal/popserver/config" + "popstellar/internal/popserver/database" + "popstellar/message/answer" + "popstellar/message/messagedata" + "popstellar/message/query/method/message" + "strings" +) + +func handleChannelChirp(channelPath string, msg message.Message) *answer.Error { + object, action, errAnswer := verifyDataAndGetObjectAction(msg) + if errAnswer != nil { + return errAnswer.Wrap("handleChannelChirp") + } + + switch object + "#" + action { + case messagedata.ChirpObject + "#" + messagedata.ChirpActionAdd: + errAnswer = handleChirpAdd(channelPath, msg) + case messagedata.ChirpObject + "#" + messagedata.ChirpActionDelete: + errAnswer = handleChirpDelete(channelPath, msg) + default: + errAnswer = answer.NewInvalidMessageFieldError("failed to handle %s#%s, invalid object#action", object, action) + } + if errAnswer != nil { + return errAnswer.Wrap("handleChannelChirp") + } + + generalMsg, errAnswer := createChirpNotify(channelPath, msg) + if errAnswer != nil { + return errAnswer.Wrap("handleChannelChirp") + } + + generalChirpsChannelID, ok := strings.CutSuffix(channelPath, Social+"/"+msg.Sender) + if !ok { + errAnswer := answer.NewInvalidMessageFieldError("invalid channel path %s", channelPath) + return errAnswer.Wrap("handleChannelChirp") + } + + db, errAnswer := database.GetChirpRepositoryInstance() + if errAnswer != nil { + return errAnswer.Wrap("handleChannelChirp") + } + + err := db.StoreChirpMessages(channelPath, generalChirpsChannelID, msg, generalMsg) + if err != nil { + errAnswer = answer.NewStoreDatabaseError(err.Error()) + return errAnswer.Wrap("handleChannelChirp") + } + + errAnswer = broadcastToAllClients(msg, channelPath) + if errAnswer != nil { + return errAnswer.Wrap("handleChannelChirp") + } + + errAnswer = broadcastToAllClients(generalMsg, generalChirpsChannelID) + if errAnswer != nil { + return errAnswer.Wrap("handleChannelChirp") + } + + return nil +} + +func handleChirpAdd(channelID string, msg message.Message) *answer.Error { + var data messagedata.ChirpAdd + errAnswer := msg.UnmarshalMsgData(&data) + if errAnswer != nil { + return errAnswer.Wrap("handleChirpAdd") + } + + errAnswer = verifyChirpMessage(channelID, msg, data) + if errAnswer != nil { + return errAnswer.Wrap("handleChirpAdd") + } + + return nil +} + +func handleChirpDelete(channelID string, msg message.Message) *answer.Error { + var data messagedata.ChirpDelete + errAnswer := msg.UnmarshalMsgData(&data) + if errAnswer != nil { + return errAnswer.Wrap("handleChirpDelete") + } + + errAnswer = verifyChirpMessage(channelID, msg, data) + if errAnswer != nil { + return errAnswer.Wrap("handleChirpDelete") + } + + db, errAnswer := database.GetChirpRepositoryInstance() + if errAnswer != nil { + return errAnswer.Wrap("handleChirpDelete") + } + + msgToDeleteExists, err := db.HasMessage(data.ChirpID) + if err != nil { + errAnswer := answer.NewQueryDatabaseError("if message exists: %v", err) + return errAnswer.Wrap("handleChirpDelete") + } + if !msgToDeleteExists { + errAnswer := answer.NewInvalidResourceError("cannot delete unknown chirp") + return errAnswer.Wrap("handleChirpDelete") + } + + return nil +} + +func verifyChirpMessage(channelID string, msg message.Message, chirpMsg messagedata.Verifiable) *answer.Error { + err := chirpMsg.Verify() + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("invalid message: %v", err) + return errAnswer.Wrap("verifyChirpMessage") + } + + if !strings.HasSuffix(channelID, msg.Sender) { + errAnswer := answer.NewAccessDeniedError("only the owner of the channel can post chirps") + return errAnswer.Wrap("verifyChirpMessage") + } + + return nil +} + +func createChirpNotify(channelID string, msg message.Message) (message.Message, *answer.Error) { + jsonData, err := base64.URLEncoding.DecodeString(msg.Data) + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("failed to decode the data: %v", err) + return message.Message{}, errAnswer.Wrap("createChirpNotify") + } + + object, action, err := messagedata.GetObjectAndAction(jsonData) + action = "notify_" + action + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("failed to read the data: %v", err) + return message.Message{}, errAnswer.Wrap("createChirpNotify") + } + + timestamp, err := messagedata.GetTime(jsonData) + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("failed to read the data: %v", err) + return message.Message{}, errAnswer.Wrap("createChirpNotify") + } + + newData := messagedata.ChirpBroadcast{ + Object: object, + Action: action, + ChirpID: msg.MessageID, + Channel: channelID, + Timestamp: timestamp, + } + + dataBuf, err := json.Marshal(newData) + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("failed to marshal: %v", err) + return message.Message{}, errAnswer.Wrap("createChirpNotify") + } + + data64 := base64.URLEncoding.EncodeToString(dataBuf) + + serverPublicKey, errAnswer := config.GetServerPublicKeyInstance() + if errAnswer != nil { + return message.Message{}, errAnswer.Wrap("createChirpNotify") + } + + pkBuf, err := serverPublicKey.MarshalBinary() + if err != nil { + errAnswer := answer.NewInternalServerError("failed to unmarshall server public key", err) + return message.Message{}, errAnswer.Wrap("createChirpNotify") + } + pk64 := base64.URLEncoding.EncodeToString(pkBuf) + + signatureBuf, errAnswer := Sign(dataBuf) + if errAnswer != nil { + return message.Message{}, errAnswer.Wrap("createChirpNotify") + } + signature64 := base64.URLEncoding.EncodeToString(signatureBuf) + + messageID64 := messagedata.Hash(data64, signature64) + + newMsg := message.Message{ + Data: data64, + Sender: pk64, + Signature: signature64, + MessageID: messageID64, + WitnessSignatures: make([]message.WitnessSignature, 0), + } + + return newMsg, nil +} diff --git a/be1-go/internal/popserver/handler/chirp_test.go b/be1-go/internal/popserver/handler/chirp_test.go new file mode 100644 index 0000000000..494c62796f --- /dev/null +++ b/be1-go/internal/popserver/handler/chirp_test.go @@ -0,0 +1,182 @@ +package handler + +import ( + "encoding/base64" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "popstellar/crypto" + "popstellar/internal/popserver/config" + "popstellar/internal/popserver/database" + "popstellar/internal/popserver/database/repository" + "popstellar/internal/popserver/generatortest" + "popstellar/internal/popserver/state" + "popstellar/internal/popserver/types" + "popstellar/message/query/method/message" + "strings" + "testing" + "time" +) + +func Test_handleChannelChirp(t *testing.T) { + subs := types.NewSubscribers() + queries := types.NewQueries(&noLog) + peers := types.NewPeers() + + state.SetState(subs, peers, queries) + + organizerBuf, err := base64.URLEncoding.DecodeString(ownerPubBuf64) + require.NoError(t, err) + + ownerPublicKey := crypto.Suite.Point() + err = ownerPublicKey.UnmarshalBinary(organizerBuf) + require.NoError(t, err) + + serverSecretKey := crypto.Suite.Scalar().Pick(crypto.Suite.RandomStream()) + serverPublicKey := crypto.Suite.Point().Mul(serverSecretKey, nil) + + config.SetConfig(ownerPublicKey, serverPublicKey, serverSecretKey, "clientAddress", "serverAddress") + + mockRepository := repository.NewMockRepository(t) + database.SetDatabase(mockRepository) + + sender := "3yPmdBu8DM7jT30IKqkPjuFFIHnubO0z4E0dV7dR4sY=" + wrongSender := "3yPmdBu8DM7jT30IKqkPjuFFIHnubO0z4E0dV7dR4sK=" + chirpID := "AAAAdBu8DM7jT30IKqkPjuFFIHnubO0z4E0dV7dR4sK=" + + var args []input + + // Test 1: successfully add a chirp and notify it + + channelID := "/root/lao1/social/" + sender + + args = append(args, input{ + name: "Test 1", + channel: channelID, + msg: newChirpAddMsg(t, channelID, sender, time.Now().Unix(), mockRepository, false), + isError: false, + contains: "", + }) + + // Test 2: failed to add chirp because not owner of the channel + + channelID = "/root/lao2/social/" + sender + + args = append(args, input{ + name: "Test 2", + channel: channelID, + msg: newChirpAddMsg(t, channelID, wrongSender, time.Now().Unix(), mockRepository, true), + isError: true, + contains: "only the owner of the channel can post chirps", + }) + + // Test 3: failed to add chirp because negative timestamp + + channelID = "/root/lao3/social/" + sender + + args = append(args, input{ + name: "Test 3", + channel: channelID, + msg: newChirpAddMsg(t, channelID, sender, -1, mockRepository, true), + isError: true, + contains: "invalid message field", + }) + + // Test 4: successfully delete a chirp and notify it + + channelID = "/root/lao4/social/" + sender + + args = append(args, input{ + name: "Test 4", + channel: channelID, + msg: newChirpDeleteMsg(t, channelID, sender, chirpID, time.Now().Unix(), mockRepository, false), + isError: false, + contains: "", + }) + + // Test 5: failed to delete chirp because not owner of the channel + + channelID = "/root/lao5/social/" + sender + + args = append(args, input{ + name: "Test 5", + channel: channelID, + msg: newChirpDeleteMsg(t, channelID, wrongSender, chirpID, time.Now().Unix(), mockRepository, true), + isError: true, + contains: "only the owner of the channel can post chirps", + }) + + // Test 6: failed to delete chirp because negative timestamp + + channelID = "/root/lao6/social/" + sender + + args = append(args, input{ + name: "Test 6", + channel: channelID, + msg: newChirpDeleteMsg(t, channelID, sender, chirpID, -1, mockRepository, true), + isError: true, + contains: "invalid message field", + }) + + // Tests all cases + + for _, arg := range args { + t.Run(arg.name, func(t *testing.T) { + errAnswer := handleChannelChirp(arg.channel, arg.msg) + if arg.isError { + require.NotNil(t, errAnswer) + require.Contains(t, errAnswer.Error(), arg.contains) + } else { + require.Nil(t, errAnswer) + } + }) + } + +} + +func newChirpAddMsg(t *testing.T, channelID string, sender string, timestamp int64, + mockRepository *repository.MockRepository, isError bool) message.Message { + + msg := generatortest.NewChirpAddMsg(t, sender, nil, timestamp) + + errAnswer := state.AddChannel(channelID) + require.Nil(t, errAnswer) + + if isError { + return msg + } + + chirpNotifyChannelID, _ := strings.CutSuffix(channelID, Social+"/"+msg.Sender) + + errAnswer = state.AddChannel(chirpNotifyChannelID) + require.Nil(t, errAnswer) + + mockRepository.On("StoreChirpMessages", channelID, chirpNotifyChannelID, mock.AnythingOfType("message.Message"), + mock.AnythingOfType("message.Message")).Return(nil) + + return msg +} + +func newChirpDeleteMsg(t *testing.T, channelID string, sender string, chirpID string, + timestamp int64, mockRepository *repository.MockRepository, isError bool) message.Message { + + msg := generatortest.NewChirpDeleteMsg(t, sender, nil, chirpID, timestamp) + + errAnswer := state.AddChannel(channelID) + require.Nil(t, errAnswer) + + if isError { + return msg + } + + mockRepository.On("HasMessage", chirpID).Return(true, nil) + + chirpNotifyChannelID, _ := strings.CutSuffix(channelID, Social+"/"+msg.Sender) + + errAnswer = state.AddChannel(chirpNotifyChannelID) + require.Nil(t, errAnswer) + + mockRepository.On("StoreChirpMessages", channelID, chirpNotifyChannelID, mock.AnythingOfType("message.Message"), + mock.AnythingOfType("message.Message")).Return(nil) + + return msg +} diff --git a/be1-go/internal/popserver/handler/coin.go b/be1-go/internal/popserver/handler/coin.go new file mode 100644 index 0000000000..4a769f8841 --- /dev/null +++ b/be1-go/internal/popserver/handler/coin.go @@ -0,0 +1,60 @@ +package handler + +import ( + "popstellar/internal/popserver/database" + "popstellar/message/answer" + "popstellar/message/messagedata" + "popstellar/message/query/method/message" +) + +func handleChannelCoin(channelPath string, msg message.Message) *answer.Error { + object, action, errAnswer := verifyDataAndGetObjectAction(msg) + if errAnswer != nil { + return errAnswer.Wrap("handleChannelCoin") + } + + switch object + "#" + action { + case messagedata.CoinObject + "#" + messagedata.CoinActionPostTransaction: + errAnswer = handleCoinPostTransaction(msg) + default: + errAnswer = answer.NewInvalidMessageFieldError("failed to handle %s#%s, invalid object#action", object, action) + } + if errAnswer != nil { + return errAnswer.Wrap("handleChannelCoin") + } + + db, errAnswer := database.GetCoinRepositoryInstance() + if errAnswer != nil { + return errAnswer.Wrap("handleChannelCoin") + } + + err := db.StoreMessageAndData(channelPath, msg) + if err != nil { + errAnswer = answer.NewStoreDatabaseError(err.Error()) + return errAnswer.Wrap("handleChannelCoin") + } + + errAnswer = broadcastToAllClients(msg, channelPath) + if errAnswer != nil { + return errAnswer.Wrap("handleChannelCoin") + } + + return nil +} + +func handleCoinPostTransaction(msg message.Message) *answer.Error { + var data messagedata.PostTransaction + + errAnswer := msg.UnmarshalMsgData(&data) + if errAnswer != nil { + return errAnswer.Wrap("handleCoinPostTransaction") + } + + err := data.Verify() + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("invalid data: %v", err) + return errAnswer.Wrap("handleCoinPostTransaction") + } + + return nil +} diff --git a/be1-go/internal/popserver/handler/coin_test.go b/be1-go/internal/popserver/handler/coin_test.go new file mode 100644 index 0000000000..e332a89ccf --- /dev/null +++ b/be1-go/internal/popserver/handler/coin_test.go @@ -0,0 +1,192 @@ +package handler + +import ( + "encoding/base64" + "encoding/json" + "github.com/stretchr/testify/require" + "os" + "path/filepath" + "popstellar/internal/popserver/database" + "popstellar/internal/popserver/database/repository" + "popstellar/internal/popserver/state" + "popstellar/internal/popserver/types" + "popstellar/message/messagedata" + "popstellar/message/query/method" + "popstellar/message/query/method/message" + "popstellar/network/socket" + "testing" +) + +const coinPath string = "../../../validation/protocol/examples/messageData/coin" + +type inputTestHandleChannelCoin struct { + name string + channelID string + message message.Message + hasError bool + sockets []*socket.FakeSocket +} + +func Test_handleChannelCoin(t *testing.T) { + subs := types.NewSubscribers() + queries := types.NewQueries(&noLog) + peers := types.NewPeers() + + state.SetState(subs, peers, queries) + + mockRepository := repository.NewMockRepository(t) + database.SetDatabase(mockRepository) + + inputs := make([]inputTestHandleChannelCoin, 0) + + // Tests that the channel works correctly when it receives a transaction + + inputs = append(inputs, newSuccessTestHandleChannelCoin(t, + "post_transaction.json", + "send transaction", + mockRepository)) + + // Tests that the channel works correctly when it receives a large transaction + + inputs = append(inputs, newSuccessTestHandleChannelCoin(t, + "post_transaction_max_amount.json", + "send transaction max amount", + mockRepository)) + + // Tests that the channel rejects transactions that exceed the maximum amount + + inputs = append(inputs, newFailTestHandleChannelCoin(t, + "post_transaction_overflow_amount.json", + "send transaction overflow amount")) + + // Tests that the channel accepts transactions with zero amounts + + inputs = append(inputs, newSuccessTestHandleChannelCoin(t, + "post_transaction_zero_amount.json", + "send transaction zero amount", + mockRepository)) + + // Tests that the channel rejects transactions with negative amounts + + inputs = append(inputs, newFailTestHandleChannelCoin(t, + "post_transaction_negative_amount.json", + "send transaction negative amount")) + + // Tests that the channel rejects Transaction with wrong id + + inputs = append(inputs, newFailTestHandleChannelCoin(t, + "post_transaction_wrong_transaction_id.json", + "send transaction wrong id")) + + // Tests that the channel rejects Transaction with bad signature + + inputs = append(inputs, newFailTestHandleChannelCoin(t, + "post_transaction_bad_signature.json", + "send transaction bad signature")) + + // Tests that the channel works correctly when it receives a transaction + + inputs = append(inputs, newSuccessTestHandleChannelCoin(t, + "post_transaction_coinbase.json", + "send transaction coinbase", + mockRepository)) + + // Tests all cases + + for _, i := range inputs { + t.Run(i.name, func(t *testing.T) { + errAnswer := handleChannelCoin(i.channelID, i.message) + if i.hasError { + require.NotNil(t, errAnswer) + } else { + require.Nil(t, errAnswer) + + for _, s := range i.sockets { + require.NotNil(t, s.Msg) + + var msg method.Broadcast + err := json.Unmarshal(s.Msg, &msg) + require.NoError(t, err) + + require.Equal(t, i.message, msg.Params.Message) + } + } + }) + } + +} + +func newSuccessTestHandleChannelCoin(t *testing.T, filename string, name string, mockRepository *repository.MockRepository) inputTestHandleChannelCoin { + laoID := messagedata.Hash(name) + var sender = "M5ZychEi5rwm22FjwjNuljL1qMJWD2sE7oX9fcHNMDU=" + var channelID = "/root/" + laoID + "/coin" + + file := filepath.Join(coinPath, filename) + buf, err := os.ReadFile(file) + require.NoError(t, err) + + buf64 := base64.URLEncoding.EncodeToString(buf) + + m := message.Message{ + Data: buf64, + Sender: sender, + Signature: "h", + MessageID: messagedata.Hash(buf64, "h"), + WitnessSignatures: []message.WitnessSignature{}, + } + + mockRepository.On("StoreMessageAndData", channelID, m).Return(nil) + + sockets := []*socket.FakeSocket{ + {Id: laoID + "0"}, + {Id: laoID + "1"}, + {Id: laoID + "2"}, + {Id: laoID + "3"}, + } + + errAnswer := state.AddChannel(channelID) + require.Nil(t, errAnswer) + + for _, s := range sockets { + errAnswer := state.Subscribe(s, channelID) + require.Nil(t, errAnswer) + } + + return inputTestHandleChannelCoin{ + name: name, + channelID: channelID, + message: m, + hasError: false, + sockets: sockets, + } +} + +func newFailTestHandleChannelCoin(t *testing.T, filename string, name string) inputTestHandleChannelCoin { + laoID := messagedata.Hash(name) + var sender = "M5ZychEi5rwm22FjwjNuljL1qMJWD2sE7oX9fcHNMDU=" + var channelID = "/root/" + laoID + "/coin" + + file := filepath.Join(coinPath, filename) + buf, err := os.ReadFile(file) + require.NoError(t, err) + + buf64 := base64.URLEncoding.EncodeToString(buf) + + m := message.Message{ + Data: buf64, + Sender: sender, + Signature: "h", + MessageID: messagedata.Hash(buf64, "h"), + WitnessSignatures: []message.WitnessSignature{}, + } + + errAnswer := state.AddChannel(channelID) + require.Nil(t, errAnswer) + + return inputTestHandleChannelCoin{ + name: name, + channelID: channelID, + message: m, + hasError: true, + } +} diff --git a/be1-go/internal/popserver/handler/election.go b/be1-go/internal/popserver/handler/election.go new file mode 100644 index 0000000000..4b793bf052 --- /dev/null +++ b/be1-go/internal/popserver/handler/election.go @@ -0,0 +1,578 @@ +package handler + +import ( + "bytes" + "encoding/base64" + "encoding/binary" + "encoding/json" + "fmt" + "popstellar/crypto" + "popstellar/internal/popserver/config" + "popstellar/internal/popserver/database" + "popstellar/internal/popserver/types" + "popstellar/message/answer" + "popstellar/message/messagedata" + "popstellar/message/query/method/message" + "sort" +) + +const ( + voteFlag = "Vote" +) + +func handleChannelElection(channelPath string, msg message.Message) *answer.Error { + object, action, errAnswer := verifyDataAndGetObjectAction(msg) + if errAnswer != nil { + return errAnswer.Wrap("handleChannelElection") + } + + storeMessage := true + + switch object + "#" + action { + case messagedata.ElectionObject + "#" + messagedata.VoteActionCastVote: + errAnswer = handleVoteCastVote(msg, channelPath) + case messagedata.ElectionObject + "#" + messagedata.ElectionActionOpen: + errAnswer = handleElectionOpen(msg, channelPath) + case messagedata.ElectionObject + "#" + messagedata.ElectionActionEnd: + errAnswer = handleElectionEnd(msg, channelPath) + storeMessage = false + default: + errAnswer = answer.NewInvalidMessageFieldError("failed to handle %s#%s, invalid object#action", object, action) + } + if errAnswer != nil { + return errAnswer.Wrap("handleChannelElection") + } + + if storeMessage { + db, errAnswer := database.GetElectionRepositoryInstance() + if errAnswer != nil { + return errAnswer.Wrap("handleChannelElection") + } + + err := db.StoreMessageAndData(channelPath, msg) + if err != nil { + errAnswer = answer.NewStoreDatabaseError(err.Error()) + return errAnswer.Wrap("handleChannelElection") + } + } + + return nil +} + +func handleVoteCastVote(msg message.Message, channelPath string) *answer.Error { + var voteCastVote messagedata.VoteCastVote + errAnswer := msg.UnmarshalMsgData(&voteCastVote) + if errAnswer != nil { + return errAnswer.Wrap("handleVoteCastVote") + } + + errAnswer = verifySenderElection(msg, channelPath, false) + if errAnswer != nil { + return errAnswer.Wrap("handleVoteCastVote") + } + + //verify message data + errAnswer = voteCastVote.Verify(channelPath) + if errAnswer != nil { + return errAnswer.Wrap("handleVoteCastVote") + } + + db, errAnswer := database.GetElectionRepositoryInstance() + if errAnswer != nil { + return errAnswer.Wrap("handleVoteCastVote") + } + + if voteCastVote.CreatedAt < 0 { + errAnswer := answer.NewInvalidMessageFieldError("cast vote created at is negative") + return errAnswer.Wrap("handleVoteCastVote") + } + + // verify VoteCastVote created after election createdAt + createdAt, err := db.GetElectionCreationTime(channelPath) + if err != nil { + errAnswer := answer.NewQueryDatabaseError("election creation time: %v", err) + return errAnswer.Wrap("handleVoteCastVote") + } + + if createdAt > voteCastVote.CreatedAt { + errAnswer := answer.NewInvalidMessageFieldError("cast vote cannot have a creation time prior to election setup") + return errAnswer.Wrap("handleVoteCastVote") + } + + // verify votes + for i, vote := range voteCastVote.Votes { + err := verifyVote(vote, channelPath, voteCastVote.Election) + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("failed to validate vote %d: %v", i, err) + return errAnswer.Wrap("handleVoteCastVote") + } + } + + // Just store the vote cast if the election has ended because will not have any influence on the result + ended, err := db.IsElectionEnded(channelPath) + if err != nil { + errAnswer := answer.NewQueryDatabaseError("election end status: %v", err) + return errAnswer.Wrap("handleVoteCastVote") + } + if ended { + return nil + } + + // verify that the election is open + started, err := db.IsElectionStarted(channelPath) + if err != nil { + errAnswer := answer.NewQueryDatabaseError("election start status: %v", err) + return errAnswer.Wrap("handleVoteCastVote") + } + + if !started { + errAnswer := answer.NewInvalidMessageFieldError("election is not started") + return errAnswer.Wrap("handleVoteCastVote") + } + + errAnswer = broadcastToAllClients(msg, channelPath) + if errAnswer != nil { + return errAnswer.Wrap("handleVoteCastVote") + } + + return nil +} + +func handleElectionOpen(msg message.Message, channelPath string) *answer.Error { + var electionOpen messagedata.ElectionOpen + errAnswer := msg.UnmarshalMsgData(&electionOpen) + if errAnswer != nil { + return errAnswer.Wrap("handleElectionOpen") + } + + errAnswer = verifySenderElection(msg, channelPath, true) + if errAnswer != nil { + return errAnswer.Wrap("handleElectionOpen") + } + + // verify message data + errAnswer = electionOpen.Verify(channelPath) + if errAnswer != nil { + return errAnswer.Wrap("handleElectionOpen") + } + + db, errAnswer := database.GetElectionRepositoryInstance() + if errAnswer != nil { + return errAnswer.Wrap("handleElectionOpen") + } + + // verify if the election was already started or terminated + ok, err := db.IsElectionStartedOrEnded(channelPath) + if err != nil { + errAnswer := answer.NewQueryDatabaseError("election start or termination status: %v", err) + return errAnswer.Wrap("handleElectionOpen") + } + if ok { + errAnswer := answer.NewInvalidMessageFieldError("election is already started or ended") + return errAnswer.Wrap("handleElectionOpen") + } + + createdAt, err := db.GetElectionCreationTime(channelPath) + if err != nil { + errAnswer := answer.NewQueryDatabaseError("election creation time: %v", err) + return errAnswer.Wrap("handleElectionOpen") + } + if electionOpen.OpenedAt < createdAt { + errAnswer := answer.NewInvalidMessageFieldError("election open cannot have a creation time prior to election setup") + return errAnswer.Wrap("handleElectionOpen") + } + + errAnswer = broadcastToAllClients(msg, channelPath) + if errAnswer != nil { + return errAnswer.Wrap("handleElectionOpen") + } + + return nil +} + +func handleElectionEnd(msg message.Message, channelPath string) *answer.Error { + var electionEnd messagedata.ElectionEnd + errAnswer := msg.UnmarshalMsgData(&electionEnd) + if errAnswer != nil { + return errAnswer.Wrap("handleElectionEnd") + } + + errAnswer = verifySenderElection(msg, channelPath, true) + if errAnswer != nil { + return errAnswer.Wrap("handleElectionEnd") + } + + errAnswer = verifyElectionEnd(electionEnd, channelPath) + if errAnswer != nil { + return errAnswer.Wrap("handleElectionEnd") + } + + db, errAnswer := database.GetElectionRepositoryInstance() + if errAnswer != nil { + return errAnswer.Wrap("handleElectionEnd") + } + + questions, err := db.GetElectionQuestionsWithValidVotes(channelPath) + if err != nil { + errAnswer := answer.NewQueryDatabaseError("election questions: %v", err) + return errAnswer.Wrap("handleElectionEnd") + } + + if len(electionEnd.RegisteredVotes) != 0 { + errAnswer = verifyRegisteredVotes(electionEnd, questions) + if errAnswer != nil { + return errAnswer.Wrap("handleElectionEnd") + } + } + + electionResultMsg, errAnswer := createElectionResult(questions, channelPath) + if errAnswer != nil { + return errAnswer.Wrap("handleElectionEnd") + } + + err = db.StoreElectionEndWithResult(channelPath, msg, electionResultMsg) + if err != nil { + errAnswer := answer.NewStoreDatabaseError("election end and election result: %v", err) + return errAnswer.Wrap("handleElectionEnd") + } + + errAnswer = broadcastToAllClients(msg, channelPath) + if errAnswer != nil { + return errAnswer.Wrap("handleElectionEnd") + } + errAnswer = broadcastToAllClients(electionResultMsg, channelPath) + if errAnswer != nil { + return errAnswer.Wrap("handleElectionEnd") + } + + return nil +} + +func verifyElectionEnd(electionEnd messagedata.ElectionEnd, channelPath string) *answer.Error { + // verify message data + errAnswer := electionEnd.Verify(channelPath) + if errAnswer != nil { + return errAnswer.Wrap("handleElectionEnd") + + } + + db, errAnswer := database.GetElectionRepositoryInstance() + if errAnswer != nil { + return errAnswer.Wrap("handleElectionEnd") + } + + // verify if the election is started + started, err := db.IsElectionStarted(channelPath) + if err != nil { + errAnswer := answer.NewQueryDatabaseError("election start status: %v", err) + return errAnswer.Wrap("handleElectionEnd") + } + if !started { + errAnswer := answer.NewInvalidMessageFieldError("election was not started") + return errAnswer.Wrap("handleElectionEnd") + } + + // verify if the timestamp is stale + createdAt, err := db.GetElectionCreationTime(channelPath) + if err != nil { + errAnswer := answer.NewQueryDatabaseError("election creation time: %v", err) + return errAnswer.Wrap("handleElectionEnd") + } + if electionEnd.CreatedAt < createdAt { + errAnswer := answer.NewInvalidMessageFieldError("election end cannot have a creation time prior to election setup") + return errAnswer.Wrap("handleElectionEnd") + } + + return nil +} + +func verifySenderElection(msg message.Message, channelPath string, onlyOrganizer bool) *answer.Error { + senderBuf, err := base64.URLEncoding.DecodeString(msg.Sender) + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("failed to decode sender: %v", err) + return errAnswer.Wrap("verifySender") + } + senderPubKey := crypto.Suite.Point() + err = senderPubKey.UnmarshalBinary(senderBuf) + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("failed to unmarshal sender: %v", err) + return errAnswer.Wrap("verifySender") + } + + db, errAnswer := database.GetElectionRepositoryInstance() + if errAnswer != nil { + return errAnswer.Wrap("verifySender") + } + + organizerPubKey, err := db.GetLAOOrganizerPubKey(channelPath) + if err != nil { + errAnswer := answer.NewQueryDatabaseError("lao organizer pk: %v", err) + return errAnswer.Wrap("verifySender") + } + + if onlyOrganizer && !senderPubKey.Equal(organizerPubKey) { + errAnswer := answer.NewInvalidMessageFieldError("sender is not the organizer of the channel") + return errAnswer.Wrap("verifySender") + } + + if onlyOrganizer { + return nil + } + + attendees, err := db.GetElectionAttendees(channelPath) + if err != nil { + errAnswer := answer.NewQueryDatabaseError("election attendees: %v", err) + return errAnswer.Wrap("verifySender") + } + + _, ok := attendees[msg.Sender] + if !ok && !senderPubKey.Equal(organizerPubKey) { + errAnswer := answer.NewInvalidMessageFieldError("sender is not an attendee or the organizer of the election") + return errAnswer.Wrap("verifySender") + } + + return nil +} + +func verifyVote(vote messagedata.Vote, channelPath, electionID string) *answer.Error { + db, errAnswer := database.GetElectionRepositoryInstance() + if errAnswer != nil { + return errAnswer.Wrap("handleElectionOpen") + } + + questions, err := db.GetElectionQuestions(channelPath) + if err != nil { + errAnswer := answer.NewQueryDatabaseError("election questions: %v", err) + return errAnswer.Wrap("verifyVote") + } + question, ok := questions[vote.Question] + if !ok { + errAnswer := answer.NewInvalidMessageFieldError("Question does not exist") + return errAnswer.Wrap("verifyVote") + } + electionType, err := db.GetElectionType(channelPath) + if err != nil { + errAnswer := answer.NewQueryDatabaseError("election type: %v", err) + return errAnswer.Wrap("verifyVote") + } + var voteString string + switch electionType { + case messagedata.OpenBallot: + voteInt, ok := vote.Vote.(int) + if !ok { + errAnswer := answer.NewInvalidMessageFieldError("vote in open ballot should be an integer") + return errAnswer.Wrap("verifyVote") + } + voteString = fmt.Sprintf("%d", voteInt) + case messagedata.SecretBallot: + voteString, ok = vote.Vote.(string) + if !ok { + errAnswer := answer.NewInvalidMessageFieldError("vote in secret ballot should be a string") + return errAnswer.Wrap("verifyVote") + } + voteBytes, err := base64.URLEncoding.DecodeString(voteString) + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("vote should be base64 encoded: %v", err) + return errAnswer.Wrap("verifyVote") + } + if len(voteBytes) != 64 { + errAnswer := answer.NewInvalidMessageFieldError("vote should be 64 bytes long") + return errAnswer.Wrap("verifyVote") + } + default: + errAnswer := answer.NewInvalidMessageFieldError("invalid election type: %s", electionType) + return errAnswer.Wrap("verifyVote") + } + hash := messagedata.Hash(voteFlag, electionID, string(question.ID), voteString) + if vote.ID != hash { + errAnswer := answer.NewInvalidMessageFieldError("vote ID is not the expected hash") + return errAnswer.Wrap("verifyVote") + } + return nil +} + +func verifyRegisteredVotes(electionEnd messagedata.ElectionEnd, questions map[string]types.Question) *answer.Error { + var voteIDs []string + for _, question := range questions { + for _, validVote := range question.ValidVotes { + voteIDs = append(voteIDs, validVote.ID) + } + } + // sort vote IDs + sort.Strings(voteIDs) + + // hash all valid vote ids + validVotesHash := messagedata.Hash(voteIDs...) + + // compare registered votes with local saved votes + if electionEnd.RegisteredVotes != validVotesHash { + errAnswer := answer.NewInvalidMessageFieldError("registered votes is %s, should be sorted and equal to %s", + electionEnd.RegisteredVotes, validVotesHash) + return errAnswer.Wrap("verifyRegisteredVotes") + } + return nil +} + +func createElectionResult(questions map[string]types.Question, channelPath string) (message.Message, *answer.Error) { + resultElection, errAnswer := computeElectionResult(questions, channelPath) + if errAnswer != nil { + return message.Message{}, errAnswer.Wrap("createElectionResult") + } + + buf, err := json.Marshal(resultElection) + if err != nil { + errAnswer := answer.NewInternalServerError("marshal election result: %v", err) + return message.Message{}, errAnswer.Wrap("createElectionResult") + } + buf64 := base64.URLEncoding.EncodeToString(buf) + + serverPubKey, errAnswer := config.GetServerPublicKeyInstance() + if errAnswer != nil { + return message.Message{}, errAnswer.Wrap("createElectionResult") + } + serverPubBuf, err := serverPubKey.MarshalBinary() + if err != nil { + errAnswer := answer.NewInternalServerError("failed to marshal server public key: %v", err) + return message.Message{}, errAnswer.Wrap("createElectionResult") + } + signatureBuf, errAnswer := Sign(buf) + if errAnswer != nil { + return message.Message{}, errAnswer.Wrap("createElectionResult") + } + + signature := base64.URLEncoding.EncodeToString(signatureBuf) + + electionResultMsg := message.Message{ + Data: buf64, + Sender: base64.URLEncoding.EncodeToString(serverPubBuf), + Signature: signature, + MessageID: messagedata.Hash(buf64, signature), + WitnessSignatures: []message.WitnessSignature{}, + } + + return electionResultMsg, nil +} + +func computeElectionResult(questions map[string]types.Question, channelPath string) ( + messagedata.ElectionResult, *answer.Error) { + db, errAnswer := database.GetElectionRepositoryInstance() + if errAnswer != nil { + return messagedata.ElectionResult{}, errAnswer.Wrap("computeElectionResult") + } + + electionType, err := db.GetElectionType(channelPath) + if err != nil { + errAnswer := answer.NewQueryDatabaseError("election type: %v", err) + return messagedata.ElectionResult{}, errAnswer.Wrap("computeElectionResult") + } + + result := make([]messagedata.ElectionResultQuestion, 0) + + for id, question := range questions { + if question.Method != messagedata.PluralityMethod { + continue + } + votesPerBallotOption := make([]int, len(question.BallotOptions)) + for _, validVote := range question.ValidVotes { + index, ok := getVoteIndex(validVote, electionType, channelPath) + if ok && index >= 0 && index < len(question.BallotOptions) { + votesPerBallotOption[index]++ + } + } + var questionResults []messagedata.ElectionResultQuestionResult + for i, options := range question.BallotOptions { + questionResults = append(questionResults, messagedata.ElectionResultQuestionResult{ + BallotOption: options, + Count: votesPerBallotOption[i], + }) + } + + electionResult := messagedata.ElectionResultQuestion{ + ID: id, + Result: questionResults, + } + result = append(result, electionResult) + } + + resultElection := messagedata.ElectionResult{ + Object: messagedata.ElectionObject, + Action: messagedata.ElectionActionResult, + Questions: result, + } + + return resultElection, nil +} + +func getVoteIndex(vote types.ValidVote, electionType, channelPath string) (int, bool) { + switch electionType { + case messagedata.OpenBallot: + index, _ := vote.Index.(int) + return index, true + + case messagedata.SecretBallot: + encryptedVote, _ := vote.Index.(string) + index, err := decryptVote(encryptedVote, channelPath) + if err != nil { + return index, false + } + return index, true + } + return -1, false +} + +func decryptVote(vote, channelPath string) (int, *answer.Error) { + voteBuff, err := base64.URLEncoding.DecodeString(vote) + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("failed to decode vote: %v", err) + return -1, errAnswer.Wrap("decryptVote") + } + if len(voteBuff) != 64 { + errAnswer := answer.NewInvalidMessageFieldError("vote should be 64 bytes long") + return -1, errAnswer.Wrap("decryptVote") + } + + // K and C are respectively the first and last 32 bytes of the vote + K := crypto.Suite.Point() + C := crypto.Suite.Point() + + err = K.UnmarshalBinary(voteBuff[:32]) + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("failed to unmarshal K: %v", err) + return -1, errAnswer.Wrap("decryptVote") + } + err = C.UnmarshalBinary(voteBuff[32:]) + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("failed to unmarshal C: %v", err) + return -1, errAnswer.Wrap("decryptVote") + } + + db, errAnswer := database.GetElectionRepositoryInstance() + if errAnswer != nil { + return -1, errAnswer.Wrap("decryptVote") + } + + electionSecretKey, err := db.GetElectionSecretKey(channelPath) + if err != nil { + errAnswer := answer.NewQueryDatabaseError("election secret key: %v", err) + return -1, errAnswer.Wrap("decryptVote") + } + + // performs the ElGamal decryption + S := crypto.Suite.Point().Mul(electionSecretKey, K) + data, err := crypto.Suite.Point().Sub(C, S).Data() + if err != nil { + errAnswer := answer.NewInternalServerError("failed to decrypt vote: %v", err) + return -1, errAnswer.Wrap("decryptVote") + } + + var index uint16 + + // interprets the data as a big endian int + buf := bytes.NewReader(data) + err = binary.Read(buf, binary.BigEndian, &index) + if err != nil { + errAnswer := answer.NewInternalServerError("failed to interpret decrypted data: %v", err) + return -1, errAnswer.Wrap("decryptVote") + } + return int(index), nil +} diff --git a/be1-go/internal/popserver/handler/election_test.go b/be1-go/internal/popserver/handler/election_test.go new file mode 100644 index 0000000000..89bf46e852 --- /dev/null +++ b/be1-go/internal/popserver/handler/election_test.go @@ -0,0 +1,540 @@ +package handler + +import ( + "encoding/base64" + "fmt" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "go.dedis.ch/kyber/v3" + "popstellar/crypto" + "popstellar/internal/popserver/config" + "popstellar/internal/popserver/database" + "popstellar/internal/popserver/database/repository" + "popstellar/internal/popserver/generatortest" + state "popstellar/internal/popserver/state" + "popstellar/internal/popserver/types" + "popstellar/message/messagedata" + "popstellar/message/query/method/message" + "testing" +) + +func Test_handleChannelElection(t *testing.T) { + var args []input + + mockRepository := repository.NewMockRepository(t) + database.SetDatabase(mockRepository) + + subs := types.NewSubscribers() + queries := types.NewQueries(&noLog) + peers := types.NewPeers() + + ownerPubBuf, err := base64.URLEncoding.DecodeString(ownerPubBuf64) + require.NoError(t, err) + + ownerPublicKey := crypto.Suite.Point() + err = ownerPublicKey.UnmarshalBinary(ownerPubBuf) + require.NoError(t, err) + + serverSecretKey := crypto.Suite.Scalar().Pick(crypto.Suite.RandomStream()) + serverPublicKey := crypto.Suite.Point().Mul(serverSecretKey, nil) + + config.SetConfig(ownerPublicKey, serverPublicKey, serverSecretKey, "clientAddress", "serverAddress") + + state.SetState(subs, peers, queries) + + laoID := base64.URLEncoding.EncodeToString([]byte("laoID")) + electionID := base64.URLEncoding.EncodeToString([]byte("electionID")) + channelPath := "/root/" + laoID + "/" + electionID + + // Test 1 Error when ElectionOpen sender is not the same as the lao organizer + args = append(args, input{ + name: "Test 1", + msg: newElectionOpenMsg(t, ownerPublicKey, wrongSender, laoID, electionID, channelPath, "", + -1, true, mockRepository), + channel: channelPath, + isError: true, + contains: "sender is not the organizer of the channel", + }) + + wrongChannelPath := "/root/" + base64.URLEncoding.EncodeToString([]byte("wrongLaoID")) + "/" + electionID + + // Test 2 Error when ElectionOpen lao id is not the same as the channel + args = append(args, input{ + name: "Test 2", + msg: newElectionOpenMsg(t, ownerPublicKey, ownerPubBuf64, laoID, electionID, wrongChannelPath, "", + -1, true, mockRepository), + channel: wrongChannelPath, + isError: true, + contains: "lao id is not the same as the channel", + }) + + wrongChannelPath = "/root/" + laoID + "/" + base64.URLEncoding.EncodeToString([]byte("wrongElectionID")) + + // Test 3 Error when ElectionOpen election id is not the same as the channel + args = append(args, input{ + name: "Test 3", + msg: newElectionOpenMsg(t, ownerPublicKey, ownerPubBuf64, laoID, electionID, wrongChannelPath, "", + -1, true, mockRepository), + channel: wrongChannelPath, + isError: true, + contains: "election id is not the same as the channel", + }) + + // Test 4 Error when Election is already started or ended + args = append(args, input{ + name: "Test 4", + msg: newElectionOpenMsg(t, ownerPublicKey, ownerPubBuf64, laoID, electionID, channelPath, messagedata.ElectionActionOpen, + -1, true, mockRepository), + channel: channelPath, + isError: true, + contains: "election is already started or ended", + }) + + //to avoid conflicts with the previous test + electionID = base64.URLEncoding.EncodeToString([]byte("electionID2")) + channelPath = "/root/" + laoID + "/" + electionID + // Test 5 Error when ElectionOpen opened at before createdAt + args = append(args, input{ + name: "Test 5", + msg: newElectionOpenMsg(t, ownerPublicKey, ownerPubBuf64, laoID, electionID, channelPath, messagedata.ElectionActionSetup, + 2, true, mockRepository), + channel: channelPath, + isError: true, + contains: "election open cannot have a creation time prior to election setup", + }) + + //to avoid conflicts with the previous test + electionID = base64.URLEncoding.EncodeToString([]byte("electionID3")) + channelPath = "/root/" + laoID + "/" + electionID + + errAnswer := state.AddChannel(channelPath) + require.Nil(t, errAnswer) + + // Test 6: Success when ElectionOpen is valid + args = append(args, input{ + name: "Test 6", + msg: newElectionOpenMsg(t, ownerPublicKey, ownerPubBuf64, laoID, electionID, channelPath, messagedata.ElectionActionSetup, + 1, false, mockRepository), + channel: channelPath, + isError: false, + contains: "", + }) + + laoID = base64.URLEncoding.EncodeToString([]byte("electionID4")) + channelPath = "/root/" + laoID + "/" + electionID + + // Test 7 Error when ElectionEnd sender is not the same as the lao organizer + args = append(args, input{ + name: "Test 7", + msg: newElectionEndMsg(t, ownerPublicKey, wrongSender, laoID, electionID, channelPath, "", "", + -1, true, mockRepository), + channel: channelPath, + isError: true, + contains: "sender is not the organizer of the channel", + }) + + wrongChannelPath = "/root/" + base64.URLEncoding.EncodeToString([]byte("wrongLaoID2")) + "/" + electionID + + // Test 8 Error when ElectionEnd lao id is not the same as the channel + args = append(args, input{ + name: "Test 8", + msg: newElectionEndMsg(t, ownerPublicKey, ownerPubBuf64, laoID, electionID, wrongChannelPath, "", "", + -1, true, mockRepository), + channel: wrongChannelPath, + isError: true, + contains: "lao id is not the same as the channel", + }) + + wrongChannelPath = "/root/" + laoID + "/" + base64.URLEncoding.EncodeToString([]byte("wrongElectionID2")) + + // Test 9 Error when ElectionEnd election id is not the same as the channel + args = append(args, input{ + name: "Test 9", + msg: newElectionEndMsg(t, ownerPublicKey, ownerPubBuf64, laoID, electionID, wrongChannelPath, "", "", + -1, true, mockRepository), + channel: wrongChannelPath, + isError: true, + contains: "election id is not the same as the channel", + }) + + // Test 10 Error when ElectionEnd is not started + args = append(args, input{ + name: "Test 10", + msg: newElectionEndMsg(t, ownerPublicKey, ownerPubBuf64, laoID, electionID, channelPath, messagedata.ElectionActionEnd, "", + -1, true, mockRepository), + channel: channelPath, + isError: true, + contains: "election was not started", + }) + + //to avoid conflicts with the previous test + electionID = base64.URLEncoding.EncodeToString([]byte("electionID5")) + channelPath = "/root/" + laoID + "/" + electionID + + // Test 11 Error when ElectionEnd creation time is before ElectionSetup creation time + args = append(args, input{ + name: "Test 11", + msg: newElectionEndMsg(t, ownerPublicKey, ownerPubBuf64, laoID, electionID, channelPath, messagedata.ElectionActionOpen, "", + 2, true, mockRepository), + channel: channelPath, + isError: true, + contains: "election end cannot have a creation time prior to election setup", + }) + + //to avoid conflicts with the previous test + electionID = base64.URLEncoding.EncodeToString([]byte("electionID6")) + channelPath = "/root/" + laoID + "/" + electionID + + wrongVotes := messagedata.Hash("wrongVotes") + + // Test 12 Error when ElectionEnd is not the expected hash + args = append(args, input{ + name: "Test 12", + msg: newElectionEndMsg(t, ownerPublicKey, ownerPubBuf64, laoID, electionID, channelPath, messagedata.ElectionActionOpen, wrongVotes, + 1, true, mockRepository), + channel: channelPath, + isError: true, + contains: fmt.Sprintf("registered votes is %s, should be sorted and equal to", wrongVotes), + }) + + //to avoid conflicts with the previous test + electionID = base64.URLEncoding.EncodeToString([]byte("electionID7")) + channelPath = "/root/" + laoID + "/" + electionID + + registeredVotes := messagedata.Hash("voteID1", "voteID2", "voteID3") + + errAnswer = state.AddChannel(channelPath) + require.Nil(t, errAnswer) + + // Test 13: Success when ElectionEnd is valid + args = append(args, input{ + name: "Test 13", + msg: newElectionEndMsg(t, ownerPublicKey, ownerPubBuf64, laoID, electionID, channelPath, messagedata.ElectionActionOpen, registeredVotes, + 1, false, mockRepository), + channel: channelPath, + isError: false, + contains: "", + }) + + votes := []generatortest.VoteInt{ + { + ID: base64.URLEncoding.EncodeToString([]byte("voteID1")), + Question: base64.URLEncoding.EncodeToString([]byte("questionID1")), + Vote: 1, + }, + } + + // Test 14 Error when VoteCastVote sender is not the same as the lao organizer + args = append(args, input{ + name: "Test 14", + msg: newVoteCastVoteIntMsg(t, wrongSender, laoID, electionID, channelPath, "", "", + -1, votes, nil, ownerPublicKey, mockRepository, true), + channel: channelPath, + isError: true, + contains: "sender is not an attendee or the organizer of the election", + }) + + // Test 15 Error when VoteCastVote lao id is not the same as the channel + wrongChannelPath = "/root/" + base64.URLEncoding.EncodeToString([]byte("wrongLaoID3")) + "/" + electionID + + args = append(args, input{ + name: "Test 15", + msg: newVoteCastVoteIntMsg(t, ownerPubBuf64, laoID, electionID, wrongChannelPath, "", "", + -1, votes, nil, ownerPublicKey, mockRepository, true), + channel: wrongChannelPath, + isError: true, + contains: "lao id is not the same as the channel", + }) + + // Test 16 Error when VoteCastVote election id is not the same as the channel + wrongChannelPath = "/root/" + laoID + "/" + base64.URLEncoding.EncodeToString([]byte("wrongElectionID3")) + + args = append(args, input{ + name: "Test 16", + msg: newVoteCastVoteIntMsg(t, ownerPubBuf64, laoID, electionID, wrongChannelPath, "", "", + -1, votes, nil, ownerPublicKey, mockRepository, true), + channel: wrongChannelPath, + isError: true, + contains: "election id is not the same as the channel", + }) + + //to avoid conflicts with the previous test + electionID = base64.URLEncoding.EncodeToString([]byte("electionID9")) + channelPath = "/root/" + laoID + "/" + electionID + + //Test 17 Error when VoteCastVote createdAt is before electionSetup createdAt + args = append(args, input{ + name: "Test 17", + msg: newVoteCastVoteIntMsg(t, ownerPubBuf64, laoID, electionID, channelPath, "", "", + 2, votes, nil, ownerPublicKey, mockRepository, true), + channel: channelPath, + isError: true, + contains: "cast vote cannot have a creation time prior to election setup", + }) + + //to avoid conflicts with the previous test + electionID = base64.URLEncoding.EncodeToString([]byte("electionID10")) + channelPath = "/root/" + laoID + "/" + electionID + + //Test 18 Error when VoteCastVote question is not present in election setup + questions := map[string]types.Question{ + base64.URLEncoding.EncodeToString([]byte("questionID2")): {ID: []byte(base64.URLEncoding.EncodeToString([]byte("questionID2")))}, + base64.URLEncoding.EncodeToString([]byte("questionID3")): {ID: []byte(base64.URLEncoding.EncodeToString([]byte("questionID3")))}, + } + + args = append(args, input{ + name: "Test 18", + msg: newVoteCastVoteIntMsg(t, ownerPubBuf64, laoID, electionID, channelPath, "", "", + 0, votes, questions, ownerPublicKey, mockRepository, true), + channel: channelPath, + isError: true, + contains: "Question does not exist", + }) + + //to avoid conflicts with the previous test + electionID = base64.URLEncoding.EncodeToString([]byte("electionID11")) + channelPath = "/root/" + laoID + "/" + electionID + + //Test 19 Error when VoteCastVote contains a string vote in an OpenBallot election + stringVotes := []generatortest.VoteString{ + { + ID: base64.URLEncoding.EncodeToString([]byte("voteID1")), + Question: base64.URLEncoding.EncodeToString([]byte("questionID2")), + Vote: base64.URLEncoding.EncodeToString([]byte("1")), + }, + } + + args = append(args, input{ + name: "Test 19", + msg: newVoteCastVoteStringMsg(t, ownerPubBuf64, laoID, electionID, channelPath, messagedata.OpenBallot, + 0, stringVotes, questions, ownerPublicKey, mockRepository), + channel: channelPath, + isError: true, + contains: "vote in open ballot should be an integer", + }) + + //to avoid conflicts with the previous test + electionID = base64.URLEncoding.EncodeToString([]byte("electionID12")) + channelPath = "/root/" + laoID + "/" + electionID + + //Test 20 Error when VoteCastVote contains a int vote in an SecretBallot election + intVotes := []generatortest.VoteInt{ + { + ID: base64.URLEncoding.EncodeToString([]byte("voteID1")), + Question: base64.URLEncoding.EncodeToString([]byte("questionID2")), + Vote: 1, + }, + } + + args = append(args, input{ + name: "Test 20", + msg: newVoteCastVoteIntMsg(t, ownerPubBuf64, laoID, electionID, channelPath, "", messagedata.SecretBallot, + 0, intVotes, questions, ownerPublicKey, mockRepository, true), + channel: channelPath, + isError: true, + contains: "vote in secret ballot should be a string", + }) + + //to avoid conflicts with the previous test + electionID = base64.URLEncoding.EncodeToString([]byte("electionID13")) + channelPath = "/root/" + laoID + "/" + electionID + + //Test 21 Error when a vote ID in VoteCastVote is not the expected hash + + args = append(args, input{ + name: "Test 21", + msg: newVoteCastVoteIntMsg(t, ownerPubBuf64, laoID, electionID, channelPath, "", messagedata.OpenBallot, + 0, intVotes, questions, ownerPublicKey, mockRepository, true), + channel: channelPath, + isError: true, + contains: "vote ID is not the expected hash", + }) + + //to avoid conflicts with the previous test + electionID = base64.URLEncoding.EncodeToString([]byte("electionID14")) + channelPath = "/root/" + laoID + "/" + electionID + + //Test 22 Success when election is already ended + questionID := base64.URLEncoding.EncodeToString([]byte("questionID2")) + voteID := messagedata.Hash(voteFlag, electionID, questionID, "1") + + votes = []generatortest.VoteInt{ + { + ID: voteID, + Question: questionID, + Vote: 1, + }, + } + + errAnswer = subs.AddChannel(channelPath) + require.Nil(t, errAnswer) + + args = append(args, input{ + name: "Test 22", + msg: newVoteCastVoteIntMsg(t, ownerPubBuf64, laoID, electionID, channelPath, messagedata.ElectionActionEnd, messagedata.OpenBallot, + 0, votes, questions, ownerPublicKey, mockRepository, false), + channel: channelPath, + isError: false, + contains: "", + }) + + //to avoid conflicts with the previous test + electionID = base64.URLEncoding.EncodeToString([]byte("electionID14")) + channelPath = "/root/" + laoID + "/" + electionID + + //Test 23 Success when election is started + args = append(args, input{ + name: "Test 23", + msg: newVoteCastVoteIntMsg(t, ownerPubBuf64, laoID, electionID, channelPath, messagedata.ElectionActionOpen, "", + -1, votes, nil, ownerPublicKey, mockRepository, true), + channel: channelPath, + isError: false, + contains: "", + }) + + for _, arg := range args { + t.Run(arg.name, func(t *testing.T) { + errAnswer := handleChannelElection(arg.channel, arg.msg) + if arg.isError { + require.Contains(t, errAnswer.Error(), arg.contains) + } else { + require.Nil(t, errAnswer) + } + }) + } +} + +func newElectionOpenMsg(t *testing.T, owner kyber.Point, sender, laoID, electionID, channelPath, state string, + createdAt int64, isError bool, mockRepository *repository.MockRepository) message.Message { + + msg := generatortest.NewElectionOpenMsg(t, sender, laoID, electionID, 1, nil) + + mockRepository.On("GetLAOOrganizerPubKey", channelPath).Return(owner, nil) + + if createdAt >= 0 { + mockRepository.On("GetElectionCreationTime", channelPath).Return(createdAt, nil) + } + + if state != "" { + mockRepository.On("IsElectionStartedOrEnded", channelPath). + Return(state == messagedata.ElectionActionOpen || state == messagedata.ElectionActionEnd, nil) + } + + if !isError { + mockRepository.On("StoreMessageAndData", channelPath, msg).Return(nil) + } + + return msg +} + +func newElectionEndMsg(t *testing.T, owner kyber.Point, sender, laoID, electionID, channelPath, state, votes string, + createdAt int64, isError bool, mockRepository *repository.MockRepository) message.Message { + + msg := generatortest.NewElectionCloseMsg(t, sender, laoID, electionID, votes, 1, nil) + + mockRepository.On("GetLAOOrganizerPubKey", channelPath).Return(owner, nil) + + if state != "" { + mockRepository.On("IsElectionStarted", channelPath). + Return(state == messagedata.ElectionActionOpen, nil) + } + + if createdAt >= 0 { + mockRepository.On("GetElectionCreationTime", channelPath).Return(createdAt, nil) + } + + if votes != "" { + questions := map[string]types.Question{ + "questionID1": { + ID: []byte("questionID1"), + ValidVotes: map[string]types.ValidVote{ + "voteID1": { + ID: "voteID1", + }, + "VoteID2": { + ID: "voteID2", + }, + }, + }, + "questionID2": { + ID: []byte("questionID2"), + ValidVotes: map[string]types.ValidVote{ + "voteID3": { + ID: "voteID3", + }, + }, + }, + } + + mockRepository.On("GetElectionQuestionsWithValidVotes", channelPath).Return(questions, nil) + } + + if !isError { + mockRepository.On("GetElectionType", channelPath).Return(messagedata.OpenBallot, nil) + mockRepository.On("StoreElectionEndWithResult", channelPath, msg, mock.AnythingOfType("message.Message")). + Return(nil) + } + + return msg +} + +func newVoteCastVoteIntMsg(t *testing.T, sender, laoID, electionID, electionPath, state, electionType string, + createdAt int64, votes []generatortest.VoteInt, questions map[string]types.Question, owner kyber.Point, + mockRepository *repository.MockRepository, isEroor bool) message.Message { + + msg := generatortest.NewVoteCastVoteIntMsg(t, sender, laoID, electionID, 1, votes, nil) + mockRepository.On("GetLAOOrganizerPubKey", electionPath).Return(owner, nil) + mockRepository.On("GetElectionAttendees", electionPath).Return(map[string]struct{}{ownerPubBuf64: {}}, nil) + + if state == messagedata.ElectionActionOpen { + mockRepository.On("IsElectionStarted", electionPath). + Return(true, nil) + } + + if state == messagedata.ElectionActionEnd { + mockRepository.On("IsElectionEnded", electionPath). + Return(false, nil) + mockRepository.On("IsElectionStarted", electionPath). + Return(true, nil) + } + + if createdAt >= 0 { + mockRepository.On("GetElectionCreationTime", electionPath).Return(createdAt, nil) + } + + if electionType != "" { + mockRepository.On("GetElectionType", electionPath).Return(electionType, nil) + } + + if questions != nil { + mockRepository.On("GetElectionQuestions", electionPath).Return(questions, nil) + } + + if !isEroor { + mockRepository.On("StoreMessageAndData", electionPath, msg).Return(nil) + } + return msg +} + +func newVoteCastVoteStringMsg(t *testing.T, sender, laoID, electionID, electionPath, electionType string, + createdAt int64, votes []generatortest.VoteString, questions map[string]types.Question, owner kyber.Point, + mockRepository *repository.MockRepository) message.Message { + + msg := generatortest.NewVoteCastVoteStringMsg(t, sender, laoID, electionID, 1, votes, nil) + mockRepository.On("GetLAOOrganizerPubKey", electionPath).Return(owner, nil) + mockRepository.On("GetElectionAttendees", electionPath).Return(map[string]struct{}{ownerPubBuf64: {}}, nil) + + if createdAt >= 0 { + mockRepository.On("GetElectionCreationTime", electionPath).Return(createdAt, nil) + } + + if electionType != "" { + mockRepository.On("GetElectionType", electionPath).Return(electionType, nil) + } + + if questions != nil { + mockRepository.On("GetElectionQuestions", electionPath).Return(questions, nil) + } + + return msg +} diff --git a/be1-go/internal/popserver/handler/incoming_message.go b/be1-go/internal/popserver/handler/incoming_message.go new file mode 100644 index 0000000000..e908cc2082 --- /dev/null +++ b/be1-go/internal/popserver/handler/incoming_message.go @@ -0,0 +1,40 @@ +package handler + +import ( + "popstellar/internal/popserver/utils" + "popstellar/message" + "popstellar/message/answer" + "popstellar/network/socket" + "popstellar/validation" +) + +func HandleIncomingMessage(socket socket.Socket, msg []byte) error { + errAnswer := utils.VerifyJSON(msg, validation.GenericMessage) + if errAnswer != nil { + errAnswer = errAnswer.Wrap("handleMessage") + socket.SendError(nil, errAnswer) + return errAnswer + } + + rpcType, err := message.GetType(msg) + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("failed to get rpc type: %v", err).Wrap("handleMessage") + socket.SendError(nil, errAnswer) + return errAnswer + } + + switch rpcType { + case message.RPCTypeQuery: + errAnswer = handleQuery(socket, msg) + case message.RPCTypeAnswer: + errAnswer = handleAnswer(msg) + default: + errAnswer = answer.NewInvalidMessageFieldError("jsonRPC is of unknown type") + } + + if errAnswer != nil { + return errAnswer.Wrap("handleMessage") + } + + return nil +} diff --git a/be1-go/internal/popserver/handler/incoming_message_test.go b/be1-go/internal/popserver/handler/incoming_message_test.go new file mode 100644 index 0000000000..6b51662aa1 --- /dev/null +++ b/be1-go/internal/popserver/handler/incoming_message_test.go @@ -0,0 +1,71 @@ +package handler + +import ( + "encoding/base64" + "fmt" + "github.com/rs/zerolog" + "github.com/stretchr/testify/require" + "io" + "os" + "popstellar/internal/popserver/generatortest" + "popstellar/internal/popserver/utils" + "popstellar/network/socket" + "popstellar/validation" + "testing" +) + +var noLog = zerolog.New(io.Discard) + +func TestMain(m *testing.M) { + schemaValidator, err := validation.NewSchemaValidator() + if err != nil { + _, _ = fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + + utils.InitUtils(&noLog, schemaValidator) + + exitVal := m.Run() + + os.Exit(exitVal) +} + +func Test_handleIncomingMessage(t *testing.T) { + type input struct { + name string + message []byte + contains string + } + + args := make([]input, 0) + + // Test 1: failed to handled popanswer because wrong json + + args = append(args, input{ + name: "Test 1", + message: generatortest.NewNothingQuery(t, 999), + contains: "invalid json", + }) + + // Test 2: failed to handled popanswer because wrong publish popanswer format + + msg := generatortest.NewNothingMsg(t, base64.URLEncoding.EncodeToString([]byte("sender")), nil) + msg.MessageID = "wrong messageID" + + args = append(args, input{ + name: "Test 2", + message: generatortest.NewPublishQuery(t, 1, "/root/lao1", msg), + contains: "invalid json", + }) + + // run all tests + + for _, arg := range args { + t.Run(arg.name, func(t *testing.T) { + fakeSocket := socket.FakeSocket{Id: "1"} + err := HandleIncomingMessage(&fakeSocket, arg.message) + require.Error(t, err) + require.Contains(t, err.Error(), arg.contains) + }) + } +} diff --git a/be1-go/internal/popserver/handler/lao.go b/be1-go/internal/popserver/handler/lao.go new file mode 100644 index 0000000000..92e02aebe5 --- /dev/null +++ b/be1-go/internal/popserver/handler/lao.go @@ -0,0 +1,478 @@ +package handler + +import ( + "encoding/base64" + "encoding/json" + "go.dedis.ch/kyber/v3" + "go.dedis.ch/kyber/v3/sign/schnorr" + "golang.org/x/exp/slices" + "popstellar/crypto" + "popstellar/internal/popserver/config" + "popstellar/internal/popserver/database" + "popstellar/internal/popserver/state" + "popstellar/message/answer" + "popstellar/message/messagedata" + "popstellar/message/query/method/message" + "strings" +) + +func handleChannelLao(channelPath string, msg message.Message) *answer.Error { + object, action, errAnswer := verifyDataAndGetObjectAction(msg) + if errAnswer != nil { + return errAnswer.Wrap("handleChannelLao") + } + + storeMessage := true + switch object + "#" + action { + case messagedata.LAOObject + "#" + messagedata.LAOActionState: + errAnswer = handleLaoState(msg, channelPath) + case messagedata.LAOObject + "#" + messagedata.LAOActionUpdate: + errAnswer = handleLaoUpdate(msg) + case messagedata.MessageObject + "#" + messagedata.MessageActionWitness: + errAnswer = handleMessageWitness(msg) + case messagedata.MeetingObject + "#" + messagedata.MeetingActionCreate: + errAnswer = handleMeetingCreate(msg) + case messagedata.MeetingObject + "#" + messagedata.MeetingActionState: + errAnswer = handleMeetingState(msg) + case messagedata.RollCallObject + "#" + messagedata.RollCallActionClose: + storeMessage = false + errAnswer = handleRollCallClose(msg, channelPath) + case messagedata.RollCallObject + "#" + messagedata.RollCallActionCreate: + errAnswer = handleRollCallCreate(msg, channelPath) + case messagedata.RollCallObject + "#" + messagedata.RollCallActionOpen: + errAnswer = handleRollCallOpen(msg, channelPath) + case messagedata.RollCallObject + "#" + messagedata.RollCallActionReOpen: + errAnswer = handleRollCallReOpen(msg, channelPath) + case messagedata.ElectionObject + "#" + messagedata.ElectionActionSetup: + storeMessage = false + errAnswer = handleElectionSetup(msg, channelPath) + default: + errAnswer = answer.NewInvalidMessageFieldError("failed to handle %s#%s, invalid object#action", object, action) + } + if errAnswer != nil { + return errAnswer.Wrap("handleChannelLao") + } + + if storeMessage { + db, errAnswer := database.GetLAORepositoryInstance() + if errAnswer != nil { + return errAnswer.Wrap("handleChannelLao") + } + + err := db.StoreMessageAndData(channelPath, msg) + if err != nil { + errAnswer = answer.NewStoreDatabaseError("message: %v", err) + return errAnswer.Wrap("handleChannelLao") + } + } + + errAnswer = broadcastToAllClients(msg, channelPath) + if errAnswer != nil { + return errAnswer.Wrap("handleChannelLao") + } + return nil +} + +func handleRollCallCreate(msg message.Message, channelPath string) *answer.Error { + var rollCallCreate messagedata.RollCallCreate + errAnswer := msg.UnmarshalMsgData(&rollCallCreate) + if errAnswer != nil { + return errAnswer.Wrap("handleRollCallCreate") + } + + errAnswer = rollCallCreate.Verify(channelPath) + if errAnswer != nil { + return errAnswer.Wrap("handleRollCallCreate") + } + + return nil +} + +func handleRollCallOpen(msg message.Message, channelPath string) *answer.Error { + var rollCallOpen messagedata.RollCallOpen + errAnswer := msg.UnmarshalMsgData(&rollCallOpen) + if errAnswer != nil { + return errAnswer.Wrap("handleRollCallOpen") + } + + errAnswer = rollCallOpen.Verify(channelPath) + if errAnswer != nil { + return errAnswer.Wrap("handleRollCallOpen") + } + + db, errAnswer := database.GetLAORepositoryInstance() + if errAnswer != nil { + return errAnswer.Wrap("handleRollCallOpen") + } + + ok, err := db.CheckPrevCreateOrCloseID(channelPath, rollCallOpen.Opens) + if err != nil { + errAnswer = answer.NewQueryDatabaseError("if previous id exists: %v", err) + return errAnswer.Wrap("handleRollCallOpen") + } else if !ok { + errAnswer = answer.NewInvalidMessageFieldError("previous id does not exist") + return errAnswer.Wrap("handleRollCallOpen") + } + return nil +} + +func handleRollCallReOpen(msg message.Message, channelPath string) *answer.Error { + var rollCallReOpen messagedata.RollCallReOpen + errAnswer := msg.UnmarshalMsgData(&rollCallReOpen) + if errAnswer != nil { + return errAnswer.Wrap("handleRollCallReOpen") + } + + errAnswer = handleRollCallOpen(msg, channelPath) + if errAnswer != nil { + return errAnswer.Wrap("handleRollCallReOpen") + } + + return nil +} + +func handleRollCallClose(msg message.Message, channelPath string) *answer.Error { + var rollCallClose messagedata.RollCallClose + errAnswer := msg.UnmarshalMsgData(&rollCallClose) + if errAnswer != nil { + return errAnswer.Wrap("handleRollCallClose") + } + + errAnswer = rollCallClose.Verify(channelPath) + if errAnswer != nil { + return errAnswer.Wrap("handleRollCallClose") + } + + db, errAnswer := database.GetLAORepositoryInstance() + if errAnswer != nil { + return errAnswer.Wrap("handleRollCallClose") + } + + ok, err := db.CheckPrevOpenOrReopenID(channelPath, rollCallClose.Closes) + if err != nil { + errAnswer = answer.NewQueryDatabaseError("if previous id exists: %v", err) + return errAnswer.Wrap("handleRollCallClose") + } else if !ok { + errAnswer = answer.NewInvalidMessageFieldError("previous id does not exist") + return errAnswer.Wrap("handleRollCallClose") + } + + newChannels, errAnswer := createNewAttendeeChannels(channelPath, rollCallClose) + if errAnswer != nil { + return errAnswer.Wrap("handleRollCallClose") + } + if len(newChannels) == 0 { + return nil + } + + err = db.StoreRollCallClose(newChannels, channelPath, msg) + if err != nil { + errAnswer = answer.NewStoreDatabaseError("channels and message: %v", err) + return errAnswer.Wrap("handleRollCallClose") + } + + return nil +} + +func createNewAttendeeChannels(channelPath string, rollCallClose messagedata.RollCallClose) ([]string, *answer.Error) { + channels := make([]string, 0, len(rollCallClose.Attendees)) + + for _, popToken := range rollCallClose.Attendees { + _, err := base64.URLEncoding.DecodeString(popToken) + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("failed to decode poptoken: %v", err) + return nil, errAnswer.Wrap("handleRollCallClose") + } + chirpingChannelPath := channelPath + Social + "/" + popToken + channels = append(channels, chirpingChannelPath) + } + + newChannels := make([]string, 0) + for _, channelPath := range channels { + alreadyExists, errAnswer := state.HasChannel(channelPath) + if errAnswer != nil { + return nil, errAnswer + } + if alreadyExists { + continue + } + errAnswer = state.AddChannel(channelPath) + if errAnswer != nil { + return nil, errAnswer + } + + newChannels = append(newChannels, channelPath) + } + + return newChannels, nil +} + +func handleElectionSetup(msg message.Message, channelPath string) *answer.Error { + var electionSetup messagedata.ElectionSetup + errAnswer := msg.UnmarshalMsgData(&electionSetup) + if errAnswer != nil { + return errAnswer.Wrap("handleElectionSetup") + } + + errAnswer = verifySenderLao(channelPath, msg) + if errAnswer != nil { + return errAnswer.Wrap("handleElectionSetup") + } + + laoID, _ := strings.CutPrefix(channelPath, RootPrefix) + + errAnswer = electionSetup.Verify(laoID) + if errAnswer != nil { + return errAnswer.Wrap("handleElectionSetup") + } + + for _, question := range electionSetup.Questions { + errAnswer = question.Verify(electionSetup.ID) + if errAnswer != nil { + return errAnswer.Wrap("handleElectionSetup") + } + } + + errAnswer = storeElection(msg, electionSetup, channelPath) + if errAnswer != nil { + return errAnswer.Wrap("handleElectionSetup") + } + return nil +} + +func verifySenderLao(channelPath string, msg message.Message) *answer.Error { + senderBuf, err := base64.URLEncoding.DecodeString(msg.Sender) + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("failed to decode sender public key: %v", err) + return errAnswer + } + senderPubKey := crypto.Suite.Point() + err = senderPubKey.UnmarshalBinary(senderBuf) + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("failed to unmarshal sender public key: %v", err) + return errAnswer + } + + db, errAnswer := database.GetLAORepositoryInstance() + if errAnswer != nil { + return errAnswer + } + + organizePubKey, err := db.GetOrganizerPubKey(channelPath) + if err != nil { + errAnswer = answer.NewQueryDatabaseError("organizer public key: %v", err) + return errAnswer + } + + if !organizePubKey.Equal(senderPubKey) { + errAnswer = answer.NewAccessDeniedError("sender public key does not match organizer public key: %s != %s", + senderPubKey, organizePubKey) + return errAnswer + } + + return nil +} + +func storeElection(msg message.Message, electionSetup messagedata.ElectionSetup, channelPath string) *answer.Error { + var errAnswer *answer.Error + + electionPubKey, electionSecretKey := generateKeys() + var electionKeyMsg message.Message + electionPath := channelPath + "/" + electionSetup.ID + + db, errAnswer := database.GetLAORepositoryInstance() + if errAnswer != nil { + return errAnswer + } + + if electionSetup.Version == messagedata.SecretBallot { + electionKeyMsg, errAnswer = createElectionKey(electionSetup.ID, electionPubKey) + if errAnswer != nil { + return errAnswer + } + err := db.StoreElectionWithElectionKey(channelPath, electionPath, electionPubKey, electionSecretKey, msg, electionKeyMsg) + if err != nil { + errAnswer = answer.NewStoreDatabaseError("election setup message: %v", err) + return errAnswer + } + } else { + err := db.StoreElection(channelPath, electionPath, electionPubKey, electionSecretKey, msg) + if err != nil { + errAnswer = answer.NewStoreDatabaseError("election setup message: %v", err) + return errAnswer + } + } + + errAnswer = state.AddChannel(electionPath) + if errAnswer != nil { + return errAnswer + } + + return nil +} + +func createElectionKey(electionID string, electionPubKey kyber.Point) (message.Message, *answer.Error) { + electionPubBuf, err := electionPubKey.MarshalBinary() + if err != nil { + errAnswer := answer.NewInternalServerError("failed to marshal election public key: %v", err) + return message.Message{}, errAnswer.Wrap("createAndSendElectionKey") + } + msgData := messagedata.ElectionKey{ + Object: messagedata.ElectionObject, + Action: messagedata.ElectionActionKey, + Election: electionID, + Key: base64.URLEncoding.EncodeToString(electionPubBuf), + } + + dataBuf, err := json.Marshal(&msgData) + if err != nil { + errAnswer := answer.NewInternalServerError("failed to marshal message data: %v", err) + return message.Message{}, errAnswer.Wrap("createAndSendElectionKey") + } + newData64 := base64.URLEncoding.EncodeToString(dataBuf) + + serverPublicKey, errAnswer := config.GetServerPublicKeyInstance() + if errAnswer != nil { + return message.Message{}, errAnswer.Wrap("createAndSendElectionKey") + } + + serverPubBuf, err := serverPublicKey.MarshalBinary() + if err != nil { + errAnswer := answer.NewInternalServerError("failed to unmarshall server secret key", err) + return message.Message{}, errAnswer.Wrap("createAndSendElectionKey") + } + signatureBuf, errAnswer := Sign(dataBuf) + if errAnswer != nil { + return message.Message{}, errAnswer.Wrap("createAndSendElectionKey") + } + signature := base64.URLEncoding.EncodeToString(signatureBuf) + electionKeyMsg := message.Message{ + Data: newData64, + Sender: base64.URLEncoding.EncodeToString(serverPubBuf), + Signature: signature, + MessageID: messagedata.Hash(newData64, signature), + WitnessSignatures: []message.WitnessSignature{}, + } + return electionKeyMsg, nil +} + +// Not working +func handleLaoState(msg message.Message, channelPath string) *answer.Error { + var laoState messagedata.LaoState + errAnswer := msg.UnmarshalMsgData(&laoState) + if errAnswer != nil { + return errAnswer.Wrap("handleLaoState") + } + + db, errAnswer := database.GetLAORepositoryInstance() + if errAnswer != nil { + return errAnswer.Wrap("handleLaoState") + } + + ok, err := db.HasMessage(laoState.ModificationID) + if err != nil { + errAnswer := answer.NewQueryDatabaseError("if message exists: %v", err) + return errAnswer.Wrap("handleLaoState") + } else if !ok { + errAnswer := answer.NewInvalidMessageFieldError("message corresponding to modificationID %s does not exist", laoState.ModificationID) + return errAnswer.Wrap("handleLaoState") + } + + witnesses, err := db.GetLaoWitnesses(channelPath) + if err != nil { + errAnswer := answer.NewQueryDatabaseError("lao witnesses: %v", err) + return errAnswer.Wrap("handleLaoState") + } + + // Check if the signatures match + expected := len(witnesses) + match := 0 + for _, modificationSignature := range laoState.ModificationSignatures { + err = schnorr.VerifyWithChecks(crypto.Suite, []byte(modificationSignature.Witness), + []byte(laoState.ModificationID), []byte(modificationSignature.Signature)) + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("failed to verify signature for witness: %s", modificationSignature.Witness) + return errAnswer.Wrap("handleLaoState") + } + if _, ok := witnesses[modificationSignature.Witness]; ok { + match++ + } + } + + if match != expected { + errAnswer := answer.NewInvalidMessageFieldError("not enough witness signatures provided. Needed %d got %d", expected, match) + return errAnswer.Wrap("handleLaoState") + } + + var updateMsgData messagedata.LaoUpdate + + err = msg.UnmarshalData(&updateMsgData) + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("failed to unmarshal update message data: %v", err) + return errAnswer.Wrap("handleLaoState") + } + + err = updateMsgData.Verify() + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("failed to verify update message data: %v", err) + return errAnswer.Wrap("handleLaoState") + } + + errAnswer = compareLaoUpdateAndState(updateMsgData, laoState) + if errAnswer != nil { + return errAnswer.Wrap("handleLaoState") + } + return nil +} + +func compareLaoUpdateAndState(update messagedata.LaoUpdate, state messagedata.LaoState) *answer.Error { + if update.LastModified != state.LastModified { + errAnswer := answer.NewInvalidMessageFieldError("mismatch between last modified: expected %d got %d", + update.LastModified, state.LastModified) + return errAnswer.Wrap("compareLaoUpdateAndState") + } + + if update.Name != state.Name { + errAnswer := answer.NewInvalidMessageFieldError("mismatch between name: expected %s got %s", + update.Name, state.Name) + return errAnswer.Wrap("compareLaoUpdateAndState") + } + + numUpdateWitnesses := len(update.Witnesses) + numStateWitnesses := len(state.Witnesses) + + if numUpdateWitnesses != numStateWitnesses { + errAnswer := answer.NewInvalidMessageFieldError("mismatch between witness count") + return errAnswer.Wrap("compareLaoUpdateAndState") + } + + match := 0 + for _, updateWitness := range update.Witnesses { + if slices.Contains(state.Witnesses, updateWitness) { + match++ + } + } + if match != numUpdateWitnesses { + errAnswer := answer.NewInvalidMessageFieldError("mismatch between witness keys") + return errAnswer.Wrap("compareLaoUpdateAndState") + } + return nil +} + +// Not implemented yet +func handleLaoUpdate(msg message.Message) *answer.Error { + return nil +} + +// Not implemented yet +func handleMeetingCreate(msg message.Message) *answer.Error { + return nil +} + +// Not implemented yet +func handleMeetingState(msg message.Message) *answer.Error { + return nil +} + +// Not implemented yet +func handleMessageWitness(msg message.Message) *answer.Error { return nil } diff --git a/be1-go/internal/popserver/handler/lao_test.go b/be1-go/internal/popserver/handler/lao_test.go new file mode 100644 index 0000000000..9631e89fb1 --- /dev/null +++ b/be1-go/internal/popserver/handler/lao_test.go @@ -0,0 +1,403 @@ +package handler + +import ( + "encoding/base64" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "go.dedis.ch/kyber/v3" + "popstellar/crypto" + "popstellar/internal/popserver/config" + "popstellar/internal/popserver/database" + "popstellar/internal/popserver/database/repository" + "popstellar/internal/popserver/generatortest" + "popstellar/internal/popserver/state" + "popstellar/internal/popserver/types" + "popstellar/message/messagedata" + "popstellar/message/query/method/message" + "strconv" + "strings" + "testing" + "time" +) + +func Test_handleChannelLao(t *testing.T) { + subs := types.NewSubscribers() + queries := types.NewQueries(&noLog) + peers := types.NewPeers() + + state.SetState(subs, peers, queries) + + ownerPubBuf, err := base64.URLEncoding.DecodeString(ownerPubBuf64) + require.NoError(t, err) + + ownerPublicKey := crypto.Suite.Point() + err = ownerPublicKey.UnmarshalBinary(ownerPubBuf) + require.NoError(t, err) + + serverSecretKey := crypto.Suite.Scalar().Pick(crypto.Suite.RandomStream()) + serverPublicKey := crypto.Suite.Point().Mul(serverSecretKey, nil) + + config.SetConfig(ownerPublicKey, serverPublicKey, serverSecretKey, "clientAddress", "serverAddress") + + var args []input + mockRepository := repository.NewMockRepository(t) + database.SetDatabase(mockRepository) + + laoID := base64.URLEncoding.EncodeToString([]byte("laoID")) + errAnswer := subs.AddChannel(laoID) + require.Nil(t, errAnswer) + + // Test 1:Success For LaoState message + args = append(args, input{ + name: "Test 1", + msg: newLaoStateMsg(t, ownerPubBuf64, laoID, mockRepository), + channel: laoID, + isError: false, + contains: "", + }) + + creation := time.Now().Unix() + start := creation + 2 + end := start + 1 + + // Test 2: Error when RollCallCreate ID is not the expected hash + args = append(args, input{ + name: "Test 2", + msg: newRollCallCreateMsg(t, ownerPubBuf64, laoID, wrongLaoName, creation, start, end, true, mockRepository), + channel: laoID, + isError: true, + contains: "roll call id is", + }) + + // Test 3: Error when RollCallCreate proposed start is before creation + args = append(args, input{ + name: "Test 3", + msg: newRollCallCreateMsg(t, ownerPubBuf64, laoID, goodLaoName, creation, creation-1, end, true, mockRepository), + channel: laoID, + isError: true, + contains: "roll call proposed start time should be greater than creation time", + }) + + // Test 4: Error when RollCallCreate proposed end is before proposed start + args = append(args, input{ + name: "Test 4", + msg: newRollCallCreateMsg(t, ownerPubBuf64, laoID, goodLaoName, creation, start, start-1, true, mockRepository), + channel: laoID, + isError: true, + contains: "roll call proposed end should be greater than proposed start", + }) + + // Test 5: Success for RollCallCreate message + args = append(args, input{ + name: "Test 5", + msg: newRollCallCreateMsg(t, ownerPubBuf64, laoID, goodLaoName, creation, start, end, false, mockRepository), + channel: laoID, + isError: false, + contains: "", + }) + + opens := base64.URLEncoding.EncodeToString([]byte("opens")) + wrongOpens := base64.URLEncoding.EncodeToString([]byte("wrongOpens")) + + // Test 6: Error when RollCallOpen ID is not the expected hash + args = append(args, input{ + name: "Test 6", + msg: newRollCallOpenMsg(t, ownerPubBuf64, laoID, wrongOpens, "", time.Now().Unix(), true, mockRepository), + channel: laoID, + isError: true, + contains: "roll call update id is", + }) + + // Test 7: Error when RollCallOpen opens is not the same as previous RollCallCreate + args = append(args, input{ + name: "Test 7", + msg: newRollCallOpenMsg(t, ownerPubBuf64, laoID, opens, wrongOpens, time.Now().Unix(), true, mockRepository), + channel: laoID, + isError: true, + contains: "previous id does not exist", + }) + + laoID = base64.URLEncoding.EncodeToString([]byte("laoID2")) + errAnswer = subs.AddChannel(laoID) + require.Nil(t, errAnswer) + + // Test 8: Success for RollCallOpen message + args = append(args, input{ + name: "Test 8", + msg: newRollCallOpenMsg(t, ownerPubBuf64, laoID, opens, opens, time.Now().Unix(), false, mockRepository), + channel: laoID, + isError: false, + contains: "", + }) + + closes := base64.URLEncoding.EncodeToString([]byte("closes")) + wrongCloses := base64.URLEncoding.EncodeToString([]byte("wrongCloses")) + + // Test 9: Error when RollCallClose ID is not the expected hash + args = append(args, input{ + name: "Test 9", + msg: newRollCallCloseMsg(t, ownerPubBuf64, laoID, wrongCloses, "", time.Now().Unix(), true, mockRepository), + channel: laoID, + isError: true, + contains: "roll call update id is", + }) + + // Test 10: Error when RollCallClose closes is not the same as previous RollCallOpen + args = append(args, input{ + name: "Test 10", + msg: newRollCallCloseMsg(t, ownerPubBuf64, laoID, closes, wrongCloses, time.Now().Unix(), true, mockRepository), + channel: laoID, + isError: true, + contains: "previous id does not exist", + }) + + laoID = base64.URLEncoding.EncodeToString([]byte("laoID3")) + errAnswer = subs.AddChannel(laoID) + require.Nil(t, errAnswer) + + // Test 11: Success for RollCallClose message + args = append(args, input{ + name: "Test 11", + msg: newRollCallCloseMsg(t, ownerPubBuf64, laoID, closes, closes, time.Now().Unix(), false, mockRepository), + channel: laoID, + isError: false, + contains: "", + }) + + electionsName := "electionName" + question := "question" + wrongQuestion := "wrongQuestion" + // Test 12: Error when sender is not the organizer of the lao for ElectionSetup + args = append(args, input{ + name: "Test 12", + msg: newElectionSetupMsg(t, ownerPublicKey, wrongSender, laoID, laoID, electionsName, question, messagedata.OpenBallot, + creation, start, end, true, mockRepository), + channel: laoID, + isError: true, + contains: "sender public key does not match organizer public key", + }) + + wrongLaoID := base64.URLEncoding.EncodeToString([]byte("wrongLaoID")) + // Test 13: Error when ElectionSetup lao is not the same as the channel + args = append(args, input{ + name: "Test 13", + msg: newElectionSetupMsg(t, ownerPublicKey, ownerPubBuf64, wrongLaoID, laoID, electionsName, question, messagedata.OpenBallot, + creation, start, end, true, mockRepository), + channel: laoID, + isError: true, + contains: "lao id is", + }) + + // Test 14: Error when ElectionSetup ID is not the expected hash + args = append(args, input{ + name: "Test 14", + msg: newElectionSetupMsg(t, ownerPublicKey, ownerPubBuf64, laoID, laoID, "wrongName", question, messagedata.OpenBallot, + creation, start, end, true, mockRepository), + channel: laoID, + isError: true, + contains: "election id is", + }) + + // Test 15: Error when proposedStart is before createdAt + args = append(args, input{ + name: "Test 15", + msg: newElectionSetupMsg(t, ownerPublicKey, ownerPubBuf64, laoID, laoID, electionsName, question, messagedata.OpenBallot, + creation, creation-1, end, true, mockRepository), + channel: laoID, + isError: true, + contains: "election start should be greater that creation time", + }) + + // Test 16: Error when proposedEnd is before proposedStart + args = append(args, input{ + name: "Test 16", + msg: newElectionSetupMsg(t, ownerPublicKey, ownerPubBuf64, laoID, laoID, electionsName, question, messagedata.OpenBallot, + creation, start, start-1, true, mockRepository), + channel: laoID, + isError: true, + contains: "election end should be greater that start time", + }) + + // Test 17: Error when ElectionSetup question is empty + args = append(args, input{ + name: "Test 17", + msg: newElectionSetupMsg(t, ownerPublicKey, ownerPubBuf64, laoID, laoID, electionsName, "", messagedata.OpenBallot, + creation, start, end, true, mockRepository), + channel: laoID, + isError: true, + contains: "Question is empty", + }) + + //Test 18: Error when question hash is not the same as the expected hash + args = append(args, input{ + name: "Test 18", + msg: newElectionSetupMsg(t, ownerPublicKey, ownerPubBuf64, laoID, laoID, electionsName, wrongQuestion, messagedata.OpenBallot, + creation, start, end, true, mockRepository), + channel: laoID, + isError: true, + contains: "Question id is", + }) + + laoID = base64.URLEncoding.EncodeToString([]byte("laoID4")) + errAnswer = subs.AddChannel(laoID) + require.Nil(t, errAnswer) + + // Test 19: Success for ElectionSetup message + args = append(args, input{ + name: "Test 19", + msg: newElectionSetupMsg(t, ownerPublicKey, ownerPubBuf64, laoID, laoID, electionsName, question, messagedata.OpenBallot, + creation, start, end, false, mockRepository), + channel: laoID, + isError: false, + contains: "", + }) + + for _, arg := range args { + t.Run(arg.name, func(t *testing.T) { + errAnswer := handleChannelLao(arg.channel, arg.msg) + if arg.isError { + require.NotNil(t, errAnswer) + require.Contains(t, errAnswer.Error(), arg.contains) + } else { + require.Nil(t, errAnswer) + } + }) + } +} + +func newLaoStateMsg(t *testing.T, organizer, laoID string, mockRepository *repository.MockRepository) message.Message { + modificationID := base64.URLEncoding.EncodeToString([]byte("modificationID")) + name := "laoName" + creation := time.Now().Unix() + lastModified := time.Now().Unix() + + msg := generatortest.NewLaoStateMsg(t, organizer, laoID, name, modificationID, creation, lastModified, nil) + + mockRepository.On("HasMessage", modificationID). + Return(true, nil) + mockRepository.On("GetLaoWitnesses", laoID). + Return(map[string]struct{}{}, nil) + mockRepository.On("StoreMessageAndData", laoID, msg). + Return(nil) + + return msg +} + +func newRollCallCreateMsg(t *testing.T, sender, laoID, laoName string, creation, start, end int64, isError bool, + mockRepository *repository.MockRepository) message.Message { + + createID := messagedata.Hash( + messagedata.RollCallFlag, + strings.ReplaceAll(laoID, RootPrefix, ""), + strconv.Itoa(int(creation)), + goodLaoName, + ) + + msg := generatortest.NewRollCallCreateMsg(t, sender, laoName, createID, creation, start, end, nil) + + if !isError { + mockRepository.On("StoreMessageAndData", laoID, msg).Return(nil) + } + + return msg +} + +func newRollCallOpenMsg(t *testing.T, sender, laoID, opens, prevID string, openedAt int64, isError bool, + mockRepository *repository.MockRepository) message.Message { + + openID := messagedata.Hash( + messagedata.RollCallFlag, + strings.ReplaceAll(laoID, RootPrefix, ""), + base64.URLEncoding.EncodeToString([]byte("opens")), + strconv.Itoa(int(openedAt)), + ) + + msg := generatortest.NewRollCallOpenMsg(t, sender, openID, opens, openedAt, nil) + + if !isError { + mockRepository.On("StoreMessageAndData", laoID, msg).Return(nil) + } + if prevID != "" { + mockRepository.On("CheckPrevCreateOrCloseID", laoID, opens).Return(opens == prevID, nil) + } + + return msg +} + +func newRollCallCloseMsg(t *testing.T, sender, laoID, closes, prevID string, closedAt int64, isError bool, + mockRepository *repository.MockRepository) message.Message { + + closeID := messagedata.Hash( + messagedata.RollCallFlag, + strings.ReplaceAll(laoID, RootPrefix, ""), + base64.URLEncoding.EncodeToString([]byte("closes")), + strconv.Itoa(int(closedAt)), + ) + + attendees := []string{base64.URLEncoding.EncodeToString([]byte("a")), base64.URLEncoding.EncodeToString([]byte("b"))} + + msg := generatortest.NewRollCallCloseMsg(t, sender, closeID, closes, closedAt, attendees, nil) + + if !isError { + var channels []string + for _, attendee := range attendees { + channels = append(channels, laoID+Social+"/"+attendee) + } + mockRepository.On("StoreRollCallClose", channels, laoID, msg).Return(nil) + } + if prevID != "" { + mockRepository.On("CheckPrevOpenOrReopenID", laoID, closes).Return(closes == prevID, nil) + } + + return msg +} + +func newElectionSetupMsg(t *testing.T, organizer kyber.Point, sender, + setupLao, laoID, electionName, question, version string, + createdAt, start, end int64, + isError bool, mockRepository *repository.MockRepository) message.Message { + + electionSetupID := messagedata.Hash( + messagedata.ElectionFlag, + setupLao, + strconv.Itoa(int(createdAt)), + "electionName", + ) + + var questions []messagedata.ElectionSetupQuestion + if question != "" { + questionID := messagedata.Hash("Question", electionSetupID, "question") + questions = append(questions, messagedata.ElectionSetupQuestion{ + ID: questionID, + Question: question, + VotingMethod: "Plurality", + BallotOptions: []string{"Option1", "Option2"}, + WriteIn: false, + }) + } else { + questionID := messagedata.Hash("Question", electionSetupID, "") + questions = append(questions, messagedata.ElectionSetupQuestion{ + ID: questionID, + Question: "", + VotingMethod: "Plurality", + BallotOptions: []string{"Option1", "Option2"}, + WriteIn: false, + }) + } + + msg := generatortest.NewElectionSetupMsg(t, sender, electionSetupID, setupLao, electionName, version, createdAt, start, + end, questions, nil) + + mockRepository.On("GetOrganizerPubKey", laoID).Return(organizer, nil) + + if !isError { + mockRepository.On("StoreElection", + laoID, + laoID+"/"+electionSetupID, + mock.AnythingOfType("*edwards25519.point"), + mock.AnythingOfType("*edwards25519.scalar"), + msg).Return(nil) + } + + return msg +} diff --git a/be1-go/internal/popserver/handler/query.go b/be1-go/internal/popserver/handler/query.go new file mode 100644 index 0000000000..dbe9635ca5 --- /dev/null +++ b/be1-go/internal/popserver/handler/query.go @@ -0,0 +1,286 @@ +package handler + +import ( + "encoding/json" + "popstellar/internal/popserver/config" + "popstellar/internal/popserver/database" + "popstellar/internal/popserver/state" + jsonrpc "popstellar/message" + "popstellar/message/answer" + "popstellar/message/query" + "popstellar/message/query/method" + "popstellar/network/socket" +) + +func handleQuery(socket socket.Socket, msg []byte) *answer.Error { + var queryBase query.Base + + err := json.Unmarshal(msg, &queryBase) + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("failed to unmarshal: %v", err).Wrap("handleQuery") + socket.SendError(nil, errAnswer) + return errAnswer + } + + var id *int = nil + var errAnswer *answer.Error + + switch queryBase.Method { + case query.MethodCatchUp: + id, errAnswer = handleCatchUp(socket, msg) + case query.MethodGetMessagesById: + id, errAnswer = handleGetMessagesByID(socket, msg) + case query.MethodGreetServer: + id, errAnswer = handleGreetServer(socket, msg) + case query.MethodHeartbeat: + errAnswer = handleHeartbeat(socket, msg) + case query.MethodPublish: + id, errAnswer = handlePublish(socket, msg) + case query.MethodSubscribe: + id, errAnswer = handleSubscribe(socket, msg) + case query.MethodUnsubscribe: + id, errAnswer = handleUnsubscribe(socket, msg) + default: + errAnswer = answer.NewInvalidResourceError("unexpected method: '%s'", queryBase.Method) + } + + if errAnswer != nil && queryBase.Method != query.MethodGreetServer && queryBase.Method != query.MethodHeartbeat { + errAnswer = errAnswer.Wrap("handleQuery") + socket.SendError(id, errAnswer) + return errAnswer + } + + return nil +} + +func handleGreetServer(socket socket.Socket, byteMessage []byte) (*int, *answer.Error) { + var greetServer method.GreetServer + + err := json.Unmarshal(byteMessage, &greetServer) + if err != nil { + errAnswer := answer.NewJsonUnmarshalError(err.Error()) + return nil, errAnswer.Wrap("handleGreetServer") + } + + errAnswer := state.AddPeerInfo(socket.ID(), greetServer.Params) + if errAnswer != nil { + return nil, errAnswer.Wrap("handleGreetServer") + } + + isGreeted, errAnswer := state.IsPeerGreeted(socket.ID()) + if errAnswer != nil { + return nil, errAnswer.Wrap("handleGreetServer") + } + if isGreeted { + return nil, nil + } + + serverPublicKey, clientAddress, serverAddress, errAnswer := config.GetServerInfo() + if errAnswer != nil { + return nil, errAnswer.Wrap("handleGreetServer") + } + + greetServerParams := method.GreetServerParams{ + PublicKey: serverPublicKey, + ServerAddress: serverAddress, + ClientAddress: clientAddress, + } + + serverGreet := &method.GreetServer{ + Base: query.Base{ + JSONRPCBase: jsonrpc.JSONRPCBase{ + JSONRPC: "2.0", + }, + Method: query.MethodGreetServer, + }, + Params: greetServerParams, + } + + buf, err := json.Marshal(serverGreet) + if err != nil { + errAnswer := answer.NewInternalServerError("failed to marshal: %v", err) + return nil, errAnswer.Wrap("handleGreetServer") + } + + socket.Send(buf) + + errAnswer = state.AddPeerGreeted(socket.ID()) + if errAnswer != nil { + return nil, errAnswer.Wrap("handleGreetServer") + } + + return nil, nil +} + +func handleSubscribe(socket socket.Socket, msg []byte) (*int, *answer.Error) { + var subscribe method.Subscribe + + err := json.Unmarshal(msg, &subscribe) + if err != nil { + errAnswer := answer.NewJsonUnmarshalError(err.Error()) + return nil, errAnswer.Wrap("handleSubscribe") + } + + if Root == subscribe.Params.Channel { + errAnswer := answer.NewInvalidActionError("cannot Subscribe to root channel") + return &subscribe.ID, errAnswer.Wrap("handleSubscribe") + } + + errAnswer := state.Subscribe(socket, subscribe.Params.Channel) + if errAnswer != nil { + return &subscribe.ID, errAnswer.Wrap("handleSubscribe") + } + + socket.SendResult(subscribe.ID, nil, nil) + + return &subscribe.ID, nil +} + +func handleUnsubscribe(socket socket.Socket, msg []byte) (*int, *answer.Error) { + var unsubscribe method.Unsubscribe + + err := json.Unmarshal(msg, &unsubscribe) + if err != nil { + errAnswer := answer.NewJsonUnmarshalError(err.Error()) + return nil, errAnswer.Wrap("handleUnsubscribe") + } + + if Root == unsubscribe.Params.Channel { + errAnswer := answer.NewInvalidActionError("cannot Unsubscribe from root channel") + return &unsubscribe.ID, errAnswer.Wrap("handleUnsubscribe") + } + + errAnswer := state.Unsubscribe(socket, unsubscribe.Params.Channel) + if errAnswer != nil { + return &unsubscribe.ID, errAnswer.Wrap("handleUnsubscribe") + } + + socket.SendResult(unsubscribe.ID, nil, nil) + + return &unsubscribe.ID, nil +} + +func handlePublish(socket socket.Socket, msg []byte) (*int, *answer.Error) { + var publish method.Publish + + err := json.Unmarshal(msg, &publish) + if err != nil { + errAnswer := answer.NewJsonUnmarshalError(err.Error()) + return nil, errAnswer.Wrap("handlePublish") + } + + errAnswer := handleChannel(publish.Params.Channel, publish.Params.Message) + if errAnswer != nil { + return &publish.ID, errAnswer.Wrap("handlePublish") + } + + socket.SendResult(publish.ID, nil, nil) + + return &publish.ID, nil +} + +func handleCatchUp(socket socket.Socket, msg []byte) (*int, *answer.Error) { + var catchup method.Catchup + + err := json.Unmarshal(msg, &catchup) + if err != nil { + errAnswer := answer.NewJsonUnmarshalError(err.Error()) + return nil, errAnswer.Wrap("handleCatchUp") + } + + db, errAnswer := database.GetQueryRepositoryInstance() + if errAnswer != nil { + return &catchup.ID, errAnswer.Wrap("handleCatchUp") + } + + result, err := db.GetAllMessagesFromChannel(catchup.Params.Channel) + if err != nil { + errAnswer := answer.NewQueryDatabaseError("all message from channel %s: %v", catchup.Params.Channel, err) + return &catchup.ID, errAnswer.Wrap("handleCatchUp") + } + + socket.SendResult(catchup.ID, result, nil) + + return &catchup.ID, nil +} + +func handleHeartbeat(socket socket.Socket, byteMessage []byte) *answer.Error { + var heartbeat method.Heartbeat + + err := json.Unmarshal(byteMessage, &heartbeat) + if err != nil { + errAnswer := answer.NewJsonUnmarshalError(err.Error()) + return errAnswer.Wrap("handleHeartbeat") + } + + db, errAnswer := database.GetQueryRepositoryInstance() + if errAnswer != nil { + return errAnswer.Wrap("handleHeartbeat") + } + + result, err := db.GetParamsForGetMessageByID(heartbeat.Params) + if err != nil { + errAnswer := answer.NewQueryDatabaseError("params for get messages by id: %v", err) + return errAnswer.Wrap("handleHeartbeat") + } + + if len(result) == 0 { + return nil + } + + queryId, errAnswer := state.GetNextID() + if errAnswer != nil { + return errAnswer.Wrap("handleHeartbeat") + } + + getMessagesById := method.GetMessagesById{ + Base: query.Base{ + JSONRPCBase: jsonrpc.JSONRPCBase{ + JSONRPC: "2.0", + }, + Method: query.MethodGetMessagesById, + }, + ID: queryId, + Params: result, + } + + buf, err := json.Marshal(getMessagesById) + if err != nil { + errAnswer := answer.NewInternalServerError("failed to marshal: %v", err) + return errAnswer.Wrap("handleHeartbeat") + } + + socket.Send(buf) + + errAnswer = state.AddQuery(queryId, getMessagesById) + if errAnswer != nil { + return errAnswer.Wrap("handleHeartbeat") + } + + return nil +} + +func handleGetMessagesByID(socket socket.Socket, msg []byte) (*int, *answer.Error) { + var getMessagesById method.GetMessagesById + + err := json.Unmarshal(msg, &getMessagesById) + if err != nil { + errAnswer := answer.NewJsonUnmarshalError(err.Error()) + return nil, errAnswer.Wrap("handleGetMessageByID") + } + + db, errAnswer := database.GetQueryRepositoryInstance() + if errAnswer != nil { + return &getMessagesById.ID, errAnswer.Wrap("handleGetMessageByID") + } + + result, err := db.GetResultForGetMessagesByID(getMessagesById.Params) + if err != nil { + errAnswer := answer.NewQueryDatabaseError("result for get messages by id: %v", err) + return &getMessagesById.ID, errAnswer.Wrap("handleGetMessageByID") + } + + socket.SendResult(getMessagesById.ID, nil, result) + + return &getMessagesById.ID, nil +} diff --git a/be1-go/internal/popserver/handler/query_test.go b/be1-go/internal/popserver/handler/query_test.go new file mode 100644 index 0000000000..3e88abf10b --- /dev/null +++ b/be1-go/internal/popserver/handler/query_test.go @@ -0,0 +1,642 @@ +package handler + +import ( + "encoding/json" + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + "popstellar/crypto" + "popstellar/internal/popserver/config" + "popstellar/internal/popserver/database" + "popstellar/internal/popserver/database/repository" + "popstellar/internal/popserver/generatortest" + "popstellar/internal/popserver/state" + "popstellar/internal/popserver/types" + "popstellar/message/query/method" + "popstellar/message/query/method/message" + "popstellar/network/socket" + "testing" +) + +func Test_handleQuery(t *testing.T) { + type input struct { + name string + message []byte + contains string + } + + args := make([]input, 0) + + // Test 1: failed to handled popquery because unknown method + + msg := generatortest.NewNothingQuery(t, 999) + + args = append(args, input{ + name: "Test 1", + message: msg, + contains: "unexpected method", + }) + + // run all tests + + for _, arg := range args { + t.Run(arg.name, func(t *testing.T) { + fakeSocket := socket.FakeSocket{Id: "fakesocket"} + errAnswer := handleQuery(&fakeSocket, arg.message) + require.NotNil(t, errAnswer) + require.Contains(t, errAnswer.Error(), arg.contains) + }) + } +} + +func Test_handleGreetServer(t *testing.T) { + subs := types.NewSubscribers() + queries := types.NewQueries(&noLog) + peers := types.NewPeers() + + state.SetState(subs, peers, queries) + + serverSecretKey := crypto.Suite.Scalar().Pick(crypto.Suite.RandomStream()) + serverPublicKey := crypto.Suite.Point().Mul(serverSecretKey, nil) + + config.SetConfig(nil, serverPublicKey, serverSecretKey, "clientAddress", "serverAddress") + + type input struct { + name string + socket socket.FakeSocket + message []byte + needGreet bool + isError bool + contains string + } + + args := make([]input, 0) + + greetServer := generatortest.NewGreetServerQuery(t, "pk", "client", "server") + + // Test 1: reply with greet server when receiving a greet server from a new server + + fakeSocket := socket.FakeSocket{Id: "1"} + + args = append(args, input{ + name: "Test 1", + socket: fakeSocket, + message: greetServer, + needGreet: true, + isError: false, + }) + + // Test 2: doesn't reply with greet server when already greeted the server + + fakeSocket = socket.FakeSocket{Id: "2"} + + peers.AddPeerGreeted(fakeSocket.Id) + + args = append(args, input{ + name: "Test 2", + message: greetServer, + socket: fakeSocket, + needGreet: false, + isError: false, + }) + + // Test 3: return an error if the socket ID is already used by another server + + fakeSocket = socket.FakeSocket{Id: "3"} + + err := peers.AddPeerInfo(fakeSocket.Id, method.GreetServerParams{}) + require.NoError(t, err) + + args = append(args, input{ + name: "Test 3", + socket: fakeSocket, + message: greetServer, + isError: true, + contains: "failed to add peer", + }) + + // run all tests + + for _, arg := range args { + t.Run(arg.name, func(t *testing.T) { + id, errAnswer := handleGreetServer(&arg.socket, arg.message) + if arg.isError { + require.NotNil(t, errAnswer) + require.Contains(t, errAnswer.Error(), arg.contains) + require.Nil(t, id) + } else if arg.needGreet { + require.Nil(t, errAnswer) + require.NotNil(t, arg.socket.Msg) + } else { + require.Nil(t, errAnswer) + require.Nil(t, arg.socket.Msg) + } + }) + } +} + +func Test_handleSubscribe(t *testing.T) { + subs := types.NewSubscribers() + queries := types.NewQueries(&noLog) + peers := types.NewPeers() + + state.SetState(subs, peers, queries) + + type input struct { + name string + socket socket.FakeSocket + ID int + channel string + message []byte + isError bool + contains string + } + + args := make([]input, 0) + + // Test 1: successfully subscribe to a channel + + fakeSocket := socket.FakeSocket{Id: "1"} + ID := 1 + channel := "/root/lao1" + + errAnswer := subs.AddChannel(channel) + require.Nil(t, errAnswer) + + args = append(args, input{ + name: "Test 1", + socket: fakeSocket, + ID: ID, + channel: channel, + message: generatortest.NewSubscribeQuery(t, ID, channel), + isError: false, + }) + + // Test 2: failed to subscribe to an unknown channel + + fakeSocket = socket.FakeSocket{Id: "2"} + ID = 2 + channel = "/root/lao2" + + args = append(args, input{ + name: "Test 2", + socket: fakeSocket, + ID: ID, + channel: channel, + message: generatortest.NewSubscribeQuery(t, ID, channel), + isError: true, + contains: "cannot Subscribe to unknown channel", + }) + + // cannot Subscribe to root + + fakeSocket = socket.FakeSocket{Id: "3"} + ID = 3 + channel = "/root" + + args = append(args, input{ + name: "Test 3", + socket: fakeSocket, + ID: ID, + channel: channel, + message: generatortest.NewSubscribeQuery(t, ID, channel), + isError: true, + contains: "cannot Subscribe to root channel", + }) + + // run all tests + + for _, arg := range args { + t.Run(arg.name, func(t *testing.T) { + id, errAnswer := handleSubscribe(&arg.socket, arg.message) + if arg.isError { + require.NotNil(t, errAnswer) + require.Contains(t, errAnswer.Error(), arg.contains) + require.Equal(t, arg.ID, *id) + } else { + require.Nil(t, errAnswer) + isSubscribed, err := subs.IsSubscribed(arg.channel, &arg.socket) + require.NoError(t, err) + require.True(t, isSubscribed) + } + }) + } +} + +func Test_handleUnsubscribe(t *testing.T) { + subs := types.NewSubscribers() + queries := types.NewQueries(&noLog) + peers := types.NewPeers() + + state.SetState(subs, peers, queries) + + type input struct { + name string + socket socket.FakeSocket + ID int + channel string + message []byte + isError bool + contains string + } + + args := make([]input, 0) + + // Test 1: successfully unsubscribe from a subscribed channel + + fakeSocket := socket.FakeSocket{Id: "1"} + ID := 1 + channel := "/root/lao1" + + errAnswer := subs.AddChannel(channel) + require.Nil(t, errAnswer) + + errAnswer = subs.Subscribe(channel, &fakeSocket) + require.Nil(t, errAnswer) + + args = append(args, input{ + name: "Test 1", + socket: fakeSocket, + ID: ID, + channel: channel, + message: generatortest.NewUnsubscribeQuery(t, ID, channel), + isError: false, + }) + + // Test 2: failed to unsubscribe because not subscribed to channel + + fakeSocket = socket.FakeSocket{Id: "2"} + ID = 2 + channel = "/root/lao2" + + errAnswer = subs.AddChannel(channel) + require.Nil(t, errAnswer) + + args = append(args, input{ + name: "Test 2", + socket: fakeSocket, + ID: ID, + channel: channel, + message: generatortest.NewUnsubscribeQuery(t, ID, channel), + isError: true, + contains: "cannot Unsubscribe from a channel not subscribed", + }) + + // Test 3: failed to unsubscribe because unknown channel + + fakeSocket = socket.FakeSocket{Id: "3"} + ID = 3 + channel = "/root/lao3" + + args = append(args, input{ + name: "Test 3", + socket: fakeSocket, + ID: ID, + channel: channel, + message: generatortest.NewUnsubscribeQuery(t, ID, channel), + isError: true, + contains: "cannot Unsubscribe from unknown channel", + }) + + // Test 3: failed to unsubscribe because cannot unsubscribe from root channel + + fakeSocket = socket.FakeSocket{Id: "4"} + ID = 4 + channel = "/root" + + args = append(args, input{ + name: "Test 4", + socket: fakeSocket, + ID: ID, + channel: channel, + message: generatortest.NewUnsubscribeQuery(t, ID, channel), + isError: true, + contains: "cannot Unsubscribe from root channel", + }) + + // run all tests + + for _, arg := range args { + t.Run(arg.name, func(t *testing.T) { + id, errAnswer := handleUnsubscribe(&arg.socket, arg.message) + if arg.isError { + require.NotNil(t, errAnswer) + require.Contains(t, errAnswer.Error(), arg.contains) + require.Equal(t, arg.ID, *id) + } else { + require.Nil(t, errAnswer) + + isSubscribe, err := subs.IsSubscribed(arg.channel, &arg.socket) + require.NoError(t, err) + require.False(t, isSubscribe) + } + }) + } +} + +func Test_handleCatchUp(t *testing.T) { + subs := types.NewSubscribers() + queries := types.NewQueries(&noLog) + peers := types.NewPeers() + + state.SetState(subs, peers, queries) + + mockRepository := repository.NewMockRepository(t) + database.SetDatabase(mockRepository) + + type input struct { + name string + socket socket.FakeSocket + ID int + message []byte + expected []message.Message + isError bool + contains string + } + + args := make([]input, 0) + + // Test 1: successfully catchup 4 messages on a channel + + fakeSocket := socket.FakeSocket{Id: "1"} + ID := 1 + channel := "/root/lao1" + messagesToCatchUp := []message.Message{ + generatortest.NewNothingMsg(t, "sender1", nil), + generatortest.NewNothingMsg(t, "sender2", nil), + generatortest.NewNothingMsg(t, "sender3", nil), + generatortest.NewNothingMsg(t, "sender4", nil), + } + + mockRepository.On("GetAllMessagesFromChannel", channel).Return(messagesToCatchUp, nil) + + args = append(args, input{ + name: "Test 1", + socket: fakeSocket, + ID: ID, + message: generatortest.NewCatchupQuery(t, ID, channel), + expected: messagesToCatchUp, + isError: false, + }) + + // Test 2: failed to catchup because DB is disconnected + + fakeSocket = socket.FakeSocket{Id: "2"} + ID = 2 + channel = "/root/lao2" + + mockRepository.On("GetAllMessagesFromChannel", channel). + Return(nil, xerrors.Errorf("DB is disconnected")) + + args = append(args, input{ + name: "Test 2", + socket: fakeSocket, + ID: ID, + message: generatortest.NewCatchupQuery(t, ID, channel), + isError: true, + contains: "DB is disconnected", + }) + + // run all tests + + for _, arg := range args { + t.Run(arg.name, func(t *testing.T) { + id, errAnswer := handleCatchUp(&arg.socket, arg.message) + if arg.isError { + require.NotNil(t, errAnswer) + require.Contains(t, errAnswer.Error(), arg.contains) + require.NotNil(t, id) + require.Equal(t, arg.ID, *id) + } else { + require.Nil(t, errAnswer) + require.Equal(t, arg.expected, arg.socket.Res) + } + }) + } +} + +func Test_handleHeartbeat(t *testing.T) { + subs := types.NewSubscribers() + queries := types.NewQueries(&noLog) + peers := types.NewPeers() + + state.SetState(subs, peers, queries) + + mockRepository := repository.NewMockRepository(t) + database.SetDatabase(mockRepository) + + type input struct { + name string + socket socket.FakeSocket + message []byte + expected map[string][]string + isError bool + contains string + } + + msgIDs := []string{"msg0", "msg1", "msg2", "msg3", "msg4", "msg5", "msg6"} + + args := make([]input, 0) + + // Test 1: successfully handled heartbeat with some messages to catching up + + fakeSocket := socket.FakeSocket{Id: "1"} + + heartbeatMsgIDs1 := make(map[string][]string) + heartbeatMsgIDs1["/root"] = []string{ + msgIDs[0], + msgIDs[1], + msgIDs[2], + } + heartbeatMsgIDs1["root/lao1"] = []string{ + msgIDs[3], + msgIDs[4], + } + heartbeatMsgIDs1["root/lao2"] = []string{ + msgIDs[5], + msgIDs[6], + } + + expected1 := make(map[string][]string) + expected1["/root"] = []string{ + msgIDs[1], + msgIDs[2], + } + expected1["root/lao1"] = []string{ + msgIDs[4], + } + + mockRepository.On("GetParamsForGetMessageByID", heartbeatMsgIDs1).Return(expected1, nil) + + args = append(args, input{ + name: "Test 1", + socket: fakeSocket, + message: generatortest.NewHeartbeatQuery(t, heartbeatMsgIDs1), + expected: expected1, + isError: false, + }) + + // Test 2: successfully handled heartbeat with nothing to catching up + + fakeSocket = socket.FakeSocket{Id: "2"} + + heartbeatMsgIDs2 := make(map[string][]string) + heartbeatMsgIDs2["/root"] = []string{ + msgIDs[0], + msgIDs[1], + msgIDs[2], + } + + mockRepository.On("GetParamsForGetMessageByID", heartbeatMsgIDs2).Return(nil, nil) + + args = append(args, input{ + name: "Test 2", + socket: fakeSocket, + message: generatortest.NewHeartbeatQuery(t, heartbeatMsgIDs2), + isError: false, + }) + + // Test 3: failed to handled heartbeat because DB is disconnected + + fakeSocket = socket.FakeSocket{Id: "3"} + + heartbeatMsgIDs3 := make(map[string][]string) + heartbeatMsgIDs3["/root"] = []string{ + msgIDs[0], + msgIDs[1], + msgIDs[2], + } + heartbeatMsgIDs3["root/lao1"] = []string{ + msgIDs[3], + msgIDs[4], + } + + mockRepository.On("GetParamsForGetMessageByID", heartbeatMsgIDs3). + Return(nil, xerrors.Errorf("DB is disconnected")) + + args = append(args, input{ + name: "failed to popquery DB", + socket: fakeSocket, + message: generatortest.NewHeartbeatQuery(t, heartbeatMsgIDs3), + isError: true, + contains: "DB is disconnected", + }) + + // run all tests + + for _, arg := range args { + t.Run(arg.name, func(t *testing.T) { + errAnswer := handleHeartbeat(&arg.socket, arg.message) + if arg.isError { + require.NotNil(t, errAnswer) + } else if arg.expected != nil { + require.Nil(t, errAnswer) + require.NotNil(t, arg.socket.Msg) + + var getMessageByID method.GetMessagesById + err := json.Unmarshal(arg.socket.Msg, &getMessageByID) + require.NoError(t, err) + + require.Equal(t, arg.expected, getMessageByID.Params) + } else { + require.Nil(t, errAnswer) + require.Nil(t, arg.socket.Msg) + } + }) + } +} + +func Test_handleGetMessagesByID(t *testing.T) { + subs := types.NewSubscribers() + queries := types.NewQueries(&noLog) + peers := types.NewPeers() + + state.SetState(subs, peers, queries) + + mockRepository := repository.NewMockRepository(t) + database.SetDatabase(mockRepository) + + type input struct { + name string + socket socket.FakeSocket + ID int + message []byte + expected map[string][]message.Message + isError bool + contains string + } + + args := make([]input, 0) + + // Test 1: successfully handled getMessagesByID and sent the result + + fakeSocket := socket.FakeSocket{Id: "1"} + ID := 1 + + expected1 := make(map[string][]message.Message) + expected1["/root"] = []message.Message{ + generatortest.NewNothingMsg(t, "sender1", nil), + generatortest.NewNothingMsg(t, "sender2", nil), + generatortest.NewNothingMsg(t, "sender3", nil), + generatortest.NewNothingMsg(t, "sender4", nil), + } + expected1["/root/lao1"] = []message.Message{ + generatortest.NewNothingMsg(t, "sender5", nil), + generatortest.NewNothingMsg(t, "sender6", nil), + } + + paramsGetMessagesByID1 := make(map[string][]string) + for k, v := range expected1 { + paramsGetMessagesByID1[k] = make([]string, 0) + for _, w := range v { + paramsGetMessagesByID1[k] = append(paramsGetMessagesByID1[k], w.MessageID) + } + } + + mockRepository.On("GetResultForGetMessagesByID", paramsGetMessagesByID1).Return(expected1, nil) + + args = append(args, input{ + name: "Test 1", + socket: fakeSocket, + ID: ID, + message: generatortest.NewGetMessagesByIDQuery(t, ID, paramsGetMessagesByID1), + expected: expected1, + isError: false, + }) + + // Test 2: failed to handled getMessagesByID because DB is disconnected + + fakeSocket = socket.FakeSocket{Id: "2"} + ID = 2 + + paramsGetMessagesByID2 := make(map[string][]string) + + mockRepository.On("GetResultForGetMessagesByID", paramsGetMessagesByID2). + Return(nil, xerrors.Errorf("DB is disconnected")) + + args = append(args, input{ + name: "Test 2", + socket: fakeSocket, + ID: ID, + message: generatortest.NewGetMessagesByIDQuery(t, ID, paramsGetMessagesByID2), + isError: true, + contains: "DB is disconnected", + }) + + // run all tests + + for _, arg := range args { + t.Run(arg.name, func(t *testing.T) { + id, errAnswer := handleGetMessagesByID(&arg.socket, arg.message) + if arg.isError { + require.NotNil(t, errAnswer) + require.NotNil(t, id) + require.Contains(t, errAnswer.Error(), arg.contains) + require.Equal(t, arg.ID, *id) + } else { + require.Nil(t, errAnswer) + require.NotNil(t, arg.expected) + require.Equal(t, arg.expected, arg.socket.MissingMsgs) + } + }) + } +} diff --git a/be1-go/internal/popserver/handler/reaction.go b/be1-go/internal/popserver/handler/reaction.go new file mode 100644 index 0000000000..75edff974c --- /dev/null +++ b/be1-go/internal/popserver/handler/reaction.go @@ -0,0 +1,109 @@ +package handler + +import ( + "popstellar/internal/popserver/database" + "popstellar/message/answer" + "popstellar/message/messagedata" + "popstellar/message/query/method/message" + "strings" +) + +func handleChannelReaction(channel string, msg message.Message) *answer.Error { + object, action, errAnswer := verifyDataAndGetObjectAction(msg) + if errAnswer != nil { + return errAnswer.Wrap("handleChannelReaction") + } + + db, errAnswer := database.GetReactionRepositoryInstance() + if errAnswer != nil { + return errAnswer.Wrap("handleChannelReaction") + } + + laoPath, _ := strings.CutSuffix(channel, Social+Reactions) + isAttendee, err := db.IsAttendee(laoPath, msg.Sender) + if err != nil { + errAnswer := answer.NewQueryDatabaseError("if is attendee: %v", err) + return errAnswer.Wrap("handleChannelReaction") + } + if !isAttendee { + errAnswer := answer.NewAccessDeniedError("user not inside roll-call") + return errAnswer.Wrap("handleChannelReaction") + } + + switch object + "#" + action { + case messagedata.ReactionObject + "#" + messagedata.ReactionActionAdd: + errAnswer = handleReactionAdd(msg) + case messagedata.ReactionObject + "#" + messagedata.ReactionActionDelete: + errAnswer = handleReactionDelete(msg) + default: + errAnswer = answer.NewInvalidMessageFieldError("failed to handle %s#%s, invalid object#action", object, action) + } + if errAnswer != nil { + return errAnswer.Wrap("handleChannelReaction") + } + + err = db.StoreMessageAndData(channel, msg) + if err != nil { + errAnswer := answer.NewStoreDatabaseError(err.Error()) + return errAnswer.Wrap("handleChannelReaction") + } + + errAnswer = broadcastToAllClients(msg, channel) + if errAnswer != nil { + return errAnswer.Wrap("handleChannelReaction") + } + + return nil + +} + +func handleReactionAdd(msg message.Message) *answer.Error { + var reactMsg messagedata.ReactionAdd + errAnswer := msg.UnmarshalMsgData(&reactMsg) + if errAnswer != nil { + return errAnswer.Wrap("handleReactionAdd") + } + + err := reactMsg.Verify() + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("invalid message: %v", err) + return errAnswer.Wrap("handleReactionAdd") + } + + return nil +} + +func handleReactionDelete(msg message.Message) *answer.Error { + var delReactMsg messagedata.ReactionDelete + errAnswer := msg.UnmarshalMsgData(&delReactMsg) + if errAnswer != nil { + return errAnswer.Wrap("handleReactionDelete") + } + + err := delReactMsg.Verify() + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("invalid message: %v", err) + return errAnswer.Wrap("handleReactionDelete") + } + + db, errAnswer := database.GetReactionRepositoryInstance() + if errAnswer != nil { + return errAnswer.Wrap("handleReactionDelete") + } + reactSender, err := db.GetReactionSender(delReactMsg.ReactionID) + if err != nil { + errAnswer := answer.NewQueryDatabaseError("sender of the reaction %s: %v", delReactMsg.ReactionID, err) + return errAnswer.Wrap("handleReactionDelete") + } + if reactSender == "" { + errAnswer := answer.NewInvalidResourceError("unknown reaction") + return errAnswer.Wrap("handleReactionDelete") + } + + if msg.Sender != reactSender { + errAnswer := answer.NewAccessDeniedError("only the owner of the reaction can delete it") + return errAnswer.Wrap("handleReactionDelete") + } + + return nil +} diff --git a/be1-go/internal/popserver/handler/reaction_test.go b/be1-go/internal/popserver/handler/reaction_test.go new file mode 100644 index 0000000000..342b442aae --- /dev/null +++ b/be1-go/internal/popserver/handler/reaction_test.go @@ -0,0 +1,280 @@ +package handler + +import ( + "encoding/base64" + "github.com/stretchr/testify/require" + "popstellar/crypto" + "popstellar/internal/popserver/config" + "popstellar/internal/popserver/database" + "popstellar/internal/popserver/database/repository" + "popstellar/internal/popserver/generatortest" + "popstellar/internal/popserver/state" + "popstellar/internal/popserver/types" + "popstellar/message/query/method/message" + "strings" + "testing" + "time" +) + +func Test_handleChannelReaction(t *testing.T) { + subs := types.NewSubscribers() + queries := types.NewQueries(&noLog) + peers := types.NewPeers() + + state.SetState(subs, peers, queries) + + organizerBuf, err := base64.URLEncoding.DecodeString(ownerPubBuf64) + require.NoError(t, err) + + ownerPublicKey := crypto.Suite.Point() + err = ownerPublicKey.UnmarshalBinary(organizerBuf) + require.NoError(t, err) + + serverSecretKey := crypto.Suite.Scalar().Pick(crypto.Suite.RandomStream()) + serverPublicKey := crypto.Suite.Point().Mul(serverSecretKey, nil) + + config.SetConfig(ownerPublicKey, serverPublicKey, serverSecretKey, "clientAddress", "serverAddress") + + mockRepository := repository.NewMockRepository(t) + database.SetDatabase(mockRepository) + + sender := "3yPmdBu8DM7jT30IKqkPjuFFIHnubO0z4E0dV7dR4sY=" + //wrongSender := "3yPmdBu8DM7jT30IKqkPjuFFIHnubO0z4E0dV7dR4sK=" + chirpID := "AAAAdBu8DM7jT30IKqkPjuFFIHnubO0z4E0dV7dR4sK=" + invalidChirpID := "NotGooD" + + var args []input + + // Test 1: successfully add a reaction 👍 + + laoID := "lao1" + channelID := RootPrefix + laoID + Social + Reactions + + args = append(args, input{ + name: "Test 1", + channel: channelID, + msg: newReactionAddMsg(t, channelID, sender, "👍", chirpID, time.Now().Unix(), mockRepository, + false, false), + isError: false, + contains: "", + }) + + // Test 2: successfully add a reaction 👎 + + laoID = "lao2" + channelID = RootPrefix + laoID + Social + Reactions + + args = append(args, input{ + name: "Test 2", + channel: channelID, + msg: newReactionAddMsg(t, channelID, sender, "👎", chirpID, time.Now().Unix(), mockRepository, + false, false), + isError: false, + contains: "", + }) + + // Test 3: successfully add a reaction ❤️ + + laoID = "lao3" + channelID = RootPrefix + laoID + Social + Reactions + + args = append(args, input{ + name: "Test 3", + channel: channelID, + msg: newReactionAddMsg(t, channelID, sender, "❤️", chirpID, time.Now().Unix(), mockRepository, + false, false), + isError: false, + contains: "", + }) + + // Test 4: failed to add a reaction because wrong chirpID + + laoID = "lao4" + channelID = RootPrefix + laoID + Social + Reactions + + args = append(args, input{ + name: "Test 4", + channel: channelID, + msg: newReactionAddMsg(t, channelID, sender, "👍", invalidChirpID, time.Now().Unix(), mockRepository, + true, false), + isError: true, + contains: "invalid message field", + }) + + // Test 5: failed to add a reaction because negative timestamp + + laoID = "lao5" + channelID = RootPrefix + laoID + Social + Reactions + + args = append(args, input{ + name: "Test 5", + channel: channelID, + msg: newReactionAddMsg(t, channelID, sender, "👍", chirpID, -1, mockRepository, + true, false), + isError: true, + contains: "invalid message field", + }) + + // Test 6: failed to add a reaction because didn't participate in roll-call + + laoID = "lao6" + channelID = RootPrefix + laoID + Social + Reactions + + args = append(args, input{ + name: "Test 6", + channel: channelID, + msg: newReactionAddMsg(t, channelID, sender, "👍", chirpID, time.Now().Unix(), mockRepository, + false, true), + isError: true, + contains: "user not inside roll-call", + }) + + // Test 7: successfully delete a reaction + + laoID = "lao7" + channelID = RootPrefix + laoID + Social + Reactions + reactionID := "AAAAdBu8DM7jT30IKqkPjuFFIHnubO0z4E0dV7dR4sK=" + + args = append(args, input{ + name: "Test 7", + channel: channelID, + msg: newReactionDeleteMsg(t, channelID, sender, reactionID, time.Now().Unix(), mockRepository, + false, false, false, false), + isError: false, + contains: "", + }) + + // Test 8: failed to delete a reaction because negative timestamp + + laoID = "lao8" + channelID = RootPrefix + laoID + Social + Reactions + reactionID = "AAAAABu8DM7jT30IKqkPjuFFIHnubO0z4E0dV7dR4sK=" + + args = append(args, input{ + name: "Test 8", + channel: channelID, + msg: newReactionDeleteMsg(t, channelID, sender, reactionID, -1, mockRepository, + true, false, false, false), + isError: true, + contains: "invalid message field", + }) + + // Test 9: failed to delete a reaction because reaction doesn't exist + + laoID = "lao9" + channelID = RootPrefix + laoID + Social + Reactions + reactionID = "AAAAdBB8DM7jT30IKqkPjuFFIHnubO0z4E0dV7dR4sK=" + + args = append(args, input{ + name: "Test 9", + channel: channelID, + msg: newReactionDeleteMsg(t, channelID, sender, reactionID, time.Now().Unix(), mockRepository, + false, true, false, false), + isError: true, + contains: "unknown reaction", + }) + + // Test 10: failed to delete a reaction because not owner + + laoID = "lao10" + channelID = RootPrefix + laoID + Social + Reactions + reactionID = "AAAAdBB8DM7jT30IKqkPjuFFIHnubO0z4E0dV7dR4KK=" + + args = append(args, input{ + name: "Test 10", + channel: channelID, + msg: newReactionDeleteMsg(t, channelID, sender, reactionID, time.Now().Unix(), mockRepository, + false, false, true, false), + isError: true, + contains: "only the owner of the reaction can delete it", + }) + + // Test 11: failed to delete a reaction because didn't participate in roll-call + + laoID = "lao11" + channelID = RootPrefix + laoID + Social + Reactions + reactionID = "AAAAdBB8DM7jT30IKqkPjuFFIHnubO0z4E0dV7dRYKK=" + + args = append(args, input{ + name: "Test 11", + channel: channelID, + msg: newReactionDeleteMsg(t, channelID, sender, reactionID, time.Now().Unix(), mockRepository, + false, false, false, true), + isError: true, + contains: "user not inside roll-call", + }) + + // Tests all cases + + for _, arg := range args { + t.Run(arg.name, func(t *testing.T) { + errAnswer := handleChannelReaction(arg.channel, arg.msg) + if arg.isError { + require.NotNil(t, errAnswer) + require.Contains(t, errAnswer.Error(), arg.contains) + } else { + require.Nil(t, errAnswer) + } + }) + } + +} + +func newReactionAddMsg(t *testing.T, channelID string, sender string, reactionCodePoint, chirpID string, timestamp int64, + mockRepository *repository.MockRepository, hasInvalidField, isNotAttendee bool) message.Message { + + msg := generatortest.NewReactionAddMsg(t, sender, nil, reactionCodePoint, chirpID, timestamp) + + errAnswer := state.AddChannel(channelID) + require.Nil(t, errAnswer) + + laoPath, _ := strings.CutSuffix(channelID, Social+Reactions) + + if !hasInvalidField && !isNotAttendee { + mockRepository.On("IsAttendee", laoPath, sender).Return(true, nil) + mockRepository.On("StoreMessageAndData", channelID, msg).Return(nil) + } + + if isNotAttendee { + mockRepository.On("IsAttendee", laoPath, sender).Return(false, nil) + } + + return msg +} + +func newReactionDeleteMsg(t *testing.T, channelID string, sender string, reactionID string, timestamp int64, + mockRepository *repository.MockRepository, hasInvalidField, hasNotReaction, isNotOwner, isNotAttendee bool) message.Message { + + msg := generatortest.NewReactionDeleteMsg(t, sender, nil, reactionID, timestamp) + + errAnswer := state.AddChannel(channelID) + require.Nil(t, errAnswer) + + laoPath, _ := strings.CutSuffix(channelID, Social+Reactions) + + if !hasInvalidField && !hasNotReaction && !isNotOwner && !isNotAttendee { + mockRepository.On("IsAttendee", laoPath, sender).Return(true, nil) + + mockRepository.On("GetReactionSender", reactionID).Return(sender, nil) + + mockRepository.On("StoreMessageAndData", channelID, msg).Return(nil) + } + + if hasNotReaction { + mockRepository.On("IsAttendee", laoPath, sender).Return(true, nil) + + mockRepository.On("GetReactionSender", reactionID).Return("", nil) + } + + if isNotOwner { + mockRepository.On("IsAttendee", laoPath, sender).Return(true, nil) + + mockRepository.On("GetReactionSender", reactionID).Return("notSender", nil) + } + + if isNotAttendee { + mockRepository.On("IsAttendee", laoPath, sender).Return(false, nil) + } + + return msg +} diff --git a/be1-go/internal/popserver/handler/root.go b/be1-go/internal/popserver/handler/root.go new file mode 100644 index 0000000000..9062738e21 --- /dev/null +++ b/be1-go/internal/popserver/handler/root.go @@ -0,0 +1,230 @@ +package handler + +import ( + "encoding/base64" + "encoding/json" + "popstellar/crypto" + "popstellar/internal/popserver/config" + "popstellar/internal/popserver/database" + "popstellar/internal/popserver/database/sqlite" + "popstellar/internal/popserver/state" + "popstellar/message/answer" + "popstellar/message/messagedata" + "popstellar/message/query/method/message" +) + +const ( + Root = "/root" + RootPrefix = "/root/" + Social = "/social" + Chirps = "/chirps" + Reactions = "/reactions" + Consensus = "/consensus" + Coin = "/coin" + Auth = "/authentication" +) + +func handleChannelRoot(msg message.Message) *answer.Error { + object, action, errAnswer := verifyDataAndGetObjectAction(msg) + if errAnswer != nil { + return errAnswer.Wrap("handleChannelRoot") + } + + switch object + "#" + action { + case messagedata.LAOObject + "#" + messagedata.LAOActionCreate: + errAnswer = handleLaoCreate(msg) + default: + errAnswer = answer.NewInvalidMessageFieldError("failed to handle %s#%s, invalid object#action", object, action) + } + + if errAnswer != nil { + return errAnswer.Wrap("handleChannelRoot") + } + + return nil +} + +func handleLaoCreate(msg message.Message) *answer.Error { + var laoCreate messagedata.LaoCreate + errAnswer := msg.UnmarshalMsgData(&laoCreate) + if errAnswer != nil { + return errAnswer.Wrap("handleLaoCreate") + } + + laoPath := RootPrefix + laoCreate.ID + organizerPubBuf, errAnswer := verifyLaoCreation(msg, laoCreate, laoPath) + if errAnswer != nil { + return errAnswer.Wrap("handleLaoCreate") + } + laoGreetMsg, errAnswer := createLaoGreet(organizerPubBuf, laoCreate.ID) + if errAnswer != nil { + return errAnswer.Wrap("handleLaoCreate") + } + errAnswer = createLaoAndChannels(msg, laoGreetMsg, organizerPubBuf, laoPath) + if errAnswer != nil { + return errAnswer.Wrap("handleLaoCreate") + } + return nil +} + +func verifyLaoCreation(msg message.Message, laoCreate messagedata.LaoCreate, laoPath string) ([]byte, *answer.Error) { + db, errAnswer := database.GetRootRepositoryInstance() + if errAnswer != nil { + return nil, errAnswer.Wrap("verifyLAOCreation") + } + + ok, err := db.HasChannel(laoPath) + if err != nil { + errAnswer := answer.NewQueryDatabaseError("if lao already exists: %v", err) + return nil, errAnswer.Wrap("verifyLAOCreation") + } else if ok { + errAnswer := answer.NewDuplicateResourceError("failed to create lao: duplicate lao path: %s", laoPath) + return nil, errAnswer.Wrap("verifyLAOCreation") + } + + err = laoCreate.Verify() + if err != nil { + errAnswer := answer.NewInvalidActionError("failed to verify message data: %v", err) + return nil, errAnswer.Wrap("verifyLAOCreation") + } + + senderPubBuf, err := base64.URLEncoding.DecodeString(msg.Sender) + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("failed to decode public key of the sender: %v", err) + return nil, errAnswer.Wrap("verifyLAOCreation") + } + + senderPubKey := crypto.Suite.Point() + err = senderPubKey.UnmarshalBinary(senderPubBuf) + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("failed to unmarshal public key of the sender: %v", err) + return nil, errAnswer.Wrap("verifyLAOCreation") + } + + organizerPubBuf, err := base64.URLEncoding.DecodeString(laoCreate.Organizer) + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("failed to decode public key of the organizer: %v", err) + return nil, errAnswer.Wrap("verifyLAOCreation") + } + + organizerPubKey := crypto.Suite.Point() + err = organizerPubKey.UnmarshalBinary(organizerPubBuf) + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("failed to unmarshal public key of the organizer: %v", err) + return nil, errAnswer.Wrap("verifyLAOCreation") + } + // Check if the sender and organizer fields of the create#lao message are equal + if !organizerPubKey.Equal(senderPubKey) { + errAnswer := answer.NewAccessDeniedError("sender's public key does not match the organizer public key: %s != %s", + senderPubKey, organizerPubKey) + return nil, errAnswer.Wrap("verifyLAOCreation") + } + + ownerPublicKey, errAnswer := config.GetOwnerPublicKeyInstance() + if errAnswer != nil { + return nil, errAnswer.Wrap("verifyLAOCreation") + } + + // Check if the sender of the LAO creation message is the owner + if ownerPublicKey != nil && !ownerPublicKey.Equal(senderPubKey) { + errAnswer := answer.NewAccessDeniedError("sender's public key does not match the owner public key: %s != %s", + senderPubKey, ownerPublicKey) + return nil, errAnswer.Wrap("verifyLAOCreation") + } + return organizerPubBuf, nil +} + +func createLaoAndChannels(msg, laoGreetMsg message.Message, organizerPubBuf []byte, laoPath string) *answer.Error { + channels := map[string]string{ + laoPath: sqlite.LaoType, + laoPath + Social + Chirps: sqlite.ChirpType, + laoPath + Social + Reactions: sqlite.ReactionType, + laoPath + Consensus: sqlite.ConsensusType, + laoPath + Coin: sqlite.CoinType, + laoPath + Auth: sqlite.AuthType, + } + + db, errAnswer := database.GetRootRepositoryInstance() + if errAnswer != nil { + return errAnswer.Wrap("createLaoAndSubChannels") + } + + err := db.StoreLaoWithLaoGreet(channels, laoPath, organizerPubBuf, msg, laoGreetMsg) + if err != nil { + errAnswer := answer.NewStoreDatabaseError("lao and sub channels: %v", err) + return errAnswer.Wrap("createLaoAndSubChannels") + } + + for channelPath := range channels { + errAnswer := state.AddChannel(channelPath) + if errAnswer != nil { + return errAnswer.Wrap("createLaoAndSubChannels") + } + } + return nil +} + +func createLaoGreet(organizerBuf []byte, laoID string) (message.Message, *answer.Error) { + peersInfo, errAnswer := state.GetAllPeersInfo() + if errAnswer != nil { + return message.Message{}, errAnswer.Wrap("createAndSendLaoGreet") + } + + knownPeers := make([]messagedata.Peer, 0, len(peersInfo)) + for _, info := range peersInfo { + knownPeers = append(knownPeers, messagedata.Peer{Address: info.ClientAddress}) + } + + _, clientServerAddress, _, errAnswer := config.GetServerInfo() + if errAnswer != nil { + return message.Message{}, errAnswer.Wrap("createAndSendLaoGreet") + } + + msgData := messagedata.LaoGreet{ + Object: messagedata.LAOObject, + Action: messagedata.LAOActionGreet, + LaoID: laoID, + Frontend: base64.URLEncoding.EncodeToString(organizerBuf), + Address: clientServerAddress, + Peers: knownPeers, + } + + // Marshall the message data + dataBuf, err := json.Marshal(&msgData) + if err != nil { + errAnswer := answer.NewInternalServerError("failed to marshal message data: %v", err) + return message.Message{}, errAnswer.Wrap("createAndSendLaoGreet") + } + + newData64 := base64.URLEncoding.EncodeToString(dataBuf) + + serverPublicKey, errAnswer := config.GetServerPublicKeyInstance() + if errAnswer != nil { + return message.Message{}, errAnswer.Wrap("createAndSendLaoGreet") + } + + // Marshall the server public key + serverPubBuf, err := serverPublicKey.MarshalBinary() + if err != nil { + errAnswer := answer.NewInternalServerError("failed to marshal server public key: %v", err) + return message.Message{}, errAnswer.Wrap("createAndSendLaoGreet") + } + + // Sign the data + signatureBuf, errAnswer := Sign(dataBuf) + if errAnswer != nil { + return message.Message{}, errAnswer.Wrap("createAndSendLaoGreet") + } + + signature := base64.URLEncoding.EncodeToString(signatureBuf) + + laoGreetMsg := message.Message{ + Data: newData64, + Sender: base64.URLEncoding.EncodeToString(serverPubBuf), + Signature: signature, + MessageID: messagedata.Hash(newData64, signature), + WitnessSignatures: []message.WitnessSignature{}, + } + + return laoGreetMsg, nil +} diff --git a/be1-go/internal/popserver/handler/root_test.go b/be1-go/internal/popserver/handler/root_test.go new file mode 100644 index 0000000000..178f073859 --- /dev/null +++ b/be1-go/internal/popserver/handler/root_test.go @@ -0,0 +1,147 @@ +package handler + +import ( + "encoding/base64" + "fmt" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "popstellar/crypto" + "popstellar/internal/popserver/config" + "popstellar/internal/popserver/database" + "popstellar/internal/popserver/database/repository" + "popstellar/internal/popserver/database/sqlite" + "popstellar/internal/popserver/generatortest" + "popstellar/internal/popserver/state" + "popstellar/internal/popserver/types" + "popstellar/message/messagedata" + "popstellar/message/query/method/message" + "testing" + "time" +) + +const ( + // wrongSender A public key different from the owner public key + wrongSender = "M5ZychEi5rwm22FjwjNuljL1qMJWD2sE7oX9fcHNMDU=" + goodLaoName = "laoName" + wrongLaoName = "wrongLaoName" +) + +type input struct { + name string + channel string + msg message.Message + isError bool + contains string +} + +func Test_handleChannelRoot(t *testing.T) { + subs := types.NewSubscribers() + queries := types.NewQueries(&noLog) + peers := types.NewPeers() + + state.SetState(subs, peers, queries) + + organizerBuf, err := base64.URLEncoding.DecodeString(ownerPubBuf64) + require.NoError(t, err) + + ownerPublicKey := crypto.Suite.Point() + err = ownerPublicKey.UnmarshalBinary(organizerBuf) + require.NoError(t, err) + + serverSecretKey := crypto.Suite.Scalar().Pick(crypto.Suite.RandomStream()) + serverPublicKey := crypto.Suite.Point().Mul(serverSecretKey, nil) + + config.SetConfig(ownerPublicKey, serverPublicKey, serverSecretKey, "clientAddress", "serverAddress") + + var args []input + mockRepository := repository.NewMockRepository(t) + database.SetDatabase(mockRepository) + + ownerPubBuf, err := ownerPublicKey.MarshalBinary() + require.NoError(t, err) + owner := base64.URLEncoding.EncodeToString(ownerPubBuf) + + // Test 1: error when different organizer and sender keys + args = append(args, input{ + name: "Test 1", + msg: newLaoCreateMsg(t, owner, wrongSender, goodLaoName, mockRepository, true), + isError: true, + contains: "sender's public key does not match the organizer public key", + }) + + // Test 2: error when different sender and owner keys + args = append(args, input{ + name: "Test 2", + msg: newLaoCreateMsg(t, wrongSender, wrongSender, goodLaoName, mockRepository, true), + isError: true, + contains: "sender's public key does not match the owner public key", + }) + + // Test 3: error when the lao name is not the same as the one used for the laoID + args = append(args, input{ + name: "Test 3", + msg: newLaoCreateMsg(t, owner, owner, wrongLaoName, mockRepository, true), + isError: true, + contains: "failed to verify message data: invalid message field: lao id", + }) + + // Test 4: error when message data is not lao_create + args = append(args, input{ + name: "Test 4", + msg: generatortest.NewNothingMsg(t, owner, nil), + isError: true, + contains: "failed to validate schema", + }) + + // Test 5: success + args = append(args, input{ + name: "Test 5", + msg: newLaoCreateMsg(t, owner, owner, goodLaoName, mockRepository, false), + isError: false, + contains: "", + }) + + for _, arg := range args { + t.Run(arg.name, func(t *testing.T) { + errAnswer := handleChannelRoot(arg.msg) + if arg.isError { + require.NotNil(t, errAnswer) + require.Contains(t, errAnswer.Error(), arg.contains) + } else { + require.Nil(t, errAnswer) + } + }) + } +} + +func newLaoCreateMsg(t *testing.T, organizer, sender, laoName string, mockRepository *repository.MockRepository, isError bool) message.Message { + creation := time.Now().Unix() + laoID := messagedata.Hash( + organizer, + fmt.Sprintf("%d", creation), + goodLaoName, + ) + + msg := generatortest.NewLaoCreateMsg(t, sender, laoID, laoName, creation, organizer, nil) + + mockRepository.On("HasChannel", RootPrefix+laoID).Return(false, nil) + if !isError { + laoPath := RootPrefix + laoID + organizerBuf, err := base64.URLEncoding.DecodeString(organizer) + require.NoError(t, err) + channels := map[string]string{ + laoPath: sqlite.LaoType, + laoPath + Social + Chirps: sqlite.ChirpType, + laoPath + Social + Reactions: sqlite.ReactionType, + laoPath + Consensus: sqlite.ConsensusType, + laoPath + Coin: sqlite.CoinType, + laoPath + Auth: sqlite.AuthType, + } + mockRepository.On("StoreLaoWithLaoGreet", + channels, + laoPath, + organizerBuf, + msg, mock.AnythingOfType("message.Message")).Return(nil) + } + return msg +} diff --git a/be1-go/internal/popserver/hub.go b/be1-go/internal/popserver/hub.go new file mode 100644 index 0000000000..07d0c62cb4 --- /dev/null +++ b/be1-go/internal/popserver/hub.go @@ -0,0 +1,155 @@ +package popserver + +import ( + "encoding/json" + "golang.org/x/xerrors" + "popstellar/internal/popserver/config" + "popstellar/internal/popserver/database" + "popstellar/internal/popserver/handler" + "popstellar/internal/popserver/state" + "popstellar/internal/popserver/types" + "popstellar/internal/popserver/utils" + jsonrpc "popstellar/message" + "popstellar/message/query" + "popstellar/message/query/method" + "popstellar/network/socket" + "time" +) + +const heartbeatDelay = 30 * time.Second + +type Hub struct { + messageChan chan socket.IncomingMessage + stop chan struct{} + closedSockets chan string + serverSockets types.Sockets +} + +func NewHub() *Hub { + return &Hub{ + messageChan: make(chan socket.IncomingMessage), + stop: make(chan struct{}), + closedSockets: make(chan string), + serverSockets: types.NewSockets(), + } +} + +func (h *Hub) NotifyNewServer(socket socket.Socket) { + h.serverSockets.Upsert(socket) +} + +func (h *Hub) Start() { + go func() { + ticker := time.NewTicker(heartbeatDelay) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + h.sendHeartbeatToServers() + case <-h.stop: + utils.LogInfo("stopping the heartbeat") + return + } + } + }() + go func() { + utils.LogInfo("start the Hub") + for { + select { + case incomingMessage := <-h.messageChan: + utils.LogInfo("start handling a message") + err := handler.HandleIncomingMessage(incomingMessage.Socket, incomingMessage.Message) + if err != nil { + utils.LogError(err) + } else { + utils.LogInfo("successfully handled a message") + } + case <-h.closedSockets: + utils.LogInfo("stopping the Sockets") + return + case <-h.stop: + utils.LogInfo("stopping the Hub") + return + } + } + }() +} + +func (h *Hub) Stop() { + close(h.stop) +} + +func (h *Hub) Receiver() chan<- socket.IncomingMessage { + return h.messageChan +} + +func (h *Hub) OnSocketClose() chan<- string { + return h.closedSockets +} + +func (h *Hub) SendGreetServer(socket socket.Socket) error { + serverPublicKey, clientAddress, serverAddress, errAnswer := config.GetServerInfo() + if errAnswer != nil { + return xerrors.Errorf(errAnswer.Error()) + } + + greetServerParams := method.GreetServerParams{ + PublicKey: serverPublicKey, + ServerAddress: serverAddress, + ClientAddress: clientAddress, + } + + serverGreet := &method.GreetServer{ + Base: query.Base{ + JSONRPCBase: jsonrpc.JSONRPCBase{ + JSONRPC: "2.0", + }, + Method: query.MethodGreetServer, + }, + Params: greetServerParams, + } + + buf, err := json.Marshal(serverGreet) + if err != nil { + return xerrors.Errorf("failed to marshal: %v", err) + } + + socket.Send(buf) + + errAnswer = state.AddPeerGreeted(socket.ID()) + if errAnswer != nil { + return xerrors.Errorf(errAnswer.Error()) + } + return nil +} + +// sendHeartbeatToServers sends a heartbeat message to all servers +func (h *Hub) sendHeartbeatToServers() { + + db, errAnswer := database.GetQueryRepositoryInstance() + if errAnswer != nil { + return + } + + params, err := db.GetParamsHeartbeat() + if err != nil { + return + } + + heartbeatMessage := method.Heartbeat{ + Base: query.Base{ + JSONRPCBase: jsonrpc.JSONRPCBase{ + JSONRPC: "2.0", + }, + Method: "heartbeat", + }, + Params: params, + } + + buf, err := json.Marshal(heartbeatMessage) + if err != nil { + utils.LogError(err) + } + h.serverSockets.SendToAll(buf) +} diff --git a/be1-go/internal/popserver/state/state.go b/be1-go/internal/popserver/state/state.go new file mode 100644 index 0000000000..65bed7164e --- /dev/null +++ b/be1-go/internal/popserver/state/state.go @@ -0,0 +1,209 @@ +package state + +import ( + "github.com/rs/zerolog" + "popstellar/internal/popserver/types" + "popstellar/message/answer" + "popstellar/message/query/method" + "popstellar/network/socket" + "sync" +) + +var once sync.Once +var instance *state + +type state struct { + subs Subscriber + peers Peerer + queries Querier +} + +type Subscriber interface { + AddChannel(channel string) *answer.Error + HasChannel(channel string) bool + Subscribe(channel string, socket socket.Socket) *answer.Error + Unsubscribe(channel string, socket socket.Socket) *answer.Error + SendToAll(buf []byte, channel string) *answer.Error +} + +type Peerer interface { + AddPeerInfo(socketID string, info method.GreetServerParams) error + AddPeerGreeted(socketID string) + GetAllPeersInfo() []method.GreetServerParams + IsPeerGreeted(socketID string) bool +} + +type Querier interface { + GetQueryState(ID int) (bool, error) + GetNextID() int + SetQueryReceived(ID int) error + AddQuery(ID int, query method.GetMessagesById) +} + +func InitState(log *zerolog.Logger) { + once.Do(func() { + instance = &state{ + subs: types.NewSubscribers(), + peers: types.NewPeers(), + queries: types.NewQueries(log), + } + }) +} + +// ONLY FOR TEST PURPOSE +// SetState is only here to be used to reset the state before each test +func SetState(subs Subscriber, peers Peerer, queries Querier) { + instance = &state{ + subs: subs, + peers: peers, + queries: queries, + } +} + +func getSubs() (Subscriber, *answer.Error) { + if instance == nil || instance.subs == nil { + return nil, answer.NewInternalServerError("subscriber was not instantiated") + } + + return instance.subs, nil +} + +func AddChannel(channel string) *answer.Error { + subs, errAnswer := getSubs() + if errAnswer != nil { + return errAnswer + } + + return subs.AddChannel(channel) +} + +func HasChannel(channel string) (bool, *answer.Error) { + subs, errAnswer := getSubs() + if errAnswer != nil { + return false, errAnswer + } + + return subs.HasChannel(channel), nil +} + +func Subscribe(socket socket.Socket, channel string) *answer.Error { + subs, errAnswer := getSubs() + if errAnswer != nil { + return errAnswer + } + + return subs.Subscribe(channel, socket) +} + +func Unsubscribe(socket socket.Socket, channel string) *answer.Error { + subs, errAnswer := getSubs() + if errAnswer != nil { + return errAnswer + } + + return subs.Unsubscribe(channel, socket) +} + +func SendToAll(buf []byte, channel string) *answer.Error { + subs, errAnswer := getSubs() + if errAnswer != nil { + return errAnswer + } + + return subs.SendToAll(buf, channel) +} + +func getPeers() (Peerer, *answer.Error) { + if instance == nil || instance.peers == nil { + return nil, answer.NewInternalServerError("peerer was not instantiated") + } + + return instance.peers, nil +} + +func AddPeerInfo(socketID string, info method.GreetServerParams) *answer.Error { + peers, errAnswer := getPeers() + if errAnswer != nil { + return errAnswer + } + + err := peers.AddPeerInfo(socketID, info) + if err != nil { + errAnswer := answer.NewInvalidActionError("failed to add peer: %v", err) + return errAnswer + } + + return nil +} + +func AddPeerGreeted(socketID string) *answer.Error { + peers, errAnswer := getPeers() + if errAnswer != nil { + return errAnswer + } + + peers.AddPeerGreeted(socketID) + + return nil +} + +func GetAllPeersInfo() ([]method.GreetServerParams, *answer.Error) { + peers, errAnswer := getPeers() + if errAnswer != nil { + return nil, errAnswer + } + + return peers.GetAllPeersInfo(), nil +} + +func IsPeerGreeted(socketID string) (bool, *answer.Error) { + peers, errAnswer := getPeers() + if errAnswer != nil { + return false, errAnswer + } + + return peers.IsPeerGreeted(socketID), nil +} + +func getQueries() (Querier, *answer.Error) { + if instance == nil || instance.queries == nil { + return nil, answer.NewInternalServerError("querier was not instantiated") + } + + return instance.queries, nil +} + +func GetNextID() (int, *answer.Error) { + queries, errAnswer := getQueries() + if errAnswer != nil { + return -1, errAnswer + } + + return queries.GetNextID(), nil +} + +func SetQueryReceived(ID int) *answer.Error { + queries, errAnswer := getQueries() + if errAnswer != nil { + return errAnswer + } + + err := queries.SetQueryReceived(ID) + if err != nil { + errAnswer := answer.NewInvalidActionError("%v", err) + return errAnswer + } + + return nil +} + +func AddQuery(ID int, query method.GetMessagesById) *answer.Error { + queries, errAnswer := getQueries() + if errAnswer != nil { + return errAnswer + } + + queries.AddQuery(ID, query) + + return nil +} diff --git a/be1-go/internal/popserver/types/peers.go b/be1-go/internal/popserver/types/peers.go new file mode 100644 index 0000000000..41c7859591 --- /dev/null +++ b/be1-go/internal/popserver/types/peers.go @@ -0,0 +1,70 @@ +package types + +import ( + "popstellar/message/answer" + "popstellar/message/query/method" + "sync" + + "golang.org/x/exp/maps" + "golang.org/x/exp/slices" +) + +// Peers stores the peers' information +type Peers struct { + sync.RWMutex + // peersInfo stores the info of the peers: public key, client and server endpoints associated with the socket ID + peersInfo map[string]method.GreetServerParams + // peersGreeted stores the peers that were greeted by the socket ID + peersGreeted map[string]struct{} +} + +// NewPeers creates a new Peers structure +func NewPeers() *Peers { + return &Peers{ + peersInfo: make(map[string]method.GreetServerParams), + peersGreeted: make(map[string]struct{}), + } +} + +// AddPeerInfo adds a peer's info to the table +func (p *Peers) AddPeerInfo(socketId string, info method.GreetServerParams) error { + p.Lock() + defer p.Unlock() + + currentInfo, ok := p.peersInfo[socketId] + if ok { + return answer.NewInvalidActionError( + "cannot add %s because peersInfo[%s] already contains %s", + info, socketId, currentInfo) + } + + p.peersInfo[socketId] = info + return nil +} + +// AddPeerGreeted adds a peer's socket ID to the slice of peers greeted +func (p *Peers) AddPeerGreeted(socketId string) { + p.Lock() + defer p.Unlock() + p.peersGreeted[socketId] = struct{}{} +} + +// GetAllPeersInfo returns a copy of the peers' info slice +func (p *Peers) GetAllPeersInfo() []method.GreetServerParams { + p.RLock() + defer p.RUnlock() + peersInfo := make([]method.GreetServerParams, 0, len(p.peersInfo)) + for _, info := range p.peersInfo { + if !slices.Contains(peersInfo, info) { + peersInfo = append(peersInfo, info) + } + } + return peersInfo +} + +// IsPeerGreeted returns true if the peer was greeted, otherwise it returns false +func (p *Peers) IsPeerGreeted(socketId string) bool { + p.RLock() + defer p.RUnlock() + return slices.Contains(maps.Keys(p.peersGreeted), socketId) +} diff --git a/be1-go/internal/popserver/types/queries.go b/be1-go/internal/popserver/types/queries.go new file mode 100644 index 0000000000..6d154f857b --- /dev/null +++ b/be1-go/internal/popserver/types/queries.go @@ -0,0 +1,83 @@ +package types + +import ( + "popstellar/message/query/method" + "sync" + + "github.com/rs/zerolog" + "golang.org/x/xerrors" +) + +// Queries let the hub remember all queries that it sent to other servers +type Queries struct { + sync.Mutex + // state stores the ID of the server's queries and their state. False for a + // query not yet answered, else true. + state map[int]bool + // getMessagesByIdQueries stores the server's getMessagesByIds queries by their ID. + getMessagesByIdQueries map[int]method.GetMessagesById + // nextID store the ID of the next query + nextID int + // zerolog + log *zerolog.Logger +} + +// NewQueries creates a new queries struct +func NewQueries(log *zerolog.Logger) *Queries { + return &Queries{ + state: make(map[int]bool), + getMessagesByIdQueries: make(map[int]method.GetMessagesById), + log: log, + } +} + +// GetQueryState returns a given query's state +func (q *Queries) GetQueryState(id int) (bool, error) { + q.Lock() + defer q.Unlock() + + state, ok := q.state[id] + if !ok { + return false, xerrors.Errorf("query with id %d not found", id) + } + return state, nil +} + +// GetNextID returns the next query ID +func (q *Queries) GetNextID() int { + q.Lock() + defer q.Unlock() + + id := q.nextID + q.nextID++ + return id +} + +// SetQueryReceived sets the state of the query with the given ID as received +func (q *Queries) SetQueryReceived(id int) error { + q.Lock() + defer q.Unlock() + + currentState, ok := q.state[id] + + if !ok { + return xerrors.Errorf("query with id %d not found", id) + } + + if currentState { + q.log.Info().Msgf("query with id %d already answered", id) + return nil + } + + q.state[id] = true + return nil +} + +// AddQuery adds the given query to the table +func (q *Queries) AddQuery(id int, query method.GetMessagesById) { + q.Lock() + defer q.Unlock() + + q.getMessagesByIdQueries[id] = query + q.state[id] = false +} diff --git a/be1-go/internal/popserver/types/question.go b/be1-go/internal/popserver/types/question.go new file mode 100644 index 0000000000..c44f5e222d --- /dev/null +++ b/be1-go/internal/popserver/types/question.go @@ -0,0 +1,31 @@ +package types + +type Question struct { + // ID represents the ID of the Question. + ID []byte + + // ballotOptions represents different ballot options. + BallotOptions []string + + // validVotes represents the list of all valid votes. The key represents + // the public key of the person casting the vote. + ValidVotes map[string]ValidVote + + // method represents the voting method of the election. Either "Plurality" + // or "Approval". + Method string +} + +type ValidVote struct { + // msgID represents the ID of the message containing the cast vote + MsgID string + + // ID represents the ID of the valid cast vote + ID string + + // voteTime represents the time of the creation of the vote + VoteTime int64 + + // index represents the index of the ballot options + Index interface{} +} diff --git a/be1-go/internal/popserver/types/sockets.go b/be1-go/internal/popserver/types/sockets.go new file mode 100644 index 0000000000..189a09d1e7 --- /dev/null +++ b/be1-go/internal/popserver/types/sockets.go @@ -0,0 +1,59 @@ +package types + +import ( + "popstellar/network/socket" + "sync" +) + +// NewSockets returns a new initialized Sockets +func NewSockets() Sockets { + return Sockets{ + store: make(map[string]socket.Socket), + } +} + +// Sockets provides thread-functionalities around a socket store. +type Sockets struct { + sync.RWMutex + store map[string]socket.Socket +} + +// Len returns the number of Sockets. +func (s *Sockets) Len() int { + return len(s.store) +} + +// SendToAll sends a message to all Sockets. +func (s *Sockets) SendToAll(buf []byte) { + s.RLock() + defer s.RUnlock() + + for _, s := range s.store { + s.Send(buf) + } +} + +// Upsert upserts a socket into the Sockets store. +func (s *Sockets) Upsert(socket socket.Socket) { + s.Lock() + defer s.Unlock() + + s.store[socket.ID()] = socket +} + +// Delete deletes a socket from the store. Returns false +// if the socket is not present in the store and true +// on success. +func (s *Sockets) Delete(ID string) bool { + s.Lock() + defer s.Unlock() + + _, ok := s.store[ID] + if !ok { + return false + } + + delete(s.store, ID) + + return true +} diff --git a/be1-go/internal/popserver/types/subscribers.go b/be1-go/internal/popserver/types/subscribers.go new file mode 100644 index 0000000000..9845ff8a2d --- /dev/null +++ b/be1-go/internal/popserver/types/subscribers.go @@ -0,0 +1,107 @@ +package types + +import ( + "golang.org/x/xerrors" + "popstellar/message/answer" + "popstellar/network/socket" + "sync" +) + +type Subscribers struct { + sync.RWMutex + list map[string]map[string]socket.Socket +} + +func NewSubscribers() *Subscribers { + return &Subscribers{ + list: make(map[string]map[string]socket.Socket), + } +} + +func (s *Subscribers) AddChannel(channel string) *answer.Error { + s.Lock() + defer s.Unlock() + + _, ok := s.list[channel] + if ok { + return answer.NewInvalidActionError("channel %s already exists", channel) + } + + s.list[channel] = make(map[string]socket.Socket) + + return nil +} + +func (s *Subscribers) Subscribe(channel string, socket socket.Socket) *answer.Error { + s.Lock() + defer s.Unlock() + + _, ok := s.list[channel] + if !ok { + return answer.NewInvalidResourceError("cannot Subscribe to unknown channel") + } + + s.list[channel][socket.ID()] = socket + + return nil +} + +func (s *Subscribers) Unsubscribe(channel string, socket socket.Socket) *answer.Error { + s.Lock() + defer s.Unlock() + + _, ok := s.list[channel] + if !ok { + return answer.NewInvalidResourceError("cannot Unsubscribe from unknown channel") + } + + _, ok = s.list[channel][socket.ID()] + if !ok { + return answer.NewInvalidActionError("cannot Unsubscribe from a channel not subscribed") + } + + delete(s.list[channel], socket.ID()) + + return nil +} + +// SendToAll sends a message to all sockets. +func (s *Subscribers) SendToAll(buf []byte, channel string) *answer.Error { + s.RLock() + defer s.RUnlock() + + sockets, ok := s.list[channel] + if !ok { + return answer.NewInvalidResourceError("failed to send to all clients, channel %s not found", channel) + } + for _, v := range sockets { + v.Send(buf) + } + + return nil +} + +func (s *Subscribers) HasChannel(channel string) bool { + s.RLock() + defer s.RUnlock() + + _, ok := s.list[channel] + + return ok +} + +func (s *Subscribers) IsSubscribed(channel string, socket socket.Socket) (bool, error) { + s.RLock() + defer s.RUnlock() + + sockets, ok := s.list[channel] + if !ok { + return false, xerrors.Errorf("channel doesn't exist") + } + _, ok = sockets[socket.ID()] + if !ok { + return false, nil + } + + return true, nil +} diff --git a/be1-go/internal/popserver/utils/utils.go b/be1-go/internal/popserver/utils/utils.go new file mode 100644 index 0000000000..5d7a74ac5b --- /dev/null +++ b/be1-go/internal/popserver/utils/utils.go @@ -0,0 +1,56 @@ +package utils + +import ( + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "popstellar/message/answer" + "popstellar/validation" + "sync" +) + +var once sync.Once +var instance *utils + +type utils struct { + log *zerolog.Logger + schemaValidator *validation.SchemaValidator +} + +func InitUtils(log *zerolog.Logger, schemaValidator *validation.SchemaValidator) { + once.Do(func() { + instance = &utils{ + log: log, + schemaValidator: schemaValidator, + } + }) +} + +func VerifyJSON(msg []byte, st validation.SchemaType) *answer.Error { + if instance == nil || instance.schemaValidator == nil { + return answer.NewInternalServerError("schema validator was not instantiated").Wrap("VerifyJSON") + } + + err := instance.schemaValidator.VerifyJSON(msg, st) + if err != nil { + errAnswer := answer.NewInvalidMessageFieldError("invalid json: %v", err).Wrap("VerifyJSON") + return errAnswer + } + + return nil +} + +func LogInfo(msg string) { + if instance == nil || instance.log == nil { + return + } + + instance.log.Info().Msg(msg) +} + +func LogError(err error) { + if instance == nil || instance.log == nil { + return + } + + log.Error().Msg(err.Error()) +} diff --git a/be1-go/message/answer/error.go b/be1-go/message/answer/error.go index ff365efeb6..668ae8b278 100644 --- a/be1-go/message/answer/error.go +++ b/be1-go/message/answer/error.go @@ -2,6 +2,15 @@ package answer import "fmt" +const ( + InvalidActionErrorCode = -1 + InvalidResourceErrorCode = -2 + DuplicateResourceErrorCode = -3 + InvalidMessageFieldErrorCode = -4 + AccessDeniedErrorCode = -5 + InternalServerErrorCode = -6 +) + // Error defines a JSON RPC error type Error struct { Code int `json:"code"` @@ -13,6 +22,13 @@ func (e *Error) Error() string { return e.Description } +func (e *Error) Wrap(description string) *Error { + return &Error{ + Code: e.Code, + Description: fmt.Sprintf(description+": %v", e.Description), + } +} + // NewError returns a *message.Error func NewError(code int, description string) *Error { return &Error{ @@ -31,35 +47,45 @@ func NewErrorf(code int, format string, values ...interface{}) *Error { // NewInvalidActionError returns an error with the code -1 for an invalid action. func NewInvalidActionError(format string, a ...interface{}) *Error { - return NewErrorf(-1, "invalid action: "+format, a...) -} - -// NewInvalidObjectError returns an error with the code -1 for an invalid object. -func NewInvalidObjectError(format string, a ...interface{}) *Error { - return NewErrorf(-1, "invalid object: "+format, a...) + return NewErrorf(InvalidActionErrorCode, "invalid action: "+format, a...) } // NewInvalidResourceError returns an error with -2 for an object with invalid resources func NewInvalidResourceError(format string, a ...interface{}) *Error { - return NewErrorf(-2, "invalid resource: "+format, a...) + return NewErrorf(InvalidResourceErrorCode, "invalid resource: "+format, a...) } // NewDuplicateResourceError returns an error with -3 for a resource that already exists func NewDuplicateResourceError(format string, a ...interface{}) *Error { - return NewErrorf(-3, "duplicate resource: "+format, a...) + return NewErrorf(DuplicateResourceErrorCode, "duplicate resource: "+format, a...) } // NewInvalidMessageFieldError returns an error with -4 when a message field is bogus func NewInvalidMessageFieldError(format string, a ...interface{}) *Error { - return NewErrorf(-4, "invalid message field: "+format, a...) + return NewErrorf(InvalidMessageFieldErrorCode, "invalid message field: "+format, a...) +} + +// NewJsonUnmarshalError returns an error with -4 when it is impossible to unmarshal a json message +func NewJsonUnmarshalError(format string, a ...interface{}) *Error { + return NewErrorf(InvalidMessageFieldErrorCode, "failed to unmarshal JSON: "+format, a...) } // NewAccessDeniedError returns an error with -5 when an access is denied for the sender func NewAccessDeniedError(format string, a ...interface{}) *Error { - return NewErrorf(-5, "access denied: "+format, a...) + return NewErrorf(AccessDeniedErrorCode, "access denied: "+format, a...) } // NewInternalServerError returns an error with -6 when there is an internal server error func NewInternalServerError(format string, a ...interface{}) *Error { - return NewErrorf(-6, "internal server error: "+format, a...) + return NewErrorf(InternalServerErrorCode, "internal server error: "+format, a...) +} + +// NewQueryDatabaseError returns an error with -6 when there is an error with a database query +func NewQueryDatabaseError(format string, a ...interface{}) *Error { + return NewErrorf(InternalServerErrorCode, "failed to query from database: "+format, a...) +} + +// NewStoreDatabaseError returns an error with -6 when there is an error with a database store +func NewStoreDatabaseError(format string, a ...interface{}) *Error { + return NewErrorf(InternalServerErrorCode, "failed to store inside database: "+format, a...) } diff --git a/be1-go/message/messagedata/election_end.go b/be1-go/message/messagedata/election_end.go index 7d793eb789..54fcf3b786 100644 --- a/be1-go/message/messagedata/election_end.go +++ b/be1-go/message/messagedata/election_end.go @@ -1,5 +1,11 @@ package messagedata +import ( + "encoding/base64" + "popstellar/message/answer" + "strings" +) + // ElectionEnd defines a message data type ElectionEnd struct { Object string `json:"object"` @@ -27,3 +33,54 @@ func (ElectionEnd) GetAction() string { func (ElectionEnd) NewEmpty() MessageData { return &ElectionEnd{} } + +func (message ElectionEnd) Verify(electionPath string) *answer.Error { + var errAnswer *answer.Error + + _, err := base64.URLEncoding.DecodeString(message.Lao) + if err != nil { + errAnswer = answer.NewInvalidMessageFieldError("failed to decode lao: %v", err) + return errAnswer + } + + _, err = base64.URLEncoding.DecodeString(message.Election) + if err != nil { + errAnswer = answer.NewInvalidMessageFieldError("failed to decode election: %v", err) + return errAnswer + } + + noRoot := strings.ReplaceAll(electionPath, RootPrefix, "") + IDs := strings.Split(noRoot, "/") + if len(IDs) != 2 { + errAnswer = answer.NewInvalidMessageFieldError("failed to split channel: %v", message) + return errAnswer + } + laoID := IDs[0] + electionID := IDs[1] + + // verify if lao id is the same as the channel + if message.Lao != laoID { + errAnswer = answer.NewInvalidMessageFieldError("lao id is not the same as the channel") + return errAnswer + } + + // verify if election id is the same as the channel + if message.Election != electionID { + errAnswer = answer.NewInvalidMessageFieldError("election id is not the same as the channel") + return errAnswer + } + + // verify message created at is positive + if message.CreatedAt < 0 { + errAnswer = answer.NewInvalidMessageFieldError("message created at is negative") + return errAnswer + } + + // verify registered votes are base64URL encoded + if _, err := base64.URLEncoding.DecodeString(message.RegisteredVotes); err != nil { + errAnswer = answer.NewInvalidMessageFieldError("registered votes are not base64 encoded") + return errAnswer + } + + return nil +} diff --git a/be1-go/message/messagedata/election_open.go b/be1-go/message/messagedata/election_open.go index 78b25d8b9f..de3874586c 100644 --- a/be1-go/message/messagedata/election_open.go +++ b/be1-go/message/messagedata/election_open.go @@ -1,5 +1,11 @@ package messagedata +import ( + "encoding/base64" + "popstellar/message/answer" + "strings" +) + // ElectionOpen defines a message data type ElectionOpen struct { Object string `json:"object"` @@ -25,3 +31,49 @@ func (ElectionOpen) GetAction() string { func (ElectionOpen) NewEmpty() MessageData { return &ElectionOpen{} } + +func (message ElectionOpen) Verify(electionPath string) *answer.Error { + var errAnswer *answer.Error + _, err := base64.URLEncoding.DecodeString(message.Lao) + if err != nil { + errAnswer = answer.NewInvalidMessageFieldError("failed to decode lao: %v", err) + return errAnswer + } + + _, err = base64.URLEncoding.DecodeString(message.Election) + if err != nil { + errAnswer = answer.NewInvalidMessageFieldError("failed to decode election: %v", err) + return errAnswer + } + noRoot := strings.ReplaceAll(electionPath, RootPrefix, "") + + IDs := strings.Split(noRoot, "/") + if len(IDs) != 2 { + errAnswer = answer.NewInvalidMessageFieldError("failed to split channel: %v", electionPath) + return errAnswer + } + laoID := IDs[0] + electionID := IDs[1] + + // verify if lao id is the same as the channel + if message.Lao != laoID { + errAnswer = answer.NewInvalidMessageFieldError("lao id is not the same as the channel") + errAnswer = errAnswer.Wrap("handleElectionOpen") + return errAnswer + } + + // verify if election id is the same as the channel + if message.Election != electionID { + errAnswer = answer.NewInvalidMessageFieldError("election id is not the same as the channel") + errAnswer = errAnswer.Wrap("handleElectionOpen") + return errAnswer + } + + // verify opened at is positive + if message.OpenedAt < 0 { + errAnswer = answer.NewInvalidMessageFieldError("opened at is negative") + errAnswer = errAnswer.Wrap("handleElectionOpen") + return errAnswer + } + return nil +} diff --git a/be1-go/message/messagedata/election_setup.go b/be1-go/message/messagedata/election_setup.go index d8c5cd6515..e81fc1ef23 100644 --- a/be1-go/message/messagedata/election_setup.go +++ b/be1-go/message/messagedata/election_setup.go @@ -1,5 +1,11 @@ package messagedata +import ( + "encoding/base64" + "popstellar/message/answer" + "strconv" +) + // ElectionSetup defines a message data type ElectionSetup struct { Object string `json:"object"` @@ -21,20 +27,63 @@ type ElectionSetup struct { Questions []ElectionSetupQuestion `json:"questions"` } -const ( - // OpenBallot is a type of election - OpenBallot = "OPEN_BALLOT" - // SecretBallot is a type of election - SecretBallot = "SECRET_BALLOT" -) +const ElectionFlag = "Election" -// ElectionSetupQuestion defines a question of an election setup -type ElectionSetupQuestion struct { - ID string `json:"id"` - Question string `json:"question"` - VotingMethod string `json:"voting_method"` - BallotOptions []string `json:"ballot_options"` - WriteIn bool `json:"write_in"` +func (message ElectionSetup) Verify(laoID string) *answer.Error { + var errAnswer *answer.Error + _, err := base64.URLEncoding.DecodeString(message.Lao) + if err != nil { + errAnswer = answer.NewInvalidMessageFieldError("failed to decode lao: %v", err) + return errAnswer + } + + if message.Lao != laoID { + errAnswer = answer.NewInvalidMessageFieldError("lao id is %s, should be %s", message.Lao, laoID) + return errAnswer + } + + _, err = base64.URLEncoding.DecodeString(message.ID) + if err != nil { + errAnswer = answer.NewInvalidMessageFieldError("failed to decode election id: %v", err) + return errAnswer + } + + // verify election setup message id + expectedID := Hash( + ElectionFlag, + laoID, + strconv.Itoa(int(message.CreatedAt)), + message.Name, + ) + if message.ID != expectedID { + errAnswer = answer.NewInvalidMessageFieldError("election id is %s, should be %s", message.ID, expectedID) + return errAnswer + } + if len(message.Name) == 0 { + errAnswer = answer.NewInvalidMessageFieldError("election name is empty") + return errAnswer + } + if message.Version != OpenBallot && message.Version != SecretBallot { + errAnswer = answer.NewInvalidMessageFieldError("election version is %s, should be %s or %s", message.Version, OpenBallot, SecretBallot) + return errAnswer + } + if message.CreatedAt < 0 { + errAnswer = answer.NewInvalidMessageFieldError("election created at is %d, should be minimum 0", message.CreatedAt) + return errAnswer + } + if message.StartTime < message.CreatedAt { + errAnswer = answer.NewInvalidMessageFieldError("election start should be greater that creation time") + return errAnswer + } + if message.EndTime < message.StartTime { + errAnswer = answer.NewInvalidMessageFieldError("election end should be greater that start time") + return errAnswer + } + if len(message.Questions) == 0 { + errAnswer = answer.NewInvalidMessageFieldError("election contains no questions") + return errAnswer + } + return nil } // GetObject implements MessageData @@ -51,3 +100,50 @@ func (ElectionSetup) GetAction() string { func (ElectionSetup) NewEmpty() MessageData { return &ElectionSetup{} } + +const ( + // OpenBallot is a type of election + OpenBallot = "OPEN_BALLOT" + // SecretBallot is a type of election + SecretBallot = "SECRET_BALLOT" + questionFlag = "Question" + PluralityMethod = "Plurality" + ApprovalMethod = "Approval" +) + +// ElectionSetupQuestion defines a question of an election setup +type ElectionSetupQuestion struct { + ID string `json:"id"` + Question string `json:"question"` + VotingMethod string `json:"voting_method"` + BallotOptions []string `json:"ballot_options"` + WriteIn bool `json:"write_in"` +} + +func (q ElectionSetupQuestion) Verify(electionSetupID string) *answer.Error { + var errAnswer *answer.Error + _, err := base64.URLEncoding.DecodeString(q.ID) + if err != nil { + errAnswer = answer.NewInvalidMessageFieldError("failed to decode Question id: %v", err) + return errAnswer + } + expectedID := Hash( + questionFlag, + electionSetupID, + q.Question, + ) + if q.ID != expectedID { + errAnswer = answer.NewInvalidMessageFieldError("Question id is %s, should be %s", q.ID, expectedID) + return errAnswer + } + if len(q.Question) == 0 { + errAnswer = answer.NewInvalidMessageFieldError("Question is empty") + return errAnswer + } + if q.VotingMethod != PluralityMethod && q.VotingMethod != ApprovalMethod { + errAnswer = answer.NewInvalidMessageFieldError("Question voting method is %s, should be %s or %s", + q.VotingMethod, PluralityMethod, ApprovalMethod) + return errAnswer + } + return nil +} diff --git a/be1-go/message/messagedata/roll_call_close.go b/be1-go/message/messagedata/roll_call_close.go index 54000ec84f..42fcc5d451 100644 --- a/be1-go/message/messagedata/roll_call_close.go +++ b/be1-go/message/messagedata/roll_call_close.go @@ -1,5 +1,12 @@ package messagedata +import ( + "encoding/base64" + "popstellar/message/answer" + "strconv" + "strings" +) + // RollCallClose defines a message data type RollCallClose struct { Object string `json:"object"` @@ -14,6 +21,38 @@ type RollCallClose struct { Attendees []string `json:"attendees"` } +func (message RollCallClose) Verify(laoPath string) *answer.Error { + var errAnswer *answer.Error + _, err := base64.URLEncoding.DecodeString(message.UpdateID) + if err != nil { + errAnswer = answer.NewInvalidMessageFieldError("failed to decode roll call update ID: %v", err) + return errAnswer + } + + expectedID := Hash( + RollCallFlag, + strings.ReplaceAll(laoPath, RootPrefix, ""), + message.Closes, + strconv.Itoa(int(message.ClosedAt)), + ) + if message.UpdateID != expectedID { + errAnswer = answer.NewInvalidMessageFieldError("roll call update id is %s, should be %s", message.UpdateID, expectedID) + return errAnswer + } + + _, err = base64.URLEncoding.DecodeString(message.Closes) + if err != nil { + errAnswer = answer.NewInvalidMessageFieldError("failed to decode roll call closes: %v", err) + return errAnswer + } + + if message.ClosedAt < 0 { + errAnswer = answer.NewInvalidMessageFieldError("roll call closed at is %d, should be minimum 0", message.ClosedAt) + return errAnswer + } + return nil +} + // GetObject implements MessageData func (RollCallClose) GetObject() string { return RollCallObject diff --git a/be1-go/message/messagedata/roll_call_create.go b/be1-go/message/messagedata/roll_call_create.go index a38858555d..d02947f140 100644 --- a/be1-go/message/messagedata/roll_call_create.go +++ b/be1-go/message/messagedata/roll_call_create.go @@ -1,5 +1,12 @@ package messagedata +import ( + "encoding/base64" + "popstellar/message/answer" + "strconv" + "strings" +) + // RollCallCreate defines a message data type RollCallCreate struct { Object string `json:"object"` @@ -20,6 +27,54 @@ type RollCallCreate struct { Description string `json:"description"` } +const RollCallFlag = "R" + +func (message RollCallCreate) Verify(laoPath string) *answer.Error { + var errAnswer *answer.Error + // verify id is base64URL encoded + _, err := base64.URLEncoding.DecodeString(message.ID) + if err != nil { + errAnswer = answer.NewInvalidMessageFieldError("failed to decode roll call ID: %v", err) + errAnswer = errAnswer.Wrap("handleRollCallCreate") + return errAnswer + } + + // verify roll call create message id + expectedID := Hash( + RollCallFlag, + strings.ReplaceAll(laoPath, RootPrefix, ""), + strconv.Itoa(int(message.Creation)), + message.Name, + ) + if message.ID != expectedID { + errAnswer = answer.NewInvalidMessageFieldError("roll call id is %s, should be %s", message.ID, expectedID) + errAnswer = errAnswer.Wrap("handleRollCallCreate") + return errAnswer + } + + // verify creation is positive + if message.Creation < 0 { + errAnswer = answer.NewInvalidMessageFieldError("roll call creation is %d, should be minimum 0", message.Creation) + errAnswer = errAnswer.Wrap("handleRollCallCreate") + return errAnswer + } + + // verify proposed start after creation + if message.ProposedStart < message.Creation { + errAnswer = answer.NewInvalidMessageFieldError("roll call proposed start time should be greater than creation time") + errAnswer = errAnswer.Wrap("handleRollCallCreate") + return errAnswer + } + + // verify proposed end after proposed start + if message.ProposedEnd < message.ProposedStart { + errAnswer = answer.NewInvalidMessageFieldError("roll call proposed end should be greater than proposed start") + errAnswer = errAnswer.Wrap("handleRollCallCreate") + return errAnswer + } + return nil +} + // GetObject implements MessageData func (RollCallCreate) GetObject() string { return RollCallObject diff --git a/be1-go/message/messagedata/roll_call_open.go b/be1-go/message/messagedata/roll_call_open.go index 42ec37aabf..2fc620c8f0 100644 --- a/be1-go/message/messagedata/roll_call_open.go +++ b/be1-go/message/messagedata/roll_call_open.go @@ -1,5 +1,12 @@ package messagedata +import ( + "encoding/base64" + "popstellar/message/answer" + "strconv" + "strings" +) + // RollCallOpen defines a message data type RollCallOpen struct { Object string `json:"object"` @@ -11,6 +18,37 @@ type RollCallOpen struct { OpenedAt int64 `json:"opened_at"` } +func (message RollCallOpen) Verify(laoPath string) *answer.Error { + var errAnswer *answer.Error + _, err := base64.URLEncoding.DecodeString(message.UpdateID) + if err != nil { + errAnswer = answer.NewInvalidMessageFieldError("failed to decode roll call update ID: %v", err) + return errAnswer + } + expectedID := Hash( + RollCallFlag, + strings.ReplaceAll(laoPath, RootPrefix, ""), + message.Opens, + strconv.Itoa(int(message.OpenedAt)), + ) + if message.UpdateID != expectedID { + errAnswer = answer.NewInvalidMessageFieldError("roll call update id is %s, should be %s", message.UpdateID, expectedID) + return errAnswer + } + + _, err = base64.URLEncoding.DecodeString(message.Opens) + if err != nil { + errAnswer = answer.NewInvalidMessageFieldError("failed to decode roll call opens: %v", err) + return errAnswer + } + + if message.OpenedAt < 0 { + errAnswer = answer.NewInvalidMessageFieldError("roll call opened at is %d, should be minimum 0", message.OpenedAt) + return errAnswer + } + return nil +} + // GetObject implements MessageData func (RollCallOpen) GetObject() string { return RollCallObject diff --git a/be1-go/message/messagedata/vote_cast_vote.go b/be1-go/message/messagedata/vote_cast_vote.go index 1090823279..8b8df771bc 100644 --- a/be1-go/message/messagedata/vote_cast_vote.go +++ b/be1-go/message/messagedata/vote_cast_vote.go @@ -1,8 +1,10 @@ package messagedata import ( + "encoding/base64" "encoding/json" "popstellar/message/answer" + "strings" ) // VoteCastVote defines a message data @@ -60,6 +62,42 @@ func (v *Vote) UnmarshalJSON(b []byte) error { return nil } +func (message VoteCastVote) Verify(electionPath string) *answer.Error { + var errAnswer *answer.Error + // verify lao id is base64URL encoded + _, err := base64.URLEncoding.DecodeString(message.Lao) + if err != nil { + errAnswer = answer.NewInvalidMessageFieldError("failed to decode lao: %v", err) + return errAnswer + } + // verify election id is base64URL encoded + _, err = base64.URLEncoding.DecodeString(message.Election) + if err != nil { + errAnswer = answer.NewInvalidMessageFieldError("failed to decode election: %v", err) + return errAnswer + } + // split channel to [lao id, election id] + noRoot := strings.ReplaceAll(electionPath, RootPrefix, "") + IDs := strings.Split(noRoot, "/") + if len(IDs) != 2 { + errAnswer = answer.NewInvalidMessageFieldError("failed to split channel: %v", electionPath) + return errAnswer + } + laoID := IDs[0] + electionID := IDs[1] + // verify if lao id is the same as the channel + if message.Lao != laoID { + errAnswer = answer.NewInvalidMessageFieldError("lao id is not the same as the channel") + return errAnswer + } + // verify if election id is the same as the channel + if message.Election != electionID { + errAnswer = answer.NewInvalidMessageFieldError("election id is not the same as the channel") + return errAnswer + } + return nil +} + // GetObject implements MessageData func (VoteCastVote) GetObject() string { return ElectionObject diff --git a/be1-go/message/query/method/broadcast.go b/be1-go/message/query/method/broadcast.go index 7abc77b3f4..75b3dd917c 100644 --- a/be1-go/message/query/method/broadcast.go +++ b/be1-go/message/query/method/broadcast.go @@ -9,8 +9,10 @@ import ( type Broadcast struct { query.Base - Params struct { - Channel string `json:"channel"` - Message message.Message `json:"message"` - } `json:"params"` + Params BroadcastParams `json:"params"` +} + +type BroadcastParams struct { + Channel string `json:"channel"` + Message message.Message `json:"message"` } diff --git a/be1-go/message/query/method/catchup.go b/be1-go/message/query/method/catchup.go index ac7ef40020..008bc8282f 100644 --- a/be1-go/message/query/method/catchup.go +++ b/be1-go/message/query/method/catchup.go @@ -8,7 +8,9 @@ type Catchup struct { ID int `json:"id"` - Params struct { - Channel string `json:"channel"` - } `json:"params"` + Params CatchupParams `json:"params"` +} + +type CatchupParams struct { + Channel string `json:"channel"` } diff --git a/be1-go/message/query/method/greet_server.go b/be1-go/message/query/method/greet_server.go index 1803f5f8df..7fdcab7434 100644 --- a/be1-go/message/query/method/greet_server.go +++ b/be1-go/message/query/method/greet_server.go @@ -2,7 +2,7 @@ package method import "popstellar/message/query" -type ServerInfo struct { +type GreetServerParams struct { PublicKey string `json:"public_key"` ServerAddress string `json:"server_address"` ClientAddress string `json:"client_address"` @@ -11,5 +11,5 @@ type ServerInfo struct { // GreetServer defines a JSON RPC greetServer message type GreetServer struct { query.Base - Params ServerInfo `json:"params"` + Params GreetServerParams `json:"params"` } diff --git a/be1-go/message/query/method/message/message.go b/be1-go/message/query/method/message/message.go index 6e9941ecee..f283a18798 100644 --- a/be1-go/message/query/method/message/message.go +++ b/be1-go/message/query/method/message/message.go @@ -33,6 +33,24 @@ func (m Message) UnmarshalData(e interface{}) error { return nil } +// UnmarshalMsgData fills the provided elements with the message data stored in the +// data field. Recall that the Data field contains a base64URL representation of +// a message data, it takes care of properly decoding it. The provided element +// 'e' MUST be a pointer. +func (m Message) UnmarshalMsgData(e interface{}) *answer.Error { + jsonData, err := base64.URLEncoding.DecodeString(m.Data) + if err != nil { + return answer.NewInvalidMessageFieldError("failed to decode base64: %v", err) + } + + err = json.Unmarshal(jsonData, e) + if err != nil { + return answer.NewInvalidMessageFieldError("failed to unmarshal jsonData: %v", err) + } + + return nil +} + // WitnessSignature defines a witness signature in a message type WitnessSignature struct { Witness string `json:"witness"` diff --git a/be1-go/message/query/method/publish.go b/be1-go/message/query/method/publish.go index 4c7b2428b7..0a0651ede7 100644 --- a/be1-go/message/query/method/publish.go +++ b/be1-go/message/query/method/publish.go @@ -11,8 +11,10 @@ type Publish struct { ID int `json:"id"` - Params struct { - Channel string `json:"channel"` - Message message.Message `json:"message"` - } `json:"params"` + Params PublishParams `json:"params"` +} + +type PublishParams struct { + Channel string `json:"channel"` + Message message.Message `json:"message"` } diff --git a/be1-go/message/query/method/subscribe.go b/be1-go/message/query/method/subscribe.go index c42ed07d6d..e4e4ad0d3d 100644 --- a/be1-go/message/query/method/subscribe.go +++ b/be1-go/message/query/method/subscribe.go @@ -8,7 +8,9 @@ type Subscribe struct { ID int `json:"id"` - Params struct { - Channel string `json:"channel"` - } `json:"params"` + Params SubscribeParams `json:"params"` +} + +type SubscribeParams struct { + Channel string `json:"channel"` } diff --git a/be1-go/message/query/method/unsubscribe.go b/be1-go/message/query/method/unsubscribe.go index 35ce37793b..0c74d00e96 100644 --- a/be1-go/message/query/method/unsubscribe.go +++ b/be1-go/message/query/method/unsubscribe.go @@ -8,7 +8,9 @@ type Unsubscribe struct { ID int `json:"id"` - Params struct { - Channel string `json:"channel"` - } `json:"params"` + Params UnsubscribeParams `json:"params"` +} + +type UnsubscribeParams struct { + Channel string `json:"channel"` } diff --git a/be1-go/message/test/answer/answer_test.go b/be1-go/message/test/answer/answer_test.go index 480ab7961c..0114af314f 100644 --- a/be1-go/message/test/answer/answer_test.go +++ b/be1-go/message/test/answer/answer_test.go @@ -51,10 +51,6 @@ func Test_Error_functions(t *testing.T) { require.Equal(t, -1, invalidAction.Code) require.Equal(t, "invalid action: "+formatString, invalidAction.Description) - invalidObject := answer.NewInvalidObjectError(formatString) - require.Equal(t, -1, invalidObject.Code) - require.Equal(t, "invalid object: "+formatString, invalidObject.Description) - invalidResource := answer.NewInvalidResourceError(formatString) require.Equal(t, -2, invalidResource.Code) require.Equal(t, "invalid resource: "+formatString, invalidResource.Description) diff --git a/be1-go/message/test/answer/error_test.go b/be1-go/message/test/answer/error_test.go index a2bdad0492..ca861dda9b 100644 --- a/be1-go/message/test/answer/error_test.go +++ b/be1-go/message/test/answer/error_test.go @@ -11,8 +11,4 @@ func Test_Error_Constructor(t *testing.T) { err := answer.NewInvalidActionError("@@@") require.Equal(t, -1, err.Code) require.Equal(t, "invalid action: @@@", err.Description) - - err = answer.NewInvalidObjectError("@@@") - require.Equal(t, -1, err.Code) - require.Equal(t, "invalid object: @@@", err.Description) } diff --git a/be1-go/network/socket/fake_socket.go b/be1-go/network/socket/fake_socket.go new file mode 100644 index 0000000000..f8dc5d4cef --- /dev/null +++ b/be1-go/network/socket/fake_socket.go @@ -0,0 +1,49 @@ +package socket + +import "popstellar/message/query/method/message" + +// FakeSocket is a fake implementation of a Socket +// +// - implements socket.Socket +type FakeSocket struct { + Socket + + ResultID int + Res []message.Message + MissingMsgs map[string][]message.Message + Msg []byte + + Err error + + // the Socket ID + Id string +} + +// Send implements socket.Socket +func (f *FakeSocket) Send(msg []byte) { + f.Msg = msg +} + +// SendResult implements socket.Socket +func (f *FakeSocket) SendResult(id int, res []message.Message, missingMsgs map[string][]message.Message) { + f.ResultID = id + f.Res = res + f.MissingMsgs = missingMsgs +} + +// SendError implements socket.Socket +func (f *FakeSocket) SendError(id *int, err error) { + f.Err = err +} + +func (f *FakeSocket) ID() string { + return f.Id +} + +func (f *FakeSocket) GetMessage() []byte { + return f.Msg +} + +func (f *FakeSocket) Type() SocketType { + return ClientSocketType +} diff --git a/be1-go/network/socket/socket_impl.go b/be1-go/network/socket/socket_impl.go index ef49801f45..ace6b8460c 100644 --- a/be1-go/network/socket/socket_impl.go +++ b/be1-go/network/socket/socket_impl.go @@ -99,8 +99,7 @@ func (s *baseSocket) ReadPump() { select { case <-s.done: return - default: - s.receiver <- msg + case s.receiver <- msg: } } } diff --git a/be1-go/sonar-project.properties b/be1-go/sonar-project.properties index 17ebcc64e6..0f9d4e6e5b 100644 --- a/be1-go/sonar-project.properties +++ b/be1-go/sonar-project.properties @@ -7,7 +7,7 @@ sonar.go.coverage.reportPaths=./coverage.out # Path patterns of the source files sonar.sources=. -sonar.exclusions=**/*_test.go,**/test/** +sonar.exclusions=**/*_test.go,**/test/**, **/generator/**, **/mock** # Path patterns of the test files sonar.tests=. diff --git a/be1-go/validation/schema_validator.go b/be1-go/validation/schema_validator.go index 4919fd3378..73978aa5f5 100644 --- a/be1-go/validation/schema_validator.go +++ b/be1-go/validation/schema_validator.go @@ -10,7 +10,6 @@ import ( "popstellar/message/answer" "strings" - "github.com/rs/zerolog" "github.com/santhosh-tekuri/jsonschema/v3" "golang.org/x/xerrors" ) @@ -19,7 +18,6 @@ import ( type SchemaValidator struct { genericMessageSchema *jsonschema.Schema dataSchema *jsonschema.Schema - log zerolog.Logger } // SchemaType denotes the type of schema. @@ -58,8 +56,6 @@ func (s SchemaValidator) VerifyJSON(msg []byte, st SchemaType) error { reader := bytes.NewBuffer(msg[:]) var schema *jsonschema.Schema - s.log.Info().Msg("verifying msg follows the schema") - switch st { case GenericMessage: schema = s.genericMessageSchema @@ -71,7 +67,6 @@ func (s SchemaValidator) VerifyJSON(msg []byte, st SchemaType) error { err := schema.Validate(reader) if err != nil { - s.log.Err(err).Msg("failed to validate schema") return answer.NewErrorf(-4, "failed to validate schema: %v", err) } @@ -79,10 +74,9 @@ func (s SchemaValidator) VerifyJSON(msg []byte, st SchemaType) error { } // NewSchemaValidator returns a Schema Validator -func NewSchemaValidator(log zerolog.Logger) (*SchemaValidator, error) { +func NewSchemaValidator() (*SchemaValidator, error) { gmCompiler := jsonschema.NewCompiler() dataCompiler := jsonschema.NewCompiler() - log = log.With().Str("role", "base hub").Logger() // recurse over the protocol directory and load all the files err := fs.WalkDir(protocolFS, "protocol", func(path string, d fs.DirEntry, err error) error { @@ -129,7 +123,6 @@ func NewSchemaValidator(log zerolog.Logger) (*SchemaValidator, error) { return &SchemaValidator{ genericMessageSchema: gmSchema, dataSchema: dataSchema, - log: log, }, nil } diff --git a/be1-go/validation/schema_validator_test.go b/be1-go/validation/schema_validator_test.go index e93c71453d..069b0b8690 100644 --- a/be1-go/validation/schema_validator_test.go +++ b/be1-go/validation/schema_validator_test.go @@ -2,16 +2,13 @@ package validation import ( "encoding/base64" - "io" "testing" - "github.com/rs/zerolog" - "github.com/stretchr/testify/require" ) func TestSchemaValidator_New(t *testing.T) { - validator, err := NewSchemaValidator(nolog) + validator, err := NewSchemaValidator() require.NoError(t, err) require.NotNil(t, validator.genericMessageSchema) require.NotNil(t, validator.dataSchema) @@ -20,7 +17,7 @@ func TestSchemaValidator_New(t *testing.T) { func TestSchemaValidator_ValidateResponse(t *testing.T) { response := `{"jsonrpc":"2.0","result":[{"message_id":"8ADlKroQD5VNOPGJ5aYfSQXapwfRp1gxvU5oK85jPUs=","data":"eyJvYmplY3QiOiJsYW8iLCJhY3Rpb24iOiJjcmVhdGUiLCJuYW1lIjoiV2ViIFZvdGluZyBUZXN0IiwiY3JlYXRpb24iOjE2MjMzNDUwNzEsIm9yZ2FuaXplciI6IkhoQVhDQ190TlBOQnk3WUgxWjRkRjl5Qk42TmExd01KQUNtYlZhMngzZGM9Iiwid2l0bmVzc2VzIjpbXSwiaWQiOiI0ZWlzc1M0Vk5fQm1JeXZLUW9TSkFQZjF0a2hTcFF1U2dCdnU3Tzc2QWFBPSJ9","sender":"HhAXCC_tNPNBy7YH1Z4dF9yBN6Na1wMJACmbVa2x3dc=","signature":"j-bpWF-4eB0WGSbxUgQSFVcP4BRXG2AvfndjY4RbCN7DWPlCEfunVraPAg_4qpOWJs8FODZZQOai-w_YPMHWBg==","witness_signatures":[]}],"id":40}` - validator, err := NewSchemaValidator(nolog) + validator, err := NewSchemaValidator() require.NoError(t, err) err = validator.VerifyJSON([]byte(response), GenericMessage) @@ -32,7 +29,7 @@ func TestSchemaValidator_ValidateDataLAOCreate(t *testing.T) { dataBuf, err := base64.URLEncoding.DecodeString(dataEncoded) require.NoError(t, err) - validator, err := NewSchemaValidator(nolog) + validator, err := NewSchemaValidator() require.NoError(t, err) err = validator.VerifyJSON(dataBuf, Data) @@ -42,7 +39,7 @@ func TestSchemaValidator_ValidateDataLAOCreate(t *testing.T) { func TestSchemaValidator_ValidateCatchupRequest(t *testing.T) { request := `{"jsonrpc":"2.0","method":"catchup","params":{"channel":"/root/kHQbFsB2Q_JxH55pJMfKNV3Mje5hPHjI7AZ4HIOlp40="},"id":39}` - validator, err := NewSchemaValidator(nolog) + validator, err := NewSchemaValidator() require.NoError(t, err) err = validator.VerifyJSON([]byte(request), GenericMessage) @@ -52,7 +49,7 @@ func TestSchemaValidator_ValidateCatchupRequest(t *testing.T) { func TestSchemaValidator_ValidateCatchupRequestBadChannel(t *testing.T) { request := `{"jsonrpc":"2.0","method":"catchup","params":{"channel":"kHQbFsB2Q_JxH55pJMfKNV3Mje5hPHjI7AZ4HIOlp40="},"id":39}` - validator, err := NewSchemaValidator(nolog) + validator, err := NewSchemaValidator() require.NoError(t, err) err = validator.VerifyJSON([]byte(request), GenericMessage) @@ -62,7 +59,7 @@ func TestSchemaValidator_ValidateCatchupRequestBadChannel(t *testing.T) { func TestSchemaValidator_ValidateAnswer(t *testing.T) { response := `{"jsonrpc":"2.0","result":0,"id":38}` - validator, err := NewSchemaValidator(nolog) + validator, err := NewSchemaValidator() require.NoError(t, err) err = validator.VerifyJSON([]byte(response), GenericMessage) @@ -72,7 +69,7 @@ func TestSchemaValidator_ValidateAnswer(t *testing.T) { func TestSchemaValidator_ValidatePublish(t *testing.T) { request := `{"method":"publish","id":11,"params":{"channel":"/root/4eissS4VN_BmIyvKQoSJAPf1tkhSpQuSgBvu7O76AaA=/cqAJNbhYsUcgqbqQKiDyCnlAcWKgeG1z-pz1acLr134=","message":{"data":"eyJvYmplY3QiOiJlbGVjdGlvbiIsImFjdGlvbiI6ImNhc3Rfdm90ZSIsImxhbyI6IjRlaXNzUzRWTl9CbUl5dktRb1NKQVBmMXRraFNwUXVTZ0J2dTdPNzZBYUE9IiwiY3JlYXRlZF9hdCI6MTYyMzM0NTY2MCwidm90ZXMiOlt7ImlkIjoieUIybEtHR1lSY0Robm4wWVdhOGJ0MlhrU0FHazBXNVNZQ3dBWHdWeVJIdz0iLCJxdWVzdGlvbiI6IlhPSmVDdXRsNzlJU1RFRkhWSVhTdUJmRWh6czF2V3lJYlZ3NFJ0U3FYSlk9Iiwidm90ZSI6WzFdfV0sImVsZWN0aW9uIjoiY3FBSk5iaFlzVWNncWJxUUtpRHlDbmxBY1dLZ2VHMXotcHoxYWNMcjEzND0ifQ==","sender":"Wto5aKBnfU0fIX2x1c_KB_-fVaW5COfOu-jLWkOIaWE=","signature":"Dy2EfE55nj9z4-d7xTqZV31pYRpwf2m4Rnleq7wTNvddWp1BbDEJnpg5uYfMt7qqHkSw3cZJKHxAnTD0quk2DQ==","message_id":"fc9ZXkNDjfhAc51PaLyBaBN-LOMA5nNx9sr9Xbvl2Ng=","witness_signatures":[]}},"jsonrpc":"2.0"}` - validator, err := NewSchemaValidator(nolog) + validator, err := NewSchemaValidator() require.NoError(t, err) err = validator.VerifyJSON([]byte(request), GenericMessage) @@ -82,7 +79,7 @@ func TestSchemaValidator_ValidatePublish(t *testing.T) { func TestSchemaValidator_ValidatePublishBadMethod(t *testing.T) { request := `{"method":"publish_foo","id":11,"params":{"channel":"/root/4eissS4VN_BmIyvKQoSJAPf1tkhSpQuSgBvu7O76AaA=/cqAJNbhYsUcgqbqQKiDyCnlAcWKgeG1z-pz1acLr134=","message":{"data":"eyJvYmplY3QiOiJlbGVjdGlvbiIsImFjdGlvbiI6ImNhc3Rfdm90ZSIsImxhbyI6IjRlaXNzUzRWTl9CbUl5dktRb1NKQVBmMXRraFNwUXVTZ0J2dTdPNzZBYUE9IiwiY3JlYXRlZF9hdCI6MTYyMzM0NTY2MCwidm90ZXMiOlt7ImlkIjoieUIybEtHR1lSY0Robm4wWVdhOGJ0MlhrU0FHazBXNVNZQ3dBWHdWeVJIdz0iLCJxdWVzdGlvbiI6IlhPSmVDdXRsNzlJU1RFRkhWSVhTdUJmRWh6czF2V3lJYlZ3NFJ0U3FYSlk9Iiwidm90ZSI6WzFdfV0sImVsZWN0aW9uIjoiY3FBSk5iaFlzVWNncWJxUUtpRHlDbmxBY1dLZ2VHMXotcHoxYWNMcjEzND0ifQ==","sender":"Wto5aKBnfU0fIX2x1c_KB_-fVaW5COfOu-jLWkOIaWE=","signature":"Dy2EfE55nj9z4-d7xTqZV31pYRpwf2m4Rnleq7wTNvddWp1BbDEJnpg5uYfMt7qqHkSw3cZJKHxAnTD0quk2DQ==","message_id":"fc9ZXkNDjfhAc51PaLyBaBN-LOMA5nNx9sr9Xbvl2Ng=","witness_signatures":[]}},"jsonrpc":"2.0"}` - validator, err := NewSchemaValidator(nolog) + validator, err := NewSchemaValidator() require.NoError(t, err) err = validator.VerifyJSON([]byte(request), GenericMessage) @@ -94,7 +91,7 @@ func TestSchemaValidator_ValidateCastVoteData(t *testing.T) { dataBuf, err := base64.URLEncoding.DecodeString(request) require.NoError(t, err) - validator, err := NewSchemaValidator(nolog) + validator, err := NewSchemaValidator() require.NoError(t, err) err = validator.VerifyJSON(dataBuf, Data) @@ -104,7 +101,7 @@ func TestSchemaValidator_ValidateCastVoteData(t *testing.T) { func TestSchemaValidator_ValidatePublishChirp(t *testing.T) { request := `{"jsonrpc":"2.0","method":"publish","id":11,"params":{"channel":"/root/krnBHWK2LtM_iQw20D_jJPObQ-NzmOHTzCGvBt7kq58=","message":{"data":"eyJvYmplY3QiOiJjaGlycCIsImFjdGlvbiI6ImFkZCIsInRleHQiOiJjaGlycCBjaGlycCIsInRpbWVzdGFtcCI6MTYzMzgxMDU1OX0=","message_id":"FaRCz2GE9ZQ_qjG171i04rhXkgB86VQ_EbWFMEN1lr0=","sender":"ljvhZLAzFaC7U8_dF9QU253DsMLC7JjPHYRi3wz-u9s=","signature":"ObqRiV4wlwAnB668FyNI9cVnxVNRpfMBHz2UhIVSw_VBxgMty33AyHkDdFs46l_5umccD3jFOIwBZpp96QY9CA==","witness_signatures":[]}}}` - validator, err := NewSchemaValidator(nolog) + validator, err := NewSchemaValidator() require.NoError(t, err) err = validator.VerifyJSON([]byte(request), GenericMessage) @@ -113,5 +110,3 @@ func TestSchemaValidator_ValidatePublishChirp(t *testing.T) { // ----------------------------------------------------------------------------- // Utility functions - -var nolog = zerolog.New(io.Discard)