Update all packages, switch to tower-sessions

This commit is contained in:
Tyler Hallada 2023-12-20 00:25:28 -05:00
parent 6c23b3aaa3
commit c9a631a1f2
23 changed files with 1016 additions and 894 deletions

1483
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -13,31 +13,32 @@ path = "src/lib.rs"
[dependencies] [dependencies]
ammonia = "3.3.0" ammonia = "3.3.0"
ansi-to-html = "0.1" ansi-to-html = "0.2"
anyhow = "1" anyhow = "1"
argon2 = "0.5" async-trait = "0.1"
async-fred-session = "0.1" axum = { version = "0.7", features = ["form", "multipart", "query"] }
axum = { version = "0.6", features = ["form", "headers", "multipart", "query"] } axum-client-ip = "0.5"
axum-client-ip = "0.4" axum-extra = { version = "0.9", features = ["typed-header"] }
# waiting for new axum-login release which will support sqlx v. 0.7+ axum-login = "0.10"
axum-login = { git = "https://github.com/maxcountryman/axum-login", branch = "main", features = [
"postgres",
] }
bytes = "1.4" bytes = "1.4"
# TODO: replace chrono with time
chrono = { version = "0.4", features = ["serde"] } chrono = { version = "0.4", features = ["serde"] }
clap = { version = "4.4", features = ["derive", "env"] } clap = { version = "4.4", features = ["derive", "env"] }
dotenvy = "0.15" dotenvy = "0.15"
feed-rs = "1.3" feed-rs = "1.3"
futures = "0.3" futures = "0.3"
headers = "0.3" headers = "0.4"
http = "0.2.9" http = "1.0.0"
ipnetwork = "0.20" ipnetwork = "0.20"
lettre = { version = "0.10", features = ["builder"] } lettre = { version = "0.11", features = ["builder"] }
maud = { version = "0.25", features = ["axum"] } # waiting for new maud release which will support axum v. 0.7+: https://github.com/lambda-fairy/maud/pull/401
maud = { git = "https://github.com/vidhanio/maud", branch = "patch-1", features = [
"axum",
] }
notify = "6" notify = "6"
once_cell = "1.18" once_cell = "1.18"
opml = "1.1" opml = "1.1"
rand = { version = "0.8.5", features = ["min_const_gen"] } password-auth = "1.0"
readability = "0.2" readability = "0.2"
reqwest = { version = "0.11", features = ["json"] } reqwest = { version = "0.11", features = ["json"] }
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
@ -52,11 +53,13 @@ sqlx = { version = "0.7", features = [
"ipnetwork", "ipnetwork",
] } ] }
thiserror = "1" thiserror = "1"
time = "0.3"
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
tokio-stream = { version = "0.1", features = ["sync"] } tokio-stream = { version = "0.1", features = ["sync"] }
tower = "0.4" tower = "0.4"
tower-livereload = "0.8" tower-livereload = "0.9"
tower-http = { version = "0.4", features = ["trace", "fs"] } tower-http = { version = "0.5", features = ["trace", "fs"] }
tower-sessions = { version = "0.7", features = ["redis-store"] }
tracing = { version = "0.1", features = ["valuable", "attributes"] } tracing = { version = "0.1", features = ["valuable", "attributes"] }
tracing-appender = "0.2" tracing-appender = "0.2"
tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-subscriber = { version = "0.3", features = ["env-filter"] }

View File

