From 92b715e0f2c655169e82fb76f7b109b8de21a095 Mon Sep 17 00:00:00 2001
From: zeripath <art27@cantab.net>
Date: Sat, 29 Jan 2022 11:37:08 +0000
Subject: [PATCH] Attempt to prevent the deadlock in the QueueDiskChannel Test
 again (#18415)

* Attempt to prevent the deadlock in the QueueDiskChannel Test again

This time we're going to adjust the pause tests to only test the right
flag.

* Only switch off pushback once we know that we are not pushing anything else
* Ensure full redirection occurs
* More nicely handle a closed datachan
* And handle similar problems in queue_channel_test

Signed-off-by: Andrew Thornton <art27@cantab.net>
---
 modules/queue/queue_bytefifo.go            |   5 +-
 modules/queue/queue_channel.go             |   5 +-
 modules/queue/queue_channel_test.go        |  69 +++++++++-----
 modules/queue/queue_disk_channel.go        |  15 ++-
 modules/queue/queue_disk_channel_test.go   | 102 ++++++++++-----------
 modules/queue/unique_queue_channel.go      |   5 +-
 modules/queue/unique_queue_disk_channel.go |  13 ++-
 7 files changed, 119 insertions(+), 95 deletions(-)

diff --git a/modules/queue/queue_bytefifo.go b/modules/queue/queue_bytefifo.go
index 7f2acf3deb..bf153d70bb 100644
--- a/modules/queue/queue_bytefifo.go
+++ b/modules/queue/queue_bytefifo.go
@@ -205,7 +205,10 @@ loop:
 				// tell the pool to shutdown.
 				q.baseCtxCancel()
 				return
-			case data := <-q.dataChan:
+			case data, ok := <-q.dataChan:
+				if !ok {
+					return
+				}
 				if err := q.PushBack(data); err != nil {
 					log.Error("Unable to push back data into queue %s", q.name)
 				}
diff --git a/modules/queue/queue_channel.go b/modules/queue/queue_channel.go
index 105388f421..5469c03100 100644
--- a/modules/queue/queue_channel.go
+++ b/modules/queue/queue_channel.go
@@ -117,7 +117,10 @@ func (q *ChannelQueue) FlushWithContext(ctx context.Context) error {
 		select {
 		case <-paused:
 			return nil
-		case data := <-q.dataChan:
+		case data, ok := <-q.dataChan:
+			if !ok {
+				return nil
+			}
 			if unhandled := q.handle(data); unhandled != nil {
 				log.Error("Unhandled Data whilst flushing queue %d", q.qid)
 			}
diff --git a/modules/queue/queue_channel_test.go b/modules/queue/queue_channel_test.go
index b700b28a14..f779ed5f8a 100644
--- a/modules/queue/queue_channel_test.go
+++ b/modules/queue/queue_channel_test.go
@@ -9,6 +9,7 @@ import (
 	"testing"
 	"time"
 
+	"code.gitea.io/gitea/modules/log"
 	"github.com/stretchr/testify/assert"
 )
 
@@ -111,7 +112,6 @@ func TestChannelQueue_Pause(t *testing.T) {
 			if pausable, ok := queue.(Pausable); ok {
 				pausable.Pause()
 			}
-			pushBack = false
 			lock.Unlock()
 			return data
 		}
@@ -123,7 +123,9 @@ func TestChannelQueue_Pause(t *testing.T) {
 		}
 		return nil
 	}
-	nilFn := func(_ func()) {}
+
+	queueShutdown := []func(){}
+	queueTerminate := []func(){}
 
 	queue, err = NewChannelQueue(handle,
 		ChannelQueueConfiguration{
@@ -139,7 +141,34 @@ func TestChannelQueue_Pause(t *testing.T) {
 		}, &testData{})
 	assert.NoError(t, err)
 
-	go queue.Run(nilFn, nilFn)
+	go queue.Run(func(shutdown func()) {
+		lock.Lock()
+		defer lock.Unlock()
+		queueShutdown = append(queueShutdown, shutdown)
+	}, func(terminate func()) {
+		lock.Lock()
+		defer lock.Unlock()
+		queueTerminate = append(queueTerminate, terminate)
+	})
+
+	// Shutdown and Terminate in defer
+	defer func() {
+		lock.Lock()
+		callbacks := make([]func(), len(queueShutdown))
+		copy(callbacks, queueShutdown)
+		lock.Unlock()
+		for _, callback := range callbacks {
+			callback()
+		}
+		lock.Lock()
+		log.Info("Finally terminating")
+		callbacks = make([]func(), len(queueTerminate))
+		copy(callbacks, queueTerminate)
+		lock.Unlock()
+		for _, callback := range callbacks {
+			callback()
+		}
+	}()
 
 	test1 := testData{"A", 1}
 	test2 := testData{"B", 2}
@@ -155,14 +184,11 @@ func TestChannelQueue_Pause(t *testing.T) {
 
 	pausable.Pause()
 
-	paused, resumed := pausable.IsPausedIsResumed()
+	paused, _ := pausable.IsPausedIsResumed()
 
 	select {
 	case <-paused:
-	case <-resumed:
-		assert.Fail(t, "Queue should not be resumed")
-		return
-	default:
+	case <-time.After(100 * time.Millisecond):
 		assert.Fail(t, "Queue is not paused")
 		return
 	}
@@ -179,10 +205,11 @@ func TestChannelQueue_Pause(t *testing.T) {
 	assert.Nil(t, result2)
 
 	pausable.Resume()
+	_, resumed := pausable.IsPausedIsResumed()
 
 	select {
 	case <-resumed:
-	default:
+	case <-time.After(100 * time.Millisecond):
 		assert.Fail(t, "Queue should be resumed")
 	}
 
@@ -199,47 +226,47 @@ func TestChannelQueue_Pause(t *testing.T) {
 	pushBack = true
 	lock.Unlock()
 
-	paused, resumed = pausable.IsPausedIsResumed()
+	_, resumed = pausable.IsPausedIsResumed()
 
 	select {
-	case <-paused:
-		assert.Fail(t, "Queue should not be paused")
-		return
 	case <-resumed:
-	default:
+	case <-time.After(100 * time.Millisecond):
 		assert.Fail(t, "Queue is not resumed")
 		return
 	}
 
 	queue.Push(&test1)
+	paused, _ = pausable.IsPausedIsResumed()
 
 	select {
 	case <-paused:
 	case <-handleChan:
 		assert.Fail(t, "handler chan should not contain test1")
 		return
-	case <-time.After(500 * time.Millisecond):
+	case <-time.After(100 * time.Millisecond):
 		assert.Fail(t, "queue should be paused")
 		return
 	}
 
-	paused, resumed = pausable.IsPausedIsResumed()
+	lock.Lock()
+	pushBack = false
+	lock.Unlock()
+
+	paused, _ = pausable.IsPausedIsResumed()
 
 	select {
 	case <-paused:
-	case <-resumed:
-		assert.Fail(t, "Queue should not be resumed")
-		return
-	default:
+	case <-time.After(100 * time.Millisecond):
 		assert.Fail(t, "Queue is not paused")
 		return
 	}
 
 	pausable.Resume()
+	_, resumed = pausable.IsPausedIsResumed()
 
 	select {
 	case <-resumed:
-	default:
+	case <-time.After(100 * time.Millisecond):
 		assert.Fail(t, "Queue should be resumed")
 	}
 
diff --git a/modules/queue/queue_disk_channel.go b/modules/queue/queue_disk_channel.go
index 3b21575a0e..0494698e0e 100644
--- a/modules/queue/queue_disk_channel.go
+++ b/modules/queue/queue_disk_channel.go
@@ -313,14 +313,13 @@ func (q *PersistableChannelQueue) Shutdown() {
 	q.channelQueue.Wait()
 	q.internal.(*LevelQueue).Wait()
 	// Redirect all remaining data in the chan to the internal channel
-	go func() {
-		log.Trace("PersistableChannelQueue: %s Redirecting remaining data", q.delayedStarter.name)
-		for data := range q.channelQueue.dataChan {
-			_ = q.internal.Push(data)
-			atomic.AddInt64(&q.channelQueue.numInQueue, -1)
-		}
-		log.Trace("PersistableChannelQueue: %s Done Redirecting remaining data", q.delayedStarter.name)
-	}()
+	log.Trace("PersistableChannelQueue: %s Redirecting remaining data", q.delayedStarter.name)
+	close(q.channelQueue.dataChan)
+	for data := range q.channelQueue.dataChan {
+		_ = q.internal.Push(data)
+		atomic.AddInt64(&q.channelQueue.numInQueue, -1)
+	}
+	log.Trace("PersistableChannelQueue: %s Done Redirecting remaining data", q.delayedStarter.name)
 
 	log.Debug("PersistableChannelQueue: %s Shutdown", q.delayedStarter.name)
 }
diff --git a/modules/queue/queue_disk_channel_test.go b/modules/queue/queue_disk_channel_test.go
index 026982fd92..f092bb1f56 100644
--- a/modules/queue/queue_disk_channel_test.go
+++ b/modules/queue/queue_disk_channel_test.go
@@ -207,7 +207,6 @@ func TestPersistableChannelQueue_Pause(t *testing.T) {
 				log.Info("pausing")
 				pausable.Pause()
 			}
-			pushBack = false
 			lock.Unlock()
 			return data
 		}
@@ -248,6 +247,25 @@ func TestPersistableChannelQueue_Pause(t *testing.T) {
 		queueTerminate = append(queueTerminate, terminate)
 	})
 
+	// Shutdown and Terminate in defer
+	defer func() {
+		lock.Lock()
+		callbacks := make([]func(), len(queueShutdown))
+		copy(callbacks, queueShutdown)
+		lock.Unlock()
+		for _, callback := range callbacks {
+			callback()
+		}
+		lock.Lock()
+		log.Info("Finally terminating")
+		callbacks = make([]func(), len(queueTerminate))
+		copy(callbacks, queueTerminate)
+		lock.Unlock()
+		for _, callback := range callbacks {
+			callback()
+		}
+	}()
+
 	test1 := testData{"A", 1}
 	test2 := testData{"B", 2}
 
@@ -263,14 +281,11 @@ func TestPersistableChannelQueue_Pause(t *testing.T) {
 	assert.Equal(t, test1.TestInt, result1.TestInt)
 
 	pausable.Pause()
-	paused, resumed := pausable.IsPausedIsResumed()
+	paused, _ := pausable.IsPausedIsResumed()
 
 	select {
 	case <-paused:
-	case <-resumed:
-		assert.Fail(t, "Queue should not be resumed")
-		return
-	default:
+	case <-time.After(100 * time.Millisecond):
 		assert.Fail(t, "Queue is not paused")
 		return
 	}
@@ -287,14 +302,11 @@ func TestPersistableChannelQueue_Pause(t *testing.T) {
 	assert.Nil(t, result2)
 
 	pausable.Resume()
-	paused, resumed = pausable.IsPausedIsResumed()
+	_, resumed := pausable.IsPausedIsResumed()
 
 	select {
-	case <-paused:
-		assert.Fail(t, "Queue should be resumed")
-		return
 	case <-resumed:
-	default:
+	case <-time.After(100 * time.Millisecond):
 		assert.Fail(t, "Queue should be resumed")
 		return
 	}
@@ -308,24 +320,27 @@ func TestPersistableChannelQueue_Pause(t *testing.T) {
 	assert.Equal(t, test2.TestString, result2.TestString)
 	assert.Equal(t, test2.TestInt, result2.TestInt)
 
+	// Set pushBack to so that the next handle will result in a Pause
 	lock.Lock()
 	pushBack = true
 	lock.Unlock()
 
-	paused, resumed = pausable.IsPausedIsResumed()
+	// Ensure that we're still resumed
+	_, resumed = pausable.IsPausedIsResumed()
 
 	select {
-	case <-paused:
-		assert.Fail(t, "Queue should not be paused")
-		return
 	case <-resumed:
-	default:
+	case <-time.After(100 * time.Millisecond):
 		assert.Fail(t, "Queue is not resumed")
 		return
 	}
 
+	// push test1
 	queue.Push(&test1)
 
+	// Now as this is handled it should pause
+	paused, _ = pausable.IsPausedIsResumed()
+
 	select {
 	case <-paused:
 	case <-handleChan:
@@ -336,27 +351,16 @@ func TestPersistableChannelQueue_Pause(t *testing.T) {
 		return
 	}
 
-	paused, resumed = pausable.IsPausedIsResumed()
-
-	select {
-	case <-paused:
-	case <-resumed:
-		assert.Fail(t, "Queue should not be resumed")
-		return
-	default:
-		assert.Fail(t, "Queue is not paused")
-		return
-	}
+	lock.Lock()
+	pushBack = false
+	lock.Unlock()
 
 	pausable.Resume()
 
-	paused, resumed = pausable.IsPausedIsResumed()
+	_, resumed = pausable.IsPausedIsResumed()
 	select {
-	case <-paused:
-		assert.Fail(t, "Queue should not be paused")
-		return
 	case <-resumed:
-	default:
+	case <-time.After(500 * time.Millisecond):
 		assert.Fail(t, "Queue should be resumed")
 		return
 	}
@@ -373,6 +377,7 @@ func TestPersistableChannelQueue_Pause(t *testing.T) {
 	lock.Lock()
 	callbacks := make([]func(), len(queueShutdown))
 	copy(callbacks, queueShutdown)
+	queueShutdown = queueShutdown[:0]
 	lock.Unlock()
 	// Now shutdown the queue
 	for _, callback := range callbacks {
@@ -402,6 +407,7 @@ func TestPersistableChannelQueue_Pause(t *testing.T) {
 	lock.Lock()
 	callbacks = make([]func(), len(queueTerminate))
 	copy(callbacks, queueTerminate)
+	queueShutdown = queueTerminate[:0]
 	lock.Unlock()
 	for _, callback := range callbacks {
 		callback()
@@ -453,14 +459,11 @@ func TestPersistableChannelQueue_Pause(t *testing.T) {
 	case <-paused:
 	}
 
-	paused, resumed = pausable.IsPausedIsResumed()
+	paused, _ = pausable.IsPausedIsResumed()
 
 	select {
 	case <-paused:
-	case <-resumed:
-		assert.Fail(t, "Queue should not be resumed")
-		return
-	default:
+	case <-time.After(500 * time.Millisecond):
 		assert.Fail(t, "Queue is not paused")
 		return
 	}
@@ -472,14 +475,15 @@ func TestPersistableChannelQueue_Pause(t *testing.T) {
 	default:
 	}
 
+	lock.Lock()
+	pushBack = false
+	lock.Unlock()
+
 	pausable.Resume()
-	paused, resumed = pausable.IsPausedIsResumed()
+	_, resumed = pausable.IsPausedIsResumed()
 	select {
-	case <-paused:
-		assert.Fail(t, "Queue should not be paused")
-		return
 	case <-resumed:
-	default:
+	case <-time.After(500 * time.Millisecond):
 		assert.Fail(t, "Queue should be resumed")
 		return
 	}
@@ -506,18 +510,4 @@ func TestPersistableChannelQueue_Pause(t *testing.T) {
 
 	assert.Equal(t, test2.TestString, result4.TestString)
 	assert.Equal(t, test2.TestInt, result4.TestInt)
-	lock.Lock()
-	callbacks = make([]func(), len(queueShutdown))
-	copy(callbacks, queueShutdown)
-	lock.Unlock()
-	for _, callback := range callbacks {
-		callback()
-	}
-	lock.Lock()
-	callbacks = make([]func(), len(queueTerminate))
-	copy(callbacks, queueTerminate)
-	lock.Unlock()
-	for _, callback := range callbacks {
-		callback()
-	}
 }
diff --git a/modules/queue/unique_queue_channel.go b/modules/queue/unique_queue_channel.go
index 59210855a1..b7282e6c6c 100644
--- a/modules/queue/unique_queue_channel.go
+++ b/modules/queue/unique_queue_channel.go
@@ -178,7 +178,10 @@ func (q *ChannelUniqueQueue) FlushWithContext(ctx context.Context) error {
 		default:
 		}
 		select {
-		case data := <-q.dataChan:
+		case data, ok := <-q.dataChan:
+			if !ok {
+				return nil
+			}
 			if unhandled := q.handle(data); unhandled != nil {
 				log.Error("Unhandled Data whilst flushing queue %d", q.qid)
 			}
diff --git a/modules/queue/unique_queue_disk_channel.go b/modules/queue/unique_queue_disk_channel.go
index ac7919926f..5ee1c396fc 100644
--- a/modules/queue/unique_queue_disk_channel.go
+++ b/modules/queue/unique_queue_disk_channel.go
@@ -282,13 +282,12 @@ func (q *PersistableChannelUniqueQueue) Shutdown() {
 	q.channelQueue.Wait()
 	q.internal.(*LevelUniqueQueue).Wait()
 	// Redirect all remaining data in the chan to the internal channel
-	go func() {
-		log.Trace("PersistableChannelUniqueQueue: %s Redirecting remaining data", q.delayedStarter.name)
-		for data := range q.channelQueue.dataChan {
-			_ = q.internal.Push(data)
-		}
-		log.Trace("PersistableChannelUniqueQueue: %s Done Redirecting remaining data", q.delayedStarter.name)
-	}()
+	close(q.channelQueue.dataChan)
+	log.Trace("PersistableChannelUniqueQueue: %s Redirecting remaining data", q.delayedStarter.name)
+	for data := range q.channelQueue.dataChan {
+		_ = q.internal.Push(data)
+	}
+	log.Trace("PersistableChannelUniqueQueue: %s Done Redirecting remaining data", q.delayedStarter.name)
 
 	log.Debug("PersistableChannelUniqueQueue: %s Shutdown", q.delayedStarter.name)
 }