Skip to content

Commit b151a08

Browse files
authored
Merge pull request #678 from WGH-/code-cleanup
Code cleanup
2 parents a611094 + 4e8e0d7 commit b151a08

File tree

1 file changed

+31
-82
lines changed

1 file changed

+31
-82
lines changed

colly.go

Lines changed: 31 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import (
2424
"fmt"
2525
"hash/fnv"
2626
"io"
27-
"io/ioutil"
2827
"log"
2928
"net/http"
3029
"net/http/cookiejar"
@@ -578,10 +577,6 @@ func (c *Collector) scrape(u, method string, depth int, requestData io.Reader, c
578577
if err != nil {
579578
return err
580579
}
581-
if err := c.requestCheck(u, parsedURL, method, requestData, depth, checkRevisit); err != nil {
582-
return err
583-
}
584-
585580
if hdr == nil {
586581
hdr = http.Header{}
587582
if c.Headers != nil {
@@ -595,30 +590,22 @@ func (c *Collector) scrape(u, method string, depth int, requestData io.Reader, c
595590
if _, ok := hdr["User-Agent"]; !ok {
596591
hdr.Set("User-Agent", c.UserAgent)
597592
}
598-
rc, ok := requestData.(io.ReadCloser)
599-
if !ok && requestData != nil {
600-
rc = ioutil.NopCloser(requestData)
593+
req, err := http.NewRequest(method, parsedURL.String(), requestData)
594+
if err != nil {
595+
return err
601596
}
597+
req.Header = hdr
602598
// The Go HTTP API ignores "Host" in the headers, preferring the client
603599
// to use the Host field on Request.
604-
host := parsedURL.Host
605600
if hostHeader := hdr.Get("Host"); hostHeader != "" {
606-
host = hostHeader
607-
}
608-
req := &http.Request{
609-
Method: method,
610-
URL: parsedURL,
611-
Proto: "HTTP/1.1",
612-
ProtoMajor: 1,
613-
ProtoMinor: 1,
614-
Header: hdr,
615-
Body: rc,
616-
Host: host,
601+
req.Host = hostHeader
617602
}
618603
// note: once 1.13 is minimum supported Go version,
619604
// replace this with http.NewRequestWithContext
620605
req = req.WithContext(c.Context)
621-
setRequestBody(req, requestData)
606+
if err := c.requestCheck(u, parsedURL, method, req.GetBody, depth, checkRevisit); err != nil {
607+
return err
608+
}
622609
u = parsedURL.String()
623610
c.wg.Add(1)
624611
if c.Async {
@@ -628,38 +615,6 @@ func (c *Collector) scrape(u, method string, depth int, requestData io.Reader, c
628615
return c.fetch(u, method, depth, requestData, ctx, hdr, req)
629616
}
630617

631-
func setRequestBody(req *http.Request, body io.Reader) {
632-
if body != nil {
633-
switch v := body.(type) {
634-
case *bytes.Buffer:
635-
req.ContentLength = int64(v.Len())
636-
buf := v.Bytes()
637-
req.GetBody = func() (io.ReadCloser, error) {
638-
r := bytes.NewReader(buf)
639-
return ioutil.NopCloser(r), nil
640-
}
641-
case *bytes.Reader:
642-
req.ContentLength = int64(v.Len())
643-
snapshot := *v
644-
req.GetBody = func() (io.ReadCloser, error) {
645-
r := snapshot
646-
return ioutil.NopCloser(&r), nil
647-
}
648-
case *strings.Reader:
649-
req.ContentLength = int64(v.Len())
650-
snapshot := *v
651-
req.GetBody = func() (io.ReadCloser, error) {
652-
r := snapshot
653-
return ioutil.NopCloser(&r), nil
654-
}
655-
}
656-
if req.GetBody != nil && req.ContentLength == 0 {
657-
req.Body = http.NoBody
658-
req.GetBody = func() (io.ReadCloser, error) { return http.NoBody, nil }
659-
}
660-
}
661-
}
662-
663618
func (c *Collector) fetch(u, method string, depth int, requestData io.Reader, ctx *Context, hdr http.Header, req *http.Request) error {
664619
defer c.wg.Done()
665620
if ctx == nil {
@@ -739,7 +694,7 @@ func (c *Collector) fetch(u, method string, depth int, requestData io.Reader, ct
739694
return err
740695
}
741696

742-
func (c *Collector) requestCheck(u string, parsedURL *url.URL, method string, requestData io.Reader, depth int, checkRevisit bool) error {
697+
func (c *Collector) requestCheck(u string, parsedURL *url.URL, method string, getBody func() (io.ReadCloser, error), depth int, checkRevisit bool) error {
743698
if u == "" {
744699
return ErrMissingURL
745700
}
@@ -755,19 +710,23 @@ func (c *Collector) requestCheck(u string, parsedURL *url.URL, method string, re
755710
}
756711
}
757712
if checkRevisit && !c.AllowURLRevisit {
758-
h := fnv.New64a()
759-
h.Write([]byte(u))
760-
761-
var uHash uint64
762-
if method == "GET" {
763-
uHash = h.Sum64()
764-
} else if requestData != nil {
765-
h.Write(streamToByte(requestData))
766-
uHash = h.Sum64()
767-
} else {
713+
// TODO weird behaviour, it allows CheckHead to work correctly,
714+
// but it should probably better be solved with
715+
// "check-but-not-save" flag or something
716+
if method != "GET" && getBody == nil {
768717
return nil
769718
}
770719

720+
var body io.ReadCloser
721+
if getBody != nil {
722+
var err error
723+
body, err = getBody()
724+
if err != nil {
725+
return err
726+
}
727+
defer body.Close()
728+
}
729+
uHash := requestHash(u, body)
771730
visited, err := c.store.IsVisited(uHash)
772731
if err != nil {
773732
return err
@@ -1368,14 +1327,8 @@ func (c *Collector) parseSettingsFromEnv() {
13681327
}
13691328

13701329
func (c *Collector) checkHasVisited(URL string, requestData map[string]string) (bool, error) {
1371-
h := fnv.New64a()
1372-
h.Write([]byte(URL))
1373-
1374-
if requestData != nil {
1375-
h.Write(streamToByte(createFormReader(requestData)))
1376-
}
1377-
1378-
return c.store.IsVisited(h.Sum64())
1330+
hash := requestHash(URL, createFormReader(requestData))
1331+
return c.store.IsVisited(hash)
13791332
}
13801333

13811334
// SanitizeFileName replaces dangerous characters in a string
@@ -1487,15 +1440,11 @@ func isMatchingFilter(fs []*regexp.Regexp, d []byte) bool {
14871440
return false
14881441
}
14891442

1490-
func streamToByte(r io.Reader) []byte {
1491-
buf := new(bytes.Buffer)
1492-
buf.ReadFrom(r)
1493-
1494-
if strReader, k := r.(*strings.Reader); k {
1495-
strReader.Seek(0, 0)
1496-
} else if bReader, kb := r.(*bytes.Reader); kb {
1497-
bReader.Seek(0, 0)
1443+
func requestHash(url string, body io.Reader) uint64 {
1444+
h := fnv.New64a()
1445+
h.Write([]byte(url))
1446+
if body != nil {
1447+
io.Copy(h, body)
14981448
}
1499-
1500-
return buf.Bytes()
1449+
return h.Sum64()
15011450
}

0 commit comments

Comments
 (0)