mirror of
https://gitlab.com/famedly/conduit.git
synced 2025-04-22 14:10:16 +03:00
Merge branch 'sso-oidc' into 'next'
Single Sign-On via OIDC/OAuth2 (attempt #2) Closes #134 See merge request famedly/conduit!676
This commit is contained in:
commit
da543a726e
21 changed files with 1104 additions and 46 deletions
16
Cargo.toml
16
Cargo.toml
|
@ -35,13 +35,17 @@ axum = { version = "0.7", default-features = false, features = [
|
|||
"json",
|
||||
"matched-path",
|
||||
], optional = true }
|
||||
axum-extra = { version = "0.9", features = ["typed-header"] }
|
||||
axum-extra = { version = "0.9", features = ["cookie", "typed-header"] }
|
||||
axum-server = { version = "0.6", features = ["tls-rustls"] }
|
||||
tower = { version = "0.4.13", features = ["util"] }
|
||||
tower-http = { version = "0.5", features = [
|
||||
"add-extension",
|
||||
"cors",
|
||||
"follow-redirect",
|
||||
"map-request-body",
|
||||
"sensitive-headers",
|
||||
"set-header",
|
||||
"timeout",
|
||||
"trace",
|
||||
"util",
|
||||
] }
|
||||
|
@ -172,6 +176,16 @@ optional = true
|
|||
package = "rust-rocksdb"
|
||||
version = "0.25"
|
||||
|
||||
[dependencies.mas-http]
|
||||
features = ["client"]
|
||||
git = "https://github.com/matrix-org/matrix-authentication-service"
|
||||
rev = "fbc360d1a94ef2ebf63d979bb403228a700f43c8"
|
||||
|
||||
[dependencies.mas-oidc-client]
|
||||
features = []
|
||||
git = "https://github.com/matrix-org/matrix-authentication-service"
|
||||
rev = "fbc360d1a94ef2ebf63d979bb403228a700f43c8"
|
||||
|
||||
[target.'cfg(unix)'.dependencies]
|
||||
nix = { version = "0.28", features = ["resource"] }
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ Conduit's configuration file is divided into the following sections:
|
|||
- [Global](#global)
|
||||
- [TLS](#tls)
|
||||
- [Proxy](#proxy)
|
||||
- [SSO (Single Sign-On)](#sso)
|
||||
|
||||
|
||||
## Global
|
||||
|
@ -111,3 +112,20 @@ exclude = ["*.clearnet.onion"]
|
|||
[global]
|
||||
{{#include ../conduit-example.toml:22:}}
|
||||
```
|
||||
|
||||
### SSO (Single Sign-On)
|
||||
|
||||
Authentication through SSO instead of a password can be enabled by configuring OIDC (OpenID Connect) identity providers.
|
||||
Identity providers using OAuth such as Github are not supported yet.
|
||||
|
||||
> **Note:** The `*` symbol indicates that the field is required, and the values in **parentheses** are the possible values
|
||||
|
||||
| Field | Type | Description | Default |
|
||||
| --- | --- | --- | --- |
|
||||
| `issuer`* | `Url` | The issuer URL. | N/A |
|
||||
| `name` | `string` | The name displayed on fallback pages. | `issuer` |
|
||||
| `icon` | `Url` OR `MxcUri` | The icon displayed on fallback pages. | N/A |
|
||||
| `scopes` | `array` | The scopes used to obtain extra claims which can be used for templates. | `["openid"]` |
|
||||
| `client_id`* | `string` | The provider-supplied, unique ID for the client. | N/A |
|
||||
| `client_secret`* | `string` | The provider-supplied, unique ID for the client. | N/A |
|
||||
| `authentication_method`* | `"basic" OR "post"` | The method used for client authentication. | N/A |
|
||||
|
|
|
@ -100,6 +100,12 @@ pub async fn upload_signing_keys_route(
|
|||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||
|
||||
let master_key = services()
|
||||
.users
|
||||
.get_master_key(Some(sender_user), sender_user, &|other| {
|
||||
sender_user == other
|
||||
})?;
|
||||
|
||||
// UIAA
|
||||
let mut uiaainfo = UiaaInfo {
|
||||
flows: vec![AuthFlow {
|
||||
|
@ -111,7 +117,15 @@ pub async fn upload_signing_keys_route(
|
|||
auth_error: None,
|
||||
};
|
||||
|
||||
if let Some(auth) = &body.auth {
|
||||
if let (Some(master_key), None) = (&body.master_key, master_key) {
|
||||
services().users.add_cross_signing_keys(
|
||||
sender_user,
|
||||
master_key,
|
||||
&body.self_signing_key,
|
||||
&body.user_signing_key,
|
||||
true,
|
||||
)?;
|
||||
} else if let Some(auth) = &body.auth {
|
||||
let (worked, uiaainfo) =
|
||||
services()
|
||||
.uiaa
|
||||
|
@ -130,16 +144,6 @@ pub async fn upload_signing_keys_route(
|
|||
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
||||
}
|
||||
|
||||
if let Some(master_key) = &body.master_key {
|
||||
services().users.add_cross_signing_keys(
|
||||
sender_user,
|
||||
master_key,
|
||||
&body.self_signing_key,
|
||||
&body.user_signing_key,
|
||||
true, // notify so that other users see the new keys
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(upload_signing_keys::v3::Response {})
|
||||
}
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ mod room;
|
|||
mod search;
|
||||
mod session;
|
||||
mod space;
|
||||
mod sso;
|
||||
mod state;
|
||||
mod sync;
|
||||
mod tag;
|
||||
|
@ -60,6 +61,7 @@ pub use room::*;
|
|||
pub use search::*;
|
||||
pub use session::*;
|
||||
pub use space::*;
|
||||
pub use sso::*;
|
||||
pub use state::*;
|
||||
pub use sync::*;
|
||||
pub use tag::*;
|
||||
|
@ -76,3 +78,5 @@ pub const DEVICE_ID_LENGTH: usize = 10;
|
|||
pub const TOKEN_LENGTH: usize = 32;
|
||||
pub const SESSION_ID_LENGTH: usize = 32;
|
||||
pub const AUTO_GEN_PASSWORD_LENGTH: usize = 15;
|
||||
pub const AUTH_SESSION_EXPIRATION_SECS: u64 = 60 * 5;
|
||||
pub const LOGIN_TOKEN_EXPIRATION_SECS: u64 = 15;
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
|
||||
use crate::{services, utils, Error, Result, Ruma};
|
||||
use crate::{service::sso::LoginToken, services, utils, Error, Result, Ruma};
|
||||
use jsonwebtoken::{Algorithm, Validation};
|
||||
use ruma::{
|
||||
api::client::{
|
||||
error::ErrorKind,
|
||||
|
@ -24,10 +25,19 @@ struct Claims {
|
|||
pub async fn get_login_types_route(
|
||||
_body: Ruma<get_login_types::v3::Request>,
|
||||
) -> Result<get_login_types::v3::Response> {
|
||||
Ok(get_login_types::v3::Response::new(vec![
|
||||
let identity_providers: Vec<_> = services().sso.login_type().collect();
|
||||
let mut flows = vec![
|
||||
get_login_types::v3::LoginType::Password(Default::default()),
|
||||
get_login_types::v3::LoginType::ApplicationService(Default::default()),
|
||||
]))
|
||||
];
|
||||
|
||||
if !identity_providers.is_empty() {
|
||||
flows.push(get_login_types::v3::LoginType::Sso(
|
||||
get_login_types::v3::SsoLoginType { identity_providers },
|
||||
));
|
||||
}
|
||||
|
||||
Ok(get_login_types::v3::Response::new(flows))
|
||||
}
|
||||
|
||||
/// # `POST /_matrix/client/r0/login`
|
||||
|
@ -101,35 +111,64 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
|
|||
user_id
|
||||
}
|
||||
login::v3::LoginInfo::Token(login::v3::Token { token }) => {
|
||||
if let Some(jwt_decoding_key) = services().globals.jwt_decoding_key() {
|
||||
let token = jsonwebtoken::decode::<Claims>(
|
||||
token,
|
||||
jwt_decoding_key,
|
||||
&jsonwebtoken::Validation::default(),
|
||||
)
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid."))?;
|
||||
let username = token.claims.sub.to_lowercase();
|
||||
let user_id =
|
||||
UserId::parse_with_server_name(username, services().globals.server_name())
|
||||
.map_err(|_| {
|
||||
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
|
||||
})?;
|
||||
match (
|
||||
services().globals.jwt_decoding_key(),
|
||||
services().globals.config.idps.is_empty(),
|
||||
) {
|
||||
(_, false) => {
|
||||
let mut v = Validation::new(Algorithm::HS256);
|
||||
|
||||
if services().appservice.is_exclusive_user_id(&user_id).await {
|
||||
v.set_required_spec_claims(&["sub", "exp", "aud", "iss"]);
|
||||
v.validate_aud = false;
|
||||
v.validate_nbf = false;
|
||||
|
||||
services()
|
||||
.globals
|
||||
.validate_claims::<LoginToken>(token, Some(&v))
|
||||
.map(LoginToken::audience)
|
||||
.map_err(|e| {
|
||||
tracing::warn!("Invalid token: {}", e);
|
||||
|
||||
Error::BadRequest(ErrorKind::InvalidParam, "Invalid token.")
|
||||
})?
|
||||
}
|
||||
(Some(jwt_decoding_key), _) => {
|
||||
let token = jsonwebtoken::decode::<Claims>(
|
||||
token,
|
||||
jwt_decoding_key,
|
||||
&Validation::default(),
|
||||
)
|
||||
.map_err(|_| {
|
||||
Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid.")
|
||||
})?;
|
||||
let username = token.claims.sub.to_lowercase();
|
||||
let user_id =
|
||||
UserId::parse_with_server_name(username, services().globals.server_name())
|
||||
.map_err(|_| {
|
||||
Error::BadRequest(
|
||||
ErrorKind::InvalidUsername,
|
||||
"Username is invalid.",
|
||||
)
|
||||
})?;
|
||||
|
||||
if services().appservice.is_exclusive_user_id(&user_id).await {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Exclusive,
|
||||
"User id reserved by appservice.",
|
||||
));
|
||||
}
|
||||
|
||||
user_id
|
||||
}
|
||||
(None, _) => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Exclusive,
|
||||
"User id reserved by appservice.",
|
||||
ErrorKind::Unknown,
|
||||
"Token login is not supported (server has no jwt decoding key).",
|
||||
));
|
||||
}
|
||||
|
||||
user_id
|
||||
} else {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Unknown,
|
||||
"Token login is not supported (server has no jwt decoding key).",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
login::v3::LoginInfo::ApplicationService(login::v3::ApplicationService {
|
||||
identifier,
|
||||
user,
|
||||
|
|
473
src/api/client_server/sso.rs
Normal file
473
src/api/client_server/sso.rs
Normal file
|
@ -0,0 +1,473 @@
|
|||
use std::{borrow::Borrow, collections::HashMap, iter::Iterator, time::SystemTime};
|
||||
|
||||
use crate::{
|
||||
config::IdpConfig,
|
||||
service::sso::{
|
||||
LoginToken, ValidationData, SSO_AUTH_EXPIRATION_SECS, SSO_SESSION_COOKIE, SUBJECT_CLAIM_KEY,
|
||||
},
|
||||
services, utils, Error, Result, Ruma,
|
||||
};
|
||||
use futures_util::TryFutureExt;
|
||||
use mas_oidc_client::{
|
||||
requests::{
|
||||
authorization_code::{self, AuthorizationRequestData},
|
||||
jose::{self, JwtVerificationData},
|
||||
userinfo,
|
||||
},
|
||||
types::{
|
||||
client_credentials::ClientCredentials,
|
||||
iana::jose::JsonWebSignatureAlg,
|
||||
requests::{AccessTokenResponse, AuthorizationResponse},
|
||||
},
|
||||
};
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
use ruma::{
|
||||
api::client::{
|
||||
error::ErrorKind,
|
||||
media::create_content,
|
||||
session::{sso_login, sso_login_with_provider},
|
||||
},
|
||||
events::{room::message::RoomMessageEventContent, GlobalAccountDataEventType},
|
||||
push, UserId,
|
||||
};
|
||||
use serde_json::Value;
|
||||
use tracing::{error, info, warn};
|
||||
use url::Url;
|
||||
|
||||
pub const CALLBACK_PATH: &str = "/_matrix/client/unstable/conduit/callback";
|
||||
|
||||
/// # `GET /_matrix/client/v3/login/sso/redirect`
|
||||
///
|
||||
/// Redirect the user to the SSO interfa.
|
||||
/// TODO: this should be removed once Ruma supports trailing slashes.
|
||||
pub async fn get_sso_redirect_route(
|
||||
Ruma {
|
||||
body,
|
||||
sender_user,
|
||||
sender_device,
|
||||
sender_servername,
|
||||
json_body,
|
||||
..
|
||||
}: Ruma<sso_login::v3::Request>,
|
||||
) -> Result<sso_login::v3::Response> {
|
||||
let sso_login_with_provider::v3::Response { location, cookie } =
|
||||
get_sso_redirect_with_provider_route(
|
||||
Ruma {
|
||||
body: sso_login_with_provider::v3::Request::new(
|
||||
Default::default(),
|
||||
body.redirect_url,
|
||||
),
|
||||
sender_user,
|
||||
sender_device,
|
||||
sender_servername,
|
||||
json_body,
|
||||
appservice_info: None,
|
||||
}
|
||||
.into(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(sso_login::v3::Response { location, cookie })
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/v3/login/sso/redirect/{idpId}`
|
||||
///
|
||||
/// Redirects the user to the SSO interface.
|
||||
pub async fn get_sso_redirect_with_provider_route(
|
||||
body: Ruma<sso_login_with_provider::v3::Request>,
|
||||
) -> Result<sso_login_with_provider::v3::Response> {
|
||||
let idp_ids: Vec<&str> = services()
|
||||
.globals
|
||||
.config
|
||||
.idps
|
||||
.iter()
|
||||
.map(Borrow::borrow)
|
||||
.collect();
|
||||
|
||||
let provider = match &*idp_ids {
|
||||
[] => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::forbidden(),
|
||||
"Single Sign-On is disabled.",
|
||||
));
|
||||
}
|
||||
[idp_id] => services().sso.get(idp_id).expect("we know it exists"),
|
||||
[_, ..] => services().sso.get(&body.idp_id).ok_or_else(|| {
|
||||
Error::BadRequest(ErrorKind::InvalidParam, "Unknown identity provider.")
|
||||
})?,
|
||||
};
|
||||
|
||||
let redirect_url = body
|
||||
.redirect_url
|
||||
.parse::<Url>()
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid redirect_url."))?;
|
||||
|
||||
let mut callback = services()
|
||||
.globals
|
||||
.well_known_client()
|
||||
.parse::<Url>()
|
||||
.map_err(|_| Error::bad_config("Invalid well_known_client url."))?;
|
||||
callback.set_path(CALLBACK_PATH);
|
||||
|
||||
let (auth_url, validation_data) = authorization_code::build_authorization_url(
|
||||
provider.metadata.authorization_endpoint().clone(),
|
||||
AuthorizationRequestData::new(
|
||||
provider.config.client_id.clone(),
|
||||
provider.config.scopes.clone(),
|
||||
callback,
|
||||
),
|
||||
&mut StdRng::from_entropy(),
|
||||
)
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::Unknown, "Failed to build authorization_url."))?;
|
||||
|
||||
let signed = services().globals.sign_claims(&ValidationData::new(
|
||||
Borrow::<str>::borrow(provider).to_owned(),
|
||||
redirect_url.to_string(),
|
||||
validation_data,
|
||||
));
|
||||
|
||||
Ok(sso_login_with_provider::v3::Response {
|
||||
location: auth_url.to_string(),
|
||||
cookie: Some(
|
||||
utils::build_cookie(
|
||||
SSO_SESSION_COOKIE,
|
||||
&signed,
|
||||
CALLBACK_PATH,
|
||||
Some(SSO_AUTH_EXPIRATION_SECS),
|
||||
)
|
||||
.to_string(),
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
/// # `GET /_conduit/client/sso/callback`
|
||||
///
|
||||
/// Validate the authorization response received from the identity provider.
|
||||
/// On success, generate a login token, add it to `redirectUrl` as a query and perform the redirect.
|
||||
/// If this is the first login, register the user, possibly interactively through a fallback page.
|
||||
pub async fn handle_callback_route(
|
||||
body: Ruma<sso_callback::Request>,
|
||||
) -> Result<sso_login_with_provider::v3::Response> {
|
||||
let sso_callback::Request {
|
||||
response:
|
||||
AuthorizationResponse {
|
||||
code,
|
||||
access_token: _,
|
||||
token_type: _,
|
||||
id_token: _,
|
||||
expires_in: _,
|
||||
},
|
||||
cookie,
|
||||
} = body.body;
|
||||
|
||||
let ValidationData {
|
||||
provider,
|
||||
redirect_url,
|
||||
inner: validation_data,
|
||||
} = services()
|
||||
.globals
|
||||
.validate_claims(&cookie, None)
|
||||
.map_err(|_| {
|
||||
Error::BadRequest(ErrorKind::InvalidParam, "Invalid value for session cookie.")
|
||||
})?;
|
||||
|
||||
let provider = services().sso.get(&provider).ok_or_else(|| {
|
||||
Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Unknown provider for session cookie.",
|
||||
)
|
||||
})?;
|
||||
|
||||
let IdpConfig {
|
||||
client_id,
|
||||
client_secret,
|
||||
auth_method,
|
||||
..
|
||||
} = provider.config.clone();
|
||||
|
||||
let credentials = match &*auth_method {
|
||||
"basic" => ClientCredentials::ClientSecretBasic {
|
||||
client_id,
|
||||
client_secret,
|
||||
},
|
||||
"post" => ClientCredentials::ClientSecretPost {
|
||||
client_id,
|
||||
client_secret,
|
||||
},
|
||||
_ => todo!(),
|
||||
};
|
||||
let ref jwks = jose::fetch_jwks(services().sso.service(), provider.metadata.jwks_uri())
|
||||
.await
|
||||
.map_err(|_| Error::bad_config("Failed to fetch signing keys for token endpoint."))?;
|
||||
let idt_verification_data = Some(JwtVerificationData {
|
||||
jwks,
|
||||
issuer: &provider.config.issuer,
|
||||
client_id: &provider.config.client_id,
|
||||
signing_algorithm: &JsonWebSignatureAlg::Rs256,
|
||||
});
|
||||
|
||||
let (
|
||||
AccessTokenResponse {
|
||||
access_token,
|
||||
refresh_token: _,
|
||||
token_type: _,
|
||||
expires_in: _,
|
||||
scope: _,
|
||||
..
|
||||
},
|
||||
Some(id_token),
|
||||
) = authorization_code::access_token_with_authorization_code(
|
||||
services().sso.service(),
|
||||
credentials,
|
||||
provider.metadata.token_endpoint(),
|
||||
code.unwrap_or_default(),
|
||||
validation_data,
|
||||
idt_verification_data,
|
||||
SystemTime::now().into(),
|
||||
&mut StdRng::from_entropy(),
|
||||
)
|
||||
.await
|
||||
.map_err(|_| Error::bad_config("Failed to fetch access token."))?
|
||||
else {
|
||||
unreachable!("ID token should never be empty")
|
||||
};
|
||||
|
||||
let mut userinfo = HashMap::default();
|
||||
if let Some(endpoint) = provider.metadata.userinfo_endpoint.as_ref() {
|
||||
userinfo = userinfo::fetch_userinfo(
|
||||
services().sso.service(),
|
||||
endpoint,
|
||||
&access_token,
|
||||
None,
|
||||
&id_token,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Failed to fetch claims for userinfo endpoint: {:?}", e);
|
||||
|
||||
Error::bad_config("Failed to fetch claims for userinfo endpoint.")
|
||||
})?;
|
||||
}
|
||||
|
||||
let (_, id_token) = id_token.into_parts();
|
||||
|
||||
info!("userinfo: {:?}", &userinfo);
|
||||
info!("id_token: {:?}", &id_token);
|
||||
|
||||
let subject = match id_token.get(SUBJECT_CLAIM_KEY) {
|
||||
Some(Value::String(s)) => s.to_owned(),
|
||||
Some(Value::Number(n)) => n.to_string(),
|
||||
value => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Unknown,
|
||||
value
|
||||
.map(|_| {
|
||||
error!("Subject claim is missing from ID token: {id_token:?}");
|
||||
|
||||
"Subject claim is missing from ID token."
|
||||
})
|
||||
.unwrap_or("Subject claim should be a string or number."),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let user_id = match services()
|
||||
.sso
|
||||
.user_from_subject(Borrow::<str>::borrow(provider), &subject)?
|
||||
{
|
||||
Some(user_id) => user_id,
|
||||
None => {
|
||||
let mut localpart = subject.clone();
|
||||
|
||||
let user_id = loop {
|
||||
match UserId::parse_with_server_name(&*localpart, services().globals.server_name())
|
||||
.map(|user_id| {
|
||||
(
|
||||
user_id.clone(),
|
||||
services().users.exists(&user_id).unwrap_or(true),
|
||||
)
|
||||
}) {
|
||||
Ok((user_id, false)) => break user_id,
|
||||
_ => {
|
||||
let n: u8 = rand::thread_rng().gen();
|
||||
|
||||
localpart = format!("{}{}", localpart, n % 10);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
services().users.set_placeholder_password(&user_id)?;
|
||||
let displayname = id_token
|
||||
.get("preferred_username")
|
||||
.or(id_token.get("nickname"));
|
||||
let mut displayname = displayname
|
||||
.as_deref()
|
||||
.map(Value::as_str)
|
||||
.flatten()
|
||||
.unwrap_or(user_id.localpart())
|
||||
.to_owned();
|
||||
|
||||
// If enabled append lightning bolt to display name (default true)
|
||||
if services().globals.enable_lightning_bolt() {
|
||||
displayname.push_str(" ⚡️");
|
||||
}
|
||||
|
||||
services()
|
||||
.users
|
||||
.set_displayname(&user_id, Some(displayname.clone()))?;
|
||||
|
||||
if let Some(Value::String(url)) = userinfo.get("picture").or(id_token.get("picture")) {
|
||||
let req = services()
|
||||
.globals
|
||||
.default_client()
|
||||
.get(url)
|
||||
.send()
|
||||
.and_then(reqwest::Response::bytes);
|
||||
|
||||
if let Ok(file) = req.await {
|
||||
let _ = crate::api::client_server::create_content_route(Ruma {
|
||||
body: create_content::v3::Request::new(file.to_vec()),
|
||||
sender_user: None,
|
||||
sender_device: None,
|
||||
sender_servername: None,
|
||||
json_body: None,
|
||||
appservice_info: None,
|
||||
})
|
||||
.await
|
||||
.and_then(|res| {
|
||||
tracing::info!("successfully imported avatar for {}", &user_id);
|
||||
|
||||
services()
|
||||
.users
|
||||
.set_avatar_url(&user_id, Some(res.content_uri))
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Initial account data
|
||||
services().account_data.update(
|
||||
None,
|
||||
&user_id,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
&serde_json::to_value(ruma::events::push_rules::PushRulesEvent {
|
||||
content: ruma::events::push_rules::PushRulesEventContent {
|
||||
global: push::Ruleset::server_default(&user_id),
|
||||
},
|
||||
})
|
||||
.expect("to json always works"),
|
||||
)?;
|
||||
|
||||
info!("New user {} registered on this server.", user_id);
|
||||
services()
|
||||
.admin
|
||||
.send_message(RoomMessageEventContent::notice_plain(format!(
|
||||
"New user {user_id} registered on this server."
|
||||
)));
|
||||
|
||||
if let Some(admin_room) = services().admin.get_admin_room()? {
|
||||
if services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.room_joined_count(&admin_room)?
|
||||
== Some(1)
|
||||
{
|
||||
services()
|
||||
.admin
|
||||
.make_user_admin(&user_id, displayname.to_owned())
|
||||
.await?;
|
||||
|
||||
warn!("Granting {} admin privileges as the first user", user_id);
|
||||
}
|
||||
}
|
||||
|
||||
user_id
|
||||
}
|
||||
};
|
||||
|
||||
let signed = services().globals.sign_claims(&LoginToken::new(
|
||||
Borrow::<str>::borrow(provider).to_owned(),
|
||||
user_id,
|
||||
));
|
||||
|
||||
let mut redirect_url: Url = redirect_url.parse().expect("");
|
||||
redirect_url
|
||||
.query_pairs_mut()
|
||||
.append_pair("loginToken", &signed);
|
||||
|
||||
Ok(sso_login_with_provider::v3::Response {
|
||||
location: redirect_url.to_string(),
|
||||
cookie: Some(utils::build_cookie(SSO_SESSION_COOKIE, "", CALLBACK_PATH, None).to_string()),
|
||||
})
|
||||
}
|
||||
|
||||
mod sso_callback {
|
||||
use axum_extra::headers::{self, HeaderMapExt};
|
||||
use http::Method;
|
||||
use mas_oidc_client::types::requests::AuthorizationResponse;
|
||||
use ruma::{
|
||||
api::{
|
||||
client::{session::sso_login_with_provider, Error},
|
||||
error::{FromHttpRequestError, HeaderDeserializationError},
|
||||
IncomingRequest, Metadata,
|
||||
},
|
||||
metadata,
|
||||
};
|
||||
|
||||
use crate::service::sso::SSO_SESSION_COOKIE;
|
||||
|
||||
pub const METADATA: Metadata = metadata! {
|
||||
method: GET,
|
||||
rate_limited: false,
|
||||
authentication: None,
|
||||
history: {
|
||||
1.0 => "/_matrix/client/unstable/conduit/callback",
|
||||
}
|
||||
};
|
||||
|
||||
pub struct Request {
|
||||
pub response: AuthorizationResponse,
|
||||
pub cookie: String,
|
||||
}
|
||||
|
||||
impl IncomingRequest for Request {
|
||||
type EndpointError = Error;
|
||||
type OutgoingResponse = sso_login_with_provider::v3::Response;
|
||||
|
||||
const METADATA: Metadata = METADATA;
|
||||
|
||||
fn try_from_http_request<B, S>(
|
||||
req: http::Request<B>,
|
||||
_path_args: &[S],
|
||||
) -> Result<Self, FromHttpRequestError>
|
||||
where
|
||||
B: AsRef<[u8]>,
|
||||
S: AsRef<str>,
|
||||
{
|
||||
if !(req.method() == METADATA.method
|
||||
|| req.method() == Method::HEAD && METADATA.method == Method::GET)
|
||||
{
|
||||
return Err(FromHttpRequestError::MethodMismatch {
|
||||
expected: METADATA.method,
|
||||
received: req.method().clone(),
|
||||
});
|
||||
}
|
||||
|
||||
let response: AuthorizationResponse =
|
||||
serde_html_form::from_str(req.uri().query().unwrap_or(""))?;
|
||||
|
||||
let Some(cookie) = req
|
||||
.headers()
|
||||
.typed_get()
|
||||
.and_then(|cookie: headers::Cookie| {
|
||||
cookie.get(SSO_SESSION_COOKIE).map(str::to_owned)
|
||||
})
|
||||
else {
|
||||
return Err(HeaderDeserializationError::MissingHeader(
|
||||
"Cookie".to_owned(),
|
||||
))?;
|
||||
};
|
||||
|
||||
Ok(Self { response, cookie })
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,16 +1,27 @@
|
|||
use std::{
|
||||
collections::BTreeMap,
|
||||
borrow::Borrow,
|
||||
collections::{BTreeMap, HashSet},
|
||||
fmt,
|
||||
hash::{Hash, Hasher},
|
||||
net::{IpAddr, Ipv4Addr},
|
||||
};
|
||||
|
||||
use ruma::{OwnedServerName, RoomVersionId};
|
||||
use serde::{de::IgnoredAny, Deserialize};
|
||||
use figment::value::{Dict, Value};
|
||||
use mas_oidc_client::types::{client_credentials::ClientCredentials, scope::Scope};
|
||||
use ruma::{
|
||||
api::client::session::get_login_types::v3::IdentityProvider, OwnedServerName, RoomVersionId,
|
||||
};
|
||||
use serde::{
|
||||
de::{self, IgnoredAny},
|
||||
Deserialize, Deserializer, Serialize,
|
||||
};
|
||||
use tracing::warn;
|
||||
use url::Url;
|
||||
|
||||
mod proxy;
|
||||
|
||||
use crate::{Error, Result};
|
||||
|
||||
use self::proxy::ProxyConfig;
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
|
@ -67,6 +78,8 @@ pub struct Config {
|
|||
pub tracing_flame: bool,
|
||||
#[serde(default)]
|
||||
pub proxy: ProxyConfig,
|
||||
#[serde(default, deserialize_with = "deserialize_providers")]
|
||||
pub idps: HashSet<IdpConfig>,
|
||||
pub jwt_secret: Option<String>,
|
||||
#[serde(default = "default_trusted_servers")]
|
||||
pub trusted_servers: Vec<OwnedServerName>,
|
||||
|
@ -101,6 +114,27 @@ pub struct WellKnownConfig {
|
|||
pub server: Option<OwnedServerName>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct IdpConfig {
|
||||
pub issuer: String,
|
||||
#[serde(flatten)]
|
||||
pub inner: IdentityProvider,
|
||||
#[serde(deserialize_with = "deserialize_scopes")]
|
||||
pub scopes: Scope,
|
||||
|
||||
pub client_id: String,
|
||||
pub client_secret: String,
|
||||
pub auth_method: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
|
||||
pub struct Template {
|
||||
pub localpart: Option<String>,
|
||||
pub displayname: Option<String>,
|
||||
pub avatar_url: Option<String>,
|
||||
pub email: Option<String>,
|
||||
}
|
||||
|
||||
const DEPRECATED_KEYS: &[&str] = &["cache_capacity"];
|
||||
|
||||
impl Config {
|
||||
|
@ -244,6 +278,49 @@ impl fmt::Display for Config {
|
|||
}
|
||||
}
|
||||
|
||||
impl Borrow<str> for IdpConfig {
|
||||
fn borrow(&self) -> &str {
|
||||
&self.inner.id
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for IdpConfig {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.inner.id == other.inner.id
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for IdpConfig {}
|
||||
|
||||
impl Hash for IdpConfig {
|
||||
fn hash<H: Hasher>(&self, hasher: &mut H) {
|
||||
self.inner.id.hash(hasher)
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<ClientCredentials> for IdpConfig {
|
||||
fn into(self) -> ClientCredentials {
|
||||
let IdpConfig {
|
||||
client_id,
|
||||
client_secret,
|
||||
auth_method,
|
||||
..
|
||||
} = self;
|
||||
|
||||
match &*auth_method {
|
||||
"basic" => ClientCredentials::ClientSecretBasic {
|
||||
client_id,
|
||||
client_secret,
|
||||
},
|
||||
"post" => ClientCredentials::ClientSecretPost {
|
||||
client_id,
|
||||
client_secret,
|
||||
},
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn false_fn() -> bool {
|
||||
false
|
||||
}
|
||||
|
@ -312,3 +389,46 @@ fn default_openid_token_ttl() -> u64 {
|
|||
pub fn default_default_room_version() -> RoomVersionId {
|
||||
RoomVersionId::V10
|
||||
}
|
||||
|
||||
fn deserialize_scopes<'de, D>(deserializer: D) -> Result<Scope, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let scopes = <Vec<String>>::deserialize(deserializer)?;
|
||||
|
||||
scopes.join(" ").parse().map_err(de::Error::custom)
|
||||
}
|
||||
|
||||
fn deserialize_providers<'de, D>(deserializer: D) -> Result<HashSet<IdpConfig>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let mut result = HashSet::new();
|
||||
let dict = Dict::deserialize(deserializer)
|
||||
.map(Dict::into_iter)
|
||||
.map_err(de::Error::custom)?;
|
||||
warn!(?dict);
|
||||
|
||||
for (name, value) in dict {
|
||||
let tag = value.tag();
|
||||
|
||||
let Some(dict) = value.into_dict() else {
|
||||
return Err(de::Error::custom(Error::bad_config(
|
||||
"Invalid SSO configuration. ",
|
||||
)));
|
||||
};
|
||||
|
||||
let id = String::from("id");
|
||||
let name = name.parse().map_err(de::Error::custom)?;
|
||||
|
||||
let dict = Some((id, name)).into_iter().chain(dict).collect();
|
||||
|
||||
result.insert(
|
||||
Value::Dict(tag, dict)
|
||||
.deserialize()
|
||||
.map_err(de::Error::custom)?,
|
||||
);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ mod media;
|
|||
mod pusher;
|
||||
mod rooms;
|
||||
mod sending;
|
||||
mod sso;
|
||||
mod transaction_ids;
|
||||
mod uiaa;
|
||||
mod users;
|
||||
|
|
29
src/database/key_value/sso.rs
Normal file
29
src/database/key_value/sso.rs
Normal file
|
@ -0,0 +1,29 @@
|
|||
use ruma::{OwnedUserId, UserId};
|
||||
|
||||
use crate::{service, utils, Error, KeyValueDatabase, Result};
|
||||
|
||||
impl service::sso::Data for KeyValueDatabase {
|
||||
fn save_subject(&self, provider: &str, user_id: &UserId, subject: &str) -> Result<()> {
|
||||
let mut key = provider.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(subject.as_bytes());
|
||||
|
||||
self.providersubjectid_userid.insert(&key, user_id.as_bytes())
|
||||
}
|
||||
|
||||
fn user_from_subject(&self, provider: &str, subject: &str) -> Result<Option<OwnedUserId>> {
|
||||
let mut key = provider.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(subject.as_bytes());
|
||||
|
||||
self.providersubjectid_userid.get(&key)?.map_or(Ok(None), |bytes| {
|
||||
Some(
|
||||
UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("User ID in claim_userid is invalid unicode.")
|
||||
})?)
|
||||
.map_err(|_| Error::bad_database("User ID in claim_userid is invalid.")),
|
||||
)
|
||||
.transpose()
|
||||
})
|
||||
}
|
||||
}
|
|
@ -119,6 +119,10 @@ impl service::users::Data for KeyValueDatabase {
|
|||
}
|
||||
}
|
||||
|
||||
fn set_placeholder_password(&self, user_id: &UserId) -> Result<()> {
|
||||
self.userid_password.insert(user_id.as_bytes(), b"0xff")
|
||||
}
|
||||
|
||||
/// Returns the displayname of a user on this homeserver.
|
||||
fn displayname(&self, user_id: &UserId) -> Result<Option<String>> {
|
||||
self.userid_displayname
|
||||
|
|
|
@ -50,7 +50,6 @@ pub struct KeyValueDatabase {
|
|||
pub(super) userdeviceid_metadata: Arc<dyn KvTree>, // This is also used to check if a device exists
|
||||
pub(super) userid_devicelistversion: Arc<dyn KvTree>, // DevicelistVersion = u64
|
||||
pub(super) token_userdeviceid: Arc<dyn KvTree>,
|
||||
|
||||
pub(super) onetimekeyid_onetimekeys: Arc<dyn KvTree>, // OneTimeKeyId = UserId + DeviceKeyId
|
||||
pub(super) userid_lastonetimekeyupdate: Arc<dyn KvTree>, // LastOneTimeKeyUpdate = Count
|
||||
pub(super) keychangeid_userid: Arc<dyn KvTree>, // KeyChangeId = UserId/RoomId + Count
|
||||
|
@ -64,6 +63,9 @@ pub struct KeyValueDatabase {
|
|||
|
||||
pub(super) todeviceid_events: Arc<dyn KvTree>, // ToDeviceId = UserId + DeviceId + Count
|
||||
|
||||
pub(super) userid_providersubjectid: Arc<dyn KvTree>,
|
||||
pub(super) providersubjectid_userid: Arc<dyn KvTree>,
|
||||
|
||||
//pub uiaa: uiaa::Uiaa,
|
||||
pub(super) userdevicesessionid_uiaainfo: Arc<dyn KvTree>, // User-interactive authentication
|
||||
pub(super) userdevicesessionid_uiaarequest:
|
||||
|
@ -298,6 +300,9 @@ impl KeyValueDatabase {
|
|||
userfilterid_filter: builder.open_tree("userfilterid_filter")?,
|
||||
todeviceid_events: builder.open_tree("todeviceid_events")?,
|
||||
|
||||
userid_providersubjectid: builder.open_tree("userid_providersubjectid")?,
|
||||
providersubjectid_userid: builder.open_tree("providersubjectid_userid")?,
|
||||
|
||||
userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?,
|
||||
userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()),
|
||||
readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?,
|
||||
|
@ -1050,6 +1055,8 @@ impl KeyValueDatabase {
|
|||
|
||||
services().admin.start_handler();
|
||||
|
||||
services().sso.start_handler().await?;
|
||||
|
||||
// Set emergency access for the conduit user
|
||||
match set_emergency_access() {
|
||||
Ok(pwd_set) => {
|
||||
|
|
|
@ -292,6 +292,11 @@ fn routes(config: &Config) -> Router {
|
|||
.ruma_route(client_server::third_party_route)
|
||||
.ruma_route(client_server::request_3pid_management_token_via_email_route)
|
||||
.ruma_route(client_server::request_3pid_management_token_via_msisdn_route)
|
||||
.ruma_route(client_server::get_sso_redirect_route)
|
||||
.ruma_route(client_server::get_sso_redirect_with_provider_route)
|
||||
// The specification will likely never introduce any endpoint for handling authorization callbacks.
|
||||
// As a workaround, we use custom path that redirects the user to the default login handler.
|
||||
.ruma_route(client_server::handle_callback_route)
|
||||
.ruma_route(client_server::get_capabilities_route)
|
||||
.ruma_route(client_server::get_pushrules_all_route)
|
||||
.ruma_route(client_server::set_pushrule_route)
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
mod data;
|
||||
pub use data::{Data, SigningKeys};
|
||||
use ruma::{
|
||||
serde::Base64, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedEventId, OwnedRoomAliasId,
|
||||
OwnedRoomId, OwnedServerName, OwnedUserId, RoomAliasId,
|
||||
serde::Base64, signatures::KeyPair, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedEventId,
|
||||
OwnedRoomAliasId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomAliasId,
|
||||
};
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
|
||||
use crate::api::server_server::DestinationResponse;
|
||||
|
||||
|
@ -17,7 +18,7 @@ use ruma::{
|
|||
DeviceId, RoomVersionId, ServerName, UserId,
|
||||
};
|
||||
use std::{
|
||||
collections::{BTreeMap, HashMap},
|
||||
collections::{BTreeMap, HashMap, HashSet},
|
||||
error::Error as StdError,
|
||||
fs,
|
||||
future::{self, Future},
|
||||
|
@ -37,6 +38,9 @@ use tracing::{error, info};
|
|||
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
|
||||
// https://github.com/rust-lang/rust/issues/104699
|
||||
const PROBLEMATIC_CONST: &[u8] = b"0xCAFEBABE";
|
||||
|
||||
type WellKnownMap = HashMap<OwnedServerName, DestinationResponse>;
|
||||
type TlsNameMap = HashMap<String, (Vec<IpAddr>, u16)>;
|
||||
type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries
|
||||
|
@ -505,6 +509,36 @@ impl Service {
|
|||
self.config.well_known_client()
|
||||
}
|
||||
|
||||
pub fn sign_claims<S: Serialize>(&self, claims: &S) -> String {
|
||||
let key = jsonwebtoken::EncodingKey::from_secret(
|
||||
self.keypair().sign(PROBLEMATIC_CONST).as_bytes(),
|
||||
);
|
||||
|
||||
jsonwebtoken::encode(&jsonwebtoken::Header::default(), claims, &key)
|
||||
.expect("signing JWTs always works")
|
||||
}
|
||||
|
||||
/// Decode and validate a macaroon with this server's macaroon key.
|
||||
pub fn validate_claims<T: DeserializeOwned>(
|
||||
&self,
|
||||
token: &str,
|
||||
validation_data: Option<&jsonwebtoken::Validation>,
|
||||
) -> jsonwebtoken::errors::Result<T> {
|
||||
let key = jsonwebtoken::DecodingKey::from_secret(
|
||||
self.keypair().sign(PROBLEMATIC_CONST).as_bytes(),
|
||||
);
|
||||
|
||||
let mut v = jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256);
|
||||
|
||||
// these validations are redundant as all JWTs are stored in cookies
|
||||
v.validate_exp = false;
|
||||
v.validate_nbf = false;
|
||||
v.required_spec_claims = HashSet::new();
|
||||
|
||||
jsonwebtoken::decode::<T>(token, &key, validation_data.unwrap_or(&v))
|
||||
.map(|data| data.claims)
|
||||
}
|
||||
|
||||
pub fn shutdown(&self) {
|
||||
self.shutdown.store(true, atomic::Ordering::Relaxed);
|
||||
// On shutdown
|
||||
|
|
|
@ -19,6 +19,7 @@ pub mod pdu;
|
|||
pub mod pusher;
|
||||
pub mod rooms;
|
||||
pub mod sending;
|
||||
pub mod sso;
|
||||
pub mod transaction_ids;
|
||||
pub mod uiaa;
|
||||
pub mod users;
|
||||
|
@ -35,6 +36,7 @@ pub struct Services {
|
|||
pub globals: globals::Service,
|
||||
pub key_backups: key_backups::Service,
|
||||
pub media: media::Service,
|
||||
pub sso: Arc<sso::Service>,
|
||||
pub sending: Arc<sending::Service>,
|
||||
}
|
||||
|
||||
|
@ -51,6 +53,7 @@ impl Services {
|
|||
+ key_backups::Data
|
||||
+ media::Data
|
||||
+ sending::Data
|
||||
+ sso::Data
|
||||
+ 'static,
|
||||
>(
|
||||
db: &'static D,
|
||||
|
@ -120,6 +123,7 @@ impl Services {
|
|||
key_backups: key_backups::Service { db },
|
||||
media: media::Service { db },
|
||||
sending: sending::Service::build(db, &config),
|
||||
sso: sso::Service::build(db)?,
|
||||
|
||||
globals: globals::Service::load(db, config)?,
|
||||
})
|
||||
|
|
9
src/service/sso/data.rs
Normal file
9
src/service/sso/data.rs
Normal file
|
@ -0,0 +1,9 @@
|
|||
use ruma::{OwnedUserId, UserId};
|
||||
|
||||
use crate::Result;
|
||||
|
||||
pub trait Data: Send + Sync {
|
||||
fn save_subject(&self, provider: &str, user_id: &UserId, subject: &str) -> Result<()>;
|
||||
|
||||
fn user_from_subject(&self, provider: &str, subject: &str) -> Result<Option<OwnedUserId>>;
|
||||
}
|
213
src/service/sso/mod.rs
Normal file
213
src/service/sso/mod.rs
Normal file
|
@ -0,0 +1,213 @@
|
|||
use std::{
|
||||
borrow::Borrow,
|
||||
collections::HashSet,
|
||||
hash::{Hash, Hasher},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
api::client_server::{LOGIN_TOKEN_EXPIRATION_SECS, TOKEN_LENGTH},
|
||||
config::IdpConfig,
|
||||
utils, Error, Result,
|
||||
};
|
||||
use futures_util::future::{self};
|
||||
use http::HeaderValue;
|
||||
use mas_oidc_client::{
|
||||
http_service::HttpService,
|
||||
requests::{authorization_code::AuthorizationValidationData, discovery},
|
||||
types::oidc::VerifiedProviderMetadata,
|
||||
};
|
||||
use ruma::{api::client::session::get_login_types::v3::IdentityProvider, OwnedUserId, UserId};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::OnceCell;
|
||||
use tower::BoxError;
|
||||
use tower_http::{set_header::SetRequestHeaderLayer, ServiceBuilderExt};
|
||||
use tracing::error;
|
||||
use url::Url;
|
||||
|
||||
use crate::services;
|
||||
|
||||
mod data;
|
||||
pub use data::Data;
|
||||
|
||||
pub const SSO_AUTH_EXPIRATION_SECS: u64 = 60 * 60;
|
||||
pub const SSO_TOKEN_EXPIRATION_SECS: u64 = 60 * 2;
|
||||
pub const SSO_SESSION_COOKIE: &str = "sso-auth";
|
||||
pub const SUBJECT_CLAIM_KEY: &str = "sub";
|
||||
|
||||
pub struct Service {
|
||||
db: &'static dyn Data,
|
||||
service: HttpService,
|
||||
providers: OnceCell<HashSet<Provider>>,
|
||||
}
|
||||
|
||||
impl Service {
|
||||
pub fn build(db: &'static dyn Data) -> Result<Arc<Self>> {
|
||||
let client = tower::ServiceBuilder::new()
|
||||
.map_err(BoxError::from)
|
||||
.layer(tower_http::timeout::TimeoutLayer::new(
|
||||
std::time::Duration::from_secs(10),
|
||||
))
|
||||
.layer(mas_http::BytesToBodyRequestLayer)
|
||||
.layer(mas_http::BodyToBytesResponseLayer)
|
||||
.layer(SetRequestHeaderLayer::overriding(
|
||||
http::header::USER_AGENT,
|
||||
HeaderValue::from_static("conduit/0.9-alpha"),
|
||||
))
|
||||
.concurrency_limit(10)
|
||||
.follow_redirects()
|
||||
.service(mas_http::make_untraced_client());
|
||||
|
||||
Ok(Arc::new(Self {
|
||||
db,
|
||||
service: HttpService::new(client),
|
||||
providers: OnceCell::new(),
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn service(&self) -> &HttpService {
|
||||
&self.service
|
||||
}
|
||||
|
||||
pub async fn start_handler(&self) -> Result<()> {
|
||||
let providers = services().globals.config.idps.iter();
|
||||
|
||||
self.providers
|
||||
.get_or_try_init(|| async move {
|
||||
future::try_join_all(providers.map(Provider::fetch_metadata))
|
||||
.await
|
||||
.map(Vec::into_iter)
|
||||
.map(HashSet::from_iter)
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get(&self, provider: &str) -> Option<&Provider> {
|
||||
let providers = self.providers.get().expect("");
|
||||
|
||||
providers.get(provider)
|
||||
}
|
||||
|
||||
pub fn login_type(&self) -> impl Iterator<Item = IdentityProvider> + '_ {
|
||||
let providers = self.providers.get().expect("");
|
||||
|
||||
providers.iter().map(|p| p.config.inner.clone())
|
||||
}
|
||||
|
||||
pub fn user_from_subject(&self, provider: &str, subject: &str) -> Result<Option<OwnedUserId>> {
|
||||
self.db.user_from_subject(provider, subject)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Provider {
|
||||
pub config: &'static IdpConfig,
|
||||
pub metadata: VerifiedProviderMetadata,
|
||||
}
|
||||
|
||||
impl Provider {
|
||||
pub async fn fetch_metadata(config: &'static IdpConfig) -> Result<Self> {
|
||||
discovery::discover(services().sso.service(), &config.issuer)
|
||||
.await
|
||||
.map(|metadata| Provider { config, metadata })
|
||||
.map_err(|e| {
|
||||
error!(
|
||||
"Failed to fetch identity provider metadata ({}): {}",
|
||||
&config.inner.id, e
|
||||
);
|
||||
|
||||
Error::bad_config("Failed to fetch identity provider metadata.")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Borrow<str> for Provider {
|
||||
fn borrow(&self) -> &str {
|
||||
self.config.borrow()
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for Provider {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.config == other.config
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for Provider {}
|
||||
|
||||
impl Hash for Provider {
|
||||
fn hash<H: Hasher>(&self, hasher: &mut H) {
|
||||
self.config.hash(hasher)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize)]
|
||||
pub struct LoginToken {
|
||||
pub iss: String,
|
||||
pub aud: OwnedUserId,
|
||||
pub sub: String,
|
||||
pub exp: u64,
|
||||
}
|
||||
|
||||
impl LoginToken {
|
||||
pub fn new(provider: String, user_id: OwnedUserId) -> Self {
|
||||
Self {
|
||||
iss: provider,
|
||||
aud: user_id,
|
||||
sub: utils::random_string(TOKEN_LENGTH),
|
||||
exp: utils::millis_since_unix_epoch()
|
||||
.checked_add(LOGIN_TOKEN_EXPIRATION_SECS * 1000)
|
||||
.expect("time overflow"),
|
||||
}
|
||||
}
|
||||
pub fn audience(self) -> OwnedUserId {
|
||||
self.aud
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct ValidationData {
|
||||
pub provider: String,
|
||||
pub redirect_url: String,
|
||||
#[serde(flatten, with = "AuthorizationValidationDataDef")]
|
||||
pub inner: AuthorizationValidationData,
|
||||
}
|
||||
|
||||
impl ValidationData {
|
||||
pub fn new(provider: String, redirect_url: String, inner: AuthorizationValidationData) -> Self {
|
||||
Self {
|
||||
provider,
|
||||
redirect_url,
|
||||
inner,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
#[serde(remote = "AuthorizationValidationData")]
|
||||
pub struct AuthorizationValidationDataDef {
|
||||
pub state: String,
|
||||
pub nonce: String,
|
||||
pub redirect_uri: Url,
|
||||
pub code_challenge_verifier: Option<String>,
|
||||
}
|
||||
|
||||
impl From<AuthorizationValidationData> for AuthorizationValidationDataDef {
|
||||
fn from(
|
||||
AuthorizationValidationData {
|
||||
state,
|
||||
nonce,
|
||||
redirect_uri,
|
||||
code_challenge_verifier,
|
||||
}: AuthorizationValidationData,
|
||||
) -> Self {
|
||||
Self {
|
||||
state,
|
||||
nonce,
|
||||
redirect_uri,
|
||||
code_challenge_verifier,
|
||||
}
|
||||
}
|
||||
}
|
34
src/service/sso/templates.rs
Normal file
34
src/service/sso/templates.rs
Normal file
|
@ -0,0 +1,34 @@
|
|||
pub fn base(title: &str, body: maud::Markup) -> maud::Markup {
|
||||
maud::html! {
|
||||
(maud::DOCTYPE)
|
||||
html lang="en" {
|
||||
head {
|
||||
meta charset="utf-8";
|
||||
meta name="viewport" content="width=device-width, initial-scale=1.0";
|
||||
link rel="icon" type="image/png" sizes="32x32" href="https://conduit.rs/conduit.svg";
|
||||
style { (FONT_FACE) }
|
||||
title { (title) }
|
||||
}
|
||||
body { (body) }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn footer() -> maud::Markup {
|
||||
let info = "An open network for secure, decentralized communication.";
|
||||
|
||||
maud::html! {
|
||||
footer { p { (info) } }
|
||||
}
|
||||
}
|
||||
|
||||
const FONT_FACE: &str = r#"
|
||||
@font-face {
|
||||
font-family: 'Source Sans 3 Variable';
|
||||
font-style: normal;
|
||||
font-display: swap;
|
||||
font-weight: 200 900;
|
||||
src: url(https://cdn.jsdelivr.net/fontsource/fonts/source-sans-3:vf@latest/latin-wght-normal.woff2) format('woff2-variations');
|
||||
unicode-range: U+0000-00FF,U+0131,U+0152-0153,U+02BB-02BC,U+02C6,U+02DA,U+02DC,U+0304,U+0308,U+0329,U+2000-206F,U+2074,U+20AC,U+2122,U+2191,U+2193,U+2212,U+2215,U+FEFF,U+FFFD;
|
||||
}
|
||||
"#;
|
|
@ -217,4 +217,6 @@ pub trait Data: Send + Sync {
|
|||
|
||||
/// Find out which user an OpenID access token belongs to.
|
||||
fn find_from_openid_token(&self, token: &str) -> Result<Option<OwnedUserId>>;
|
||||
|
||||
fn set_placeholder_password(&self, user_id: &UserId) -> Result<()>;
|
||||
}
|
||||
|
|
|
@ -602,6 +602,10 @@ impl Service {
|
|||
pub fn find_from_openid_token(&self, token: &str) -> Result<Option<OwnedUserId>> {
|
||||
self.db.find_from_openid_token(token)
|
||||
}
|
||||
|
||||
pub fn set_placeholder_password(&self, user_id: &UserId) -> Result<()> {
|
||||
self.db.set_placeholder_password(user_id)
|
||||
}
|
||||
}
|
||||
|
||||
/// Ensure that a user only sees signatures from themselves and the target user
|
||||
|
|
|
@ -175,6 +175,22 @@ impl Error {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<mas_oidc_client::types::errors::ClientError> for Error {
|
||||
fn from(e: mas_oidc_client::types::errors::ClientError) -> Self {
|
||||
error!(
|
||||
"Failed to complete authorization callback: {} {}",
|
||||
e.error,
|
||||
e.error_description.as_deref().unwrap_or_default()
|
||||
);
|
||||
|
||||
// TODO: error conversion
|
||||
Self::BadRequest(
|
||||
ErrorKind::Unknown,
|
||||
"Failed to complete authorization callback.",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "persy")]
|
||||
impl<T: Into<PersyError>> From<persy::PE<T>> for Error {
|
||||
fn from(err: persy::PE<T>) -> Self {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
pub mod error;
|
||||
|
||||
use argon2::{Config, Variant};
|
||||
use axum_extra::extract::cookie::{Cookie, SameSite};
|
||||
use cmp::Ordering;
|
||||
use rand::prelude::*;
|
||||
use ring::digest;
|
||||
|
@ -8,7 +9,7 @@ use ruma::{canonical_json::try_from_json_map, CanonicalJsonError, CanonicalJsonO
|
|||
use std::{
|
||||
cmp, fmt,
|
||||
str::FromStr,
|
||||
time::{SystemTime, UNIX_EPOCH},
|
||||
time::{Duration, SystemTime, UNIX_EPOCH},
|
||||
};
|
||||
|
||||
pub fn millis_since_unix_epoch() -> u64 {
|
||||
|
@ -142,6 +143,29 @@ pub fn deserialize_from_str<
|
|||
deserializer.deserialize_str(Visitor(std::marker::PhantomData))
|
||||
}
|
||||
|
||||
pub fn build_cookie<'c>(
|
||||
name: &'c str,
|
||||
value: &'c str,
|
||||
path: &'c str,
|
||||
max_age: Option<u64>,
|
||||
) -> Cookie<'c> {
|
||||
let mut cookie = Cookie::new(name, value);
|
||||
|
||||
cookie.set_path(path);
|
||||
cookie.set_secure(true);
|
||||
cookie.set_http_only(true);
|
||||
cookie.set_same_site(SameSite::None);
|
||||
cookie.set_max_age(
|
||||
max_age
|
||||
.map(Duration::from_secs)
|
||||
.map(TryInto::try_into)
|
||||
.transpose()
|
||||
.expect("time overflow"),
|
||||
);
|
||||
|
||||
cookie
|
||||
}
|
||||
|
||||
// Copied from librustdoc:
|
||||
// https://github.com/rust-lang/rust/blob/cbaeec14f90b59a91a6b0f17fc046c66fa811892/src/librustdoc/html/escape.rs
|
||||
|
||||
|
|
Loading…
Reference in a new issue