pub mod abstraction;

pub mod account_data;
pub mod admin;
pub mod appservice;
pub mod globals;
pub mod key_backups;
pub mod media;
pub mod pusher;
pub mod rooms;
pub mod sending;
pub mod transaction_ids;
pub mod uiaa;
pub mod users;

use crate::{utils, Error, Result};
use abstraction::DatabaseEngine;
use directories::ProjectDirs;
use log::error;
use rocket::futures::{channel::mpsc, stream::FuturesUnordered, StreamExt};
use ruma::{DeviceId, ServerName, UserId};
use serde::Deserialize;
use std::{
    collections::HashMap,
    fs::{self, remove_dir_all},
    io::Write,
    sync::{Arc, RwLock},
};
use tokio::sync::Semaphore;

#[derive(Clone, Debug, Deserialize)]
pub struct Config {
    server_name: Box<ServerName>,
    database_path: String,
    #[serde(default = "default_cache_capacity")]
    cache_capacity: u32,
    #[serde(default = "default_max_request_size")]
    max_request_size: u32,
    #[serde(default = "default_max_concurrent_requests")]
    max_concurrent_requests: u16,
    #[serde(default = "true_fn")]
    allow_registration: bool,
    #[serde(default = "true_fn")]
    allow_encryption: bool,
    #[serde(default = "false_fn")]
    allow_federation: bool,
    #[serde(default = "false_fn")]
    pub allow_jaeger: bool,
    #[serde(default)]
    proxy: ProxyConfig,
    jwt_secret: Option<String>,
    #[serde(default = "Vec::new")]
    trusted_servers: Vec<Box<ServerName>>,
    #[serde(default = "default_log")]
    pub log: String,
}

fn false_fn() -> bool {
    false
}

fn true_fn() -> bool {
    true
}

fn default_cache_capacity() -> u32 {
    1024 * 1024 * 1024
}

fn default_max_request_size() -> u32 {
    20 * 1024 * 1024 // Default to 20 MB
}

fn default_max_concurrent_requests() -> u16 {
    100
}

fn default_log() -> String {
    "info,state_res=warn,rocket=off,_=off,sled=off".to_owned()
}

#[cfg(feature = "sled")]
pub type Engine = abstraction::SledEngine;

#[cfg(feature = "rocksdb")]
pub type Engine = abstraction::RocksDbEngine;

#[derive(Clone, Debug, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ProxyConfig {
    None,
    Global {
        #[serde(deserialize_with = "crate::utils::deserialize_from_str")]
        url: reqwest::Url,
    },
    ByDomain(Vec<PartialProxyConfig>),
}
impl ProxyConfig {
    pub fn to_proxy(&self) -> Result<Option<reqwest::Proxy>> {
        Ok(match self.clone() {
            ProxyConfig::None => None,
            ProxyConfig::Global { url } => Some(reqwest::Proxy::all(url)?),
            ProxyConfig::ByDomain(proxies) => Some(reqwest::Proxy::custom(move |url| {
                proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching proxy
            })),
        })
    }
}
impl Default for ProxyConfig {
    fn default() -> Self {
        ProxyConfig::None
    }
}

#[derive(Clone, Debug, Deserialize)]
pub struct PartialProxyConfig {
    #[serde(deserialize_with = "crate::utils::deserialize_from_str")]
    url: reqwest::Url,
    #[serde(default)]
    include: Vec<WildCardedDomain>,
    #[serde(default)]
    exclude: Vec<WildCardedDomain>,
}
impl PartialProxyConfig {
    pub fn for_url(&self, url: &reqwest::Url) -> Option<&reqwest::Url> {
        let domain = url.domain()?;
        let mut included_because = None; // most specific reason it was included
        let mut excluded_because = None; // most specific reason it was excluded
        if self.include.is_empty() {
            // treat empty include list as `*`
            included_because = Some(&WildCardedDomain::WildCard)
        }
        for wc_domain in &self.include {
            if wc_domain.matches(domain) {
                match included_because {
                    Some(prev) if !wc_domain.more_specific_than(prev) => (),
                    _ => included_because = Some(wc_domain),
                }
            }
        }
        for wc_domain in &self.exclude {
            if wc_domain.matches(domain) {
                match excluded_because {
                    Some(prev) if !wc_domain.more_specific_than(prev) => (),
                    _ => excluded_because = Some(wc_domain),
                }
            }
        }
        match (included_because, excluded_because) {
            (Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), // included for a more specific reason than excluded
            (Some(_), None) => Some(&self.url),
            _ => None,
        }
    }
}

