diff --git a/pkg/utils/jwt/jwt.go b/pkg/utils/jwt/jwt.go index e0ec4c5..ee6a24c 100644 --- a/pkg/utils/jwt/jwt.go +++ b/pkg/utils/jwt/jwt.go @@ -30,6 +30,7 @@ import ( w "github.com/MicroOps-cn/fuck/wrapper" "github.com/MicroOps-cn/idas/pkg/common" + "github.com/MicroOps-cn/idas/pkg/errors" "github.com/golang-jwt/jwt/v4" ) @@ -110,6 +111,50 @@ func (j *JWTConfig) UnmarshalJSON(bytes []byte) (err error) { *j = *issuer return nil } +func privateKeyToPEM(pk any) (string, error) { + asn1Bytes, err := x509.MarshalPKCS8PrivateKey(pk) + if err != nil { + return "", err + } + pemBlock := &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: asn1Bytes, + } + + // 将PEM块编码为PEM格式 + pemBytes := pem.EncodeToMemory(pemBlock) + + return string(pemBytes), nil +} + +//func (j JWTConfig) MarshalJSONPB(_ *jsonpb.Marshaler) ([]byte, error) { +// return j.MarshalJSON() +//} + +func (j JWTConfig) MarshalJSON() ([]byte, error) { + type plain struct { + PrivateKey string `json:"private_key"` + Algorithm string `json:"algorithm"` + } + switch j.Algorithm.(type) { + case *jwt.SigningMethodHMAC: + return json.Marshal(plain{ + PrivateKey: string(j.PrivateKey.([]byte)), + Algorithm: j.Algorithm.Alg(), + }) + case *jwt.SigningMethodRSA, *jwt.SigningMethodECDSA: + pemStr, err := privateKeyToPEM(j.PrivateKey) + if err != nil { + return nil, err + } + return json.Marshal(plain{ + PrivateKey: pemStr, + Algorithm: j.Algorithm.Alg(), + }) + default: + return nil, errors.New("unsupported algorithm") + } +} func NewRandomKey(method string) (string, error) { switch method { @@ -189,13 +234,6 @@ func NewJWTConfig(issuerId, method, privateKey string) (*JWTConfig, error) { if err != nil { return nil, fmt.Errorf("failed to load rsa private key: %s", err) } - //pubk, err := jwt.ParseRSAPublicKeyFromPEM([]byte(publicKey)) - //if err != nil { - // return nil, fmt.Errorf("failed to load rsa public key: %s", err) - //} - //if pubk.N.Cmp(privk.N) != 0 || pubk.E != privk.E { - // return nil, fmt.Errorf("public key does not match private key") - //} jwtConfig.PublicKey = privk.Public() jwtConfig.PrivateKey = privk case "ES256", "ES384", "ES512": @@ -203,13 +241,6 @@ func NewJWTConfig(issuerId, method, privateKey string) (*JWTConfig, error) { if err != nil { return nil, fmt.Errorf("failed to load ecdsa private key: %s", err) } - //pubk, err := jwt.ParseECPublicKeyFromPEM([]byte(publicKey)) - //if err != nil { - // return nil, fmt.Errorf("failed to load ecdsa public key: %s", err) - //} - //if pubk.X.Cmp(privk.X) != 0 || pubk.Y.Cmp(privk.Y) != 0 { - // return nil, fmt.Errorf("public key does not match private key") - //} jwtConfig.PublicKey = privk.Public() jwtConfig.PrivateKey = privk switch privk.Curve.Params().BitSize {