ConvLSTM的用法

news/2024/7/24 13:18:47 标签: python, 深度学习, lstm, numpy

简单RNN与LSTM对比

LSTM计算示意

LSTM计算示意

import torch
from torch import nn
import torch.nn.functional as f
from torch.autograd import Variable


# Define some constants
KERNEL_SIZE = 3
PADDING = KERNEL_SIZE // 2


class ConvLSTMCell(nn.Module):
    """
    Generate a convolutional LSTM cell
    """

    def __init__(self, input_size, hidden_size):
        super(ConvLSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.Gates = nn.Conv2d(input_size + hidden_size, 4 * hidden_size, KERNEL_SIZE, padding=PADDING)

    def forward(self, input_, prev_state):

        # get batch and spatial sizes
        batch_size = input_.data.size()[0]
        spatial_size = input_.data.size()[2:]

        # generate empty prev_state, if None is provided
        if prev_state is None:
            state_size = [batch_size, self.hidden_size] + list(spatial_size)
            prev_state = (
                Variable(torch.zeros(state_size)),
                Variable(torch.zeros(state_size))
            )

        prev_hidden, prev_cell = prev_state

        # data size is [batch, channel, height, width]
        stacked_inputs = torch.cat((input_, prev_hidden), 1)
        gates = self.Gates(stacked_inputs)

        # chunk across channel dimension
        in_gate, remember_gate, out_gate, cell_gate = gates.chunk(4, 1)

        # apply sigmoid non linearity
        in_gate = f.sigmoid(in_gate)
        remember_gate = f.sigmoid(remember_gate)
        out_gate = f.sigmoid(out_gate)

        # apply tanh non linearity
        cell_gate = f.tanh(cell_gate)

        # compute current cell and hidden state
        cell = (remember_gate * prev_cell) + (in_gate * cell_gate)
        hidden = out_gate * f.tanh(cell)

        return hidden, cell


def _main():
    """
    Run some basic tests on the API
    """

    # define batch_size, channels, height, width
    b, c, h, w = 1, 3, 4, 8
    d = 5           # hidden state size
    lr = 1e-1       # learning rate
    T = 6           # sequence length
    max_epoch = 20  # number of epochs

    # set manual seed
    torch.manual_seed(0)

    print('Instantiate model')
    model = ConvLSTMCell(c, d)
    print(repr(model))

    print('Create input and target Variables')
    x = Variable(torch.rand(T, b, c, h, w))
    y = Variable(torch.randn(T, b, d, h, w))

    print('Create a MSE criterion')
    loss_fn = nn.MSELoss()

    print('Run for', max_epoch, 'iterations')
    for epoch in range(0, max_epoch):
        state = None
        loss = 0
        for t in range(0, T):
            state = model(x[t], state)
            loss += loss_fn(state[0], y[t])

        print(' > Epoch {:2d} loss: {:.3f}'.format((epoch+1), loss.data[0]))

        # zero grad parameters
        model.zero_grad()

        # compute new grad parameters through time!
        loss.backward()

        # learning_rate step against the gradient
        for p in model.parameters():
            p.data.sub_(p.grad.data * lr)

    print('Input size:', list(x.data.size()))
    print('Target size:', list(y.data.size()))
    print('Last hidden state size:', list(state[0].size()))


if __name__ == '__main__':
    _main()

 


http://www.niftyadmin.cn/n/957522.html

相关文章

这个命令错了没有。npm init vitejs/app guagua2

两个父组件子组件&#xff0c;放在一个component文件夹里&#xff0c;是不是不行&#xff0c;还需要用到通信工具&#xff1f; 所以需要建立打包的工具&#xff1f; 又换了命令了。 npm init vite <project-name> -- --template vue

PyTorch的torch.cat

1. 字面理解&#xff1a;torch.cat是将两个张量&#xff08;tensor&#xff09;拼接在一起&#xff0c;cat是concatnate的意思&#xff0c;即拼接&#xff0c;联系在一起。 2. 例子理解 >>> import torch >>> Atorch.ones(2,3) #2x3的张量&#xff08;矩阵…

组件第一个功能:App.vue是父组件,是包在外面的组件,在components文件夹里的组件,HelloWorld.vue是子组件,props是父组件给子组件传值

msg的值在父组件里&#xff0c;在app.vue里 给子组件helloworld.vue里的msg传值。 好好体会一下。 emits又是怎样使用呢&#xff1f; 哪些情况下&#xff0c;什么情况下&#xff0c;子组件需要给父组件传值呢&#xff1f; 父组件 <template><div id"app"…

Pytorch torch.cat与torch.chunk

https://blog.csdn.net/benbenls/article/details/102974070?depth_1-utm_sourcedistribute.pc_relevant.none-task&utm_sourcedistribute.pc_relevant.none-task

$attrs的作用

不是很懂&#xff0c;这个例子&#xff0c;有什么变化&#xff0c;不一样&#xff0c;功能有什么体现吗&#xff0c;没看出来。 App.vue <template><div id"app"><HelloWorlddata-a "1"data-b "2"msg"Welcome to Your Vue.…

用Gym学习强化学习之Policy Gradient

作者&#xff1a;CloudyyyyyHIT 兴趣方向&#xff1a;自然语言处理、人工智能 目录 什么是强化学习强化学习的问题要素Gym简介Policy Gradient实战总结参考 1 什么是强化学习 强化学习在机器学习的应用分类里常常和监督学习和非监督学习并列。 在监督学习和非监督学习中&am…

解析这个文件的功能,是应用了哪些方法实现的。

1、有三个按钮&#xff0c;并排 2、点home按钮&#xff0c;出现&#xff0c; Home component字样 3、点Posts按钮&#xff0c;出现&#xff0c; Cat Ipsum Hipster Ipsum Cupcake Ipsum Click on a blog title to the left to view it. Cat Ipsum Hipster Ipsum Cupcake Ipsum …

python3没有xrange原因与解决方法

今天在用Python3写CTF试题代码的时候&#xff0c;发现了xrange标红&#xff0c;python3找不到xrange定义了。 原因&#xff1a; 在 Python 2 中 xrange() 创建迭代对象的用法是非常流行的。比如&#xff1a; for 循环或者是列表/集合/字典推导式。这个表现十分像生成器&#…