Tvoříme model podobný GPT (6.díl): Trénování modelu

V tomto návodu ukážeme, jak trénovat jednoduchý model neuronové sítě pro klasifikaci ručně psaných číslic z datasetu MNIST pomocí Keras (součást TensorFlow 2.0 a vyšší).

V tomto návodu ukážeme, jak trénovat jednoduchý model neuronové sítě pro klasifikaci ručně psaných číslic z datasetu MNIST pomocí Keras (součást TensorFlow 2.0 a vyšší).

Nainstalujte TensorFlow, pokud jej ještě nemáte

pip install tensorflow

Importujte potřebné moduly

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Dropout

Načtěte dataset MNIST

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()

Předzpracujte dataset

x_train, x_test = x_train / 255.0, x_test / 255.0

Definujte architekturu modelu (použijeme stejnou architekturu jako v předchozím návodu)

input_shape = (28, 28)  # Rozměr vstupních obrázků
num_classes = 10        # Počet tříd (0-9)

model = Sequential([
    Flatten(input_shape=input_shape),
    Dense(128, activation='relu'),
    Dropout(0.2),
    Dense(num_classes, activation='softmax')
])

Kompilujte model

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

Trénujte model

epochs = 10

history = model.fit(x_train, y_train, epochs=epochs, validation_data=(x_test, y_test))

Vyhodnoťte model

test_loss, test_accuracy = model.evaluate(x_test, y_test)
print(f"Test loss: {test_loss}")
print(f"Test accuracy: {test_accuracy}")

Tento návod ukazuje, jak načíst dataset MNIST, předzpracovat data, definovat architekturu modelu, kompilovat model, trénovat model a vyhodnotit model na testovacích datech. Model je trénován pomocí funkce fit(), která bere trénovací data, počet epoch a validační data jako argumenty. Počet epoch můžete upravit podle potřeby.

Trénování modelu je často iterativní proces, který zahrnuje ladění hyperparametrů, architektury modelu a předzpracování dat, aby se dosáhlo co nejlepšího výkonu na testovacích nebo validačních datech.

Napsat komentář