Skip to content

Commit

Permalink
feat: implement create&delete user access token api
Browse files Browse the repository at this point in the history
  • Loading branch information
boojack committed Aug 6, 2023
1 parent ad98857 commit a902792
Show file tree
Hide file tree
Showing 10 changed files with 1,060 additions and 240 deletions.
43 changes: 43 additions & 0 deletions api/auth/auth.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package auth

import (
"fmt"
"time"

"github.com/golang-jwt/jwt/v4"
)

const (
Expand All @@ -20,3 +23,43 @@ const (
// AccessTokenCookieName is the cookie name of access token.
AccessTokenCookieName = "slash.access-token"
)

type ClaimsMessage struct {
Name string `json:"name"`
jwt.RegisteredClaims
}

// GenerateAccessToken generates an access token.
// username is the email of the user.
func GenerateAccessToken(username string, userID int32, secret string) (string, error) {
expirationTime := time.Now().Add(AccessTokenDuration)
return generateToken(username, userID, expirationTime, []byte(secret))
}

// generateToken generates a jwt token.
func generateToken(username string, userID int32, expirationTime time.Time, secret []byte) (string, error) {
// Create the JWT claims, which includes the username and expiry time.
claims := &ClaimsMessage{
Name: username,
RegisteredClaims: jwt.RegisteredClaims{
Audience: jwt.ClaimStrings{AccessTokenAudienceName},
// In JWT, the expiry time is expressed as unix milliseconds.
ExpiresAt: jwt.NewNumericDate(expirationTime),
IssuedAt: jwt.NewNumericDate(time.Now()),
Issuer: Issuer,
Subject: fmt.Sprint(userID),
},
}

// Declare the token with the HS256 algorithm used for signing, and the claims.
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
token.Header["kid"] = KeyID

// Create the JWT string.
tokenString, err := token.SignedString(secret)
if err != nil {
return "", err
}

return tokenString, nil
}
10 changes: 2 additions & 8 deletions api/v1/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group, secret string) {
return echo.NewHTTPError(http.StatusUnauthorized, "unmatched email and password")
}

accessToken, err := GenerateAccessToken(user.Email, user.ID, secret)
accessToken, err := auth.GenerateAccessToken(user.Email, user.ID, secret)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to generate tokens, err: %s", err)).SetInternal(err)
}
Expand Down Expand Up @@ -107,7 +107,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group, secret string) {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to create user, err: %s", err)).SetInternal(err)
}

accessToken, err := GenerateAccessToken(user.Email, user.ID, secret)
accessToken, err := auth.GenerateAccessToken(user.Email, user.ID, secret)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to generate tokens, err: %s", err)).SetInternal(err)
}
Expand Down Expand Up @@ -151,12 +151,6 @@ func (s *APIV1Service) UpsertAccessTokenToStore(ctx context.Context, user *store
return nil
}

// GenerateAccessToken generates an access token for web.
func GenerateAccessToken(username string, userID int32, secret string) (string, error) {
expirationTime := time.Now().Add(auth.AccessTokenDuration)
return generateToken(username, userID, auth.AccessTokenAudienceName, expirationTime, []byte(secret))
}

