1
0
Fork 0
mirror of https://gitlab.com/famedly/conduit.git synced 2025-04-22 14:10:16 +03:00

Merge branch 'refresh-tokens' into 'next'

Draft: Refresh token support

See merge request 
This commit is contained in:
avdb 2024-11-04 09:05:59 +00:00
commit 6c020b690f
7 changed files with 181 additions and 9 deletions
src
api
client_server
ruma_wrapper
config
database
service/users

View file

@ -1,12 +1,14 @@
use std::time::Duration;
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
use crate::{services, utils, Error, Result, Ruma};
use ruma::{
api::client::{
error::ErrorKind,
session::{get_login_types, login, logout, logout_all},
session::{get_login_types, login, logout, logout_all, refresh_token},
uiaa::UserIdentifier,
},
UserId,
OwnedDeviceId, UserId,
};
use serde::Deserialize;
use tracing::{info, warn};
@ -179,7 +181,16 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
.unwrap_or_else(|| utils::random_string(DEVICE_ID_LENGTH).into());
// Generate a new token for the device
let token = utils::random_string(TOKEN_LENGTH);
let access_token = utils::random_string(TOKEN_LENGTH);
let (refresh_token, expires_at) = match body.refresh_token {
false => (None, None),
_ => services()
.users
.create_refresh_token(&access_token)
.map(Some)
.map(Option::unzip)?,
};
// Determine if device_id was provided and exists in the db for this user
let device_exists = body.device_id.as_ref().map_or(false, |device_id| {
@ -190,12 +201,14 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
});
if device_exists {
services().users.set_token(&user_id, &device_id, &token)?;
services()
.users
.set_token(&user_id, &device_id, &access_token)?;
} else {
services().users.create_device(
&user_id,
&device_id,
&token,
&access_token,
body.initial_device_display_name.clone(),
)?;
}
@ -206,12 +219,14 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
#[allow(deprecated)]
Ok(login::v3::Response {
user_id,
access_token: token,
access_token,
home_server: Some(services().globals.server_name().to_owned()),
device_id,
well_known: None,
refresh_token: None,
expires_in: None,
refresh_token,
expires_in: expires_at
.map(|n| n.checked_sub(utils::millis_since_unix_epoch()).expect(""))
.map(Duration::from_millis),
})
}
@ -277,3 +292,49 @@ pub async fn logout_all_route(
Ok(logout_all::v3::Response::new())
}
pub async fn refresh_token_route(
body: Ruma<refresh_token::v3::Request>,
) -> Result<refresh_token::v3::Response> {
let expires_at = services()
.users
.get_refresh_token_ttl(&body.refresh_token)?
.ok_or_else(|| {
Error::BadRequest(
ErrorKind::UnknownToken { soft_logout: false },
"Unknown refresh token.",
)
})?;
if expires_at < utils::millis_since_unix_epoch() {
return Err(Error::BadRequest(
ErrorKind::UnknownToken { soft_logout: false },
"Expired refresh token.",
));
}
let (user_id, device_id) = {
let access_token = services()
.users
.refresh_to_access_token(&body.refresh_token)?
.expect("");
services().users.find_from_token(&access_token)?.expect("")
};
let access_token = utils::random_string(TOKEN_LENGTH);
let (refresh_token, expires_at) = services().users.create_refresh_token(&access_token)?;
let device_id: OwnedDeviceId = device_id.into();
services()
.users
.set_token(&user_id, &device_id, &access_token)?;
Ok(refresh_token::v3::Response {
access_token,
refresh_token: Some(refresh_token),
expires_in_ms: Some(Duration::from_millis(
expires_at
.checked_sub(utils::millis_since_unix_epoch())
.expect(""),
)),
})
}

View file

