KNN: Python Code
-Priya R. Bachan, Rishika
Code
from csv import reader
from math import sqrt
# Load a CSV file
def load_csv(filename):
dataset = list()
with open(filename, 'r') as file:
csv_reader = reader(file)
for row in csv_reader:
if not row:
continue
dataset.append(row)
return dataset
# Convert string column to float
def str_column_to_float(dataset, column):
for row in dataset:
row[column] = float(row[column].strip())
# Convert string column to integer
def str_column_to_int(dataset, column):
class_values = [row[column] for row in dataset]
unique = set(class_values)
lookup = dict()
for i, value in enumerate(unique):
lookup[value] = i
print('[%s] => %d' % (value, i))
for row in dataset:
row[column] = lookup[row[column]]
return lookup
# Find the min and max values for each column
def dataset_minmax(dataset):
minmax = list()
for i in range(len(dataset[0])):
col_values = [row[i] for row in dataset]
value_min = min(col_values)
value_max = max(col_values)
minmax.append([value_min, value_max])
return minmax
# Rescale dataset columns to the range 0-1
def normalize_dataset(dataset, minmax):
for row in dataset:
for i in range(len(row)):
row[i] = (row[i] - minmax[i][0]) / (minmax[i][1] - minmax[i][0])
# Calculate the Euclidean distance between two vectors
def euclidean_distance(row1, row2):
distance = 0.0
for i in range(len(row1)-1):
distance += (row1[i] - row2[i])**2
return sqrt(distance)
# Locate the most similar neighbors
def get_neighbors(train, test_row, num_neighbors):
distances = list()
for train_row in train:
dist = euclidean_distance(test_row, train_row)
distances.append((train_row, dist))
distances.sort(key=lambda tup: tup[1])
neighbors = list()
for i in range(num_neighbors):
neighbors.append(distances[i][0])
return neighbors
# Make a prediction with neighbors
def predict_classification(train, test_row, num_neighbors):
neighbors = get_neighbors(train, test_row, num_neighbors)
output_values = [row[-1] for row in neighbors]
prediction = max(set(output_values), key=output_values.count)
return prediction
# Make a prediction with KNN on Iris Dataset
filename = 'iris.csv'
dataset = load_csv(filename)
for i in range(len(dataset[0])-1):
str_column_to_int(dataset, i)
# convert class column to integers
str_column_to_int(dataset, len(dataset[0])-1)
# define model parameter
num_neighbors = 5
# define a new record
row = [5.7,2.9,4.2,1.3]
# predict the label
label = predict_classification(dataset, row, num_neighbors)
print('Data=%s, Predicted: %s' % (row, label))
output
[5.3] => 0
[4.8] => 1
[7.6] => 2
[7.4] => 3
[6.8] => 4
[5.7] => 5
[6.1] => 6
[7.2] => 7
[4.6] => 8
[5.4] => 9
[4.7] => 10
[7] => 11
[4.3] => 12
[6.4] => 13
[5.6] => 14
[6.6] => 15
[7.9] => 16
[5.1] => 17
[sepal.length] => 18
[7.3] => 19
[6] => 20
[4.5] => 21
[6.9] => 22
[6.2] => 23
[4.4] => 24
[7.7] => 25
[6.7] => 26
[5.2] => 27
[5.9] => 28
[4.9] => 29
[6.3] => 30
[7.1] => 31
[6.5] => 32
[5] => 33
[5.8] => 34
[5.5] => 35
[3.4] => 0
[3.8] => 1
[2.2] => 2
[2.9] => 3
[2.8] => 4
[3.9] => 5
[3] => 6
[2] => 7
[2.3] => 8
[sepal.width] => 9
[3.1] => 10
[4] => 11
[4.1] => 12
[3.6] => 13
[4.4] => 14
[2.7] => 15
[2.6] => 16
[3.5] => 17
[4.2] => 18
[2.5] => 19
[3.2] => 20
[3.3] => 21
[3.7] => 22
[2.4] => 23
[3.8] => 0
[5.3] => 1
[1.9] => 2
[1] => 3
[4.8] => 4
[3.9] => 5
[5.7] => 6
[6.1] => 7
[3] => 8
[4.6] => 9
[1.1] => 10
[5.4] => 11
[4.7] => 12
[4.3] => 13
[5.6] => 14
[6.4] => 15
[6.6] => 16
[1.4] => 17
[1.6] => 18
[5.1] => 19
[1.7] => 20
[4] => 21
[6] => 22
[4.1] => 23
[4.5] => 24
[6.9] => 25
[1.5] => 26
[3.6] => 27
[4.4] => 28
[1.3] => 29
[6.7] => 30
[5.2] => 31
[3.5] => 32
[4.2] => 33
[3.3] => 34
[5.9] => 35
[4.9] => 36
[1.2] => 37
[petal.length] => 38
[3.7] => 39
[6.3] => 40
[5] => 41
[5.8] => 42
[5.5] => 43
[petal.width] => 0
[2.2] => 1
[2.1] => 2
[1.9] => 3
[0.5] => 4
[1] => 5
[0.6] => 6
[0.4] => 7
[1.1] => 8
[1.4] => 9
[1.6] => 10
[2] => 11
[2.3] => 12
[1.7] => 13
[0.1] => 14
[0.2] => 15
[1.5] => 16
[1.3] => 17
[2.5] => 18
[1.8] => 19
[0.3] => 20
[1.2] => 21
[2.4] => 22
[Setosa] => 0
[Versicolor] => 1
[Virginica] => 2
[variety] => 3
[] => 0
[] => 0
[] => 0
[] => 0
[] => 0
[] => 0
[] => 0
[] => 0
[] => 0
[] => 0
[] => 0
[] => 0
[] => 0
[] => 0
[5.4] => 1
[sepal.length] => 2
[5.8] => 3
[5.2] => 4
[3.5] => 0
[] => 1
[3] => 2
[2.7] => 3
[sepal.width] => 4
[] => 0
[4.5] => 1
[petal.length] => 2
[1.5] => 3
[5.1] => 4
[] => 0
[petal.width] => 1
[1.9] => 2
[0.2] => 3
[1.5] => 4
[] => 0
[?] => 1
[variety] => 2
Data=[5.7, 2.9, 4.2, 1.3], Predicted: 0