`
haoningabc
  • 浏览: 1449324 次
  • 性别: Icon_minigender_1
  • 来自: 北京
社区版块
存档分类
最新评论

kaidi-wasm学习笔记(一)的两个重要文件

阅读更多
最近在看kaldi-wasm
两个重要文件备份一下,总结以后再写:

src/workers/asrWorker.js
import JSZip from 'jszip';

import kaldiJS from '../computations/kaldiJS';
import kaldiWasm from '../computations/kaldiJS.wasm';
import KaldiConfigParser from '../utils/kaldiConfigParser';

const kaldiModule = kaldiJS({
  locateFile(path) {
    if (path.endsWith('.wasm')){
        console.log("hao-asrWorer.js---kaldiJS:"+path);
        return kaldiWasm;
    }
   // if (path.endsWith('.wasm')) return kaldiJS;
    return path;
  },
});

const MODEL_STORE = {
  NAME: 'models',
  KEY_PATH: 'language',
};

let asr = null;
let parser = null;

function mkdirExistOK(fileSystem, path) {
  console.log("hao:asrWorker.js---mkdirExistOK,path:" +path+",fileSystem:")
  console.log(fileSystem);
  try {
    //fileSystem.mkdir(path);
    fileSystem.mkdir(path);
  } catch (e) {
    console.log("hao--mkdirExistOK--error..$$$$$$$$$$$$$$$$$$$$");
    if (e.code !== 'EEXIST') throw e;
  }
}

function initEMFS(fileSystem, modelName) {
  console.log("hao-asrWorker---initEMFS--fileSystem:");
  console.log(fileSystem);
  mkdirExistOK(fileSystem, MODEL_STORE.NAME);
  console.log("hao-asrWorker---initEMFS--MODEL_STORE.NAME:"+MODEL_STORE.NAME);
  fileSystem.mount(fileSystem.filesystems.IDBFS, {},
    MODEL_STORE.NAME);
  fileSystem.chdir(MODEL_STORE.NAME);
  fileSystem.mkdir(modelName);
  fileSystem.chdir(modelName);
  console.log("hao-asrWorker---initEMFS--over.");
}

async function unzip(zipfile) {
//  console.log("hao-asrWorker---unzip--:");
//  console.log(zipfile);
  const zip = new JSZip();

  const unzipped = await zip.loadAsync(zipfile);
  return unzipped;
}

function dirname(path) {
  const dirs = path.match(/.*\//);
  if (dirs === null) return '';
  // without trailing '/'
  return dirs[0].slice(0, dirs[0].length - 1);
}

function mkdirp(fileSystem, path) {
  console.log("hao-asrWorker-----mkdirp--path:"+path);
  console.log(fileSystem);
  const dirBoundary = '/';
  const startIndex = path[0] === dirBoundary ? 1 : 0;
  for (let i = startIndex; i < path.length; i += 1) {
    if (path[i] === dirBoundary) mkdirExistOK(fileSystem, path.slice(0, i));
  }
  mkdirExistOK(fileSystem, path);
}

async function writeToFileSystem(fileSystem, path, fileObj) {
  console.log("asrWorker.js---writeToFileSystem---fileSystem:");
  console.log(fileSystem);
  const content = await fileObj.async('arraybuffer');
  console.log("content:"+content);
  try {
    //fileSystem.writeFile(path, new Uint8Array(content));
    fileSystem.writeFile(path, new Uint8Array(content),function(err){
        console.log("writefile-----error:"+err);
    });
    console.log("asrWorker.js---writeToFileSystem---writeFile----over.path:"+fileSystem.cwd()+"/"+path);
    console.log("asrWorker.js---writeToFileSystem---isDir----models:"+fileSystem.isDir("/models"));
    console.log("asrWorker.js---writeToFileSystem---isDir:::"+fileSystem.isDir(fileSystem.cwd()));
    console.log("asrWorker.js---writeToFileSystem---final.mdl---isFile:"+fileSystem.isFile("final.mdl"));
    console.log("asrWorker.js---writeToFileSystem---end----------");
    return;
  } catch (e) {
    console.log("hao---error:-->>>>>>>>>>>>>>>>>>>>>>>writeToFileSystem......>");
    if (e.code === 'ENOENT') {
      const dirName = dirname(path);
      mkdirp(fileSystem, dirName);
      // eslint-disable-next-line consistent-return
      return writeToFileSystem(fileSystem, path, fileObj);
    }
    throw e;
  }
}

var thisModule;
async function loadToFS(modelName, zip) {
//  console.log("hao-asrWorker---loadToFS--begin----kaldiModule:");
//  console.log(kaldiModule);
  console.log("hao-asrWorker---loadToFS---unzip begin")
  const unzipped = await unzip(zip);
  console.log("hao-asrWorker---loadToFS--unzip over....");

  await  kaldiModule.then(
    function(result){
       console.log("hao---hao-asrWorker---loadToFS----kaldiModule.then:")
       console.log(result.FS);
       thisModule=result;
       initEMFS(thisModule.FS, modelName);
   });
  //initEMFS(kaldiModule.FS, modelName);

  //const unzipped = await unzip(zip);
  //const unzipped = unzip(zip);
  // hack to wait for model saving on Emscripten fileSystem
  // unzipped.forEach does not allow to wait for end of async calls
  const files = Object.keys(unzipped.files);
  await Promise.all(
      files.map(async (file) => {
        console.log("asrWorker----loadToFS---Promise.all...files.map--->"+file);
        const content = unzipped.file(file);
        if (content !== null) {
          //await writeToFileSystem(kaldiModule.FS, content.name, content);
          //await writeToFileSystem(thisModule.FS, content.name, content);
          console.log(" hao -----asrWorker----content.name--->"+content.name);
          const cwd = thisModule.FS.cwd();
          console.log(" hao------asrWorker-----cwd------>"+cwd);
          await writeToFileSystem(thisModule.FS, content.name, content);
        }
      })
  );

  //  }
 // );
   // .then(
   // function(endResult){
   //     asr = startASR(endResult);
   //    // asr = startASR();
   // }
   // );
//  asr = startASR(thisModule);
  console.log("asrWorker----loadToFS---end-----------------------<");
  return true;
}

/*
 * Assumes that we are in the directory with the requested model
 */
//function startASR() {
//  console.log("hao-asrWorker---startASR");
//  parser = new KaldiConfigParser(kaldiModule.FS, kaldiModule.FS.cwd());
//  const args = parser.createArgs();
//  const cppArgs = args.reduce((wasmArgs, arg) => {
//    wasmArgs.push_back(arg);
//    return wasmArgs;
//  }, new kaldiModule.StringList());
//  return new kaldiModule.OnlineASR(cppArgs);
//}
//function startASR(asrModule) {
function startASR() {
  console.log("hao-asrWorker---startASR---------->thisModule.FS:");
  console.log(thisModule.FS);
  //parser = new KaldiConfigParser(kaldiModule.FS, kaldiModule.FS.cwd());
  parser = new KaldiConfigParser(thisModule.FS, thisModule.FS.cwd());
  console.log("hao-asrWorker---startASR--------------------------> cwd");
  const args = parser.createArgs();
  const cppArgs = args.reduce((wasmArgs, arg) => {
    wasmArgs.push_back(arg);
    return wasmArgs;
  },
    //new kaldiModule.StringList());
    new thisModule.StringList());

  //return new kaldiModule.OnlineASR(cppArgs);
  return new thisModule.OnlineASR(cppArgs);
}

const helper = {
  async init(msg) {
    console.log("hao-asrWoker---init:"+msg+",msg:");
    console.log(msg);
    await loadToFS(msg.data.modelName, msg.data.zip);
    asr = startASR();
    //asr = startASR(thisModule);
  },
  async process(msg) {
    if (asr === null) throw new Error('ASR not ready');
    const asrOutput = asr.processBuffer(msg.data.pcm);
    if (asrOutput === '') return null;
    return {
      isFinal: asrOutput.endsWith('\n'),
      text: asrOutput.trim(),
    };
  },
  async samplerate() {
    if (parser === null) throw new Error('ASR not ready');
    return parser.getSampleRate();
  },
  async reset() {
    if (asr === null) throw new Error('ASR not ready');
    const asrOutput = asr.reset();
    const result = {
      isFinal: asrOutput.endsWith('\n'),
      text: asrOutput.trim(),
    };
    return result;
  },
  async terminate() {
    if (asr !== null) asr.delete();
    asr = null;
  },
};

onmessage = (msg) => {
  const { command } = msg.data;
  const response = { command, ok: true };

  if (command in helper) {
    helper[command](msg)
      .then((value) => { response.value = value; })
      .catch((e) => {
        response.ok = false;
        response.value = e;
      })
      .finally(() => { postMessage(response); });
  } else {
    response.ok = false;
    response.value = new Error(`Unknown command '${command}'`);
    postMessage(response);
  }
};

./kaldi/src/online2bin/online2-tcp-nnet3-decode-faster-reorganized.cc

// online2bin/online2-tcp-nnet3-decode-faster.cc

// Copyright 2014  Johns Hopkins University (author: Daniel Povey)
//           2016  Api.ai (Author: Ilya Platonov)
//           2018  Polish-Japanese Academy of Information Technology (Author: Danijel Korzinek)

// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//  http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#include <netinet/in.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <poll.h>
#include <signal.h>
#include <arpa/inet.h>
#include <unistd.h>
#include <string>

#include "feat/wave-reader.h"
#include "online2/online-nnet3-decoding.h"
#include "online2/online-nnet2-feature-pipeline.h"
#include "online2/onlinebin-util.h"
#include "online2/online-timing.h"
#include "online2/online-endpoint.h"
#include "fstext/fstext-lib.h"
#include "lat/lattice-functions.h"
#include "util/kaldi-thread.h"
#include "nnet3/nnet-utils.h"

namespace kaldi {

class TcpServer {
 public:
  explicit TcpServer(int read_timeout);
  ~TcpServer();

  bool Listen(int32 port);  // start listening on a given port
  int32 Accept();  // accept a client and return its descriptor

  // get more data and return false if end-of-stream
  bool ReadChunk(size_t len);

  Vector<BaseFloat> GetChunk();  // get the data read by above method

  // write to accepted client
  bool Write(const std::string &msg);
  bool WriteLn(const std::string &msg, const std::string &eol = "\n");

  void Disconnect();

 private:
  struct ::sockaddr_in h_addr_;
  int32 server_desc_, client_desc_;
  int16 *samp_buf_;
  size_t buf_len_, has_read_;
  pollfd client_set_[1];
  int read_timeout_;
};

std::string LatticeToString(
    const Lattice &lat,
    const fst::SymbolTable &word_syms
) {
  LatticeWeight weight;
  std::vector<int32> alignment;
  std::vector<int32> words;
  GetLinearSymbolSequence(lat, &alignment, &words, &weight);

  std::ostringstream msg;
  for (size_t i = 0; i < words.size(); i++) {
    std::string s = word_syms.Find(words[i]);
    if (s.empty()) {
      KALDI_WARN << "Word-id " << words[i] << " not in symbol table.";
      msg << "<#" << std::to_string(i) << "> ";
    } else {
        msg << s << " ";
    }
  }
  return msg.str();
}

std::string GetTimeString(int32 t_beg, int32 t_end, BaseFloat time_unit) {
  constexpr size_t kBufferLen { 100 };
  char buffer[kBufferLen];
  double t_beg2 = t_beg * time_unit;
  double t_end2 = t_end * time_unit;
  snprintf(buffer, kBufferLen, "%.2f %.2f", t_beg2, t_end2);
  return std::string(buffer);
}

int32 GetLatticeTimeSpan(const Lattice& lat) {
  std::vector<int32> times;
  LatticeStateTimes(lat, &times);
  return times.back();
}

std::string LatticeToString(
  const CompactLattice &clat,
  const fst::SymbolTable &word_syms
) {
  if (clat.NumStates() == 0) {
    KALDI_WARN << "Empty lattice.";
    return "";
  }
  CompactLattice best_path_clat;
  CompactLatticeShortestPath(clat, &best_path_clat);

  Lattice best_path_lat;
  ConvertLattice(best_path_clat, &best_path_lat);
  return LatticeToString(best_path_lat, word_syms);
}

struct OnlineASROptionParser: public ParseOptions {
  OnlineASROptionParser();
  explicit OnlineASROptionParser(int argc, const char* const* argv);
  int Read(int, const char* const*);

  // Members
  static constexpr const char *usage =
    "Reads in audio from a network socket and performs online\n"
    "decoding with neural nets (nnet3 setup), with iVector-based\n"
    "speaker adaptation and endpointing.\n"
    "Note: some configuration values and inputs are set via config\n"
    "files whose filenames are passed as options\n"
    "\n"
    "Usage: online2-tcp-nnet3-decode-faster [options] <nnet3-in> "
    "<fst-in> <word-symbol-table>\n";
  // ASR stuff
  BaseFloat output_period = 1;
  bool produce_time = false;
  BaseFloat samp_freq = 16000.0;
  OnlineEndpointConfig endpoint_opts;
  OnlineNnet2FeaturePipelineConfig feature_opts;
  nnet3::NnetSimpleLoopedComputationOptions decodable_opts;
  LatticeFasterDecoderConfig decoder_opts;
  std::string nnet3_rxfilename;
  std::string fst_rxfilename;
  std::string word_syms_filename;
  // TCP stuff
  BaseFloat chunk_length_secs = 0.18;
  int port_num = 5050;
  int read_timeout = 3;
};

class OnlineASR {
 public:
  static constexpr const char eou {'\n'};
  static constexpr const char tmp_eou {'\r'};

  explicit OnlineASR(int argc, const char *const argv[]);
  explicit OnlineASR(const std::vector<std::string> &args);
  explicit OnlineASR(const OnlineASROptionParser& po);
  std::string ProcessBuffer(int16 *, size_t);
  std::string ProcessSTLVector(const std::vector<int16>&);
  std::string ProcessVector(const Vector<BaseFloat>&);
  std::string Reset();
  ~OnlineASR();

 private:
  BaseFloat samp_freq;
  int32 frame_offset {0};
  int32 check_period;
  int32 samp_count {0};
  bool produce_time {false};
  // Model related members
  nnet3::DecodableNnetSimpleLoopedInfo *decodable_info = nullptr;
  OnlineNnet2FeaturePipelineInfo *feature_info = nullptr;
  LatticeFasterDecoderConfig decoder_opts;
  nnet3::NnetSimpleLoopedComputationOptions decodable_opts;
  fst::Fst<fst::StdArc> *decode_fst = nullptr;
  TransitionModel trans_model;
  nnet3::AmNnetSimple am_nnet;
  fst::SymbolTable *word_syms = nullptr;
  OnlineEndpointConfig endpoint_opts;
  // Stream parameters
  OnlineNnet2FeaturePipeline *feature_pipeline = nullptr;
  SingleUtteranceNnet3Decoder *decoder = nullptr;
  // Utterance parameters
  OnlineSilenceWeighting *silence_weighting = nullptr;
  std::vector<std::pair<int32, BaseFloat> > delta_weights;

  // private methods
  void InitClass(const OnlineASROptionParser& parser);
  void InitWords(const std::string& filename);
  void UpdateDecoder(const Vector<BaseFloat>&);
  std::string CheckDecoderOutput();
  std::string PrependTimestamps(const std::string&);
  void ResetStreamDecoder();
  void ResetUtteranceDecoder();
};
}  // end kaldi namespace

#ifndef __EMSCRIPTEN__

int main(int argc, const char* const* argv) {
  using kaldi::int32;
  using kaldi::int64;
  using kaldi::OnlineASR;
  using kaldi::OnlineASROptionParser;
  using kaldi::Vector;
  using kaldi::BaseFloat;
  using kaldi::TcpServer;

  OnlineASROptionParser po;

  try {
    po.Read(argc, argv);
    OnlineASR onlineASR(po);

    // ignore SIGPIPE to avoid crashing when socket forcefully disconnected
    signal(SIGPIPE, SIG_IGN);

    size_t chunk_len = static_cast<size_t>(po.chunk_length_secs * po.samp_freq);
    TcpServer server(po.read_timeout);
    server.Listen(po.port_num);

    while (true) {
      server.Accept();
      bool eos {false};

      while (!eos) {
        while (true) {
          eos = !server.ReadChunk(chunk_len);
          if (eos) {
            std::string msg { onlineASR.Reset() };
            KALDI_VLOG(1) << "EndOfAudio, sending message: " << msg;
            server.Write(msg);
            server.Disconnect();
            break;
          }
          Vector<BaseFloat> wave_part = server.GetChunk();
          std::string msg { onlineASR.ProcessVector(wave_part) };
          if (msg != "") {
            server.Write(msg);
            if (msg[msg.length() - 1] == onlineASR.tmp_eou) {
              KALDI_VLOG(1) << "Temporary transcript: " << msg;
            } else {
              KALDI_VLOG(1) << "Endpoint, sending message: " << msg;
              break;
            }
          }
        }
      }
    }
  } catch (const std::invalid_argument& e) {
    po.PrintUsage();
    return 1;
  } catch (const std::exception &e) {
    std::cerr << e.what();
    return -1;
  }
}
#else

#include <emscripten/bind.h>
#include <emscripten/val.h>
#include <iterator>

using std::vector;
using kaldi::int16;
using emscripten::val;
using emscripten::class_;
using emscripten::optional_override;
using emscripten::register_vector;

/* Convert JS Int16Array to C++ std::vector<kaldi::int16> without copy of data
*/
vector<int16> typed_array_to_vector(const val &int16_array) {
  unsigned int length = int16_array["length"].as<unsigned int>();
  vector<int16> vec(length);

  val memory = val::module_property("HEAP16")["buffer"];
  val memoryView = val::global("Int16Array").new_(memory,
      reinterpret_cast<std::uintptr_t>(vec.data()), length);

  memoryView.call<void>("set", int16_array);

  return vec;
}

EMSCRIPTEN_BINDINGS(asr) {
  class_<kaldi::OnlineASR>("OnlineASR")
    .constructor<const std::vector<std::string>& >()
    // Inject lambda before class method call to adapt I/O types
    .function("processBuffer", optional_override(
      [](kaldi::OnlineASR& self, const val& int16_array) {
        vector<int16> vect_array = typed_array_to_vector(int16_array);
        return self.ProcessSTLVector(vect_array);
      })
    )
    .function("reset", &kaldi::OnlineASR::Reset)
  ;
  // Define JS class StringList to be understood as vector<string> in C++
  register_vector<std::string>("StringList");
}

#endif

namespace kaldi {
TcpServer::TcpServer(int read_timeout) {
  server_desc_ = -1;
  client_desc_ = -1;
  samp_buf_ = NULL;
  buf_len_ = 0;
  read_timeout_ = 1000 * read_timeout;
}

bool TcpServer::Listen(int32 port) {
  h_addr_.sin_addr.s_addr = INADDR_ANY;
  h_addr_.sin_port = htons(port);
  h_addr_.sin_family = AF_INET;

  server_desc_ = socket(AF_INET, SOCK_STREAM, 0);

  if (server_desc_ == -1) {
    KALDI_ERR << "Cannot create TCP socket!";
    return false;
  }

  int32 flag = 1;
  int32 len = sizeof(int32);
  if (setsockopt(server_desc_, SOL_SOCKET, SO_REUSEADDR, &flag, len) == -1) {
    KALDI_ERR << "Cannot set socket options!";
    return false;
  }

  if (bind(server_desc_, (struct sockaddr *) &h_addr_, sizeof(h_addr_)) == -1) {
    KALDI_ERR << "Cannot bind to port: " << port << " (is it taken?)";
    return false;
  }

  if (listen(server_desc_, 1) == -1) {
    KALDI_ERR << "Cannot listen on port!";
    return false;
  }

  KALDI_LOG << "TcpServer: Listening on port: " << port;

  return true;
}

TcpServer::~TcpServer() {
  Disconnect();
  if (server_desc_ != -1)
    close(server_desc_);
  delete[] samp_buf_;
}

int32 TcpServer::Accept() {
  KALDI_LOG << "Waiting for client...";

  socklen_t len;

  len = sizeof(struct sockaddr);
  client_desc_ = accept(server_desc_, (struct sockaddr *) &h_addr_, &len);

  struct sockaddr_storage addr;
  char ipstr[20];

  len = sizeof addr;
  getpeername(client_desc_, (struct sockaddr *) &addr, &len);

  struct sockaddr_in *s = (struct sockaddr_in *) &addr;
  inet_ntop(AF_INET, &s->sin_addr, ipstr, sizeof ipstr);

  client_set_[0].fd = client_desc_;
  client_set_[0].events = POLLIN;

  KALDI_LOG << "Accepted connection from: " << ipstr;

  return client_desc_;
}

bool TcpServer::ReadChunk(size_t len) {
  if (buf_len_ != len) {
    buf_len_ = len;
    delete[] samp_buf_;
    samp_buf_ = new int16[len];
  }

  ssize_t ret;
  int poll_ret;
  size_t to_read = len;
  has_read_ = 0;
  while (to_read > 0) {
    poll_ret = poll(client_set_, 1, read_timeout_);
    if (poll_ret == 0) {
      KALDI_WARN << "Socket timeout! Disconnecting...";
      break;
    }
    if (poll_ret < 0) {
      KALDI_WARN << "Socket error! Disconnecting...";
      break;
    }
    ret = read(client_desc_, static_cast<void *>(samp_buf_ + has_read_),
               to_read * sizeof(int16));
    if (ret <= 0) {
      KALDI_WARN << "Stream over...";
      break;
    }
    to_read -= ret / sizeof(int16);
    has_read_ += ret / sizeof(int16);
  }

  return has_read_ > 0;
}

Vector<BaseFloat> TcpServer::GetChunk() {
  Vector<BaseFloat> buf;

  buf.Resize(static_cast<MatrixIndexT>(has_read_));

  for (int i = 0; i < has_read_; i++)
    buf(i) = static_cast<BaseFloat>(samp_buf_[i]);

  return buf;
}

bool TcpServer::Write(const std::string &msg) {
  const char *p = msg.c_str();
  size_t to_write = msg.size();
  size_t wrote = 0;
  while (to_write > 0) {
    ssize_t ret = write(client_desc_, static_cast<const void *>(p + wrote),
                        to_write);
    if (ret <= 0)
      return false;

    to_write -= ret;
    wrote += ret;
  }

  return true;
}

bool TcpServer::WriteLn(const std::string &msg, const std::string &eol) {
  if (Write(msg))
    return Write(eol);
  else
    return false;
}

void TcpServer::Disconnect() {
  if (client_desc_ != -1) {
    close(client_desc_);
    client_desc_ = -1;
  }
}

OnlineASROptionParser::OnlineASROptionParser(): ParseOptions{usage} {
  g_num_threads = 0;

  Register("samp-freq", &samp_freq,
           "Sampling frequency of the input signal (coded as 16-bit slinear).");
  Register("chunk-length", &chunk_length_secs,
           "Length of chunk size in seconds, that we process.");
  Register("output-period", &output_period,
           "How often in seconds, do we check for changes in output.");
  Register("num-threads-startup", &g_num_threads,
           "Number of threads used when initializing iVector extractor.");
  Register("read-timeout", &read_timeout,
           "Number of seconds of timout for TCP audio data to appear on the "
           "stream. Use -1 for blocking.");
  Register("port-num", &port_num,
           "Portnumber the server will listen on.");
  Register("produce-time", &produce_time,
           "Prepend begin/end times between endpoints (e.g. '5.46 6.81"
           " <text_output>', in seconds)");

  endpoint_opts.Register(this);
  feature_opts.Register(this);
  decodable_opts.Register(this);
  decoder_opts.Register(this);
}

OnlineASROptionParser::OnlineASROptionParser(int argc,
                                             const char* const * argv):
  OnlineASROptionParser() {
    Read(argc, argv);
  }

int OnlineASROptionParser::Read(int argc, const char* const* argv) {
  int read_value = ParseOptions::Read(argc, argv);
  if (NumArgs() != 3)
    throw std::invalid_argument("Wrong number of arguments\n");

  nnet3_rxfilename = GetArg(1);
  fst_rxfilename = GetArg(2);
  word_syms_filename = GetArg(3);

  return read_value;
}

OnlineASR::OnlineASR(int argc, const char *const argv[]):
  OnlineASR(OnlineASROptionParser(argc, argv)) {
  }

OnlineASR::OnlineASR(const std::vector<std::string> &args) {
  // Convert args to const char* const *
  std::vector<const char*> char_array;
  char_array.reserve(args.size());
  for (int i = 0; i < args.size(); ++i)
    char_array.push_back(const_cast<char*>(args[i].c_str()));

  int argc { static_cast<int>(char_array.size()) };
  OnlineASROptionParser parser {argc, &char_array[0]};
  InitClass(parser);
}

OnlineASR::OnlineASR(const OnlineASROptionParser& parser) {
  InitClass(parser);
}

void OnlineASR::InitClass(const OnlineASROptionParser& parser) {
  decodable_opts = parser.decodable_opts;
  decoder_opts = parser.decoder_opts;
  endpoint_opts = parser.endpoint_opts;
  samp_freq = parser.samp_freq;
  check_period = static_cast<int32>(samp_freq * parser.output_period);
  produce_time = parser.produce_time;

  feature_info = new OnlineNnet2FeaturePipelineInfo(parser.feature_opts);
  InitWords(parser.word_syms_filename);

  KALDI_VLOG(1) << "Loading AM...";
  {
    bool binary;
    Input ki(parser.nnet3_rxfilename, &binary);
    trans_model.Read(ki.Stream(), binary);
    am_nnet.Read(ki.Stream(), binary);
    SetBatchnormTestMode(true, &(am_nnet.GetNnet()));
    SetDropoutTestMode(true, &(am_nnet.GetNnet()));
    nnet3::CollapseModel(nnet3::CollapseModelConfig(), &(am_nnet.GetNnet()));
  }

  KALDI_VLOG(1) << "Loading FST...";
  decode_fst = fst::ReadFstKaldiGeneric(parser.fst_rxfilename);
  // this object contains precomputed stuff that is used by all decodable
  // objects.  It takes a pointer to am_nnet because if it has iVectors it has
  // to modify the nnet to accept iVectors at intervals.
  decodable_info = new nnet3::DecodableNnetSimpleLoopedInfo(decodable_opts,
      &am_nnet);

  ResetStreamDecoder();
}

void OnlineASR::ResetStreamDecoder() {
  frame_offset = 0;

  delete feature_pipeline;
  feature_pipeline = new OnlineNnet2FeaturePipeline(*feature_info);

  delete decoder;
  decoder = new SingleUtteranceNnet3Decoder(decoder_opts, trans_model,
                                            *decodable_info, *decode_fst,
                                            feature_pipeline);
  ResetUtteranceDecoder();
}

void OnlineASR::InitWords(const std::string &filename) {
  if (!filename.empty())
    if (!(word_syms = fst::SymbolTable::ReadText(filename)))
      KALDI_ERR << "Could not read symbol table from file "
                << filename;
}

void OnlineASR::ResetUtteranceDecoder() {
  decoder->InitDecoding(frame_offset);
  delete silence_weighting;
  silence_weighting = new OnlineSilenceWeighting(
      trans_model,
      feature_info->silence_weighting_config,
      decodable_opts.frame_subsampling_factor);
  delta_weights = std::vector<std::pair<int32, BaseFloat> >();
}

std::string OnlineASR::ProcessBuffer(int16 *samp_buf, size_t buf_len) {
  Vector<BaseFloat> buf;
  buf.Resize(static_cast<MatrixIndexT>(buf_len));
  for (int i = 0; i < buf_len; ++i)
    buf(i) = static_cast<BaseFloat>(samp_buf[i]);

  return ProcessVector(buf);
}

std::string OnlineASR::ProcessVector(const Vector<BaseFloat>& buf) {
  UpdateDecoder(buf);
  return CheckDecoderOutput();
}

void OnlineASR::UpdateDecoder(const Vector<BaseFloat>& buf) {
  feature_pipeline->AcceptWaveform(samp_freq, buf);
  samp_count += buf.Dim();

  if (silence_weighting->Active() &&
      feature_pipeline->IvectorFeature() != NULL) {
    silence_weighting->ComputeCurrentTraceback(decoder->Decoder());
    silence_weighting->GetDeltaWeights(feature_pipeline->NumFramesReady(),
                                       frame_offset * decodable_opts.frame_subsampling_factor,
                                       &delta_weights);
    feature_pipeline->UpdateFrameWeights(delta_weights);
  }

  decoder->AdvanceDecoding();
}

std::string OnlineASR::CheckDecoderOutput() {
  if (decoder->EndpointDetected(endpoint_opts)) {
    samp_count %= check_period;
    decoder->FinalizeDecoding();
    frame_offset += decoder->NumFramesDecoded();
    CompactLattice lat;
    decoder->GetLattice(true, &lat);
    std::string msg = LatticeToString(lat, *word_syms);
    if (produce_time) msg = PrependTimestamps(msg);

    ResetUtteranceDecoder();
    return msg + eou;
  }

  // Force temporary result
  if (samp_count > check_period) {
    samp_count %= check_period;
    if (decoder->NumFramesDecoded() > 0) {
      Lattice lat;
      decoder->GetBestPath(false, &lat);
      TopSort(&lat);  // for LatticeStateTimes(),
      std::string msg = LatticeToString(lat, *word_syms);

      if (produce_time) {
        int32 frame_subsampling { decodable_opts.frame_subsampling_factor };
        BaseFloat frame_shift { feature_info->FrameShiftInSeconds() };
        int32 t_beg = frame_offset;
        int32 t_end = frame_offset + GetLatticeTimeSpan(lat);
        msg = GetTimeString(t_beg, t_end, frame_shift * frame_subsampling) + " "
              + msg;
      }
      return msg + tmp_eou;
    }
  }

  return "";
}

std::string OnlineASR::PrependTimestamps(const std::string& msg) {
  int32 frame_subsampling { decodable_opts.frame_subsampling_factor };
  BaseFloat frame_shift { feature_info->FrameShiftInSeconds() };
  int32 t_beg = frame_offset - decoder->NumFramesDecoded();
  int32 t_end = frame_offset;
  return GetTimeString(t_beg, t_end, frame_shift * frame_subsampling) + " "
         + msg;
}

std::string OnlineASR::ProcessSTLVector(const std::vector<int16>& samp_buf) {
  // cast input to float
  Vector<BaseFloat> buf;
  size_t buf_len = samp_buf.size();
  buf.Resize(static_cast<MatrixIndexT>(buf_len));
  for (int i = 0; i < buf_len; ++i)
    buf(i) = static_cast<BaseFloat>(samp_buf[i]);

  return ProcessVector(buf);
}

std::string OnlineASR::Reset() {
  feature_pipeline->InputFinished();
  decoder->AdvanceDecoding();
  decoder->FinalizeDecoding();

  std::string msg {""};
  frame_offset += decoder->NumFramesDecoded();
  if (decoder->NumFramesDecoded() > 0) {
    CompactLattice lat;
    decoder->GetLattice(true, &lat);
    msg = LatticeToString(lat, *word_syms);
    if (produce_time) msg = PrependTimestamps(msg);
  }

  ResetStreamDecoder();
  return msg + eou;
}

OnlineASR::~OnlineASR() {
  delete feature_info;
  delete feature_pipeline;
  delete decoder;
  delete decodable_info;
  delete word_syms;
  delete silence_weighting;
  delete decode_fst;
}
}  // namespace kaldi
分享到:
评论

相关推荐

    comet:[ICLR 2021]少量学习的概念学习者

    少量学习的概念学习者 凯迪(Kaidi Cao)*,玛丽亚(MariaBrbić)*,尤里(Jure Leskovec) 此存储库包含COMET算法的...我们在这里提供一个示例: 运行python ./train.py --dataset CUB --model Conv4NP --method

    kaldi详细资料_kadi语音识别工具_

    语音识别工具kaldi及其应用,kaidi全部资料,适合新手使用

    GCN_ADV_Train:图神经网络的对抗训练

    此外,利用我们基于梯度的攻击,我们提出了针对GNN的第一个基于优化的对抗训练。 引用这项工作: 徐凯迪*,陈洪格*,刘思佳,陈品宇,翁翠薇,洪明义和林雪, ,IJCAI 2019。(*平等贡献) @inproceedings{xu2019...

    node_tasted:Node.js 浅尝辄止——分享至我仍未知道名字的十楼公司

    By Kaidi, ZHU, R&D Engineer of and . 正确打开姿势 预备工作:请确保已安装 Node.js 在你的电脑。 安装 依赖。执行 $ npm install。 启动它,执行 $ npm start 。 在弹出的浏览器窗口中点击 tasted.md 即可。 若非...

    python源码基于YOLOV5安全帽检测系统及危险区域入侵检测告警系统源码.rar

    本资源提供了一个基于YOLOv5的安全帽检测系统及危险区域入侵检测告警系统的Python源码 该系统主要利用深度学习和计算机视觉技术,实现了安全帽和危险区域入侵的实时检测与告警。具体功能如下: 1. 安全帽检测:系统能够识别并检测工人是否佩戴安全帽,对于未佩戴安全帽的工人,系统会发出告警信号,提醒工人佩戴安全帽。 2. 危险区域入侵检测:系统能够实时监测危险区域,如高空作业、机械设备等,对于未经授权的人员或车辆进入危险区域,系统会立即发出告警信号,阻止入侵行为,确保安全。 本资源采用了YOLOv5作为目标检测算法,该算法基于深度学习和卷积神经网络,具有较高的检测精度和实时性能。同时,本资源还提供了详细的使用说明和示例代码,便于用户快速上手和实现二次开发。 运行测试ok,课程设计高分资源,放心下载使用!该资源适合计算机相关专业(如人工智能、通信工程、自动化、软件工程等)的在校学生、老师或者企业员工下载,适合小白学习或者实际项目借鉴参考! 当然也可作为毕业设计、课程设计、课程作业、项目初期立项演示等。如果基础还行,可以在此代码基础之上做改动以实现更多功能,如增加多种安全帽和危险区域的识别、支持多种传感器数据输入、实现远程监控等。

    基于SpringBoot的响应式技术博客的设计和实现(源码+文档)

    本课题将许多当前比较热门的技术框架有机的集合起来,比如Spring boot、Spring data、Elasticsearch等。同时采用Java8作为主要开发语言,利用新型API,改善传统的开发模式和代码结构,实现了具有实时全文搜索、博客编辑、分布式文件存贮和能够在浏览器中适配移动端等功能的响应式技术博客。 本毕业设计选用SpringBoot框架,结合Thymeleaf,SpringData,SpringSecurity,Elasticsearch等技术,旨在为技术人员设计并实现一款用于记录并分享技术文档的技术博客。通过该技术博客,方便技术人员记录自己工作和学习过程中的点滴,不断地进行技术的总结和积累,从而提升自己的综合能力,并通过博客这一平台,把自己的知识、经验、教训分享给大家,为志同道合者提供一个相互交流、共同学习的平台,促使更多的人共同进步[9]。学习到别人的一些良好的设计思路、编码风格和优秀的技术能力,使笔者的设计初衷。本系统主要面向web端的用户,希望能给用户更多的学习和交流的选择。

    javalab 3.zip

    javalab 3.zip

    J0001基于javaWeb的健身房管理系统设计与实现

    该系统基于javaweb整合,数据层为MyBatis,mysql数据库,具有完整的业务逻辑,适合选题:健身、健身房、健身房管理等 健身房管理系统开发使用JSP技术和MySQL数据库,该系统所使用的是Java语言,Java是目前最为优秀的面相对象的程序设计语言,只需要开发者对概念有一些了解就可以编写出程序,因此,开发该系统总体上不会有很大的难度,同时在开发系统时,所使用的数据库也是必不可少的。开发此系统所使用的技术都是通过在大学期间学习的,对每科课程都有很好的掌握,对系统的开发具有很好的判断性。因此,在完成该系统的开发建设时所使用的技术是完全可行的。 学员主要实现的功能有:网站信息、课程信息、教练列表、我的信息、登录 员工主要实现的功能有:工资查询、会员管理、器材借还、健身卡管理、个人中心、登录 教练主要实现的功能有:工资查询、学员列表、个人中心 管理员是系统的核心,可以对系统信息进行更新和维护,主要实现的功能有:个人中心、学员管理、教练管理、网站信息管理、器械信息管理、课程信息管理。

    架构.cpp

    架构.cpp

    利用Python实现中文文本关键词抽取(三种方法)

    文本关键词抽取,是对文本信息进行高度凝练的一种有效手段,通过3-5个词语准确概括文本的主题,帮助读者快速理解文本信息。目前,用于文本关键词提取的主要方法有四种:基于TF-IDF的关键词抽取、基于TextRank的关键词抽取、基于Word2Vec词聚类的关键词抽取,以及多种算法相融合的关键词抽取。笔者在使用前三种算法进行关键词抽取的学习过程中,发现采用TF-IDF和TextRank方法进行关键词抽取在网上有很多的例子,代码和步骤也比较简单,但是采用Word2Vec词聚类方法时网上的资料并未把过程和步骤表达的很清晰。因此,本文分别采用 1. TF-IDF方法 2. TextRank方 3. Word2Vec词聚类方法 实现对专利文本(同样适用于其它类型文本)的关键词抽取,通过理论与实践相结合的方式,一步步了解、学习、实现中文文本关键词抽取。

    演示Asm字节码插桩asmd-demo-master.zip

    演示Asm字节码插桩asmd-demo-master.zip

    VB+access干部档案管理系统(源代码+系统).zip

    档案是国家机构、社会组织在干部管理活动中形成的、记述和反映个人经历和德才表现等情况、以人头为单位集中保存以备查考的原始记录。 档案管理的目的是为了档案的利用。如果放松管理,无论对单位和对个人都会影响档案的利用。举个例子,如果应该进入档案的材料没及时归档,则对个人资料的记载就是不完整的,缺乏了这一部分的凭证,就无法出具相关证明。如果发生了损坏或丢失档案的情况,后果就更加严重,有的档案材料是难以重新建立的。档案的管理是与干部、流动手续的衔接密切相关的。以北京市人才服务中心为例,拥有着全市最大的档案管理中心,共保管了档案12万份。这些档案的利用率相对很高,表现在出具干部证明、婚育证明、出国政审、职称评定、工龄认定以及各种保险的相关手续等方面。档案中心的工作人员每天都要接待大量的企业用人中的查询、查阅。 档案好像是计划经济的产物,在市场经济条件下,随着人才流动潮流的涌现,人们思想观念上的放开,档案越来越被人们所冷落和忽视。到底档案对个人以及人力资源部意味着什么,放松对档案的管理会带来哪些后果呢? 目前我国的档案管理社会化趋势日益明显。非公有制单位,国有企业事业单位发展干部代理使流动人员档案管理

    本算法是结合“时间遗忘曲线”和“物品类….zip

    协同过滤算法(Collaborative Filtering)是一种经典的推荐算法,其基本原理是“协同大家的反馈、评价和意见,一起对海量的信息进行过滤,从中筛选出用户可能感兴趣的信息”。它主要依赖于用户和物品之间的行为关系进行推荐。 协同过滤算法主要分为两类: 基于物品的协同过滤算法:给用户推荐与他之前喜欢的物品相似的物品。 基于用户的协同过滤算法:给用户推荐与他兴趣相似的用户喜欢的物品。 协同过滤算法的优点包括: 无需事先对商品或用户进行分类或标注,适用于各种类型的数据。 算法简单易懂,容易实现和部署。 推荐结果准确性较高,能够为用户提供个性化的推荐服务。 然而,协同过滤算法也存在一些缺点: 对数据量和数据质量要求较高,需要大量的历史数据和较高的数据质量。 容易受到“冷启动”问题的影响,即对新用户或新商品的推荐效果较差。 存在“同质化”问题,即推荐结果容易出现重复或相似的情况。 协同过滤算法在多个场景中有广泛的应用,如电商推荐系统、社交网络推荐和视频推荐系统等。在这些场景中,协同过滤算法可以根据用户的历史行为数据,推荐与用户兴趣相似的商品、用户或内容,从而提高用户的购买转化率、活跃度和社交体验。 未来,协同过滤算法的发展方向可能是结合其他推荐算法形成混合推荐系统,以充分发挥各算法的优势。

    JAVAWEB校园二手平台项目.zip

    JAVAWEB校园二手平台项目,基本功能包括:个人信息、商品管理;交易商品板块管理等。本系统结构如下: (1)本月推荐交易板块: 电脑及配件:实现对该类商品的查询、用户留言功能 通讯器材:实现对该类商品的查询、用户留言功能 视听设备:实现对该类商品的查询、用户留言功能 书籍报刊:实现对该类商品的查询、用户留言功能 生活服务:实现对该类商品的查询、用户留言功能 房屋信息:实现对该类商品的查询、用户留言功能 交通工具:实现对该类商品的查询、用户留言功能 其他商品:实现对该类商品的查询、用户留言功能 (2)载入个人用户: 用户登陆 用户注册 (3)个人平台: 信息管理:实现对商品的删除、修改、查询功能 添加二手信息:实现对新商品的添加 修改个人资料:实现对用户个人信息的修改 注销

    文档+程序动态四足机器人的自由模型预测控制

    Representation-Free Model Predictive Control for Dynamic Quadruped 专注于动态四足机器人的控制问题,特别强调了自由模型预测控制(Free-MPC)在该领域的应用。内容涉及自由MPC的原理、算法构建和在四足机器人动态平衡与运动控制中的实践案例。通过案例分析,揭示了自由MPC如何提升四足机器人在复杂地形下的适应性和稳定性。适合机器人工程师、控制理论研究者和相关专业学生阅读。使用场景包括机器人设计与开发、控制算法研究以及高等教育课程。目标是推动四足机器人控制技术的发展,增强其在多变环境中的表现。 关键词标签: 四足机器人 动态控制 自由模型预测控制 Free-MPC 机器人工程

    Building Android Apps in Python

    Building Android Apps in Python

    基于PHP实现的WEB图片共享系统(源代码+论文)

    基于PHP实现的WEB图片共享系统(源代码+论文)

    soxr-0.3.5-cp37-cp37m-win_amd64.whl.zip

    soxr-0.3.5-cp37-cp37m-win_amd64.whl.zip

    HeartShapeArtist

    这个 HeartShapeArtist 脚本使用 Python 的 turtle 图形库来绘制一个美丽的红色心形。它首先设置画布和画笔的属性,包括速度、颜色和背景,然后通过连续的小步骤绘制心形的上半部分和尖端,完成后进行颜色填充。该脚本适合用作学习如何使用 turtle 模块进行基本图形绘制的入门示例,也可以作为情人节或其他特殊场合的小项目。只需运行脚本,即可在屏幕上看到一个完美的心形图案展现出来。

    MySQL文件,省市区sql数据,微调

    参考了https://github.com/nhjclxc/District-SQL; 调整了北京优先级,删除了level = 4的数据;

Global site tag (gtag.js) - Google Analytics