/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.druid.segment.nested;

import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectRBTreeMap;
import it.unimi.dsi.fastutil.ints.IntArrays;
import it.unimi.dsi.fastutil.ints.IntIterator;
import org.apache.druid.collections.bitmap.ImmutableBitmap;
import org.apache.druid.collections.bitmap.MutableBitmap;
import org.apache.druid.io.Channels;
import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.segment.column.BitmapIndexType;
import org.apache.druid.segment.data.CompressedVSizeColumnarIntsSerializer;
import org.apache.druid.segment.data.CompressionStrategy;
import org.apache.druid.segment.data.FixedIndexedIntWriter;
import org.apache.druid.segment.data.GenericIndexedWriter;
import org.apache.druid.segment.data.SingleValueColumnarIntsSerializer;
import org.apache.druid.segment.data.VSizeColumnarIntsSerializer;
import org.apache.druid.segment.file.SegmentFileBuilder;
import org.apache.druid.segment.file.SegmentFileChannel;
import org.apache.druid.segment.serde.ColumnSerializerUtils;
import org.apache.druid.segment.serde.DictionaryEncodedColumnPartSerde;
import org.apache.druid.segment.serde.Serializer;
import org.apache.druid.segment.writeout.SegmentWriteOutMedium;

import javax.annotation.Nullable;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.WritableByteChannel;

/**
 * Base class for writer of global dictionary encoded nested field columns for {@link NestedDataColumnSerializer}.
 * Nested columns are written in multiple passes. The first pass processes the 'raw' nested data with a
 * {@link StructuredDataProcessor} which will call {@link #addValue(int, Object)} for writers of each field which is
 * present. For this type of writer, this entails building a local dictionary ({@link #localDictionary}) to map into
 * the global dictionary ({@link #globalDictionaryIdLookup}) and writes this unsorted localId to an intermediate
 * integer column, {@link #intermediateValueWriter}.
 * <p>
 * When processing the 'raw' value column is complete, the {@link #writeTo(int, SegmentFileBuilder)} method will sort
 * the local ids and write them out to a local sorted dictionary, iterate over {@link #intermediateValueWriter} swapping
 * the unsorted local ids with the sorted ids and writing to the compressed id column writer
 * {@link #encodedValueSerializer}, building the bitmap indexes along the way.
 *
 * @see ScalarDoubleFieldColumnWriter   - single type double columns
 * @see ScalarLongFieldColumnWriter     - single type long columns
 * @see ScalarStringFieldColumnWriter   - single type string columns
 * @see VariantArrayFieldColumnWriter   - single type array columns of double, long, or string
 * @see VariantFieldColumnWriter        - mixed type columns of any combination
 */
public abstract class GlobalDictionaryEncodedFieldColumnWriter<T>
{
  private static final Logger log = new Logger(GlobalDictionaryEncodedFieldColumnWriter.class);

  protected final SegmentWriteOutMedium segmentWriteOutMedium;
  protected final String columnName;
  protected final String fieldName;
  protected final NestedCommonFormatColumnFormatSpec columnFormatSpec;
  protected final DictionaryIdLookup globalDictionaryIdLookup;
  protected final LocalDimensionDictionary localDictionary = new LocalDimensionDictionary();

  public BitmapIndexType bitmapIndexType = BitmapIndexType.DictionaryEncodedValueIndex.INSTANCE;
  protected final Int2ObjectRBTreeMap<MutableBitmap> arrayElements = new Int2ObjectRBTreeMap<>();

  protected final Closer fieldResourceCloser = Closer.create();

  protected FixedIndexedIntWriter intermediateValueWriter;
  // maybe someday we allow no bitmap indexes or multi-value columns
  protected int flags = DictionaryEncodedColumnPartSerde.NO_FLAGS;
  protected DictionaryEncodedColumnPartSerde.VERSION version = null;
  protected SingleValueColumnarIntsSerializer encodedValueSerializer;

  protected int cursorPosition;

  protected GlobalDictionaryEncodedFieldColumnWriter(
      String columnName,
      String fieldName,
      SegmentWriteOutMedium segmentWriteOutMedium,
      NestedCommonFormatColumnFormatSpec columnFormatSpec,
      DictionaryIdLookup globalDictionaryIdLookup
  )
  {
    this.columnName = columnName;
    this.fieldName = fieldName;
    this.columnFormatSpec = columnFormatSpec;
    this.segmentWriteOutMedium = segmentWriteOutMedium;
    this.globalDictionaryIdLookup = globalDictionaryIdLookup;
  }

  /**
   * Perform any value conversion needed before looking up the global id in the value dictionary (such as null handling
   * stuff or array processing to add the elements to the dictionary before adding the int[] to the dictionary)
   */
  T processValue(int row, Object value)
  {
    return (T) value;
  }

