Refactor and comments

demo-mode
Wes Holland 1 year ago
parent db765c18be
commit dfd7a9b6a8

@ -0,0 +1,27 @@
use askama::Template;
use askama_axum::IntoResponse;
use axum::middleware::from_extractor;
use axum::Router;
use axum::routing::get;
use crate::app::state::AppState;
use crate::auth::User;
pub mod state;
pub fn routes() -> Router<AppState> {
Router::new()
.route("/", get(index))
.route("/index.html", get(index))
// Ensure that all routes here require an authenticated user
// whether explicitly asked or not
.route_layer(from_extractor::<User>())
}
#[derive(Template)]
#[template(path = "index.html")]
struct IndexTemplate;
async fn index() -> impl IntoResponse {
IndexTemplate.into_response()
}

@ -0,0 +1,27 @@
use sqlx::SqlitePool;
use oauth2::basic::BasicClient;
use axum::extract::FromRef;
// App state. Pretty basic stuff. Gets passed around by the server to the handlers and whatnot
// Use in a handler with the state enum:
// async fn handler(State(my_app_state): State<AppState>)
#[derive(Clone)]
pub struct AppState {
pub db: SqlitePool,
pub oauth_client: BasicClient,
}
// Axum extractors for app state. These allow the handler to just use
// pieces of the App state
// async fn handler(State(my_db): State<SqlitePool>)
impl FromRef<AppState> for SqlitePool {
fn from_ref(input: &AppState) -> Self {
input.db.clone()
}
}
impl FromRef<AppState> for BasicClient {
fn from_ref(input: &AppState) -> Self {
input.oauth_client.clone()
}
}

@ -1,27 +0,0 @@
use axum::extract::FromRef;
use oauth2::basic::BasicClient;
use sqlx::SqlitePool;
// App state. Pretty basic stuff. Gets passed around by the server to the handlers and whatnot
// Use in a handler with the state enum:
// async fn handler(State(my_app_state): State<AppState>)
#[derive(Clone)]
pub struct AppState {
pub db: SqlitePool,
pub oauth_client: BasicClient,
}
// Axum extractors for app state. These allow the handler to just use
// pieces of the App state
// async fn handler(State(my_db): State<SqlitePool>)
impl FromRef<AppState> for SqlitePool {
fn from_ref(input: &AppState) -> Self {
input.db.clone()
}
}
impl FromRef<AppState> for BasicClient {
fn from_ref(input: &AppState) -> Self {
input.oauth_client.clone()
}
}

