diff --git a/services/pull/merge.go b/services/pull/merge.go index e8bb3a1cdd..6266f9ab1f 100644 --- a/services/pull/merge.go +++ b/services/pull/merge.go @@ -28,6 +28,7 @@ import ( user_model "code.gitea.io/gitea/models/user" "code.gitea.io/gitea/modules/cache" "code.gitea.io/gitea/modules/git" + "code.gitea.io/gitea/modules/graceful" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/notification" "code.gitea.io/gitea/modules/references" @@ -165,9 +166,10 @@ func Merge(ctx context.Context, pr *issues_model.PullRequest, doer *user_model.U go AddTestPullRequestTask(doer, pr.BaseRepo.ID, pr.BaseBranch, false, "", "") }() - // TODO: make it able to do this in a database session - mergeCtx := context.Background() - pr.MergedCommitID, err = rawMerge(mergeCtx, pr, doer, mergeStyle, expectedHeadCommitID, message) + // Run the merge in the hammer context to prevent cancellation + hammerCtx := graceful.GetManager().HammerContext() + + pr.MergedCommitID, err = rawMerge(hammerCtx, pr, doer, mergeStyle, expectedHeadCommitID, message) if err != nil { return err } @@ -176,18 +178,18 @@ func Merge(ctx context.Context, pr *issues_model.PullRequest, doer *user_model.U pr.Merger = doer pr.MergerID = doer.ID - if _, err := pr.SetMerged(ctx); err != nil { + if _, err := pr.SetMerged(hammerCtx); err != nil { log.Error("setMerged [%d]: %v", pr.ID, err) } - if err := pr.LoadIssueCtx(ctx); err != nil { + if err := pr.LoadIssueCtx(hammerCtx); err != nil { log.Error("loadIssue [%d]: %v", pr.ID, err) } - if err := pr.Issue.LoadRepo(ctx); err != nil { + if err := pr.Issue.LoadRepo(hammerCtx); err != nil { log.Error("loadRepo for issue [%d]: %v", pr.ID, err) } - if err := pr.Issue.Repo.GetOwner(ctx); err != nil { + if err := pr.Issue.Repo.GetOwner(hammerCtx); err != nil { log.Error("GetOwner for issue repo [%d]: %v", pr.ID, err) } @@ -197,17 +199,17 @@ func Merge(ctx context.Context, pr *issues_model.PullRequest, doer *user_model.U cache.Remove(pr.Issue.Repo.GetCommitsCountCacheKey(pr.BaseBranch, true)) // Resolve cross references - refs, err := pr.ResolveCrossReferences(ctx) + refs, err := pr.ResolveCrossReferences(hammerCtx) if err != nil { log.Error("ResolveCrossReferences: %v", err) return nil } for _, ref := range refs { - if err = ref.LoadIssueCtx(ctx); err != nil { + if err = ref.LoadIssueCtx(hammerCtx); err != nil { return err } - if err = ref.Issue.LoadRepo(ctx); err != nil { + if err = ref.Issue.LoadRepo(hammerCtx); err != nil { return err } close := ref.RefAction == references.XRefActionCloses