Giter Site home page Giter Site logo

james77777778 / keras-aug Goto Github PK

View Code? Open in Web Editor NEW
15.0 3.0 0.0 715 KB

A library that includes Keras 3 preprocessing and augmentation layers, providing support for various data types such as images, labels, bounding boxes, segmentation masks, and more.

License: Apache License 2.0

Python 99.76% Shell 0.24%
augmentation keras preprocessing tensorflow jax keras3 tensorflow-datasets torch torchvision

keras-aug's People


james77777778 avatar


 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar


 avatar  avatar  avatar



keras-aug's Issues

Release roadmap

The current focus of this repository is on 2D augmentations, such as geometry, intensity, mixing, and regularization. Version 0.1.0 will be released once it can reproduce all augmentations in YOLOv8 and existing KerasCV augmentations.



  • RandAugment
  • RepeatedAugment
  • TrivialAugmentWide


  • CenterCrop
  • PadIfNeeded
  • RandomAffine (yolov8: RandomPerspective)
  • RandomCrop
  • RandomCropAndResize (albumentations: RandomResizedCrop)
  • RandomFlip
  • RandomRotate
  • RandomZoomAndCrop
  • Resize (crop_to_aspect_ratio, pad_to_aspect_ratio, yolov8: LetterBox)
  • RandomResize (yolov5: multi-scale training)


  • AutoContrast
  • ChannelShuffle
  • Equalize
  • Grayscale
  • Invert
  • Normalize
  • RandomBlur
  • RandomChannelShift
  • RandomCLAHE (isears/tf_clahe)
  • RandomColorJitter (albumentations: ColorJitter)
  • RandomGamma
  • RandomGaussianBlur
  • RandomHSV
  • RandomJpegQuality (albumentations: ImageCompression)
  • RandomPosterize (a random version of keras_cv: Posterization)
  • RandomSharpness
  • RandomSolarize (keras_cv: Solarization)
  • RandomMedianBlur (not going to support)
  • Rescale


  • AugMix
  • CutMix
  • FourierMix
  • RandomCopyPaste (not going to support)
  • MixUp
  • MosaicYOLOV8 (an improved version of keras_cv: Mosaic)


  • ChannelDropout
  • RandomCutout
  • RandomErase
  • RandomGridMask


  • Identity
  • RandomApply
  • RandomChoice


  • VectorizedBaseRandomLayer


KerasAug will not support RandomMedianBlur due to the limitations of the implementation of tfa.image.median_filter2d, which uses tf.image.extract_patches and cannot support dynamic filter size.

KerasAug cannot support RandomCopyPaste augmentation until the data structure for instance segmentation is clarified.

Auto Augmentation with mixed precision bug

I am receiving the following error when trying to apply TrivialAugmentWide or RandAugment layers with mixed precision. Any guess?

from tensorflow.keras import mixed_precision

def make_dataset(X,y,batch_size,autotune =,augmentation = None,seed = seed):
    def preprocess_data(images, labels, augmentation=None):
        inputs = {"images": images, "labels": labels}
        outputs = augmentation(inputs) if augmentation != None else inputs
        return outputs["images"], outputs["labels"]
    dataset =, y))
    dataset = dataset.batch(batch_size).map(lambda x, y: preprocess_data(x, y, augmentation=augmentation), num_parallel_calls=autotune).prefetch(autotune)
return dataset

augmentation_layer = tfk.Sequential([
    keras_aug.layers.TrivialAugmentWide(value_range=(0,1), interpolation='bilinear', name='trivial_augment'),
    keras_aug.layers.RandomErase(area_factor=(0.02, 0.1), fill_mode='gaussian_noise', name='random_erase')
    ], name='preprocessing')

training_dataset = make_dataset(X_train, y_train, batch_size=batch_size, augmentation=augmentation_layer)


File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/data/ops/, in, map_func, num_parallel_calls, deterministic, name)
   2264 # Loaded lazily due to a circular dependency (dataset_ops -> map_op ->
   2265 # dataset_ops).
   2266 # pylint: disable=g-import-not-at-top,protected-access
   2267 from import map_op
