Add basic user auth
This commit is contained in:
50
src/main.rs
50
src/main.rs
@@ -1,19 +1,20 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
net::SocketAddr,
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
};
|
||||
use std::{collections::HashMap, net::SocketAddr, path::Path, sync::Arc};
|
||||
|
||||
use anyhow::Result;
|
||||
use axum::{
|
||||
response::IntoResponse,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
Extension, Router,
|
||||
};
|
||||
use axum_login::{
|
||||
axum_sessions::{async_session::MemoryStore, SessionLayer},
|
||||
AuthLayer, PostgresStore, RequireAuthorizationLayer,
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use clap::Parser;
|
||||
use dotenvy::dotenv;
|
||||
use notify::Watcher;
|
||||
use rand::Rng;
|
||||
use reqwest::Client;
|
||||
use sqlx::postgres::PgPoolOptions;
|
||||
use tokio::sync::watch::channel;
|
||||
@@ -29,8 +30,10 @@ use lib::config::Config;
|
||||
use lib::domain_locks::DomainLocks;
|
||||
use lib::handlers;
|
||||
use lib::log::init_tracing;
|
||||
use lib::models::user::User;
|
||||
use lib::state::AppState;
|
||||
use lib::USER_AGENT;
|
||||
use uuid::Uuid;
|
||||
|
||||
async fn serve(app: Router, addr: SocketAddr) -> Result<()> {
|
||||
debug!("listening on {}", addr);
|
||||
@@ -40,6 +43,13 @@ async fn serve(app: Router, addr: SocketAddr) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn protected_handler(Extension(user): Extension<User>) -> impl IntoResponse {
|
||||
format!(
|
||||
"Logged in as: {}",
|
||||
user.name.unwrap_or_else(|| "No name".to_string())
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
dotenv().ok();
|
||||
@@ -54,11 +64,20 @@ async fn main() -> Result<()> {
|
||||
let domain_locks = DomainLocks::new();
|
||||
let client = Client::builder().user_agent(USER_AGENT).build()?;
|
||||
|
||||
let secret = rand::thread_rng().gen::<[u8; 64]>();
|
||||
|
||||
let pool = PgPoolOptions::new()
|
||||
.max_connections(config.database_max_connections)
|
||||
.connect(&config.database_url)
|
||||
.await?;
|
||||
|
||||
// TODO: store sessions in postgres eventually
|
||||
let session_store = MemoryStore::new();
|
||||
let session_layer = SessionLayer::new(session_store, &secret).with_secure(false);
|
||||
let user_store = PostgresStore::<User>::new(pool.clone())
|
||||
.with_query("select * from users where user_id = $1");
|
||||
let auth_layer = AuthLayer::new(user_store, &secret);
|
||||
|
||||
sqlx::migrate!().run(&pool).await?;
|
||||
|
||||
let crawl_scheduler = CrawlSchedulerHandle::new(
|
||||
@@ -69,14 +88,12 @@ async fn main() -> Result<()> {
|
||||
crawls.clone(),
|
||||
);
|
||||
let _ = crawl_scheduler.bootstrap().await;
|
||||
let importer = ImporterHandle::new(
|
||||
pool.clone(),
|
||||
crawl_scheduler.clone(),
|
||||
imports.clone(),
|
||||
);
|
||||
let importer = ImporterHandle::new(pool.clone(), crawl_scheduler.clone(), imports.clone());
|
||||
|
||||
let addr = format!("{}:{}", &config.host, &config.port).parse()?;
|
||||
let mut app = Router::new()
|
||||
.route("/protected", get(protected_handler))
|
||||
.route_layer(RequireAuthorizationLayer::<Uuid, User>::login())
|
||||
.route("/api/v1/feeds", get(handlers::api::feeds::get))
|
||||
.route("/api/v1/feed", post(handlers::api::feed::post))
|
||||
.route("/api/v1/feed/:id", get(handlers::api::feed::get))
|
||||
@@ -95,6 +112,11 @@ async fn main() -> Result<()> {
|
||||
.route("/log/stream", get(handlers::log::stream))
|
||||
.route("/import/opml", post(handlers::import::opml))
|
||||
.route("/import/:id/stream", get(handlers::import::stream))
|
||||
.route("/login", get(handlers::login::get))
|
||||
.route("/login", post(handlers::login::post))
|
||||
.route("/logout", get(handlers::logout::get))
|
||||
.route("/signup", get(handlers::signup::get))
|
||||
.route("/signup", post(handlers::signup::post))
|
||||
.nest_service("/static", ServeDir::new("static"))
|
||||
.with_state(AppState {
|
||||
pool,
|
||||
@@ -107,7 +129,9 @@ async fn main() -> Result<()> {
|
||||
importer,
|
||||
imports,
|
||||
})
|
||||
.layer(ServiceBuilder::new().layer(TraceLayer::new_for_http()));
|
||||
.layer(ServiceBuilder::new().layer(TraceLayer::new_for_http()))
|
||||
.layer(auth_layer)
|
||||
.layer(session_layer);
|
||||
|
||||
if cfg!(debug_assertions) {
|
||||
debug!("starting livereload");
|
||||
|
||||
Reference in New Issue
Block a user