Skip to content
This repository has been archived by the owner on Jan 9, 2020. It is now read-only.

Commit

Permalink
[SPARK-19558][SQL] Add config key to register QueryExecutionListeners…
Browse files Browse the repository at this point in the history
… automatically.

This change adds a new SQL config key that is equivalent to SparkContext's
"spark.extraListeners", allowing users to register QueryExecutionListener
instances through the Spark configuration system instead of having to
explicitly do it in code.

The code used by SparkContext to implement the feature was refactored into
a helper method in the Utils class, and SQL's ExecutionListenerManager was
modified to use it to initialize listener declared in the configuration.

Unit tests were added to verify all the new functionality.

Author: Marcelo Vanzin <[email protected]>

Closes apache#19309 from vanzin/SPARK-19558.
  • Loading branch information
Marcelo Vanzin authored and gatorsmile committed Oct 10, 2017
1 parent bfc7e1f commit bd4eb9c
Show file tree
Hide file tree
Showing 9 changed files with 216 additions and 40 deletions.
38 changes: 5 additions & 33 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2344,41 +2344,13 @@ class SparkContext(config: SparkConf) extends Logging {
* (e.g. after the web UI and event logging listeners have been registered).
*/
private def setupAndStartListenerBus(): Unit = {
// Use reflection to instantiate listeners specified via `spark.extraListeners`
try {
val listenerClassNames: Seq[String] =
conf.get("spark.extraListeners", "").split(',').map(_.trim).filter(_ != "")
for (className <- listenerClassNames) {
// Use reflection to find the right constructor
val constructors = {
val listenerClass = Utils.classForName(className)
listenerClass
.getConstructors
.asInstanceOf[Array[Constructor[_ <: SparkListenerInterface]]]
conf.get(EXTRA_LISTENERS).foreach { classNames =>
val listeners = Utils.loadExtensions(classOf[SparkListenerInterface], classNames, conf)
listeners.foreach { listener =>
listenerBus.addToSharedQueue(listener)
logInfo(s"Registered listener ${listener.getClass().getName()}")
}
val constructorTakingSparkConf = constructors.find { c =>
c.getParameterTypes.sameElements(Array(classOf[SparkConf]))
}
lazy val zeroArgumentConstructor = constructors.find { c =>
c.getParameterTypes.isEmpty
}
val listener: SparkListenerInterface = {
if (constructorTakingSparkConf.isDefined) {
constructorTakingSparkConf.get.newInstance(conf)
} else if (zeroArgumentConstructor.isDefined) {
zeroArgumentConstructor.get.newInstance()
} else {
throw new SparkException(
s"$className did not have a zero-argument constructor or a" +
" single-argument constructor that accepts SparkConf. Note: if the class is" +
" defined inside of another Scala class, then its constructors may accept an" +
" implicit parameter that references the enclosing class; in this case, you must" +
" define the listener as a top-level class in order to prevent this extra" +
" parameter from breaking Spark's ability to find a valid constructor.")
}
}
listenerBus.addToSharedQueue(listener)
logInfo(s"Registered listener $className")
}
} catch {
case e: Exception =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -419,4 +419,11 @@ package object config {
.stringConf
.toSequence
.createWithDefault(Nil)

private[spark] val EXTRA_LISTENERS = ConfigBuilder("spark.extraListeners")
.doc("Class names of listeners to add to SparkContext during initialization.")
.stringConf
.toSequence
.createOptional

}
57 changes: 56 additions & 1 deletion core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.util

import java.io._
import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo}
import java.lang.reflect.InvocationTargetException
import java.math.{MathContext, RoundingMode}
import java.net._
import java.nio.ByteBuffer
Expand All @@ -37,7 +38,7 @@ import scala.collection.Map
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
import scala.reflect.ClassTag
import scala.util.Try
import scala.util.{Failure, Success, Try}
import scala.util.control.{ControlThrowable, NonFatal}
import scala.util.matching.Regex

Expand Down Expand Up @@ -2687,6 +2688,60 @@ private[spark] object Utils extends Logging {
def stringToSeq(str: String): Seq[String] = {
str.split(",").map(_.trim()).filter(_.nonEmpty)
}

/**
* Create instances of extension classes.
*
* The classes in the given list must:
* - Be sub-classes of the given base class.
* - Provide either a no-arg constructor, or a 1-arg constructor that takes a SparkConf.
*
* The constructors are allowed to throw "UnsupportedOperationException" if the extension does not
* want to be registered; this allows the implementations to check the Spark configuration (or
* other state) and decide they do not need to be added. A log message is printed in that case.
* Other exceptions are bubbled up.
*/
def loadExtensions[T](extClass: Class[T], classes: Seq[String], conf: SparkConf): Seq[T] = {
classes.flatMap { name =>
try {
val klass = classForName(name)
require(extClass.isAssignableFrom(klass),
s"$name is not a subclass of ${extClass.getName()}.")

val ext = Try(klass.getConstructor(classOf[SparkConf])) match {
case Success(ctor) =>
ctor.newInstance(conf)

case Failure(_) =>
klass.getConstructor().newInstance()
}

Some(ext.asInstanceOf[T])
} catch {
case _: NoSuchMethodException =>
throw new SparkException(
s"$name did not have a zero-argument constructor or a" +
" single-argument constructor that accepts SparkConf. Note: if the class is" +
" defined inside of another Scala class, then its constructors may accept an" +
" implicit parameter that references the enclosing class; in this case, you must" +
" define the class as a top-level class in order to prevent this extra" +
" parameter from breaking Spark's ability to find a valid constructor.")

case e: InvocationTargetException =>
e.getCause() match {
case uoe: UnsupportedOperationException =>
logDebug(s"Extension $name not being initialized.", uoe)
logInfo(s"Extension $name not being initialized.")
None

case null => throw e

case cause => throw cause
}
}
}
}

}

private[util] object CallerContext extends Logging {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.scalatest.Matchers

import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.internal.config.LISTENER_BUS_EVENT_QUEUE_CAPACITY
import org.apache.spark.internal.config._
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.util.{ResetSystemProperties, RpcUtils}

Expand Down Expand Up @@ -446,13 +446,13 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
classOf[FirehoseListenerThatAcceptsSparkConf],
classOf[BasicJobCounter])
val conf = new SparkConf().setMaster("local").setAppName("test")
.set("spark.extraListeners", listeners.map(_.getName).mkString(","))
.set(EXTRA_LISTENERS, listeners.map(_.getName))
sc = new SparkContext(conf)
sc.listenerBus.listeners.asScala.count(_.isInstanceOf[BasicJobCounter]) should be (1)
sc.listenerBus.listeners.asScala
.count(_.isInstanceOf[ListenerThatAcceptsSparkConf]) should be (1)
sc.listenerBus.listeners.asScala
.count(_.isInstanceOf[FirehoseListenerThatAcceptsSparkConf]) should be (1)
.count(_.isInstanceOf[FirehoseListenerThatAcceptsSparkConf]) should be (1)
}

test("add and remove listeners to/from LiveListenerBus queues") {
Expand Down
56 changes: 55 additions & 1 deletion core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ import org.apache.commons.math3.stat.inference.ChiSquareTest
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext}
import org.apache.spark.{SparkConf, SparkException, SparkFunSuite, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.scheduler.SparkListener

class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {

Expand Down Expand Up @@ -1110,4 +1111,57 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
Utils.tryWithSafeFinallyAndFailureCallbacks {}(catchBlock = {}, finallyBlock = {})
TaskContext.unset
}

test("load extensions") {
val extensions = Seq(
classOf[SimpleExtension],
classOf[ExtensionWithConf],
classOf[UnregisterableExtension]).map(_.getName())

val conf = new SparkConf(false)
val instances = Utils.loadExtensions(classOf[Object], extensions, conf)
assert(instances.size === 2)
assert(instances.count(_.isInstanceOf[SimpleExtension]) === 1)

val extWithConf = instances.find(_.isInstanceOf[ExtensionWithConf])
.map(_.asInstanceOf[ExtensionWithConf])
.get
assert(extWithConf.conf eq conf)

class NestedExtension { }

val invalid = Seq(classOf[NestedExtension].getName())
intercept[SparkException] {
Utils.loadExtensions(classOf[Object], invalid, conf)
}

val error = Seq(classOf[ExtensionWithError].getName())
intercept[IllegalArgumentException] {
Utils.loadExtensions(classOf[Object], error, conf)
}

val wrongType = Seq(classOf[ListenerImpl].getName())
intercept[IllegalArgumentException] {
Utils.loadExtensions(classOf[Seq[_]], wrongType, conf)
}
}

}

private class SimpleExtension

private class ExtensionWithConf(val conf: SparkConf)

private class UnregisterableExtension {

throw new UnsupportedOperationException()

}

private class ExtensionWithError {

throw new IllegalArgumentException()

}

private class ListenerImpl extends SparkListener
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,12 @@ object StaticSQLConf {
"implement Function1[SparkSessionExtension, Unit], and must have a no-args constructor.")
.stringConf
.createOptional

val QUERY_EXECUTION_LISTENERS = buildStaticConf("spark.sql.queryExecutionListeners")
.doc("List of class names implementing QueryExecutionListener that will be automatically " +
"added to newly created sessions. The classes should have either a no-arg constructor, " +
"or a constructor that expects a SparkConf argument.")
.stringConf
.toSequence
.createOptional
}
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,8 @@ abstract class BaseSessionStateBuilder(
* This gets cloned from parent if available, otherwise is a new instance is created.
*/
protected def listenerManager: ExecutionListenerManager = {
parentState.map(_.listenerManager.clone()).getOrElse(new ExecutionListenerManager)
parentState.map(_.listenerManager.clone()).getOrElse(
new ExecutionListenerManager(session.sparkContext.conf))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@ import java.util.concurrent.locks.ReentrantReadWriteLock
import scala.collection.mutable.ListBuffer
import scala.util.control.NonFatal

import org.apache.spark.SparkConf
import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.internal.StaticSQLConf._
import org.apache.spark.util.Utils

/**
* :: Experimental ::
Expand Down Expand Up @@ -72,7 +75,14 @@ trait QueryExecutionListener {
*/
@Experimental
@InterfaceStability.Evolving
class ExecutionListenerManager private[sql] () extends Logging {
class ExecutionListenerManager private extends Logging {

private[sql] def this(conf: SparkConf) = {
this()
conf.get(QUERY_EXECUTION_LISTENERS).foreach { classNames =>
Utils.loadExtensions(classOf[QueryExecutionListener], classNames, conf).foreach(register)
}
}

/**
* Registers the specified [[QueryExecutionListener]].
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.util

import java.util.concurrent.atomic.AtomicInteger

import org.apache.spark._
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.internal.StaticSQLConf._

class ExecutionListenerManagerSuite extends SparkFunSuite {

import CountingQueryExecutionListener._

test("register query execution listeners using configuration") {
val conf = new SparkConf(false)
.set(QUERY_EXECUTION_LISTENERS, Seq(classOf[CountingQueryExecutionListener].getName()))

val mgr = new ExecutionListenerManager(conf)
assert(INSTANCE_COUNT.get() === 1)
mgr.onSuccess(null, null, 42L)
assert(CALLBACK_COUNT.get() === 1)

val clone = mgr.clone()
assert(INSTANCE_COUNT.get() === 1)

clone.onSuccess(null, null, 42L)
assert(CALLBACK_COUNT.get() === 2)
}

}

private class CountingQueryExecutionListener extends QueryExecutionListener {

import CountingQueryExecutionListener._

INSTANCE_COUNT.incrementAndGet()

override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
CALLBACK_COUNT.incrementAndGet()
}

override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
CALLBACK_COUNT.incrementAndGet()
}

}

private object CountingQueryExecutionListener {

val CALLBACK_COUNT = new AtomicInteger()
val INSTANCE_COUNT = new AtomicInteger()

}

0 comments on commit bd4eb9c

Please sign in to comment.