Commit 3ae3ba7d authored by Flavio TARSETTI's avatar Flavio TARSETTI

Merge branch 'implement_db_v2' into 'master'

Implement db v2 handling

Closes #586

See merge request !417
parents ef854288 230e1999
Pipeline #47715 passed with stages
in 19 minutes and 54 seconds
......@@ -105,19 +105,22 @@ class BackendUtilitiesMixin(object):
environment=dict(name=env.name, version=env.version),
)
raw_access_db_name = "simple_rawdata_access/1"
source_prefix = os.path.join(settings.BASE_DIR, "src", "beat.examples")
db_root_file_path = os.path.join(settings.PREFIX, "db_root.json")
db_path = os.path.join(
settings.PREFIX, "data", raw_access_db_name.replace("/", "_")
)
db_root_data = {raw_access_db_name: db_path}
db_root_data = {}
os.makedirs(db_path, exist_ok=True)
for version in range(1, 3):
raw_access_db_name = f"simple_rawdata_access/{version}"
source_prefix = os.path.join(settings.BASE_DIR, "src", "beat.examples")
db_path = os.path.join(
settings.PREFIX, "data", raw_access_db_name.replace("/", "_")
)
db_root_data[raw_access_db_name] = db_path
with open(os.path.join(db_path, "datafile.txt"), "wt") as datafile:
datafile.write("1")
os.makedirs(db_path, exist_ok=True)
with open(os.path.join(db_path, "datafile.txt"), "wt") as datafile:
datafile.write("1")
db_root_file_path = os.path.join(settings.PREFIX, "db_root.json")
with open(db_root_file_path, "wt") as db_root_file:
db_root_file.write(json.dumps(db_root_data))
......
......@@ -121,25 +121,28 @@ class ScheduleExperimentTest(BaseBackendTestCase):
self.assertEqual(experiment.status, Experiment.DONE)
def test_success(self):
fullname = "user/user/single/1/single"
for name in ["single", "single_db_v2"]:
fullname = f"user/user/single/1/{name}"
with self.subTest(experiment_name=fullname):
xp = Experiment.objects.get(name=fullname.split("/")[-1])
xp = Experiment.objects.get(name=fullname.split("/")[-1])
b0 = xp.blocks.all()[0]
b1 = xp.blocks.all()[1]
b0 = xp.blocks.all()[0]
b1 = xp.blocks.all()[1]
self.check_pending_block_of_pending_experiment(b0)
self.check_pending_block_of_pending_experiment(b1)
self.check_pending_block_of_pending_experiment(b0)
self.check_pending_block_of_pending_experiment(b1)
self.assertEqual(Job.objects.count(), 0)
self.assertEqual(Job.objects.count(), 0)
schedule_experiment(xp)
schedule_experiment(xp)
self.check_pending_block_of_scheduled_experiment(b0)
self.check_pending_block_of_scheduled_experiment(b1, runnable=False)
self.check_pending_block_of_scheduled_experiment(b0)
self.check_pending_block_of_scheduled_experiment(b1, runnable=False)
self.assertEqual(Job.objects.count(), 2)
self.assertEqual(JobSplit.objects.count(), 0)
self.assertEqual(Job.objects.count(), 2)
Job.objects.all().delete()
self.assertEqual(JobSplit.objects.count(), 0)
def test_first_block_in_cache(self):
fullname = "user/user/single/1/single"
......@@ -623,25 +626,28 @@ class SplitNewJobsTest(BaseBackendTestCase):
self.assertTrue(split.end_date is None)
def test_one_experiment_one_slot(self):
fullname = "user/user/single/1/single"
for name in ["single", "single_db_v2"]:
fullname = f"user/user/single/1/{name}"
with self.subTest(experiment_name=fullname):
xp = Experiment.objects.get(name=fullname.split("/")[-1])
xp = Experiment.objects.get(name=fullname.split("/")[-1])
schedule_experiment(xp)
schedule_experiment(xp)
self.assertEqual(Job.objects.count(), 2)
self.assertEqual(JobSplit.objects.count(), 0)
self.assertEqual(Job.objects.count(), 2)
self.assertEqual(JobSplit.objects.count(), 0)
split_new_jobs()
split_new_jobs()
self.assertEqual(JobSplit.objects.count(), 1)
self.assertEqual(JobSplit.objects.count(), 1)
xp.refresh_from_db()
xp.refresh_from_db()
b0 = xp.blocks.all()[0]
split = b0.job.splits.all()[0]
b0 = xp.blocks.all()[0]
split = b0.job.splits.all()[0]
self.check_split(split, split_index=0)
self.check_split(split, split_index=0)
Job.objects.all().delete()
JobSplit.objects.all().delete()
def test_one_experiment_two_slots(self):
fullname = "user/user/single/1/single_split_2"
......@@ -776,9 +782,6 @@ class SplitNewJobsTest(BaseBackendTestCase):
split1 = b0.job.splits.all()[0]
split2 = b0.job.splits.all()[1]
print(split2.start_index)
print(split2.end_index)
self.check_split(split1, split_index=0, start_index=0, end_index=5)
self.check_split(split2, split_index=1, start_index=6, end_index=8)
......
......@@ -145,6 +145,7 @@ class ListCreateBaseView(
def get(self, request, *args, **kwargs):
fields_to_return = self.get_serializer_fields(request)
limit_to_latest_versions = is_true(
request.query_params.get("latest_versions", False)
)
......@@ -243,16 +244,16 @@ class RetrieveUpdateDestroyContributionView(
return super().get_serializer(*args, **kwargs)
def get_object(self):
version = self.kwargs["version"]
author_name = self.kwargs["author_name"]
object_name = self.kwargs["object_name"]
kwargs = dict(
version=self.kwargs["version"], name__iexact=self.kwargs["object_name"]
)
if hasattr(self.model, "author"):
kwargs["author__username__iexact"] = self.kwargs["author_name"]
user = self.request.user
try:
obj = self.model.objects.for_user(user, True).get(
author__username__iexact=author_name,
name__iexact=object_name,
version=version,
)
obj = self.model.objects.for_user(user, True).get(**kwargs)
except self.model.DoesNotExist:
raise drf_exceptions.NotFound()
return obj
......@@ -262,7 +263,10 @@ class RetrieveUpdateDestroyContributionView(
self.check_object_permissions(request, db_object)
# Process the query string
allow_sharing = request.user == db_object.author
allow_sharing = False
if hasattr(db_object, "author"):
allow_sharing = request.user == db_object.author
fields_to_return = self.get_serializer_fields(
request, allow_sharing=allow_sharing
......
......@@ -390,6 +390,29 @@ class VersionableManager(ShareableManager):
def is_last_version(self, object):
return not self.filter(name=object.name, version__gt=object.version).exists()
def create_object(
self,
name,
short_description="",
description="",
declaration=None,
version=1,
previous_version=None,
fork_of=None,
):
create = getattr(self, "create_{}".format(self.model.__name__.lower()))
return create(
name=name,
short_description=short_description,
description=description,
declaration=declaration,
version=version,
previous_version=previous_version,
fork_of=fork_of,
)
# ----------------------------------------------------------
......
......@@ -400,15 +400,20 @@ class ContributionCreationSerializer(ContributionModSerializer):
data["name"] = name
version = data.get("version")
if self.Meta.model.objects.filter(
author=user, name=name, version=version
).exists():
kwargs = {
"name": name,
"version": version,
}
if hasattr(self.Meta.model, "author"):
kwargs["author"] = user
if self.Meta.model.objects.filter(**kwargs).exists():
raise serializers.ValidationError(
"{} {} version {} already exists on this account".format(
"{} {} version {} already exists".format(
self.Meta.model.__name__.lower(), name, version
)
)
previous_version = data.get("previous_version")
fork_of = data.get("fork_of")
......
......@@ -154,13 +154,9 @@ def annotate_full_name(query):
filtered.
"""
return query.annotate(
full_name=Concat(
"author__username",
V("/"),
"name",
V("/"),
"version",
output_field=CharField(),
)
)
args = ["name", V("/"), "version"]
if hasattr(query.model, "author"):
args = ["author__username", V("/")] + args
return query.annotate(full_name=Concat(*args, output_field=CharField()))
......@@ -28,6 +28,7 @@
import json
import logging
import os
from pathlib import PurePath
from rest_framework import exceptions as drf_exceptions
from rest_framework import permissions as drf_permissions
......@@ -98,17 +99,32 @@ def database_to_json(database, request_user, fields_to_return):
def clean_paths(declaration):
pseudo_path = "/path_to_db_folder"
def _clean_path(item):
parameters = item.get("parameters", {})
if "annotations" not in parameters:
return
ppath = PurePath(parameters["annotations"])
if not ppath.is_absolute():
return
cleaned_folder = ppath.parts[-2:]
parameters["annotations"] = os.path.join(pseudo_path, *cleaned_folder)
root_folder = declaration["root_folder"]
cleaned_folder = os.path.basename(os.path.normpath(root_folder))
declaration["root_folder"] = os.path.join(pseudo_path, cleaned_folder)
for protocol in declaration["protocols"]:
for set_ in protocol["sets"]:
if "parameters" in set_ and "annotations" in set_["parameters"]:
annotations_folder = set_["parameters"]["annotations"]
cleaned_folder = annotations_folder.split("/")[-2:]
set_["parameters"]["annotations"] = os.path.join(
pseudo_path, *cleaned_folder
)
# sets is a key only available in the V1 version of databases
if "sets" in protocol:
for set_ in protocol["sets"]:
_clean_path(set_)
else:
for view in protocol["views"].values():
_clean_path(view)
return declaration
......
......@@ -415,18 +415,15 @@ class DatabaseSetTemplate(models.Model):
# ----------------------------------------------------------
class DatabaseSetManager(models.Manager):
def create(self, protocol, template, name):
dataset = DatabaseSet(
name=name,
template=template,
protocol=protocol,
hash=hashDataset(protocol.database.fullname(), protocol.name, name),
)
class DatabaseSetQuerySet(models.query.QuerySet):
def create(self, **kwargs):
protocol = kwargs["protocol"]
name = kwargs["name"]
kwargs["hash"] = hashDataset(protocol.database.fullname(), protocol.name, name)
return super().create(**kwargs)
dataset.save()
return dataset
class DatabaseSetManager(models.manager.BaseManager.from_queryset(DatabaseSetQuerySet)):
def get_by_natural_key(
self, database_name, database_version, protocol_name, name, template_name
):
......
......@@ -29,7 +29,9 @@
from django.db import models
from django.dispatch import receiver
from ..common.utils import annotate_full_name
from ..dataformats.models import DataFormat
from ..protocoltemplates.models import ProtocolTemplate
from .models import Database
from .models import DatabaseProtocol
from .models import DatabaseSet
......@@ -41,8 +43,7 @@ from .models import validate_database
@receiver(models.signals.post_delete, sender=Database)
def auto_delete_file_on_delete(sender, instance, **kwargs):
"""Deletes file from filesystem when ``Database`` object is deleted.
"""
"""Deletes file from filesystem when ``Database`` object is deleted."""
if instance.declaration_file:
instance.declaration_file.delete(save=False)
......@@ -90,7 +91,6 @@ def refresh_protocols(sender, instance, **kwargs):
"""Refreshes changed protocols"""
try:
core = validate_database(instance.declaration)
core.name = instance.fullname()
......@@ -113,30 +113,41 @@ def refresh_protocols(sender, instance, **kwargs):
json_protocol = json_protocols[protocol_name]
# creates all the template sets, outputs, etc for the first time
for set_attr in json_protocol["sets"]:
set_data_list = []
for set_name, set_attr in core.sets(protocol_name).items():
set_data = {
# V1 has a dedicated template name while V2 has one unique
# name
"template": set_attr.get("template", set_name),
"name": set_attr["name"],
"outputs": set_attr["outputs"],
}
set_data_list.append(set_data)
for set_attr in set_data_list:
template_name = json_protocol["template"]
try:
protocol_template = annotate_full_name(
ProtocolTemplate.objects
).get(full_name=template_name)
except ProtocolTemplate.DoesNotExist:
pass
else:
protocol_template.databases.add(instance)
tset_name = json_protocol["template"] + "__" + set_attr["template"]
tset_name = template_name + "__" + set_attr["template"]
dataset_template = DatabaseSetTemplate.objects.filter(name=tset_name)
if not dataset_template: # create
dataset_template = DatabaseSetTemplate(name=tset_name)
dataset_template.save()
else:
dataset_template = dataset_template[0]
dataset_template, _ = DatabaseSetTemplate.objects.get_or_create(
name=tset_name
)
# Create the database set
dataset = DatabaseSet.objects.filter(
name=set_attr["name"], template=dataset_template, protocol=protocol,
dataset, _ = DatabaseSet.objects.get_or_create(
name=set_attr["name"], template=dataset_template, protocol=protocol
)
if not dataset: # create
dataset = DatabaseSet.objects.create(
name=set_attr["name"],
template=dataset_template,
protocol=protocol,
)
# Create the database set template output
for output_name, format_name in set_attr["outputs"].items():
if len(format_name.split("/")) != 3:
......@@ -146,54 +157,24 @@ def refresh_protocols(sender, instance, **kwargs):
"value `%s' is not valid" % (format_name,)
)
(author, name, version) = format_name.split("/")
dataformats = DataFormat.objects.filter(
dataformat = DataFormat.objects.get(
author__username=author, name=name, version=version,
)
# TODO: Remove this when validation works (see comments)
if len(dataformats) != 1:
raise SyntaxError(
"Could not find dataformat named `%s' to set"
"output `%s' of template `%s' for protocol"
"`%s' of database `%s'",
(
format_name,
output_name,
dataset_template.name,
protocol_name,
instance.name,
),
)
return
database_template_output = DatabaseSetTemplateOutput.objects.filter(
(
database_template_output,
_,
) = DatabaseSetTemplateOutput.objects.get_or_create(
name=output_name,
template=dataset_template,
dataformat=dataformats[0],
dataformat=dataformat,
)
if not database_template_output: # create
database_template_output = DatabaseSetTemplateOutput(
name=output_name,
template=dataset_template,
dataformat=dataformats[0],
)
database_template_output.save()
else:
database_template_output = database_template_output[0]
# Create the database set output
dataset_output = DatabaseSetOutput.objects.filter(
DatabaseSetOutput.objects.get_or_create(
template=database_template_output, set=dataset,
)
if not dataset_output: # create
dataset_output = DatabaseSetOutput(
template=database_template_output, set=dataset,
)
dataset_output.save()
except Exception:
instance.delete()
raise
......
......@@ -36,13 +36,14 @@ from django.urls import reverse
from ..common.testutils import BaseTestCase
from ..common.testutils import tearDownModule # noqa test runner will call it
from ..dataformats.models import DataFormat
from ..protocoltemplates.models import ProtocolTemplate
from .models import Database
TEST_PWD = "1234" # nosec
class DatabaseAPIBase(BaseTestCase):
DATABASE = {
DATABASE_V1 = {
"root_folder": "/path/to/root/folder",
"protocols": [
{
......@@ -53,6 +54,7 @@ class DatabaseAPIBase(BaseTestCase):
"name": "set1",
"template": "set",
"view": "dummy",
"parameters": {"annotations": "/path/to/annotations"},
"outputs": {"out": settings.SYSTEM_ACCOUNT + "/float/1"},
}
],
......@@ -60,6 +62,30 @@ class DatabaseAPIBase(BaseTestCase):
],
}
DATABASE_V2 = {
"schema_version": 2,
"root_folder": "/path/to/root/folder",
"protocols": [
{
"name": "protocol1",
"template": "set1/1",
"views": {
"set1": {
"view": "dummy",
"parameters": {"annotations": "../annotations"},
}
},
}
],
}
PROTOCOL_TEMPLATE = {
"schema_version": 1,
"sets": [
{"name": "set1", "outputs": {"out": settings.SYSTEM_ACCOUNT + "/float/1"}}
],
}
def setUp(self):
# Users
self.system_user = User.objects.create_user(
......@@ -96,11 +122,10 @@ class DatabaseCreationAPI(DatabaseAPIBase):
response = self.client.post(
self.url,
json.dumps(
{"name": self.db_name, "version": 1, "declaration": self.DATABASE}
{"name": self.db_name, "version": 1, "declaration": self.DATABASE_V1}
),
content_type="application/json",
)
self.checkResponse(response, 400)
def test_create_database(self):
......@@ -111,24 +136,56 @@ class DatabaseCreationAPI(DatabaseAPIBase):
self.assertIsNotNone(dataformat, errors)
dataformat.share()
(protocol_template, errors) = ProtocolTemplate.objects.create_protocoltemplate(
"set1", declaration=self.PROTOCOL_TEMPLATE
)
self.assertIsNotNone(protocol_template, errors)
protocol_template.share()
self.client.login(username=settings.SYSTEM_ACCOUNT, password=TEST_PWD)
response = self.client.post(
self.url,
json.dumps(
{"name": self.db_name, "version": 1, "declaration": self.DATABASE}
),
content_type="application/json",
)
with self.subTest("Create a v1 database"):
response = self.client.post(
self.url,
json.dumps(
{
"name": self.db_name,
"version": 1,
"declaration": self.DATABASE_V1,
}
),
content_type="application/json",
)
data = self.checkResponse(response, 201, content_type="application/json")
data = self.checkResponse(response, 201, content_type="application/json")
self.assertTrue(data["name"] == self.db_name)
self.assertTrue(data["name"] == self.db_name)
databases = Database.objects.all()
self.assertEqual(databases.count(), 1)
databases.delete()
dataformat.delete()
databases = Database.objects.all()
self.assertEqual(databases.count(), 1)
with self.subTest("Create a v2 database"):
response = self.client.post(
self.url,
json.dumps(
{
"name": self.db_name,
"version": 2,
"declaration": self.DATABASE_V2,
}
),
content_type="application/json",
)
data = self.checkResponse(response, 201, content_type="application/json")
self.assertTrue(data["name"] == self.db_name)
databases = Database.objects.all()
self.assertEqual(databases.count(), 2)
Database.objects.all().delete()
DataFormat.objects.all().delete()
class DatabaseRetrievalAPI(DatabaseAPIBase):
......@@ -141,13 +198,25 @@ class DatabaseRetrievalAPI(DatabaseAPIBase):
self.assertIsNotNone(dataformat, errors)
dataformat.share()
(protocol_template, errors) = ProtocolTemplate.objects.create_protocoltemplate(
"set1", declaration=self.PROTOCOL_TEMPLATE
)
self.assertIsNotNone(protocol_template, errors)
protocol_template.share()
(self.database, errors) = Database.objects.create_database(
self.db_name, declaration=self.DATABASE
self.db_name, version=1, declaration=self.DATABASE_V1
)
self.assertIsNotNone(self.database, errors)
(self.database_v2, errors) = Database.objects.create_database(
self.db_name, version=2, declaration=self.DATABASE_V2
)
self.assertIsNotNone(self.database_v2, errors)
def tearDown(self):
self.database.delete()
self.database_v2.delete()
def __accessibility_test(self):
self.database.accessibility_start_date = datetime.now()
......@@ -182,19 +251,35 @@ class DatabaseRetrievalAPI(DatabaseAPIBase):
self.assertEqual(len(response.json()), 0)
def test_retrieve_database(self):
self.database.share()
self.client.login(username=settings.SYSTEM_ACCOUNT, password=TEST_PWD)
url = reverse(
"api_databases:object", kwargs={"database_name": self.db_name, "version": 1}
)
response = self.client.get(url, format="json")
data = self.checkResponse(response, 200, content_type="application/json")
declaration = json.loads(data["declaration"])
self.assertTrue(declaration["root_folder"].startswith("/path_to_db_folder"))
def _check_database(kwargs):
url = reverse("api_databases:object", kwargs=kwargs)
response = self.client.get(url, format="json")
data = self.checkResponse(response, 200, content_type="application/json")
declaration = json.loads(data["declaration"])
self.assertTrue(declaration["root_folder"].startswith("/path_to_db_folder"))
return declaration
with self.subTest("Retrieve a v1 database"):
self.database.share()
declaration = _check_database({"database_name": self.db_name, "version": 1})
self.assertEqual(
declaration["protocols"][0]["sets"][0]["parameters"]["annotations"],
"/path_to_db_folder/to/annotations",
)
with self.subTest("Retrieve a v2 database"):
self.database_v2.share()
declaration = _check_database({"database_name": self.db_name, "version": 2})
self.assertEqual(
declaration["protocols"][0]["views"]["set1"]["parameters"][
"annotations"
],
"../annotations",
)
def test_dated_database_for_user(self):
self.database.share(users=[settings.SYSTEM_ACCOUNT])
......
......@@ -153,7 +153,7 @@ class PlotterParameterCreationTestCase(PlotterParameterTestCase):
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(