Comments (7)
将
optimizer_stdp = SGD(params_stdp, lr=0., momentum=0.)中lr设为0
输出结果为:
epoch : 0 ; train_loss : 0.10000000149011612 ; train_acc : 0.10852148579752367
epoch : 1 ; train_loss : 0.10000000149011612 ; train_acc : 0.1088856518572469
epoch : 2 ; train_loss : 0.10000000149011612 ; train_acc : 0.1088856518572469
epoch : 3 ; train_loss : 0.10000000149011612 ; train_acc : 0.1088856518572469
epoch : 4 ; train_loss : 0.10000000149011612 ; train_acc : 0.1088856518572469
epoch : 5 ; train_loss : 0.10000000149011612 ; train_acc : 0.1088856518572469
epoch : 6 ; train_loss : 0.10000000149011612 ; train_acc : 0.1088856518572469
修改为使用optimizer 进行训练,就能够正常分类,修改后的代码如下:
optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
for epoch in range(start_epoch, args.epochs):
start_time = time.time()
net.train()
for i in range(stdp_learners.len()):
stdp_learners[i].enable()
train_loss = 0
train_acc = 0
train_samples = 0
for img, label in train_loader:
optimizer.zero_grad()
img = img.to(args.device)
label = label.to(args.device)
label_onehot = F.one_hot(label, 10).float()
out_fr = net(img)
loss = F.mse_loss(out_fr, label_onehot)
loss.backward()
optimizer.step()
train_samples += label.numel()
train_loss += loss.item() * label.numel()
train_acc += (out_fr.argmax(1) == label).float().sum().item()
torch.cuda.empty_cache()
functional.reset_net(net)
train_loss /= train_samples
train_acc /= train_samples
print('epoch : ',epoch,' ; ','train_loss : ',train_loss,' ; ','train_acc : ',train_acc)
输出如下:
epoch : 0 ; train_loss : 0.12450837841269663 ; train_acc : 0.13000728332119446
epoch : 1 ; train_loss : 0.13798252292362687 ; train_acc : 0.1540422432629279
epoch : 2 ; train_loss : 0.12840495524196754 ; train_acc : 0.1540422432629279
本实验用的 CSNN网络在我的数据集上最高可以达到95%的准确率,网络应该是没有问题的,但不知道在利用STDP进行训练时出了什么问题,导致无法分类
from spikingjelly.
建议把STDP学习率设置成0,先看看只用GD时网络是否收敛,检查一下网络是否正确训练
from spikingjelly.
STDP作为无监督的学习器,是不保证使用后性能能增加的。如果纯GD训练没有问题,那就得慢慢调试STDP的参数了
from spikingjelly.
将 optimizer_stdp = SGD(params_stdp, lr=0., momentum=0.)中lr设为0 输出结果为: epoch : 0 ; train_loss : 0.10000000149011612 ; train_acc : 0.10852148579752367 epoch : 1 ; train_loss : 0.10000000149011612 ; train_acc : 0.1088856518572469 epoch : 2 ; train_loss : 0.10000000149011612 ; train_acc : 0.1088856518572469 epoch : 3 ; train_loss : 0.10000000149011612 ; train_acc : 0.1088856518572469 epoch : 4 ; train_loss : 0.10000000149011612 ; train_acc : 0.1088856518572469 epoch : 5 ; train_loss : 0.10000000149011612 ; train_acc : 0.1088856518572469 epoch : 6 ; train_loss : 0.10000000149011612 ; train_acc : 0.1088856518572469
修改为使用optimizer 进行训练,就能够正常分类,修改后的代码如下: optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs) for epoch in range(start_epoch, args.epochs): start_time = time.time() net.train() for i in range(stdp_learners.len()): stdp_learners[i].enable() train_loss = 0 train_acc = 0 train_samples = 0 for img, label in train_loader: optimizer.zero_grad() img = img.to(args.device) label = label.to(args.device) label_onehot = F.one_hot(label, 10).float() out_fr = net(img) loss = F.mse_loss(out_fr, label_onehot) loss.backward() optimizer.step() train_samples += label.numel() train_loss += loss.item() * label.numel() train_acc += (out_fr.argmax(1) == label).float().sum().item() torch.cuda.empty_cache() functional.reset_net(net)
train_loss /= train_samples train_acc /= train_samples print('epoch : ',epoch,' ; ','train_loss : ',train_loss,' ; ','train_acc : ',train_acc)
输出如下: epoch : 0 ; train_loss : 0.12450837841269663 ; train_acc : 0.13000728332119446 epoch : 1 ; train_loss : 0.13798252292362687 ; train_acc : 0.1540422432629279 epoch : 2 ; train_loss : 0.12840495524196754 ; train_acc : 0.1540422432629279
本实验用的 CSNN网络在我的数据集上最高可以达到95%的准确率,网络应该是没有问题的,但不知道在利用STDP进行训练时出了什么问题,导致无法分类
您好,我也遇到了同样的问题,请问您解决了吗?
from spikingjelly.
我想再次强调,这是STDP的feature,不是bug😂
from spikingjelly.
STDP本来在深度SNN上就是不work的算法,即使STDP本身的实现是正确的,教程里只是展示如何使用。
from spikingjelly.
我刚刚也遇到了这个问题,看来只能先不用stdp了
from spikingjelly.
Related Issues (20)
- wana to apply some change on the x,y,p,t data: HOT 2
- 在推理时,无法进入自定义的LIF函数 HOT 1
- running build_ext error: [WinError 2] 系统找不到指定的文件。 HOT 4
- STDP学习器step函数存在数值类型异常 HOT 2
- 获取事件数据使用自定义transform参数报错 HOT 4
- 神经形态数据集处理的两种方法并用 HOT 2
- 请问如果本身使用时序数据,假设数据维度为(B,C,T),还需要额外模拟一个时间T吗,然后在T循环吗?可否出一个时序数据的教程?非常感谢您的回复 HOT 4
- logging模块被修改 HOT 1
- About slow processing of issues and discussions
- spikingjelly中的conv1d不用多步模式就会报错RuntimeError: Expected 2D (unbatched) or 3D (batched) input to conv1d HOT 2
- Set initial value for membrane potential HOT 5
- 将脉冲神经元添加到其他模型时,训练时发生错误 HOT 1
- Hello! I have some problems with
- A problem with amount of values to unpack. Gym version. HOT 1
- 神经元模型是在激活什么呢?神经元权重值体现在哪里? HOT 2
- Question about the evolution of the membrame potential HOT 10
- 在用AttributeMonitor记录神经元的电压值的时候,已经设置了m.store_v_seq = True,但还是只记录一个时间的电压 HOT 3
- 请问如何才能求出整个model中所有神经元的数量?因为要计算总的脉冲平均发射率? HOT 1
- what is mstdp and mstdpet in learning.py? HOT 1
- 降低时间步长方法 HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from spikingjelly.