Skip to content

Commit

Permalink
Merge pull request #239 from NEOS-AI/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
YeonwooSung authored Jan 12, 2025
2 parents 07eaa8b + f94ae77 commit ea79639
Show file tree
Hide file tree
Showing 23 changed files with 429 additions and 846 deletions.
46 changes: 45 additions & 1 deletion neosearch/api/routers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from neosearch.models.chat_models import ChatData
from neosearch.utils.logging import Logger
from neosearch.utils.events import EventCallbackHandler
from neosearch.response.chat import ChatStreamResponse
from neosearch.response.chat import ChatStreamResponse, ChatStreamResponseV2
from neosearch.settings import get_llm_model_by_id


logger = Logger()
Expand All @@ -29,6 +30,49 @@ async def chat(
) -> ChatStreamResponse:
req_id = request.state.request_id

try:
last_message_content = data.get_last_message_content()
messages = data.get_history_messages()

# get model id, and loads the corresponding model
model_id = data.get_model_id()
llm = get_llm_model_by_id(model_id)

chat_id = data.get_chat_id()

doc_ids = data.get_chat_document_ids()
filters = generate_filters(doc_ids)
logger.log_info(f"method={request.method} | {request.url} | {req_id} | 200 | details: Creating chat engine with filters: {str(filters)}") # noqa: E501

event_handler = EventCallbackHandler()

# get chat engine, and generate response with async chat stream
chat_engine = get_custom_chat_engine(
llm=llm, verbose=False
)
response = chat_engine.astream_chat(last_message_content, messages)

logger.log_debug(f"method={request.method} | {request.url} | {req_id} | 200 | details: Chat response generated") # noqa: E501

return ChatStreamResponseV2(
request, event_handler, response, data, background_tasks, chat_id
)
except Exception as e:
logger.log_error(f"method={request.method} | {request.url} | {req_id} | 500 | details: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error in chat engine: {e}",
) from e


@r.post("/base")
async def chat_base(
request: Request,
data: ChatData,
background_tasks: BackgroundTasks,
) -> ChatStreamResponse:
req_id = request.state.request_id

try:
last_message_content = data.get_last_message_content()
messages = data.get_history_messages()
Expand Down
Empty file.
50 changes: 50 additions & 0 deletions neosearch/datastore/crud/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from sqlalchemy import select, update, delete
from sqlalchemy.orm import Session
from uuid import UUID

# custom modules
from neosearch.datastore.model.chat import Chat


class ChatCRUD:
def __init__(self, session: Session):
self.session = session

def create_chat(self, title: str, user_id: UUID, visibility: str = "private"):
"""Create a new chat."""
new_chat = Chat(
title=title,
user_id=user_id,
visibility=visibility,
)
self.session.add(new_chat)
self.session.commit()
self.session.refresh(new_chat)
return new_chat

def get_chat_by_id(self, chat_id: UUID):
"""Retrieve a chat by its ID."""
return self.session.get(Chat, chat_id)

def get_all_chats(self, user_id: UUID = None, visibility: str = None):
"""Retrieve all chats, optionally filtered by user_id and/or visibility."""
query = select(Chat)
if user_id:
query = query.where(Chat.user_id == user_id)
if visibility:
query = query.where(Chat.visibility == visibility)
return self.session.execute(query).scalars().all()

def update_chat(self, chat_id: UUID, **kwargs):
"""Update a chat by its ID."""
stmt = update(Chat).where(Chat.id == chat_id).values(**kwargs)
result = self.session.execute(stmt)
self.session.commit()
return result.rowcount # Returns the number of rows updated

def delete_chat(self, chat_id: UUID):
"""Delete a chat by its ID."""
stmt = delete(Chat).where(Chat.id == chat_id)
result = self.session.execute(stmt)
self.session.commit()
return result.rowcount # Returns the number of rows deleted
56 changes: 56 additions & 0 deletions neosearch/datastore/crud/document.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from sqlalchemy.orm import Session
from uuid import UUID

# custom modules
from neosearch.datastore.model.document import Document


