@@ -15,8 +15,72 @@ import (
1515 "crypto/sha256"
1616 "crypto/x509"
1717 "encoding/pem"
18+ "sync"
1819)
1920
21+ // server pub keys registry
22+ var (
23+ serverPubKeyLock sync.RWMutex
24+ serverPubKeyRegistry map [string ]* rsa.PublicKey
25+ )
26+
27+ // RegisterServerPubKey registers a server RSA public key which can be used to
28+ // send data in a secure manner to the server without receiving the public key
29+ // in a potentially insecure way from the server first.
30+ // Registered keys can afterwards be used adding serverPubKey=<name> to the DSN.
31+ //
32+ // Note: The provided rsa.PublicKey instance is exclusively owned by the driver
33+ // after registering it and may not be modified.
34+ //
35+ // data, err := ioutil.ReadFile("mykey.pem")
36+ // if err != nil {
37+ // log.Fatal(err)
38+ // }
39+ //
40+ // block, _ := pem.Decode(data)
41+ // if block == nil || block.Type != "PUBLIC KEY" {
42+ // log.Fatal("failed to decode PEM block containing public key")
43+ // }
44+ //
45+ // pub, err := x509.ParsePKIXPublicKey(block.Bytes)
46+ // if err != nil {
47+ // log.Fatal(err)
48+ // }
49+ //
50+ // if rsaPubKey, ok := pub.(*rsa.PublicKey); ok {
51+ // mysql.RegisterServerPubKey("mykey", rsaPubKey)
52+ // } else {
53+ // log.Fatal("not a RSA public key")
54+ // }
55+ //
56+ func RegisterServerPubKey (name string , pubKey * rsa.PublicKey ) {
57+ serverPubKeyLock .Lock ()
58+ if serverPubKeyRegistry == nil {
59+ serverPubKeyRegistry = make (map [string ]* rsa.PublicKey )
60+ }
61+
62+ serverPubKeyRegistry [name ] = pubKey
63+ serverPubKeyLock .Unlock ()
64+ }
65+
66+ // DeregisterServerPubKey removes the public key registered with the given name.
67+ func DeregisterServerPubKey (name string ) {
68+ serverPubKeyLock .Lock ()
69+ if serverPubKeyRegistry != nil {
70+ delete (serverPubKeyRegistry , name )
71+ }
72+ serverPubKeyLock .Unlock ()
73+ }
74+
75+ func getServerPubKey (name string ) (pubKey * rsa.PublicKey ) {
76+ serverPubKeyLock .RLock ()
77+ if v , ok := serverPubKeyRegistry [name ]; ok {
78+ pubKey = v
79+ }
80+ serverPubKeyLock .RUnlock ()
81+ return
82+ }
83+
2084// Hash password using pre 4.1 (old password) method
2185// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c
2286type myRnd struct {
@@ -154,6 +218,25 @@ func scrambleSHA256Password(scramble []byte, password string) []byte {
154218 return message1
155219}
156220
221+ func encryptPassword (password string , seed []byte , pub * rsa.PublicKey ) ([]byte , error ) {
222+ plain := make ([]byte , len (password )+ 1 )
223+ copy (plain , password )
224+ for i := range plain {
225+ j := i % len (seed )
226+ plain [i ] ^= seed [j ]
227+ }
228+ sha1 := sha1 .New ()
229+ return rsa .EncryptOAEP (sha1 , rand .Reader , pub , plain , nil )
230+ }
231+
232+ func (mc * mysqlConn ) sendEncryptedPassword (seed []byte , pub * rsa.PublicKey ) error {
233+ enc , err := encryptPassword (mc .cfg .Passwd , seed , pub )
234+ if err != nil {
235+ return err
236+ }
237+ return mc .writeAuthSwitchPacket (enc , false )
238+ }
239+
157240func (mc * mysqlConn ) auth (authData []byte , plugin string ) ([]byte , bool , error ) {
158241 switch plugin {
159242 case "caching_sha2_password" :
@@ -187,6 +270,25 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, bool, error)
187270 authResp := scramblePassword (authData [:20 ], mc .cfg .Passwd )
188271 return authResp , false , nil
189272
273+ case "sha256_password" :
274+ if len (mc .cfg .Passwd ) == 0 {
275+ return nil , true , nil
276+ }
277+ if mc .cfg .tls != nil || mc .cfg .Net == "unix" {
278+ // write cleartext auth packet
279+ return []byte (mc .cfg .Passwd ), true , nil
280+ }
281+
282+ pubKey := mc .cfg .pubKey
283+ if pubKey == nil {
284+ // request public key from server
285+ return []byte {1 }, false , nil
286+ }
287+
288+ // encrypted password
289+ enc , err := encryptPassword (mc .cfg .Passwd , authData , pubKey )
290+ return enc , false , err
291+
190292 default :
191293 errLog .Print ("unknown auth plugin:" , plugin )
192294 return nil , false , ErrUnknownPlugin
@@ -206,6 +308,9 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
206308 // sent and we have to keep using the cipher sent in the init packet.
207309 if authData == nil {
208310 authData = oldAuthData
311+ } else {
312+ // copy data from read buffer to owned slice
313+ copy (oldAuthData , authData )
209314 }
210315
211316 plugin = newPlugin
@@ -223,6 +328,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
223328 if err != nil {
224329 return err
225330 }
331+
226332 // Do not allow to change the auth plugin more than once
227333 if newPlugin != "" {
228334 return ErrMalformPkt
@@ -251,48 +357,34 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
251357 return err
252358 }
253359 } else {
254- seed := oldAuthData
255-
256- // TODO: allow to specify a local file with the pub key via
257- // the DSN
258-
259- // request public key
260- data := mc .buf .takeSmallBuffer (4 + 1 )
261- data [4 ] = cachingSha2PasswordRequestPublicKey
262- mc .writePacket (data )
263-
264- // parse public key
265- data , err := mc .readPacket ()
266- if err != nil {
267- return err
268- }
269-
270- block , _ := pem .Decode (data [1 :])
271- pub , err := x509 .ParsePKIXPublicKey (block .Bytes )
272- if err != nil {
273- return err
360+ pubKey := mc .cfg .pubKey
361+ if pubKey == nil {
362+ // request public key from server
363+ data := mc .buf .takeSmallBuffer (4 + 1 )
364+ data [4 ] = cachingSha2PasswordRequestPublicKey
365+ mc .writePacket (data )
366+
367+ // parse public key
368+ data , err := mc .readPacket ()
369+ if err != nil {
370+ return err
371+ }
372+
373+ block , _ := pem .Decode (data [1 :])
374+ pkix , err := x509 .ParsePKIXPublicKey (block .Bytes )
375+ if err != nil {
376+ return err
377+ }
378+ pubKey = pkix .(* rsa.PublicKey )
274379 }
275380
276381 // send encrypted password
277- plain := make ([]byte , len (mc .cfg .Passwd )+ 1 )
278- copy (plain , mc .cfg .Passwd )
279- for i := range plain {
280- j := i % len (seed )
281- plain [i ] ^= seed [j ]
282- }
283- sha1 := sha1 .New ()
284- enc , err := rsa .EncryptOAEP (sha1 , rand .Reader , pub .(* rsa.PublicKey ), plain , nil )
382+ err = mc .sendEncryptedPassword (oldAuthData , pubKey )
285383 if err != nil {
286384 return err
287385 }
288-
289- if err = mc .writeAuthSwitchPacket (enc , false ); err != nil {
290- return err
291- }
292- }
293- if err = mc .readResultOK (); err == nil {
294- return nil // auth successful
295386 }
387+ return mc .readResultOK ()
296388
297389 default :
298390 return ErrMalformPkt
@@ -301,6 +393,25 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
301393 return ErrMalformPkt
302394 }
303395
396+ case "sha256_password" :
397+ switch len (authData ) {
398+ case 0 :
399+ return nil // auth successful
400+ default :
401+ block , _ := pem .Decode (authData )
402+ pub , err := x509 .ParsePKIXPublicKey (block .Bytes )
403+ if err != nil {
404+ return err
405+ }
406+
407+ // send encrypted password
408+ err = mc .sendEncryptedPassword (oldAuthData , pub .(* rsa.PublicKey ))
409+ if err != nil {
410+ return err
411+ }
412+ return mc .readResultOK ()
413+ }
414+
304415 default :
305416 return nil // auth successful
306417 }
0 commit comments