Tensorflow是如何註冊和調用C++ New Op的

寒假回來想起來之前挖的坑,但好像並沒有特別好的主題可以寫,更不用說實習招聘近在眼前了,於是打算先擴展一下之前在知乎上的兩個回答。

本文主要介紹動態鏈接的C++ New Op是如何被註冊進來,又如何被Python代碼調用的,也算是給自己的一個交代,畢竟本人一直不太喜歡high-level的API。本文大致分為三個模塊:註冊Ops,註冊Kernel,調用Ops。

  • Ops的註冊過程

先說一下OpRegistrationData這個東西,這個類的對象由全局註冊器Registry負責分配,作用簡單來說就是保存OpDef和OpShapeInferenceFn函數,前者保存有Op的各種具體信息,會由OpDefBuilder在最後的解析參數時(成員函數Finalize)放進來,後者在SetShapeFn傳進來(由Wrapper轉發),所謂註冊就是將op name和OpRegistrationData關聯起來,具體來說放進hashmap。

mutable std::unordered_map<string, const OpRegistrationData*> registry_;

還得先說一下OpDefBuilder這個類,OpDefBuilder會負責接收Op的各種屬性和參數定義(就是REGISTER_OP時指定的,見下),最後統一解析(注意只是解析並不保證合法性之類的)並轉給OpRegistrationData這個類(包括ShapeFn)。

我們自己註冊op都會通過下面這個宏定義:

REGISTER_OP("YourOp") .Attr("T: {float}") .Input("logits: T") .Input("Labels: T") .Output("loss: T") .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { c->set_output(0, c->MakeShape({1})); return Status::OK(); });

細節都在REGISTER_OP那個宏定義裡面,簡化如下:

static OpDefBuilderReceiver register_op = OpDefBuilderWrapper(YourOp)

其中OpDefBuilderWrapper內部保存有一個OpDefBuilder成員變數,你所有對REGISTER_OP宏連續調用的操作包括op的名字最後都會一股腦轉發給前面那個唯一的OpDefBuilder變數,而OpDefBuilderReceiver則拿過來BuilderWrapper交給一個負責管理所有Op註冊的Registry,Registry暴露Register方法給op們註冊,把官方的example摘過來示意一下:

//Example registration: OpRegistry::Global()->Register( [](OpRegistrationData* op_reg_data)->Status { // Populate *op_reg_data here. return Status::OK(); });

(先解釋下:OpRegistry::Global()簡單的單例模式,返回OpRegistry的全局唯一實例,當然這裡必須要感謝下新標準對static線程安全的保證。)

在那個lambda裡面你就可以做任何想做的事情了,比如就像OpDefBuilderReceiver一樣把BuilderWrapper拿進來,然後把wrapper去掉取出OpDefBuilder,看到上面lambda裡面那個op_reg_data沒,對這就是之前提到的將解析好參數及shapefn傳到OpRegistrationData里,最後Register拿到op的name和OpRegistrationData組成pair放進hashmap完成註冊,同時會做一些合法性檢查的事情。如下:

OpRegistry::Global()->Register( [wrapper](OpRegistrationData* op_reg_data) -> Status { return wrapper.builder().Finalize(op_reg_data); });

其實到這裡真正的註冊並不一定會發生,下面會詳細說。

  • Kernel的註冊過程

與Ops的註冊類似,也是有一個叫作KernelDefBuilder的wrapper,內部保存有KernelDef的一個指針,用於設置各種屬性,最後調用Build函數可返回該指針並清空Builder,Kernel的註冊主要是通過下面這個宏來實現的:

REGISTER_KERNEL_BUILDER( Name("PsRoiAlignGrad").Device(DEVICE_GPU).TypeConstraint<float>("T"), PSROIAlignGradOp<GPUDevice, float>);

其中Name是KernelDefBuilder的一個派生類,Name("KernelName")會首先創建一個KernelDefBuilder同時設置設置kernel名稱,每次調用這種setter函數就會返回Builder自身從而支持連續調用,然後是設置Device,最後添加值float到屬性T中。

