From dfd7a9b6a83b383e4cbf3c497ecaaca89fddc5ad Mon Sep 17 00:00:00 2001 From: Wes Holland Date: Wed, 6 Nov 2024 15:15:41 -0600 Subject: [PATCH] Refactor and comments --- src/app/mod.rs | 27 ++++++++++++++++++++++ src/app/state.rs | 27 ++++++++++++++++++++++ src/app_state.rs | 27 ---------------------- src/auth.rs | 46 +++++++++++++++++++++++++++---------- src/db.rs | 2 +- src/error.rs | 15 +++++++----- src/main.rs | 54 ++++++++++++++++---------------------------- src/static_routes.rs | 3 ++- static/index.html | 21 ----------------- templates/index.html | 14 +++++++----- templates/main.html | 19 +++++++++++++++- 11 files changed, 145 insertions(+), 110 deletions(-) create mode 100644 src/app/mod.rs create mode 100644 src/app/state.rs delete mode 100644 static/index.html diff --git a/src/app/mod.rs b/src/app/mod.rs new file mode 100644 index 0000000..d0ef9cf --- /dev/null +++ b/src/app/mod.rs @@ -0,0 +1,27 @@ +use askama::Template; +use askama_axum::IntoResponse; +use axum::middleware::from_extractor; +use axum::Router; +use axum::routing::get; +use crate::app::state::AppState; +use crate::auth::User; + +pub mod state; + +pub fn routes() -> Router { + Router::new() + .route("/", get(index)) + .route("/index.html", get(index)) + // Ensure that all routes here require an authenticated user + // whether explicitly asked or not + .route_layer(from_extractor::()) +} + +#[derive(Template)] +#[template(path = "index.html")] +struct IndexTemplate; + +async fn index() -> impl IntoResponse { + IndexTemplate.into_response() +} + diff --git a/src/app/state.rs b/src/app/state.rs new file mode 100644 index 0000000..6f39d1d --- /dev/null +++ b/src/app/state.rs @@ -0,0 +1,27 @@ +use sqlx::SqlitePool; +use oauth2::basic::BasicClient; +use axum::extract::FromRef; + +// App state. Pretty basic stuff. Gets passed around by the server to the handlers and whatnot +// Use in a handler with the state enum: +// async fn handler(State(my_app_state): State) +#[derive(Clone)] +pub struct AppState { + pub db: SqlitePool, + pub oauth_client: BasicClient, +} + +// Axum extractors for app state. These allow the handler to just use +// pieces of the App state +// async fn handler(State(my_db): State) +impl FromRef for SqlitePool { + fn from_ref(input: &AppState) -> Self { + input.db.clone() + } +} + +impl FromRef for BasicClient { + fn from_ref(input: &AppState) -> Self { + input.oauth_client.clone() + } +} \ No newline at end of file diff --git a/src/app_state.rs b/src/app_state.rs index 427ada1..e69de29 100644 --- a/src/app_state.rs +++ b/src/app_state.rs @@ -1,27 +0,0 @@ -use axum::extract::FromRef; -use oauth2::basic::BasicClient; -use sqlx::SqlitePool; - -// App state. Pretty basic stuff. Gets passed around by the server to the handlers and whatnot -// Use in a handler with the state enum: -// async fn handler(State(my_app_state): State) -#[derive(Clone)] -pub struct AppState { - pub db: SqlitePool, - pub oauth_client: BasicClient, -} - -// Axum extractors for app state. These allow the handler to just use -// pieces of the App state -// async fn handler(State(my_db): State) -impl FromRef for SqlitePool { - fn from_ref(input: &AppState) -> Self { - input.db.clone() - } -} - -impl FromRef for BasicClient { - fn from_ref(input: &AppState) -> Self { - input.oauth_client.clone() - } -} diff --git a/src/auth.rs b/src/auth.rs index 5d7aaed..a9ef663 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -1,8 +1,7 @@ use anyhow::{anyhow, Context}; use askama::Template; -use axum::{async_trait, http, Router}; -use axum::extract::{FromRef, FromRequestParts, Query, State}; -use axum::http::{header, StatusCode}; +use axum::{async_trait, Router}; +use axum::extract::{FromRequestParts, State}; use axum::http::request::Parts; use axum::response::{IntoResponse, Redirect, Response}; use axum::routing::get; @@ -10,13 +9,16 @@ use oauth2::basic::BasicClient; use oauth2::reqwest::async_http_client; use oauth2::{AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, RevocationUrl, Scope, TokenResponse, TokenUrl}; use serde::{Deserialize, Serialize}; -use sqlx::SqlitePool; use tower_sessions::Session; use crate::error::{AppError, AppForbiddenResponse}; -use crate::{auth, CSRF_TOKEN, USER_SESSION}; use crate::error::QueryExtractor; -use crate::app_state::AppState; +use crate::app::state::AppState; + +// This module is all the stuff related to authentication and authorization + +const CSRF_TOKEN: &str = "csrf_token"; +const USER_SESSION: &str = "user"; pub fn routes() -> Router { Router::new() @@ -25,6 +27,7 @@ pub fn routes() -> Router { .route("/auth/authorized", get(auth_authorized)) } +/// Using the OAUTH2 library for communication to the OAUTH Provider (google) pub fn init_client() -> anyhow::Result { use std::env::var; @@ -53,27 +56,34 @@ pub fn init_client() -> anyhow::Result { Ok(client) } +/// Handler for when the user logs in pub async fn auth_login( session: Session, user: Option, State(oauth_client): State, ) -> anyhow::Result { + + // Make sure we don't already have a session if user.is_some() { return Ok(Redirect::to("/")); } - + + // STEP 1 - Get the OAUTH Redirect Info with a random state token let (auth_url, csrf_token) = oauth_client .authorize_url(CsrfToken::new_random) .add_scope(Scope::new("profile".to_string())) .add_scope(Scope::new("email".to_string())) .url(); + // STEP 2 - Save the CSRF token to the session session.insert(CSRF_TOKEN, csrf_token).await?; + // STEP 3 - Redirect to oauth provider with state Ok(Redirect::to(auth_url.as_ref())) } +/// Handler for when the user is redirected back to us from the OAUTH site pub async fn auth_authorized( session: Session, QueryExtractor(query_auth): QueryExtractor, @@ -82,7 +92,8 @@ pub async fn auth_authorized( let user_info_endpoint = std::env::var("OAUTH_USER_INFO_URL") .context("OAUTH_USER_INFO_URL not set")?; - // Validate that we have a stored csrf token that matches the one passed back to us + // STEP 4 - Get the saved CSRF token from the state and ensure it matches + // what was returned in the query string of the redirect let stored_csrf_token = session.remove::(CSRF_TOKEN) .await .context("unable to access csrf token")? @@ -92,13 +103,14 @@ pub async fn auth_authorized( return Err(anyhow!("session csrf mismatch").into()) } + // STEP 5 - Exchange the Authorization Code for an Access Token let token = oauth_client .exchange_code(AuthorizationCode::new(query_auth.code.clone())) .request_async(async_http_client) .await .context("failed in sending request request to authorization server")?; - // Fetch user data + // STEP 6 - Use the Access Token to pull user data like name, email, etc let client = reqwest::Client::new(); let user_data = client .get(user_info_endpoint) @@ -112,6 +124,7 @@ pub async fn auth_authorized( .await .context("failed to deserialize response as JSON")?; + // STEP 7 - Authorize the user at the application level //TODO Check against database instead of string let valid_users = std::env::var("AUTHORIZED_USERS") .context("Authorized users not set")?; @@ -125,8 +138,10 @@ pub async fn auth_authorized( return Ok(AppForbiddenResponse::new(&user_data.email, "application").into_response()) } + // STEP 8 - Save user session data session.insert(USER_SESSION, user_data).await?; - + + // STEP 9 - Redirect back to the rest of the application Ok(Redirect::to("/").into_response()) } @@ -134,11 +149,13 @@ pub async fn auth_authorized( #[template(path = "logged-out.html")] struct LoggedOutTemplate; +/// Handler for user log-out pub async fn auth_logout( session: Session, user: Option, ) -> anyhow::Result { - + + // Logging out is as simple as clearing the user session if user.is_some() { session.remove::(USER_SESSION).await?; } @@ -146,12 +163,14 @@ pub async fn auth_logout( Ok(LoggedOutTemplate.into_response()) } +/// Query string response for "authorized" endpoint #[derive(Debug, Deserialize)] pub struct AuthRequest { pub code: String, pub state: String, } +/// User information that will be return from the OAUTH authority #[derive(Debug, Serialize, Deserialize)] pub struct User { pub id: String, @@ -161,6 +180,7 @@ pub struct User { pub picture: String, } +/// A custom error for the User extractor pub enum UserExtractError { InternalServerError(anyhow::Error), Unauthorized, @@ -175,10 +195,12 @@ impl IntoResponse for UserExtractError { } } +/// The user extractor is used to pull out the user data from the session. This can be used +/// as a guard to ensure that a user session exists. Basically an authentication +/// (but not authorization) guard #[async_trait] impl FromRequestParts for User where - SqlitePool: FromRef, S: Send + Sync, { type Rejection = UserExtractError; diff --git a/src/db.rs b/src/db.rs index db2f179..2b6ed48 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,5 +1,5 @@ use sqlx::SqlitePool; -use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions}; +use sqlx::sqlite::SqliteConnectOptions; pub async fn init(filename: &str) -> anyhow::Result { let options = SqliteConnectOptions::new() diff --git a/src/error.rs b/src/error.rs index c6c6477..f87f0d3 100644 --- a/src/error.rs +++ b/src/error.rs @@ -5,9 +5,11 @@ use axum::response::{IntoResponse, Response}; use axum::Router; use axum::routing::get; use axum::extract::FromRequestParts; -use crate::app_state::AppState; +use crate::app::state::AppState; use crate::auth::User; +// This module is all the stuff related to handling error responses + /// These are just test routes. They shouldn't really be called directly /// as they just return an error. But they are nice for testing pub fn routes() -> Router { @@ -28,11 +30,6 @@ pub fn routes() -> Router { router } -/// Handler that always responds with 404 Not Found -pub async fn not_found() -> impl IntoResponse { - AppNotFoundResponse.into_response() -} - /// Application error that is a thin layer over anyhow::Error but has the distinction of having /// the piping for converting from the error to an axum response pub struct AppError(anyhow::Error); @@ -128,3 +125,9 @@ fn always_fails() -> anyhow::Result<()> { async fn forbidden(user: User) -> impl IntoResponse { AppForbiddenResponse::new(&user.email, "test endpoint") } + +/// Handler that always responds with 404 Not Found +pub async fn not_found() -> impl IntoResponse { + AppNotFoundResponse.into_response() +} + diff --git a/src/main.rs b/src/main.rs index 21a014e..fc4c7d8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,20 +1,11 @@ -use crate::app_state::AppState; -use crate::auth::User; -use crate::error::{AppError, AppForbiddenResponse}; -use anyhow::{anyhow, Context, Result}; -use askama_axum::Template; -use axum::extract::{FromRef, Query, State}; -use axum::http::header::SET_COOKIE; -use axum::http::HeaderMap; -use axum::response::{IntoResponse, Redirect}; -use axum::{extract::Request, handler::HandlerWithoutStateExt, http, http::StatusCode, routing::get, Router}; +use app::state::AppState; +use anyhow::{Context, Result}; +use axum::Router; use tokio::signal; use tokio::task::{AbortHandle, JoinHandle}; use tower_http::{ - services::{ServeDir, ServeFile}, trace::TraceLayer, }; -use tower_sessions::Session; use tracing::info; use tracing_subscriber::EnvFilter; @@ -24,15 +15,13 @@ mod error; mod session; mod auth; mod static_routes; +mod app; //NOTE TO FUTURE ME: I'm leaving a bunch of notes about these things as part of the learning // process. There is a lot of implementation details that are obscured by all these pieces, and it // can be hard to tell how heavy a line is. Lots of comment in this file and some of the kind of // "first breadcrumb" type files so you have a place to pick up the threads. Hopefully you remember // this stuff, but in case you don't here's some notes -const CSRF_TOKEN: &str = "csrf_token"; -const USER_SESSION: &str = "user"; - #[tokio::main] async fn main() -> Result<()>{ @@ -69,15 +58,18 @@ async fn main() -> Result<()>{ // Session let (session_layer, session_task) = session::init().await?; - - let auth_routes = auth::routes(); - let static_routes = static_routes::routes(); - let error_routes: Router = error::routes(); + // Long-running tasks + let mut tasks = vec![]; + tasks.push(session_task); - let app_routes: Router = Router::new() - .route("/", get(index)); + // Assemble all the routes to the various handlers + let auth_routes = auth::routes(); + let static_routes = static_routes::routes(); + let error_routes = error::routes(); + let app_routes = app::routes(); + // Top level router let router = Router::new() .merge(auth_routes) .merge(error_routes) @@ -88,6 +80,7 @@ async fn main() -> Result<()>{ .fallback(error::not_found) .with_state(app_state); + // Serve let address = "0.0.0.0:4206"; let listener = tokio::net::TcpListener::bind(address) .await @@ -95,9 +88,6 @@ async fn main() -> Result<()>{ info!("listening on {}", address); - let mut tasks = vec![]; - tasks.push(session_task); - axum::serve(listener, router.into_make_service()) .with_graceful_shutdown(shutdown_signal(tasks)) .await.context("unable to serve")?; @@ -106,17 +96,11 @@ async fn main() -> Result<()>{ } -#[derive(Template)] -#[template(path = "index.html")] -struct IndexTemplate<'a> { - name: &'a str, -} - -async fn index(user: User) -> impl IntoResponse { - IndexTemplate { name: user.name.as_str() }.into_response() -} - - +/// This is needed to handle the shutdown of any long-running tasks +/// such as the one that clears expired sessions. This just +/// functions by listening for the termination signal--either +/// ctrl-c or SIGTERM--triggering the abort handle for each +/// task and then joining (awaiting) each handle async fn shutdown_signal(tasks: Vec>>) { let abort_handles: Vec = tasks.iter().map(|h| h.abort_handle()).collect(); diff --git a/src/static_routes.rs b/src/static_routes.rs index 3d0423d..f1d49fb 100644 --- a/src/static_routes.rs +++ b/src/static_routes.rs @@ -1,10 +1,11 @@ use axum::Router; use tower_http::services::ServeFile; -use crate::app_state::AppState; +use crate::app::state::AppState; pub fn routes() -> Router { Router::new() .nest_service("/css/pico.min.css", ServeFile::new("static/css/pico.min.css")) + .nest_service("/css/custom.css", ServeFile::new("static/css/custom.css")) .nest_service("/js/htmx.min.js", ServeFile::new("static/js/htmx.min.js")) .nest_service("/favicon.ico", ServeFile::new("static/favicon.ico")) } diff --git a/static/index.html b/static/index.html deleted file mode 100644 index 0efbd1d..0000000 --- a/static/index.html +++ /dev/null @@ -1,21 +0,0 @@ - - - - - - - - - Test Page - - -
-

Test Page

-

This is a test page

-
-

Card

-
- Logout -
- - \ No newline at end of file diff --git a/templates/index.html b/templates/index.html index 5d0198a..b6ca88b 100644 --- a/templates/index.html +++ b/templates/index.html @@ -1,12 +1,14 @@ {% extends "main.html" %} +{% block title %} Inventory App {% endblock %} + {% block content %} -

Hello {{ name }}

-

This is a test page

-
-

Card

-
-Logout +

+ +

{% endblock %} \ No newline at end of file diff --git a/templates/main.html b/templates/main.html index 1aa3f05..448b59b 100644 --- a/templates/main.html +++ b/templates/main.html @@ -5,12 +5,29 @@ + - Test Page + {% block title %}Title{% endblock %} +
+ +
{% block content %}

Content Missing

{% endblock %}
+ \ No newline at end of file