Skip to content

Commit

Permalink
feat(ec2): try to dispose SSM session state #5616
Browse files Browse the repository at this point in the history
## Problem
The remote connection through VSCode does not terminate the SSM session
on close consistently.

Note: We do not get an event when the remote window is closed.

## Solution
Leverage two best-effort approaches:
- only allow a single connection from the toolkit to any given EC2
Instance. If a customer attempts to open another remote window in an EC2
instance, we can use that as a sign to terminate the old session.
- on toolkit shutdown (deactivate), remote any sessions that are still
running.

## Implementation Details
- Implement `Ec2RemoteEnvManager` to manage the remote environments.
Behaves like a map of instance ids to sessions ids that most importantly
maintains the invariant that any deleted item has its session
terminated.
- Refactor `packages/core/src/awsService/ec2/commands.ts` and
`packages/core/src/awsService/ec2/activation.ts` to allow for state
tracking in `EC2ConnectionManager`. This change also gives us an
opportunity to improve the testing infrastructure for this code.
---

Co-authored-by: JadenSimon <[email protected]>
Co-authored-by: Justin M. Keyes <[email protected]>
Co-authored-by: Weinstock <[email protected]>
  • Loading branch information
4 people authored Oct 25, 2024
1 parent 7dabb7c commit 142761e
Show file tree
Hide file tree
Showing 17 changed files with 372 additions and 61 deletions.
22 changes: 14 additions & 8 deletions packages/core/src/awsService/ec2/activation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import * as vscode from 'vscode'
import { ExtContext } from '../../shared/extensions'
import { Commands } from '../../shared/vscode/commands2'
import { telemetry } from '../../shared/telemetry/telemetry'
import { Ec2InstanceNode } from './explorer/ec2InstanceNode'
import { Ec2InstanceNode, tryRefreshNode } from './explorer/ec2InstanceNode'
import { copyTextCommand } from '../../awsexplorer/commands/copyText'
import { Ec2Node } from './explorer/ec2ParentNode'
import {
Expand All @@ -15,13 +15,15 @@ import {
rebootInstance,
startInstance,
stopInstance,
refreshExplorer,
openLogDocument,
linkToLaunchInstance,
openLogDocument,
} from './commands'
import { Ec2ConnecterMap } from './connectionManagerMap'
import { ec2LogsScheme } from '../../shared/constants'
import { Ec2LogDocumentProvider } from './ec2LogDocumentProvider'

const connectionManagers = new Ec2ConnecterMap()

export async function activate(ctx: ExtContext): Promise<void> {
ctx.extensionContext.subscriptions.push(
vscode.workspace.registerTextDocumentContentProvider(ec2LogsScheme, new Ec2LogDocumentProvider())
Expand All @@ -30,7 +32,7 @@ export async function activate(ctx: ExtContext): Promise<void> {
Commands.register('aws.ec2.openTerminal', async (node?: Ec2InstanceNode) => {
await telemetry.ec2_connectToInstance.run(async (span) => {
span.record({ ec2ConnectionType: 'ssm' })
await openTerminal(node)
await openTerminal(connectionManagers, node)
})
}),

Expand All @@ -42,30 +44,30 @@ export async function activate(ctx: ExtContext): Promise<void> {
}),

Commands.register('aws.ec2.openRemoteConnection', async (node?: Ec2Node) => {
await openRemoteConnection(node)
await openRemoteConnection(connectionManagers, node)
}),

Commands.register('aws.ec2.startInstance', async (node?: Ec2Node) => {
await telemetry.ec2_changeState.run(async (span) => {
span.record({ ec2InstanceState: 'start' })
await startInstance(node)
refreshExplorer(node)
await tryRefreshNode(node)
})
}),

Commands.register('aws.ec2.stopInstance', async (node?: Ec2Node) => {
await telemetry.ec2_changeState.run(async (span) => {
span.record({ ec2InstanceState: 'stop' })
await stopInstance(node)
refreshExplorer(node)
await tryRefreshNode(node)
})
}),

Commands.register('aws.ec2.rebootInstance', async (node?: Ec2Node) => {
await telemetry.ec2_changeState.run(async (span) => {
span.record({ ec2InstanceState: 'reboot' })
await rebootInstance(node)
refreshExplorer(node)
await tryRefreshNode(node)
})
}),

Expand All @@ -76,3 +78,7 @@ export async function activate(ctx: ExtContext): Promise<void> {
})
)
}

