#! /usr/bin/env python # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. try: # python 3.9+ from math import gcd except ImportError: # old python from fractions import gcd """Code generation for bulk operations""" MAX_SPECIALIZED_BITS_PER_VALUE = 24; PACKED_64_SINGLE_BLOCK_BPV = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 16, 21, 32] OUTPUT_FILE = "BulkOperation.java" HEADER = """// This file has been automatically generated, DO NOT EDIT /* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.lucene.util.packed; """ FOOTER = """ protected int writeLong(long block, byte[] blocks, int blocksOffset) { for (int j = 1; j <= 8; ++j) { blocks[blocksOffset++] = (byte) (block >>> (64 - (j << 3))); } return blocksOffset; } /** * For every number of bits per value, there is a minimum number of * blocks (b) / values (v) you need to write in order to reach the next block * boundary: *
   *  - 16 bits per value -> b=2, v=1
   *  - 24 bits per value -> b=3, v=1
   *  - 50 bits per value -> b=25, v=4
   *  - 63 bits per value -> b=63, v=8
   *  - ...
   * 
* * A bulk read consists in copying iterations*v values that are * contained in iterations*b blocks into a long[] * (higher values of iterations are likely to yield a better * throughput): this requires n * (b + 8v) bytes of memory. * * This method computes iterations as * ramBudget / (b + 8v) (since a long is 8 bytes). */ public final int computeIterations(int valueCount, int ramBudget) { final int iterations = ramBudget / (byteBlockCount() + 8 * byteValueCount()); if (iterations == 0) { // at least 1 return 1; } else if ((iterations - 1) * byteValueCount() >= valueCount) { // don't allocate for more than the size of the reader return (int) Math.ceil((double) valueCount / byteValueCount()); } else { return iterations; } } } """ def is_power_of_two(n): return n & (n - 1) == 0 def casts(typ): cast_start = "(%s) (" % typ cast_end = ")" if typ == "long": cast_start = "" cast_end = "" return cast_start, cast_end def hexNoLSuffix(n): # On 32 bit Python values > (1 << 31)-1 will have L appended by hex function: s = hex(n) if s.endswith('L'): s = s[:-1] return s def masks(bits): if bits == 64: return "", "" return "(", " & %sL)" % (hexNoLSuffix((1 << bits) - 1)) def get_type(bits): if bits == 8: return "byte" elif bits == 16: return "short" elif bits == 32: return "int" elif bits == 64: return "long" else: assert False def block_value_count(bpv, bits=64): blocks = bpv values = blocks * bits // bpv while blocks % 2 == 0 and values % 2 == 0: blocks //= 2 values //= 2 assert values * bpv == bits * blocks, "%d values, %d blocks, %d bits per value" % (values, blocks, bpv) return (blocks, values) def packed64(bpv, f): mask = (1 << bpv) - 1 f.write("\n") f.write(" public BulkOperationPacked%d() {\n" % bpv) f.write(" super(%d);\n" % bpv) f.write(" }\n\n") if bpv == 64: f.write(""" @Override public void decode(long[] blocks, int blocksOffset, long[] values, int valuesOffset, int iterations) { System.arraycopy(blocks, blocksOffset, values, valuesOffset, valueCount() * iterations); } @Override public void decode(long[] blocks, int blocksOffset, int[] values, int valuesOffset, int iterations) { throw new UnsupportedOperationException(); } @Override public void decode(byte[] blocks, int blocksOffset, int[] values, int valuesOffset, int iterations) { throw new UnsupportedOperationException(); } @Override public void decode(byte[] blocks, int blocksOffset, long[] values, int valuesOffset, int iterations) { LongBuffer.wrap(values, valuesOffset, iterations * valueCount()).put(ByteBuffer.wrap(blocks, blocksOffset, 8 * iterations * blockCount()).asLongBuffer()); } """) else: p64_decode(bpv, f, 32) p64_decode(bpv, f, 64) def p64_decode(bpv, f, bits): blocks, values = block_value_count(bpv) typ = get_type(bits) cast_start, cast_end = casts(typ) f.write(" @Override\n") f.write(" public void decode(long[] blocks, int blocksOffset, %s[] values, int valuesOffset, int iterations) {\n" % typ) if bits < bpv: f.write(" throw new UnsupportedOperationException();\n") else: f.write(" for (int i = 0; i < iterations; ++i) {\n") mask = (1 << bpv) - 1 if is_power_of_two(bpv): f.write(" final long block = blocks[blocksOffset++];\n") f.write(" for (int shift = %d; shift >= 0; shift -= %d) {\n" % (64 - bpv, bpv)) f.write(" values[valuesOffset++] = %s(block >>> shift) & %d%s;\n" % (cast_start, mask, cast_end)) f.write(" }\n") else: for i in range(0, values): block_offset = i * bpv // 64 bit_offset = (i * bpv) % 64 if bit_offset == 0: # start of block f.write(" final long block%d = blocks[blocksOffset++];\n" % block_offset); f.write(" values[valuesOffset++] = %sblock%d >>> %d%s;\n" % (cast_start, block_offset, 64 - bpv, cast_end)) elif bit_offset + bpv == 64: # end of block f.write(" values[valuesOffset++] = %sblock%d & %dL%s;\n" % (cast_start, block_offset, mask, cast_end)) elif bit_offset + bpv < 64: # middle of block f.write(" values[valuesOffset++] = %s(block%d >>> %d) & %dL%s;\n" % (cast_start, block_offset, 64 - bit_offset - bpv, mask, cast_end)) else: # value spans across 2 blocks mask1 = (1 << (64 - bit_offset)) - 1 shift1 = bit_offset + bpv - 64 shift2 = 64 - shift1 f.write(" final long block%d = blocks[blocksOffset++];\n" % (block_offset + 1)); f.write(" values[valuesOffset++] = %s((block%d & %dL) << %d) | (block%d >>> %d)%s;\n" % (cast_start, block_offset, mask1, shift1, block_offset + 1, shift2, cast_end)) f.write(" }\n") f.write(" }\n\n") byte_blocks, byte_values = block_value_count(bpv, 8) f.write(" @Override\n") f.write(" public void decode(byte[] blocks, int blocksOffset, %s[] values, int valuesOffset, int iterations) {\n" % typ) if bits < bpv: f.write(" throw new UnsupportedOperationException();\n") else: if is_power_of_two(bpv) and bpv < 8: f.write(" for (int j = 0; j < iterations; ++j) {\n") f.write(" final byte block = blocks[blocksOffset++];\n") for shift in range(8 - bpv, 0, -bpv): f.write(" values[valuesOffset++] = (block >>> %d) & %d;\n" % (shift, mask)) f.write(" values[valuesOffset++] = block & %d;\n" % mask) f.write(" }\n") elif bpv == 8: f.write(" for (int j = 0; j < iterations; ++j) {\n") f.write(" values[valuesOffset++] = blocks[blocksOffset++] & 0xFF;\n") f.write(" }\n") elif is_power_of_two(bpv) and bpv > 8: f.write(" for (int j = 0; j < iterations; ++j) {\n") m = bits <= 32 and "0xFF" or "0xFFL" f.write(" values[valuesOffset++] =") for i in range(bpv // 8 - 1): f.write(" ((blocks[blocksOffset++] & %s) << %d) |" % (m, bpv - 8)) f.write(" (blocks[blocksOffset++] & %s);\n" % m) f.write(" }\n") else: f.write(" for (int i = 0; i < iterations; ++i) {\n") for i in range(0, byte_values): byte_start = i * bpv // 8 bit_start = (i * bpv) % 8 byte_end = ((i + 1) * bpv - 1) // 8 bit_end = ((i + 1) * bpv - 1) % 8 shift = lambda b: 8 * (byte_end - b - 1) + 1 + bit_end if bit_start == 0: f.write(" final %s byte%d = blocks[blocksOffset++] & 0xFF;\n" % (typ, byte_start)) for b in range(byte_start + 1, byte_end + 1): f.write(" final %s byte%d = blocks[blocksOffset++] & 0xFF;\n" % (typ, b)) f.write(" values[valuesOffset++] =") if byte_start == byte_end: if bit_start == 0: if bit_end == 7: f.write(" byte%d" % byte_start) else: f.write(" byte%d >>> %d" % (byte_start, 7 - bit_end)) else: if bit_end == 7: f.write(" byte%d & %d" % (byte_start, 2 ** (8 - bit_start) - 1)) else: f.write(" (byte%d >>> %d) & %d" % (byte_start, 7 - bit_end, 2 ** (bit_end - bit_start + 1) - 1)) else: if bit_start == 0: f.write(" (byte%d << %d)" % (byte_start, shift(byte_start))) else: f.write(" ((byte%d & %d) << %d)" % (byte_start, 2 ** (8 - bit_start) - 1, shift(byte_start))) for b in range(byte_start + 1, byte_end): f.write(" | (byte%d << %d)" % (b, shift(b))) if bit_end == 7: f.write(" | byte%d" % byte_end) else: f.write(" | (byte%d >>> %d)" % (byte_end, 7 - bit_end)) f.write(";\n") f.write(" }\n") f.write(" }\n\n") if __name__ == '__main__': f = open(OUTPUT_FILE, 'w') f.write(HEADER) f.write('\n') f.write('''/** * Efficient sequential read/write of packed integers. */\n''') f.write('abstract class BulkOperation implements PackedInts.Decoder, PackedInts.Encoder {\n') f.write(' private static final BulkOperation[] packedBulkOps = new BulkOperation[] {\n') for bpv in range(1, 65): if bpv > MAX_SPECIALIZED_BITS_PER_VALUE: f.write(' new BulkOperationPacked(%d),\n' % bpv) continue f2 = open('BulkOperationPacked%d.java' % bpv, 'w') f2.write(HEADER) if bpv == 64: f2.write('import java.nio.LongBuffer;\n') f2.write('import java.nio.ByteBuffer;\n') f2.write('\n') f2.write('''/** * Efficient sequential read/write of packed integers. */\n''') f2.write('final class BulkOperationPacked%d extends BulkOperationPacked {\n' % bpv) packed64(bpv, f2) f2.write('}\n') f2.close() f.write(' new BulkOperationPacked%d(),\n' % bpv) f.write(' };\n') f.write('\n') f.write(' // NOTE: this is sparse (some entries are null):\n') f.write(' private static final BulkOperation[] packedSingleBlockBulkOps = new BulkOperation[] {\n') for bpv in range(1, max(PACKED_64_SINGLE_BLOCK_BPV) + 1): if bpv in PACKED_64_SINGLE_BLOCK_BPV: f.write(' new BulkOperationPackedSingleBlock(%d),\n' % bpv) else: f.write(' null,\n') f.write(' };\n') f.write('\n') f.write("\n") f.write(" public static BulkOperation of(PackedInts.Format format, int bitsPerValue) {\n") f.write(" switch (format) {\n") f.write(" case PACKED:\n") f.write(" assert packedBulkOps[bitsPerValue - 1] != null;\n") f.write(" return packedBulkOps[bitsPerValue - 1];\n") f.write(" case PACKED_SINGLE_BLOCK:\n") f.write(" assert packedSingleBlockBulkOps[bitsPerValue - 1] != null;\n") f.write(" return packedSingleBlockBulkOps[bitsPerValue - 1];\n") f.write(" default:\n") f.write(" throw new AssertionError();\n") f.write(" }\n") f.write(" }\n") f.write(FOOTER) f.close()