Skip to content

Commit

Permalink
fix: all auth flows
Browse files Browse the repository at this point in the history
  • Loading branch information
aarlaud committed Jan 23, 2025
1 parent fb70022 commit c42ea6c
Show file tree
Hide file tree
Showing 21 changed files with 521 additions and 85 deletions.
18 changes: 18 additions & 0 deletions accept-server.local.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"//": "private refers to what's internal to snyk, i.e. the snyk.io server",
"private": [
{
"//": "send any type of request to our connected clients",
"method": "any",
"path": "/*"
}
],
"public": [
{
"//": "send any type of request to our connected clients",
"method": "any",
"path": "/*"
}
]
}

50 changes: 50 additions & 0 deletions lib/client/auth/brokerServerConnection.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import { getConfig } from '../../common/config/config';
import { PostFilterPreparedRequest } from '../../common/relay/prepareRequest';
import version from '../../common/utils/version';
import {
HttpResponse,
makeRequestToDownstream,
} from '../../hybrid-sdk/http/request';
import { Role } from '../types/client';

export interface BrokerServerConnectionParams {
connectionIdentifier: string;
brokerClientId: string;
authorization: string;
role: Role;
serverId: number;
}
export const renewBrokerServerConnection = async (
brokerServerConnectionParams: BrokerServerConnectionParams,
): Promise<HttpResponse> => {
const clientConfig = getConfig();
const apiHostname = clientConfig.API_BASE_URL;
const body = {
data: {
type: 'broker_connection',
attributes: {
broker_client_id: brokerServerConnectionParams.brokerClientId,
},
},
};
const url = new URL(
`${apiHostname}/hidden/brokers/connections/${brokerServerConnectionParams.connectionIdentifier}/auth/refresh`,
);
url.searchParams.append('connection_role', brokerServerConnectionParams.role);
if (brokerServerConnectionParams.serverId) {
url.searchParams.append(
'serverId',
`${brokerServerConnectionParams.serverId}`,
);
}
const req: PostFilterPreparedRequest = {
url: url.toString(),
headers: {
authorization: brokerServerConnectionParams.authorization,
'user-agent': `Snyk Broker Client ${version}`,
},
method: 'POST',
body: JSON.stringify(body),
};
return await makeRequestToDownstream(req);
};
95 changes: 79 additions & 16 deletions lib/client/socket.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import { fetchJwt } from './auth/oauth';
import { getServerId } from './dispatcher';
import { determineFilterType } from './utils/filterSelection';
import { notificationHandler } from './socketHandlers/notificationHandler';
import { renewBrokerServerConnection } from './auth/brokerServerConnection';