-> 2268 return map_op._map_v2(
   2269     self,
   2270     map_func,
   2271     num_parallel_calls=num_parallel_calls,
   2272     deterministic=deterministic,
   2273     name=name)

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/data/ops/, in _map_v2(input_dataset, map_func, num_parallel_calls, deterministic, name)
     37   return _MapDataset(
     38       input_dataset, map_func, preserve_cardinality=True, name=name)
     39 else:
---> 40   return _ParallelMapDataset(
     41       input_dataset,
     42       map_func,
     43       num_parallel_calls=num_parallel_calls,
     44       deterministic=deterministic,
     45       preserve_cardinality=True,
     46       name=name)

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/data/ops/, in _ParallelMapDataset.__init__(self, input_dataset, map_func, num_parallel_calls, deterministic, use_inter_op_parallelism, preserve_cardinality, use_legacy_function, name)
    146 self._input_dataset = input_dataset
    147 self._use_inter_op_parallelism = use_inter_op_parallelism
--> 148 self._map_func = structured_function.StructuredFunctionWrapper(
    149     map_func,
    150     self._transformation_name(),
    151     dataset=input_dataset,
    152     use_legacy_function=use_legacy_function)
    153 if deterministic is None:
    154   self._deterministic = "default"

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/data/ops/, in StructuredFunctionWrapper.__init__(self, func, transformation_name, dataset, input_classes, input_shapes, input_types, input_structure, add_to_graph, use_legacy_function, defun_kwargs)
    258       warnings.warn(
    259           "Even though the `tf.config.experimental_run_functions_eagerly` "
    260           "option is set, this option does not apply to functions. "
    261           "To force eager execution of functions, please use "
    262           "``.")
    263     fn_factory = trace_tf_function(defun_kwargs)
--> 265 self._function = fn_factory()
    266 # There is no graph to add in eager mode.
    267 add_to_graph &= not context.executing_eagerly()

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/polymorphic_function/, in Function.get_concrete_function(self, *args, **kwargs)
   1220 def get_concrete_function(self, *args, **kwargs):
   1221   # Implements GenericFunction.get_concrete_function.
-> 1222   concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
   1223   concrete._garbage_collector.release()  # pylint: disable=protected-access
   1224   return concrete

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/polymorphic_function/, in Function._get_concrete_function_garbage_collected(self, *args, **kwargs)
   1190   if self._variable_creation_config is None:
   1191     initializers = []
-> 1192     self._initialize(args, kwargs, add_initializers_to=initializers)
   1193     self._initialize_uninitialized_variables(initializers)
   1195 if self._created_variables:
   1196   # In this case we have created variables on the first call, so we run the
   1197   # version which is guaranteed to never create variables.

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/polymorphic_function/, in Function._initialize(self, args, kwds, add_initializers_to)
    689 self._variable_creation_config = self._generate_scoped_tracing_options(
    690     variable_capturing_scope,
    691     tracing_compilation.ScopeType.VARIABLE_CREATION,
    692 )
    693 # Force the definition of the function for these arguments
--> 694 self._concrete_variable_creation_fn = tracing_compilation.trace_function(
    695     args, kwds, self._variable_creation_config
    696 )
    698 def invalid_creator_scope(*unused_args, **unused_kwds):
    699   """Disables variable creation."""

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/polymorphic_function/, in trace_function(args, kwargs, tracing_options)
    175     args = tracing_options.input_signature
    176     kwargs = {}
--> 178   concrete_function = _maybe_define_function(
    179       args, kwargs, tracing_options
    180   )
    181   _set_arg_keywords(concrete_function)
    183 if not tracing_options.bind_graph_to_function:

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/polymorphic_function/, in _maybe_define_function(args, kwargs, tracing_options)
    282 else:
    283   target_func_type = lookup_func_type
--> 284 concrete_function = _create_concrete_function(
    285     target_func_type, lookup_func_context, func_graph, tracing_options
    286 )
    288 if tracing_options.function_cache is not None:
    289   tracing_options.function_cache.add(
    290       concrete_function, current_func_context
    291   )

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/polymorphic_function/, in _create_concrete_function(function_type, type_context, func_graph, tracing_options)
    303 with func_graph.as_default():
    304   placeholder_bound_args = function_type.placeholder_arguments(
    305       placeholder_context
    306   )
