You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
170 lines
5.2 KiB
170 lines
5.2 KiB
use anyhow::Context;
|
|
use axum::{async_trait, http};
|
|
use axum::extract::{FromRef, FromRequestParts, Query, State};
|
|
use axum::http::{header, StatusCode};
|
|
use axum::http::request::Parts;
|
|
use axum::response::{IntoResponse, Redirect, Response};
|
|
use oauth2::basic::BasicClient;
|
|
use oauth2::reqwest::async_http_client;
|
|
use oauth2::{AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, RevocationUrl, Scope, TokenResponse, TokenUrl};
|
|
use serde::{Deserialize, Serialize};
|
|
use sqlx::SqlitePool;
|
|
use tower_sessions::Session;
|
|
|
|
use crate::error::AppError;
|
|
use crate::{CSRF_TOKEN, USER_SESSION};
|
|
|
|
|
|
pub fn init_client() -> anyhow::Result<BasicClient> {
|
|
use std::env::var;
|
|
|
|
let client_id = var("OAUTH_CLIENT_ID").context("env OAUTH_CLIENT_ID not set")?;
|
|
let client_secret = var("OAUTH_CLIENT_SECRET").context("env OAUTH_CLIENT_SECRET not set")?;
|
|
let auth_url = var("OAUTH_AUTH_URL").context("env OAUTH_AUTH_URL not set")?;
|
|
let token_url = var("OAUTH_TOKEN_URL").context("env OAUTH_TOKEN_URL not set")?;
|
|
let revoke_url = var("OAUTH_REVOKE_URL").context("env OAUTH_REVOKE_URL not set")?;
|
|
let redirect_url = var("OAUTH_REDIRECT_URL").context("env OAUTH_REDIRECT_URL not set")?;
|
|
|
|
let client_id = ClientId::new(client_id);
|
|
let client_secret = ClientSecret::new(client_secret);
|
|
let auth_url = AuthUrl::new(auth_url)?;
|
|
let token_url = TokenUrl::new(token_url)?;
|
|
let revoke_url = RevocationUrl::new(revoke_url)?;
|
|
let redirect_url = RedirectUrl::new(redirect_url)?;
|
|
|
|
let client = BasicClient::new(
|
|
client_id,
|
|
Some(client_secret),
|
|
auth_url,
|
|
Some(token_url))
|
|
.set_redirect_uri(redirect_url)
|
|
.set_revocation_uri(revoke_url);
|
|
|
|
Ok(client)
|
|
}
|
|
|
|
pub async fn auth_google(
|
|
session: Session,
|
|
State(oauth_client): State<BasicClient>,
|
|
) -> anyhow::Result<impl IntoResponse, AppError> {
|
|
let (auth_url, csrf_token) = oauth_client
|
|
.authorize_url(CsrfToken::new_random)
|
|
.add_scope(Scope::new("profile".to_string()))
|
|
.add_scope(Scope::new("email".to_string()))
|
|
.url();
|
|
|
|
session.insert(CSRF_TOKEN, csrf_token).await?;
|
|
|
|
Ok(Redirect::to(auth_url.as_ref()))
|
|
}
|
|
|
|
|
|
pub async fn auth_authorized(
|
|
session: Session,
|
|
Query(query_auth): Query<AuthRequest>,
|
|
State(oauth_client): State<BasicClient>,
|
|
) -> anyhow::Result<impl IntoResponse, AppError> {
|
|
let user_info_endpoint = std::env::var("OAUTH_USER_INFO_URL")
|
|
.context("OAUTH_USER_INFO_URL not set")?;
|
|
|
|
let token = oauth_client
|
|
.exchange_code(AuthorizationCode::new(query_auth.code.clone()))
|
|
.request_async(async_http_client)
|
|
.await
|
|
.context("failed in sending request request to authorization server")?;
|
|
|
|
// Fetch user data
|
|
let client = reqwest::Client::new();
|
|
let user_data = client
|
|
.get(user_info_endpoint)
|
|
.bearer_auth(token.access_token().secret())
|
|
.send()
|
|
.await
|
|
.context("failed in sending request to target Url")?;
|
|
|
|
let user_data = user_data
|
|
.json::<User>()
|
|
.await
|
|
.context("failed to deserialize response as JSON")?;
|
|
|
|
session.insert(USER_SESSION, user_data).await?;
|
|
|
|
//TODO Redirect somewhere sane
|
|
Ok(Redirect::to("/protected"))
|
|
}
|
|
|
|
pub async fn auth_logout(
|
|
session: Session,
|
|
) -> anyhow::Result<impl IntoResponse, AppError> {
|
|
|
|
session.remove::<User>(USER_SESSION).await?;
|
|
|
|
Ok(Redirect::to("/"))
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct AuthRequest {
|
|
pub code: String,
|
|
pub state: String,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
pub struct User {
|
|
pub id: String,
|
|
pub email: String,
|
|
pub name: String,
|
|
pub verified_email: bool,
|
|
pub picture: String,
|
|
}
|
|
|
|
pub struct UserExtractError(http::StatusCode);
|
|
|
|
impl IntoResponse for UserExtractError {
|
|
fn into_response(self) -> Response {
|
|
match self.0 {
|
|
StatusCode::UNAUTHORIZED => { Redirect::temporary("/auth/login").into_response() }
|
|
StatusCode::FORBIDDEN => { StatusCode::FORBIDDEN.into_response() }
|
|
_ => StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
|
}
|
|
}
|
|
}
|
|
|
|
impl From<(http::StatusCode, &'static str)> for UserExtractError {
|
|
fn from(value: (StatusCode, &'static str)) -> Self {
|
|
Self(value.0)
|
|
}
|
|
}
|
|
|
|
impl From<tower_sessions::session::Error> for UserExtractError {
|
|
fn from(_value: tower_sessions::session::Error) -> Self {
|
|
Self(StatusCode::INTERNAL_SERVER_ERROR)
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl<S> FromRequestParts<S> for User
|
|
where
|
|
SqlitePool: FromRef<S>,
|
|
S: Send + Sync,
|
|
{
|
|
type Rejection = UserExtractError;
|
|
|
|
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
|
let session = Session::from_request_parts(parts, state).await
|
|
.map_err(|_| UserExtractError(StatusCode::INTERNAL_SERVER_ERROR))?;
|
|
|
|
let user: User = session.get(USER_SESSION).await
|
|
.map_err(|_| UserExtractError(StatusCode::INTERNAL_SERVER_ERROR))?
|
|
.ok_or(UserExtractError(StatusCode::UNAUTHORIZED))?;
|
|
|
|
let db = SqlitePool::from_ref(state);
|
|
|
|
//TODO actual verification of users
|
|
if user.email != "whatswithwes@gmail.com" {
|
|
Err(UserExtractError(StatusCode::FORBIDDEN))?;
|
|
}
|
|
|
|
Ok(user)
|
|
}
|
|
}
|