Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SSO MFA prompt for WebUI MFA flows #49794

Merged
merged 20 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 40 additions & 6 deletions lib/client/weblogin.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ type MFAChallengeResponse struct {
WebauthnResponse *wantypes.CredentialAssertionResponse `json:"webauthn_response,omitempty"`
// SSOResponse is a response from an SSO MFA flow.
SSOResponse *SSOResponse `json:"sso_response"`
// TODO(Joerger): DELETE IN v19.0.0, WebauthnResponse used instead.
WebauthnAssertionResponse *wantypes.CredentialAssertionResponse `json:"webauthnAssertionResponse"`
}

// SSOResponse is a json compatible [proto.SSOResponse].
Expand All @@ -124,25 +126,57 @@ type SSOResponse struct {
// GetOptionalMFAResponseProtoReq converts response to a type proto.MFAAuthenticateResponse,
// if there were any responses set. Otherwise returns nil.
func (r *MFAChallengeResponse) GetOptionalMFAResponseProtoReq() (*proto.MFAAuthenticateResponse, error) {
if r.TOTPCode != "" && r.WebauthnResponse != nil {
var availableResponses int
if r.TOTPCode != "" {
availableResponses++
}
if r.WebauthnResponse != nil {
availableResponses++
}
if r.SSOResponse != nil {
availableResponses++
}

if availableResponses > 1 {
return nil, trace.BadParameter("only one MFA response field can be set")
}

if r.TOTPCode != "" {
switch {
case r.WebauthnResponse != nil:
return &proto.MFAAuthenticateResponse{Response: &proto.MFAAuthenticateResponse_Webauthn{
Webauthn: wantypes.CredentialAssertionResponseToProto(r.WebauthnResponse),
}}, nil
case r.SSOResponse != nil:
return &proto.MFAAuthenticateResponse{Response: &proto.MFAAuthenticateResponse_SSO{
SSO: &proto.SSOResponse{
RequestId: r.SSOResponse.RequestID,
Token: r.SSOResponse.Token,
},
}}, nil
case r.TOTPCode != "":
return &proto.MFAAuthenticateResponse{Response: &proto.MFAAuthenticateResponse_TOTP{
TOTP: &proto.TOTPResponse{Code: r.TOTPCode},
}}, nil
}

if r.WebauthnResponse != nil {
case r.WebauthnAssertionResponse != nil:
return &proto.MFAAuthenticateResponse{Response: &proto.MFAAuthenticateResponse_Webauthn{
Webauthn: wantypes.CredentialAssertionResponseToProto(r.WebauthnResponse),
Webauthn: wantypes.CredentialAssertionResponseToProto(r.WebauthnAssertionResponse),
}}, nil
}

return nil, nil
}

// ParseMFAChallengeResponse parses [MFAChallengeResponse] from JSON and returns it as a [proto.MFAAuthenticateResponse].
func ParseMFAChallengeResponse(mfaResponseJSON []byte) (*proto.MFAAuthenticateResponse, error) {
Joerger marked this conversation as resolved.
Show resolved Hide resolved
var resp MFAChallengeResponse
if err := json.Unmarshal(mfaResponseJSON, &resp); err != nil {
return nil, trace.Wrap(err)
}

protoResp, err := resp.GetOptionalMFAResponseProtoReq()
return protoResp, trace.Wrap(err)
}

// CreateSSHCertReq is passed by tsh to authenticate a local user without MFA
// and receive short-lived certificates.
type CreateSSHCertReq struct {
Expand Down
12 changes: 4 additions & 8 deletions lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2763,7 +2763,7 @@ func (h *Handler) mfaLoginBegin(w http.ResponseWriter, r *http.Request, p httpro
return nil, trace.AccessDenied("invalid credentials")
}

return makeAuthenticateChallenge(mfaChallenge), nil
return makeAuthenticateChallenge(mfaChallenge, "" /*channelID*/), nil
}

// mfaLoginFinish completes the MFA login ceremony, returning a new SSH
Expand Down Expand Up @@ -4847,16 +4847,12 @@ func parseMFAResponseFromRequest(r *http.Request) error {
// context and returned.
func contextWithMFAResponseFromRequestHeader(ctx context.Context, requestHeader http.Header) (context.Context, error) {
if mfaResponseJSON := requestHeader.Get("Teleport-MFA-Response"); mfaResponseJSON != "" {
var resp mfaResponse
if err := json.Unmarshal([]byte(mfaResponseJSON), &resp); err != nil {
mfaResp, err := client.ParseMFAChallengeResponse([]byte(mfaResponseJSON))
if err != nil {
return nil, trace.Wrap(err)
}

return mfa.ContextWithMFAResponse(ctx, &proto.MFAAuthenticateResponse{
Response: &proto.MFAAuthenticateResponse_Webauthn{
Webauthn: wantypes.CredentialAssertionResponseToProto(resp.WebauthnAssertionResponse),
},
}), nil
return mfa.ContextWithMFAResponse(ctx, mfaResp), nil
}

return ctx, nil
Expand Down
8 changes: 3 additions & 5 deletions lib/web/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5573,10 +5573,6 @@ func TestCreateAppSession_RequireSessionMFA(t *testing.T) {
require.NoError(t, err)
mfaResp, err := webauthnDev.SolveAuthn(chal)
require.NoError(t, err)
mfaRespJSON, err := json.Marshal(mfaResponse{
WebauthnAssertionResponse: wantypes.CredentialAssertionResponseFromProto(mfaResp.GetWebauthn()),
})
require.NoError(t, err)

// Extract the session ID and bearer token for the current session.
rawCookie := *pack.cookies[0]
Expand Down Expand Up @@ -5610,7 +5606,9 @@ func TestCreateAppSession_RequireSessionMFA(t *testing.T) {
PublicAddr: "panel.example.com",
ClusterName: "localhost",
},
MFAResponse: string(mfaRespJSON),
MFAResponse: client.MFAChallengeResponse{
WebauthnAssertionResponse: wantypes.CredentialAssertionResponseFromProto(mfaResp.GetWebauthn()),
},
},
expectMFAVerified: true,
},
Expand Down
29 changes: 15 additions & 14 deletions lib/web/apps.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ package web

import (
"context"
"encoding/json"
"net/http"
"sort"

Expand All @@ -33,7 +32,7 @@ import (
"github.com/gravitational/teleport/api/client/proto"
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes"
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/httplib"
"github.com/gravitational/teleport/lib/reversetunnelclient"
"github.com/gravitational/teleport/lib/utils"
Expand Down Expand Up @@ -191,7 +190,10 @@ type CreateAppSessionRequest struct {
// AWSRole is the AWS role ARN when accessing AWS management console.
AWSRole string `json:"arn,omitempty"`
// MFAResponse is an optional MFA response used to create an MFA verified app session.
MFAResponse string `json:"mfa_response"`
MFAResponse client.MFAChallengeResponse `json:"mfaResponse"`
// TODO(Joerger): DELETE IN v19.0.0
// Backwards compatible version of MFAResponse
MFAResponseJSON string `json:"mfa_response"`
}

// CreateAppSessionResponse is a response to POST /v1/webapi/sessions/app
Expand Down Expand Up @@ -230,17 +232,16 @@ func (h *Handler) createAppSession(w http.ResponseWriter, r *http.Request, p htt
}
}

var mfaProtoResponse *proto.MFAAuthenticateResponse
if req.MFAResponse != "" {
var resp mfaResponse
if err := json.Unmarshal([]byte(req.MFAResponse), &resp); err != nil {
return nil, trace.Wrap(err)
}
mfaResponse, err := req.MFAResponse.GetOptionalMFAResponseProtoReq()
if err != nil {
return nil, trace.Wrap(err)
}

mfaProtoResponse = &proto.MFAAuthenticateResponse{
Response: &proto.MFAAuthenticateResponse_Webauthn{
Webauthn: wantypes.CredentialAssertionResponseToProto(resp.WebauthnAssertionResponse),
},
// Fallback to backwards compatible mfa response.
if mfaResponse == nil && req.MFAResponseJSON != "" {
mfaResponse, err = client.ParseMFAChallengeResponse([]byte(req.MFAResponseJSON))
if err != nil {
return nil, trace.Wrap(err)
}
}

Expand All @@ -263,7 +264,7 @@ func (h *Handler) createAppSession(w http.ResponseWriter, r *http.Request, p htt
PublicAddr: result.App.GetPublicAddr(),
ClusterName: result.ClusterName,
AWSRoleARN: req.AWSRole,
MFAResponse: mfaProtoResponse,
MFAResponse: mfaResponse,
AppName: result.App.GetName(),
URI: result.App.GetURI(),
ClientAddr: r.RemoteAddr,
Expand Down
47 changes: 22 additions & 25 deletions lib/web/files.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package web

import (
"context"
"encoding/json"
"errors"
"net/http"
"time"
Expand All @@ -35,7 +34,6 @@ import (
"github.com/gravitational/teleport/api/utils/keys"
"github.com/gravitational/teleport/api/utils/sshutils"
"github.com/gravitational/teleport/lib/auth/authclient"
wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes"
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/multiplexer"
"github.com/gravitational/teleport/lib/reversetunnelclient"
Expand All @@ -56,8 +54,8 @@ type fileTransferRequest struct {
remoteLocation string
// filename is a file name
filename string
// webauthn is an optional parameter that contains a webauthn response string used to issue single use certs
webauthn string
// mfaResponse is an optional parameter that contains an mfa response string used to issue single use certs
mfaResponse string
// fileTransferRequestID is used to find a FileTransferRequest on a session
fileTransferRequestID string
// moderatedSessonID is an ID of a moderated session that has completed a
Expand All @@ -74,11 +72,25 @@ func (h *Handler) transferFile(w http.ResponseWriter, r *http.Request, p httprou
remoteLocation: query.Get("location"),
filename: query.Get("filename"),
namespace: defaults.Namespace,
webauthn: query.Get("webauthn"),
mfaResponse: query.Get("mfaResponse"),
fileTransferRequestID: query.Get("fileTransferRequestId"),
moderatedSessionID: query.Get("moderatedSessionId"),
}

// Check for old query parameter, uses the same data structure.
// TODO(Joerger): DELETE IN v19.0.0
if req.mfaResponse == "" {
req.mfaResponse = query.Get("webauthn")
}

var mfaResponse *proto.MFAAuthenticateResponse
if req.mfaResponse != "" {
var err error
if mfaResponse, err = client.ParseMFAChallengeResponse([]byte(req.mfaResponse)); err != nil {
return nil, trace.Wrap(err)
}
}

// Send an error if only one of these params has been sent. Both should exist or not exist together
if (req.fileTransferRequestID != "") != (req.moderatedSessionID != "") {
return nil, trace.BadParameter("fileTransferRequestId and moderatedSessionId must both be included in the same request.")
Expand Down Expand Up @@ -107,7 +119,7 @@ func (h *Handler) transferFile(w http.ResponseWriter, r *http.Request, p httprou
return nil, trace.Wrap(err)
}

if mfaReq.Required && query.Get("webauthn") == "" {
if mfaReq.Required && mfaResponse == nil {
return nil, trace.AccessDenied("MFA required for file transfer")
}

Expand Down Expand Up @@ -135,8 +147,8 @@ func (h *Handler) transferFile(w http.ResponseWriter, r *http.Request, p httprou
return nil, trace.Wrap(err)
}

if req.webauthn != "" {
err = ft.issueSingleUseCert(req.webauthn, r, tc)
if req.mfaResponse != "" {
err = ft.issueSingleUseCert(mfaResponse, r, tc)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -216,21 +228,10 @@ func (f *fileTransfer) createClient(req fileTransferRequest, httpReq *http.Reque
return tc, nil
}

type mfaResponse struct {
// WebauthnResponse is the response from authenticators.
WebauthnAssertionResponse *wantypes.CredentialAssertionResponse `json:"webauthnAssertionResponse"`
}

// issueSingleUseCert will take an assertion response sent from a solved challenge in the web UI
// and use that to generate a cert. This cert is added to the Teleport Client as an authmethod that
// can be used to connect to a node.
func (f *fileTransfer) issueSingleUseCert(webauthn string, httpReq *http.Request, tc *client.TeleportClient) error {
var mfaResp mfaResponse
err := json.Unmarshal([]byte(webauthn), &mfaResp)
if err != nil {
return trace.Wrap(err)
}

func (f *fileTransfer) issueSingleUseCert(mfaResponse *proto.MFAAuthenticateResponse, httpReq *http.Request, tc *client.TeleportClient) error {
pk, err := keys.ParsePrivateKey(f.sctx.cfg.Session.GetSSHPriv())
if err != nil {
return trace.Wrap(err)
Expand All @@ -241,11 +242,7 @@ func (f *fileTransfer) issueSingleUseCert(webauthn string, httpReq *http.Request
SSHPublicKey: pk.MarshalSSHPublicKey(),
Username: f.sctx.GetUser(),
Expires: time.Now().Add(time.Minute).UTC(),
MFAResponse: &proto.MFAAuthenticateResponse{
Response: &proto.MFAAuthenticateResponse_Webauthn{
Webauthn: wantypes.CredentialAssertionResponseToProto(mfaResp.WebauthnAssertionResponse),
},
},
MFAResponse: mfaResponse,
})
if err != nil {
return trace.Wrap(err)
Expand Down
27 changes: 23 additions & 4 deletions lib/web/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ package web
import (
"context"
"net/http"
"net/url"
"strings"

"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/julienschmidt/httprouter"

Expand Down Expand Up @@ -201,6 +203,22 @@ func (h *Handler) createAuthenticateChallengeHandle(w http.ResponseWriter, r *ht
allowReuse = mfav1.ChallengeAllowReuse_CHALLENGE_ALLOW_REUSE_YES
}

// Prepare an sso client redirect URL in case the user has an SSO MFA device.
ssoClientRedirectURL, err := url.Parse(sso.WebMFARedirect)
if err != nil {
return nil, trace.Wrap(err)
}

// id is used by the front end to differentiate between separate ongoing SSO challenges.
id, err := uuid.NewRandom()
if err != nil {
return nil, trace.Wrap(err)
}
channelID := id.String()
query := ssoClientRedirectURL.Query()
query.Set("channel_id", channelID)
ssoClientRedirectURL.RawQuery = query.Encode()

chal, err := clt.CreateAuthenticateChallenge(r.Context(), &proto.CreateAuthenticateChallengeRequest{
Request: &proto.CreateAuthenticateChallengeRequest_ContextUser{
ContextUser: &proto.ContextUser{},
Expand All @@ -211,13 +229,13 @@ func (h *Handler) createAuthenticateChallengeHandle(w http.ResponseWriter, r *ht
AllowReuse: allowReuse,
UserVerificationRequirement: req.UserVerificationRequirement,
},
SSOClientRedirectURL: sso.WebMFARedirect,
SSOClientRedirectURL: ssoClientRedirectURL.String(),
})
if err != nil {
return nil, trace.Wrap(err)
}

return makeAuthenticateChallenge(chal), nil
return makeAuthenticateChallenge(chal, channelID), nil
}

// createAuthenticateChallengeWithTokenHandle creates and returns MFA authenticate challenges for the user defined in token.
Expand All @@ -235,7 +253,7 @@ func (h *Handler) createAuthenticateChallengeWithTokenHandle(w http.ResponseWrit
return nil, trace.Wrap(err)
}

return makeAuthenticateChallenge(chal), nil
return makeAuthenticateChallenge(chal, "" /*channelID*/), nil
}

type createRegisterChallengeWithTokenRequest struct {
Expand Down Expand Up @@ -581,7 +599,7 @@ func (h *Handler) checkMFARequired(ctx context.Context, req *isMFARequiredReques
}

// makeAuthenticateChallenge converts proto to JSON format.
func makeAuthenticateChallenge(protoChal *proto.MFAAuthenticateChallenge) *client.MFAAuthenticateChallenge {
func makeAuthenticateChallenge(protoChal *proto.MFAAuthenticateChallenge, ssoChannelID string) *client.MFAAuthenticateChallenge {
chal := &client.MFAAuthenticateChallenge{
TOTPChallenge: protoChal.GetTOTP() != nil,
}
Expand All @@ -590,6 +608,7 @@ func makeAuthenticateChallenge(protoChal *proto.MFAAuthenticateChallenge) *clien
}
if protoChal.GetSSOChallenge() != nil {
chal.SSOChallenge = client.SSOChallengeFromProto(protoChal.GetSSOChallenge())
chal.SSOChallenge.ChannelID = ssoChannelID
}
return chal
}
Loading
Loading