package com.artfess.base.interceptor;

import java.sql.Connection;
import java.util.*;

import com.alibaba.fastjson.JSONObject;
import com.artfess.base.aop.DataPermissionAspect;
import com.artfess.base.exception.ApplicationException;
import com.artfess.base.util.AppUtil;
import com.artfess.base.util.AuthenticationUtil;
import com.artfess.base.util.BeanUtils;
import com.artfess.base.util.StringUtil;
import com.baomidou.mybatisplus.core.enums.IEnum;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import com.baomidou.mybatisplus.extension.handlers.AbstractSqlParserHandler;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.*;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.DefaultReflectorFactory;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.type.TypeHandlerRegistry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.dao.EmptyResultDataAccessException;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.ResultSetExtractor;
import org.springframework.jdbc.support.rowset.SqlRowSet;
import org.springframework.jdbc.support.rowset.SqlRowSetMetaData;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;


/**
 * 数据权限过滤
 * 这里处理update 和 delete 语句  select 语句放在切面处理  可以改为拦截 Executor.update 方法
 * @author liygui
 */
@Slf4j
@Intercepts({@Signature(type=StatementHandler.class,method="prepare",args={Connection.class,Integer.class})})
public class DataPermissionInterceptor extends AbstractSqlParserHandler implements Interceptor{

	private Logger logger = LoggerFactory.getLogger(DataPermissionInterceptor.class);

	@SuppressWarnings("unchecked")
	@Override
	public Object intercept(Invocation invocation) throws Throwable {
		StatementHandler statementHandler = (StatementHandler)invocation.getTarget();
		// 通过MetaObject优雅访问对象的属性，这里是访问statementHandler的属性
		MetaObject metaObject = MetaObject.forObject(statementHandler, SystemMetaObject.DEFAULT_OBJECT_FACTORY, SystemMetaObject.DEFAULT_OBJECT_WRAPPER_FACTORY, new DefaultReflectorFactory());
		// 先拦截到RoutingStatementHandler，里面有个StatementHandler类型的delegate变量，其实现类是BaseStatementHandler，然后就到BaseStatementHandler的成员变量mappedStatement
		MappedStatement mappedStatement = (MappedStatement)metaObject.getValue("delegate.mappedStatement");

		BoundSql boundSql = statementHandler.getBoundSql();

		//sql字段开始 处理日志时时通过队列不是在一个request请求里面所以这里获取会爆空指针异常
		ServletRequestAttributes attributes=null;
		HttpServletRequest request=null;
		try {
			attributes = (ServletRequestAttributes) RequestContextHolder.currentRequestAttributes();
			request = attributes.getRequest();
		}catch (Exception e){
			return invocation.proceed();
		}

		String export=request.getParameter("export");
		if("field".equals(export)) {
			HttpServletResponse response = attributes.getResponse();

			//BoundSql boundSql = (BoundSql) metaObject.getValue("delegate.boundSql");

			String sql = boundSql.getSql().replaceAll("\\s+", " ").toLowerCase();

			List<ParameterMapping> parameterMappings = new ArrayList<>(boundSql.getParameterMappings());
			Object parameterObject = boundSql.getParameterObject();

			Configuration configuration = mappedStatement.getConfiguration();
			TypeHandlerRegistry typeHandlerRegistry = configuration.getTypeHandlerRegistry();

			try {
				this.sqlParser(metaObject);

				String parameter = "null";
				MetaObject newMetaObject = configuration.newMetaObject(parameterObject);
				for (ParameterMapping parameterMapping : parameterMappings) {
					if (parameterMapping.getMode() == ParameterMode.OUT) {
						continue;
					}
					String propertyName = parameterMapping.getProperty();
					if (typeHandlerRegistry.hasTypeHandler(parameterObject.getClass())) {
						parameter = getParameterValue(parameterObject);
					} else if (newMetaObject.hasGetter(propertyName)) {
						parameter = getParameterValue(newMetaObject.getValue(propertyName));
					} else if (boundSql.hasAdditionalParameter(propertyName)) {
						parameter = getParameterValue(boundSql.getAdditionalParameter(propertyName));
					}
					//此处不严谨，若sql语句中有❓，则替换错位。?️
					sql = sql.replaceFirst("\\?", parameter);
				}
				//最后要执行的sql 语句
				//System.out.println(sql);
			} catch (Exception e) {
				log.error(String.format("intercept sql error: [%s]", sql), e);
			}

			JdbcTemplate jdbcTemplate = AppUtil.getBean(JdbcTemplate.class);

			List<TableField> fieldList=getSqlFields(jdbcTemplate,sql);
			//System.out.println(fieldList);
			ServletOutputStream output = response.getOutputStream();
			JSONObject map=new JSONObject();
			map.put("fields",fieldList);
			output.write(map.toString().getBytes());
			output.flush();
			return null;
		}
		//SQL字段结束

		// 原始的SQL语句
		String sql = boundSql.getSql();
		// 改造后带过滤条件的sql
		metaObject.setValue("delegate.boundSql.sql", sql);

		Map<String, Object> mapThreadLocal = AuthenticationUtil.getMapThreadLocal();
		if(BeanUtils.isEmpty(mapThreadLocal)){
			// 没有配置数据权限设置的不处理
			return invocation.proceed();
		}

		SqlCommandType commondType = mappedStatement.getSqlCommandType();
		if (commondType.compareTo(SqlCommandType.SELECT) == 0) {
			// 查询语句在这里不处理 在切面中已经处理了
			return invocation.proceed();
		}

		String createBySql="";
		String createOrgIdSql="";
		if(mapThreadLocal.containsKey(DataPermissionAspect.CREATE_BY_)){
			createBySql = " CREATE_BY_ =  " + mapThreadLocal.get(DataPermissionAspect.CREATE_BY_);
		}

		if(mapThreadLocal.containsKey(DataPermissionAspect.CREATE_ORG_ID_)){
			Object object = mapThreadLocal.get(DataPermissionAspect.CREATE_ORG_ID_);
			Set<String> orgIds = (Set<String>) object;
			String inSql = getInSql(DataPermissionAspect.CREATE_ORG_ID_, orgIds);
			createOrgIdSql = inSql;
		}
		if(StringUtil.isNotEmpty(createBySql)){
			createOrgIdSql = createBySql + " or " + createOrgIdSql;
		}
		if(StringUtil.isNotEmpty(createOrgIdSql)){
			sql = sql + " and ( " + createOrgIdSql + ") "  ;
		}

		logger.debug(" custom sql " + sql );
		// 改造后带过滤条件的sql
		metaObject.setValue("delegate.boundSql.sql", sql);

		return invocation.proceed();
	}