export const createWebSocketConnectionPairs = async (
websocketConnections: WebSocketConnection[],
Expand Down Expand Up @@ -66,7 +67,6 @@ export const createWebSocketConnectionPairs = async (
} else {
logger.info(
{
connection: socketIdentifyingMetadata.friendlyName,
serverId: serverId,
},
'received server id',
Expand Down Expand Up @@ -97,6 +97,11 @@ export const createWebSocket = (
const localClientOps = Object.assign({}, clientOpts);
identifyingMetadata.identifier =
identifyingMetadata.identifier ?? localClientOps.config.brokerToken;
if (!identifyingMetadata.identifier) {
throw new Error(
`Invalid Broker Identifier/Token in websocket tunnel creation step.`,
);
}
const Socket = Primus.createSocket({
transformer: 'engine.io',
parser: 'EJSON',
Expand Down Expand Up @@ -162,30 +167,64 @@ export const createWebSocket = (
websocket.role = identifyingMetadata.role;

if (clientOpts.accessToken) {
let timeoutHandlerId;
let timeoutHandler = async () => {};
timeoutHandler = async () => {
logger.debug({}, 'Refreshing oauth access token');
clearTimeout(timeoutHandlerId);
clearTimeout(websocket.timeoutHandlerId);
clientOpts.accessToken = await fetchJwt(
clientOpts.config.API_BASE_URL,
clientOpts.config.brokerClientConfiguration.common.oauth!.clientId,
clientOpts.config.brokerClientConfiguration.common.oauth!.clientSecret,
);

websocket.transport.extraHeaders['Authorization'] =
clientOpts.accessToken!.authHeader;
// websocket.end();
// websocket.open();
timeoutHandlerId = setTimeout(
timeoutHandler,
(clientOpts.accessToken!.expiresIn - 60) * 1000,
);
websocket.transport.extraHeaders = {
Authorization: clientOpts.accessToken!.authHeader,
'x-snyk-broker-client-id': identifyingMetadata.clientId,
'x-snyk-broker-client-role': identifyingMetadata.role,
};

if (clientOpts.config.WS_TUNNEL_BOUNCE_ON_AUTH_REFRESH) {
websocket.end();
websocket.open();
} else {
logger.debug(
{
connection: `${identifyingMetadata.friendlyName}`,
role: identifyingMetadata.role,
},
'Renewing auth.',
);

const renewResponse = await renewBrokerServerConnection({
connectionIdentifier: identifyingMetadata.identifier!,
brokerClientId: identifyingMetadata.clientId,
authorization: clientOpts.accessToken!.authHeader,
role: identifyingMetadata.role,
serverId: serverId,
});
if (renewResponse.statusCode != 201) {
logger.debug(
{
connection: identifyingMetadata.identifier,
role: identifyingMetadata.role,
responseCode: renewResponse.statusCode,
},
'Failed to renew connection',
);
} else {
websocket.timeoutHandlerId = setTimeout(
timeoutHandler,
clientOpts.config.AUTH_EXPIRATION_OVERRIDE ??
(clientOpts.accessToken!.expiresIn - 60) * 1000,
);
}
}
};

timeoutHandlerId = setTimeout(
websocket.timeoutHandlerId = setTimeout(
timeoutHandler,
(clientOpts.accessToken!.expiresIn - 60) * 1000,
clientOpts.config.AUTH_EXPIRATION_OVERRIDE ??
(clientOpts.accessToken!.expiresIn - 60) * 1000,
);
}

Expand Down Expand Up @@ -235,9 +274,33 @@ export const createWebSocket = (
openHandler(websocket, localClientOps, identifyingMetadata),
);

websocket.on('close', () =>
closeHandler(localClientOps, identifyingMetadata),
);
websocket.on('service', (msg) => {
logger.info({ msg }, 'service message received');
});
// websocket.on('outgoing::open', function () {
// type OnErrorHandler = (type: string, code: number) => void;

// const originalErrorHandler: OnErrorHandler =
// websocket.socket.transport.onError;

// websocket.socket.transport.onError = (...args: [string, number]) => {
// const [type, code] = args; // Destructure for clarity
// if (code === 401) {
// logger.error({ type, code }, `Connection denied: unauthorized.`);
// } else {
// logger.error({ type, code }, `Transport error during polling.`);
// }
// originalErrorHandler.apply(websocket.socket?.transport, args);
// };
// });

websocket.on('close', () => {
if (websocket.timeoutHandlerId) {
logger.debug({}, `Clearing ${websocket.friendlyName} timers.`);
clearTimeout(websocket.timeoutHandlerId);
}
closeHandler(localClientOps, identifyingMetadata);
});

// only required if we're manually opening the connection
// websocket.open();
Expand Down
1 change: 1 addition & 0 deletions lib/client/types/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ export interface WebSocketConnection {
capabilities?: any;
on: (string, any) => any;
readyState: any;
timeoutHandlerId?: any;
}
// export interface WebSocketConnection {
// websocket: Connection;
Expand Down
1 change: 1 addition & 0 deletions lib/hybrid-sdk/clientRequestHelpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ export class HybridClientRequestHandler {
response: this.res,
streamBuffer,
streamSize: 0,
brokerAppClientId: this.res.locals.brokerAppClientId,
});
streamBuffer.pipe(this.res);
const simplifiedContextWithStreamingID = this.simplifiedContext;
Expand Down
5 changes: 4 additions & 1 deletion lib/hybrid-sdk/http/downstream-post-stream-to-server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,11 @@ class BrokerServerPostResponseHandler {

async #initHttpClientRequest() {
try {
const backendHostname = this.#config.universalBrokerEnabled
? `${this.#config.API_BASE_URL}/hidden/broker`
: this.#config.brokerServerUrl;
const url = new URL(
`${this.#config.brokerServerUrl}/response-data/${this.#brokerToken}/${
`${backendHostname}/response-data/${this.#brokerToken}/${
this.#streamingId
}`,
);
Expand Down
11 changes: 9 additions & 2 deletions lib/hybrid-sdk/http/server-post-stream-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ export interface StreamResponse {
streamBuffer: stream.PassThrough;
response: Response;
streamSize?: number;
brokerAppClientId: string;
}

export class StreamResponseHandler {
Expand All @@ -35,12 +36,18 @@ export class StreamResponseHandler {
streamingID,
streamResponse.streamBuffer,
streamResponse.response,
streamResponse.brokerAppClientId,
);
}

constructor(streamingID, streamBuffer, response) {
constructor(streamingID, streamBuffer, response, brokerAppClientId) {
this.streamingID = streamingID;
this.streamResponse = { streamBuffer, response, streamSize: 0 };
this.streamResponse = {
streamBuffer,
response,
streamSize: 0,
brokerAppClientId,
};
}

writeStatusAndHeaders = (statusAndHeaders) => {
Expand Down
46 changes: 46 additions & 0 deletions lib/server/auth/authHelpers.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import { getConfig } from '../../common/config/config';
import { PostFilterPreparedRequest } from '../../common/relay/prepareRequest';
import { makeSingleRawRequestToDownstream } from '../../hybrid-sdk/http/request';
import { log as logger } from '../../logs/logger';

export const validateBrokerClientCredentials = async (
authHeaderValue: string,
brokerClientId: string,
brokerConnectionIdentifier: string,
) => {
const body = {
data: {
type: 'broker_connection',
attributes: {
broker_client_id: brokerClientId,
},
},
};

const req: PostFilterPreparedRequest = {
url: `${
getConfig().apiHostname
}/hidden/brokers/connections/${brokerConnectionIdentifier}/auth/validate?version=2024-02-08~experimental`,
headers: {
authorization: authHeaderValue,
'Content-type': 'application/vnd.api+json',
},
method: 'POST',
body: JSON.stringify(body),
};
logger.debug({ req }, `Validate Broker Client Credentials request`);
const response = await makeSingleRawRequestToDownstream(req);
logger.debug(
{ validationResponseCode: response.statusCode },
'Validate Broker Client Credentials response',
);
if (response.statusCode === 201) {
return true;
} else {
logger.debug(
{ statusCode: response.statusCode, message: response.statusText },
`Broker ${brokerConnectionIdentifier} client ID ${brokerClientId} failed validation.`,
);
return false;
}
};
30 changes: 30 additions & 0 deletions lib/server/auth/connectionWatchdog.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import { getConfig } from '../../common/config/config';
import { getSocketConnections } from '../socket';
import { log as logger } from '../../logs/logger';

export const disconnectConnectionsWithStaleCreds = async () => {
const connections = getSocketConnections();
const connectionsIterator = connections.entries();
for (const [identifier, connection] of connectionsIterator) {
connection.forEach((client) => {
if (!isDateWithinAnHourAndFiveSec(client.credsValidationTime!)) {
logger.debug(
{
connection: `${identifier}`,
credsLastValidated: client.credsValidationTime,
},
'Cutting off connection.',
);
client.socket!.end();
}
});
}
};

const isDateWithinAnHourAndFiveSec = (date: string): boolean => {
const dateInMs = new Date(date); // Convert ISO string to Date
const now = Date.now(); // Get current time in milliseconds
const staleConnectionsCleanupInterval =
getConfig().STALE_CONNECTIONS_CLEANUP_FREQUENCY ?? 65 * 60 * 1000; // 1h05 hour in milliseconds
return now - dateInMs.getTime() < staleConnectionsCleanupInterval;
};
24 changes: 23 additions & 1 deletion lib/server/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import { getForwardHttpRequestHandler } from './socketHandlers/initHandlers';
import { loadAllFilters } from '../common/filter/filtersAsync';
import { FiltersType } from '../common/types/filter';
import filterRulesLoader from '../common/filter/filter-rules-loading';
import { authRefreshHandler } from './routesHandlers/authHandlers';
import { disconnectConnectionsWithStaleCreds } from './auth/connectionWatchdog';

export const main = async (serverOpts: ServerOpts) => {
logger.info({ version }, 'Broker starting in server mode');
Expand Down Expand Up @@ -57,21 +59,41 @@ export const main = async (serverOpts: ServerOpts) => {
app.use(applyPrometheusMiddleware());
}
app.get('/connection-status/:token', connectionStatusHandler);
app.post(
'/hidden/brokers/connections/:identifier/auth/refresh',
authRefreshHandler,
);
app.all(
'/broker/:token/*',
overloadHttpRequestWithConnectionDetailsMiddleware,
validateBrokerTypeMiddleware,
getForwardHttpRequestHandler(),
);

app.post('/response-data/:brokerToken/:streamingId', handlePostResponse);
if (
loadedServerOpts.config.BROKER_SERVER_MANDATORY_AUTH_ENABLED ||
loadedServerOpts.config.RESPONSE_DATA_HIDDEN_ENABLED
) {
app.post(
'/hidden/broker/response-data/:brokerToken/:streamingId',
handlePostResponse,
);
} else {
app.post('/response-data/:brokerToken/:streamingId', handlePostResponse);
}

app.get('/', (req, res) => res.status(200).json({ ok: true, version }));

app.get('/healthcheck', (req, res) =>
res.status(200).json({ ok: true, version }),
);

setInterval(
disconnectConnectionsWithStaleCreds,
loadedServerOpts.config.STALE_CONNECTIONS_CLEANUP_FREQUENCY ??
10 * 60 * 1000,
);

return {
websocket: websocket,
close: (done) => {
Expand Down
Loading

0 comments on commit c42ea6c

Please sign in to comment.