From 52506bc01f929a5597b52c0228bc185b9312e8f4 Mon Sep 17 00:00:00 2001 From: sigoden Date: Thu, 21 Dec 2023 15:46:55 +0800 Subject: [PATCH] refactor: optimize http range parsing and handling (#323) --- src/main.rs | 4 +-- src/server.rs | 55 +++++++---------------------- src/utils.rs | 93 ++++++++++++++++++++++++++++++++++++++------------ tests/range.rs | 10 ++---- 4 files changed, 89 insertions(+), 73 deletions(-) diff --git a/src/main.rs b/src/main.rs index 6933eb6..3d0a337 100644 --- a/src/main.rs +++ b/src/main.rs @@ -95,7 +95,7 @@ fn serve(args: Args, running: Arc) -> Result>> { let (cnx, addr) = listener.accept().await.unwrap(); let Ok(stream) = tls_accepter.accept(cnx).await else { eprintln!( - "Warning during tls handshake connection from {}", + "WARNING during tls handshake connection from {}", addr ); continue; @@ -172,7 +172,7 @@ where }; match err.downcast_ref::() { Some(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => {} - _ => eprintln!("Warning serving connection{}: {}", scope, err), + _ => eprintln!("WARNING serving connection{}: {}", scope, err), } } } diff --git a/src/server.rs b/src/server.rs index 5dd4add..7d7e208 100644 --- a/src/server.rs +++ b/src/server.rs @@ -3,7 +3,8 @@ use crate::auth::{www_authenticate, AccessPaths, AccessPerm}; use crate::http_utils::{body_full, IncomingStream, LengthLimitedStream}; use crate::utils::{ - decode_uri, encode_uri, get_file_mtime_and_mode, get_file_name, glob, try_get_file_name, + decode_uri, encode_uri, get_file_mtime_and_mode, get_file_name, glob, parse_range, + try_get_file_name, }; use crate::Args; @@ -716,6 +717,7 @@ impl Server { ) -> Result<()> { let (file, meta) = tokio::join!(fs::File::open(path), fs::metadata(path),); let (mut file, meta) = (file?, meta?); + let size = meta.len(); let mut use_range = true; if let Some((etag, last_modified)) = extract_cache_headers(&meta) { let cached = { @@ -747,7 +749,12 @@ impl Server { } let range = if use_range { - parse_range(headers) + headers.get(RANGE).map(|range| { + range + .to_str() + .ok() + .and_then(|range| parse_range(range, size)) + }) } else { None }; @@ -762,18 +769,12 @@ impl Server { res.headers_mut().typed_insert(AcceptRanges::bytes()); - let size = meta.len(); - if let Some(range) = range { - if range - .end - .map_or_else(|| range.start < size, |v| v >= range.start) - && file.seek(SeekFrom::Start(range.start)).await.is_ok() - { - let end = range.end.unwrap_or(size - 1).min(size - 1); - let range_size = end - range.start + 1; + if let Some((start, end)) = range { + file.seek(SeekFrom::Start(start)).await?; + let range_size = end - start + 1; *res.status_mut() = StatusCode::PARTIAL_CONTENT; - let content_range = format!("bytes {}-{}/{}", range.start, end, size); + let content_range = format!("bytes {}-{}/{}", start, end, size); res.headers_mut() .insert(CONTENT_RANGE, content_range.parse()?); res.headers_mut() @@ -1530,36 +1531,6 @@ fn extract_cache_headers(meta: &Metadata) -> Option<(ETag, LastModified)> { Some((etag, last_modified)) } -#[derive(Debug)] -struct RangeValue { - start: u64, - end: Option, -} - -fn parse_range(headers: &HeaderMap) -> Option { - let range_hdr = headers.get(RANGE)?; - let hdr = range_hdr.to_str().ok()?; - let mut sp = hdr.splitn(2, '='); - let units = sp.next()?; - if units == "bytes" { - let range = sp.next()?; - let mut sp_range = range.splitn(2, '-'); - let start: u64 = sp_range.next()?.parse().ok()?; - let end: Option = if let Some(end) = sp_range.next() { - if end.is_empty() { - None - } else { - Some(end.parse().ok()?) - } - } else { - None - }; - Some(RangeValue { start, end }) - } else { - None - } -} - fn status_forbid(res: &mut Response) { *res.status_mut() = StatusCode::FORBIDDEN; *res.body_mut() = body_full("Forbidden"); diff --git a/src/utils.rs b/src/utils.rs index a4f8def..edf8544 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -100,26 +100,75 @@ pub fn load_private_key>(filename: T) -> Result Option<(u64, u64)> { + let (unit, range) = range.split_once('=')?; + if unit != "bytes" || range.contains(',') { + return None; + } + let (start, end) = range.split_once('-')?; + if start.is_empty() { + let offset = end.parse::().ok()?; + if offset <= size { + Some((size - offset, size - 1)) + } else { + None + } + } else { + let start = start.parse::().ok()?; + if start < size { + if end.is_empty() { + Some((start, size - 1)) + } else { + let end = end.parse::().ok()?; + if end < size { + Some((start, end)) + } else { + None + } + } + } else { + None + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_glob_key() { + assert!(glob("", "")); + assert!(glob(".*", ".git")); + assert!(glob("abc", "abc")); + assert!(glob("a*c", "abc")); + assert!(glob("a?c", "abc")); + assert!(glob("a*c", "abbc")); + assert!(glob("*c", "abc")); + assert!(glob("a*", "abc")); + assert!(glob("?c", "bc")); + assert!(glob("a?", "ab")); + assert!(!glob("abc", "adc")); + assert!(!glob("abc", "abcd")); + assert!(!glob("a?c", "abbc")); + assert!(!glob("*.log", "log")); + assert!(glob("*.abc-cba", "xyz.abc-cba")); + assert!(glob("*.abc-cba", "123.xyz.abc-cba")); + assert!(glob("*.log", ".log")); + assert!(glob("*.log", "a.log")); + assert!(glob("*/", "abc/")); + assert!(!glob("*/", "abc")); + } + + #[test] + fn test_parse_range() { + assert_eq!(parse_range("bytes=0-499", 500), Some((0, 499))); + assert_eq!(parse_range("bytes=0-", 500), Some((0, 499))); + assert_eq!(parse_range("bytes=299-", 500), Some((299, 499))); + assert_eq!(parse_range("bytes=-500", 500), Some((0, 499))); + assert_eq!(parse_range("bytes=-300", 500), Some((200, 499))); + assert_eq!(parse_range("bytes=500-", 500), None); + assert_eq!(parse_range("bytes=-501", 500), None); + assert_eq!(parse_range("bytes=0-500", 500), None); + } } diff --git a/tests/range.rs b/tests/range.rs index 4da721b..511c244 100644 --- a/tests/range.rs +++ b/tests/range.rs @@ -23,14 +23,10 @@ fn get_file_range_beyond(server: TestServer) -> Result<(), Error> { let resp = fetch!(b"GET", format!("{}index.html", server.url())) .header("range", HeaderValue::from_static("bytes=12-20")) .send()?; - assert_eq!(resp.status(), 206); - assert_eq!( - resp.headers().get("content-range").unwrap(), - "bytes 12-17/18" - ); + assert_eq!(resp.status(), 416); + assert_eq!(resp.headers().get("content-range").unwrap(), "bytes */18"); assert_eq!(resp.headers().get("accept-ranges").unwrap(), "bytes"); - assert_eq!(resp.headers().get("content-length").unwrap(), "6"); - assert_eq!(resp.text()?, "x.html"); + assert_eq!(resp.headers().get("content-length").unwrap(), "0"); Ok(()) }