|
1 | 1 | import argparse |
| 2 | +import importlib |
2 | 3 | import importlib.resources |
3 | 4 | import itertools |
4 | 5 | from dataclasses import dataclass |
|
22 | 23 |
|
23 | 24 | MAX_DIFF_LINES = 1000 |
24 | 25 | NUMSTAT_PARTS = 3 |
| 26 | +NAME_STATUS_PARTS = 2 |
| 27 | +RENAME_STATUS_PARTS = 3 |
25 | 28 |
|
26 | 29 |
|
27 | 30 | def define_commit_parser(subparsers: argparse._SubParsersAction) -> None: |
@@ -60,79 +63,158 @@ class CommitData(BaseModel): |
60 | 63 | is_breaking: bool |
61 | 64 |
|
62 | 65 |
|
63 | | -def get_filtered_diff_files(repo: git.Repo) -> tuple[list[str], list[str]]: |
64 | | - diff_numstat = repo.git.diff("--cached", "--numstat") |
65 | | - files_to_include = [] |
66 | | - lock_files = [] |
| 66 | +def get_changed_files_from_status(repo: git.Repo) -> set[str]: |
| 67 | + """获取所有变更的文件,包括重命名/移动的文件""" |
| 68 | + diff_name_status = repo.git.diff("--cached", "--name-status", "-M") |
| 69 | + all_changed_files = set() |
| 70 | + |
| 71 | + for line in diff_name_status.splitlines(): |
| 72 | + parts = line.split("\t") |
| 73 | + if len(parts) >= NAME_STATUS_PARTS: |
| 74 | + status = parts[0] |
| 75 | + if status.startswith("R"): # 重命名/移动 |
| 76 | + # 重命名格式: R100 old_file new_file |
| 77 | + if len(parts) >= RENAME_STATUS_PARTS: |
| 78 | + old_file, new_file = parts[1], parts[2] |
| 79 | + all_changed_files.add(old_file) |
| 80 | + all_changed_files.add(new_file) |
| 81 | + else: |
| 82 | + # 其他状态: A(添加), M(修改), D(删除)等 |
| 83 | + filename = parts[1] |
| 84 | + all_changed_files.add(filename) |
| 85 | + |
| 86 | + return all_changed_files |
| 87 | + |
| 88 | + |
| 89 | +def get_file_change_sizes(repo: git.Repo) -> dict[str, int]: |
| 90 | + """获取文件变更的行数统计""" |
| 91 | + diff_numstat = repo.git.diff("--cached", "--numstat", "-M") |
| 92 | + file_sizes = {} |
| 93 | + |
67 | 94 | for line in diff_numstat.splitlines(): |
68 | 95 | parts = line.split("\t") |
69 | 96 | if len(parts) >= NUMSTAT_PARTS: |
70 | 97 | added, deleted, filename = parts[0], parts[1], parts[2] |
71 | | - if filename.endswith(".lock"): |
72 | | - lock_files.append(filename) |
73 | | - continue |
74 | 98 | try: |
75 | | - added = int(added) if added != "-" else 0 |
76 | | - deleted = int(deleted) if deleted != "-" else 0 |
| 99 | + added_int = int(added) if added != "-" else 0 |
| 100 | + deleted_int = int(deleted) if deleted != "-" else 0 |
| 101 | + file_sizes[filename] = added_int + deleted_int |
77 | 102 | except ValueError: |
78 | | - continue |
79 | | - if added + deleted <= MAX_DIFF_LINES: |
80 | | - files_to_include.append(filename) |
| 103 | + # 对于二进制文件等特殊情况,设置为0以包含在diff中 |
| 104 | + file_sizes[filename] = 0 |
| 105 | + |
| 106 | + return file_sizes |
| 107 | + |
| 108 | + |
| 109 | +def get_filtered_diff_files(repo: git.Repo) -> tuple[list[str], list[str]]: |
| 110 | + """获取过滤后的差异文件列表""" |
| 111 | + all_changed_files = get_changed_files_from_status(repo) |
| 112 | + file_sizes = get_file_change_sizes(repo) |
| 113 | + |
| 114 | + files_to_include = [] |
| 115 | + lock_files = [] |
| 116 | + |
| 117 | + # 过滤文件 |
| 118 | + for filename in all_changed_files: |
| 119 | + if filename.endswith(".lock"): |
| 120 | + lock_files.append(filename) |
| 121 | + continue |
| 122 | + |
| 123 | + # 检查文件大小(如果有统计信息) |
| 124 | + total_changes = file_sizes.get(filename, 0) |
| 125 | + if total_changes <= MAX_DIFF_LINES: |
| 126 | + files_to_include.append(filename) |
| 127 | + |
81 | 128 | return files_to_include, lock_files |
82 | 129 |
|
83 | 130 |
|
| 131 | +def _import_openai(): # type: ignore[misc] # noqa: ANN202 |
| 132 | + """动态导入 openai 包""" |
| 133 | + try: |
| 134 | + # 动态导入,避免在模块级别导入 |
| 135 | + return importlib.import_module("openai") |
| 136 | + except ImportError as e: |
| 137 | + error_msg = "openai package is not installed" |
| 138 | + raise ImportError(error_msg) from e |
| 139 | + |
| 140 | + |
| 141 | +def _check_openai_availability() -> None: |
| 142 | + """检查 openai 包是否可用""" |
| 143 | + _import_openai() # 这会在包不可用时抛出异常 |
| 144 | + |
| 145 | + |
| 146 | +def _create_openai_client(): # type: ignore[misc] # noqa: ANN202 |
| 147 | + """创建并配置 OpenAI 客户端""" |
| 148 | + openai = _import_openai() |
| 149 | + client = openai.Client() |
| 150 | + api_url = settings.get("apiUrl") |
| 151 | + if api_url: |
| 152 | + client.base_url = api_url |
| 153 | + api_key = settings.get("apiKey") |
| 154 | + if api_key: |
| 155 | + client.api_key = api_key |
| 156 | + return client |
| 157 | + |
| 158 | + |
| 159 | +def _generate_commit_with_ai(diff: str, specified_type: str | None, current_branch: str) -> CommitData | None: |
| 160 | + """使用 AI 生成提交消息""" |
| 161 | + _check_openai_availability() |
| 162 | + client = _create_openai_client() |
| 163 | + |
| 164 | + template_params = {"types": commit_types, "branch": current_branch} |
| 165 | + if specified_type: |
| 166 | + template_params["specified_type"] = specified_type |
| 167 | + |
| 168 | + with console.status("[bold green]Generating commit message...[/bold green]"): |
| 169 | + chat_completion = client.responses.parse( |
| 170 | + input=[ |
| 171 | + { |
| 172 | + "role": "system", |
| 173 | + "content": commit_prompt_template.render(**template_params), |
| 174 | + }, |
| 175 | + {"role": "user", "content": diff}, |
| 176 | + ], |
| 177 | + model=settings.get("model", "gpt-4.1"), |
| 178 | + max_output_tokens=50, |
| 179 | + text_format=CommitData, |
| 180 | + ) |
| 181 | + |
| 182 | + return chat_completion.output_parsed |
| 183 | + |
| 184 | + |
84 | 185 | def get_ai_command(specified_type: str | None = None) -> str | None: |
85 | 186 | current_dir = Path.cwd() |
86 | 187 | try: |
87 | 188 | repo = git.Repo(current_dir, search_parent_directories=True) |
88 | 189 | except git.InvalidGitRepositoryError: |
89 | 190 | print("[yellow]Not a git repository[/yellow]") |
90 | 191 | return None |
| 192 | + |
91 | 193 | files_to_include, lock_files = get_filtered_diff_files(repo) |
92 | 194 | if not files_to_include and not lock_files: |
93 | 195 | print("[yellow]No files to commit, please add some files before using AI[/yellow]") |
94 | 196 | return None |
| 197 | + |
95 | 198 | diff = "" |
96 | 199 | if lock_files: |
97 | 200 | diff += f"[INFO] The following lock files were modified but are not included in the diff: {', '.join(lock_files)}\n" |
98 | 201 | if files_to_include: |
99 | | - diff += repo.git.diff("--cached", "--", *files_to_include) |
| 202 | + diff += repo.git.diff("--cached", "-M", "--", *files_to_include) |
100 | 203 | current_branch = repo.active_branch.name |
101 | 204 |
|
102 | 205 | if not diff: |
103 | 206 | print("[yellow]No changes to commit, please add some changes before using AI[/yellow]") |
104 | 207 | return None |
| 208 | + |
105 | 209 | try: |
106 | | - import openai |
107 | | - |
108 | | - client = openai.Client() |
109 | | - if settings.get("apiUrl", None): |
110 | | - client.api_base = settings.get("apiUrl", None) |
111 | | - if settings.get("apiKey", None): |
112 | | - client.api_key = settings.get("apiKey", None) |
113 | | - # 准备模板渲染参数,如果用户指定了类型,则传递给模板 |
114 | | - template_params = {"types": commit_types, "branch": current_branch} |
115 | | - |
116 | | - if specified_type: |
117 | | - template_params["specified_type"] = specified_type |
118 | | - with console.status("[bold green]Generating commit message...[/bold green]"): |
119 | | - chat_completion = client.responses.parse( |
120 | | - input=[ |
121 | | - { |
122 | | - "role": "system", |
123 | | - "content": commit_prompt_template.render(**template_params), |
124 | | - }, |
125 | | - {"role": "user", "content": diff}, |
126 | | - ], |
127 | | - model=settings.get("model", "gpt-4.1"), |
128 | | - max_output_tokens=50, |
129 | | - text_format=CommitData, |
130 | | - ) |
| 210 | + resp = _generate_commit_with_ai(diff, specified_type, current_branch) |
| 211 | + if resp is None: |
| 212 | + print("[red]Failed to parse AI response[/red]") |
| 213 | + return None |
131 | 214 | except Exception as e: |
132 | 215 | print("[red]Could not connect to AI provider[/red]") |
133 | 216 | print(e) |
134 | 217 | return None |
135 | | - resp = chat_completion.output_parsed |
136 | 218 |
|
137 | 219 | # 如果用户指定了类型,则使用用户指定的类型 |
138 | 220 | commit_type = specified_type or resp.type |
|
0 commit comments