Giter Site home page Giter Site logo

hctian713 / multispectral-rsimg-classification Goto Github PK

View Code? Open in Web Editor NEW
9.0 1.0 0.0 4.44 MB

【武汉大学遥感学院】空间智能感知与服务课设 | 基于Softmax的多波段遥感影像分类

Jupyter Notebook 100.00%
cv machine-learning multispectral-images softmax-classification

multispectral-rsimg-classification's Introduction

MultiSpectral-RSImg-Classification

【武汉大学遥感学院】空间智能感知与服务课设 | 基于Softmax的多波段遥感影像分类

1. 数据描述

原始数据FINAL.tif为tiff格式的多波段影像,共21个波段,其中第18个为Output波段即分类结果。经处理发现,共4种污染程度即4个类别,每个类别像素数量之间差距很大。具体数量如后所示class_num:{1: 1092, 2: 6364, 3: 21599, 4: 112139}

数据存在大量Nan Inf干扰项,对实际预测判断没有用处,需进行去除。

波段1/2/19/18(Output)可视化结果

2. 实验条件

Pytorch Osgeo.gdal sklearn seaborn

3. 网络模型

由于数据结构较为简单,因此不引用任何开源网络架构,独立设计了一个简单的Softmax神经网络分割模型,经过多次失败的测试,最终设计得到的网络结构如下图所示,下面详细解释网络结构和设计的思路:

  • 输入层:采用考虑空间信息的卷积,则会产生大量的权重参数,但考虑到数据集数量较少,则很容易导致过拟合,因此只考率像素的波段信息,则输入层为20维张量。
  • 隐层:为使得网络足够复杂以能够表达关系信息,共设置结点数分别为40/25/10的3层隐层,激活函数分别为Relu/Relu/Sigmoid,前两个Relu可以起到增加训练效率的作用。
  • 输出层和损失函数:为像素级分类,即图像分割,输出层Softmax结构实现多分类,实际采用log_softmax,损失函数为负对数似然损失函数NLLLoss。公式如下: $$log\_ softmax = \frac{e^{xi}}{\sum_{i}e^{xi}}~~~~NLLLoss = - \frac{1}{N}{\sum\limits_{k = 1}^{N}{y_{k}(log\_ softmax)}}$$
  • 训练优化方法:采用动态梯度下降法momentum,将一段时间内的梯度向量进行了加权平均,一定程度上消除了更新过程中的不确定性因素(如摆动现象),增加训练效率。

4.关键步骤

4.1 数据读取与清洗

原始数据tiff无法用OpenCV或者PIL读取,使用GDAL读取原始数据转化为ndarray,并按照波段reshape为(20,141194)数据。利用np.isinfnp.isnan进行数据清洗,将无效数据替换为0或1。

4.2 样本划分与数据增强

数据样本数量少且类别非常不均衡,因此采用简单的9:1随机进行训练集和验证集的划分。

将训练集验证集的数据和标签转换为pytorch.tensor格式,同时注意数据必须为float32,标签必须为long,否则无法进行模型训练。

4.3 网络模型搭建

  • 定义SoftMax网络
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.l1=nn.Linear(input_n,40)
        self.l2=nn.Linear(40,25)
        self.l3=nn.Linear(25,10)
        self.l4=nn.Linear(10,output_n)

    def forward(self,x):
        a1=F.relu(self.l1(x))
        a2=F.relu(self.l2(a1))
        a3=F.sigmoid(self.l3(a2))
        output=F.log_softmax(self.l4(a3), dim=1)
        return output
  • 优化器与损失函数
#优化器随机梯度下降 momentum动态梯度下降
optimizer = torch.optim.SGD(model.parameters(), lr=lr,momentum=0.9)
#交叉熵损失
loss=nn.NLLLoss()

4.4 训练与验证

训练超参数设置如下,混淆矩阵可视化验证。

epochs=100#训练次数
lr=0.001#学习率
batch_size=256#批次大小
iteration=train_data.shape[0]//batch_size

5.结果评价

由于训练数据进行了数据增强,因此合理的直接采用全部原始数据进行验证。 最终计算得出的精度为92.768%

损失曲线和混淆矩阵

图像复原

原始OutPut 预测OutPut

multispectral-rsimg-classification's People

Contributors

hctian713 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.