Giter Site home page Giter Site logo

keras_multi_gpu's Issues

Incompatible shapes

Hi,
seems like it's not able to split the tensors every time. I get "incompatible shapes".
Perhaps it cannot handle odd lengths?

nvalidArgumentError: Incompatible shapes: [17,258] vs. [35,258]
	 [[Node: sequential_4/gru_3/while/add = Add[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/gpu:0"](sequential_4/gru_3/while/strided_slice, sequential_4/gru_3/while/MatMul)]]

Caused by op 'sequential_4/gru_3/while/add', defined at:
  File "/usr/lib/python3.5/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.5/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/pawel/venv/lib/python3.5/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/home/pawel/venv/lib/python3.5/site-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()

(...)

  File "/home/pawel/venv/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 2630, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/home/pawel/venv/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1204, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): Incompatible shapes: [17,258] vs. [35,258]
	 [[Node: sequential_4/gru_3/while/add = Add[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/gpu:0"](sequential_4/gru_3/while/strided_slice, sequential_4/gru_3/while/MatMul)]]

multi-gpu not works

Hi, I try to use this method to parallel my model. but it may do not work.
For example, I use 3 gpus. but actually it seem use one gpu.
The code is

def multi_gpu_wrapper(single_model, num_gpu):
    inputs = single_model.inputs
    towers = []
    concate_layer = tf.keras.layers.Concatenate(axis=0)
    for gpu_id in range(num_gpu):
        print 'cur gpu is ', gpu_id
        with tf.device('/gpu:' + str(gpu_id)):
            splited_layer = tf.keras.layers.Lambda(lambda x: slice_batch(x, num_gpu, gpu_id))
            cur_inputs = []
            for input in inputs:
                cur_inputs.append(
                    splited_layer(input)
                )
            towers.append(single_model(cur_inputs))
            print towers[-1]
    outputs = []
    num_output = len(towers[-1])
    with tf.device('/cpu:0'):
        for i in range(num_output):
            tmp_outputs = []
            for j in range(num_gpu):
                tmp_outputs.append(towers[j][i])
            outputs.append(concate_layer(tmp_outputs))
    multi_gpu_model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
    return multi_gpu_model

The output of nvidia-smi is:
image

Do you know how to fix it?
Thank you!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.