Skip to content

Commit c0f5de1

Browse files
dsnetbroady
authored andcommitted
[release-branch.go1.7] compress/flate: make huffmanBitWriter errors persistent
For persistent error handling, the methods of huffmanBitWriter have to be consistent about how they check errors. It must either consistently check error *before* every operation OR immediately *after* every operation. Since most of the current logic uses the previous approach, we apply the same style of error checking to writeBits and all calls to Write such that they only operate if w.err is already nil going into them. The error handling approach is brittle and easily broken by future commits to the code. In the near future, we should switch the logic to use panic at the lowest levels and a recover at the edge of the public API to ensure that errors are always persistent. Fixes #16749 Change-Id: Ie1d83e4ed8842f6911a31e23311cd3cbf38abe8c Reviewed-on: https://go-review.googlesource.com/27200 Reviewed-by: Matthew Dempsky <[email protected]> Reviewed-by: Brad Fitzpatrick <[email protected]> Reviewed-on: https://go-review.googlesource.com/28634 Reviewed-by: Joe Tsai <[email protected]> Run-TryBot: Brad Fitzpatrick <[email protected]>
1 parent 1861ac1 commit c0f5de1

File tree

3 files changed

+78
-14
lines changed

3 files changed

+78
-14
lines changed

src/compress/flate/deflate.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,7 @@ func (w *Writer) Close() error {
724724
// the result of NewWriter or NewWriterDict called with dst
725725
// and w's level and dictionary.
726726
func (w *Writer) Reset(dst io.Writer) {
727-
if dw, ok := w.d.w.w.(*dictWriter); ok {
727+
if dw, ok := w.d.w.writer.(*dictWriter); ok {
728728
// w was created with NewWriterDict
729729
dw.w = dst
730730
w.d.reset(dw)

src/compress/flate/deflate_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package flate
66

77
import (
88
"bytes"
9+
"errors"
910
"fmt"
1011
"internal/testenv"
1112
"io"
@@ -631,3 +632,52 @@ func TestBestSpeed(t *testing.T) {
631632
}
632633
}
633634
}
635+
636+
var errIO = errors.New("IO error")
637+
638+
// failWriter fails with errIO exactly at the nth call to Write.
639+
type failWriter struct{ n int }
640+
641+
func (w *failWriter) Write(b []byte) (int, error) {
642+
w.n--
643+
if w.n == -1 {
644+
return 0, errIO
645+
}
646+
return len(b), nil
647+
}
648+
649+
func TestWriterPersistentError(t *testing.T) {
650+
d, err := ioutil.ReadFile("../testdata/Mark.Twain-Tom.Sawyer.txt")
651+
if err != nil {
652+
t.Fatalf("ReadFile: %v", err)
653+
}
654+
d = d[:10000] // Keep this test short
655+
656+
zw, err := NewWriter(nil, DefaultCompression)
657+
if err != nil {
658+
t.Fatalf("NewWriter: %v", err)
659+
}
660+
661+
// Sweep over the threshold at which an error is returned.
662+
// The variable i makes it such that the ith call to failWriter.Write will
663+
// return errIO. Since failWriter errors are not persistent, we must ensure
664+
// that flate.Writer errors are persistent.
665+
for i := 0; i < 1000; i++ {
666+
fw := &failWriter{i}
667+
zw.Reset(fw)
668+
669+
_, werr := zw.Write(d)
670+
cerr := zw.Close()
671+
if werr != errIO && werr != nil {
672+
t.Errorf("test %d, mismatching Write error: got %v, want %v", i, werr, errIO)
673+
}
674+
if cerr != errIO && fw.n < 0 {
675+
t.Errorf("test %d, mismatching Close error: got %v, want %v", i, cerr, errIO)
676+
}
677+
if fw.n >= 0 {
678+
// At this point, the failure threshold was sufficiently high enough
679+
// that we wrote the whole stream without any errors.
680+
return
681+
}
682+
}
683+
}

src/compress/flate/huffman_bit_writer.go

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,11 @@ var offsetBase = []uint32{
7777
var codegenOrder = []uint32{16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}
7878

7979
type huffmanBitWriter struct {
80-
w io.Writer
80+
// writer is the underlying writer.
81+
// Do not use it directly; use the write method, which ensures
82+
// that Write errors are sticky.
83+
writer io.Writer
84+
8185
// Data waiting to be written is bytes[0:nbytes]
8286
// and then the low nbits of bits.
8387
bits uint64
@@ -96,7 +100,7 @@ type huffmanBitWriter struct {
96100

97101
func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter {
98102
return &huffmanBitWriter{
99-
w: w,
103+
writer: w,
100104
literalFreq: make([]int32, maxNumLit),
101105
offsetFreq: make([]int32, offsetCodeCount),
102106
codegen: make([]uint8, maxNumLit+offsetCodeCount+1),
@@ -107,7 +111,7 @@ func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter {
107111
}
108112

109113
func (w *huffmanBitWriter) reset(writer io.Writer) {
110-
w.w = writer
114+
w.writer = writer
111115
w.bits, w.nbits, w.nbytes, w.err = 0, 0, 0, nil
112116
w.bytes = [bufferSize]byte{}
113117
}
@@ -129,11 +133,21 @@ func (w *huffmanBitWriter) flush() {
129133
n++
130134
}
131135
w.bits = 0
132-
_, w.err = w.w.Write(w.bytes[:n])
136+
w.write(w.bytes[:n])
133137
w.nbytes = 0
134138
}
135139

140+
func (w *huffmanBitWriter) write(b []byte) {
141+
if w.err != nil {
142+
return
143+
}
144+
_, w.err = w.writer.Write(b)
145+
}
146+
136147
func (w *huffmanBitWriter) writeBits(b int32, nb uint) {
148+
if w.err != nil {
149+
return
150+
}
137151
w.bits |= uint64(b) << w.nbits
138152
w.nbits += nb
139153
if w.nbits >= 48 {
@@ -150,7 +164,7 @@ func (w *huffmanBitWriter) writeBits(b int32, nb uint) {
150164
bytes[5] = byte(bits >> 40)
151165
n += 6
152166
if n >= bufferFlushSize {
153-
_, w.err = w.w.Write(w.bytes[:n])
167+
w.write(w.bytes[:n])
154168
n = 0
155169
}
156170
w.nbytes = n
@@ -173,13 +187,10 @@ func (w *huffmanBitWriter) writeBytes(bytes []byte) {
173187
n++
174188
}
175189
if n != 0 {
176-
_, w.err = w.w.Write(w.bytes[:n])
177-
if w.err != nil {
178-
return
179-
}
190+
w.write(w.bytes[:n])
180191
}
181192
w.nbytes = 0
182-
_, w.err = w.w.Write(bytes)
193+
w.write(bytes)
183194
}
184195

185196
// RFC 1951 3.2.7 specifies a special run-length encoding for specifying
@@ -341,7 +352,7 @@ func (w *huffmanBitWriter) writeCode(c hcode) {
341352
bytes[5] = byte(bits >> 40)
342353
n += 6
343354
if n >= bufferFlushSize {
344-
_, w.err = w.w.Write(w.bytes[:n])
355+
w.write(w.bytes[:n])
345356
n = 0
346357
}
347358
w.nbytes = n
@@ -572,6 +583,9 @@ func (w *huffmanBitWriter) indexTokens(tokens []token) (numLiterals, numOffsets
572583
// writeTokens writes a slice of tokens to the output.
573584
// codes for literal and offset encoding must be supplied.
574585
func (w *huffmanBitWriter) writeTokens(tokens []token, leCodes, oeCodes []hcode) {
586+
if w.err != nil {
587+
return
588+
}
575589
for _, t := range tokens {
576590
if t < matchType {
577591
w.writeCode(leCodes[t.literal()])
@@ -676,9 +690,9 @@ func (w *huffmanBitWriter) writeBlockHuff(eof bool, input []byte) {
676690
if n < bufferFlushSize {
677691
continue
678692
}
679-
_, w.err = w.w.Write(w.bytes[:n])
693+
w.write(w.bytes[:n])
680694
if w.err != nil {
681-
return
695+
return // Return early in the event of write failures
682696
}
683697
n = 0
684698
}

0 commit comments

Comments
 (0)