Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 45 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ type Client struct {
// forcefully killed.
processKilled bool

hostSocketDir string
unixSocketCfg UnixSocketConfig
}

// NegotiatedVersion returns the protocol version negotiated with the server.
Expand Down Expand Up @@ -240,6 +240,28 @@ type ClientConfig struct {
// SkipHostEnv allows plugins to run without inheriting the parent process'
// environment variables.
SkipHostEnv bool

// UnixSocketConfig configures additional options for any Unix sockets
// that are created. Not normally required. Not supported on Windows.
UnixSocketConfig *UnixSocketConfig
}

type UnixSocketConfig struct {
// If set, go-plugin will change the owner of any Unix sockets created to
// this group, and set them as group-writable. Can be a name or gid. The
// client process must be a member of this group or chown will fail.
Group string

// The directory to create Unix sockets in. Internally managed by go-plugin
// and deleted when the plugin is killed.
directory string
}

func unixSocketConfigFromEnv() UnixSocketConfig {
return UnixSocketConfig{
Group: os.Getenv(EnvUnixSocketGroup),
directory: os.Getenv(EnvUnixSocketDir),
}
}

// ReattachConfig is used to configure a client to reattach to an
Expand Down Expand Up @@ -445,7 +467,7 @@ func (c *Client) Kill() {
c.l.Lock()
runner := c.runner
addr := c.address
hostSocketDir := c.hostSocketDir
hostSocketDir := c.unixSocketCfg.directory
c.l.Unlock()

// If there is no runner or ID, there is nothing to kill.
Expand Down Expand Up @@ -629,15 +651,33 @@ func (c *Client) Start() (addr net.Addr, err error) {
}
}

if c.config.UnixSocketConfig != nil {
c.unixSocketCfg.Group = c.config.UnixSocketConfig.Group
}

if c.unixSocketCfg.Group != "" {
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", EnvUnixSocketGroup, c.unixSocketCfg.Group))
}

