From 37800f630da6bc60839d132b3d57a30f84bd03e0 Mon Sep 17 00:00:00 2001 From: sigoden Date: Thu, 7 Dec 2023 15:04:14 +0800 Subject: [PATCH] refactor: change the format of www-authenticate (#312) --- src/auth.rs | 28 ++++++++++++++++------------ src/server.rs | 7 +++---- tests/auth.rs | 10 +++++++++- 3 files changed, 28 insertions(+), 17 deletions(-) diff --git a/src/auth.rs b/src/auth.rs index 678079c..6f87e0b 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -1,7 +1,9 @@ +use crate::{args::Args, server::Response, utils::unix_now}; + use anyhow::{anyhow, bail, Result}; use base64::{engine::general_purpose, Engine as _}; use headers::HeaderValue; -use hyper::Method; +use hyper::{header::WWW_AUTHENTICATE, Method}; use indexmap::IndexMap; use lazy_static::lazy_static; use md5::Context; @@ -11,8 +13,6 @@ use std::{ }; use uuid::Uuid; -use crate::{args::Args, utils::unix_now}; - const REALM: &str = "DUFS"; const DIGEST_AUTH_TIMEOUT: u32 = 604800; // 7 days @@ -258,17 +258,21 @@ impl AccessPerm { } } -pub fn www_authenticate(args: &Args) -> Result { - let value = if args.auth.use_hashed_password { - format!("Basic realm=\"{}\"", REALM) +pub fn www_authenticate(res: &mut Response, args: &Args) -> Result<()> { + if args.auth.use_hashed_password { + let basic = HeaderValue::from_str(&format!("Basic realm=\"{}\"", REALM))?; + res.headers_mut().insert(WWW_AUTHENTICATE, basic); } else { let nonce = create_nonce()?; - format!( - "Digest realm=\"{}\", nonce=\"{}\", qop=\"auth\", Basic realm=\"{}\"", - REALM, nonce, REALM - ) - }; - Ok(HeaderValue::from_str(&value)?) + let digest = HeaderValue::from_str(&format!( + "Digest realm=\"{}\", nonce=\"{}\", qop=\"auth\"", + REALM, nonce + ))?; + let basic = HeaderValue::from_str(&format!("Basic realm=\"{}\"", REALM))?; + res.headers_mut().append(WWW_AUTHENTICATE, digest); + res.headers_mut().append(WWW_AUTHENTICATE, basic); + } + Ok(()) } pub fn get_auth_user(authorization: &HeaderValue) -> Option { diff --git a/src/server.rs b/src/server.rs index 41b7e2e..8d9a5e3 100644 --- a/src/server.rs +++ b/src/server.rs @@ -21,7 +21,7 @@ use headers::{ }; use hyper::header::{ HeaderValue, AUTHORIZATION, CONTENT_DISPOSITION, CONTENT_LENGTH, CONTENT_RANGE, CONTENT_TYPE, - RANGE, WWW_AUTHENTICATE, + RANGE, }; use hyper::{Body, Method, StatusCode, Uri}; use serde::Serialize; @@ -1056,9 +1056,8 @@ impl Server { fn auth_reject(&self, res: &mut Response) -> Result<()> { set_webdav_headers(res); - res.headers_mut() - .append(WWW_AUTHENTICATE, www_authenticate(&self.args)?); - // set 401 to make the browser pop up the login box + + www_authenticate(res, &self.args)?; *res.status_mut() = StatusCode::UNAUTHORIZED; Ok(()) } diff --git a/tests/auth.rs b/tests/auth.rs index fbf6347..81e4d2a 100644 --- a/tests/auth.rs +++ b/tests/auth.rs @@ -10,7 +10,15 @@ use rstest::rstest; fn no_auth(#[with(&["--auth", "user:pass@/:rw", "-A"])] server: TestServer) -> Result<(), Error> { let resp = reqwest::blocking::get(server.url())?; assert_eq!(resp.status(), 401); - assert!(resp.headers().contains_key("www-authenticate")); + let values: Vec<&str> = resp + .headers() + .get_all("www-authenticate") + .iter() + .map(|v| v.to_str().unwrap()) + .collect(); + assert!(values[0].starts_with("Digest")); + assert!(values[1].starts_with("Basic")); + let url = format!("{}file1", server.url()); let resp = fetch!(b"PUT", &url).body(b"abc".to_vec()).send()?; assert_eq!(resp.status(), 401);