rcasero / keras Goto Github PK
View Code? Open in Web Editor NEWThis project forked from keras-team/keras
Deep Learning for humans
Home Page: http://keras.io/
License: Other
This project forked from keras-team/keras
Deep Learning for humans
Home Page: http://keras.io/
License: Other
In commit f1c5c1d, running the following code with python 3.6
import os
os.environ['KERAS_BACKEND'] = 'tensorflow'
import keras
import keras.backend as K
import numpy as np
from keras.models import Model, Sequential
from keras.layers import Activation, Conv2D, Input
from keras.layers.normalization import BatchNormalization
from keras.utils import multi_gpu_model
# remove warning "Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Just disables the warning, doesn't enable AVX/FMA
# image_data_format = 'channels_first'
image_data_format = 'channels_last'
K.set_image_data_format(image_data_format)
# simulate input images
if image_data_format == 'channels_first':
im = np.zeros(shape=(10, 3, 64, 64), dtype='uint8')
# simulate network output
out = 2 * np.ones(shape=(10, 1, 64, 64), dtype='float32')
aux_out = 5 * np.ones(shape=(10, 1, 22, 22), dtype='float32')
# simulate training weights for network output
weight = np.ones(shape=(10, 1, 64, 64), dtype='float32')
aux_weight = np.ones(shape=(10, 1, 22, 22), dtype='float32')
# simulate validation data
im_validation = 3 * np.ones(shape=(5, 3, 64, 64), dtype='uint8')
out_validation = 4 * np.ones(shape=(5, 1, 64, 64), dtype='float32')
elif image_data_format == 'channels_last':
im = np.zeros(shape=(10, 64, 64, 3), dtype='uint8')
# simulate network output
out = 2 * np.ones(shape=(10, 64, 64, 1), dtype='float32')
aux_out = 5 * np.ones(shape=(10, 22, 22, 1), dtype='float32')
# simulate training weights for network output
weight = np.ones(shape=(10, 64, 64, 1), dtype='float32')
aux_weight = np.ones(shape=(10, 22, 22, 1), dtype='float32')
# simulate validation data
im_validation = 3 * np.ones(shape=(5, 64, 64, 3), dtype='uint8')
out_validation = 4 * np.ones(shape=(5, 64, 64, 1), dtype='float32')
else:
raise ValueError('Unrecognised position for channels')
validation_data = (im_validation, out_validation)
# optimizer
optimizer = keras.optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
'''simple CNN with one output
'''
# create network model
model = Sequential()
model.add(Conv2D(input_shape=im.shape[1:],
filters=32, kernel_size=(3, 3), strides=1, padding='same'))
model.add(BatchNormalization(axis=3))
model.add(Activation('relu'))
model.add(Conv2D(filters=1, kernel_size=(1, 1), strides=1, padding='same'))
'''string format (sample_weights_mode='element')
'''
# compile model
model.compile(loss='mae', optimizer=optimizer, metrics=['accuracy'], sample_weight_mode='element')
# train with validation_data
model.fit(im, out, sample_weight=weight, validation_data=validation_data, batch_size=3, epochs=3)
produces the following error
Traceback (most recent call last):
File "<input>", line 79, in <module>
File "/home/rcasero/.conda/envs/cytometer_tensorflow/lib/python3.6/site-packages/keras/engine/training.py", line 1070, in fit
validation_steps=validation_steps)
File "/home/rcasero/.conda/envs/cytometer_tensorflow/lib/python3.6/site-packages/keras/engine/training_arrays.py", line 199, in fit_loop
outs = f(ins_batch)
File "/home/rcasero/.conda/envs/cytometer_tensorflow/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2661, in __call__
return self._call(inputs)
File "/home/rcasero/.conda/envs/cytometer_tensorflow/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2631, in _call
fetched = self._callable_fn(*array_vals)
File "/home/rcasero/.conda/envs/cytometer_tensorflow/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1454, in __call__
self._session._session, self._handle, args, status, None)
File "/home/rcasero/.conda/envs/cytometer_tensorflow/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py", line 519, in __exit__
c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [3,64,64] vs. [3,64,64,1]
[[Node: loss/conv2d_2_loss/mul = Mul[T=DT_FLOAT, _class=["loc:@training/SGD/gradients/loss/conv2d_2_loss/mul_grad/Sum"], _device="/job:localhost/replica:0/task:0/device:GPU:0"](loss/conv2d_2_loss/Mean, _arg_conv2d_2_sample_weights_0_2/_47)]]
[[Node: loss/mul/_99 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_518_loss/mul", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.