I am receiving the following error when trying to apply TrivialAugmentWide or RandAugment layers with mixed precision. Any guess?
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')
def make_dataset(X,y,batch_size,autotune = tf.data.AUTOTUNE,augmentation = None,seed = seed):
def preprocess_data(images, labels, augmentation=None):
inputs = {"images": images, "labels": labels}
outputs = augmentation(inputs) if augmentation != None else inputs
return outputs["images"], outputs["labels"]
dataset = tf.data.Dataset.from_tensor_slices((X, y))
dataset = dataset.batch(batch_size).map(lambda x, y: preprocess_data(x, y, augmentation=augmentation), num_parallel_calls=autotune).prefetch(autotune)
return dataset
augmentation_layer = tfk.Sequential([
keras_aug.layers.TrivialAugmentWide(value_range=(0,1), interpolation='bilinear', name='trivial_augment'),
keras_aug.layers.RandomErase(area_factor=(0.02, 0.1), fill_mode='gaussian_noise', name='random_erase')
], name='preprocessing')
training_dataset = make_dataset(X_train, y_train, batch_size=batch_size, augmentation=augmentation_layer)
--------------------------------------------------------------------------------------------------------------------------------
File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/data/ops/dataset_ops.py:2268, in DatasetV2.map(self, map_func, num_parallel_calls, deterministic, name)
2264 # Loaded lazily due to a circular dependency (dataset_ops -> map_op ->
2265 # dataset_ops).
2266 # pylint: disable=g-import-not-at-top,protected-access
2267 from tensorflow.python.data.ops import map_op
-> 2268 return map_op._map_v2(
2269 self,
2270 map_func,
2271 num_parallel_calls=num_parallel_calls,
2272 deterministic=deterministic,
2273 name=name)
File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/data/ops/map_op.py:40, in _map_v2(input_dataset, map_func, num_parallel_calls, deterministic, name)
37 return _MapDataset(
38 input_dataset, map_func, preserve_cardinality=True, name=name)
39 else:
---> 40 return _ParallelMapDataset(
41 input_dataset,
42 map_func,
43 num_parallel_calls=num_parallel_calls,
44 deterministic=deterministic,
45 preserve_cardinality=True,
46 name=name)
File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/data/ops/map_op.py:148, in _ParallelMapDataset.__init__(self, input_dataset, map_func, num_parallel_calls, deterministic, use_inter_op_parallelism, preserve_cardinality, use_legacy_function, name)
146 self._input_dataset = input_dataset
147 self._use_inter_op_parallelism = use_inter_op_parallelism
--> 148 self._map_func = structured_function.StructuredFunctionWrapper(
149 map_func,
150 self._transformation_name(),
151 dataset=input_dataset,
152 use_legacy_function=use_legacy_function)
153 if deterministic is None:
154 self._deterministic = "default"
File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/data/ops/structured_function.py:265, in StructuredFunctionWrapper.__init__(self, func, transformation_name, dataset, input_classes, input_shapes, input_types, input_structure, add_to_graph, use_legacy_function, defun_kwargs)
258 warnings.warn(
259 "Even though the `tf.config.experimental_run_functions_eagerly` "
260 "option is set, this option does not apply to tf.data functions. "
261 "To force eager execution of tf.data functions, please use "
262 "`tf.data.experimental.enable_debug_mode()`.")
263 fn_factory = trace_tf_function(defun_kwargs)
--> 265 self._function = fn_factory()
266 # There is no graph to add in eager mode.
267 add_to_graph &= not context.executing_eagerly()
File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:1222, in Function.get_concrete_function(self, *args, **kwargs)
1220 def get_concrete_function(self, *args, **kwargs):
1221 # Implements GenericFunction.get_concrete_function.
-> 1222 concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
1223 concrete._garbage_collector.release() # pylint: disable=protected-access
1224 return concrete
File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:1192, in Function._get_concrete_function_garbage_collected(self, *args, **kwargs)
1190 if self._variable_creation_config is None:
1191 initializers = []
-> 1192 self._initialize(args, kwargs, add_initializers_to=initializers)
1193 self._initialize_uninitialized_variables(initializers)
1195 if self._created_variables:
1196 # In this case we have created variables on the first call, so we run the
1197 # version which is guaranteed to never create variables.
File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:694, in Function._initialize(self, args, kwds, add_initializers_to)
689 self._variable_creation_config = self._generate_scoped_tracing_options(
690 variable_capturing_scope,
691 tracing_compilation.ScopeType.VARIABLE_CREATION,
692 )
693 # Force the definition of the function for these arguments
--> 694 self._concrete_variable_creation_fn = tracing_compilation.trace_function(
695 args, kwds, self._variable_creation_config
696 )
698 def invalid_creator_scope(*unused_args, **unused_kwds):
699 """Disables variable creation."""
File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py:178, in trace_function(args, kwargs, tracing_options)
175 args = tracing_options.input_signature
176 kwargs = {}
--> 178 concrete_function = _maybe_define_function(
179 args, kwargs, tracing_options
180 )
181 _set_arg_keywords(concrete_function)
183 if not tracing_options.bind_graph_to_function:
File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py:284, in _maybe_define_function(args, kwargs, tracing_options)
282 else:
283 target_func_type = lookup_func_type
--> 284 concrete_function = _create_concrete_function(
285 target_func_type, lookup_func_context, func_graph, tracing_options
286 )
288 if tracing_options.function_cache is not None:
289 tracing_options.function_cache.add(
290 concrete_function, current_func_context
291 )
File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py:308, in _create_concrete_function(function_type, type_context, func_graph, tracing_options)
303 with func_graph.as_default():
304 placeholder_bound_args = function_type.placeholder_arguments(
305 placeholder_context
306 )
--> 308 traced_func_graph = func_graph_module.func_graph_from_py_func(
309 tracing_options.name,
310 tracing_options.python_function,
311 placeholder_bound_args.args,
312 placeholder_bound_args.kwargs,
313 None,
314 func_graph=func_graph,
315 arg_names=function_type_utils.to_arg_names(function_type),
316 create_placeholders=False,
317 )
319 transform.apply_func_graph_transforms(traced_func_graph)
321 graph_capture_container = traced_func_graph.function_captures
File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/framework/func_graph.py:1059, in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, create_placeholders)
1056 return x
1058 _, original_func = tf_decorator.unwrap(python_func)
-> 1059 func_outputs = python_func(*func_args, **func_kwargs)
1061 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
1062 # TensorArrays and `None`s.
1063 func_outputs = variable_utils.convert_variables_to_tensors(func_outputs)
File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:597, in Function._generate_scoped_tracing_options.<locals>.wrapped_fn(*args, **kwds)
593 with default_graph._variable_creator_scope(scope, priority=50): # pylint: disable=protected-access
594 # __wrapped__ allows AutoGraph to swap in a converted function. We give
595 # the function a weak reference to itself to avoid a reference cycle.
596 with OptionalXlaContext(compile_with_xla):
--> 597 out = weak_wrapped_fn().__wrapped__(*args, **kwds)
598 return out
File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/data/ops/structured_function.py:231, in StructuredFunctionWrapper.__init__.<locals>.trace_tf_function.<locals>.wrapped_fn(*args)
230 def wrapped_fn(*args): # pylint: disable=missing-docstring
--> 231 ret = wrapper_helper(*args)
232 ret = structure.to_tensor_list(self._output_structure, ret)
233 return [ops.convert_to_tensor(t) for t in ret]
File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/data/ops/structured_function.py:161, in StructuredFunctionWrapper.__init__.<locals>.wrapper_helper(*args)
159 if not _should_unpack(nested_args):
160 nested_args = (nested_args,)
--> 161 ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args)
162 ret = variable_utils.convert_variables_to_tensors(ret)
163 if _should_pack(ret):
File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/autograph/impl/api.py:693, in convert.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
691 except Exception as e: # pylint:disable=broad-except
692 if hasattr(e, 'ag_error_metadata'):
--> 693 raise e.ag_error_metadata.to_exception(e)
694 else:
695 raise
File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/autograph/impl/api.py:690, in convert.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
688 try:
689 with conversion_ctx:
--> 690 return converted_call(f, args, kwargs, options=options)
691 except Exception as e: # pylint:disable=broad-except
692 if hasattr(e, 'ag_error_metadata'):
File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/autograph/impl/api.py:439, in converted_call(f, args, kwargs, caller_fn_scope, options)
437 try:
438 if kwargs is not None:
--> 439 result = converted_f(*effective_args, **kwargs)
440 else:
441 result = converted_f(*effective_args)
File ~tmp/__autograph_generated_filebrqr6ub4.py:7, in outer_factory.<locals>.inner_factory.<locals>.<lambda>(x, y)
6 def inner_factory(ag__):
----> 7 tf__lam = lambda x, y: ag__.with_function_scope(lambda lscope: ag__.converted_call(preprocess_data, (x, y), dict(augmentation=augmentation), lscope), 'lscope', ag__.STD)
8 return tf__lam
File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/autograph/core/function_wrappers.py:113, in with_function_scope(thunk, scope_name, options)
111 """Inline version of the FunctionScope context manager."""
112 with FunctionScope('lambda_', scope_name, options) as scope:
--> 113 return thunk(scope)
File ~tmp/__autograph_generated_filebrqr6ub4.py:7, in outer_factory.<locals>.inner_factory.<locals>.<lambda>(lscope)
6 def inner_factory(ag__):
----> 7 tf__lam = lambda x, y: ag__.with_function_scope(lambda lscope: ag__.converted_call(preprocess_data, (x, y), dict(augmentation=augmentation), lscope), 'lscope', ag__.STD)
8 return tf__lam
File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/autograph/impl/api.py:439, in converted_call(f, args, kwargs, caller_fn_scope, options)
437 try:
438 if kwargs is not None:
--> 439 result = converted_f(*effective_args, **kwargs)
440 else:
441 result = converted_f(*effective_args)
File ~tmp/__autograph_generated_filelvb17_ff.py:11, in outer_factory.<locals>.inner_factory.<locals>.tf__preprocess_data(images, labels, augmentation)
9 retval_ = ag__.UndefinedReturnValue()
10 inputs = {'images': ag__.ld(images), 'labels': ag__.ld(labels)}
---> 11 outputs = ag__.if_exp(ag__.ld(augmentation) != None, lambda: ag__.converted_call(ag__.ld(augmentation), (ag__.ld(inputs),), None, fscope), lambda: ag__.ld(inputs), 'augmentation != None')
12 try:
13 do_return = True
File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/autograph/operators/conditional_expressions.py:27, in if_exp(cond, if_true, if_false, expr_repr)
25 return _tf_if_exp(cond, if_true, if_false, expr_repr)
26 else:
---> 27 return _py_if_exp(cond, if_true, if_false)
File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/autograph/operators/conditional_expressions.py:52, in _py_if_exp(cond, if_true, if_false)
51 def _py_if_exp(cond, if_true, if_false):
---> 52 return if_true() if cond else if_false()
File ~tmp/__autograph_generated_filelvb17_ff.py:11, in outer_factory.<locals>.inner_factory.<locals>.tf__preprocess_data.<locals>.<lambda>()
9 retval_ = ag__.UndefinedReturnValue()
10 inputs = {'images': ag__.ld(images), 'labels': ag__.ld(labels)}
---> 11 outputs = ag__.if_exp(ag__.ld(augmentation) != None, lambda: ag__.converted_call(ag__.ld(augmentation), (ag__.ld(inputs),), None, fscope), lambda: ag__.ld(inputs), 'augmentation != None')
12 try:
13 do_return = True
File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/autograph/impl/api.py:331, in converted_call(f, args, kwargs, caller_fn_scope, options)
329 if conversion.is_in_allowlist_cache(f, options):
330 logging.log(2, 'Allowlisted %s: from cache', f)
--> 331 return _call_unconverted(f, args, kwargs, options, False)
333 if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
334 logging.log(2, 'Allowlisted: %s: AutoGraph is disabled in context', f)
File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/autograph/impl/api.py:460, in _call_unconverted(f, args, kwargs, options, update_cache)
458 if kwargs is not None:
459 return f(*args, **kwargs)
--> 460 return f(*args)
File ~usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py:70, in filter_traceback.<locals>.error_handler(*args, **kwargs)
67 filtered_tb = _process_traceback_frames(e.__traceback__)
68 # To get the full stack trace, call:
69 # `tf.debugging.disable_traceback_filtering()`
---> 70 raise e.with_traceback(filtered_tb) from None
71 finally:
72 del filtered_tb
File ~tmp/__autograph_generated_file12c62saa.py:33, in outer_factory.<locals>.inner_factory.<locals>.tf__call(self, inputs)
31 nonlocal do_return, retval_
32 raise ag__.converted_call(ag__.ld(ValueError), (f'Image augmentation layers are expecting inputs to be rank 3 (HWC) or 4D (NHWC) tensors. Got shape: {ag__.ld(images).shape}',), None, fscope)
---> 33 ag__.if_stmt(ag__.or_(lambda: ag__.ld(images).shape.rank == 3, lambda: ag__.ld(images).shape.rank == 4), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
34 return fscope.ret(retval_, do_return)
File ~tmp/__autograph_generated_file12c62saa.py:25, in outer_factory.<locals>.inner_factory.<locals>.tf__call.<locals>.if_body()
23 try:
24 do_return = True
---> 25 retval_ = ag__.converted_call(ag__.ld(self)._format_output, (ag__.converted_call(ag__.ld(self)._batch_augment, (ag__.ld(inputs),), None, fscope), ag__.ld(metadata)), None, fscope)
26 except:
27 do_return = False
File ~tmp/__autograph_generated_filep1xqu7wy.py:35, in outer_factory.<locals>.inner_factory.<locals>.tf___batch_augment(self, inputs)
33 ag__.if_stmt(ag__.ld(bounding_boxes) is not None, if_body, else_body, get_state, set_state, ('inputs[BOUNDING_BOXES]', 'ori_bbox_info'), 2)
34 inputs_for_trivial_augment_single_input = {'inputs': ag__.ld(inputs), 'transformations': ag__.ld(transformations)}
---> 35 result = ag__.converted_call(ag__.ld(tf).map_fn, (ag__.ld(self).trivial_augment_single_input, ag__.ld(inputs_for_trivial_augment_single_input)), dict(fn_output_signature=ag__.converted_call(ag__.ld(augmentation_utils).compute_signature, (ag__.ld(inputs), ag__.ld(self).compute_dtype), None, fscope)), fscope)
36 bounding_boxes = ag__.converted_call(ag__.ld(result).get, (ag__.ld(BOUNDING_BOXES), None), None, fscope)
38 def get_state_2():
File ~tmp/__autograph_generated_filev6ti3m33.py:26, in outer_factory.<locals>.inner_factory.<locals>.tf__trivial_augment_single_input(self, inputs)
24 idx = ag__.Undefined('idx')
25 ag__.for_stmt(ag__.converted_call(ag__.ld(enumerate), (ag__.ld(self).aug_layers,), None, fscope), None, loop_body, get_state, set_state, (), {'iterate_names': '(idx, layer)'})
---> 26 result = ag__.converted_call(ag__.ld(tf).switch_case, (ag__.ld(random_indice),), dict(branch_fns=ag__.ld(branch_fns)), fscope)
28 def get_state_1():
29 return (ag__.ldu(lambda: result[BOUNDING_BOXES], 'result[BOUNDING_BOXES]'),)
File ~tmp/__autograph_generated_file12c62saa.py:33, in outer_factory.<locals>.inner_factory.<locals>.tf__call(self, inputs)
31 nonlocal do_return, retval_
32 raise ag__.converted_call(ag__.ld(ValueError), (f'Image augmentation layers are expecting inputs to be rank 3 (HWC) or 4D (NHWC) tensors. Got shape: {ag__.ld(images).shape}',), None, fscope)
---> 33 ag__.if_stmt(ag__.or_(lambda: ag__.ld(images).shape.rank == 3, lambda: ag__.ld(images).shape.rank == 4), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
34 return fscope.ret(retval_, do_return)
File ~tmp/__autograph_generated_file12c62saa.py:25, in outer_factory.<locals>.inner_factory.<locals>.tf__call.<locals>.if_body()
23 try:
24 do_return = True
---> 25 retval_ = ag__.converted_call(ag__.ld(self)._format_output, (ag__.converted_call(ag__.ld(self)._batch_augment, (ag__.ld(inputs),), None, fscope), ag__.ld(metadata)), None, fscope)
26 except:
27 do_return = False
File ~tmp/__autograph_generated_filevkq0dz7z.py:36, in outer_factory.<locals>.inner_factory.<locals>.tf___batch_augment(self, inputs)
34 images = ag__.converted_call(ag__.ld(self).augment_images, (ag__.ld(images),), dict(transformations=ag__.ld(transformations), bounding_boxes=ag__.ld(bounding_boxes), labels=ag__.ld(labels)), fscope)
35 inputs_for_raggeds = ag__.Undefined('inputs_for_raggeds')
---> 36 ag__.if_stmt(ag__.and_(lambda: ag__.converted_call(ag__.ld(isinstance), (ag__.ld(images), ag__.ld(tf).RaggedTensor), None, fscope), lambda: ag__.not_(ag__.ld(self).force_no_unwrap_ragged_image_call)), if_body, else_body, get_state, set_state, ('images',), 1)
38 def get_state_1():
39 return (images,)
File ~tmp/__autograph_generated_filevkq0dz7z.py:34, in outer_factory.<locals>.inner_factory.<locals>.tf___batch_augment.<locals>.else_body()
32 def else_body():
33 nonlocal images
---> 34 images = ag__.converted_call(ag__.ld(self).augment_images, (ag__.ld(images),), dict(transformations=ag__.ld(transformations), bounding_boxes=ag__.ld(bounding_boxes), labels=ag__.ld(labels)), fscope)
File ~tmp/__autograph_generated_filegmle6ti9.py:15, in outer_factory.<locals>.inner_factory.<locals>.tf__augment_images(self, images, transformations, **kwargs)
13 scales = 255.0 / (ag__.ld(highs) - ag__.ld(lows))
14 eq_idxs = ag__.converted_call(ag__.ld(tf).math.is_inf, (ag__.ld(scales),), None, fscope)
---> 15 lows = ag__.converted_call(ag__.ld(tf).where, (ag__.ld(eq_idxs), 0.0, ag__.ld(lows)), None, fscope)
16 scales = ag__.converted_call(ag__.ld(tf).where, (ag__.ld(eq_idxs), 1.0, ag__.ld(scales)), None, fscope)
17 images = ag__.converted_call(ag__.ld(tf).clip_by_value, ((ag__.ld(images) - ag__.ld(lows)) * ag__.ld(scales), 0, 255), None, fscope)
TypeError: in user code:
File "/tmp/ipykernel_11/1206971647.py", line 15, in None *
lambda x, y: preprocess_data(x, y, augmentation=augmentation)
File "/tmp/ipykernel_11/1206971647.py", line 11, in preprocess_data *
outputs = augmentation(inputs) if augmentation != None else inputs
File "/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler **
raise e.with_traceback(filtered_tb) from None
File "/tmp/__autograph_generated_file12c62saa.py", line 33, in tf__call
ag__.if_stmt(ag__.or_(lambda: ag__.ld(images).shape.rank == 3, lambda: ag__.ld(images).shape.rank == 4), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
File "/tmp/__autograph_generated_file12c62saa.py", line 28, in if_body
raise
File "/tmp/__autograph_generated_filep1xqu7wy.py", line 35, in tf___batch_augment
result = ag__.converted_call(ag__.ld(tf).map_fn, (ag__.ld(self).trivial_augment_single_input, ag__.ld(inputs_for_trivial_augment_single_input)), dict(fn_output_signature=ag__.converted_call(ag__.ld(augmentation_utils).compute_signature, (ag__.ld(inputs), ag__.ld(self).compute_dtype), None, fscope)), fscope)
File "/tmp/__autograph_generated_filev6ti3m33.py", line 26, in tf__trivial_augment_single_input
result = ag__.converted_call(ag__.ld(tf).switch_case, (ag__.ld(random_indice),), dict(branch_fns=ag__.ld(branch_fns)), fscope)
File "/tmp/__autograph_generated_file12c62saa.py", line 33, in tf__call
ag__.if_stmt(ag__.or_(lambda: ag__.ld(images).shape.rank == 3, lambda: ag__.ld(images).shape.rank == 4), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
File "/tmp/__autograph_generated_file12c62saa.py", line 28, in if_body
raise
File "/tmp/__autograph_generated_filevkq0dz7z.py", line 36, in tf___batch_augment
ag__.if_stmt(ag__.and_(lambda: ag__.converted_call(ag__.ld(isinstance), (ag__.ld(images), ag__.ld(tf).RaggedTensor), None, fscope), lambda: ag__.not_(ag__.ld(self).force_no_unwrap_ragged_image_call)), if_body, else_body, get_state, set_state, ('images',), 1)
File "/tmp/__autograph_generated_filevkq0dz7z.py", line 34, in else_body
images = ag__.converted_call(ag__.ld(self).augment_images, (ag__.ld(images),), dict(transformations=ag__.ld(transformations), bounding_boxes=ag__.ld(bounding_boxes), labels=ag__.ld(labels)), fscope)
File "/tmp/__autograph_generated_filegmle6ti9.py", line 15, in tf__augment_images
lows = ag__.converted_call(ag__.ld(tf).where, (ag__.ld(eq_idxs), 0.0, ag__.ld(lows)), None, fscope)
TypeError: Exception encountered when calling layer 'trivial_augment' (type TrivialAugmentWide).
in user code:
File "/usr/local/lib/python3.11/dist-packages/keras_aug/layers/base/vectorized_base_random_layer.py", line 613, in call *
if images.shape.rank == 3 or images.shape.rank == 4:
File "/usr/local/lib/python3.11/dist-packages/keras_aug/layers/augmentation/auto/trivial_augment_wide.py", line 281, in _batch_augment *
result = tf.map_fn(
File "/usr/local/lib/python3.11/dist-packages/keras_aug/layers/augmentation/auto/trivial_augment_wide.py", line 316, in trivial_augment_single_input *
result = tf.switch_case(random_indice, branch_fns=branch_fns)
File "/usr/local/lib/python3.11/dist-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/tmp/__autograph_generated_file12c62saa.py", line 33, in tf__call
ag__.if_stmt(ag__.or_(lambda: ag__.ld(images).shape.rank == 3, lambda: ag__.ld(images).shape.rank == 4), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
File "/tmp/__autograph_generated_file12c62saa.py", line 28, in if_body
raise
File "/tmp/__autograph_generated_filevkq0dz7z.py", line 36, in tf___batch_augment
ag__.if_stmt(ag__.and_(lambda: ag__.converted_call(ag__.ld(isinstance), (ag__.ld(images), ag__.ld(tf).RaggedTensor), None, fscope), lambda: ag__.not_(ag__.ld(self).force_no_unwrap_ragged_image_call)), if_body, else_body, get_state, set_state, ('images',), 1)
File "/tmp/__autograph_generated_filevkq0dz7z.py", line 34, in else_body
images = ag__.converted_call(ag__.ld(self).augment_images, (ag__.ld(images),), dict(transformations=ag__.ld(transformations), bounding_boxes=ag__.ld(bounding_boxes), labels=ag__.ld(labels)), fscope)
File "/tmp/__autograph_generated_filegmle6ti9.py", line 15, in tf__augment_images
lows = ag__.converted_call(ag__.ld(tf).where, (ag__.ld(eq_idxs), 0.0, ag__.ld(lows)), None, fscope)
TypeError: Exception encountered when calling layer 'trivial_augment' (type AutoContrast).
in user code:
File "/usr/local/lib/python3.11/dist-packages/keras_aug/layers/base/vectorized_base_random_layer.py", line 614, in call *
return self._format_output(self._batch_augment(inputs), metadata)
File "/usr/local/lib/python3.11/dist-packages/keras_aug/layers/base/vectorized_base_random_layer.py", line 416, in _batch_augment *
images = self.augment_images(
File "/usr/local/lib/python3.11/dist-packages/keras_aug/layers/preprocessing/intensity/auto_contrast.py", line 50, in augment_images *
lows = tf.where(eq_idxs, 0.0, lows)
TypeError: Input 'e' of 'SelectV2' Op has type float16 that does not match type float32 of argument 't'.
Call arguments received by layer 'trivial_augment' (type AutoContrast):
• inputs={'images': 'tf.Tensor(shape=(1, 32, 32, 3), dtype=float16)', 'labels': 'tf.Tensor(shape=(1, 10), dtype=float16)'}
Call arguments received by layer 'trivial_augment' (type TrivialAugmentWide):
• inputs={'images': 'tf.Tensor(shape=(None, 32, 32, 3), dtype=float16)', 'labels': 'tf.Tensor(shape=(None, 10), dtype=float16)'}