/// A domain name, that optionally allows a * as its first subdomain.
#[derive(Clone, Debug)]
pub enum WildCardedDomain {
    WildCard,
    WildCarded(String),
    Exact(String),
}
impl WildCardedDomain {
    pub fn matches(&self, domain: &str) -> bool {
        match self {
            WildCardedDomain::WildCard => true,
            WildCardedDomain::WildCarded(d) => domain.ends_with(d),
            WildCardedDomain::Exact(d) => domain == d,
        }
    }
    pub fn more_specific_than(&self, other: &Self) -> bool {
        match (self, other) {
            (WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false,
            (_, WildCardedDomain::WildCard) => true,
            (WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a),
            (WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => {
                a != b && a.ends_with(b)
            }
            _ => false,
        }
    }
}
impl std::str::FromStr for WildCardedDomain {
    type Err = std::convert::Infallible;
    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
        // maybe do some domain validation?
        Ok(if s.starts_with("*.") {
            WildCardedDomain::WildCarded(s[1..].to_owned())
        } else if s == "*" {
            WildCardedDomain::WildCarded("".to_owned())
        } else {
            WildCardedDomain::Exact(s.to_owned())
        })
    }
}
impl<'de> serde::de::Deserialize<'de> for WildCardedDomain {
    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
    where
        D: serde::de::Deserializer<'de>,
    {
        crate::utils::deserialize_from_str(deserializer)
    }
}

pub struct Database {
    pub globals: globals::Globals,
    pub users: users::Users,
    pub uiaa: uiaa::Uiaa,
    pub rooms: rooms::Rooms,
    pub account_data: account_data::AccountData,
    pub media: media::Media,
    pub key_backups: key_backups::KeyBackups,
    pub transaction_ids: transaction_ids::TransactionIds,
    pub sending: sending::Sending,
    pub admin: admin::Admin,
    pub appservice: appservice::Appservice,
    pub pusher: pusher::PushData,
}

impl Database {
    /// Tries to remove the old database but ignores all errors.
    pub fn try_remove(server_name: &str) -> Result<()> {
        let mut path = ProjectDirs::from("xyz", "koesters", "conduit")
            .ok_or_else(|| Error::bad_config("The OS didn't return a valid home directory path."))?
            .data_dir()
            .to_path_buf();
        path.push(server_name);
        let _ = remove_dir_all(path);

        Ok(())
    }

