@@ -50,10 +50,12 @@ limitations under the License.
5050#include " tensorflow/core/framework/partial_tensor_shape.h"
5151#include " tensorflow/core/framework/tensor.h"
5252#include " tensorflow/core/framework/tensor.pb.h" // NOLINT
53+ #include " tensorflow/core/framework/cpp_shape_inference.pb.h"
5354#include " tensorflow/core/framework/tensor_shape.h"
5455#include " tensorflow/core/framework/tensor_shape.pb.h"
5556#include " tensorflow/core/framework/types.h"
5657#include " tensorflow/core/framework/versions.pb.h"
58+ #include " tensorflow/core/framework/shape_inference.h"
5759#include " tensorflow/core/graph/graph.h"
5860#include " tensorflow/core/graph/node_builder.h"
5961#include " tensorflow/core/graph/validate.h"
@@ -71,6 +73,8 @@ limitations under the License.
7173#include " tensorflow/core/platform/types.h"
7274#include " tensorflow/core/public/session.h"
7375#include " tensorflow/core/public/version.h"
76+ #include " tensorflow/core/framework/full_type.pb.h"
77+ #include " tensorflow/core/framework/attr_value_util.h"
7478
7579// The implementation below is at the top level instead of the
7680// brain namespace because we are defining 'extern "C"' functions.
@@ -2614,6 +2618,144 @@ void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
26142618 }
26152619}
26162620
2621+ // TF Customized C APIs for Tensorflow.NET --------------------------
2622+
2623+ void TFC_AddControlInput (TF_Graph* graph, TF_Operation* op, TF_Operation* input) {
2624+ mutex_lock l (graph->mu );
2625+ graph->graph .AddControlEdge (&input->node , &op->node );
2626+ tensorflow::RecordMutation (graph, *op, " adding control input" );
2627+ }
2628+
2629+ void TFC_SetAttr (TF_Graph* graph, TF_Operation* op, const char * attr_name,
2630+ TF_Buffer* attr_value_proto, TF_Status* status) {
2631+ using tensorflow::RecordMutation;
2632+ tensorflow::AttrValue attr_val;
2633+ if (!attr_val.ParseFromArray (attr_value_proto->data ,
2634+ attr_value_proto->length )) {
2635+ status->status =
2636+ tensorflow::errors::InvalidArgument (" Invalid AttrValue proto" );
2637+ return ;
2638+ }
2639+
2640+ mutex_lock l (graph->mu );
2641+ op->node .AddAttr (attr_name, attr_val);
2642+ tensorflow::RecordMutation (graph, *op, " setting attribute" );
2643+ }
2644+
2645+ void TFC_ClearAttr (TF_Graph* graph, TF_Operation* op, const char * attr_name,
2646+ TF_Status* status) {
2647+ mutex_lock l (graph->mu );
2648+ op->node .ClearAttr (attr_name);
2649+ tensorflow::RecordMutation (graph, *op, " clearing attribute" );
2650+ }
2651+
2652+ void TFC_SetFullType (TF_Graph* graph, TF_Operation* op,
2653+ const tensorflow::FullTypeDef& full_type) {
2654+ mutex_lock l (graph->mu );
2655+ *op->node .mutable_def ()->mutable_experimental_type () = full_type;
2656+ tensorflow::RecordMutation (graph, *op, " setting fulltype" );
2657+ }
2658+
2659+ void TFC_SetRequestedDevice (TF_Graph* graph, TF_Operation* op, const char * device) {
2660+ mutex_lock l (graph->mu );
2661+ op->node .set_requested_device (device);
2662+ tensorflow::RecordMutation (graph, *op, " setting device" );
2663+ }
2664+
2665+ void TFC_UpdateEdge (TF_Graph* graph, TF_Output new_src, TF_Input dst,
2666+ TF_Status* status) {
2667+ TF_UpdateEdge (graph, new_src, dst, status);
2668+ }
2669+
2670+ void TFC_RemoveAllControlInputs (TF_Graph* graph, TF_Operation* op) {
2671+ mutex_lock l (graph->mu );
2672+ std::vector<const tensorflow::Edge*> control_edges;
2673+ for (const tensorflow::Edge* edge : op->node .in_edges ()) {
2674+ if (!edge->IsControlEdge ()) continue ;
2675+ control_edges.push_back (edge);
2676+ }
2677+ for (const tensorflow::Edge* edge : control_edges) {
2678+ graph->graph .RemoveControlEdge (edge);
2679+ }
2680+ }
2681+
2682+ void TFC_SetRequireShapeInferenceFns (TF_Graph* graph, bool require) {
2683+ mutex_lock l (graph->mu );
2684+ graph->refiner .set_require_shape_inference_fns (require);
2685+ }
2686+
2687+ void TFC_ExtendSession (TF_Session* session, TF_Status* status) {
2688+ ExtendSessionGraphHelper (session, status);
2689+ session->extend_before_run = false ;
2690+ }
2691+
2692+ const char * TFC_GetHandleShapeAndType (TF_Graph* graph, TF_Output output) {
2693+ Node* node = &output.oper ->node ;
2694+ tensorflow::CppShapeInferenceResult::HandleData handle_data;
2695+ handle_data.set_is_set (true );
2696+ {
2697+ mutex_lock l (graph->mu );
2698+ tensorflow::shape_inference::InferenceContext* ic =
2699+ graph->refiner .GetContext (node);
2700+ CHECK (ic != nullptr );
2701+ CHECK_LT (output.index , ic->num_outputs ());
2702+ const auto * shapes_and_types =
2703+ ic->output_handle_shapes_and_types (output.index );
2704+ if (shapes_and_types == nullptr ) return " " ;
2705+
2706+ for (const auto & p : *shapes_and_types) {
2707+ auto * out_shape_and_type = handle_data.add_shape_and_type ();
2708+ ic->ShapeHandleToProto (p.shape , out_shape_and_type->mutable_shape ());
2709+ out_shape_and_type->set_dtype (p.dtype );
2710+ *out_shape_and_type->mutable_type () = p.type ;
2711+ }
2712+ }
2713+ string result;
2714+ handle_data.SerializeToString (&result);
2715+ return result.c_str ();
2716+ }
2717+
2718+ void TFC_SetHandleShapeAndType (TF_Graph* graph, TF_Output output, const void * proto,
2719+ size_t proto_len, TF_Status* status) {
2720+ tensorflow::CppShapeInferenceResult::HandleData handle_data;
2721+ if (!handle_data.ParseFromArray (proto, proto_len)) {
2722+ status->status = tensorflow::errors::InvalidArgument (
2723+ " Couldn't deserialize HandleData proto" );
2724+ return ;
2725+ }
2726+ DCHECK (handle_data.is_set ());
2727+
2728+ tensorflow::mutex_lock l (graph->mu );
2729+ tensorflow::shape_inference::InferenceContext* ic =
2730+ graph->refiner .GetContext (&output.oper ->node );
2731+
2732+ std::vector<tensorflow::shape_inference::ShapeAndType> shapes_and_types;
2733+ for (const auto & shape_and_type_proto : handle_data.shape_and_type ()) {
2734+ tensorflow::shape_inference::ShapeHandle shape;
2735+ status->status =
2736+ ic->MakeShapeFromShapeProto (shape_and_type_proto.shape (), &shape);
2737+ if (TF_GetCode (status) != TF_OK) return ;
2738+ shapes_and_types.emplace_back (shape, shape_and_type_proto.dtype (),
2739+ shape_and_type_proto.type ());
2740+ }
2741+ ic->set_output_handle_shapes_and_types (output.index , shapes_and_types);
2742+ }
2743+
2744+ void TFC_AddWhileInputHack (TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
2745+ TF_Status* status) {
2746+ mutex_lock l (graph->mu );
2747+ status->status = graph->graph .AddWhileInputHack (&new_src.oper ->node ,
2748+ new_src.index , &dst->node );
2749+ if (TF_GetCode (status) == TF_OK) {
2750+ // This modification only updates the destination node for
2751+ // the purposes of running this graph in a session. Thus, we don't
2752+ // record the source node as being modified.
2753+ tensorflow::RecordMutation (graph, *dst, " adding input tensor" );
2754+ }
2755+ }
2756+
2757+ // -------------------------------------------------------------------
2758+
26172759// TF_Server functions ----------------------------------------------
26182760
26192761#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
0 commit comments