	@Override
	public Object plugin(Object target) {
		return Plugin.wrap(target, this);
	}

	@Override
	public void setProperties(Properties properties) {
		logger.debug(" properties " + properties.toString());
	}

	/**
	 *
	 * @param field
	 * @param orgIds
	 * @return  ( field in (1, 2,3 ) or field in (4,5,6) )
	 */
	private String getInSql( String field, Set<String> orgIds ){

		StringBuffer sb = new StringBuffer(" (");

		Iterator<String> iterator = orgIds.iterator();
		List<String> list = new ArrayList<String>();
		int i =1;
		while (iterator.hasNext()) {
			String next = iterator.next();
			if(i>500){
				sb.append( field + " in (" +  String.join(",", list) +")");
				list =  new ArrayList<String>();
				i=1;
			}
			list.add(next);
			i++;
		}

		if(BeanUtils.isNotEmpty(list)){
			sb.append( field + " in (" +  String.join(",", list) +")");
		}

		sb.append(")");
		return sb.toString();
	}


	public ResultSetExtractor<?> getResultSetExtractor() {
		return new SQLRowSetOracleResultSetExtractor();
	}

	//返回查询的sql的字段名
	public List<TableField> getSqlFields(JdbcTemplate jdbcTemplate, String sql){
		//判断 如果语句中包含了update、delete 就返回空,不是以select开头
		if(sql.indexOf("update ")!=-1 || sql.indexOf("delete ")!=-1 || sql.indexOf("truncate ")!=-1|| sql.indexOf("create ")!=-1 || sql.indexOf("select")!=0){
			throw new ApplicationException("SQL语句只能是查询语句。");
		}
		//适用mysql和postgresql
		sql="select * from ("+sql+") t limit 1";
		SqlRowSet srs = (SqlRowSet) jdbcTemplate.query(sql, getResultSetExtractor());
		List<TableField> columnList = new ArrayList<TableField>();
		//mysql
		String queryColumn="select column_comment from information_schema.columns where table_schema ='yhxt' and table_name=? and column_name=?";
		//postgresql
		//String queryColumn="SELECT col_description ( A.attrelid, A.attnum ) AS COMMENT FROM pg_class AS C,pg_attribute AS A WHERE C.relname = ? and A.attname=? AND A.attrelid = C.oid AND A.attnum > 0";
		try {
			SqlRowSetMetaData rsmd = srs.getMetaData();
			for (int i = 1; i < rsmd.getColumnCount() + 1; i++) {
				String fieldName = rsmd.getColumnName(i).toLowerCase();
				String tableName = rsmd.getTableName(i).toLowerCase();
				String dataType = rsmd.getColumnTypeName(i).toLowerCase();
				String dataBase = rsmd.getCatalogName(i).toLowerCase();
				String fieldDesc = "";
				try {
					//fieldDesc=jdbcTemplate.queryForObject(queryClumn, new Object[]{dataBase, tableName, fieldName}, String.class);
					fieldDesc=jdbcTemplate.queryForObject(queryColumn, new Object[]{tableName, fieldName}, String.class);
				}catch (EmptyResultDataAccessException e){
					//sql里面有拼接字段取别名会出错
					System.out.println("**********"+fieldName);
				}
				TableField field=new TableField();
				field.setFieldDesc(fieldDesc);
				field.setTableName(tableName);
				field.setDataType(dataType);
				field.setFieldName(fieldName);
				columnList.add(field);
			}
		}catch (Exception e){
			throw new ApplicationException("解析SQL出错"+e.getMessage());
		}
		return columnList;
	}

	/**
	 * 获取参数
	 *
	 * @param param Object类型参数
	 * @return 转换之后的参数
	 */
	private static String getParameterValue(Object param) {
		if (param == null) {
			return "null";
		}
		if (param instanceof Number) {
			return param.toString();
		}
		String value = null;
		if (param instanceof String) {
			value = param.toString();
		} else if (param instanceof Date) {
			//TODO  引入依赖
			//DateUtil.format((Date) param, "yyyy-MM-dd HH:mm:ss");
		} else if (param instanceof IEnum) {
			value = String.valueOf(((IEnum) param).getValue());
		} else {
			value = param.toString();
		}
		return StringUtils.quotaMark(value);
	}
}
