1
0
Fork 0
mirror of https://gitlab.com/famedly/conduit.git synced 2025-04-22 14:10:16 +03:00

Merge branch 'sso-oidc' into 'next'

Single Sign-On via OIDC/OAuth2 (attempt )

Closes 

See merge request 
This commit is contained in:
avdb 2024-11-04 01:22:43 +00:00
commit da543a726e
21 changed files with 1104 additions and 46 deletions

View file

@ -35,13 +35,17 @@ axum = { version = "0.7", default-features = false, features = [
"json",
"matched-path",
], optional = true }
axum-extra = { version = "0.9", features = ["typed-header"] }
axum-extra = { version = "0.9", features = ["cookie", "typed-header"] }
axum-server = { version = "0.6", features = ["tls-rustls"] }
tower = { version = "0.4.13", features = ["util"] }
tower-http = { version = "0.5", features = [
"add-extension",
"cors",
"follow-redirect",
"map-request-body",
"sensitive-headers",
"set-header",
"timeout",
"trace",
"util",
] }
@ -172,6 +176,16 @@ optional = true
package = "rust-rocksdb"
version = "0.25"
[dependencies.mas-http]
features = ["client"]
git = "https://github.com/matrix-org/matrix-authentication-service"
rev = "fbc360d1a94ef2ebf63d979bb403228a700f43c8"
[dependencies.mas-oidc-client]
features = []
git = "https://github.com/matrix-org/matrix-authentication-service"
rev = "fbc360d1a94ef2ebf63d979bb403228a700f43c8"
[target.'cfg(unix)'.dependencies]
nix = { version = "0.28", features = ["resource"] }

View file