    /// Load an existing database or create a new one.
    pub async fn load_or_create(config: Config) -> Result<Arc<Self>> {
        let builder = Engine::open(&config)?;

        if config.max_request_size < 1024 {
            eprintln!("ERROR: Max request size is less than 1KB. Please increase it.");
        }

        let (admin_sender, admin_receiver) = mpsc::unbounded();
        let (sending_sender, sending_receiver) = mpsc::unbounded();

        let db = Arc::new(Self {
            users: users::Users {
                userid_password: builder.open_tree("userid_password")?,
                userid_displayname: builder.open_tree("userid_displayname")?,
                userid_avatarurl: builder.open_tree("userid_avatarurl")?,
                userdeviceid_token: builder.open_tree("userdeviceid_token")?,
                userdeviceid_metadata: builder.open_tree("userdeviceid_metadata")?,
                userid_devicelistversion: builder.open_tree("userid_devicelistversion")?,
                token_userdeviceid: builder.open_tree("token_userdeviceid")?,
                onetimekeyid_onetimekeys: builder.open_tree("onetimekeyid_onetimekeys")?,
                userid_lastonetimekeyupdate: builder.open_tree("userid_lastonetimekeyupdate")?,
                keychangeid_userid: builder.open_tree("keychangeid_userid")?,
                keyid_key: builder.open_tree("keyid_key")?,
                userid_masterkeyid: builder.open_tree("userid_masterkeyid")?,
                userid_selfsigningkeyid: builder.open_tree("userid_selfsigningkeyid")?,
                userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid")?,
                todeviceid_events: builder.open_tree("todeviceid_events")?,
            },
            uiaa: uiaa::Uiaa {
                userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?,
                userdevicesessionid_uiaarequest: builder
                    .open_tree("userdevicesessionid_uiaarequest")?,
            },
            rooms: rooms::Rooms {
                edus: rooms::RoomEdus {
                    readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?,
                    roomuserid_privateread: builder.open_tree("roomuserid_privateread")?, // "Private" read receipt
                    roomuserid_lastprivatereadupdate: builder
                        .open_tree("roomuserid_lastprivatereadupdate")?,
                    typingid_userid: builder.open_tree("typingid_userid")?,
                    roomid_lasttypingupdate: builder.open_tree("roomid_lasttypingupdate")?,
                    presenceid_presence: builder.open_tree("presenceid_presence")?,
                    userid_lastpresenceupdate: builder.open_tree("userid_lastpresenceupdate")?,
                },
                pduid_pdu: builder.open_tree("pduid_pdu")?,
                eventid_pduid: builder.open_tree("eventid_pduid")?,
                roomid_pduleaves: builder.open_tree("roomid_pduleaves")?,

                alias_roomid: builder.open_tree("alias_roomid")?,
                aliasid_alias: builder.open_tree("aliasid_alias")?,
                publicroomids: builder.open_tree("publicroomids")?,

                tokenids: builder.open_tree("tokenids")?,

                roomserverids: builder.open_tree("roomserverids")?,
                serverroomids: builder.open_tree("serverroomids")?,
                userroomid_joined: builder.open_tree("userroomid_joined")?,
                roomuserid_joined: builder.open_tree("roomuserid_joined")?,
                roomuseroncejoinedids: builder.open_tree("roomuseroncejoinedids")?,
                userroomid_invitestate: builder.open_tree("userroomid_invitestate")?,
                roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?,
                userroomid_leftstate: builder.open_tree("userroomid_leftstate")?,
                roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?,

                userroomid_notificationcount: builder.open_tree("userroomid_notificationcount")?,
                userroomid_highlightcount: builder.open_tree("userroomid_highlightcount")?,

                statekey_shortstatekey: builder.open_tree("statekey_shortstatekey")?,
                stateid_shorteventid: builder.open_tree("stateid_shorteventid")?,
                eventid_shorteventid: builder.open_tree("eventid_shorteventid")?,
                shorteventid_eventid: builder.open_tree("shorteventid_eventid")?,
                shorteventid_shortstatehash: builder.open_tree("shorteventid_shortstatehash")?,
                roomid_shortstatehash: builder.open_tree("roomid_shortstatehash")?,
                statehash_shortstatehash: builder.open_tree("statehash_shortstatehash")?,

                eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?,
                prevevent_parent: builder.open_tree("prevevent_parent")?,
            },
            account_data: account_data::AccountData {
                roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?,
            },
            media: media::Media {
                mediaid_file: builder.open_tree("mediaid_file")?,
            },
            key_backups: key_backups::KeyBackups {
                backupid_algorithm: builder.open_tree("backupid_algorithm")?,
                backupid_etag: builder.open_tree("backupid_etag")?,
                backupkeyid_backup: builder.open_tree("backupkeyid_backup")?,
            },
            transaction_ids: transaction_ids::TransactionIds {
                userdevicetxnid_response: builder.open_tree("userdevicetxnid_response")?,
            },
            sending: sending::Sending {
                servername_educount: builder.open_tree("servername_educount")?,
                servernamepduids: builder.open_tree("servernamepduids")?,
                servercurrentevents: builder.open_tree("servercurrentevents")?,
                maximum_requests: Arc::new(Semaphore::new(config.max_concurrent_requests as usize)),
                sender: sending_sender,
            },
            admin: admin::Admin {
                sender: admin_sender,
            },
            appservice: appservice::Appservice {
                cached_registrations: Arc::new(RwLock::new(HashMap::new())),
                id_appserviceregistrations: builder.open_tree("id_appserviceregistrations")?,
            },
            pusher: pusher::PushData {
                senderkey_pusher: builder.open_tree("senderkey_pusher")?,
            },
            globals: globals::Globals::load(
                builder.open_tree("global")?,
                builder.open_tree("server_signingkeys")?,
                config,
            )?,
        });

        // MIGRATIONS
        // TODO: database versions of new dbs should probably not be 0
        if db.globals.database_version()? < 1 {
            for (roomserverid, _) in db.rooms.roomserverids.iter() {
                let mut parts = roomserverid.split(|&b| b == 0xff);
                let room_id = parts.next().expect("split always returns one element");
                let servername = match parts.next() {
                    Some(s) => s,
                    None => {
                        error!("Migration: Invalid roomserverid in db.");
                        continue;
                    }
                };
                let mut serverroomid = servername.to_vec();
                serverroomid.push(0xff);
                serverroomid.extend_from_slice(room_id);

                db.rooms.serverroomids.insert(&serverroomid, &[])?;
            }

            db.globals.bump_database_version(1)?;

            println!("Migration: 0 -> 1 finished");
        }

        if db.globals.database_version()? < 2 {
            // We accidentally inserted hashed versions of "" into the db instead of just ""
            for (userid, password) in db.users.userid_password.iter() {
                let password = utils::string_from_bytes(&password);

                let empty_hashed_password = password.map_or(false, |password| {
                    argon2::verify_encoded(&password, b"").unwrap_or(false)
                });

                if empty_hashed_password {
                    db.users.userid_password.insert(&userid, b"")?;
                }
            }

            db.globals.bump_database_version(2)?;

            println!("Migration: 1 -> 2 finished");
        }

        if db.globals.database_version()? < 3 {
            // Move media to filesystem
            for (key, content) in db.media.mediaid_file.iter() {
                if content.len() == 0 {
                    continue;
                }

                let path = db.globals.get_media_file(&key);
                let mut file = fs::File::create(path)?;
                file.write_all(&content)?;
                db.media.mediaid_file.insert(&key, &[])?;
            }

            db.globals.bump_database_version(3)?;

            println!("Migration: 2 -> 3 finished");
        }

        if db.globals.database_version()? < 4 {
            // Add federated users to db as deactivated
            for our_user in db.users.iter() {
                let our_user = our_user?;
                if db.users.is_deactivated(&our_user)? {
                    continue;
                }
                for room in db.rooms.rooms_joined(&our_user) {
                    for user in db.rooms.room_members(&room?) {
                        let user = user?;
                        if user.server_name() != db.globals.server_name() {
                            println!("Migration: Creating user {}", user);
                            db.users.create(&user, None)?;
                        }
                    }
                }
            }

            db.globals.bump_database_version(4)?;

            println!("Migration: 3 -> 4 finished");
        }

        // This data is probably outdated
        db.rooms.edus.presenceid_presence.clear()?;

        db.admin.start_handler(Arc::clone(&db), admin_receiver);
        db.sending.start_handler(Arc::clone(&db), sending_receiver);

        Ok(db)
    }

    pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) {
        let userid_bytes = user_id.as_bytes().to_vec();
        let mut userid_prefix = userid_bytes.clone();
        userid_prefix.push(0xff);

        let mut userdeviceid_prefix = userid_prefix.clone();
        userdeviceid_prefix.extend_from_slice(device_id.as_bytes());
        userdeviceid_prefix.push(0xff);

        let mut futures = FuturesUnordered::new();

        // Return when *any* user changed his key
        // TODO: only send for user they share a room with
        futures.push(
            self.users
                .todeviceid_events
                .watch_prefix(&userdeviceid_prefix),
        );

        futures.push(self.rooms.userroomid_joined.watch_prefix(&userid_prefix));
        futures.push(
            self.rooms
                .userroomid_invitestate
                .watch_prefix(&userid_prefix),
        );
        futures.push(self.rooms.userroomid_leftstate.watch_prefix(&userid_prefix));

        // Events for rooms we are in
        for room_id in self.rooms.rooms_joined(user_id).filter_map(|r| r.ok()) {
            let roomid_bytes = room_id.as_bytes().to_vec();
            let mut roomid_prefix = roomid_bytes.clone();
            roomid_prefix.push(0xff);

            // PDUs
            futures.push(self.rooms.pduid_pdu.watch_prefix(&roomid_prefix));

            // EDUs
            futures.push(
                self.rooms
                    .edus
                    .roomid_lasttypingupdate
                    .watch_prefix(&roomid_bytes),
            );

            futures.push(
                self.rooms
                    .edus
                    .readreceiptid_readreceipt
                    .watch_prefix(&roomid_prefix),
            );

            // Key changes
            futures.push(self.users.keychangeid_userid.watch_prefix(&roomid_prefix));

            // Room account data
            let mut roomuser_prefix = roomid_prefix.clone();
            roomuser_prefix.extend_from_slice(&userid_prefix);

            futures.push(
                self.account_data
                    .roomuserdataid_accountdata
                    .watch_prefix(&roomuser_prefix),
            );
        }

        let mut globaluserdata_prefix = vec![0xff];
        globaluserdata_prefix.extend_from_slice(&userid_prefix);

        futures.push(
            self.account_data
                .roomuserdataid_accountdata
                .watch_prefix(&globaluserdata_prefix),
        );

        // More key changes (used when user is not joined to any rooms)
        futures.push(self.users.keychangeid_userid.watch_prefix(&userid_prefix));

        // One time keys
        futures.push(
            self.users
                .userid_lastonetimekeyupdate
                .watch_prefix(&userid_bytes),
        );

        // Wait until one of them finds something
        futures.next().await;
    }

    pub async fn flush(&self) -> Result<()> {
        // noop while we don't use sled 1.0
        //self._db.flush_async().await?;
        Ok(())
    }
}