|
|
|
|
@ -1,8 +1,7 @@
|
|
|
|
|
use anyhow::{anyhow, Context};
|
|
|
|
|
use askama::Template;
|
|
|
|
|
use axum::{async_trait, http, Router};
|
|
|
|
|
use axum::extract::{FromRef, FromRequestParts, Query, State};
|
|
|
|
|
use axum::http::{header, StatusCode};
|
|
|
|
|
use axum::{async_trait, Router};
|
|
|
|
|
use axum::extract::{FromRequestParts, State};
|
|
|
|
|
use axum::http::request::Parts;
|
|
|
|
|
use axum::response::{IntoResponse, Redirect, Response};
|
|
|
|
|
use axum::routing::get;
|
|
|
|
|
@ -10,13 +9,16 @@ use oauth2::basic::BasicClient;
|
|
|
|
|
use oauth2::reqwest::async_http_client;
|
|
|
|
|
use oauth2::{AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, RevocationUrl, Scope, TokenResponse, TokenUrl};
|
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
|
use sqlx::SqlitePool;
|
|
|
|
|
use tower_sessions::Session;
|
|
|
|
|
|
|
|
|
|
use crate::error::{AppError, AppForbiddenResponse};
|
|
|
|
|
use crate::{auth, CSRF_TOKEN, USER_SESSION};
|
|
|
|
|
use crate::error::QueryExtractor;
|
|
|
|
|
use crate::app_state::AppState;
|
|
|
|
|
use crate::app::state::AppState;
|
|
|
|
|
|
|
|
|
|
// This module is all the stuff related to authentication and authorization
|
|
|
|
|
|
|
|
|
|
const CSRF_TOKEN: &str = "csrf_token";
|
|
|
|
|
const USER_SESSION: &str = "user";
|
|
|
|
|
|
|
|
|
|
pub fn routes() -> Router<AppState> {
|
|
|
|
|
Router::new()
|
|
|
|
|
@ -25,6 +27,7 @@ pub fn routes() -> Router<AppState> {
|
|
|
|
|
.route("/auth/authorized", get(auth_authorized))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Using the OAUTH2 library for communication to the OAUTH Provider (google)
|
|
|
|
|
pub fn init_client() -> anyhow::Result<BasicClient> {
|
|
|
|
|
use std::env::var;
|
|
|
|
|
|
|
|
|
|
@ -53,27 +56,34 @@ pub fn init_client() -> anyhow::Result<BasicClient> {
|
|
|
|
|
Ok(client)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Handler for when the user logs in
|
|
|
|
|
pub async fn auth_login(
|
|
|
|
|
session: Session,
|
|
|
|
|
user: Option<User>,
|
|
|
|
|
State(oauth_client): State<BasicClient>,
|
|
|
|
|
) -> anyhow::Result<impl IntoResponse, AppError> {
|
|
|
|
|
|
|
|
|
|
// Make sure we don't already have a session
|
|
|
|
|
if user.is_some() {
|
|
|
|
|
return Ok(Redirect::to("/"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// STEP 1 - Get the OAUTH Redirect Info with a random state token
|
|
|
|
|
let (auth_url, csrf_token) = oauth_client
|
|
|
|
|
.authorize_url(CsrfToken::new_random)
|
|
|
|
|
.add_scope(Scope::new("profile".to_string()))
|
|
|
|
|
.add_scope(Scope::new("email".to_string()))
|
|
|
|
|
.url();
|
|
|
|
|
|
|
|
|
|
// STEP 2 - Save the CSRF token to the session
|
|
|
|
|
session.insert(CSRF_TOKEN, csrf_token).await?;
|
|
|
|
|
|
|
|
|
|
// STEP 3 - Redirect to oauth provider with state
|
|
|
|
|
Ok(Redirect::to(auth_url.as_ref()))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/// Handler for when the user is redirected back to us from the OAUTH site
|
|
|
|
|
pub async fn auth_authorized(
|
|
|
|
|
session: Session,
|
|
|
|
|
QueryExtractor(query_auth): QueryExtractor<AuthRequest>,
|
|
|
|
|
@ -82,7 +92,8 @@ pub async fn auth_authorized(
|
|
|
|
|
let user_info_endpoint = std::env::var("OAUTH_USER_INFO_URL")
|
|
|
|
|
.context("OAUTH_USER_INFO_URL not set")?;
|
|
|
|
|
|
|
|
|
|
// Validate that we have a stored csrf token that matches the one passed back to us
|
|
|
|
|
// STEP 4 - Get the saved CSRF token from the state and ensure it matches
|
|
|
|
|
// what was returned in the query string of the redirect
|
|
|
|
|
let stored_csrf_token = session.remove::<CsrfToken>(CSRF_TOKEN)
|
|
|
|
|
.await
|
|
|
|
|
.context("unable to access csrf token")?
|
|
|
|
|
@ -92,13 +103,14 @@ pub async fn auth_authorized(
|
|
|
|
|
return Err(anyhow!("session csrf mismatch").into())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// STEP 5 - Exchange the Authorization Code for an Access Token
|
|
|
|
|
let token = oauth_client
|
|
|
|
|
.exchange_code(AuthorizationCode::new(query_auth.code.clone()))
|
|
|
|
|
.request_async(async_http_client)
|
|
|
|
|
.await
|
|
|
|
|
.context("failed in sending request request to authorization server")?;
|
|
|
|
|
|
|
|
|
|
// Fetch user data
|
|
|
|
|
// STEP 6 - Use the Access Token to pull user data like name, email, etc
|
|
|
|
|
let client = reqwest::Client::new();
|
|
|
|
|
let user_data = client
|
|
|
|
|
.get(user_info_endpoint)
|
|
|
|
|
@ -112,6 +124,7 @@ pub async fn auth_authorized(
|
|
|
|
|
.await
|
|
|
|
|
.context("failed to deserialize response as JSON")?;
|
|
|
|
|
|
|
|
|
|
// STEP 7 - Authorize the user at the application level
|
|
|
|
|
//TODO Check against database instead of string
|
|
|
|
|
let valid_users = std::env::var("AUTHORIZED_USERS")
|
|
|
|
|
.context("Authorized users not set")?;
|
|
|
|
|
@ -125,8 +138,10 @@ pub async fn auth_authorized(
|
|
|
|
|
return Ok(AppForbiddenResponse::new(&user_data.email, "application").into_response())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// STEP 8 - Save user session data
|
|
|
|
|
session.insert(USER_SESSION, user_data).await?;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// STEP 9 - Redirect back to the rest of the application
|
|
|
|
|
Ok(Redirect::to("/").into_response())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -134,11 +149,13 @@ pub async fn auth_authorized(
|
|
|
|
|
#[template(path = "logged-out.html")]
|
|
|
|
|
struct LoggedOutTemplate;
|
|
|
|
|
|
|
|
|
|
/// Handler for user log-out
|
|
|
|
|
pub async fn auth_logout(
|
|
|
|
|
session: Session,
|
|
|
|
|
user: Option<User>,
|
|
|
|
|
) -> anyhow::Result<impl IntoResponse, AppError> {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Logging out is as simple as clearing the user session
|
|
|
|
|
if user.is_some() {
|
|
|
|
|
session.remove::<User>(USER_SESSION).await?;
|
|
|
|
|
}
|
|
|
|
|
@ -146,12 +163,14 @@ pub async fn auth_logout(
|
|
|
|
|
Ok(LoggedOutTemplate.into_response())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Query string response for "authorized" endpoint
|
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
|
|
|
pub struct AuthRequest {
|
|
|
|
|
pub code: String,
|
|
|
|
|
pub state: String,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// User information that will be return from the OAUTH authority
|
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
|
|
|
pub struct User {
|
|
|
|
|
pub id: String,
|
|
|
|
|
@ -161,6 +180,7 @@ pub struct User {
|
|
|
|
|
pub picture: String,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// A custom error for the User extractor
|
|
|
|
|
pub enum UserExtractError {
|
|
|
|
|
InternalServerError(anyhow::Error),
|
|
|
|
|
Unauthorized,
|
|
|
|
|
@ -175,10 +195,12 @@ impl IntoResponse for UserExtractError {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// The user extractor is used to pull out the user data from the session. This can be used
|
|
|
|
|
/// as a guard to ensure that a user session exists. Basically an authentication
|
|
|
|
|
/// (but not authorization) guard
|
|
|
|
|
#[async_trait]
|
|
|
|
|
impl<S> FromRequestParts<S> for User
|
|
|
|
|
where
|
|
|
|
|
SqlitePool: FromRef<S>,
|
|
|
|
|
S: Send + Sync,
|
|
|
|
|
{
|
|
|
|
|
type Rejection = UserExtractError;
|
|
|
|
|
|