Skip to content

Commit

Permalink
Support delta-sharing-capabilities header in DeltaSharingRestClient (#…
Browse files Browse the repository at this point in the history
…338)

* delta sharing server changes

* Delta Sharing Server to Support queryDeltaLog

* fix import

* fix tests

* actions work

* handle delta-sharing-capabilities header

* update header name

* Initialize DeltaSharingRestClient in its own file

* clean up imports

* further cleanup

* Support delta-sharing-capabilities header in DeltaSharingRestClient

* fix

* tests passed

* exception tests passed

* check responded format

* add comment

* add comment
  • Loading branch information
linzhou-db authored Jul 22, 2023
1 parent ef18d8d commit 228f22e
Show file tree
Hide file tree
Showing 10 changed files with 724 additions and 186 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ lazy val commonSettings = Seq(
)
)

lazy val root = (project in file(".")).aggregate(spark, server)
lazy val root = (project in file(".")).aggregate(client, spark, server)

lazy val client = (project in file("client")) settings(
name := "delta-sharing-client",
Expand Down
168 changes: 135 additions & 33 deletions client/src/main/scala/io/delta/sharing/client/DeltaSharingClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ class DeltaSharingRestClient(
numRetries: Int = 10,
maxRetryDuration: Long = Long.MaxValue,
sslTrustAll: Boolean = false,
forStreaming: Boolean = false) extends DeltaSharingClient {
forStreaming: Boolean = false,
responseFormat: String = DeltaSharingRestClient.RESPONSE_FORMAT_PARQUET
) extends DeltaSharingClient with Logging {

@volatile private var created = false

Expand Down Expand Up @@ -193,9 +195,10 @@ class DeltaSharingRestClient(
val target =
getTargetUrl(s"/shares/$encodedShareName/schemas/$encodedSchemaName/tables/" +
s"$encodedTableName/version$encodedParam")
val (version, _) = getResponse(new HttpGet(target), true, true)
val (version, _, _) = getResponse(new HttpGet(target), true, true)
version.getOrElse {
throw new IllegalStateException("Cannot find Delta-Table-Version in the header")
throw new IllegalStateException(s"Cannot find " +
s"${DeltaSharingRestClient.RESPONSE_TABLE_VERSION_HEADER_KEY} in the header")
}
}

Expand All @@ -205,7 +208,18 @@ class DeltaSharingRestClient(
val encodedTableName = URLEncoder.encode(table.name, "UTF-8")
val target = getTargetUrl(
s"/shares/$encodedShareName/schemas/$encodedSchemaName/tables/$encodedTableName/metadata")
val (version, lines) = getNDJson(target)
val (version, respondedFormat, lines) = getNDJson(target)
if (responseFormat != respondedFormat) {
// This could only happen when the asked format is delta and the server doesn't support
// the requested format.
logWarning(s"RespondedFormat($respondedFormat) is different from requested responseFormat(" +
s"$responseFormat) for getMetadata.${table.share}.${table.schema}.${table.name}.")
}
// To ensure that it works with delta sharing server that doesn't support the requested format.
if (respondedFormat == DeltaSharingRestClient.RESPONSE_FORMAT_DELTA) {
return DeltaTableMetadata(version, lines = lines)
}

val protocol = JsonUtils.fromJson[SingleAction](lines(0)).protocol
checkProtocol(protocol)
val metadata = JsonUtils.fromJson[SingleAction](lines(1)).metaData
Expand Down Expand Up @@ -235,7 +249,7 @@ class DeltaSharingRestClient(
val encodedTableName = URLEncoder.encode(table.name, "UTF-8")
val target = getTargetUrl(
s"/shares/$encodedShareName/schemas/$encodedSchemaName/tables/$encodedTableName/query")
val (version, lines) = getNDJson(
val (version, respondedFormat, lines) = getNDJson(
target,
QueryTableRequest(
predicates,
Expand All @@ -247,6 +261,15 @@ class DeltaSharingRestClient(
jsonPredicateHints
)
)
if (responseFormat != respondedFormat) {
logWarning(s"RespondedFormat($respondedFormat) is different from requested responseFormat(" +
s"$responseFormat) for getFiles(versionAsOf-$versionAsOf, timestampAsOf-$timestampAsOf " +
s"for table ${table.share}.${table.schema}.${table.name}.")
}
// To ensure that it works with delta sharing server that doesn't support the requested format.
if (respondedFormat == DeltaSharingRestClient.RESPONSE_FORMAT_DELTA) {
return DeltaTableFiles(version, lines = lines)
}
require(versionAsOf.isEmpty || versionAsOf.get == version)
val protocol = JsonUtils.fromJson[SingleAction](lines(0)).protocol
checkProtocol(protocol)
Expand All @@ -265,8 +288,17 @@ class DeltaSharingRestClient(
val encodedTableName = URLEncoder.encode(table.name, "UTF-8")
val target = getTargetUrl(
s"/shares/$encodedShareName/schemas/$encodedSchemaName/tables/$encodedTableName/query")
val (version, lines) = getNDJson(
val (version, respondedFormat, lines) = getNDJson(
target, QueryTableRequest(Nil, None, None, None, Some(startingVersion), endingVersion, None))
if (responseFormat != respondedFormat) {
logWarning(s"RespondedFormat($respondedFormat) is different from requested responseFormat(" +
s"$responseFormat) for getFiles(startingVersion-$startingVersion, endingVersion-" +
s"$endingVersion) for table ${table.share}.${table.schema}.${table.name}.")
}
// To ensure that it works with delta sharing server that doesn't support the requested format.
if (respondedFormat == DeltaSharingRestClient.RESPONSE_FORMAT_DELTA) {
return DeltaTableFiles(version, lines = lines)
}
val protocol = JsonUtils.fromJson[SingleAction](lines(0)).protocol
checkProtocol(protocol)
val metadata = JsonUtils.fromJson[SingleAction](lines(1)).metaData
Expand Down Expand Up @@ -300,7 +332,16 @@ class DeltaSharingRestClient(

val target = getTargetUrl(
s"/shares/$encodedShare/schemas/$encodedSchema/tables/$encodedTable/changes?$encodedParams")
val (version, lines) = getNDJson(target, requireVersion = false)
val (version, respondedFormat, lines) = getNDJson(target, requireVersion = false)
if (responseFormat != respondedFormat) {
logWarning(s"RespondedFormat($respondedFormat) is different from requested responseFormat(" +
s"$responseFormat) for getCDFFiles(cdfOptions-$cdfOptions) for table " +
s"${table.share}.${table.schema}.${table.name}.")
}
// To ensure that it works with delta sharing server that doesn't support the requested format.
if (respondedFormat == DeltaSharingRestClient.RESPONSE_FORMAT_DELTA) {
return DeltaTableFiles(version, lines = lines)
}
val protocol = JsonUtils.fromJson[SingleAction](lines(0)).protocol
checkProtocol(protocol)
val metadata = JsonUtils.fromJson[SingleAction](lines(1)).metaData
Expand Down Expand Up @@ -340,30 +381,60 @@ class DeltaSharingRestClient(
}.mkString("&")
}

private def getNDJson(target: String, requireVersion: Boolean = true): (Long, Seq[String]) = {
val (version, lines) = getResponse(new HttpGet(target))
version.getOrElse {
if (requireVersion) {
throw new IllegalStateException("Cannot find Delta-Table-Version in the header")
} else {
0L
}
} -> lines
private def getNDJson(
target: String, requireVersion: Boolean = true): (Long, String, Seq[String]) = {
val (version, capabilities, lines) = getResponse(new HttpGet(target))
(
version.getOrElse {
if (requireVersion) {
throw new IllegalStateException(s"Cannot find " +
s"${DeltaSharingRestClient.RESPONSE_TABLE_VERSION_HEADER_KEY} in the header")
} else {
0L
}
},
getRespondedFormat(capabilities),
lines
)
}

private def getNDJson[T: Manifest](target: String, data: T): (Long, Seq[String]) = {
private def getNDJson[T: Manifest](target: String, data: T): (Long, String, Seq[String]) = {
val httpPost = new HttpPost(target)
val json = JsonUtils.toJson(data)
httpPost.setHeader("Content-type", "application/json")
httpPost.setEntity(new StringEntity(json, UTF_8))
val (version, lines) = getResponse(httpPost)
version.getOrElse {
throw new IllegalStateException("Cannot find Delta-Table-Version in the header")
} -> lines
val (version, capabilities, lines) = getResponse(httpPost)
(
version.getOrElse {
throw new IllegalStateException("Cannot find Delta-Table-Version in the header")
},
getRespondedFormat(capabilities),
lines
)
}

private def getRespondedFormat(capabilities: Option[String]): String = {
val capabilitiesMap = getDeltaSharingCapabilitiesMap(capabilities)
capabilitiesMap.get(DeltaSharingRestClient.RESPONSE_FORMAT).getOrElse(
DeltaSharingRestClient.RESPONSE_FORMAT_PARQUET
)
}
private def getDeltaSharingCapabilitiesMap(capabilities: Option[String]): Map[String, String] = {
if (capabilities.isEmpty) {
return Map.empty[String, String]
}
capabilities.get.toLowerCase().split(",").map { capability =>
val splits = capability.split("=")
if (splits.size == 2) {
(splits(0), splits(1))
} else {
("", "")
}
}.toMap
}

private def getJson[R: Manifest](target: String): R = {
val (_, response) = getResponse(new HttpGet(target), false, true)
val (_, _, response) = getResponse(new HttpGet(target), false, true)
if (response.size != 1) {
throw new IllegalStateException(
"Unexpected response for target: " + target + ", response=" + response
Expand Down Expand Up @@ -394,8 +465,7 @@ class DeltaSharingRestClient(
}
}

// TODO: [linzhou] mark this as private once tests are migrated.
def prepareHeaders(httpRequest: HttpRequestBase): HttpRequestBase = {
private[client] def prepareHeaders(httpRequest: HttpRequestBase): HttpRequestBase = {
val customeHeaders = profileProvider.getCustomHeaders
if (customeHeaders.contains(HttpHeaders.AUTHORIZATION)
|| customeHeaders.contains(HttpHeaders.USER_AGENT)) {
Expand All @@ -406,7 +476,8 @@ class DeltaSharingRestClient(
}
val headers = Map(
HttpHeaders.AUTHORIZATION -> s"Bearer ${profileProvider.getProfile.bearerToken}",
HttpHeaders.USER_AGENT -> getUserAgent()
HttpHeaders.USER_AGENT -> getUserAgent(),
DeltaSharingRestClient.DELTA_SHARING_CAPABILITIES_HEADER -> getDeltaSharingCapabilities()
) ++ customeHeaders
headers.foreach(header => httpRequest.setHeader(header._1, header._2))

Expand All @@ -426,7 +497,7 @@ class DeltaSharingRestClient(
httpRequest: HttpRequestBase,
allowNoContent: Boolean = false,
fetchAsOneString: Boolean = false
): (Option[Long], Seq[String]) = {
): (Option[Long], Option[String], Seq[String]) = {
RetryUtils.runWithExponentialBackoff(numRetries, maxRetryDuration) {
val profile = profileProvider.getProfile
val response = client.execute(
Expand Down Expand Up @@ -476,7 +547,15 @@ class DeltaSharingRestClient(
s"HTTP request failed with status: $status $responseToShow. $additionalErrorInfo",
statusCode)
}
Option(response.getFirstHeader("Delta-Table-Version")).map(_.getValue.toLong) -> lines
(
Option(
response.getFirstHeader(DeltaSharingRestClient.RESPONSE_TABLE_VERSION_HEADER_KEY)
).map(_.getValue.toLong),
Option(
response.getFirstHeader(DeltaSharingRestClient.DELTA_SHARING_CAPABILITIES_HEADER)
).map(_.getValue),
lines
)
} finally {
response.close()
}
Expand All @@ -494,6 +573,13 @@ class DeltaSharingRestClient(
s"$sparkAgent/$VERSION" + DeltaSharingRestClient.USER_AGENT
}

// The value for delta-sharing-capabilities header, semicolon separated capabilities.
// Each capability is in the format of "key=value1,value2", values are separated by comma.
// Example: "capability1=value1;capability2=value3,value4,value5"
private def getDeltaSharingCapabilities(): String = {
s"${DeltaSharingRestClient.RESPONSE_FORMAT}=$responseFormat"
}

def close(): Unit = {
if (created) {
try client.close() finally created = false
Expand All @@ -509,6 +595,11 @@ object DeltaSharingRestClient extends Logging {
val CURRENT = 1

val SPARK_STRUCTURED_STREAMING = "Delta-Sharing-SparkStructuredStreaming"
val DELTA_SHARING_CAPABILITIES_HEADER = "delta-sharing-capabilities"
val RESPONSE_TABLE_VERSION_HEADER_KEY = "Delta-Table-Version"
val RESPONSE_FORMAT = "responseformat"
val RESPONSE_FORMAT_DELTA = "delta"
val RESPONSE_FORMAT_PARQUET = "parquet"

lazy val USER_AGENT = {
try {
Expand Down Expand Up @@ -546,7 +637,11 @@ object DeltaSharingRestClient extends Logging {
if (value == null) "<unknown>" else value.replace(' ', '_')
}

def apply(profileFile: String, forStreaming: Boolean = false): DeltaSharingClient = {
def apply(
profileFile: String,
forStreaming: Boolean = false,
responseFormat: String = DeltaSharingRestClient.RESPONSE_FORMAT_PARQUET
): DeltaSharingClient = {
val sqlConf = SparkSession.active.sessionState.conf

val profileProviderClass = ConfUtils.profileProviderClass(sqlConf)
Expand All @@ -565,14 +660,21 @@ object DeltaSharingRestClient extends Logging {

val clientClass = ConfUtils.clientClass(sqlConf)
Class.forName(clientClass)
.getConstructor(classOf[DeltaSharingProfileProvider],
classOf[Int], classOf[Int], classOf[Long], classOf[Boolean], classOf[Boolean])
.newInstance(profileProvider,
.getConstructor(
classOf[DeltaSharingProfileProvider],
classOf[Int],
classOf[Int],
classOf[Long],
classOf[Boolean],
classOf[Boolean],
classOf[String]
).newInstance(profileProvider,
java.lang.Integer.valueOf(timeoutInSeconds),
java.lang.Integer.valueOf(numRetries),
java.lang.Long.valueOf(maxRetryDurationMillis),
java.lang.Boolean.valueOf(sslTrustAll),
java.lang.Boolean.valueOf(forStreaming))
.asInstanceOf[DeltaSharingClient]
java.lang.Boolean.valueOf(forStreaming),
responseFormat
).asInstanceOf[DeltaSharingClient]
}
}
12 changes: 7 additions & 5 deletions client/src/main/scala/io/delta/sharing/client/model.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,20 @@ private[sharing] object CDFColumnInfo {

private[sharing] case class DeltaTableMetadata(
version: Long,
protocol: Protocol,
metadata: Metadata)
protocol: Protocol = null,
metadata: Metadata = null,
lines: Seq[String] = Nil)

private[sharing] case class DeltaTableFiles(
version: Long,
protocol: Protocol,
metadata: Metadata,
protocol: Protocol = null,
metadata: Metadata = null,
files: Seq[AddFile] = Nil,
addFiles: Seq[AddFileForCDF] = Nil,
cdfFiles: Seq[AddCDCFile] = Nil,
removeFiles: Seq[RemoveFile] = Nil,
additionalMetadatas: Seq[Metadata] = Nil)
additionalMetadatas: Seq[Metadata] = Nil,
lines: Seq[String] = Nil)

private[sharing] case class Share(name: String)

Expand Down
Loading

0 comments on commit 228f22e

Please sign in to comment.