package fr.perceptron; import javax.swing.BorderFactory; import javax.swing.BoxLayout; import javax.swing.JButton; import javax.swing.JComponent; import javax.swing.JFrame; import javax.swing.JLabel; import javax.swing.JPanel; import java.awt.BorderLayout; import java.awt.Color; import java.awt.Dimension; import java.awt.FlowLayout; import java.awt.Graphics; import java.awt.event.MouseAdapter; import java.awt.event.MouseEvent; import java.awt.event.MouseMotionAdapter; import java.awt.image.BufferedImage; import java.io.File; import java.io.IOException; import java.util.Arrays; import java.util.Random; public class MnistWindow { private int PXSIZE = 20; private DataSet data; private final JFrame frame; private final DrawingPanel panel; private final Random random = new Random(); private BarComponent[] bars; private int currentPointIndex = 0; private JLabel expectedResult; private JLabel predictedResult; private NeuralNetwork nn; public MnistWindow(DataSet data, NeuralNetwork nn) { this.data = data; this.nn = nn; frame = new JFrame("MNIST Viewer"); frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); panel = new DrawingPanel(Arrays.copyOf(data.points[0].inputs, data.points[0].inputs.length)); bars = new BarComponent[10]; JPanel rightPanel = new JPanel(); rightPanel.setLayout(new BoxLayout(rightPanel, BoxLayout.Y_AXIS)); // Create digit bars for (int i = 0; i < 10; i++) { JPanel barPanel = new JPanel(new FlowLayout(FlowLayout.LEFT, 5, 0)); barPanel.add(new JLabel(Integer.toString(i))); BarComponent bar = new BarComponent(); bar.setPreferredSize(new Dimension(200, 20)); bars[i] = bar; barPanel.add(bar); rightPanel.add(barPanel); } expectedResult = new JLabel(); predictedResult = new JLabel(); rightPanel.add(predictedResult); rightPanel.add(expectedResult); double acc = 0; for (DataPoint d : data.points) { double expected = getDataValue(d.expectedOutputs); double predicted = getDataValue(softmax(nn.calculateOutputs(d.inputs))); if (expected == predicted) { acc += 1; } } acc /= data.points.length; // Create buttons JPanel buttonPanel = new JPanel(); JButton resetButton = new JButton("Reset"); resetButton.addActionListener(e -> panel.reset()); JButton nextButton = new JButton("Next"); nextButton.addActionListener(e -> showRandomPoint(false)); JButton nextErrorButton = new JButton("Next error"); nextErrorButton.addActionListener(e -> showRandomPoint(true)); buttonPanel.add(resetButton); buttonPanel.add(nextButton); buttonPanel.add(nextErrorButton); rightPanel.add(buttonPanel); rightPanel.add(new JLabel("Global accuracy : " + acc*100 + "%")); // Setup main layout JPanel mainPanel = new JPanel(new BorderLayout(10, 10)); mainPanel.setBorder(BorderFactory.createEmptyBorder(10, 10, 10, 10)); mainPanel.add(panel, BorderLayout.WEST); mainPanel.add(rightPanel, BorderLayout.CENTER); frame.add(mainPanel); frame.pack(); frame.setVisible(true); predict(false); //showPoint(currentPointIndex); // Show initial point } public double[] softmax(double outputs[]) { double[] result = new double[outputs.length]; double sum = 0; for (double v : outputs) { sum += v; } for (int i=0; imax) { max = v; imax = i; } } return imax; } private void showRandomPoint(boolean findError) { currentPointIndex = random.nextInt(data.points.length); //showPoint(currentPointIndex); predict(findError); } private void updateBars(double[] values) { for (int i = 0; i < 10; i++) { bars[i].setValue(values[i]); } } private BufferedImage createImageFromData(double[] data) { BufferedImage image = new BufferedImage(28, 28, BufferedImage.TYPE_INT_RGB); for (int y = 0; y < 28; y++) { for (int x = 0; x < 28; x++) { int value = (int) (data[y * 28 + x] * 255); value = Math.max(0, Math.min(255, value)); int rgb = new Color(value, value, value).getRGB(); image.setRGB(x, y, rgb); } } return image; } private class DrawingPanel extends JPanel { private double[] data; private BufferedImage currentImage; private DrawingPanel(double[] data) { this.data = data; this.currentImage = createImageFromData(data); setupMouseListeners(); setPreferredSize(new Dimension(28*PXSIZE, 28*PXSIZE)); } public double[] getData() { return data; } public void setData(double[] newData) { this.data = Arrays.copyOf(newData, newData.length); this.currentImage = createImageFromData(this.data); repaint(); } private void setupMouseListeners() { addMouseListener(new MouseAdapter() { @Override public void mousePressed(MouseEvent e) { updateCell(e.getX(), e.getY()); } }); addMouseMotionListener(new MouseMotionAdapter() { @Override public void mouseDragged(MouseEvent e) { updateCell(e.getX(), e.getY()); } }); } private void updateCell(int x, int y) { int col = x / PXSIZE; int row = y / PXSIZE; int radius = 2; double sigma = radius / 2.8; if (col >= 0 && col < 28 && row >= 0 && row < 28) { for (int dr = -radius; dr <= radius; dr++) { for (int dc = -radius; dc <= radius; dc++) { int c = col + dc; int r = row + dr; if (c >= 0 && c < 28 && r >= 0 && r < 28) { double distance = Math.sqrt(dc * dc + dr * dr); if (distance <= radius) { // gauss double exponent = -(distance * distance) / (2 * sigma * sigma); double falloff = Math.exp(exponent); int index = r * 28 + c; data[index] = Math.min(1.0, Math.max(data[index]+falloff/2, falloff)); } } } } currentImage = createImageFromData(data); repaint(); predictCustom(); } } public void reset() { Arrays.fill(data, 0.0); currentImage = createImageFromData(data); repaint(); } @Override protected void paintComponent(Graphics g) { super.paintComponent(g); if (currentImage == null) return; g.drawImage(currentImage, 0, 0, 28*PXSIZE, 28*PXSIZE, null); } } private class BarComponent extends JComponent { private double value = 0.0; public void setValue(double value) { this.value = Math.max(0.0, Math.min(1.0, value)); repaint(); } @Override protected void paintComponent(Graphics g) { super.paintComponent(g); int width = getWidth(); int height = getHeight(); // Draw background g.setColor(Color.WHITE); g.fillRect(0, 0, width, height); // Draw filled bar g.setColor(Color.BLACK); int fillWidth = (int) (value * (width - 2)); g.fillRect(1, 1, fillWidth, height - 2); // Draw border g.setColor(Color.BLACK); g.drawRect(0, 0, width - 1, height - 1); } @Override public Dimension getPreferredSize() { return new Dimension(200, 20); } } public static void main(String[] args) { if (!new File("data.dat").exists()) { MNIST.loadTrain(); MNIST.loadEval(); MNIST.save(); } else { MNIST.load(); } DataSet eval = new DataSet(1); eval.loadFromArray(MNIST.getEvalImages(), MNIST.getEvalLabels()); String cachedNN = "nn_256_128.dat"; NeuralNetwork nn; System.out.println("Try loading cached NN..."); try { nn = NeuralNetwork.load(cachedNN); System.out.println("Done."); System.out.println(nn + "\n"); new MnistWindow(eval, nn); } catch (ClassNotFoundException | IOException e) { System.out.println(e); System.out.println("Loading failed."); } } }