Files
Perceptron_Pizz2/stats/stats.py
2026-03-29 20:21:25 +02:00

92 lines
2.6 KiB
Python

#/usr/bin/env python3
# -*- coding: utf-8 -*-
# Simple script to show statistics a Neural Network training session.
# - Imports
from matplotlib import pyplot as plt
# Read the data from the file
def read_data(file_path: str) -> list:
with open(file_path, 'r') as file:
lines = file.readlines()
return [line.strip() for line in lines if line.strip()]
# Parse the data from the file
# This function takes the data read from the file and parses it.
# Each line contains multiple values separated by commas.
# The function extracts the values and returns them as a list of lists.
def parse_data(data: list, numEntries: int) -> list[list]:
parsed_data = [[] for _ in range(numEntries)]
for line in data:
values = line.split(',')
for i in range(numEntries):
try:
parsed_data[i].append(float(values[i]))
except (ValueError, IndexError):
print(f"Error parsing line: {line}")
continue
return parsed_data
# Plot the data using matplotlib
def plot_data(
data: list,
title: str,
xlabel: str,
ylabel: str,
log_scale: bool = False
) -> None:
# Plot the data
plt.figure(figsize=(10, 5))
plt.plot(data, label=title)
# Add labels and title
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title)
plt.legend()
plt.grid()
# Show the plot
if log_scale:
plt.yscale('log')
plt.show()
def plot_accuracies(
learn_rates: list,
train: list,
eval: list
) -> None:
# Create 2 subplots, one for the learning rates in relation to the epochs
# and one for the training and evaluation accuracies
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 10))
# Plot the learning rates
ax1.plot(learn_rates, label='Learning Rate')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Learning Rate')
ax1.set_title('Learning Rate vs Epochs')
ax1.legend()
ax1.grid()
# Plot the accuracies
ax2.plot(train, label='Training Accuracy')
ax2.plot(eval, label='Evaluation Accuracy')
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Accuracy')
ax2.set_title('Training and Evaluation Accuracy vs Epochs')
ax2.legend()
ax2.grid()
# Show the plots
plt.tight_layout()
plt.show()
# Main function to execute the script
def main():
# Read the data from the file
data = read_data('stats/stats_256_128.txt')
# Parse the data
learn_rates, data_accuracy, eval_accuracy = parse_data(data, 3)
# Plot the accuracies
plot_accuracies(learn_rates, data_accuracy, eval_accuracy)
if __name__ == "__main__":
main()
# - End of script