diff --git a/models/issues/pull.go b/models/issues/pull.go index 9f180f9ac950b..077c47c1877b3 100644 --- a/models/issues/pull.go +++ b/models/issues/pull.go @@ -1002,3 +1002,36 @@ func GetPullRequestByMergedCommit(ctx context.Context, repoID int64, sha string) return pr, nil } + +func GetPullRequestsByMergedCommit(ctx context.Context, repoID int64, sha string) (PullRequestList, error) { + prs := PullRequestList{} + err := db.GetEngine(ctx).Where("base_repo_id = ? AND merged_commit_id = ?", repoID, sha).Find(&prs) + if err != nil { + return nil, err + } + + return prs, nil +} + +// GetPullRequestsByHeadBranch returns all pull requests whose head branch is one +// of the given branch names and whose head or base repo matches the given repo ID. +// This finds both same-repo PRs (head_repo_id matches) and forked PRs (base_repo_id matches). +func GetPullRequestsByHeadBranch(ctx context.Context, repoID int64, branches []string) (PullRequestList, error) { + if len(branches) == 0 { + return nil, nil + } + + prs := PullRequestList{} + err := db.GetEngine(ctx). + Where(builder.Or( + builder.Eq{"head_repo_id": repoID}, + builder.Eq{"base_repo_id": repoID}, + )). + In("head_branch", branches). + Find(&prs) + if err != nil { + return nil, err + } + + return prs, nil +} diff --git a/modules/git/repo_commit.go b/modules/git/repo_commit.go index c10f73690cfdd..12acbb71b389d 100644 --- a/modules/git/repo_commit.go +++ b/modules/git/repo_commit.go @@ -6,10 +6,12 @@ package git import ( "bytes" + "context" "io" "os" "strconv" "strings" + "time" "code.gitea.io/gitea/modules/git/gitcmd" "code.gitea.io/gitea/modules/setting" @@ -523,6 +525,24 @@ func (repo *Repository) IsCommitInBranch(commitID, branch string) (r bool, err e return len(stdout) > 0, err } +// GetBranchesContaining returns all local branch names that contain the given commit. +// A timeout is applied to prevent slow lookups on large repositories. +func (repo *Repository) GetBranchesContaining(commitID string) ([]string, error) { + ctx, cancel := context.WithTimeout(repo.Ctx, 10*time.Second) + defer cancel() + + stdout, _, err := gitcmd.NewCommand("for-each-ref", "--format=%(refname:strip=2)"). + AddOptionValues("--contains", commitID, BranchPrefix). + WithDir(repo.Path). + RunStdString(ctx) + if err != nil { + return nil, err + } + + branches := strings.Fields(stdout) + return branches, nil +} + // GetCommitBranchStart returns the commit where the branch diverged func (repo *Repository) GetCommitBranchStart(env []string, branch, endCommitID string) (string, error) { cmd := gitcmd.NewCommand("log", prettyLogFormat) diff --git a/routers/api/v1/api.go b/routers/api/v1/api.go index 359d5af4c4bc4..5907d53126689 100644 --- a/routers/api/v1/api.go +++ b/routers/api/v1/api.go @@ -1393,6 +1393,7 @@ func Routes() *web.Router { g.MatchPath("GET", "//status", repo.GetCombinedCommitStatusByRef) g.MatchPath("GET", "//statuses", repo.GetCommitStatusesByRef) g.MatchPath("GET", "//pull", repo.GetCommitPullRequest) + g.MatchPath("GET", "//pulls", repo.GetCommitPullRequests) }) }, reqRepoReader(unit.TypeCode)) m.Group("/git", func() { diff --git a/routers/api/v1/repo/commits.go b/routers/api/v1/repo/commits.go index 2a7efa0ea6f1b..b2f8e05d20f12 100644 --- a/routers/api/v1/repo/commits.go +++ b/routers/api/v1/repo/commits.go @@ -11,6 +11,7 @@ import ( "time" issues_model "code.gitea.io/gitea/models/issues" + "code.gitea.io/gitea/models/unit" user_model "code.gitea.io/gitea/models/user" "code.gitea.io/gitea/modules/git" "code.gitea.io/gitea/modules/gitrepo" @@ -400,3 +401,122 @@ func GetCommitPullRequest(ctx *context.APIContext) { } ctx.JSON(http.StatusOK, convert.ToAPIPullRequest(ctx, pr, ctx.Doer)) } + +func GetCommitPullRequests(ctx *context.APIContext) { + // swagger:operation GET /repos/{owner}/{repo}/commits/{sha}/pulls repository repoGetCommitPullRequests + // --- + // summary: Get the pull requests associated with a commit + // produces: + // - application/json + // parameters: + // - name: owner + // in: path + // description: owner of the repo + // type: string + // required: true + // - name: repo + // in: path + // description: name of the repo + // type: string + // required: true + // - name: sha + // in: path + // description: SHA of the commit to get + // type: string + // required: true + // - name: page + // in: query + // description: page number of results to return (1-based) + // type: integer + // - name: limit + // in: query + // description: page size of results + // type: integer + // responses: + // "200": + // "$ref": "#/responses/PullRequestList" + // "404": + // "$ref": "#/responses/notFound" + + if !ctx.Repo.CanRead(unit.TypePullRequests) { + ctx.APIErrorNotFound() + return + } + + sha := ctx.PathParam("sha") + + // Strategy 1: Find PRs where this commit is the merge commit + mergedPRs, err := issues_model.GetPullRequestsByMergedCommit(ctx, ctx.Repo.Repository.ID, sha) + if err != nil { + ctx.APIErrorInternal(err) + return + } + + // Strategy 2: Find branches containing this commit, then match to PRs + gitRepo, err := gitrepo.RepositoryFromRequestContextOrOpen(ctx, ctx.Repo.Repository) + if err != nil { + ctx.APIErrorInternal(err) + return + } + + branches, err := gitRepo.GetBranchesContaining(sha) + if err != nil { + // Intentionally ignoring errors here (e.g., invalid SHA, non-existent commit). + // This matches GitHub API behavior which returns 200 OK with an empty array + // rather than an error for unknown commits. + branches = nil + } + + branchPRs, err := issues_model.GetPullRequestsByHeadBranch(ctx, ctx.Repo.Repository.ID, branches) + if err != nil { + ctx.APIErrorInternal(err) + return + } + + // Combine and deduplicate + seen := make(map[int64]bool, len(mergedPRs)) + allPRs := make(issues_model.PullRequestList, 0, len(mergedPRs)+len(branchPRs)) + + for _, pr := range mergedPRs { + if !seen[pr.ID] { + seen[pr.ID] = true + allPRs = append(allPRs, pr) + } + } + for _, pr := range branchPRs { + if !seen[pr.ID] { + seen[pr.ID] = true + allPRs = append(allPRs, pr) + } + } + + totalCount := int64(len(allPRs)) + + // Apply pagination + listOptions := utils.GetListOptions(ctx) + start := (listOptions.Page - 1) * listOptions.PageSize + end := start + listOptions.PageSize + if start >= len(allPRs) { + allPRs = issues_model.PullRequestList{} + } else { + if end > len(allPRs) { + end = len(allPRs) + } + allPRs = allPRs[start:end] + } + + if len(allPRs) == 0 { + ctx.SetTotalCountHeader(totalCount) + ctx.JSON(http.StatusOK, []*api.PullRequest{}) + return + } + + baseRepo := ctx.Repo.Repository + apiPRs, err := convert.ToAPIPullRequests(ctx, baseRepo, allPRs, ctx.Doer) + if err != nil { + ctx.APIErrorInternal(err) + return + } + ctx.SetTotalCountHeader(totalCount) + ctx.JSON(http.StatusOK, apiPRs) +} diff --git a/templates/swagger/v1_json.tmpl b/templates/swagger/v1_json.tmpl index 570747ca57e94..1b57b8fb78639 100644 --- a/templates/swagger/v1_json.tmpl +++ b/templates/swagger/v1_json.tmpl @@ -7499,6 +7499,61 @@ } } }, + "/repos/{owner}/{repo}/commits/{sha}/pulls": { + "get": { + "produces": [ + "application/json" + ], + "tags": [ + "repository" + ], + "summary": "Get the pull requests associated with a commit", + "operationId": "repoGetCommitPullRequests", + "parameters": [ + { + "type": "string", + "description": "owner of the repo", + "name": "owner", + "in": "path", + "required": true + }, + { + "type": "string", + "description": "name of the repo", + "name": "repo", + "in": "path", + "required": true + }, + { + "type": "string", + "description": "SHA of the commit to get", + "name": "sha", + "in": "path", + "required": true + }, + { + "type": "integer", + "description": "page number of results to return (1-based)", + "name": "page", + "in": "query" + }, + { + "type": "integer", + "description": "page size of results", + "name": "limit", + "in": "query" + } + ], + "responses": { + "200": { + "$ref": "#/responses/PullRequestList" + }, + "404": { + "$ref": "#/responses/notFound" + } + } + } + }, "/repos/{owner}/{repo}/compare/{basehead}": { "get": { "produces": [ diff --git a/tests/integration/api_repo_get_commit_pull_request_test.go b/tests/integration/api_repo_get_commit_pull_request_test.go new file mode 100644 index 0000000000000..dd8e7025f41cc --- /dev/null +++ b/tests/integration/api_repo_get_commit_pull_request_test.go @@ -0,0 +1,80 @@ +// Copyright 2026 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package integration + +import ( + "net/http" + "testing" + + auth_model "code.gitea.io/gitea/models/auth" + "code.gitea.io/gitea/models/unittest" + user_model "code.gitea.io/gitea/models/user" + api "code.gitea.io/gitea/modules/structs" + "code.gitea.io/gitea/tests" + + "github.com/stretchr/testify/assert" +) + +func TestAPIReposGetCommitPullRequests(t *testing.T) { + defer tests.PrepareTestEnv(t)() + + user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2}) + session := loginUser(t, user.Name) + token := getTokenForLoggedInUser(t, session, auth_model.AccessTokenScopeReadRepository) + + // Helper: query the /pulls endpoint and decode the response + getCommitPRs := func(t *testing.T, sha string, expectedStatus int) []*api.PullRequest { + t.Helper() + req := NewRequestf(t, "GET", "/api/v1/repos/%s/repo1/commits/%s/pulls", user.Name, sha). + AddTokenAuth(token) + resp := MakeRequest(t, req, expectedStatus) + + var prs []*api.PullRequest + DecodeJSON(t, resp, &prs) + return prs + } + + t.Run("MergedCommit", func(t *testing.T) { + // PR #1 (fixture id=1) has merged_commit_id = 1a8823cd1a9549fde083f992f6b9b87a7ab74fb3 + // This tests the DB-level lookup by merged_commit_id + mergedCommitSHA := "1a8823cd1a9549fde083f992f6b9b87a7ab74fb3" + + prs := getCommitPRs(t, mergedCommitSHA, http.StatusOK) + + assert.NotEmpty(t, prs, "Should find the PR by its merge commit SHA") + assert.Equal(t, int64(2), prs[0].Index, "Should be PR index 2 (fixture PR #1)") + assert.Equal(t, "master", prs[0].Base.Name) + }) + + t.Run("CommitInPRBranch", func(t *testing.T) { + // Commit 5c050d3b is on branch2 (PR #2, fixture id=2) and pr-to-update (PR #5, fixture id=5) + // This tests the git branch containment strategy + commitOnBranch := "5c050d3b6d2db231ab1f64e324f1b6b9a0b181c2" + + prs := getCommitPRs(t, commitOnBranch, http.StatusOK) + + assert.NotEmpty(t, prs, "Should find PRs whose branches contain this commit") + + // Verify we found at least the PR with head_branch=branch2 + foundPR2 := false + for _, pr := range prs { + if pr.Index == 3 { // PR #2 has issue_id=3 so index=3 + foundPR2 = true + assert.Equal(t, "branch2", pr.Head.Name) + } + } + assert.True(t, foundPR2, "Expected to find PR with head_branch=branch2") + }) + + t.Run("InvalidCommitSHA", func(t *testing.T) { + prs := getCommitPRs(t, "invalidsha", http.StatusOK) + assert.Empty(t, prs, "Should return empty array for invalid SHA") + }) + + t.Run("NonexistentCommit", func(t *testing.T) { + // Valid SHA format but doesn't exist in repo + prs := getCommitPRs(t, "0000000000000000000000000000000000000000", http.StatusOK) + assert.Empty(t, prs, "Should return empty array for nonexistent commit") + }) +}