Skip to content

Commit ad966a0

Browse files
committed
Add customized C APIs for tf.net correponding to v2.11.
1 parent a3e2c69 commit ad966a0

File tree

5 files changed

+239
-3
lines changed

5 files changed

+239
-3
lines changed

tensorflow/c/c_api.cc

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,14 @@ limitations under the License.
5151
#include "tensorflow/core/framework/partial_tensor_shape.h"
5252
#include "tensorflow/core/framework/tensor.h"
5353
#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
54+
#include "tensorflow/core/framework/cpp_shape_inference.pb.h"
55+
#include "tensorflow/core/framework/shape_inference.h"
5456
#include "tensorflow/core/framework/tensor_shape.h"
5557
#include "tensorflow/core/framework/tensor_shape.pb.h"
5658
#include "tensorflow/core/framework/types.h"
5759
#include "tensorflow/core/framework/versions.pb.h"
60+
#include "tensorflow/core/framework/full_type.pb.h"
61+
#include "tensorflow/core/framework/attr_value_util.h"
5862
#include "tensorflow/core/graph/graph.h"
5963
#include "tensorflow/core/graph/node_builder.h"
6064
#include "tensorflow/core/graph/validate.h"
@@ -2556,6 +2560,144 @@ void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
25562560
}
25572561
}
25582562

2563+
// TF Customized C APIs for Tensorflow.NET --------------------------
2564+
2565+
void TFC_AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) {
2566+
mutex_lock l(graph->mu);
2567+
graph->graph.AddControlEdge(&input->node, &op->node);
2568+
tensorflow::RecordMutation(graph, *op, "adding control input");
2569+
}
2570+
2571+
void TFC_SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
2572+
TF_Buffer* attr_value_proto, TF_Status* status) {
2573+
using tensorflow::RecordMutation;
2574+
tensorflow::AttrValue attr_val;
2575+
if (!attr_val.ParseFromArray(attr_value_proto->data,
2576+
attr_value_proto->length)) {
2577+
status->status =
2578+
tensorflow::errors::InvalidArgument("Invalid AttrValue proto");
2579+
return;
2580+
}
2581+
2582+
mutex_lock l(graph->mu);
2583+
op->node.AddAttr(attr_name, attr_val);
2584+
tensorflow::RecordMutation(graph, *op, "setting attribute");
2585+
}
2586+
2587+
void TFC_ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
2588+
TF_Status* status) {
2589+
mutex_lock l(graph->mu);
2590+
op->node.ClearAttr(attr_name);
2591+
tensorflow::RecordMutation(graph, *op, "clearing attribute");
2592+
}
2593+
2594+
void TFC_SetFullType(TF_Graph* graph, TF_Operation* op,
2595+
const tensorflow::FullTypeDef& full_type) {
2596+
mutex_lock l(graph->mu);
2597+
*op->node.mutable_def()->mutable_experimental_type() = full_type;
2598+
tensorflow::RecordMutation(graph, *op, "setting fulltype");
2599+
}
2600+
2601+
void TFC_SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
2602+
mutex_lock l(graph->mu);
2603+
op->node.set_requested_device(device);
2604+
tensorflow::RecordMutation(graph, *op, "setting device");
2605+
}
2606+
2607+
void TFC_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
2608+
TF_Status* status) {
2609+
TF_UpdateEdge(graph, new_src, dst, status);
2610+
}
2611+
2612+
void TFC_RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) {
2613+
mutex_lock l(graph->mu);
2614+
std::vector<const tensorflow::Edge*> control_edges;
2615+
for (const tensorflow::Edge* edge : op->node.in_edges()) {
2616+
if (!edge->IsControlEdge()) continue;
2617+
control_edges.push_back(edge);
2618+
}
2619+
for (const tensorflow::Edge* edge : control_edges) {
2620+
graph->graph.RemoveControlEdge(edge);
2621+
}
2622+
}
2623+
2624+
void TFC_SetRequireShapeInferenceFns(TF_Graph* graph, bool require) {
2625+
mutex_lock l(graph->mu);
2626+
graph->refiner.set_require_shape_inference_fns(require);
2627+
}
2628+
2629+
void TFC_ExtendSession(TF_Session* session, TF_Status* status) {
2630+
ExtendSessionGraphHelper(session, status);
2631+
session->extend_before_run = false;
2632+
}
2633+
2634+
const char* TFC_GetHandleShapeAndType(TF_Graph* graph, TF_Output output) {
2635+
Node* node = &output.oper->node;
2636+
tensorflow::CppShapeInferenceResult::HandleData handle_data;
2637+
handle_data.set_is_set(true);
2638+
{
2639+
mutex_lock l(graph->mu);
2640+
tensorflow::shape_inference::InferenceContext* ic =
2641+
graph->refiner.GetContext(node);
2642+
CHECK(ic != nullptr);
2643+
CHECK_LT(output.index, ic->num_outputs());
2644+
const auto* shapes_and_types =
2645+
ic->output_handle_shapes_and_types(output.index);
2646+
if (shapes_and_types == nullptr) return "";
2647+
2648+
for (const auto& p : *shapes_and_types) {
2649+
auto* out_shape_and_type = handle_data.add_shape_and_type();
2650+
ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape());
2651+
out_shape_and_type->set_dtype(p.dtype);
2652+
*out_shape_and_type->mutable_type() = p.type;
2653+
}
2654+
}
2655+
string result;
2656+
handle_data.SerializeToString(&result);
2657+
return result.c_str();
2658+
}
2659+
2660+
void TFC_SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
2661+
size_t proto_len, TF_Status* status) {
2662+
tensorflow::CppShapeInferenceResult::HandleData handle_data;
2663+
if (!handle_data.ParseFromArray(proto, proto_len)) {
2664+
status->status = tensorflow::errors::InvalidArgument(
2665+
"Couldn't deserialize HandleData proto");
2666+
return;
2667+
}
2668+
DCHECK(handle_data.is_set());
2669+
2670+
tensorflow::mutex_lock l(graph->mu);
2671+
tensorflow::shape_inference::InferenceContext* ic =
2672+
graph->refiner.GetContext(&output.oper->node);
2673+
2674+
std::vector<tensorflow::shape_inference::ShapeAndType> shapes_and_types;
2675+
for (const auto& shape_and_type_proto : handle_data.shape_and_type()) {
2676+
tensorflow::shape_inference::ShapeHandle shape;
2677+
status->status =
2678+
ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
2679+
if (TF_GetCode(status) != TF_OK) return;
2680+
shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype(),
2681+
shape_and_type_proto.type());
2682+
}
2683+
ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
2684+
}
2685+
2686+
void TFC_AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
2687+
TF_Status* status) {
2688+
mutex_lock l(graph->mu);
2689+
status->status = graph->graph.AddWhileInputHack(&new_src.oper->node,
2690+
new_src.index, &dst->node);
2691+
if (TF_GetCode(status) == TF_OK) {
2692+
// This modification only updates the destination node for
2693+
// the purposes of running this graph in a session. Thus, we don't
2694+
// record the source node as being modified.
2695+
tensorflow::RecordMutation(graph, *dst, "adding input tensor");
2696+
}
2697+
}
2698+
2699+
// -------------------------------------------------------------------
2700+
25592701
// TF_Server functions ----------------------------------------------
25602702