var runner runner.Runner
switch {
case c.config.RunnerFunc != nil:
c.hostSocketDir, err = os.MkdirTemp("", "")
c.unixSocketCfg.directory, err = os.MkdirTemp("", "plugin-dir")
if err != nil {
return nil, err
}
c.logger.Trace("created temporary directory for unix sockets", "dir", c.hostSocketDir)
runner, err = c.config.RunnerFunc(c.logger, cmd, c.hostSocketDir)
// os.MkdirTemp creates folders with 0o700, so if we have a group
// configured we need to make it group-writable.
if c.unixSocketCfg.Group != "" {
err = setGroupWritable(c.unixSocketCfg.directory, c.unixSocketCfg.Group, 0o770)
if err != nil {
return nil, err
}
}
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", EnvUnixSocketDir, c.unixSocketCfg.directory))
c.logger.Trace("created temporary directory for unix sockets", "dir", c.unixSocketCfg.directory)

runner, err = c.config.RunnerFunc(c.logger, cmd, c.unixSocketCfg.directory)
if err != nil {
return nil, err
}
Expand Down
97 changes: 97 additions & 0 deletions client_unix_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

//go:build !windows
// +build !windows

package plugin

import (
"fmt"
"os"
"os/exec"
"os/user"
"runtime"
"syscall"
"testing"

"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-plugin/internal/cmdrunner"
"github.com/hashicorp/go-plugin/runner"
)

func TestSetGroup(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("go-plugin doesn't support unix sockets on Windows")
}

group, err := user.LookupGroupId(fmt.Sprintf("%d", os.Getgid()))
if err != nil {
t.Fatal(err)
}
for name, tc := range map[string]struct {
group string
}{
"as integer": {fmt.Sprintf("%d", os.Getgid())},
"as name": {group.Name},
} {
t.Run(name, func(t *testing.T) {
process := helperProcess("mock")
c := NewClient(&ClientConfig{
HandshakeConfig: testHandshake,
Plugins: testPluginMap,
UnixSocketConfig: &UnixSocketConfig{
Group: tc.group,
},
RunnerFunc: func(l hclog.Logger, cmd *exec.Cmd, tmpDir string) (runner.Runner, error) {
// Run tests inside the RunnerFunc to ensure we don't race
// with the code that deletes tmpDir when the client fails
// to start properly.

// Test that it creates a directory with the proper owners and permissions.
info, err := os.Lstat(tmpDir)
if err != nil {
t.Fatal(err)
}
if info.Mode()&os.ModePerm != 0o770 {
t.Fatal(info.Mode())
}
stat, ok := info.Sys().(*syscall.Stat_t)
if !ok {
t.Fatal()
}
if stat.Gid != uint32(os.Getgid()) {
t.Fatalf("Expected %d, but got %d", os.Getgid(), stat.Gid)
}

// Check the correct environment variables were set to forward
// Unix socket config onto the plugin.
var foundUnixSocketDir, foundUnixSocketGroup bool
for _, env := range cmd.Env {
if env == fmt.Sprintf("%s=%s", EnvUnixSocketDir, tmpDir) {
foundUnixSocketDir = true
}
if env == fmt.Sprintf("%s=%s", EnvUnixSocketGroup, tc.group) {
foundUnixSocketGroup = true
}
}
if !foundUnixSocketDir {
t.Errorf("Did not find correct %s env in %v", EnvUnixSocketDir, cmd.Env)
}
if !foundUnixSocketGroup {
t.Errorf("Did not find correct %s env in %v", EnvUnixSocketGroup, cmd.Env)
}

process.Env = append(process.Env, cmd.Env...)
return cmdrunner.NewCmdRunner(l, process)
},
})
defer c.Kill()

_, err := c.Start()
if err != nil {
t.Fatalf("err should be nil, got %s", err)
}
})
}
}
8 changes: 4 additions & 4 deletions grpc_broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ type GRPCBroker struct {
doneCh chan struct{}
o sync.Once

socketDir string
unixSocketCfg UnixSocketConfig
addrTranslator runner.AddrTranslator

sync.Mutex
Expand All @@ -279,14 +279,14 @@ type gRPCBrokerPending struct {
doneCh chan struct{}
}

func newGRPCBroker(s streamer, tls *tls.Config, socketDir string, addrTranslator runner.AddrTranslator) *GRPCBroker {
func newGRPCBroker(s streamer, tls *tls.Config, unixSocketCfg UnixSocketConfig, addrTranslator runner.AddrTranslator) *GRPCBroker {
return &GRPCBroker{
streamer: s,
streams: make(map[uint32]*gRPCBrokerPending),
tls: tls,
doneCh: make(chan struct{}),

socketDir: socketDir,
unixSocketCfg: unixSocketCfg,
addrTranslator: addrTranslator,
}
}
Expand All @@ -295,7 +295,7 @@ func newGRPCBroker(s streamer, tls *tls.Config, socketDir string, addrTranslator
//
// This should not be called multiple times with the same ID at one time.
func (b *GRPCBroker) Accept(id uint32) (net.Listener, error) {
listener, err := serverListener(b.socketDir)
listener, err := serverListener(b.unixSocketCfg)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion grpc_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func newGRPCClient(doneCtx context.Context, c *Client) (*GRPCClient, error) {

// Start the broker.
brokerGRPCClient := newGRPCBrokerClient(conn)
broker := newGRPCBroker(brokerGRPCClient, c.config.TLSConfig, c.hostSocketDir, c.runner)
broker := newGRPCBroker(brokerGRPCClient, c.config.TLSConfig, c.unixSocketCfg, c.runner)
go broker.Run()
go brokerGRPCClient.StartStream()

Expand Down
2 changes: 1 addition & 1 deletion grpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func (s *GRPCServer) Init() error {
// Register the broker service
brokerServer := newGRPCBrokerServer()
plugin.RegisterGRPCBrokerServer(s.server, brokerServer)
s.broker = newGRPCBroker(brokerServer, s.TLS, "", nil)
s.broker = newGRPCBroker(brokerServer, s.TLS, unixSocketConfigFromEnv(), nil)
go s.broker.Run()

// Register the controller
Expand Down
57 changes: 33 additions & 24 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ func Serve(opts *ServeConfig) {
}

// Register a listener so we can accept a connection
listener, err := serverListener(os.Getenv(EnvUnixSocketDir))
listener, err := serverListener(unixSocketConfigFromEnv())
if err != nil {
logger.Error("plugin init error", "error", err)
return
Expand Down Expand Up @@ -496,12 +496,12 @@ func Serve(opts *ServeConfig) {
}
}

func serverListener(dir string) (net.Listener, error) {
func serverListener(unixSocketCfg UnixSocketConfig) (net.Listener, error) {
if runtime.GOOS == "windows" {
return serverListener_tcp()
}

return serverListener_unix(dir)
return serverListener_unix(unixSocketCfg)
}

func serverListener_tcp() (net.Listener, error) {
Expand Down Expand Up @@ -546,8 +546,8 @@ func serverListener_tcp() (net.Listener, error) {
return nil, errors.New("Couldn't bind plugin TCP listener")
}

func serverListener_unix(dir string) (net.Listener, error) {
tf, err := os.CreateTemp(dir, "plugin")
func serverListener_unix(unixSocketCfg UnixSocketConfig) (net.Listener, error) {
tf, err := os.CreateTemp(unixSocketCfg.directory, "plugin")
if err != nil {
return nil, err
}
Expand All @@ -569,25 +569,8 @@ func serverListener_unix(dir string) (net.Listener, error) {

// By default, unix sockets are only writable by the owner. Set up a custom
// group owner and group write permissions if configured.
if groupString := os.Getenv(EnvUnixSocketGroup); groupString != "" {
groupID, err := strconv.Atoi(groupString)
if err != nil {
group, err := user.LookupGroup(groupString)
if err != nil {
return nil, fmt.Errorf("failed to find group ID from %s=%s environment variable: %w", EnvUnixSocketGroup, groupString, err)
}
groupID, err = strconv.Atoi(group.Gid)
if err != nil {
return nil, fmt.Errorf("failed to parse %q group's Gid as an integer: %w", groupString, err)
}
}

err = os.Chown(path, -1, groupID)
if err != nil {
return nil, err
}

err = os.Chmod(path, 0o660)
if unixSocketCfg.Group != "" {
err = setGroupWritable(path, unixSocketCfg.Group, 0o660)
if err != nil {
return nil, err
}
Expand All @@ -601,6 +584,32 @@ func serverListener_unix(dir string) (net.Listener, error) {
}, nil
}

func setGroupWritable(path, groupString string, mode os.FileMode) error {
groupID, err := strconv.Atoi(groupString)
if err != nil {
group, err := user.LookupGroup(groupString)
if err != nil {
return fmt.Errorf("failed to find gid from %q: %w", groupString, err)
}
groupID, err = strconv.Atoi(group.Gid)
if err != nil {
return fmt.Errorf("failed to parse %q group's gid as an integer: %w", groupString, err)
}
}

err = os.Chown(path, -1, groupID)
if err != nil {
return err
}

err = os.Chmod(path, mode)
if err != nil {
return err
}

return nil
}

// rmListener is an implementation of net.Listener that forwards most
// calls to the listener but also removes a file as part of the close. We
// use this to cleanup the unix domain socket on close.
Expand Down
6 changes: 2 additions & 4 deletions server_unix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,13 @@ func TestUnixSocketGroupPermissions(t *testing.T) {
t.Fatal(err)
}
for name, tc := range map[string]struct {
gid string
group string
}{
"as integer": {fmt.Sprintf("%d", os.Getgid())},
"as name": {group.Name},
} {
t.Run(name, func(t *testing.T) {
t.Setenv(EnvUnixSocketGroup, tc.gid)

ln, err := serverListener_unix("")
ln, err := serverListener_unix(UnixSocketConfig{Group: tc.group})
if err != nil {
t.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func TestPluginGRPCConn(t testing.T, ps map[string]Plugin) (*GRPCClient, *GRPCSe
}

brokerGRPCClient := newGRPCBrokerClient(conn)
broker := newGRPCBroker(brokerGRPCClient, nil, "", nil)
broker := newGRPCBroker(brokerGRPCClient, nil, UnixSocketConfig{}, nil)
go broker.Run()
go brokerGRPCClient.StartStream()

Expand Down