diff --git a/pkg/auth/httpauth/resources.go b/pkg/auth/httpauth/resources.go index 2b8a0f36..a576a377 100644 --- a/pkg/auth/httpauth/resources.go +++ b/pkg/auth/httpauth/resources.go @@ -265,19 +265,22 @@ func (res *Resources) getAccess(w http.ResponseWriter, req *http.Request) { PublicProjectID string `json:"public_project_id,omitempty"` } - var publicProjectID uuid.UUID + var publicProjectID string if len(result.PublicProjectID) > 0 { - publicProjectID, err = uuid.FromBytes(result.PublicProjectID) + publicProjectUUID, err := uuid.FromBytes(result.PublicProjectID) if err != nil { res.writeError(w, "getAccess", err.Error(), http.StatusInternalServerError) return } + if !publicProjectUUID.IsZero() { + publicProjectID = publicProjectUUID.String() + } } response.AccessGrant = result.AccessGrant response.SecretKey = result.SecretKey.ToBase32() response.Public = result.Public - response.PublicProjectID = publicProjectID.String() + response.PublicProjectID = publicProjectID w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(response) diff --git a/pkg/auth/spannerauth/spannerauth.go b/pkg/auth/spannerauth/spannerauth.go index 3fa53941..95f0151a 100644 --- a/pkg/auth/spannerauth/spannerauth.go +++ b/pkg/auth/spannerauth/spannerauth.go @@ -4,6 +4,7 @@ package spannerauth import ( + "bytes" "context" "time" @@ -17,6 +18,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" + "storj.io/common/uuid" "storj.io/edge/pkg/auth/authdb" ) @@ -90,7 +92,6 @@ func (d *CloudDatabase) Put(ctx context.Context, keyHash authdb.KeyHash, record "encryption_key_hash": keyHash.Bytes(), // "created_at" has default value "public": record.Public, - "public_project_id": record.PublicProjectID, "satellite_address": record.SatelliteAddress, "macaroon_head": record.MacaroonHead, "encrypted_secret_key": record.EncryptedSecretKey, @@ -99,6 +100,10 @@ func (d *CloudDatabase) Put(ctx context.Context, keyHash authdb.KeyHash, record // "invalidated_at" } + if record.PublicProjectID != nil && !bytes.Equal(record.PublicProjectID, uuid.UUID{}.Bytes()) { + in["public_project_id"] = record.PublicProjectID + } + // we do not set any expiry unless it's non-zero. If an empty time.Time was // passed, then the value should be null in the database to ensure row // deletion policy doesn't inadvertently delete non-expiring records. diff --git a/pkg/auth/spannerauth/spannerauth_test.go b/pkg/auth/spannerauth/spannerauth_test.go index 540e5e9e..482a21b4 100644 --- a/pkg/auth/spannerauth/spannerauth_test.go +++ b/pkg/auth/spannerauth/spannerauth_test.go @@ -18,6 +18,7 @@ import ( "storj.io/common/storj" "storj.io/common/testcontext" "storj.io/common/testrand" + "storj.io/common/uuid" "storj.io/edge/pkg/auth/authdb" "storj.io/edge/pkg/auth/spannerauth" "storj.io/edge/pkg/auth/spannerauth/spannerauthtest" @@ -222,6 +223,48 @@ func TestContextCanceledHandling(t *testing.T) { require.ErrorIs(t, err, context.Canceled) } +func TestEmptyPublicProjectID(t *testing.T) { + ctx := testcontext.New(t) + defer ctx.Cleanup() + + logger := zaptest.NewLogger(t) + defer ctx.Check(logger.Sync) + + server, err := spannerauthtest.ConfigureTestServer(ctx, logger) + require.NoError(t, err) + defer server.Close() + + db, err := spannerauth.Open(ctx, logger, spannerauth.Config{ + DatabaseName: "projects/P/instances/I/databases/D", + Address: server.Addr, + }) + require.NoError(t, err) + defer ctx.Check(db.Close) + + require.NoError(t, db.HealthCheck(ctx)) + + testUUID := testrand.UUID() + + test := func(publicProjectID []byte, expected []byte) { + var k authdb.KeyHash + testrand.Read(k[:]) + + record := createRandomRecord(t, time.Time{}, true) + record.PublicProjectID = publicProjectID + + require.NoError(t, db.Put(ctx, k, record)) + + actual, err := db.Get(ctx, k) + require.NoError(t, err) + + require.Equal(t, expected, actual.PublicProjectID) + } + + test(nil, nil) + test(uuid.UUID{}.Bytes(), nil) + test(testUUID.Bytes(), testUUID.Bytes()) +} + func createRandomRecord(t *testing.T, expiresAt time.Time, forcePublic bool) *authdb.Record { var secretKey authdb.SecretKey _, err := rand.Read(secretKey[:]) diff --git a/testsuite/authservice/integration_test.go b/testsuite/authservice/integration_test.go index e0b8c65c..9f400bcd 100644 --- a/testsuite/authservice/integration_test.go +++ b/testsuite/authservice/integration_test.go @@ -12,8 +12,11 @@ import ( "storj.io/common/errs2" "storj.io/common/fpath" + "storj.io/common/grant" + "storj.io/common/macaroon" "storj.io/common/memory" "storj.io/common/testcontext" + "storj.io/common/testrand" "storj.io/edge/internal/register" "storj.io/edge/pkg/auth" "storj.io/edge/pkg/auth/spannerauth" @@ -44,14 +47,19 @@ func TestAuthservice(t *testing.T) { require.NoError(t, err) defer ctx.Check(db.Close) + nonExistentSatellite := testrand.NodeID().String() + "@sa" + authConfig := auth.Config{ - Endpoint: "http://localhost:1234", - AuthToken: []string{"super-secret"}, - POSTSizeLimit: 4 * memory.KiB, - AllowedSatellites: []string{planet.Satellites[0].NodeURL().String()}, - KVBackend: "spanner://", - ListenAddr: ":0", - DRPCListenAddr: ":0", + Endpoint: "http://localhost:1234", + AuthToken: []string{"super-secret"}, + POSTSizeLimit: 4 * memory.KiB, + AllowedSatellites: []string{ + planet.Satellites[0].NodeURL().String(), + nonExistentSatellite, + }, + KVBackend: "spanner://", + ListenAddr: ":0", + DRPCListenAddr: ":0", Spanner: spannerauth.Config{ DatabaseName: "projects/P/instances/I/databases/D", Address: server.Addr, @@ -79,18 +87,41 @@ func TestAuthservice(t *testing.T) { serialized, err := planet.Uplinks[0].Access[planet.Satellites[0].ID()].Serialize() require.NoError(t, err) - runTest := func(addr string) { + runTest := func(addr, serialized string, test func(resp authclient.AuthServiceResponse)) { creds, err := register.Access(ctx, addr, serialized, false) require.NoError(t, err) resp, err := authClient.Resolve(ctx, creds.AccessKeyID, "") require.NoError(t, err) + test(resp) + } + + test := func(resp authclient.AuthServiceResponse) { require.Equal(t, serialized, resp.AccessGrant) require.Equal(t, planet.Uplinks[0].Projects[0].PublicID.String(), resp.PublicProjectID) } - runTest("http://" + auth.Address()) - runTest("drpc://" + auth.DRPCAddress()) + runTest("http://"+auth.Address(), serialized, test) + runTest("drpc://"+auth.DRPCAddress(), serialized, test) + + apiKey, err := macaroon.NewAPIKey([]byte("secret")) + require.NoError(t, err) + + ag := grant.Access{ + SatelliteAddress: nonExistentSatellite, + APIKey: apiKey, + EncAccess: grant.NewEncryptionAccess(), + } + nonExistentSatelliteAccess, err := ag.Serialize() + require.NoError(t, err) + + nonExistentSatelliteTest := func(resp authclient.AuthServiceResponse) { + require.Equal(t, nonExistentSatelliteAccess, resp.AccessGrant) + require.Equal(t, "", resp.PublicProjectID) + } + + runTest("http://"+auth.Address(), nonExistentSatelliteAccess, nonExistentSatelliteTest) + runTest("drpc://"+auth.DRPCAddress(), nonExistentSatelliteAccess, nonExistentSatelliteTest) }) }