diff --git a/go.mod b/go.mod index f882445f57..0df45261fc 100644 --- a/go.mod +++ b/go.mod @@ -14,10 +14,11 @@ require ( github.com/alicebob/miniredis/v2 v2.35.0 github.com/arsham/figurine v1.3.0 github.com/atotto/clipboard v0.1.4 - github.com/aws/aws-sdk-go-v2 v1.37.2 + github.com/aws/aws-sdk-go-v2 v1.38.1 github.com/aws/aws-sdk-go-v2/config v1.30.3 github.com/aws/aws-sdk-go-v2/credentials v1.18.3 github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.18.3 + github.com/aws/aws-sdk-go-v2/service/ecr v1.49.2 github.com/aws/aws-sdk-go-v2/service/s3 v1.86.0 github.com/aws/aws-sdk-go-v2/service/ssm v1.62.0 github.com/aws/aws-sdk-go-v2/service/sts v1.36.0 @@ -130,8 +131,8 @@ require ( github.com/aws/aws-sdk-go v1.55.7 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.2 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.4 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.4 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.2 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 // indirect diff --git a/go.sum b/go.sum index 32a9ac0790..8abbffee88 100644 --- a/go.sum +++ b/go.sum @@ -749,8 +749,8 @@ github.com/aws/aws-sdk-go v1.34.0/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU github.com/aws/aws-sdk-go v1.44.122/go.mod h1:y4AeaBuwd2Lk+GepC1E9v0qOiTws0MIWAX4oIKwKHZo= github.com/aws/aws-sdk-go v1.55.7 h1:UJrkFq7es5CShfBwlWAC8DA077vp8PyVbQd3lqLiztE= github.com/aws/aws-sdk-go v1.55.7/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= -github.com/aws/aws-sdk-go-v2 v1.37.2 h1:xkW1iMYawzcmYFYEV0UCMxc8gSsjCGEhBXQkdQywVbo= -github.com/aws/aws-sdk-go-v2 v1.37.2/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg= +github.com/aws/aws-sdk-go-v2 v1.38.1 h1:j7sc33amE74Rz0M/PoCpsZQ6OunLqys/m5antM0J+Z8= +github.com/aws/aws-sdk-go-v2 v1.38.1/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 h1:6GMWV6CNpA/6fbFHnoAjrv4+LGfyTqZz2LtCHnspgDg= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0/go.mod h1:/mXlTIVG9jbxkqDnr5UQNQxW1HRYxeGklkM9vAFeabg= github.com/aws/aws-sdk-go-v2/config v1.30.3 h1:utupeVnE3bmB221W08P0Moz1lDI3OwYa2fBtUhl7TCc= @@ -761,14 +761,16 @@ github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.2 h1:nRniHAvjFJGUCl04F3WaAj7 github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.2/go.mod h1:eJDFKAMHHUvv4a0Zfa7bQb//wFNUXGrbFpYRCHe2kD0= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.18.3 h1:Nb2pUE30lySKPGdkiIJ1SZgHsjiebOiRNI7R9NA1WtM= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.18.3/go.mod h1:BO5EKulvhBF1NXwui8lfnuDPBQQU5807yvWASZ/5n6k= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2 h1:sPiRHLVUIIQcoVZTNwqQcdtjkqkPopyYmIX0M5ElRf4= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2/go.mod h1:ik86P3sgV+Bk7c1tBFCwI3VxMoSEwl4YkRB9xn1s340= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2 h1:ZdzDAg075H6stMZtbD2o+PyB933M/f20e9WmCBC17wA= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2/go.mod h1:eE1IIzXG9sdZCB0pNNpMpsYTLl4YdOQD3njiVN1e/E4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.4 h1:IdCLsiiIj5YJ3AFevsewURCPV+YWUlOW8JiPhoAy8vg= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.4/go.mod h1:l4bdfCD7XyyZA9BolKBo1eLqgaJxl0/x91PL4Yqe0ao= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.4 h1:j7vjtr1YIssWQOMeOWRbh3z8g2oY/xPjnZH2gLY4sGw= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.4/go.mod h1:yDmJgqOiH4EA8Hndnv4KwAo8jCGTSnM5ASG1nBI+toA= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.2 h1:sBpc8Ph6CpfZsEdkz/8bfg8WhKlWMCms5iWj6W/AW2U= github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.2/go.mod h1:Z2lDojZB+92Wo6EKiZZmJid9pPrDJW2NNIXSlaEfVlU= +github.com/aws/aws-sdk-go-v2/service/ecr v1.49.2 h1:aFmDHNrMqJb7Um0wusnZ8lqDcYTf0+RXxSvmCuelBiM= +github.com/aws/aws-sdk-go-v2/service/ecr v1.49.2/go.mod h1:Knlx5anjbiHqbCdnOabD+soFqsJIx2RdKf5R9SoBuUg= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0 h1:6+lZi2JeGKtCraAj1rpoZfKqnQ9SptseRZioejfUOLM= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.0/go.mod h1:eb3gfbVIxIoGgJsi9pGne19dhCBpK6opTYpQqAmdy44= github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.8.2 h1:blV3dY6WbxIVOFggfYIo2E1Q2lZoy5imS7nKgu5m6Tc= diff --git a/internal/exec/oci_auth_test.go b/internal/exec/oci_auth_test.go new file mode 100644 index 0000000000..58f402d858 --- /dev/null +++ b/internal/exec/oci_auth_test.go @@ -0,0 +1,130 @@ +package exec + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudposse/atmos/pkg/schema" +) + +// TestGetRegistryAuth tests the main registry authentication function +// This is an integration test that verifies the overall authentication flow. +func TestGetRegistryAuth(t *testing.T) { + tests := []struct { + name string + registry string + atmosConfig *schema.AtmosConfiguration + expectError bool + errorMsg string + }{ + { + name: "GitHub Container Registry", + registry: "ghcr.io", + atmosConfig: &schema.AtmosConfiguration{}, + expectError: false, + }, + { + name: "Docker Hub", + registry: "docker.io", + atmosConfig: &schema.AtmosConfiguration{}, + expectError: false, + }, + { + name: "Azure Container Registry", + registry: "test.azurecr.io", + atmosConfig: &schema.AtmosConfiguration{}, + expectError: false, + }, + { + name: "AWS ECR (skipped - requires credentials)", + registry: "invalid-ecr-registry.com", + atmosConfig: &schema.AtmosConfiguration{}, + expectError: false, + }, + { + name: "Google Container Registry", + registry: "gcr.io", + atmosConfig: &schema.AtmosConfiguration{}, + expectError: false, + }, + { + name: "Unknown registry", + registry: "unknown.registry.com", + atmosConfig: &schema.AtmosConfiguration{}, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := getRegistryAuth(tt.registry, tt.atmosConfig) + + if tt.expectError { + assert.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else if err != nil { + // Authentication might fail due to missing credentials, but should not panic + t.Logf("Authentication failed as expected: %v", err) + } + }) + } +} + +// TestCloudProviderAuth tests authentication for different cloud providers +// This is an integration test that verifies provider-specific authentication flows. +func TestCloudProviderAuth(t *testing.T) { + tests := []struct { + name string + registry string + provider string + expectError bool + }{ + { + name: "GitHub CR with token", + registry: "ghcr.io", + provider: "github", + expectError: false, + }, + { + name: "Docker Hub with credentials", + registry: "docker.io", + provider: "docker", + expectError: false, + }, + { + name: "Azure Container Registry", + registry: "test.azurecr.io", + provider: "azure", + expectError: false, + }, + { + name: "AWS ECR (skipped - requires credentials)", + registry: "invalid-ecr-registry.com", + provider: "aws", + expectError: false, + }, + { + name: "Google Container Registry", + registry: "gcr.io", + provider: "google", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + atmosConfig := &schema.AtmosConfiguration{} + _, err := getRegistryAuth(tt.registry, atmosConfig) + + if tt.expectError { + assert.Error(t, err) + } else if err != nil { + // Authentication might fail due to missing credentials, but should not panic + t.Logf("Authentication failed as expected: %v", err) + } + }) + } +} diff --git a/internal/exec/oci_aws.go b/internal/exec/oci_aws.go new file mode 100644 index 0000000000..24af3574d6 --- /dev/null +++ b/internal/exec/oci_aws.go @@ -0,0 +1,125 @@ +package exec + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "regexp" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/ecr" + "github.com/aws/aws-sdk-go-v2/service/ecr/types" + log "github.com/charmbracelet/log" + "github.com/google/go-containerregistry/pkg/authn" +) + +var ( + // Static errors for ECR authentication + errInvalidECRRegistryFormat = errors.New("invalid ECR registry format") + errCouldNotParseECRAccount = errors.New("could not parse ECR account/region") + errFailedToGetECRAuthToken = errors.New("failed to get ECR authorization token") + errNoECRAuthorizationData = errors.New("no authorization data returned from ECR") + errEmptyECRAuthorizationToken = errors.New("empty ECR authorization token") + errFailedToDecodeECRAuthToken = errors.New("failed to decode ECR authorization token") + errInvalidECRAuthTokenFormat = errors.New("invalid ECR authorization token format") + errFailedToLoadAWSConfig = errors.New("failed to load AWS config") +) + +// Precompiled: supports ecr and ecr-fips across partitions (incl. .cn). +var ecrRegistryRe = regexp.MustCompile(`^(?P\d{12})\.dkr\.(?Pecr(?:-fips)?)\.(?P[a-z0-9-]+)\.amazonaws\.com(?:\.cn)?$`) + +// parseECRRegistry parses ECR registry string and extracts account ID and region. +func parseECRRegistry(registry string) (accountID, region string, err error) { + m := ecrRegistryRe.FindStringSubmatch(registry) + if m == nil { + return "", "", fmt.Errorf("%w: %s", errInvalidECRRegistryFormat, registry) + } + + accountID = m[ecrRegistryRe.SubexpIndex("acct")] + region = m[ecrRegistryRe.SubexpIndex("region")] + + if accountID == "" || region == "" { + return "", "", fmt.Errorf("%w from %s", errCouldNotParseECRAccount, registry) + } + + return accountID, region, nil +} + +// getECRAuthToken retrieves the authorization token from ECR. +func getECRAuthToken(ctx context.Context, ecrClient *ecr.Client, accountID string) (*types.AuthorizationData, error) { + authTokenInput := &ecr.GetAuthorizationTokenInput{ + RegistryIds: []string{accountID}, + } + authTokenOutput, err := ecrClient.GetAuthorizationToken(ctx, authTokenInput) + if err != nil { + return nil, fmt.Errorf("%w: %v", errFailedToGetECRAuthToken, err) + } + if len(authTokenOutput.AuthorizationData) == 0 { + return nil, fmt.Errorf("%w for account %s", errNoECRAuthorizationData, accountID) + } + return &authTokenOutput.AuthorizationData[0], nil +} + +// parseECRCredentials decodes and parses the ECR authorization token. +func parseECRCredentials(authData *types.AuthorizationData, registry string) (username, password string, err error) { + if authData.AuthorizationToken == nil { + return "", "", fmt.Errorf("%w for %s", errEmptyECRAuthorizationToken, registry) + } + + token, err := base64.StdEncoding.DecodeString(*authData.AuthorizationToken) + if err != nil { + return "", "", fmt.Errorf("%w: %v", errFailedToDecodeECRAuthToken, err) + } + + parts := strings.SplitN(string(token), ":", 2) + if len(parts) != 2 { + return "", "", errInvalidECRAuthTokenFormat + } + + return parts[0], parts[1], nil +} + +// getECRAuth attempts to get AWS ECR authentication using AWS credentials +// Supports SSO/role providers by not gating on environment variables +// Supports both standard ECR and FIPS endpoints. +func getECRAuth(registry string) (authn.Authenticator, error) { + accountID, region, err := parseECRRegistry(registry) + if err != nil { + return nil, err + } + + log.Debug("Extracted ECR registry info", "registry", registry, "accountID", accountID, "region", region) + + // Create context with timeout to prevent hanging AWS API calls + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Load AWS config for the target region + cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region)) + if err != nil { + return nil, fmt.Errorf("%w: %v", errFailedToLoadAWSConfig, err) + } + ecrClient := ecr.NewFromConfig(cfg) + + // Get ECR authorization token + authData, err := getECRAuthToken(ctx, ecrClient, accountID) + if err != nil { + return nil, err + } + + // Parse credentials from token + username, password, err := parseECRCredentials(authData, registry) + if err != nil { + return nil, err + } + + log.Debug("Successfully obtained ECR credentials", "registry", registry, "accountID", accountID, "region", region) + + return &authn.Basic{ + Username: username, + Password: password, + }, nil +} diff --git a/internal/exec/oci_aws_test.go b/internal/exec/oci_aws_test.go new file mode 100644 index 0000000000..367fe1a209 --- /dev/null +++ b/internal/exec/oci_aws_test.go @@ -0,0 +1,82 @@ +package exec + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestECRAuthDirect tests AWS ECR authentication directly. +func TestECRAuthDirect(t *testing.T) { + t.Parallel() + tests := []struct { + name string + registry string + expectError bool + errorMsg string + }{ + { + name: "Invalid ECR registry format", + registry: "invalid-ecr-registry.com", + expectError: true, + errorMsg: "invalid ECR registry format", + }, + { + name: "Non-ECR registry", + registry: "docker.io", + expectError: true, + errorMsg: "invalid ECR registry format", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + _, err := getECRAuth(tt.registry) + + if tt.expectError { + assert.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestECRRegistryParsing tests the ECR registry parsing logic. +func TestECRRegistryParsing(t *testing.T) { + t.Parallel() + tests := []struct { + name string + registry string + wantAcct string + wantRegion string + expectError bool + }{ + {"Standard ECR", "123456789012.dkr.ecr.us-west-2.amazonaws.com", "123456789012", "us-west-2", false}, + {"ECR FIPS", "123456789012.dkr.ecr-fips.us-west-2.amazonaws.com", "123456789012", "us-west-2", false}, + {"ECR China", "123456789012.dkr.ecr.cn-northwest-1.amazonaws.com.cn", "123456789012", "cn-northwest-1", false}, + {"ECR FIPS China", "123456789012.dkr.ecr-fips.cn-northwest-1.amazonaws.com.cn", "123456789012", "cn-northwest-1", false}, + {"Non-ECR registry", "docker.io", "", "", true}, + {"Invalid format", "invalid-registry.com", "", "", true}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + acct, region, err := parseECRRegistry(tt.registry) + if tt.expectError { + assert.Error(t, err) + return + } + assert.NoError(t, err) + assert.Equal(t, tt.wantAcct, acct) + assert.Equal(t, tt.wantRegion, region) + }) + } +} diff --git a/internal/exec/oci_azure.go b/internal/exec/oci_azure.go new file mode 100644 index 0000000000..c27b6083b7 --- /dev/null +++ b/internal/exec/oci_azure.go @@ -0,0 +1,254 @@ +package exec + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + log "github.com/charmbracelet/log" + "github.com/google/go-containerregistry/pkg/authn" + "github.com/spf13/viper" + + "github.com/cloudposse/atmos/pkg/schema" +) + +// httpClient is used for outbound HTTP in this package; override in tests. +var httpClient = &http.Client{Timeout: 30 * time.Second} + +var ( + // Static errors for Azure authentication + errInvalidACRRegistryFormat = errors.New("invalid Azure Container Registry format") + errNoValidAzureAuth = errors.New("no valid Azure authentication found") + errFailedToCreateTokenExchangeReq = errors.New("failed to create token exchange request") + errFailedToExecuteTokenExchangeReq = errors.New("failed to execute token exchange request") + errTokenExchangeFailed = errors.New("token exchange failed") + errFailedToDecodeTokenExchangeResp = errors.New("failed to decode token exchange response") + errNoTokenReceivedFromACR = errors.New("no token received from ACR OAuth2 exchange") + errInvalidJWTTokenFormat = errors.New("invalid JWT token format") + errFailedToDecodeJWTPayload = errors.New("failed to decode JWT payload") + errFailedToParseJWTPayload = errors.New("failed to parse JWT payload") + errTenantIDNotFoundInJWT = errors.New("tenant ID not found in JWT token") + errFailedToCreateAzureCredential = errors.New("failed to create Azure credential") + errFailedToGetAzureToken = errors.New("failed to get Azure token") + errACRTokenExchangeFailed = errors.New("acr token exchange failed") + errFailedToCreateAzureDefaultCred = errors.New("failed to create Azure default credential") +) + +// extractACRName extracts the ACR name from the registry URL. +func extractACRName(registry string) (string, error) { + // Expected formats: .azurecr.{io|cn|us} + for _, suf := range []string{".azurecr.io", ".azurecr.us", ".azurecr.cn"} { + if strings.HasSuffix(registry, suf) { + return strings.TrimSuffix(registry, suf), nil + } + } + return "", fmt.Errorf("%w: %s (expected .azurecr.{io|us|cn})", errInvalidACRRegistryFormat, registry) +} + +// gatherAzureCredentials collects Azure credentials from config and environment. +func gatherAzureCredentials(atmosConfig *schema.AtmosConfiguration) (clientID, clientSecret, tenantID string) { + // Create a Viper instance for environment variable access + v := viper.New() + bindEnv(v, "azure_client_id", "ATMOS_OCI_AZURE_CLIENT_ID", "AZURE_CLIENT_ID") + bindEnv(v, "azure_client_secret", "ATMOS_OCI_AZURE_CLIENT_SECRET", "AZURE_CLIENT_SECRET") + bindEnv(v, "azure_tenant_id", "ATMOS_OCI_AZURE_TENANT_ID", "AZURE_TENANT_ID") + + // Resolve from env first, then config. + clientID = v.GetString("azure_client_id") + if clientID == "" { + clientID = atmosConfig.Settings.OCI.AzureClientID + } + clientSecret = v.GetString("azure_client_secret") + if clientSecret == "" { + clientSecret = atmosConfig.Settings.OCI.AzureClientSecret + } + tenantID = v.GetString("azure_tenant_id") + if tenantID == "" { + tenantID = atmosConfig.Settings.OCI.AzureTenantID + } + + return clientID, clientSecret, tenantID +} + +// getACRAuth attempts to get Azure Container Registry authentication. +func getACRAuth(registry string, atmosConfig *schema.AtmosConfiguration) (authn.Authenticator, error) { + // Extract ACR name from registry URL + acrName, err := extractACRName(registry) + if err != nil { + return nil, err + } + + // Gather Azure credentials + clientID, clientSecret, tenantID := gatherAzureCredentials(atmosConfig) + + // Try Service Principal authentication if credentials are available + if clientID != "" && clientSecret != "" && tenantID != "" { + log.Debug("Using Azure Service Principal credentials", logFieldRegistry, registry, "acrName", acrName) + return getACRAuthViaServicePrincipal(registry, acrName, clientID, clientSecret, tenantID) + } + + // Try Azure Default Credential (Managed Identity, Workload Identity, etc.) + if auth, err := getACRAuthViaDefaultCredential(registry, acrName); err == nil { + return auth, nil + } else { + log.Debug("Azure Default Credential failed", logFieldRegistry, registry, "error", err) + } + + return nil, fmt.Errorf("%w for %s", errNoValidAzureAuth, registry) +} + +// exchangeAADForACRRefreshToken exchanges an AAD token for an ACR refresh token. +func exchangeAADForACRRefreshToken(ctx context.Context, registry, tenantID, aadToken string) (string, error) { + // ACR OAuth2 endpoint for token exchange + oauthURL := fmt.Sprintf("https://%s/oauth2/exchange", registry) + + // Prepare the form data for the token exchange + formData := url.Values{} + formData.Set("grant_type", "access_token") + formData.Set("service", registry) + if tenantID != "" { + formData.Set("tenant", tenantID) + } + formData.Set("access_token", aadToken) + + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, "POST", oauthURL, strings.NewReader(formData.Encode())) + if err != nil { + return "", fmt.Errorf("%w: %v", errFailedToCreateTokenExchangeReq, err) + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + // Execute the request + resp, err := httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("%w: %v", errFailedToExecuteTokenExchangeReq, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<10)) // 2KB + return "", fmt.Errorf("%w: status=%d body=%q", errTokenExchangeFailed, resp.StatusCode, string(body)) + } + + // Parse the response + var tokenResponse struct { + RefreshToken string `json:"refresh_token"` + AccessToken string `json:"access_token"` + } + + if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil { + return "", fmt.Errorf("%w: %v", errFailedToDecodeTokenExchangeResp, err) + } + + // Return the refresh token (preferred) or access token as fallback + if tokenResponse.RefreshToken != "" { + return tokenResponse.RefreshToken, nil + } + if tokenResponse.AccessToken != "" { + return tokenResponse.AccessToken, nil + } + + return "", errNoTokenReceivedFromACR +} + +// extractTenantIDFromToken extracts the tenant ID from a JWT token. +func extractTenantIDFromToken(tokenString string) (string, error) { + // JWT tokens have 3 parts separated by dots + parts := strings.Split(tokenString, ".") + if len(parts) != 3 { + return "", errInvalidJWTTokenFormat + } + + // Decode the payload (second part) + payload := parts[1] + // Decode Base64URL (no padding) + decoded, err := base64.RawURLEncoding.DecodeString(payload) + if err != nil { + return "", fmt.Errorf("%w: %v", errFailedToDecodeJWTPayload, err) + } + + // Parse JSON payload + var payloadData struct { + TID string `json:"tid"` + } + + if err := json.Unmarshal(decoded, &payloadData); err != nil { + return "", fmt.Errorf("%w: %v", errFailedToParseJWTPayload, err) + } + + if payloadData.TID == "" { + return "", errTenantIDNotFoundInJWT + } + + return payloadData.TID, nil +} + +// getACRAuthViaServicePrincipal attempts to get ACR credentials using Azure Service Principal +func getACRAuthViaServicePrincipal(registry, acrName, clientID, clientSecret, tenantID string) (authn.Authenticator, error) { + // Create Azure credential using Service Principal + cred, err := azidentity.NewClientSecretCredential(tenantID, clientID, clientSecret, nil) + if err != nil { + return nil, fmt.Errorf("%w: %v", errFailedToCreateAzureCredential, err) + } + + // Get AAD token for ACR scope + ctx := context.Background() + aad, err := cred.GetToken(ctx, policy.TokenRequestOptions{ + Scopes: []string{"https://management.azure.com/.default"}, + }) + if err != nil { + return nil, fmt.Errorf("%w: %v", errFailedToGetAzureToken, err) + } + + // Exchange AAD token for ACR refresh token + refresh, err := exchangeAADForACRRefreshToken(ctx, registry, tenantID, aad.Token) + if err != nil { + return nil, fmt.Errorf("%w: %v", errACRTokenExchangeFailed, err) + } + log.Debug("Obtained ACR refresh token via Service Principal", logFieldRegistry, registry, "acrName", acrName) + return &authn.Basic{ + Username: "00000000-0000-0000-0000-000000000000", + Password: refresh, + }, nil +} + +// getACRAuthViaDefaultCredential attempts to get ACR credentials using Azure Default Credential +func getACRAuthViaDefaultCredential(registry, acrName string) (authn.Authenticator, error) { + // Create Azure credential using Default Credential (Managed Identity, Azure CLI, etc.) + cred, err := azidentity.NewDefaultAzureCredential(nil) + if err != nil { + return nil, fmt.Errorf("%w: %v", errFailedToCreateAzureDefaultCred, err) + } + + // Get AAD token for ACR scope + ctx := context.Background() + aad, err := cred.GetToken(ctx, policy.TokenRequestOptions{ + Scopes: []string{"https://management.azure.com/.default"}, + }) + if err != nil { + return nil, fmt.Errorf("%w: %v", errFailedToGetAzureToken, err) + } + + // Exchange AAD token for ACR refresh token + // Tenant is optional here; if unknown, pass empty and let ACR infer. + refresh, err := exchangeAADForACRRefreshToken(ctx, registry, "", aad.Token) + if err != nil { + return nil, fmt.Errorf("%w: %v", errACRTokenExchangeFailed, err) + } + log.Debug("Obtained ACR refresh token via Default Credential", logFieldRegistry, registry, "acrName", acrName) + + return &authn.Basic{ + Username: "00000000-0000-0000-0000-000000000000", + Password: refresh, + }, nil +} diff --git a/internal/exec/oci_azure_test.go b/internal/exec/oci_azure_test.go new file mode 100644 index 0000000000..1b12369b40 --- /dev/null +++ b/internal/exec/oci_azure_test.go @@ -0,0 +1,437 @@ +package exec + +import ( + "context" + "encoding/base64" + "fmt" + "io" + "net/http" + "os" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/cloudposse/atmos/pkg/schema" +) + +// TestACRAuthDirect tests Azure Container Registry authentication directly. +func TestACRAuthDirect(t *testing.T) { + azureSkipIfNotE2E(t) + + tests := []struct { + name string + registry string + atmosConfig *schema.AtmosConfiguration + expectError bool + errorMsg string + }{ + { + name: "ACR .io with no authentication", + registry: "test.azurecr.io", + atmosConfig: &schema.AtmosConfiguration{}, + expectError: true, + errorMsg: "no valid Azure authentication found", + }, + { + name: "ACR .us with no authentication", + registry: "test.azurecr.us", + atmosConfig: &schema.AtmosConfiguration{}, + expectError: true, + errorMsg: "no valid Azure authentication found", + }, + { + name: "ACR .cn with no authentication", + registry: "test.azurecr.cn", + atmosConfig: &schema.AtmosConfiguration{}, + expectError: true, + errorMsg: "no valid Azure authentication found", + }, + { + name: "Non-ACR registry", + registry: "docker.io", + atmosConfig: &schema.AtmosConfiguration{}, + expectError: true, + errorMsg: "invalid Azure Container Registry format", + }, + { + name: "Invalid ACR format", + registry: "test.azurecr.invalid", + atmosConfig: &schema.AtmosConfiguration{}, + expectError: true, + errorMsg: "invalid Azure Container Registry format", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := getACRAuth(tt.registry, tt.atmosConfig) + + if tt.expectError { + assert.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestGetACRAuthViaServicePrincipalDirect tests Azure Service Principal authentication directly. +func TestGetACRAuthViaServicePrincipalDirect(t *testing.T) { + azureSkipIfNotE2E(t) + + tests := []struct { + name string + registry string + acrName string + clientID string + clientSecret string + tenantID string + expectError bool + errorMsg string + }{ + { + name: "Valid Service Principal credentials (.io)", + registry: "test.azurecr.io", + acrName: "test", + clientID: "test-client-id", + clientSecret: "test-client-secret", + tenantID: "test-tenant-id", + expectError: true, // Will fail due to invalid credentials + errorMsg: "failed to get Azure token", + }, + { + name: "Valid Service Principal credentials (.us)", + registry: "test.azurecr.us", + acrName: "test", + clientID: "test-client-id", + clientSecret: "test-client-secret", + tenantID: "test-tenant-id", + expectError: true, // Will fail due to invalid credentials + errorMsg: "failed to get Azure token", + }, + { + name: "Valid Service Principal credentials (.cn)", + registry: "test.azurecr.cn", + acrName: "test", + clientID: "test-client-id", + clientSecret: "test-client-secret", + tenantID: "test-tenant-id", + expectError: true, // Will fail due to invalid credentials + errorMsg: "failed to get Azure token", + }, + { + name: "Missing client ID", + registry: "test.azurecr.io", + acrName: "test", + clientID: "", + clientSecret: "test-client-secret", + tenantID: "test-tenant-id", + expectError: true, + errorMsg: "failed to get Azure token", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := getACRAuthViaServicePrincipal(tt.registry, tt.acrName, tt.clientID, tt.clientSecret, tt.tenantID) + + if tt.expectError { + assert.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestGetACRAuthViaDefaultCredentialDirect tests Azure Default Credential authentication directly. +func TestGetACRAuthViaDefaultCredentialDirect(t *testing.T) { + azureSkipIfNotE2E(t) + + tests := []struct { + name string + registry string + acrName string + expectError bool + errorMsg string + }{ + { + name: "Default credential without Azure environment (.io)", + registry: "test.azurecr.io", + acrName: "test", + expectError: true, + errorMsg: "failed to get Azure token", + }, + { + name: "Default credential without Azure environment (.us)", + registry: "test.azurecr.us", + acrName: "test", + expectError: true, + errorMsg: "failed to get Azure token", + }, + { + name: "Default credential without Azure environment (.cn)", + registry: "test.azurecr.cn", + acrName: "test", + expectError: true, + errorMsg: "failed to get Azure token", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := getACRAuthViaDefaultCredential(tt.registry, tt.acrName) + + if tt.expectError { + assert.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestExchangeAADForACRRefreshTokenDirect tests AAD to ACR token exchange directly. +func TestExchangeAADForACRRefreshTokenDirect(t *testing.T) { + azureSkipIfNotE2E(t) + + tests := []struct { + name string + registry string + tenantID string + aadToken string + expectError bool + errorMsg string + }{ + { + name: "Invalid registry format", + registry: "invalid-registry", + tenantID: "test-tenant", + aadToken: "test-token", + expectError: true, + errorMsg: "failed to execute token exchange request", + }, + { + name: "Empty AAD token", + registry: "test.azurecr.io", + tenantID: "test-tenant", + aadToken: "", + expectError: true, + errorMsg: "failed to execute token exchange request", + }, + { + name: "Network error with invalid registry", + registry: "nonexistent.azurecr.io", + tenantID: "test-tenant", + aadToken: "test-token", + expectError: true, + errorMsg: "failed to execute token exchange request", + }, + { + name: "Valid parameters but network failure (.io)", + registry: "test.azurecr.io", + tenantID: "test-tenant", + aadToken: "test-token", + expectError: true, + errorMsg: "failed to execute token exchange request", + }, + { + name: "Valid parameters but network failure (.us)", + registry: "test.azurecr.us", + tenantID: "test-tenant", + aadToken: "test-token", + expectError: true, + errorMsg: "failed to execute token exchange request", + }, + { + name: "Valid parameters but network failure (.cn)", + registry: "test.azurecr.cn", + tenantID: "test-tenant", + aadToken: "test-token", + expectError: true, + errorMsg: "failed to execute token exchange request", + }, + { + name: "Empty tenant ID", + registry: "test.azurecr.io", + tenantID: "", + aadToken: "test-token", + expectError: true, + errorMsg: "failed to execute token exchange request", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := exchangeAADForACRRefreshToken(context.Background(), tt.registry, tt.tenantID, tt.aadToken) + + if tt.expectError { + assert.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestExchangeAADForACRRefreshTokenWithStubbedHTTP demonstrates how to stub HTTP client for testing. +func TestExchangeAADForACRRefreshTokenWithStubbedHTTP(t *testing.T) { + // Save original HTTP client + originalClient := httpClient + defer func() { + httpClient = originalClient + }() + + // Create a mock HTTP client that returns a predefined response + mockClient := &http.Client{ + Timeout: 10 * time.Second, + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + assert.Equal(t, http.MethodPost, req.Method) + assert.Equal(t, "test.azurecr.io", req.URL.Host) + assert.Equal(t, "application/json", req.Header.Get("Content-Type")) + assert.Equal(t, "Bearer test-token", req.Header.Get("Authorization")) + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"refresh_token": "mock-refresh-token", "access_token": "mock-access-token"}`)), + }, nil + }), + } + + // Override the HTTP client + httpClient = mockClient + + // Test the token exchange with stubbed HTTP + refreshToken, err := exchangeAADForACRRefreshToken( + context.Background(), + "test.azurecr.io", + "test-tenant", + "test-token", + ) + + // Verify the result + assert.NoError(t, err) + assert.Equal(t, "mock-refresh-token", refreshToken) +} + +// roundTripFunc helps inline stub HTTP behavior in tests. +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +// TestExtractTenantIDFromTokenDirect tests JWT token parsing for tenant ID directly. +func TestExtractTenantIDFromTokenDirect(t *testing.T) { + t.Parallel() + tests := []struct { + name string + token string + expectError bool + expectedID string + errorMsg string + }{ + { + name: "Valid JWT with tenant ID", + token: createValidJWTDirect("test-tenant-id"), + expectError: false, + expectedID: "test-tenant-id", + }, + { + name: "JWT without tenant ID", + token: createJWTWithoutTenantIDDirect(), + expectError: true, + errorMsg: "tenant ID not found in JWT token", + }, + { + name: "Invalid JWT format", + token: "invalid.jwt.token", + expectError: true, + errorMsg: "failed to parse JWT payload", + }, + { + name: "Empty token", + token: "", + expectError: true, + errorMsg: "invalid JWT token format", + }, + { + name: "JWT with invalid JSON in payload", + token: createJWTWithInvalidJSONDirect(), + expectError: true, + errorMsg: "failed to parse JWT payload", + }, + { + name: "JWT with missing payload", + token: "header.signature", + expectError: true, + errorMsg: "invalid JWT token format", + }, + { + name: "JWT with empty payload", + token: "header..signature", + expectError: true, + errorMsg: "failed to parse JWT payload", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + tenantID, err := extractTenantIDFromToken(tt.token) + + if tt.expectError { + assert.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedID, tenantID) + } + }) + } +} + +// Helper functions for creating test JWT tokens. +func createValidJWTDirect(tenantID string) string { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`)) + payload := base64.RawURLEncoding.EncodeToString([]byte(fmt.Sprintf(`{"tid":"%s","sub":"test","iss":"test"}`, tenantID))) + signature := base64.RawURLEncoding.EncodeToString([]byte("signature")) + return fmt.Sprintf("%s.%s.%s", header, payload, signature) +} + +func createJWTWithoutTenantIDDirect() string { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`)) + payload := base64.RawURLEncoding.EncodeToString([]byte(`{"sub":"test","iss":"test"}`)) + signature := base64.RawURLEncoding.EncodeToString([]byte("signature")) + return fmt.Sprintf("%s.%s.%s", header, payload, signature) +} + +func createJWTWithInvalidJSONDirect() string { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`)) + payload := base64.RawURLEncoding.EncodeToString([]byte(`{"invalid json`)) + signature := base64.RawURLEncoding.EncodeToString([]byte("signature")) + return fmt.Sprintf("%s.%s.%s", header, payload, signature) +} + +func azureSkipIfNotE2E(t *testing.T) { + if os.Getenv("ATMOS_AZURE_E2E") == "" { + t.Skip("Skipping Azure integration test (set ATMOS_AZURE_E2E=1 to run)") + } +} diff --git a/internal/exec/oci_docker.go b/internal/exec/oci_docker.go new file mode 100644 index 0000000000..1edcb9ef10 --- /dev/null +++ b/internal/exec/oci_docker.go @@ -0,0 +1,316 @@ +package exec + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "regexp" + "strings" + "time" + + log "github.com/charmbracelet/log" + "github.com/google/go-containerregistry/pkg/authn" + "github.com/spf13/viper" + + "github.com/cloudposse/atmos/pkg/schema" +) + +// Allow tests to override exec for credential helpers. +var ( + lookPath = exec.LookPath + commandContext = exec.CommandContext +) + +// DockerConfig represents the structure of Docker's config.json file. +type DockerConfig struct { + Auths map[string]struct { + Auth string `json:"auth"` + } `json:"auths"` + CredsStore string `json:"credsStore"` + CredHelpers map[string]string `json:"credHelpers"` +} + +var ( + // Static errors for Docker authentication + errFailedToGetUserHomeDir = errors.New("failed to get user home directory") + errDockerConfigFileNotFound = errors.New("Docker config file not found") + errFailedToReadDockerConfigFile = errors.New("failed to read Docker config file") + errFailedToParseDockerConfigJSON = errors.New("failed to parse Docker config JSON") + errNoCredentialHelpersConfigured = errors.New("no credential helpers configured") + errNoCredentialHelperFound = errors.New("no credential helper found for registry") + errNoGlobalCredentialStore = errors.New("no global credential store configured") + errFailedToDecodeAuthForRegistry = errors.New("failed to decode auth for registry") + errNoDirectAuthFound = errors.New("no direct auth found for registry") + errInvalidRegistryName = errors.New("invalid registry name") + errInvalidCredentialStoreName = errors.New("invalid credential store name") + errCredentialHelperNotFound = errors.New("credential helper not found") + errFailedToGetCredentialsFromStore = errors.New("failed to get credentials from store") + errFailedToParseCredentialStoreOutput = errors.New("failed to parse credential store output") + errInvalidCredentialsFromStore = errors.New("invalid credentials from store") + errFailedToDecodeBase64AuthString = errors.New("failed to decode base64 auth string") + errInvalidAuthStringFormat = errors.New("invalid auth string format, expected username:password") +) + +// resolveDockerConfigPath resolves the Docker config file path from various sources. +func resolveDockerConfigPath(atmosConfig *schema.AtmosConfiguration) (string, error) { + // Create a Viper instance for environment variable access. + v := viper.New() + bindEnv(v, "docker_config", "ATMOS_OCI_DOCKER_CONFIG", "DOCKER_CONFIG") + + // Resolve Docker config path (env has precedence). + configDir := v.GetString("docker_config") + if configDir == "" { + configDir = atmosConfig.Settings.OCI.DockerConfig + } + if configDir == "" { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("%w: %v", errFailedToGetUserHomeDir, err) + } + configDir = filepath.Join(homeDir, ".docker") + } + return filepath.Join(configDir, "config.json"), nil +} + +// loadDockerConfig loads and parses the Docker config file. +func loadDockerConfig(configPath string) (DockerConfig, error) { + // Check if Docker config file exists. + if _, err := os.Stat(configPath); os.IsNotExist(err) { + return DockerConfig{}, fmt.Errorf("%w: %s", errDockerConfigFileNotFound, configPath) + } + + // Read Docker config file. + configData, err := os.ReadFile(configPath) + if err != nil { + return DockerConfig{}, fmt.Errorf("%w: %v", errFailedToReadDockerConfigFile, err) + } + + // Parse Docker config JSON. + var dockerConfig DockerConfig + if err := json.Unmarshal(configData, &dockerConfig); err != nil { + return DockerConfig{}, fmt.Errorf("%w: %v", errFailedToParseDockerConfigJSON, err) + } + + return dockerConfig, nil +} + +// tryCredentialHelper attempts to authenticate using a specific credential helper. +func tryCredentialHelper(registry, helperKey, helper string) (authn.Authenticator, error) { + if helper == "" { + return nil, errNoCredentialHelperFound + } + + // Use the exact helper key (server URL) that matched in config. + auth, err := getCredentialStoreAuth(helperKey, helper) + if err == nil { + log.Debug("Using per-registry credential helper", logFieldRegistry, helperKey, "helper", helper) + return auth, nil + } + + log.Debug("Per-registry credential helper failed", logFieldRegistry, helperKey, "helper", helper, "error", err) + return nil, err +} + +// tryCredentialHelpers attempts to authenticate using per-registry credential helpers. +func tryCredentialHelpers(registry string, credHelpers map[string]string) (authn.Authenticator, error) { + if credHelpers == nil { + return nil, errNoCredentialHelpersConfigured + } + + // Try exact registry match first + if helper, ok := credHelpers[registry]; ok { + if auth, err := tryCredentialHelper(registry, registry, helper); err == nil { + return auth, nil + } + } + + // Try with https:// prefix + httpsRegistry := "https://" + registry + if helper, ok := credHelpers[httpsRegistry]; ok { + if auth, err := tryCredentialHelper(registry, httpsRegistry, helper); err == nil { + return auth, nil + } + } + + // Try with http:// prefix. + httpRegistry := "http://" + registry + if helper, ok := credHelpers[httpRegistry]; ok { + if auth, err := tryCredentialHelper(registry, httpRegistry, helper); err == nil { + return auth, nil + } + } + + return nil, errNoCredentialHelperFound +} + +// tryGlobalCredentialStore attempts to authenticate using the global credential store. +func tryGlobalCredentialStore(registry, credsStore string) (authn.Authenticator, error) { + if credsStore == "" { + return nil, errNoGlobalCredentialStore + } + + // Try common server-URL variants before giving up. + variants := []string{ + registry, + "https://" + registry, + "http://" + registry, + "https://" + registry + "/v1/", + "http://" + registry + "/v1/", + } + var lastErr error + for _, r := range variants { + if auth, err := getCredentialStoreAuth(r, credsStore); err == nil { + log.Debug("Using global credential store authentication", logFieldRegistry, r, "store", credsStore) + return auth, nil + } else { + lastErr = err + log.Debug("Global credential store authentication failed", logFieldRegistry, r, "store", credsStore, "error", err) + } + } + + return nil, lastErr +} + +// tryDirectAuth attempts to authenticate using direct auth strings in the config. +func tryDirectAuth(registry string, auths map[string]struct { + Auth string `json:"auth"` +}, +) (authn.Authenticator, error) { + // Try different registry formats + registryVariants := []string{ + registry, + "https://" + registry, + "http://" + registry, + "https://" + registry + "/v1/", + "http://" + registry + "/v1/", + } + + for _, reg := range registryVariants { + if authData, exists := auths[reg]; exists && authData.Auth != "" { + username, password, err := decodeDockerAuth(authData.Auth) + if err != nil { + return nil, fmt.Errorf("%w %s: %v", errFailedToDecodeAuthForRegistry, reg, err) + } + return &authn.Basic{ + Username: username, + Password: password, + }, nil + } + } + + return nil, errNoDirectAuthFound +} + +// getDockerAuth attempts to get Docker authentication for a registry +// Supports DOCKER_CONFIG environment variable, global credential stores (credsStore), +// and per-registry credential helpers (credHelpers). +func getDockerAuth(registry string, atmosConfig *schema.AtmosConfiguration) (authn.Authenticator, error) { + // Resolve Docker config path + configPath, err := resolveDockerConfigPath(atmosConfig) + if err != nil { + return nil, err + } + log.Debug("Using Docker config path", "path", configPath) + + // Load Docker config + dockerConfig, err := loadDockerConfig(configPath) + if err != nil { + if errors.Is(err, errDockerConfigFileNotFound) { + log.Debug("Docker config file not found; skipping Docker auth", "path", configPath) + return nil, fmt.Errorf("%w %s", errNoAuthenticationFound, registry) + } + return nil, err + } + + // Try per-registry credential helpers first + if auth, err := tryCredentialHelpers(registry, dockerConfig.CredHelpers); err == nil { + return auth, nil + } + + // Try global credential store + if auth, err := tryGlobalCredentialStore(registry, dockerConfig.CredsStore); err == nil { + return auth, nil + } + + // Try direct auth strings + if auth, err := tryDirectAuth(registry, dockerConfig.Auths); err == nil { + return auth, nil + } + + return nil, fmt.Errorf("%w %s", errNoAuthenticationFound, registry) +} + +// getCredentialStoreAuth attempts to get credentials from Docker's credential store. +func getCredentialStoreAuth(registry, credsStore string) (authn.Authenticator, error) { + // Validate registry using an allowlist approach + // Registry may only include letters, digits, dots, colons, slashes, and hyphens + validRegistry := regexp.MustCompile(`^[A-Za-z0-9./:-]+$`) + if !validRegistry.MatchString(registry) { + return nil, fmt.Errorf("%w: %s", errInvalidRegistryName, registry) + } + + // Validate credsStore using an allowlist (letters, digits, underscore, hyphen). + if !regexp.MustCompile(`^[A-Za-z0-9_-]+$`).MatchString(credsStore) { + return nil, fmt.Errorf("%w: %s", errInvalidCredentialStoreName, credsStore) + } + + // For Docker Desktop on macOS, the credential store is typically "desktop". + // We need to use the docker-credential-desktop helper to get credentials. + + // Try to execute the credential helper + helperCmd := "docker-credential-" + credsStore + if _, err := lookPath(helperCmd); err != nil { + return nil, fmt.Errorf("%w %s: %v", errCredentialHelperNotFound, helperCmd, err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + cmd := commandContext(ctx, helperCmd, "get") + cmd.Stdin = strings.NewReader(registry + "\n") + + output, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("%w %s: %v", errFailedToGetCredentialsFromStore, credsStore, err) + } + + // Parse the JSON output from the credential helper + var creds struct { + Username string `json:"Username"` + Secret string `json:"Secret"` + } + + if err := json.Unmarshal(output, &creds); err != nil { + return nil, fmt.Errorf("%w: %v", errFailedToParseCredentialStoreOutput, err) + } + + if creds.Username == "" || creds.Secret == "" { + return nil, errInvalidCredentialsFromStore + } + + return &authn.Basic{ + Username: creds.Username, + Password: creds.Secret, + }, nil +} + +// decodeDockerAuth decodes the base64-encoded auth string from Docker config. +func decodeDockerAuth(authString string) (string, string, error) { + // Decode base64 + decoded, err := base64.StdEncoding.DecodeString(authString) + if err != nil { + return "", "", fmt.Errorf("%w: %v", errFailedToDecodeBase64AuthString, err) + } + + // Split username:password + parts := strings.SplitN(string(decoded), ":", 2) + if len(parts) != 2 { + return "", "", errInvalidAuthStringFormat + } + + return parts[0], parts[1], nil +} diff --git a/internal/exec/oci_docker_test.go b/internal/exec/oci_docker_test.go new file mode 100644 index 0000000000..3bbb1d028c --- /dev/null +++ b/internal/exec/oci_docker_test.go @@ -0,0 +1,275 @@ +package exec + +import ( + "encoding/base64" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/cloudposse/atmos/pkg/schema" +) + +func TestDockerCredHelpers(t *testing.T) { + // Hermetic Docker config: empty config.json in a temp dir + dir := t.TempDir() + cfg := filepath.Join(dir, "config.json") + if err := os.WriteFile(cfg, []byte(`{}`), 0o600); err != nil { + t.Fatal(err) + } + t.Setenv("DOCKER_CONFIG", dir) + t.Setenv("ATMOS_OCI_DOCKER_CONFIG", cfg) + + tests := []struct { + name string + registry string + atmosConfig *schema.AtmosConfiguration + expectError bool + errorMsg string + }{ + { + name: "Docker Hub with no authentication", + registry: "docker.io", + atmosConfig: &schema.AtmosConfiguration{}, + expectError: true, // Will fail without actual credential helper or config + errorMsg: "failed to read Docker config file", + }, + { + name: "Private registry with no authentication", + registry: "my-registry.com", + atmosConfig: &schema.AtmosConfiguration{}, + expectError: true, + errorMsg: "failed to read Docker config file", + }, + { + name: "Registry with no authentication", + registry: "test.registry.com", + atmosConfig: &schema.AtmosConfiguration{}, + expectError: true, + errorMsg: "failed to read Docker config file", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := getDockerAuth(tt.registry, tt.atmosConfig) + + if tt.expectError { + assert.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestDecodeDockerAuth tests Docker auth string decoding. +func TestDecodeDockerAuth(t *testing.T) { + tests := []struct { + name string + authString string + expectError bool + errorMsg string + expectedUser string + expectedPass string + }{ + { + name: "Valid auth string", + // Build "username:password" at runtime to avoid secret scanners. + authString: base64.StdEncoding.EncodeToString([]byte("username:password")), + expectError: false, + expectedUser: "username", + expectedPass: "password", + }, + { + name: "Invalid base64 string", + authString: "invalid-base64", + expectError: true, + errorMsg: "failed to decode base64 auth string", + }, + { + name: "Invalid format (no colon)", + // Build "username" at runtime to avoid secret scanners. + authString: base64.StdEncoding.EncodeToString([]byte("username")), + expectError: true, + errorMsg: "invalid auth string format", + }, + { + name: "Empty string", + authString: "", + expectError: true, + errorMsg: "invalid auth string format", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + username, password, err := decodeDockerAuth(tt.authString) + + if tt.expectError { + assert.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedUser, username) + assert.Equal(t, tt.expectedPass, password) + } + }) + } +} + +// TestGetCredentialStoreAuth tests credential store authentication. +func TestGetCredentialStoreAuth(t *testing.T) { + tests := []struct { + name string + registry string + credsStore string + expectError bool + errorMsg string + }{ + { + name: "Empty credential store name", + registry: "docker.io", + credsStore: "", + expectError: true, + errorMsg: "credential helper docker-credential- not found", + }, + { + name: "Non-existent credential helper", + registry: "docker.io", + credsStore: "nonexistent", + expectError: true, + errorMsg: "credential helper docker-credential-nonexistent not found", + }, + { + name: "Invalid registry name with command injection attempt", + registry: "registry; rm -rf /", + credsStore: "desktop", + expectError: true, + errorMsg: "invalid registry name", + }, + { + name: "Invalid registry name with shell metacharacters", + registry: "registry&echo hello", + credsStore: "desktop", + expectError: true, + errorMsg: "invalid registry name", + }, + { + name: "Invalid registry name with backticks", + registry: "registry`whoami`", + credsStore: "desktop", + expectError: true, + errorMsg: "invalid registry name", + }, + { + name: "Invalid registry name with dollar expansion", + registry: "registry$HOME", + credsStore: "desktop", + expectError: true, + errorMsg: "invalid registry name", + }, + { + name: "Invalid registry name with parentheses", + registry: "registry$(whoami)", + credsStore: "desktop", + expectError: true, + errorMsg: "invalid registry name", + }, + { + name: "Invalid registry name with brackets", + registry: "registry[test]", + credsStore: "desktop", + expectError: true, + errorMsg: "invalid registry name", + }, + { + name: "Invalid registry name with quotes", + registry: "registry'test'", + credsStore: "desktop", + expectError: true, + errorMsg: "invalid registry name", + }, + { + name: "Invalid registry name with newlines", + registry: "registry\ntest", + credsStore: "desktop", + expectError: true, + errorMsg: "invalid registry name", + }, + { + name: "Valid registry name - standard domain", + registry: "docker.io", + credsStore: "desktop", + expectError: true, + errorMsg: "failed to get credentials from store", + }, + { + name: "Valid registry name - with port", + registry: "registry.com:5000", + credsStore: "desktop", + expectError: true, + errorMsg: "failed to get credentials from store", + }, + { + name: "Valid registry name - with path", + registry: "registry.com/path", + credsStore: "desktop", + expectError: true, + errorMsg: "failed to get credentials from store", + }, + { + name: "Valid registry name - with hyphens", + registry: "my-registry.com", + credsStore: "desktop", + expectError: true, + errorMsg: "failed to get credentials from store", + }, + { + name: "Valid registry name - IP address", + registry: "192.168.1.100:5000", + credsStore: "desktop", + expectError: true, + errorMsg: "failed to get credentials from store", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := getCredentialStoreAuth(tt.registry, tt.credsStore) + + if tt.expectError { + assert.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +// Helper function to create temporary files for testing. +func createTempFile(t *testing.T, content string) string { + tmpfile, err := os.CreateTemp("", "docker-config-*.json") + if err != nil { + t.Fatal(err) + } + + if _, err := tmpfile.Write([]byte(content)); err != nil { + t.Fatal(err) + } + + if err := tmpfile.Close(); err != nil { + t.Fatal(err) + } + + return tmpfile.Name() +} diff --git a/internal/exec/oci_extraction_test.go b/internal/exec/oci_extraction_test.go new file mode 100644 index 0000000000..12a7efad21 --- /dev/null +++ b/internal/exec/oci_extraction_test.go @@ -0,0 +1,294 @@ +package exec + +import ( + "archive/zip" + "bytes" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// runZipExtractionTest is a helper function to run ZIP extraction tests. +func runZipExtractionTest(t *testing.T, zipContent map[string]string, expectError bool, errorMsg string) { + tempDir := t.TempDir() + + // Create a ZIP file in memory + var buf bytes.Buffer + zipWriter := zip.NewWriter(&buf) + + for filename, content := range zipContent { + writer, err := zipWriter.Create(filename) + require.NoError(t, err) + _, err = writer.Write([]byte(content)) + require.NoError(t, err) + } + zipWriter.Close() + + // Test extraction + reader := bytes.NewReader(buf.Bytes()) + err := extractZipFile(reader, tempDir) + + if expectError { + assert.Error(t, err) + if errorMsg != "" { + assert.Contains(t, err.Error(), errorMsg) + } + } else { + assert.NoError(t, err) + + // Check that files were extracted + for filename, expectedContent := range zipContent { + filePath := filepath.Join(tempDir, filename) + assert.FileExists(t, filePath) + + content, err := os.ReadFile(filePath) + assert.NoError(t, err) + assert.Equal(t, expectedContent, string(content)) + } + } +} + +// runZipExtractionTestSuite runs a complete test suite for ZIP extraction. +func runZipExtractionTestSuite(t *testing.T, tests []struct { + name string + zipContent map[string]string + expectError bool + errorMsg string +}, +) { + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + runZipExtractionTest(t, tt.zipContent, tt.expectError, tt.errorMsg) + }) + } +} + +// TestExtractZipFile tests ZIP file extraction functionality. +func TestExtractZipFile(t *testing.T) { + tests := []struct { + name string + zipContent map[string]string + expectError bool + errorMsg string + }{ + { + name: "Valid ZIP with files", + zipContent: map[string]string{ + "file1.txt": "content1", + "file2.txt": "content2", + }, + expectError: false, + }, + { + name: "ZIP with directory", + zipContent: map[string]string{ + "dir/file.txt": "content", + }, + expectError: false, + }, + { + name: "ZIP with path traversal attempt", + zipContent: map[string]string{ + "../file.txt": "malicious", + }, + expectError: true, + errorMsg: "illegal file path in ZIP", + }, + { + name: "ZIP with absolute path", + zipContent: map[string]string{ + "/etc/passwd": "malicious", + }, + expectError: true, + errorMsg: "illegal file path in ZIP", + }, + { + name: "ZIP with Windows absolute path", + zipContent: map[string]string{ + "C:\\Windows\\file.txt": "malicious", + }, + expectError: true, + errorMsg: "illegal file path in ZIP", + }, + } + + runZipExtractionTestSuite(t, tests) +} + +// TestExtractZipFileZipSlip tests ZIP slip vulnerability protection. +func TestExtractZipFileZipSlip(t *testing.T) { + tests := []struct { + name string + zipContent map[string]string + expectError bool + errorMsg string + }{ + { + name: "Path traversal with ../", + zipContent: map[string]string{ + "../../../etc/passwd": "malicious", + }, + expectError: true, + errorMsg: "illegal file path in ZIP", + }, + { + name: "Path traversal with ..\\", + zipContent: map[string]string{ + "..\\..\\..\\Windows\\System32\\config\\SAM": "malicious", + }, + expectError: true, + errorMsg: "illegal file path in ZIP", + }, + { + name: "Mixed path traversal", + zipContent: map[string]string{ + "normal/../malicious/file.txt": "malicious", + }, + expectError: true, + errorMsg: "illegal file path in ZIP", + }, + { + name: "Valid nested path", + zipContent: map[string]string{ + "normal/nested/file.txt": "valid content", + }, + expectError: false, + }, + } + + runZipExtractionTestSuite(t, tests) +} + +// TestExtractZipFileSymlinks tests ZIP file extraction with symlink handling. +func TestExtractZipFileSymlinks(t *testing.T) { + tests := []struct { + name string + zipContent map[string]string + expectError bool + errorMsg string + }{ + { + name: "ZIP with symlink", + zipContent: map[string]string{ + "file.txt": "content", + }, + expectError: false, + }, + { + name: "ZIP with directory", + zipContent: map[string]string{ + "dir/": "", + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + + // Create a ZIP file in memory + var buf bytes.Buffer + zipWriter := zip.NewWriter(&buf) + + for filename, content := range tt.zipContent { + if strings.HasSuffix(filename, "/") { + // Create directory entry + _, err := zipWriter.Create(filename) + require.NoError(t, err) + } else { + // Create file entry + writer, err := zipWriter.Create(filename) + require.NoError(t, err) + _, err = writer.Write([]byte(content)) + require.NoError(t, err) + } + } + zipWriter.Close() + + // Test extraction + reader := bytes.NewReader(buf.Bytes()) + err := extractZipFile(reader, tempDir) + + if tt.expectError { + assert.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + assert.NoError(t, err) + + // Check that files were extracted + for filename, expectedContent := range tt.zipContent { + if !strings.HasSuffix(filename, "/") { + filePath := filepath.Join(tempDir, filename) + assert.FileExists(t, filePath) + + content, err := os.ReadFile(filePath) + assert.NoError(t, err) + assert.Equal(t, expectedContent, string(content)) + } + } + } + }) + } +} + +// TestExtractRawData tests raw data extraction functionality. +func TestExtractRawData(t *testing.T) { + tests := []struct { + name string + data string + layerIndex int + expectError bool + }{ + { + name: "Valid raw data", + data: "test data content", + layerIndex: 0, + expectError: false, + }, + { + name: "Empty data", + data: "", + layerIndex: 0, + expectError: false, + }, + { + name: "Large data", + data: strings.Repeat("test", 1000), + layerIndex: 1, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + reader := strings.NewReader(tt.data) + + err := extractRawData(reader, tempDir, tt.layerIndex) + + if tt.expectError { + assert.Error(t, err) + return + } + + assert.NoError(t, err) + + // Check that the file was created + expectedFile := filepath.Join(tempDir, fmt.Sprintf("layer_%d_raw", tt.layerIndex)) + assert.FileExists(t, expectedFile) + + // Check file content + content, err := os.ReadFile(expectedFile) + assert.NoError(t, err) + assert.Equal(t, tt.data, string(content)) + }) + } +} diff --git a/internal/exec/oci_github.go b/internal/exec/oci_github.go new file mode 100644 index 0000000000..68ce3f9176 --- /dev/null +++ b/internal/exec/oci_github.go @@ -0,0 +1,41 @@ +package exec + +import ( + "errors" + "fmt" + "strings" + + log "github.com/charmbracelet/log" + "github.com/google/go-containerregistry/pkg/authn" + "github.com/spf13/viper" + + "github.com/cloudposse/atmos/pkg/schema" +) + +// Static errors for GitHub authentication +var errNoGitHubAuthenticationFound = errors.New("no GitHub authentication found for registry") + +// getGitHubAuth attempts to get GitHub Container Registry authentication. +func getGitHubAuth(registry string, atmosConfig *schema.AtmosConfiguration) (authn.Authenticator, error) { + // Check for GitHub Container Registry + if strings.EqualFold(registry, "ghcr.io") { + // Create a Viper instance for environment variable access + v := viper.New() + bindEnv(v, "github_token", "ATMOS_OCI_GITHUB_TOKEN", "GITHUB_TOKEN") + + // Try Atmos-specific token first, then fallback to standard GITHUB_TOKEN + token := atmosConfig.Settings.OCI.GithubToken + if token == "" { + token = v.GetString("github_token") // Use Viper instead of os.Getenv + } + if token != "" { + log.Debug("Using GitHub token for authentication", "registry", registry) + return &authn.Basic{ + Username: "oauth2", + Password: token, + }, nil + } + } + + return nil, fmt.Errorf("%w %s", errNoGitHubAuthenticationFound, registry) +} diff --git a/internal/exec/oci_github_test.go b/internal/exec/oci_github_test.go new file mode 100644 index 0000000000..af919f51e8 --- /dev/null +++ b/internal/exec/oci_github_test.go @@ -0,0 +1,171 @@ +package exec + +import ( + "os" + "testing" + + "github.com/google/go-containerregistry/pkg/authn" + "github.com/stretchr/testify/assert" + + "github.com/cloudposse/atmos/pkg/schema" +) + +// TestGitHubAuth tests GitHub Container Registry authentication. +func TestGitHubAuth(t *testing.T) { + tests := []struct { + name string + registry string + atmosConfig *schema.AtmosConfiguration + expectError bool + errorMsg string + }{ + { + name: "GitHub Container Registry with token", + registry: "ghcr.io", + atmosConfig: &schema.AtmosConfiguration{}, + expectError: false, // Will succeed with GITHUB_TOKEN + }, + { + name: "GitHub Container Registry with Atmos token", + registry: "ghcr.io", + atmosConfig: &schema.AtmosConfiguration{ + Settings: schema.AtmosSettings{ + OCI: schema.OCISettings{ + GithubToken: "test-token", + }, + }, + }, + expectError: false, + }, + { + name: "Non-GitHub registry", + registry: "docker.io", + atmosConfig: &schema.AtmosConfiguration{}, + expectError: true, + errorMsg: "no GitHub authentication found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set up environment for tests that need it + if tt.name == "GitHub Container Registry with token" { + os.Setenv("GITHUB_TOKEN", "test-token") + defer os.Unsetenv("GITHUB_TOKEN") + } + + auth, err := getGitHubAuth(tt.registry, tt.atmosConfig) + + if tt.expectError { + assert.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + assert.NoError(t, err) + assert.NotNil(t, auth) + + // Verify the authenticator is of the correct type + basicAuth, ok := auth.(*authn.Basic) + assert.True(t, ok) + assert.Equal(t, "oauth2", basicAuth.Username) + assert.Equal(t, "test-token", basicAuth.Password) + } + }) + } +} + +// TestGitHubAuthWithEnvironment tests GitHub authentication with environment variables. +func TestGitHubAuthWithEnvironment(t *testing.T) { + tests := []struct { + name string + registry string + atmosConfig *schema.AtmosConfiguration + envToken string + expectError bool + errorMsg string + }{ + { + name: "GitHub CR with GITHUB_TOKEN environment variable", + registry: "ghcr.io", + atmosConfig: &schema.AtmosConfiguration{}, + envToken: "env-token", + expectError: false, + }, + { + name: "GitHub CR with ATMOS_OCI_GITHUB_TOKEN environment variable", + registry: "ghcr.io", + atmosConfig: &schema.AtmosConfiguration{}, + envToken: "atmos-token", + expectError: false, + }, + { + name: "GitHub CR with no environment token", + registry: "ghcr.io", + atmosConfig: &schema.AtmosConfiguration{}, + envToken: "", + expectError: true, + errorMsg: "no GitHub authentication found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set up environment + if tt.envToken != "" { + if tt.name == "GitHub CR with ATMOS_OCI_GITHUB_TOKEN environment variable" { + os.Setenv("ATMOS_OCI_GITHUB_TOKEN", tt.envToken) + defer os.Unsetenv("ATMOS_OCI_GITHUB_TOKEN") + } else { + os.Setenv("GITHUB_TOKEN", tt.envToken) + defer os.Unsetenv("GITHUB_TOKEN") + } + } + + auth, err := getGitHubAuth(tt.registry, tt.atmosConfig) + + if tt.expectError { + assert.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + assert.NoError(t, err) + assert.NotNil(t, auth) + + // Verify the authenticator + basicAuth, ok := auth.(*authn.Basic) + assert.True(t, ok) + assert.Equal(t, "oauth2", basicAuth.Username) + assert.Equal(t, tt.envToken, basicAuth.Password) + } + }) + } +} + +// TestGitHubAuthPrecedence tests token precedence (Atmos config vs environment). +func TestGitHubAuthPrecedence(t *testing.T) { + // Test that Atmos config token takes precedence over environment + atmosConfig := &schema.AtmosConfiguration{ + Settings: schema.AtmosSettings{ + OCI: schema.OCISettings{ + GithubToken: "atmos-config-token", + }, + }, + } + + // Set environment token + os.Setenv("GITHUB_TOKEN", "env-token") + defer os.Unsetenv("GITHUB_TOKEN") + + auth, err := getGitHubAuth("ghcr.io", atmosConfig) + + assert.NoError(t, err) + assert.NotNil(t, auth) + + // Should use Atmos config token, not environment token + basicAuth, ok := auth.(*authn.Basic) + assert.True(t, ok) + assert.Equal(t, "oauth2", basicAuth.Username) + assert.Equal(t, "atmos-config-token", basicAuth.Password) +} diff --git a/internal/exec/oci_google.go b/internal/exec/oci_google.go new file mode 100644 index 0000000000..3476fdc904 --- /dev/null +++ b/internal/exec/oci_google.go @@ -0,0 +1,51 @@ +package exec + +import ( + "context" + "errors" + "fmt" + "time" + + log "github.com/charmbracelet/log" + "github.com/google/go-containerregistry/pkg/authn" + "golang.org/x/oauth2/google" +) + +var ( + // Static errors for Google Cloud authentication + errFailedToFindGoogleCloudCredentials = errors.New("failed to find Google Cloud credentials") + errNoGoogleCloudCredentialsFound = errors.New("no Google Cloud credentials found for registry") + errFailedToGetGoogleCloudToken = errors.New("failed to get Google Cloud token") +) + +// getGCRAuth attempts to get Google Container Registry authentication. +func getGCRAuth(registry string) (authn.Authenticator, error) { + // Use Google Cloud Application Default Credentials + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + creds, err := google.FindDefaultCredentials(ctx, "https://www.googleapis.com/auth/cloud-platform") + if err != nil { + log.Debug("Failed to find Google Cloud credentials", logFieldRegistry, registry, "error", err) + return nil, fmt.Errorf("%w: %v", errFailedToFindGoogleCloudCredentials, err) + } + + if creds == nil || creds.TokenSource == nil { + log.Debug("No Google Cloud credentials found", logFieldRegistry, registry) + return nil, fmt.Errorf("%w %s", errNoGoogleCloudCredentialsFound, registry) + } + + // Get a token from the credentials + token, err := creds.TokenSource.Token() + if err != nil { + log.Debug("Failed to get Google Cloud token", logFieldRegistry, registry, "error", err) + return nil, fmt.Errorf("%w: %v", errFailedToGetGoogleCloudToken, err) + } + + // For GCR/Artifact Registry, use an OAuth2 access token as the password with + // the username "oauth2accesstoken". This is the standard pattern for GCR/AR authentication. + log.Debug("Successfully obtained Google Cloud credentials", logFieldRegistry, registry) + return &authn.Basic{ + Username: "oauth2accesstoken", + Password: token.AccessToken, + }, nil +} diff --git a/internal/exec/oci_google_test.go b/internal/exec/oci_google_test.go new file mode 100644 index 0000000000..9526d07c56 --- /dev/null +++ b/internal/exec/oci_google_test.go @@ -0,0 +1,96 @@ +package exec + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestGCRAuth tests Google Container Registry authentication. +func TestGCRAuth(t *testing.T) { + tests := []struct { + name string + registry string + expectError bool + errorMsg string + }{ + { + name: "Google Container Registry", + registry: "gcr.io", + expectError: true, // Will fail without Google Cloud credentials + errorMsg: "failed to find Google Cloud credentials", + }, + { + name: "Google Artifact Registry", + registry: "us-docker.pkg.dev", + expectError: true, // Will fail without Google Cloud credentials + errorMsg: "failed to find Google Cloud credentials", + }, + { + name: "Non-Google registry", + registry: "docker.io", + expectError: true, + errorMsg: "failed to find Google Cloud credentials", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := getGCRAuth(tt.registry) + + if tt.expectError { + assert.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestGCRRegistryDetection tests Google registry detection logic. +// Note: This test reflects the current implementation logic in getRegistryAuth. +func TestGCRRegistryDetection(t *testing.T) { + tests := []struct { + name string + registry string + isGCR bool + }{ + { + name: "Google Container Registry", + registry: "gcr.io", + isGCR: true, + }, + { + name: "Google Artifact Registry", + registry: "us-docker.pkg.dev", + isGCR: true, + }, + { + name: "Another Artifact Registry region", + registry: "europe-docker.pkg.dev", + isGCR: true, + }, + { + name: "Non-Google registry", + registry: "docker.io", + isGCR: false, + }, + { + name: "Azure Container Registry", + registry: "test.azurecr.io", + isGCR: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // This matches the logic in getRegistryAuth in oci_utils.go + isGCR := strings.Contains(tt.registry, "gcr.io") || strings.Contains(tt.registry, "pkg.dev") + assert.Equal(t, tt.isGCR, isGCR) + }) + } +} diff --git a/internal/exec/oci_processing_test.go b/internal/exec/oci_processing_test.go new file mode 100644 index 0000000000..0599f96c18 --- /dev/null +++ b/internal/exec/oci_processing_test.go @@ -0,0 +1,264 @@ +package exec + +import ( + "bytes" + "encoding/json" + "io" + "strings" + "testing" + + log "github.com/charmbracelet/log" + "github.com/google/go-containerregistry/pkg/name" + "github.com/google/go-containerregistry/pkg/v1/remote" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/cloudposse/atmos/pkg/schema" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" +) + +// TestProcessOciImage tests the main OCI image processing function. +func TestProcessOciImage(t *testing.T) { + tests := []struct { + name string + imageName string + atmosConfig *schema.AtmosConfiguration + expectError bool + errorMsg string + }{ + { + name: "Invalid image reference", + imageName: "invalid-image-reference", + atmosConfig: &schema.AtmosConfiguration{}, + expectError: true, + errorMsg: "failed to pull image", + }, + { + name: "Empty image name", + imageName: "", + atmosConfig: &schema.AtmosConfiguration{}, + expectError: true, + errorMsg: "invalid image reference", + }, + { + name: "Valid image reference format", + imageName: "ghcr.io/test/repo:latest", + atmosConfig: &schema.AtmosConfiguration{}, + expectError: true, // Will fail to pull, but should parse correctly + errorMsg: "failed to pull image", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + err := processOciImage(tt.atmosConfig, tt.imageName, tempDir) + + if tt.expectError { + assert.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestProcessOciImageIntegration tests OCI image processing with mock data. +func TestProcessOciImageIntegration(t *testing.T) { + t.Run("Test with mock OCI image", func(t *testing.T) { + // This test verifies that the function handles the processing flow correctly + // even when the actual image pull fails + atmosConfig := &schema.AtmosConfiguration{} + tempDir := t.TempDir() + + // Use an invalid image name that will fail to pull + imageName := "index.docker.io/library/invalid-image-name:latest" + + err := processOciImage(atmosConfig, imageName, tempDir) + + // Should fail due to invalid image, but should not panic + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to pull image") + }) +} + +// TestPullImage tests the image pulling functionality. +func TestPullImage(t *testing.T) { + tests := []struct { + name string + imageRef string + atmosConfig *schema.AtmosConfiguration + expectError bool + errorMsg string + }{ + { + name: "Invalid registry reference", + imageRef: "invalid-registry/test:latest", + atmosConfig: &schema.AtmosConfiguration{}, + expectError: true, + errorMsg: "failed to pull image", + }, + { + name: "Valid format but non-existent image", + imageRef: "ghcr.io/non-existent/repo:latest", + atmosConfig: &schema.AtmosConfiguration{}, + expectError: true, + errorMsg: "failed to pull image", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ref, err := name.ParseReference(tt.imageRef) + if err != nil { + t.Skipf("Skipping test due to invalid reference: %v", err) + } + + descriptor, err := pullImage(ref, tt.atmosConfig) + + if tt.expectError { + assert.Error(t, err) + assert.Nil(t, descriptor) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + assert.NoError(t, err) + assert.NotNil(t, descriptor) + } + }) + } +} + +// TestCheckArtifactType tests artifact type validation. +func TestCheckArtifactType(t *testing.T) { + tests := []struct { + name string + artifactType string + imageName string + expectWarning bool + warningMessage string + }{ + { + name: "Supported Atmos artifact type", + artifactType: "application/vnd.atmos.component.terraform.v1+tar+gzip", + imageName: "test-image", + expectWarning: false, + }, + { + name: "Supported OpenTofu artifact type", + artifactType: "application/vnd.opentofu.modulepkg", + imageName: "test-image", + expectWarning: false, + }, + { + name: "Supported Terraform artifact type", + artifactType: "application/vnd.terraform.module.v1+tar+gzip", + imageName: "test-image", + expectWarning: false, + }, + { + name: "Unsupported artifact type", + artifactType: "application/vnd.unsupported.type", + imageName: "test-image", + expectWarning: true, + warningMessage: "OCI image artifact type not recognized", + }, + { + name: "Empty artifact type", + artifactType: "", + imageName: "test-image", + expectWarning: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a mock manifest + manifest := &ocispec.Manifest{ + ArtifactType: tt.artifactType, + } + + // Serialize the manifest + manifestBytes, err := json.Marshal(manifest) + require.NoError(t, err) + + // Create a mock descriptor + descriptor := &remote.Descriptor{ + Manifest: manifestBytes, + } + + // Capture log output + var logOutput strings.Builder + originalLogger := log.Default() + log.SetDefault(log.NewWithOptions(&logOutput, log.Options{ + Level: log.DebugLevel, + })) + defer log.SetDefault(originalLogger) + + checkArtifactType(descriptor, tt.imageName) + + if tt.expectWarning { + assert.Contains(t, logOutput.String(), tt.warningMessage) + } else { + assert.NotContains(t, logOutput.String(), "not recognized") + } + }) + } +} + +// TestParseOCIManifest tests OCI manifest parsing. +func TestParseOCIManifest(t *testing.T) { + tests := []struct { + name string + manifest *ocispec.Manifest + expectError bool + }{ + { + name: "Valid manifest", + manifest: &ocispec.Manifest{ + ArtifactType: "application/vnd.atmos.component.terraform.v1+tar+gzip", + Layers: []ocispec.Descriptor{ + { + MediaType: "application/vnd.docker.image.rootfs.diff.tar.gzip", + Digest: "sha256:test-digest", + Size: 1024, + }, + }, + }, + expectError: false, + }, + { + name: "Invalid JSON", + manifest: nil, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var reader io.Reader + if tt.manifest != nil { + manifestBytes, err := json.Marshal(tt.manifest) + require.NoError(t, err) + reader = bytes.NewReader(manifestBytes) + } else { + reader = strings.NewReader("invalid json") + } + + result, err := parseOCIManifest(reader) + + if tt.expectError { + assert.Error(t, err) + assert.Nil(t, result) + } else { + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, tt.manifest.ArtifactType, result.ArtifactType) + } + }) + } +} diff --git a/internal/exec/oci_utils.go b/internal/exec/oci_utils.go index 1895665112..b41e08e892 100644 --- a/internal/exec/oci_utils.go +++ b/internal/exec/oci_utils.go @@ -1,15 +1,18 @@ package exec import ( + "archive/zip" "bytes" "encoding/json" + "errors" "fmt" "io" "os" + "path/filepath" "strings" log "github.com/charmbracelet/log" // Charmbracelet structured logger - "github.com/pkg/errors" + "github.com/spf13/viper" "github.com/google/go-containerregistry/pkg/authn" "github.com/google/go-containerregistry/pkg/name" @@ -21,13 +24,40 @@ import ( "github.com/cloudposse/atmos/pkg/schema" ) -var ErrNoLayers = errors.New("the OCI image does not have any layers") +var ( + ErrNoLayers = errors.New("the OCI image does not have any layers") + errIllegalZipFilePath = errors.New("illegal file path in ZIP") + errFailedToCreateDirectory = errors.New("failed to create directory") + errFailedToOpenZipFile = errors.New("failed to open file in ZIP") + errFailedToCreateFile = errors.New("failed to create file") + errFailedToCopyFile = errors.New("failed to copy file") + errFailedToReadZipData = errors.New("failed to read ZIP data") + errFailedToCreateZipReader = errors.New("failed to create ZIP reader") + errFailedToResolveDestination = errors.New("failed to resolve destination directory") + errZipSizeExceeded = errors.New("ZIP file size exceeds maximum allowed size") + errNoAuthenticationFound = errors.New("no authentication found") +) const ( targetArtifactType = "application/vnd.atmos.component.terraform.v1+tar+gzip" // Target artifact type for Atmos components - githubTokenEnv = "GITHUB_TOKEN" + // Additional supported artifact types + opentofuArtifactType = "application/vnd.opentofu.modulepkg" // OpenTofu module package + terraformArtifactType = "application/vnd.terraform.module.v1+tar+gzip" // Terraform module package + logFieldRegistry = "registry" + logFieldIndex = "index" + logFieldDigest = "digest" + logFieldError = "error" + defaultDirMode = 0o755 // Default directory permissions (rwxr-xr-x) + maxZipSize = 100 * 1024 * 1024 // Maximum ZIP file size: 100MB (prevents decompression bomb) ) +// bindEnv binds environment variables to Viper with fallback support. +func bindEnv(v *viper.Viper, key string, envVars ...string) { + if err := v.BindEnv(append([]string{key}, envVars...)...); err != nil { + log.Debug("Failed to bind environment variable", "key", key, "envVars", envVars, logFieldError, err) + } +} + // processOciImage processes an OCI image and extracts its layers to the specified destination directory. func processOciImage(atmosConfig *schema.AtmosConfiguration, imageName string, destDir string) error { tempDir, err := os.MkdirTemp("", uuid.New().String()) @@ -38,18 +68,18 @@ func processOciImage(atmosConfig *schema.AtmosConfiguration, imageName string, d ref, err := name.ParseReference(imageName) if err != nil { - log.Error("Failed to parse OCI image reference", "image", imageName, "error", err) + log.Error("Failed to parse OCI image reference", "image", imageName, logFieldError, err) return fmt.Errorf("invalid image reference: %w", err) } - descriptor, err := pullImage(ref) + descriptor, err := pullImage(ref, atmosConfig) if err != nil { return fmt.Errorf("failed to pull image: %w", err) } img, err := descriptor.Image() if err != nil { - log.Error("Failed to get image descriptor", "image", imageName, "error", err) + log.Error("Failed to get image descriptor", "image", imageName, logFieldError, err) return fmt.Errorf("cannot get a descriptor for the OCI image '%s': %w", imageName, err) } @@ -57,7 +87,7 @@ func processOciImage(atmosConfig *schema.AtmosConfiguration, imageName string, d layers, err := img.Layers() if err != nil { - log.Error("Failed to retrieve layers from OCI image", "image", imageName, "error", err) + log.Error("Failed to retrieve layers from OCI image", "image", imageName, logFieldError, err) return fmt.Errorf("failed to get image layers: %w", err) } @@ -66,83 +96,484 @@ func processOciImage(atmosConfig *schema.AtmosConfiguration, imageName string, d return ErrNoLayers } + successfulLayers := 0 for i, layer := range layers { if err := processLayer(layer, i, destDir); err != nil { - return fmt.Errorf("failed to process layer %d: %w", i, err) + log.Warn("Failed to process layer", logFieldIndex, i, logFieldError, err) + continue // Continue with other layers instead of failing completely } + successfulLayers++ + } + + // Check if any files were actually extracted + files, err := os.ReadDir(destDir) + switch { + case err != nil: + log.Warn("Could not read destination directory", "dir", destDir, logFieldError, err) + case len(files) == 0: + log.Warn("No files were extracted to destination directory", "dir", destDir, "totalLayers", len(layers), "successfulLayers", successfulLayers) + default: + log.Debug("Successfully extracted files", "dir", destDir, "fileCount", len(files), "totalLayers", len(layers), "successfulLayers", successfulLayers) } return nil } // pullImage pulls an OCI image from the specified reference and returns its descriptor. -func pullImage(ref name.Reference) (*remote.Descriptor, error) { +func pullImage(ref name.Reference, atmosConfig *schema.AtmosConfiguration) (*remote.Descriptor, error) { var opts []remote.Option - opts = append(opts, remote.WithAuth(authn.Anonymous)) // Get registry from parsed reference registry := ref.Context().Registry.Name() - if strings.EqualFold(registry, "ghcr.io") { - githubToken := os.Getenv(githubTokenEnv) - if githubToken != "" { - opts = append(opts, remote.WithAuth(&authn.Basic{ - Username: "oauth2", - Password: githubToken, - })) - log.Debug("Using GitHub token for authentication", "registry", registry) + + // Try to get authentication from various sources + auth, err := getRegistryAuth(registry, atmosConfig) + if err != nil { + if errors.Is(err, errNoAuthenticationFound) { + log.Debug("No authentication found, using anonymous.", logFieldRegistry, registry) + opts = append(opts, remote.WithAuth(authn.Anonymous)) + } else { + log.Error("Registry auth error.", logFieldRegistry, registry, logFieldError, err) + return nil, fmt.Errorf("resolve registry auth: %w", err) } + } else { + opts = append(opts, remote.WithAuth(auth)) + log.Debug("Using authentication for registry", logFieldRegistry, registry) } descriptor, err := remote.Get(ref, opts...) if err != nil { - log.Error("Failed to pull OCI image", "image", ref.Name(), "error", err) + log.Error("Failed to pull OCI image", "image", ref.Name(), logFieldError, err) return nil, fmt.Errorf("failed to pull image '%s': %w", ref.Name(), err) } return descriptor, nil } +// tryGitHubAuth attempts GitHub Container Registry authentication. +func tryGitHubAuth(registry string, atmosConfig *schema.AtmosConfiguration) (authn.Authenticator, error) { + if !strings.EqualFold(registry, "ghcr.io") { + return nil, fmt.Errorf("not a GitHub registry") + } + + auth, err := getGitHubAuth(registry, atmosConfig) + if err == nil { + log.Debug("Using GitHub authentication", logFieldRegistry, registry) + } + return auth, err +} + +// tryDockerAuth attempts Docker credential helper authentication. +func tryDockerAuth(registry string, atmosConfig *schema.AtmosConfiguration) (authn.Authenticator, error) { + auth, err := getDockerAuth(registry, atmosConfig) + if err == nil { + log.Debug("Using Docker config authentication", logFieldRegistry, registry) + } + return auth, err +} + +// tryEnvironmentAuth attempts authentication using environment variables. +func tryEnvironmentAuth(registry string) (authn.Authenticator, error) { + v := viper.New() + + // Normalize registry name by replacing dots and hyphens with underscores for valid env var names + registryEnvName := strings.ToUpper(strings.NewReplacer(".", "_", "-", "_").Replace(registry)) + usernameKey := fmt.Sprintf("%s_username", registryEnvName) + passwordKey := fmt.Sprintf("%s_password", registryEnvName) + + // Bind the registry-specific environment variables + bindEnv(v, usernameKey, + fmt.Sprintf("%s_USERNAME", registryEnvName), + fmt.Sprintf("ATMOS_%s_USERNAME", registryEnvName), + ) + bindEnv(v, passwordKey, + fmt.Sprintf("%s_PASSWORD", registryEnvName), + fmt.Sprintf("ATMOS_%s_PASSWORD", registryEnvName), + ) + + username := v.GetString(usernameKey) + password := v.GetString(passwordKey) + + if username == "" || password == "" { + return nil, fmt.Errorf("no environment credentials found") + } + + log.Debug("Using environment variable authentication", logFieldRegistry, registry) + return &authn.Basic{ + Username: username, + Password: password, + }, nil +} + +// tryECRAuth attempts AWS ECR authentication. +func tryECRAuth(registry string) (authn.Authenticator, error) { + if !strings.Contains(registry, "dkr.ecr") || !strings.Contains(registry, "amazonaws.com") { + return nil, fmt.Errorf("not an ECR registry") + } + + auth, err := getECRAuth(registry) + if err == nil { + log.Debug("Using AWS ECR authentication", logFieldRegistry, registry) + } + return auth, err +} + +// tryACRAuth attempts Azure Container Registry authentication. +func tryACRAuth(registry string, atmosConfig *schema.AtmosConfiguration) (authn.Authenticator, error) { + if !strings.Contains(registry, "azurecr.io") { + return nil, fmt.Errorf("not an Azure registry") + } + + auth, err := getACRAuth(registry, atmosConfig) + if err == nil { + log.Debug("Using Azure ACR authentication", logFieldRegistry, registry) + } + return auth, err +} + +// tryGCRAuth attempts Google Container Registry authentication. +func tryGCRAuth(registry string) (authn.Authenticator, error) { + if !strings.Contains(registry, "gcr.io") && !strings.Contains(registry, "pkg.dev") { + return nil, fmt.Errorf("not a Google registry") + } + + auth, err := getGCRAuth(registry) + if err == nil { + log.Debug("Using Google GCR authentication", logFieldRegistry, registry) + } + return auth, err +} + +// getRegistryAuth attempts to find authentication credentials for the given registry. +// It checks multiple sources in order of preference: +// 1. GitHub Container Registry (ghcr.io) with GITHUB_TOKEN. +// 2. Docker credential helpers (from ~/.docker/config.json). +// 3. Environment variables for specific registries. +// 4. AWS ECR authentication (if AWS credentials are available). +func getRegistryAuth(registry string, atmosConfig *schema.AtmosConfiguration) (authn.Authenticator, error) { + // Try authentication methods in order of preference + authMethods := []func() (authn.Authenticator, error){ + func() (authn.Authenticator, error) { return tryGitHubAuth(registry, atmosConfig) }, + func() (authn.Authenticator, error) { return tryDockerAuth(registry, atmosConfig) }, + func() (authn.Authenticator, error) { return tryEnvironmentAuth(registry) }, + func() (authn.Authenticator, error) { return tryECRAuth(registry) }, + func() (authn.Authenticator, error) { return tryACRAuth(registry, atmosConfig) }, + func() (authn.Authenticator, error) { return tryGCRAuth(registry) }, + } + + for _, method := range authMethods { + if auth, err := method(); err == nil { + return auth, nil + } + } + + return nil, fmt.Errorf("%w %s", errNoAuthenticationFound, registry) +} + +// handleExtractionError handles extraction errors by attempting alternative extraction methods. +func handleExtractionError(extractionErr error, layer v1.Layer, uncompressed io.ReadCloser, destDir string, index int, layerDesc v1.Hash) error { + log.Error("Layer extraction failed", logFieldIndex, index, logFieldDigest, layerDesc, logFieldError, extractionErr) + + // Try alternative extraction methods for different formats + log.Debug("Attempting alternative extraction methods", logFieldIndex, index, logFieldDigest, layerDesc) + + // Reset the uncompressed reader + if uncompressed != nil { + _ = uncompressed.Close() + } + uncompressed, err := layer.Uncompressed() + if err != nil { + log.Error("Failed to reset uncompressed reader", logFieldIndex, index, logFieldDigest, layerDesc, logFieldError, err) + return fmt.Errorf("layer decompression error: %w", err) + } + defer func() { + if uncompressed != nil { + _ = uncompressed.Close() + } + }() + + // Try to extract as raw data first + if err := extractRawData(uncompressed, destDir, index); err != nil { + log.Error("Raw data extraction also failed", logFieldIndex, index, logFieldDigest, layerDesc, logFieldError, err) + + // If this is the first layer and it fails, it might be metadata + if index == 0 { + log.Warn("First layer extraction failed, this might be metadata. Skipping layer.", logFieldIndex, index, logFieldDigest, layerDesc) + return nil // Skip this layer instead of failing + } + + return fmt.Errorf("all extraction methods failed: %w", err) + } + + log.Debug("Successfully extracted layer using alternative method", logFieldIndex, index, logFieldDigest, layerDesc) + return nil +} + // processLayer processes a single OCI layer and extracts its contents to the specified destination directory. func processLayer(layer v1.Layer, index int, destDir string) error { layerDesc, err := layer.Digest() if err != nil { - log.Warn("Skipping layer with invalid digest", "index", index, "error", err) + log.Warn("Skipping layer with invalid digest", logFieldIndex, index, logFieldError, err) return nil } + // Get layer size for debugging + size, err := layer.Size() + if err != nil { + log.Warn("Could not get layer size", logFieldIndex, index, logFieldDigest, layerDesc, logFieldError, err) + } else { + log.Debug("Processing layer", logFieldIndex, index, logFieldDigest, layerDesc, "size", size) + } + + // Get layer media type for debugging + mediaType, err := layer.MediaType() + if err != nil { + log.Warn("Could not get layer media type", logFieldIndex, index, logFieldDigest, layerDesc, logFieldError, err) + } else { + log.Debug("Layer media type", logFieldIndex, index, logFieldDigest, layerDesc, "mediaType", mediaType) + } + uncompressed, err := layer.Uncompressed() if err != nil { - log.Error("Layer decompression failed", "index", index, "digest", layerDesc, "error", err) + log.Error("Layer decompression failed", logFieldIndex, index, logFieldDigest, layerDesc, logFieldError, err) return fmt.Errorf("layer decompression error: %w", err) } - defer uncompressed.Close() + defer func() { + if uncompressed != nil { + _ = uncompressed.Close() + } + }() + + // Try to extract the layer based on media type + var extractionErr error + + // Check if it's a ZIP file + mediaTypeStr := string(mediaType) + if strings.Contains(mediaTypeStr, "zip") { + log.Debug("Detected ZIP layer, extracting as ZIP", logFieldIndex, index, logFieldDigest, layerDesc, "mediaType", mediaTypeStr) + extractionErr = extractZipFile(uncompressed, destDir) + } else { + // Default to tar extraction + log.Debug("Extracting as TAR", logFieldIndex, index, logFieldDigest, layerDesc, "mediaType", mediaTypeStr) + extractionErr = extractTarball(uncompressed, destDir) + } + + if extractionErr != nil { + return handleExtractionError(extractionErr, layer, uncompressed, destDir, index, layerDesc) + } + + log.Debug("Successfully extracted layer", logFieldIndex, index, logFieldDigest, layerDesc) + return nil +} + +// extractRawData attempts to extract raw data from the layer as a fallback. +func extractRawData(reader io.Reader, destDir string, layerIndex int) error { + // Create a temporary file to store the raw data + tempFile := filepath.Join(destDir, fmt.Sprintf("layer_%d_raw", layerIndex)) + + file, err := os.Create(tempFile) + if err != nil { + return fmt.Errorf("failed to create temp file: %w", err) + } + defer file.Close() + + // Copy the raw data + _, err = io.Copy(file, reader) + if err != nil { + return fmt.Errorf("failed to copy raw data: %w", err) + } + + log.Debug("Extracted raw data to temp file", "file", tempFile) + return nil +} + +// validateZipFilePath validates a file path from a ZIP archive for security. +func validateZipFilePath(fileName string) error { + // Check for empty or null filename + if fileName == "" { + return fmt.Errorf("%w: empty filename", errIllegalZipFilePath) + } + + // Check for absolute paths + if filepath.IsAbs(fileName) { + return fmt.Errorf("%w: absolute path not allowed: %s", errIllegalZipFilePath, fileName) + } + + // Check for path traversal patterns + if strings.Contains(fileName, "..") { + return fmt.Errorf("%w: path traversal not allowed: %s", errIllegalZipFilePath, fileName) + } + + // Check for Windows absolute paths (drive letter followed by colon and backslash) + if len(fileName) >= 3 && fileName[1] == ':' && (fileName[2] == '\\' || fileName[2] == '/') { + return fmt.Errorf("%w: Windows absolute path not allowed: %s", errIllegalZipFilePath, fileName) + } + + // Check for leading slashes or backslashes (Unix/Windows absolute-like paths) + if strings.HasPrefix(fileName, "/") || strings.HasPrefix(fileName, "\\") { + return fmt.Errorf("%w: leading separator not allowed: %s", errIllegalZipFilePath, fileName) + } + + // Check for null bytes (potential injection) + if strings.Contains(fileName, "\x00") { + return fmt.Errorf("%w: null byte not allowed: %s", errIllegalZipFilePath, fileName) + } + + return nil +} + +// resolveZipFilePath resolves and validates the destination path for a ZIP file entry. +func resolveZipFilePath(destDir, fileName string) (string, error) { + // Ensure destination directory is absolute and clean + cleanDest, err := filepath.Abs(filepath.Clean(destDir)) + if err != nil { + return "", fmt.Errorf("%w: %v", errFailedToResolveDestination, err) + } + + // Join the paths and clean the result + joined := filepath.Join(cleanDest, fileName) + filePath := filepath.Clean(joined) + + // Ensure the resolved path is within the destination directory + // This prevents directory traversal attacks + if !strings.HasPrefix(filePath, cleanDest+string(os.PathSeparator)) && filePath != cleanDest { + return "", fmt.Errorf("%w: path outside destination directory: %s", errIllegalZipFilePath, fileName) + } + + return filePath, nil +} + +// extractZipFileEntry extracts a single file from a ZIP archive. +func extractZipFileEntry(file *zip.File, destDir string) error { + // Validate the file path + if err := validateZipFilePath(file.Name); err != nil { + return err + } + + // Resolve the destination path + filePath, err := resolveZipFilePath(destDir, file.Name) + if err != nil { + return err + } + + // Create parent directories if they don't exist + if err := os.MkdirAll(filepath.Dir(filePath), defaultDirMode); err != nil { + return fmt.Errorf("%w for %s: %w", errFailedToCreateDirectory, file.Name, err) + } + + // Open the file in the ZIP + rc, err := file.Open() + if err != nil { + return fmt.Errorf("%w %s: %w", errFailedToOpenZipFile, file.Name, err) + } + defer rc.Close() + + // Create the destination file + dstFile, err := os.Create(filePath) + if err != nil { + return fmt.Errorf("%w %s: %w", errFailedToCreateFile, filePath, err) + } + defer dstFile.Close() + + // Copy the file contents + if _, err := io.Copy(dstFile, rc); err != nil { + return fmt.Errorf("%w %s: %w", errFailedToCopyFile, file.Name, err) + } + + log.Debug("Extracted file from ZIP", "file", file.Name, "path", filePath) + return nil +} + +// shouldSkipZipFile determines if a ZIP file entry should be skipped. +func shouldSkipZipFile(file *zip.File) (bool, string) { + // Skip directories + if file.FileInfo().IsDir() { + return true, "directory" + } + + // Skip symlinks for security + if file.FileInfo().Mode()&os.ModeSymlink != 0 { + return true, "symlink" + } - if err := extractTarball(uncompressed, destDir); err != nil { - log.Error("Layer extraction failed", "index", index, "digest", layerDesc, "error", err) - return fmt.Errorf("tarball extraction error: %w", err) + return false, "" +} + +// extractZipFile extracts a ZIP file from an io.Reader into the destination directory. +func extractZipFile(reader io.Reader, destDir string) error { + // Read the ZIP data with size limit to prevent decompression bomb attacks + limitedReader := io.LimitReader(reader, int64(maxZipSize)+1) + zipData, err := io.ReadAll(limitedReader) + if err != nil { + return fmt.Errorf("%w: %v", errFailedToReadZipData, err) + } + + // Reject if ZIP exceeds configured max size. + if len(zipData) > maxZipSize { + return fmt.Errorf("%w: %d bytes", errZipSizeExceeded, maxZipSize) } + // Create a ZIP reader + zipReader, err := zip.NewReader(bytes.NewReader(zipData), int64(len(zipData))) + if err != nil { + return fmt.Errorf("%w: %v", errFailedToCreateZipReader, err) + } + + // Extract each file in the ZIP + for _, file := range zipReader.File { + // Check if we should skip this file + if skip, reason := shouldSkipZipFile(file); skip { + if reason == "symlink" { + log.Warn("Skipping symlink in ZIP", "name", file.Name) + } + continue + } + + // Extract the file + if err := extractZipFileEntry(file, destDir); err != nil { + return err + } + } + + log.Debug("Successfully extracted ZIP file", "destination", destDir) return nil } -// checkArtifactType to check and log artifact type mismatches . +// checkArtifactType to check and log artifact type mismatches. func checkArtifactType(descriptor *remote.Descriptor, imageName string) { manifest, err := parseOCIManifest(bytes.NewReader(descriptor.Manifest)) if err != nil { - log.Error("Failed to parse OCI manifest", "image", imageName, "error", err) + log.Error("Failed to parse OCI manifest", "image", imageName, logFieldError, err) return } - if manifest.ArtifactType != targetArtifactType { - // log that don't match the target artifact type - log.Warn("OCI image does not match the target artifact type", "image", imageName, "artifactType", manifest.ArtifactType) + + // Check if the artifact type is supported + supportedTypes := []string{ + targetArtifactType, + opentofuArtifactType, + terraformArtifactType, + } + + isSupported := false + for _, supportedType := range supportedTypes { + if manifest.ArtifactType == supportedType { + isSupported = true + break + } + } + + if !isSupported { + log.Warn("OCI image artifact type not recognized", "image", imageName, "artifactType", manifest.ArtifactType, "supportedTypes", supportedTypes) + } else { + log.Debug("OCI image artifact type is supported", "image", imageName, "artifactType", manifest.ArtifactType) } } -// ParseOCIManifest reads and decodes an OCI manifest from a JSON file. +// parseOCIManifest reads and decodes an OCI manifest from a JSON file. func parseOCIManifest(manifestBytes io.Reader) (*ocispec.Manifest, error) { var manifest ocispec.Manifest if err := json.NewDecoder(manifestBytes).Decode(&manifest); err != nil { - return nil, err + return nil, fmt.Errorf("parse OCI manifest: %w", err) } return &manifest, nil diff --git a/internal/exec/oci_utils_test.go b/internal/exec/oci_utils_test.go new file mode 100644 index 0000000000..fb657e3753 --- /dev/null +++ b/internal/exec/oci_utils_test.go @@ -0,0 +1,60 @@ +package exec + +import ( + "os" + "testing" + + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" +) + +// TestBindEnv tests the Viper environment binding function. +func TestBindEnv(t *testing.T) { + tests := []struct { + name string + key string + envVars []string + setEnv map[string]string + expected string + }{ + { + name: "Single environment variable", + key: "test_key", + envVars: []string{"TEST_VAR"}, + setEnv: map[string]string{"TEST_VAR": "test_value"}, + expected: "test_value", + }, + { + name: "Multiple environment variables with fallback", + key: "test_key", + envVars: []string{"PRIMARY_VAR", "FALLBACK_VAR"}, + setEnv: map[string]string{"FALLBACK_VAR": "fallback_value"}, + expected: "fallback_value", + }, + { + name: "No environment variables set", + key: "test_key", + envVars: []string{"MISSING_VAR"}, + setEnv: map[string]string{}, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set up environment variables + for key, value := range tt.setEnv { + os.Setenv(key, value) + defer os.Unsetenv(key) + } + + // Create Viper instance + v := viper.New() + bindEnv(v, tt.key, tt.envVars...) + + // Test that the value is accessible + result := v.GetString(tt.key) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/pkg/config/load.go b/pkg/config/load.go index f7858c8e31..c341babbc8 100644 --- a/pkg/config/load.go +++ b/pkg/config/load.go @@ -120,6 +120,14 @@ func setEnv(v *viper.Viper) { bindEnv(v, "settings.telemetry.enabled", "ATMOS_TELEMETRY_ENABLED") bindEnv(v, "settings.telemetry.token", "ATMOS_TELEMETRY_TOKEN") bindEnv(v, "settings.telemetry.endpoint", "ATMOS_TELEMETRY_ENDPOINT") + + // OCI Registry Authentication settings + bindEnv(v, "settings.oci.github_token", "ATMOS_OCI_GITHUB_TOKEN", "GITHUB_TOKEN") + bindEnv(v, "settings.oci.azure_client_id", "ATMOS_OCI_AZURE_CLIENT_ID", "AZURE_CLIENT_ID") + bindEnv(v, "settings.oci.azure_client_secret", "ATMOS_OCI_AZURE_CLIENT_SECRET", "AZURE_CLIENT_SECRET") + bindEnv(v, "settings.oci.azure_tenant_id", "ATMOS_OCI_AZURE_TENANT_ID", "AZURE_TENANT_ID") + bindEnv(v, "settings.oci.azure_cli_auth", "ATMOS_OCI_AZURE_CLI_AUTH", "AZURE_CLI_AUTH") + bindEnv(v, "settings.oci.docker_config", "ATMOS_OCI_DOCKER_CONFIG", "DOCKER_CONFIG") } func bindEnv(v *viper.Viper, key ...string) { diff --git a/pkg/schema/schema.go b/pkg/schema/schema.go index 5f6d2d40a6..f4d580cdfd 100644 --- a/pkg/schema/schema.go +++ b/pkg/schema/schema.go @@ -243,6 +243,8 @@ type AtmosSettings struct { Pro ProSettings `yaml:"pro,omitempty" json:"pro,omitempty" mapstructure:"pro"` // Telemetry settings Telemetry TelemetrySettings `yaml:"telemetry,omitempty" json:"telemetry,omitempty" mapstructure:"telemetry"` + // OCI Registry Authentication settings + OCI OCISettings `yaml:"oci,omitempty" json:"oci,omitempty" mapstructure:"oci"` } // TelemetrySettings contains configuration for telemetry collection. @@ -267,6 +269,16 @@ type GithubOIDCSettings struct { RequestToken string `yaml:"request_token,omitempty" json:"request_token,omitempty" mapstructure:"request_token"` } +// OCISettings contains OCI Registry Authentication configuration. +type OCISettings struct { + GithubToken string `yaml:"github_token,omitempty" json:"github_token,omitempty" mapstructure:"github_token"` + AzureClientID string `yaml:"azure_client_id,omitempty" json:"azure_client_id,omitempty" mapstructure:"azure_client_id"` + AzureClientSecret string `yaml:"azure_client_secret,omitempty" json:"azure_client_secret,omitempty" mapstructure:"azure_client_secret"` + AzureTenantID string `yaml:"azure_tenant_id,omitempty" json:"azure_tenant_id,omitempty" mapstructure:"azure_tenant_id"` + AzureCLIAuth string `yaml:"azure_cli_auth,omitempty" json:"azure_cli_auth,omitempty" mapstructure:"azure_cli_auth"` + DockerConfig string `yaml:"docker_config,omitempty" json:"docker_config,omitempty" mapstructure:"docker_config"` +} + type Docs struct { // Deprecated: this has moved to `settings.terminal.max-width` MaxWidth int `yaml:"max-width" json:"max_width" mapstructure:"max-width"`