Skip to content

Commit

Permalink
Merge pull request #2620 from actiontech/issue-2602-v2
Browse files Browse the repository at this point in the history
fix: expired token causes scannerd to be unable to access SQLite normally
  • Loading branch information
ColdWaterLW authored Sep 20, 2024
2 parents 3f884f0 + 6f65a3b commit e953fc1
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 16 deletions.
8 changes: 4 additions & 4 deletions sqle/api/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ func StartApi(net *gracenet.Net, exitChan chan struct{}, config config.SqleConfi
e.GET("/v1/oauth2/link", v1.Oauth2Link)
e.GET("/v1/oauth2/callback", v1.Oauth2Callback)
e.POST("/v1/oauth2/user/bind", v1.BindOauth2User)
e.POST("/v1/projects/:project_name/audit_plans/:audit_plan_name/sqls/full", v1.FullSyncAuditPlanSQLs, sqleMiddleware.ScannerVerifier())
e.POST("/v2/projects/:project_name/audit_plans/:audit_plan_name/sqls/full", v2.FullSyncAuditPlanSQLs, sqleMiddleware.ScannerVerifier())
e.POST("/v1/projects/:project_name/audit_plans/:audit_plan_name/sqls/partial", v1.PartialSyncAuditPlanSQLs, sqleMiddleware.ScannerVerifier())
e.POST("/v2/projects/:project_name/audit_plans/:audit_plan_name/sqls/partial", v2.PartialSyncAuditPlanSQLs, sqleMiddleware.ScannerVerifier())

v1Router := e.Group(apiV1)
v1Router.Use(sqleMiddleware.JWTTokenAdapter(), sqleMiddleware.JWTWithConfig(utils.JWTSecretKey), sqleMiddleware.VerifyUserIsDisabled(), sqleMiddleware.LicenseAdapter(), sqleMiddleware.OperationLogRecord())
Expand Down Expand Up @@ -392,10 +396,6 @@ func StartApi(net *gracenet.Net, exitChan chan struct{}, config config.SqleConfi
v1Router.GET("/projects/:project_name/audit_plans/:audit_plan_name/reports/:audit_plan_report_id/", v1.GetAuditPlanReport)

v1Router.GET("/projects/:project_name/audit_plans/:audit_plan_name/sqls", v1.GetAuditPlanSQLs)
v1Router.POST("/projects/:project_name/audit_plans/:audit_plan_name/sqls/full", v1.FullSyncAuditPlanSQLs, sqleMiddleware.ScannerVerifier())
v2Router.POST("/projects/:project_name/audit_plans/:audit_plan_name/sqls/full", v2.FullSyncAuditPlanSQLs, sqleMiddleware.ScannerVerifier())
v1Router.POST("/projects/:project_name/audit_plans/:audit_plan_name/sqls/partial", v1.PartialSyncAuditPlanSQLs, sqleMiddleware.ScannerVerifier())
v2Router.POST("/projects/:project_name/audit_plans/:audit_plan_name/sqls/partial", v2.PartialSyncAuditPlanSQLs, sqleMiddleware.ScannerVerifier())
v1Router.POST("/projects/:project_name/audit_plans/:audit_plan_name/trigger", v1.TriggerAuditPlan)
v1Router.PATCH("/projects/:project_name/audit_plans/:audit_plan_name/notify_config", v1.UpdateAuditPlanNotifyConfig)
v1Router.GET("/projects/:project_name/audit_plans/:audit_plan_name/notify_config", v1.GetAuditPlanNotifyConfig)
Expand Down
15 changes: 14 additions & 1 deletion sqle/api/middleware/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,24 @@ func ScannerVerifier() echo.MiddlewareFunc {
token = parts[1]
}

apnInToken, err := utils.ParseAuditPlanName(token)
apnInToken, userName, err := utils.ParseAuditPlanToken(token)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
projectName := c.Param("project_name")
apnInParam := c.Param("audit_plan_name")
// verify user
user, isExist, err := model.GetStorage().GetUserByName(userName)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
if !isExist {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("user is not exist"))
}
if user.IsDisabled() {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("current user is disabled"))
}
// verify audit plan
// 由于对生成的JWT Token的负载使用MD5算法进行预处理,因此在验证的时候也需要对param中的apn使用MD5处理
// 为了兼容老版本的JWT Token需要增加不经MD5处理的apnInParam和apnInToken的判断
if apnInToken != apnInParam && apnInToken != utils.Md5(apnInParam) {
Expand All @@ -74,6 +86,7 @@ func ScannerVerifier() echo.MiddlewareFunc {
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
// verify token in audit plan
if !apnExist || apn.Token != token {
return echo.NewHTTPError(http.StatusInternalServerError, errAuditPlanMisMatch.Error())
}
Expand Down
38 changes: 32 additions & 6 deletions sqle/api/middleware/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,28 +42,39 @@ func TestScannerVerifier(t *testing.T) {
}

{ // test audit plan name don't match the token
mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
assert.NoError(t, err)
model.InitMockStorage(mockDB)
mock.ExpectQuery("SELECT * FROM `users` WHERE `users`.`deleted_at` IS NULL AND ((login_name = ?)) ORDER BY `users`.`id` ASC LIMIT 1").WithArgs(testUser).WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow(driver.Value(testUser)))
token, err := jwt.CreateToken(testUser, time.Now().Add(1*time.Hour).Unix(), utils.WithAuditPlanName(apName))
assert.NoError(t, err)
ctx, _ := newContextFunc(token, fmt.Sprintf("%s_modified", apName))
err = mw(h)(ctx)
mockDB.Close()
assert.Contains(t, err.Error(), errAuditPlanMisMatch.Error())
}

{ // test unknown token
mockDB, _, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
assert.NoError(t, err)
model.InitMockStorage(mockDB)
token, err := jwt.CreateToken(testUser, time.Now().Add(1*time.Hour).Unix())
assert.NoError(t, err)
ctx, _ := newContextFunc(token, apName)
err = mw(h)(ctx)
assert.Contains(t, err.Error(), "unknown token")
mockDB.Close()
}

{ // test audit plan token incorrect
token, err := jwt.CreateToken(testUser, time.Now().Add(1*time.Hour).Unix(), utils.WithAuditPlanName(apName))
assert.NoError(t, err)

mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
assert.NoError(t, err)
model.InitMockStorage(mockDB)
mock.ExpectQuery("SELECT * FROM `users` WHERE `users`.`deleted_at` IS NULL AND ((login_name = ?)) ORDER BY `users`.`id` ASC LIMIT 1").WithArgs(testUser).WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow(driver.Value(testUser)))

token, err := jwt.CreateToken(testUser, time.Now().Add(1*time.Hour).Unix(), utils.WithAuditPlanName(apName))
assert.NoError(t, err)

mock.ExpectQuery("SELECT `audit_plans`.* FROM `audit_plans` LEFT JOIN projects ON projects.id = audit_plans.project_id WHERE `audit_plans`.`deleted_at` IS NULL AND ((projects.name = ? AND audit_plans.name = ?))").
WithArgs(projectName, apName).
WillReturnRows(sqlmock.NewRows([]string{"name", "token"}).AddRow(driver.Value(testUser), "test-token"))
Expand All @@ -85,6 +96,7 @@ func TestScannerVerifier(t *testing.T) {
mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
assert.NoError(t, err)
model.InitMockStorage(mockDB)
mock.ExpectQuery("SELECT * FROM `users` WHERE `users`.`deleted_at` IS NULL AND ((login_name = ?)) ORDER BY `users`.`id` ASC LIMIT 1").WithArgs(testUser).WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow(driver.Value(testUser)))
mock.ExpectQuery("SELECT `audit_plans`.* FROM `audit_plans` LEFT JOIN projects ON projects.id = audit_plans.project_id WHERE `audit_plans`.`deleted_at` IS NULL AND ((projects.name = ? AND audit_plans.name = ?))").
WithArgs(projectName, apName).
WillReturnError(gorm.ErrRecordNotFound)
Expand All @@ -108,6 +120,7 @@ func TestScannerVerifier(t *testing.T) {
mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
assert.NoError(t, err)
model.InitMockStorage(mockDB)
mock.ExpectQuery("SELECT * FROM `users` WHERE `users`.`deleted_at` IS NULL AND ((login_name = ?)) ORDER BY `users`.`id` ASC LIMIT 1").WithArgs(testUser).WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow(driver.Value(testUser)))
mock.ExpectQuery("SELECT `audit_plans`.* FROM `audit_plans` LEFT JOIN projects ON projects.id = audit_plans.project_id WHERE `audit_plans`.`deleted_at` IS NULL AND ((projects.name = ? AND audit_plans.name = ?))").
WithArgs(projectName, apName).
WillReturnRows(sqlmock.NewRows([]string{"name", "token"}).AddRow(testUser, token))
Expand All @@ -130,6 +143,7 @@ func TestScannerVerifier(t *testing.T) {
mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
assert.NoError(t, err)
model.InitMockStorage(mockDB)
mock.ExpectQuery("SELECT * FROM `users` WHERE `users`.`deleted_at` IS NULL AND ((login_name = ?)) ORDER BY `users`.`id` ASC LIMIT 1").WithArgs(testUser).WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow(driver.Value(testUser)))
mock.ExpectQuery("SELECT `audit_plans`.* FROM `audit_plans` LEFT JOIN projects ON projects.id = audit_plans.project_id WHERE `audit_plans`.`deleted_at` IS NULL AND ((projects.name = ? AND audit_plans.name = ?))").
WithArgs(projectName, apName).
WillReturnRows(sqlmock.NewRows([]string{"name", "token"}).AddRow(testUser, token))
Expand Down Expand Up @@ -170,12 +184,13 @@ func TestScannerVerifierIssue1758(t *testing.T) {
return ctx, res
}
{ // test check success
token, err := jwt.CreateToken(utils.Md5(userName), time.Now().Add(1*time.Hour).Unix(), utils.WithAuditPlanName(utils.Md5(apName120)))
token, err := jwt.CreateToken(userName, time.Now().Add(1*time.Hour).Unix(), utils.WithAuditPlanName(utils.Md5(apName120)))
assert.NoError(t, err)

mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
assert.NoError(t, err)
model.InitMockStorage(mockDB)
mock.ExpectQuery("SELECT * FROM `users` WHERE `users`.`deleted_at` IS NULL AND ((login_name = ?)) ORDER BY `users`.`id` ASC LIMIT 1").WithArgs(userName).WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow(driver.Value(userName)))
mock.ExpectQuery("SELECT `audit_plans`.* FROM `audit_plans` LEFT JOIN projects ON projects.id = audit_plans.project_id WHERE `audit_plans`.`deleted_at` IS NULL AND ((projects.name = ? AND audit_plans.name = ?))").
WithArgs(projectName, apName120).
WillReturnRows(sqlmock.NewRows([]string{"name", "token"}).AddRow(userName, token))
Expand All @@ -191,25 +206,36 @@ func TestScannerVerifierIssue1758(t *testing.T) {
assert.NoError(t, err)
}
{ // test audit plan name don't match the token
token, err := jwt.CreateToken(utils.Md5(userName), time.Now().Add(1*time.Hour).Unix(), utils.WithAuditPlanName(utils.Md5(apName120)))
mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
assert.NoError(t, err)
model.InitMockStorage(mockDB)
mock.ExpectQuery("SELECT * FROM `users` WHERE `users`.`deleted_at` IS NULL AND ((login_name = ?)) ORDER BY `users`.`id` ASC LIMIT 1").WithArgs(userName).WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow(driver.Value(userName)))
token, err := jwt.CreateToken(userName, time.Now().Add(1*time.Hour).Unix(), utils.WithAuditPlanName(utils.Md5(apName120)))
assert.NoError(t, err)
ctx, _ := newContextFunc(token, fmt.Sprintf("%s_modified", apName120))
err = mw(h)(ctx)
assert.Contains(t, err.Error(), errAuditPlanMisMatch.Error())
mockDB.Close()
}
{ // test unknown token
token, err := jwt.CreateToken(utils.Md5(userName), time.Now().Add(1*time.Hour).Unix())
mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
assert.NoError(t, err)
model.InitMockStorage(mockDB)
mock.ExpectQuery("SELECT * FROM `users` WHERE `users`.`deleted_at` IS NULL AND ((login_name = ?)) ORDER BY `users`.`id` ASC LIMIT 1").WithArgs(userName).WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow(driver.Value(userName)))
token, err := jwt.CreateToken(userName, time.Now().Add(1*time.Hour).Unix())
assert.NoError(t, err)
ctx, _ := newContextFunc(token, apName120)
err = mw(h)(ctx)
assert.Contains(t, err.Error(), "unknown token")
mockDB.Close()
}
{ // test old token
token, err := jwt.CreateToken(userName, time.Now().Add(1*time.Hour).Unix(), utils.WithAuditPlanName(apName120))
assert.NoError(t, err)
mockDB, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
assert.NoError(t, err)
model.InitMockStorage(mockDB)
mock.ExpectQuery("SELECT * FROM `users` WHERE `users`.`deleted_at` IS NULL AND ((login_name = ?)) ORDER BY `users`.`id` ASC LIMIT 1").WithArgs(userName).WillReturnRows(sqlmock.NewRows([]string{"name"}).AddRow(driver.Value(userName)))
mock.ExpectQuery("SELECT `audit_plans`.* FROM `audit_plans` LEFT JOIN projects ON projects.id = audit_plans.project_id WHERE `audit_plans`.`deleted_at` IS NULL AND ((projects.name = ? AND audit_plans.name = ?))").
WithArgs(projectName, apName120).
WillReturnRows(sqlmock.NewRows([]string{"name", "token"}).AddRow(userName, token))
Expand Down
18 changes: 13 additions & 5 deletions sqle/utils/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,23 +63,31 @@ func WithAuditPlanName(name string) CustomClaimOption {
})
}

