import os
from pptx import Presentation
from pptx.oxml.xmlchemy import OxmlElement
from pptx.oxml.ns import qn

def SubElement(parent, tagname, **kwargs):
    element = OxmlElement(tagname)
    element.attrib.update(kwargs)
    parent.append(element)
    return element

def set_cell_border(cell, side, border_color="000000", border_width='12700', visible=True):
    """
    修改单元格指定一侧的边框。
    """
    tc = cell._tc
    tcPr = tc.get_or_add_tcPr()

    # 1. 映射标签名
    tag_map = {
        'left': 'a:lnL',
        'right': 'a:lnR',
        'top': 'a:lnT',
        'bottom': 'a:lnB'
    }
    tag_name = tag_map.get(side)
    if not tag_name:
        return

    # 2. 【查找时必须用 qn】检查并删除现有的边框定义
    existing_ln = tcPr.find(qn(tag_name))
    if existing_ln is not None:
        tcPr.remove(existing_ln)

    # 3. 创建新的边框元素
    # 【创建时不能用 qn，直接用字符串】
    ln = OxmlElement(tag_name)
    
    if visible:
        # --- 画实线 ---
        ln.set('w', str(border_width))
        ln.set('cap', 'flat')
        ln.set('cmpd', 'sng')
        ln.set('algn', 'ctr')

        # 创建 solidFill (创建时直接用字符串)
        solidFill = OxmlElement('a:solidFill')
        ln.append(solidFill)

        # 创建 srgbClr (创建时直接用字符串)
        srgbClr = OxmlElement('a:srgbClr')
        srgbClr.set('val', border_color)
        solidFill.append(srgbClr)

        # 创建 prstDash (创建时直接用字符串)
        prstDash = OxmlElement('a:prstDash')
        prstDash.set('val', 'solid')
        ln.append(prstDash)
        
    else:
        # --- 无边框 ---
        # 创建 noFill (创建时直接用字符串)
        noFill = OxmlElement('a:noFill')
        ln.append(noFill)

    # 4. 将新定义的边框附加到单元格属性中
    tcPr.append(ln)


def apply_three_line_style(table):
    """
    将表格转换为标准的学术三线表
    """
    # 宽度设置 (EMU单位: 1 pt = 12700 EMU)
    THICK_WIDTH = 28575  # ~2.25 pt
    THIN_WIDTH = 9525    # ~0.75 pt
    COLOR = "000000"     # 黑色

    rows = table.rows
    if len(rows) == 0:
        return

    # 遍历所有单元格
    for row_idx, row in enumerate(rows):
        for cell in row.cells:
            
            # --- 步骤 1: 清除左右边框 ---
            set_cell_border(cell, 'left', visible=False)
            set_cell_border(cell, 'right', visible=False)

            # --- 步骤 2: 处理上边框 ---
            if row_idx == 0:
                # 第一行顶部：粗线
                set_cell_border(cell, 'top', COLOR, THICK_WIDTH, visible=True)
            else:
                # 其他行顶部：清除（依赖上一行的bottom通常就够了，但为了保险清除top）
                set_cell_border(cell, 'top', visible=False)

            # --- 步骤 3: 处理下边框 ---
            if row_idx == 0:
                # 第一行底部（栏目线）：细线 (如果只有一行则粗线)
                width = THICK_WIDTH if len(rows) == 1 else THIN_WIDTH
                set_cell_border(cell, 'bottom', COLOR, width, visible=True)
            
            elif row_idx == len(rows) - 1:
                # 最后一行底部：粗线
                set_cell_border(cell, 'bottom', COLOR, THICK_WIDTH, visible=True)
            
            else:
                # 中间行底部：清除
                set_cell_border(cell, 'bottom', visible=False)


def process_presentations(folder_path):
    files = [f for f in os.listdir(folder_path) if f.endswith(".pptx") and not f.startswith("~$")]
    
    if not files:
        print(f"在 {folder_path} 中未找到 PPTX 文件。")
        return

    output_folder = os.path.join(folder_path, "modified_tables")
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    for filename in files:
        print(f"正在处理: {filename}...")
        try:
            file_path = os.path.join(folder_path, filename)
            prs = Presentation(file_path)
            
            table_count = 0
            for slide in prs.slides:
                for shape in slide.shapes:
                    if shape.has_table:
                        apply_three_line_style(shape.table)
                        table_count += 1
            
            new_path = os.path.join(output_folder, f"三线表_{filename}")
            prs.save(new_path)
            print(f"  - 成功! 修改了 {table_count} 个表格 -> {new_path}")
            
        except Exception as e:
            print(f"  - 失败: {filename} 出现错误: {e}")
            import traceback
            traceback.print_exc()

if __name__ == "__main__":
    process_presentations('.')