@@ -166,6 +166,164 @@ func ValidateVoteExtensions(ctx sdk.Context, reqHeight int64, extVoteInfo []abci
166166 return nil
167167}
168168
169+ // FilterVoteExtensions verifies the vote extension correctness and filters out invalid ones
170+ func FilterVoteExtensions (ctx sdk.Context , reqHeight int64 , extVoteInfo []abciTypes.ExtendedVoteInfo , round int32 , validatorSet * stakeTypes.ValidatorSet , milestoneKeeper milestoneKeeper.Keeper , logger log.Logger ) ([]abciTypes.ExtendedVoteInfo , error ) {
171+ validVoteExtensions := make ([]abciTypes.ExtendedVoteInfo , 0 )
172+
173+ // check if VEs are enabled
174+ if err := checkIfVoteExtensionsDisabled (ctx , reqHeight + 1 ); err != nil {
175+ return nil , err
176+ }
177+
178+ // check if reqHeight is the initial height
179+ if reqHeight <= retrieveVoteExtensionsEnableHeight (ctx ) {
180+ if len (extVoteInfo ) != 0 {
181+ return nil , fmt .Errorf ("non-empty VEs received at initial height %d" , reqHeight )
182+ }
183+ return nil , nil
184+ }
185+
186+ // Map to track seen validator addresses
187+ seenValidators := make (map [string ]struct {})
188+ sumVPPerBlockHash := make (map [string ]int64 )
189+
190+ ac := address.HexCodec {}
191+
192+ for _ , vote := range extVoteInfo {
193+
194+ // make sure the BlockIdFlag is valid
195+ if ! isBlockIdFlagValid (vote .BlockIdFlag ) {
196+ logger .Error ("received vote with invalid block ID flag at height, skipping" ,
197+ "blockIDFlag" , vote .BlockIdFlag .String (),
198+ "height" , reqHeight )
199+ continue
200+ }
201+ // if not BlockIDFlagCommit, skip that vote, as it doesn't have relevant information
202+ if vote .BlockIdFlag != cmtTypes .BlockIDFlagCommit {
203+ logger .Error ("wrong block id flag, skipping" ,
204+ "blockIDFlag" , vote .BlockIdFlag .String (),
205+ "height" , reqHeight )
206+ continue
207+ }
208+
209+ valAddrStr , err := ac .BytesToString (vote .Validator .Address )
210+ if err != nil {
211+ return nil , fmt .Errorf ("validator address %v is not valid" , vote .Validator .Address )
212+ }
213+
214+ if len (vote .ExtensionSignature ) == 0 {
215+ return nil , fmt .Errorf ("received empty vote extension signature at height %d from validator %s" , reqHeight , valAddrStr )
216+ }
217+
218+ voteExtension := new (sidetxs.VoteExtension )
219+ if err = voteExtension .Unmarshal (vote .VoteExtension ); err != nil {
220+ logger .Error ("error while unmarshalling vote extension" , "error" , err )
221+ continue
222+ }
223+
224+ if voteExtension .Height != reqHeight - 1 {
225+ logger .Error ("invalid height received for vote extension" , "expected" , reqHeight - 1 , "got" , voteExtension .Height )
226+ continue
227+ }
228+
229+ txHash , err := validateSideTxResponses (voteExtension .SideTxResponses )
230+ if err != nil {
231+ logger .Error ("invalid sideTxResponses detected for validator" , "validator" , valAddrStr , "txHash" , common .Bytes2Hex (txHash ), "error" , err )
232+ continue
233+ }
234+
235+ if err := milestoneAbci .ValidateMilestoneProposition (ctx , & milestoneKeeper , voteExtension .MilestoneProposition ); err != nil {
236+ logger .Error ("invalid milestone proposition detected for validator" , "validator" , valAddrStr , "error" , err )
237+ continue
238+ }
239+
240+ // Check for duplicate votes by the same validator
241+ if _ , found := seenValidators [valAddrStr ]; found {
242+ return nil , fmt .Errorf ("duplicate vote detected from validator %s at height %d" , valAddrStr , reqHeight )
243+ }
244+ // Add validator address to the map
245+ seenValidators [valAddrStr ] = struct {}{}
246+
247+ _ , validator := validatorSet .GetByAddress (valAddrStr )
248+ if validator == nil {
249+ if milestoneAbci .ShouldErrorOnValidatorNotFound (ctx .BlockHeight ()) {
250+ return nil , fmt .Errorf ("failed to get validator %s" , valAddrStr )
251+ }
252+ continue
253+ }
254+
255+ cmtPubKey , err := getValidatorPublicKey (validator )
256+ if err != nil {
257+ return nil , err
258+ }
259+
260+ cve := cmtTypes.CanonicalVoteExtension {
261+ Extension : vote .VoteExtension ,
262+ Height : reqHeight - 1 , // the vote extension was signed in the previous height
263+ Round : int64 (round ),
264+ ChainId : ctx .ChainID (),
265+ }
266+
267+ marshalDelimitedFn := func (msg proto.Message ) ([]byte , error ) {
268+ var buf bytes.Buffer
269+ if _ , err := protoio .NewDelimitedWriter (& buf ).WriteMsg (msg ); err != nil {
270+ return nil , err
271+ }
272+
273+ return buf .Bytes (), nil
274+ }
275+
276+ extSignBytes , err := marshalDelimitedFn (& cve )
277+ if err != nil {
278+ return nil , fmt .Errorf ("failed to encode CanonicalVoteExtension: %w" , err )
279+ }
280+
281+ if ! cmtPubKey .VerifySignature (extSignBytes , vote .ExtensionSignature ) {
282+ return nil , fmt .Errorf ("failed to verify validator %s vote extension signature" , valAddrStr )
283+ }
284+
285+ sumVPPerBlockHash [common .Bytes2Hex (voteExtension .BlockHash )] += validator .VotingPower
286+
287+ validVoteExtensions = append (validVoteExtensions , vote )
288+ }
289+
290+ // Ensure we have at least 2/3 voting power for the submitted vote extensions in each side tx
291+ totalVotingPower := validatorSet .GetTotalVotingPower ()
292+ sumVP := int64 (0 )
293+
294+ majorityVP := totalVotingPower * 2 / 3
295+ var majorityBlockHash string
296+ for sumVPBlockHash , vp := range sumVPPerBlockHash {
297+ if vp > majorityVP {
298+ sumVP = vp
299+ majorityBlockHash = sumVPBlockHash
300+ break
301+ }
302+ }
303+
304+ if sumVP <= majorityVP {
305+ return nil , fmt .Errorf ("insufficient cumulative voting power received to verify vote extensions; got: %d, expected: >%d" , sumVP , majorityVP )
306+ }
307+
308+ logger .Debug ("majority block hash selected" ,
309+ "blockHash" , majorityBlockHash ,
310+ "votingPower" , sumVP ,
311+ "majorityThreshold" , majorityVP )
312+
313+ filteredByBlockHash := make ([]abciTypes.ExtendedVoteInfo , 0 , len (validVoteExtensions ))
314+ for _ , vote := range validVoteExtensions {
315+ ve := new (sidetxs.VoteExtension )
316+ if err := ve .Unmarshal (vote .VoteExtension ); err != nil {
317+ continue
318+ }
319+ if common .Bytes2Hex (ve .BlockHash ) == majorityBlockHash {
320+ filteredByBlockHash = append (filteredByBlockHash , vote )
321+ }
322+ }
323+
324+ return filteredByBlockHash , nil
325+ }
326+
169327// tallyVotes tallies the votes received for the side tx
170328// It returns the lists of txs which got >2/3+ YES, NO and UNSPECIFIED votes respectively
171329func tallyVotes (extVoteInfo []abciTypes.ExtendedVoteInfo , logger log.Logger , totalVotingPower int64 , currentHeight int64 ) ([][]byte , [][]byte , [][]byte , error ) {
0 commit comments