Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions src/smolagents/default_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,13 @@ class DuckDuckGoSearchTool(Tool):
inputs = {"query": {"type": "string", "description": "The search query to perform."}}
output_type = "string"

def __init__(self, max_results: int = 10, rate_limit: float | None = 1.0, **kwargs):
def __init__(
self, max_results: int = 10, rate_limit: float | None = 1.0, site_denylist: list[str] | None = None, **kwargs
):
super().__init__()
self.max_results = max_results
self.rate_limit = rate_limit
self.site_denylist = site_denylist or []
self._min_interval = 1.0 / rate_limit if rate_limit else 0.0
self._last_request_time = 0.0
try:
Expand All @@ -136,6 +139,11 @@ def __init__(self, max_results: int = 10, rate_limit: float | None = 1.0, **kwar

def forward(self, query: str) -> str:
self._enforce_rate_limit()

if self.site_denylist:
exclusion_terms = " ".join([f"-site:{pattern}" for pattern in self.site_denylist])
query = f"{query} {exclusion_terms}"

results = self.ddgs.text(query, max_results=self.max_results)
if len(results) == 0:
raise Exception("No results found! Try a less restrictive/shorter query.")
Expand Down Expand Up @@ -169,7 +177,7 @@ class GoogleSearchTool(Tool):
}
output_type = "string"

def __init__(self, provider: str = "serpapi"):
def __init__(self, provider: str = "serpapi", site_denylist: list[str] | None = None):
super().__init__()
import os

Expand All @@ -183,10 +191,15 @@ def __init__(self, provider: str = "serpapi"):
self.api_key = os.getenv(api_key_env_name)
if self.api_key is None:
raise ValueError(f"Missing API key. Make sure you have '{api_key_env_name}' in your env variables.")
self.site_denylist = site_denylist or []

def forward(self, query: str, filter_year: int | None = None) -> str:
import requests

if self.site_denylist:
exclusion_terms = " ".join([f"-site:{pattern}" for pattern in self.site_denylist])
query = f"{query} {exclusion_terms}"

