mirror of
https://github.com/farcasclaudiu/TradingAgents.git
synced 2026-06-28 23:01:28 +03:00
feat: replace per-agent BM25 memory with persistent append-only decision log
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,18 +1,23 @@
|
||||
# TradingAgents/graph/trading_graph.py
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import json
|
||||
from datetime import date
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, Tuple, List, Optional
|
||||
|
||||
import yfinance as yf
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
from tradingagents.llm_clients import create_llm_client
|
||||
|
||||
from tradingagents.agents import *
|
||||
from tradingagents.default_config import DEFAULT_CONFIG
|
||||
from tradingagents.agents.utils.memory import FinancialSituationMemory
|
||||
from tradingagents.agents.utils.memory import TradingMemoryLog
|
||||
from tradingagents.agents.utils.agent_states import (
|
||||
AgentState,
|
||||
InvestDebateState,
|
||||
@@ -92,12 +97,7 @@ class TradingAgentsGraph:
|
||||
self.deep_thinking_llm = deep_client.get_llm()
|
||||
self.quick_thinking_llm = quick_client.get_llm()
|
||||
|
||||
# Initialize memories
|
||||
self.bull_memory = FinancialSituationMemory("bull_memory", self.config)
|
||||
self.bear_memory = FinancialSituationMemory("bear_memory", self.config)
|
||||
self.trader_memory = FinancialSituationMemory("trader_memory", self.config)
|
||||
self.invest_judge_memory = FinancialSituationMemory("invest_judge_memory", self.config)
|
||||
self.portfolio_manager_memory = FinancialSituationMemory("portfolio_manager_memory", self.config)
|
||||
self.memory_log = TradingMemoryLog(self.config)
|
||||
|
||||
# Create tool nodes
|
||||
self.tool_nodes = self._create_tool_nodes()
|
||||
@@ -111,11 +111,6 @@ class TradingAgentsGraph:
|
||||
self.quick_thinking_llm,
|
||||
self.deep_thinking_llm,
|
||||
self.tool_nodes,
|
||||
self.bull_memory,
|
||||
self.bear_memory,
|
||||
self.trader_memory,
|
||||
self.invest_judge_memory,
|
||||
self.portfolio_manager_memory,
|
||||
self.conditional_logic,
|
||||
)
|
||||
|
||||
@@ -189,14 +184,90 @@ class TradingAgentsGraph:
|
||||
),
|
||||
}
|
||||
|
||||
def _fetch_returns(
|
||||
self, ticker: str, trade_date: str, holding_days: int = 5
|
||||
) -> Tuple[Optional[float], Optional[float], Optional[int]]:
|
||||
"""Fetch raw and alpha return for ticker over holding_days from trade_date.
|
||||
|
||||
Returns (raw_return, alpha_return, actual_holding_days) or
|
||||
(None, None, None) if price data is unavailable (too recent, delisted,
|
||||
or network error).
|
||||
"""
|
||||
try:
|
||||
start = datetime.strptime(trade_date, "%Y-%m-%d")
|
||||
end = start + timedelta(days=holding_days + 7) # buffer for weekends/holidays
|
||||
end_str = end.strftime("%Y-%m-%d")
|
||||
|
||||
stock = yf.Ticker(ticker).history(start=trade_date, end=end_str)
|
||||
spy = yf.Ticker("SPY").history(start=trade_date, end=end_str)
|
||||
|
||||
if len(stock) < 2 or len(spy) < 2:
|
||||
return None, None, None
|
||||
|
||||
actual_days = min(holding_days, len(stock) - 1, len(spy) - 1)
|
||||
raw = float(
|
||||
(stock["Close"].iloc[actual_days] - stock["Close"].iloc[0])
|
||||
/ stock["Close"].iloc[0]
|
||||
)
|
||||
spy_ret = float(
|
||||
(spy["Close"].iloc[actual_days] - spy["Close"].iloc[0])
|
||||
/ spy["Close"].iloc[0]
|
||||
)
|
||||
alpha = raw - spy_ret
|
||||
return raw, alpha, actual_days
|
||||
except Exception as e:
|
||||
logger.debug("_fetch_returns failed for %s@%s: %s", ticker, trade_date, e)
|
||||
return None, None, None
|
||||
|
||||
def _resolve_pending_entries(self, ticker: str) -> None:
|
||||
"""Resolve pending log entries for ticker at the start of a new run.
|
||||
|
||||
Fetches returns for each same-ticker pending entry, generates reflections,
|
||||
then writes all updates in a single atomic batch write to avoid redundant I/O.
|
||||
Skips entries whose price data is not yet available (too recent or delisted).
|
||||
|
||||
Trade-off: only same-ticker entries are resolved per run. Entries for
|
||||
other tickers accumulate until that ticker is run again.
|
||||
"""
|
||||
pending = [e for e in self.memory_log.get_pending_entries() if e["ticker"] == ticker]
|
||||
if not pending:
|
||||
return
|
||||
|
||||
updates = []
|
||||
for entry in pending:
|
||||
raw, alpha, days = self._fetch_returns(ticker, entry["date"])
|
||||
if raw is None:
|
||||
continue # price not available yet — try again next run
|
||||
reflection = self.reflector.reflect_on_final_decision(
|
||||
final_decision=entry.get("decision", ""),
|
||||
raw_return=raw,
|
||||
alpha_return=alpha,
|
||||
)
|
||||
updates.append({
|
||||
"ticker": ticker,
|
||||
"trade_date": entry["date"],
|
||||
"raw_return": raw,
|
||||
"alpha_return": alpha,
|
||||
"holding_days": days,
|
||||
"reflection": reflection,
|
||||
})
|
||||
|
||||
if updates:
|
||||
self.memory_log.batch_update_with_outcomes(updates)
|
||||
|
||||
def propagate(self, company_name, trade_date):
|
||||
"""Run the trading agents graph for a company on a specific date."""
|
||||
|
||||
self.ticker = company_name
|
||||
|
||||
# Initialize state
|
||||
# Resolve any pending log entries for this ticker before the pipeline runs.
|
||||
# This adds the outcome + reflection from the previous run at zero latency cost.
|
||||
self._resolve_pending_entries(company_name)
|
||||
|
||||
# Initialize state — inject memory log context for PM
|
||||
past_context = self.memory_log.get_past_context(company_name)
|
||||
init_agent_state = self.propagator.create_initial_state(
|
||||
company_name, trade_date
|
||||
company_name, trade_date, past_context=past_context
|
||||
)
|
||||
args = self.propagator.get_graph_args()
|
||||
|
||||
@@ -221,6 +292,13 @@ class TradingAgentsGraph:
|
||||
# Log state
|
||||
self._log_state(trade_date, final_state)
|
||||
|
||||
# Store decision for deferred reflection.
|
||||
self.memory_log.store_decision(
|
||||
ticker=company_name,
|
||||
trade_date=trade_date,
|
||||
final_trade_decision=final_state["final_trade_decision"],
|
||||
)
|
||||
|
||||
# Return decision and processed signal
|
||||
return final_state, self.process_signal(final_state["final_trade_decision"])
|
||||
|
||||
@@ -264,24 +342,6 @@ class TradingAgentsGraph:
|
||||
with open(log_path, "w", encoding="utf-8") as f:
|
||||
json.dump(self.log_states_dict[str(trade_date)], f, indent=4)
|
||||
|
||||
def reflect_and_remember(self, returns_losses):
|
||||
"""Reflect on decisions and update memory based on returns."""
|
||||
self.reflector.reflect_bull_researcher(
|
||||
self.curr_state, returns_losses, self.bull_memory
|
||||
)
|
||||
self.reflector.reflect_bear_researcher(
|
||||
self.curr_state, returns_losses, self.bear_memory
|
||||
)
|
||||
self.reflector.reflect_trader(
|
||||
self.curr_state, returns_losses, self.trader_memory
|
||||
)
|
||||
self.reflector.reflect_invest_judge(
|
||||
self.curr_state, returns_losses, self.invest_judge_memory
|
||||
)
|
||||
self.reflector.reflect_portfolio_manager(
|
||||
self.curr_state, returns_losses, self.portfolio_manager_memory
|
||||
)
|
||||
|
||||
def process_signal(self, full_signal):
|
||||
"""Process a signal to extract the core decision."""
|
||||
return self.signal_processor.process_signal(full_signal)
|
||||
|
||||
Reference in New Issue
Block a user