-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #239 from NEOS-AI/dev
Dev
- Loading branch information
Showing
23 changed files
with
429 additions
and
846 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from sqlalchemy import select, update, delete | ||
from sqlalchemy.orm import Session | ||
from uuid import UUID | ||
|
||
# custom modules | ||
from neosearch.datastore.model.chat import Chat | ||
|
||
|
||
class ChatCRUD: | ||
def __init__(self, session: Session): | ||
self.session = session | ||
|
||
def create_chat(self, title: str, user_id: UUID, visibility: str = "private"): | ||
"""Create a new chat.""" | ||
new_chat = Chat( | ||
title=title, | ||
user_id=user_id, | ||
visibility=visibility, | ||
) | ||
self.session.add(new_chat) | ||
self.session.commit() | ||
self.session.refresh(new_chat) | ||
return new_chat | ||
|
||
def get_chat_by_id(self, chat_id: UUID): | ||
"""Retrieve a chat by its ID.""" | ||
return self.session.get(Chat, chat_id) | ||
|
||
def get_all_chats(self, user_id: UUID = None, visibility: str = None): | ||
"""Retrieve all chats, optionally filtered by user_id and/or visibility.""" | ||
query = select(Chat) | ||
if user_id: | ||
query = query.where(Chat.user_id == user_id) | ||
if visibility: | ||
query = query.where(Chat.visibility == visibility) | ||
return self.session.execute(query).scalars().all() | ||
|
||
def update_chat(self, chat_id: UUID, **kwargs): | ||
"""Update a chat by its ID.""" | ||
stmt = update(Chat).where(Chat.id == chat_id).values(**kwargs) | ||
result = self.session.execute(stmt) | ||
self.session.commit() | ||
return result.rowcount # Returns the number of rows updated | ||
|
||
def delete_chat(self, chat_id: UUID): | ||
"""Delete a chat by its ID.""" | ||
stmt = delete(Chat).where(Chat.id == chat_id) | ||
result = self.session.execute(stmt) | ||
self.session.commit() | ||
return result.rowcount # Returns the number of rows deleted |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from sqlalchemy.orm import Session | ||
from uuid import UUID | ||
|
||
# custom modules | ||
from neosearch.datastore.model.document import Document | ||
|
||
|
||
class DocumentCRUD: | ||
def __init__(self, session: Session): | ||
self.session = session | ||
|
||
def create_document(self, title: str, kind: str, user_id: UUID, content: str = None): | ||
"""Create a new document.""" | ||
new_document = Document( | ||
title=title, | ||
content=content, | ||
kind=kind, | ||
user_id=user_id, | ||
) | ||
self.session.add(new_document) | ||
self.session.commit() | ||
self.session.refresh(new_document) | ||
return new_document | ||
|
||
def get_document(self, doc_id: UUID, created_at: str): | ||
"""Retrieve a document by its composite key (id and created_at).""" | ||
return ( | ||
self.session.query(Document) | ||
.filter_by(id=doc_id, created_at=created_at) | ||
.one_or_none() | ||
) | ||
|
||
def get_documents_by_user(self, user_id: UUID, kind: str = None): | ||
"""Retrieve all documents for a specific user, optionally filtered by kind.""" | ||
query = self.session.query(Document).filter_by(user_id=user_id) | ||
if kind: | ||
query = query.filter_by(kind=kind) | ||
return query.all() | ||
|
||
def update_document(self, doc_id: UUID, created_at: str, **kwargs): | ||
"""Update a document by its composite key.""" | ||
doc = self.get_document(doc_id, created_at) | ||
if doc: | ||
for key, value in kwargs.items(): | ||
setattr(doc, key, value) | ||
self.session.commit() | ||
return doc | ||
|
||
def delete_document(self, doc_id: UUID, created_at: str): | ||
"""Delete a document by its composite key.""" | ||
doc = self.get_document(doc_id, created_at) | ||
if doc: | ||
self.session.delete(doc) | ||
self.session.commit() | ||
return True | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from sqlalchemy import select, update, delete | ||
from sqlalchemy.orm import Session | ||
from uuid import UUID | ||
|
||
# custom modules | ||
from neosearch.datastore.model.message import Message | ||
|
||
|
||
class MessageCRUD: | ||
def __init__(self, session: Session): | ||
self.session = session | ||
|
||
def create_message(self, chat_id: UUID, role: str, content: dict): | ||
"""Create a new message.""" | ||
new_message = Message( | ||
chat_id=chat_id, | ||
role=role, | ||
content=content, | ||
) | ||
self.session.add(new_message) | ||
self.session.commit() | ||
self.session.refresh(new_message) | ||
return new_message | ||
|
||
def get_message_by_id(self, message_id: UUID): | ||
"""Retrieve a message by its ID.""" | ||
return self.session.get(Message, message_id) | ||
|
||
def get_all_messages(self, chat_id: UUID = None): | ||
"""Retrieve all messages, optionally filtered by chat_id.""" | ||
query = select(Message) | ||
if chat_id: | ||
query = query.where(Message.chat_id == chat_id) | ||
return self.session.execute(query).scalars().all() | ||
|
||
def update_message(self, message_id: UUID, **kwargs): | ||
"""Update a message by its ID.""" | ||
stmt = update(Message).where(Message.id == message_id).values(**kwargs) | ||
result = self.session.execute(stmt) | ||
self.session.commit() | ||
return result.rowcount # Returns the number of rows updated | ||
|
||
def delete_message(self, message_id: UUID): | ||
"""Delete a message by its ID.""" | ||
stmt = delete(Message).where(Message.id == message_id) | ||
result = self.session.execute(stmt) | ||
self.session.commit() | ||
return result.rowcount # Returns the number of rows deleted |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from sqlalchemy.orm import Session | ||
from uuid import UUID | ||
|
||
# custom modules | ||
from neosearch.datastore.model.vote import Vote | ||
|
||
|
||
class VoteCRUD: | ||
def __init__(self, session: Session): | ||
self.session = session | ||
|
||
def create_vote(self, chat_id: UUID, message_id: UUID, is_upvoted: bool): | ||
"""Create a new vote.""" | ||
new_vote = Vote( | ||
chat_id=chat_id, | ||
message_id=message_id, | ||
is_upvoted=is_upvoted, | ||
) | ||
self.session.add(new_vote) | ||
self.session.commit() | ||
return new_vote | ||
|
||
def get_vote(self, chat_id: UUID, message_id: UUID): | ||
"""Retrieve a vote by its composite key (chat_id and message_id).""" | ||
return ( | ||
self.session.query(Vote) | ||
.filter_by(chat_id=chat_id, message_id=message_id) | ||
.one_or_none() | ||
) | ||
|
||
def get_votes_by_chat(self, chat_id: UUID): | ||
"""Retrieve all votes for a specific chat.""" | ||
return self.session.query(Vote).filter_by(chat_id=chat_id).all() | ||
|
||
def update_vote(self, chat_id: UUID, message_id: UUID, is_upvoted: bool): | ||
"""Update a vote's is_upvoted value.""" | ||
vote = self.get_vote(chat_id, message_id) | ||
if vote: | ||
vote.is_upvoted = is_upvoted | ||
self.session.commit() | ||
return vote | ||
|
||
def delete_vote(self, chat_id: UUID, message_id: UUID): | ||
"""Delete a vote by its composite key.""" | ||
vote = self.get_vote(chat_id, message_id) | ||
if vote: | ||
self.session.delete(vote) | ||
self.session.commit() | ||
return True | ||
return False |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from sqlalchemy import Column, Text, ForeignKey, TIMESTAMP | ||
from sqlalchemy.dialects.postgresql import UUID, ENUM | ||
from sqlalchemy.sql import func | ||
from sqlalchemy.ext.declarative import declarative_base | ||
|
||
|
||
Base = declarative_base() | ||
|
||
visibility_enum = ENUM('public', 'private', name='visibility_enum', create_type=False) | ||
|
||
class Chat(Base): | ||
__tablename__ = 'Chat' | ||
|
||
id = Column( | ||
UUID(as_uuid=True), | ||
primary_key=True, | ||
server_default=func.gen_random_uuid(), | ||
nullable=False, | ||
) | ||
created_at = Column(TIMESTAMP, nullable=False, server_default=func.now()) | ||
title = Column(Text, nullable=False) | ||
user_id = Column(UUID(as_uuid=True), ForeignKey('User.id'), nullable=False) | ||
visibility = Column( | ||
visibility_enum, | ||
nullable=False, | ||
server_default="private" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from sqlalchemy import Column, Text, ForeignKey, TIMESTAMP, PrimaryKeyConstraint | ||
from sqlalchemy.dialects.postgresql import UUID, ENUM | ||
from sqlalchemy.sql import func | ||
from sqlalchemy.ext.declarative import declarative_base | ||
from uuid_extensions import uuid7str | ||
|
||
|
||
Base = declarative_base() | ||
|
||
# Define ENUM for the 'kind' field | ||
kind_enum = ENUM('text', 'code', name='kind_enum', create_type=False) | ||
|
||
|
||
class Document(Base): | ||
__tablename__ = 'Document' | ||
|
||
id = Column(UUID(as_uuid=True), nullable=False, server_default=uuid7str()) | ||
created_at = Column(TIMESTAMP, nullable=False, server_default=func.now()) | ||
title = Column(Text, nullable=False) | ||
content = Column(Text, nullable=True) | ||
kind = Column(kind_enum, nullable=False, server_default='text') | ||
user_id = Column(UUID(as_uuid=True), ForeignKey('user.id'), nullable=False) | ||
|
||
# Composite primary key | ||
__table_args__ = ( | ||
PrimaryKeyConstraint('id', 'created_at', name='document_pk'), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from sqlalchemy import Column, String, JSON, ForeignKey, TIMESTAMP | ||
from sqlalchemy.dialects.postgresql import UUID | ||
from sqlalchemy.sql import func | ||
from sqlalchemy.ext.declarative import declarative_base | ||
from uuid_extensions import uuid7str | ||
|
||
|
||
Base = declarative_base() | ||
|
||
|
||
class Message(Base): | ||
__tablename__ = 'Message' | ||
|
||
id = Column(UUID(as_uuid=True), primary_key=True, server_default=uuid7str(), nullable=False) | ||
chat_id = Column(UUID(as_uuid=True), ForeignKey('Chat.id'), nullable=False) | ||
role = Column(String, nullable=False) | ||
content = Column(JSON, nullable=False) | ||
created_at = Column(TIMESTAMP, nullable=False, server_default=func.now()) | ||
|
||
def __repr__(self): | ||
return f"<Message(id={self.id}, chat_id={self.chat_id}, role={self.role}, content={self.content}, created_at={self.created_at})>" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from sqlalchemy import Column, Boolean, ForeignKey, PrimaryKeyConstraint | ||
from sqlalchemy.dialects.postgresql import UUID | ||
from sqlalchemy.ext.declarative import declarative_base | ||
|
||
|
||
Base = declarative_base() | ||
|
||
|
||
class Vote(Base): | ||
__tablename__ = 'Vote' | ||
|
||
chat_id = Column(UUID(as_uuid=True), ForeignKey('Chat.id'), nullable=False) | ||
message_id = Column(UUID(as_uuid=True), ForeignKey('Message.id'), nullable=False) | ||
is_upvoted = Column(Boolean, nullable=False) | ||
|
||
# Composite primary key | ||
__table_args__ = ( | ||
PrimaryKeyConstraint('chat_id', 'message_id', name='vote_pk'), | ||
) |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.