@ -13,6 +13,7 @@ Conduit's configuration file is divided into the following sections:
- [Global](#global)
- [TLS](#tls)
- [Proxy](#proxy)
- [SSO (Single Sign-On)](#sso)
## Global
@ -111,3 +112,20 @@ exclude = ["*.clearnet.onion"]
[global]
{{#include ../conduit-example.toml:22:}}
```
### SSO (Single Sign-On)
Authentication through SSO instead of a password can be enabled by configuring OIDC (OpenID Connect) identity providers.
Identity providers using OAuth such as Github are not supported yet.
> **Note:** The `*` symbol indicates that the field is required, and the values in **parentheses** are the possible values
| Field | Type | Description | Default |
| --- | --- | --- | --- |
| `issuer`* | `Url` | The issuer URL. | N/A |
| `name` | `string` | The name displayed on fallback pages. | `issuer` |
| `icon` | `Url` OR `MxcUri` | The icon displayed on fallback pages. | N/A |
| `scopes` | `array` | The scopes used to obtain extra claims which can be used for templates. | `["openid"]` |
| `client_id`* | `string` | The provider-supplied, unique ID for the client. | N/A |
| `client_secret`* | `string` | The provider-supplied, unique ID for the client. | N/A |
| `authentication_method`* | `"basic" OR "post"` | The method used for client authentication. | N/A |

View file

@ -100,6 +100,12 @@ pub async fn upload_signing_keys_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
let master_key = services()
.users
.get_master_key(Some(sender_user), sender_user, &|other| {
sender_user == other
})?;
// UIAA
let mut uiaainfo = UiaaInfo {
flows: vec![AuthFlow {
@ -111,7 +117,15 @@ pub async fn upload_signing_keys_route(
auth_error: None,
};
if let Some(auth) = &body.auth {
if let (Some(master_key), None) = (&body.master_key, master_key) {
services().users.add_cross_signing_keys(
sender_user,
master_key,
&body.self_signing_key,
&body.user_signing_key,
true,
)?;
} else if let Some(auth) = &body.auth {
let (worked, uiaainfo) =
services()
.uiaa
@ -130,16 +144,6 @@ pub async fn upload_signing_keys_route(
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
}
if let Some(master_key) = &body.master_key {
services().users.add_cross_signing_keys(
sender_user,
master_key,
&body.self_signing_key,
&body.user_signing_key,
true, // notify so that other users see the new keys
)?;
}
Ok(upload_signing_keys::v3::Response {})
}

View file

@ -23,6 +23,7 @@ mod room;
mod search;
mod session;
mod space;
mod sso;
mod state;
mod sync;
mod tag;
@ -60,6 +61,7 @@ pub use room::*;
pub use search::*;
pub use session::*;
pub use space::*;
pub use sso::*;
pub use state::*;
pub use sync::*;
pub use tag::*;
@ -76,3 +78,5 @@ pub const DEVICE_ID_LENGTH: usize = 10;
pub const TOKEN_LENGTH: usize = 32;
pub const SESSION_ID_LENGTH: usize = 32;
pub const AUTO_GEN_PASSWORD_LENGTH: usize = 15;
pub const AUTH_SESSION_EXPIRATION_SECS: u64 = 60 * 5;
pub const LOGIN_TOKEN_EXPIRATION_SECS: u64 = 15;

View file

@ -1,5 +1,6 @@
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
use crate::{services, utils, Error, Result, Ruma};
use crate::{service::sso::LoginToken, services, utils, Error, Result, Ruma};
use jsonwebtoken::{Algorithm, Validation};
use ruma::{
api::client::{
error::ErrorKind,
@ -24,10 +25,19 @@ struct Claims {
pub async fn get_login_types_route(
_body: Ruma<get_login_types::v3::Request>,
) -> Result<get_login_types::v3::Response> {
Ok(get_login_types::v3::Response::new(vec![
let identity_providers: Vec<_> = services().sso.login_type().collect();
let mut flows = vec![
get_login_types::v3::LoginType::Password(Default::default()),
get_login_types::v3::LoginType::ApplicationService(Default::default()),
]))
];
if !identity_providers.is_empty() {
flows.push(get_login_types::v3::LoginType::Sso(
get_login_types::v3::SsoLoginType { identity_providers },
));
}
Ok(get_login_types::v3::Response::new(flows))
}
/// # `POST /_matrix/client/r0/login`
@ -101,35 +111,64 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
user_id
}
login::v3::LoginInfo::Token(login::v3::Token { token }) => {
if let Some(jwt_decoding_key) = services().globals.jwt_decoding_key() {
let token = jsonwebtoken::decode::<Claims>(
token,
jwt_decoding_key,
&jsonwebtoken::Validation::default(),
)
.map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid."))?;
let username = token.claims.sub.to_lowercase();
let user_id =
UserId::parse_with_server_name(username, services().globals.server_name())
.map_err(|_| {
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
})?;
match (
services().globals.jwt_decoding_key(),
services().globals.config.idps.is_empty(),
) {
(_, false) => {
let mut v = Validation::new(Algorithm::HS256);
if services().appservice.is_exclusive_user_id(&user_id).await {
v.set_required_spec_claims(&["sub", "exp", "aud", "iss"]);
v.validate_aud = false;
v.validate_nbf = false;
services()
.globals
.validate_claims::<LoginToken>(token, Some(&v))
.map(LoginToken::audience)
.map_err(|e| {
tracing::warn!("Invalid token: {}", e);
Error::BadRequest(ErrorKind::InvalidParam, "Invalid token.")
})?
}
(Some(jwt_decoding_key), _) => {
let token = jsonwebtoken::decode::<Claims>(
token,
jwt_decoding_key,
&Validation::default(),
)
.map_err(|_| {
Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid.")
})?;
let username = token.claims.sub.to_lowercase();
let user_id =
UserId::parse_with_server_name(username, services().globals.server_name())
.map_err(|_| {
Error::BadRequest(
ErrorKind::InvalidUsername,
"Username is invalid.",
)
})?;
if services().appservice.is_exclusive_user_id(&user_id).await {
return Err(Error::BadRequest(
ErrorKind::Exclusive,
"User id reserved by appservice.",
));
}
user_id
}
(None, _) => {
return Err(Error::BadRequest(
ErrorKind::Exclusive,
"User id reserved by appservice.",
ErrorKind::Unknown,
"Token login is not supported (server has no jwt decoding key).",
));
}
user_id
} else {
return Err(Error::BadRequest(
ErrorKind::Unknown,
"Token login is not supported (server has no jwt decoding key).",
));
}
}
login::v3::LoginInfo::ApplicationService(login::v3::ApplicationService {
identifier,
user,

View file

@ -0,0 +1,473 @@
use std::{borrow::Borrow, collections::HashMap, iter::Iterator, time::SystemTime};
use crate::{
config::IdpConfig,
service::sso::{
LoginToken, ValidationData, SSO_AUTH_EXPIRATION_SECS, SSO_SESSION_COOKIE, SUBJECT_CLAIM_KEY,
},
services, utils, Error, Result, Ruma,
};
use futures_util::TryFutureExt;
use mas_oidc_client::{
requests::{
authorization_code::{self, AuthorizationRequestData},
jose::{self, JwtVerificationData},
userinfo,
},
types::{
client_credentials::ClientCredentials,
iana::jose::JsonWebSignatureAlg,
requests::{AccessTokenResponse, AuthorizationResponse},
},
};
use rand::{rngs::StdRng, Rng, SeedableRng};
use ruma::{
api::client::{
error::ErrorKind,
media::create_content,
session::{sso_login, sso_login_with_provider},
},
events::{room::message::RoomMessageEventContent, GlobalAccountDataEventType},
push, UserId,
};
use serde_json::Value;
use tracing::{error, info, warn};
use url::Url;
pub const CALLBACK_PATH: &str = "/_matrix/client/unstable/conduit/callback";
/// # `GET /_matrix/client/v3/login/sso/redirect`
///
/// Redirect the user to the SSO interfa.
/// TODO: this should be removed once Ruma supports trailing slashes.
pub async fn get_sso_redirect_route(
Ruma {
body,
sender_user,
sender_device,
sender_servername,
json_body,
..
}: Ruma<sso_login::v3::Request>,
) -> Result<sso_login::v3::Response> {
let sso_login_with_provider::v3::Response { location, cookie } =
get_sso_redirect_with_provider_route(
Ruma {
body: sso_login_with_provider::v3::Request::new(
Default::default(),
body.redirect_url,
),
sender_user,
sender_device,
sender_servername,
json_body,
appservice_info: None,
}
.into(),
)
.await?;
Ok(sso_login::v3::Response { location, cookie })
}
/// # `GET /_matrix/client/v3/login/sso/redirect/{idpId}`
///
/// Redirects the user to the SSO interface.
pub async fn get_sso_redirect_with_provider_route(
body: Ruma<sso_login_with_provider::v3::Request>,
) -> Result<sso_login_with_provider::v3::Response> {
let idp_ids: Vec<&str> = services()
.globals
.config
.idps
.iter()
.map(Borrow::borrow)
.collect();
let provider = match &*idp_ids {
[] => {
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"Single Sign-On is disabled.",
));
}
[idp_id] => services().sso.get(idp_id).expect("we know it exists"),
[_, ..] => services().sso.get(&body.idp_id).ok_or_else(|| {
Error::BadRequest(ErrorKind::InvalidParam, "Unknown identity provider.")
})?,
};
let redirect_url = body
.redirect_url
.parse::<Url>()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid redirect_url."))?;
let mut callback = services()
.globals
.well_known_client()
.parse::<Url>()
.map_err(|_| Error::bad_config("Invalid well_known_client url."))?;
callback.set_path(CALLBACK_PATH);
let (auth_url, validation_data) = authorization_code::build_authorization_url(
provider.metadata.authorization_endpoint().clone(),
AuthorizationRequestData::new(
provider.config.client_id.clone(),
provider.config.scopes.clone(),
callback,
),
&mut StdRng::from_entropy(),
)
.map_err(|_| Error::BadRequest(ErrorKind::Unknown, "Failed to build authorization_url."))?;
let signed = services().globals.sign_claims(&ValidationData::new(
Borrow::<str>::borrow(provider).to_owned(),
redirect_url.to_string(),
validation_data,
));
Ok(sso_login_with_provider::v3::Response {
location: auth_url.to_string(),
cookie: Some(
utils::build_cookie(
SSO_SESSION_COOKIE,
&signed,
CALLBACK_PATH,
Some(SSO_AUTH_EXPIRATION_SECS),
)
.to_string(),
),
})
}
/// # `GET /_conduit/client/sso/callback`
///
/// Validate the authorization response received from the identity provider.
/// On success, generate a login token, add it to `redirectUrl` as a query and perform the redirect.
/// If this is the first login, register the user, possibly interactively through a fallback page.
pub async fn handle_callback_route(
body: Ruma<sso_callback::Request>,
) -> Result<sso_login_with_provider::v3::Response> {
let sso_callback::Request {
response:
AuthorizationResponse {
code,
access_token: _,
token_type: _,
id_token: _,
expires_in: _,
},
cookie,
} = body.body;
let ValidationData {
provider,
redirect_url,
inner: validation_data,
} = services()
.globals
.validate_claims(&cookie, None)
.map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Invalid value for session cookie.")
})?;
let provider = services().sso.get(&provider).ok_or_else(|| {
Error::BadRequest(
ErrorKind::InvalidParam,
"Unknown provider for session cookie.",
)
})?;
let IdpConfig {
client_id,
client_secret,
auth_method,
..
} = provider.config.clone();
let credentials = match &*auth_method {
"basic" => ClientCredentials::ClientSecretBasic {
client_id,
client_secret,
},
"post" => ClientCredentials::ClientSecretPost {
client_id,
client_secret,
},
_ => todo!(),
};
let ref jwks = jose::fetch_jwks(services().sso.service(), provider.metadata.jwks_uri())
.await
.map_err(|_| Error::bad_config("Failed to fetch signing keys for token endpoint."))?;
let idt_verification_data = Some(JwtVerificationData {
jwks,
issuer: &provider.config.issuer,
client_id: &provider.config.client_id,
signing_algorithm: &JsonWebSignatureAlg::Rs256,
});
let (
AccessTokenResponse {
access_token,
refresh_token: _,
token_type: _,
expires_in: _,
scope: _,
..
},
Some(id_token),
) = authorization_code::access_token_with_authorization_code(
services().sso.service(),
credentials,
provider.metadata.token_endpoint(),
code.unwrap_or_default(),
validation_data,
idt_verification_data,
SystemTime::now().into(),
&mut StdRng::from_entropy(),
)
.await
.map_err(|_| Error::bad_config("Failed to fetch access token."))?
else {
unreachable!("ID token should never be empty")
};
let mut userinfo = HashMap::default();
if let Some(endpoint) = provider.metadata.userinfo_endpoint.as_ref() {
userinfo = userinfo::fetch_userinfo(
services().sso.service(),
endpoint,
&access_token,
None,
&id_token,
)
.await
.map_err(|e| {
tracing::error!("Failed to fetch claims for userinfo endpoint: {:?}", e);
Error::bad_config("Failed to fetch claims for userinfo endpoint.")
})?;
}
let (_, id_token) = id_token.into_parts();
info!("userinfo: {:?}", &userinfo);
info!("id_token: {:?}", &id_token);
let subject = match id_token.get(SUBJECT_CLAIM_KEY) {
Some(Value::String(s)) => s.to_owned(),
Some(Value::Number(n)) => n.to_string(),
value => {
return Err(Error::BadRequest(
ErrorKind::Unknown,
value
.map(|_| {
error!("Subject claim is missing from ID token: {id_token:?}");
"Subject claim is missing from ID token."
})
.unwrap_or("Subject claim should be a string or number."),
));
}
};
let user_id = match services()
.sso
.user_from_subject(Borrow::<str>::borrow(provider), &subject)?
{
Some(user_id) => user_id,
None => {
let mut localpart = subject.clone();
let user_id = loop {
match UserId::parse_with_server_name(&*localpart, services().globals.server_name())
.map(|user_id| {
(
user_id.clone(),
services().users.exists(&user_id).unwrap_or(true),
)
}) {
Ok((user_id, false)) => break user_id,
_ => {
let n: u8 = rand::thread_rng().gen();
localpart = format!("{}{}", localpart, n % 10);
}
}
};
services().users.set_placeholder_password(&user_id)?;
let displayname = id_token
.get("preferred_username")
.or(id_token.get("nickname"));
let mut displayname = displayname
.as_deref()
.map(Value::as_str)
.flatten()
.unwrap_or(user_id.localpart())
.to_owned();
// If enabled append lightning bolt to display name (default true)
if services().globals.enable_lightning_bolt() {
displayname.push_str(" ⚡️");
}
services()
.users
.set_displayname(&user_id, Some(displayname.clone()))?;
if let Some(Value::String(url)) = userinfo.get("picture").or(id_token.get("picture")) {
let req = services()
.globals
.default_client()
.get(url)
.send()
.and_then(reqwest::Response::bytes);
if let Ok(file) = req.await {
let _ = crate::api::client_server::create_content_route(Ruma {
body: create_content::v3::Request::new(file.to_vec()),
sender_user: None,
sender_device: None,
sender_servername: None,
json_body: None,
appservice_info: None,
})
.await
.and_then(|res| {
tracing::info!("successfully imported avatar for {}", &user_id);
services()
.users
.set_avatar_url(&user_id, Some(res.content_uri))
});
}
}
// Initial account data
services().account_data.update(
None,
&user_id,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(ruma::events::push_rules::PushRulesEvent {
content: ruma::events::push_rules::PushRulesEventContent {
global: push::Ruleset::server_default(&user_id),
},
})
.expect("to json always works"),
)?;
info!("New user {} registered on this server.", user_id);
services()
.admin
.send_message(RoomMessageEventContent::notice_plain(format!(
"New user {user_id} registered on this server."
)));
if let Some(admin_room) = services().admin.get_admin_room()? {
if services()
.rooms
.state_cache
.room_joined_count(&admin_room)?
== Some(1)
{
services()
.admin
.make_user_admin(&user_id, displayname.to_owned())
.await?;
warn!("Granting {} admin privileges as the first user", user_id);
}
}
user_id
}
};
let signed = services().globals.sign_claims(&LoginToken::new(
Borrow::<str>::borrow(provider).to_owned(),
user_id,
));
let mut redirect_url: Url = redirect_url.parse().expect("");
redirect_url
.query_pairs_mut()
.append_pair("loginToken", &signed);
Ok(sso_login_with_provider::v3::Response {
location: redirect_url.to_string(),
cookie: Some(utils::build_cookie(SSO_SESSION_COOKIE, "", CALLBACK_PATH, None).to_string()),
})
}
mod sso_callback {
use axum_extra::headers::{self, HeaderMapExt};
use http::Method;
use mas_oidc_client::types::requests::AuthorizationResponse;
use ruma::{
api::{
client::{session::sso_login_with_provider, Error},
error::{FromHttpRequestError, HeaderDeserializationError},
IncomingRequest, Metadata,
},
metadata,
};
use crate::service::sso::SSO_SESSION_COOKIE;
pub const METADATA: Metadata = metadata! {
method: GET,
rate_limited: false,
authentication: None,
history: {
1.0 => "/_matrix/client/unstable/conduit/callback",
}
};
pub struct Request {
pub response: AuthorizationResponse,
pub cookie: String,
}
impl IncomingRequest for Request {
type EndpointError = Error;
type OutgoingResponse = sso_login_with_provider::v3::Response;
const METADATA: Metadata = METADATA;
fn try_from_http_request<B, S>(
req: http::Request<B>,
_path_args: &[S],
) -> Result<Self, FromHttpRequestError>
where
B: AsRef<[u8]>,
S: AsRef<str>,
{
if !(req.method() == METADATA.method
|| req.method() == Method::HEAD && METADATA.method == Method::GET)
{
return Err(FromHttpRequestError::MethodMismatch {
expected: METADATA.method,
received: req.method().clone(),
});
}
let response: AuthorizationResponse =
serde_html_form::from_str(req.uri().query().unwrap_or(""))?;
let Some(cookie) = req
.headers()
.typed_get()
.and_then(|cookie: headers::Cookie| {
cookie.get(SSO_SESSION_COOKIE).map(str::to_owned)
})
else {
return Err(HeaderDeserializationError::MissingHeader(
"Cookie".to_owned(),
))?;
};
Ok(Self { response, cookie })
}
}
}

View file

@ -1,16 +1,27 @@
use std::{
collections::BTreeMap,
borrow::Borrow,
collections::{BTreeMap, HashSet},
fmt,
hash::{Hash, Hasher},
net::{IpAddr, Ipv4Addr},
};
use ruma::{OwnedServerName, RoomVersionId};
use serde::{de::IgnoredAny, Deserialize};
use figment::value::{Dict, Value};
use mas_oidc_client::types::{client_credentials::ClientCredentials, scope::Scope};
use ruma::{
api::client::session::get_login_types::v3::IdentityProvider, OwnedServerName, RoomVersionId,
};
use serde::{
de::{self, IgnoredAny},
Deserialize, Deserializer, Serialize,
};
use tracing::warn;
use url::Url;
mod proxy;
use crate::{Error, Result};
use self::proxy::ProxyConfig;
#[derive(Clone, Debug, Deserialize)]
@ -67,6 +78,8 @@ pub struct Config {
pub tracing_flame: bool,
#[serde(default)]
pub proxy: ProxyConfig,
#[serde(default, deserialize_with = "deserialize_providers")]
pub idps: HashSet<IdpConfig>,
pub jwt_secret: Option<String>,
#[serde(default = "default_trusted_servers")]
pub trusted_servers: Vec<OwnedServerName>,
@ -101,6 +114,27 @@ pub struct WellKnownConfig {
pub server: Option<OwnedServerName>,
}
#[derive(Clone, Debug, Deserialize)]
pub struct IdpConfig {
pub issuer: String,
#[serde(flatten)]
pub inner: IdentityProvider,
#[serde(deserialize_with = "deserialize_scopes")]
pub scopes: Scope,
pub client_id: String,
pub client_secret: String,
pub auth_method: String,
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct Template {
pub localpart: Option<String>,
pub displayname: Option<String>,
pub avatar_url: Option<String>,
pub email: Option<String>,
}
const DEPRECATED_KEYS: &[&str] = &["cache_capacity"];
impl Config {
@ -244,6 +278,49 @@ impl fmt::Display for Config {
}
}
impl Borrow<str> for IdpConfig {
fn borrow(&self) -> &str {
&self.inner.id
}
}
impl PartialEq for IdpConfig {
fn eq(&self, other: &Self) -> bool {
self.inner.id == other.inner.id
}
}
impl Eq for IdpConfig {}
impl Hash for IdpConfig {
fn hash<H: Hasher>(&self, hasher: &mut H) {
self.inner.id.hash(hasher)
}
}
impl Into<ClientCredentials> for IdpConfig {
fn into(self) -> ClientCredentials {
let IdpConfig {
client_id,
client_secret,
auth_method,
..
} = self;
match &*auth_method {
"basic" => ClientCredentials::ClientSecretBasic {
client_id,
client_secret,
},
"post" => ClientCredentials::ClientSecretPost {
client_id,
client_secret,
},
_ => unimplemented!(),
}
}
}
fn false_fn() -> bool {
false
}
@ -312,3 +389,46 @@ fn default_openid_token_ttl() -> u64 {
pub fn default_default_room_version() -> RoomVersionId {
RoomVersionId::V10
}
fn deserialize_scopes<'de, D>(deserializer: D) -> Result<Scope, D::Error>
where
D: Deserializer<'de>,
{
let scopes = <Vec<String>>::deserialize(deserializer)?;
scopes.join(" ").parse().map_err(de::Error::custom)
}
fn deserialize_providers<'de, D>(deserializer: D) -> Result<HashSet<IdpConfig>, D::Error>
where
D: Deserializer<'de>,
{
let mut result = HashSet::new();
let dict = Dict::deserialize(deserializer)
.map(Dict::into_iter)
.map_err(de::Error::custom)?;
warn!(?dict);
for (name, value) in dict {
let tag = value.tag();
let Some(dict) = value.into_dict() else {
return Err(de::Error::custom(Error::bad_config(
"Invalid SSO configuration. ",
)));
};
let id = String::from("id");
let name = name.parse().map_err(de::Error::custom)?;
let dict = Some((id, name)).into_iter().chain(dict).collect();
result.insert(
Value::Dict(tag, dict)
.deserialize()
.map_err(de::Error::custom)?,
);
}
Ok(result)
}

View file

@ -8,6 +8,7 @@ mod media;
mod pusher;
mod rooms;
mod sending;
mod sso;
mod transaction_ids;
mod uiaa;
mod users;

View file

@ -0,0 +1,29 @@
use ruma::{OwnedUserId, UserId};
use crate::{service, utils, Error, KeyValueDatabase, Result};
impl service::sso::Data for KeyValueDatabase {
fn save_subject(&self, provider: &str, user_id: &UserId, subject: &str) -> Result<()> {
let mut key = provider.as_bytes().to_vec();
key.push(0xff);
key.extend_from_slice(subject.as_bytes());
self.providersubjectid_userid.insert(&key, user_id.as_bytes())
}
fn user_from_subject(&self, provider: &str, subject: &str) -> Result<Option<OwnedUserId>> {
let mut key = provider.as_bytes().to_vec();
key.push(0xff);
key.extend_from_slice(subject.as_bytes());
self.providersubjectid_userid.get(&key)?.map_or(Ok(None), |bytes| {
Some(
UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("User ID in claim_userid is invalid unicode.")
})?)
.map_err(|_| Error::bad_database("User ID in claim_userid is invalid.")),
)
.transpose()
})
}
}

View file

@ -119,6 +119,10 @@ impl service::users::Data for KeyValueDatabase {
}
}
fn set_placeholder_password(&self, user_id: &UserId) -> Result<()> {
self.userid_password.insert(user_id.as_bytes(), b"0xff")
}
/// Returns the displayname of a user on this homeserver.
fn displayname(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_displayname

View file

@ -50,7 +50,6 @@ pub struct KeyValueDatabase {
pub(super) userdeviceid_metadata: Arc<dyn KvTree>, // This is also used to check if a device exists
pub(super) userid_devicelistversion: Arc<dyn KvTree>, // DevicelistVersion = u64
pub(super) token_userdeviceid: Arc<dyn KvTree>,
pub(super) onetimekeyid_onetimekeys: Arc<dyn KvTree>, // OneTimeKeyId = UserId + DeviceKeyId
pub(super) userid_lastonetimekeyupdate: Arc<dyn KvTree>, // LastOneTimeKeyUpdate = Count
pub(super) keychangeid_userid: Arc<dyn KvTree>, // KeyChangeId = UserId/RoomId + Count
@ -64,6 +63,9 @@ pub struct KeyValueDatabase {
pub(super) todeviceid_events: Arc<dyn KvTree>, // ToDeviceId = UserId + DeviceId + Count
pub(super) userid_providersubjectid: Arc<dyn KvTree>,
pub(super) providersubjectid_userid: Arc<dyn KvTree>,
//pub uiaa: uiaa::Uiaa,
pub(super) userdevicesessionid_uiaainfo: Arc<dyn KvTree>, // User-interactive authentication
pub(super) userdevicesessionid_uiaarequest:
@ -298,6 +300,9 @@ impl KeyValueDatabase {
userfilterid_filter: builder.open_tree("userfilterid_filter")?,
todeviceid_events: builder.open_tree("todeviceid_events")?,
userid_providersubjectid: builder.open_tree("userid_providersubjectid")?,
providersubjectid_userid: builder.open_tree("providersubjectid_userid")?,
userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?,
userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()),
readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?,
@ -1050,6 +1055,8 @@ impl KeyValueDatabase {
services().admin.start_handler();
services().sso.start_handler().await?;
// Set emergency access for the conduit user
match set_emergency_access() {
Ok(pwd_set) => {

View file

@ -292,6 +292,11 @@ fn routes(config: &Config) -> Router {
.ruma_route(client_server::third_party_route)
.ruma_route(client_server::request_3pid_management_token_via_email_route)
.ruma_route(client_server::request_3pid_management_token_via_msisdn_route)
.ruma_route(client_server::get_sso_redirect_route)
.ruma_route(client_server::get_sso_redirect_with_provider_route)
// The specification will likely never introduce any endpoint for handling authorization callbacks.
// As a workaround, we use custom path that redirects the user to the default login handler.
.ruma_route(client_server::handle_callback_route)
.ruma_route(client_server::get_capabilities_route)
.ruma_route(client_server::get_pushrules_all_route)
.ruma_route(client_server::set_pushrule_route)

View file

@ -1,9 +1,10 @@
mod data;
pub use data::{Data, SigningKeys};
use ruma::{
serde::Base64, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedEventId, OwnedRoomAliasId,
OwnedRoomId, OwnedServerName, OwnedUserId, RoomAliasId,
serde::Base64, signatures::KeyPair, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedEventId,
OwnedRoomAliasId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomAliasId,
};
use serde::{de::DeserializeOwned, Serialize};
use crate::api::server_server::DestinationResponse;
@ -17,7 +18,7 @@ use ruma::{
DeviceId, RoomVersionId, ServerName, UserId,
};
use std::{
collections::{BTreeMap, HashMap},
collections::{BTreeMap, HashMap, HashSet},
error::Error as StdError,
fs,
future::{self, Future},
@ -37,6 +38,9 @@ use tracing::{error, info};
use base64::{engine::general_purpose, Engine as _};
// https://github.com/rust-lang/rust/issues/104699
const PROBLEMATIC_CONST: &[u8] = b"0xCAFEBABE";
type WellKnownMap = HashMap<OwnedServerName, DestinationResponse>;
type TlsNameMap = HashMap<String, (Vec<IpAddr>, u16)>;
type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries
@ -505,6 +509,36 @@ impl Service {
self.config.well_known_client()
}
pub fn sign_claims<S: Serialize>(&self, claims: &S) -> String {
let key = jsonwebtoken::EncodingKey::from_secret(
self.keypair().sign(PROBLEMATIC_CONST).as_bytes(),
);
jsonwebtoken::encode(&jsonwebtoken::Header::default(), claims, &key)
.expect("signing JWTs always works")
}
/// Decode and validate a macaroon with this server's macaroon key.
pub fn validate_claims<T: DeserializeOwned>(
&self,
token: &str,
validation_data: Option<&jsonwebtoken::Validation>,
) -> jsonwebtoken::errors::Result<T> {
let key = jsonwebtoken::DecodingKey::from_secret(
self.keypair().sign(PROBLEMATIC_CONST).as_bytes(),
);
let mut v = jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256);
// these validations are redundant as all JWTs are stored in cookies
v.validate_exp = false;
v.validate_nbf = false;
v.required_spec_claims = HashSet::new();
jsonwebtoken::decode::<T>(token, &key, validation_data.unwrap_or(&v))
.map(|data| data.claims)
}
pub fn shutdown(&self) {
self.shutdown.store(true, atomic::Ordering::Relaxed);
// On shutdown

View file

@ -19,6 +19,7 @@ pub mod pdu;
pub mod pusher;
pub mod rooms;
pub mod sending;
pub mod sso;
pub mod transaction_ids;
pub mod uiaa;
pub mod users;
@ -35,6 +36,7 @@ pub struct Services {
pub globals: globals::Service,
pub key_backups: key_backups::Service,
pub media: media::Service,
pub sso: Arc<sso::Service>,
pub sending: Arc<sending::Service>,
}
@ -51,6 +53,7 @@ impl Services {
+ key_backups::Data
+ media::Data
+ sending::Data
+ sso::Data
+ 'static,
>(
db: &'static D,
@ -120,6 +123,7 @@ impl Services {
key_backups: key_backups::Service { db },
media: media::Service { db },
sending: sending::Service::build(db, &config),
sso: sso::Service::build(db)?,
globals: globals::Service::load(db, config)?,
})

9
src/service/sso/data.rs Normal file
View file

@ -0,0 +1,9 @@
use ruma::{OwnedUserId, UserId};
use crate::Result;
pub trait Data: Send + Sync {
fn save_subject(&self, provider: &str, user_id: &UserId, subject: &str) -> Result<()>;
fn user_from_subject(&self, provider: &str, subject: &str) -> Result<Option<OwnedUserId>>;
}

213
src/service/sso/mod.rs Normal file
View file

@ -0,0 +1,213 @@
use std::{
borrow::Borrow,
collections::HashSet,
hash::{Hash, Hasher},
sync::Arc,
};
use crate::{
api::client_server::{LOGIN_TOKEN_EXPIRATION_SECS, TOKEN_LENGTH},
config::IdpConfig,
utils, Error, Result,
};
use futures_util::future::{self};
use http::HeaderValue;
use mas_oidc_client::{
http_service::HttpService,
requests::{authorization_code::AuthorizationValidationData, discovery},
types::oidc::VerifiedProviderMetadata,
};
use ruma::{api::client::session::get_login_types::v3::IdentityProvider, OwnedUserId, UserId};
use serde::{Deserialize, Serialize};
use tokio::sync::OnceCell;
use tower::BoxError;
use tower_http::{set_header::SetRequestHeaderLayer, ServiceBuilderExt};
use tracing::error;
use url::Url;
use crate::services;
mod data;
pub use data::Data;
pub const SSO_AUTH_EXPIRATION_SECS: u64 = 60 * 60;
pub const SSO_TOKEN_EXPIRATION_SECS: u64 = 60 * 2;
pub const SSO_SESSION_COOKIE: &str = "sso-auth";
pub const SUBJECT_CLAIM_KEY: &str = "sub";
pub struct Service {
db: &'static dyn Data,
service: HttpService,
providers: OnceCell<HashSet<Provider>>,
}
impl Service {
pub fn build(db: &'static dyn Data) -> Result<Arc<Self>> {
let client = tower::ServiceBuilder::new()
.map_err(BoxError::from)
.layer(tower_http::timeout::TimeoutLayer::new(
std::time::Duration::from_secs(10),
))
.layer(mas_http::BytesToBodyRequestLayer)
.layer(mas_http::BodyToBytesResponseLayer)
.layer(SetRequestHeaderLayer::overriding(
http::header::USER_AGENT,
HeaderValue::from_static("conduit/0.9-alpha"),
))
.concurrency_limit(10)
.follow_redirects()
.service(mas_http::make_untraced_client());
Ok(Arc::new(Self {
db,
service: HttpService::new(client),
providers: OnceCell::new(),
}))
}
pub fn service(&self) -> &HttpService {
&self.service
}
pub async fn start_handler(&self) -> Result<()> {
let providers = services().globals.config.idps.iter();
self.providers
.get_or_try_init(|| async move {
future::try_join_all(providers.map(Provider::fetch_metadata))
.await
.map(Vec::into_iter)
.map(HashSet::from_iter)
})
.await?;
Ok(())
}
pub fn get(&self, provider: &str) -> Option<&Provider> {
let providers = self.providers.get().expect("");
providers.get(provider)
}
pub fn login_type(&self) -> impl Iterator<Item = IdentityProvider> + '_ {
let providers = self.providers.get().expect("");
providers.iter().map(|p| p.config.inner.clone())
}
pub fn user_from_subject(&self, provider: &str, subject: &str) -> Result<Option<OwnedUserId>> {
self.db.user_from_subject(provider, subject)
}
}
#[derive(Clone, Debug)]
pub struct Provider {
pub config: &'static IdpConfig,
pub metadata: VerifiedProviderMetadata,
}
impl Provider {
pub async fn fetch_metadata(config: &'static IdpConfig) -> Result<Self> {
discovery::discover(services().sso.service(), &config.issuer)
.await
.map(|metadata| Provider { config, metadata })
.map_err(|e| {
error!(
"Failed to fetch identity provider metadata ({}): {}",
&config.inner.id, e
);
Error::bad_config("Failed to fetch identity provider metadata.")
})
}
}
impl Borrow<str> for Provider {
fn borrow(&self) -> &str {
self.config.borrow()
}
}
impl PartialEq for Provider {
fn eq(&self, other: &Self) -> bool {
self.config == other.config
}
}
impl Eq for Provider {}
impl Hash for Provider {
fn hash<H: Hasher>(&self, hasher: &mut H) {
self.config.hash(hasher)
}
}
#[derive(Clone, Deserialize, Serialize)]
pub struct LoginToken {
pub iss: String,
pub aud: OwnedUserId,
pub sub: String,
pub exp: u64,
}
impl LoginToken {
pub fn new(provider: String, user_id: OwnedUserId) -> Self {
Self {
iss: provider,
aud: user_id,
sub: utils::random_string(TOKEN_LENGTH),
exp: utils::millis_since_unix_epoch()
.checked_add(LOGIN_TOKEN_EXPIRATION_SECS * 1000)
.expect("time overflow"),
}
}
pub fn audience(self) -> OwnedUserId {
self.aud
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ValidationData {
pub provider: String,
pub redirect_url: String,
#[serde(flatten, with = "AuthorizationValidationDataDef")]
pub inner: AuthorizationValidationData,
}
impl ValidationData {
pub fn new(provider: String, redirect_url: String, inner: AuthorizationValidationData) -> Self {
Self {
provider,
redirect_url,
inner,
}
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(remote = "AuthorizationValidationData")]
pub struct AuthorizationValidationDataDef {
pub state: String,
pub nonce: String,
pub redirect_uri: Url,
pub code_challenge_verifier: Option<String>,
}
impl From<AuthorizationValidationData> for AuthorizationValidationDataDef {
fn from(
AuthorizationValidationData {
state,
nonce,
redirect_uri,
code_challenge_verifier,
}: AuthorizationValidationData,
) -> Self {
Self {
state,
nonce,
redirect_uri,
code_challenge_verifier,
}
}
}

View file

@ -0,0 +1,34 @@
pub fn base(title: &str, body: maud::Markup) -> maud::Markup {
maud::html! {
(maud::DOCTYPE)
html lang="en" {
head {
meta charset="utf-8";
meta name="viewport" content="width=device-width, initial-scale=1.0";
link rel="icon" type="image/png" sizes="32x32" href="https://conduit.rs/conduit.svg";
style { (FONT_FACE) }
title { (title) }
}
body { (body) }
}
}
}
pub fn footer() -> maud::Markup {
let info = "An open network for secure, decentralized communication.";
maud::html! {
footer { p { (info) } }
}
}
const FONT_FACE: &str = r#"
@font-face {
font-family: 'Source Sans 3 Variable';
font-style: normal;
font-display: swap;
font-weight: 200 900;
src: url(https://cdn.jsdelivr.net/fontsource/fonts/source-sans-3:vf@latest/latin-wght-normal.woff2) format('woff2-variations');
unicode-range: U+0000-00FF,U+0131,U+0152-0153,U+02BB-02BC,U+02C6,U+02DA,U+02DC,U+0304,U+0308,U+0329,U+2000-206F,U+2074,U+20AC,U+2122,U+2191,U+2193,U+2212,U+2215,U+FEFF,U+FFFD;
}
"#;

View file

@ -217,4 +217,6 @@ pub trait Data: Send + Sync {
/// Find out which user an OpenID access token belongs to.
fn find_from_openid_token(&self, token: &str) -> Result<Option<OwnedUserId>>;
fn set_placeholder_password(&self, user_id: &UserId) -> Result<()>;
}

View file

@ -602,6 +602,10 @@ impl Service {
pub fn find_from_openid_token(&self, token: &str) -> Result<Option<OwnedUserId>> {
self.db.find_from_openid_token(token)
}
pub fn set_placeholder_password(&self, user_id: &UserId) -> Result<()> {
self.db.set_placeholder_password(user_id)
}
}
/// Ensure that a user only sees signatures from themselves and the target user

View file

@ -175,6 +175,22 @@ impl Error {
}
}
impl From<mas_oidc_client::types::errors::ClientError> for Error {
fn from(e: mas_oidc_client::types::errors::ClientError) -> Self {
error!(
"Failed to complete authorization callback: {} {}",
e.error,
e.error_description.as_deref().unwrap_or_default()
);
// TODO: error conversion
Self::BadRequest(
ErrorKind::Unknown,
"Failed to complete authorization callback.",
)
}
}
#[cfg(feature = "persy")]
impl<T: Into<PersyError>> From<persy::PE<T>> for Error {
fn from(err: persy::PE<T>) -> Self {

View file

@ -1,6 +1,7 @@
pub mod error;
use argon2::{Config, Variant};
use axum_extra::extract::cookie::{Cookie, SameSite};
use cmp::Ordering;
use rand::prelude::*;
use ring::digest;
@ -8,7 +9,7 @@ use ruma::{canonical_json::try_from_json_map, CanonicalJsonError, CanonicalJsonO
use std::{
cmp, fmt,
str::FromStr,
time::{SystemTime, UNIX_EPOCH},
time::{Duration, SystemTime, UNIX_EPOCH},
};
pub fn millis_since_unix_epoch() -> u64 {
@ -142,6 +143,29 @@ pub fn deserialize_from_str<
deserializer.deserialize_str(Visitor(std::marker::PhantomData))
}
pub fn build_cookie<'c>(
name: &'c str,
value: &'c str,
path: &'c str,
max_age: Option<u64>,
) -> Cookie<'c> {
let mut cookie = Cookie::new(name, value);
cookie.set_path(path);
cookie.set_secure(true);
cookie.set_http_only(true);
cookie.set_same_site(SameSite::None);
cookie.set_max_age(
max_age
.map(Duration::from_secs)
.map(TryInto::try_into)
.transpose()
.expect("time overflow"),
);
cookie
}
// Copied from librustdoc:
// https://github.com/rust-lang/rust/blob/cbaeec14f90b59a91a6b0f17fc046c66fa811892/src/librustdoc/html/escape.rs