class Name : public KernelDefBuilder {public: // For system kernels, we ignore selective registration and // unconditionally register the kernel. explicit Name(const char* op) : KernelDefBuilder(op) {}};

REGISTER_KERNEL_BUILDER宏裡面就是一些trick,實質是創建一個名稱唯一的類型為OpKernelRegistrar的全局靜態變數,如果你有興趣可以看一下:

#define REGISTER_KERNEL_BUILDER(kernel_builder, ...) REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__)#define REGISTER_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)#define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) constexpr bool should_register_##ctr##__flag = SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__); static ::tensorflow::kernel_factory::OpKernelRegistrar registrar__body__##ctr##__object( should_register_##ctr##__flag ? ::tensorflow::register_kernel::kernel_builder.Build() : nullptr, #__VA_ARGS__, [](::tensorflow::OpKernelConstruction* context) -> ::tensorflow::OpKernel* { return new __VA_ARGS__(context); });

OpKernelRegistrar靜態變數的構造需要三個參數,如下所示,第一個是KernelDef,第二個是定義Kernel的類名,第三個是創建kernel對象的函數,其實後面就可以知道這三個參數都會被包裝到KernelRegistration這個結構體里,然後作為Kernel註冊表的值。因此這個宏會首先調用KernelDefBuilder的Build函數獲得對應的KernelDef;然後獲取用於創建這個Kernel的C++類名稱(這個類是繼承自OpKernel的);最後包裝一個factory函數用來接收傳進來的OpKernelConstruction*,創建對應的Kernel類對象,並返回其指針。

class OpKernelRegistrar {public: typedef OpKernel* (*Factory)(OpKernelConstruction*); OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name, Factory factory) { if (kernel_def != nullptr) { InitInternal(kernel_def, kernel_class_name, factory); } }};

這裡是InitInternal的細節

