You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
inventory-app/src/auth.rs

170 lines
5.2 KiB

use anyhow::Context;
use axum::{async_trait, http};
use axum::extract::{FromRef, FromRequestParts, Query, State};
use axum::http::{header, StatusCode};
use axum::http::request::Parts;
use axum::response::{IntoResponse, Redirect, Response};
use oauth2::basic::BasicClient;
use oauth2::reqwest::async_http_client;
use oauth2::{AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, RevocationUrl, Scope, TokenResponse, TokenUrl};
use serde::{Deserialize, Serialize};
use sqlx::SqlitePool;
use tower_sessions::Session;
use crate::error::AppError;
use crate::{CSRF_TOKEN, USER_SESSION};
pub fn init_client() -> anyhow::Result<BasicClient> {
use std::env::var;
let client_id = var("OAUTH_CLIENT_ID").context("env OAUTH_CLIENT_ID not set")?;
let client_secret = var("OAUTH_CLIENT_SECRET").context("env OAUTH_CLIENT_SECRET not set")?;
let auth_url = var("OAUTH_AUTH_URL").context("env OAUTH_AUTH_URL not set")?;
let token_url = var("OAUTH_TOKEN_URL").context("env OAUTH_TOKEN_URL not set")?;
let revoke_url = var("OAUTH_REVOKE_URL").context("env OAUTH_REVOKE_URL not set")?;
let redirect_url = var("OAUTH_REDIRECT_URL").context("env OAUTH_REDIRECT_URL not set")?;
let client_id = ClientId::new(client_id);
let client_secret = ClientSecret::new(client_secret);
let auth_url = AuthUrl::new(auth_url)?;
let token_url = TokenUrl::new(token_url)?;
let revoke_url = RevocationUrl::new(revoke_url)?;
let redirect_url = RedirectUrl::new(redirect_url)?;
let client = BasicClient::new(
client_id,
Some(client_secret),
auth_url,
Some(token_url))
.set_redirect_uri(redirect_url)
.set_revocation_uri(revoke_url);
Ok(client)
}
pub async fn auth_google(
session: Session,
State(oauth_client): State<BasicClient>,
) -> anyhow::Result<impl IntoResponse, AppError> {
let (auth_url, csrf_token) = oauth_client
.authorize_url(CsrfToken::new_random)
.add_scope(Scope::new("profile".to_string()))
.add_scope(Scope::new("email".to_string()))
.url();
session.insert(CSRF_TOKEN, csrf_token).await?;
Ok(Redirect::to(auth_url.as_ref()))
}
pub async fn auth_authorized(
session: Session,
Query(query_auth): Query<AuthRequest>,
State(oauth_client): State<BasicClient>,
) -> anyhow::Result<impl IntoResponse, AppError> {
let user_info_endpoint = std::env::var("OAUTH_USER_INFO_URL")
.context("OAUTH_USER_INFO_URL not set")?;
let token = oauth_client
.exchange_code(AuthorizationCode::new(query_auth.code.clone()))
.request_async(async_http_client)
.await
.context("failed in sending request request to authorization server")?;
// Fetch user data
let client = reqwest::Client::new();
let user_data = client
.get(user_info_endpoint)
.bearer_auth(token.access_token().secret())
.send()
.await
.context("failed in sending request to target Url")?;
let user_data = user_data
.json::<User>()
.await
.context("failed to deserialize response as JSON")?;
session.insert(USER_SESSION, user_data).await?;
//TODO Redirect somewhere sane
Ok(Redirect::to("/protected"))
}
pub async fn auth_logout(
session: Session,
) -> anyhow::Result<impl IntoResponse, AppError> {
session.remove::<User>(USER_SESSION).await?;
Ok(Redirect::to("/"))
}
#[derive(Debug, Deserialize)]
pub struct AuthRequest {
pub code: String,
pub state: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct User {
pub id: String,
pub email: String,
pub name: String,
pub verified_email: bool,
pub picture: String,
}
pub struct UserExtractError(http::StatusCode);
impl IntoResponse for UserExtractError {
fn into_response(self) -> Response {
match self.0 {
StatusCode::UNAUTHORIZED => { Redirect::temporary("/auth/login").into_response() }
StatusCode::FORBIDDEN => { StatusCode::FORBIDDEN.into_response() }
_ => StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
}
}
impl From<(http::StatusCode, &'static str)> for UserExtractError {
fn from(value: (StatusCode, &'static str)) -> Self {
Self(value.0)
}
}
impl From<tower_sessions::session::Error> for UserExtractError {
fn from(_value: tower_sessions::session::Error) -> Self {
Self(StatusCode::INTERNAL_SERVER_ERROR)
}
}
#[async_trait]
impl<S> FromRequestParts<S> for User
where
SqlitePool: FromRef<S>,
S: Send + Sync,
{
type Rejection = UserExtractError;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let session = Session::from_request_parts(parts, state).await
.map_err(|_| UserExtractError(StatusCode::INTERNAL_SERVER_ERROR))?;
let user: User = session.get(USER_SESSION).await
.map_err(|_| UserExtractError(StatusCode::INTERNAL_SERVER_ERROR))?
.ok_or(UserExtractError(StatusCode::UNAUTHORIZED))?;
let db = SqlitePool::from_ref(state);
//TODO actual verification of users
if user.email != "whatswithwes@gmail.com" {
Err(UserExtractError(StatusCode::FORBIDDEN))?;
}
Ok(user)
}
}

Powered by TurnKey Linux.