Do reset password operations in db transaction

And modify signature of model methods to accept an executor instead of a
pool connection which will also allow transactions.
This commit is contained in:
Tyler Hallada 2023-10-13 14:44:40 +02:00
parent 60671d5865
commit 835e9dc748
7 changed files with 126 additions and 82 deletions

View File

@ -71,13 +71,9 @@ pub async fn post(
layout: Layout,
Form(forgot_password): Form<ForgotPassword>,
) -> Result<Response> {
dbg!(&ip);
dbg!(&user_agent);
dbg!(&forgot_password.email);
let user: User = match User::get_by_email(&pool, forgot_password.email.clone()).await {
Ok(user) => user,
Err(err) => {
dbg!(&err);
if let Error::NotFoundString(_, _) = err {
info!(email = forgot_password.email, "invalid email");
return Ok(layout

View File

@ -241,11 +241,11 @@ pub async fn post(
}
};
info!(user_id = %user.user_id, "user exists with verified email, resetting password");
// TODO: do both in transaction
UserPasswordResetToken::delete(&pool, reset_password.token).await?;
let mut tx = pool.begin().await?;
UserPasswordResetToken::delete(tx.as_mut(), reset_password.token).await?;
let user = match user
.update_password(
&pool,
tx.as_mut(),
UpdateUserPassword {
password: reset_password.password,
},
@ -289,6 +289,7 @@ pub async fn post(
ip.into(),
user_agent.map(|ua| ua.to_string()),
);
tx.commit().await?;
Ok(layout
.with_subtitle("reset password")
.targeted(hx_target)

View File

@ -1,6 +1,6 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use sqlx::{Executor, Postgres};
use uuid::Uuid;
use validator::{Validate, ValidationErrors};
@ -44,9 +44,9 @@ pub struct GetEntriesOptions {
}
impl Entry {
pub async fn get(pool: &PgPool, entry_id: Uuid) -> Result<Entry> {
pub async fn get(db: impl Executor<'_, Database = Postgres>, entry_id: Uuid) -> Result<Entry> {
sqlx::query_as!(Entry, "select * from entry where entry_id = $1", entry_id)
.fetch_one(pool)
.fetch_one(db)
.await
.map_err(|error| {
if let sqlx::error::Error::RowNotFound = error {
@ -56,7 +56,10 @@ impl Entry {
})
}
pub async fn get_all(pool: &PgPool, options: &GetEntriesOptions) -> sqlx::Result<Vec<Entry>> {
pub async fn get_all(
db: impl Executor<'_, Database = Postgres>,
options: &GetEntriesOptions,
) -> sqlx::Result<Vec<Entry>> {
if let Some(feed_id) = options.feed_id {
if let Some(published_before) = options.published_before {
if let Some(id_before) = options.id_before {
@ -74,7 +77,7 @@ impl Entry {
id_before,
options.limit.unwrap_or(DEFAULT_ENTRIES_PAGE_SIZE)
)
.fetch_all(pool)
.fetch_all(db)
.await
} else {
sqlx::query_as!(
@ -90,7 +93,7 @@ impl Entry {
published_before,
options.limit.unwrap_or(DEFAULT_ENTRIES_PAGE_SIZE)
)
.fetch_all(pool)
.fetch_all(db)
.await
}
} else {
@ -105,7 +108,7 @@ impl Entry {
feed_id,
options.limit.unwrap_or(DEFAULT_ENTRIES_PAGE_SIZE)
)
.fetch_all(pool)
.fetch_all(db)
.await
}
} else if let Some(published_before) = options.published_before {
@ -122,7 +125,7 @@ impl Entry {
id_before,
options.limit.unwrap_or(DEFAULT_ENTRIES_PAGE_SIZE)
)
.fetch_all(pool)
.fetch_all(db)
.await
} else {
sqlx::query_as!(
@ -136,7 +139,7 @@ impl Entry {
published_before,
options.limit.unwrap_or(DEFAULT_ENTRIES_PAGE_SIZE)
)
.fetch_all(pool)
.fetch_all(db)
.await
}
} else {
@ -149,12 +152,15 @@ impl Entry {
",
options.limit.unwrap_or(DEFAULT_ENTRIES_PAGE_SIZE)
)
.fetch_all(pool)
.fetch_all(db)
.await
}
}
pub async fn create(pool: &PgPool, payload: CreateEntry) -> Result<Entry> {
pub async fn create(
db: impl Executor<'_, Database = Postgres>,
payload: CreateEntry,
) -> Result<Entry> {
payload.validate()?;
sqlx::query_as!(
Entry,
@ -169,7 +175,7 @@ impl Entry {
payload.feed_id,
payload.published_at,
)
.fetch_one(pool)
.fetch_one(db)
.await
.map_err(|error| {
if let sqlx::error::Error::Database(ref psql_error) = error {
@ -181,7 +187,10 @@ impl Entry {
})
}
pub async fn upsert(pool: &PgPool, payload: CreateEntry) -> Result<Entry> {
pub async fn upsert(
db: impl Executor<'_, Database = Postgres>,
payload: CreateEntry,
) -> Result<Entry> {
payload.validate()?;
sqlx::query_as!(
Entry,
@ -200,7 +209,7 @@ impl Entry {
payload.feed_id,
payload.published_at,
)
.fetch_one(pool)
.fetch_one(db)
.await
.map_err(|error| {
if let sqlx::error::Error::Database(ref psql_error) = error {
@ -212,7 +221,10 @@ impl Entry {
})
}
pub async fn bulk_create(pool: &PgPool, payload: Vec<CreateEntry>) -> Result<Vec<Entry>> {
pub async fn bulk_create(
db: impl Executor<'_, Database = Postgres>,
payload: Vec<CreateEntry>,
) -> Result<Vec<Entry>> {
let mut titles = Vec::with_capacity(payload.len());
let mut urls = Vec::with_capacity(payload.len());
let mut descriptions: Vec<Option<String>> = Vec::with_capacity(payload.len());
@ -241,7 +253,7 @@ impl Entry {
feed_ids.as_slice(),
published_ats.as_slice(),
)
.fetch_all(pool)
.fetch_all(db)
.await
.map_err(|error| {
if let sqlx::error::Error::Database(ref psql_error) = error {
@ -253,7 +265,10 @@ impl Entry {
})
}
pub async fn bulk_upsert(pool: &PgPool, payload: Vec<CreateEntry>) -> Result<Vec<Entry>> {
pub async fn bulk_upsert(
db: impl Executor<'_, Database = Postgres>,
payload: Vec<CreateEntry>,
) -> Result<Vec<Entry>> {
let mut titles = Vec::with_capacity(payload.len());
let mut urls = Vec::with_capacity(payload.len());
let mut descriptions: Vec<Option<String>> = Vec::with_capacity(payload.len());
@ -286,7 +301,7 @@ impl Entry {
feed_ids.as_slice(),
published_ats.as_slice(),
)
.fetch_all(pool)
.fetch_all(db)
.await
.map_err(|error| {
if let sqlx::error::Error::Database(ref psql_error) = error {
@ -298,7 +313,10 @@ impl Entry {
})
}
pub async fn update(pool: &PgPool, payload: Entry) -> Result<Entry> {
pub async fn update(
db: impl Executor<'_, Database = Postgres>,
payload: Entry,
) -> Result<Entry> {
sqlx::query_as!(
Entry,
"update entry set
@ -321,7 +339,7 @@ impl Entry {
payload.last_modified_header,
payload.published_at,
)
.fetch_one(pool)
.fetch_one(db)
.await
.map_err(|error| {
if let sqlx::error::Error::Database(ref psql_error) = error {
@ -333,12 +351,12 @@ impl Entry {
})
}
pub async fn delete(pool: &PgPool, entry_id: Uuid) -> Result<()> {
pub async fn delete(db: impl Executor<'_, Database = Postgres>, entry_id: Uuid) -> Result<()> {
sqlx::query!(
"update entry set deleted_at = now() where entry_id = $1",
entry_id
)
.execute(pool)
.execute(db)
.await?;
Ok(())
}

View File

@ -2,7 +2,7 @@ use std::str::FromStr;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::{FromRow, PgPool, postgres::PgQueryResult};
use sqlx::{postgres::PgQueryResult, Executor, FromRow, Postgres};
use uuid::Uuid;
use validator::Validate;
@ -127,7 +127,7 @@ pub struct GetFeedsOptions {
}
impl Feed {
pub async fn get(pool: &PgPool, feed_id: Uuid) -> Result<Feed> {
pub async fn get(db: impl Executor<'_, Database = Postgres>, feed_id: Uuid) -> Result<Feed> {
sqlx::query_as!(
Feed,
// Unable to SELECT * here due to https://github.com/launchbadge/sqlx/issues/1004
@ -150,7 +150,7 @@ impl Feed {
from feed where feed_id = $1"#,
feed_id
)
.fetch_one(pool)
.fetch_one(db)
.await
.map_err(|error| {
if let sqlx::error::Error::RowNotFound = error {
@ -160,7 +160,10 @@ impl Feed {
})
}
pub async fn get_all(pool: &PgPool, options: &GetFeedsOptions) -> sqlx::Result<Vec<Feed>> {
pub async fn get_all(
db: impl Executor<'_, Database = Postgres>,
options: &GetFeedsOptions,
) -> sqlx::Result<Vec<Feed>> {
// TODO: make sure there are indices for all of these sort options
match options.sort.as_ref().unwrap_or(&GetFeedsSort::CreatedAt) {
GetFeedsSort::Title => {
@ -192,7 +195,7 @@ impl Feed {
options.before_id.unwrap_or(Uuid::nil()),
options.limit.unwrap_or(DEFAULT_FEEDS_PAGE_SIZE),
)
.fetch_all(pool)
.fetch_all(db)
.await
} else {
sqlx::query_as!(
@ -219,9 +222,8 @@ impl Feed {
"#,
options.limit.unwrap_or(DEFAULT_FEEDS_PAGE_SIZE),
)
.fetch_all(pool)
.fetch_all(db)
.await
}
}
GetFeedsSort::CreatedAt => {
@ -253,7 +255,7 @@ impl Feed {
options.before_id.unwrap_or(Uuid::nil()),
options.limit.unwrap_or(DEFAULT_FEEDS_PAGE_SIZE),
)
.fetch_all(pool)
.fetch_all(db)
.await
} else {
sqlx::query_as!(
@ -280,9 +282,8 @@ impl Feed {
"#,
options.limit.unwrap_or(DEFAULT_FEEDS_PAGE_SIZE),
)
.fetch_all(pool)
.fetch_all(db)
.await
}
}
GetFeedsSort::LastCrawledAt => {
@ -314,7 +315,7 @@ impl Feed {
options.before_id.unwrap_or(Uuid::nil()),
options.limit.unwrap_or(DEFAULT_FEEDS_PAGE_SIZE),
)
.fetch_all(pool)
.fetch_all(db)
.await
} else {
sqlx::query_as!(
@ -341,9 +342,8 @@ impl Feed {
"#,
options.limit.unwrap_or(DEFAULT_FEEDS_PAGE_SIZE),
)
.fetch_all(pool)
.fetch_all(db)
.await
}
}
GetFeedsSort::LastEntryPublishedAt => {
@ -375,7 +375,7 @@ impl Feed {
options.before_id.unwrap_or(Uuid::nil()),
options.limit.unwrap_or(DEFAULT_FEEDS_PAGE_SIZE),
)
.fetch_all(pool)
.fetch_all(db)
.await
} else {
sqlx::query_as!(
@ -402,15 +402,17 @@ impl Feed {
"#,
options.limit.unwrap_or(DEFAULT_FEEDS_PAGE_SIZE),
)
.fetch_all(pool)
.fetch_all(db)
.await
}
}
}
}
pub async fn create(pool: &PgPool, payload: CreateFeed) -> Result<Feed> {
pub async fn create(
db: impl Executor<'_, Database = Postgres>,
payload: CreateFeed,
) -> Result<Feed> {
payload.validate()?;
Ok(sqlx::query_as!(
Feed,
@ -438,11 +440,14 @@ impl Feed {
payload.url,
payload.description
)
.fetch_one(pool)
.fetch_one(db)
.await?)
}
pub async fn upsert(pool: &PgPool, payload: UpsertFeed) -> Result<Feed> {
pub async fn upsert(
db: impl Executor<'_, Database = Postgres>,
payload: UpsertFeed,
) -> Result<Feed> {
payload.validate()?;
Ok(sqlx::query_as!(
Feed,
@ -476,11 +481,15 @@ impl Feed {
payload.feed_type as Option<FeedType>,
payload.description
)
.fetch_one(pool)
.fetch_one(db)
.await?)
}
pub async fn update(pool: &PgPool, feed_id: Uuid, payload: UpdateFeed) -> Result<Feed> {
pub async fn update(
db: impl Executor<'_, Database = Postgres>,
feed_id: Uuid,
payload: UpdateFeed,
) -> Result<Feed> {
payload.validate()?;
let mut query = sqlx::QueryBuilder::new("UPDATE feed SET ");
@ -520,20 +529,24 @@ impl Feed {
let query = query.build_query_as();
Ok(query.fetch_one(pool).await?)
Ok(query.fetch_one(db).await?)
}
pub async fn delete(pool: &PgPool, feed_id: Uuid) -> Result<()> {
pub async fn delete(db: impl Executor<'_, Database = Postgres>, feed_id: Uuid) -> Result<()> {
sqlx::query!(
"update feed set deleted_at = now() where feed_id = $1",
feed_id
)
.execute(pool)
.execute(db)
.await?;
Ok(())
}
pub async fn update_crawl_error(pool: &PgPool, feed_id: Uuid, last_crawl_error: String) -> Result<PgQueryResult> {
pub async fn update_crawl_error(
db: impl Executor<'_, Database = Postgres>,
feed_id: Uuid,
last_crawl_error: String,
) -> Result<PgQueryResult> {
Ok(sqlx::query!(
r#"update feed set
last_crawl_error = $2
@ -541,11 +554,11 @@ impl Feed {
feed_id,
last_crawl_error,
)
.execute(pool)
.execute(db)
.await?)
}
pub async fn save(&self, pool: &PgPool) -> Result<Feed> {
pub async fn save(&self, db: impl Executor<'_, Database = Postgres>) -> Result<Feed> {
Ok(sqlx::query_as!(
Feed,
r#"update feed set
@ -588,7 +601,7 @@ impl Feed {
self.last_crawled_at,
self.last_entry_published_at,
)
.fetch_one(pool)
.fetch_one(db)
.await?)
}
}

View File

@ -1,7 +1,7 @@
use axum_login::{secrecy::SecretVec, AuthUser, PostgresStore};
use chrono::{DateTime, Utc};
use serde::Deserialize;
use sqlx::{FromRow, PgPool};
use sqlx::{Executor, FromRow, Postgres};
use uuid::Uuid;
use validator::Validate;
@ -55,7 +55,7 @@ impl AuthUser<Uuid> for User {
}
impl User {
pub async fn get(pool: &PgPool, user_id: Uuid) -> Result<User> {
pub async fn get(db: impl Executor<'_, Database = Postgres>, user_id: Uuid) -> Result<User> {
sqlx::query_as!(
User,
r#"select
@ -64,7 +64,7 @@ impl User {
where user_id = $1"#,
user_id
)
.fetch_one(pool)
.fetch_one(db)
.await
.map_err(|error| {
if let sqlx::error::Error::RowNotFound = error {
@ -74,7 +74,10 @@ impl User {
})
}
pub async fn get_by_email(pool: &PgPool, email: String) -> Result<User> {
pub async fn get_by_email(
db: impl Executor<'_, Database = Postgres>,
email: String,
) -> Result<User> {
sqlx::query_as!(
User,
r#"select
@ -83,7 +86,7 @@ impl User {
where email = $1"#,
email
)
.fetch_one(pool)
.fetch_one(db)
.await
.map_err(|error| {
if let sqlx::error::Error::RowNotFound = error {
@ -93,7 +96,10 @@ impl User {
})
}
pub async fn create(pool: &PgPool, payload: CreateUser) -> Result<User> {
pub async fn create(
db: impl Executor<'_, Database = Postgres>,
payload: CreateUser,
) -> Result<User> {
payload.validate()?;
let password_hash = hash_password(payload.password).await?;
@ -117,11 +123,14 @@ impl User {
password_hash,
payload.name
)
.fetch_one(pool)
.fetch_one(db)
.await?)
}
pub async fn verify_email(pool: &PgPool, user_id: Uuid) -> Result<User> {
pub async fn verify_email(
db: impl Executor<'_, Database = Postgres>,
user_id: Uuid,
) -> Result<User> {
sqlx::query_as!(
User,
r#"update users set
@ -131,7 +140,7 @@ impl User {
"#,
user_id
)
.fetch_one(pool)
.fetch_one(db)
.await
.map_err(|error| {
if let sqlx::error::Error::RowNotFound = error {
@ -141,7 +150,11 @@ impl User {
})
}
pub async fn update_password(&self, pool: &PgPool, payload: UpdateUserPassword) -> Result<User> {
pub async fn update_password(
&self,
db: impl Executor<'_, Database = Postgres>,
payload: UpdateUserPassword,
) -> Result<User> {
payload.validate()?;
let password_hash = hash_password(payload.password).await?;
@ -164,7 +177,7 @@ impl User {
self.user_id,
password_hash,
)
.fetch_one(pool)
.fetch_one(db)
.await?)
}
}

View File

@ -1,6 +1,6 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use sqlx::{Executor, Postgres};
use uuid::Uuid;
use crate::error::{Error, Result};
@ -25,7 +25,10 @@ impl UserEmailVerificationToken {
Utc::now() > self.expires_at
}
pub async fn get(pool: &PgPool, token_id: Uuid) -> Result<UserEmailVerificationToken> {
pub async fn get(
db: impl Executor<'_, Database = Postgres>,
token_id: Uuid,
) -> Result<UserEmailVerificationToken> {
sqlx::query_as!(
UserEmailVerificationToken,
r#"select
@ -34,7 +37,7 @@ impl UserEmailVerificationToken {
where token_id = $1"#,
token_id
)
.fetch_one(pool)
.fetch_one(db)
.await
.map_err(|error| {
if let sqlx::error::Error::RowNotFound = error {
@ -45,7 +48,7 @@ impl UserEmailVerificationToken {
}
pub async fn create(
pool: &PgPool,
db: impl Executor<'_, Database = Postgres>,
payload: CreateUserEmailVerificationToken,
) -> Result<UserEmailVerificationToken> {
Ok(sqlx::query_as!(
@ -58,20 +61,17 @@ impl UserEmailVerificationToken {
payload.user_id,
payload.expires_at
)
.fetch_one(pool)
.fetch_one(db)
.await?)
}
pub async fn delete(
pool: &PgPool,
token_id: Uuid,
) -> Result<()> {
pub async fn delete(db: impl Executor<'_, Database = Postgres>, token_id: Uuid) -> Result<()> {
sqlx::query!(
r#"delete from user_email_verification_token
where token_id = $1"#,
token_id
)
.execute(pool)
.execute(db)
.await?;
Ok(())
}

View File

@ -1,7 +1,7 @@
use chrono::{DateTime, Utc};
use ipnetwork::IpNetwork;
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use sqlx::{Executor, Postgres};
use uuid::Uuid;
use crate::error::{Error, Result};
@ -31,7 +31,10 @@ impl UserPasswordResetToken {
Utc::now() > self.expires_at
}
pub async fn get(pool: &PgPool, token_id: Uuid) -> Result<UserPasswordResetToken> {
pub async fn get(
pool: impl Executor<'_, Database = Postgres>,
token_id: Uuid,
) -> Result<UserPasswordResetToken> {
sqlx::query_as!(
UserPasswordResetToken,
r#"select
@ -51,7 +54,7 @@ impl UserPasswordResetToken {
}
pub async fn create(
pool: &PgPool,
pool: impl Executor<'_, Database = Postgres>,
payload: CreatePasswordResetToken,
) -> Result<UserPasswordResetToken> {
Ok(sqlx::query_as!(
@ -72,7 +75,7 @@ impl UserPasswordResetToken {
}
pub async fn delete(
pool: &PgPool,
pool: impl Executor<'_, Database = Postgres>,
token_id: Uuid,
) -> Result<()> {
sqlx::query!(