@ -1,38 +1,95 @@
use anyhow::Context; use anyhow::Context;
use argon2::password_hash::{ use async_trait::async_trait;
rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString, use axum_login::{AuthUser, AuthnBackend, UserId};
}; use password_auth;
use argon2::Argon2; use serde::Deserialize;
use sqlx::PgPool;
use uuid::Uuid;
use crate::error::{Error, Result}; use crate::{error::Result, models::user::User};
pub async fn hash_password(password: String) -> Result<String> { pub async fn generate_hash(password: String) -> Result<String> {
// Argon2 hashing is designed to be computationally intensive, // Argon2 hashing is designed to be computationally intensive,
// so we need to do this on a blocking thread. tokio::task::spawn_blocking(move || -> String { password_auth::generate_hash(password) })
tokio::task::spawn_blocking(move || -> Result<String> { .await
let salt = SaltString::generate(&mut OsRng); .context("panic in generating password hash")
let argon2 = Argon2::default(); .map_err(|e| e.into())
Ok(argon2
.hash_password(password.as_bytes(), &salt)
.map_err(|e| anyhow::anyhow!("failed to generate password hash: {}", e))?
.to_string())
})
.await
.context("panic in generating password hash")?
} }
pub async fn verify_password(password: String, password_hash: String) -> Result<()> { pub async fn verify_password(password: String, password_hash: String) -> Result<()> {
tokio::task::spawn_blocking(move || -> Result<()> { tokio::task::spawn_blocking(move || -> Result<()> {
let hash = PasswordHash::new(&password_hash) password_auth::verify_password(password.as_bytes(), &password_hash)
.map_err(|e| anyhow::anyhow!("invalid password hash: {}", e))?; .map_err(|e| anyhow::anyhow!("failed to verify password hash: {}", e).into())
Argon2::default()
.verify_password(password.as_bytes(), &hash)
.map_err(|e| match e {
argon2::password_hash::Error::Password => Error::Unauthorized,
_ => anyhow::anyhow!("failed to verify password hash: {}", e).into(),
})
}) })
.await .await
.context("panic in verifying password hash")? .context("panic in verifying password hash")?
} }
impl AuthUser for User {
type Id = Uuid;
fn id(&self) -> Self::Id {
self.user_id
}
fn session_auth_hash(&self) -> &[u8] {
self.password_hash.as_bytes()
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct Credentials {
pub email: String,
pub password: String,
}
#[derive(Debug, Clone)]
pub struct Backend {
db: PgPool,
}
impl Backend {
pub fn new(db: PgPool) -> Self {
Self { db }
}
}
#[async_trait]
impl AuthnBackend for Backend {
type User = User;
type Credentials = Credentials;
type Error = sqlx::Error;
async fn authenticate(
&self,
creds: Self::Credentials,
) -> Result<Option<Self::User>, Self::Error> {
let user = User::get_by_email(&self.db, creds.email).await.ok();
if let Some(user) = user {
if verify_password(creds.password, user.password_hash.clone())
.await
.ok()
.is_some()
{
return Ok(Some(user));
}
}
Ok(None)
}
async fn get_user(&self, user_id: &UserId<Self>) -> Result<Option<Self::User>, Self::Error> {
sqlx::query_as!(
User,
r#"select
*
from users
where user_id = $1"#,
user_id
)
.fetch_optional(&self.db)
.await
}
}
pub type AuthSession = axum_login::AuthSession<Backend>;

View File

@ -13,7 +13,7 @@ impl FromStr for IpSource {
type Err = &'static str; type Err = &'static str;
fn from_str(s: &str) -> Result<Self, Self::Err> { fn from_str(s: &str) -> Result<Self, Self::Err> {
// SourceClientIpSource doesn't implement FromStr itself, so I have to implement it on this // SourceClientIpSource doesn't implement FromStr itself, so I have to implement it on this
// wrapping newtype. See https://github.com/imbolc/axum-client-ip/issues/11 // wrapping newtype. See https://github.com/imbolc/axum-client-ip/issues/11
let inner = match s { let inner = match s {
"RightmostForwarded" => SecureClientIpSource::RightmostForwarded, "RightmostForwarded" => SecureClientIpSource::RightmostForwarded,
@ -63,4 +63,6 @@ pub struct Config {
pub session_secret: String, pub session_secret: String,
#[clap(long, env, default_value = "ConnectInfo")] #[clap(long, env, default_value = "ConnectInfo")]
pub ip_source: IpSource, pub ip_source: IpSource,
#[clap(long, env, default_value = "100")]
pub session_duration_days: i64,
} }

15
src/handlers/account.rs Normal file
View File

@ -0,0 +1,15 @@
use axum::response::IntoResponse;
use crate::auth::AuthSession;
pub async fn get(auth: AuthSession) -> impl IntoResponse {
match auth.user {
Some(user) => {
format!(
"Logged in as: {}",
user.name.unwrap_or_else(|| "No name".to_string())
)
}
None => "Not logged in".to_string(),
}
}

View File

@ -1,7 +1,7 @@
use axum::extract::Query; use axum::extract::Query;
use axum::extract::State; use axum::extract::State;
use axum::response::IntoResponse; use axum::response::IntoResponse;
use axum::TypedHeader; use axum_extra::TypedHeader;
use sqlx::PgPool; use sqlx::PgPool;
use crate::api_response::ApiResponse; use crate::api_response::ApiResponse;

View File

@ -1,7 +1,7 @@
use axum::TypedHeader;
use axum::extract::Query; use axum::extract::Query;
use axum::response::IntoResponse;
use axum::extract::State; use axum::extract::State;
use axum::response::IntoResponse;
use axum_extra::TypedHeader;
use sqlx::PgPool; use sqlx::PgPool;
use crate::api_response::ApiResponse; use crate::api_response::ApiResponse;
@ -21,7 +21,5 @@ pub async fn get(
return Ok::<ApiResponse<Vec<Feed>>, Error>(ApiResponse::Json(feeds)); return Ok::<ApiResponse<Vec<Feed>>, Error>(ApiResponse::Json(feeds));
} }
} }
Ok(ApiResponse::Html( Ok(ApiResponse::Html(feed_list(feeds, &options).into_string()))
feed_list(feeds, &options).into_string(),
))
} }

View File

@ -1,6 +1,7 @@
use axum::extract::{Query, State}; use axum::extract::{Query, State};
use axum::response::Response; use axum::response::Response;
use axum::{Form, TypedHeader}; use axum::Form;
use axum_extra::TypedHeader;
use lettre::SmtpTransport; use lettre::SmtpTransport;
use maud::{html, Markup}; use maud::{html, Markup};
use serde::Deserialize; use serde::Deserialize;
@ -9,11 +10,12 @@ use sqlx::PgPool;
use tracing::{info, warn}; use tracing::{info, warn};
use uuid::Uuid; use uuid::Uuid;
use crate::auth::AuthSession;
use crate::config::Config; use crate::config::Config;
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::htmx::HXTarget; use crate::htmx::HXTarget;
use crate::mailers::email_verification::send_confirmation_email; use crate::mailers::email_verification::send_confirmation_email;
use crate::models::user::{AuthContext, User}; use crate::models::user::User;
use crate::models::user_email_verification_token::UserEmailVerificationToken; use crate::models::user_email_verification_token::UserEmailVerificationToken;
use crate::partials::confirm_email_form::{confirm_email_form, ConfirmEmailFormProps}; use crate::partials::confirm_email_form::{confirm_email_form, ConfirmEmailFormProps};
use crate::partials::layout::Layout; use crate::partials::layout::Layout;
@ -66,7 +68,7 @@ pub fn confirm_email_page(
pub async fn get( pub async fn get(
State(pool): State<PgPool>, State(pool): State<PgPool>,
auth: AuthContext, auth: AuthSession,
hx_target: Option<TypedHeader<HXTarget>>, hx_target: Option<TypedHeader<HXTarget>>,
layout: Layout, layout: Layout,
query: Query<ConfirmEmailQuery>, query: Query<ConfirmEmailQuery>,
@ -129,7 +131,7 @@ pub async fn get(
hx_target, hx_target,
layout, layout,
form_props: ConfirmEmailFormProps { form_props: ConfirmEmailFormProps {
email: auth.current_user.map(|u| u.email), email: auth.user.map(|u| u.email),
..Default::default() ..Default::default()
}, },
..Default::default() ..Default::default()

View File

@ -2,7 +2,7 @@ use std::fs;
use axum::extract::{Path, State}; use axum::extract::{Path, State};
use axum::response::Response; use axum::response::Response;
use axum::TypedHeader; use axum_extra::TypedHeader;
use maud::{html, PreEscaped}; use maud::{html, PreEscaped};
use sqlx::PgPool; use sqlx::PgPool;

View File

@ -4,7 +4,8 @@ use axum::extract::{Path, State};
use axum::http::StatusCode; use axum::http::StatusCode;
use axum::response::sse::{Event, KeepAlive}; use axum::response::sse::{Event, KeepAlive};
use axum::response::{IntoResponse, Redirect, Response, Sse}; use axum::response::{IntoResponse, Redirect, Response, Sse};
use axum::{Form, TypedHeader}; use axum::Form;
use axum_extra::TypedHeader;
use feed_rs::parser; use feed_rs::parser;
use maud::html; use maud::html;
use serde::Deserialize; use serde::Deserialize;

View File

@ -1,6 +1,6 @@
use axum::extract::State; use axum::extract::State;
use axum::response::Response; use axum::response::Response;
use axum::TypedHeader; use axum_extra::TypedHeader;
use maud::html; use maud::html;
use sqlx::PgPool; use sqlx::PgPool;

View File

@ -1,7 +1,7 @@
use axum::response::{IntoResponse, Response}; use axum::response::{IntoResponse, Response};
use axum::TypedHeader;
use axum::{extract::State, Form}; use axum::{extract::State, Form};
use axum_client_ip::SecureClientIp; use axum_client_ip::SecureClientIp;
use axum_extra::TypedHeader;
use headers::UserAgent; use headers::UserAgent;
use lettre::SmtpTransport; use lettre::SmtpTransport;
use maud::html; use maud::html;
@ -10,11 +10,11 @@ use serde_with::serde_as;
use sqlx::PgPool; use sqlx::PgPool;
use tracing::{info, warn}; use tracing::{info, warn};
use crate::auth::AuthSession;
use crate::config::Config; use crate::config::Config;
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::htmx::HXTarget; use crate::htmx::HXTarget;
use crate::mailers::forgot_password::send_forgot_password_email; use crate::mailers::forgot_password::send_forgot_password_email;
use crate::models::user::AuthContext;
use crate::partials::forgot_password_form::{forgot_password_form, ForgotPasswordFormProps}; use crate::partials::forgot_password_form::{forgot_password_form, ForgotPasswordFormProps};
use crate::{models::user::User, partials::layout::Layout}; use crate::{models::user::User, partials::layout::Layout};
@ -67,7 +67,7 @@ pub fn confirm_forgot_password_sent_page(
} }
pub async fn get( pub async fn get(
auth: AuthContext, auth: AuthSession,
hx_target: Option<TypedHeader<HXTarget>>, hx_target: Option<TypedHeader<HXTarget>>,
layout: Layout, layout: Layout,
) -> Result<Response> { ) -> Result<Response> {
@ -75,7 +75,7 @@ pub async fn get(
hx_target, hx_target,
layout, layout,
ForgotPasswordFormProps { ForgotPasswordFormProps {
email: auth.current_user.map(|u| u.email), email: auth.user.map(|u| u.email),
email_error: None, email_error: None,
}, },
)) ))

View File

@ -1,6 +1,6 @@
use axum::extract::State; use axum::extract::State;
use axum::response::Response; use axum::response::Response;
use axum::TypedHeader; use axum_extra::TypedHeader;
use maud::html; use maud::html;
use sqlx::PgPool; use sqlx::PgPool;

View File

@ -2,14 +2,14 @@ use std::convert::Infallible;
use std::str::from_utf8; use std::str::from_utf8;
use std::time::Duration; use std::time::Duration;
use ansi_to_html::convert_escaped; use ansi_to_html::convert;
use axum::extract::State; use axum::extract::State;
use axum::response::sse::KeepAlive; use axum::response::sse::KeepAlive;
use axum::response::{ use axum::response::{
sse::{Event, Sse}, sse::{Event, Sse},
Response, Response,
}; };
use axum::TypedHeader; use axum_extra::TypedHeader;
use bytes::Bytes; use bytes::Bytes;
use maud::{html, PreEscaped}; use maud::{html, PreEscaped};
use tokio::sync::watch::Receiver; use tokio::sync::watch::Receiver;
@ -29,7 +29,7 @@ pub async fn get(hx_target: Option<TypedHeader<HXTarget>>, layout: Layout) -> Re
.targeted(hx_target) .targeted(hx_target)
.render(html! { .render(html! {
pre id="log" hx-sse="connect:/log/stream swap:message" hx-swap="beforeend" hx-target="#log" { pre id="log" hx-sse="connect:/log/stream swap:message" hx-swap="beforeend" hx-target="#log" {
(PreEscaped(convert_escaped(from_utf8(mem_buf.as_slices().0).unwrap()).unwrap())) (PreEscaped(convert(from_utf8(mem_buf.as_slices().0).unwrap()).unwrap()))
} }
})) }))
} }
@ -41,7 +41,7 @@ pub async fn stream(
let log_stream = log_stream.map(|line| { let log_stream = log_stream.map(|line| {
Ok(Event::default().data( Ok(Event::default().data(
html! { html! {
(PreEscaped(convert_escaped(from_utf8(&line).unwrap()).unwrap())) (PreEscaped(convert(from_utf8(&line).unwrap()).unwrap()))
} }
.into_string(), .into_string(),
)) ))

View File

@ -1,29 +1,34 @@
use axum::response::{IntoResponse, Response}; use axum::response::{IntoResponse, Response};
use axum::TypedHeader; use axum::Form;
use axum::{extract::State, Form}; use axum_extra::TypedHeader;
use http::HeaderValue; use http::HeaderValue;
use maud::html; use maud::html;
use serde::Deserialize; use serde::Deserialize;
use serde_with::serde_as; use serde_with::serde_as;
use sqlx::PgPool;
use tracing::info; use tracing::info;
use crate::auth::verify_password; use crate::auth::{AuthSession, Credentials};
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::htmx::{HXRedirect, HXRequest, HXTarget}; use crate::htmx::{HXRedirect, HXRequest, HXTarget};
use crate::partials::login_form::{login_form, LoginFormProps}; use crate::partials::login_form::{login_form, LoginFormProps};
use crate::{ use crate::{models::user::User, partials::layout::Layout};
models::user::{AuthContext, User},
partials::layout::Layout,
};
#[serde_as] #[serde_as]
#[derive(Deserialize)] #[derive(Debug, Clone, Deserialize)]
pub struct Login { pub struct Login {
email: String, email: String,
password: String, password: String,
} }
impl From<Login> for Credentials {
fn from(login: Login) -> Self {
Credentials {
email: login.email,
password: login.password,
}
}
}
pub fn login_page( pub fn login_page(
hx_target: Option<TypedHeader<HXTarget>>, hx_target: Option<TypedHeader<HXTarget>>,
layout: Layout, layout: Layout,
@ -53,47 +58,30 @@ pub async fn get(hx_target: Option<TypedHeader<HXTarget>>, layout: Layout) -> Re
} }
pub async fn post( pub async fn post(
State(pool): State<PgPool>, mut auth: AuthSession,
mut auth: AuthContext,
hx_target: Option<TypedHeader<HXTarget>>, hx_target: Option<TypedHeader<HXTarget>>,
hx_request: Option<TypedHeader<HXRequest>>, hx_request: Option<TypedHeader<HXRequest>>,
layout: Layout, layout: Layout,
Form(login): Form<Login>, Form(login): Form<Login>,
) -> Result<Response> { ) -> Result<Response> {
let user: User = match User::get_by_email(&pool, login.email.clone()).await { let user: User = match auth.authenticate(login.clone().into()).await {
Ok(user) => user, Ok(Some(user)) => user,
Err(err) => { Ok(None) => {
if let Error::NotFoundString(_, _) = err { info!(email = login.email, "authentication failed");
info!(email = login.email, "invalid email"); return Ok(login_page(
return Ok(login_page( hx_target,
hx_target, layout,
layout, LoginFormProps {
LoginFormProps { email: Some(login.email),
email: Some(login.email), general_error: Some("invalid email or password".to_string()),
general_error: Some("invalid email or password".to_string()), ..Default::default()
..Default::default() },
}, ));
)); }
} else { Err(_) => {
return Err(err); return Err(Error::InternalServerError);
}
} }
}; };
if verify_password(login.password, user.password_hash.clone())
.await
.is_err()
{
info!(user_id = %user.user_id, "invalid password");
return Ok(login_page(
hx_target,
layout,
LoginFormProps {
email: Some(login.email),
general_error: Some("invalid email or password".to_string()),
..Default::default()
},
));
}
info!(user_id = %user.user_id, "login successful"); info!(user_id = %user.user_id, "login successful");
auth.login(&user) auth.login(&user)
.await .await

View File

@ -1,6 +1,12 @@
use crate::{models::user::AuthContext, htmx::HXRedirect}; use anyhow::Context;
use axum::response::{IntoResponse, Response};
pub async fn get(mut auth: AuthContext) -> HXRedirect { use crate::auth::AuthSession;
auth.logout().await; use crate::error::Result;
HXRedirect::to("/").reload(true) use crate::htmx::HXRedirect;
pub async fn get(mut auth: AuthSession) -> Result<Response> {
auth.logout()
.context("failed to logout user from session")?;
Ok(HXRedirect::to("/").reload(true).into_response())
} }

View File

@ -1,12 +1,13 @@
pub mod account;
pub mod api; pub mod api;
pub mod confirm_email; pub mod confirm_email;
pub mod entries; pub mod entries;
pub mod entry; pub mod entry;
pub mod home;
pub mod import;
pub mod feed; pub mod feed;
pub mod feeds; pub mod feeds;
pub mod forgot_password; pub mod forgot_password;
pub mod home;
pub mod import;
pub mod log; pub mod log;
pub mod login; pub mod login;
pub mod logout; pub mod logout;

View File

@ -1,6 +1,6 @@
use axum::response::{IntoResponse, Response}; use axum::response::{IntoResponse, Response};
use axum::TypedHeader;
use axum::{extract::State, Form}; use axum::{extract::State, Form};
use axum_extra::TypedHeader;
use http::HeaderValue; use http::HeaderValue;
use lettre::SmtpTransport; use lettre::SmtpTransport;
use maud::html; use maud::html;
@ -8,11 +8,12 @@ use serde::Deserialize;
use serde_with::{serde_as, NoneAsEmptyString}; use serde_with::{serde_as, NoneAsEmptyString};
use sqlx::PgPool; use sqlx::PgPool;
use crate::auth::AuthSession;
use crate::config::Config; use crate::config::Config;
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::htmx::{HXRedirect, HXTarget}; use crate::htmx::{HXRedirect, HXTarget};
use crate::mailers::email_verification::send_confirmation_email; use crate::mailers::email_verification::send_confirmation_email;
use crate::models::user::{AuthContext, CreateUser, User}; use crate::models::user::{CreateUser, User};
use crate::partials::layout::Layout; use crate::partials::layout::Layout;
use crate::partials::register_form::{register_form, RegisterFormProps}; use crate::partials::register_form::{register_form, RegisterFormProps};
@ -61,7 +62,7 @@ pub async fn post(
State(pool): State<PgPool>, State(pool): State<PgPool>,
State(mailer): State<SmtpTransport>, State(mailer): State<SmtpTransport>,
State(config): State<Config>, State(config): State<Config>,
mut auth: AuthContext, mut auth: AuthSession,
hx_target: Option<TypedHeader<HXTarget>>, hx_target: Option<TypedHeader<HXTarget>>,
layout: Layout, layout: Layout,
Form(register): Form<Register>, Form(register): Form<Register>,

View File

@ -1,8 +1,8 @@
use axum::extract::Query; use axum::extract::Query;
use axum::response::Response; use axum::response::Response;
use axum::TypedHeader;
use axum::{extract::State, Form}; use axum::{extract::State, Form};
use axum_client_ip::SecureClientIp; use axum_client_ip::SecureClientIp;
use axum_extra::TypedHeader;
use headers::UserAgent; use headers::UserAgent;
use lettre::SmtpTransport; use lettre::SmtpTransport;
use maud::html; use maud::html;

View File

@ -1,7 +1,5 @@
use axum::{ use axum::http::{HeaderName, HeaderValue};
headers::{self, Header}, use axum_extra::headers::{self, Header};
http::{HeaderName, HeaderValue},
};
/// Typed header implementation for the `Accept` header. /// Typed header implementation for the `Accept` header.
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]

View File

@ -1,23 +1,26 @@
use std::{collections::HashMap, net::SocketAddr, path::Path, sync::Arc}; use std::{collections::HashMap, net::SocketAddr, path::Path, sync::Arc};
use anyhow::Result; use anyhow::Result;
use async_fred_session::{RedisSessionStore, fred::{pool::RedisPool, types::RedisConfig}};
use axum::{ use axum::{
response::IntoResponse, error_handling::HandleErrorLayer,
routing::{get, post}, routing::{get, post},
Extension, Router, BoxError, Router,
}; };
use axum_login::{ use axum_login::{
axum_sessions::SessionLayer, AuthLayer, PostgresStore, RequireAuthorizationLayer, login_required,
tower_sessions::{fred::prelude::*, Expiry, RedisStore, SessionManagerLayer},
AuthManagerLayerBuilder,
}; };
use bytes::Bytes; use bytes::Bytes;
use clap::Parser; use clap::Parser;
use dotenvy::dotenv; use dotenvy::dotenv;
use http::StatusCode;
use lettre::transport::smtp::authentication::Credentials; use lettre::transport::smtp::authentication::Credentials;
use lettre::SmtpTransport; use lettre::SmtpTransport;
use notify::Watcher; use notify::Watcher;
use reqwest::Client; use reqwest::Client;
use sqlx::postgres::PgPoolOptions; use sqlx::postgres::PgPoolOptions;
use time::Duration;
use tokio::sync::watch::channel; use tokio::sync::watch::channel;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tower::ServiceBuilder; use tower::ServiceBuilder;
@ -25,30 +28,24 @@ use tower_http::{services::ServeDir, trace::TraceLayer};
use tower_livereload::LiveReloadLayer; use tower_livereload::LiveReloadLayer;
use tracing::debug; use tracing::debug;
use lib::actors::crawl_scheduler::CrawlSchedulerHandle;
use lib::actors::importer::ImporterHandle; use lib::actors::importer::ImporterHandle;
use lib::config::Config; use lib::config::Config;
use lib::domain_locks::DomainLocks; use lib::domain_locks::DomainLocks;
use lib::handlers; use lib::handlers;
use lib::log::init_tracing; use lib::log::init_tracing;
use lib::models::user::User;
use lib::state::AppState; use lib::state::AppState;
use lib::USER_AGENT; use lib::USER_AGENT;
use uuid::Uuid; use lib::{actors::crawl_scheduler::CrawlSchedulerHandle, auth::Backend};
async fn serve(app: Router, addr: SocketAddr) -> Result<()> { async fn serve(app: Router, addr: SocketAddr) -> Result<()> {
debug!("listening on {}", addr); debug!("listening on {}", addr);
axum::Server::bind(&addr) let listener = tokio::net::TcpListener::bind(addr).await?;
.serve(app.into_make_service_with_connect_info::<SocketAddr>()) axum::serve(
.await?; listener,
Ok(()) app.into_make_service_with_connect_info::<SocketAddr>(),
}
async fn protected_handler(Extension(user): Extension<User>) -> impl IntoResponse {
format!(
"Logged in as: {}",
user.name.unwrap_or_else(|| "No name".to_string())
) )
.await?;
Ok(())
} }
#[tokio::main] #[tokio::main]
@ -65,7 +62,8 @@ async fn main() -> Result<()> {
let domain_locks = DomainLocks::new(); let domain_locks = DomainLocks::new();
let client = Client::builder().user_agent(USER_AGENT).build()?; let client = Client::builder().user_agent(USER_AGENT).build()?;
let secret = config.session_secret.as_bytes(); // TODO: not needed anymore?
// let secret = config.session_secret.as_bytes();
let pool = PgPoolOptions::new() let pool = PgPoolOptions::new()
.max_connections(config.database_max_connections) .max_connections(config.database_max_connections)
@ -73,15 +71,27 @@ async fn main() -> Result<()> {
.await?; .await?;
let redis_config = RedisConfig::from_url(&config.redis_url)?; let redis_config = RedisConfig::from_url(&config.redis_url)?;
let redis_pool = RedisPool::new(redis_config, None, None, config.redis_pool_size)?; // TODO: https://github.com/maxcountryman/tower-sessions/issues/92
redis_pool.connect(); // let redis_pool = RedisPool::new(redis_config, None, None, config.redis_pool_size)?;
redis_pool.wait_for_connect().await?; // redis_pool.connect();
// redis_pool.wait_for_connect().await?;
let redis_client = RedisClient::new(redis_config, None, None, None);
redis_client.connect();
redis_client.wait_for_connect().await?;
let session_store = RedisSessionStore::from_pool(redis_pool, Some("async-fred-session/".into())); let session_store = RedisStore::new(redis_client);
let session_layer = SessionLayer::new(session_store, secret).with_secure(false); let session_layer = SessionManagerLayer::new(session_store)
let user_store = PostgresStore::<User>::new(pool.clone()) .with_secure(!cfg!(debug_assertions))
.with_query("select * from users where user_id = $1"); .with_expiry(Expiry::OnInactivity(Duration::days(
let auth_layer = AuthLayer::new(user_store, secret); config.session_duration_days,
)));
let backend = Backend::new(pool.clone());
let auth_service = ServiceBuilder::new()
.layer(HandleErrorLayer::new(|_: BoxError| async {
StatusCode::BAD_REQUEST
}))
.layer(AuthManagerLayerBuilder::new(backend, session_layer).build());
let creds = Credentials::new(config.smtp_user.clone(), config.smtp_password.clone()); let creds = Credentials::new(config.smtp_user.clone(), config.smtp_password.clone());
@ -107,8 +117,8 @@ async fn main() -> Result<()> {
let addr = format!("{}:{}", &config.host, &config.port).parse()?; let addr = format!("{}:{}", &config.host, &config.port).parse()?;
let mut app = Router::new() let mut app = Router::new()
.route("/protected", get(protected_handler)) .route("/account", get(handlers::account::get))
.route_layer(RequireAuthorizationLayer::<Uuid, User>::login()) .route_layer(login_required!(Backend, login_url = "/login"))
.route("/api/v1/feeds", get(handlers::api::feeds::get)) .route("/api/v1/feeds", get(handlers::api::feeds::get))
.route("/api/v1/feed", post(handlers::api::feed::post)) .route("/api/v1/feed", post(handlers::api::feed::post))
.route("/api/v1/feed/:id", get(handlers::api::feed::get)) .route("/api/v1/feed/:id", get(handlers::api::feed::get))
@ -152,8 +162,7 @@ async fn main() -> Result<()> {
mailer, mailer,
}) })
.layer(ServiceBuilder::new().layer(TraceLayer::new_for_http())) .layer(ServiceBuilder::new().layer(TraceLayer::new_for_http()))
.layer(auth_layer) .layer(auth_service)
.layer(session_layer)
.layer(ip_source_extension); .layer(ip_source_extension);
if cfg!(debug_assertions) { if cfg!(debug_assertions) {

View File

@ -1,11 +1,10 @@
use axum_login::{secrecy::SecretVec, AuthUser, PostgresStore};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use serde::Deserialize; use serde::Deserialize;
use sqlx::{Executor, FromRow, Postgres}; use sqlx::{Executor, FromRow, Postgres};
use uuid::Uuid; use uuid::Uuid;
use validator::Validate; use validator::Validate;
use crate::auth::hash_password; use crate::auth::generate_hash;
use crate::error::{Error, Result}; use crate::error::{Error, Result};
#[derive(Debug, Default, Clone, FromRow)] #[derive(Debug, Default, Clone, FromRow)]
@ -44,16 +43,6 @@ pub struct UpdateUserPassword {
pub password: String, pub password: String,
} }
impl AuthUser<Uuid> for User {
fn get_id(&self) -> Uuid {
self.user_id
}
fn get_password_hash(&self) -> SecretVec<u8> {
SecretVec::new(self.password_hash.clone().into())
}
}
impl User { impl User {
pub async fn get(db: impl Executor<'_, Database = Postgres>, user_id: Uuid) -> Result<User> { pub async fn get(db: impl Executor<'_, Database = Postgres>, user_id: Uuid) -> Result<User> {
sqlx::query_as!( sqlx::query_as!(
@ -101,7 +90,7 @@ impl User {
payload: CreateUser, payload: CreateUser,
) -> Result<User> { ) -> Result<User> {
payload.validate()?; payload.validate()?;
let password_hash = hash_password(payload.password).await?; let password_hash = generate_hash(payload.password).await?;
Ok(sqlx::query_as!( Ok(sqlx::query_as!(
User, User,
@ -156,7 +145,7 @@ impl User {
payload: UpdateUserPassword, payload: UpdateUserPassword,
) -> Result<User> { ) -> Result<User> {
payload.validate()?; payload.validate()?;
let password_hash = hash_password(payload.password).await?; let password_hash = generate_hash(payload.password).await?;
Ok(sqlx::query_as!( Ok(sqlx::query_as!(
User, User,
@ -181,5 +170,3 @@ impl User {
.await?) .await?)
} }
} }
pub type AuthContext = axum_login::extractors::AuthContext<Uuid, User, PostgresStore<User>>;

View File

@ -8,17 +8,17 @@ use axum::{
extract::{FromRef, FromRequestParts, State}, extract::{FromRef, FromRequestParts, State},
http::request::Parts, http::request::Parts,
response::{IntoResponse, Response}, response::{IntoResponse, Response},
TypedHeader,
}; };
use axum_extra::TypedHeader;
use headers::HeaderValue; use headers::HeaderValue;
use maud::{html, Markup, DOCTYPE}; use maud::{html, Markup, DOCTYPE};
use crate::models::user::AuthContext; use crate::auth::AuthSession;
use crate::models::user::User;
use crate::config::Config; use crate::config::Config;
use crate::htmx::HXTarget; use crate::htmx::HXTarget;
use crate::partials::header::header; use crate::models::user::User;
use crate::partials::footer::footer; use crate::partials::footer::footer;
use crate::partials::header::header;
#[cfg(not(debug_assertions))] #[cfg(not(debug_assertions))]
use crate::{CSS_MANIFEST, JS_MANIFEST}; use crate::{CSS_MANIFEST, JS_MANIFEST};
@ -42,13 +42,12 @@ where
let State(config) = State::<Config>::from_request_parts(parts, state) let State(config) = State::<Config>::from_request_parts(parts, state)
.await .await
.map_err(|err| err.into_response())?; .map_err(|err| err.into_response())?;
let auth_context = let auth_session = AuthSession::from_request_parts(parts, state)
AuthContext::from_request_parts(parts, state) .await
.await .map_err(|err| err.into_response())?;
.map_err(|err| err.into_response())?;
Ok(Self { Ok(Self {
title: config.title, title: config.title,
user: auth_context.current_user, user: auth_session.user,
..Default::default() ..Default::default()
}) })
} }
@ -119,11 +118,11 @@ impl Layout {
self self
} }
/// If the given HX-Target is present and equal to "main-content", then this function will make /// If the given HX-Target is present and equal to "main-content", then this function will make
/// this Layout skip rendering the layout and only render the template with a hx-swap-oob /// this Layout skip rendering the layout and only render the template with a hx-swap-oob
/// <title> element to update the document title. /// <title> element to update the document title.
/// ///
/// Links and forms that are boosted with the hx-boost attribute are only updating a portion of /// Links and forms that are boosted with the hx-boost attribute are only updating a portion of
/// the page inside the layout, so there is no need to render and send the layout again. /// the page inside the layout, so there is no need to render and send the layout again.
pub fn targeted(mut self, hx_target: Option<TypedHeader<HXTarget>>) -> Self { pub fn targeted(mut self, hx_target: Option<TypedHeader<HXTarget>>) -> Self {
if let Some(hx_target) = hx_target { if let Some(hx_target) = hx_target {