Skip to content

Commit

Permalink
Implements ShutdownCode option and ShutdownSignal os.Signal wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
jasonmills committed Aug 8, 2022
1 parent 1124297 commit b8a589f
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 6 deletions.
34 changes: 31 additions & 3 deletions app.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,12 @@ type App struct {
errorHooks []ErrorHandler
validate bool
// Used to signal shutdowns.
donesMu sync.Mutex // guards dones and shutdownSig
dones []chan os.Signal
shutdownSig os.Signal
donesMu sync.Mutex // guards dones and shutdownSig
dones []chan os.Signal
shutdownSig os.Signal
waitsMu sync.Mutex // guards waits and shutdownCode
waits []chan ShutdownSignal
shutdownSignal *ShutdownSignal

osExit func(code int) // os.Exit override; used for testing only
}
Expand Down Expand Up @@ -737,6 +740,31 @@ func (app *App) Done() <-chan os.Signal {
return c
}

func (app *App) wait() <-chan ShutdownSignal {
c := make(chan ShutdownSignal, 1)

app.waitsMu.Lock()
defer app.waitsMu.Unlock()

if app.shutdownSignal != nil {
c <- *app.shutdownSignal
return c
}

app.waits = append(app.waits, c)
return c
}

func (app *App) Wait(ctx context.Context) (ShutdownSignal, error) {
c := app.wait()
select {
case s := <-c:
return s, nil
case <-ctx.Done():
return ShutdownSignal{}, ctx.Err()
}
}

// StartTimeout returns the configured startup timeout. Apps default to using
// DefaultTimeout, but users can configure this behavior using the
// StartTimeout option.
Expand Down
66 changes: 63 additions & 3 deletions shutdown.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ package fx
import (
"fmt"
"os"

"go.uber.org/multierr"
)

// Shutdowner provides a method that can manually trigger the shutdown of the
Expand All @@ -39,8 +41,26 @@ type ShutdownOption interface {
apply(*shutdowner)
}

type shutdownCode int

func (c shutdownCode) apply(s *shutdowner) {
s.exitCode = int(c)
}

// ShutdownCode implements a shutdown option that allows a user specify the
// os.Exit code that an application should exit with.
func ShutdownCode(code int) ShutdownOption {
return shutdownCode(code)
}

type shutdowner struct {
app *App
exitCode int
app *App
}

type ShutdownSignal struct {
Signal os.Signal
ExitCode int
}

// Shutdown broadcasts a signal to all of the application's Done channels
Expand All @@ -49,14 +69,25 @@ type shutdowner struct {
// In practice this means Shutdowner.Shutdown should not be called from an
// fx.Invoke, but from a fx.Lifecycle.OnStart hook.
func (s *shutdowner) Shutdown(opts ...ShutdownOption) error {
return s.app.broadcastSignal(_sigTERM)
for _, opt := range opts {
opt.apply(s)
}

return s.app.broadcastSignal(_sigTERM, s.exitCode)
}

func (app *App) shutdowner() Shutdowner {
return &shutdowner{app: app}
}

func (app *App) broadcastSignal(signal os.Signal) error {
func (app *App) broadcastSignal(signal os.Signal, code int) error {
return multierr.Combine(
app.broadcastDoneSignal(signal),
app.broadcastWaitSignal(signal, code),
)
}

func (app *App) broadcastDoneSignal(signal os.Signal) error {
app.donesMu.Lock()
defer app.donesMu.Unlock()

Expand All @@ -81,3 +112,32 @@ func (app *App) broadcastSignal(signal os.Signal) error {

return nil
}

func (app *App) broadcastWaitSignal(signal os.Signal, code int) error {
app.waitsMu.Lock()
defer app.waitsMu.Unlock()

app.shutdownSignal = &ShutdownSignal{
Signal: signal,
ExitCode: code,
}

var unsent int
for _, wait := range app.waits {
select {
case wait <- *app.shutdownSignal:
default:
// shutdown called when wait channel has already received a
// termination signal that has not been cleared
unsent++
}
}

if unsent != 0 {
return fmt.Errorf("failed to send %v codes to %v out of %v channels",
signal, unsent, len(app.waits),
)
}

return nil
}
55 changes: 55 additions & 0 deletions shutdown_code_example_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright (c) 2022 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

package fx_test

import (
"context"
"fmt"
"time"

"go.uber.org/fx"
)

func ExampleShutdownCode() {
app := fx.New(
fx.Invoke(func(shutdowner fx.Shutdowner) {
// Call the shutdowner Shutdown method with a shutdown code
// option
shutdowner.Shutdown(fx.ShutdownCode(1))
}),
)

app.Run()

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()

shutdown, err := app.Wait(ctx)

if err != nil {
panic(err)
}

fmt.Printf("os.Exit(%v)\n", shutdown.ExitCode)

// Output:
// os.Exit(1)
}
58 changes: 58 additions & 0 deletions shutdown_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ package fx_test

import (
"context"
"fmt"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -87,6 +89,62 @@ func TestShutdown(t *testing.T) {
assert.NotNil(t, <-done1, "done channel 1 did not receive signal")
assert.NotNil(t, <-done2, "done channel 2 did not receive signal")
})

t.Run("shutdown app with exit code(s)", func(t *testing.T) {
t.Parallel()

t.Run("default", func(t *testing.T) {
t.Parallel()
var s fx.Shutdowner
app := fxtest.New(t, fx.Populate(&s))

done := app.Done()
defer app.RequireStart().RequireStop()

assert.NoError(t, s.Shutdown(), "error returned from first shutdown call")
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()

signal, err := app.Wait(ctx)
assert.NoError(t, err, "error in app wait")
assert.NotEmpty(t, signal, "no shutdown signal")
assert.NotNil(t, signal.Signal)
assert.Zero(t, signal.ExitCode)
assert.Equal(t, signal.Signal, <-done)
assert.NoError(t, ctx.Err())
})

for expected := 0; expected <= 3; expected++ {
expected := expected
t.Run(fmt.Sprintf("with exit code %v", expected), func(t *testing.T) {
t.Parallel()
var s fx.Shutdowner
app := fxtest.New(
t,
fx.Populate(&s),
)

done := app.Done()
defer app.RequireStart().RequireStop()

assert.NoError(
t,
s.Shutdown(fx.ShutdownCode(expected)),
"error in app shutdown",
)

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()

signal, err := app.Wait(ctx)
assert.NoError(t, err, "error in app wait")
assert.NotEmpty(t, signal, "no shutdown signal")
assert.NotNil(t, signal.Signal)
assert.Equal(t, expected, signal.ExitCode)
assert.Equal(t, signal.Signal, <-done)
})
}
})
}

func TestDataRace(t *testing.T) {
Expand Down

0 comments on commit b8a589f

Please sign in to comment.