// RemoveTokensAndCookies removes the jwt token from the cookies.
func RemoveTokensAndCookies(c echo.Context) {
cookieExp := time.Now().Add(-1 * time.Hour)
Expand Down
36 changes: 1 addition & 35 deletions api/v1/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"fmt"
"net/http"
"strings"
"time"

"github.com/boojack/slash/api/auth"
"github.com/boojack/slash/internal/util"
Expand All @@ -21,39 +20,6 @@ const (
UserIDContextKey = "user-id"
)

type claimsMessage struct {
Name string `json:"name"`
jwt.RegisteredClaims
}

// generateToken generates a jwt token.
func generateToken(username string, userID int32, aud string, expirationTime time.Time, secret []byte) (string, error) {
// Create the JWT claims, which includes the username and expiry time.
claims := &claimsMessage{
Name: username,
RegisteredClaims: jwt.RegisteredClaims{
Audience: jwt.ClaimStrings{aud},
// In JWT, the expiry time is expressed as unix milliseconds.
ExpiresAt: jwt.NewNumericDate(expirationTime),
IssuedAt: jwt.NewNumericDate(time.Now()),
Issuer: auth.Issuer,
Subject: fmt.Sprint(userID),
},
}

// Declare the token with the HS256 algorithm used for signing, and the claims.
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
token.Header["kid"] = auth.KeyID

// Create the JWT string.
tokenString, err := token.SignedString(secret)
if err != nil {
return "", err
}

return tokenString, nil
}

func extractTokenFromHeader(c echo.Context) (string, error) {
authHeader := c.Request().Header.Get("Authorization")
if authHeader == "" {
Expand Down Expand Up @@ -111,7 +77,7 @@ func JWTMiddleware(s *APIV1Service, next echo.HandlerFunc, secret string) echo.H
return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token")
}

claims := &claimsMessage{}
claims := &auth.ClaimsMessage{}
_, err := jwt.ParseWithClaims(token, claims, func(t *jwt.Token) (any, error) {
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
Expand Down
7 changes: 1 addition & 6 deletions api/v2/acl.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,6 @@ const (
UserIDContextKey ContextKey = iota
)

type claimsMessage struct {
Name string `json:"name"`
jwt.RegisteredClaims
}

// GRPCAuthInterceptor is the auth interceptor for gRPC server.
type GRPCAuthInterceptor struct {
Store *store.Store
Expand Down Expand Up @@ -93,7 +88,7 @@ func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessTokenStr
if accessTokenStr == "" {
return 0, status.Errorf(codes.Unauthenticated, "access token not found")
}
claims := &claimsMessage{}
claims := &auth.ClaimsMessage{}
_, err := jwt.ParseWithClaims(accessTokenStr, claims, func(t *jwt.Token) (any, error) {
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
return nil, status.Errorf(codes.Unauthenticated, "unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
Expand Down
125 changes: 118 additions & 7 deletions api/v2/user_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package v2
import (
"context"

"github.com/boojack/slash/api/auth"
apiv2pb "github.com/boojack/slash/proto/gen/api/v2"
storepb "github.com/boojack/slash/proto/gen/store"
"github.com/boojack/slash/store"
"github.com/golang-jwt/jwt/v4"
"github.com/pkg/errors"
Expand Down Expand Up @@ -45,7 +47,7 @@ func (s *UserService) GetUser(ctx context.Context, request *apiv2pb.GetUserReque
return response, nil
}

func (s *UserService) GetUserAccessTokens(ctx context.Context, request *apiv2pb.GetUserAccessTokensRequest) (*apiv2pb.GetUserAccessTokensResponse, error) {
func (s *UserService) ListUserAccessTokens(ctx context.Context, request *apiv2pb.ListUserAccessTokensRequest) (*apiv2pb.ListUserAccessTokensResponse, error) {
userID := ctx.Value(UserIDContextKey).(int32)
if userID != request.Id {
return nil, status.Errorf(codes.PermissionDenied, "Permission denied")
Expand All @@ -56,9 +58,9 @@ func (s *UserService) GetUserAccessTokens(ctx context.Context, request *apiv2pb.
return nil, status.Errorf(codes.Internal, "failed to list access tokens: %v", err)
}

accessTokens := []*apiv2pb.GetUserAccessTokensResponse_AccessToken{}
accessTokens := []*apiv2pb.UserAccessToken{}
for _, userAccessToken := range userAccessTokens {
claims := &claimsMessage{}
claims := &auth.ClaimsMessage{}
_, err := jwt.ParseWithClaims(userAccessToken.AccessToken, claims, func(t *jwt.Token) (any, error) {
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
Expand All @@ -75,20 +77,129 @@ func (s *UserService) GetUserAccessTokens(ctx context.Context, request *apiv2pb.
continue
}

accessTokens = append(accessTokens, &apiv2pb.GetUserAccessTokensResponse_AccessToken{
accessTokens = append(accessTokens, &apiv2pb.UserAccessToken{
AccessToken: userAccessToken.AccessToken,
Description: userAccessToken.Description,
ExpiresTime: timestamppb.New(claims.ExpiresAt.Time),
CreatedTime: timestamppb.New(claims.IssuedAt.Time),
IssuedAt: timestamppb.New(claims.IssuedAt.Time),
ExpiresAt: timestamppb.New(claims.ExpiresAt.Time),
})
}

response := &apiv2pb.GetUserAccessTokensResponse{
response := &apiv2pb.ListUserAccessTokensResponse{
AccessTokens: accessTokens,
}
return response, nil
}

func (s *UserService) CreateUserAccessToken(ctx context.Context, request *apiv2pb.CreateUserAccessTokenRequest) (*apiv2pb.CreateUserAccessTokenResponse, error) {
userID := ctx.Value(UserIDContextKey).(int32)
if userID != request.Id {
return nil, status.Errorf(codes.PermissionDenied, "Permission denied")
}

user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.NotFound, "user not found")
}

accessToken, err := auth.GenerateAccessToken(user.Email, user.ID, s.Secret)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to generate access token: %v", err)
}

claims := &auth.ClaimsMessage{}
_, err = jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) {
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
}
if kid, ok := t.Header["kid"].(string); ok {
if kid == "v1" {
return []byte(s.Secret), nil
}
}
return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"])
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to parse access token: %v", err)
}

// Upsert the access token to user setting store.
if err := s.UpsertAccessTokenToStore(ctx, user, accessToken); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert access token to store: %v", err)
}

response := &apiv2pb.CreateUserAccessTokenResponse{
AccessToken: &apiv2pb.UserAccessToken{
AccessToken: accessToken,
Description: request.Description,
IssuedAt: timestamppb.New(claims.IssuedAt.Time),
ExpiresAt: timestamppb.New(claims.ExpiresAt.Time),
},
}
return response, nil
}

func (s *UserService) DeleteUserAccessToken(ctx context.Context, request *apiv2pb.DeleteUserAccessTokenRequest) (*apiv2pb.DeleteUserAccessTokenResponse, error) {
userID := ctx.Value(UserIDContextKey).(int32)
if userID != request.Id {
return nil, status.Errorf(codes.PermissionDenied, "Permission denied")
}

userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, userID)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list access tokens: %v", err)
}
updatedUserAccessTokens := []*storepb.AccessTokensUserSetting_AccessToken{}
for _, userAccessToken := range userAccessTokens {
if userAccessToken.AccessToken == request.AccessToken {
continue
}
updatedUserAccessTokens = append(updatedUserAccessTokens, userAccessToken)
}
if _, err := s.Store.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: userID,
Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS,
Value: &storepb.UserSetting_AccessTokensUserSetting{
AccessTokensUserSetting: &storepb.AccessTokensUserSetting{
AccessTokens: updatedUserAccessTokens,
},
},
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert user setting: %v", err)
}

return &apiv2pb.DeleteUserAccessTokenResponse{}, nil
}

func (s *UserService) UpsertAccessTokenToStore(ctx context.Context, user *store.User, accessToken string) error {
userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, user.ID)
if err != nil {
return errors.Wrap(err, "failed to get user access tokens")
}
userAccessToken := storepb.AccessTokensUserSetting_AccessToken{
AccessToken: accessToken,
Description: "user sign in",
}
userAccessTokens = append(userAccessTokens, &userAccessToken)
if _, err := s.Store.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS,
Value: &storepb.UserSetting_AccessTokensUserSetting{
AccessTokensUserSetting: &storepb.AccessTokensUserSetting{
AccessTokens: userAccessTokens,
},
},
}); err != nil {
return errors.Wrap(err, "failed to upsert user setting")
}
return nil
}

func convertUserFromStore(user *store.User) *apiv2pb.User {
return &apiv2pb.User{
Id: int32(user.ID),
Expand Down
53 changes: 42 additions & 11 deletions proto/api/v2/user_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,21 @@ service UserService {
option (google.api.http) = {get: "/api/v2/users/{id}"};
option (google.api.method_signature) = "id";
}
// GetUserAccessTokens returns a list of access tokens for a user.
rpc GetUserAccessTokens(GetUserAccessTokensRequest) returns (GetUserAccessTokensResponse) {
// ListUserAccessTokens returns a list of access tokens for a user.
rpc ListUserAccessTokens(ListUserAccessTokensRequest) returns (ListUserAccessTokensResponse) {
option (google.api.http) = {get: "/api/v2/users/{id}/access_tokens"};
option (google.api.method_signature) = "id";
}
// CreateUserAccessToken creates a new access token for a user.
rpc CreateUserAccessToken(CreateUserAccessTokenRequest) returns (CreateUserAccessTokenResponse) {
option (google.api.http) = {post: "/api/v2/users/{id}/access_tokens"};
option (google.api.method_signature) = "id";
}
// DeleteUserAccessToken deletes an access token for a user.
rpc DeleteUserAccessToken(DeleteUserAccessTokenRequest) returns (DeleteUserAccessTokenResponse) {
option (google.api.http) = {delete: "/api/v2/users/{id}/access_tokens/{access_token}"};
option (google.api.method_signature) = "id,access_token";
}
}

message User {
Expand Down Expand Up @@ -54,16 +64,37 @@ message GetUserResponse {
User user = 1;
}

message GetUserAccessTokensRequest {
message ListUserAccessTokensRequest {
// id is the user id.
int32 id = 1;
}

message GetUserAccessTokensResponse {
message AccessToken {
string access_token = 1;
string description = 2;
google.protobuf.Timestamp created_time = 3;
google.protobuf.Timestamp expires_time = 4;
}
repeated AccessToken access_tokens = 1;
message ListUserAccessTokensResponse {
repeated UserAccessToken access_tokens = 1;
}

message CreateUserAccessTokenRequest {
// id is the user id.
int32 id = 1;
string description = 2;
}

message CreateUserAccessTokenResponse {
UserAccessToken access_token = 1;
}

message DeleteUserAccessTokenRequest {
// id is the user id.
int32 id = 1;
// access_token is the access token to delete.
string access_token = 2;
}

message DeleteUserAccessTokenResponse {}

message UserAccessToken {
string access_token = 1;
string description = 2;
google.protobuf.Timestamp issued_at = 3;
google.protobuf.Timestamp expires_at = 4;
}
Loading

0 comments on commit a902792

Please sign in to comment.