Commit e32313bc authored by Samuel GAIST's avatar Samuel GAIST

[databases][api] Fix paths anonymization

Take into account v1 and v2 databases
parent b8f3c8cd
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
import json import json
import logging import logging
import os import os
from pathlib import PurePath
from rest_framework import exceptions as drf_exceptions from rest_framework import exceptions as drf_exceptions
from rest_framework import permissions as drf_permissions from rest_framework import permissions as drf_permissions
...@@ -98,17 +99,32 @@ def database_to_json(database, request_user, fields_to_return): ...@@ -98,17 +99,32 @@ def database_to_json(database, request_user, fields_to_return):
def clean_paths(declaration): def clean_paths(declaration):
pseudo_path = "/path_to_db_folder" pseudo_path = "/path_to_db_folder"
def _clean_path(item):
parameters = item.get("parameters", {})
if "annotations" not in parameters:
return
ppath = PurePath(parameters["annotations"])
if not ppath.is_absolute():
return
cleaned_folder = ppath.parts[-2:]
parameters["annotations"] = os.path.join(pseudo_path, *cleaned_folder)
root_folder = declaration["root_folder"] root_folder = declaration["root_folder"]
cleaned_folder = os.path.basename(os.path.normpath(root_folder)) cleaned_folder = os.path.basename(os.path.normpath(root_folder))
declaration["root_folder"] = os.path.join(pseudo_path, cleaned_folder) declaration["root_folder"] = os.path.join(pseudo_path, cleaned_folder)
for protocol in declaration["protocols"]: for protocol in declaration["protocols"]:
for set_ in protocol["sets"]: # sets is a key only available in the V1 version of databases
if "parameters" in set_ and "annotations" in set_["parameters"]: if "sets" in protocol:
annotations_folder = set_["parameters"]["annotations"] for set_ in protocol["sets"]:
cleaned_folder = annotations_folder.split("/")[-2:] _clean_path(set_)
set_["parameters"]["annotations"] = os.path.join( else:
pseudo_path, *cleaned_folder for view in protocol["views"].values():
) _clean_path(view)
return declaration return declaration
......
...@@ -36,13 +36,14 @@ from django.urls import reverse ...@@ -36,13 +36,14 @@ from django.urls import reverse
from ..common.testutils import BaseTestCase from ..common.testutils import BaseTestCase
from ..common.testutils import tearDownModule # noqa test runner will call it from ..common.testutils import tearDownModule # noqa test runner will call it
from ..dataformats.models import DataFormat from ..dataformats.models import DataFormat
from ..protocoltemplates.models import ProtocolTemplate
from .models import Database from .models import Database
TEST_PWD = "1234" # nosec TEST_PWD = "1234" # nosec
class DatabaseAPIBase(BaseTestCase): class DatabaseAPIBase(BaseTestCase):
DATABASE = { DATABASE_V1 = {
"root_folder": "/path/to/root/folder", "root_folder": "/path/to/root/folder",
"protocols": [ "protocols": [
{ {
...@@ -53,6 +54,7 @@ class DatabaseAPIBase(BaseTestCase): ...@@ -53,6 +54,7 @@ class DatabaseAPIBase(BaseTestCase):
"name": "set1", "name": "set1",
"template": "set", "template": "set",
"view": "dummy", "view": "dummy",
"parameters": {"annotations": "/path/to/annotations"},
"outputs": {"out": settings.SYSTEM_ACCOUNT + "/float/1"}, "outputs": {"out": settings.SYSTEM_ACCOUNT + "/float/1"},
} }
], ],
...@@ -60,6 +62,30 @@ class DatabaseAPIBase(BaseTestCase): ...@@ -60,6 +62,30 @@ class DatabaseAPIBase(BaseTestCase):
], ],
} }
DATABASE_V2 = {
"schema_version": 2,
"root_folder": "/path/to/root/folder",
"protocols": [
{
"name": "protocol1",
"template": "set1/1",
"views": {
"set1": {
"view": "dummy",
"parameters": {"annotations": "../annotations"},
}
},
}
],
}
PROTOCOL_TEMPLATE = {
"schema_version": 1,
"sets": [
{"name": "set1", "outputs": {"out": settings.SYSTEM_ACCOUNT + "/float/1"}}
],
}
def setUp(self): def setUp(self):
# Users # Users
self.system_user = User.objects.create_user( self.system_user = User.objects.create_user(
...@@ -96,11 +122,10 @@ class DatabaseCreationAPI(DatabaseAPIBase): ...@@ -96,11 +122,10 @@ class DatabaseCreationAPI(DatabaseAPIBase):
response = self.client.post( response = self.client.post(
self.url, self.url,
json.dumps( json.dumps(
{"name": self.db_name, "version": 1, "declaration": self.DATABASE} {"name": self.db_name, "version": 1, "declaration": self.DATABASE_V1}
), ),
content_type="application/json", content_type="application/json",
) )
self.checkResponse(response, 400) self.checkResponse(response, 400)
def test_create_database(self): def test_create_database(self):
...@@ -111,24 +136,56 @@ class DatabaseCreationAPI(DatabaseAPIBase): ...@@ -111,24 +136,56 @@ class DatabaseCreationAPI(DatabaseAPIBase):
self.assertIsNotNone(dataformat, errors) self.assertIsNotNone(dataformat, errors)
dataformat.share() dataformat.share()
(protocol_template, errors) = ProtocolTemplate.objects.create_protocoltemplate(
"set1", declaration=self.PROTOCOL_TEMPLATE
)
self.assertIsNotNone(protocol_template, errors)
protocol_template.share()
self.client.login(username=settings.SYSTEM_ACCOUNT, password=TEST_PWD) self.client.login(username=settings.SYSTEM_ACCOUNT, password=TEST_PWD)
response = self.client.post( with self.subTest("Create a v1 database"):
self.url, response = self.client.post(
json.dumps( self.url,
{"name": self.db_name, "version": 1, "declaration": self.DATABASE} json.dumps(
), {
content_type="application/json", "name": self.db_name,
) "version": 1,
"declaration": self.DATABASE_V1,
}
),
content_type="application/json",
)
data = self.checkResponse(response, 201, content_type="application/json") data = self.checkResponse(response, 201, content_type="application/json")
self.assertTrue(data["name"] == self.db_name) self.assertTrue(data["name"] == self.db_name)
databases = Database.objects.all() databases = Database.objects.all()
self.assertEqual(databases.count(), 1) self.assertEqual(databases.count(), 1)
databases.delete()
dataformat.delete() with self.subTest("Create a v2 database"):
response = self.client.post(
self.url,
json.dumps(
{
"name": self.db_name,
"version": 2,
"declaration": self.DATABASE_V2,
}
),
content_type="application/json",
)
data = self.checkResponse(response, 201, content_type="application/json")
self.assertTrue(data["name"] == self.db_name)
databases = Database.objects.all()
self.assertEqual(databases.count(), 2)
Database.objects.all().delete()
DataFormat.objects.all().delete()
class DatabaseRetrievalAPI(DatabaseAPIBase): class DatabaseRetrievalAPI(DatabaseAPIBase):
...@@ -141,13 +198,25 @@ class DatabaseRetrievalAPI(DatabaseAPIBase): ...@@ -141,13 +198,25 @@ class DatabaseRetrievalAPI(DatabaseAPIBase):
self.assertIsNotNone(dataformat, errors) self.assertIsNotNone(dataformat, errors)
dataformat.share() dataformat.share()
(protocol_template, errors) = ProtocolTemplate.objects.create_protocoltemplate(
"set1", declaration=self.PROTOCOL_TEMPLATE
)
self.assertIsNotNone(protocol_template, errors)
protocol_template.share()
(self.database, errors) = Database.objects.create_database( (self.database, errors) = Database.objects.create_database(
self.db_name, declaration=self.DATABASE self.db_name, version=1, declaration=self.DATABASE_V1
) )
self.assertIsNotNone(self.database, errors) self.assertIsNotNone(self.database, errors)
(self.database_v2, errors) = Database.objects.create_database(
self.db_name, version=2, declaration=self.DATABASE_V2
)
self.assertIsNotNone(self.database_v2, errors)
def tearDown(self): def tearDown(self):
self.database.delete() self.database.delete()
self.database_v2.delete()
def __accessibility_test(self): def __accessibility_test(self):
self.database.accessibility_start_date = datetime.now() self.database.accessibility_start_date = datetime.now()
...@@ -182,19 +251,35 @@ class DatabaseRetrievalAPI(DatabaseAPIBase): ...@@ -182,19 +251,35 @@ class DatabaseRetrievalAPI(DatabaseAPIBase):
self.assertEqual(len(response.json()), 0) self.assertEqual(len(response.json()), 0)
def test_retrieve_database(self): def test_retrieve_database(self):
self.database.share()
self.client.login(username=settings.SYSTEM_ACCOUNT, password=TEST_PWD) self.client.login(username=settings.SYSTEM_ACCOUNT, password=TEST_PWD)
url = reverse( def _check_database(kwargs):
"api_databases:object", kwargs={"database_name": self.db_name, "version": 1} url = reverse("api_databases:object", kwargs=kwargs)
)
response = self.client.get(url, format="json")
response = self.client.get(url, format="json") data = self.checkResponse(response, 200, content_type="application/json")
data = self.checkResponse(response, 200, content_type="application/json")
declaration = json.loads(data["declaration"])
declaration = json.loads(data["declaration"]) self.assertTrue(declaration["root_folder"].startswith("/path_to_db_folder"))
self.assertTrue(declaration["root_folder"].startswith("/path_to_db_folder")) return declaration
with self.subTest("Retrieve a v1 database"):
self.database.share()
declaration = _check_database({"database_name": self.db_name, "version": 1})
self.assertEqual(
declaration["protocols"][0]["sets"][0]["parameters"]["annotations"],
"/path_to_db_folder/to/annotations",
)
with self.subTest("Retrieve a v2 database"):
self.database_v2.share()
declaration = _check_database({"database_name": self.db_name, "version": 2})
self.assertEqual(
declaration["protocols"][0]["views"]["set1"]["parameters"][
"annotations"
],
"../annotations",
)
def test_dated_database_for_user(self): def test_dated_database_for_user(self):
self.database.share(users=[settings.SYSTEM_ACCOUNT]) self.database.share(users=[settings.SYSTEM_ACCOUNT])
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment