Skip to content

Commit 76e5fda

Browse files
Trolloldemashutosh-narkar
authored andcommitted
fmt: report wrong arity for built-in functions
This commit fixes the index out of range error discussed in #5646 and adds error handling to avoid that `fmt` panics when such errors are encountered by the formatter The error handling procedure introduced is similar to the one used by the Scanner struct responsible for the parsing of a `rego` file, since it uses a slice which is filled with all the eventual errors that can be found during the format procedure Fixes: #5646 Signed-off-by: Gianluca Oldani <[email protected]>
1 parent 6cb6ed3 commit 76e5fda

File tree

2 files changed

+178
-6
lines changed

2 files changed

+178
-6
lines changed

cmd/fmt_test.go

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@ package cmd
22

33
import (
44
"bytes"
5+
"fmt"
56
"io"
67
"os"
78
"path/filepath"
89
"strings"
910
"testing"
1011

12+
"github.com/open-policy-agent/opa/ast"
13+
"github.com/open-policy-agent/opa/format"
1114
"github.com/open-policy-agent/opa/util/test"
1215
)
1316

@@ -30,6 +33,27 @@ const unformatted = `
3033
3134
`
3235

36+
const singleWrongArity = `package test
37+
38+
p {
39+
a := 1
40+
b := 2
41+
plus(a, b, c) == 3
42+
}
43+
`
44+
45+
const MultipleWrongArity = `package test
46+
47+
p {
48+
x:=5
49+
y:=7
50+
z:=6
51+
plus([x, y]) == 3
52+
and(true, false, false) == false
53+
plus(a, x, y, z)
54+
}
55+
`
56+
3357
func TestFmtFormatFile(t *testing.T) {
3458
params := fmtCommandParams{}
3559
var stdout bytes.Buffer
@@ -269,3 +293,89 @@ func TestFmtFailFileChangesDiff(t *testing.T) {
269293
}
270294
})
271295
}
296+
297+
func TestFmtSingleWrongArityError(t *testing.T) {
298+
params := fmtCommandParams{}
299+
var stdout bytes.Buffer
300+
301+
files := map[string]string{
302+
"policy.rego": singleWrongArity,
303+
}
304+
305+
test.WithTempFS(files, func(path string) {
306+
policyFile := filepath.Join(path, "policy.rego")
307+
info, err := os.Stat(policyFile)
308+
err = formatFile(&params, &stdout, policyFile, info, err)
309+
if err == nil {
310+
t.Fatalf("Expected error but did not receive one")
311+
}
312+
313+
loc := ast.Location{File: policyFile, Row: 6}
314+
errExp := ast.NewError(ast.TypeErr, &loc, "%s: %s", "plus", "arity mismatch")
315+
errExp.Details = &format.ArityFormatErrDetail{
316+
Have: []string{"var", "var", "var"},
317+
Want: []string{"number", "number"},
318+
}
319+
expectedErrs := ast.Errors(make([]*ast.Error, 1))
320+
expectedErrs[0] = errExp
321+
expectedSingleWrongArityErr := newError("failed to parse Rego source file: %v", fmt.Errorf("%s: %v", policyFile, expectedErrs))
322+
323+
if err != expectedSingleWrongArityErr {
324+
t.Fatalf("Expected:%s\n\nGot:%s\n\n", expectedSingleWrongArityErr, err)
325+
}
326+
})
327+
}
328+
329+
func TestFmtMultipleWrongArityError(t *testing.T) {
330+
params := fmtCommandParams{}
331+
var stdout bytes.Buffer
332+
333+
files := map[string]string{
334+
"policy.rego": MultipleWrongArity,
335+
}
336+
337+
test.WithTempFS(files, func(path string) {
338+
policyFile := filepath.Join(path, "policy.rego")
339+
info, err := os.Stat(policyFile)
340+
err = formatFile(&params, &stdout, policyFile, info, err)
341+
if err == nil {
342+
t.Fatalf("Expected error but did not receive one")
343+
}
344+
345+
locations := []ast.Location{
346+
{File: policyFile, Row: 7},
347+
{File: policyFile, Row: 8},
348+
{File: policyFile, Row: 9},
349+
}
350+
haveStrings := [][]string{
351+
{"array"},
352+
{"boolean", "boolean", "boolean"},
353+
{"var", "var", "var", "var"},
354+
}
355+
wantStrings := [][]string{
356+
{"number", "number"},
357+
{"set[any]", "set[any]"},
358+
{"number", "number"},
359+
}
360+
operators := []string{
361+
"plus",
362+
"and",
363+
"plus",
364+
}
365+
expectedErrs := ast.Errors(make([]*ast.Error, 3))
366+
for i := 0; i < 3; i++ {
367+
loc := locations[i]
368+
errExp := ast.NewError(ast.TypeErr, &loc, "%s: %s", operators[i], "arity mismatch")
369+
errExp.Details = &format.ArityFormatErrDetail{
370+
Have: haveStrings[i],
371+
Want: wantStrings[i],
372+
}
373+
expectedErrs[i] = errExp
374+
}
375+
expectedMultipleWrongArityErr := newError("failed to parse Rego source file: %v", fmt.Errorf("%s: %v", policyFile, expectedErrs))
376+
377+
if err != expectedMultipleWrongArityErr {
378+
t.Fatalf("Expected:%s\n\nGot:%s\n\n", expectedMultipleWrongArityErr, err)
379+
}
380+
})
381+
}