export async function deactivate(): Promise<void> {
connectionManagers.forEach(async (manager) => await manager.dispose())
}
23 changes: 6 additions & 17 deletions packages/core/src/awsService/ec2/commands.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,25 @@
*/
import { Ec2InstanceNode } from './explorer/ec2InstanceNode'
import { Ec2Node } from './explorer/ec2ParentNode'
import { Ec2ConnectionManager } from './model'
import { Ec2Prompter, instanceFilter, Ec2Selection } from './prompter'
import { SafeEc2Instance, Ec2Client } from '../../shared/clients/ec2Client'
import { copyToClipboard } from '../../shared/utilities/messages'
import { getLogger } from '../../shared/logger'
import { ec2LogSchema } from './ec2LogDocumentProvider'
import { getAwsConsoleUrl } from '../../shared/awsConsole'
import { showRegionPrompter } from '../../auth/utils'
import { openUrl } from '../../shared/utilities/vsCodeUtils'
import { showFile } from '../../shared/utilities/textDocumentUtilities'
import { Ec2ConnecterMap } from './connectionManagerMap'
import { Ec2Prompter, Ec2Selection, instanceFilter } from './prompter'

export function refreshExplorer(node?: Ec2Node) {
if (node) {
const n = node instanceof Ec2InstanceNode ? node.parent : node
n.refreshNode().catch((e) => {
getLogger().error('refreshNode failed: %s', (e as Error).message)
})
}
}

export async function openTerminal(node?: Ec2Node) {
export async function openTerminal(connectionManagers: Ec2ConnecterMap, node?: Ec2Node) {
const selection = await getSelection(node)

const connectionManager = new Ec2ConnectionManager(selection.region)
const connectionManager = connectionManagers.getOrInit(selection.region)
await connectionManager.attemptToOpenEc2Terminal(selection)
}

export async function openRemoteConnection(node?: Ec2Node) {
export async function openRemoteConnection(connectionManagers: Ec2ConnecterMap, node?: Ec2Node) {
const selection = await getSelection(node)
const connectionManager = new Ec2ConnectionManager(selection.region)
const connectionManager = connectionManagers.getOrInit(selection.region)
await connectionManager.tryOpenRemoteConnection(selection)
}

Expand Down
26 changes: 26 additions & 0 deletions packages/core/src/awsService/ec2/connectionManagerMap.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*!
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

import { getLogger } from '../../shared'
import { Ec2Connecter } from './model'

export class Ec2ConnecterMap extends Map<string, Ec2Connecter> {
private static warnSize: number = 25

public getOrInit(regionCode: string) {
return this.has(regionCode) ? this.get(regionCode)! : this.initManager(regionCode)
}

private initManager(regionCode: string): Ec2Connecter {
if (this.size >= Ec2ConnecterMap.warnSize) {
getLogger().warn(
`Connection manager exceeded threshold of ${Ec2ConnecterMap.warnSize} with ${this.size} active connections`
)
}
const newConnectionManager = new Ec2Connecter(regionCode)
this.set(regionCode, newConnectionManager)
return newConnectionManager
}
}
14 changes: 13 additions & 1 deletion packages/core/src/awsService/ec2/explorer/ec2InstanceNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ import { SafeEc2Instance } from '../../../shared/clients/ec2Client'
import globals from '../../../shared/extensionGlobals'
import { getIconCode } from '../utils'
import { Ec2Selection } from '../prompter'
import { Ec2ParentNode } from './ec2ParentNode'
import { Ec2Node, Ec2ParentNode } from './ec2ParentNode'
import { EC2 } from 'aws-sdk'
import { getLogger } from '../../../shared'

export const Ec2InstanceRunningContext = 'awsEc2RunningNode'
export const Ec2InstanceStoppedContext = 'awsEc2StoppedNode'
Expand Down Expand Up @@ -101,3 +102,14 @@ export class Ec2InstanceNode extends AWSTreeNodeBase implements AWSResourceNode
await vscode.commands.executeCommand('aws.refreshAwsExplorerNode', this)
}
}