  /**
   * Hook to allow implementors the chance to do additional operations during {@link #writeTo(int, SegmentFileBuilder)},
   * such as writing an additional value column
   */
  void writeValue(@Nullable T value) throws IOException
  {
    // do nothing, if a value column is present this method should be overridden to write the value to the serializer
  }

  /**
   * Find a value in {@link #globalDictionaryIdLookup} as is most appropriate to the writer type
   */
  abstract int lookupGlobalId(T value);

  /**
   * Open the writer so that {@link #addValue(int, Object)} can be called
   */
  public void open() throws IOException
  {
    intermediateValueWriter = new FixedIndexedIntWriter(segmentWriteOutMedium, false);
    intermediateValueWriter.open();
    cursorPosition = 0;
  }

  /**
   * Add a value to the unsorted local dictionary and write to an intermediate column
   */
  public void addValue(int row, Object val) throws IOException
  {
    if (row > cursorPosition) {
      fillNull(row);
    }
    final T value = processValue(row, val);
    final int localId;
    // null is always 0
    if (value == null) {
      localId = localDictionary.add(0);
    } else {
      final int globalId = lookupGlobalId(value);
      Preconditions.checkArgument(globalId >= 0, "Value [%s] is not present in global dictionary", value);
      localId = localDictionary.add(globalId);
    }
    intermediateValueWriter.write(localId);
    cursorPosition++;
  }

  /**
   * Backfill intermediate column with null values
   */
  private void fillNull(int row) throws IOException
  {
    final int localId = localDictionary.add(0);
    while (cursorPosition < row) {
      intermediateValueWriter.write(localId);
      cursorPosition++;
    }
  }


  /**
   * How many bytes {@link #writeColumnTo(WritableByteChannel, SegmentFileBuilder)} is expected to write to the segment.
   */
  long getSerializedColumnSize() throws IOException
  {
    return Integer.BYTES + Integer.BYTES + encodedValueSerializer.getSerializedSize();
  }

  /**
   * Defines how to write the column, including the dictionary id column, along with any additional columns
   * such as the long or double value column as type appropriate.
   */
  void writeColumnTo(WritableByteChannel channel, SegmentFileBuilder fileBuilder) throws IOException
  {
    writeLongAndDoubleColumnLength(channel, 0, 0);
    encodedValueSerializer.writeTo(channel, fileBuilder);
  }

