Giter Site home page Giter Site logo

Comments (7)

Wang-jun-yu avatar Wang-jun-yu commented on June 18, 2024 1


optimizer_stdp = SGD(params_stdp, lr=0., momentum=0.)中lr设为0
输出结果为:
epoch : 0 ; train_loss : 0.10000000149011612 ; train_acc : 0.10852148579752367
epoch : 1 ; train_loss : 0.10000000149011612 ; train_acc : 0.1088856518572469
epoch : 2 ; train_loss : 0.10000000149011612 ; train_acc : 0.1088856518572469
epoch : 3 ; train_loss : 0.10000000149011612 ; train_acc : 0.1088856518572469
epoch : 4 ; train_loss : 0.10000000149011612 ; train_acc : 0.1088856518572469
epoch : 5 ; train_loss : 0.10000000149011612 ; train_acc : 0.1088856518572469
epoch : 6 ; train_loss : 0.10000000149011612 ; train_acc : 0.1088856518572469

修改为使用optimizer 进行训练,就能够正常分类,修改后的代码如下:
optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
for epoch in range(start_epoch, args.epochs):
start_time = time.time()
net.train()
for i in range(stdp_learners.len()):
stdp_learners[i].enable()
train_loss = 0
train_acc = 0
train_samples = 0
for img, label in train_loader:
optimizer.zero_grad()
img = img.to(args.device)
label = label.to(args.device)
label_onehot = F.one_hot(label, 10).float()
out_fr = net(img)
loss = F.mse_loss(out_fr, label_onehot)
loss.backward()
optimizer.step()
train_samples += label.numel()
train_loss += loss.item() * label.numel()
train_acc += (out_fr.argmax(1) == label).float().sum().item()
torch.cuda.empty_cache()
functional.reset_net(net)

    train_loss /= train_samples
    train_acc /= train_samples
    print('epoch : ',epoch,'  ;  ','train_loss : ',train_loss,'  ;  ','train_acc : ',train_acc)

输出如下:
epoch : 0 ; train_loss : 0.12450837841269663 ; train_acc : 0.13000728332119446
epoch : 1 ; train_loss : 0.13798252292362687 ; train_acc : 0.1540422432629279
epoch : 2 ; train_loss : 0.12840495524196754 ; train_acc : 0.1540422432629279

本实验用的 CSNN网络在我的数据集上最高可以达到95%的准确率,网络应该是没有问题的,但不知道在利用STDP进行训练时出了什么问题,导致无法分类

from spikingjelly.

fangwei123456 avatar fangwei123456 commented on June 18, 2024

建议把STDP学习率设置成0,先看看只用GD时网络是否收敛,检查一下网络是否正确训练

from spikingjelly.

fangwei123456 avatar fangwei123456 commented on June 18, 2024

STDP作为无监督的学习器,是不保证使用后性能能增加的。如果纯GD训练没有问题,那就得慢慢调试STDP的参数了

from spikingjelly.

EECSPeanuts avatar EECSPeanuts commented on June 18, 2024

将 optimizer_stdp = SGD(params_stdp, lr=0., momentum=0.)中lr设为0 输出结果为: epoch : 0 ; train_loss : 0.10000000149011612 ; train_acc : 0.10852148579752367 epoch : 1 ; train_loss : 0.10000000149011612 ; train_acc : 0.1088856518572469 epoch : 2 ; train_loss : 0.10000000149011612 ; train_acc : 0.1088856518572469 epoch : 3 ; train_loss : 0.10000000149011612 ; train_acc : 0.1088856518572469 epoch : 4 ; train_loss : 0.10000000149011612 ; train_acc : 0.1088856518572469 epoch : 5 ; train_loss : 0.10000000149011612 ; train_acc : 0.1088856518572469 epoch : 6 ; train_loss : 0.10000000149011612 ; train_acc : 0.1088856518572469

修改为使用optimizer 进行训练,就能够正常分类,修改后的代码如下: optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs) for epoch in range(start_epoch, args.epochs): start_time = time.time() net.train() for i in range(stdp_learners.len()): stdp_learners[i].enable() train_loss = 0 train_acc = 0 train_samples = 0 for img, label in train_loader: optimizer.zero_grad() img = img.to(args.device) label = label.to(args.device) label_onehot = F.one_hot(label, 10).float() out_fr = net(img) loss = F.mse_loss(out_fr, label_onehot) loss.backward() optimizer.step() train_samples += label.numel() train_loss += loss.item() * label.numel() train_acc += (out_fr.argmax(1) == label).float().sum().item() torch.cuda.empty_cache() functional.reset_net(net)

    train_loss /= train_samples
    train_acc /= train_samples
    print('epoch : ',epoch,'  ;  ','train_loss : ',train_loss,'  ;  ','train_acc : ',train_acc)

输出如下: epoch : 0 ; train_loss : 0.12450837841269663 ; train_acc : 0.13000728332119446 epoch : 1 ; train_loss : 0.13798252292362687 ; train_acc : 0.1540422432629279 epoch : 2 ; train_loss : 0.12840495524196754 ; train_acc : 0.1540422432629279

本实验用的 CSNN网络在我的数据集上最高可以达到95%的准确率,网络应该是没有问题的,但不知道在利用STDP进行训练时出了什么问题,导致无法分类

您好,我也遇到了同样的问题,请问您解决了吗?

from spikingjelly.

fangwei123456 avatar fangwei123456 commented on June 18, 2024

我想再次强调,这是STDP的feature,不是bug😂

from spikingjelly.

Yanqi-Chen avatar Yanqi-Chen commented on June 18, 2024

STDP本来在深度SNN上就是不work的算法,即使STDP本身的实现是正确的,教程里只是展示如何使用。

from spikingjelly.

thebug-dev avatar thebug-dev commented on June 18, 2024

我刚刚也遇到了这个问题,看来只能先不用stdp了

from spikingjelly.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.