mirror of
https://github.com/farcasclaudiu/TradingAgents.git
synced 2026-06-29 07:01:17 +03:00
feat: add LangGraph checkpoint resume for crash recovery (#594)
Long analyses can take many minutes; a crash or interruption forced users to re-run from scratch and re-pay every LLM call. This adds an opt-in checkpoint layer backed by per-ticker SQLite databases so the graph resumes from the last successful node. How to use: - CLI: tradingagents analyze --checkpoint - CLI: tradingagents analyze --clear-checkpoints - Python: config["checkpoint_enabled"] = True Lifecycle: - propagate() recompiles the graph with a SqliteSaver when enabled and injects a deterministic thread_id derived from ticker+date so the same ticker+date resumes while a different date starts fresh. - On successful completion the per-thread checkpoint rows are cleared. - The context manager is closed in a try/finally so a crash never leaks the SQLite connection or leaves the graph in checkpoint mode. Storage: ~/.tradingagents/cache/checkpoints/<TICKER>.db (override via TRADINGAGENTS_CACHE_DIR). The checkpointer module is new (tradingagents/graph/checkpointer.py) and the GraphSetup now returns the uncompiled workflow so it can be recompiled with a saver when needed. Adds langgraph-checkpoint-sqlite>=2.0.0 dependency. 3 new tests verify the crash/resume cycle and that a different date starts fresh.
This commit is contained in:
@@ -38,6 +38,7 @@ from tradingagents.agents.utils.agent_utils import (
|
||||
get_global_news
|
||||
)
|
||||
|
||||
from .checkpointer import checkpoint_step, clear_checkpoint, get_checkpointer, thread_id
|
||||
from .conditional_logic import ConditionalLogic
|
||||
from .setup import GraphSetup
|
||||
from .propagation import Propagator
|
||||
@@ -123,8 +124,10 @@ class TradingAgentsGraph:
|
||||
self.ticker = None
|
||||
self.log_states_dict = {} # date to full state dict
|
||||
|
||||
# Set up the graph
|
||||
self.graph = self.graph_setup.setup_graph(selected_analysts)
|
||||
# Set up the graph: keep the workflow for recompilation with a checkpointer.
|
||||
self.workflow = self.graph_setup.setup_graph(selected_analysts)
|
||||
self.graph = self.workflow.compile()
|
||||
self._checkpointer_ctx = None
|
||||
|
||||
def _get_provider_kwargs(self) -> Dict[str, Any]:
|
||||
"""Get provider-specific kwargs for LLM client creation."""
|
||||
@@ -259,23 +262,58 @@ class TradingAgentsGraph:
|
||||
self.memory_log.batch_update_with_outcomes(updates)
|
||||
|
||||
def propagate(self, company_name, trade_date):
|
||||
"""Run the trading agents graph for a company on a specific date."""
|
||||
"""Run the trading agents graph for a company on a specific date.
|
||||
|
||||
When ``checkpoint_enabled`` is set in config, the graph is recompiled
|
||||
with a per-ticker SqliteSaver so a crashed run can resume from the last
|
||||
successful node on a subsequent invocation with the same ticker+date.
|
||||
"""
|
||||
self.ticker = company_name
|
||||
|
||||
# Resolve any pending log entries for this ticker before the pipeline runs.
|
||||
# This adds the outcome + reflection from the previous run at zero latency cost.
|
||||
# Resolve any pending memory-log entries for this ticker before the pipeline runs.
|
||||
self._resolve_pending_entries(company_name)
|
||||
|
||||
# Initialize state — inject memory log context for PM
|
||||
# Recompile with a checkpointer if the user opted in.
|
||||
if self.config.get("checkpoint_enabled"):
|
||||
self._checkpointer_ctx = get_checkpointer(
|
||||
self.config["data_cache_dir"], company_name
|
||||
)
|
||||
saver = self._checkpointer_ctx.__enter__()
|
||||
self.graph = self.workflow.compile(checkpointer=saver)
|
||||
|
||||
step = checkpoint_step(
|
||||
self.config["data_cache_dir"], company_name, str(trade_date)
|
||||
)
|
||||
if step is not None:
|
||||
logger.info(
|
||||
"Resuming from step %d for %s on %s", step, company_name, trade_date
|
||||
)
|
||||
else:
|
||||
logger.info("Starting fresh for %s on %s", company_name, trade_date)
|
||||
|
||||
try:
|
||||
return self._run_graph(company_name, trade_date)
|
||||
finally:
|
||||
if self._checkpointer_ctx is not None:
|
||||
self._checkpointer_ctx.__exit__(None, None, None)
|
||||
self._checkpointer_ctx = None
|
||||
self.graph = self.workflow.compile()
|
||||
|
||||
def _run_graph(self, company_name, trade_date):
|
||||
"""Execute the graph and write the resulting state to disk and memory log."""
|
||||
# Initialize state — inject memory log context for PM.
|
||||
past_context = self.memory_log.get_past_context(company_name)
|
||||
init_agent_state = self.propagator.create_initial_state(
|
||||
company_name, trade_date, past_context=past_context
|
||||
)
|
||||
args = self.propagator.get_graph_args()
|
||||
|
||||
# Inject thread_id so same ticker+date resumes, different date starts fresh.
|
||||
if self.config.get("checkpoint_enabled"):
|
||||
tid = thread_id(company_name, str(trade_date))
|
||||
args.setdefault("config", {}).setdefault("configurable", {})["thread_id"] = tid
|
||||
|
||||
if self.debug:
|
||||
# Debug mode with tracing
|
||||
trace = []
|
||||
for chunk in self.graph.stream(init_agent_state, **args):
|
||||
if len(chunk["messages"]) == 0:
|
||||
@@ -283,26 +321,29 @@ class TradingAgentsGraph:
|
||||
else:
|
||||
chunk["messages"][-1].pretty_print()
|
||||
trace.append(chunk)
|
||||
|
||||
final_state = trace[-1]
|
||||
else:
|
||||
# Standard mode without tracing
|
||||
final_state = self.graph.invoke(init_agent_state, **args)
|
||||
|
||||
# Store current state for reflection
|
||||
# Store current state for reflection.
|
||||
self.curr_state = final_state
|
||||
|
||||
# Log state
|
||||
# Log state to disk.
|
||||
self._log_state(trade_date, final_state)
|
||||
|
||||
# Store decision for deferred reflection.
|
||||
# Store decision for deferred reflection on the next same-ticker run.
|
||||
self.memory_log.store_decision(
|
||||
ticker=company_name,
|
||||
trade_date=trade_date,
|
||||
final_trade_decision=final_state["final_trade_decision"],
|
||||
)
|
||||
|
||||
# Return decision and processed signal
|
||||
# Clear checkpoint on successful completion to avoid stale state.
|
||||
if self.config.get("checkpoint_enabled"):
|
||||
clear_checkpoint(
|
||||
self.config["data_cache_dir"], company_name, str(trade_date)
|
||||
)
|
||||
|
||||
return final_state, self.process_signal(final_state["final_trade_decision"])
|
||||
|
||||
def _log_state(self, trade_date, final_state):
|
||||
|
||||
Reference in New Issue
Block a user