From 9da9d5d4160e5a68689802f2959e16476a987394 Mon Sep 17 00:00:00 2001 From: Harshit Raj <75195728+1-Harshit@users.noreply.github.com> Date: Sun, 5 Dec 2021 20:17:00 +0530 Subject: [PATCH] Add: Refresh Token, Reset password, Async Emailing (#40) * Refactor Access Token * Add Refresh Token * Add reset password api (#5) * Make Email Async (#7) * Add: Return Name with history (#9) * Rename Files * rename again --- .gitignore | 3 +- account/refresh_token.go | 34 +++++++ account/user.go | 18 ++++ account/wallet.go | 103 ++++++++++---------- auth/access_token.go | 37 ++++++++ auth/jwt.go | 105 +++++++++------------ auth/login.go | 2 +- auth/otp.go | 7 +- auth/refresh_token.go | 77 +++++++++++++++ auth/reset_password.go | 41 ++++++++ cmd/iitk-coin/main.go | 7 ++ config.yml | 8 +- database/tables.go | 12 ++- errors/handler.go | 41 ++++---- handlers/{checklogin.go => check_login.go} | 0 handlers/login.go | 68 +++++++++---- handlers/refresh.go | 52 ++++++++++ handlers/reset_password.go | 36 +++++++ mail/config.go | 45 +++++++++ mail/send.go | 40 -------- mail/service.go | 34 +++++++ mail/test.go | 24 +++++ server/router.go | 3 + 23 files changed, 606 insertions(+), 191 deletions(-) create mode 100644 account/refresh_token.go create mode 100644 auth/access_token.go create mode 100644 auth/refresh_token.go create mode 100644 auth/reset_password.go rename handlers/{checklogin.go => check_login.go} (100%) create mode 100644 handlers/refresh.go create mode 100644 handlers/reset_password.go create mode 100644 mail/config.go delete mode 100644 mail/send.go create mode 100644 mail/service.go create mode 100644 mail/test.go diff --git a/.gitignore b/.gitignore index 0fdf57c..8f35557 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ *.db *.exe -*.log \ No newline at end of file +*.log +iitk-coin diff --git a/account/refresh_token.go b/account/refresh_token.go new file mode 100644 index 0000000..45288ca --- /dev/null +++ b/account/refresh_token.go @@ -0,0 +1,34 @@ +package account + +import ( + "database/sql" + + "github.com/bhuvansingla/iitk-coin/database" +) + +func UpdateRefreshToken(token string, rollNo string) error { + _, err := database.DB.Exec(("UPDATE REFRESH_TOKEN SET token = $1 WHERE rollNo = $2"), token, rollNo) + return err +} + +func DeleteRefreshToken(rollNo string) error { + return UpdateRefreshToken("", rollNo) +} + +func InvalidateAllRefreshTokens() error { + _, err := database.DB.Exec(("UPDATE REFRESH_TOKEN SET token = $1"), "") + return err +} + +func GetRefreshToken(rollNo string) (string, error) { + var token string + err := database.DB.QueryRow(("SELECT token FROM REFRESH_TOKEN WHERE rollNo = $1"), rollNo).Scan(&token) + + if err == sql.ErrNoRows { + return "", nil + } + if err != nil { + return "", err + } + return token, nil +} diff --git a/account/user.go b/account/user.go index 66f0e0b..9a75e8a 100644 --- a/account/user.go +++ b/account/user.go @@ -20,17 +20,35 @@ const ( func Create(rollNo string, hashedPasssword string, name string) error { role := NormalUser + stmt, err := database.DB.Prepare("INSERT INTO ACCOUNT (rollNo, name, password, coins, role) VALUES ($1, $2, $3, $4, $5)") if err != nil { return errors.NewHTTPError(err, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) } + _, err = stmt.Exec(rollNo, name, hashedPasssword, 0, role) if err != nil { return errors.NewHTTPError(err, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) } + + stmt, err = database.DB.Prepare("INSERT INTO REFRESH_TOKEN (rollNo, token) VALUES ($1, $2)") + if err != nil { + return errors.NewHTTPError(err, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) + } + + _, err = stmt.Exec(rollNo, "") + if err != nil { + return errors.NewHTTPError(err, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) + } + return nil } +func UpdatePassword(rollNo string, hashedPasssword string) error { + _, err := database.DB.Exec("UPDATE ACCOUNT SET password = $1 WHERE rollNo = $2", hashedPasssword, rollNo) + return err +} + func UserExists(rollNo string) (bool, error) { row := database.DB.QueryRow("SELECT rollNo FROM ACCOUNT WHERE rollNo=$1", rollNo) scannedRow := "" diff --git a/account/wallet.go b/account/wallet.go index 5a71f13..0fe9da2 100644 --- a/account/wallet.go +++ b/account/wallet.go @@ -2,45 +2,47 @@ package account import ( "database/sql" - + "github.com/bhuvansingla/iitk-coin/database" ) type TransactionType string const ( - REDEEM TransactionType = "REDEEM" - REWARD TransactionType = "REWARD" - TRANSFER TransactionType = "TRANSFER" + REDEEM TransactionType = "REDEEM" + REWARD TransactionType = "REWARD" + TRANSFER TransactionType = "TRANSFER" ) type RedeemHistory struct { - Type TransactionType `json:"type"` - Time int64 `json:"timeStamp"` - Id string `json:"txnID"` - Amount int64 `json:"amount"` - Item string `json:"item"` - Status RedeemStatus `json:"status"` - ActionByRollNo string `json:"actionByRollNo"` + Type TransactionType `json:"type"` + Time int `json:"timeStamp"` + Id string `json:"txnID"` + Amount int `json:"amount"` + Item string `json:"item"` + Status RedeemStatus `json:"status"` + ActionByRollNo string `json:"actionByRollNo"` + Name string `json:"name"` } type RewardHistory struct { - Type TransactionType `json:"type"` - Time int64 `json:"timeStamp"` - Id string `json:"txnID"` - Amount int64 `json:"amount"` - Remarks string `json:"remarks"` + Type TransactionType `json:"type"` + Time int `json:"timeStamp"` + Id string `json:"txnID"` + Amount int `json:"amount"` + Remarks string `json:"remarks"` } type TransferHistory struct { - Type TransactionType `json:"type"` - Time int64 `json:"timeStamp"` - Id string `json:"txnID"` - Amount int64 `json:"amount"` - Tax int64 `json:"tax"` - FromRollNo string `json:"fromRollNo"` - ToRollNo string `json:"toRollNo"` - Remarks string `json:"remarks"` + Type TransactionType `json:"type"` + Time int `json:"timeStamp"` + Id string `json:"txnID"` + Amount int `json:"amount"` + Tax int `json:"tax"` + FromRollNo string `json:"fromRollNo"` + ToRollNo string `json:"toRollNo"` + Remarks string `json:"remarks"` + Name string `json:"name"` } func GetCoinBalanceByRollNo(rollNo string) (int, error) { @@ -54,14 +56,13 @@ func GetCoinBalanceByRollNo(rollNo string) (int, error) { func GetWalletHistoryByRollNo(rollNo string) ([]interface{}, error) { queryString := ` - SELECT history.* + SELECT history.*, a.name FROM ( SELECT id, time, $2 AS type, fromRollNo, toRollNo, - NULL AS rollNo, coins, tax, NULL AS item, @@ -76,7 +77,6 @@ func GetWalletHistoryByRollNo(rollNo string) ([]interface{}, error) { $3 AS type, NULL AS fromRollNo, NULL AS toRollNo, - rollNo, coins, NULL AS tax, item, @@ -91,7 +91,6 @@ func GetWalletHistoryByRollNo(rollNo string) ([]interface{}, error) { $4 AS type, NULL AS fromRollNo, NULL AS toRollNo, - rollNo, coins, NULL AS tax, NULL AS item, @@ -101,6 +100,12 @@ func GetWalletHistoryByRollNo(rollNo string) ([]interface{}, error) { FROM REWARD_HISTORY WHERE rollNo = $1 ) history + LEFT JOIN ACCOUNT a + ON ( + history.type = $2 AND a.rollNo = CASE WHEN history.fromRollNo = $1 THEN history.toRollNo ELSE history.fromRollNo END + OR + history.type = $3 AND a.rollNo = history.actionByRollNo + ) ORDER BY history.time DESC;` rows, err := database.DB.Query(queryString, rollNo, TRANSFER, REDEEM, REWARD) @@ -110,46 +115,47 @@ func GetWalletHistoryByRollNo(rollNo string) ([]interface{}, error) { } var history []interface{} - + for rows.Next() { var ( - id int - time int64 - txType TransactionType - fromRollNo sql.NullString - toRollNo sql.NullString - rollNo sql.NullString - coins sql.NullInt64 - tax sql.NullInt64 - item sql.NullString - status sql.NullString + id int + time int + txType TransactionType + fromRollNo sql.NullString + toRollNo sql.NullString + coins sql.NullInt64 + tax sql.NullInt64 + item sql.NullString + status sql.NullString actionByRollNo sql.NullString - remarks sql.NullString + remarks sql.NullString + name sql.NullString ) - - if err := rows.Scan(&id, &time, &txType, &fromRollNo, &toRollNo, &rollNo, &coins, &tax, &item, &status, &actionByRollNo, &remarks); err != nil { + + if err := rows.Scan(&id, &time, &txType, &fromRollNo, &toRollNo, &coins, &tax, &item, &status, &actionByRollNo, &remarks, &name); err != nil { return nil, err } var historyItem interface{} - + switch txType { case REDEEM: historyItem = RedeemHistory{ Type: txType, - Time: time, + Time: int(time), Id: formatTxnID(id, REDEEM), - Amount: coins.Int64, + Amount: int(coins.Int64), Item: item.String, Status: RedeemStatus(status.String), ActionByRollNo: actionByRollNo.String, + Name: name.String, } case REWARD: historyItem = RewardHistory{ Type: txType, Time: time, Id: formatTxnID(id, REWARD), - Amount: coins.Int64, + Amount: int(coins.Int64), Remarks: remarks.String, } case TRANSFER: @@ -157,11 +163,12 @@ func GetWalletHistoryByRollNo(rollNo string) ([]interface{}, error) { Type: txType, Time: time, Id: formatTxnID(id, TRANSFER), - Amount: coins.Int64, - Tax: tax.Int64, + Amount: int(coins.Int64), + Tax: int(tax.Int64), FromRollNo: fromRollNo.String, ToRollNo: toRollNo.String, Remarks: remarks.String, + Name: name.String, } } diff --git a/auth/access_token.go b/auth/access_token.go new file mode 100644 index 0000000..a09cf76 --- /dev/null +++ b/auth/access_token.go @@ -0,0 +1,37 @@ +package auth + +import ( + "net/http" + "time" + + "github.com/bhuvansingla/iitk-coin/errors" + "github.com/spf13/viper" +) + +func GenerateAccessToken(rollNo string) (string, error) { + + expirationTime := time.Now().Add(time.Duration(viper.GetInt("JWT.ACCESS_TOKEN.EXPIRATION_TIME_IN_MIN")) * time.Minute) + + return generateToken(rollNo, expirationTime) +} + +func IsAuthorized(endpoint func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { + + return func(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie(viper.GetString("JWT.ACCESS_TOKEN.NAME")) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("bad token")) + return + } + + err = isTokenValid(cookie) + + if err == nil { + endpoint(w, r) + return + } + + errors.WriteResponse(err, w) + } +} diff --git a/auth/jwt.go b/auth/jwt.go index 5d02256..e881d98 100644 --- a/auth/jwt.go +++ b/auth/jwt.go @@ -5,6 +5,7 @@ import ( "net/http" "time" + "github.com/bhuvansingla/iitk-coin/errors" "github.com/golang-jwt/jwt/v4" "github.com/spf13/viper" ) @@ -16,10 +17,33 @@ type Claims struct { jwt.RegisteredClaims } -func GenerateToken(rollNo string) (string, error) { +func GetRollNoFromRequest(r *http.Request) (string, error) { + cookie, err := r.Cookie(viper.GetString("JWT.ACCESS_TOKEN.NAME")) + if err != nil { + return "", err + } + return GetRollNoFromTokenCookie(cookie) +} - expirationTime := time.Now().Add(time.Duration(viper.GetInt("JWT.EXPIRATION_TIME_IN_MIN")) * time.Minute) +func GetRollNoFromTokenCookie(cookie *http.Cookie) (string, error) { + token := cookie.Value + claims := &Claims{} + _, err := jwt.ParseWithClaims(token, claims, keyFunc) + if err != nil { + return "", err + } + return claims.RollNo, nil +} +func keyFunc(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("invalid signing method") + } + return privateKey, nil +} + +func generateToken(rollNo string, expirationTime time.Time) (string, error) { + claims := &Claims{ RollNo: rollNo, RegisteredClaims: jwt.RegisteredClaims{ @@ -37,66 +61,25 @@ func GenerateToken(rollNo string) (string, error) { return tokenString, nil } -func IsAuthorized(endpoint func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { - - return func(w http.ResponseWriter, r *http.Request) { - cookie, err := r.Cookie(viper.GetString("JWT.COOKIE_NAME")) - if err != nil { - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("bad token")) - return - } - - token, err := jwt.Parse(cookie.Value, func(token *jwt.Token) (interface{}, error) { - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("invalid signing method") - } - return privateKey, nil - }) - - if token.Valid { - endpoint(w, r) - return - } else if ve, ok := err.(*jwt.ValidationError); ok { - if ve.Errors&jwt.ValidationErrorMalformed != 0 { - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("bad token")) - return - } else if ve.Errors&(jwt.ValidationErrorExpired|jwt.ValidationErrorNotValidYet) != 0 { - // Token is either expired or not active yet - w.WriteHeader(http.StatusUnauthorized) - w.Write([]byte("token expired")) - return - } else { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte(http.StatusText(http.StatusInternalServerError))) - return - } +func isTokenValid(cookie *http.Cookie) error { + + token, err := jwt.Parse(cookie.Value, keyFunc) + + if token.Valid { + return nil + } + + jwtError, ok := err.(*jwt.ValidationError) + + if ok { + if jwtError.Errors&jwt.ValidationErrorMalformed != 0 { + return errors.NewHTTPError(err, http.StatusBadRequest, "validation malformed") + } else if jwtError.Errors&(jwt.ValidationErrorExpired|jwt.ValidationErrorNotValidYet) != 0 { + return errors.NewHTTPError(err, http.StatusUnauthorized, "token expired") } else { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte(http.StatusText(http.StatusInternalServerError))) - return + return errors.NewHTTPError(nil, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) } - - } -} - -func GetRollNoFromRequest(r *http.Request) (string, error) { - cookie, err := r.Cookie(viper.GetString("JWT.COOKIE_NAME")) - if err != nil { - return "", err - } - return GetRollNoFromTokenCookie(cookie) -} - -func GetRollNoFromTokenCookie(cookie *http.Cookie) (string, error) { - token := cookie.Value - claims := &Claims{} - _, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) { - return privateKey, nil - }) - if err != nil { - return "", err + } else { + return errors.NewHTTPError(err, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) } - return claims.RollNo, nil } diff --git a/auth/login.go b/auth/login.go index da0a7b2..f3c919b 100644 --- a/auth/login.go +++ b/auth/login.go @@ -28,7 +28,7 @@ func Login(rollNo string, password string) (ok bool, err error) { } if !util.CompareHashAndPassword(passwordFromRollNo, password) { - return false, errors.NewHTTPError(nil, http.StatusBadRequest, "invalid password") + return false, errors.NewHTTPError(nil, http.StatusUnauthorized, "invalid password") } return true, nil } diff --git a/auth/otp.go b/auth/otp.go index 2e725ec..a9ecba4 100644 --- a/auth/otp.go +++ b/auth/otp.go @@ -40,10 +40,13 @@ func GenerateOtp(rollNo string) error { return errors.NewHTTPError(err, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) } - if err = mail.SendOTP(rollNo, otp); err != nil { - return errors.NewHTTPError(err, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) + emailRequest := mail.EmailRequest{ + To: rollNo, + OTP: otp, } + mail.MailChannel <- emailRequest + return nil } diff --git a/auth/refresh_token.go b/auth/refresh_token.go new file mode 100644 index 0000000..c0c4be8 --- /dev/null +++ b/auth/refresh_token.go @@ -0,0 +1,77 @@ +package auth + +import ( + "net/http" + "time" + + "github.com/bhuvansingla/iitk-coin/account" + "github.com/bhuvansingla/iitk-coin/errors" + "github.com/spf13/viper" +) + +func GenerateRefreshToken(rollNo string) (string, error) { + + expirationTime := time.Now().Add(time.Duration(viper.GetInt("JWT.REFRESH_TOKEN.EXPIRATION_TIME_IN_MIN")) * time.Minute) + + refreshToken, err := generateToken(rollNo, expirationTime) + if err != nil { + return "", err + } + + err = account.UpdateRefreshToken(refreshToken, rollNo) + if err != nil { + return "", err + } + + return refreshToken, nil +} + +func CheckRefreshTokenValidity(r *http.Request) (string, error) { + + cookie, err := r.Cookie(viper.GetString("JWT.ACCESS_TOKEN.NAME")) + if err != nil { + return "", errors.NewHTTPError(err, http.StatusBadRequest, "bad access token") + } + + err = isTokenValid(cookie) + + if err == nil { + rollNo, err := GetRollNoFromTokenCookie(cookie) + if err != nil { + return "", errors.NewHTTPError(err, http.StatusBadRequest, "bad access token") + } + return rollNo, nil + } + + clientError, ok := err.(errors.ClientError) + if !ok { + return "", errors.NewHTTPError(err, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) + } + + if status, _ := clientError.ResponseHeaders(); status!=http.StatusUnauthorized { + return "", err + } + + cookie, err = r.Cookie(viper.GetString("JWT.REFRESH_TOKEN.NAME")) + if err != nil { + return "", errors.NewHTTPError(err, http.StatusBadRequest, "bad refresh token") + } + + rollNo, err := GetRollNoFromTokenCookie(cookie) + if err != nil { + return "", errors.NewHTTPError(err, http.StatusBadRequest, "bad refresh token") + } + + refreshToken, err := account.GetRefreshToken(rollNo) + if err != nil { + return "", errors.NewHTTPError(err, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) + } + + if refreshToken != cookie.Value { + return "", errors.NewHTTPError(err, http.StatusBadRequest, "bad refresh token") + } + + err = isTokenValid(cookie) + + return rollNo, err +} diff --git a/auth/reset_password.go b/auth/reset_password.go new file mode 100644 index 0000000..c9ebf0a --- /dev/null +++ b/auth/reset_password.go @@ -0,0 +1,41 @@ +package auth + +import ( + "net/http" + + "github.com/bhuvansingla/iitk-coin/account" + "github.com/bhuvansingla/iitk-coin/errors" + "github.com/bhuvansingla/iitk-coin/util" +) + +func ResetPassword(rollNo string, newPassword string, otp string) error { + + userExists, err := account.UserExists(rollNo) + if err != nil { + return errors.NewHTTPError(err, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) + } + + if !userExists { + return errors.NewHTTPError(nil, http.StatusBadRequest, "account doesnot exists") + } + + if err := account.ValidatePassword(newPassword); err != nil { + return err + } + + if err := VerifyOTP(rollNo, otp); err != nil { + return err + } + + hashedPwd, err := util.HashAndSalt(newPassword) + if err != nil { + return errors.NewHTTPError(err, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) + } + + err = account.UpdatePassword(rollNo, hashedPwd) + if err != nil { + return errors.NewHTTPError(err, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) + } + + return nil +} diff --git a/cmd/iitk-coin/main.go b/cmd/iitk-coin/main.go index 40c4f3e..37a5d79 100644 --- a/cmd/iitk-coin/main.go +++ b/cmd/iitk-coin/main.go @@ -4,6 +4,7 @@ import ( _ "github.com/bhuvansingla/iitk-coin/config" "github.com/bhuvansingla/iitk-coin/database" _ "github.com/bhuvansingla/iitk-coin/logger" + "github.com/bhuvansingla/iitk-coin/mail" "github.com/bhuvansingla/iitk-coin/server" log "github.com/sirupsen/logrus" ) @@ -15,6 +16,12 @@ func main() { log.Error("Error connecting to database: %s", err) return } + + err = mail.Test() + if err != nil { + log.Error("Error sending mail: %s", err) + return + } err = server.Start() if err != nil { diff --git a/config.yml b/config.yml index a2574dc..6d315f1 100644 --- a/config.yml +++ b/config.yml @@ -17,9 +17,13 @@ TAX: INTRA_BATCH: 33 JWT: + ACCESS_TOKEN: + NAME: "access_token" + EXPIRATION_TIME_IN_MIN: 10 + REFRESH_TOKEN: + NAME: "refresh_token" + EXPIRATION_TIME_IN_MIN: 50000 PRIVATE_KEY: "this-is-a-secret" - EXPIRATION_TIME_IN_MIN: 10 - COOKIE_NAME: "token" LOGGER: LOG_LEVEL: 5 # Error: 2, Warn: 3, Info: 4, Debug: 5 diff --git a/database/tables.go b/database/tables.go index 02bf322..cc3ebbb 100644 --- a/database/tables.go +++ b/database/tables.go @@ -31,11 +31,16 @@ func createTables() (err error) { log.Error(err.Error()) return } + err = createRefreshTokenTable() + if err != nil { + log.Error(err.Error()) + return + } return } func createAccountTable() (err error) { - _, err = DB.Exec("create table if not exists ACCOUNT (rollNo text, name text, password text, coins int, role int)") + _, err = DB.Exec("create table if not exists ACCOUNT (rollNo text PRIMARY KEY NOT NULL, name text, password text, coins int, role int)") return } @@ -58,3 +63,8 @@ func createRewardHistoryTable() (err error) { _, err = DB.Exec("create table if not exists REWARD_HISTORY (id SERIAL PRIMARY KEY NOT NULL, rollNo text, coins int, time NUMERIC, remarks text)") return } + +func createRefreshTokenTable() (err error) { + _, err = DB.Exec("create table if not exists REFRESH_TOKEN (rollNo text PRIMARY KEY NOT NULL, token text)") + return +} diff --git a/errors/handler.go b/errors/handler.go index 34b4130..081605a 100644 --- a/errors/handler.go +++ b/errors/handler.go @@ -16,27 +16,32 @@ func Handler(endpoint func(http.ResponseWriter, *http.Request) error) func(http. return } - log.Error(err) + WriteResponse(err, w) + } +} - clientError, ok := err.(ClientError) +func WriteResponse(err error, w http.ResponseWriter) { - if !ok { - w.WriteHeader(http.StatusInternalServerError) - return - } + log.Error(err) - body, err := clientError.ResponseBody() - if err != nil { - log.Error(err) - w.WriteHeader(http.StatusInternalServerError) - return - } + clientError, ok := err.(ClientError) - status, headers := clientError.ResponseHeaders() - for k, v := range headers { - w.Header().Set(k, v) - } - w.WriteHeader(status) - w.Write(body) + if !ok { + w.WriteHeader(http.StatusInternalServerError) + return + } + + body, err := clientError.ResponseBody() + if err != nil { + log.Error(err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + status, headers := clientError.ResponseHeaders() + for k, v := range headers { + w.Header().Set(k, v) } + w.WriteHeader(status) + w.Write(body) } diff --git a/handlers/checklogin.go b/handlers/check_login.go similarity index 100% rename from handlers/checklogin.go rename to handlers/check_login.go diff --git a/handlers/login.go b/handlers/login.go index 3c14c5e..8838812 100644 --- a/handlers/login.go +++ b/handlers/login.go @@ -3,7 +3,6 @@ package handlers import ( "encoding/json" "net/http" - "time" "github.com/bhuvansingla/iitk-coin/account" "github.com/bhuvansingla/iitk-coin/auth" @@ -41,29 +40,31 @@ func Login(w http.ResponseWriter, r *http.Request) error { return errors.NewHTTPError(err, http.StatusUnauthorized, "invalid credentials") } - token, err := auth.GenerateToken(loginRequest.RollNo) + return setCookiesAndRespond(loginRequest.RollNo, w) +} + +func setCookiesAndRespond(rollNo string, w http.ResponseWriter) error { + + accessToken, err := auth.GenerateAccessToken(rollNo) if err != nil { return errors.NewHTTPError(err, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) } - cookie := &http.Cookie{ - Name: viper.GetString("JWT.COOKIE_NAME"), - Value: token, - Expires: time.Now().Add(time.Duration(viper.GetInt("JWT.EXPIRATION_TIME_IN_MIN")) * time.Minute), - HttpOnly: true, - Path: "/", + refreshToken, err := auth.GenerateRefreshToken(rollNo) + if err != nil { + return errors.NewHTTPError(err, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) } - http.SetCookie(w, cookie) + setCookies(w, accessToken, refreshToken) - isAdmin, err := account.IsAdmin(loginRequest.RollNo) + isAdmin, err := account.IsAdmin(rollNo) if err != nil { return errors.NewHTTPError(err, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) } err = json.NewEncoder(w).Encode(&LoginResponse{ IsAdmin: isAdmin, - RollNo: loginRequest.RollNo, + RollNo: rollNo, }) if err != nil { return errors.NewHTTPError(err, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) @@ -72,13 +73,46 @@ func Login(w http.ResponseWriter, r *http.Request) error { return nil } -func Logout(w http.ResponseWriter, r *http.Request) error { - http.SetCookie(w, &http.Cookie{ - Name: viper.GetString("JWT.COOKIE_NAME"), - Value: "", - Expires: time.Now(), +func setCookies(w http.ResponseWriter, accessToken string, refreshToken string) { + cookie := &http.Cookie{ + Name: viper.GetString("JWT.ACCESS_TOKEN.NAME"), + Value: accessToken, HttpOnly: true, Path: "/", - }) + } + + http.SetCookie(w, cookie) + + cookie = &http.Cookie{ + Name: viper.GetString("JWT.REFRESH_TOKEN.NAME"), + Value: refreshToken, + HttpOnly: true, + Path: "/auth", + } + + http.SetCookie(w, cookie) +} + +func Logout(w http.ResponseWriter, r *http.Request) error { + + cookie, err := r.Cookie(viper.GetString("JWT.REFRESH_TOKEN.NAME")) + if err != nil { + setCookies(w, "", "") + return nil + } + + rollNo, err := auth.GetRollNoFromTokenCookie(cookie) + if err != nil { + setCookies(w, "", "") + return nil + } + + setCookies(w, "", "") + + err = account.DeleteRefreshToken(rollNo) + if err != nil { + return errors.NewHTTPError(err, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) + } + return nil } diff --git a/handlers/refresh.go b/handlers/refresh.go new file mode 100644 index 0000000..e110e40 --- /dev/null +++ b/handlers/refresh.go @@ -0,0 +1,52 @@ +package handlers + +import ( + "net/http" + + "github.com/bhuvansingla/iitk-coin/account" + "github.com/bhuvansingla/iitk-coin/auth" + "github.com/bhuvansingla/iitk-coin/errors" +) + +func RefreshToken(w http.ResponseWriter, r *http.Request) error { + + if r.Method != "GET" { + return errors.NewHTTPError(nil, http.StatusMethodNotAllowed, http.StatusText(http.StatusMethodNotAllowed)) + } + + rollNo, err := auth.CheckRefreshTokenValidity(r) + + if err != nil { + return err + } + + return setCookiesAndRespond(rollNo, w) +} + +func InvalidateRefreshTokens(w http.ResponseWriter, r *http.Request) error { + + if r.Method != "POST" { + return errors.NewHTTPError(nil, http.StatusMethodNotAllowed, http.StatusText(http.StatusMethodNotAllowed)) + } + + requestorRollNo, err := auth.GetRollNoFromRequest(r) + if err != nil { + return errors.NewHTTPError(err, http.StatusBadRequest, "invalid cookie") + } + + requestorRole, err := account.GetAccountRoleByRollNo(requestorRollNo) + if err != nil { + return err + } + + if !(requestorRole == account.GeneralSecretary || requestorRole == account.AssociateHead || requestorRole == account.CoreTeamMember) { + return errors.NewHTTPError(nil, http.StatusUnauthorized, "you don't have permission to invalidate refresh tokens") + } + + err = account.InvalidateAllRefreshTokens() + if err != nil { + return errors.NewHTTPError(err, http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) + } + + return nil +} diff --git a/handlers/reset_password.go b/handlers/reset_password.go new file mode 100644 index 0000000..1518599 --- /dev/null +++ b/handlers/reset_password.go @@ -0,0 +1,36 @@ +package handlers + +import ( + "encoding/json" + "net/http" + + "github.com/bhuvansingla/iitk-coin/auth" + "github.com/bhuvansingla/iitk-coin/errors" +) + +type ResetPasswordRequest struct { + RollNo string `json:"rollNo"` + NewPassword string `json:"newPassword"` + Otp string `json:"otp"` +} + +func ResetPassword(w http.ResponseWriter, r *http.Request) error { + + if r.Method != "POST" { + return errors.NewHTTPError(nil, http.StatusMethodNotAllowed, http.StatusText(http.StatusMethodNotAllowed)) + } + + var resetPasswordRequest ResetPasswordRequest + + if err := json.NewDecoder(r.Body).Decode(&resetPasswordRequest); err != nil { + return errors.NewHTTPError(err, http.StatusBadRequest, "error decoding request body") + } + + err := auth.ResetPassword(resetPasswordRequest.RollNo, resetPasswordRequest.NewPassword, resetPasswordRequest.Otp) + + if err != nil { + return err + } + + return nil +} diff --git a/mail/config.go b/mail/config.go new file mode 100644 index 0000000..d7a78fd --- /dev/null +++ b/mail/config.go @@ -0,0 +1,45 @@ +package mail + +import ( + "net/smtp" + + "github.com/spf13/viper" +) + +type smtpServer struct { + host string + port string +} + +func (s *smtpServer) Address() string { + return s.host + ":" + s.port +} + +type EmailRequest struct { + To string + OTP string +} + +var ( + MailChannel chan EmailRequest + from string + password string + server smtpServer + auth smtp.Auth + otpValidity string +) + +func init() { + MailChannel = make(chan EmailRequest) + from = viper.GetString("MAIL.FROM") + password = viper.GetString("MAIL.PASSWORD") + + otpValidity = viper.GetString("OTP.EXPIRY_PERIOD_IN_MIN") + + server = smtpServer{host: viper.GetString("MAIL.HOST"), port: viper.GetString("MAIL.PORT")} + go mailService(MailChannel) +} + +func authorize() { + auth = smtp.PlainAuth("", from, password, server.host) +} diff --git a/mail/send.go b/mail/send.go deleted file mode 100644 index 37351ea..0000000 --- a/mail/send.go +++ /dev/null @@ -1,40 +0,0 @@ -package mail - -import ( - "net/smtp" - - log "github.com/sirupsen/logrus" - "github.com/spf13/viper" -) - -type smtpServer struct { - host string - port string -} - -func (s *smtpServer) Address() string { - return s.host + ":" + s.port -} - -func SendOTP(toRollNo string, otp string) (err error) { - - from := viper.GetString("MAIL.FROM") - password := viper.GetString("MAIL.PASSWORD") - to := []string{ - toRollNo + "@iitk.ac.in", - } - smtpServer := smtpServer{host: viper.GetString("MAIL.HOST"), port: viper.GetString("MAIL.PORT")} - - message := []byte("Your OTP is " + otp) - - auth := smtp.PlainAuth("", from, password, smtpServer.host) - - err = smtp.SendMail(smtpServer.Address(), auth, from, to, message) - - if err != nil { - log.Error("error sending mail: ", err) - return err - } - log.Info("Mail sent to ", toRollNo) - return nil -} diff --git a/mail/service.go b/mail/service.go new file mode 100644 index 0000000..64a3faa --- /dev/null +++ b/mail/service.go @@ -0,0 +1,34 @@ +package mail + +import ( + "net/smtp" + + log "github.com/sirupsen/logrus" +) + +func mailService(mailChannel chan EmailRequest) { + authorize() + for request := range mailChannel { + to := []string{request.To+"@iitk.ac.in"} + msg := []byte("To: " + request.To + "@iitk.ac.in" + "\n" + + "From: " + "IITK-Coin<" + from + ">\n" + + "Subject: IITK-Coin One Time Password\n" + + "Your OTP is " + request.OTP + "\n" + + "This OTP is valid for " + otpValidity + " minutes and don't share it with anyone." + + "\n") + + err := smtp.SendMail(server.Address(), auth, from, to, msg) + + // if error, try to login again + if err != nil { + authorize() + err = smtp.SendMail(server.Address(), auth, from, to, msg) + if err != nil { + log.Error("Error sending mail: " + err.Error()) + continue + } + } + + log.Info("Mail sent to ", request.To) + } +} diff --git a/mail/test.go b/mail/test.go new file mode 100644 index 0000000..0b34729 --- /dev/null +++ b/mail/test.go @@ -0,0 +1,24 @@ +package mail + +import ( + "net/smtp" + + log "github.com/sirupsen/logrus" +) + +func Test() (err error) { + authorize() + to := []string{from} + msg := []byte("To: " + from + "\n" + + "From: " + "IITK-Coin<" + from + ">\n" + + "Subject: IITK-Coin Test Mail\n" + + "This is a Test Mail" + + "\n") + + err = smtp.SendMail(server.Address(), auth, from, to, msg) + if err != nil { + return + } + log.Info("Test mail sent") + return +} diff --git a/server/router.go b/server/router.go index aeebc73..8d46ba4 100644 --- a/server/router.go +++ b/server/router.go @@ -13,9 +13,12 @@ func setRoutes() { http.HandleFunc("/auth/login", CORS(errors.Handler(handlers.Login))) http.HandleFunc("/auth/signup", CORS(errors.Handler(handlers.Signup))) + http.HandleFunc("/auth/reset-password", CORS(errors.Handler(handlers.ResetPassword))) http.HandleFunc("/auth/check", CORS(auth.IsAuthorized((errors.Handler(handlers.CheckLogin))))) http.HandleFunc("/auth/otp", CORS(errors.Handler(handlers.GenerateOtp))) http.HandleFunc("/auth/logout", CORS(errors.Handler(handlers.Logout))) + http.HandleFunc("/auth/refresh", CORS(errors.Handler(handlers.RefreshToken))) + http.HandleFunc("/auth/refresh/invalidate", CORS(auth.IsAuthorized(errors.Handler(handlers.InvalidateRefreshTokens)))) http.HandleFunc("/user/name", CORS(auth.IsAuthorized(errors.Handler(handlers.GetNameByRollNo))))