Inference

Pytorch Modules for Training Models for sequential data

source

InferenceWrapper

 InferenceWrapper (learner, device='cpu')

A wrapper class to simplify inference with a trained tsfast/fastai Learner on NumPy data. Handles normalization and state reset automatically.

from tsfast.datasets.core import create_dls_test
from tsfast.learner import RNNLearner
from tsfast.prediction import FranSysLearner
dls = create_dls_test()
lrn = RNNLearner(dls)
model = InferenceWrapper(lrn)
model(np.random.randn(100, 1)).shape
(100, 1)
model(np.random.randn(100)).shape
(100, 1)
model(np.random.randn(1,100,1)).shape
(100, 1)
lrn = FranSysLearner(dls,10,attach_output=True)
model = InferenceWrapper(lrn)
model(np.random.randn(100, 1),np.random.randn(100, 1)).shape
(100, 1)