[BAEL-1381] Add JPA Examples (#6369)

* BAEL-1381

* [BAEL-1381]

* [BAEL-1381] New module name

* [BAEL-1381] software-security module

* [BAEL-1381] Add JPA examples
This commit is contained in:
psevestre 2019-02-18 02:39:25 -03:00 committed by maibin
parent cf25075e22
commit 6c237aaf3a
11 changed files with 246 additions and 51 deletions

View File

@ -16,6 +16,8 @@
</parent> </parent>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>org.springframework.boot</groupId> <groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-jdbc</artifactId> <artifactId>spring-boot-starter-jdbc</artifactId>
@ -42,6 +44,17 @@
<scope>provided</scope> <scope>provided</scope>
</dependency> </dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-jpa</artifactId>
</dependency>
<dependency>
<groupId>org.hibernate</groupId>
<artifactId>hibernate-jpamodelgen</artifactId>
</dependency>
</dependencies> </dependencies>
<build> <build>

View File

@ -0,0 +1,34 @@
/**
*
*/
package com.baeldung.examples.security.sql;
import java.math.BigDecimal;
import javax.persistence.Entity;
import javax.persistence.GeneratedValue;
import javax.persistence.GenerationType;
import javax.persistence.Id;
import javax.persistence.Table;
import lombok.Data;
/**
* @author Philippe
*
*/
@Entity
@Table(name="Accounts")
@Data
public class Account {
@Id
@GeneratedValue(strategy=GenerationType.IDENTITY)
private Long id;
private String customerId;
private String accNumber;
private String branchId;
private BigDecimal balance;
}

View File

@ -7,14 +7,24 @@ import java.sql.Connection;
import java.sql.PreparedStatement; import java.sql.PreparedStatement;
import java.sql.ResultSet; import java.sql.ResultSet;
import java.sql.SQLException; import java.sql.SQLException;
import java.util.AbstractMap;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import java.util.stream.Stream; import java.util.stream.Stream;
import javax.persistence.EntityManager;
import javax.persistence.Query;
import javax.persistence.TypedQuery;
import javax.persistence.criteria.CriteriaBuilder;
import javax.persistence.criteria.CriteriaQuery;
import javax.persistence.criteria.Order;
import javax.persistence.criteria.Root;
import javax.persistence.metamodel.SingularAttribute;
import javax.sql.DataSource; import javax.sql.DataSource;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
@ -27,9 +37,11 @@ import org.springframework.stereotype.Component;
public class AccountDAO { public class AccountDAO {
private final DataSource dataSource; private final DataSource dataSource;
private final EntityManager em;
public AccountDAO(DataSource dataSource) { public AccountDAO(DataSource dataSource, EntityManager em) {
this.dataSource = dataSource; this.dataSource = dataSource;
this.em = em;
} }
/** /**
@ -63,6 +75,26 @@ public class AccountDAO {
} }
} }
/**
* Return all accounts owned by a given customer,given his/her external id - JPA version
*
* @param customerId
* @return
*/
public List<AccountDTO> unsafeJpaFindAccountsByCustomerId(String customerId) {
String jql = "from Account where customerId = '" + customerId + "'";
TypedQuery<Account> q = em.createQuery(jql, Account.class);
return q.getResultList()
.stream()
.map(a -> AccountDTO.builder()
.accNumber(a.getAccNumber())
.balance(a.getBalance())
.branchId(a.getAccNumber())
.customerId(a.getCustomerId())
.build())
.collect(Collectors.toList());
}
/** /**
* Return all accounts owned by a given customer,given his/her external id * Return all accounts owned by a given customer,given his/her external id
* *
@ -71,7 +103,7 @@ public class AccountDAO {
*/ */
public List<AccountDTO> safeFindAccountsByCustomerId(String customerId) { public List<AccountDTO> safeFindAccountsByCustomerId(String customerId) {
String sql = "select " + "customer_id,acc_number,branch_id,balance from Accounts where customer_id = ?"; String sql = "select customer_id, branch_id,acc_number,balance from Accounts where customer_id = ?";
try (Connection c = dataSource.getConnection(); PreparedStatement p = c.prepareStatement(sql)) { try (Connection c = dataSource.getConnection(); PreparedStatement p = c.prepareStatement(sql)) {
p.setString(1, customerId); p.setString(1, customerId);
@ -93,9 +125,60 @@ public class AccountDAO {
} }
} }
/**
* Return all accounts owned by a given customer,given his/her external id - JPA version
*
* @param customerId
* @return
*/
public List<AccountDTO> safeJpaFindAccountsByCustomerId(String customerId) {
String jql = "from Account where customerId = :customerId";
TypedQuery<Account> q = em.createQuery(jql, Account.class)
.setParameter("customerId", customerId);
return q.getResultList()
.stream()
.map(a -> AccountDTO.builder()
.accNumber(a.getAccNumber())
.balance(a.getBalance())
.branchId(a.getAccNumber())
.customerId(a.getCustomerId())
.build())
.collect(Collectors.toList());
}
/**
* Return all accounts owned by a given customer,given his/her external id - JPA version
*
* @param customerId
* @return
*/
public List<AccountDTO> safeJpaCriteriaFindAccountsByCustomerId(String customerId) {
CriteriaBuilder cb = em.getCriteriaBuilder();
CriteriaQuery<Account> cq = cb.createQuery(Account.class);
Root<Account> root = cq.from(Account.class);
cq.select(root)
.where(cb.equal(root.get(Account_.customerId), customerId));
TypedQuery<Account> q = em.createQuery(cq);
return q.getResultList()
.stream()
.map(a -> AccountDTO.builder()
.accNumber(a.getAccNumber())
.balance(a.getBalance())
.branchId(a.getAccNumber())
.customerId(a.getCustomerId())
.build())
.collect(Collectors.toList());
}
private static final Set<String> VALID_COLUMNS_FOR_ORDER_BY = Stream.of("acc_number", "branch_id", "balance") private static final Set<String> VALID_COLUMNS_FOR_ORDER_BY = Stream.of("acc_number", "branch_id", "balance")
.collect(Collectors.toCollection(HashSet::new)); .collect(Collectors.toCollection(HashSet::new));
/** /**
* Return all accounts owned by a given customer,given his/her external id * Return all accounts owned by a given customer,given his/her external id
* *
@ -108,8 +191,7 @@ public class AccountDAO {
if (VALID_COLUMNS_FOR_ORDER_BY.contains(orderBy)) { if (VALID_COLUMNS_FOR_ORDER_BY.contains(orderBy)) {
sql = sql + " order by " + orderBy; sql = sql + " order by " + orderBy;
} } else {
else {
throw new IllegalArgumentException("Nice try!"); throw new IllegalArgumentException("Nice try!");
} }
@ -135,35 +217,82 @@ public class AccountDAO {
} }
} }
private static final Map<String,SingularAttribute<Account,?>> VALID_JPA_COLUMNS_FOR_ORDER_BY = Stream.of(
new AbstractMap.SimpleEntry<>(Account_.ACC_NUMBER, Account_.accNumber),
new AbstractMap.SimpleEntry<>(Account_.BRANCH_ID, Account_.branchId),
new AbstractMap.SimpleEntry<>(Account_.BALANCE, Account_.balance)
)
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
/**
* Return all accounts owned by a given customer,given his/her external id
*
* @param customerId
* @return
*/
public List<AccountDTO> safeJpaFindAccountsByCustomerId(String customerId, String orderBy) {
SingularAttribute<Account,?> orderByAttribute = VALID_JPA_COLUMNS_FOR_ORDER_BY.get(orderBy);
if ( orderByAttribute == null) {
throw new IllegalArgumentException("Nice try!");
}
CriteriaBuilder cb = em.getCriteriaBuilder();
CriteriaQuery<Account> cq = cb.createQuery(Account.class);
Root<Account> root = cq.from(Account.class);
cq.select(root)
.where(cb.equal(root.get(Account_.customerId), customerId))
.orderBy(cb.asc(root.get(orderByAttribute)));
TypedQuery<Account> q = em.createQuery(cq);
return q.getResultList()
.stream()
.map(a -> AccountDTO.builder()
.accNumber(a.getAccNumber())
.balance(a.getBalance())
.branchId(a.getAccNumber())
.customerId(a.getCustomerId())
.build())
.collect(Collectors.toList());
}
/** /**
* Invalid placeholder usage example * Invalid placeholder usage example
* *
* @param tableName * @param tableName
* @return * @return
*/ */
public List<AccountDTO> wrongCountRecordsByTableName(String tableName) { public Long wrongCountRecordsByTableName(String tableName) {
try (Connection c = dataSource.getConnection(); try (Connection c = dataSource.getConnection(); PreparedStatement p = c.prepareStatement("select count(*) from ?")) {
PreparedStatement p = c.prepareStatement("select count(*) from ?")) {
p.setString(1, tableName); p.setString(1, tableName);
ResultSet rs = p.executeQuery(); ResultSet rs = p.executeQuery();
List<AccountDTO> accounts = new ArrayList<>(); rs.next();
while (rs.next()) { return rs.getLong(1);
AccountDTO acc = AccountDTO.builder()
.customerId(rs.getString("customerId"))
.branchId(rs.getString("branch_id"))
.accNumber(rs.getString("acc_number"))
.balance(rs.getBigDecimal("balance"))
.build();
accounts.add(acc);
}
return accounts;
} catch (SQLException ex) { } catch (SQLException ex) {
throw new RuntimeException(ex); throw new RuntimeException(ex);
} }
} }
/**
* Invalid placeholder usage example - JPA
*
* @param tableName
* @return
*/
public Long wrongJpaCountRecordsByTableName(String tableName) {
String jql = "select count(*) from :tableName";
TypedQuery<Long> q = em.createQuery(jql, Long.class)
.setParameter("tableName", tableName);
return q.getSingleResult();
}
} }

View File

@ -1,19 +0,0 @@
<databaseChangeLog
xmlns="http://www.liquibase.org/xml/ns/dbchangelog"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://www.liquibase.org/xml/ns/dbchangelog
http://www.liquibase.org/xml/ns/dbchangelog/dbchangelog-3.1.xsd">
<changeSet id="create-tables" author="baeldung">
<createTable tableName="Accounts" >
<column name="id" autoIncrement="true" type="BIGINT" remarks="Internal account PK" >
<constraints primaryKey="true"/>
</column>
<column name="customer_id" type="java.sql.Types.VARCHAR(32)" remarks="External Customer Id"></column>
<column name="acc_number" type="java.sql.Types.VARCHAR(128)" remarks="External Account Number"></column>
<column name="branch_id" type="java.sql.Types.VARCHAR(32)"></column>
<column name="balance" type="CURRENCY"></column>
</createTable>
</changeSet>
</databaseChangeLog>

View File

@ -1,8 +0,0 @@
<databaseChangeLog
xmlns="http://www.liquibase.org/xml/ns/dbchangelog"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://www.liquibase.org/xml/ns/dbchangelog
http://www.liquibase.org/xml/ns/dbchangelog/dbchangelog-3.1.xsd">
<include file="changelog/create-tables.xml" relativeToChangelogFile="true"/>
</databaseChangeLog>

View File

@ -40,6 +40,15 @@ public class SqlInjectionSamplesApplicationUnitTest {
assertThat(accounts).hasSize(3); assertThat(accounts).hasSize(3);
} }
@Test
public void givenAVulnerableJpaMethod_whenHackedCustomerId_thenReturnAllAccounts() {
List<AccountDTO> accounts = target.unsafeJpaFindAccountsByCustomerId("C1' or '1'='1");
assertThat(accounts).isNotNull();
assertThat(accounts).isNotEmpty();
assertThat(accounts).hasSize(3);
}
@Test @Test
public void givenASafeMethod_whenHackedCustomerId_thenReturnNoAccounts() { public void givenASafeMethod_whenHackedCustomerId_thenReturnNoAccounts() {
@ -48,13 +57,36 @@ public class SqlInjectionSamplesApplicationUnitTest {
assertThat(accounts).isEmpty(); assertThat(accounts).isEmpty();
} }
@Test
public void givenASafeJpaMethod_whenHackedCustomerId_thenReturnNoAccounts() {
List<AccountDTO> accounts = target.safeJpaFindAccountsByCustomerId("C1' or '1'='1");
assertThat(accounts).isNotNull();
assertThat(accounts).isEmpty();
}
@Test
public void givenASafeJpaCriteriaMethod_whenHackedCustomerId_thenReturnNoAccounts() {
List<AccountDTO> accounts = target.safeJpaCriteriaFindAccountsByCustomerId("C1' or '1'='1");
assertThat(accounts).isNotNull();
assertThat(accounts).isEmpty();
}
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)
public void givenASafeMethod_whenInvalidOrderBy_thenThroweException() { public void givenASafeMethod_whenInvalidOrderBy_thenThroweException() {
target.safeFindAccountsByCustomerId("C1", "INVALID"); target.safeFindAccountsByCustomerId("C1", "INVALID");
} }
@Test(expected = RuntimeException.class) @Test(expected = Exception.class)
public void givenWrongPlaceholderUsageMethod_whenNormalCall_thenThrowsException() { public void givenWrongPlaceholderUsageMethod_whenNormalCall_thenThrowsException() {
target.wrongCountRecordsByTableName("Accounts"); target.wrongCountRecordsByTableName("Accounts");
} }
@Test(expected = Exception.class)
public void givenWrongJpaPlaceholderUsageMethod_whenNormalCall_thenThrowsException() {
target.wrongJpaCountRecordsByTableName("Accounts");
}
} }

View File

@ -2,5 +2,17 @@
# Test profile configuration # Test profile configuration
# #
spring: spring:
liquibase:
change-log: db/changelog/db.changelog-master.xml
jpa:
hibernate:
ddl-auto: none
datasource: datasource:
initialization-mode: always initialization-mode: embedded
logging:
level:
sql: DEBUG

View File

@ -1,4 +1,5 @@
create table Accounts ( create table Accounts (
id BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1),
customer_id varchar(16) not null, customer_id varchar(16) not null,
acc_number varchar(16) not null, acc_number varchar(16) not null,
branch_id decimal(8,0), branch_id decimal(8,0),