Intern RowSignature in DruidSchema to reduce its memory footprint (#12001)

DruidSchema consists of a concurrent HashMap of DataSource -> Segement -> AvailableSegmentMetadata. AvailableSegmentMetadata contains RowSignature of the segment, and for each segment, a new object is getting created. RowSignature is an immutable class, and hence it can be interned, and this can lead to huge savings of memory being used in broker, since a lot of the segments of a table would potentially have same RowSignature.
This commit is contained in:
Laksh Singla 2021-12-08 15:11:13 +05:30 committed by GitHub
parent 45be2be368
commit ca260dfef6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 261 additions and 16 deletions

View File

@ -0,0 +1,216 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.benchmark;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import org.apache.druid.client.BrokerInternalQueryConfig;
import org.apache.druid.client.TimelineServerView;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.query.metadata.metadata.ColumnAnalysis;
import org.apache.druid.query.metadata.metadata.SegmentAnalysis;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.join.JoinableFactory;
import org.apache.druid.server.QueryLifecycleFactory;
import org.apache.druid.server.SegmentManager;
import org.apache.druid.server.coordination.DruidServerMetadata;
import org.apache.druid.server.coordination.ServerType;
import org.apache.druid.server.security.Escalator;
import org.apache.druid.sql.calcite.planner.PlannerConfig;
import org.apache.druid.sql.calcite.schema.DruidSchema;
import org.apache.druid.timeline.DataSegment;
import org.apache.druid.timeline.SegmentId;
import org.apache.druid.timeline.partition.LinearShardSpec;
import org.easymock.EasyMock;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
@State(Scope.Benchmark)
@Fork(value = 1)
@Warmup(iterations = 3)
@Measurement(iterations = 10)
public class DruidSchemaInternRowSignatureBenchmark
{
private DruidSchemaForBenchmark druidSchema;
private static class DruidSchemaForBenchmark extends DruidSchema
{
public DruidSchemaForBenchmark(
final QueryLifecycleFactory queryLifecycleFactory,
final TimelineServerView serverView,
final SegmentManager segmentManager,
final JoinableFactory joinableFactory,
final PlannerConfig config,
final Escalator escalator,
final BrokerInternalQueryConfig brokerInternalQueryConfig
)
{
super(
queryLifecycleFactory,
serverView,
segmentManager,
joinableFactory,
config,
escalator,
brokerInternalQueryConfig
);
}
// Overriding here so that it can be called explicitly to benchmark
@Override
public Set<SegmentId> refreshSegments(final Set<SegmentId> segments) throws IOException
{
return super.refreshSegments(segments);
}
@Override
public void addSegment(final DruidServerMetadata server, final DataSegment segment)
{
super.addSegment(server, segment);
}
@Override
protected Sequence<SegmentAnalysis> runSegmentMetadataQuery(Iterable<SegmentId> segments)
{
final int numColumns = 1000;
Map<String, ColumnAnalysis> columnToAnalysisMap = new HashMap<>();
for (int i = 0; i < numColumns; ++i) {
columnToAnalysisMap.put(
"col" + i,
new ColumnAnalysis(
ColumnType.STRING,
null,
false,
false,
40,
null,
null,
null,
null
)
);
}
return Sequences.simple(
Lists.transform(
Lists.newArrayList(segments),
(segment) -> new SegmentAnalysis(
segment.toString(),
ImmutableList.of(segment.getInterval()),
columnToAnalysisMap,
40,
40,
null,
null,
null,
false
)
)
);
}
}
@State(Scope.Thread)
public static class MyState
{
Set<SegmentId> segmentIds;
@Setup(Level.Iteration)
public void setup()
{
ImmutableSet.Builder<SegmentId> segmentIdsBuilder = ImmutableSet.builder();
for (int i = 0; i < 10000; ++i) {
segmentIdsBuilder.add(SegmentId.of("dummy", Intervals.of(i + "/" + (i + 1)), "1", new LinearShardSpec(0)));
}
segmentIds = segmentIdsBuilder.build();
}
@TearDown(Level.Iteration)
public void teardown()
{
segmentIds = null;
}
}
@Setup
public void setup()
{
druidSchema = new DruidSchemaForBenchmark(
EasyMock.mock(QueryLifecycleFactory.class),
EasyMock.mock(TimelineServerView.class),
null,
null,
EasyMock.mock(PlannerConfig.class),
null,
null
);
DruidServerMetadata serverMetadata = new DruidServerMetadata(
"dummy",
"dummy",
"dummy",
42,
ServerType.HISTORICAL,
"tier-0",
0
);
DataSegment.Builder builder = DataSegment.builder()
.dataSource("dummy")
.shardSpec(new LinearShardSpec(0))
.dimensions(ImmutableList.of("col1", "col2", "col3", "col4"))
.version("1")
.size(0);
for (int i = 0; i < 10000; ++i) {
DataSegment dataSegment = builder.interval(Intervals.of(i + "/" + (i + 1)))
.build();
druidSchema.addSegment(serverMetadata, dataSegment);
}
}
@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
public void addSegments(MyState state, Blackhole blackhole) throws IOException
{
blackhole.consume(druidSchema.refreshSegments(state.segmentIds));
}
}

View File

@ -54,6 +54,13 @@ public class RowSignature implements ColumnInspector
private final Object2IntMap<String> columnPositions = new Object2IntOpenHashMap<>();
private final List<String> columnNames;
/**
* Precompute and store the hashCode since it is getting interned in
* {@link org.apache.druid.sql.calcite.schema.DruidSchema}
* Also helps in comparing the RowSignatures in equals method
*/
private final int hashCode;
private RowSignature(final List<ColumnSignature> columnTypeList)
{
this.columnPositions.defaultReturnValue(-1);
@ -76,6 +83,7 @@ public class RowSignature implements ColumnInspector
}
this.columnNames = columnNamesBuilder.build();
this.hashCode = computeHashCode();
}
@JsonCreator
@ -192,14 +200,20 @@ public class RowSignature implements ColumnInspector
return false;
}
RowSignature that = (RowSignature) o;
return columnTypes.equals(that.columnTypes) &&
return hashCode == that.hashCode &&
columnTypes.equals(that.columnTypes) &&
columnNames.equals(that.columnNames);
}
private int computeHashCode()
{
return Objects.hash(columnTypes, columnNames);
}
@Override
public int hashCode()
{
return Objects.hash(columnTypes, columnNames);
return hashCode;
}
@Override

