diff --git a/spring-integration-test/src/main/java/org/springframework/integration/test/context/MockIntegrationContext.java b/spring-integration-test/src/main/java/org/springframework/integration/test/context/MockIntegrationContext.java index d0e89265b14..42082fc4494 100644 --- a/spring-integration-test/src/main/java/org/springframework/integration/test/context/MockIntegrationContext.java +++ b/spring-integration-test/src/main/java/org/springframework/integration/test/context/MockIntegrationContext.java @@ -1,5 +1,5 @@ /* - * Copyright 2017-2019 the original author or authors. + * Copyright 2017-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ import org.springframework.beans.factory.BeanFactoryAware; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.context.Lifecycle; +import org.springframework.context.SmartLifecycle; import org.springframework.integration.core.MessageProducer; import org.springframework.integration.core.MessageSource; import org.springframework.integration.endpoint.IntegrationConsumer; @@ -97,6 +98,11 @@ public void resetBeans(String... beanNames) { .forEach(e -> { Object endpoint = this.beanFactory.getBean(e.getKey()); DirectFieldAccessor directFieldAccessor = new DirectFieldAccessor(endpoint); + SmartLifecycle lifecycle = null; + if (endpoint instanceof SmartLifecycle && ((SmartLifecycle) endpoint).isRunning()) { + lifecycle = (SmartLifecycle) endpoint; + lifecycle.stop(); + } if (endpoint instanceof SourcePollingChannelAdapter) { directFieldAccessor.setPropertyValue("source", e.getValue()); } @@ -108,9 +114,19 @@ else if (endpoint instanceof ReactiveStreamsConsumer) { else if (endpoint instanceof IntegrationConsumer) { directFieldAccessor.setPropertyValue(HANDLER, e.getValue()); } + if (lifecycle != null && lifecycle.isAutoStartup()) { + lifecycle.start(); + } }); - this.beans.clear(); + if (!ObjectUtils.isEmpty(beanNames)) { + for (String name : beanNames) { + this.beans.remove(name); + } + } + else { + this.beans.clear(); + } } /** diff --git a/spring-integration-test/src/test/java/org/springframework/integration/test/mock/MockMessageHandlerTests.java b/spring-integration-test/src/test/java/org/springframework/integration/test/mock/MockMessageHandlerTests.java index 5fec40efed0..633dc78f7ce 100644 --- a/spring-integration-test/src/test/java/org/springframework/integration/test/mock/MockMessageHandlerTests.java +++ b/spring-integration-test/src/test/java/org/springframework/integration/test/mock/MockMessageHandlerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2017-2019 the original author or authors. + * Copyright 2017-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,16 +17,18 @@ package org.springframework.integration.test.mock; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; import static org.assertj.core.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.springframework.integration.test.mock.MockIntegration.mockMessageHandler; import java.util.List; +import java.util.Map; -import org.junit.After; -import org.junit.Test; -import org.junit.runner.RunWith; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; import org.reactivestreams.Subscriber; @@ -34,6 +36,7 @@ import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.integration.annotation.EndpointId; import org.springframework.integration.annotation.Poller; import org.springframework.integration.annotation.ServiceActivator; import org.springframework.integration.channel.DirectChannel; @@ -42,6 +45,7 @@ import org.springframework.integration.endpoint.ReactiveStreamsConsumer; import org.springframework.integration.expression.ValueExpression; import org.springframework.integration.handler.ExpressionEvaluatingMessageHandler; +import org.springframework.integration.handler.LoggingHandler; import org.springframework.integration.support.MessageBuilder; import org.springframework.integration.test.context.MockIntegrationContext; import org.springframework.integration.test.context.SpringIntegrationTest; @@ -56,7 +60,7 @@ import org.springframework.messaging.support.GenericMessage; import org.springframework.test.annotation.DirtiesContext; import org.springframework.test.context.ContextConfiguration; -import org.springframework.test.context.junit4.SpringRunner; +import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; /** * @author Artem Bilan @@ -64,7 +68,7 @@ * * @since 5.0 */ -@RunWith(SpringRunner.class) +@SpringJUnitConfig @ContextConfiguration(classes = MockMessageHandlerTests.Config.class) @SpringIntegrationTest @DirtiesContext @@ -97,10 +101,10 @@ public class MockMessageHandlerTests { @Autowired private ArgumentCaptor> messageArgumentCaptor; - @After + @AfterEach public void tearDown() { this.mockIntegrationContext.resetBeans(); - results.purge(null); + this.results.purge(null); } @Test @@ -181,8 +185,7 @@ public void testMockRawHandler() { ArgumentCaptor> messageArgumentCaptor = MockIntegration.messageArgumentCaptor(); MessageHandler mockMessageHandler = spy(mockMessageHandler(messageArgumentCaptor)) - .handleNext(m -> { - }); + .handleNext(m -> { }); String endpointId = "rawHandlerConsumer"; this.mockIntegrationContext.substituteMessageHandlerFor(endpointId, mockMessageHandler); @@ -201,25 +204,22 @@ public void testMockRawHandler() { this.mockIntegrationContext.resetBeans(endpointId); - mockMessageHandler = + MessageHandler mockMessageHandler2 = mockMessageHandler() .handleNextAndReply(m -> m); - try { - this.mockIntegrationContext.substituteMessageHandlerFor(endpointId, mockMessageHandler); - fail("IllegalStateException expected"); - } - catch (Exception e) { - assertThat(e).isInstanceOf(IllegalStateException.class); - assertThat(e.getMessage()).contains("with replies can't replace simple MessageHandler"); - } + + assertThatIllegalStateException() + .isThrownBy(() -> + this.mockIntegrationContext.substituteMessageHandlerFor(endpointId, mockMessageHandler2)) + .withMessageContaining("with replies can't replace simple MessageHandler"); this.mockIntegrationContext.resetBeans(); assertThat(TestUtils.getPropertyValue(endpoint, "handler", MessageHandler.class)) - .isNotSameAs(mockMessageHandler); + .isNotSameAs(mockMessageHandler2); assertThat(TestUtils.getPropertyValue(endpoint, "subscriber", Subscriber.class)) - .isNotSameAs(mockMessageHandler); + .isNotSameAs(mockMessageHandler2); } /** @@ -238,6 +238,37 @@ public void testHandlerSubstitutionWithOutputChannel() { assertThat(list.size()).isEqualTo(2); } + @Autowired + private MessageChannel logChannel; + + @Test + @SuppressWarnings("unchecked") + public void testMockIntegrationContextReset() { + MockMessageHandler mockMessageHandler = mockMessageHandler(); + mockMessageHandler.handleNext(message -> { }); + + this.mockIntegrationContext.substituteMessageHandlerFor("logEndpoint", mockMessageHandler); + + String endpointId = "mockMessageHandlerTests.Config.myService.serviceActivator"; + this.mockIntegrationContext.substituteMessageHandlerFor(endpointId, mockMessageHandler); + + this.logChannel.send(new GenericMessage<>(1)); + + this.mockIntegrationContext.resetBeans("logEndpoint"); + + this.logChannel.send(new GenericMessage<>(2)); + + verify(mockMessageHandler).handleMessage(any(Message.class)); + + assertThat(TestUtils.getPropertyValue(this.mockIntegrationContext, "beans", Map.class)).hasSize(1); + + assertThat( + TestUtils.getPropertyValue( + this.context.getBean("mockMessageHandlerTests.Config.myService.serviceActivator"), "handler")) + .isSameAs(mockMessageHandler); + } + + @Configuration @EnableIntegration public static class Config { @@ -301,6 +332,13 @@ public MessageHandler handleNextInput() { }); } + @Bean + @EndpointId("logEndpoint") + @ServiceActivator(inputChannel = "logChannel") + public MessageHandler logHandler() { + return new LoggingHandler(LoggingHandler.Level.FATAL); + } + } }