Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 22 additions & 27 deletions pkg/fixer/fixer.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ const (

// Fixer must be instantiated via NewFixer.
type Fixer struct {
registeredFixes map[string]any
registeredFixes map[string]fixes.Fix
onConflictOperation OnConflictOperation
registeredRoots []string
versionsMap map[string]ast.RegoVersion
Expand All @@ -35,7 +35,7 @@ type Fixer struct {
// NewFixer instantiates a Fixer.
func NewFixer() *Fixer {
return &Fixer{
registeredFixes: make(map[string]any),
registeredFixes: make(map[string]fixes.Fix),
registeredRoots: make([]string, 0),
onConflictOperation: OnConflictError,
}
Expand Down Expand Up @@ -78,17 +78,11 @@ func (f *Fixer) RegisterRoots(roots ...string) *Fixer {
}

func (f *Fixer) GetFixForName(name string) (fixes.Fix, bool) {
fix, ok := f.registeredFixes[name]
if !ok {
return nil, false
if fix, ok := f.registeredFixes[name]; ok {
return fix, true
}

fixInstance, ok := fix.(fixes.Fix)
if !ok {
return nil, false
}

return fixInstance, true
return nil, false
}

func (f *Fixer) Fix(ctx context.Context, l *linter.Linter, fp fileprovider.FileProvider) (*Report, error) {
Expand Down Expand Up @@ -119,16 +113,13 @@ func (f *Fixer) FixViolations(
return nil, fmt.Errorf("failed to list files: %w", err)
}

// rangeValCopy may be expensive, but this is not critical enough
// to motivate cluttering the code
//nolint:gocritic
for _, violation := range violations {
fixInstance, ok := f.GetFixForName(violation.Title)
for i := range violations {
fixInstance, ok := f.GetFixForName(violations[i].Title)
if !ok {
return nil, fmt.Errorf("no fix for violation %s", violation.Title)
return nil, fmt.Errorf("no fix for violation %s", violations[i].Title)
}

file := violation.Location.File
file := violations[i].Location.File

fc, err := fp.Get(file)
if err != nil {
Expand All @@ -143,7 +134,7 @@ func (f *Fixer) FixViolations(
fixResults, err := fixInstance.Fix(&fixes.FixCandidate{Filename: file, Contents: fc}, &fixes.RuntimeOptions{
BaseDir: util.FindClosestMatchingRoot(abs, f.registeredRoots),
Config: config,
Locations: []report.Location{violation.Location},
Locations: []report.Location{violations[i].Location},
})
if err != nil {
return nil, fmt.Errorf("failed to fix %s: %w", file, err)
Expand Down Expand Up @@ -200,6 +191,11 @@ func (f *Fixer) applyLinterFixes(
versionsMap = f.versionsMap
}

li, err := l.WithDisableAll(true).WithEnabledRules(fixableEnabledRules...).Prepare(ctx)
if err != nil {
return fmt.Errorf("failed to prepare linter for fixing: %w", err)
}

for {
fixMadeInIteration := false

Expand All @@ -208,7 +204,7 @@ func (f *Fixer) applyLinterFixes(
return fmt.Errorf("failed to create linter input: %w", err)
}

rep, err := l.WithDisableAll(true).WithEnabledRules(fixableEnabledRules...).WithInputModules(&in).Lint(ctx)
rep, err := li.WithInputModules(&in).Lint(ctx)
if err != nil {
return fmt.Errorf("failed to lint before fixing: %w", err)
}
Expand All @@ -217,19 +213,18 @@ func (f *Fixer) applyLinterFixes(
break
}

//nolint:gocritic
for _, violation := range rep.Violations {
fixInstance, ok := f.GetFixForName(violation.Title)
for i := range rep.Violations {
fixInstance, ok := f.GetFixForName(rep.Violations[i].Title)
if !ok {
return fmt.Errorf("no fix for violation %s", violation.Title)
return fmt.Errorf("no fix for violation %s", rep.Violations[i].Title)
}

config, err := l.GetConfig()
if err != nil {
return fmt.Errorf("failed to get config: %w", err)
}

file := violation.Location.File
file := rep.Violations[i].Location.File

abs, err := filepath.Abs(file)
if err != nil {
Expand All @@ -246,7 +241,7 @@ func (f *Fixer) applyLinterFixes(
fixResults, err := fixInstance.Fix(&fixCandidate, &fixes.RuntimeOptions{
BaseDir: util.FindClosestMatchingRoot(abs, f.registeredRoots),
Config: config,
Locations: []report.Location{violation.Location},
Locations: []report.Location{rep.Violations[i].Location},
})
if err != nil {
return fmt.Errorf("failed to fix %s: %w", file, err)
Expand All @@ -259,7 +254,7 @@ func (f *Fixer) applyLinterFixes(
fixResult := fixResults[0]

if fixResult.Rename != nil {
if err = f.handleRename(fp, fixReport, startingFiles, fixResult); err != nil {
if err := f.handleRename(fp, fixReport, startingFiles, fixResult); err != nil {
return err
}

Expand Down
40 changes: 37 additions & 3 deletions pkg/fixer/fixer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ deny = true

memfp := fileprovider.NewInMemoryFileProvider(policies)

input, err := memfp.ToInput(map[string]ast.RegoVersion{
mainDir: ast.RegoV1,
})
input, err := memfp.ToInput(map[string]ast.RegoVersion{mainDir: ast.RegoV1})
if err != nil {
t.Fatalf("failed to create input: %v", err)
}
Expand Down Expand Up @@ -196,3 +194,39 @@ func TestFixViolations(t *testing.T) {
}
}
}

// 150116720 ns/op 101567417 B/op 2359322 allocs/op
// 132816578 ns/op 89093239 B/op 2068892 allocs/op // Linter.Prepare()
func BenchmarkFixViolations(b *testing.B) {
rootPath := testutil.Must(filepath.Abs(filepath.FromSlash("/root")))(b)
mainDir := filepath.Join(rootPath, "main")
mainRegoFile := filepath.Join(mainDir, "main.rego")

policies := map[string]string{
mainRegoFile: `package test

allow if {
true #no space
}
deny = true
`,
}

memfp := fileprovider.NewInMemoryFileProvider(policies)

input, err := memfp.ToInput(map[string]ast.RegoVersion{mainDir: ast.RegoV1})
if err != nil {
b.Fatalf("failed to create input: %v", err)
}

l := linter.NewLinter().WithEnableAll(true).WithInputModules(&input)
f := NewFixer().RegisterFixes(fixes.NewDefaultFixes()...).RegisterRoots(rootPath).
SetRegoVersionsMap(map[string]ast.RegoVersion{mainDir: ast.RegoV1})

for b.Loop() {
_, err := f.Fix(b.Context(), &l, memfp)
if err != nil {
b.Fatalf("failed to fix: %v", err)
}
}
}
2 changes: 1 addition & 1 deletion pkg/fixer/fixes/nowhitespacecomment.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func (n *NoWhitespaceComment) Fix(fc *FixCandidate, opts *RuntimeOptions) ([]Fix
fixed := false

for _, loc := range opts.Locations {
if loc.Row > len(lines) || loc.Column > len(lines[loc.Row-1]) || loc.Column < 1 {
if loc.Row < 1 || loc.Column < 1 || loc.Row > len(lines) || loc.Column > len(lines[loc.Row-1]) {
continue
}

Expand Down
8 changes: 4 additions & 4 deletions pkg/linter/linter.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,10 @@ func (l Linter) Lint(ctx context.Context) (report.Report, error) {
l.stopTimer(regalmetrics.RegalFilterIgnoredModules)
}

if len(l.inputPaths) == 0 && l.inputModules == nil && len(l.overriddenAggregates) == 0 {
return report.Report{}, errors.New("nothing provided to lint")
}

regoReport, err := l.lint(ctx, input)
if err != nil {
return report.Report{}, fmt.Errorf("failed to lint using Rego rules: %w", err)
Expand Down Expand Up @@ -531,10 +535,6 @@ func (l Linter) notPrepared() Linter {
}

func (l Linter) validate(conf *config.Config) error {
if len(l.inputPaths) == 0 && l.inputModules == nil && len(l.overriddenAggregates) == 0 {
return errors.New("nothing provided to lint")
}

if l.customRuleError != nil {
return fmt.Errorf("failed to load custom rules: %w", l.customRuleError)
}
Expand Down
Loading