Skip to content

Commit affa1a2

Browse files
committed
Add tests for check.get_default_compare_branch
1 parent 15c46bc commit affa1a2

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

src/towncrier/test/test_check.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
import sys
77

88
from subprocess import PIPE, Popen, call
9+
from unittest.mock import patch
910

1011
from click.testing import CliRunner
1112
from twisted.trial.unittest import TestCase
1213

14+
from towncrier import check
1315
from towncrier.check import _main as towncrier_check
1416

1517

@@ -271,3 +273,47 @@ def test_release_branch(self):
271273
# Assert
272274
self.assertEqual(0, result.exit_code, (result, result.output))
273275
self.assertIn("Checks SKIPPED: news file changes detected", result.output)
276+
277+
def test_get_default_compare_branch_missingf(self):
278+
"""
279+
If there's no recognized remote origin, exit with an error.
280+
"""
281+
runner = CliRunner()
282+
283+
with runner.isolated_filesystem():
284+
create_project()
285+
286+
result = runner.invoke(towncrier_check)
287+
288+
self.assertEqual(1, result.exit_code)
289+
self.assertEqual("Could not detect default branch. Aborting.\n", result.output)
290+
291+
def test_get_default_compare_branch_main(self):
292+
"""
293+
If there's a remote branch origin/main, prefer it over everything else.
294+
"""
295+
runner = CliRunner()
296+
297+
with runner.isolated_filesystem():
298+
create_project()
299+
300+
with patch("towncrier.check._run") as m:
301+
m.return_value = b" origin/master\n origin/main\n\n"
302+
branch = check.get_default_compare_branch(".", "utf-8")
303+
304+
self.assertEqual("origin/main", branch)
305+
306+
def test_get_default_compare_branch_fallback(self):
307+
"""
308+
If there's origin/master and no main, use it.
309+
"""
310+
runner = CliRunner()
311+
312+
with runner.isolated_filesystem():
313+
create_project()
314+
315+
with patch("towncrier.check._run") as m:
316+
m.return_value = b" origin/master\n origin/foo\n\n"
317+
branch = check.get_default_compare_branch(".", "utf-8")
318+
319+
self.assertEqual("origin/master", branch)

0 commit comments

Comments
 (0)