Expanding Neural Gas

When performing checks using machine learning models, one of the frequently solved problems is the clustering problem. For example, it is necessary to split customer reviews on a mobile application into several clusters (thematic modeling task). The k-means model is often used for clustering tasks. This is due to its simplicity and clarity. However, this algorithm has one big drawback - the need to initially set the number of clusters. This problem is perfectly handled by expanding neural gas.





The expanding neural gas builds a graph in an attempt to approximate the distribution of the data. The unconnected subgraphs of this graph are our desired clusters. It is built according to the following algorithm:





  1. Generating the first two neurons at random





  2. One data item is taken at each step of the iterative process. Two neurons closest to him move in his direction





  3. A new neuron is created between the most frequently moving neuron and its nearest neighbor





  4. Connections are removed if the connected neurons do not move together, and neurons without connections





Let's consider this iterative algorithm using an example with the following data:





At the very beginning of building the graph, the first two neurons s1 and s2 are randomly assigned.





After that, the iterative process begins:





  1. One item of our data v1 is selected.





2. . r1 r2 , r1 > r2.





3. s2 , s1. s2 , s3 s2 s1. s1 s2 .





4. 3 s1 . . s3,





5. 3 , 3 s4. s2-s3-s4,





, . k-means.





.





sklearn c :





from sklearn.datasets import make_moons
data, _ = make_moons(10000, noise=0.06, random_state=0)
plt.scatter(*data.T)
plt.show()
      
      



:





import copy
from neupy import algorithms, utils

def draw_image(graph, show=True):
    		for node_1, node_2 in graph.edges:
        			weights = np.concatenate([node_1.weight, node_2.weight])
        			line, = plt.plot(*weights.T, color='black')
        			plt.setp(line, linewidth=0.2, color='black')

    		plt.xticks([], [])
    		plt.yticks([], [])
    
    		if show:
       			plt.show()

def create_gng(max_nodes, step=0.2, n_start_nodes=2, max_edge_age=50):
    		return algorithms.GrowingNeuralGas(
        			n_inputs=2,
        			n_start_nodes=n_start_nodes,

        			shuffle_data=True,
        			verbose=True,

        			step=step,
        			neighbour_step=0.005,

        			max_edge_age=max_edge_age,
        			max_nodes=max_nodes,

        			n_iter_before_neuron_added=100,
        			after_split_error_decay_rate=0.5,
        			error_decay_rate=0.995,
        			min_distance_for_update=0.01,
    		)

def extract_subgraphs(graph):
    		subgraphs = []
    		edges_per_node = copy.deepcopy(graph.edges_per_node)
    
    		while edges_per_node:
        			nodes_left = list(edges_per_node.keys())
        			nodes_to_check = [nodes_left[0]]
        			subgraph = []
        
        			while nodes_to_check:
           				node = nodes_to_check.pop()
            			subgraph.append(node)

            			if node in edges_per_node:
                				nodes_to_check.extend(edges_per_node[node])
                				del edges_per_node[node]
            
        			subgraphs.append(subgraph)
        
    		return subgraphs
      
      



500 , 10000, , .





utils.reproducible()
gng = create_gng(max_nodes=500)

for epoch in range(20):
    		gng.train(data, epochs=1)

	draw_image(gng.graph)    
print("Found {} clusters".format(len(extract_subgraphs(gng.graph))))
      
      



, .





3 :





X = -0.7 - 2.5 * np.random.rand(900,2)
X1 = 0.7 + 2.5 * np.random.rand(375,2)
X2 = -0.5 + 1.7 * np.random.rand(50,2)
X[475:850, :] = X1
X[850:900, :] = X2
plt.scatter(X[ : , 0], X[ :, 1])
plt.show()
      
      



Despite the lack of structured data and implicit boundaries between them, the expanding neural gas was able to correctly approximate the distribution and determine the number of clusters here.





utils.reproducible()
gng = create_gng(max_nodes=300)

for epoch in range(40):
    		gng.train(X, epochs=1)
    	
draw_image(gng.graph)    
print("Found {} clusters".format(len(extract_subgraphs(gng.graph))))
      
      






All Articles