export async function tryRefreshNode(node?: Ec2Node) {
if (node) {
const n = node instanceof Ec2InstanceNode ? node.parent : node
try {
await n.refreshNode()
} catch (e) {
getLogger().error('refreshNode failed: %s', (e as Error).message)
}
}
}
47 changes: 27 additions & 20 deletions packages/core/src/awsService/ec2/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
*/
import * as vscode from 'vscode'
import { Session } from 'aws-sdk/clients/ssm'
import { IAM, SSM } from 'aws-sdk'
import { EC2, IAM, SSM } from 'aws-sdk'
import { Ec2Selection } from './prompter'
import { getOrInstallCli } from '../../shared/utilities/cliUtils'
import { isCloud9 } from '../../shared/extensionUtilities'
Expand All @@ -25,20 +25,24 @@ import { createBoundProcess } from '../../codecatalyst/model'
import { getLogger } from '../../shared/logger/logger'
import { CancellationError, Timeout } from '../../shared/utilities/timeoutUtils'
import { showMessageWithCancel } from '../../shared/utilities/messages'
import { SshConfig, sshLogFileLocation } from '../../shared/sshConfig'
import { SshConfig } from '../../shared/sshConfig'
import { SshKeyPair } from './sshKeyPair'
import { Ec2SessionTracker } from './remoteSessionManager'
import { getEc2SsmEnv } from './utils'

export type Ec2ConnectErrorCode = 'EC2SSMStatus' | 'EC2SSMPermission' | 'EC2SSMConnect' | 'EC2SSMAgentStatus'

