import os
import torch
from kitcar_ml.onboarding.model import OnboardingNet
from kitcar_ml.onboarding.script import get_dataloader, test
path_onboarding = os.path.join(os.path.dirname(__file__), "..")
[docs]def test_get_dataloader():
print("Test Get Dataloader")
data_loader = get_dataloader(
os.path.join(path_onboarding, "dataset", "train.yaml"), batch_size=1
)
assert len(data_loader) >= 500
for batch in data_loader:
assert batch[0].size() == torch.Size([1, 1, 28, 28])
assert batch[1].size() == torch.Size([1])
[docs]def test_model():
print("Test Onboarding Test Function")
# Just run the test function once to check if it runs without an exception
test(OnboardingNet(), os.path.join(path_onboarding, "dataset", "test.yaml"))
[docs]def main():
print("Test Onboarding Script")
# Run onboarding setup
os.system(f'python3 {os.path.join(path_onboarding, "setup.py")}')
test_get_dataloader()
test_model()
if __name__ == "__main__":
main()