use anyhow::Context; use axum::{async_trait, http}; use axum::extract::{FromRef, FromRequestParts, Query, State}; use axum::http::{header, StatusCode}; use axum::http::request::Parts; use axum::response::{IntoResponse, Redirect, Response}; use oauth2::basic::BasicClient; use oauth2::reqwest::async_http_client; use oauth2::{AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, RevocationUrl, Scope, TokenResponse, TokenUrl}; use serde::{Deserialize, Serialize}; use sqlx::SqlitePool; use tower_sessions::Session; use crate::error::AppError; use crate::{CSRF_TOKEN, USER_SESSION}; pub fn init_client() -> anyhow::Result { use std::env::var; let client_id = var("OAUTH_CLIENT_ID").context("env OAUTH_CLIENT_ID not set")?; let client_secret = var("OAUTH_CLIENT_SECRET").context("env OAUTH_CLIENT_SECRET not set")?; let auth_url = var("OAUTH_AUTH_URL").context("env OAUTH_AUTH_URL not set")?; let token_url = var("OAUTH_TOKEN_URL").context("env OAUTH_TOKEN_URL not set")?; let revoke_url = var("OAUTH_REVOKE_URL").context("env OAUTH_REVOKE_URL not set")?; let redirect_url = var("OAUTH_REDIRECT_URL").context("env OAUTH_REDIRECT_URL not set")?; let client_id = ClientId::new(client_id); let client_secret = ClientSecret::new(client_secret); let auth_url = AuthUrl::new(auth_url)?; let token_url = TokenUrl::new(token_url)?; let revoke_url = RevocationUrl::new(revoke_url)?; let redirect_url = RedirectUrl::new(redirect_url)?; let client = BasicClient::new( client_id, Some(client_secret), auth_url, Some(token_url)) .set_redirect_uri(redirect_url) .set_revocation_uri(revoke_url); Ok(client) } pub async fn auth_google( session: Session, State(oauth_client): State, ) -> anyhow::Result { let (auth_url, csrf_token) = oauth_client .authorize_url(CsrfToken::new_random) .add_scope(Scope::new("profile".to_string())) .add_scope(Scope::new("email".to_string())) .url(); session.insert(CSRF_TOKEN, csrf_token).await?; Ok(Redirect::to(auth_url.as_ref())) } pub async fn auth_authorized( session: Session, Query(query_auth): Query, State(oauth_client): State, ) -> anyhow::Result { let user_info_endpoint = std::env::var("OAUTH_USER_INFO_URL") .context("OAUTH_USER_INFO_URL not set")?; let token = oauth_client .exchange_code(AuthorizationCode::new(query_auth.code.clone())) .request_async(async_http_client) .await .context("failed in sending request request to authorization server")?; // Fetch user data let client = reqwest::Client::new(); let user_data = client .get(user_info_endpoint) .bearer_auth(token.access_token().secret()) .send() .await .context("failed in sending request to target Url")?; let user_data = user_data .json::() .await .context("failed to deserialize response as JSON")?; session.insert(USER_SESSION, user_data).await?; //TODO Redirect somewhere sane Ok(Redirect::to("/protected")) } pub async fn auth_logout( session: Session, ) -> anyhow::Result { session.remove::(USER_SESSION).await?; Ok(Redirect::to("/")) } #[derive(Debug, Deserialize)] pub struct AuthRequest { pub code: String, pub state: String, } #[derive(Debug, Serialize, Deserialize)] pub struct User { pub id: String, pub email: String, pub name: String, pub verified_email: bool, pub picture: String, } pub struct UserExtractError(http::StatusCode); impl IntoResponse for UserExtractError { fn into_response(self) -> Response { match self.0 { StatusCode::UNAUTHORIZED => { Redirect::temporary("/auth/login").into_response() } StatusCode::FORBIDDEN => { StatusCode::FORBIDDEN.into_response() } _ => StatusCode::INTERNAL_SERVER_ERROR.into_response() } } } impl From<(http::StatusCode, &'static str)> for UserExtractError { fn from(value: (StatusCode, &'static str)) -> Self { Self(value.0) } } impl From for UserExtractError { fn from(_value: tower_sessions::session::Error) -> Self { Self(StatusCode::INTERNAL_SERVER_ERROR) } } #[async_trait] impl FromRequestParts for User where SqlitePool: FromRef, S: Send + Sync, { type Rejection = UserExtractError; async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let session = Session::from_request_parts(parts, state).await .map_err(|_| UserExtractError(StatusCode::INTERNAL_SERVER_ERROR))?; let user: User = session.get(USER_SESSION).await .map_err(|_| UserExtractError(StatusCode::INTERNAL_SERVER_ERROR))? .ok_or(UserExtractError(StatusCode::UNAUTHORIZED))?; let db = SqlitePool::from_ref(state); //TODO actual verification of users if user.email != "whatswithwes@gmail.com" { Err(UserExtractError(StatusCode::FORBIDDEN))?; } Ok(user) } }