25612703
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)

tensorflow/c/c_api.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ limitations under the License.
2525
#include "tensorflow/c/tf_status.h"
2626
#include "tensorflow/c/tf_tensor.h"
2727
#include "tensorflow/c/tf_tstring.h"
28+
#include "tensorflow/core/framework/full_type.pb.h"
2829

2930
// --------------------------------------------------------------------------
3031
// C API for TensorFlow.
@@ -1595,6 +1596,48 @@ TF_CAPI_EXPORT extern void TF_RegisterLogListener(
15951596
TF_CAPI_EXPORT extern void TF_RegisterFilesystemPlugin(
15961597
const char* plugin_filename, TF_Status* status);
15971598

1599+
// Customized C APIs for Tensorflow.NET ---------------------------
1600+
1601+
TF_CAPI_EXPORT extern void TFC_AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input);
1602+
1603+
TF_CAPI_EXPORT extern void TFC_SetAttr(TF_Graph* graph, TF_Operation* op,
1604+
const char* attr_name,
1605+
TF_Buffer* attr_value_proto,
1606+
TF_Status* status);
1607+
1608+
TF_CAPI_EXPORT extern void TFC_ClearAttr(TF_Graph* graph, TF_Operation* op,
1609+
const char* attr_name,
1610+
TF_Status* status);
1611+
1612+
TF_CAPI_EXPORT extern void TFC_SetFullType(TF_Graph* graph, TF_Operation* op,
1613+
const tensorflow::FullTypeDef& full_type);
1614+
1615+
TF_CAPI_EXPORT extern void TFC_SetRequestedDevice(TF_Graph* graph,
1616+
TF_Operation* op,
1617+
const char* device);
1618+
1619+
TF_CAPI_EXPORT extern void TFC_UpdateEdge(TF_Graph* graph, TF_Output new_src,
1620+
TF_Input dst, TF_Status* status);
1621+
1622+
TF_CAPI_EXPORT extern void TFC_RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op);
1623+
1624+
TF_CAPI_EXPORT extern void TFC_SetRequireShapeInferenceFns(TF_Graph* graph, bool require);
1625+
1626+
TF_CAPI_EXPORT extern void TFC_ExtendSession(TF_Session* session, TF_Status* status);
1627+
1628+
TF_CAPI_EXPORT extern const char* TFC_GetHandleShapeAndType(TF_Graph* graph, TF_Output output);
1629+
1630+
TF_CAPI_EXPORT extern void TFC_SetHandleShapeAndType(TF_Graph* graph,
1631+
TF_Output output,
1632+
const void* proto,
1633+
size_t proto_len,
1634+
TF_Status* status);
1635+
1636+
void TFC_AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
1637+
TF_Status* status);
1638+
1639+
// ----------------------------------------------------------------
1640+
15981641
#ifdef __cplusplus
15991642
} /* end extern "C" */
16001643
#endif

