`
haoningabc
  • 浏览: 1446836 次
  • 性别: Icon_minigender_1
  • 来自: 北京
社区版块
存档分类
最新评论

kaidi-wasm学习笔记(一)的两个重要文件

阅读更多
最近在看kaldi-wasm
两个重要文件备份一下,总结以后再写:

src/workers/asrWorker.js
import JSZip from 'jszip';

import kaldiJS from '../computations/kaldiJS';
import kaldiWasm from '../computations/kaldiJS.wasm';
import KaldiConfigParser from '../utils/kaldiConfigParser';

const kaldiModule = kaldiJS({
  locateFile(path) {
    if (path.endsWith('.wasm')){
        console.log("hao-asrWorer.js---kaldiJS:"+path);
        return kaldiWasm;
    }
   // if (path.endsWith('.wasm')) return kaldiJS;
    return path;
  },
});

const MODEL_STORE = {
  NAME: 'models',
  KEY_PATH: 'language',
};

let asr = null;
let parser = null;

function mkdirExistOK(fileSystem, path) {
  console.log("hao:asrWorker.js---mkdirExistOK,path:" +path+",fileSystem:")
  console.log(fileSystem);
  try {
    //fileSystem.mkdir(path);
    fileSystem.mkdir(path);
  } catch (e) {
    console.log("hao--mkdirExistOK--error..$$$$$$$$$$$$$$$$$$$$");
    if (e.code !== 'EEXIST') throw e;
  }
}

function initEMFS(fileSystem, modelName) {
  console.log("hao-asrWorker---initEMFS--fileSystem:");
  console.log(fileSystem);
  mkdirExistOK(fileSystem, MODEL_STORE.NAME);
  console.log("hao-asrWorker---initEMFS--MODEL_STORE.NAME:"+MODEL_STORE.NAME);
  fileSystem.mount(fileSystem.filesystems.IDBFS, {},
    MODEL_STORE.NAME);
  fileSystem.chdir(MODEL_STORE.NAME);
  fileSystem.mkdir(modelName);
  fileSystem.chdir(modelName);
  console.log("hao-asrWorker---initEMFS--over.");
}

async function unzip(zipfile) {
//  console.log("hao-asrWorker---unzip--:");
//  console.log(zipfile);
  const zip = new JSZip();

  const unzipped = await zip.loadAsync(zipfile);
  return unzipped;
}

function dirname(path) {
  const dirs = path.match(/.*\//);
  if (dirs === null) return '';
  // without trailing '/'
  return dirs[0].slice(0, dirs[0].length - 1);
}

function mkdirp(fileSystem, path) {
  console.log("hao-asrWorker-----mkdirp--path:"+path);
  console.log(fileSystem);
  const dirBoundary = '/';
  const startIndex = path[0] === dirBoundary ? 1 : 0;
  for (let i = startIndex; i < path.length; i += 1) {
    if (path[i] === dirBoundary) mkdirExistOK(fileSystem, path.slice(0, i));
  }
  mkdirExistOK(fileSystem, path);
}

async function writeToFileSystem(fileSystem, path, fileObj) {
  console.log("asrWorker.js---writeToFileSystem---fileSystem:");
  console.log(fileSystem);
  const content = await fileObj.async('arraybuffer');
  console.log("content:"+content);
  try {
    //fileSystem.writeFile(path, new Uint8Array(content));
    fileSystem.writeFile(path, new Uint8Array(content),function(err){
        console.log("writefile-----error:"+err);
    });
    console.log("asrWorker.js---writeToFileSystem---writeFile----over.path:"+fileSystem.cwd()+"/"+path);
    console.log("asrWorker.js---writeToFileSystem---isDir----models:"+fileSystem.isDir("/models"));
    console.log("asrWorker.js---writeToFileSystem---isDir:::"+fileSystem.isDir(fileSystem.cwd()));
    console.log("asrWorker.js---writeToFileSystem---final.mdl---isFile:"+fileSystem.isFile("final.mdl"));
    console.log("asrWorker.js---writeToFileSystem---end----------");
    return;
  } catch (e) {
    console.log("hao---error:-->>>>>>>>>>>>>>>>>>>>>>>writeToFileSystem......>");
    if (e.code === 'ENOENT') {
      const dirName = dirname(path);
      mkdirp(fileSystem, dirName);
      // eslint-disable-next-line consistent-return
      return writeToFileSystem(fileSystem, path, fileObj);
    }
    throw e;
  }
}

var thisModule;
async function loadToFS(modelName, zip) {
//  console.log("hao-asrWorker---loadToFS--begin----kaldiModule:");
//  console.log(kaldiModule);
  console.log("hao-asrWorker---loadToFS---unzip begin")
  const unzipped = await unzip(zip);
  console.log("hao-asrWorker---loadToFS--unzip over....");

  await  kaldiModule.then(
    function(result){
       console.log("hao---hao-asrWorker---loadToFS----kaldiModule.then:")
       console.log(result.FS);
       thisModule=result;
       initEMFS(thisModule.FS, modelName);
   });
  //initEMFS(kaldiModule.FS, modelName);

  //const unzipped = await unzip(zip);
  //const unzipped = unzip(zip);
  // hack to wait for model saving on Emscripten fileSystem
  // unzipped.forEach does not allow to wait for end of async calls
  const files = Object.keys(unzipped.files);
  await Promise.all(
      files.map(async (file) => {
        console.log("asrWorker----loadToFS---Promise.all...files.map--->"+file);
        const content = unzipped.file(file);
        if (content !== null) {
          //await writeToFileSystem(kaldiModule.FS, content.name, content);
          //await writeToFileSystem(thisModule.FS, content.name, content);
          console.log(" hao -----asrWorker----content.name--->"+content.name);
          const cwd = thisModule.FS.cwd();
          console.log(" hao------asrWorker-----cwd------>"+cwd);
          await writeToFileSystem(thisModule.FS, content.name, content);
        }
      })
  );

  //  }
 // );
   // .then(
   // function(endResult){
   //     asr = startASR(endResult);
   //    // asr = startASR();
   // }
   // );
//  asr = startASR(thisModule);
  console.log("asrWorker----loadToFS---end-----------------------<");
  return true;
}

/*
 * Assumes that we are in the directory with the requested model
 */
//function startASR() {
//  console.log("hao-asrWorker---startASR");
//  parser = new KaldiConfigParser(kaldiModule.FS, kaldiModule.FS.cwd());
//  const args = parser.createArgs();
//  const cppArgs = args.reduce((wasmArgs, arg) => {
//    wasmArgs.push_back(arg);
//    return wasmArgs;
//  }, new kaldiModule.StringList());
//  return new kaldiModule.OnlineASR(cppArgs);
//}
//function startASR(asrModule) {
function startASR() {
  console.log("hao-asrWorker---startASR---------->thisModule.FS:");
  console.log(thisModule.FS);
  //parser = new KaldiConfigParser(kaldiModule.FS, kaldiModule.FS.cwd());
  parser = new KaldiConfigParser(thisModule.FS, thisModule.FS.cwd());
  console.log("hao-asrWorker---startASR--------------------------> cwd");
  const args = parser.createArgs();
  const cppArgs = args.reduce((wasmArgs, arg) => {
    wasmArgs.push_back(arg);
    return wasmArgs;
  },
    //new kaldiModule.StringList());
    new thisModule.StringList());

  //return new kaldiModule.OnlineASR(cppArgs);
  return new thisModule.OnlineASR(cppArgs);
}

const helper = {
  async init(msg) {
    console.log("hao-asrWoker---init:"+msg+",msg:");
    console.log(msg);
    await loadToFS(msg.data.modelName, msg.data.zip);
    asr = startASR();
    //asr = startASR(thisModule);
  },
  async process(msg) {
    if (asr === null) throw new Error('ASR not ready');
    const asrOutput = asr.processBuffer(msg.data.pcm);
    if (asrOutput === '') return null;
    return {
      isFinal: asrOutput.endsWith('\n'),
      text: asrOutput.trim(),
    };
  },
  async samplerate() {
    if (parser === null) throw new Error('ASR not ready');
    return parser.getSampleRate();
  },
  async reset() {
    if (asr === null) throw new Error('ASR not ready');
    const asrOutput = asr.reset();
    const result = {
      isFinal: asrOutput.endsWith('\n'),
      text: asrOutput.trim(),
    };
    return result;
  },
  async terminate() {
    if (asr !== null) asr.delete();
    asr = null;
  },
};

onmessage = (msg) => {
  const { command } = msg.data;
  const response = { command, ok: true };

  if (command in helper) {
    helper[command](msg)
      .then((value) => { response.value = value; })
      .catch((e) => {
        response.ok = false;
        response.value = e;
      })
      .finally(() => { postMessage(response); });
  } else {
    response.ok = false;
    response.value = new Error(`Unknown command '${command}'`);
    postMessage(response);
  }
};

./kaldi/src/online2bin/online2-tcp-nnet3-decode-faster-reorganized.cc

// online2bin/online2-tcp-nnet3-decode-faster.cc

// Copyright 2014  Johns Hopkins University (author: Daniel Povey)
//           2016  Api.ai (Author: Ilya Platonov)
//           2018  Polish-Japanese Academy of Information Technology (Author: Danijel Korzinek)

// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//  http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#include <netinet/in.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <poll.h>
#include <signal.h>
#include <arpa/inet.h>
#include <unistd.h>
#include <string>

#include "feat/wave-reader.h"
#include "online2/online-nnet3-decoding.h"
#include "online2/online-nnet2-feature-pipeline.h"
#include "online2/onlinebin-util.h"
#include "online2/online-timing.h"
#include "online2/online-endpoint.h"
#include "fstext/fstext-lib.h"
#include "lat/lattice-functions.h"
#include "util/kaldi-thread.h"
#include "nnet3/nnet-utils.h"

namespace kaldi {

class TcpServer {
 public:
  explicit TcpServer(int read_timeout);
  ~TcpServer();

  bool Listen(int32 port);  // start listening on a given port
  int32 Accept();  // accept a client and return its descriptor

  // get more data and return false if end-of-stream
  bool ReadChunk(size_t len);

  Vector<BaseFloat> GetChunk();  // get the data read by above method

  // write to accepted client
  bool Write(const std::string &msg);
  bool WriteLn(const std::string &msg, const std::string &eol = "\n");

  void Disconnect();

 private:
  struct ::sockaddr_in h_addr_;
  int32 server_desc_, client_desc_;
  int16 *samp_buf_;
  size_t buf_len_, has_read_;
  pollfd client_set_[1];
  int read_timeout_;
};

std::string LatticeToString(
    const Lattice &lat,
    const fst::SymbolTable &word_syms
) {
  LatticeWeight weight;
  std::vector<int32> alignment;
  std::vector<int32> words;
  GetLinearSymbolSequence(lat, &alignment, &words, &weight);

  std::ostringstream msg;
  for (size_t i = 0; i < words.size(); i++) {
    std::string s = word_syms.Find(words[i]);
    if (s.empty()) {
      KALDI_WARN << "Word-id " << words[i] << " not in symbol table.";
      msg << "<#" << std::to_string(i) << "> ";
    } else {
        msg << s << " ";
    }
  }
  return msg.str();
}

std::string GetTimeString(int32 t_beg, int32 t_end, BaseFloat time_unit) {
  constexpr size_t kBufferLen { 100 };
  char buffer[kBufferLen];
  double t_beg2 = t_beg * time_unit;
  double t_end2 = t_end * time_unit;
  snprintf(buffer, kBufferLen, "%.2f %.2f", t_beg2, t_end2);
  return std::string(buffer);
}

int32 GetLatticeTimeSpan(const Lattice& lat) {
  std::vector<int32> times;
  LatticeStateTimes(lat, &times);
  return times.back();
}

std::string LatticeToString(
  const CompactLattice &clat,
  const fst::SymbolTable &word_syms
) {
  if (clat.NumStates() == 0) {
    KALDI_WARN << "Empty lattice.";
    return "";
  }
  CompactLattice best_path_clat;
  CompactLatticeShortestPath(clat, &best_path_clat);

  Lattice best_path_lat;
  ConvertLattice(best_path_clat, &best_path_lat);
  return LatticeToString(best_path_lat, word_syms);
}

struct OnlineASROptionParser: public ParseOptions {
  OnlineASROptionParser();
  explicit OnlineASROptionParser(int argc, const char* const* argv);
  int Read(int, const char* const*);

  // Members
  static constexpr const char *usage =
    "Reads in audio from a network socket and performs online\n"
    "decoding with neural nets (nnet3 setup), with iVector-based\n"
    "speaker adaptation and endpointing.\n"
    "Note: some configuration values and inputs are set via config\n"
    "files whose filenames are passed as options\n"
    "\n"
    "Usage: online2-tcp-nnet3-decode-faster [options] <nnet3-in> "
    "<fst-in> <word-symbol-table>\n";
  // ASR stuff
  BaseFloat output_period = 1;
  bool produce_time = false;
  BaseFloat samp_freq = 16000.0;
  OnlineEndpointConfig endpoint_opts;
  OnlineNnet2FeaturePipelineConfig feature_opts;
  nnet3::NnetSimpleLoopedComputationOptions decodable_opts;
  LatticeFasterDecoderConfig decoder_opts;
  std::string nnet3_rxfilename;
  std::string fst_rxfilename;
  std::string word_syms_filename;
  // TCP stuff
  BaseFloat chunk_length_secs = 0.18;
  int port_num = 5050;
  int read_timeout = 3;
};

class OnlineASR {
 public:
  static constexpr const char eou {'\n'};
  static constexpr const char tmp_eou {'\r'};

  explicit OnlineASR(int argc, const char *const argv[]);
  explicit OnlineASR(const std::vector<std::string> &args);
  explicit OnlineASR(const OnlineASROptionParser& po);
  std::string ProcessBuffer(int16 *, size_t);
  std::string ProcessSTLVector(const std::vector<int16>&);
  std::string ProcessVector(const Vector<BaseFloat>&);
  std::string Reset();
  ~OnlineASR();

 private:
  BaseFloat samp_freq;
  int32 frame_offset {0};
  int32 check_period;
  int32 samp_count {0};
  bool produce_time {false};
  // Model related members
  nnet3::DecodableNnetSimpleLoopedInfo *decodable_info = nullptr;
  OnlineNnet2FeaturePipelineInfo *feature_info = nullptr;
  LatticeFasterDecoderConfig decoder_opts;
  nnet3::NnetSimpleLoopedComputationOptions decodable_opts;
  fst::Fst<fst::StdArc> *decode_fst = nullptr;
  TransitionModel trans_model;
  nnet3::AmNnetSimple am_nnet;
  fst::SymbolTable *word_syms = nullptr;
  OnlineEndpointConfig endpoint_opts;
  // Stream parameters
  OnlineNnet2FeaturePipeline *feature_pipeline = nullptr;
  SingleUtteranceNnet3Decoder *decoder = nullptr;
  // Utterance parameters
  OnlineSilenceWeighting *silence_weighting = nullptr;
  std::vector<std::pair<int32, BaseFloat> > delta_weights;

  // private methods
  void InitClass(const OnlineASROptionParser& parser);
  void InitWords(const std::string& filename);
  void UpdateDecoder(const Vector<BaseFloat>&);
  std::string CheckDecoderOutput();
  std::string PrependTimestamps(const std::string&);
  void ResetStreamDecoder();
  void ResetUtteranceDecoder();
};
}  // end kaldi namespace

#ifndef __EMSCRIPTEN__

int main(int argc, const char* const* argv) {
  using kaldi::int32;
  using kaldi::int64;
  using kaldi::OnlineASR;
  using kaldi::OnlineASROptionParser;
  using kaldi::Vector;
  using kaldi::BaseFloat;
  using kaldi::TcpServer;

  OnlineASROptionParser po;

  try {
    po.Read(argc, argv);
    OnlineASR onlineASR(po);

    // ignore SIGPIPE to avoid crashing when socket forcefully disconnected
    signal(SIGPIPE, SIG_IGN);

    size_t chunk_len = static_cast<size_t>(po.chunk_length_secs * po.samp_freq);
    TcpServer server(po.read_timeout);
    server.Listen(po.port_num);

    while (true) {
      server.Accept();
      bool eos {false};

      while (!eos) {
        while (true) {
          eos = !server.ReadChunk(chunk_len);
          if (eos) {
            std::string msg { onlineASR.Reset() };
            KALDI_VLOG(1) << "EndOfAudio, sending message: " << msg;
            server.Write(msg);
            server.Disconnect();
            break;
          }
          Vector<BaseFloat> wave_part = server.GetChunk();
          std::string msg { onlineASR.ProcessVector(wave_part) };
          if (msg != "") {
            server.Write(msg);
            if (msg[msg.length() - 1] == onlineASR.tmp_eou) {
              KALDI_VLOG(1) << "Temporary transcript: " << msg;
            } else {
              KALDI_VLOG(1) << "Endpoint, sending message: " << msg;
              break;
            }
          }
        }
      }
    }
  } catch (const std::invalid_argument& e) {
    po.PrintUsage();
    return 1;
  } catch (const std::exception &e) {
    std::cerr << e.what();
    return -1;
  }
}
#else

#include <emscripten/bind.h>
#include <emscripten/val.h>
#include <iterator>

using std::vector;
using kaldi::int16;
using emscripten::val;
using emscripten::class_;
using emscripten::optional_override;
using emscripten::register_vector;

/* Convert JS Int16Array to C++ std::vector<kaldi::int16> without copy of data
*/
vector<int16> typed_array_to_vector(const val &int16_array) {
  unsigned int length = int16_array["length"].as<unsigned int>();
  vector<int16> vec(length);

  val memory = val::module_property("HEAP16")["buffer"];
  val memoryView = val::global("Int16Array").new_(memory,
      reinterpret_cast<std::uintptr_t>(vec.data()), length);

  memoryView.call<void>("set", int16_array);

  return vec;
}

EMSCRIPTEN_BINDINGS(asr) {
  class_<kaldi::OnlineASR>("OnlineASR")
    .constructor<const std::vector<std::string>& >()
    // Inject lambda before class method call to adapt I/O types
    .function("processBuffer", optional_override(
      [](kaldi::OnlineASR& self, const val& int16_array) {
        vector<int16> vect_array = typed_array_to_vector(int16_array);
        return self.ProcessSTLVector(vect_array);
      })
    )
    .function("reset", &kaldi::OnlineASR::Reset)
  ;
  // Define JS class StringList to be understood as vector<string> in C++
  register_vector<std::string>("StringList");
}

#endif

namespace kaldi {
TcpServer::TcpServer(int read_timeout) {
  server_desc_ = -1;
  client_desc_ = -1;
  samp_buf_ = NULL;
  buf_len_ = 0;
  read_timeout_ = 1000 * read_timeout;
}

bool TcpServer::Listen(int32 port) {
  h_addr_.sin_addr.s_addr = INADDR_ANY;
  h_addr_.sin_port = htons(port);
  h_addr_.sin_family = AF_INET;

  server_desc_ = socket(AF_INET, SOCK_STREAM, 0);

  if (server_desc_ == -1) {
    KALDI_ERR << "Cannot create TCP socket!";
    return false;
  }

  int32 flag = 1;
  int32 len = sizeof(int32);
  if (setsockopt(server_desc_, SOL_SOCKET, SO_REUSEADDR, &flag, len) == -1) {
    KALDI_ERR << "Cannot set socket options!";
    return false;
  }

  if (bind(server_desc_, (struct sockaddr *) &h_addr_, sizeof(h_addr_)) == -1) {
    KALDI_ERR << "Cannot bind to port: " << port << " (is it taken?)";
    return false;
  }

  if (listen(server_desc_, 1) == -1) {
    KALDI_ERR << "Cannot listen on port!";
    return false;
  }

  KALDI_LOG << "TcpServer: Listening on port: " << port;

  return true;
}

TcpServer::~TcpServer() {
  Disconnect();
  if (server_desc_ != -1)
    close(server_desc_);
  delete[] samp_buf_;
}

int32 TcpServer::Accept() {
  KALDI_LOG << "Waiting for client...";

  socklen_t len;

  len = sizeof(struct sockaddr);
  client_desc_ = accept(server_desc_, (struct sockaddr *) &h_addr_, &len);

  struct sockaddr_storage addr;
  char ipstr[20];

  len = sizeof addr;
  getpeername(client_desc_, (struct sockaddr *) &addr, &len);

  struct sockaddr_in *s = (struct sockaddr_in *) &addr;
  inet_ntop(AF_INET, &s->sin_addr, ipstr, sizeof ipstr);

  client_set_[0].fd = client_desc_;
  client_set_[0].events = POLLIN;

  KALDI_LOG << "Accepted connection from: " << ipstr;

  return client_desc_;
}

bool TcpServer::ReadChunk(size_t len) {
  if (buf_len_ != len) {
    buf_len_ = len;
    delete[] samp_buf_;
    samp_buf_ = new int16[len];
  }

  ssize_t ret;
  int poll_ret;
  size_t to_read = len;
  has_read_ = 0;
  while (to_read > 0) {
    poll_ret = poll(client_set_, 1, read_timeout_);
    if (poll_ret == 0) {
      KALDI_WARN << "Socket timeout! Disconnecting...";
      break;
    }
    if (poll_ret < 0) {
      KALDI_WARN << "Socket error! Disconnecting...";
      break;
    }
    ret = read(client_desc_, static_cast<void *>(samp_buf_ + has_read_),
               to_read * sizeof(int16));
    if (ret <= 0) {
      KALDI_WARN << "Stream over...";
      break;
    }
    to_read -= ret / sizeof(int16);
    has_read_ += ret / sizeof(int16);
  }

  return has_read_ > 0;
}

Vector<BaseFloat> TcpServer::GetChunk() {
  Vector<BaseFloat> buf;

  buf.Resize(static_cast<MatrixIndexT>(has_read_));

  for (int i = 0; i < has_read_; i++)
    buf(i) = static_cast<BaseFloat>(samp_buf_[i]);

  return buf;
}

bool TcpServer::Write(const std::string &msg) {
  const char *p = msg.c_str();
  size_t to_write = msg.size();
  size_t wrote = 0;
  while (to_write > 0) {
    ssize_t ret = write(client_desc_, static_cast<const void *>(p + wrote),
                        to_write);
    if (ret <= 0)
      return false;

    to_write -= ret;
    wrote += ret;
  }

  return true;
}

bool TcpServer::WriteLn(const std::string &msg, const std::string &eol) {
  if (Write(msg))
    return Write(eol);
  else
    return false;
}

void TcpServer::Disconnect() {
  if (client_desc_ != -1) {
    close(client_desc_);
    client_desc_ = -1;
  }
}

OnlineASROptionParser::OnlineASROptionParser(): ParseOptions{usage} {
  g_num_threads = 0;

  Register("samp-freq", &samp_freq,
           "Sampling frequency of the input signal (coded as 16-bit slinear).");
  Register("chunk-length", &chunk_length_secs,
           "Length of chunk size in seconds, that we process.");
  Register("output-period", &output_period,
           "How often in seconds, do we check for changes in output.");
  Register("num-threads-startup", &g_num_threads,
           "Number of threads used when initializing iVector extractor.");
  Register("read-timeout", &read_timeout,
           "Number of seconds of timout for TCP audio data to appear on the "
           "stream. Use -1 for blocking.");
  Register("port-num", &port_num,
           "Portnumber the server will listen on.");
  Register("produce-time", &produce_time,
           "Prepend begin/end times between endpoints (e.g. '5.46 6.81"
           " <text_output>', in seconds)");

  endpoint_opts.Register(this);
  feature_opts.Register(this);
  decodable_opts.Register(this);
  decoder_opts.Register(this);
}

OnlineASROptionParser::OnlineASROptionParser(int argc,
                                             const char* const * argv):
  OnlineASROptionParser() {
    Read(argc, argv);
  }

int OnlineASROptionParser::Read(int argc, const char* const* argv) {
  int read_value = ParseOptions::Read(argc, argv);
  if (NumArgs() != 3)
    throw std::invalid_argument("Wrong number of arguments\n");

  nnet3_rxfilename = GetArg(1);
  fst_rxfilename = GetArg(2);
  word_syms_filename = GetArg(3);

  return read_value;
}

OnlineASR::OnlineASR(int argc, const char *const argv[]):
  OnlineASR(OnlineASROptionParser(argc, argv)) {
  }

OnlineASR::OnlineASR(const std::vector<std::string> &args) {
  // Convert args to const char* const *
  std::vector<const char*> char_array;
  char_array.reserve(args.size());
  for (int i = 0; i < args.size(); ++i)
    char_array.push_back(const_cast<char*>(args[i].c_str()));

  int argc { static_cast<int>(char_array.size()) };
  OnlineASROptionParser parser {argc, &char_array[0]};
  InitClass(parser);
}

OnlineASR::OnlineASR(const OnlineASROptionParser& parser) {
  InitClass(parser);
}

void OnlineASR::InitClass(const OnlineASROptionParser& parser) {
  decodable_opts = parser.decodable_opts;
  decoder_opts = parser.decoder_opts;
  endpoint_opts = parser.endpoint_opts;
  samp_freq = parser.samp_freq;
  check_period = static_cast<int32>(samp_freq * parser.output_period);
  produce_time = parser.produce_time;

  feature_info = new OnlineNnet2FeaturePipelineInfo(parser.feature_opts);
  InitWords(parser.word_syms_filename);

  KALDI_VLOG(1) << "Loading AM...";
  {
    bool binary;
    Input ki(parser.nnet3_rxfilename, &binary);
    trans_model.Read(ki.Stream(), binary);
    am_nnet.Read(ki.Stream(), binary);
    SetBatchnormTestMode(true, &(am_nnet.GetNnet()));
    SetDropoutTestMode(true, &(am_nnet.GetNnet()));
    nnet3::CollapseModel(nnet3::CollapseModelConfig(), &(am_nnet.GetNnet()));
  }

  KALDI_VLOG(1) << "Loading FST...";
  decode_fst = fst::ReadFstKaldiGeneric(parser.fst_rxfilename);
  // this object contains precomputed stuff that is used by all decodable
  // objects.  It takes a pointer to am_nnet because if it has iVectors it has
  // to modify the nnet to accept iVectors at intervals.
  decodable_info = new nnet3::DecodableNnetSimpleLoopedInfo(decodable_opts,
      &am_nnet);

  ResetStreamDecoder();
}

void OnlineASR::ResetStreamDecoder() {
  frame_offset = 0;

  delete feature_pipeline;
  feature_pipeline = new OnlineNnet2FeaturePipeline(*feature_info);

  delete decoder;
  decoder = new SingleUtteranceNnet3Decoder(decoder_opts, trans_model,
                                            *decodable_info, *decode_fst,
                                            feature_pipeline);
  ResetUtteranceDecoder();
}

void OnlineASR::InitWords(const std::string &filename) {
  if (!filename.empty())
    if (!(word_syms = fst::SymbolTable::ReadText(filename)))
      KALDI_ERR << "Could not read symbol table from file "
                << filename;
}

void OnlineASR::ResetUtteranceDecoder() {
  decoder->InitDecoding(frame_offset);
  delete silence_weighting;
  silence_weighting = new OnlineSilenceWeighting(
      trans_model,
      feature_info->silence_weighting_config,
      decodable_opts.frame_subsampling_factor);
  delta_weights = std::vector<std::pair<int32, BaseFloat> >();
}

std::string OnlineASR::ProcessBuffer(int16 *samp_buf, size_t buf_len) {
  Vector<BaseFloat> buf;
  buf.Resize(static_cast<MatrixIndexT>(buf_len));
  for (int i = 0; i < buf_len; ++i)
    buf(i) = static_cast<BaseFloat>(samp_buf[i]);

  return ProcessVector(buf);
}

std::string OnlineASR::ProcessVector(const Vector<BaseFloat>& buf) {
  UpdateDecoder(buf);
  return CheckDecoderOutput();
}

void OnlineASR::UpdateDecoder(const Vector<BaseFloat>& buf) {
  feature_pipeline->AcceptWaveform(samp_freq, buf);
  samp_count += buf.Dim();

  if (silence_weighting->Active() &&
      feature_pipeline->IvectorFeature() != NULL) {
    silence_weighting->ComputeCurrentTraceback(decoder->Decoder());
    silence_weighting->GetDeltaWeights(feature_pipeline->NumFramesReady(),
                                       frame_offset * decodable_opts.frame_subsampling_factor,
                                       &delta_weights);
    feature_pipeline->UpdateFrameWeights(delta_weights);
  }

  decoder->AdvanceDecoding();
}

std::string OnlineASR::CheckDecoderOutput() {
  if (decoder->EndpointDetected(endpoint_opts)) {
    samp_count %= check_period;
    decoder->FinalizeDecoding();
    frame_offset += decoder->NumFramesDecoded();
    CompactLattice lat;
    decoder->GetLattice(true, &lat);
    std::string msg = LatticeToString(lat, *word_syms);
    if (produce_time) msg = PrependTimestamps(msg);

    ResetUtteranceDecoder();
    return msg + eou;
  }

  // Force temporary result
  if (samp_count > check_period) {
    samp_count %= check_period;
    if (decoder->NumFramesDecoded() > 0) {
      Lattice lat;
      decoder->GetBestPath(false, &lat);
      TopSort(&lat);  // for LatticeStateTimes(),
      std::string msg = LatticeToString(lat, *word_syms);

      if (produce_time) {
        int32 frame_subsampling { decodable_opts.frame_subsampling_factor };
        BaseFloat frame_shift { feature_info->FrameShiftInSeconds() };
        int32 t_beg = frame_offset;
        int32 t_end = frame_offset + GetLatticeTimeSpan(lat);
        msg = GetTimeString(t_beg, t_end, frame_shift * frame_subsampling) + " "
              + msg;
      }
      return msg + tmp_eou;
    }
  }

  return "";
}

std::string OnlineASR::PrependTimestamps(const std::string& msg) {
  int32 frame_subsampling { decodable_opts.frame_subsampling_factor };
  BaseFloat frame_shift { feature_info->FrameShiftInSeconds() };
  int32 t_beg = frame_offset - decoder->NumFramesDecoded();
  int32 t_end = frame_offset;
  return GetTimeString(t_beg, t_end, frame_shift * frame_subsampling) + " "
         + msg;
}

std::string OnlineASR::ProcessSTLVector(const std::vector<int16>& samp_buf) {
  // cast input to float
  Vector<BaseFloat> buf;
  size_t buf_len = samp_buf.size();
  buf.Resize(static_cast<MatrixIndexT>(buf_len));
  for (int i = 0; i < buf_len; ++i)
    buf(i) = static_cast<BaseFloat>(samp_buf[i]);

  return ProcessVector(buf);
}

std::string OnlineASR::Reset() {
  feature_pipeline->InputFinished();
  decoder->AdvanceDecoding();
  decoder->FinalizeDecoding();

  std::string msg {""};
  frame_offset += decoder->NumFramesDecoded();
  if (decoder->NumFramesDecoded() > 0) {
    CompactLattice lat;
    decoder->GetLattice(true, &lat);
    msg = LatticeToString(lat, *word_syms);
    if (produce_time) msg = PrependTimestamps(msg);
  }

  ResetStreamDecoder();
  return msg + eou;
}

OnlineASR::~OnlineASR() {
  delete feature_info;
  delete feature_pipeline;
  delete decoder;
  delete decodable_info;
  delete word_syms;
  delete silence_weighting;
  delete decode_fst;
}
}  // namespace kaldi
分享到:
评论

相关推荐

    kaldi详细资料_kadi语音识别工具_

    语音识别工具kaldi及其应用,kaidi全部资料,适合新手使用

    comet:[ICLR 2021]少量学习的概念学习者

    少量学习的概念学习者 凯迪(Kaidi Cao)*,玛丽亚(MariaBrbić)*,尤里(Jure Leskovec) 此存储库包含COMET算法的...我们在这里提供一个示例: 运行python ./train.py --dataset CUB --model Conv4NP --method

    GCN_ADV_Train:图神经网络的对抗训练

    此外,利用我们基于梯度的攻击,我们提出了针对GNN的第一个基于优化的对抗训练。 引用这项工作: 徐凯迪*,陈洪格*,刘思佳,陈品宇,翁翠薇,洪明义和林雪, ,IJCAI 2019。(*平等贡献) @inproceedings{xu2019...

    node_tasted:Node.js 浅尝辄止——分享至我仍未知道名字的十楼公司

    By Kaidi, ZHU, R&D Engineer of and . 正确打开姿势 预备工作:请确保已安装 Node.js 在你的电脑。 安装 依赖。执行 $ npm install。 启动它,执行 $ npm start 。 在弹出的浏览器窗口中点击 tasted.md 即可。 若非...

    node-v7.7.2-linux-x86.tar.xz

    Node.js,简称Node,是一个开源且跨平台的JavaScript运行时环境,它允许在浏览器外运行JavaScript代码。Node.js于2009年由Ryan Dahl创立,旨在创建高性能的Web服务器和网络应用程序。它基于Google Chrome的V8 JavaScript引擎,可以在Windows、Linux、Unix、Mac OS X等操作系统上运行。 Node.js的特点之一是事件驱动和非阻塞I/O模型,这使得它非常适合处理大量并发连接,从而在构建实时应用程序如在线游戏、聊天应用以及实时通讯服务时表现卓越。此外,Node.js使用了模块化的架构,通过npm(Node package manager,Node包管理器),社区成员可以共享和复用代码,极大地促进了Node.js生态系统的发展和扩张。 Node.js不仅用于服务器端开发。随着技术的发展,它也被用于构建工具链、开发桌面应用程序、物联网设备等。Node.js能够处理文件系统、操作数据库、处理网络请求等,因此,开发者可以用JavaScript编写全栈应用程序,这一点大大提高了开发效率和便捷性。 在实践中,许多大型企业和组织已经采用Node.js作为其Web应用程序的开发平台,如Netflix、PayPal和Walmart等。它们利用Node.js提高了应用性能,简化了开发流程,并且能更快地响应市场需求。

    前后端分离的毕业论文(设计)管理系统 (SpringBoot+Vue)

    关于基于SpringBoot和Vue的毕业论文(设计)管理系统,到了一些相关的资源和示例项目,这些资源可能对您的毕业设计有所帮助。 1. **SpringBoot+Vue的三只松鼠商城**: 这个项目是一个基于SpringBoot和Vue的在线购物系统,采用了前后端分离的架构模式。系统实现了管理员模块和用户模块,包括用户管理、地址管理、订单管理、商品管理、支付功能等。这个项目是一个B2C电商平台,使用了MySQL和Redis数据库。 2. **大学生校园社团管理系统**: 这是一个基于SpringBoot和Vue的校园社团管理系统,旨在简化社团报名和组织活动的流程。系统包括用户管理、社团管理、活动信息管理等功能。该项目展示了如何使用前后端分离架构来构建一个校园社团管理系统。 3. **智慧宿舍管理系统**: 这个项目是基于SpringBoot和Vue的智慧宿舍管理系统,旨在提高宿舍管理的效率和便利性。系统包括学生宿舍信息管理、设备监控、安全管理和生活服务等功能。该项目展示了如何使用前后端分离架构来构建一个智能宿舍管理系统。 这些项目可以为您的毕业设计提供灵感和实际的技术指导。您可以

    238.html

    238.html

    基于tensorflow深度学习的地理位置的命名实体识别.zip

    基于tensorflow深度学习的地理位置的命名实体识别.zip

    优秀项目 基于STM32单片机+Python+OpenCV的二自由度人脸跟踪舵机云台源码+详细文档+全部数据资料.zip

    【资源说明】 优秀项目 基于STM32单片机+Python+OpenCV的二自由度人脸跟踪舵机云台源码+详细文档+全部数据资料.zip 【备注】 1、该项目是个人高分项目源码,已获导师指导认可通过,答辩评审分达到95分 2、该资源内项目代码都经过测试运行成功,功能ok的情况下才上传的,请放心下载使用! 3、本项目适合计算机相关专业(如软件工程、计科、人工智能、通信工程、自动化、电子信息等)的在校学生、老师或者企业员工下载使用,也可作为毕业设计、课程设计、作业、项目初期立项演示等,当然也适合小白学习进阶。 4、如果基础还行,可以在此代码基础上进行修改,以实现其他功能,也可直接用于毕设、课设、作业等。 欢迎下载,沟通交流,互相学习,共同进步!

    文件I/O基础-I.MX6U嵌入式Linux C应用编程学习笔记基于正点原子阿尔法开发板

    文件I/O基础-I.MX6U嵌入式Linux C应用编程学习笔记基于正点原子阿尔法开发板

    基于深度神经网络的图像分类任务.zip

    基于深度神经网络的图像分类任务.zip

    强化学习基准代码,已经针对Tensoflow2.x版本修改,可以直接使用

    强化学习基准代码,已经针对Tensoflow2.x版本修改,可以直接使用

    node-v7.7.4-linux-ppc64.tar.xz

    Node.js,简称Node,是一个开源且跨平台的JavaScript运行时环境,它允许在浏览器外运行JavaScript代码。Node.js于2009年由Ryan Dahl创立,旨在创建高性能的Web服务器和网络应用程序。它基于Google Chrome的V8 JavaScript引擎,可以在Windows、Linux、Unix、Mac OS X等操作系统上运行。 Node.js的特点之一是事件驱动和非阻塞I/O模型,这使得它非常适合处理大量并发连接,从而在构建实时应用程序如在线游戏、聊天应用以及实时通讯服务时表现卓越。此外,Node.js使用了模块化的架构,通过npm(Node package manager,Node包管理器),社区成员可以共享和复用代码,极大地促进了Node.js生态系统的发展和扩张。 Node.js不仅用于服务器端开发。随着技术的发展,它也被用于构建工具链、开发桌面应用程序、物联网设备等。Node.js能够处理文件系统、操作数据库、处理网络请求等,因此,开发者可以用JavaScript编写全栈应用程序,这一点大大提高了开发效率和便捷性。 在实践中,许多大型企业和组织已经采用Node.js作为其Web应用程序的开发平台,如Netflix、PayPal和Walmart等。它们利用Node.js提高了应用性能,简化了开发流程,并且能更快地响应市场需求。

    DZT0227-2010 地质岩心钻探规程.pdf

    DZT0227-2010 地质岩心钻探规程.pdf

    自动泊车之AVM环视系统算法及其框架.pdf

    自动泊车之AVM环视系统算法及其框架

    HTML+CSS+JS精品网页模板H70.rar

    HTML+CSS+JS精品网页模板,设置导航条、轮翻效果,鼠标滑动效果,自动弹窗,点击事件、链接等功能;适用于大学生期末大作业或公司网页的设计制作。响应式网页,可以根据不同的设备屏幕大小自动调整页面布局; 支持如Dreamweaver、HBuilder、Text 、Vscode 等任意html编辑软件进行编辑修改; 支持包括IE、Firefox、Chrome、Safari主流浏览器浏览; 下载文件解压缩,用Dreamweaver、HBuilder、Text 、Vscode 等任意html编辑软件打开,只需更改源代码中的文字和图片可直接使用。图片的命名和格式需要与原图片的名字和格式一致,其他的无需更改。如碰到HTML5+CSS+JS等专业技术问题,以及需要对应行业的模板等相关源码、模板、资料、教程等,随时联系博主咨询。 网页设计和制作、大学生网页课程设计、期末大作业、毕业设计、网页模板,网页成品源代码等,5000+套Web案例源码,主题涵盖各行各业,关注作者联系获取更多源码; 更多优质网页博文、网页模板移步查阅我的CSDN主页:angella.blog.csdn.net。

    高分项目 基于STM32单片机的窗户控制系统APP源代码+项目资料齐全+教程文档.zip

    【资源概览】 高分项目 基于STM32的窗户控制系统APP源代码+项目资料齐全+教程文档.zip高分项目 基于STM32的窗户控制系统APP源代码+项目资料齐全+教程文档.zip高分项目 基于STM32的窗户控制系统APP源代码+项目资料齐全+教程文档.zip 【资源说明】 高分项目源码:此资源是在校高分项目的完整源代码,经过导师的悉心指导与认可,答辩评审得分高达95分,项目的质量与深度有保障。 测试运行成功:所有的项目代码在上传前都经过了严格的测试,确保在功能上完全符合预期,您可以放心下载并使用。 适用人群广泛:该项目不仅适合计算机相关专业(如电子信息、物联网、通信工程、自动化等)的在校学生和老师,还可以作为毕业设计、课程设计、作业或项目初期立项的演示材料。对于希望进阶学习的小白来说,同样是一个极佳的学习资源。 代码灵活性高:如果您具备一定的编程基础,可以在此代码基础上进行个性化的修改,以实现更多功能。当然,直接用于毕业设计、课程设计或作业也是完全可行的。 欢迎下载,与我一起交流学习,共同进步!

    Windows 系统下 Xshell 安装使用教程

    附件是Windows 系统下 Xshell 安装使用教程,仅供交流学习使用,无任何商业目的! Xshell 是一款功能丰富的 SSH 客户端,除了基本的远程命令行访问,还提供了许多高级功能,如标签式界面、强大的脚本功能等。通过实践和探索,你可以更深入地了解 Xshell 的各种功能。

    node-v0.12.17-x86.msi

    Node.js,简称Node,是一个开源且跨平台的JavaScript运行时环境,它允许在浏览器外运行JavaScript代码。Node.js于2009年由Ryan Dahl创立,旨在创建高性能的Web服务器和网络应用程序。它基于Google Chrome的V8 JavaScript引擎,可以在Windows、Linux、Unix、Mac OS X等操作系统上运行。 Node.js的特点之一是事件驱动和非阻塞I/O模型,这使得它非常适合处理大量并发连接,从而在构建实时应用程序如在线游戏、聊天应用以及实时通讯服务时表现卓越。此外,Node.js使用了模块化的架构,通过npm(Node package manager,Node包管理器),社区成员可以共享和复用代码,极大地促进了Node.js生态系统的发展和扩张。 Node.js不仅用于服务器端开发。随着技术的发展,它也被用于构建工具链、开发桌面应用程序、物联网设备等。Node.js能够处理文件系统、操作数据库、处理网络请求等,因此,开发者可以用JavaScript编写全栈应用程序,这一点大大提高了开发效率和便捷性。 在实践中,许多大型企业和组织已经采用Node.js作为其Web应用程序的开发平台,如Netflix、PayPal和Walmart等。它们利用Node.js提高了应用性能,简化了开发流程,并且能更快地响应市场需求。

    node-v7.1.0-linux-arm64.tar.xz

    Node.js,简称Node,是一个开源且跨平台的JavaScript运行时环境,它允许在浏览器外运行JavaScript代码。Node.js于2009年由Ryan Dahl创立,旨在创建高性能的Web服务器和网络应用程序。它基于Google Chrome的V8 JavaScript引擎,可以在Windows、Linux、Unix、Mac OS X等操作系统上运行。 Node.js的特点之一是事件驱动和非阻塞I/O模型,这使得它非常适合处理大量并发连接,从而在构建实时应用程序如在线游戏、聊天应用以及实时通讯服务时表现卓越。此外,Node.js使用了模块化的架构,通过npm(Node package manager,Node包管理器),社区成员可以共享和复用代码,极大地促进了Node.js生态系统的发展和扩张。 Node.js不仅用于服务器端开发。随着技术的发展,它也被用于构建工具链、开发桌面应用程序、物联网设备等。Node.js能够处理文件系统、操作数据库、处理网络请求等,因此,开发者可以用JavaScript编写全栈应用程序,这一点大大提高了开发效率和便捷性。 在实践中,许多大型企业和组织已经采用Node.js作为其Web应用程序的开发平台,如Netflix、PayPal和Walmart等。它们利用Node.js提高了应用性能,简化了开发流程,并且能更快地响应市场需求。

Global site tag (gtag.js) - Google Analytics