format/format.go

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@ import (
1010
"fmt"
1111
"regexp"
1212
"sort"
13+
"strings"
1314

1415
"github.com/open-policy-agent/opa/ast"
1516
"github.com/open-policy-agent/opa/internal/future"
17+
"github.com/open-policy-agent/opa/types"
1618
)
1719

1820
// Opts lets you control the code formatting via `AstWithOpts()`.
@@ -37,6 +39,7 @@ func Source(filename string, src []byte) ([]byte, error) {
3739
if err != nil {
3840
return nil, err
3941
}
42+
4043
formatted, err := Ast(module)
4144
if err != nil {
4245
return nil, fmt.Errorf("%s: %v", filename, err)
@@ -142,6 +145,7 @@ func AstWithOpts(x interface{}, opts Opts) ([]byte, error) {
142145

143146
w := &writer{
144147
indent: "\t",
148+
errs: make([]*ast.Error, 0),
145149
}
146150

147151
switch x := x.(type) {
@@ -178,6 +182,9 @@ func AstWithOpts(x interface{}, opts Opts) ([]byte, error) {
178182
return nil, fmt.Errorf("not an ast element: %v", x)
179183
}
180184

185+
if len(w.errs) > 0 {
186+
return nil, w.errs
187+
}
181188
return squashTrailingNewlines(w.buf.Bytes()), nil
182189
}
183190

@@ -228,6 +235,7 @@ type writer struct {
228235
inline bool
229236
beforeEnd *ast.Comment
230237
delay bool
238+
errs ast.Errors
231239
}
232240

233241
func (w *writer) writeModule(module *ast.Module, o fmtOpts) {
@@ -587,7 +595,7 @@ func (w *writer) writeSomeDecl(decl *ast.SomeDecl, comments []*ast.Comment) []*a
587595
w.write(",")
588596
}
589597
case ast.Call:
590-
comments = w.writeInOperator(false, val[1:], comments)
598+
comments = w.writeInOperator(false, val[1:], comments, decl.Location, ast.BuiltinMap[val[0].String()].Decl)
591599
}
592600
}
593601

@@ -622,7 +630,7 @@ func (w *writer) writeFunctionCall(expr *ast.Expr, comments []*ast.Comment) []*a
622630

623631
switch operator {
624632
case ast.Member.Name, ast.MemberWithKey.Name:
625-
return w.writeInOperator(false, terms[1:], comments)
633+
return w.writeInOperator(false, terms[1:], comments, terms[0].Location, ast.BuiltinMap[terms[0].String()].Decl)
626634
}
627635

628636
bi, ok := ast.BuiltinMap[operator]
@@ -647,6 +655,9 @@ func (w *writer) writeFunctionCall(expr *ast.Expr, comments []*ast.Comment) []*a
647655
comments = w.writeTerm(terms[2], comments)
648656
return comments
649657
}
658+
// NOTE(Trolloldem): in this point we are operating with a built-in function with the
659+
// wrong arity even when the assignment notation is used
660+
w.errs = append(w.errs, ArityFormatMismatchError(terms[1:], terms[0].String(), terms[0].Location, bi.Decl))
650661
return w.writeFunctionCallPlain(terms, comments)
651662
}
652663

@@ -708,7 +719,7 @@ func (w *writer) writeTermParens(parens bool, term *ast.Term, comments []*ast.Co
708719
case ast.Var:
709720
w.write(w.formatVar(x))
710721
case ast.Call:
711-
comments = w.writeCall(parens, x, comments)
722+
comments = w.writeCall(parens, x, term.Location, comments)
712723
case fmt.Stringer:
713724
w.write(x.String())
714725
}
@@ -760,7 +771,7 @@ func (w *writer) formatVar(v ast.Var) string {
760771
return v.String()
761772
}
762773

