# -*- coding: utf-8 -*-
"""
解析民政部《2023年中华人民共和国县以上行政区划代码》HTML
生成 admin_division.db

表结构：
    administrative_division (
        code         TEXT PRIMARY KEY,   -- 6 位行政区划代码
        name         TEXT NOT NULL,      -- 单位名称（已去除前导空格）
        parent_code  TEXT,               -- 上级行政区代码（省级为空）
        level        INTEGER NOT NULL    -- 1=省级, 2=市级, 3=县级
    )

判定规则：
    - 省级（含直辖市）: 末 4 位均为 0（XX0000）
    - 市级（含地级市/自治州/盟/地区）: 末 2 位为 0 且中间 2 位非 0（XXYY00）
    - 县级（含区/县/旗/自治县/县级市）: 末 2 位非 0（XXYYZZ）

parent 规则：
    - 省级：None
    - 市级：前 2 位 + '0000'
    - 县级：先查前 4 位 + '00' 是否存在；
            存在 → 指向该市
            不存在 → 指向省级（兼容直辖市下辖区、无市级占位的省直辖县级市）
"""

import os
import re
import sqlite3

# 输入 HTML 与输出数据库的绝对路径
HTML_PATH = r'd:\BYP\Project\BYP\BYPsNotes\extern\attachments\topic\县以上行政区划代码\202301xzqh.html'
DB_PATH = r'd:\BYP\Project\BYP\BYPsNotes\db\admin_division.db'


def parse_html(html_path):
    """从 HTML 中提取所有 (code, raw_name) 对，按原始顺序返回"""
    with open(html_path, 'r', encoding='utf-8') as f:
        content = f.read()

    # 匹配：<td ...>6位数字</td> 紧跟 <td ...>名称</td>
    # 名称单元格内可能有 <span>...</span> 包裹的前导空格
    pattern = re.compile(
        r'<td[^>]*>\s*(\d{6})\s*</td>\s*<td[^>]*>(.*?)</td>',
        re.DOTALL
    )

    records = []
    for m in pattern.finditer(content):
        code = m.group(1)
        raw_name = m.group(2)
        name = clean_name(raw_name)
        if name:
            records.append((code, name))
    return records


def clean_name(raw):
    """剥离 HTML 标签与空白实体，返回干净的中文名称"""
    # 去掉所有 HTML 标签
    text = re.sub(r'<[^>]+>', '', raw)
    # 替换 HTML 实体与特殊空白
    text = (text
            .replace('&nbsp;', ' ')
            .replace('\xa0', ' ')
            .replace('\u3000', ' '))
    # 折叠连续空白并 strip
    text = re.sub(r'\s+', '', text)
    # 去除民政部脚注标记（如"五指山市*"中的星号）
    text = text.rstrip('*＊')
    return text.strip()


def determine_level(code):
    """根据 6 位代码判定级别：1=省, 2=市, 3=县"""
    mid = code[2:4]
    tail = code[4:6]
    if mid == '00' and tail == '00':
        return 1
    if tail == '00':
        return 2
    return 3


def determine_parent(code, level, code_set):
    """根据级别和已知代码集合推导 parent_code"""
    if level == 1:
        return None
    if level == 2:
        # 市级 → 省级
        return code[:2] + '0000'
    # 县级：优先挂到同前 4 位的市；无市级占位则挂到省
    city_code = code[:4] + '00'
    if city_code in code_set and city_code != code:
        return city_code
    return code[:2] + '0000'


def build_db(records, db_path):
    """把解析结果写入 SQLite"""
    # 先建立 code 集合，便于县级寻父
    code_set = {code for code, _ in records}

    rows = []
    for code, name in records:
        level = determine_level(code)
        parent = determine_parent(code, level, code_set)
        rows.append((code, name, parent, level))

    # 确保输出目录存在
    os.makedirs(os.path.dirname(db_path), exist_ok=True)
    # 若旧库存在则覆盖重建
    if os.path.exists(db_path):
        os.remove(db_path)

    conn = sqlite3.connect(db_path)
    cur = conn.cursor()
    cur.execute('''
        CREATE TABLE administrative_division (
            code        TEXT PRIMARY KEY,
            name        TEXT NOT NULL,
            parent_code TEXT,
            level       INTEGER NOT NULL
        )
    ''')
    cur.execute('CREATE INDEX idx_parent ON administrative_division(parent_code)')
    cur.execute('CREATE INDEX idx_level ON administrative_division(level)')
    cur.execute('CREATE INDEX idx_name ON administrative_division(name)')

    cur.executemany(
        'INSERT INTO administrative_division (code, name, parent_code, level) VALUES (?, ?, ?, ?)',
        rows
    )
    conn.commit()

    # 统计概要
    cur.execute('SELECT level, COUNT(*) FROM administrative_division GROUP BY level ORDER BY level')
    level_stats = cur.fetchall()
    conn.close()

    return len(rows), level_stats


def main():
    print(f'读取：{HTML_PATH}')
    records = parse_html(HTML_PATH)
    print(f'  解析到 {len(records)} 条记录\n')

    print(f'写入：{DB_PATH}')
    total, level_stats = build_db(records, DB_PATH)
    print(f'  共写入 {total} 条\n')

    level_name = {1: '省级', 2: '市级', 3: '县级'}
    print('分级统计：')
    for lvl, cnt in level_stats:
        print(f'  {level_name.get(lvl, lvl)}: {cnt}')

    # 抽样校验：展示几条典型
    print('\n抽样校验：')
    conn = sqlite3.connect(DB_PATH)
    cur = conn.cursor()
    for code in ('110000', '110101', '130000', '130100', '130121', '450000', '450881', '469001', '152200', '152201'):
        cur.execute('SELECT code, name, parent_code, level FROM administrative_division WHERE code = ?', (code,))
        row = cur.fetchone()
        if row:
            parent_name = ''
            if row[2]:
                cur.execute('SELECT name FROM administrative_division WHERE code = ?', (row[2],))
                pr = cur.fetchone()
                parent_name = pr[0] if pr else ''
            print(f'  {row[0]} {row[1]:<12} level={row[3]}  parent={row[2] or "—":<6} ({parent_name})')
    conn.close()

    print('\n完成。')


if __name__ == '__main__':
    main()
