/**
 * Copyright (c) 2013-2022 Contributors to the Eclipse Foundation
 *
 * <p> See the NOTICE file distributed with this work for additional information regarding copyright
 * ownership. All rights reserved. This program and the accompanying materials are made available
 * under the terms of the Apache License, Version 2.0 which accompanies this distribution and is
 * available at http://www.apache.org/licenses/LICENSE-2.0.txt
 */
package org.locationtech.geowave.analytic.kmeans.serial;

import java.util.ArrayList;
import java.util.List;
import org.apache.commons.lang3.tuple.Pair;
import org.locationtech.geowave.analytic.AnalyticItemWrapper;
import org.locationtech.geowave.analytic.AnalyticItemWrapperFactory;
import org.locationtech.geowave.analytic.clustering.CentroidPairing;
import org.locationtech.geowave.analytic.kmeans.AssociationNotification;
import org.locationtech.geowave.analytic.kmeans.CentroidAssociationFn;
import org.locationtech.geowave.analytic.sample.SampleNotification;
import org.locationtech.geowave.analytic.sample.Sampler;

public class KMeansParallelInitialize<T> {
  private CentroidAssociationFn<T> centroidAssociationFn = new CentroidAssociationFn<>();
  private double psi = 5.0;
  private final Sampler<T> sampler = new Sampler<>();
  private AnalyticItemWrapperFactory<T> centroidFactory;
  private final AnalyticStats stats = new StatsMap();

  public CentroidAssociationFn<T> getCentroidAssociationFn() {
    return centroidAssociationFn;
  }

  public void setCentroidAssociationFn(final CentroidAssociationFn<T> centroidAssociationFn) {
    this.centroidAssociationFn = centroidAssociationFn;
  }

  public double getPsi() {
    return psi;
  }

  public void setPsi(final double psi) {
    this.psi = psi;
  }

  public Sampler<T> getSampler() {
    return sampler;
  }

  public AnalyticItemWrapperFactory<T> getCentroidFactory() {
    return centroidFactory;
  }

  public void setCentroidFactory(final AnalyticItemWrapperFactory<T> centroidFactory) {
    this.centroidFactory = centroidFactory;
  }

  public AnalyticStats getStats() {
    return stats;
  }

  public Pair<List<CentroidPairing<T>>, List<AnalyticItemWrapper<T>>> runLocal(
      final Iterable<AnalyticItemWrapper<T>> pointSet) {

    stats.reset();

    final List<AnalyticItemWrapper<T>> sampleSet = new ArrayList<>();
    sampleSet.add(pointSet.iterator().next());

    final List<CentroidPairing<T>> pairingSet = new ArrayList<>();

    final AssociationNotification<T> assocFn = new AssociationNotification<T>() {
      @Override
      public void notify(final CentroidPairing<T> pairing) {
        pairingSet.add(pairing);
        pairing.getCentroid().incrementAssociationCount(1);
      }
    };
    // combine to get pairing?
    double normalizingConstant = centroidAssociationFn.compute(pointSet, sampleSet, assocFn);
    stats.notify(AnalyticStats.StatValue.COST, normalizingConstant);

    final int logPsi = Math.max(1, (int) (Math.log(psi) / Math.log(2)));
    for (int i = 0; i < logPsi; i++) {
      sampler.sample(pairingSet, new SampleNotification<T>() {
        @Override
        public void notify(final T item, final boolean partial) {
          sampleSet.add(centroidFactory.create(item));
        }
      }, normalizingConstant);
      pairingSet.clear();
      for (final AnalyticItemWrapper<T> centroid : sampleSet) {
        centroid.resetAssociatonCount();
      }
      normalizingConstant = centroidAssociationFn.compute(pointSet, sampleSet, assocFn);
      stats.notify(AnalyticStats.StatValue.COST, normalizingConstant);
    }
    return Pair.of(pairingSet, sampleSet);
  }
}
