diff --git a/rsc/src/main/java/com/cloudera/livy/rsc/RSCClient.java b/rsc/src/main/java/com/cloudera/livy/rsc/RSCClient.java index 564639e89..1835c7b37 100644 --- a/rsc/src/main/java/com/cloudera/livy/rsc/RSCClient.java +++ b/rsc/src/main/java/com/cloudera/livy/rsc/RSCClient.java @@ -61,7 +61,8 @@ public class RSCClient implements LivyClient { private ContextInfo contextInfo; private volatile boolean isAlive; - private volatile String replState; + + private SessionStateListener stateListener; RSCClient(RSCConf conf, Promise ctx) throws IOException { this.conf = conf; @@ -94,6 +95,10 @@ public void onFailure(Throwable error) { isAlive = true; } + public void registerStateListener(SessionStateListener stateListener) { + this.stateListener = stateListener; + } + private synchronized void connectToContext(final ContextInfo info) throws Exception { this.contextInfo = info; @@ -291,13 +296,6 @@ public Future getReplJobResults() throws Exception { return deferredCall(new BaseProtocol.GetReplJobResults(), ReplJobResults.class); } - /** - * @return Return the repl state. If this's not connected to a repl session, it will return null. - */ - public String getReplState() { - return replState; - } - private class ClientProtocol extends BaseProtocol { JobHandleImpl submit(Job job) { @@ -393,7 +391,10 @@ private void handle(ChannelHandlerContext ctx, JobStarted msg) { private void handle(ChannelHandlerContext ctx, ReplState msg) { LOG.trace("Received repl state for {}", msg.state); - replState = msg.state; + + if (stateListener != null) { + stateListener.onStateUpdated(msg.state); + } } } } diff --git a/rsc/src/main/java/com/cloudera/livy/rsc/SessionStateListener.java b/rsc/src/main/java/com/cloudera/livy/rsc/SessionStateListener.java new file mode 100644 index 000000000..446dc2463 --- /dev/null +++ b/rsc/src/main/java/com/cloudera/livy/rsc/SessionStateListener.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.cloudera.livy.rsc; + +public interface SessionStateListener { + + /** + * Action when state is updated. + */ + void onStateUpdated(String state); +} diff --git a/server/src/main/scala/com/cloudera/livy/server/interactive/InteractiveSession.scala b/server/src/main/scala/com/cloudera/livy/server/interactive/InteractiveSession.scala index 3bcdc108b..9bc83f0a2 100644 --- a/server/src/main/scala/com/cloudera/livy/server/interactive/InteractiveSession.scala +++ b/server/src/main/scala/com/cloudera/livy/server/interactive/InteractiveSession.scala @@ -38,7 +38,7 @@ import org.apache.spark.launcher.SparkLauncher import com.cloudera.livy._ import com.cloudera.livy.client.common.HttpMessages._ -import com.cloudera.livy.rsc.{PingJob, RSCClient, RSCConf} +import com.cloudera.livy.rsc.{PingJob, RSCClient, RSCConf, SessionStateListener} import com.cloudera.livy.rsc.driver.Statement import com.cloudera.livy.server.recovery.SessionStore import com.cloudera.livy.sessions._ @@ -354,11 +354,12 @@ class InteractiveSession( mockApp: Option[SparkApp]) // For unit test. extends Session(id, owner, livyConf) with SessionHeartbeat - with SparkAppListener { + with SparkAppListener + with SessionStateListener { import InteractiveSession._ - private var serverSideState: SessionState = initialState + @volatile private var _state: SessionState = initialState override protected val heartbeatTimeout: FiniteDuration = { val heartbeatTimeoutInSecond = heartbeatTimeoutS @@ -373,6 +374,8 @@ class InteractiveSession( _appId = appIdHint sessionStore.save(RECOVERY_SESSION_TYPE, recoveryMetadata) heartbeat() + // Register this class to RSCClient as a session state listener + client.foreach(_.registerStateListener(this)) private val app = mockApp.orElse { if (livyConf.isRunningOnYarn()) { @@ -413,14 +416,14 @@ class InteractiveSession( override def onJobFailed(job: JobHandle[Void], cause: Throwable): Unit = errorOut() override def onJobSucceeded(job: JobHandle[Void], result: Void): Unit = { - transition(SessionState.Running()) + transition(SessionState.Idle()) } private def errorOut(): Unit = { // Other code might call stop() to close the RPC channel. When RPC channel is closing, // this callback might be triggered. Check and don't call stop() to avoid nested called // if the session is already shutting down. - if (serverSideState != SessionState.ShuttingDown()) { + if (_state != SessionState.ShuttingDown()) { transition(SessionState.Error()) stop() app.foreach { a => @@ -438,17 +441,7 @@ class InteractiveSession( InteractiveRecoveryMetadata( id, appId, appTag, kind, heartbeatTimeout.toSeconds.toInt, owner, proxyUser, rscDriverUri) - override def state: SessionState = { - if (serverSideState.isInstanceOf[SessionState.Running]) { - // If session is in running state, return the repl state from RSCClient. - client - .flatMap(s => Option(s.getReplState)) - .map(SessionState(_)) - .getOrElse(SessionState.Busy()) // If repl state is unknown, assume repl is busy. - } else { - serverSideState - } - } + override def state: SessionState = _state override def stopSession(): Unit = { try { @@ -548,24 +541,24 @@ class InteractiveSession( // If the session crashed because of the error, the session should instead go to dead state. // Since these 2 transitions are triggered by different threads, there's a race condition. // Make sure we won't transit from dead to error state. - val areSameStates = serverSideState.getClass() == newState.getClass() - val transitFromInactiveToActive = !serverSideState.isActive && newState.isActive + val areSameStates = _state.getClass() == newState.getClass() + val transitFromInactiveToActive = !_state.isActive && newState.isActive if (!areSameStates && !transitFromInactiveToActive) { - debug(s"$this session state change from ${serverSideState} to $newState") - serverSideState = newState + debug(s"$this session state change from ${_state} to $newState") + _state = newState } } private def ensureActive(): Unit = synchronized { - require(serverSideState.isActive, "Session isn't active.") + require(_state.isActive, "Session isn't active.") require(client.isDefined, "Session is active but client hasn't been created.") } private def ensureRunning(): Unit = synchronized { - serverSideState match { - case SessionState.Running() => + _state match { + case SessionState.Idle() | SessionState.Busy() => Unit case _ => - throw new IllegalStateException("Session is in state %s" format serverSideState) + throw new IllegalStateException(s"Session is in state ${_state}") } } @@ -597,4 +590,6 @@ class InteractiveSession( } override def infoChanged(appInfo: AppInfo): Unit = { this.appInfo = appInfo } + + override def onStateUpdated(state: String): Unit = { transition(SessionState(state)) } }