--> 308 traced_func_graph = func_graph_module.func_graph_from_py_func(
    310     tracing_options.python_function,
    311     placeholder_bound_args.args,
    312     placeholder_bound_args.kwargs,
    313     None,
    314     func_graph=func_graph,
    315     arg_names=function_type_utils.to_arg_names(function_type),
    316     create_placeholders=False,
    317 )
    319 transform.apply_func_graph_transforms(traced_func_graph)
    321 graph_capture_container = traced_func_graph.function_captures

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/framework/, in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, create_placeholders)
   1056   return x
   1058 _, original_func = tf_decorator.unwrap(python_func)
-> 1059 func_outputs = python_func(*func_args, **func_kwargs)
   1061 # invariant: `func_outputs` contains only Tensors, CompositeTensors,
   1062 # TensorArrays and `None`s.
   1063 func_outputs = variable_utils.convert_variables_to_tensors(func_outputs)

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/eager/polymorphic_function/, in Function._generate_scoped_tracing_options.<locals>.wrapped_fn(*args, **kwds)
    593 with default_graph._variable_creator_scope(scope, priority=50):  # pylint: disable=protected-access
    594   # __wrapped__ allows AutoGraph to swap in a converted function. We give
    595   # the function a weak reference to itself to avoid a reference cycle.
    596   with OptionalXlaContext(compile_with_xla):
--> 597     out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    598   return out

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/data/ops/, in StructuredFunctionWrapper.__init__.<locals>.trace_tf_function.<locals>.wrapped_fn(*args)
    230 def wrapped_fn(*args):  # pylint: disable=missing-docstring
--> 231   ret = wrapper_helper(*args)
    232   ret = structure.to_tensor_list(self._output_structure, ret)
    233   return [ops.convert_to_tensor(t) for t in ret]

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/data/ops/, in StructuredFunctionWrapper.__init__.<locals>.wrapper_helper(*args)
    159 if not _should_unpack(nested_args):
    160   nested_args = (nested_args,)
--> 161 ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args)
    162 ret = variable_utils.convert_variables_to_tensors(ret)
    163 if _should_pack(ret):

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/autograph/impl/, in convert.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
    691 except Exception as e:  # pylint:disable=broad-except
    692   if hasattr(e, 'ag_error_metadata'):
--> 693     raise e.ag_error_metadata.to_exception(e)
    694   else:
    695     raise

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/autograph/impl/, in convert.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
    688 try:
    689   with conversion_ctx:
--> 690     return converted_call(f, args, kwargs, options=options)
    691 except Exception as e:  # pylint:disable=broad-except
    692   if hasattr(e, 'ag_error_metadata'):

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/autograph/impl/, in converted_call(f, args, kwargs, caller_fn_scope, options)
    437 try:
    438   if kwargs is not None:
--> 439     result = converted_f(*effective_args, **kwargs)
    440   else:
    441     result = converted_f(*effective_args)

File ~tmp/, in outer_factory.<locals>.inner_factory.<locals>.<lambda>(x, y)
      6 def inner_factory(ag__):
----> 7     tf__lam = lambda x, y: ag__.with_function_scope(lambda lscope: ag__.converted_call(preprocess_data, (x, y), dict(augmentation=augmentation), lscope), 'lscope', ag__.STD)
      8     return tf__lam

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/autograph/core/, in with_function_scope(thunk, scope_name, options)
    111 """Inline version of the FunctionScope context manager."""
    112 with FunctionScope('lambda_', scope_name, options) as scope:
--> 113   return thunk(scope)

File ~tmp/, in outer_factory.<locals>.inner_factory.<locals>.<lambda>(lscope)
      6 def inner_factory(ag__):
----> 7     tf__lam = lambda x, y: ag__.with_function_scope(lambda lscope: ag__.converted_call(preprocess_data, (x, y), dict(augmentation=augmentation), lscope), 'lscope', ag__.STD)
      8     return tf__lam

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/autograph/impl/, in converted_call(f, args, kwargs, caller_fn_scope, options)
    437 try:
    438   if kwargs is not None:
--> 439     result = converted_f(*effective_args, **kwargs)
    440   else:
    441     result = converted_f(*effective_args)

File ~tmp/, in outer_factory.<locals>.inner_factory.<locals>.tf__preprocess_data(images, labels, augmentation)
      9 retval_ = ag__.UndefinedReturnValue()
     10 inputs = {'images': ag__.ld(images), 'labels': ag__.ld(labels)}
---> 11 outputs = ag__.if_exp(ag__.ld(augmentation) != None, lambda: ag__.converted_call(ag__.ld(augmentation), (ag__.ld(inputs),), None, fscope), lambda: ag__.ld(inputs), 'augmentation != None')
     12 try:
     13     do_return = True

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/autograph/operators/, in if_exp(cond, if_true, if_false, expr_repr)
     25   return _tf_if_exp(cond, if_true, if_false, expr_repr)
     26 else:
---> 27   return _py_if_exp(cond, if_true, if_false)

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/autograph/operators/, in _py_if_exp(cond, if_true, if_false)
     51 def _py_if_exp(cond, if_true, if_false):
---> 52   return if_true() if cond else if_false()

File ~tmp/, in outer_factory.<locals>.inner_factory.<locals>.tf__preprocess_data.<locals>.<lambda>()
      9 retval_ = ag__.UndefinedReturnValue()
     10 inputs = {'images': ag__.ld(images), 'labels': ag__.ld(labels)}
---> 11 outputs = ag__.if_exp(ag__.ld(augmentation) != None, lambda: ag__.converted_call(ag__.ld(augmentation), (ag__.ld(inputs),), None, fscope), lambda: ag__.ld(inputs), 'augmentation != None')
     12 try:
     13     do_return = True

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/autograph/impl/, in converted_call(f, args, kwargs, caller_fn_scope, options)
    329 if conversion.is_in_allowlist_cache(f, options):
    330   logging.log(2, 'Allowlisted %s: from cache', f)
--> 331   return _call_unconverted(f, args, kwargs, options, False)
    333 if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
    334   logging.log(2, 'Allowlisted: %s: AutoGraph is disabled in context', f)

File ~usr/local/lib/python3.11/dist-packages/tensorflow/python/autograph/impl/, in _call_unconverted(f, args, kwargs, options, update_cache)
    458 if kwargs is not None:
    459   return f(*args, **kwargs)
--> 460 return f(*args)

File ~usr/local/lib/python3.11/dist-packages/keras/src/utils/, in filter_traceback.<locals>.error_handler(*args, **kwargs)
     67     filtered_tb = _process_traceback_frames(e.__traceback__)
     68     # To get the full stack trace, call:
     69     # `tf.debugging.disable_traceback_filtering()`
---> 70     raise e.with_traceback(filtered_tb) from None
     71 finally:
     72     del filtered_tb

File ~tmp/, in outer_factory.<locals>.inner_factory.<locals>.tf__call(self, inputs)
     31     nonlocal do_return, retval_
     32     raise ag__.converted_call(ag__.ld(ValueError), (f'Image augmentation layers are expecting inputs to be rank 3 (HWC) or 4D (NHWC) tensors. Got shape: {ag__.ld(images).shape}',), None, fscope)