interface Ec2RemoteEnv extends VscodeRemoteConnection {
export interface Ec2RemoteEnv extends VscodeRemoteConnection {
selection: Ec2Selection
keyPair: SshKeyPair
ssmSession: SSM.StartSessionResponse
}

export class Ec2ConnectionManager {
export class Ec2Connecter implements vscode.Disposable {
protected ssmClient: SsmClient
protected ec2Client: Ec2Client
protected iamClient: DefaultIamClient
protected sessionManager: Ec2SessionTracker

private policyDocumentationUri = vscode.Uri.parse(
'https://docs.aws.amazon.com/systems-manager/latest/userguide/session-manager-getting-started-instance-profile.html'
Expand All @@ -52,6 +56,7 @@ export class Ec2ConnectionManager {
this.ssmClient = this.createSsmSdkClient()
this.ec2Client = this.createEc2SdkClient()
this.iamClient = this.createIamSdkClient()
this.sessionManager = new Ec2SessionTracker(regionCode, this.ssmClient)
}

protected createSsmSdkClient(): SsmClient {
Expand All @@ -66,6 +71,18 @@ export class Ec2ConnectionManager {
return new DefaultIamClient(this.regionCode)
}

public async addActiveSession(sessionId: SSM.SessionId, instanceId: EC2.InstanceId): Promise<void> {
await this.sessionManager.addSession(instanceId, sessionId)
}

public async dispose(): Promise<void> {
await this.sessionManager.dispose()
}

public isConnectedTo(instanceId: string): boolean {
return this.sessionManager.isConnectedTo(instanceId)
}

public async getAttachedIamRole(instanceId: string): Promise<IAM.Role | undefined> {
const IamInstanceProfile = await this.ec2Client.getAttachedIamInstanceProfile(instanceId)
if (IamInstanceProfile && IamInstanceProfile.Arn) {
Expand Down Expand Up @@ -183,6 +200,7 @@ export class Ec2ConnectionManager {
this.throwGeneralConnectionError(selection, err as Error)
}
}

public async prepareEc2RemoteEnvWithProgress(selection: Ec2Selection, remoteUser: string): Promise<Ec2RemoteEnv> {
const timeout = new Timeout(60000)
await showMessageWithCancel('AWS: Opening remote connection...', timeout)
Expand All @@ -204,8 +222,10 @@ export class Ec2ConnectionManager {

throw err
}
const session = await this.ssmClient.startSession(selection.instanceId, 'AWS-StartSSHSession')
const vars = getEc2SsmEnv(selection, ssm, session)
const ssmSession = await this.ssmClient.startSession(selection.instanceId, 'AWS-StartSSHSession')
await this.addActiveSession(selection.instanceId, ssmSession.SessionId!)

const vars = getEc2SsmEnv(selection, ssm, ssmSession)
const envProvider = async () => {
return { [sshAgentSocketVariable]: await startSshAgent(), ...vars }
}
Expand All @@ -223,6 +243,7 @@ export class Ec2ConnectionManager {
SessionProcess,
selection,
keyPair,
ssmSession,
}
}

Expand Down Expand Up @@ -267,17 +288,3 @@ export class Ec2ConnectionManager {
throw new ToolkitError(`Unrecognized OS name ${osName} on instance ${instanceId}`, { code: 'UnknownEc2OS' })
}
}

function getEc2SsmEnv(selection: Ec2Selection, ssmPath: string, session: SSM.StartSessionResponse): NodeJS.ProcessEnv {
return Object.assign(
{
AWS_REGION: selection.region,
AWS_SSM_CLI: ssmPath,
LOG_FILE_LOCATION: sshLogFileLocation('ec2', selection.instanceId),
STREAM_URL: session.StreamUrl,
SESSION_ID: session.SessionId,
TOKEN: session.TokenValue,
},
process.env
)
}
8 changes: 8 additions & 0 deletions packages/core/src/awsService/ec2/prompter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import { isValidResponse } from '../../shared/wizards/wizard'
import { CancellationError } from '../../shared/utilities/timeoutUtils'
import { AsyncCollection } from '../../shared/utilities/asyncCollection'
import { getIconCode } from './utils'
import { Ec2Node } from './explorer/ec2ParentNode'
import { Ec2InstanceNode } from './explorer/ec2InstanceNode'

export type instanceFilter = (instance: SafeEc2Instance) => boolean
export interface Ec2Selection {
Expand Down Expand Up @@ -72,3 +74,9 @@ export class Ec2Prompter {
)
}
}

export async function getSelection(node?: Ec2Node, filter?: instanceFilter): Promise<Ec2Selection> {
const prompter = new Ec2Prompter(filter)
const selection = node && node instanceof Ec2InstanceNode ? node.toSelection() : await prompter.promptUser()
return selection
}
40 changes: 40 additions & 0 deletions packages/core/src/awsService/ec2/remoteSessionManager.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*!
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

import { EC2, SSM } from 'aws-sdk'
import { SsmClient } from '../../shared/clients/ssmClient'
import { Disposable } from 'vscode'

export class Ec2SessionTracker extends Map<EC2.InstanceId, SSM.SessionId> implements Disposable {
public constructor(
readonly regionCode: string,
protected ssmClient: SsmClient
) {
super()
}

public async addSession(instanceId: EC2.InstanceId, sessionId: SSM.SessionId): Promise<void> {
if (this.isConnectedTo(instanceId)) {
const existingSessionId = this.get(instanceId)!
await this.ssmClient.terminateSessionFromId(existingSessionId)
this.set(instanceId, sessionId)
} else {
this.set(instanceId, sessionId)
}
}

private async disconnectEnv(instanceId: EC2.InstanceId): Promise<void> {
await this.ssmClient.terminateSessionFromId(this.get(instanceId)!)
this.delete(instanceId)
}

public async dispose(): Promise<void> {
this.forEach(async (_sessionId, instanceId) => await this.disconnectEnv(instanceId))
}

public isConnectedTo(instanceId: EC2.InstanceId): boolean {
return this.has(instanceId)
}
}
26 changes: 26 additions & 0 deletions packages/core/src/awsService/ec2/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
*/

import { SafeEc2Instance } from '../../shared/clients/ec2Client'
import { copyToClipboard } from '../../shared/utilities/messages'
import { Ec2Selection } from './prompter'
import { sshLogFileLocation } from '../../shared/sshConfig'
import { SSM } from 'aws-sdk'

export function getIconCode(instance: SafeEc2Instance) {
if (instance.LastSeenStatus === 'running') {
Expand All @@ -16,3 +20,25 @@ export function getIconCode(instance: SafeEc2Instance) {

return 'loading~spin'
}

export async function copyInstanceId(instanceId: string): Promise<void> {
await copyToClipboard(instanceId, 'Id')
}

export function getEc2SsmEnv(
selection: Ec2Selection,
ssmPath: string,
session: SSM.StartSessionResponse
): NodeJS.ProcessEnv {
return Object.assign(
{
AWS_REGION: selection.region,
AWS_SSM_CLI: ssmPath,
LOG_FILE_LOCATION: sshLogFileLocation('ec2', selection.instanceId),
STREAM_URL: session.StreamUrl,
SESSION_ID: session.SessionId,
TOKEN: session.TokenValue,
},
process.env
)
}
Loading

0 comments on commit 142761e

Please sign in to comment.