Регистрация | Вход
import torchimport torch.nn as nnimport torch.optim as optimimport torchvisionfrom torchvision import transformsimport time# === 1. Загрузка MNIST ===transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1)) # 28x28 → 784])train_dataset = torchvision.datasets.MNIST( root="./data", train=True, download=True, transform=transform)test_dataset = torchvision.datasets.MNIST( root="./data", train=False, download=True, transform=transform)learn_inputs = torch.stack([x for x, _ in train_dataset])learn_outputs = torch.nn.functional.one_hot(torch.tensor([y for _, y in train_dataset]), num_classes=10).float()test_inputs = torch.stack([x for x, _ in test_dataset])test_outputs = torch.nn.functional.one_hot(torch.tensor([y for _, y in test_dataset]), num_classes=10).float()# === 2. Конфигурация сети ===class MLP(nn.Module): def __init__(self): super().__init__() self.net = nn.Sequential( nn.Linear(784, 128), nn.ReLU(), nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 10) ) def forward(self, x): return self.net(x)# === 3. Настройки ===device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = MLP().to(device)optimizer = optim.SGD(model.parameters(), lr=0.1)criterion = nn.CrossEntropyLoss()y_train = torch.argmax(learn_outputs, dim=1).long()y_test = torch.argmax(test_outputs, dim=1).long()# === 4. Обучение ===print("Training started...")start = time.time()epochs = 15batch_size = 64for epoch in range(epochs): perm = torch.randperm(len(learn_inputs)) learn_inputs = learn_inputs[perm] y_train = y_train[perm] for i in range(0, len(learn_inputs), batch_size): xb = learn_inputs[i:i+batch_size].to(device) yb = y_train[i:i+batch_size].to(device) optimizer.zero_grad() preds = model(xb) loss = criterion(preds, yb) loss.backward() optimizer.step() # уменьшение lr for g in optimizer.param_groups: g['lr'] *= 0.95 print(f"Epoch {epoch+1}/{epochs} | Loss: {loss.item():.4f} | LR: {optimizer.param_groups[0]['lr']:.5f}")end = time.time()# === 5. Тестирование ===with torch.no_grad(): preds = model(test_inputs.to(device)) acc = (preds.argmax(dim=1) == y_test.to(device)).float().mean().item()print(f"\nTraining time: {end - start:.2f} seconds")print(f"Final Accuracy: {acc:.4f}")print("Done!")