View File

@ -21,6 +21,7 @@ package org.apache.druid.segment.column;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
import nl.jqno.equalsverifier.EqualsVerifier;
import org.apache.druid.segment.TestHelper;
import org.junit.Assert;
import org.junit.Test;
@ -29,6 +30,16 @@ import java.io.IOException;
public class RowSignatureTest
{
@Test
public void testEqualsAndHashCode()
{
EqualsVerifier.forClass(RowSignature.class)
.usingGetClass()
.withCachedHashCode("hashCode", "computeHashCode", RowSignature.builder().build())
.withIgnoredFields("columnPositions")
.verify();
}
@Test
public void test_add_withConflict()
{

View File

@ -25,6 +25,8 @@ import com.google.common.base.Predicates;
import com.google.common.collect.FluentIterable;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Interner;
import com.google.common.collect.Interners;
import com.google.common.collect.Iterables;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
@ -112,6 +114,8 @@ public class DruidSchema extends AbstractSchema
*/
private final ConcurrentMap<String, DruidTable> tables = new ConcurrentHashMap<>();
private static final Interner<RowSignature> ROW_SIGNATURE_INTERNER = Interners.newWeakInterner();
/**
* DataSource -> Segment -> AvailableSegmentMetadata(contains RowSignature) for that segment.
* Use SortedMap for segments so they are merged in deterministic order, from older to newer.
@ -414,7 +418,7 @@ public class DruidSchema extends AbstractSchema
}
@VisibleForTesting
void addSegment(final DruidServerMetadata server, final DataSegment segment)
protected void addSegment(final DruidServerMetadata server, final DataSegment segment)
{
// Get lock first so that we won't wait in ConcurrentMap.compute().
synchronized (lock) {
@ -440,7 +444,7 @@ public class DruidSchema extends AbstractSchema
// segmentReplicatable is used to determine if segments are served by historical or realtime servers
long isRealtime = server.isSegmentReplicationTarget() ? 0 : 1;
segmentMetadata = AvailableSegmentMetadata
.builder(segment, isRealtime, ImmutableSet.of(server), null, DEFAULT_NUM_ROWS)
.builder(segment, isRealtime, ImmutableSet.of(server), null, DEFAULT_NUM_ROWS) // Added without needing a refresh
.build();
markSegmentAsNeedRefresh(segment.getId());
if (!server.isSegmentReplicationTarget()) {
@ -620,7 +624,7 @@ public class DruidSchema extends AbstractSchema
* which may be a subset of the asked-for set.
*/
@VisibleForTesting
Set<SegmentId> refreshSegments(final Set<SegmentId> segments) throws IOException
protected Set<SegmentId> refreshSegments(final Set<SegmentId> segments) throws IOException
{
final Set<SegmentId> retVal = new HashSet<>();
@ -844,7 +848,7 @@ public class DruidSchema extends AbstractSchema
* @return {@link Sequence} of {@link SegmentAnalysis} objects
*/
@VisibleForTesting
Sequence<SegmentAnalysis> runSegmentMetadataQuery(
protected Sequence<SegmentAnalysis> runSegmentMetadataQuery(
final Iterable<SegmentId> segments
)
{
@ -903,7 +907,7 @@ public class DruidSchema extends AbstractSchema
rowSignatureBuilder.add(entry.getKey(), valueType);
}
return rowSignatureBuilder.build();
return ROW_SIGNATURE_INTERNER.intern(rowSignatureBuilder.build());
}
/**

View File

@ -222,7 +222,7 @@ public class DruidSchemaTest extends DruidSchemaTestCommon
}
@Override
Set<SegmentId> refreshSegments(final Set<SegmentId> segments) throws IOException
protected Set<SegmentId> refreshSegments(final Set<SegmentId> segments) throws IOException
{
if (throwException) {
throwException = false;
@ -489,7 +489,7 @@ public class DruidSchemaTest extends DruidSchemaTestCommon
)
{
@Override
void addSegment(final DruidServerMetadata server, final DataSegment segment)
protected void addSegment(final DruidServerMetadata server, final DataSegment segment)
{
super.addSegment(server, segment);
if (datasource.equals(segment.getDataSource())) {
@ -531,7 +531,7 @@ public class DruidSchemaTest extends DruidSchemaTestCommon
)
{
@Override
void addSegment(final DruidServerMetadata server, final DataSegment segment)
protected void addSegment(final DruidServerMetadata server, final DataSegment segment)
{
super.addSegment(server, segment);
if (datasource.equals(segment.getDataSource())) {
@ -577,7 +577,7 @@ public class DruidSchemaTest extends DruidSchemaTestCommon
)
{
@Override
void addSegment(final DruidServerMetadata server, final DataSegment segment)
protected void addSegment(final DruidServerMetadata server, final DataSegment segment)
{
super.addSegment(server, segment);
if (datasource.equals(segment.getDataSource())) {
@ -620,7 +620,7 @@ public class DruidSchemaTest extends DruidSchemaTestCommon
)
{
@Override
void addSegment(final DruidServerMetadata server, final DataSegment segment)
protected void addSegment(final DruidServerMetadata server, final DataSegment segment)
{
super.addSegment(server, segment);
if (datasource.equals(segment.getDataSource())) {
@ -660,7 +660,7 @@ public class DruidSchemaTest extends DruidSchemaTestCommon
)
{
@Override
void addSegment(final DruidServerMetadata server, final DataSegment segment)
protected void addSegment(final DruidServerMetadata server, final DataSegment segment)
{
super.addSegment(server, segment);
if (datasource.equals(segment.getDataSource())) {
@ -717,7 +717,7 @@ public class DruidSchemaTest extends DruidSchemaTestCommon
)
{
@Override
void addSegment(final DruidServerMetadata server, final DataSegment segment)
protected void addSegment(final DruidServerMetadata server, final DataSegment segment)
{
super.addSegment(server, segment);
if (datasource.equals(segment.getDataSource())) {
@ -811,7 +811,7 @@ public class DruidSchemaTest extends DruidSchemaTestCommon
)
{
@Override
void addSegment(final DruidServerMetadata server, final DataSegment segment)
protected void addSegment(final DruidServerMetadata server, final DataSegment segment)
{
super.addSegment(server, segment);
if (datasource.equals(segment.getDataSource())) {
@ -858,7 +858,7 @@ public class DruidSchemaTest extends DruidSchemaTestCommon
)
{
@Override
void addSegment(final DruidServerMetadata server, final DataSegment segment)
protected void addSegment(final DruidServerMetadata server, final DataSegment segment)
{
super.addSegment(server, segment);
if (datasource.equals(segment.getDataSource())) {