Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert unit tests to Typescript #258

Merged
merged 3 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions lib/DBSQLClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { TProtocolVersion } from '../thrift/TCLIService_types';
import IDBSQLClient, { ClientOptions, ConnectionOptions, OpenSessionRequest } from './contracts/IDBSQLClient';
import IDriver from './contracts/IDriver';
import IClientContext, { ClientConfig } from './contracts/IClientContext';
import IThriftClient from './contracts/IThriftClient';
import HiveDriver from './hive/HiveDriver';
import DBSQLSession from './DBSQLSession';
import IDBSQLSession from './contracts/IDBSQLSession';
Expand Down Expand Up @@ -43,26 +44,28 @@ function getInitialNamespaceOptions(catalogName?: string, schemaName?: string) {
};
}

export type ThriftLibrary = Pick<typeof thrift, 'createClient'>;

export default class DBSQLClient extends EventEmitter implements IDBSQLClient, IClientContext {
private static defaultLogger?: IDBSQLLogger;

private readonly config: ClientConfig;

private connectionProvider?: IConnectionProvider;
protected connectionProvider?: IConnectionProvider;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we test through the public interface. Do we need to expose these to get reasonable test coverage? I don't know the standards in the Node world, but in general my approach is, if there are methods that should be private, but I need to expose in order to test, I make a new type where those methods make sense as public methods, and have the original type call the new type to accomplish the work. This allows you to mock the new class in the tests of the original class, and not feel compelled to have a large public interface that consumers really shouldn't call. Given how big a change we are already doing, I won't block this PR on that, but it's something we should consider before we make another release of this library.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I took another look at these changes and totally agree with you. I'll revert all similar changes (private -> protected) and temporarily just suppress TS errors with @ts-expect-error

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done 275bdbf

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now I kept code that uses private fields and methods - this PR is already too big to add more refactoring. However, it turns out that TS allows to access private and protected fields using string indexing (obj['prop']) - TS still checks if property exists and keeps things type-safe, so I think it's still better than untyped code or with errors suppressed


private authProvider?: IAuthentication;
protected authProvider?: IAuthentication;

private client?: TCLIService.Client;
protected client?: IThriftClient;

private readonly driver = new HiveDriver({
context: this,
});

private readonly logger: IDBSQLLogger;

private readonly thrift = thrift;
protected thrift: ThriftLibrary = thrift;

private sessions = new CloseableCollection<DBSQLSession>();
protected sessions = new CloseableCollection<DBSQLSession>();

private static getDefaultLogger(): IDBSQLLogger {
if (!this.defaultLogger) {
Expand Down Expand Up @@ -99,7 +102,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I
this.logger.log(LogLevel.info, 'Created DBSQLClient');
}

private getConnectionOptions(options: ConnectionOptions): IConnectionOptions {
protected getConnectionOptions(options: ConnectionOptions): IConnectionOptions {
return {
host: options.host,
port: options.port || 443,
Expand All @@ -113,7 +116,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I
};
}

private initAuthProvider(options: ConnectionOptions, authProvider?: IAuthentication): IAuthentication {
protected createAuthProvider(options: ConnectionOptions, authProvider?: IAuthentication): IAuthentication {
if (authProvider) {
return authProvider;
}
Expand Down Expand Up @@ -143,6 +146,10 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I
}
}

protected createConnectionProvider(options: ConnectionOptions): IConnectionProvider {
return new HttpConnection(this.getConnectionOptions(options), this);
}

/**
* Connects DBSQLClient to endpoint
* @public
Expand All @@ -153,9 +160,9 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I
* const session = client.connect({host, path, token});
*/
public async connect(options: ConnectionOptions, authProvider?: IAuthentication): Promise<IDBSQLClient> {
this.authProvider = this.initAuthProvider(options, authProvider);
this.authProvider = this.createAuthProvider(options, authProvider);

this.connectionProvider = new HttpConnection(this.getConnectionOptions(options), this);
this.connectionProvider = this.createConnectionProvider(options);

const thriftConnection = await this.connectionProvider.getThriftConnection();

Expand Down Expand Up @@ -238,7 +245,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I
return this.connectionProvider;
}

public async getClient(): Promise<TCLIService.Client> {
public async getClient(): Promise<IThriftClient> {
const connectionProvider = await this.getConnectionProvider();

if (!this.client) {
Expand Down
20 changes: 10 additions & 10 deletions lib/DBSQLOperation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,23 +48,23 @@ async function delay(ms?: number): Promise<void> {
}

export default class DBSQLOperation implements IOperation {
private readonly context: IClientContext;
protected readonly context: IClientContext;

private readonly operationHandle: TOperationHandle;
protected readonly operationHandle: TOperationHandle;

public onClose?: () => void;
protected readonly _data: RowSetProvider;

private readonly _data: RowSetProvider;
protected readonly closeOperation?: TCloseOperationResp;

private readonly closeOperation?: TCloseOperationResp;
protected closed: boolean = false;

private closed: boolean = false;
protected cancelled: boolean = false;

private cancelled: boolean = false;
protected metadata?: TGetResultSetMetadataResp;

private metadata?: TGetResultSetMetadataResp;
protected state: number = TOperationState.INITIALIZED_STATE;

private state: number = TOperationState.INITIALIZED_STATE;
public onClose?: () => void;

// Once operation is finished or fails - cache status response, because subsequent calls
// to `getOperationStatus()` may fail with irrelevant errors, e.g. HTTP 404
Expand Down Expand Up @@ -376,7 +376,7 @@ export default class DBSQLOperation implements IOperation {
return this.metadata;
}

private async getResultHandler(): Promise<ResultSlicer<any>> {
protected async getResultHandler(): Promise<ResultSlicer<any>> {
const metadata = await this.fetchMetadata();
const resultFormat = definedOrError(metadata.resultFormat);

Expand Down
6 changes: 3 additions & 3 deletions lib/DBSQLSession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,11 @@ export default class DBSQLSession implements IDBSQLSession {

private readonly sessionHandle: TSessionHandle;

private isOpen = true;
protected isOpen = true;

public onClose?: () => void;
protected operations = new CloseableCollection<DBSQLOperation>();

private operations = new CloseableCollection<DBSQLOperation>();
public onClose?: () => void;

constructor({ handle, context }: DBSQLSessionConstructorOptions) {
this.sessionHandle = handle;
Expand Down
113 changes: 65 additions & 48 deletions lib/connection/auth/DatabricksOAuth/AuthorizationCode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,50 +6,26 @@ import { OAuthScopes, scopeDelimiter } from './OAuthScope';
import IClientContext from '../../../contracts/IClientContext';
import AuthenticationError from '../../../errors/AuthenticationError';

export type DefaultOpenAuthUrlFunction = (authUrl: string) => Promise<void>;

export type CustomOpenAuthUrlFunction = (
authUrl: string,
defaultOpenAuthUrl: DefaultOpenAuthUrlFunction,
) => Promise<void>;

export interface AuthorizationCodeOptions {
client: BaseClient;
ports: Array<number>;
context: IClientContext;
openAuthUrl?: CustomOpenAuthUrlFunction;
}

async function startServer(
host: string,
port: number,
requestHandler: (req: IncomingMessage, res: ServerResponse) => void,
): Promise<Server> {
const server = http.createServer(requestHandler);

return new Promise((resolve, reject) => {
const errorListener = (error: Error) => {
server.off('error', errorListener);
reject(error);
};

server.on('error', errorListener);
server.listen(port, host, () => {
server.off('error', errorListener);
resolve(server);
});
});
async function defaultOpenAuthUrl(authUrl: string) {
await open(authUrl);
}

async function stopServer(server: Server): Promise<void> {
if (!server.listening) {
return;
}

return new Promise((resolve, reject) => {
const errorListener = (error: Error) => {
server.off('error', errorListener);
reject(error);
};

server.on('error', errorListener);
server.close(() => {
server.off('error', errorListener);
resolve();
});
});
async function openAuthUrl(authUrl: string, defaultOpenUrl: DefaultOpenAuthUrlFunction) {
return defaultOpenUrl(authUrl);
}

export interface AuthorizationCodeFetchResult {
Expand All @@ -65,16 +41,12 @@ export default class AuthorizationCode {

private readonly host: string = 'localhost';

private readonly ports: Array<number>;
private readonly options: AuthorizationCodeOptions;

constructor(options: AuthorizationCodeOptions) {
this.client = options.client;
this.ports = options.ports;
this.context = options.context;
}

private async openUrl(url: string) {
return open(url);
this.options = options;
}

public async fetch(scopes: OAuthScopes): Promise<AuthorizationCodeFetchResult> {
Expand All @@ -84,7 +56,7 @@ export default class AuthorizationCode {

let receivedParams: CallbackParamsType | undefined;

const server = await this.startServer((req, res) => {
const server = await this.createServer((req, res) => {
const params = this.client.callbackParams(req);
if (params.state === state) {
receivedParams = params;
Expand All @@ -108,7 +80,8 @@ export default class AuthorizationCode {
redirect_uri: redirectUri,
});

await this.openUrl(authUrl);
const openUrl = this.options.openAuthUrl ?? openAuthUrl;
await openUrl(authUrl, defaultOpenAuthUrl);
await server.stopped();

if (!receivedParams || !receivedParams.code) {
Expand All @@ -122,11 +95,11 @@ export default class AuthorizationCode {
return { code: receivedParams.code, verifier: verifierString, redirectUri };
}

private async startServer(requestHandler: (req: IncomingMessage, res: ServerResponse) => void) {
for (const port of this.ports) {
private async createServer(requestHandler: (req: IncomingMessage, res: ServerResponse) => void) {
for (const port of this.options.ports) {
const host = this.host; // eslint-disable-line prefer-destructuring
try {
const server = await startServer(host, port, requestHandler); // eslint-disable-line no-await-in-loop
const server = await this.startServer(host, port, requestHandler); // eslint-disable-line no-await-in-loop
this.context.getLogger().log(LogLevel.info, `Listening for OAuth authorization callback at ${host}:${port}`);

let resolveStopped: () => void;
Expand All @@ -140,7 +113,7 @@ export default class AuthorizationCode {
host,
port,
server,
stop: () => stopServer(server).then(resolveStopped).catch(rejectStopped),
stop: () => this.stopServer(server).then(resolveStopped).catch(rejectStopped),
stopped: () => stoppedPromise,
};
} catch (error) {
Expand All @@ -156,6 +129,50 @@ export default class AuthorizationCode {
throw new AuthenticationError('Failed to start server: all ports are in use');
}

protected createHttpServer(requestHandler: (req: IncomingMessage, res: ServerResponse) => void) {
return http.createServer(requestHandler);
}

private async startServer(
host: string,
port: number,
requestHandler: (req: IncomingMessage, res: ServerResponse) => void,
): Promise<Server> {
const server = this.createHttpServer(requestHandler);

return new Promise((resolve, reject) => {
const errorListener = (error: Error) => {
server.off('error', errorListener);
reject(error);
};

server.on('error', errorListener);
server.listen(port, host, () => {
server.off('error', errorListener);
resolve(server);
});
});
}

private async stopServer(server: Server): Promise<void> {
if (!server.listening) {
return;
}

return new Promise((resolve, reject) => {
const errorListener = (error: Error) => {
server.off('error', errorListener);
reject(error);
};

server.on('error', errorListener);
server.close(() => {
server.off('error', errorListener);
resolve();
});
});
}

private renderCallbackResponse(): string {
const applicationName = 'Databricks Sql Connector';

Expand Down
12 changes: 8 additions & 4 deletions lib/connection/auth/DatabricksOAuth/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import IClientContext from '../../../contracts/IClientContext';

export { OAuthFlow };

interface DatabricksOAuthOptions extends OAuthManagerOptions {
export interface DatabricksOAuthOptions extends OAuthManagerOptions {
scopes?: OAuthScopes;
persistence?: OAuthPersistence;
headers?: HeadersInit;
Expand All @@ -16,16 +16,16 @@ interface DatabricksOAuthOptions extends OAuthManagerOptions {
export default class DatabricksOAuth implements IAuthentication {
private readonly context: IClientContext;

private readonly options: DatabricksOAuthOptions;
protected readonly options: DatabricksOAuthOptions;

private readonly manager: OAuthManager;
protected readonly manager: OAuthManager;

private readonly defaultPersistence = new OAuthPersistenceCache();

constructor(options: DatabricksOAuthOptions) {
this.context = options.context;
this.options = options;
this.manager = OAuthManager.getManager(this.options);
this.manager = this.createManager(this.options);
}

public async authenticate(): Promise<HeadersInit> {
Expand All @@ -46,4 +46,8 @@ export default class DatabricksOAuth implements IAuthentication {
Authorization: `Bearer ${token.accessToken}`,
};
}

protected createManager(options: OAuthManagerOptions): OAuthManager {
return OAuthManager.getManager(options);
}
}
8 changes: 4 additions & 4 deletions lib/connection/auth/PlainHttpAuthentication.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ interface PlainHttpAuthenticationOptions {
}

export default class PlainHttpAuthentication implements IAuthentication {
private readonly context: IClientContext;
protected readonly context: IClientContext;

private readonly username: string;
protected readonly username: string;

private readonly password: string;
protected readonly password: string;

private readonly headers: HeadersInit;
protected readonly headers: HeadersInit;

constructor(options: PlainHttpAuthenticationOptions) {
this.context = options.context;
Expand Down
4 changes: 2 additions & 2 deletions lib/connection/connections/HttpConnection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export default class HttpConnection implements IConnectionProvider {

private readonly context: IClientContext;

private headers: HeadersInit = {};
protected headers: HeadersInit = {};

private connection?: ThriftHttpConnection;

Expand All @@ -36,7 +36,7 @@ export default class HttpConnection implements IConnectionProvider {
});
}

public async getAgent(): Promise<http.Agent> {
public async getAgent(): Promise<http.Agent | undefined> {
if (!this.agent) {
if (this.options.proxy !== undefined) {
this.agent = this.createProxyAgent(this.options.proxy);
Expand Down
Loading
Loading