import sys
sys.path.append(__file__.split('blue_script')[0] + 'blue_script')

import pymysql
from sshtunnel import SSHTunnelForwarder

from config import PROJECT
from lib.common import *


# 连接数据库，支持SSH隧道，host 为数据库地址，database 为数据库名，user 为数据库用户名，password 为数据库密码，autocommit 为是否自动提交，返回数据库连接对象和SSH隧道对象
def connect(host, database, user, password, autocommit=True, port=3306):
    host = PROJECT.getConfig('db.host', host)
    database = PROJECT.getConfig('db.database', database)
    user = PROJECT.getConfig('db.user', user)
    password = PROJECT.getConfig('db.password', password)

    if PROJECT.has('ssh'):
        sshHost = PROJECT.getConfig('ssh.ip')
        sshUser = PROJECT.getConfig('ssh.user')
        sshKey = PROJECT.getConfig('ssh.key')

        tunnel = SSHTunnelForwarder((sshHost, 22), ssh_username=sshUser, ssh_pkey=sshKey, remote_bind_address=(host, port))
        tunnel.daemon_forward_servers = True
        tunnel.start()
        localPort = tunnel.local_bind_port
        log.debug(f'SSH connect DB tunnel started: {sshHost}, {sshUser}, {sshKey}, ({host}:{port}) -> (localhost:{localPort})')

        ret = pymysql.connect(
            host='127.0.0.1', 
            port=tunnel.local_bind_port, 
            user=user, 
            password=password, 
            database=database, 
            autocommit=autocommit,
            connect_timeout=10, # 连接超时时间
            read_timeout=300,   # 读取结果超时
            write_timeout=30,   # 发送请求超时
        )
        log.debug(f'DB connected: {host}({database}), user: {user}, via SSH (localhost:{localPort}) -> {sshHost} -> ({host}:3306)')
    else:
        tunnel = None
        ret = pymysql.connect(host=host, user=user, password=password, database=database, autocommit=autocommit)
        log.debug(f'DB connected: {host}({database}), user: {user}')
    return ret, tunnel


# 查询数据库，host 为数据库地址，database 为数据库名，user 为数据库用户名，password 为数据库密码，isDict 为是否返回字典，withName 为是否返回字段名，raiseEx 为是否抛出异常，返回查询结果（list）
def query(sql, host=None, database=None, user=None, password=None, isDict=False, withName=False, raiseEx=False, port=3306):
    ret = []
    conn = None
    tunnel = None
    cursor = None

    try:
        log.debug(f'Query: {sql}')
        conn, tunnel = connect(host, database, user, password, port=port)

        if isDict:
            cursor = conn.cursor(pymysql.cursors.DictCursor)
        else:
            cursor = conn.cursor(pymysql.cursors.Cursor)

        cursor.execute(sql)
        ret = cursor.fetchall()

        if log.isDebug():
            rowNum = len(ret)
            log.debug(f'Query result: {rowNum} rows')

        if withName and not isDict:
            names = [desc[0] for desc in cursor.description]
            ret = [names] + list(ret)

        return ret
    except Exception as e:
        if raiseEx:
            raise e
        else:
            log.exception(f'Query error: {e}')

        return ret
    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()
        if tunnel:
            tunnel.stop()


# 更新数据库，host 为数据库地址，database 为数据库名，user 为数据库用户名，password 为数据库密码，raiseEx 为是否抛出异常，返回修改条数
def update(sql, host=None, database=None, user=None, password=None, raiseEx=False, port=3306):
    conn = None
    tunnel = None
    cursor = None

    try:
        log.debug(f'Update: {sql}')
        conn, tunnel = connect(host, database, user, password, port=port)

        cursor = conn.cursor()
        cursor.execute(sql)
        ret = cursor.rowcount
        conn.commit()

        return ret
    except Exception as e:
        log.error(f'Update error: {e}')
        if raiseEx:
            raise e
    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()
        if tunnel:
            tunnel.stop()


