Files
Perceptron_Pizz2/src/main/java/fr/perceptron/MnistWindow.java
2026-03-29 20:21:25 +02:00

330 lines
10 KiB
Java

package fr.perceptron;
import javax.swing.BorderFactory;
import javax.swing.BoxLayout;
import javax.swing.JButton;
import javax.swing.JComponent;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JPanel;
import java.awt.BorderLayout;
import java.awt.Color;
import java.awt.Dimension;
import java.awt.FlowLayout;
import java.awt.Graphics;
import java.awt.event.MouseAdapter;
import java.awt.event.MouseEvent;
import java.awt.event.MouseMotionAdapter;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Random;
public class MnistWindow {
private int PXSIZE = 20;
private DataSet data;
private final JFrame frame;
private final DrawingPanel panel;
private final Random random = new Random();
private BarComponent[] bars;
private int currentPointIndex = 0;
private JLabel expectedResult;
private JLabel predictedResult;
private NeuralNetwork nn;
public MnistWindow(DataSet data, NeuralNetwork nn) {
this.data = data;
this.nn = nn;
frame = new JFrame("MNIST Viewer");
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
panel = new DrawingPanel(Arrays.copyOf(data.points[0].inputs, data.points[0].inputs.length));
bars = new BarComponent[10];
JPanel rightPanel = new JPanel();
rightPanel.setLayout(new BoxLayout(rightPanel, BoxLayout.Y_AXIS));
// Create digit bars
for (int i = 0; i < 10; i++) {
JPanel barPanel = new JPanel(new FlowLayout(FlowLayout.LEFT, 5, 0));
barPanel.add(new JLabel(Integer.toString(i)));
BarComponent bar = new BarComponent();
bar.setPreferredSize(new Dimension(200, 20));
bars[i] = bar;
barPanel.add(bar);
rightPanel.add(barPanel);
}
expectedResult = new JLabel();
predictedResult = new JLabel();
rightPanel.add(predictedResult);
rightPanel.add(expectedResult);
double acc = 0;
for (DataPoint d : data.points) {
double expected = getDataValue(d.expectedOutputs);
double predicted = getDataValue(softmax(nn.calculateOutputs(d.inputs)));
if (expected == predicted) {
acc += 1;
}
}
acc /= data.points.length;
// Create buttons
JPanel buttonPanel = new JPanel();
JButton resetButton = new JButton("Reset");
resetButton.addActionListener(e -> panel.reset());
JButton nextButton = new JButton("Next");
nextButton.addActionListener(e -> showRandomPoint(false));
JButton nextErrorButton = new JButton("Next error");
nextErrorButton.addActionListener(e -> showRandomPoint(true));
buttonPanel.add(resetButton);
buttonPanel.add(nextButton);
buttonPanel.add(nextErrorButton);
rightPanel.add(buttonPanel);
rightPanel.add(new JLabel("Global accuracy : " + acc*100 + "%"));
// Setup main layout
JPanel mainPanel = new JPanel(new BorderLayout(10, 10));
mainPanel.setBorder(BorderFactory.createEmptyBorder(10, 10, 10, 10));
mainPanel.add(panel, BorderLayout.WEST);
mainPanel.add(rightPanel, BorderLayout.CENTER);
frame.add(mainPanel);
frame.pack();
frame.setVisible(true);
predict(false);
//showPoint(currentPointIndex); // Show initial point
}
public double[] softmax(double outputs[]) {
double[] result = new double[outputs.length];
double sum = 0;
for (double v : outputs) {
sum += v;
}
for (int i=0; i<result.length; i++) {
result[i] = outputs[i]/sum;
}
return result;
}
// Update le tout pour le dessin custom
private void predictCustom() {
double[] inputs = panel.getData();
double[] outputs = softmax(nn.calculateOutputs(inputs));
String v = String.valueOf(getDataValue(outputs));
expectedResult.setText("");
predictedResult.setText("Predicted : " + v);
updateBars(outputs);
}
private void predict(boolean findError) {
DataPoint point = data.points[currentPointIndex];
panel.setData(Arrays.copyOf(point.inputs, point.inputs.length));
double[] inputs = panel.getData();
double[] outputs = softmax(nn.calculateOutputs(inputs));
double[] expectedOutputs = point.expectedOutputs;
expectedResult.setText("Expected : " + getDataValue(expectedOutputs));
predictedResult.setText("Predicted : " + getDataValue(outputs));
updateBars(outputs);
if (findError && (getDataValue(expectedOutputs) == getDataValue(outputs))) {
showRandomPoint(true);
}
}
private int getDataValue(double[] d) {
double max = 0;
int imax = 0;
for (int i=0; i<d.length; i++) {
double v = d[i];
if (v>max) {
max = v;
imax = i;
}
}
return imax;
}
private void showRandomPoint(boolean findError) {
currentPointIndex = random.nextInt(data.points.length);
//showPoint(currentPointIndex);
predict(findError);
}
private void updateBars(double[] values) {
for (int i = 0; i < 10; i++) {
bars[i].setValue(values[i]);
}
}
private BufferedImage createImageFromData(double[] data) {
BufferedImage image = new BufferedImage(28, 28, BufferedImage.TYPE_INT_RGB);
for (int y = 0; y < 28; y++) {
for (int x = 0; x < 28; x++) {
int value = (int) (data[y * 28 + x] * 255);
value = Math.max(0, Math.min(255, value));
int rgb = new Color(value, value, value).getRGB();
image.setRGB(x, y, rgb);
}
}
return image;
}
private class DrawingPanel extends JPanel {
private double[] data;
private BufferedImage currentImage;
private DrawingPanel(double[] data) {
this.data = data;
this.currentImage = createImageFromData(data);
setupMouseListeners();
setPreferredSize(new Dimension(28*PXSIZE, 28*PXSIZE));
}
public double[] getData() {
return data;
}
public void setData(double[] newData) {
this.data = Arrays.copyOf(newData, newData.length);
this.currentImage = createImageFromData(this.data);
repaint();
}
private void setupMouseListeners() {
addMouseListener(new MouseAdapter() {
@Override
public void mousePressed(MouseEvent e) {
updateCell(e.getX(), e.getY());
}
});
addMouseMotionListener(new MouseMotionAdapter() {
@Override
public void mouseDragged(MouseEvent e) {
updateCell(e.getX(), e.getY());
}
});
}
private void updateCell(int x, int y) {
int col = x / PXSIZE;
int row = y / PXSIZE;
int radius = 2;
double sigma = radius / 2.8;
if (col >= 0 && col < 28 && row >= 0 && row < 28) {
for (int dr = -radius; dr <= radius; dr++) {
for (int dc = -radius; dc <= radius; dc++) {
int c = col + dc;
int r = row + dr;
if (c >= 0 && c < 28 && r >= 0 && r < 28) {
double distance = Math.sqrt(dc * dc + dr * dr);
if (distance <= radius) {
// gauss
double exponent = -(distance * distance) / (2 * sigma * sigma);
double falloff = Math.exp(exponent);
int index = r * 28 + c;
data[index] = Math.min(1.0, Math.max(data[index]+falloff/2, falloff));
}
}
}
}
currentImage = createImageFromData(data);
repaint();
predictCustom();
}
}
public void reset() {
Arrays.fill(data, 0.0);
currentImage = createImageFromData(data);
repaint();
}
@Override
protected void paintComponent(Graphics g) {
super.paintComponent(g);
if (currentImage == null) return;
g.drawImage(currentImage, 0, 0, 28*PXSIZE, 28*PXSIZE, null);
}
}
private class BarComponent extends JComponent {
private double value = 0.0;
public void setValue(double value) {
this.value = Math.max(0.0, Math.min(1.0, value));
repaint();
}
@Override
protected void paintComponent(Graphics g) {
super.paintComponent(g);
int width = getWidth();
int height = getHeight();
// Draw background
g.setColor(Color.WHITE);
g.fillRect(0, 0, width, height);
// Draw filled bar
g.setColor(Color.BLACK);
int fillWidth = (int) (value * (width - 2));
g.fillRect(1, 1, fillWidth, height - 2);
// Draw border
g.setColor(Color.BLACK);
g.drawRect(0, 0, width - 1, height - 1);
}
@Override
public Dimension getPreferredSize() {
return new Dimension(200, 20);
}
}
public static void main(String[] args) {
if (!new File("data.dat").exists()) {
MNIST.loadTrain();
MNIST.loadEval();
MNIST.save();
}
else {
MNIST.load();
}
DataSet eval = new DataSet(1);
eval.loadFromArray(MNIST.getEvalImages(), MNIST.getEvalLabels());
String cachedNN = "nn_256_128.dat";
NeuralNetwork nn;
System.out.println("Try loading cached NN...");
try {
nn = NeuralNetwork.load(cachedNN);
System.out.println("Done.");
System.out.println(nn + "\n");
new MnistWindow(eval, nn);
} catch (ClassNotFoundException | IOException e) {
System.out.println(e);
System.out.println("Loading failed.");
}
}
}