mirror of https://github.com/microsoft/autogen.git
396 lines
12 KiB
Plaintext
396 lines
12 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Pytorch model tuning example on CIFAR10\n",
|
|
"This notebook uses flaml to tune a pytorch model on CIFAR10. It is modified based on [this example](https://docs.ray.io/en/master/tune/examples/cifar10_pytorch.html).\n",
|
|
"\n",
|
|
"**Requirements.** This notebook requires:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"%pip install torchvision flaml[blendsearch,ray]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Network Specification"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"import torch.nn as nn\n",
|
|
"import torch.nn.functional as F\n",
|
|
"import torch.optim as optim\n",
|
|
"from torch.utils.data import random_split\n",
|
|
"import torchvision\n",
|
|
"import torchvision.transforms as transforms\n",
|
|
"\n",
|
|
"\n",
|
|
"class Net(nn.Module):\n",
|
|
"\n",
|
|
" def __init__(self, l1=120, l2=84):\n",
|
|
" super(Net, self).__init__()\n",
|
|
" self.conv1 = nn.Conv2d(3, 6, 5)\n",
|
|
" self.pool = nn.MaxPool2d(2, 2)\n",
|
|
" self.conv2 = nn.Conv2d(6, 16, 5)\n",
|
|
" self.fc1 = nn.Linear(16 * 5 * 5, l1)\n",
|
|
" self.fc2 = nn.Linear(l1, l2)\n",
|
|
" self.fc3 = nn.Linear(l2, 10)\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" x = self.pool(F.relu(self.conv1(x)))\n",
|
|
" x = self.pool(F.relu(self.conv2(x)))\n",
|
|
" x = x.view(-1, 16 * 5 * 5)\n",
|
|
" x = F.relu(self.fc1(x))\n",
|
|
" x = F.relu(self.fc2(x))\n",
|
|
" x = self.fc3(x)\n",
|
|
" return x"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def load_data(data_dir=\"data\"):\n",
|
|
" transform = transforms.Compose([\n",
|
|
" transforms.ToTensor(),\n",
|
|
" transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
|
|
" ])\n",
|
|
"\n",
|
|
" trainset = torchvision.datasets.CIFAR10(\n",
|
|
" root=data_dir, train=True, download=True, transform=transform)\n",
|
|
"\n",
|
|
" testset = torchvision.datasets.CIFAR10(\n",
|
|
" root=data_dir, train=False, download=True, transform=transform)\n",
|
|
"\n",
|
|
" return trainset, testset"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Training"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from ray import tune\n",
|
|
"\n",
|
|
"def train_cifar(config, checkpoint_dir=None, data_dir=None):\n",
|
|
" if \"l1\" not in config:\n",
|
|
" logger.warning(config)\n",
|
|
" net = Net(2**config[\"l1\"], 2**config[\"l2\"])\n",
|
|
"\n",
|
|
" device = \"cpu\"\n",
|
|
" if torch.cuda.is_available():\n",
|
|
" device = \"cuda:0\"\n",
|
|
" if torch.cuda.device_count() > 1:\n",
|
|
" net = nn.DataParallel(net)\n",
|
|
" net.to(device)\n",
|
|
"\n",
|
|
" criterion = nn.CrossEntropyLoss()\n",
|
|
" optimizer = optim.SGD(net.parameters(), lr=config[\"lr\"], momentum=0.9)\n",
|
|
"\n",
|
|
" # The `checkpoint_dir` parameter gets passed by Ray Tune when a checkpoint\n",
|
|
" # should be restored.\n",
|
|
" if checkpoint_dir:\n",
|
|
" checkpoint = os.path.join(checkpoint_dir, \"checkpoint\")\n",
|
|
" model_state, optimizer_state = torch.load(checkpoint)\n",
|
|
" net.load_state_dict(model_state)\n",
|
|
" optimizer.load_state_dict(optimizer_state)\n",
|
|
"\n",
|
|
" trainset, testset = load_data(data_dir)\n",
|
|
"\n",
|
|
" test_abs = int(len(trainset) * 0.8)\n",
|
|
" train_subset, val_subset = random_split(\n",
|
|
" trainset, [test_abs, len(trainset) - test_abs])\n",
|
|
"\n",
|
|
" trainloader = torch.utils.data.DataLoader(\n",
|
|
" train_subset,\n",
|
|
" batch_size=int(2**config[\"batch_size\"]),\n",
|
|
" shuffle=True,\n",
|
|
" num_workers=4)\n",
|
|
" valloader = torch.utils.data.DataLoader(\n",
|
|
" val_subset,\n",
|
|
" batch_size=int(2**config[\"batch_size\"]),\n",
|
|
" shuffle=True,\n",
|
|
" num_workers=4)\n",
|
|
"\n",
|
|
" for epoch in range(int(round(config[\"num_epochs\"]))): # loop over the dataset multiple times\n",
|
|
" running_loss = 0.0\n",
|
|
" epoch_steps = 0\n",
|
|
" for i, data in enumerate(trainloader, 0):\n",
|
|
" # get the inputs; data is a list of [inputs, labels]\n",
|
|
" inputs, labels = data\n",
|
|
" inputs, labels = inputs.to(device), labels.to(device)\n",
|
|
"\n",
|
|
" # zero the parameter gradients\n",
|
|
" optimizer.zero_grad()\n",
|
|
"\n",
|
|
" # forward + backward + optimize\n",
|
|
" outputs = net(inputs)\n",
|
|
" loss = criterion(outputs, labels)\n",
|
|
" loss.backward()\n",
|
|
" optimizer.step()\n",
|
|
"\n",
|
|
" # print statistics\n",
|
|
" running_loss += loss.item()\n",
|
|
" epoch_steps += 1\n",
|
|
" if i % 2000 == 1999: # print every 2000 mini-batches\n",
|
|
" print(\"[%d, %5d] loss: %.3f\" % (epoch + 1, i + 1,\n",
|
|
" running_loss / epoch_steps))\n",
|
|
" running_loss = 0.0\n",
|
|
"\n",
|
|
" # Validation loss\n",
|
|
" val_loss = 0.0\n",
|
|
" val_steps = 0\n",
|
|
" total = 0\n",
|
|
" correct = 0\n",
|
|
" for i, data in enumerate(valloader, 0):\n",
|
|
" with torch.no_grad():\n",
|
|
" inputs, labels = data\n",
|
|
" inputs, labels = inputs.to(device), labels.to(device)\n",
|
|
"\n",
|
|
" outputs = net(inputs)\n",
|
|
" _, predicted = torch.max(outputs.data, 1)\n",
|
|
" total += labels.size(0)\n",
|
|
" correct += (predicted == labels).sum().item()\n",
|
|
"\n",
|
|
" loss = criterion(outputs, labels)\n",
|
|
" val_loss += loss.cpu().numpy()\n",
|
|
" val_steps += 1\n",
|
|
"\n",
|
|
" # Here we save a checkpoint. It is automatically registered with\n",
|
|
" # Ray Tune and will potentially be passed as the `checkpoint_dir`\n",
|
|
" # parameter in future iterations.\n",
|
|
" with tune.checkpoint_dir(step=epoch) as checkpoint_dir:\n",
|
|
" path = os.path.join(checkpoint_dir, \"checkpoint\")\n",
|
|
" torch.save(\n",
|
|
" (net.state_dict(), optimizer.state_dict()), path)\n",
|
|
"\n",
|
|
" tune.report(loss=(val_loss / val_steps), accuracy=correct / total)\n",
|
|
" print(\"Finished Training\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Test Accuracy"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def _test_accuracy(net, device=\"cpu\"):\n",
|
|
" trainset, testset = load_data()\n",
|
|
"\n",
|
|
" testloader = torch.utils.data.DataLoader(\n",
|
|
" testset, batch_size=4, shuffle=False, num_workers=2)\n",
|
|
"\n",
|
|
" correct = 0\n",
|
|
" total = 0\n",
|
|
" with torch.no_grad():\n",
|
|
" for data in testloader:\n",
|
|
" images, labels = data\n",
|
|
" images, labels = images.to(device), labels.to(device)\n",
|
|
" outputs = net(images)\n",
|
|
" _, predicted = torch.max(outputs.data, 1)\n",
|
|
" total += labels.size(0)\n",
|
|
" correct += (predicted == labels).sum().item()\n",
|
|
"\n",
|
|
" return correct / total"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Hyperparameter Optimization"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"import flaml\n",
|
|
"import os\n",
|
|
"\n",
|
|
"data_dir = os.path.abspath(\"data\")\n",
|
|
"load_data(data_dir) # Download data for all trials before starting the run"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Search space"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"max_num_epoch = 100\n",
|
|
"config = {\n",
|
|
" \"l1\": tune.randint(2, 9), # log transformed with base 2\n",
|
|
" \"l2\": tune.randint(2, 9), # log transformed with base 2\n",
|
|
" \"lr\": tune.loguniform(1e-4, 1e-1),\n",
|
|
" \"num_epochs\": tune.loguniform(1, max_num_epoch),\n",
|
|
" \"batch_size\": tune.randint(1, 5) # log transformed with base 2\n",
|
|
"}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"time_budget_s = 3600 # time budget in seconds\n",
|
|
"gpus_per_trial = 0.5 # number of gpus for each trial; 0.5 means two training jobs can share one gpu\n",
|
|
"num_samples = 500 # maximal number of trials\n",
|
|
"np.random.seed(7654321)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Launch the tuning"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import time\n",
|
|
"start_time = time.time()\n",
|
|
"result = flaml.tune.run(\n",
|
|
" tune.with_parameters(train_cifar, data_dir=data_dir),\n",
|
|
" config=config,\n",
|
|
" metric=\"loss\",\n",
|
|
" mode=\"min\",\n",
|
|
" low_cost_partial_config={\"num_epochs\": 1},\n",
|
|
" max_resource=max_num_epoch,\n",
|
|
" min_resource=1,\n",
|
|
" scheduler=\"asha\", # need to use tune.report to report intermediate results in train_cifar \n",
|
|
" resources_per_trial={\"cpu\": 1, \"gpu\": gpus_per_trial},\n",
|
|
" local_dir='logs/',\n",
|
|
" num_samples=num_samples,\n",
|
|
" time_budget_s=time_budget_s,\n",
|
|
" use_ray=True)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(f\"#trials={len(result.trials)}\")\n",
|
|
"print(f\"time={time.time()-start_time}\")\n",
|
|
"best_trial = result.get_best_trial(\"loss\", \"min\", \"all\")\n",
|
|
"print(\"Best trial config: {}\".format(best_trial.config))\n",
|
|
"print(\"Best trial final validation loss: {}\".format(\n",
|
|
" best_trial.metric_analysis[\"loss\"][\"min\"]))\n",
|
|
"print(\"Best trial final validation accuracy: {}\".format(\n",
|
|
" best_trial.metric_analysis[\"accuracy\"][\"max\"]))\n",
|
|
"\n",
|
|
"best_trained_model = Net(2**best_trial.config[\"l1\"],\n",
|
|
" 2**best_trial.config[\"l2\"])\n",
|
|
"device = \"cpu\"\n",
|
|
"if torch.cuda.is_available():\n",
|
|
" device = \"cuda:0\"\n",
|
|
" if gpus_per_trial > 1:\n",
|
|
" best_trained_model = nn.DataParallel(best_trained_model)\n",
|
|
"best_trained_model.to(device)\n",
|
|
"\n",
|
|
"checkpoint_value = (\n",
|
|
" getattr(best_trial.checkpoint, \"dir_or_data\", None)\n",
|
|
" or best_trial.checkpoint.value\n",
|
|
")\n",
|
|
"checkpoint_path = os.path.join(checkpoint_value, \"checkpoint\")\n",
|
|
"\n",
|
|
"model_state, optimizer_state = torch.load(checkpoint_path)\n",
|
|
"best_trained_model.load_state_dict(model_state)\n",
|
|
"\n",
|
|
"test_acc = _test_accuracy(best_trained_model, device)\n",
|
|
"print(\"Best trial test set accuracy: {}\".format(test_acc))"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3.11.0 64-bit",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.0"
|
|
},
|
|
"metadata": {
|
|
"interpreter": {
|
|
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
|
|
}
|
|
},
|
|
"vscode": {
|
|
"interpreter": {
|
|
"hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
|
|
}
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|