mirror of https://github.com/apache/lucene.git
338 lines
12 KiB
Python
338 lines
12 KiB
Python
#! /usr/bin/env python
|
|
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
# this work for additional information regarding copyright ownership.
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
# (the "License"); you may not use this file except in compliance with
|
|
# the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
try:
|
|
# python 3.9+
|
|
from math import gcd
|
|
except ImportError:
|
|
# old python
|
|
from fractions import gcd
|
|
|
|
"""Code generation for bulk operations"""
|
|
|
|
MAX_SPECIALIZED_BITS_PER_VALUE = 24;
|
|
PACKED_64_SINGLE_BLOCK_BPV = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 16, 21, 32]
|
|
OUTPUT_FILE = "BulkOperation.java"
|
|
HEADER = """// This file has been automatically generated, DO NOT EDIT
|
|
|
|
/*
|
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
|
* contributor license agreements. See the NOTICE file distributed with
|
|
* this work for additional information regarding copyright ownership.
|
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
|
* (the "License"); you may not use this file except in compliance with
|
|
* the License. You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
package org.apache.lucene.util.packed;
|
|
|
|
"""
|
|
|
|
FOOTER = """
|
|
protected int writeLong(long block, byte[] blocks, int blocksOffset) {
|
|
for (int j = 1; j <= 8; ++j) {
|
|
blocks[blocksOffset++] = (byte) (block >>> (64 - (j << 3)));
|
|
}
|
|
return blocksOffset;
|
|
}
|
|
|
|
/**
|
|
* For every number of bits per value, there is a minimum number of
|
|
* blocks (b) / values (v) you need to write in order to reach the next block
|
|
* boundary:
|
|
* <pre>
|
|
* - 16 bits per value -> b=2, v=1
|
|
* - 24 bits per value -> b=3, v=1
|
|
* - 50 bits per value -> b=25, v=4
|
|
* - 63 bits per value -> b=63, v=8
|
|
* - ...
|
|
* </pre>
|
|
*
|
|
* A bulk read consists in copying <code>iterations*v</code> values that are
|
|
* contained in <code>iterations*b</code> blocks into a <code>long[]</code>
|
|
* (higher values of <code>iterations</code> are likely to yield a better
|
|
* throughput): this requires n * (b + 8v) bytes of memory.
|
|
*
|
|
* This method computes <code>iterations</code> as
|
|
* <code>ramBudget / (b + 8v)</code> (since a long is 8 bytes).
|
|
*/
|
|
public final int computeIterations(int valueCount, int ramBudget) {
|
|
final int iterations = ramBudget / (byteBlockCount() + 8 * byteValueCount());
|
|
if (iterations == 0) {
|
|
// at least 1
|
|
return 1;
|
|
} else if ((iterations - 1) * byteValueCount() >= valueCount) {
|
|
// don't allocate for more than the size of the reader
|
|
return (int) Math.ceil((double) valueCount / byteValueCount());
|
|
} else {
|
|
return iterations;
|
|
}
|
|
}
|
|
}
|
|
"""
|
|
|
|
def is_power_of_two(n):
|
|
return n & (n - 1) == 0
|
|
|
|
def casts(typ):
|
|
cast_start = "(%s) (" % typ
|
|
cast_end = ")"
|
|
if typ == "long":
|
|
cast_start = ""
|
|
cast_end = ""
|
|
return cast_start, cast_end
|
|
|
|
def hexNoLSuffix(n):
|
|
# On 32 bit Python values > (1 << 31)-1 will have L appended by hex function:
|
|
s = hex(n)
|
|
if s.endswith('L'):
|
|
s = s[:-1]
|
|
return s
|
|
|
|
def masks(bits):
|
|
if bits == 64:
|
|
return "", ""
|
|
return "(", " & %sL)" % (hexNoLSuffix((1 << bits) - 1))
|
|
|
|
def get_type(bits):
|
|
if bits == 8:
|
|
return "byte"
|
|
elif bits == 16:
|
|
return "short"
|
|
elif bits == 32:
|
|
return "int"
|
|
elif bits == 64:
|
|
return "long"
|
|
else:
|
|
assert False
|
|
|
|
def block_value_count(bpv, bits=64):
|
|
blocks = bpv
|
|
values = blocks * bits // bpv
|
|
while blocks % 2 == 0 and values % 2 == 0:
|
|
blocks //= 2
|
|
values //= 2
|
|
assert values * bpv == bits * blocks, "%d values, %d blocks, %d bits per value" % (values, blocks, bpv)
|
|
return (blocks, values)
|
|
|
|
def packed64(bpv, f):
|
|
mask = (1 << bpv) - 1
|
|
|
|
f.write("\n")
|
|
f.write(" public BulkOperationPacked%d() {\n" % bpv)
|
|
f.write(" super(%d);\n" % bpv)
|
|
f.write(" }\n\n")
|
|
|
|
if bpv == 64:
|
|
f.write(""" @Override
|
|
public void decode(long[] blocks, int blocksOffset, long[] values, int valuesOffset, int iterations) {
|
|
System.arraycopy(blocks, blocksOffset, values, valuesOffset, valueCount() * iterations);
|
|
}
|
|
|
|
@Override
|
|
public void decode(long[] blocks, int blocksOffset, int[] values, int valuesOffset, int iterations) {
|
|
throw new UnsupportedOperationException();
|
|
}
|
|
|
|
@Override
|
|
public void decode(byte[] blocks, int blocksOffset, int[] values, int valuesOffset, int iterations) {
|
|
throw new UnsupportedOperationException();
|
|
}
|
|
|
|
@Override
|
|
public void decode(byte[] blocks, int blocksOffset, long[] values, int valuesOffset, int iterations) {
|
|
LongBuffer.wrap(values, valuesOffset, iterations * valueCount()).put(ByteBuffer.wrap(blocks, blocksOffset, 8 * iterations * blockCount()).asLongBuffer());
|
|
}
|
|
""")
|
|
else:
|
|
p64_decode(bpv, f, 32)
|
|
p64_decode(bpv, f, 64)
|
|
|
|
def p64_decode(bpv, f, bits):
|
|
blocks, values = block_value_count(bpv)
|
|
typ = get_type(bits)
|
|
cast_start, cast_end = casts(typ)
|
|
|
|
f.write(" @Override\n")
|
|
f.write(" public void decode(long[] blocks, int blocksOffset, %s[] values, int valuesOffset, int iterations) {\n" % typ)
|
|
if bits < bpv:
|
|
f.write(" throw new UnsupportedOperationException();\n")
|
|
else:
|
|
f.write(" for (int i = 0; i < iterations; ++i) {\n")
|
|
mask = (1 << bpv) - 1
|
|
|
|
if is_power_of_two(bpv):
|
|
f.write(" final long block = blocks[blocksOffset++];\n")
|
|
f.write(" for (int shift = %d; shift >= 0; shift -= %d) {\n" % (64 - bpv, bpv))
|
|
f.write(" values[valuesOffset++] = %s(block >>> shift) & %d%s;\n" % (cast_start, mask, cast_end))
|
|
f.write(" }\n")
|
|
else:
|
|
for i in range(0, values):
|
|
block_offset = i * bpv // 64
|
|
bit_offset = (i * bpv) % 64
|
|
if bit_offset == 0:
|
|
# start of block
|
|
f.write(" final long block%d = blocks[blocksOffset++];\n" % block_offset);
|
|
f.write(" values[valuesOffset++] = %sblock%d >>> %d%s;\n" % (cast_start, block_offset, 64 - bpv, cast_end))
|
|
elif bit_offset + bpv == 64:
|
|
# end of block
|
|
f.write(" values[valuesOffset++] = %sblock%d & %dL%s;\n" % (cast_start, block_offset, mask, cast_end))
|
|
elif bit_offset + bpv < 64:
|
|
# middle of block
|
|
f.write(" values[valuesOffset++] = %s(block%d >>> %d) & %dL%s;\n" % (cast_start, block_offset, 64 - bit_offset - bpv, mask, cast_end))
|
|
else:
|
|
# value spans across 2 blocks
|
|
mask1 = (1 << (64 - bit_offset)) - 1
|
|
shift1 = bit_offset + bpv - 64
|
|
shift2 = 64 - shift1
|
|
f.write(" final long block%d = blocks[blocksOffset++];\n" % (block_offset + 1));
|
|
f.write(" values[valuesOffset++] = %s((block%d & %dL) << %d) | (block%d >>> %d)%s;\n" % (cast_start, block_offset, mask1, shift1, block_offset + 1, shift2, cast_end))
|
|
f.write(" }\n")
|
|
f.write(" }\n\n")
|
|
|
|
byte_blocks, byte_values = block_value_count(bpv, 8)
|
|
|
|
f.write(" @Override\n")
|
|
f.write(" public void decode(byte[] blocks, int blocksOffset, %s[] values, int valuesOffset, int iterations) {\n" % typ)
|
|
if bits < bpv:
|
|
f.write(" throw new UnsupportedOperationException();\n")
|
|
else:
|
|
if is_power_of_two(bpv) and bpv < 8:
|
|
f.write(" for (int j = 0; j < iterations; ++j) {\n")
|
|
f.write(" final byte block = blocks[blocksOffset++];\n")
|
|
for shift in range(8 - bpv, 0, -bpv):
|
|
f.write(" values[valuesOffset++] = (block >>> %d) & %d;\n" % (shift, mask))
|
|
f.write(" values[valuesOffset++] = block & %d;\n" % mask)
|
|
f.write(" }\n")
|
|
elif bpv == 8:
|
|
f.write(" for (int j = 0; j < iterations; ++j) {\n")
|
|
f.write(" values[valuesOffset++] = blocks[blocksOffset++] & 0xFF;\n")
|
|
f.write(" }\n")
|
|
elif is_power_of_two(bpv) and bpv > 8:
|
|
f.write(" for (int j = 0; j < iterations; ++j) {\n")
|
|
m = bits <= 32 and "0xFF" or "0xFFL"
|
|
f.write(" values[valuesOffset++] =")
|
|
for i in range(bpv // 8 - 1):
|
|
f.write(" ((blocks[blocksOffset++] & %s) << %d) |" % (m, bpv - 8))
|
|
f.write(" (blocks[blocksOffset++] & %s);\n" % m)
|
|
f.write(" }\n")
|
|
else:
|
|
f.write(" for (int i = 0; i < iterations; ++i) {\n")
|
|
for i in range(0, byte_values):
|
|
byte_start = i * bpv // 8
|
|
bit_start = (i * bpv) % 8
|
|
byte_end = ((i + 1) * bpv - 1) // 8
|
|
bit_end = ((i + 1) * bpv - 1) % 8
|
|
shift = lambda b: 8 * (byte_end - b - 1) + 1 + bit_end
|
|
if bit_start == 0:
|
|
f.write(" final %s byte%d = blocks[blocksOffset++] & 0xFF;\n" % (typ, byte_start))
|
|
for b in range(byte_start + 1, byte_end + 1):
|
|
f.write(" final %s byte%d = blocks[blocksOffset++] & 0xFF;\n" % (typ, b))
|
|
f.write(" values[valuesOffset++] =")
|
|
if byte_start == byte_end:
|
|
if bit_start == 0:
|
|
if bit_end == 7:
|
|
f.write(" byte%d" % byte_start)
|
|
else:
|
|
f.write(" byte%d >>> %d" % (byte_start, 7 - bit_end))
|
|
else:
|
|
if bit_end == 7:
|
|
f.write(" byte%d & %d" % (byte_start, 2 ** (8 - bit_start) - 1))
|
|
else:
|
|
f.write(" (byte%d >>> %d) & %d" % (byte_start, 7 - bit_end, 2 ** (bit_end - bit_start + 1) - 1))
|
|
else:
|
|
if bit_start == 0:
|
|
f.write(" (byte%d << %d)" % (byte_start, shift(byte_start)))
|
|
else:
|
|
f.write(" ((byte%d & %d) << %d)" % (byte_start, 2 ** (8 - bit_start) - 1, shift(byte_start)))
|
|
for b in range(byte_start + 1, byte_end):
|
|
f.write(" | (byte%d << %d)" % (b, shift(b)))
|
|
if bit_end == 7:
|
|
f.write(" | byte%d" % byte_end)
|
|
else:
|
|
f.write(" | (byte%d >>> %d)" % (byte_end, 7 - bit_end))
|
|
f.write(";\n")
|
|
f.write(" }\n")
|
|
f.write(" }\n\n")
|
|
|
|
if __name__ == '__main__':
|
|
f = open(OUTPUT_FILE, 'w')
|
|
f.write(HEADER)
|
|
f.write('\n')
|
|
f.write('''/**
|
|
* Efficient sequential read/write of packed integers.
|
|
*/\n''')
|
|
|
|
f.write('abstract class BulkOperation implements PackedInts.Decoder, PackedInts.Encoder {\n')
|
|
f.write(' private static final BulkOperation[] packedBulkOps = new BulkOperation[] {\n')
|
|
|
|
for bpv in range(1, 65):
|
|
if bpv > MAX_SPECIALIZED_BITS_PER_VALUE:
|
|
f.write(' new BulkOperationPacked(%d),\n' % bpv)
|
|
continue
|
|
f2 = open('BulkOperationPacked%d.java' % bpv, 'w')
|
|
f2.write(HEADER)
|
|
if bpv == 64:
|
|
f2.write('import java.nio.LongBuffer;\n')
|
|
f2.write('import java.nio.ByteBuffer;\n')
|
|
f2.write('\n')
|
|
f2.write('''/**
|
|
* Efficient sequential read/write of packed integers.
|
|
*/\n''')
|
|
f2.write('final class BulkOperationPacked%d extends BulkOperationPacked {\n' % bpv)
|
|
packed64(bpv, f2)
|
|
f2.write('}\n')
|
|
f2.close()
|
|
f.write(' new BulkOperationPacked%d(),\n' % bpv)
|
|
|
|
f.write(' };\n')
|
|
f.write('\n')
|
|
|
|
f.write(' // NOTE: this is sparse (some entries are null):\n')
|
|
f.write(' private static final BulkOperation[] packedSingleBlockBulkOps = new BulkOperation[] {\n')
|
|
for bpv in range(1, max(PACKED_64_SINGLE_BLOCK_BPV) + 1):
|
|
if bpv in PACKED_64_SINGLE_BLOCK_BPV:
|
|
f.write(' new BulkOperationPackedSingleBlock(%d),\n' % bpv)
|
|
else:
|
|
f.write(' null,\n')
|
|
f.write(' };\n')
|
|
f.write('\n')
|
|
|
|
f.write("\n")
|
|
f.write(" public static BulkOperation of(PackedInts.Format format, int bitsPerValue) {\n")
|
|
f.write(" switch (format) {\n")
|
|
|
|
f.write(" case PACKED:\n")
|
|
f.write(" assert packedBulkOps[bitsPerValue - 1] != null;\n")
|
|
f.write(" return packedBulkOps[bitsPerValue - 1];\n")
|
|
f.write(" case PACKED_SINGLE_BLOCK:\n")
|
|
f.write(" assert packedSingleBlockBulkOps[bitsPerValue - 1] != null;\n")
|
|
f.write(" return packedSingleBlockBulkOps[bitsPerValue - 1];\n")
|
|
f.write(" default:\n")
|
|
f.write(" throw new AssertionError();\n")
|
|
f.write(" }\n")
|
|
f.write(" }\n")
|
|
f.write(FOOTER)
|
|
f.close()
|