Skip to content

Commit 2719b1a

Browse files
gregory-mgopherbot
authored andcommitted
compress/flate: return error on closed stream write
Previously flate.Writer allowed writes after Close, and this behavior could lead to stream corruption. Fixes #27741 Change-Id: Iee1ac69f8199232f693dba77b275f7078257b582 Reviewed-on: https://go-review.googlesource.com/c/go/+/136475 Run-TryBot: Ian Lance Taylor <[email protected]> TryBot-Result: Gopher Robot <[email protected]> Auto-Submit: Ian Lance Taylor <[email protected]> Reviewed-by: Joseph Tsai <[email protected]> Reviewed-by: Carlos Amedee <[email protected]> Reviewed-by: Ian Lance Taylor <[email protected]> Run-TryBot: Ian Lance Taylor <[email protected]>
1 parent fdf1d76 commit 2719b1a

File tree

2 files changed

+120
-4
lines changed

2 files changed

+120
-4
lines changed

src/compress/flate/deflate.go

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package flate
66

77
import (
8+
"errors"
89
"fmt"
910
"io"
1011
"math"
@@ -699,17 +700,27 @@ func (w *dictWriter) Write(b []byte) (n int, err error) {
699700
return w.w.Write(b)
700701
}
701702

703+
var errWriteAfterClose = errors.New("compress/flate: write after close")
704+
702705
// A Writer takes data written to it and writes the compressed
703706
// form of that data to an underlying writer (see NewWriter).
704707
type Writer struct {
705708
d compressor
706709
dict []byte
710+
err error
707711
}
708712

709713
// Write writes data to w, which will eventually write the
710714
// compressed form of data to its underlying writer.
711715
func (w *Writer) Write(data []byte) (n int, err error) {
712-
return w.d.write(data)
716+
if w.err != nil {
717+
return 0, w.err
718+
}
719+
n, err = w.d.write(data)
720+
if err != nil {
721+
w.err = err
722+
}
723+
return n, err
713724
}
714725

715726
// Flush flushes any pending data to the underlying writer.
@@ -724,18 +735,37 @@ func (w *Writer) Write(data []byte) (n int, err error) {
724735
func (w *Writer) Flush() error {
725736
// For more about flushing:
726737
// https://www.bolet.org/~pornin/deflate-flush.html
727-
return w.d.syncFlush()
738+
if w.err != nil {
739+
return w.err
740+
}
741+
if err := w.d.syncFlush(); err != nil {
742+
w.err = err
743+
return err
744+
}
745+
return nil
728746
}
729747

730748
// Close flushes and closes the writer.
731749
func (w *Writer) Close() error {
732-
return w.d.close()
750+
if w.err == errWriteAfterClose {
751+
return nil
752+
}
753+
if w.err != nil {
754+
return w.err
755+
}
756+
if err := w.d.close(); err != nil {
757+
w.err = err
758+
return err
759+
}
760+
w.err = errWriteAfterClose
761+
return nil
733762
}
734763

735764
// Reset discards the writer's state and makes it equivalent to
736765
// the result of NewWriter or NewWriterDict called with dst
737766
// and w's level and dictionary.
738767
func (w *Writer) Reset(dst io.Writer) {
768+
w.err = nil
739769
if dw, ok := w.d.w.writer.(*dictWriter); ok {
740770
// w was created with NewWriterDict
741771
dw.w = dst

src/compress/flate/deflate_test.go

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,40 @@ func TestDeflate(t *testing.T) {
125125
}
126126
}
127127

128+
func TestWriterClose(t *testing.T) {
129+
b := new(bytes.Buffer)
130+
zw, err := NewWriter(b, 6)
131+
if err != nil {
132+
t.Fatalf("NewWriter: %v", err)
133+
}
134+
135+
if c, err := zw.Write([]byte("Test")); err != nil || c != 4 {
136+
t.Fatalf("Write to not closed writer: %s, %d", err, c)
137+
}
138+
139+
if err := zw.Close(); err != nil {
140+
t.Fatalf("Close: %v", err)
141+
}
142+
143+
afterClose := b.Len()
144+
145+
if c, err := zw.Write([]byte("Test")); err == nil || c != 0 {
146+
t.Fatalf("Write to closed writer: %s, %d", err, c)
147+
}
148+
149+
if err := zw.Flush(); err == nil {
150+
t.Fatalf("Flush to closed writer: %s", err)
151+
}
152+
153+
if err := zw.Close(); err != nil {
154+
t.Fatalf("Close: %v", err)
155+
}
156+
157+
if afterClose != b.Len() {
158+
t.Fatalf("Writer wrote data after close. After close: %d. After writes on closed stream: %d", afterClose, b.Len())
159+
}
160+
}
161+
128162
// A sparseReader returns a stream consisting of 0s followed by 1<<16 1s.
129163
// This tests missing hash references in a very large input.
130164
type sparseReader struct {
@@ -683,7 +717,7 @@ func (w *failWriter) Write(b []byte) (int, error) {
683717
return len(b), nil
684718
}
685719

686-
func TestWriterPersistentError(t *testing.T) {
720+
func TestWriterPersistentWriteError(t *testing.T) {
687721
t.Parallel()
688722
d, err := os.ReadFile("../../testdata/Isaac.Newton-Opticks.txt")
689723
if err != nil {
@@ -706,19 +740,71 @@ func TestWriterPersistentError(t *testing.T) {
706740

707741
_, werr := zw.Write(d)
708742
cerr := zw.Close()
743+
ferr := zw.Flush()
709744
if werr != errIO && werr != nil {
710745
t.Errorf("test %d, mismatching Write error: got %v, want %v", i, werr, errIO)
711746
}
712747
if cerr != errIO && fw.n < 0 {
713748
t.Errorf("test %d, mismatching Close error: got %v, want %v", i, cerr, errIO)
714749
}
750+
if ferr != errIO && fw.n < 0 {
751+
t.Errorf("test %d, mismatching Flush error: got %v, want %v", i, ferr, errIO)
752+
}
715753
if fw.n >= 0 {
716754
// At this point, the failure threshold was sufficiently high enough
717755
// that we wrote the whole stream without any errors.
718756
return
719757
}
720758
}
721759
}
760+
func TestWriterPersistentFlushError(t *testing.T) {
761+
zw, err := NewWriter(&failWriter{0}, DefaultCompression)
762+
if err != nil {
763+
t.Fatalf("NewWriter: %v", err)
764+
}
765+
flushErr := zw.Flush()
766+
closeErr := zw.Close()
767+
_, writeErr := zw.Write([]byte("Test"))
768+
checkErrors([]error{closeErr, flushErr, writeErr}, errIO, t)
769+
}
770+
771+
func TestWriterPersistentCloseError(t *testing.T) {
772+
// If underlying writer return error on closing stream we should persistent this error across all writer calls.
773+
zw, err := NewWriter(&failWriter{0}, DefaultCompression)
774+
if err != nil {
775+
t.Fatalf("NewWriter: %v", err)
776+
}
777+
closeErr := zw.Close()
778+
flushErr := zw.Flush()
779+
_, writeErr := zw.Write([]byte("Test"))
780+
checkErrors([]error{closeErr, flushErr, writeErr}, errIO, t)
781+
782+
// After closing writer we should persistent "write after close" error across Flush and Write calls, but return nil
783+
// on next Close calls.
784+
var b bytes.Buffer
785+
zw.Reset(&b)
786+
err = zw.Close()
787+
if err != nil {
788+
t.Fatalf("First call to close returned error: %s", err)
789+
}
790+
err = zw.Close()
791+
if err != nil {
792+
t.Fatalf("Second call to close returned error: %s", err)
793+
}
794+
795+
flushErr = zw.Flush()
796+
_, writeErr = zw.Write([]byte("Test"))
797+
checkErrors([]error{flushErr, writeErr}, errWriteAfterClose, t)
798+
}
799+
800+
func checkErrors(got []error, want error, t *testing.T) {
801+
t.Helper()
802+
for _, err := range got {
803+
if err != want {
804+
t.Errorf("Errors dosn't match\nWant: %s\nGot: %s", want, got)
805+
}
806+
}
807+
}
722808

723809
func TestBestSpeedMatch(t *testing.T) {
724810
t.Parallel()

0 commit comments

Comments
 (0)