1

🤓🚀 Como usar HyperOpt para la optimizacion de hiperparametros de nuestros modelos.

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchtext.data.functional import to_map_style_dataset
from hyperopt import fmin, tpe, hp, Trials, STATUS_OK

defobjective(params):
    EPOCHS = int(params['epochs'])
    TASA_APRENDIZAJE = params['lr']
    BATCH_TAMANO = int(params['batch_size'])

    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.SGD(model.parameters(), lr=TASA_APRENDIZAJE)

    train_iter, test_iter = DBpedia()
    train_dataset = to_map_style_dataset(train_iter)
    test_dataset = to_map_style_dataset(test_iter)

    num_train = int(len(train_dataset) * 0.95)
    split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train])

    train_dataloader = DataLoader(split_train_, batch_size=BATCH_TAMANO, shuffle=True, collate_fn=collate_batch)
    valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_TAMANO, shuffle=True, collate_fn=collate_batch)

    best_valid_loss = float('inf')
    for epoch in range(1, EPOCHS + 1):
        train_loss, train_acc = train(train_dataloader, model, optimizer, criterion)
        validation_loss, validation_acc = evaluate(valid_dataloader, model, criterion)
        if validation_loss < best_valid_loss:
            best_valid_loss = validation_loss
            torch.save(model.state_dict(), 'best-model.pt')
    
    return {'loss': best_valid_loss, 'status': STATUS_OK}

# Define the search space
space = {
    'epochs': hp.quniform('epochs', 1, 10, 1),
    'lr': hp.loguniform('lr', -5, 0),
    'batch_size': hp.quniform('batch_size', 32, 256, 16),
}

trials = Trials()

best = fmin(
    fn=objective,
    space=space,
    algo=tpe.suggest,
    max_evals=50,
    trials=trials
)

print("Best hyperparameters found were:", best)

Con este sencillo codigo podemos entrenar la red optimizando los hiperparametros con hyperOpt para el mejor resultado posible.

Escribe tu comentario
+ 2