@@ -19,6 +19,7 @@ import (
1919 "context"
2020 "encoding/base64"
2121 "encoding/json"
22+ "errors"
2223 "fmt"
2324
2425 "github.com/in-toto/in-toto-golang/in_toto"
@@ -37,44 +38,52 @@ import (
3738//
3839// If there's no error, and payload is empty means the predicateType did not
3940// match the attestation.
40- func AttestationToPayloadJSON (ctx context.Context , predicateType string , verifiedAttestation oci.Signature ) ([]byte , error ) {
41- // Check the predicate up front, no point in wasting time if it's invalid.
42- predicateURI , err := options .ParsePredicateType (predicateType )
43-
44- if err != nil {
45- return nil , fmt .Errorf ("invalid predicate type: %s" , predicateType )
41+ // Returns the attestation type (PredicateType) if the payload was decoded
42+ // before the error happened, or in the case the predicateType that was
43+ // requested does not match. This is useful for callers to be able to provide
44+ // better error messages. For example, if there's a typo in the predicateType,
45+ // or the predicateType is not the one they are looking for. Without returning
46+ // this, it's hard for users to know which attestations/predicateTypes were
47+ // inspected.
48+ func AttestationToPayloadJSON (ctx context.Context , predicateType string , verifiedAttestation oci.Signature ) ([]byte , string , error ) {
49+ if predicateType == "" {
50+ return nil , "" , errors .New ("missing predicate type" )
51+ }
52+ predicateURI , ok := options .PredicateTypeMap [predicateType ]
53+ if ! ok {
54+ // Not a custom one, use it as is.
55+ predicateURI = predicateType
4656 }
47-
4857 var payloadData map [string ]interface {}
4958
5059 p , err := verifiedAttestation .Payload ()
5160 if err != nil {
52- return nil , fmt .Errorf ("getting payload: %w" , err )
61+ return nil , "" , fmt .Errorf ("getting payload: %w" , err )
5362 }
5463
5564 err = json .Unmarshal (p , & payloadData )
5665 if err != nil {
57- return nil , fmt .Errorf ("unmarshaling payload data" )
66+ return nil , "" , fmt .Errorf ("unmarshaling payload data" )
5867 }
5968
6069 var decodedPayload []byte
6170 if val , ok := payloadData ["payload" ]; ok {
6271 decodedPayload , err = base64 .StdEncoding .DecodeString (val .(string ))
6372 if err != nil {
64- return nil , fmt .Errorf ("decoding payload: %w" , err )
73+ return nil , "" , fmt .Errorf ("decoding payload: %w" , err )
6574 }
6675 } else {
67- return nil , fmt .Errorf ("could not find payload in payload data" )
76+ return nil , "" , fmt .Errorf ("could not find payload in payload data" )
6877 }
6978
7079 // Only apply the policy against the requested predicate type
7180 var statement in_toto.Statement
7281 if err := json .Unmarshal (decodedPayload , & statement ); err != nil {
73- return nil , fmt .Errorf ("unmarshal in-toto statement: %w" , err )
82+ return nil , "" , fmt .Errorf ("unmarshal in-toto statement: %w" , err )
7483 }
7584 if statement .PredicateType != predicateURI {
7685 // This is not the predicate we're looking for, so skip it.
77- return nil , nil
86+ return nil , statement . PredicateType , nil
7887 }
7988
8089 // NB: In many (all?) of these cases, we could just return the
@@ -85,59 +94,59 @@ func AttestationToPayloadJSON(ctx context.Context, predicateType string, verifie
8594 case options .PredicateCustom :
8695 payload , err = json .Marshal (statement )
8796 if err != nil {
88- return nil , fmt .Errorf ("generating CosignStatement: %w" , err )
97+ return nil , statement . PredicateType , fmt .Errorf ("generating CosignStatement: %w" , err )
8998 }
9099 case options .PredicateLink :
91100 var linkStatement in_toto.LinkStatement
92101 if err := json .Unmarshal (decodedPayload , & linkStatement ); err != nil {
93- return nil , fmt .Errorf ("unmarshaling LinkStatement: %w" , err )
102+ return nil , statement . PredicateType , fmt .Errorf ("unmarshaling LinkStatement: %w" , err )
94103 }
95104 payload , err = json .Marshal (linkStatement )
96105 if err != nil {
97- return nil , fmt .Errorf ("marshaling LinkStatement: %w" , err )
106+ return nil , statement . PredicateType , fmt .Errorf ("marshaling LinkStatement: %w" , err )
98107 }
99108 case options .PredicateSLSA :
100109 var slsaProvenanceStatement in_toto.ProvenanceStatement
101110 if err := json .Unmarshal (decodedPayload , & slsaProvenanceStatement ); err != nil {
102- return nil , fmt .Errorf ("unmarshaling ProvenanceStatement): %w" , err )
111+ return nil , statement . PredicateType , fmt .Errorf ("unmarshaling ProvenanceStatement): %w" , err )
103112 }
104113 payload , err = json .Marshal (slsaProvenanceStatement )
105114 if err != nil {
106- return nil , fmt .Errorf ("marshaling ProvenanceStatement: %w" , err )
115+ return nil , statement . PredicateType , fmt .Errorf ("marshaling ProvenanceStatement: %w" , err )
107116 }
108117 case options .PredicateSPDX , options .PredicateSPDXJSON :
109118 var spdxStatement in_toto.SPDXStatement
110119 if err := json .Unmarshal (decodedPayload , & spdxStatement ); err != nil {
111- return nil , fmt .Errorf ("unmarshaling SPDXStatement: %w" , err )
120+ return nil , statement . PredicateType , fmt .Errorf ("unmarshaling SPDXStatement: %w" , err )
112121 }
113122 payload , err = json .Marshal (spdxStatement )
114123 if err != nil {
115- return nil , fmt .Errorf ("marshaling SPDXStatement: %w" , err )
124+ return nil , statement . PredicateType , fmt .Errorf ("marshaling SPDXStatement: %w" , err )
116125 }
117126 case options .PredicateCycloneDX :
118127 var cyclonedxStatement in_toto.CycloneDXStatement
119128 if err := json .Unmarshal (decodedPayload , & cyclonedxStatement ); err != nil {
120- return nil , fmt .Errorf ("unmarshaling CycloneDXStatement: %w" , err )
129+ return nil , statement . PredicateType , fmt .Errorf ("unmarshaling CycloneDXStatement: %w" , err )
121130 }
122131 payload , err = json .Marshal (cyclonedxStatement )
123132 if err != nil {
124- return nil , fmt .Errorf ("marshaling CycloneDXStatement: %w" , err )
133+ return nil , statement . PredicateType , fmt .Errorf ("marshaling CycloneDXStatement: %w" , err )
125134 }
126135 case options .PredicateVuln :
127136 var vulnStatement attestation.CosignVulnStatement
128137 if err := json .Unmarshal (decodedPayload , & vulnStatement ); err != nil {
129- return nil , fmt .Errorf ("unmarshaling CosignVulnStatement: %w" , err )
138+ return nil , statement . PredicateType , fmt .Errorf ("unmarshaling CosignVulnStatement: %w" , err )
130139 }
131140 payload , err = json .Marshal (vulnStatement )
132141 if err != nil {
133- return nil , fmt .Errorf ("marshaling CosignVulnStatement: %w" , err )
142+ return nil , statement . PredicateType , fmt .Errorf ("marshaling CosignVulnStatement: %w" , err )
134143 }
135144 default :
136145 // Valid URI type reaches here.
137146 payload , err = json .Marshal (statement )
138147 if err != nil {
139- return nil , fmt .Errorf ("generating Statement: %w" , err )
148+ return nil , statement . PredicateType , fmt .Errorf ("generating Statement: %w" , err )
140149 }
141150 }
142- return payload , nil
151+ return payload , statement . PredicateType , nil
143152}
0 commit comments