From e9d685e02018d44f0fc84cdbf65f984150d2e0c4 Mon Sep 17 00:00:00 2001 From: Stuart Douglas Date: Fri, 4 Oct 2024 15:55:16 +0100 Subject: [PATCH] fix: check Destination url against the request URL or the ACS url, instead of just the ACS url --- service_provider.go | 25 ++++--- service_provider_test.go | 155 ++++++++++++++++++++++++++------------- 2 files changed, 119 insertions(+), 61 deletions(-) diff --git a/service_provider.go b/service_provider.go index 30b35670..cbd398b0 100644 --- a/service_provider.go +++ b/service_provider.go @@ -652,7 +652,7 @@ func (sp *ServiceProvider) handleArtifactRequest(ctx context.Context, artifactID retErr.PrivateErr = fmt.Errorf("Error during artifact resolution: %s", err) return nil, retErr } - assertion, err := sp.ParseXMLArtifactResponse(responseBody, possibleRequestIDs, artifactResolveRequest.ID) + assertion, err := sp.ParseXMLArtifactResponse(responseBody, possibleRequestIDs, artifactResolveRequest.ID, *req.URL) if err != nil { return nil, err } @@ -670,7 +670,7 @@ func (sp *ServiceProvider) parseResponseHTTP(req *http.Request, possibleRequestI return nil, retErr } - assertion, err := sp.ParseXMLResponse(rawResponseBuf, possibleRequestIDs) + assertion, err := sp.ParseXMLResponse(rawResponseBuf, possibleRequestIDs, *req.URL) if err != nil { return nil, err } @@ -687,7 +687,7 @@ func (sp *ServiceProvider) parseResponseHTTP(req *http.Request, possibleRequestI // properties are useful in describing which part of the parsing process // failed. However, to discourage inadvertent disclosure the diagnostic // information, the Error() method returns a static string. -func (sp *ServiceProvider) ParseXMLArtifactResponse(soapResponseXML []byte, possibleRequestIDs []string, artifactRequestID string) (*Assertion, error) { +func (sp *ServiceProvider) ParseXMLArtifactResponse(soapResponseXML []byte, possibleRequestIDs []string, artifactRequestID string, currentURL url.URL) (*Assertion, error) { now := TimeNow() retErr := &InvalidResponseError{ Response: string(soapResponseXML), @@ -727,10 +727,10 @@ func (sp *ServiceProvider) ParseXMLArtifactResponse(soapResponseXML []byte, poss return nil, retErr } - return sp.parseArtifactResponse(artifactResponseEl, possibleRequestIDs, artifactRequestID, now) + return sp.parseArtifactResponse(artifactResponseEl, possibleRequestIDs, artifactRequestID, now, currentURL) } -func (sp *ServiceProvider) parseArtifactResponse(artifactResponseEl *etree.Element, possibleRequestIDs []string, artifactRequestID string, now time.Time) (*Assertion, error) { +func (sp *ServiceProvider) parseArtifactResponse(artifactResponseEl *etree.Element, possibleRequestIDs []string, artifactRequestID string, now time.Time, currentURL url.URL) (*Assertion, error) { retErr := &InvalidResponseError{ Now: now, Response: elementToString(artifactResponseEl), @@ -778,7 +778,7 @@ func (sp *ServiceProvider) parseArtifactResponse(artifactResponseEl *etree.Eleme return nil, retErr } - assertion, err := sp.parseResponse(responseEl, possibleRequestIDs, now, signatureRequirement) + assertion, err := sp.parseResponse(responseEl, possibleRequestIDs, now, signatureRequirement, currentURL) if err != nil { retErr.PrivateErr = err return nil, retErr @@ -798,7 +798,7 @@ func (sp *ServiceProvider) parseArtifactResponse(artifactResponseEl *etree.Eleme // properties are useful in describing which part of the parsing process // failed. However, to discourage inadvertent disclosure the diagnostic // information, the Error() method returns a static string. -func (sp *ServiceProvider) ParseXMLResponse(decodedResponseXML []byte, possibleRequestIDs []string) (*Assertion, error) { +func (sp *ServiceProvider) ParseXMLResponse(decodedResponseXML []byte, possibleRequestIDs []string, currentURL url.URL) (*Assertion, error) { now := TimeNow() var err error retErr := &InvalidResponseError{ @@ -822,7 +822,7 @@ func (sp *ServiceProvider) ParseXMLResponse(decodedResponseXML []byte, possibleR return nil, retErr } - assertion, err := sp.parseResponse(doc.Root(), possibleRequestIDs, now, signatureRequired) + assertion, err := sp.parseResponse(doc.Root(), possibleRequestIDs, now, signatureRequired, currentURL) if err != nil { retErr.PrivateErr = err return nil, retErr @@ -844,7 +844,7 @@ const ( // This function handles decrypting the message, verifying the digital // signature on the assertion, and verifying that the specified conditions // and properties are met. -func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequestIDs []string, now time.Time, signatureRequirement signatureRequirement) (*Assertion, error) { +func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequestIDs []string, now time.Time, signatureRequirement signatureRequirement, currentURL url.URL) (*Assertion, error) { var responseSignatureErr error var responseHasSignature bool if signatureRequirement == signatureRequired { @@ -867,8 +867,11 @@ func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequ // If the response is *not* signed, the Destination may be omitted. if responseHasSignature || response.Destination != "" { - if response.Destination != sp.AcsURL.String() { - return nil, fmt.Errorf("`Destination` does not match AcsURL (expected %q, actual %q)", sp.AcsURL.String(), response.Destination) + // Per section 3.4.5.2 of the SAML spec, Destination must match the location at which the response was received, i.e. currentURL. + // Historically, we checked against the SP's ACS URL instead of currentURL, which is usually the same but may differ in query params. + // To mitigate the risk of switching to comparing against currentURL, we still allow it if the ACS URL matches, even if the current URL doesn't. + if response.Destination != currentURL.String() && response.Destination != sp.AcsURL.String() { + return nil, fmt.Errorf("`Destination` does not match requested URL or AcsURL (destination %q, requested %q, acs %q)", response.Destination, currentURL.String(), sp.AcsURL.String()) } } diff --git a/service_provider_test.go b/service_provider_test.go index 4309738c..df90cd97 100644 --- a/service_provider_test.go +++ b/service_provider_test.go @@ -471,7 +471,7 @@ func TestSPCanHandleOneloginResponse(t *testing.T) { err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(SamlResponse)) assertion, err := s.ParseResponse(&req, []string{"id-d40c15c104b52691eccf0a2a5c8a15595be75423"}) assert.Check(t, err) @@ -552,7 +552,7 @@ func TestSPCanHandleOktaSignedResponseEncryptedAssertion(t *testing.T) { err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(SamlResponse)) assertion, err := s.ParseResponse(&req, []string{"id-a7364d1e4432aa9085a7a8bd824ea2fa8fa8f684"}) assert.Check(t, err) @@ -593,7 +593,7 @@ func TestSPCanHandleOktaResponseEncryptedSignedAssertion(t *testing.T) { err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(SamlResponse)) assertion, err := s.ParseResponse(&req, []string{"id-6d976cdde8e76df5df0a8ff58148fc0b7ec6796d"}) assert.Check(t, err) @@ -634,7 +634,7 @@ func TestSPCanHandleOktaResponseEncryptedAssertionBothSigned(t *testing.T) { err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(SamlResponse)) assertion, err := s.ParseResponse(&req, []string{"id-953d4cab69ff475c5901d12e585b0bb15a7b85fe"}) assert.Check(t, err) @@ -675,7 +675,7 @@ func TestSPCanHandlePlaintextResponse(t *testing.T) { err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(SamlResponse)) assertion, err := s.ParseResponse(&req, []string{"id-fd419a5ab0472645427f8e07d87a3a5dd0b2e9a6"}) assert.Check(t, err) @@ -739,7 +739,7 @@ func TestSPRejectsInjectedComment(t *testing.T) { // this is a valid response { - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(SamlResponse)) assertion, err := s.ParseResponse(&req, []string{"id-fd419a5ab0472645427f8e07d87a3a5dd0b2e9a6"}) assert.Check(t, err) @@ -752,7 +752,7 @@ func TestSPRejectsInjectedComment(t *testing.T) { y := strings.Replace(string(x), "ross@octolabs.io", "ross@octolabs.io", 1) SamlResponse = []byte(base64.StdEncoding.EncodeToString([]byte(y))) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(SamlResponse)) assertion, err := s.ParseResponse(&req, []string{"id-fd419a5ab0472645427f8e07d87a3a5dd0b2e9a6"}) @@ -774,7 +774,7 @@ func TestSPRejectsInjectedComment(t *testing.T) { y := strings.Replace(string(x), "ross@octolabs.io", "ross@octolabs.io.example.com", 1) SamlResponse = []byte(base64.StdEncoding.EncodeToString([]byte(y))) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(SamlResponse)) _, err := s.ParseResponse(&req, []string{"id-fd419a5ab0472645427f8e07d87a3a5dd0b2e9a6"}) assert.Check(t, err != nil) @@ -797,7 +797,7 @@ func TestSPCanParseResponse(t *testing.T) { err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(test.SamlResponse)) assertion, err := s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) assert.Check(t, err) @@ -944,7 +944,7 @@ func TestSPCanProcessResponseWithoutDestination(t *testing.T) { err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} test.replaceDestination("") req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(test.SamlResponse)) _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) @@ -972,6 +972,13 @@ func removeDestinationFromDocument(doc *etree.Document) *etree.Document { return doc } +func overrideDestinationFromDocument(doc *etree.Document, newDestination string) *etree.Document { + responseEl := doc.FindElement("//Response") + destAttr := responseEl.SelectAttr("Destination") + destAttr.Value = newDestination + return doc +} + func TestServiceProviderMismatchedDestinationsWithSignaturePresent(t *testing.T) { test := NewServiceProviderTest(t) s := ServiceProvider{ @@ -984,13 +991,61 @@ func TestServiceProviderMismatchedDestinationsWithSignaturePresent(t *testing.T) err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} s.AcsURL = mustParseURL("https://wrong/saml2/acs") bytes, _ := test.responseDom(t).WriteToBytes() req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(bytes)) _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, - "`Destination` does not match AcsURL (expected \"https://wrong/saml2/acs\", actual \"https://15661444.ngrok.io/saml2/acs\")")) + "`Destination` does not match requested URL or AcsURL (destination \"https://15661444.ngrok.io/saml2/acs\", requested \"https://wrong/saml2/acs\", acs \"https://wrong/saml2/acs\")")) +} + +func TestDestinationMatchesCurrentUrlButNotAcsUrlWithSignaturePresent(t *testing.T) { + test := NewServiceProviderTest(t) + s := ServiceProvider{ + Key: test.Key, + Certificate: test.Certificate, + MetadataURL: mustParseURL("https://15661444.ngrok.io/saml2/metadata"), + AcsURL: mustParseURL("https://15661444.ngrok.io/saml2/acs"), + IDPMetadata: &EntityDescriptor{}, + } + err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) + assert.Check(t, err) + + currentUrl := mustParseURL("https://15661444.ngrok.io/saml2/acs?current=true") + req := http.Request{PostForm: url.Values{}, URL: ¤tUrl} + bytes, _ := overrideDestinationFromDocument(test.responseDom(t), "https://15661444.ngrok.io/saml2/acs?current=true").WriteToBytes() + req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(bytes)) + assertion, err := s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) + if err != nil { + t.Logf("%s", err.(*InvalidResponseError).PrivateErr) + } + assert.Check(t, err) + assert.Check(t, is.Equal("_41bd295976dadd70e1480f318e772841", assertion.Subject.NameID.Value)) +} + +func TestDestinationMatchesAcsUrlButNotCurrentUrlWithSignaturePresent(t *testing.T) { + test := NewServiceProviderTest(t) + s := ServiceProvider{ + Key: test.Key, + Certificate: test.Certificate, + MetadataURL: mustParseURL("https://15661444.ngrok.io/saml2/metadata"), + AcsURL: mustParseURL("https://15661444.ngrok.io/saml2/acs"), + IDPMetadata: &EntityDescriptor{}, + } + err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) + assert.Check(t, err) + + currentUrl := mustParseURL("https://15661444.ngrok.io/saml2/acs?query=param") + req := http.Request{PostForm: url.Values{}, URL: ¤tUrl} + bytes, _ := test.responseDom(t).WriteToBytes() + req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(bytes)) + assertion, err := s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) + if err != nil { + t.Logf("%s", err.(*InvalidResponseError).PrivateErr) + } + assert.Check(t, err) + assert.Check(t, is.Equal("_41bd295976dadd70e1480f318e772841", assertion.Subject.NameID.Value)) } func TestServiceProviderMissingDestinationWithSignaturePresent(t *testing.T) { @@ -1005,12 +1060,12 @@ func TestServiceProviderMissingDestinationWithSignaturePresent(t *testing.T) { err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} bytes, _ := removeDestinationFromDocument(addSignatureToDocument(test.responseDom(t))).WriteToBytes() req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(bytes)) _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, - "`Destination` does not match AcsURL (expected \"https://15661444.ngrok.io/saml2/acs\", actual \"\")")) + "`Destination` does not match requested URL or AcsURL (destination \"\", requested \"https://15661444.ngrok.io/saml2/acs\", acs \"https://15661444.ngrok.io/saml2/acs\")")) } func TestSPMismatchedDestinationsWithSignaturePresent(t *testing.T) { @@ -1025,13 +1080,13 @@ func TestSPMismatchedDestinationsWithSignaturePresent(t *testing.T) { err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} test.replaceDestination("https://wrong/saml2/acs") bytes, _ := addSignatureToDocument(test.responseDom(t)).WriteToBytes() req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(bytes)) _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, - "`Destination` does not match AcsURL (expected \"https://15661444.ngrok.io/saml2/acs\", actual \"https://wrong/saml2/acs\")")) + "`Destination` does not match requested URL or AcsURL (destination \"https://wrong/saml2/acs\", requested \"https://15661444.ngrok.io/saml2/acs\", acs \"https://15661444.ngrok.io/saml2/acs\")")) } func TestSPMismatchedDestinationsWithNoSignaturePresent(t *testing.T) { @@ -1046,13 +1101,13 @@ func TestSPMismatchedDestinationsWithNoSignaturePresent(t *testing.T) { err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} test.replaceDestination("https://wrong/saml2/acs") bytes, _ := test.responseDom(t).WriteToBytes() req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(bytes)) _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, - "`Destination` does not match AcsURL (expected \"https://15661444.ngrok.io/saml2/acs\", actual \"https://wrong/saml2/acs\")")) + "`Destination` does not match requested URL or AcsURL (destination \"https://wrong/saml2/acs\", requested \"https://15661444.ngrok.io/saml2/acs\", acs \"https://15661444.ngrok.io/saml2/acs\")")) } func TestSPMissingDestinationWithSignaturePresent(t *testing.T) { @@ -1067,13 +1122,13 @@ func TestSPMissingDestinationWithSignaturePresent(t *testing.T) { err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} test.replaceDestination("") bytes, _ := addSignatureToDocument(test.responseDom(t)).WriteToBytes() req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(bytes)) _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, - "`Destination` does not match AcsURL (expected \"https://15661444.ngrok.io/saml2/acs\", actual \"\")")) + "`Destination` does not match requested URL or AcsURL (destination \"\", requested \"https://15661444.ngrok.io/saml2/acs\", acs \"https://15661444.ngrok.io/saml2/acs\")")) } func TestSPInvalidAssertions(t *testing.T) { @@ -1187,7 +1242,7 @@ func TestXswPermutationOneIsRejected(t *testing.T) { err := xml.Unmarshal(idpMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(respStr)) _, err = s.ParseResponse(&req, []string{"id-d40c15c104b52691eccf0a2a5c8a15595be75423"}) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, @@ -1214,7 +1269,7 @@ func TestXswPermutationTwoIsRejected(t *testing.T) { err := xml.Unmarshal(idpMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(respStr)) _, err = s.ParseResponse(&req, []string{"id-d40c15c104b52691eccf0a2a5c8a15595be75423"}) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, @@ -1241,7 +1296,7 @@ func TestXswPermutationThreeIsRejected(t *testing.T) { err := xml.Unmarshal(idpMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(respStr)) _, err = s.ParseResponse(&req, []string{"ONELOGIN_4fee3b046395c4e751011e97f8900b5273d56685"}) @@ -1273,7 +1328,7 @@ func TestXswPermutationFourIsRejected(t *testing.T) { err := xml.Unmarshal(idpMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(respStr)) _, err = s.ParseResponse(&req, []string{"ONELOGIN_4fee3b046395c4e751011e97f8900b5273d56685"}) @@ -1303,7 +1358,7 @@ func TestXswPermutationFiveIsRejected(t *testing.T) { err := xml.Unmarshal(idpMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(respStr)) _, err = s.ParseResponse(&req, []string{"ONELOGIN_4fee3b046395c4e751011e97f8900b5273d56685"}) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, @@ -1330,7 +1385,7 @@ func TestXswPermutationSixIsRejected(t *testing.T) { err := xml.Unmarshal(idpMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(respStr)) _, err = s.ParseResponse(&req, []string{"ONELOGIN_4fee3b046395c4e751011e97f8900b5273d56685"}) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, @@ -1360,7 +1415,7 @@ func TestXswPermutationSevenIsRejected(t *testing.T) { err := xml.Unmarshal(idpMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(respStr)) _, err = s.ParseResponse(&req, []string{"ONELOGIN_4fee3b046395c4e751011e97f8900b5273d56685"}) // It's the assertion signature that can't be verified. The error message is generic and always mentions Response @@ -1391,7 +1446,7 @@ func TestXswPermutationEightIsRejected(t *testing.T) { err := xml.Unmarshal(idpMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(respStr)) _, err = s.ParseResponse(&req, []string{"ONELOGIN_4fee3b046395c4e751011e97f8900b5273d56685"}) // It's the assertion signature that can't be verified. The error message is generic and always mentions Response @@ -1422,7 +1477,7 @@ func TestXswPermutationNineIsRejected(t *testing.T) { err := xml.Unmarshal(idpMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(respStr)) _, err = s.ParseResponse(&req, []string{"ONELOGIN_4fee3b046395c4e751011e97f8900b5273d56685"}) // It's the assertion signature that can't be verified. The error message is generic and always mentions Response @@ -1449,7 +1504,7 @@ func TestSPRealWorldKeyInfoHasRSAPublicKeyNotX509Cert(t *testing.T) { err := xml.Unmarshal(idpMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(respStr)) _, err = s.ParseResponse(&req, []string{"id-3992f74e652d89c3cf1efd6c7e472abaac9bc917"}) if err != nil { @@ -1480,7 +1535,7 @@ func TestSPRealWorldAssertionSignedNotResponse(t *testing.T) { err := xml.Unmarshal(idpMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(respStr)) _, err = s.ParseResponse(&req, []string{"id-3992f74e652d89c3cf1efd6c7e472abaac9bc917"}) if err != nil { @@ -1519,7 +1574,7 @@ func TestServiceProviderCanHandleSignedAssertionsResponse(t *testing.T) { err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} req.PostForm.Set("SAMLResponse", string(SamlResponse)) assertion, err := s.ParseResponse(&req, []string{"ONELOGIN_4fee3b046395c4e751011e97f8900b5273d56685"}) if err != nil { @@ -1582,7 +1637,7 @@ func TestSPResponseWithNoIssuer(t *testing.T) { err := xml.Unmarshal(test.IDPMetadata, &s.IDPMetadata) assert.Check(t, err) - req := http.Request{PostForm: url.Values{}} + req := http.Request{PostForm: url.Values{}, URL: &s.AcsURL} // Response with no (modified ServiceProviderTest.SamlResponse) samlResponse := golden.Get(t, "TestSPResponseWithNoIssuer_response") @@ -1695,7 +1750,7 @@ func TestParseXMLArtifactResponse(t *testing.T) { possibleReqIDs := []string{"id-f3c7bc7d626a4ededa6028b718e5252c6e770b94"} reqID := "id-218eb155248f7db7c85fe4e2709a3f17a70d09c7" - assertion, err := sp.ParseXMLArtifactResponse(samlResponse, possibleReqIDs, reqID) + assertion, err := sp.ParseXMLArtifactResponse(samlResponse, possibleReqIDs, reqID, sp.AcsURL) assert.Check(t, err) x, err := xml.Marshal(assertion) @@ -1727,7 +1782,7 @@ func TestParseBadXMLArtifactResponse(t *testing.T) { IDPMetadata: &EntityDescriptor{}, } - assertion, err := sp.ParseXMLArtifactResponse(samlResponse, possibleReqIDs, reqID) + assertion, err := sp.ParseXMLArtifactResponse(samlResponse, possibleReqIDs, reqID, sp.AcsURL) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, "response Issuer does not match the IDP metadata (expected \"\")")) assert.Check(t, is.Nil(assertion)) @@ -1735,9 +1790,9 @@ func TestParseBadXMLArtifactResponse(t *testing.T) { err = xml.Unmarshal(test.IDPMetadata, &sp.IDPMetadata) assert.Check(t, err) - assertion, err = sp.ParseXMLArtifactResponse(samlResponse, possibleReqIDs, reqID) + assertion, err = sp.ParseXMLArtifactResponse(samlResponse, possibleReqIDs, reqID, sp.AcsURL) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, - "`Destination` does not match AcsURL (expected \"https://example.com/saml2/acs\", actual \"http://localhost:8000/saml/acs\")")) + "`Destination` does not match requested URL or AcsURL (destination \"http://localhost:8000/saml/acs\", requested \"https://example.com/saml2/acs\", acs \"https://example.com/saml2/acs\")")) assert.Check(t, is.Nil(assertion)) sp.AcsURL = mustParseURL("http://localhost:8000/saml/acs") @@ -1748,7 +1803,7 @@ func TestParseBadXMLArtifactResponse(t *testing.T) { return rv } - assertion, err = sp.ParseXMLArtifactResponse(samlResponse, possibleReqIDs, reqID) + assertion, err = sp.ParseXMLArtifactResponse(samlResponse, possibleReqIDs, reqID, sp.AcsURL) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, "response IssueInstant expired at 2021-08-17 10:28:50.146 +0000 UTC")) assert.Check(t, is.Nil(assertion)) @@ -1763,38 +1818,38 @@ func TestParseBadXMLArtifactResponse(t *testing.T) { return rv } - assertion, err = sp.ParseXMLArtifactResponse(samlResponse, possibleReqIDs, reqID) + assertion, err = sp.ParseXMLArtifactResponse(samlResponse, possibleReqIDs, reqID, sp.AcsURL) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, "cannot validate signature on ArtifactResponse: Cert is not valid at this time")) assert.Check(t, is.Nil(assertion)) Clock = dsig.NewFakeClockAt(TimeNow()) wrongReqID := "id-218eb155248f7db7c85fe4e2709a3f17a70d09c8" - assertion, err = sp.ParseXMLArtifactResponse(samlResponse, possibleReqIDs, wrongReqID) + assertion, err = sp.ParseXMLArtifactResponse(samlResponse, possibleReqIDs, wrongReqID, sp.AcsURL) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, "`InResponseTo` does not match the artifact request ID (expected id-218eb155248f7db7c85fe4e2709a3f17a70d09c8)")) assert.Check(t, is.Nil(assertion)) wrongPossibleReqIDs := []string{"id-f3c7bc7d626a4ededa6028b718e5252c6e770b95"} - assertion, err = sp.ParseXMLArtifactResponse(samlResponse, wrongPossibleReqIDs, reqID) + assertion, err = sp.ParseXMLArtifactResponse(samlResponse, wrongPossibleReqIDs, reqID, sp.AcsURL) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, "`InResponseTo` does not match any of the possible request IDs (expected [id-f3c7bc7d626a4ededa6028b718e5252c6e770b95])")) assert.Check(t, is.Nil(assertion)) // random other key sp.Key = mustParsePrivateKey(golden.Get(t, "key_2017.pem")).(*rsa.PrivateKey) - assertion, err = sp.ParseXMLArtifactResponse(samlResponse, possibleReqIDs, reqID) + assertion, err = sp.ParseXMLArtifactResponse(samlResponse, possibleReqIDs, reqID, sp.AcsURL) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, "failed to decrypt EncryptedAssertion: certificate does not match provided key")) assert.Check(t, is.Nil(assertion)) // no input - assertion, err = sp.ParseXMLArtifactResponse([]byte(""), possibleReqIDs, reqID) + assertion, err = sp.ParseXMLArtifactResponse([]byte(""), possibleReqIDs, reqID, sp.AcsURL) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, "invalid xml: no root")) assert.Check(t, is.Nil(assertion)) - assertion, err = sp.ParseXMLArtifactResponse([]byte(""), []string{}) + assertion, err := sp.ParseXMLResponse([]byte(""), []string{}, mustParseURL("http://test.com")) assert.Check(t, is.Error(err.(*InvalidResponseError).PrivateErr, "invalid xml: no root")) assert.Check(t, is.Nil(assertion)) - assertion, err = sp.ParseXMLResponse([]byte("