92 lines
2.6 KiB
Python
92 lines
2.6 KiB
Python
#/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
# Simple script to show statistics a Neural Network training session.
|
|
|
|
# - Imports
|
|
from matplotlib import pyplot as plt
|
|
|
|
# Read the data from the file
|
|
def read_data(file_path: str) -> list:
|
|
with open(file_path, 'r') as file:
|
|
lines = file.readlines()
|
|
return [line.strip() for line in lines if line.strip()]
|
|
|
|
# Parse the data from the file
|
|
# This function takes the data read from the file and parses it.
|
|
# Each line contains multiple values separated by commas.
|
|
# The function extracts the values and returns them as a list of lists.
|
|
def parse_data(data: list, numEntries: int) -> list[list]:
|
|
parsed_data = [[] for _ in range(numEntries)]
|
|
for line in data:
|
|
values = line.split(',')
|
|
for i in range(numEntries):
|
|
try:
|
|
parsed_data[i].append(float(values[i]))
|
|
except (ValueError, IndexError):
|
|
print(f"Error parsing line: {line}")
|
|
continue
|
|
return parsed_data
|
|
|
|
# Plot the data using matplotlib
|
|
def plot_data(
|
|
data: list,
|
|
title: str,
|
|
xlabel: str,
|
|
ylabel: str,
|
|
log_scale: bool = False
|
|
) -> None:
|
|
# Plot the data
|
|
plt.figure(figsize=(10, 5))
|
|
plt.plot(data, label=title)
|
|
# Add labels and title
|
|
plt.xlabel(xlabel)
|
|
plt.ylabel(ylabel)
|
|
plt.title(title)
|
|
plt.legend()
|
|
plt.grid()
|
|
# Show the plot
|
|
if log_scale:
|
|
plt.yscale('log')
|
|
plt.show()
|
|
|
|
def plot_accuracies(
|
|
learn_rates: list,
|
|
train: list,
|
|
eval: list
|
|
) -> None:
|
|
# Create 2 subplots, one for the learning rates in relation to the epochs
|
|
# and one for the training and evaluation accuracies
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 10))
|
|
# Plot the learning rates
|
|
ax1.plot(learn_rates, label='Learning Rate')
|
|
ax1.set_xlabel('Epochs')
|
|
ax1.set_ylabel('Learning Rate')
|
|
ax1.set_title('Learning Rate vs Epochs')
|
|
ax1.legend()
|
|
ax1.grid()
|
|
# Plot the accuracies
|
|
ax2.plot(train, label='Training Accuracy')
|
|
ax2.plot(eval, label='Evaluation Accuracy')
|
|
ax2.set_xlabel('Epochs')
|
|
ax2.set_ylabel('Accuracy')
|
|
ax2.set_title('Training and Evaluation Accuracy vs Epochs')
|
|
ax2.legend()
|
|
ax2.grid()
|
|
# Show the plots
|
|
plt.tight_layout()
|
|
plt.show()
|
|
|
|
# Main function to execute the script
|
|
def main():
|
|
# Read the data from the file
|
|
data = read_data('stats/stats_256_128.txt')
|
|
# Parse the data
|
|
learn_rates, data_accuracy, eval_accuracy = parse_data(data, 3)
|
|
# Plot the accuracies
|
|
plot_accuracies(learn_rates, data_accuracy, eval_accuracy)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|
|
# - End of script |