From feb189554e758ed27d1e309e5ec309d663e8f338 Mon Sep 17 00:00:00 2001
From: qwerty287 <80460567+qwerty287@users.noreply.github.com>
Date: Mon, 26 Feb 2024 03:39:01 +0100
Subject: [PATCH] Add API to get PR by base/head (#29242)

Closes https://github.com/go-gitea/gitea/issues/16289

Add a new API `/repos/{owner}/{repo}/pulls/{base}/{head}` to get a PR by
its base and head branch.
---
 models/issues/pull.go              | 29 ++++++++++
 routers/api/v1/api.go              |  1 +
 routers/api/v1/repo/pull.go        | 85 ++++++++++++++++++++++++++++++
 templates/swagger/v1_json.tmpl     | 50 ++++++++++++++++++
 tests/integration/api_pull_test.go | 21 ++++++++
 5 files changed, 186 insertions(+)

diff --git a/models/issues/pull.go b/models/issues/pull.go
index 2cb1e1b971..ce2fd9233d 100644
--- a/models/issues/pull.go
+++ b/models/issues/pull.go
@@ -652,6 +652,35 @@ func GetPullRequestByIssueID(ctx context.Context, issueID int64) (*PullRequest,
 	return pr, pr.LoadAttributes(ctx)
 }
 
+// GetPullRequestsByBaseHeadInfo returns the pull request by given base and head
+func GetPullRequestByBaseHeadInfo(ctx context.Context, baseID, headID int64, base, head string) (*PullRequest, error) {
+	pr := &PullRequest{}
+	sess := db.GetEngine(ctx).
+		Join("INNER", "issue", "issue.id = pull_request.issue_id").
+		Where("base_repo_id = ? AND base_branch = ? AND head_repo_id = ? AND head_branch = ?", baseID, base, headID, head)
+	has, err := sess.Get(pr)
+	if err != nil {
+		return nil, err
+	}
+	if !has {
+		return nil, ErrPullRequestNotExist{
+			HeadRepoID: headID,
+			BaseRepoID: baseID,
+			HeadBranch: head,
+			BaseBranch: base,
+		}
+	}
+
+	if err = pr.LoadAttributes(ctx); err != nil {
+		return nil, err
+	}
+	if err = pr.LoadIssue(ctx); err != nil {
+		return nil, err
+	}
+
+	return pr, nil
+}
+
 // GetAllUnmergedAgitPullRequestByPoster get all unmerged agit flow pull request
 // By poster id.
 func GetAllUnmergedAgitPullRequestByPoster(ctx context.Context, uid int64) ([]*PullRequest, error) {
diff --git a/routers/api/v1/api.go b/routers/api/v1/api.go
index 38c0c01a0a..e75d4cc964 100644
--- a/routers/api/v1/api.go
+++ b/routers/api/v1/api.go
@@ -1289,6 +1289,7 @@ func Routes() *web.Route {
 							Delete(bind(api.PullReviewRequestOptions{}), repo.DeleteReviewRequests).
 							Post(bind(api.PullReviewRequestOptions{}), repo.CreateReviewRequests)
 					})
+					m.Get("/{base}/*", repo.GetPullRequestByBaseHead)
 				}, mustAllowPulls, reqRepoReader(unit.TypeCode), context.ReferencesGitRepo())
 				m.Group("/statuses", func() {
 					m.Combo("/{sha}").Get(repo.GetCommitStatuses).
diff --git a/routers/api/v1/repo/pull.go b/routers/api/v1/repo/pull.go
index b1b486e017..ef9b4893fb 100644
--- a/routers/api/v1/repo/pull.go
+++ b/routers/api/v1/repo/pull.go
@@ -187,6 +187,91 @@ func GetPullRequest(ctx *context.APIContext) {
 	ctx.JSON(http.StatusOK, convert.ToAPIPullRequest(ctx, pr, ctx.Doer))
 }
 
+// GetPullRequest returns a single PR based on index
+func GetPullRequestByBaseHead(ctx *context.APIContext) {
+	// swagger:operation GET /repos/{owner}/{repo}/pulls/{base}/{head} repository repoGetPullRequestByBaseHead
+	// ---
+	// summary: Get a pull request by base and head
+	// produces:
+	// - application/json
+	// parameters:
+	// - name: owner
+	//   in: path
+	//   description: owner of the repo
+	//   type: string
+	//   required: true
+	// - name: repo
+	//   in: path
+	//   description: name of the repo
+	//   type: string
+	//   required: true
+	// - name: base
+	//   in: path
+	//   description: base of the pull request to get
+	//   type: string
+	//   required: true
+	// - name: head
+	//   in: path
+	//   description: head of the pull request to get
+	//   type: string
+	//   required: true
+	// responses:
+	//   "200":
+	//     "$ref": "#/responses/PullRequest"
+	//   "404":
+	//     "$ref": "#/responses/notFound"
+
+	var headRepoID int64
+	var headBranch string
+	head := ctx.Params("*")
+	if strings.Contains(head, ":") {
+		split := strings.SplitN(head, ":", 2)
+		headBranch = split[1]
+		var owner, name string
+		if strings.Contains(split[0], "/") {
+			split = strings.Split(split[0], "/")
+			owner = split[0]
+			name = split[1]
+		} else {
+			owner = split[0]
+			name = ctx.Repo.Repository.Name
+		}
+		repo, err := repo_model.GetRepositoryByOwnerAndName(ctx, owner, name)
+		if err != nil {
+			if repo_model.IsErrRepoNotExist(err) {
+				ctx.NotFound()
+			} else {
+				ctx.Error(http.StatusInternalServerError, "GetRepositoryByOwnerName", err)
+			}
+			return
+		}
+		headRepoID = repo.ID
+	} else {
+		headRepoID = ctx.Repo.Repository.ID
+		headBranch = head
+	}
+
+	pr, err := issues_model.GetPullRequestByBaseHeadInfo(ctx, ctx.Repo.Repository.ID, headRepoID, ctx.Params(":base"), headBranch)
+	if err != nil {
+		if issues_model.IsErrPullRequestNotExist(err) {
+			ctx.NotFound()
+		} else {
+			ctx.Error(http.StatusInternalServerError, "GetPullRequestByBaseHeadInfo", err)
+		}
+		return
+	}
+
+	if err = pr.LoadBaseRepo(ctx); err != nil {
+		ctx.Error(http.StatusInternalServerError, "LoadBaseRepo", err)
+		return
+	}
+	if err = pr.LoadHeadRepo(ctx); err != nil {
+		ctx.Error(http.StatusInternalServerError, "LoadHeadRepo", err)
+		return
+	}
+	ctx.JSON(http.StatusOK, convert.ToAPIPullRequest(ctx, pr, ctx.Doer))
+}
+
 // DownloadPullDiffOrPatch render a pull's raw diff or patch
 func DownloadPullDiffOrPatch(ctx *context.APIContext) {
 	// swagger:operation GET /repos/{owner}/{repo}/pulls/{index}.{diffType} repository repoDownloadPullDiffOrPatch
diff --git a/templates/swagger/v1_json.tmpl b/templates/swagger/v1_json.tmpl
index 18ab544415..973759604b 100644
--- a/templates/swagger/v1_json.tmpl
+++ b/templates/swagger/v1_json.tmpl
@@ -10769,6 +10769,56 @@
         }
       }
     },
