package com.mario6.common.db;

import com.mario6.common.db.mapper.MapRowMapper;
import com.mario6.common.db.mapper.RowMapper;
import com.mario6.common.db.mapper.SingleColumnRowMapper;
import com.mario6.common.db.util.JdbcUtils;

import javax.sql.DataSource;
import java.sql.*;
import java.util.*;

/**
 * 数据库访问工具类
 *
 * 线程安全的
 */
public class JdbcTemplate {

    // 所使用的数据源
    private DataSource dataSource;
    /* 标志，是否每次执行访问方法后关闭连接, null或true代表自动关闭 */
    private static ThreadLocal<Boolean> AUTO_CLOSE = new ThreadLocal<>();

    public JdbcTemplate(DataSource dataSource) {
        this.dataSource = dataSource;
    }

    //---------------------------------------------------------
    // SQL执行的基本方法
    //---------------------------------------------------------
    //----------------------------------------------------------------------------
    // 查询，预期结果集有任意行结果
    //----------------------------------------------------------------------------
    public<T> List<T> query(String sql, RowMapper<T> rowMapper) throws DataAccessException {
        return query(sql, null, rowMapper);
    }

    public<T> List<T> query(String sql, RowMapper<T> rowMapper, Object...params) throws DataAccessException {
        List<Object> paramsToUse = Arrays.asList(params);
        return query(sql, paramsToUse, rowMapper);
    }

    public<T> List<T> query(String sql, List params, RowMapper<T> rowMapper) throws DataAccessException {
        Connection conn = getConnection();
        PreparedStatement stat = null;
        ResultSet rs = null;
        List<T> result = new ArrayList<>();
        try {
            stat = conn.prepareStatement(sql);
            setPreparedStatement(stat, params);
            rs = stat.executeQuery();
            // 获得表列信息
            result = resultToList(rs, rowMapper);
        } catch (SQLException e) {
            closeConnection();
            conn = null;
            throw new DataAccessException(e);
        } finally {
            closeResultSet(rs);
            closeStatement(stat);
            // 尝试关闭连接
            tryToCloseConnection();
        }
        return result;
    }

    //------查询结果单个实体是Map
    public List<Map<String, Object>> query(String sql, List params) throws DataAccessException {
        return query(sql, params, MapRowMapper.newInstance());
    }

    public List<Map<String, Object>> query(String sql)  throws DataAccessException {
        return query(sql, new ArrayList(0));
    }

    public List<Map<String, Object>> query(String sql, Object...params)   throws DataAccessException {
        List<Object> paramsToUse = Arrays.asList(params);
        return query(sql, paramsToUse);
    }


    //----------------------------------------------------------------------------
    // 查询，预期结果集只有单行结果
    //----------------------------------------------------------------------------
    public<T> T queryForObject(String sql, List params, RowMapper<T> rowMapper)  throws DataAccessException {
        Connection conn = getConnection();
        PreparedStatement stat = null;
        ResultSet rs = null;
        T result = null;
        try {
            stat = conn.prepareStatement(sql);
            setPreparedStatement(stat, params);
            rs = stat.executeQuery();
            if(rs.next()){
                result = rowMapper.rowMap(rs);
            }
            if(rs.next()) {
                throw new DataAccessException("结果集数据不止一行");
            }
        } catch (SQLException e) {
            closeConnection();
            conn = null;
            throw new DataAccessException(e);
        } finally {
            closeResultSet(rs);
            closeStatement(stat);
            // 尝试关闭连接
            tryToCloseConnection();
        }
        return result;
    }

    public<T> T queryForObject(String sql, RowMapper<T> rowMapper) throws DataAccessException {
        return queryForObject(sql, null, rowMapper);
    }

    public<T> T queryForObject(String sql, RowMapper<T> rowMapper, Object...params)  throws DataAccessException {
        return queryForObject(sql, Arrays.asList(params), rowMapper);
    }

    //------查询结果单个实体是Map
    public Map<String, Object> queryForObject(String sql) throws DataAccessException {
        return queryForObject(sql, new ArrayList(0));
    }

    public Map<String, Object> queryForObject(String sql, Object...params)  throws DataAccessException {
        List<Object> paramsToUse = Arrays.asList(params);
        return queryForObject(sql, paramsToUse);
    }

