@@ -51,10 +51,14 @@ limitations under the License.
5151#include " tensorflow/core/framework/partial_tensor_shape.h"
5252#include " tensorflow/core/framework/tensor.h"
5353#include " tensorflow/core/framework/tensor.pb.h" // NOLINT
54+ #include " tensorflow/core/framework/cpp_shape_inference.pb.h"
55+ #include " tensorflow/core/framework/shape_inference.h"
5456#include " tensorflow/core/framework/tensor_shape.h"
5557#include " tensorflow/core/framework/tensor_shape.pb.h"
5658#include " tensorflow/core/framework/types.h"
5759#include " tensorflow/core/framework/versions.pb.h"
60+ #include " tensorflow/core/framework/full_type.pb.h"
61+ #include " tensorflow/core/framework/attr_value_util.h"
5862#include " tensorflow/core/graph/graph.h"
5963#include " tensorflow/core/graph/node_builder.h"
6064#include " tensorflow/core/graph/validate.h"
@@ -2556,6 +2560,144 @@ void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
25562560 }
25572561}
25582562
2563+ // TF Customized C APIs for Tensorflow.NET --------------------------
2564+
2565+ void TFC_AddControlInput (TF_Graph* graph, TF_Operation* op, TF_Operation* input) {
2566+ mutex_lock l (graph->mu );
2567+ graph->graph .AddControlEdge (&input->node , &op->node );
2568+ tensorflow::RecordMutation (graph, *op, " adding control input" );
2569+ }
2570+
2571+ void TFC_SetAttr (TF_Graph* graph, TF_Operation* op, const char * attr_name,
2572+ TF_Buffer* attr_value_proto, TF_Status* status) {
2573+ using tensorflow::RecordMutation;
2574+ tensorflow::AttrValue attr_val;
2575+ if (!attr_val.ParseFromArray (attr_value_proto->data ,
2576+ attr_value_proto->length )) {
2577+ status->status =
2578+ tensorflow::errors::InvalidArgument (" Invalid AttrValue proto" );
2579+ return ;
2580+ }
2581+
2582+ mutex_lock l (graph->mu );
2583+ op->node .AddAttr (attr_name, attr_val);
2584+ tensorflow::RecordMutation (graph, *op, " setting attribute" );
2585+ }
2586+
2587+ void TFC_ClearAttr (TF_Graph* graph, TF_Operation* op, const char * attr_name,
2588+ TF_Status* status) {
2589+ mutex_lock l (graph->mu );
2590+ op->node .ClearAttr (attr_name);
2591+ tensorflow::RecordMutation (graph, *op, " clearing attribute" );
2592+ }
2593+
2594+ void TFC_SetFullType (TF_Graph* graph, TF_Operation* op,
2595+ const tensorflow::FullTypeDef& full_type) {
2596+ mutex_lock l (graph->mu );
2597+ *op->node .mutable_def ()->mutable_experimental_type () = full_type;
2598+ tensorflow::RecordMutation (graph, *op, " setting fulltype" );
2599+ }
2600+
2601+ void TFC_SetRequestedDevice (TF_Graph* graph, TF_Operation* op, const char * device) {
2602+ mutex_lock l (graph->mu );
2603+ op->node .set_requested_device (device);
2604+ tensorflow::RecordMutation (graph, *op, " setting device" );
2605+ }
2606+
2607+ void TFC_UpdateEdge (TF_Graph* graph, TF_Output new_src, TF_Input dst,
2608+ TF_Status* status) {
2609+ TF_UpdateEdge (graph, new_src, dst, status);
2610+ }
2611+
2612+ void TFC_RemoveAllControlInputs (TF_Graph* graph, TF_Operation* op) {
2613+ mutex_lock l (graph->mu );
2614+ std::vector<const tensorflow::Edge*> control_edges;
2615+ for (const tensorflow::Edge* edge : op->node .in_edges ()) {
2616+ if (!edge->IsControlEdge ()) continue ;
2617+ control_edges.push_back (edge);
2618+ }
2619+ for (const tensorflow::Edge* edge : control_edges) {
2620+ graph->graph .RemoveControlEdge (edge);
2621+ }
2622+ }
2623+
2624+ void TFC_SetRequireShapeInferenceFns (TF_Graph* graph, bool require) {
2625+ mutex_lock l (graph->mu );
2626+ graph->refiner .set_require_shape_inference_fns (require);
2627+ }
2628+
2629+ void TFC_ExtendSession (TF_Session* session, TF_Status* status) {
2630+ ExtendSessionGraphHelper (session, status);
2631+ session->extend_before_run = false ;
2632+ }
2633+
2634+ const char * TFC_GetHandleShapeAndType (TF_Graph* graph, TF_Output output) {
2635+ Node* node = &output.oper ->node ;
2636+ tensorflow::CppShapeInferenceResult::HandleData handle_data;
2637+ handle_data.set_is_set (true );
2638+ {
2639+ mutex_lock l (graph->mu );
2640+ tensorflow::shape_inference::InferenceContext* ic =
2641+ graph->refiner .GetContext (node);
2642+ CHECK (ic != nullptr );
2643+ CHECK_LT (output.index , ic->num_outputs ());
2644+ const auto * shapes_and_types =
2645+ ic->output_handle_shapes_and_types (output.index );
2646+ if (shapes_and_types == nullptr ) return " " ;
2647+
2648+ for (const auto & p : *shapes_and_types) {
2649+ auto * out_shape_and_type = handle_data.add_shape_and_type ();
2650+ ic->ShapeHandleToProto (p.shape , out_shape_and_type->mutable_shape ());
2651+ out_shape_and_type->set_dtype (p.dtype );
2652+ *out_shape_and_type->mutable_type () = p.type ;
2653+ }
2654+ }
2655+ string result;
2656+ handle_data.SerializeToString (&result);
2657+ return result.c_str ();
2658+ }
2659+
2660+ void TFC_SetHandleShapeAndType (TF_Graph* graph, TF_Output output, const void * proto,
2661+ size_t proto_len, TF_Status* status) {
2662+ tensorflow::CppShapeInferenceResult::HandleData handle_data;
2663+ if (!handle_data.ParseFromArray (proto, proto_len)) {
2664+ status->status = tensorflow::errors::InvalidArgument (
2665+ " Couldn't deserialize HandleData proto" );
2666+ return ;
2667+ }
2668+ DCHECK (handle_data.is_set ());
2669+
2670+ tensorflow::mutex_lock l (graph->mu );
2671+ tensorflow::shape_inference::InferenceContext* ic =
2672+ graph->refiner .GetContext (&output.oper ->node );
2673+
2674+ std::vector<tensorflow::shape_inference::ShapeAndType> shapes_and_types;
2675+ for (const auto & shape_and_type_proto : handle_data.shape_and_type ()) {
2676+ tensorflow::shape_inference::ShapeHandle shape;
2677+ status->status =
2678+ ic->MakeShapeFromShapeProto (shape_and_type_proto.shape (), &shape);
2679+ if (TF_GetCode (status) != TF_OK) return ;
2680+ shapes_and_types.emplace_back (shape, shape_and_type_proto.dtype (),
2681+ shape_and_type_proto.type ());
2682+ }
2683+ ic->set_output_handle_shapes_and_types (output.index , shapes_and_types);
2684+ }
2685+
2686+ void TFC_AddWhileInputHack (TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
2687+ TF_Status* status) {
2688+ mutex_lock l (graph->mu );
2689+ status->status = graph->graph .AddWhileInputHack (&new_src.oper ->node ,
2690+ new_src.index , &dst->node );
2691+ if (TF_GetCode (status) == TF_OK) {
2692+ // This modification only updates the destination node for
2693+ // the purposes of running this graph in a session. Thus, we don't
2694+ // record the source node as being modified.
2695+ tensorflow::RecordMutation (graph, *dst, " adding input tensor" );
2696+ }
2697+ }
2698+
2699+ // -------------------------------------------------------------------
2700+
25592701// TF_Server functions ----------------------------------------------
25602702
25612703#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
0 commit comments