#  Copyright 2020 Google LLC
#  Copyright 2020 EPAM Systems
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

import io
import os
import sys

sys.path.append(f"{os.getenv('AIRFLOW_SRC_DIR')}/plugins")
sys.path.append(f"{os.getenv('AIRFLOW_SRC_DIR')}/dags")

from libs.exceptions import GCSObjectURIError
import pytest
from libs.context import Context
from libs.refresh_token import AirflowTokenRefresher
from libs.upload_file import GCSFileUploader


class TestSourceFileChecker:

    @pytest.fixture()
    def file_uploader(self, monkeypatch):
        context = Context(data_partition_id="test", app_key="")
        file_uploader = GCSFileUploader("http://test", AirflowTokenRefresher(),
                                                         context)
        monkeypatch.setattr(file_uploader, "_get_signed_url_request",
                            lambda *args, **kwargs: ("test", "test"))
        monkeypatch.setattr(file_uploader, "_upload_file_request",
                            lambda *args, **kwargs: None)
        monkeypatch.setattr(file_uploader, "_get_file_location_request",
                            lambda *args, **kwargs: "test")
        return file_uploader

    def test_get_file_from_bucket(
        self,
        monkeypatch,
        file_uploader: GCSFileUploader
    ):
        file = io.RawIOBase()
        monkeypatch.setattr(file_uploader, "get_file_from_bucket",
                            lambda *args, **kwargs: (file, "test"))
        file_uploader.upload_file("gs://test/test")

    @pytest.mark.parametrize(
        "file_path",
        [
            pytest.param("gs://test"),
            pytest.param("://test"),
            pytest.param("test"),
        ]
    )
    def test_invalid_gcs_object_uri(self, file_uploader: GCSFileUploader,
                                    file_path: str):
        with pytest.raises(GCSObjectURIError):
            file_uploader._parse_object_uri(file_path)