    public Map<String, Object> queryForObject(String sql, List params) throws DataAccessException {
        return queryForObject(sql, params, MapRowMapper.newInstance());
    }


    //----------------------------------------------------------------------------
    // 查询，预期单列结果
    //----------------------------------------------------------------------------
    public<T> List<T> queryForList(String sql, Class<T> elementType) throws DataAccessException {
        return queryForList(sql, null, elementType);
    }

    public<T> List<T> queryForList(String sql, Class<T> elementType, Object...params) throws DataAccessException {
        return queryForList(sql, Arrays.asList(params), elementType);
    }

    public<T> List<T> queryForList(String sql, List params, Class<T> elementType) throws DataAccessException {
        RowMapper<T> mapper = SingleColumnRowMapper.newInstance(elementType);
        return query(sql, params, mapper);
    }


    //----------------------------------------------------------------------------
    // 查询，预期单行&单列结果
    //----------------------------------------------------------------------------
    public<T> T queryForObject(String sql, Class<T> requireType) throws DataAccessException {
        return queryForObject(sql, null, requireType);
    }

    public<T> T queryForObject(String sql, Class<T> requireType, Object...params) throws DataAccessException {
        return queryForObject(sql, Arrays.asList(params), requireType);
    };

    public<T> T queryForObject(String sql, List params, Class<T> requireType) throws DataAccessException {
        RowMapper<T> mapper = SingleColumnRowMapper.newInstance(requireType);
        return queryForObject(sql, params, mapper);
    };


    //----------------------------------------------------------------------------
    // 更新相关
    //----------------------------------------------------------------------------
    public int update(String sql)  throws DataAccessException {
        return update(sql, new ArrayList(0));
    }

    public int update(String sql, Object...params)  throws DataAccessException {
        List<Object> paramsToUse = Arrays.asList(params);
        return update(sql, paramsToUse);
    }

    public int update(String sql, List params)  throws DataAccessException {
        return update(sql, params, null);
    }

    public int update(String sql, AutoKey key)   throws DataAccessException {
        return update(sql, new ArrayList(0), key);
    }

    public int update(String sql, AutoKey key, Object...params)    throws DataAccessException {
        return update(sql, Arrays.asList(params), key);
    }


    public int update(String sql, List params, AutoKey key)     throws DataAccessException {
        Connection conn = getConnection();
        PreparedStatement stat = null;
        int rows = 0;
        try {
            if(key != null) {
                stat = conn.prepareStatement(sql, PreparedStatement.RETURN_GENERATED_KEYS);
            } else {
                stat = conn.prepareStatement(sql);
            }
            setPreparedStatement(stat, params);
            rows = stat.executeUpdate();

            if(key != null && rows >0) {
                ResultSet keys = stat.getGeneratedKeys();
                resolveAutoKey(keys, key);
            }
        } catch (SQLException e) {
            closeConnection();
            conn = null;
            throw new DataAccessException(e);
        } finally {
            closeStatement(stat);
            // 尝试关闭连接
            tryToCloseConnection();
        }
        return rows;
    }


    //------------------------------------------------------------
    // 存储过程相关
    //------------------------------------------------------------
    public Map<Integer, Object> call(String call, Object[] inArgs, int[] outArgs) throws DataAccessException {
        Connection conn = getConnection();
        CallableStatement stat = null;
        ResultSet rs = null;
        try {
            stat = conn.prepareCall(call);
            // 设置输入参数
            setPreparedStatement(stat, Arrays.asList(inArgs));
            final int start = (inArgs==null)?(1):(inArgs.length+1);
            if(outArgs != null) {
                // 注册输出参数
                registerOutParameter(stat, start, outArgs);
            }
            stat.execute();
            Map<Integer, Object> result = new HashMap<>();
            for(int i=0; i<outArgs.length; i++) {
                Object value = stat.getObject(start+i);
                result.put(i+1, value);
            }
            if(result.size() == 0) {
                return null;
            }
            return result;
        } catch (SQLException e) {
            closeConnection();
            conn = null;
            throw new DataAccessException(e);
        } finally {
            closeResultSet(rs);
            closeStatement(stat);
            // 尝试关闭连接
            tryToCloseConnection();
        }
    }

