KNN principle and Python code implementation


1、 Principle

1. Overview

If the nearest neighbor algorithm does not depend on the basic condition of the machine learning, it will have a better learning effect.

1.1 core ideas

Birds of a feather flock together. As the saying goes, “to see whether a man is good or not depends on the friends around him is absolutely right”. For the samples we need to learn and predict, the truth is the same. We can judge what category a sample belongs to by the samples around him. The K samples nearest to the sample in the feature space should be in the same group as the samples to be predicted. Therefore, if most of the K samples belong to category A, then the samples to be predicted also belong to category a.

1.2 purpose

KNN can be used for classification (two classification and multi classification, and the algorithm does not need to be modified), and can also be used for regression. Whether KNN is used for classification or regression mainly depends on the calculation method of the output results of the algorithm,classification problem In most cases, the K samples near the predicted samples are used to vote on the category, and the category of the samples to be classified is determined according to the voting resultsMajority votingThe minority is subordinate to the majority;Regression problemIt is to average the output of the nearest K samples to be predicted, and the mean value is taken as the output of the samples to be predicted, i.eaverage method

1.3 advantages and disadvantages of KNN


  • There is no hypothesis about the distribution of data, which can be applied to all kinds of data sets. However, generally speaking, dense data is better, and too sparse data is more difficult to control the value of K, which is easy to be misled;
  • The algorithm is simple, easy to understand and implement, and can be classified or regressed;


  • KNN has no training and learning process according to the data, and it has to run once every time to classify, so the speed is slow;
  • It is sensitive to sample imbalance and easily disturbed by large sample size categories (we can consider using distance to weight to increase the impact of the nearest data);
  • The implementation of KD tree consumes memory.

2. Algorithm flow and implementation

2.1 KNN algorithm flow (taking KNN classification algorithm as an example)

Input: training set\(T=\{(x_{1}, y_{1}),(x_{2}, y_{2}),…(x_{n}, y_{n})\}\)In which\(x\)Is the feature vector of the sample and is the category of the sample
Output: category y of sample x
(1) Select the distance measurement method, find the K sample points closest to the predicted sample X in the training set t, and mark the neighborhood of X containing these k points;
(2) According to the classification decision rules (such as majority voting) in, the class y of X is determined.

From the above KNN algorithm process, we can find that there are four key points in the operation

  • The choice of distance measurement method;
  • The selection of K value;
  • How to find the sample points of the k nearest neighbors;
  • The selection of classification decision rules.

amongMeasurement method, K value and classification decision rulesbe calledThree elements of KNN algorithmIt will affect the effect of the algorithm, which is very important. How to find the sample points of the k nearest neighbors directly determines the specific calculation process after the implementation of KNN algorithm,It affects the complexity of the algorithm。 Next, we will analyze these four points in detail

(1)Distance measurementGenerally speaking, distance measurement mainly includes European distance, Manhattan distance and Minkowski distance. European distance is commonly used by us, which is the basic content and will not be repeated in this paper.

(2)Selection of K valueThere is no good way to select k value. Generally, cross validation is used to select appropriate K value, such as using gridsearchcv tool. The selection of K value will have a greater impact on the prediction results. The smaller the K value is, the more complex the model is (the larger the K is, the simpler the model is, for example, k = n, all the samples to be predicted are divided into the same category, and the model is very simple). At this time, the bias of the model decreases, the variance increases, and the model is easy to over fit, so the selection of K value should be careful.

(3)Selection of classification decision rulesThe majority voting method is used for classification problems.

(4)There are three common methods to find the sample points of these k nearest neighbors

a. The realization of violenceIn other words, the distances between all the sample points and the prediction points in the training set are obtained through direct traversal, and then the top k nearest neighbor points are sorted, so the time complexity of the calculation is O (n);

b. kd-tree In practical engineering application, it is difficult to calculate the sample size of hundreds of thousands of features and millions of samples. Therefore, it can be optimized by dividing the search space by index tree. The time complexity is O (log n), detailed The content will be discussed later;

