Skip to content

Commit

Permalink
feat: implement very basic ORM (#9)
Browse files Browse the repository at this point in the history
At this stage, there is an easy way to query, insert, and delete with
super simple filters. This is enough of a scaffolding to be able to
quickly make it much more advanced.
  • Loading branch information
m4tx authored Aug 19, 2024
1 parent 7867952 commit 88067d7
Show file tree
Hide file tree
Showing 22 changed files with 834 additions and 43 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,7 @@ Cargo.lock

# MSVC Windows builds of rustc generate these, which store debugging information
*.pdb

# Test databases
*.db
*.sqlite3
9 changes: 6 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ members = [
"flareon-admin",
"flareon-auth",
"flareon-macros",
"flareon-orm",
# Examples
"examples/hello-world",
"examples/todo-list",
Expand All @@ -22,7 +21,9 @@ axum = "0.7.5"
bytes = "1.6.1"
chrono = { version = "0.4.38", features = ["serde"] }
clap = { version = "4.5.8", features = ["derive", "env"] }
convert_case = "0.6.0"
derive_builder = "0.20.0"
derive_more = { version = "1.0.0", features = ["full"] }
env_logger = "0.11.3"
flareon = { path = "flareon" }
flareon_macros = { path = "flareon-macros" }
Expand All @@ -31,8 +32,10 @@ indexmap = "2.2.6"
itertools = "0.13.0"
log = "0.4.22"
regex = "1.10.5"
sea-query = "0.32.0-rc.1"
sea-query-binder = { version = "0.7.0-rc.1", features = ["sqlx-any", "runtime-tokio"] }
serde = "1.0.203"
slug = "0.1.5"
tokio = { version = "1.38.0", features = ["macros", "rt-multi-thread"] }
tower = "0.4.13"
sqlx = { version = "0.8.0", features = ["runtime-tokio", "sqlite"] }
thiserror = "1.0.61"
tokio = { version = "1.38.0", features = ["macros", "rt-multi-thread"] }
45 changes: 36 additions & 9 deletions examples/todo-list/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
use std::sync::Arc;

use askama::Template;
use flareon::db::query::ExprEq;
use flareon::db::{model, Database, Model};
use flareon::forms::Form;
use flareon::prelude::{Body, Error, FlareonApp, FlareonProject, Response, Route, StatusCode};
use flareon::request::Request;
use flareon::reverse;
use tokio::sync::RwLock;
use tokio::sync::OnceCell;

#[derive(Debug, Clone)]
#[model]
struct TodoItem {
id: i32,
title: String,
}

Expand All @@ -19,10 +23,12 @@ struct IndexTemplate<'a> {
todo_items: Vec<TodoItem>,
}

static TODOS: RwLock<Vec<TodoItem>> = RwLock::const_new(Vec::new());
static DB: OnceCell<Database> = OnceCell::const_new();

async fn index(request: Request) -> Result<Response, Error> {
let todo_items = (*TODOS.read().await).clone();
let db = DB.get().unwrap();

let todo_items = TodoItem::objects().all(db).await.unwrap();
let index_template = IndexTemplate {
request: &request,
todo_items,
Expand All @@ -45,22 +51,30 @@ async fn add_todo(mut request: Request) -> Result<Response, Error> {
let todo_form = TodoForm::from_request(&mut request).await.unwrap();

{
let mut todos = TODOS.write().await;
todos.push(TodoItem {
let db = DB.get().unwrap();
TodoItem {
id: 0,
title: todo_form.title,
});
}
.save(db)
.await
.unwrap();
}

Ok(reverse!(request, "index"))
}

async fn remove_todo(request: Request) -> Result<Response, Error> {
let todo_id = request.path_param("todo_id").expect("todo_id not found");
let todo_id = todo_id.parse::<usize>().expect("todo_id is not a number");
let todo_id = todo_id.parse::<i32>().expect("todo_id is not a number");

{
let mut todos = TODOS.write().await;
todos.remove(todo_id);
let db = DB.get().unwrap();
TodoItem::objects()
.filter(<TodoItem as Model>::Fields::ID.eq(todo_id))
.delete(db)
.await
.unwrap();
}

Ok(reverse!(request, "index"))
Expand All @@ -70,6 +84,19 @@ async fn remove_todo(request: Request) -> Result<Response, Error> {
async fn main() {
env_logger::init();

let db = DB
.get_or_init(|| async { Database::new("sqlite::memory:").await.unwrap() })
.await;
db.execute(
r"
CREATE TABLE todo_item (
id INTEGER PRIMARY KEY AUTOINCREMENT,
title TEXT NOT NULL
);",
)
.await
.unwrap();

let todo_app = FlareonApp::builder()
.urls([
Route::with_handler_and_name("/", Arc::new(Box::new(index)), "index"),
Expand Down
2 changes: 1 addition & 1 deletion examples/todo-list/templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ <h1>TODO List</h1>
<ul id="todo-list">
{% for todo in todo_items %}
<li>
{% let todo_id = loop.index0 %}
{% let todo_id = todo.id %}
<form action="{{ flareon::reverse_str!(request, "remove-todo", "todo_id" => todo_id) }}" method="post">
<span>{{ todo.title }}</span>
<button type="submit">Remove</button>
Expand Down
1 change: 1 addition & 0 deletions flareon-admin/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#[must_use]
pub fn add(left: u64, right: u64) -> u64 {
left + right
}
Expand Down
1 change: 1 addition & 0 deletions flareon-auth/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#[must_use]
pub fn add(left: u64, right: u64) -> u64 {
left + right
}
Expand Down
1 change: 1 addition & 0 deletions flareon-macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ name = "tests"
path = "tests/compile_tests.rs"

[dependencies]
convert_case.workspace = true
darling = "0.20.10"
proc-macro-crate = "3.1.0"
proc-macro2 = "1.0.86"
Expand Down
17 changes: 17 additions & 0 deletions flareon-macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
mod form;
mod model;

use darling::ast::NestedMeta;
use darling::Error;
use proc_macro::TokenStream;
use proc_macro_crate::crate_name;
use quote::quote;
use syn::parse_macro_input;

use crate::form::impl_form_for_struct;
use crate::model::impl_model_for_struct;

/// Derive the [`Form`] trait for a struct.
///
Expand All @@ -22,6 +26,19 @@ pub fn derive_form(input: TokenStream) -> TokenStream {
token_stream.into()
}

#[proc_macro_attribute]
pub fn model(args: TokenStream, input: TokenStream) -> TokenStream {
let attr_args = match NestedMeta::parse_meta_list(args.into()) {
Ok(v) => v,
Err(e) => {
return TokenStream::from(Error::from(e).write_errors());
}
};
let ast = parse_macro_input!(input as syn::DeriveInput);
let token_stream = impl_model_for_struct(attr_args, ast);
token_stream.into()
}

pub(crate) fn flareon_ident() -> proc_macro2::TokenStream {
let flareon_crate = crate_name("flareon").expect("flareon is not present in `Cargo.toml`");
match flareon_crate {
Expand Down
178 changes: 178 additions & 0 deletions flareon-macros/src/model.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
use convert_case::{Case, Casing};
use darling::ast::NestedMeta;
use darling::{FromDeriveInput, FromField};
use proc_macro2::{Ident, TokenStream};
use quote::{format_ident, quote, ToTokens, TokenStreamExt};

use crate::flareon_ident;

pub fn impl_model_for_struct(_args: Vec<NestedMeta>, ast: syn::DeriveInput) -> TokenStream {
let opts = match ModelOpts::from_derive_input(&ast) {
Ok(val) => val,
Err(err) => {
return err.write_errors();
}
};

let mut builder = opts.as_model_builder();
for field in opts.fields() {
builder.push_field(field);
}

quote!(#ast #builder)
}

#[derive(Debug, FromDeriveInput)]
#[darling(forward_attrs(allow, doc, cfg), supports(struct_named))]
struct ModelOpts {
ident: syn::Ident,
data: darling::ast::Data<darling::util::Ignored, Field>,
}

impl ModelOpts {
fn fields(&self) -> Vec<&Field> {
self.data
.as_ref()
.take_struct()
.expect("Only structs are supported")
.fields
}

fn field_count(&self) -> usize {
self.fields().len()
}

fn as_model_builder(&self) -> ModelBuilder {
let table_name = self.ident.to_string().to_case(Case::Snake);

ModelBuilder {
name: self.ident.clone(),
table_name,
fields_struct_name: format_ident!("{}Fields", self.ident),
fields_as_columns: Vec::with_capacity(self.field_count()),
fields_as_from_db: Vec::with_capacity(self.field_count()),
fields_as_get_values: Vec::with_capacity(self.field_count()),
fields_as_field_refs: Vec::with_capacity(self.field_count()),
}
}
}

#[derive(Debug, Clone, FromField)]
#[darling(attributes(form))]
struct Field {
ident: Option<syn::Ident>,
ty: syn::Type,
}

#[derive(Debug)]
struct ModelBuilder {
name: Ident,
table_name: String,
fields_struct_name: Ident,
fields_as_columns: Vec<TokenStream>,
fields_as_from_db: Vec<TokenStream>,
fields_as_get_values: Vec<TokenStream>,
fields_as_field_refs: Vec<TokenStream>,
}

impl ToTokens for ModelBuilder {
fn to_tokens(&self, tokens: &mut TokenStream) {
tokens.append_all(self.build_model_impl());
tokens.append_all(self.build_fields_struct());
}
}

impl ModelBuilder {
fn push_field(&mut self, field: &Field) {
let orm_ident = orm_ident();

let name = field.ident.as_ref().unwrap();
let const_name = format_ident!("{}", name.to_string().to_case(Case::UpperSnake));
let ty = &field.ty;
let index = self.fields_as_columns.len();

let column_name = name.to_string().to_case(Case::Snake);
let is_auto = column_name == "id";

{
let mut field_as_column = quote!(#orm_ident::Column::new(
#orm_ident::Identifier::new(#column_name)
));
if is_auto {
field_as_column.append_all(quote!(.auto(true)));
}
self.fields_as_columns.push(field_as_column);
}

self.fields_as_from_db.push(quote!(
#name: db_row.get::<#ty>(#index)?
));

self.fields_as_get_values.push(quote!(
#index => &self.#name as &dyn #orm_ident::ValueRef
));

self.fields_as_field_refs.push(quote!(
pub const #const_name: #orm_ident::query::FieldRef<#ty> =
#orm_ident::query::FieldRef::<#ty>::new(#orm_ident::Identifier::new(#column_name));
));
}

fn build_model_impl(&self) -> TokenStream {
let orm_ident = orm_ident();

let name = &self.name;
let table_name = &self.table_name;
let fields_struct_name = &self.fields_struct_name;
let fields_as_columns = &self.fields_as_columns;
let fields_as_from_db = &self.fields_as_from_db;
let fields_as_get_values = &self.fields_as_get_values;

quote! {
#[automatically_derived]
impl #orm_ident::Model for #name {
type Fields = #fields_struct_name;

const COLUMNS: &'static [#orm_ident::Column] = &[
#(#fields_as_columns,)*
];
const TABLE_NAME: #orm_ident::Identifier = #orm_ident::Identifier::new(#table_name);

fn from_db(db_row: #orm_ident::Row) -> #orm_ident::Result<Self> {
Ok(Self {
#(#fields_as_from_db,)*
})
}

fn get_values(&self, columns: &[usize]) -> Vec<&dyn #orm_ident::ValueRef> {
columns
.iter()
.map(|&column| match column {
#(#fields_as_get_values,)*
_ => panic!("Unknown column index: {}", column),
})
.collect()
}
}
}
}

fn build_fields_struct(&self) -> TokenStream {
let fields_struct_name = &self.fields_struct_name;
let fields_as_field_refs = &self.fields_as_field_refs;

quote! {
#[derive(::core::fmt::Debug)]
pub struct #fields_struct_name;

impl #fields_struct_name {
#(#fields_as_field_refs)*
}
}
}
}

fn orm_ident() -> TokenStream {
let crate_ident = flareon_ident();
quote! { #crate_ident::db }
}
6 changes: 6 additions & 0 deletions flareon-macros/tests/compile_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,9 @@ fn test_derive_form() {
let t = trybuild::TestCases::new();
t.pass("tests/ui/derive_form.rs");
}

#[test]
fn test_attr_model() {
let t = trybuild::TestCases::new();
t.pass("tests/ui/attr_model.rs");
}
14 changes: 14 additions & 0 deletions flareon-macros/tests/ui/attr_model.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
use flareon::db::{model, Model};

#[derive(Debug)]
#[model]
struct MyModel {
id: i32,
name: std::string::String,
description: String,
visits: i32,
}

fn main() {
println!("{:?}", MyModel::TABLE_NAME);
}
Loading

0 comments on commit 88067d7

Please sign in to comment.