# 创建数据库类，name 为数据库名，返回数据库类
def createClass(name):
    host = PROJECT.getConfig(f'db.{name}_db.host')
    database = PROJECT.getConfig(f'db.{name}_db.database')
    user = PROJECT.getConfig(f'db.{name}_db.user')
    password = PROJECT.getConfig(f'db.{name}_db.password')

    user = PROJECT.getConfig('db.user', user)
    password = PROJECT.getConfig('db.password', password)

    if name == 'game':
        hostFormat = PROJECT.getConfig(f'db.{name}_db.host')
        databaseFormat = PROJECT.getConfig(f'db.{name}_db.database')

        def queryMethod(sid, sql, isDict=False, withName=False, raiseEx=False):
            host = hostFormat.format(sid)
            database = databaseFormat.format(sid)
            return query(sql, host, database, user, password, isDict, withName, raiseEx)

        def updateMethod(sid, sql, raiseEx=False):
            host = hostFormat.format(sid)
            database = databaseFormat.format(sid)
            return update(sql, host, database, user, password, raiseEx)
    else:
        def queryMethod(sql, isDict=False, withName=False, raiseEx=False):
            return query(sql, host, database, user, password, isDict, withName, raiseEx)

        def updateMethod(sql, raiseEx=False):
            return update(sql, host, database, user, password, raiseEx)

    methods = {
        'query': staticmethod(queryMethod),
        'update': staticmethod(updateMethod),
    }

    return type(name, (object,), methods)




def initClass():
    if 'db' in PROJECT:
        for key in PROJECT['db']:
            if key.endswith('_db'):
                name = key[:-3]
                globals()[name] = createClass(name)

initClass()

# 系统数据库列表
SYSTEM_DB_LIST = ['Database', 'information_schema', 'performance_schema', 'mysql', 'sys', 'test', '__recycle_bin__']

# 获取数据库列表
def getDatabaseList(host, user=None, password=None):
    systemDBStr = ', '.join(f"'{db}'" for db in SYSTEM_DB_LIST)
    sql = f"""
        SELECT 
            schema_name AS dbname 
        FROM 
            information_schema.schemata 
        WHERE 
            schema_name NOT IN ({systemDBStr})
        ORDER BY 
            schema_name;
    """
    data = query(sql, host, '', user, password)
    ret = [row[0] for row in data]  # 提取出数据库名
    return ret


# 获取数据库大小（MB）
def getDatabaseSize(host, user=None, password=None):
    systemDBStr = ', '.join(f"'{db}'" for db in SYSTEM_DB_LIST)
    sql = f"""
        SELECT 
            table_schema,
            SUM(data_length + index_length) AS size 
        FROM 
            information_schema.tables 
        WHERE 
            table_schema NOT IN ({systemDBStr})
        GROUP BY 
            table_schema;
    """
    return query(sql, host, '', user, password)


# 获取数据库连接情况
def getConnectStatus(host, user=None, password=None, port=3306):
    sql = """
                SELECT 
                    host AS ip,
                    IFNULL(db, 'NULL') AS dbname,
                    COUNT(*) AS connection_count
                FROM 
                    information_schema.PROCESSLIST
                GROUP BY 
                    host, db
                ORDER BY 
                    connection_count DESC;
            """
    return query(sql, host, '', user, password)

