Refactor and comments

demo-mode
Wes Holland 1 year ago
parent db765c18be
commit dfd7a9b6a8

@ -0,0 +1,27 @@
use askama::Template;
use askama_axum::IntoResponse;
use axum::middleware::from_extractor;
use axum::Router;
use axum::routing::get;
use crate::app::state::AppState;
use crate::auth::User;
pub mod state;
pub fn routes() -> Router<AppState> {
Router::new()
.route("/", get(index))
.route("/index.html", get(index))
// Ensure that all routes here require an authenticated user
// whether explicitly asked or not
.route_layer(from_extractor::<User>())
}
#[derive(Template)]
#[template(path = "index.html")]
struct IndexTemplate;
async fn index() -> impl IntoResponse {
IndexTemplate.into_response()
}

@ -0,0 +1,27 @@
use sqlx::SqlitePool;
use oauth2::basic::BasicClient;
use axum::extract::FromRef;
// App state. Pretty basic stuff. Gets passed around by the server to the handlers and whatnot
// Use in a handler with the state enum:
// async fn handler(State(my_app_state): State<AppState>)
#[derive(Clone)]
pub struct AppState {
pub db: SqlitePool,
pub oauth_client: BasicClient,
}
// Axum extractors for app state. These allow the handler to just use
// pieces of the App state
// async fn handler(State(my_db): State<SqlitePool>)
impl FromRef<AppState> for SqlitePool {
fn from_ref(input: &AppState) -> Self {
input.db.clone()
}
}
impl FromRef<AppState> for BasicClient {
fn from_ref(input: &AppState) -> Self {
input.oauth_client.clone()
}
}

@ -1,27 +0,0 @@
use axum::extract::FromRef;
use oauth2::basic::BasicClient;
use sqlx::SqlitePool;
// App state. Pretty basic stuff. Gets passed around by the server to the handlers and whatnot
// Use in a handler with the state enum:
// async fn handler(State(my_app_state): State<AppState>)
#[derive(Clone)]
pub struct AppState {
pub db: SqlitePool,
pub oauth_client: BasicClient,
}
// Axum extractors for app state. These allow the handler to just use
// pieces of the App state
// async fn handler(State(my_db): State<SqlitePool>)
impl FromRef<AppState> for SqlitePool {
fn from_ref(input: &AppState) -> Self {
input.db.clone()
}
}
impl FromRef<AppState> for BasicClient {
fn from_ref(input: &AppState) -> Self {
input.oauth_client.clone()
}
}

