Comments (9)
Thanks for pointing out the problem in our code!
I've added the evaluate_cls function and the predgt2doa_cls function to learner. You can use them for classification task.
The error occurs because of using the evaluate function of the regression task.
Meanwhile, I'll be refactoring some of the code to make it easier to read and use when I'm done with some other work.
from fn-ssl.
OK, I'll check the code with your infomation!
from fn-ssl.
oh, really thanks for your reply, and it is really my honor to see your new revised code
from fn-ssl.
Thank you for recognizing our work,
Q1: When you change target to classification, you should change the loss function to ce_loss. You can see the function we used in ce_loss is torch.nn.functional.cross_entropy, it contains the softmax function that there is no need to do in the model.
Q2: As you say, you need to change to the predgt2doa_cls, meanwhile, you can send me your detailed error information via issue or email, and I'll check the code with your error information to find the problem.
from fn-ssl.
thank you very much four reply, I will produce the one source data first, because I produce two source data yesterday, I want to see if this code can run on two source data. after I produce the data, I will run it immediately and inform you
from fn-ssl.
OK, I'll check the code with your infomation!
Hi, this is my error
seems the pred is a list not a dict
Traceback (most recent call last):
File "D:\code\FN-SSL-main\Train.py", line 103, in
loss_val, metric_val = learner.test_epoch(dataloader_val, return_metric=True)
File "D:\code\FN-SSL-main\Learner.py", line 176, in test_epoch
metric_batch = self.evaluate(pred=pred_batch, gt=gt_batch)
File "D:\code\FN-SSL-main\Learner.py", line 573, in evaluate
doa_pred = pred['doa'] * 180 / np.pi
TypeError: list indices must be integers or slices, not str
from fn-ssl.
sorry, when I see the Train.py, I have a confusion about this code:
win_len = 512
nfft = 512
# win_shift_ratio = 0.5
win_shift_ratio = 160/512
fre_used_ratio = 1
seg_fra_ratio = 1 # one estimate per segment (namely seg_fra_ratio frames)
# seg_fra_ratio = 12 # one estimate per segment (namely seg_fra_ratio frames)
seg_len = int(win_len*win_shift_ratio*(seg_fra_ratio+1))
seg_shift = int(win_len*win_shift_ratio*seg_fra_ratio)
- due to my win_len = 512, and win_shift = 160, so I change the win_shift_ratio = 160/512, did this is ok ?
- due to real time considering, in model, there is a pooling along the time dimension, seems this will be not real time?, I set here seg_fra_ratio=1, since I want to every frame have a doa estimation, then the model did not do mean pooling, they will have same dimension in Time dimension, the seg_len according to " int(win_lenwin_shift_ratio(seg_fra_ratio+1))" will be 320, did this value should be equal to win_len?
from fn-ssl.
Hello,
For q1, you can set like this. you can see the Module.stft and Dataset.Segmenting_SRPDNN for more detail about that. And I think set int(win_lenwin_shift_ratio(seg_fra_ratio+1)) to win_len in your current configs is ok.
For q2, in my opinion, pooling along the time dimension does not affect the causality of the method. actually, there are 12 stft frames to 1 doa frame. From the perspective of the doa frame, the current doa frame does not use information from future doa frames. This is theoretically online. Meanwhile, you can measure the real-time performance of the algorithm by calculating the RTF.
Hope it can help you.
from fn-ssl.
thanks, I will try as you advice
from fn-ssl.
Related Issues (3)
- doesn't have the Noise92 dataset HOT 8
- no 0 and 180 degrees HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from fn-ssl.