Files
gen_game-a0/src/llm_config.py
vuongps38770 7c41ddaa82 init
2026-01-13 09:33:10 +07:00

199 lines
5.4 KiB
Python

"""
llm_config.py - Cấu hình LLM linh hoạt
Hỗ trợ:
- Ollama (local)
- Google Gemini
- OpenAI
Sử dụng:
from llm_config import ModelConfig, get_llm
config = ModelConfig(provider="ollama", model_name="qwen2.5:14b")
llm = get_llm(config)
"""
import os
from typing import Optional
from pydantic import BaseModel, Field
from langchain_core.language_models.chat_models import BaseChatModel
class ModelConfig(BaseModel):
"""Cấu hình cho LLM"""
provider: str = Field(
default="gemini",
description="Provider: ollama, gemini, openai"
)
model_name: str = Field(
default="gemini-2.0-flash-lite",
description="Tên model"
)
api_key: Optional[str] = Field(
default=None,
description="API key (nếu None, lấy từ env)"
)
temperature: float = Field(
default=0.1,
description="Độ sáng tạo (0.0 - 1.0)"
)
base_url: Optional[str] = Field(
default=None,
description="Base URL cho Ollama"
)
class Config:
# Cho phép tạo từ dict
extra = "allow"
# ============== DEFAULT CONFIGS ==============
DEFAULT_CONFIGS = {
"ollama": ModelConfig(
provider="ollama",
model_name="qwen2.5:14b",
temperature=0.1,
base_url=None # Sẽ lấy từ OLLAMA_BASE_URL env
),
"ollama_light": ModelConfig(
provider="ollama",
model_name="qwen2.5:7b",
temperature=0.0,
base_url=None # Sẽ lấy từ OLLAMA_BASE_URL env
),
"gemini": ModelConfig(
provider="gemini",
model_name="gemini-2.0-flash-lite",
temperature=0.1
),
"gemini_light": ModelConfig(
provider="gemini",
model_name="gemini-2.0-flash-lite",
temperature=0.0
),
"openai": ModelConfig(
provider="openai",
model_name="gpt-4o-mini",
temperature=0.1,
),
"openai_light": ModelConfig(
provider="openai",
model_name="gpt-4o-mini",
temperature=0.0
),
}
def get_default_config(name: str = "gemini") -> ModelConfig:
"""Lấy config mặc định theo tên"""
return DEFAULT_CONFIGS.get(name, DEFAULT_CONFIGS["gemini"])
# ============== LLM FACTORY ==============
def get_llm(config: ModelConfig) -> BaseChatModel:
"""
Factory function tạo LLM instance
Args:
config: ModelConfig object
Returns:
BaseChatModel instance
"""
provider = config.provider.lower()
if provider == "ollama":
from langchain_ollama import ChatOllama
base_url = config.base_url or os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
return ChatOllama(
model=config.model_name,
temperature=config.temperature,
base_url=base_url
)
elif provider == "gemini":
from langchain_google_genai import ChatGoogleGenerativeAI
api_key = config.api_key or os.getenv("GOOGLE_API_KEY")
print("Using GOOGLE_API_KEY:", api_key)
if not api_key:
raise ValueError("GOOGLE_API_KEY required for Gemini. Set via env or config.api_key")
return ChatGoogleGenerativeAI(
model=config.model_name,
temperature=config.temperature,
google_api_key=api_key,
version="v1",
additional_headers={
"User-Agent": "PostmanRuntime/7.43.0",
"Accept": "*/*"
}
)
elif provider == "openai":
from langchain_openai import ChatOpenAI
api_key = config.api_key or os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEY required for OpenAI. Set via env or config.api_key")
return ChatOpenAI(
model=config.model_name,
temperature=config.temperature,
api_key=api_key,
base_url=config.base_url or None
)
else:
raise ValueError(f"Provider '{provider}' không được hỗ trợ. Chọn: ollama, gemini, openai")
def get_completion_model(config: ModelConfig):
"""
Tạo completion model (non-chat) nếu cần
Hiện tại chỉ Ollama có completion model riêng
"""
if config.provider.lower() == "ollama":
from langchain_ollama.llms import OllamaLLM
base_url = config.base_url or os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
return OllamaLLM(
model=config.model_name,
temperature=config.temperature,
base_url=base_url
)
# Các provider khác dùng Chat interface
return get_llm(config)
# ============== HELPER ==============
def create_config(
provider: str = "gemini",
model_name: Optional[str] = None,
api_key: Optional[str] = None,
temperature: float = 0.1,
base_url: Optional[str] = None
) -> ModelConfig:
"""
Helper function tạo ModelConfig
Nếu không chỉ định model_name, sẽ dùng default cho provider đó
"""
default_models = {
"ollama": "qwen2.5:14b",
"gemini": "gemini-2.0-flash-lite",
"openai": "gpt-4o-mini"
}
return ModelConfig(
provider=provider,
model_name=model_name or default_models.get(provider, "gemini-2.0-flash-lite"),
api_key=api_key,
temperature=temperature,
base_url=base_url
)