package com.tbyd.data.datasync.core;

import com.alibaba.druid.pool.DruidDataSource;
import com.alibaba.druid.pool.DruidPooledConnection;
import com.tbyd.data.datasync.config.DataSyncProperties;
import com.tbyd.data.datasync.config.TableAndKeys;
import io.debezium.data.Envelope;
import io.debezium.engine.ChangeEvent;
import io.debezium.engine.DebeziumEngine;
import lombok.extern.slf4j.Slf4j;
import org.apache.kafka.connect.data.Field;
import org.apache.kafka.connect.data.Schema;
import org.apache.kafka.connect.data.Struct;
import org.apache.kafka.connect.source.SourceRecord;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;

@Slf4j
public class DataSyncWriter implements RecordHandler {

    private final DruidDataSource dataSource;
    private static final String JDBC_PATTERN = "jdbc:sqlserver://<server_name>:<port>;databaseName=<database_name>";
    private final Properties props;
    private final Map<String, TableMetadata> metadataCache = new HashMap<>();

    public DataSyncWriter(Properties props) {
        String host = DataSyncProperties.getDestinationHost(props);
        String port = DataSyncProperties.getDestinationPort(props);
        String user = DataSyncProperties.getDestinationUser(props);
        String password = DataSyncProperties.getDestinationPassword(props);
        String dbname = DataSyncProperties.getDestinationDbname(props);
        String url = buildUrl(host, port, dbname);
        dataSource = new DruidDataSource();
        dataSource.setUrl(url);
        dataSource.setUsername(user);
        dataSource.setPassword(password);
        this.props = props;
        init();
    }

    private void init() {
        for (TableAndKeys tableAndKeys : DataSyncProperties.getSyncTables(props)) {
            try (DruidPooledConnection conn = dataSource.getConnection()) {
                String tableName = tableAndKeys.getTable();
                TableMetadata tableMetadata = DBUtils.getTableMetadata(conn, tableName);
                tableMetadata.setKeys(tableAndKeys.getKeys());
                metadataCache.put(tableName, tableMetadata);
            } catch (SQLException e) {
                throw new IllegalStateException(e);
            }
        }
    }

    private static String buildUrl(String host, String port, String dbName) {
        return JDBC_PATTERN.replace("<server_name>", host)
                .replace("<port>", port)
                .replace("<database_name>", dbName);
    }

    @Override
    public boolean supports(ChangeEvent<SourceRecord, SourceRecord> record) {
        SourceRecord r = record.value();
        Struct value = (Struct) r.value();
        if (value == null) {
            return false;
        }
        Schema valueSchema = r.valueSchema();
        if (valueSchema.name().equals("io.debezium.connector.sqlserver.SchemaChangeValue")) {
            return false;
        }
        return true;
    }

    @Override
    public void handle(ChangeEvent<SourceRecord, SourceRecord> record) {
        throw new UnsupportedOperationException();
    }

    @Override
    public void handleBatch(List<ChangeEvent<SourceRecord, SourceRecord>> records, DebeziumEngine.RecordCommitter<ChangeEvent<SourceRecord, SourceRecord>> committer) throws Exception {
        RecordChangeEvent prevEvent = null;
        int maxBatchSize = 1500;
        int cnt = 0;
        for (ChangeEvent<SourceRecord, SourceRecord> record : records) {
            if (supports(record)) {
                cnt++;
                if (cnt >= maxBatchSize) {
                    cnt = 0;
                    if (prevEvent != null) {
                        commitPrepareStatement((PreparedStatement) prevEvent.attachment);
                    }
                    prevEvent = null;
                }
                String tableName = record.destination().substring(
                        record.destination().indexOf(".") + 1);
                Struct value = (Struct) record.value().value();
                Schema valueSchema = record.value().valueSchema();
                Struct before = (Struct) value.get(valueSchema.field(Envelope.FieldName.BEFORE));
                Struct after = (Struct) value.get(valueSchema.field(Envelope.FieldName.AFTER));
                String opCode = (String) value.get(valueSchema.field(Envelope.FieldName.OPERATION));
                Envelope.Operation op = Envelope.Operation.forCode(opCode);
                RecordChangeEvent curEvent = new RecordChangeEvent(tableName, op,
                        resolveColumns(before),
                        resolveColumns(after));

                try {
                    doHandle(curEvent, prevEvent);
                } catch (Exception e) {
                    log.error("处理记录变更时发生错误", e);
                    log.error("table: {}, before: {}, after: {}", curEvent.tableName, curEvent.before, curEvent.after);
                }
                if (prevEvent != null) {
                    prevEvent.attachment = null;
                }
                prevEvent = curEvent;
            }
            committer.markProcessed(record);
        }
        if (prevEvent != null) {
            commitPrepareStatement((PreparedStatement) prevEvent.attachment);
        }
        committer.markBatchFinished();
    }


