PyTorch arrive en version 1.5 avec une API frontale C++ stable

Par:
admin

lun, 27/04/2020 - 15:52

PyTorch est une bibliothèque de machine learning dévelopée par Facebook AI Research. La sortie de la version 1.5 de la bibliothèque vient d'être annoncée.

La nouveauté majeure est que l'API frontale C++, juste qu'ici expérimentale, est désormais stable, et est à parité avec Python. Concrètement, cela signifie les utilisateurs peuvent traduire complètement leurs modèles de l'API Python en API C ++.

Cette version ajoute une nouvelle API, expérimentale cette fois, permettant de lier simultanément des classes C ++ personnalisées à TorchScript et Python. Cela permet aux développeurs d'exposer une classe C ++ et ses méthodes au système de type TorchScript et au système d'exécution de manière à pouvoir instancier et manipuler des objets C ++ arbitraires à partir de TorchScript et Python. Par exemple :

template <class T>
struct MyStackClass : torch::CustomClassHolder {
  std::vector<T> stack_;
  MyStackClass(std::vector<T> init) : stack_(std::move(init)) {}
 
 void push(T x) {
    stack_.push_back(x);
  }

  T pop() {
    auto val = stack_.back();
    stack_.pop_back();
    return val;
  }
};

static auto testStack =
  torch::class_<MyStackClass<std::string>>("myclasses", "MyStackClass")
      .def(torch::init<std::vector<std::string>>())
      .def("push", &MyStackClass<std::string>::push)
      .def("pop", &MyStackClass<std::string>::pop)
      .def("size", [](const c10::intrusive_ptr<MyStackClass>& self) {
       return self->stack_.size();
});

expose une classe utilisable avec Python et TorchScript comme ceci :

@torch.jit.script
def do_stacks(s : torch.classes.myclasses.MyStackClass):
    s2 = torch.classes.myclasses.MyStackClass(["hi", "mom"])
    print(s2.pop()) # "mom"
    s2.push("foobar")
    return s2 # ["hi", "foobar"]

A remarquer encore:  les API RPC ont elles aussi acquis le statut stable dans PyTorch 1.5.

Enfin, à partir de Python 1.5, la bibliothèque ne supporte plus Python 2.

Toutes les nouveautés de Python 1.5 peuvent être retrouvées dans sa note de version.