Table des matières
KNN : Exemple des Iris
Il s'agit d'un exemple classique.
Les iris (fleurs) se répartissent en différentes espèces : Setosa, Versicolor et Virginica. On voudrait qu'un algorithme puisse décider, si on lui présente une iris, à quelle espèce elle appartient.
Les données
Un individu est ici une fleur.
Pour chaque individu, on mesure ses caractéristiques :
- longueur et largeur des sépales,
- longueur et largeur des pétales.
On a fait ce relever pour une centaine de fleurs et on obtient un fichier iris.csv dont voici un résumé :
| sepal_length | sepal_width | petal_length | petal_width | species |
|---|---|---|---|---|
| 5.1 | 3.5 | 1.4 | 0.2 | setosa |
| 4.9 | 3.0 | 1.4 | 0.2 | setosa |
| 4.7 | 3.2 | 1.3 | 0.2 | setosa |
| 4.6 | 3.1 | 1.5 | 0.2 | setosa |
| 5.0 | 3.6 | 1.4 | 0.2 | setosa |
| 5.4 | 3.9 | 1.7 | 0.4 | setosa |
| 4.6 | 3.4 | 1.4 | 0.3 | setosa |
| 5.0 | 3.4 | 1.5 | 0.2 | setosa |
| 4.4 | 2.9 | 1.4 | 0.2 | setosa |
| 4.9 | 3.1 | 1.5 | 0.1 | setosa |
L'unité est le centimètre.
Ce que l'IA doit faire
On présente une nouvelle fleur qui ne fait pas partie de la base de données d'entraînement. On veut deviner son espèce.
| longueur sépale | largeur sépale | longueur pétale | largeur pétale | espèce |
|---|---|---|---|---|
| 4.3 | 2.5 | 1.1 | 0.3 | ? |
L'algorithme KNN cherchera les K plus proches fleurs dans la base d'entraînement afin de décider l'espèce de notre inconnue.
Implémentation
Lecture des données
Commencez par récupérer le fichier iris.csv
Je vous propose ensuite d'utiliser Pandas. Nous allons en profiter pour exploiter quelques-unes de ses fonctions.
import pandas
import matplotlib.pyplot as plt # pour le graphique
data = pandas.read_csv("iris.csv", delimiter=",")
Exécutez.
En console, vous pouvez visualiser un aperçu du contenu de data :
>>> data.head()
Pandas s'est chargé de convertir en float les données numériques. On n'a donc rien de plus à faire !
Je vous demande de le faire en console pour ne pas encombrer le programme. On pourrait faire cet affichage depuis le programme.
Visualisation
Pour bien comprendre le fonctionnement de l'algorithme, nous allons faire un graphique.
Pour que le graphique reste lisible, on va supprimer une partie des données. On se contentera des pétales et on enlèvera les sépales (juste pour cette représentation graphique)
def graphique_light(data):
# suppression de deux colonnes. axis = 1 signifie colonne.
data_light = data.drop(['sepal_width', 'sepal_length'], axis = 1)
# choix de couleurs pour les différentes espèces
colormap = { "setosa":'b', "virginica":'g', "versicolor":'k' }
couleurs = [colormap[s] for s in data_light['species']]
# tracé
data.plot.scatter(x = 'petal_length', y = 'petal_width', c = couleurs)
# j'ajoute un point pour l'explication
plt.scatter([2],[0.5], c='k')
plt.show()
Exécutez et lancez la fonction en console.
>>> graphique_light(data)
Le point noir serait une fleur dont on ignore l'espèce. Vous comprenez sur ce graphique que, du fait de sa grande proximité avec le nuage bleu des Setosa, on proposera que l'espèce de cette inconnue est Setosa.
Fonctions utiles
Un individu se présente. Par exemple :
individu = {"sepal_length":4.3, "sepal_width":3.0, "petal_length":1.3, "petal_width":0.3}
On veut une fonction knn(k, data, a_etiqueter) qui décide de l'espèce ce cet individu :
>>> knn(3, data, individu) 'setosa'
Pour arriver à ce résultat nous avons besoin de quelques ingrédients que nous listons ci-dessous.
Distance
Écrire une fonction distance(item1, item2)->float
def distance(item1, item2)->float:
'''
item1, item2: 2 individus, par exemple une ligne de la base
ou l'individu à étiqueter. On accède aux attributs
par exemple par item1['petal_length']
renvoie la distance entre item1 et item2
'''
Test à effectuer
it1 = {"sepal_length":4.3, "sepal_width":3.0, "petal_length":1.3, "petal_width":0.3}
it2 = {"sepal_length":5.2, "sepal_width":2.8, "petal_length":1.7, "petal_width":0.1}
d = distance(it1, it2)
assert round(d,3) == 1.025
Le calcul de distance ne fait pas de difficulté dans le cas des iris car toutes les grandeurs en jeu sont comparables simplement. Il ne s'agit que de mesures en centimètres. Voyez cette page pour voir les problèmes que l'on peut avoir dans d'autres cas.
Liste des distances
L'inconnu à étiqueter se présente, on veut prendre sa distance avec chaque individu de la base en notant l'espèce.
Écrire la fonction liste_distances(data, a_etiqueter)
def liste_distances(data, a_etiqueter):
'''
data: données de la base d'entraînement
a_etiqueter: individu à étiqueter
renvoie un tableau constitué de paires (distance, espèce) pour
chaque individu de la base
'''
Pour parcourir les items de data vous pourrez utiliser for indice, item in data.iterrows()
Extraire k plus proches
Nous disposons d'une liste dont les items sont de la forme (distance, espèce).
Nous voulons extraire les espèces des k individus dont la distance est la plus faible.
Complétez la fonction ci-dessous.
def k_plus_proches(liste, k):
'''
liste: tableau d'items (distance, espèce)
k: nombre de plus proches à trouver
renvoie un tableau donnant les espèce des k plus proches
par exemple ['setosa', 'setosa', 'virginica']
'''
extraire le majoritaire
Nous disposons maintenant de la liste des k plus proches. On veut identifier le majoritaire.
Complétez.
def get_maj(plus_proches):
'''
plus_proches: liste d'espèce, par exemple ['setosa', 'setosa', 'virginica']
renvoie celui dont l'effectif est le plus grand
'''
Test à effectuer :
assert get_maj(['setosa', 'setosa', 'virginica']) == 'setosa' assert get_maj(['versicolor', 'virginica', 'virginica', 'setosa', 'virginica']) == 'virginica' m = get_maj(['versicolor', 'virginica', 'setosa', 'setosa', 'virginica']) assert m == 'setosa' or m == 'virginica'
Tout ensemble
Vous avez tous les ingrédients. Assemblez-les pour implémentez la fonction knn.
Test à effectuer :
assert knn(3, data, it1) == 'setosa'


