diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 545cdea..601d89a 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -23,7 +23,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v2 + uses: actions/checkout@v3 # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL diff --git a/.github/workflows/dockerhub-push.yml b/.github/workflows/dockerhub-push.yml index 49369db..07ba054 100644 --- a/.github/workflows/dockerhub-push.yml +++ b/.github/workflows/dockerhub-push.yml @@ -11,7 +11,7 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v3 - name: Set up QEMU uses: docker/setup-qemu-action@v1 diff --git a/.github/workflows/lint-test.yml b/.github/workflows/lint-test.yml index 6ba06a6..1531c6b 100644 --- a/.github/workflows/lint-test.yml +++ b/.github/workflows/lint-test.yml @@ -12,7 +12,7 @@ jobs: - name: Checkout code uses: actions/checkout@v3 - name: Run golangci-lint - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v3.1.0 with: version: latest args: --timeout 5m diff --git a/README.md b/README.md index b6184d9..0363d59 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,7 @@ This will display help for the tool. Here are all the switches it supports. | `-max-file-size` | Max Upload File Size (default 50 MB) | `simplehttpserver -max-file-size 100` | | `-sandbox` | Enable sandbox mode | `simplehttpserver -sandbox` | | `-https` | Enable HTTPS in case of http server | `simplehttpserver -https` | +| `-http1` | Enable only HTTP1 | `simplehttpserver -http1` | | `-cert` | HTTPS/TLS certificate (self generated if not specified) | `simplehttpserver -cert cert.pem` | | `-key` | HTTPS/TLS certificate private key | `simplehttpserver -key cert.key` | | `-domain` | Domain name to use for the self-generated certificate | `simplehttpserver -domain projectdiscovery.io` | @@ -128,7 +129,9 @@ simplehttpserver -rule rules.yaml -tcp -tls -domain localhost The rules are written as follows: ```yaml rules: - - match: regex + - match: regex-match + match-contains: literal-match + name: rule-name response: response data ``` @@ -137,6 +140,7 @@ For example to handle two different paths simulating an HTTP server or SMTP comm rules: # HTTP Requests - match: GET /path1 + name: redirect response: | HTTP/1.0 200 OK Server: httpd/2.0 @@ -149,6 +153,7 @@ rules: - match: GET /path2 + name: "404" response: | HTTP/1.0 404 OK Server: httpd/2.0 @@ -156,6 +161,7 @@ rules: Not found # SMTP Commands - match: "EHLO example.com" + name: smtp response: | 250-localhost Nice to meet you, [127.0.0.1] 250-PIPELINING @@ -167,6 +173,14 @@ rules: response: 250 Accepted - match: "RCPT TO: " response: 250 Accepted + + - match-contains: !!binary | + MAwCAQFgBwIBAwQAgAA= + name: "ldap" + # Request: 300c 0201 0160 0702 0103 0400 8000 0....`........ + # Response: 300c 0201 0161 070a 0100 0400 0400 0....a........ + response: !!binary | + MAwCAQFhBwoBAAQABAA= ``` ## Note diff --git a/go.mod b/go.mod index 64bffb3..2f48683 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/projectdiscovery/simplehttpserver go 1.17 require ( + github.com/fsnotify/fsnotify v1.5.1 github.com/phayes/freeport v0.0.0-20180830031419-95f893ade6f2 github.com/projectdiscovery/gologger v1.1.4 github.com/projectdiscovery/sslcert v0.0.0-20210416140253-8f56bec1bb5e @@ -14,4 +15,5 @@ require ( github.com/logrusorgru/aurora v2.0.3+incompatible // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.1 // indirect + golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c // indirect ) diff --git a/go.sum b/go.sum index aafe4a2..e239b8e 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fsnotify/fsnotify v1.5.1 h1:mZcQUHVQUQWoPXXtuf9yuEXKudkV2sx1E06UadKWpgI= +github.com/fsnotify/fsnotify v1.5.1/go.mod h1:T3375wBYaZdLLcVNkcVbzGHY7f1l/uK5T5Ai1i3InKU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/json-iterator/go v1.1.10 h1:Kz6Cvnvv2wGdaG/V8yMvfkmNiXq9Ya2KUv4rouJJr68= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= @@ -31,6 +33,8 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c h1:F1jZWGFhYfh0Ci55sIpILtKKK8p3i2/krTr0H1rg74I= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/internal/runner/banner.go b/internal/runner/banner.go index 9093db8..24febe0 100644 --- a/internal/runner/banner.go +++ b/internal/runner/banner.go @@ -8,11 +8,11 @@ const banner = ` \__ \/ / __ -__ \/ __ \/ / _ \/ /_/ / / / / / / /_/ / ___/ _ \/ ___/ | / / _ \/ ___/ ___/ / / / / / / / /_/ / / __/ __ / / / / / / ____(__ ) __/ / | |/ / __/ / /____/_/_/ /_/ /_/ .___/_/\___/_/ /_/ /_/ /_/ /_/ /____/\___/_/ |___/\___/_/ - /_/ - v0.0.4 + /_/ - v0.0.5 ` // Version is the current version -const Version = `0.0.4` +const Version = `0.0.5` // showBanner is used to show the banner to the user func showBanner() { diff --git a/internal/runner/options.go b/internal/runner/options.go index a5869d6..5086d8a 100644 --- a/internal/runner/options.go +++ b/internal/runner/options.go @@ -31,6 +31,8 @@ type Options struct { Silent bool Sandbox bool MaxFileSize int + HTTP1Only bool + MaxDumpBodySize int } // ParseOptions parses the command line options for application @@ -56,8 +58,9 @@ func ParseOptions() *Options { flag.BoolVar(&options.Version, "version", false, "Show version of the software") flag.BoolVar(&options.Silent, "silent", false, "Show only results in the output") flag.BoolVar(&options.Sandbox, "sandbox", false, "Enable sandbox mode") + flag.BoolVar(&options.HTTP1Only, "http1", false, "Enable only HTTP1") flag.IntVar(&options.MaxFileSize, "max-file-size", 50, "Max Upload File Size") - + flag.IntVar(&options.MaxDumpBodySize, "max-dump-body-size", -1, "Max Dump Body Size") flag.Parse() // Read the inputs and configure the logging diff --git a/internal/runner/runner.go b/internal/runner/runner.go index 5806044..59c28e3 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -5,6 +5,7 @@ import ( "github.com/projectdiscovery/simplehttpserver/pkg/binder" "github.com/projectdiscovery/simplehttpserver/pkg/httpserver" "github.com/projectdiscovery/simplehttpserver/pkg/tcpserver" + "github.com/projectdiscovery/simplehttpserver/pkg/unit" ) // Runner is a client for running the enumeration process. @@ -41,6 +42,12 @@ func New(options *Options) (*Runner, error) { if err != nil { return nil, err } + watcher, err := watchFile(r.options.RulesFile, serverTCP.LoadTemplate) + if err != nil { + return nil, err + } + defer watcher.Close() + r.serverTCP = serverTCP return &r, nil } @@ -59,6 +66,8 @@ func New(options *Options) (*Runner, error) { Verbose: r.options.Verbose, Sandbox: r.options.Sandbox, MaxFileSize: r.options.MaxFileSize, + HTTP1Only: r.options.HTTP1Only, + MaxDumpBodySize: unit.ToMb(r.options.MaxDumpBodySize), }) if err != nil { return nil, err @@ -71,6 +80,10 @@ func New(options *Options) (*Runner, error) { // Run logic func (r *Runner) Run() error { if r.options.EnableTCP { + if r.options.TCPWithTLS { + gologger.Print().Msgf("Serving TCP rule based tls server on tcp://%s", r.options.ListenAddress) + return r.serverTCP.ListenAndServeTLS() + } gologger.Print().Msgf("Serving TCP rule based server on tcp://%s", r.options.ListenAddress) return r.serverTCP.ListenAndServe() } diff --git a/internal/runner/watchdog.go b/internal/runner/watchdog.go new file mode 100644 index 0000000..2cdde4c --- /dev/null +++ b/internal/runner/watchdog.go @@ -0,0 +1,36 @@ +package runner + +import ( + "log" + + "github.com/fsnotify/fsnotify" +) + +type WatchEvent func(fname string) error + +func watchFile(fname string, callback WatchEvent) (watcher *fsnotify.Watcher, err error) { + watcher, err = fsnotify.NewWatcher() + if err != nil { + return + } + go func() { + for { + select { + case event, ok := <-watcher.Events: + if !ok { + continue + } + if event.Op&fsnotify.Write == fsnotify.Write { + if err := callback(fname); err != nil { + log.Println("err", err) + } + } + case <-watcher.Errors: + // ignore errors for now + } + } + }() + + err = watcher.Add(fname) + return +} diff --git a/pkg/httpserver/authlayer.go b/pkg/httpserver/authlayer.go index f2eff4b..297d863 100644 --- a/pkg/httpserver/authlayer.go +++ b/pkg/httpserver/authlayer.go @@ -5,7 +5,7 @@ import ( "net/http" ) -func (t *HTTPServer) basicauthlayer(handler http.Handler) http.HandlerFunc { +func (t *HTTPServer) basicauthlayer(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { user, pass, ok := r.BasicAuth() if !ok || user != t.options.BasicAuthUsername || pass != t.options.BasicAuthPassword { diff --git a/pkg/httpserver/httpserver.go b/pkg/httpserver/httpserver.go index 72da466..94c050e 100644 --- a/pkg/httpserver/httpserver.go +++ b/pkg/httpserver/httpserver.go @@ -1,6 +1,7 @@ package httpserver import ( + "crypto/tls" "errors" "net/http" "os" @@ -23,7 +24,9 @@ type Options struct { BasicAuthReal string Verbose bool Sandbox bool + HTTP1Only bool MaxFileSize int // 50Mb + MaxDumpBodySize int64 } // HTTPServer instance @@ -32,6 +35,9 @@ type HTTPServer struct { layers http.Handler } +// LayerHandler is the interface of all layer funcs +type Middleware func(http.Handler) http.Handler + // New http server instance with options func New(options *Options) (*HTTPServer, error) { var h HTTPServer @@ -50,18 +56,44 @@ func New(options *Options) (*HTTPServer, error) { if options.Sandbox { dir = SandboxFileSystem{fs: http.Dir(options.Folder), RootFolder: options.Folder} } - h.layers = h.loglayer(http.FileServer(dir)) + + httpHandler := http.FileServer(dir) + addHandler := func(newHandler Middleware) { + httpHandler = newHandler(httpHandler) + } + + // middleware + if options.EnableUpload { + addHandler(h.uploadlayer) + } + if options.BasicAuthUsername != "" || options.BasicAuthPassword != "" { - h.layers = h.loglayer(h.basicauthlayer(http.FileServer(dir))) + addHandler(h.basicauthlayer) } + + httpHandler = h.loglayer(httpHandler) + + // add handler + h.layers = httpHandler h.options = options return &h, nil } +func (t *HTTPServer) makeHTTPServer(tlsConfig *tls.Config) *http.Server { + httpServer := &http.Server{Addr: t.options.ListenAddress} + if t.options.HTTP1Only { + httpServer.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) + } + httpServer.TLSConfig = tlsConfig + httpServer.Handler = t.layers + return httpServer +} + // ListenAndServe requests over http func (t *HTTPServer) ListenAndServe() error { - return http.ListenAndServe(t.options.ListenAddress, t.layers) + httpServer := t.makeHTTPServer(nil) + return httpServer.ListenAndServe() } // ListenAndServeTLS requests over https @@ -73,11 +105,7 @@ func (t *HTTPServer) ListenAndServeTLS() error { if err != nil { return err } - httpServer := &http.Server{ - Addr: t.options.ListenAddress, - TLSConfig: tlsConfig, - } - httpServer.Handler = t.layers + httpServer := t.makeHTTPServer(tlsConfig) return httpServer.ListenAndServeTLS("", "") } return http.ListenAndServeTLS(t.options.ListenAddress, t.options.Certificate, t.options.CertificateKey, t.layers) diff --git a/pkg/httpserver/loglayer.go b/pkg/httpserver/loglayer.go index 0e1a87a..f3fb4f7 100644 --- a/pkg/httpserver/loglayer.go +++ b/pkg/httpserver/loglayer.go @@ -2,12 +2,9 @@ package httpserver import ( "bytes" - "io/ioutil" "net/http" "net/http/httputil" - "path" - "path/filepath" - + "time" "github.com/projectdiscovery/gologger" ) @@ -17,88 +14,49 @@ var ( EnableVerbose bool ) +func (t *HTTPServer) shouldDumpBody(bodysize int64) bool { + return t.options.MaxDumpBodySize > 0 && bodysize > t.options.MaxDumpBodySize +} + func (t *HTTPServer) loglayer(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fullRequest, _ := httputil.DumpRequest(r, true) - lrw := newLoggingResponseWriter(w) - handler.ServeHTTP(lrw, r) - - // Handles file write if enabled - if EnableUpload && r.Method == http.MethodPut { - // sandbox - calcolate absolute path - if t.options.Sandbox { - absPath, err := filepath.Abs(filepath.Join(t.options.Folder, r.URL.Path)) - if err != nil { - gologger.Print().Msgf("%s\n", err) - w.WriteHeader(http.StatusBadRequest) - return - } - // check if the path is within the configured folder - pattern := t.options.Folder + string(filepath.Separator) + "*" - matched, err := filepath.Match(pattern, absPath) - if err != nil { - gologger.Print().Msgf("%s\n", err) - w.WriteHeader(http.StatusBadRequest) - return - } else if !matched { - gologger.Print().Msg("pointing to unauthorized directory") - w.WriteHeader(http.StatusBadRequest) - return - } - } - - var ( - data []byte - err error - ) - if t.options.Sandbox { - maxFileSize := toMb(t.options.MaxFileSize) - // check header content length - if r.ContentLength > maxFileSize { - gologger.Print().Msg("request too large") - return - } - // body max length - r.Body = http.MaxBytesReader(w, r.Body, maxFileSize) - } - - data, err = ioutil.ReadAll(r.Body) - if err != nil { - gologger.Print().Msgf("%s\n", err) - w.WriteHeader(http.StatusInternalServerError) - return - } - err = handleUpload(t.options.Folder, path.Base(r.URL.Path), data) - if err != nil { - gologger.Print().Msgf("%s\n", err) - w.WriteHeader(http.StatusInternalServerError) - return - } + var fullRequest []byte + if t.shouldDumpBody(r.ContentLength) { + fullRequest, _ = httputil.DumpRequest(r, false) + } else { + fullRequest, _ = httputil.DumpRequest(r, true) } + lrw := newLoggingResponseWriter(w, t.options.MaxDumpBodySize) + handler.ServeHTTP(lrw, r) if EnableVerbose { headers := new(bytes.Buffer) lrw.Header().Write(headers) //nolint - gologger.Print().Msgf("\nRemote Address: %s\n%s\n%s %d %s\n%s\n%s\n", r.RemoteAddr, string(fullRequest), r.Proto, lrw.statusCode, http.StatusText(lrw.statusCode), headers.String(), string(lrw.Data)) + gologger.Print().Msgf("\n[%s]\nRemote Address: %s\n%s\n%s %d %s\n%s\n%s\n", time.Now().Format("2006-01-02 15:04:05"), r.RemoteAddr, string(fullRequest), r.Proto, lrw.statusCode, http.StatusText(lrw.statusCode), headers.String(), string(lrw.Data)) } else { - gologger.Print().Msgf("%s \"%s %s %s\" %d %d", r.RemoteAddr, r.Method, r.URL, r.Proto, lrw.statusCode, len(lrw.Data)) + gologger.Print().Msgf("[%s] %s \"%s %s %s\" %d %d", time.Now().Format("2006-01-02 15:04:05"), r.RemoteAddr, r.Method, r.URL, r.Proto, lrw.statusCode, lrw.Size) } }) } type loggingResponseWriter struct { http.ResponseWriter - statusCode int - Data []byte + statusCode int + Data []byte + Size int + MaxDumpSize int64 } -func newLoggingResponseWriter(w http.ResponseWriter) *loggingResponseWriter { - return &loggingResponseWriter{w, http.StatusOK, []byte{}} +func newLoggingResponseWriter(w http.ResponseWriter, maxSize int64) *loggingResponseWriter { + return &loggingResponseWriter{w, http.StatusOK, []byte{}, 0, maxSize} } // Write the data func (lrw *loggingResponseWriter) Write(data []byte) (int, error) { - lrw.Data = append(lrw.Data, data...) + if len(lrw.Data) < int(lrw.MaxDumpSize) { + lrw.Data = append(lrw.Data, data...) + } + lrw.Size += len(data) return lrw.ResponseWriter.Write(data) } diff --git a/pkg/httpserver/uploadlayer.go b/pkg/httpserver/uploadlayer.go index 928ac60..670d75a 100644 --- a/pkg/httpserver/uploadlayer.go +++ b/pkg/httpserver/uploadlayer.go @@ -3,21 +3,97 @@ package httpserver import ( "errors" "io/ioutil" + "net/http" + "os" + "path" "path/filepath" "strings" + + "github.com/projectdiscovery/gologger" + "github.com/projectdiscovery/simplehttpserver/pkg/unit" ) +// uploadlayer handles PUT requests and save the file to disk +func (t *HTTPServer) uploadlayer(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Handles file write if enabled + if EnableUpload && r.Method == http.MethodPut { + // sandbox - calcolate absolute path + if t.options.Sandbox { + absPath, err := filepath.Abs(filepath.Join(t.options.Folder, r.URL.Path)) + if err != nil { + gologger.Print().Msgf("%s\n", err) + w.WriteHeader(http.StatusBadRequest) + return + } + // check if the path is within the configured folder + pattern := t.options.Folder + string(filepath.Separator) + "*" + matched, err := filepath.Match(pattern, absPath) + if err != nil { + gologger.Print().Msgf("%s\n", err) + w.WriteHeader(http.StatusBadRequest) + return + } else if !matched { + gologger.Print().Msg("pointing to unauthorized directory") + w.WriteHeader(http.StatusBadRequest) + return + } + } + + var ( + data []byte + err error + ) + if t.options.Sandbox { + maxFileSize := unit.ToMb(t.options.MaxFileSize) + // check header content length + if r.ContentLength > maxFileSize { + gologger.Print().Msg("request too large") + return + } + // body max length + r.Body = http.MaxBytesReader(w, r.Body, maxFileSize) + } + + data, err = ioutil.ReadAll(r.Body) + if err != nil { + gologger.Print().Msgf("%s\n", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + sanitizedPath := filepath.FromSlash(path.Clean("/" + strings.Trim(r.URL.Path, "/"))) + + err = handleUpload(t.options.Folder, sanitizedPath, data) + if err != nil { + gologger.Print().Msgf("%s\n", err) + w.WriteHeader(http.StatusInternalServerError) + return + } else { + w.WriteHeader(http.StatusCreated) + return + } + } + + handler.ServeHTTP(w, r) + }) +} + func handleUpload(base, file string, data []byte) error { // rejects all paths containing a non exhaustive list of invalid characters - This is only a best effort as the tool is meant for development if strings.ContainsAny(file, "\\`\"':") { return errors.New("invalid character") } - // allow upload only in subfolders - rel, err := filepath.Rel(base, file) - if rel == "" || err != nil { - return err + untrustedPath := filepath.Clean(filepath.Join(base, file)) + if !strings.HasPrefix(untrustedPath, filepath.Clean(base)) { + return errors.New("invalid path") + } + trustedPath := untrustedPath + + if _, err := os.Stat(path.Dir(trustedPath)); os.IsNotExist(err) { + return errors.New("invalid path") } - return ioutil.WriteFile(file, data, 0655) + return ioutil.WriteFile(trustedPath, data, 0655) } diff --git a/pkg/httpserver/util.go b/pkg/httpserver/util.go deleted file mode 100644 index 4c69d6f..0000000 --- a/pkg/httpserver/util.go +++ /dev/null @@ -1,5 +0,0 @@ -package httpserver - -func toMb(n int) int64 { - return int64(n) * 1024 * 1024 -} diff --git a/pkg/tcpserver/addr.go b/pkg/tcpserver/addr.go new file mode 100644 index 0000000..b678b30 --- /dev/null +++ b/pkg/tcpserver/addr.go @@ -0,0 +1,9 @@ +package tcpserver + +// ContextType is the key type stored in ctx +type ContextType string + +var ( + // Addr is the contextKey where the net.Addr is stored + Addr ContextType = "addr" +) diff --git a/pkg/tcpserver/responseengine.go b/pkg/tcpserver/responseengine.go index ec15da0..80fb795 100644 --- a/pkg/tcpserver/responseengine.go +++ b/pkg/tcpserver/responseengine.go @@ -6,9 +6,12 @@ import ( // BuildResponse according to rules func (t *TCPServer) BuildResponse(data []byte) ([]byte, error) { + t.mux.RLock() + defer t.mux.RUnlock() + // Process all the rules - for _, rule := range t.options.rules { - if rule.matchRegex.Match(data) { + for _, rule := range t.rules { + if rule.MatchInput(data) { return []byte(rule.Response), nil } } diff --git a/pkg/tcpserver/rule.go b/pkg/tcpserver/rule.go index 903331b..aa9e6e8 100644 --- a/pkg/tcpserver/rule.go +++ b/pkg/tcpserver/rule.go @@ -1,6 +1,9 @@ package tcpserver -import "regexp" +import ( + "regexp" + "strings" +) // RulesConfiguration from yaml type RulesConfiguration struct { @@ -9,13 +12,20 @@ type RulesConfiguration struct { // Rule to apply to various requests type Rule struct { - Match string `yaml:"match,omitempty"` - matchRegex *regexp.Regexp - Response string `yaml:"response,omitempty"` + Name string `yaml:"name,omitempty"` + Match string `yaml:"match,omitempty"` + MatchContains string `yaml:"match-contains,omitempty"` + matchRegex *regexp.Regexp + Response string `yaml:"response,omitempty"` } -// NewRule from model +// NewRule creates a new Rule - default is regex func NewRule(match, response string) (*Rule, error) { + return NewRegexRule(match, response) +} + +// NewRegexRule returns a new regex-match Rule +func NewRegexRule(match, response string) (*Rule, error) { regxp, err := regexp.Compile(match) if err != nil { return nil, err @@ -23,3 +33,33 @@ func NewRule(match, response string) (*Rule, error) { return &Rule{Match: match, matchRegex: regxp, Response: response}, nil } + +// NewLiteralRule returns a new literal-match Rule +func NewLiteralRule(match, response string) (*Rule, error) { + return &Rule{MatchContains: match, Response: response}, nil +} + +// NewRuleFromTemplate "copies" a new Rule +func NewRuleFromTemplate(r Rule) (newRule *Rule, err error) { + newRule = &Rule{ + Name: r.Name, + Response: r.Response, + MatchContains: r.MatchContains, + Match: r.Match, + } + if newRule.Match != "" { + newRule.matchRegex, err = regexp.Compile(newRule.Match) + } + + return +} + +// MatchInput returns if the input was matches with one of the matchers +func (r *Rule) MatchInput(input []byte) bool { + if r.matchRegex != nil && r.matchRegex.Match(input) { + return true + } else if r.MatchContains != "" && strings.Contains(string(input), r.MatchContains) { + return true + } + return false +} diff --git a/pkg/tcpserver/tcpserver.go b/pkg/tcpserver/tcpserver.go index 876fbb4..cbdd407 100644 --- a/pkg/tcpserver/tcpserver.go +++ b/pkg/tcpserver/tcpserver.go @@ -1,9 +1,12 @@ package tcpserver import ( + "context" "crypto/tls" + "errors" "io/ioutil" "net" + "sync" "time" "github.com/projectdiscovery/gologger" @@ -24,20 +27,35 @@ type Options struct { Verbose bool } +// CallBackFunc handles what is send back to the client, based on the incomming question +type CallBackFunc func(ctx context.Context, question []byte) (answer []byte, err error) + // TCPServer instance type TCPServer struct { options *Options listener net.Listener + + // Callbacks to retrieve information about the system + HandleMessageFnc CallBackFunc + + mux sync.RWMutex + rules []Rule } // New tcp server instance with specified options func New(options *Options) (*TCPServer, error) { - return &TCPServer{options: options}, nil + srv := &TCPServer{options: options} + srv.HandleMessageFnc = srv.BuildResponseWithContext + srv.rules = options.rules + return srv, nil } // AddRule to the server func (t *TCPServer) AddRule(rule Rule) error { - t.options.rules = append(t.options.rules, rule) + t.mux.Lock() + defer t.mux.Unlock() + + t.rules = append(t.rules, rule) return nil } @@ -51,23 +69,27 @@ func (t *TCPServer) ListenAndServe() error { return t.run() } -func (t *TCPServer) handleConnection(conn net.Conn) error { +func (t *TCPServer) handleConnection(conn net.Conn, callback CallBackFunc) error { defer conn.Close() //nolint + // Create Context + ctx := context.WithValue(context.Background(), Addr, conn.RemoteAddr()) + buf := make([]byte, 4096) for { if err := conn.SetReadDeadline(time.Now().Add(readTimeout * time.Second)); err != nil { gologger.Info().Msgf("%s\n", err) } - _, err := conn.Read(buf) + n, err := conn.Read(buf) if err != nil { return err } - gologger.Print().Msgf("%s\n", buf) + gologger.Print().Msgf("%s\n", buf[:n]) - resp, err := t.BuildResponse(buf) + resp, err := callback(ctx, buf[:n]) if err != nil { + gologger.Info().Msgf("Closing connection: %s\n", err) return err } @@ -112,7 +134,7 @@ func (t *TCPServer) run() error { if err != nil { return err } - go t.handleConnection(c) //nolint + go t.handleConnection(c, t.HandleMessageFnc) //nolint } } @@ -133,13 +155,54 @@ func (t *TCPServer) LoadTemplate(templatePath string) error { return err } + t.mux.Lock() + defer t.mux.Unlock() + + t.rules = make([]Rule, 0) for _, ruleTemplate := range config.Rules { - rule, err := NewRule(ruleTemplate.Match, ruleTemplate.Response) + rule, err := NewRuleFromTemplate(ruleTemplate) if err != nil { return err } - t.options.rules = append(t.options.rules, *rule) + t.rules = append(t.rules, *rule) } + gologger.Info().Msgf("TCP configuration loaded. Rules: %d\n", len(t.rules)) + return nil } + +// MatchRule returns the rule, which was matched first +func (t *TCPServer) MatchRule(data []byte) (rule Rule, err error) { + t.mux.RLock() + defer t.mux.RUnlock() + + // Process all the rules + for _, rule := range t.rules { + if rule.MatchInput(data) { + return rule, nil + } + } + return Rule{}, errors.New("no matched rule") +} + +// BuildResponseWithContext is a wrapper with context +func (t *TCPServer) BuildResponseWithContext(ctx context.Context, data []byte) ([]byte, error) { + return t.BuildResponse(data) +} + +// BuildResponseWithContext is a wrapper with context +func (t *TCPServer) BuildRuleResponse(ctx context.Context, data []byte) ([]byte, error) { + addr := "unknown" + if netAddr, ok := ctx.Value(Addr).(net.Addr); ok { + addr = netAddr.String() + } + rule, err := t.MatchRule(data) + if err != nil { + return []byte(":) "), err + } + + gologger.Info().Msgf("Incoming TCP request(%s) from: %s\n", rule.Name, addr) + + return []byte(rule.Response), nil +} diff --git a/pkg/unit/unit.go b/pkg/unit/unit.go new file mode 100644 index 0000000..98cdb35 --- /dev/null +++ b/pkg/unit/unit.go @@ -0,0 +1,6 @@ +package unit + +// ToMb converts bytes to megabytes +func ToMb(n int) int64 { + return int64(n) * 1024 * 1024 +}