class DocumentCRUD:
def __init__(self, session: Session):
self.session = session

def create_document(self, title: str, kind: str, user_id: UUID, content: str = None):
"""Create a new document."""
new_document = Document(
title=title,
content=content,
kind=kind,
user_id=user_id,
)
self.session.add(new_document)
self.session.commit()
self.session.refresh(new_document)
return new_document

def get_document(self, doc_id: UUID, created_at: str):
"""Retrieve a document by its composite key (id and created_at)."""
return (
self.session.query(Document)
.filter_by(id=doc_id, created_at=created_at)
.one_or_none()
)

def get_documents_by_user(self, user_id: UUID, kind: str = None):
"""Retrieve all documents for a specific user, optionally filtered by kind."""
query = self.session.query(Document).filter_by(user_id=user_id)
if kind:
query = query.filter_by(kind=kind)
return query.all()

def update_document(self, doc_id: UUID, created_at: str, **kwargs):
"""Update a document by its composite key."""
doc = self.get_document(doc_id, created_at)
if doc:
for key, value in kwargs.items():
setattr(doc, key, value)
self.session.commit()
return doc

def delete_document(self, doc_id: UUID, created_at: str):
"""Delete a document by its composite key."""
doc = self.get_document(doc_id, created_at)
if doc:
self.session.delete(doc)
self.session.commit()
return True
return False
48 changes: 48 additions & 0 deletions neosearch/datastore/crud/message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from sqlalchemy import select, update, delete
from sqlalchemy.orm import Session
from uuid import UUID

# custom modules
from neosearch.datastore.model.message import Message


class MessageCRUD:
def __init__(self, session: Session):
self.session = session

def create_message(self, chat_id: UUID, role: str, content: dict):
"""Create a new message."""
new_message = Message(
chat_id=chat_id,
role=role,
content=content,
)
self.session.add(new_message)
self.session.commit()
self.session.refresh(new_message)
return new_message

def get_message_by_id(self, message_id: UUID):
"""Retrieve a message by its ID."""
return self.session.get(Message, message_id)

def get_all_messages(self, chat_id: UUID = None):
"""Retrieve all messages, optionally filtered by chat_id."""
query = select(Message)
if chat_id:
query = query.where(Message.chat_id == chat_id)
return self.session.execute(query).scalars().all()

def update_message(self, message_id: UUID, **kwargs):
"""Update a message by its ID."""
stmt = update(Message).where(Message.id == message_id).values(**kwargs)
result = self.session.execute(stmt)
self.session.commit()
return result.rowcount # Returns the number of rows updated

def delete_message(self, message_id: UUID):
"""Delete a message by its ID."""
stmt = delete(Message).where(Message.id == message_id)
result = self.session.execute(stmt)
self.session.commit()
return result.rowcount # Returns the number of rows deleted
50 changes: 50 additions & 0 deletions neosearch/datastore/crud/vote.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from sqlalchemy.orm import Session
from uuid import UUID

# custom modules
from neosearch.datastore.model.vote import Vote


class VoteCRUD:
def __init__(self, session: Session):
self.session = session

def create_vote(self, chat_id: UUID, message_id: UUID, is_upvoted: bool):
"""Create a new vote."""
new_vote = Vote(
chat_id=chat_id,
message_id=message_id,
is_upvoted=is_upvoted,
)
self.session.add(new_vote)
self.session.commit()
return new_vote

def get_vote(self, chat_id: UUID, message_id: UUID):
"""Retrieve a vote by its composite key (chat_id and message_id)."""
return (
self.session.query(Vote)
.filter_by(chat_id=chat_id, message_id=message_id)
.one_or_none()
)

def get_votes_by_chat(self, chat_id: UUID):
"""Retrieve all votes for a specific chat."""
return self.session.query(Vote).filter_by(chat_id=chat_id).all()

def update_vote(self, chat_id: UUID, message_id: UUID, is_upvoted: bool):
"""Update a vote's is_upvoted value."""
vote = self.get_vote(chat_id, message_id)
if vote:
vote.is_upvoted = is_upvoted
self.session.commit()
return vote

