Skip to content

Commit 364fb49

Browse files
committed
zstd: faster next state update in BMI2 version of decode
Use the Go-code approach: use single getBits to obtain three bitfields.
1 parent 6ebbb85 commit 364fb49

File tree

3 files changed

+236
-159
lines changed

3 files changed

+236
-159
lines changed

zstd/_generate/gen.go

Lines changed: 102 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -289,12 +289,48 @@ func (o options) generateBody(name string, executeSingleTriple func(ctx *execute
289289

290290
// Update states, max tablelog 28
291291
{
292-
Comment("Update Literal Length State")
293-
o.updateState(name+"_llState", llState, brValue, brBitsRead, "llTable")
294-
Comment("Update Match Length State")
295-
o.updateState(name+"_mlState", mlState, brValue, brBitsRead, "mlTable")
296-
Comment("Update Offset State")
297-
o.updateState(name+"_ofState", ofState, brValue, brBitsRead, "ofTable")
292+
if o.bmi2 {
293+
// Get total number of bits (it is safe, as nBits is <= 9, thus 3*9 < 255)
294+
total := GP64()
295+
LEAQ(Mem{Base: llState, Index: mlState, Scale: 1}, total)
296+
ADDQ(ofState, total)
297+
MOVBQZX(total.As8(), total) // total = llState.As8() + mlState.As8() + ofState.As8()
298+
299+
// Read `total` bits
300+
bits := o.getBitsValue(name+"_getBits", total, brValue, brBitsRead)
301+
302+
// Update states
303+
Comment("Update Offset State")
304+
{
305+
nBits := ofState // Note: SHRXQ uses lower 6 bits of shift amount and BZHIQ lower 8 bits of count
306+
lowBits := GP64()
307+
BZHIQ(nBits, bits, lowBits) // lowBits = bits & ((1 << nBits) - 1))
308+
SHRXQ(nBits, bits, bits) // bits >>= nBits
309+
o.nextState(name+"_ofState", ofState, lowBits, "ofTable")
310+
}
311+
Comment("Update Match Length State")
312+
{
313+
nBits := mlState
314+
lowBits := GP64()
315+
BZHIQ(nBits, bits, lowBits) // lowBits = bits & ((1 << nBits) - 1))
316+
SHRXQ(nBits, bits, bits) // lowBits >>= nBits
317+
o.nextState(name+"_mlState", mlState, lowBits, "mlTable")
318+
}
319+
Comment("Update Literal Length State")
320+
{
321+
nBits := llState
322+
lowBits := GP64()
323+
BZHIQ(nBits, bits, lowBits) // lowBits = bits & ((1 << nBits) - 1))
324+
o.nextState(name+"_llState", llState, lowBits, "llTable")
325+
}
326+
} else {
327+
Comment("Update Literal Length State")
328+
o.updateState(name+"_llState", llState, brValue, brBitsRead, "llTable")
329+
Comment("Update Match Length State")
330+
o.updateState(name+"_mlState", mlState, brValue, brBitsRead, "mlTable")
331+
Comment("Update Offset State")
332+
o.updateState(name+"_ofState", ofState, brValue, brBitsRead, "ofTable")
333+
}
298334
}
299335
Label(name + "_skip_update")
300336

@@ -624,6 +660,39 @@ func (o options) updateState(name string, state, brValue, brBitsRead reg.GPVirtu
624660
MOVQ(Mem{Base: tablePtr, Index: DX, Scale: 8}, state)
625661
}
626662

663+
func (o options) nextState(name string, state, lowBits reg.GPVirtual, table string) {
664+
DX := GP64()
665+
if o.bmi2 {
666+
tmp := GP64()
667+
MOVQ(U32(16|(16<<8)), tmp)
668+
BEXTRQ(tmp, state, DX)
669+
} else {
670+
MOVQ(state, DX)
671+
SHRQ(U8(16), DX)
672+
MOVWQZX(DX.As16(), DX)
673+
}
674+
675+
ADDQ(lowBits, DX)
676+
677+
// Load table pointer
678+
tablePtr := GP64()
679+
Comment("Load ctx." + table)
680+
ctx := Dereference(Param("ctx"))
681+
tableA, err := ctx.Field(table).Base().Resolve()
682+
if err != nil {
683+
panic(err)
684+
}
685+
MOVQ(tableA.Addr, tablePtr)
686+
687+
// Check if below tablelog
688+
assert(func(ok LabelRef) {
689+
CMPQ(DX, U32(512))
690+
JB(ok)
691+
})
692+
// Load new state
693+
MOVQ(Mem{Base: tablePtr, Index: DX, Scale: 8}, state)
694+
}
695+
627696
// getBits will return nbits bits from brValue.
628697
// If nbits == 0 it *may* jump to jmpZero, otherwise 0 is returned.
629698
func (o options) getBits(name string, nBits, brValue, brBitsRead reg.GPVirtual, jmpZero LabelRef) reg.GPVirtual {
@@ -649,6 +718,33 @@ func (o options) getBits(name string, nBits, brValue, brBitsRead reg.GPVirtual,
649718
return BX
650719
}
651720

721+
// getBits will return nbits bits from brValue.
722+
// If nbits == 0 then 0 is returned.
723+
func (o options) getBitsValue(name string, nBits, brValue, brBitsRead reg.GPVirtual) reg.GPVirtual {
724+
BX := GP64()
725+
CX := reg.CL
726+
if o.bmi2 {
727+
LEAQ(Mem{Base: brBitsRead, Index: nBits, Scale: 1}, CX.As64())
728+
MOVQ(brValue, BX)
729+
MOVQ(CX.As64(), brBitsRead)
730+
ROLQ(CX, BX)
731+
BZHIQ(nBits, BX, BX)
732+
} else {
733+
XORQ(BX, BX)
734+
CMPQ(nBits, U8(0))
735+
JZ(LabelRef(name + "_get_bits_value_zero"))
736+
MOVQ(brBitsRead, CX.As64())
737+
ADDQ(nBits, brBitsRead)
738+
MOVQ(brValue, BX)
739+
SHLQ(CX, BX)
740+
MOVQ(nBits, CX.As64())
741+
NEGQ(CX.As64())
742+
SHRQ(CX, BX)
743+
Label(name + "_get_bits_value_zero")
744+
}
745+
return BX
746+
}
747+
652748
func (o options) adjustOffset(name string, moP, llP Mem, offsetB reg.GPVirtual, offsets *[3]reg.GPVirtual) (offset reg.GPVirtual) {
653749
offset = GP64()
654750
MOVQ(moP, offset)

zstd/seqdec.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ func (s *sequenceDecs) execute(seqs []seqVals, hist []byte) error {
188188
}
189189
}
190190
}
191+
191192
// Add final literals
192193
copy(out[t:], s.literals)
193194
if debugDecoder {
@@ -203,12 +204,11 @@ func (s *sequenceDecs) execute(seqs []seqVals, hist []byte) error {
203204

204205
// decode sequences from the stream with the provided history.
205206
func (s *sequenceDecs) decodeSync(hist []byte) error {
206-
if true {
207-
supported, err := s.decodeSyncSimple(hist)
208-
if supported {
209-
return err
210-
}
207+
supported, err := s.decodeSyncSimple(hist)
208+
if supported {
209+
return err
211210
}
211+
212212
br := s.br
213213
seqs := s.nSeqs
214214
startSize := len(s.out)
@@ -396,6 +396,7 @@ func (s *sequenceDecs) decodeSync(hist []byte) error {
396396
ofState = ofTable[ofState.newState()&maxTableMask]
397397
} else {
398398
bits := br.get32BitsFast(nBits)
399+
399400
lowBits := uint16(bits >> ((ofState.nbBits() + mlState.nbBits()) & 31))
400401
llState = llTable[(llState.newState()+lowBits)&maxTableMask]
401402

0 commit comments

Comments
 (0)