tensorflow/c/version_script.lds

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ VERS_1.0 {
33
global:
44
*TF_*;
55
*TFE_*;
6+
*TFC_*;
67

78
# Hide everything else.
89
local:

tensorflow/core/framework/BUILD

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ package(
2626
default_visibility = [
2727
"//tensorflow/core:__subpackages__",
2828
"//tensorflow/security/fuzzing:__subpackages__",
29-
# TODO(pedaveeraiah): to be removed when summary.proto.h deps moves to TSL
29+
# TODO(pedaveeraiah): to be removed when summary.proto.h deps moves to TSL
3030
"//tensorflow/tsl/lib:__subpackages__",
3131
],
3232
licenses = ["notice"],
@@ -107,6 +107,7 @@ exports_files(
107107
srcs = [
108108
"allocation_description.proto",
109109
"api_def.proto",
110+
"cpp_shape_inference.proto",
110111
"attr_value.proto",
111112
"cost_graph.proto",
112113
"dataset_metadata.proto",
@@ -1406,7 +1407,7 @@ cc_library(
14061407
# protos from the same package, so we can build the protos here and then
14071408
# link them from core:protos_all without circular dependencies.
14081409

1409-
# Generate the C++ sources for some of the protos.
1410+
#Generate the C++ sources for some of the protos.
14101411
tf_generate_proto_text_sources(
14111412
name = "attr_value_proto_text",
14121413
srcs = [
@@ -1697,6 +1698,18 @@ tf_proto_library(
16971698
],
16981699
)
16991700

1701+
tf_proto_library(
1702+
name = "cpp_shape_inference_proto",
1703+
srcs = ["cpp_shape_inference.proto"],
1704+
cc_api_version = 2,
1705+
make_default_target_header_only = True,
1706+
protodeps = [
1707+
":full_type_proto",
1708+
":tensor_shape_proto",
1709+
":types_proto",
1710+
],
1711+
)
1712+
17001713
tf_proto_library(
17011714
name = "variable_proto",
17021715
srcs = ["variable.proto"],
@@ -1764,7 +1777,7 @@ tf_proto_library(
17641777
# ":function_proto",
17651778
# ],
17661779
# )
1767-
# copybara:uncomment_end
1780+
#copybara : uncomment_end
17681781

17691782
tf_proto_library(
17701783
name = "summary_proto",
@@ -1812,6 +1825,7 @@ tf_proto_library(
18121825
protodeps = [
18131826
":allocation_description_proto",
18141827
":api_def_proto",
1828+
":cpp_shape_inference_proto",
18151829
":attr_value_proto",
18161830
":cost_graph_proto",
18171831
":dataset_metadata_proto",
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
syntax = "proto3";
2+
3+
package tensorflow;
4+
5+
import "tensorflow/core/framework/full_type.proto";
6+
import "tensorflow/core/framework/tensor_shape.proto";
7+
import "tensorflow/core/framework/types.proto";
8+
9+
option cc_enable_arenas = true;
10+
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/python/framework/cpp_shape_inference_go_proto";
11+
12+
message CppShapeInferenceResult {
13+
message HandleShapeAndType {
14+
reserved 3;
15+
16+
TensorShapeProto shape = 1;
17+
DataType dtype = 2;
18+
FullTypeDef type = 4;
19+
}
20+
message HandleData {
21+
bool is_set = 1;
22+
23+
// Only valid if <is_set>.
24+
repeated HandleShapeAndType shape_and_type = 2;
25+
}
26+
TensorShapeProto shape = 1;
27+
28+
reserved 2; // was handle_shape
29+
reserved 3; // was handle_dtype
30+
HandleData handle_data = 4;
31+
}
32+
33+
message CppShapeInferenceInputsNeeded {
34+
repeated int32 input_tensors_needed = 1;
35+
repeated int32 input_tensors_as_shapes_needed = 2;
36+
}

0 commit comments

Comments
 (0)