feat: add multi-provider LLM support with thinking configurations

Models added:
- OpenAI: GPT-5.2, GPT-5.1, GPT-5, GPT-5 Mini, GPT-5 Nano, GPT-4.1
- Anthropic: Claude Opus 4.5/4.1, Claude Sonnet 4.5/4, Claude Haiku 4.5
- Google: Gemini 3 Pro/Flash, Gemini 2.5 Flash/Flash Lite
- xAI: Grok 4, Grok 4.1 Fast (Reasoning/Non-Reasoning)

Configs updated:
- Add unified thinking_level for Gemini (maps to thinking_level for Gemini 3,
  thinking_budget for Gemini 2.5; handles Pro's lack of "minimal" support)
- Add OpenAI reasoning_effort configuration
- Add NormalizedChatGoogleGenerativeAI for consistent response handling

Fixes:
- Fix Bull/Bear researcher display truncation
- Replace ChromaDB with BM25 for memory retrieval
This commit is contained in:
Yijia Xiao
2026-01-26 16:48:28 +00:00
parent 79051580b8
commit d4dadb82fc
17 changed files with 639 additions and 958 deletions
@@ -76,7 +76,7 @@ Volume-Based Indicators:
if len(result.tool_calls) == 0:
report = result.content
return {
"messages": [result],
"market_report": report,
+4 -4
View File
@@ -24,15 +24,15 @@ def create_msg_delete():
def delete_messages(state):
"""Clear messages and add placeholder for Anthropic compatibility"""
messages = state["messages"]
# Remove all messages
removal_operations = [RemoveMessage(id=m.id) for m in messages]
# Add a minimal placeholder message
placeholder = HumanMessage(content="Continue")
return {"messages": removal_operations + [placeholder]}
return delete_messages
+86 -55
View File
@@ -1,75 +1,106 @@
import chromadb
from chromadb.config import Settings
from openai import OpenAI
"""Financial situation memory using BM25 for lexical similarity matching.
Uses BM25 (Best Matching 25) algorithm for retrieval - no API calls,
no token limits, works offline with any LLM provider.
"""
from rank_bm25 import BM25Okapi
from typing import List, Tuple
import re
class FinancialSituationMemory:
def __init__(self, name, config):
if config["backend_url"] == "http://localhost:11434/v1":
self.embedding = "nomic-embed-text"
"""Memory system for storing and retrieving financial situations using BM25."""
def __init__(self, name: str, config: dict = None):
"""Initialize the memory system.
Args:
name: Name identifier for this memory instance
config: Configuration dict (kept for API compatibility, not used for BM25)
"""
self.name = name
self.documents: List[str] = []
self.recommendations: List[str] = []
self.bm25 = None
def _tokenize(self, text: str) -> List[str]:
"""Tokenize text for BM25 indexing.
Simple whitespace + punctuation tokenization with lowercasing.
"""
# Lowercase and split on non-alphanumeric characters
tokens = re.findall(r'\b\w+\b', text.lower())
return tokens
def _rebuild_index(self):
"""Rebuild the BM25 index after adding documents."""
if self.documents:
tokenized_docs = [self._tokenize(doc) for doc in self.documents]
self.bm25 = BM25Okapi(tokenized_docs)
else:
self.embedding = "text-embedding-3-small"
self.client = OpenAI(base_url=config["backend_url"])
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
self.situation_collection = self.chroma_client.create_collection(name=name)
self.bm25 = None
def get_embedding(self, text):
"""Get OpenAI embedding for a text"""
response = self.client.embeddings.create(
model=self.embedding, input=text
)
return response.data[0].embedding
def add_situations(self, situations_and_advice: List[Tuple[str, str]]):
"""Add financial situations and their corresponding advice.
def add_situations(self, situations_and_advice):
"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""
Args:
situations_and_advice: List of tuples (situation, recommendation)
"""
for situation, recommendation in situations_and_advice:
self.documents.append(situation)
self.recommendations.append(recommendation)
situations = []
advice = []
ids = []
embeddings = []
# Rebuild BM25 index with new documents
self._rebuild_index()
offset = self.situation_collection.count()
def get_memories(self, current_situation: str, n_matches: int = 1) -> List[dict]:
"""Find matching recommendations using BM25 similarity.
for i, (situation, recommendation) in enumerate(situations_and_advice):
situations.append(situation)
advice.append(recommendation)
ids.append(str(offset + i))
embeddings.append(self.get_embedding(situation))
Args:
current_situation: The current financial situation to match against
n_matches: Number of top matches to return
self.situation_collection.add(
documents=situations,
metadatas=[{"recommendation": rec} for rec in advice],
embeddings=embeddings,
ids=ids,
)
Returns:
List of dicts with matched_situation, recommendation, and similarity_score
"""
if not self.documents or self.bm25 is None:
return []
def get_memories(self, current_situation, n_matches=1):
"""Find matching recommendations using OpenAI embeddings"""
query_embedding = self.get_embedding(current_situation)
# Tokenize query
query_tokens = self._tokenize(current_situation)
results = self.situation_collection.query(
query_embeddings=[query_embedding],
n_results=n_matches,
include=["metadatas", "documents", "distances"],
)
# Get BM25 scores for all documents
scores = self.bm25.get_scores(query_tokens)
matched_results = []
for i in range(len(results["documents"][0])):
matched_results.append(
{
"matched_situation": results["documents"][0][i],
"recommendation": results["metadatas"][0][i]["recommendation"],
"similarity_score": 1 - results["distances"][0][i],
}
)
# Get top-n indices sorted by score (descending)
top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:n_matches]
return matched_results
# Build results
results = []
max_score = max(scores) if max(scores) > 0 else 1 # Normalize scores
for idx in top_indices:
# Normalize score to 0-1 range for consistency
normalized_score = scores[idx] / max_score if max_score > 0 else 0
results.append({
"matched_situation": self.documents[idx],
"recommendation": self.recommendations[idx],
"similarity_score": normalized_score,
})
return results
def clear(self):
"""Clear all stored memories."""
self.documents = []
self.recommendations = []
self.bm25 = None
if __name__ == "__main__":
# Example usage
matcher = FinancialSituationMemory()
matcher = FinancialSituationMemory("test_memory")
# Example data
example_data = [
@@ -96,7 +127,7 @@ if __name__ == "__main__":
# Example query
current_situation = """
Market showing increased volatility in tech sector, with institutional investors
Market showing increased volatility in tech sector, with institutional investors
reducing positions and rising interest rates affecting growth stock valuations
"""
+2 -5
View File
@@ -3,24 +3,21 @@ from typing import Dict, Optional
# Use default config but allow it to be overridden
_config: Optional[Dict] = None
DATA_DIR: Optional[str] = None
def initialize_config():
"""Initialize the configuration with default values."""
global _config, DATA_DIR
global _config
if _config is None:
_config = default_config.DEFAULT_CONFIG.copy()
DATA_DIR = _config["data_dir"]
def set_config(config: Dict):
"""Update the configuration with custom values."""
global _config, DATA_DIR
global _config
if _config is None:
_config = default_config.DEFAULT_CONFIG.copy()
_config.update(config)
DATA_DIR = _config["data_dir"]
def get_config() -> Dict:
+14 -25
View File
@@ -1,10 +1,15 @@
from typing import Annotated
# Import from vendor-specific modules
from .local import get_YFin_data, get_finnhub_news, get_finnhub_company_insider_sentiment, get_finnhub_company_insider_transactions, get_simfin_balance_sheet, get_simfin_cashflow, get_simfin_income_statements, get_reddit_global_news, get_reddit_company_news
from .y_finance import get_YFin_data_online, get_stock_stats_indicators_window, get_balance_sheet as get_yfinance_balance_sheet, get_cashflow as get_yfinance_cashflow, get_income_statement as get_yfinance_income_statement, get_insider_transactions as get_yfinance_insider_transactions
from .google import get_google_news
from .openai import get_stock_news_openai, get_global_news_openai, get_fundamentals_openai
from .y_finance import (
get_YFin_data_online,
get_stock_stats_indicators_window,
get_balance_sheet as get_yfinance_balance_sheet,
get_cashflow as get_yfinance_cashflow,
get_income_statement as get_yfinance_income_statement,
get_insider_transactions as get_yfinance_insider_transactions,
)
from .yfinance_news import get_news_yfinance, get_global_news_yfinance
from .alpha_vantage import (
get_stock as get_alpha_vantage_stock,
get_indicator as get_alpha_vantage_indicator,
@@ -13,7 +18,7 @@ from .alpha_vantage import (
get_cashflow as get_alpha_vantage_cashflow,
get_income_statement as get_alpha_vantage_income_statement,
get_insider_transactions as get_alpha_vantage_insider_transactions,
get_news as get_alpha_vantage_news
get_news as get_alpha_vantage_news,
)
from .alpha_vantage_common import AlphaVantageRateLimitError
@@ -44,21 +49,18 @@ TOOLS_CATEGORIES = {
]
},
"news_data": {
"description": "News (public/insiders, original/processed)",
"description": "News and insider data",
"tools": [
"get_news",
"get_global_news",
"get_insider_sentiment",
"get_insider_transactions",
]
}
}
VENDOR_LIST = [
"local",
"yfinance",
"openai",
"google"
"alpha_vantage",
]
# Mapping of methods to their vendor-specific implementations
@@ -67,52 +69,39 @@ VENDOR_METHODS = {
"get_stock_data": {
"alpha_vantage": get_alpha_vantage_stock,
"yfinance": get_YFin_data_online,
"local": get_YFin_data,
},
# technical_indicators
"get_indicators": {
"alpha_vantage": get_alpha_vantage_indicator,
"yfinance": get_stock_stats_indicators_window,
"local": get_stock_stats_indicators_window
},
# fundamental_data
"get_fundamentals": {
"alpha_vantage": get_alpha_vantage_fundamentals,
"openai": get_fundamentals_openai,
},
"get_balance_sheet": {
"alpha_vantage": get_alpha_vantage_balance_sheet,
"yfinance": get_yfinance_balance_sheet,
"local": get_simfin_balance_sheet,
},
"get_cashflow": {
"alpha_vantage": get_alpha_vantage_cashflow,
"yfinance": get_yfinance_cashflow,
"local": get_simfin_cashflow,
},
"get_income_statement": {
"alpha_vantage": get_alpha_vantage_income_statement,
"yfinance": get_yfinance_income_statement,
"local": get_simfin_income_statements,
},
# news_data
"get_news": {
"alpha_vantage": get_alpha_vantage_news,
"openai": get_stock_news_openai,
"google": get_google_news,
"local": [get_finnhub_news, get_reddit_company_news, get_google_news],
"yfinance": get_news_yfinance,
},
"get_global_news": {
"openai": get_global_news_openai,
"local": get_reddit_global_news
},
"get_insider_sentiment": {
"local": get_finnhub_company_insider_sentiment
"yfinance": get_global_news_yfinance,
},
"get_insider_transactions": {
"alpha_vantage": get_alpha_vantage_insider_transactions,
"yfinance": get_yfinance_insider_transactions,
"local": get_finnhub_company_insider_transactions,
},
}
+32 -50
View File
@@ -3,7 +3,7 @@ import yfinance as yf
from stockstats import wrap
from typing import Annotated
import os
from .config import get_config, DATA_DIR
from .config import get_config
class StockstatsUtils:
@@ -17,63 +17,45 @@ class StockstatsUtils:
str, "curr date for retrieving stock price data, YYYY-mm-dd"
],
):
# Get config and set up data directory path
config = get_config()
online = config["data_vendors"]["technical_indicators"] != "local"
df = None
data = None
today_date = pd.Timestamp.today()
curr_date_dt = pd.to_datetime(curr_date)
if not online:
try:
data = pd.read_csv(
os.path.join(
DATA_DIR,
f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
)
)
df = wrap(data)
except FileNotFoundError:
raise Exception("Stockstats fail: Yahoo Finance data not fetched yet!")
end_date = today_date
start_date = today_date - pd.DateOffset(years=15)
start_date_str = start_date.strftime("%Y-%m-%d")
end_date_str = end_date.strftime("%Y-%m-%d")
# Ensure cache directory exists
os.makedirs(config["data_cache_dir"], exist_ok=True)
data_file = os.path.join(
config["data_cache_dir"],
f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv",
)
if os.path.exists(data_file):
data = pd.read_csv(data_file)
data["Date"] = pd.to_datetime(data["Date"])
else:
# Get today's date as YYYY-mm-dd to add to cache
today_date = pd.Timestamp.today()
curr_date = pd.to_datetime(curr_date)
end_date = today_date
start_date = today_date - pd.DateOffset(years=15)
start_date = start_date.strftime("%Y-%m-%d")
end_date = end_date.strftime("%Y-%m-%d")
# Get config and ensure cache directory exists
os.makedirs(config["data_cache_dir"], exist_ok=True)
data_file = os.path.join(
config["data_cache_dir"],
f"{symbol}-YFin-data-{start_date}-{end_date}.csv",
data = yf.download(
symbol,
start=start_date_str,
end=end_date_str,
multi_level_index=False,
progress=False,
auto_adjust=True,
)
data = data.reset_index()
data.to_csv(data_file, index=False)
if os.path.exists(data_file):
data = pd.read_csv(data_file)
data["Date"] = pd.to_datetime(data["Date"])
else:
data = yf.download(
symbol,
start=start_date,
end=end_date,
multi_level_index=False,
progress=False,
auto_adjust=True,
)
data = data.reset_index()
data.to_csv(data_file, index=False)
df = wrap(data)
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
curr_date = curr_date.strftime("%Y-%m-%d")
df = wrap(data)
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
curr_date_str = curr_date_dt.strftime("%Y-%m-%d")
df[indicator] # trigger stockstats to calculate the indicator
matching_rows = df[df["Date"].str.startswith(curr_date)]
matching_rows = df[df["Date"].str.startswith(curr_date_str)]
if not matching_rows.empty:
indicator_value = matching_rows[indicator].values[0]
+190
View File
@@ -0,0 +1,190 @@
"""yfinance-based news data fetching functions."""
import yfinance as yf
from datetime import datetime
from dateutil.relativedelta import relativedelta
def _extract_article_data(article: dict) -> dict:
"""Extract article data from yfinance news format (handles nested 'content' structure)."""
# Handle nested content structure
if "content" in article:
content = article["content"]
title = content.get("title", "No title")
summary = content.get("summary", "")
provider = content.get("provider", {})
publisher = provider.get("displayName", "Unknown")
# Get URL from canonicalUrl or clickThroughUrl
url_obj = content.get("canonicalUrl") or content.get("clickThroughUrl") or {}
link = url_obj.get("url", "")
# Get publish date
pub_date_str = content.get("pubDate", "")
pub_date = None
if pub_date_str:
try:
pub_date = datetime.fromisoformat(pub_date_str.replace("Z", "+00:00"))
except (ValueError, AttributeError):
pass
return {
"title": title,
"summary": summary,
"publisher": publisher,
"link": link,
"pub_date": pub_date,
}
else:
# Fallback for flat structure
return {
"title": article.get("title", "No title"),
"summary": article.get("summary", ""),
"publisher": article.get("publisher", "Unknown"),
"link": article.get("link", ""),
"pub_date": None,
}
def get_news_yfinance(
ticker: str,
start_date: str,
end_date: str,
) -> str:
"""
Retrieve news for a specific stock ticker using yfinance.
Args:
ticker: Stock ticker symbol (e.g., "AAPL")
start_date: Start date in yyyy-mm-dd format
end_date: End date in yyyy-mm-dd format
Returns:
Formatted string containing news articles
"""
try:
stock = yf.Ticker(ticker)
news = stock.get_news(count=20, tab="news")
if not news:
return f"No news found for {ticker}"
# Parse date range for filtering
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
news_str = ""
filtered_count = 0
for article in news:
data = _extract_article_data(article)
# Filter by date if publish time is available
if data["pub_date"]:
pub_date_naive = data["pub_date"].replace(tzinfo=None)
if not (start_dt <= pub_date_naive <= end_dt + relativedelta(days=1)):
continue
news_str += f"### {data['title']} (source: {data['publisher']})\n"
if data["summary"]:
news_str += f"{data['summary']}\n"
if data["link"]:
news_str += f"Link: {data['link']}\n"
news_str += "\n"
filtered_count += 1
if filtered_count == 0:
return f"No news found for {ticker} between {start_date} and {end_date}"
return f"## {ticker} News, from {start_date} to {end_date}:\n\n{news_str}"
except Exception as e:
return f"Error fetching news for {ticker}: {str(e)}"
def get_global_news_yfinance(
curr_date: str,
look_back_days: int = 7,
limit: int = 10,
) -> str:
"""
Retrieve global/macro economic news using yfinance Search.
Args:
curr_date: Current date in yyyy-mm-dd format
look_back_days: Number of days to look back
limit: Maximum number of articles to return
Returns:
Formatted string containing global news articles
"""
# Search queries for macro/global news
search_queries = [
"stock market economy",
"Federal Reserve interest rates",
"inflation economic outlook",
"global markets trading",
]
all_news = []
seen_titles = set()
try:
for query in search_queries:
search = yf.Search(
query=query,
news_count=limit,
enable_fuzzy_query=True,
)
if search.news:
for article in search.news:
# Handle both flat and nested structures
if "content" in article:
data = _extract_article_data(article)
title = data["title"]
else:
title = article.get("title", "")
# Deduplicate by title
if title and title not in seen_titles:
seen_titles.add(title)
all_news.append(article)
if len(all_news) >= limit:
break
if not all_news:
return f"No global news found for {curr_date}"
# Calculate date range
curr_dt = datetime.strptime(curr_date, "%Y-%m-%d")
start_dt = curr_dt - relativedelta(days=look_back_days)
start_date = start_dt.strftime("%Y-%m-%d")
news_str = ""
for article in all_news[:limit]:
# Handle both flat and nested structures
if "content" in article:
data = _extract_article_data(article)
title = data["title"]
publisher = data["publisher"]
link = data["link"]
summary = data["summary"]
else:
title = article.get("title", "No title")
publisher = article.get("publisher", "Unknown")
link = article.get("link", "")
summary = ""
news_str += f"### {title} (source: {publisher})\n"
if summary:
news_str += f"{summary}\n"
if link:
news_str += f"Link: {link}\n"
news_str += "\n"
return f"## Global Market News, from {start_date} to {curr_date}:\n\n{news_str}"
except Exception as e:
return f"Error fetching global news: {str(e)}"
+9 -8
View File
@@ -3,16 +3,18 @@ import os
DEFAULT_CONFIG = {
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
"data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data",
"data_cache_dir": os.path.join(
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
"dataflows/data_cache",
),
# LLM settings
"llm_provider": "openai",
"deep_think_llm": "o4-mini",
"quick_think_llm": "gpt-4o-mini",
"deep_think_llm": "gpt-5.2",
"quick_think_llm": "gpt-5-mini",
"backend_url": "https://api.openai.com/v1",
# Provider-specific thinking configuration
"google_thinking_level": None, # "high", "minimal", etc.
"openai_reasoning_effort": None, # "medium", "high", "low"
# Debate and discussion settings
"max_debate_rounds": 1,
"max_risk_discuss_rounds": 1,
@@ -20,14 +22,13 @@ DEFAULT_CONFIG = {
# Data vendor configuration
# Category-level configuration (default for all tools in category)
"data_vendors": {
"core_stock_apis": "yfinance", # Options: yfinance, alpha_vantage, local
"technical_indicators": "yfinance", # Options: yfinance, alpha_vantage, local
"fundamental_data": "alpha_vantage", # Options: openai, alpha_vantage, local
"news_data": "alpha_vantage", # Options: openai, alpha_vantage, google, local
"core_stock_apis": "yfinance", # Options: yfinance, alpha_vantage
"technical_indicators": "yfinance", # Options: yfinance, alpha_vantage
"fundamental_data": "alpha_vantage", # Options: alpha_vantage, yfinance
"news_data": "yfinance", # Options: yfinance, alpha_vantage
},
# Tool-level configuration (takes precedence over category-level)
"tool_vendors": {
# Example: "get_stock_data": "alpha_vantage", # Override category default
# Example: "get_news": "openai", # Override category default
},
}
+22 -1
View File
@@ -69,16 +69,20 @@ class TradingAgentsGraph:
exist_ok=True,
)
# Initialize LLMs
# Initialize LLMs with provider-specific thinking configuration
llm_kwargs = self._get_provider_kwargs()
deep_client = create_llm_client(
provider=self.config["llm_provider"],
model=self.config["deep_think_llm"],
base_url=self.config.get("backend_url"),
**llm_kwargs,
)
quick_client = create_llm_client(
provider=self.config["llm_provider"],
model=self.config["quick_think_llm"],
base_url=self.config.get("backend_url"),
**llm_kwargs,
)
self.deep_thinking_llm = deep_client.get_llm()
self.quick_thinking_llm = quick_client.get_llm()
@@ -119,6 +123,23 @@ class TradingAgentsGraph:
# Set up the graph
self.graph = self.graph_setup.setup_graph(selected_analysts)
def _get_provider_kwargs(self) -> Dict[str, Any]:
"""Get provider-specific kwargs for LLM client creation."""
kwargs = {}
provider = self.config.get("llm_provider", "").lower()
if provider == "google":
thinking_level = self.config.get("google_thinking_level")
if thinking_level:
kwargs["thinking_level"] = thinking_level
elif provider == "openai":
reasoning_effort = self.config.get("openai_reasoning_effort")
if reasoning_effort:
kwargs["reasoning_effort"] = reasoning_effort
return kwargs
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
"""Create tool nodes for different data sources using abstract methods."""
return {
@@ -14,18 +14,12 @@ class AnthropicClient(BaseLLMClient):
def get_llm(self) -> Any:
"""Return configured ChatAnthropic instance."""
llm_kwargs = {
"model": self.model,
"max_tokens": self.kwargs.get("max_tokens", 4096),
}
llm_kwargs = {"model": self.model}
for key in ("timeout", "max_retries", "api_key"):
for key in ("timeout", "max_retries", "api_key", "max_tokens"):
if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key]
if "thinking_config" in self.kwargs:
llm_kwargs["thinking"] = self.kwargs["thinking_config"]
return ChatAnthropic(**llm_kwargs)
def validate_model(self) -> bool:
+38 -7
View File
@@ -6,6 +6,28 @@ from .base_client import BaseLLMClient
from .validators import validate_model
class NormalizedChatGoogleGenerativeAI(ChatGoogleGenerativeAI):
"""ChatGoogleGenerativeAI with normalized content output.
Gemini 3 models return content as list: [{'type': 'text', 'text': '...'}]
This normalizes to string for consistent downstream handling.
"""
def _normalize_content(self, response):
content = response.content
if isinstance(content, list):
texts = [
item.get("text", "") if isinstance(item, dict) and item.get("type") == "text"
else item if isinstance(item, str) else ""
for item in content
]
response.content = "\n".join(t for t in texts if t)
return response
def invoke(self, input, config=None, **kwargs):
return self._normalize_content(super().invoke(input, config, **kwargs))
class GoogleClient(BaseLLMClient):
"""Client for Google Gemini models."""
@@ -20,14 +42,23 @@ class GoogleClient(BaseLLMClient):
if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key]
if "thinking_budget" in self.kwargs and self._is_preview_model():
llm_kwargs["thinking_budget"] = self.kwargs["thinking_budget"]
# Map thinking_level to appropriate API param based on model
# Gemini 3 Pro: low, high
# Gemini 3 Flash: minimal, low, medium, high
# Gemini 2.5: thinking_budget (0=disable, -1=dynamic)
thinking_level = self.kwargs.get("thinking_level")
if thinking_level:
model_lower = self.model.lower()
if "gemini-3" in model_lower:
# Gemini 3 Pro doesn't support "minimal", use "low" instead
if "pro" in model_lower and thinking_level == "minimal":
thinking_level = "low"
llm_kwargs["thinking_level"] = thinking_level
else:
# Gemini 2.5: map to thinking_budget
llm_kwargs["thinking_budget"] = -1 if thinking_level == "high" else 0
return ChatGoogleGenerativeAI(**llm_kwargs)
def _is_preview_model(self) -> bool:
"""Check if this is a preview model that supports thinking budget."""
return "preview" in self.model.lower()
return NormalizedChatGoogleGenerativeAI(**llm_kwargs)
def validate_model(self) -> bool:
"""Validate model for Google."""
+54 -41
View File
@@ -1,58 +1,75 @@
from typing import Dict, List
"""Model name validators for each provider.
VALID_MODELS: Dict[str, List[str]] = {
Only validates model names - does NOT enforce limits.
Let LLM providers use their own defaults for unspecified params.
"""
VALID_MODELS = {
"openai": [
# GPT-5 series (2025)
"gpt-5.2",
"gpt-5.1",
"gpt-5",
"gpt-5-mini",
"gpt-5-nano",
# GPT-4.1 series (2025)
"gpt-4.1",
"gpt-4.1-mini",
"gpt-4.1-nano",
# o-series reasoning models
"o4-mini",
"o3",
"o3-mini",
"o1",
"o1-preview",
# GPT-4o series (legacy but still supported)
"gpt-4o",
"gpt-4o-mini",
"gpt-4-turbo",
"gpt-4",
"gpt-3.5-turbo",
"o1",
"o1-mini",
"o1-preview",
"o3-mini",
"gpt-5-nano",
"gpt-5-mini",
"gpt-5",
],
"anthropic": [
"claude-3-5-sonnet-20241022",
"claude-3-5-haiku-20241022",
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
# Claude 4.5 series (2025)
"claude-opus-4-5",
"claude-sonnet-4-5",
"claude-haiku-4-5",
# Claude 4.x series
"claude-opus-4-1-20250805",
"claude-sonnet-4-20250514",
"claude-haiku-4-5-20251001",
"claude-opus-4-5-20251101",
# Claude 3.7 series
"claude-3-7-sonnet-20250219",
# Claude 3.5 series (legacy)
"claude-3-5-haiku-20241022",
"claude-3-5-sonnet-20241022",
],
"google": [
"gemini-1.5-pro",
"gemini-1.5-flash",
"gemini-2.0-flash",
"gemini-2.0-flash-lite",
"gemini-2.5-pro-preview-05-06",
"gemini-2.5-flash-preview-05-20",
# Gemini 3 series (preview)
"gemini-3-pro-preview",
"gemini-3-flash-preview",
# Gemini 2.5 series
"gemini-2.5-pro",
"gemini-2.5-flash",
"gemini-2.5-flash-lite",
# Gemini 2.0 series
"gemini-2.0-flash",
"gemini-2.0-flash-lite",
],
"xai": [
"grok-beta",
"grok-2",
"grok-2-mini",
"grok-3",
"grok-3-mini",
# Grok 4.1 series
"grok-4-1-fast",
"grok-4-1-fast-reasoning",
"grok-4-1-fast-non-reasoning",
# Grok 4 series
"grok-4",
"grok-4-0709",
"grok-4-fast-reasoning",
"grok-4-fast-non-reasoning",
],
"ollama": [],
"openrouter": [],
"vllm": [],
}
def validate_model(provider: str, model: str) -> bool:
"""Validate that a model is supported by the provider.
"""Check if model name is valid for the given provider.
For ollama, openrouter, and vllm, any model is accepted.
For other providers, checks against VALID_MODELS.
For ollama, openrouter, vllm - any model is accepted.
"""
provider_lower = provider.lower()
@@ -60,10 +77,6 @@ def validate_model(provider: str, model: str) -> bool:
return True
if provider_lower not in VALID_MODELS:
return False
valid = VALID_MODELS[provider_lower]
if not valid:
return True
return model in valid
return model in VALID_MODELS[provider_lower]