From 368e9e0f1bf1094aad881336c698f540c047d9ac Mon Sep 17 00:00:00 2001
From: Lunny Xiao <xiaolunwen@gmail.com>
Date: Fri, 11 Aug 2023 13:27:23 +0800
Subject: [PATCH] Add transaction when creating pull request created dirty data
 (#26259) (#26437)

Backport #26259

This PR will introduce a transaction on creating pull request so that if
some step failed, it will rollback totally. And there will be no dirty
pull request exist.

Co-authored-by: Giteabot <teabot@gitea.io>
---
 models/issues/pull.go      |   5 +-
 models/issues/pull_test.go |   4 +-
 models/issues/review.go    |  16 ++---
 routers/web/repo/issue.go  |   2 +-
 services/issue/assignee.go |  14 ++--
 services/issue/issue.go    |  26 +++----
 services/pull/patch.go     |   9 ++-
 services/pull/pull.go      | 135 ++++++++++++++++++++++---------------
 8 files changed, 122 insertions(+), 89 deletions(-)

diff --git a/models/issues/pull.go b/models/issues/pull.go
index 444ae49e05..a403ebd4dd 100644
--- a/models/issues/pull.go
+++ b/models/issues/pull.go
@@ -531,13 +531,12 @@ func (pr *PullRequest) SetMerged(ctx context.Context) (bool, error) {
 }
 
 // NewPullRequest creates new pull request with labels for repository.
-func NewPullRequest(outerCtx context.Context, repo *repo_model.Repository, issue *Issue, labelIDs []int64, uuids []string, pr *PullRequest) (err error) {
-	ctx, committer, err := db.TxContext(outerCtx)
+func NewPullRequest(ctx context.Context, repo *repo_model.Repository, issue *Issue, labelIDs []int64, uuids []string, pr *PullRequest) (err error) {
+	ctx, committer, err := db.TxContext(ctx)
 	if err != nil {
 		return err
 	}
 	defer committer.Close()
-	ctx.WithContext(outerCtx)
 
 	idx, err := db.GetNextResourceIndex(ctx, "issue_index", repo.ID)
 	if err != nil {
diff --git a/models/issues/pull_test.go b/models/issues/pull_test.go
index 7a183ac312..81f64e0104 100644
--- a/models/issues/pull_test.go
+++ b/models/issues/pull_test.go
@@ -87,14 +87,14 @@ func TestLoadRequestedReviewers(t *testing.T) {
 	user1, err := user_model.GetUserByID(db.DefaultContext, 1)
 	assert.NoError(t, err)
 
-	comment, err := issues_model.AddReviewRequest(issue, user1, &user_model.User{})
+	comment, err := issues_model.AddReviewRequest(db.DefaultContext, issue, user1, &user_model.User{})
 	assert.NoError(t, err)
 	assert.NotNil(t, comment)
 
 	assert.NoError(t, pull.LoadRequestedReviewers(db.DefaultContext))
 	assert.Len(t, pull.RequestedReviewers, 1)
 
-	comment, err = issues_model.RemoveReviewRequest(issue, user1, &user_model.User{})
+	comment, err = issues_model.RemoveReviewRequest(db.DefaultContext, issue, user1, &user_model.User{})
 	assert.NoError(t, err)
 	assert.NotNil(t, comment)
 
diff --git a/models/issues/review.go b/models/issues/review.go
index 83ca7d3d87..124f69d00d 100644
--- a/models/issues/review.go
+++ b/models/issues/review.go
@@ -557,8 +557,8 @@ func InsertReviews(reviews []*Review) error {
 }
 
 // AddReviewRequest add a review request from one reviewer
-func AddReviewRequest(issue *Issue, reviewer, doer *user_model.User) (*Comment, error) {
-	ctx, committer, err := db.TxContext(db.DefaultContext)
+func AddReviewRequest(ctx context.Context, issue *Issue, reviewer, doer *user_model.User) (*Comment, error) {
+	ctx, committer, err := db.TxContext(ctx)
 	if err != nil {
 		return nil, err
 	}
@@ -612,8 +612,8 @@ func AddReviewRequest(issue *Issue, reviewer, doer *user_model.User) (*Comment,
 }
 
 // RemoveReviewRequest remove a review request from one reviewer
-func RemoveReviewRequest(issue *Issue, reviewer, doer *user_model.User) (*Comment, error) {
-	ctx, committer, err := db.TxContext(db.DefaultContext)
+func RemoveReviewRequest(ctx context.Context, issue *Issue, reviewer, doer *user_model.User) (*Comment, error) {
+	ctx, committer, err := db.TxContext(ctx)
 	if err != nil {
 		return nil, err
 	}
@@ -673,8 +673,8 @@ func restoreLatestOfficialReview(ctx context.Context, issueID, reviewerID int64)
 }
 
 // AddTeamReviewRequest add a review request from one team
-func AddTeamReviewRequest(issue *Issue, reviewer *organization.Team, doer *user_model.User) (*Comment, error) {
-	ctx, committer, err := db.TxContext(db.DefaultContext)
+func AddTeamReviewRequest(ctx context.Context, issue *Issue, reviewer *organization.Team, doer *user_model.User) (*Comment, error) {
+	ctx, committer, err := db.TxContext(ctx)
 	if err != nil {
 		return nil, err
 	}
@@ -732,8 +732,8 @@ func AddTeamReviewRequest(issue *Issue, reviewer *organization.Team, doer *user_
 }
 
 // RemoveTeamReviewRequest remove a review request from one team
-func RemoveTeamReviewRequest(issue *Issue, reviewer *organization.Team, doer *user_model.User) (*Comment, error) {
-	ctx, committer, err := db.TxContext(db.DefaultContext)
+func RemoveTeamReviewRequest(ctx context.Context, issue *Issue, reviewer *organization.Team, doer *user_model.User) (*Comment, error) {
+	ctx, committer, err := db.TxContext(ctx)
 	if err != nil {
 		return nil, err
 	}
diff --git a/routers/web/repo/issue.go b/routers/web/repo/issue.go
index 2ad7af050b..1904ae3ba9 100644
--- a/routers/web/repo/issue.go
+++ b/routers/web/repo/issue.go
@@ -2226,7 +2226,7 @@ func UpdateIssueAssignee(ctx *context.Context) {
 				return
 			}
 
-			_, _, err = issue_service.ToggleAssignee(ctx, issue, ctx.Doer, assigneeID)
+			_, _, err = issue_service.ToggleAssigneeWithNotify(ctx, issue, ctx.Doer, assigneeID)
 			if err != nil {
 				ctx.ServerError("ToggleAssignee", err)
 				return
diff --git a/services/issue/assignee.go b/services/issue/assignee.go
index 943a761e28..2ff0a97fa3 100644
--- a/services/issue/assignee.go
+++ b/services/issue/assignee.go
@@ -33,7 +33,7 @@ func DeleteNotPassedAssignee(ctx context.Context, issue *issues_model.Issue, doe
 
 		if !found {
 			// This function also does comments and hooks, which is why we call it separately instead of directly removing the assignees here
-			if _, _, err := ToggleAssignee(ctx, issue, doer, assignee.ID); err != nil {
+			if _, _, err := ToggleAssigneeWithNotify(ctx, issue, doer, assignee.ID); err != nil {
 				return err
 			}
 		}
@@ -42,8 +42,8 @@ func DeleteNotPassedAssignee(ctx context.Context, issue *issues_model.Issue, doe
 	return nil
 }
 
-// ToggleAssignee changes a user between assigned and not assigned for this issue, and make issue comment for it.
-func ToggleAssignee(ctx context.Context, issue *issues_model.Issue, doer *user_model.User, assigneeID int64) (removed bool, comment *issues_model.Comment, err error) {
+// ToggleAssigneeWithNoNotify changes a user between assigned and not assigned for this issue, and make issue comment for it.
+func ToggleAssigneeWithNotify(ctx context.Context, issue *issues_model.Issue, doer *user_model.User, assigneeID int64) (removed bool, comment *issues_model.Comment, err error) {
 	removed, comment, err = issues_model.ToggleIssueAssignee(ctx, issue, doer, assigneeID)
 	if err != nil {
 		return
@@ -63,9 +63,9 @@ func ToggleAssignee(ctx context.Context, issue *issues_model.Issue, doer *user_m
 // ReviewRequest add or remove a review request from a user for this PR, and make comment for it.
 func ReviewRequest(ctx context.Context, issue *issues_model.Issue, doer, reviewer *user_model.User, isAdd bool) (comment *issues_model.Comment, err error) {
 	if isAdd {
-		comment, err = issues_model.AddReviewRequest(issue, reviewer, doer)
+		comment, err = issues_model.AddReviewRequest(ctx, issue, reviewer, doer)
 	} else {
-		comment, err = issues_model.RemoveReviewRequest(issue, reviewer, doer)
+		comment, err = issues_model.RemoveReviewRequest(ctx, issue, reviewer, doer)
 	}
 
 	if err != nil {
@@ -230,9 +230,9 @@ func IsValidTeamReviewRequest(ctx context.Context, reviewer *organization.Team,
 // TeamReviewRequest add or remove a review request from a team for this PR, and make comment for it.
 func TeamReviewRequest(ctx context.Context, issue *issues_model.Issue, doer *user_model.User, reviewer *organization.Team, isAdd bool) (comment *issues_model.Comment, err error) {
 	if isAdd {
-		comment, err = issues_model.AddTeamReviewRequest(issue, reviewer, doer)
+		comment, err = issues_model.AddTeamReviewRequest(ctx, issue, reviewer, doer)
 	} else {
-		comment, err = issues_model.RemoveTeamReviewRequest(issue, reviewer, doer)
+		comment, err = issues_model.RemoveTeamReviewRequest(ctx, issue, reviewer, doer)
 	}
 
 	if err != nil {
diff --git a/services/issue/issue.go b/services/issue/issue.go
index 03d34ba12d..596b27c588 100644
--- a/services/issue/issue.go
+++ b/services/issue/issue.go
@@ -27,7 +27,7 @@ func NewIssue(ctx context.Context, repo *repo_model.Repository, issue *issues_mo
 	}
 
 	for _, assigneeID := range assigneeIDs {
-		if err := AddAssigneeIfNotAssigned(ctx, issue, issue.Poster, assigneeID); err != nil {
+		if _, err := AddAssigneeIfNotAssigned(ctx, issue, issue.Poster, assigneeID, true); err != nil {
 			return err
 		}
 	}
@@ -122,7 +122,7 @@ func UpdateAssignees(ctx context.Context, issue *issues_model.Issue, oneAssignee
 	// has access to the repo.
 	for _, assignee := range allNewAssignees {
 		// Extra method to prevent double adding (which would result in removing)
-		err = AddAssigneeIfNotAssigned(ctx, issue, doer, assignee.ID)
+		_, err = AddAssigneeIfNotAssigned(ctx, issue, doer, assignee.ID, true)
 		if err != nil {
 			return err
 		}
@@ -167,36 +167,36 @@ func DeleteIssue(ctx context.Context, doer *user_model.User, gitRepo *git.Reposi
 
 // AddAssigneeIfNotAssigned adds an assignee only if he isn't already assigned to the issue.
 // Also checks for access of assigned user
-func AddAssigneeIfNotAssigned(ctx context.Context, issue *issues_model.Issue, doer *user_model.User, assigneeID int64) (err error) {
+func AddAssigneeIfNotAssigned(ctx context.Context, issue *issues_model.Issue, doer *user_model.User, assigneeID int64, notify bool) (comment *issues_model.Comment, err error) {
 	assignee, err := user_model.GetUserByID(ctx, assigneeID)
 	if err != nil {
-		return err
+		return nil, err
 	}
 
 	// Check if the user is already assigned
 	isAssigned, err := issues_model.IsUserAssignedToIssue(ctx, issue, assignee)
 	if err != nil {
-		return err
+		return nil, err
 	}
 	if isAssigned {
 		// nothing to to
-		return nil
+		return nil, nil
 	}
 
 	valid, err := access_model.CanBeAssigned(ctx, assignee, issue.Repo, issue.IsPull)
 	if err != nil {
-		return err
+		return nil, err
 	}
 	if !valid {
-		return repo_model.ErrUserDoesNotHaveAccessToRepo{UserID: assigneeID, RepoName: issue.Repo.Name}
+		return nil, repo_model.ErrUserDoesNotHaveAccessToRepo{UserID: assigneeID, RepoName: issue.Repo.Name}
 	}
 
-	_, _, err = ToggleAssignee(ctx, issue, doer, assigneeID)
-	if err != nil {
-		return err
+	if notify {
+		_, comment, err = ToggleAssigneeWithNotify(ctx, issue, doer, assigneeID)
+		return comment, err
 	}
-
-	return nil
+	_, comment, err = issues_model.ToggleIssueAssignee(ctx, issue, doer, assigneeID)
+	return comment, err
 }
 
 // GetRefEndNamesAndURLs retrieves the ref end names (e.g. refs/heads/branch-name -> branch-name)
diff --git a/services/pull/patch.go b/services/pull/patch.go
index 9277355720..f8388892e5 100644
--- a/services/pull/patch.go
+++ b/services/pull/patch.go
@@ -62,14 +62,19 @@ func TestPatch(pr *issues_model.PullRequest) error {
 	ctx, _, finished := process.GetManager().AddContext(graceful.GetManager().HammerContext(), fmt.Sprintf("TestPatch: %s", pr))
 	defer finished()
 
-	// Clone base repo.
 	prCtx, cancel, err := createTemporaryRepoForPR(ctx, pr)
 	if err != nil {
-		log.Error("createTemporaryRepoForPR %-v: %v", pr, err)
+		if !models.IsErrBranchDoesNotExist(err) {
+			log.Error("CreateTemporaryRepoForPR %-v: %v", pr, err)
+		}
 		return err
 	}
 	defer cancel()
 
+	return testPatch(ctx, prCtx, pr)
+}
+
+func testPatch(ctx context.Context, prCtx *prContext, pr *issues_model.PullRequest) error {
 	gitRepo, err := git.OpenRepository(ctx, prCtx.tmpBasePath)
 	if err != nil {
 		return fmt.Errorf("OpenRepository: %w", err)
diff --git a/services/pull/pull.go b/services/pull/pull.go
index 8880bfc6ea..b46446a1fc 100644
--- a/services/pull/pull.go
+++ b/services/pull/pull.go
@@ -23,7 +23,6 @@ import (
 	"code.gitea.io/gitea/modules/json"
 	"code.gitea.io/gitea/modules/log"
 	"code.gitea.io/gitea/modules/notification"
-	"code.gitea.io/gitea/modules/process"
 	repo_module "code.gitea.io/gitea/modules/repository"
 	"code.gitea.io/gitea/modules/setting"
 	"code.gitea.io/gitea/modules/sync"
@@ -35,73 +34,70 @@ import (
 var pullWorkingPool = sync.NewExclusivePool()
 
 // NewPullRequest creates new pull request with labels for repository.
-func NewPullRequest(ctx context.Context, repo *repo_model.Repository, pull *issues_model.Issue, labelIDs []int64, uuids []string, pr *issues_model.PullRequest, assigneeIDs []int64) error {
-	if err := TestPatch(pr); err != nil {
+func NewPullRequest(ctx context.Context, repo *repo_model.Repository, issue *issues_model.Issue, labelIDs []int64, uuids []string, pr *issues_model.PullRequest, assigneeIDs []int64) error {
+	prCtx, cancel, err := createTemporaryRepoForPR(ctx, pr)
+	if err != nil {
+		if !models.IsErrBranchDoesNotExist(err) {
+			log.Error("CreateTemporaryRepoForPR %-v: %v", pr, err)
+		}
+		return err
+	}
+	defer cancel()
+
+	if err := testPatch(ctx, prCtx, pr); err != nil {
 		return err
 	}
 
-	divergence, err := GetDiverging(ctx, pr)
+	divergence, err := git.GetDivergingCommits(ctx, prCtx.tmpBasePath, baseBranch, trackingBranch)
 	if err != nil {
 		return err
 	}
 	pr.CommitsAhead = divergence.Ahead
 	pr.CommitsBehind = divergence.Behind
 
-	if err := issues_model.NewPullRequest(ctx, repo, pull, labelIDs, uuids, pr); err != nil {
-		return err
-	}
-
-	for _, assigneeID := range assigneeIDs {
-		if err := issue_service.AddAssigneeIfNotAssigned(ctx, pull, pull.Poster, assigneeID); err != nil {
-			return err
-		}
-	}
-
-	pr.Issue = pull
-	pull.PullRequest = pr
-
-	// Now - even if the request context has been cancelled as the PR has been created
-	// in the db and there is no way to cancel that transaction we have to proceed - therefore
-	// create new context and work from there
-	prCtx, _, finished := process.GetManager().AddContext(graceful.GetManager().HammerContext(), fmt.Sprintf("NewPullRequest: %s:%d", repo.FullName(), pr.Index))
-	defer finished()
-
-	if pr.Flow == issues_model.PullRequestFlowGithub {
-		err = PushToBaseRepo(prCtx, pr)
-	} else {
-		err = UpdateRef(prCtx, pr)
-	}
-	if err != nil {
-		return err
-	}
-
-	mentions, err := issues_model.FindAndUpdateIssueMentions(ctx, pull, pull.Poster, pull.Content)
-	if err != nil {
-		return err
-	}
-
-	notification.NotifyNewPullRequest(prCtx, pr, mentions)
-	if len(pull.Labels) > 0 {
-		notification.NotifyIssueChangeLabels(prCtx, pull.Poster, pull, pull.Labels, nil)
-	}
-	if pull.Milestone != nil {
-		notification.NotifyIssueChangeMilestone(prCtx, pull.Poster, pull, 0)
-	}
+	assigneeCommentMap := make(map[int64]*issues_model.Comment)
 
 	// add first push codes comment
-	baseGitRepo, err := git.OpenRepository(prCtx, pr.BaseRepo.RepoPath())
+	baseGitRepo, err := git.OpenRepository(ctx, pr.BaseRepo.RepoPath())
 	if err != nil {
 		return err
 	}
 	defer baseGitRepo.Close()
 
-	compareInfo, err := baseGitRepo.GetCompareInfo(pr.BaseRepo.RepoPath(),
-		git.BranchPrefix+pr.BaseBranch, pr.GetGitRefName(), false, false)
-	if err != nil {
-		return err
-	}
+	if err := db.WithTx(ctx, func(ctx context.Context) error {
+		if err := issues_model.NewPullRequest(ctx, repo, issue, labelIDs, uuids, pr); err != nil {
+			return err
+		}
+
+		for _, assigneeID := range assigneeIDs {
+			comment, err := issue_service.AddAssigneeIfNotAssigned(ctx, issue, issue.Poster, assigneeID, false)
+			if err != nil {
+				return err
+			}
+			assigneeCommentMap[assigneeID] = comment
+		}
+
+		pr.Issue = issue
+		issue.PullRequest = pr
+
+		if pr.Flow == issues_model.PullRequestFlowGithub {
+			err = PushToBaseRepo(ctx, pr)
+		} else {
+			err = UpdateRef(ctx, pr)
+		}
+		if err != nil {
+			return err
+		}
+
+		compareInfo, err := baseGitRepo.GetCompareInfo(pr.BaseRepo.RepoPath(),
+			git.BranchPrefix+pr.BaseBranch, pr.GetGitRefName(), false, false)
+		if err != nil {
+			return err
+		}
+		if len(compareInfo.Commits) == 0 {
+			return nil
+		}
 
-	if len(compareInfo.Commits) > 0 {
 		data := issues_model.PushActionContent{IsForcePush: false}
 		data.CommitIDs = make([]string, 0, len(compareInfo.Commits))
 		for i := len(compareInfo.Commits) - 1; i >= 0; i-- {
@@ -115,14 +111,47 @@ func NewPullRequest(ctx context.Context, repo *repo_model.Repository, pull *issu
 
 		ops := &issues_model.CreateCommentOptions{
 			Type:        issues_model.CommentTypePullRequestPush,
-			Doer:        pull.Poster,
+			Doer:        issue.Poster,
 			Repo:        repo,
 			Issue:       pr.Issue,
 			IsForcePush: false,
 			Content:     string(dataJSON),
 		}
 
-		_, _ = issue_service.CreateComment(ctx, ops)
+		if _, err = issues_model.CreateComment(ctx, ops); err != nil {
+			return err
+		}
+
+		return nil
+	}); err != nil {
+		// cleanup: this will only remove the reference, the real commit will be clean up when next GC
+		if err1 := baseGitRepo.RemoveReference(pr.GetGitRefName()); err1 != nil {
+			log.Error("RemoveReference: %v", err1)
+		}
+		return err
+	}
+	baseGitRepo.Close() // close immediately to avoid notifications will open the repository again
+
+	mentions, err := issues_model.FindAndUpdateIssueMentions(ctx, issue, issue.Poster, issue.Content)
+	if err != nil {
+		return err
+	}
+
+	notification.NotifyNewPullRequest(ctx, pr, mentions)
+	if len(issue.Labels) > 0 {
+		notification.NotifyIssueChangeLabels(ctx, issue.Poster, issue, issue.Labels, nil)
+	}
+	if issue.Milestone != nil {
+		notification.NotifyIssueChangeMilestone(ctx, issue.Poster, issue, 0)
+	}
+	if len(assigneeIDs) > 0 {
+		for _, assigneeID := range assigneeIDs {
+			assignee, err := user_model.GetUserByID(ctx, assigneeID)
+			if err != nil {
+				return ErrDependenciesLeft
+			}
+			notification.NotifyIssueChangeAssignee(ctx, issue.Poster, issue, assignee, false, assigneeCommentMap[assigneeID])
+		}
 	}
 
 	return nil