Add basic user auth

This commit is contained in:
Tyler Hallada 2023-09-25 01:35:26 -04:00
parent ec394fc170
commit 306059c355
30 changed files with 1433 additions and 663 deletions

1470
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -12,19 +12,25 @@ path = "src/lib.rs"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
ammonia = "3.3.0"
ansi-to-html = "0.1" ansi-to-html = "0.1"
anyhow = "1" anyhow = "1"
argon2 = "0.5"
axum = { version = "0.6", features = ["form", "headers", "multipart"] } axum = { version = "0.6", features = ["form", "headers", "multipart"] }
# waiting for new axum-login release which will support sqlx v. 0.7+
axum-login = { git = "https://github.com/maxcountryman/axum-login", branch = "main", features = ["postgres"] }
bytes = "1.4" bytes = "1.4"
chrono = { version = "0.4", features = ["serde"] } chrono = { version = "0.4", features = ["serde"] }
clap = { version = "4.3", features = ["derive", "env"] } clap = { version = "4.4", features = ["derive", "env"] }
dotenvy = "0.15" dotenvy = "0.15"
feed-rs = "1.3" feed-rs = "1.3"
futures = "0.3" futures = "0.3"
http = "0.2.9"
maud = { version = "0.25", features = ["axum"] } maud = { version = "0.25", features = ["axum"] }
notify = "6" notify = "6"
once_cell = "1.17" once_cell = "1.18"
opml = "1.1" opml = "1.1"
rand = { version = "0.8.5", features = ["min_const_gen"] }
readability = "0.2" readability = "0.2"
reqwest = { version = "0.11", features = ["json"] } reqwest = { version = "0.11", features = ["json"] }
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }
@ -46,8 +52,6 @@ tower-http = { version = "0.4", features = ["trace", "fs"] }
tracing = { version = "0.1", features = ["valuable", "attributes"] } tracing = { version = "0.1", features = ["valuable", "attributes"] }
tracing-appender = "0.2" tracing-appender = "0.2"
tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-subscriber = { version = "0.3", features = ["env-filter"] }
uuid = { version = "1.3", features = ["serde"] } uuid = { version = "1.4", features = ["serde"] }
url = "2.4" url = "2.4"
validator = { version = "0.16", features = ["derive"] } validator = { version = "0.16", features = ["derive"] }
ammonia = "3.3.0"
http = "0.2.9"

View File

@ -6,4 +6,5 @@ drop table _sqlx_migrations cascade;
drop collation case_insensitive; drop collation case_insensitive;
drop table entry cascade; drop table entry cascade;
drop table feed cascade; drop table feed cascade;
drop table users cascade;
drop type feed_type; drop type feed_type;

View File

