mirror of
https://gitlab.com/famedly/conduit.git
synced 2025-04-22 14:10:16 +03:00
Merge branch 'refresh-tokens' into 'next'
Draft: Refresh token support See merge request famedly/conduit!714
This commit is contained in:
commit
6c020b690f
7 changed files with 181 additions and 9 deletions
|
@ -1,12 +1,14 @@
|
|||
use std::time::Duration;
|
||||
|
||||
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
|
||||
use crate::{services, utils, Error, Result, Ruma};
|
||||
use ruma::{
|
||||
api::client::{
|
||||
error::ErrorKind,
|
||||
session::{get_login_types, login, logout, logout_all},
|
||||
session::{get_login_types, login, logout, logout_all, refresh_token},
|
||||
uiaa::UserIdentifier,
|
||||
},
|
||||
UserId,
|
||||
OwnedDeviceId, UserId,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use tracing::{info, warn};
|
||||
|
@ -179,7 +181,16 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
|
|||
.unwrap_or_else(|| utils::random_string(DEVICE_ID_LENGTH).into());
|
||||
|
||||
// Generate a new token for the device
|
||||
let token = utils::random_string(TOKEN_LENGTH);
|
||||
let access_token = utils::random_string(TOKEN_LENGTH);
|
||||
|
||||
let (refresh_token, expires_at) = match body.refresh_token {
|
||||
false => (None, None),
|
||||
_ => services()
|
||||
.users
|
||||
.create_refresh_token(&access_token)
|
||||
.map(Some)
|
||||
.map(Option::unzip)?,
|
||||
};
|
||||
|
||||
// Determine if device_id was provided and exists in the db for this user
|
||||
let device_exists = body.device_id.as_ref().map_or(false, |device_id| {
|
||||
|
@ -190,12 +201,14 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
|
|||
});
|
||||
|
||||
if device_exists {
|
||||
services().users.set_token(&user_id, &device_id, &token)?;
|
||||
services()
|
||||
.users
|
||||
.set_token(&user_id, &device_id, &access_token)?;
|
||||
} else {
|
||||
services().users.create_device(
|
||||
&user_id,
|
||||
&device_id,
|
||||
&token,
|
||||
&access_token,
|
||||
body.initial_device_display_name.clone(),
|
||||
)?;
|
||||
}
|
||||
|
@ -206,12 +219,14 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
|
|||
#[allow(deprecated)]
|
||||
Ok(login::v3::Response {
|
||||
user_id,
|
||||
access_token: token,
|
||||
access_token,
|
||||
home_server: Some(services().globals.server_name().to_owned()),
|
||||
device_id,
|
||||
well_known: None,
|
||||
refresh_token: None,
|
||||
expires_in: None,
|
||||
refresh_token,
|
||||
expires_in: expires_at
|
||||
.map(|n| n.checked_sub(utils::millis_since_unix_epoch()).expect(""))
|
||||
.map(Duration::from_millis),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -277,3 +292,49 @@ pub async fn logout_all_route(
|
|||
|
||||
Ok(logout_all::v3::Response::new())
|
||||
}
|
||||
|
||||
pub async fn refresh_token_route(
|
||||
body: Ruma<refresh_token::v3::Request>,
|
||||
) -> Result<refresh_token::v3::Response> {
|
||||
let expires_at = services()
|
||||
.users
|
||||
.get_refresh_token_ttl(&body.refresh_token)?
|
||||
.ok_or_else(|| {
|
||||
Error::BadRequest(
|
||||
ErrorKind::UnknownToken { soft_logout: false },
|
||||
"Unknown refresh token.",
|
||||
)
|
||||
})?;
|
||||
if expires_at < utils::millis_since_unix_epoch() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::UnknownToken { soft_logout: false },
|
||||
"Expired refresh token.",
|
||||
));
|
||||
}
|
||||
|
||||
let (user_id, device_id) = {
|
||||
let access_token = services()
|
||||
.users
|
||||
.refresh_to_access_token(&body.refresh_token)?
|
||||
.expect("");
|
||||
services().users.find_from_token(&access_token)?.expect("")
|
||||
};
|
||||
|
||||
let access_token = utils::random_string(TOKEN_LENGTH);
|
||||
let (refresh_token, expires_at) = services().users.create_refresh_token(&access_token)?;
|
||||
|
||||
let device_id: OwnedDeviceId = device_id.into();
|
||||
services()
|
||||
.users
|
||||
.set_token(&user_id, &device_id, &access_token)?;
|
||||
|
||||
Ok(refresh_token::v3::Response {
|
||||
access_token,
|
||||
refresh_token: Some(refresh_token),
|
||||
expires_in_ms: Some(Duration::from_millis(
|
||||
expires_at
|
||||
.checked_sub(utils::millis_since_unix_epoch())
|
||||
.expect(""),
|
||||
)),
|
||||
})
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ use serde::Deserialize;
|
|||
use tracing::{debug, error, warn};
|
||||
|
||||
use super::{Ruma, RumaResponse};
|
||||
use crate::{service::appservice::RegistrationInfo, services, Error, Result};
|
||||
use crate::{service::appservice::RegistrationInfo, services, utils, Error, Result};
|
||||
|
||||
enum Token {
|
||||
Appservice(Box<RegistrationInfo>),
|
||||
|
@ -87,6 +87,17 @@ where
|
|||
if let Some(reg_info) = services().appservice.find_from_token(token).await {
|
||||
Token::Appservice(Box::new(reg_info.clone()))
|
||||
} else if let Some((user_id, device_id)) = services().users.find_from_token(token)? {
|
||||
if services()
|
||||
.users
|
||||
.get_access_token_ttl(token)?
|
||||
.is_some_and(|expires_at| expires_at < utils::millis_since_unix_epoch())
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::UnknownToken { soft_logout: true },
|
||||
"Expired access token.",
|
||||
));
|
||||
}
|
||||
|
||||
Token::User((user_id, OwnedDeviceId::from(device_id)))
|
||||
} else {
|
||||
Token::Invalid
|
||||
|
|
|
@ -47,6 +47,8 @@ pub struct Config {
|
|||
#[serde(default = "false_fn")]
|
||||
pub allow_registration: bool,
|
||||
pub registration_token: Option<String>,
|
||||
#[serde(default, flatten)]
|
||||
pub refresh_token: RefreshTokenConfig,
|
||||
#[serde(default = "default_openid_token_ttl")]
|
||||
pub openid_token_ttl: u64,
|
||||
#[serde(default = "true_fn")]
|
||||
|
@ -101,6 +103,14 @@ pub struct WellKnownConfig {
|
|||
pub server: Option<OwnedServerName>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Default)]
|
||||
pub struct RefreshTokenConfig {
|
||||
#[serde(default = "default_refresh_token_ttl")]
|
||||
pub ttl: u64,
|
||||
#[serde(default = "default_access_token_ttl")]
|
||||
pub access_token_ttl: u64,
|
||||
}
|
||||
|
||||
const DEPRECATED_KEYS: &[&str] = &["cache_capacity"];
|
||||
|
||||
impl Config {
|
||||
|
@ -304,6 +314,14 @@ fn default_turn_ttl() -> u64 {
|
|||
60 * 60 * 24
|
||||
}
|
||||
|
||||
fn default_refresh_token_ttl() -> u64 {
|
||||
60 * 60
|
||||
}
|
||||
|
||||
fn default_access_token_ttl() -> u64 {
|
||||
60 * 5
|
||||
}
|
||||
|
||||
fn default_openid_token_ttl() -> u64 {
|
||||
60 * 60
|
||||
}
|
||||
|
|
|
@ -945,6 +945,60 @@ impl service::users::Data for KeyValueDatabase {
|
|||
}
|
||||
}
|
||||
|
||||
fn create_refresh_token(&self, access_token: &str) -> Result<(String, u64)> {
|
||||
let crate::config::RefreshTokenConfig {
|
||||
ttl,
|
||||
access_token_ttl,
|
||||
} = services().globals.config.refresh_token;
|
||||
|
||||
let refresh_token = utils::random_string(TOKEN_LENGTH);
|
||||
|
||||
let mut value = refresh_token.as_bytes().to_vec();
|
||||
value.extend_from_slice(
|
||||
&utils::millis_since_unix_epoch()
|
||||
.checked_add(access_token_ttl * 1000)
|
||||
.expect("time is valid")
|
||||
.to_be_bytes(),
|
||||
);
|
||||
|
||||
self.accesstoken_refreshtokenttl
|
||||
.insert(access_token.as_bytes(), &value)?;
|
||||
|
||||
let mut value = access_token.as_bytes().to_vec();
|
||||
value.extend_from_slice(
|
||||
&utils::millis_since_unix_epoch()
|
||||
.checked_add(ttl * 1000)
|
||||
.expect("time is valid")
|
||||
.to_be_bytes(),
|
||||
);
|
||||
|
||||
self.refreshtoken_accesstokenttl
|
||||
.insert(refresh_token.as_bytes(), &value)?;
|
||||
|
||||
Ok((refresh_token, access_token_ttl))
|
||||
}
|
||||
|
||||
fn refresh_to_access_token(&self, refresh_token: &str) -> Result<Option<String>> {
|
||||
Ok(self
|
||||
.refreshtoken_accesstokenttl
|
||||
.get(refresh_token.as_bytes())?
|
||||
.map(|v| utils::string_from_bytes(&v[..TOKEN_LENGTH]).expect("")))
|
||||
}
|
||||
|
||||
fn get_access_token_ttl(&self, access_token: &str) -> Result<Option<u64>> {
|
||||
Ok(self
|
||||
.accesstoken_refreshtokenttl
|
||||
.get(access_token.as_bytes())?
|
||||
.map(|v| u64::from_be_bytes(v[TOKEN_LENGTH..].try_into().expect(""))))
|
||||
}
|
||||
|
||||
fn get_refresh_token_ttl(&self, refresh_token: &str) -> Result<Option<u64>> {
|
||||
Ok(self
|
||||
.refreshtoken_accesstokenttl
|
||||
.get(refresh_token.as_bytes())?
|
||||
.map(|v| u64::from_be_bytes(v[TOKEN_LENGTH..].try_into().expect(""))))
|
||||
}
|
||||
|
||||
// Creates an OpenID token, which can be used to prove that a user has access to an account (primarily for integrations)
|
||||
fn create_openid_token(&self, user_id: &UserId) -> Result<(String, u64)> {
|
||||
let token = utils::random_string(TOKEN_LENGTH);
|
||||
|
|
|
@ -59,6 +59,8 @@ pub struct KeyValueDatabase {
|
|||
pub(super) userid_selfsigningkeyid: Arc<dyn KvTree>,
|
||||
pub(super) userid_usersigningkeyid: Arc<dyn KvTree>,
|
||||
pub(super) openidtoken_expiresatuserid: Arc<dyn KvTree>, // expiresatuserid = expiresat + userid
|
||||
pub(super) accesstoken_refreshtokenttl: Arc<dyn KvTree>,
|
||||
pub(super) refreshtoken_accesstokenttl: Arc<dyn KvTree>,
|
||||
|
||||
pub(super) userfilterid_filter: Arc<dyn KvTree>, // UserFilterId = UserId + FilterId
|
||||
|
||||
|
@ -295,6 +297,8 @@ impl KeyValueDatabase {
|
|||
userid_selfsigningkeyid: builder.open_tree("userid_selfsigningkeyid")?,
|
||||
userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid")?,
|
||||
openidtoken_expiresatuserid: builder.open_tree("openidtoken_expiresatuserid")?,
|
||||
accesstoken_refreshtokenttl: builder.open_tree("accesstoken_refreshtokenttl")?,
|
||||
refreshtoken_accesstokenttl: builder.open_tree("refreshtoken_accesstokenttl")?,
|
||||
userfilterid_filter: builder.open_tree("userfilterid_filter")?,
|
||||
todeviceid_events: builder.open_tree("todeviceid_events")?,
|
||||
|
||||
|
|
|
@ -212,6 +212,14 @@ pub trait Data: Send + Sync {
|
|||
|
||||
fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result<Option<FilterDefinition>>;
|
||||
|
||||
fn create_refresh_token(&self, access_token: &str) -> Result<(String, u64)>;
|
||||
|
||||
fn refresh_to_access_token(&self, refresh_token: &str) -> Result<Option<String>>;
|
||||
|
||||
fn get_access_token_ttl(&self, access_token: &str) -> Result<Option<u64>>;
|
||||
|
||||
fn get_refresh_token_ttl(&self, refresh_token: &str) -> Result<Option<u64>>;
|
||||
|
||||
// Creates an OpenID token, which can be used to prove that a user has access to an account (primarily for integrations)
|
||||
fn create_openid_token(&self, user_id: &UserId) -> Result<(String, u64)>;
|
||||
|
||||
|
|
|
@ -593,6 +593,22 @@ impl Service {
|
|||
self.db.get_filter(user_id, filter_id)
|
||||
}
|
||||
|
||||
pub fn create_refresh_token(&self, access_token: &str) -> Result<(String, u64)> {
|
||||
self.db.create_refresh_token(access_token)
|
||||
}
|
||||
|
||||
pub fn refresh_to_access_token(&self, refresh_token: &str) -> Result<Option<String>> {
|
||||
self.db.refresh_to_access_token(refresh_token)
|
||||
}
|
||||
|
||||
pub fn get_access_token_ttl(&self, access_token: &str) -> Result<Option<u64>> {
|
||||
self.db.get_access_token_ttl(access_token)
|
||||
}
|
||||
|
||||
pub fn get_refresh_token_ttl(&self, refresh_token: &str) -> Result<Option<u64>> {
|
||||
self.db.get_refresh_token_ttl(refresh_token)
|
||||
}
|
||||
|
||||
// Creates an OpenID token, which can be used to prove that a user has access to an account (primarily for integrations)
|
||||
pub fn create_openid_token(&self, user_id: &UserId) -> Result<(String, u64)> {
|
||||
self.db.create_openid_token(user_id)
|
||||
|
|
Loading…
Reference in a new issue