Giter Site home page Giter Site logo

Comments (4)

sharvil avatar sharvil commented on May 27, 2024 1

Thanks @w1d2s! This is an excellent issue. I'll take a look.

from haste.

w1d2s avatar w1d2s commented on May 27, 2024

I make several experiments and found it may be caused by floating point error:

  1. weight initialize (-1, 1), float32: which means the rnn outputs after several time steps will be distributed around zero, in that way even if the maximum absolute error is about 1e-3 or 1e-4, the relative error is large.
  2. weight initialize (1, 2), float32: the rnn outputs after several time steps are larger than 1, the absolute error is 1e-3 or 1e-4 but the relative error is 1e-6 or 1e-7, which seems acceptable.
  3. weight initialize (-1, 1), float64: the large relative error comes at later time step than float32, but after 50 time steps it still grows large.
  4. weight initialize (1, 2), float64: after 50 time steps the relative error is still 0.0, I thought that means the implementation is correct.

So is there any solution to the floating point error?

from haste.

sharvil avatar sharvil commented on May 27, 2024

Thanks so much for running those experiments!

Maximum relative error doesn't seem like a good measure here. The CPU and GPU implementations of standard operations (e.g. matrix multiply) don't produce bitwise identical results. There's really no avoiding loss of precision for many such commonly-used operations. Here's a sample script that shows that the maximum relative error can grow to large values even with standard PyTorch operations:

import numpy as np
import torch
import torch.nn as nn


batch_size = 32
time_steps = 250
input_size = 128
hidden_size = 256


def IndRNNScript(input, h0, kernel, recurrent_scale, bias):
  time_steps = input.shape[0]
  h = [h0]
  Wx = input @ kernel + bias
  for t in range(time_steps):
    h.append(torch.tanh(Wx[t] + h[-1] * recurrent_scale))
  h = torch.stack(h)
  return h


def cal_err_pointwise(a, b):
  if a is None or b is None:
      return None
  a = a.cpu().detach().numpy()
  b = b.cpu().detach().numpy()
  if np.all(np.equal(a, b)):
      return 0.0
  diff = a - b
  denom = np.abs(b) + np.ones_like(b) * 1e-10
  ratio = np.abs(diff) / denom
  err_mean = np.mean(ratio)
  err_max = np.max(ratio)
  return err_mean, err_max


kernel = torch.empty(input_size, hidden_size)
recurrent_scale = torch.empty(hidden_size)
bias = torch.empty(hidden_size)

worst_err_max = 0
for _ in range(100):
  nn.init.xavier_uniform_(kernel)
  nn.init.uniform_(recurrent_scale, -0.5, 0.5)
  nn.init.zeros_(bias)

  x = torch.rand(time_steps, batch_size, input_size)
  h0 = torch.zeros(1, batch_size, hidden_size)

  y_cpu = IndRNNScript(x, h0, kernel, recurrent_scale, bias)
  y_gpu = IndRNNScript(x.cuda(), h0.cuda(), kernel.cuda(), recurrent_scale.cuda(), bias.cuda())

  err_mean, err_max = cal_err_pointwise(y_cpu, y_gpu.cpu())
  worst_err_max = max(err_max, worst_err_max)
  print(err_mean, err_max)
print(f'Largest maximum relative error: {worst_err_max}')

Here are the fast few lines of output on my machine:

1.0851818e-06 0.12782718
9.656759e-07 0.14665353
1.1386784e-06 0.49666694
3.2556713e-06 4.9337797
Largest maximum relative error: 111.75870513916016

In fact, you could even return Wx from IndRNNScript (only perform a single matrix multiply) and see that the max relative error is huge.

Are you running into an issue in practice with this loss of precision?

from haste.

w1d2s avatar w1d2s commented on May 27, 2024

I think you are right, I'm implementing LSTM with projection based on your implementation of LSTM, and before running a real training task I want to compare the forward and backward result of my LSTMP with my existing LSTMP based on pytorch ops. Now I think it can be assured that the implementation is correct and I can go on to real training task and see the performance on testsets.

from haste.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.