c. Number of ballsThe sphere tree is similar to the KD tree, but it will not do some redundant calculation like the KD tree. The main difference is that the KD tree gets the hyper rectangle composed of node samples, while the ball tree gets the smallest hypersphere composed of node samples. This hypersphere is smaller than the corresponding KD tree’s hyperrectangle. In this way, some unnecessary search can be avoided when doing nearest neighbor search.

Next, the implementation of KNN based on KD tree is described in detail. Why write KD tree instead of ball tree? One is that the two are almost the same. The other is that KD tree seems to have some inexplicable connection with Kevin Durant.

2.2 KD tree

KD tree (k-dimension) The function of this tree is to help us quickly find data points that are close to the target point. Think about this function, is it very similar to the game of guessing numbers we play. One person secretly thinks about a number (target data point), while others keep guessing numbers to narrow down the range. This is just like finding the closest K for the target point Is a dot the same?

So, think about the dichotomy, which is also the core idea of KD treedivide and rule。 Therefore, KD tree is a binary tree. Each node of KD tree will divide all the digital points into two parts in a certain dimension. Imagine that if the data space is divided into many small parts in each dimension, then we need to determine a new point with similar data points. As long as we find the part where the point is located, it will be very important.

But what dimension should this tree be in and how to divide it? Of course, it is to select the most open dimension that can divide the number (for example, if the variance is large, it means that the data points along the coordinate axis are scattered, and in this direction, data segmentation can obtain the best resolution). Is this similar to the idea of decision tree? Of course, the decision tree can be directly used for classification. Feature selection is more complex (such as information gain ratio), and each feature is only used once, and the feature is not required to be of numerical type. However, the basic idea is very similar. Let alone the decision tree, let’s see how to implement a KD tree.

2.2.1 construction of KD tree

Let’s start with a common simple example: suppose there are six two-dimensional data points\(\{(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)\}\)The data points are located in two-dimensional space. According to the above introduction, we need to divide the data space into many small parts according to these data points. In fact, the graph of KD tree generated by six two-dimensional data points is divided as follows:

