Skip to content

Commit 8071a1c

Browse files
committed
🔨 refactor(ai): modularize ai client creation and file filtering
1 parent 27e178c commit 8071a1c

File tree

2 files changed

+149
-67
lines changed

2 files changed

+149
-67
lines changed

tgit/commit.py

Lines changed: 121 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import importlib
23
import importlib.resources
34
import itertools
45
from dataclasses import dataclass
@@ -22,6 +23,8 @@
2223

2324
MAX_DIFF_LINES = 1000
2425
NUMSTAT_PARTS = 3
26+
NAME_STATUS_PARTS = 2
27+
RENAME_STATUS_PARTS = 3
2528

2629

2730
def define_commit_parser(subparsers: argparse._SubParsersAction) -> None:
@@ -60,79 +63,158 @@ class CommitData(BaseModel):
6063
is_breaking: bool
6164

6265

63-
def get_filtered_diff_files(repo: git.Repo) -> tuple[list[str], list[str]]:
64-
diff_numstat = repo.git.diff("--cached", "--numstat")
65-
files_to_include = []
66-
lock_files = []
66+
def get_changed_files_from_status(repo: git.Repo) -> set[str]:
67+
"""获取所有变更的文件,包括重命名/移动的文件"""
68+
diff_name_status = repo.git.diff("--cached", "--name-status", "-M")
69+
all_changed_files = set()
70+
71+
for line in diff_name_status.splitlines():
72+
parts = line.split("\t")
73+
if len(parts) >= NAME_STATUS_PARTS:
74+
status = parts[0]
75+
if status.startswith("R"): # 重命名/移动
76+
# 重命名格式: R100 old_file new_file
77+
if len(parts) >= RENAME_STATUS_PARTS:
78+
old_file, new_file = parts[1], parts[2]
79+
all_changed_files.add(old_file)
80+
all_changed_files.add(new_file)
81+
else:
82+
# 其他状态: A(添加), M(修改), D(删除)等
83+
filename = parts[1]
84+
all_changed_files.add(filename)
85+
86+
return all_changed_files
87+
88+
89+
def get_file_change_sizes(repo: git.Repo) -> dict[str, int]:
90+
"""获取文件变更的行数统计"""
91+
diff_numstat = repo.git.diff("--cached", "--numstat", "-M")
92+
file_sizes = {}
93+
6794
for line in diff_numstat.splitlines():
6895
parts = line.split("\t")
6996
if len(parts) >= NUMSTAT_PARTS:
7097
added, deleted, filename = parts[0], parts[1], parts[2]
71-
if filename.endswith(".lock"):
72-
lock_files.append(filename)
73-
continue
7498
try:
75-
added = int(added) if added != "-" else 0
76-
deleted = int(deleted) if deleted != "-" else 0
99+
added_int = int(added) if added != "-" else 0
100+
deleted_int = int(deleted) if deleted != "-" else 0
101+
file_sizes[filename] = added_int + deleted_int
77102
except ValueError:
78-
continue
79-
if added + deleted <= MAX_DIFF_LINES:
80-
files_to_include.append(filename)
103+
# 对于二进制文件等特殊情况,设置为0以包含在diff中
104+
file_sizes[filename] = 0
105+
106+
return file_sizes
107+
108+
109+
def get_filtered_diff_files(repo: git.Repo) -> tuple[list[str], list[str]]:
110+
"""获取过滤后的差异文件列表"""
111+
all_changed_files = get_changed_files_from_status(repo)
112+
file_sizes = get_file_change_sizes(repo)
113+
114+
files_to_include = []
115+
lock_files = []
116+
117+
# 过滤文件
118+
for filename in all_changed_files:
119+
if filename.endswith(".lock"):
120+
lock_files.append(filename)
121+
continue
122+
123+
# 检查文件大小(如果有统计信息)
124+
total_changes = file_sizes.get(filename, 0)
125+
if total_changes <= MAX_DIFF_LINES:
126+
files_to_include.append(filename)
127+
81128
return files_to_include, lock_files
82129

83130

