Faced with neural networks at university, the Hopfield network became one of my favorites. I was surprised that it was the last on the list of labs, because its work can be clearly demonstrated using images and it is not so difficult to implement.
This article demonstrates how to solve the problem of restoring distorted images using a Hopfield neural network, previously trained on reference images.
I tried to describe step-by-step and as simple as possible the process of implementing a program that allows you to play with a neural network right in the browser, train the network using my own hand-drawn images and test its operation on distorted images.
For implementation you will need:
Browser
Basic understanding of neural networks
Basic knowledge of JavaScript / HTML
A bit of theory
The Hopfield neural network is a fully connected neural network with a symmetric matrix of connections. Such a network can be used to organize associative memory, as a filter, and also to solve some optimization problems.
- . , . , , .
, , . (, ), . , , ยซยป ( ).
:
:
โ
โ - - - .
. :
( ):
โ ;
โ ;
โ .
. โ 3, , , . , .
.
Canvas , ( ) . , Canvas ( ยซยป ).
, 10ร10 . , , 100 ( 100 ). โ , โ1 1, โ1 โ , 1 โ .
- , .
// 10
const gridSize = 10;
//
const squareSize = 45;
// (100)
const inputNodes = gridSize * gridSize;
// ,
//
let userImageState = [];
//
let isDrawing = false;
//
for (let i = 0; i < inputNodes; i += 1) {
userImageState[i] = -1;
}
// :
const userCanvas = document.getElementById('userCanvas');
const userContext = userCanvas.getContext('2d');
const netCanvas = document.getElementById('netCanvas');
const netContext = netCanvas.getContext('2d');
, .
//
// 100 (gridSize * gridSize)
const drawGrid = (ctx) => {
ctx.beginPath();
ctx.fillStyle = 'white';
ctx.lineWidth = 3;
ctx.strokeStyle = 'black';
for (let row = 0; row < gridSize; row += 1) {
for (let column = 0; column < gridSize; column += 1) {
const x = column * squareSize;
const y = row * squareSize;
ctx.rect(x, y, squareSize, squareSize);
ctx.fill();
ctx.stroke();
}
}
ctx.closePath();
};
ยซยป , .
//
const handleMouseDown = (e) => {
userContext.fillStyle = 'black';
// x, y
// squareSize squareSize (4545 )
userContext.fillRect(
Math.floor(e.offsetX / squareSize) * squareSize,
Math.floor(e.offsetY / squareSize) * squareSize,
squareSize, squareSize,
);
// ,
//
const { clientX, clientY } = e;
const coords = getNewSquareCoords(userCanvas, clientX, clientY, squareSize);
const index = calcIndex(coords.x, coords.y, gridSize);
//
if (isValidIndex(index, inputNodes) && userImageState[index] !== 1) {
userImageState[index] = 1;
}
// ( )
isDrawing = true;
};
//
const handleMouseMove = (e) => {
// , .. ,
if (!isDrawing) return;
// , handleMouseDown
// isDrawing = true;
userContext.fillStyle = 'black';
userContext.fillRect(
Math.floor(e.offsetX / squareSize) * squareSize,
Math.floor(e.offsetY / squareSize) * squareSize,
squareSize, squareSize,
);
const { clientX, clientY } = e;
const coords = getNewSquareCoords(userCanvas, clientX, clientY, squareSize);
const index = calcIndex(coords.x, coords.y, gridSize);
if (isValidIndex(index, inputNodes) && userImageState[index] !== 1) {
userImageState[index] = 1;
}
};
, , getNewSquareCoords, calcIndex isValidIndex. .
//
//
const calcIndex = (x, y, size) => x + y * size;
// ,
const isValidIndex = (index, len) => index < len && index >= 0;
//
// , 0 9
const getNewSquareCoords = (canvas, clientX, clientY, size) => {
const rect = canvas.getBoundingClientRect();
const x = Math.ceil((clientX - rect.left) / size) - 1;
const y = Math.ceil((clientY - rect.top) / size) - 1;
return { x, y };
};
. .
const clearCurrentImage = () => {
// ,
//
drawGrid(userContext);
drawGrid(netContext);
userImageState = new Array(gridSize * gridSize).fill(-1);
};
ยซยป .
โ . ( ).
...
const weights = []; //
for (let i = 0; i < inputNodes; i += 1) {
weights[i] = new Array(inputNodes).fill(0); // 0
userImageState[i] = -1;
}
...
, , inputNodes . 100 , 100 .
( ) . . .
const memorizeImage = () => {
for (let i = 0; i < inputNodes; i += 1) {
for (let j = 0; j < inputNodes; j += 1) {
if (i === j) weights[i][j] = 0;
else {
// , userImageState
// -1 1, -1 - , 1 -
weights[i][j] += userImageState[i] * userImageState[j];
}
}
}
};
, , , . :
// - html lodash:
<script src="https://cdnjs.cloudflare.com/ajax/libs/lodash.js/4.17.21/lodash.min.js"></script>
...
const recognizeSignal = () => {
let prevNetState;
// .
//
// (2 ),
const currNetState = [...userImageState];
do {
// ,
// ..
prevNetState = [...currNetState];
// 3
for (let i = 0; i < inputNodes; i += 1) {
let sum = 0;
for (let j = 0; j < inputNodes; j += 1) {
sum += weights[i][j] * prevNetState[j];
}
// ( - )
currNetState[i] = sum >= 0 ? 1 : -1;
}
//
// - isEqual
} while (!_.isEqual(currNetState, prevNetState));
// ( ),
drawImageFromArray(currNetState, netContext);
};
isEqual lodash.
drawImageFromArray. .
const drawImageFromArray = (data, ctx) => {
const twoDimData = [];
//
while (data.length) twoDimData.push(data.splice(0, gridSize));
//
drawGrid(ctx);
// ( )
for (let i = 0; i < gridSize; i += 1) {
for (let j = 0; j < gridSize; j += 1) {
if (twoDimData[i][j] === 1) {
ctx.fillStyle = 'black';
ctx.fillRect((j * squareSize), (i * squareSize), squareSize, squareSize);
}
}
}
};
HTML .
HTML
const resetButton = document.getElementById('resetButton');
const memoryButton = document.getElementById('memoryButton');
const recognizeButton = document.getElementById('recognizeButton');
//
resetButton.addEventListener('click', () => clearCurrentImage());
memoryButton.addEventListener('click', () => memorizeImage());
recognizeButton.addEventListener('click', () => recognizeSignal());
//
userCanvas.addEventListener('mousedown', (e) => handleMouseDown(e));
userCanvas.addEventListener('mousemove', (e) => handleMouseMove(e));
// ,
userCanvas.addEventListener('mouseup', () => isDrawing = false);
userCanvas.addEventListener('mouseleave', () => isDrawing = false);
//
drawGrid(userContext);
drawGrid(netContext);
, :
:
! .
, , ( โ ). , , , , , .
Instead of literature, lectures were used by an excellent teacher on neural networks - Sergey Mikhailovich Roshchin , for which many thanks to him.