@ -23,7 +23,7 @@ use serde::Deserialize;
use tracing::{debug, error, warn};
use super::{Ruma, RumaResponse};
use crate::{service::appservice::RegistrationInfo, services, Error, Result};
use crate::{service::appservice::RegistrationInfo, services, utils, Error, Result};
enum Token {
Appservice(Box<RegistrationInfo>),
@ -87,6 +87,17 @@ where
if let Some(reg_info) = services().appservice.find_from_token(token).await {
Token::Appservice(Box::new(reg_info.clone()))
} else if let Some((user_id, device_id)) = services().users.find_from_token(token)? {
if services()
.users
.get_access_token_ttl(token)?
.is_some_and(|expires_at| expires_at < utils::millis_since_unix_epoch())
{
return Err(Error::BadRequest(
ErrorKind::UnknownToken { soft_logout: true },
"Expired access token.",
));
}
Token::User((user_id, OwnedDeviceId::from(device_id)))
} else {
Token::Invalid

View file

@ -47,6 +47,8 @@ pub struct Config {
#[serde(default = "false_fn")]
pub allow_registration: bool,
pub registration_token: Option<String>,
#[serde(default, flatten)]
pub refresh_token: RefreshTokenConfig,
#[serde(default = "default_openid_token_ttl")]
pub openid_token_ttl: u64,
#[serde(default = "true_fn")]
@ -101,6 +103,14 @@ pub struct WellKnownConfig {
pub server: Option<OwnedServerName>,
}
#[derive(Clone, Debug, Deserialize, Default)]
pub struct RefreshTokenConfig {
#[serde(default = "default_refresh_token_ttl")]
pub ttl: u64,
#[serde(default = "default_access_token_ttl")]
pub access_token_ttl: u64,
}
const DEPRECATED_KEYS: &[&str] = &["cache_capacity"];
impl Config {
@ -304,6 +314,14 @@ fn default_turn_ttl() -> u64 {
60 * 60 * 24
}
fn default_refresh_token_ttl() -> u64 {
60 * 60
}
fn default_access_token_ttl() -> u64 {
60 * 5
}
fn default_openid_token_ttl() -> u64 {
60 * 60
}

View file

@ -945,6 +945,60 @@ impl service::users::Data for KeyValueDatabase {
}
}
fn create_refresh_token(&self, access_token: &str) -> Result<(String, u64)> {
let crate::config::RefreshTokenConfig {
ttl,
access_token_ttl,
} = services().globals.config.refresh_token;
let refresh_token = utils::random_string(TOKEN_LENGTH);
let mut value = refresh_token.as_bytes().to_vec();
value.extend_from_slice(
&utils::millis_since_unix_epoch()
.checked_add(access_token_ttl * 1000)
.expect("time is valid")
.to_be_bytes(),
);
self.accesstoken_refreshtokenttl
.insert(access_token.as_bytes(), &value)?;
let mut value = access_token.as_bytes().to_vec();
value.extend_from_slice(
&utils::millis_since_unix_epoch()
.checked_add(ttl * 1000)
.expect("time is valid")
.to_be_bytes(),
);
self.refreshtoken_accesstokenttl
.insert(refresh_token.as_bytes(), &value)?;
Ok((refresh_token, access_token_ttl))
}
fn refresh_to_access_token(&self, refresh_token: &str) -> Result<Option<String>> {
Ok(self
.refreshtoken_accesstokenttl
.get(refresh_token.as_bytes())?
.map(|v| utils::string_from_bytes(&v[..TOKEN_LENGTH]).expect("")))
}
fn get_access_token_ttl(&self, access_token: &str) -> Result<Option<u64>> {
Ok(self
.accesstoken_refreshtokenttl
.get(access_token.as_bytes())?
.map(|v| u64::from_be_bytes(v[TOKEN_LENGTH..].try_into().expect(""))))
}
fn get_refresh_token_ttl(&self, refresh_token: &str) -> Result<Option<u64>> {
Ok(self
.refreshtoken_accesstokenttl
.get(refresh_token.as_bytes())?
.map(|v| u64::from_be_bytes(v[TOKEN_LENGTH..].try_into().expect(""))))
}
// Creates an OpenID token, which can be used to prove that a user has access to an account (primarily for integrations)
fn create_openid_token(&self, user_id: &UserId) -> Result<(String, u64)> {
let token = utils::random_string(TOKEN_LENGTH);

View file

@ -59,6 +59,8 @@ pub struct KeyValueDatabase {
pub(super) userid_selfsigningkeyid: Arc<dyn KvTree>,
pub(super) userid_usersigningkeyid: Arc<dyn KvTree>,
pub(super) openidtoken_expiresatuserid: Arc<dyn KvTree>, // expiresatuserid = expiresat + userid
pub(super) accesstoken_refreshtokenttl: Arc<dyn KvTree>,
pub(super) refreshtoken_accesstokenttl: Arc<dyn KvTree>,
pub(super) userfilterid_filter: Arc<dyn KvTree>, // UserFilterId = UserId + FilterId
@ -295,6 +297,8 @@ impl KeyValueDatabase {
userid_selfsigningkeyid: builder.open_tree("userid_selfsigningkeyid")?,
userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid")?,
openidtoken_expiresatuserid: builder.open_tree("openidtoken_expiresatuserid")?,
accesstoken_refreshtokenttl: builder.open_tree("accesstoken_refreshtokenttl")?,
refreshtoken_accesstokenttl: builder.open_tree("refreshtoken_accesstokenttl")?,
userfilterid_filter: builder.open_tree("userfilterid_filter")?,
todeviceid_events: builder.open_tree("todeviceid_events")?,

View file

@ -212,6 +212,14 @@ pub trait Data: Send + Sync {
fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result<Option<FilterDefinition>>;
fn create_refresh_token(&self, access_token: &str) -> Result<(String, u64)>;
fn refresh_to_access_token(&self, refresh_token: &str) -> Result<Option<String>>;
fn get_access_token_ttl(&self, access_token: &str) -> Result<Option<u64>>;
fn get_refresh_token_ttl(&self, refresh_token: &str) -> Result<Option<u64>>;
// Creates an OpenID token, which can be used to prove that a user has access to an account (primarily for integrations)
fn create_openid_token(&self, user_id: &UserId) -> Result<(String, u64)>;

View file

@ -593,6 +593,22 @@ impl Service {
self.db.get_filter(user_id, filter_id)
}
pub fn create_refresh_token(&self, access_token: &str) -> Result<(String, u64)> {
self.db.create_refresh_token(access_token)
}
pub fn refresh_to_access_token(&self, refresh_token: &str) -> Result<Option<String>> {
self.db.refresh_to_access_token(refresh_token)
}
pub fn get_access_token_ttl(&self, access_token: &str) -> Result<Option<u64>> {
self.db.get_access_token_ttl(access_token)
}
pub fn get_refresh_token_ttl(&self, refresh_token: &str) -> Result<Option<u64>> {
self.db.get_refresh_token_ttl(refresh_token)
}
// Creates an OpenID token, which can be used to prove that a user has access to an account (primarily for integrations)
pub fn create_openid_token(&self, user_id: &UserId) -> Result<(String, u64)> {
self.db.create_openid_token(user_id)