package pl.codewise.commons.aws;

import com.amazonaws.services.ec2.model.Instance;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Collection;
import java.util.List;

import static java.util.stream.Collectors.toList;

public class InstanceNameUpdater {

    private static final Logger log = LoggerFactory.getLogger(InstanceNameUpdater.class);
    private static final int COLLISION_RETRIES = 3;

    private final IpProvider ipProvider;
    private final String instancePrefix;
    private final Ec2Wrapper ec2Wrapper;
    private final Collection<String> autoScalingGroups;
    private final Sleeper sleeper;

    public InstanceNameUpdater(IpProvider ipProvider, String instancePrefix, Ec2Wrapper ec2Wrapper,
            Collection<String> autoScalingGroups, Sleeper sleeper) {
        this.ipProvider = ipProvider;
        this.instancePrefix = instancePrefix;
        this.ec2Wrapper = ec2Wrapper;
        this.autoScalingGroups = autoScalingGroups;
        this.sleeper = sleeper;
    }

    public String updateEc2InstanceName() {
        List<Instance> instances = getInstancesFromScalingGroups(autoScalingGroups);
        Instance thisInstance = getThisInstance(instances);
        String name = Ec2Wrapper.getInstanceName(thisInstance);
        if (StringUtils.isBlank(name)) {
            log.info("Setting instance name");
            sleeper.sleep();
            String newName = updateInstanceName(instances, thisInstance);
            return resolveNewNameIfNameCollides(newName);
        } else {
            log.info("Instance name is {}. No need to update.", name);
            return name;
        }
    }

    private List<Instance> getInstancesFromScalingGroups(Collection<String> autoScalingGroups) {
        return autoScalingGroups
                .stream()
                .map(ec2Wrapper::describeInstances)
                .flatMap(Collection::stream)
                .collect(toList());
    }

    private String resolveNewNameIfNameCollides(String name) {
        String latestName = name;
        for (int i = 0; i < COLLISION_RETRIES && countCollidingInstances(latestName) > 1; i++) {
            latestName = updateCollidingInstanceName(latestName);
        }

        return latestName;
    }

    private long countCollidingInstances(String name) {
        return getInstancesFromScalingGroups(autoScalingGroups)
                .stream()
                .map(Ec2Wrapper::getInstanceName)
                .filter(instanceName -> StringUtils.equals(instanceName, name))
                .count();
    }

    private String updateCollidingInstanceName(String name) {
        log.info("Instance name collision detected for name {}, going to sleep before retry", name);
        sleeper.sleep();
        List<Instance> instancesAfterRandomSleep = getInstancesFromScalingGroups(autoScalingGroups);
        Instance thisInstance = getThisInstance(instancesAfterRandomSleep);
        return updateInstanceName(instancesAfterRandomSleep, thisInstance);
    }

    private String updateInstanceName(List<Instance> instances, Instance thisInstance) {
        int firstAvailableInstanceNumber = getFirstAvailableInstanceNumber(instances);

        String newName = instancePrefix + firstAvailableInstanceNumber;
        ec2Wrapper.setInstanceName(thisInstance, newName);
        log.info("Instance name set to {}", newName);
        return newName;
    }

    private int getFirstAvailableInstanceNumber(List<Instance> instances) {
        int highestNumber = 0;
        List<Integer> instanceNumbers = instances
                .stream()
                .map(Ec2Wrapper::getInstanceName)
                .filter(instanceName -> StringUtils.startsWith(instanceName, instancePrefix))
                .map(this::getInstanceNumber)
                .filter(number -> number > 0)
                .sorted()
                .collect(toList());

        for (Integer instanceNumber : instanceNumbers) {
            if (instanceNumber - highestNumber > 1) {
                return highestNumber + 1;
            }
            highestNumber = instanceNumber;
        }

        return highestNumber + 1;
    }

    private Integer getInstanceNumber(String instanceName) {
        String instanceNumberString =
                instanceName.substring(instancePrefix.length(), instanceName.length());
        try {
            return Integer.parseInt(instanceNumberString);
        } catch (NumberFormatException e) {
            log.warn("Could not get instance number from name: {}", instanceName);
            return -1;
        }
    }

    private Instance getThisInstance(List<Instance> instances) {
        String ip = ipProvider.getIp();
        for (Instance instance : instances) {
            if (instance.getPublicIpAddress().equals(ip)) {
                return instance;
            }
        }

        throw new RuntimeException("Could not determine Instance for IP: " + ip);
    }
}