---> 33 ag__.if_stmt(ag__.or_(lambda: ag__.ld(images).shape.rank == 3, lambda: ag__.ld(images).shape.rank == 4), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
     34 return fscope.ret(retval_, do_return)

File ~tmp/, in outer_factory.<locals>.inner_factory.<locals>.tf__call.<locals>.if_body()
     23 try:
     24     do_return = True
---> 25     retval_ = ag__.converted_call(ag__.ld(self)._format_output, (ag__.converted_call(ag__.ld(self)._batch_augment, (ag__.ld(inputs),), None, fscope), ag__.ld(metadata)), None, fscope)
     26 except:
     27     do_return = False

File ~tmp/, in outer_factory.<locals>.inner_factory.<locals>.tf___batch_augment(self, inputs)
     33 ag__.if_stmt(ag__.ld(bounding_boxes) is not None, if_body, else_body, get_state, set_state, ('inputs[BOUNDING_BOXES]', 'ori_bbox_info'), 2)
     34 inputs_for_trivial_augment_single_input = {'inputs': ag__.ld(inputs), 'transformations': ag__.ld(transformations)}
---> 35 result = ag__.converted_call(ag__.ld(tf).map_fn, (ag__.ld(self).trivial_augment_single_input, ag__.ld(inputs_for_trivial_augment_single_input)), dict(fn_output_signature=ag__.converted_call(ag__.ld(augmentation_utils).compute_signature, (ag__.ld(inputs), ag__.ld(self).compute_dtype), None, fscope)), fscope)
     36 bounding_boxes = ag__.converted_call(ag__.ld(result).get, (ag__.ld(BOUNDING_BOXES), None), None, fscope)
     38 def get_state_2():

File ~tmp/, in outer_factory.<locals>.inner_factory.<locals>.tf__trivial_augment_single_input(self, inputs)
     24 idx = ag__.Undefined('idx')
     25 ag__.for_stmt(ag__.converted_call(ag__.ld(enumerate), (ag__.ld(self).aug_layers,), None, fscope), None, loop_body, get_state, set_state, (), {'iterate_names': '(idx, layer)'})
---> 26 result = ag__.converted_call(ag__.ld(tf).switch_case, (ag__.ld(random_indice),), dict(branch_fns=ag__.ld(branch_fns)), fscope)
     28 def get_state_1():
     29     return (ag__.ldu(lambda: result[BOUNDING_BOXES], 'result[BOUNDING_BOXES]'),)

File ~tmp/, in outer_factory.<locals>.inner_factory.<locals>.tf__call(self, inputs)
     31     nonlocal do_return, retval_
     32     raise ag__.converted_call(ag__.ld(ValueError), (f'Image augmentation layers are expecting inputs to be rank 3 (HWC) or 4D (NHWC) tensors. Got shape: {ag__.ld(images).shape}',), None, fscope)
---> 33 ag__.if_stmt(ag__.or_(lambda: ag__.ld(images).shape.rank == 3, lambda: ag__.ld(images).shape.rank == 4), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
     34 return fscope.ret(retval_, do_return)

File ~tmp/, in outer_factory.<locals>.inner_factory.<locals>.tf__call.<locals>.if_body()
     23 try:
     24     do_return = True
---> 25     retval_ = ag__.converted_call(ag__.ld(self)._format_output, (ag__.converted_call(ag__.ld(self)._batch_augment, (ag__.ld(inputs),), None, fscope), ag__.ld(metadata)), None, fscope)
     26 except:
     27     do_return = False

File ~tmp/, in outer_factory.<locals>.inner_factory.<locals>.tf___batch_augment(self, inputs)
     34     images = ag__.converted_call(ag__.ld(self).augment_images, (ag__.ld(images),), dict(transformations=ag__.ld(transformations), bounding_boxes=ag__.ld(bounding_boxes), labels=ag__.ld(labels)), fscope)
     35 inputs_for_raggeds = ag__.Undefined('inputs_for_raggeds')
---> 36 ag__.if_stmt(ag__.and_(lambda: ag__.converted_call(ag__.ld(isinstance), (ag__.ld(images), ag__.ld(tf).RaggedTensor), None, fscope), lambda: ag__.not_(ag__.ld(self).force_no_unwrap_ragged_image_call)), if_body, else_body, get_state, set_state, ('images',), 1)
     38 def get_state_1():
     39     return (images,)

File ~tmp/, in outer_factory.<locals>.inner_factory.<locals>.tf___batch_augment.<locals>.else_body()
     32 def else_body():
     33     nonlocal images
---> 34     images = ag__.converted_call(ag__.ld(self).augment_images, (ag__.ld(images),), dict(transformations=ag__.ld(transformations), bounding_boxes=ag__.ld(bounding_boxes), labels=ag__.ld(labels)), fscope)

File ~tmp/, in outer_factory.<locals>.inner_factory.<locals>.tf__augment_images(self, images, transformations, **kwargs)
     13 scales = 255.0 / (ag__.ld(highs) - ag__.ld(lows))
     14 eq_idxs = ag__.converted_call(ag__.ld(tf).math.is_inf, (ag__.ld(scales),), None, fscope)
---> 15 lows = ag__.converted_call(ag__.ld(tf).where, (ag__.ld(eq_idxs), 0.0, ag__.ld(lows)), None, fscope)
     16 scales = ag__.converted_call(ag__.ld(tf).where, (ag__.ld(eq_idxs), 1.0, ag__.ld(scales)), None, fscope)
     17 images = ag__.converted_call(ag__.ld(tf).clip_by_value, ((ag__.ld(images) - ag__.ld(lows)) * ag__.ld(scales), 0, 255), None, fscope)

TypeError: in user code:

    File "/tmp/ipykernel_11/", line 15, in None  *
        lambda x, y: preprocess_data(x, y, augmentation=augmentation)
    File "/tmp/ipykernel_11/", line 11, in preprocess_data  *
        outputs = augmentation(inputs) if augmentation != None else inputs
    File "/usr/local/lib/python3.11/dist-packages/keras/src/utils/", line 70, in error_handler  **
        raise e.with_traceback(filtered_tb) from None
    File "/tmp/", line 33, in tf__call
        ag__.if_stmt(ag__.or_(lambda: ag__.ld(images).shape.rank == 3, lambda: ag__.ld(images).shape.rank == 4), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
    File "/tmp/", line 28, in if_body
    File "/tmp/", line 35, in tf___batch_augment
        result = ag__.converted_call(ag__.ld(tf).map_fn, (ag__.ld(self).trivial_augment_single_input, ag__.ld(inputs_for_trivial_augment_single_input)), dict(fn_output_signature=ag__.converted_call(ag__.ld(augmentation_utils).compute_signature, (ag__.ld(inputs), ag__.ld(self).compute_dtype), None, fscope)), fscope)
    File "/tmp/", line 26, in tf__trivial_augment_single_input
        result = ag__.converted_call(ag__.ld(tf).switch_case, (ag__.ld(random_indice),), dict(branch_fns=ag__.ld(branch_fns)), fscope)
    File "/tmp/", line 33, in tf__call
        ag__.if_stmt(ag__.or_(lambda: ag__.ld(images).shape.rank == 3, lambda: ag__.ld(images).shape.rank == 4), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
    File "/tmp/", line 28, in if_body
    File "/tmp/", line 36, in tf___batch_augment
        ag__.if_stmt(ag__.and_(lambda: ag__.converted_call(ag__.ld(isinstance), (ag__.ld(images), ag__.ld(tf).RaggedTensor), None, fscope), lambda: ag__.not_(ag__.ld(self).force_no_unwrap_ragged_image_call)), if_body, else_body, get_state, set_state, ('images',), 1)
    File "/tmp/", line 34, in else_body
        images = ag__.converted_call(ag__.ld(self).augment_images, (ag__.ld(images),), dict(transformations=ag__.ld(transformations), bounding_boxes=ag__.ld(bounding_boxes), labels=ag__.ld(labels)), fscope)
    File "/tmp/", line 15, in tf__augment_images
        lows = ag__.converted_call(ag__.ld(tf).where, (ag__.ld(eq_idxs), 0.0, ag__.ld(lows)), None, fscope)

    TypeError: Exception encountered when calling layer 'trivial_augment' (type TrivialAugmentWide).
    in user code:
        File "/usr/local/lib/python3.11/dist-packages/keras_aug/layers/base/", line 613, in call  *
            if images.shape.rank == 3 or images.shape.rank == 4:
        File "/usr/local/lib/python3.11/dist-packages/keras_aug/layers/augmentation/auto/", line 281, in _batch_augment  *
            result = tf.map_fn(
        File "/usr/local/lib/python3.11/dist-packages/keras_aug/layers/augmentation/auto/", line 316, in trivial_augment_single_input  *
            result = tf.switch_case(random_indice, branch_fns=branch_fns)
        File "/usr/local/lib/python3.11/dist-packages/keras/src/utils/", line 70, in error_handler
            raise e.with_traceback(filtered_tb) from None
        File "/tmp/", line 33, in tf__call
            ag__.if_stmt(ag__.or_(lambda: ag__.ld(images).shape.rank == 3, lambda: ag__.ld(images).shape.rank == 4), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
        File "/tmp/", line 28, in if_body
        File "/tmp/", line 36, in tf___batch_augment
            ag__.if_stmt(ag__.and_(lambda: ag__.converted_call(ag__.ld(isinstance), (ag__.ld(images), ag__.ld(tf).RaggedTensor), None, fscope), lambda: ag__.not_(ag__.ld(self).force_no_unwrap_ragged_image_call)), if_body, else_body, get_state, set_state, ('images',), 1)
        File "/tmp/", line 34, in else_body
            images = ag__.converted_call(ag__.ld(self).augment_images, (ag__.ld(images),), dict(transformations=ag__.ld(transformations), bounding_boxes=ag__.ld(bounding_boxes), labels=ag__.ld(labels)), fscope)
        File "/tmp/", line 15, in tf__augment_images
            lows = ag__.converted_call(ag__.ld(tf).where, (ag__.ld(eq_idxs), 0.0, ag__.ld(lows)), None, fscope)
        TypeError: Exception encountered when calling layer 'trivial_augment' (type AutoContrast).
        in user code:
            File "/usr/local/lib/python3.11/dist-packages/keras_aug/layers/base/", line 614, in call  *
                return self._format_output(self._batch_augment(inputs), metadata)
            File "/usr/local/lib/python3.11/dist-packages/keras_aug/layers/base/", line 416, in _batch_augment  *
                images = self.augment_images(
            File "/usr/local/lib/python3.11/dist-packages/keras_aug/layers/preprocessing/intensity/", line 50, in augment_images  *
                lows = tf.where(eq_idxs, 0.0, lows)
            TypeError: Input 'e' of 'SelectV2' Op has type float16 that does not match type float32 of argument 't'.
        Call arguments received by layer 'trivial_augment' (type AutoContrast):
          • inputs={'images': 'tf.Tensor(shape=(1, 32, 32, 3), dtype=float16)', 'labels': 'tf.Tensor(shape=(1, 10), dtype=float16)'}
    Call arguments received by layer 'trivial_augment' (type TrivialAugmentWide):
      • inputs={'images': 'tf.Tensor(shape=(None, 32, 32, 3), dtype=float16)', 'labels': 'tf.Tensor(shape=(None, 10), dtype=float16)'}

Investigate the possibility of using `jax` as backend (ongoing)

Can KerasAug uses jax as its backend?


Pseudo code:

class BaseLayer(nn.Module):

    def __call__(self, inputs):

        return outputs
  • Support dense array first

Support ragged `segmentation_masks`

  • RandomAffine
  • RandomCropAndResize
  • RandomCrop
  • RandomFlip
  • RandomResize (polish api)
  • RandomRotate
  • RandomZoomAndCrop (add support)
  • CenterCrop (polish api)
  • PadIfNeeded (polish api)
  • Resize (polish api)

Add benchmark


  • KerasCV
  • Torchvision
  • Albumentations

Tests should be run in the batch mode

Fix docstrings


  • add RandomResize


  • RandomFlip
    References: add KerasCV
  • RandomZoomAndCrop
    References: add KerasCV
  • Resize
    References: add KerasCV
  • RandomColorJitter
    Fix References
  • RandomGaussianBlur
    References: add KerasCV
  • RandomSharpness
    References: Tensorflow Model -> Tensorflow Model augment
  • MixUp
    References: mixup: Beyond Empirical Risk Minimization -> MixUp
  • RandomGridMask:
    References: GridMask repo -> GridMask Official Repo

Refactor layers



  1. preprocessing
    CenterCrop, PadIfNeeded, Resize, ResizeAndCrop, ResizeAndPad, ResizeByLongestSide, ResizeBySmallestSide, AutoContrast, Equalize, Grayscale, Invert, Normalize, Rescale
  2. augmentation
    • auto
    • geometry
    • intensity
    • mix
    • regularization
    • utility
  3. base



Unit tests for all augmentation layers

This issue keeping track of a list of unit tests:



  • CenterCrop
  • PadIfNeeded
  • RandomAffine
  • RandomCropAndResize
  • ResizeAndPad
  • ResizeByLongestSide
  • ResizeBySmallestSide


  • Normalize
  • RandomBlur
  • RandomBrightnessContrast
  • RandomGamma
  • RandomColorJitter
  • RandomHSV
  • RandomJpegQuality


  • MosaicYOLOV8
  • MixUp


  • ChannelDropout


  • RandomApply

Release v0.5.7

Need to bump version to v0.5.7 to support tensorflow 2.13 and keras 2.13

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.