Skip to content

Commit 81a8a4a

Browse files
committed
crypto/ecies: improve concatKDF (ethereum#20836)
1 parent 9a21b24 commit 81a8a4a

File tree

3 files changed

+93
-110
lines changed

3 files changed

+93
-110
lines changed

crypto/ecies/ecies.go

Lines changed: 47 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ import (
3535
"crypto/elliptic"
3636
"crypto/hmac"
3737
"crypto/subtle"
38+
"encoding/binary"
3839
"errors"
39-
"fmt"
4040
"hash"
4141
"io"
4242
"math/big"
@@ -45,7 +45,6 @@ import (
4545
var (
4646
ErrImport = errors.New("ecies: failed to import key")
4747
ErrInvalidCurve = errors.New("ecies: invalid elliptic curve")
48-
ErrInvalidParams = errors.New("ecies: invalid ECIES parameters")
4948
ErrInvalidPublicKey = errors.New("ecies: invalid public key")
5049
ErrSharedKeyIsPointAtInfinity = errors.New("ecies: shared key is point at infinity")
5150
ErrSharedKeyTooBig = errors.New("ecies: shared key params are too big")
@@ -139,57 +138,39 @@ func (prv *PrivateKey) GenerateShared(pub *PublicKey, skLen, macLen int) (sk []b
139138
}
140139

141140
var (
142-
ErrKeyDataTooLong = errors.New("ecies: can't supply requested key data")
143141
ErrSharedTooLong = errors.New("ecies: shared secret is too long")
144142
ErrInvalidMessage = errors.New("ecies: invalid message")
145143
)
146144

147-
var (
148-
big2To32 = new(big.Int).Exp(big.NewInt(2), big.NewInt(32), nil)
149-
big2To32M1 = new(big.Int).Sub(big2To32, big.NewInt(1))
150-
)
151-
152-
func incCounter(ctr []byte) {
153-
if ctr[3]++; ctr[3] != 0 {
154-
return
155-
}
156-
if ctr[2]++; ctr[2] != 0 {
157-
return
158-
}
159-
if ctr[1]++; ctr[1] != 0 {
160-
return
161-
}
162-
if ctr[0]++; ctr[0] != 0 {
163-
return
164-
}
165-
}
166-
167145
// NIST SP 800-56 Concatenation Key Derivation Function (see section 5.8.1).
168-
func concatKDF(hash hash.Hash, z, s1 []byte, kdLen int) (k []byte, err error) {
169-
if s1 == nil {
170-
s1 = make([]byte, 0)
171-
}
172-
173-
reps := ((kdLen + 7) * 8) / (hash.BlockSize() * 8)
174-
if big.NewInt(int64(reps)).Cmp(big2To32M1) > 0 {
175-
fmt.Println(big2To32M1)
176-
return nil, ErrKeyDataTooLong
177-
}
178-
179-
counter := []byte{0, 0, 0, 1}
180-
k = make([]byte, 0)
181-
182-
for i := 0; i <= reps; i++ {
183-
hash.Write(counter)
146+
func concatKDF(hash hash.Hash, z, s1 []byte, kdLen int) []byte {
147+
counterBytes := make([]byte, 4)
148+
k := make([]byte, 0, roundup(kdLen, hash.Size()))
149+
for counter := uint32(1); len(k) < kdLen; counter++ {
150+
binary.BigEndian.PutUint32(counterBytes, counter)
151+
hash.Reset()
152+
hash.Write(counterBytes)
184153
hash.Write(z)
185154
hash.Write(s1)
186-
k = append(k, hash.Sum(nil)...)
187-
hash.Reset()
188-
incCounter(counter)
155+
k = hash.Sum(k)
189156
}
157+
return k[:kdLen]
158+
}
190159

191-
k = k[:kdLen]
192-
return
160+
// roundup rounds size up to the next multiple of blocksize.
161+
func roundup(size, blocksize int) int {
162+
return size + blocksize - (size % blocksize)
163+
}
164+
165+
// deriveKeys creates the encryption and MAC keys using concatKDF.
166+
func deriveKeys(hash hash.Hash, z, s1 []byte, keyLen int) (Ke, Km []byte) {
167+
K := concatKDF(hash, z, s1, 2*keyLen)
168+
Ke = K[:keyLen]
169+
Km = K[keyLen:]
170+
hash.Reset()
171+
hash.Write(Km)
172+
Km = hash.Sum(Km[:0])
173+
return Ke, Km
193174
}
194175

195176
// messageTag computes the MAC of a message (called the tag) as per
@@ -210,7 +191,6 @@ func generateIV(params *ECIESParams, rand io.Reader) (iv []byte, err error) {
210191
}
211192

212193
// symEncrypt carries out CTR encryption using the block cipher specified in the
213-
// parameters.
214194
func symEncrypt(rand io.Reader, params *ECIESParams, key, m []byte) (ct []byte, err error) {
215195
c, err := params.Cipher(key)
216196
if err != nil {
@@ -250,36 +230,27 @@ func symDecrypt(params *ECIESParams, key, ct []byte) (m []byte, err error) {
250230
// ciphertext. s1 is fed into key derivation, s2 is fed into the MAC. If the
251231
// shared information parameters aren't being used, they should be nil.
252232
func Encrypt(rand io.Reader, pub *PublicKey, m, s1, s2 []byte) (ct []byte, err error) {
253-
params := pub.Params
254-
if params == nil {
255-
if params = ParamsFromCurve(pub.Curve); params == nil {
256-
err = ErrUnsupportedECIESParameters
257-
return
258-
}
233+
params, err := pubkeyParams(pub)
234+
if err != nil {
235+
return nil, err
259236
}
237+
260238
R, err := GenerateKey(rand, pub.Curve, params)
261239
if err != nil {
262-
return
240+
return nil, err
263241
}
264242

265-
hash := params.Hash()
266243
z, err := R.GenerateShared(pub, params.KeyLen, params.KeyLen)
267244
if err != nil {
268-
return
269-
}
270-
K, err := concatKDF(hash, z, s1, params.KeyLen+params.KeyLen)
271-
if err != nil {
272-
return
245+
return nil, err
273246
}
274-
Ke := K[:params.KeyLen]
275-
Km := K[params.KeyLen:]
276-
hash.Write(Km)
277-
Km = hash.Sum(nil)
278-
hash.Reset()
247+
248+
hash := params.Hash()
249+
Ke, Km := deriveKeys(hash, z, s1, params.KeyLen)
279250

280251
em, err := symEncrypt(rand, params, Ke, m)
281252
if err != nil || len(em) <= params.BlockSize {
282-
return
253+
return nil, err
283254
}
284255

285256
d := messageTag(params.Hash, Km, em, s2)
@@ -289,21 +260,19 @@ func Encrypt(rand io.Reader, pub *PublicKey, m, s1, s2 []byte) (ct []byte, err e
289260
copy(ct, Rb)
290261
copy(ct[len(Rb):], em)
291262
copy(ct[len(Rb)+len(em):], d)
292-
return
263+
return ct, nil
293264
}
294265

295266
// Decrypt decrypts an ECIES ciphertext.
296267
func (prv *PrivateKey) Decrypt(c, s1, s2 []byte) (m []byte, err error) {
297268
if len(c) == 0 {
298269
return nil, ErrInvalidMessage
299270
}
300-
params := prv.PublicKey.Params
301-
if params == nil {
302-
if params = ParamsFromCurve(prv.PublicKey.Curve); params == nil {
303-
err = ErrUnsupportedECIESParameters
304-
return
305-
}
271+
params, err := pubkeyParams(&prv.PublicKey)
272+
if err != nil {
273+
return nil, err
306274
}
275+
307276
hash := params.Hash()
308277

309278
var (
@@ -317,12 +286,10 @@ func (prv *PrivateKey) Decrypt(c, s1, s2 []byte) (m []byte, err error) {
317286
case 2, 3, 4:
318287
rLen = (prv.PublicKey.Curve.Params().BitSize + 7) / 4
319288
if len(c) < (rLen + hLen + 1) {
320-
err = ErrInvalidMessage
321-
return
289+
return nil, ErrInvalidMessage
322290
}
323291
default:
324-
err = ErrInvalidPublicKey
325-
return
292+
return nil, ErrInvalidPublicKey
326293
}
327294

328295
mStart = rLen
@@ -332,36 +299,19 @@ func (prv *PrivateKey) Decrypt(c, s1, s2 []byte) (m []byte, err error) {
332299
R.Curve = prv.PublicKey.Curve
333300
R.X, R.Y = elliptic.Unmarshal(R.Curve, c[:rLen])
334301
if R.X == nil {
335-
err = ErrInvalidPublicKey
336-
return
337-
}
338-
if !R.Curve.IsOnCurve(R.X, R.Y) {
339-
err = ErrInvalidCurve
340-
return
302+
return nil, ErrInvalidPublicKey
341303
}
342304

343305
z, err := prv.GenerateShared(R, params.KeyLen, params.KeyLen)
344306
if err != nil {
345-
return
307+
return nil, err
346308
}
347-
348-
K, err := concatKDF(hash, z, s1, params.KeyLen+params.KeyLen)
349-
if err != nil {
350-
return
351-
}
352-
353-
Ke := K[:params.KeyLen]
354-
Km := K[params.KeyLen:]
355-
hash.Write(Km)
356-
Km = hash.Sum(nil)
357-
hash.Reset()
309+
Ke, Km := deriveKeys(hash, z, s1, params.KeyLen)
358310

359311
d := messageTag(params.Hash, Km, c[mStart:mEnd], s2)
360312
if subtle.ConstantTimeCompare(c[mEnd:], d) != 1 {
361-
err = ErrInvalidMessage
362-
return
313+
return nil, ErrInvalidMessage
363314
}
364315

365-
m, err = symDecrypt(params, Ke, c[mStart:mEnd])
366-
return
316+
return symDecrypt(params, Ke, c[mStart:mEnd])
367317
}

crypto/ecies/ecies_test.go

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,17 +42,23 @@ import (
4242
"github.com/XinFinOrg/XDPoSChain/crypto"
4343
)
4444

45-
// Ensure the KDF generates appropriately sized keys.
4645
func TestKDF(t *testing.T) {
47-
msg := []byte("Hello, world")
48-
h := sha256.New()
49-
50-
k, err := concatKDF(h, msg, nil, 64)
51-
if err != nil {
52-
t.Fatal(err)
53-
}
54-
if len(k) != 64 {
55-
t.Fatalf("KDF: generated key is the wrong size (%d instead of 64\n", len(k))
46+
tests := []struct {
47+
length int
48+
output []byte
49+
}{
50+
{6, decode("858b192fa2ed")},
51+
{32, decode("858b192fa2ed4395e2bf88dd8d5770d67dc284ee539f12da8bceaa45d06ebae0")},
52+
{48, decode("858b192fa2ed4395e2bf88dd8d5770d67dc284ee539f12da8bceaa45d06ebae0700f1ab918a5f0413b8140f9940d6955")},
53+
{64, decode("858b192fa2ed4395e2bf88dd8d5770d67dc284ee539f12da8bceaa45d06ebae0700f1ab918a5f0413b8140f9940d6955f3467fd6672cce1024c5b1effccc0f61")},
54+
}
55+
56+
for _, test := range tests {
57+
h := sha256.New()
58+
k := concatKDF(h, []byte("input"), nil, test.length)
59+
if !bytes.Equal(k, test.output) {
60+
t.Fatalf("KDF: generated key %x does not match expected output %x", k, test.output)
61+
}
5662
}
5763
}
5864

@@ -293,8 +299,8 @@ func TestParamSelection(t *testing.T) {
293299

294300
func testParamSelection(t *testing.T, c testCase) {
295301
params := ParamsFromCurve(c.Curve)
296-
if params == nil && c.Expected != nil {
297-
t.Fatalf("%s (%s)\n", ErrInvalidParams.Error(), c.Name)
302+
if params == nil {
303+
t.Fatal("ParamsFromCurve returned nil")
298304
} else if params != nil && !cmpParams(params, c.Expected) {
299305
t.Fatalf("ecies: parameters should be invalid (%s)\n", c.Name)
300306
}
@@ -328,7 +334,6 @@ func testParamSelection(t *testing.T, c testCase) {
328334
if err == nil {
329335
t.Fatalf("ecies: encryption should not have succeeded (%s)\n", c.Name)
330336
}
331-
332337
}
333338

334339
// Ensure that the basic public key validation in the decryption operation
@@ -414,3 +419,11 @@ func hexKey(prv string) *PrivateKey {
414419
}
415420
return ImportECDSA(key)
416421
}
422+
423+
func decode(s string) []byte {
424+
bytes, err := hex.DecodeString(s)
425+
if err != nil {
426+
panic(err)
427+
}
428+
return bytes
429+
}

crypto/ecies/params.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ import (
4040
"crypto/sha256"
4141
"crypto/sha512"
4242
"errors"
43+
"fmt"
4344
"hash"
4445

4546
ethcrypto "github.com/XinFinOrg/XDPoSChain/crypto"
@@ -49,8 +50,14 @@ var (
4950
DefaultCurve = ethcrypto.S256()
5051
ErrUnsupportedECDHAlgorithm = errors.New("ecies: unsupported ECDH algorithm")
5152
ErrUnsupportedECIESParameters = errors.New("ecies: unsupported ECIES parameters")
53+
ErrInvalidKeyLen = fmt.Errorf("ecies: invalid key size (> %d) in ECIESParams", maxKeyLen)
5254
)
5355

56+
// KeyLen is limited to prevent overflow of the counter
57+
// in concatKDF. While the theoretical limit is much higher,
58+
// no known cipher uses keys larger than 512 bytes.
59+
const maxKeyLen = 512
60+
5461
type ECIESParams struct {
5562
Hash func() hash.Hash // hash function
5663
hashAlgo crypto.Hash
@@ -115,3 +122,16 @@ func AddParamsForCurve(curve elliptic.Curve, params *ECIESParams) {
115122
func ParamsFromCurve(curve elliptic.Curve) (params *ECIESParams) {
116123
return paramsFromCurve[curve]
117124
}
125+
126+
func pubkeyParams(key *PublicKey) (*ECIESParams, error) {
127+
params := key.Params
128+
if params == nil {
129+
if params = ParamsFromCurve(key.Curve); params == nil {
130+
return nil, ErrUnsupportedECIESParameters
131+
}
132+
}
133+
if params.KeyLen > maxKeyLen {
134+
return nil, ErrInvalidKeyLen
135+
}
136+
return params, nil
137+
}

0 commit comments

Comments
 (0)