// ParseAuditPlanName used by echo middleware which only verify api request to audit plan related.
func ParseAuditPlanName(tokenString string) (string, error) {
// ParseAuditPlanToken used by echo middleware which only verify api request to audit plan related.
func ParseAuditPlanToken(tokenString string) (string, string, error) {
keyFunc := func(t *jwt.Token) (interface{}, error) {
return JWTSecretKey, nil
}
token, err := jwt.Parse(tokenString, keyFunc)
if err != nil {
return "", err
if e, ok := err.(*jwt.ValidationError); ok {
if e.Errors != jwt.ValidationErrorExpired {
return "", "", err
}
}
}
// claims can only be jwt.MapClaims
//nolint:forcetypeassert
claims := token.Claims.(jwt.MapClaims)
apn, ok := claims["apn"]
if !ok {
return "", jwt.NewValidationError("unknown token", jwt.ValidationErrorClaimsInvalid)
return "", "", jwt.NewValidationError("unknown token", jwt.ValidationErrorClaimsInvalid)
}
userName, ok := claims["name"]
if !ok {
return "", "", jwt.NewValidationError("unknown token", jwt.ValidationErrorClaimsInvalid)
}
return apn.(string), nil
return apn.(string), userName.(string), nil
}

func GetUserNameFromJWTToken(token string) (string, error) {
Expand Down

0 comments on commit e953fc1

Please sign in to comment.