|
|
|
@ -1,9 +1,11 @@
|
|
|
|
use anyhow::{anyhow, Context};
|
|
|
|
use anyhow::{anyhow, Context};
|
|
|
|
use axum::{async_trait, http};
|
|
|
|
use askama::Template;
|
|
|
|
|
|
|
|
use axum::{async_trait, http, Router};
|
|
|
|
use axum::extract::{FromRef, FromRequestParts, Query, State};
|
|
|
|
use axum::extract::{FromRef, FromRequestParts, Query, State};
|
|
|
|
use axum::http::{header, StatusCode};
|
|
|
|
use axum::http::{header, StatusCode};
|
|
|
|
use axum::http::request::Parts;
|
|
|
|
use axum::http::request::Parts;
|
|
|
|
use axum::response::{IntoResponse, Redirect, Response};
|
|
|
|
use axum::response::{IntoResponse, Redirect, Response};
|
|
|
|
|
|
|
|
use axum::routing::get;
|
|
|
|
use oauth2::basic::BasicClient;
|
|
|
|
use oauth2::basic::BasicClient;
|
|
|
|
use oauth2::reqwest::async_http_client;
|
|
|
|
use oauth2::reqwest::async_http_client;
|
|
|
|
use oauth2::{AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, RevocationUrl, Scope, TokenResponse, TokenUrl};
|
|
|
|
use oauth2::{AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, RevocationUrl, Scope, TokenResponse, TokenUrl};
|
|
|
|
@ -11,9 +13,17 @@ use serde::{Deserialize, Serialize};
|
|
|
|
use sqlx::SqlitePool;
|
|
|
|
use sqlx::SqlitePool;
|
|
|
|
use tower_sessions::Session;
|
|
|
|
use tower_sessions::Session;
|
|
|
|
|
|
|
|
|
|
|
|
use crate::error::AppError;
|
|
|
|
use crate::error::{AppError, AppForbiddenResponse};
|
|
|
|
use crate::{CSRF_TOKEN, USER_SESSION};
|
|
|
|
use crate::{auth, CSRF_TOKEN, USER_SESSION};
|
|
|
|
|
|
|
|
use crate::error::QueryExtractor;
|
|
|
|
|
|
|
|
use crate::app_state::AppState;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pub fn routes() -> Router<AppState> {
|
|
|
|
|
|
|
|
Router::new()
|
|
|
|
|
|
|
|
.route("/auth/login", get(auth_login))
|
|
|
|
|
|
|
|
.route("/auth/logout", get(auth_logout))
|
|
|
|
|
|
|
|
.route("/auth/authorized", get(auth_authorized))
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
pub fn init_client() -> anyhow::Result<BasicClient> {
|
|
|
|
pub fn init_client() -> anyhow::Result<BasicClient> {
|
|
|
|
use std::env::var;
|
|
|
|
use std::env::var;
|
|
|
|
@ -43,10 +53,15 @@ pub fn init_client() -> anyhow::Result<BasicClient> {
|
|
|
|
Ok(client)
|
|
|
|
Ok(client)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
pub async fn auth_google(
|
|
|
|
pub async fn auth_login(
|
|
|
|
session: Session,
|
|
|
|
session: Session,
|
|
|
|
|
|
|
|
user: Option<User>,
|
|
|
|
State(oauth_client): State<BasicClient>,
|
|
|
|
State(oauth_client): State<BasicClient>,
|
|
|
|
) -> anyhow::Result<impl IntoResponse, AppError> {
|
|
|
|
) -> anyhow::Result<impl IntoResponse, AppError> {
|
|
|
|
|
|
|
|
if user.is_some() {
|
|
|
|
|
|
|
|
return Ok(Redirect::to("/"));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
let (auth_url, csrf_token) = oauth_client
|
|
|
|
let (auth_url, csrf_token) = oauth_client
|
|
|
|
.authorize_url(CsrfToken::new_random)
|
|
|
|
.authorize_url(CsrfToken::new_random)
|
|
|
|
.add_scope(Scope::new("profile".to_string()))
|
|
|
|
.add_scope(Scope::new("profile".to_string()))
|
|
|
|
@ -61,7 +76,7 @@ pub async fn auth_google(
|
|
|
|
|
|
|
|
|
|
|
|
pub async fn auth_authorized(
|
|
|
|
pub async fn auth_authorized(
|
|
|
|
session: Session,
|
|
|
|
session: Session,
|
|
|
|
Query(query_auth): Query<AuthRequest>,
|
|
|
|
QueryExtractor(query_auth): QueryExtractor<AuthRequest>,
|
|
|
|
State(oauth_client): State<BasicClient>,
|
|
|
|
State(oauth_client): State<BasicClient>,
|
|
|
|
) -> anyhow::Result<impl IntoResponse, AppError> {
|
|
|
|
) -> anyhow::Result<impl IntoResponse, AppError> {
|
|
|
|
let user_info_endpoint = std::env::var("OAUTH_USER_INFO_URL")
|
|
|
|
let user_info_endpoint = std::env::var("OAUTH_USER_INFO_URL")
|
|
|
|
@ -106,24 +121,29 @@ pub async fn auth_authorized(
|
|
|
|
|
|
|
|
|
|
|
|
let is_authorized = valid_users.contains(&user_data.email.as_str());
|
|
|
|
let is_authorized = valid_users.contains(&user_data.email.as_str());
|
|
|
|
|
|
|
|
|
|
|
|
if is_authorized {
|
|
|
|
if !is_authorized {
|
|
|
|
session.insert(USER_SESSION, user_data).await?;
|
|
|
|
return Ok(AppForbiddenResponse::new(&user_data.email, "application").into_response())
|
|
|
|
|
|
|
|
|
|
|
|
//TODO Redirect somewhere sane
|
|
|
|
|
|
|
|
Ok(Redirect::to("/protected").into_response())
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
else {
|
|
|
|
|
|
|
|
Ok((http::StatusCode::UNAUTHORIZED, "Unauthorized").into_response())
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
session.insert(USER_SESSION, user_data).await?;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Ok(Redirect::to("/").into_response())
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(Template)]
|
|
|
|
|
|
|
|
#[template(path = "logged-out.html")]
|
|
|
|
|
|
|
|
struct LoggedOutTemplate;
|
|
|
|
|
|
|
|
|
|
|
|
pub async fn auth_logout(
|
|
|
|
pub async fn auth_logout(
|
|
|
|
session: Session,
|
|
|
|
session: Session,
|
|
|
|
|
|
|
|
user: Option<User>,
|
|
|
|
) -> anyhow::Result<impl IntoResponse, AppError> {
|
|
|
|
) -> anyhow::Result<impl IntoResponse, AppError> {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if user.is_some() {
|
|
|
|
|
|
|
|
session.remove::<User>(USER_SESSION).await?;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
session.remove::<User>(USER_SESSION).await?;
|
|
|
|
Ok(LoggedOutTemplate.into_response())
|
|
|
|
|
|
|
|
|
|
|
|
Ok(Redirect::to("/"))
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
|
|
@ -141,30 +161,20 @@ pub struct User {
|
|
|
|
pub picture: String,
|
|
|
|
pub picture: String,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
pub struct UserExtractError(http::StatusCode);
|
|
|
|
pub enum UserExtractError {
|
|
|
|
|
|
|
|
InternalServerError(anyhow::Error),
|
|
|
|
|
|
|
|
Unauthorized,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
impl IntoResponse for UserExtractError {
|
|
|
|
impl IntoResponse for UserExtractError {
|
|
|
|
fn into_response(self) -> Response {
|
|
|
|
fn into_response(self) -> Response {
|
|
|
|
match self.0 {
|
|
|
|
match self {
|
|
|
|
StatusCode::UNAUTHORIZED => { Redirect::temporary("/auth/login").into_response() }
|
|
|
|
UserExtractError::InternalServerError(err) => AppError::from(err).into_response(),
|
|
|
|
StatusCode::FORBIDDEN => { StatusCode::FORBIDDEN.into_response() }
|
|
|
|
UserExtractError::Unauthorized => { Redirect::temporary("/auth/login").into_response() }
|
|
|
|
_ => StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
impl From<(http::StatusCode, &'static str)> for UserExtractError {
|
|
|
|
|
|
|
|
fn from(value: (StatusCode, &'static str)) -> Self {
|
|
|
|
|
|
|
|
Self(value.0)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
impl From<tower_sessions::session::Error> for UserExtractError {
|
|
|
|
|
|
|
|
fn from(_value: tower_sessions::session::Error) -> Self {
|
|
|
|
|
|
|
|
Self(StatusCode::INTERNAL_SERVER_ERROR)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[async_trait]
|
|
|
|
#[async_trait]
|
|
|
|
impl<S> FromRequestParts<S> for User
|
|
|
|
impl<S> FromRequestParts<S> for User
|
|
|
|
where
|
|
|
|
where
|
|
|
|
@ -175,11 +185,11 @@ where
|
|
|
|
|
|
|
|
|
|
|
|
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
|
|
|
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
|
|
|
let session = Session::from_request_parts(parts, state).await
|
|
|
|
let session = Session::from_request_parts(parts, state).await
|
|
|
|
.map_err(|_| UserExtractError(StatusCode::INTERNAL_SERVER_ERROR))?;
|
|
|
|
.map_err(|_| UserExtractError::InternalServerError(anyhow!("session from parts failed")))?;
|
|
|
|
|
|
|
|
|
|
|
|
let user: User = session.get(USER_SESSION).await
|
|
|
|
let user = session.get(USER_SESSION).await
|
|
|
|
.map_err(|_| UserExtractError(StatusCode::INTERNAL_SERVER_ERROR))?
|
|
|
|
.map_err(|e| UserExtractError::InternalServerError(anyhow::Error::from(e)))?
|
|
|
|
.ok_or(UserExtractError(StatusCode::UNAUTHORIZED))?;
|
|
|
|
.ok_or(UserExtractError::Unauthorized)?;
|
|
|
|
|
|
|
|
|
|
|
|
Ok(user)
|
|
|
|
Ok(user)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|