diff --git a/aws-serverless-java-container-springboot3/src/main/java/com/amazonaws/serverless/proxy/spring/SpringBootLambdaContainerHandler.java b/aws-serverless-java-container-springboot3/src/main/java/com/amazonaws/serverless/proxy/spring/SpringBootLambdaContainerHandler.java index 6b3dd5ca2..45bf28efe 100644 --- a/aws-serverless-java-container-springboot3/src/main/java/com/amazonaws/serverless/proxy/spring/SpringBootLambdaContainerHandler.java +++ b/aws-serverless-java-container-springboot3/src/main/java/com/amazonaws/serverless/proxy/spring/SpringBootLambdaContainerHandler.java @@ -18,7 +18,7 @@ import org.slf4j.LoggerFactory; import org.springframework.boot.WebApplicationType; import org.springframework.boot.builder.SpringApplicationBuilder; -import org.springframework.boot.web.servlet.context.AnnotationConfigServletWebServerApplicationContext; +import org.springframework.web.context.ConfigurableWebApplicationContext; import org.springframework.context.ConfigurableApplicationContext; import com.amazonaws.serverless.exceptions.ContainerInitializationException; @@ -184,20 +184,26 @@ protected void handleRequest(HttpServletRequest containerRequest, AwsHttpServlet Timer.stop("SPRINGBOOT2_HANDLE_REQUEST"); } + SpringApplicationBuilder getSpringApplicationBuilder(Class... sources) { + return new SpringApplicationBuilder(sources); + } @Override public void initialize() throws ContainerInitializationException { Timer.start("SPRINGBOOT2_COLD_START"); - SpringApplicationBuilder builder = new SpringApplicationBuilder(getEmbeddedContainerClasses()) + SpringApplicationBuilder builder = getSpringApplicationBuilder(getEmbeddedContainerClasses()) .web(springWebApplicationType); // .REACTIVE, .SERVLET + if(springBootInitializer != null) { + builder.main(springBootInitializer); + } if (springProfiles != null) { builder.profiles(springProfiles); } applicationContext = builder.run(); if (springWebApplicationType == WebApplicationType.SERVLET) { - ((AnnotationConfigServletWebServerApplicationContext)applicationContext).setServletContext(getServletContext()); + ((ConfigurableWebApplicationContext)applicationContext).setServletContext(getServletContext()); AwsServletRegistration reg = (AwsServletRegistration)getServletContext().getServletRegistration(DISPATCHER_SERVLET_REGISTRATION_NAME); if (reg != null) { reg.setLoadOnStartup(1); diff --git a/aws-serverless-java-container-springboot3/src/test/java/com/amazonaws/serverless/proxy/spring/SpringBootLambdaContainerHandlerTest.java b/aws-serverless-java-container-springboot3/src/test/java/com/amazonaws/serverless/proxy/spring/SpringBootLambdaContainerHandlerTest.java new file mode 100644 index 000000000..871a751bb --- /dev/null +++ b/aws-serverless-java-container-springboot3/src/test/java/com/amazonaws/serverless/proxy/spring/SpringBootLambdaContainerHandlerTest.java @@ -0,0 +1,76 @@ +package com.amazonaws.serverless.proxy.spring; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import com.amazonaws.serverless.exceptions.ContainerInitializationException; +import com.amazonaws.serverless.proxy.AwsProxyExceptionHandler; +import com.amazonaws.serverless.proxy.AwsProxySecurityContextWriter; +import com.amazonaws.serverless.proxy.InitializationWrapper; +import com.amazonaws.serverless.proxy.internal.servlet.AwsProxyHttpServletRequestReader; +import com.amazonaws.serverless.proxy.internal.servlet.AwsProxyHttpServletResponseWriter; +import com.amazonaws.serverless.proxy.model.AwsProxyRequest; +import com.amazonaws.serverless.proxy.model.AwsProxyResponse; +import com.amazonaws.serverless.proxy.spring.servletapp.ServletApplication; +import com.amazonaws.serverless.proxy.spring.webfluxapp.WebFluxTestApplication; +import java.util.Collection; +import java.util.List; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; +import org.springframework.boot.WebApplicationType; +import org.springframework.boot.builder.SpringApplicationBuilder; + +class SpringBootLambdaContainerHandlerTest { + + SpringBootLambdaContainerHandler handler; + SpringApplicationBuilder springApplicationBuilder; + + public static Collection data() { + return List.of(new TestData(WebApplicationType.SERVLET, ServletApplication.class), + new TestData(WebApplicationType.REACTIVE, WebFluxTestApplication.class)); + } + + private void initSpringBootLambdaContainerHandlerTest(Class springBootInitializer, + WebApplicationType applicationType) { + handler = Mockito.spy(new SpringBootLambdaContainerHandler<>(AwsProxyRequest.class, + AwsProxyResponse.class, + new AwsProxyHttpServletRequestReader(), + new AwsProxyHttpServletResponseWriter(), + new AwsProxySecurityContextWriter(), + new AwsProxyExceptionHandler(), + springBootInitializer, + new InitializationWrapper(), + applicationType)); + + doAnswer(d -> { + springApplicationBuilder = ((SpringApplicationBuilder) Mockito.spy(d.callRealMethod())); + return springApplicationBuilder; + }).when(handler).getSpringApplicationBuilder(any(Class[].class)); + } + + @ParameterizedTest + @MethodSource("data") + void initialize_withSpringBootInitializer(TestData data) throws ContainerInitializationException { + initSpringBootLambdaContainerHandlerTest(data.springBootApplication(), data.applicationType()); + handler.initialize(); + + verify(springApplicationBuilder, times(1)).main(data.springBootApplication()); + } + + @ParameterizedTest + @EnumSource(WebApplicationType.class) + void initialize_withoutSpringBootInitializer(WebApplicationType webApplicationType) { + initSpringBootLambdaContainerHandlerTest(null, webApplicationType); + assertThrows(IllegalArgumentException.class, handler::initialize, "Source must not be null"); + + verify(springApplicationBuilder, never()).main(any()); + } + + record TestData(WebApplicationType applicationType, Class springBootApplication) {} +} \ No newline at end of file