Segregate advance and advanceUninterruptibly flow in postJoinCursor to allow for interrupts in advance (#15222)

Currently advance function in postJoinCursor calls advanceUninterruptibly which in turn keeps calling baseCursor.advanceUninterruptibly until the post join condition matches, without checking for interrupts. This causes the CPU to hit 100% without getting a chance for query to be cancelled.

With this change, the call flow of advance and advanceUninterruptibly is separated out so that they call baseCursor.advance and baseCursor.advanceUninterruptibly in them, respectively, giving a chance for interrupts in the former case between successive calls to baseCursor.advance.
This commit is contained in:
Vishesh Garg 2023-10-30 14:39:15 +05:30 committed by GitHub
parent 275c1ec64c
commit a27598a487
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 292 additions and 5 deletions

View File

@ -19,7 +19,7 @@
package org.apache.druid.segment.join;
import org.apache.druid.query.BaseQuery;
import com.google.common.annotations.VisibleForTesting;
import org.apache.druid.query.filter.Filter;
import org.apache.druid.query.filter.ValueMatcher;
import org.apache.druid.segment.ColumnSelectorFactory;
@ -39,7 +39,7 @@ public class PostJoinCursor implements Cursor
private final ColumnSelectorFactory columnSelectorFactory;
@Nullable
private final ValueMatcher valueMatcher;
private ValueMatcher valueMatcher;
@Nullable
private final Filter postJoinFilter;
@ -69,7 +69,28 @@ public class PostJoinCursor implements Cursor
return postJoinCursor;
}
@VisibleForTesting
public void setValueMatcher(@Nullable ValueMatcher valueMatcher)
{
this.valueMatcher = valueMatcher;
}
private void advanceToMatch()
{
if (valueMatcher != null) {
while (!isDone() && !valueMatcher.matches(false)) {
baseCursor.advance();
}
}
}
/**
* Matches tuples coming out of a join to a post-join condition uninterruptibly, and hence can be a long-running call.
* For this reason, {@link PostJoinCursor#advance()} instead calls {@link PostJoinCursor#advanceToMatch()} (unlike
* other cursors) that allows interruptions, thereby resolving issues where the
* <a href="https://github.com/apache/druid/issues/14514">CPU thread running PostJoinCursor cannot be terminated</a>
*/
private void advanceToMatchUninterruptibly()
{
if (valueMatcher != null) {
while (!isDone() && !valueMatcher.matches(false)) {
@ -99,15 +120,17 @@ public class PostJoinCursor implements Cursor
@Override
public void advance()
{
advanceUninterruptibly();
BaseQuery.checkInterrupted();
baseCursor.advance();
// Relies on baseCursor.advance() call inside this for BaseQuery.checkInterrupted() checks -- unlike other cursors
// which call advanceInterruptibly() and hence have to explicitly provision for interrupts.
advanceToMatch();
}
@Override
public void advanceUninterruptibly()
{
baseCursor.advanceUninterruptibly();
advanceToMatch();
advanceToMatchUninterruptibly();
}
@Override

View File

@ -0,0 +1,264 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.segment.join;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.query.QueryInterruptedException;
import org.apache.druid.query.QueryMetrics;
import org.apache.druid.query.filter.Filter;
import org.apache.druid.query.filter.ValueMatcher;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.Cursor;
import org.apache.druid.segment.QueryableIndex;
import org.apache.druid.segment.QueryableIndexSegment;
import org.apache.druid.segment.QueryableIndexStorageAdapter;
import org.apache.druid.segment.StorageAdapter;
import org.apache.druid.segment.VirtualColumns;
import org.apache.druid.segment.join.filter.JoinFilterPreAnalysis;
import org.apache.druid.timeline.SegmentId;
import org.joda.time.DateTime;
import org.joda.time.Interval;
import org.junit.Test;
import javax.annotation.Nullable;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import static java.lang.Thread.sleep;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
public class PostJoinCursorTest extends BaseHashJoinSegmentStorageAdapterTest
{
public QueryableIndexSegment infiniteFactSegment;
/**
* Simulates infinite segment by using a base cursor with advance() and advanceInterruptibly()
* reduced to a no-op.
*/
private static class TestInfiniteQueryableIndexSegment extends QueryableIndexSegment
{
private static class InfiniteQueryableIndexStorageAdapter extends QueryableIndexStorageAdapter
{
CountDownLatch countDownLatch;
public InfiniteQueryableIndexStorageAdapter(QueryableIndex index, CountDownLatch countDownLatch)
{
super(index);
this.countDownLatch = countDownLatch;
}
@Override
public Sequence<Cursor> makeCursors(
@Nullable Filter filter,
Interval interval,
VirtualColumns virtualColumns,
Granularity gran,
boolean descending,
@Nullable QueryMetrics<?> queryMetrics
)
{
return super.makeCursors(filter, interval, virtualColumns, gran, descending, queryMetrics)
.map(cursor -> new CursorNoAdvance(cursor, countDownLatch));
}
private static class CursorNoAdvance implements Cursor
{
Cursor cursor;
CountDownLatch countDownLatch;
public CursorNoAdvance(Cursor cursor, CountDownLatch countDownLatch)
{
this.cursor = cursor;
this.countDownLatch = countDownLatch;
}
@Override
public ColumnSelectorFactory getColumnSelectorFactory()
{
return cursor.getColumnSelectorFactory();
}
@Override
public DateTime getTime()
{
return cursor.getTime();
}
@Override
public void advance()
{
// Do nothing to simulate infinite rows
countDownLatch.countDown();
}
@Override
public void advanceUninterruptibly()
{
// Do nothing to simulate infinite rows
countDownLatch.countDown();
}
@Override
public boolean isDone()
{
return false;
}
@Override
public boolean isDoneOrInterrupted()
{
return cursor.isDoneOrInterrupted();
}
@Override
public void reset()
{
}
}
}
private final StorageAdapter testStorageAdaptor;
public TestInfiniteQueryableIndexSegment(QueryableIndex index, SegmentId segmentId, CountDownLatch countDownLatch)
{
super(index, segmentId);
testStorageAdaptor = new InfiniteQueryableIndexStorageAdapter(index, countDownLatch);
}
@Override
public StorageAdapter asStorageAdapter()
{
return testStorageAdaptor;
}
}
private static class ExceptionHandler implements Thread.UncaughtExceptionHandler
{
Throwable exception;
@Override
public void uncaughtException(Thread t, Throwable e)
{
exception = e;
}
public Throwable getException()
{
return exception;
}
}
@Test
public void testAdvanceWithInterruption() throws IOException, InterruptedException
{
final int rowsBeforeInterrupt = 1000;
CountDownLatch countDownLatch = new CountDownLatch(rowsBeforeInterrupt);
infiniteFactSegment = new TestInfiniteQueryableIndexSegment(
JoinTestHelper.createFactIndexBuilder(temporaryFolder.newFolder()).buildMMappedIndex(),
SegmentId.dummy("facts"),
countDownLatch
);
countriesTable = JoinTestHelper.createCountriesIndexedTable();
Thread joinCursorThread = new Thread(() -> makeCursorAndAdvance());
ExceptionHandler exceptionHandler = new ExceptionHandler();
joinCursorThread.setUncaughtExceptionHandler(exceptionHandler);
joinCursorThread.start();
countDownLatch.await(1, TimeUnit.SECONDS);
joinCursorThread.interrupt();
// Wait for a max of 1 sec for the exception to be set.
for (int i = 0; i < 1000; i++) {
if (exceptionHandler.getException() == null) {
sleep(1);
} else {
assertTrue(exceptionHandler.getException() instanceof QueryInterruptedException);
return;
}
}
fail();
}
public void makeCursorAndAdvance()
{
List<JoinableClause> joinableClauses = ImmutableList.of(
factToCountryOnIsoCode(JoinType.LEFT)
);
JoinFilterPreAnalysis joinFilterPreAnalysis = makeDefaultConfigPreAnalysis(
null,
joinableClauses,
VirtualColumns.EMPTY
);
HashJoinSegmentStorageAdapter hashJoinSegmentStorageAdapter = new HashJoinSegmentStorageAdapter(
infiniteFactSegment.asStorageAdapter(),
joinableClauses,
joinFilterPreAnalysis
);
Cursor cursor = Iterables.getOnlyElement(hashJoinSegmentStorageAdapter.makeCursors(
null,
Intervals.ETERNITY,
VirtualColumns.EMPTY,
Granularities.ALL,
false,
null
).toList());
((PostJoinCursor) cursor).setValueMatcher(new ValueMatcher()
{
@Override
public boolean matches(boolean includeUnknown)
{
return false;
}
@Override
public void inspectRuntimeShape(RuntimeShapeInspector inspector)
{
}
});
cursor.advance();
}
}