@ -1,8 +1,7 @@
use anyhow::{anyhow, Context}; use anyhow::{anyhow, Context};
use askama::Template; use askama::Template;
use axum::{async_trait, http, Router}; use axum::{async_trait, Router};
use axum::extract::{FromRef, FromRequestParts, Query, State}; use axum::extract::{FromRequestParts, State};
use axum::http::{header, StatusCode};
use axum::http::request::Parts; use axum::http::request::Parts;
use axum::response::{IntoResponse, Redirect, Response}; use axum::response::{IntoResponse, Redirect, Response};
use axum::routing::get; use axum::routing::get;
@ -10,13 +9,16 @@ use oauth2::basic::BasicClient;
use oauth2::reqwest::async_http_client; use oauth2::reqwest::async_http_client;
use oauth2::{AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, RevocationUrl, Scope, TokenResponse, TokenUrl}; use oauth2::{AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, RevocationUrl, Scope, TokenResponse, TokenUrl};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::SqlitePool;
use tower_sessions::Session; use tower_sessions::Session;
use crate::error::{AppError, AppForbiddenResponse}; use crate::error::{AppError, AppForbiddenResponse};
use crate::{auth, CSRF_TOKEN, USER_SESSION};
use crate::error::QueryExtractor; use crate::error::QueryExtractor;
use crate::app_state::AppState; use crate::app::state::AppState;
// This module is all the stuff related to authentication and authorization
const CSRF_TOKEN: &str = "csrf_token";
const USER_SESSION: &str = "user";
pub fn routes() -> Router<AppState> { pub fn routes() -> Router<AppState> {
Router::new() Router::new()
@ -25,6 +27,7 @@ pub fn routes() -> Router<AppState> {
.route("/auth/authorized", get(auth_authorized)) .route("/auth/authorized", get(auth_authorized))
} }
/// Using the OAUTH2 library for communication to the OAUTH Provider (google)
pub fn init_client() -> anyhow::Result<BasicClient> { pub fn init_client() -> anyhow::Result<BasicClient> {
use std::env::var; use std::env::var;
@ -53,27 +56,34 @@ pub fn init_client() -> anyhow::Result<BasicClient> {
Ok(client) Ok(client)
} }
/// Handler for when the user logs in
pub async fn auth_login( pub async fn auth_login(
session: Session, session: Session,
user: Option<User>, user: Option<User>,
State(oauth_client): State<BasicClient>, State(oauth_client): State<BasicClient>,
) -> anyhow::Result<impl IntoResponse, AppError> { ) -> anyhow::Result<impl IntoResponse, AppError> {
// Make sure we don't already have a session
if user.is_some() { if user.is_some() {
return Ok(Redirect::to("/")); return Ok(Redirect::to("/"));
} }
// STEP 1 - Get the OAUTH Redirect Info with a random state token
let (auth_url, csrf_token) = oauth_client let (auth_url, csrf_token) = oauth_client
.authorize_url(CsrfToken::new_random) .authorize_url(CsrfToken::new_random)
.add_scope(Scope::new("profile".to_string())) .add_scope(Scope::new("profile".to_string()))
.add_scope(Scope::new("email".to_string())) .add_scope(Scope::new("email".to_string()))
.url(); .url();
// STEP 2 - Save the CSRF token to the session
session.insert(CSRF_TOKEN, csrf_token).await?; session.insert(CSRF_TOKEN, csrf_token).await?;
// STEP 3 - Redirect to oauth provider with state
Ok(Redirect::to(auth_url.as_ref())) Ok(Redirect::to(auth_url.as_ref()))
} }
/// Handler for when the user is redirected back to us from the OAUTH site
pub async fn auth_authorized( pub async fn auth_authorized(
session: Session, session: Session,
QueryExtractor(query_auth): QueryExtractor<AuthRequest>, QueryExtractor(query_auth): QueryExtractor<AuthRequest>,
@ -82,7 +92,8 @@ pub async fn auth_authorized(
let user_info_endpoint = std::env::var("OAUTH_USER_INFO_URL") let user_info_endpoint = std::env::var("OAUTH_USER_INFO_URL")
.context("OAUTH_USER_INFO_URL not set")?; .context("OAUTH_USER_INFO_URL not set")?;
// Validate that we have a stored csrf token that matches the one passed back to us // STEP 4 - Get the saved CSRF token from the state and ensure it matches
// what was returned in the query string of the redirect
let stored_csrf_token = session.remove::<CsrfToken>(CSRF_TOKEN) let stored_csrf_token = session.remove::<CsrfToken>(CSRF_TOKEN)
.await .await
.context("unable to access csrf token")? .context("unable to access csrf token")?
@ -92,13 +103,14 @@ pub async fn auth_authorized(
return Err(anyhow!("session csrf mismatch").into()) return Err(anyhow!("session csrf mismatch").into())
} }
// STEP 5 - Exchange the Authorization Code for an Access Token
let token = oauth_client let token = oauth_client
.exchange_code(AuthorizationCode::new(query_auth.code.clone())) .exchange_code(AuthorizationCode::new(query_auth.code.clone()))
.request_async(async_http_client) .request_async(async_http_client)
.await .await
.context("failed in sending request request to authorization server")?; .context("failed in sending request request to authorization server")?;
// Fetch user data // STEP 6 - Use the Access Token to pull user data like name, email, etc
let client = reqwest::Client::new(); let client = reqwest::Client::new();
let user_data = client let user_data = client
.get(user_info_endpoint) .get(user_info_endpoint)
@ -112,6 +124,7 @@ pub async fn auth_authorized(
.await .await
.context("failed to deserialize response as JSON")?; .context("failed to deserialize response as JSON")?;
// STEP 7 - Authorize the user at the application level
//TODO Check against database instead of string //TODO Check against database instead of string
let valid_users = std::env::var("AUTHORIZED_USERS") let valid_users = std::env::var("AUTHORIZED_USERS")
.context("Authorized users not set")?; .context("Authorized users not set")?;
@ -125,8 +138,10 @@ pub async fn auth_authorized(
return Ok(AppForbiddenResponse::new(&user_data.email, "application").into_response()) return Ok(AppForbiddenResponse::new(&user_data.email, "application").into_response())
} }
// STEP 8 - Save user session data
session.insert(USER_SESSION, user_data).await?; session.insert(USER_SESSION, user_data).await?;
// STEP 9 - Redirect back to the rest of the application
Ok(Redirect::to("/").into_response()) Ok(Redirect::to("/").into_response())
} }
@ -134,11 +149,13 @@ pub async fn auth_authorized(
#[template(path = "logged-out.html")] #[template(path = "logged-out.html")]
struct LoggedOutTemplate; struct LoggedOutTemplate;
/// Handler for user log-out
pub async fn auth_logout( pub async fn auth_logout(
session: Session, session: Session,
user: Option<User>, user: Option<User>,
) -> anyhow::Result<impl IntoResponse, AppError> { ) -> anyhow::Result<impl IntoResponse, AppError> {
// Logging out is as simple as clearing the user session
if user.is_some() { if user.is_some() {
session.remove::<User>(USER_SESSION).await?; session.remove::<User>(USER_SESSION).await?;
} }
@ -146,12 +163,14 @@ pub async fn auth_logout(
Ok(LoggedOutTemplate.into_response()) Ok(LoggedOutTemplate.into_response())
} }
/// Query string response for "authorized" endpoint
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct AuthRequest { pub struct AuthRequest {
pub code: String, pub code: String,
pub state: String, pub state: String,
} }
/// User information that will be return from the OAUTH authority
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct User { pub struct User {
pub id: String, pub id: String,
@ -161,6 +180,7 @@ pub struct User {
pub picture: String, pub picture: String,
} }
/// A custom error for the User extractor
pub enum UserExtractError { pub enum UserExtractError {
InternalServerError(anyhow::Error), InternalServerError(anyhow::Error),
Unauthorized, Unauthorized,
@ -175,10 +195,12 @@ impl IntoResponse for UserExtractError {
} }
} }
/// The user extractor is used to pull out the user data from the session. This can be used
/// as a guard to ensure that a user session exists. Basically an authentication
/// (but not authorization) guard
#[async_trait] #[async_trait]
impl<S> FromRequestParts<S> for User impl<S> FromRequestParts<S> for User
where where
SqlitePool: FromRef<S>,
S: Send + Sync, S: Send + Sync,
{ {
type Rejection = UserExtractError; type Rejection = UserExtractError;

@ -1,5 +1,5 @@
use sqlx::SqlitePool; use sqlx::SqlitePool;
use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions}; use sqlx::sqlite::SqliteConnectOptions;
pub async fn init(filename: &str) -> anyhow::Result<SqlitePool> { pub async fn init(filename: &str) -> anyhow::Result<SqlitePool> {
let options = SqliteConnectOptions::new() let options = SqliteConnectOptions::new()

@ -5,9 +5,11 @@ use axum::response::{IntoResponse, Response};
use axum::Router; use axum::Router;
use axum::routing::get; use axum::routing::get;
use axum::extract::FromRequestParts; use axum::extract::FromRequestParts;
use crate::app_state::AppState; use crate::app::state::AppState;
use crate::auth::User; use crate::auth::User;
// This module is all the stuff related to handling error responses
/// These are just test routes. They shouldn't really be called directly /// These are just test routes. They shouldn't really be called directly
/// as they just return an error. But they are nice for testing /// as they just return an error. But they are nice for testing
pub fn routes() -> Router<AppState> { pub fn routes() -> Router<AppState> {
@ -28,11 +30,6 @@ pub fn routes() -> Router<AppState> {
router router
} }
/// Handler that always responds with 404 Not Found
pub async fn not_found() -> impl IntoResponse {
AppNotFoundResponse.into_response()
}
/// Application error that is a thin layer over anyhow::Error but has the distinction of having /// Application error that is a thin layer over anyhow::Error but has the distinction of having
/// the piping for converting from the error to an axum response /// the piping for converting from the error to an axum response
pub struct AppError(anyhow::Error); pub struct AppError(anyhow::Error);
@ -128,3 +125,9 @@ fn always_fails() -> anyhow::Result<()> {
async fn forbidden(user: User) -> impl IntoResponse { async fn forbidden(user: User) -> impl IntoResponse {
AppForbiddenResponse::new(&user.email, "test endpoint") AppForbiddenResponse::new(&user.email, "test endpoint")
} }
/// Handler that always responds with 404 Not Found
pub async fn not_found() -> impl IntoResponse {
AppNotFoundResponse.into_response()
}

@ -1,20 +1,11 @@
use crate::app_state::AppState; use app::state::AppState;
use crate::auth::User; use anyhow::{Context, Result};
use crate::error::{AppError, AppForbiddenResponse}; use axum::Router;
use anyhow::{anyhow, Context, Result};
use askama_axum::Template;
use axum::extract::{FromRef, Query, State};
use axum::http::header::SET_COOKIE;
use axum::http::HeaderMap;
use axum::response::{IntoResponse, Redirect};
use axum::{extract::Request, handler::HandlerWithoutStateExt, http, http::StatusCode, routing::get, Router};
use tokio::signal; use tokio::signal;
use tokio::task::{AbortHandle, JoinHandle}; use tokio::task::{AbortHandle, JoinHandle};
use tower_http::{ use tower_http::{
services::{ServeDir, ServeFile},
trace::TraceLayer, trace::TraceLayer,
}; };
use tower_sessions::Session;
use tracing::info; use tracing::info;
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
@ -24,15 +15,13 @@ mod error;
mod session; mod session;
mod auth; mod auth;
mod static_routes; mod static_routes;
mod app;
//NOTE TO FUTURE ME: I'm leaving a bunch of notes about these things as part of the learning //NOTE TO FUTURE ME: I'm leaving a bunch of notes about these things as part of the learning
// process. There is a lot of implementation details that are obscured by all these pieces, and it // process. There is a lot of implementation details that are obscured by all these pieces, and it
// can be hard to tell how heavy a line is. Lots of comment in this file and some of the kind of // can be hard to tell how heavy a line is. Lots of comment in this file and some of the kind of
// "first breadcrumb" type files so you have a place to pick up the threads. Hopefully you remember // "first breadcrumb" type files so you have a place to pick up the threads. Hopefully you remember
// this stuff, but in case you don't here's some notes // this stuff, but in case you don't here's some notes
const CSRF_TOKEN: &str = "csrf_token";
const USER_SESSION: &str = "user";
#[tokio::main] #[tokio::main]
async fn main() -> Result<()>{ async fn main() -> Result<()>{
@ -69,15 +58,18 @@ async fn main() -> Result<()>{
// Session // Session
let (session_layer, session_task) = session::init().await?; let (session_layer, session_task) = session::init().await?;
let auth_routes = auth::routes();
let static_routes = static_routes::routes();
let error_routes: Router<AppState> = error::routes(); // Long-running tasks
let mut tasks = vec![];
tasks.push(session_task);
let app_routes: Router<AppState> = Router::new() // Assemble all the routes to the various handlers
.route("/", get(index)); let auth_routes = auth::routes();
let static_routes = static_routes::routes();
let error_routes = error::routes();
let app_routes = app::routes();
// Top level router
let router = Router::new() let router = Router::new()
.merge(auth_routes) .merge(auth_routes)
.merge(error_routes) .merge(error_routes)
@ -88,6 +80,7 @@ async fn main() -> Result<()>{
.fallback(error::not_found) .fallback(error::not_found)
.with_state(app_state); .with_state(app_state);
// Serve
let address = "0.0.0.0:4206"; let address = "0.0.0.0:4206";
let listener = tokio::net::TcpListener::bind(address) let listener = tokio::net::TcpListener::bind(address)
.await .await
@ -95,9 +88,6 @@ async fn main() -> Result<()>{
info!("listening on {}", address); info!("listening on {}", address);
let mut tasks = vec![];
tasks.push(session_task);
axum::serve(listener, router.into_make_service()) axum::serve(listener, router.into_make_service())
.with_graceful_shutdown(shutdown_signal(tasks)) .with_graceful_shutdown(shutdown_signal(tasks))
.await.context("unable to serve")?; .await.context("unable to serve")?;
@ -106,17 +96,11 @@ async fn main() -> Result<()>{
} }
#[derive(Template)] /// This is needed to handle the shutdown of any long-running tasks
#[template(path = "index.html")] /// such as the one that clears expired sessions. This just
struct IndexTemplate<'a> { /// functions by listening for the termination signal--either
name: &'a str, /// ctrl-c or SIGTERM--triggering the abort handle for each
} /// task and then joining (awaiting) each handle
async fn index(user: User) -> impl IntoResponse {
IndexTemplate { name: user.name.as_str() }.into_response()
}
async fn shutdown_signal(tasks: Vec<JoinHandle<Result<()>>>) { async fn shutdown_signal(tasks: Vec<JoinHandle<Result<()>>>) {
let abort_handles: Vec<AbortHandle> = tasks.iter().map(|h| h.abort_handle()).collect(); let abort_handles: Vec<AbortHandle> = tasks.iter().map(|h| h.abort_handle()).collect();

@ -1,10 +1,11 @@
use axum::Router; use axum::Router;
use tower_http::services::ServeFile; use tower_http::services::ServeFile;
use crate::app_state::AppState; use crate::app::state::AppState;
pub fn routes() -> Router<AppState> { pub fn routes() -> Router<AppState> {
Router::new() Router::new()
.nest_service("/css/pico.min.css", ServeFile::new("static/css/pico.min.css")) .nest_service("/css/pico.min.css", ServeFile::new("static/css/pico.min.css"))
.nest_service("/css/custom.css", ServeFile::new("static/css/custom.css"))
.nest_service("/js/htmx.min.js", ServeFile::new("static/js/htmx.min.js")) .nest_service("/js/htmx.min.js", ServeFile::new("static/js/htmx.min.js"))
.nest_service("/favicon.ico", ServeFile::new("static/favicon.ico")) .nest_service("/favicon.ico", ServeFile::new("static/favicon.ico"))
} }

@ -1,21 +0,0 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<meta name="color-scheme" content="light dark">
<link rel="stylesheet" href="css/pico.min.css">
<script src="js/htmx.min.js"></script>
<title>Test Page</title>
</head>
<body>
<main class="container">
<h1>Test Page</h1>
<p>This is a test page</p>
<article>
<h2>Card</h2>
</article>
<a href="/auth/logout">Logout</a>
</main>
</body>
</html>

@ -1,12 +1,14 @@
{% extends "main.html" %} {% extends "main.html" %}
{% block title %} Inventory App {% endblock %}
{% block content %} {% block content %}
<h1>Hello {{ name }}</h1> <p>
<p>This is a test page</p> <input name="search"
<article> placeholder="Search"
<h2>Card</h2> aria-label="Search"
</article> type="search">
<a href="/auth/logout">Logout</a> </p>
{% endblock %} {% endblock %}

@ -5,12 +5,29 @@
<meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="viewport" content="width=device-width, initial-scale=1">
<meta name="color-scheme" content="light dark"> <meta name="color-scheme" content="light dark">
<link rel="stylesheet" href="/css/pico.min.css"> <link rel="stylesheet" href="/css/pico.min.css">
<link rel="stylesheet" href="/css/custom.css">
<script src="/js/htmx.min.js"></script> <script src="/js/htmx.min.js"></script>
<title>Test Page</title> <title>{% block title %}Title{% endblock %}</title>
</head> </head>
<body> <body>
<header class="container">
<nav class="container">
<ul>
<li><h1>Inventory App</h1></li>
</ul>
<ul>
<li><a class="secondary" href="#">Overview</a></li>
<li><a class="secondary" href="#">Receiving</a></li>
<li><a class="secondary" href="#">Reports</a></li>
<li><a class="secondary" href="#">Adjustments</a></li>
<li><a class="contrast" href="/auth/logout">Logout</a></li>
</ul>
</nav>
</header>
<main class="container"> <main class="container">
{% block content %}<p>Content Missing</p>{% endblock %} {% block content %}<p>Content Missing</p>{% endblock %}
</main> </main>
<footer class="container-fluid">
</footer>
</body> </body>
</html> </html>
Loading…
Cancel
Save

Powered by TurnKey Linux.