import torch
import numpy as np

import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

# Define a simple feed-forward neural network
class FeedForwardNet(nn.Module):
    def __init__(self):
        super(FeedForwardNet, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(1, 16),
            nn.ReLU(),
            nn.Linear(16, 16),
            nn.ReLU(),
            nn.Linear(16, 1)
        )
    
    def forward(self, x):
        return self.net(x)

def visualize():
    # Generate data for visualization
    x_test = torch.linspace(0, 1, 500).view(-1, 1)
    with torch.no_grad():
        y_pred = model(x_test)

    # Plot results
    plt.figure(figsize=(10, 6))
    plt.plot(x_test.numpy(), x_test.numpy()**2, 'b-', label='True function: f(x) = x²')
    plt.plot(x_test.numpy(), y_pred.numpy(), 'r--', label='Neural network approximation')
    plt.scatter(x_train.numpy(), y_train.numpy(), c='g', alpha=0.4, label='Training points')
    plt.title('Neural Network Approximation of f(x) = x²')
    plt.xlabel('x')
    plt.ylabel('f(x)')
    plt.legend()
    plt.grid(True)
    plt.show()

# Generate training data
x_train = torch.linspace(0, 1, 100).view(-1, 1)
y_train = x_train ** 2

# Initialize model, loss function, and optimizer
model = FeedForwardNet()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
#optimizer = optim.SGD(model.parameters(), lr=0.01)

# Train the model
num_epochs = 10000
loss_history = []
for epoch in range(num_epochs):
    # Forward pass and loss calculation
    y_pred = model(x_train)
    loss = criterion(y_pred, y_train)
    
    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if (epoch + 1) % 50 == 0:
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.6f}')
        loss_history.append(loss.item())
        #visualize()

# Plot the loss history
plt.figure(figsize=(10, 6))
plt.loglog(range(50, num_epochs + 1, 50), loss_history, 'b-')
plt.title('Training Loss over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.show()