Skip to content

Commit 426c287

Browse files
jboelterintcagl
authored andcommitted
crypto/tls: add VerifyPeerCertificate to tls.Config
VerifyPeerCertificate returns an error if the peer should not be trusted. It will be called after the initial handshake and before any other verification checks on the cert or chain are performed. This provides the callee an opportunity to augment the certificate verification. If VerifyPeerCertificate is not nil and returns an error, then the handshake will fail. Fixes #16363 Change-Id: I6a22f199f0e81b6f5d5f37c54d85ab878216bb22 Reviewed-on: https://go-review.googlesource.com/26654 Reviewed-by: Brad Fitzpatrick <[email protected]>
1 parent 93bca63 commit 426c287

File tree

5 files changed

+182
-1
lines changed

5 files changed

+182
-1
lines changed

src/crypto/tls/common.go

+13
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,18 @@ type Config struct {
325325
// material from the returned config will be used for session tickets.
326326
GetConfigForClient func(*ClientHelloInfo) (*Config, error)
327327

328+
// VerifyPeerCertificate, if not nil, is called after normal
329+
// certificate verification by either a TLS client or server. It
330+
// receives the raw ASN.1 certificates provided by the peer and also
331+
// any verified chains that normal processing found. If it returns a
332+
// non-nil error, the handshake is aborted and that error results.
333+
//
334+
// If normal verification fails then the handshake will abort before
335+
// considering this callback. If normal verification is disabled by
336+
// setting InsecureSkipVerify then this callback will be considered but
337+
// the verifiedChains argument will always be nil.
338+
VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
339+
328340
// RootCAs defines the set of root certificate authorities
329341
// that clients use when verifying server certificates.
330342
// If RootCAs is nil, TLS uses the host's root CA set.
@@ -474,6 +486,7 @@ func (c *Config) Clone() *Config {
474486
NameToCertificate: c.NameToCertificate,
475487
GetCertificate: c.GetCertificate,
476488
GetConfigForClient: c.GetConfigForClient,
489+
VerifyPeerCertificate: c.VerifyPeerCertificate,
477490
RootCAs: c.RootCAs,
478491
NextProtos: c.NextProtos,
479492
ServerName: c.ServerName,

src/crypto/tls/handshake_client.go

+7
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,13 @@ func (hs *clientHandshakeState) doFullHandshake() error {
304304
}
305305
}
306306

307+
if c.config.VerifyPeerCertificate != nil {
308+
if err := c.config.VerifyPeerCertificate(certMsg.certificates, c.verifiedChains); err != nil {
309+
c.sendAlert(alertBadCertificate)
310+
return err
311+
}
312+
}
313+
307314
switch certs[0].PublicKey.(type) {
308315
case *rsa.PublicKey, *ecdsa.PublicKey:
309316
break

src/crypto/tls/handshake_client_test.go

+154
Original file line numberDiff line numberDiff line change
@@ -1067,6 +1067,160 @@ func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) {
10671067
}
10681068
}
10691069

