Skip to content

Commit 49e80a1

Browse files
authored
Merge pull request #499 from 0xPolygon/v0.4.4/hotfix
ABCI layer improvements
2 parents f156696 + 644fa88 commit 49e80a1

File tree

3 files changed

+169
-2
lines changed

3 files changed

+169
-2
lines changed

app/abci.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,14 @@ func (app *HeimdallApp) NewPrepareProposalHandler() sdk.PrepareProposalHandler {
4545
return nil, err
4646
}
4747

48-
if err := ValidateVoteExtensions(ctx, req.Height, req.LocalLastCommit.Votes, req.LocalLastCommit.Round, validatorSet, app.MilestoneKeeper); err != nil {
49-
logger.Error("Error occurred while validating VEs in PrepareProposal", err)
48+
validVoteExtensions, err := FilterVoteExtensions(ctx, req.Height, req.LocalLastCommit.Votes, req.LocalLastCommit.Round, validatorSet, app.MilestoneKeeper, logger)
49+
if err != nil {
50+
logger.Error("Error occurred while filtering VEs in PrepareProposal", err)
5051
return nil, err
5152
}
5253

54+
req.LocalLastCommit.Votes = validVoteExtensions
55+
5356
if err := ValidateNonRpVoteExtensions(ctx, req.Height, req.LocalLastCommit.Votes, validatorSet, app.ChainManagerKeeper, app.CheckpointKeeper, app.caller, logger); err != nil {
5457
logger.Error("Error occurred while validating non-rp VEs in PrepareProposal", err)
5558
}

app/vote_ext_utils.go

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,164 @@ func ValidateVoteExtensions(ctx sdk.Context, reqHeight int64, extVoteInfo []abci
166166
return nil
167167
}
168168

169+
// FilterVoteExtensions verifies the vote extension correctness and filters out invalid ones
170+
func FilterVoteExtensions(ctx sdk.Context, reqHeight int64, extVoteInfo []abciTypes.ExtendedVoteInfo, round int32, validatorSet *stakeTypes.ValidatorSet, milestoneKeeper milestoneKeeper.Keeper, logger log.Logger) ([]abciTypes.ExtendedVoteInfo, error) {
171+
validVoteExtensions := make([]abciTypes.ExtendedVoteInfo, 0)
172+
173+
// check if VEs are enabled
174+
if err := checkIfVoteExtensionsDisabled(ctx, reqHeight+1); err != nil {
175+
return nil, err
176+
}
177+
178+
// check if reqHeight is the initial height
179+
if reqHeight <= retrieveVoteExtensionsEnableHeight(ctx) {
180+
if len(extVoteInfo) != 0 {
181+
return nil, fmt.Errorf("non-empty VEs received at initial height %d", reqHeight)
182+
}
183+
return nil, nil
184+
}
185+
186+
// Map to track seen validator addresses
187+
seenValidators := make(map[string]struct{})
188+
sumVPPerBlockHash := make(map[string]int64)
189+
190+
ac := address.HexCodec{}
191+
192+
for _, vote := range extVoteInfo {
193+
194+
// make sure the BlockIdFlag is valid
195+
if !isBlockIdFlagValid(vote.BlockIdFlag) {
196+
logger.Error("received vote with invalid block ID flag at height, skipping",
197+
"blockIDFlag", vote.BlockIdFlag.String(),
198+
"height", reqHeight)
199+
continue
200+
}
201+
// if not BlockIDFlagCommit, skip that vote, as it doesn't have relevant information
202+
if vote.BlockIdFlag != cmtTypes.BlockIDFlagCommit {
203+
logger.Error("wrong block id flag, skipping",
204+
"blockIDFlag", vote.BlockIdFlag.String(),
205+
"height", reqHeight)
206+
continue
207+
}
208+
209+
valAddrStr, err := ac.BytesToString(vote.Validator.Address)
210+
if err != nil {
211+
return nil, fmt.Errorf("validator address %v is not valid", vote.Validator.Address)
212+
}
213+
214+
if len(vote.ExtensionSignature) == 0 {
215+
return nil, fmt.Errorf("received empty vote extension signature at height %d from validator %s", reqHeight, valAddrStr)
216+
}
217+
218+
voteExtension := new(sidetxs.VoteExtension)
219+
if err = voteExtension.Unmarshal(vote.VoteExtension); err != nil {
220+
logger.Error("error while unmarshalling vote extension", "error", err)
221+
continue
222+
}
223+
224+
if voteExtension.Height != reqHeight-1 {
225+
logger.Error("invalid height received for vote extension", "expected", reqHeight-1, "got", voteExtension.Height)
226+
continue
227+
}
228+
229+
txHash, err := validateSideTxResponses(voteExtension.SideTxResponses)
230+
if err != nil {
231+
logger.Error("invalid sideTxResponses detected for validator", "validator", valAddrStr, "txHash", common.Bytes2Hex(txHash), "error", err)
232+
continue
233+
}
234+
235+
if err := milestoneAbci.ValidateMilestoneProposition(ctx, &milestoneKeeper, voteExtension.MilestoneProposition); err != nil {
236+
logger.Error("invalid milestone proposition detected for validator", "validator", valAddrStr, "error", err)
237+
continue
238+
}
239+
240+
// Check for duplicate votes by the same validator
241+
if _, found := seenValidators[valAddrStr]; found {
242+
return nil, fmt.Errorf("duplicate vote detected from validator %s at height %d", valAddrStr, reqHeight)
243+
}
244+
// Add validator address to the map
245+
seenValidators[valAddrStr] = struct{}{}
246+
247+
_, validator := validatorSet.GetByAddress(valAddrStr)
248+
if validator == nil {
249+
if milestoneAbci.ShouldErrorOnValidatorNotFound(ctx.BlockHeight()) {
250+
return nil, fmt.Errorf("failed to get validator %s", valAddrStr)
251+
}
252+
continue
253+
}
254+
255+
cmtPubKey, err := getValidatorPublicKey(validator)
256+
if err != nil {
257+
return nil, err
258+
}
259+
260+
cve := cmtTypes.CanonicalVoteExtension{
261+
Extension: vote.VoteExtension,
262+
Height: reqHeight - 1, // the vote extension was signed in the previous height
263+
Round: int64(round),
264+
ChainId: ctx.ChainID(),
265+
}
266+
267+
marshalDelimitedFn := func(msg proto.Message) ([]byte, error) {
268+
var buf bytes.Buffer
269+
if _, err := protoio.NewDelimitedWriter(&buf).WriteMsg(msg); err != nil {
270+
return nil, err
271+
}
272+
273+
return buf.Bytes(), nil
274+
}
275+
276+
extSignBytes, err := marshalDelimitedFn(&cve)
277+
if err != nil {
278+
return nil, fmt.Errorf("failed to encode CanonicalVoteExtension: %w", err)
279+
}
280+
281+
if !cmtPubKey.VerifySignature(extSignBytes, vote.ExtensionSignature) {
282+
return nil, fmt.Errorf("failed to verify validator %s vote extension signature", valAddrStr)
283+
}
284+
285+
sumVPPerBlockHash[common.Bytes2Hex(voteExtension.BlockHash)] += validator.VotingPower
286+
287+
validVoteExtensions = append(validVoteExtensions, vote)
288+
}
289+
290+
// Ensure we have at least 2/3 voting power for the submitted vote extensions in each side tx
291+
totalVotingPower := validatorSet.GetTotalVotingPower()
292+
sumVP := int64(0)
293+
294+
majorityVP := totalVotingPower * 2 / 3
295+
var majorityBlockHash string
296+
for sumVPBlockHash, vp := range sumVPPerBlockHash {
297+
if vp > majorityVP {
298+
sumVP = vp
299+
majorityBlockHash = sumVPBlockHash
300+
break
301+
}
302+
}
303+
304+
if sumVP <= majorityVP {
305+
return nil, fmt.Errorf("insufficient cumulative voting power received to verify vote extensions; got: %d, expected: >%d", sumVP, majorityVP)
306+
}
307+
308+
logger.Debug("majority block hash selected",
309+
"blockHash", majorityBlockHash,
310+
"votingPower", sumVP,
311+
"majorityThreshold", majorityVP)
312+
313+
filteredByBlockHash := make([]abciTypes.ExtendedVoteInfo, 0, len(validVoteExtensions))
314+
for _, vote := range validVoteExtensions {
315+
ve := new(sidetxs.VoteExtension)
316+
if err := ve.Unmarshal(vote.VoteExtension); err != nil {
317+
continue
318+
}
319+
if common.Bytes2Hex(ve.BlockHash) == majorityBlockHash {
320+
filteredByBlockHash = append(filteredByBlockHash, vote)
321+
}
322+
}
323+
324+
return filteredByBlockHash, nil
325+
}
326+
169327
// tallyVotes tallies the votes received for the side tx
170328
// It returns the lists of txs which got >2/3+ YES, NO and UNSPECIFIED votes respectively
171329
func tallyVotes(extVoteInfo []abciTypes.ExtendedVoteInfo, logger log.Logger, totalVotingPower int64, currentHeight int64) ([][]byte, [][]byte, [][]byte, error) {

x/milestone/abci/abci.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,10 +478,16 @@ func ValidateMilestoneProposition(ctx sdk.Context, milestoneKeeper *keeper.Keepe
478478
return fmt.Errorf("len mismatch between hashes and tds: %d != %d", len(milestoneProp.BlockHashes), len(milestoneProp.BlockTds))
479479
}
480480

481+
duplicateBlockHashes := make(map[string]struct{})
481482
for _, blockHash := range milestoneProp.BlockHashes {
482483
if len(blockHash) != common.HashLength {
483484
return fmt.Errorf("invalid block hash length")
484485
}
486+
duplicateBlockHashes[string(blockHash)] = struct{}{}
487+
}
488+
489+
if len(duplicateBlockHashes) != len(milestoneProp.BlockHashes) {
490+
return fmt.Errorf("duplicate block hashes found")
485491
}
486492

487493
return nil

0 commit comments

Comments
 (0)