from input_samples import rand_low_volatility
from scipy.optimize import basinhopping
from tqdm import tqdm
from scalene import scalene_profiler
import numpy as np

T = 1.0
n_steps = 500
dt = T / n_steps


# GMM Estimation
def gmm_initial_estimate(X, dt):
    X = np.array(X)
    rho_hat = np.corrcoef(X[:-1], X[1:])[0, 1]
    rho_hat = np.clip(rho_hat, 1e-4, 0.999)
    theta_hat = max(-np.log(rho_hat) / dt, 0.5)
    sigma_sq_hat = 2 * theta_hat * np.var(X)
    return theta_hat, sigma_sq_hat


# Log-likelihood function
def log_likelihood(params, X, dt):
    theta, sigma_sq = params
    X = np.array(X)
    n = len(X)
    V = np.maximum((sigma_sq / (2 * theta)) * (1 - np.exp(-2 * theta * dt)), 1e-6)
    log_likelihood = -0.5 * np.sum(np.log(2 * np.pi * V) + (X[1:] - X[:-1] * np.exp(-theta * dt)) ** 2 / V)
    return -log_likelihood


# MLE Estimation with BFGS
def estimate_parameters(X, dt):
    X = np.array(X)
    theta_init, sigma_sq_init = gmm_initial_estimate(X, dt)
    init_params = np.array([theta_init, sigma_sq_init])
    bounds = [(0.01, 10), (0.1, 10)]
    result = basinhopping(log_likelihood, init_params,
                          minimizer_kwargs={"args": (X, dt), "method": "L-BFGS-B", "bounds": bounds}, niter=20)

    return result.x if result.success else (theta_init, sigma_sq_init)


theta_estimates = []
sigma_sq_estimates = []

# Memory profiling
scalene_profiler.start()

for i in tqdm(range(rand_low_volatility.shape[0]), desc="Processing Paths",
              unit="path"):
    X_path = rand_low_volatility[i, :]
    theta_hat, sigma_sq_hat = estimate_parameters(X_path, dt)

    theta_estimates.append(theta_hat)
    sigma_sq_estimates.append(sigma_sq_hat)

scalene_profiler.stop()

print("\nProcessing complete.")

true_theta = 0.5
true_sigma_sq = .25

# Compute statistics
theta_mean = np.mean(theta_estimates)
theta_median = np.median(theta_estimates)
theta_std = np.std(theta_estimates, ddof=1)
theta_rmse = np.sqrt(np.mean((np.array(theta_estimates) - true_theta) ** 2))

sigma_sq_mean = np.mean(sigma_sq_estimates)
sigma_sq_median = np.median(sigma_sq_estimates)
sigma_sq_std = np.std(sigma_sq_estimates, ddof=1)
sigma_sq_rmse = np.sqrt(np.mean((np.array(sigma_sq_estimates) - true_sigma_sq) ** 2))

print("\nFinal Statistics:")
print(f"Theta - Mean: {theta_mean:.4f}, Median: {theta_median:.4f}, Std: {theta_std:.4f}, RMSE: {theta_rmse:.4f}")
print(
    f"Sigma^2 - Mean: {sigma_sq_mean:.4f}, Median: {sigma_sq_median:.4f}, Std: {sigma_sq_std:.4f}, RMSE: {sigma_sq_rmse:.4f}")
