Skip to content

Commit

Permalink
add deflate method on saml request
Browse files Browse the repository at this point in the history
  • Loading branch information
Zwiterrion committed Jul 19, 2023
1 parent ab0df77 commit ed9f79b
Showing 1 changed file with 28 additions and 1 deletion.
29 changes: 28 additions & 1 deletion otoroshi/app/auth/saml/SAMLClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ import scala.concurrent.{ExecutionContext, Future}
import scala.jdk.CollectionConverters.{asScalaBufferConverter, asScalaSetConverter}
import scala.util.Try

import java.util.zip.Deflater

case class SAMLModule(authConfig: SamlAuthModuleConfig) extends AuthModule {

import SAMLModule._
Expand Down Expand Up @@ -777,10 +779,18 @@ object SAMLModule {
nameIDPolicy.setFormat(samlConfig.nameIDFormat.value)
request.setNameIDPolicy(nameIDPolicy)

request.setID("z" + UUID.randomUUID().toString)

val issuer = buildObject(Issuer.DEFAULT_ELEMENT_NAME).asInstanceOf[Issuer]
issuer.setValue(samlConfig.issuer)
request.setIssuer(issuer)

val subject = buildObject(Subject.DEFAULT_ELEMENT_NAME).asInstanceOf[Subject]
val nameID = buildObject(NameID.DEFAULT_ELEMENT_NAME).asInstanceOf[NameID]
nameID.setValue("z" + UUID.randomUUID().toString)
subject.setNameID(nameID)
request.setSubject(subject)

request.setIssueInstant(Instant.now())

signSAMLObject(env, samlConfig, request.asInstanceOf[RequestAbstractType]).map {
Expand Down Expand Up @@ -822,8 +832,10 @@ object SAMLModule {
val dom = marshaller.marshall(request)

XMLHelper.writeNode(dom, stringWriter)
val body = stringWriter.toString
val deflatedBody = doDeflate(body.getBytes(StandardCharsets.UTF_8))

org.apache.commons.codec.binary.Base64.encodeBase64String(stringWriter.toString.getBytes(StandardCharsets.UTF_8))
org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(deflatedBody)
}

def decodeAndValidateSamlResponse(
Expand Down Expand Up @@ -1013,6 +1025,21 @@ object SAMLModule {
}
}

def doDeflate(dataBytes: Array[Byte]): Array[Byte] = {
var compBufSize = 655316
if (compBufSize < dataBytes.length + 5) {
compBufSize = dataBytes.length + 5
}
val compBuf = new Array[Byte](compBufSize)
val compresser = new Deflater(9, true)
compresser.setInput(dataBytes)
compresser.finish
val compressedDataLength = compresser.deflate(compBuf)
val compressedData = new Array[Byte](compressedDataLength)
System.arraycopy(compBuf, 0, compressedData, 0, compressedDataLength)
compressedData
}

def decodeAssertionWithCertificate(response: Response, certificate: BasicX509Credential): Future[Unit] = {
val resolverChain = new util.ArrayList[KeyInfoCredentialResolver]()
resolverChain.add(new StaticKeyInfoCredentialResolver(certificate))
Expand Down

0 comments on commit ed9f79b

Please sign in to comment.