diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 572a870..1614e8a 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -25,18 +25,55 @@ jobs: - name: Build run: go build -v ./... - - name: Start test docker containers - if: matrix.platform == 'ubuntu-latest' - run: | - docker-compose -f storage/mysql/docker-compose-test.yml up & - - - name: Set - if: matrix.platform == 'ubuntu-latest' - run: echo "NANODEP_MYSQL_STORAGE_TEST=1" >> $GITHUB_ENV - - name: Test run: go test -v -race ./... - name: Format if: matrix.platform == 'ubuntu-latest' run: if [ "$(gofmt -s -l . | wc -l)" -gt 0 ]; then exit 1; fi + + mysql-integration-test: + name: Integration tests for the MySQL backend. + runs-on: 'ubuntu-latest' + needs: build-test + services: + mysql: + image: mysql:8.0 + env: + MYSQL_RANDOM_ROOT_PASSWORD: yes + MYSQL_DATABASE: nanodep + MYSQL_USER: nanodep + MYSQL_PASSWORD: nanodep + ports: + - 3800:3306 + options: --health-cmd="mysqladmin ping" --health-interval=5s --health-timeout=2s --health-retries=3 + defaults: + run: + shell: bash + env: + MYSQL_PWD: nanodep + PORT: 3800 + steps: + - uses: actions/checkout@v3 + + - name: setup go + uses: actions/setup-go@v4 + with: + go-version: '1.17.x' + + - name: Verify MySQL connection + run: | + while ! mysqladmin ping --host=localhost --port=$PORT --protocol=TCP --silent; do + sleep 1 + done + + - name: Set up schema + run: | + mysql --version + mysql --user=nanodep --host=localhost --port=$PORT --protocol=TCP nanodep < ./storage/mysql/schema.sql + + - name: Set + run: echo "NANODEP_MYSQL_STORAGE_TEST_DSN=nanodep:nanodep@tcp(localhost:$PORT)/nanodep" >> $GITHUB_ENV + + - name: Test + run: go test -v ./storage/mysql diff --git a/storage/file/file_test.go b/storage/file/file_test.go index e6d56e4..6b103f4 100644 --- a/storage/file/file_test.go +++ b/storage/file/file_test.go @@ -1,18 +1,17 @@ package file import ( + "context" "testing" - "github.com/micromdm/nanodep/storage" "github.com/micromdm/nanodep/storage/test" ) func TestFileStorage(t *testing.T) { - test.Run(t, func(t *testing.T) storage.AllStorage { - s, err := New(t.TempDir()) - if err != nil { - t.Fatal(err) - } - return s - }) + s, err := New(t.TempDir()) + if err != nil { + t.Fatal(err) + } + + test.TestWithStorages(t, context.Background(), s) } diff --git a/storage/mysql/docker-compose-test.yml b/storage/mysql/docker-compose-test.yml deleted file mode 100644 index 9503909..0000000 --- a/storage/mysql/docker-compose-test.yml +++ /dev/null @@ -1,23 +0,0 @@ ---- -version: "2" -services: - mysql: - image: ${NANODEP_MYSQL_IMAGE:-mysql:8.0.19} - platform: ${NANODEP_MYSQL_PLATFORM:-linux/x86_64} - command: - [ - "mysqld", - "--datadir=/tmp/mysqldata", - "--log-bin=bin.log", - "--server-id=master-01" - ] - environment: - MYSQL_ROOT_PASSWORD: toor - MYSQL_DATABASE: nanodep - MYSQL_USER: nanodep - MYSQL_PASSWORD: insecure - tmpfs: - - /var/lib/mysql:rw,noexec,nosuid - - /tmpfs - ports: - - "4242:3306" \ No newline at end of file diff --git a/storage/mysql/mysql_test.go b/storage/mysql/mysql_test.go index 2872f66..4ee49bd 100644 --- a/storage/mysql/mysql_test.go +++ b/storage/mysql/mysql_test.go @@ -2,86 +2,23 @@ package mysql import ( "context" - "database/sql" - "fmt" "os" - "strings" "testing" - "time" _ "github.com/go-sql-driver/mysql" - "github.com/micromdm/nanodep/storage" "github.com/micromdm/nanodep/storage/test" ) func TestMySQLStorage(t *testing.T) { - testDSN := os.Getenv("NANODEP_MYSQL_STORAGE_TEST") + testDSN := os.Getenv("NANODEP_MYSQL_STORAGE_TEST_DSN") if testDSN == "" { - t.Skip("NANODEP_MYSQL_STORAGE_TEST not set") + t.Skip("NANODEP_MYSQL_STORAGE_TEST_DSN not set") } - test.Run(t, func(t *testing.T) storage.AllStorage { - dbName := initTestDB(t) - testDSN := fmt.Sprintf("nanodep:insecure@tcp(localhost:4242)/%s?charset=utf8mb4&loc=UTC", dbName) - s, err := New(WithDSN(testDSN)) - if err != nil { - t.Fatal(err) - } - return s - }) -} - -// initTestDB clears any existing data from the database. -func initTestDB(t *testing.T) string { - rootDSN := "root:toor@tcp(localhost:4242)/?charset=utf8mb4&loc=UTC" - db, err := sql.Open("mysql", rootDSN) - if err != nil { - t.Fatal(err) - } - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) - defer cancel() - for { - err := db.PingContext(ctx) - if err == nil { - break - } - t.Logf("failed to connect: %s, retrying connection", err) - select { - case <-time.After(1 * time.Second): - // OK, continue. - case <-ctx.Done(): - t.Fatal("timeout connecting to MySQL") - } - } - defer func() { - if err := db.Close(); err != nil { - t.Fatal(err) - } - }() - dbName := dbName(t) - _, err = db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", dbName)) + s, err := New(WithDSN(testDSN)) if err != nil { t.Fatal(err) } - _, err = db.Exec(fmt.Sprintf("CREATE DATABASE %s;", dbName)) - if err != nil { - t.Fatal(err) - } - _, err = db.Exec(fmt.Sprintf("USE %s;", dbName)) - if err != nil { - t.Fatal(err) - } - _, err = db.Exec(Schema) - if err != nil { - t.Fatal(err) - } - _, err = db.Exec(fmt.Sprintf("GRANT ALL PRIVILEGES ON %s.* TO 'nanodep';", dbName)) - if err != nil { - t.Fatal(err) - } - return dbName -} -func dbName(t *testing.T) string { - return strings.ReplaceAll(strings.ReplaceAll(t.Name(), "/", "_"), "-", "_") + test.TestWithStorages(t, context.Background(), s) } diff --git a/storage/test/test.go b/storage/test/test.go index 86e3a3c..5e7088b 100644 --- a/storage/test/test.go +++ b/storage/test/test.go @@ -5,6 +5,7 @@ import ( "bytes" "context" "errors" + "math/rand" "testing" "time" @@ -13,194 +14,194 @@ import ( "github.com/micromdm/nanodep/tokenpki" ) -// Run runs a battery of tests on the storage.AllStorage returned by storageFn. -func Run(t *testing.T, storageFn func(t *testing.T) storage.AllStorage) { - ctx := context.Background() +// TestWithStorages runs multiple tests with different storage provided by storageFn. +func TestWithStorages(t *testing.T, ctx context.Context, store storage.AllStorage) { + depName1, depName2 := genRandName(4), genRandName(4) - // Test retrieval methods on empty storage. t.Run("empty", func(t *testing.T) { - const name = "empty" - - s := storageFn(t) - - if _, _, err := s.RetrieveTokenPKI(ctx, name); !errors.Is(err, storage.ErrNotFound) { - t.Fatalf("unexpected error: %s", err) - } - - if _, err := s.RetrieveAuthTokens(ctx, name); !errors.Is(err, storage.ErrNotFound) { - t.Fatalf("unexpected error: %s", err) - } - - config, err := s.RetrieveConfig(ctx, name) - checkErr(t, err) - if config != nil { - t.Fatalf("expected non-existent config: %+v", config) - } - - // Profile assigner storing and retrieval. - profileUUID, modTime, err := s.RetrieveAssignerProfile(ctx, name) - checkErr(t, err) - if profileUUID != "" { - t.Fatal("expected empty profileUUID") - } - if !modTime.IsZero() { - t.Fatal("expected zero modTime") - } - - cursor, err := s.RetrieveCursor(ctx, name) - checkErr(t, err) - if cursor != "" { - t.Fatal("expected empty cursor") - } + TestEmpty(t, ctx, depName1, store) }) - testWithName := func(t *testing.T, name string, s storage.AllStorage) { - // PKI storing and retrieval. - if _, _, err := s.RetrieveTokenPKI(ctx, name); !errors.Is(err, storage.ErrNotFound) { - t.Fatalf("unexpected error: %s", err) - } - pemCert, pemKey := generatePKI(t, "basicdn", 1) - err := s.StoreTokenPKI(ctx, name, pemCert, pemKey) - checkErr(t, err) - pemCert2, pemKey2, err := s.RetrieveTokenPKI(ctx, name) - checkErr(t, err) - if !bytes.Equal(pemCert, pemCert2) { - t.Fatalf("pem cert mismatch: %s vs. %s", pemCert, pemCert2) - } - if !bytes.Equal(pemKey, pemKey2) { - t.Fatalf("pem key mismatch: %s vs. %s", pemKey, pemKey2) - } - - // Token storing and retrieval. - if _, err := s.RetrieveAuthTokens(ctx, name); !errors.Is(err, storage.ErrNotFound) { - t.Fatalf("unexpected error: %s", err) - } - tokens := &client.OAuth1Tokens{ - ConsumerKey: "CK_9af2f8218b150c351ad802c6f3d66abe", - ConsumerSecret: "CS_9af2f8218b150c351ad802c6f3d66abe", - AccessToken: "AT_9af2f8218b150c351ad802c6f3d66abe", - AccessSecret: "AS_9af2f8218b150c351ad802c6f3d66abe", - AccessTokenExpiry: time.Now().UTC(), - } - err = s.StoreAuthTokens(ctx, name, tokens) - checkErr(t, err) - tokens2, err := s.RetrieveAuthTokens(ctx, name) - checkErr(t, err) - checkTokens(t, tokens, tokens2) - tokens3 := &client.OAuth1Tokens{ - ConsumerKey: "foo_CK_9af2f8218b150c351ad802c6f3d66abe", - ConsumerSecret: "foo_CS_9af2f8218b150c351ad802c6f3d66abe", - AccessToken: "foo_AT_9af2f8218b150c351ad802c6f3d66abe", - AccessSecret: "foo_AS_9af2f8218b150c351ad802c6f3d66abe", - AccessTokenExpiry: time.Now().Add(5 * time.Second).UTC(), - } - err = s.StoreAuthTokens(ctx, name, tokens3) - checkErr(t, err) - tokens4, err := s.RetrieveAuthTokens(ctx, name) - checkErr(t, err) - checkTokens(t, tokens3, tokens4) - - // Config storing and retrieval. - config, err := s.RetrieveConfig(ctx, name) - checkErr(t, err) - if config != nil { - t.Fatalf("expected not-existing config: %+v", config) - } - config = &client.Config{ - BaseURL: "https://config.example.com", - } - err = s.StoreConfig(ctx, name, config) - checkErr(t, err) - config2, err := s.RetrieveConfig(ctx, name) - checkErr(t, err) - if *config != *config2 { - t.Fatalf("config mismatch: %+v vs. %+v", config, config2) - } - config2 = &client.Config{ - BaseURL: "https://config2.example.com", - } - err = s.StoreConfig(ctx, name, config2) - checkErr(t, err) - config3, err := s.RetrieveConfig(ctx, name) - checkErr(t, err) - if *config2 != *config3 { - t.Fatalf("config mismatch: %+v vs. %+v", config2, config3) - } - - // Profile assigner storing and retrieval. - profileUUID, modTime, err := s.RetrieveAssignerProfile(ctx, name) - checkErr(t, err) - if profileUUID != "" { - t.Fatal("expected empty profileUUID") - } - if !modTime.IsZero() { - t.Fatal("expected zero modTime") - } - profileUUID = "43277A13FBCA0CFC" - err = s.StoreAssignerProfile(ctx, name, profileUUID) - checkErr(t, err) - profileUUID2, modTime, err := s.RetrieveAssignerProfile(ctx, name) - checkErr(t, err) - if profileUUID != profileUUID2 { - t.Fatalf("profileUUID mismatch: %s vs. %s", profileUUID, profileUUID2) - } - now := time.Now() - if modTime.Before(now.Add(-1*time.Minute)) || modTime.After(now.Add(1*time.Minute)) { - t.Fatalf("mismatch modTime, expected: %s (+/- 1m), actual: %s", now, modTime) - } - time.Sleep(1 * time.Second) - profileUUID3 := "foo_43277A13FBCA0CFC" - err = s.StoreAssignerProfile(ctx, name, profileUUID3) - checkErr(t, err) - profileUUID4, modTime2, err := s.RetrieveAssignerProfile(ctx, name) - checkErr(t, err) - if profileUUID3 != profileUUID4 { - t.Fatalf("profileUUID mismatch: %s vs. %s", profileUUID, profileUUID3) - } - if modTime2 == modTime { - t.Fatalf("expected time update: %s", modTime2) - } - now = time.Now() - if modTime2.Before(now.Add(-1*time.Minute)) || modTime2.After(now.Add(1*time.Minute)) { - t.Fatalf("mismatch modTime, expected: %s (+/- 1m), actual: %s", now, modTime) - } - - cursor, err := s.RetrieveCursor(ctx, name) - checkErr(t, err) - if cursor != "" { - t.Fatal("expected empty cursor") - } - cursor = "MTY1NzI2ODE5Ny0x" - err = s.StoreCursor(ctx, name, cursor) - checkErr(t, err) - cursor2, err := s.RetrieveCursor(ctx, name) - checkErr(t, err) - if cursor != cursor2 { - t.Fatalf("cursor mismatch: %s vs. %s", cursor, cursor2) - } - cursor2 = "foo_MTY1NzI2ODE5Ny0x" - err = s.StoreCursor(ctx, name, cursor2) - checkErr(t, err) - cursor3, err := s.RetrieveCursor(ctx, name) - checkErr(t, err) - if cursor2 != cursor3 { - t.Fatalf("cursor mismatch: %s vs. %s", cursor2, cursor3) - } - } - - t.Run("basic", func(t *testing.T) { - storage := storageFn(t) - testWithName(t, "basic", storage) + t.Run("basic-name1", func(t *testing.T) { + TestWitName(t, ctx, depName1, store) }) - t.Run("multiple-names", func(t *testing.T) { - storage := storageFn(t) - testWithName(t, "name1", storage) - testWithName(t, "name2", storage) + t.Run("basic-name2", func(t *testing.T) { + TestWitName(t, ctx, depName2, store) }) } +// TestEmpty tests retrieval methods on an empty/missing name. +func TestEmpty(t *testing.T, ctx context.Context, name string, s storage.AllStorage) { + if _, _, err := s.RetrieveTokenPKI(ctx, name); !errors.Is(err, storage.ErrNotFound) { + t.Fatalf("unexpected error: %s", err) + } + + if _, err := s.RetrieveAuthTokens(ctx, name); !errors.Is(err, storage.ErrNotFound) { + t.Fatalf("unexpected error: %s", err) + } + + config, err := s.RetrieveConfig(ctx, name) + checkErr(t, err) + if config != nil { + t.Fatalf("expected non-existent config: %+v", config) + } + + // Profile assigner storing and retrieval. + profileUUID, modTime, err := s.RetrieveAssignerProfile(ctx, name) + checkErr(t, err) + if profileUUID != "" { + t.Fatal("expected empty profileUUID") + } + if !modTime.IsZero() { + t.Fatal("expected zero modTime") + } + cursor, err := s.RetrieveCursor(ctx, name) + checkErr(t, err) + if cursor != "" { + t.Fatal("expected empty cursor") + } +} + +func TestWitName(t *testing.T, ctx context.Context, name string, s storage.AllStorage) { + // PKI storing and retrieval. + if _, _, err := s.RetrieveTokenPKI(ctx, name); !errors.Is(err, storage.ErrNotFound) { + t.Fatalf("unexpected error: %s", err) + } + pemCert, pemKey := generatePKI(t, "basicdn", 1) + err := s.StoreTokenPKI(ctx, name, pemCert, pemKey) + checkErr(t, err) + pemCert2, pemKey2, err := s.RetrieveTokenPKI(ctx, name) + checkErr(t, err) + if !bytes.Equal(pemCert, pemCert2) { + t.Fatalf("pem cert mismatch: %s vs. %s", pemCert, pemCert2) + } + if !bytes.Equal(pemKey, pemKey2) { + t.Fatalf("pem key mismatch: %s vs. %s", pemKey, pemKey2) + } + + // Token storing and retrieval. + if _, err := s.RetrieveAuthTokens(ctx, name); !errors.Is(err, storage.ErrNotFound) { + t.Fatalf("unexpected error: %s", err) + } + tokens := &client.OAuth1Tokens{ + ConsumerKey: "CK_9af2f8218b150c351ad802c6f3d66abe", + ConsumerSecret: "CS_9af2f8218b150c351ad802c6f3d66abe", + AccessToken: "AT_9af2f8218b150c351ad802c6f3d66abe", + AccessSecret: "AS_9af2f8218b150c351ad802c6f3d66abe", + AccessTokenExpiry: time.Now().UTC(), + } + err = s.StoreAuthTokens(ctx, name, tokens) + checkErr(t, err) + tokens2, err := s.RetrieveAuthTokens(ctx, name) + checkErr(t, err) + checkTokens(t, tokens, tokens2) + tokens3 := &client.OAuth1Tokens{ + ConsumerKey: "foo_CK_9af2f8218b150c351ad802c6f3d66abe", + ConsumerSecret: "foo_CS_9af2f8218b150c351ad802c6f3d66abe", + AccessToken: "foo_AT_9af2f8218b150c351ad802c6f3d66abe", + AccessSecret: "foo_AS_9af2f8218b150c351ad802c6f3d66abe", + AccessTokenExpiry: time.Now().Add(5 * time.Second).UTC(), + } + err = s.StoreAuthTokens(ctx, name, tokens3) + checkErr(t, err) + tokens4, err := s.RetrieveAuthTokens(ctx, name) + checkErr(t, err) + checkTokens(t, tokens3, tokens4) + + // Config storing and retrieval. + config, err := s.RetrieveConfig(ctx, name) + checkErr(t, err) + if config != nil { + t.Fatalf("expected not-existing config: %+v", config) + } + config = &client.Config{ + BaseURL: "https://config.example.com", + } + err = s.StoreConfig(ctx, name, config) + checkErr(t, err) + config2, err := s.RetrieveConfig(ctx, name) + checkErr(t, err) + if *config != *config2 { + t.Fatalf("config mismatch: %+v vs. %+v", config, config2) + } + config2 = &client.Config{ + BaseURL: "https://config2.example.com", + } + err = s.StoreConfig(ctx, name, config2) + checkErr(t, err) + config3, err := s.RetrieveConfig(ctx, name) + checkErr(t, err) + if *config2 != *config3 { + t.Fatalf("config mismatch: %+v vs. %+v", config2, config3) + } + + // Profile assigner storing and retrieval. + profileUUID, modTime, err := s.RetrieveAssignerProfile(ctx, name) + checkErr(t, err) + if profileUUID != "" { + t.Fatal("expected empty profileUUID") + } + if !modTime.IsZero() { + t.Fatal("expected zero modTime") + } + profileUUID = "43277A13FBCA0CFC" + err = s.StoreAssignerProfile(ctx, name, profileUUID) + checkErr(t, err) + profileUUID2, modTime, err := s.RetrieveAssignerProfile(ctx, name) + checkErr(t, err) + if profileUUID != profileUUID2 { + t.Fatalf("profileUUID mismatch: %s vs. %s", profileUUID, profileUUID2) + } + now := time.Now() + if modTime.Before(now.Add(-1*time.Minute)) || modTime.After(now.Add(1*time.Minute)) { + t.Fatalf("mismatch modTime, expected: %s (+/- 1m), actual: %s", now, modTime) + } + time.Sleep(1 * time.Second) + profileUUID3 := "foo_43277A13FBCA0CFC" + err = s.StoreAssignerProfile(ctx, name, profileUUID3) + checkErr(t, err) + profileUUID4, modTime2, err := s.RetrieveAssignerProfile(ctx, name) + checkErr(t, err) + if profileUUID3 != profileUUID4 { + t.Fatalf("profileUUID mismatch: %s vs. %s", profileUUID, profileUUID3) + } + if modTime2 == modTime { + t.Fatalf("expected time update: %s", modTime2) + } + now = time.Now() + if modTime2.Before(now.Add(-1*time.Minute)) || modTime2.After(now.Add(1*time.Minute)) { + t.Fatalf("mismatch modTime, expected: %s (+/- 1m), actual: %s", now, modTime) + } + + cursor, err := s.RetrieveCursor(ctx, name) + checkErr(t, err) + if cursor != "" { + t.Fatal("expected empty cursor") + } + cursor = "MTY1NzI2ODE5Ny0x" + err = s.StoreCursor(ctx, name, cursor) + checkErr(t, err) + cursor2, err := s.RetrieveCursor(ctx, name) + checkErr(t, err) + if cursor != cursor2 { + t.Fatalf("cursor mismatch: %s vs. %s", cursor, cursor2) + } + cursor2 = "foo_MTY1NzI2ODE5Ny0x" + err = s.StoreCursor(ctx, name, cursor2) + checkErr(t, err) + cursor3, err := s.RetrieveCursor(ctx, name) + checkErr(t, err) + if cursor2 != cursor3 { + t.Fatalf("cursor mismatch: %s vs. %s", cursor2, cursor3) + } +} + func checkTokens(t *testing.T, t1 *client.OAuth1Tokens, t2 *client.OAuth1Tokens) { + if t1 == nil || t2 == nil { + t.Fatalf("check tokens nil") + return + } if t1.ConsumerKey != t2.ConsumerKey { t.Fatalf("tokens consumer_key mismatch: %s vs. %s", t1.ConsumerKey, t2.ConsumerKey) } @@ -236,3 +237,15 @@ func generatePKI(t *testing.T, cn string, days int64) (pemCert []byte, pemKey [] pemKey = tokenpki.PEMRSAPrivateKey(key) return pemCert, pemKey } + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func genRandName(length int) string { + result := make([]byte, length) + for i := 0; i < length; i++ { + result[i] = byte(rand.Intn(26) + 'a') + } + return "go_test_dep_name." + string(result) +}