fyancy / metafd Goto Github PK
View Code? Open in Web Editor NEWThe source codes of Meta-learning for few-shot cross-domain fault diagnosis.
License: MIT License
The source codes of Meta-learning for few-shot cross-domain fault diagnosis.
License: MIT License
The reported accuracy in our paper and README file in github repository is 71.8%.
We post the evaluation process in detail as follows. (2022/4/6, by Yong Feng)
Test? y/n
y
Load Model successfully from [G:\model_save\meta_learning\CNN\cnn_ft\5shot\cnn_ft_C30_ep50]...
x shape: (10, 200, 1, 1024), y shape: (10, 200)
Acc: 0.7180, Loss: 1.2492
*** Testing time: 1.2068 (s) ***
Load Model successfully from [G:\model_save\meta_learning\CNN\cnn_ft\5shot\cnn_ft_C30_ep72]...
4-ways, 5-shots for testing ...
x shape: (4, 5, 1, 1024), y shape: (4, 5)
x shape: (4, 200, 1, 1024), y shape: (4, 200)
Acc: 0.8238, Loss: 0.5628
*** Testing time: 3.6451 (s) ***
Load Model successfully from [G:\model_save\meta_learning\CNN\cnn_ft\5shot\cnn_ft_C30_ep72]...
4-ways, 1-shots for testing ...
x shape: (4, 1, 1, 1024), y shape: (4, 1)
x shape: (4, 200, 1, 1024), y shape: (4, 200)
Acc: 0.5950, Loss: 1.0441
*** Testing time: 3.6962 (s) ***
For better validation, we post All model weights here. We stated this in README file.
使用训练集训练的时候报错
File "E:\GeekTools\anaconda\lib\site-packages\keras\utils\traceback_utils.py", line 70, in error_handler
raise e.with_traceback(filtered_tb) from None
File "<__array_function__ internals>", line 5, in transpose
File "E:\GeekTools\anaconda\lib\site-packages\numpy\core\fromnumeric.py", line 660, in transpose
return _wrapfunc(a, 'transpose', axes)
File "E:\GeekTools\anaconda\lib\site-packages\numpy\core\fromnumeric.py", line 57, in _wrapfunc
return bound(*args, **kwds)
ValueError: axes don't match array
用自己训练集保存的模型和百度网盘下载的模型都会报错
你好,谢谢你之前的回信,很有帮助,又有新的问题麻烦你.原型网络训练代码是这样
我有点不明白元学习的训练模式,想请教你.CNN的训练就是一个epoch中分多个batch,batch之间的数据互不相同.但是每个epoch的数据都是一样的,是重复的训练.到了元任务的模式,我看是分成epoch和eposide,其他论文也是这么写的,但是我看代码,他是根据一个个task来训练模型的, train_tasks.sample() 生成的样本是随机的.那么是否不单eposide之间的数据不一样,其实不同epoch训练数据也不相同?那么分成epoch和eposide有什么意义,就是单纯规定一个合适的周期来看训练的精度吗?
我看到在训练和测试时,遍历task是不同的,训练时是这样 task = train_tasks.sample() ,测试时
for i, batch in enumerate(train_tasks)
但是根据我的测试 train_tasks.sample() 可能会返回2个数据相同的task,后面那种遍历没有这个情况.同时如果2个task的数据完全相同.我的意思是如果task的data的tensor的数据完全一样,那么是否认为这2个task就是重复的,对提高模型精度就没有效果?还有就是task的重复的可能概率是如何估计,如果重复率过高是否对训练有负面影响?期盼你能简答我的疑惑,不谢感谢!
Q:
我对跨域的T2任务有些不理解,T1就是跨工况,病害类别没有变化,只是工况从3到0,这是常见的.但是T2比较少见,他是相同工况下,目标域和源域的病害(故障)类别完全不相交,让我困惑,能否详细说明一下
T1: 10 ways, load 3 ==> 10 ways, load 0
T2: 6 ways, load 0 ==> 4 ways, load 0
Tasks | Source categories | Target categories | Source load | Target load |
---|---|---|---|---|
T1 | {NC, IF1, IF2, IF3, ..., RoF3 } | {NC, IF1, IF2, IF3, ..., RoF3} | 3 | 0 |
T2 | {IF1, IF2, IF3, OF1, OF2, OF3} | {NC, RoF1, RoF2, RoF3 } | 0 | 0 |
A:
在我们的文章中,构造了两个故障诊断任务用以展示各个元学习模型的潜力。T1是一个常见的迁移任务,仅工况变化不同(负载变化)。T2有完全不同的源域和目标域类别,属于识别“新类别”的范畴。这应当是元学习故障诊断之后的探索方向,比如在轴承上训练、在齿轮上测试。这样,他们拥有不同的类别,但可以持续用于新任务。在文章中,跨度没这么大,仅仅使用了轴承数据,所以让类别不相交来模拟这一情景。在原文中我们也探讨了这类情景,比如 DCDB-DF (Seen in Sectiuon 4.1)
Furthermore,we clarify three concepts of SCDB,DCSB and DCDB in Fig.17.In this illustration,we only emphasize that the types of components for source and target domain are different in DCDB scenario, thus it is the most intractable.
元学习算法通过情景式训练(Episodic paradigm)学习任务之间的元知识,在新任务上迅速学习。至于元学习模型具体如何实现的,可以参考代码,原理可以参考我们的元学习综述文章或者一些方法的原文。
This may be caused by the Version incompatibility of learn2learn. Our codes support learn2learn with version >= 0.1.5
, and the original codes are:
tasks = l2l.data.TaskDataset(dataset, task_transforms=[
l2l.data.transforms.FusedNWaysKShots(dataset, new_ways, 2 * shots, filter_labels=filter_labels),
l2l.data.transforms.LoadData(dataset),
# l2l.data.transforms.RemapLabels(dataset, shuffle=label_shuffle_per_task),
l2l.data.transforms.RemapLabels(dataset, shuffle=True),
# do not keep the original labels, use (0 ,..., n-1);
# if shuffle=True, to shuffle labels at each task.
l2l.data.transforms.ConsecutiveLabels(dataset),
# re-order samples and make their original labels as (0 ,..., n-1).
], num_tasks=num_tasks)
FusedNWaysKShots
is just an Efficient implementation of FilterLabels
, NWays
, and KShots
. To avoid the above error, you can update your learn2learn package
(>=0.1.5) or directly use the codes as follows.
tasks = l2l.data.TaskDataset(dataset, task_transforms=[
l2l.data.transforms.NWays(dataset, new_ways),
l2l.data.transforms.KShots(dataset, 2 * shots),
l2l.data.transforms.FilterLabels(dataset, filter_labels),
l2l.data.transforms.LoadData(dataset),
l2l.data.transforms.RemapLabels(dataset),
l2l.data.transforms.ConsecutiveLabels(dataset),
], num_tasks=num_tasks)
For more details, you can refer to my blog: 初探元学习库learn2learn or the official documents of learn2learn.
GOOD LUCK to you guys.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.