  public void writeTo(int finalRowCount, SegmentFileBuilder fileBuilder) throws IOException
  {
    if (finalRowCount > cursorPosition) {
      fillNull(finalRowCount);
    }
    // use a child writeout medium so that we can close them when we are finished and don't leave temporary files
    // hanging out until the entire segment is done
    final SegmentWriteOutMedium tmpWriteoutMedium = segmentWriteOutMedium.makeChildWriteOutMedium();
    final FixedIndexedIntWriter sortedDictionaryWriter = new FixedIndexedIntWriter(tmpWriteoutMedium, true);
    sortedDictionaryWriter.open();
    final FixedIndexedIntWriter arrayElementDictionaryWriter = new FixedIndexedIntWriter(tmpWriteoutMedium, true);
    arrayElementDictionaryWriter.open();
    BitmapIndexType.Writer bitmapIndexWriter = bitmapIndexType.getWriter();
    bitmapIndexWriter.openWriter(tmpWriteoutMedium, columnName, columnFormatSpec.getBitmapEncoding().getObjectStrategy());
    GenericIndexedWriter<ImmutableBitmap> arrayElementIndexWriter = new GenericIndexedWriter<>(
        tmpWriteoutMedium,
        columnName,
        columnFormatSpec.getBitmapEncoding().getObjectStrategy()
    );
    arrayElementIndexWriter.open();
    arrayElementIndexWriter.setObjectsNotSorted();

    final Int2IntOpenHashMap globalToUnsorted = localDictionary.getGlobalIdToLocalId();
    final int[] unsortedToGlobal = new int[localDictionary.size()];
    for (int key : globalToUnsorted.keySet()) {
      unsortedToGlobal[globalToUnsorted.get(key)] = key;
    }
    final int[] sortedGlobal = new int[unsortedToGlobal.length];
    System.arraycopy(unsortedToGlobal, 0, sortedGlobal, 0, unsortedToGlobal.length);
    IntArrays.unstableSort(sortedGlobal);

    final int[] unsortedToSorted = new int[unsortedToGlobal.length];
    for (int index = 0; index < sortedGlobal.length; index++) {
      final int globalId = sortedGlobal[index];
      sortedDictionaryWriter.write(globalId);
      final int unsortedId = globalToUnsorted.get(globalId);
      unsortedToSorted[unsortedId] = index;
    }

    for (Int2ObjectMap.Entry<MutableBitmap> arrayElement : arrayElements.int2ObjectEntrySet()) {
      arrayElementDictionaryWriter.write(arrayElement.getIntKey());
      arrayElementIndexWriter.write(
          columnFormatSpec.getBitmapEncoding().getBitmapFactory().makeImmutableBitmap(arrayElement.getValue())
      );
    }

    openColumnSerializer(tmpWriteoutMedium, sortedGlobal[sortedGlobal.length - 1]);
    bitmapIndexWriter.init(columnFormatSpec.getBitmapEncoding().getBitmapFactory(), sortedGlobal.length);
    final IntIterator rows = intermediateValueWriter.getIterator();
    int rowCount = 0;
    while (rows.hasNext()) {
      final int unsortedLocalId = rows.nextInt();
      final int sortedLocalId = unsortedToSorted[unsortedLocalId];
      encodedValueSerializer.addValue(sortedLocalId);
      T value = (T) globalDictionaryIdLookup.getDictionaryValue(unsortedToGlobal[unsortedLocalId]);
      writeValue(value);
      bitmapIndexWriter.add(rowCount, sortedLocalId, value);
      rowCount++;
    }
    bitmapIndexWriter.finalizeWriter(columnFormatSpec.getBitmapEncoding().getBitmapFactory());

    final Serializer fieldSerializer = new Serializer()
    {
      @Override
      public long getSerializedSize() throws IOException
      {
        final long arraySize;
        if (arrayElements.size() > 0) {
          arraySize = arrayElementDictionaryWriter.getSerializedSize() + arrayElementIndexWriter.getSerializedSize();
        } else {
          arraySize = 0;
        }
        return 1 + Integer.BYTES + // version + feature flags
               sortedDictionaryWriter.getSerializedSize() +
               getSerializedColumnSize() +
               bitmapIndexWriter.getSerializedSize() +
               arraySize;
      }

      @Override
      public void writeTo(WritableByteChannel channel, SegmentFileBuilder fileBuilder) throws IOException
      {
        Channels.writeFully(channel, ByteBuffer.wrap(new byte[]{version.asByte()}));
        Channels.writeFully(channel, ByteBuffer.wrap(Ints.toByteArray(flags)));
        sortedDictionaryWriter.writeTo(channel, fileBuilder);
        writeColumnTo(channel, fileBuilder);
        bitmapIndexWriter.writeTo(channel, fileBuilder);
        if (arrayElements.size() > 0) {
          arrayElementDictionaryWriter.writeTo(channel, fileBuilder);
          arrayElementIndexWriter.writeTo(channel, fileBuilder);
        }
      }
    };
    final String fieldFileName = ColumnSerializerUtils.getInternalFileName(columnName, fieldName);
    final long size = fieldSerializer.getSerializedSize();
    log.debug("Column [%s] serializing [%s] field of size [%d].", columnName, fieldName, size);
    try (SegmentFileChannel channel = fileBuilder.addWithChannel(fieldFileName, size)) {
      fieldSerializer.writeTo(channel, fileBuilder);
    }
    finally {
      tmpWriteoutMedium.close();
      fieldResourceCloser.close();
    }
  }

  public void openColumnSerializer(SegmentWriteOutMedium medium, int maxId) throws IOException
  {
    if (columnFormatSpec.getDictionaryEncodedColumnCompression() != CompressionStrategy.UNCOMPRESSED) {
      this.version = DictionaryEncodedColumnPartSerde.VERSION.COMPRESSED;
      encodedValueSerializer = CompressedVSizeColumnarIntsSerializer.create(
          fieldName,
          medium,
          columnName,
          maxId,
          columnFormatSpec.getDictionaryEncodedColumnCompression(),
          fieldResourceCloser
      );
    } else {
      encodedValueSerializer = new VSizeColumnarIntsSerializer(medium, maxId);
      this.version = DictionaryEncodedColumnPartSerde.VERSION.UNCOMPRESSED_SINGLE_VALUE;
    }
    encodedValueSerializer.open();
  }

  public void writeLongAndDoubleColumnLength(WritableByteChannel channel, int longLength, int doubleLength)
      throws IOException
  {
    ByteBuffer intBuffer = ByteBuffer.allocate(Integer.BYTES).order(ByteOrder.nativeOrder());
    intBuffer.position(0);
    intBuffer.putInt(longLength);
    intBuffer.flip();
    Channels.writeFully(channel, intBuffer);
    intBuffer.position(0);
    intBuffer.limit(intBuffer.capacity());
    intBuffer.putInt(doubleLength);
    intBuffer.flip();
    Channels.writeFully(channel, intBuffer);
  }
}
