Do reset password operations in db transaction

And modify signature of model methods to accept an executor instead of a
pool connection which will also allow transactions.
This commit is contained in:
Tyler Hallada 2023-10-13 14:44:40 +02:00
parent 60671d5865
commit 835e9dc748
7 changed files with 126 additions and 82 deletions

View File

@ -71,13 +71,9 @@ pub async fn post(
layout: Layout, layout: Layout,
Form(forgot_password): Form<ForgotPassword>, Form(forgot_password): Form<ForgotPassword>,
) -> Result<Response> { ) -> Result<Response> {
dbg!(&ip);
dbg!(&user_agent);
dbg!(&forgot_password.email);
let user: User = match User::get_by_email(&pool, forgot_password.email.clone()).await { let user: User = match User::get_by_email(&pool, forgot_password.email.clone()).await {
Ok(user) => user, Ok(user) => user,
Err(err) => { Err(err) => {
dbg!(&err);
if let Error::NotFoundString(_, _) = err { if let Error::NotFoundString(_, _) = err {
info!(email = forgot_password.email, "invalid email"); info!(email = forgot_password.email, "invalid email");
return Ok(layout return Ok(layout

View File

@ -241,11 +241,11 @@ pub async fn post(
} }
}; };
info!(user_id = %user.user_id, "user exists with verified email, resetting password"); info!(user_id = %user.user_id, "user exists with verified email, resetting password");
// TODO: do both in transaction let mut tx = pool.begin().await?;
UserPasswordResetToken::delete(&pool, reset_password.token).await?; UserPasswordResetToken::delete(tx.as_mut(), reset_password.token).await?;
let user = match user let user = match user
.update_password( .update_password(
&pool, tx.as_mut(),
UpdateUserPassword { UpdateUserPassword {
password: reset_password.password, password: reset_password.password,
}, },
@ -289,6 +289,7 @@ pub async fn post(
ip.into(), ip.into(),
user_agent.map(|ua| ua.to_string()), user_agent.map(|ua| ua.to_string()),
); );
tx.commit().await?;
Ok(layout Ok(layout
.with_subtitle("reset password") .with_subtitle("reset password")
.targeted(hx_target) .targeted(hx_target)

View File

@ -1,6 +1,6 @@
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::PgPool; use sqlx::{Executor, Postgres};
use uuid::Uuid; use uuid::Uuid;
use validator::{Validate, ValidationErrors}; use validator::{Validate, ValidationErrors};
@ -44,9 +44,9 @@ pub struct GetEntriesOptions {
} }
impl Entry { impl Entry {
pub async fn get(pool: &PgPool, entry_id: Uuid) -> Result<Entry> { pub async fn get(db: impl Executor<'_, Database = Postgres>, entry_id: Uuid) -> Result<Entry> {
sqlx::query_as!(Entry, "select * from entry where entry_id = $1", entry_id) sqlx::query_as!(Entry, "select * from entry where entry_id = $1", entry_id)
.fetch_one(pool) .fetch_one(db)
.await .await
.map_err(|error| { .map_err(|error| {
if let sqlx::error::Error::RowNotFound = error { if let sqlx::error::Error::RowNotFound = error {
@ -56,7 +56,10 @@ impl Entry {
}) })
} }
pub async fn get_all(pool: &PgPool, options: &GetEntriesOptions) -> sqlx::Result<Vec<Entry>> { pub async fn get_all(
db: impl Executor<'_, Database = Postgres>,
options: &GetEntriesOptions,
) -> sqlx::Result<Vec<Entry>> {
if let Some(feed_id) = options.feed_id { if let Some(feed_id) = options.feed_id {
if let Some(published_before) = options.published_before { if let Some(published_before) = options.published_before {
if let Some(id_before) = options.id_before { if let Some(id_before) = options.id_before {
@ -74,7 +77,7 @@ impl Entry {
id_before, id_before,
options.limit.unwrap_or(DEFAULT_ENTRIES_PAGE_SIZE) options.limit.unwrap_or(DEFAULT_ENTRIES_PAGE_SIZE)
) )
.fetch_all(pool) .fetch_all(db)
.await .await
} else { } else {
sqlx::query_as!( sqlx::query_as!(
@ -90,7 +93,7 @@ impl Entry {
published_before, published_before,
options.limit.unwrap_or(DEFAULT_ENTRIES_PAGE_SIZE) options.limit.unwrap_or(DEFAULT_ENTRIES_PAGE_SIZE)
) )
.fetch_all(pool) .fetch_all(db)
.await .await
} }
} else { } else {
@ -105,7 +108,7 @@ impl Entry {
feed_id, feed_id,
options.limit.unwrap_or(DEFAULT_ENTRIES_PAGE_SIZE) options.limit.unwrap_or(DEFAULT_ENTRIES_PAGE_SIZE)
) )
.fetch_all(pool) .fetch_all(db)
.await .await
} }
} else if let Some(published_before) = options.published_before { } else if let Some(published_before) = options.published_before {
@ -122,7 +125,7 @@ impl Entry {
id_before, id_before,
options.limit.unwrap_or(DEFAULT_ENTRIES_PAGE_SIZE) options.limit.unwrap_or(DEFAULT_ENTRIES_PAGE_SIZE)
) )
.fetch_all(pool) .fetch_all(db)
.await .await
} else { } else {
sqlx::query_as!( sqlx::query_as!(
@ -136,7 +139,7 @@ impl Entry {
published_before, published_before,
options.limit.unwrap_or(DEFAULT_ENTRIES_PAGE_SIZE) options.limit.unwrap_or(DEFAULT_ENTRIES_PAGE_SIZE)
) )
.fetch_all(pool) .fetch_all(db)
.await .await
} }
} else { } else {
@ -149,12 +152,15 @@ impl Entry {
", ",
options.limit.unwrap_or(DEFAULT_ENTRIES_PAGE_SIZE) options.limit.unwrap_or(DEFAULT_ENTRIES_PAGE_SIZE)
) )
.fetch_all(pool) .fetch_all(db)
.await .await
} }
} }
pub async fn create(pool: &PgPool, payload: CreateEntry) -> Result<Entry> { pub async fn create(
db: impl Executor<'_, Database = Postgres>,
payload: CreateEntry,
) -> Result<Entry> {
payload.validate()?; payload.validate()?;
sqlx::query_as!( sqlx::query_as!(
Entry, Entry,
@ -169,7 +175,7 @@ impl Entry {
payload.feed_id, payload.feed_id,
payload.published_at, payload.published_at,
) )
.fetch_one(pool) .fetch_one(db)
.await .await
.map_err(|error| { .map_err(|error| {
if let sqlx::error::Error::Database(ref psql_error) = error { if let sqlx::error::Error::Database(ref psql_error) = error {
@ -181,7 +187,10 @@ impl Entry {
}) })
} }
pub async fn upsert(pool: &PgPool, payload: CreateEntry) -> Result<Entry> { pub async fn upsert(
db: impl Executor<'_, Database = Postgres>,
payload: CreateEntry,
) -> Result<Entry> {
payload.validate()?; payload.validate()?;
sqlx::query_as!( sqlx::query_as!(
Entry, Entry,
@ -200,7 +209,7 @@ impl Entry {
payload.feed_id, payload.feed_id,
payload.published_at, payload.published_at,
) )
.fetch_one(pool) .fetch_one(db)
.await .await
.map_err(|error| { .map_err(|error| {
if let sqlx::error::Error::Database(ref psql_error) = error { if let sqlx::error::Error::Database(ref psql_error) = error {
@ -212,7 +221,10 @@ impl Entry {
}) })
} }
pub async fn bulk_create(pool: &PgPool, payload: Vec<CreateEntry>) -> Result<Vec<Entry>> { pub async fn bulk_create(
db: impl Executor<'_, Database = Postgres>,
payload: Vec<CreateEntry>,
) -> Result<Vec<Entry>> {
let mut titles = Vec::with_capacity(payload.len()); let mut titles = Vec::with_capacity(payload.len());
let mut urls = Vec::with_capacity(payload.len()); let mut urls = Vec::with_capacity(payload.len());
let mut descriptions: Vec<Option<String>> = Vec::with_capacity(payload.len()); let mut descriptions: Vec<Option<String>> = Vec::with_capacity(payload.len());
@ -241,7 +253,7 @@ impl Entry {
feed_ids.as_slice(), feed_ids.as_slice(),
published_ats.as_slice(), published_ats.as_slice(),
) )
.fetch_all(pool) .fetch_all(db)
.await .await
.map_err(|error| { .map_err(|error| {
if let sqlx::error::Error::Database(ref psql_error) = error { if let sqlx::error::Error::Database(ref psql_error) = error {
@ -253,7 +265,10 @@ impl Entry {
}) })
} }
pub async fn bulk_upsert(pool: &PgPool, payload: Vec<CreateEntry>) -> Result<Vec<Entry>> { pub async fn bulk_upsert(
db: impl Executor<'_, Database = Postgres>,
payload: Vec<CreateEntry>,
) -> Result<Vec<Entry>> {
let mut titles = Vec::with_capacity(payload.len()); let mut titles = Vec::with_capacity(payload.len());
let mut urls = Vec::with_capacity(payload.len()); let mut urls = Vec::with_capacity(payload.len());
let mut descriptions: Vec<Option<String>> = Vec::with_capacity(payload.len()); let mut descriptions: Vec<Option<String>> = Vec::with_capacity(payload.len());
@ -286,7 +301,7 @@ impl Entry {
feed_ids.as_slice(), feed_ids.as_slice(),
published_ats.as_slice(), published_ats.as_slice(),
) )
.fetch_all(pool) .fetch_all(db)
.await .await
.map_err(|error| { .map_err(|error| {
if let sqlx::error::Error::Database(ref psql_error) = error { if let sqlx::error::Error::Database(ref psql_error) = error {
@ -298,7 +313,10 @@ impl Entry {
}) })
} }
pub async fn update(pool: &PgPool, payload: Entry) -> Result<Entry> { pub async fn update(
db: impl Executor<'_, Database = Postgres>,
payload: Entry,
) -> Result<Entry> {
sqlx::query_as!( sqlx::query_as!(
Entry, Entry,
"update entry set "update entry set
@ -321,7 +339,7 @@ impl Entry {
payload.last_modified_header, payload.last_modified_header,
payload.published_at, payload.published_at,
) )
.fetch_one(pool) .fetch_one(db)
.await .await
.map_err(|error| { .map_err(|error| {
if let sqlx::error::Error::Database(ref psql_error) = error { if let sqlx::error::Error::Database(ref psql_error) = error {
@ -333,12 +351,12 @@ impl Entry {
}) })
} }
pub async fn delete(pool: &PgPool, entry_id: Uuid) -> Result<()> { pub async fn delete(db: impl Executor<'_, Database = Postgres>, entry_id: Uuid) -> Result<()> {
sqlx::query!( sqlx::query!(
"update entry set deleted_at = now() where entry_id = $1", "update entry set deleted_at = now() where entry_id = $1",
entry_id entry_id
) )
.execute(pool) .execute(db)
.await?; .await?;
Ok(()) Ok(())
} }

View File

@ -2,7 +2,7 @@ use std::str::FromStr;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::{FromRow, PgPool, postgres::PgQueryResult}; use sqlx::{postgres::PgQueryResult, Executor, FromRow, Postgres};
use uuid::Uuid; use uuid::Uuid;
use validator::Validate; use validator::Validate;
@ -127,7 +127,7 @@ pub struct GetFeedsOptions {
} }
impl Feed { impl Feed {
pub async fn get(pool: &PgPool, feed_id: Uuid) -> Result<Feed> { pub async fn get(db: impl Executor<'_, Database = Postgres>, feed_id: Uuid) -> Result<Feed> {
sqlx::query_as!( sqlx::query_as!(
Feed, Feed,
// Unable to SELECT * here due to https://github.com/launchbadge/sqlx/issues/1004 // Unable to SELECT * here due to https://github.com/launchbadge/sqlx/issues/1004
@ -150,7 +150,7 @@ impl Feed {
from feed where feed_id = $1"#, from feed where feed_id = $1"#,
feed_id feed_id
) )
.fetch_one(pool) .fetch_one(db)
.await .await
.map_err(|error| { .map_err(|error| {
if let sqlx::error::Error::RowNotFound = error { if let sqlx::error::Error::RowNotFound = error {
@ -160,7 +160,10 @@ impl Feed {
}) })
} }
pub async fn get_all(pool: &PgPool, options: &GetFeedsOptions) -> sqlx::Result<Vec<Feed>> { pub async fn get_all(
db: impl Executor<'_, Database = Postgres>,
options: &GetFeedsOptions,
) -> sqlx::Result<Vec<Feed>> {
// TODO: make sure there are indices for all of these sort options // TODO: make sure there are indices for all of these sort options
match options.sort.as_ref().unwrap_or(&GetFeedsSort::CreatedAt) { match options.sort.as_ref().unwrap_or(&GetFeedsSort::CreatedAt) {
GetFeedsSort::Title => { GetFeedsSort::Title => {
@ -192,7 +195,7 @@ impl Feed {
options.before_id.unwrap_or(Uuid::nil()), options.before_id.unwrap_or(Uuid::nil()),
options.limit.unwrap_or(DEFAULT_FEEDS_PAGE_SIZE), options.limit.unwrap_or(DEFAULT_FEEDS_PAGE_SIZE),
) )
.fetch_all(pool) .fetch_all(db)
.await .await
} else { } else {
sqlx::query_as!( sqlx::query_as!(
@ -219,9 +222,8 @@ impl Feed {
"#, "#,
options.limit.unwrap_or(DEFAULT_FEEDS_PAGE_SIZE), options.limit.unwrap_or(DEFAULT_FEEDS_PAGE_SIZE),
) )
.fetch_all(pool) .fetch_all(db)
.await .await
} }
} }
GetFeedsSort::CreatedAt => { GetFeedsSort::CreatedAt => {
@ -253,7 +255,7 @@ impl Feed {
options.before_id.unwrap_or(Uuid::nil()), options.before_id.unwrap_or(Uuid::nil()),
options.limit.unwrap_or(DEFAULT_FEEDS_PAGE_SIZE), options.limit.unwrap_or(DEFAULT_FEEDS_PAGE_SIZE),
) )
.fetch_all(pool) .fetch_all(db)
.await .await
} else { } else {
sqlx::query_as!( sqlx::query_as!(
@ -280,9 +282,8 @@ impl Feed {
"#, "#,
options.limit.unwrap_or(DEFAULT_FEEDS_PAGE_SIZE), options.limit.unwrap_or(DEFAULT_FEEDS_PAGE_SIZE),
) )
.fetch_all(pool) .fetch_all(db)
.await .await
} }
} }
GetFeedsSort::LastCrawledAt => { GetFeedsSort::LastCrawledAt => {
@ -314,7 +315,7 @@ impl Feed {
options.before_id.unwrap_or(Uuid::nil()), options.before_id.unwrap_or(Uuid::nil()),
options.limit.unwrap_or(DEFAULT_FEEDS_PAGE_SIZE), options.limit.unwrap_or(DEFAULT_FEEDS_PAGE_SIZE),
) )
.fetch_all(pool) .fetch_all(db)
.await .await
} else { } else {
sqlx::query_as!( sqlx::query_as!(
@ -341,9 +342,8 @@ impl Feed {
"#, "#,
options.limit.unwrap_or(DEFAULT_FEEDS_PAGE_SIZE), options.limit.unwrap_or(DEFAULT_FEEDS_PAGE_SIZE),
) )
.fetch_all(pool) .fetch_all(db)
.await .await
} }
} }
GetFeedsSort::LastEntryPublishedAt => { GetFeedsSort::LastEntryPublishedAt => {
@ -375,7 +375,7 @@ impl Feed {
options.before_id.unwrap_or(Uuid::nil()), options.before_id.unwrap_or(Uuid::nil()),
options.limit.unwrap_or(DEFAULT_FEEDS_PAGE_SIZE), options.limit.unwrap_or(DEFAULT_FEEDS_PAGE_SIZE),
) )
.fetch_all(pool) .fetch_all(db)
.await .await
} else { } else {
sqlx::query_as!( sqlx::query_as!(
@ -402,15 +402,17 @@ impl Feed {
"#, "#,
options.limit.unwrap_or(DEFAULT_FEEDS_PAGE_SIZE), options.limit.unwrap_or(DEFAULT_FEEDS_PAGE_SIZE),
) )
.fetch_all(pool) .fetch_all(db)
.await .await
} }
} }
} }
} }
pub async fn create(pool: &PgPool, payload: CreateFeed) -> Result<Feed> { pub async fn create(
db: impl Executor<'_, Database = Postgres>,
payload: CreateFeed,
) -> Result<Feed> {
payload.validate()?; payload.validate()?;
Ok(sqlx::query_as!( Ok(sqlx::query_as!(
Feed, Feed,
@ -438,11 +440,14 @@ impl Feed {
payload.url, payload.url,
payload.description payload.description
) )
.fetch_one(pool) .fetch_one(db)
.await?) .await?)
} }
pub async fn upsert(pool: &PgPool, payload: UpsertFeed) -> Result<Feed> { pub async fn upsert(
db: impl Executor<'_, Database = Postgres>,
payload: UpsertFeed,
) -> Result<Feed> {
payload.validate()?; payload.validate()?;
Ok(sqlx::query_as!( Ok(sqlx::query_as!(
Feed, Feed,
@ -476,11 +481,15 @@ impl Feed {
payload.feed_type as Option<FeedType>, payload.feed_type as Option<FeedType>,
payload.description payload.description
) )
.fetch_one(pool) .fetch_one(db)
.await?) .await?)
} }
pub async fn update(pool: &PgPool, feed_id: Uuid, payload: UpdateFeed) -> Result<Feed> { pub async fn update(
db: impl Executor<'_, Database = Postgres>,
feed_id: Uuid,
payload: UpdateFeed,
) -> Result<Feed> {
payload.validate()?; payload.validate()?;
let mut query = sqlx::QueryBuilder::new("UPDATE feed SET "); let mut query = sqlx::QueryBuilder::new("UPDATE feed SET ");
@ -520,20 +529,24 @@ impl Feed {
let query = query.build_query_as(); let query = query.build_query_as();
Ok(query.fetch_one(pool).await?) Ok(query.fetch_one(db).await?)
} }
pub async fn delete(pool: &PgPool, feed_id: Uuid) -> Result<()> { pub async fn delete(db: impl Executor<'_, Database = Postgres>, feed_id: Uuid) -> Result<()> {
sqlx::query!( sqlx::query!(
"update feed set deleted_at = now() where feed_id = $1", "update feed set deleted_at = now() where feed_id = $1",
feed_id feed_id
) )
.execute(pool) .execute(db)
.await?; .await?;
Ok(()) Ok(())
} }
pub async fn update_crawl_error(pool: &PgPool, feed_id: Uuid, last_crawl_error: String) -> Result<PgQueryResult> { pub async fn update_crawl_error(
db: impl Executor<'_, Database = Postgres>,
feed_id: Uuid,
last_crawl_error: String,
) -> Result<PgQueryResult> {
Ok(sqlx::query!( Ok(sqlx::query!(
r#"update feed set r#"update feed set
last_crawl_error = $2 last_crawl_error = $2
@ -541,11 +554,11 @@ impl Feed {
feed_id, feed_id,
last_crawl_error, last_crawl_error,
) )
.execute(pool) .execute(db)
.await?) .await?)
} }
pub async fn save(&self, pool: &PgPool) -> Result<Feed> { pub async fn save(&self, db: impl Executor<'_, Database = Postgres>) -> Result<Feed> {
Ok(sqlx::query_as!( Ok(sqlx::query_as!(
Feed, Feed,
r#"update feed set r#"update feed set
@ -588,7 +601,7 @@ impl Feed {
self.last_crawled_at, self.last_crawled_at,
self.last_entry_published_at, self.last_entry_published_at,
) )
.fetch_one(pool) .fetch_one(db)
.await?) .await?)
} }
} }

View File

@ -1,7 +1,7 @@
use axum_login::{secrecy::SecretVec, AuthUser, PostgresStore}; use axum_login::{secrecy::SecretVec, AuthUser, PostgresStore};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use serde::Deserialize; use serde::Deserialize;
use sqlx::{FromRow, PgPool}; use sqlx::{Executor, FromRow, Postgres};
use uuid::Uuid; use uuid::Uuid;
use validator::Validate; use validator::Validate;
@ -55,7 +55,7 @@ impl AuthUser<Uuid> for User {
} }
impl User { impl User {
pub async fn get(pool: &PgPool, user_id: Uuid) -> Result<User> { pub async fn get(db: impl Executor<'_, Database = Postgres>, user_id: Uuid) -> Result<User> {
sqlx::query_as!( sqlx::query_as!(
User, User,
r#"select r#"select
@ -64,7 +64,7 @@ impl User {
where user_id = $1"#, where user_id = $1"#,
user_id user_id
) )
.fetch_one(pool) .fetch_one(db)
.await .await
.map_err(|error| { .map_err(|error| {
if let sqlx::error::Error::RowNotFound = error { if let sqlx::error::Error::RowNotFound = error {
@ -74,7 +74,10 @@ impl User {
}) })
} }
pub async fn get_by_email(pool: &PgPool, email: String) -> Result<User> { pub async fn get_by_email(
db: impl Executor<'_, Database = Postgres>,
email: String,
) -> Result<User> {
sqlx::query_as!( sqlx::query_as!(
User, User,
r#"select r#"select
@ -83,7 +86,7 @@ impl User {
where email = $1"#, where email = $1"#,
email email
) )
.fetch_one(pool) .fetch_one(db)
.await .await
.map_err(|error| { .map_err(|error| {
if let sqlx::error::Error::RowNotFound = error { if let sqlx::error::Error::RowNotFound = error {
@ -93,7 +96,10 @@ impl User {
}) })
} }
pub async fn create(pool: &PgPool, payload: CreateUser) -> Result<User> { pub async fn create(
db: impl Executor<'_, Database = Postgres>,
payload: CreateUser,
) -> Result<User> {
payload.validate()?; payload.validate()?;
let password_hash = hash_password(payload.password).await?; let password_hash = hash_password(payload.password).await?;
@ -117,11 +123,14 @@ impl User {
password_hash, password_hash,
payload.name payload.name
) )
.fetch_one(pool) .fetch_one(db)
.await?) .await?)
} }
pub async fn verify_email(pool: &PgPool, user_id: Uuid) -> Result<User> { pub async fn verify_email(
db: impl Executor<'_, Database = Postgres>,
user_id: Uuid,
) -> Result<User> {
sqlx::query_as!( sqlx::query_as!(
User, User,
r#"update users set r#"update users set
@ -131,7 +140,7 @@ impl User {
"#, "#,
user_id user_id
) )
.fetch_one(pool) .fetch_one(db)
.await .await
.map_err(|error| { .map_err(|error| {
if let sqlx::error::Error::RowNotFound = error { if let sqlx::error::Error::RowNotFound = error {
@ -141,7 +150,11 @@ impl User {
}) })
} }
pub async fn update_password(&self, pool: &PgPool, payload: UpdateUserPassword) -> Result<User> { pub async fn update_password(
&self,
db: impl Executor<'_, Database = Postgres>,
payload: UpdateUserPassword,
) -> Result<User> {
payload.validate()?; payload.validate()?;
let password_hash = hash_password(payload.password).await?; let password_hash = hash_password(payload.password).await?;
@ -164,7 +177,7 @@ impl User {
self.user_id, self.user_id,
password_hash, password_hash,
) )
.fetch_one(pool) .fetch_one(db)
.await?) .await?)
} }
} }

View File

@ -1,6 +1,6 @@
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::PgPool; use sqlx::{Executor, Postgres};
use uuid::Uuid; use uuid::Uuid;
use crate::error::{Error, Result}; use crate::error::{Error, Result};
@ -25,7 +25,10 @@ impl UserEmailVerificationToken {
Utc::now() > self.expires_at Utc::now() > self.expires_at
} }
pub async fn get(pool: &PgPool, token_id: Uuid) -> Result<UserEmailVerificationToken> { pub async fn get(
db: impl Executor<'_, Database = Postgres>,
token_id: Uuid,
) -> Result<UserEmailVerificationToken> {
sqlx::query_as!( sqlx::query_as!(
UserEmailVerificationToken, UserEmailVerificationToken,
r#"select r#"select
@ -34,7 +37,7 @@ impl UserEmailVerificationToken {
where token_id = $1"#, where token_id = $1"#,
token_id token_id
) )
.fetch_one(pool) .fetch_one(db)
.await .await
.map_err(|error| { .map_err(|error| {
if let sqlx::error::Error::RowNotFound = error { if let sqlx::error::Error::RowNotFound = error {
@ -45,7 +48,7 @@ impl UserEmailVerificationToken {
} }
pub async fn create( pub async fn create(
pool: &PgPool, db: impl Executor<'_, Database = Postgres>,
payload: CreateUserEmailVerificationToken, payload: CreateUserEmailVerificationToken,
) -> Result<UserEmailVerificationToken> { ) -> Result<UserEmailVerificationToken> {
Ok(sqlx::query_as!( Ok(sqlx::query_as!(
@ -58,20 +61,17 @@ impl UserEmailVerificationToken {
payload.user_id, payload.user_id,
payload.expires_at payload.expires_at
) )
.fetch_one(pool) .fetch_one(db)
.await?) .await?)
} }
pub async fn delete( pub async fn delete(db: impl Executor<'_, Database = Postgres>, token_id: Uuid) -> Result<()> {
pool: &PgPool,
token_id: Uuid,
) -> Result<()> {
sqlx::query!( sqlx::query!(
r#"delete from user_email_verification_token r#"delete from user_email_verification_token
where token_id = $1"#, where token_id = $1"#,
token_id token_id
) )
.execute(pool) .execute(db)
.await?; .await?;
Ok(()) Ok(())
} }

View File

@ -1,7 +1,7 @@
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use ipnetwork::IpNetwork; use ipnetwork::IpNetwork;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::PgPool; use sqlx::{Executor, Postgres};
use uuid::Uuid; use uuid::Uuid;
use crate::error::{Error, Result}; use crate::error::{Error, Result};
@ -31,7 +31,10 @@ impl UserPasswordResetToken {
Utc::now() > self.expires_at Utc::now() > self.expires_at
} }
pub async fn get(pool: &PgPool, token_id: Uuid) -> Result<UserPasswordResetToken> { pub async fn get(
pool: impl Executor<'_, Database = Postgres>,
token_id: Uuid,
) -> Result<UserPasswordResetToken> {
sqlx::query_as!( sqlx::query_as!(
UserPasswordResetToken, UserPasswordResetToken,
r#"select r#"select
@ -51,7 +54,7 @@ impl UserPasswordResetToken {
} }
pub async fn create( pub async fn create(
pool: &PgPool, pool: impl Executor<'_, Database = Postgres>,
payload: CreatePasswordResetToken, payload: CreatePasswordResetToken,
) -> Result<UserPasswordResetToken> { ) -> Result<UserPasswordResetToken> {
Ok(sqlx::query_as!( Ok(sqlx::query_as!(
@ -72,7 +75,7 @@ impl UserPasswordResetToken {
} }
pub async fn delete( pub async fn delete(
pool: &PgPool, pool: impl Executor<'_, Database = Postgres>,
token_id: Uuid, token_id: Uuid,
) -> Result<()> { ) -> Result<()> {
sqlx::query!( sqlx::query!(