330 lines
10 KiB
Java
330 lines
10 KiB
Java
package fr.perceptron;
|
|
|
|
import javax.swing.BorderFactory;
|
|
import javax.swing.BoxLayout;
|
|
import javax.swing.JButton;
|
|
import javax.swing.JComponent;
|
|
import javax.swing.JFrame;
|
|
import javax.swing.JLabel;
|
|
import javax.swing.JPanel;
|
|
|
|
import java.awt.BorderLayout;
|
|
import java.awt.Color;
|
|
import java.awt.Dimension;
|
|
import java.awt.FlowLayout;
|
|
import java.awt.Graphics;
|
|
import java.awt.event.MouseAdapter;
|
|
import java.awt.event.MouseEvent;
|
|
import java.awt.event.MouseMotionAdapter;
|
|
import java.awt.image.BufferedImage;
|
|
import java.io.File;
|
|
import java.io.IOException;
|
|
import java.util.Arrays;
|
|
import java.util.Random;
|
|
|
|
public class MnistWindow {
|
|
private int PXSIZE = 20;
|
|
private DataSet data;
|
|
private final JFrame frame;
|
|
private final DrawingPanel panel;
|
|
private final Random random = new Random();
|
|
private BarComponent[] bars;
|
|
private int currentPointIndex = 0;
|
|
private JLabel expectedResult;
|
|
private JLabel predictedResult;
|
|
|
|
private NeuralNetwork nn;
|
|
|
|
public MnistWindow(DataSet data, NeuralNetwork nn) {
|
|
this.data = data;
|
|
this.nn = nn;
|
|
|
|
frame = new JFrame("MNIST Viewer");
|
|
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
|
|
|
|
panel = new DrawingPanel(Arrays.copyOf(data.points[0].inputs, data.points[0].inputs.length));
|
|
bars = new BarComponent[10];
|
|
|
|
JPanel rightPanel = new JPanel();
|
|
rightPanel.setLayout(new BoxLayout(rightPanel, BoxLayout.Y_AXIS));
|
|
|
|
// Create digit bars
|
|
for (int i = 0; i < 10; i++) {
|
|
JPanel barPanel = new JPanel(new FlowLayout(FlowLayout.LEFT, 5, 0));
|
|
barPanel.add(new JLabel(Integer.toString(i)));
|
|
BarComponent bar = new BarComponent();
|
|
bar.setPreferredSize(new Dimension(200, 20));
|
|
bars[i] = bar;
|
|
barPanel.add(bar);
|
|
rightPanel.add(barPanel);
|
|
}
|
|
|
|
expectedResult = new JLabel();
|
|
predictedResult = new JLabel();
|
|
rightPanel.add(predictedResult);
|
|
rightPanel.add(expectedResult);
|
|
|
|
double acc = 0;
|
|
for (DataPoint d : data.points) {
|
|
double expected = getDataValue(d.expectedOutputs);
|
|
double predicted = getDataValue(softmax(nn.calculateOutputs(d.inputs)));
|
|
if (expected == predicted) {
|
|
acc += 1;
|
|
}
|
|
}
|
|
acc /= data.points.length;
|
|
|
|
// Create buttons
|
|
JPanel buttonPanel = new JPanel();
|
|
JButton resetButton = new JButton("Reset");
|
|
resetButton.addActionListener(e -> panel.reset());
|
|
JButton nextButton = new JButton("Next");
|
|
nextButton.addActionListener(e -> showRandomPoint(false));
|
|
JButton nextErrorButton = new JButton("Next error");
|
|
nextErrorButton.addActionListener(e -> showRandomPoint(true));
|
|
buttonPanel.add(resetButton);
|
|
buttonPanel.add(nextButton);
|
|
buttonPanel.add(nextErrorButton);
|
|
rightPanel.add(buttonPanel);
|
|
|
|
rightPanel.add(new JLabel("Global accuracy : " + acc*100 + "%"));
|
|
|
|
// Setup main layout
|
|
JPanel mainPanel = new JPanel(new BorderLayout(10, 10));
|
|
mainPanel.setBorder(BorderFactory.createEmptyBorder(10, 10, 10, 10));
|
|
mainPanel.add(panel, BorderLayout.WEST);
|
|
mainPanel.add(rightPanel, BorderLayout.CENTER);
|
|
|
|
frame.add(mainPanel);
|
|
frame.pack();
|
|
frame.setVisible(true);
|
|
|
|
predict(false);
|
|
//showPoint(currentPointIndex); // Show initial point
|
|
}
|
|
|
|
public double[] softmax(double outputs[]) {
|
|
double[] result = new double[outputs.length];
|
|
double sum = 0;
|
|
for (double v : outputs) {
|
|
sum += v;
|
|
}
|
|
for (int i=0; i<result.length; i++) {
|
|
result[i] = outputs[i]/sum;
|
|
}
|
|
return result;
|
|
}
|
|
|
|
// Update le tout pour le dessin custom
|
|
private void predictCustom() {
|
|
double[] inputs = panel.getData();
|
|
double[] outputs = softmax(nn.calculateOutputs(inputs));
|
|
|
|
String v = String.valueOf(getDataValue(outputs));
|
|
expectedResult.setText("");
|
|
predictedResult.setText("Predicted : " + v);
|
|
|
|
updateBars(outputs);
|
|
}
|
|
|
|
private void predict(boolean findError) {
|
|
DataPoint point = data.points[currentPointIndex];
|
|
panel.setData(Arrays.copyOf(point.inputs, point.inputs.length));
|
|
double[] inputs = panel.getData();
|
|
double[] outputs = softmax(nn.calculateOutputs(inputs));
|
|
double[] expectedOutputs = point.expectedOutputs;
|
|
|
|
expectedResult.setText("Expected : " + getDataValue(expectedOutputs));
|
|
predictedResult.setText("Predicted : " + getDataValue(outputs));
|
|
|
|
updateBars(outputs);
|
|
|
|
if (findError && (getDataValue(expectedOutputs) == getDataValue(outputs))) {
|
|
showRandomPoint(true);
|
|
}
|
|
}
|
|
|
|
private int getDataValue(double[] d) {
|
|
double max = 0;
|
|
int imax = 0;
|
|
for (int i=0; i<d.length; i++) {
|
|
double v = d[i];
|
|
if (v>max) {
|
|
max = v;
|
|
imax = i;
|
|
}
|
|
}
|
|
return imax;
|
|
}
|
|
|
|
private void showRandomPoint(boolean findError) {
|
|
currentPointIndex = random.nextInt(data.points.length);
|
|
//showPoint(currentPointIndex);
|
|
predict(findError);
|
|
}
|
|
|
|
private void updateBars(double[] values) {
|
|
for (int i = 0; i < 10; i++) {
|
|
bars[i].setValue(values[i]);
|
|
}
|
|
}
|
|
|
|
|
|
private BufferedImage createImageFromData(double[] data) {
|
|
BufferedImage image = new BufferedImage(28, 28, BufferedImage.TYPE_INT_RGB);
|
|
for (int y = 0; y < 28; y++) {
|
|
for (int x = 0; x < 28; x++) {
|
|
int value = (int) (data[y * 28 + x] * 255);
|
|
value = Math.max(0, Math.min(255, value));
|
|
int rgb = new Color(value, value, value).getRGB();
|
|
image.setRGB(x, y, rgb);
|
|
}
|
|
}
|
|
return image;
|
|
}
|
|
|
|
private class DrawingPanel extends JPanel {
|
|
private double[] data;
|
|
private BufferedImage currentImage;
|
|
|
|
private DrawingPanel(double[] data) {
|
|
this.data = data;
|
|
this.currentImage = createImageFromData(data);
|
|
setupMouseListeners();
|
|
setPreferredSize(new Dimension(28*PXSIZE, 28*PXSIZE));
|
|
}
|
|
|
|
public double[] getData() {
|
|
return data;
|
|
}
|
|
|
|
public void setData(double[] newData) {
|
|
this.data = Arrays.copyOf(newData, newData.length);
|
|
this.currentImage = createImageFromData(this.data);
|
|
repaint();
|
|
}
|
|
|
|
private void setupMouseListeners() {
|
|
addMouseListener(new MouseAdapter() {
|
|
@Override
|
|
public void mousePressed(MouseEvent e) {
|
|
updateCell(e.getX(), e.getY());
|
|
}
|
|
});
|
|
|
|
addMouseMotionListener(new MouseMotionAdapter() {
|
|
@Override
|
|
public void mouseDragged(MouseEvent e) {
|
|
updateCell(e.getX(), e.getY());
|
|
}
|
|
});
|
|
}
|
|
|
|
private void updateCell(int x, int y) {
|
|
int col = x / PXSIZE;
|
|
int row = y / PXSIZE;
|
|
int radius = 2;
|
|
double sigma = radius / 2.8;
|
|
|
|
if (col >= 0 && col < 28 && row >= 0 && row < 28) {
|
|
for (int dr = -radius; dr <= radius; dr++) {
|
|
for (int dc = -radius; dc <= radius; dc++) {
|
|
int c = col + dc;
|
|
int r = row + dr;
|
|
|
|
if (c >= 0 && c < 28 && r >= 0 && r < 28) {
|
|
double distance = Math.sqrt(dc * dc + dr * dr);
|
|
|
|
if (distance <= radius) {
|
|
// gauss
|
|
double exponent = -(distance * distance) / (2 * sigma * sigma);
|
|
double falloff = Math.exp(exponent);
|
|
|
|
int index = r * 28 + c;
|
|
data[index] = Math.min(1.0, Math.max(data[index]+falloff/2, falloff));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
currentImage = createImageFromData(data);
|
|
repaint();
|
|
predictCustom();
|
|
}
|
|
}
|
|
|
|
public void reset() {
|
|
Arrays.fill(data, 0.0);
|
|
currentImage = createImageFromData(data);
|
|
repaint();
|
|
}
|
|
|
|
@Override
|
|
protected void paintComponent(Graphics g) {
|
|
super.paintComponent(g);
|
|
if (currentImage == null) return;
|
|
g.drawImage(currentImage, 0, 0, 28*PXSIZE, 28*PXSIZE, null);
|
|
}
|
|
}
|
|
|
|
private class BarComponent extends JComponent {
|
|
private double value = 0.0;
|
|
|
|
public void setValue(double value) {
|
|
this.value = Math.max(0.0, Math.min(1.0, value));
|
|
repaint();
|
|
}
|
|
|
|
@Override
|
|
protected void paintComponent(Graphics g) {
|
|
super.paintComponent(g);
|
|
int width = getWidth();
|
|
int height = getHeight();
|
|
|
|
// Draw background
|
|
g.setColor(Color.WHITE);
|
|
g.fillRect(0, 0, width, height);
|
|
|
|
// Draw filled bar
|
|
g.setColor(Color.BLACK);
|
|
int fillWidth = (int) (value * (width - 2));
|
|
g.fillRect(1, 1, fillWidth, height - 2);
|
|
|
|
// Draw border
|
|
g.setColor(Color.BLACK);
|
|
g.drawRect(0, 0, width - 1, height - 1);
|
|
}
|
|
|
|
@Override
|
|
public Dimension getPreferredSize() {
|
|
return new Dimension(200, 20);
|
|
}
|
|
}
|
|
|
|
public static void main(String[] args) {
|
|
if (!new File("data.dat").exists()) {
|
|
MNIST.loadTrain();
|
|
MNIST.loadEval();
|
|
MNIST.save();
|
|
}
|
|
else {
|
|
MNIST.load();
|
|
}
|
|
|
|
DataSet eval = new DataSet(1);
|
|
eval.loadFromArray(MNIST.getEvalImages(), MNIST.getEvalLabels());
|
|
|
|
String cachedNN = "nn_256_128.dat";
|
|
NeuralNetwork nn;
|
|
System.out.println("Try loading cached NN...");
|
|
try {
|
|
nn = NeuralNetwork.load(cachedNN);
|
|
System.out.println("Done.");
|
|
System.out.println(nn + "\n");
|
|
new MnistWindow(eval, nn);
|
|
} catch (ClassNotFoundException | IOException e) {
|
|
System.out.println(e);
|
|
System.out.println("Loading failed.");
|
|
}
|
|
|
|
}
|
|
} |