diff --git a/otoroshi/app/auth/saml/SAMLClient.scala b/otoroshi/app/auth/saml/SAMLClient.scala index bb5009430a..53a0c25122 100644 --- a/otoroshi/app/auth/saml/SAMLClient.scala +++ b/otoroshi/app/auth/saml/SAMLClient.scala @@ -55,6 +55,8 @@ import scala.concurrent.{ExecutionContext, Future} import scala.jdk.CollectionConverters.{asScalaBufferConverter, asScalaSetConverter} import scala.util.Try +import java.util.zip.Deflater + case class SAMLModule(authConfig: SamlAuthModuleConfig) extends AuthModule { import SAMLModule._ @@ -777,10 +779,18 @@ object SAMLModule { nameIDPolicy.setFormat(samlConfig.nameIDFormat.value) request.setNameIDPolicy(nameIDPolicy) + request.setID("z" + UUID.randomUUID().toString) + val issuer = buildObject(Issuer.DEFAULT_ELEMENT_NAME).asInstanceOf[Issuer] issuer.setValue(samlConfig.issuer) request.setIssuer(issuer) + val subject = buildObject(Subject.DEFAULT_ELEMENT_NAME).asInstanceOf[Subject] + val nameID = buildObject(NameID.DEFAULT_ELEMENT_NAME).asInstanceOf[NameID] + nameID.setValue("z" + UUID.randomUUID().toString) + subject.setNameID(nameID) + request.setSubject(subject) + request.setIssueInstant(Instant.now()) signSAMLObject(env, samlConfig, request.asInstanceOf[RequestAbstractType]).map { @@ -822,8 +832,10 @@ object SAMLModule { val dom = marshaller.marshall(request) XMLHelper.writeNode(dom, stringWriter) + val body = stringWriter.toString + val deflatedBody = doDeflate(body.getBytes(StandardCharsets.UTF_8)) - org.apache.commons.codec.binary.Base64.encodeBase64String(stringWriter.toString.getBytes(StandardCharsets.UTF_8)) + org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(deflatedBody) } def decodeAndValidateSamlResponse( @@ -1013,6 +1025,21 @@ object SAMLModule { } } + def doDeflate(dataBytes: Array[Byte]): Array[Byte] = { + var compBufSize = 655316 + if (compBufSize < dataBytes.length + 5) { + compBufSize = dataBytes.length + 5 + } + val compBuf = new Array[Byte](compBufSize) + val compresser = new Deflater(9, true) + compresser.setInput(dataBytes) + compresser.finish + val compressedDataLength = compresser.deflate(compBuf) + val compressedData = new Array[Byte](compressedDataLength) + System.arraycopy(compBuf, 0, compressedData, 0, compressedDataLength) + compressedData + } + def decodeAssertionWithCertificate(response: Response, certificate: BasicX509Credential): Future[Unit] = { val resolverChain = new util.ArrayList[KeyInfoCredentialResolver]() resolverChain.add(new StaticKeyInfoCredentialResolver(certificate))