    static class RecordChangeEvent {
        String tableName;
        Envelope.Operation op;
        Object[] before;
        Object[] after;
        Object attachment;

        public RecordChangeEvent(String tableName, Envelope.Operation op, Object[] before, Object[] after) {
            this.tableName = tableName;
            this.op = op;
            this.before = before;
            this.after = after;
        }

    }

    private void doHandle(RecordChangeEvent curEvent, RecordChangeEvent prevEvent) throws SQLException {
        switch (curEvent.op) {
            case CREATE:
                doHandleCreate(curEvent, prevEvent);
                return;
            case UPDATE:
                doHandleUpdate(curEvent, prevEvent);
                return;
            case DELETE:
                doHandleDelete(curEvent, prevEvent);
                return;
        }
    }

    private void doHandleDelete(RecordChangeEvent curEvent, RecordChangeEvent prevEvent) throws SQLException {
        PreparedStatement stmt;
        if (determineUseSamePreparedStatement(curEvent, prevEvent)) {
            stmt = (PreparedStatement) prevEvent.attachment;
        } else {
            if (prevEvent != null) {
                commitPrepareStatement((PreparedStatement) prevEvent.attachment);
            }
            stmt = getConn().prepareStatement(buildDeleteSql(metadataCache.get(curEvent.tableName)));
        }
        curEvent.attachment = stmt;
        TableMetadata tableMetadata = metadataCache.get(curEvent.tableName);
        String[] keys = tableMetadata.getKeysArray();
        for (int i = 0; i < keys.length; i++) {
            stmt.setObject(i + 1, curEvent.before[tableMetadata.getIndex(keys[i])]);
        }
        stmt.addBatch();
    }

    private void doHandleUpdate(RecordChangeEvent curEvent, RecordChangeEvent prevEvent) throws SQLException {
        PreparedStatement stmt;
        if (determineUseSamePreparedStatement(curEvent, prevEvent)) {
            stmt = (PreparedStatement) prevEvent.attachment;
        } else {
            if (prevEvent != null) {
                commitPrepareStatement((PreparedStatement) prevEvent.attachment);
            }
            stmt = getConn().prepareStatement(buildUpdateSql(metadataCache.get(curEvent.tableName)));
        }
        curEvent.attachment = stmt;
        for (int i = 0; i < curEvent.after.length; i++) {
            stmt.setObject(i + 1, curEvent.after[i]);
        }
        TableMetadata tableMetadata = metadataCache.get(curEvent.tableName);
        String[] keys = tableMetadata.getKeysArray();
        for (int i = 0; i < keys.length; i++) {
            stmt.setObject(curEvent.after.length + i + 1, curEvent.before[tableMetadata.getIndex(keys[i])]);
        }
        stmt.addBatch();
    }

    private void doHandleCreate(RecordChangeEvent curEvent, RecordChangeEvent prevEvent) throws SQLException {
        PreparedStatement stmt;
        if (determineUseSamePreparedStatement(curEvent, prevEvent)) {
            stmt = (PreparedStatement) prevEvent.attachment;
        } else {
            if (prevEvent != null) {
                commitPrepareStatement((PreparedStatement) prevEvent.attachment);
            }
            stmt = getConn().prepareStatement(buildInsertSql(metadataCache.get(curEvent.tableName)));
        }
        curEvent.attachment = stmt;
        for (int i = 0; i < curEvent.after.length; i++) {
            stmt.setObject(i + 1, curEvent.after[i]);
        }
        stmt.addBatch();
    }

