Skip to content

LODTensor related Variable type-deduce with inherience #3396

@Superjomn

Description

@Superjomn

LODTensor 的逻辑需要在 Variable 里添加一个继承关系的推导,最终达到此目的:

  • 前提, LODTensor 继承自 Tensor
  • 如果 variable 已经被设定为 LODTensor , 则 variable.Get<Tensor>()variable.GetMutate<Tensor>() 都返回原有 LODTensor 中的内容

如此,必须记录相关类型继承的关系。

下面是一个简单的方案:

  • 建立一个全局的map 通过 PrefixHash 记录不同类型
  • base classPrefixHash 会是所有 derived class 的前缀
  • Variable 在判定两种类型是否有继承关系时,获取两者的 PrefixHash 并判定是否是前缀关系便可
#include <cstring>
#include <iostream>
#include <map>
#include <typeindex>
#include <typeinfo>

struct PrefixHash {
  std::string hash;
  size_t num_children;

  bool IsDescendentOf(const PrefixHash &other) {
    if (hash.size() < other.hash.size() &&
        std::memcmp(hash.data(), other.hash.data(), hash.size())) {
      return true;
    }
    return false;
  }

  void SetDescendentOf(PrefixHash &other) {
    hash = other.hash;
    hash.push_back((unsigned char)other.num_children++);
  }
};

// base of all types
struct BaseType {};

struct TypeDescendentDeducer {

  TypeDescendentDeducer() {
    // insert base type
    PrefixHash hash{"0", 0};
    type_map[typeid(BaseType)] = hash;
  }

  static TypeDescendentDeducer &Global() {
    static TypeDescendentDeducer x;
    return x;
  }

  template <typename Father, typename Child> void Register() {
    const auto &father_type = std::type_index(typeid(Father));
    const auto &child_type = std::type_index(typeid(Child));

    auto father_it = type_map.find(father_type);
    // insert father record
    if (father_it == type_map.end()) {
      auto &base_type = type_map[typeid(BaseType)];
      PrefixHash hash;
      hash.SetDescendentOf(base_type);
      type_map[std::type_index(father_type)] = hash;
    }

    PrefixHash child_hash;
    child_hash.SetDescendentOf(type_map[father_type]);
    type_map[child_type] = child_hash;
  }

  template <typename T> bool IsDescendentOf(const std::type_info &child) {
    auto child_it = type_map.find(child);
    auto father_it = type_map.find(typeid(T));
    if (child_it == type_map.end() || father_it == type_map.end())
      return false;
    return child_it->second.IsDescendentOf(father_it->second);
  }

  std::map<std::type_index, PrefixHash> type_map;
};

#define REGISTER_TYPE_DESCENDENCE(__father, __child)                           \
  struct __type_descendence_##__father##__child##__ {                          \
    __type_descendence_##__father##__child##__() {                             \
      TypeDescendentDeducer::Global().Register<__father, __child>();           \
    }                                                                          \
  };                                                                           \
  __type_descendence_##__father##__child##__                                   \
      __type_descendence_##__father##__child##___;

具体使用方法:

// register LODTensor as a derived class of Tensor
class Tensor {};
class LODTensor : public Tensor {};

REGISTER_TYPE_DESCENDENCE(Tensor, LODTensor)

class Variable {

public:
  // ...
  template <typename T> bool IsType() const {
    if (std::type_index(typeid(T)) == type_) {
      return true;
    }
    return TypeDescendentDeducer::Global().IsDescendentOf<T>(type_);
  }

private:
  std::type_index type_;
};

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions