|
3 | 3 | import itertools |
4 | 4 | from dataclasses import dataclass |
5 | 5 | from pathlib import Path |
6 | | -from typing import TYPE_CHECKING |
| 6 | +from typing import TYPE_CHECKING, Any |
7 | 7 |
|
8 | 8 | import click |
9 | 9 | import git |
10 | 10 | from jinja2 import Environment, FileSystemLoader |
11 | 11 | from pydantic import BaseModel |
12 | 12 | from rich import get_console, print |
13 | 13 |
|
14 | | -from tgit.constants import DEFAULT_MODEL |
| 14 | +from tgit.constants import DEFAULT_MODEL, REASONING_MODEL_HINTS |
15 | 15 | from tgit.shared import settings |
16 | 16 | from tgit.utils import get_commit_command, run_command, type_emojis |
17 | 17 |
|
@@ -72,6 +72,14 @@ class CommitData(BaseModel): |
72 | 72 | is_breaking: bool |
73 | 73 |
|
74 | 74 |
|
| 75 | +def _supports_reasoning(model: str) -> bool: |
| 76 | + """Return True when the selected model supports reasoning parameters.""" |
| 77 | + if not model: |
| 78 | + return False |
| 79 | + model_lower = model.lower() |
| 80 | + return any(hint in model_lower for hint in REASONING_MODEL_HINTS) |
| 81 | + |
| 82 | + |
75 | 83 | def get_changed_files_from_status(repo: git.Repo) -> set[str]: |
76 | 84 | """获取所有变更的文件,包括重命名/移动的文件""" |
77 | 85 | diff_name_status = repo.git.diff("--cached", "--name-status", "-M") |
@@ -178,18 +186,24 @@ def _generate_commit_with_ai(diff: str, specified_type: str | None, current_bran |
178 | 186 | ) |
179 | 187 |
|
180 | 188 | with console.status("[bold green]Generating commit message...[/bold green]"): |
181 | | - chat_completion = client.responses.parse( |
182 | | - input=[ |
| 189 | + model_name = settings.model or DEFAULT_MODEL |
| 190 | + request_kwargs: dict[str, Any] = { |
| 191 | + "input": [ |
183 | 192 | { |
184 | 193 | "role": "system", |
185 | 194 | "content": commit_prompt_template.render(**template_params.__dict__), |
186 | 195 | }, |
187 | 196 | {"role": "user", "content": diff}, |
188 | 197 | ], |
189 | | - model=settings.model or DEFAULT_MODEL, |
190 | | - reasoning={"effort": "minimal"}, |
191 | | - max_output_tokens=50, |
192 | | - text_format=CommitData, |
| 198 | + "model": model_name, |
| 199 | + "max_output_tokens": 50, |
| 200 | + "text_format": CommitData, |
| 201 | + } |
| 202 | + if _supports_reasoning(model_name): |
| 203 | + request_kwargs["reasoning"] = {"effort": "minimal"} |
| 204 | + |
| 205 | + chat_completion = client.responses.parse( |
| 206 | + **request_kwargs, |
193 | 207 | ) |
194 | 208 |
|
195 | 209 | return chat_completion.output_parsed |
|
0 commit comments