DATA_DIR = 'data/' DATASET = 'fmnist' TEST_DATA_FILENAME = DATA_DIR + DATASET + '/t10k-images-idx3-ubyte' TEST_LABELS_FILENAME = DATA_DIR + DATASET + '/t10k-labels-idx1-ubyte' TRAIN_DATA_FILENAME = DATA_DIR + DATASET + '/train-images-idx3-ubyte' TRAIN_LABELS_FILENAME = DATA_DIR + DATASET + '/train-labels-idx1-ubyte' def bytes_to_int(byte_data): return int.from_bytes(byte_data, 'big') def read_images(filename, n_max_images=None): images = [] with open(filename, 'rb') as f: _ = f.read(4) n_images = bytes_to_int(f.read(4)) if n_max_images: n_images = n_max_images n_rows = bytes_to_int(f.read(4)) n_cols = bytes_to_int(f.read(4)) for image_idx in range(n_images): image = [] for row_idx in range(n_rows): row = [] for col_idx in range(n_cols): pixel = f.read(1) row.append(pixel) image.append(row) images.append(image) return images def read_labels(filename, n_max_labels=None): labels = [] with open(filename, 'rb') as f: _ = f.read(4) # magic number n_labels = bytes_to_int(f.read(4)) if n_max_labels: n_labels = n_max_labels for label_idx in range(n_labels): label = bytes_to_int(f.read(1)) labels.append(label) return labels def flatten_list(subl): return [pixel for sublist in subl for pixel in sublist] def extract_features(X): return [flatten_list(sample) for sample in X] def classify_one(args): test_sample_idx, test_sample, X_train, y_train, k = args # unpack args print(test_sample_idx) training_distances = [dist(train_sample, test_sample) for train_sample in X_train] sorted_distance_indices = [ pair[0] for pair in sorted(enumerate(training_distances), key=lambda x: x[1]) ] candidates = [y_train[idx] for idx in sorted_distance_indices[:k]] top_candidate = max(candidates, key=candidates.count) return top_candidate from multiprocessing import Pool import multiprocessing def dist(x, y): return sum( [ (bytes_to_int(x_i) - bytes_to_int(y_i)) ** 2 for x_i, y_i in zip(x, y)] ) ** 0.5 def knn(X_train, y_train, X_test, k=3): with Pool(processes=multiprocessing.cpu_count()) as pool: work_items = [(test_sample_idx, test_sample, X_train, y_train, k) for test_sample_idx, test_sample in enumerate(X_test)] y_pred = pool.map(classify_one, work_items) return y_pred def get_garment_from_label(label): return [ 'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot', ][label] def main(): n_train = 6000 n_test = 100 k = 7 print(f"Dataset: {DATASET}") print(f"n_train: {n_train}") print(f"n_test: {n_test}") print(f"k: {k}") X_train = read_images(TRAIN_DATA_FILENAME, n_train) y_train = read_labels(TRAIN_LABELS_FILENAME, n_train) X_test = read_images(TEST_DATA_FILENAME, n_test) y_test = read_labels(TEST_LABELS_FILENAME, n_test) X_train = extract_features(X_train) X_test = extract_features(X_test) y_pred = knn(X_train, y_train, X_test, k) accuracy = sum([ int(y_pred_i == y_test_i) for y_pred_i, y_test_i in zip(y_pred, y_test) ]) / len(y_test) if DATASET == 'fmnist': garments_pred = [ get_garment_from_label(label) for label in y_pred ] print(f"Predicted garments: {garments_pred}") else: print(f"Predicted labels: {y_pred}") print(f"Accuracy: {accuracy * 100}%") if __name__ == '__main__': main()