import boto3
import re
import logging
from datetime import datetime
from lib.common.db import get_instance_memory_bytes
# Configure logging
logger = logging.getLogger(__name__)

# 将最大连接数定义为库内的常量
MAX_CONNECTIONS = 16000

# AWS RDS 实例类型到网络基准带宽 (Gbps) 的映射。
# 注意：此列表可能不完整，数据基于 AWS 官方文档。对于 T 系列实例，这里使用的是基准性能，而非可突增的带宽。
RDS_INSTANCE_BANDWIDTH_GBPS = {
    # T4g (Baseline bandwidth)
    "db.t4g.micro": 0.064, "db.t4g.small": 0.128, "db.t4g.medium": 0.256, "db.t4g.large": 0.512,
    "db.t4g.xlarge": 1.024, "db.t4g.2xlarge": 2.048,
    # T3 (Baseline bandwidth)
    "db.t3.micro": 0.064, "db.t3.small": 0.128, "db.t3.medium": 0.256, "db.t3.large": 0.512,
    "db.t3.xlarge": 1.024, "db.t3.2xlarge": 2.048,
    # R8g / R7g 
    "db.r8g.large": 0.937, "db.r8g.xlarge": 1.875, "db.r8g.2xlarge": 3.75, "db.r8g.4xlarge": 7.5, "db.r8g.8xlarge": 15, "db.r8g.12xlarge": 22.5, "db.r8g.16xlarge": 30, "db.r8g.24xlarge": 50,
    "db.r7g.large": 0.937, "db.r7g.xlarge": 1.875, "db.r7g.2xlarge": 3.75, "db.r7g.4xlarge": 7.5, "db.r7g.8xlarge": 15, "db.r7g.12xlarge": 22.5, "db.r7g.16xlarge": 30
}
class DbAlarmManager:
    METRIC_NAME_SLOW_QUERY = 'AuroraSlowQueryCount'
    METRIC_NAME_NO_INDEX = 'AuroraNoIndexSlowQueryCount'
    METRIC_NAMESPACE = 'RDS/AuroraSlowLog'
    DEFAULT_LAMBDA_FUNCTION_NAME = 'aurora-slowlog-processor'
    DEFAULT_LAMBDA_ROLE_NAME = 'aurora-slowlog-processor-role'
    def __init__(self, credential):
        """
        初始化数据库告警管理器。
        Args:
            credential (dict): AWS 认证信息。
        """
        self.rds = boto3.client('rds', **credential)
        self.cloudwatch = boto3.client('cloudwatch', **credential)
        logger.info("DbAlarmManager initialized.")

    def _get_clusters_by_pattern(self, cluster_pattern):
        """根据正则表达式模式获取匹配的 RDS 集群及其成员实例信息。"""
        paginator = self.rds.get_paginator('describe_db_clusters')
        matched_clusters = []
        regex = re.compile(cluster_pattern)

        logger.info(f"正在查找匹配模式 '{cluster_pattern}' 的数据库集群...")
        for page in paginator.paginate():
            for cluster in page['DBClusters']:
                if regex.match(cluster['DBClusterIdentifier']):
                    matched_clusters.append(cluster)
        
        if not matched_clusters:
            logger.warning("未找到任何匹配的数据库集群。")
        else:
            logger.info(f"找到 {len(matched_clusters)} 个匹配的集群。")
            
        return matched_clusters

    def _create_cluster_cpu_alarm(self, cluster_info, cpu_utilization, alarm_topic):
        """为指定的数据库集群创建或更新 CPU 利用率告警 (基于集群最高值)。"""
        cluster_id = cluster_info['DBClusterIdentifier']
        alarm_name = f"{cluster_id}-MaxCPUUtilization"
        
        metrics_data = []
        for i, member in enumerate(cluster_info.get('DBClusterMembers', [])):
            instance_id = member['DBInstanceIdentifier']
            metric = {
                'Id': f'm{i}',
                'MetricStat': {
                    'Metric': {
                        'Namespace': 'AWS/RDS',
                        'MetricName': 'CPUUtilization',
                        'Dimensions': [{'Name': 'DBInstanceIdentifier', 'Value': instance_id}]
                    },
                    'Period': 60,
                    'Stat': 'Maximum'
                },
                'ReturnData': False,
            }
            metrics_data.append(metric)

        if not metrics_data:
            logger.warning(f"集群 {cluster_id} 没有任何成员实例，无法创建告警。")
            return None

        metrics_data.append({
            'Id': 'e1',
            'Expression': 'MAX(METRICS())',
            'Label': 'MaxClusterCPUUtilization',
            'ReturnData': True,
        })

        logger.info(f"为集群 {cluster_id} 创建/更新 CPU 告警...")
        
        return self.cloudwatch.put_metric_alarm(
            AlarmName=alarm_name,
            AlarmDescription=f'The maximum CPU utilization for any instance in cluster {cluster_id} is over {cpu_utilization}%',
            ActionsEnabled=True,
            AlarmActions=[alarm_topic],
            OKActions=[alarm_topic],
            EvaluationPeriods=3,
            DatapointsToAlarm=3,
            Threshold=float(cpu_utilization),
            ComparisonOperator='GreaterThanThreshold',
            TreatMissingData='missing',
            Metrics=metrics_data
        )

    def _check_cluster_cpu_alarm(self, cluster_id, cpu_utilization, alarm_topic):
        """检查集群 CPU 告警是否存在且配置正确。"""
        alarm_name = f"{cluster_id}-MaxCPUUtilization"
        try:
            alarms = self.cloudwatch.describe_alarms(AlarmNamePrefix=alarm_name, AlarmTypes=['MetricAlarm'])
            for alarm in alarms['MetricAlarms']:
                if alarm.get('Metrics') and any(m.get('Expression') == 'MAX(METRICS())' for m in alarm['Metrics']):
                    needs_update = (
                        abs(float(alarm['Threshold']) - float(cpu_utilization)) > 0.1 or
                        alarm['AlarmActions'] != [alarm_topic]
                    )
                    return True, alarm_name, needs_update
            return False, alarm_name, False
        except Exception as e:
            logger.error(f"检查集群 {cluster_id} 的 CPU 告警时出错: {str(e)}")
            return False, None, False

    def set_cpu_alarms_by_pattern(self, cluster_pattern, cpu_utilization, alarm_topic):
        """为匹配的数据库集群创建或更新 CPU 利用率告警。"""
        clusters = self._get_clusters_by_pattern(cluster_pattern)
        if not alarm_topic or not alarm_topic.startswith('arn:aws:sns:'):
            raise ValueError("无效或未配置 SNS Topic ARN。")

        results = []
        for cluster in clusters:
            cluster_id = cluster['DBClusterIdentifier']
            try:
                exists, alarm_name, needs_update = self._check_cluster_cpu_alarm(cluster_id, cpu_utilization, alarm_topic)
                
                if not exists or needs_update:
                    if not exists:
                        logger.info(f"\n集群 {cluster_id} 未配置 CPU 告警，开始创建...")
                    else:
                        logger.info(f"\n集群 {cluster_id} 的 CPU 告警配置需要更新...")
                    
                    result = self._create_cluster_cpu_alarm(cluster, cpu_utilization, alarm_topic)
                    if result:
                        status = 'created' if not exists else 'updated'
                        results.append({'clusterId': cluster_id, 'status': status})
                    else:
                        results.append({'clusterId': cluster_id, 'status': 'skipped_no_members'})
                else:
                    logger.info(f"\n集群 {cluster_id} 的 CPU 告警已正确配置。")
                    results.append({'clusterId': cluster_id, 'status': 'unchanged'})
                    
            except Exception as e:
                logger.error(f"为集群 {cluster_id} 处理 CPU 告警失败: {str(e)}")
                results.append({'clusterId': cluster_id, 'status': 'failed', 'error': str(e)})
        
        return results

    def _create_cluster_memory_alarm(self, cluster_info, memory_utilization, alarm_topic):
        """为指定的数据库集群创建或更新内存利用率告警 (基于集群最小可用内存)。"""
        cluster_id = cluster_info['DBClusterIdentifier']
        alarm_name = f"{cluster_id}-MinFreeableMemory"
        
        metrics_data = []
        instance_types = {}
        
        # 获取每个实例的详细信息以确定实例类型
        for i, member in enumerate(cluster_info.get('DBClusterMembers', [])):
            instance_id = member['DBInstanceIdentifier']
            try:
                instance_details = self.rds.describe_db_instances(DBInstanceIdentifier=instance_id)
                instance_class = instance_details['DBInstances'][0]['DBInstanceClass']
                instance_types[instance_id] = instance_class
                
                metric = {
                    'Id': f'm{i}',
                    'MetricStat': {
                        'Metric': {
                            'Namespace': 'AWS/RDS',
                            'MetricName': 'FreeableMemory',
                            'Dimensions': [{'Name': 'DBInstanceIdentifier', 'Value': instance_id}]
                        },
                        'Period': 60,
                        'Stat': 'Minimum'
                    },
                    'ReturnData': False,
                }
                metrics_data.append(metric)
            except Exception as e:
                logger.warning(f"无法获取实例 {instance_id} 的详细信息: {str(e)}")
                continue

        if not metrics_data:
            logger.warning(f"集群 {cluster_id} 没有任何可用的成员实例，无法创建告警。")
            return None

        # 计算最小的内存阈值（集群中所有实例的最小值）
        min_threshold = float('inf')
        for instance_id, instance_class in instance_types.items():
            try:
                total_memory_bytes = get_instance_memory_bytes(instance_class)
                threshold = total_memory_bytes * (1 - float(memory_utilization) / 100.0)
                min_threshold = min(min_threshold, threshold)
            except Exception as e:
                logger.warning(f"无法计算实例 {instance_id} ({instance_class}) 的内存阈值: {str(e)}")

        if min_threshold == float('inf'):
            logger.error(f"无法计算集群 {cluster_id} 的内存阈值。")
            return None

        # 添加 Metric Math 表达式
        metrics_data.append({
            'Id': 'e1',
            'Expression': 'MIN(METRICS())',
            'Label': 'MinClusterFreeableMemory',
            'ReturnData': True,
        })

        logger.info(f"为集群 {cluster_id} 创建/更新内存告警...")
        
        return self.cloudwatch.put_metric_alarm(
            AlarmName=alarm_name,
            AlarmDescription=f'The minimum freeable memory for any instance in cluster {cluster_id} is below the {memory_utilization}% utilization threshold',
            ActionsEnabled=True,
            AlarmActions=[alarm_topic],
            OKActions=[alarm_topic],
            EvaluationPeriods=3,
            DatapointsToAlarm=3,
            Threshold=min_threshold,
            ComparisonOperator='LessThanThreshold',
            TreatMissingData='missing',
            Metrics=metrics_data
        )

    def _check_cluster_memory_alarm(self, cluster_id, alarm_topic, expected_threshold):
        """检查集群内存告警是否存在且配置正确。"""
        alarm_name = f"{cluster_id}-MinFreeableMemory"
        try:
            alarms = self.cloudwatch.describe_alarms(AlarmNamePrefix=alarm_name, AlarmTypes=['MetricAlarm'])
            for alarm in alarms['MetricAlarms']:
                if alarm.get('Metrics') and any(m.get('Expression') == 'MIN(METRICS())' for m in alarm['Metrics']):
                    # 既检查 SNS Topic，也检查阈值
                    needs_update = (
                        alarm['AlarmActions'] != [alarm_topic] or
                        abs(float(alarm['Threshold']) - float(expected_threshold)) > 0.1
                    )
                    return True, alarm_name, needs_update
            return False, alarm_name, False
        except Exception as e:
            logger.error(f"检查集群 {cluster_id} 的内存告警时出错: {str(e)}")
            return False, None, False

    def set_memory_alarms_by_pattern(self, cluster_pattern, memory_utilization, alarm_topic):
        """为匹配的数据库集群创建或更新内存利用率告警。"""
        clusters = self._get_clusters_by_pattern(cluster_pattern)
        if not alarm_topic or not alarm_topic.startswith('arn:aws:sns:'):
            raise ValueError("无效或未配置 SNS Topic ARN。")

        results = []
        for cluster in clusters:
            cluster_id = cluster['DBClusterIdentifier']
            try:
                # 计算 expected_threshold
                instance_types = {}
                for member in cluster.get('DBClusterMembers', []):
                    instance_id = member['DBInstanceIdentifier']
                    try:
                        instance_details = self.rds.describe_db_instances(DBInstanceIdentifier=instance_id)
                        instance_class = instance_details['DBInstances'][0]['DBInstanceClass']
                        instance_types[instance_id] = instance_class
                    except Exception as e:
                        logger.warning(f"无法获取实例 {instance_id} 的详细信息: {str(e)}")
                        continue

                min_threshold = float('inf')
                for instance_id, instance_class in instance_types.items():
                    try:
                        total_memory_bytes = get_instance_memory_bytes(instance_class)
                        threshold = total_memory_bytes * (1 - float(memory_utilization) / 100.0)
                        min_threshold = min(min_threshold, threshold)
                    except Exception as e:
                        logger.warning(f"无法计算实例 {instance_id} ({instance_class}) 的内存阈值: {str(e)}")

                if min_threshold == float('inf'):
                    logger.error(f"无法计算集群 {cluster_id} 的内存阈值。")
                    results.append({'clusterId': cluster_id, 'status': 'skipped_no_members'})
                    continue

                exists, alarm_name, needs_update = self._check_cluster_memory_alarm(cluster_id, alarm_topic, min_threshold)
                
                if not exists or needs_update:
                    if not exists:
                        logger.info(f"\n集群 {cluster_id} 未配置内存告警，开始创建...")
                    else:
                        logger.info(f"\n集群 {cluster_id} 的内存告警配置需要更新...")
                    
                    result = self._create_cluster_memory_alarm(cluster, memory_utilization, alarm_topic)
                    if result:
                        status = 'created' if not exists else 'updated'
                        results.append({'clusterId': cluster_id, 'status': status})
                    else:
                        results.append({'clusterId': cluster_id, 'status': 'skipped_no_members'})
                else:
                    logger.info(f"\n集群 {cluster_id} 的内存告警已正确配置。")
                    results.append({'clusterId': cluster_id, 'status': 'unchanged'})
                    
            except Exception as e:
                logger.error(f"为集群 {cluster_id} 处理内存告警失败: {str(e)}")
                results.append({'clusterId': cluster_id, 'status': 'failed', 'error': str(e)})
        
        return results

    def _create_cluster_connection_alarm(self, cluster_info, connection_percentage, alarm_topic):
        """为指定的数据库集群创建或更新连接数告警 (基于集群最大连接数)。"""
        cluster_id = cluster_info['DBClusterIdentifier']
        alarm_name = f"{cluster_id}-MaxDatabaseConnections"
        
        metrics_data = []
        for i, member in enumerate(cluster_info.get('DBClusterMembers', [])):
            instance_id = member['DBInstanceIdentifier']
            metric = {
                'Id': f'm{i}',
                'MetricStat': {
                    'Metric': {
                        'Namespace': 'AWS/RDS',
                        'MetricName': 'DatabaseConnections',
                        'Dimensions': [{'Name': 'DBInstanceIdentifier', 'Value': instance_id}]
                    },
                    'Period': 60,
                    'Stat': 'Maximum'
                },
                'ReturnData': False,
            }
            metrics_data.append(metric)

        if not metrics_data:
            logger.warning(f"集群 {cluster_id} 没有任何成员实例，无法创建告警。")
            return None

        # 计算绝对阈值
        threshold_absolute = (float(connection_percentage) / 100.0) * MAX_CONNECTIONS

        # 添加 Metric Math 表达式
        metrics_data.append({
            'Id': 'e1',
            'Expression': 'MAX(METRICS())',
            'Label': 'MaxClusterDatabaseConnections',
            'ReturnData': True,
        })

        logger.info(f"为集群 {cluster_id} 创建/更新连接数告警...")
        logger.info(f"  - 监控配置: {connection_percentage}% of {MAX_CONNECTIONS} -> 触发阈值: {int(threshold_absolute)} connections")
        
        return self.cloudwatch.put_metric_alarm(
            AlarmName=alarm_name,
            AlarmDescription=f'The maximum database connections for any instance in cluster {cluster_id} exceeds {connection_percentage}% of the limit ({int(threshold_absolute)} connections)',
            ActionsEnabled=True,
            AlarmActions=[alarm_topic],
            OKActions=[alarm_topic],
            EvaluationPeriods=3,
            DatapointsToAlarm=3,
            Threshold=threshold_absolute,
            ComparisonOperator='GreaterThanThreshold',
            TreatMissingData='missing',
            Metrics=metrics_data
        )

    def _check_cluster_connection_alarm(self, cluster_id, connection_percentage, alarm_topic):
        """检查集群连接数告警是否存在且配置正确。"""
        alarm_name = f"{cluster_id}-MaxDatabaseConnections"
        try:
            alarms = self.cloudwatch.describe_alarms(AlarmNamePrefix=alarm_name, AlarmTypes=['MetricAlarm'])
            for alarm in alarms['MetricAlarms']:
                if alarm.get('Metrics') and any(m.get('Expression') == 'MAX(METRICS())' for m in alarm['Metrics']):
                    threshold_absolute = (float(connection_percentage) / 100.0) * MAX_CONNECTIONS
                    needs_update = (
                        abs(float(alarm['Threshold']) - threshold_absolute) > 0.1 or
                        alarm['AlarmActions'] != [alarm_topic]
                    )
                    return True, alarm_name, needs_update
            return False, alarm_name, False
        except Exception as e:
            logger.error(f"检查集群 {cluster_id} 的连接数告警时出错: {str(e)}")
            return False, None, False
   

    def set_connection_alarms_by_pattern(self, cluster_pattern, connection_percentage, alarm_topic):
        """为匹配的数据库集群创建或更新连接数告警。"""
        clusters = self._get_clusters_by_pattern(cluster_pattern)
        if not alarm_topic or not alarm_topic.startswith('arn:aws:sns:'):
            raise ValueError("无效或未配置 SNS Topic ARN。")

        results = []
        for cluster in clusters:
            cluster_id = cluster['DBClusterIdentifier']
            try:
                exists, alarm_name, needs_update = self._check_cluster_connection_alarm(cluster_id, connection_percentage, alarm_topic)
                
                if not exists or needs_update:
                    if not exists:
                        logger.info(f"\n集群 {cluster_id} 未配置连接数告警，开始创建...")
                    else:
                        logger.info(f"\n集群 {cluster_id} 的连接数告警配置需要更新...")
                    
                    result = self._create_cluster_connection_alarm(cluster, connection_percentage, alarm_topic)
                    if result:
                        status = 'created' if not exists else 'updated'
                        results.append({'clusterId': cluster_id, 'status': status})
                    else:
                        results.append({'clusterId': cluster_id, 'status': 'skipped_no_members'})
                else:
                    logger.info(f"\n集群 {cluster_id} 的连接数告警已正确配置。")
                    results.append({'clusterId': cluster_id, 'status': 'unchanged'})
                    
            except Exception as e:
                logger.error(f"为集群 {cluster_id} 处理连接数告警失败: {str(e)}")
                results.append({'clusterId': cluster_id, 'status': 'failed', 'error': str(e)})
        
        return results

    def _create_cluster_throughput_alarm(self, cluster_info, throughput_percentage, alarm_topic):
        """为指定的数据库集群创建或更新网络吞吐量百分比告警 (基于集群最大网络利用率百分比)。"""
        cluster_id = cluster_info['DBClusterIdentifier']
        alarm_name = f"{cluster_id}-MaxNetworkThroughputPercentage"
        
        metrics_data = []
        instance_types = {}
        
        # 获取每个实例的详细信息以确定实例类型
        for i, member in enumerate(cluster_info.get('DBClusterMembers', [])):
            instance_id = member['DBInstanceIdentifier']
            try:
                instance_details = self.rds.describe_db_instances(DBInstanceIdentifier=instance_id)
                instance_class = instance_details['DBInstances'][0]['DBInstanceClass']
                instance_types[instance_id] = instance_class
                
                # 获取实例的基准带宽（字节/秒）
                baseline_throughput = self.get_instance_max_throughput_bytes(instance_class)
                
                # NetworkReceiveThroughput 指标
                metrics_data.append({
                    'Id': f'rx{i}',
                    'MetricStat': {
                        'Metric': {
                            'Namespace': 'AWS/RDS',
                            'MetricName': 'NetworkReceiveThroughput',
                            'Dimensions': [{'Name': 'DBInstanceIdentifier', 'Value': instance_id}]
                        },
                        'Period': 60,
                        'Stat': 'Average'
                    },
                    'ReturnData': False,
                })
                
                # NetworkTransmitThroughput 指标
                metrics_data.append({
                    'Id': f'tx{i}',
                    'MetricStat': {
                        'Metric': {
                            'Namespace': 'AWS/RDS',
                            'MetricName': 'NetworkTransmitThroughput',
                            'Dimensions': [{'Name': 'DBInstanceIdentifier', 'Value': instance_id}]
                        },
                        'Period': 60,
                        'Stat': 'Average'
                    },
                    'ReturnData': False,
                })
                
                # 计算该实例的网络利用率百分比: (rx + tx) / baseline * 100
                metrics_data.append({
                    'Id': f'pct{i}',
                    'Expression': f'(rx{i} + tx{i}) / {baseline_throughput} * 100',
                    'Label': f'NetworkUtilizationPercentage_{instance_id}',
                    'ReturnData': False,
                })
                
            except Exception as e:
                logger.warning(f"无法获取实例 {instance_id} 的详细信息: {str(e)}")
                continue

        if not instance_types:
            logger.warning(f"集群 {cluster_id} 没有任何可用的成员实例，无法创建告警。")
            return None

        # 创建最终的 MAX 表达式，获取所有实例中的最大网络利用率百分比
        pct_metrics = [f'pct{i}' for i in range(len(instance_types))]
        max_expression = f'MAX([{", ".join(pct_metrics)}])'
        
        metrics_data.append({
            'Id': 'e1',
            'Expression': max_expression,
            'Label': 'MaxClusterNetworkUtilizationPercentage',
            'ReturnData': True,
        })

        logger.info(f"为集群 {cluster_id} 创建/更新网络吞吐量百分比告警...")
        logger.info(f"  - 监控配置: 当任一实例网络利用率超过 {throughput_percentage}% 时触发告警")
        
        return self.cloudwatch.put_metric_alarm(
            AlarmName=alarm_name,
            AlarmDescription=f'The maximum network utilization percentage for any instance in cluster {cluster_id} exceeds {throughput_percentage}%',
            ActionsEnabled=True,
            AlarmActions=[alarm_topic],
            OKActions=[alarm_topic],
            EvaluationPeriods=3,
            DatapointsToAlarm=3,
            Threshold=float(throughput_percentage),
            ComparisonOperator='GreaterThanThreshold',
            TreatMissingData='missing',
            Metrics=metrics_data
        )

    def _check_cluster_throughput_alarm(self, cluster_id, throughput_percentage, alarm_topic):
        """检查集群网络吞吐量百分比告警是否存在且配置正确。"""
        alarm_name = f"{cluster_id}-MaxNetworkThroughputPercentage"
        try:
            alarms = self.cloudwatch.describe_alarms(AlarmNamePrefix=alarm_name, AlarmTypes=['MetricAlarm'])
            for alarm in alarms['MetricAlarms']:
                if alarm.get('Metrics') and any('MAX([' in str(m.get('Expression', '')) for m in alarm['Metrics']):
                    needs_update = (
                        abs(float(alarm['Threshold']) - float(throughput_percentage)) > 0.1 or
                        alarm['AlarmActions'] != [alarm_topic]
                    )
                    return True, alarm_name, needs_update
            return False, alarm_name, False
        except Exception as e:
            logger.error(f"检查集群 {cluster_id} 的网络吞吐量告警时出错: {str(e)}")
            return False, None, False

    def set_throughput_alarms_by_pattern(self, cluster_pattern, throughput_percentage, alarm_topic):
        """为匹配的数据库集群创建或更新网络吞吐量百分比告警。"""
        clusters = self._get_clusters_by_pattern(cluster_pattern)
        if not alarm_topic or not alarm_topic.startswith('arn:aws:sns:'):
            raise ValueError("无效或未配置 SNS Topic ARN。")

        results = []
        for cluster in clusters:
            cluster_id = cluster['DBClusterIdentifier']
            try:
                exists, alarm_name, needs_update = self._check_cluster_throughput_alarm(cluster_id, throughput_percentage, alarm_topic)
                
                if not exists or needs_update:
                    if not exists:
                        logger.info(f"\n集群 {cluster_id} 未配置网络吞吐量告警，开始创建...")
                    else:
                        logger.info(f"\n集群 {cluster_id} 的网络吞吐量告警配置需要更新...")
                    
                    result = self._create_cluster_throughput_alarm(cluster, throughput_percentage, alarm_topic)
                    if result:
                        status = 'created' if not exists else 'updated'
                        results.append({'clusterId': cluster_id, 'status': status})
                    else:
                        results.append({'clusterId': cluster_id, 'status': 'skipped_no_members'})
                else:
                    logger.info(f"\n集群 {cluster_id} 的网络吞吐量告警已正确配置。")
                    results.append({'clusterId': cluster_id, 'status': 'unchanged'})
                    
            except Exception as e:
                logger.error(f"为集群 {cluster_id} 处理吞吐量告警失败: {str(e)}")
                results.append({'clusterId': cluster_id, 'status': 'failed', 'error': str(e)})
        
        return results

    def get_instance_max_throughput_bytes(self, instance_type):
        """根据实例类型获取其最大网络吞吐量（单位：字节/秒）。"""
        gbps = RDS_INSTANCE_BANDWIDTH_GBPS.get(instance_type)
        if not gbps:
            raise ValueError(f"未找到实例类型 '{instance_type}' 的带宽信息。请在 RDS_INSTANCE_BANDWIDTH_GBPS 字典中添加它。")
        # 1 Gbps = 125,000,000 Bytes/s
        return gbps * 125 * 1000 * 1000

    def set_slowquery_alarms_by_pattern(self, pattern,alarm_topic,threshold,period=60):
        """根据模式为多个rds集群设置慢查询数量告警
        Args:
            pattern: rds集群ID匹配模式
            sns_arn: SNS主题ARN
            threshold: 告警阈值
            period: 统计周期
        """
        logger.info(f"Setting slow queries alarms for rds clusters matching pattern: {pattern}")
        self._alarm_by_pattern(self._create_slow_queries_alarm, pattern, alarm_topic, threshold, period)

    def _create_slow_queries_alarm(self, cluster_info,alarm_topic,threshold, period):
        """为数据库集群创建慢查询告警。"""
        cluster_id = cluster_info['DBClusterIdentifier']
        self.cloudwatch.put_metric_alarm(
            AlarmName=f"{cluster_id}-SlowQueries",
            AlarmDescription=f"Slow queries rate for cluster {cluster_id} exceeds {threshold} queries per second",
            ActionsEnabled=True,
            AlarmActions=[alarm_topic],
            OKActions=[alarm_topic],
            EvaluationPeriods=3,
            DatapointsToAlarm=3,
            Threshold=threshold,
            ComparisonOperator='GreaterThanThreshold',
            MetricName= DbAlarmManager.METRIC_NAME_SLOW_QUERY,
            Namespace= DbAlarmManager.METRIC_NAMESPACE,
            Statistic='Sum',
            Period=period,
            Dimensions=[
                {
                    'Name': 'DBClusterIdentifier',
                    'Value': cluster_id
                },
            ],
            TreatMissingData="ignore",

        )
        logger.info(f"Slow queries alarm created for rds cluster: {cluster_id}")
        
    def _create_no_index_slow_queries_alarm(self, aurora_info, sns_topic_arn, threshold, period):
        """创建Aurora无索引慢查询数量告警"""
        cluster_id = aurora_info['DBClusterIdentifier']
        logger.info(f"Creating no-index slow queries alarm for Aurora cluster: {cluster_id}")
        
        self.cloudwatch.put_metric_alarm(
            AlarmName=f"{cluster_id}_Aurora-NoIndexSlowQueries",
            AlarmDescription=f"当Aurora集群 {cluster_id} 的无索引慢查询数量连续3个周期超过{threshold}个时触发告警",
            ActionsEnabled=True,
            AlarmActions=[sns_topic_arn],
            OKActions=[sns_topic_arn],
            EvaluationPeriods=3,
            DatapointsToAlarm=3,
            Threshold=threshold,
            ComparisonOperator="GreaterThanThreshold",
            MetricName= DbAlarmManager.METRIC_NAME_NO_INDEX,
            Namespace= DbAlarmManager.METRIC_NAMESPACE,
            Statistic='Sum',
            Period=period,
            Dimensions=[
                {
                    'Name': 'DBClusterIdentifier',
                    'Value': cluster_id
                },
            ],
            TreatMissingData="ignore",
            
        )
        logger.info(f"No-index slow queries alarm created for Aurora cluster: {cluster_id}")

    def set_no_index_slow_queries_alarm_by_pattern(self, pattern, sns_arn, threshold, period=60):
        """根据模式为多个Aurora集群设置无索引慢查询数量告警"""
        logger.info(f"Setting no-index slow queries alarms for Aurora clusters matching pattern: {pattern}")
        self._alarm_by_pattern(self._create_no_index_slow_queries_alarm, pattern, sns_arn, threshold, period)

    
    def _alarm_by_pattern(self, func, pattern, alarm_topic, threshold, period):
        """
        根据模式为rds集群添加告警
        Args:
            func: 告警设置函数
            pattern: rds 集群ID匹配模式
            sns_arn: SNS主题ARN
            threshold: 告警阈值
            period: 统计周期
        """
        logger.info(f"Searching for rds clusters matching pattern: {pattern}")
        rds_clusters = []
        paginator = self.rds.get_paginator('describe_db_clusters')
        found = 0
        for page in paginator.paginate():
            
            for cluster in page['DBClusters']:
                if re.match(pattern, cluster['DBClusterIdentifier']):
                    logger.info(f"Found matching rds cluster: {cluster['DBClusterIdentifier']}")
                    func(cluster, alarm_topic, threshold, period)
                    found = found + 1
        if found == 0:
            logger.warning(f"No Redis instances found matching pattern: {pattern}")
        else:
            logger.info(f"Found {found} Redis cluster(s) matching pattern: {pattern}")

    def get_connection_summary_by_pattern(self, cluster_pattern, max_workers=10, connection_threshold=None):
        """
        根据模式获取数据库集群的最新连接数信息。
        
        Args:
            cluster_pattern (str): 集群ID的正则表达式模式
            max_workers (int): 最大线程数，默认10
            connection_threshold (int): 连接数阈值，当连接数小于此值时生成告警
            
        Returns:
            dict: 包含连接数信息和告警信息的字典
        """
        import concurrent.futures
        import threading
        
        clusters = self._get_clusters_by_pattern(cluster_pattern)
        results = []
        alerts = []
        
        # 为每个集群创建一个线程锁，确保日志输出的线程安全
        cluster_locks = {}
        
        for cluster in clusters:
            cluster_id = cluster['DBClusterIdentifier']
            cluster_info = {
                'cluster_id': cluster_id,
                'status': cluster['Status'],
                'engine': cluster['Engine'],
                'engine_version': cluster['EngineVersion'],
                'instances': []
            }
            
            cluster_locks[cluster_id] = threading.Lock()
            
            logger.info(f"Getting latest connection metrics for cluster '{cluster_id}'...")
            
            # 收集所有需要获取的实例ID和成员信息
            cluster_members = cluster.get('DBClusterMembers', [])
            instance_member_map = {member['DBInstanceIdentifier']: member for member in cluster_members}
            instance_ids = list(instance_member_map.keys())
            
            if not instance_ids:
                results.append(cluster_info)
                continue
            
            # 使用线程池并发获取实例连接数信息
            with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
                # 提交所有任务
                future_to_instance = {
                    executor.submit(self._get_instance_connection_metrics, instance_id): instance_id 
                    for instance_id in instance_ids
                }
                
                # 收集结果
                for future in concurrent.futures.as_completed(future_to_instance):
                    instance_id = future_to_instance[future]
                    try:
                        instance_metrics = future.result()
                        cluster_info['instances'].append(instance_metrics)
                        
                        # 统计集群级别的连接数
                        if instance_metrics['status'] == 'success':
                            current_connections = instance_metrics['current']
                            # 检查是否是主实例（写实例）
                            member = instance_member_map.get(instance_id, {})
                            is_writer = member.get('IsClusterWriter', False)
                            
                            # 累加到集群总连接数
                            if 'total_connections' not in cluster_info:
                                cluster_info['total_connections'] = 0
                                cluster_info['writer_connections'] = 0
                                cluster_info['reader_connections'] = 0
                                cluster_info['valid_instances'] = 0
                            
                            cluster_info['total_connections'] += current_connections
                            cluster_info['valid_instances'] += 1
                            
                            # 分别统计主实例和从实例连接数
                            if is_writer:
                                cluster_info['writer_connections'] += current_connections
                            else:
                                cluster_info['reader_connections'] += current_connections
                        
                        # 线程安全的日志输出
                        with cluster_locks[cluster_id]:
                            if instance_metrics['status'] == 'success':
                                logger.info(f"Successfully got metrics for instance '{instance_id}': {instance_metrics['current']} connections")
                            else:
                                logger.warning(f"Failed to get metrics for instance '{instance_id}': {instance_metrics.get('error', 'Unknown error')}")
                                
                    except Exception as e:
                        logger.error(f"Exception occurred while getting metrics for instance '{instance_id}': {str(e)}")
                        # 添加错误信息到结果中
                        cluster_info['instances'].append({
                            'instance_id': instance_id,
                            'current': 'N/A',
                            'timestamp': 'N/A',
                            'status': 'error',
                            'error': str(e)
                        })
            
            # 检查集群级别的连接数阈值告警
            if connection_threshold is not None and cluster_info.get('total_connections', 0) < connection_threshold:
                alert_info = {
                    'cluster_id': cluster_id,
                    'total_connections': cluster_info.get('total_connections', 0),
                    'threshold': connection_threshold,
                    'engine': cluster_info['engine'],
                    'engine_version': cluster_info['engine_version'],
                    'writer_connections': cluster_info.get('writer_connections', 0),
                    'reader_connections': cluster_info.get('reader_connections', 0),
                    'valid_instances': cluster_info.get('valid_instances', 0)
                }
                alerts.append(alert_info)
            
            results.append(cluster_info)
        
        return {
            'clusters': results,
            'alerts': alerts
        }

    def generate_feishu_alert_message(self, alerts):
        """
        生成飞书告警消息。
        
        Args:
            alerts (list): 告警信息列表
            
        Returns:
            str: 格式化的飞书消息
        """
        if not alerts:
            return None
        
        # 按集群分组告警
        cluster_alerts = {}
        for alert in alerts:
            cluster_id = alert['cluster_id']
            if cluster_id not in cluster_alerts:
                cluster_alerts[cluster_id] = []
            cluster_alerts[cluster_id].append(alert)
        
        # 生成消息内容
        message_parts = []
        message_parts.append("🚨 **数据库集群连接数告警**")
        message_parts.append("")
        message_parts.append(f"**告警时间**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
        message_parts.append(f"**告警集群数**: {len(alerts)}")
        message_parts.append("")
        
        for alert in alerts:
            message_parts.append(f"**集群**: {alert['cluster_id']}")
            message_parts.append(f"  - 集群总连接数: {alert['total_connections']}")
            message_parts.append(f"  - 主实例连接数: {alert['writer_connections']}")
            message_parts.append(f"  - 从实例连接数: {alert['reader_connections']}")
            message_parts.append(f"  - 有效实例数: {alert['valid_instances']}")
            message_parts.append(f"  - 阈值: {alert['threshold']}")
            message_parts.append(f"  - 引擎: {alert['engine']} {alert['engine_version']}")
            message_parts.append("")
        
        message_parts.append("请及时检查数据库集群连接情况！")
        
        return "\n".join(message_parts)

    def _get_instance_connection_metrics(self, instance_id):
        """
        获取指定实例的最新连接数指标数据。
        
        Args:
            instance_id (str): 数据库实例ID
            
        Returns:
            dict: 包含最新连接数信息的字典
        """
        try:
            from datetime import timedelta
            
            # 获取最近30分钟的数据，确保能获取到最新数据点
            end_time = datetime.utcnow()
            start_time = end_time - timedelta(minutes=5)
            
            response = self.cloudwatch.get_metric_statistics(
                Namespace='AWS/RDS',
                MetricName='DatabaseConnections',
                Dimensions=[
                    {
                        'Name': 'DBInstanceIdentifier',
                        'Value': instance_id
                    }
                ],
                StartTime=start_time,
                EndTime=end_time,
                Period=60,  # 1分钟间隔
                Statistics=['Average']
            )
            
            if not response['Datapoints']:
                return {
                    'instance_id': instance_id,
                    'current': 'N/A',
                    'timestamp': 'N/A',
                    'status': 'no_data'
                }
            
            # 获取最新的数据点
            latest_datapoint = max(response['Datapoints'], key=lambda x: x['Timestamp'])
            
            return {
                'instance_id': instance_id,
                'current': latest_datapoint['Average'],
                'timestamp': latest_datapoint['Timestamp'].strftime('%Y-%m-%d %H:%M:%S'),
                'status': 'success'
            }
            
        except Exception as e:
            logger.error(f"Error getting connection metrics for instance {instance_id}: {str(e)}")
            return {
                'instance_id': instance_id,
                'current': 'N/A',
                'timestamp': 'N/A',
                'status': 'error',
                'error': str(e)
            }

    