Giter Site home page Giter Site logo

metafd's People

Contributors

fyancy avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

metafd's Issues

Performance degradation of CNN

The reported accuracy in our paper and README file in github repository is 71.8%.
We post the evaluation process in detail as follows. (2022/4/6, by Yong Feng)

Test? y/n
y
Load Model successfully from [G:\model_save\meta_learning\CNN\cnn_ft\5shot\cnn_ft_C30_ep50]...
x shape: (10, 200, 1, 1024), y shape: (10, 200)
Acc: 0.7180, Loss: 1.2492
*** Testing time:  1.2068 (s) ***

Load Model successfully from [G:\model_save\meta_learning\CNN\cnn_ft\5shot\cnn_ft_C30_ep72]...
4-ways, 5-shots for testing ...
x shape: (4, 5, 1, 1024), y shape: (4, 5)
x shape: (4, 200, 1, 1024), y shape: (4, 200)
Acc: 0.8238, Loss: 0.5628
*** Testing time:  3.6451 (s) ***

Load Model successfully from [G:\model_save\meta_learning\CNN\cnn_ft\5shot\cnn_ft_C30_ep72]...
4-ways, 1-shots for testing ...
x shape: (4, 1, 1, 1024), y shape: (4, 1)
x shape: (4, 200, 1, 1024), y shape: (4, 200)
Acc: 0.5950, Loss: 1.0441
*** Testing time:  3.6962 (s) ***

For better validation, we post All model weights here. We stated this in README file.

使用测试集test的时候所有模型都会报错-权重张量的维度不匹配

使用训练集训练的时候报错

 File "E:\GeekTools\anaconda\lib\site-packages\keras\utils\traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "<__array_function__ internals>", line 5, in transpose
  File "E:\GeekTools\anaconda\lib\site-packages\numpy\core\fromnumeric.py", line 660, in transpose
    return _wrapfunc(a, 'transpose', axes)
  File "E:\GeekTools\anaconda\lib\site-packages\numpy\core\fromnumeric.py", line 57, in _wrapfunc
    return bound(*args, **kwds)
ValueError: axes don't match array

用自己训练集保存的模型和百度网盘下载的模型都会报错

关于元任务训练模式和task重复

   你好,谢谢你之前的回信,很有帮助,又有新的问题麻烦你.原型网络训练代码是这样

image
我有点不明白元学习的训练模式,想请教你.CNN的训练就是一个epoch中分多个batch,batch之间的数据互不相同.但是每个epoch的数据都是一样的,是重复的训练.到了元任务的模式,我看是分成epoch和eposide,其他论文也是这么写的,但是我看代码,他是根据一个个task来训练模型的, train_tasks.sample() 生成的样本是随机的.那么是否不单eposide之间的数据不一样,其实不同epoch训练数据也不相同?那么分成epoch和eposide有什么意义,就是单纯规定一个合适的周期来看训练的精度吗?
我看到在训练和测试时,遍历task是不同的,训练时是这样 task = train_tasks.sample() ,测试时
for i, batch in enumerate(train_tasks)
但是根据我的测试 train_tasks.sample() 可能会返回2个数据相同的task,后面那种遍历没有这个情况.同时如果2个task的数据完全相同.我的意思是如果task的data的tensor的数据完全一样,那么是否认为这2个task就是重复的,对提高模型精度就没有效果?还有就是task的重复的可能概率是如何估计,如果重复率过高是否对训练有负面影响?期盼你能简答我的疑惑,不谢感谢!

How to understand the cross-domain task T2?

Q:
我对跨域的T2任务有些不理解,T1就是跨工况,病害类别没有变化,只是工况从3到0,这是常见的.但是T2比较少见,他是相同工况下,目标域和源域的病害(故障)类别完全不相交,让我困惑,能否详细说明一下

T1: 10 ways, load 3 ==> 10 ways, load 0  
T2: 6 ways, load 0 ==> 4 ways, load 0  
Tasks Source categories Target categories Source load Target load
T1 {NC, IF1, IF2, IF3, ..., RoF3 } {NC, IF1, IF2, IF3, ..., RoF3} 3 0
T2 {IF1, IF2, IF3, OF1, OF2, OF3} {NC, RoF1, RoF2, RoF3 } 0 0

A:
在我们的文章中,构造了两个故障诊断任务用以展示各个元学习模型的潜力。T1是一个常见的迁移任务,仅工况变化不同(负载变化)。T2有完全不同的源域和目标域类别,属于识别“新类别”的范畴。这应当是元学习故障诊断之后的探索方向,比如在轴承上训练、在齿轮上测试。这样,他们拥有不同的类别,但可以持续用于新任务。在文章中,跨度没这么大,仅仅使用了轴承数据,所以让类别不相交来模拟这一情景。在原文中我们也探讨了这类情景,比如 DCDB-DF (Seen in Sectiuon 4.1)

Furthermore,we clarify three concepts of SCDB,DCSB and DCDB in Fig.17.In this illustration,we only emphasize that the types of components for source and target domain are different in DCDB scenario, thus it is the most intractable.

元学习算法通过情景式训练(Episodic paradigm)学习任务之间的元知识,在新任务上迅速学习。至于元学习模型具体如何实现的,可以参考代码,原理可以参考我们的元学习综述文章或者一些方法的原文。

module 'learn2learn has no attribute FusedNWayKShots'

This may be caused by the Version incompatibility of learn2learn. Our codes support learn2learn with version >= 0.1.5, and the original codes are:

tasks = l2l.data.TaskDataset(dataset, task_transforms=[
            l2l.data.transforms.FusedNWaysKShots(dataset, new_ways, 2 * shots, filter_labels=filter_labels),
            l2l.data.transforms.LoadData(dataset),
            # l2l.data.transforms.RemapLabels(dataset, shuffle=label_shuffle_per_task),
            l2l.data.transforms.RemapLabels(dataset, shuffle=True),
            # do not keep the original labels, use (0 ,..., n-1);
            # if shuffle=True, to shuffle labels at each task.
            l2l.data.transforms.ConsecutiveLabels(dataset),
            # re-order samples and make their original labels as (0 ,..., n-1).
        ], num_tasks=num_tasks)

FusedNWaysKShots is just an Efficient implementation of FilterLabels, NWays, and KShots. To avoid the above error, you can update your learn2learn package (>=0.1.5) or directly use the codes as follows.

tasks = l2l.data.TaskDataset(dataset, task_transforms=[
        l2l.data.transforms.NWays(dataset, new_ways),
        l2l.data.transforms.KShots(dataset, 2 * shots),
        l2l.data.transforms.FilterLabels(dataset, filter_labels),
        l2l.data.transforms.LoadData(dataset),
        l2l.data.transforms.RemapLabels(dataset),
        l2l.data.transforms.ConsecutiveLabels(dataset),
    ], num_tasks=num_tasks)

For more details, you can refer to my blog: 初探元学习库learn2learn or the official documents of learn2learn.

GOOD LUCK to you guys.

Relation Network 无法复现

你好,通过 README 的结果图可以看出 Relation Network 在 5-shot 下表现良好,准确率达到了95%以上,然而我拉取了最新的代码只是修改了必须修改的路径后运行,准确率只有 76% 如图,请问是还需要配置哪里或者默认参数需要调整?

Screen Shot 2022-03-26 at 21 10 56

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.