Building a Hopfield Neural Network in JavaScript

Faced with neural networks at university, the Hopfield network became one of my favorites. I was surprised that it was the last on the list of labs, because its work can be clearly demonstrated using images and it is not so difficult to implement.





This article demonstrates how to solve the problem of restoring distorted images using a Hopfield neural network, previously trained on reference images.





I tried to describe step-by-step and as simple as possible the process of implementing a program that allows you to play with a neural network right in the browser, train the network using my own hand-drawn images and test its operation on distorted images.





Sources on Github and demo .





For implementation you will need:





  • Browser





  • Basic understanding of neural networks





  • Basic knowledge of JavaScript / HTML





A bit of theory

The Hopfield neural network is a fully connected neural network with a symmetric matrix of connections. Such a network can be used to organize associative memory, as a filter, and also to solve some optimization problems.





-  .     ,   .   ,   , .





Block diagram of the Hopfield neural network

, , . (, ), .   ,  , ยซยป ( ).





:







  1. :



    w_ {ij} = \ left \ {\ begin {matrix} \ sum_ {k = 1} ^ {m} x_ {i} ^ {k} * x_ {j} ^ {k} & i \ neq j \\ 0 , & i = j \ end {matrix} \ right.



    mโ€”  

    x_ {i} ^ {k}, x_ {j} ^ {k} โ€” i- j- k- .





  2.   . :

    y_ {j} (0) = x_ {j}





  3. (   ):



    y_ {j} (t + 1) = f \ left (\ sum_ {i = 1} ^ {n} w_ {ij} * y_ {i} (t) \ right)



    f โ€”   [-1; 1];

    t โ€” ;

    j = 1 ... n;  n โ€”  .





  4.   .  โ€”   3, , , . ,   .





.





Demonstration of the program

    Canvas   . HTML  CSS ,     (  ).





Canvas , ( ) .   ,     Canvas (  ยซยป     ).





 , 10ร—10   .   , ,   100 (  100  ).  โ€” ,  โˆ’1  1, โˆ’1 โ€” ,  1 โ€” .





-   , .





//     10   
const gridSize = 10;
//     
const squareSize = 45;
//    (100)
const inputNodes = gridSize * gridSize;

//         ,
//      
let userImageState = [];
//      
let isDrawing = false;
//  
for (let i = 0; i < inputNodes; i += 1) {  
  userImageState[i] = -1;  
}

//   :
const userCanvas = document.getElementById('userCanvas');
const userContext = userCanvas.getContext('2d');
const netCanvas = document.getElementById('netCanvas');
const netContext = netCanvas.getContext('2d');
      
      



, .





//      
//   100  (gridSize * gridSize)
const drawGrid = (ctx) => {
  ctx.beginPath();
  ctx.fillStyle = 'white';
  ctx.lineWidth = 3;
  ctx.strokeStyle = 'black';
  for (let row = 0; row < gridSize; row += 1) {
    for (let column = 0; column < gridSize; column += 1) {
      const x = column * squareSize;
      const y = row * squareSize;
      ctx.rect(x, y, squareSize, squareSize);
      ctx.fill();
      ctx.stroke();
    }
  }
  ctx.closePath();
};
      
      



ยซยป ,    .





//   
const handleMouseDown = (e) => {
  userContext.fillStyle = 'black';
  //      x, y
  //  squareSize  squareSize (4545 )
  userContext.fillRect(
    Math.floor(e.offsetX / squareSize) * squareSize,
    Math.floor(e.offsetY / squareSize) * squareSize,
    squareSize, squareSize,
  );

  //     ,
  //      
  const { clientX, clientY } = e;
  const coords = getNewSquareCoords(userCanvas, clientX, clientY, squareSize);
  const index = calcIndex(coords.x, coords.y, gridSize);

  //       
  if (isValidIndex(index, inputNodes) && userImageState[index] !== 1) {
    userImageState[index] = 1;
  }

  //   (   )
  isDrawing = true;
};

//     
const handleMouseMove = (e) => {
  //   , ..      ,    
  if (!isDrawing) return;

  //  ,   handleMouseDown
  //     isDrawing = true;
  userContext.fillStyle = 'black';

  userContext.fillRect(
    Math.floor(e.offsetX / squareSize) * squareSize,
    Math.floor(e.offsetY / squareSize) * squareSize,
    squareSize, squareSize,
  );

  const { clientX, clientY } = e;
  const coords = getNewSquareCoords(userCanvas, clientX, clientY, squareSize);
  const index = calcIndex(coords.x, coords.y, gridSize);

  if (isValidIndex(index, inputNodes) && userImageState[index] !== 1) {
    userImageState[index] = 1;
  }
};
      
      



  , , getNewSquareCoords, calcIndex  isValidIndex.  .





//      
//      
const calcIndex = (x, y, size) => x + y * size;

// ,     
const isValidIndex = (index, len) => index < len && index >= 0;

//        
//  ,      0  9
const getNewSquareCoords = (canvas, clientX, clientY, size) => {
  const rect = canvas.getBoundingClientRect();
  const x = Math.ceil((clientX - rect.left) / size) - 1;
  const y = Math.ceil((clientY - rect.top) / size) - 1;
  return { x, y };
};
      
      



.   .





const clearCurrentImage = () => {
  //    ,    
  //       
  drawGrid(userContext);
  drawGrid(netContext);
  userImageState = new Array(gridSize * gridSize).fill(-1);
};
      
      



  ยซยป .





 โ€” .   ( ).





...
const weights = [];  //   
for (let i = 0; i < inputNodes; i += 1) {
  weights[i] = new Array(inputNodes).fill(0); //       0
  userImageState[i] = -1;
}
...
      
      



    , , inputNodes .     100 ,      100 .





( )   .     . .





const memorizeImage = () => {
  for (let i = 0; i < inputNodes; i += 1) {
    for (let j = 0; j < inputNodes; j += 1) {
      if (i === j) weights[i][j] = 0;
      else {
        // ,       userImageState  
        //  -1  1,  -1 -  ,  1 -     
        weights[i][j] += userImageState[i] * userImageState[j];
      }
    }
  }
};
      
      



,   ,    ,   . :





// -  html   lodash:
<script src="https://cdnjs.cloudflare.com/ajax/libs/lodash.js/4.17.21/lodash.min.js"></script>
...
const recognizeSignal = () => {
  let prevNetState;
  //      .  
  //       
  // (2  ),     
  const currNetState = [...userImageState];
  do {
    //    , 
		// ..     
    prevNetState = [...currNetState];
    //      3  
    for (let i = 0; i < inputNodes; i += 1) {
      let sum = 0;
      for (let j = 0; j < inputNodes; j += 1) {
        sum += weights[i][j] * prevNetState[j];
      }
      //    ( - )
      currNetState[i] = sum >= 0 ? 1 : -1;
    }
    //      
    //     - isEqual
  } while (!_.isEqual(currNetState, prevNetState));

  //    ( ),   
  drawImageFromArray(currNetState, netContext);
};
      
      



    isEqual   lodash.





drawImageFromArray.       .





const drawImageFromArray = (data, ctx) => {
  const twoDimData = [];
  //     
  while (data.length) twoDimData.push(data.splice(0, gridSize));

  //   
  drawGrid(ctx);
  //     ( )
  for (let i = 0; i < gridSize; i += 1) {
    for (let j = 0; j < gridSize; j += 1) {
      if (twoDimData[i][j] === 1) {
        ctx.fillStyle = 'black';
        ctx.fillRect((j * squareSize), (i * squareSize), squareSize, squareSize);
      }
    }
  }
};
      
      



  HTML   .





HTML
const resetButton = document.getElementById('resetButton');
const memoryButton = document.getElementById('memoryButton');
const recognizeButton = document.getElementById('recognizeButton');

//    
resetButton.addEventListener('click', () => clearCurrentImage());
memoryButton.addEventListener('click', () => memorizeImage());
recognizeButton.addEventListener('click', () => recognizeSignal());

//    
userCanvas.addEventListener('mousedown', (e) => handleMouseDown(e));
userCanvas.addEventListener('mousemove', (e) => handleMouseMove(e));
//  ,         
userCanvas.addEventListener('mouseup', () => isDrawing = false);
userCanvas.addEventListener('mouseleave', () => isDrawing = false);

//  
drawGrid(userContext);
drawGrid(netContext);
      
      



, :





Reference images for network training

:





Trying to recognize the distorted image of the letter H
Trying to recognize a distorted image of the letter T

! .





, m  , 0.15 * n( nโ€”   ). , , , ,   ,               .





Sources on Github and demo .





Instead of literature, lectures were used by an excellent teacher on neural networks - Sergey Mikhailovich Roshchin , for which many thanks to him.








All Articles