Skip to content

Commit

Permalink
Session.query to Session.exec (#174)
Browse files Browse the repository at this point in the history
  • Loading branch information
ElijahAhianyo authored Nov 28, 2023
1 parent d71a4eb commit cb3e32c
Show file tree
Hide file tree
Showing 10 changed files with 96 additions and 64 deletions.
5 changes: 3 additions & 2 deletions basic_crud/basic_crud/basic_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json

import httpx
from sqlmodel import select

import reflex as rx

Expand Down Expand Up @@ -35,7 +36,7 @@ class State(rx.State):

def load_product(self):
with rx.session() as session:
self.products = session.query(Product).all()
self.products = session.exec(select(Product)).all()
yield State.reload_product

@rx.background
Expand All @@ -45,7 +46,7 @@ async def reload_product(self):
if self.db_updated:
async with self:
with rx.session() as session:
self.products = session.query(Product).all()
self.products = session.exec(select(Product)).all()
self._db_updated = False

@rx.var
Expand Down
6 changes: 4 additions & 2 deletions crm/crm/state/login.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import reflex as rx
from sqlmodel import select

from .models import User
from .state import State

Expand All @@ -11,7 +13,7 @@ class LoginState(State):

def log_in(self):
with rx.session() as sess:
user = sess.query(User).filter_by(email=self.email_field).first()
user = sess.exec(select(User).where(User.email == self.email_field)).first()
if user and user.password == self.password_field:
self.user = user
return rx.redirect("/")
Expand All @@ -20,7 +22,7 @@ def log_in(self):

def sign_up(self):
with rx.session() as sess:
user = sess.query(User).filter_by(email=self.email_field).first()
user = sess.exec(select(User).where(User.email == self.email_field)).first()
if user:
return rx.window_alert(
"Looks like you’re already registered! Try logging in instead."
Expand Down
2 changes: 1 addition & 1 deletion crm/crm/state/state.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional
import reflex as rx
from .models import User, Contact
from .models import User


class State(rx.State):
Expand Down
66 changes: 40 additions & 26 deletions gpt/gpt/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from openai import OpenAI

import reflex as rx
from sqlmodel import select

from .helpers import navbar

Expand All @@ -29,6 +30,7 @@ class Question(rx.Model, table=True):

class State(rx.State):
"""The app state."""

show_columns = ["Question", "Answer"]
username: str = ""
password: str = ""
Expand All @@ -42,21 +44,22 @@ def questions(self) -> list[Question]:
"""Get the users saved questions and answers from the database."""
with rx.session() as session:
if self.logged_in:
qa = (
session.query(Question)
qa = session.exec(
select(Question)
.where(Question.username == self.username)
.distinct(Question.prompt)
.order_by(Question.timestamp.desc())
.limit(MAX_QUESTIONS)
.all()
)
).all()
return [[q.prompt, q.answer] for q in qa]
else:
return []

def login(self):
with rx.session() as session:
user = session.query(User).where(User.username == self.username).first()
user = session.exec(
select(User).where(User.username == self.username)
).first()
if (user and user.password == self.password) or self.username == "admin":
self.logged_in = True
return rx.redirect("/home")
Expand All @@ -76,25 +79,28 @@ def signup(self):
return rx.redirect("/home")

def get_result(self):
if (
rx.session()
.query(Question)
.where(Question.username == self.username)
.where(Question.prompt == self.prompt)
.first()
or rx.session()
.query(Question)
.where(Question.username == self.username)
.where(
Question.timestamp
> datetime.datetime.now() - datetime.timedelta(days=1)
)
.count()
> MAX_QUESTIONS
):
return rx.window_alert(
"You have already asked this question or have asked too many questions in the past 24 hours."
)
with rx.session as session:
if (
session.exec(
select(Question)
.where(Question.username == self.username)
.where(Question.prompt == self.prompt)
).first()
or len(
session.exec(
select(Question)
.where(Question.username == self.username)
.where(
Question.timestamp
> datetime.datetime.now() - datetime.timedelta(days=1)
)
).all()
)
> MAX_QUESTIONS
):
return rx.window_alert(
"You have already asked this question or have asked too many questions in the past 24 hours."
)
try:
response = client.completions.create(
model="text-davinci-002",
Expand Down Expand Up @@ -180,7 +186,12 @@ def login():
return rx.center(
rx.vstack(
rx.input(on_blur=State.set_username, placeholder="Username", width="100%"),
rx.input(type_="password", on_blur=State.set_password, placeholder="Password", width="100%"),
rx.input(
type_="password",
on_blur=State.set_password,
placeholder="Password",
width="100%",
),
rx.button("Login", on_click=State.login, width="100%"),
rx.link(rx.button("Sign Up", width="100%"), href="/signup", width="100%"),
),
Expand All @@ -201,7 +212,10 @@ def signup():
on_blur=State.set_username, placeholder="Username", width="100%"
),
rx.input(
type_="password", on_blur=State.set_password, placeholder="Password", width="100%"
type_="password",
on_blur=State.set_password,
placeholder="Password",
width="100%",
),
rx.input(
type_="password",
Expand Down
11 changes: 7 additions & 4 deletions local_auth/local_auth/base_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ def authenticated_user(self) -> User:
corresponding to the currently authenticated user.
"""
with rx.session() as session:
result = session.query(User, AuthSession).where(
result = session.exec(
select(User, AuthSession).where(
AuthSession.session_id == self.auth_token,
AuthSession.expiration
>= datetime.datetime.now(datetime.timezone.utc),
User.id == AuthSession.user_id,
).first()

),
).first()
if result:
user, session = result
return user
Expand All @@ -55,7 +56,9 @@ def is_authenticated(self) -> bool:
def do_logout(self) -> None:
"""Destroy AuthSessions associated with the auth_token."""
with rx.session() as session:
for auth_session in session.query(AuthSession).filter_by(session_id=self.auth_token).all():
for auth_session in session.exec(
select(AuthSession).where(AuthSession.session_id == self.auth_token)
).all():
session.delete(auth_session)
session.commit()
self.auth_token = self.auth_token
Expand Down
5 changes: 4 additions & 1 deletion local_auth/local_auth/login.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Login page and authentication logic."""
import reflex as rx
from sqlmodel import select

from .base_state import State
from .user import User
Expand All @@ -25,7 +26,9 @@ def on_submit(self, form_data) -> rx.event.EventSpec:
username = form_data["username"]
password = form_data["password"]
with rx.session() as session:
user = session.query(User).filter_by(username=username).first()
user = session.exec(
select(User).where(User.username == username)
).one_or_none()
if user is not None and not user.enabled:
self.error_message = "This account is disabled."
return rx.set_value("password", "")
Expand Down
5 changes: 4 additions & 1 deletion local_auth/local_auth/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections.abc import AsyncGenerator

import reflex as rx
from sqlmodel import select

from .base_state import State
from .login import LOGIN_ROUTE, REGISTER_ROUTE
Expand Down Expand Up @@ -33,7 +34,9 @@ async def handle_registration(
self.error_message = "Username cannot be empty"
yield rx.set_focus("username")
return
existing_user = session.query(User).filter_by(username=username).first()
existing_user = session.exec(
select(User).where(User.username == username)
).one_or_none()
if existing_user is not None:
self.error_message = (
f"Username {username} is already registered. Try a different name"
Expand Down
14 changes: 10 additions & 4 deletions sales/sales/sales.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from openai import OpenAI

import reflex as rx
from sqlmodel import select

from .models import Customer

Expand Down Expand Up @@ -64,7 +65,9 @@ class State(rx.State):
def add_customer(self):
"""Add a customer to the database."""
with rx.session() as session:
if session.query(Customer).filter_by(email=self.email).first():
if session.exec(
select(Customer).where(Customer.email == self.email)
).first():
return rx.window_alert("User already exists")
session.add(
Customer(
Expand All @@ -91,7 +94,10 @@ def onboarding_page(self):
def delete_customer(self, email: str):
"""Delete a customer from the database."""
with rx.session() as session:
session.query(Customer).filter_by(email=email).delete()
customer = session.exec(
select(Customer).where(Customer.email == email)
).first()
session.delete(customer)
session.commit()

generate_email_data: dict = {}
Expand All @@ -117,7 +123,7 @@ async def call_openai(self):
# save the data related to email_content
self.email_content_data = response.choices[0].text
# update layout of email_content manually
return rx.set_value("email_content", self.email_content_data)
return rx.set_value("email_content", self.email_content_data)

def generate_email(
self,
Expand All @@ -144,7 +150,7 @@ def generate_email(
def get_users(self) -> list[Customer]:
"""Get all users from the database."""
with rx.session() as session:
self.users = session.query(Customer).all()
self.users = session.exec(select(Customer)).all()
return self.users

def open_text_area(self):
Expand Down
7 changes: 5 additions & 2 deletions twitter/twitter/state/auth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""The authentication state."""
import reflex as rx
from sqlmodel import select

from .base import State, User

Expand All @@ -16,7 +17,7 @@ def signup(self):
with rx.session() as session:
if self.password != self.confirm_password:
return rx.window_alert("Passwords do not match.")
if session.query(User).filter_by(username=self.username).first():
if session.exec(select(User).where(User.username == self.username)).first():
return rx.window_alert("Username already exists.")
self.user = User(username=self.username, password=self.password)
session.add(self.user)
Expand All @@ -27,7 +28,9 @@ def signup(self):
def login(self):
"""Log in a user."""
with rx.session() as session:
user = session.query(User).filter_by(username=self.username).first()
user = session.exec(
select(User).where(User.username == self.username)
).first()
if user and user.password == self.password:
self.user = user
return rx.redirect("/")
Expand Down
39 changes: 18 additions & 21 deletions twitter/twitter/state/home.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from datetime import datetime

import reflex as rx
from sqlmodel import select

from .base import Follows, State, Tweet, User

Expand Down Expand Up @@ -33,13 +34,11 @@ def get_tweets(self):
"""Get tweets from the database."""
with rx.session() as session:
if self.search:
self.tweets = (
session.query(Tweet)
.filter(Tweet.content.contains(self.search))
.all()[::-1]
)
self.tweets = session.exec(
select(Tweet).where(Tweet.content.contains(self.search))
).all()[::-1]
else:
self.tweets = session.query(Tweet).all()[::-1]
self.tweets = session.exec(select(Tweet)).all()[::-1]

def set_search(self, search):
"""Set the search query."""
Expand All @@ -60,23 +59,23 @@ def following(self) -> list[Follows]:
"""Get a list of users the current user is following."""
if self.logged_in:
with rx.session() as session:
return (
session.query(Follows)
.filter(Follows.follower_username == self.user.username)
.all()
)
return session.exec(
select(Follows).where(
Follows.follower_username == self.user.username
)
).all()
return []

@rx.var
def followers(self) -> list[Follows]:
"""Get a list of users following the current user."""
if self.logged_in:
with rx.session() as session:
return (
session.query(Follows)
.where(Follows.followed_username == self.user.username)
.all()
)
return session.exec(
select(Follows).where(
Follows.followed_username == self.user.username
)
).all()
return []

@rx.var
Expand All @@ -85,13 +84,11 @@ def search_users(self) -> list[User]:
if self.friend != "":
with rx.session() as session:
current_username = self.user.username if self.user is not None else ""
users = (
session.query(User)
.filter(
users = session.exec(
select(User).where(
User.username.contains(self.friend),
User.username != current_username,
)
.all()
)
).all()
return users
return []

0 comments on commit cb3e32c

Please sign in to comment.