Skip to content

Commit

Permalink
[FEAT] QA 내용 반영
Browse files Browse the repository at this point in the history
[FEAT] QA 내요 반영
resolves #1 #2
  • Loading branch information
sunnyineverywhere authored Sep 11, 2023
2 parents a0a9b1e + 64fc4f0 commit 6fe1985
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 14 deletions.
43 changes: 35 additions & 8 deletions crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import scheme
from sqlalchemy.orm import Session, joinedload
from model import Account, Question, InterviewQuestion, Inteverview
from model import Account, Question, InterviewQuestion, Inteverview, Answer
import random


Expand All @@ -20,23 +20,33 @@ def find_account_by_email(db: Session, email):
# question
def find_all_question(db: Session):
db_questions = db.query(Question).all()
return random.sample(db_questions, 10)
return list(db_questions)


def find_question_by_id(db: Session, question_id: int):
return db.query(Question).get(question_id)


def find_question_by_categories(db: Session, categories: list):
db_questions = None
if len(categories) == 1:
categories.append(None)
categories.append(None)
db_questions = db.query(Question).filter(
categories[0] == Question.category
)
elif len(categories) == 2:
categories.append(None)
db_questions = db.query(Question).filter(or_(
categories[0] == Question.category, categories[1] == Question.category
))
elif len(categories) == 3:
db_questions = db.query(Question).filter(or_(
categories[0] == Question.category, categories[1] == Question.category, categories[2] == Question.category))
else:
db_questions = db.query(Question)

db_questions = db.query(Question).filter(or_(
categories[0] == Question.category, categories[1] == Question.category, categories[2] == Question.category))
return random.sample(list(db_questions), 10)
if len(list(db_questions)) >= 10:
return random.sample(list(db_questions), 10)
else:
return list(db_questions)


# interview
Expand Down Expand Up @@ -116,6 +126,23 @@ def update_interview_question_additional_answer(db: Session, sequence: int, ques
db.commit()


def update_question_gpt_answer(db: Session, db_question, gpt_answer):
answer = Answer(
question=db_question.id,
content=gpt_answer,
type="TYPE_GPT"
)
db.add(answer)
db.commit()
db.refresh(answer)

def delete_gpt_answers(db: Session):
db.query(Answer).filter(
Answer.type == "TYPE_GPT"
).delete()
db.commit()


def find_interview_question_by_pk(db, iq_id: int):
return list(db.query(InterviewQuestion).options(
joinedload(InterviewQuestion.question_model)).filter(InterviewQuestion.id == iq_id))[0]
18 changes: 13 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,7 @@ async def root():
async def start_interview(req: InterviewStartReq, Authorization: str | None = Header(default=None),
db: Session = Depends(get_db)):
account = crud.find_account_by_email(db=db, email=jwt_util.decode_jwt(access_token=Authorization))
questions = []
if len(req.categories) < 1:
questions = crud.find_all_question(db=db)
else:
questions = crud.find_question_by_categories(db=db, categories=req.categories)
questions = crud.find_question_by_categories(db=db, categories=req.categories)
interview, interview_question = crud.create_interview(db=db, account=account, questions=questions,
categories=req.categories)
return {
Expand Down Expand Up @@ -119,3 +115,15 @@ async def answer_interview_additional(req: AdditionalInterviewReq, Authorization
@app.get("/interview/all/{id}")
async def get_interviw_question_by_pk(id: int, db: Session = Depends(get_db)):
return crud.find_interview_question_by_pk(db=db, iq_id=id)


# 배포금지
@app.put("/gpt/question")
async def modify_gpt_question(db: Session = Depends(get_db)):
crud.delete_gpt_answers(db)
questions = crud.find_all_question(db=db)
for q in questions:
answer = gpt_util.get_gpt_answer_static(question=q)
print(q.title)
print(answer)
crud.update_question_gpt_answer(db=db, db_question=q, gpt_answer=answer)
8 changes: 8 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ class Question(Base):
category = Column(String)


class Answer(Base):
__tablename__ = "answer"
id = Column(BigInteger, primary_key=True, index=True, autoincrement=True)
question = Column(BigInteger, ForeignKey("question.id"))
content = Column(String)
type = Column(String)


class Inteverview(Base):
__tablename__ = "interview"
id = Column(BigInteger, primary_key=True, index=True, autoincrement=True)
Expand Down
13 changes: 12 additions & 1 deletion util/gpt_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,21 @@ def get_gpt_answer(question, prompt, model="gpt-3.5-turbo-0613"):
return response.choices[0].message["content"]


def get_gpt_answer_static(question, model="gpt-3.5-turbo-0613"):
content = "나는 " + question.title + "이라는 질문에 대해 구체적이고 정확한 답변을 알려줘. 그리고 답변에 대해 설명해줘."
messages = [{"role": "assistant", "content": content}]
response = openai.ChatCompletion.create(
model=model,
messages=messages,
temperature=0
)
return response.choices[0].message["content"]


def get_gpt_questions(question, answer, model="gpt-3.5-turbo-0613"):
content = "나는 " + question + "이라는 질문에 대해 나는 " + answer + " 라고 답변했어." \
+ "네가 면접관이라면, 나에게 할 추가 질문을 3개 줘." \
+ "json 형식으로, {'question_1':"", 'question_2':"", 'question_3':""}로 응답해줘."
+ "json 형식으로, {'question_1':"", 'question_2':"", 'question_3':""}로 응답해줘."
messages = [{"role": "assistant", "content": content}]
response = openai.ChatCompletion.create(
model=model,
Expand Down

0 comments on commit 6fe1985

Please sign in to comment.