Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change the tf/task jwt token management #27

Merged
merged 4 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ var (
useServiceHost bool
serviceName string
dashboard string
fswatchImage string
)

func main() {
Expand Down Expand Up @@ -60,6 +61,8 @@ func main() {
viper.BindPFlag("service-name", pflag.Lookup("service-name"))
pflag.StringVar(&dashboard, "dashboard", "", "Connect to the dashboard with api credentials")
viper.BindPFlag("dashboard", pflag.Lookup("dashboard"))
pflag.StringVar(&fswatchImage, "fswatch-image", "ghcr.io/galleybytes/fswatch:0.11.0", "Docker image for fswatch (log-service)")
viper.BindPFlag("fswatch-image", pflag.Lookup("fswatch-image"))
pflag.Parse()

pflag.Set("alsologtostderr", "false")
Expand All @@ -80,6 +83,7 @@ func main() {
useServiceHost = viper.GetBool("use-service-host")
serviceName = viper.GetString("service-name")
dashboard = viper.GetString("dashboard")
fswatchImage = viper.GetString("fswatch-image")

clientset := kubernetes.NewForConfigOrDie(NewConfigOrDie(os.Getenv("KUBECONFIG")))
var database *gorm.DB
Expand Down Expand Up @@ -109,7 +113,7 @@ func main() {
}
}

apiHandler := api.NewAPIHandler(database, clientset, ssoConfig, &serviceIP, &dashboard)
apiHandler := api.NewAPIHandler(database, clientset, ssoConfig, &serviceIP, &dashboard, fswatchImage)
apiHandler.RegisterRoutes()
fmt.Printf("Starting server on %s\n", addr)
apiHandler.Server.Run(addr)
Expand Down
35 changes: 19 additions & 16 deletions pkg/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@ import (
)

type APIHandler struct {
Server *gin.Engine
DB *gorm.DB
clientset kubernetes.Interface
ssoConfig *SSOConfig
serviceIP *string
tenant string
Cache *cache.Cache
dashboard *string
Server *gin.Engine
DB *gorm.DB
clientset kubernetes.Interface
ssoConfig *SSOConfig
serviceIP *string
tenant string
Cache *cache.Cache
dashboard *string
fswatchImage string
}

type SSOConfig struct {
Expand Down Expand Up @@ -55,16 +56,17 @@ func NewSAMLConfig(issuer, recipient, metadataURL string) (*SSOConfig, error) {
}, nil
}

func NewAPIHandler(db *gorm.DB, clientset kubernetes.Interface, ssoConfig *SSOConfig, serviceIP, dashboard *string) *APIHandler {
func NewAPIHandler(db *gorm.DB, clientset kubernetes.Interface, ssoConfig *SSOConfig, serviceIP, dashboard *string, fswatchImage string) *APIHandler {

return &APIHandler{
Server: gin.Default(),
DB: db,
clientset: clientset,
ssoConfig: ssoConfig,
serviceIP: serviceIP,
Cache: cache.New(20 * time.Second),
dashboard: dashboard,
Server: gin.Default(),
DB: db,
clientset: clientset,
ssoConfig: ssoConfig,
serviceIP: serviceIP,
Cache: cache.New(20 * time.Second),
dashboard: dashboard,
fswatchImage: fswatchImage,
}
}

Expand All @@ -85,6 +87,7 @@ func (h APIHandler) RegisterRoutes() {
})

preauth.POST("/login", h.login)
preauth.POST("/refresh", h.loginWithRefreshToken)
preauth.GET("/connect", h.defaultConnectMethod) // Determine preferred auth method
preauth.GET("/sso", h.ssoRedirecter)
preauth.POST("/sso/saml", h.samlConnecter)
Expand Down
58 changes: 44 additions & 14 deletions pkg/api/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,22 +56,29 @@ func validateJwt(c *gin.Context) {
unauthorized(c, err.Error())
}

token, err := jwt.Parse(userProvidedJWT, func(token *jwt.Token) (interface{}, error) {
_, err = doValidation(userProvidedJWT)
if err != nil {
unauthorized(c, err.Error())
return
}
}

func doValidation(jwtToken string) (*jwt.Token, error) {
token, err := jwt.Parse(jwtToken, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("there was an error in parsing")
}
return []byte(jwtSigningKey), nil
})

if err != nil {
unauthorized(c, err.Error())
return
return nil, err
}

if token == nil {
unauthorized(c, "invalid token")
return
return nil, fmt.Errorf("invalid token")
}

return token, nil
}

func (h APIHandler) login(c *gin.Context) {
Expand All @@ -94,7 +101,26 @@ func (h APIHandler) login(c *gin.Context) {
return
}

token, err := generateJWT(jsonData.Username)
token, err := generateJWT(jsonData.Username, 12)
if err != nil {
unauthorized(c, fmt.Sprintf("Error issuing JWT: %s", err.Error()))
return
}

c.JSON(http.StatusOK, response(http.StatusOK, "", []string{token}))
}

func (h APIHandler) loginWithRefreshToken(c *gin.Context) {
jsonData := struct {
RefreshToken string `json:"refresh_token"`
}{}
err := c.BindJSON(&jsonData)
if err != nil {
unauthorized(c, fmt.Sprintf("Error paring data: %s", err.Error()))
return
}

token, err := NewTaskTokenFromRefreshToken(h.DB, jsonData.RefreshToken, GetApiURL(c, h.serviceIP), h.clientset)
if err != nil {
unauthorized(c, fmt.Sprintf("Error issuing JWT: %s", err.Error()))
return
Expand All @@ -103,13 +129,12 @@ func (h APIHandler) login(c *gin.Context) {
c.JSON(http.StatusOK, response(http.StatusOK, "", []string{token}))
}

func generateJWT(username string) (string, error) {
func generateJWT(username string, durationHours time.Duration) (string, error) {
token := jwt.New(jwt.SigningMethodHS256)
claims := token.Claims.(jwt.MapClaims)

claims["authorized"] = true
claims["username"] = username
claims["exp"] = time.Now().Add(time.Hour * 12).Unix()
claims["exp"] = time.Now().Add(time.Hour * durationHours).Unix()

tokenString, err := token.SignedString([]byte(jwtSigningKey))

Expand Down Expand Up @@ -174,7 +199,7 @@ func (h APIHandler) samlConnecter(c *gin.Context) {
return
}

jwtToken, err := generateJWT(username)
jwtToken, err := generateJWT(username, 12)
if err != nil {
c.AbortWithError(http.StatusNotAcceptable, err)
return
Expand Down Expand Up @@ -224,10 +249,15 @@ func fetchIDPCertificate(metadataURL string) (*x509.Certificate, error) {
// can be removed for a better method soon.
//
// Grant 30 days of access per issued token.
func generateTaskJWT(resourceUUID, tenant, clientName, generation string) (string, error) {
func generateTaskJWT(resourceUUID, tenant, clientName, generation string) (string, string, error) {
token := jwt.New(jwt.SigningMethodHS256)
claims := token.Claims.(jwt.MapClaims)

refreshToken, err := generateJWT(resourceUUID, 17520) // 2 years
if err != nil {
return "", "", fmt.Errorf("Could not generate refresh token: %s", err.Error())
}
claims["refresh_token"] = refreshToken
claims["exp"] = time.Now().Add(time.Hour * 720).Unix()
claims["authorized"] = true
claims["resourceUUID"] = resourceUUID
Expand All @@ -236,9 +266,9 @@ func generateTaskJWT(resourceUUID, tenant, clientName, generation string) (strin
tokenString, err := token.SignedString([]byte(jwtSigningKey))

if err != nil {
return "", fmt.Errorf("something went wrong: %s", err.Error())
return "", "", fmt.Errorf("something went wrong: %s", err.Error())
}
return tokenString, nil
return tokenString, refreshToken, nil
}

// Check that the taskJWT is correctly formatted with all the claim fields defined
Expand Down
Loading
Loading