Skip to content
Snippets Groups Projects
Commit 7ecfd161 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[tests] Remove traces of signs-to-tb "model"

parent 8e9081ce
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -513,17 +513,17 @@ def test_compare_pasa_montgomery(temporary_basedir): ...@@ -513,17 +513,17 @@ def test_compare_pasa_montgomery(temporary_basedir):
@pytest.mark.skip(reason="Test need to be updated") @pytest.mark.skip(reason="Test need to be updated")
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_train_signstotb_montgomery_rs(temporary_basedir, datadir): def test_train_mlp_montgomery_rs(temporary_basedir, datadir):
from ptbench.scripts.train import train from ptbench.scripts.train import train
runner = CliRunner() runner = CliRunner()
with stdout_logging() as buf: with stdout_logging() as buf:
output_folder = str(temporary_basedir / "results/signstotb") output_folder = str(temporary_basedir / "results/mlp")
result = runner.invoke( result = runner.invoke(
train, train,
[ [
"signs_to_tb", "mlp",
"montgomery_rs", "montgomery_rs",
"-vv", "-vv",
"--epochs=1", "--epochs=1",
...@@ -567,7 +567,7 @@ def test_train_signstotb_montgomery_rs(temporary_basedir, datadir): ...@@ -567,7 +567,7 @@ def test_train_signstotb_montgomery_rs(temporary_basedir, datadir):
@pytest.mark.skip(reason="Test need to be updated") @pytest.mark.skip(reason="Test need to be updated")
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_predict_signstotb_montgomery_rs(temporary_basedir, datadir): def test_predict_mlp_montgomery_rs(temporary_basedir, datadir):
from ptbench.scripts.predict import predict from ptbench.scripts.predict import predict
runner = CliRunner() runner = CliRunner()
...@@ -577,12 +577,12 @@ def test_predict_signstotb_montgomery_rs(temporary_basedir, datadir): ...@@ -577,12 +577,12 @@ def test_predict_signstotb_montgomery_rs(temporary_basedir, datadir):
result = runner.invoke( result = runner.invoke(
predict, predict,
[ [
"signs_to_tb", "mlp",
"montgomery_rs", "montgomery_rs",
"-vv", "-vv",
"--batch-size=1", "--batch-size=1",
"--relevance-analysis", "--relevance-analysis",
f"--weight={str(datadir / 'lfs' / 'models' / 'signstotb.ckpt')}", f"--weight={str(datadir / 'lfs' / 'models' / 'mlp.ckpt')}",
f"--output-folder={output_folder}", f"--output-folder={output_folder}",
], ],
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment