diff --git a/sql/src/main/java/org/apache/druid/sql/avatica/DruidConnection.java b/sql/src/main/java/org/apache/druid/sql/avatica/DruidConnection.java index 63ef7e9f518..6c84ccb46ea 100644 --- a/sql/src/main/java/org/apache/druid/sql/avatica/DruidConnection.java +++ b/sql/src/main/java/org/apache/druid/sql/avatica/DruidConnection.java @@ -30,9 +30,10 @@ import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.sql.SqlLifecycleFactory; import javax.annotation.concurrent.GuardedBy; -import java.util.HashMap; import java.util.Map; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import java.util.concurrent.Future; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -53,10 +54,14 @@ public class DruidConnection private final AtomicInteger statementCounter = new AtomicInteger(); private final AtomicReference> timeoutFuture = new AtomicReference<>(); - @GuardedBy("statements") - private final Map statements; + // Typically synchronized by connectionLock, except in one case: the onClose function passed + // into DruidStatements contained by the map. + private final ConcurrentMap statements; - @GuardedBy("statements") + @GuardedBy("connectionLock") + private final Object connectionLock = new Object(); + + @GuardedBy("connectionLock") private boolean open = true; public DruidConnection(final String connectionId, final int maxStatements, final Map context) @@ -64,14 +69,14 @@ public class DruidConnection this.connectionId = Preconditions.checkNotNull(connectionId); this.maxStatements = maxStatements; this.context = ImmutableMap.copyOf(context); - this.statements = new HashMap<>(); + this.statements = new ConcurrentHashMap<>(); } public DruidStatement createStatement(SqlLifecycleFactory sqlLifecycleFactory) { final int statementId = statementCounter.incrementAndGet(); - synchronized (statements) { + synchronized (connectionLock) { if (statements.containsKey(statementId)) { // Will only happen if statementCounter rolls over before old statements are cleaned up. If this // ever happens then something fishy is going on, because we shouldn't have billions of statements. @@ -96,10 +101,9 @@ public class DruidConnection sqlLifecycleFactory.factorize(), () -> { // onClose function for the statement - synchronized (statements) { - log.debug("Connection[%s] closed statement[%s].", connectionId, statementId); - statements.remove(statementId); - } + log.debug("Connection[%s] closed statement[%s].", connectionId, statementId); + // statements will be accessed unsynchronized to avoid deadlock + statements.remove(statementId); } ); @@ -111,7 +115,7 @@ public class DruidConnection public DruidStatement getStatement(final int statementId) { - synchronized (statements) { + synchronized (connectionLock) { return statements.get(statementId); } } @@ -123,7 +127,7 @@ public class DruidConnection */ public boolean closeIfEmpty() { - synchronized (statements) { + synchronized (connectionLock) { if (statements.isEmpty()) { close(); return true; @@ -135,7 +139,7 @@ public class DruidConnection public void close() { - synchronized (statements) { + synchronized (connectionLock) { // Copy statements before iterating because statement.close() modifies it. for (DruidStatement statement : ImmutableList.copyOf(statements.values())) { try { diff --git a/sql/src/main/java/org/apache/druid/sql/avatica/DruidStatement.java b/sql/src/main/java/org/apache/druid/sql/avatica/DruidStatement.java index 432eb4045db..20cdd45359d 100644 --- a/sql/src/main/java/org/apache/druid/sql/avatica/DruidStatement.java +++ b/sql/src/main/java/org/apache/druid/sql/avatica/DruidStatement.java @@ -300,11 +300,11 @@ public class DruidStatement implements Closeable @Override public void close() { - synchronized (lock) { - final State oldState = state; - state = State.DONE; - - try { + State oldState = null; + try { + synchronized (lock) { + oldState = state; + state = State.DONE; if (yielder != null) { Yielder theYielder = this.yielder; this.yielder = null; @@ -321,33 +321,33 @@ public class DruidStatement implements Closeable yielderOpenCloseExecutor.shutdownNow(); } } - catch (Throwable t) { - if (oldState != State.DONE) { - // First close. Run the onClose function. - try { - onClose.run(); - sqlLifecycle.emitLogsAndMetrics(t, null, -1); - } - catch (Throwable t1) { - t.addSuppressed(t1); - } - } - - throw Throwables.propagate(t); - } - + } + catch (Throwable t) { if (oldState != State.DONE) { // First close. Run the onClose function. try { - if (!(this.throwable instanceof ForbiddenException)) { - sqlLifecycle.emitLogsAndMetrics(this.throwable, null, -1); - } onClose.run(); + sqlLifecycle.emitLogsAndMetrics(t, null, -1); } - catch (Throwable t) { - throw Throwables.propagate(t); + catch (Throwable t1) { + t.addSuppressed(t1); } } + + throw Throwables.propagate(t); + } + + if (oldState != State.DONE) { + // First close. Run the onClose function. + try { + if (!(this.throwable instanceof ForbiddenException)) { + sqlLifecycle.emitLogsAndMetrics(this.throwable, null, -1); + } + onClose.run(); + } + catch (Throwable t) { + throw Throwables.propagate(t); + } } }