from datetime import datetime

import matplotlib.pyplot as plt
import pandas_ta as ta

from boatwright import Backtest, Strategy
from boatwright.Brokers import BacktestBroker
from boatwright.Data.OHLCV import CSVdatabase
from boatwright.Indicators import crossover, s_max
from boatwright.Orders import MarketOrder
from boatwright.Visualization import plot_backtest


class PandasTaExample(Strategy):
    """
    Example strategy using indicators from the pandas_ta library.
    """

    def __init__(self, period: int = 7, std_devs: int = 2, trigger_period: int = 10, symbol: str = "BTC", strategy_id="pandas_ta"):
        super().__init__(symbol, strategy_id)
        self.p["period"] = period
        self.p["std_devs"] = std_devs
        self.p["trigger_period"] = trigger_period
        self.position = False

    def calculate_signals(self, df):
        bbands = ta.bbands(df["close"], length=self.p["period"], std=int(self.p["std_devs"]))
        df["lower"], df["avg"], df["upper"] = bbands.iloc[:, 0], bbands.iloc[:, 1], bbands.iloc[:, 2]
        df["bband_buy"] = crossover(df["close"], df["lower"])
        df["bband_sell"] = crossover(df["close"], df["upper"])

        df["rsi"] = ta.rsi(df["close"], length=self.p["period"])
        df["rsi_buy"] = crossover(df["close"], 30)
        df["rsi_sell"] = crossover(df["close"], 70)

        stoch = ta.stoch(
            df["high"],
            df["low"],
            df["close"],
            k=self.p["period"],
            d=int(self.p["period"] / 2),
            smooth_k=int(self.p["period"] / 4),
        )
        df["stoch_k"], df["stoch_d"] = stoch.iloc[:, 0], stoch.iloc[:, 1]
        df["stoch_buy"] = crossover(df["stoch_k"], df["stoch_d"])
        df["stoch_sell"] = crossover(df["stoch_d"], df["stoch_k"])

        p = self.p["trigger_period"]
        df["buy_trigger"] = (s_max(df["bband_buy"], p) + s_max(df["rsi_buy"], p) + s_max(df["stoch_buy"], p)) >= 2
        df["sell_trigger"] = (s_max(df["bband_sell"], p) + s_max(df["rsi_sell"], p) + s_max(df["stoch_sell"], p)) > 1

        return df

    def calc_prerequisite_data_length(self):
        return self.p["period"]

    def step(self, row):
        buy_trigger = row["buy_trigger"]
        sell_trigger = row["sell_trigger"]

        if buy_trigger and not self.position:
            buy_order = MarketOrder(symbol=self.symbol, side="BUY", frac=0.95)
            self.broker.place_order(buy_order)
            self.position = True
        if sell_trigger and self.position:
            sell_order = MarketOrder(symbol=self.symbol, side="SELL", frac=1)
            self.broker.place_order(sell_order)
            self.position = False

    def plot_info(self):
        return {
            "avg": {"ax_id": "price", "color": "C0", "linestyle": "-", "label": "bbands"},
            "lower": {"ax_id": "price", "color": "C0", "linestyle": "--", "label": None},
            "upper": {"ax_id": "price", "color": "C0", "linestyle": "--", "label": None},
            "rsi": {"ax_id": "rsi", "color": "black"},
            "stoch_k": {"ax_id": "stoch", "color": "C0"},
            "stoch_d": {"ax_id": "stoch", "color": "C1"},
            "buy_trigger": {"ax_id": "triggers", "color": "lime"},
            "sell_trigger": {"ax_id": "triggers", "color": "red"},
        }


if __name__ == "__main__":
    strategy = PandasTaExample(period=200, std_devs=2, trigger_period=20, symbol="BTC")

    database = CSVdatabase(source="ALPACA", debug=False, dir="quickstart_data/")
    start = datetime(year=2025, month=2, day=1, hour=1, minute=0)
    end = datetime(year=2025, month=2, day=20, hour=12, minute=0)
    data = database.load(
        symbol=strategy.symbol,
        start=start,
        end=end,
        prerequisite_data_length=strategy.calc_prerequisite_data_length(),
        granularity=1,
        granularity_unit="MINUTE",
        verbose=True,
    )

    broker_model = BacktestBroker(taker_fee=0, maker_fee=0, slippage=0, quote_symbol="USD")
    backtest = Backtest(strategy=strategy, data=data, broker=broker_model, debug=False)
    backtest.run(verbose=True)

    figs, axs = plot_backtest(backtest, candles=False)

    # saving figures for documentation
    figs["aum"].savefig("../docs/examples/images/pandas_ta/aum.png")
    figs["price"].savefig("../docs/examples/images/pandas_ta/price.png")
    figs["rsi"].savefig("../docs/examples/images/pandas_ta/rsi.png")
    figs["stoch"].savefig("../docs/examples/images/pandas_ta/stoch.png")
    figs["triggers"].savefig("../docs/examples/images/pandas_ta/triggers.png")

    plt.show()
