Skip to content

Commit 4eaf6ff

Browse files
committed
fixes
1 parent 2827f80 commit 4eaf6ff

File tree

1 file changed

+59
-142
lines changed

1 file changed

+59
-142
lines changed

selector.go

Lines changed: 59 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"net"
1010
"strings"
1111
"sync"
12+
"time"
1213

1314
"github.com/miekg/dns"
1415
)
@@ -152,7 +153,6 @@ var rootIPs = []string{
152153

153154
type authNSSelector struct {
154155
sync.Mutex
155-
update sync.Mutex
156156
pool *Resolvers
157157
list []*resolver
158158
lookup map[string]*resolver
@@ -173,55 +173,59 @@ func newAuthNSSelector(r *Resolvers) *authNSSelector {
173173

174174
for _, addr := range rootIPs {
175175
if res := r.initResolver(addr); res != nil {
176-
auth.setLookup(addr, res)
176+
auth.lookup[addr] = res
177+
auth.list = append(auth.list, res)
177178
auth.rootResolvers = append(auth.rootResolvers, res)
178179
}
179180
}
180-
181-
auth.appendToList(auth.rootResolvers)
182181
return auth
183182
}
184183

185184
// GetResolver performs random selection on the pool of resolvers.
186185
func (r *authNSSelector) GetResolver(fqdn string) *resolver {
186+
r.Lock()
187+
defer r.Unlock()
188+
187189
name := strings.ToLower(RemoveLastDot(fqdn))
188190
labels := strings.Split(name, ".")
189191
if len(labels) > 1 {
190192
name = strings.Join(labels[1:], ".")
191193
}
192194

193-
servers := r.getFQDNToServers(name)
194-
if len(servers) == 0 {
195-
r.update.Lock()
195+
if _, found := r.fqdnToResolvers[name]; !found {
196196
r.populateAuthServers(name)
197-
r.update.Unlock()
198197
}
199198

200-
resolvers := r.getFQDNToResolvers(name)
201-
if len(resolvers) > 0 {
202-
return pickOneResolver(resolvers)
199+
if res, found := r.fqdnToResolvers[name]; found {
200+
return pickOneResolver(res)
203201
}
204202
return nil
205203
}
206204

207205
func (r *authNSSelector) LookupResolver(addr string) *resolver {
208-
return r.getLookup(addr)
206+
r.Lock()
207+
defer r.Unlock()
208+
209+
return r.lookup[addr]
209210
}
210211

211212
func (r *authNSSelector) AddResolver(res *resolver) {
212-
addrstr := res.address.IP.String()
213+
r.Lock()
214+
defer r.Unlock()
213215

214-
if fres := r.getLookup(addrstr); fres == nil {
215-
r.appendToList([]*resolver{res})
216-
r.setLookup(addrstr, res)
216+
addrstr := res.address.IP.String()
217+
if _, found := r.lookup[addrstr]; !found {
218+
r.list = append(r.list, res)
219+
r.lookup[addrstr] = res
217220
}
218221
}
219222

220223
func (r *authNSSelector) AllResolvers() []*resolver {
221-
list := r.getList()
224+
r.Lock()
225+
defer r.Unlock()
222226

223227
var active []*resolver
224-
for _, res := range list {
228+
for _, res := range r.list {
225229
select {
226230
case <-res.done:
227231
default:
@@ -232,10 +236,11 @@ func (r *authNSSelector) AllResolvers() []*resolver {
232236
}
233237

234238
func (r *authNSSelector) Len() int {
235-
list := r.getList()
239+
r.Lock()
240+
defer r.Unlock()
236241

237242
var count int
238-
for _, res := range list {
243+
for _, res := range r.list {
239244
select {
240245
case <-res.done:
241246
default:
@@ -246,9 +251,10 @@ func (r *authNSSelector) Len() int {
246251
}
247252

248253
func (r *authNSSelector) Close() {
249-
list := r.getList()
254+
r.Lock()
255+
defer r.Unlock()
250256

251-
for _, res := range list {
257+
for _, res := range r.list {
252258
select {
253259
case <-res.done:
254260
default:
@@ -264,12 +270,9 @@ func (r *authNSSelector) Close() {
264270
}
265271

266272
func (r *authNSSelector) populateAuthServers(fqdn string) {
267-
if s := r.getFQDNToServers(fqdn); len(s) > 0 {
268-
return
269-
}
270-
271273
labels := strings.Split(fqdn, ".")
272274
last, resolvers := r.findClosestResolverSet(fqdn, labels[len(labels)-1])
275+
273276
if len(labels) < len(strings.Split(last, ".")) {
274277
return
275278
}
@@ -282,7 +285,7 @@ func (r *authNSSelector) populateFromLabel(last, fqdn string, resolvers []*resol
282285
res := pickOneResolver(resolvers)
283286

284287
if servers := r.getNameServers(name, res); len(servers) > 0 {
285-
r.setFQDNToServers(name, servers)
288+
r.fqdnToServers[name] = servers
286289

287290
var wg sync.WaitGroup
288291
var resset []*resolver
@@ -291,13 +294,13 @@ func (r *authNSSelector) populateFromLabel(last, fqdn string, resolvers []*resol
291294
go func(name string, res *resolver) {
292295
defer wg.Done()
293296

294-
if fres := r.getServerToResolver(server); fres != nil {
297+
if fres, found := r.serverToResolver[server]; found {
295298
resset = append(resset, fres)
296299
} else if nres := r.serverNameToResolverObj(server, res); nres != nil {
297300
resset = append(resset, nres)
298-
r.appendToList([]*resolver{nres})
299-
r.setLookup(nres.address.IP.String(), nres)
300-
r.setServerToResolver(server, nres)
301+
r.list = append(r.list, nres)
302+
r.lookup[nres.address.IP.String()] = nres
303+
r.serverToResolver[server] = nres
301304
}
302305
}(server, pickOneResolver(resolvers))
303306
}
@@ -307,8 +310,7 @@ func (r *authNSSelector) populateFromLabel(last, fqdn string, resolvers []*resol
307310
resolvers = resset
308311
}
309312
}
310-
311-
r.setFQDNToResolvers(name, resolvers)
313+
r.fqdnToResolvers[name] = resolvers
312314
return false
313315
})
314316
}
@@ -319,7 +321,7 @@ func (r *authNSSelector) findClosestResolverSet(fqdn, tld string) (string, []*re
319321
_ = copy(resolvers, r.rootResolvers)
320322

321323
FQDNToRegistered(fqdn, tld, func(name string) bool {
322-
if res := r.getFQDNToResolvers(name); len(res) > 0 {
324+
if res, found := r.fqdnToResolvers[name]; found {
323325
resolvers = res
324326
return true
325327
}
@@ -331,136 +333,51 @@ func (r *authNSSelector) findClosestResolverSet(fqdn, tld string) (string, []*re
331333
}
332334

333335
func (r *authNSSelector) serverNameToResolverObj(server string, res *resolver) *resolver {
334-
ch := make(chan *dns.Msg, 1)
335-
defer close(ch)
336+
addr := res.address.IP.String() + ":53"
337+
client := dns.Client{
338+
Net: "udp",
339+
Timeout: time.Second,
340+
}
336341

337342
for i := 0; i < maxQueryAttempts; i++ {
338-
req := request{
339-
Res: res,
340-
Msg: QueryMsg(server, dns.TypeA),
341-
Result: ch,
342-
}
343-
res.queue.Append(&req)
343+
msg := QueryMsg(server, dns.TypeA)
344344

345-
select {
346-
case <-res.done:
347-
return nil
348-
case resp := <-ch:
349-
if resp != nil && resp.Rcode == dns.RcodeSuccess {
350-
for _, rr := range AnswersByType(resp, dns.TypeA) {
351-
if addr, ok := rr.(*dns.A); ok {
352-
addr := net.JoinHostPort(addr.A.String(), "53")
353-
return r.pool.initResolver(addr)
354-
}
345+
if m, _, err := client.Exchange(msg, addr); err == nil && m != nil && m.Rcode == dns.RcodeSuccess {
346+
for _, rr := range AnswersByType(m, dns.TypeA) {
347+
if record, ok := rr.(*dns.A); ok {
348+
ip := net.JoinHostPort(record.A.String(), "53")
349+
return r.pool.initResolver(ip)
355350
}
356-
return nil
357351
}
352+
break
358353
}
359354
}
360355
return nil
361356
}
362357

363358
func (r *authNSSelector) getNameServers(fqdn string, res *resolver) []string {
364-
ch := make(chan *dns.Msg, 1)
365-
defer close(ch)
359+
addr := res.address.IP.String() + ":53"
360+
client := dns.Client{
361+
Net: "udp",
362+
Timeout: time.Second,
363+
}
366364

367365
var servers []string
368-
loop:
369366
for i := 0; i < maxQueryAttempts; i++ {
370-
req := request{
371-
Res: res,
372-
Msg: QueryMsg(fqdn, dns.TypeNS),
373-
Result: ch,
374-
}
375-
res.queue.Append(&req)
367+
msg := QueryMsg(fqdn, dns.TypeNS)
376368

377-
select {
378-
case <-res.done:
379-
return nil
380-
case resp := <-ch:
381-
if resp != nil && resp.Rcode == dns.RcodeSuccess {
382-
for _, rr := range AnswersByType(resp, dns.TypeNS) {
383-
if record, ok := rr.(*dns.NS); ok {
384-
servers = append(servers, strings.ToLower(RemoveLastDot(record.Ns)))
385-
}
369+
if m, _, err := client.Exchange(msg, addr); err == nil && m != nil && m.Rcode == dns.RcodeSuccess {
370+
for _, rr := range AnswersByType(m, dns.TypeNS) {
371+
if record, ok := rr.(*dns.NS); ok {
372+
servers = append(servers, strings.ToLower(RemoveLastDot(record.Ns)))
386373
}
387-
break loop
388374
}
375+
break
389376
}
390377
}
391378
return servers
392379
}
393380

394-
func (r *authNSSelector) getList() []*resolver {
395-
r.Lock()
396-
defer r.Unlock()
397-
398-
return r.list
399-
}
400-
401-
func (r *authNSSelector) appendToList(elements []*resolver) {
402-
r.Lock()
403-
defer r.Unlock()
404-
405-
r.list = append(r.list, elements...)
406-
}
407-
408-
func (r *authNSSelector) getLookup(key string) *resolver {
409-
r.Lock()
410-
defer r.Unlock()
411-
412-
return r.lookup[key]
413-
}
414-
415-
func (r *authNSSelector) setLookup(key string, res *resolver) {
416-
r.Lock()
417-
defer r.Unlock()
418-
419-
r.lookup[key] = res
420-
}
421-
422-
func (r *authNSSelector) getFQDNToServers(key string) []string {
423-
r.Lock()
424-
defer r.Unlock()
425-
426-
return r.fqdnToServers[key]
427-
}
428-
429-
func (r *authNSSelector) setFQDNToServers(key string, servers []string) {
430-
r.Lock()
431-
defer r.Unlock()
432-
433-
r.fqdnToServers[key] = servers
434-
}
435-
436-
func (r *authNSSelector) getFQDNToResolvers(key string) []*resolver {
437-
r.Lock()
438-
defer r.Unlock()
439-
440-
return r.fqdnToResolvers[key]
441-
}
442-
443-
func (r *authNSSelector) setFQDNToResolvers(key string, resolvers []*resolver) {
444-
r.Lock()
445-
defer r.Unlock()
446-
447-
r.fqdnToResolvers[key] = resolvers
448-
}
449-
450-
func (r *authNSSelector) getServerToResolver(key string) *resolver {
451-
r.Lock()
452-
defer r.Unlock()
453-
454-
return r.serverToResolver[key]
455-
}
456-
457-
func (r *authNSSelector) setServerToResolver(key string, res *resolver) {
458-
r.Lock()
459-
defer r.Unlock()
460-
461-
r.serverToResolver[key] = res
462-
}
463-
464381
func pickOneResolver(resolvers []*resolver) *resolver {
465382
if l := len(resolvers); l > 0 {
466383
return resolvers[rand.Intn(l)]

0 commit comments

Comments
 (0)