diff --git a/persistence-modules/spring-persistence-simple-2/src/main/java/com/baeldung/jdbc/Employee.java b/persistence-modules/spring-persistence-simple-2/src/main/java/com/baeldung/jdbc/Employee.java index a43eb265c7..d5f87ca3df 100644 --- a/persistence-modules/spring-persistence-simple-2/src/main/java/com/baeldung/jdbc/Employee.java +++ b/persistence-modules/spring-persistence-simple-2/src/main/java/com/baeldung/jdbc/Employee.java @@ -9,6 +9,13 @@ public class Employee { private String address; + public Employee(int id, String firstName, String lastName, String address) { + setId(id); + setFirstName(firstName); + setLastName(lastName); + setAddress(address); + } + public int getId() { return id; } diff --git a/persistence-modules/spring-persistence-simple-2/src/main/java/com/baeldung/jdbc/EmployeeDAO.java b/persistence-modules/spring-persistence-simple-2/src/main/java/com/baeldung/jdbc/EmployeeDAO.java index b5bf9452ed..dec88ee1f6 100644 --- a/persistence-modules/spring-persistence-simple-2/src/main/java/com/baeldung/jdbc/EmployeeDAO.java +++ b/persistence-modules/spring-persistence-simple-2/src/main/java/com/baeldung/jdbc/EmployeeDAO.java @@ -1,21 +1,62 @@ package com.baeldung.jdbc; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + import javax.sql.DataSource; import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.core.namedparam.MapSqlParameterSource; +import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate; +import org.springframework.jdbc.core.namedparam.SqlParameterSource; import org.springframework.stereotype.Repository; @Repository public class EmployeeDAO { private JdbcTemplate jdbcTemplate; + private NamedParameterJdbcTemplate namedJdbcTemplate; public void setDataSource(DataSource dataSource) { jdbcTemplate = new JdbcTemplate(dataSource); + namedJdbcTemplate = new NamedParameterJdbcTemplate(dataSource); } public int getCountOfEmployees() { return jdbcTemplate.queryForObject("SELECT COUNT(*) FROM EMPLOYEE", Integer.class); } - + public List getEmployeesFromIdListNamed(List ids) { + SqlParameterSource parameters = new MapSqlParameterSource("ids", ids); + List employees = namedJdbcTemplate.query("SELECT * FROM EMPLOYEE WHERE id IN (:ids)", + parameters, + (rs, rowNum) -> new Employee(rs.getInt("id"), rs.getString("first_name"), rs.getString("last_name"), rs.getString("address"))); + + return employees; + } + + public List getEmployeesFromIdList(List ids) { + String inSql = String.join(",", Collections.nCopies(ids.size(), "?")); + List employees = jdbcTemplate.query("SELECT * FROM EMPLOYEE WHERE id IN (" + inSql +")", + ids.toArray(), + (rs, rowNum) -> new Employee(rs.getInt("id"), rs.getString("first_name"), rs.getString("last_name"), rs.getString("address"))); + + return employees; + } + + public List getEmployeesFromLargeIdList(List ids) { + jdbcTemplate.execute("CREATE TEMPORARY TABLE employee_tmp (id INT NOT NULL)"); + + List employeeIds = new ArrayList<>(); + for (Integer id : ids) { + employeeIds.add(new Object[] { id }); + } + jdbcTemplate.batchUpdate("INSERT INTO employee_tmp VALUES(?)", employeeIds); + + List employees = jdbcTemplate.query("SELECT * FROM EMPLOYEE WHERE id IN (SELECT id FROM employee_tmp)", + (rs, rowNum) -> new Employee(rs.getInt("id"), rs.getString("first_name"), rs.getString("last_name"), rs.getString("address"))); + + return employees; + } + } diff --git a/persistence-modules/spring-persistence-simple-2/src/test/java/com/baeldung/jdbc/EmployeeDAOUnitTest.java b/persistence-modules/spring-persistence-simple-2/src/test/java/com/baeldung/jdbc/EmployeeDAOUnitTest.java index 71e8fb4263..f21704221b 100644 --- a/persistence-modules/spring-persistence-simple-2/src/test/java/com/baeldung/jdbc/EmployeeDAOUnitTest.java +++ b/persistence-modules/spring-persistence-simple-2/src/test/java/com/baeldung/jdbc/EmployeeDAOUnitTest.java @@ -2,8 +2,12 @@ package com.baeldung.jdbc; import static org.junit.jupiter.api.Assertions.assertEquals; +import java.util.ArrayList; +import java.util.List; + import javax.sql.DataSource; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; @@ -14,13 +18,24 @@ import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType; import org.springframework.test.util.ReflectionTestUtils; - @RunWith(MockitoJUnitRunner.class) public class EmployeeDAOUnitTest { @Mock JdbcTemplate jdbcTemplate; + DataSource dataSource; + + @Before + public void setup() { + dataSource = new EmbeddedDatabaseBuilder().setType(EmbeddedDatabaseType.H2) + .generateUniqueName(true) + .addScript("classpath:jdbc/schema.sql") + .addScript("classpath:jdbc/test-data.sql") + .build(); + + } + @Test public void whenMockJdbcTemplate_thenReturnCorrectEmployeeCount() { EmployeeDAO employeeDAO = new EmployeeDAO(); @@ -38,14 +53,50 @@ public class EmployeeDAOUnitTest { @Test public void whenInjectInMemoryDataSource_thenReturnCorrectEmployeeCount() { - DataSource dataSource = new EmbeddedDatabaseBuilder().setType(EmbeddedDatabaseType.H2) - .addScript("classpath:jdbc/schema.sql") - .addScript("classpath:jdbc/test-data.sql") - .build(); - EmployeeDAO employeeDAO = new EmployeeDAO(); employeeDAO.setDataSource(dataSource); assertEquals(4, employeeDAO.getCountOfEmployees()); } + + @Test + public void givenSmallIdList_whenGetEmployeesFromIdList_thenReturnCorrectEmployees() { + List ids = new ArrayList<>(); + ids.add(1); + ids.add(3); + ids.add(4); + EmployeeDAO employeeDAO = new EmployeeDAO(); + employeeDAO.setDataSource(dataSource); + + List employees = employeeDAO.getEmployeesFromIdList(ids); + + assertEquals(3, employees.size()); + assertEquals(1, employees.get(0).getId()); + assertEquals(3, employees.get(1).getId()); + assertEquals(4, employees.get(2).getId()); + + employees = employeeDAO.getEmployeesFromIdListNamed(ids); + + assertEquals(3, employees.size()); + assertEquals(1, employees.get(0).getId()); + assertEquals(3, employees.get(1).getId()); + assertEquals(4, employees.get(2).getId()); + } + + @Test + public void givenLargeIdList_whenGetEmployeesFromIdList_thenReturnCorrectEmployees() { + List ids = new ArrayList<>(); + ids.add(1); + ids.add(3); + ids.add(4); + EmployeeDAO employeeDAO = new EmployeeDAO(); + employeeDAO.setDataSource(dataSource); + + List employees = employeeDAO.getEmployeesFromLargeIdList(ids); + + assertEquals(3, employees.size()); + assertEquals(1, employees.get(0).getId()); + assertEquals(3, employees.get(1).getId()); + assertEquals(4, employees.get(2).getId()); + } }