Skip to content

Commit

Permalink
p2p: convert additional Multiplexer method to use generic implementat…
Browse files Browse the repository at this point in the history
…ion (#6034)
  • Loading branch information
cce authored Jun 20, 2024
1 parent a513cbf commit 0dd0cb7
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 66 deletions.
2 changes: 1 addition & 1 deletion data/txHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ func (handler *TxHandler) Start() {
{
Tag: protocol.TxnTag,
// create anonymous struct to hold the two functions and satisfy the network.MessageProcessor interface
MessageProcessor: struct {
MessageHandler: struct {
network.ProcessorValidateFunc
network.ProcessorHandleFunc
}{
Expand Down
13 changes: 6 additions & 7 deletions network/gossipNode.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,18 +239,17 @@ func (f ProcessorHandleFunc) Handle(message ValidatedMessage) OutgoingMessage {
return f(message)
}

// TaggedMessageHandler receives one type of broadcast messages
type TaggedMessageHandler struct {
type taggedMessageDispatcher[T any] struct {
Tag
MessageHandler
MessageHandler T
}

// TaggedMessageHandler receives one type of broadcast messages
type TaggedMessageHandler = taggedMessageDispatcher[MessageHandler]

// TaggedMessageProcessor receives one type of broadcast messages
// and performs two stage processing: validating and handling
type TaggedMessageProcessor struct {
Tag
MessageProcessor
}
type TaggedMessageProcessor = taggedMessageDispatcher[MessageProcessor]

// Propagate is a convenience function to save typing in the common case of a message handler telling us to propagate an incoming message
// "return network.Propagate(msg)" instead of "return network.OutgoingMsg{network.Broadcast, msg.Tag, msg.Data}"
Expand Down
66 changes: 19 additions & 47 deletions network/multiplexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,6 @@ func getMap[T any](source *atomic.Value) map[Tag]T {
return nil
}

// getHandlersMap retrieves the handlers map.
func (m *Multiplexer) getHandlersMap() map[Tag]MessageHandler {
return getMap[MessageHandler](&m.msgHandlers)
}

// getProcessorsMap retrieves the processors map.
func (m *Multiplexer) getProcessorsMap() map[Tag]MessageProcessor {
return getMap[MessageProcessor](&m.msgHandlers)
}

// Retrieves the handler for the given message Tag from the given value while.
func getHandler[T any](source *atomic.Value, tag Tag) (T, bool) {
if handlers := getMap[T](source); handlers != nil {
Expand All @@ -77,41 +67,31 @@ func (m *Multiplexer) getProcessor(tag Tag) (MessageProcessor, bool) {

// Handle is the "input" side of the multiplexer. It dispatches the message to the previously defined handler.
func (m *Multiplexer) Handle(msg IncomingMessage) OutgoingMessage {
handler, ok := m.getHandler(msg.Tag)

if ok {
outmsg := handler.Handle(msg)
return outmsg
if handler, ok := m.getHandler(msg.Tag); ok {
return handler.Handle(msg)
}
return OutgoingMessage{}
}

// Validate is an alternative "input" side of the multiplexer. It dispatches the message to the previously defined validator.
func (m *Multiplexer) Validate(msg IncomingMessage) ValidatedMessage {
handler, ok := m.getProcessor(msg.Tag)

if ok {
outmsg := handler.Validate(msg)
return outmsg
if handler, ok := m.getProcessor(msg.Tag); ok {
return handler.Validate(msg)
}
return ValidatedMessage{}
}

// Process is the second step of message handling after validation. It dispatches the message to the previously defined processor.
func (m *Multiplexer) Process(msg ValidatedMessage) OutgoingMessage {
handler, ok := m.getProcessor(msg.Tag)

if ok {
outmsg := handler.Handle(msg)
return outmsg
if handler, ok := m.getProcessor(msg.Tag); ok {
return handler.Handle(msg)
}
return OutgoingMessage{}
}

// RegisterHandlers registers the set of given message handlers.
func (m *Multiplexer) RegisterHandlers(dispatch []TaggedMessageHandler) {
mp := make(map[Tag]MessageHandler)
if existingMap := m.getHandlersMap(); existingMap != nil {
func registerMultiplexer[T any](target *atomic.Value, dispatch []taggedMessageDispatcher[T]) {
mp := make(map[Tag]T)
if existingMap := getMap[T](target); existingMap != nil {
for k, v := range existingMap {
mp[k] = v
}
Expand All @@ -122,28 +102,20 @@ func (m *Multiplexer) RegisterHandlers(dispatch []TaggedMessageHandler) {
}
mp[v.Tag] = v.MessageHandler
}
m.msgHandlers.Store(mp)
target.Store(mp)
}

// RegisterHandlers registers the set of given message handlers.
func (m *Multiplexer) RegisterHandlers(dispatch []TaggedMessageHandler) {
registerMultiplexer(&m.msgHandlers, dispatch)
}

// RegisterProcessors registers the set of given message handlers.
func (m *Multiplexer) RegisterProcessors(dispatch []TaggedMessageProcessor) {
mp := make(map[Tag]MessageProcessor)
if existingMap := m.getProcessorsMap(); existingMap != nil {
for k, v := range existingMap {
mp[k] = v
}
}
for _, v := range dispatch {
if _, has := mp[v.Tag]; has {
panic(fmt.Sprintf("Already registered a handler for tag %v", v.Tag))
}
mp[v.Tag] = v.MessageProcessor
}
m.msgProcessors.Store(mp)
registerMultiplexer(&m.msgProcessors, dispatch)
}

// ClearProcessors deregisters all the existing message handlers other than the one provided in the excludeTags list
func clear[T any](target *atomic.Value, excludeTags []Tag) {
func clearMultiplexer[T any](target *atomic.Value, excludeTags []Tag) {
if len(excludeTags) == 0 {
target.Store(make(map[Tag]T))
return
Expand All @@ -168,10 +140,10 @@ func clear[T any](target *atomic.Value, excludeTags []Tag) {

// ClearHandlers deregisters all the existing message handlers other than the one provided in the excludeTags list
func (m *Multiplexer) ClearHandlers(excludeTags []Tag) {
clear[MessageHandler](&m.msgHandlers, excludeTags)
clearMultiplexer[MessageHandler](&m.msgHandlers, excludeTags)
}

// ClearProcessors deregisters all the existing message handlers other than the one provided in the excludeTags list
func (m *Multiplexer) ClearProcessors(excludeTags []Tag) {
clear[MessageProcessor](&m.msgProcessors, excludeTags)
clearMultiplexer[MessageProcessor](&m.msgProcessors, excludeTags)
}
21 changes: 10 additions & 11 deletions network/p2pNetwork_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func TestP2PSubmitTX(t *testing.T) {
passThroughHandler := []TaggedMessageProcessor{
{
Tag: protocol.TxnTag,
MessageProcessor: struct {
MessageHandler: struct {
ProcessorValidateFunc
ProcessorHandleFunc
}{
Expand Down Expand Up @@ -195,7 +195,7 @@ func TestP2PSubmitTXNoGossip(t *testing.T) {
passThroughHandler := []TaggedMessageProcessor{
{
Tag: protocol.TxnTag,
MessageProcessor: struct {
MessageHandler: struct {
ProcessorValidateFunc
ProcessorHandleFunc
}{
Expand Down Expand Up @@ -829,22 +829,21 @@ func TestP2PRelay(t *testing.T) {
return netA.hasPeers() && netB.hasPeers()
}, 2*time.Second, 50*time.Millisecond)

makeCounterHandler := func(numExpected int) ([]TaggedMessageProcessor, *int, chan struct{}) {
numActual := 0
makeCounterHandler := func(numExpected int) ([]TaggedMessageProcessor, *atomic.Uint32, chan struct{}) {
var numActual atomic.Uint32
counterDone := make(chan struct{})
counterHandler := []TaggedMessageProcessor{
{
Tag: protocol.TxnTag,
MessageProcessor: struct {
MessageHandler: struct {
ProcessorValidateFunc
ProcessorHandleFunc
}{
ProcessorValidateFunc(func(msg IncomingMessage) ValidatedMessage {
return ValidatedMessage{Action: Accept, Tag: msg.Tag, ValidatorData: nil}
}),
ProcessorHandleFunc(func(msg ValidatedMessage) OutgoingMessage {
numActual++
if numActual >= numExpected {
if count := numActual.Add(1); int(count) >= numExpected {
close(counterDone)
}
return OutgoingMessage{Action: Ignore}
Expand Down Expand Up @@ -916,10 +915,10 @@ func TestP2PRelay(t *testing.T) {
select {
case <-counterDone:
case <-time.After(2 * time.Second):
if *count < expectedMsgs {
require.Failf(t, "One or more messages failed to reach destination network", "%d > %d", expectedMsgs, *count)
} else if *count > expectedMsgs {
require.Failf(t, "One or more messages that were expected to be dropped, reached destination network", "%d < %d", expectedMsgs, *count)
if c := count.Load(); c < expectedMsgs {
require.Failf(t, "One or more messages failed to reach destination network", "%d > %d", expectedMsgs, c)
} else if c > expectedMsgs {
require.Failf(t, "One or more messages that were expected to be dropped, reached destination network", "%d < %d", expectedMsgs, c)
}
}
}
Expand Down

0 comments on commit 0dd0cb7

Please sign in to comment.