From fce22362d401227bd41e50a26cf53b697d6fb019 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Timo=20K=C3=B6sters?= <timo@koesters.xyz>
Date: Tue, 3 Aug 2021 19:18:41 +0200
Subject: [PATCH] improvement: better auth chain calculation

---
 src/server_server.rs | 14 +++++++-------
 1 file changed, 7 insertions(+), 7 deletions(-)

diff --git a/src/server_server.rs b/src/server_server.rs
index a20c9abe..aa70ce01 100644
--- a/src/server_server.rs
+++ b/src/server_server.rs
@@ -1732,7 +1732,7 @@ fn get_auth_chain(starting_events: Vec<EventId>, db: &Database) -> Result<HashSe
             cached.clone()
         } else {
             drop(cache);
-            let auth_chain = get_auth_chain_recursive(&event_id, db)?;
+            let auth_chain = get_auth_chain_recursive(&event_id, HashSet::new(), db)?;
 
             cache = db.rooms.auth_chain_cache();
 
@@ -1747,19 +1747,19 @@ fn get_auth_chain(starting_events: Vec<EventId>, db: &Database) -> Result<HashSe
     Ok(full_auth_chain)
 }
 
-fn get_auth_chain_recursive(event_id: &EventId, db: &Database) -> Result<HashSet<EventId>> {
-    let mut auth_chain = HashSet::new();
-
+fn get_auth_chain_recursive(event_id: &EventId, mut found: HashSet<EventId>, db: &Database) -> Result<HashSet<EventId>> {
     if let Some(pdu) = db.rooms.get_pdu(&event_id)? {
-        auth_chain.extend(pdu.auth_events.iter().cloned());
         for auth_event in &pdu.auth_events {
-            auth_chain.extend(get_auth_chain_recursive(&auth_event, db)?);
+            if !found.contains(auth_event) {
+                found.insert(auth_event.clone());
+                found = get_auth_chain_recursive(&auth_event, found, db)?;
+            }
         }
     } else {
         warn!("Could not find pdu mentioned in auth events.");
     }
 
-    Ok(auth_chain)
+    Ok(found)
 }
 
 #[cfg_attr(