diff --git a/diff/diff_test.go b/diff/diff_test.go index 50a138f..6e2ef2b 100644 --- a/diff/diff_test.go +++ b/diff/diff_test.go @@ -74,6 +74,7 @@ func TestParseHunksAndPrintHunks(t *testing.T) { {filename: "oneline_hunk.diff"}, {filename: "empty.diff"}, {filename: "sample_hunk_lines_start_with_minuses.diff"}, + {filename: "sample_hunk_lines_start_with_minuses_pluses.diff"}, } for _, test := range tests { diffData, err := ioutil.ReadFile(filepath.Join("testdata", test.filename)) @@ -736,6 +737,8 @@ func TestParseMultiFileDiffAndPrintMultiFileDiff(t *testing.T) { {filename: "sample_contains_only_added_deleted_files.diff", wantFileDiffs: 3}, {filename: "sample_onlyin_line_isnt_a_file_header.diff", wantFileDiffs: 4}, {filename: "sample_onlyin_complex_filenames.diff", wantFileDiffs: 3}, + {filename: "sample_multi_file_minuses_pluses.diff", wantFileDiffs: 2}, + {filename: "sample_multi_file_without_extended.diff", wantFileDiffs: 2}, } for _, test := range tests { diffData, err := ioutil.ReadFile(filepath.Join("testdata", test.filename)) diff --git a/diff/parse.go b/diff/parse.go index 8d5cfc2..3b79790 100644 --- a/diff/parse.go +++ b/diff/parse.go @@ -23,14 +23,14 @@ func ParseMultiFileDiff(diff []byte) ([]*FileDiff, error) { // NewMultiFileDiffReader returns a new MultiFileDiffReader that reads // a multi-file unified diff from r. func NewMultiFileDiffReader(r io.Reader) *MultiFileDiffReader { - return &MultiFileDiffReader{reader: bufio.NewReader(r)} + return &MultiFileDiffReader{reader: newLineReader(r)} } // MultiFileDiffReader reads a multi-file unified diff. type MultiFileDiffReader struct { line int offset int64 - reader *bufio.Reader + reader *lineReader // TODO(sqs): line and offset tracking in multi-file diffs is broken; add tests and fix @@ -85,7 +85,7 @@ func (r *MultiFileDiffReader) ReadFile() (*FileDiff, error) { // caused by the lack of any hunks, or a malformatted hunk, so we // need to perform the check here. hr := fr.HunksReader() - line, err := readLine(r.reader) + line, err := r.reader.readLine() if err != nil && err != io.EOF { return fd, err } @@ -141,14 +141,14 @@ func ParseFileDiff(diff []byte) (*FileDiff, error) { // NewFileDiffReader returns a new FileDiffReader that reads a file // unified diff. func NewFileDiffReader(r io.Reader) *FileDiffReader { - return &FileDiffReader{reader: bufio.NewReader(r)} + return &FileDiffReader{reader: &lineReader{reader: bufio.NewReader(r)}} } // FileDiffReader reads a unified file diff. type FileDiffReader struct { line int offset int64 - reader *bufio.Reader + reader *lineReader // fileHeaderLine is the first file header line, set by: // @@ -266,7 +266,7 @@ func (r *FileDiffReader) readOneFileHeader(prefix []byte) (filename string, time if r.fileHeaderLine == nil { var err error - line, err = readLine(r.reader) + line, err = r.reader.readLine() if err == io.EOF { return "", nil, &ParseError{r.line, r.offset, ErrNoFileHeader} } else if err != nil { @@ -318,7 +318,7 @@ func (r *FileDiffReader) ReadExtendedHeaders() ([]string, error) { var line []byte if r.fileHeaderLine == nil { var err error - line, err = readLine(r.reader) + line, err = r.reader.readLine() if err == io.EOF { return xheaders, &ParseError{r.line, r.offset, ErrExtendedHeadersEOF} } else if err != nil { @@ -447,7 +447,7 @@ func ParseHunks(diff []byte) ([]*Hunk, error) { // NewHunksReader returns a new HunksReader that reads unified diff hunks // from r. func NewHunksReader(r io.Reader) *HunksReader { - return &HunksReader{reader: bufio.NewReader(r)} + return &HunksReader{reader: &lineReader{reader: bufio.NewReader(r)}} } // A HunksReader reads hunks from a unified diff. @@ -455,7 +455,7 @@ type HunksReader struct { line int offset int64 hunk *Hunk - reader *bufio.Reader + reader *lineReader nextHunkHeaderLine []byte } @@ -474,7 +474,7 @@ func (r *HunksReader) ReadHunk() (*Hunk, error) { line = r.nextHunkHeaderLine r.nextHunkHeaderLine = nil } else { - line, err = readLine(r.reader) + line, err = r.reader.readLine() if err != nil { if err == io.EOF && r.hunk != nil { return r.hunk, nil @@ -518,12 +518,15 @@ func (r *HunksReader) ReadHunk() (*Hunk, error) { // If the line starts with `---` and the next one with `+++` we're // looking at a non-extended file header and need to abort. if bytes.HasPrefix(line, []byte("---")) { - ok, err := peekPrefix(r.reader, "+++") + ok, err := r.reader.nextLineStartsWith("+++") if err != nil { return r.hunk, err } if ok { - return r.hunk, &ParseError{r.line, r.offset, &ErrBadHunkLine{Line: line}} + ok2, _ := r.reader.nextNextLineStartsWith(string(hunkPrefix)) + if ok2 { + return r.hunk, &ParseError{r.line, r.offset, &ErrBadHunkLine{Line: line}} + } } } @@ -593,19 +596,6 @@ func linePrefix(c byte) bool { return false } -// peekPrefix peeks into the given reader to check whether the next -// bytes match the given prefix. -func peekPrefix(reader *bufio.Reader, prefix string) (bool, error) { - next, err := reader.Peek(len(prefix)) - if err != nil { - if err == io.EOF { - return false, nil - } - return false, err - } - return bytes.HasPrefix(next, []byte(prefix)), nil -} - // normalizeHeader takes a header of the form: // "@@ -linestart[,chunksize] +linestart[,chunksize] @@ section" // and returns two strings, with the first in the form: diff --git a/diff/reader_util.go b/diff/reader_util.go index 395fb7b..4530025 100644 --- a/diff/reader_util.go +++ b/diff/reader_util.go @@ -2,9 +2,92 @@ package diff import ( "bufio" + "bytes" + "errors" "io" ) +var ErrLineReaderUninitialized = errors.New("line reader not initialized") + +func newLineReader(r io.Reader) *lineReader { + return &lineReader{reader: bufio.NewReader(r)} +} + +// lineReader is a wrapper around a bufio.Reader that caches the next line to +// provide lookahead functionality for the next two lines. +type lineReader struct { + reader *bufio.Reader + + cachedNextLine []byte + cachedNextLineErr error +} + +// readLine returns the next unconsumed line and advances the internal cache of +// the lineReader. +func (l *lineReader) readLine() ([]byte, error) { + if l.cachedNextLine == nil && l.cachedNextLineErr == nil { + l.cachedNextLine, l.cachedNextLineErr = readLine(l.reader) + } + + if l.cachedNextLineErr != nil { + return nil, l.cachedNextLineErr + } + + next := l.cachedNextLine + + l.cachedNextLine, l.cachedNextLineErr = readLine(l.reader) + + return next, nil +} + +// nextLineStartsWith looks at the line that would be returned by the next call +// to readLine to check whether it has the given prefix. +// +// io.EOF and bufio.ErrBufferFull errors are ignored so that the function can +// be used when at the end of the file. +func (l *lineReader) nextLineStartsWith(prefix string) (bool, error) { + if l.cachedNextLine == nil && l.cachedNextLineErr == nil { + l.cachedNextLine, l.cachedNextLineErr = readLine(l.reader) + } + + return l.lineHasPrefix(l.cachedNextLine, prefix, l.cachedNextLineErr) +} + +// nextNextLineStartsWith checks the prefix of the line *after* the line that +// would be returned by the next readLine. +// +// io.EOF and bufio.ErrBufferFull errors are ignored so that the function can +// be used when at the end of the file. +// +// The lineReader MUST be initialized by calling readLine at least once before +// calling nextLineStartsWith. Otherwise ErrLineReaderUninitialized will be +// returned. +func (l *lineReader) nextNextLineStartsWith(prefix string) (bool, error) { + if l.cachedNextLine == nil && l.cachedNextLineErr == nil { + l.cachedNextLine, l.cachedNextLineErr = readLine(l.reader) + } + + next, err := l.reader.Peek(len(prefix)) + return l.lineHasPrefix(next, prefix, err) +} + +// lineHasPrefix checks whether the given line has the given prefix with +// bytes.HasPrefix. +// +// The readErr should be the error that was returned when the line was read. +// lineHasPrefix checks the error to adjust its return value to, e.g., return +// false and ignore the error when readErr is io.EOF. +func (l *lineReader) lineHasPrefix(line []byte, prefix string, readErr error) (bool, error) { + if readErr != nil { + if readErr == io.EOF || readErr == bufio.ErrBufferFull { + return false, nil + } + return false, readErr + } + + return bytes.HasPrefix(line, []byte(prefix)), nil +} + // readLine is a helper that mimics the functionality of calling bufio.Scanner.Scan() and // bufio.Scanner.Bytes(), but without the token size limitation. It will read and return // the next line in the Reader with the trailing newline stripped. It will return an diff --git a/diff/reader_util_test.go b/diff/reader_util_test.go index 8d5b2b7..8dd0016 100644 --- a/diff/reader_util_test.go +++ b/diff/reader_util_test.go @@ -66,3 +66,144 @@ index 0000000..3be2928`, }) } } + +func TestLineReader_ReadLine(t *testing.T) { + input := `diff --git a/test.go b/test.go +new file mode 100644 +index 0000000..3be2928 + + +` + + in := newLineReader(strings.NewReader(input)) + out := []string{} + for i := 0; i < 4; i++ { + l, err := in.readLine() + if err != nil { + t.Fatal(err) + } + out = append(out, string(l)) + } + + wantOut := strings.Split(input, "\n")[0:4] + if !reflect.DeepEqual(wantOut, out) { + t.Errorf("read lines not equal: want %v, got %v", wantOut, out) + } + + _, err := in.readLine() + if err != nil { + t.Fatal(err) + } + if in.cachedNextLineErr != io.EOF { + t.Fatalf("lineReader has wrong cachedNextLineErr: %s", in.cachedNextLineErr) + } + _, err = in.readLine() + if err != io.EOF { + t.Fatalf("readLine did not return io.EOF: %s", err) + } +} + +func TestLineReader_NextLine(t *testing.T) { + input := `aaa rest of line +bbbrest of line +ccc rest of line` + + in := newLineReader(strings.NewReader(input)) + + type assertion struct { + prefix string + want bool + } + + testsPerReadLine := []struct { + nextLine []assertion + nextNextLine []assertion + wantReadLineErr error + }{ + { + nextLine: []assertion{ + {prefix: "a", want: true}, + {prefix: "aa", want: true}, + {prefix: "aaa", want: true}, + {prefix: "bbb", want: false}, + {prefix: "ccc", want: false}, + }, + nextNextLine: []assertion{ + {prefix: "aaa", want: false}, + {prefix: "bbb", want: true}, + {prefix: "ccc", want: false}, + }, + }, + { + nextLine: []assertion{ + {prefix: "aaa", want: false}, + {prefix: "bbb", want: true}, + {prefix: "ccc", want: false}, + }, + nextNextLine: []assertion{ + {prefix: "aaa", want: false}, + {prefix: "bbb", want: false}, + {prefix: "ccc", want: true}, + }, + }, + { + nextLine: []assertion{ + {prefix: "aaa", want: false}, + {prefix: "bbb", want: false}, + {prefix: "ccc", want: true}, + {prefix: "ddd", want: false}, + }, + nextNextLine: []assertion{ + {prefix: "aaa", want: false}, + {prefix: "bbb", want: false}, + {prefix: "ccc", want: false}, + {prefix: "ddd", want: false}, + }, + }, + { + nextLine: []assertion{ + {prefix: "aaa", want: false}, + {prefix: "bbb", want: false}, + {prefix: "ccc", want: false}, + {prefix: "ddd", want: false}, + }, + nextNextLine: []assertion{ + {prefix: "aaa", want: false}, + {prefix: "bbb", want: false}, + {prefix: "ccc", want: false}, + {prefix: "ddd", want: false}, + }, + wantReadLineErr: io.EOF, + }, + } + + for _, tc := range testsPerReadLine { + for _, assert := range tc.nextLine { + got, err := in.nextLineStartsWith(assert.prefix) + if err != nil { + t.Fatalf("nextLineStartsWith returned unexpected error: %s", err) + } + + if got != assert.want { + t.Fatalf("unexpected result for prefix %q. got=%t, want=%t", assert.prefix, got, assert.want) + } + } + + for _, assert := range tc.nextNextLine { + got, err := in.nextNextLineStartsWith(assert.prefix) + if err != nil { + t.Fatalf("nextLineStartsWith returned unexpected error: %s", err) + } + + if got != assert.want { + t.Fatalf("unexpected result for prefix %q. got=%t, want=%t", assert.prefix, got, assert.want) + } + } + + _, err := in.readLine() + if err != tc.wantReadLineErr { + t.Fatalf("readLine returned unexpected error. got=%s, want=%s", err, tc.wantReadLineErr) + } + + } +} diff --git a/diff/testdata/sample_hunk_lines_start_with_minuses_pluses.diff b/diff/testdata/sample_hunk_lines_start_with_minuses_pluses.diff new file mode 100644 index 0000000..1bf6f75 --- /dev/null +++ b/diff/testdata/sample_hunk_lines_start_with_minuses_pluses.diff @@ -0,0 +1,8 @@ +@@ -1,5 +1,5 @@ + select 1; +--- this is my query ++++ this is my query + select 2; + select 3; +--- this is the last line ++++ this is the last line diff --git a/diff/testdata/sample_multi_file_minuses_pluses.diff b/diff/testdata/sample_multi_file_minuses_pluses.diff new file mode 100644 index 0000000..f74d8a0 --- /dev/null +++ b/diff/testdata/sample_multi_file_minuses_pluses.diff @@ -0,0 +1,21 @@ +diff --git a/comment-last-line.sql b/comment-last-line.sql +index 04a1655..97bd115 100644 +--- a/comment-last-line.sql ++++ b/comment-last-line.sql +@@ -1,4 +1,4 @@ + select 1; ++++ invalid SQL comment + select 2; + select 3; +--- end of three queries +diff --git a/query.sql b/query.sql +index 9537d7b..234ef35 100644 +--- a/query.sql ++++ b/query.sql +@@ -1,5 +1,4 @@ + select 1; +--- this is my query + select 2; + select 3; +--- this is the last line ++++ invalid sql comment