Notice: This page requires JavaScript to function properly.
Please enable JavaScript in your browser settings or update your browser.
Oppiskele Opetus-Testausjako ja Ristiinvalidointi | K-NN-luokitin
Luokittelu Pythonilla

Opetus-Testausjako ja Ristiinvalidointi

Pyyhkäise näyttääksesi valikon

Edellisissä luvuissa rakensimme mallit ja ennustimme uusia arvoja. Mutta meillä ei ole tietoa siitä, kuinka hyvin malli suoriutuu ja ovatko nämä ennusteet luotettavia.

Train-Test Split

Mallin suorituskyvyn mittaamiseksi tarvitsemme osan merkittyä dataa, jota malli ei ole nähnyt. Siksi jaamme kaikki merkityt tiedot satunnaisesti koulutusjoukkoon ja testijoukkoon.

traintestset

Tämä on mahdollista käyttämällä train_test_split()-kirjaston sklearn-funktiota.

TrainTestFunc

Yleensä malli jaetaan noin 70–90 % koulutusjoukolle ja 10–30 % testijoukolle.

Note
Huomio

Kun tietoaineistossa on miljoonia havaintoja, muutaman tuhannen käyttäminen testaukseen on yleensä enemmän kuin riittävästi. Tällaisissa tapauksissa voit varata testaukseen jopa alle 10 % aineistosta.

Nyt voidaan kouluttaa malli opetusaineistolla ja arvioida sen tarkkuus testiaineistolla.

123456789101112131415161718192021
from sklearn.neighbors import KNeighborsClassifier from sklearn.preprocessing import StandardScaler import pandas as pd from sklearn.model_selection import train_test_split df = pd.read_csv('https://codefinity-content-media.s3.eu-west-1.amazonaws.com/b71ff7ac-3932-41d2-a4d8-060e24b00129/starwars_binary.csv') X = df.drop('StarWars6', axis=1) y = df['StarWars6'] # Splitting the data X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) scaler = StandardScaler() X_train = scaler.fit_transform(X_train) X_test = scaler.transform(X_test) knn = KNeighborsClassifier(n_neighbors=3).fit(X_train, y_train) # Printing the accuracy on the test set print(knn.score(X_test, y_test))

Mutta tässä lähestymistavassa on joitakin puutteita:

  • Emme käytä kaikkea saatavilla olevaa dataa mallin kouluttamiseen, mikä voisi parantaa malliamme;
  • Koska arvioimme mallin tarkkuutta pienellä osalla dataa (testijoukko), tämä tarkkuus voi olla epäluotettava pienillä aineistoilla. Voit suorittaa yllä olevan koodin useita kertoja ja havaita, kuinka tarkkuus muuttuu joka kerta, kun uusi testijoukko arvotaan.

Ristiinvalidointi

Ristiinvalidointi on suunniteltu ratkaisemaan ylisovittamisen ongelmaa ja varmistamaan, että malli yleistyy hyvin uuteen, aiemmin näkemättömään dataan. Voit ajatella sitä kuin luokkahuonekoulutuksena mallillesi — se auttaa mallia oppimaan tasapainoisemmin ennen varsinaista lopputestiä.

Ideana on sekoittaa koko aineisto ja jakaa se n yhtä suureen osaan, joita kutsutaan taitoksiksi (folds). Tämän jälkeen malli käy läpi n iteraatiota. Jokaisessa iteraatiossa n-1 taitosta käytetään koulutukseen ja 1 taitos validointiin. Näin jokainen osa datasta toimii kerran validointina, ja saamme luotettavamman arvion mallin suorituskyvystä.

Huomioi, että ristiinvalidointi ei korvaa testijoukkoa. Kun olet käyttänyt ristiinvalidointia mallin valintaan ja hienosäätöön, arvioi se erillisellä testijoukolla saadaksesi puolueettoman arvion mallin todellisesta suorituskyvystä.

Note
Huomio

Yleinen valinta taitosten määräksi on 5. Tällöin yksi taitos toimii testijoukkona ja loput 4 taitosta käytetään koulutukseen.

risti

Koulutetaan viisi mallia hieman erilaisilla osajoukoilla. Jokaiselle mallille lasketaan testijoukon tarkkuus:

accuracy=predicted correctlypredicted correctly+predicted incorrectly\text{accuracy} = \frac{\text{predicted correctly}}{\text{predicted correctly} + \text{predicted incorrectly}}

Kun tämä on tehty, voidaan laskea näiden viiden tarkkuuden keskiarvo, joka toimii ristiinvalidoinnin tarkkuutena:

accuracyavg=accuracy1+accuracy2+...+accuracy55\text{accuracy}_{avg} = \frac{\text{accuracy}_1+\text{accuracy}_2+...+\text{accuracy}_5}{5}

Luotettavampi tulos, koska laskimme tarkkuusluvun käyttäen kaikkia tietojamme – vain jaettuna eri tavoin viidessä iteraatiossa.

Nyt kun tiedämme, kuinka hyvin malli suoriutuu, voimme kouluttaa sen uudelleen koko aineistolla.

Onneksi sklearn tarjoaa cross_val_score()-funktion mallin arviointiin ristivalidoinnilla, joten sinun ei tarvitse toteuttaa sitä itse:

CrossValFunc

Tässä esimerkki siitä, miten ristivalidointia käytetään k-NN-mallin kanssa, joka on koulutettu Star Wars -arvosanadatan avulla:

12345678910111213141516171819
from sklearn.neighbors import KNeighborsClassifier from sklearn.preprocessing import StandardScaler import pandas as pd from sklearn.model_selection import cross_val_score df = pd.read_csv('https://codefinity-content-media.s3.eu-west-1.amazonaws.com/b71ff7ac-3932-41d2-a4d8-060e24b00129/starwars_binary.csv') X = df.drop('StarWars6', axis=1) y = df['StarWars6'] scaler = StandardScaler() X = scaler.fit_transform(X) knn = KNeighborsClassifier(n_neighbors=3) # Calculating the accuracy for each split scores = cross_val_score(knn, X, y, cv=5) print('Scores: ', scores) print('Average score:', scores.mean())

Luokittelussa oletusarvoisesti käytetty mittari on tarkkuus.

question mark

Valitse kaikki oikeat väittämät.

Valitse kaikki oikeat vastaukset

Oliko kaikki selvää?

Miten voimme parantaa sitä?

Kiitos palautteestasi!

Osio 1. Luku 6

Kysy tekoälyä

expand

Kysy tekoälyä

ChatGPT

Kysy mitä tahansa tai kokeile jotakin ehdotetuista kysymyksistä aloittaaksesi keskustelumme

Osio 1. Luku 6
some-alt