    private Connection getConn() throws SQLException {
        Connection conn = dataSource.getConnection();
        conn.setAutoCommit(false);
        return conn;
    }

    private static void commitPrepareStatement(PreparedStatement stmt) {
        Connection conn = null;
        try {
            if (stmt != null) {
                conn = stmt.getConnection();
                stmt.executeBatch();
                conn.commit();
            }
        } catch (SQLException e) {
            log.error(e.getMessage(), e);
            try {
                conn.rollback();
            } catch (SQLException ex) {
            }
        } finally {
            DBUtils.closeResource(null, stmt, conn);
        }
    }

    /**
     * 是否使用同一个PreparedStatement
     */
    private boolean determineUseSamePreparedStatement(RecordChangeEvent curEvent, RecordChangeEvent prevEvent) {
        if (prevEvent == null) {
            return false;
        }
        return curEvent.tableName.equals(prevEvent.tableName) && curEvent.op == prevEvent.op;
    }

    private static Object[] resolveColumns(Struct colsStruct) {
        if (colsStruct == null) {
            return null;
        }
        Schema schema = colsStruct.schema();
        Object[] cols = new Object[schema.fields().size()];
        for (int i = 0; i < schema.fields().size(); i++) {
            Field field = schema.fields().get(i);
            Object value = colsStruct.get(field);
            value = convert(value, field.schema());
            cols[i] = value;
        }
        return cols;
    }

    private static Object convert(Object value, Schema fieldSchema) {
        if (value == null) {
            return null;
        }
        if (fieldSchema.name() == null) {
            if (fieldSchema.type() == Schema.Type.STRING) {
                return value;
            }
            if (fieldSchema.type() == Schema.Type.INT32) {
                return value;
            }
            if (fieldSchema.type() == Schema.Type.BYTES) {
                return value;
            }
        }
        switch (fieldSchema.name()) {
            case "org.apache.kafka.connect.data.Decimal":
                return value;
            case "io.debezium.time.Timestamp":
                return new Timestamp((Long) value);
            default:
                throw new IllegalStateException("暂不支持这种数据类型的字段：" + fieldSchema.name());
        }
    }

    private static String buildInsertSql(TableMetadata tableMetadata) {
        StringBuilder sql = new StringBuilder();
        sql.append("INSERT INTO ");
        sql.append(tableMetadata.getTableName());
        sql.append("(");
        sql.append(String.join(",", tableMetadata.getColumnNames()));
        sql.append(") VALUES (");
        for (int i = 0; i < tableMetadata.getColumnCount(); i++) {
            sql.append("?");
            if (i < tableMetadata.getColumnCount() - 1) {
                sql.append(",");
            }
        }
        sql.append(")");
        return sql.toString();
    }


    private static String buildUpdateSql(TableMetadata tableMetadata) {
        StringBuilder sql = new StringBuilder();
        sql.append("UPDATE ");
        sql.append(tableMetadata.getTableName());
        sql.append(" SET ");
        for (int i = 0; i < tableMetadata.getColumnCount(); i++) {
            sql.append(tableMetadata.getColumnNames()[i]);
            sql.append(" = ");
            sql.append("?");
            if (i < tableMetadata.getColumnCount() - 1) {
                sql.append(", ");
            }
        }
        sql.append(" WHERE ");
        String[] keys = tableMetadata.getKeysArray();
        for (int i = 0; i < keys.length; i++) {
            sql.append(keys[i]);
            sql.append(" = ");
            sql.append("?");
            if (i < keys.length - 1) {
                sql.append(" AND ");
            }
        }
        return sql.toString();
    }

    private static String buildDeleteSql(TableMetadata tableMetadata) {
        StringBuilder sql = new StringBuilder();
        sql.append("DELETE FROM ");
        sql.append(tableMetadata.getTableName());
        sql.append(" WHERE ");
        String[] keys = tableMetadata.getKeysArray();
        for (int i = 0; i < keys.length; i++) {
            sql.append(keys[i]);
            sql.append(" = ");
            sql.append("?");
            if (i < keys.length - 1) {
                sql.append(" AND ");
            }
        }
        return sql.toString();
    }
}