    //------------------------------------------------------------
    // 简化统计sql语句
    //------------------------------------------------------------
    public int count(String sql) throws DataAccessException {
        String countSql = queryToCount(sql);
        return queryForObject(countSql, Integer.class);
    }

    public int count(String sql, List params) throws DataAccessException {
        String countSql = queryToCount(sql);
        return queryForObject(countSql, Integer.class, params);
    }

    public int count(String sql, Object...params) throws DataAccessException {
        String countSql = queryToCount(sql);
        return queryForObject(countSql, Integer.class, params);
    }

    public long countForLong(String sql, List params) throws DataAccessException {
        String countSql = queryToCount(sql);
        return queryForObject(countSql, Long.class);
    }

    public long countForLong(String sql, Object...params) throws DataAccessException {
        String countSql = queryToCount(sql);
        return queryForObject(countSql, Integer.class, params);
    }


    //------------------------------
    // 一些可能提升性能的方法
    //------------------------------
    /**
     * 当前是否自动关闭连接
     * @return
     */
    public boolean isAutoClose() {
        Boolean autoClose = AUTO_CLOSE.get();
        return autoClose == null?true:autoClose;
    }

    /**
     * 设置自动关闭标志值
     * @param autoClose
     */
    public void setAutoClose(boolean autoClose) {
        AUTO_CLOSE.set(autoClose);
    }

    //------------------------------------------------------------
    // 查询辅助函数
    //------------------------------------------------------------
    private String queryToCount(String sql) {
        String upperSql = sql.toUpperCase();
        int sidx = upperSql.indexOf(" FROM ");
        if(sidx < 0) {
            throw new RuntimeException("无效的统计SQL:" + sql);
        }
        int eidx = upperSql.indexOf(" LIMIT");

        String snippet = "";
        if(eidx != -1) {
            snippet = sql.substring(sidx, eidx+1);
        } else {
            snippet = sql.substring(sidx);
        }
        return "SELECT count(1) total "  + snippet;
    }

    // 解析自动主键，并设置值
    private void resolveAutoKey(ResultSet keys, AutoKey key) throws SQLException {
        if (keys.next()) {
            Object value = null;
            if(key instanceof AutoKey.IntegerAutoKey) {
                value = keys.getInt(1);
            } else if(key instanceof AutoKey.LongAutoKey) {
                value = keys.getString(1);
            } else if(key instanceof AutoKey.StringAutoKey) {
                value = keys.getLong(1);
            } else {
                value = keys.getObject(1);
            }
            key.set(value);
        }
    }
    // 结果集转化为list<map>结构
    private<T> List<T> resultToList(ResultSet rs, RowMapper<T> rowMapper) throws SQLException {
        List<T> results = new ArrayList<>();
        while(rs.next()){
            T one = rowMapper.rowMap(rs);
            results.add(one);
        }
        return results;
    }

    // 设置PreparedStatement参数
    private void setPreparedStatement(PreparedStatement stat, List params) throws SQLException {
        if(params == null || params.size() == 0) {
            return;
        }
        for(int i = 0; i < params.size(); i++) {
            stat.setObject(i+1, params.get(i));
        }
    }

    private void registerOutParameter(CallableStatement stat, int start, int[] types) throws SQLException {
        for(int i = 0; i < types.length; i++) {
            stat.registerOutParameter(start+i, types[i]);
        }
    }

    private Connection getConnection() {
        try {
            return DataSourceManager.getConnection(dataSource);
        } catch (SQLException e) {
            throw new DataAccessException("获取数据库连接失败", e);
        }
    }

    public void closeConnection() {
        Connection conn = DataSourceManager.getCurrentConnection(dataSource);
        JdbcUtils.closeConnection(conn);
    }

    private void closeResultSet(ResultSet rs) {
        JdbcUtils.closeResultSet(rs);
    }

    private void closeStatement(PreparedStatement stat) {
        JdbcUtils.closeStatement(stat);
    }

    private  boolean tryToCloseConnection() {
        if(!DataSourceManager.isTransactionFlag() && isAutoClose()) {
            closeConnection();
            return true;
        }
        return false;
    }
}