131+
def _import_openai(): # type: ignore[misc] # noqa: ANN202
132+
"""动态导入 openai 包"""
133+
try:
134+
# 动态导入,避免在模块级别导入
135+
return importlib.import_module("openai")
136+
except ImportError as e:
137+
error_msg = "openai package is not installed"
138+
raise ImportError(error_msg) from e
139+
140+
141+
def _check_openai_availability() -> None:
142+
"""检查 openai 包是否可用"""
143+
_import_openai() # 这会在包不可用时抛出异常
144+
145+
146+
def _create_openai_client(): # type: ignore[misc] # noqa: ANN202
147+
"""创建并配置 OpenAI 客户端"""
148+
openai = _import_openai()
149+
client = openai.Client()
150+
api_url = settings.get("apiUrl")
151+
if api_url:
152+
client.base_url = api_url
153+
api_key = settings.get("apiKey")
154+
if api_key:
155+
client.api_key = api_key
156+
return client
157+
158+
159+
def _generate_commit_with_ai(diff: str, specified_type: str | None, current_branch: str) -> CommitData | None:
160+
"""使用 AI 生成提交消息"""
161+
_check_openai_availability()
162+
client = _create_openai_client()
163+
164+
template_params = {"types": commit_types, "branch": current_branch}
165+
if specified_type:
166+
template_params["specified_type"] = specified_type
167+
168+
with console.status("[bold green]Generating commit message...[/bold green]"):
169+
chat_completion = client.responses.parse(
170+
input=[
171+
{
172+
"role": "system",
173+
"content": commit_prompt_template.render(**template_params),
174+
},
175+
{"role": "user", "content": diff},
176+
],
177+
model=settings.get("model", "gpt-4.1"),
178+
max_output_tokens=50,
179+
text_format=CommitData,
180+
)
181+
182+
return chat_completion.output_parsed
183+
184+
84185
def get_ai_command(specified_type: str | None = None) -> str | None:
85186
current_dir = Path.cwd()
86187
try:
87188
repo = git.Repo(current_dir, search_parent_directories=True)
88189
except git.InvalidGitRepositoryError:
89190
print("[yellow]Not a git repository[/yellow]")
90191
return None
192+
91193
files_to_include, lock_files = get_filtered_diff_files(repo)
92194
if not files_to_include and not lock_files:
93195
print("[yellow]No files to commit, please add some files before using AI[/yellow]")
94196
return None
197+
95198
diff = ""
96199
if lock_files:
97200
diff += f"[INFO] The following lock files were modified but are not included in the diff: {', '.join(lock_files)}\n"
98201
if files_to_include:
99-
diff += repo.git.diff("--cached", "--", *files_to_include)
202+
diff += repo.git.diff("--cached", "-M", "--", *files_to_include)
100203
current_branch = repo.active_branch.name
101204

102205
if not diff:
103206
print("[yellow]No changes to commit, please add some changes before using AI[/yellow]")
104207
return None
208+
105209
try:
106-
import openai
107-
108-
client = openai.Client()
109-
if settings.get("apiUrl", None):
110-
client.api_base = settings.get("apiUrl", None)
111-
if settings.get("apiKey", None):
112-
client.api_key = settings.get("apiKey", None)
113-
# 准备模板渲染参数,如果用户指定了类型,则传递给模板
114-
template_params = {"types": commit_types, "branch": current_branch}
115-
116-
if specified_type:
117-
template_params["specified_type"] = specified_type
118-
with console.status("[bold green]Generating commit message...[/bold green]"):
119-
chat_completion = client.responses.parse(
120-
input=[
121-
{
122-
"role": "system",
123-
"content": commit_prompt_template.render(**template_params),
124-
},
125-
{"role": "user", "content": diff},
126-
],
127-
model=settings.get("model", "gpt-4.1"),
128-
max_output_tokens=50,
129-
text_format=CommitData,
130-
)
210+
resp = _generate_commit_with_ai(diff, specified_type, current_branch)
211+
if resp is None:
212+
print("[red]Failed to parse AI response[/red]")
213+
return None
131214
except Exception as e:
132215
print("[red]Could not connect to AI provider[/red]")
133216
print(e)
134217
return None
135-
resp = chat_completion.output_parsed
136218

137219
# 如果用户指定了类型,则使用用户指定的类型
138220
commit_type = specified_type or resp.type

0 commit comments

Comments
 (0)