Skip to content

Commit

Permalink
Add support for arrays of scalars as parameters.
Browse files Browse the repository at this point in the history
Reviewed By: jfix71

Differential Revision: D49376799

fbshipit-source-id: 7b74737ff520c84b01e8cea0ad35f3e5fcad2273
  • Loading branch information
Jay Banerjee authored and facebook-github-bot committed Sep 26, 2023
1 parent b06a601 commit 0a53b9f
Showing 1 changed file with 22 additions and 13 deletions.
35 changes: 22 additions & 13 deletions include/glow/Graph/FXIRUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <folly/Conv.h>
#include <folly/String.h>
#include <folly/dynamic.h>
#include <type_traits>

namespace glow {

Expand Down Expand Up @@ -68,21 +69,29 @@ std::vector<T> toIntegerArray(std::string intArrayStr,
}

template <class T>
std::vector<T> toIntegerArray(const folly::dynamic &dyn,
const uint32_t &length = 0) {
std::vector<T> toArray(const folly::dynamic &dyn, const uint32_t &length = 0) {
static_assert(std::is_floating_point<T>() || std::is_integral<T>(),
"Currently only support float and int types");
auto isType = [](const auto &a) {
return std::is_floating_point<T>() ? a.isDouble() : a.isInt();
};
auto getType = [](const auto &a) {
return std::is_floating_point<T>() ? a.getDouble() : a.getInt();
};
std::vector<T> vec;
if (dyn.isInt()) {
vec.emplace_back(dyn.getInt());
if (isType(dyn)) {
vec.emplace_back(getType(dyn));
} else if (dyn.isArray()) {
for (auto &e : dyn) {
if (e.isInt()) {
vec.emplace_back(e.getInt());
if (isType(e)) {
vec.emplace_back(getType(e));
} else {
LOG(FATAL) << "Non-integer vector unhandled";
LOG(FATAL) << "Mismatch between specified type for toArray and found "
"type in the vector in json";
}
}
} else {
LOG(FATAL) << "Only supporting integer/vec<integer>";
LOG(FATAL) << "Expected single element or vector of specified isArray type";
}

CHECK(!vec.empty()) << "Empty dimension size!";
Expand Down Expand Up @@ -167,21 +176,21 @@ template <class T> std::vector<T> getConvStride(const folly::dynamic &node) {
const auto &inputs = getNodeKwargs(node);
CHECK(inputs.find("stride") != inputs.items().end())
<< "stride field doesn't exist in Conv Inputs " << node;
return toIntegerArray<uint32_t>(inputs["stride"], 2);
return toArray<uint32_t>(inputs["stride"], 2);
}

template <class T> std::vector<T> getConvPads(const folly::dynamic &node) {
const auto &inputs = getNodeKwargs(node);
CHECK(inputs.find("padding") != inputs.items().end())
<< "padding field doesn't exist in Conv Inputs " << node;
return toIntegerArray<uint32_t>(inputs["padding"], 2);
return toArray<uint32_t>(inputs["padding"], 2);
}

template <class T> std::vector<T> getConvKernels(const folly::dynamic &node) {
const auto &inputs = getNodeKwargs(node);
CHECK(inputs.find("kernel_size") != inputs.items().end())
<< "kernel_size field doesn't exist in Conv Inputs " << node;
return toIntegerArray<uint32_t>(inputs["kernel_size"], 2);
return toArray<uint32_t>(inputs["kernel_size"], 2);
}

template <class T>
Expand All @@ -197,14 +206,14 @@ std::vector<T> getTransposeShuffle(const folly::dynamic &node) {
const auto &inputs = getNodeKwargs(node);
CHECK(inputs.find("permutation") != inputs.items().end())
<< "field transposed_dims doesn't exist in Conv Inputs " << node;
return toIntegerArray<uint32_t>(inputs["permutation"], 2);
return toArray<uint32_t>(inputs["permutation"], 2);
}

template <class T> std::vector<T> getMeanDims(const folly::dynamic &node) {
auto &inputs = getNodeKwargs(node);
CHECK(inputs.find("dim") != inputs.items().end())
<< "field dims doesn't exist in Mean Inputs " << node;
return toIntegerArray<uint32_t>(inputs["dim"]);
return toArray<uint32_t>(inputs["dim"]);
}

/// Search \p storageNodeNameToDest and \p nonStorageNodeNameToDest for
Expand Down

0 comments on commit 0a53b9f

Please sign in to comment.