1070+
func TestVerifyPeerCertificate(t *testing.T) {
1071+
issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
1072+
if err != nil {
1073+
panic(err)
1074+
}
1075+
1076+
rootCAs := x509.NewCertPool()
1077+
rootCAs.AddCert(issuer)
1078+
1079+
now := func() time.Time { return time.Unix(1476984729, 0) }
1080+
1081+
sentinelErr := errors.New("TestVerifyPeerCertificate")
1082+
1083+
verifyCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1084+
if l := len(rawCerts); l != 1 {
1085+
return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
1086+
}
1087+
if len(validatedChains) == 0 {
1088+
return errors.New("got len(validatedChains) = 0, wanted non-zero")
1089+
}
1090+
*called = true
1091+
return nil
1092+
}
1093+
1094+
tests := []struct {
1095+
configureServer func(*Config, *bool)
1096+
configureClient func(*Config, *bool)
1097+
validate func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error)
1098+
}{
1099+
{
1100+
configureServer: func(config *Config, called *bool) {
1101+
config.InsecureSkipVerify = false
1102+
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1103+
return verifyCallback(called, rawCerts, validatedChains)
1104+
}
1105+
},
1106+
configureClient: func(config *Config, called *bool) {
1107+
config.InsecureSkipVerify = false
1108+
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1109+
return verifyCallback(called, rawCerts, validatedChains)
1110+
}
1111+
},
1112+
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1113+
if clientErr != nil {
1114+
t.Errorf("#%d: client handshake failed: %v", testNo, clientErr)
1115+
}
1116+
if serverErr != nil {
1117+
t.Errorf("#%d: server handshake failed: %v", testNo, serverErr)
1118+
}
1119+
if !clientCalled {
1120+
t.Error("#%d: client did not call callback", testNo)
1121+
}
1122+
if !serverCalled {
1123+
t.Error("#%d: server did not call callback", testNo)
1124+
}
1125+
},
1126+
},
1127+
{
1128+
configureServer: func(config *Config, called *bool) {
1129+
config.InsecureSkipVerify = false
1130+
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1131+
return sentinelErr
1132+
}
1133+
},
1134+
configureClient: func(config *Config, called *bool) {
1135+
config.VerifyPeerCertificate = nil
1136+
},
1137+
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1138+
if serverErr != sentinelErr {
1139+
t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
1140+
}
1141+
},
1142+
},
1143+
{
1144+
configureServer: func(config *Config, called *bool) {
1145+
config.InsecureSkipVerify = false
1146+
},
1147+
configureClient: func(config *Config, called *bool) {
1148+
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1149+
return sentinelErr
1150+
}
1151+
},
1152+
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1153+
if clientErr != sentinelErr {
1154+
t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
1155+
}
1156+
},
1157+
},
1158+
{
1159+
configureServer: func(config *Config, called *bool) {
1160+
config.InsecureSkipVerify = false
1161+
},
1162+
configureClient: func(config *Config, called *bool) {
1163+
config.InsecureSkipVerify = true
1164+
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1165+
if l := len(rawCerts); l != 1 {
1166+
return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
1167+
}
1168+
// With InsecureSkipVerify set, this
1169+
// callback should still be called but
1170+
// validatedChains must be empty.
1171+
if l := len(validatedChains); l != 0 {
1172+
return errors.New("got len(validatedChains) = 0, wanted zero")
1173+
}
1174+
*called = true
1175+
return nil
1176+
}
1177+
},
1178+
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1179+
if clientErr != nil {
1180+
t.Errorf("#%d: client handshake failed: %v", testNo, clientErr)
1181+
}
1182+
if serverErr != nil {
1183+
t.Errorf("#%d: server handshake failed: %v", testNo, serverErr)
1184+
}
1185+
if !clientCalled {
1186+
t.Error("#%d: client did not call callback", testNo)
1187+
}
1188+
},
1189+
},
1190+
}
1191+
1192+
for i, test := range tests {
1193+
c, s := net.Pipe()
1194+
done := make(chan error)
1195+
1196+
var clientCalled, serverCalled bool
1197+
1198+
go func() {
1199+
config := testConfig.Clone()
1200+
config.ServerName = "example.golang"
1201+
config.ClientAuth = RequireAndVerifyClientCert
1202+
config.ClientCAs = rootCAs
1203+
config.Time = now
1204+
test.configureServer(config, &serverCalled)
1205+
1206+
err = Server(s, config).Handshake()
1207+
s.Close()
1208+
done <- err
1209+
}()
1210+
1211+
config := testConfig.Clone()
1212+
config.ServerName = "example.golang"
1213+
config.RootCAs = rootCAs
1214+
config.Time = now
1215+
test.configureClient(config, &clientCalled)
1216+
clientErr := Client(c, config).Handshake()
1217+
c.Close()
1218+
serverErr := <-done
1219+
1220+
test.validate(t, i, clientCalled, serverCalled, clientErr, serverErr)
1221+
}
1222+
}
1223+
10701224
// brokenConn wraps a net.Conn and causes all Writes after a certain number to
10711225
// fail with brokenConnErr.
10721226
type brokenConn struct {

src/crypto/tls/handshake_server.go

+7
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,13 @@ func (hs *serverHandshakeState) processCertsFromClient(certificates [][]byte) (c
747747
c.verifiedChains = chains
748748
}
749749

750+
if c.config.VerifyPeerCertificate != nil {
751+
if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil {
752+
c.sendAlert(alertBadCertificate)
753+
return nil, err
754+
}
755+
}
756+
750757
if len(certs) == 0 {
751758
return nil, nil
752759
}

src/crypto/tls/tls_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ func TestClone(t *testing.T) {
477477
case "Rand":
478478
f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
479479
continue
480-
case "Time", "GetCertificate", "GetConfigForClient":
480+
case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate":
481481
// DeepEqual can't compare functions.
482482
continue
483483
case "Certificates":

0 commit comments

Comments
 (0)