Skip to content

Commit c0543f7

Browse files
Added Manhattan and Minkowski distance metrics to KNN algorithm (#13546)
1 parent bba005a commit c0543f7

File tree

1 file changed

+36
-19
lines changed

1 file changed

+36
-19
lines changed

machine_learning/k_nearest_neighbours.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from collections import Counter
1616
from heapq import nsmallest
17-
1817
import numpy as np
1918
from sklearn import datasets
2019
from sklearn.model_selection import train_test_split
@@ -26,23 +25,36 @@ def __init__(
2625
train_data: np.ndarray[float],
2726
train_target: np.ndarray[int],
2827
class_labels: list[str],
28+
distance_metric: str = "euclidean",
29+
p: int = 2,
2930
) -> None:
3031
"""
31-
Create a kNN classifier using the given training data and class labels
32+
Create a kNN classifier using the given training data and class labels.
33+
34+
Parameters:
35+
-----------
36+
distance_metric : str
37+
Type of distance metric to use ('euclidean', 'manhattan', 'minkowski')
38+
p : int
39+
Power parameter for Minkowski distance (default 2)
3240
"""
33-
self.data = zip(train_data, train_target)
41+
self.data = list(zip(train_data, train_target))
3442
self.labels = class_labels
43+
self.distance_metric = distance_metric
44+
self.p = p
3545

36-
@staticmethod
37-
def _euclidean_distance(a: np.ndarray[float], b: np.ndarray[float]) -> float:
46+
def _calculate_distance(self, a: np.ndarray[float], b: np.ndarray[float]) -> float:
3847
"""
39-
Calculate the Euclidean distance between two points
40-
>>> KNN._euclidean_distance(np.array([0, 0]), np.array([3, 4]))
41-
5.0
42-
>>> KNN._euclidean_distance(np.array([1, 2, 3]), np.array([1, 8, 11]))
43-
10.0
48+
Calculate distance between two points based on the selected metric.
4449
"""
45-
return float(np.linalg.norm(a - b))
50+
if self.distance_metric == "euclidean":
51+
return float(np.linalg.norm(a - b))
52+
elif self.distance_metric == "manhattan":
53+
return float(np.sum(np.abs(a - b)))
54+
elif self.distance_metric == "minkowski":
55+
return float(np.sum(np.abs(a - b) ** self.p) ** (1 / self.p))
56+
else:
57+
raise ValueError("Invalid distance metric. Choose 'euclidean', 'manhattan', or 'minkowski'.")
4658

4759
def classify(self, pred_point: np.ndarray[float], k: int = 5) -> str:
4860
"""
@@ -57,23 +69,18 @@ def classify(self, pred_point: np.ndarray[float], k: int = 5) -> str:
5769
>>> knn.classify(point)
5870
'A'
5971
"""
60-
# Distances of all points from the point to be classified
6172
distances = (
62-
(self._euclidean_distance(data_point[0], pred_point), data_point[1])
73+
(self._calculate_distance(data_point[0], pred_point), data_point[1])
6374
for data_point in self.data
6475
)
6576

66-
# Choosing k points with the shortest distances
6777
votes = (i[1] for i in nsmallest(k, distances))
68-
69-
# Most commonly occurring class is the one into which the point is classified
7078
result = Counter(votes).most_common(1)[0][0]
7179
return self.labels[result]
7280

7381

7482
if __name__ == "__main__":
7583
import doctest
76-
7784
doctest.testmod()
7885

7986
iris = datasets.load_iris()
@@ -84,5 +91,15 @@ def classify(self, pred_point: np.ndarray[float], k: int = 5) -> str:
8491

8592
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
8693
iris_point = np.array([4.4, 3.1, 1.3, 1.4])
87-
classifier = KNN(X_train, y_train, iris_classes)
88-
print(classifier.classify(iris_point, k=3))
94+
95+
print("\nUsing Euclidean Distance:")
96+
classifier1 = KNN(X_train, y_train, iris_classes, distance_metric="euclidean")
97+
print(classifier1.classify(iris_point, k=3))
98+
99+
print("\nUsing Manhattan Distance:")
100+
classifier2 = KNN(X_train, y_train, iris_classes, distance_metric="manhattan")
101+
print(classifier2.classify(iris_point, k=3))
102+
103+
print("\nUsing Minkowski Distance (p=3):")
104+
classifier3 = KNN(X_train, y_train, iris_classes, distance_metric="minkowski", p=3)
105+
print(classifier3.classify(iris_point, k=3))

0 commit comments

Comments
 (0)