Giter Site home page Giter Site logo

Comments (5)

jack-willturner avatar jack-willturner commented on May 30, 2024

Thanks for your interest :)

Could you share the code used to produce those histograms? That will make it easier to see what's going on.

I will also try to package up and include our original histogram plotting code this week.

from nas-without-training.

jack-willturner avatar jack-willturner commented on May 30, 2024

I've added the plotting code in plot_histograms.py.

If you're happy with this I'll close the issue. Otherwise, I'm still happy to take a look at your code and try to spot anything different.

from nas-without-training.

larenzhang avatar larenzhang commented on May 30, 2024

Thanks for your reply and my codes are as the following:

from nas_201_api import NASBench201API as API
import argparse
import numpy as np
import torch
from datasets import get_datasets
from models import get_cell_based_tiny_net
import os
import matplotlib.pyplot as plt


parser = argparse.ArgumentParser(description='NAS Without Training')
parser.add_argument('--data_loc', default='~/dataset/cifar10/', type=str, help='dataset folder')
parser.add_argument('--api_loc', default='NAS-Bench-201-v1_1-096897.pth',
                    type=str, help='path to API')
parser.add_argument('--save_loc', default='jacob_matrix_saved', type=str, help='folder to save results')
parser.add_argument('--batch_size', default=256, type=int)
parser.add_argument('--GPU', default='0', type=str)
parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--trainval', action='store_true')
parser.add_argument('--dataset', default='cifar10', type=str)
parser.add_argument('--n_samples', default=10, type=int)
parser.add_argument('--iters', default=1, type=int)
parser.add_argument('--corrcoef', default='pearson', type=str)

args = parser.parse_args()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
api = API(args.api_loc)
ARCHS_NUM = 15625
os.makedirs(args.save_loc, exist_ok=True)

train_data, valid_data, xshape, class_num = get_datasets(args.dataset, args.data_loc, cutout=0)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=True)
print(device)

def get_batch_jacobian(net, x, target, to, device, args=None):
    net.zero_grad()

    x.requires_grad_(True)

    _, y = net(x)

    y.backward(torch.ones_like(y))
    jacob = x.grad.detach()

    return jacob, target.detach()

def get_batch_corrcoef(arch, x, target):
    config = api.get_net_config(arch, args.dataset)
    config['num_classes'] = 1
    network = get_cell_based_tiny_net(config)  # create the network from configuration
    network = network.to(device)
    network.eval()

    jacob, labels = get_batch_jacobian(network, x, target, 1, device, args)
    jacob = jacob.detach().cpu().numpy()
    jacob = np.reshape(jacob, (args.batch_size, -1))

    if args.corrcoef == 'pearson':
        jacob = np.corrcoef(jacob)
    else:
        raise NotImplementedError

    return jacob, config

