diff --git a/packages/backend/server/src/__tests__/auth/guard.spec.ts b/packages/backend/server/src/__tests__/auth/guard.spec.ts index 8c1f55a451ae9..57f96d3d9fb14 100644 --- a/packages/backend/server/src/__tests__/auth/guard.spec.ts +++ b/packages/backend/server/src/__tests__/auth/guard.spec.ts @@ -48,7 +48,7 @@ test.before(async t => { u1 = await auth.signUp('u1@affine.pro', '1'); const models = app.get(Models); - const session = await models.session.create(); + const session = await models.session.createSession(); sessionId = session.id; await auth.createUserSession(u1.id, sessionId); diff --git a/packages/backend/server/src/__tests__/auth/job.spec.ts b/packages/backend/server/src/__tests__/auth/job.spec.ts new file mode 100644 index 0000000000000..930b2b5e7ba37 --- /dev/null +++ b/packages/backend/server/src/__tests__/auth/job.spec.ts @@ -0,0 +1,47 @@ +import { ScheduleModule } from '@nestjs/schedule'; +import { TestingModule } from '@nestjs/testing'; +import { PrismaClient } from '@prisma/client'; +import test from 'ava'; + +import { AuthModule, AuthService } from '../../core/auth'; +import { AuthCronJob } from '../../core/auth/job'; +import { createTestingModule } from '../utils'; + +let m: TestingModule; +let db: PrismaClient; + +test.before(async () => { + m = await createTestingModule({ + imports: [ScheduleModule.forRoot(), AuthModule], + }); + + db = m.get(PrismaClient); +}); + +test.after.always(async () => { + await m.close(); +}); + +test('should clean expired user sessions', async t => { + const auth = m.get(AuthService); + const job = m.get(AuthCronJob); + const user1 = await auth.signUp('u1@affine.pro', '1'); + const user2 = await auth.signUp('u2@affine.pro', '1'); + await auth.createUserSession(user1.id); + await auth.createUserSession(user2.id); + let userSessions = await db.userSession.findMany(); + t.is(userSessions.length, 2); + + // no expired sessions + await job.cleanExpiredUserSessions(); + userSessions = await db.userSession.findMany(); + t.is(userSessions.length, 2); + + // clean all expired sessions + await db.userSession.updateMany({ + data: { expiresAt: new Date(Date.now() - 1000) }, + }); + await job.cleanExpiredUserSessions(); + userSessions = await db.userSession.findMany(); + t.is(userSessions.length, 0); +}); diff --git a/packages/backend/server/src/__tests__/auth/service.spec.ts b/packages/backend/server/src/__tests__/auth/service.spec.ts index 6dc96c80f0401..cb6f903ff5686 100644 --- a/packages/backend/server/src/__tests__/auth/service.spec.ts +++ b/packages/backend/server/src/__tests__/auth/service.spec.ts @@ -192,8 +192,10 @@ test('should be able to signout multi accounts session', async t => { const session = await auth.createSession(); - await auth.createUserSession(u1.id, session.id); - await auth.createUserSession(u2.id, session.id); + const userSession1 = await auth.createUserSession(u1.id, session.id); + const userSession2 = await auth.createUserSession(u2.id, session.id); + t.not(userSession1.id, userSession2.id); + t.is(userSession1.sessionId, userSession2.sessionId); await auth.signOut(session.id, u1.id); diff --git a/packages/backend/server/src/core/auth/index.ts b/packages/backend/server/src/core/auth/index.ts index d85355295d5d6..94a0e4501997c 100644 --- a/packages/backend/server/src/core/auth/index.ts +++ b/packages/backend/server/src/core/auth/index.ts @@ -7,6 +7,7 @@ import { QuotaModule } from '../quota'; import { UserModule } from '../user'; import { AuthController } from './controller'; import { AuthGuard, AuthWebsocketOptionsProvider } from './guard'; +import { AuthCronJob } from './job'; import { AuthResolver } from './resolver'; import { AuthService } from './service'; @@ -16,6 +17,7 @@ import { AuthService } from './service'; AuthService, AuthResolver, AuthGuard, + AuthCronJob, AuthWebsocketOptionsProvider, ], exports: [AuthService, AuthGuard, AuthWebsocketOptionsProvider], diff --git a/packages/backend/server/src/core/auth/job.ts b/packages/backend/server/src/core/auth/job.ts new file mode 100644 index 0000000000000..1e59279dd5692 --- /dev/null +++ b/packages/backend/server/src/core/auth/job.ts @@ -0,0 +1,14 @@ +import { Injectable } from '@nestjs/common'; +import { Cron, CronExpression } from '@nestjs/schedule'; + +import { Models } from '../../models'; + +@Injectable() +export class AuthCronJob { + constructor(private readonly models: Models) {} + + @Cron(CronExpression.EVERY_DAY_AT_MIDNIGHT) + async cleanExpiredUserSessions() { + await this.models.session.cleanExpiredUserSessions(); + } +} diff --git a/packages/backend/server/src/core/auth/service.ts b/packages/backend/server/src/core/auth/service.ts index 46687eab15314..de53a620d87e6 100644 --- a/packages/backend/server/src/core/auth/service.ts +++ b/packages/backend/server/src/core/auth/service.ts @@ -1,12 +1,9 @@ import { Injectable, OnApplicationBootstrap } from '@nestjs/common'; -import { Cron, CronExpression } from '@nestjs/schedule'; -import type { User, UserSession } from '@prisma/client'; -import { PrismaClient } from '@prisma/client'; import type { CookieOptions, Request, Response } from 'express'; import { assign, pick } from 'lodash-es'; import { Config, MailService, SignUpForbidden } from '../../base'; -import { Models } from '../../models'; +import { Models, type User, type UserSession } from '../../models'; import { FeatureManagementService } from '../features/management'; import { QuotaService } from '../quota/service'; import { QuotaType } from '../quota/types'; @@ -47,7 +44,6 @@ export class AuthService implements OnApplicationBootstrap { constructor( private readonly config: Config, - private readonly db: PrismaClient, private readonly models: Models, private readonly mailer: MailService, private readonly feature: FeatureManagementService, @@ -105,14 +101,9 @@ export class AuthService implements OnApplicationBootstrap { async signOut(sessionId: string, userId?: string) { // sign out all users in the session if (!userId) { - await this.models.session.delete(sessionId); + await this.models.session.deleteSession(sessionId); } else { - await this.db.userSession.deleteMany({ - where: { - sessionId, - userId, - }, - }); + await this.models.session.deleteUserSession(userId, sessionId); } } @@ -136,7 +127,8 @@ export class AuthService implements OnApplicationBootstrap { // fallback to the first valid session if user provided userId is invalid if (!userSession) { // checked - userSession = sessions.at(-1) as UserSession; + // oxlint-disable-next-line @typescript-eslint/no-non-null-assertion + userSession = sessions.at(-1)!; } const user = await this.user.findUserById(userSession.userId); @@ -149,117 +141,50 @@ export class AuthService implements OnApplicationBootstrap { } async getUserSessions(sessionId: string) { - return this.db.userSession.findMany({ - where: { - sessionId, - OR: [{ expiresAt: { gt: new Date() } }, { expiresAt: null }], - }, - orderBy: { - createdAt: 'asc', - }, - }); + return await this.models.session.findUserSessionsBySessionId(sessionId); } - async createUserSession( - userId: string, - sessionId?: string, - ttl = this.config.auth.session.ttl - ) { - // check whether given session is valid - if (sessionId) { - const session = await this.getSession(sessionId); - - if (!session) { - sessionId = undefined; - } - } - - if (!sessionId) { - const session = await this.createSession(); - sessionId = session.id; - } - - const expiresAt = new Date(Date.now() + ttl * 1000); - - return this.db.userSession.upsert({ - where: { - sessionId_userId: { - sessionId, - userId, - }, - }, - update: { - expiresAt, - }, - create: { - sessionId, - userId, - expiresAt, - }, - }); + async createUserSession(userId: string, sessionId?: string, ttl?: number) { + return await this.models.session.createOrRefreshUserSession( + userId, + sessionId, + ttl + ); } async getUserList(sessionId: string) { - const sessions = await this.db.userSession.findMany({ - where: { - sessionId, - OR: [ - { - expiresAt: null, - }, - { - expiresAt: { - gt: new Date(), - }, - }, - ], - }, - include: { + const sessions = await this.models.session.findUserSessionsBySessionId( + sessionId, + { user: true, - }, - orderBy: { - createdAt: 'asc', - }, - }); - + } + ); return sessions.map(({ user }) => sessionUser(user)); } async createSession() { - return await this.models.session.create(); + return await this.models.session.createSession(); } async getSession(sessionId: string) { - return await this.models.session.get(sessionId); + return await this.models.session.getSession(sessionId); } async refreshUserSessionIfNeeded( res: Response, - session: UserSession, - ttr = this.config.auth.session.ttr + userSession: UserSession, + ttr?: number ): Promise { - if ( - session.expiresAt && - session.expiresAt.getTime() - Date.now() > ttr * 1000 - ) { + const newExpiresAt = await this.models.session.refreshUserSessionIfNeeded( + userSession, + ttr + ); + if (!newExpiresAt) { // no need to refresh return false; } - const newExpiresAt = new Date( - Date.now() + this.config.auth.session.ttl * 1000 - ); - - await this.db.userSession.update({ - where: { - id: session.id, - }, - data: { - expiresAt: newExpiresAt, - }, - }); - - res.cookie(AuthService.sessionCookieName, session.sessionId, { + res.cookie(AuthService.sessionCookieName, userSession.sessionId, { expires: newExpiresAt, ...this.cookieOptions, }); @@ -268,11 +193,7 @@ export class AuthService implements OnApplicationBootstrap { } async revokeUserSessions(userId: string) { - return this.db.userSession.deleteMany({ - where: { - userId, - }, - }); + return await this.models.session.deleteUserSession(userId); } getSessionOptionsFromRequest(req: Request) { @@ -412,15 +333,4 @@ export class AuthService implements OnApplicationBootstrap { to: email, }); } - - @Cron(CronExpression.EVERY_DAY_AT_MIDNIGHT) - async cleanExpiredSessions() { - await this.db.userSession.deleteMany({ - where: { - expiresAt: { - lte: new Date(), - }, - }, - }); - } } diff --git a/packages/backend/server/src/core/auth/session.ts b/packages/backend/server/src/core/auth/session.ts index 3707d5aaab952..9a42e5fcad062 100644 --- a/packages/backend/server/src/core/auth/session.ts +++ b/packages/backend/server/src/core/auth/session.ts @@ -1,8 +1,8 @@ import type { ExecutionContext } from '@nestjs/common'; import { createParamDecorator } from '@nestjs/common'; -import { User, UserSession } from '@prisma/client'; import { getRequestResponseFromContext } from '../../base'; +import type { User, UserSession } from '../../models'; /** * Used to fetch current user from the request context. @@ -37,7 +37,7 @@ import { getRequestResponseFromContext } from '../../base'; * ``` */ // interface and variable don't conflict -// eslint-disable-next-line no-redeclare +// oxlint-disable-next-line no-redeclare export const CurrentUser = createParamDecorator( (_: unknown, context: ExecutionContext) => { return getRequestResponseFromContext(context).req.session?.user; @@ -51,7 +51,7 @@ export interface CurrentUser } // interface and variable don't conflict -// eslint-disable-next-line no-redeclare +// oxlint-disable-next-line no-redeclare export const Session = createParamDecorator( (_: unknown, context: ExecutionContext) => { return getRequestResponseFromContext(context).req.session; diff --git a/packages/backend/server/src/models/index.ts b/packages/backend/server/src/models/index.ts index 0edcc2e90a63c..0e7ec89f63ce3 100644 --- a/packages/backend/server/src/models/index.ts +++ b/packages/backend/server/src/models/index.ts @@ -4,6 +4,8 @@ import { SessionModel } from './session'; import { UserModel } from './user'; import { VerificationTokenModel } from './verification-token'; +export * from './session'; +export * from './user'; export * from './verification-token'; const models = [UserModel, SessionModel, VerificationTokenModel] as const;