From c3ac2a21c95d8f687e8af8821fe4385d1810bd67 Mon Sep 17 00:00:00 2001
From: sigoden <sigoden@gmail.com>
Date: Sun, 19 Jun 2022 14:23:10 +0800
Subject: [PATCH] feat: serve single file (#54)

close #53
---
 src/args.rs                       |  3 +++
 src/server.rs                     | 32 +++++++++++++++++++++++--------
 tests/{path_prefix.rs => args.rs} | 25 +++++++++++++++++++++++-
 tests/bind.rs                     |  4 +++-
 tests/fixtures.rs                 |  2 +-
 tests/http.rs                     |  1 +
 6 files changed, 56 insertions(+), 11 deletions(-)
 rename tests/{path_prefix.rs => args.rs} (54%)

diff --git a/src/args.rs b/src/args.rs
index 867e3b4..9801e0c 100644
--- a/src/args.rs
+++ b/src/args.rs
@@ -120,6 +120,7 @@ pub struct Args {
     pub addrs: Vec<IpAddr>,
     pub port: u16,
     pub path: PathBuf,
+    pub path_is_file: bool,
     pub path_prefix: String,
     pub uri_prefix: String,
     pub auth: AccessControl,
@@ -146,6 +147,7 @@ impl Args {
             .unwrap_or_else(|| vec!["0.0.0.0", "::"]);
         let addrs: Vec<IpAddr> = Args::parse_addrs(&addrs)?;
         let path = Args::parse_path(matches.value_of_os("path").unwrap_or_default())?;
+        let path_is_file = path.metadata()?.is_file();
         let path_prefix = matches
             .value_of("path-prefix")
             .map(|v| v.trim_matches('/').to_owned())
@@ -180,6 +182,7 @@ impl Args {
             addrs,
             port,
             path,
+            path_is_file,
             path_prefix,
             uri_prefix,
             auth,
diff --git a/src/server.rs b/src/server.rs
index 81f01f7..1a25b95 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -90,19 +90,26 @@ impl Server {
         let headers = req.headers();
         let method = req.method().clone();
 
-        let authorization = headers.get(AUTHORIZATION);
-        let guard_type = self.args.auth.guard(req_path, &method, authorization);
-
         if req_path == "/favicon.ico" && method == Method::GET {
             self.handle_send_favicon(headers, &mut res).await?;
             return Ok(res);
         }
 
+        let authorization = headers.get(AUTHORIZATION);
+        let guard_type = self.args.auth.guard(req_path, &method, authorization);
         if guard_type.is_reject() {
             self.auth_reject(&mut res);
             return Ok(res);
         }
 
+        let head_only = method == Method::HEAD;
+
+        if self.args.path_is_file {
+            self.handle_send_file(&self.args.path, headers, head_only, &mut res)
+                .await?;
+            return Ok(res);
+        }
+
         let path = match self.extract_path(req_path) {
             Some(v) => v,
             None => {
@@ -133,7 +140,6 @@ impl Server {
 
         match method {
             Method::GET | Method::HEAD => {
-                let head_only = method == Method::HEAD;
                 if is_dir {
                     if render_try_index && query == "zip" {
                         self.handle_zip_dir(path, head_only, &mut res).await?;
@@ -340,10 +346,7 @@ impl Server {
         res: &mut Response,
     ) -> BoxResult<()> {
         let (mut writer, reader) = tokio::io::duplex(BUF_SIZE);
-        let filename = path
-            .file_name()
-            .and_then(|v| v.to_str())
-            .ok_or_else(|| format!("Failed to get name of `{}`", path.display()))?;
+        let filename = get_file_name(path)?;
         res.headers_mut().insert(
             CONTENT_DISPOSITION,
             HeaderValue::from_str(&format!(
@@ -482,6 +485,13 @@ impl Server {
             );
         }
 
+        let filename = get_file_name(path)?;
+        res.headers_mut().insert(
+            CONTENT_DISPOSITION,
+            HeaderValue::from_str(&format!("inline; filename=\"{}\"", encode_uri(filename),))
+                .unwrap(),
+        );
+
         res.headers_mut().typed_insert(AcceptRanges::bytes());
 
         let size = meta.len();
@@ -1022,6 +1032,12 @@ fn status_no_content(res: &mut Response) {
     *res.status_mut() = StatusCode::NO_CONTENT;
 }
 
+fn get_file_name(path: &Path) -> BoxResult<&str> {
+    path.file_name()
+        .and_then(|v| v.to_str())
+        .ok_or_else(|| format!("Failed to get file name of `{}`", path.display()).into())
+}
+
 fn set_webdav_headers(res: &mut Response) {
     res.headers_mut().insert(
         "Allow",
diff --git a/tests/path_prefix.rs b/tests/args.rs
similarity index 54%
rename from tests/path_prefix.rs
rename to tests/args.rs
index dd34acf..e67ecc6 100644
--- a/tests/path_prefix.rs
+++ b/tests/args.rs
@@ -1,8 +1,11 @@
 mod fixtures;
 mod utils;
 
-use fixtures::{server, Error, TestServer};
+use assert_cmd::prelude::*;
+use assert_fs::fixture::TempDir;
+use fixtures::{port, server, tmpdir, wait_for_port, Error, TestServer};
 use rstest::rstest;
+use std::process::{Command, Stdio};
 
 #[rstest]
 fn path_prefix_index(#[with(&["--path-prefix", "xyz"])] server: TestServer) -> Result<(), Error> {
@@ -28,3 +31,23 @@ fn path_prefix_propfind(
     assert!(text.contains("<D:href>/xyz/</D:href>"));
     Ok(())
 }
+
+#[rstest]
+#[case("index.html")]
+fn serve_single_file(tmpdir: TempDir, port: u16, #[case] file: &str) -> Result<(), Error> {
+    let mut child = Command::cargo_bin("duf")?
+        .env("RUST_LOG", "false")
+        .arg(tmpdir.path().join(file))
+        .arg("-p")
+        .arg(port.to_string())
+        .stdout(Stdio::piped())
+        .spawn()?;
+
+    wait_for_port(port);
+
+    let resp = reqwest::blocking::get(format!("http://localhost:{}/index.html", port))?;
+    assert_eq!(resp.text()?, "This is index.html");
+
+    child.kill()?;
+    Ok(())
+}
diff --git a/tests/bind.rs b/tests/bind.rs
index 919d5d8..488a1ef 100644
--- a/tests/bind.rs
+++ b/tests/bind.rs
@@ -1,6 +1,6 @@
 mod fixtures;
 
-use fixtures::{port, server, tmpdir, Error, TestServer};
+use fixtures::{port, server, tmpdir, wait_for_port, Error, TestServer};
 
 use assert_cmd::prelude::*;
 use assert_fs::fixture::TempDir;
@@ -59,6 +59,8 @@ fn validate_printed_urls(tmpdir: TempDir, port: u16, #[case] args: &[&str]) -> R
         .stdout(Stdio::piped())
         .spawn()?;
 
+    wait_for_port(port);
+
     // WARN assumes urls list is terminated by an empty line
     let url_lines = BufReader::new(child.stdout.take().unwrap())
         .lines()
diff --git a/tests/fixtures.rs b/tests/fixtures.rs
index d581be3..c60747c 100644
--- a/tests/fixtures.rs
+++ b/tests/fixtures.rs
@@ -142,7 +142,7 @@ where
 }
 
 /// Wait a max of 1s for the port to become available.
-fn wait_for_port(port: u16) {
+pub fn wait_for_port(port: u16) {
     let start_wait = Instant::now();
 
     while !port_check::is_port_reachable(format!("localhost:{}", port)) {
diff --git a/tests/http.rs b/tests/http.rs
index 833da41..9cf7677 100644
--- a/tests/http.rs
+++ b/tests/http.rs
@@ -105,6 +105,7 @@ fn head_file(server: TestServer) -> Result<(), Error> {
     assert_eq!(resp.status(), 200);
     assert_eq!(resp.headers().get("content-type").unwrap(), "text/html");
     assert_eq!(resp.headers().get("accept-ranges").unwrap(), "bytes");
+    assert!(resp.headers().contains_key("content-disposition"));
     assert!(resp.headers().contains_key("etag"));
     assert!(resp.headers().contains_key("last-modified"));
     assert!(resp.headers().contains_key("content-length"));