199 lines
5.4 KiB
Python
199 lines
5.4 KiB
Python
"""
|
|
llm_config.py - Cấu hình LLM linh hoạt
|
|
|
|
Hỗ trợ:
|
|
- Ollama (local)
|
|
- Google Gemini
|
|
- OpenAI
|
|
|
|
Sử dụng:
|
|
from llm_config import ModelConfig, get_llm
|
|
|
|
config = ModelConfig(provider="ollama", model_name="qwen2.5:14b")
|
|
llm = get_llm(config)
|
|
"""
|
|
import os
|
|
from typing import Optional
|
|
from pydantic import BaseModel, Field
|
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
|
|
|
|
class ModelConfig(BaseModel):
|
|
"""Cấu hình cho LLM"""
|
|
provider: str = Field(
|
|
default="gemini",
|
|
description="Provider: ollama, gemini, openai"
|
|
)
|
|
model_name: str = Field(
|
|
default="gemini-2.0-flash-lite",
|
|
description="Tên model"
|
|
)
|
|
api_key: Optional[str] = Field(
|
|
default=None,
|
|
description="API key (nếu None, lấy từ env)"
|
|
)
|
|
temperature: float = Field(
|
|
default=0.1,
|
|
description="Độ sáng tạo (0.0 - 1.0)"
|
|
)
|
|
base_url: Optional[str] = Field(
|
|
default=None,
|
|
description="Base URL cho Ollama"
|
|
)
|
|
|
|
class Config:
|
|
# Cho phép tạo từ dict
|
|
extra = "allow"
|
|
|
|
|
|
# ============== DEFAULT CONFIGS ==============
|
|
|
|
DEFAULT_CONFIGS = {
|
|
"ollama": ModelConfig(
|
|
provider="ollama",
|
|
model_name="qwen2.5:14b",
|
|
temperature=0.1,
|
|
base_url=None # Sẽ lấy từ OLLAMA_BASE_URL env
|
|
),
|
|
"ollama_light": ModelConfig(
|
|
provider="ollama",
|
|
model_name="qwen2.5:7b",
|
|
temperature=0.0,
|
|
base_url=None # Sẽ lấy từ OLLAMA_BASE_URL env
|
|
),
|
|
"gemini": ModelConfig(
|
|
provider="gemini",
|
|
model_name="gemini-2.0-flash-lite",
|
|
temperature=0.1
|
|
),
|
|
"gemini_light": ModelConfig(
|
|
provider="gemini",
|
|
model_name="gemini-2.0-flash-lite",
|
|
temperature=0.0
|
|
),
|
|
"openai": ModelConfig(
|
|
provider="openai",
|
|
model_name="gpt-4o-mini",
|
|
temperature=0.1,
|
|
),
|
|
"openai_light": ModelConfig(
|
|
provider="openai",
|
|
model_name="gpt-4o-mini",
|
|
temperature=0.0
|
|
),
|
|
}
|
|
|
|
|
|
def get_default_config(name: str = "gemini") -> ModelConfig:
|
|
"""Lấy config mặc định theo tên"""
|
|
return DEFAULT_CONFIGS.get(name, DEFAULT_CONFIGS["gemini"])
|
|
|
|
|
|
# ============== LLM FACTORY ==============
|
|
|
|
def get_llm(config: ModelConfig) -> BaseChatModel:
|
|
"""
|
|
Factory function tạo LLM instance
|
|
|
|
Args:
|
|
config: ModelConfig object
|
|
|
|
Returns:
|
|
BaseChatModel instance
|
|
"""
|
|
provider = config.provider.lower()
|
|
|
|
if provider == "ollama":
|
|
from langchain_ollama import ChatOllama
|
|
|
|
base_url = config.base_url or os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
|
|
return ChatOllama(
|
|
model=config.model_name,
|
|
temperature=config.temperature,
|
|
base_url=base_url
|
|
)
|
|
|
|
elif provider == "gemini":
|
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
|
|
api_key = config.api_key or os.getenv("GOOGLE_API_KEY")
|
|
print("Using GOOGLE_API_KEY:", api_key)
|
|
if not api_key:
|
|
raise ValueError("GOOGLE_API_KEY required for Gemini. Set via env or config.api_key")
|
|
|
|
return ChatGoogleGenerativeAI(
|
|
model=config.model_name,
|
|
temperature=config.temperature,
|
|
google_api_key=api_key,
|
|
version="v1",
|
|
additional_headers={
|
|
"User-Agent": "PostmanRuntime/7.43.0",
|
|
"Accept": "*/*"
|
|
}
|
|
)
|
|
|
|
elif provider == "openai":
|
|
from langchain_openai import ChatOpenAI
|
|
|
|
api_key = config.api_key or os.getenv("OPENAI_API_KEY")
|
|
if not api_key:
|
|
raise ValueError("OPENAI_API_KEY required for OpenAI. Set via env or config.api_key")
|
|
|
|
return ChatOpenAI(
|
|
model=config.model_name,
|
|
temperature=config.temperature,
|
|
api_key=api_key,
|
|
base_url=config.base_url or None
|
|
)
|
|
|
|
else:
|
|
raise ValueError(f"Provider '{provider}' không được hỗ trợ. Chọn: ollama, gemini, openai")
|
|
|
|
|
|
def get_completion_model(config: ModelConfig):
|
|
"""
|
|
Tạo completion model (non-chat) nếu cần
|
|
Hiện tại chỉ Ollama có completion model riêng
|
|
"""
|
|
if config.provider.lower() == "ollama":
|
|
from langchain_ollama.llms import OllamaLLM
|
|
|
|
base_url = config.base_url or os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
|
|
return OllamaLLM(
|
|
model=config.model_name,
|
|
temperature=config.temperature,
|
|
base_url=base_url
|
|
)
|
|
|
|
# Các provider khác dùng Chat interface
|
|
return get_llm(config)
|
|
|
|
|
|
# ============== HELPER ==============
|
|
|
|
def create_config(
|
|
provider: str = "gemini",
|
|
model_name: Optional[str] = None,
|
|
api_key: Optional[str] = None,
|
|
temperature: float = 0.1,
|
|
base_url: Optional[str] = None
|
|
) -> ModelConfig:
|
|
"""
|
|
Helper function tạo ModelConfig
|
|
|
|
Nếu không chỉ định model_name, sẽ dùng default cho provider đó
|
|
"""
|
|
default_models = {
|
|
"ollama": "qwen2.5:14b",
|
|
"gemini": "gemini-2.0-flash-lite",
|
|
"openai": "gpt-4o-mini"
|
|
}
|
|
|
|
return ModelConfig(
|
|
provider=provider,
|
|
model_name=model_name or default_models.get(provider, "gemini-2.0-flash-lite"),
|
|
api_key=api_key,
|
|
temperature=temperature,
|
|
base_url=base_url
|
|
)
|