Skip to content
This repository was archived by the owner on Jul 11, 2022. It is now read-only.

Commit 00a3025

Browse files
zsolambv
authored andcommitted
Preserve line endings when formatting a file in place (pytest-dev#288)
1 parent dbe2616 commit 00a3025

File tree

4 files changed

+68
-15
lines changed

4 files changed

+68
-15
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -720,6 +720,8 @@ More details can be found in [CONTRIBUTING](CONTRIBUTING.md).
720720
* fixed stdin handling not working correctly if an old version of Click was
721721
used (#276)
722722

723+
* *Black* now preserves line endings when formatting a file in place (#258)
724+
723725

724726
### 18.5b1
725727

black.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from concurrent.futures import Executor, ProcessPoolExecutor
55
from enum import Enum, Flag
66
from functools import partial, wraps
7+
import io
78
import keyword
89
import logging
910
from multiprocessing import Manager
@@ -465,8 +466,9 @@ def format_file_in_place(
465466
"""
466467
if src.suffix == ".pyi":
467468
mode |= FileMode.PYI
468-
with tokenize.open(src) as src_buffer:
469-
src_contents = src_buffer.read()
469+
470+
with open(src, "rb") as buf:
471+
newline, encoding, src_contents = prepare_input(buf.read())
470472
try:
471473
dst_contents = format_file_contents(
472474
src_contents, line_length=line_length, fast=fast, mode=mode
@@ -475,7 +477,7 @@ def format_file_in_place(
475477
return False
476478

477479
if write_back == write_back.YES:
478-
with open(src, "w", encoding=src_buffer.encoding) as f:
480+
with open(src, "w", encoding=encoding, newline=newline) as f:
479481
f.write(dst_contents)
480482
elif write_back == write_back.DIFF:
481483
src_name = f"{src} (original)"
@@ -484,7 +486,14 @@ def format_file_in_place(
484486
if lock:
485487
lock.acquire()
486488
try:
487-
sys.stdout.write(diff_contents)
489+
f = io.TextIOWrapper(
490+
sys.stdout.buffer,
491+
encoding=encoding,
492+
newline=newline,
493+
write_through=True,
494+
)
495+
f.write(diff_contents)
496+
f.detach()
488497
finally:
489498
if lock:
490499
lock.release()
@@ -503,7 +512,7 @@ def format_stdin_to_stdout(
503512
`line_length`, `fast`, `is_pyi`, and `force_py36` arguments are passed to
504513
:func:`format_file_contents`.
505514
"""
506-
src = sys.stdin.read()
515+
newline, encoding, src = prepare_input(sys.stdin.buffer.read())
507516
dst = src
508517
try:
509518
dst = format_file_contents(src, line_length=line_length, fast=fast, mode=mode)
@@ -514,11 +523,25 @@ def format_stdin_to_stdout(
514523

515524
finally:
516525
if write_back == WriteBack.YES:
517-
sys.stdout.write(dst)
526+
f = io.TextIOWrapper(
527+
sys.stdout.buffer,
528+
encoding=encoding,
529+
newline=newline,
530+
write_through=True,
531+
)
532+
f.write(dst)
533+
f.detach()
518534
elif write_back == WriteBack.DIFF:
519535
src_name = "<stdin> (original)"
520536
dst_name = "<stdin> (formatted)"
521-
sys.stdout.write(diff(src, dst, src_name, dst_name))
537+
f = io.TextIOWrapper(
538+
sys.stdout.buffer,
539+
encoding=encoding,
540+
newline=newline,
541+
write_through=True,
542+
)
543+
f.write(diff(src, dst, src_name, dst_name))
544+
f.detach()
522545

523546

524547
def format_file_contents(
@@ -579,6 +602,19 @@ def format_str(
579602
return dst_contents
580603

581604

605+
def prepare_input(src: bytes) -> Tuple[str, str, str]:
606+
"""Analyze `src` and return a tuple of (newline, encoding, decoded_contents)
607+
608+
Where `newline` is either CRLF or LF, and `decoded_contents` is decoded with
609+
universal newlines (i.e. only LF).
610+
"""
611+
srcbuf = io.BytesIO(src)
612+
encoding, lines = tokenize.detect_encoding(srcbuf.readline)
613+
newline = "\r\n" if b"\r\n" == lines[0][-2:] else "\n"
614+
srcbuf.seek(0)
615+
return newline, encoding, io.TextIOWrapper(srcbuf, encoding).read()
616+
617+
582618
GRAMMARS = [
583619
pygram.python_grammar_no_print_statement_no_exec_statement,
584620
pygram.python_grammar_no_print_statement,
@@ -590,8 +626,7 @@ def lib2to3_parse(src_txt: str) -> Node:
590626
"""Given a string with source, return the lib2to3 Node."""
591627
grammar = pygram.python_grammar_no_print_statement
592628
if src_txt[-1] != "\n":
593-
nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
594-
src_txt += nl
629+
src_txt += "\n"
595630
for grammar in GRAMMARS:
596631
drv = driver.Driver(grammar, pytree.convert)
597632
try:

docs/reference/reference_functions.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ Parsing
6161

6262
.. autofunction:: black.lib2to3_unparse
6363

64+
.. autofunction:: black.prepare_input
65+
6466
Split functions
6567
---------------
6668

tests/test_black.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from concurrent.futures import ThreadPoolExecutor
44
from contextlib import contextmanager
55
from functools import partial
6-
from io import StringIO
6+
from io import BytesIO, TextIOWrapper
77
import os
88
from pathlib import Path
99
import sys
@@ -121,8 +121,9 @@ def test_piping(self) -> None:
121121
source, expected = read_data("../black")
122122
hold_stdin, hold_stdout = sys.stdin, sys.stdout
123123
try:
124-
sys.stdin, sys.stdout = StringIO(source), StringIO()
125-
sys.stdin.name = "<stdin>"
124+
sys.stdin = TextIOWrapper(BytesIO(source.encode("utf8")), encoding="utf8")
125+
sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
126+
sys.stdin.buffer.name = "<stdin>" # type: ignore
126127
black.format_stdin_to_stdout(
127128
line_length=ll, fast=True, write_back=black.WriteBack.YES
128129
)
@@ -139,8 +140,9 @@ def test_piping_diff(self) -> None:
139140
expected, _ = read_data("expression.diff")
140141
hold_stdin, hold_stdout = sys.stdin, sys.stdout
141142
try:
142-
sys.stdin, sys.stdout = StringIO(source), StringIO()
143-
sys.stdin.name = "<stdin>"
143+
sys.stdin = TextIOWrapper(BytesIO(source.encode("utf8")), encoding="utf8")
144+
sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
145+
sys.stdin.buffer.name = "<stdin>" # type: ignore
144146
black.format_stdin_to_stdout(
145147
line_length=ll, fast=True, write_back=black.WriteBack.DIFF
146148
)
@@ -204,7 +206,7 @@ def test_expression_diff(self) -> None:
204206
tmp_file = Path(black.dump_to_file(source))
205207
hold_stdout = sys.stdout
206208
try:
207-
sys.stdout = StringIO()
209+
sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
208210
self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
209211
sys.stdout.seek(0)
210212
actual = sys.stdout.read()
@@ -1108,6 +1110,18 @@ def test_invalid_include_exclude(self) -> None:
11081110
result = CliRunner().invoke(black.main, ["-", option, "**()(!!*)"])
11091111
self.assertEqual(result.exit_code, 2)
11101112

1113+
def test_preserves_line_endings(self) -> None:
1114+
with TemporaryDirectory() as workspace:
1115+
test_file = Path(workspace) / "test.py"
1116+
for nl in ["\n", "\r\n"]:
1117+
contents = nl.join(["def f( ):", " pass"])
1118+
test_file.write_bytes(contents.encode())
1119+
ff(test_file, write_back=black.WriteBack.YES)
1120+
updated_contents: bytes = test_file.read_bytes()
1121+
self.assertIn(nl.encode(), updated_contents) # type: ignore
1122+
if nl == "\n":
1123+
self.assertNotIn(b"\r\n", updated_contents) # type: ignore
1124+
11111125

11121126
if __name__ == "__main__":
11131127
unittest.main()

0 commit comments

Comments
 (0)