+    "/repos/{owner}/{repo}/pulls/{base}/{head}": {
+      "get": {
+        "produces": [
+          "application/json"
+        ],
+        "tags": [
+          "repository"
+        ],
+        "summary": "Get a pull request by base and head",
+        "operationId": "repoGetPullRequestByBaseHead",
+        "parameters": [
+          {
+            "type": "string",
+            "description": "owner of the repo",
+            "name": "owner",
+            "in": "path",
+            "required": true
+          },
+          {
+            "type": "string",
+            "description": "name of the repo",
+            "name": "repo",
+            "in": "path",
+            "required": true
+          },
+          {
+            "type": "string",
+            "description": "base of the pull request to get",
+            "name": "base",
+            "in": "path",
+            "required": true
+          },
+          {
+            "type": "string",
+            "description": "head of the pull request to get",
+            "name": "head",
+            "in": "path",
+            "required": true
+          }
+        ],
+        "responses": {
+          "200": {
+            "$ref": "#/responses/PullRequest"
+          },
+          "404": {
+            "$ref": "#/responses/notFound"
+          }
+        }
+      }
+    },
     "/repos/{owner}/{repo}/pulls/{index}": {
       "get": {
         "produces": [
diff --git a/tests/integration/api_pull_test.go b/tests/integration/api_pull_test.go
index f4640521ca..d49711c5c3 100644
--- a/tests/integration/api_pull_test.go
+++ b/tests/integration/api_pull_test.go
@@ -61,6 +61,27 @@ func TestAPIViewPulls(t *testing.T) {
 	}
 }
 
+func TestAPIViewPullsByBaseHead(t *testing.T) {
+	defer tests.PrepareTestEnv(t)()
+	repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 1})
+	owner := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: repo.OwnerID})
+
+	ctx := NewAPITestContext(t, "user2", repo.Name, auth_model.AccessTokenScopeReadRepository)
+
+	req := NewRequestf(t, "GET", "/api/v1/repos/%s/%s/pulls/master/branch2", owner.Name, repo.Name).
+		AddTokenAuth(ctx.Token)
+	resp := ctx.Session.MakeRequest(t, req, http.StatusOK)
+
+	pull := &api.PullRequest{}
+	DecodeJSON(t, resp, pull)
+	assert.EqualValues(t, 3, pull.Index)
+	assert.EqualValues(t, 2, pull.ID)
+
+	req = NewRequestf(t, "GET", "/api/v1/repos/%s/%s/pulls/master/branch-not-exist", owner.Name, repo.Name).
+		AddTokenAuth(ctx.Token)
+	ctx.Session.MakeRequest(t, req, http.StatusNotFound)
+}
+
 // TestAPIMergePullWIP ensures that we can't merge a WIP pull request
 func TestAPIMergePullWIP(t *testing.T) {
 	defer tests.PrepareTestEnv(t)()