From ed253e236cff1cdab7bd18dc03a4046da4ca3c4f Mon Sep 17 00:00:00 2001
From: Gabriel Souza Franco <gabrielfrancosouza@gmail.com>
Date: Wed, 21 Apr 2021 00:35:44 -0300
Subject: [PATCH] chore: document FedDest, fix tests

---
 src/server_server.rs | 62 ++++++++++++++++++++++++++++----------------
 1 file changed, 40 insertions(+), 22 deletions(-)

diff --git a/src/server_server.rs b/src/server_server.rs
index ac38f4d6..553f9449 100644
--- a/src/server_server.rs
+++ b/src/server_server.rs
@@ -45,6 +45,20 @@ use std::{
 #[cfg(feature = "conduit_bin")]
 use rocket::{get, post, put};
 
+/// Wraps either an literal IP address plus port, or a hostname plus complement
+/// (colon-plus-port if it was specified).
+///
+/// Note: A `FedDest::Named` might contain an IP address in string form if there
+/// was no port specified to construct a SocketAddr with.
+///
+/// # Examples:
+/// ```rust,ignore
+/// FedDest::Literal("198.51.100.3:8448".parse()?);
+/// FedDest::Literal("[2001:db8::4:5]:443".parse()?);
+/// FedDest::Named("matrix.example.org".to_owned(), "".to_owned());
+/// FedDest::Named("matrix.example.org".to_owned(), ":8448".to_owned());
+/// FedDest::Named("198.51.100.5".to_owned(), "".to_owned());
+/// ```
 #[derive(Clone, Debug, PartialEq)]
 enum FedDest {
     Literal(SocketAddr),
@@ -52,21 +66,21 @@ enum FedDest {
 }
 
 impl FedDest {
-    fn into_https_url(self) -> String {
+    fn into_https_string(self) -> String {
         match self {
             Self::Literal(addr) => format!("https://{}", addr),
             Self::Named(host, port) => format!("https://{}{}", host, port),
         }
     }
 
-    fn into_uri(self) -> String {
+    fn into_uri_string(self) -> String {
         match self {
             Self::Literal(addr) => addr.to_string(),
             Self::Named(host, ref port) => host + port,
         }
     }
 
-    fn host(&self) -> String {
+    fn hostname(&self) -> String {
         match &self {
             Self::Literal(addr) => addr.ip().to_string(),
             Self::Named(host, _) => host.clone(),
@@ -99,21 +113,23 @@ where
     } else {
         let result = find_actual_destination(globals, &destination).await;
         let (actual_destination, host) = result.clone();
-        let result = (result.0.into_https_url(), result.1.into_uri());
+        let result_string = (result.0.into_https_string(), result.1.into_uri_string());
         globals
             .actual_destination_cache
             .write()
             .unwrap()
-            .insert(Box::<ServerName>::from(destination), result.clone());
-        if actual_destination.host() != host.host() {
+            .insert(Box::<ServerName>::from(destination), result_string.clone());
+        let dest_hostname = actual_destination.hostname();
+        let host_hostname = host.hostname();
+        if dest_hostname != host_hostname {
             globals.tls_name_override.write().unwrap().insert(
-                actual_destination.host(),
-                webpki::DNSNameRef::try_from_ascii_str(&host.host())
+                dest_hostname,
+                webpki::DNSNameRef::try_from_ascii_str(&host_hostname)
                     .unwrap()
                     .to_owned(),
             );
         }
-        result
+        result_string
     };
 
     let mut http_request = request
@@ -317,6 +333,8 @@ async fn find_actual_destination(
         }
     };
 
+    // Can't use get_ip_with_port here because we don't want to add a port
+    // to an IP address if it wasn't specified
     let hostname = if let Ok(addr) = hostname.parse::<SocketAddr>() {
         FedDest::Literal(addr)
     } else if let Ok(addr) = hostname.parse::<IpAddr>() {
@@ -1743,45 +1761,45 @@ pub async fn fetch_required_signing_keys(
 
 #[cfg(test)]
 mod tests {
-    use super::{add_port_to_hostname, get_ip_with_port};
+    use super::{FedDest, add_port_to_hostname, get_ip_with_port};
 
     #[test]
     fn ips_get_default_ports() {
         assert_eq!(
-            get_ip_with_port(String::from("1.1.1.1")),
-            Some(String::from("1.1.1.1:8448"))
+            get_ip_with_port("1.1.1.1"),
+            Some(FedDest::Literal("1.1.1.1:8448".parse().unwrap()))
         );
         assert_eq!(
-            get_ip_with_port(String::from("dead:beef::")),
-            Some(String::from("[dead:beef::]:8448"))
+            get_ip_with_port("dead:beef::"),
+            Some(FedDest::Literal("[dead:beef::]:8448".parse().unwrap()))
         );
     }
 
     #[test]
     fn ips_keep_custom_ports() {
         assert_eq!(
-            get_ip_with_port(String::from("1.1.1.1:1234")),
-            Some(String::from("1.1.1.1:1234"))
+            get_ip_with_port("1.1.1.1:1234"),
+            Some(FedDest::Literal("1.1.1.1:1234".parse().unwrap()))
         );
         assert_eq!(
-            get_ip_with_port(String::from("[dead::beef]:8933")),
-            Some(String::from("[dead::beef]:8933"))
+            get_ip_with_port("[dead::beef]:8933"),
+            Some(FedDest::Literal("[dead::beef]:8933".parse().unwrap()))
         );
     }
 
     #[test]
     fn hostnames_get_default_ports() {
         assert_eq!(
-            add_port_to_hostname(String::from("example.com")),
-            "example.com:8448"
+            add_port_to_hostname("example.com"),
+            FedDest::Named(String::from("example.com"), String::from(":8448"))
         )
     }
 
     #[test]
     fn hostnames_keep_custom_ports() {
         assert_eq!(
-            add_port_to_hostname(String::from("example.com:1337")),
-            "example.com:1337"
+            add_port_to_hostname("example.com:1337"),
+            FedDest::Named(String::from("example.com"), String::from(":1337"))
         )
     }
 }