if self.provider == "serpapi":
params = {
"q": query,
Expand Down Expand Up @@ -279,6 +292,7 @@ def __init__(
api_key_name: str = "",
headers: dict = None,
params: dict = None,
site_denylist: list[str] | None = None,
rate_limit: float | None = 1.0,
):
import os
Expand All @@ -290,6 +304,7 @@ def __init__(
self.headers = headers or {"X-Subscription-Token": self.api_key}
self.params = params or {"count": 10}
self.rate_limit = rate_limit
self.site_denylist = site_denylist or []
self._min_interval = 1.0 / rate_limit if rate_limit else 0.0
self._last_request_time = 0.0

Expand All @@ -310,6 +325,11 @@ def forward(self, query: str) -> str:
import requests

self._enforce_rate_limit()

if self.site_denylist:
exclusion_terms = " ".join([f"-site:{pattern}" for pattern in self.site_denylist])
query = f"{query} {exclusion_terms}"

params = {**self.params, "q": query}
response = requests.get(self.endpoint, headers=self.headers, params=params)
response.raise_for_status()
Expand Down Expand Up @@ -342,12 +362,17 @@ class WebSearchTool(Tool):
inputs = {"query": {"type": "string", "description": "The search query to perform."}}
output_type = "string"

def __init__(self, max_results: int = 10, engine: str = "duckduckgo"):
def __init__(self, max_results: int = 10, engine: str = "duckduckgo", site_denylist: list[str] | None = None):
super().__init__()
self.max_results = max_results
self.engine = engine
self.site_denylist = site_denylist or []

def forward(self, query: str) -> str:
if self.site_denylist:
exclusion_terms = " ".join([f"-site:{pattern}" for pattern in self.site_denylist])
query = f"{query} {exclusion_terms}"

results = self.search(query)
if len(results) == 0:
raise Exception("No results found! Try a less restrictive/shorter query.")
Expand Down
95 changes: 95 additions & 0 deletions tests/test_default_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from unittest.mock import MagicMock, patch

import pytest

from smolagents.agent_types import _AGENT_TYPE_MAPPING
from smolagents.default_tools import (
ApiWebSearchTool,
DuckDuckGoSearchTool,
GoogleSearchTool,
PythonInterpreterTool,
SpeechToTextTool,
VisitWebpageTool,
WebSearchTool,
WikipediaSearchTool,
)

Expand All @@ -41,6 +45,97 @@ def test_ddgs_with_kwargs(self):
result = DuckDuckGoSearchTool(timeout=20)("DeepSeek parent company")
assert isinstance(result, str)

@patch("ddgs.DDGS")
def test_ddgs_with_denylist(self, MockDDGS):
mock_ddgs_instance = MockDDGS.return_value
mock_ddgs_instance.text.return_value = [{"title": "Test", "href": "http://test.com", "body": "Test body"}]

tool = DuckDuckGoSearchTool(site_denylist=["example.com", "*.badsite.org"])
base_query = "test query"
expected_query = "test query -site:example.com -site:*.badsite.org"

tool.forward(base_query)
mock_ddgs_instance.text.assert_called_once_with(expected_query, max_results=10)

@patch("requests.get")
def test_google_search_with_denylist(self, mock_get):
serpapi_response = MagicMock()
serpapi_response.status_code = 200
serpapi_response.json.return_value = {
"organic_results": [{"title": "Test", "link": "http://test.com", "snippet": "Test snippet"}]
}
mock_get.return_value = serpapi_response

with patch("os.getenv", return_value="fake_api_key"):
tool_serpapi = GoogleSearchTool(provider="serpapi", site_denylist=["google.com"])

base_query_1 = "search for something"
expected_query_1 = "search for something -site:google.com"
tool_serpapi.forward(base_query_1)

mock_get.assert_called_once()
_, called_kwargs_1 = mock_get.call_args
self.assertEqual(called_kwargs_1["params"]["q"], expected_query_1)

mock_get.reset_mock()

serper_response = MagicMock()
serper_response.status_code = 200
serper_response.json.return_value = {
"organic": [{"title": "Test Serper", "link": "http://test.com", "snippet": "Test snippet"}]
}
mock_get.return_value = serper_response

with patch("os.getenv", return_value="fake_api_key"):
tool_serper = GoogleSearchTool(provider="serper", site_denylist=["serper.dev"])

base_query_2 = "search serper"
expected_query_2 = "search serper -site:serper.dev"

tool_serper.forward(base_query_2)
mock_get.assert_called_once()
_, called_kwargs_2 = mock_get.call_args
self.assertEqual(called_kwargs_2["params"]["q"], expected_query_2)

@patch("requests.get")
def test_api_web_search_with_denylist(self, mock_get):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"web": {
"results": [
{
"title": "Test",
"url": "http://test.com",
"description": "Test snippet",
}
]
}
}
mock_get.return_value = mock_response

with patch("os.getenv", return_value="fake_api_key"):
tool = ApiWebSearchTool(site_denylist=["brave.com"])

base_query = "search brave"
expected_query = "search brave -site:brave.com"

tool.forward(base_query)

mock_get.assert_called_once()
_, called_kwargs = mock_get.call_args
self.assertEqual(called_kwargs["params"]["q"], expected_query)

@patch("smolagents.default_tools.WebSearchTool.search")
def test_web_search_with_denylist(self, mock_search):
mock_search.return_value = [{"title": "Test", "link": "http://test.com", "description": "Test snippet"}]
tool = WebSearchTool(site_denylist=["ddg.com", "bing.com"])
base_query = "search engines"
expected_query = "search engines -site:ddg.com -site:bing.com"

tool.forward(base_query)
mock_search.assert_called_once_with(expected_query)


class TestPythonInterpreterTool(ToolTesterMixin):
def setup_method(self):
Expand Down