@ -55,6 +55,10 @@ header.header nav ul li {
margin-left: 16px; margin-left: 16px;
} }
header.header nav .auth {
margin-left: auto;
}
/* Footer */ /* Footer */
footer.footer { footer.footer {
@ -187,7 +191,7 @@ form.feed-form .form-grid textarea {
form.feed-form .form-grid button { form.feed-form .form-grid button {
font-size: 14px; font-size: 14px;
padding: 4px 8px; padding: 4px 8px;
grid-column: 3 / 4; grid-column: 3 / 3;
} }
ul.stream-messages { ul.stream-messages {
@ -217,3 +221,38 @@ header.feed-header button {
padding: 4px 8px; padding: 4px 8px;
margin-left: 24px; margin-left: 24px;
} }
/* Signup & Login */
.auth-form-grid {
display: grid;
grid-template-columns: fit-content(100%) minmax(100px, 400px);
grid-gap: 16px;
width: 100%;
margin: 16px;
margin-bottom: 32px;
}
.auth-form-grid label {
font-size: 16px;
font-weight: bold;
grid-column: 1;
text-align: right;
}
.auth-form-grid input {
font-size: 16px;
grid-column: 2;
}
.auth-form-grid button {
font-size: 14px;
padding: 4px 8px;
grid-column: 2;
margin-left: auto;
}
.auth-form-grid span.error {
font-size: 16px;
grid-column: 2 / 3;
}

View File

@ -68,3 +68,15 @@ create table if not exists "entry" (
create index on "entry" (published_at desc) where deleted_at is null; create index on "entry" (published_at desc) where deleted_at is null;
create unique index on "entry" (url, feed_id); create unique index on "entry" (url, feed_id);
select trigger_updated_at('"entry"'); select trigger_updated_at('"entry"');
create table if not exists "users" (
user_id uuid primary key default uuid_generate_v1mc(),
password_hash text not null,
email text not null collate case_insensitive,
name text,
created_at timestamptz not null default now(),
updated_at timestamptz,
deleted_at timestamptz
);
create unique index on "users" (email);
select trigger_updated_at('"users"');

38
src/auth.rs Normal file
View File

@ -0,0 +1,38 @@
use anyhow::Context;
use argon2::password_hash::{
rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString,
};
use argon2::Argon2;
use crate::error::{Error, Result};
pub async fn hash_password(password: String) -> Result<String> {
// Argon2 hashing is designed to be computationally intensive,
// so we need to do this on a blocking thread.
tokio::task::spawn_blocking(move || -> Result<String> {
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
Ok(argon2
.hash_password(password.as_bytes(), &salt)
.map_err(|e| anyhow::anyhow!("failed to generate password hash: {}", e))?
.to_string())
})
.await
.context("panic in generating password hash")?
}
pub async fn verify_password(password: String, password_hash: String) -> Result<()> {
tokio::task::spawn_blocking(move || -> Result<()> {
let hash = PasswordHash::new(&password_hash)
.map_err(|e| anyhow::anyhow!("invalid password hash: {}", e))?;
Argon2::default()
.verify_password(password.as_bytes(), &hash)
.map_err(|e| match e {
argon2::password_hash::Error::Password => Error::Unauthorized,
_ => anyhow::anyhow!("failed to verify password hash: {}", e).into(),
})
})
.await
.context("panic in verifying password hash")?
}

View File

@ -34,13 +34,22 @@ pub enum Error {
NoFile, NoFile,
#[error("{0}: {1} not found")] #[error("{0}: {1} not found")]
NotFound(&'static str, Uuid), NotFoundUuid(&'static str, Uuid),
#[error("{0}: {1} not found")]
NotFoundString(&'static str, String),
#[error("referenced {0} not found")] #[error("referenced {0} not found")]
RelationNotFound(&'static str), RelationNotFound(&'static str),
#[error("an internal server error occurred")] #[error("an internal server error occurred")]
InternalServerError, InternalServerError,
#[error("unauthorized")]
Unauthorized,
#[error("bad request: {0}")]
BadRequest(&'static str)
} }
pub type Result<T, E = Error> = ::std::result::Result<T, E>; pub type Result<T, E = Error> = ::std::result::Result<T, E>;
@ -81,7 +90,9 @@ impl Error {
use Error::*; use Error::*;
match self { match self {
NotFound(_, _) => StatusCode::NOT_FOUND, NotFoundUuid(_, _) | NotFoundString(_, _) => StatusCode::NOT_FOUND,
Unauthorized => StatusCode::UNAUTHORIZED,
BadRequest(_) => StatusCode::BAD_REQUEST,
InternalServerError | Sqlx(_) | Anyhow(_) | Reqwest(_) => { InternalServerError | Sqlx(_) | Anyhow(_) | Reqwest(_) => {
StatusCode::INTERNAL_SERVER_ERROR StatusCode::INTERNAL_SERVER_ERROR
} }

View File

@ -27,7 +27,9 @@ pub async fn get(
let content = fs::read_to_string(content_path).unwrap_or_else(|_| "No content".to_string()); let content = fs::read_to_string(content_path).unwrap_or_else(|_| "No content".to_string());
Ok(layout.with_subtitle(&title).render(html! { Ok(layout.with_subtitle(&title).render(html! {
article { article {
header {
h2 class="title" { a href=(entry.url) { (title) } } h2 class="title" { a href=(entry.url) { (title) } }
}
div { div {
span class="published" { span class="published" {
strong { "Published: " } strong { "Published: " }

View File

@ -164,7 +164,7 @@ pub async fn stream(
let mut crawls = crawls.lock().await; let mut crawls = crawls.lock().await;
crawls.remove(&id.as_uuid()) crawls.remove(&id.as_uuid())
} }
.ok_or_else(|| Error::NotFound("feed stream", id.as_uuid()))?; .ok_or_else(|| Error::NotFoundUuid("feed stream", id.as_uuid()))?;
let stream = BroadcastStream::new(receiver); let stream = BroadcastStream::new(receiver);
let feed_id = format!("feed-{}", id); let feed_id = format!("feed-{}", id);

View File

@ -14,7 +14,7 @@ pub async fn get(State(pool): State<PgPool>, layout: Layout) -> Result<Response>
let options = GetFeedsOptions::default(); let options = GetFeedsOptions::default();
let feeds = Feed::get_all(&pool, &options).await?; let feeds = Feed::get_all(&pool, &options).await?;
Ok(layout.with_subtitle("feeds").render(html! { Ok(layout.with_subtitle("feeds").render(html! {
h2 { "Feeds" } header { h2 { "Feeds" } }
div class="feeds" { div class="feeds" {
ul id="feeds" { ul id="feeds" {
(feed_list(feeds, &options)) (feed_list(feeds, &options))

View File

@ -59,7 +59,7 @@ pub async fn stream(
let mut imports = imports.lock().await; let mut imports = imports.lock().await;
imports.remove(&id.as_uuid()) imports.remove(&id.as_uuid())
} }
.ok_or_else(|| Error::NotFound("import stream", id.as_uuid()))?; .ok_or_else(|| Error::NotFoundUuid("import stream", id.as_uuid()))?;
let stream = BroadcastStream::new(receiver); let stream = BroadcastStream::new(receiver);
let stream = stream.map(move |msg| match msg { let stream = stream.map(move |msg| match msg {

69
src/handlers/login.rs Normal file
View File

@ -0,0 +1,69 @@
use axum::response::{IntoResponse, Redirect, Response};
use axum::{extract::State, Form};
use maud::html;
use serde::Deserialize;
use serde_with::serde_as;
use sqlx::PgPool;
use crate::auth::verify_password;
use crate::error::{Error, Result};
use crate::partials::login_form::{login_form, LoginFormProps};
use crate::{
models::user::{AuthContext, User},
partials::layout::Layout,
};
#[serde_as]
#[derive(Deserialize)]
pub struct Login {
email: String,
password: String,
}
pub async fn get(layout: Layout) -> Result<Response> {
Ok(layout.with_subtitle("login").render(html! {
header {
h2 { "Login" }
}
(login_form(LoginFormProps::default()))
}))
}
pub async fn post(
State(pool): State<PgPool>,
mut auth: AuthContext,
Form(login): Form<Login>,
) -> Result<Response> {
let user: User = match User::get_by_email(&pool, login.email.clone()).await {
Ok(user) => user,
Err(err) => {
if let Error::NotFoundString(_, _) = err {
// Error::BadRequest("invalid email or password")
return Ok(login_form(LoginFormProps {
email: Some(login.email),
general_error: Some("invalid email or password".to_string()),
..Default::default()
})
.into_response());
} else {
return Err(err);
}
}
};
if verify_password(login.password, user.password_hash.clone())
.await
.is_err()
{
// return Err(Error::BadRequest("invalid email or password"));
return Ok(login_form(LoginFormProps {
email: Some(login.email),
general_error: Some("invalid email or password".to_string()),
..Default::default()
})
.into_response());
}
auth.login(&user)
.await
.map_err(|_| Error::InternalServerError)?;
Ok(Redirect::to("/").into_response())
}

8
src/handlers/logout.rs Normal file
View File

@ -0,0 +1,8 @@
use axum::response::Redirect;
use crate::models::user::AuthContext;
pub async fn get(mut auth: AuthContext) -> Redirect {
auth.logout().await;
Redirect::to("/")
}

View File

@ -6,3 +6,6 @@ pub mod import;
pub mod feed; pub mod feed;
pub mod feeds; pub mod feeds;
pub mod log; pub mod log;
pub mod login;
pub mod logout;
pub mod signup;

111
src/handlers/signup.rs Normal file
View File

@ -0,0 +1,111 @@
use axum::response::{IntoResponse, Redirect, Response};
use axum::{extract::State, Form};
use maud::html;
use serde::Deserialize;
use serde_with::{serde_as, NoneAsEmptyString};
use sqlx::PgPool;
use crate::error::{Error, Result};
use crate::models::user::{AuthContext, CreateUser, User};
use crate::partials::layout::Layout;
use crate::partials::signup_form::{signup_form, SignupFormProps};
#[serde_as]
#[derive(Debug, Deserialize)]
pub struct Signup {
pub email: String,
pub password: String,
pub password_confirmation: String,
#[serde_as(as = "NoneAsEmptyString")]
pub name: Option<String>,
}
pub async fn get(layout: Layout) -> Result<Response> {
Ok(layout.with_subtitle("signup").render(html! {
header {
h2 { "Signup" }
}
(signup_form(SignupFormProps::default()))
}))
}
pub async fn post(
State(pool): State<PgPool>,
mut auth: AuthContext,
Form(signup): Form<Signup>,
) -> Result<Response> {
if signup.password != signup.password_confirmation {
// return Err(Error::BadRequest("passwords do not match"));
return Ok(signup_form(SignupFormProps {
email: Some(signup.email),
name: signup.name,
password_error: Some("passwords do not match".to_string()),
..Default::default()
})
.into_response());
}
let user = match User::create(
&pool,
CreateUser {
email: signup.email.clone(),
password: signup.password.clone(),
name: signup.name.clone(),
},
)
.await
{
Ok(user) => user,
Err(err) => {
if let Error::InvalidEntity(validation_errors) = err {
let field_errors = validation_errors.field_errors();
dbg!(&validation_errors);
dbg!(&field_errors);
return Ok(signup_form(SignupFormProps {
email: Some(signup.email),
name: signup.name,
email_error: field_errors.get("email").map(|&errors| {
errors
.iter()
.filter_map(|error| error.message.clone().map(|m| m.to_string()))
.collect::<Vec<String>>()
.join(", ")
}),
name_error: field_errors.get("name").map(|&errors| {
errors
.iter()
.filter_map(|error| error.message.clone().map(|m| m.to_string()))
.collect::<Vec<String>>()
.join(", ")
}),
password_error: field_errors.get("password").map(|&errors| {
errors
.iter()
.filter_map(|error| error.message.clone().map(|m| m.to_string()))
.collect::<Vec<String>>()
.join(", ")
}),
..Default::default()
})
.into_response());
}
if let Error::Sqlx(sqlx::error::Error::Database(db_error)) = &err {
if let Some(constraint) = db_error.constraint() {
if constraint == "users_email_idx" {
return Ok(signup_form(SignupFormProps {
email: Some(signup.email),
name: signup.name,
email_error: Some("email already exists".to_string()),
..Default::default()
})
.into_response());
}
}
}
return Err(err);
}
};
auth.login(&user)
.await
.map_err(|_| Error::InternalServerError)?;
Ok(Redirect::to("/").into_response())
}

View File

@ -1,5 +1,6 @@
pub mod actors; pub mod actors;
pub mod api_response; pub mod api_response;
pub mod auth;
pub mod config; pub mod config;
pub mod domain_locks; pub mod domain_locks;
pub mod error; pub mod error;

View File

@ -1,19 +1,20 @@
use std::{ use std::{collections::HashMap, net::SocketAddr, path::Path, sync::Arc};
collections::HashMap,
net::SocketAddr,
path::Path,
sync::Arc,
};
use anyhow::Result; use anyhow::Result;
use axum::{ use axum::{
response::IntoResponse,
routing::{get, post}, routing::{get, post},
Router, Extension, Router,
};
use axum_login::{
axum_sessions::{async_session::MemoryStore, SessionLayer},
AuthLayer, PostgresStore, RequireAuthorizationLayer,
}; };
use bytes::Bytes; use bytes::Bytes;
use clap::Parser; use clap::Parser;
use dotenvy::dotenv; use dotenvy::dotenv;
use notify::Watcher; use notify::Watcher;
use rand::Rng;
use reqwest::Client; use reqwest::Client;
use sqlx::postgres::PgPoolOptions; use sqlx::postgres::PgPoolOptions;
use tokio::sync::watch::channel; use tokio::sync::watch::channel;
@ -29,8 +30,10 @@ use lib::config::Config;
use lib::domain_locks::DomainLocks; use lib::domain_locks::DomainLocks;
use lib::handlers; use lib::handlers;
use lib::log::init_tracing; use lib::log::init_tracing;
use lib::models::user::User;
use lib::state::AppState; use lib::state::AppState;
use lib::USER_AGENT; use lib::USER_AGENT;
use uuid::Uuid;
async fn serve(app: Router, addr: SocketAddr) -> Result<()> { async fn serve(app: Router, addr: SocketAddr) -> Result<()> {
debug!("listening on {}", addr); debug!("listening on {}", addr);
@ -40,6 +43,13 @@ async fn serve(app: Router, addr: SocketAddr) -> Result<()> {
Ok(()) Ok(())
} }
async fn protected_handler(Extension(user): Extension<User>) -> impl IntoResponse {
format!(
"Logged in as: {}",
user.name.unwrap_or_else(|| "No name".to_string())
)
}
#[tokio::main] #[tokio::main]
async fn main() -> Result<()> { async fn main() -> Result<()> {
dotenv().ok(); dotenv().ok();
@ -54,11 +64,20 @@ async fn main() -> Result<()> {
let domain_locks = DomainLocks::new(); let domain_locks = DomainLocks::new();
let client = Client::builder().user_agent(USER_AGENT).build()?; let client = Client::builder().user_agent(USER_AGENT).build()?;
let secret = rand::thread_rng().gen::<[u8; 64]>();
let pool = PgPoolOptions::new() let pool = PgPoolOptions::new()
.max_connections(config.database_max_connections) .max_connections(config.database_max_connections)
.connect(&config.database_url) .connect(&config.database_url)
.await?; .await?;
// TODO: store sessions in postgres eventually
let session_store = MemoryStore::new();
let session_layer = SessionLayer::new(session_store, &secret).with_secure(false);
let user_store = PostgresStore::<User>::new(pool.clone())
.with_query("select * from users where user_id = $1");
let auth_layer = AuthLayer::new(user_store, &secret);
sqlx::migrate!().run(&pool).await?; sqlx::migrate!().run(&pool).await?;
let crawl_scheduler = CrawlSchedulerHandle::new( let crawl_scheduler = CrawlSchedulerHandle::new(
@ -69,14 +88,12 @@ async fn main() -> Result<()> {
crawls.clone(), crawls.clone(),
); );
let _ = crawl_scheduler.bootstrap().await; let _ = crawl_scheduler.bootstrap().await;
let importer = ImporterHandle::new( let importer = ImporterHandle::new(pool.clone(), crawl_scheduler.clone(), imports.clone());
pool.clone(),
crawl_scheduler.clone(),
imports.clone(),
);
let addr = format!("{}:{}", &config.host, &config.port).parse()?; let addr = format!("{}:{}", &config.host, &config.port).parse()?;
let mut app = Router::new() let mut app = Router::new()
.route("/protected", get(protected_handler))
.route_layer(RequireAuthorizationLayer::<Uuid, User>::login())
.route("/api/v1/feeds", get(handlers::api::feeds::get)) .route("/api/v1/feeds", get(handlers::api::feeds::get))
.route("/api/v1/feed", post(handlers::api::feed::post)) .route("/api/v1/feed", post(handlers::api::feed::post))
.route("/api/v1/feed/:id", get(handlers::api::feed::get)) .route("/api/v1/feed/:id", get(handlers::api::feed::get))
@ -95,6 +112,11 @@ async fn main() -> Result<()> {
.route("/log/stream", get(handlers::log::stream)) .route("/log/stream", get(handlers::log::stream))
.route("/import/opml", post(handlers::import::opml)) .route("/import/opml", post(handlers::import::opml))
.route("/import/:id/stream", get(handlers::import::stream)) .route("/import/:id/stream", get(handlers::import::stream))
.route("/login", get(handlers::login::get))
.route("/login", post(handlers::login::post))
.route("/logout", get(handlers::logout::get))
.route("/signup", get(handlers::signup::get))
.route("/signup", post(handlers::signup::post))
.nest_service("/static", ServeDir::new("static")) .nest_service("/static", ServeDir::new("static"))
.with_state(AppState { .with_state(AppState {
pool, pool,
@ -107,7 +129,9 @@ async fn main() -> Result<()> {
importer, importer,
imports, imports,
}) })
.layer(ServiceBuilder::new().layer(TraceLayer::new_for_http())); .layer(ServiceBuilder::new().layer(TraceLayer::new_for_http()))
.layer(auth_layer)
.layer(session_layer);
if cfg!(debug_assertions) { if cfg!(debug_assertions) {
debug!("starting livereload"); debug!("starting livereload");

View File

@ -50,7 +50,7 @@ impl Entry {
.await .await
.map_err(|error| { .map_err(|error| {
if let sqlx::error::Error::RowNotFound = error { if let sqlx::error::Error::RowNotFound = error {
return Error::NotFound("entry", entry_id); return Error::NotFoundUuid("entry", entry_id);
} }
Error::Sqlx(error) Error::Sqlx(error)
}) })

View File

@ -154,7 +154,7 @@ impl Feed {
.await .await
.map_err(|error| { .map_err(|error| {
if let sqlx::error::Error::RowNotFound = error { if let sqlx::error::Error::RowNotFound = error {
return Error::NotFound("feed", feed_id); return Error::NotFoundUuid("feed", feed_id);
} }
Error::Sqlx(error) Error::Sqlx(error)
}) })

View File

@ -1,2 +1,3 @@
pub mod entry; pub mod entry;
pub mod feed; pub mod feed;
pub mod user;

113
src/models/user.rs Normal file
View File

@ -0,0 +1,113 @@
use axum_login::{secrecy::SecretVec, AuthUser, PostgresStore};
use chrono::{DateTime, Utc};
use serde::Deserialize;
use sqlx::{FromRow, PgPool};
use uuid::Uuid;
use validator::Validate;
use crate::auth::hash_password;
use crate::error::{Error, Result};
#[derive(Debug, Default, Clone, FromRow)]
pub struct User {
pub user_id: Uuid,
pub email: String,
pub password_hash: String,
pub name: Option<String>,
pub created_at: DateTime<Utc>,
pub updated_at: Option<DateTime<Utc>>,
pub deleted_at: Option<DateTime<Utc>>,
}
#[derive(Debug, Deserialize, Default, Validate)]
pub struct CreateUser {
#[validate(email(message = "email must be a valid email address"))]
pub email: String,
#[validate(length(
min = 8,
max = 255,
message = "password must be between 8 and 255 characters long"
))]
pub password: String,
#[validate(length(max = 255, message = "name must be less than 255 characters long"))]
pub name: Option<String>,
}
impl AuthUser<Uuid> for User {
fn get_id(&self) -> Uuid {
self.user_id
}
fn get_password_hash(&self) -> SecretVec<u8> {
SecretVec::new(self.password_hash.clone().into())
}
}
impl User {
pub async fn get(pool: &PgPool, user_id: Uuid) -> Result<User> {
sqlx::query_as!(
User,
r#"select
*
from users
where user_id = $1"#,
user_id
)
.fetch_one(pool)
.await
.map_err(|error| {
if let sqlx::error::Error::RowNotFound = error {
return Error::NotFoundUuid("user", user_id);
}
Error::Sqlx(error)
})
}
pub async fn get_by_email(pool: &PgPool, email: String) -> Result<User> {
sqlx::query_as!(
User,
r#"select
*
from users
where email = $1"#,
email
)
.fetch_one(pool)
.await
.map_err(|error| {
if let sqlx::error::Error::RowNotFound = error {
return Error::NotFoundString("user", email);
}
Error::Sqlx(error)
})
}
pub async fn create(pool: &PgPool, payload: CreateUser) -> Result<User> {
payload.validate()?;
let password_hash = hash_password(payload.password).await?;
Ok(sqlx::query_as!(
User,
r#"insert into users (
email, password_hash, name
) values (
$1, $2, $3
) returning
user_id,
email,
password_hash,
name,
created_at,
updated_at,
deleted_at
"#,
payload.email,
password_hash,
payload.name
)
.fetch_one(pool)
.await?)
}
}
pub type AuthContext = axum_login::extractors::AuthContext<Uuid, User, PostgresStore<User>>;

View File

@ -2,7 +2,7 @@ use maud::{html, Markup};
pub fn add_feed_form() -> Markup { pub fn add_feed_form() -> Markup {
html! { html! {
form hx-post="/feed" class="feed-form" { form hx-post="/feed" hx-swap="outerHTML" class="feed-form" {
div class="form-grid" { div class="form-grid" {
label for="url" { "URL: " } label for="url" { "URL: " }
input type="text" id="url" name="url" placeholder="https://example.com/feed.xml" required="true"; input type="text" id="url" name="url" placeholder="https://example.com/feed.xml" required="true";

View File

@ -1,6 +1,9 @@
use maud::{html, Markup}; use maud::{html, Markup};
pub fn header(title: &str) -> Markup { use crate::models::user::User;
use crate::partials::user_name::user_name;
pub fn header(title: &str, user: Option<User>) -> Markup {
html! { html! {
header class="header" { header class="header" {
nav { nav {
@ -9,6 +12,17 @@ pub fn header(title: &str) -> Markup {
li { a href="/feeds" { "feeds" } } li { a href="/feeds" { "feeds" } }
li { a href="/log" { "log" } } li { a href="/log" { "log" } }
} }
div class="auth" {
@if let Some(user) = user {
(user_name(user))
span { " | " }
a href="/logout" { "logout" }
} @else {
a href="/login" { "login" }
span { " | " }
a href="/signup" { "signup" }
}
}
} }
} }
} }

View File

@ -9,10 +9,14 @@ use axum::{
http::request::Parts, http::request::Parts,
response::{Html, IntoResponse, Response}, response::{Html, IntoResponse, Response},
}; };
use axum_login::{extractors::AuthContext, SqlxStore};
use maud::{html, Markup, DOCTYPE}; use maud::{html, Markup, DOCTYPE};
use sqlx::PgPool;
use uuid::Uuid;
use crate::partials::header::header; use crate::partials::header::header;
use crate::{config::Config, partials::footer::footer}; use crate::{config::Config, partials::footer::footer};
use crate::models::user::User;
#[cfg(not(debug_assertions))] #[cfg(not(debug_assertions))]
use crate::{CSS_MANIFEST, JS_MANIFEST}; use crate::{CSS_MANIFEST, JS_MANIFEST};
@ -20,6 +24,7 @@ use crate::{CSS_MANIFEST, JS_MANIFEST};
pub struct Layout { pub struct Layout {
pub title: String, pub title: String,
pub subtitle: Option<String>, pub subtitle: Option<String>,
pub user: Option<User>,
} }
#[async_trait] #[async_trait]
@ -34,8 +39,12 @@ where
let State(config) = State::<Config>::from_request_parts(parts, state) let State(config) = State::<Config>::from_request_parts(parts, state)
.await .await
.map_err(|err| err.into_response())?; .map_err(|err| err.into_response())?;
let auth_context = AuthContext::<Uuid, User, SqlxStore<PgPool, User>>::from_request_parts(parts, state)
.await
.map_err(|err| err.into_response())?;
Ok(Self { Ok(Self {
title: config.title, title: config.title,
user: auth_context.current_user,
..Default::default() ..Default::default()
}) })
} }
@ -101,6 +110,11 @@ impl Layout {
self self
} }
pub fn with_user(mut self, user: User) -> Self {
self.user = Some(user);
self
}
fn full_title(&self) -> String { fn full_title(&self) -> String {
if let Some(subtitle) = &self.subtitle { if let Some(subtitle) = &self.subtitle {
format!("{} - {}", self.title, subtitle) format!("{} - {}", self.title, subtitle)
@ -124,7 +138,7 @@ impl Layout {
} }
} }
body hx-booster="true" { body hx-booster="true" {
(header(&self.title)) (header(&self.title, self.user))
(template) (template)
(footer()) (footer())
} }

View File

@ -0,0 +1,36 @@
use maud::{html, Markup};
#[derive(Debug, Default)]
pub struct LoginFormProps {
pub email: Option<String>,
pub email_error: Option<String>,
pub password_error: Option<String>,
pub general_error: Option<String>,
}
pub fn login_form(props: LoginFormProps) -> Markup {
let LoginFormProps {
email,
email_error,
password_error,
general_error,
} = props;
html! {
form hx-post="/login" hx-swap="outerHTML" class="auth-form-grid" {
label for="email" { "Email" }
input type="email" name="email" id="email" placeholder="Email" value=(email.unwrap_or_default()) required;
@if let Some(email_error) = email_error {
span class="error" { (email_error) }
}
label for="email" { "Password" }
input type="password" name="password" id="password" placeholder="Password" minlength="8" maxlength="255" required;
@if let Some(password_error) = password_error {
span class="error" { (password_error) }
}
button type="submit" { "Submit" }
@if let Some(general_error) = general_error {
span class="error" { (general_error) }
}
}
}
}

View File

@ -6,4 +6,7 @@ pub mod feed_list;
pub mod footer; pub mod footer;
pub mod header; pub mod header;
pub mod layout; pub mod layout;
pub mod login_form;
pub mod opml_import_form; pub mod opml_import_form;
pub mod signup_form;
pub mod user_name;

View File

@ -2,7 +2,7 @@ use maud::{html, Markup, PreEscaped};
pub fn opml_import_form() -> Markup { pub fn opml_import_form() -> Markup {
html! { html! {
form id="opml-import-form" hx-post="/import/opml" hx-encoding="multipart/form-data" class="feed-form" { form id="opml-import-form" hx-post="/import/opml" hx-swap="outerHTML" hx-encoding="multipart/form-data" class="feed-form" {
div class="form-grid" { div class="form-grid" {
label for="opml" { "OPML: " } label for="opml" { "OPML: " }
input type="file" id="opml" name="opml" required="true" accept="text/x-opml,application/xml,text/xml"; input type="file" id="opml" name="opml" required="true" accept="text/x-opml,application/xml,text/xml";

View File

@ -0,0 +1,47 @@
use maud::{html, Markup, PreEscaped};
#[derive(Debug, Default)]
pub struct SignupFormProps {
pub email: Option<String>,
pub name: Option<String>,
pub email_error: Option<String>,
pub name_error: Option<String>,
pub password_error: Option<String>,
pub general_error: Option<String>,
}
pub fn signup_form(props: SignupFormProps) -> Markup {
let SignupFormProps {
email,
name,
email_error,
name_error,
password_error,
general_error,
} = props;
html! {
form hx-post="/signup" hx-swap="outerHTML" class="auth-form-grid" {
label for="email" { "Email *" }
input type="email" name="email" id="email" placeholder="Email" value=(email.unwrap_or_default()) required;
@if let Some(email_error) = email_error {
span class="error" { (email_error) }
}
label for="name" { (PreEscaped("Name &nbsp;")) }
input type="text" name="name" id="name" value=(name.unwrap_or_default()) placeholder="Name" maxlength="255";
@if let Some(name_error) = name_error {
span class="error" { (name_error) }
}
label for="email" { "Password *" }
input type="password" name="password" id="password" placeholder="Password" minlength="8" maxlength="255" required;
@if let Some(password_error) = password_error {
span class="error" { (password_error) }
}
label for="password_confirmation" { "Confirm Password *" }
input type="password" name="password_confirmation" id="password_confirmation" placeholder="Confirm Password" required;
button type="submit" { "Submit" }
@if let Some(general_error) = general_error {
span class="error" { (general_error) }
}
}
}
}

10
src/partials/user_name.rs Normal file
View File

@ -0,0 +1,10 @@
use maud::{html, Markup};
use crate::models::user::User;
pub fn user_name(user: User) -> Markup {
let name = user.name.unwrap_or(user.email);
html! {
a href="/account" { (name) }
}
}

View File

@ -1,12 +1,11 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::{broadcast, watch, Mutex};
use axum::extract::FromRef; use axum::extract::FromRef;
use bytes::Bytes; use bytes::Bytes;
use reqwest::Client; use reqwest::Client;
use sqlx::PgPool; use sqlx::PgPool;
use tokio::sync::{broadcast, watch, Mutex};
use uuid::Uuid; use uuid::Uuid;
use crate::actors::importer::{ImporterHandle, ImporterHandleMessage}; use crate::actors::importer::{ImporterHandle, ImporterHandleMessage};