import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')


class CustomOUProcess:
    def __init__(self, theta=1.0, sigma=1.0, initial=1.0, T=1.0, rng=None):
        self.theta = theta
        self.sigma = sigma
        self.initial = initial
        self.T = T
        self.rng = rng or np.random.default_rng()

    def simulate(self, n_steps, n_paths):
        dt = self.T / n_steps
        times = np.linspace(0, self.T, n_steps)

        # Multiple of sigma for uniform distribution
        sigma_multiple = 30  # Large enough to observe mean reversion clearly

        # Sample uniformly around a multiple of sigma
        initial_values = self.rng.uniform(
            low=-sigma_multiple * self.sigma,
            high=sigma_multiple * self.sigma,
            size=n_paths
        )

        # Initialize paths array
        paths = np.zeros((n_paths, n_steps))
        paths[:, 0] = initial_values

        # Simulate each step using Euler-Maruyama
        for i in range(1, n_steps):
            dW = self.rng.normal(scale=np.sqrt(dt), size=n_paths)
            paths[:, i] = (
                    paths[:, i - 1]
                    - self.theta * paths[:, i - 1] * dt
                    + self.sigma * dW
            )

        return paths, times

    def draw(self, n_steps, n_paths, style=None, title="Ornstein–Uhlenbeck Process"):
        if style:
            plt.style.use(style)

        # Simulate the paths
        paths, times = self.simulate(n_steps, n_paths)

        # Calculate the expectation across all paths for each time step
        expectations = np.mean(paths, axis=0)

        # Get the terminal values for all paths
        terminal_values = paths[:, -1]

        # Color each path based on its terminal value
        norm = plt.Normalize(terminal_values.min(), terminal_values.max())
        colors = matplotlib.cm.plasma(norm(terminal_values))

        fig, ax = plt.subplots(figsize=(12, 7))

        # Plot each path with corresponding color
        for i in range(n_paths):
            ax.plot(times, paths[i, :], color=colors[i], alpha=0.6)

        # Plot the expectation line for all paths
        ax.plot(times, expectations, 'b--', label='$E[X_t]$', linewidth=2)

        # Customize labels and title with larger font size
        ax.set_title(title, fontsize=22)
        ax.set_xlabel('$t$', fontsize=18)
        ax.set_ylabel(r'$X_{t_0}$', fontsize=18)
        ax.tick_params(axis='both', which='major', labelsize=14)
        ax.legend(fontsize=16)

        # Add colorbar to indicate terminal values
        sm = plt.cm.ScalarMappable(cmap='plasma', norm=norm)
        sm.set_array([])
        cbar = plt.colorbar(sm, ax=ax, pad=0.02)
        cbar.set_label('$X_T$', rotation=270, labelpad=20, fontsize=18)
        cbar.ax.tick_params(labelsize=14)

        plt.savefig('output.png')
        plt.close()