How is this done?

  • Find the characteristics of the partition(split )。The data variances of the six data points in the X and Y dimensions are 6.97 and 5.37 respectively, so the variance on the x-axis is larger, and the first dimension feature is used to build the tree;
  • Determination of dividing points (7,2)(Node-Data )。The data are sorted according to the values in the X-dimension. The median value of the six data (the so-called median value, i.e. the value of the middle size) is 6, so the data of dividing points is (7,2)(you can choose (5,4)。 In this way, the hyperplane of the node passes through (7,2) and is perpendicular to the straight line x = 7 which divides the dimension of the point;
  • Determine the left subspace(left)And right subspace(right )。The partition of hyperplane x = 7 divides the whole space into two parts: the part with x < = 7 is the left subspace, which contains three nodes {(2,3), (5,4), (4,7)}; the other part is the right subspace, which contains two nodes = {(9,6), (8,1)};
  • Iteration.The nodes {(2,3), (5,4), (4,7)} of the left subtree and {(9,6), (8,1)} of the right subtree are divided in the same way.

We can see the following figure: points (7,2) are the root nodes, (5,4) and (9,6) are the left and right children of the root node, while (2,3), (4,7) are the left and right children of (5,4), and finally, (8,1) are the left children of (9,6). In this way, a k-d tree is formed.

From the description of the above process and results, we can summarize the basic data structure and construction process of KD tree
The data structure of k-d tree is as follows

Pseudo code of KD tree construction process:

Input:Data point set, dataset, and its space
Output:KD, type KD tree

1 if DataSet is null , return null;
2Else calls the node data generator:
aCalculate split. For all data (high-dimensional vector), the variance of each dimension is counted. The dimension corresponding to the maximum variance is the value of split field;
bCalculate node data. The dataset is sorted according to the value of the split dimension, and the data point closest to the median is selected as node data;
3Dataleft = {D belongs to dataset & D [: split] < = node data [: split]};
 Left-Range = { Range && dataleft };
  dataright = {D belongs to dataset & D [: split] > node data [: split]};
 Right-Range = { Range && dataright };
4Left = KD tree established by (dataleft, left range);
  set left as the left of KD tree;
  right = KD tree established by (dataright, right range);
  set right to the right of KD tree;
5Right and the data are repeated on the left.

According to the above pseudo code, we can write a program to build KD tree. Python implementation is as follows:

Python code of KD tree construction process:

class KDNode(object):
	def __init__(self, node_data, split, left, right):
		self.node_data = node_data
		self.split = split
		self.left = left
		self.right = right

class KDTree(object):
	def __init__(self, dataset):
		self.dim = len(dataset[0]) 
		self.tree = self.generate_kdtree(dataset)

	def generate_kdtree(self, dataset):
		Recursively generate a tree
		if not dataset:
			return None
			split, data_node = self.cal_node_data(dataset)
			left = [d for d in dataset if d[split] < data_node[split]]
			right = [d for d in dataset if d[split] > data_node[split]]
			return KDNode(data_node, split, self.generate_kdtree(left), self.generate_kdtree(right))

	def cal_node_data(self, dataset):
		std_ Lst = [] ා replace variance with standard deviation, all the same
		for i in range(self.dim):
			std_i = np.std([d[i] for d in dataset])
		split = std_lst.index(max(std_lst))
		dataset.sort(key=lambda x: x[split])
		indx=int((len(dataset) + 2) / 2)-1
		data_node = dataset[indx]
		return split, data_node
2.2.2 using KD tree to get k-nearest neighbor

(1) Using KD tree to get nearest neighbor

1) Search along the binary tree until the nearest neighbor node in the leaf node is found. This process will determine a search path;
2) Because the binary tree divides the data space with a certain feature as the standard, considering all the features comprehensively, the child nodes searched according to the binary tree are not necessarily the nearest neighbors, but the nearest neighbors must be located in the circle with the query point as the center and passing through the leaf node (otherwise, it will be far away from the leaf node). Therefore, according to the search path Line backtracking operation to find the nearest neighbor.

For example, search the nearest neighbor of (3,4.5) in the previous KD tree

  • Binary tree search search search (3,4.5), get the search path
  • First of all(4,7)The nearest point (2.2.1) is calculated as the nearest point;
  • It then traces back to its parent node(5,4)And determine whether there are data points closer to the query point in the other child node space of the parent node. Draw a circle with (3,4.5) as the center and 2.69 as the radius, as shown in the figure below. It is found that the circle and hyperplane y = 4, so the calculated distance between (5,4) and (3,4.5) is 2.06, less than 2.69, the nearest neighbor is updated to (5,4), the update distance is 2.06, and draw a green circle;
  • Because the circle drawn enters another part of the (5,4) division, we need to find it in this area. It is found that the distance between the (2,3) node and the target point is 1.8, which is closer than (5,4). The nearest neighbor is (2,3), and a blue circle is drawn with (3,4.5) as the center.
  • Go back to it(7,2)The blue circle doesn’t intersect the hyperplane x = 7, so we don’t have to go into the (7,2) right subspace to search. At this point, all nodes in the search path have been backtracked, and the whole search is finished. The nearest neighbor (2,3) is returned, and the nearest distance is 1.8.

This is the process of finding the nearest neighbor. Only the nearest point is found. How can we find the k nearest neighbor?

(2) Using KD tree to get k-nearest neighbor

Mr. Li Hang only mentioned one sentence in “statistical learning methods”. According to the above ideas, we can find K Nearest neighbor, let’s change the above idea: the above is to find 1, and now we need to find K, so we can maintain an ordered data set with the length of K. if we find a smaller distance in the search process, we will replace the maximum value in the set, so that we can get k after the search It seems very reasonable for us to use priority queue to implement the data set. OK, let’s try:

