Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion libs/langchain/langchain_classic/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,31 @@ def _init_chat_model_helper(
_check_pkg("langchain_huggingface")
from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline

llm = HuggingFacePipeline.from_model_id(model_id=model, **kwargs)
# ------------------------------------------------------
# Add default task for decoder-only HuggingFace models.
# ------------------------------------------------------
if "task" not in kwargs:
kwargs["task"] = "text-generation"

# ------------------------------------------------------
# Filter out non-HuggingFacePipeline parameters.
# ------------------------------------------------------
pipeline_allowed_params = {"task", "model_kwargs", "device"}
pipeline_kwargs = {
k: v for k, v in kwargs.items() if k in pipeline_allowed_params
}

# ------------------------------------------------------
# Construct the HF pipeline from model_id
# ------------------------------------------------------
llm = HuggingFacePipeline.from_model_id(
model_id=model,
task=pipeline_kwargs.get("task"),
model_kwargs=pipeline_kwargs.get("model_kwargs"),
device=pipeline_kwargs.get("device"),
)

# Wrap into ChatHuggingFace
return ChatHuggingFace(llm=llm)

if model_provider == "groq":
Expand Down
14 changes: 12 additions & 2 deletions libs/partners/huggingface/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions tests/test_init_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import sys, os
sys.path.append(os.path.abspath("libs/langchain")) # ensures local package can be imported

from libs.langchain.langchain_classic.chat_models import init_chat_model


def test_hf_model_init():
llm = init_chat_model(
model="microsoft/Phi-3-mini-4k-instruct",
model_provider="huggingface",
temperature=0,
max_tokens=1024,
timeout=None,
max_retries=2,
)
print(llm) # optional for debug
assert llm is not None