763-
func (w *writer) writeCall(parens bool, x ast.Call, comments []*ast.Comment) []*ast.Comment {
774+
func (w *writer) writeCall(parens bool, x ast.Call, loc *ast.Location, comments []*ast.Comment) []*ast.Comment {
764775
bi, ok := ast.BuiltinMap[x[0].String()]
765776
if !ok || bi.Infix == "" {
766777
return w.writeFunctionCallPlain(x, comments)
@@ -769,13 +780,22 @@ func (w *writer) writeCall(parens bool, x ast.Call, comments []*ast.Comment) []*
769780
if bi.Infix == "in" {
770781
// NOTE(sr): `in` requires special handling, mirroring what happens in the parser,
771782
// since there can be one or two lhs arguments.
772-
return w.writeInOperator(true, x[1:], comments)
783+
return w.writeInOperator(true, x[1:], comments, loc, bi.Decl)
773784
}
774785

775786
// TODO(tsandall): improve to consider precedence?
776787
if parens {
777788
w.write("(")
778789
}
790+
791+
// NOTE(Trolloldem): writeCall is only invoked when the function call is a term
792+
// of another function. The only valid arity is the one of the
793+
// built-in function
794+
if len(bi.Decl.Args()) != len(x)-1 {
795+
w.errs = append(w.errs, ArityFormatMismatchError(x[1:], x[0].String(), loc, bi.Decl))
796+
return comments
797+
}
798+
779799
comments = w.writeTermParens(true, x[1], comments)
780800
w.write(" " + bi.Infix + " ")
781801
comments = w.writeTermParens(true, x[2], comments)
@@ -786,7 +806,16 @@ func (w *writer) writeCall(parens bool, x ast.Call, comments []*ast.Comment) []*
786806
return comments
787807
}
788808

789-
func (w *writer) writeInOperator(parens bool, operands []*ast.Term, comments []*ast.Comment) []*ast.Comment {
809+
func (w *writer) writeInOperator(parens bool, operands []*ast.Term, comments []*ast.Comment, loc *ast.Location, f *types.Function) []*ast.Comment {
810+
if len(operands) != len(f.Args()) {
811+
// The number of operands does not math the arity of the `in` operator
812+
operator := ast.Member.Name
813+
if len(f.Args()) == 3 {
814+
operator = ast.MemberWithKey.Name
815+
}
816+
w.errs = append(w.errs, ArityFormatMismatchError(operands, operator, loc, f))
817+
return comments
818+
}
790819
kw := "in"
791820
switch len(operands) {
792821
case 2:
@@ -1356,3 +1385,36 @@ func ensureFutureKeywordImport(imps []*ast.Import, kw string) []*ast.Import {
13561385
imp.Location = defaultLocation(imp)
13571386
return append(imps, imp)
13581387
}
1388+
1389+
// ArgErrDetail but for `fmt` checks since compiler has not run yet.
1390+
type ArityFormatErrDetail struct {
1391+
Have []string `json:"have"`
1392+
Want []string `json:"want"`
1393+
}
1394+
1395+
// arityMismatchError but for `fmt` checks since the compiler has not run yet.
1396+
func ArityFormatMismatchError(operands []*ast.Term, operator string, loc *ast.Location, f *types.Function) *ast.Error {
1397+
want := make([]string, len(f.Args()))
1398+
for i := range f.Args() {
1399+
want[i] = types.Sprint(f.Args()[i])
1400+
}
1401+
1402+
have := make([]string, len(operands))
1403+
for i := 0; i < len(operands); i++ {
1404+
have[i] = ast.TypeName(operands[i].Value)
1405+
}
1406+
err := ast.NewError(ast.TypeErr, loc, "%s: %s", operator, "arity mismatch")
1407+
err.Details = &ArityFormatErrDetail{
1408+
Have: have,
1409+
Want: want,
1410+
}
1411+
return err
1412+
}
1413+
1414+
// Lines returns the string representation of the detail.
1415+
func (d *ArityFormatErrDetail) Lines() []string {
1416+
return []string{
1417+
"have: " + "(" + strings.Join(d.Have, ",") + ")",
1418+
"want: " + "(" + strings.Join(d.Want, ",") + ")",
1419+
}
1420+
}

0 commit comments

Comments
 (0)