Commit e32313bc authored by Samuel GAIST's avatar Samuel GAIST

[databases][api] Fix paths anonymization

Take into account v1 and v2 databases
parent b8f3c8cd
......@@ -28,6 +28,7 @@
import json
import logging
import os
from pathlib import PurePath
from rest_framework import exceptions as drf_exceptions
from rest_framework import permissions as drf_permissions
......@@ -98,17 +99,32 @@ def database_to_json(database, request_user, fields_to_return):
def clean_paths(declaration):
pseudo_path = "/path_to_db_folder"
def _clean_path(item):
parameters = item.get("parameters", {})
if "annotations" not in parameters:
return
ppath = PurePath(parameters["annotations"])
if not ppath.is_absolute():
return
cleaned_folder = ppath.parts[-2:]
parameters["annotations"] = os.path.join(pseudo_path, *cleaned_folder)
root_folder = declaration["root_folder"]
cleaned_folder = os.path.basename(os.path.normpath(root_folder))
declaration["root_folder"] = os.path.join(pseudo_path, cleaned_folder)
for protocol in declaration["protocols"]:
for set_ in protocol["sets"]:
if "parameters" in set_ and "annotations" in set_["parameters"]:
annotations_folder = set_["parameters"]["annotations"]
cleaned_folder = annotations_folder.split("/")[-2:]
set_["parameters"]["annotations"] = os.path.join(
pseudo_path, *cleaned_folder
)
# sets is a key only available in the V1 version of databases
if "sets" in protocol:
for set_ in protocol["sets"]:
_clean_path(set_)
else:
for view in protocol["views"].values():
_clean_path(view)
return declaration
......
......@@ -36,13 +36,14 @@ from django.urls import reverse
from ..common.testutils import BaseTestCase
from ..common.testutils import tearDownModule # noqa test runner will call it
from ..dataformats.models import DataFormat
from ..protocoltemplates.models import ProtocolTemplate
from .models import Database
TEST_PWD = "1234" # nosec
class DatabaseAPIBase(BaseTestCase):
DATABASE = {
DATABASE_V1 = {
"root_folder": "/path/to/root/folder",
"protocols": [
{
......@@ -53,6 +54,7 @@ class DatabaseAPIBase(BaseTestCase):
"name": "set1",
"template": "set",
"view": "dummy",
"parameters": {"annotations": "/path/to/annotations"},
"outputs": {"out": settings.SYSTEM_ACCOUNT + "/float/1"},
}
],
......@@ -60,6 +62,30 @@ class DatabaseAPIBase(BaseTestCase):
],
}
DATABASE_V2 = {
"schema_version": 2,
"root_folder": "/path/to/root/folder",
"protocols": [
{
"name": "protocol1",
"template": "set1/1",
"views": {
"set1": {
"view": "dummy",
"parameters": {"annotations": "../annotations"},
}
},
}
],
}
PROTOCOL_TEMPLATE = {
"schema_version": 1,
"sets": [
{"name": "set1", "outputs": {"out": settings.SYSTEM_ACCOUNT + "/float/1"}}
],
}
def setUp(self):
# Users
self.system_user = User.objects.create_user(
......@@ -96,11 +122,10 @@ class DatabaseCreationAPI(DatabaseAPIBase):
response = self.client.post(
self.url,
json.dumps(
{"name": self.db_name, "version": 1, "declaration": self.DATABASE}
{"name": self.db_name, "version": 1, "declaration": self.DATABASE_V1}
),
content_type="application/json",
)
self.checkResponse(response, 400)
def test_create_database(self):
......@@ -111,24 +136,56 @@ class DatabaseCreationAPI(DatabaseAPIBase):
self.assertIsNotNone(dataformat, errors)
dataformat.share()
(protocol_template, errors) = ProtocolTemplate.objects.create_protocoltemplate(
"set1", declaration=self.PROTOCOL_TEMPLATE
)
self.assertIsNotNone(protocol_template, errors)
protocol_template.share()
self.client.login(username=settings.SYSTEM_ACCOUNT, password=TEST_PWD)
response = self.client.post(
self.url,
json.dumps(
{"name": self.db_name, "version": 1, "declaration": self.DATABASE}
),
content_type="application/json",
)
with self.subTest("Create a v1 database"):
response = self.client.post(
self.url,
json.dumps(
{
"name": self.db_name,
"version": 1,
"declaration": self.DATABASE_V1,
}
),
content_type="application/json",
)
data = self.checkResponse(response, 201, content_type="application/json")
data = self.checkResponse(response, 201, content_type="application/json")
self.assertTrue(data["name"] == self.db_name)
self.assertTrue(data["name"] == self.db_name)
databases = Database.objects.all()
self.assertEqual(databases.count(), 1)
databases.delete()
dataformat.delete()
databases = Database.objects.all()
self.assertEqual(databases.count(), 1)
with self.subTest("Create a v2 database"):
response = self.client.post(
self.url,
json.dumps(
{
"name": self.db_name,
"version": 2,
"declaration": self.DATABASE_V2,
}
),
content_type="application/json",
)
data = self.checkResponse(response, 201, content_type="application/json")
self.assertTrue(data["name"] == self.db_name)
databases = Database.objects.all()
self.assertEqual(databases.count(), 2)
Database.objects.all().delete()
DataFormat.objects.all().delete()
class DatabaseRetrievalAPI(DatabaseAPIBase):
......@@ -141,13 +198,25 @@ class DatabaseRetrievalAPI(DatabaseAPIBase):
self.assertIsNotNone(dataformat, errors)
dataformat.share()
(protocol_template, errors) = ProtocolTemplate.objects.create_protocoltemplate(
"set1", declaration=self.PROTOCOL_TEMPLATE
)
self.assertIsNotNone(protocol_template, errors)
protocol_template.share()
(self.database, errors) = Database.objects.create_database(
self.db_name, declaration=self.DATABASE
self.db_name, version=1, declaration=self.DATABASE_V1
)
self.assertIsNotNone(self.database, errors)
(self.database_v2, errors) = Database.objects.create_database(
self.db_name, version=2, declaration=self.DATABASE_V2
)
self.assertIsNotNone(self.database_v2, errors)
def tearDown(self):
self.database.delete()
self.database_v2.delete()
def __accessibility_test(self):
self.database.accessibility_start_date = datetime.now()
......@@ -182,19 +251,35 @@ class DatabaseRetrievalAPI(DatabaseAPIBase):
self.assertEqual(len(response.json()), 0)
def test_retrieve_database(self):
self.database.share()
self.client.login(username=settings.SYSTEM_ACCOUNT, password=TEST_PWD)
url = reverse(
"api_databases:object", kwargs={"database_name": self.db_name, "version": 1}
)
response = self.client.get(url, format="json")
data = self.checkResponse(response, 200, content_type="application/json")
declaration = json.loads(data["declaration"])
self.assertTrue(declaration["root_folder"].startswith("/path_to_db_folder"))
def _check_database(kwargs):
url = reverse("api_databases:object", kwargs=kwargs)
response = self.client.get(url, format="json")
data = self.checkResponse(response, 200, content_type="application/json")
declaration = json.loads(data["declaration"])
self.assertTrue(declaration["root_folder"].startswith("/path_to_db_folder"))
return declaration
with self.subTest("Retrieve a v1 database"):
self.database.share()
declaration = _check_database({"database_name": self.db_name, "version": 1})
self.assertEqual(
declaration["protocols"][0]["sets"][0]["parameters"]["annotations"],
"/path_to_db_folder/to/annotations",
)
with self.subTest("Retrieve a v2 database"):
self.database_v2.share()
declaration = _check_database({"database_name": self.db_name, "version": 2})
self.assertEqual(
declaration["protocols"][0]["views"]["set1"]["parameters"][
"annotations"
],
"../annotations",
)
def test_dated_database_for_user(self):
self.database.share(users=[settings.SYSTEM_ACCOUNT])
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment