Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: gracefully shutdown redis connection on nestjs shutdown #289

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 76 additions & 31 deletions lib/redis.core-module.ts
Original file line number Diff line number Diff line change
@@ -1,67 +1,100 @@
import { DynamicModule, Module, Global, Provider } from '@nestjs/common';
import { RedisModuleAsyncOptions, RedisModuleOptions, RedisModuleOptionsFactory } from './redis.interfaces';
import { createRedisConnection, getRedisOptionsToken, getRedisConnectionToken } from './redis.utils'
import {
DynamicModule,
Module,
Global,
Provider,
OnApplicationShutdown,
} from '@nestjs/common';
import {
RedisModuleAsyncOptions,
RedisModuleOptions,
RedisModuleOptionsFactory,
} from './redis.interfaces';
import {
createRedisConnection,
getRedisOptionsToken,
getRedisConnectionToken,
tryCloseRedisConnectionPermanently,
} from './redis.utils';
import Redis, { Cluster } from 'ioredis';

@Global()
@Module({})
export class RedisCoreModule {
export class RedisCoreModule implements OnApplicationShutdown {
private static readonly redisConnections = [] as Array<
WeakRef<Redis | Cluster>
>;

public async onApplicationShutdown() {
await Promise.all(
RedisCoreModule.redisConnections.map(async (connection) => {
const redis = connection.deref();
if (redis) {
await tryCloseRedisConnectionPermanently(redis);
}
}),
);
}

/* forRoot */
static forRoot(options: RedisModuleOptions, connection?: string): DynamicModule {

static forRoot(
options: RedisModuleOptions,
connection?: string,
): DynamicModule {
const redisOptionsProvider: Provider = {
provide: getRedisOptionsToken(connection),
useValue: options,
};

const redisConnectionProvider: Provider = {
provide: getRedisConnectionToken(connection),
useValue: createRedisConnection(options),
useValue: RedisCoreModule.createAndTrackRedisConnection(options),
};

return {
module: RedisCoreModule,
providers: [
redisOptionsProvider,
redisConnectionProvider,
],
exports: [
redisOptionsProvider,
redisConnectionProvider,
],
providers: [redisOptionsProvider, redisConnectionProvider],
exports: [redisOptionsProvider, redisConnectionProvider],
};
}

/* forRootAsync */
public static forRootAsync(options: RedisModuleAsyncOptions, connection: string): DynamicModule {

public static forRootAsync(
options: RedisModuleAsyncOptions,
connection?: string,
): DynamicModule {
const redisConnectionProvider: Provider = {
provide: getRedisConnectionToken(connection),
useFactory(options: RedisModuleOptions) {
return createRedisConnection(options)
return RedisCoreModule.createAndTrackRedisConnection(options);
},
inject: [getRedisOptionsToken(connection)],
};

return {
module: RedisCoreModule,
imports: options.imports,
providers: [...this.createAsyncProviders(options, connection), redisConnectionProvider],
providers: [
...this.createAsyncProviders(options, connection),
redisConnectionProvider,
],
exports: [redisConnectionProvider],
};
}

/* createAsyncProviders */
public static createAsyncProviders(options: RedisModuleAsyncOptions, connection?: string): Provider[] {

if(!(options.useExisting || options.useFactory || options.useClass)) {
throw new Error('Invalid configuration. Must provide useFactory, useClass or useExisting');
public static createAsyncProviders(
options: RedisModuleAsyncOptions,
connection?: string,
): Provider[] {
if (!(options.useExisting || options.useFactory || options.useClass)) {
throw new Error(
'Invalid configuration. Must provide useFactory, useClass or useExisting',
);
}

if (options.useExisting || options.useFactory) {
return [
this.createAsyncOptionsProvider(options, connection)
];
return [this.createAsyncOptionsProvider(options, connection)];
}

return [
Expand All @@ -71,10 +104,14 @@ export class RedisCoreModule {
}

/* createAsyncOptionsProvider */
public static createAsyncOptionsProvider(options: RedisModuleAsyncOptions, connection?: string): Provider {

if(!(options.useExisting || options.useFactory || options.useClass)) {
throw new Error('Invalid configuration. Must provide useFactory, useClass or useExisting');
public static createAsyncOptionsProvider(
options: RedisModuleAsyncOptions,
connection?: string,
): Provider {
if (!(options.useExisting || options.useFactory || options.useClass)) {
throw new Error(
'Invalid configuration. Must provide useFactory, useClass or useExisting',
);
}

if (options.useFactory) {
Expand All @@ -87,10 +124,18 @@ export class RedisCoreModule {

return {
provide: getRedisOptionsToken(connection),
async useFactory(optionsFactory: RedisModuleOptionsFactory): Promise<RedisModuleOptions> {
async useFactory(
optionsFactory: RedisModuleOptionsFactory,
): Promise<RedisModuleOptions> {
return await optionsFactory.createRedisModuleOptions();
},
inject: [options.useClass || options.useExisting],
};
}

protected static createAndTrackRedisConnection(options: RedisModuleOptions) {
const redis = createRedisConnection(options);
this.redisConnections.push(new WeakRef(redis));
return redis;
}
}
107 changes: 80 additions & 27 deletions lib/redis.module.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,21 @@ import { Test, TestingModule } from '@nestjs/testing';
import { RedisModule } from './redis.module';
import { getRedisConnectionToken } from './redis.utils';
import { InjectRedis } from './redis.decorators';
import { setTimeout } from 'timers/promises';

describe('RedisModule', () => {
it('Instance Redis', async () => {
const module: TestingModule = await Test.createTestingModule({
imports: [RedisModule.forRoot({
type: 'single',
options: {
host: '127.0.0.1',
port: 6379,
password: '123456',
}
})],
imports: [
RedisModule.forRoot({
type: 'single',
options: {
host: '127.0.0.1',
port: 6379,
password: '123456',
},
}),
],
}).compile();

const app = module.createNestApplication();
Expand All @@ -31,20 +34,24 @@ describe('RedisModule', () => {
const defaultConnection: string = 'default';

const module: TestingModule = await Test.createTestingModule({
imports: [RedisModule.forRoot({
type: 'single',
options: {
host: '127.0.0.1',
port: 6379,
password: '123456',
}
})],
},).compile();
imports: [
RedisModule.forRoot({
type: 'single',
options: {
host: '127.0.0.1',
port: 6379,
password: '123456',
},
}),
],
}).compile();

const app = module.createNestApplication();
await app.init();
const redisClient = module.get(getRedisConnectionToken(defaultConnection));
const redisClientTest = module.get(getRedisConnectionToken(defaultConnection));
const redisClientTest = module.get(
getRedisConnectionToken(defaultConnection),
);

expect(redisClient).toBeInstanceOf(Redis);
expect(redisClientTest).toBeInstanceOf(Redis);
Expand All @@ -53,7 +60,6 @@ describe('RedisModule', () => {
});

it('inject redis connection', async () => {

@Injectable()
class TestProvider {
constructor(@InjectRedis() private readonly redis: Redis) {}
Expand All @@ -64,14 +70,16 @@ describe('RedisModule', () => {
}

const module: TestingModule = await Test.createTestingModule({
imports: [RedisModule.forRoot({
type: 'single',
options: {
host: '127.0.0.1',
port: 6379,
password: '123456',
}
})],
imports: [
RedisModule.forRoot({
type: 'single',
options: {
host: '127.0.0.1',
port: 6379,
password: '123456',
},
}),
],
providers: [TestProvider],
}).compile();

Expand All @@ -83,4 +91,49 @@ describe('RedisModule', () => {

await app.close();
});

it('closes all redis connections on shutdown', async () => {
const module: TestingModule = await Test.createTestingModule({
imports: [
RedisModule.forRoot({
type: 'single',
options: {
host: '127.0.0.1',
port: 6379,
password: '123456',
},
}),
RedisModule.forRoot(
{
type: 'single',
options: {
host: '127.0.0.1',
port: 6379,
password: '123456',
},
},
'second',
),
],
}).compile();

const app = module.createNestApplication();
await app.init();
const defaultRedisClient = module.get<Redis>(getRedisConnectionToken());
const secondRedisClient = module.get<Redis>(
getRedisConnectionToken('second'),
);

await setTimeout(1000);

expect(defaultRedisClient.status).toBe('ready');
expect(secondRedisClient.status).toBe('ready');

await app.close();

await setTimeout(1000);

expect(defaultRedisClient.status).toBe('end');
expect(secondRedisClient.status).toBe('end');
});
});
29 changes: 23 additions & 6 deletions lib/redis.utils.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import Redis, { RedisOptions } from 'ioredis';
import Redis, { Cluster, RedisOptions } from 'ioredis';
import { RedisModuleOptions } from './redis.interfaces';
import {
REDIS_MODULE_CONNECTION,
REDIS_MODULE_CONNECTION_TOKEN,
REDIS_MODULE_OPTIONS_TOKEN
REDIS_MODULE_OPTIONS_TOKEN,
} from './redis.constants';

export function getRedisOptionsToken(connection?: string): string {
return `${ connection || REDIS_MODULE_CONNECTION }_${ REDIS_MODULE_OPTIONS_TOKEN }`;
return `${
connection || REDIS_MODULE_CONNECTION
}_${REDIS_MODULE_OPTIONS_TOKEN}`;
}

export function getRedisConnectionToken(connection?: string): string {
return `${ connection || REDIS_MODULE_CONNECTION }_${ REDIS_MODULE_CONNECTION_TOKEN }`;
return `${
connection || REDIS_MODULE_CONNECTION
}_${REDIS_MODULE_CONNECTION_TOKEN}`;
}

export function createRedisConnection(options: RedisModuleOptions) {
Expand All @@ -24,10 +28,23 @@ export function createRedisConnection(options: RedisModuleOptions) {
const { url, options: { port, host } = {} } = options;
const connectionOptions: RedisOptions = { ...commonOptions, port, host };

return url ? new Redis(url, connectionOptions) : new Redis(connectionOptions);
return url
? new Redis(url, connectionOptions)
: new Redis(connectionOptions);
default:
throw new Error('Invalid configuration');
}
}


export const tryCloseRedisConnectionPermanently = async (
redis: Redis | Cluster,
) => {
try {
await redis.quit();
} catch (error) {
if (error instanceof Error && error.message === 'Connection is closed.') {
return;
}
throw error;
}
};
15 changes: 3 additions & 12 deletions tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,8 @@
"sourceMap": false,
"outDir": "./dist",
"rootDir": "./lib",
"lib": ["es7"]
"lib": ["es7", "ES2021.WeakRef"]
},
"include": [
"lib/**/*"
],
"exclude": [
"node_modules"
]
"include": ["lib/**/*"],
"exclude": ["node_modules"]
}