init lfs
This commit is contained in:
92
stats/stats.py
Normal file
92
stats/stats.py
Normal file
@@ -0,0 +1,92 @@
|
||||
#/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Simple script to show statistics a Neural Network training session.
|
||||
|
||||
# - Imports
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
# Read the data from the file
|
||||
def read_data(file_path: str) -> list:
|
||||
with open(file_path, 'r') as file:
|
||||
lines = file.readlines()
|
||||
return [line.strip() for line in lines if line.strip()]
|
||||
|
||||
# Parse the data from the file
|
||||
# This function takes the data read from the file and parses it.
|
||||
# Each line contains multiple values separated by commas.
|
||||
# The function extracts the values and returns them as a list of lists.
|
||||
def parse_data(data: list, numEntries: int) -> list[list]:
|
||||
parsed_data = [[] for _ in range(numEntries)]
|
||||
for line in data:
|
||||
values = line.split(',')
|
||||
for i in range(numEntries):
|
||||
try:
|
||||
parsed_data[i].append(float(values[i]))
|
||||
except (ValueError, IndexError):
|
||||
print(f"Error parsing line: {line}")
|
||||
continue
|
||||
return parsed_data
|
||||
|
||||
# Plot the data using matplotlib
|
||||
def plot_data(
|
||||
data: list,
|
||||
title: str,
|
||||
xlabel: str,
|
||||
ylabel: str,
|
||||
log_scale: bool = False
|
||||
) -> None:
|
||||
# Plot the data
|
||||
plt.figure(figsize=(10, 5))
|
||||
plt.plot(data, label=title)
|
||||
# Add labels and title
|
||||
plt.xlabel(xlabel)
|
||||
plt.ylabel(ylabel)
|
||||
plt.title(title)
|
||||
plt.legend()
|
||||
plt.grid()
|
||||
# Show the plot
|
||||
if log_scale:
|
||||
plt.yscale('log')
|
||||
plt.show()
|
||||
|
||||
def plot_accuracies(
|
||||
learn_rates: list,
|
||||
train: list,
|
||||
eval: list
|
||||
) -> None:
|
||||
# Create 2 subplots, one for the learning rates in relation to the epochs
|
||||
# and one for the training and evaluation accuracies
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 10))
|
||||
# Plot the learning rates
|
||||
ax1.plot(learn_rates, label='Learning Rate')
|
||||
ax1.set_xlabel('Epochs')
|
||||
ax1.set_ylabel('Learning Rate')
|
||||
ax1.set_title('Learning Rate vs Epochs')
|
||||
ax1.legend()
|
||||
ax1.grid()
|
||||
# Plot the accuracies
|
||||
ax2.plot(train, label='Training Accuracy')
|
||||
ax2.plot(eval, label='Evaluation Accuracy')
|
||||
ax2.set_xlabel('Epochs')
|
||||
ax2.set_ylabel('Accuracy')
|
||||
ax2.set_title('Training and Evaluation Accuracy vs Epochs')
|
||||
ax2.legend()
|
||||
ax2.grid()
|
||||
# Show the plots
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
# Main function to execute the script
|
||||
def main():
|
||||
# Read the data from the file
|
||||
data = read_data('stats/stats_256_128.txt')
|
||||
# Parse the data
|
||||
learn_rates, data_accuracy, eval_accuracy = parse_data(data, 3)
|
||||
# Plot the accuracies
|
||||
plot_accuracies(learn_rates, data_accuracy, eval_accuracy)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
# - End of script
|
||||
Reference in New Issue
Block a user