package com.mario6.common.db;

import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.*;

/**
 * 数据源管理
 * 主要提供连接获取和事务管理
 */
public class DataSourceManager {
    /* 线程当前绑定的连接 */
    private static ThreadLocal<Map<DataSource, Connection>> LOCAL_CONNECTION = new ThreadLocal<>();
    /* 标志，是否开启事务 */
    private static ThreadLocal<Boolean> TRANSACTION = new ThreadLocal<>();

    /**
     * 获取一个连接
     * @param ds
     * @return
     * @throws SQLException
     */
    public static Connection getConnection(DataSource ds) throws SQLException {
        Map<DataSource, Connection> dsToConnMap = LOCAL_CONNECTION.get();
        if(dsToConnMap != null) {
            Connection conn = dsToConnMap.get(ds);
            if(conn != null && !conn.isClosed()) {
                // 当前线程原有连接有效
                if(isTransactionFlag()) {
                    conn.setAutoCommit(false);
                }
                return conn;
            }
        }
        // 新创建连接
        Connection conn = ds.getConnection();
        if(isTransactionFlag()) {
            conn.setAutoCommit(false);
        }
        if(dsToConnMap == null) {
            dsToConnMap = new HashMap<>();
        }
        dsToConnMap.put(ds, conn);
        LOCAL_CONNECTION.set(dsToConnMap);

        return conn;
    }


    public static Connection getCurrentConnection(DataSource ds) {
        Map<DataSource, Connection> dsToConnMap = LOCAL_CONNECTION.get();
        if(dsToConnMap != null) {
            Connection conn = dsToConnMap.get(ds);
            if(conn != null) {
                return conn;
            }
        }
        return null;
    }

    /**
     * 开始事务
     */
    public static void begin() {
        try {
            // 对所有当前使用的连接开启事务
            List<Connection> connections = getAllValidConnection();
            for(Connection conn: connections) {
                conn.setAutoCommit(false);
            }
            // 开启内置的事务标志
            onTransactionFlag();
        } catch (SQLException e) {
            throw new RuntimeException("开始事物失败", e);
        }
    }

    /**
     * 提交事务
     */
    public static void commit() {
        try {
            // 对所有当前使用的连接开启事务
            offTransactionFlag();
            List<Connection> connections = getAllValidConnection();
            for(Connection conn: connections) {
                conn.commit();
            }
        } catch (SQLException e) {
            throw new RuntimeException("提交事物失败", e);
        }
    }

    /**
     * 回滚事务
     */
    public static void rollback() {
        try {
            // 对所有当前使用的连接开启事务
            offTransactionFlag();
            List<Connection> connections = getAllValidConnection();
            for(Connection conn: connections) {
                conn.rollback();
            }
        } catch (SQLException e) {
            throw new RuntimeException("回滚事物失败", e);
        }
    }



    //----------------------
    // 事务辅助函数
    //----------------------
    private static List<Connection> getAllValidConnection() throws SQLException {
        Map<DataSource, Connection> dsToConMap = LOCAL_CONNECTION.get();
        Collection<Connection> values = dsToConMap.values();

        List<Connection> results = new ArrayList<>();
        for(Connection conn: values) {
            if (conn != null && !conn.isClosed()) {
                results.add(conn);
            }
        }
        return results;
    }


    private static void onTransactionFlag() {
        TRANSACTION.set(true);
    }


    private static void offTransactionFlag() {
        TRANSACTION.remove();
    }

    public static boolean isTransactionFlag() {
        Boolean transaction = TRANSACTION.get();
        return transaction==null?false:transaction;
    }

}
