Skip to content

Commit

Permalink
misc fixes to address race conditions, enforce ordering of callbacks,…
Browse files Browse the repository at this point in the history
… add unit tests
  • Loading branch information
eli-darkly committed Nov 16, 2021
1 parent e0af7d4 commit fdf68c5
Show file tree
Hide file tree
Showing 11 changed files with 454 additions and 144 deletions.
8 changes: 8 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,17 @@ workflows:
version: 2
test:
jobs:
- build-and-test
- docker-build-and-smoke-test

jobs:
build-and-test:
docker:
- image: cimg/go:1.17
steps:
- checkout
- make test

docker-build-and-smoke-test:
docker:
- image: cimg/base:2021.10
Expand Down
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@ GOLANGCI_LINT_VERSION=v1.27.0
LINTER=./bin/golangci-lint
LINTER_VERSION_FILE=./bin/.golangci-lint-version-$(GOLANGCI_LINT_VERSION)

.PHONY: build clean lint build-release publish-release docker-build docker-push docker-smoke-test
.PHONY: build clean test lint build-release publish-release docker-build docker-push docker-smoke-test

build:
go build

clean:
go clean

test:
go test ./...

$(LINTER_VERSION_FILE):
rm -f $(LINTER)
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | \
Expand Down
114 changes: 28 additions & 86 deletions framework/harness.go
Original file line number Diff line number Diff line change
@@ -1,26 +1,21 @@
package framework

import (
"bytes"
"context"
"fmt"
"io"
"io/ioutil"
"net/http"
"strings"
"sync"
"time"
)

const endpointPathPrefix = "/endpoints/"
const httpListenerTimeout = time.Second * 10

type TestHarness struct {
testServiceBaseURL string
testHarnessExternalBaseURL string
testServiceInfo TestServiceInfo
endpoints map[string]*MockEndpoint
lastEndpointID int
mockEndpoints *mockEndpointsManager
logger Logger
lock sync.Mutex
}
Expand All @@ -40,12 +35,12 @@ func NewTestHarness(
debugLogger = NullLogger()
}

externalBaseUrl := fmt.Sprintf("http://%s:%d", testHarnessExternalHostname, testHarnessPort)
externalBaseURL := fmt.Sprintf("http://%s:%d", testHarnessExternalHostname, testHarnessPort)

h := &TestHarness{
testServiceBaseURL: testServiceBaseURL,
testHarnessExternalBaseURL: externalBaseUrl,
endpoints: make(map[string]*MockEndpoint),
testHarnessExternalBaseURL: externalBaseURL,
mockEndpoints: newMockEndpointsManager(externalBaseURL, debugLogger),
logger: debugLogger,
}

Expand Down Expand Up @@ -75,86 +70,33 @@ func (h *TestHarness) TestServiceHasCapability(desired string) bool {
return false
}

