Skip to content

Commit

Permalink
[Improve] spark-submit improvements (#3900)
Browse files Browse the repository at this point in the history
  • Loading branch information
wolfboys authored Jul 20, 2024
1 parent af65569 commit cd9db37
Show file tree
Hide file tree
Showing 11 changed files with 175 additions and 399 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,11 @@
* limitations under the License.
*/

package org.apache.streampark.flink.proxy
package org.apache.streampark.common.util

import java.io.{File, IOException}
import java.net.{URL, URLClassLoader}
import java.util
import java.util.function.Consumer
import java.util.regex.Pattern

import scala.util.Try

Expand All @@ -36,40 +34,12 @@ import scala.util.Try
class ChildFirstClassLoader(
urls: Array[URL],
parent: ClassLoader,
flinkResourcePattern: Pattern,
classLoadingExceptionHandler: Consumer[Throwable])
parentFirstClasses: List[String],
loadJarFilter: String => Boolean)
extends URLClassLoader(urls, parent) {

ClassLoader.registerAsParallelCapable()

def this(urls: Array[URL], parent: ClassLoader, flinkResourcePattern: Pattern) {
this(
urls,
parent,
flinkResourcePattern,
(t: Throwable) => throw t)
}

ClassLoader.registerAsParallelCapable()

private val FLINK_PATTERN =
Pattern.compile("flink-(.*).jar", Pattern.CASE_INSENSITIVE | Pattern.DOTALL)

private val JAR_PROTOCOL = "jar"

private val PARENT_FIRST_PATTERNS = List(
"java.",
"javax.xml",
"org.slf4j",
"org.apache.log4j",
"org.apache.logging",
"org.apache.commons.logging",
"org.apache.commons.cli",
"ch.qos.logback",
"org.xml",
"org.w3c",
"org.apache.hadoop")

@throws[ClassNotFoundException]
override def loadClass(name: String, resolve: Boolean): Class[_] = {
try {
Expand All @@ -78,7 +48,7 @@ class ChildFirstClassLoader(
super.findLoadedClass(name) match {
case null =>
// check whether the class should go parent-first
PARENT_FIRST_PATTERNS.find(name.startsWith) match {
parentFirstClasses.find(name.startsWith) match {
case Some(_) => super.loadClass(name, resolve)
case _ => Try(findClass(name)).getOrElse(super.loadClass(name, resolve))
}
Expand All @@ -90,9 +60,7 @@ class ChildFirstClassLoader(
}
}
} catch {
case e: Throwable =>
classLoadingExceptionHandler.accept(e)
null
case e: Throwable => throw e
}
}

Expand All @@ -105,20 +73,14 @@ class ChildFirstClassLoader(
}

/**
* e.g. flinkResourcePattern: flink-1.12 <p> flink-1.12.jar/resource flink-1.14.jar/resource
* other.jar/resource \=> after filterFlinkShimsResource \=> flink-1.12.jar/resource
* other.jar/resource
*
* @param urlClassLoaderResource
* @return
*/
private def filterFlinkShimsResource(urlClassLoaderResource: URL): URL = {
if (urlClassLoaderResource != null && JAR_PROTOCOL == urlClassLoaderResource.getProtocol) {
private def filterResource(urlClassLoaderResource: URL): URL = {
if (urlClassLoaderResource != null && "jar" == urlClassLoaderResource.getProtocol) {
val spec = urlClassLoaderResource.getFile
val filename = new File(spec.substring(0, spec.indexOf("!/"))).getName
val matchState =
FLINK_PATTERN.matcher(filename).matches && !flinkResourcePattern.matcher(filename).matches
if (matchState) {
val jarName = new File(spec.substring(0, spec.indexOf("!/"))).getName
if (loadJarFilter(jarName)) {
return null
}
}
Expand All @@ -127,7 +89,7 @@ class ChildFirstClassLoader(

private def addResources(result: util.List[URL], resources: util.Enumeration[URL]) = {
while (resources.hasMoreElements) {
val urlClassLoaderResource = filterFlinkShimsResource(resources.nextElement)
val urlClassLoaderResource = filterResource(resources.nextElement)
if (urlClassLoaderResource != null) {
result.add(urlClassLoaderResource)
}
Expand All @@ -145,7 +107,7 @@ class ChildFirstClassLoader(
addResources(result, parent.getResources(name))
}
new util.Enumeration[URL]() {
final private[proxy] val iter = result.iterator
private[this] val iter = result.iterator

override def hasMoreElements: Boolean = iter.hasNext

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.streampark.flink.proxy
package org.apache.streampark.common.util

import java.io.{InputStream, IOException, ObjectInputStream, ObjectStreamClass}
import java.lang.reflect.Proxy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ trait FlinkClientTrait extends Logger {
|--------------------------------------- flink job start ---------------------------------------
| userFlinkHome : ${submitRequest.flinkVersion.flinkHome}
| flinkVersion : ${submitRequest.flinkVersion.version}
| appName : ${submitRequest.appName}
| appName : ${submitRequest.effectiveAppName}
| devMode : ${submitRequest.developmentMode.name()}
| execMode : ${submitRequest.executionMode.name()}
| k8sNamespace : ${submitRequest.kubernetesNamespace}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.streampark.flink.proxy

import org.apache.streampark.common.Constant
import org.apache.streampark.common.conf.{ConfigKeys, FlinkVersion}
import org.apache.streampark.common.util.{ClassLoaderUtils, Logger}
import org.apache.streampark.common.util.{ChildFirstClassLoader, ClassLoaderObjectInputStream, ClassLoaderUtils, Logger}
import org.apache.streampark.common.util.ImplicitsUtils._

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, ObjectOutputStream}
Expand All @@ -35,13 +35,28 @@ object FlinkShimsProxy extends Logger {

private[this] val VERIFY_SQL_CLASS_LOADER_CACHE = MutableMap[String, ClassLoader]()

private[this] val FLINK_JAR_PATTERN = Pattern.compile("flink-(.*).jar", Pattern.CASE_INSENSITIVE | Pattern.DOTALL)

private[this] val INCLUDE_PATTERN: Pattern = Pattern.compile("(streampark-shaded-jackson-)(.*).jar", Pattern.CASE_INSENSITIVE | Pattern.DOTALL)

private[this] def getFlinkShimsResourcePattern(majorVersion: String) =
Pattern.compile(s"flink-(.*)-$majorVersion(.*).jar", Pattern.CASE_INSENSITIVE | Pattern.DOTALL)

private[this] lazy val FLINK_SHIMS_PREFIX = "streampark-flink-shims_flink"

private[this] lazy val PARENT_FIRST_PATTERNS = List(
"java.",
"javax.xml",
"org.slf4j",
"org.apache.log4j",
"org.apache.logging",
"org.apache.commons.logging",
"org.apache.commons.cli",
"ch.qos.logback",
"org.xml",
"org.w3c",
"org.apache.hadoop")

/**
* Get shimsClassLoader to execute for scala API
*
Expand Down Expand Up @@ -97,10 +112,16 @@ object FlinkShimsProxy extends Logger {
new ChildFirstClassLoader(
shimsUrls.toArray,
Thread.currentThread().getContextClassLoader,
getFlinkShimsResourcePattern(flinkVersion.majorVersion))
PARENT_FIRST_PATTERNS,
jarName => loadJarFilter(jarName, flinkVersion))
})
}

private def loadJarFilter(jarName: String, flinkVersion: FlinkVersion): Boolean = {
val childFirstPattern = getFlinkShimsResourcePattern(flinkVersion.majorVersion)
FLINK_JAR_PATTERN.matcher(jarName).matches && !childFirstPattern.matcher(jarName).matches
}

private def addShimsUrls(flinkVersion: FlinkVersion, addShimUrl: File => Unit): Unit = {
val appHome = System.getProperty(ConfigKeys.KEY_APP_HOME)
require(
Expand Down Expand Up @@ -177,7 +198,8 @@ object FlinkShimsProxy extends Logger {
new ChildFirstClassLoader(
shimsUrls.toArray,
Thread.currentThread().getContextClassLoader,
getFlinkShimsResourcePattern(flinkVersion.majorVersion))
PARENT_FIRST_PATTERNS,
jarName => loadJarFilter(jarName, flinkVersion))
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,41 +52,37 @@ case class SubmitRequest(
@Nullable buildResult: BuildResult,
@Nullable extraParameter: JavaMap[String, Any]) {

val DEFAULT_SUBMIT_PARAM = Map[String, Any](
"spark.driver.cores" -> "1",
"spark.driver.memory" -> "1g",
"spark.executor.cores" -> "1",
"spark.executor.memory" -> "1g")

private[this] lazy val appProperties: Map[String, String] = getParameterMap(
KEY_SPARK_PROPERTY_PREFIX)

lazy val appMain: String = this.developmentMode match {
case FlinkDevelopmentMode.FLINK_SQL =>
Constant.STREAMPARK_SPARKSQL_CLIENT_CLASS
case FlinkDevelopmentMode.FLINK_SQL => Constant.STREAMPARK_SPARKSQL_CLIENT_CLASS
case _ => appProperties(KEY_FLINK_APPLICATION_MAIN_CLASS)
}

lazy val effectiveAppName: String =
if (this.appName == null) appProperties(KEY_FLINK_APP_NAME)
else this.appName
lazy val effectiveAppName: String = if (this.appName == null) {
appProperties(KEY_FLINK_APP_NAME)
} else {
this.appName
}

lazy val libs: List[URL] = {
val path = s"${Workspace.local.APP_WORKSPACE}/$id/lib"
Try(new File(path).listFiles().map(_.toURI.toURL).toList)
.getOrElse(List.empty[URL])
}

lazy val classPaths: List[URL] = sparkVersion.sparkLibs ++ libs

lazy val flinkSQL: String = extraParameter.get(KEY_FLINK_SQL()).toString

lazy val userJarFile: File = {
lazy val userJarPath: String = {
executionMode match {
case _ =>
checkBuildResult()
new File(buildResult.asInstanceOf[ShadedBuildResponse].shadedJarPath)
}
}

lazy val safePackageProgram: Boolean = {
sparkVersion.version.split("\\.").map(_.trim.toInt) match {
case Array(a, b, c) if a >= 3 => b > 1
case _ => false
buildResult.asInstanceOf[ShadedBuildResponse].shadedJarPath
}
}

Expand Down Expand Up @@ -133,7 +129,7 @@ case class SubmitRequest(
}

@throws[IOException]
def isSymlink(file: File): Boolean = {
private def isSymlink(file: File): Boolean = {
if (file == null) throw new NullPointerException("File must not be null")
Files.isSymbolicLink(file.toPath)
}
Expand Down Expand Up @@ -163,7 +159,7 @@ case class SubmitRequest(
}

@throws[Exception]
def checkBuildResult(): Unit = {
private def checkBuildResult(): Unit = {
executionMode match {
case _ =>
if (this.buildResult == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,4 @@

package org.apache.streampark.spark.client.conf

object SparkConfiguration {
val defaultParameters = Map[String, Any](
"spark.driver.cores" -> "1",
"spark.driver.memory" -> "1g",
"spark.executor.cores" -> "1",
"spark.executor.memory" -> "1g")

}
object SparkConfiguration {}
Loading

0 comments on commit cd9db37

Please sign in to comment.