# AWS RDS实例类型到内存(GiB)的映射。
# 注意：此列表可能不完整，您可以根据需要添加更多实例类型。
# 数据来源: AWS官方文档
RDS_INSTANCE_MEMORY_GIB = {
    # T4g
    "db.t4g.micro": 1, "db.t4g.small": 2, "db.t4g.medium": 4, "db.t4g.large": 8,
    # T3
    "db.t3.micro": 1, "db.t3.small": 2, "db.t3.medium": 4, "db.t3.large": 8,
    "db.t3.xlarge": 16, "db.t3.2xlarge": 32,
    # R8g
    "db.r8g.large": 16, "db.r8g.xlarge": 32, "db.r8g.2xlarge": 64, "db.r8g.4xlarge": 128,
    "db.r8g.8xlarge": 256, "db.r8g.12xlarge": 384, "db.r8g.16xlarge": 512, "db.r8g.24xlarge": 768,
    # R7g
    "db.r7g.large": 16, "db.r7g.xlarge": 32, "db.r7g.2xlarge": 64, "db.r7g.4xlarge": 128,
    "db.r7g.8xlarge": 256, "db.r7g.12xlarge": 384, "db.r7g.16xlarge": 512,
    # R6g
    "db.r6g.large": 16, "db.r6g.xlarge": 32, "db.r6g.2xlarge": 64, "db.r6g.4xlarge": 128,
    "db.r6g.8xlarge": 256, "db.r6g.12xlarge": 384, "db.r6g.16xlarge": 512,
    # R5
    "db.r5.large": 16, "db.r5.xlarge": 32, "db.r5.2xlarge": 64, "db.r5.4xlarge": 128,
    "db.r5.8xlarge": 256, "db.r5.12xlarge": 384, "db.r5.16xlarge": 512, "db.r5.24xlarge": 768,
}

def get_instance_memory_bytes(instance_type):
    """
    根据实例类型获取其总内存（单位：字节）。
    """
    memory_gib = RDS_INSTANCE_MEMORY_GIB.get(instance_type)
    if not memory_gib:
        raise ValueError(f"未找到实例类型 '{instance_type}' 的内存信息。请在 RDS_INSTANCE_MEMORY_GIB 字典中添加它。")
    # 1 GiB = 1024 * 1024 * 1024 bytes
    return memory_gib * (1024**3)


# AWS RDS 实例类型到网络基准带宽 (Gbps) 的映射。
# 注意：此列表可能不完整，数据基于 AWS 官方文档。对于 T 系列实例，这里使用的是基准性能，而非可突增的带宽。
RDS_INSTANCE_BANDWIDTH_GBPS = {
    # T4g (Baseline bandwidth)
    "db.t4g.micro": 0.046, "db.t4g.small": 0.092, "db.t4g.medium": 0.184, "db.t4g.large": 0.276,
    "db.t4g.xlarge": 0.368, "db.t4g.2xlarge": 0.736,
    # T3 (Baseline bandwidth)
    "db.t3.micro": 0.046, "db.t3.small": 0.092, "db.t3.medium": 0.184, "db.t3.large": 0.276,
    "db.t3.xlarge": 0.368, "db.t3.2xlarge": 0.736,
    # R8g / R7g / R6g
    "db.r8g.large": 12.5, "db.r8g.xlarge": 12.5, "db.r8g.2xlarge": 12.5, "db.r8g.4xlarge": 12.5, "db.r8g.8xlarge": 15, "db.r8g.12xlarge": 22.5, "db.r8g.16xlarge": 30, "db.r8g.24xlarge": 50,
    "db.r7g.large": 12.5, "db.r7g.xlarge": 12.5, "db.r7g.2xlarge": 12.5, "db.r7g.4xlarge": 12.5, "db.r7g.8xlarge": 15, "db.r7g.12xlarge": 22.5, "db.r7g.16xlarge": 30,
    "db.r6g.large": 10, "db.r6g.xlarge": 10, "db.r6g.2xlarge": 10, "db.r6g.4xlarge": 10, "db.r6g.8xlarge": 12.5, "db.r6g.12xlarge": 20, "db.r6g.16xlarge": 25,
    # R5
    "db.r5.large": 10, "db.r5.xlarge": 10, "db.r5.2xlarge": 10, "db.r5.4xlarge": 10, "db.r5.8xlarge": 10, "db.r5.12xlarge": 12.5, "db.r5.16xlarge": 20, "db.r5.24xlarge": 25,
}


if __name__ == '__main__':
    PROJECT.readConfig('lastwar/aws/us')
    
    data = query('SELECT * FROM 000_changelog limit 1')

    print(data)