diff --git a/zotify/__init__.py b/zotify/__init__.py index 52a06df..7355f1d 100644 --- a/zotify/__init__.py +++ b/zotify/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations +from enum import IntEnum from http.server import BaseHTTPRequestHandler, HTTPServer from pathlib import Path from threading import Thread @@ -69,7 +70,7 @@ class Session(LibrespotSession): def __init__( self, session_builder: LibrespotSession.Builder, - token: TokenProvider.StoredToken, + oauth: OAuth, language: str = "en", ) -> None: """ @@ -89,7 +90,7 @@ class Session(LibrespotSession): ), ApResolver.get_random_accesspoint(), ) - self.__token = token + self.__oauth = oauth self.__language = language self.connect() self.authenticate(session_builder.login_credentials) @@ -112,8 +113,7 @@ class Session(LibrespotSession): .build() ) session = LibrespotSession.Builder(conf).stored_file(str(cred_file)) - token = session.login_credentials.auth_data # TODO: this is wrong - return Session(session, token, language) + return Session(session, OAuth(), language) # TODO @staticmethod def from_oauth( @@ -148,7 +148,7 @@ class Session(LibrespotSession): typ=Authentication.AuthenticationType.values()[3], auth_data=token.access_token.encode(), ) - return Session(session, token, language) + return Session(session, auth, language) def __get_playable( self, playable_id: PlayableId, quality: Quality @@ -188,9 +188,9 @@ class Session(LibrespotSession): self.api(), ) - def token(self) -> TokenProvider.StoredToken: - """Returns API token""" - return self.__token + def oauth(self) -> OAuth: + """Returns OAuth service""" + return self.__oauth def language(self) -> str: """Returns session language""" @@ -288,7 +288,7 @@ class TokenProvider(LibrespotTokenProvider): self._session = session def get_token(self, *scopes) -> TokenProvider.StoredToken: - return self._session.token() + return self._session.oauth().get_token() class StoredToken(LibrespotTokenProvider.StoredToken): def __init__(self, obj): @@ -309,6 +309,11 @@ class OAuth: self.__server_thread.start() def get_authorization_url(self) -> str: + """ + Generate OAuth URL + Returns: + OAuth URL + """ self.__code_verifier = generate_code_verifier() code_challenge = get_code_challenge(self.__code_verifier) params = { @@ -322,19 +327,48 @@ class OAuth: return f"{AUTH_URL}authorize?{urlencode(params)}" def await_token(self) -> TokenProvider.StoredToken: + """ + Blocks until server thread gets token + Returns: + StoredToken + """ self.__server_thread.join() return self.__token - def set_token(self, code: str) -> None: + def get_token(self) -> TokenProvider.StoredToken: + """ + Gets a valid token + Returns: + StoredToken + """ + if self.__token is None: + raise RuntimeError("Session isn't authenticated!") + elif self.__token.expired(): + self.set_token(self.__token.refresh_token, OAuth.RequestType.REFRESH) + return self.__token + + def set_token(self, code: str, request_type: RequestType) -> None: + """ + Fetches and sets stored token + Returns: + StoredToken + """ token_url = f"{AUTH_URL}api/token" headers = {"Content-Type": "application/x-www-form-urlencoded"} - body = { - "grant_type": "authorization_code", - "code": code, - "redirect_uri": REDIRECT_URI, - "client_id": CLIENT_ID, - "code_verifier": self.__code_verifier, - } + if request_type == OAuth.RequestType.LOGIN: + body = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": REDIRECT_URI, + "client_id": CLIENT_ID, + "code_verifier": self.__code_verifier, + } + elif request_type == OAuth.RequestType.REFRESH: + body = { + "grant_type": "refresh_token", + "refresh_token": code, + "client_id": CLIENT_ID, + } response = post(token_url, headers=headers, data=body) if response.status_code != 200: raise IOError( @@ -348,6 +382,10 @@ class OAuth: httpd.authenticator = self httpd.serve_forever() + class RequestType(IntEnum): + LOGIN = 0 + REFRESH = 1 + class OAuthHTTPServer(HTTPServer): authenticator: OAuth @@ -371,7 +409,9 @@ class OAuth: if code: if isinstance(self.server, OAuth.OAuthHTTPServer): - self.server.authenticator.set_token(code[0]) + self.server.authenticator.set_token( + code[0], OAuth.RequestType.LOGIN + ) self.send_response(200) self.send_header("Content-type", "text/html") self.end_headers() diff --git a/zotify/app.py b/zotify/app.py index 0f4780b..57cfb04 100644 --- a/zotify/app.py +++ b/zotify/app.py @@ -106,7 +106,7 @@ class Selection: def __print(self, count: int, items: list[dict[str, Any]], *args: str) -> None: arg_range = range(len(args)) - category_str = " # " + " ".join("{:<38}" for _ in arg_range) + category_str = "# " + " ".join("{:<38}" for _ in arg_range) print(category_str.format(*[s.upper() for s in list(args)])) for item in items: count += 1