Skip to content

Commit c52526a

Browse files
committed
Fix in MockMultipartHttpServletRequest#getMultipartHeaders
Previously this method returned headers only when a Content-Type part header was present. Now it is guaranteed to return headers (possibly empty) as long as there is a MultipartFile or Part with the given name. Closes gh-26501
1 parent 7a329eb commit c52526a

File tree

4 files changed

+68
-18
lines changed

4 files changed

+68
-18
lines changed

spring-test/src/main/java/org/springframework/mock/web/MockMultipartHttpServletRequest.java

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2020 the original author or authors.
2+
* Copyright 2002-2021 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -17,6 +17,7 @@
1717
package org.springframework.mock.web;
1818

1919
import java.io.IOException;
20+
import java.util.ArrayList;
2021
import java.util.Collections;
2122
import java.util.Enumeration;
2223
import java.util.Iterator;
@@ -33,6 +34,7 @@
3334
import org.springframework.util.Assert;
3435
import org.springframework.util.LinkedMultiValueMap;
3536
import org.springframework.util.MultiValueMap;
37+
import org.springframework.web.multipart.MultipartException;
3638
import org.springframework.web.multipart.MultipartFile;
3739
import org.springframework.web.multipart.MultipartHttpServletRequest;
3840

@@ -155,15 +157,28 @@ public HttpHeaders getRequestHeaders() {
155157

156158
@Override
157159
public HttpHeaders getMultipartHeaders(String paramOrFileName) {
158-
String contentType = getMultipartContentType(paramOrFileName);
159-
if (contentType != null) {
160+
MultipartFile file = getFile(paramOrFileName);
161+
if (file != null) {
160162
HttpHeaders headers = new HttpHeaders();
161-
headers.add(HttpHeaders.CONTENT_TYPE, contentType);
163+
if (file.getContentType() != null) {
164+
headers.add(HttpHeaders.CONTENT_TYPE, file.getContentType());
165+
}
162166
return headers;
163167
}
164-
else {
165-
return null;
168+
try {
169+
Part part = getPart(paramOrFileName);
170+
if (part != null) {
171+
HttpHeaders headers = new HttpHeaders();
172+
for (String headerName : part.getHeaderNames()) {
173+
headers.put(headerName, new ArrayList<>(part.getHeaders(headerName)));
174+
}
175+
return headers;
176+
}
166177
}
178+
catch (Throwable ex) {
179+
throw new MultipartException("Could not access multipart servlet request", ex);
180+
}
181+
return null;
167182
}
168183

169184
}

spring-web/src/main/java/org/springframework/web/multipart/MultipartHttpServletRequest.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2011 the original author or authors.
2+
* Copyright 2002-2021 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -60,9 +60,10 @@ public interface MultipartHttpServletRequest extends HttpServletRequest, Multipa
6060
HttpHeaders getRequestHeaders();
6161

6262
/**
63-
* Return the headers associated with the specified part of the multipart request.
64-
* <p>If the underlying implementation supports access to headers, then all headers are returned.
65-
* Otherwise, the returned headers will include a 'Content-Type' header at the very least.
63+
* Return the headers for the specified part of the multipart request.
64+
* <p>If the underlying implementation supports access to part headers,
65+
* then all headers are returned. Otherwise, e.g. for a file upload, the
66+
* returned headers may expose a 'Content-Type' if available.
6667
*/
6768
@Nullable
6869
HttpHeaders getMultipartHeaders(String paramOrFileName);

spring-web/src/testFixtures/java/org/springframework/web/testfixture/servlet/MockMultipartHttpServletRequest.java

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2020 the original author or authors.
2+
* Copyright 2002-2021 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -17,6 +17,7 @@
1717
package org.springframework.web.testfixture.servlet;
1818

1919
import java.io.IOException;
20+
import java.util.ArrayList;
2021
import java.util.Collections;
2122
import java.util.Enumeration;
2223
import java.util.Iterator;
@@ -33,6 +34,7 @@
3334
import org.springframework.util.Assert;
3435
import org.springframework.util.LinkedMultiValueMap;
3536
import org.springframework.util.MultiValueMap;
37+
import org.springframework.web.multipart.MultipartException;
3638
import org.springframework.web.multipart.MultipartFile;
3739
import org.springframework.web.multipart.MultipartHttpServletRequest;
3840

@@ -155,15 +157,28 @@ public HttpHeaders getRequestHeaders() {
155157

156158
@Override
157159
public HttpHeaders getMultipartHeaders(String paramOrFileName) {
158-
String contentType = getMultipartContentType(paramOrFileName);
159-
if (contentType != null) {
160+
MultipartFile file = getFile(paramOrFileName);
161+
if (file != null) {
160162
HttpHeaders headers = new HttpHeaders();
161-
headers.add(HttpHeaders.CONTENT_TYPE, contentType);
163+
if (file.getContentType() != null) {
164+
headers.add(HttpHeaders.CONTENT_TYPE, file.getContentType());
165+
}
162166
return headers;
163167
}
164-
else {
165-
return null;
168+
try {
169+
Part part = getPart(paramOrFileName);
170+
if (part != null) {
171+
HttpHeaders headers = new HttpHeaders();
172+
for (String headerName : part.getHeaderNames()) {
173+
headers.put(headerName, new ArrayList<>(part.getHeaders(headerName)));
174+
}
175+
return headers;
176+
}
166177
}
178+
catch (Throwable ex) {
179+
throw new MultipartException("Could not access multipart servlet request", ex);
180+
}
181+
return null;
167182
}
168183

169184
}

spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/RequestPartMethodArgumentResolverTests.java

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2021 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -36,6 +36,7 @@
3636
import org.springframework.http.HttpInputMessage;
3737
import org.springframework.http.MediaType;
3838
import org.springframework.http.converter.HttpMessageConverter;
39+
import org.springframework.http.converter.StringHttpMessageConverter;
3940
import org.springframework.lang.Nullable;
4041
import org.springframework.util.ReflectionUtils;
4142
import org.springframework.validation.BindingResult;
@@ -51,6 +52,7 @@
5152
import org.springframework.web.multipart.MultipartException;
5253
import org.springframework.web.multipart.MultipartFile;
5354
import org.springframework.web.multipart.support.MissingServletRequestPartException;
55+
import org.springframework.web.testfixture.method.ResolvableMethod;
5456
import org.springframework.web.testfixture.servlet.MockHttpServletRequest;
5557
import org.springframework.web.testfixture.servlet.MockHttpServletResponse;
5658
import org.springframework.web.testfixture.servlet.MockMultipartFile;
@@ -311,6 +313,22 @@ public void resolveRequestPartNotRequired() throws Exception {
311313
testResolveArgument(new SimpleBean("foo"), paramValidRequestPart);
312314
}
313315

316+
@Test // gh-26501
317+
public void resolveRequestPartWithoutContentType() throws Exception {
318+
MockMultipartHttpServletRequest servletRequest = new MockMultipartHttpServletRequest();
319+
servletRequest.addPart(new MockPart("requestPartString", "part value".getBytes(StandardCharsets.UTF_8)));
320+
ServletWebRequest webRequest = new ServletWebRequest(servletRequest, new MockHttpServletResponse());
321+
322+
List<HttpMessageConverter<?>> converters = Collections.singletonList(new StringHttpMessageConverter());
323+
RequestPartMethodArgumentResolver resolver = new RequestPartMethodArgumentResolver(converters);
324+
MethodParameter parameter = ResolvableMethod.on(getClass()).named("handle").build().arg(String.class);
325+
326+
Object actualValue = resolver.resolveArgument(
327+
parameter, new ModelAndViewContainer(), webRequest, new ValidatingBinderFactory());
328+
329+
assertThat(actualValue).isEqualTo("part value");
330+
}
331+
314332
@Test
315333
public void isMultipartRequest() throws Exception {
316334
MockHttpServletRequest request = new MockHttpServletRequest();
@@ -606,7 +624,8 @@ public void handle(
606624
@RequestPart("requestPart") Optional<List<MultipartFile>> optionalMultipartFileList,
607625
Optional<Part> optionalPart,
608626
@RequestPart("requestPart") Optional<List<Part>> optionalPartList,
609-
@RequestPart("requestPart") Optional<SimpleBean> optionalRequestPart) {
627+
@RequestPart("requestPart") Optional<SimpleBean> optionalRequestPart,
628+
@RequestPart("requestPartString") String requestPartString) {
610629
}
611630

612631
}

0 commit comments

Comments
 (0)