Why this function call is giving error
TypeError Traceback (most recent call last)
in
1 # Initialize training
----> 2 iterative_process = tff.learning.build_federated_averaging_process(model_fn)
3 state = iterative_process.initialize()
~/anaconda3/envs/text_pred/lib/python3.6/site-packages/tensorflow_federated/python/learning/federated_averaging.py in build_federated_averaging_process(model_fn, server_optimizer_fn, client_weight_fn, stateful_delta_aggregate_fn, stateful_model_broadcast_fn)
162 return optimizer_utils.build_model_delta_optimizer_process(
163 model_fn, client_fed_avg, server_optimizer_fn,
--> 164 stateful_delta_aggregate_fn, stateful_model_broadcast_fn)
~/anaconda3/envs/text_pred/lib/python3.6/site-packages/tensorflow_federated/python/learning/framework/optimizer_utils.py in build_model_delta_optimizer_process(model_fn, model_to_client_delta_fn, server_optimizer_fn, stateful_delta_aggregate_fn, stateful_model_broadcast_fn)
359 server_state_type = tf_init_fn.type_signature.result
360
--> 361 @tff.tf_computation(tf_dataset_type, server_state_type.model)
362 def tf_client_delta(tf_dataset, initial_model_weights):
363 """Performs client local model optimization.
~/anaconda3/envs/text_pred/lib/python3.6/site-packages/tensorflow_federated/python/core/impl/computation_wrapper.py in (fn)
413 args = (args,)
414 arg_type = computation_types.to_type(args[0])
--> 415 return lambda fn: _wrap(fn, arg_type, self._wrapper_fn)
~/anaconda3/envs/text_pred/lib/python3.6/site-packages/tensorflow_federated/python/core/impl/computation_wrapper.py in _wrap(fn, parameter_type, wrapper_fn)
101
102 # Either we have a concrete parameter type, or this is no-arg function.
--> 103 concrete_fn = wrapper_fn(fn, parameter_type, unpack=None)
104 py_typecheck.check_type(concrete_fn, function_utils.ConcreteFunction,
105 'value returned by the wrapper')
~/anaconda3/envs/text_pred/lib/python3.6/site-packages/tensorflow_federated/python/core/impl/computation_wrapper_instances.py in _tf_wrapper_fn(failed resolving arguments)
42 ctx_stack = context_stack_impl.context_stack
43 comp_pb, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation(
---> 44 target_fn, parameter_type, ctx_stack)
45 return computation_impl.ComputationImpl(comp_pb, ctx_stack, extra_type_spec)
46
~/anaconda3/envs/text_pred/lib/python3.6/site-packages/tensorflow_federated/python/core/impl/tensorflow_serialization.py in serialize_py_fn_as_tf_computation(target, parameter_type, context_stack)
266 context = tf_computation_context.TensorFlowComputationContext(graph)
267 with context_stack.install(context):
--> 268 result = target(*args)
269
270 # TODO(b/122081673): This needs to change for TF 2.0. We may also
~/anaconda3/envs/text_pred/lib/python3.6/site-packages/tensorflow_federated/python/core/impl/utils/function_utils.py in (arg)
582 except NameError:
583 raise AssertionError('Args to be bound must be in scope.')
--> 584 return lambda arg: _unpack_and_call(fn, arg_types, kwarg_types, arg)
585 else:
586 # An interceptor function that verifies the actual parameter before it
~/anaconda3/envs/text_pred/lib/python3.6/site-packages/tensorflow_federated/python/core/impl/utils/function_utils.py in _unpack_and_call(fn, arg_types, kwarg_types, arg)
553 for idx, expected_type in enumerate(arg_types):
554 element_value = arg[idx]
--> 555 actual_type = type_utils.infer_type(element_value)
556 if not type_utils.is_assignable_from(expected_type, actual_type):
557 raise TypeError(
~/anaconda3/envs/text_pred/lib/python3.6/site-packages/tensorflow_federated/python/core/impl/type_utils.py in infer_type(arg)
63 return computation_types.SequenceType(
64 tf_dtypes_and_shapes_to_type(
---> 65 tf.compat.v1.data.get_output_types(arg),
66 tf.compat.v1.data.get_output_shapes(arg)))
67 elif isinstance(arg, anonymous_tuple.AnonymousTuple):
~/anaconda3/envs/text_pred/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py in get_legacy_output_types(dataset_or_iterator)
2037 of an element of this dataset.
2038 """
-> 2039 return get_structure(dataset_or_iterator)._to_legacy_output_types() # pylint: disable=protected-access
2040
2041
~/anaconda3/envs/text_pred/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py in get_structure(dataset_or_iterator)
2001 pass
2002 raise TypeError("dataset_or_iterator
must be a Dataset or Iterator object, "
-> 2003 "but got %s." % type(dataset_or_iterator))
2004
2005
TypeError: dataset_or_iterator
must be a Dataset or Iterator object, but got <class 'tensorflow.python.data.ops.dataset_ops._VariantDataset'>.