diff --git a/.pylintrc b/.pylintrc index 4498033..03e3384 100644 --- a/.pylintrc +++ b/.pylintrc @@ -171,7 +171,7 @@ class-rgx=[A-Z_][a-zA-Z0-9]+$ function-rgx=[a-z_][a-z0-9_]{2,30}$ # Regular expression which should only match correct method names -method-rgx=(test[A-Za-z0-9_]{2,30})|([a-z_][a-z0-9_]{2,30})$ +method-rgx=[a-z_][a-z0-9_]{2,30}$ # Regular expression which should only match correct instance attribute names attr-rgx=[a-z_][a-z0-9_]{2,30}$ diff --git a/dfdewey/datastore/postgresql.py b/dfdewey/datastore/postgresql.py index 921d1e8..d104092 100644 --- a/dfdewey/datastore/postgresql.py +++ b/dfdewey/datastore/postgresql.py @@ -125,7 +125,7 @@ class PostgresqlDataStore(): WHERE table_schema = '{0:s}' AND table_name = '{1:s}'""".format( table_schema, table_name)) - return self.cursor.fetchone() + return self.cursor.fetchone() is not None def value_exists(self, table_name, column_name, value): """Check if a value exists in a table. @@ -143,4 +143,4 @@ class PostgresqlDataStore(): SELECT 1 from {0:s} WHERE {1:s} = '{2:s}'""".format(table_name, column_name, value)) - return self.cursor.fetchone() + return self.cursor.fetchone() is not None diff --git a/dfdewey/datastore/postgresql_test.py b/dfdewey/datastore/postgresql_test.py new file mode 100644 index 0000000..7ec56a9 --- /dev/null +++ b/dfdewey/datastore/postgresql_test.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for PostgreSQL datastore.""" + +import unittest +import mock + +from dfdewey.datastore.postgresql import PostgresqlDataStore + + +class PostgresqlTest(unittest.TestCase): + """Tests for PostgreSQL datastore.""" + + def _get_datastore(self): + """Get a mock postgresql datastore. + + Returns: + Mock postgresql datastore. + """ + with mock.patch('psycopg2.connect') as _: + db = PostgresqlDataStore(autocommit=True) + return db + + @mock.patch('psycopg2.extras.execute_values') + def test_bulk_insert(self, _): + """Test bulk insert method.""" + db = self._get_datastore() + rows = [(1, 1), (2, 2), (3, 3)] + db.bulk_insert('blocks (block, inum)', rows) + + def test_execute(self): + """Test execute method.""" + db = self._get_datastore() + command = ( + 'CREATE TABLE images (image_path TEXT, image_hash TEXT PRIMARY KEY)') + db.execute(command) + + def test_query(self): + """Test query method.""" + db = self._get_datastore() + query = 'SELECT filename FROM files WHERE inum = 0' + with mock.patch.object(db.cursor, 'fetchall', return_value=[('$MFT',)]): + results = db.query(query) + + self.assertEqual(results, [('$MFT',)]) + + def test_query_single_row(self): + """Test query single row method.""" + db = self._get_datastore() + query = ( + 'SELECT 1 from image_case WHERE image_hash = ' + '\'d41d8cd98f00b204e9800998ecf8427e\'') + with mock.patch.object(db.cursor, 'fetchone', return_value=(1,)): + results = db.query_single_row(query) + + self.assertEqual(results, (1,)) + + def test_switch_database(self): + """Test switch database method.""" + db = self._get_datastore() + db.switch_database(db_name='dfdewey', autocommit=True) + + def test_table_exists(self): + """Test table exists method.""" + db = self._get_datastore() + + with mock.patch.object(db.cursor, 'fetchone', return_value=(1,)): + result = db.table_exists('images') + self.assertEqual(result, True) + + with mock.patch.object(db.cursor, 'fetchone', return_value=None): + result = db.table_exists('images') + self.assertEqual(result, False) + + def test_value_exists(self): + """Test value exists method.""" + db = self._get_datastore() + + with mock.patch.object(db.cursor, 'fetchone', return_value=(1,)): + result = db.value_exists( + 'images', 'image_hash', 'd41d8cd98f00b204e9800998ecf8427e') + self.assertEqual(result, True) + + with mock.patch.object(db.cursor, 'fetchone', return_value=None): + result = db.value_exists( + 'images', 'image_hash', 'd41d8cd98f00b204e9800998ecf8427e') + self.assertEqual(result, False) + + +if __name__ == '__main__': + unittest.main() diff --git a/dfdewey/utils/image.py b/dfdewey/utils/image.py index a070d68..5f630ac 100644 --- a/dfdewey/utils/image.py +++ b/dfdewey/utils/image.py @@ -72,7 +72,7 @@ def check_tracking_database(tracking_db, image_path, image_hash, case): tracking_db.execute( """ CREATE TABLE image_case ( - case_id TEXT, image_hash TEXT REFERENCES images(image_hash), + case_id TEXT, image_hash TEXT REFERENCES images(image_hash), PRIMARY KEY (case_id, image_hash))""") else: image_exists = tracking_db.value_exists('images', 'image_hash', image_hash)