Skip to content

Commit a33c5f8

Browse files
nmeheusNathan Meheus
andauthored
openapi3filter: support primitive parsing for individual text like parts in multipart/form-data (#1090)
* Add primitive parsing * fix check --------- Co-authored-by: Nathan Meheus <[email protected]>
1 parent e00a340 commit a33c5f8

File tree

2 files changed

+131
-2
lines changed

2 files changed

+131
-2
lines changed

openapi3filter/issue949_test.go

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
package openapi3filter_test
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/json"
7+
"fmt"
8+
"io"
9+
"mime/multipart"
10+
"net/http"
11+
"net/textproto"
12+
"strings"
13+
"testing"
14+
15+
"github.com/getkin/kin-openapi/openapi3"
16+
"github.com/getkin/kin-openapi/openapi3filter"
17+
"github.com/getkin/kin-openapi/routers/gorillamux"
18+
"github.com/stretchr/testify/require"
19+
)
20+
21+
const testSchema = `
22+
openapi: 3.0.0
23+
info:
24+
title: 'Validator'
25+
version: 0.0.1
26+
paths:
27+
/test:
28+
post:
29+
requestBody:
30+
required: true
31+
content:
32+
multipart/form-data:
33+
schema:
34+
type: object
35+
properties:
36+
file:
37+
type: string
38+
format: binary
39+
counts:
40+
type: object
41+
properties:
42+
name:
43+
type: string
44+
count:
45+
type: integer
46+
primitive:
47+
type: integer
48+
responses:
49+
'200':
50+
description: OK
51+
`
52+
53+
type count struct {
54+
Name string `json:"name"`
55+
Count int `json:"count"`
56+
}
57+
58+
func TestIssue949(t *testing.T) {
59+
loader := openapi3.NewLoader()
60+
doc, err := loader.LoadFromData([]byte(testSchema))
61+
require.NoError(t, err)
62+
63+
err = doc.Validate(context.Background())
64+
require.NoError(t, err)
65+
66+
router, err := gorillamux.NewRouter(doc)
67+
require.NoError(t, err)
68+
69+
body := &bytes.Buffer{}
70+
writer := multipart.NewWriter(body)
71+
72+
// Add the counts object to the request body
73+
h := make(textproto.MIMEHeader)
74+
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"`, "counts"))
75+
h.Set("Content-Type", "application/json")
76+
fw, err := writer.CreatePart(h)
77+
require.NoError(t, err)
78+
79+
countStruct := count{Name: "foo", Count: 7}
80+
countBody, err := json.Marshal(countStruct)
81+
require.NoError(t, err)
82+
_, err = fw.Write(countBody)
83+
require.NoError(t, err)
84+
85+
// Add the file to the request body
86+
fw, err = writer.CreateFormFile("file", "hello.txt")
87+
require.NoError(t, err)
88+
89+
_, err = io.Copy(fw, strings.NewReader("hello"))
90+
require.NoError(t, err)
91+
92+
// Add the primitive integer to the request body
93+
fw, err = writer.CreateFormField("primitive")
94+
require.NoError(t, err)
95+
_, err = fw.Write([]byte("1"))
96+
require.NoError(t, err)
97+
98+
writer.Close()
99+
100+
req, _ := http.NewRequest(http.MethodPost, "/test", bytes.NewReader(body.Bytes()))
101+
req.Header.Set("Content-Type", writer.FormDataContentType())
102+
103+
route, pathParams, err := router.FindRoute(req)
104+
require.NoError(t, err)
105+
106+
reqBody := route.Operation.RequestBody.Value
107+
108+
requestValidationInput := &openapi3filter.RequestValidationInput{
109+
Request: req,
110+
PathParams: pathParams,
111+
Route: route,
112+
}
113+
114+
err = openapi3filter.ValidateRequestBody(context.TODO(), requestValidationInput, reqBody)
115+
require.NoError(t, err)
116+
}

openapi3filter/req_resp_decoder.go

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import (
1717
"strconv"
1818
"strings"
1919

20-
"github.com/oasdiff/yaml3"
20+
yaml "github.com/oasdiff/yaml3"
2121

2222
"github.com/getkin/kin-openapi/openapi3"
2323
)
@@ -1472,13 +1472,26 @@ func MultipartBodyDecoder(body io.Reader, header http.Header, schema *openapi3.S
14721472
}
14731473
}
14741474

1475+
partHeader := http.Header(part.Header)
14751476
var value any
1476-
if _, value, err = decodeBody(part, http.Header(part.Header), valueSchema, subEncFn); err != nil {
1477+
if _, value, err = decodeBody(part, partHeader, valueSchema, subEncFn); err != nil {
14771478
if v, ok := err.(*ParseError); ok {
14781479
return nil, &ParseError{path: []any{name}, Cause: v}
14791480
}
14801481
return nil, fmt.Errorf("part %s: %w", name, err)
14811482
}
1483+
1484+
// Parse primitive types when no content type is explicitely provided, or the content type is set to text/plain
1485+
contentType := partHeader.Get(headerCT)
1486+
if contentType == "" || contentType == "text/plain" {
1487+
if value, err = parsePrimitive(value.(string), valueSchema); err != nil {
1488+
if v, ok := err.(*ParseError); ok {
1489+
return nil, &ParseError{path: []any{name}, Cause: v}
1490+
}
1491+
return nil, fmt.Errorf("part %s: %w", name, err)
1492+
}
1493+
}
1494+
14821495
values[name] = append(values[name], value)
14831496
}
14841497

0 commit comments

Comments
 (0)