class NodeDis(object):
	Customize the class that binds the node to its target distance, and compares the size of the object according to the distance self.distance  > other.distance

	def __init__(self, node, dis):
		self.node = node
		self.distance = dis

	def __lt__(self, other):
		return self.distance > other.distance

def search_k_neighbour(kdtree, target, k):
	k_queue = PriorityQueue()
	return search_path(kdtree, target, k, k_queue)

def search_path(kdtree, target, k, k_queue):
	Recursively finding k-nearest neighbors in the whole tree
	if kdtree is None:
		return NodeDis([], np.inf)
	path = []
	while kdtree:
		if target[kdtree.split] <= kdtree.node_data[kdtree.split]:
			path.append((kdtree.node_data, kdtree.split, kdtree.right))
			kdtree = kdtree.left
			path.append((kdtree.node_data, kdtree.split, kdtree.left))
			kdtree = kdtree.right
	radius = np.inf
	for i in path:
		node_data = i[0]
		split = i[1]
		opposite_tree = i[2]
		#First determine whether the delineated area intersects with the segmentation axis
		distance_axis = abs(node_data[split] - target[split])
		if distance_axis > radius:
			distance = cal_Euclidean_dis(node_data, target)
			k_queue.put(NodeDis(node_data, distance))
			if k_queue.qsize() > k:
				radius = k_queue.queue[-1].distance
			# print(radius,[i.distance for i in k_queue.queue])
			search_path(opposite_tree, target, k, k_queue)
	return k_queue

def cal_Euclidean_dis(point1, point2):
	return np.sqrt(np.sum((np.array(point1) - np.array(point2)) ** 2))

Test the data in the example, and output the results of 3 nearest neighbor points of (3,4.5), as shown in the following figure:


2.3 KNN classification based on violence search and KD tree

According to the method of brute force search and KD tree to get k-nearest neighbor in the previous section, we implement a KNN classifier

Implementation of KNN in Python

import numpy as np
from collections import Counter
import kd_ Tree ා KD tree implemented in the previous section

class KNNClassfier(object):

	def __init__(self, k, distance='Euclidean',kdtree=True):
		Initialization determines the distance measurement method and K, whether to use KD tree
		self.distance = distance
		self.k = k

	def get_k_neighb(self, new_point, train_data, labels):
		K labels of k nearest neighbors obtained by violent search
		distance_lst = [(self.cal_Euclidean_dis(new_point, point), label) for point, label in zip(train_data, labels)]
		distance_lst.sort(key=lambda x: x[0], reverse=False)
		k_labels = [i[1] for i in distance_lst[:self.k]]
		return k_labels

	def fit(self, train_data):
		If you use KD tree, you can train the tree well in advance, as if you have a little learning process. If you want to improve the efficiency, you can save the constructed tree
		kdtree = kd_tree.KDTree(train_data)
		return kdtree

	def get_k_neighb_kdtree(self, new_point, train_data, labels, kdtree):
		K labels of k-nearest neighbors are obtained by KD tree search
		result = kd_tree.search_k_neighbour(kdtree.tree, new_point, self.k)
		data_dict={data:label for data,label in zip(train_data, labels)}
		k_labels=[data_dict[data.node] for data in result.queue]
		return k_labels

	def predict(self, new_point, train_data, labels,kdtree):
		if self.kdtree:
			k_labels = self.get_k_neighb_kdtree(new_point, train_data, labels,kdtree)
			k_labels = self.get_k_neighb(new_point, train_data, labels)
		# print(k_labels)
		return self.decision_rule(k_labels)

	def decision_rule(self, k_labels):
		Classification decision rule: voting
		label_count = Counter(k_labels)
		new_label = None
		max = 0
		for label, count in label_count.items():
			if count > max:
				new_label = label
				max = count
		return new_label

	def cal_Euclidean_dis(self, point1, point2):
		return np.sqrt(np.sum((np.array(point1) - np.array(point2)) ** 2))

The classification results are as follows. The results of KNN implemented by the two methods are the same (the two square points in the figure are the points to be classified)