def delete_vote(self, chat_id: UUID, message_id: UUID):
"""Delete a vote by its composite key."""
vote = self.get_vote(chat_id, message_id)
if vote:
self.session.delete(vote)
self.session.commit()
return True
return False
Empty file.
27 changes: 27 additions & 0 deletions neosearch/datastore/model/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from sqlalchemy import Column, Text, ForeignKey, TIMESTAMP
from sqlalchemy.dialects.postgresql import UUID, ENUM
from sqlalchemy.sql import func
from sqlalchemy.ext.declarative import declarative_base


Base = declarative_base()

visibility_enum = ENUM('public', 'private', name='visibility_enum', create_type=False)

class Chat(Base):
__tablename__ = 'Chat'

id = Column(
UUID(as_uuid=True),
primary_key=True,
server_default=func.gen_random_uuid(),
nullable=False,
)
created_at = Column(TIMESTAMP, nullable=False, server_default=func.now())
title = Column(Text, nullable=False)
user_id = Column(UUID(as_uuid=True), ForeignKey('User.id'), nullable=False)
visibility = Column(
visibility_enum,
nullable=False,
server_default="private"
)
27 changes: 27 additions & 0 deletions neosearch/datastore/model/document.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from sqlalchemy import Column, Text, ForeignKey, TIMESTAMP, PrimaryKeyConstraint
from sqlalchemy.dialects.postgresql import UUID, ENUM
from sqlalchemy.sql import func
from sqlalchemy.ext.declarative import declarative_base
from uuid_extensions import uuid7str


Base = declarative_base()

# Define ENUM for the 'kind' field
kind_enum = ENUM('text', 'code', name='kind_enum', create_type=False)


class Document(Base):
__tablename__ = 'Document'

id = Column(UUID(as_uuid=True), nullable=False, server_default=uuid7str())
created_at = Column(TIMESTAMP, nullable=False, server_default=func.now())
title = Column(Text, nullable=False)
content = Column(Text, nullable=True)
kind = Column(kind_enum, nullable=False, server_default='text')
user_id = Column(UUID(as_uuid=True), ForeignKey('user.id'), nullable=False)

# Composite primary key
__table_args__ = (
PrimaryKeyConstraint('id', 'created_at', name='document_pk'),
)
21 changes: 21 additions & 0 deletions neosearch/datastore/model/message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from sqlalchemy import Column, String, JSON, ForeignKey, TIMESTAMP
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.sql import func
from sqlalchemy.ext.declarative import declarative_base
from uuid_extensions import uuid7str


Base = declarative_base()


class Message(Base):
__tablename__ = 'Message'

id = Column(UUID(as_uuid=True), primary_key=True, server_default=uuid7str(), nullable=False)
chat_id = Column(UUID(as_uuid=True), ForeignKey('Chat.id'), nullable=False)
role = Column(String, nullable=False)
content = Column(JSON, nullable=False)
created_at = Column(TIMESTAMP, nullable=False, server_default=func.now())

def __repr__(self):
return f"<Message(id={self.id}, chat_id={self.chat_id}, role={self.role}, content={self.content}, created_at={self.created_at})>"
19 changes: 19 additions & 0 deletions neosearch/datastore/model/vote.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from sqlalchemy import Column, Boolean, ForeignKey, PrimaryKeyConstraint
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.ext.declarative import declarative_base


Base = declarative_base()


class Vote(Base):
__tablename__ = 'Vote'

chat_id = Column(UUID(as_uuid=True), ForeignKey('Chat.id'), nullable=False)
message_id = Column(UUID(as_uuid=True), ForeignKey('Message.id'), nullable=False)
is_upvoted = Column(Boolean, nullable=False)

# Composite primary key
__table_args__ = (
PrimaryKeyConstraint('chat_id', 'message_id', name='vote_pk'),
)
3 changes: 0 additions & 3 deletions neosearch/datastore/providers/__init__.py

This file was deleted.

Loading

0 comments on commit ea79639

Please sign in to comment.