@ -1,8 +1,7 @@
use anyhow::{anyhow, Context};
use askama::Template;
use axum::{async_trait, http, Router};
use axum::extract::{FromRef, FromRequestParts, Query, State};
use axum::http::{header, StatusCode};
use axum::{async_trait, Router};
use axum::extract::{FromRequestParts, State};
use axum::http::request::Parts;
use axum::response::{IntoResponse, Redirect, Response};
use axum::routing::get;
@ -10,13 +9,16 @@ use oauth2::basic::BasicClient;
use oauth2::reqwest::async_http_client;
use oauth2::{AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, RevocationUrl, Scope, TokenResponse, TokenUrl};
use serde::{Deserialize, Serialize};
use sqlx::SqlitePool;
use tower_sessions::Session;
use crate::error::{AppError, AppForbiddenResponse};
use crate::{auth, CSRF_TOKEN, USER_SESSION};
use crate::error::QueryExtractor;
use crate::app_state::AppState;
use crate::app::state::AppState;
// This module is all the stuff related to authentication and authorization
const CSRF_TOKEN: &str = "csrf_token";
const USER_SESSION: &str = "user";
pub fn routes() -> Router<AppState> {
Router::new()
@ -25,6 +27,7 @@ pub fn routes() -> Router<AppState> {
.route("/auth/authorized", get(auth_authorized))
}
/// Using the OAUTH2 library for communication to the OAUTH Provider (google)
pub fn init_client() -> anyhow::Result<BasicClient> {
use std::env::var;
@ -53,27 +56,34 @@ pub fn init_client() -> anyhow::Result<BasicClient> {
Ok(client)
}
/// Handler for when the user logs in
pub async fn auth_login(
session: Session,
user: Option<User>,
State(oauth_client): State<BasicClient>,
) -> anyhow::Result<impl IntoResponse, AppError> {
// Make sure we don't already have a session
if user.is_some() {
return Ok(Redirect::to("/"));
}
// STEP 1 - Get the OAUTH Redirect Info with a random state token
let (auth_url, csrf_token) = oauth_client
.authorize_url(CsrfToken::new_random)
.add_scope(Scope::new("profile".to_string()))
.add_scope(Scope::new("email".to_string()))
.url();
// STEP 2 - Save the CSRF token to the session
session.insert(CSRF_TOKEN, csrf_token).await?;
// STEP 3 - Redirect to oauth provider with state
Ok(Redirect::to(auth_url.as_ref()))
}
/// Handler for when the user is redirected back to us from the OAUTH site
pub async fn auth_authorized(
session: Session,
QueryExtractor(query_auth): QueryExtractor<AuthRequest>,
@ -82,7 +92,8 @@ pub async fn auth_authorized(
let user_info_endpoint = std::env::var("OAUTH_USER_INFO_URL")
.context("OAUTH_USER_INFO_URL not set")?;
// Validate that we have a stored csrf token that matches the one passed back to us
// STEP 4 - Get the saved CSRF token from the state and ensure it matches
// what was returned in the query string of the redirect
let stored_csrf_token = session.remove::<CsrfToken>(CSRF_TOKEN)
.await
.context("unable to access csrf token")?
@ -92,13 +103,14 @@ pub async fn auth_authorized(
return Err(anyhow!("session csrf mismatch").into())
}
// STEP 5 - Exchange the Authorization Code for an Access Token
let token = oauth_client
.exchange_code(AuthorizationCode::new(query_auth.code.clone()))
.request_async(async_http_client)
.await
.context("failed in sending request request to authorization server")?;
// Fetch user data
// STEP 6 - Use the Access Token to pull user data like name, email, etc
let client = reqwest::Client::new();
let user_data = client
.get(user_info_endpoint)
@ -112,6 +124,7 @@ pub async fn auth_authorized(
.await
.context("failed to deserialize response as JSON")?;
// STEP 7 - Authorize the user at the application level
//TODO Check against database instead of string
let valid_users = std::env::var("AUTHORIZED_USERS")
.context("Authorized users not set")?;
@ -125,8 +138,10 @@ pub async fn auth_authorized(
return Ok(AppForbiddenResponse::new(&user_data.email, "application").into_response())
}
// STEP 8 - Save user session data
session.insert(USER_SESSION, user_data).await?;
// STEP 9 - Redirect back to the rest of the application
Ok(Redirect::to("/").into_response())
}
@ -134,11 +149,13 @@ pub async fn auth_authorized(
#[template(path = "logged-out.html")]
struct LoggedOutTemplate;
/// Handler for user log-out
pub async fn auth_logout(
session: Session,
user: Option<User>,
) -> anyhow::Result<impl IntoResponse, AppError> {
// Logging out is as simple as clearing the user session
if user.is_some() {
session.remove::<User>(USER_SESSION).await?;
}
@ -146,12 +163,14 @@ pub async fn auth_logout(
Ok(LoggedOutTemplate.into_response())
}
/// Query string response for "authorized" endpoint
#[derive(Debug, Deserialize)]
pub struct AuthRequest {
pub code: String,
pub state: String,
}
/// User information that will be return from the OAUTH authority
#[derive(Debug, Serialize, Deserialize)]
pub struct User {
pub id: String,
@ -161,6 +180,7 @@ pub struct User {
pub picture: String,
}
/// A custom error for the User extractor
pub enum UserExtractError {
InternalServerError(anyhow::Error),
Unauthorized,
@ -175,10 +195,12 @@ impl IntoResponse for UserExtractError {
}
}
/// The user extractor is used to pull out the user data from the session. This can be used
/// as a guard to ensure that a user session exists. Basically an authentication
/// (but not authorization) guard
#[async_trait]
impl<S> FromRequestParts<S> for User
where
SqlitePool: FromRef<S>,
S: Send + Sync,
{
type Rejection = UserExtractError;

@ -1,5 +1,5 @@
use sqlx::SqlitePool;
use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
use sqlx::sqlite::SqliteConnectOptions;
pub async fn init(filename: &str) -> anyhow::Result<SqlitePool> {
let options = SqliteConnectOptions::new()

@ -5,9 +5,11 @@ use axum::response::{IntoResponse, Response};
use axum::Router;
use axum::routing::get;
use axum::extract::FromRequestParts;
use crate::app_state::AppState;
use crate::app::state::AppState;
use crate::auth::User;
// This module is all the stuff related to handling error responses
/// These are just test routes. They shouldn't really be called directly
/// as they just return an error. But they are nice for testing
pub fn routes() -> Router<AppState> {
@ -28,11 +30,6 @@ pub fn routes() -> Router<AppState> {
router
}
/// Handler that always responds with 404 Not Found
pub async fn not_found() -> impl IntoResponse {
AppNotFoundResponse.into_response()
}
/// Application error that is a thin layer over anyhow::Error but has the distinction of having
/// the piping for converting from the error to an axum response
pub struct AppError(anyhow::Error);
@ -128,3 +125,9 @@ fn always_fails() -> anyhow::Result<()> {
async fn forbidden(user: User) -> impl IntoResponse {
AppForbiddenResponse::new(&user.email, "test endpoint")
}
/// Handler that always responds with 404 Not Found
pub async fn not_found() -> impl IntoResponse {
AppNotFoundResponse.into_response()
}

@ -1,20 +1,11 @@
use crate::app_state::AppState;
use crate::auth::User;
use crate::error::{AppError, AppForbiddenResponse};
use anyhow::{anyhow, Context, Result};
use askama_axum::Template;
use axum::extract::{FromRef, Query, State};
use axum::http::header::SET_COOKIE;
use axum::http::HeaderMap;
use axum::response::{IntoResponse, Redirect};
use axum::{extract::Request, handler::HandlerWithoutStateExt, http, http::StatusCode, routing::get, Router};
use app::state::AppState;
use anyhow::{Context, Result};
use axum::Router;
use tokio::signal;
use tokio::task::{AbortHandle, JoinHandle};
use tower_http::{
services::{ServeDir, ServeFile},
trace::TraceLayer,
};
use tower_sessions::Session;
use tracing::info;
use tracing_subscriber::EnvFilter;
@ -24,15 +15,13 @@ mod error;
mod session;
mod auth;
mod static_routes;
mod app;
//NOTE TO FUTURE ME: I'm leaving a bunch of notes about these things as part of the learning
// process. There is a lot of implementation details that are obscured by all these pieces, and it
// can be hard to tell how heavy a line is. Lots of comment in this file and some of the kind of
// "first breadcrumb" type files so you have a place to pick up the threads. Hopefully you remember
// this stuff, but in case you don't here's some notes
const CSRF_TOKEN: &str = "csrf_token";
const USER_SESSION: &str = "user";
#[tokio::main]
async fn main() -> Result<()>{
@ -69,15 +58,18 @@ async fn main() -> Result<()>{
// Session
let (session_layer, session_task) = session::init().await?;
let auth_routes = auth::routes();
let static_routes = static_routes::routes();
let error_routes: Router<AppState> = error::routes();
// Long-running tasks
let mut tasks = vec![];
tasks.push(session_task);
let app_routes: Router<AppState> = Router::new()
.route("/", get(index));
// Assemble all the routes to the various handlers
let auth_routes = auth::routes();
let static_routes = static_routes::routes();
let error_routes = error::routes();
let app_routes = app::routes();
// Top level router
let router = Router::new()
.merge(auth_routes)
.merge(error_routes)
@ -88,6 +80,7 @@ async fn main() -> Result<()>{
.fallback(error::not_found)
.with_state(app_state);
// Serve
let address = "0.0.0.0:4206";
let listener = tokio::net::TcpListener::bind(address)
.await
@ -95,9 +88,6 @@ async fn main() -> Result<()>{
info!("listening on {}", address);
let mut tasks = vec![];
tasks.push(session_task);
axum::serve(listener, router.into_make_service())
.with_graceful_shutdown(shutdown_signal(tasks))
.await.context("unable to serve")?;
@ -106,17 +96,11 @@ async fn main() -> Result<()>{
}
#[derive(Template)]
#[template(path = "index.html")]
struct IndexTemplate<'a> {
name: &'a str,
}
async fn index(user: User) -> impl IntoResponse {
IndexTemplate { name: user.name.as_str() }.into_response()
}
/// This is needed to handle the shutdown of any long-running tasks
/// such as the one that clears expired sessions. This just
/// functions by listening for the termination signal--either
/// ctrl-c or SIGTERM--triggering the abort handle for each
/// task and then joining (awaiting) each handle
async fn shutdown_signal(tasks: Vec<JoinHandle<Result<()>>>) {
let abort_handles: Vec<AbortHandle> = tasks.iter().map(|h| h.abort_handle()).collect();

@ -1,10 +1,11 @@
use axum::Router;
use tower_http::services::ServeFile;
use crate::app_state::AppState;
use crate::app::state::AppState;
pub fn routes() -> Router<AppState> {
Router::new()
.nest_service("/css/pico.min.css", ServeFile::new("static/css/pico.min.css"))
.nest_service("/css/custom.css", ServeFile::new("static/css/custom.css"))
.nest_service("/js/htmx.min.js", ServeFile::new("static/js/htmx.min.js"))
.nest_service("/favicon.ico", ServeFile::new("static/favicon.ico"))
}

@ -1,21 +0,0 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<meta name="color-scheme" content="light dark">
<link rel="stylesheet" href="css/pico.min.css">
<script src="js/htmx.min.js"></script>
<title>Test Page</title>
</head>
<body>
<main class="container">
<h1>Test Page</h1>
<p>This is a test page</p>
<article>
<h2>Card</h2>
</article>
<a href="/auth/logout">Logout</a>
</main>
</body>
</html>

@ -1,12 +1,14 @@
{% extends "main.html" %}
{% block title %} Inventory App {% endblock %}
{% block content %}
<h1>Hello {{ name }}</h1>
<p>This is a test page</p>
<article>
<h2>Card</h2>
</article>
<a href="/auth/logout">Logout</a>
<p>
<input name="search"
placeholder="Search"
aria-label="Search"
type="search">
</p>
{% endblock %}

@ -5,12 +5,29 @@
<meta name="viewport" content="width=device-width, initial-scale=1">
<meta name="color-scheme" content="light dark">
<link rel="stylesheet" href="/css/pico.min.css">
<link rel="stylesheet" href="/css/custom.css">
<script src="/js/htmx.min.js"></script>
<title>Test Page</title>
<title>{% block title %}Title{% endblock %}</title>
</head>
<body>
<header class="container">
<nav class="container">
<ul>
<li><h1>Inventory App</h1></li>
</ul>
<ul>
<li><a class="secondary" href="#">Overview</a></li>
<li><a class="secondary" href="#">Receiving</a></li>
<li><a class="secondary" href="#">Reports</a></li>
<li><a class="secondary" href="#">Adjustments</a></li>
<li><a class="contrast" href="/auth/logout">Logout</a></li>
</ul>
</nav>
</header>
<main class="container">
{% block content %}<p>Content Missing</p>{% endblock %}
</main>
<footer class="container-fluid">
</footer>
</body>
</html>
Loading…
Cancel
Save

Powered by TurnKey Linux.