def statistic_cifar10():
    if args.dataset == 'cifar10':
        acc_type = 'ori-test'
        if args.trainval:
            acc_type = 'x-valid'
    else:
        acc_type = 'x-test'
        val_acc_type = 'x-valid'

    dset = args.dataset if not args.trainval else 'cifar10-valid'

    archs_pre_up_90 = []
    archs_pre_80_to_90 = []
    archs_pre_70_to_80 = []
    archs_pre_60_to_70 = []
    archs_pre_below_60 = []

    for arch in range(ARCHS_NUM):
        info = api.query_by_index(arch)
        acc = info.get_metrics(dset, acc_type)['accuracy']
        if acc >= 90:
            archs_pre_up_90.append(arch)
        elif acc >= 80:
            archs_pre_80_to_90.append(arch)
        elif acc >= 70:
            archs_pre_70_to_80.append(arch)
        elif acc >= 60:
            archs_pre_60_to_70.append(arch)
        else:
            archs_pre_below_60.append(arch)

    archs_pre = [archs_pre_up_90, archs_pre_80_to_90, archs_pre_70_to_80, archs_pre_60_to_70, archs_pre_below_60]
    print('archs_pre_up_90 num:{0}'.format(len(archs_pre_up_90)))
    print('archs_pre_80_to_90 num:{0}'.format(len(archs_pre_80_to_90)))
    print('archs_pre_70_to_80 num:{0}'.format(len(archs_pre_up_90)))
    print('archs_pre_60_to_70 num:{0}'.format(len(archs_pre_70_to_80)))
    print('archs_pre_below_60 num:{0}'.format(len(archs_pre_below_60)))

    archs_pre_up_90_samples = np.random.choice(archs_pre_up_90, args.n_samples)
    archs_pre_80_to_90_samples = np.random.choice(archs_pre_80_to_90, args.n_samples)
    archs_pre_70_to_80_samples = np.random.choice(archs_pre_70_to_80, args.n_samples)
    archs_pre_60_to_70_samples = np.random.choice(archs_pre_60_to_70, args.n_samples)
    archs_pre_below_60_samples = np.random.choice(archs_pre_below_60, args.n_samples)

    total_sample_archs = [archs_pre_up_90_samples,archs_pre_80_to_90_samples,archs_pre_70_to_80_samples,archs_pre_60_to_70_samples,archs_pre_below_60_samples]
    jacobs_all = []
    archs_all_saved = []
    for i, archs in enumerate(total_sample_archs):
        jacobs = []
        archs_saved = []
        for arch in archs:
            data_iterator = iter(train_loader)
            x, target = next(data_iterator)
            x, target = x.to(device), target.to(device)

            jacob, config = get_batch_corrcoef(arch, x, target)

            if np.isnan(np.sum(jacob)):
                while True:
                    arch = np.random.choice(archs_pre[i], 1)[0]
                    if arch not in archs:
                        jacob, config = get_batch_corrcoef(arch, x, target)
                        if not np.isnan(np.sum(jacob)):
                            break

            genotype = config['arch_str']
            archs_saved.append(genotype)
            jacobs.append(jacob)

        archs_all_saved.append(archs_saved)
        jacobs_all.append(jacobs)

    plt_histogram(jacobs_all)

def plt_histogram(jacobs_all):

    for i, jacobs in enumerate(jacobs_all):
        for j, jacob in enumerate(jacobs):
            jacob = np.reshape(jacob, (args.batch_size, -1))
            plt.subplot(len(jacobs_all),len(jacobs),i*len(jacobs)+j+1)
            plt.hist(jacob.reshape(-1), bins=100)
            plt.xticks(range(-1,2))
            plt.xlim(-1.0,1.0)
            plt.xticks(fontsize=3)
            plt.yticks([])

    plt.tight_layout()
    plt.savefig('./correlation_matrix.png', dpi=300)
    # plt.show()


def main():
    if args.dataset == 'cifar10':
        statistic_cifar10()
    else:
        raise NotImplementedError

if __name__ == '__main__':
    main()

from nas-without-training.

mellorjc avatar mellorjc commented on May 30, 2024

Hi,

Sorry for the slow response.

The important difference between our plotting code and yours (and different to our search.py code too) is in our plotting code we do not switch the network to eval mode, but remain in training mode.

The models in nasbench201 use batch norm. By default pytorch batchnorm has track_running_stats=True. With this flag set to true, eval mode does not normalise using the batch instead using the running stats learned via training. Since we are not training our networks we want to use the statistics of the batch. The simplest way to do this is having the network in train mode.

Using eval mode in our search.py code was an oversight on our part. We are in the process of updating our results and will update the arxiv paper when complete. Preliminarily we see a small improvement in some experiments by removing eval from the code. Thank you very much for bringing this to our attention.

from nas-without-training.

jack-willturner avatar jack-willturner commented on May 30, 2024

I've updated the repo to remove the .eval() call, so I'll close this issue now.

from nas-without-training.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.