import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from ouprocess import CustomOUProcess

T = 1.0
n_steps = 500
n_paths_per_setting = 5000

# 4 parameter combinations
parameter_sets = [
    (2.0, 1.0),
    (0.2, 1.0),
    (0.5, 4.0),
    (0.5, 0.25),
]

all_paths = []
all_labels = []

for theta, sigma_squared in parameter_sets:
    sigma = np.sqrt(sigma_squared)
    ou_process = CustomOUProcess(theta=theta, sigma=sigma, T=T)
    paths, _ = ou_process.simulate(n_steps=n_steps, n_paths=n_paths_per_setting)
    all_paths.append(paths)
    labels = np.array([[theta, sigma_squared]] * n_paths_per_setting, dtype=np.float32)
    all_labels.append(labels)

X_data = np.vstack(all_paths).astype(np.float32).reshape(-1, n_steps, 1)
y_data = np.vstack(all_labels).astype(np.float32)

# Normalize inputs
X_mean = X_data.mean()
X_std = X_data.std()
X_data = (X_data - X_mean) / X_std

X_tensor = torch.tensor(X_data)
y_tensor = torch.tensor(y_data)

# Training/validation split
X_train, X_val, y_train, y_val = train_test_split(X_tensor, y_tensor, test_size=0.2, random_state=42)
train_dataset = torch.utils.data.TensorDataset(X_train, y_train)
val_dataset = torch.utils.data.TensorDataset(X_val, y_val)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=128)


# Model definition
class OUParameterLSTM(nn.Module):
    def __init__(self, input_size=1, hidden_size=64, num_layers=2, output_size=2):
        super(OUParameterLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.elu = nn.ELU()
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        _, (hn, _) = self.lstm(x)
        x = self.elu(hn[-1])
        return self.fc(x)


model = OUParameterLSTM()

# Loss function
theta_weight = 1.0
sigma2_weight = 0.5
criterion = nn.HuberLoss(delta=1.0)


def weighted_huber_loss(pred, target):
    loss_theta = criterion(pred[:, 0], target[:, 0])  # θ
    loss_sigma2 = criterion(pred[:, 1], target[:, 1])  # σ²
    return theta_weight * loss_theta + sigma2_weight * loss_sigma2


optimizer = optim.Adam(model.parameters(), lr=0.001)

with open("ou_parameter_lstm_loss_log.txt", "w") as log_file:
    num_epochs = 100
    print("Training started...")
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0
        for batch_X, batch_y in tqdm(train_loader):
            optimizer.zero_grad()
            outputs = model(batch_X)
            loss = weighted_huber_loss(outputs, batch_y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            epoch_loss += loss.item()
        avg_train_loss = epoch_loss / len(train_loader)

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch_X, batch_y in val_loader:
                outputs = model(batch_X)
                loss = weighted_huber_loss(outputs, batch_y)
                val_loss += loss.item()
        avg_val_loss = val_loss / len(val_loader)

        log_line = f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {avg_train_loss:.6f}, Val Loss: {avg_val_loss:.6f}\n"
        print(log_line.strip())
        log_file.write(log_line)

print("Training complete.")

torch.save({
    'model_state_dict': model.state_dict(),
    'x_mean': X_mean,
    'x_std': X_std
}, 'ou_parameter_lstm.pth')
print("Model saved to 'ou_parameter_lstm.pth'")
