Giter Site home page Giter Site logo

keras's People

Contributors

abhaikollara avatar ahundt avatar carlthome avatar danielhiversen avatar edersantana avatar farizrahman4u avatar fchollet avatar fuzzythecat avatar gabrieldemarmiesse avatar gvtulder avatar jfsantos avatar jihobak avatar joosephook avatar lukedeo avatar matsuyamax avatar maxpumperla avatar myutwo150 avatar nzw0301 avatar olegsinavski avatar ozabluda avatar phreeza avatar rcasero avatar staticskies avatar taehoonlee avatar tdhd avatar the-moliver avatar tleeuwenburg avatar wxs avatar yanboliang avatar yaringal avatar

Watchers

 avatar  avatar

keras's Issues

tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shape

In commit f1c5c1d, running the following code with python 3.6

import os
os.environ['KERAS_BACKEND'] = 'tensorflow'
import keras
import keras.backend as K
import numpy as np

from keras.models import Model, Sequential
from keras.layers import Activation, Conv2D, Input
from keras.layers.normalization import BatchNormalization

from keras.utils import multi_gpu_model

# remove warning "Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # Just disables the warning, doesn't enable AVX/FMA

# image_data_format = 'channels_first'
image_data_format = 'channels_last'

K.set_image_data_format(image_data_format)

# simulate input images
if image_data_format == 'channels_first':
    im = np.zeros(shape=(10, 3, 64, 64), dtype='uint8')

    # simulate network output
    out = 2 * np.ones(shape=(10, 1, 64, 64), dtype='float32')
    aux_out = 5 * np.ones(shape=(10, 1, 22, 22), dtype='float32')
    # simulate training weights for network output
    weight = np.ones(shape=(10, 1, 64, 64), dtype='float32')
    aux_weight = np.ones(shape=(10, 1, 22, 22), dtype='float32')

    # simulate validation data
    im_validation = 3 * np.ones(shape=(5, 3, 64, 64), dtype='uint8')
    out_validation = 4 * np.ones(shape=(5, 1, 64, 64), dtype='float32')

elif image_data_format == 'channels_last':
    im = np.zeros(shape=(10, 64, 64, 3), dtype='uint8')

    # simulate network output
    out = 2 * np.ones(shape=(10, 64, 64, 1), dtype='float32')
    aux_out = 5 * np.ones(shape=(10, 22, 22, 1), dtype='float32')
    # simulate training weights for network output
    weight = np.ones(shape=(10, 64, 64, 1), dtype='float32')
    aux_weight = np.ones(shape=(10, 22, 22, 1), dtype='float32')

    # simulate validation data
    im_validation = 3 * np.ones(shape=(5, 64, 64, 3), dtype='uint8')
    out_validation = 4 * np.ones(shape=(5, 64, 64, 1), dtype='float32')

else:
    raise ValueError('Unrecognised position for channels')


validation_data = (im_validation, out_validation)

# optimizer
optimizer = keras.optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)

'''simple CNN with one output
'''

# create network model
model = Sequential()
model.add(Conv2D(input_shape=im.shape[1:],
                 filters=32, kernel_size=(3, 3), strides=1, padding='same'))
model.add(BatchNormalization(axis=3))
model.add(Activation('relu'))
model.add(Conv2D(filters=1, kernel_size=(1, 1), strides=1, padding='same'))

'''string format (sample_weights_mode='element')
'''

# compile model
model.compile(loss='mae', optimizer=optimizer, metrics=['accuracy'], sample_weight_mode='element')

# train with validation_data
model.fit(im, out, sample_weight=weight, validation_data=validation_data, batch_size=3, epochs=3)

produces the following error

Traceback (most recent call last):
  File "<input>", line 79, in <module>
  File "/home/rcasero/.conda/envs/cytometer_tensorflow/lib/python3.6/site-packages/keras/engine/training.py", line 1070, in fit
    validation_steps=validation_steps)
  File "/home/rcasero/.conda/envs/cytometer_tensorflow/lib/python3.6/site-packages/keras/engine/training_arrays.py", line 199, in fit_loop
    outs = f(ins_batch)
  File "/home/rcasero/.conda/envs/cytometer_tensorflow/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2661, in __call__
    return self._call(inputs)
  File "/home/rcasero/.conda/envs/cytometer_tensorflow/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2631, in _call
    fetched = self._callable_fn(*array_vals)
  File "/home/rcasero/.conda/envs/cytometer_tensorflow/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1454, in __call__
    self._session._session, self._handle, args, status, None)
  File "/home/rcasero/.conda/envs/cytometer_tensorflow/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py", line 519, in __exit__
    c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [3,64,64] vs. [3,64,64,1]
	 [[Node: loss/conv2d_2_loss/mul = Mul[T=DT_FLOAT, _class=["loc:@training/SGD/gradients/loss/conv2d_2_loss/mul_grad/Sum"], _device="/job:localhost/replica:0/task:0/device:GPU:0"](loss/conv2d_2_loss/Mean, _arg_conv2d_2_sample_weights_0_2/_47)]]
	 [[Node: loss/mul/_99 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_518_loss/mul", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.