diff --git a/.github/workflows/publish-backend-docker.yml b/.github/workflows/publish-backend-docker.yml index c56fa90..05f485f 100644 --- a/.github/workflows/publish-backend-docker.yml +++ b/.github/workflows/publish-backend-docker.yml @@ -25,7 +25,7 @@ jobs: uses: docker/login-action@v1 with: registry: ghcr.io - username: ${{ github.actor }} + username: eldpswp99 password: ${{ secrets.DOCKER_TOKEN }} - name: Build image @@ -51,4 +51,4 @@ jobs: username: eldpswp99 password: ${{ secrets.DOCKER_TOKEN }} - name: Docker compose up - run: SHA=${GITHUB_SHA} docker compose -f ~/docker-compose.yml up -d fooriend-backend \ No newline at end of file + run: SHA=${GITHUB_SHA} docker compose -f ~/docker-compose.yml up -d diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ad9b490..641a9fb 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -2,9 +2,11 @@ name: 테스트 on: push: - branches: [ "main" ] + branches: + - "main" pull_request: - branches: [ "main" ] + branches: + - "main" jobs: test: diff --git a/backend/package-lock.json b/backend/package-lock.json index ba50972..03f35b4 100644 --- a/backend/package-lock.json +++ b/backend/package-lock.json @@ -21,6 +21,7 @@ "@types/bcryptjs": "^2.4.5", "@types/jsonwebtoken": "^9.0.3", "@types/multer": "^1.4.8", + "axios": "^1.6.0", "bcryptjs": "^2.4.3", "class-transformer": "^0.5.1", "class-validator": "^0.14.0", @@ -4324,8 +4325,17 @@ "node_modules/asynckit": { "version": "0.4.0", "resolved": "https://registry.npmjs.org/asynckit/-/asynckit-0.4.0.tgz", - "integrity": "sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==", - "dev": true + "integrity": "sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==" + }, + "node_modules/axios": { + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/axios/-/axios-1.6.0.tgz", + "integrity": "sha512-EZ1DYihju9pwVB+jg67ogm+Tmqc6JmhamRN6I4Zt8DfZu5lbcQGw3ozH9lFejSJgs/ibaef3A9PMXPLeefFGJg==", + "dependencies": { + "follow-redirects": "^1.15.0", + "form-data": "^4.0.0", + "proxy-from-env": "^1.1.0" + } }, "node_modules/babel-jest": { "version": "29.7.0", @@ -5065,7 +5075,6 @@ "version": "1.0.8", "resolved": "https://registry.npmjs.org/combined-stream/-/combined-stream-1.0.8.tgz", "integrity": "sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==", - "dev": true, "dependencies": { "delayed-stream": "~1.0.0" }, @@ -5498,7 +5507,6 @@ "version": "1.0.0", "resolved": "https://registry.npmjs.org/delayed-stream/-/delayed-stream-1.0.0.tgz", "integrity": "sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==", - "dev": true, "engines": { "node": ">=0.4.0" } @@ -6356,6 +6364,25 @@ "integrity": "sha512-36yxDn5H7OFZQla0/jFJmbIKTdZAQHngCedGxiMmpNfEZM0sdEeT+WczLQrjK6D7o2aiyLYDnkw0R3JK0Qv1RQ==", "dev": true }, + "node_modules/follow-redirects": { + "version": "1.15.3", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.3.tgz", + "integrity": "sha512-1VzOtuEM8pC9SFU1E+8KfTjZyMztRsgEfwQl44z8A25uy13jSzTj6dyK2Df52iV0vgHCfBwLhDWevLn95w5v6Q==", + "funding": [ + { + "type": "individual", + "url": "https://github.com/sponsors/RubenVerborgh" + } + ], + "engines": { + "node": ">=4.0" + }, + "peerDependenciesMeta": { + "debug": { + "optional": true + } + } + }, "node_modules/foreground-child": { "version": "3.1.1", "resolved": "https://registry.npmjs.org/foreground-child/-/foreground-child-3.1.1.tgz", @@ -6404,7 +6431,6 @@ "version": "4.0.0", "resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.0.tgz", "integrity": "sha512-ETEklSGi5t0QMZuiXoA/Q6vcnxcLQP5vdugSpuAyi6SVGi2clPPp+xgEhuMaHC+zGgn31Kd235W35f7Hykkaww==", - "dev": true, "dependencies": { "asynckit": "^0.4.0", "combined-stream": "^1.0.8", @@ -9091,6 +9117,11 @@ "node": ">= 0.10" } }, + "node_modules/proxy-from-env": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/proxy-from-env/-/proxy-from-env-1.1.0.tgz", + "integrity": "sha512-D+zkORCbA9f1tdWRK0RaCR3GPv50cMxcrz4X8k5LTSUD1Dkw47mKJEZQNunItRTkWwgtaUSo1RVFRIG9ZXiFYg==" + }, "node_modules/pump": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/pump/-/pump-3.0.0.tgz", diff --git a/backend/package.json b/backend/package.json index 9f80cdf..52fa1dc 100644 --- a/backend/package.json +++ b/backend/package.json @@ -32,6 +32,7 @@ "@types/bcryptjs": "^2.4.5", "@types/jsonwebtoken": "^9.0.3", "@types/multer": "^1.4.8", + "axios": "^1.6.0", "bcryptjs": "^2.4.3", "class-transformer": "^0.5.1", "class-validator": "^0.14.0", diff --git a/backend/src/review/dtos/in-dtos/restaurant.dto.ts b/backend/src/review/dtos/in-dtos/restaurant.dto.ts index 95e9d18..b33bda7 100644 --- a/backend/src/review/dtos/in-dtos/restaurant.dto.ts +++ b/backend/src/review/dtos/in-dtos/restaurant.dto.ts @@ -4,6 +4,9 @@ export class RestaurantDto { @IsString() googleMapPlaceId: string; + @IsString() + name: string; + @IsNumber() latitude: number; diff --git a/backend/src/review/dtos/out-dtos/restaurant.dto.ts b/backend/src/review/dtos/out-dtos/restaurant.dto.ts index 0754756..eea0efa 100644 --- a/backend/src/review/dtos/out-dtos/restaurant.dto.ts +++ b/backend/src/review/dtos/out-dtos/restaurant.dto.ts @@ -5,11 +5,19 @@ export class RestaurantDto { private googleMapPlaceId: string; private longitude: number; private latitude: number; + private name: string; - constructor({ id, googleMapPlaceId, longitude, latitude }: RestaurantEntity) { + constructor({ + id, + googleMapPlaceId, + longitude, + latitude, + name, + }: RestaurantEntity) { this.id = id; this.googleMapPlaceId = googleMapPlaceId; this.longitude = longitude; this.latitude = latitude; + this.name = name; } } diff --git a/backend/src/review/ml-remote.ts b/backend/src/review/ml-remote.ts new file mode 100644 index 0000000..9f59b56 --- /dev/null +++ b/backend/src/review/ml-remote.ts @@ -0,0 +1,19 @@ +import axios from 'axios'; + +export async function getReceiptOcr(imageUrl: string) { + return ( + await axios.post(`${process.env.ML_URL}ocr`, { + image_url: imageUrl, + }) + ).data; +} + +export async function getReviewIsPositive(review: string) { + return ( + ( + await axios.post(`${process.env.ML_URL}review`, { + review, + }) + ).data['result'] === '긍정' + ); +} diff --git a/backend/src/review/models/restaurant.entity.ts b/backend/src/review/models/restaurant.entity.ts index e8c81eb..e454b7e 100644 --- a/backend/src/review/models/restaurant.entity.ts +++ b/backend/src/review/models/restaurant.entity.ts @@ -13,6 +13,9 @@ export class RestaurantEntity extends IssuedAtMetaEntity { @Column({ type: 'float8' }) longitude: number; + @Column({ type: 'varchar', default: '' }) + name: string; + @OneToMany(() => ReviewEntity, (review) => review.restaurant) reviews: ReviewEntity[]; } diff --git a/backend/src/review/models/review.entity.ts b/backend/src/review/models/review.entity.ts index 0b5100d..856eab5 100644 --- a/backend/src/review/models/review.entity.ts +++ b/backend/src/review/models/review.entity.ts @@ -18,6 +18,12 @@ export class ReviewEntity extends IssuedAtMetaEntity { @Column({ type: 'varchar' }) content: string; + @Column({ type: 'boolean', default: false }) + isPositive: boolean; + + @Column('varchar', { array: true, default: [] }) + menu: string[]; + get receiptImage() { return this.images.find((image) => image.isReceipt); } diff --git a/backend/src/review/repositories/restaurant.repository.ts b/backend/src/review/repositories/restaurant.repository.ts index d35c00a..edb71f9 100644 --- a/backend/src/review/repositories/restaurant.repository.ts +++ b/backend/src/review/repositories/restaurant.repository.ts @@ -6,7 +6,7 @@ import { RestaurantDto } from '../dtos/in-dtos/restaurant.dto'; @CustomRepository(RestaurantEntity) export class RestaurantRepository extends Repository { async findOrCreate(data: RestaurantDto) { - const { googleMapPlaceId, longitude, latitude } = data; + const { googleMapPlaceId, longitude, latitude, name } = data; const restaurant = await this.findOne({ where: { googleMapPlaceId }, }); @@ -17,6 +17,7 @@ export class RestaurantRepository extends Repository { googleMapPlaceId, longitude, latitude, + name, }).save(); } } diff --git a/backend/src/review/review.service.ts b/backend/src/review/review.service.ts index 6ce2d3f..a98e53e 100644 --- a/backend/src/review/review.service.ts +++ b/backend/src/review/review.service.ts @@ -1,4 +1,4 @@ -import { Injectable } from '@nestjs/common'; +import { BadRequestException, Injectable } from '@nestjs/common'; import { UserEntity } from '../user/models/user.entity'; import { CreateReviewDto } from './dtos/in-dtos/createReview.dto'; import { RestaurantRepository } from './repositories/restaurant.repository'; @@ -7,6 +7,7 @@ import { In } from 'typeorm'; import { ReviewEntity } from './models/review.entity'; import { ReviewAdjacentQueryDto } from './dtos/in-dtos/review-adjacent-query.dto'; import { getDistance } from 'geolib'; +import { getReceiptOcr, getReviewIsPositive } from './ml-remote'; import { RestaurantEntity } from './models/restaurant.entity'; @Injectable() @@ -29,8 +30,22 @@ export class ReviewService { id: In(imageIds.concat([receiptImageId ?? -1])), }); const receiptImage = images.find((image) => image.id === receiptImageId); + let menu = []; if (receiptImage) { await receiptImage.markAsReceipt(); + try { + const receiptData = await getReceiptOcr(receiptImage.url); + menu = receiptData['menu']; + } catch (e) { + throw new BadRequestException('잘못된 영수증입니다.'); + } + } + + let isPositive = false; + try { + isPositive = await getReviewIsPositive(content); + } catch (e) { + throw new BadRequestException('리뷰 분석 중 오류가 발생했습니다.'); } return await ReviewEntity.create({ @@ -38,6 +53,8 @@ export class ReviewService { user, restaurant, images, + isPositive, + menu, }).save(); } diff --git a/backend/src/test/review/create-review.spec.ts b/backend/src/test/review/create-review.spec.ts index f7c514b..040833b 100644 --- a/backend/src/test/review/create-review.spec.ts +++ b/backend/src/test/review/create-review.spec.ts @@ -11,6 +11,12 @@ import { validateDtoKeys } from '../utils'; import { RestaurantEntity } from '../../review/models/restaurant.entity'; import { RestaurantFixture } from '../fixture/restaurant.fixture'; import { ReviewEntity } from '../../review/models/review.entity'; +import axios from 'axios'; +import { getReceiptOcr, getReviewIsPositive } from '../../review/ml-remote'; +import { ImageFixture } from '../fixture/image.fixture'; + +// mock ml-remote.ts +jest.mock('../../review/ml-remote.ts'); describe('Create Review test', () => { let testServer: NestExpressApplication; @@ -49,6 +55,16 @@ describe('Create Review test', () => { .expect(HttpStatus.CREATED); accessToken = body.accessToken; + + jest.resetAllMocks(); + + (getReceiptOcr as jest.Mock).mockResolvedValue({ + title: '김태준의 탕탕집', + address: '서울 강남구 학동로4길 12. 1,2층(논현동)', + date: '2018/01/30', + menu: ['낙지탕탕이'], + }); + (getReviewIsPositive as jest.Mock).mockResolvedValue(true); }); it('unauthorized', async () => { @@ -59,6 +75,7 @@ describe('Create Review test', () => { googleMapPlaceId: 'ChIJN1t_tDeuEmsRUsoyG83frY4', latitude: 37.4224764, longitude: -122.0842499, + name: '두근두근쭈꾸미', }, content: 'content', }) @@ -73,6 +90,7 @@ describe('Create Review test', () => { googleMapPlaceId: 'ChIJN1t_tDeuEmsRUsoyG83frY4', latitude: 37.4224764, longitude: -122.0842499, + name: '두근두근쭈꾸미', }, content: 'content', imageIds: [], @@ -100,15 +118,109 @@ describe('Create Review test', () => { expect(reviewEntity.user.id).toEqual(user.id); }); + it('긍정부정 / 영수증 반영', async () => { + const image = await ImageFixture.create({}); + + const receiptImageId = await ImageFixture.create({}); + + await supertest(testServer.getHttpServer()) + .post('/reviews') + .send({ + restaurant: { + googleMapPlaceId: 'ChIJN1t_tDeuEmsRUsoyG83frY4', + latitude: 37.4224764, + longitude: -122.0842499, + name: '두근두근쭈꾸미', + }, + content: 'content', + imageIds: [image.id], + receiptImageId: receiptImageId.id, + }) + .set('Authorization', `Bearer ${accessToken}`) + .expect(HttpStatus.CREATED); + + const reviewEntity = await ReviewEntity.findOneOrFail({ + where: { + content: 'content', + }, + relations: { + restaurant: true, + images: true, + user: true, + }, + }); + + expect(reviewEntity.content).toEqual('content'); + expect(reviewEntity.restaurant.googleMapPlaceId).toEqual( + 'ChIJN1t_tDeuEmsRUsoyG83frY4', + ); + expect(reviewEntity.restaurant.latitude).toEqual(37.4224764); + expect(reviewEntity.restaurant.longitude).toEqual(-122.0842499); + expect(reviewEntity.user.id).toEqual(user.id); + expect(reviewEntity.isPositive).toEqual(true); + expect(reviewEntity.menu).toEqual(['낙지탕탕이']); + }); + + it('영수증 에러나면 400', async () => { + const image = await ImageFixture.create({}); + + const receiptImageId = await ImageFixture.create({}); + + jest.clearAllMocks(); + (getReceiptOcr as jest.Mock).mockRejectedValue(new Error('error')); + + await supertest(testServer.getHttpServer()) + .post('/reviews') + .send({ + restaurant: { + googleMapPlaceId: 'ChIJN1t_tDeuEmsRUsoyG83frY4', + latitude: 37.4224764, + longitude: -122.0842499, + name: '두근두근쭈꾸미', + }, + content: 'content', + imageIds: [image.id], + receiptImageId: receiptImageId.id, + }) + .set('Authorization', `Bearer ${accessToken}`) + .expect(HttpStatus.BAD_REQUEST); + }); + + it('영수증 에러나면 400', async () => { + const image = await ImageFixture.create({}); + + const receiptImageId = await ImageFixture.create({}); + + jest.clearAllMocks(); + (getReviewIsPositive as jest.Mock).mockRejectedValue(new Error('error')); + + await supertest(testServer.getHttpServer()) + .post('/reviews') + .send({ + restaurant: { + googleMapPlaceId: 'ChIJN1t_tDeuEmsRUsoyG83frY4', + latitude: 37.4224764, + longitude: -122.0842499, + name: '두근두근쭈꾸미', + }, + content: 'content', + imageIds: [image.id], + receiptImageId: receiptImageId.id, + }) + .set('Authorization', `Bearer ${accessToken}`) + .expect(HttpStatus.BAD_REQUEST); + }); + it('레스토랑 없었으면 생성한다.', async () => { const restaurantCount = await RestaurantEntity.count({}); - const { body } = await supertest(testServer.getHttpServer()) + await supertest(testServer.getHttpServer()) .post('/reviews') .send({ restaurant: { googleMapPlaceId: 'ChIJN1t_tDeuEmsRUsoyG83frY4', latitude: 37.4224764, longitude: -122.0842499, + name: '두근두근쭈꾸미', }, content: 'content', imageIds: [], @@ -123,12 +235,13 @@ describe('Create Review test', () => { const restaurant = await RestaurantFixture.create({}); const restaurantCount = await RestaurantEntity.count({}); - const { body } = await supertest(testServer.getHttpServer()) + await supertest(testServer.getHttpServer()) .post('/reviews') .send({ restaurant: { googleMapPlaceId: restaurant.googleMapPlaceId, latitude: restaurant.latitude, + name: '두근두근쭈꾸미', longitude: restaurant.longitude, }, content: 'content', @@ -144,6 +257,7 @@ describe('Create Review test', () => { { restaurant: { googleMapPlaceId: 'ChIJN1t_tDeuEmsRUsoyG83frY4', + name: '두근두근쭈꾸미', latitude: 37.4224764, }, content: 'content', @@ -153,6 +267,7 @@ describe('Create Review test', () => { restaurant: { googleMapPlaceId: 'ChIJN1t_tDeuEmsRUsoyG83frY4', latitude: 37.4224764, + name: '두근두근쭈꾸미', longitude: -122.0842499, }, imageIds: [], @@ -160,24 +275,28 @@ describe('Create Review test', () => { { restaurant: { googleMapPlaceId: 'ChIJN1t_tDeuEmsRUsoyG83frY4', + name: '두근두근쭈꾸미', latitude: 37.4224764, longitude: -122.0842499, }, content: 'content', }, + { + restaurant: { + googleMapPlaceId: 'ChIJN1t_tDeuEmsRUsoyG83frY4', + latitude: 37.4224764, + longitude: -122.0842499, + }, + content: 'content', + imageIds: [1, 2, 3], + }, ])('input validation', async (requestBody) => { - const { body } = await supertest(testServer.getHttpServer()) + await supertest(testServer.getHttpServer()) .post('/reviews') .send({ - restaurant: { - googleMapPlaceId: 'ChIJN1t_tDeuEmsRUsoyG83frY4', - latitude: 37.4224764, - longitude: -122.0842499, - }, - content: 'content', - imageIds: [], + ...requestBody, }) .set('Authorization', `Bearer ${accessToken}`) - .expect(HttpStatus.CREATED); + .expect(HttpStatus.BAD_REQUEST); }); }); diff --git a/backend/src/test/review/validateReviewList.ts b/backend/src/test/review/validateReviewList.ts index f50482c..195ece8 100644 --- a/backend/src/test/review/validateReviewList.ts +++ b/backend/src/test/review/validateReviewList.ts @@ -31,7 +31,13 @@ export function validateRestaurantList(body: any) { } export function validateRestaurant(body: any) { - validateDtoKeys(body, ['id', 'googleMapPlaceId', 'longitude', 'latitude']); + validateDtoKeys(body, [ + 'id', + 'googleMapPlaceId', + 'longitude', + 'latitude', + 'name', + ]); } export function validateUserSummary(body: any) { diff --git a/ml/.gitignore b/ml/.gitignore new file mode 100644 index 0000000..e3a0395 --- /dev/null +++ b/ml/.gitignore @@ -0,0 +1,190 @@ +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +clova/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json + +### venv ### +# Virtualenv +# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ +[Bb]in +[Ii]nclude +[Ll]ib +[Ll]ib64 +[Ll]ocal +[Ss]cripts +pyvenv.cfg +pip-selfcheck.json + +# End of https://www.toptal.com/developers/gitignore/api/python,venv + +pictures/* +model/* +model.pt \ No newline at end of file diff --git a/ml/Dockerfile b/ml/Dockerfile new file mode 100644 index 0000000..31ba4bf --- /dev/null +++ b/ml/Dockerfile @@ -0,0 +1,17 @@ +FROM python:3.11 + +ENV PYTHONDONTWRITEBYTECODE 1 +ENV PYTHONUNBUFFERED 1 + +RUN apt-get update +RUN apt-get install -y cmake + +COPY . . +RUN pip install -r requirements.txt + +EXPOSE 81 + +ENTRYPOINT ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "81"] + + + diff --git a/ml/dto.py b/ml/dto.py new file mode 100644 index 0000000..a414756 --- /dev/null +++ b/ml/dto.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + + +class OcrModel(BaseModel): + image_url: str + +class ReviewModel(BaseModel): + review: str diff --git a/ml/main.py b/ml/main.py new file mode 100644 index 0000000..d0ef5e7 --- /dev/null +++ b/ml/main.py @@ -0,0 +1,37 @@ +from io import BytesIO + +import requests +import uuid +from PIL import Image +from fastapi import FastAPI + +from dto import OcrModel, ReviewModel +from ocr.execution import ocr_receipt +from review.predict import predict_review + +app = FastAPI() + + +@app.get("/") +async def root(): + return {"message": "Hello World"} + + +@app.post("/ocr") +async def receipt_ocr(data: OcrModel): + image_res = requests.get(data.image_url) + image = Image.open(BytesIO(image_res.content)) + jpg_image = image.convert('RGB') + uuid_string = uuid.uuid4() + receipt_path = 'pictures/' + str(uuid_string) + '.jpg' + jpg_image.save(receipt_path) + response = ocr_receipt(receipt_path) + + return response + + +@app.post("/review") +async def review_classification(data: ReviewModel): + return predict_review(data.review) + + diff --git a/ml/ocr/assets/receipt.jpeg b/ml/ocr/assets/receipt.jpeg new file mode 100644 index 0000000..a579642 Binary files /dev/null and b/ml/ocr/assets/receipt.jpeg differ diff --git a/ml/ocr/execution.py b/ml/ocr/execution.py new file mode 100644 index 0000000..a778952 --- /dev/null +++ b/ml/ocr/execution.py @@ -0,0 +1,80 @@ +import json +import os +import platform +import re +import time +import uuid + +import cv2 +import numpy as np +import requests +from dotenv import load_dotenv +from matplotlib import pyplot as plt +from PIL import Image, ImageDraw, ImageFont + +load_dotenv() + +api_url = os.environ.get("api_url") +secret_key = os.environ.get("secret_key") + + +# path = "assets/receipt.jpeg" + + +def ocr_receipt(path): + files = [("file", open(path, "rb"))] + + request_json = { + "images": [{"format": "jpg", "name": "demo"}], + "requestId": str(uuid.uuid4()), + "version": "V2", + "timestamp": int(round(time.time() * 1000)), + } + + payload = {"message": json.dumps(request_json).encode("UTF-8")} + + headers = { + "X-OCR-SECRET": secret_key, + } + + response = requests.request("POST", api_url, headers=headers, data=payload, files=files) + + response_body = json.loads(response.text) + + images = response_body["images"] + images_receipt = images[0].get("receipt") + + receipt_title = images_receipt["result"]["storeInfo"]["name"]["text"] + receipt_address = images_receipt["result"]["storeInfo"]["addresses"][0]["text"] + receipt_date = images_receipt["result"]["paymentInfo"]["date"]["text"] + + sub_results = images_receipt["result"]["subResults"] + receipt_menu = [] + for sub_result in sub_results: + items = sub_result.get("items", []) + for item in items: + menu_name = item.get("name", {}).get("text", "") + receipt_menu.append(menu_name) + + # receipt_price = int( + # float( + # re.sub( + # r"[^\uAC00-\uD7A30-9a-zA-Z\s]", + # "", + # images_receipt["result"]["totalPrice"]["price"]["text"], + # ) + # ) + # ) + + receipt_data = { + "title": receipt_title, + "address": receipt_address, + "date": receipt_date, + "menu": receipt_menu, + # "price": receipt_price, + } + + return receipt_data + + +# print(ocr_receipt(path)) diff --git a/ml/requirements.txt b/ml/requirements.txt new file mode 100644 index 0000000..f6b29e4 --- /dev/null +++ b/ml/requirements.txt @@ -0,0 +1,93 @@ +annotated-types==0.6.0 +anyio==3.7.1 +appnope==0.1.3 +asttokens==2.4.0 +attrs==23.1.0 +backcall==0.2.0 +black==23.10.1 +certifi==2023.7.22 +charset-normalizer==3.3.1 +click==8.1.7 +contourpy==1.1.1 +cycler==0.12.1 +Cython==3.0.4 +decorator==5.1.1 +dynaconf==3.2.3 +executing==2.0.0 +fastapi==0.104.0 +fastjsonschema==2.18.1 +filelock==3.12.4 +fonttools==4.43.1 +fsspec==2023.10.0 +gluonnlp==0.10.0 +graphviz==0.8.4 +h11==0.14.0 +huggingface-hub==0.17.3 +idna==3.4 +import-ipynb==0.1.4 +ipython==8.16.1 +jedi==0.19.1 +Jinja2==3.1.2 +joblib==1.3.2 +jsonschema==4.19.1 +jsonschema-specifications==2023.7.1 +jupyter_core==5.4.0 +keras==2.14.0 +kiwisolver==1.4.5 +MarkupSafe==2.1.3 +matplotlib==3.8.0 +matplotlib-inline==0.1.6 +mpmath==1.3.0 +mxnet==1.6.0 +mypy-extensions==1.0.0 +nbformat==5.9.2 +networkx==3.2 +numpy==1.26.1 +opencv-contrib-python==4.8.1.78 +opencv-python==4.8.1.78 +packaging==23.2 +pandas==2.1.1 +parso==0.8.3 +pathspec==0.11.2 +pexpect==4.8.0 +pickleshare==0.7.5 +Pillow==10.1.0 +platformdirs==3.11.0 +prompt-toolkit==3.0.39 +ptyprocess==0.7.0 +pure-eval==0.2.2 +pydantic==1.10.13 +pydantic_core==2.10.1 +Pygments==2.16.1 +pyparsing==3.1.1 +python-dateutil==2.8.2 +python-dotenv==1.0.0 +pytz==2023.3.post1 +PyYAML==6.0.1 +referencing==0.30.2 +regex==2023.10.3 +requests==2.31.0 +rpds-py==0.10.6 +safetensors==0.4.0 +scikit-learn==1.3.2 +scipy==1.11.3 +six==1.16.0 +sniffio==1.3.0 +SQLAlchemy==1.4.49 +sqlalchemy2-stubs==0.0.2a35 +sqlmodel==0.0.9 +stack-data==0.6.3 +starlette==0.27.0 +sympy==1.12 +threadpoolctl==3.2.0 +tokenizers==0.14.1 +torch==2.1.0 +tqdm==4.66.1 +traitlets==5.12.0 +transformers==4.34.1 +typer==0.9.0 +typing_extensions==4.8.0 +tzdata==2023.3 +urllib3==2.0.7 +uvicorn==0.23.2 +wcwidth==0.2.8 diff --git a/ml/review/model.py b/ml/review/model.py new file mode 100644 index 0000000..62fefa7 --- /dev/null +++ b/ml/review/model.py @@ -0,0 +1,50 @@ +import numpy as np +import torch +from torch import nn +from torch.utils.data import Dataset + +from tokenization import BERTSentenceTransform + + +class BERTClassifier(nn.Module): + def __init__(self, bert, hidden_size=768, num_classes=2, dr_rate=None, params=None): + super(BERTClassifier, self).__init__() + self.bert = bert + self.dr_rate = dr_rate + + self.classifier = nn.Linear(hidden_size, num_classes) + if dr_rate: + self.dropout = nn.Dropout(p=dr_rate) + + def gen_attention_mask(self, token_ids, valid_length): + attention_mask = torch.zeros_like(token_ids) + for i, v in enumerate(valid_length): + attention_mask[i][:v] = 1 + return attention_mask.float() + + def forward(self, token_ids, valid_length, segment_ids): + attention_mask = self.gen_attention_mask(token_ids, valid_length) + + _, pooler = self.bert( + input_ids=token_ids, + token_type_ids=segment_ids.long(), + attention_mask=attention_mask.float().to(token_ids.device), + ) + if self.dr_rate: + out = self.dropout(pooler) + return self.classifier(out) + + +class BERTDataset(Dataset): + def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len, pad, pair): + transform = BERTSentenceTransform( + bert_tokenizer, max_seq_length=max_len, vocab=vocab, pad=pad, pair=pair + ) + self.sentences = [transform([i[sent_idx]]) for i in dataset] + self.labels = [np.int32(i[label_idx]) for i in dataset] + + def __getitem__(self, i): + return self.sentences[i] + (self.labels[i],) + + def __len__(self): + return len(self.labels) diff --git a/ml/review/predict.py b/ml/review/predict.py new file mode 100644 index 0000000..7f8cf56 --- /dev/null +++ b/ml/review/predict.py @@ -0,0 +1,42 @@ +import torch +from torch import nn +from transformers import AutoTokenizer, BertModel + +tokenizer = AutoTokenizer.from_pretrained("WhitePeak/bert-base-cased-Korean-sentiment") + +device = "cuda" if torch.cuda.is_available() else "cpu" +model = torch.load("model.pt", map_location=device) + + +class CustomModel(nn.Module): + def __init__(self, num_classes=2): + super(CustomModel, self).__init__() + self.bert = BertModel.from_pretrained("bert-base-uncased") + self.dropout = nn.Dropout(0.1) + self.classifier = nn.Linear(768, num_classes) + + def forward(self, input_ids, attention_mask): + outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) + pooled_output = outputs.pooler_output + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + return logits + + +def predict_review(sentence): + inputs = tokenizer(sentence, return_tensors="pt") + + outputs = model(**inputs) + + probabilities = torch.nn.functional.softmax(outputs.logits, dim=1) + + predicted_class = torch.argmax(probabilities, dim=1).item() + + result = [] + + if predicted_class == 1: + result.append("긍정") + else: + result.append("부정") + + return result[0] diff --git a/ml/review/preprocessing.py b/ml/review/preprocessing.py new file mode 100644 index 0000000..028e847 --- /dev/null +++ b/ml/review/preprocessing.py @@ -0,0 +1,18 @@ +import pandas as pd +from sklearn.model_selection import train_test_split + + +def preprocessing(data): + data_list = [] + + for q, label in zip(data["review"], data["y"]): + data = [] + data.append(q) + data.append(str(label)) + + data_list.append(data) + + dataset_train, dataset_test = train_test_split( + data_list, test_size=0.2, shuffle=True, random_state=23 + ) + return dataset_train, dataset_test diff --git a/ml/review/tokenization.py b/ml/review/tokenization.py new file mode 100644 index 0000000..c18bee0 --- /dev/null +++ b/ml/review/tokenization.py @@ -0,0 +1,160 @@ +# coding=utf-8 +# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team and Jangwon Park +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Tokenization classes for KoBERT model """ + +import gluonnlp as nlp +import numpy as np +from kobert_tokenizer import KoBERTTokenizer +from transformers import BertModel + +ko_tokenizer = KoBERTTokenizer.from_pretrained("skt/kobert-base-v1") +ko_bertmodel = BertModel.from_pretrained("skt/kobert-base-v1", return_dict=False) +ko_vocab = nlp.vocab.BERTVocab.from_sentencepiece(ko_tokenizer.vocab_file, padding_token="[PAD]") + + +class BERTSentenceTransform: + r"""BERT style data transformation. + + Parameters + ---------- + tokenizer : BERTTokenizer. + Tokenizer for the sentences. + max_seq_length : int. + Maximum sequence length of the sentences. + pad : bool, default True + Whether to pad the sentences to maximum length. + pair : bool, default True + Whether to transform sentences or sentence pairs. + """ + + def __init__(self, tokenizer, max_seq_length, vocab, pad=True, pair=True): + self._tokenizer = tokenizer + self._max_seq_length = max_seq_length + self._pad = pad + self._pair = pair + self._vocab = vocab + + def __call__(self, line): + """Perform transformation for sequence pairs or single sequences. + + The transformation is processed in the following steps: + - tokenize the input sequences + - insert [CLS], [SEP] as necessary + - generate type ids to indicate whether a token belongs to the first + sequence or the second sequence. + - generate valid length + + For sequence pairs, the input is a tuple of 2 strings: + text_a, text_b. + + Inputs: + text_a: 'is this jacksonville ?' + text_b: 'no it is not' + Tokenization: + text_a: 'is this jack ##son ##ville ?' + text_b: 'no it is not .' + Processed: + tokens: '[CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]' + type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 + valid_length: 14 + + For single sequences, the input is a tuple of single string: + text_a. + + Inputs: + text_a: 'the dog is hairy .' + Tokenization: + text_a: 'the dog is hairy .' + Processed: + text_a: '[CLS] the dog is hairy . [SEP]' + type_ids: 0 0 0 0 0 0 0 + valid_length: 7 + + Parameters + ---------- + line: tuple of str + Input strings. For sequence pairs, the input is a tuple of 2 strings: + (text_a, text_b). For single sequences, the input is a tuple of single + string: (text_a,). + + Returns + ------- + np.array: input token ids in 'int32', shape (batch_size, seq_length) + np.array: valid length in 'int32', shape (batch_size,) + np.array: input token type ids in 'int32', shape (batch_size, seq_length) + + """ + + # convert to unicode + text_a = line[0] + if self._pair: + assert len(line) == 2 + text_b = line[1] + + tokens_a = self._tokenizer.tokenize(text_a) + tokens_b = None + + if self._pair: + tokens_b = self._tokenizer(text_b) + + if tokens_b: + # Modifies `tokens_a` and `tokens_b` in place so that the total + # length is less than the specified length. + # Account for [CLS], [SEP], [SEP] with "- 3" + self._truncate_seq_pair(tokens_a, tokens_b, self._max_seq_length - 3) + else: + # Account for [CLS] and [SEP] with "- 2" + if len(tokens_a) > self._max_seq_length - 2: + tokens_a = tokens_a[0 : (self._max_seq_length - 2)] + + # The embedding vectors for `type=0` and `type=1` were learned during + # pre-training and are added to the wordpiece embedding vector + # (and position vector). This is not *strictly* necessary since + # the [SEP] token unambiguously separates the sequences, but it makes + # it easier for the model to learn the concept of sequences. + + # For classification tasks, the first vector (corresponding to [CLS]) is + # used as as the "sentence vector". Note that this only makes sense because + # the entire model is fine-tuned. + # vocab = self._tokenizer.vocab + vocab = self._vocab + tokens = [] + tokens.append(vocab.cls_token) + tokens.extend(tokens_a) + tokens.append(vocab.sep_token) + segment_ids = [0] * len(tokens) + + if tokens_b: + tokens.extend(tokens_b) + tokens.append(vocab.sep_token) + segment_ids.extend([1] * (len(tokens) - len(segment_ids))) + + input_ids = self._tokenizer.convert_tokens_to_ids(tokens) + + # The valid length of sentences. Only real tokens are attended to. + valid_length = len(input_ids) + + if self._pad: + # Zero-pad up to the sequence length. + padding_length = self._max_seq_length - valid_length + # use padding tokens for the rest + input_ids.extend([vocab[vocab.padding_token]] * padding_length) + segment_ids.extend([0] * padding_length) + + return ( + np.array(input_ids, dtype="int32"), + np.array(valid_length, dtype="int32"), + np.array(segment_ids, dtype="int32"), + ) diff --git a/ml/review/train.py b/ml/review/train.py new file mode 100644 index 0000000..7c38b37 --- /dev/null +++ b/ml/review/train.py @@ -0,0 +1,120 @@ +import gluonnlp as nlp +import numpy as np +import pandas as pd +import torch +from kobert_tokenizer import KoBERTTokenizer +from model import BERTClassifier, BERTDataset +from preprocessing import preprocessing +from torch import nn +from tqdm.notebook import tqdm +from transformers import AdamW, BertModel +from transformers.optimization import get_cosine_schedule_with_warmup + +data = pd.read_csv("review_data.csv") + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +PATH = "../model" + + +def calc_accuracy(X, Y): + max_vals, max_indices = torch.max(X, 1) + train_acc = (max_indices == Y).sum().data.cpu().numpy() / max_indices.size()[0] + return train_acc + + +max_len = 64 +batch_size = 64 +warmup_ratio = 0.1 +num_epochs = 5 +max_grad_norm = 1 +log_interval = 200 +learning_rate = 5e-5 + +ko_tokenizer = KoBERTTokenizer.from_pretrained("skt/kobert-base-v1") +ko_bertmodel = BertModel.from_pretrained("skt/kobert-base-v1", return_dict=False) +ko_vocab = nlp.vocab.BERTVocab.from_sentencepiece(ko_tokenizer.vocab_file, padding_token="[PAD]") + + +def train(data): + dataset_train, dataset_test = preprocessing(data) + data_train = BERTDataset(dataset_train, 0, 1, ko_tokenizer, ko_vocab, max_len, True, False) + data_test = BERTDataset(dataset_test, 0, 1, ko_tokenizer, ko_vocab, max_len, True, False) + + train_dataloader = torch.utils.data.DataLoader( + data_train, batch_size=batch_size, num_workers=5 + ) + test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=batch_size, num_workers=5) + + model = BERTClassifier(ko_bertmodel, dr_rate=0.5).to(device) + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [ + p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) + ], + "weight_decay": 0.01, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate) + loss_fn = nn.CrossEntropyLoss() + + t_total = len(train_dataloader) * num_epochs + warmup_step = int(t_total * warmup_ratio) + + scheduler = get_cosine_schedule_with_warmup( + optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total + ) + + train_history = [] + test_history = [] + loss_history = [] + + for e in range(num_epochs): + train_acc = 0.0 + test_acc = 0.0 + model.train() + for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate( + tqdm(train_dataloader) + ): + optimizer.zero_grad() + token_ids = token_ids.long().to(device) + segment_ids = segment_ids.long().to(device) + valid_length = valid_length + label = label.long().to(device) + out = model(token_ids, valid_length, segment_ids) + + loss = loss_fn(out, label) + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) + optimizer.step() + scheduler.step() # Update learning rate schedule + train_acc += calc_accuracy(out, label) + if batch_id % log_interval == 0: + print( + "epoch {} batch id {} loss {} train acc {}".format( + e + 1, batch_id + 1, loss.data.cpu().numpy(), train_acc / (batch_id + 1) + ) + ) + train_history.append(train_acc / (batch_id + 1)) + loss_history.append(loss.data.cpu().numpy()) + print("epoch {} train acc {}".format(e + 1, train_acc / (batch_id + 1))) + model.eval() + for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate( + tqdm(test_dataloader) + ): + token_ids = token_ids.long().to(device) + segment_ids = segment_ids.long().to(device) + valid_length = valid_length + label = label.long().to(device) + out = model(token_ids, valid_length, segment_ids) + test_acc += calc_accuracy(out, label) + print("epoch {} test acc {}".format(e + 1, test_acc / (batch_id + 1))) + test_history.append(test_acc / (batch_id + 1)) + torch.save(model, PATH)