func (h *TestHarness) serveHTTP(w http.ResponseWriter, req *http.Request) {
if req.Method == "HEAD" {
w.WriteHeader(200) // we use this to test whether our own listener is active yet
return
}

if !strings.HasPrefix(req.URL.Path, endpointPathPrefix) {
h.logger.Printf("Received request for unrecognized URL path %s", req.URL.Path)
w.WriteHeader(404)
return
}
path := strings.TrimPrefix(req.URL.Path, endpointPathPrefix)
var endpointID string
slashPos := strings.Index(path, "/")
if slashPos >= 0 {
endpointID = path[0:slashPos]
path = path[slashPos:]
} else {
endpointID = path
path = ""
}
// NewEndpoint adds a new endpoint that can receive requests.
//
// The specified handler will be called for all incoming requests to the endpoint's
// base URL or any subpath of it. For instance, if the generated base URL (as reported
// by MockEndpoint.BaseURL()) is http://localhost:8111/endpoints/3, then it can also
// receive requests to http://localhost:8111/endpoints/3/some/subpath.
//
// When the handler is called, the test harness rewrites the request URL first so that
// the handler sees only the subpath. It also attaches a Context to the request whose
// Done channel will be closed if Close is called on the endpoint.
func (h *TestHarness) NewMockEndpoint(
handler http.Handler,
contextFn func(context.Context) context.Context,
logger Logger,
) *MockEndpoint {
if logger == nil {
logger = h.logger
}
return h.mockEndpoints.newMockEndpoint(handler, contextFn, logger)
}

h.lock.Lock()
e := h.endpoints[endpointID]
h.lock.Unlock()
if e == nil {
h.logger.Printf("Received request for unrecognized endpoint %s", req.URL.Path)
w.WriteHeader(404)
func (h *TestHarness) serveHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == "HEAD" {
w.WriteHeader(200) // we use this to test whether our own listener is active yet
return
}

var body []byte
if req.Body != nil {
data, err := ioutil.ReadAll(req.Body)
req.Body.Close()
if err != nil {
h.logger.Printf("Unexpected error trying to read request body: %s", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
body = data
}

e.lock.Lock()
ctx, canceller := context.WithCancel(req.Context())
cancellerPtr := &canceller
e.cancels = append(e.cancels, cancellerPtr)
e.lock.Unlock()

incoming := IncomingRequestInfo{
Headers: req.Header,
Method: req.Method,
Body: body,
Context: ctx,
}
select { // non-blocking push
case e.newConns <- incoming:
break
default:
h.logger.Printf("Incoming connection channel was full for %s", req.URL)
}

transformedReq := req.WithContext(ctx)
url := *req.URL
url.Path = path
transformedReq.URL = &url
if body != nil {
transformedReq.Body = ioutil.NopCloser(bytes.NewBuffer(body))
}

e.handler.ServeHTTP(w, transformedReq)

e.lock.Lock()
for i, c := range e.cancels {
if c == cancellerPtr { // can't compare functions with ==, but can compare pointers
e.cancels = append(e.cancels[:i], e.cancels[i+1:]...)
break
}
}
e.lock.Unlock()
h.mockEndpoints.serveHTTP(w, r)
}

func startServer(port int, handler http.Handler) error {
Expand Down
62 changes: 62 additions & 0 deletions framework/message_sorting_queue.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package framework

import (
"sort"
"sync"
)

type MessageSortingQueue struct {
C chan []byte
lastCounter int
deferred []deferredMessage
lock sync.Mutex
closeOnce sync.Once
}

type deferredMessage struct {
counter int
message []byte
}

func NewMessageSortingQueue(channelSize int) *MessageSortingQueue {
return &MessageSortingQueue{C: make(chan []byte, channelSize)}
}

func (q *MessageSortingQueue) Accept(counter int, message []byte) {
q.lock.Lock()
if counter > q.lastCounter+1 {
q.deferred = append(q.deferred, deferredMessage{counter: counter, message: message})
sort.Slice(q.deferred, func(i, j int) bool { return q.deferred[i].counter < q.deferred[j].counter })
q.lock.Unlock()
return
}
q.lastCounter = counter
q.C <- message
for len(q.deferred) > 0 {
next := q.deferred[0]
if next.counter != q.lastCounter+1 {
break
}
q.deferred = q.deferred[1:]
q.lastCounter++
q.C <- next.message
}
q.lock.Unlock()
}

func (q *MessageSortingQueue) Deferred() [][]byte {
q.lock.Lock()
ret := make([][]byte, 0, len(q.deferred))
for _, d := range q.deferred {
ret = append(ret, d.message)
}
q.lock.Unlock()
return ret
}

func (q *MessageSortingQueue) Close() {
q.closeOnce.Do(func() {
close(q.C)
q.C = nil
})
}
79 changes: 79 additions & 0 deletions framework/message_sorting_queue_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package framework

import (
"fmt"
"strings"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func fakeItemData(counter int) []byte {
return []byte(fmt.Sprintf("item-%d", counter))
}

func acceptTestItems(q *MessageSortingQueue, counters ...int) {
for _, c := range counters {
q.Accept(c, fakeItemData(c))
}
}

func expectTestItems(t *testing.T, q *MessageSortingQueue, counters ...int) {
for _, c := range counters {
select {
case item := <-q.C:
assert.Equal(t, string(fakeItemData(c)), string(item))
case <-time.After(time.Second):
var deferredList []string
for _, d := range q.Deferred() {
deferredList = append(deferredList, string(d))
}
require.Fail(t, "timed out waiting for item from queue",
"was waiting for item %d; deferred items were [%v]", strings.Join(deferredList, ","))
}
}
}

func expectDeferredItems(t *testing.T, q *MessageSortingQueue, counters ...int) {
var expected, actual []string
for _, c := range counters {
expected = append(expected, string(fakeItemData(c)))
}
for _, d := range q.Deferred() {
actual = append(actual, string(d))
}
assert.Equal(t, expected, actual, "did not see expected items in deferred list")
}

func TestMessageSortingQueueWithMessagesInOrder(t *testing.T) {
q := NewMessageSortingQueue(10)
acceptTestItems(q, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
expectDeferredItems(t, q) // should be empty
expectTestItems(t, q, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
}

func TestMessageSortingQueueWithMessagesOutOfOrder(t *testing.T) {
q := NewMessageSortingQueue(10)

acceptTestItems(q, 3)
expectDeferredItems(t, q, 3)

acceptTestItems(q, 2)
expectDeferredItems(t, q, 2, 3)

acceptTestItems(q, 6)
expectDeferredItems(t, q, 2, 3, 6)

acceptTestItems(q, 1)
expectTestItems(t, q, 1, 2, 3)
expectDeferredItems(t, q, 6)

acceptTestItems(q, 5)
expectDeferredItems(t, q, 5, 6)

acceptTestItems(q, 4)
expectTestItems(t, q, 4, 5, 6)
expectDeferredItems(t, q) // empty
}
Loading

0 comments on commit fdf68c5

Please sign in to comment.