/*
 *
 *  Copyright 2022 the original author or authors.
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *       https://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */
package org.flowstep.mongo;

import com.mongodb.client.FindIterable;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.model.Aggregates;
import com.mongodb.client.model.BsonField;
import com.mongodb.client.model.Filters;
import com.mongodb.client.model.Projections;
import org.bson.Document;
import org.bson.conversions.Bson;
import org.flowstep.core.DataProviderException;
import org.flowstep.core.ExtractedRecord;
import org.flowstep.core.FlowDataProvider;
import org.flowstep.core.RecordFieldExtractor;
import org.flowstep.core.connection.EnvironmentConnection;
import org.flowstep.core.context.FlowPackageContext;
import org.flowstep.core.context.FlowStepContext;
import org.flowstep.core.model.step.FlowGroup;
import org.flowstep.core.model.step.FlowStepData;
import org.flowstep.mongo.model.MongoEnvironment;
import org.flowstep.mongo.model.MongoStep;
import org.flowstep.mongo.model.StepSort;
import org.springframework.dao.DataAccessException;
import org.springframework.data.mongodb.core.MongoTemplate;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import static com.mongodb.client.model.Sorts.orderBy;

public class MongoDataProvider implements FlowDataProvider {

    private EnvironmentConnection connection;
    private FlowPackageContext stepPackageContext;
    private MongoStep step;

    private FlowGroup stepGroup;

    public MongoDataProvider setConnection(EnvironmentConnection connection) {
        this.connection = connection;
        return this;
    }

    public MongoDataProvider setStepPackageContext(FlowPackageContext stepPackageContext) {
        this.stepPackageContext = stepPackageContext;
        return this;
    }

    public MongoDataProvider setStep(MongoStep step) {
        this.step = step;
        return this;
    }

    public MongoDataProvider setStepGroup(FlowGroup stepGroup) {
        this.stepGroup = stepGroup;
        return this;
    }

    public void build() throws DataProviderException {

        try {
            MongoEnvironment mongoEnvironment = (MongoEnvironment) connection.getEnvironmentItemSettings();
            MongoTemplate mongoTemplate = connection.getConnectionTemplate();
            MongoCollection<Document> collection = mongoTemplate.getCollection(mongoEnvironment.getCollection());
            FlowStepContext stepContext = new FlowStepContext(stepPackageContext, stepGroup, step);


            Bson filter = buildAggregateFilter(step, stepContext);
            Bson sort = buildFindSorts(step);
            Bson projection = buildProjection(step);

            Iterable<Document> result = retrieveData(collection, filter, projection, sort, step.getLimit());

            for (Document document : result) {
                saveRecord(step, document, stepContext, stepPackageContext);
            }
        } catch (DataAccessException e) {
            throw new DataProviderException("Data Provider Exception", e);
        }
    }

    private List<Bson> buildGroupPipeline(MongoStep step, Bson filter) {
        Document aggregateFields = new Document();
        step.getFields().forEach(field -> {
            String fieldName = field.getId().replace(".", "");
            String fieldKey = "$" + field.getId();
            aggregateFields.append(fieldName, fieldKey);
        });

        BsonField aggregator = step.getGroup().getGroupOperator();

        List<Bson> aggregatedParams = new ArrayList<>();
        aggregatedParams.add(Aggregates.match(filter));
        aggregatedParams.add(Aggregates.group(aggregateFields, aggregator));

        if (!step.getSort().isEmpty()) {
            aggregatedParams.add(Aggregates.sort(getAggregatedSorts(step)));
        }

        return aggregatedParams;
    }

    private Bson buildFindSorts(MongoStep step) {
        List<Bson> sorts = step.getSort().stream()
                .map(StepSort::getSort)
                .collect(Collectors.toList());

        return orderBy(sorts);
    }

    private Bson getAggregatedSorts(MongoStep step) {
        String prefix = "_id";
        String groupName = step.getGroup().getFieldName();

        List<Bson> sorts = step.getSort().stream()
                .map(sort -> {
                    String currentPrefix = sort.getFieldName().equalsIgnoreCase(groupName) ? "" : prefix;
                    return sort.getSort(currentPrefix);
                })
                .collect(Collectors.toList());

        return orderBy(sorts);
    }

    private Bson buildProjection(MongoStep step) {
        List<String> ids = step.getFields()
                .stream()
                .map(field -> (field.getId().contains(".[")) ? field.getId().substring(0, field.getId().indexOf(".[")) : field.getId())
                .collect(Collectors.toList());

        return Projections.fields(Projections.include(ids));
    }

    private Bson buildAggregateFilter(MongoStep step, FlowStepContext stepContext) {
        List<Bson> filterGroups = new ArrayList<>();

        step.getFilterGroups().forEach(group -> {
            Bson groupedFilter = group.getType()
                    .getOperatorFunction()
                    .apply(group.getFilters().stream()
                            .map(filter -> {
                                filter.processValueFromTransform(stepContext);
                                return filter.getCondition()
                                        .getFilter(filter);
                            })
                            .collect(Collectors.toList()));

            filterGroups.add(groupedFilter);
        });

        return filterGroups.size() > 1 ? Filters.or(filterGroups) : Filters.and(filterGroups);
    }

    private Iterable<Document> retrieveData(MongoCollection<Document> collection, Bson filter, Bson projection, Bson sort, Integer limit) {

        if (step.getGroup() != null) {
            List<Bson> groupPipeline = buildGroupPipeline(step, filter);
            return collection.aggregate(groupPipeline).allowDiskUse(true);

        } else {

            FindIterable<Document> findResults = collection.find(filter)
                    .projection(projection)
                    .sort(sort)
                    .allowDiskUse(true)
                    .noCursorTimeout(true);

            if (limit > 0)
                findResults.limit(limit);

            return findResults;
        }
    }

    @Override
    public void saveRecord(FlowStepData step, Map<String, Object> recordData, FlowStepContext stepContext, FlowPackageContext stepPackageContext) {
        MongoStep mongoStep = (MongoStep) step;
        RecordFieldExtractor recordFieldExtractor = new RecordFieldExtractor();
        ExtractedRecord extractedRecord = new ExtractedRecord();

        Map<String, Object> primaryKeyValue = new HashMap<>();

        step.getFields().forEach(field -> {
            Object value;
            if (mongoStep.getGroup() != null) {
                String fieldName = field.getId().equalsIgnoreCase(mongoStep.getGroup().getFieldName()) ? field.getId() : "_id." + field.getId();
                value = recordFieldExtractor.getValue(recordData, fieldName);
            } else
                value = recordFieldExtractor.getValue(recordData, field.getId());

            value = field.getValue(value, stepContext);

            extractedRecord.put(field.getName(), value);

            if (field.isPrimaryKey())
                primaryKeyValue.put(field.getName(), value);

        });

        if (stepPackageContext.getStepDependencies().contains(step.getId()) || !step.isIgnoreIfNeverUsed()) {
            stepPackageContext.addRecordInCache(step, extractedRecord);
        }

        if (step.hasPrimaryKey()) {
            primaryKeyValue.forEach((key, value) -> stepPackageContext.addRecordInCache(step, key, value));
        }

    }
}
