@@ -24,7 +24,6 @@ import (
2424 "fmt"
2525 "hash/fnv"
2626 "io"
27- "io/ioutil"
2827 "log"
2928 "net/http"
3029 "net/http/cookiejar"
@@ -578,10 +577,6 @@ func (c *Collector) scrape(u, method string, depth int, requestData io.Reader, c
578577 if err != nil {
579578 return err
580579 }
581- if err := c .requestCheck (u , parsedURL , method , requestData , depth , checkRevisit ); err != nil {
582- return err
583- }
584-
585580 if hdr == nil {
586581 hdr = http.Header {}
587582 if c .Headers != nil {
@@ -595,30 +590,22 @@ func (c *Collector) scrape(u, method string, depth int, requestData io.Reader, c
595590 if _ , ok := hdr ["User-Agent" ]; ! ok {
596591 hdr .Set ("User-Agent" , c .UserAgent )
597592 }
598- rc , ok := requestData .(io. ReadCloser )
599- if ! ok && requestData != nil {
600- rc = ioutil . NopCloser ( requestData )
593+ req , err := http . NewRequest ( method , parsedURL . String (), requestData )
594+ if err != nil {
595+ return err
601596 }
597+ req .Header = hdr
602598 // The Go HTTP API ignores "Host" in the headers, preferring the client
603599 // to use the Host field on Request.
604- host := parsedURL .Host
605600 if hostHeader := hdr .Get ("Host" ); hostHeader != "" {
606- host = hostHeader
607- }
608- req := & http.Request {
609- Method : method ,
610- URL : parsedURL ,
611- Proto : "HTTP/1.1" ,
612- ProtoMajor : 1 ,
613- ProtoMinor : 1 ,
614- Header : hdr ,
615- Body : rc ,
616- Host : host ,
601+ req .Host = hostHeader
617602 }
618603 // note: once 1.13 is minimum supported Go version,
619604 // replace this with http.NewRequestWithContext
620605 req = req .WithContext (c .Context )
621- setRequestBody (req , requestData )
606+ if err := c .requestCheck (u , parsedURL , method , req .GetBody , depth , checkRevisit ); err != nil {
607+ return err
608+ }
622609 u = parsedURL .String ()
623610 c .wg .Add (1 )
624611 if c .Async {
@@ -628,38 +615,6 @@ func (c *Collector) scrape(u, method string, depth int, requestData io.Reader, c
628615 return c .fetch (u , method , depth , requestData , ctx , hdr , req )
629616}
630617
631- func setRequestBody (req * http.Request , body io.Reader ) {
632- if body != nil {
633- switch v := body .(type ) {
634- case * bytes.Buffer :
635- req .ContentLength = int64 (v .Len ())
636- buf := v .Bytes ()
637- req .GetBody = func () (io.ReadCloser , error ) {
638- r := bytes .NewReader (buf )
639- return ioutil .NopCloser (r ), nil
640- }
641- case * bytes.Reader :
642- req .ContentLength = int64 (v .Len ())
643- snapshot := * v
644- req .GetBody = func () (io.ReadCloser , error ) {
645- r := snapshot
646- return ioutil .NopCloser (& r ), nil
647- }
648- case * strings.Reader :
649- req .ContentLength = int64 (v .Len ())
650- snapshot := * v
651- req .GetBody = func () (io.ReadCloser , error ) {
652- r := snapshot
653- return ioutil .NopCloser (& r ), nil
654- }
655- }
656- if req .GetBody != nil && req .ContentLength == 0 {
657- req .Body = http .NoBody
658- req .GetBody = func () (io.ReadCloser , error ) { return http .NoBody , nil }
659- }
660- }
661- }
662-
663618func (c * Collector ) fetch (u , method string , depth int , requestData io.Reader , ctx * Context , hdr http.Header , req * http.Request ) error {
664619 defer c .wg .Done ()
665620 if ctx == nil {
@@ -739,7 +694,7 @@ func (c *Collector) fetch(u, method string, depth int, requestData io.Reader, ct
739694 return err
740695}
741696
742- func (c * Collector ) requestCheck (u string , parsedURL * url.URL , method string , requestData io.Reader , depth int , checkRevisit bool ) error {
697+ func (c * Collector ) requestCheck (u string , parsedURL * url.URL , method string , getBody func () ( io.ReadCloser , error ) , depth int , checkRevisit bool ) error {
743698 if u == "" {
744699 return ErrMissingURL
745700 }
@@ -755,19 +710,23 @@ func (c *Collector) requestCheck(u string, parsedURL *url.URL, method string, re
755710 }
756711 }
757712 if checkRevisit && ! c .AllowURLRevisit {
758- h := fnv .New64a ()
759- h .Write ([]byte (u ))
760-
761- var uHash uint64
762- if method == "GET" {
763- uHash = h .Sum64 ()
764- } else if requestData != nil {
765- h .Write (streamToByte (requestData ))
766- uHash = h .Sum64 ()
767- } else {
713+ // TODO weird behaviour, it allows CheckHead to work correctly,
714+ // but it should probably better be solved with
715+ // "check-but-not-save" flag or something
716+ if method != "GET" && getBody == nil {
768717 return nil
769718 }
770719
720+ var body io.ReadCloser
721+ if getBody != nil {
722+ var err error
723+ body , err = getBody ()
724+ if err != nil {
725+ return err
726+ }
727+ defer body .Close ()
728+ }
729+ uHash := requestHash (u , body )
771730 visited , err := c .store .IsVisited (uHash )
772731 if err != nil {
773732 return err
@@ -1368,14 +1327,8 @@ func (c *Collector) parseSettingsFromEnv() {
13681327}
13691328
13701329func (c * Collector ) checkHasVisited (URL string , requestData map [string ]string ) (bool , error ) {
1371- h := fnv .New64a ()
1372- h .Write ([]byte (URL ))
1373-
1374- if requestData != nil {
1375- h .Write (streamToByte (createFormReader (requestData )))
1376- }
1377-
1378- return c .store .IsVisited (h .Sum64 ())
1330+ hash := requestHash (URL , createFormReader (requestData ))
1331+ return c .store .IsVisited (hash )
13791332}
13801333
13811334// SanitizeFileName replaces dangerous characters in a string
@@ -1487,15 +1440,11 @@ func isMatchingFilter(fs []*regexp.Regexp, d []byte) bool {
14871440 return false
14881441}
14891442
1490- func streamToByte (r io.Reader ) []byte {
1491- buf := new (bytes.Buffer )
1492- buf .ReadFrom (r )
1493-
1494- if strReader , k := r .(* strings.Reader ); k {
1495- strReader .Seek (0 , 0 )
1496- } else if bReader , kb := r .(* bytes.Reader ); kb {
1497- bReader .Seek (0 , 0 )
1443+ func requestHash (url string , body io.Reader ) uint64 {
1444+ h := fnv .New64a ()
1445+ h .Write ([]byte (url ))
1446+ if body != nil {
1447+ io .Copy (h , body )
14981448 }
1499-
1500- return buf .Bytes ()
1449+ return h .Sum64 ()
15011450}
0 commit comments