From cb3e32c1974e5a533c8507e1820bb643fdde53d0 Mon Sep 17 00:00:00 2001 From: Elijah Ahianyo Date: Tue, 28 Nov 2023 08:20:57 +0000 Subject: [PATCH] Session.query to Session.exec (#174) --- basic_crud/basic_crud/basic_crud.py | 5 +- crm/crm/state/login.py | 6 ++- crm/crm/state/state.py | 2 +- gpt/gpt/gpt.py | 66 ++++++++++++++++----------- local_auth/local_auth/base_state.py | 11 +++-- local_auth/local_auth/login.py | 5 +- local_auth/local_auth/registration.py | 5 +- sales/sales/sales.py | 14 ++++-- twitter/twitter/state/auth.py | 7 ++- twitter/twitter/state/home.py | 39 ++++++++-------- 10 files changed, 96 insertions(+), 64 deletions(-) diff --git a/basic_crud/basic_crud/basic_crud.py b/basic_crud/basic_crud/basic_crud.py index 4ff50164..0095d48c 100644 --- a/basic_crud/basic_crud/basic_crud.py +++ b/basic_crud/basic_crud/basic_crud.py @@ -3,6 +3,7 @@ import json import httpx +from sqlmodel import select import reflex as rx @@ -35,7 +36,7 @@ class State(rx.State): def load_product(self): with rx.session() as session: - self.products = session.query(Product).all() + self.products = session.exec(select(Product)).all() yield State.reload_product @rx.background @@ -45,7 +46,7 @@ async def reload_product(self): if self.db_updated: async with self: with rx.session() as session: - self.products = session.query(Product).all() + self.products = session.exec(select(Product)).all() self._db_updated = False @rx.var diff --git a/crm/crm/state/login.py b/crm/crm/state/login.py index d75dfdfb..907e0f3f 100644 --- a/crm/crm/state/login.py +++ b/crm/crm/state/login.py @@ -1,4 +1,6 @@ import reflex as rx +from sqlmodel import select + from .models import User from .state import State @@ -11,7 +13,7 @@ class LoginState(State): def log_in(self): with rx.session() as sess: - user = sess.query(User).filter_by(email=self.email_field).first() + user = sess.exec(select(User).where(User.email == self.email_field)).first() if user and user.password == self.password_field: self.user = user return rx.redirect("/") @@ -20,7 +22,7 @@ def log_in(self): def sign_up(self): with rx.session() as sess: - user = sess.query(User).filter_by(email=self.email_field).first() + user = sess.exec(select(User).where(User.email == self.email_field)).first() if user: return rx.window_alert( "Looks like you’re already registered! Try logging in instead." diff --git a/crm/crm/state/state.py b/crm/crm/state/state.py index aeb2058f..52b2016b 100644 --- a/crm/crm/state/state.py +++ b/crm/crm/state/state.py @@ -1,6 +1,6 @@ from typing import Optional import reflex as rx -from .models import User, Contact +from .models import User class State(rx.State): diff --git a/gpt/gpt/gpt.py b/gpt/gpt/gpt.py index a176c1ab..2a837fcc 100644 --- a/gpt/gpt/gpt.py +++ b/gpt/gpt/gpt.py @@ -4,6 +4,7 @@ from openai import OpenAI import reflex as rx +from sqlmodel import select from .helpers import navbar @@ -29,6 +30,7 @@ class Question(rx.Model, table=True): class State(rx.State): """The app state.""" + show_columns = ["Question", "Answer"] username: str = "" password: str = "" @@ -42,21 +44,22 @@ def questions(self) -> list[Question]: """Get the users saved questions and answers from the database.""" with rx.session() as session: if self.logged_in: - qa = ( - session.query(Question) + qa = session.exec( + select(Question) .where(Question.username == self.username) .distinct(Question.prompt) .order_by(Question.timestamp.desc()) .limit(MAX_QUESTIONS) - .all() - ) + ).all() return [[q.prompt, q.answer] for q in qa] else: return [] def login(self): with rx.session() as session: - user = session.query(User).where(User.username == self.username).first() + user = session.exec( + select(User).where(User.username == self.username) + ).first() if (user and user.password == self.password) or self.username == "admin": self.logged_in = True return rx.redirect("/home") @@ -76,25 +79,28 @@ def signup(self): return rx.redirect("/home") def get_result(self): - if ( - rx.session() - .query(Question) - .where(Question.username == self.username) - .where(Question.prompt == self.prompt) - .first() - or rx.session() - .query(Question) - .where(Question.username == self.username) - .where( - Question.timestamp - > datetime.datetime.now() - datetime.timedelta(days=1) - ) - .count() - > MAX_QUESTIONS - ): - return rx.window_alert( - "You have already asked this question or have asked too many questions in the past 24 hours." - ) + with rx.session as session: + if ( + session.exec( + select(Question) + .where(Question.username == self.username) + .where(Question.prompt == self.prompt) + ).first() + or len( + session.exec( + select(Question) + .where(Question.username == self.username) + .where( + Question.timestamp + > datetime.datetime.now() - datetime.timedelta(days=1) + ) + ).all() + ) + > MAX_QUESTIONS + ): + return rx.window_alert( + "You have already asked this question or have asked too many questions in the past 24 hours." + ) try: response = client.completions.create( model="text-davinci-002", @@ -180,7 +186,12 @@ def login(): return rx.center( rx.vstack( rx.input(on_blur=State.set_username, placeholder="Username", width="100%"), - rx.input(type_="password", on_blur=State.set_password, placeholder="Password", width="100%"), + rx.input( + type_="password", + on_blur=State.set_password, + placeholder="Password", + width="100%", + ), rx.button("Login", on_click=State.login, width="100%"), rx.link(rx.button("Sign Up", width="100%"), href="/signup", width="100%"), ), @@ -201,7 +212,10 @@ def signup(): on_blur=State.set_username, placeholder="Username", width="100%" ), rx.input( - type_="password", on_blur=State.set_password, placeholder="Password", width="100%" + type_="password", + on_blur=State.set_password, + placeholder="Password", + width="100%", ), rx.input( type_="password", diff --git a/local_auth/local_auth/base_state.py b/local_auth/local_auth/base_state.py index 957f07fd..99b707bb 100644 --- a/local_auth/local_auth/base_state.py +++ b/local_auth/local_auth/base_state.py @@ -31,13 +31,14 @@ def authenticated_user(self) -> User: corresponding to the currently authenticated user. """ with rx.session() as session: - result = session.query(User, AuthSession).where( + result = session.exec( + select(User, AuthSession).where( AuthSession.session_id == self.auth_token, AuthSession.expiration >= datetime.datetime.now(datetime.timezone.utc), User.id == AuthSession.user_id, - ).first() - + ), + ).first() if result: user, session = result return user @@ -55,7 +56,9 @@ def is_authenticated(self) -> bool: def do_logout(self) -> None: """Destroy AuthSessions associated with the auth_token.""" with rx.session() as session: - for auth_session in session.query(AuthSession).filter_by(session_id=self.auth_token).all(): + for auth_session in session.exec( + select(AuthSession).where(AuthSession.session_id == self.auth_token) + ).all(): session.delete(auth_session) session.commit() self.auth_token = self.auth_token diff --git a/local_auth/local_auth/login.py b/local_auth/local_auth/login.py index 174e1e44..3f9012fb 100644 --- a/local_auth/local_auth/login.py +++ b/local_auth/local_auth/login.py @@ -1,5 +1,6 @@ """Login page and authentication logic.""" import reflex as rx +from sqlmodel import select from .base_state import State from .user import User @@ -25,7 +26,9 @@ def on_submit(self, form_data) -> rx.event.EventSpec: username = form_data["username"] password = form_data["password"] with rx.session() as session: - user = session.query(User).filter_by(username=username).first() + user = session.exec( + select(User).where(User.username == username) + ).one_or_none() if user is not None and not user.enabled: self.error_message = "This account is disabled." return rx.set_value("password", "") diff --git a/local_auth/local_auth/registration.py b/local_auth/local_auth/registration.py index 67857f0f..55cb4da0 100644 --- a/local_auth/local_auth/registration.py +++ b/local_auth/local_auth/registration.py @@ -5,6 +5,7 @@ from collections.abc import AsyncGenerator import reflex as rx +from sqlmodel import select from .base_state import State from .login import LOGIN_ROUTE, REGISTER_ROUTE @@ -33,7 +34,9 @@ async def handle_registration( self.error_message = "Username cannot be empty" yield rx.set_focus("username") return - existing_user = session.query(User).filter_by(username=username).first() + existing_user = session.exec( + select(User).where(User.username == username) + ).one_or_none() if existing_user is not None: self.error_message = ( f"Username {username} is already registered. Try a different name" diff --git a/sales/sales/sales.py b/sales/sales/sales.py index 2c4c2260..eb1feb8e 100644 --- a/sales/sales/sales.py +++ b/sales/sales/sales.py @@ -1,6 +1,7 @@ from openai import OpenAI import reflex as rx +from sqlmodel import select from .models import Customer @@ -64,7 +65,9 @@ class State(rx.State): def add_customer(self): """Add a customer to the database.""" with rx.session() as session: - if session.query(Customer).filter_by(email=self.email).first(): + if session.exec( + select(Customer).where(Customer.email == self.email) + ).first(): return rx.window_alert("User already exists") session.add( Customer( @@ -91,7 +94,10 @@ def onboarding_page(self): def delete_customer(self, email: str): """Delete a customer from the database.""" with rx.session() as session: - session.query(Customer).filter_by(email=email).delete() + customer = session.exec( + select(Customer).where(Customer.email == email) + ).first() + session.delete(customer) session.commit() generate_email_data: dict = {} @@ -117,7 +123,7 @@ async def call_openai(self): # save the data related to email_content self.email_content_data = response.choices[0].text # update layout of email_content manually - return rx.set_value("email_content", self.email_content_data) + return rx.set_value("email_content", self.email_content_data) def generate_email( self, @@ -144,7 +150,7 @@ def generate_email( def get_users(self) -> list[Customer]: """Get all users from the database.""" with rx.session() as session: - self.users = session.query(Customer).all() + self.users = session.exec(select(Customer)).all() return self.users def open_text_area(self): diff --git a/twitter/twitter/state/auth.py b/twitter/twitter/state/auth.py index b6649039..67da8a6b 100644 --- a/twitter/twitter/state/auth.py +++ b/twitter/twitter/state/auth.py @@ -1,5 +1,6 @@ """The authentication state.""" import reflex as rx +from sqlmodel import select from .base import State, User @@ -16,7 +17,7 @@ def signup(self): with rx.session() as session: if self.password != self.confirm_password: return rx.window_alert("Passwords do not match.") - if session.query(User).filter_by(username=self.username).first(): + if session.exec(select(User).where(User.username == self.username)).first(): return rx.window_alert("Username already exists.") self.user = User(username=self.username, password=self.password) session.add(self.user) @@ -27,7 +28,9 @@ def signup(self): def login(self): """Log in a user.""" with rx.session() as session: - user = session.query(User).filter_by(username=self.username).first() + user = session.exec( + select(User).where(User.username == self.username) + ).first() if user and user.password == self.password: self.user = user return rx.redirect("/") diff --git a/twitter/twitter/state/home.py b/twitter/twitter/state/home.py index faceedd1..36abe766 100644 --- a/twitter/twitter/state/home.py +++ b/twitter/twitter/state/home.py @@ -2,6 +2,7 @@ from datetime import datetime import reflex as rx +from sqlmodel import select from .base import Follows, State, Tweet, User @@ -33,13 +34,11 @@ def get_tweets(self): """Get tweets from the database.""" with rx.session() as session: if self.search: - self.tweets = ( - session.query(Tweet) - .filter(Tweet.content.contains(self.search)) - .all()[::-1] - ) + self.tweets = session.exec( + select(Tweet).where(Tweet.content.contains(self.search)) + ).all()[::-1] else: - self.tweets = session.query(Tweet).all()[::-1] + self.tweets = session.exec(select(Tweet)).all()[::-1] def set_search(self, search): """Set the search query.""" @@ -60,11 +59,11 @@ def following(self) -> list[Follows]: """Get a list of users the current user is following.""" if self.logged_in: with rx.session() as session: - return ( - session.query(Follows) - .filter(Follows.follower_username == self.user.username) - .all() - ) + return session.exec( + select(Follows).where( + Follows.follower_username == self.user.username + ) + ).all() return [] @rx.var @@ -72,11 +71,11 @@ def followers(self) -> list[Follows]: """Get a list of users following the current user.""" if self.logged_in: with rx.session() as session: - return ( - session.query(Follows) - .where(Follows.followed_username == self.user.username) - .all() - ) + return session.exec( + select(Follows).where( + Follows.followed_username == self.user.username + ) + ).all() return [] @rx.var @@ -85,13 +84,11 @@ def search_users(self) -> list[User]: if self.friend != "": with rx.session() as session: current_username = self.user.username if self.user is not None else "" - users = ( - session.query(User) - .filter( + users = session.exec( + select(User).where( User.username.contains(self.friend), User.username != current_username, ) - .all() - ) + ).all() return users return []