void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name, Factory factory) { // See comments in register_kernel::Name in header for info on _no_register. if (kernel_def->op() != "_no_register") { const string key = Key(kernel_def->op(), DeviceType(kernel_def->device_type()), kernel_def->label()); GlobalKernelRegistryTyped()->insert(std::make_pair( key, KernelRegistration(*kernel_def, kernel_class_name, factory))); } delete kernel_def;}

可以看到OpKernelRegistrar這個類主要是負責根據傳進來的KernelDef和KernelFactory,首先依據一定規則生成一個適當的key,並插入到一個全局唯一的Kernel註冊表裡,註冊表當然是一個map但是值得注意的是它是multimap因此支持一個鍵對應多個kernel副本。

typedef std::unordered_multimap<string, KernelRegistration> KernelRegistry;

  • OpKernel的創建與調用

如果你還記得的話,前面還有一個全局的OpRegistry,這樣根據NodeDef里的Op名稱就可以獲得Op對應的信息,再結合設備類型也就可以獲得Kernel對應的信息了,而NodeDef是在Python創建Operation之前創建的,可以看這裡create_op,後面會提到調用這個函數的地方。

然後就可以根據一個NodeDef和當前的設備類型在運行時創建一個OpKernel了,每個被創建的OpKernel都會被自動地管理生命周期。在Device類中會有一個OpSegment對象,OpSegment會管理一個sessions中用到的kernel,根據情況來決定是創建新的還是復用之前的OpKernel,具體來說是有兩個嵌套的hashmap,第一個將session handle映射到一個KernelMap,然後在KernelMap就可以去查找是否有對應Op名的OpKernel,如果沒有就調用一個create_fn函數進行創建。

那麼問題來了,這背後的原動力在哪?事實上Session在第一次為某個Node創建Executor的時候這一切就發生了(後面會再說到Executor的):DirectSession::GetOrCreateExecutors,更直接地可以看查找失敗後第一次創建Executor的地方,代碼片段如下:

LocalExecutorParams params;params.device = device;params.function_library = lib;auto opseg = device->op_segment();params.create_kernel = [this, lib, opseg](const NodeDef& ndef, OpKernel** kernel) { // We do not share the kernel via the OpSegment if the node is // stateless, or a function. // NOTE(mrry): We must not share function kernels (implemented // using `CallOp`) between subgraphs, because `CallOp::handle_` // is tied to a particular subgraph. Even if the function itself // is stateful, the `CallOp` that invokes it is not. if (!lib->IsStateful(ndef.op()) || lib->GetFunctionLibraryDefinition()->Find(ndef.op()) != nullptr) { return lib->CreateKernel(ndef, kernel); } auto create_fn = [lib, &ndef](OpKernel** kernel) { return lib->CreateKernel(ndef, kernel); }; // Kernels created for subgraph nodes need to be cached. On // cache miss, create_fn() is invoked to create a kernel based // on the function library here + global op registry. return opseg->FindOrCreate(session_handle_, ndef.name(), kernel, create_fn);};

可以看到取出OpSegment,構造create_fn並調用FindOrCreate的過程。其中create_fn內部調用的FunctionLibraryRuntime的CreateKernel函數可以看這裡:FunctionLibraryRuntimeImpl::CreateKernel,再往下CreateNonCachedKernel:

Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib, const NodeDef& ndef, int graph_def_version, OpKernel** kernel) { const auto device_type = DeviceType(device->attributes().device_type()); auto allocator = device->GetAllocator(AllocatorAttributes()); return CreateOpKernel(device_type, device, allocator, flib, ndef, graph_def_version, kernel);}

看到了CreateOpKernel的調用,這下總算回到了我們最開始的地方CreateOpKernel:

Status CreateOpKernel(DeviceType device_type, DeviceBase* device, Allocator* allocator, FunctionLibraryRuntime* flib, const NodeDef& node_def, int graph_def_version, OpKernel** kernel)

這個核心函數主要是做一下以下幾件事情:根據node_def取出op名,去查OpRegistry,並與node_def的信息進行校驗,比如介面是否一致,node_def中是否包含所有op_def中的信息等,然後根據device_type和op名去查KernelRegistry獲取KernelRegistration,就是map中的值,包含之前提到的三項。接著是確定輸入輸出類型及其存儲位置,最後是創建一個OpKernelConstruction對象,並傳給Kernel的factory函數函數,這就到了用戶自己寫的函數這邊了:

// Everything needed for OpKernel construction.OpKernelConstruction context( device_type, device, allocator, &node_def, op_def, flib, inputs, input_memory_types, outputs, output_memory_types, graph_def_version, &s);*kernel = (*registration->factory)(&context);

Kernel創建完了,那麼它什麼時候被執行呢?前面說到第一次創建executor的時候會創建OpKernel,其實每次Session調用Run的時候最終也是轉到executor這邊來執行的,包括根據當前的運行時環境創建OpKernelContext以及OpKernel::Compute的調用:

// Synchronous computes.OpKernelContext ctx(&params, item.num_outputs);nodestats::SetOpStart(stats);device->Compute(CHECK_NOTNULL(op_kernel), &ctx);nodestats::SetOpEnd(stats);

其中device->Compute這一步通過查看基類的實現就大概能知道所有細節了Device::Compute:

// Performs the actual compute function.//// Subclasses may override this function if they wish to perform// some initialization before each compute.virtual void Compute(OpKernel* op_kernel, OpKernelContext* context) { op_kernel->Compute(context);}

可以發現,我們寫的Compute方法在這裡就被調用了。至此故事好像可以告一段落了,不過說了半天好像一直在C++這邊啊,那Python代碼怎麼調用的呢?

  • 註冊Ops和Kernel後傳

根據上面REGISTER_KERNEL_BUILDER所展開的兩段程序很容易就判斷出如果動態庫被載入進來的話,Kernel就會自動完成註冊,這跟Ops的註冊基本是一樣的,不同之處在於動態鏈接進來的Ops會在載入庫之前設置延遲註冊的標記,並添加一個Watcher,然後手動調用註冊,這主要是為了通過Watcher獲取註冊過程中從OpRegistrationData(就是註冊表的值)中取出的OpDef,這一點可以在後面的LoadLibrary中看到。這個過程很重要,通過獲得的OpDef組成的OpList並序列化後,Python端就可以解析出這些OpDef,同時調用C++這邊利用這些OpDef生成對應的ApiDef,二者結合就可以動態生成定義這個Op的Python代碼,然後返回到Python端執行這些代碼,注意這些代碼的執行並不包括創建Op並添加到Graph這個過程,只包括定義相關代碼段的函數,下面是從Python端load_op_library一直到生成Python代碼的過程:load_op_library->GetPythonWrappers->GetPythonOps->GetPythonOp->GenPythonOp::Code()。還有從OpList生成ApiDef的地方ApiDefMap::ApiDefMap(const OpList& op_list)。如果你有興趣的話可以去看一下我之前寫的一個Op自動生成的代碼,我附在了本文最後,生成代碼中的apply_op就是添加Op到Graph的代碼,可以看這裡apply_op,這個函數的最後面就是前面提到的調用Graph的create_op。

下面是LoadLibrary的代碼段,可以對照一下:

Status LoadLibrary(const char* library_filename, void** result, const void** buf, size_t* len) { static mutex mu; static std::unordered_map<string, Library> loaded_libs; Env* env = Env::Default(); Library library; std::unordered_set<string> seen_op_names; { mutex_lock lock(mu); if (loaded_libs.find(library_filename) != loaded_libs.end()) { library = loaded_libs[library_filename]; } else { Status s = OpRegistry::Global()->ProcessRegistrations(); if (!s.ok()) { return s; } TF_RETURN_IF_ERROR(OpRegistry::Global()->SetWatcher( [&library, &seen_op_names](const Status& s, const OpDef& opdef) -> Status { if (errors::IsAlreadyExists(s)) { if (seen_op_names.find(opdef.name()) == seen_op_names.end()) { // Over writing a registration of an op not in this custom op // library. Treat this as not an error. return Status::OK(); } } if (s.ok()) { *library.op_list.add_op() = opdef; seen_op_names.insert(opdef.name()); } return s; })); OpRegistry::Global()->DeferRegistrations(); s = env->LoadLibrary(library_filename, &library.handle); if (s.ok()) { s = OpRegistry::Global()->ProcessRegistrations(); } if (!s.ok()) { OpRegistry::Global()->ClearDeferredRegistrations(); TF_RETURN_IF_ERROR(OpRegistry::Global()->SetWatcher(nullptr)); return s; } TF_RETURN_IF_ERROR(OpRegistry::Global()->SetWatcher(nullptr)); loaded_libs[library_filename] = library; } } string str; library.op_list.SerializeToString(&str); char* str_buf = reinterpret_cast<char*>(port::Malloc(str.length())); memcpy(str_buf, str.data(), str.length()); *buf = str_buf; *len = str.length(); *result = library.handle; return Status::OK();}

自動生成的Python代碼,這裡是對應的C++ Op:

"""Python wrappers around TensorFlow ops.This file is MACHINE GENERATED! Do not edit."""import collections as _collectionsfrom tensorflow.core.framework import op_def_pb2 as _op_def_pb2# Needed to trigger the call to _set_call_cpp_shape_fn.from tensorflow.python.framework import common_shapes as _common_shapesfrom tensorflow.python.framework import op_def_registry as _op_def_registryfrom tensorflow.python.framework import ops as _opsfrom tensorflow.python.framework import op_def_library as _op_def_libraryfrom tensorflow.python.util.tf_export import tf_export_ps_roi_align_outputs = ["pooled_features", "pooled_index"]_PsRoiAlignOutput = _collections.namedtuple( "PsRoiAlign", _ps_roi_align_outputs)@tf_export(ps_roi_align)def ps_roi_align(inputs, rois, grid_dim_width, grid_dim_height, name=None): r""" PsRoiAlign is a new PsRoiPooling method without align problems. The input rois to be pooled must in format [center_y, center_x, h, w] and each element must be in range [0, 1.]. The caller must make sure that all rois is valid (has a intersect region (one pixel at least) with the window [0.5, 0.5, 1., 1.]). Args: inputs: A `Tensor`. Must be one of the following types: `float32`. rois: A `Tensor`. Must have the same type as `inputs`. grid_dim_width: An `int`. grid_dim_height: An `int`. name: A name for the operation (optional). Returns: A tuple of `Tensor` objects (pooled_features, pooled_index). pooled_features: A `Tensor`. Has the same type as `inputs`. pooled_index: A `Tensor` of type `int32`. """ _result = _op_def_lib.apply_op("PsRoiAlign", inputs=inputs, rois=rois, grid_dim_width=grid_dim_width, grid_dim_height=grid_dim_height, name=name) _result = _PsRoiAlignOutput._make(_result) return _result_ops.RegisterShape("PsRoiAlign")(None)@tf_export(ps_roi_align_grad)def ps_roi_align_grad(inputs, rois, pooled_features_grad, pooled_index, grid_dim_width, grid_dim_height, name=None): r""" PsRoiAlignGrad is the Gradient op of PsRoiAlign. The input rois to be pooled must in format [center_y, center_x, h, w] and each element must be in range [0, 1.]. The caller must make sure that all rois is valid (has a intersect region (one pixel at least) with the window [0.5, 0.5, 1., 1.]). Args: inputs: A `Tensor`. Must be one of the following types: `float32`. rois: A `Tensor`. Must have the same type as `inputs`. pooled_features_grad: A `Tensor`. Must have the same type as `inputs`. pooled_index: A `Tensor` of type `int32`. grid_dim_width: An `int`. grid_dim_height: An `int`. name: A name for the operation (optional). Returns: A `Tensor`. Has the same type as `inputs`. """ _result = _op_def_lib.apply_op("PsRoiAlignGrad", inputs=inputs, rois=rois, pooled_features_grad=pooled_features_grad, pooled_index=pooled_index, grid_dim_width=grid_dim_width, grid_dim_height=grid_dim_height, name=name) return _result_ops.RegisterShape("PsRoiAlignGrad")(None)def _InitOpDefLibrary(op_list_proto_bytes): op_list = _op_def_pb2.OpList() op_list.ParseFromString(op_list_proto_bytes) _op_def_registry.register_op_list(op_list) op_def_lib = _op_def_library.OpDefLibrary() op_def_lib.add_op_list(op_list) return op_def_lib# op {# name: "PsRoiAlign"# input_arg {# name: "inputs"# type_attr: "T"# }# input_arg {# name: "rois"# type_attr: "T"# }# output_arg {# name: "pooled_features"# type_attr: "T"# }# output_arg {# name: "pooled_index"# type: DT_INT32# }# attr {# name: "T"# type: "type"# allowed_values {# list {# type: DT_FLOAT# }# }# }# attr {# name: "grid_dim_width"# type: "int"# }# attr {# name: "grid_dim_height"# type: "int"# }# }# op {# name: "PsRoiAlignGrad"# input_arg {# name: "inputs"# type_attr: "T"# }# input_arg {# name: "rois"# type_attr: "T"# }# input_arg {# name: "pooled_features_grad"# type_attr: "T"# }# input_arg {# name: "pooled_index"# type: DT_INT32# }# output_arg {# name: "grad_output"# type_attr: "T"# }# attr {# name: "T"# type: "type"# allowed_values {# list {# type: DT_FLOAT# }# }# }# attr {# name: "grid_dim_width"# type: "int"# }# attr {# name: "grid_dim_height"# type: "int"# }# }_op_def_lib = _InitOpDefLibrary(b"\n\215\001\n\nPsRoiAlign\022\013\n\006inputs\"\001T\022\t\n\004rois\"\001T\032\024\n\017pooled_features\"\001T\032\020\n\014pooled_index\030\003\"\020\n\001T\022\004type:\005\n\0032\001\001\"\025\n\016grid_dim_width\022\003int\"\026\n\017grid_dim_height\022\003int\n\250\001\n\016PsRoiAlignGrad\022\013\n\006inputs\"\001T\022\t\n\004rois\"\001T\022\031\n\024pooled_features_grad\"\001T\022\020\n\014pooled_index\030\003\032\020\n\013grad_output\"\001T\"\020\n\001T\022\004type:\005\n\0032\001\001\"\025\n\016grid_dim_width\022\003int\"\026\n\017grid_dim_height\022\003int")

推薦閱讀:

tensorflow的共享變數,tf.Variable(),tf.get_variable(),tf.Variable_scope(),tf.name_scope()聯繫與區別
Tensorflow on Spark爬坑指南
【博客存檔】TensoFlow之深入理解GoogLeNet
TensorFlow官方教程翻譯:導入數據
在TensorFlow中使用pipeline載入數據

TAG:機器學習 | TensorFlow | 深度學習DeepLearning |