mirror of
https://github.com/farcasclaudiu/TradingAgents.git
synced 2026-06-28 17:01:20 +03:00
feat: add multi-provider LLM support with thinking configurations
Models added: - OpenAI: GPT-5.2, GPT-5.1, GPT-5, GPT-5 Mini, GPT-5 Nano, GPT-4.1 - Anthropic: Claude Opus 4.5/4.1, Claude Sonnet 4.5/4, Claude Haiku 4.5 - Google: Gemini 3 Pro/Flash, Gemini 2.5 Flash/Flash Lite - xAI: Grok 4, Grok 4.1 Fast (Reasoning/Non-Reasoning) Configs updated: - Add unified thinking_level for Gemini (maps to thinking_level for Gemini 3, thinking_budget for Gemini 2.5; handles Pro's lack of "minimal" support) - Add OpenAI reasoning_effort configuration - Add NormalizedChatGoogleGenerativeAI for consistent response handling Fixes: - Fix Bull/Bear researcher display truncation - Replace ChromaDB with BM25 for memory retrieval
This commit is contained in:
@@ -76,7 +76,7 @@ Volume-Based Indicators:
|
||||
|
||||
if len(result.tool_calls) == 0:
|
||||
report = result.content
|
||||
|
||||
|
||||
return {
|
||||
"messages": [result],
|
||||
"market_report": report,
|
||||
|
||||
@@ -24,15 +24,15 @@ def create_msg_delete():
|
||||
def delete_messages(state):
|
||||
"""Clear messages and add placeholder for Anthropic compatibility"""
|
||||
messages = state["messages"]
|
||||
|
||||
|
||||
# Remove all messages
|
||||
removal_operations = [RemoveMessage(id=m.id) for m in messages]
|
||||
|
||||
|
||||
# Add a minimal placeholder message
|
||||
placeholder = HumanMessage(content="Continue")
|
||||
|
||||
|
||||
return {"messages": removal_operations + [placeholder]}
|
||||
|
||||
|
||||
return delete_messages
|
||||
|
||||
|
||||
|
||||
@@ -1,75 +1,106 @@
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from openai import OpenAI
|
||||
"""Financial situation memory using BM25 for lexical similarity matching.
|
||||
|
||||
Uses BM25 (Best Matching 25) algorithm for retrieval - no API calls,
|
||||
no token limits, works offline with any LLM provider.
|
||||
"""
|
||||
|
||||
from rank_bm25 import BM25Okapi
|
||||
from typing import List, Tuple
|
||||
import re
|
||||
|
||||
|
||||
class FinancialSituationMemory:
|
||||
def __init__(self, name, config):
|
||||
if config["backend_url"] == "http://localhost:11434/v1":
|
||||
self.embedding = "nomic-embed-text"
|
||||
"""Memory system for storing and retrieving financial situations using BM25."""
|
||||
|
||||
def __init__(self, name: str, config: dict = None):
|
||||
"""Initialize the memory system.
|
||||
|
||||
Args:
|
||||
name: Name identifier for this memory instance
|
||||
config: Configuration dict (kept for API compatibility, not used for BM25)
|
||||
"""
|
||||
self.name = name
|
||||
self.documents: List[str] = []
|
||||
self.recommendations: List[str] = []
|
||||
self.bm25 = None
|
||||
|
||||
def _tokenize(self, text: str) -> List[str]:
|
||||
"""Tokenize text for BM25 indexing.
|
||||
|
||||
Simple whitespace + punctuation tokenization with lowercasing.
|
||||
"""
|
||||
# Lowercase and split on non-alphanumeric characters
|
||||
tokens = re.findall(r'\b\w+\b', text.lower())
|
||||
return tokens
|
||||
|
||||
def _rebuild_index(self):
|
||||
"""Rebuild the BM25 index after adding documents."""
|
||||
if self.documents:
|
||||
tokenized_docs = [self._tokenize(doc) for doc in self.documents]
|
||||
self.bm25 = BM25Okapi(tokenized_docs)
|
||||
else:
|
||||
self.embedding = "text-embedding-3-small"
|
||||
self.client = OpenAI(base_url=config["backend_url"])
|
||||
self.chroma_client = chromadb.Client(Settings(allow_reset=True))
|
||||
self.situation_collection = self.chroma_client.create_collection(name=name)
|
||||
self.bm25 = None
|
||||
|
||||
def get_embedding(self, text):
|
||||
"""Get OpenAI embedding for a text"""
|
||||
|
||||
response = self.client.embeddings.create(
|
||||
model=self.embedding, input=text
|
||||
)
|
||||
return response.data[0].embedding
|
||||
def add_situations(self, situations_and_advice: List[Tuple[str, str]]):
|
||||
"""Add financial situations and their corresponding advice.
|
||||
|
||||
def add_situations(self, situations_and_advice):
|
||||
"""Add financial situations and their corresponding advice. Parameter is a list of tuples (situation, rec)"""
|
||||
Args:
|
||||
situations_and_advice: List of tuples (situation, recommendation)
|
||||
"""
|
||||
for situation, recommendation in situations_and_advice:
|
||||
self.documents.append(situation)
|
||||
self.recommendations.append(recommendation)
|
||||
|
||||
situations = []
|
||||
advice = []
|
||||
ids = []
|
||||
embeddings = []
|
||||
# Rebuild BM25 index with new documents
|
||||
self._rebuild_index()
|
||||
|
||||
offset = self.situation_collection.count()
|
||||
def get_memories(self, current_situation: str, n_matches: int = 1) -> List[dict]:
|
||||
"""Find matching recommendations using BM25 similarity.
|
||||
|
||||
for i, (situation, recommendation) in enumerate(situations_and_advice):
|
||||
situations.append(situation)
|
||||
advice.append(recommendation)
|
||||
ids.append(str(offset + i))
|
||||
embeddings.append(self.get_embedding(situation))
|
||||
Args:
|
||||
current_situation: The current financial situation to match against
|
||||
n_matches: Number of top matches to return
|
||||
|
||||
self.situation_collection.add(
|
||||
documents=situations,
|
||||
metadatas=[{"recommendation": rec} for rec in advice],
|
||||
embeddings=embeddings,
|
||||
ids=ids,
|
||||
)
|
||||
Returns:
|
||||
List of dicts with matched_situation, recommendation, and similarity_score
|
||||
"""
|
||||
if not self.documents or self.bm25 is None:
|
||||
return []
|
||||
|
||||
def get_memories(self, current_situation, n_matches=1):
|
||||
"""Find matching recommendations using OpenAI embeddings"""
|
||||
query_embedding = self.get_embedding(current_situation)
|
||||
# Tokenize query
|
||||
query_tokens = self._tokenize(current_situation)
|
||||
|
||||
results = self.situation_collection.query(
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=n_matches,
|
||||
include=["metadatas", "documents", "distances"],
|
||||
)
|
||||
# Get BM25 scores for all documents
|
||||
scores = self.bm25.get_scores(query_tokens)
|
||||
|
||||
matched_results = []
|
||||
for i in range(len(results["documents"][0])):
|
||||
matched_results.append(
|
||||
{
|
||||
"matched_situation": results["documents"][0][i],
|
||||
"recommendation": results["metadatas"][0][i]["recommendation"],
|
||||
"similarity_score": 1 - results["distances"][0][i],
|
||||
}
|
||||
)
|
||||
# Get top-n indices sorted by score (descending)
|
||||
top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:n_matches]
|
||||
|
||||
return matched_results
|
||||
# Build results
|
||||
results = []
|
||||
max_score = max(scores) if max(scores) > 0 else 1 # Normalize scores
|
||||
|
||||
for idx in top_indices:
|
||||
# Normalize score to 0-1 range for consistency
|
||||
normalized_score = scores[idx] / max_score if max_score > 0 else 0
|
||||
results.append({
|
||||
"matched_situation": self.documents[idx],
|
||||
"recommendation": self.recommendations[idx],
|
||||
"similarity_score": normalized_score,
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
def clear(self):
|
||||
"""Clear all stored memories."""
|
||||
self.documents = []
|
||||
self.recommendations = []
|
||||
self.bm25 = None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
matcher = FinancialSituationMemory()
|
||||
matcher = FinancialSituationMemory("test_memory")
|
||||
|
||||
# Example data
|
||||
example_data = [
|
||||
@@ -96,7 +127,7 @@ if __name__ == "__main__":
|
||||
|
||||
# Example query
|
||||
current_situation = """
|
||||
Market showing increased volatility in tech sector, with institutional investors
|
||||
Market showing increased volatility in tech sector, with institutional investors
|
||||
reducing positions and rising interest rates affecting growth stock valuations
|
||||
"""
|
||||
|
||||
|
||||
@@ -3,24 +3,21 @@ from typing import Dict, Optional
|
||||
|
||||
# Use default config but allow it to be overridden
|
||||
_config: Optional[Dict] = None
|
||||
DATA_DIR: Optional[str] = None
|
||||
|
||||
|
||||
def initialize_config():
|
||||
"""Initialize the configuration with default values."""
|
||||
global _config, DATA_DIR
|
||||
global _config
|
||||
if _config is None:
|
||||
_config = default_config.DEFAULT_CONFIG.copy()
|
||||
DATA_DIR = _config["data_dir"]
|
||||
|
||||
|
||||
def set_config(config: Dict):
|
||||
"""Update the configuration with custom values."""
|
||||
global _config, DATA_DIR
|
||||
global _config
|
||||
if _config is None:
|
||||
_config = default_config.DEFAULT_CONFIG.copy()
|
||||
_config.update(config)
|
||||
DATA_DIR = _config["data_dir"]
|
||||
|
||||
|
||||
def get_config() -> Dict:
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
from typing import Annotated
|
||||
|
||||
# Import from vendor-specific modules
|
||||
from .local import get_YFin_data, get_finnhub_news, get_finnhub_company_insider_sentiment, get_finnhub_company_insider_transactions, get_simfin_balance_sheet, get_simfin_cashflow, get_simfin_income_statements, get_reddit_global_news, get_reddit_company_news
|
||||
from .y_finance import get_YFin_data_online, get_stock_stats_indicators_window, get_balance_sheet as get_yfinance_balance_sheet, get_cashflow as get_yfinance_cashflow, get_income_statement as get_yfinance_income_statement, get_insider_transactions as get_yfinance_insider_transactions
|
||||
from .google import get_google_news
|
||||
from .openai import get_stock_news_openai, get_global_news_openai, get_fundamentals_openai
|
||||
from .y_finance import (
|
||||
get_YFin_data_online,
|
||||
get_stock_stats_indicators_window,
|
||||
get_balance_sheet as get_yfinance_balance_sheet,
|
||||
get_cashflow as get_yfinance_cashflow,
|
||||
get_income_statement as get_yfinance_income_statement,
|
||||
get_insider_transactions as get_yfinance_insider_transactions,
|
||||
)
|
||||
from .yfinance_news import get_news_yfinance, get_global_news_yfinance
|
||||
from .alpha_vantage import (
|
||||
get_stock as get_alpha_vantage_stock,
|
||||
get_indicator as get_alpha_vantage_indicator,
|
||||
@@ -13,7 +18,7 @@ from .alpha_vantage import (
|
||||
get_cashflow as get_alpha_vantage_cashflow,
|
||||
get_income_statement as get_alpha_vantage_income_statement,
|
||||
get_insider_transactions as get_alpha_vantage_insider_transactions,
|
||||
get_news as get_alpha_vantage_news
|
||||
get_news as get_alpha_vantage_news,
|
||||
)
|
||||
from .alpha_vantage_common import AlphaVantageRateLimitError
|
||||
|
||||
@@ -44,21 +49,18 @@ TOOLS_CATEGORIES = {
|
||||
]
|
||||
},
|
||||
"news_data": {
|
||||
"description": "News (public/insiders, original/processed)",
|
||||
"description": "News and insider data",
|
||||
"tools": [
|
||||
"get_news",
|
||||
"get_global_news",
|
||||
"get_insider_sentiment",
|
||||
"get_insider_transactions",
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
VENDOR_LIST = [
|
||||
"local",
|
||||
"yfinance",
|
||||
"openai",
|
||||
"google"
|
||||
"alpha_vantage",
|
||||
]
|
||||
|
||||
# Mapping of methods to their vendor-specific implementations
|
||||
@@ -67,52 +69,39 @@ VENDOR_METHODS = {
|
||||
"get_stock_data": {
|
||||
"alpha_vantage": get_alpha_vantage_stock,
|
||||
"yfinance": get_YFin_data_online,
|
||||
"local": get_YFin_data,
|
||||
},
|
||||
# technical_indicators
|
||||
"get_indicators": {
|
||||
"alpha_vantage": get_alpha_vantage_indicator,
|
||||
"yfinance": get_stock_stats_indicators_window,
|
||||
"local": get_stock_stats_indicators_window
|
||||
},
|
||||
# fundamental_data
|
||||
"get_fundamentals": {
|
||||
"alpha_vantage": get_alpha_vantage_fundamentals,
|
||||
"openai": get_fundamentals_openai,
|
||||
},
|
||||
"get_balance_sheet": {
|
||||
"alpha_vantage": get_alpha_vantage_balance_sheet,
|
||||
"yfinance": get_yfinance_balance_sheet,
|
||||
"local": get_simfin_balance_sheet,
|
||||
},
|
||||
"get_cashflow": {
|
||||
"alpha_vantage": get_alpha_vantage_cashflow,
|
||||
"yfinance": get_yfinance_cashflow,
|
||||
"local": get_simfin_cashflow,
|
||||
},
|
||||
"get_income_statement": {
|
||||
"alpha_vantage": get_alpha_vantage_income_statement,
|
||||
"yfinance": get_yfinance_income_statement,
|
||||
"local": get_simfin_income_statements,
|
||||
},
|
||||
# news_data
|
||||
"get_news": {
|
||||
"alpha_vantage": get_alpha_vantage_news,
|
||||
"openai": get_stock_news_openai,
|
||||
"google": get_google_news,
|
||||
"local": [get_finnhub_news, get_reddit_company_news, get_google_news],
|
||||
"yfinance": get_news_yfinance,
|
||||
},
|
||||
"get_global_news": {
|
||||
"openai": get_global_news_openai,
|
||||
"local": get_reddit_global_news
|
||||
},
|
||||
"get_insider_sentiment": {
|
||||
"local": get_finnhub_company_insider_sentiment
|
||||
"yfinance": get_global_news_yfinance,
|
||||
},
|
||||
"get_insider_transactions": {
|
||||
"alpha_vantage": get_alpha_vantage_insider_transactions,
|
||||
"yfinance": get_yfinance_insider_transactions,
|
||||
"local": get_finnhub_company_insider_transactions,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import yfinance as yf
|
||||
from stockstats import wrap
|
||||
from typing import Annotated
|
||||
import os
|
||||
from .config import get_config, DATA_DIR
|
||||
from .config import get_config
|
||||
|
||||
|
||||
class StockstatsUtils:
|
||||
@@ -17,63 +17,45 @@ class StockstatsUtils:
|
||||
str, "curr date for retrieving stock price data, YYYY-mm-dd"
|
||||
],
|
||||
):
|
||||
# Get config and set up data directory path
|
||||
config = get_config()
|
||||
online = config["data_vendors"]["technical_indicators"] != "local"
|
||||
|
||||
df = None
|
||||
data = None
|
||||
today_date = pd.Timestamp.today()
|
||||
curr_date_dt = pd.to_datetime(curr_date)
|
||||
|
||||
if not online:
|
||||
try:
|
||||
data = pd.read_csv(
|
||||
os.path.join(
|
||||
DATA_DIR,
|
||||
f"{symbol}-YFin-data-2015-01-01-2025-03-25.csv",
|
||||
)
|
||||
)
|
||||
df = wrap(data)
|
||||
except FileNotFoundError:
|
||||
raise Exception("Stockstats fail: Yahoo Finance data not fetched yet!")
|
||||
end_date = today_date
|
||||
start_date = today_date - pd.DateOffset(years=15)
|
||||
start_date_str = start_date.strftime("%Y-%m-%d")
|
||||
end_date_str = end_date.strftime("%Y-%m-%d")
|
||||
|
||||
# Ensure cache directory exists
|
||||
os.makedirs(config["data_cache_dir"], exist_ok=True)
|
||||
|
||||
data_file = os.path.join(
|
||||
config["data_cache_dir"],
|
||||
f"{symbol}-YFin-data-{start_date_str}-{end_date_str}.csv",
|
||||
)
|
||||
|
||||
if os.path.exists(data_file):
|
||||
data = pd.read_csv(data_file)
|
||||
data["Date"] = pd.to_datetime(data["Date"])
|
||||
else:
|
||||
# Get today's date as YYYY-mm-dd to add to cache
|
||||
today_date = pd.Timestamp.today()
|
||||
curr_date = pd.to_datetime(curr_date)
|
||||
|
||||
end_date = today_date
|
||||
start_date = today_date - pd.DateOffset(years=15)
|
||||
start_date = start_date.strftime("%Y-%m-%d")
|
||||
end_date = end_date.strftime("%Y-%m-%d")
|
||||
|
||||
# Get config and ensure cache directory exists
|
||||
os.makedirs(config["data_cache_dir"], exist_ok=True)
|
||||
|
||||
data_file = os.path.join(
|
||||
config["data_cache_dir"],
|
||||
f"{symbol}-YFin-data-{start_date}-{end_date}.csv",
|
||||
data = yf.download(
|
||||
symbol,
|
||||
start=start_date_str,
|
||||
end=end_date_str,
|
||||
multi_level_index=False,
|
||||
progress=False,
|
||||
auto_adjust=True,
|
||||
)
|
||||
data = data.reset_index()
|
||||
data.to_csv(data_file, index=False)
|
||||
|
||||
if os.path.exists(data_file):
|
||||
data = pd.read_csv(data_file)
|
||||
data["Date"] = pd.to_datetime(data["Date"])
|
||||
else:
|
||||
data = yf.download(
|
||||
symbol,
|
||||
start=start_date,
|
||||
end=end_date,
|
||||
multi_level_index=False,
|
||||
progress=False,
|
||||
auto_adjust=True,
|
||||
)
|
||||
data = data.reset_index()
|
||||
data.to_csv(data_file, index=False)
|
||||
|
||||
df = wrap(data)
|
||||
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
|
||||
curr_date = curr_date.strftime("%Y-%m-%d")
|
||||
df = wrap(data)
|
||||
df["Date"] = df["Date"].dt.strftime("%Y-%m-%d")
|
||||
curr_date_str = curr_date_dt.strftime("%Y-%m-%d")
|
||||
|
||||
df[indicator] # trigger stockstats to calculate the indicator
|
||||
matching_rows = df[df["Date"].str.startswith(curr_date)]
|
||||
matching_rows = df[df["Date"].str.startswith(curr_date_str)]
|
||||
|
||||
if not matching_rows.empty:
|
||||
indicator_value = matching_rows[indicator].values[0]
|
||||
|
||||
@@ -0,0 +1,190 @@
|
||||
"""yfinance-based news data fetching functions."""
|
||||
|
||||
import yfinance as yf
|
||||
from datetime import datetime
|
||||
from dateutil.relativedelta import relativedelta
|
||||
|
||||
|
||||
def _extract_article_data(article: dict) -> dict:
|
||||
"""Extract article data from yfinance news format (handles nested 'content' structure)."""
|
||||
# Handle nested content structure
|
||||
if "content" in article:
|
||||
content = article["content"]
|
||||
title = content.get("title", "No title")
|
||||
summary = content.get("summary", "")
|
||||
provider = content.get("provider", {})
|
||||
publisher = provider.get("displayName", "Unknown")
|
||||
|
||||
# Get URL from canonicalUrl or clickThroughUrl
|
||||
url_obj = content.get("canonicalUrl") or content.get("clickThroughUrl") or {}
|
||||
link = url_obj.get("url", "")
|
||||
|
||||
# Get publish date
|
||||
pub_date_str = content.get("pubDate", "")
|
||||
pub_date = None
|
||||
if pub_date_str:
|
||||
try:
|
||||
pub_date = datetime.fromisoformat(pub_date_str.replace("Z", "+00:00"))
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
return {
|
||||
"title": title,
|
||||
"summary": summary,
|
||||
"publisher": publisher,
|
||||
"link": link,
|
||||
"pub_date": pub_date,
|
||||
}
|
||||
else:
|
||||
# Fallback for flat structure
|
||||
return {
|
||||
"title": article.get("title", "No title"),
|
||||
"summary": article.get("summary", ""),
|
||||
"publisher": article.get("publisher", "Unknown"),
|
||||
"link": article.get("link", ""),
|
||||
"pub_date": None,
|
||||
}
|
||||
|
||||
|
||||
def get_news_yfinance(
|
||||
ticker: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve news for a specific stock ticker using yfinance.
|
||||
|
||||
Args:
|
||||
ticker: Stock ticker symbol (e.g., "AAPL")
|
||||
start_date: Start date in yyyy-mm-dd format
|
||||
end_date: End date in yyyy-mm-dd format
|
||||
|
||||
Returns:
|
||||
Formatted string containing news articles
|
||||
"""
|
||||
try:
|
||||
stock = yf.Ticker(ticker)
|
||||
news = stock.get_news(count=20, tab="news")
|
||||
|
||||
if not news:
|
||||
return f"No news found for {ticker}"
|
||||
|
||||
# Parse date range for filtering
|
||||
start_dt = datetime.strptime(start_date, "%Y-%m-%d")
|
||||
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
|
||||
news_str = ""
|
||||
filtered_count = 0
|
||||
|
||||
for article in news:
|
||||
data = _extract_article_data(article)
|
||||
|
||||
# Filter by date if publish time is available
|
||||
if data["pub_date"]:
|
||||
pub_date_naive = data["pub_date"].replace(tzinfo=None)
|
||||
if not (start_dt <= pub_date_naive <= end_dt + relativedelta(days=1)):
|
||||
continue
|
||||
|
||||
news_str += f"### {data['title']} (source: {data['publisher']})\n"
|
||||
if data["summary"]:
|
||||
news_str += f"{data['summary']}\n"
|
||||
if data["link"]:
|
||||
news_str += f"Link: {data['link']}\n"
|
||||
news_str += "\n"
|
||||
filtered_count += 1
|
||||
|
||||
if filtered_count == 0:
|
||||
return f"No news found for {ticker} between {start_date} and {end_date}"
|
||||
|
||||
return f"## {ticker} News, from {start_date} to {end_date}:\n\n{news_str}"
|
||||
|
||||
except Exception as e:
|
||||
return f"Error fetching news for {ticker}: {str(e)}"
|
||||
|
||||
|
||||
def get_global_news_yfinance(
|
||||
curr_date: str,
|
||||
look_back_days: int = 7,
|
||||
limit: int = 10,
|
||||
) -> str:
|
||||
"""
|
||||
Retrieve global/macro economic news using yfinance Search.
|
||||
|
||||
Args:
|
||||
curr_date: Current date in yyyy-mm-dd format
|
||||
look_back_days: Number of days to look back
|
||||
limit: Maximum number of articles to return
|
||||
|
||||
Returns:
|
||||
Formatted string containing global news articles
|
||||
"""
|
||||
# Search queries for macro/global news
|
||||
search_queries = [
|
||||
"stock market economy",
|
||||
"Federal Reserve interest rates",
|
||||
"inflation economic outlook",
|
||||
"global markets trading",
|
||||
]
|
||||
|
||||
all_news = []
|
||||
seen_titles = set()
|
||||
|
||||
try:
|
||||
for query in search_queries:
|
||||
search = yf.Search(
|
||||
query=query,
|
||||
news_count=limit,
|
||||
enable_fuzzy_query=True,
|
||||
)
|
||||
|
||||
if search.news:
|
||||
for article in search.news:
|
||||
# Handle both flat and nested structures
|
||||
if "content" in article:
|
||||
data = _extract_article_data(article)
|
||||
title = data["title"]
|
||||
else:
|
||||
title = article.get("title", "")
|
||||
|
||||
# Deduplicate by title
|
||||
if title and title not in seen_titles:
|
||||
seen_titles.add(title)
|
||||
all_news.append(article)
|
||||
|
||||
if len(all_news) >= limit:
|
||||
break
|
||||
|
||||
if not all_news:
|
||||
return f"No global news found for {curr_date}"
|
||||
|
||||
# Calculate date range
|
||||
curr_dt = datetime.strptime(curr_date, "%Y-%m-%d")
|
||||
start_dt = curr_dt - relativedelta(days=look_back_days)
|
||||
start_date = start_dt.strftime("%Y-%m-%d")
|
||||
|
||||
news_str = ""
|
||||
for article in all_news[:limit]:
|
||||
# Handle both flat and nested structures
|
||||
if "content" in article:
|
||||
data = _extract_article_data(article)
|
||||
title = data["title"]
|
||||
publisher = data["publisher"]
|
||||
link = data["link"]
|
||||
summary = data["summary"]
|
||||
else:
|
||||
title = article.get("title", "No title")
|
||||
publisher = article.get("publisher", "Unknown")
|
||||
link = article.get("link", "")
|
||||
summary = ""
|
||||
|
||||
news_str += f"### {title} (source: {publisher})\n"
|
||||
if summary:
|
||||
news_str += f"{summary}\n"
|
||||
if link:
|
||||
news_str += f"Link: {link}\n"
|
||||
news_str += "\n"
|
||||
|
||||
return f"## Global Market News, from {start_date} to {curr_date}:\n\n{news_str}"
|
||||
|
||||
except Exception as e:
|
||||
return f"Error fetching global news: {str(e)}"
|
||||
@@ -3,16 +3,18 @@ import os
|
||||
DEFAULT_CONFIG = {
|
||||
"project_dir": os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||
"results_dir": os.getenv("TRADINGAGENTS_RESULTS_DIR", "./results"),
|
||||
"data_dir": "/Users/yluo/Documents/Code/ScAI/FR1-data",
|
||||
"data_cache_dir": os.path.join(
|
||||
os.path.abspath(os.path.join(os.path.dirname(__file__), ".")),
|
||||
"dataflows/data_cache",
|
||||
),
|
||||
# LLM settings
|
||||
"llm_provider": "openai",
|
||||
"deep_think_llm": "o4-mini",
|
||||
"quick_think_llm": "gpt-4o-mini",
|
||||
"deep_think_llm": "gpt-5.2",
|
||||
"quick_think_llm": "gpt-5-mini",
|
||||
"backend_url": "https://api.openai.com/v1",
|
||||
# Provider-specific thinking configuration
|
||||
"google_thinking_level": None, # "high", "minimal", etc.
|
||||
"openai_reasoning_effort": None, # "medium", "high", "low"
|
||||
# Debate and discussion settings
|
||||
"max_debate_rounds": 1,
|
||||
"max_risk_discuss_rounds": 1,
|
||||
@@ -20,14 +22,13 @@ DEFAULT_CONFIG = {
|
||||
# Data vendor configuration
|
||||
# Category-level configuration (default for all tools in category)
|
||||
"data_vendors": {
|
||||
"core_stock_apis": "yfinance", # Options: yfinance, alpha_vantage, local
|
||||
"technical_indicators": "yfinance", # Options: yfinance, alpha_vantage, local
|
||||
"fundamental_data": "alpha_vantage", # Options: openai, alpha_vantage, local
|
||||
"news_data": "alpha_vantage", # Options: openai, alpha_vantage, google, local
|
||||
"core_stock_apis": "yfinance", # Options: yfinance, alpha_vantage
|
||||
"technical_indicators": "yfinance", # Options: yfinance, alpha_vantage
|
||||
"fundamental_data": "alpha_vantage", # Options: alpha_vantage, yfinance
|
||||
"news_data": "yfinance", # Options: yfinance, alpha_vantage
|
||||
},
|
||||
# Tool-level configuration (takes precedence over category-level)
|
||||
"tool_vendors": {
|
||||
# Example: "get_stock_data": "alpha_vantage", # Override category default
|
||||
# Example: "get_news": "openai", # Override category default
|
||||
},
|
||||
}
|
||||
|
||||
@@ -69,16 +69,20 @@ class TradingAgentsGraph:
|
||||
exist_ok=True,
|
||||
)
|
||||
|
||||
# Initialize LLMs
|
||||
# Initialize LLMs with provider-specific thinking configuration
|
||||
llm_kwargs = self._get_provider_kwargs()
|
||||
|
||||
deep_client = create_llm_client(
|
||||
provider=self.config["llm_provider"],
|
||||
model=self.config["deep_think_llm"],
|
||||
base_url=self.config.get("backend_url"),
|
||||
**llm_kwargs,
|
||||
)
|
||||
quick_client = create_llm_client(
|
||||
provider=self.config["llm_provider"],
|
||||
model=self.config["quick_think_llm"],
|
||||
base_url=self.config.get("backend_url"),
|
||||
**llm_kwargs,
|
||||
)
|
||||
self.deep_thinking_llm = deep_client.get_llm()
|
||||
self.quick_thinking_llm = quick_client.get_llm()
|
||||
@@ -119,6 +123,23 @@ class TradingAgentsGraph:
|
||||
# Set up the graph
|
||||
self.graph = self.graph_setup.setup_graph(selected_analysts)
|
||||
|
||||
def _get_provider_kwargs(self) -> Dict[str, Any]:
|
||||
"""Get provider-specific kwargs for LLM client creation."""
|
||||
kwargs = {}
|
||||
provider = self.config.get("llm_provider", "").lower()
|
||||
|
||||
if provider == "google":
|
||||
thinking_level = self.config.get("google_thinking_level")
|
||||
if thinking_level:
|
||||
kwargs["thinking_level"] = thinking_level
|
||||
|
||||
elif provider == "openai":
|
||||
reasoning_effort = self.config.get("openai_reasoning_effort")
|
||||
if reasoning_effort:
|
||||
kwargs["reasoning_effort"] = reasoning_effort
|
||||
|
||||
return kwargs
|
||||
|
||||
def _create_tool_nodes(self) -> Dict[str, ToolNode]:
|
||||
"""Create tool nodes for different data sources using abstract methods."""
|
||||
return {
|
||||
|
||||
@@ -14,18 +14,12 @@ class AnthropicClient(BaseLLMClient):
|
||||
|
||||
def get_llm(self) -> Any:
|
||||
"""Return configured ChatAnthropic instance."""
|
||||
llm_kwargs = {
|
||||
"model": self.model,
|
||||
"max_tokens": self.kwargs.get("max_tokens", 4096),
|
||||
}
|
||||
llm_kwargs = {"model": self.model}
|
||||
|
||||
for key in ("timeout", "max_retries", "api_key"):
|
||||
for key in ("timeout", "max_retries", "api_key", "max_tokens"):
|
||||
if key in self.kwargs:
|
||||
llm_kwargs[key] = self.kwargs[key]
|
||||
|
||||
if "thinking_config" in self.kwargs:
|
||||
llm_kwargs["thinking"] = self.kwargs["thinking_config"]
|
||||
|
||||
return ChatAnthropic(**llm_kwargs)
|
||||
|
||||
def validate_model(self) -> bool:
|
||||
|
||||
@@ -6,6 +6,28 @@ from .base_client import BaseLLMClient
|
||||
from .validators import validate_model
|
||||
|
||||
|
||||
class NormalizedChatGoogleGenerativeAI(ChatGoogleGenerativeAI):
|
||||
"""ChatGoogleGenerativeAI with normalized content output.
|
||||
|
||||
Gemini 3 models return content as list: [{'type': 'text', 'text': '...'}]
|
||||
This normalizes to string for consistent downstream handling.
|
||||
"""
|
||||
|
||||
def _normalize_content(self, response):
|
||||
content = response.content
|
||||
if isinstance(content, list):
|
||||
texts = [
|
||||
item.get("text", "") if isinstance(item, dict) and item.get("type") == "text"
|
||||
else item if isinstance(item, str) else ""
|
||||
for item in content
|
||||
]
|
||||
response.content = "\n".join(t for t in texts if t)
|
||||
return response
|
||||
|
||||
def invoke(self, input, config=None, **kwargs):
|
||||
return self._normalize_content(super().invoke(input, config, **kwargs))
|
||||
|
||||
|
||||
class GoogleClient(BaseLLMClient):
|
||||
"""Client for Google Gemini models."""
|
||||
|
||||
@@ -20,14 +42,23 @@ class GoogleClient(BaseLLMClient):
|
||||
if key in self.kwargs:
|
||||
llm_kwargs[key] = self.kwargs[key]
|
||||
|
||||
if "thinking_budget" in self.kwargs and self._is_preview_model():
|
||||
llm_kwargs["thinking_budget"] = self.kwargs["thinking_budget"]
|
||||
# Map thinking_level to appropriate API param based on model
|
||||
# Gemini 3 Pro: low, high
|
||||
# Gemini 3 Flash: minimal, low, medium, high
|
||||
# Gemini 2.5: thinking_budget (0=disable, -1=dynamic)
|
||||
thinking_level = self.kwargs.get("thinking_level")
|
||||
if thinking_level:
|
||||
model_lower = self.model.lower()
|
||||
if "gemini-3" in model_lower:
|
||||
# Gemini 3 Pro doesn't support "minimal", use "low" instead
|
||||
if "pro" in model_lower and thinking_level == "minimal":
|
||||
thinking_level = "low"
|
||||
llm_kwargs["thinking_level"] = thinking_level
|
||||
else:
|
||||
# Gemini 2.5: map to thinking_budget
|
||||
llm_kwargs["thinking_budget"] = -1 if thinking_level == "high" else 0
|
||||
|
||||
return ChatGoogleGenerativeAI(**llm_kwargs)
|
||||
|
||||
def _is_preview_model(self) -> bool:
|
||||
"""Check if this is a preview model that supports thinking budget."""
|
||||
return "preview" in self.model.lower()
|
||||
return NormalizedChatGoogleGenerativeAI(**llm_kwargs)
|
||||
|
||||
def validate_model(self) -> bool:
|
||||
"""Validate model for Google."""
|
||||
|
||||
@@ -1,58 +1,75 @@
|
||||
from typing import Dict, List
|
||||
"""Model name validators for each provider.
|
||||
|
||||
VALID_MODELS: Dict[str, List[str]] = {
|
||||
Only validates model names - does NOT enforce limits.
|
||||
Let LLM providers use their own defaults for unspecified params.
|
||||
"""
|
||||
|
||||
VALID_MODELS = {
|
||||
"openai": [
|
||||
# GPT-5 series (2025)
|
||||
"gpt-5.2",
|
||||
"gpt-5.1",
|
||||
"gpt-5",
|
||||
"gpt-5-mini",
|
||||
"gpt-5-nano",
|
||||
# GPT-4.1 series (2025)
|
||||
"gpt-4.1",
|
||||
"gpt-4.1-mini",
|
||||
"gpt-4.1-nano",
|
||||
# o-series reasoning models
|
||||
"o4-mini",
|
||||
"o3",
|
||||
"o3-mini",
|
||||
"o1",
|
||||
"o1-preview",
|
||||
# GPT-4o series (legacy but still supported)
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4",
|
||||
"gpt-3.5-turbo",
|
||||
"o1",
|
||||
"o1-mini",
|
||||
"o1-preview",
|
||||
"o3-mini",
|
||||
"gpt-5-nano",
|
||||
"gpt-5-mini",
|
||||
"gpt-5",
|
||||
],
|
||||
"anthropic": [
|
||||
"claude-3-5-sonnet-20241022",
|
||||
"claude-3-5-haiku-20241022",
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-haiku-20240307",
|
||||
# Claude 4.5 series (2025)
|
||||
"claude-opus-4-5",
|
||||
"claude-sonnet-4-5",
|
||||
"claude-haiku-4-5",
|
||||
# Claude 4.x series
|
||||
"claude-opus-4-1-20250805",
|
||||
"claude-sonnet-4-20250514",
|
||||
"claude-haiku-4-5-20251001",
|
||||
"claude-opus-4-5-20251101",
|
||||
# Claude 3.7 series
|
||||
"claude-3-7-sonnet-20250219",
|
||||
# Claude 3.5 series (legacy)
|
||||
"claude-3-5-haiku-20241022",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
],
|
||||
"google": [
|
||||
"gemini-1.5-pro",
|
||||
"gemini-1.5-flash",
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.0-flash-lite",
|
||||
"gemini-2.5-pro-preview-05-06",
|
||||
"gemini-2.5-flash-preview-05-20",
|
||||
# Gemini 3 series (preview)
|
||||
"gemini-3-pro-preview",
|
||||
"gemini-3-flash-preview",
|
||||
# Gemini 2.5 series
|
||||
"gemini-2.5-pro",
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-flash-lite",
|
||||
# Gemini 2.0 series
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.0-flash-lite",
|
||||
],
|
||||
"xai": [
|
||||
"grok-beta",
|
||||
"grok-2",
|
||||
"grok-2-mini",
|
||||
"grok-3",
|
||||
"grok-3-mini",
|
||||
# Grok 4.1 series
|
||||
"grok-4-1-fast",
|
||||
"grok-4-1-fast-reasoning",
|
||||
"grok-4-1-fast-non-reasoning",
|
||||
# Grok 4 series
|
||||
"grok-4",
|
||||
"grok-4-0709",
|
||||
"grok-4-fast-reasoning",
|
||||
"grok-4-fast-non-reasoning",
|
||||
],
|
||||
"ollama": [],
|
||||
"openrouter": [],
|
||||
"vllm": [],
|
||||
}
|
||||
|
||||
|
||||
def validate_model(provider: str, model: str) -> bool:
|
||||
"""Validate that a model is supported by the provider.
|
||||
"""Check if model name is valid for the given provider.
|
||||
|
||||
For ollama, openrouter, and vllm, any model is accepted.
|
||||
For other providers, checks against VALID_MODELS.
|
||||
For ollama, openrouter, vllm - any model is accepted.
|
||||
"""
|
||||
provider_lower = provider.lower()
|
||||
|
||||
@@ -60,10 +77,6 @@ def validate_model(provider: str, model: str) -> bool:
|
||||
return True
|
||||
|
||||
if provider_lower not in VALID_MODELS:
|
||||
return False
|
||||
|
||||
valid = VALID_MODELS[provider_lower]
|
||||
if not valid:
|
||||
return True
|
||||
|
||||
return model in valid
|
||||
return model in VALID_MODELS[provider_lower]
|
||||
|
||||
Reference in New Issue
Block a user