#/usr/bin/env python3 # -*- coding: utf-8 -*- # Simple script to show statistics a Neural Network training session. # - Imports from matplotlib import pyplot as plt # Read the data from the file def read_data(file_path: str) -> list: with open(file_path, 'r') as file: lines = file.readlines() return [line.strip() for line in lines if line.strip()] # Parse the data from the file # This function takes the data read from the file and parses it. # Each line contains multiple values separated by commas. # The function extracts the values and returns them as a list of lists. def parse_data(data: list, numEntries: int) -> list[list]: parsed_data = [[] for _ in range(numEntries)] for line in data: values = line.split(',') for i in range(numEntries): try: parsed_data[i].append(float(values[i])) except (ValueError, IndexError): print(f"Error parsing line: {line}") continue return parsed_data # Plot the data using matplotlib def plot_data( data: list, title: str, xlabel: str, ylabel: str, log_scale: bool = False ) -> None: # Plot the data plt.figure(figsize=(10, 5)) plt.plot(data, label=title) # Add labels and title plt.xlabel(xlabel) plt.ylabel(ylabel) plt.title(title) plt.legend() plt.grid() # Show the plot if log_scale: plt.yscale('log') plt.show() def plot_accuracies( learn_rates: list, train: list, eval: list ) -> None: # Create 2 subplots, one for the learning rates in relation to the epochs # and one for the training and evaluation accuracies fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 10)) # Plot the learning rates ax1.plot(learn_rates, label='Learning Rate') ax1.set_xlabel('Epochs') ax1.set_ylabel('Learning Rate') ax1.set_title('Learning Rate vs Epochs') ax1.legend() ax1.grid() # Plot the accuracies ax2.plot(train, label='Training Accuracy') ax2.plot(eval, label='Evaluation Accuracy') ax2.set_xlabel('Epochs') ax2.set_ylabel('Accuracy') ax2.set_title('Training and Evaluation Accuracy vs Epochs') ax2.legend() ax2.grid() # Show the plots plt.tight_layout() plt.show() # Main function to execute the script def main(): # Read the data from the file data = read_data('stats/stats_256_128.txt') # Parse the data learn_rates, data_accuracy, eval_accuracy = parse_data(data, 3) # Plot the accuracies plot_accuracies(learn_rates, data_accuracy, eval_accuracy) if __name__ == "__main__": main() # - End of script